Return dict of predictions
This commit is contained in:
parent
c735b01f8b
commit
ad1bcc1dab
@ -25,6 +25,8 @@ def detect(img_path: str, yolo_path: str, resnet_path: str):
|
||||
healthy or wilted.
|
||||
|
||||
:param str img_path: path to image
|
||||
:param yolo_path: path to yolo weights
|
||||
:param resnet_path: path to resnet weights
|
||||
:returns: dict of bounding boxes and their predictions
|
||||
|
||||
"""
|
||||
@ -43,8 +45,7 @@ def detect(img_path: str, yolo_path: str, resnet_path: str):
|
||||
cropped_image = get_cutout(img.copy(), xmin, xmax, ymin, ymax)
|
||||
|
||||
# Classify ROI in RGB
|
||||
pred = classify(second_stage, cropped_image[..., ::-1])
|
||||
print(pred)
|
||||
predictions[idx] = classify(second_stage, cropped_image[..., ::-1])
|
||||
|
||||
# cv2.imshow('cropped ' + str(idx), cropped_image)
|
||||
# cv2.waitKey(0)
|
||||
@ -65,6 +66,7 @@ def detect(img_path: str, yolo_path: str, resnet_path: str):
|
||||
cv2.imshow('original with bounding box', original)
|
||||
cv2.waitKey(0)
|
||||
cv2.destroyAllWindows()
|
||||
return predictions
|
||||
|
||||
|
||||
def get_boxes(model, img):
|
||||
@ -124,4 +126,5 @@ if __name__ == '__main__':
|
||||
|
||||
if opt.source:
|
||||
with torch.no_grad():
|
||||
detect(opt.source)
|
||||
detect(opt.source, 'runs/train/yolov7-custom7/weights/best.pt',
|
||||
'resnet.pt')
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user