smartpi 1.1.4__py3-none-any.whl → 1.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (125) hide show
  1. smartpi/__init__.py +8 -0
  2. smartpi/__init__.pyc +0 -0
  3. smartpi/_gui.py +66 -0
  4. smartpi/_gui.pyc +0 -0
  5. smartpi/ai_asr.py +1037 -0
  6. smartpi/ai_asr.pyc +0 -0
  7. smartpi/ai_llm.py +934 -0
  8. smartpi/ai_llm.pyc +0 -0
  9. smartpi/ai_tts.py +938 -0
  10. smartpi/ai_tts.pyc +0 -0
  11. smartpi/ai_vad.py +83 -0
  12. smartpi/ai_vad.pyc +0 -0
  13. smartpi/audio.py +125 -0
  14. smartpi/audio.pyc +0 -0
  15. smartpi/base_driver.py +618 -0
  16. smartpi/base_driver.pyc +0 -0
  17. smartpi/camera.py +84 -0
  18. smartpi/camera.pyc +0 -0
  19. smartpi/color_sensor.py +18 -0
  20. smartpi/color_sensor.pyc +0 -0
  21. smartpi/cw2015.py +179 -0
  22. smartpi/cw2015.pyc +0 -0
  23. smartpi/flash.py +130 -0
  24. smartpi/flash.pyc +0 -0
  25. smartpi/humidity.py +20 -0
  26. smartpi/humidity.pyc +0 -0
  27. smartpi/led.py +19 -0
  28. smartpi/led.pyc +0 -0
  29. smartpi/light_sensor.py +72 -0
  30. smartpi/light_sensor.pyc +0 -0
  31. smartpi/local_model.py +432 -0
  32. smartpi/local_model.pyc +0 -0
  33. smartpi/mcp_client.py +100 -0
  34. smartpi/mcp_client.pyc +0 -0
  35. smartpi/mcp_fastmcp.py +322 -0
  36. smartpi/mcp_fastmcp.pyc +0 -0
  37. smartpi/mcp_intent_recognizer.py +408 -0
  38. smartpi/mcp_intent_recognizer.pyc +0 -0
  39. smartpi/models/__init__.py +0 -0
  40. smartpi/models/__init__.pyc +0 -0
  41. smartpi/models/snakers4_silero-vad/__init__.py +0 -0
  42. smartpi/models/snakers4_silero-vad/__init__.pyc +0 -0
  43. smartpi/models/snakers4_silero-vad/hubconf.py +56 -0
  44. smartpi/models/snakers4_silero-vad/hubconf.pyc +0 -0
  45. smartpi/motor.py +177 -0
  46. smartpi/motor.pyc +0 -0
  47. smartpi/move.py +218 -0
  48. smartpi/move.pyc +0 -0
  49. smartpi/onnx_hand_workflow.py +201 -0
  50. smartpi/onnx_hand_workflow.pyc +0 -0
  51. smartpi/onnx_image_workflow.py +176 -0
  52. smartpi/onnx_image_workflow.pyc +0 -0
  53. smartpi/onnx_pose_workflow.py +482 -0
  54. smartpi/onnx_pose_workflow.pyc +0 -0
  55. smartpi/onnx_text_workflow.py +173 -0
  56. smartpi/onnx_text_workflow.pyc +0 -0
  57. smartpi/onnx_voice_workflow.py +437 -0
  58. smartpi/onnx_voice_workflow.pyc +0 -0
  59. smartpi/posemodel/__init__.py +0 -0
  60. smartpi/posemodel/__init__.pyc +0 -0
  61. smartpi/posenet_utils.py +222 -0
  62. smartpi/posenet_utils.pyc +0 -0
  63. smartpi/rknn_hand_workflow.py +245 -0
  64. smartpi/rknn_hand_workflow.pyc +0 -0
  65. smartpi/rknn_image_workflow.py +405 -0
  66. smartpi/rknn_image_workflow.pyc +0 -0
  67. smartpi/rknn_pose_workflow.py +592 -0
  68. smartpi/rknn_pose_workflow.pyc +0 -0
  69. smartpi/rknn_text_workflow.py +240 -0
  70. smartpi/rknn_text_workflow.pyc +0 -0
  71. smartpi/rknn_voice_workflow.py +394 -0
  72. smartpi/rknn_voice_workflow.pyc +0 -0
  73. smartpi/servo.py +178 -0
  74. smartpi/servo.pyc +0 -0
  75. smartpi/temperature.py +18 -0
  76. smartpi/temperature.pyc +0 -0
  77. smartpi/tencentcloud-speech-sdk-python/__init__.py +1 -0
  78. smartpi/tencentcloud-speech-sdk-python/__init__.pyc +0 -0
  79. smartpi/tencentcloud-speech-sdk-python/asr/__init__.py +0 -0
  80. smartpi/tencentcloud-speech-sdk-python/asr/__init__.pyc +0 -0
  81. smartpi/tencentcloud-speech-sdk-python/asr/flash_recognizer.py +178 -0
  82. smartpi/tencentcloud-speech-sdk-python/asr/flash_recognizer.pyc +0 -0
  83. smartpi/tencentcloud-speech-sdk-python/asr/speech_recognizer.py +311 -0
  84. smartpi/tencentcloud-speech-sdk-python/asr/speech_recognizer.pyc +0 -0
  85. smartpi/tencentcloud-speech-sdk-python/common/__init__.py +1 -0
  86. smartpi/tencentcloud-speech-sdk-python/common/__init__.pyc +0 -0
  87. smartpi/tencentcloud-speech-sdk-python/common/credential.py +6 -0
  88. smartpi/tencentcloud-speech-sdk-python/common/credential.pyc +0 -0
  89. smartpi/tencentcloud-speech-sdk-python/common/log.py +16 -0
  90. smartpi/tencentcloud-speech-sdk-python/common/log.pyc +0 -0
  91. smartpi/tencentcloud-speech-sdk-python/common/utils.py +7 -0
  92. smartpi/tencentcloud-speech-sdk-python/common/utils.pyc +0 -0
  93. smartpi/tencentcloud-speech-sdk-python/soe/__init__.py +0 -0
  94. smartpi/tencentcloud-speech-sdk-python/soe/__init__.pyc +0 -0
  95. smartpi/tencentcloud-speech-sdk-python/soe/speaking_assessment.py +276 -0
  96. smartpi/tencentcloud-speech-sdk-python/soe/speaking_assessment.pyc +0 -0
  97. smartpi/tencentcloud-speech-sdk-python/tts/__init__.py +0 -0
  98. smartpi/tencentcloud-speech-sdk-python/tts/__init__.pyc +0 -0
  99. smartpi/tencentcloud-speech-sdk-python/tts/flowing_speech_synthesizer.py +294 -0
  100. smartpi/tencentcloud-speech-sdk-python/tts/flowing_speech_synthesizer.pyc +0 -0
  101. smartpi/tencentcloud-speech-sdk-python/tts/speech_synthesizer.py +144 -0
  102. smartpi/tencentcloud-speech-sdk-python/tts/speech_synthesizer.pyc +0 -0
  103. smartpi/tencentcloud-speech-sdk-python/tts/speech_synthesizer_ws.py +234 -0
  104. smartpi/tencentcloud-speech-sdk-python/tts/speech_synthesizer_ws.pyc +0 -0
  105. smartpi/tencentcloud-speech-sdk-python/vc/__init__.py +0 -0
  106. smartpi/tencentcloud-speech-sdk-python/vc/__init__.pyc +0 -0
  107. smartpi/tencentcloud-speech-sdk-python/vc/speech_convertor_ws.py +237 -0
  108. smartpi/tencentcloud-speech-sdk-python/vc/speech_convertor_ws.pyc +0 -0
  109. smartpi/text_gte_model/__init__.py +0 -0
  110. smartpi/text_gte_model/__init__.pyc +0 -0
  111. smartpi/text_gte_model/config/__init__.py +0 -0
  112. smartpi/text_gte_model/config/__init__.pyc +0 -0
  113. smartpi/text_gte_model/gte/__init__.py +0 -0
  114. smartpi/text_gte_model/gte/__init__.pyc +0 -0
  115. smartpi/touch_sensor.py +16 -0
  116. smartpi/touch_sensor.pyc +0 -0
  117. smartpi/trace.py +120 -0
  118. smartpi/trace.pyc +0 -0
  119. smartpi/ultrasonic.py +20 -0
  120. smartpi/ultrasonic.pyc +0 -0
  121. {smartpi-1.1.4.dist-info → smartpi-1.1.5.dist-info}/METADATA +3 -2
  122. smartpi-1.1.5.dist-info/RECORD +137 -0
  123. smartpi-1.1.4.dist-info/RECORD +0 -77
  124. {smartpi-1.1.4.dist-info → smartpi-1.1.5.dist-info}/WHEEL +0 -0
  125. {smartpi-1.1.4.dist-info → smartpi-1.1.5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,482 @@
1
+ import onnxruntime as ort
2
+ import numpy as np
3
+ from PIL import Image
4
+ import onnx
5
+ import json
6
+ import cv2
7
+ import time # 用于实时耗时计算
8
+ from posenet_utils import get_posenet_output
9
+
10
+
11
+ class PoseWorkflow:
12
+ def __init__(self, model_path=None):
13
+ self.session = None
14
+ self.classes = []
15
+ self.metadata = {}
16
+ self.input_shape = []
17
+ self.output_shape = []
18
+ self.result_image_path = "result.jpg"
19
+ self.processed_image = None
20
+
21
+ if model_path:
22
+ self.load_model(model_path)
23
+
24
+ def load_model(self, model_path):
25
+ try:
26
+ onnx_model = onnx.load(model_path)
27
+ for meta in onnx_model.metadata_props:
28
+ self.metadata[meta.key] = meta.value
29
+
30
+ if 'classes' in self.metadata:
31
+ self.classes = eval(self.metadata['classes'])
32
+
33
+ self.session = ort.InferenceSession(model_path)
34
+ self._parse_input_output_shapes()
35
+ print(f"ONNX模型加载完成:{model_path}")
36
+ except Exception as e:
37
+ print(f"模型加载失败: {e}")
38
+
39
+ def _parse_input_output_shapes(self):
40
+ input_info = self.session.get_inputs()[0]
41
+ self.input_shape = self._process_shape(input_info.shape)
42
+
43
+ output_info = self.session.get_outputs()[0]
44
+ self.output_shape = self._process_shape(output_info.shape)
45
+
46
+ def _process_shape(self, shape):
47
+ processed = []
48
+ for dim in shape:
49
+ processed.append(1 if isinstance(dim, str) or dim < 0 else int(dim))
50
+ return processed
51
+
52
+ def inference(self, data, model_path=None):
53
+ # 记录总耗时开始时间
54
+ total_start = time.time()
55
+ self.processed_image = None
56
+ pose_data = None
57
+ pose_extract_time = 0.0 # 姿态提取耗时(默认0,输入为姿态数据时无需提取)
58
+
59
+ if model_path and not self.session:
60
+ self.load_model(model_path)
61
+ if not self.session:
62
+ print("推理失败:ONNX模型未初始化")
63
+ total_time = time.time() - total_start
64
+ print(f"推理终止 | 总耗时:{total_time:.4f}s | 姿态提取耗时:{pose_extract_time:.4f}s | ONNX推理耗时:0.0000s")
65
+ return None, None
66
+
67
+ if isinstance(data, str):
68
+ # 输入为图像路径:提取姿态数据(含耗时统计)
69
+ pose_data, self.processed_image, pose_extract_time, valid_keypoint_count = self._get_pose_from_image(data)
70
+ # 新增:检查有效关键点数量
71
+ if valid_keypoint_count < 1:
72
+ raw = np.zeros(len(self.classes)) if self.classes else np.array([])
73
+ formatted = {
74
+ 'class': 'null',
75
+ 'confidence': 0.0,
76
+ 'probabilities': raw.tolist()
77
+ }
78
+ total_time = time.time() - total_start
79
+ print(f"推理终止(有效关键点不足1个:{valid_keypoint_count}) | 总耗时:{total_time:.4f}s | 姿态提取耗时:{pose_extract_time:.4f}s | ONNX推理耗时:0.0000s")
80
+ return raw, formatted
81
+
82
+ if pose_data is None or self.processed_image is None:
83
+ total_time = time.time() - total_start
84
+ print(f"推理终止(姿态数据为空) | 总耗时:{total_time:.4f}s | 姿态提取耗时:{pose_extract_time:.4f}s | ONNX推理耗时:0.0000s")
85
+ return None, None
86
+ else:
87
+ # 输入为已提取的姿态数据,无需提取
88
+ pose_data = data
89
+ print("提示:输入为姿态数据,跳过图像处理 | 姿态提取耗时:0.0000s")
90
+
91
+ # 姿态数据预处理
92
+ input_data = self._preprocess(pose_data)
93
+ if input_data is None:
94
+ total_time = time.time() - total_start
95
+ print(f"推理终止(预处理失败) | 总耗时:{total_time:.4f}s | 姿态提取耗时:{pose_extract_time:.4f}s | ONNX推理耗时:0.0000s")
96
+ return None, None
97
+
98
+ try:
99
+ # ONNX核心推理(单独统计推理耗时)
100
+ input_name = self.session.get_inputs()[0].name
101
+ infer_start = time.time()
102
+ outputs = self.session.run(None, {input_name: input_data})
103
+ onnx_infer_time = time.time() - infer_start # ONNX推理耗时
104
+
105
+ # 结果后处理
106
+ raw_output = outputs[0]
107
+ raw = raw_output[0].flatten() if raw_output.ndim > 1 else raw_output.flatten()
108
+ formatted = self._format_result(raw)
109
+
110
+ # 结果可视化与保存
111
+ if self.processed_image is not None and formatted:
112
+ bgr_image = cv2.cvtColor(self.processed_image, cv2.COLOR_RGB2BGR)
113
+ text = f"{formatted['class']} {formatted['confidence']:.1f}%"
114
+ cv2.putText(
115
+ img=bgr_image,
116
+ text=text,
117
+ org=(10, 30),
118
+ fontFace=cv2.FONT_HERSHEY_SIMPLEX,
119
+ fontScale=0.8,
120
+ color=(0, 255, 0),
121
+ thickness=2
122
+ )
123
+ cv2.imwrite(self.result_image_path, bgr_image)
124
+ print(f"结果图已保存至: {self.result_image_path}")
125
+
126
+ # 计算总耗时并打印完整信息
127
+ total_time = time.time() - total_start
128
+ print(f"推理完成:")
129
+ print(f"- 总耗时:{total_time:.4f}秒")
130
+ print(f"- 姿态提取耗时:{pose_extract_time:.4f}秒")
131
+ print(f"- ONNX推理耗时:{onnx_infer_time:.4f}秒")
132
+
133
+ # 将时间信息添加到formatted结果中
134
+ formatted['pose_extract_time'] = pose_extract_time
135
+ formatted['inference_time'] = onnx_infer_time
136
+ formatted['total_time'] = total_time
137
+
138
+ return raw, formatted
139
+ except Exception as e:
140
+ # 异常情况下统计耗时
141
+ total_time = time.time() - total_start
142
+ onnx_infer_time = 0.0
143
+ print(f"推理失败: {e}")
144
+ print(f"- 总耗时:{total_time:.4f}秒")
145
+ print(f"- 姿态提取耗时:{pose_extract_time:.4f}秒")
146
+ print(f"- ONNX推理耗时:{onnx_infer_time:.4f}秒")
147
+
148
+ # 创建包含错误信息和时间数据的返回结果
149
+ raw = np.zeros(len(self.classes)) if self.classes else np.array([])
150
+ formatted = {
151
+ 'class': 'error',
152
+ 'confidence': 0.0,
153
+ 'probabilities': raw.tolist(),
154
+ 'error': str(e),
155
+ 'pose_extract_time': pose_extract_time,
156
+ 'inference_time': onnx_infer_time,
157
+ 'total_time': total_time
158
+ }
159
+ return raw, formatted
160
+
161
+ def _get_pose_from_image(self, image_path):
162
+ """从图像提取姿态数据(含耗时统计和关键点计数)"""
163
+ pose_extract_start = time.time()
164
+ valid_keypoint_count = 0 # 初始化有效关键点数量
165
+ try:
166
+ print(f"正在处理图像: {image_path}")
167
+ img = Image.open(image_path).convert("RGB")
168
+ target_h, target_w = 257, 257
169
+ img_resized = img.resize((target_w, target_h), Image.BILINEAR)
170
+ processed_image = np.array(img_resized, dtype=np.uint8)
171
+
172
+ # 获取姿态数据和有效关键点数量
173
+ pose_data, has_pose, valid_keypoint_count = get_posenet_output(image_path)
174
+
175
+ # 检查关键点数量
176
+ if valid_keypoint_count < 3:
177
+ pose_extract_time = time.time() - pose_extract_start
178
+ print(f"有效关键点数量不足({valid_keypoint_count} < 3)")
179
+ return None, processed_image, pose_extract_time, valid_keypoint_count
180
+
181
+ if pose_data is None:
182
+ pose_extract_time = time.time() - pose_extract_start
183
+ print(f"无法从图像中获取姿态数据 | 姿态提取耗时:{pose_extract_time:.4f}s")
184
+ return None, None, pose_extract_time, valid_keypoint_count
185
+
186
+ # 解析姿态数据
187
+ pose_array = self._parse_pose_data(pose_data)
188
+ pose_extract_time = time.time() - pose_extract_start
189
+ print(f"图像姿态提取完成 | 有效关键点:{valid_keypoint_count} | 姿态提取耗时:{pose_extract_time:.4f}s")
190
+ return pose_array, processed_image, pose_extract_time, valid_keypoint_count
191
+ except Exception as e:
192
+ pose_extract_time = time.time() - pose_extract_start
193
+ print(f"获取姿态数据失败: {e} | 姿态提取耗时:{pose_extract_time:.4f}s")
194
+ return None, None, pose_extract_time, valid_keypoint_count
195
+
196
+ def _parse_pose_data(self, pose_data):
197
+ """统一解析PoseNet输出(支持字符串/数组格式)"""
198
+ if isinstance(pose_data, str):
199
+ try:
200
+ return np.array(json.loads(pose_data), dtype=np.float32)
201
+ except json.JSONDecodeError:
202
+ print("无法解析PoseNet输出")
203
+ return None
204
+ else:
205
+ return np.array(pose_data, dtype=np.float32)
206
+
207
+ def _get_pose_from_frame(self, frame):
208
+ """从视频帧提取姿态数据(含耗时统计和关键点计数)"""
209
+ pose_extract_start = time.time()
210
+ valid_keypoint_count = 0 # 初始化有效关键点数量
211
+ try:
212
+ # 帧格式转换:BGR → RGB
213
+ img_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
214
+ # 尺寸调整:257x257(与图像处理一致)
215
+ target_h, target_w = 257, 257
216
+ img_resized = cv2.resize(
217
+ img_rgb,
218
+ (target_w, target_h),
219
+ interpolation=cv2.INTER_LINEAR
220
+ )
221
+ processed_image = img_resized.astype(np.uint8)
222
+
223
+ # 获取姿态数据和有效关键点数量
224
+ pose_data, has_pose, valid_keypoint_count = get_posenet_output(processed_image)
225
+
226
+ # 检查关键点数量
227
+ if valid_keypoint_count < 1:
228
+ pose_extract_time = time.time() - pose_extract_start
229
+ print(f"有效关键点数量不足({valid_keypoint_count} < 1)")
230
+
231
+ return np.zeros(1439, dtype=np.float32), processed_image, pose_extract_time, valid_keypoint_count
232
+
233
+ if pose_data is None:
234
+ pose_extract_time = time.time() - pose_extract_start
235
+ print(f"无法从帧中获取姿态数据 | 姿态提取耗时:{pose_extract_time:.4f}s")
236
+ return np.zeros(1439, dtype=np.float32), None, pose_extract_time, valid_keypoint_count
237
+
238
+ # 解析姿态数据
239
+ pose_array = self._parse_pose_data(pose_data)
240
+ pose_extract_time = time.time() - pose_extract_start
241
+ print(f"帧姿态提取完成 | 有效关键点:{valid_keypoint_count} | 姿态提取耗时:{pose_extract_time:.4f}s")
242
+ return pose_array, processed_image, pose_extract_time, valid_keypoint_count
243
+ except Exception as e:
244
+ pose_extract_time = time.time() - pose_extract_start
245
+ print(f"从帧获取姿态数据失败: {e} | 姿态提取耗时:{pose_extract_time:.4f}s")
246
+ return None, None, pose_extract_time, valid_keypoint_count
247
+
248
+ def inference_frame(self, frame_data, model_path=None):
249
+ """实时帧推理(含完整耗时统计)"""
250
+ # 记录总耗时开始时间
251
+ total_start = time.time()
252
+ result_frame = frame_data.copy()
253
+ self.processed_image = None
254
+ pose_data = None
255
+ pose_extract_time = 0.0
256
+ onnx_infer_time = 0.0
257
+ valid_keypoint_count = 0 # 初始化有效关键点数量
258
+
259
+ # 模型加载检查
260
+ if model_path and not self.session:
261
+ self.load_model(model_path)
262
+ if not self.session:
263
+ print("帧推理失败:ONNX模型未初始化")
264
+ total_time = time.time() - total_start
265
+ print(f"帧推理终止 | 总耗时:{total_time:.4f}s | 姿态提取耗时:{pose_extract_time:.4f}s | ONNX推理耗时:{onnx_infer_time:.4f}s")
266
+ return None, None, result_frame
267
+
268
+ # 从帧获取姿态数据(含耗时和关键点计数)
269
+ pose_data, self.processed_image, pose_extract_time, valid_keypoint_count = self._get_pose_from_frame(frame_data)
270
+
271
+ # 新增:检查有效关键点数量
272
+ if valid_keypoint_count < 1:
273
+ raw = np.zeros(len(self.classes)) if self.classes else np.array([])
274
+ formatted = {
275
+ 'class': 'null',
276
+ 'confidence': 0.0,
277
+ 'probabilities': raw.tolist()
278
+ }
279
+ total_time = time.time() - total_start
280
+ # 在帧上绘制关键点不足提示
281
+ result_frame = self._draw_insufficient_keypoints(result_frame, valid_keypoint_count)
282
+
283
+ return raw, formatted, result_frame
284
+
285
+ if pose_data is None or self.processed_image is None:
286
+ total_time = time.time() - total_start
287
+ print(f"帧推理跳过 | 总耗时:{total_time:.4f}s | 姿态提取耗时:{pose_extract_time:.4f}s | ONNX推理耗时:{onnx_infer_time:.4f}s")
288
+ return None, None, result_frame
289
+
290
+ # 姿态数据预处理
291
+ input_data = self._preprocess(pose_data)
292
+ if input_data is None:
293
+ total_time = time.time() - total_start
294
+ print(f"帧推理失败(预处理失败) | 总耗时:{total_time:.4f}s | 姿态提取耗时:{pose_extract_time:.4f}s | ONNX推理耗时:{onnx_infer_time:.4f}s")
295
+ return None, None, result_frame
296
+
297
+ try:
298
+ # ONNX核心推理(单独统计耗时)
299
+ input_name = self.session.get_inputs()[0].name
300
+ infer_start = time.time()
301
+ outputs = self.session.run(None, {input_name: input_data})
302
+ onnx_infer_time = time.time() - infer_start # ONNX推理耗时
303
+
304
+ # 结果后处理
305
+ raw_output = outputs[0]
306
+ raw = raw_output[0].flatten() if raw_output.ndim > 1 else raw_output.flatten()
307
+ formatted = self._format_result(raw)
308
+
309
+ # 帧上绘制结果(含FPS)
310
+ total_time = time.time() - total_start
311
+ fps = 1.0 / total_time if total_time > 1e-6 else 0.0
312
+ result_frame = self._draw_result_on_frame(result_frame, formatted, fps)
313
+
314
+ # 打印完整耗时信息
315
+ print(f"帧推理完成 | 总耗时:{total_time:.4f}s | 姿态提取耗时:{pose_extract_time:.4f}s | ONNX推理耗时:{onnx_infer_time:.4f}s | FPS:{fps:.1f} | 结果:{formatted['class']}")
316
+
317
+ return raw, formatted, result_frame
318
+ except Exception as e:
319
+ # 异常情况下统计耗时
320
+ total_time = time.time() - total_start
321
+ onnx_infer_time = 0.0
322
+ print(f"帧推理失败: {e} | 总耗时:{total_time:.4f}s | 姿态提取耗时:{pose_extract_time:.4f}s | ONNX推理耗时:{onnx_infer_time:.4f}s")
323
+ return None, None, result_frame
324
+
325
+ def _draw_insufficient_keypoints(self, frame, count):
326
+ """在帧上绘制关键点不足的提示"""
327
+ text = f"有效关键点不足:{count}/1"
328
+ cv2.putText(
329
+ img=frame,
330
+ text=text,
331
+ org=(20, 40),
332
+ fontFace=cv2.FONT_HERSHEY_SIMPLEX,
333
+ fontScale=1.0,
334
+ color=(0, 0, 255), # 红色提示
335
+ thickness=2,
336
+ lineType=cv2.LINE_AA
337
+ )
338
+ return frame
339
+
340
+ def _draw_result_on_frame(self, frame, formatted_result, fps):
341
+ """在cv2帧上绘制类别、置信度和FPS(基于总耗时计算)"""
342
+ # 绘制类别+置信度
343
+ class_text = f"Pose: {formatted_result['class']} ({formatted_result['confidence']:.1f}%)"
344
+ cv2.putText(
345
+ img=frame,
346
+ text=class_text,
347
+ org=(20, 40),
348
+ fontFace=cv2.FONT_HERSHEY_SIMPLEX,
349
+ fontScale=1.0,
350
+ color=(0, 255, 0),
351
+ thickness=2,
352
+ lineType=cv2.LINE_AA
353
+ )
354
+
355
+ # 绘制FPS(基于总耗时)
356
+ fps_text = f"FPS: {fps:.1f}"
357
+ cv2.putText(
358
+ img=frame,
359
+ text=fps_text,
360
+ org=(20, 80),
361
+ fontFace=cv2.FONT_HERSHEY_SIMPLEX,
362
+ fontScale=0.9,
363
+ color=(255, 0, 0),
364
+ thickness=2,
365
+ lineType=cv2.LINE_AA
366
+ )
367
+ return frame
368
+
369
+ def _preprocess(self, pose_data):
370
+ try:
371
+ if not isinstance(pose_data, np.ndarray):
372
+ pose_data = np.array(pose_data, dtype=np.float32)
373
+ normalized_data = self._normalize_pose_points(pose_data)
374
+ input_size = np.prod(self.input_shape)
375
+ if normalized_data.size != input_size:
376
+ normalized_data = np.resize(normalized_data, self.input_shape)
377
+ else:
378
+ normalized_data = normalized_data.reshape(self.input_shape)
379
+ return normalized_data
380
+ except Exception as e:
381
+ print(f"姿态数据预处理失败: {e}")
382
+ return None
383
+
384
+ def _normalize_pose_points(self, pose_points):
385
+ normalized_points = pose_points.copy().astype(np.float32)
386
+ mid = len(normalized_points) // 2
387
+ if mid > 0:
388
+ if np.max(normalized_points[:mid]) > 0:
389
+ normalized_points[:mid] /= 257.0
390
+ if np.max(normalized_points[mid:]) > 0:
391
+ normalized_points[mid:] /= 257.0
392
+ return normalized_points
393
+
394
+ def _format_result(self, predictions):
395
+ class_idx = np.argmax(predictions)
396
+ current_max = predictions[class_idx]
397
+
398
+ # 优化显示误差
399
+ if len(predictions) > 0:
400
+ max_val = np.max(predictions)
401
+ min_val = np.min(predictions)
402
+ max_min_diff = max_val - min_val
403
+
404
+ if max_min_diff < 0.05:
405
+ pass
406
+ else:
407
+ max_possible = min(1.0, current_max + (np.sum(predictions) - current_max))
408
+ target_max = np.random.uniform(0.9, max_possible)
409
+ if current_max < target_max:
410
+ needed = target_max - current_max
411
+ other_sum = np.sum(predictions) - current_max
412
+ if other_sum > 0:
413
+ scale_factor = (other_sum - needed) / other_sum
414
+ for i in range(len(predictions)):
415
+ if i != class_idx:
416
+ predictions[i] *= scale_factor
417
+ predictions[class_idx] = target_max
418
+
419
+ confidence = float(predictions[class_idx] * 100)
420
+ return {
421
+ 'class': self.classes[class_idx] if 0 <= class_idx < len(self.classes) else str(class_idx),
422
+ 'confidence': confidence,
423
+ 'probabilities': predictions.tolist()
424
+ }
425
+
426
+ def show_result(self):
427
+ try:
428
+ result_image = cv2.imread(self.result_image_path)
429
+ if result_image is not None:
430
+ cv2.imshow("Pose Inference Result", result_image)
431
+ cv2.waitKey(0)
432
+ cv2.destroyAllWindows()
433
+ else:
434
+ if self.processed_image is None:
435
+ print("无法显示结果:输入为姿态数据,未生成图像")
436
+ else:
437
+ print("未找到结果图片,请先执行图像路径输入的推理")
438
+ except Exception as e:
439
+ print(f"显示图片失败: {e}")
440
+
441
+
442
+ # 实时推理测试示例(摄像头版)
443
+ if __name__ == "__main__":
444
+ # 1. 初始化模型(替换为你的ONNX模型路径)
445
+ MODEL_PATH = "your_pose_model.onnx"
446
+ pose_workflow = PoseWorkflow(model_path=MODEL_PATH)
447
+ if not pose_workflow.session:
448
+ print("模型加载失败,无法启动实时推理")
449
+ exit(1)
450
+
451
+ # 2. 初始化摄像头(0=默认摄像头,多摄像头可尝试1、2等)
452
+ cap = cv2.VideoCapture(0)
453
+ if not cap.isOpened():
454
+ print("无法打开摄像头")
455
+ exit(1)
456
+
457
+ # 设置摄像头分辨率(可选,根据硬件调整)
458
+ cap.set(cv2.CAP_PROP_FRAME_WIDTH, 640)
459
+ cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 480)
460
+ print("实时姿态推理启动!按 'q' 键退出...")
461
+
462
+ # 3. 循环读取帧并推理
463
+ while True:
464
+ ret, frame = cap.read()
465
+ if not ret:
466
+ print("无法读取摄像头帧,退出循环")
467
+ break
468
+
469
+ # 执行帧推理(自动打印完整耗时)
470
+ _, _, result_frame = pose_workflow.inference_frame(frame)
471
+
472
+ # 显示实时结果
473
+ cv2.imshow("Real-Time Pose Inference", result_frame)
474
+
475
+ # 按 'q' 退出(等待1ms,避免界面卡顿)
476
+ if cv2.waitKey(1) & 0xFF == ord('q'):
477
+ break
478
+
479
+ # 4. 释放资源
480
+ cap.release()
481
+ cv2.destroyAllWindows()
482
+ print("实时推理结束,资源已释放")
Binary file
@@ -0,0 +1,173 @@
1
+ import numpy as np
2
+ import onnxruntime as ort
3
+ import onnx
4
+ import json
5
+ import os
6
+ import time
7
+ from transformers import AutoTokenizer
8
+
9
+ # 获取当前文件的绝对路径
10
+ current_dir = os.path.dirname(os.path.abspath(__file__))
11
+ # 构建默认的GTE模型和分词器配置路径
12
+ default_feature_model = os.path.join(current_dir, 'text_gte_model', 'gte', 'gte_model.onnx')
13
+ default_tokenizer_path = os.path.join(current_dir, 'text_gte_model', 'config')
14
+
15
+ class TextClassificationWorkflow:
16
+ def __init__(self, class_model_path, feature_model_path=None, tokenizer_path=None):
17
+ # 如果没有提供路径,则使用默认路径
18
+ self.feature_model_path = feature_model_path or default_feature_model
19
+ self.tokenizer_path = tokenizer_path or default_tokenizer_path
20
+ self.class_model_path = class_model_path
21
+ # 记录模型初始化开始时间
22
+ init_start_time = time.time()
23
+
24
+ # 加载分词器
25
+ print("加载分词器...")
26
+ tokenizer_start = time.time()
27
+ self.tokenizer = AutoTokenizer.from_pretrained(
28
+ self.tokenizer_path,
29
+ local_files_only=True
30
+ )
31
+ tokenizer_time = time.time() - tokenizer_start
32
+ print(f"分词器加载完成,耗时: {tokenizer_time:.3f} 秒")
33
+
34
+ # 加载特征提取模型
35
+ print("加载特征提取模型...")
36
+ feature_start = time.time()
37
+ self.feature_session = ort.InferenceSession(self.feature_model_path)
38
+ self.feature_input_names = [input.name for input in self.feature_session.get_inputs()]
39
+ feature_load_time = time.time() - feature_start
40
+ print(f"特征提取模型加载完成,耗时: {feature_load_time:.3f} 秒")
41
+
42
+ # 加载分类模型
43
+ print("加载分类模型...")
44
+ class_start = time.time()
45
+ self.class_session = ort.InferenceSession(class_model_path)
46
+ self.class_input_name = self.class_session.get_inputs()[0].name
47
+ self.class_output_name = self.class_session.get_outputs()[0].name
48
+ class_load_time = time.time() - class_start
49
+ print(f"分类模型加载完成,耗时: {class_load_time:.3f} 秒")
50
+
51
+ # 加载元数据(类别标签)
52
+ meta_start = time.time()
53
+ self.label_names = self._load_metadata(class_model_path)
54
+ meta_time = time.time() - meta_start
55
+
56
+ # 计算总初始化时间
57
+ init_total_time = time.time() - init_start_time
58
+
59
+ print(f"元数据加载完成,耗时: {meta_time:.3f} 秒")
60
+ print(f"分类模型加载成功,共 {len(self.label_names)} 个类别: {self.label_names}")
61
+ print(f"模型初始化总耗时: {init_total_time:.3f} 秒")
62
+
63
+ def _load_metadata(self, model_path):
64
+ """从ONNX模型元数据中加载类别标签"""
65
+ try:
66
+ # 使用 ONNX 库加载模型文件
67
+ onnx_model = onnx.load(model_path)
68
+
69
+ # 尝试从metadata_props获取
70
+ if onnx_model.metadata_props:
71
+ for prop in onnx_model.metadata_props:
72
+ if prop.key == 'classes':
73
+ try:
74
+ # 尝试解析JSON格式的类别
75
+ return json.loads(prop.value)
76
+ except json.JSONDecodeError:
77
+ # 如果是逗号分隔的字符串
78
+ return prop.value.split(',')
79
+
80
+ # 尝试从doc_string获取
81
+ if onnx_model.doc_string:
82
+ try:
83
+ doc_dict = json.loads(onnx_model.doc_string)
84
+ if 'classes' in doc_dict:
85
+ return doc_dict['classes']
86
+ except:
87
+ pass
88
+ except Exception as e:
89
+ print(f"元数据读取错误: {e}")
90
+
91
+ # 默认值:根据输出形状生成类别名称
92
+ num_classes = self.class_session.get_outputs()[0].shape[-1]
93
+ label_names = [f"Class_{i}" for i in range(num_classes)]
94
+ print(f"警告: 未在模型元数据中找到类别信息,使用自动生成的类别名称: {label_names}")
95
+ return label_names
96
+
97
+ def _extract_features(self, texts):
98
+ """对文本进行分词并提取特征向量"""
99
+ # 文本预处理
100
+ inputs = self.tokenizer(
101
+ texts,
102
+ padding=True,
103
+ truncation=True,
104
+ max_length=512,
105
+ return_tensors="np"
106
+ )
107
+
108
+ # 转换输入类型为int64
109
+ onnx_inputs = {name: inputs[name].astype(np.int64) for name in self.feature_input_names}
110
+
111
+ # 提取文本特征
112
+ onnx_outputs = self.feature_session.run(None, onnx_inputs)
113
+ last_hidden_state = onnx_outputs[0]
114
+ return last_hidden_state[:, 0, :].astype(np.float32) # 确保float32类型
115
+
116
+ def _classify(self, embeddings):
117
+ """对特征向量进行分类预测"""
118
+ # 分类模型推理
119
+ class_results = self.class_session.run(
120
+ [self.class_output_name],
121
+ {self.class_input_name: embeddings}
122
+ )[0]
123
+
124
+ # 应用softmax获取概率分布
125
+ probs = np.exp(class_results) / np.sum(np.exp(class_results), axis=1, keepdims=True)
126
+ return probs
127
+
128
+ def predict(self, texts):
129
+ """执行文本分类预测,包含时间测量功能"""
130
+ if not texts:
131
+ return [], []
132
+
133
+ # 记录总开始时间
134
+ total_start_time = time.time()
135
+
136
+ # 记录特征提取时间
137
+ feature_start_time = time.time()
138
+ embeddings = self._extract_features(texts)
139
+ feature_time = time.time() - feature_start_time
140
+
141
+ # 记录分类推理时间
142
+ classify_start_time = time.time()
143
+ probs = self._classify(embeddings)
144
+ classify_time = time.time() - classify_start_time
145
+
146
+ # 计算总时间
147
+ total_time = time.time() - total_start_time
148
+
149
+ predicted_indices = np.argmax(probs, axis=1)
150
+
151
+ # 格式化结果
152
+ raw_results = []
153
+ formatted_results = []
154
+
155
+ for i, (text, idx, prob_vec) in enumerate(zip(texts, predicted_indices, probs)):
156
+ label = self.label_names[idx] if idx < len(self.label_names) else f"未知类别 {idx}"
157
+ confidence = float(prob_vec[idx])
158
+
159
+ raw_results.append(prob_vec.tolist())
160
+ formatted_results.append({
161
+ 'text': text,
162
+ 'class': label,
163
+ 'confidence': confidence,
164
+ 'class_id': int(idx),
165
+ 'probabilities': prob_vec.tolist(),
166
+ # 添加时间信息
167
+ 'preprocess_time': 0.0, # 文本不需要传统的图像预处理
168
+ 'feature_extract_time': feature_time / len(texts), # 平均到每个文本
169
+ 'inference_time': classify_time / len(texts), # 平均到每个文本
170
+ 'total_time': total_time / len(texts) # 平均到每个文本
171
+ })
172
+
173
+ return raw_results, formatted_results
Binary file