171 lines
5.3 KiB
Python

import argparse
import cv2
import torch
from torchvision import transforms
import tkinter
def load_models(yolo_path: str, resnet_path: str):
"""Load the models for two-stage classification.
:param str yolo_path: path to yolo weights
:param str resnet_path: path to resnet weights
:returns: tuple of models
"""
first_stage = torch.hub.load("WongKinYiu/yolov7",
"custom",
yolo_path,
trust_repo=True)
second_stage = torch.load(resnet_path)
return (first_stage, second_stage)
def detect(img_path: str, yolo_path: str, resnet_path: str):
"""Load an image, detect individual plants and label them as
healthy or wilted.
:param str img_path: path to image
:param str yolo_path: path to yolo weights
:param str resnet_path: path to resnet weights
:returns: tuple of recent image and dict of bounding boxes and
their predictions
"""
img = cv2.imread(img_path)
original = img.copy()
(first_stage, second_stage) = load_models(yolo_path, resnet_path)
# Get bounding boxes from object detection model
box_coords = get_boxes(first_stage, img)
box_coords.sort_values(by=['xmin'], ignore_index=True, inplace=True)
print(box_coords)
predictions = {}
for idx, row in box_coords.iterrows():
xmin, xmax = int(row['xmin']), int(row['xmax'])
ymin, ymax = int(row['ymin']), int(row['ymax'])
# Get tensor of ROI in BGR
cropped_image = get_cutout(img.copy(), xmin, xmax, ymin, ymax)
# Classify ROI in RGB
predictions[idx] = classify(second_stage, cropped_image[..., ::-1])
# Draw bounding box and number on original image
original = cv2.rectangle(original, (xmin, ymin), (xmax, ymax),
(0, 255, 0), 2)
original = cv2.putText(original, str(idx), (xmin + 5, ymin + 25),
cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 0, 0), 4,
cv2.LINE_AA)
original = cv2.putText(original, str(idx), (xmin + 5, ymin + 25),
cv2.FONT_HERSHEY_SIMPLEX, 0.9, (255, 255, 255),
2, cv2.LINE_AA)
cv2.imshow('original', original)
cv2.waitKey(0)
cv2.destroyAllWindows()
return (original, predictions)
def get_boxes(model, img):
"""Run object detection model on an image and get the bounding box
coordinates of all matches.
:param model: object detection model (YOLO)
:param img: opencv2 image object
:returns: pandas dataframe of matches
"""
with torch.no_grad():
box_coords = model(img[..., ::-1], size=640)
return box_coords.pandas().xyxy[0]
def classify(model, img):
"""Classify img with object classification model.
:param model: object classification model
:param img: opencv2 image object in RGB
:returns: tensor of class predictions
"""
# Transform image for ResNet
data_transforms = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
img = data_transforms(img.copy())
with torch.no_grad():
out = model(img.unsqueeze(0))
# Apply softmax to get percentage confidence of classes
out = torch.nn.functional.softmax(out, dim=1)[0] * 100
return out
def get_cutout(img, xmin, xmax, ymin, ymax):
"""Cut out a bounding box from an image and transform it for
object classification model.
:param img: opencv2 image object in BGR
:param int xmin: start of bounding box on x axis
:param int xmax: end of bounding box on x axis
:param int ymin: start of bounding box on y axis
:param int ymax: end of bounding box on y axis
:returns: tensor of cropped image in BGR
"""
cropped_image = img[ymin:ymax, xmin:xmax]
return cropped_image
def export_to_onnx(yolo_path: str, resnet_path: str):
"""Export the models to onnx.
:param yolo_path: path to yolo weights
:param resnet_path: path to resnet weights
:returns: None
"""
(first, second) = load_models(yolo_path, resnet_path)
first.eval()
second.eval()
first_x = torch.randn((1, 3, 640, 640), requires_grad=True)
second_x = torch.randn((1, 3, 224, 224), requires_grad=True)
torch.onnx.export(first,
first_x,
'yolo.onnx',
export_params=True,
do_constant_folding=True,
input_names=['input'],
output_names=['output'])
torch.onnx.export(second,
second_x,
'resnet.onnx',
export_params=True,
do_constant_folding=True,
input_names=['input'],
output_names=['output'])
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--source', type=str, help='image file or webcam')
parser.add_argument('--onnx',
action='store_true',
dest='onnx',
help='export models to onnx')
opt = parser.parse_args()
if opt.source:
detect(opt.source, 'yolo.pt', 'resnet.pt')
if opt.onnx:
export_to_onnx('yolo.pt', 'resnet.pt')