Do not perform gradient calculation during inference
This commit is contained in:
parent
0ef8eb0cd4
commit
548d7a1f9c
@ -37,6 +37,7 @@ def detect(img_path: str, yolo_path: str, resnet_path: str):
|
||||
|
||||
# Get bounding boxes from object detection model
|
||||
box_coords = get_boxes(first_stage, img)
|
||||
|
||||
box_coords.sort_values(by=['xmin'], ignore_index=True, inplace=True)
|
||||
|
||||
predictions = {}
|
||||
@ -72,6 +73,7 @@ def get_boxes(model, img):
|
||||
:returns: pandas dataframe of matches
|
||||
|
||||
"""
|
||||
with torch.no_grad():
|
||||
box_coords = model(img[..., ::-1], size=640)
|
||||
return box_coords.pandas().xyxy[0]
|
||||
|
||||
@ -91,6 +93,7 @@ def classify(model, img):
|
||||
])
|
||||
|
||||
img = data_transforms(img.copy())
|
||||
with torch.no_grad():
|
||||
out = model(img.unsqueeze(0))
|
||||
# Apply softmax to get percentage confidence of classes
|
||||
out = torch.nn.functional.softmax(out, dim=1)[0] * 100
|
||||
@ -119,6 +122,5 @@ if __name__ == '__main__':
|
||||
opt = parser.parse_args()
|
||||
|
||||
if opt.source:
|
||||
with torch.no_grad():
|
||||
detect(opt.source, 'runs/train/yolov7-custom7/weights/best.pt',
|
||||
'resnet.pt')
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user