smartpi 0.1.34__py3-none-any.whl → 0.1.36__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,222 @@
1
+ import tensorflow as tf
2
+ import cv2
3
+ import numpy as np
4
+ import os
5
+ import sys
6
+ from pathlib import Path
7
+
8
+ # 获取当前脚本的绝对路径
9
+ script_dir = os.path.dirname(os.path.abspath(__file__))
10
+ # 获取项目根目录(假设posenet_utils.py位于pose/lib目录下)
11
+ project_root = os.path.abspath(os.path.join(script_dir, '..'))
12
+
13
+ # 全局变量存储模型解释器和相关信息
14
+ _interpreter = None
15
+ _input_details = None
16
+ _output_details = None
17
+ # 使用绝对路径定义模型路径
18
+ _MODEL_PATH = os.path.join(project_root, 'posemodel', 'posenet.tflite') # 默认模型路径
19
+
20
+ # 人体姿态判断参数(可根据需求调整)
21
+ POSE_THRESHOLD = 0.3 # 单个关键点分数阈值
22
+ REQUIRED_KEYPOINTS = 3 # 判断存在人体所需的有效关键点数量
23
+ # 关键人体关节点索引(对应COCO数据集17个关键点)
24
+ KEY_KEYPOINTS = [0, 1, 2, 3, 4, 5, 6, 7] # 头部、颈部、肩膀、肘部等关键节点
25
+
26
+
27
+ def _load_posenet_model(model_path):
28
+ """内部函数:加载Posenet TFLite模型"""
29
+ try:
30
+ interpreter = tf.lite.Interpreter(model_path=model_path)
31
+ interpreter.allocate_tensors()
32
+ input_details = interpreter.get_input_details()
33
+ output_details = interpreter.get_output_details()
34
+ return interpreter, input_details, output_details
35
+ except Exception as e:
36
+ raise FileNotFoundError(f"模型加载失败: {str(e)}")
37
+
38
+
39
+ def _preprocess_image(image_path, input_size=(257, 257)):
40
+ """内部函数:预处理图像,对齐Web端逻辑"""
41
+ img = cv2.imread(image_path)
42
+ if img is None:
43
+ raise FileNotFoundError(f"无法读取图像: {image_path}")
44
+
45
+ # 转为RGB
46
+ img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
47
+
48
+ return _preprocess_common(img_rgb, input_size)
49
+
50
+
51
+ def _preprocess_frame(frame, input_size=(257, 257)):
52
+ """内部函数:预处理视频帧(numpy数组)"""
53
+ # 确保输入是BGR格式(OpenCV默认格式)
54
+ if len(frame.shape) != 3 or frame.shape[2] != 3:
55
+ raise ValueError(f"无效的帧格式,期望3通道BGR图像,实际为{frame.shape}")
56
+
57
+ # 转为RGB
58
+ img_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
59
+
60
+ return _preprocess_common(img_rgb, input_size)
61
+
62
+
63
+ def _preprocess_common(img_rgb, input_size=(257, 257)):
64
+ """通用预处理逻辑,供图像和帧处理共享"""
65
+ # 计算缩放比例
66
+ scale = min(input_size[0]/img_rgb.shape[1], input_size[1]/img_rgb.shape[0])
67
+ scaled_width = int(img_rgb.shape[1] * scale)
68
+ scaled_height = int(img_rgb.shape[0] * scale)
69
+
70
+ # 缩放图像(使用线性插值平衡速度和质量)
71
+ img_scaled = cv2.resize(img_rgb, (scaled_width, scaled_height), interpolation=cv2.INTER_LINEAR)
72
+
73
+ # 创建257x257画布,居中放置缩放后的图像
74
+ img_padded = np.ones((input_size[1], input_size[0], 3), dtype=np.uint8) * 255
75
+ x_offset = (input_size[0] - scaled_width) // 2
76
+ y_offset = (input_size[1] - scaled_height) // 2
77
+ img_padded[y_offset:y_offset+scaled_height, x_offset:x_offset+scaled_width, :] = img_scaled
78
+
79
+ # 归一化
80
+ img_normalized = (img_padded.astype(np.float32) / 127.5) - 1.0
81
+
82
+ # 添加批次维度
83
+ return np.expand_dims(img_normalized, axis=0)
84
+
85
+
86
+ def _has_human_pose(heatmap_scores):
87
+ """判断是否存在人体姿态"""
88
+ # heatmap_scores形状为 (height, width, num_keypoints)
89
+ num_keypoints = heatmap_scores.shape[2]
90
+
91
+ # 检查关键节点索引是否有效
92
+ valid_keypoints = [k for k in KEY_KEYPOINTS if k < num_keypoints]
93
+ if not valid_keypoints:
94
+ return False, 0
95
+
96
+ # 计算每个关键点的最大分数(在整个热图上的最大值)
97
+ keypoint_max_scores = []
98
+ for k in valid_keypoints:
99
+ # 取当前关键点通道的最大分数
100
+ max_score = np.max(heatmap_scores[..., k])
101
+ keypoint_max_scores.append(max_score)
102
+
103
+ # 统计超过阈值的关键点数量
104
+ valid_count = sum(1 for score in keypoint_max_scores if score >= POSE_THRESHOLD)
105
+
106
+ # 判断是否达到所需数量
107
+ has_pose = valid_count >= REQUIRED_KEYPOINTS
108
+ return has_pose, valid_count
109
+
110
+
111
+ def get_posenet_output(input_data, model_path=None, output_file=None,
112
+ heatmap_file="heatmap.txt", offsets_file="offsets.txt", precision=6):
113
+ """
114
+ 获取输入的posenet输出,支持图像路径或视频帧(numpy数组)
115
+
116
+ 参数:
117
+ input_data: 图像文件路径(str)或视频帧(numpy.ndarray,BGR格式)
118
+ model_path: 可选,模型文件路径,默认使用 _MODEL_PATH
119
+ output_file: 可选,拼接后的输出txt文件路径,若为None则不保存
120
+ heatmap_file: 可选,heatmap数据保存路径,若为None则不保存
121
+ offsets_file: 可选,offsets数据保存路径,若为None则不保存
122
+ precision: 数据保存精度(小数位数),默认6位
123
+
124
+ 返回:
125
+ 元组 (posenet_output, has_pose, valid_keypoint_count)
126
+ posenet_output: 处理后的一维数组
127
+ has_pose: 是否检测到人体姿态(bool)
128
+ valid_keypoint_count: 有效关键点数量
129
+ """
130
+ global _interpreter, _input_details, _output_details, _MODEL_PATH
131
+
132
+ # 如果指定了新的模型路径或模型未加载,则重新加载模型
133
+ if model_path is not None or _interpreter is None:
134
+ model_to_load = model_path if model_path is not None else _MODEL_PATH
135
+ _interpreter, _input_details, _output_details = _load_posenet_model(model_to_load)
136
+
137
+ # 根据输入类型选择预处理方式
138
+ if isinstance(input_data, str):
139
+ # 处理图像路径
140
+ input_tensor = _preprocess_image(input_data)
141
+ elif isinstance(input_data, np.ndarray):
142
+ # 处理视频帧(numpy数组)
143
+ input_tensor = _preprocess_frame(input_data)
144
+ else:
145
+ raise TypeError(f"不支持的输入类型: {type(input_data)},请提供图像路径或numpy数组")
146
+
147
+ # 执行推理(复用全局解释器,避免重复初始化)
148
+ _interpreter.set_tensor(_input_details[0]['index'], input_tensor)
149
+ _interpreter.invoke()
150
+
151
+ # 按名称匹配输出张量
152
+ output_dict = {}
153
+ for output in _output_details:
154
+ output_name = output['name']
155
+ output_tensor = _interpreter.get_tensor(output['index']).squeeze(axis=0)
156
+ output_dict[output_name] = output_tensor
157
+
158
+ # 提取heatmap和offsets,对heatmap应用Sigmoid激活
159
+ heatmap = output_dict['MobilenetV1/heatmap_2/BiasAdd']
160
+ offsets = output_dict['MobilenetV1/offset_2/BiasAdd']
161
+
162
+ # 对heatmap应用Sigmoid激活,与TFJS侧的heatmapScores保持一致
163
+ def sigmoid(x):
164
+ x = np.clip(x, -500, 500) # 限制输入范围,防止exp计算溢出
165
+ return 1 / (1 + np.exp(-x))
166
+
167
+ # 生成激活后的heatmap分数(范围[0,1],与训练数据一致)
168
+ heatmap_scores = sigmoid(heatmap)
169
+
170
+ # 判断是否存在人体姿态
171
+ has_pose, valid_count = _has_human_pose(heatmap_scores)
172
+
173
+ # 拼接激活后的heatmap和offsets(保持与TFJS侧顺序一致)
174
+ concatenated = np.concatenate([heatmap_scores, offsets], axis=2)
175
+ posenet_output = concatenated.astype(np.float32).flatten()
176
+
177
+ # 保存拼接后的输出(仅当指定了路径且输入是图像时)
178
+ if output_file is not None and isinstance(input_data, str):
179
+ output_dir = Path(output_file).parent
180
+ output_dir.mkdir(parents=True, exist_ok=True)
181
+
182
+ with open(output_file, 'w', encoding='utf-8') as f:
183
+ for value in posenet_output:
184
+ f.write(f"{value:.{precision}f}\n")
185
+ print(f"拼接后的posenet输出已保存到: {output_file}")
186
+
187
+ return posenet_output, has_pose, valid_count
188
+
189
+
190
+ # 配套加载函数:从按行保存的txt文件加载数据
191
+ def load_posenet_output(txt_path):
192
+ """从按行保存的txt文件加载posenet_output"""
193
+ if not Path(txt_path).exists():
194
+ raise FileNotFoundError(f"文件不存在: {txt_path}")
195
+
196
+ with open(txt_path, 'r', encoding='utf-8') as f:
197
+ # 读取所有行,跳过空行并转换为float
198
+ data = [float(line.strip()) for line in f if line.strip() and not line.strip().startswith('shape:')]
199
+
200
+ return np.array(data, dtype=np.float32)
201
+
202
+
203
+ # 加载heatmap或offsets数据的函数
204
+ def load_posenet_component(txt_path):
205
+ """从保存的文件加载heatmap或offsets数据,保留原始形状"""
206
+ if not Path(txt_path).exists():
207
+ raise FileNotFoundError(f"文件不存在: {txt_path}")
208
+
209
+ with open(txt_path, 'r', encoding='utf-8') as f:
210
+ lines = [line.strip() for line in f if line.strip()]
211
+
212
+ # 解析形状信息
213
+ shape_line = next(line for line in lines if line.startswith('shape:'))
214
+ shape_str = shape_line.split('shape: ')[1].strip('()')
215
+ shape = tuple(map(int, shape_str.split(',')))
216
+
217
+ # 解析数据
218
+ data_lines = [line for line in lines if not line.startswith('shape:')]
219
+ data = np.array([float(line) for line in data_lines], dtype=np.float32)
220
+
221
+ # 重塑为原始形状
222
+ return data.reshape(shape)
@@ -0,0 +1,245 @@
1
+ import cv2
2
+ import numpy as np
3
+ import mediapipe as mp
4
+ import json
5
+ from PIL import Image
6
+ import os
7
+ from rknnlite.api import RKNNLite # 导入RKNNLite
8
+ import time # 用于时间测量
9
+
10
+ class GestureWorkflow:
11
+ def __init__(self, model_path):
12
+ # 确保model_path是绝对路径
13
+ self.model_path = os.path.abspath(model_path)
14
+
15
+ # 初始化MediaPipe Hands
16
+ self.mp_hands = mp.solutions.hands
17
+ self.hands = self.mp_hands.Hands(
18
+ static_image_mode=False, # 视频流模式 如果只是获取照片的手势关键点 请设置为True
19
+ max_num_hands=1,#如果想要检测双手,请设置成2
20
+ min_detection_confidence=0.5,#手势关键点的阈值
21
+ model_complexity=0#使用最简单的模型 如果效果不准确 可以考虑设置比较复制的模型 1
22
+ )
23
+
24
+ # 初始化元数据
25
+ self.min_vals = None
26
+ self.max_vals = None
27
+ self.class_labels = None
28
+
29
+ # 加载模型和元数据
30
+ self.load_model(self.model_path)
31
+
32
+ def load_model(self, model_path):
33
+ """加载RKNN模型并解析元数据"""
34
+ # 创建RKNNLite实例
35
+ self.rknn_lite = RKNNLite()
36
+
37
+ # 加载RKNN模型
38
+ ret = self.rknn_lite.load_rknn(model_path)
39
+ if ret != 0:
40
+ raise RuntimeError(f'加载RKNN模型失败, 错误码: {ret}')
41
+
42
+ # 初始化运行时环境 强制使用npu core_mask=RKNNLite.NPU_CORE_0
43
+ ret = self.rknn_lite.init_runtime()
44
+ if ret != 0:
45
+ raise RuntimeError(f'初始化NPU运行时失败, 错误码: {ret}')
46
+
47
+ # 从同目录的JSON文件加载元数据
48
+ metadata_path = self._get_metadata_path(model_path)
49
+ self._load_metadata(metadata_path)
50
+
51
+ def _get_metadata_path(self, model_path):
52
+ """获取元数据文件的绝对路径"""
53
+ # 尝试与模型同目录的JSON文件
54
+ base_dir = os.path.dirname(model_path)
55
+ base_name = os.path.basename(model_path)
56
+ metadata_name = os.path.splitext(base_name)[0] + 'rknn_metadata.json'
57
+ metadata_path = os.path.join(base_dir, metadata_name)
58
+
59
+ # 如果文件不存在,尝试默认名称
60
+ if not os.path.exists(metadata_path):
61
+ metadata_path = os.path.join(base_dir, 'rknn_metadata.json')
62
+
63
+ return metadata_path
64
+
65
+ def _load_metadata(self, metadata_path):
66
+ """从JSON文件加载元数据"""
67
+ try:
68
+ with open(metadata_path, 'r', encoding='utf-8') as f:
69
+ metadata = json.load(f)
70
+
71
+ self.class_labels = metadata.get('classes', ["点赞", "点踩", "胜利", "拳头", "我爱你", "手掌"])
72
+ min_max = metadata.get('minMax', {})
73
+ self.min_vals = min_max.get('min', [])
74
+ self.max_vals = min_max.get('max', [])
75
+
76
+ print(f"从 {metadata_path} 加载元数据成功")
77
+ print(f"类别标签: {self.class_labels}")
78
+
79
+ except Exception as e:
80
+ print(f"加载元数据失败: {e}")
81
+ # 设置默认值
82
+ self.class_labels = ["点赞", "点踩", "胜利", "拳头", "我爱你", "手掌"]
83
+ self.min_vals = []
84
+ self.max_vals = []
85
+
86
+ def preprocess_image(self, image, target_width=224, target_height=224):
87
+ """
88
+ 预处理图像:保持比例缩放并居中放置在目标尺寸的画布上
89
+ 返回处理后的OpenCV图像 (BGR格式)
90
+ """
91
+ # 将OpenCV图像转换为PIL格式
92
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
93
+ pil_image = Image.fromarray(image_rgb)
94
+
95
+ # 计算缩放比例
96
+ width, height = pil_image.size
97
+ scale = min(target_width / width, target_height / height)
98
+
99
+ # 计算新尺寸和位置
100
+ new_width = int(width * scale)
101
+ new_height = int(height * scale)
102
+ x = (target_width - new_width) // 2
103
+ y = (target_height - new_height) // 2
104
+
105
+ # 创建白色背景画布并粘贴缩放后的图像
106
+ canvas = Image.new('RGB', (target_width, target_height), (255, 255, 255))
107
+ resized_image = pil_image.resize((new_width, new_height), Image.Resampling.LANCZOS)
108
+ canvas.paste(resized_image, (x, y))
109
+
110
+ # 转换回OpenCV格式
111
+ processed_image = np.array(canvas)
112
+ return cv2.cvtColor(processed_image, cv2.COLOR_RGB2BGR)
113
+
114
+ def extract_hand_keypoints(self, image):
115
+ """从图像中提取手部关键点"""
116
+ # 转换图像为RGB格式并处理
117
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
118
+ results = self.hands.process(image_rgb)
119
+
120
+ if results.multi_hand_landmarks:
121
+ # 只使用检测到的第一只手
122
+ landmarks = results.multi_hand_world_landmarks[0]
123
+
124
+ # 提取关键点坐标
125
+ keypoints = []
126
+ for landmark in landmarks.landmark:
127
+ keypoints.extend([landmark.x, landmark.y, landmark.z])
128
+
129
+ return np.array(keypoints, dtype=np.float32)
130
+ return None
131
+
132
+ def normalize_keypoints(self, keypoints):
133
+ """归一化关键点数据"""
134
+ if not self.min_vals or not self.max_vals or len(self.min_vals) != len(keypoints):
135
+ # 如果没有归一化参数或长度不匹配,返回原始数据
136
+ return keypoints
137
+
138
+ normalized = []
139
+ for i, value in enumerate(keypoints):
140
+ min_val = self.min_vals[i]
141
+ max_val = self.max_vals[i]
142
+ if max_val - min_val > 1e-6: # 避免除以零
143
+ normalized.append((value - min_val) / (max_val - min_val))
144
+ else:
145
+ normalized.append(0.0)
146
+
147
+ return np.array(normalized, dtype=np.float32)
148
+
149
+ def predict_frame(self, frame):
150
+ """执行手势分类预测(直接处理图像帧)"""
151
+ # 记录开始时间
152
+ start_time = time.time()
153
+
154
+ # 预处理图像
155
+ processed_image = self.preprocess_image(frame, 224, 224)
156
+
157
+ # 提取关键点
158
+ keypoints = self.extract_hand_keypoints(processed_image)
159
+ min_time = time.time()
160
+ hand_time = min_time - start_time
161
+ #print(f"关键点识别耗时: {hand_time:.4f}秒")
162
+ if keypoints is None:
163
+ # 记录结束时间并计算耗时
164
+ end_time = time.time()
165
+ #print(f"识别耗时: {end_time - start_time:.4f}秒 (未检测到手部)")
166
+ return None, {"error": "未检测到手部", "processing_time": end_time - start_time}
167
+
168
+ # 归一化关键点
169
+ normalized_kps = self.normalize_keypoints(keypoints)
170
+
171
+ # 准备输入数据 (1, 63) 形状
172
+ input_data = normalized_kps.reshape(1, -1).astype(np.float32)
173
+
174
+ # 使用RKNN Lite进行推理
175
+ try:
176
+ outputs = self.rknn_lite.inference(inputs=[input_data])
177
+ predictions = outputs[0][0]
178
+
179
+ # 获取预测结果
180
+ class_id = np.argmax(predictions)
181
+ confidence = float(predictions[class_id])
182
+
183
+ # 获取类别标签
184
+ label = self.class_labels[class_id] if class_id < len(self.class_labels) else f"未知类别 {class_id}"
185
+
186
+ # 返回原始结果和格式化结果
187
+ raw_result = predictions.tolist()
188
+ formatted_result = {
189
+ 'class': label,
190
+ 'confidence': confidence,
191
+ 'class_id': class_id,
192
+ 'probabilities': raw_result
193
+ }
194
+
195
+ # 记录结束时间并计算耗时
196
+ end_time = time.time()
197
+ rknn_time= end_time - min_time
198
+ processing_time = end_time - start_time
199
+ print(f"rknn识别耗时: {rknn_time:.4f}秒")
200
+ print(f"总共识别耗时: {processing_time:.4f}秒 - 识别结果: {label} (置信度: {confidence:.2f})")
201
+
202
+ return raw_result, formatted_result
203
+
204
+ except Exception as e:
205
+ # 记录结束时间并计算耗时
206
+ end_time = time.time()
207
+ print(f"推理失败: {e}, 耗时: {end_time - start_time:.4f}秒")
208
+ return None, {"error": f"推理失败: {str(e)}", "processing_time": end_time - start_time}
209
+
210
+ def release(self):
211
+ """释放资源"""
212
+ if hasattr(self, 'rknn_lite'):
213
+ self.rknn_lite.release()
214
+ print("NPU资源已释放")
215
+
216
+ def __del__(self):
217
+ """析构函数自动释放资源"""
218
+ self.release()
219
+
220
+ # 保留原始方法以兼容旧代码
221
+ def predict(self, image_path):
222
+ """执行手势分类预测(从文件路径)"""
223
+ # 确保图像路径是绝对路径
224
+ absolute_image_path = os.path.abspath(image_path)
225
+
226
+ try:
227
+ # 使用PIL库读取图像,避免libpng版本问题
228
+ pil_image = Image.open(absolute_image_path)
229
+ # 转换为RGB格式
230
+ rgb_image = pil_image.convert('RGB')
231
+ # 转换为numpy数组
232
+ image_array = np.array(rgb_image)
233
+ # 转换为BGR格式(OpenCV使用的格式)
234
+ image = cv2.cvtColor(image_array, cv2.COLOR_RGB2BGR)
235
+
236
+ if image is None:
237
+ raise ValueError(f"无法读取图像: {absolute_image_path}")
238
+
239
+ return self.predict_frame(image)
240
+ except Exception as e:
241
+ # 如果PIL失败,尝试使用cv2作为备选
242
+ image = cv2.imread(absolute_image_path)
243
+ if image is None:
244
+ raise ValueError(f"无法读取图像: {absolute_image_path}")
245
+ return self.predict_frame(image)