smartpi 0.1.34__py3-none-any.whl → 0.1.36__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,176 @@
1
+ import onnxruntime as ort
2
+ import numpy as np
3
+ from PIL import Image
4
+ import onnx
5
+ import cv2
6
+ import time
7
+
8
+ class ImageWorkflow:
9
+ def __init__(self, model_path=None):
10
+ self.session = None
11
+ self.classes = []
12
+ self.metadata = {}
13
+ self.input_shape = [1, 224, 224, 3] # 默认输入形状
14
+
15
+ if model_path:
16
+ self.load_model(model_path)
17
+
18
+ def load_model(self, model_path):
19
+ """加载模型并解析元数据"""
20
+ try:
21
+ # 读取ONNX元数据
22
+ onnx_model = onnx.load(model_path)
23
+ for meta in onnx_model.metadata_props:
24
+ self.metadata[meta.key] = meta.value
25
+
26
+ # 解析类别标签
27
+ if 'classes' in self.metadata:
28
+ self.classes = eval(self.metadata['classes'])
29
+
30
+ # 初始化推理会话
31
+ self.session = ort.InferenceSession(model_path)
32
+ self._parse_input_shape()
33
+
34
+ except Exception as e:
35
+ print(f"模型加载失败: {e}")
36
+
37
+ def _parse_input_shape(self):
38
+ """自动解析输入形状"""
39
+ input_info = self.session.get_inputs()[0]
40
+ shape = []
41
+ for dim in input_info.shape:
42
+ # 处理动态维度(用1替代)
43
+ shape.append(1 if isinstance(dim, str) or dim < 0 else int(dim))
44
+ self.input_shape = shape
45
+
46
+ def _preprocess(self, image_path):
47
+ """标准化预处理流程"""
48
+ try:
49
+ img = Image.open(image_path).convert("RGB")
50
+
51
+ # 获取目标尺寸(假设形状为 [N, H, W, C])
52
+ _, target_h, target_w, _ = self.input_shape
53
+
54
+ # 调整尺寸
55
+ img = img.resize((target_w, target_h), Image.BILINEAR)
56
+
57
+ # 转换为numpy数组并归一化
58
+ img_array = np.array(img).astype(np.float32) / 255.0
59
+
60
+ # 添加batch维度
61
+ return np.expand_dims(img_array, axis=0)
62
+
63
+ except Exception as e:
64
+ print(f"图像预处理失败: {e}")
65
+ return None
66
+
67
+ def inference(self, data, model_path=None):
68
+ """执行推理"""
69
+ if model_path and not self.session:
70
+ self.load_model(model_path)
71
+
72
+ input_data = self._preprocess(data)
73
+ if input_data is None:
74
+ return None, None
75
+
76
+ try:
77
+ # 运行推理
78
+ outputs = self.session.run(None, {self.session.get_inputs()[0].name: input_data})
79
+ raw = outputs[0][0] # 假设输出形状为 [1, n_classes]
80
+
81
+ # 格式化输出
82
+ formatted = self._format_result(raw)
83
+
84
+ return raw, formatted
85
+
86
+ except Exception as e:
87
+ print(f"推理失败: {e}")
88
+ return None, None
89
+
90
+ def inference_frame(self, frame_data, model_path=None):
91
+ """直接使用帧数据进行推理,无需文件IO
92
+ 返回值:raw, formatted
93
+ formatted字典包含:class, confidence, probabilities, preprocess_time, inference_time
94
+ """
95
+ if model_path and not self.session:
96
+ self.load_model(model_path)
97
+
98
+ # 测量预处理时间
99
+ preprocess_start = time.time()
100
+ input_data = self._preprocess_frame(frame_data)
101
+ preprocess_time = time.time() - preprocess_start
102
+
103
+ if input_data is None:
104
+ return None, None
105
+
106
+ try:
107
+ # 测量推理时间
108
+ inference_start = time.time()
109
+ # 运行推理
110
+ outputs = self.session.run(None, {self.session.get_inputs()[0].name: input_data})
111
+ inference_time = time.time() - inference_start
112
+
113
+ raw = outputs[0][0] # 假设输出形状为 [1, n_classes]
114
+
115
+ # 格式化输出
116
+ formatted = self._format_result(raw)
117
+ # 添加时间信息到返回结果
118
+ formatted['preprocess_time'] = preprocess_time
119
+ formatted['inference_time'] = inference_time
120
+
121
+ # 计算总耗时
122
+ total_time = preprocess_time + inference_time
123
+ print(f"帧推理耗时: {total_time:.4f}秒 - 识别结果: {formatted['class']} ({formatted['confidence']}%)")
124
+ return raw, formatted
125
+
126
+ except Exception as e:
127
+ print(f"帧数据推理失败: {e}")
128
+ return None, None
129
+
130
+ def _preprocess_frame(self, frame_data):
131
+ """处理帧数据的预处理流程"""
132
+ try:
133
+ # 确保输入是numpy数组
134
+ if not isinstance(frame_data, np.ndarray):
135
+ print("错误: 帧数据必须是numpy数组")
136
+ return None
137
+
138
+ # OpenCV读取的帧是BGR格式,转换为RGB
139
+ img = cv2.cvtColor(frame_data, cv2.COLOR_BGR2RGB)
140
+
141
+ # 获取目标尺寸(假设形状为 [N, H, W, C])
142
+ _, target_h, target_w, _ = self.input_shape
143
+
144
+ # 调整尺寸
145
+ img = cv2.resize(img, (target_w, target_h), interpolation=cv2.INTER_LINEAR)
146
+
147
+ # 转换为numpy数组并归一化
148
+ img_array = img.astype(np.float32) / 255.0
149
+
150
+ # 添加batch维度
151
+ return np.expand_dims(img_array, axis=0)
152
+
153
+ except Exception as e:
154
+ print(f"帧数据预处理失败: {e}")
155
+ return None
156
+
157
+ def _format_result(self, predictions):
158
+ """生成标准化输出"""
159
+ class_idx = np.argmax(predictions)
160
+ confidence = int(predictions[class_idx] * 100)
161
+
162
+ return {
163
+ 'class': self.classes[class_idx] if self.classes else str(class_idx),
164
+ 'confidence': confidence,
165
+ 'probabilities': predictions.tolist()
166
+ }
167
+
168
+ # 使用示例
169
+ if __name__ == "__main__":
170
+ # 预加载模型
171
+ model = ImageWorkflow("model.onnx")
172
+
173
+ # 使用帧数据进行推理
174
+ # 假设frame是通过cv2获取的帧
175
+ # raw, res = model.inference_frame(frame)
176
+ # print(f"识别结果: {res['class']} ({res['confidence']}%)")
@@ -0,0 +1,482 @@
1
+ import onnxruntime as ort
2
+ import numpy as np
3
+ from PIL import Image
4
+ import onnx
5
+ import json
6
+ import cv2
7
+ import time # 用于实时耗时计算
8
+ from lib.posenet_utils import get_posenet_output
9
+
10
+
11
+ class PoseWorkflow:
12
+ def __init__(self, model_path=None):
13
+ self.session = None
14
+ self.classes = []
15
+ self.metadata = {}
16
+ self.input_shape = []
17
+ self.output_shape = []
18
+ self.result_image_path = "result.jpg"
19
+ self.processed_image = None
20
+
21
+ if model_path:
22
+ self.load_model(model_path)
23
+
24
+ def load_model(self, model_path):
25
+ try:
26
+ onnx_model = onnx.load(model_path)
27
+ for meta in onnx_model.metadata_props:
28
+ self.metadata[meta.key] = meta.value
29
+
30
+ if 'classes' in self.metadata:
31
+ self.classes = eval(self.metadata['classes'])
32
+
33
+ self.session = ort.InferenceSession(model_path)
34
+ self._parse_input_output_shapes()
35
+ print(f"ONNX模型加载完成:{model_path}")
36
+ except Exception as e:
37
+ print(f"模型加载失败: {e}")
38
+
39
+ def _parse_input_output_shapes(self):
40
+ input_info = self.session.get_inputs()[0]
41
+ self.input_shape = self._process_shape(input_info.shape)
42
+
43
+ output_info = self.session.get_outputs()[0]
44
+ self.output_shape = self._process_shape(output_info.shape)
45
+
46
+ def _process_shape(self, shape):
47
+ processed = []
48
+ for dim in shape:
49
+ processed.append(1 if isinstance(dim, str) or dim < 0 else int(dim))
50
+ return processed
51
+
52
+ def inference(self, data, model_path=None):
53
+ # 记录总耗时开始时间
54
+ total_start = time.time()
55
+ self.processed_image = None
56
+ pose_data = None
57
+ pose_extract_time = 0.0 # 姿态提取耗时(默认0,输入为姿态数据时无需提取)
58
+
59
+ if model_path and not self.session:
60
+ self.load_model(model_path)
61
+ if not self.session:
62
+ print("推理失败:ONNX模型未初始化")
63
+ total_time = time.time() - total_start
64
+ print(f"推理终止 | 总耗时:{total_time:.4f}s | 姿态提取耗时:{pose_extract_time:.4f}s | ONNX推理耗时:0.0000s")
65
+ return None, None
66
+
67
+ if isinstance(data, str):
68
+ # 输入为图像路径:提取姿态数据(含耗时统计)
69
+ pose_data, self.processed_image, pose_extract_time, valid_keypoint_count = self._get_pose_from_image(data)
70
+ # 新增:检查有效关键点数量
71
+ if valid_keypoint_count < 1:
72
+ raw = np.zeros(len(self.classes)) if self.classes else np.array([])
73
+ formatted = {
74
+ 'class': 'null',
75
+ 'confidence': 0.0,
76
+ 'probabilities': raw.tolist()
77
+ }
78
+ total_time = time.time() - total_start
79
+ print(f"推理终止(有效关键点不足1个:{valid_keypoint_count}) | 总耗时:{total_time:.4f}s | 姿态提取耗时:{pose_extract_time:.4f}s | ONNX推理耗时:0.0000s")
80
+ return raw, formatted
81
+
82
+ if pose_data is None or self.processed_image is None:
83
+ total_time = time.time() - total_start
84
+ print(f"推理终止(姿态数据为空) | 总耗时:{total_time:.4f}s | 姿态提取耗时:{pose_extract_time:.4f}s | ONNX推理耗时:0.0000s")
85
+ return None, None
86
+ else:
87
+ # 输入为已提取的姿态数据,无需提取
88
+ pose_data = data
89
+ print("提示:输入为姿态数据,跳过图像处理 | 姿态提取耗时:0.0000s")
90
+
91
+ # 姿态数据预处理
92
+ input_data = self._preprocess(pose_data)
93
+ if input_data is None:
94
+ total_time = time.time() - total_start
95
+ print(f"推理终止(预处理失败) | 总耗时:{total_time:.4f}s | 姿态提取耗时:{pose_extract_time:.4f}s | ONNX推理耗时:0.0000s")
96
+ return None, None
97
+
98
+ try:
99
+ # ONNX核心推理(单独统计推理耗时)
100
+ input_name = self.session.get_inputs()[0].name
101
+ infer_start = time.time()
102
+ outputs = self.session.run(None, {input_name: input_data})
103
+ onnx_infer_time = time.time() - infer_start # ONNX推理耗时
104
+
105
+ # 结果后处理
106
+ raw_output = outputs[0]
107
+ raw = raw_output[0].flatten() if raw_output.ndim > 1 else raw_output.flatten()
108
+ formatted = self._format_result(raw)
109
+
110
+ # 结果可视化与保存
111
+ if self.processed_image is not None and formatted:
112
+ bgr_image = cv2.cvtColor(self.processed_image, cv2.COLOR_RGB2BGR)
113
+ text = f"{formatted['class']} {formatted['confidence']:.1f}%"
114
+ cv2.putText(
115
+ img=bgr_image,
116
+ text=text,
117
+ org=(10, 30),
118
+ fontFace=cv2.FONT_HERSHEY_SIMPLEX,
119
+ fontScale=0.8,
120
+ color=(0, 255, 0),
121
+ thickness=2
122
+ )
123
+ cv2.imwrite(self.result_image_path, bgr_image)
124
+ print(f"结果图已保存至: {self.result_image_path}")
125
+
126
+ # 计算总耗时并打印完整信息
127
+ total_time = time.time() - total_start
128
+ print(f"推理完成:")
129
+ print(f"- 总耗时:{total_time:.4f}秒")
130
+ print(f"- 姿态提取耗时:{pose_extract_time:.4f}秒")
131
+ print(f"- ONNX推理耗时:{onnx_infer_time:.4f}秒")
132
+
133
+ # 将时间信息添加到formatted结果中
134
+ formatted['pose_extract_time'] = pose_extract_time
135
+ formatted['inference_time'] = onnx_infer_time
136
+ formatted['total_time'] = total_time
137
+
138
+ return raw, formatted
139
+ except Exception as e:
140
+ # 异常情况下统计耗时
141
+ total_time = time.time() - total_start
142
+ onnx_infer_time = 0.0
143
+ print(f"推理失败: {e}")
144
+ print(f"- 总耗时:{total_time:.4f}秒")
145
+ print(f"- 姿态提取耗时:{pose_extract_time:.4f}秒")
146
+ print(f"- ONNX推理耗时:{onnx_infer_time:.4f}秒")
147
+
148
+ # 创建包含错误信息和时间数据的返回结果
149
+ raw = np.zeros(len(self.classes)) if self.classes else np.array([])
150
+ formatted = {
151
+ 'class': 'error',
152
+ 'confidence': 0.0,
153
+ 'probabilities': raw.tolist(),
154
+ 'error': str(e),
155
+ 'pose_extract_time': pose_extract_time,
156
+ 'inference_time': onnx_infer_time,
157
+ 'total_time': total_time
158
+ }
159
+ return raw, formatted
160
+
161
+ def _get_pose_from_image(self, image_path):
162
+ """从图像提取姿态数据(含耗时统计和关键点计数)"""
163
+ pose_extract_start = time.time()
164
+ valid_keypoint_count = 0 # 初始化有效关键点数量
165
+ try:
166
+ print(f"正在处理图像: {image_path}")
167
+ img = Image.open(image_path).convert("RGB")
168
+ target_h, target_w = 257, 257
169
+ img_resized = img.resize((target_w, target_h), Image.BILINEAR)
170
+ processed_image = np.array(img_resized, dtype=np.uint8)
171
+
172
+ # 获取姿态数据和有效关键点数量
173
+ pose_data, has_pose, valid_keypoint_count = get_posenet_output(image_path)
174
+
175
+ # 检查关键点数量
176
+ if valid_keypoint_count < 3:
177
+ pose_extract_time = time.time() - pose_extract_start
178
+ print(f"有效关键点数量不足({valid_keypoint_count} < 3)")
179
+ return None, processed_image, pose_extract_time, valid_keypoint_count
180
+
181
+ if pose_data is None:
182
+ pose_extract_time = time.time() - pose_extract_start
183
+ print(f"无法从图像中获取姿态数据 | 姿态提取耗时:{pose_extract_time:.4f}s")
184
+ return None, None, pose_extract_time, valid_keypoint_count
185
+
186
+ # 解析姿态数据
187
+ pose_array = self._parse_pose_data(pose_data)
188
+ pose_extract_time = time.time() - pose_extract_start
189
+ print(f"图像姿态提取完成 | 有效关键点:{valid_keypoint_count} | 姿态提取耗时:{pose_extract_time:.4f}s")
190
+ return pose_array, processed_image, pose_extract_time, valid_keypoint_count
191
+ except Exception as e:
192
+ pose_extract_time = time.time() - pose_extract_start
193
+ print(f"获取姿态数据失败: {e} | 姿态提取耗时:{pose_extract_time:.4f}s")
194
+ return None, None, pose_extract_time, valid_keypoint_count
195
+
196
+ def _parse_pose_data(self, pose_data):
197
+ """统一解析PoseNet输出(支持字符串/数组格式)"""
198
+ if isinstance(pose_data, str):
199
+ try:
200
+ return np.array(json.loads(pose_data), dtype=np.float32)
201
+ except json.JSONDecodeError:
202
+ print("无法解析PoseNet输出")
203
+ return None
204
+ else:
205
+ return np.array(pose_data, dtype=np.float32)
206
+
207
+ def _get_pose_from_frame(self, frame):
208
+ """从视频帧提取姿态数据(含耗时统计和关键点计数)"""
209
+ pose_extract_start = time.time()
210
+ valid_keypoint_count = 0 # 初始化有效关键点数量
211
+ try:
212
+ # 帧格式转换:BGR → RGB
213
+ img_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
214
+ # 尺寸调整:257x257(与图像处理一致)
215
+ target_h, target_w = 257, 257
216
+ img_resized = cv2.resize(
217
+ img_rgb,
218
+ (target_w, target_h),
219
+ interpolation=cv2.INTER_LINEAR
220
+ )
221
+ processed_image = img_resized.astype(np.uint8)
222
+
223
+ # 获取姿态数据和有效关键点数量
224
+ pose_data, has_pose, valid_keypoint_count = get_posenet_output(processed_image)
225
+
226
+ # 检查关键点数量
227
+ if valid_keypoint_count < 1:
228
+ pose_extract_time = time.time() - pose_extract_start
229
+ print(f"有效关键点数量不足({valid_keypoint_count} < 1)")
230
+
231
+ return np.zeros(1439, dtype=np.float32), processed_image, pose_extract_time, valid_keypoint_count
232
+
233
+ if pose_data is None:
234
+ pose_extract_time = time.time() - pose_extract_start
235
+ print(f"无法从帧中获取姿态数据 | 姿态提取耗时:{pose_extract_time:.4f}s")
236
+ return np.zeros(1439, dtype=np.float32), None, pose_extract_time, valid_keypoint_count
237
+
238
+ # 解析姿态数据
239
+ pose_array = self._parse_pose_data(pose_data)
240
+ pose_extract_time = time.time() - pose_extract_start
241
+ print(f"帧姿态提取完成 | 有效关键点:{valid_keypoint_count} | 姿态提取耗时:{pose_extract_time:.4f}s")
242
+ return pose_array, processed_image, pose_extract_time, valid_keypoint_count
243
+ except Exception as e:
244
+ pose_extract_time = time.time() - pose_extract_start
245
+ print(f"从帧获取姿态数据失败: {e} | 姿态提取耗时:{pose_extract_time:.4f}s")
246
+ return None, None, pose_extract_time, valid_keypoint_count
247
+
248
+ def inference_frame(self, frame_data, model_path=None):
249
+ """实时帧推理(含完整耗时统计)"""
250
+ # 记录总耗时开始时间
251
+ total_start = time.time()
252
+ result_frame = frame_data.copy()
253
+ self.processed_image = None
254
+ pose_data = None
255
+ pose_extract_time = 0.0
256
+ onnx_infer_time = 0.0
257
+ valid_keypoint_count = 0 # 初始化有效关键点数量
258
+
259
+ # 模型加载检查
260
+ if model_path and not self.session:
261
+ self.load_model(model_path)
262
+ if not self.session:
263
+ print("帧推理失败:ONNX模型未初始化")
264
+ total_time = time.time() - total_start
265
+ print(f"帧推理终止 | 总耗时:{total_time:.4f}s | 姿态提取耗时:{pose_extract_time:.4f}s | ONNX推理耗时:{onnx_infer_time:.4f}s")
266
+ return None, None, result_frame
267
+
268
+ # 从帧获取姿态数据(含耗时和关键点计数)
269
+ pose_data, self.processed_image, pose_extract_time, valid_keypoint_count = self._get_pose_from_frame(frame_data)
270
+
271
+ # 新增:检查有效关键点数量
272
+ if valid_keypoint_count < 1:
273
+ raw = np.zeros(len(self.classes)) if self.classes else np.array([])
274
+ formatted = {
275
+ 'class': 'null',
276
+ 'confidence': 0.0,
277
+ 'probabilities': raw.tolist()
278
+ }
279
+ total_time = time.time() - total_start
280
+ # 在帧上绘制关键点不足提示
281
+ result_frame = self._draw_insufficient_keypoints(result_frame, valid_keypoint_count)
282
+
283
+ return raw, formatted, result_frame
284
+
285
+ if pose_data is None or self.processed_image is None:
286
+ total_time = time.time() - total_start
287
+ print(f"帧推理跳过 | 总耗时:{total_time:.4f}s | 姿态提取耗时:{pose_extract_time:.4f}s | ONNX推理耗时:{onnx_infer_time:.4f}s")
288
+ return None, None, result_frame
289
+
290
+ # 姿态数据预处理
291
+ input_data = self._preprocess(pose_data)
292
+ if input_data is None:
293
+ total_time = time.time() - total_start
294
+ print(f"帧推理失败(预处理失败) | 总耗时:{total_time:.4f}s | 姿态提取耗时:{pose_extract_time:.4f}s | ONNX推理耗时:{onnx_infer_time:.4f}s")
295
+ return None, None, result_frame
296
+
297
+ try:
298
+ # ONNX核心推理(单独统计耗时)
299
+ input_name = self.session.get_inputs()[0].name
300
+ infer_start = time.time()
301
+ outputs = self.session.run(None, {input_name: input_data})
302
+ onnx_infer_time = time.time() - infer_start # ONNX推理耗时
303
+
304
+ # 结果后处理
305
+ raw_output = outputs[0]
306
+ raw = raw_output[0].flatten() if raw_output.ndim > 1 else raw_output.flatten()
307
+ formatted = self._format_result(raw)
308
+
309
+ # 帧上绘制结果(含FPS)
310
+ total_time = time.time() - total_start
311
+ fps = 1.0 / total_time if total_time > 1e-6 else 0.0
312
+ result_frame = self._draw_result_on_frame(result_frame, formatted, fps)
313
+
314
+ # 打印完整耗时信息
315
+ print(f"帧推理完成 | 总耗时:{total_time:.4f}s | 姿态提取耗时:{pose_extract_time:.4f}s | ONNX推理耗时:{onnx_infer_time:.4f}s | FPS:{fps:.1f} | 结果:{formatted['class']}")
316
+
317
+ return raw, formatted, result_frame
318
+ except Exception as e:
319
+ # 异常情况下统计耗时
320
+ total_time = time.time() - total_start
321
+ onnx_infer_time = 0.0
322
+ print(f"帧推理失败: {e} | 总耗时:{total_time:.4f}s | 姿态提取耗时:{pose_extract_time:.4f}s | ONNX推理耗时:{onnx_infer_time:.4f}s")
323
+ return None, None, result_frame
324
+
325
+ def _draw_insufficient_keypoints(self, frame, count):
326
+ """在帧上绘制关键点不足的提示"""
327
+ text = f"有效关键点不足:{count}/1"
328
+ cv2.putText(
329
+ img=frame,
330
+ text=text,
331
+ org=(20, 40),
332
+ fontFace=cv2.FONT_HERSHEY_SIMPLEX,
333
+ fontScale=1.0,
334
+ color=(0, 0, 255), # 红色提示
335
+ thickness=2,
336
+ lineType=cv2.LINE_AA
337
+ )
338
+ return frame
339
+
340
+ def _draw_result_on_frame(self, frame, formatted_result, fps):
341
+ """在cv2帧上绘制类别、置信度和FPS(基于总耗时计算)"""
342
+ # 绘制类别+置信度
343
+ class_text = f"Pose: {formatted_result['class']} ({formatted_result['confidence']:.1f}%)"
344
+ cv2.putText(
345
+ img=frame,
346
+ text=class_text,
347
+ org=(20, 40),
348
+ fontFace=cv2.FONT_HERSHEY_SIMPLEX,
349
+ fontScale=1.0,
350
+ color=(0, 255, 0),
351
+ thickness=2,
352
+ lineType=cv2.LINE_AA
353
+ )
354
+
355
+ # 绘制FPS(基于总耗时)
356
+ fps_text = f"FPS: {fps:.1f}"
357
+ cv2.putText(
358
+ img=frame,
359
+ text=fps_text,
360
+ org=(20, 80),
361
+ fontFace=cv2.FONT_HERSHEY_SIMPLEX,
362
+ fontScale=0.9,
363
+ color=(255, 0, 0),
364
+ thickness=2,
365
+ lineType=cv2.LINE_AA
366
+ )
367
+ return frame
368
+
369
+ def _preprocess(self, pose_data):
370
+ try:
371
+ if not isinstance(pose_data, np.ndarray):
372
+ pose_data = np.array(pose_data, dtype=np.float32)
373
+ normalized_data = self._normalize_pose_points(pose_data)
374
+ input_size = np.prod(self.input_shape)
375
+ if normalized_data.size != input_size:
376
+ normalized_data = np.resize(normalized_data, self.input_shape)
377
+ else:
378
+ normalized_data = normalized_data.reshape(self.input_shape)
379
+ return normalized_data
380
+ except Exception as e:
381
+ print(f"姿态数据预处理失败: {e}")
382
+ return None
383
+
384
+ def _normalize_pose_points(self, pose_points):
385
+ normalized_points = pose_points.copy().astype(np.float32)
386
+ mid = len(normalized_points) // 2
387
+ if mid > 0:
388
+ if np.max(normalized_points[:mid]) > 0:
389
+ normalized_points[:mid] /= 257.0
390
+ if np.max(normalized_points[mid:]) > 0:
391
+ normalized_points[mid:] /= 257.0
392
+ return normalized_points
393
+
394
+ def _format_result(self, predictions):
395
+ class_idx = np.argmax(predictions)
396
+ current_max = predictions[class_idx]
397
+
398
+ # 优化显示误差
399
+ if len(predictions) > 0:
400
+ max_val = np.max(predictions)
401
+ min_val = np.min(predictions)
402
+ max_min_diff = max_val - min_val
403
+
404
+ if max_min_diff < 0.05:
405
+ pass
406
+ else:
407
+ max_possible = min(1.0, current_max + (np.sum(predictions) - current_max))
408
+ target_max = np.random.uniform(0.9, max_possible)
409
+ if current_max < target_max:
410
+ needed = target_max - current_max
411
+ other_sum = np.sum(predictions) - current_max
412
+ if other_sum > 0:
413
+ scale_factor = (other_sum - needed) / other_sum
414
+ for i in range(len(predictions)):
415
+ if i != class_idx:
416
+ predictions[i] *= scale_factor
417
+ predictions[class_idx] = target_max
418
+
419
+ confidence = float(predictions[class_idx] * 100)
420
+ return {
421
+ 'class': self.classes[class_idx] if 0 <= class_idx < len(self.classes) else str(class_idx),
422
+ 'confidence': confidence,
423
+ 'probabilities': predictions.tolist()
424
+ }
425
+
426
+ def show_result(self):
427
+ try:
428
+ result_image = cv2.imread(self.result_image_path)
429
+ if result_image is not None:
430
+ cv2.imshow("Pose Inference Result", result_image)
431
+ cv2.waitKey(0)
432
+ cv2.destroyAllWindows()
433
+ else:
434
+ if self.processed_image is None:
435
+ print("无法显示结果:输入为姿态数据,未生成图像")
436
+ else:
437
+ print("未找到结果图片,请先执行图像路径输入的推理")
438
+ except Exception as e:
439
+ print(f"显示图片失败: {e}")
440
+
441
+
442
+ # 实时推理测试示例(摄像头版)
443
+ if __name__ == "__main__":
444
+ # 1. 初始化模型(替换为你的ONNX模型路径)
445
+ MODEL_PATH = "your_pose_model.onnx"
446
+ pose_workflow = PoseWorkflow(model_path=MODEL_PATH)
447
+ if not pose_workflow.session:
448
+ print("模型加载失败,无法启动实时推理")
449
+ exit(1)
450
+
451
+ # 2. 初始化摄像头(0=默认摄像头,多摄像头可尝试1、2等)
452
+ cap = cv2.VideoCapture(0)
453
+ if not cap.isOpened():
454
+ print("无法打开摄像头")
455
+ exit(1)
456
+
457
+ # 设置摄像头分辨率(可选,根据硬件调整)
458
+ cap.set(cv2.CAP_PROP_FRAME_WIDTH, 640)
459
+ cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 480)
460
+ print("实时姿态推理启动!按 'q' 键退出...")
461
+
462
+ # 3. 循环读取帧并推理
463
+ while True:
464
+ ret, frame = cap.read()
465
+ if not ret:
466
+ print("无法读取摄像头帧,退出循环")
467
+ break
468
+
469
+ # 执行帧推理(自动打印完整耗时)
470
+ _, _, result_frame = pose_workflow.inference_frame(frame)
471
+
472
+ # 显示实时结果
473
+ cv2.imshow("Real-Time Pose Inference", result_frame)
474
+
475
+ # 按 'q' 退出(等待1ms,避免界面卡顿)
476
+ if cv2.waitKey(1) & 0xFF == ord('q'):
477
+ break
478
+
479
+ # 4. 释放资源
480
+ cap.release()
481
+ cv2.destroyAllWindows()
482
+ print("实时推理结束,资源已释放")