Accept model weights as parameters
This commit is contained in:
parent
920c1fd852
commit
c735b01f8b
@ -4,34 +4,36 @@ import torch
|
|||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
|
|
||||||
|
|
||||||
def load_models():
|
def load_models(yolo_path: str, resnet_path: str):
|
||||||
"""Load the models for two-stage classification.
|
"""Load the models for two-stage classification.
|
||||||
|
|
||||||
|
:param yolo_path: path to yolo weights
|
||||||
|
:param resnet_path: path to resnet weights
|
||||||
:returns: tuple of models
|
:returns: tuple of models
|
||||||
|
|
||||||
"""
|
"""
|
||||||
first_stage = torch.hub.load("WongKinYiu/yolov7",
|
first_stage = torch.hub.load("WongKinYiu/yolov7",
|
||||||
"custom",
|
"custom",
|
||||||
"runs/train/yolov7-custom7/weights/best.pt",
|
yolo_path,
|
||||||
trust_repo=True)
|
trust_repo=True)
|
||||||
second_stage = torch.load('resnet.pt')
|
second_stage = torch.load(resnet_path)
|
||||||
return (first_stage, second_stage)
|
return (first_stage, second_stage)
|
||||||
|
|
||||||
|
|
||||||
def detect(img_path: str):
|
def detect(img_path: str, yolo_path: str, resnet_path: str):
|
||||||
"""Load an image, detect individual plants and label them as
|
"""Load an image, detect individual plants and label them as
|
||||||
healthy or wilted.
|
healthy or wilted.
|
||||||
|
|
||||||
:param str img_path: path to image
|
:param str img_path: path to image
|
||||||
:returns: tensor of confidence values per class
|
:returns: dict of bounding boxes and their predictions
|
||||||
|
|
||||||
"""
|
"""
|
||||||
img = cv2.imread(img_path)
|
img = cv2.imread(img_path)
|
||||||
original = img.copy()
|
original = img.copy()
|
||||||
(first_stage, second_stage) = load_models()
|
(first_stage, second_stage) = load_models(yolo_path, resnet_path)
|
||||||
box_coords = get_boxes(first_stage, img)
|
box_coords = get_boxes(first_stage, img)
|
||||||
box_coords.sort_values(by=['xmin'], ignore_index=True, inplace=True)
|
box_coords.sort_values(by=['xmin'], ignore_index=True, inplace=True)
|
||||||
print(box_coords)
|
predictions = {}
|
||||||
|
|
||||||
for idx, row in box_coords.iterrows():
|
for idx, row in box_coords.iterrows():
|
||||||
xmin, xmax = int(row['xmin']), int(row['xmax'])
|
xmin, xmax = int(row['xmin']), int(row['xmax'])
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user