smartpi 0.1.35__py3-none-any.whl → 0.1.36__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smartpi/__init__.py +1 -1
- smartpi/camera.py +84 -0
- smartpi/onnx_hand_workflow.py +201 -0
- smartpi/onnx_image_workflow.py +176 -0
- smartpi/onnx_pose_workflow.py +482 -0
- smartpi/onnx_text_workflow.py +173 -0
- smartpi/onnx_voice_workflow.py +437 -0
- smartpi/posenet_utils.py +222 -0
- smartpi/rknn_hand_workflow.py +245 -0
- smartpi/rknn_image_workflow.py +405 -0
- smartpi/rknn_pose_workflow.py +592 -0
- smartpi/rknn_text_workflow.py +240 -0
- smartpi/rknn_voice_workflow.py +394 -0
- {smartpi-0.1.35.dist-info → smartpi-0.1.36.dist-info}/METADATA +1 -1
- smartpi-0.1.36.dist-info/RECORD +32 -0
- smartpi-0.1.35.dist-info/RECORD +0 -20
- {smartpi-0.1.35.dist-info → smartpi-0.1.36.dist-info}/WHEEL +0 -0
- {smartpi-0.1.35.dist-info → smartpi-0.1.36.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,482 @@
|
|
|
1
|
+
import onnxruntime as ort
|
|
2
|
+
import numpy as np
|
|
3
|
+
from PIL import Image
|
|
4
|
+
import onnx
|
|
5
|
+
import json
|
|
6
|
+
import cv2
|
|
7
|
+
import time # 用于实时耗时计算
|
|
8
|
+
from lib.posenet_utils import get_posenet_output
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class PoseWorkflow:
|
|
12
|
+
def __init__(self, model_path=None):
|
|
13
|
+
self.session = None
|
|
14
|
+
self.classes = []
|
|
15
|
+
self.metadata = {}
|
|
16
|
+
self.input_shape = []
|
|
17
|
+
self.output_shape = []
|
|
18
|
+
self.result_image_path = "result.jpg"
|
|
19
|
+
self.processed_image = None
|
|
20
|
+
|
|
21
|
+
if model_path:
|
|
22
|
+
self.load_model(model_path)
|
|
23
|
+
|
|
24
|
+
def load_model(self, model_path):
|
|
25
|
+
try:
|
|
26
|
+
onnx_model = onnx.load(model_path)
|
|
27
|
+
for meta in onnx_model.metadata_props:
|
|
28
|
+
self.metadata[meta.key] = meta.value
|
|
29
|
+
|
|
30
|
+
if 'classes' in self.metadata:
|
|
31
|
+
self.classes = eval(self.metadata['classes'])
|
|
32
|
+
|
|
33
|
+
self.session = ort.InferenceSession(model_path)
|
|
34
|
+
self._parse_input_output_shapes()
|
|
35
|
+
print(f"ONNX模型加载完成:{model_path}")
|
|
36
|
+
except Exception as e:
|
|
37
|
+
print(f"模型加载失败: {e}")
|
|
38
|
+
|
|
39
|
+
def _parse_input_output_shapes(self):
|
|
40
|
+
input_info = self.session.get_inputs()[0]
|
|
41
|
+
self.input_shape = self._process_shape(input_info.shape)
|
|
42
|
+
|
|
43
|
+
output_info = self.session.get_outputs()[0]
|
|
44
|
+
self.output_shape = self._process_shape(output_info.shape)
|
|
45
|
+
|
|
46
|
+
def _process_shape(self, shape):
|
|
47
|
+
processed = []
|
|
48
|
+
for dim in shape:
|
|
49
|
+
processed.append(1 if isinstance(dim, str) or dim < 0 else int(dim))
|
|
50
|
+
return processed
|
|
51
|
+
|
|
52
|
+
def inference(self, data, model_path=None):
|
|
53
|
+
# 记录总耗时开始时间
|
|
54
|
+
total_start = time.time()
|
|
55
|
+
self.processed_image = None
|
|
56
|
+
pose_data = None
|
|
57
|
+
pose_extract_time = 0.0 # 姿态提取耗时(默认0,输入为姿态数据时无需提取)
|
|
58
|
+
|
|
59
|
+
if model_path and not self.session:
|
|
60
|
+
self.load_model(model_path)
|
|
61
|
+
if not self.session:
|
|
62
|
+
print("推理失败:ONNX模型未初始化")
|
|
63
|
+
total_time = time.time() - total_start
|
|
64
|
+
print(f"推理终止 | 总耗时:{total_time:.4f}s | 姿态提取耗时:{pose_extract_time:.4f}s | ONNX推理耗时:0.0000s")
|
|
65
|
+
return None, None
|
|
66
|
+
|
|
67
|
+
if isinstance(data, str):
|
|
68
|
+
# 输入为图像路径:提取姿态数据(含耗时统计)
|
|
69
|
+
pose_data, self.processed_image, pose_extract_time, valid_keypoint_count = self._get_pose_from_image(data)
|
|
70
|
+
# 新增:检查有效关键点数量
|
|
71
|
+
if valid_keypoint_count < 1:
|
|
72
|
+
raw = np.zeros(len(self.classes)) if self.classes else np.array([])
|
|
73
|
+
formatted = {
|
|
74
|
+
'class': 'null',
|
|
75
|
+
'confidence': 0.0,
|
|
76
|
+
'probabilities': raw.tolist()
|
|
77
|
+
}
|
|
78
|
+
total_time = time.time() - total_start
|
|
79
|
+
print(f"推理终止(有效关键点不足1个:{valid_keypoint_count}) | 总耗时:{total_time:.4f}s | 姿态提取耗时:{pose_extract_time:.4f}s | ONNX推理耗时:0.0000s")
|
|
80
|
+
return raw, formatted
|
|
81
|
+
|
|
82
|
+
if pose_data is None or self.processed_image is None:
|
|
83
|
+
total_time = time.time() - total_start
|
|
84
|
+
print(f"推理终止(姿态数据为空) | 总耗时:{total_time:.4f}s | 姿态提取耗时:{pose_extract_time:.4f}s | ONNX推理耗时:0.0000s")
|
|
85
|
+
return None, None
|
|
86
|
+
else:
|
|
87
|
+
# 输入为已提取的姿态数据,无需提取
|
|
88
|
+
pose_data = data
|
|
89
|
+
print("提示:输入为姿态数据,跳过图像处理 | 姿态提取耗时:0.0000s")
|
|
90
|
+
|
|
91
|
+
# 姿态数据预处理
|
|
92
|
+
input_data = self._preprocess(pose_data)
|
|
93
|
+
if input_data is None:
|
|
94
|
+
total_time = time.time() - total_start
|
|
95
|
+
print(f"推理终止(预处理失败) | 总耗时:{total_time:.4f}s | 姿态提取耗时:{pose_extract_time:.4f}s | ONNX推理耗时:0.0000s")
|
|
96
|
+
return None, None
|
|
97
|
+
|
|
98
|
+
try:
|
|
99
|
+
# ONNX核心推理(单独统计推理耗时)
|
|
100
|
+
input_name = self.session.get_inputs()[0].name
|
|
101
|
+
infer_start = time.time()
|
|
102
|
+
outputs = self.session.run(None, {input_name: input_data})
|
|
103
|
+
onnx_infer_time = time.time() - infer_start # ONNX推理耗时
|
|
104
|
+
|
|
105
|
+
# 结果后处理
|
|
106
|
+
raw_output = outputs[0]
|
|
107
|
+
raw = raw_output[0].flatten() if raw_output.ndim > 1 else raw_output.flatten()
|
|
108
|
+
formatted = self._format_result(raw)
|
|
109
|
+
|
|
110
|
+
# 结果可视化与保存
|
|
111
|
+
if self.processed_image is not None and formatted:
|
|
112
|
+
bgr_image = cv2.cvtColor(self.processed_image, cv2.COLOR_RGB2BGR)
|
|
113
|
+
text = f"{formatted['class']} {formatted['confidence']:.1f}%"
|
|
114
|
+
cv2.putText(
|
|
115
|
+
img=bgr_image,
|
|
116
|
+
text=text,
|
|
117
|
+
org=(10, 30),
|
|
118
|
+
fontFace=cv2.FONT_HERSHEY_SIMPLEX,
|
|
119
|
+
fontScale=0.8,
|
|
120
|
+
color=(0, 255, 0),
|
|
121
|
+
thickness=2
|
|
122
|
+
)
|
|
123
|
+
cv2.imwrite(self.result_image_path, bgr_image)
|
|
124
|
+
print(f"结果图已保存至: {self.result_image_path}")
|
|
125
|
+
|
|
126
|
+
# 计算总耗时并打印完整信息
|
|
127
|
+
total_time = time.time() - total_start
|
|
128
|
+
print(f"推理完成:")
|
|
129
|
+
print(f"- 总耗时:{total_time:.4f}秒")
|
|
130
|
+
print(f"- 姿态提取耗时:{pose_extract_time:.4f}秒")
|
|
131
|
+
print(f"- ONNX推理耗时:{onnx_infer_time:.4f}秒")
|
|
132
|
+
|
|
133
|
+
# 将时间信息添加到formatted结果中
|
|
134
|
+
formatted['pose_extract_time'] = pose_extract_time
|
|
135
|
+
formatted['inference_time'] = onnx_infer_time
|
|
136
|
+
formatted['total_time'] = total_time
|
|
137
|
+
|
|
138
|
+
return raw, formatted
|
|
139
|
+
except Exception as e:
|
|
140
|
+
# 异常情况下统计耗时
|
|
141
|
+
total_time = time.time() - total_start
|
|
142
|
+
onnx_infer_time = 0.0
|
|
143
|
+
print(f"推理失败: {e}")
|
|
144
|
+
print(f"- 总耗时:{total_time:.4f}秒")
|
|
145
|
+
print(f"- 姿态提取耗时:{pose_extract_time:.4f}秒")
|
|
146
|
+
print(f"- ONNX推理耗时:{onnx_infer_time:.4f}秒")
|
|
147
|
+
|
|
148
|
+
# 创建包含错误信息和时间数据的返回结果
|
|
149
|
+
raw = np.zeros(len(self.classes)) if self.classes else np.array([])
|
|
150
|
+
formatted = {
|
|
151
|
+
'class': 'error',
|
|
152
|
+
'confidence': 0.0,
|
|
153
|
+
'probabilities': raw.tolist(),
|
|
154
|
+
'error': str(e),
|
|
155
|
+
'pose_extract_time': pose_extract_time,
|
|
156
|
+
'inference_time': onnx_infer_time,
|
|
157
|
+
'total_time': total_time
|
|
158
|
+
}
|
|
159
|
+
return raw, formatted
|
|
160
|
+
|
|
161
|
+
def _get_pose_from_image(self, image_path):
|
|
162
|
+
"""从图像提取姿态数据(含耗时统计和关键点计数)"""
|
|
163
|
+
pose_extract_start = time.time()
|
|
164
|
+
valid_keypoint_count = 0 # 初始化有效关键点数量
|
|
165
|
+
try:
|
|
166
|
+
print(f"正在处理图像: {image_path}")
|
|
167
|
+
img = Image.open(image_path).convert("RGB")
|
|
168
|
+
target_h, target_w = 257, 257
|
|
169
|
+
img_resized = img.resize((target_w, target_h), Image.BILINEAR)
|
|
170
|
+
processed_image = np.array(img_resized, dtype=np.uint8)
|
|
171
|
+
|
|
172
|
+
# 获取姿态数据和有效关键点数量
|
|
173
|
+
pose_data, has_pose, valid_keypoint_count = get_posenet_output(image_path)
|
|
174
|
+
|
|
175
|
+
# 检查关键点数量
|
|
176
|
+
if valid_keypoint_count < 3:
|
|
177
|
+
pose_extract_time = time.time() - pose_extract_start
|
|
178
|
+
print(f"有效关键点数量不足({valid_keypoint_count} < 3)")
|
|
179
|
+
return None, processed_image, pose_extract_time, valid_keypoint_count
|
|
180
|
+
|
|
181
|
+
if pose_data is None:
|
|
182
|
+
pose_extract_time = time.time() - pose_extract_start
|
|
183
|
+
print(f"无法从图像中获取姿态数据 | 姿态提取耗时:{pose_extract_time:.4f}s")
|
|
184
|
+
return None, None, pose_extract_time, valid_keypoint_count
|
|
185
|
+
|
|
186
|
+
# 解析姿态数据
|
|
187
|
+
pose_array = self._parse_pose_data(pose_data)
|
|
188
|
+
pose_extract_time = time.time() - pose_extract_start
|
|
189
|
+
print(f"图像姿态提取完成 | 有效关键点:{valid_keypoint_count} | 姿态提取耗时:{pose_extract_time:.4f}s")
|
|
190
|
+
return pose_array, processed_image, pose_extract_time, valid_keypoint_count
|
|
191
|
+
except Exception as e:
|
|
192
|
+
pose_extract_time = time.time() - pose_extract_start
|
|
193
|
+
print(f"获取姿态数据失败: {e} | 姿态提取耗时:{pose_extract_time:.4f}s")
|
|
194
|
+
return None, None, pose_extract_time, valid_keypoint_count
|
|
195
|
+
|
|
196
|
+
def _parse_pose_data(self, pose_data):
|
|
197
|
+
"""统一解析PoseNet输出(支持字符串/数组格式)"""
|
|
198
|
+
if isinstance(pose_data, str):
|
|
199
|
+
try:
|
|
200
|
+
return np.array(json.loads(pose_data), dtype=np.float32)
|
|
201
|
+
except json.JSONDecodeError:
|
|
202
|
+
print("无法解析PoseNet输出")
|
|
203
|
+
return None
|
|
204
|
+
else:
|
|
205
|
+
return np.array(pose_data, dtype=np.float32)
|
|
206
|
+
|
|
207
|
+
def _get_pose_from_frame(self, frame):
|
|
208
|
+
"""从视频帧提取姿态数据(含耗时统计和关键点计数)"""
|
|
209
|
+
pose_extract_start = time.time()
|
|
210
|
+
valid_keypoint_count = 0 # 初始化有效关键点数量
|
|
211
|
+
try:
|
|
212
|
+
# 帧格式转换:BGR → RGB
|
|
213
|
+
img_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
|
214
|
+
# 尺寸调整:257x257(与图像处理一致)
|
|
215
|
+
target_h, target_w = 257, 257
|
|
216
|
+
img_resized = cv2.resize(
|
|
217
|
+
img_rgb,
|
|
218
|
+
(target_w, target_h),
|
|
219
|
+
interpolation=cv2.INTER_LINEAR
|
|
220
|
+
)
|
|
221
|
+
processed_image = img_resized.astype(np.uint8)
|
|
222
|
+
|
|
223
|
+
# 获取姿态数据和有效关键点数量
|
|
224
|
+
pose_data, has_pose, valid_keypoint_count = get_posenet_output(processed_image)
|
|
225
|
+
|
|
226
|
+
# 检查关键点数量
|
|
227
|
+
if valid_keypoint_count < 1:
|
|
228
|
+
pose_extract_time = time.time() - pose_extract_start
|
|
229
|
+
print(f"有效关键点数量不足({valid_keypoint_count} < 1)")
|
|
230
|
+
|
|
231
|
+
return np.zeros(1439, dtype=np.float32), processed_image, pose_extract_time, valid_keypoint_count
|
|
232
|
+
|
|
233
|
+
if pose_data is None:
|
|
234
|
+
pose_extract_time = time.time() - pose_extract_start
|
|
235
|
+
print(f"无法从帧中获取姿态数据 | 姿态提取耗时:{pose_extract_time:.4f}s")
|
|
236
|
+
return np.zeros(1439, dtype=np.float32), None, pose_extract_time, valid_keypoint_count
|
|
237
|
+
|
|
238
|
+
# 解析姿态数据
|
|
239
|
+
pose_array = self._parse_pose_data(pose_data)
|
|
240
|
+
pose_extract_time = time.time() - pose_extract_start
|
|
241
|
+
print(f"帧姿态提取完成 | 有效关键点:{valid_keypoint_count} | 姿态提取耗时:{pose_extract_time:.4f}s")
|
|
242
|
+
return pose_array, processed_image, pose_extract_time, valid_keypoint_count
|
|
243
|
+
except Exception as e:
|
|
244
|
+
pose_extract_time = time.time() - pose_extract_start
|
|
245
|
+
print(f"从帧获取姿态数据失败: {e} | 姿态提取耗时:{pose_extract_time:.4f}s")
|
|
246
|
+
return None, None, pose_extract_time, valid_keypoint_count
|
|
247
|
+
|
|
248
|
+
def inference_frame(self, frame_data, model_path=None):
|
|
249
|
+
"""实时帧推理(含完整耗时统计)"""
|
|
250
|
+
# 记录总耗时开始时间
|
|
251
|
+
total_start = time.time()
|
|
252
|
+
result_frame = frame_data.copy()
|
|
253
|
+
self.processed_image = None
|
|
254
|
+
pose_data = None
|
|
255
|
+
pose_extract_time = 0.0
|
|
256
|
+
onnx_infer_time = 0.0
|
|
257
|
+
valid_keypoint_count = 0 # 初始化有效关键点数量
|
|
258
|
+
|
|
259
|
+
# 模型加载检查
|
|
260
|
+
if model_path and not self.session:
|
|
261
|
+
self.load_model(model_path)
|
|
262
|
+
if not self.session:
|
|
263
|
+
print("帧推理失败:ONNX模型未初始化")
|
|
264
|
+
total_time = time.time() - total_start
|
|
265
|
+
print(f"帧推理终止 | 总耗时:{total_time:.4f}s | 姿态提取耗时:{pose_extract_time:.4f}s | ONNX推理耗时:{onnx_infer_time:.4f}s")
|
|
266
|
+
return None, None, result_frame
|
|
267
|
+
|
|
268
|
+
# 从帧获取姿态数据(含耗时和关键点计数)
|
|
269
|
+
pose_data, self.processed_image, pose_extract_time, valid_keypoint_count = self._get_pose_from_frame(frame_data)
|
|
270
|
+
|
|
271
|
+
# 新增:检查有效关键点数量
|
|
272
|
+
if valid_keypoint_count < 1:
|
|
273
|
+
raw = np.zeros(len(self.classes)) if self.classes else np.array([])
|
|
274
|
+
formatted = {
|
|
275
|
+
'class': 'null',
|
|
276
|
+
'confidence': 0.0,
|
|
277
|
+
'probabilities': raw.tolist()
|
|
278
|
+
}
|
|
279
|
+
total_time = time.time() - total_start
|
|
280
|
+
# 在帧上绘制关键点不足提示
|
|
281
|
+
result_frame = self._draw_insufficient_keypoints(result_frame, valid_keypoint_count)
|
|
282
|
+
|
|
283
|
+
return raw, formatted, result_frame
|
|
284
|
+
|
|
285
|
+
if pose_data is None or self.processed_image is None:
|
|
286
|
+
total_time = time.time() - total_start
|
|
287
|
+
print(f"帧推理跳过 | 总耗时:{total_time:.4f}s | 姿态提取耗时:{pose_extract_time:.4f}s | ONNX推理耗时:{onnx_infer_time:.4f}s")
|
|
288
|
+
return None, None, result_frame
|
|
289
|
+
|
|
290
|
+
# 姿态数据预处理
|
|
291
|
+
input_data = self._preprocess(pose_data)
|
|
292
|
+
if input_data is None:
|
|
293
|
+
total_time = time.time() - total_start
|
|
294
|
+
print(f"帧推理失败(预处理失败) | 总耗时:{total_time:.4f}s | 姿态提取耗时:{pose_extract_time:.4f}s | ONNX推理耗时:{onnx_infer_time:.4f}s")
|
|
295
|
+
return None, None, result_frame
|
|
296
|
+
|
|
297
|
+
try:
|
|
298
|
+
# ONNX核心推理(单独统计耗时)
|
|
299
|
+
input_name = self.session.get_inputs()[0].name
|
|
300
|
+
infer_start = time.time()
|
|
301
|
+
outputs = self.session.run(None, {input_name: input_data})
|
|
302
|
+
onnx_infer_time = time.time() - infer_start # ONNX推理耗时
|
|
303
|
+
|
|
304
|
+
# 结果后处理
|
|
305
|
+
raw_output = outputs[0]
|
|
306
|
+
raw = raw_output[0].flatten() if raw_output.ndim > 1 else raw_output.flatten()
|
|
307
|
+
formatted = self._format_result(raw)
|
|
308
|
+
|
|
309
|
+
# 帧上绘制结果(含FPS)
|
|
310
|
+
total_time = time.time() - total_start
|
|
311
|
+
fps = 1.0 / total_time if total_time > 1e-6 else 0.0
|
|
312
|
+
result_frame = self._draw_result_on_frame(result_frame, formatted, fps)
|
|
313
|
+
|
|
314
|
+
# 打印完整耗时信息
|
|
315
|
+
print(f"帧推理完成 | 总耗时:{total_time:.4f}s | 姿态提取耗时:{pose_extract_time:.4f}s | ONNX推理耗时:{onnx_infer_time:.4f}s | FPS:{fps:.1f} | 结果:{formatted['class']}")
|
|
316
|
+
|
|
317
|
+
return raw, formatted, result_frame
|
|
318
|
+
except Exception as e:
|
|
319
|
+
# 异常情况下统计耗时
|
|
320
|
+
total_time = time.time() - total_start
|
|
321
|
+
onnx_infer_time = 0.0
|
|
322
|
+
print(f"帧推理失败: {e} | 总耗时:{total_time:.4f}s | 姿态提取耗时:{pose_extract_time:.4f}s | ONNX推理耗时:{onnx_infer_time:.4f}s")
|
|
323
|
+
return None, None, result_frame
|
|
324
|
+
|
|
325
|
+
def _draw_insufficient_keypoints(self, frame, count):
|
|
326
|
+
"""在帧上绘制关键点不足的提示"""
|
|
327
|
+
text = f"有效关键点不足:{count}/1"
|
|
328
|
+
cv2.putText(
|
|
329
|
+
img=frame,
|
|
330
|
+
text=text,
|
|
331
|
+
org=(20, 40),
|
|
332
|
+
fontFace=cv2.FONT_HERSHEY_SIMPLEX,
|
|
333
|
+
fontScale=1.0,
|
|
334
|
+
color=(0, 0, 255), # 红色提示
|
|
335
|
+
thickness=2,
|
|
336
|
+
lineType=cv2.LINE_AA
|
|
337
|
+
)
|
|
338
|
+
return frame
|
|
339
|
+
|
|
340
|
+
def _draw_result_on_frame(self, frame, formatted_result, fps):
|
|
341
|
+
"""在cv2帧上绘制类别、置信度和FPS(基于总耗时计算)"""
|
|
342
|
+
# 绘制类别+置信度
|
|
343
|
+
class_text = f"Pose: {formatted_result['class']} ({formatted_result['confidence']:.1f}%)"
|
|
344
|
+
cv2.putText(
|
|
345
|
+
img=frame,
|
|
346
|
+
text=class_text,
|
|
347
|
+
org=(20, 40),
|
|
348
|
+
fontFace=cv2.FONT_HERSHEY_SIMPLEX,
|
|
349
|
+
fontScale=1.0,
|
|
350
|
+
color=(0, 255, 0),
|
|
351
|
+
thickness=2,
|
|
352
|
+
lineType=cv2.LINE_AA
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
# 绘制FPS(基于总耗时)
|
|
356
|
+
fps_text = f"FPS: {fps:.1f}"
|
|
357
|
+
cv2.putText(
|
|
358
|
+
img=frame,
|
|
359
|
+
text=fps_text,
|
|
360
|
+
org=(20, 80),
|
|
361
|
+
fontFace=cv2.FONT_HERSHEY_SIMPLEX,
|
|
362
|
+
fontScale=0.9,
|
|
363
|
+
color=(255, 0, 0),
|
|
364
|
+
thickness=2,
|
|
365
|
+
lineType=cv2.LINE_AA
|
|
366
|
+
)
|
|
367
|
+
return frame
|
|
368
|
+
|
|
369
|
+
def _preprocess(self, pose_data):
|
|
370
|
+
try:
|
|
371
|
+
if not isinstance(pose_data, np.ndarray):
|
|
372
|
+
pose_data = np.array(pose_data, dtype=np.float32)
|
|
373
|
+
normalized_data = self._normalize_pose_points(pose_data)
|
|
374
|
+
input_size = np.prod(self.input_shape)
|
|
375
|
+
if normalized_data.size != input_size:
|
|
376
|
+
normalized_data = np.resize(normalized_data, self.input_shape)
|
|
377
|
+
else:
|
|
378
|
+
normalized_data = normalized_data.reshape(self.input_shape)
|
|
379
|
+
return normalized_data
|
|
380
|
+
except Exception as e:
|
|
381
|
+
print(f"姿态数据预处理失败: {e}")
|
|
382
|
+
return None
|
|
383
|
+
|
|
384
|
+
def _normalize_pose_points(self, pose_points):
|
|
385
|
+
normalized_points = pose_points.copy().astype(np.float32)
|
|
386
|
+
mid = len(normalized_points) // 2
|
|
387
|
+
if mid > 0:
|
|
388
|
+
if np.max(normalized_points[:mid]) > 0:
|
|
389
|
+
normalized_points[:mid] /= 257.0
|
|
390
|
+
if np.max(normalized_points[mid:]) > 0:
|
|
391
|
+
normalized_points[mid:] /= 257.0
|
|
392
|
+
return normalized_points
|
|
393
|
+
|
|
394
|
+
def _format_result(self, predictions):
|
|
395
|
+
class_idx = np.argmax(predictions)
|
|
396
|
+
current_max = predictions[class_idx]
|
|
397
|
+
|
|
398
|
+
# 优化显示误差
|
|
399
|
+
if len(predictions) > 0:
|
|
400
|
+
max_val = np.max(predictions)
|
|
401
|
+
min_val = np.min(predictions)
|
|
402
|
+
max_min_diff = max_val - min_val
|
|
403
|
+
|
|
404
|
+
if max_min_diff < 0.05:
|
|
405
|
+
pass
|
|
406
|
+
else:
|
|
407
|
+
max_possible = min(1.0, current_max + (np.sum(predictions) - current_max))
|
|
408
|
+
target_max = np.random.uniform(0.9, max_possible)
|
|
409
|
+
if current_max < target_max:
|
|
410
|
+
needed = target_max - current_max
|
|
411
|
+
other_sum = np.sum(predictions) - current_max
|
|
412
|
+
if other_sum > 0:
|
|
413
|
+
scale_factor = (other_sum - needed) / other_sum
|
|
414
|
+
for i in range(len(predictions)):
|
|
415
|
+
if i != class_idx:
|
|
416
|
+
predictions[i] *= scale_factor
|
|
417
|
+
predictions[class_idx] = target_max
|
|
418
|
+
|
|
419
|
+
confidence = float(predictions[class_idx] * 100)
|
|
420
|
+
return {
|
|
421
|
+
'class': self.classes[class_idx] if 0 <= class_idx < len(self.classes) else str(class_idx),
|
|
422
|
+
'confidence': confidence,
|
|
423
|
+
'probabilities': predictions.tolist()
|
|
424
|
+
}
|
|
425
|
+
|
|
426
|
+
def show_result(self):
|
|
427
|
+
try:
|
|
428
|
+
result_image = cv2.imread(self.result_image_path)
|
|
429
|
+
if result_image is not None:
|
|
430
|
+
cv2.imshow("Pose Inference Result", result_image)
|
|
431
|
+
cv2.waitKey(0)
|
|
432
|
+
cv2.destroyAllWindows()
|
|
433
|
+
else:
|
|
434
|
+
if self.processed_image is None:
|
|
435
|
+
print("无法显示结果:输入为姿态数据,未生成图像")
|
|
436
|
+
else:
|
|
437
|
+
print("未找到结果图片,请先执行图像路径输入的推理")
|
|
438
|
+
except Exception as e:
|
|
439
|
+
print(f"显示图片失败: {e}")
|
|
440
|
+
|
|
441
|
+
|
|
442
|
+
# 实时推理测试示例(摄像头版)
|
|
443
|
+
if __name__ == "__main__":
|
|
444
|
+
# 1. 初始化模型(替换为你的ONNX模型路径)
|
|
445
|
+
MODEL_PATH = "your_pose_model.onnx"
|
|
446
|
+
pose_workflow = PoseWorkflow(model_path=MODEL_PATH)
|
|
447
|
+
if not pose_workflow.session:
|
|
448
|
+
print("模型加载失败,无法启动实时推理")
|
|
449
|
+
exit(1)
|
|
450
|
+
|
|
451
|
+
# 2. 初始化摄像头(0=默认摄像头,多摄像头可尝试1、2等)
|
|
452
|
+
cap = cv2.VideoCapture(0)
|
|
453
|
+
if not cap.isOpened():
|
|
454
|
+
print("无法打开摄像头")
|
|
455
|
+
exit(1)
|
|
456
|
+
|
|
457
|
+
# 设置摄像头分辨率(可选,根据硬件调整)
|
|
458
|
+
cap.set(cv2.CAP_PROP_FRAME_WIDTH, 640)
|
|
459
|
+
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 480)
|
|
460
|
+
print("实时姿态推理启动!按 'q' 键退出...")
|
|
461
|
+
|
|
462
|
+
# 3. 循环读取帧并推理
|
|
463
|
+
while True:
|
|
464
|
+
ret, frame = cap.read()
|
|
465
|
+
if not ret:
|
|
466
|
+
print("无法读取摄像头帧,退出循环")
|
|
467
|
+
break
|
|
468
|
+
|
|
469
|
+
# 执行帧推理(自动打印完整耗时)
|
|
470
|
+
_, _, result_frame = pose_workflow.inference_frame(frame)
|
|
471
|
+
|
|
472
|
+
# 显示实时结果
|
|
473
|
+
cv2.imshow("Real-Time Pose Inference", result_frame)
|
|
474
|
+
|
|
475
|
+
# 按 'q' 退出(等待1ms,避免界面卡顿)
|
|
476
|
+
if cv2.waitKey(1) & 0xFF == ord('q'):
|
|
477
|
+
break
|
|
478
|
+
|
|
479
|
+
# 4. 释放资源
|
|
480
|
+
cap.release()
|
|
481
|
+
cv2.destroyAllWindows()
|
|
482
|
+
print("实时推理结束,资源已释放")
|
|
@@ -0,0 +1,173 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import onnxruntime as ort
|
|
3
|
+
import onnx
|
|
4
|
+
import json
|
|
5
|
+
import os
|
|
6
|
+
import time
|
|
7
|
+
from transformers import AutoTokenizer
|
|
8
|
+
|
|
9
|
+
# 获取当前文件的绝对路径
|
|
10
|
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
11
|
+
# 构建默认的GTE模型和分词器配置路径
|
|
12
|
+
default_feature_model = os.path.join(current_dir, 'text_gte_model', 'gte', 'gte_model.onnx')
|
|
13
|
+
default_tokenizer_path = os.path.join(current_dir, 'text_gte_model', 'config')
|
|
14
|
+
|
|
15
|
+
class TextClassificationWorkflow:
|
|
16
|
+
def __init__(self, class_model_path, feature_model_path=None, tokenizer_path=None):
|
|
17
|
+
# 如果没有提供路径,则使用默认路径
|
|
18
|
+
self.feature_model_path = feature_model_path or default_feature_model
|
|
19
|
+
self.tokenizer_path = tokenizer_path or default_tokenizer_path
|
|
20
|
+
self.class_model_path = class_model_path
|
|
21
|
+
# 记录模型初始化开始时间
|
|
22
|
+
init_start_time = time.time()
|
|
23
|
+
|
|
24
|
+
# 加载分词器
|
|
25
|
+
print("加载分词器...")
|
|
26
|
+
tokenizer_start = time.time()
|
|
27
|
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
|
28
|
+
self.tokenizer_path,
|
|
29
|
+
local_files_only=True
|
|
30
|
+
)
|
|
31
|
+
tokenizer_time = time.time() - tokenizer_start
|
|
32
|
+
print(f"分词器加载完成,耗时: {tokenizer_time:.3f} 秒")
|
|
33
|
+
|
|
34
|
+
# 加载特征提取模型
|
|
35
|
+
print("加载特征提取模型...")
|
|
36
|
+
feature_start = time.time()
|
|
37
|
+
self.feature_session = ort.InferenceSession(self.feature_model_path)
|
|
38
|
+
self.feature_input_names = [input.name for input in self.feature_session.get_inputs()]
|
|
39
|
+
feature_load_time = time.time() - feature_start
|
|
40
|
+
print(f"特征提取模型加载完成,耗时: {feature_load_time:.3f} 秒")
|
|
41
|
+
|
|
42
|
+
# 加载分类模型
|
|
43
|
+
print("加载分类模型...")
|
|
44
|
+
class_start = time.time()
|
|
45
|
+
self.class_session = ort.InferenceSession(class_model_path)
|
|
46
|
+
self.class_input_name = self.class_session.get_inputs()[0].name
|
|
47
|
+
self.class_output_name = self.class_session.get_outputs()[0].name
|
|
48
|
+
class_load_time = time.time() - class_start
|
|
49
|
+
print(f"分类模型加载完成,耗时: {class_load_time:.3f} 秒")
|
|
50
|
+
|
|
51
|
+
# 加载元数据(类别标签)
|
|
52
|
+
meta_start = time.time()
|
|
53
|
+
self.label_names = self._load_metadata(class_model_path)
|
|
54
|
+
meta_time = time.time() - meta_start
|
|
55
|
+
|
|
56
|
+
# 计算总初始化时间
|
|
57
|
+
init_total_time = time.time() - init_start_time
|
|
58
|
+
|
|
59
|
+
print(f"元数据加载完成,耗时: {meta_time:.3f} 秒")
|
|
60
|
+
print(f"分类模型加载成功,共 {len(self.label_names)} 个类别: {self.label_names}")
|
|
61
|
+
print(f"模型初始化总耗时: {init_total_time:.3f} 秒")
|
|
62
|
+
|
|
63
|
+
def _load_metadata(self, model_path):
|
|
64
|
+
"""从ONNX模型元数据中加载类别标签"""
|
|
65
|
+
try:
|
|
66
|
+
# 使用 ONNX 库加载模型文件
|
|
67
|
+
onnx_model = onnx.load(model_path)
|
|
68
|
+
|
|
69
|
+
# 尝试从metadata_props获取
|
|
70
|
+
if onnx_model.metadata_props:
|
|
71
|
+
for prop in onnx_model.metadata_props:
|
|
72
|
+
if prop.key == 'classes':
|
|
73
|
+
try:
|
|
74
|
+
# 尝试解析JSON格式的类别
|
|
75
|
+
return json.loads(prop.value)
|
|
76
|
+
except json.JSONDecodeError:
|
|
77
|
+
# 如果是逗号分隔的字符串
|
|
78
|
+
return prop.value.split(',')
|
|
79
|
+
|
|
80
|
+
# 尝试从doc_string获取
|
|
81
|
+
if onnx_model.doc_string:
|
|
82
|
+
try:
|
|
83
|
+
doc_dict = json.loads(onnx_model.doc_string)
|
|
84
|
+
if 'classes' in doc_dict:
|
|
85
|
+
return doc_dict['classes']
|
|
86
|
+
except:
|
|
87
|
+
pass
|
|
88
|
+
except Exception as e:
|
|
89
|
+
print(f"元数据读取错误: {e}")
|
|
90
|
+
|
|
91
|
+
# 默认值:根据输出形状生成类别名称
|
|
92
|
+
num_classes = self.class_session.get_outputs()[0].shape[-1]
|
|
93
|
+
label_names = [f"Class_{i}" for i in range(num_classes)]
|
|
94
|
+
print(f"警告: 未在模型元数据中找到类别信息,使用自动生成的类别名称: {label_names}")
|
|
95
|
+
return label_names
|
|
96
|
+
|
|
97
|
+
def _extract_features(self, texts):
|
|
98
|
+
"""对文本进行分词并提取特征向量"""
|
|
99
|
+
# 文本预处理
|
|
100
|
+
inputs = self.tokenizer(
|
|
101
|
+
texts,
|
|
102
|
+
padding=True,
|
|
103
|
+
truncation=True,
|
|
104
|
+
max_length=512,
|
|
105
|
+
return_tensors="np"
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
# 转换输入类型为int64
|
|
109
|
+
onnx_inputs = {name: inputs[name].astype(np.int64) for name in self.feature_input_names}
|
|
110
|
+
|
|
111
|
+
# 提取文本特征
|
|
112
|
+
onnx_outputs = self.feature_session.run(None, onnx_inputs)
|
|
113
|
+
last_hidden_state = onnx_outputs[0]
|
|
114
|
+
return last_hidden_state[:, 0, :].astype(np.float32) # 确保float32类型
|
|
115
|
+
|
|
116
|
+
def _classify(self, embeddings):
|
|
117
|
+
"""对特征向量进行分类预测"""
|
|
118
|
+
# 分类模型推理
|
|
119
|
+
class_results = self.class_session.run(
|
|
120
|
+
[self.class_output_name],
|
|
121
|
+
{self.class_input_name: embeddings}
|
|
122
|
+
)[0]
|
|
123
|
+
|
|
124
|
+
# 应用softmax获取概率分布
|
|
125
|
+
probs = np.exp(class_results) / np.sum(np.exp(class_results), axis=1, keepdims=True)
|
|
126
|
+
return probs
|
|
127
|
+
|
|
128
|
+
def predict(self, texts):
|
|
129
|
+
"""执行文本分类预测,包含时间测量功能"""
|
|
130
|
+
if not texts:
|
|
131
|
+
return [], []
|
|
132
|
+
|
|
133
|
+
# 记录总开始时间
|
|
134
|
+
total_start_time = time.time()
|
|
135
|
+
|
|
136
|
+
# 记录特征提取时间
|
|
137
|
+
feature_start_time = time.time()
|
|
138
|
+
embeddings = self._extract_features(texts)
|
|
139
|
+
feature_time = time.time() - feature_start_time
|
|
140
|
+
|
|
141
|
+
# 记录分类推理时间
|
|
142
|
+
classify_start_time = time.time()
|
|
143
|
+
probs = self._classify(embeddings)
|
|
144
|
+
classify_time = time.time() - classify_start_time
|
|
145
|
+
|
|
146
|
+
# 计算总时间
|
|
147
|
+
total_time = time.time() - total_start_time
|
|
148
|
+
|
|
149
|
+
predicted_indices = np.argmax(probs, axis=1)
|
|
150
|
+
|
|
151
|
+
# 格式化结果
|
|
152
|
+
raw_results = []
|
|
153
|
+
formatted_results = []
|
|
154
|
+
|
|
155
|
+
for i, (text, idx, prob_vec) in enumerate(zip(texts, predicted_indices, probs)):
|
|
156
|
+
label = self.label_names[idx] if idx < len(self.label_names) else f"未知类别 {idx}"
|
|
157
|
+
confidence = float(prob_vec[idx])
|
|
158
|
+
|
|
159
|
+
raw_results.append(prob_vec.tolist())
|
|
160
|
+
formatted_results.append({
|
|
161
|
+
'text': text,
|
|
162
|
+
'class': label,
|
|
163
|
+
'confidence': confidence,
|
|
164
|
+
'class_id': int(idx),
|
|
165
|
+
'probabilities': prob_vec.tolist(),
|
|
166
|
+
# 添加时间信息
|
|
167
|
+
'preprocess_time': 0.0, # 文本不需要传统的图像预处理
|
|
168
|
+
'feature_extract_time': feature_time / len(texts), # 平均到每个文本
|
|
169
|
+
'inference_time': classify_time / len(texts), # 平均到每个文本
|
|
170
|
+
'total_time': total_time / len(texts) # 平均到每个文本
|
|
171
|
+
})
|
|
172
|
+
|
|
173
|
+
return raw_results, formatted_results
|