zyworkflow 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zyworkflow/__init__.py +0 -0
- zyworkflow/api_server.py +630 -0
- zyworkflow/data/__init__.py +0 -0
- zyworkflow/data/collection.py +1241 -0
- zyworkflow/data/process.py +72 -0
- zyworkflow/doc/api.md +461 -0
- zyworkflow/example/__init__.py +0 -0
- zyworkflow/example/train_client.py +301 -0
- zyworkflow/example/train_client_example.py +43 -0
- zyworkflow/policy/__init__.py +0 -0
- zyworkflow/policy/train_pick_policy.py +834 -0
- zyworkflow/utils/__init__.py +0 -0
- zyworkflow/utils/logger_config.py +50 -0
- zyworkflow/utils/pose.py +131 -0
- zyworkflow/utils/utils.py +264 -0
- zyworkflow-0.0.1.dist-info/METADATA +11 -0
- zyworkflow-0.0.1.dist-info/RECORD +19 -0
- zyworkflow-0.0.1.dist-info/WHEEL +5 -0
- zyworkflow-0.0.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1241 @@
|
|
|
1
|
+
# coding: utf-8
|
|
2
|
+
import os
|
|
3
|
+
import cv2
|
|
4
|
+
import csv
|
|
5
|
+
import time
|
|
6
|
+
import threading
|
|
7
|
+
import re
|
|
8
|
+
import numpy as np
|
|
9
|
+
import concurrent.futures
|
|
10
|
+
from collections import deque
|
|
11
|
+
from typing import Optional, List, Tuple, Dict, Any
|
|
12
|
+
|
|
13
|
+
from zyworkflow.utils.utils import *
|
|
14
|
+
from zyworkflow.utils.pose import get_target_pose
|
|
15
|
+
from zyworkflow.data.process import process_dataset
|
|
16
|
+
from zyworkflow.utils.logger_config import setup_data_collection_logger
|
|
17
|
+
logger = setup_data_collection_logger()
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class ImageCaptureThread:
|
|
21
|
+
def __init__(self, image_url: str, buffer_size: int = 200):
|
|
22
|
+
self.image_url = image_url
|
|
23
|
+
self.buffer_size = buffer_size
|
|
24
|
+
self.running = False
|
|
25
|
+
self.thread: Optional[threading.Thread] = None
|
|
26
|
+
|
|
27
|
+
self.image_buffer: deque = deque(maxlen=buffer_size)
|
|
28
|
+
self.buffer_lock = threading.RLock()
|
|
29
|
+
|
|
30
|
+
self.capture_count = 0
|
|
31
|
+
self.error_count = 0
|
|
32
|
+
self.last_capture_time = 0
|
|
33
|
+
self.capture_durations = deque(maxlen=100)
|
|
34
|
+
|
|
35
|
+
import requests
|
|
36
|
+
self._session = requests.Session()
|
|
37
|
+
adapter = requests.adapters.HTTPAdapter(
|
|
38
|
+
pool_connections=10,
|
|
39
|
+
pool_maxsize=50,
|
|
40
|
+
max_retries=2,
|
|
41
|
+
pool_block=False
|
|
42
|
+
)
|
|
43
|
+
self._session.mount('http://', adapter)
|
|
44
|
+
self._session.mount('https://', adapter)
|
|
45
|
+
|
|
46
|
+
def _get_image_with_session(self, image_url):
|
|
47
|
+
try:
|
|
48
|
+
response = self._session.get(image_url, timeout=0.5)
|
|
49
|
+
if response.status_code == 200:
|
|
50
|
+
img_array = np.frombuffer(response.content, np.uint8)
|
|
51
|
+
img = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
|
|
52
|
+
return img, "Success"
|
|
53
|
+
else:
|
|
54
|
+
return None, f"HTTP {response.status_code}"
|
|
55
|
+
except Exception as e:
|
|
56
|
+
return None, str(e)
|
|
57
|
+
|
|
58
|
+
def start(self):
|
|
59
|
+
if self.running:
|
|
60
|
+
logger.warning("图像采集线程已经在运行")
|
|
61
|
+
return
|
|
62
|
+
|
|
63
|
+
self.running = True
|
|
64
|
+
self.thread = threading.Thread(
|
|
65
|
+
target=self._capture_loop,
|
|
66
|
+
name="ImageCaptureThread",
|
|
67
|
+
daemon=True
|
|
68
|
+
)
|
|
69
|
+
self.thread.start()
|
|
70
|
+
logger.info(f"图像采集线程启动,目标频率: 20Hz")
|
|
71
|
+
|
|
72
|
+
def stop(self):
|
|
73
|
+
self.running = False
|
|
74
|
+
if self._session:
|
|
75
|
+
self._session.close()
|
|
76
|
+
if self.thread and self.thread.is_alive():
|
|
77
|
+
self.thread.join(timeout=2.0)
|
|
78
|
+
logger.info(f"图像采集线程停止,共采集 {self.capture_count} 张图像,错误 {self.error_count} 次")
|
|
79
|
+
|
|
80
|
+
def _capture_loop(self):
|
|
81
|
+
capture_interval = 0.01
|
|
82
|
+
consecutive_errors = 0
|
|
83
|
+
max_consecutive_errors = 5
|
|
84
|
+
|
|
85
|
+
while self.running:
|
|
86
|
+
try:
|
|
87
|
+
start_time = time.perf_counter()
|
|
88
|
+
img, info = self._get_image_with_session(self.image_url)
|
|
89
|
+
if img is not None:
|
|
90
|
+
timestamp = time.time()
|
|
91
|
+
with self.buffer_lock:
|
|
92
|
+
self.image_buffer.append((timestamp, img))
|
|
93
|
+
self.capture_count += 1
|
|
94
|
+
self.last_capture_time = timestamp
|
|
95
|
+
consecutive_errors = 0
|
|
96
|
+
|
|
97
|
+
capture_time = time.perf_counter() - start_time
|
|
98
|
+
self.capture_durations.append(capture_time)
|
|
99
|
+
else:
|
|
100
|
+
self.error_count += 1
|
|
101
|
+
consecutive_errors += 1
|
|
102
|
+
if consecutive_errors >= max_consecutive_errors:
|
|
103
|
+
logger.error(f"连续采集失败 {consecutive_errors} 次: {info}")
|
|
104
|
+
consecutive_errors = 0
|
|
105
|
+
|
|
106
|
+
process_time = time.perf_counter() - start_time
|
|
107
|
+
if process_time < capture_interval:
|
|
108
|
+
sleep_time = capture_interval - process_time
|
|
109
|
+
if sleep_time > 0.001:
|
|
110
|
+
time.sleep(sleep_time * 0.9)
|
|
111
|
+
else:
|
|
112
|
+
if process_time > capture_interval * 1.5:
|
|
113
|
+
logger.warning(f"图像采集超时: {process_time * 1000:.1f}ms > {capture_interval * 1000:.1f}ms")
|
|
114
|
+
|
|
115
|
+
time.sleep(0.001)
|
|
116
|
+
|
|
117
|
+
except Exception as e:
|
|
118
|
+
self.error_count += 1
|
|
119
|
+
logger.error(f"图像采集异常: {e}")
|
|
120
|
+
time.sleep(0.01)
|
|
121
|
+
|
|
122
|
+
def get_image_at(self, target_time: float, time_tolerance: float = 0.02):
|
|
123
|
+
with self.buffer_lock:
|
|
124
|
+
if not self.image_buffer:
|
|
125
|
+
return None
|
|
126
|
+
|
|
127
|
+
buffer_list = list(self.image_buffer)
|
|
128
|
+
left, right = 0, len(buffer_list) - 1
|
|
129
|
+
best_match = None
|
|
130
|
+
min_time_diff = float('inf')
|
|
131
|
+
|
|
132
|
+
while left <= right:
|
|
133
|
+
mid = (left + right) // 2
|
|
134
|
+
timestamp, img = buffer_list[mid]
|
|
135
|
+
time_diff = abs(timestamp - target_time)
|
|
136
|
+
|
|
137
|
+
if time_diff < min_time_diff:
|
|
138
|
+
min_time_diff = time_diff
|
|
139
|
+
best_match = (timestamp, img)
|
|
140
|
+
|
|
141
|
+
if timestamp < target_time:
|
|
142
|
+
left = mid + 1
|
|
143
|
+
else:
|
|
144
|
+
right = mid - 1
|
|
145
|
+
|
|
146
|
+
if best_match and min_time_diff <= time_tolerance:
|
|
147
|
+
return best_match[1]
|
|
148
|
+
|
|
149
|
+
return None
|
|
150
|
+
|
|
151
|
+
def get_latest_image(self):
|
|
152
|
+
with self.buffer_lock:
|
|
153
|
+
if self.image_buffer:
|
|
154
|
+
return self.image_buffer[-1]
|
|
155
|
+
return None
|
|
156
|
+
|
|
157
|
+
def clear_buffer(self):
|
|
158
|
+
with self.buffer_lock:
|
|
159
|
+
self.image_buffer.clear()
|
|
160
|
+
logger.debug("图像缓冲区已清空")
|
|
161
|
+
|
|
162
|
+
def get_stats(self) -> Dict:
|
|
163
|
+
with self.buffer_lock:
|
|
164
|
+
buffer_len = len(self.image_buffer)
|
|
165
|
+
recent_durations = list(self.capture_durations)
|
|
166
|
+
|
|
167
|
+
avg_duration = np.mean(recent_durations) * 1000 if recent_durations else 0
|
|
168
|
+
max_duration = np.max(recent_durations) * 1000 if recent_durations else 0
|
|
169
|
+
|
|
170
|
+
return {
|
|
171
|
+
'running': self.running,
|
|
172
|
+
'buffer_size': buffer_len,
|
|
173
|
+
'total_captured': self.capture_count,
|
|
174
|
+
'error_count': self.error_count,
|
|
175
|
+
'last_capture_time': self.last_capture_time,
|
|
176
|
+
'avg_capture_time_ms': f"{avg_duration:.1f}",
|
|
177
|
+
'max_capture_time_ms': f"{max_duration:.1f}",
|
|
178
|
+
'current_fps': self.capture_count / max(1, time.time() - self.last_capture_time) if self.capture_count > 0 else 0
|
|
179
|
+
}
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
@register(name='bnn-pick')
|
|
183
|
+
class Pick:
|
|
184
|
+
def __init__(self, ability_id, dataset_id, init_pose, speed=40, sampling_rate=20):
|
|
185
|
+
self.speed = speed
|
|
186
|
+
self.ability_id = ability_id
|
|
187
|
+
self.dataset_id = dataset_id
|
|
188
|
+
self.init_pose = init_pose
|
|
189
|
+
self._stop_requested = False
|
|
190
|
+
self._stop_lock = threading.Lock()
|
|
191
|
+
self.msg = {'code': 0, 'msg': 'Execution completed', 'action': {}}
|
|
192
|
+
|
|
193
|
+
self.recording = False
|
|
194
|
+
self.record_thread = None
|
|
195
|
+
self.image_thread = None
|
|
196
|
+
self.data_dir = os.path.join("/workspace/dataset", self.dataset_id, self.ability_id)
|
|
197
|
+
|
|
198
|
+
self.sampling_rate = sampling_rate
|
|
199
|
+
self.sampling_interval = 1.0 / self.sampling_rate
|
|
200
|
+
|
|
201
|
+
self.time_alignment_tolerance = 0.01
|
|
202
|
+
|
|
203
|
+
self._last_arm_status = None
|
|
204
|
+
self._last_arm_time = 0
|
|
205
|
+
self._arm_cache_duration = 0.05
|
|
206
|
+
|
|
207
|
+
self._stats = {
|
|
208
|
+
'total_samples': 0,
|
|
209
|
+
'arm_calls': 0,
|
|
210
|
+
'image_hits': 0,
|
|
211
|
+
'image_misses': 0,
|
|
212
|
+
'processing_times': deque(maxlen=100)
|
|
213
|
+
}
|
|
214
|
+
|
|
215
|
+
self._target_positions = []
|
|
216
|
+
self._target_reached = False
|
|
217
|
+
self._target_position_lock = threading.Lock()
|
|
218
|
+
|
|
219
|
+
os.makedirs(self.data_dir, exist_ok=True)
|
|
220
|
+
|
|
221
|
+
self.record_count = self._get_max_record_number()
|
|
222
|
+
logger.info(f"开始执行数据采集任务,任务名称: {self.dataset_id}-{self.ability_id}")
|
|
223
|
+
|
|
224
|
+
self._thread_pool = concurrent.futures.ThreadPoolExecutor(
|
|
225
|
+
max_workers=3,
|
|
226
|
+
thread_name_prefix="PickAsync"
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
def _get_max_record_number(self) -> int:
|
|
230
|
+
if not os.path.exists(self.data_dir):
|
|
231
|
+
return 0
|
|
232
|
+
|
|
233
|
+
max_num = 0
|
|
234
|
+
pattern = re.compile(r'traj_(\d{3})')
|
|
235
|
+
|
|
236
|
+
for item in os.listdir(self.data_dir):
|
|
237
|
+
item_path = os.path.join(self.data_dir, item)
|
|
238
|
+
if os.path.isdir(item_path):
|
|
239
|
+
match = pattern.match(item)
|
|
240
|
+
if match:
|
|
241
|
+
num = int(match.group(1))
|
|
242
|
+
if num > max_num:
|
|
243
|
+
max_num = num
|
|
244
|
+
|
|
245
|
+
logger.debug(f"找到最大记录编号: {max_num}")
|
|
246
|
+
return max_num
|
|
247
|
+
|
|
248
|
+
def set_target_position(self, position):
|
|
249
|
+
with self._target_position_lock:
|
|
250
|
+
self._target_positions = position
|
|
251
|
+
self._target_reached = False
|
|
252
|
+
|
|
253
|
+
def mark_target_reached(self):
|
|
254
|
+
with self._target_position_lock:
|
|
255
|
+
self._target_reached = True
|
|
256
|
+
|
|
257
|
+
def _is_position_reached(self, current_pose, tolerance=0.01):
|
|
258
|
+
with self._target_position_lock:
|
|
259
|
+
if not self._target_positions:
|
|
260
|
+
return False
|
|
261
|
+
|
|
262
|
+
if self._target_reached:
|
|
263
|
+
return True
|
|
264
|
+
|
|
265
|
+
if len(current_pose) >= 3 and len(self._target_positions) >= 3:
|
|
266
|
+
dx = abs(current_pose[0] - self._target_positions[0])
|
|
267
|
+
dy = abs(current_pose[1] - self._target_positions[1])
|
|
268
|
+
dz = abs(current_pose[2] - self._target_positions[2])
|
|
269
|
+
|
|
270
|
+
distance = (dx ** 2 + dy ** 2 + dz ** 2) ** 0.5
|
|
271
|
+
reached = distance < tolerance
|
|
272
|
+
|
|
273
|
+
if reached:
|
|
274
|
+
self._target_reached = True
|
|
275
|
+
|
|
276
|
+
return reached
|
|
277
|
+
|
|
278
|
+
return False
|
|
279
|
+
|
|
280
|
+
def _get_gripper_state(self, current_pose):
|
|
281
|
+
if self._is_position_reached(current_pose):
|
|
282
|
+
return 0, 0
|
|
283
|
+
else:
|
|
284
|
+
return 1000, 1000
|
|
285
|
+
|
|
286
|
+
def request_stop(self):
|
|
287
|
+
with self._stop_lock:
|
|
288
|
+
self._stop_requested = True
|
|
289
|
+
logger.debug(f"任务 {self.dataset_id}-{self.ability_id} 收到停止请求")
|
|
290
|
+
|
|
291
|
+
def clear_stop_flag(self):
|
|
292
|
+
with self._stop_lock:
|
|
293
|
+
self._stop_requested = False
|
|
294
|
+
logger.debug(f"任务 {self.dataset_id}-{self.ability_id} 停止标志已清除")
|
|
295
|
+
|
|
296
|
+
def should_stop(self):
|
|
297
|
+
with self._stop_lock:
|
|
298
|
+
return self._stop_requested
|
|
299
|
+
|
|
300
|
+
def preprocess_image(self, img):
|
|
301
|
+
if img is None:
|
|
302
|
+
return None
|
|
303
|
+
try:
|
|
304
|
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|
305
|
+
return img
|
|
306
|
+
except Exception as e:
|
|
307
|
+
logger.error(f"图像预处理失败: {e}")
|
|
308
|
+
return None
|
|
309
|
+
|
|
310
|
+
def _start_async_capture(self):
|
|
311
|
+
logger.info("启动异步采集线程...")
|
|
312
|
+
|
|
313
|
+
self.image_thread = ImageCaptureThread(rgb_image_url, buffer_size=300)
|
|
314
|
+
self.image_thread.start()
|
|
315
|
+
|
|
316
|
+
time.sleep(0.3)
|
|
317
|
+
logger.info("异步采集线程启动完成")
|
|
318
|
+
|
|
319
|
+
def _stop_async_capture(self):
|
|
320
|
+
logger.info("停止异步采集线程...")
|
|
321
|
+
|
|
322
|
+
if self.image_thread:
|
|
323
|
+
self.image_thread.stop()
|
|
324
|
+
self.image_thread = None
|
|
325
|
+
|
|
326
|
+
logger.info("异步采集线程停止完成")
|
|
327
|
+
|
|
328
|
+
def _get_arm_status_cached(self) -> Tuple[List[float], List[float]]:
|
|
329
|
+
current_time = time.time()
|
|
330
|
+
|
|
331
|
+
if (self._last_arm_status is None or
|
|
332
|
+
(current_time - self._last_arm_time) > self._arm_cache_duration):
|
|
333
|
+
try:
|
|
334
|
+
status, _ = get_arm_status()
|
|
335
|
+
if status:
|
|
336
|
+
pose = status.get("pose", [0.0] * 6)
|
|
337
|
+
joint_deg = status.get("joint", [0.0] * 6)
|
|
338
|
+
joint_rad = [j / 180.0 * 3.1415926535 for j in joint_deg]
|
|
339
|
+
|
|
340
|
+
self._last_arm_status = (pose, joint_rad)
|
|
341
|
+
self._last_arm_time = current_time
|
|
342
|
+
self._stats['arm_calls'] += 1
|
|
343
|
+
|
|
344
|
+
return pose, joint_rad
|
|
345
|
+
except Exception as e:
|
|
346
|
+
logger.debug(f"获取机械臂状态失败: {e}")
|
|
347
|
+
|
|
348
|
+
if self._last_arm_status:
|
|
349
|
+
return self._last_arm_status
|
|
350
|
+
|
|
351
|
+
return [0.0] * 6, [0.0] * 6
|
|
352
|
+
|
|
353
|
+
def _record_trajectory_async(self, filename: str, sku: str, img_dir: str):
|
|
354
|
+
try:
|
|
355
|
+
logger.info(f"开始高性能记录: {filename}, SKU: {sku}")
|
|
356
|
+
|
|
357
|
+
import shutil
|
|
358
|
+
if os.path.exists(img_dir):
|
|
359
|
+
shutil.rmtree(img_dir)
|
|
360
|
+
os.makedirs(img_dir, exist_ok=True)
|
|
361
|
+
|
|
362
|
+
self._start_async_capture()
|
|
363
|
+
time.sleep(0.3)
|
|
364
|
+
|
|
365
|
+
csv_file = open(filename, "w", newline="")
|
|
366
|
+
csv_writer = csv.writer(csv_file)
|
|
367
|
+
|
|
368
|
+
csv_writer.writerow([
|
|
369
|
+
"Time(s)", "X(m)", "Y(m)", "Z(m)",
|
|
370
|
+
"Rx(rad)", "Ry(rad)", "Rz(rad)",
|
|
371
|
+
"j1(rad)", "j2(rad)", "j3(rad)",
|
|
372
|
+
"j4(rad)", "j5(rad)", "j6(rad)",
|
|
373
|
+
"Gripper_Set", "Gripper_Real", "SKU", "Image_Filename"
|
|
374
|
+
])
|
|
375
|
+
|
|
376
|
+
start_time = time.perf_counter()
|
|
377
|
+
next_sample_time = start_time
|
|
378
|
+
sample_count = 0
|
|
379
|
+
|
|
380
|
+
image_queue = deque()
|
|
381
|
+
image_save_thread_running = True
|
|
382
|
+
|
|
383
|
+
def image_save_worker():
|
|
384
|
+
saved_count = 0
|
|
385
|
+
while image_save_thread_running or image_queue:
|
|
386
|
+
if image_queue:
|
|
387
|
+
img_path, image_data = image_queue.popleft()
|
|
388
|
+
try:
|
|
389
|
+
if image_data is not None:
|
|
390
|
+
cv2.imwrite(img_path, image_data)
|
|
391
|
+
saved_count += 1
|
|
392
|
+
except Exception as e:
|
|
393
|
+
logger.debug(f"保存图像失败 {img_path}: {e}")
|
|
394
|
+
else:
|
|
395
|
+
time.sleep(0.001)
|
|
396
|
+
logger.debug(f"图像保存线程结束,共保存 {saved_count} 张图像")
|
|
397
|
+
|
|
398
|
+
image_save_thread = threading.Thread(
|
|
399
|
+
target=image_save_worker,
|
|
400
|
+
name="ImageSaveThread",
|
|
401
|
+
daemon=True
|
|
402
|
+
)
|
|
403
|
+
image_save_thread.start()
|
|
404
|
+
|
|
405
|
+
logger.info(f"开始数据采集,目标采样率: {self.sampling_rate}Hz")
|
|
406
|
+
|
|
407
|
+
last_log_time = start_time
|
|
408
|
+
last_sample_count = 0
|
|
409
|
+
|
|
410
|
+
while self.recording and not self.should_stop():
|
|
411
|
+
current_time = time.perf_counter()
|
|
412
|
+
|
|
413
|
+
if current_time >= next_sample_time:
|
|
414
|
+
loop_start = current_time
|
|
415
|
+
try:
|
|
416
|
+
rel_time = current_time - start_time
|
|
417
|
+
|
|
418
|
+
pose_data, joint_data = self._get_arm_status_cached()
|
|
419
|
+
gripper_set, gripper_real = self._get_gripper_state(pose_data)
|
|
420
|
+
|
|
421
|
+
image_filename = ""
|
|
422
|
+
if self.image_thread:
|
|
423
|
+
image = self.image_thread.get_image_at(time.time(), self.time_alignment_tolerance)
|
|
424
|
+
if image is not None:
|
|
425
|
+
img_filename = f"{rel_time:.6f}.png"
|
|
426
|
+
image_filename = img_filename
|
|
427
|
+
img_path = os.path.join(img_dir, img_filename)
|
|
428
|
+
|
|
429
|
+
image_queue.append((img_path, image))
|
|
430
|
+
self._stats['image_hits'] += 1
|
|
431
|
+
else:
|
|
432
|
+
self._stats['image_misses'] += 1
|
|
433
|
+
|
|
434
|
+
csv_writer.writerow([
|
|
435
|
+
f"{rel_time:.6f}",
|
|
436
|
+
*[f"{p:.6f}" for p in pose_data],
|
|
437
|
+
*[f"{j:.6f}" for j in joint_data],
|
|
438
|
+
f"{gripper_set}",
|
|
439
|
+
f"{gripper_real}",
|
|
440
|
+
f"{sku}",
|
|
441
|
+
image_filename
|
|
442
|
+
])
|
|
443
|
+
|
|
444
|
+
sample_count += 1
|
|
445
|
+
self._stats['total_samples'] += 1
|
|
446
|
+
if sample_count % 30 == 0:
|
|
447
|
+
csv_file.flush()
|
|
448
|
+
elapsed = time.perf_counter() - start_time
|
|
449
|
+
actual_rate = sample_count / elapsed if elapsed > 0 else 0
|
|
450
|
+
|
|
451
|
+
recent_samples = sample_count - last_sample_count
|
|
452
|
+
recent_time = current_time - last_log_time
|
|
453
|
+
recent_rate = recent_samples / recent_time if recent_time > 0 else 0
|
|
454
|
+
|
|
455
|
+
logger.info(
|
|
456
|
+
f"进度: {sample_count}样本, "
|
|
457
|
+
f"时间: {elapsed:.1f}s, "
|
|
458
|
+
f"平均率: {actual_rate:.1f}Hz, "
|
|
459
|
+
f"近期率: {recent_rate:.1f}Hz, "
|
|
460
|
+
f"图像命中: {self._stats['image_hits']}/{sample_count}"
|
|
461
|
+
)
|
|
462
|
+
|
|
463
|
+
last_log_time = current_time
|
|
464
|
+
last_sample_count = sample_count
|
|
465
|
+
|
|
466
|
+
next_sample_time = start_time + (sample_count * self.sampling_interval)
|
|
467
|
+
process_time = time.perf_counter() - loop_start
|
|
468
|
+
self._stats['processing_times'].append(process_time)
|
|
469
|
+
|
|
470
|
+
if process_time > self.sampling_interval * 1.5:
|
|
471
|
+
logger.warning(
|
|
472
|
+
f"采样处理时间过长: {process_time * 1000:.1f}ms > "
|
|
473
|
+
f"间隔{self.sampling_interval * 1000:.1f}ms"
|
|
474
|
+
)
|
|
475
|
+
except Exception as e:
|
|
476
|
+
logger.error(f"采样过程出错: {e}")
|
|
477
|
+
sample_count += 1
|
|
478
|
+
next_sample_time = start_time + (sample_count * self.sampling_interval)
|
|
479
|
+
|
|
480
|
+
else:
|
|
481
|
+
sleep_time = next_sample_time - time.perf_counter()
|
|
482
|
+
if sleep_time > 0.0005:
|
|
483
|
+
end_time = time.perf_counter() + sleep_time
|
|
484
|
+
while time.perf_counter() < end_time:
|
|
485
|
+
time.sleep(0.0001)
|
|
486
|
+
|
|
487
|
+
image_save_thread_running = False
|
|
488
|
+
if image_save_thread.is_alive():
|
|
489
|
+
image_save_thread.join(timeout=2.0)
|
|
490
|
+
|
|
491
|
+
self._stop_async_capture()
|
|
492
|
+
csv_file.flush()
|
|
493
|
+
csv_file.close()
|
|
494
|
+
|
|
495
|
+
total_time = time.perf_counter() - start_time
|
|
496
|
+
actual_rate = sample_count / total_time if total_time > 0 else 0
|
|
497
|
+
|
|
498
|
+
stats_file = filename.replace('.txt', '_stats.txt')
|
|
499
|
+
with open(stats_file, 'w') as sf:
|
|
500
|
+
sf.write("=== 轨迹记录统计 ===\n")
|
|
501
|
+
sf.write(f"总样本数: {sample_count}\n")
|
|
502
|
+
sf.write(f"总时间: {total_time:.3f}s\n")
|
|
503
|
+
sf.write(f"目标采样率: {self.sampling_rate}Hz\n")
|
|
504
|
+
sf.write(f"实际采样率: {actual_rate:.2f}Hz\n")
|
|
505
|
+
sf.write(f"机械臂调用次数: {self._stats['arm_calls']}\n")
|
|
506
|
+
sf.write(f"图像命中数: {self._stats['image_hits']}\n")
|
|
507
|
+
sf.write(f"图像缺失数: {self._stats['image_misses']}\n")
|
|
508
|
+
sf.write(f"SKU: {sku}\n")
|
|
509
|
+
|
|
510
|
+
if self._stats['processing_times']:
|
|
511
|
+
avg_process = np.mean(self._stats['processing_times']) * 1000
|
|
512
|
+
max_process = np.max(self._stats['processing_times']) * 1000
|
|
513
|
+
sf.write(f"平均处理时间: {avg_process:.2f}ms\n")
|
|
514
|
+
sf.write(f"最大处理时间: {max_process:.2f}ms\n")
|
|
515
|
+
|
|
516
|
+
logger.info(f"记录完成: {filename}")
|
|
517
|
+
logger.info(f"最终统计: {sample_count}样本, {total_time:.2f}s, {actual_rate:.1f}Hz")
|
|
518
|
+
|
|
519
|
+
except Exception as e:
|
|
520
|
+
logger.error(f"记录轨迹数据异常: {e}", exc_info=True)
|
|
521
|
+
self._stop_async_capture()
|
|
522
|
+
if 'csv_file' in locals() and not csv_file.closed:
|
|
523
|
+
csv_file.close()
|
|
524
|
+
|
|
525
|
+
def _start_recording(self, sku: str):
|
|
526
|
+
if self.recording:
|
|
527
|
+
logger.warning("已经在记录中,忽略重复开始")
|
|
528
|
+
return
|
|
529
|
+
|
|
530
|
+
self.recording = True
|
|
531
|
+
self.record_count += 1
|
|
532
|
+
|
|
533
|
+
self._stats = {
|
|
534
|
+
'total_samples': 0,
|
|
535
|
+
'arm_calls': 0,
|
|
536
|
+
'image_hits': 0,
|
|
537
|
+
'image_misses': 0,
|
|
538
|
+
'processing_times': deque(maxlen=100)
|
|
539
|
+
}
|
|
540
|
+
|
|
541
|
+
self._last_arm_status = None
|
|
542
|
+
self._last_arm_time = 0
|
|
543
|
+
|
|
544
|
+
with self._target_position_lock:
|
|
545
|
+
self._target_positions = []
|
|
546
|
+
self._target_reached = False
|
|
547
|
+
|
|
548
|
+
imgs_dir = f"images"
|
|
549
|
+
filename = f"actions.txt"
|
|
550
|
+
traj_dir = f"traj_{self.record_count:03d}"
|
|
551
|
+
filepath = os.path.join(self.data_dir, traj_dir, filename)
|
|
552
|
+
imgs_filepath = os.path.join(self.data_dir, traj_dir, imgs_dir)
|
|
553
|
+
os.makedirs(os.path.dirname(filepath), exist_ok=True)
|
|
554
|
+
os.makedirs(imgs_filepath, exist_ok=True)
|
|
555
|
+
self.record_thread = threading.Thread(
|
|
556
|
+
target=self._record_trajectory_async,
|
|
557
|
+
args=(filepath, sku, imgs_filepath),
|
|
558
|
+
name=f"RecordThread_{self.record_count}",
|
|
559
|
+
daemon=True
|
|
560
|
+
)
|
|
561
|
+
self.record_thread.start()
|
|
562
|
+
|
|
563
|
+
logger.info(f"开始第 {self.record_count} 次数据记录 -> {filepath}, SKU: {sku}")
|
|
564
|
+
|
|
565
|
+
def _stop_recording(self):
|
|
566
|
+
if not self.recording:
|
|
567
|
+
return
|
|
568
|
+
|
|
569
|
+
self.recording = False
|
|
570
|
+
if self.record_thread and self.record_thread.is_alive():
|
|
571
|
+
self.record_thread.join(timeout=5.0)
|
|
572
|
+
|
|
573
|
+
logger.info(f"第 {self.record_count} 次数据记录完成")
|
|
574
|
+
|
|
575
|
+
def run_from_http_camera(self, sku):
|
|
576
|
+
try:
|
|
577
|
+
logger.info(f"执行{self.dataset_id}-{self.ability_id}任务")
|
|
578
|
+
|
|
579
|
+
self.set_sampling_rate(20.0)
|
|
580
|
+
start_time = time.time()
|
|
581
|
+
self.clear_stop_flag()
|
|
582
|
+
|
|
583
|
+
code, _ = post_arm_movej(self.init_pose, speed=self.speed)
|
|
584
|
+
time.sleep(2)
|
|
585
|
+
|
|
586
|
+
dof_time = time.time()
|
|
587
|
+
xyz, quat_xyzw, _ = post_6dof(str(sku))
|
|
588
|
+
logger.info(f'{self.dataset_id}-{self.ability_id} - 6dof 用时{time.time() - dof_time}')
|
|
589
|
+
|
|
590
|
+
if xyz is None:
|
|
591
|
+
logger.info(f"{self.dataset_id}-{self.ability_id}任务执行完成,没检测到6dof信息")
|
|
592
|
+
return {'code': -4, 'msg': "没检测到6dof信息", 'dataset_id': self.dataset_id, 'traj_path': ''}
|
|
593
|
+
|
|
594
|
+
coords = xyz
|
|
595
|
+
quat = quat_xyzw
|
|
596
|
+
seven_dim_coord = np.concatenate((coords, quat))
|
|
597
|
+
|
|
598
|
+
state, _ = get_arm_status()
|
|
599
|
+
end_pose = state["pose"]
|
|
600
|
+
logger.info(f'起始位姿: {end_pose}')
|
|
601
|
+
|
|
602
|
+
pos_target = get_target_pose(end_pose, seven_dim_coord)
|
|
603
|
+
pos_target = [float(pos) for pos in pos_target]
|
|
604
|
+
logger.info(f'pos_target: {pos_target}')
|
|
605
|
+
|
|
606
|
+
move_posy = [state["pose"][0], pos_target[1], state["pose"][2],
|
|
607
|
+
state["pose"][3], state["pose"][4], state["pose"][5]]
|
|
608
|
+
move_posz = [pos_target[0] + 0.3, pos_target[1], pos_target[2],
|
|
609
|
+
state["pose"][3], state["pose"][4], state["pose"][5]]
|
|
610
|
+
move_posx = [pos_target[0] + 0.17, pos_target[1], pos_target[2],
|
|
611
|
+
state["pose"][3], state["pose"][4], state["pose"][5]]
|
|
612
|
+
|
|
613
|
+
self.set_target_position(move_posx)
|
|
614
|
+
self._start_recording(sku)
|
|
615
|
+
time.sleep(0.5)
|
|
616
|
+
|
|
617
|
+
code, _ = post_arm_movel(move_posy, speed=30)
|
|
618
|
+
code, _ = post_arm_movel(move_posz, speed=self.speed)
|
|
619
|
+
code, _ = post_arm_movel(move_posx, speed=self.speed)
|
|
620
|
+
|
|
621
|
+
self.mark_target_reached()
|
|
622
|
+
post_gripper_move(0)
|
|
623
|
+
|
|
624
|
+
keep_time = 2
|
|
625
|
+
keep_start = time.time()
|
|
626
|
+
while (time.time() - keep_start) < keep_time and not self.should_stop():
|
|
627
|
+
time.sleep(0.1)
|
|
628
|
+
|
|
629
|
+
self._stop_recording()
|
|
630
|
+
|
|
631
|
+
time.sleep(1)
|
|
632
|
+
grip_state, set_position, position, _ = get_gripper_status()
|
|
633
|
+
if grip_state and position:
|
|
634
|
+
if grip_state == 3 or position == 0:
|
|
635
|
+
logger.info('---------检测到物体掉落---------')
|
|
636
|
+
return {'code': -5, 'msg': "没有夹到物体", 'dataset_id': self.dataset_id, 'traj_path': ''}
|
|
637
|
+
|
|
638
|
+
post_gripper_move(1000)
|
|
639
|
+
state, _ = get_arm_status()
|
|
640
|
+
pose = state['pose']
|
|
641
|
+
pose[0] = pose[0] + 0.2
|
|
642
|
+
code, _ = post_arm_movel(pose, speed=self.speed)
|
|
643
|
+
code, _ = post_arm_movej(self.init_pose, speed=self.speed)
|
|
644
|
+
|
|
645
|
+
total_time = time.time() - start_time
|
|
646
|
+
logger.info(f'{self.dataset_id}-{self.ability_id} - 总用时{total_time}')
|
|
647
|
+
logger.info(f"{self.dataset_id}-{self.ability_id}任务执行完成,机械臂已回到初始位置")
|
|
648
|
+
traj_path = os.path.join(self.data_dir, f"traj_{self.record_count:03d}")
|
|
649
|
+
process_dataset(self.data_dir, self.record_count)
|
|
650
|
+
self.msg['traj_path'] = traj_path
|
|
651
|
+
self.msg['dataset_id'] = self.dataset_id
|
|
652
|
+
self.msg['ability_id'] = self.ability_id
|
|
653
|
+
return self.msg
|
|
654
|
+
|
|
655
|
+
except RuntimeError as e:
|
|
656
|
+
msg = {'code': -1, 'msg': str(e), 'dataset_id': self.dataset_id, 'traj_path': '', 'ability_id': self.ability_id}
|
|
657
|
+
self.msg = msg
|
|
658
|
+
self._stop_recording()
|
|
659
|
+
return self.msg
|
|
660
|
+
|
|
661
|
+
except Exception as e:
|
|
662
|
+
msg = {'code': -1, 'msg': f"未知错误: {str(e)}", 'dataset_id': self.dataset_id, 'traj_path': '', 'ability_id': self.ability_id}
|
|
663
|
+
self.msg = msg
|
|
664
|
+
self._stop_recording()
|
|
665
|
+
return self.msg
|
|
666
|
+
|
|
667
|
+
finally:
|
|
668
|
+
if self.msg.get('code') != 0:
|
|
669
|
+
traj_dir_to_clean = os.path.join(self.data_dir, f"traj_{self.record_count:03d}")
|
|
670
|
+
if os.path.isdir(traj_dir_to_clean):
|
|
671
|
+
try:
|
|
672
|
+
import shutil
|
|
673
|
+
shutil.rmtree(traj_dir_to_clean)
|
|
674
|
+
logger.warning(f"{self.dataset_id}-{self.ability_id} 采集失败,已清理目录: {traj_dir_to_clean}")
|
|
675
|
+
except Exception as e:
|
|
676
|
+
logger.error(f"{self.dataset_id}-{self.ability_id} 清理失败目录 {traj_dir_to_clean} 时出错: {e}")
|
|
677
|
+
|
|
678
|
+
self._stop_recording()
|
|
679
|
+
if hasattr(self, '_thread_pool'):
|
|
680
|
+
self._thread_pool.shutdown(wait=False)
|
|
681
|
+
|
|
682
|
+
def set_time_alignment_tolerance(self, tolerance_ms: float):
|
|
683
|
+
self.time_alignment_tolerance = tolerance_ms / 1000.0
|
|
684
|
+
logger.info(f"时间对齐容差设置为: {tolerance_ms}ms")
|
|
685
|
+
|
|
686
|
+
def set_sampling_rate(self, rate_hz: float):
|
|
687
|
+
if rate_hz <= 0:
|
|
688
|
+
logger.warning(f"无效的采样率: {rate_hz} Hz,使用默认值20Hz")
|
|
689
|
+
rate_hz = 20.0
|
|
690
|
+
|
|
691
|
+
self.sampling_rate = rate_hz
|
|
692
|
+
self.sampling_interval = 1.0 / rate_hz
|
|
693
|
+
logger.info(f"采样率设置为: {rate_hz} Hz (间隔: {self.sampling_interval:.6f} 秒)")
|
|
694
|
+
|
|
695
|
+
def get_capture_stats(self) -> Dict:
|
|
696
|
+
stats = {
|
|
697
|
+
'recording': self.recording,
|
|
698
|
+
'record_count': self.record_count,
|
|
699
|
+
'sampling_rate': self.sampling_rate,
|
|
700
|
+
'time_alignment_tolerance_ms': self.time_alignment_tolerance * 1000,
|
|
701
|
+
'performance_stats': self._stats.copy()
|
|
702
|
+
}
|
|
703
|
+
|
|
704
|
+
if self.image_thread:
|
|
705
|
+
stats['image_capture'] = self.image_thread.get_stats()
|
|
706
|
+
|
|
707
|
+
if self._stats['processing_times']:
|
|
708
|
+
proc_times = list(self._stats['processing_times'])
|
|
709
|
+
stats['avg_processing_time_ms'] = np.mean(proc_times) * 1000
|
|
710
|
+
stats['max_processing_time_ms'] = np.max(proc_times) * 1000
|
|
711
|
+
|
|
712
|
+
return stats
|
|
713
|
+
|
|
714
|
+
def get_record_count(self) -> int:
|
|
715
|
+
return self.record_count
|
|
716
|
+
|
|
717
|
+
def cleanup(self):
|
|
718
|
+
logger.info("开始清理Pick类资源...")
|
|
719
|
+
self._stop_recording()
|
|
720
|
+
if hasattr(self, '_thread_pool'):
|
|
721
|
+
self._thread_pool.shutdown(wait=False)
|
|
722
|
+
logger.info("Pick类资源清理完成")
|
|
723
|
+
|
|
724
|
+
|
|
725
|
+
@register(name='bnn-place')
|
|
726
|
+
class Place:
|
|
727
|
+
def __init__(self, ability_id, dataset_id, init_pose, speed=40, sampling_rate=20):
|
|
728
|
+
self.speed = speed
|
|
729
|
+
self.ability_id = ability_id
|
|
730
|
+
self.dataset_id = dataset_id
|
|
731
|
+
self.init_pose = init_pose
|
|
732
|
+
self._stop_requested = False
|
|
733
|
+
self._stop_lock = threading.Lock()
|
|
734
|
+
self.msg = {'code': 0, 'msg': 'Execution completed', 'action': {}}
|
|
735
|
+
|
|
736
|
+
self.recording = False
|
|
737
|
+
self.record_thread = None
|
|
738
|
+
self.image_thread = None
|
|
739
|
+
self.data_dir = os.path.join("/workspace/dataset", self.dataset_id, self.ability_id)
|
|
740
|
+
|
|
741
|
+
self.sampling_rate = sampling_rate
|
|
742
|
+
self.sampling_interval = 1.0 / self.sampling_rate
|
|
743
|
+
|
|
744
|
+
self.time_alignment_tolerance = 0.01
|
|
745
|
+
self._last_arm_status = None
|
|
746
|
+
self._last_arm_time = 0
|
|
747
|
+
self._arm_cache_duration = 0.05
|
|
748
|
+
|
|
749
|
+
self._stats = {
|
|
750
|
+
'total_samples': 0,
|
|
751
|
+
'arm_calls': 0,
|
|
752
|
+
'image_hits': 0,
|
|
753
|
+
'image_misses': 0,
|
|
754
|
+
'processing_times': deque(maxlen=100)
|
|
755
|
+
}
|
|
756
|
+
|
|
757
|
+
self._target_positions = []
|
|
758
|
+
self._target_reached = False
|
|
759
|
+
self._target_position_lock = threading.Lock()
|
|
760
|
+
|
|
761
|
+
os.makedirs(self.data_dir, exist_ok=True)
|
|
762
|
+
|
|
763
|
+
self.record_count = self._get_max_record_number()
|
|
764
|
+
logger.info(f"开始执行数据采集任务,任务名称: {self.dataset_id}-{self.ability_id}")
|
|
765
|
+
|
|
766
|
+
self._thread_pool = concurrent.futures.ThreadPoolExecutor(
|
|
767
|
+
max_workers=3,
|
|
768
|
+
thread_name_prefix="PickAsync"
|
|
769
|
+
)
|
|
770
|
+
|
|
771
|
+
def _get_max_record_number(self) -> int:
|
|
772
|
+
if not os.path.exists(self.data_dir):
|
|
773
|
+
return 0
|
|
774
|
+
|
|
775
|
+
max_num = 0
|
|
776
|
+
pattern = re.compile(r'traj_(\d{3})')
|
|
777
|
+
|
|
778
|
+
for item in os.listdir(self.data_dir):
|
|
779
|
+
item_path = os.path.join(self.data_dir, item)
|
|
780
|
+
if os.path.isdir(item_path):
|
|
781
|
+
match = pattern.match(item)
|
|
782
|
+
if match:
|
|
783
|
+
num = int(match.group(1))
|
|
784
|
+
if num > max_num:
|
|
785
|
+
max_num = num
|
|
786
|
+
|
|
787
|
+
logger.debug(f"找到最大记录编号: {max_num}")
|
|
788
|
+
return max_num
|
|
789
|
+
|
|
790
|
+
def set_target_position(self, position):
|
|
791
|
+
with self._target_position_lock:
|
|
792
|
+
self._target_positions = position
|
|
793
|
+
self._target_reached = False
|
|
794
|
+
|
|
795
|
+
def mark_target_reached(self):
|
|
796
|
+
with self._target_position_lock:
|
|
797
|
+
self._target_reached = True
|
|
798
|
+
|
|
799
|
+
def _is_position_reached(self, current_pose, tolerance=0.01):
|
|
800
|
+
with self._target_position_lock:
|
|
801
|
+
if not self._target_positions:
|
|
802
|
+
return False
|
|
803
|
+
|
|
804
|
+
if self._target_reached:
|
|
805
|
+
return True
|
|
806
|
+
|
|
807
|
+
if len(current_pose) >= 3 and len(self._target_positions) >= 3:
|
|
808
|
+
dx = abs(current_pose[0] - self._target_positions[0])
|
|
809
|
+
dy = abs(current_pose[1] - self._target_positions[1])
|
|
810
|
+
dz = abs(current_pose[2] - self._target_positions[2])
|
|
811
|
+
|
|
812
|
+
distance = (dx ** 2 + dy ** 2 + dz ** 2) ** 0.5
|
|
813
|
+
reached = distance < tolerance
|
|
814
|
+
|
|
815
|
+
if reached:
|
|
816
|
+
self._target_reached = True
|
|
817
|
+
|
|
818
|
+
return reached
|
|
819
|
+
|
|
820
|
+
return False
|
|
821
|
+
|
|
822
|
+
def _get_gripper_state(self, current_pose):
|
|
823
|
+
if self._is_position_reached(current_pose):
|
|
824
|
+
return 1000, 1000
|
|
825
|
+
else:
|
|
826
|
+
return 0, 0
|
|
827
|
+
|
|
828
|
+
def request_stop(self):
|
|
829
|
+
with self._stop_lock:
|
|
830
|
+
self._stop_requested = True
|
|
831
|
+
logger.debug(f"任务 {self.task_id} 收到停止请求")
|
|
832
|
+
|
|
833
|
+
def clear_stop_flag(self):
|
|
834
|
+
with self._stop_lock:
|
|
835
|
+
self._stop_requested = False
|
|
836
|
+
logger.debug(f"任务 {self.task_id} 停止标志已清除")
|
|
837
|
+
|
|
838
|
+
def should_stop(self):
|
|
839
|
+
with self._stop_lock:
|
|
840
|
+
return self._stop_requested
|
|
841
|
+
|
|
842
|
+
def preprocess_image(self, img):
|
|
843
|
+
if img is None:
|
|
844
|
+
return None
|
|
845
|
+
try:
|
|
846
|
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|
847
|
+
return img
|
|
848
|
+
except Exception as e:
|
|
849
|
+
logger.error(f"图像预处理失败: {e}")
|
|
850
|
+
return None
|
|
851
|
+
|
|
852
|
+
def _start_async_capture(self):
|
|
853
|
+
logger.info("启动异步采集线程...")
|
|
854
|
+
|
|
855
|
+
self.image_thread = ImageCaptureThread(rgb_image_url, buffer_size=300)
|
|
856
|
+
self.image_thread.start()
|
|
857
|
+
|
|
858
|
+
time.sleep(0.3)
|
|
859
|
+
logger.info("异步采集线程启动完成")
|
|
860
|
+
|
|
861
|
+
def _stop_async_capture(self):
|
|
862
|
+
logger.info("停止异步采集线程...")
|
|
863
|
+
|
|
864
|
+
if self.image_thread:
|
|
865
|
+
self.image_thread.stop()
|
|
866
|
+
self.image_thread = None
|
|
867
|
+
|
|
868
|
+
logger.info("异步采集线程停止完成")
|
|
869
|
+
|
|
870
|
+
def _get_arm_status_cached(self) -> Tuple[List[float], List[float]]:
|
|
871
|
+
current_time = time.time()
|
|
872
|
+
|
|
873
|
+
if (self._last_arm_status is None or
|
|
874
|
+
(current_time - self._last_arm_time) > self._arm_cache_duration):
|
|
875
|
+
try:
|
|
876
|
+
status, _ = get_arm_status()
|
|
877
|
+
if status:
|
|
878
|
+
pose = status.get("pose", [0.0] * 6)
|
|
879
|
+
joint_deg = status.get("joint", [0.0] * 6)
|
|
880
|
+
joint_rad = [j / 180.0 * 3.1415926535 for j in joint_deg]
|
|
881
|
+
|
|
882
|
+
self._last_arm_status = (pose, joint_rad)
|
|
883
|
+
self._last_arm_time = current_time
|
|
884
|
+
self._stats['arm_calls'] += 1
|
|
885
|
+
|
|
886
|
+
return pose, joint_rad
|
|
887
|
+
except Exception as e:
|
|
888
|
+
logger.debug(f"获取机械臂状态失败: {e}")
|
|
889
|
+
|
|
890
|
+
if self._last_arm_status:
|
|
891
|
+
return self._last_arm_status
|
|
892
|
+
|
|
893
|
+
return [0.0] * 6, [0.0] * 6
|
|
894
|
+
|
|
895
|
+
def _record_trajectory_async(self, filename: str, sku: str, img_dir: str):
|
|
896
|
+
try:
|
|
897
|
+
logger.info(f"开始高性能记录: {filename}, SKU: {sku}")
|
|
898
|
+
|
|
899
|
+
import shutil
|
|
900
|
+
if os.path.exists(img_dir):
|
|
901
|
+
shutil.rmtree(img_dir)
|
|
902
|
+
os.makedirs(img_dir, exist_ok=True)
|
|
903
|
+
|
|
904
|
+
self._start_async_capture()
|
|
905
|
+
time.sleep(0.3)
|
|
906
|
+
|
|
907
|
+
csv_file = open(filename, "w", newline="")
|
|
908
|
+
csv_writer = csv.writer(csv_file)
|
|
909
|
+
|
|
910
|
+
csv_writer.writerow([
|
|
911
|
+
"Time(s)", "X(m)", "Y(m)", "Z(m)",
|
|
912
|
+
"Rx(rad)", "Ry(rad)", "Rz(rad)",
|
|
913
|
+
"j1(rad)", "j2(rad)", "j3(rad)",
|
|
914
|
+
"j4(rad)", "j5(rad)", "j6(rad)",
|
|
915
|
+
"Gripper_Set", "Gripper_Real", "SKU", "Image_Filename"
|
|
916
|
+
])
|
|
917
|
+
|
|
918
|
+
start_time = time.perf_counter()
|
|
919
|
+
next_sample_time = start_time
|
|
920
|
+
sample_count = 0
|
|
921
|
+
|
|
922
|
+
image_queue = deque()
|
|
923
|
+
image_save_thread_running = True
|
|
924
|
+
|
|
925
|
+
def image_save_worker():
|
|
926
|
+
saved_count = 0
|
|
927
|
+
while image_save_thread_running or image_queue:
|
|
928
|
+
if image_queue:
|
|
929
|
+
img_path, image_data = image_queue.popleft()
|
|
930
|
+
try:
|
|
931
|
+
if image_data is not None:
|
|
932
|
+
cv2.imwrite(img_path, image_data)
|
|
933
|
+
saved_count += 1
|
|
934
|
+
except Exception as e:
|
|
935
|
+
logger.debug(f"保存图像失败 {img_path}: {e}")
|
|
936
|
+
else:
|
|
937
|
+
time.sleep(0.001)
|
|
938
|
+
logger.debug(f"图像保存线程结束,共保存 {saved_count} 张图像")
|
|
939
|
+
|
|
940
|
+
image_save_thread = threading.Thread(
|
|
941
|
+
target=image_save_worker,
|
|
942
|
+
name="ImageSaveThread",
|
|
943
|
+
daemon=True
|
|
944
|
+
)
|
|
945
|
+
image_save_thread.start()
|
|
946
|
+
|
|
947
|
+
logger.info(f"开始数据采集,目标采样率: {self.sampling_rate}Hz")
|
|
948
|
+
|
|
949
|
+
last_log_time = start_time
|
|
950
|
+
last_sample_count = 0
|
|
951
|
+
|
|
952
|
+
while self.recording and not self.should_stop():
|
|
953
|
+
current_time = time.perf_counter()
|
|
954
|
+
|
|
955
|
+
if current_time >= next_sample_time:
|
|
956
|
+
loop_start = current_time
|
|
957
|
+
|
|
958
|
+
try:
|
|
959
|
+
rel_time = current_time - start_time
|
|
960
|
+
pose_data, joint_data = self._get_arm_status_cached()
|
|
961
|
+
gripper_set, gripper_real = self._get_gripper_state(pose_data)
|
|
962
|
+
|
|
963
|
+
image_filename = ""
|
|
964
|
+
if self.image_thread:
|
|
965
|
+
image = self.image_thread.get_image_at(time.time(), self.time_alignment_tolerance)
|
|
966
|
+
if image is not None:
|
|
967
|
+
img_filename = f"{rel_time:.6f}.png"
|
|
968
|
+
image_filename = img_filename
|
|
969
|
+
img_path = os.path.join(img_dir, img_filename)
|
|
970
|
+
|
|
971
|
+
image_queue.append((img_path, image))
|
|
972
|
+
self._stats['image_hits'] += 1
|
|
973
|
+
else:
|
|
974
|
+
self._stats['image_misses'] += 1
|
|
975
|
+
|
|
976
|
+
csv_writer.writerow([
|
|
977
|
+
f"{rel_time:.6f}",
|
|
978
|
+
*[f"{p:.6f}" for p in pose_data],
|
|
979
|
+
*[f"{j:.6f}" for j in joint_data],
|
|
980
|
+
f"{gripper_set}",
|
|
981
|
+
f"{gripper_real}",
|
|
982
|
+
f"{sku}",
|
|
983
|
+
image_filename
|
|
984
|
+
])
|
|
985
|
+
|
|
986
|
+
sample_count += 1
|
|
987
|
+
self._stats['total_samples'] += 1
|
|
988
|
+
|
|
989
|
+
if sample_count % 30 == 0:
|
|
990
|
+
csv_file.flush()
|
|
991
|
+
elapsed = time.perf_counter() - start_time
|
|
992
|
+
actual_rate = sample_count / elapsed if elapsed > 0 else 0
|
|
993
|
+
|
|
994
|
+
recent_samples = sample_count - last_sample_count
|
|
995
|
+
recent_time = current_time - last_log_time
|
|
996
|
+
recent_rate = recent_samples / recent_time if recent_time > 0 else 0
|
|
997
|
+
|
|
998
|
+
logger.info(
|
|
999
|
+
f"进度: {sample_count}样本, "
|
|
1000
|
+
f"时间: {elapsed:.1f}s, "
|
|
1001
|
+
f"平均率: {actual_rate:.1f}Hz, "
|
|
1002
|
+
f"近期率: {recent_rate:.1f}Hz, "
|
|
1003
|
+
f"图像命中: {self._stats['image_hits']}/{sample_count}"
|
|
1004
|
+
)
|
|
1005
|
+
|
|
1006
|
+
last_log_time = current_time
|
|
1007
|
+
last_sample_count = sample_count
|
|
1008
|
+
|
|
1009
|
+
next_sample_time = start_time + (sample_count * self.sampling_interval)
|
|
1010
|
+
process_time = time.perf_counter() - loop_start
|
|
1011
|
+
self._stats['processing_times'].append(process_time)
|
|
1012
|
+
|
|
1013
|
+
if process_time > self.sampling_interval * 1.5:
|
|
1014
|
+
logger.warning(
|
|
1015
|
+
f"采样处理时间过长: {process_time * 1000:.1f}ms > "
|
|
1016
|
+
f"间隔{self.sampling_interval * 1000:.1f}ms"
|
|
1017
|
+
)
|
|
1018
|
+
|
|
1019
|
+
except Exception as e:
|
|
1020
|
+
logger.error(f"采样过程出错: {e}")
|
|
1021
|
+
sample_count += 1
|
|
1022
|
+
next_sample_time = start_time + (sample_count * self.sampling_interval)
|
|
1023
|
+
|
|
1024
|
+
else:
|
|
1025
|
+
sleep_time = next_sample_time - time.perf_counter()
|
|
1026
|
+
if sleep_time > 0.0005:
|
|
1027
|
+
end_time = time.perf_counter() + sleep_time
|
|
1028
|
+
while time.perf_counter() < end_time:
|
|
1029
|
+
time.sleep(0.0001)
|
|
1030
|
+
|
|
1031
|
+
image_save_thread_running = False
|
|
1032
|
+
if image_save_thread.is_alive():
|
|
1033
|
+
image_save_thread.join(timeout=2.0)
|
|
1034
|
+
|
|
1035
|
+
self._stop_async_capture()
|
|
1036
|
+
|
|
1037
|
+
csv_file.flush()
|
|
1038
|
+
csv_file.close()
|
|
1039
|
+
|
|
1040
|
+
total_time = time.perf_counter() - start_time
|
|
1041
|
+
actual_rate = sample_count / total_time if total_time > 0 else 0
|
|
1042
|
+
|
|
1043
|
+
stats_file = filename.replace('.txt', '_stats.txt')
|
|
1044
|
+
with open(stats_file, 'w') as sf:
|
|
1045
|
+
sf.write("=== 轨迹记录统计 ===\n")
|
|
1046
|
+
sf.write(f"总样本数: {sample_count}\n")
|
|
1047
|
+
sf.write(f"总时间: {total_time:.3f}s\n")
|
|
1048
|
+
sf.write(f"目标采样率: {self.sampling_rate}Hz\n")
|
|
1049
|
+
sf.write(f"实际采样率: {actual_rate:.2f}Hz\n")
|
|
1050
|
+
sf.write(f"机械臂调用次数: {self._stats['arm_calls']}\n")
|
|
1051
|
+
sf.write(f"图像命中数: {self._stats['image_hits']}\n")
|
|
1052
|
+
sf.write(f"图像缺失数: {self._stats['image_misses']}\n")
|
|
1053
|
+
sf.write(f"SKU: {sku}\n")
|
|
1054
|
+
|
|
1055
|
+
if self._stats['processing_times']:
|
|
1056
|
+
avg_process = np.mean(self._stats['processing_times']) * 1000
|
|
1057
|
+
max_process = np.max(self._stats['processing_times']) * 1000
|
|
1058
|
+
sf.write(f"平均处理时间: {avg_process:.2f}ms\n")
|
|
1059
|
+
sf.write(f"最大处理时间: {max_process:.2f}ms\n")
|
|
1060
|
+
|
|
1061
|
+
logger.info(f"记录完成: {filename}")
|
|
1062
|
+
logger.info(f"最终统计: {sample_count}样本, {total_time:.2f}s, {actual_rate:.1f}Hz")
|
|
1063
|
+
|
|
1064
|
+
except Exception as e:
|
|
1065
|
+
logger.error(f"记录轨迹数据异常: {e}", exc_info=True)
|
|
1066
|
+
self._stop_async_capture()
|
|
1067
|
+
if 'csv_file' in locals() and not csv_file.closed:
|
|
1068
|
+
csv_file.close()
|
|
1069
|
+
|
|
1070
|
+
def _start_recording(self, sku: str):
|
|
1071
|
+
if self.recording:
|
|
1072
|
+
logger.warning("已经在记录中,忽略重复开始")
|
|
1073
|
+
return
|
|
1074
|
+
|
|
1075
|
+
self.recording = True
|
|
1076
|
+
self.record_count += 1
|
|
1077
|
+
|
|
1078
|
+
self._stats = {
|
|
1079
|
+
'total_samples': 0,
|
|
1080
|
+
'arm_calls': 0,
|
|
1081
|
+
'image_hits': 0,
|
|
1082
|
+
'image_misses': 0,
|
|
1083
|
+
'processing_times': deque(maxlen=100)
|
|
1084
|
+
}
|
|
1085
|
+
|
|
1086
|
+
self._last_arm_status = None
|
|
1087
|
+
self._last_arm_time = 0
|
|
1088
|
+
|
|
1089
|
+
with self._target_position_lock:
|
|
1090
|
+
self._target_positions = []
|
|
1091
|
+
self._target_reached = False
|
|
1092
|
+
|
|
1093
|
+
imgs_dir = f"images"
|
|
1094
|
+
filename = f"actions.txt"
|
|
1095
|
+
traj_dir = f"traj_{self.record_count:03d}"
|
|
1096
|
+
filepath = os.path.join(self.data_dir, traj_dir, filename)
|
|
1097
|
+
imgs_filepath = os.path.join(self.data_dir, traj_dir, imgs_dir)
|
|
1098
|
+
os.makedirs(os.path.dirname(filepath), exist_ok=True)
|
|
1099
|
+
os.makedirs(imgs_filepath, exist_ok=True)
|
|
1100
|
+
|
|
1101
|
+
self.record_thread = threading.Thread(
|
|
1102
|
+
target=self._record_trajectory_async,
|
|
1103
|
+
args=(filepath, sku, imgs_filepath),
|
|
1104
|
+
name=f"RecordThread_{self.record_count}",
|
|
1105
|
+
daemon=True
|
|
1106
|
+
)
|
|
1107
|
+
self.record_thread.start()
|
|
1108
|
+
logger.info(f"开始第 {self.record_count} 次数据记录 -> {filepath}, SKU: {sku}")
|
|
1109
|
+
|
|
1110
|
+
def _stop_recording(self):
|
|
1111
|
+
if not self.recording:
|
|
1112
|
+
return
|
|
1113
|
+
|
|
1114
|
+
self.recording = False
|
|
1115
|
+
|
|
1116
|
+
if self.record_thread and self.record_thread.is_alive():
|
|
1117
|
+
self.record_thread.join(timeout=5.0)
|
|
1118
|
+
logger.info(f"第 {self.record_count} 次数据记录完成")
|
|
1119
|
+
|
|
1120
|
+
def run_from_http_camera(self, sku):
|
|
1121
|
+
try:
|
|
1122
|
+
logger.info(f"执行{self.dataset_id}-{self.ability_id}任务")
|
|
1123
|
+
start_time = time.time()
|
|
1124
|
+
self.set_sampling_rate(20.0)
|
|
1125
|
+
self.clear_stop_flag()
|
|
1126
|
+
dof_time = time.time()
|
|
1127
|
+
xyz, quat_xyzw, _ = post_find_basket()
|
|
1128
|
+
logger.info(f'{self.dataset_id}-{self.ability_id} - find_basket 用时{time.time() - dof_time}')
|
|
1129
|
+
if xyz is None:
|
|
1130
|
+
logger.info(f"{self.dataset_id}-{self.ability_id}任务执行完成,没检测到basket信息")
|
|
1131
|
+
return {'code': -4, 'msg': "没检测到basket信息", 'dataset_id': self.dataset_id, 'traj_path': ''}
|
|
1132
|
+
coords = xyz
|
|
1133
|
+
quat = quat_xyzw
|
|
1134
|
+
seven_dim_coord = np.concatenate((coords, quat))
|
|
1135
|
+
|
|
1136
|
+
state, _ = get_arm_status()
|
|
1137
|
+
end_pose = state["pose"]
|
|
1138
|
+
logger.info(f'起始位姿: {end_pose}')
|
|
1139
|
+
|
|
1140
|
+
pos_target = get_target_pose(end_pose, seven_dim_coord)
|
|
1141
|
+
pos_target = [float(pos) for pos in pos_target]
|
|
1142
|
+
logger.info(f'pos_target: {pos_target}')
|
|
1143
|
+
|
|
1144
|
+
move_posy = [state["pose"][0], pos_target[1], state["pose"][2], state["pose"][3], state["pose"][4],
|
|
1145
|
+
state["pose"][5]]
|
|
1146
|
+
move_posx = [pos_target[0] + 0.17, pos_target[1], pos_target[2] + 0.35, state["pose"][3], state["pose"][4],
|
|
1147
|
+
state["pose"][5]]
|
|
1148
|
+
|
|
1149
|
+
self.set_target_position(move_posx)
|
|
1150
|
+
self._start_recording(sku)
|
|
1151
|
+
time.sleep(0.5)
|
|
1152
|
+
|
|
1153
|
+
code, _ = post_arm_movel(move_posy, speed=self.speed)
|
|
1154
|
+
code, _ = post_arm_movel(move_posx, speed=self.speed)
|
|
1155
|
+
|
|
1156
|
+
self.mark_target_reached()
|
|
1157
|
+
post_gripper_move(1000)
|
|
1158
|
+
|
|
1159
|
+
keep_time = 2
|
|
1160
|
+
keep_start = time.time()
|
|
1161
|
+
while (time.time() - keep_start) < keep_time and not self.should_stop():
|
|
1162
|
+
time.sleep(0.1)
|
|
1163
|
+
|
|
1164
|
+
self._stop_recording()
|
|
1165
|
+
time.sleep(1)
|
|
1166
|
+
|
|
1167
|
+
code, _ = post_arm_movej(self.init_pose, speed=self.speed)
|
|
1168
|
+
# post_gripper_move(0)
|
|
1169
|
+
logger.info(f'{self.dataset_id}-{self.ability_id} - 总用时{time.time() - start_time}')
|
|
1170
|
+
logger.info(f"{self.dataset_id}-{self.ability_id}任务执行完成,机械臂已回到初始位置")
|
|
1171
|
+
traj_path = os.path.join(self.data_dir, f"traj_{self.record_count:03d}")
|
|
1172
|
+
process_dataset(self.data_dir, self.record_count)
|
|
1173
|
+
self.msg['traj_path'] = traj_path
|
|
1174
|
+
self.msg['dataset_id'] = self.dataset_id
|
|
1175
|
+
self.msg['ability_id'] = self.ability_id
|
|
1176
|
+
return self.msg
|
|
1177
|
+
except RuntimeError as e:
|
|
1178
|
+
msg = {'code': -1, 'msg': str(e), 'dataset_id': self.dataset_id, 'traj_path': '', 'ability_id': self.ability_id}
|
|
1179
|
+
self.msg = msg
|
|
1180
|
+
return self.msg
|
|
1181
|
+
except Exception as e:
|
|
1182
|
+
msg = {'code': -1, 'msg': f"未知错误: {str(e)}", 'dataset_id': self.dataset_id, 'traj_path': '', 'ability_id': self.ability_id}
|
|
1183
|
+
self.msg = msg
|
|
1184
|
+
return self.msg
|
|
1185
|
+
|
|
1186
|
+
finally:
|
|
1187
|
+
if self.msg.get('code') != 0:
|
|
1188
|
+
traj_dir_to_clean = os.path.join(self.data_dir, f"traj_{self.record_count:03d}")
|
|
1189
|
+
if os.path.isdir(traj_dir_to_clean):
|
|
1190
|
+
try:
|
|
1191
|
+
import shutil
|
|
1192
|
+
shutil.rmtree(traj_dir_to_clean)
|
|
1193
|
+
logger.warning(f"{self.dataset_id}-{self.ability_id} 采集失败,已清理目录: {traj_dir_to_clean}")
|
|
1194
|
+
except Exception as e:
|
|
1195
|
+
logger.error(f"{self.dataset_id}-{self.ability_id} 清理失败目录 {traj_dir_to_clean} 时出错: {e}")
|
|
1196
|
+
|
|
1197
|
+
self._stop_recording()
|
|
1198
|
+
if hasattr(self, '_thread_pool'):
|
|
1199
|
+
self._thread_pool.shutdown(wait=False)
|
|
1200
|
+
|
|
1201
|
+
def set_time_alignment_tolerance(self, tolerance_ms: float):
|
|
1202
|
+
self.time_alignment_tolerance = tolerance_ms / 1000.0
|
|
1203
|
+
logger.info(f"时间对齐容差设置为: {tolerance_ms}ms")
|
|
1204
|
+
|
|
1205
|
+
def set_sampling_rate(self, rate_hz: float):
|
|
1206
|
+
if rate_hz <= 0:
|
|
1207
|
+
logger.warning(f"无效的采样率: {rate_hz} Hz,使用默认值20Hz")
|
|
1208
|
+
rate_hz = 20.0
|
|
1209
|
+
|
|
1210
|
+
self.sampling_rate = rate_hz
|
|
1211
|
+
self.sampling_interval = 1.0 / rate_hz
|
|
1212
|
+
logger.info(f"采样率设置为: {rate_hz} Hz (间隔: {self.sampling_interval:.6f} 秒)")
|
|
1213
|
+
|
|
1214
|
+
def get_capture_stats(self) -> Dict:
|
|
1215
|
+
stats = {
|
|
1216
|
+
'recording': self.recording,
|
|
1217
|
+
'record_count': self.record_count,
|
|
1218
|
+
'sampling_rate': self.sampling_rate,
|
|
1219
|
+
'time_alignment_tolerance_ms': self.time_alignment_tolerance * 1000,
|
|
1220
|
+
'performance_stats': self._stats.copy()
|
|
1221
|
+
}
|
|
1222
|
+
|
|
1223
|
+
if self.image_thread:
|
|
1224
|
+
stats['image_capture'] = self.image_thread.get_stats()
|
|
1225
|
+
|
|
1226
|
+
if self._stats['processing_times']:
|
|
1227
|
+
proc_times = list(self._stats['processing_times'])
|
|
1228
|
+
stats['avg_processing_time_ms'] = np.mean(proc_times) * 1000
|
|
1229
|
+
stats['max_processing_time_ms'] = np.max(proc_times) * 1000
|
|
1230
|
+
|
|
1231
|
+
return stats
|
|
1232
|
+
|
|
1233
|
+
def get_record_count(self) -> int:
|
|
1234
|
+
return self.record_count
|
|
1235
|
+
|
|
1236
|
+
def cleanup(self):
|
|
1237
|
+
logger.info("开始清理Pick类资源...")
|
|
1238
|
+
self._stop_recording()
|
|
1239
|
+
if hasattr(self, '_thread_pool'):
|
|
1240
|
+
self._thread_pool.shutdown(wait=False)
|
|
1241
|
+
logger.info("Pick类资源清理完成")
|