Accept model weights as parameters
This commit is contained in:
parent
920c1fd852
commit
c735b01f8b
@ -4,34 +4,36 @@ import torch
|
||||
from torchvision import transforms
|
||||
|
||||
|
||||
def load_models():
|
||||
def load_models(yolo_path: str, resnet_path: str):
|
||||
"""Load the models for two-stage classification.
|
||||
|
||||
:param yolo_path: path to yolo weights
|
||||
:param resnet_path: path to resnet weights
|
||||
:returns: tuple of models
|
||||
|
||||
"""
|
||||
first_stage = torch.hub.load("WongKinYiu/yolov7",
|
||||
"custom",
|
||||
"runs/train/yolov7-custom7/weights/best.pt",
|
||||
yolo_path,
|
||||
trust_repo=True)
|
||||
second_stage = torch.load('resnet.pt')
|
||||
second_stage = torch.load(resnet_path)
|
||||
return (first_stage, second_stage)
|
||||
|
||||
|
||||
def detect(img_path: str):
|
||||
def detect(img_path: str, yolo_path: str, resnet_path: str):
|
||||
"""Load an image, detect individual plants and label them as
|
||||
healthy or wilted.
|
||||
|
||||
:param str img_path: path to image
|
||||
:returns: tensor of confidence values per class
|
||||
:returns: dict of bounding boxes and their predictions
|
||||
|
||||
"""
|
||||
img = cv2.imread(img_path)
|
||||
original = img.copy()
|
||||
(first_stage, second_stage) = load_models()
|
||||
(first_stage, second_stage) = load_models(yolo_path, resnet_path)
|
||||
box_coords = get_boxes(first_stage, img)
|
||||
box_coords.sort_values(by=['xmin'], ignore_index=True, inplace=True)
|
||||
print(box_coords)
|
||||
predictions = {}
|
||||
|
||||
for idx, row in box_coords.iterrows():
|
||||
xmin, xmax = int(row['xmin']), int(row['xmax'])
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user