smartpi 0.1.35__py3-none-any.whl → 0.1.36__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smartpi/__init__.py +1 -1
- smartpi/camera.py +84 -0
- smartpi/onnx_hand_workflow.py +201 -0
- smartpi/onnx_image_workflow.py +176 -0
- smartpi/onnx_pose_workflow.py +482 -0
- smartpi/onnx_text_workflow.py +173 -0
- smartpi/onnx_voice_workflow.py +437 -0
- smartpi/posenet_utils.py +222 -0
- smartpi/rknn_hand_workflow.py +245 -0
- smartpi/rknn_image_workflow.py +405 -0
- smartpi/rknn_pose_workflow.py +592 -0
- smartpi/rknn_text_workflow.py +240 -0
- smartpi/rknn_voice_workflow.py +394 -0
- {smartpi-0.1.35.dist-info → smartpi-0.1.36.dist-info}/METADATA +1 -1
- smartpi-0.1.36.dist-info/RECORD +32 -0
- smartpi-0.1.35.dist-info/RECORD +0 -20
- {smartpi-0.1.35.dist-info → smartpi-0.1.36.dist-info}/WHEEL +0 -0
- {smartpi-0.1.35.dist-info → smartpi-0.1.36.dist-info}/top_level.txt +0 -0
smartpi/__init__.py
CHANGED
|
@@ -4,5 +4,5 @@ from .base_driver import P1, P2, P3, P4, P5, P6, M1, M2, M3, M4, M5, M6
|
|
|
4
4
|
__all__ = ["base_driver","gui","ultrasonic","touch_sensor","temperature","humidity","light_sensor","color_sensor","motor","servo","led","flash",
|
|
5
5
|
"P1", "P2", "P3", "P4", "P5", "P6", "M1", "M2", "M3", "M4", "M5", "M6"]
|
|
6
6
|
|
|
7
|
-
__version__ = "0.1.
|
|
7
|
+
__version__ = "0.1.36"
|
|
8
8
|
|
smartpi/camera.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
# coding: utf-8
|
|
2
|
+
import cv2
|
|
3
|
+
import os
|
|
4
|
+
import time
|
|
5
|
+
import platform
|
|
6
|
+
|
|
7
|
+
class Camera:
|
|
8
|
+
def __init__(self, indexes=[0, 1, 2, 3], target_width=640, target_height=480):
|
|
9
|
+
self.cap = None
|
|
10
|
+
self.indexes = indexes
|
|
11
|
+
self.target_width = target_width
|
|
12
|
+
self.target_height = target_height
|
|
13
|
+
self.open_camera()
|
|
14
|
+
|
|
15
|
+
def open_camera(self):
|
|
16
|
+
"""打开摄像头(硬件加速+参数优化)"""
|
|
17
|
+
for idx in self.indexes:
|
|
18
|
+
try:
|
|
19
|
+
# 适配linux/Android的V4L2硬件加速(RK芯片优先)
|
|
20
|
+
if platform.system() == "Linux":
|
|
21
|
+
cap = cv2.VideoCapture(idx, cv2.CAP_V4L2)
|
|
22
|
+
# 尝试启用硬件加速(兼容不同OpenCV版本)
|
|
23
|
+
try:
|
|
24
|
+
# 对于较新版本的OpenCV
|
|
25
|
+
if hasattr(cv2, 'CAP_PROP_HW_ACCELERATION') and hasattr(cv2, 'VIDEO_ACCELERATION_ANY'):
|
|
26
|
+
cap.set(cv2.CAP_PROP_HW_ACCELERATION, cv2.VIDEO_ACCELERATION_ANY)
|
|
27
|
+
except AttributeError as ae:
|
|
28
|
+
print(f"硬件加速设置不支持,使用默认配置: {ae}")
|
|
29
|
+
else:
|
|
30
|
+
cap = cv2.VideoCapture(idx)
|
|
31
|
+
|
|
32
|
+
if cap.isOpened():
|
|
33
|
+
# 尝试设置分辨率
|
|
34
|
+
cap.set(cv2.CAP_PROP_FRAME_WIDTH, self.target_width)
|
|
35
|
+
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, self.target_height)
|
|
36
|
+
|
|
37
|
+
# 获取实际设置的分辨率
|
|
38
|
+
actual_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
|
39
|
+
actual_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
|
40
|
+
|
|
41
|
+
print(f"摄像头 {idx} 已打开, 分辨率: {actual_width}x{actual_height}")
|
|
42
|
+
self.cap = cap
|
|
43
|
+
return True
|
|
44
|
+
except Exception as e:
|
|
45
|
+
print(f"尝试打开摄像头 {idx} 失败: {e}")
|
|
46
|
+
continue
|
|
47
|
+
|
|
48
|
+
print("无法打开任何摄像头")
|
|
49
|
+
return False
|
|
50
|
+
|
|
51
|
+
def read_frame(self):
|
|
52
|
+
"""读取一帧并自动处理错误"""
|
|
53
|
+
if not self.cap or not self.cap.isOpened():
|
|
54
|
+
return False, None
|
|
55
|
+
|
|
56
|
+
ret, frame = self.cap.read()
|
|
57
|
+
if not ret:
|
|
58
|
+
print("读取帧失败,尝试重新打开摄像头...")
|
|
59
|
+
self.release()
|
|
60
|
+
time.sleep(1)
|
|
61
|
+
if self.open_camera():
|
|
62
|
+
return self.read_frame()
|
|
63
|
+
return False, None
|
|
64
|
+
|
|
65
|
+
# 调整到目标分辨率
|
|
66
|
+
if frame.shape[1] != self.target_width or frame.shape[0] != self.target_height:
|
|
67
|
+
frame = cv2.resize(frame, (self.target_width, self.target_height))
|
|
68
|
+
|
|
69
|
+
return True, frame
|
|
70
|
+
|
|
71
|
+
def get_resolution(self):
|
|
72
|
+
"""获取当前分辨率"""
|
|
73
|
+
if self.cap and self.cap.isOpened():
|
|
74
|
+
return (
|
|
75
|
+
int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)),
|
|
76
|
+
int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
|
77
|
+
)
|
|
78
|
+
return self.target_width, self.target_height
|
|
79
|
+
|
|
80
|
+
def release(self):
|
|
81
|
+
"""释放摄像头资源"""
|
|
82
|
+
if self.cap and self.cap.isOpened():
|
|
83
|
+
self.cap.release()
|
|
84
|
+
self.cap = None
|
|
@@ -0,0 +1,201 @@
|
|
|
1
|
+
import cv2
|
|
2
|
+
import numpy as np
|
|
3
|
+
import onnxruntime as ort
|
|
4
|
+
import mediapipe as mp
|
|
5
|
+
import json
|
|
6
|
+
from PIL import Image
|
|
7
|
+
import time # 用于时间测量
|
|
8
|
+
|
|
9
|
+
class GestureWorkflow:
|
|
10
|
+
def __init__(self, model_path):
|
|
11
|
+
# 初始化MediaPipe Hands
|
|
12
|
+
self.mp_hands = mp.solutions.hands
|
|
13
|
+
self.hands = self.mp_hands.Hands(
|
|
14
|
+
static_image_mode=False, # 视频流模式 如果只是获取照片的手势关键点 请设置为True
|
|
15
|
+
max_num_hands=1,#如果想要检测双手,请设置成2
|
|
16
|
+
min_detection_confidence=0.5,#手势关键点的阈值
|
|
17
|
+
model_complexity=0#使用最简单的模型 如果效果不准确 可以考虑设置比较复制的模型 1
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
# 初始化元数据
|
|
21
|
+
self.min_vals = None
|
|
22
|
+
self.max_vals = None
|
|
23
|
+
self.class_labels = None
|
|
24
|
+
|
|
25
|
+
# 加载模型和元数据
|
|
26
|
+
self.load_model(model_path)
|
|
27
|
+
|
|
28
|
+
def load_model(self, model_path):
|
|
29
|
+
"""加载模型并解析元数据"""
|
|
30
|
+
# 初始化ONNX Runtime会话
|
|
31
|
+
self.session = ort.InferenceSession(model_path)
|
|
32
|
+
|
|
33
|
+
# 加载元数据
|
|
34
|
+
self._load_metadata()
|
|
35
|
+
|
|
36
|
+
def _load_metadata(self):
|
|
37
|
+
"""从ONNX模型元数据中加载归一化参数和类别标签"""
|
|
38
|
+
model_meta = self.session.get_modelmeta()
|
|
39
|
+
|
|
40
|
+
# 检查custom_metadata_map是否存在
|
|
41
|
+
if hasattr(model_meta, 'custom_metadata_map'):
|
|
42
|
+
metadata = model_meta.custom_metadata_map
|
|
43
|
+
if 'minMaxValues' in metadata:
|
|
44
|
+
min_max_data = json.loads(metadata['minMaxValues'])
|
|
45
|
+
self.min_vals = min_max_data.get('min')
|
|
46
|
+
self.max_vals = min_max_data.get('max')
|
|
47
|
+
|
|
48
|
+
if 'classes' in metadata:
|
|
49
|
+
class_labels = json.loads(metadata['classes'])
|
|
50
|
+
self.class_labels = list(class_labels.values()) if isinstance(class_labels, dict) else class_labels
|
|
51
|
+
else:
|
|
52
|
+
# 对于旧版本的ONNX Runtime,使用metadata_props
|
|
53
|
+
for prop in model_meta.metadata_props:
|
|
54
|
+
if prop.key == 'minMaxValues':
|
|
55
|
+
min_max_data = json.loads(prop.value)
|
|
56
|
+
self.min_vals = min_max_data.get('min')
|
|
57
|
+
self.max_vals = min_max_data.get('max')
|
|
58
|
+
elif prop.key == 'classes':
|
|
59
|
+
class_labels = json.loads(prop.value)
|
|
60
|
+
self.class_labels = list(class_labels.values()) if isinstance(class_labels, dict) else class_labels
|
|
61
|
+
|
|
62
|
+
# 设置默认值
|
|
63
|
+
if self.class_labels is None:
|
|
64
|
+
self.class_labels = ["点赞", "点踩", "胜利", "拳头", "我爱你", "手掌"]
|
|
65
|
+
|
|
66
|
+
def preprocess_image(self, image, target_width=224, target_height=224):
|
|
67
|
+
"""
|
|
68
|
+
预处理图像:保持比例缩放并居中放置在目标尺寸的画布上
|
|
69
|
+
返回处理后的OpenCV图像 (BGR格式)
|
|
70
|
+
"""
|
|
71
|
+
# 将OpenCV图像转换为PIL格式
|
|
72
|
+
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
|
73
|
+
pil_image = Image.fromarray(image_rgb)
|
|
74
|
+
|
|
75
|
+
# 计算缩放比例
|
|
76
|
+
width, height = pil_image.size
|
|
77
|
+
scale = min(target_width / width, target_height / height)
|
|
78
|
+
|
|
79
|
+
# 计算新尺寸和位置
|
|
80
|
+
new_width = int(width * scale)
|
|
81
|
+
new_height = int(height * scale)
|
|
82
|
+
x = (target_width - new_width) // 2
|
|
83
|
+
y = (target_height - new_height) // 2
|
|
84
|
+
|
|
85
|
+
# 创建白色背景画布并粘贴缩放后的图像
|
|
86
|
+
canvas = Image.new('RGB', (target_width, target_height), (255, 255, 255))
|
|
87
|
+
resized_image = pil_image.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
|
88
|
+
canvas.paste(resized_image, (x, y))
|
|
89
|
+
|
|
90
|
+
# 转换回OpenCV格式
|
|
91
|
+
processed_image = np.array(canvas)
|
|
92
|
+
return cv2.cvtColor(processed_image, cv2.COLOR_RGB2BGR)
|
|
93
|
+
|
|
94
|
+
def extract_hand_keypoints(self, image):
|
|
95
|
+
"""从图像中提取手部关键点"""
|
|
96
|
+
# 转换图像为RGB格式并处理
|
|
97
|
+
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
|
98
|
+
results = self.hands.process(image_rgb)
|
|
99
|
+
|
|
100
|
+
if results.multi_hand_landmarks:
|
|
101
|
+
# 只使用检测到的第一只手
|
|
102
|
+
landmarks = results.multi_hand_world_landmarks[0]
|
|
103
|
+
|
|
104
|
+
# 提取关键点坐标
|
|
105
|
+
keypoints = []
|
|
106
|
+
for landmark in landmarks.landmark:
|
|
107
|
+
keypoints.extend([landmark.x, landmark.y, landmark.z])
|
|
108
|
+
|
|
109
|
+
return np.array(keypoints, dtype=np.float32)
|
|
110
|
+
return None
|
|
111
|
+
|
|
112
|
+
def normalize_keypoints(self, keypoints):
|
|
113
|
+
"""归一化关键点数据"""
|
|
114
|
+
if self.min_vals is None or self.max_vals is None:
|
|
115
|
+
return keypoints # 如果没有归一化参数,返回原始数据
|
|
116
|
+
|
|
117
|
+
normalized = []
|
|
118
|
+
for i, value in enumerate(keypoints):
|
|
119
|
+
if i < len(self.min_vals) and i < len(self.max_vals):
|
|
120
|
+
min_val = self.min_vals[i]
|
|
121
|
+
max_val = self.max_vals[i]
|
|
122
|
+
if max_val - min_val > 0:
|
|
123
|
+
normalized.append((value - min_val) / (max_val - min_val))
|
|
124
|
+
else:
|
|
125
|
+
normalized.append(0)
|
|
126
|
+
else:
|
|
127
|
+
normalized.append(value)
|
|
128
|
+
|
|
129
|
+
return np.array(normalized, dtype=np.float32)
|
|
130
|
+
|
|
131
|
+
def predict_frame(self, frame):
|
|
132
|
+
"""执行手势分类预测(直接处理图像帧)"""
|
|
133
|
+
# 记录开始时间
|
|
134
|
+
start_time = time.time()
|
|
135
|
+
# 预处理图像
|
|
136
|
+
processed_image = self.preprocess_image(frame, 224, 224)
|
|
137
|
+
|
|
138
|
+
# 提取关键点
|
|
139
|
+
keypoints = self.extract_hand_keypoints(processed_image)
|
|
140
|
+
min_time = time.time()
|
|
141
|
+
hand_time = min_time - start_time
|
|
142
|
+
#print(f"关键点识别耗时: {hand_time:.4f}秒")
|
|
143
|
+
if keypoints is None:
|
|
144
|
+
return None, {"error": "未检测到手部"}
|
|
145
|
+
|
|
146
|
+
# 归一化关键点
|
|
147
|
+
normalized_kps = self.normalize_keypoints(keypoints)
|
|
148
|
+
|
|
149
|
+
# 准备ONNX输入
|
|
150
|
+
input_data = normalized_kps.reshape(1, -1).astype(np.float32)
|
|
151
|
+
|
|
152
|
+
# 运行推理
|
|
153
|
+
input_name = self.session.get_inputs()[0].name
|
|
154
|
+
outputs = self.session.run(None, {input_name: input_data})
|
|
155
|
+
predictions = outputs[0][0]
|
|
156
|
+
|
|
157
|
+
# 获取预测结果
|
|
158
|
+
class_id = np.argmax(predictions)
|
|
159
|
+
confidence = float(predictions[class_id])
|
|
160
|
+
|
|
161
|
+
# 获取类别标签
|
|
162
|
+
label = self.class_labels[class_id] if class_id < len(self.class_labels) else f"未知类别 {class_id}"
|
|
163
|
+
end_time = time.time()
|
|
164
|
+
all_time = end_time - start_time
|
|
165
|
+
onnx_time = end_time - min_time
|
|
166
|
+
print(f"onnx耗时: {onnx_time:.4f}秒")
|
|
167
|
+
print(f"总耗时: {all_time:.4f}秒")
|
|
168
|
+
# 返回原始结果和格式化结果
|
|
169
|
+
raw_result = predictions.tolist()
|
|
170
|
+
formatted_result = {
|
|
171
|
+
'class': label,
|
|
172
|
+
'confidence': confidence,
|
|
173
|
+
'class_id': class_id,
|
|
174
|
+
'probabilities': raw_result
|
|
175
|
+
}
|
|
176
|
+
|
|
177
|
+
return raw_result, formatted_result
|
|
178
|
+
|
|
179
|
+
# 保留原始方法以兼容旧代码
|
|
180
|
+
def predict(self, image_path):
|
|
181
|
+
"""执行手势分类预测(从文件路径)"""
|
|
182
|
+
try:
|
|
183
|
+
# 使用PIL库读取图像,避免libpng版本问题
|
|
184
|
+
pil_image = Image.open(image_path)
|
|
185
|
+
# 转换为RGB格式
|
|
186
|
+
rgb_image = pil_image.convert('RGB')
|
|
187
|
+
# 转换为numpy数组
|
|
188
|
+
image_array = np.array(rgb_image)
|
|
189
|
+
# 转换为BGR格式(OpenCV使用的格式)
|
|
190
|
+
image = cv2.cvtColor(image_array, cv2.COLOR_RGB2BGR)
|
|
191
|
+
|
|
192
|
+
if image is None:
|
|
193
|
+
raise ValueError(f"无法读取图像: {image_path}")
|
|
194
|
+
|
|
195
|
+
return self.predict_frame(image)
|
|
196
|
+
except Exception as e:
|
|
197
|
+
# 如果PIL失败,尝试使用cv2作为备选
|
|
198
|
+
image = cv2.imread(image_path)
|
|
199
|
+
if image is None:
|
|
200
|
+
raise ValueError(f"无法读取图像: {image_path}")
|
|
201
|
+
return self.predict_frame(image)
|
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
import onnxruntime as ort
|
|
2
|
+
import numpy as np
|
|
3
|
+
from PIL import Image
|
|
4
|
+
import onnx
|
|
5
|
+
import cv2
|
|
6
|
+
import time
|
|
7
|
+
|
|
8
|
+
class ImageWorkflow:
|
|
9
|
+
def __init__(self, model_path=None):
|
|
10
|
+
self.session = None
|
|
11
|
+
self.classes = []
|
|
12
|
+
self.metadata = {}
|
|
13
|
+
self.input_shape = [1, 224, 224, 3] # 默认输入形状
|
|
14
|
+
|
|
15
|
+
if model_path:
|
|
16
|
+
self.load_model(model_path)
|
|
17
|
+
|
|
18
|
+
def load_model(self, model_path):
|
|
19
|
+
"""加载模型并解析元数据"""
|
|
20
|
+
try:
|
|
21
|
+
# 读取ONNX元数据
|
|
22
|
+
onnx_model = onnx.load(model_path)
|
|
23
|
+
for meta in onnx_model.metadata_props:
|
|
24
|
+
self.metadata[meta.key] = meta.value
|
|
25
|
+
|
|
26
|
+
# 解析类别标签
|
|
27
|
+
if 'classes' in self.metadata:
|
|
28
|
+
self.classes = eval(self.metadata['classes'])
|
|
29
|
+
|
|
30
|
+
# 初始化推理会话
|
|
31
|
+
self.session = ort.InferenceSession(model_path)
|
|
32
|
+
self._parse_input_shape()
|
|
33
|
+
|
|
34
|
+
except Exception as e:
|
|
35
|
+
print(f"模型加载失败: {e}")
|
|
36
|
+
|
|
37
|
+
def _parse_input_shape(self):
|
|
38
|
+
"""自动解析输入形状"""
|
|
39
|
+
input_info = self.session.get_inputs()[0]
|
|
40
|
+
shape = []
|
|
41
|
+
for dim in input_info.shape:
|
|
42
|
+
# 处理动态维度(用1替代)
|
|
43
|
+
shape.append(1 if isinstance(dim, str) or dim < 0 else int(dim))
|
|
44
|
+
self.input_shape = shape
|
|
45
|
+
|
|
46
|
+
def _preprocess(self, image_path):
|
|
47
|
+
"""标准化预处理流程"""
|
|
48
|
+
try:
|
|
49
|
+
img = Image.open(image_path).convert("RGB")
|
|
50
|
+
|
|
51
|
+
# 获取目标尺寸(假设形状为 [N, H, W, C])
|
|
52
|
+
_, target_h, target_w, _ = self.input_shape
|
|
53
|
+
|
|
54
|
+
# 调整尺寸
|
|
55
|
+
img = img.resize((target_w, target_h), Image.BILINEAR)
|
|
56
|
+
|
|
57
|
+
# 转换为numpy数组并归一化
|
|
58
|
+
img_array = np.array(img).astype(np.float32) / 255.0
|
|
59
|
+
|
|
60
|
+
# 添加batch维度
|
|
61
|
+
return np.expand_dims(img_array, axis=0)
|
|
62
|
+
|
|
63
|
+
except Exception as e:
|
|
64
|
+
print(f"图像预处理失败: {e}")
|
|
65
|
+
return None
|
|
66
|
+
|
|
67
|
+
def inference(self, data, model_path=None):
|
|
68
|
+
"""执行推理"""
|
|
69
|
+
if model_path and not self.session:
|
|
70
|
+
self.load_model(model_path)
|
|
71
|
+
|
|
72
|
+
input_data = self._preprocess(data)
|
|
73
|
+
if input_data is None:
|
|
74
|
+
return None, None
|
|
75
|
+
|
|
76
|
+
try:
|
|
77
|
+
# 运行推理
|
|
78
|
+
outputs = self.session.run(None, {self.session.get_inputs()[0].name: input_data})
|
|
79
|
+
raw = outputs[0][0] # 假设输出形状为 [1, n_classes]
|
|
80
|
+
|
|
81
|
+
# 格式化输出
|
|
82
|
+
formatted = self._format_result(raw)
|
|
83
|
+
|
|
84
|
+
return raw, formatted
|
|
85
|
+
|
|
86
|
+
except Exception as e:
|
|
87
|
+
print(f"推理失败: {e}")
|
|
88
|
+
return None, None
|
|
89
|
+
|
|
90
|
+
def inference_frame(self, frame_data, model_path=None):
|
|
91
|
+
"""直接使用帧数据进行推理,无需文件IO
|
|
92
|
+
返回值:raw, formatted
|
|
93
|
+
formatted字典包含:class, confidence, probabilities, preprocess_time, inference_time
|
|
94
|
+
"""
|
|
95
|
+
if model_path and not self.session:
|
|
96
|
+
self.load_model(model_path)
|
|
97
|
+
|
|
98
|
+
# 测量预处理时间
|
|
99
|
+
preprocess_start = time.time()
|
|
100
|
+
input_data = self._preprocess_frame(frame_data)
|
|
101
|
+
preprocess_time = time.time() - preprocess_start
|
|
102
|
+
|
|
103
|
+
if input_data is None:
|
|
104
|
+
return None, None
|
|
105
|
+
|
|
106
|
+
try:
|
|
107
|
+
# 测量推理时间
|
|
108
|
+
inference_start = time.time()
|
|
109
|
+
# 运行推理
|
|
110
|
+
outputs = self.session.run(None, {self.session.get_inputs()[0].name: input_data})
|
|
111
|
+
inference_time = time.time() - inference_start
|
|
112
|
+
|
|
113
|
+
raw = outputs[0][0] # 假设输出形状为 [1, n_classes]
|
|
114
|
+
|
|
115
|
+
# 格式化输出
|
|
116
|
+
formatted = self._format_result(raw)
|
|
117
|
+
# 添加时间信息到返回结果
|
|
118
|
+
formatted['preprocess_time'] = preprocess_time
|
|
119
|
+
formatted['inference_time'] = inference_time
|
|
120
|
+
|
|
121
|
+
# 计算总耗时
|
|
122
|
+
total_time = preprocess_time + inference_time
|
|
123
|
+
print(f"帧推理耗时: {total_time:.4f}秒 - 识别结果: {formatted['class']} ({formatted['confidence']}%)")
|
|
124
|
+
return raw, formatted
|
|
125
|
+
|
|
126
|
+
except Exception as e:
|
|
127
|
+
print(f"帧数据推理失败: {e}")
|
|
128
|
+
return None, None
|
|
129
|
+
|
|
130
|
+
def _preprocess_frame(self, frame_data):
|
|
131
|
+
"""处理帧数据的预处理流程"""
|
|
132
|
+
try:
|
|
133
|
+
# 确保输入是numpy数组
|
|
134
|
+
if not isinstance(frame_data, np.ndarray):
|
|
135
|
+
print("错误: 帧数据必须是numpy数组")
|
|
136
|
+
return None
|
|
137
|
+
|
|
138
|
+
# OpenCV读取的帧是BGR格式,转换为RGB
|
|
139
|
+
img = cv2.cvtColor(frame_data, cv2.COLOR_BGR2RGB)
|
|
140
|
+
|
|
141
|
+
# 获取目标尺寸(假设形状为 [N, H, W, C])
|
|
142
|
+
_, target_h, target_w, _ = self.input_shape
|
|
143
|
+
|
|
144
|
+
# 调整尺寸
|
|
145
|
+
img = cv2.resize(img, (target_w, target_h), interpolation=cv2.INTER_LINEAR)
|
|
146
|
+
|
|
147
|
+
# 转换为numpy数组并归一化
|
|
148
|
+
img_array = img.astype(np.float32) / 255.0
|
|
149
|
+
|
|
150
|
+
# 添加batch维度
|
|
151
|
+
return np.expand_dims(img_array, axis=0)
|
|
152
|
+
|
|
153
|
+
except Exception as e:
|
|
154
|
+
print(f"帧数据预处理失败: {e}")
|
|
155
|
+
return None
|
|
156
|
+
|
|
157
|
+
def _format_result(self, predictions):
|
|
158
|
+
"""生成标准化输出"""
|
|
159
|
+
class_idx = np.argmax(predictions)
|
|
160
|
+
confidence = int(predictions[class_idx] * 100)
|
|
161
|
+
|
|
162
|
+
return {
|
|
163
|
+
'class': self.classes[class_idx] if self.classes else str(class_idx),
|
|
164
|
+
'confidence': confidence,
|
|
165
|
+
'probabilities': predictions.tolist()
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
# 使用示例
|
|
169
|
+
if __name__ == "__main__":
|
|
170
|
+
# 预加载模型
|
|
171
|
+
model = ImageWorkflow("model.onnx")
|
|
172
|
+
|
|
173
|
+
# 使用帧数据进行推理
|
|
174
|
+
# 假设frame是通过cv2获取的帧
|
|
175
|
+
# raw, res = model.inference_frame(frame)
|
|
176
|
+
# print(f"识别结果: {res['class']} ({res['confidence']}%)")
|