Return dict of predictions
This commit is contained in:
parent
c735b01f8b
commit
ad1bcc1dab
@ -25,6 +25,8 @@ def detect(img_path: str, yolo_path: str, resnet_path: str):
|
|||||||
healthy or wilted.
|
healthy or wilted.
|
||||||
|
|
||||||
:param str img_path: path to image
|
:param str img_path: path to image
|
||||||
|
:param yolo_path: path to yolo weights
|
||||||
|
:param resnet_path: path to resnet weights
|
||||||
:returns: dict of bounding boxes and their predictions
|
:returns: dict of bounding boxes and their predictions
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@ -43,8 +45,7 @@ def detect(img_path: str, yolo_path: str, resnet_path: str):
|
|||||||
cropped_image = get_cutout(img.copy(), xmin, xmax, ymin, ymax)
|
cropped_image = get_cutout(img.copy(), xmin, xmax, ymin, ymax)
|
||||||
|
|
||||||
# Classify ROI in RGB
|
# Classify ROI in RGB
|
||||||
pred = classify(second_stage, cropped_image[..., ::-1])
|
predictions[idx] = classify(second_stage, cropped_image[..., ::-1])
|
||||||
print(pred)
|
|
||||||
|
|
||||||
# cv2.imshow('cropped ' + str(idx), cropped_image)
|
# cv2.imshow('cropped ' + str(idx), cropped_image)
|
||||||
# cv2.waitKey(0)
|
# cv2.waitKey(0)
|
||||||
@ -65,6 +66,7 @@ def detect(img_path: str, yolo_path: str, resnet_path: str):
|
|||||||
cv2.imshow('original with bounding box', original)
|
cv2.imshow('original with bounding box', original)
|
||||||
cv2.waitKey(0)
|
cv2.waitKey(0)
|
||||||
cv2.destroyAllWindows()
|
cv2.destroyAllWindows()
|
||||||
|
return predictions
|
||||||
|
|
||||||
|
|
||||||
def get_boxes(model, img):
|
def get_boxes(model, img):
|
||||||
@ -124,4 +126,5 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
if opt.source:
|
if opt.source:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
detect(opt.source)
|
detect(opt.source, 'runs/train/yolov7-custom7/weights/best.pt',
|
||||||
|
'resnet.pt')
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user