Implement API for predictions
This commit is contained in:
parent
8a9c3c1edc
commit
19e538bdff
@ -1,12 +1,22 @@
|
|||||||
|
import argparse
|
||||||
|
import logging
|
||||||
import atexit
|
import atexit
|
||||||
import jetson.utils
|
import os
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import json
|
||||||
|
import base64
|
||||||
|
import datetime
|
||||||
|
from dateutil.parser import parse
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
from flask import Flask
|
from flask import Flask
|
||||||
from apscheduler.schedulers.background import BackgroundScheduler
|
from apscheduler.schedulers.background import BackgroundScheduler
|
||||||
|
|
||||||
from multiprocessing import Manager
|
from multiprocessing import Manager
|
||||||
|
|
||||||
from code.evaluation.detection import detect
|
from evaluation.detection import detect
|
||||||
|
from utils.manipulations import draw_boxes
|
||||||
|
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
scheduler = BackgroundScheduler(daemon=True)
|
scheduler = BackgroundScheduler(daemon=True)
|
||||||
@ -14,39 +24,130 @@ manager = Manager()
|
|||||||
pred = manager.dict()
|
pred = manager.dict()
|
||||||
|
|
||||||
|
|
||||||
@scheduler.task('interval', id='get_pred', minutes=30, misfire_grace_time=900)
|
|
||||||
def get_pred():
|
def get_pred():
|
||||||
img = take_image('./current_image.jpg')
|
tmp = deepcopy(pred)
|
||||||
print('Job 1 executed')
|
take_image()
|
||||||
|
logging.debug('Starting image classification')
|
||||||
|
preds = detect('current.jpg', '../weights/yolo.onnx',
|
||||||
|
'../weights/resnet.onnx')
|
||||||
|
logging.debug('Finished image classification: %s', preds)
|
||||||
|
logging.debug('Reading current.jpg for drawing bounding boxes')
|
||||||
|
current = cv2.imread('current.jpg')
|
||||||
|
logging.debug('Drawing bounding boxes on current.jpg')
|
||||||
|
bbox_img = draw_boxes(
|
||||||
|
current, preds[['xmin', 'ymin', 'xmax',
|
||||||
|
'ymax']].itertuples(index=False, name=None))
|
||||||
|
logging.debug(
|
||||||
|
'Finished drawing bounding boxes. Saving to current_bbox.jpg ...')
|
||||||
|
cv2.imwrite('current_bbox.jpg', bbox_img)
|
||||||
|
|
||||||
|
# Clear superfluous bboxes if less detected
|
||||||
|
# if len(preds) < len(pred):
|
||||||
|
# logging.debug(
|
||||||
|
# 'Current round contains less bboxes than previous round: old: %s\nnew: %s',
|
||||||
|
# json.dumps(preds.copy()), json.dumps(pred.copy()))
|
||||||
|
# for key in pred:
|
||||||
|
# if key not in preds:
|
||||||
|
# pred.pop(key)
|
||||||
|
|
||||||
|
pred.clear()
|
||||||
|
for idx, row in preds.iterrows():
|
||||||
|
new = []
|
||||||
|
state = int(round(float(row['cls_conf']) / 10, 0))
|
||||||
|
new.append(state)
|
||||||
|
new.append(datetime.datetime.now(datetime.timezone.utc).isoformat())
|
||||||
|
try:
|
||||||
|
pred[str(idx)] = tmp[str(idx)]
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
if tmp[idx][2] == -1 and state > 3:
|
||||||
|
logging.debug(
|
||||||
|
'State is worse than 3 for the first time: %s ... populating third field',
|
||||||
|
str(state))
|
||||||
|
new.append(
|
||||||
|
datetime.datetime.now(datetime.timezone.utc).isoformat())
|
||||||
|
elif tmp[idx][2] != -1 and state <= 3:
|
||||||
|
logging.debug(
|
||||||
|
'State changed from worse than 3 to better than 3')
|
||||||
|
new.append(-1)
|
||||||
|
elif tmp[idx][2] != -1 and state > 3:
|
||||||
|
logging.debug('State is still worse than 3')
|
||||||
|
new.append(tmp[idx][2])
|
||||||
|
except:
|
||||||
|
logging.debug('Third key does not exist')
|
||||||
|
if state > 3:
|
||||||
|
logging.debug(
|
||||||
|
'State is worse than 3. Populating third field with timestamp'
|
||||||
|
)
|
||||||
|
new.append(
|
||||||
|
datetime.datetime.now(datetime.timezone.utc).isoformat())
|
||||||
|
else:
|
||||||
|
logging.debug(
|
||||||
|
'State is better than 3. Populating third field with -1')
|
||||||
|
new.append(-1)
|
||||||
|
|
||||||
|
pred[idx] = new
|
||||||
|
|
||||||
|
bbox_img_b64 = base64.b64encode(cv2.imencode('.jpg', bbox_img)[1]).decode()
|
||||||
|
pred['image'] = bbox_img_b64
|
||||||
|
logging.debug('Saved bbox_img to json')
|
||||||
|
|
||||||
|
|
||||||
def take_image(img_path: str):
|
def take_image():
|
||||||
"""Take an image with the webcam and save it to the specified
|
"""Take an image with the webcam and save it to the specified
|
||||||
path.
|
path.
|
||||||
|
|
||||||
:param str img_path: path image should be saved to
|
:param str img_path: path image should be saved to
|
||||||
:returns: captured image
|
:returns: captured image
|
||||||
|
|
||||||
"""
|
"""
|
||||||
input = jetson.utils.videoSource('csi://0')
|
capture_path = os.path.join('.', 'image-capture', 'capture')
|
||||||
output = jetson.utils.videoOutput(img_path)
|
if os.path.isfile(capture_path) and os.access(capture_path, os.X_OK):
|
||||||
img = input.Capture()
|
logging.debug('Starting image capture')
|
||||||
output.Render(img)
|
os.system('./image-capture/capture')
|
||||||
return img
|
logging.debug('Finished image capture')
|
||||||
|
else:
|
||||||
|
logging.critical(
|
||||||
|
'Image capture binary is not at path %s. Shutting down server...',
|
||||||
|
capture_path)
|
||||||
|
app.terminate()
|
||||||
|
|
||||||
|
|
||||||
@app.route('/')
|
@app.route('/')
|
||||||
def index():
|
def index():
|
||||||
# TODO: call script and save initial image with bounding boxes
|
copy = pred.copy()
|
||||||
# TODO: get predictions and output them in JSON via API
|
for key, value in copy.items():
|
||||||
# TODO: periodically get image from webcam and go to beginning
|
if key == 'image':
|
||||||
# TODO: JSON format: [Nr, state (1-10), timestamp, time since below 3]
|
continue
|
||||||
return 'Server works'
|
if value[2] != -1:
|
||||||
|
logging.debug('value: %s', value)
|
||||||
|
# Calc difference to now
|
||||||
|
time_below_thresh = parse(value[2])
|
||||||
|
now = datetime.datetime.now(datetime.timezone.utc)
|
||||||
|
delta = now - time_below_thresh
|
||||||
|
value[2] = str(delta)
|
||||||
|
return json.dumps(copy)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
scheduler.add_job(func=get_pred, trigger='interval', minutes=30)
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
'--log',
|
||||||
|
type=str,
|
||||||
|
help='log level (debug, info, warning, error, critical)',
|
||||||
|
default='warning')
|
||||||
|
opt = parser.parse_args()
|
||||||
|
numeric_level = getattr(logging, opt.log.upper(), None)
|
||||||
|
logging.basicConfig(format='%(levelname)s::%(asctime)s::%(message)s',
|
||||||
|
datefmt='%Y-%m-%dT%H:%M:%S',
|
||||||
|
level=numeric_level)
|
||||||
|
|
||||||
|
scheduler.add_job(func=get_pred,
|
||||||
|
trigger='interval',
|
||||||
|
minutes=2,
|
||||||
|
next_run_time=datetime.datetime.now())
|
||||||
scheduler.start()
|
scheduler.start()
|
||||||
atexit.register(scheduler.shutdown())
|
atexit.register(scheduler.shutdown)
|
||||||
|
|
||||||
app.run()
|
app.run()
|
||||||
|
|||||||
@ -5,6 +5,7 @@ def draw_boxes(image, bboxes):
|
|||||||
img = image.copy()
|
img = image.copy()
|
||||||
for idx, bbox in enumerate(bboxes):
|
for idx, bbox in enumerate(bboxes):
|
||||||
xmin, ymin, xmax, ymax = bbox
|
xmin, ymin, xmax, ymax = bbox
|
||||||
|
xmin, ymin, xmax, ymax = int(xmin), int(ymin), int(xmax), int(ymax)
|
||||||
# Draw bounding box and number on original image
|
# Draw bounding box and number on original image
|
||||||
img = cv2.rectangle(img, (xmin, ymin), (xmax, ymax), (0, 255, 0), 2)
|
img = cv2.rectangle(img, (xmin, ymin), (xmax, ymax), (0, 255, 0), 2)
|
||||||
img = cv2.putText(img, str(idx), (xmin + 5, ymin + 25),
|
img = cv2.putText(img, str(idx), (xmin + 5, ymin + 25),
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user