smartpi 0.1.38__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smartpi/__init__.py +8 -0
- smartpi/_gui.py +66 -0
- smartpi/base_driver.py +566 -0
- smartpi/camera.py +84 -0
- smartpi/color_sensor.py +18 -0
- smartpi/cw2015.py +179 -0
- smartpi/flash.py +130 -0
- smartpi/humidity.py +20 -0
- smartpi/led.py +19 -0
- smartpi/light_sensor.py +72 -0
- smartpi/motor.py +177 -0
- smartpi/move.py +218 -0
- smartpi/onnx_hand_workflow.py +201 -0
- smartpi/onnx_image_workflow.py +176 -0
- smartpi/onnx_pose_workflow.py +482 -0
- smartpi/onnx_text_workflow.py +173 -0
- smartpi/onnx_voice_workflow.py +437 -0
- smartpi/posemodel/__init__.py +0 -0
- smartpi/posemodel/posenet.tflite +0 -0
- smartpi/posenet_utils.py +222 -0
- smartpi/rknn_hand_workflow.py +245 -0
- smartpi/rknn_image_workflow.py +405 -0
- smartpi/rknn_pose_workflow.py +592 -0
- smartpi/rknn_text_workflow.py +240 -0
- smartpi/rknn_voice_workflow.py +394 -0
- smartpi/servo.py +178 -0
- smartpi/temperature.py +18 -0
- smartpi/text_gte_model/__init__.py +0 -0
- smartpi/text_gte_model/config/__init__.py +0 -0
- smartpi/text_gte_model/config/config.json +30 -0
- smartpi/text_gte_model/config/quantize_config.json +30 -0
- smartpi/text_gte_model/config/special_tokens_map.json +7 -0
- smartpi/text_gte_model/config/tokenizer.json +14924 -0
- smartpi/text_gte_model/config/tokenizer_config.json +23 -0
- smartpi/text_gte_model/config/vocab.txt +14760 -0
- smartpi/text_gte_model/gte/__init__.py +0 -0
- smartpi/text_gte_model/gte/gte_model.onnx +0 -0
- smartpi/touch_sensor.py +16 -0
- smartpi/trace.py +120 -0
- smartpi/ultrasonic.py +20 -0
- smartpi-0.1.38.dist-info/METADATA +17 -0
- smartpi-0.1.38.dist-info/RECORD +44 -0
- smartpi-0.1.38.dist-info/WHEEL +5 -0
- smartpi-0.1.38.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,592 @@
|
|
|
1
|
+
import cv2
|
|
2
|
+
import numpy as np
|
|
3
|
+
from PIL import Image
|
|
4
|
+
import json
|
|
5
|
+
import os
|
|
6
|
+
import time
|
|
7
|
+
from rknnlite.api import RKNNLite # RKNN核心库
|
|
8
|
+
from lib.posenet_utils import get_posenet_output # 姿态关键点提取逻辑(需支持返回3个值)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class PoseWorkflow:
|
|
12
|
+
def __init__(self, model_path=None):
|
|
13
|
+
# 基础属性初始化
|
|
14
|
+
self.model_path = os.path.abspath(model_path) if model_path else None
|
|
15
|
+
self.rknn_lite = None # RKNN Lite实例
|
|
16
|
+
self.classes = [] # 姿态类别标签
|
|
17
|
+
self.input_shape = [] # 模型输入形状
|
|
18
|
+
self.output_shape = [] # 模型输出形状
|
|
19
|
+
self.min_vals = None # 归一化用最小值
|
|
20
|
+
self.max_vals = None # 归一化用最大值
|
|
21
|
+
|
|
22
|
+
# 新增:有效关键点校验配置(可根据需求调整阈值)
|
|
23
|
+
self.min_valid_keypoints = 1 # 最小有效关键点数量(低于此值则推理终止)
|
|
24
|
+
|
|
25
|
+
# 结果相关属性
|
|
26
|
+
self.result_image_path = "result.jpg"
|
|
27
|
+
self.processed_image = None
|
|
28
|
+
self.last_infer_time = 0.0 # 记录上一帧推理耗时(用于FPS计算)
|
|
29
|
+
|
|
30
|
+
# 若传入模型路径,直接加载
|
|
31
|
+
if model_path:
|
|
32
|
+
self.load_model()
|
|
33
|
+
|
|
34
|
+
def _get_metadata_path(self):
|
|
35
|
+
"""获取元数据文件路径"""
|
|
36
|
+
if not self.model_path:
|
|
37
|
+
raise ValueError("模型路径未初始化")
|
|
38
|
+
|
|
39
|
+
# 元数据文件规则:与模型同目录,模型名+"_rknn_metadata.json"
|
|
40
|
+
base_dir = os.path.dirname(self.model_path)
|
|
41
|
+
base_name = os.path.basename(self.model_path)
|
|
42
|
+
metadata_name = os.path.splitext(base_name)[0] + "_rknn_metadata.json"
|
|
43
|
+
metadata_path = os.path.join(base_dir, metadata_name)
|
|
44
|
+
|
|
45
|
+
# 若自定义名不存在,尝试默认名
|
|
46
|
+
if not os.path.exists(metadata_path):
|
|
47
|
+
metadata_path = os.path.join(base_dir, "rknn_metadata.json")
|
|
48
|
+
print(f"自定义元数据文件不存在,尝试默认路径:{metadata_path}")
|
|
49
|
+
|
|
50
|
+
return metadata_path
|
|
51
|
+
|
|
52
|
+
def _load_metadata(self):
|
|
53
|
+
"""加载元数据(classes、input_shape、minMax等)"""
|
|
54
|
+
metadata_path = self._get_metadata_path()
|
|
55
|
+
try:
|
|
56
|
+
with open(metadata_path, "r", encoding="utf-8") as f:
|
|
57
|
+
metadata = json.load(f)
|
|
58
|
+
|
|
59
|
+
# 读取核心元数据
|
|
60
|
+
self.classes = metadata.get("classes", []) # 姿态类别列表
|
|
61
|
+
self.input_shape = [1, 14739] # 输入形状,默认14739维
|
|
62
|
+
self.output_shape = metadata.get("output_shape", [1, len(self.classes)]) # 输出形状
|
|
63
|
+
min_max = metadata.get("minMax", {})
|
|
64
|
+
self.min_vals = np.array(min_max.get("min", []), dtype=np.float32) # 归一化最小值
|
|
65
|
+
self.max_vals = np.array(min_max.get("max", []), dtype=np.float32) # 归一化最大值
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
except Exception as e:
|
|
69
|
+
print(f"元数据加载失败:{e},使用默认配置")
|
|
70
|
+
self.classes = []
|
|
71
|
+
self.input_shape = [1, 14739]
|
|
72
|
+
self.min_vals = np.array([])
|
|
73
|
+
self.max_vals = np.array([])
|
|
74
|
+
|
|
75
|
+
def load_model(self):
|
|
76
|
+
"""加载RKNN模型"""
|
|
77
|
+
try:
|
|
78
|
+
# 初始化RKNN Lite
|
|
79
|
+
self.rknn_lite = RKNNLite()
|
|
80
|
+
|
|
81
|
+
# 加载RKNN模型文件
|
|
82
|
+
ret = self.rknn_lite.load_rknn(self.model_path)
|
|
83
|
+
if ret != 0:
|
|
84
|
+
raise RuntimeError(f"RKNN模型加载失败,错误码:{ret}")
|
|
85
|
+
|
|
86
|
+
# 初始化NPU运行时
|
|
87
|
+
ret = self.rknn_lite.init_runtime()
|
|
88
|
+
if ret != 0:
|
|
89
|
+
raise RuntimeError(f"NPU运行时初始化失败,错误码:{ret}")
|
|
90
|
+
|
|
91
|
+
# 加载元数据
|
|
92
|
+
self._load_metadata()
|
|
93
|
+
print("RKNN模型加载完成")
|
|
94
|
+
|
|
95
|
+
except Exception as e:
|
|
96
|
+
print(f"模型加载总失败:{e}")
|
|
97
|
+
# 释放资源避免泄漏
|
|
98
|
+
if self.rknn_lite:
|
|
99
|
+
self.rknn_lite.release()
|
|
100
|
+
|
|
101
|
+
def _normalize_pose_points(self, pose_points):
|
|
102
|
+
"""姿态关键点归一化"""
|
|
103
|
+
pose_points = np.array(pose_points, dtype=np.float32)
|
|
104
|
+
# 若元数据有minMax且长度匹配,使用minMax归一化
|
|
105
|
+
if len(self.min_vals) > 0 and len(self.max_vals) > 0 and len(self.min_vals) == len(pose_points):
|
|
106
|
+
# 避免除以零
|
|
107
|
+
ranges = self.max_vals - self.min_vals
|
|
108
|
+
ranges[ranges < 1e-6] = 1e-6
|
|
109
|
+
normalized = (pose_points - self.min_vals) / ranges
|
|
110
|
+
return normalized
|
|
111
|
+
return pose_points
|
|
112
|
+
|
|
113
|
+
def _preprocess(self, pose_data):
|
|
114
|
+
"""姿态数据预处理"""
|
|
115
|
+
try:
|
|
116
|
+
# 转为numpy数组
|
|
117
|
+
if not isinstance(pose_data, np.ndarray):
|
|
118
|
+
pose_data = np.array(pose_data, dtype=np.float32)
|
|
119
|
+
|
|
120
|
+
# 归一化
|
|
121
|
+
normalized_data = self._normalize_pose_points(pose_data)
|
|
122
|
+
|
|
123
|
+
# 调整为模型输入形状
|
|
124
|
+
input_size = np.prod(self.input_shape)
|
|
125
|
+
if normalized_data.size != input_size:
|
|
126
|
+
print(f"输入数据长度不匹配(实际:{normalized_data.size} | 期望:{input_size}),自动调整维度")
|
|
127
|
+
normalized_data = np.resize(normalized_data, self.input_shape)
|
|
128
|
+
else:
|
|
129
|
+
normalized_data = normalized_data.reshape(self.input_shape)
|
|
130
|
+
|
|
131
|
+
# RKNN要求输入为float32
|
|
132
|
+
return normalized_data.astype(np.float32)
|
|
133
|
+
|
|
134
|
+
except Exception as e:
|
|
135
|
+
print(f"姿态数据预处理失败:{e}")
|
|
136
|
+
return None
|
|
137
|
+
|
|
138
|
+
def _get_pose_from_image(self, image_path):
|
|
139
|
+
"""从图像提取姿态数据(含耗时统计+有效关键点计数)"""
|
|
140
|
+
try:
|
|
141
|
+
print(f"正在处理图像:{image_path}")
|
|
142
|
+
# 记录姿态提取开始时间
|
|
143
|
+
pose_extract_start = time.time()
|
|
144
|
+
|
|
145
|
+
# 图像读取与预处理
|
|
146
|
+
img = Image.open(image_path).convert("RGB")
|
|
147
|
+
target_h, target_w = 257, 257 # PoseNet默认输入尺寸
|
|
148
|
+
img_resized = img.resize((target_w, target_h), Image.BILINEAR)
|
|
149
|
+
self.processed_image = np.array(img_resized, dtype=np.uint8)
|
|
150
|
+
|
|
151
|
+
# 【关键修改1】调用PoseNet获取3个返回值:姿态数据、是否有姿态、有效关键点数量
|
|
152
|
+
pose_data, has_pose, valid_keypoint_count = get_posenet_output(image_path)
|
|
153
|
+
# 处理可能的None值(避免后续报错)
|
|
154
|
+
valid_keypoint_count = valid_keypoint_count if valid_keypoint_count is not None else 0
|
|
155
|
+
|
|
156
|
+
# 【关键校验1】有效关键点数量不足,返回无效结果
|
|
157
|
+
if valid_keypoint_count < self.min_valid_keypoints:
|
|
158
|
+
pose_extract_time = time.time() - pose_extract_start
|
|
159
|
+
print(f"图像有效关键点不足({valid_keypoint_count}/{self.min_valid_keypoints}),无法提取姿态")
|
|
160
|
+
return None, pose_extract_time, valid_keypoint_count
|
|
161
|
+
|
|
162
|
+
# 姿态数据为空的情况
|
|
163
|
+
if pose_data is None or not has_pose:
|
|
164
|
+
pose_extract_time = time.time() - pose_extract_start
|
|
165
|
+
print(f"无法从图像中获取姿态数据 | 姿态提取耗时:{pose_extract_time:.4f}s")
|
|
166
|
+
return None, pose_extract_time, valid_keypoint_count
|
|
167
|
+
|
|
168
|
+
# 姿态数据格式转换
|
|
169
|
+
pose_array = self._parse_pose_data(pose_data)
|
|
170
|
+
# 计算姿态提取耗时
|
|
171
|
+
pose_extract_time = time.time() - pose_extract_start
|
|
172
|
+
print(f"图像姿态提取完成 | 有效关键点:{valid_keypoint_count}/{self.min_valid_keypoints} | 姿态提取耗时:{pose_extract_time:.4f}s")
|
|
173
|
+
return pose_array, pose_extract_time, valid_keypoint_count
|
|
174
|
+
|
|
175
|
+
except Exception as e:
|
|
176
|
+
pose_extract_time = time.time() - pose_extract_start
|
|
177
|
+
valid_keypoint_count = 0 # 异常时默认关键点数量为0
|
|
178
|
+
print(f"获取图像姿态数据失败:{e} | 姿态提取耗时:{pose_extract_time:.4f}s | 有效关键点:{valid_keypoint_count}")
|
|
179
|
+
return None, pose_extract_time, valid_keypoint_count
|
|
180
|
+
|
|
181
|
+
def _parse_pose_data(self, pose_data):
|
|
182
|
+
"""统一解析PoseNet输出(支持字符串/数组格式)"""
|
|
183
|
+
if isinstance(pose_data, str):
|
|
184
|
+
try:
|
|
185
|
+
return np.array(json.loads(pose_data), dtype=np.float32)
|
|
186
|
+
except json.JSONDecodeError:
|
|
187
|
+
print("PoseNet输出JSON解析失败")
|
|
188
|
+
return None
|
|
189
|
+
else:
|
|
190
|
+
return np.array(pose_data, dtype=np.float32)
|
|
191
|
+
|
|
192
|
+
def _get_pose_from_frame(self, frame):
|
|
193
|
+
"""从视频帧提取姿态数据(含耗时统计+有效关键点计数)"""
|
|
194
|
+
try:
|
|
195
|
+
# 记录姿态提取开始时间
|
|
196
|
+
pose_extract_start = time.time()
|
|
197
|
+
|
|
198
|
+
# 帧格式转换:cv2默认BGR → PoseNet要求RGB
|
|
199
|
+
img_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
|
200
|
+
|
|
201
|
+
# 尺寸调整:与PoseNet输入一致(257x257)
|
|
202
|
+
target_h, target_w = 257, 257
|
|
203
|
+
img_resized = cv2.resize(
|
|
204
|
+
img_rgb,
|
|
205
|
+
(target_w, target_h),
|
|
206
|
+
interpolation=cv2.INTER_LINEAR
|
|
207
|
+
)
|
|
208
|
+
processed_frame = img_resized.astype(np.uint8) # 保留预处理后帧用于可视化
|
|
209
|
+
|
|
210
|
+
# 【关键修改2】调用PoseNet获取3个返回值:姿态数据、是否有姿态、有效关键点数量
|
|
211
|
+
pose_data, has_pose, valid_keypoint_count = get_posenet_output(processed_frame)
|
|
212
|
+
# 处理可能的None值
|
|
213
|
+
valid_keypoint_count = valid_keypoint_count if valid_keypoint_count is not None else 0
|
|
214
|
+
|
|
215
|
+
# 【关键校验2】有效关键点数量不足,返回无效结果
|
|
216
|
+
if valid_keypoint_count < self.min_valid_keypoints:
|
|
217
|
+
pose_extract_time = time.time() - pose_extract_start
|
|
218
|
+
print(f"帧有效关键点不足({valid_keypoint_count}/{self.min_valid_keypoints}),无法提取姿态")
|
|
219
|
+
return None, None, pose_extract_time, valid_keypoint_count
|
|
220
|
+
|
|
221
|
+
# 姿态数据为空的情况
|
|
222
|
+
if pose_data is None or not has_pose:
|
|
223
|
+
pose_extract_time = time.time() - pose_extract_start
|
|
224
|
+
print(f"无法从当前帧中获取姿态数据 | 姿态提取耗时:{pose_extract_time:.4f}s | 有效关键点:{valid_keypoint_count}")
|
|
225
|
+
return None, None, pose_extract_time, valid_keypoint_count
|
|
226
|
+
|
|
227
|
+
# 统一解析姿态数据
|
|
228
|
+
pose_array = self._parse_pose_data(pose_data)
|
|
229
|
+
# 计算姿态提取耗时
|
|
230
|
+
pose_extract_time = time.time() - pose_extract_start
|
|
231
|
+
print(f"帧姿态提取完成 | 有效关键点:{valid_keypoint_count}/{self.min_valid_keypoints} | 姿态提取耗时:{pose_extract_time:.4f}s")
|
|
232
|
+
return pose_array, processed_frame, pose_extract_time, valid_keypoint_count
|
|
233
|
+
|
|
234
|
+
except Exception as e:
|
|
235
|
+
pose_extract_time = time.time() - pose_extract_start
|
|
236
|
+
valid_keypoint_count = 0 # 异常时默认关键点数量为0
|
|
237
|
+
print(f"获取帧姿态数据失败:{e} | 姿态提取耗时:{pose_extract_time:.4f}s | 有效关键点:{valid_keypoint_count}")
|
|
238
|
+
return None, None, pose_extract_time, valid_keypoint_count
|
|
239
|
+
|
|
240
|
+
# 【新增方法1】绘制关键点不足的红色提示
|
|
241
|
+
def _draw_insufficient_keypoints(self, frame, valid_count):
|
|
242
|
+
"""在帧上绘制“有效关键点不足”的红色提示文本"""
|
|
243
|
+
text = f"有效关键点不足:{valid_count}/{self.min_valid_keypoints}"
|
|
244
|
+
cv2.putText(
|
|
245
|
+
img=frame,
|
|
246
|
+
text=text,
|
|
247
|
+
org=(20, 40), # 与正常结果文本位置一致,覆盖无效信息
|
|
248
|
+
fontFace=cv2.FONT_HERSHEY_SIMPLEX,
|
|
249
|
+
fontScale=1.0,
|
|
250
|
+
color=(0, 0, 255), # 红色警示
|
|
251
|
+
thickness=2,
|
|
252
|
+
lineType=cv2.LINE_AA
|
|
253
|
+
)
|
|
254
|
+
# 可选:绘制FPS(即使关键点不足也显示帧率)
|
|
255
|
+
fps = 1.0 / self.last_infer_time if self.last_infer_time > 1e-6 else 0.0
|
|
256
|
+
fps_text = f"FPS: {fps:.1f}"
|
|
257
|
+
cv2.putText(
|
|
258
|
+
img=frame,
|
|
259
|
+
text=fps_text,
|
|
260
|
+
org=(20, 80),
|
|
261
|
+
fontFace=cv2.FONT_HERSHEY_SIMPLEX,
|
|
262
|
+
fontScale=0.9,
|
|
263
|
+
color=(255, 0, 0),
|
|
264
|
+
thickness=2,
|
|
265
|
+
lineType=cv2.LINE_AA
|
|
266
|
+
)
|
|
267
|
+
return frame
|
|
268
|
+
|
|
269
|
+
def _draw_result_on_frame(self, frame, formatted_result):
|
|
270
|
+
"""在cv2帧上绘制推理结果"""
|
|
271
|
+
# 计算实时FPS(基于总耗时)
|
|
272
|
+
fps = 1.0 / self.last_infer_time if self.last_infer_time > 1e-6 else 0.0
|
|
273
|
+
|
|
274
|
+
# 绘制类别+置信度
|
|
275
|
+
class_text = f"Pose: {formatted_result['class']} ({formatted_result['confidence']:.1f}%)"
|
|
276
|
+
cv2.putText(
|
|
277
|
+
img=frame,
|
|
278
|
+
text=class_text,
|
|
279
|
+
org=(20, 40),
|
|
280
|
+
fontFace=cv2.FONT_HERSHEY_SIMPLEX,
|
|
281
|
+
fontScale=1.0,
|
|
282
|
+
color=(0, 255, 0),
|
|
283
|
+
thickness=2,
|
|
284
|
+
lineType=cv2.LINE_AA
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
# 绘制FPS(基于总耗时)
|
|
288
|
+
fps_text = f"FPS: {fps:.1f}"
|
|
289
|
+
cv2.putText(
|
|
290
|
+
img=frame,
|
|
291
|
+
text=fps_text,
|
|
292
|
+
org=(20, 80),
|
|
293
|
+
fontFace=cv2.FONT_HERSHEY_SIMPLEX,
|
|
294
|
+
fontScale=0.9,
|
|
295
|
+
color=(255, 0, 0),
|
|
296
|
+
thickness=2,
|
|
297
|
+
lineType=cv2.LINE_AA
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
return frame
|
|
301
|
+
|
|
302
|
+
def inference_frame(self, frame, model_path=None):
|
|
303
|
+
"""实时帧推理入口(含完整耗时统计+有效关键点校验)"""
|
|
304
|
+
# 记录总耗时开始时间
|
|
305
|
+
total_start = time.time()
|
|
306
|
+
result_frame = frame.copy()
|
|
307
|
+
self.processed_image = None
|
|
308
|
+
pose_data = None
|
|
309
|
+
valid_keypoint_count = 0 # 初始化有效关键点数量
|
|
310
|
+
pose_extract_time = 0.0
|
|
311
|
+
|
|
312
|
+
try:
|
|
313
|
+
# 模型加载检查
|
|
314
|
+
if model_path and (not self.rknn_lite or self.model_path != os.path.abspath(model_path)):
|
|
315
|
+
self.model_path = os.path.abspath(model_path)
|
|
316
|
+
self.load_model()
|
|
317
|
+
if not self.rknn_lite:
|
|
318
|
+
raise RuntimeError("RKNN模型未加载,无法执行推理")
|
|
319
|
+
|
|
320
|
+
# 【关键修改3】调用帧姿态提取,接收4个返回值(新增有效关键点数量)
|
|
321
|
+
pose_data, self.processed_image, pose_extract_time, valid_keypoint_count = self._get_pose_from_frame(frame)
|
|
322
|
+
|
|
323
|
+
# 【关键校验3】有效关键点不足,直接返回无效结果
|
|
324
|
+
if valid_keypoint_count < self.min_valid_keypoints:
|
|
325
|
+
# 生成无效结果格式
|
|
326
|
+
raw_result = np.zeros(len(self.classes)) if self.classes else np.array([])
|
|
327
|
+
formatted_result = {
|
|
328
|
+
"class": "null",
|
|
329
|
+
"confidence": 0.0,
|
|
330
|
+
"probabilities": raw_result.tolist(),
|
|
331
|
+
"class_id": -1
|
|
332
|
+
}
|
|
333
|
+
# 计算总耗时并更新FPS
|
|
334
|
+
total_time = time.time() - total_start
|
|
335
|
+
self.last_infer_time = total_time
|
|
336
|
+
# 绘制关键点不足提示
|
|
337
|
+
result_frame = self._draw_insufficient_keypoints(result_frame, valid_keypoint_count)
|
|
338
|
+
# 打印日志
|
|
339
|
+
print(f"帧推理终止(有效关键点不足) | 总耗时:{total_time:.4f}s | 姿态提取耗时:{pose_extract_time:.4f}s | RKNN推理耗时:0.0000s | FPS:{1/total_time:.1f}")
|
|
340
|
+
return raw_result, formatted_result, result_frame
|
|
341
|
+
|
|
342
|
+
# 姿态数据为空(非关键点不足的其他情况)
|
|
343
|
+
if pose_data is None or self.processed_image is None:
|
|
344
|
+
total_time = time.time() - total_start
|
|
345
|
+
self.last_infer_time = total_time
|
|
346
|
+
print(f"帧推理跳过 | 总耗时:{total_time:.4f}s | 姿态提取耗时:{pose_extract_time:.4f}s | RKNN推理耗时:0.0000s | FPS:{1/total_time:.1f}")
|
|
347
|
+
return None, None, result_frame
|
|
348
|
+
|
|
349
|
+
# 姿态数据预处理
|
|
350
|
+
input_data = self._preprocess(pose_data)
|
|
351
|
+
if input_data is None:
|
|
352
|
+
total_time = time.time() - total_start
|
|
353
|
+
self.last_infer_time = total_time
|
|
354
|
+
print(f"帧推理失败(预处理失败) | 总耗时:{total_time:.4f}s | 姿态提取耗时:{pose_extract_time:.4f}s | RKNN推理耗时:0.0000s | FPS:{1/total_time:.1f}")
|
|
355
|
+
return None, None, result_frame
|
|
356
|
+
|
|
357
|
+
# RKNN核心推理(单独统计推理耗时)
|
|
358
|
+
infer_start = time.time()
|
|
359
|
+
outputs = self.rknn_lite.inference(inputs=[input_data])
|
|
360
|
+
rknn_infer_time = time.time() - infer_start # RKNN推理耗时
|
|
361
|
+
raw_output = outputs[0]
|
|
362
|
+
|
|
363
|
+
# 结果后处理
|
|
364
|
+
raw_result = raw_output[0].flatten() if raw_output.ndim > 1 else raw_output.flatten()
|
|
365
|
+
formatted_result = self._format_result(raw_result)
|
|
366
|
+
|
|
367
|
+
# 帧结果可视化
|
|
368
|
+
result_frame = self._draw_result_on_frame(result_frame, formatted_result)
|
|
369
|
+
|
|
370
|
+
# 计算总耗时(从方法开始到推理完成)
|
|
371
|
+
total_time = time.time() - total_start
|
|
372
|
+
self.last_infer_time = total_time # 更新最后一帧总耗时(用于FPS计算)
|
|
373
|
+
|
|
374
|
+
# 打印完整耗时信息
|
|
375
|
+
print(f"帧推理完成 | 总耗时:{total_time:.4f}s | 姿态提取耗时:{pose_extract_time:.4f}s | RKNN推理耗时:{rknn_infer_time:.4f}s | FPS:{1/total_time:.1f} | 结果:{formatted_result['class']}")
|
|
376
|
+
|
|
377
|
+
return raw_result, formatted_result, result_frame
|
|
378
|
+
|
|
379
|
+
except Exception as e:
|
|
380
|
+
# 异常情况下也统计耗时
|
|
381
|
+
total_time = time.time() - total_start
|
|
382
|
+
rknn_infer_time = 0.0 # 推理未执行,耗时为0
|
|
383
|
+
self.last_infer_time = total_time
|
|
384
|
+
print(f"帧推理失败:{e} | 总耗时:{total_time:.4f}s | 姿态提取耗时:{pose_extract_time:.4f}s | RKNN推理耗时:{rknn_infer_time:.4f}s | FPS:{1/total_time:.1f}")
|
|
385
|
+
return None, None, result_frame
|
|
386
|
+
|
|
387
|
+
def inference(self, data, model_path=None):
|
|
388
|
+
"""RKNN推理(含完整耗时统计+有效关键点校验)"""
|
|
389
|
+
# 记录总耗时开始时间
|
|
390
|
+
total_start = time.time()
|
|
391
|
+
self.processed_image = None
|
|
392
|
+
pose_data = None
|
|
393
|
+
raw_result = None
|
|
394
|
+
formatted_result = None
|
|
395
|
+
pose_extract_time = 0.0 # 默认为0(若输入已为姿态数据,无需提取)
|
|
396
|
+
valid_keypoint_count = 0 # 初始化有效关键点数量
|
|
397
|
+
|
|
398
|
+
try:
|
|
399
|
+
# 模型加载检查
|
|
400
|
+
if model_path and (not self.rknn_lite or self.model_path != os.path.abspath(model_path)):
|
|
401
|
+
self.model_path = os.path.abspath(model_path)
|
|
402
|
+
self.load_model()
|
|
403
|
+
if not self.rknn_lite:
|
|
404
|
+
raise RuntimeError("RKNN模型未加载,无法执行推理")
|
|
405
|
+
|
|
406
|
+
# 处理输入数据
|
|
407
|
+
if isinstance(data, str):
|
|
408
|
+
# 【关键修改4】输入为图像路径:接收3个返回值(新增有效关键点数量)
|
|
409
|
+
pose_data, pose_extract_time, valid_keypoint_count = self._get_pose_from_image(data)
|
|
410
|
+
|
|
411
|
+
# 【关键校验4】有效关键点不足,返回无效结果
|
|
412
|
+
if valid_keypoint_count < self.min_valid_keypoints:
|
|
413
|
+
raw_result = np.zeros(len(self.classes)) if self.classes else np.array([])
|
|
414
|
+
formatted_result = {
|
|
415
|
+
"class": "null",
|
|
416
|
+
"confidence": 0.0,
|
|
417
|
+
"probabilities": raw_result.tolist(),
|
|
418
|
+
"class_id": -1
|
|
419
|
+
}
|
|
420
|
+
total_time = time.time() - total_start
|
|
421
|
+
print(f"推理终止(有效关键点不足) | 总耗时:{total_time:.4f}s | 姿态提取耗时:{pose_extract_time:.4f}s | RKNN推理耗时:0.0000s")
|
|
422
|
+
return raw_result, formatted_result
|
|
423
|
+
|
|
424
|
+
# 姿态数据为空(非关键点不足的其他情况)
|
|
425
|
+
if pose_data is None:
|
|
426
|
+
total_time = time.time() - total_start
|
|
427
|
+
print(f"推理终止(姿态数据为空) | 总耗时:{total_time:.4f}s | 姿态提取耗时:{pose_extract_time:.4f}s | RKNN推理耗时:0.0000s")
|
|
428
|
+
return None, None
|
|
429
|
+
else:
|
|
430
|
+
# 输入为已提取的姿态数据,无需提取(默认关键点充足,跳过校验)
|
|
431
|
+
pose_data = data
|
|
432
|
+
valid_keypoint_count = "未知(输入为姿态数据)"
|
|
433
|
+
print(f"输入为姿态数据,跳过图像处理 | 姿态提取耗时:0.0000s | 有效关键点:{valid_keypoint_count}")
|
|
434
|
+
|
|
435
|
+
# 数据预处理
|
|
436
|
+
input_data = self._preprocess(pose_data)
|
|
437
|
+
if input_data is None:
|
|
438
|
+
total_time = time.time() - total_start
|
|
439
|
+
print(f"推理终止(预处理失败) | 总耗时:{total_time:.4f}s | 姿态提取耗时:{pose_extract_time:.4f}s | RKNN推理耗时:0.0000s")
|
|
440
|
+
return None, None
|
|
441
|
+
|
|
442
|
+
# RKNN推理(单独统计推理耗时)
|
|
443
|
+
infer_start = time.time()
|
|
444
|
+
outputs = self.rknn_lite.inference(inputs=[input_data])
|
|
445
|
+
rknn_infer_time = time.time() - infer_start # RKNN推理耗时
|
|
446
|
+
raw_output = outputs[0]
|
|
447
|
+
|
|
448
|
+
# 结果后处理
|
|
449
|
+
raw_result = raw_output[0].flatten() if raw_output.ndim > 1 else raw_output.flatten()
|
|
450
|
+
formatted_result = self._format_result(raw_result)
|
|
451
|
+
|
|
452
|
+
# 结果可视化
|
|
453
|
+
if self.processed_image is not None and formatted_result:
|
|
454
|
+
bgr_image = cv2.cvtColor(self.processed_image, cv2.COLOR_RGB2BGR)
|
|
455
|
+
# 若为关键点不足的无效结果,绘制红色提示;否则绘制正常结果
|
|
456
|
+
if formatted_result["class"] == "null":
|
|
457
|
+
text = f"有效关键点不足:{valid_keypoint_count}/{self.min_valid_keypoints}"
|
|
458
|
+
cv2.putText(bgr_image, text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 255), 2)
|
|
459
|
+
else:
|
|
460
|
+
text = f"{formatted_result['class']} {formatted_result['confidence']:.1f}%"
|
|
461
|
+
cv2.putText(bgr_image, text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 0), 2)
|
|
462
|
+
cv2.imwrite(self.result_image_path, bgr_image)
|
|
463
|
+
print(f"结果图已保存至:{self.result_image_path} | {text}")
|
|
464
|
+
|
|
465
|
+
# 计算总耗时并打印完整信息
|
|
466
|
+
total_time = time.time() - total_start
|
|
467
|
+
print(f"推理完成:")
|
|
468
|
+
print(f"- 总耗时:{total_time:.4f}秒")
|
|
469
|
+
print(f"- 姿态提取耗时:{pose_extract_time:.4f}秒")
|
|
470
|
+
print(f"- RKNN推理耗时:{rknn_infer_time:.4f}秒")
|
|
471
|
+
print(f"- 有效关键点:{valid_keypoint_count}/{self.min_valid_keypoints}")
|
|
472
|
+
|
|
473
|
+
# 将时间信息添加到formatted结果中,以便pose.py获取
|
|
474
|
+
if formatted_result:
|
|
475
|
+
formatted_result['pose_extract_time'] = pose_extract_time
|
|
476
|
+
formatted_result['inference_time'] = rknn_infer_time
|
|
477
|
+
formatted_result['total_time'] = total_time
|
|
478
|
+
|
|
479
|
+
except Exception as e:
|
|
480
|
+
# 异常情况下也统计耗时
|
|
481
|
+
total_time = time.time() - total_start
|
|
482
|
+
rknn_infer_time = 0.0 # 推理未执行,耗时为0
|
|
483
|
+
print(f"推理失败:{e}")
|
|
484
|
+
print(f"- 总耗时:{total_time:.4f}秒")
|
|
485
|
+
print(f"- 姿态提取耗时:{pose_extract_time:.4f}秒")
|
|
486
|
+
print(f"- RKNN推理耗时:{rknn_infer_time:.4f}秒")
|
|
487
|
+
print(f"- 有效关键点:{valid_keypoint_count}/{self.min_valid_keypoints}")
|
|
488
|
+
|
|
489
|
+
return raw_result, formatted_result
|
|
490
|
+
|
|
491
|
+
def _format_result(self, predictions):
|
|
492
|
+
"""结果格式化逻辑"""
|
|
493
|
+
predictions = np.array(predictions, dtype=np.float32)
|
|
494
|
+
class_idx = np.argmax(predictions)
|
|
495
|
+
current_max = predictions[class_idx]
|
|
496
|
+
|
|
497
|
+
# 生成格式化结果
|
|
498
|
+
confidence = float(predictions[class_idx] * 100)
|
|
499
|
+
return {
|
|
500
|
+
"class": self.classes[class_idx] if 0 <= class_idx < len(self.classes) else f"未知类别_{class_idx}",
|
|
501
|
+
"confidence": confidence,
|
|
502
|
+
"probabilities": predictions.tolist(),
|
|
503
|
+
"class_id": int(class_idx)
|
|
504
|
+
}
|
|
505
|
+
|
|
506
|
+
def show_result(self):
|
|
507
|
+
"""结果显示逻辑"""
|
|
508
|
+
try:
|
|
509
|
+
result_image = cv2.imread(self.result_image_path)
|
|
510
|
+
if result_image is not None:
|
|
511
|
+
cv2.imshow("Pose RKNN Inference Result", result_image)
|
|
512
|
+
print("按任意键关闭窗口...")
|
|
513
|
+
cv2.waitKey(0)
|
|
514
|
+
cv2.destroyAllWindows()
|
|
515
|
+
else:
|
|
516
|
+
if self.processed_image is None:
|
|
517
|
+
print("无法显示结果:输入为姿态数据,未生成图像")
|
|
518
|
+
else:
|
|
519
|
+
print("未找到结果图片,请先执行图像路径输入的推理")
|
|
520
|
+
except Exception as e:
|
|
521
|
+
print(f"显示结果失败:{e}")
|
|
522
|
+
|
|
523
|
+
def release(self):
|
|
524
|
+
"""释放RKNN资源"""
|
|
525
|
+
if hasattr(self, "rknn_lite") and self.rknn_lite:
|
|
526
|
+
self.rknn_lite.release()
|
|
527
|
+
print("RKNN NPU资源已释放")
|
|
528
|
+
|
|
529
|
+
def __del__(self):
|
|
530
|
+
"""析构函数自动释放资源"""
|
|
531
|
+
self.release()
|
|
532
|
+
|
|
533
|
+
# 便捷接口
|
|
534
|
+
def predict(self, image_path):
|
|
535
|
+
"""从图像路径推理(同步接口)"""
|
|
536
|
+
return self.inference(image_path)
|
|
537
|
+
|
|
538
|
+
def predict_pose_data(self, pose_data):
|
|
539
|
+
"""从已提取的姿态数据推理(同步接口)"""
|
|
540
|
+
return self.inference(pose_data)
|
|
541
|
+
|
|
542
|
+
|
|
543
|
+
# 实时摄像头推理测试
|
|
544
|
+
if __name__ == "__main__":
|
|
545
|
+
# 配置参数(替换为你的RKNN模型路径)
|
|
546
|
+
RKNN_MODEL_PATH = "your_pose_model.rknn"
|
|
547
|
+
CAMERA_INDEX = 0 # 0=默认摄像头
|
|
548
|
+
|
|
549
|
+
# 初始化姿态推理工作流
|
|
550
|
+
pose_workflow = PoseWorkflow(model_path=RKNN_MODEL_PATH)
|
|
551
|
+
if not pose_workflow.rknn_lite:
|
|
552
|
+
print("RKNN模型初始化失败,无法启动实时推理")
|
|
553
|
+
exit(1)
|
|
554
|
+
|
|
555
|
+
# 初始化摄像头
|
|
556
|
+
cap = cv2.VideoCapture(CAMERA_INDEX)
|
|
557
|
+
if not cap.isOpened():
|
|
558
|
+
print(f"无法打开摄像头(索引:{CAMERA_INDEX})")
|
|
559
|
+
pose_workflow.release()
|
|
560
|
+
exit(1)
|
|
561
|
+
|
|
562
|
+
# 设置摄像头分辨率
|
|
563
|
+
cap.set(cv2.CAP_PROP_FRAME_WIDTH, 640)
|
|
564
|
+
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 480)
|
|
565
|
+
print("实时姿态推理启动成功!")
|
|
566
|
+
print(f"模型路径:{RKNN_MODEL_PATH}")
|
|
567
|
+
print(f"最小有效关键点:{pose_workflow.min_valid_keypoints}个")
|
|
568
|
+
print("按 'q' 键退出实时推理...")
|
|
569
|
+
|
|
570
|
+
# 循环读取摄像头帧并推理
|
|
571
|
+
while True:
|
|
572
|
+
ret, frame = cap.read()
|
|
573
|
+
if not ret:
|
|
574
|
+
print("无法读取摄像头帧,退出循环")
|
|
575
|
+
break
|
|
576
|
+
|
|
577
|
+
# 执行帧推理(自动打印完整耗时)
|
|
578
|
+
_, _, result_frame = pose_workflow.inference_frame(frame)
|
|
579
|
+
|
|
580
|
+
# 显示带标注的帧
|
|
581
|
+
cv2.imshow("RKNN Real-Time Pose Inference", result_frame)
|
|
582
|
+
|
|
583
|
+
# 按'q'退出
|
|
584
|
+
if cv2.waitKey(1) & 0xFF == ord('q'):
|
|
585
|
+
print("用户触发退出")
|
|
586
|
+
break
|
|
587
|
+
|
|
588
|
+
# 释放资源
|
|
589
|
+
cap.release()
|
|
590
|
+
cv2.destroyAllWindows()
|
|
591
|
+
pose_workflow.release()
|
|
592
|
+
print("实时推理结束,资源已释放")
|