dhb-xr 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. dhb_xr/__init__.py +61 -0
  2. dhb_xr/cli.py +206 -0
  3. dhb_xr/core/__init__.py +28 -0
  4. dhb_xr/core/geometry.py +167 -0
  5. dhb_xr/core/geometry_torch.py +77 -0
  6. dhb_xr/core/types.py +113 -0
  7. dhb_xr/database/__init__.py +10 -0
  8. dhb_xr/database/motion_db.py +79 -0
  9. dhb_xr/database/retrieval.py +6 -0
  10. dhb_xr/database/similarity.py +71 -0
  11. dhb_xr/decoder/__init__.py +13 -0
  12. dhb_xr/decoder/decoder_torch.py +52 -0
  13. dhb_xr/decoder/dhb_dr.py +261 -0
  14. dhb_xr/decoder/dhb_qr.py +89 -0
  15. dhb_xr/encoder/__init__.py +27 -0
  16. dhb_xr/encoder/dhb_dr.py +418 -0
  17. dhb_xr/encoder/dhb_qr.py +129 -0
  18. dhb_xr/encoder/dhb_ti.py +204 -0
  19. dhb_xr/encoder/encoder_torch.py +54 -0
  20. dhb_xr/encoder/padding.py +82 -0
  21. dhb_xr/generative/__init__.py +78 -0
  22. dhb_xr/generative/flow_matching.py +705 -0
  23. dhb_xr/generative/latent_encoder.py +536 -0
  24. dhb_xr/generative/sampling.py +203 -0
  25. dhb_xr/generative/training.py +475 -0
  26. dhb_xr/generative/vfm_tokenizer.py +485 -0
  27. dhb_xr/integration/__init__.py +13 -0
  28. dhb_xr/integration/vla/__init__.py +11 -0
  29. dhb_xr/integration/vla/libero.py +132 -0
  30. dhb_xr/integration/vla/pipeline.py +85 -0
  31. dhb_xr/integration/vla/robocasa.py +85 -0
  32. dhb_xr/losses/__init__.py +16 -0
  33. dhb_xr/losses/geodesic_loss.py +91 -0
  34. dhb_xr/losses/hybrid_loss.py +36 -0
  35. dhb_xr/losses/invariant_loss.py +73 -0
  36. dhb_xr/optimization/__init__.py +72 -0
  37. dhb_xr/optimization/casadi_solver.py +342 -0
  38. dhb_xr/optimization/constraints.py +32 -0
  39. dhb_xr/optimization/cusadi_solver.py +311 -0
  40. dhb_xr/optimization/export_casadi_decode.py +111 -0
  41. dhb_xr/optimization/fatrop_solver.py +477 -0
  42. dhb_xr/optimization/torch_solver.py +85 -0
  43. dhb_xr/preprocessing/__init__.py +42 -0
  44. dhb_xr/preprocessing/diagnostics.py +330 -0
  45. dhb_xr/preprocessing/trajectory_cleaner.py +485 -0
  46. dhb_xr/tokenization/__init__.py +56 -0
  47. dhb_xr/tokenization/causal_encoder.py +54 -0
  48. dhb_xr/tokenization/compression.py +749 -0
  49. dhb_xr/tokenization/hierarchical.py +359 -0
  50. dhb_xr/tokenization/rvq.py +178 -0
  51. dhb_xr/tokenization/vqvae.py +155 -0
  52. dhb_xr/utils/__init__.py +24 -0
  53. dhb_xr/utils/io.py +59 -0
  54. dhb_xr/utils/resampling.py +66 -0
  55. dhb_xr/utils/xdof_loader.py +89 -0
  56. dhb_xr/visualization/__init__.py +5 -0
  57. dhb_xr/visualization/plot.py +242 -0
  58. dhb_xr-0.2.1.dist-info/METADATA +784 -0
  59. dhb_xr-0.2.1.dist-info/RECORD +82 -0
  60. dhb_xr-0.2.1.dist-info/WHEEL +5 -0
  61. dhb_xr-0.2.1.dist-info/entry_points.txt +2 -0
  62. dhb_xr-0.2.1.dist-info/top_level.txt +3 -0
  63. examples/__init__.py +54 -0
  64. examples/basic_encoding.py +82 -0
  65. examples/benchmark_backends.py +37 -0
  66. examples/dhb_qr_comparison.py +79 -0
  67. examples/dhb_ti_time_invariant.py +72 -0
  68. examples/gpu_batch_optimization.py +102 -0
  69. examples/imitation_learning.py +53 -0
  70. examples/integration/__init__.py +19 -0
  71. examples/integration/libero_full_demo.py +692 -0
  72. examples/integration/libero_pro_dhb_demo.py +1063 -0
  73. examples/integration/libero_simulation_demo.py +286 -0
  74. examples/integration/libero_swap_demo.py +534 -0
  75. examples/integration/robocasa_libero_dhb_pipeline.py +56 -0
  76. examples/integration/test_libero_adapter.py +47 -0
  77. examples/integration/test_libero_encoding.py +75 -0
  78. examples/integration/test_libero_retrieval.py +105 -0
  79. examples/motion_database.py +88 -0
  80. examples/trajectory_adaptation.py +85 -0
  81. examples/vla_tokenization.py +107 -0
  82. notebooks/__init__.py +24 -0
