Return dict of predictions

This commit is contained in:
Tobias Eidelpes 2023-01-19 10:52:31 +01:00
parent c735b01f8b
commit ad1bcc1dab

View File

@ -25,6 +25,8 @@ def detect(img_path: str, yolo_path: str, resnet_path: str):
healthy or wilted.
:param str img_path: path to image
:param yolo_path: path to yolo weights
:param resnet_path: path to resnet weights
:returns: dict of bounding boxes and their predictions
"""
@ -43,8 +45,7 @@ def detect(img_path: str, yolo_path: str, resnet_path: str):
cropped_image = get_cutout(img.copy(), xmin, xmax, ymin, ymax)
# Classify ROI in RGB
pred = classify(second_stage, cropped_image[..., ::-1])
print(pred)
predictions[idx] = classify(second_stage, cropped_image[..., ::-1])
# cv2.imshow('cropped ' + str(idx), cropped_image)
# cv2.waitKey(0)
@ -65,6 +66,7 @@ def detect(img_path: str, yolo_path: str, resnet_path: str):
cv2.imshow('original with bounding box', original)
cv2.waitKey(0)
cv2.destroyAllWindows()
return predictions
def get_boxes(model, img):
@ -124,4 +126,5 @@ if __name__ == '__main__':
if opt.source:
with torch.no_grad():
detect(opt.source)
detect(opt.source, 'runs/train/yolov7-custom7/weights/best.pt',
'resnet.pt')