zyworkflow 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,834 @@
1
+ import os
2
+ import time
3
+ import glob
4
+ import torch
5
+ import traceback
6
+ import numpy as np
7
+ import pandas as pd
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import torch.multiprocessing as mp
11
+ from PIL import Image
12
+ from tqdm import tqdm
13
+ from bnn.simulate import Simulate
14
+ from torch.utils.data import Dataset, DataLoader
15
+ from concurrent.futures import ThreadPoolExecutor
16
+ from zyworkflow.utils.logger_config import setup_train_pick_policy_logger
17
+
18
+
19
+ logger = setup_train_pick_policy_logger()
20
+ GRIPPER_CLOSE_THRESH = 500.0
21
+ SUCCESS_THRESH = 0.5
22
+ COL_JOINTS = ["j1", "j2", "j3", "j4", "j5", "j6"]
23
+ COL_GRIPPER = "Gripper_Set_Position(‰)"
24
+ COL_SUCCESS = "success_flag"
25
+
26
+
27
+ def build_success_targets_and_mask(
28
+ raw_success_flags: torch.Tensor,
29
+ t: int,
30
+ traj_len: int,
31
+ chunk_len: int,
32
+ pad_len: int,
33
+ mode: str,
34
+ thr: float,
35
+ ):
36
+ device = raw_success_flags.device
37
+ succ = (raw_success_flags > thr).float()
38
+
39
+ if mode == "within_horizon":
40
+ rev = torch.flip(succ, dims=[0])
41
+ rev_cum = torch.cumsum(rev, dim=0)
42
+ suffix_any = (torch.flip(rev_cum, dims=[0]) > 0).float()
43
+ c_s = suffix_any
44
+ c_sm = torch.ones_like(c_s)
45
+
46
+ elif mode == "terminal_only":
47
+ c_s = torch.zeros((chunk_len, 1), dtype=torch.float32, device=device)
48
+ c_sm = torch.zeros((chunk_len, 1), dtype=torch.float32, device=device)
49
+ final_idx = traj_len - 1
50
+ if t <= final_idx < (t + chunk_len):
51
+ off = final_idx - t
52
+ c_s[off, 0] = succ[off, 0]
53
+ c_sm[off, 0] = 1.0
54
+ else:
55
+ c_s = succ
56
+ c_sm = torch.ones_like(c_s)
57
+
58
+ if pad_len > 0:
59
+ if chunk_len > 0:
60
+ last_s = c_s[-1:].repeat(pad_len, 1)
61
+ else:
62
+ last_s = torch.zeros((pad_len, 1), device=device)
63
+
64
+ c_s = torch.cat([c_s, last_s], dim=0)
65
+ pad_zeros = torch.zeros((pad_len, 1), dtype=torch.float32, device=device)
66
+ c_sm = torch.cat([c_sm, pad_zeros], dim=0)
67
+
68
+ return c_s, c_sm
69
+
70
+
71
+ class BNNWorker(mp.Process):
72
+ def __init__(self, pipe):
73
+ super().__init__()
74
+ self.pipe = pipe
75
+ self.bnn_instance = None
76
+
77
+ def run(self):
78
+ try:
79
+ os.environ["OMP_NUM_THREADS"] = "1"
80
+ torch.set_num_threads(1)
81
+
82
+ self.bnn_instance = Simulate()
83
+
84
+ while True:
85
+ cmd, data = self.pipe.recv()
86
+ if cmd == "STEP":
87
+ inp = torch.tensor(data, dtype=torch.float32)
88
+ with torch.no_grad():
89
+ out = self.bnn_instance.run_simulation(inp)
90
+ if isinstance(out, torch.Tensor):
91
+ out = out.detach().cpu().numpy()
92
+ self.pipe.send(out)
93
+ elif cmd == "RESET":
94
+ if self.bnn_instance:
95
+ self.bnn_instance.reset_state()
96
+ self.pipe.send("OK")
97
+ elif cmd == "CLOSE":
98
+ break
99
+ except Exception as e:
100
+ logger.error(f"BNN Worker Error: {e}")
101
+ logger.error(traceback.format_exc())
102
+ try:
103
+ self.pipe.send(None)
104
+ except:
105
+ pass
106
+
107
+
108
+ class PersistentBNNPool:
109
+ def __init__(self, num_workers):
110
+ self.num_workers = num_workers
111
+ self.workers = []
112
+ self.pipes = []
113
+ logger.info(f"启动 {num_workers} 个 BNN 常驻进程...")
114
+ for _ in range(num_workers):
115
+ parent_conn, child_conn = mp.Pipe()
116
+ p = BNNWorker(child_conn)
117
+ p.daemon = True
118
+ p.start()
119
+ self.workers.append(p)
120
+ self.pipes.append(parent_conn)
121
+
122
+ def reset_all(self, n_used=None):
123
+ if n_used is None:
124
+ n_used = len(self.pipes)
125
+ for p in self.pipes[:n_used]:
126
+ p.send(("RESET", None))
127
+ for p in self.pipes[:n_used]:
128
+ p.recv()
129
+
130
+ def step_batch(self, batch_inputs_np):
131
+ n = len(batch_inputs_np)
132
+ for i, inp in enumerate(batch_inputs_np):
133
+ self.pipes[i].send(("STEP", inp))
134
+ return [self.pipes[i].recv() for i in range(n)]
135
+
136
+ def close(self, timeout: float = 2.0):
137
+ for p in self.pipes:
138
+ try:
139
+ p.send(("CLOSE", None))
140
+ except Exception:
141
+ pass
142
+
143
+ deadline = time.time() + timeout
144
+ for w in self.workers:
145
+ remaining = max(0.0, deadline - time.time())
146
+ try:
147
+ w.join(timeout=remaining)
148
+ except Exception:
149
+ pass
150
+
151
+ for w in self.workers:
152
+ if w.is_alive():
153
+ try:
154
+ logger.warning(f"强制终止残留BNN进程: pid={w.pid}")
155
+ w.terminate()
156
+ w.join(timeout=1.0)
157
+ except Exception as e:
158
+ logger.error(f"强制终止BNN进程失败: pid={w.pid}, err={e}")
159
+
160
+ for p in self.pipes:
161
+ try:
162
+ p.close()
163
+ except Exception:
164
+ pass
165
+
166
+
167
+ class SingleViewRobotTrajectoryDataset(Dataset):
168
+ def __init__(self, root_dir, min_frames_per_traj=6, time_round=3, time_tol=1e-3, debug_max_bad=3):
169
+ self.root_dir = root_dir
170
+ self.trajectories = []
171
+
172
+ self.min_frames_per_traj = min_frames_per_traj
173
+ self.time_round = time_round
174
+ self.time_tol = time_tol
175
+
176
+ self.debug_stats = {
177
+ "root_dir_exists": os.path.exists(root_dir),
178
+ "traj_dirs_found": 0,
179
+ "traj_used": 0,
180
+ "skip_missing_csv": 0,
181
+ "skip_missing_imgdir": 0,
182
+ "skip_csv_read_error": 0,
183
+ "skip_missing_cols": 0,
184
+ "skip_no_images": 0,
185
+ "skip_no_parsable_images": 0,
186
+ "skip_no_matched": 0,
187
+ "skip_too_few_images": 0,
188
+ "total_frames_csv": 0,
189
+ "total_images_in_dir": 0,
190
+ "matched_frames": 0,
191
+ }
192
+ self.bad_examples = []
193
+
194
+ traj_dirs = sorted([os.path.join(root_dir, d) for d in os.listdir(root_dir) if d.startswith("traj_")])
195
+ self.debug_stats["traj_dirs_found"] = len(traj_dirs)
196
+ logger.info(f"Dataset: 正在扫描数据集: {root_dir} | traj_dirs={len(traj_dirs)}")
197
+
198
+ num_pos = 0
199
+ num_neg = 0
200
+ all_joints_list = []
201
+
202
+ for traj_path in tqdm(traj_dirs):
203
+ csv_path = os.path.join(traj_path, "actions.csv")
204
+ if not os.path.exists(csv_path):
205
+ self.debug_stats["skip_missing_csv"] += 1
206
+ continue
207
+
208
+ img_dir = None
209
+ for cand in ["images", "image"]:
210
+ p = os.path.join(traj_path, cand)
211
+ if os.path.isdir(p):
212
+ img_dir = p
213
+ break
214
+ if img_dir is None:
215
+ self.debug_stats["skip_missing_imgdir"] += 1
216
+ continue
217
+
218
+ try:
219
+ df = pd.read_csv(csv_path, header=0)
220
+ except Exception:
221
+ self.debug_stats["skip_csv_read_error"] += 1
222
+ continue
223
+
224
+ if "Time(s)" not in df.columns:
225
+ self.debug_stats["skip_missing_cols"] += 1
226
+ if len(self.bad_examples) < debug_max_bad:
227
+ self.bad_examples.append({
228
+ "traj": traj_path,
229
+ "reason": "missing Time(s) col",
230
+ "df_cols": list(df.columns)[:30],
231
+ })
232
+ continue
233
+
234
+ need_cols = [COL_GRIPPER, COL_SUCCESS] + COL_JOINTS
235
+ if any(c not in df.columns for c in need_cols):
236
+ self.debug_stats["skip_missing_cols"] += 1
237
+ if len(self.bad_examples) < debug_max_bad:
238
+ self.bad_examples.append({
239
+ "traj": traj_path,
240
+ "reason": "missing required cols",
241
+ "need_cols": need_cols,
242
+ "df_cols": list(df.columns),
243
+ })
244
+ continue
245
+
246
+ df = df.sort_values("Time(s)").reset_index(drop=True)
247
+ T = len(df)
248
+ self.debug_stats["total_frames_csv"] += T
249
+
250
+ img_files = glob.glob(os.path.join(img_dir, "*.png"))
251
+ self.debug_stats["total_images_in_dir"] += len(img_files)
252
+ if len(img_files) == 0:
253
+ self.debug_stats["skip_no_images"] += 1
254
+ continue
255
+
256
+ img_entries = []
257
+ img_name_examples = []
258
+ for fp in img_files:
259
+ stem = os.path.splitext(os.path.basename(fp))[0]
260
+ if len(img_name_examples) < 8:
261
+ img_name_examples.append(stem)
262
+ try:
263
+ tf = float(stem)
264
+ img_entries.append((tf, fp))
265
+ except Exception:
266
+ pass
267
+
268
+ if len(img_entries) == 0:
269
+ self.debug_stats["skip_no_parsable_images"] += 1
270
+ if len(self.bad_examples) < debug_max_bad:
271
+ self.bad_examples.append({
272
+ "traj": traj_path,
273
+ "reason": "no parsable image filenames (stem->float failed)",
274
+ "img_dir": img_dir,
275
+ "img_stems_sample": img_name_examples,
276
+ })
277
+ continue
278
+
279
+ img_entries.sort(key=lambda x: x[0])
280
+ img_times = np.array([x[0] for x in img_entries], dtype=np.float64)
281
+ img_paths = [x[1] for x in img_entries]
282
+
283
+ img_map = {}
284
+ for tf, fp in img_entries:
285
+ k = round(float(tf), self.time_round)
286
+ if k not in img_map:
287
+ img_map[k] = fp
288
+
289
+ joints_np = df[COL_JOINTS].to_numpy(dtype=np.float32)
290
+ gripper_np = df[COL_GRIPPER].to_numpy(dtype=np.float32)
291
+ success_np = df[COL_SUCCESS].to_numpy(dtype=np.float32)
292
+ times_np = df["Time(s)"].to_numpy(dtype=np.float64)
293
+
294
+ valid_img_paths = []
295
+ valid_targets = []
296
+
297
+ for i in range(T):
298
+ t_csv = float(times_np[i])
299
+ key = round(t_csv, self.time_round)
300
+
301
+ fp = img_map.get(key, None)
302
+
303
+ if fp is None:
304
+ idx = int(np.searchsorted(img_times, t_csv))
305
+ cand = []
306
+ if 0 <= idx < len(img_times):
307
+ cand.append(idx)
308
+ if 0 <= idx - 1 < len(img_times):
309
+ cand.append(idx - 1)
310
+ best_fp = None
311
+ best_dt = 1e9
312
+ for ci in cand:
313
+ dt = abs(float(img_times[ci]) - t_csv)
314
+ if dt < best_dt:
315
+ best_dt = dt
316
+ best_fp = img_paths[ci]
317
+ if best_fp is not None and best_dt <= self.time_tol:
318
+ fp = best_fp
319
+
320
+ if fp is None or (not os.path.exists(fp)):
321
+ continue
322
+
323
+ joints_val = joints_np[i]
324
+ gripper_val = float(gripper_np[i])
325
+ success_val = float(success_np[i])
326
+
327
+ target_vec = np.concatenate([joints_val, [gripper_val], [success_val]]).astype(np.float32)
328
+ valid_img_paths.append(fp)
329
+ valid_targets.append(target_vec)
330
+
331
+ if gripper_val < GRIPPER_CLOSE_THRESH:
332
+ num_pos += 1
333
+ else:
334
+ num_neg += 1
335
+
336
+ self.debug_stats["matched_frames"] += len(valid_img_paths)
337
+
338
+ if len(valid_img_paths) == 0:
339
+ self.debug_stats["skip_no_matched"] += 1
340
+ if len(self.bad_examples) < debug_max_bad:
341
+ self.bad_examples.append({
342
+ "traj": traj_path,
343
+ "reason": "matched_frames=0 (Time(s) vs img filename mismatch)",
344
+ "csv_times_sample": [float(x) for x in times_np[:8]],
345
+ "img_stems_sample": img_name_examples,
346
+ "time_round": self.time_round,
347
+ "time_tol": self.time_tol,
348
+ })
349
+ continue
350
+
351
+ if len(valid_img_paths) < self.min_frames_per_traj:
352
+ self.debug_stats["skip_too_few_images"] += 1
353
+ if len(self.bad_examples) < debug_max_bad:
354
+ self.bad_examples.append({
355
+ "traj": traj_path,
356
+ "reason": f"too few matched frames (<{self.min_frames_per_traj})",
357
+ "matched": len(valid_img_paths),
358
+ "csv_len": T,
359
+ "images_in_dir": len(img_files),
360
+ })
361
+ continue
362
+
363
+ try:
364
+ with Image.open(valid_img_paths[0]) as im:
365
+ exp_hw = (im.height, im.width)
366
+ except Exception:
367
+ exp_hw = (480, 640)
368
+
369
+ self.trajectories.append({
370
+ "traj_id": traj_path,
371
+ "view_paths": valid_img_paths,
372
+ "targets": np.array(valid_targets, dtype=np.float32),
373
+ "length": len(valid_img_paths),
374
+ "exp_hw": exp_hw
375
+ })
376
+ self.debug_stats["traj_used"] += 1
377
+
378
+ all_joints_list.append(np.array(valid_targets, dtype=np.float32)[:, :6])
379
+
380
+ if len(all_joints_list) > 0:
381
+ all_joints_np = np.concatenate(all_joints_list, axis=0)
382
+ self.joint_mean = torch.tensor(np.mean(all_joints_np, axis=0), dtype=torch.float32)
383
+ self.joint_std = torch.tensor(np.std(all_joints_np, axis=0), dtype=torch.float32)
384
+ self.joint_std = torch.where(self.joint_std < 1e-6, torch.ones_like(self.joint_std), self.joint_std)
385
+ else:
386
+ self.joint_mean = torch.zeros(6)
387
+ self.joint_std = torch.ones(6)
388
+
389
+ self.pos_weight = (num_neg / max(num_pos, 1)) if (num_pos + num_neg) > 0 else 1.0
390
+
391
+ logger.success(f"数据加载完毕: {len(self.trajectories)} 条轨迹.")
392
+ logger.info(f"PosWeight={self.pos_weight:.2f}")
393
+ logger.info(f"Joint Mean: {self.joint_mean.numpy().round(3)}")
394
+
395
+ logger.debug("Dataset build summary:")
396
+ for k, v in self.debug_stats.items():
397
+ logger.debug(f" - {k}: {v}")
398
+ if len(self.bad_examples) > 0:
399
+ logger.warning(f"发现 {len(self.bad_examples)} 个错误示例:")
400
+ for ex in self.bad_examples:
401
+ for kk, vv in ex.items():
402
+ logger.debug(f"{kk}: {vv}")
403
+
404
+ def __len__(self):
405
+ return len(self.trajectories)
406
+
407
+ def __getitem__(self, idx):
408
+ return self.trajectories[idx]
409
+
410
+
411
+ def traj_collate_fn(batch):
412
+ return batch
413
+
414
+
415
+ class FiLM(nn.Module):
416
+ def __init__(self, dim_in, dim_out, num_layers=2, hidden_dim=128):
417
+ super().__init__()
418
+ self.mlp = nn.Sequential(
419
+ nn.Linear(dim_in, hidden_dim), nn.ReLU(),
420
+ nn.Linear(hidden_dim, dim_out * 2)
421
+ )
422
+
423
+ def forward(self, x, cond):
424
+ cond_flat = cond.reshape(-1, cond.size(-1))
425
+ params = self.mlp(cond_flat)
426
+ gamma, beta = params.chunk(2, dim=-1)
427
+ B, S, D = x.shape
428
+ gamma = gamma.view(B, S, D)
429
+ beta = beta.view(B, S, D)
430
+ return gamma * x + beta
431
+
432
+
433
+ class SingleViewBNNActionPolicy(nn.Module):
434
+ def __init__(self, seq_len=4, action_chunk=8, dim_to_bnn=15, dim_bnn_output=80):
435
+ super().__init__()
436
+ self.seq_len = seq_len
437
+ self.action_chunk = action_chunk
438
+
439
+ self.conv_layers = nn.Sequential(
440
+ nn.Conv2d(3, 24, kernel_size=5, stride=2, padding=2), nn.ReLU(),
441
+ nn.Conv2d(24, 36, kernel_size=5, stride=2, padding=2), nn.ReLU(),
442
+ nn.Conv2d(36, 48, kernel_size=3, stride=2, padding=1), nn.ReLU(),
443
+ nn.Conv2d(48, 64, kernel_size=3, stride=1, padding=1), nn.ReLU(),
444
+ nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1), nn.ReLU()
445
+ )
446
+
447
+ self.before_bnn_mlp = nn.Sequential(
448
+ nn.Linear(80 * 60, 512), nn.ReLU(),
449
+ nn.Linear(512, dim_to_bnn)
450
+ )
451
+ self.bnn_adapter = nn.Sequential(nn.LayerNorm(dim_to_bnn), nn.Tanh())
452
+
453
+ self.film = FiLM(dim_in=dim_to_bnn, dim_out=dim_bnn_output)
454
+
455
+ feature_dim = (dim_to_bnn + dim_bnn_output) * seq_len
456
+ self.shared_backbone = nn.Sequential(
457
+ nn.Linear(feature_dim, 512), nn.ReLU(),
458
+ nn.Linear(512, 256)
459
+ )
460
+
461
+ self.head_joints = nn.Sequential(
462
+ nn.Linear(256, 128), nn.ReLU(),
463
+ nn.Linear(128, 64), nn.ReLU(),
464
+ nn.Linear(64, 6 * action_chunk)
465
+ )
466
+ self.head_gripper = nn.Sequential(
467
+ nn.Linear(256, 64), nn.ReLU(),
468
+ nn.Linear(64, 1 * action_chunk)
469
+ )
470
+ self.head_success = nn.Sequential(
471
+ nn.Linear(256, 64), nn.ReLU(),
472
+ nn.Linear(64, 1 * action_chunk)
473
+ )
474
+
475
+ def encode_visual(self, x):
476
+ B, S, C, H, W = x.shape
477
+ x = x.contiguous().view(B * S, C, H, W)
478
+ x = self.conv_layers(x)
479
+ x = x.view(B, S, -1)
480
+ feat = self.before_bnn_mlp(x)
481
+ return self.bnn_adapter(feat)
482
+
483
+ def decode_action(self, bnn_in, bnn_out):
484
+ bnn_out_film = self.film(bnn_out, bnn_in)
485
+ bnn_out_film = F.layer_norm(bnn_out_film, (bnn_out_film.size(-1),))
486
+ feats = torch.cat([bnn_in, bnn_out_film], dim=-1).view(bnn_in.size(0), -1)
487
+ shared = self.shared_backbone(feats)
488
+
489
+ j = self.head_joints(shared).view(-1, self.action_chunk, 6)
490
+ g = self.head_gripper(shared).view(-1, self.action_chunk, 1)
491
+ s = self.head_success(shared).view(-1, self.action_chunk, 1)
492
+ return j, g, s
493
+
494
+
495
+ def _load_single_window(args):
496
+ v_paths, t, seq_len, exp_hw = args
497
+ target_w, target_h = 640, 480
498
+
499
+ v_list = []
500
+ start_idx = t - seq_len + 1
501
+
502
+ for i in range(start_idx, t + 1):
503
+ idx = i if i >= 0 else 0
504
+ try:
505
+ with Image.open(v_paths[idx]) as img:
506
+ arr = np.array(img.resize((target_w, target_h)), dtype=np.float32) / 255.0
507
+ v_list.append(arr.transpose(2, 0, 1))
508
+ except Exception:
509
+ v_list.append(np.zeros((3, target_h, target_w), dtype=np.float32))
510
+
511
+ return np.stack(v_list)
512
+
513
+
514
+ def train_single_view_parallel_chunk(
515
+ task_name,
516
+ root_dir,
517
+ batch_size=50,
518
+ seq_len=4,
519
+ action_chunk=8,
520
+ lr=1e-4,
521
+ num_epochs=500,
522
+ start_epoch=0,
523
+ lambda_joints=10.0, lambda_grip=5.0, lambda_success=2.0,
524
+ log_path=None, ckpt_dir=None, success_mode="within_horizon",
525
+ report_url=None,
526
+ ):
527
+ logger.info(f"启动单视角训练 | BS={batch_size} | Chunk={action_chunk} | Mode={success_mode}")
528
+
529
+ bnn_pool = None
530
+ io_pool = None
531
+ device = None
532
+
533
+ last_saved_ckpt_path = None
534
+ try:
535
+ if log_path:
536
+ os.makedirs(os.path.dirname(log_path), exist_ok=True)
537
+ os.makedirs(ckpt_dir, exist_ok=True)
538
+
539
+ dataset = SingleViewRobotTrajectoryDataset(
540
+ root_dir=root_dir,
541
+ min_frames_per_traj=6,
542
+ time_round=3,
543
+ time_tol=1e-3,
544
+ debug_max_bad=3
545
+ )
546
+
547
+ if len(dataset) == 0:
548
+ raise RuntimeError(
549
+ "Dataset size = 0,训练无法开始。\n"
550
+ "请看上面 [DEBUG] 的 bad examples:基本都是 Time(s) 与图片文件名不匹配导致 matched_frames=0。\n"
551
+ )
552
+
553
+ loader = DataLoader(dataset, batch_size=batch_size, shuffle=True,
554
+ collate_fn=traj_collate_fn, drop_last=True)
555
+
556
+ bnn_pool = PersistentBNNPool(num_workers=batch_size)
557
+ io_pool = ThreadPoolExecutor(max_workers=min(64, batch_size * 2))
558
+
559
+ device = torch.device("cuda")
560
+
561
+ joint_mean_gpu = dataset.joint_mean.to(device)
562
+ joint_std_gpu = dataset.joint_std.to(device)
563
+
564
+ model = SingleViewBNNActionPolicy(seq_len, action_chunk).to(device)
565
+ optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
566
+
567
+ crit_mse = nn.MSELoss(reduction='none')
568
+ crit_bce_w = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([dataset.pos_weight], device=device), reduction='none')
569
+ crit_bce = nn.BCEWithLogitsLoss(reduction='none')
570
+
571
+ if start_epoch > 0:
572
+ p = os.path.join(ckpt_dir, f"epoch_{start_epoch}.pth")
573
+ if os.path.exists(p):
574
+ checkpoint = torch.load(p, map_location=device)
575
+ if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
576
+ model.load_state_dict(checkpoint['model_state_dict'])
577
+ logger.info(f"Loaded Checkpoint Epoch {checkpoint.get('epoch')}")
578
+ else:
579
+ model.load_state_dict(checkpoint)
580
+ logger.info("Loaded Legacy State Dict")
581
+
582
+ update_freq = 5
583
+
584
+ for epoch in range(start_epoch, num_epochs):
585
+ epoch_start_time = time.time()
586
+ model.train()
587
+ total_loss, steps = 0.0, 0
588
+ real_joint_err_accum = 0.0
589
+
590
+ # pbar = tqdm(loader, desc=f"Ep {epoch+1}")
591
+ for batch_trajs in loader:
592
+ curr_bs = len(batch_trajs)
593
+ bnn_pool.reset_all(curr_bs)
594
+
595
+ optimizer.zero_grad(set_to_none=True)
596
+ accum_steps = 0
597
+
598
+ lengths = [t["length"] for t in batch_trajs]
599
+ max_len = max(lengths)
600
+ exp_hw = batch_trajs[0]["exp_hw"]
601
+
602
+ bnn_hist = torch.zeros(curr_bs, max_len, 80, device=device)
603
+
604
+ for t in range(max_len):
605
+
606
+ futures = [
607
+ io_pool.submit(_load_single_window, (tr["view_paths"], t, seq_len, exp_hw))
608
+ if t < tr["length"] else None
609
+ for tr in batch_trajs
610
+ ]
611
+
612
+ imgs, masks = [], []
613
+ for f in futures:
614
+ if f:
615
+ v = f.result()
616
+ imgs.append(torch.from_numpy(v))
617
+ masks.append(True)
618
+ else:
619
+ z = torch.zeros(seq_len, 3, 480, 640)
620
+ imgs.append(z)
621
+ masks.append(False)
622
+
623
+ b_imgs = torch.stack(imgs).to(device, non_blocking=True)
624
+ bnn_in = model.encode_visual(b_imgs)
625
+
626
+ curr_feat = bnn_in[:, -1, :].detach().cpu().numpy() / 10.0
627
+ bnn_outs = bnn_pool.step_batch([curr_feat[b].reshape(1, -1) for b in range(curr_bs)])
628
+
629
+ clean_outs = [
630
+ o.T if o is not None and hasattr(o, "shape") and o.shape == (1, 80)
631
+ else (o if o is not None else np.zeros(80))
632
+ for o in bnn_outs
633
+ ]
634
+ clean_outs = [np.array(o).squeeze() for o in clean_outs]
635
+ bnn_curr = torch.tensor(np.stack(clean_outs), device=device, dtype=torch.float32)
636
+ bnn_hist[:, t, :] = bnn_curr
637
+
638
+ s_idx = t - seq_len + 1
639
+ if s_idx >= 0:
640
+ bnn_seq = bnn_hist[:, s_idx:t+1]
641
+ else:
642
+ first = bnn_hist[:, 0:1, :].repeat(1, -s_idx, 1)
643
+ bnn_seq = torch.cat([first, bnn_hist[:, 0:t+1]], dim=1)
644
+
645
+ p_j, p_g, p_s = model.decode_action(bnn_in, bnn_seq)
646
+
647
+ t_j_list, t_g_list, t_s_list, m_list = [], [], [], []
648
+ for b_idx in range(curr_bs):
649
+ if not masks[b_idx]:
650
+ z = torch.zeros(action_chunk, 1, device=device)
651
+ t_j_list.append(torch.zeros(action_chunk, 6, device=device))
652
+ t_g_list.append(z)
653
+ t_s_list.append(z)
654
+ m_list.append(z)
655
+ continue
656
+
657
+ traj = batch_trajs[b_idx]
658
+ real_end = min(t + action_chunk, traj["length"])
659
+ chunk_sz = real_end - t
660
+ pad_sz = action_chunk - chunk_sz
661
+
662
+ raw = torch.from_numpy(traj["targets"][t:real_end]).to(device)
663
+ c_j = raw[:, 0:6]
664
+ c_g = (raw[:, 6:7] < GRIPPER_CLOSE_THRESH).float()
665
+ c_s, _ = build_success_targets_and_mask(
666
+ raw[:, 7:8], t, traj["length"], chunk_sz, pad_sz, success_mode, SUCCESS_THRESH
667
+ )
668
+ c_s = c_s.to(device)
669
+
670
+ if pad_sz > 0:
671
+ c_j = torch.cat([c_j, c_j[-1:].repeat(pad_sz, 1)], 0)
672
+ c_g = torch.cat([c_g, c_g[-1:].repeat(pad_sz, 1)], 0)
673
+
674
+ mask = torch.cat([torch.ones(chunk_sz, 1), torch.zeros(pad_sz, 1)], 0).to(device)
675
+ t_j_list.append(c_j)
676
+ t_g_list.append(c_g)
677
+ t_s_list.append(c_s)
678
+ m_list.append(mask)
679
+
680
+ t_j = torch.stack(t_j_list)
681
+ t_g = torch.stack(t_g_list)
682
+ t_s = torch.stack(t_s_list)
683
+ loss_mask = torch.stack(m_list)
684
+
685
+ t_j_norm = (t_j - joint_mean_gpu) / joint_std_gpu
686
+
687
+ valid = loss_mask.sum()
688
+ if valid > 0:
689
+ l_j = (crit_mse(p_j, t_j_norm).mean(-1, keepdim=True) * loss_mask).sum() / valid
690
+ l_g = (crit_bce_w(p_g, t_g) * loss_mask).sum() / valid
691
+ l_s = (crit_bce(p_s, t_s) * loss_mask).sum() / valid
692
+
693
+ loss = lambda_joints * l_j + lambda_grip * l_g + lambda_success * l_s
694
+ (loss / update_freq).backward()
695
+
696
+ accum_steps += 1
697
+ total_loss += loss.item()
698
+ steps += 1
699
+
700
+ if accum_steps % update_freq == 0:
701
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=20.0)
702
+ optimizer.step()
703
+ optimizer.zero_grad(set_to_none=True)
704
+
705
+ with torch.no_grad():
706
+ real_p = p_j[:, 0, :] * joint_std_gpu + joint_mean_gpu
707
+ real_t = t_j[:, 0, :]
708
+ v0 = loss_mask[:, 0, :]
709
+ if v0.sum() > 0:
710
+ err = torch.abs(real_p - real_t).mean(dim=1)
711
+ real_joint_err_accum += (err * v0.squeeze()).sum().item() / (v0.sum().item() + 1e-6)
712
+
713
+ if accum_steps % update_freq != 0:
714
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=20.0)
715
+ optimizer.step()
716
+ optimizer.zero_grad(set_to_none=True)
717
+
718
+ # pbar.set_postfix({
719
+ # "Loss": f"{total_loss/steps:.3f}" if steps > 0 else "0",
720
+ # "J_Err": f"{real_joint_err_accum/steps:.3f}" if steps > 0 else "0"
721
+ # })
722
+
723
+ epoch_duration = time.time() - epoch_start_time
724
+
725
+ checkpoint = {
726
+ 'model_state_dict': model.state_dict(),
727
+ 'optimizer_state_dict': optimizer.state_dict(),
728
+ 'joint_mean': dataset.joint_mean,
729
+ 'joint_std': dataset.joint_std,
730
+ 'epoch': epoch + 1
731
+ }
732
+ last_saved_ckpt_path = os.path.join(ckpt_dir, f"epoch_{epoch+1}.pth")
733
+ torch.save(checkpoint, last_saved_ckpt_path)
734
+
735
+ avg_loss = total_loss / steps if steps > 0 else 0.0
736
+ avg_j_err = real_joint_err_accum / steps if steps > 0 else 0.0
737
+ msg_core = f"Ep {epoch+1} Saved. Time: {epoch_duration:.2f}s, Avg Loss: {avg_loss:.4f}, J_Err: {avg_j_err:.4f}"
738
+ msg = f"[{task_name}] {msg_core}" if task_name else msg_core
739
+
740
+ logger.info(msg)
741
+ # if log_path:
742
+ # with open(log_path, "a") as f:
743
+ # f.write(msg + "\n")
744
+
745
+ if report_url and task_name:
746
+ try:
747
+ import requests
748
+ payload = {
749
+ "task_name": task_name,
750
+ "epoch": epoch + 1,
751
+ "duration_sec": epoch_duration,
752
+ "avg_loss": avg_loss,
753
+ "j_err": avg_j_err,
754
+ "msg": msg_core,
755
+ "is_finished": False if epoch < num_epochs - 1 else True,
756
+ "model_path": last_saved_ckpt_path,
757
+ }
758
+ requests.post(report_url, json=payload, timeout=3)
759
+ except Exception as e:
760
+ logger.warning(f"上报失败: {e}")
761
+ except KeyboardInterrupt:
762
+ logger.warning("训练被用户中断")
763
+ raise
764
+ except Exception as e:
765
+ logger.error(f"训练过程中发生错误: {str(e)}\n{traceback.format_exc()}")
766
+ raise
767
+ finally:
768
+ logger.info("正在清理训练资源...")
769
+ if bnn_pool is not None:
770
+ try:
771
+ bnn_pool.close()
772
+ logger.info("BNN 进程池已关闭")
773
+ except Exception as e:
774
+ logger.error(f"关闭 BNN 进程池时出错: {str(e)}")
775
+
776
+ if io_pool is not None:
777
+ try:
778
+ io_pool.shutdown(wait=True, cancel_futures=True)
779
+ logger.info("IO 线程池已关闭")
780
+ except Exception as e:
781
+ logger.error(f"关闭 IO 线程池时出错: {str(e)}")
782
+
783
+ if device is not None and device.type == 'cuda':
784
+ try:
785
+ torch.cuda.empty_cache()
786
+ torch.cuda.ipc_collect()
787
+ logger.info("已清空 GPU 缓存")
788
+ except Exception as e:
789
+ logger.error(f"清空 GPU 缓存时出错: {str(e)}")
790
+
791
+ logger.info("资源清理完成")
792
+
793
+
794
+ if __name__ == "__main__":
795
+ mp.set_start_method('spawn', force=True)
796
+
797
+ import argparse
798
+
799
+ parser = argparse.ArgumentParser(description="单视角训练脚本")
800
+ parser.add_argument("--task_name", type=str, required=False, default=None)
801
+ parser.add_argument("--report_url", type=str, required=False, default=None)
802
+ parser.add_argument("--root_dir", type=str, required=True)
803
+ parser.add_argument("--batch_size", type=int, default=48)
804
+ parser.add_argument("--seq_len", type=int, default=4)
805
+ parser.add_argument("--action_chunk", type=int, default=8)
806
+ parser.add_argument("--lr", type=float, default=1e-4)
807
+ parser.add_argument("--num_epochs", type=int, default=500)
808
+ parser.add_argument("--start_epoch", type=int, default=0)
809
+ parser.add_argument("--lambda_joints", type=float, default=10.0)
810
+ parser.add_argument("--lambda_grip", type=float, default=5.0)
811
+ parser.add_argument("--lambda_success", type=float, default=2.0)
812
+ parser.add_argument("--log_path", type=str, default=None)
813
+ parser.add_argument("--ckpt_dir", type=str, required=True)
814
+ parser.add_argument("--success_mode", type=str, default="within_horizon")
815
+
816
+ args = parser.parse_args()
817
+
818
+ train_single_view_parallel_chunk(
819
+ task_name=args.task_name,
820
+ root_dir=args.root_dir,
821
+ batch_size=args.batch_size,
822
+ seq_len=args.seq_len,
823
+ action_chunk=args.action_chunk,
824
+ lr=args.lr,
825
+ num_epochs=args.num_epochs,
826
+ start_epoch=args.start_epoch,
827
+ log_path=args.log_path,
828
+ ckpt_dir=args.ckpt_dir,
829
+ lambda_joints=args.lambda_joints,
830
+ lambda_grip=args.lambda_grip,
831
+ lambda_success=args.lambda_success,
832
+ success_mode=args.success_mode,
833
+ report_url=args.report_url,
834
+ )