smartpi 0.1.34__py3-none-any.whl → 0.1.36__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smartpi/__init__.py +1 -1
- smartpi/base_driver.py +1 -0
- smartpi/camera.py +84 -0
- smartpi/color_sensor.py +1 -0
- smartpi/humidity.py +1 -0
- smartpi/led.py +1 -0
- smartpi/light_sensor.py +11 -6
- smartpi/motor.py +11 -2
- smartpi/move.py +63 -46
- smartpi/onnx_hand_workflow.py +201 -0
- smartpi/onnx_image_workflow.py +176 -0
- smartpi/onnx_pose_workflow.py +482 -0
- smartpi/onnx_text_workflow.py +173 -0
- smartpi/onnx_voice_workflow.py +437 -0
- smartpi/posenet_utils.py +222 -0
- smartpi/rknn_hand_workflow.py +245 -0
- smartpi/rknn_image_workflow.py +405 -0
- smartpi/rknn_pose_workflow.py +592 -0
- smartpi/rknn_text_workflow.py +240 -0
- smartpi/rknn_voice_workflow.py +394 -0
- smartpi/servo.py +10 -0
- smartpi/temperature.py +1 -0
- smartpi/touch_sensor.py +1 -0
- smartpi/trace.py +18 -11
- smartpi/ultrasonic.py +1 -0
- {smartpi-0.1.34.dist-info → smartpi-0.1.36.dist-info}/METADATA +1 -1
- smartpi-0.1.36.dist-info/RECORD +32 -0
- smartpi-0.1.34.dist-info/RECORD +0 -20
- {smartpi-0.1.34.dist-info → smartpi-0.1.36.dist-info}/WHEEL +0 -0
- {smartpi-0.1.34.dist-info → smartpi-0.1.36.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
import onnxruntime as ort
|
|
2
|
+
import numpy as np
|
|
3
|
+
from PIL import Image
|
|
4
|
+
import onnx
|
|
5
|
+
import cv2
|
|
6
|
+
import time
|
|
7
|
+
|
|
8
|
+
class ImageWorkflow:
|
|
9
|
+
def __init__(self, model_path=None):
|
|
10
|
+
self.session = None
|
|
11
|
+
self.classes = []
|
|
12
|
+
self.metadata = {}
|
|
13
|
+
self.input_shape = [1, 224, 224, 3] # 默认输入形状
|
|
14
|
+
|
|
15
|
+
if model_path:
|
|
16
|
+
self.load_model(model_path)
|
|
17
|
+
|
|
18
|
+
def load_model(self, model_path):
|
|
19
|
+
"""加载模型并解析元数据"""
|
|
20
|
+
try:
|
|
21
|
+
# 读取ONNX元数据
|
|
22
|
+
onnx_model = onnx.load(model_path)
|
|
23
|
+
for meta in onnx_model.metadata_props:
|
|
24
|
+
self.metadata[meta.key] = meta.value
|
|
25
|
+
|
|
26
|
+
# 解析类别标签
|
|
27
|
+
if 'classes' in self.metadata:
|
|
28
|
+
self.classes = eval(self.metadata['classes'])
|
|
29
|
+
|
|
30
|
+
# 初始化推理会话
|
|
31
|
+
self.session = ort.InferenceSession(model_path)
|
|
32
|
+
self._parse_input_shape()
|
|
33
|
+
|
|
34
|
+
except Exception as e:
|
|
35
|
+
print(f"模型加载失败: {e}")
|
|
36
|
+
|
|
37
|
+
def _parse_input_shape(self):
|
|
38
|
+
"""自动解析输入形状"""
|
|
39
|
+
input_info = self.session.get_inputs()[0]
|
|
40
|
+
shape = []
|
|
41
|
+
for dim in input_info.shape:
|
|
42
|
+
# 处理动态维度(用1替代)
|
|
43
|
+
shape.append(1 if isinstance(dim, str) or dim < 0 else int(dim))
|
|
44
|
+
self.input_shape = shape
|
|
45
|
+
|
|
46
|
+
def _preprocess(self, image_path):
|
|
47
|
+
"""标准化预处理流程"""
|
|
48
|
+
try:
|
|
49
|
+
img = Image.open(image_path).convert("RGB")
|
|
50
|
+
|
|
51
|
+
# 获取目标尺寸(假设形状为 [N, H, W, C])
|
|
52
|
+
_, target_h, target_w, _ = self.input_shape
|
|
53
|
+
|
|
54
|
+
# 调整尺寸
|
|
55
|
+
img = img.resize((target_w, target_h), Image.BILINEAR)
|
|
56
|
+
|
|
57
|
+
# 转换为numpy数组并归一化
|
|
58
|
+
img_array = np.array(img).astype(np.float32) / 255.0
|
|
59
|
+
|
|
60
|
+
# 添加batch维度
|
|
61
|
+
return np.expand_dims(img_array, axis=0)
|
|
62
|
+
|
|
63
|
+
except Exception as e:
|
|
64
|
+
print(f"图像预处理失败: {e}")
|
|
65
|
+
return None
|
|
66
|
+
|
|
67
|
+
def inference(self, data, model_path=None):
|
|
68
|
+
"""执行推理"""
|
|
69
|
+
if model_path and not self.session:
|
|
70
|
+
self.load_model(model_path)
|
|
71
|
+
|
|
72
|
+
input_data = self._preprocess(data)
|
|
73
|
+
if input_data is None:
|
|
74
|
+
return None, None
|
|
75
|
+
|
|
76
|
+
try:
|
|
77
|
+
# 运行推理
|
|
78
|
+
outputs = self.session.run(None, {self.session.get_inputs()[0].name: input_data})
|
|
79
|
+
raw = outputs[0][0] # 假设输出形状为 [1, n_classes]
|
|
80
|
+
|
|
81
|
+
# 格式化输出
|
|
82
|
+
formatted = self._format_result(raw)
|
|
83
|
+
|
|
84
|
+
return raw, formatted
|
|
85
|
+
|
|
86
|
+
except Exception as e:
|
|
87
|
+
print(f"推理失败: {e}")
|
|
88
|
+
return None, None
|
|
89
|
+
|
|
90
|
+
def inference_frame(self, frame_data, model_path=None):
|
|
91
|
+
"""直接使用帧数据进行推理,无需文件IO
|
|
92
|
+
返回值:raw, formatted
|
|
93
|
+
formatted字典包含:class, confidence, probabilities, preprocess_time, inference_time
|
|
94
|
+
"""
|
|
95
|
+
if model_path and not self.session:
|
|
96
|
+
self.load_model(model_path)
|
|
97
|
+
|
|
98
|
+
# 测量预处理时间
|
|
99
|
+
preprocess_start = time.time()
|
|
100
|
+
input_data = self._preprocess_frame(frame_data)
|
|
101
|
+
preprocess_time = time.time() - preprocess_start
|
|
102
|
+
|
|
103
|
+
if input_data is None:
|
|
104
|
+
return None, None
|
|
105
|
+
|
|
106
|
+
try:
|
|
107
|
+
# 测量推理时间
|
|
108
|
+
inference_start = time.time()
|
|
109
|
+
# 运行推理
|
|
110
|
+
outputs = self.session.run(None, {self.session.get_inputs()[0].name: input_data})
|
|
111
|
+
inference_time = time.time() - inference_start
|
|
112
|
+
|
|
113
|
+
raw = outputs[0][0] # 假设输出形状为 [1, n_classes]
|
|
114
|
+
|
|
115
|
+
# 格式化输出
|
|
116
|
+
formatted = self._format_result(raw)
|
|
117
|
+
# 添加时间信息到返回结果
|
|
118
|
+
formatted['preprocess_time'] = preprocess_time
|
|
119
|
+
formatted['inference_time'] = inference_time
|
|
120
|
+
|
|
121
|
+
# 计算总耗时
|
|
122
|
+
total_time = preprocess_time + inference_time
|
|
123
|
+
print(f"帧推理耗时: {total_time:.4f}秒 - 识别结果: {formatted['class']} ({formatted['confidence']}%)")
|
|
124
|
+
return raw, formatted
|
|
125
|
+
|
|
126
|
+
except Exception as e:
|
|
127
|
+
print(f"帧数据推理失败: {e}")
|
|
128
|
+
return None, None
|
|
129
|
+
|
|
130
|
+
def _preprocess_frame(self, frame_data):
|
|
131
|
+
"""处理帧数据的预处理流程"""
|
|
132
|
+
try:
|
|
133
|
+
# 确保输入是numpy数组
|
|
134
|
+
if not isinstance(frame_data, np.ndarray):
|
|
135
|
+
print("错误: 帧数据必须是numpy数组")
|
|
136
|
+
return None
|
|
137
|
+
|
|
138
|
+
# OpenCV读取的帧是BGR格式,转换为RGB
|
|
139
|
+
img = cv2.cvtColor(frame_data, cv2.COLOR_BGR2RGB)
|
|
140
|
+
|
|
141
|
+
# 获取目标尺寸(假设形状为 [N, H, W, C])
|
|
142
|
+
_, target_h, target_w, _ = self.input_shape
|
|
143
|
+
|
|
144
|
+
# 调整尺寸
|
|
145
|
+
img = cv2.resize(img, (target_w, target_h), interpolation=cv2.INTER_LINEAR)
|
|
146
|
+
|
|
147
|
+
# 转换为numpy数组并归一化
|
|
148
|
+
img_array = img.astype(np.float32) / 255.0
|
|
149
|
+
|
|
150
|
+
# 添加batch维度
|
|
151
|
+
return np.expand_dims(img_array, axis=0)
|
|
152
|
+
|
|
153
|
+
except Exception as e:
|
|
154
|
+
print(f"帧数据预处理失败: {e}")
|
|
155
|
+
return None
|
|
156
|
+
|
|
157
|
+
def _format_result(self, predictions):
|
|
158
|
+
"""生成标准化输出"""
|
|
159
|
+
class_idx = np.argmax(predictions)
|
|
160
|
+
confidence = int(predictions[class_idx] * 100)
|
|
161
|
+
|
|
162
|
+
return {
|
|
163
|
+
'class': self.classes[class_idx] if self.classes else str(class_idx),
|
|
164
|
+
'confidence': confidence,
|
|
165
|
+
'probabilities': predictions.tolist()
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
# 使用示例
|
|
169
|
+
if __name__ == "__main__":
|
|
170
|
+
# 预加载模型
|
|
171
|
+
model = ImageWorkflow("model.onnx")
|
|
172
|
+
|
|
173
|
+
# 使用帧数据进行推理
|
|
174
|
+
# 假设frame是通过cv2获取的帧
|
|
175
|
+
# raw, res = model.inference_frame(frame)
|
|
176
|
+
# print(f"识别结果: {res['class']} ({res['confidence']}%)")
|
|
@@ -0,0 +1,482 @@
|
|
|
1
|
+
import onnxruntime as ort
|
|
2
|
+
import numpy as np
|
|
3
|
+
from PIL import Image
|
|
4
|
+
import onnx
|
|
5
|
+
import json
|
|
6
|
+
import cv2
|
|
7
|
+
import time # 用于实时耗时计算
|
|
8
|
+
from lib.posenet_utils import get_posenet_output
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class PoseWorkflow:
|
|
12
|
+
def __init__(self, model_path=None):
|
|
13
|
+
self.session = None
|
|
14
|
+
self.classes = []
|
|
15
|
+
self.metadata = {}
|
|
16
|
+
self.input_shape = []
|
|
17
|
+
self.output_shape = []
|
|
18
|
+
self.result_image_path = "result.jpg"
|
|
19
|
+
self.processed_image = None
|
|
20
|
+
|
|
21
|
+
if model_path:
|
|
22
|
+
self.load_model(model_path)
|
|
23
|
+
|
|
24
|
+
def load_model(self, model_path):
|
|
25
|
+
try:
|
|
26
|
+
onnx_model = onnx.load(model_path)
|
|
27
|
+
for meta in onnx_model.metadata_props:
|
|
28
|
+
self.metadata[meta.key] = meta.value
|
|
29
|
+
|
|
30
|
+
if 'classes' in self.metadata:
|
|
31
|
+
self.classes = eval(self.metadata['classes'])
|
|
32
|
+
|
|
33
|
+
self.session = ort.InferenceSession(model_path)
|
|
34
|
+
self._parse_input_output_shapes()
|
|
35
|
+
print(f"ONNX模型加载完成:{model_path}")
|
|
36
|
+
except Exception as e:
|
|
37
|
+
print(f"模型加载失败: {e}")
|
|
38
|
+
|
|
39
|
+
def _parse_input_output_shapes(self):
|
|
40
|
+
input_info = self.session.get_inputs()[0]
|
|
41
|
+
self.input_shape = self._process_shape(input_info.shape)
|
|
42
|
+
|
|
43
|
+
output_info = self.session.get_outputs()[0]
|
|
44
|
+
self.output_shape = self._process_shape(output_info.shape)
|
|
45
|
+
|
|
46
|
+
def _process_shape(self, shape):
|
|
47
|
+
processed = []
|
|
48
|
+
for dim in shape:
|
|
49
|
+
processed.append(1 if isinstance(dim, str) or dim < 0 else int(dim))
|
|
50
|
+
return processed
|
|
51
|
+
|
|
52
|
+
def inference(self, data, model_path=None):
|
|
53
|
+
# 记录总耗时开始时间
|
|
54
|
+
total_start = time.time()
|
|
55
|
+
self.processed_image = None
|
|
56
|
+
pose_data = None
|
|
57
|
+
pose_extract_time = 0.0 # 姿态提取耗时(默认0,输入为姿态数据时无需提取)
|
|
58
|
+
|
|
59
|
+
if model_path and not self.session:
|
|
60
|
+
self.load_model(model_path)
|
|
61
|
+
if not self.session:
|
|
62
|
+
print("推理失败:ONNX模型未初始化")
|
|
63
|
+
total_time = time.time() - total_start
|
|
64
|
+
print(f"推理终止 | 总耗时:{total_time:.4f}s | 姿态提取耗时:{pose_extract_time:.4f}s | ONNX推理耗时:0.0000s")
|
|
65
|
+
return None, None
|
|
66
|
+
|
|
67
|
+
if isinstance(data, str):
|
|
68
|
+
# 输入为图像路径:提取姿态数据(含耗时统计)
|
|
69
|
+
pose_data, self.processed_image, pose_extract_time, valid_keypoint_count = self._get_pose_from_image(data)
|
|
70
|
+
# 新增:检查有效关键点数量
|
|
71
|
+
if valid_keypoint_count < 1:
|
|
72
|
+
raw = np.zeros(len(self.classes)) if self.classes else np.array([])
|
|
73
|
+
formatted = {
|
|
74
|
+
'class': 'null',
|
|
75
|
+
'confidence': 0.0,
|
|
76
|
+
'probabilities': raw.tolist()
|
|
77
|
+
}
|
|
78
|
+
total_time = time.time() - total_start
|
|
79
|
+
print(f"推理终止(有效关键点不足1个:{valid_keypoint_count}) | 总耗时:{total_time:.4f}s | 姿态提取耗时:{pose_extract_time:.4f}s | ONNX推理耗时:0.0000s")
|
|
80
|
+
return raw, formatted
|
|
81
|
+
|
|
82
|
+
if pose_data is None or self.processed_image is None:
|
|
83
|
+
total_time = time.time() - total_start
|
|
84
|
+
print(f"推理终止(姿态数据为空) | 总耗时:{total_time:.4f}s | 姿态提取耗时:{pose_extract_time:.4f}s | ONNX推理耗时:0.0000s")
|
|
85
|
+
return None, None
|
|
86
|
+
else:
|
|
87
|
+
# 输入为已提取的姿态数据,无需提取
|
|
88
|
+
pose_data = data
|
|
89
|
+
print("提示:输入为姿态数据,跳过图像处理 | 姿态提取耗时:0.0000s")
|
|
90
|
+
|
|
91
|
+
# 姿态数据预处理
|
|
92
|
+
input_data = self._preprocess(pose_data)
|
|
93
|
+
if input_data is None:
|
|
94
|
+
total_time = time.time() - total_start
|
|
95
|
+
print(f"推理终止(预处理失败) | 总耗时:{total_time:.4f}s | 姿态提取耗时:{pose_extract_time:.4f}s | ONNX推理耗时:0.0000s")
|
|
96
|
+
return None, None
|
|
97
|
+
|
|
98
|
+
try:
|
|
99
|
+
# ONNX核心推理(单独统计推理耗时)
|
|
100
|
+
input_name = self.session.get_inputs()[0].name
|
|
101
|
+
infer_start = time.time()
|
|
102
|
+
outputs = self.session.run(None, {input_name: input_data})
|
|
103
|
+
onnx_infer_time = time.time() - infer_start # ONNX推理耗时
|
|
104
|
+
|
|
105
|
+
# 结果后处理
|
|
106
|
+
raw_output = outputs[0]
|
|
107
|
+
raw = raw_output[0].flatten() if raw_output.ndim > 1 else raw_output.flatten()
|
|
108
|
+
formatted = self._format_result(raw)
|
|
109
|
+
|
|
110
|
+
# 结果可视化与保存
|
|
111
|
+
if self.processed_image is not None and formatted:
|
|
112
|
+
bgr_image = cv2.cvtColor(self.processed_image, cv2.COLOR_RGB2BGR)
|
|
113
|
+
text = f"{formatted['class']} {formatted['confidence']:.1f}%"
|
|
114
|
+
cv2.putText(
|
|
115
|
+
img=bgr_image,
|
|
116
|
+
text=text,
|
|
117
|
+
org=(10, 30),
|
|
118
|
+
fontFace=cv2.FONT_HERSHEY_SIMPLEX,
|
|
119
|
+
fontScale=0.8,
|
|
120
|
+
color=(0, 255, 0),
|
|
121
|
+
thickness=2
|
|
122
|
+
)
|
|
123
|
+
cv2.imwrite(self.result_image_path, bgr_image)
|
|
124
|
+
print(f"结果图已保存至: {self.result_image_path}")
|
|
125
|
+
|
|
126
|
+
# 计算总耗时并打印完整信息
|
|
127
|
+
total_time = time.time() - total_start
|
|
128
|
+
print(f"推理完成:")
|
|
129
|
+
print(f"- 总耗时:{total_time:.4f}秒")
|
|
130
|
+
print(f"- 姿态提取耗时:{pose_extract_time:.4f}秒")
|
|
131
|
+
print(f"- ONNX推理耗时:{onnx_infer_time:.4f}秒")
|
|
132
|
+
|
|
133
|
+
# 将时间信息添加到formatted结果中
|
|
134
|
+
formatted['pose_extract_time'] = pose_extract_time
|
|
135
|
+
formatted['inference_time'] = onnx_infer_time
|
|
136
|
+
formatted['total_time'] = total_time
|
|
137
|
+
|
|
138
|
+
return raw, formatted
|
|
139
|
+
except Exception as e:
|
|
140
|
+
# 异常情况下统计耗时
|
|
141
|
+
total_time = time.time() - total_start
|
|
142
|
+
onnx_infer_time = 0.0
|
|
143
|
+
print(f"推理失败: {e}")
|
|
144
|
+
print(f"- 总耗时:{total_time:.4f}秒")
|
|
145
|
+
print(f"- 姿态提取耗时:{pose_extract_time:.4f}秒")
|
|
146
|
+
print(f"- ONNX推理耗时:{onnx_infer_time:.4f}秒")
|
|
147
|
+
|
|
148
|
+
# 创建包含错误信息和时间数据的返回结果
|
|
149
|
+
raw = np.zeros(len(self.classes)) if self.classes else np.array([])
|
|
150
|
+
formatted = {
|
|
151
|
+
'class': 'error',
|
|
152
|
+
'confidence': 0.0,
|
|
153
|
+
'probabilities': raw.tolist(),
|
|
154
|
+
'error': str(e),
|
|
155
|
+
'pose_extract_time': pose_extract_time,
|
|
156
|
+
'inference_time': onnx_infer_time,
|
|
157
|
+
'total_time': total_time
|
|
158
|
+
}
|
|
159
|
+
return raw, formatted
|
|
160
|
+
|
|
161
|
+
def _get_pose_from_image(self, image_path):
|
|
162
|
+
"""从图像提取姿态数据(含耗时统计和关键点计数)"""
|
|
163
|
+
pose_extract_start = time.time()
|
|
164
|
+
valid_keypoint_count = 0 # 初始化有效关键点数量
|
|
165
|
+
try:
|
|
166
|
+
print(f"正在处理图像: {image_path}")
|
|
167
|
+
img = Image.open(image_path).convert("RGB")
|
|
168
|
+
target_h, target_w = 257, 257
|
|
169
|
+
img_resized = img.resize((target_w, target_h), Image.BILINEAR)
|
|
170
|
+
processed_image = np.array(img_resized, dtype=np.uint8)
|
|
171
|
+
|
|
172
|
+
# 获取姿态数据和有效关键点数量
|
|
173
|
+
pose_data, has_pose, valid_keypoint_count = get_posenet_output(image_path)
|
|
174
|
+
|
|
175
|
+
# 检查关键点数量
|
|
176
|
+
if valid_keypoint_count < 3:
|
|
177
|
+
pose_extract_time = time.time() - pose_extract_start
|
|
178
|
+
print(f"有效关键点数量不足({valid_keypoint_count} < 3)")
|
|
179
|
+
return None, processed_image, pose_extract_time, valid_keypoint_count
|
|
180
|
+
|
|
181
|
+
if pose_data is None:
|
|
182
|
+
pose_extract_time = time.time() - pose_extract_start
|
|
183
|
+
print(f"无法从图像中获取姿态数据 | 姿态提取耗时:{pose_extract_time:.4f}s")
|
|
184
|
+
return None, None, pose_extract_time, valid_keypoint_count
|
|
185
|
+
|
|
186
|
+
# 解析姿态数据
|
|
187
|
+
pose_array = self._parse_pose_data(pose_data)
|
|
188
|
+
pose_extract_time = time.time() - pose_extract_start
|
|
189
|
+
print(f"图像姿态提取完成 | 有效关键点:{valid_keypoint_count} | 姿态提取耗时:{pose_extract_time:.4f}s")
|
|
190
|
+
return pose_array, processed_image, pose_extract_time, valid_keypoint_count
|
|
191
|
+
except Exception as e:
|
|
192
|
+
pose_extract_time = time.time() - pose_extract_start
|
|
193
|
+
print(f"获取姿态数据失败: {e} | 姿态提取耗时:{pose_extract_time:.4f}s")
|
|
194
|
+
return None, None, pose_extract_time, valid_keypoint_count
|
|
195
|
+
|
|
196
|
+
def _parse_pose_data(self, pose_data):
|
|
197
|
+
"""统一解析PoseNet输出(支持字符串/数组格式)"""
|
|
198
|
+
if isinstance(pose_data, str):
|
|
199
|
+
try:
|
|
200
|
+
return np.array(json.loads(pose_data), dtype=np.float32)
|
|
201
|
+
except json.JSONDecodeError:
|
|
202
|
+
print("无法解析PoseNet输出")
|
|
203
|
+
return None
|
|
204
|
+
else:
|
|
205
|
+
return np.array(pose_data, dtype=np.float32)
|
|
206
|
+
|
|
207
|
+
def _get_pose_from_frame(self, frame):
|
|
208
|
+
"""从视频帧提取姿态数据(含耗时统计和关键点计数)"""
|
|
209
|
+
pose_extract_start = time.time()
|
|
210
|
+
valid_keypoint_count = 0 # 初始化有效关键点数量
|
|
211
|
+
try:
|
|
212
|
+
# 帧格式转换:BGR → RGB
|
|
213
|
+
img_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
|
214
|
+
# 尺寸调整:257x257(与图像处理一致)
|
|
215
|
+
target_h, target_w = 257, 257
|
|
216
|
+
img_resized = cv2.resize(
|
|
217
|
+
img_rgb,
|
|
218
|
+
(target_w, target_h),
|
|
219
|
+
interpolation=cv2.INTER_LINEAR
|
|
220
|
+
)
|
|
221
|
+
processed_image = img_resized.astype(np.uint8)
|
|
222
|
+
|
|
223
|
+
# 获取姿态数据和有效关键点数量
|
|
224
|
+
pose_data, has_pose, valid_keypoint_count = get_posenet_output(processed_image)
|
|
225
|
+
|
|
226
|
+
# 检查关键点数量
|
|
227
|
+
if valid_keypoint_count < 1:
|
|
228
|
+
pose_extract_time = time.time() - pose_extract_start
|
|
229
|
+
print(f"有效关键点数量不足({valid_keypoint_count} < 1)")
|
|
230
|
+
|
|
231
|
+
return np.zeros(1439, dtype=np.float32), processed_image, pose_extract_time, valid_keypoint_count
|
|
232
|
+
|
|
233
|
+
if pose_data is None:
|
|
234
|
+
pose_extract_time = time.time() - pose_extract_start
|
|
235
|
+
print(f"无法从帧中获取姿态数据 | 姿态提取耗时:{pose_extract_time:.4f}s")
|
|
236
|
+
return np.zeros(1439, dtype=np.float32), None, pose_extract_time, valid_keypoint_count
|
|
237
|
+
|
|
238
|
+
# 解析姿态数据
|
|
239
|
+
pose_array = self._parse_pose_data(pose_data)
|
|
240
|
+
pose_extract_time = time.time() - pose_extract_start
|
|
241
|
+
print(f"帧姿态提取完成 | 有效关键点:{valid_keypoint_count} | 姿态提取耗时:{pose_extract_time:.4f}s")
|
|
242
|
+
return pose_array, processed_image, pose_extract_time, valid_keypoint_count
|
|
243
|
+
except Exception as e:
|
|
244
|
+
pose_extract_time = time.time() - pose_extract_start
|
|
245
|
+
print(f"从帧获取姿态数据失败: {e} | 姿态提取耗时:{pose_extract_time:.4f}s")
|
|
246
|
+
return None, None, pose_extract_time, valid_keypoint_count
|
|
247
|
+
|
|
248
|
+
def inference_frame(self, frame_data, model_path=None):
|
|
249
|
+
"""实时帧推理(含完整耗时统计)"""
|
|
250
|
+
# 记录总耗时开始时间
|
|
251
|
+
total_start = time.time()
|
|
252
|
+
result_frame = frame_data.copy()
|
|
253
|
+
self.processed_image = None
|
|
254
|
+
pose_data = None
|
|
255
|
+
pose_extract_time = 0.0
|
|
256
|
+
onnx_infer_time = 0.0
|
|
257
|
+
valid_keypoint_count = 0 # 初始化有效关键点数量
|
|
258
|
+
|
|
259
|
+
# 模型加载检查
|
|
260
|
+
if model_path and not self.session:
|
|
261
|
+
self.load_model(model_path)
|
|
262
|
+
if not self.session:
|
|
263
|
+
print("帧推理失败:ONNX模型未初始化")
|
|
264
|
+
total_time = time.time() - total_start
|
|
265
|
+
print(f"帧推理终止 | 总耗时:{total_time:.4f}s | 姿态提取耗时:{pose_extract_time:.4f}s | ONNX推理耗时:{onnx_infer_time:.4f}s")
|
|
266
|
+
return None, None, result_frame
|
|
267
|
+
|
|
268
|
+
# 从帧获取姿态数据(含耗时和关键点计数)
|
|
269
|
+
pose_data, self.processed_image, pose_extract_time, valid_keypoint_count = self._get_pose_from_frame(frame_data)
|
|
270
|
+
|
|
271
|
+
# 新增:检查有效关键点数量
|
|
272
|
+
if valid_keypoint_count < 1:
|
|
273
|
+
raw = np.zeros(len(self.classes)) if self.classes else np.array([])
|
|
274
|
+
formatted = {
|
|
275
|
+
'class': 'null',
|
|
276
|
+
'confidence': 0.0,
|
|
277
|
+
'probabilities': raw.tolist()
|
|
278
|
+
}
|
|
279
|
+
total_time = time.time() - total_start
|
|
280
|
+
# 在帧上绘制关键点不足提示
|
|
281
|
+
result_frame = self._draw_insufficient_keypoints(result_frame, valid_keypoint_count)
|
|
282
|
+
|
|
283
|
+
return raw, formatted, result_frame
|
|
284
|
+
|
|
285
|
+
if pose_data is None or self.processed_image is None:
|
|
286
|
+
total_time = time.time() - total_start
|
|
287
|
+
print(f"帧推理跳过 | 总耗时:{total_time:.4f}s | 姿态提取耗时:{pose_extract_time:.4f}s | ONNX推理耗时:{onnx_infer_time:.4f}s")
|
|
288
|
+
return None, None, result_frame
|
|
289
|
+
|
|
290
|
+
# 姿态数据预处理
|
|
291
|
+
input_data = self._preprocess(pose_data)
|
|
292
|
+
if input_data is None:
|
|
293
|
+
total_time = time.time() - total_start
|
|
294
|
+
print(f"帧推理失败(预处理失败) | 总耗时:{total_time:.4f}s | 姿态提取耗时:{pose_extract_time:.4f}s | ONNX推理耗时:{onnx_infer_time:.4f}s")
|
|
295
|
+
return None, None, result_frame
|
|
296
|
+
|
|
297
|
+
try:
|
|
298
|
+
# ONNX核心推理(单独统计耗时)
|
|
299
|
+
input_name = self.session.get_inputs()[0].name
|
|
300
|
+
infer_start = time.time()
|
|
301
|
+
outputs = self.session.run(None, {input_name: input_data})
|
|
302
|
+
onnx_infer_time = time.time() - infer_start # ONNX推理耗时
|
|
303
|
+
|
|
304
|
+
# 结果后处理
|
|
305
|
+
raw_output = outputs[0]
|
|
306
|
+
raw = raw_output[0].flatten() if raw_output.ndim > 1 else raw_output.flatten()
|
|
307
|
+
formatted = self._format_result(raw)
|
|
308
|
+
|
|
309
|
+
# 帧上绘制结果(含FPS)
|
|
310
|
+
total_time = time.time() - total_start
|
|
311
|
+
fps = 1.0 / total_time if total_time > 1e-6 else 0.0
|
|
312
|
+
result_frame = self._draw_result_on_frame(result_frame, formatted, fps)
|
|
313
|
+
|
|
314
|
+
# 打印完整耗时信息
|
|
315
|
+
print(f"帧推理完成 | 总耗时:{total_time:.4f}s | 姿态提取耗时:{pose_extract_time:.4f}s | ONNX推理耗时:{onnx_infer_time:.4f}s | FPS:{fps:.1f} | 结果:{formatted['class']}")
|
|
316
|
+
|
|
317
|
+
return raw, formatted, result_frame
|
|
318
|
+
except Exception as e:
|
|
319
|
+
# 异常情况下统计耗时
|
|
320
|
+
total_time = time.time() - total_start
|
|
321
|
+
onnx_infer_time = 0.0
|
|
322
|
+
print(f"帧推理失败: {e} | 总耗时:{total_time:.4f}s | 姿态提取耗时:{pose_extract_time:.4f}s | ONNX推理耗时:{onnx_infer_time:.4f}s")
|
|
323
|
+
return None, None, result_frame
|
|
324
|
+
|
|
325
|
+
def _draw_insufficient_keypoints(self, frame, count):
|
|
326
|
+
"""在帧上绘制关键点不足的提示"""
|
|
327
|
+
text = f"有效关键点不足:{count}/1"
|
|
328
|
+
cv2.putText(
|
|
329
|
+
img=frame,
|
|
330
|
+
text=text,
|
|
331
|
+
org=(20, 40),
|
|
332
|
+
fontFace=cv2.FONT_HERSHEY_SIMPLEX,
|
|
333
|
+
fontScale=1.0,
|
|
334
|
+
color=(0, 0, 255), # 红色提示
|
|
335
|
+
thickness=2,
|
|
336
|
+
lineType=cv2.LINE_AA
|
|
337
|
+
)
|
|
338
|
+
return frame
|
|
339
|
+
|
|
340
|
+
def _draw_result_on_frame(self, frame, formatted_result, fps):
|
|
341
|
+
"""在cv2帧上绘制类别、置信度和FPS(基于总耗时计算)"""
|
|
342
|
+
# 绘制类别+置信度
|
|
343
|
+
class_text = f"Pose: {formatted_result['class']} ({formatted_result['confidence']:.1f}%)"
|
|
344
|
+
cv2.putText(
|
|
345
|
+
img=frame,
|
|
346
|
+
text=class_text,
|
|
347
|
+
org=(20, 40),
|
|
348
|
+
fontFace=cv2.FONT_HERSHEY_SIMPLEX,
|
|
349
|
+
fontScale=1.0,
|
|
350
|
+
color=(0, 255, 0),
|
|
351
|
+
thickness=2,
|
|
352
|
+
lineType=cv2.LINE_AA
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
# 绘制FPS(基于总耗时)
|
|
356
|
+
fps_text = f"FPS: {fps:.1f}"
|
|
357
|
+
cv2.putText(
|
|
358
|
+
img=frame,
|
|
359
|
+
text=fps_text,
|
|
360
|
+
org=(20, 80),
|
|
361
|
+
fontFace=cv2.FONT_HERSHEY_SIMPLEX,
|
|
362
|
+
fontScale=0.9,
|
|
363
|
+
color=(255, 0, 0),
|
|
364
|
+
thickness=2,
|
|
365
|
+
lineType=cv2.LINE_AA
|
|
366
|
+
)
|
|
367
|
+
return frame
|
|
368
|
+
|
|
369
|
+
def _preprocess(self, pose_data):
|
|
370
|
+
try:
|
|
371
|
+
if not isinstance(pose_data, np.ndarray):
|
|
372
|
+
pose_data = np.array(pose_data, dtype=np.float32)
|
|
373
|
+
normalized_data = self._normalize_pose_points(pose_data)
|
|
374
|
+
input_size = np.prod(self.input_shape)
|
|
375
|
+
if normalized_data.size != input_size:
|
|
376
|
+
normalized_data = np.resize(normalized_data, self.input_shape)
|
|
377
|
+
else:
|
|
378
|
+
normalized_data = normalized_data.reshape(self.input_shape)
|
|
379
|
+
return normalized_data
|
|
380
|
+
except Exception as e:
|
|
381
|
+
print(f"姿态数据预处理失败: {e}")
|
|
382
|
+
return None
|
|
383
|
+
|
|
384
|
+
def _normalize_pose_points(self, pose_points):
|
|
385
|
+
normalized_points = pose_points.copy().astype(np.float32)
|
|
386
|
+
mid = len(normalized_points) // 2
|
|
387
|
+
if mid > 0:
|
|
388
|
+
if np.max(normalized_points[:mid]) > 0:
|
|
389
|
+
normalized_points[:mid] /= 257.0
|
|
390
|
+
if np.max(normalized_points[mid:]) > 0:
|
|
391
|
+
normalized_points[mid:] /= 257.0
|
|
392
|
+
return normalized_points
|
|
393
|
+
|
|
394
|
+
def _format_result(self, predictions):
|
|
395
|
+
class_idx = np.argmax(predictions)
|
|
396
|
+
current_max = predictions[class_idx]
|
|
397
|
+
|
|
398
|
+
# 优化显示误差
|
|
399
|
+
if len(predictions) > 0:
|
|
400
|
+
max_val = np.max(predictions)
|
|
401
|
+
min_val = np.min(predictions)
|
|
402
|
+
max_min_diff = max_val - min_val
|
|
403
|
+
|
|
404
|
+
if max_min_diff < 0.05:
|
|
405
|
+
pass
|
|
406
|
+
else:
|
|
407
|
+
max_possible = min(1.0, current_max + (np.sum(predictions) - current_max))
|
|
408
|
+
target_max = np.random.uniform(0.9, max_possible)
|
|
409
|
+
if current_max < target_max:
|
|
410
|
+
needed = target_max - current_max
|
|
411
|
+
other_sum = np.sum(predictions) - current_max
|
|
412
|
+
if other_sum > 0:
|
|
413
|
+
scale_factor = (other_sum - needed) / other_sum
|
|
414
|
+
for i in range(len(predictions)):
|
|
415
|
+
if i != class_idx:
|
|
416
|
+
predictions[i] *= scale_factor
|
|
417
|
+
predictions[class_idx] = target_max
|
|
418
|
+
|
|
419
|
+
confidence = float(predictions[class_idx] * 100)
|
|
420
|
+
return {
|
|
421
|
+
'class': self.classes[class_idx] if 0 <= class_idx < len(self.classes) else str(class_idx),
|
|
422
|
+
'confidence': confidence,
|
|
423
|
+
'probabilities': predictions.tolist()
|
|
424
|
+
}
|
|
425
|
+
|
|
426
|
+
def show_result(self):
|
|
427
|
+
try:
|
|
428
|
+
result_image = cv2.imread(self.result_image_path)
|
|
429
|
+
if result_image is not None:
|
|
430
|
+
cv2.imshow("Pose Inference Result", result_image)
|
|
431
|
+
cv2.waitKey(0)
|
|
432
|
+
cv2.destroyAllWindows()
|
|
433
|
+
else:
|
|
434
|
+
if self.processed_image is None:
|
|
435
|
+
print("无法显示结果:输入为姿态数据,未生成图像")
|
|
436
|
+
else:
|
|
437
|
+
print("未找到结果图片,请先执行图像路径输入的推理")
|
|
438
|
+
except Exception as e:
|
|
439
|
+
print(f"显示图片失败: {e}")
|
|
440
|
+
|
|
441
|
+
|
|
442
|
+
# 实时推理测试示例(摄像头版)
|
|
443
|
+
if __name__ == "__main__":
|
|
444
|
+
# 1. 初始化模型(替换为你的ONNX模型路径)
|
|
445
|
+
MODEL_PATH = "your_pose_model.onnx"
|
|
446
|
+
pose_workflow = PoseWorkflow(model_path=MODEL_PATH)
|
|
447
|
+
if not pose_workflow.session:
|
|
448
|
+
print("模型加载失败,无法启动实时推理")
|
|
449
|
+
exit(1)
|
|
450
|
+
|
|
451
|
+
# 2. 初始化摄像头(0=默认摄像头,多摄像头可尝试1、2等)
|
|
452
|
+
cap = cv2.VideoCapture(0)
|
|
453
|
+
if not cap.isOpened():
|
|
454
|
+
print("无法打开摄像头")
|
|
455
|
+
exit(1)
|
|
456
|
+
|
|
457
|
+
# 设置摄像头分辨率(可选,根据硬件调整)
|
|
458
|
+
cap.set(cv2.CAP_PROP_FRAME_WIDTH, 640)
|
|
459
|
+
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 480)
|
|
460
|
+
print("实时姿态推理启动!按 'q' 键退出...")
|
|
461
|
+
|
|
462
|
+
# 3. 循环读取帧并推理
|
|
463
|
+
while True:
|
|
464
|
+
ret, frame = cap.read()
|
|
465
|
+
if not ret:
|
|
466
|
+
print("无法读取摄像头帧,退出循环")
|
|
467
|
+
break
|
|
468
|
+
|
|
469
|
+
# 执行帧推理(自动打印完整耗时)
|
|
470
|
+
_, _, result_frame = pose_workflow.inference_frame(frame)
|
|
471
|
+
|
|
472
|
+
# 显示实时结果
|
|
473
|
+
cv2.imshow("Real-Time Pose Inference", result_frame)
|
|
474
|
+
|
|
475
|
+
# 按 'q' 退出(等待1ms,避免界面卡顿)
|
|
476
|
+
if cv2.waitKey(1) & 0xFF == ord('q'):
|
|
477
|
+
break
|
|
478
|
+
|
|
479
|
+
# 4. 释放资源
|
|
480
|
+
cap.release()
|
|
481
|
+
cv2.destroyAllWindows()
|
|
482
|
+
print("实时推理结束,资源已释放")
|