smartpi 0.1.38__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smartpi/__init__.py +8 -0
- smartpi/_gui.py +66 -0
- smartpi/base_driver.py +566 -0
- smartpi/camera.py +84 -0
- smartpi/color_sensor.py +18 -0
- smartpi/cw2015.py +179 -0
- smartpi/flash.py +130 -0
- smartpi/humidity.py +20 -0
- smartpi/led.py +19 -0
- smartpi/light_sensor.py +72 -0
- smartpi/motor.py +177 -0
- smartpi/move.py +218 -0
- smartpi/onnx_hand_workflow.py +201 -0
- smartpi/onnx_image_workflow.py +176 -0
- smartpi/onnx_pose_workflow.py +482 -0
- smartpi/onnx_text_workflow.py +173 -0
- smartpi/onnx_voice_workflow.py +437 -0
- smartpi/posemodel/__init__.py +0 -0
- smartpi/posemodel/posenet.tflite +0 -0
- smartpi/posenet_utils.py +222 -0
- smartpi/rknn_hand_workflow.py +245 -0
- smartpi/rknn_image_workflow.py +405 -0
- smartpi/rknn_pose_workflow.py +592 -0
- smartpi/rknn_text_workflow.py +240 -0
- smartpi/rknn_voice_workflow.py +394 -0
- smartpi/servo.py +178 -0
- smartpi/temperature.py +18 -0
- smartpi/text_gte_model/__init__.py +0 -0
- smartpi/text_gte_model/config/__init__.py +0 -0
- smartpi/text_gte_model/config/config.json +30 -0
- smartpi/text_gte_model/config/quantize_config.json +30 -0
- smartpi/text_gte_model/config/special_tokens_map.json +7 -0
- smartpi/text_gte_model/config/tokenizer.json +14924 -0
- smartpi/text_gte_model/config/tokenizer_config.json +23 -0
- smartpi/text_gte_model/config/vocab.txt +14760 -0
- smartpi/text_gte_model/gte/__init__.py +0 -0
- smartpi/text_gte_model/gte/gte_model.onnx +0 -0
- smartpi/touch_sensor.py +16 -0
- smartpi/trace.py +120 -0
- smartpi/ultrasonic.py +20 -0
- smartpi-0.1.38.dist-info/METADATA +17 -0
- smartpi-0.1.38.dist-info/RECORD +44 -0
- smartpi-0.1.38.dist-info/WHEEL +5 -0
- smartpi-0.1.38.dist-info/top_level.txt +1 -0
smartpi/move.py
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
1
|
+
# coding=utf-8
|
|
2
|
+
import time
|
|
3
|
+
from typing import List, Optional
|
|
4
|
+
from smartpi import base_driver
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
#以速度移动x秒:dir:方向forward、backward、turnright、turnleft;speed:0~100;second:x秒
|
|
8
|
+
def run_second(dir:bytes,speed:bytes,second:bytes) -> Optional[bytes]:
|
|
9
|
+
move_str=[0xA0, 0x01, 0x11, 0x71, 0x00, 0x71, 0x00, 0x71, 0x00, 0xBE]
|
|
10
|
+
|
|
11
|
+
if dir=="forward":
|
|
12
|
+
move_str[4]=0x01
|
|
13
|
+
elif dir=="backward":
|
|
14
|
+
move_str[4]=0x02
|
|
15
|
+
elif dir=="turnright":
|
|
16
|
+
move_str[4]=0x03
|
|
17
|
+
elif dir=="turnleft":
|
|
18
|
+
move_str[4]=0x04
|
|
19
|
+
|
|
20
|
+
move_str[6]=speed
|
|
21
|
+
move_str[8]=second
|
|
22
|
+
time.sleep(0.005)
|
|
23
|
+
base_driver.write_data(0X01, 0X02, move_str)
|
|
24
|
+
# response = base_driver.single_operate_sensor(move_str,0)
|
|
25
|
+
# if response == None:
|
|
26
|
+
# return None
|
|
27
|
+
# else:
|
|
28
|
+
return 0
|
|
29
|
+
|
|
30
|
+
#以速度移动x度:dir:方向forward、backward、turnright、turnleft:speed:0~100:angle:65535
|
|
31
|
+
def run_angle(dir:bytes,speed:bytes,angle:int) -> Optional[bytes]:
|
|
32
|
+
move_str=[0xA0, 0x01, 0x12, 0x71, 0x00, 0x71, 0x00, 0x81, 0x00, 0x00, 0xBE]
|
|
33
|
+
|
|
34
|
+
if dir=="forward":
|
|
35
|
+
move_str[4]=0x01
|
|
36
|
+
elif dir=="backward":
|
|
37
|
+
move_str[4]=0x02
|
|
38
|
+
elif dir=="turnright":
|
|
39
|
+
move_str[4]=0x03
|
|
40
|
+
elif dir=="turnleft":
|
|
41
|
+
move_str[4]=0x04
|
|
42
|
+
|
|
43
|
+
move_str[6]=speed
|
|
44
|
+
move_str[8]=angle//256
|
|
45
|
+
move_str[9]=angle%256
|
|
46
|
+
|
|
47
|
+
time.sleep(0.005)
|
|
48
|
+
base_driver.write_data(0X01, 0X02, move_str)
|
|
49
|
+
# response = base_driver.single_operate_sensor(move_str,0)
|
|
50
|
+
# if response == None:
|
|
51
|
+
# return None
|
|
52
|
+
# else:
|
|
53
|
+
return 0
|
|
54
|
+
|
|
55
|
+
#以速度移动:dir:方向forward、backward、turnright、turnleft;speed:0~100;
|
|
56
|
+
def run(dir:bytes,speed:bytes) -> Optional[bytes]:
|
|
57
|
+
move_str=[0xA0, 0x01, 0x13, 0x71, 0x00, 0x71, 0x00, 0xBE]
|
|
58
|
+
|
|
59
|
+
if dir=="forward":
|
|
60
|
+
move_str[4]=0x01
|
|
61
|
+
elif dir=="backward":
|
|
62
|
+
move_str[4]=0x02
|
|
63
|
+
elif dir=="turnright":
|
|
64
|
+
move_str[4]=0x03
|
|
65
|
+
elif dir=="turnleft":
|
|
66
|
+
move_str[4]=0x04
|
|
67
|
+
|
|
68
|
+
move_str[6]=speed
|
|
69
|
+
|
|
70
|
+
time.sleep(0.005)
|
|
71
|
+
base_driver.write_data(0X01, 0X02, move_str)
|
|
72
|
+
# response = base_driver.single_operate_sensor(move_str,0)
|
|
73
|
+
# if response == None:
|
|
74
|
+
# return None
|
|
75
|
+
# else:
|
|
76
|
+
return 0
|
|
77
|
+
|
|
78
|
+
#设置左右轮速度移动x秒:Lspeed:-100~100;Rspeed:-100~100;second:1~255
|
|
79
|
+
def run_speed_second(Lspeed:int,Rspeed:int,second:bytes) -> Optional[bytes]:
|
|
80
|
+
move_str=[0xA0, 0x01, 0x14, 0x71, 0x00, 0x71, 0x00, 0x71, 0x00, 0xBE]
|
|
81
|
+
|
|
82
|
+
if Lspeed>100:
|
|
83
|
+
m_par=100
|
|
84
|
+
elif Lspeed>=0 and Lspeed<=100:
|
|
85
|
+
m_par=Lspeed
|
|
86
|
+
elif Lspeed<-100:
|
|
87
|
+
m_par=156
|
|
88
|
+
elif Lspeed<=0 and Lspeed>=-100:
|
|
89
|
+
m_par=256+Lspeed
|
|
90
|
+
|
|
91
|
+
move_str[6]=m_par
|
|
92
|
+
|
|
93
|
+
if Rspeed>100:
|
|
94
|
+
m_par=100
|
|
95
|
+
elif Rspeed>=0 and Rspeed<=100:
|
|
96
|
+
m_par=Rspeed
|
|
97
|
+
elif Rspeed<-100:
|
|
98
|
+
m_par=156
|
|
99
|
+
elif Rspeed<=0 and Rspeed>=-100:
|
|
100
|
+
m_par=256+Rspeed
|
|
101
|
+
|
|
102
|
+
move_str[4]=m_par
|
|
103
|
+
|
|
104
|
+
move_str[8]=second
|
|
105
|
+
|
|
106
|
+
time.sleep(0.005)
|
|
107
|
+
base_driver.write_data(0X01, 0X02, move_str)
|
|
108
|
+
# response = base_driver.single_operate_sensor(move_str,0)
|
|
109
|
+
# if response == None:
|
|
110
|
+
# return None
|
|
111
|
+
# else:
|
|
112
|
+
return 0
|
|
113
|
+
|
|
114
|
+
#设置左右轮速度移动:Lspeed:-100~100;Rspeed:-100~100;
|
|
115
|
+
def run_speed(Lspeed:int,Rspeed:int) -> Optional[bytes]:
|
|
116
|
+
move_str=[0xA0, 0x01, 0x15, 0x71, 0x00, 0x71, 0x00, 0xBE]
|
|
117
|
+
|
|
118
|
+
if Lspeed>100:
|
|
119
|
+
m_par=100
|
|
120
|
+
elif Lspeed>=0 and Lspeed<=100:
|
|
121
|
+
m_par=Lspeed
|
|
122
|
+
elif Lspeed<-100:
|
|
123
|
+
m_par=156
|
|
124
|
+
elif Lspeed<=0 and Lspeed>=-100:
|
|
125
|
+
m_par=256+Lspeed
|
|
126
|
+
|
|
127
|
+
move_str[6]=m_par
|
|
128
|
+
|
|
129
|
+
if Rspeed>100:
|
|
130
|
+
m_par=100
|
|
131
|
+
elif Rspeed>=0 and Rspeed<=100:
|
|
132
|
+
m_par=Rspeed
|
|
133
|
+
elif Rspeed<-100:
|
|
134
|
+
m_par=156
|
|
135
|
+
elif Rspeed<=0 and Rspeed>=-100:
|
|
136
|
+
m_par=256+Rspeed
|
|
137
|
+
|
|
138
|
+
move_str[4]=m_par
|
|
139
|
+
|
|
140
|
+
time.sleep(0.005)
|
|
141
|
+
base_driver.write_data(0X01, 0X02, move_str)
|
|
142
|
+
# response = base_driver.single_operate_sensor(move_str,0)
|
|
143
|
+
# if response == None:
|
|
144
|
+
# return None
|
|
145
|
+
# else:
|
|
146
|
+
return 0
|
|
147
|
+
|
|
148
|
+
#设置左右轮功率移动:Lpower:0~100;Rpower:0~100;
|
|
149
|
+
def run_power(Lpower:bytes,Rpower:bytes) -> Optional[bytes]:
|
|
150
|
+
move_str=[0xA0, 0x01, 0x17, 0x71, 0x00, 0x71, 0x00, 0xBE]
|
|
151
|
+
|
|
152
|
+
move_str[4]=Rpower
|
|
153
|
+
move_str[6]=Lpower
|
|
154
|
+
|
|
155
|
+
time.sleep(0.005)
|
|
156
|
+
base_driver.write_data(0X01, 0X02, move_str)
|
|
157
|
+
# response = base_driver.single_operate_sensor(move_str,0)
|
|
158
|
+
# if response == None:
|
|
159
|
+
# return None
|
|
160
|
+
# else:
|
|
161
|
+
return 0
|
|
162
|
+
|
|
163
|
+
#设置最大功率:M1:0~100;M2:0~100;M3:0~100;M4:0~100;M5:0~100;M6:0~100;
|
|
164
|
+
def set_maxpower(M1:bytes,M2:bytes,M3:bytes,M4:bytes,M5:bytes,M6:bytes) -> Optional[bytes]:
|
|
165
|
+
move_str=[0xA0, 0x01, 0x18, 0x71, 0x00, 0x71, 0x00, 0x71, 0x00, 0x71, 0x00, 0x71, 0x00, 0x71, 0x00, 0xBE]
|
|
166
|
+
|
|
167
|
+
move_str[4]=M1
|
|
168
|
+
move_str[6]=M2
|
|
169
|
+
move_str[8]=M3
|
|
170
|
+
move_str[10]=M4
|
|
171
|
+
move_str[12]=M5
|
|
172
|
+
move_str[14]=M6
|
|
173
|
+
|
|
174
|
+
time.sleep(0.005)
|
|
175
|
+
base_driver.write_data(0X01, 0X02, move_str)
|
|
176
|
+
# response = base_driver.single_operate_sensor(move_str,0)
|
|
177
|
+
# if response == None:
|
|
178
|
+
# return None
|
|
179
|
+
# else:
|
|
180
|
+
return 0
|
|
181
|
+
|
|
182
|
+
#马达停止
|
|
183
|
+
def stop() -> Optional[bytes]:
|
|
184
|
+
move_str=[0xA0, 0x01, 0x0A, 0xBE]
|
|
185
|
+
|
|
186
|
+
time.sleep(0.005)
|
|
187
|
+
base_driver.write_data(0X01, 0X02, move_str)
|
|
188
|
+
# response = base_driver.single_operate_sensor(move_str,0)
|
|
189
|
+
# if response == None:
|
|
190
|
+
# return None
|
|
191
|
+
# else:
|
|
192
|
+
return 0
|
|
193
|
+
|
|
194
|
+
#设置左右轮方向:Lmotor:1~6;Rmotor:1~6;state: no_reversal、all_reversal、left_reversal、right_reversal
|
|
195
|
+
def set_move_init(Lmotor:bytes,Rmotor:bytes,state:bytes) -> Optional[bytes]:
|
|
196
|
+
move_str=[0xA0, 0x01, 0x19, 0x71, 0x00, 0x71, 0x00, 0x71, 0x00, 0xBE]
|
|
197
|
+
|
|
198
|
+
if state=="no_reversal":
|
|
199
|
+
move_str[4]=0x01
|
|
200
|
+
elif state=="all_reversal":
|
|
201
|
+
move_str[4]=0x02
|
|
202
|
+
elif state=="left_reversal":
|
|
203
|
+
move_str[4]=0x03
|
|
204
|
+
elif state=="right_reversal":
|
|
205
|
+
move_str[4]=0x04
|
|
206
|
+
|
|
207
|
+
move_str[6]=Rmotor
|
|
208
|
+
move_str[8]=Lmotor
|
|
209
|
+
|
|
210
|
+
time.sleep(0.005)
|
|
211
|
+
base_driver.write_data(0X01, 0X02, move_str)
|
|
212
|
+
# response = base_driver.single_operate_sensor(move_str,0)
|
|
213
|
+
# if response == None:
|
|
214
|
+
# return None
|
|
215
|
+
# else:
|
|
216
|
+
return 0
|
|
217
|
+
|
|
218
|
+
|
|
@@ -0,0 +1,201 @@
|
|
|
1
|
+
import cv2
|
|
2
|
+
import numpy as np
|
|
3
|
+
import onnxruntime as ort
|
|
4
|
+
import mediapipe as mp
|
|
5
|
+
import json
|
|
6
|
+
from PIL import Image
|
|
7
|
+
import time # 用于时间测量
|
|
8
|
+
|
|
9
|
+
class GestureWorkflow:
|
|
10
|
+
def __init__(self, model_path):
|
|
11
|
+
# 初始化MediaPipe Hands
|
|
12
|
+
self.mp_hands = mp.solutions.hands
|
|
13
|
+
self.hands = self.mp_hands.Hands(
|
|
14
|
+
static_image_mode=False, # 视频流模式 如果只是获取照片的手势关键点 请设置为True
|
|
15
|
+
max_num_hands=1,#如果想要检测双手,请设置成2
|
|
16
|
+
min_detection_confidence=0.5,#手势关键点的阈值
|
|
17
|
+
model_complexity=0#使用最简单的模型 如果效果不准确 可以考虑设置比较复制的模型 1
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
# 初始化元数据
|
|
21
|
+
self.min_vals = None
|
|
22
|
+
self.max_vals = None
|
|
23
|
+
self.class_labels = None
|
|
24
|
+
|
|
25
|
+
# 加载模型和元数据
|
|
26
|
+
self.load_model(model_path)
|
|
27
|
+
|
|
28
|
+
def load_model(self, model_path):
|
|
29
|
+
"""加载模型并解析元数据"""
|
|
30
|
+
# 初始化ONNX Runtime会话
|
|
31
|
+
self.session = ort.InferenceSession(model_path)
|
|
32
|
+
|
|
33
|
+
# 加载元数据
|
|
34
|
+
self._load_metadata()
|
|
35
|
+
|
|
36
|
+
def _load_metadata(self):
|
|
37
|
+
"""从ONNX模型元数据中加载归一化参数和类别标签"""
|
|
38
|
+
model_meta = self.session.get_modelmeta()
|
|
39
|
+
|
|
40
|
+
# 检查custom_metadata_map是否存在
|
|
41
|
+
if hasattr(model_meta, 'custom_metadata_map'):
|
|
42
|
+
metadata = model_meta.custom_metadata_map
|
|
43
|
+
if 'minMaxValues' in metadata:
|
|
44
|
+
min_max_data = json.loads(metadata['minMaxValues'])
|
|
45
|
+
self.min_vals = min_max_data.get('min')
|
|
46
|
+
self.max_vals = min_max_data.get('max')
|
|
47
|
+
|
|
48
|
+
if 'classes' in metadata:
|
|
49
|
+
class_labels = json.loads(metadata['classes'])
|
|
50
|
+
self.class_labels = list(class_labels.values()) if isinstance(class_labels, dict) else class_labels
|
|
51
|
+
else:
|
|
52
|
+
# 对于旧版本的ONNX Runtime,使用metadata_props
|
|
53
|
+
for prop in model_meta.metadata_props:
|
|
54
|
+
if prop.key == 'minMaxValues':
|
|
55
|
+
min_max_data = json.loads(prop.value)
|
|
56
|
+
self.min_vals = min_max_data.get('min')
|
|
57
|
+
self.max_vals = min_max_data.get('max')
|
|
58
|
+
elif prop.key == 'classes':
|
|
59
|
+
class_labels = json.loads(prop.value)
|
|
60
|
+
self.class_labels = list(class_labels.values()) if isinstance(class_labels, dict) else class_labels
|
|
61
|
+
|
|
62
|
+
# 设置默认值
|
|
63
|
+
if self.class_labels is None:
|
|
64
|
+
self.class_labels = ["点赞", "点踩", "胜利", "拳头", "我爱你", "手掌"]
|
|
65
|
+
|
|
66
|
+
def preprocess_image(self, image, target_width=224, target_height=224):
|
|
67
|
+
"""
|
|
68
|
+
预处理图像:保持比例缩放并居中放置在目标尺寸的画布上
|
|
69
|
+
返回处理后的OpenCV图像 (BGR格式)
|
|
70
|
+
"""
|
|
71
|
+
# 将OpenCV图像转换为PIL格式
|
|
72
|
+
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
|
73
|
+
pil_image = Image.fromarray(image_rgb)
|
|
74
|
+
|
|
75
|
+
# 计算缩放比例
|
|
76
|
+
width, height = pil_image.size
|
|
77
|
+
scale = min(target_width / width, target_height / height)
|
|
78
|
+
|
|
79
|
+
# 计算新尺寸和位置
|
|
80
|
+
new_width = int(width * scale)
|
|
81
|
+
new_height = int(height * scale)
|
|
82
|
+
x = (target_width - new_width) // 2
|
|
83
|
+
y = (target_height - new_height) // 2
|
|
84
|
+
|
|
85
|
+
# 创建白色背景画布并粘贴缩放后的图像
|
|
86
|
+
canvas = Image.new('RGB', (target_width, target_height), (255, 255, 255))
|
|
87
|
+
resized_image = pil_image.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
|
88
|
+
canvas.paste(resized_image, (x, y))
|
|
89
|
+
|
|
90
|
+
# 转换回OpenCV格式
|
|
91
|
+
processed_image = np.array(canvas)
|
|
92
|
+
return cv2.cvtColor(processed_image, cv2.COLOR_RGB2BGR)
|
|
93
|
+
|
|
94
|
+
def extract_hand_keypoints(self, image):
|
|
95
|
+
"""从图像中提取手部关键点"""
|
|
96
|
+
# 转换图像为RGB格式并处理
|
|
97
|
+
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
|
98
|
+
results = self.hands.process(image_rgb)
|
|
99
|
+
|
|
100
|
+
if results.multi_hand_landmarks:
|
|
101
|
+
# 只使用检测到的第一只手
|
|
102
|
+
landmarks = results.multi_hand_world_landmarks[0]
|
|
103
|
+
|
|
104
|
+
# 提取关键点坐标
|
|
105
|
+
keypoints = []
|
|
106
|
+
for landmark in landmarks.landmark:
|
|
107
|
+
keypoints.extend([landmark.x, landmark.y, landmark.z])
|
|
108
|
+
|
|
109
|
+
return np.array(keypoints, dtype=np.float32)
|
|
110
|
+
return None
|
|
111
|
+
|
|
112
|
+
def normalize_keypoints(self, keypoints):
|
|
113
|
+
"""归一化关键点数据"""
|
|
114
|
+
if self.min_vals is None or self.max_vals is None:
|
|
115
|
+
return keypoints # 如果没有归一化参数,返回原始数据
|
|
116
|
+
|
|
117
|
+
normalized = []
|
|
118
|
+
for i, value in enumerate(keypoints):
|
|
119
|
+
if i < len(self.min_vals) and i < len(self.max_vals):
|
|
120
|
+
min_val = self.min_vals[i]
|
|
121
|
+
max_val = self.max_vals[i]
|
|
122
|
+
if max_val - min_val > 0:
|
|
123
|
+
normalized.append((value - min_val) / (max_val - min_val))
|
|
124
|
+
else:
|
|
125
|
+
normalized.append(0)
|
|
126
|
+
else:
|
|
127
|
+
normalized.append(value)
|
|
128
|
+
|
|
129
|
+
return np.array(normalized, dtype=np.float32)
|
|
130
|
+
|
|
131
|
+
def predict_frame(self, frame):
|
|
132
|
+
"""执行手势分类预测(直接处理图像帧)"""
|
|
133
|
+
# 记录开始时间
|
|
134
|
+
start_time = time.time()
|
|
135
|
+
# 预处理图像
|
|
136
|
+
processed_image = self.preprocess_image(frame, 224, 224)
|
|
137
|
+
|
|
138
|
+
# 提取关键点
|
|
139
|
+
keypoints = self.extract_hand_keypoints(processed_image)
|
|
140
|
+
min_time = time.time()
|
|
141
|
+
hand_time = min_time - start_time
|
|
142
|
+
#print(f"关键点识别耗时: {hand_time:.4f}秒")
|
|
143
|
+
if keypoints is None:
|
|
144
|
+
return None, {"error": "未检测到手部"}
|
|
145
|
+
|
|
146
|
+
# 归一化关键点
|
|
147
|
+
normalized_kps = self.normalize_keypoints(keypoints)
|
|
148
|
+
|
|
149
|
+
# 准备ONNX输入
|
|
150
|
+
input_data = normalized_kps.reshape(1, -1).astype(np.float32)
|
|
151
|
+
|
|
152
|
+
# 运行推理
|
|
153
|
+
input_name = self.session.get_inputs()[0].name
|
|
154
|
+
outputs = self.session.run(None, {input_name: input_data})
|
|
155
|
+
predictions = outputs[0][0]
|
|
156
|
+
|
|
157
|
+
# 获取预测结果
|
|
158
|
+
class_id = np.argmax(predictions)
|
|
159
|
+
confidence = float(predictions[class_id])
|
|
160
|
+
|
|
161
|
+
# 获取类别标签
|
|
162
|
+
label = self.class_labels[class_id] if class_id < len(self.class_labels) else f"未知类别 {class_id}"
|
|
163
|
+
end_time = time.time()
|
|
164
|
+
all_time = end_time - start_time
|
|
165
|
+
onnx_time = end_time - min_time
|
|
166
|
+
print(f"onnx耗时: {onnx_time:.4f}秒")
|
|
167
|
+
print(f"总耗时: {all_time:.4f}秒")
|
|
168
|
+
# 返回原始结果和格式化结果
|
|
169
|
+
raw_result = predictions.tolist()
|
|
170
|
+
formatted_result = {
|
|
171
|
+
'class': label,
|
|
172
|
+
'confidence': confidence,
|
|
173
|
+
'class_id': class_id,
|
|
174
|
+
'probabilities': raw_result
|
|
175
|
+
}
|
|
176
|
+
|
|
177
|
+
return raw_result, formatted_result
|
|
178
|
+
|
|
179
|
+
# 保留原始方法以兼容旧代码
|
|
180
|
+
def predict(self, image_path):
|
|
181
|
+
"""执行手势分类预测(从文件路径)"""
|
|
182
|
+
try:
|
|
183
|
+
# 使用PIL库读取图像,避免libpng版本问题
|
|
184
|
+
pil_image = Image.open(image_path)
|
|
185
|
+
# 转换为RGB格式
|
|
186
|
+
rgb_image = pil_image.convert('RGB')
|
|
187
|
+
# 转换为numpy数组
|
|
188
|
+
image_array = np.array(rgb_image)
|
|
189
|
+
# 转换为BGR格式(OpenCV使用的格式)
|
|
190
|
+
image = cv2.cvtColor(image_array, cv2.COLOR_RGB2BGR)
|
|
191
|
+
|
|
192
|
+
if image is None:
|
|
193
|
+
raise ValueError(f"无法读取图像: {image_path}")
|
|
194
|
+
|
|
195
|
+
return self.predict_frame(image)
|
|
196
|
+
except Exception as e:
|
|
197
|
+
# 如果PIL失败,尝试使用cv2作为备选
|
|
198
|
+
image = cv2.imread(image_path)
|
|
199
|
+
if image is None:
|
|
200
|
+
raise ValueError(f"无法读取图像: {image_path}")
|
|
201
|
+
return self.predict_frame(image)
|
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
import onnxruntime as ort
|
|
2
|
+
import numpy as np
|
|
3
|
+
from PIL import Image
|
|
4
|
+
import onnx
|
|
5
|
+
import cv2
|
|
6
|
+
import time
|
|
7
|
+
|
|
8
|
+
class ImageWorkflow:
|
|
9
|
+
def __init__(self, model_path=None):
|
|
10
|
+
self.session = None
|
|
11
|
+
self.classes = []
|
|
12
|
+
self.metadata = {}
|
|
13
|
+
self.input_shape = [1, 224, 224, 3] # 默认输入形状
|
|
14
|
+
|
|
15
|
+
if model_path:
|
|
16
|
+
self.load_model(model_path)
|
|
17
|
+
|
|
18
|
+
def load_model(self, model_path):
|
|
19
|
+
"""加载模型并解析元数据"""
|
|
20
|
+
try:
|
|
21
|
+
# 读取ONNX元数据
|
|
22
|
+
onnx_model = onnx.load(model_path)
|
|
23
|
+
for meta in onnx_model.metadata_props:
|
|
24
|
+
self.metadata[meta.key] = meta.value
|
|
25
|
+
|
|
26
|
+
# 解析类别标签
|
|
27
|
+
if 'classes' in self.metadata:
|
|
28
|
+
self.classes = eval(self.metadata['classes'])
|
|
29
|
+
|
|
30
|
+
# 初始化推理会话
|
|
31
|
+
self.session = ort.InferenceSession(model_path)
|
|
32
|
+
self._parse_input_shape()
|
|
33
|
+
|
|
34
|
+
except Exception as e:
|
|
35
|
+
print(f"模型加载失败: {e}")
|
|
36
|
+
|
|
37
|
+
def _parse_input_shape(self):
|
|
38
|
+
"""自动解析输入形状"""
|
|
39
|
+
input_info = self.session.get_inputs()[0]
|
|
40
|
+
shape = []
|
|
41
|
+
for dim in input_info.shape:
|
|
42
|
+
# 处理动态维度(用1替代)
|
|
43
|
+
shape.append(1 if isinstance(dim, str) or dim < 0 else int(dim))
|
|
44
|
+
self.input_shape = shape
|
|
45
|
+
|
|
46
|
+
def _preprocess(self, image_path):
|
|
47
|
+
"""标准化预处理流程"""
|
|
48
|
+
try:
|
|
49
|
+
img = Image.open(image_path).convert("RGB")
|
|
50
|
+
|
|
51
|
+
# 获取目标尺寸(假设形状为 [N, H, W, C])
|
|
52
|
+
_, target_h, target_w, _ = self.input_shape
|
|
53
|
+
|
|
54
|
+
# 调整尺寸
|
|
55
|
+
img = img.resize((target_w, target_h), Image.BILINEAR)
|
|
56
|
+
|
|
57
|
+
# 转换为numpy数组并归一化
|
|
58
|
+
img_array = np.array(img).astype(np.float32) / 255.0
|
|
59
|
+
|
|
60
|
+
# 添加batch维度
|
|
61
|
+
return np.expand_dims(img_array, axis=0)
|
|
62
|
+
|
|
63
|
+
except Exception as e:
|
|
64
|
+
print(f"图像预处理失败: {e}")
|
|
65
|
+
return None
|
|
66
|
+
|
|
67
|
+
def inference(self, data, model_path=None):
|
|
68
|
+
"""执行推理"""
|
|
69
|
+
if model_path and not self.session:
|
|
70
|
+
self.load_model(model_path)
|
|
71
|
+
|
|
72
|
+
input_data = self._preprocess(data)
|
|
73
|
+
if input_data is None:
|
|
74
|
+
return None, None
|
|
75
|
+
|
|
76
|
+
try:
|
|
77
|
+
# 运行推理
|
|
78
|
+
outputs = self.session.run(None, {self.session.get_inputs()[0].name: input_data})
|
|
79
|
+
raw = outputs[0][0] # 假设输出形状为 [1, n_classes]
|
|
80
|
+
|
|
81
|
+
# 格式化输出
|
|
82
|
+
formatted = self._format_result(raw)
|
|
83
|
+
|
|
84
|
+
return raw, formatted
|
|
85
|
+
|
|
86
|
+
except Exception as e:
|
|
87
|
+
print(f"推理失败: {e}")
|
|
88
|
+
return None, None
|
|
89
|
+
|
|
90
|
+
def inference_frame(self, frame_data, model_path=None):
|
|
91
|
+
"""直接使用帧数据进行推理,无需文件IO
|
|
92
|
+
返回值:raw, formatted
|
|
93
|
+
formatted字典包含:class, confidence, probabilities, preprocess_time, inference_time
|
|
94
|
+
"""
|
|
95
|
+
if model_path and not self.session:
|
|
96
|
+
self.load_model(model_path)
|
|
97
|
+
|
|
98
|
+
# 测量预处理时间
|
|
99
|
+
preprocess_start = time.time()
|
|
100
|
+
input_data = self._preprocess_frame(frame_data)
|
|
101
|
+
preprocess_time = time.time() - preprocess_start
|
|
102
|
+
|
|
103
|
+
if input_data is None:
|
|
104
|
+
return None, None
|
|
105
|
+
|
|
106
|
+
try:
|
|
107
|
+
# 测量推理时间
|
|
108
|
+
inference_start = time.time()
|
|
109
|
+
# 运行推理
|
|
110
|
+
outputs = self.session.run(None, {self.session.get_inputs()[0].name: input_data})
|
|
111
|
+
inference_time = time.time() - inference_start
|
|
112
|
+
|
|
113
|
+
raw = outputs[0][0] # 假设输出形状为 [1, n_classes]
|
|
114
|
+
|
|
115
|
+
# 格式化输出
|
|
116
|
+
formatted = self._format_result(raw)
|
|
117
|
+
# 添加时间信息到返回结果
|
|
118
|
+
formatted['preprocess_time'] = preprocess_time
|
|
119
|
+
formatted['inference_time'] = inference_time
|
|
120
|
+
|
|
121
|
+
# 计算总耗时
|
|
122
|
+
total_time = preprocess_time + inference_time
|
|
123
|
+
print(f"帧推理耗时: {total_time:.4f}秒 - 识别结果: {formatted['class']} ({formatted['confidence']}%)")
|
|
124
|
+
return raw, formatted
|
|
125
|
+
|
|
126
|
+
except Exception as e:
|
|
127
|
+
print(f"帧数据推理失败: {e}")
|
|
128
|
+
return None, None
|
|
129
|
+
|
|
130
|
+
def _preprocess_frame(self, frame_data):
|
|
131
|
+
"""处理帧数据的预处理流程"""
|
|
132
|
+
try:
|
|
133
|
+
# 确保输入是numpy数组
|
|
134
|
+
if not isinstance(frame_data, np.ndarray):
|
|
135
|
+
print("错误: 帧数据必须是numpy数组")
|
|
136
|
+
return None
|
|
137
|
+
|
|
138
|
+
# OpenCV读取的帧是BGR格式,转换为RGB
|
|
139
|
+
img = cv2.cvtColor(frame_data, cv2.COLOR_BGR2RGB)
|
|
140
|
+
|
|
141
|
+
# 获取目标尺寸(假设形状为 [N, H, W, C])
|
|
142
|
+
_, target_h, target_w, _ = self.input_shape
|
|
143
|
+
|
|
144
|
+
# 调整尺寸
|
|
145
|
+
img = cv2.resize(img, (target_w, target_h), interpolation=cv2.INTER_LINEAR)
|
|
146
|
+
|
|
147
|
+
# 转换为numpy数组并归一化
|
|
148
|
+
img_array = img.astype(np.float32) / 255.0
|
|
149
|
+
|
|
150
|
+
# 添加batch维度
|
|
151
|
+
return np.expand_dims(img_array, axis=0)
|
|
152
|
+
|
|
153
|
+
except Exception as e:
|
|
154
|
+
print(f"帧数据预处理失败: {e}")
|
|
155
|
+
return None
|
|
156
|
+
|
|
157
|
+
def _format_result(self, predictions):
|
|
158
|
+
"""生成标准化输出"""
|
|
159
|
+
class_idx = np.argmax(predictions)
|
|
160
|
+
confidence = int(predictions[class_idx] * 100)
|
|
161
|
+
|
|
162
|
+
return {
|
|
163
|
+
'class': self.classes[class_idx] if self.classes else str(class_idx),
|
|
164
|
+
'confidence': confidence,
|
|
165
|
+
'probabilities': predictions.tolist()
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
# 使用示例
|
|
169
|
+
if __name__ == "__main__":
|
|
170
|
+
# 预加载模型
|
|
171
|
+
model = ImageWorkflow("model.onnx")
|
|
172
|
+
|
|
173
|
+
# 使用帧数据进行推理
|
|
174
|
+
# 假设frame是通过cv2获取的帧
|
|
175
|
+
# raw, res = model.inference_frame(frame)
|
|
176
|
+
# print(f"识别结果: {res['class']} ({res['confidence']}%)")
|