dhb-xr 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. dhb_xr/__init__.py +61 -0
  2. dhb_xr/cli.py +206 -0
  3. dhb_xr/core/__init__.py +28 -0
  4. dhb_xr/core/geometry.py +167 -0
  5. dhb_xr/core/geometry_torch.py +77 -0
  6. dhb_xr/core/types.py +113 -0
  7. dhb_xr/database/__init__.py +10 -0
  8. dhb_xr/database/motion_db.py +79 -0
  9. dhb_xr/database/retrieval.py +6 -0
  10. dhb_xr/database/similarity.py +71 -0
  11. dhb_xr/decoder/__init__.py +13 -0
  12. dhb_xr/decoder/decoder_torch.py +52 -0
  13. dhb_xr/decoder/dhb_dr.py +261 -0
  14. dhb_xr/decoder/dhb_qr.py +89 -0
  15. dhb_xr/encoder/__init__.py +27 -0
  16. dhb_xr/encoder/dhb_dr.py +418 -0
  17. dhb_xr/encoder/dhb_qr.py +129 -0
  18. dhb_xr/encoder/dhb_ti.py +204 -0
  19. dhb_xr/encoder/encoder_torch.py +54 -0
  20. dhb_xr/encoder/padding.py +82 -0
  21. dhb_xr/generative/__init__.py +78 -0
  22. dhb_xr/generative/flow_matching.py +705 -0
  23. dhb_xr/generative/latent_encoder.py +536 -0
  24. dhb_xr/generative/sampling.py +203 -0
  25. dhb_xr/generative/training.py +475 -0
  26. dhb_xr/generative/vfm_tokenizer.py +485 -0
  27. dhb_xr/integration/__init__.py +13 -0
  28. dhb_xr/integration/vla/__init__.py +11 -0
  29. dhb_xr/integration/vla/libero.py +132 -0
  30. dhb_xr/integration/vla/pipeline.py +85 -0
  31. dhb_xr/integration/vla/robocasa.py +85 -0
  32. dhb_xr/losses/__init__.py +16 -0
  33. dhb_xr/losses/geodesic_loss.py +91 -0
  34. dhb_xr/losses/hybrid_loss.py +36 -0
  35. dhb_xr/losses/invariant_loss.py +73 -0
  36. dhb_xr/optimization/__init__.py +72 -0
  37. dhb_xr/optimization/casadi_solver.py +342 -0
  38. dhb_xr/optimization/constraints.py +32 -0
  39. dhb_xr/optimization/cusadi_solver.py +311 -0
  40. dhb_xr/optimization/export_casadi_decode.py +111 -0
  41. dhb_xr/optimization/fatrop_solver.py +477 -0
  42. dhb_xr/optimization/torch_solver.py +85 -0
  43. dhb_xr/preprocessing/__init__.py +42 -0
  44. dhb_xr/preprocessing/diagnostics.py +330 -0
  45. dhb_xr/preprocessing/trajectory_cleaner.py +485 -0
  46. dhb_xr/tokenization/__init__.py +56 -0
  47. dhb_xr/tokenization/causal_encoder.py +54 -0
  48. dhb_xr/tokenization/compression.py +749 -0
  49. dhb_xr/tokenization/hierarchical.py +359 -0
  50. dhb_xr/tokenization/rvq.py +178 -0
  51. dhb_xr/tokenization/vqvae.py +155 -0
  52. dhb_xr/utils/__init__.py +24 -0
  53. dhb_xr/utils/io.py +59 -0
  54. dhb_xr/utils/resampling.py +66 -0
  55. dhb_xr/utils/xdof_loader.py +89 -0
  56. dhb_xr/visualization/__init__.py +5 -0
  57. dhb_xr/visualization/plot.py +242 -0
  58. dhb_xr-0.2.1.dist-info/METADATA +784 -0
  59. dhb_xr-0.2.1.dist-info/RECORD +82 -0
  60. dhb_xr-0.2.1.dist-info/WHEEL +5 -0
  61. dhb_xr-0.2.1.dist-info/entry_points.txt +2 -0
  62. dhb_xr-0.2.1.dist-info/top_level.txt +3 -0
  63. examples/__init__.py +54 -0
  64. examples/basic_encoding.py +82 -0
  65. examples/benchmark_backends.py +37 -0
  66. examples/dhb_qr_comparison.py +79 -0
  67. examples/dhb_ti_time_invariant.py +72 -0
  68. examples/gpu_batch_optimization.py +102 -0
  69. examples/imitation_learning.py +53 -0
  70. examples/integration/__init__.py +19 -0
  71. examples/integration/libero_full_demo.py +692 -0
  72. examples/integration/libero_pro_dhb_demo.py +1063 -0
  73. examples/integration/libero_simulation_demo.py +286 -0
  74. examples/integration/libero_swap_demo.py +534 -0
  75. examples/integration/robocasa_libero_dhb_pipeline.py +56 -0
  76. examples/integration/test_libero_adapter.py +47 -0
  77. examples/integration/test_libero_encoding.py +75 -0
  78. examples/integration/test_libero_retrieval.py +105 -0
  79. examples/motion_database.py +88 -0
  80. examples/trajectory_adaptation.py +85 -0
  81. examples/vla_tokenization.py +107 -0
  82. notebooks/__init__.py +24 -0
