diff --git a/yolo-second-run/model.py b/yolo-second-run/model.py index c1c3eb4..8e8ede6 100644 --- a/yolo-second-run/model.py +++ b/yolo-second-run/model.py @@ -37,6 +37,7 @@ def detect(img_path: str, yolo_path: str, resnet_path: str): # Get bounding boxes from object detection model box_coords = get_boxes(first_stage, img) + box_coords.sort_values(by=['xmin'], ignore_index=True, inplace=True) predictions = {} @@ -72,7 +73,8 @@ def get_boxes(model, img): :returns: pandas dataframe of matches """ - box_coords = model(img[..., ::-1], size=640) + with torch.no_grad(): + box_coords = model(img[..., ::-1], size=640) return box_coords.pandas().xyxy[0] @@ -91,9 +93,10 @@ def classify(model, img): ]) img = data_transforms(img.copy()) - out = model(img.unsqueeze(0)) - # Apply softmax to get percentage confidence of classes - out = torch.nn.functional.softmax(out, dim=1)[0] * 100 + with torch.no_grad(): + out = model(img.unsqueeze(0)) + # Apply softmax to get percentage confidence of classes + out = torch.nn.functional.softmax(out, dim=1)[0] * 100 return out @@ -119,6 +122,5 @@ if __name__ == '__main__': opt = parser.parse_args() if opt.source: - with torch.no_grad(): - detect(opt.source, 'runs/train/yolov7-custom7/weights/best.pt', - 'resnet.pt') + detect(opt.source, 'runs/train/yolov7-custom7/weights/best.pt', + 'resnet.pt')