diff --git a/yolo-second-run/model.py b/yolo-second-run/model.py index 8e8ede6..54b61ac 100644 --- a/yolo-second-run/model.py +++ b/yolo-second-run/model.py @@ -7,8 +7,8 @@ from torchvision import transforms def load_models(yolo_path: str, resnet_path: str): """Load the models for two-stage classification. - :param yolo_path: path to yolo weights - :param resnet_path: path to resnet weights + :param str yolo_path: path to yolo weights + :param str resnet_path: path to resnet weights :returns: tuple of models """ @@ -25,8 +25,8 @@ def detect(img_path: str, yolo_path: str, resnet_path: str): healthy or wilted. :param str img_path: path to image - :param yolo_path: path to yolo weights - :param resnet_path: path to resnet weights + :param str yolo_path: path to yolo weights + :param str resnet_path: path to resnet weights :returns: tuple of recent image and dict of bounding boxes and their predictions @@ -105,10 +105,10 @@ def get_cutout(img, xmin, xmax, ymin, ymax): object classification model. :param img: opencv2 image object in BGR - :param xmin: start of bounding box on x axis - :param xmax: end of bounding box on x axis - :param ymin: start of bounding box on y axis - :param ymax: end of bounding box on y axis + :param int xmin: start of bounding box on x axis + :param int xmax: end of bounding box on x axis + :param int ymin: start of bounding box on y axis + :param int ymax: end of bounding box on y axis :returns: tensor of cropped image in BGR """