@@ -0,0 +1,534 @@
1
+ #!/usr/bin/env python
2
+ """DHB-XR vs Naive Replay under LIBERO-PRO Spatial Swap.
3
+
4
+ This script creates a compelling comparison showing WHY DHB-XR matters for
5
+ Vision-Language-Action (VLA) systems:
6
+
7
+ 1. **Original environment**: Robot picks up bowl, places on plate.
8
+ 2. **Swapped environment** (LIBERO-PRO swap perturbation):
9
+ Plate and cookies swap positions (~17cm shift).
10
+ - **Naive replay**: Same actions → robot reaches for OLD plate position → FAILS
11
+ - **DHB-adapted**: Trajectory adapted via Fatrop to NEW plate → reaches correct target
12
+
13
+ VLA Relevance
14
+ -------------
15
+ Current VLA models (RT-2, Octo, OpenVLA) learn a mapping from (vision + language)
16
+ to actions. When the scene changes, they either fail or require massive data
17
+ augmentation. DHB-XR turns this generalization problem from a *learning* problem
18
+ into a *geometry* problem:
19
+
20
+ Traditional VLA: Vision → Policy → Actions (tied to absolute positions)
21
+ With DHB-XR: Vision → Object Pose → DHB Adaptation → Actions
22
+ (same motion shape, new target)
23
+
24
+ Key advantages for VLA:
25
+ - **Data efficiency**: 1 demo + DHB adaptation covers spatial variations that
26
+ would otherwise require 100s of demonstrations
27
+ - **Robustness**: SE(3)-invariant encoding is immune to spatial perturbations
28
+ - **Speed**: Fatrop solver adapts trajectory in ~7ms (100+ Hz replanning)
29
+ - **Composability**: DHB invariants can augment any VLA policy as a trajectory
30
+ representation layer
31
+
32
+ Results from this demo:
33
+ - Naive replay: EE ends 5.2cm from OLD plate, 11.1cm from NEW plate
34
+ - DHB-adapted: EE ends 4.6cm from NEW plate (correct target)
35
+ - Improvement: 6.5cm closer to the correct target
36
+
37
+ Requirements:
38
+ ~/miniforge3/bin/mamba run -n libero python libero_swap_demo.py
39
+
40
+ Author: Andy Park
41
+ """
42
+
43
+ import os
44
+ import sys
45
+ import random
46
+ import tempfile
47
+ import time
48
+ from pathlib import Path
49
+ from typing import Dict, Optional
50
+
51
+ sys.path.insert(0, str(Path(__file__).parent.parent.parent / "src"))
52
+ sys.path.insert(0, "/home/andypark/Projects/repos/LIBERO-PRO")
53
+
54
+ import numpy as np
55
+
56
+ # DHB-XR
57
+ from dhb_xr.encoder.dhb_dr import encode_dhb_dr
58
+ from dhb_xr.decoder.dhb_dr import decode_dhb_dr
59
+ from dhb_xr.core.types import EncodingMethod, DHBMethod
60
+ from dhb_xr.core import geometry as geom
61
+
62
+ # Solver
63
+ try:
64
+ from dhb_xr.optimization.fatrop_solver import generate_trajectory_fatrop
65
+ HAS_FATROP = True
66
+ except ImportError:
67
+ HAS_FATROP = False
68
+
69
+ try:
70
+ from dhb_xr.optimization.casadi_solver import generate_trajectory as generate_trajectory_casadi
71
+ HAS_CASADI = True
72
+ except ImportError:
73
+ HAS_CASADI = False
74
+
75
+ # LIBERO
76
+ try:
77
+ import h5py
78
+ from libero.libero.benchmark import get_benchmark
79
+ from libero.libero.envs import OffScreenRenderEnv
80
+ HAS_LIBERO = True
81
+ except ImportError:
82
+ HAS_LIBERO = False
83
+
84
+ # LIBERO-PRO perturbation
85
+ try:
86
+ from perturbation import BDDLParser, SwapPerturbator
87
+ HAS_PERTURBATION = True
88
+ except ImportError:
89
+ HAS_PERTURBATION = False
90
+
91
+ # Visualization
92
+ try:
93
+ import matplotlib
94
+ matplotlib.use('Agg')
95
+ import matplotlib.pyplot as plt
96
+ HAS_MATPLOTLIB = True
97
+ except ImportError:
98
+ HAS_MATPLOTLIB = False
99
+
100
+ try:
101
+ import imageio
102
+ HAS_IMAGEIO = True
103
+ except ImportError:
104
+ HAS_IMAGEIO = False
105
+
106
+
107
+ # =============================================================================
108
+ # Constants
109
+ # =============================================================================
110
+
111
+ TASK_NAME = "pick_up_the_black_bowl_between_the_plate_and_the_ramekin_and_place_it_on_the_plate"
112
+ SWAP_CONFIG = "/home/andypark/Projects/repos/LIBERO-PRO/libero_ood/ood_spatial_relation.yaml"
113
+ DATASET_DIR = Path("/home/andypark/Projects/data/libero/libero_spatial")
114
+
115
+
116
+ # =============================================================================
117
+ # Helpers
118
+ # =============================================================================
119
+
120
+ def load_demo(task_id: int = 0, demo_id: int = 0) -> Dict:
121
+ """Load demo from LIBERO HDF5."""
122
+ hdf5_files = sorted(DATASET_DIR.glob("*.hdf5"))
123
+ with h5py.File(str(hdf5_files[task_id]), "r") as f:
124
+ demo = f["data"][f"demo_{demo_id}"]
125
+ actions = np.array(demo["actions"])
126
+ robot_states = np.array(demo["robot_states"])
127
+ ee_pos = robot_states[:, 2:5].astype(np.float64)
128
+ ee_quat_wxyz = robot_states[:, 5:9].astype(np.float64)
129
+ ee_quat = ee_quat_wxyz[:, [1, 2, 3, 0]] # → xyzw
130
+ ee_quat /= np.linalg.norm(ee_quat, axis=1, keepdims=True)
131
+ return {"actions": actions, "positions": ee_pos, "quaternions": ee_quat,
132
+ "num_frames": len(actions)}
133
+
134
+
135
+ def create_swapped_bddl(bddl_path: str, seed: int = 42) -> str:
136
+ """Apply SwapPerturbator to swap object positions."""
137
+ with open(bddl_path) as f:
138
+ content = f.read()
139
+ random.seed(seed)
140
+ parser = BDDLParser(content)
141
+ swap = SwapPerturbator(parser, SWAP_CONFIG)
142
+ swapped = swap.perturb("libero_spatial", TASK_NAME)
143
+ tmp = tempfile.NamedTemporaryFile(suffix='.bddl', mode='w', delete=False)
144
+ tmp.write(swapped)
145
+ tmp.close()
146
+ return tmp.name
147
+
148
+
149
+ def run_episode(env, actions):
150
+ """Run actions in environment, return EE trajectory and frames."""
151
+ obs = env.reset()
152
+ ee_pos = [obs["robot0_eef_pos"].copy()]
153
+ ee_quat = [obs["robot0_eef_quat"].copy()]
154
+ frames = []
155
+ reward = 0
156
+
157
+ for step in range(len(actions)):
158
+ obs, r, done, info = env.step(actions[step])
159
+ reward += r
160
+ ee_pos.append(obs["robot0_eef_pos"].copy())
161
+ ee_quat.append(obs["robot0_eef_quat"].copy())
162
+ frame = obs.get("agentview_image")
163
+ if frame is not None:
164
+ frames.append(frame[::-1].copy())
165
+ if done or r > 0:
166
+ break
167
+
168
+ return {
169
+ "positions": np.array(ee_pos, dtype=np.float64),
170
+ "quaternions": np.array(ee_quat, dtype=np.float64),
171
+ "frames": frames, "reward": reward,
172
+ "success": reward > 0, "steps": len(ee_pos) - 1,
173
+ }
174
+
175
+
176
+ def adapt_trajectory_solver(demo_pos, demo_quat, new_start, new_start_quat,
177
+ new_goal, new_goal_quat, traj_length=50):
178
+ """Adapt trajectory using Fatrop or CasADi solver."""
179
+ pose_init = {"position": new_start, "quaternion": new_start_quat}
180
+ pose_goal = {"position": new_goal, "quaternion": new_goal_quat}
181
+
182
+ if HAS_FATROP:
183
+ try:
184
+ result = generate_trajectory_fatrop(
185
+ demo_pos, demo_quat,
186
+ pose_target_init=pose_init, pose_target_final=pose_goal,
187
+ traj_length=traj_length, use_fatrop=True, verbose=False,
188
+ )
189
+ if result.get("success"):
190
+ return {"positions": result["positions"],
191
+ "quaternions": result["quaternions"],
192
+ "solver": "fatrop", "solve_time": result["solve_time"]}
193
+ except Exception as e:
194
+ print(f" Fatrop: {e}")
195
+
196
+ if HAS_CASADI:
197
+ try:
198
+ result = generate_trajectory_casadi(
199
+ demo_pos, demo_quat,
200
+ pose_target_init=pose_init, pose_target_final=pose_goal,
201
+ traj_length=traj_length, dhb_method=DHBMethod.DOUBLE_REFLECTION,
202
+ use_casadi=True, verbose=False,
203
+ )
204
+ if result.get("solver") == "casadi":
205
+ return {"positions": result["adapted_pos_data"],
206
+ "quaternions": result["adapted_quat_data"],
207
+ "solver": "casadi"}
208
+ except Exception as e:
209
+ print(f" CasADi: {e}")
210
+
211
+ # Fallback: encode-decode
212
+ encoded = encode_dhb_dr(demo_pos, demo_quat, method=EncodingMethod.POSITION,
213
+ dhb_method=DHBMethod.DOUBLE_REFLECTION)
214
+ decoded = decode_dhb_dr(encoded["linear_motion_invariants"],
215
+ encoded["angular_motion_invariants"],
216
+ {"position": new_start, "quaternion": new_start_quat},
217
+ method=EncodingMethod.POSITION,
218
+ dhb_method=DHBMethod.DOUBLE_REFLECTION)
219
+ return {"positions": decoded["positions"], "quaternions": decoded["quaternions"],
220
+ "solver": "decode"}
221
+
222
+
223
+ # =============================================================================
224
+ # Main
225
+ # =============================================================================
226
+
227
+ def main():
228
+ print("=" * 70)
229
+ print(" DHB-XR vs Naive Replay: LIBERO-PRO Spatial Swap Demo")
230
+ print("=" * 70)
231
+
232
+ if not all([HAS_LIBERO, HAS_PERTURBATION]):
233
+ print("ERROR: LIBERO and LIBERO-PRO perturbation module required.")
234
+ return
235
+
236
+ print(f"\nSolver: Fatrop={'YES' if HAS_FATROP else 'NO'}, CasADi={'YES' if HAS_CASADI else 'NO'}")
237
+
238
+ # --- Load demo ---
239
+ demo = load_demo(task_id=0, demo_id=0)
240
+ print(f"Demo: {demo['num_frames']} frames")
241
+
242
+ # --- Get BDDL ---
243
+ bm = get_benchmark('libero_spatial')()
244
+ orig_bddl = bm.get_task_bddl_file_path(0)
245
+
246
+ # --- Create swapped BDDL ---
247
+ print("\n--- Creating swapped environment (plate ↔ cookies) ---")
248
+ swapped_bddl = create_swapped_bddl(orig_bddl, seed=42)
249
+
250
+ # =====================================================================
251
+ # Run original demo (try up to 3 times for stochastic success)
252
+ # =====================================================================
253
+ print("\n" + "=" * 70)
254
+ print(" STEP 1: Original demo in ORIGINAL environment")
255
+ print("=" * 70)
256
+
257
+ result_orig = None
258
+ for attempt in range(3):
259
+ env_orig = OffScreenRenderEnv(
260
+ bddl_file_name=orig_bddl, camera_heights=256, camera_widths=256)
261
+ obs_orig = env_orig.reset()
262
+ plate_pos_orig = obs_orig.get("plate_1_pos", np.zeros(3)).copy()
263
+ bowl_pos_orig = obs_orig.get("akita_black_bowl_1_pos", np.zeros(3)).copy()
264
+
265
+ result_orig = run_episode(env_orig, demo["actions"])
266
+ env_orig.close()
267
+ print(f" Attempt {attempt+1}: {'SUCCESS' if result_orig['success'] else 'FAILED'} "
268
+ f"(plate=[{plate_pos_orig[0]:.3f}, {plate_pos_orig[1]:.3f}])")
269
+ if result_orig["success"]:
270
+ break
271
+
272
+ print(f" Plate position: [{plate_pos_orig[0]:.4f}, {plate_pos_orig[1]:.4f}, {plate_pos_orig[2]:.4f}]")
273
+ print(f" Bowl position: [{bowl_pos_orig[0]:.4f}, {bowl_pos_orig[1]:.4f}, {bowl_pos_orig[2]:.4f}]")
274
+
275
+ # =====================================================================
276
+ # Naive replay in swapped environment
277
+ # =====================================================================
278
+ print("\n" + "=" * 70)
279
+ print(" STEP 2: NAIVE REPLAY in SWAPPED environment")
280
+ print("=" * 70)
281
+
282
+ env_swap = OffScreenRenderEnv(
283
+ bddl_file_name=swapped_bddl, camera_heights=256, camera_widths=256)
284
+ obs_swap = env_swap.reset()
285
+ plate_pos_swap = obs_swap.get("plate_1_pos", np.zeros(3)).copy()
286
+ bowl_pos_swap = obs_swap.get("akita_black_bowl_1_pos", np.zeros(3)).copy()
287
+
288
+ plate_shift = np.linalg.norm(plate_pos_swap - plate_pos_orig)
289
+ offset = plate_pos_swap - plate_pos_orig
290
+
291
+ print(f" Plate NEW pos: [{plate_pos_swap[0]:.4f}, {plate_pos_swap[1]:.4f}, {plate_pos_swap[2]:.4f}]")
292
+ print(f" Plate SHIFTED: {plate_shift*100:.1f} cm!")
293
+
294
+ result_naive = run_episode(env_swap, demo["actions"])
295
+ env_swap.close()
296
+
297
+ naive_final = result_naive["positions"][-1]
298
+ naive_to_old = np.linalg.norm(naive_final[:2] - plate_pos_orig[:2])
299
+ naive_to_new = np.linalg.norm(naive_final[:2] - plate_pos_swap[:2])
300
+
301
+ print(f" Result: {'SUCCESS' if result_naive['success'] else 'FAILED'}")
302
+ print(f" EE final → OLD plate: {naive_to_old*100:.1f} cm (robot went HERE)")
303
+ print(f" EE final → NEW plate: {naive_to_new*100:.1f} cm (should have gone HERE)")
304
+
305
+ # =====================================================================
306
+ # DHB-adapted trajectory
307
+ # =====================================================================
308
+ print("\n" + "=" * 70)
309
+ print(" STEP 3: DHB-ADAPTED trajectory (Fatrop solver)")
310
+ print("=" * 70)
311
+
312
+ # Adapt the demo trajectory: shift start by bowl offset, shift goal by plate offset
313
+ demo_pos = demo["positions"]
314
+ demo_quat = demo["quaternions"]
315
+
316
+ # The goal of the demo is to place the bowl on the plate.
317
+ # In the swapped env, the plate is at a new position.
318
+ # We adapt the GOAL to the new plate position.
319
+ new_goal_pos = demo_pos[-1] + offset
320
+ new_goal_quat = demo_quat[-1].copy()
321
+
322
+ print(f" Plate offset: [{offset[0]:+.4f}, {offset[1]:+.4f}, {offset[2]:+.4f}]")
323
+ print(f" Original goal: [{demo_pos[-1][0]:.4f}, {demo_pos[-1][1]:.4f}, {demo_pos[-1][2]:.4f}]")
324
+ print(f" Adapted goal: [{new_goal_pos[0]:.4f}, {new_goal_pos[1]:.4f}, {new_goal_pos[2]:.4f}]")
325
+
326
+ t0 = time.perf_counter()
327
+ adapted = adapt_trajectory_solver(
328
+ demo_pos, demo_quat,
329
+ demo_pos[0].copy(), demo_quat[0].copy(),
330
+ new_goal_pos, new_goal_quat,
331
+ traj_length=50,
332
+ )
333
+ t_total = time.perf_counter() - t0
334
+ solver = adapted.get("solver", "unknown")
335
+ solve_time = adapted.get("solve_time", t_total)
336
+
337
+ adapted_pos = adapted["positions"]
338
+ adapted_final = adapted_pos[-1]
339
+ dhb_to_new = np.linalg.norm(adapted_final[:2] - plate_pos_swap[:2])
340
+ dhb_to_old = np.linalg.norm(adapted_final[:2] - plate_pos_orig[:2])
341
+
342
+ print(f"\n Solver: {solver} ({solve_time*1000:.1f} ms)")
343
+ print(f" Adapted traj: {len(adapted_pos)} waypoints")
344
+ print(f" Start error: {np.linalg.norm(adapted_pos[0] - demo_pos[0])*1000:.3f} mm")
345
+ print(f" Goal error: {np.linalg.norm(adapted_pos[-1] - new_goal_pos)*1000:.3f} mm")
346
+ print(f" Adapted end → NEW plate: {dhb_to_new*100:.1f} cm (CORRECT target)")
347
+ print(f" Adapted end → OLD plate: {dhb_to_old*100:.1f} cm")
348
+
349
+ # =====================================================================
350
+ # Summary
351
+ # =====================================================================
352
+ print("\n" + "=" * 70)
353
+ print(" COMPARISON SUMMARY")
354
+ print("=" * 70)
355
+ print(f" Plate shift: {plate_shift*100:.1f} cm (via LIBERO-PRO swap perturbation)")
356
+ print()
357
+ print(f" {'Method':<30s} {'EE final → OLD plate':>20s} {'EE final → NEW plate':>20s}")
358
+ print(f" {'-'*70}")
359
+ print(f" {'Naive replay (same actions)':<30s} {naive_to_old*100:>18.1f}cm {naive_to_new*100:>18.1f}cm")
360
+ print(f" {'DHB-adapted (Fatrop solver)':<30s} {dhb_to_old*100:>18.1f}cm {dhb_to_new*100:>18.1f}cm")
361
+ print()
362
+ print(f" Naive replay: EE ends near OLD plate (blind to scene change)")
363
+ print(f" DHB-adapted: EE ends near NEW plate (adapted to correct target)")
364
+ print(f" Improvement: {(naive_to_new - dhb_to_new)*100:.1f} cm closer to correct target")
365
+ print(f" Solver: {solver} ({solve_time*1000:.1f} ms)")
366
+ print("=" * 70)
367
+
368
+ # =====================================================================
369
+ # Visualization
370
+ # =====================================================================
371
+ if HAS_MATPLOTLIB:
372
+ fig = plt.figure(figsize=(20, 10))
373
+ fig.suptitle(
374
+ f"DHB-XR vs Naive Replay under LIBERO-PRO Spatial Swap\n"
375
+ f"Plate shifted {plate_shift*100:.1f} cm — Solver: {solver} ({solve_time*1000:.1f} ms)",
376
+ fontsize=14, fontweight='bold',
377
+ )
378
+
379
+ # --- Panel 1: Simulation frames ---
380
+ ax1 = fig.add_subplot(2, 4, 1)
381
+ if result_orig and result_orig["frames"]:
382
+ mid = len(result_orig["frames"]) * 3 // 4
383
+ ax1.imshow(result_orig["frames"][min(mid, len(result_orig["frames"])-1)])
384
+ ax1.set_title(f"Original Demo\n{'SUCCESS' if result_orig['success'] else 'FAILED'}",
385
+ color='green' if result_orig['success'] else 'red', fontweight='bold')
386
+ ax1.axis('off')
387
+
388
+ ax2 = fig.add_subplot(2, 4, 2)
389
+ if result_naive["frames"]:
390
+ mid = len(result_naive["frames"]) * 3 // 4
391
+ ax2.imshow(result_naive["frames"][min(mid, len(result_naive["frames"])-1)])
392
+ ax2.set_title(f"Naive Replay (swapped)\nFAILED — wrong target!",
393
+ color='red', fontweight='bold')
394
+ ax2.axis('off')
395
+
396
+ # --- Panel 2: Top-down XY view (KEY PANEL) ---
397
+ ax_xy = fig.add_subplot(2, 4, (3, 4))
398
+ # Original demo trajectory
399
+ orig_pos = result_orig["positions"] if result_orig else demo_pos
400
+ ax_xy.plot(orig_pos[:, 0], orig_pos[:, 1], 'b-', lw=2, label='Original demo', alpha=0.7)
401
+ ax_xy.scatter(orig_pos[-1, 0], orig_pos[-1, 1], c='blue', s=80, marker='x', zorder=5)
402
+
403
+ # Naive replay trajectory
404
+ naive_pos = result_naive["positions"]
405
+ ax_xy.plot(naive_pos[:, 0], naive_pos[:, 1], 'r--', lw=2, label='Naive replay', alpha=0.8)
406
+ ax_xy.scatter(naive_pos[-1, 0], naive_pos[-1, 1], c='red', s=80, marker='x', zorder=5)
407
+
408
+ # DHB-adapted trajectory
409
+ ax_xy.plot(adapted_pos[:, 0], adapted_pos[:, 1], 'g-', lw=2.5, label='DHB-adapted', alpha=0.9)
410
+ ax_xy.scatter(adapted_pos[-1, 0], adapted_pos[-1, 1], c='green', s=80, marker='x', zorder=5)
411
+
412
+ # Plate positions
413
+ ax_xy.scatter(plate_pos_orig[0], plate_pos_orig[1], c='blue', s=400, marker='s',
414
+ label=f'Plate (original)', zorder=10, edgecolors='black', linewidths=2, alpha=0.7)
415
+ ax_xy.scatter(plate_pos_swap[0], plate_pos_swap[1], c='green', s=400, marker='s',
416
+ label=f'Plate (swapped)', zorder=10, edgecolors='black', linewidths=2, alpha=0.7)
417
+ ax_xy.annotate('', xy=plate_pos_swap[:2], xytext=plate_pos_orig[:2],
418
+ arrowprops=dict(arrowstyle='->', color='purple', lw=3, linestyle='--'))
419
+ ax_xy.annotate(f'{plate_shift*100:.0f}cm shift',
420
+ xy=(plate_pos_orig[:2] + plate_pos_swap[:2])/2 + np.array([0.01, 0.01]),
421
+ fontsize=11, color='purple', fontweight='bold')
422
+
423
+ ax_xy.set_xlabel('X (m)', fontsize=12)
424
+ ax_xy.set_ylabel('Y (m)', fontsize=12)
425
+ ax_xy.set_title('Top-Down View: Where does the robot go?', fontsize=13, fontweight='bold')
426
+ ax_xy.legend(fontsize=9, loc='upper left')
427
+ ax_xy.grid(True, alpha=0.3)
428
+ ax_xy.set_aspect('equal')
429
+
430
+ # --- Panel 3: 3D trajectory comparison ---
431
+ ax3d = fig.add_subplot(2, 4, 5, projection='3d')
432
+ ax3d.plot(orig_pos[:, 0], orig_pos[:, 1], orig_pos[:, 2], 'b-', lw=1.5, label='Original', alpha=0.6)
433
+ ax3d.plot(naive_pos[:, 0], naive_pos[:, 1], naive_pos[:, 2], 'r--', lw=1.5, label='Naive', alpha=0.6)
434
+ ax3d.plot(adapted_pos[:, 0], adapted_pos[:, 1], adapted_pos[:, 2], 'g-', lw=2, label='DHB-adapted', alpha=0.8)
435
+ ax3d.scatter(*plate_pos_orig, c='blue', s=150, marker='s', zorder=5, edgecolors='black')
436
+ ax3d.scatter(*plate_pos_swap, c='green', s=150, marker='s', zorder=5, edgecolors='black')
437
+ ax3d.set_xlabel('X'); ax3d.set_ylabel('Y'); ax3d.set_zlabel('Z')
438
+ ax3d.set_title('3D Trajectories')
439
+ ax3d.legend(fontsize=7)
440
+
441
+ # --- Panel 4: EE position over time ---
442
+ ax_t = fig.add_subplot(2, 4, 6)
443
+ for dim, label, style in [(0, 'X', '-'), (1, 'Y', '--')]:
444
+ ax_t.plot(orig_pos[:, dim], f'b{style}', lw=1, label=f'Orig {label}', alpha=0.5)
445
+ ax_t.plot(naive_pos[:, dim], f'r{style}', lw=1, label=f'Naive {label}', alpha=0.5)
446
+ t_adapted = np.linspace(0, len(orig_pos)-1, len(adapted_pos))
447
+ ax_t.plot(t_adapted, adapted_pos[:, dim], f'g{style}', lw=2, label=f'DHB {label}', alpha=0.8)
448
+ # Mark plate positions
449
+ ax_t.axhline(plate_pos_orig[1], color='blue', ls=':', lw=1, alpha=0.5, label='Plate Y (orig)')
450
+ ax_t.axhline(plate_pos_swap[1], color='green', ls=':', lw=1, alpha=0.5, label='Plate Y (swap)')
451
+ ax_t.set_xlabel('Step'); ax_t.set_ylabel('Position (m)')
452
+ ax_t.set_title('EE X/Y over time')
453
+ ax_t.legend(fontsize=6, ncol=2); ax_t.grid(True, alpha=0.3)
454
+
455
+ # --- Panel 5: Distance to plate over time ---
456
+ ax_dist = fig.add_subplot(2, 4, 7)
457
+ t_orig = np.arange(len(orig_pos))
458
+ t_naive = np.arange(len(naive_pos))
459
+ t_dhb = np.linspace(0, max(len(orig_pos), len(naive_pos))-1, len(adapted_pos))
460
+
461
+ dist_orig = np.linalg.norm(orig_pos[:, :2] - plate_pos_orig[:2], axis=1)
462
+ dist_naive_old = np.linalg.norm(naive_pos[:, :2] - plate_pos_orig[:2], axis=1)
463
+ dist_naive_new = np.linalg.norm(naive_pos[:, :2] - plate_pos_swap[:2], axis=1)
464
+ dist_dhb_new = np.linalg.norm(adapted_pos[:, :2] - plate_pos_swap[:2], axis=1)
465
+
466
+ ax_dist.plot(t_orig, dist_orig * 100, 'b-', lw=1.5, label='Orig → plate', alpha=0.7)
467
+ ax_dist.plot(t_naive, dist_naive_new * 100, 'r--', lw=1.5, label='Naive → NEW plate', alpha=0.7)
468
+ ax_dist.plot(t_dhb, dist_dhb_new * 100, 'g-', lw=2, label='DHB → NEW plate', alpha=0.9)
469
+ ax_dist.axhline(0, color='gray', ls=':', lw=0.5)
470
+ ax_dist.set_xlabel('Step'); ax_dist.set_ylabel('Distance (cm)')
471
+ ax_dist.set_title('Distance to target plate')
472
+ ax_dist.legend(fontsize=8); ax_dist.grid(True, alpha=0.3)
473
+
474
+ # --- Panel 6: Summary text ---
475
+ ax_txt = fig.add_subplot(2, 4, 8)
476
+ lines = [
477
+ f"LIBERO-PRO Swap Perturbation",
478
+ f"plate_1 ↔ cookies_1",
479
+ f"",
480
+ f"Plate shifted: {plate_shift*100:.1f} cm",
481
+ f"",
482
+ f"Naive replay:",
483
+ f" → OLD plate: {naive_to_old*100:.1f} cm",
484
+ f" → NEW plate: {naive_to_new*100:.1f} cm",
485
+ f" Result: FAILED",
486
+ f"",
487
+ f"DHB-adapted (Fatrop):",
488
+ f" → OLD plate: {dhb_to_old*100:.1f} cm",
489
+ f" → NEW plate: {dhb_to_new*100:.1f} cm",
490
+ f" Goal error: {np.linalg.norm(adapted_pos[-1] - new_goal_pos)*1000:.1f} mm",
491
+ f" Solve time: {solve_time*1000:.1f} ms",
492
+ ]
493
+ for i, line in enumerate(lines):
494
+ color = 'red' if 'FAILED' in line else ('green' if 'DHB' in line else 'black')
495
+ ax_txt.text(0.05, 0.95 - i * 0.065, line, transform=ax_txt.transAxes,
496
+ fontsize=9, fontfamily='monospace', color=color,
497
+ fontweight='bold' if 'Naive' in line or 'DHB' in line else 'normal')
498
+ ax_txt.axis('off')
499
+
500
+ plt.tight_layout(rect=[0, 0, 1, 0.92])
501
+ out_path = '/tmp/dhb_swap_comparison.png'
502
+ plt.savefig(out_path, dpi=150, bbox_inches='tight')
503
+ print(f"\n Plot saved: {out_path}")
504
+ plt.close()
505
+
506
+ # Save video (side-by-side: original | naive in swapped env)
507
+ if HAS_IMAGEIO and result_orig:
508
+ video_path = '/tmp/dhb_swap_comparison.mp4'
509
+ print(f" Saving video: {video_path}")
510
+
511
+ frames_orig = result_orig["frames"] if result_orig else []
512
+ frames_naive = result_naive["frames"]
513
+
514
+ all_sets = [f for f in [frames_orig, frames_naive] if f]
515
+ if all_sets:
516
+ max_len = max(len(fs) for fs in all_sets)
517
+ writer = imageio.get_writer(video_path, fps=20)
518
+
519
+ for i in range(max_len):
520
+ row = []
521
+ for fs in all_sets:
522
+ f = fs[min(i, len(fs)-1)] if fs else np.zeros((256, 256, 3), dtype=np.uint8)
523
+ row.append(f)
524
+ combined = np.concatenate(row, axis=1)
525
+ writer.append_data(combined)
526
+ writer.close()
527
+ print(f" Video saved: {video_path}")
528
+
529
+ os.unlink(swapped_bddl)
530
+ print("\nDone!")
531
+
532
+
533
+ if __name__ == "__main__":
534
+ main()
@@ -0,0 +1,56 @@
1
+ """End-to-end DHB-XR pipeline demo for RoboCASA and Libero."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ from pathlib import Path
7
+
8
+ from dhb_xr.database.motion_db import MotionDatabase
9
+ from dhb_xr.integration.vla import RoboCASAAdapter, LiberoAdapter, DHBVLAPipeline
10
+
11
+
12
+ def run_pipeline(name: str, dataset_path: Path, max_episodes: int = 10) -> None:
13
+ adapter = RoboCASAAdapter() if name == "robocasa" else LiberoAdapter()
14
+ pipeline = DHBVLAPipeline()
15
+
16
+ episodes = []
17
+ for i, ep in enumerate(adapter.load_dataset(str(dataset_path))):
18
+ episodes.append(ep)
19
+ if i + 1 >= max_episodes:
20
+ break
21
+
22
+ print(f"[{name}] Loaded {len(episodes)} episodes from {dataset_path}")
23
+
24
+ outputs = pipeline.process_dataset(episodes)
25
+ print(f"[{name}] Tokenized {len(outputs)} trajectories")
26
+ if outputs:
27
+ print(f"[{name}] Example token shape: {outputs[0]['tokens'].shape}")
28
+
29
+ # Minimal retrieval demo
30
+ db = MotionDatabase()
31
+ for ep in episodes:
32
+ db.add(ep["positions"], ep["quaternions"], metadata=ep.get("metadata"))
33
+
34
+ if episodes:
35
+ query = episodes[0]
36
+ results = db.retrieve(query["positions"], query["quaternions"], k=min(3, len(episodes)))
37
+ print(f"[{name}] Retrieval top-{len(results)} distances:", [round(r[2], 4) for r in results])
38
+
39
+
40
+ def main() -> None:
41
+ parser = argparse.ArgumentParser(description="DHB-XR VLA integration demo")
42
+ parser.add_argument("--robocasa", type=str, default="", help="Path to RoboCASA HDF5 dataset")
43
+ parser.add_argument("--libero", type=str, default="", help="Path to Libero HDF5 dataset")
44
+ parser.add_argument("--max-episodes", type=int, default=10, help="Max episodes per dataset")
45
+ args = parser.parse_args()
46
+
47
+ if args.robocasa:
48
+ run_pipeline("robocasa", Path(args.robocasa), args.max_episodes)
49
+ if args.libero:
50
+ run_pipeline("libero", Path(args.libero), args.max_episodes)
51
+ if not args.robocasa and not args.libero:
52
+ raise SystemExit("Provide at least one dataset path via --robocasa or --libero.")
53
+
54
+
55
+ if __name__ == "__main__":
56
+ main()
@@ -0,0 +1,47 @@
1
+ #!/usr/bin/env python
2
+ """Test LiberoAdapter with real LIBERO-Spatial dataset.
3
+
4
+ Usage:
5
+ pixi run python examples/integration/test_libero_adapter.py
6
+
7
+ Requirements:
8
+ - LIBERO dataset downloaded to /home/andypark/Projects/data/libero/libero_spatial/
9
+ - Or set LIBERO_DATA_DIR environment variable
10
+ """
11
+ import os
12
+ import sys
13
+ from pathlib import Path
14
+
15
+ # Add src to path when running from project root
16
+ sys.path.insert(0, str(Path(__file__).parent.parent.parent / "src"))
17
+
18
+ from dhb_xr.integration.vla.libero import LiberoAdapter
19
+
20
+ DATA_DIR = Path(os.environ.get("LIBERO_DATA_DIR", "/home/andypark/Projects/data/libero/libero_spatial"))
21
+ FILE_PATH = DATA_DIR / "pick_up_the_black_bowl_between_the_plate_and_the_ramekin_and_place_it_on_the_plate_demo.hdf5"
22
+
23
+ print("=== Testing LiberoAdapter with real LIBERO data ===\n")
24
+
25
+ adapter = LiberoAdapter()
26
+ print(f"Adapter config:")
27
+ print(f" Position keys: {adapter.pos_keys}")
28
+ print(f" Quaternion keys: {adapter.quat_keys}")
29
+ print(f" Robot states key: {adapter.robot_states_key}")
30
+ print(f" Robot states quat slice: {adapter.robot_states_quat_slice}")
31
+ print(f" Robot states quat format: {adapter.robot_states_quat_format}")
32
+ print()
33
+
34
+ count = 0
35
+ for episode in adapter.load_dataset(str(FILE_PATH)):
36
+ count += 1
37
+ if count <= 3: # Show first 3 episodes
38
+ print(f"Episode {count}: {episode['metadata']['demo_id']}")
39
+ print(f" Positions: shape={episode['positions'].shape}")
40
+ print(f" Quaternions: shape={episode['quaternions'].shape}")
41
+ print(f" Quat source: {episode['metadata']['quat_source']}")
42
+ print(f" First position: {episode['positions'][0]}")
43
+ print(f" First quaternion (x,y,z,w): {episode['quaternions'][0]}")
44
+ print()
45
+
46
+ print(f"Total episodes loaded: {count}")
47
+ print("\n=== SUCCESS: LiberoAdapter works with real LIBERO data! ===")