Accept model weights as parameters

This commit is contained in:
Tobias Eidelpes 2023-01-19 10:51:08 +01:00
parent 920c1fd852
commit c735b01f8b

View File

@ -4,34 +4,36 @@ import torch
from torchvision import transforms from torchvision import transforms
def load_models(): def load_models(yolo_path: str, resnet_path: str):
"""Load the models for two-stage classification. """Load the models for two-stage classification.
:param yolo_path: path to yolo weights
:param resnet_path: path to resnet weights
:returns: tuple of models :returns: tuple of models
""" """
first_stage = torch.hub.load("WongKinYiu/yolov7", first_stage = torch.hub.load("WongKinYiu/yolov7",
"custom", "custom",
"runs/train/yolov7-custom7/weights/best.pt", yolo_path,
trust_repo=True) trust_repo=True)
second_stage = torch.load('resnet.pt') second_stage = torch.load(resnet_path)
return (first_stage, second_stage) return (first_stage, second_stage)
def detect(img_path: str): def detect(img_path: str, yolo_path: str, resnet_path: str):
"""Load an image, detect individual plants and label them as """Load an image, detect individual plants and label them as
healthy or wilted. healthy or wilted.
:param str img_path: path to image :param str img_path: path to image
:returns: tensor of confidence values per class :returns: dict of bounding boxes and their predictions
""" """
img = cv2.imread(img_path) img = cv2.imread(img_path)
original = img.copy() original = img.copy()
(first_stage, second_stage) = load_models() (first_stage, second_stage) = load_models(yolo_path, resnet_path)
box_coords = get_boxes(first_stage, img) box_coords = get_boxes(first_stage, img)
box_coords.sort_values(by=['xmin'], ignore_index=True, inplace=True) box_coords.sort_values(by=['xmin'], ignore_index=True, inplace=True)
print(box_coords) predictions = {}
for idx, row in box_coords.iterrows(): for idx, row in box_coords.iterrows():
xmin, xmax = int(row['xmin']), int(row['xmax']) xmin, xmax = int(row['xmin']), int(row['xmax'])