Do detections on the GPU

This commit is contained in:
Tobias Eidelpes 2023-02-27 16:28:07 +01:00
parent 3ec77e0558
commit 43cff6dff5

View File

@ -95,11 +95,28 @@ def classify(resnet_path, img):
batch = img.unsqueeze(0)
# Do inference
session = onnxruntime.InferenceSession(resnet_path)
providers = [('CUDAExecutionProvider', {
"cudnn_conv_algo_search": "DEFAULT"
}), 'CPUExecutionProvider']
session = onnxruntime.InferenceSession(resnet_path, providers=providers)
outname = [i.name for i in session.get_outputs()]
inname = [i.name for i in session.get_inputs()]
inp = {inname[0]: batch.numpy()}
out = torch.tensor(np.array(session.run(outname, inp)))[0]
io_binding = session.io_binding()
io_binding.bind_cpu_input(inname[0], inp[inname[0]])
io_binding.bind_output(outname[0])
session.run_with_iobinding(io_binding)
out = torch.tensor(io_binding.copy_outputs_to_cpu()[0])
# Do inference
# session = onnxruntime.InferenceSession(resnet_path)
# outname = [i.name for i in session.get_outputs()]
# inname = [i.name for i in session.get_inputs()]
# inp = {inname[0]: batch.numpy()}
# out = torch.tensor(np.array(session.run(outname, inp)))[0]
# Apply softmax to get percentage confidence of classes
out = torch.nn.functional.softmax(out, dim=1)[0] * 100
@ -167,14 +184,26 @@ def get_boxes(yolo_path, image):
img['image'] = img['image'].unsqueeze(0)
# Do inference
session = onnxruntime.InferenceSession(yolo_path)
providers = [('CUDAExecutionProvider', {
"cudnn_conv_algo_search": "DEFAULT"
}), 'CPUExecutionProvider']
session = onnxruntime.InferenceSession(yolo_path, providers=providers)
outname = [i.name for i in session.get_outputs()]
inname = [i.name for i in session.get_inputs()]
inp = {inname[0]: img['image'].numpy()}
out = torch.tensor(np.array(session.run(outname, inp)))[0]
io_binding = session.io_binding()
io_binding.bind_cpu_input(inname[0], inp[inname[0]])
io_binding.bind_output(outname[0])
session.run_with_iobinding(io_binding)
outs = torch.tensor(io_binding.copy_outputs_to_cpu()[0])
# out = torch.tensor(np.array(session.run(outname, inp)))[0]
# print(out.shape)
# Apply NMS to results
preds_nms = apply_nms([out])[0]
preds_nms = apply_nms([outs])[0]
# Convert boxes from resized img to original img
xyxy_boxes = preds_nms[:, [1, 2, 3, 4]] # xmin, ymin, xmax, ymax