dhb-xr 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dhb_xr/__init__.py +61 -0
- dhb_xr/cli.py +206 -0
- dhb_xr/core/__init__.py +28 -0
- dhb_xr/core/geometry.py +167 -0
- dhb_xr/core/geometry_torch.py +77 -0
- dhb_xr/core/types.py +113 -0
- dhb_xr/database/__init__.py +10 -0
- dhb_xr/database/motion_db.py +79 -0
- dhb_xr/database/retrieval.py +6 -0
- dhb_xr/database/similarity.py +71 -0
- dhb_xr/decoder/__init__.py +13 -0
- dhb_xr/decoder/decoder_torch.py +52 -0
- dhb_xr/decoder/dhb_dr.py +261 -0
- dhb_xr/decoder/dhb_qr.py +89 -0
- dhb_xr/encoder/__init__.py +27 -0
- dhb_xr/encoder/dhb_dr.py +418 -0
- dhb_xr/encoder/dhb_qr.py +129 -0
- dhb_xr/encoder/dhb_ti.py +204 -0
- dhb_xr/encoder/encoder_torch.py +54 -0
- dhb_xr/encoder/padding.py +82 -0
- dhb_xr/generative/__init__.py +78 -0
- dhb_xr/generative/flow_matching.py +705 -0
- dhb_xr/generative/latent_encoder.py +536 -0
- dhb_xr/generative/sampling.py +203 -0
- dhb_xr/generative/training.py +475 -0
- dhb_xr/generative/vfm_tokenizer.py +485 -0
- dhb_xr/integration/__init__.py +13 -0
- dhb_xr/integration/vla/__init__.py +11 -0
- dhb_xr/integration/vla/libero.py +132 -0
- dhb_xr/integration/vla/pipeline.py +85 -0
- dhb_xr/integration/vla/robocasa.py +85 -0
- dhb_xr/losses/__init__.py +16 -0
- dhb_xr/losses/geodesic_loss.py +91 -0
- dhb_xr/losses/hybrid_loss.py +36 -0
- dhb_xr/losses/invariant_loss.py +73 -0
- dhb_xr/optimization/__init__.py +72 -0
- dhb_xr/optimization/casadi_solver.py +342 -0
- dhb_xr/optimization/constraints.py +32 -0
- dhb_xr/optimization/cusadi_solver.py +311 -0
- dhb_xr/optimization/export_casadi_decode.py +111 -0
- dhb_xr/optimization/fatrop_solver.py +477 -0
- dhb_xr/optimization/torch_solver.py +85 -0
- dhb_xr/preprocessing/__init__.py +42 -0
- dhb_xr/preprocessing/diagnostics.py +330 -0
- dhb_xr/preprocessing/trajectory_cleaner.py +485 -0
- dhb_xr/tokenization/__init__.py +56 -0
- dhb_xr/tokenization/causal_encoder.py +54 -0
- dhb_xr/tokenization/compression.py +749 -0
- dhb_xr/tokenization/hierarchical.py +359 -0
- dhb_xr/tokenization/rvq.py +178 -0
- dhb_xr/tokenization/vqvae.py +155 -0
- dhb_xr/utils/__init__.py +24 -0
- dhb_xr/utils/io.py +59 -0
- dhb_xr/utils/resampling.py +66 -0
- dhb_xr/utils/xdof_loader.py +89 -0
- dhb_xr/visualization/__init__.py +5 -0
- dhb_xr/visualization/plot.py +242 -0
- dhb_xr-0.2.1.dist-info/METADATA +784 -0
- dhb_xr-0.2.1.dist-info/RECORD +82 -0
- dhb_xr-0.2.1.dist-info/WHEEL +5 -0
- dhb_xr-0.2.1.dist-info/entry_points.txt +2 -0
- dhb_xr-0.2.1.dist-info/top_level.txt +3 -0
- examples/__init__.py +54 -0
- examples/basic_encoding.py +82 -0
- examples/benchmark_backends.py +37 -0
- examples/dhb_qr_comparison.py +79 -0
- examples/dhb_ti_time_invariant.py +72 -0
- examples/gpu_batch_optimization.py +102 -0
- examples/imitation_learning.py +53 -0
- examples/integration/__init__.py +19 -0
- examples/integration/libero_full_demo.py +692 -0
- examples/integration/libero_pro_dhb_demo.py +1063 -0
- examples/integration/libero_simulation_demo.py +286 -0
- examples/integration/libero_swap_demo.py +534 -0
- examples/integration/robocasa_libero_dhb_pipeline.py +56 -0
- examples/integration/test_libero_adapter.py +47 -0
- examples/integration/test_libero_encoding.py +75 -0
- examples/integration/test_libero_retrieval.py +105 -0
- examples/motion_database.py +88 -0
- examples/trajectory_adaptation.py +85 -0
- examples/vla_tokenization.py +107 -0
- notebooks/__init__.py +24 -0
|
@@ -0,0 +1,485 @@
|
|
|
1
|
+
"""
|
|
2
|
+
VFM Token Generator: End-to-end integration of tokenizer + flow matching.
|
|
3
|
+
|
|
4
|
+
Provides a unified interface for:
|
|
5
|
+
1. Encoding trajectories to continuous latent space
|
|
6
|
+
2. Training flow matching on latent representations
|
|
7
|
+
3. Generating diverse trajectories via flow matching sampling
|
|
8
|
+
4. Decoding back to invariants and SE(3) poses
|
|
9
|
+
|
|
10
|
+
Supports both deterministic (FlowMatcher) and variational (VariationalFlowMatcher)
|
|
11
|
+
generation, with optional conditioning on goals, context, or task embeddings.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
16
|
+
import torch
|
|
17
|
+
import torch.nn as nn
|
|
18
|
+
import torch.nn.functional as F
|
|
19
|
+
from torch import Tensor
|
|
20
|
+
from typing import Optional, Dict, List, Tuple, Any, Union
|
|
21
|
+
|
|
22
|
+
from .flow_matching import FlowMatcher, VariationalFlowMatcher
|
|
23
|
+
from .latent_encoder import LatentEncoder, CategoricalLatentEncoder
|
|
24
|
+
from .sampling import ode_solve
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class VFMTokenGenerator(nn.Module):
|
|
28
|
+
"""
|
|
29
|
+
End-to-end variational flow matching generator for DHB invariants.
|
|
30
|
+
|
|
31
|
+
Combines:
|
|
32
|
+
- Tokenizer encoder: invariants -> continuous latent z
|
|
33
|
+
- Flow matcher: generative model over z
|
|
34
|
+
- Tokenizer decoder: z -> invariants
|
|
35
|
+
|
|
36
|
+
Example:
|
|
37
|
+
>>> tokenizer = DHBTokenizer(invariant_dim=8, latent_dim=32, codebook_size=512)
|
|
38
|
+
>>> flow_matcher = VariationalFlowMatcher(latent_dim=32, condition_dim=16)
|
|
39
|
+
>>> generator = VFMTokenGenerator(tokenizer, flow_matcher)
|
|
40
|
+
>>>
|
|
41
|
+
>>> # Train on invariant sequences
|
|
42
|
+
>>> loss = generator.loss(invariants)
|
|
43
|
+
>>> loss.backward()
|
|
44
|
+
>>>
|
|
45
|
+
>>> # Generate diverse trajectories
|
|
46
|
+
>>> samples = generator.generate_multimodal(num_samples=4, seq_len=50)
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
def __init__(
|
|
50
|
+
self,
|
|
51
|
+
tokenizer: nn.Module,
|
|
52
|
+
flow_matcher: Union[FlowMatcher, VariationalFlowMatcher],
|
|
53
|
+
latent_encoder: Optional[LatentEncoder] = None,
|
|
54
|
+
condition_encoder: Optional[nn.Module] = None,
|
|
55
|
+
):
|
|
56
|
+
"""
|
|
57
|
+
Args:
|
|
58
|
+
tokenizer: DHBTokenizer or ResidualVQTokenizer with encode_continuous()
|
|
59
|
+
and decode_from_latent() methods.
|
|
60
|
+
flow_matcher: FlowMatcher or VariationalFlowMatcher for generation.
|
|
61
|
+
latent_encoder: Optional external latent encoder (overrides flow_matcher's
|
|
62
|
+
internal encoder for VariationalFlowMatcher).
|
|
63
|
+
condition_encoder: Optional encoder for external conditions (e.g., goal pose,
|
|
64
|
+
task embedding).
|
|
65
|
+
"""
|
|
66
|
+
super().__init__()
|
|
67
|
+
self.tokenizer = tokenizer
|
|
68
|
+
self.flow_matcher = flow_matcher
|
|
69
|
+
self.latent_encoder = latent_encoder
|
|
70
|
+
self.condition_encoder = condition_encoder
|
|
71
|
+
|
|
72
|
+
# Check if variational
|
|
73
|
+
self.is_variational = isinstance(flow_matcher, VariationalFlowMatcher)
|
|
74
|
+
|
|
75
|
+
# Get dimensions from tokenizer
|
|
76
|
+
self.latent_dim = getattr(tokenizer, "latent_dim", None)
|
|
77
|
+
if self.latent_dim is None:
|
|
78
|
+
raise ValueError("Tokenizer must have latent_dim attribute")
|
|
79
|
+
|
|
80
|
+
def encode(self, invariants: Tensor) -> Tensor:
|
|
81
|
+
"""
|
|
82
|
+
Encode invariants to continuous latent space.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
invariants: Input invariant sequences (B, T, invariant_dim).
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
Continuous latent z (B, T, latent_dim).
|
|
89
|
+
"""
|
|
90
|
+
return self.tokenizer.encode_continuous(invariants)
|
|
91
|
+
|
|
92
|
+
def decode(self, z: Tensor) -> Tensor:
|
|
93
|
+
"""
|
|
94
|
+
Decode continuous latent to invariants.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
z: Continuous latent (B, T, latent_dim).
|
|
98
|
+
|
|
99
|
+
Returns:
|
|
100
|
+
Reconstructed invariants (B, T, invariant_dim).
|
|
101
|
+
"""
|
|
102
|
+
return self.tokenizer.decode_from_latent(z)
|
|
103
|
+
|
|
104
|
+
def encode_condition(self, condition: Any) -> Optional[Tensor]:
|
|
105
|
+
"""
|
|
106
|
+
Encode external condition (if condition_encoder is provided).
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
condition: External condition (format depends on encoder).
|
|
110
|
+
|
|
111
|
+
Returns:
|
|
112
|
+
Condition embedding or None.
|
|
113
|
+
"""
|
|
114
|
+
if self.condition_encoder is None or condition is None:
|
|
115
|
+
return None
|
|
116
|
+
return self.condition_encoder(condition)
|
|
117
|
+
|
|
118
|
+
def forward(
|
|
119
|
+
self,
|
|
120
|
+
invariants: Tensor,
|
|
121
|
+
condition: Optional[Any] = None,
|
|
122
|
+
) -> Dict[str, Tensor]:
|
|
123
|
+
"""
|
|
124
|
+
Forward pass: encode, flow match, decode.
|
|
125
|
+
|
|
126
|
+
Args:
|
|
127
|
+
invariants: Input invariant sequences (B, T, invariant_dim).
|
|
128
|
+
condition: Optional external condition.
|
|
129
|
+
|
|
130
|
+
Returns:
|
|
131
|
+
Dictionary with 'z', 'z_reconstructed', 'invariants_reconstructed'.
|
|
132
|
+
"""
|
|
133
|
+
# Encode to latent
|
|
134
|
+
z = self.encode(invariants)
|
|
135
|
+
|
|
136
|
+
# Encode condition if provided
|
|
137
|
+
cond = self.encode_condition(condition)
|
|
138
|
+
|
|
139
|
+
# Pass through flow matcher (for analysis, not generation)
|
|
140
|
+
# During training, use loss() instead
|
|
141
|
+
z_recon = z # No-op for now
|
|
142
|
+
|
|
143
|
+
# Decode back to invariants
|
|
144
|
+
invariants_recon = self.decode(z)
|
|
145
|
+
|
|
146
|
+
return {
|
|
147
|
+
"z": z,
|
|
148
|
+
"z_reconstructed": z_recon,
|
|
149
|
+
"invariants_reconstructed": invariants_recon,
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
def loss(
|
|
153
|
+
self,
|
|
154
|
+
invariants: Tensor,
|
|
155
|
+
condition: Optional[Any] = None,
|
|
156
|
+
tokenizer_weight: float = 1.0,
|
|
157
|
+
flow_weight: float = 1.0,
|
|
158
|
+
beta: float = 0.01,
|
|
159
|
+
) -> Dict[str, Tensor]:
|
|
160
|
+
"""
|
|
161
|
+
Compute joint training loss.
|
|
162
|
+
|
|
163
|
+
Combines:
|
|
164
|
+
- Tokenizer reconstruction loss
|
|
165
|
+
- Flow matching velocity loss
|
|
166
|
+
- KL divergence (for variational models)
|
|
167
|
+
|
|
168
|
+
Args:
|
|
169
|
+
invariants: Input invariant sequences (B, T, invariant_dim).
|
|
170
|
+
condition: Optional external condition.
|
|
171
|
+
tokenizer_weight: Weight for tokenizer loss.
|
|
172
|
+
flow_weight: Weight for flow matching loss.
|
|
173
|
+
beta: KL divergence weight (for variational models).
|
|
174
|
+
|
|
175
|
+
Returns:
|
|
176
|
+
Dictionary with 'total', 'tokenizer', 'flow', and optionally 'kl'.
|
|
177
|
+
"""
|
|
178
|
+
# Encode to latent
|
|
179
|
+
z_1 = self.encode(invariants) # "data" distribution
|
|
180
|
+
|
|
181
|
+
# Tokenizer loss (reconstruction through VQ)
|
|
182
|
+
indices, reconstructed, z, z_q = self.tokenizer(invariants)
|
|
183
|
+
tokenizer_loss = self.tokenizer.loss(invariants, reconstructed, z, z_q)
|
|
184
|
+
|
|
185
|
+
# Flow matching loss
|
|
186
|
+
if self.is_variational:
|
|
187
|
+
flow_losses = self.flow_matcher.loss(z_1, beta=beta)
|
|
188
|
+
flow_loss = flow_losses["total"]
|
|
189
|
+
kl_loss = flow_losses["kl"]
|
|
190
|
+
else:
|
|
191
|
+
flow_loss = self.flow_matcher.loss(z_1)
|
|
192
|
+
kl_loss = torch.tensor(0.0, device=invariants.device)
|
|
193
|
+
|
|
194
|
+
# Total loss
|
|
195
|
+
total_loss = tokenizer_weight * tokenizer_loss + flow_weight * flow_loss
|
|
196
|
+
|
|
197
|
+
return {
|
|
198
|
+
"total": total_loss,
|
|
199
|
+
"tokenizer": tokenizer_loss,
|
|
200
|
+
"flow": flow_loss,
|
|
201
|
+
"kl": kl_loss,
|
|
202
|
+
}
|
|
203
|
+
|
|
204
|
+
def generate(
|
|
205
|
+
self,
|
|
206
|
+
num_samples: int = 1,
|
|
207
|
+
seq_len: int = 50,
|
|
208
|
+
num_steps: int = 10,
|
|
209
|
+
method: str = "euler",
|
|
210
|
+
condition: Optional[Any] = None,
|
|
211
|
+
w: Optional[Tensor] = None,
|
|
212
|
+
device: str = "cpu",
|
|
213
|
+
return_latent: bool = False,
|
|
214
|
+
) -> Tensor | Tuple[Tensor, Tensor]:
|
|
215
|
+
"""
|
|
216
|
+
Generate invariant sequences via flow matching.
|
|
217
|
+
|
|
218
|
+
Args:
|
|
219
|
+
num_samples: Number of samples to generate.
|
|
220
|
+
seq_len: Sequence length.
|
|
221
|
+
num_steps: ODE integration steps.
|
|
222
|
+
method: ODE solver ('euler' or 'rk4').
|
|
223
|
+
condition: Optional external condition.
|
|
224
|
+
w: Optional conditioning latent for variational models.
|
|
225
|
+
device: Device for computation.
|
|
226
|
+
return_latent: If True, also return the generated latent z.
|
|
227
|
+
|
|
228
|
+
Returns:
|
|
229
|
+
Generated invariants (num_samples, seq_len, invariant_dim),
|
|
230
|
+
optionally with latent z.
|
|
231
|
+
"""
|
|
232
|
+
# Generate latent via flow matching
|
|
233
|
+
if self.is_variational:
|
|
234
|
+
z = self.flow_matcher.sample(
|
|
235
|
+
num_samples=num_samples,
|
|
236
|
+
seq_len=seq_len,
|
|
237
|
+
num_steps=num_steps,
|
|
238
|
+
method=method,
|
|
239
|
+
w=w,
|
|
240
|
+
device=device,
|
|
241
|
+
)
|
|
242
|
+
else:
|
|
243
|
+
z = self.flow_matcher.sample(
|
|
244
|
+
num_samples=num_samples,
|
|
245
|
+
seq_len=seq_len,
|
|
246
|
+
num_steps=num_steps,
|
|
247
|
+
method=method,
|
|
248
|
+
device=device,
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
# Decode to invariants
|
|
252
|
+
invariants = self.decode(z)
|
|
253
|
+
|
|
254
|
+
if return_latent:
|
|
255
|
+
return invariants, z
|
|
256
|
+
return invariants
|
|
257
|
+
|
|
258
|
+
def generate_multimodal(
|
|
259
|
+
self,
|
|
260
|
+
num_samples: int = 1,
|
|
261
|
+
seq_len: int = 50,
|
|
262
|
+
num_modes: int = 4,
|
|
263
|
+
num_steps: int = 10,
|
|
264
|
+
method: str = "euler",
|
|
265
|
+
condition: Optional[Any] = None,
|
|
266
|
+
device: str = "cpu",
|
|
267
|
+
) -> Tensor:
|
|
268
|
+
"""
|
|
269
|
+
Generate diverse samples by sampling different conditioning latents.
|
|
270
|
+
|
|
271
|
+
Args:
|
|
272
|
+
num_samples: Number of samples per mode.
|
|
273
|
+
seq_len: Sequence length.
|
|
274
|
+
num_modes: Number of distinct modes to sample.
|
|
275
|
+
num_steps: ODE integration steps.
|
|
276
|
+
method: ODE solver method.
|
|
277
|
+
condition: Optional external condition.
|
|
278
|
+
device: Device for computation.
|
|
279
|
+
|
|
280
|
+
Returns:
|
|
281
|
+
Generated invariants (num_modes * num_samples, seq_len, invariant_dim).
|
|
282
|
+
"""
|
|
283
|
+
if not self.is_variational:
|
|
284
|
+
# For non-variational, just generate multiple samples
|
|
285
|
+
return self.generate(
|
|
286
|
+
num_samples=num_modes * num_samples,
|
|
287
|
+
seq_len=seq_len,
|
|
288
|
+
num_steps=num_steps,
|
|
289
|
+
method=method,
|
|
290
|
+
device=device,
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
# Generate with different w for each mode
|
|
294
|
+
z = self.flow_matcher.sample_multimodal(
|
|
295
|
+
num_samples=num_samples,
|
|
296
|
+
seq_len=seq_len,
|
|
297
|
+
num_modes=num_modes,
|
|
298
|
+
num_steps=num_steps,
|
|
299
|
+
method=method,
|
|
300
|
+
device=device,
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
return self.decode(z)
|
|
304
|
+
|
|
305
|
+
def generate_continuation(
|
|
306
|
+
self,
|
|
307
|
+
prefix_invariants: Tensor,
|
|
308
|
+
continuation_len: int,
|
|
309
|
+
num_modes: int = 1,
|
|
310
|
+
num_steps: int = 10,
|
|
311
|
+
method: str = "euler",
|
|
312
|
+
) -> Tensor:
|
|
313
|
+
"""
|
|
314
|
+
Generate trajectory continuations from a prefix.
|
|
315
|
+
|
|
316
|
+
Args:
|
|
317
|
+
prefix_invariants: Prefix invariant sequence (B, prefix_len, invariant_dim).
|
|
318
|
+
continuation_len: Length of continuation to generate.
|
|
319
|
+
num_modes: Number of continuation modes (for variational).
|
|
320
|
+
num_steps: ODE integration steps.
|
|
321
|
+
method: ODE solver method.
|
|
322
|
+
|
|
323
|
+
Returns:
|
|
324
|
+
Continued trajectories (B * num_modes, prefix_len + continuation_len, invariant_dim).
|
|
325
|
+
"""
|
|
326
|
+
B, prefix_len, _ = prefix_invariants.shape
|
|
327
|
+
device = prefix_invariants.device
|
|
328
|
+
total_len = prefix_len + continuation_len
|
|
329
|
+
|
|
330
|
+
# Encode prefix to latent
|
|
331
|
+
z_prefix = self.encode(prefix_invariants)
|
|
332
|
+
|
|
333
|
+
if self.is_variational:
|
|
334
|
+
# Sample continuations with different w
|
|
335
|
+
z_full = self.flow_matcher.sample_continuation(
|
|
336
|
+
z_prefix=z_prefix,
|
|
337
|
+
prefix_len=prefix_len,
|
|
338
|
+
total_len=total_len,
|
|
339
|
+
num_modes=num_modes,
|
|
340
|
+
num_steps=num_steps,
|
|
341
|
+
method=method,
|
|
342
|
+
)
|
|
343
|
+
else:
|
|
344
|
+
# Generate full sequences and keep prefix + continuation
|
|
345
|
+
all_z = []
|
|
346
|
+
for _ in range(num_modes):
|
|
347
|
+
z = self.flow_matcher.sample(
|
|
348
|
+
num_samples=B,
|
|
349
|
+
seq_len=total_len,
|
|
350
|
+
num_steps=num_steps,
|
|
351
|
+
method=method,
|
|
352
|
+
device=device,
|
|
353
|
+
)
|
|
354
|
+
# Blend: use prefix from input, continuation from generated
|
|
355
|
+
z_out = torch.cat([z_prefix, z[:, prefix_len:]], dim=1)
|
|
356
|
+
all_z.append(z_out)
|
|
357
|
+
z_full = torch.cat(all_z, dim=0)
|
|
358
|
+
|
|
359
|
+
return self.decode(z_full)
|
|
360
|
+
|
|
361
|
+
def interpolate_latent(
|
|
362
|
+
self,
|
|
363
|
+
invariants_a: Tensor,
|
|
364
|
+
invariants_b: Tensor,
|
|
365
|
+
num_interp: int = 5,
|
|
366
|
+
) -> Tensor:
|
|
367
|
+
"""
|
|
368
|
+
Interpolate between two invariant sequences in latent space.
|
|
369
|
+
|
|
370
|
+
Args:
|
|
371
|
+
invariants_a: First sequence (1, T, invariant_dim).
|
|
372
|
+
invariants_b: Second sequence (1, T, invariant_dim).
|
|
373
|
+
num_interp: Number of interpolation steps.
|
|
374
|
+
|
|
375
|
+
Returns:
|
|
376
|
+
Interpolated invariants (num_interp, T, invariant_dim).
|
|
377
|
+
"""
|
|
378
|
+
z_a = self.encode(invariants_a)
|
|
379
|
+
z_b = self.encode(invariants_b)
|
|
380
|
+
|
|
381
|
+
# Linear interpolation in latent space
|
|
382
|
+
alphas = torch.linspace(0, 1, num_interp, device=z_a.device)
|
|
383
|
+
z_interp = []
|
|
384
|
+
for alpha in alphas:
|
|
385
|
+
z = (1 - alpha) * z_a + alpha * z_b
|
|
386
|
+
z_interp.append(z)
|
|
387
|
+
|
|
388
|
+
z_interp = torch.cat(z_interp, dim=0)
|
|
389
|
+
return self.decode(z_interp)
|
|
390
|
+
|
|
391
|
+
|
|
392
|
+
class ConditionalVFMGenerator(VFMTokenGenerator):
|
|
393
|
+
"""
|
|
394
|
+
VFM generator with explicit goal/context conditioning.
|
|
395
|
+
|
|
396
|
+
Extends VFMTokenGenerator with:
|
|
397
|
+
- Goal pose conditioning
|
|
398
|
+
- Task embedding conditioning
|
|
399
|
+
- Context encoder for vision/language inputs
|
|
400
|
+
"""
|
|
401
|
+
|
|
402
|
+
def __init__(
|
|
403
|
+
self,
|
|
404
|
+
tokenizer: nn.Module,
|
|
405
|
+
flow_matcher: VariationalFlowMatcher,
|
|
406
|
+
goal_dim: int = 7, # position (3) + quaternion (4)
|
|
407
|
+
context_dim: int = 0,
|
|
408
|
+
hidden_dim: int = 128,
|
|
409
|
+
):
|
|
410
|
+
"""
|
|
411
|
+
Args:
|
|
412
|
+
tokenizer: DHBTokenizer or RVQ with encode_continuous/decode_from_latent.
|
|
413
|
+
flow_matcher: VariationalFlowMatcher.
|
|
414
|
+
goal_dim: Dimension of goal pose (default: 7 for pos + quat).
|
|
415
|
+
context_dim: Dimension of external context (e.g., vision embedding).
|
|
416
|
+
hidden_dim: Hidden dimension for condition encoding.
|
|
417
|
+
"""
|
|
418
|
+
# Create condition encoder
|
|
419
|
+
cond_input_dim = goal_dim + context_dim
|
|
420
|
+
condition_encoder = nn.Sequential(
|
|
421
|
+
nn.Linear(cond_input_dim, hidden_dim),
|
|
422
|
+
nn.GELU(),
|
|
423
|
+
nn.Linear(hidden_dim, flow_matcher.condition_dim),
|
|
424
|
+
) if cond_input_dim > 0 else None
|
|
425
|
+
|
|
426
|
+
super().__init__(
|
|
427
|
+
tokenizer=tokenizer,
|
|
428
|
+
flow_matcher=flow_matcher,
|
|
429
|
+
condition_encoder=condition_encoder,
|
|
430
|
+
)
|
|
431
|
+
|
|
432
|
+
self.goal_dim = goal_dim
|
|
433
|
+
self.context_dim = context_dim
|
|
434
|
+
|
|
435
|
+
def generate_to_goal(
|
|
436
|
+
self,
|
|
437
|
+
goal_pose: Tensor,
|
|
438
|
+
num_samples: int = 1,
|
|
439
|
+
seq_len: int = 50,
|
|
440
|
+
num_steps: int = 10,
|
|
441
|
+
method: str = "euler",
|
|
442
|
+
context: Optional[Tensor] = None,
|
|
443
|
+
) -> Tensor:
|
|
444
|
+
"""
|
|
445
|
+
Generate trajectories conditioned on goal pose.
|
|
446
|
+
|
|
447
|
+
Args:
|
|
448
|
+
goal_pose: Goal pose (B, 7) with position and quaternion.
|
|
449
|
+
num_samples: Samples per goal.
|
|
450
|
+
seq_len: Sequence length.
|
|
451
|
+
num_steps: ODE integration steps.
|
|
452
|
+
method: ODE solver.
|
|
453
|
+
context: Optional external context.
|
|
454
|
+
|
|
455
|
+
Returns:
|
|
456
|
+
Generated invariants reaching the goal.
|
|
457
|
+
"""
|
|
458
|
+
B = goal_pose.shape[0]
|
|
459
|
+
device = goal_pose.device
|
|
460
|
+
|
|
461
|
+
# Build condition
|
|
462
|
+
if context is not None:
|
|
463
|
+
cond_input = torch.cat([goal_pose, context], dim=-1)
|
|
464
|
+
else:
|
|
465
|
+
cond_input = goal_pose
|
|
466
|
+
|
|
467
|
+
# Encode condition to w
|
|
468
|
+
w = self.condition_encoder(cond_input)
|
|
469
|
+
|
|
470
|
+
# Expand for num_samples
|
|
471
|
+
if num_samples > 1:
|
|
472
|
+
w = w.unsqueeze(1).expand(-1, num_samples, -1)
|
|
473
|
+
w = w.reshape(B * num_samples, -1)
|
|
474
|
+
|
|
475
|
+
# Generate with goal conditioning
|
|
476
|
+
z = self.flow_matcher.sample(
|
|
477
|
+
num_samples=B * num_samples,
|
|
478
|
+
seq_len=seq_len,
|
|
479
|
+
num_steps=num_steps,
|
|
480
|
+
method=method,
|
|
481
|
+
w=w,
|
|
482
|
+
device=device,
|
|
483
|
+
)
|
|
484
|
+
|
|
485
|
+
return self.decode(z)
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
"""VLA framework integrations."""
|
|
2
|
+
|
|
3
|
+
from dhb_xr.integration.vla.robocasa import RoboCASAAdapter
|
|
4
|
+
from dhb_xr.integration.vla.libero import LiberoAdapter
|
|
5
|
+
from dhb_xr.integration.vla.pipeline import DHBVLAPipeline
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"RoboCASAAdapter",
|
|
9
|
+
"LiberoAdapter",
|
|
10
|
+
"DHBVLAPipeline",
|
|
11
|
+
]
|
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
"""Libero dataset adapter (robomimic-style HDF5).
|
|
2
|
+
|
|
3
|
+
LIBERO stores end-effector data as:
|
|
4
|
+
- Position: obs/ee_pos (N, 3)
|
|
5
|
+
- Orientation: robot_states[:, 5:9] as quaternion (w, x, y, z)
|
|
6
|
+
|
|
7
|
+
Note: obs/ee_ori contains Euler angles, not quaternions.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
from dataclasses import dataclass, field
|
|
13
|
+
from typing import Dict, Iterator, Optional, Tuple
|
|
14
|
+
|
|
15
|
+
import numpy as np
|
|
16
|
+
|
|
17
|
+
try:
|
|
18
|
+
import h5py
|
|
19
|
+
HAS_H5PY = True
|
|
20
|
+
except ImportError: # pragma: no cover - optional dependency
|
|
21
|
+
HAS_H5PY = False
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
DEFAULT_POS_KEYS = (
|
|
25
|
+
"robot0_eef_pos",
|
|
26
|
+
"eef_pos",
|
|
27
|
+
"ee_pos",
|
|
28
|
+
)
|
|
29
|
+
DEFAULT_QUAT_KEYS = (
|
|
30
|
+
"robot0_eef_quat",
|
|
31
|
+
"eef_quat",
|
|
32
|
+
"ee_quat",
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@dataclass
|
|
37
|
+
class LiberoAdapter:
|
|
38
|
+
"""
|
|
39
|
+
Minimal Libero adapter that yields (positions, quaternions, metadata).
|
|
40
|
+
|
|
41
|
+
Libero datasets are typically robomimic-style HDF5 files.
|
|
42
|
+
|
|
43
|
+
LIBERO stores quaternions in robot_states[:, 5:9] as (w, x, y, z),
|
|
44
|
+
not in the obs group. This adapter handles both cases.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
pos_keys: Tuple[str, ...] = DEFAULT_POS_KEYS
|
|
48
|
+
quat_keys: Tuple[str, ...] = DEFAULT_QUAT_KEYS
|
|
49
|
+
obs_group: str = "obs"
|
|
50
|
+
# Robot states quaternion extraction (LIBERO-specific)
|
|
51
|
+
robot_states_key: str = "robot_states"
|
|
52
|
+
robot_states_quat_slice: Tuple[int, int] = field(default=(5, 9))
|
|
53
|
+
robot_states_quat_format: str = "wxyz" # LIBERO uses (w, x, y, z)
|
|
54
|
+
|
|
55
|
+
def _find_key(self, group: "h5py.Group", candidates: Tuple[str, ...]) -> Optional[str]:
|
|
56
|
+
for key in candidates:
|
|
57
|
+
if key in group:
|
|
58
|
+
return key
|
|
59
|
+
return None
|
|
60
|
+
|
|
61
|
+
def _extract_quaternion(
|
|
62
|
+
self, demo: "h5py.Group", obs: "h5py.Group"
|
|
63
|
+
) -> Tuple[Optional[np.ndarray], str]:
|
|
64
|
+
"""
|
|
65
|
+
Extract quaternions from demo, trying obs group first, then robot_states.
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
(quaternions, source_key) where source_key describes where data came from.
|
|
69
|
+
"""
|
|
70
|
+
# First try obs group
|
|
71
|
+
quat_key = self._find_key(obs, self.quat_keys)
|
|
72
|
+
if quat_key is not None:
|
|
73
|
+
quats = np.asarray(obs[quat_key], dtype=np.float64)
|
|
74
|
+
return quats, f"obs/{quat_key}"
|
|
75
|
+
|
|
76
|
+
# Fallback: extract from robot_states (LIBERO-specific)
|
|
77
|
+
if self.robot_states_key in demo:
|
|
78
|
+
robot_states = np.asarray(demo[self.robot_states_key], dtype=np.float64)
|
|
79
|
+
start, end = self.robot_states_quat_slice
|
|
80
|
+
quats = robot_states[:, start:end]
|
|
81
|
+
|
|
82
|
+
# Convert from wxyz to xyzw format (standard for scipy/transforms3d)
|
|
83
|
+
if self.robot_states_quat_format == "wxyz":
|
|
84
|
+
quats = quats[:, [1, 2, 3, 0]] # w,x,y,z -> x,y,z,w
|
|
85
|
+
|
|
86
|
+
return quats, f"{self.robot_states_key}[{start}:{end}]"
|
|
87
|
+
|
|
88
|
+
return None, ""
|
|
89
|
+
|
|
90
|
+
def load_dataset(self, dataset_path: str) -> Iterator[Dict]:
|
|
91
|
+
if not HAS_H5PY:
|
|
92
|
+
raise ImportError("h5py is required for Libero adapter (pip install h5py).")
|
|
93
|
+
|
|
94
|
+
with h5py.File(dataset_path, "r") as h5:
|
|
95
|
+
data_group = h5.get("data")
|
|
96
|
+
if data_group is None:
|
|
97
|
+
raise ValueError("Libero HDF5 missing /data group.")
|
|
98
|
+
|
|
99
|
+
for demo_id in data_group.keys():
|
|
100
|
+
demo = data_group[demo_id]
|
|
101
|
+
obs = demo.get(self.obs_group)
|
|
102
|
+
if obs is None:
|
|
103
|
+
continue
|
|
104
|
+
|
|
105
|
+
# Get position
|
|
106
|
+
pos_key = self._find_key(obs, self.pos_keys)
|
|
107
|
+
if pos_key is None:
|
|
108
|
+
continue
|
|
109
|
+
positions = np.asarray(obs[pos_key], dtype=np.float64)
|
|
110
|
+
|
|
111
|
+
# Get quaternion (try obs first, then robot_states)
|
|
112
|
+
quaternions, quat_source = self._extract_quaternion(demo, obs)
|
|
113
|
+
if quaternions is None:
|
|
114
|
+
continue
|
|
115
|
+
|
|
116
|
+
metadata = {
|
|
117
|
+
"demo_id": demo_id,
|
|
118
|
+
"pos_key": pos_key,
|
|
119
|
+
"quat_source": quat_source,
|
|
120
|
+
"source": "libero",
|
|
121
|
+
"num_frames": len(positions),
|
|
122
|
+
}
|
|
123
|
+
if "task" in demo.attrs:
|
|
124
|
+
metadata["task"] = demo.attrs["task"]
|
|
125
|
+
if "language" in demo.attrs:
|
|
126
|
+
metadata["language_instruction"] = demo.attrs["language"]
|
|
127
|
+
|
|
128
|
+
yield {
|
|
129
|
+
"positions": positions,
|
|
130
|
+
"quaternions": quaternions,
|
|
131
|
+
"metadata": metadata,
|
|
132
|
+
}
|