Create function for object detection only
This commit is contained in:
parent
d88a6766c9
commit
3ec77e0558
@ -11,6 +11,7 @@ from albumentations.pytorch import ToTensorV2
|
||||
from utils.conversions import scale_bboxes
|
||||
from utils.manipulations import get_cutout
|
||||
|
||||
|
||||
def detect(img_path: str, yolo_path: str, resnet_path: str):
|
||||
"""Load an image, detect individual plants and label them as
|
||||
healthy or wilted.
|
||||
@ -55,6 +56,25 @@ def detect(img_path: str, yolo_path: str, resnet_path: str):
|
||||
return box_coords
|
||||
|
||||
|
||||
def detect_yolo_only(img_path: str, yolo_path: str):
|
||||
"""Load an image and detect individual plants.
|
||||
|
||||
:param str img_path: path to image
|
||||
:param str yolo_path: path to yolo weights
|
||||
:returns: tuple of recent image and dict of bounding boxes and
|
||||
their predictions
|
||||
|
||||
"""
|
||||
img = cv2.imread(img_path)
|
||||
|
||||
# Get bounding boxes from object detection model
|
||||
box_coords = get_boxes(yolo_path, img.copy())
|
||||
|
||||
box_coords.sort_values(by=['xmin'], ignore_index=True, inplace=True)
|
||||
|
||||
return box_coords
|
||||
|
||||
|
||||
def classify(resnet_path, img):
|
||||
"""Classify img with object classification model.
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user