You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

204 lines
7.6 KiB

11 months ago
import torch
import cv2
import numpy as np
import torch
import os
import importlib
from model.plugins.ModelBase import ModelBase
from loguru import logger
'''
class ModelManager_tmp():
def __init__(self):
print("ModelInit")
def __del__(self):
print("ModelManager DEL")
def __preprocess_image(self,image, cfg, bgr2rgb=True):
"""图片预处理"""
img, scale_ratio, pad_size = letterbox(image, new_shape=cfg['input_shape'])
if bgr2rgb:
img = img[:, :, ::-1]
img = img.transpose(2, 0, 1) # HWC2CHW
img = np.ascontiguousarray(img, dtype=np.float32)
return img, scale_ratio, pad_size
def __draw_bbox(self,bbox, img0, color, wt, names):
"""在图片上画预测框"""
det_result_str = ''
for idx, class_id in enumerate(bbox[:, 5]):
if float(bbox[idx][4] < float(0.05)):
continue
img0 = cv2.rectangle(img0, (int(bbox[idx][0]), int(bbox[idx][1])), (int(bbox[idx][2]), int(bbox[idx][3])),
color, wt)
img0 = cv2.putText(img0, str(idx) + ' ' + names[int(class_id)], (int(bbox[idx][0]), int(bbox[idx][1] + 16)),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1)
img0 = cv2.putText(img0, '{:.4f}'.format(bbox[idx][4]), (int(bbox[idx][0]), int(bbox[idx][1] + 32)),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1)
det_result_str += '{} {} {} {} {} {}\n'.format(
names[bbox[idx][5]], str(bbox[idx][4]), bbox[idx][0], bbox[idx][1], bbox[idx][2], bbox[idx][3])
return img0
def __get_labels_from_txt(self,path):
"""从txt文件获取图片标签"""
labels_dict = dict()
with open(path) as f:
for cat_id, label in enumerate(f.readlines()):
labels_dict[cat_id] = label.strip()
return labels_dict
def __draw_prediction(self,pred, image, labels):
"""在图片上画出预测框并进行可视化展示"""
imgbox = widgets.Image(format='jpg', height=720, width=1280)
img_dw = self.__draw_bbox(pred, image, (0, 255, 0), 2, labels)
imgbox.value = cv2.imencode('.jpg', img_dw)[1].tobytes()
display(imgbox)
def __infer_image(self,img_path, model, class_names, cfg):
"""图片推理"""
# 图片载入
image = cv2.imread(img_path)
# 数据预处理
img, scale_ratio, pad_size = self.__preprocess_image(image, cfg)
# 模型推理
output = model.infer([img])[0]
output = torch.tensor(output)
# 非极大值抑制后处理
boxout = nms(output, conf_thres=cfg["conf_thres"], iou_thres=cfg["iou_thres"])
pred_all = boxout[0].numpy()
# 预测坐标转换
scale_coords(cfg['input_shape'], pred_all[:, :4], image.shape, ratio_pad=(scale_ratio, pad_size))
# 图片预测结果可视化
self.__draw_prediction(pred_all, image, class_names)
def __infer_frame_with_vis(self,image, model, labels_dict, cfg, bgr2rgb=True):
# 数据预处理
img, scale_ratio, pad_size = self.__preprocess_image(image, cfg, bgr2rgb)
# 模型推理
output = model.infer([img])[0]
output = torch.tensor(output)
# 非极大值抑制后处理
boxout = nms(output, conf_thres=cfg["conf_thres"], iou_thres=cfg["iou_thres"])
pred_all = boxout[0].numpy()
# 预测坐标转换
scale_coords(cfg['input_shape'], pred_all[:, :4], image.shape, ratio_pad=(scale_ratio, pad_size))
# 图片预测结果可视化
img_vis = self.__draw_bbox(pred_all, image, (0, 255, 0), 2, labels_dict)
return img_vis
def __img2bytes(self,image):
"""将图片转换为字节码"""
return bytes(cv2.imencode('.jpg', image)[1])
def __infer_camera(self,model, labels_dict, cfg):
"""外设摄像头实时推理"""
def find_camera_index():
max_index_to_check = 10 # Maximum index to check for camera
for index in range(max_index_to_check):
cap = cv2.VideoCapture(index)
if cap.read()[0]:
cap.release()
return index
# If no camera is found
raise ValueError("No camera found.")
# 获取摄像头 --这里可以换成RTSP流
camera_index = find_camera_index()
cap = cv2.VideoCapture(camera_index)
# 初始化可视化对象
image_widget = widgets.Image(format='jpeg', width=1280, height=720)
display(image_widget)
while True:
# 对摄像头每一帧进行推理和可视化
_, img_frame = cap.read()
image_pred = self.__infer_frame_with_vis(img_frame, model, labels_dict, cfg)
image_widget.value = self.__img2bytes(image_pred)
def __infer_video(self,video_path, model, labels_dict, cfg):
"""视频推理"""
image_widget = widgets.Image(format='jpeg', width=800, height=600)
display(image_widget)
# 读入视频
cap = cv2.VideoCapture(video_path)
while True:
ret, img_frame = cap.read()
if not ret:
break
# 对视频帧进行推理
image_pred = self.__infer_frame_with_vis(img_frame, model, labels_dict, cfg, bgr2rgb=True)
image_widget.value = self.__img2bytes(image_pred)
def startWork(self,infer_mode,file_paht = ""):
cfg = {
'conf_thres': 0.4, # 模型置信度阈值,阈值越低,得到的预测框越多
'iou_thres': 0.5, # IOU阈值,高于这个阈值的重叠预测框会被过滤掉
'input_shape': [640, 640], # 模型输入尺寸
}
model_path = 'yolo.om'
label_path = './coco_names.txt'
# 初始化推理模型
model = InferSession(0, model_path)
labels_dict = self.__get_labels_from_txt(label_path)
#执行验证
if infer_mode == 'image':
img_path = 'world_cup.jpg'
self.__infer_image(img_path, model, labels_dict, cfg)
elif infer_mode == 'camera':
self.__infer_camera(model, labels_dict, cfg)
elif infer_mode == 'video':
video_path = 'racing.mp4'
self.__infer_video(video_path, model, labels_dict, cfg)
'''
'''
算法实现类实现算法执行线程根据配内容以线程方式执行算法模块
'''
class ModelManager():
def __init__(self):
print("ModelManager init")
def __del__(self):
print("ModelManager del")
def doWork(self):
pass
#动态导入文件 -- 方法二 -- 相对推荐使用该方法 但spec感觉没什么用
def import_source(spec, plgpath):
module = None
if os.path.exists(plgpath):
module_spec = importlib.util.spec_from_file_location(spec, plgpath)
module = importlib.util.module_from_spec(module_spec)
module_spec.loader.exec_module(module)
else:
logger.error("{}文件不存在".format(plgpath))
return module
#plgpath 为list [poc][file_name][name]
def run_plugin(plgpath, target,copy_flag=True):
module = import_source("", plgpath)
if module:
classname = "Model"
plg = getattr(module, classname)()
if not isinstance(plg, ModelBase):
raise Exception("{} not rx_Model".format(plg))
new_plg = plg
result = new_plg.doWork("","","","") # 执行plugin基类的run, 返回结果
return result
else:
print("模型加载失败")
return None
def test():
run_plugin("plugins/RYRQ_Model_ACL.py","")
11 months ago
if __name__ == "__main__":
test()