You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
204 lines
7.6 KiB
204 lines
7.6 KiB
import torch
|
|
import cv2
|
|
import numpy as np
|
|
import torch
|
|
import os
|
|
import importlib
|
|
from model.plugins.ModelBase import ModelBase
|
|
from loguru import logger
|
|
|
|
'''
|
|
class ModelManager_tmp():
|
|
def __init__(self):
|
|
print("ModelInit")
|
|
|
|
def __del__(self):
|
|
print("ModelManager DEL")
|
|
|
|
def __preprocess_image(self,image, cfg, bgr2rgb=True):
|
|
"""图片预处理"""
|
|
img, scale_ratio, pad_size = letterbox(image, new_shape=cfg['input_shape'])
|
|
if bgr2rgb:
|
|
img = img[:, :, ::-1]
|
|
img = img.transpose(2, 0, 1) # HWC2CHW
|
|
img = np.ascontiguousarray(img, dtype=np.float32)
|
|
return img, scale_ratio, pad_size
|
|
|
|
def __draw_bbox(self,bbox, img0, color, wt, names):
|
|
"""在图片上画预测框"""
|
|
det_result_str = ''
|
|
for idx, class_id in enumerate(bbox[:, 5]):
|
|
if float(bbox[idx][4] < float(0.05)):
|
|
continue
|
|
img0 = cv2.rectangle(img0, (int(bbox[idx][0]), int(bbox[idx][1])), (int(bbox[idx][2]), int(bbox[idx][3])),
|
|
color, wt)
|
|
img0 = cv2.putText(img0, str(idx) + ' ' + names[int(class_id)], (int(bbox[idx][0]), int(bbox[idx][1] + 16)),
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1)
|
|
img0 = cv2.putText(img0, '{:.4f}'.format(bbox[idx][4]), (int(bbox[idx][0]), int(bbox[idx][1] + 32)),
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1)
|
|
det_result_str += '{} {} {} {} {} {}\n'.format(
|
|
names[bbox[idx][5]], str(bbox[idx][4]), bbox[idx][0], bbox[idx][1], bbox[idx][2], bbox[idx][3])
|
|
return img0
|
|
|
|
def __get_labels_from_txt(self,path):
|
|
"""从txt文件获取图片标签"""
|
|
labels_dict = dict()
|
|
with open(path) as f:
|
|
for cat_id, label in enumerate(f.readlines()):
|
|
labels_dict[cat_id] = label.strip()
|
|
return labels_dict
|
|
|
|
def __draw_prediction(self,pred, image, labels):
|
|
"""在图片上画出预测框并进行可视化展示"""
|
|
imgbox = widgets.Image(format='jpg', height=720, width=1280)
|
|
img_dw = self.__draw_bbox(pred, image, (0, 255, 0), 2, labels)
|
|
imgbox.value = cv2.imencode('.jpg', img_dw)[1].tobytes()
|
|
display(imgbox)
|
|
|
|
def __infer_image(self,img_path, model, class_names, cfg):
|
|
"""图片推理"""
|
|
# 图片载入
|
|
image = cv2.imread(img_path)
|
|
# 数据预处理
|
|
img, scale_ratio, pad_size = self.__preprocess_image(image, cfg)
|
|
# 模型推理
|
|
output = model.infer([img])[0]
|
|
|
|
output = torch.tensor(output)
|
|
# 非极大值抑制后处理
|
|
boxout = nms(output, conf_thres=cfg["conf_thres"], iou_thres=cfg["iou_thres"])
|
|
pred_all = boxout[0].numpy()
|
|
# 预测坐标转换
|
|
scale_coords(cfg['input_shape'], pred_all[:, :4], image.shape, ratio_pad=(scale_ratio, pad_size))
|
|
# 图片预测结果可视化
|
|
self.__draw_prediction(pred_all, image, class_names)
|
|
|
|
def __infer_frame_with_vis(self,image, model, labels_dict, cfg, bgr2rgb=True):
|
|
# 数据预处理
|
|
img, scale_ratio, pad_size = self.__preprocess_image(image, cfg, bgr2rgb)
|
|
# 模型推理
|
|
output = model.infer([img])[0]
|
|
|
|
output = torch.tensor(output)
|
|
# 非极大值抑制后处理
|
|
boxout = nms(output, conf_thres=cfg["conf_thres"], iou_thres=cfg["iou_thres"])
|
|
pred_all = boxout[0].numpy()
|
|
# 预测坐标转换
|
|
scale_coords(cfg['input_shape'], pred_all[:, :4], image.shape, ratio_pad=(scale_ratio, pad_size))
|
|
# 图片预测结果可视化
|
|
img_vis = self.__draw_bbox(pred_all, image, (0, 255, 0), 2, labels_dict)
|
|
return img_vis
|
|
|
|
def __img2bytes(self,image):
|
|
"""将图片转换为字节码"""
|
|
return bytes(cv2.imencode('.jpg', image)[1])
|
|
def __infer_camera(self,model, labels_dict, cfg):
|
|
"""外设摄像头实时推理"""
|
|
|
|
def find_camera_index():
|
|
max_index_to_check = 10 # Maximum index to check for camera
|
|
|
|
for index in range(max_index_to_check):
|
|
cap = cv2.VideoCapture(index)
|
|
if cap.read()[0]:
|
|
cap.release()
|
|
return index
|
|
|
|
# If no camera is found
|
|
raise ValueError("No camera found.")
|
|
|
|
# 获取摄像头 --这里可以换成RTSP流
|
|
camera_index = find_camera_index()
|
|
cap = cv2.VideoCapture(camera_index)
|
|
# 初始化可视化对象
|
|
image_widget = widgets.Image(format='jpeg', width=1280, height=720)
|
|
display(image_widget)
|
|
while True:
|
|
# 对摄像头每一帧进行推理和可视化
|
|
_, img_frame = cap.read()
|
|
image_pred = self.__infer_frame_with_vis(img_frame, model, labels_dict, cfg)
|
|
image_widget.value = self.__img2bytes(image_pred)
|
|
|
|
def __infer_video(self,video_path, model, labels_dict, cfg):
|
|
"""视频推理"""
|
|
image_widget = widgets.Image(format='jpeg', width=800, height=600)
|
|
display(image_widget)
|
|
|
|
# 读入视频
|
|
cap = cv2.VideoCapture(video_path)
|
|
while True:
|
|
ret, img_frame = cap.read()
|
|
if not ret:
|
|
break
|
|
# 对视频帧进行推理
|
|
image_pred = self.__infer_frame_with_vis(img_frame, model, labels_dict, cfg, bgr2rgb=True)
|
|
image_widget.value = self.__img2bytes(image_pred)
|
|
|
|
def startWork(self,infer_mode,file_paht = ""):
|
|
cfg = {
|
|
'conf_thres': 0.4, # 模型置信度阈值,阈值越低,得到的预测框越多
|
|
'iou_thres': 0.5, # IOU阈值,高于这个阈值的重叠预测框会被过滤掉
|
|
'input_shape': [640, 640], # 模型输入尺寸
|
|
}
|
|
|
|
model_path = 'yolo.om'
|
|
label_path = './coco_names.txt'
|
|
# 初始化推理模型
|
|
model = InferSession(0, model_path)
|
|
labels_dict = self.__get_labels_from_txt(label_path)
|
|
|
|
#执行验证
|
|
if infer_mode == 'image':
|
|
img_path = 'world_cup.jpg'
|
|
self.__infer_image(img_path, model, labels_dict, cfg)
|
|
elif infer_mode == 'camera':
|
|
self.__infer_camera(model, labels_dict, cfg)
|
|
elif infer_mode == 'video':
|
|
video_path = 'racing.mp4'
|
|
self.__infer_video(video_path, model, labels_dict, cfg)
|
|
'''
|
|
|
|
'''
|
|
算法实现类,实现算法执行线程,根据配内容,以线程方式执行算法模块
|
|
'''
|
|
class ModelManager():
|
|
def __init__(self):
|
|
print("ModelManager init")
|
|
|
|
def __del__(self):
|
|
print("ModelManager del")
|
|
|
|
def doWork(self):
|
|
pass
|
|
|
|
#动态导入文件 -- 方法二 -- 相对推荐使用该方法 但spec感觉没什么用
|
|
def import_source(spec, plgpath):
|
|
module = None
|
|
if os.path.exists(plgpath):
|
|
module_spec = importlib.util.spec_from_file_location(spec, plgpath)
|
|
module = importlib.util.module_from_spec(module_spec)
|
|
module_spec.loader.exec_module(module)
|
|
else:
|
|
logger.error("{}文件不存在".format(plgpath))
|
|
return module
|
|
|
|
#plgpath 为list [poc][file_name][name]
|
|
def run_plugin(plgpath, target,copy_flag=True):
|
|
module = import_source("", plgpath)
|
|
if module:
|
|
classname = "Model"
|
|
plg = getattr(module, classname)()
|
|
if not isinstance(plg, ModelBase):
|
|
raise Exception("{} not rx_Model".format(plg))
|
|
new_plg = plg
|
|
result = new_plg.doWork("","","","") # 执行plugin基类的run, 返回结果
|
|
return result
|
|
else:
|
|
print("模型加载失败")
|
|
return None
|
|
|
|
def test():
|
|
run_plugin("plugins/RYRQ_Model_ACL.py","")
|
|
|
|
if __name__ == "__main__":
|
|
test()
|