smartpi 0.1.38__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smartpi/__init__.py +8 -0
- smartpi/_gui.py +66 -0
- smartpi/base_driver.py +566 -0
- smartpi/camera.py +84 -0
- smartpi/color_sensor.py +18 -0
- smartpi/cw2015.py +179 -0
- smartpi/flash.py +130 -0
- smartpi/humidity.py +20 -0
- smartpi/led.py +19 -0
- smartpi/light_sensor.py +72 -0
- smartpi/motor.py +177 -0
- smartpi/move.py +218 -0
- smartpi/onnx_hand_workflow.py +201 -0
- smartpi/onnx_image_workflow.py +176 -0
- smartpi/onnx_pose_workflow.py +482 -0
- smartpi/onnx_text_workflow.py +173 -0
- smartpi/onnx_voice_workflow.py +437 -0
- smartpi/posemodel/__init__.py +0 -0
- smartpi/posemodel/posenet.tflite +0 -0
- smartpi/posenet_utils.py +222 -0
- smartpi/rknn_hand_workflow.py +245 -0
- smartpi/rknn_image_workflow.py +405 -0
- smartpi/rknn_pose_workflow.py +592 -0
- smartpi/rknn_text_workflow.py +240 -0
- smartpi/rknn_voice_workflow.py +394 -0
- smartpi/servo.py +178 -0
- smartpi/temperature.py +18 -0
- smartpi/text_gte_model/__init__.py +0 -0
- smartpi/text_gte_model/config/__init__.py +0 -0
- smartpi/text_gte_model/config/config.json +30 -0
- smartpi/text_gte_model/config/quantize_config.json +30 -0
- smartpi/text_gte_model/config/special_tokens_map.json +7 -0
- smartpi/text_gte_model/config/tokenizer.json +14924 -0
- smartpi/text_gte_model/config/tokenizer_config.json +23 -0
- smartpi/text_gte_model/config/vocab.txt +14760 -0
- smartpi/text_gte_model/gte/__init__.py +0 -0
- smartpi/text_gte_model/gte/gte_model.onnx +0 -0
- smartpi/touch_sensor.py +16 -0
- smartpi/trace.py +120 -0
- smartpi/ultrasonic.py +20 -0
- smartpi-0.1.38.dist-info/METADATA +17 -0
- smartpi-0.1.38.dist-info/RECORD +44 -0
- smartpi-0.1.38.dist-info/WHEEL +5 -0
- smartpi-0.1.38.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,437 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import onnxruntime as ort
|
|
3
|
+
import librosa
|
|
4
|
+
import onnx
|
|
5
|
+
import time
|
|
6
|
+
import os
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class Workflow:
|
|
10
|
+
def __init__(self, model_path=None, smoothing_time_constant=0, step_size=43):
|
|
11
|
+
self.model = None
|
|
12
|
+
self.classes = []
|
|
13
|
+
self.metadata = {}
|
|
14
|
+
self.model_params = {
|
|
15
|
+
'fft_size': 2048,
|
|
16
|
+
'sample_rate': 44100,
|
|
17
|
+
'num_frames': 43, # 每块帧数
|
|
18
|
+
'spec_features': 232
|
|
19
|
+
}
|
|
20
|
+
self.global_mean = None
|
|
21
|
+
self.global_std = None
|
|
22
|
+
self.smoothing_time_constant = smoothing_time_constant
|
|
23
|
+
self.step_size = step_size
|
|
24
|
+
self.frame_duration = None
|
|
25
|
+
self.hop_length = 735 # 44100/60=735 (每帧时长 ~16.67ms)
|
|
26
|
+
self.previous_spec = None
|
|
27
|
+
|
|
28
|
+
if model_path:
|
|
29
|
+
self.load_model(model_path)
|
|
30
|
+
|
|
31
|
+
# 计算帧时间信息
|
|
32
|
+
self.frame_duration = self.hop_length / self.model_params['sample_rate']
|
|
33
|
+
self.block_duration = self.model_params['num_frames'] * self.frame_duration
|
|
34
|
+
|
|
35
|
+
def load_model(self, model_path):
|
|
36
|
+
"""加载模型并解析元数据"""
|
|
37
|
+
onnx_model = onnx.load(model_path)
|
|
38
|
+
for meta in onnx_model.metadata_props:
|
|
39
|
+
self.metadata[meta.key] = meta.value
|
|
40
|
+
|
|
41
|
+
if 'classes' in self.metadata:
|
|
42
|
+
self.classes = eval(self.metadata['classes'])
|
|
43
|
+
|
|
44
|
+
if 'global_mean' in self.metadata:
|
|
45
|
+
self.global_mean = np.array(eval(self.metadata['global_mean']))
|
|
46
|
+
if 'global_std' in self.metadata:
|
|
47
|
+
self.global_std = np.array(eval(self.metadata['global_std']))
|
|
48
|
+
|
|
49
|
+
self.session = ort.InferenceSession(model_path)
|
|
50
|
+
self.input_shape = self._get_fixed_shape(self.session.get_inputs()[0].shape)
|
|
51
|
+
|
|
52
|
+
def _get_fixed_shape(self, shape):
|
|
53
|
+
fixed = []
|
|
54
|
+
for dim in shape:
|
|
55
|
+
if isinstance(dim, str) or dim < 0:
|
|
56
|
+
fixed.append(1)
|
|
57
|
+
else:
|
|
58
|
+
fixed.append(int(dim))
|
|
59
|
+
return fixed
|
|
60
|
+
|
|
61
|
+
def _apply_hann_window(self, frame):
|
|
62
|
+
"""应用汉宁窗函数"""
|
|
63
|
+
return frame * np.hanning(len(frame))
|
|
64
|
+
|
|
65
|
+
def _apply_temporal_smoothing(self, current_spec):
|
|
66
|
+
"""应用时域指数平滑"""
|
|
67
|
+
if self.previous_spec is None:
|
|
68
|
+
self.previous_spec = current_spec
|
|
69
|
+
return current_spec
|
|
70
|
+
|
|
71
|
+
smoothed = (self.smoothing_time_constant * self.previous_spec
|
|
72
|
+
+ (1 - self.smoothing_time_constant) * current_spec)
|
|
73
|
+
|
|
74
|
+
self.previous_spec = smoothed.copy()
|
|
75
|
+
return smoothed
|
|
76
|
+
|
|
77
|
+
def _load_audio(self, audio_path):
|
|
78
|
+
"""加载音频文件(支持wav和webm),返回音频数组和采样率"""
|
|
79
|
+
ext = os.path.splitext(audio_path)[1].lower()
|
|
80
|
+
|
|
81
|
+
if ext == '.wav':
|
|
82
|
+
# 使用librosa加载wav文件
|
|
83
|
+
audio, sr = librosa.load(audio_path, sr=self.model_params['sample_rate'])
|
|
84
|
+
return audio, sr
|
|
85
|
+
|
|
86
|
+
elif ext == '.webm':
|
|
87
|
+
# 使用pydub加载webm文件(需要ffmpeg支持)
|
|
88
|
+
try:
|
|
89
|
+
from pydub import AudioSegment
|
|
90
|
+
except ImportError:
|
|
91
|
+
raise ImportError("处理webm格式需要pydub库,请先安装:pip install pydub")
|
|
92
|
+
|
|
93
|
+
try:
|
|
94
|
+
# 加载webm文件
|
|
95
|
+
audio_segment = AudioSegment.from_file(audio_path, format='webm')
|
|
96
|
+
|
|
97
|
+
# 转换为单声道
|
|
98
|
+
audio_segment = audio_segment.set_channels(1)
|
|
99
|
+
|
|
100
|
+
# 转换采样率
|
|
101
|
+
audio_segment = audio_segment.set_frame_rate(self.model_params['sample_rate'])
|
|
102
|
+
|
|
103
|
+
# 转换为numpy数组(范围:[-1, 1])
|
|
104
|
+
samples = np.array(audio_segment.get_array_of_samples(), dtype=np.float32)
|
|
105
|
+
samples = samples / 32768.0 # 16位音频的归一化
|
|
106
|
+
|
|
107
|
+
return samples, self.model_params['sample_rate']
|
|
108
|
+
|
|
109
|
+
except FileNotFoundError as e:
|
|
110
|
+
if 'ffmpeg' in str(e).lower() or 'avconv' in str(e).lower():
|
|
111
|
+
print("\n" + "="*60)
|
|
112
|
+
print("检测到错误:缺少ffmpeg支持,无法处理webm格式音频")
|
|
113
|
+
print("="*60)
|
|
114
|
+
print("ffmpeg是处理webm等音频格式的必要工具,请按照以下教程安装:\n")
|
|
115
|
+
|
|
116
|
+
print("【Linux系统安装教程】")
|
|
117
|
+
print("1. Ubuntu/Debian系统:")
|
|
118
|
+
print(" sudo apt update")
|
|
119
|
+
print(" sudo apt install ffmpeg\n")
|
|
120
|
+
|
|
121
|
+
print("2. CentOS/RHEL系统:")
|
|
122
|
+
print(" sudo yum install epel-release")
|
|
123
|
+
print(" sudo yum install ffmpeg ffmpeg-devel\n")
|
|
124
|
+
|
|
125
|
+
print("3. Fedora系统:")
|
|
126
|
+
print(" sudo dnf install ffmpeg\n")
|
|
127
|
+
|
|
128
|
+
print("4. Arch Linux系统:")
|
|
129
|
+
print(" sudo pacman -S ffmpeg\n")
|
|
130
|
+
|
|
131
|
+
print("【Windows系统安装教程】")
|
|
132
|
+
print("1. 访问ffmpeg官网下载页:https://ffmpeg.org/download.html#build-windows")
|
|
133
|
+
print("2. 推荐下载方式:")
|
|
134
|
+
print(" - 从 Gyan.dev 下载:https://www.gyan.dev/ffmpeg/builds/")
|
|
135
|
+
print(" - 选择 'ffmpeg-release-essentials.zip' 版本")
|
|
136
|
+
print("3. 解压下载的zip文件到任意目录(例如:C:\\ffmpeg)")
|
|
137
|
+
print("4. 配置环境变量:")
|
|
138
|
+
print(" - 右键点击'此电脑' -> '属性' -> '高级系统设置' -> '环境变量'")
|
|
139
|
+
print(" - 在'系统变量'中找到'Path',点击'编辑'")
|
|
140
|
+
print(" - 点击'新建',添加ffmpeg的bin目录路径(例如:C:\\ffmpeg\\bin)")
|
|
141
|
+
print(" - 点击所有窗口的'确定'保存设置")
|
|
142
|
+
print("5. 验证安装:打开新的命令提示符,输入 'ffmpeg -version',能显示版本信息即为安装成功\n")
|
|
143
|
+
|
|
144
|
+
print("安装完成后,请重新运行程序。")
|
|
145
|
+
print("="*60 + "\n")
|
|
146
|
+
raise # 重新抛出异常终止程序
|
|
147
|
+
else:
|
|
148
|
+
raise # 其他文件未找到错误,正常抛出
|
|
149
|
+
except Exception as e:
|
|
150
|
+
print(f"处理webm音频时发生其他错误:{str(e)}")
|
|
151
|
+
raise
|
|
152
|
+
|
|
153
|
+
else:
|
|
154
|
+
raise ValueError(f"不支持的音频格式: {ext},目前支持 .wav 和 .webm")
|
|
155
|
+
|
|
156
|
+
def _preprocess_audio(self, audio_path):
|
|
157
|
+
"""预处理整个音频文件,返回分贝谱"""
|
|
158
|
+
audio, sr = self._load_audio(audio_path)
|
|
159
|
+
assert sr == self.model_params['sample_rate'], f"采样率不匹配,需要 {self.model_params['sample_rate']}Hz"
|
|
160
|
+
|
|
161
|
+
# 使用新参数计算STFT
|
|
162
|
+
hop_length = self.hop_length
|
|
163
|
+
win_length = self.model_params['fft_size']
|
|
164
|
+
n_fft = self.model_params['fft_size']
|
|
165
|
+
|
|
166
|
+
# 手动分帧并加窗
|
|
167
|
+
frames = librosa.util.frame(audio, frame_length=win_length, hop_length=hop_length)
|
|
168
|
+
windowed_frames = np.zeros_like(frames)
|
|
169
|
+
for i in range(frames.shape[1]):
|
|
170
|
+
windowed_frames[:, i] = self._apply_hann_window(frames[:, i])
|
|
171
|
+
|
|
172
|
+
# 执行FFT
|
|
173
|
+
D = np.fft.rfft(windowed_frames, n=n_fft, axis=0)
|
|
174
|
+
|
|
175
|
+
# 计算幅度谱并转分贝
|
|
176
|
+
magnitude = np.abs(D)
|
|
177
|
+
db = 20 * np.log10(np.maximum(1e-5, magnitude))
|
|
178
|
+
|
|
179
|
+
# 截取需要的特征维度并转置
|
|
180
|
+
db = db[:self.model_params['spec_features'], :]
|
|
181
|
+
spec = db.T # 转置为[时间帧, 频率特征]
|
|
182
|
+
|
|
183
|
+
return spec
|
|
184
|
+
|
|
185
|
+
def preprocess_audio_segment(self, audio_segment):
|
|
186
|
+
"""预处理音频片段(用于实时处理),返回分贝谱"""
|
|
187
|
+
# 确保音频是单声道且采样率正确
|
|
188
|
+
sr = self.model_params['sample_rate']
|
|
189
|
+
|
|
190
|
+
# 使用新参数计算STFT
|
|
191
|
+
hop_length = self.hop_length
|
|
192
|
+
win_length = self.model_params['fft_size']
|
|
193
|
+
n_fft = self.model_params['fft_size']
|
|
194
|
+
|
|
195
|
+
# 手动分帧并加窗
|
|
196
|
+
frames = librosa.util.frame(audio_segment, frame_length=win_length, hop_length=hop_length)
|
|
197
|
+
windowed_frames = np.zeros_like(frames)
|
|
198
|
+
for i in range(frames.shape[1]):
|
|
199
|
+
windowed_frames[:, i] = self._apply_hann_window(frames[:, i])
|
|
200
|
+
|
|
201
|
+
# 执行FFT
|
|
202
|
+
D = np.fft.rfft(windowed_frames, n=n_fft, axis=0)
|
|
203
|
+
|
|
204
|
+
# 计算幅度谱并转分贝
|
|
205
|
+
magnitude = np.abs(D)
|
|
206
|
+
db = 20 * np.log10(np.maximum(1e-5, magnitude))
|
|
207
|
+
|
|
208
|
+
# 截取需要的特征维度并转置
|
|
209
|
+
db = db[:self.model_params['spec_features'], :]
|
|
210
|
+
spec = db.T # 转置为[时间帧, 频率特征]
|
|
211
|
+
|
|
212
|
+
return spec
|
|
213
|
+
|
|
214
|
+
def _extract_blocks(self, full_spec):
|
|
215
|
+
"""从完整频谱中提取指定帧数的块"""
|
|
216
|
+
total_frames = full_spec.shape[0]
|
|
217
|
+
blocks = []
|
|
218
|
+
start_indices = []
|
|
219
|
+
|
|
220
|
+
num_blocks = (total_frames - self.model_params['num_frames']) // self.step_size + 1
|
|
221
|
+
|
|
222
|
+
for i in range(num_blocks):
|
|
223
|
+
start = i * self.step_size
|
|
224
|
+
end = start + self.model_params['num_frames']
|
|
225
|
+
|
|
226
|
+
block = full_spec[start:end, :]
|
|
227
|
+
|
|
228
|
+
if block.shape[0] < self.model_params['num_frames']:
|
|
229
|
+
padded = np.zeros((self.model_params['num_frames'], self.model_params['spec_features']))
|
|
230
|
+
padded[:block.shape[0]] = block
|
|
231
|
+
block = padded
|
|
232
|
+
|
|
233
|
+
blocks.append(block)
|
|
234
|
+
start_indices.append(start)
|
|
235
|
+
|
|
236
|
+
return blocks, start_indices
|
|
237
|
+
|
|
238
|
+
def _normalize(self, spec):
|
|
239
|
+
"""归一化处理"""
|
|
240
|
+
epsilon = 1e-8
|
|
241
|
+
mean = np.mean(spec)
|
|
242
|
+
variance = np.var(spec)
|
|
243
|
+
std = np.sqrt(variance)
|
|
244
|
+
normalized = (spec - mean) / (std + epsilon)
|
|
245
|
+
return normalized.astype(np.float32)
|
|
246
|
+
|
|
247
|
+
def inference(self, audio_path, model_path=None):
|
|
248
|
+
if model_path and not hasattr(self, 'session'):
|
|
249
|
+
self.load_model(model_path)
|
|
250
|
+
|
|
251
|
+
full_spec = self._preprocess_audio(audio_path)
|
|
252
|
+
blocks, start_indices = self._extract_blocks(full_spec)
|
|
253
|
+
|
|
254
|
+
block_results = []
|
|
255
|
+
|
|
256
|
+
print(f"开始处理音频: {audio_path}")
|
|
257
|
+
print(f"总帧数: {full_spec.shape[0]}, 总时长: {full_spec.shape[0] * self.frame_duration:.2f}秒")
|
|
258
|
+
print(f"将处理 {len(blocks)} 个块 (每块 {self.model_params['num_frames']}帧 = {self.block_duration:.3f}秒)")
|
|
259
|
+
print("=" * 60)
|
|
260
|
+
|
|
261
|
+
for i, block in enumerate(blocks):
|
|
262
|
+
start_time = time.time()
|
|
263
|
+
|
|
264
|
+
normalized_block = self._normalize(block)
|
|
265
|
+
input_tensor = normalized_block.flatten().reshape(self.input_shape)
|
|
266
|
+
|
|
267
|
+
input_name = self.session.get_inputs()[0].name
|
|
268
|
+
outputs = self.session.run(None, {input_name: input_tensor})
|
|
269
|
+
|
|
270
|
+
raw_output = outputs[0][0]
|
|
271
|
+
result = self._format_output(raw_output)
|
|
272
|
+
|
|
273
|
+
process_time = time.time() - start_time
|
|
274
|
+
start_frame = start_indices[i]
|
|
275
|
+
end_frame = start_frame + self.model_params['num_frames']
|
|
276
|
+
start_time_sec = start_frame * self.frame_duration
|
|
277
|
+
end_time_sec = end_frame * self.frame_duration
|
|
278
|
+
|
|
279
|
+
block_result = {
|
|
280
|
+
'block_index': i,
|
|
281
|
+
'start_frame': start_frame,
|
|
282
|
+
'end_frame': end_frame,
|
|
283
|
+
'start_time': start_time_sec,
|
|
284
|
+
'end_time': end_time_sec,
|
|
285
|
+
'process_time': process_time,
|
|
286
|
+
'result': result,
|
|
287
|
+
'raw_output': raw_output
|
|
288
|
+
}
|
|
289
|
+
|
|
290
|
+
block_results.append(block_result)
|
|
291
|
+
|
|
292
|
+
print(f"块 #{i+1} [时间: {start_time_sec:.2f}-{end_time_sec:.2f}s]")
|
|
293
|
+
print(f" 分类: {result['class']}, 置信度: {result['confidence']}%")
|
|
294
|
+
print(f" 处理时间: {process_time * 1000:.2f}ms")
|
|
295
|
+
print("-" * 50)
|
|
296
|
+
|
|
297
|
+
final_result = self._aggregate_results(block_results)
|
|
298
|
+
return block_results, final_result
|
|
299
|
+
|
|
300
|
+
def process_audio_segment(self, audio_segment):
|
|
301
|
+
"""处理音频片段(用于实时处理),包含时间测量功能"""
|
|
302
|
+
if not hasattr(self, 'session'):
|
|
303
|
+
raise ValueError("请先加载模型")
|
|
304
|
+
|
|
305
|
+
# 记录总开始时间
|
|
306
|
+
total_start_time = time.time()
|
|
307
|
+
|
|
308
|
+
# 记录预处理时间
|
|
309
|
+
preprocess_start_time = time.time()
|
|
310
|
+
full_spec = self.preprocess_audio_segment(audio_segment)
|
|
311
|
+
blocks, start_indices = self._extract_blocks(full_spec)
|
|
312
|
+
preprocess_time = time.time() - preprocess_start_time
|
|
313
|
+
|
|
314
|
+
block_results = []
|
|
315
|
+
inference_time = 0.0
|
|
316
|
+
|
|
317
|
+
for i, block in enumerate(blocks):
|
|
318
|
+
# 记录归一化时间
|
|
319
|
+
normalize_start_time = time.time()
|
|
320
|
+
normalized_block = self._normalize(block)
|
|
321
|
+
input_tensor = normalized_block.flatten().reshape(self.input_shape)
|
|
322
|
+
normalize_time = time.time() - normalize_start_time
|
|
323
|
+
|
|
324
|
+
# 记录推理时间
|
|
325
|
+
inference_start_time = time.time()
|
|
326
|
+
input_name = self.session.get_inputs()[0].name
|
|
327
|
+
outputs = self.session.run(None, {input_name: input_tensor})
|
|
328
|
+
block_inference_time = time.time() - inference_start_time
|
|
329
|
+
inference_time += block_inference_time
|
|
330
|
+
|
|
331
|
+
raw_output = outputs[0][0]
|
|
332
|
+
result = self._format_output(raw_output)
|
|
333
|
+
|
|
334
|
+
start_frame = start_indices[i]
|
|
335
|
+
end_frame = start_frame + self.model_params['num_frames']
|
|
336
|
+
start_time_sec = start_frame * self.frame_duration
|
|
337
|
+
end_time_sec = end_frame * self.frame_duration
|
|
338
|
+
|
|
339
|
+
block_result = {
|
|
340
|
+
'block_index': i,
|
|
341
|
+
'start_frame': start_frame,
|
|
342
|
+
'end_frame': end_frame,
|
|
343
|
+
'start_time': start_time_sec,
|
|
344
|
+
'end_time': end_time_sec,
|
|
345
|
+
'result': result,
|
|
346
|
+
'raw_output': raw_output,
|
|
347
|
+
'normalize_time': normalize_time,
|
|
348
|
+
'inference_time': block_inference_time
|
|
349
|
+
}
|
|
350
|
+
|
|
351
|
+
block_results.append(block_result)
|
|
352
|
+
|
|
353
|
+
final_result = self._aggregate_results(block_results)
|
|
354
|
+
|
|
355
|
+
# 计算总时间
|
|
356
|
+
total_time = time.time() - total_start_time
|
|
357
|
+
|
|
358
|
+
# 如果有最终结果,添加时间信息
|
|
359
|
+
if final_result:
|
|
360
|
+
final_result['preprocess_time'] = preprocess_time
|
|
361
|
+
final_result['inference_time'] = inference_time
|
|
362
|
+
final_result['total_time'] = total_time
|
|
363
|
+
|
|
364
|
+
return block_results, final_result
|
|
365
|
+
|
|
366
|
+
def _format_output(self, predictions):
|
|
367
|
+
class_idx = np.argmax(predictions)
|
|
368
|
+
confidence = int(predictions[class_idx] * 100)
|
|
369
|
+
if len(self.classes) > 0:
|
|
370
|
+
label = self.classes[class_idx] if class_idx < len(self.classes) else "未知"
|
|
371
|
+
else:
|
|
372
|
+
label = str(class_idx)
|
|
373
|
+
return {
|
|
374
|
+
'class': label,
|
|
375
|
+
'confidence': confidence,
|
|
376
|
+
'probabilities': predictions.tolist()
|
|
377
|
+
}
|
|
378
|
+
|
|
379
|
+
def _aggregate_results(self, block_results):
|
|
380
|
+
"""聚合所有块的结果"""
|
|
381
|
+
if len(block_results) == 2:
|
|
382
|
+
# 两个块时取置信度最高的
|
|
383
|
+
max_confidence = -1
|
|
384
|
+
best_result = None
|
|
385
|
+
for result in block_results:
|
|
386
|
+
if result['result']['confidence'] > max_confidence:
|
|
387
|
+
max_confidence = result['result']['confidence']
|
|
388
|
+
best_result = result
|
|
389
|
+
return {
|
|
390
|
+
'class': best_result['result']['class'],
|
|
391
|
+
'confidence': best_result['result']['confidence'],
|
|
392
|
+
'occurrence_percentage': 100.0,
|
|
393
|
+
'total_blocks': len(block_results),
|
|
394
|
+
'best_raw_output': best_result['raw_output'],
|
|
395
|
+
'class_distribution': {best_result['result']['class']: 1},
|
|
396
|
+
'aggregation_method': 'highest_confidence'
|
|
397
|
+
}
|
|
398
|
+
|
|
399
|
+
# 正常情况:统计每个类别的出现次数
|
|
400
|
+
class_counts = {}
|
|
401
|
+
max_confidence = {}
|
|
402
|
+
|
|
403
|
+
for result in block_results:
|
|
404
|
+
class_label = result['result']['class']
|
|
405
|
+
confidence = result['result']['confidence']
|
|
406
|
+
|
|
407
|
+
class_counts[class_label] = class_counts.get(class_label, 0) + 1
|
|
408
|
+
if class_label not in max_confidence or confidence > max_confidence[class_label]:
|
|
409
|
+
max_confidence[class_label] = confidence
|
|
410
|
+
|
|
411
|
+
if not class_counts:
|
|
412
|
+
return None
|
|
413
|
+
|
|
414
|
+
# 找出最频繁的类别
|
|
415
|
+
most_common = max(class_counts.items(), key=lambda x: x[1])
|
|
416
|
+
most_common_class = most_common[0]
|
|
417
|
+
count = most_common[1]
|
|
418
|
+
percentage = (count / len(block_results)) * 100
|
|
419
|
+
confidence = max_confidence[most_common_class]
|
|
420
|
+
|
|
421
|
+
# 找出该类别中置信度最高的原始输出
|
|
422
|
+
best_raw_output = None
|
|
423
|
+
for result in block_results:
|
|
424
|
+
if result['result']['class'] == most_common_class:
|
|
425
|
+
if best_raw_output is None or result['result']['confidence'] > best_raw_output['result']['confidence']:
|
|
426
|
+
best_raw_output = result
|
|
427
|
+
|
|
428
|
+
return {
|
|
429
|
+
'class': most_common_class,
|
|
430
|
+
'confidence': confidence,
|
|
431
|
+
'occurrence_percentage': percentage,
|
|
432
|
+
'total_blocks': len(block_results),
|
|
433
|
+
'best_raw_output': best_raw_output['raw_output'] if best_raw_output else None,
|
|
434
|
+
'class_distribution': class_counts,
|
|
435
|
+
'aggregation_method': 'majority_vote'
|
|
436
|
+
}
|
|
437
|
+
|
|
File without changes
|
|
Binary file
|
smartpi/posenet_utils.py
ADDED
|
@@ -0,0 +1,222 @@
|
|
|
1
|
+
import tensorflow as tf
|
|
2
|
+
import cv2
|
|
3
|
+
import numpy as np
|
|
4
|
+
import os
|
|
5
|
+
import sys
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
|
|
8
|
+
# 获取当前脚本的绝对路径
|
|
9
|
+
script_dir = os.path.dirname(os.path.abspath(__file__))
|
|
10
|
+
# 获取项目根目录(假设posenet_utils.py位于pose/lib目录下)
|
|
11
|
+
project_root = os.path.abspath(os.path.join(script_dir, '..'))
|
|
12
|
+
|
|
13
|
+
# 全局变量存储模型解释器和相关信息
|
|
14
|
+
_interpreter = None
|
|
15
|
+
_input_details = None
|
|
16
|
+
_output_details = None
|
|
17
|
+
# 使用绝对路径定义模型路径
|
|
18
|
+
_MODEL_PATH = os.path.join(project_root, 'posemodel', 'posenet.tflite') # 默认模型路径
|
|
19
|
+
|
|
20
|
+
# 人体姿态判断参数(可根据需求调整)
|
|
21
|
+
POSE_THRESHOLD = 0.3 # 单个关键点分数阈值
|
|
22
|
+
REQUIRED_KEYPOINTS = 3 # 判断存在人体所需的有效关键点数量
|
|
23
|
+
# 关键人体关节点索引(对应COCO数据集17个关键点)
|
|
24
|
+
KEY_KEYPOINTS = [0, 1, 2, 3, 4, 5, 6, 7] # 头部、颈部、肩膀、肘部等关键节点
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _load_posenet_model(model_path):
|
|
28
|
+
"""内部函数:加载Posenet TFLite模型"""
|
|
29
|
+
try:
|
|
30
|
+
interpreter = tf.lite.Interpreter(model_path=model_path)
|
|
31
|
+
interpreter.allocate_tensors()
|
|
32
|
+
input_details = interpreter.get_input_details()
|
|
33
|
+
output_details = interpreter.get_output_details()
|
|
34
|
+
return interpreter, input_details, output_details
|
|
35
|
+
except Exception as e:
|
|
36
|
+
raise FileNotFoundError(f"模型加载失败: {str(e)}")
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _preprocess_image(image_path, input_size=(257, 257)):
|
|
40
|
+
"""内部函数:预处理图像,对齐Web端逻辑"""
|
|
41
|
+
img = cv2.imread(image_path)
|
|
42
|
+
if img is None:
|
|
43
|
+
raise FileNotFoundError(f"无法读取图像: {image_path}")
|
|
44
|
+
|
|
45
|
+
# 转为RGB
|
|
46
|
+
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|
47
|
+
|
|
48
|
+
return _preprocess_common(img_rgb, input_size)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def _preprocess_frame(frame, input_size=(257, 257)):
|
|
52
|
+
"""内部函数:预处理视频帧(numpy数组)"""
|
|
53
|
+
# 确保输入是BGR格式(OpenCV默认格式)
|
|
54
|
+
if len(frame.shape) != 3 or frame.shape[2] != 3:
|
|
55
|
+
raise ValueError(f"无效的帧格式,期望3通道BGR图像,实际为{frame.shape}")
|
|
56
|
+
|
|
57
|
+
# 转为RGB
|
|
58
|
+
img_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
|
59
|
+
|
|
60
|
+
return _preprocess_common(img_rgb, input_size)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def _preprocess_common(img_rgb, input_size=(257, 257)):
|
|
64
|
+
"""通用预处理逻辑,供图像和帧处理共享"""
|
|
65
|
+
# 计算缩放比例
|
|
66
|
+
scale = min(input_size[0]/img_rgb.shape[1], input_size[1]/img_rgb.shape[0])
|
|
67
|
+
scaled_width = int(img_rgb.shape[1] * scale)
|
|
68
|
+
scaled_height = int(img_rgb.shape[0] * scale)
|
|
69
|
+
|
|
70
|
+
# 缩放图像(使用线性插值平衡速度和质量)
|
|
71
|
+
img_scaled = cv2.resize(img_rgb, (scaled_width, scaled_height), interpolation=cv2.INTER_LINEAR)
|
|
72
|
+
|
|
73
|
+
# 创建257x257画布,居中放置缩放后的图像
|
|
74
|
+
img_padded = np.ones((input_size[1], input_size[0], 3), dtype=np.uint8) * 255
|
|
75
|
+
x_offset = (input_size[0] - scaled_width) // 2
|
|
76
|
+
y_offset = (input_size[1] - scaled_height) // 2
|
|
77
|
+
img_padded[y_offset:y_offset+scaled_height, x_offset:x_offset+scaled_width, :] = img_scaled
|
|
78
|
+
|
|
79
|
+
# 归一化
|
|
80
|
+
img_normalized = (img_padded.astype(np.float32) / 127.5) - 1.0
|
|
81
|
+
|
|
82
|
+
# 添加批次维度
|
|
83
|
+
return np.expand_dims(img_normalized, axis=0)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def _has_human_pose(heatmap_scores):
|
|
87
|
+
"""判断是否存在人体姿态"""
|
|
88
|
+
# heatmap_scores形状为 (height, width, num_keypoints)
|
|
89
|
+
num_keypoints = heatmap_scores.shape[2]
|
|
90
|
+
|
|
91
|
+
# 检查关键节点索引是否有效
|
|
92
|
+
valid_keypoints = [k for k in KEY_KEYPOINTS if k < num_keypoints]
|
|
93
|
+
if not valid_keypoints:
|
|
94
|
+
return False, 0
|
|
95
|
+
|
|
96
|
+
# 计算每个关键点的最大分数(在整个热图上的最大值)
|
|
97
|
+
keypoint_max_scores = []
|
|
98
|
+
for k in valid_keypoints:
|
|
99
|
+
# 取当前关键点通道的最大分数
|
|
100
|
+
max_score = np.max(heatmap_scores[..., k])
|
|
101
|
+
keypoint_max_scores.append(max_score)
|
|
102
|
+
|
|
103
|
+
# 统计超过阈值的关键点数量
|
|
104
|
+
valid_count = sum(1 for score in keypoint_max_scores if score >= POSE_THRESHOLD)
|
|
105
|
+
|
|
106
|
+
# 判断是否达到所需数量
|
|
107
|
+
has_pose = valid_count >= REQUIRED_KEYPOINTS
|
|
108
|
+
return has_pose, valid_count
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def get_posenet_output(input_data, model_path=None, output_file=None,
|
|
112
|
+
heatmap_file="heatmap.txt", offsets_file="offsets.txt", precision=6):
|
|
113
|
+
"""
|
|
114
|
+
获取输入的posenet输出,支持图像路径或视频帧(numpy数组)
|
|
115
|
+
|
|
116
|
+
参数:
|
|
117
|
+
input_data: 图像文件路径(str)或视频帧(numpy.ndarray,BGR格式)
|
|
118
|
+
model_path: 可选,模型文件路径,默认使用 _MODEL_PATH
|
|
119
|
+
output_file: 可选,拼接后的输出txt文件路径,若为None则不保存
|
|
120
|
+
heatmap_file: 可选,heatmap数据保存路径,若为None则不保存
|
|
121
|
+
offsets_file: 可选,offsets数据保存路径,若为None则不保存
|
|
122
|
+
precision: 数据保存精度(小数位数),默认6位
|
|
123
|
+
|
|
124
|
+
返回:
|
|
125
|
+
元组 (posenet_output, has_pose, valid_keypoint_count)
|
|
126
|
+
posenet_output: 处理后的一维数组
|
|
127
|
+
has_pose: 是否检测到人体姿态(bool)
|
|
128
|
+
valid_keypoint_count: 有效关键点数量
|
|
129
|
+
"""
|
|
130
|
+
global _interpreter, _input_details, _output_details, _MODEL_PATH
|
|
131
|
+
|
|
132
|
+
# 如果指定了新的模型路径或模型未加载,则重新加载模型
|
|
133
|
+
if model_path is not None or _interpreter is None:
|
|
134
|
+
model_to_load = model_path if model_path is not None else _MODEL_PATH
|
|
135
|
+
_interpreter, _input_details, _output_details = _load_posenet_model(model_to_load)
|
|
136
|
+
|
|
137
|
+
# 根据输入类型选择预处理方式
|
|
138
|
+
if isinstance(input_data, str):
|
|
139
|
+
# 处理图像路径
|
|
140
|
+
input_tensor = _preprocess_image(input_data)
|
|
141
|
+
elif isinstance(input_data, np.ndarray):
|
|
142
|
+
# 处理视频帧(numpy数组)
|
|
143
|
+
input_tensor = _preprocess_frame(input_data)
|
|
144
|
+
else:
|
|
145
|
+
raise TypeError(f"不支持的输入类型: {type(input_data)},请提供图像路径或numpy数组")
|
|
146
|
+
|
|
147
|
+
# 执行推理(复用全局解释器,避免重复初始化)
|
|
148
|
+
_interpreter.set_tensor(_input_details[0]['index'], input_tensor)
|
|
149
|
+
_interpreter.invoke()
|
|
150
|
+
|
|
151
|
+
# 按名称匹配输出张量
|
|
152
|
+
output_dict = {}
|
|
153
|
+
for output in _output_details:
|
|
154
|
+
output_name = output['name']
|
|
155
|
+
output_tensor = _interpreter.get_tensor(output['index']).squeeze(axis=0)
|
|
156
|
+
output_dict[output_name] = output_tensor
|
|
157
|
+
|
|
158
|
+
# 提取heatmap和offsets,对heatmap应用Sigmoid激活
|
|
159
|
+
heatmap = output_dict['MobilenetV1/heatmap_2/BiasAdd']
|
|
160
|
+
offsets = output_dict['MobilenetV1/offset_2/BiasAdd']
|
|
161
|
+
|
|
162
|
+
# 对heatmap应用Sigmoid激活,与TFJS侧的heatmapScores保持一致
|
|
163
|
+
def sigmoid(x):
|
|
164
|
+
x = np.clip(x, -500, 500) # 限制输入范围,防止exp计算溢出
|
|
165
|
+
return 1 / (1 + np.exp(-x))
|
|
166
|
+
|
|
167
|
+
# 生成激活后的heatmap分数(范围[0,1],与训练数据一致)
|
|
168
|
+
heatmap_scores = sigmoid(heatmap)
|
|
169
|
+
|
|
170
|
+
# 判断是否存在人体姿态
|
|
171
|
+
has_pose, valid_count = _has_human_pose(heatmap_scores)
|
|
172
|
+
|
|
173
|
+
# 拼接激活后的heatmap和offsets(保持与TFJS侧顺序一致)
|
|
174
|
+
concatenated = np.concatenate([heatmap_scores, offsets], axis=2)
|
|
175
|
+
posenet_output = concatenated.astype(np.float32).flatten()
|
|
176
|
+
|
|
177
|
+
# 保存拼接后的输出(仅当指定了路径且输入是图像时)
|
|
178
|
+
if output_file is not None and isinstance(input_data, str):
|
|
179
|
+
output_dir = Path(output_file).parent
|
|
180
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
181
|
+
|
|
182
|
+
with open(output_file, 'w', encoding='utf-8') as f:
|
|
183
|
+
for value in posenet_output:
|
|
184
|
+
f.write(f"{value:.{precision}f}\n")
|
|
185
|
+
print(f"拼接后的posenet输出已保存到: {output_file}")
|
|
186
|
+
|
|
187
|
+
return posenet_output, has_pose, valid_count
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
# 配套加载函数:从按行保存的txt文件加载数据
|
|
191
|
+
def load_posenet_output(txt_path):
|
|
192
|
+
"""从按行保存的txt文件加载posenet_output"""
|
|
193
|
+
if not Path(txt_path).exists():
|
|
194
|
+
raise FileNotFoundError(f"文件不存在: {txt_path}")
|
|
195
|
+
|
|
196
|
+
with open(txt_path, 'r', encoding='utf-8') as f:
|
|
197
|
+
# 读取所有行,跳过空行并转换为float
|
|
198
|
+
data = [float(line.strip()) for line in f if line.strip() and not line.strip().startswith('shape:')]
|
|
199
|
+
|
|
200
|
+
return np.array(data, dtype=np.float32)
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
# 加载heatmap或offsets数据的函数
|
|
204
|
+
def load_posenet_component(txt_path):
|
|
205
|
+
"""从保存的文件加载heatmap或offsets数据,保留原始形状"""
|
|
206
|
+
if not Path(txt_path).exists():
|
|
207
|
+
raise FileNotFoundError(f"文件不存在: {txt_path}")
|
|
208
|
+
|
|
209
|
+
with open(txt_path, 'r', encoding='utf-8') as f:
|
|
210
|
+
lines = [line.strip() for line in f if line.strip()]
|
|
211
|
+
|
|
212
|
+
# 解析形状信息
|
|
213
|
+
shape_line = next(line for line in lines if line.startswith('shape:'))
|
|
214
|
+
shape_str = shape_line.split('shape: ')[1].strip('()')
|
|
215
|
+
shape = tuple(map(int, shape_str.split(',')))
|
|
216
|
+
|
|
217
|
+
# 解析数据
|
|
218
|
+
data_lines = [line for line in lines if not line.startswith('shape:')]
|
|
219
|
+
data = np.array([float(line) for line in data_lines], dtype=np.float32)
|
|
220
|
+
|
|
221
|
+
# 重塑为原始形状
|
|
222
|
+
return data.reshape(shape)
|