Accept model weights as parameters

This commit is contained in:
Tobias Eidelpes 2023-01-19 10:51:08 +01:00
parent 920c1fd852
commit c735b01f8b

View File

@ -4,34 +4,36 @@ import torch
from torchvision import transforms
def load_models():
def load_models(yolo_path: str, resnet_path: str):
"""Load the models for two-stage classification.
:param yolo_path: path to yolo weights
:param resnet_path: path to resnet weights
:returns: tuple of models
"""
first_stage = torch.hub.load("WongKinYiu/yolov7",
"custom",
"runs/train/yolov7-custom7/weights/best.pt",
yolo_path,
trust_repo=True)
second_stage = torch.load('resnet.pt')
second_stage = torch.load(resnet_path)
return (first_stage, second_stage)
def detect(img_path: str):
def detect(img_path: str, yolo_path: str, resnet_path: str):
"""Load an image, detect individual plants and label them as
healthy or wilted.
:param str img_path: path to image
:returns: tensor of confidence values per class
:returns: dict of bounding boxes and their predictions
"""
img = cv2.imread(img_path)
original = img.copy()
(first_stage, second_stage) = load_models()
(first_stage, second_stage) = load_models(yolo_path, resnet_path)
box_coords = get_boxes(first_stage, img)
box_coords.sort_values(by=['xmin'], ignore_index=True, inplace=True)
print(box_coords)
predictions = {}
for idx, row in box_coords.iterrows():
xmin, xmax = int(row['xmin']), int(row['xmax'])