smartpi 0.1.34__py3-none-any.whl → 0.1.36__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,592 @@
1
+ import cv2
2
+ import numpy as np
3
+ from PIL import Image
4
+ import json
5
+ import os
6
+ import time
7
+ from rknnlite.api import RKNNLite # RKNN核心库
8
+ from lib.posenet_utils import get_posenet_output # 姿态关键点提取逻辑(需支持返回3个值)
9
+
10
+
11
+ class PoseWorkflow:
12
+ def __init__(self, model_path=None):
13
+ # 基础属性初始化
14
+ self.model_path = os.path.abspath(model_path) if model_path else None
15
+ self.rknn_lite = None # RKNN Lite实例
16
+ self.classes = [] # 姿态类别标签
17
+ self.input_shape = [] # 模型输入形状
18
+ self.output_shape = [] # 模型输出形状
19
+ self.min_vals = None # 归一化用最小值
20
+ self.max_vals = None # 归一化用最大值
21
+
22
+ # 新增:有效关键点校验配置(可根据需求调整阈值)
23
+ self.min_valid_keypoints = 1 # 最小有效关键点数量(低于此值则推理终止)
24
+
25
+ # 结果相关属性
26
+ self.result_image_path = "result.jpg"
27
+ self.processed_image = None
28
+ self.last_infer_time = 0.0 # 记录上一帧推理耗时(用于FPS计算)
29
+
30
+ # 若传入模型路径,直接加载
31
+ if model_path:
32
+ self.load_model()
33
+
34
+ def _get_metadata_path(self):
35
+ """获取元数据文件路径"""
36
+ if not self.model_path:
37
+ raise ValueError("模型路径未初始化")
38
+
39
+ # 元数据文件规则:与模型同目录,模型名+"_rknn_metadata.json"
40
+ base_dir = os.path.dirname(self.model_path)
41
+ base_name = os.path.basename(self.model_path)
42
+ metadata_name = os.path.splitext(base_name)[0] + "_rknn_metadata.json"
43
+ metadata_path = os.path.join(base_dir, metadata_name)
44
+
45
+ # 若自定义名不存在,尝试默认名
46
+ if not os.path.exists(metadata_path):
47
+ metadata_path = os.path.join(base_dir, "rknn_metadata.json")
48
+ print(f"自定义元数据文件不存在,尝试默认路径:{metadata_path}")
49
+
50
+ return metadata_path
51
+
52
+ def _load_metadata(self):
53
+ """加载元数据(classes、input_shape、minMax等)"""
54
+ metadata_path = self._get_metadata_path()
55
+ try:
56
+ with open(metadata_path, "r", encoding="utf-8") as f:
57
+ metadata = json.load(f)
58
+
59
+ # 读取核心元数据
60
+ self.classes = metadata.get("classes", []) # 姿态类别列表
61
+ self.input_shape = [1, 14739] # 输入形状,默认14739维
62
+ self.output_shape = metadata.get("output_shape", [1, len(self.classes)]) # 输出形状
63
+ min_max = metadata.get("minMax", {})
64
+ self.min_vals = np.array(min_max.get("min", []), dtype=np.float32) # 归一化最小值
65
+ self.max_vals = np.array(min_max.get("max", []), dtype=np.float32) # 归一化最大值
66
+
67
+
68
+ except Exception as e:
69
+ print(f"元数据加载失败:{e},使用默认配置")
70
+ self.classes = []
71
+ self.input_shape = [1, 14739]
72
+ self.min_vals = np.array([])
73
+ self.max_vals = np.array([])
74
+
75
+ def load_model(self):
76
+ """加载RKNN模型"""
77
+ try:
78
+ # 初始化RKNN Lite
79
+ self.rknn_lite = RKNNLite()
80
+
81
+ # 加载RKNN模型文件
82
+ ret = self.rknn_lite.load_rknn(self.model_path)
83
+ if ret != 0:
84
+ raise RuntimeError(f"RKNN模型加载失败,错误码:{ret}")
85
+
86
+ # 初始化NPU运行时
87
+ ret = self.rknn_lite.init_runtime()
88
+ if ret != 0:
89
+ raise RuntimeError(f"NPU运行时初始化失败,错误码:{ret}")
90
+
91
+ # 加载元数据
92
+ self._load_metadata()
93
+ print("RKNN模型加载完成")
94
+
95
+ except Exception as e:
96
+ print(f"模型加载总失败:{e}")
97
+ # 释放资源避免泄漏
98
+ if self.rknn_lite:
99
+ self.rknn_lite.release()
100
+
101
+ def _normalize_pose_points(self, pose_points):
102
+ """姿态关键点归一化"""
103
+ pose_points = np.array(pose_points, dtype=np.float32)
104
+ # 若元数据有minMax且长度匹配,使用minMax归一化
105
+ if len(self.min_vals) > 0 and len(self.max_vals) > 0 and len(self.min_vals) == len(pose_points):
106
+ # 避免除以零
107
+ ranges = self.max_vals - self.min_vals
108
+ ranges[ranges < 1e-6] = 1e-6
109
+ normalized = (pose_points - self.min_vals) / ranges
110
+ return normalized
111
+ return pose_points
112
+
113
+ def _preprocess(self, pose_data):
114
+ """姿态数据预处理"""
115
+ try:
116
+ # 转为numpy数组
117
+ if not isinstance(pose_data, np.ndarray):
118
+ pose_data = np.array(pose_data, dtype=np.float32)
119
+
120
+ # 归一化
121
+ normalized_data = self._normalize_pose_points(pose_data)
122
+
123
+ # 调整为模型输入形状
124
+ input_size = np.prod(self.input_shape)
125
+ if normalized_data.size != input_size:
126
+ print(f"输入数据长度不匹配(实际:{normalized_data.size} | 期望:{input_size}),自动调整维度")
127
+ normalized_data = np.resize(normalized_data, self.input_shape)
128
+ else:
129
+ normalized_data = normalized_data.reshape(self.input_shape)
130
+
131
+ # RKNN要求输入为float32
132
+ return normalized_data.astype(np.float32)
133
+
134
+ except Exception as e:
135
+ print(f"姿态数据预处理失败:{e}")
136
+ return None
137
+
138
+ def _get_pose_from_image(self, image_path):
139
+ """从图像提取姿态数据(含耗时统计+有效关键点计数)"""
140
+ try:
141
+ print(f"正在处理图像:{image_path}")
142
+ # 记录姿态提取开始时间
143
+ pose_extract_start = time.time()
144
+
145
+ # 图像读取与预处理
146
+ img = Image.open(image_path).convert("RGB")
147
+ target_h, target_w = 257, 257 # PoseNet默认输入尺寸
148
+ img_resized = img.resize((target_w, target_h), Image.BILINEAR)
149
+ self.processed_image = np.array(img_resized, dtype=np.uint8)
150
+
151
+ # 【关键修改1】调用PoseNet获取3个返回值:姿态数据、是否有姿态、有效关键点数量
152
+ pose_data, has_pose, valid_keypoint_count = get_posenet_output(image_path)
153
+ # 处理可能的None值(避免后续报错)
154
+ valid_keypoint_count = valid_keypoint_count if valid_keypoint_count is not None else 0
155
+
156
+ # 【关键校验1】有效关键点数量不足,返回无效结果
157
+ if valid_keypoint_count < self.min_valid_keypoints:
158
+ pose_extract_time = time.time() - pose_extract_start
159
+ print(f"图像有效关键点不足({valid_keypoint_count}/{self.min_valid_keypoints}),无法提取姿态")
160
+ return None, pose_extract_time, valid_keypoint_count
161
+
162
+ # 姿态数据为空的情况
163
+ if pose_data is None or not has_pose:
164
+ pose_extract_time = time.time() - pose_extract_start
165
+ print(f"无法从图像中获取姿态数据 | 姿态提取耗时:{pose_extract_time:.4f}s")
166
+ return None, pose_extract_time, valid_keypoint_count
167
+
168
+ # 姿态数据格式转换
169
+ pose_array = self._parse_pose_data(pose_data)
170
+ # 计算姿态提取耗时
171
+ pose_extract_time = time.time() - pose_extract_start
172
+ print(f"图像姿态提取完成 | 有效关键点:{valid_keypoint_count}/{self.min_valid_keypoints} | 姿态提取耗时:{pose_extract_time:.4f}s")
173
+ return pose_array, pose_extract_time, valid_keypoint_count
174
+
175
+ except Exception as e:
176
+ pose_extract_time = time.time() - pose_extract_start
177
+ valid_keypoint_count = 0 # 异常时默认关键点数量为0
178
+ print(f"获取图像姿态数据失败:{e} | 姿态提取耗时:{pose_extract_time:.4f}s | 有效关键点:{valid_keypoint_count}")
179
+ return None, pose_extract_time, valid_keypoint_count
180
+
181
+ def _parse_pose_data(self, pose_data):
182
+ """统一解析PoseNet输出(支持字符串/数组格式)"""
183
+ if isinstance(pose_data, str):
184
+ try:
185
+ return np.array(json.loads(pose_data), dtype=np.float32)
186
+ except json.JSONDecodeError:
187
+ print("PoseNet输出JSON解析失败")
188
+ return None
189
+ else:
190
+ return np.array(pose_data, dtype=np.float32)
191
+
192
+ def _get_pose_from_frame(self, frame):
193
+ """从视频帧提取姿态数据(含耗时统计+有效关键点计数)"""
194
+ try:
195
+ # 记录姿态提取开始时间
196
+ pose_extract_start = time.time()
197
+
198
+ # 帧格式转换:cv2默认BGR → PoseNet要求RGB
199
+ img_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
200
+
201
+ # 尺寸调整:与PoseNet输入一致(257x257)
202
+ target_h, target_w = 257, 257
203
+ img_resized = cv2.resize(
204
+ img_rgb,
205
+ (target_w, target_h),
206
+ interpolation=cv2.INTER_LINEAR
207
+ )
208
+ processed_frame = img_resized.astype(np.uint8) # 保留预处理后帧用于可视化
209
+
210
+ # 【关键修改2】调用PoseNet获取3个返回值:姿态数据、是否有姿态、有效关键点数量
211
+ pose_data, has_pose, valid_keypoint_count = get_posenet_output(processed_frame)
212
+ # 处理可能的None值
213
+ valid_keypoint_count = valid_keypoint_count if valid_keypoint_count is not None else 0
214
+
215
+ # 【关键校验2】有效关键点数量不足,返回无效结果
216
+ if valid_keypoint_count < self.min_valid_keypoints:
217
+ pose_extract_time = time.time() - pose_extract_start
218
+ print(f"帧有效关键点不足({valid_keypoint_count}/{self.min_valid_keypoints}),无法提取姿态")
219
+ return None, None, pose_extract_time, valid_keypoint_count
220
+
221
+ # 姿态数据为空的情况
222
+ if pose_data is None or not has_pose:
223
+ pose_extract_time = time.time() - pose_extract_start
224
+ print(f"无法从当前帧中获取姿态数据 | 姿态提取耗时:{pose_extract_time:.4f}s | 有效关键点:{valid_keypoint_count}")
225
+ return None, None, pose_extract_time, valid_keypoint_count
226
+
227
+ # 统一解析姿态数据
228
+ pose_array = self._parse_pose_data(pose_data)
229
+ # 计算姿态提取耗时
230
+ pose_extract_time = time.time() - pose_extract_start
231
+ print(f"帧姿态提取完成 | 有效关键点:{valid_keypoint_count}/{self.min_valid_keypoints} | 姿态提取耗时:{pose_extract_time:.4f}s")
232
+ return pose_array, processed_frame, pose_extract_time, valid_keypoint_count
233
+
234
+ except Exception as e:
235
+ pose_extract_time = time.time() - pose_extract_start
236
+ valid_keypoint_count = 0 # 异常时默认关键点数量为0
237
+ print(f"获取帧姿态数据失败:{e} | 姿态提取耗时:{pose_extract_time:.4f}s | 有效关键点:{valid_keypoint_count}")
238
+ return None, None, pose_extract_time, valid_keypoint_count
239
+
240
+ # 【新增方法1】绘制关键点不足的红色提示
241
+ def _draw_insufficient_keypoints(self, frame, valid_count):
242
+ """在帧上绘制“有效关键点不足”的红色提示文本"""
243
+ text = f"有效关键点不足:{valid_count}/{self.min_valid_keypoints}"
244
+ cv2.putText(
245
+ img=frame,
246
+ text=text,
247
+ org=(20, 40), # 与正常结果文本位置一致,覆盖无效信息
248
+ fontFace=cv2.FONT_HERSHEY_SIMPLEX,
249
+ fontScale=1.0,
250
+ color=(0, 0, 255), # 红色警示
251
+ thickness=2,
252
+ lineType=cv2.LINE_AA
253
+ )
254
+ # 可选:绘制FPS(即使关键点不足也显示帧率)
255
+ fps = 1.0 / self.last_infer_time if self.last_infer_time > 1e-6 else 0.0
256
+ fps_text = f"FPS: {fps:.1f}"
257
+ cv2.putText(
258
+ img=frame,
259
+ text=fps_text,
260
+ org=(20, 80),
261
+ fontFace=cv2.FONT_HERSHEY_SIMPLEX,
262
+ fontScale=0.9,
263
+ color=(255, 0, 0),
264
+ thickness=2,
265
+ lineType=cv2.LINE_AA
266
+ )
267
+ return frame
268
+
269
+ def _draw_result_on_frame(self, frame, formatted_result):
270
+ """在cv2帧上绘制推理结果"""
271
+ # 计算实时FPS(基于总耗时)
272
+ fps = 1.0 / self.last_infer_time if self.last_infer_time > 1e-6 else 0.0
273
+
274
+ # 绘制类别+置信度
275
+ class_text = f"Pose: {formatted_result['class']} ({formatted_result['confidence']:.1f}%)"
276
+ cv2.putText(
277
+ img=frame,
278
+ text=class_text,
279
+ org=(20, 40),
280
+ fontFace=cv2.FONT_HERSHEY_SIMPLEX,
281
+ fontScale=1.0,
282
+ color=(0, 255, 0),
283
+ thickness=2,
284
+ lineType=cv2.LINE_AA
285
+ )
286
+
287
+ # 绘制FPS(基于总耗时)
288
+ fps_text = f"FPS: {fps:.1f}"
289
+ cv2.putText(
290
+ img=frame,
291
+ text=fps_text,
292
+ org=(20, 80),
293
+ fontFace=cv2.FONT_HERSHEY_SIMPLEX,
294
+ fontScale=0.9,
295
+ color=(255, 0, 0),
296
+ thickness=2,
297
+ lineType=cv2.LINE_AA
298
+ )
299
+
300
+ return frame
301
+
302
+ def inference_frame(self, frame, model_path=None):
303
+ """实时帧推理入口(含完整耗时统计+有效关键点校验)"""
304
+ # 记录总耗时开始时间
305
+ total_start = time.time()
306
+ result_frame = frame.copy()
307
+ self.processed_image = None
308
+ pose_data = None
309
+ valid_keypoint_count = 0 # 初始化有效关键点数量
310
+ pose_extract_time = 0.0
311
+
312
+ try:
313
+ # 模型加载检查
314
+ if model_path and (not self.rknn_lite or self.model_path != os.path.abspath(model_path)):
315
+ self.model_path = os.path.abspath(model_path)
316
+ self.load_model()
317
+ if not self.rknn_lite:
318
+ raise RuntimeError("RKNN模型未加载,无法执行推理")
319
+
320
+ # 【关键修改3】调用帧姿态提取,接收4个返回值(新增有效关键点数量)
321
+ pose_data, self.processed_image, pose_extract_time, valid_keypoint_count = self._get_pose_from_frame(frame)
322
+
323
+ # 【关键校验3】有效关键点不足,直接返回无效结果
324
+ if valid_keypoint_count < self.min_valid_keypoints:
325
+ # 生成无效结果格式
326
+ raw_result = np.zeros(len(self.classes)) if self.classes else np.array([])
327
+ formatted_result = {
328
+ "class": "null",
329
+ "confidence": 0.0,
330
+ "probabilities": raw_result.tolist(),
331
+ "class_id": -1
332
+ }
333
+ # 计算总耗时并更新FPS
334
+ total_time = time.time() - total_start
335
+ self.last_infer_time = total_time
336
+ # 绘制关键点不足提示
337
+ result_frame = self._draw_insufficient_keypoints(result_frame, valid_keypoint_count)
338
+ # 打印日志
339
+ print(f"帧推理终止(有效关键点不足) | 总耗时:{total_time:.4f}s | 姿态提取耗时:{pose_extract_time:.4f}s | RKNN推理耗时:0.0000s | FPS:{1/total_time:.1f}")
340
+ return raw_result, formatted_result, result_frame
341
+
342
+ # 姿态数据为空(非关键点不足的其他情况)
343
+ if pose_data is None or self.processed_image is None:
344
+ total_time = time.time() - total_start
345
+ self.last_infer_time = total_time
346
+ print(f"帧推理跳过 | 总耗时:{total_time:.4f}s | 姿态提取耗时:{pose_extract_time:.4f}s | RKNN推理耗时:0.0000s | FPS:{1/total_time:.1f}")
347
+ return None, None, result_frame
348
+
349
+ # 姿态数据预处理
350
+ input_data = self._preprocess(pose_data)
351
+ if input_data is None:
352
+ total_time = time.time() - total_start
353
+ self.last_infer_time = total_time
354
+ print(f"帧推理失败(预处理失败) | 总耗时:{total_time:.4f}s | 姿态提取耗时:{pose_extract_time:.4f}s | RKNN推理耗时:0.0000s | FPS:{1/total_time:.1f}")
355
+ return None, None, result_frame
356
+
357
+ # RKNN核心推理(单独统计推理耗时)
358
+ infer_start = time.time()
359
+ outputs = self.rknn_lite.inference(inputs=[input_data])
360
+ rknn_infer_time = time.time() - infer_start # RKNN推理耗时
361
+ raw_output = outputs[0]
362
+
363
+ # 结果后处理
364
+ raw_result = raw_output[0].flatten() if raw_output.ndim > 1 else raw_output.flatten()
365
+ formatted_result = self._format_result(raw_result)
366
+
367
+ # 帧结果可视化
368
+ result_frame = self._draw_result_on_frame(result_frame, formatted_result)
369
+
370
+ # 计算总耗时(从方法开始到推理完成)
371
+ total_time = time.time() - total_start
372
+ self.last_infer_time = total_time # 更新最后一帧总耗时(用于FPS计算)
373
+
374
+ # 打印完整耗时信息
375
+ print(f"帧推理完成 | 总耗时:{total_time:.4f}s | 姿态提取耗时:{pose_extract_time:.4f}s | RKNN推理耗时:{rknn_infer_time:.4f}s | FPS:{1/total_time:.1f} | 结果:{formatted_result['class']}")
376
+
377
+ return raw_result, formatted_result, result_frame
378
+
379
+ except Exception as e:
380
+ # 异常情况下也统计耗时
381
+ total_time = time.time() - total_start
382
+ rknn_infer_time = 0.0 # 推理未执行,耗时为0
383
+ self.last_infer_time = total_time
384
+ print(f"帧推理失败:{e} | 总耗时:{total_time:.4f}s | 姿态提取耗时:{pose_extract_time:.4f}s | RKNN推理耗时:{rknn_infer_time:.4f}s | FPS:{1/total_time:.1f}")
385
+ return None, None, result_frame
386
+
387
+ def inference(self, data, model_path=None):
388
+ """RKNN推理(含完整耗时统计+有效关键点校验)"""
389
+ # 记录总耗时开始时间
390
+ total_start = time.time()
391
+ self.processed_image = None
392
+ pose_data = None
393
+ raw_result = None
394
+ formatted_result = None
395
+ pose_extract_time = 0.0 # 默认为0(若输入已为姿态数据,无需提取)
396
+ valid_keypoint_count = 0 # 初始化有效关键点数量
397
+
398
+ try:
399
+ # 模型加载检查
400
+ if model_path and (not self.rknn_lite or self.model_path != os.path.abspath(model_path)):
401
+ self.model_path = os.path.abspath(model_path)
402
+ self.load_model()
403
+ if not self.rknn_lite:
404
+ raise RuntimeError("RKNN模型未加载,无法执行推理")
405
+
406
+ # 处理输入数据
407
+ if isinstance(data, str):
408
+ # 【关键修改4】输入为图像路径:接收3个返回值(新增有效关键点数量)
409
+ pose_data, pose_extract_time, valid_keypoint_count = self._get_pose_from_image(data)
410
+
411
+ # 【关键校验4】有效关键点不足,返回无效结果
412
+ if valid_keypoint_count < self.min_valid_keypoints:
413
+ raw_result = np.zeros(len(self.classes)) if self.classes else np.array([])
414
+ formatted_result = {
415
+ "class": "null",
416
+ "confidence": 0.0,
417
+ "probabilities": raw_result.tolist(),
418
+ "class_id": -1
419
+ }
420
+ total_time = time.time() - total_start
421
+ print(f"推理终止(有效关键点不足) | 总耗时:{total_time:.4f}s | 姿态提取耗时:{pose_extract_time:.4f}s | RKNN推理耗时:0.0000s")
422
+ return raw_result, formatted_result
423
+
424
+ # 姿态数据为空(非关键点不足的其他情况)
425
+ if pose_data is None:
426
+ total_time = time.time() - total_start
427
+ print(f"推理终止(姿态数据为空) | 总耗时:{total_time:.4f}s | 姿态提取耗时:{pose_extract_time:.4f}s | RKNN推理耗时:0.0000s")
428
+ return None, None
429
+ else:
430
+ # 输入为已提取的姿态数据,无需提取(默认关键点充足,跳过校验)
431
+ pose_data = data
432
+ valid_keypoint_count = "未知(输入为姿态数据)"
433
+ print(f"输入为姿态数据,跳过图像处理 | 姿态提取耗时:0.0000s | 有效关键点:{valid_keypoint_count}")
434
+
435
+ # 数据预处理
436
+ input_data = self._preprocess(pose_data)
437
+ if input_data is None:
438
+ total_time = time.time() - total_start
439
+ print(f"推理终止(预处理失败) | 总耗时:{total_time:.4f}s | 姿态提取耗时:{pose_extract_time:.4f}s | RKNN推理耗时:0.0000s")
440
+ return None, None
441
+
442
+ # RKNN推理(单独统计推理耗时)
443
+ infer_start = time.time()
444
+ outputs = self.rknn_lite.inference(inputs=[input_data])
445
+ rknn_infer_time = time.time() - infer_start # RKNN推理耗时
446
+ raw_output = outputs[0]
447
+
448
+ # 结果后处理
449
+ raw_result = raw_output[0].flatten() if raw_output.ndim > 1 else raw_output.flatten()
450
+ formatted_result = self._format_result(raw_result)
451
+
452
+ # 结果可视化
453
+ if self.processed_image is not None and formatted_result:
454
+ bgr_image = cv2.cvtColor(self.processed_image, cv2.COLOR_RGB2BGR)
455
+ # 若为关键点不足的无效结果,绘制红色提示;否则绘制正常结果
456
+ if formatted_result["class"] == "null":
457
+ text = f"有效关键点不足:{valid_keypoint_count}/{self.min_valid_keypoints}"
458
+ cv2.putText(bgr_image, text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 255), 2)
459
+ else:
460
+ text = f"{formatted_result['class']} {formatted_result['confidence']:.1f}%"
461
+ cv2.putText(bgr_image, text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 0), 2)
462
+ cv2.imwrite(self.result_image_path, bgr_image)
463
+ print(f"结果图已保存至:{self.result_image_path} | {text}")
464
+
465
+ # 计算总耗时并打印完整信息
466
+ total_time = time.time() - total_start
467
+ print(f"推理完成:")
468
+ print(f"- 总耗时:{total_time:.4f}秒")
469
+ print(f"- 姿态提取耗时:{pose_extract_time:.4f}秒")
470
+ print(f"- RKNN推理耗时:{rknn_infer_time:.4f}秒")
471
+ print(f"- 有效关键点:{valid_keypoint_count}/{self.min_valid_keypoints}")
472
+
473
+ # 将时间信息添加到formatted结果中,以便pose.py获取
474
+ if formatted_result:
475
+ formatted_result['pose_extract_time'] = pose_extract_time
476
+ formatted_result['inference_time'] = rknn_infer_time
477
+ formatted_result['total_time'] = total_time
478
+
479
+ except Exception as e:
480
+ # 异常情况下也统计耗时
481
+ total_time = time.time() - total_start
482
+ rknn_infer_time = 0.0 # 推理未执行,耗时为0
483
+ print(f"推理失败:{e}")
484
+ print(f"- 总耗时:{total_time:.4f}秒")
485
+ print(f"- 姿态提取耗时:{pose_extract_time:.4f}秒")
486
+ print(f"- RKNN推理耗时:{rknn_infer_time:.4f}秒")
487
+ print(f"- 有效关键点:{valid_keypoint_count}/{self.min_valid_keypoints}")
488
+
489
+ return raw_result, formatted_result
490
+
491
+ def _format_result(self, predictions):
492
+ """结果格式化逻辑"""
493
+ predictions = np.array(predictions, dtype=np.float32)
494
+ class_idx = np.argmax(predictions)
495
+ current_max = predictions[class_idx]
496
+
497
+ # 生成格式化结果
498
+ confidence = float(predictions[class_idx] * 100)
499
+ return {
500
+ "class": self.classes[class_idx] if 0 <= class_idx < len(self.classes) else f"未知类别_{class_idx}",
501
+ "confidence": confidence,
502
+ "probabilities": predictions.tolist(),
503
+ "class_id": int(class_idx)
504
+ }
505
+
506
+ def show_result(self):
507
+ """结果显示逻辑"""
508
+ try:
509
+ result_image = cv2.imread(self.result_image_path)
510
+ if result_image is not None:
511
+ cv2.imshow("Pose RKNN Inference Result", result_image)
512
+ print("按任意键关闭窗口...")
513
+ cv2.waitKey(0)
514
+ cv2.destroyAllWindows()
515
+ else:
516
+ if self.processed_image is None:
517
+ print("无法显示结果:输入为姿态数据,未生成图像")
518
+ else:
519
+ print("未找到结果图片,请先执行图像路径输入的推理")
520
+ except Exception as e:
521
+ print(f"显示结果失败:{e}")
522
+
523
+ def release(self):
524
+ """释放RKNN资源"""
525
+ if hasattr(self, "rknn_lite") and self.rknn_lite:
526
+ self.rknn_lite.release()
527
+ print("RKNN NPU资源已释放")
528
+
529
+ def __del__(self):
530
+ """析构函数自动释放资源"""
531
+ self.release()
532
+
533
+ # 便捷接口
534
+ def predict(self, image_path):
535
+ """从图像路径推理(同步接口)"""
536
+ return self.inference(image_path)
537
+
538
+ def predict_pose_data(self, pose_data):
539
+ """从已提取的姿态数据推理(同步接口)"""
540
+ return self.inference(pose_data)
541
+
542
+
543
+ # 实时摄像头推理测试
544
+ if __name__ == "__main__":
545
+ # 配置参数(替换为你的RKNN模型路径)
546
+ RKNN_MODEL_PATH = "your_pose_model.rknn"
547
+ CAMERA_INDEX = 0 # 0=默认摄像头
548
+
549
+ # 初始化姿态推理工作流
550
+ pose_workflow = PoseWorkflow(model_path=RKNN_MODEL_PATH)
551
+ if not pose_workflow.rknn_lite:
552
+ print("RKNN模型初始化失败,无法启动实时推理")
553
+ exit(1)
554
+
555
+ # 初始化摄像头
556
+ cap = cv2.VideoCapture(CAMERA_INDEX)
557
+ if not cap.isOpened():
558
+ print(f"无法打开摄像头(索引:{CAMERA_INDEX})")
559
+ pose_workflow.release()
560
+ exit(1)
561
+
562
+ # 设置摄像头分辨率
563
+ cap.set(cv2.CAP_PROP_FRAME_WIDTH, 640)
564
+ cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 480)
565
+ print("实时姿态推理启动成功!")
566
+ print(f"模型路径:{RKNN_MODEL_PATH}")
567
+ print(f"最小有效关键点:{pose_workflow.min_valid_keypoints}个")
568
+ print("按 'q' 键退出实时推理...")
569
+
570
+ # 循环读取摄像头帧并推理
571
+ while True:
572
+ ret, frame = cap.read()
573
+ if not ret:
574
+ print("无法读取摄像头帧,退出循环")
575
+ break
576
+
577
+ # 执行帧推理(自动打印完整耗时)
578
+ _, _, result_frame = pose_workflow.inference_frame(frame)
579
+
580
+ # 显示带标注的帧
581
+ cv2.imshow("RKNN Real-Time Pose Inference", result_frame)
582
+
583
+ # 按'q'退出
584
+ if cv2.waitKey(1) & 0xFF == ord('q'):
585
+ print("用户触发退出")
586
+ break
587
+
588
+ # 释放资源
589
+ cap.release()
590
+ cv2.destroyAllWindows()
591
+ pose_workflow.release()
592
+ print("实时推理结束,资源已释放")