@@ -0,0 +1,485 @@
1
+ """
2
+ VFM Token Generator: End-to-end integration of tokenizer + flow matching.
3
+
4
+ Provides a unified interface for:
5
+ 1. Encoding trajectories to continuous latent space
6
+ 2. Training flow matching on latent representations
7
+ 3. Generating diverse trajectories via flow matching sampling
8
+ 4. Decoding back to invariants and SE(3) poses
9
+
10
+ Supports both deterministic (FlowMatcher) and variational (VariationalFlowMatcher)
11
+ generation, with optional conditioning on goals, context, or task embeddings.
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+ from torch import Tensor
20
+ from typing import Optional, Dict, List, Tuple, Any, Union
21
+
22
+ from .flow_matching import FlowMatcher, VariationalFlowMatcher
23
+ from .latent_encoder import LatentEncoder, CategoricalLatentEncoder
24
+ from .sampling import ode_solve
25
+
26
+
27
+ class VFMTokenGenerator(nn.Module):
28
+ """
29
+ End-to-end variational flow matching generator for DHB invariants.
30
+
31
+ Combines:
32
+ - Tokenizer encoder: invariants -> continuous latent z
33
+ - Flow matcher: generative model over z
34
+ - Tokenizer decoder: z -> invariants
35
+
36
+ Example:
37
+ >>> tokenizer = DHBTokenizer(invariant_dim=8, latent_dim=32, codebook_size=512)
38
+ >>> flow_matcher = VariationalFlowMatcher(latent_dim=32, condition_dim=16)
39
+ >>> generator = VFMTokenGenerator(tokenizer, flow_matcher)
40
+ >>>
41
+ >>> # Train on invariant sequences
42
+ >>> loss = generator.loss(invariants)
43
+ >>> loss.backward()
44
+ >>>
45
+ >>> # Generate diverse trajectories
46
+ >>> samples = generator.generate_multimodal(num_samples=4, seq_len=50)
47
+ """
48
+
49
+ def __init__(
50
+ self,
51
+ tokenizer: nn.Module,
52
+ flow_matcher: Union[FlowMatcher, VariationalFlowMatcher],
53
+ latent_encoder: Optional[LatentEncoder] = None,
54
+ condition_encoder: Optional[nn.Module] = None,
55
+ ):
56
+ """
57
+ Args:
58
+ tokenizer: DHBTokenizer or ResidualVQTokenizer with encode_continuous()
59
+ and decode_from_latent() methods.
60
+ flow_matcher: FlowMatcher or VariationalFlowMatcher for generation.
61
+ latent_encoder: Optional external latent encoder (overrides flow_matcher's
62
+ internal encoder for VariationalFlowMatcher).
63
+ condition_encoder: Optional encoder for external conditions (e.g., goal pose,
64
+ task embedding).
65
+ """
66
+ super().__init__()
67
+ self.tokenizer = tokenizer
68
+ self.flow_matcher = flow_matcher
69
+ self.latent_encoder = latent_encoder
70
+ self.condition_encoder = condition_encoder
71
+
72
+ # Check if variational
73
+ self.is_variational = isinstance(flow_matcher, VariationalFlowMatcher)
74
+
75
+ # Get dimensions from tokenizer
76
+ self.latent_dim = getattr(tokenizer, "latent_dim", None)
77
+ if self.latent_dim is None:
78
+ raise ValueError("Tokenizer must have latent_dim attribute")
79
+
80
+ def encode(self, invariants: Tensor) -> Tensor:
81
+ """
82
+ Encode invariants to continuous latent space.
83
+
84
+ Args:
85
+ invariants: Input invariant sequences (B, T, invariant_dim).
86
+
87
+ Returns:
88
+ Continuous latent z (B, T, latent_dim).
89
+ """
90
+ return self.tokenizer.encode_continuous(invariants)
91
+
92
+ def decode(self, z: Tensor) -> Tensor:
93
+ """
94
+ Decode continuous latent to invariants.
95
+
96
+ Args:
97
+ z: Continuous latent (B, T, latent_dim).
98
+
99
+ Returns:
100
+ Reconstructed invariants (B, T, invariant_dim).
101
+ """
102
+ return self.tokenizer.decode_from_latent(z)
103
+
104
+ def encode_condition(self, condition: Any) -> Optional[Tensor]:
105
+ """
106
+ Encode external condition (if condition_encoder is provided).
107
+
108
+ Args:
109
+ condition: External condition (format depends on encoder).
110
+
111
+ Returns:
112
+ Condition embedding or None.
113
+ """
114
+ if self.condition_encoder is None or condition is None:
115
+ return None
116
+ return self.condition_encoder(condition)
117
+
118
+ def forward(
119
+ self,
120
+ invariants: Tensor,
121
+ condition: Optional[Any] = None,
122
+ ) -> Dict[str, Tensor]:
123
+ """
124
+ Forward pass: encode, flow match, decode.
125
+
126
+ Args:
127
+ invariants: Input invariant sequences (B, T, invariant_dim).
128
+ condition: Optional external condition.
129
+
130
+ Returns:
131
+ Dictionary with 'z', 'z_reconstructed', 'invariants_reconstructed'.
132
+ """
133
+ # Encode to latent
134
+ z = self.encode(invariants)
135
+
136
+ # Encode condition if provided
137
+ cond = self.encode_condition(condition)
138
+
139
+ # Pass through flow matcher (for analysis, not generation)
140
+ # During training, use loss() instead
141
+ z_recon = z # No-op for now
142
+
143
+ # Decode back to invariants
144
+ invariants_recon = self.decode(z)
145
+
146
+ return {
147
+ "z": z,
148
+ "z_reconstructed": z_recon,
149
+ "invariants_reconstructed": invariants_recon,
150
+ }
151
+
152
+ def loss(
153
+ self,
154
+ invariants: Tensor,
155
+ condition: Optional[Any] = None,
156
+ tokenizer_weight: float = 1.0,
157
+ flow_weight: float = 1.0,
158
+ beta: float = 0.01,
159
+ ) -> Dict[str, Tensor]:
160
+ """
161
+ Compute joint training loss.
162
+
163
+ Combines:
164
+ - Tokenizer reconstruction loss
165
+ - Flow matching velocity loss
166
+ - KL divergence (for variational models)
167
+
168
+ Args:
169
+ invariants: Input invariant sequences (B, T, invariant_dim).
170
+ condition: Optional external condition.
171
+ tokenizer_weight: Weight for tokenizer loss.
172
+ flow_weight: Weight for flow matching loss.
173
+ beta: KL divergence weight (for variational models).
174
+
175
+ Returns:
176
+ Dictionary with 'total', 'tokenizer', 'flow', and optionally 'kl'.
177
+ """
178
+ # Encode to latent
179
+ z_1 = self.encode(invariants) # "data" distribution
180
+
181
+ # Tokenizer loss (reconstruction through VQ)
182
+ indices, reconstructed, z, z_q = self.tokenizer(invariants)
183
+ tokenizer_loss = self.tokenizer.loss(invariants, reconstructed, z, z_q)
184
+
185
+ # Flow matching loss
186
+ if self.is_variational:
187
+ flow_losses = self.flow_matcher.loss(z_1, beta=beta)
188
+ flow_loss = flow_losses["total"]
189
+ kl_loss = flow_losses["kl"]
190
+ else:
191
+ flow_loss = self.flow_matcher.loss(z_1)
192
+ kl_loss = torch.tensor(0.0, device=invariants.device)
193
+
194
+ # Total loss
195
+ total_loss = tokenizer_weight * tokenizer_loss + flow_weight * flow_loss
196
+
197
+ return {
198
+ "total": total_loss,
199
+ "tokenizer": tokenizer_loss,
200
+ "flow": flow_loss,
201
+ "kl": kl_loss,
202
+ }
203
+
204
+ def generate(
205
+ self,
206
+ num_samples: int = 1,
207
+ seq_len: int = 50,
208
+ num_steps: int = 10,
209
+ method: str = "euler",
210
+ condition: Optional[Any] = None,
211
+ w: Optional[Tensor] = None,
212
+ device: str = "cpu",
213
+ return_latent: bool = False,
214
+ ) -> Tensor | Tuple[Tensor, Tensor]:
215
+ """
216
+ Generate invariant sequences via flow matching.
217
+
218
+ Args:
219
+ num_samples: Number of samples to generate.
220
+ seq_len: Sequence length.
221
+ num_steps: ODE integration steps.
222
+ method: ODE solver ('euler' or 'rk4').
223
+ condition: Optional external condition.
224
+ w: Optional conditioning latent for variational models.
225
+ device: Device for computation.
226
+ return_latent: If True, also return the generated latent z.
227
+
228
+ Returns:
229
+ Generated invariants (num_samples, seq_len, invariant_dim),
230
+ optionally with latent z.
231
+ """
232
+ # Generate latent via flow matching
233
+ if self.is_variational:
234
+ z = self.flow_matcher.sample(
235
+ num_samples=num_samples,
236
+ seq_len=seq_len,
237
+ num_steps=num_steps,
238
+ method=method,
239
+ w=w,
240
+ device=device,
241
+ )
242
+ else:
243
+ z = self.flow_matcher.sample(
244
+ num_samples=num_samples,
245
+ seq_len=seq_len,
246
+ num_steps=num_steps,
247
+ method=method,
248
+ device=device,
249
+ )
250
+
251
+ # Decode to invariants
252
+ invariants = self.decode(z)
253
+
254
+ if return_latent:
255
+ return invariants, z
256
+ return invariants
257
+
258
+ def generate_multimodal(
259
+ self,
260
+ num_samples: int = 1,
261
+ seq_len: int = 50,
262
+ num_modes: int = 4,
263
+ num_steps: int = 10,
264
+ method: str = "euler",
265
+ condition: Optional[Any] = None,
266
+ device: str = "cpu",
267
+ ) -> Tensor:
268
+ """
269
+ Generate diverse samples by sampling different conditioning latents.
270
+
271
+ Args:
272
+ num_samples: Number of samples per mode.
273
+ seq_len: Sequence length.
274
+ num_modes: Number of distinct modes to sample.
275
+ num_steps: ODE integration steps.
276
+ method: ODE solver method.
277
+ condition: Optional external condition.
278
+ device: Device for computation.
279
+
280
+ Returns:
281
+ Generated invariants (num_modes * num_samples, seq_len, invariant_dim).
282
+ """
283
+ if not self.is_variational:
284
+ # For non-variational, just generate multiple samples
285
+ return self.generate(
286
+ num_samples=num_modes * num_samples,
287
+ seq_len=seq_len,
288
+ num_steps=num_steps,
289
+ method=method,
290
+ device=device,
291
+ )
292
+
293
+ # Generate with different w for each mode
294
+ z = self.flow_matcher.sample_multimodal(
295
+ num_samples=num_samples,
296
+ seq_len=seq_len,
297
+ num_modes=num_modes,
298
+ num_steps=num_steps,
299
+ method=method,
300
+ device=device,
301
+ )
302
+
303
+ return self.decode(z)
304
+
305
+ def generate_continuation(
306
+ self,
307
+ prefix_invariants: Tensor,
308
+ continuation_len: int,
309
+ num_modes: int = 1,
310
+ num_steps: int = 10,
311
+ method: str = "euler",
312
+ ) -> Tensor:
313
+ """
314
+ Generate trajectory continuations from a prefix.
315
+
316
+ Args:
317
+ prefix_invariants: Prefix invariant sequence (B, prefix_len, invariant_dim).
318
+ continuation_len: Length of continuation to generate.
319
+ num_modes: Number of continuation modes (for variational).
320
+ num_steps: ODE integration steps.
321
+ method: ODE solver method.
322
+
323
+ Returns:
324
+ Continued trajectories (B * num_modes, prefix_len + continuation_len, invariant_dim).
325
+ """
326
+ B, prefix_len, _ = prefix_invariants.shape
327
+ device = prefix_invariants.device
328
+ total_len = prefix_len + continuation_len
329
+
330
+ # Encode prefix to latent
331
+ z_prefix = self.encode(prefix_invariants)
332
+
333
+ if self.is_variational:
334
+ # Sample continuations with different w
335
+ z_full = self.flow_matcher.sample_continuation(
336
+ z_prefix=z_prefix,
337
+ prefix_len=prefix_len,
338
+ total_len=total_len,
339
+ num_modes=num_modes,
340
+ num_steps=num_steps,
341
+ method=method,
342
+ )
343
+ else:
344
+ # Generate full sequences and keep prefix + continuation
345
+ all_z = []
346
+ for _ in range(num_modes):
347
+ z = self.flow_matcher.sample(
348
+ num_samples=B,
349
+ seq_len=total_len,
350
+ num_steps=num_steps,
351
+ method=method,
352
+ device=device,
353
+ )
354
+ # Blend: use prefix from input, continuation from generated
355
+ z_out = torch.cat([z_prefix, z[:, prefix_len:]], dim=1)
356
+ all_z.append(z_out)
357
+ z_full = torch.cat(all_z, dim=0)
358
+
359
+ return self.decode(z_full)
360
+
361
+ def interpolate_latent(
362
+ self,
363
+ invariants_a: Tensor,
364
+ invariants_b: Tensor,
365
+ num_interp: int = 5,
366
+ ) -> Tensor:
367
+ """
368
+ Interpolate between two invariant sequences in latent space.
369
+
370
+ Args:
371
+ invariants_a: First sequence (1, T, invariant_dim).
372
+ invariants_b: Second sequence (1, T, invariant_dim).
373
+ num_interp: Number of interpolation steps.
374
+
375
+ Returns:
376
+ Interpolated invariants (num_interp, T, invariant_dim).
377
+ """
378
+ z_a = self.encode(invariants_a)
379
+ z_b = self.encode(invariants_b)
380
+
381
+ # Linear interpolation in latent space
382
+ alphas = torch.linspace(0, 1, num_interp, device=z_a.device)
383
+ z_interp = []
384
+ for alpha in alphas:
385
+ z = (1 - alpha) * z_a + alpha * z_b
386
+ z_interp.append(z)
387
+
388
+ z_interp = torch.cat(z_interp, dim=0)
389
+ return self.decode(z_interp)
390
+
391
+
392
+ class ConditionalVFMGenerator(VFMTokenGenerator):
393
+ """
394
+ VFM generator with explicit goal/context conditioning.
395
+
396
+ Extends VFMTokenGenerator with:
397
+ - Goal pose conditioning
398
+ - Task embedding conditioning
399
+ - Context encoder for vision/language inputs
400
+ """
401
+
402
+ def __init__(
403
+ self,
404
+ tokenizer: nn.Module,
405
+ flow_matcher: VariationalFlowMatcher,
406
+ goal_dim: int = 7, # position (3) + quaternion (4)
407
+ context_dim: int = 0,
408
+ hidden_dim: int = 128,
409
+ ):
410
+ """
411
+ Args:
412
+ tokenizer: DHBTokenizer or RVQ with encode_continuous/decode_from_latent.
413
+ flow_matcher: VariationalFlowMatcher.
414
+ goal_dim: Dimension of goal pose (default: 7 for pos + quat).
415
+ context_dim: Dimension of external context (e.g., vision embedding).
416
+ hidden_dim: Hidden dimension for condition encoding.
417
+ """
418
+ # Create condition encoder
419
+ cond_input_dim = goal_dim + context_dim
420
+ condition_encoder = nn.Sequential(
421
+ nn.Linear(cond_input_dim, hidden_dim),
422
+ nn.GELU(),
423
+ nn.Linear(hidden_dim, flow_matcher.condition_dim),
424
+ ) if cond_input_dim > 0 else None
425
+
426
+ super().__init__(
427
+ tokenizer=tokenizer,
428
+ flow_matcher=flow_matcher,
429
+ condition_encoder=condition_encoder,
430
+ )
431
+
432
+ self.goal_dim = goal_dim
433
+ self.context_dim = context_dim
434
+
435
+ def generate_to_goal(
436
+ self,
437
+ goal_pose: Tensor,
438
+ num_samples: int = 1,
439
+ seq_len: int = 50,
440
+ num_steps: int = 10,
441
+ method: str = "euler",
442
+ context: Optional[Tensor] = None,
443
+ ) -> Tensor:
444
+ """
445
+ Generate trajectories conditioned on goal pose.
446
+
447
+ Args:
448
+ goal_pose: Goal pose (B, 7) with position and quaternion.
449
+ num_samples: Samples per goal.
450
+ seq_len: Sequence length.
451
+ num_steps: ODE integration steps.
452
+ method: ODE solver.
453
+ context: Optional external context.
454
+
455
+ Returns:
456
+ Generated invariants reaching the goal.
457
+ """
458
+ B = goal_pose.shape[0]
459
+ device = goal_pose.device
460
+
461
+ # Build condition
462
+ if context is not None:
463
+ cond_input = torch.cat([goal_pose, context], dim=-1)
464
+ else:
465
+ cond_input = goal_pose
466
+
467
+ # Encode condition to w
468
+ w = self.condition_encoder(cond_input)
469
+
470
+ # Expand for num_samples
471
+ if num_samples > 1:
472
+ w = w.unsqueeze(1).expand(-1, num_samples, -1)
473
+ w = w.reshape(B * num_samples, -1)
474
+
475
+ # Generate with goal conditioning
476
+ z = self.flow_matcher.sample(
477
+ num_samples=B * num_samples,
478
+ seq_len=seq_len,
479
+ num_steps=num_steps,
480
+ method=method,
481
+ w=w,
482
+ device=device,
483
+ )
484
+
485
+ return self.decode(z)
@@ -0,0 +1,13 @@
1
+ """Integration utilities for external frameworks."""
2
+
3
+ from dhb_xr.integration.vla import (
4
+ RoboCASAAdapter,
5
+ LiberoAdapter,
6
+ DHBVLAPipeline,
7
+ )
8
+
9
+ __all__ = [
10
+ "RoboCASAAdapter",
11
+ "LiberoAdapter",
12
+ "DHBVLAPipeline",
13
+ ]
@@ -0,0 +1,11 @@
1
+ """VLA framework integrations."""
2
+
3
+ from dhb_xr.integration.vla.robocasa import RoboCASAAdapter
4
+ from dhb_xr.integration.vla.libero import LiberoAdapter
5
+ from dhb_xr.integration.vla.pipeline import DHBVLAPipeline
6
+
7
+ __all__ = [
8
+ "RoboCASAAdapter",
9
+ "LiberoAdapter",
10
+ "DHBVLAPipeline",
11
+ ]
@@ -0,0 +1,132 @@
1
+ """Libero dataset adapter (robomimic-style HDF5).
2
+
3
+ LIBERO stores end-effector data as:
4
+ - Position: obs/ee_pos (N, 3)
5
+ - Orientation: robot_states[:, 5:9] as quaternion (w, x, y, z)
6
+
7
+ Note: obs/ee_ori contains Euler angles, not quaternions.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ from dataclasses import dataclass, field
13
+ from typing import Dict, Iterator, Optional, Tuple
14
+
15
+ import numpy as np
16
+
17
+ try:
18
+ import h5py
19
+ HAS_H5PY = True
20
+ except ImportError: # pragma: no cover - optional dependency
21
+ HAS_H5PY = False
22
+
23
+
24
+ DEFAULT_POS_KEYS = (
25
+ "robot0_eef_pos",
26
+ "eef_pos",
27
+ "ee_pos",
28
+ )
29
+ DEFAULT_QUAT_KEYS = (
30
+ "robot0_eef_quat",
31
+ "eef_quat",
32
+ "ee_quat",
33
+ )
34
+
35
+
36
+ @dataclass
37
+ class LiberoAdapter:
38
+ """
39
+ Minimal Libero adapter that yields (positions, quaternions, metadata).
40
+
41
+ Libero datasets are typically robomimic-style HDF5 files.
42
+
43
+ LIBERO stores quaternions in robot_states[:, 5:9] as (w, x, y, z),
44
+ not in the obs group. This adapter handles both cases.
45
+ """
46
+
47
+ pos_keys: Tuple[str, ...] = DEFAULT_POS_KEYS
48
+ quat_keys: Tuple[str, ...] = DEFAULT_QUAT_KEYS
49
+ obs_group: str = "obs"
50
+ # Robot states quaternion extraction (LIBERO-specific)
51
+ robot_states_key: str = "robot_states"
52
+ robot_states_quat_slice: Tuple[int, int] = field(default=(5, 9))
53
+ robot_states_quat_format: str = "wxyz" # LIBERO uses (w, x, y, z)
54
+
55
+ def _find_key(self, group: "h5py.Group", candidates: Tuple[str, ...]) -> Optional[str]:
56
+ for key in candidates:
57
+ if key in group:
58
+ return key
59
+ return None
60
+
61
+ def _extract_quaternion(
62
+ self, demo: "h5py.Group", obs: "h5py.Group"
63
+ ) -> Tuple[Optional[np.ndarray], str]:
64
+ """
65
+ Extract quaternions from demo, trying obs group first, then robot_states.
66
+
67
+ Returns:
68
+ (quaternions, source_key) where source_key describes where data came from.
69
+ """
70
+ # First try obs group
71
+ quat_key = self._find_key(obs, self.quat_keys)
72
+ if quat_key is not None:
73
+ quats = np.asarray(obs[quat_key], dtype=np.float64)
74
+ return quats, f"obs/{quat_key}"
75
+
76
+ # Fallback: extract from robot_states (LIBERO-specific)
77
+ if self.robot_states_key in demo:
78
+ robot_states = np.asarray(demo[self.robot_states_key], dtype=np.float64)
79
+ start, end = self.robot_states_quat_slice
80
+ quats = robot_states[:, start:end]
81
+
82
+ # Convert from wxyz to xyzw format (standard for scipy/transforms3d)
83
+ if self.robot_states_quat_format == "wxyz":
84
+ quats = quats[:, [1, 2, 3, 0]] # w,x,y,z -> x,y,z,w
85
+
86
+ return quats, f"{self.robot_states_key}[{start}:{end}]"
87
+
88
+ return None, ""
89
+
90
+ def load_dataset(self, dataset_path: str) -> Iterator[Dict]:
91
+ if not HAS_H5PY:
92
+ raise ImportError("h5py is required for Libero adapter (pip install h5py).")
93
+
94
+ with h5py.File(dataset_path, "r") as h5:
95
+ data_group = h5.get("data")
96
+ if data_group is None:
97
+ raise ValueError("Libero HDF5 missing /data group.")
98
+
99
+ for demo_id in data_group.keys():
100
+ demo = data_group[demo_id]
101
+ obs = demo.get(self.obs_group)
102
+ if obs is None:
103
+ continue
104
+
105
+ # Get position
106
+ pos_key = self._find_key(obs, self.pos_keys)
107
+ if pos_key is None:
108
+ continue
109
+ positions = np.asarray(obs[pos_key], dtype=np.float64)
110
+
111
+ # Get quaternion (try obs first, then robot_states)
112
+ quaternions, quat_source = self._extract_quaternion(demo, obs)
113
+ if quaternions is None:
114
+ continue
115
+
116
+ metadata = {
117
+ "demo_id": demo_id,
118
+ "pos_key": pos_key,
119
+ "quat_source": quat_source,
120
+ "source": "libero",
121
+ "num_frames": len(positions),
122
+ }
123
+ if "task" in demo.attrs:
124
+ metadata["task"] = demo.attrs["task"]
125
+ if "language" in demo.attrs:
126
+ metadata["language_instruction"] = demo.attrs["language"]
127
+
128
+ yield {
129
+ "positions": positions,
130
+ "quaternions": quaternions,
131
+ "metadata": metadata,
132
+ }