dhb-xr 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. dhb_xr/__init__.py +61 -0
  2. dhb_xr/cli.py +206 -0
  3. dhb_xr/core/__init__.py +28 -0
  4. dhb_xr/core/geometry.py +167 -0
  5. dhb_xr/core/geometry_torch.py +77 -0
  6. dhb_xr/core/types.py +113 -0
  7. dhb_xr/database/__init__.py +10 -0
  8. dhb_xr/database/motion_db.py +79 -0
  9. dhb_xr/database/retrieval.py +6 -0
  10. dhb_xr/database/similarity.py +71 -0
  11. dhb_xr/decoder/__init__.py +13 -0
  12. dhb_xr/decoder/decoder_torch.py +52 -0
  13. dhb_xr/decoder/dhb_dr.py +261 -0
  14. dhb_xr/decoder/dhb_qr.py +89 -0
  15. dhb_xr/encoder/__init__.py +27 -0
  16. dhb_xr/encoder/dhb_dr.py +418 -0
  17. dhb_xr/encoder/dhb_qr.py +129 -0
  18. dhb_xr/encoder/dhb_ti.py +204 -0
  19. dhb_xr/encoder/encoder_torch.py +54 -0
  20. dhb_xr/encoder/padding.py +82 -0
  21. dhb_xr/generative/__init__.py +78 -0
  22. dhb_xr/generative/flow_matching.py +705 -0
  23. dhb_xr/generative/latent_encoder.py +536 -0
  24. dhb_xr/generative/sampling.py +203 -0
  25. dhb_xr/generative/training.py +475 -0
  26. dhb_xr/generative/vfm_tokenizer.py +485 -0
  27. dhb_xr/integration/__init__.py +13 -0
  28. dhb_xr/integration/vla/__init__.py +11 -0
  29. dhb_xr/integration/vla/libero.py +132 -0
  30. dhb_xr/integration/vla/pipeline.py +85 -0
  31. dhb_xr/integration/vla/robocasa.py +85 -0
  32. dhb_xr/losses/__init__.py +16 -0
  33. dhb_xr/losses/geodesic_loss.py +91 -0
  34. dhb_xr/losses/hybrid_loss.py +36 -0
  35. dhb_xr/losses/invariant_loss.py +73 -0
  36. dhb_xr/optimization/__init__.py +72 -0
  37. dhb_xr/optimization/casadi_solver.py +342 -0
  38. dhb_xr/optimization/constraints.py +32 -0
  39. dhb_xr/optimization/cusadi_solver.py +311 -0
  40. dhb_xr/optimization/export_casadi_decode.py +111 -0
  41. dhb_xr/optimization/fatrop_solver.py +477 -0
  42. dhb_xr/optimization/torch_solver.py +85 -0
  43. dhb_xr/preprocessing/__init__.py +42 -0
  44. dhb_xr/preprocessing/diagnostics.py +330 -0
  45. dhb_xr/preprocessing/trajectory_cleaner.py +485 -0
  46. dhb_xr/tokenization/__init__.py +56 -0
  47. dhb_xr/tokenization/causal_encoder.py +54 -0
  48. dhb_xr/tokenization/compression.py +749 -0
  49. dhb_xr/tokenization/hierarchical.py +359 -0
  50. dhb_xr/tokenization/rvq.py +178 -0
  51. dhb_xr/tokenization/vqvae.py +155 -0
  52. dhb_xr/utils/__init__.py +24 -0
  53. dhb_xr/utils/io.py +59 -0
  54. dhb_xr/utils/resampling.py +66 -0
  55. dhb_xr/utils/xdof_loader.py +89 -0
  56. dhb_xr/visualization/__init__.py +5 -0
  57. dhb_xr/visualization/plot.py +242 -0
  58. dhb_xr-0.2.1.dist-info/METADATA +784 -0
  59. dhb_xr-0.2.1.dist-info/RECORD +82 -0
  60. dhb_xr-0.2.1.dist-info/WHEEL +5 -0
  61. dhb_xr-0.2.1.dist-info/entry_points.txt +2 -0
  62. dhb_xr-0.2.1.dist-info/top_level.txt +3 -0
  63. examples/__init__.py +54 -0
  64. examples/basic_encoding.py +82 -0
  65. examples/benchmark_backends.py +37 -0
  66. examples/dhb_qr_comparison.py +79 -0
  67. examples/dhb_ti_time_invariant.py +72 -0
  68. examples/gpu_batch_optimization.py +102 -0
  69. examples/imitation_learning.py +53 -0
  70. examples/integration/__init__.py +19 -0
  71. examples/integration/libero_full_demo.py +692 -0
  72. examples/integration/libero_pro_dhb_demo.py +1063 -0
  73. examples/integration/libero_simulation_demo.py +286 -0
  74. examples/integration/libero_swap_demo.py +534 -0
  75. examples/integration/robocasa_libero_dhb_pipeline.py +56 -0
  76. examples/integration/test_libero_adapter.py +47 -0
  77. examples/integration/test_libero_encoding.py +75 -0
  78. examples/integration/test_libero_retrieval.py +105 -0
  79. examples/motion_database.py +88 -0
  80. examples/trajectory_adaptation.py +85 -0
  81. examples/vla_tokenization.py +107 -0
  82. notebooks/__init__.py +24 -0
@@ -0,0 +1,536 @@
1
+ """
2
+ Latent encoders for variational flow matching.
3
+
4
+ Provides inference networks q(w | z_t, t) for the variational formulation:
5
+ - LatentEncoder: Gaussian posterior (continuous w)
6
+ - CategoricalLatentEncoder: Categorical posterior with Gumbel-Softmax (discrete w)
7
+
8
+ The latent w captures "intent" or "mode" information that helps the velocity
9
+ network distinguish between multiple valid trajectory continuations.
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import math
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ from torch import Tensor
19
+ from typing import Optional, Tuple, Dict
20
+
21
+ from .flow_matching import SinusoidalTimeEmbedding
22
+
23
+
24
+ class LatentEncoder(nn.Module):
25
+ """
26
+ Gaussian latent encoder for variational flow matching.
27
+
28
+ Maps (z_t, t) to a Gaussian distribution over the conditioning latent w.
29
+ Uses an MLP with optional temporal attention for sequence inputs.
30
+ """
31
+
32
+ def __init__(
33
+ self,
34
+ input_dim: int,
35
+ latent_dim: int,
36
+ hidden_dim: int = 256,
37
+ num_layers: int = 2,
38
+ time_embed_dim: int = 64,
39
+ use_attention: bool = False,
40
+ num_heads: int = 4,
41
+ ):
42
+ """
43
+ Args:
44
+ input_dim: Dimension of input z_t.
45
+ latent_dim: Dimension of output latent w.
46
+ hidden_dim: Hidden layer dimension.
47
+ num_layers: Number of hidden layers.
48
+ time_embed_dim: Dimension of time embedding.
49
+ use_attention: Whether to use attention for sequence pooling.
50
+ num_heads: Number of attention heads (if use_attention=True).
51
+ """
52
+ super().__init__()
53
+ self.input_dim = input_dim
54
+ self.latent_dim = latent_dim
55
+ self.use_attention = use_attention
56
+
57
+ # Time embedding
58
+ self.time_embed = SinusoidalTimeEmbedding(time_embed_dim)
59
+
60
+ # Optional attention for sequence pooling
61
+ if use_attention:
62
+ self.attention = nn.MultiheadAttention(
63
+ embed_dim=input_dim,
64
+ num_heads=num_heads,
65
+ batch_first=True,
66
+ )
67
+ self.pool_query = nn.Parameter(torch.randn(1, 1, input_dim))
68
+
69
+ # MLP encoder
70
+ layers = []
71
+ in_dim = input_dim + time_embed_dim
72
+ for i in range(num_layers):
73
+ out_dim = hidden_dim if i < num_layers - 1 else hidden_dim
74
+ layers.extend([
75
+ nn.Linear(in_dim, out_dim),
76
+ nn.LayerNorm(out_dim),
77
+ nn.GELU(),
78
+ ])
79
+ in_dim = out_dim
80
+
81
+ self.encoder = nn.Sequential(*layers)
82
+
83
+ # Output heads for mean and log_std
84
+ self.mean_head = nn.Linear(hidden_dim, latent_dim)
85
+ self.log_std_head = nn.Linear(hidden_dim, latent_dim)
86
+
87
+ def pool_sequence(self, z_t: Tensor) -> Tensor:
88
+ """Pool sequence dimension to get batch-level representation."""
89
+ if z_t.dim() == 2:
90
+ return z_t # Already (B, D)
91
+
92
+ if self.use_attention:
93
+ # Use attention pooling
94
+ B, T, D = z_t.shape
95
+ query = self.pool_query.expand(B, -1, -1)
96
+ pooled, _ = self.attention(query, z_t, z_t)
97
+ return pooled.squeeze(1) # (B, D)
98
+ else:
99
+ # Simple mean pooling
100
+ return z_t.mean(dim=1)
101
+
102
+ def forward(
103
+ self,
104
+ z_t: Tensor,
105
+ t: Tensor | float,
106
+ ) -> Tuple[Tensor, Tensor]:
107
+ """
108
+ Encode to posterior parameters.
109
+
110
+ Args:
111
+ z_t: Input state (B, T, D) or (B, D).
112
+ t: Time (B,) or scalar in [0, 1].
113
+
114
+ Returns:
115
+ Tuple of (mean, log_std) for q(w | z_t, t).
116
+ """
117
+ # Handle scalar time
118
+ if isinstance(t, float):
119
+ t = torch.tensor([t], device=z_t.device, dtype=z_t.dtype)
120
+ if t.dim() == 0:
121
+ t = t.unsqueeze(0)
122
+
123
+ # Pool sequence
124
+ z_pooled = self.pool_sequence(z_t)
125
+
126
+ # Time embedding
127
+ t_embed = self.time_embed(t)
128
+
129
+ # Expand t_embed to batch size if needed
130
+ if t_embed.shape[0] == 1 and z_pooled.shape[0] > 1:
131
+ t_embed = t_embed.expand(z_pooled.shape[0], -1)
132
+
133
+ # Concatenate and encode
134
+ x = torch.cat([z_pooled, t_embed], dim=-1)
135
+ h = self.encoder(x)
136
+
137
+ # Output parameters
138
+ mean = self.mean_head(h)
139
+ log_std = self.log_std_head(h)
140
+ log_std = torch.clamp(log_std, min=-10, max=2)
141
+
142
+ return mean, log_std
143
+
144
+ def sample(
145
+ self,
146
+ z_t: Tensor,
147
+ t: Tensor | float,
148
+ num_samples: int = 1,
149
+ ) -> Tuple[Tensor, Tensor, Tensor]:
150
+ """
151
+ Sample from posterior using reparameterization trick.
152
+
153
+ Args:
154
+ z_t: Input state.
155
+ t: Time.
156
+ num_samples: Number of samples per input.
157
+
158
+ Returns:
159
+ Tuple of (w, mean, log_std) where w is (B, num_samples, latent_dim)
160
+ or (B, latent_dim) if num_samples=1.
161
+ """
162
+ mean, log_std = self.forward(z_t, t)
163
+ std = torch.exp(log_std)
164
+
165
+ if num_samples == 1:
166
+ eps = torch.randn_like(mean)
167
+ w = mean + std * eps
168
+ else:
169
+ # (B, num_samples, latent_dim)
170
+ eps = torch.randn(mean.shape[0], num_samples, mean.shape[1], device=mean.device)
171
+ w = mean.unsqueeze(1) + std.unsqueeze(1) * eps
172
+
173
+ return w, mean, log_std
174
+
175
+ def kl_divergence(
176
+ self,
177
+ mean: Tensor,
178
+ log_std: Tensor,
179
+ prior_mean: float = 0.0,
180
+ prior_std: float = 1.0,
181
+ ) -> Tensor:
182
+ """
183
+ Compute KL divergence from posterior to prior.
184
+
185
+ KL(N(mean, std^2) || N(prior_mean, prior_std^2))
186
+ """
187
+ std = torch.exp(log_std)
188
+ prior_var = prior_std ** 2
189
+
190
+ kl = 0.5 * (
191
+ (std ** 2 + (mean - prior_mean) ** 2) / prior_var
192
+ - 1
193
+ - 2 * log_std
194
+ + 2 * math.log(prior_std)
195
+ )
196
+
197
+ return kl.sum(dim=-1).mean()
198
+
199
+
200
+ class CategoricalLatentEncoder(nn.Module):
201
+ """
202
+ Categorical latent encoder with Gumbel-Softmax for discrete modes.
203
+
204
+ Useful when the multi-modality has a clear categorical structure
205
+ (e.g., left vs. right grasp, different skill types).
206
+
207
+ Uses Gumbel-Softmax for differentiable sampling during training
208
+ and argmax during inference.
209
+ """
210
+
211
+ def __init__(
212
+ self,
213
+ input_dim: int,
214
+ num_categories: int,
215
+ embedding_dim: int = 16,
216
+ hidden_dim: int = 256,
217
+ num_layers: int = 2,
218
+ time_embed_dim: int = 64,
219
+ temperature: float = 1.0,
220
+ ):
221
+ """
222
+ Args:
223
+ input_dim: Dimension of input z_t.
224
+ num_categories: Number of discrete categories (modes).
225
+ embedding_dim: Dimension of category embeddings.
226
+ hidden_dim: Hidden layer dimension.
227
+ num_layers: Number of hidden layers.
228
+ time_embed_dim: Dimension of time embedding.
229
+ temperature: Gumbel-Softmax temperature.
230
+ """
231
+ super().__init__()
232
+ self.input_dim = input_dim
233
+ self.num_categories = num_categories
234
+ self.embedding_dim = embedding_dim
235
+ self.temperature = temperature
236
+
237
+ # Time embedding
238
+ self.time_embed = SinusoidalTimeEmbedding(time_embed_dim)
239
+
240
+ # MLP encoder
241
+ layers = []
242
+ in_dim = input_dim + time_embed_dim
243
+ for i in range(num_layers):
244
+ out_dim = hidden_dim
245
+ layers.extend([
246
+ nn.Linear(in_dim, out_dim),
247
+ nn.LayerNorm(out_dim),
248
+ nn.GELU(),
249
+ ])
250
+ in_dim = out_dim
251
+
252
+ self.encoder = nn.Sequential(*layers)
253
+
254
+ # Logits head
255
+ self.logits_head = nn.Linear(hidden_dim, num_categories)
256
+
257
+ # Category embeddings (for converting one-hot to continuous)
258
+ self.category_embeddings = nn.Embedding(num_categories, embedding_dim)
259
+
260
+ @property
261
+ def latent_dim(self) -> int:
262
+ """Return the effective latent dimension (embedding_dim)."""
263
+ return self.embedding_dim
264
+
265
+ def forward(
266
+ self,
267
+ z_t: Tensor,
268
+ t: Tensor | float,
269
+ ) -> Tensor:
270
+ """
271
+ Compute category logits.
272
+
273
+ Args:
274
+ z_t: Input state (B, T, D) or (B, D).
275
+ t: Time (B,) or scalar.
276
+
277
+ Returns:
278
+ Logits (B, num_categories).
279
+ """
280
+ # Handle scalar time
281
+ if isinstance(t, float):
282
+ t = torch.tensor([t], device=z_t.device, dtype=z_t.dtype)
283
+ if t.dim() == 0:
284
+ t = t.unsqueeze(0)
285
+
286
+ # Pool sequence
287
+ if z_t.dim() == 3:
288
+ z_pooled = z_t.mean(dim=1)
289
+ else:
290
+ z_pooled = z_t
291
+
292
+ # Time embedding
293
+ t_embed = self.time_embed(t)
294
+ if t_embed.shape[0] == 1 and z_pooled.shape[0] > 1:
295
+ t_embed = t_embed.expand(z_pooled.shape[0], -1)
296
+
297
+ # Encode
298
+ x = torch.cat([z_pooled, t_embed], dim=-1)
299
+ h = self.encoder(x)
300
+
301
+ # Logits
302
+ logits = self.logits_head(h)
303
+
304
+ return logits
305
+
306
+ def sample(
307
+ self,
308
+ z_t: Tensor,
309
+ t: Tensor | float,
310
+ hard: bool = False,
311
+ temperature: Optional[float] = None,
312
+ ) -> Tuple[Tensor, Tensor, Tensor]:
313
+ """
314
+ Sample using Gumbel-Softmax.
315
+
316
+ Args:
317
+ z_t: Input state.
318
+ t: Time.
319
+ hard: If True, use straight-through estimator for hard samples.
320
+ temperature: Override default temperature.
321
+
322
+ Returns:
323
+ Tuple of (w, probs, logits) where:
324
+ - w: Category embedding (B, embedding_dim)
325
+ - probs: Category probabilities (B, num_categories)
326
+ - logits: Raw logits (B, num_categories)
327
+ """
328
+ if temperature is None:
329
+ temperature = self.temperature
330
+
331
+ logits = self.forward(z_t, t)
332
+ probs = F.softmax(logits, dim=-1)
333
+
334
+ # Gumbel-Softmax sampling
335
+ if self.training:
336
+ gumbel_probs = F.gumbel_softmax(logits, tau=temperature, hard=hard)
337
+ else:
338
+ # During inference, use argmax
339
+ if hard:
340
+ indices = logits.argmax(dim=-1)
341
+ gumbel_probs = F.one_hot(indices, self.num_categories).float()
342
+ else:
343
+ gumbel_probs = probs
344
+
345
+ # Convert to embedding
346
+ # gumbel_probs: (B, num_categories)
347
+ # category_embeddings: (num_categories, embedding_dim)
348
+ w = gumbel_probs @ self.category_embeddings.weight
349
+
350
+ return w, probs, logits
351
+
352
+ def sample_prior(
353
+ self,
354
+ batch_size: int,
355
+ device: str = "cpu",
356
+ hard: bool = True,
357
+ ) -> Tensor:
358
+ """
359
+ Sample from uniform categorical prior.
360
+
361
+ Args:
362
+ batch_size: Number of samples.
363
+ device: Device for tensor.
364
+ hard: If True, return embeddings for hard categories.
365
+
366
+ Returns:
367
+ Category embeddings (batch_size, embedding_dim).
368
+ """
369
+ if hard:
370
+ # Sample uniform categories
371
+ indices = torch.randint(
372
+ 0, self.num_categories, (batch_size,), device=device
373
+ )
374
+ w = self.category_embeddings(indices)
375
+ else:
376
+ # Soft uniform mixture
377
+ probs = torch.ones(batch_size, self.num_categories, device=device)
378
+ probs = probs / self.num_categories
379
+ w = probs @ self.category_embeddings.weight
380
+
381
+ return w
382
+
383
+ def kl_divergence(
384
+ self,
385
+ probs: Tensor,
386
+ prior_probs: Optional[Tensor] = None,
387
+ ) -> Tensor:
388
+ """
389
+ Compute KL divergence from posterior to prior.
390
+
391
+ If prior_probs is None, uses uniform prior.
392
+ """
393
+ if prior_probs is None:
394
+ # Uniform prior
395
+ prior_probs = torch.ones_like(probs) / self.num_categories
396
+
397
+ # KL = sum_k p_k * log(p_k / q_k)
398
+ kl = probs * (torch.log(probs + 1e-10) - torch.log(prior_probs + 1e-10))
399
+
400
+ return kl.sum(dim=-1).mean()
401
+
402
+ def entropy(self, probs: Tensor) -> Tensor:
403
+ """Compute entropy of the categorical distribution."""
404
+ return -(probs * torch.log(probs + 1e-10)).sum(dim=-1).mean()
405
+
406
+
407
+ class HybridLatentEncoder(nn.Module):
408
+ """
409
+ Hybrid encoder combining categorical and continuous latents.
410
+
411
+ Useful for representing both discrete mode selection and
412
+ continuous variations within each mode.
413
+ """
414
+
415
+ def __init__(
416
+ self,
417
+ input_dim: int,
418
+ num_categories: int,
419
+ continuous_dim: int,
420
+ hidden_dim: int = 256,
421
+ num_layers: int = 2,
422
+ time_embed_dim: int = 64,
423
+ temperature: float = 1.0,
424
+ ):
425
+ """
426
+ Args:
427
+ input_dim: Dimension of input z_t.
428
+ num_categories: Number of discrete modes.
429
+ continuous_dim: Dimension of continuous latent per mode.
430
+ hidden_dim: Hidden layer dimension.
431
+ num_layers: Number of hidden layers.
432
+ time_embed_dim: Dimension of time embedding.
433
+ temperature: Gumbel-Softmax temperature.
434
+ """
435
+ super().__init__()
436
+ self.num_categories = num_categories
437
+ self.continuous_dim = continuous_dim
438
+
439
+ # Categorical encoder
440
+ self.categorical = CategoricalLatentEncoder(
441
+ input_dim=input_dim,
442
+ num_categories=num_categories,
443
+ embedding_dim=hidden_dim // 4,
444
+ hidden_dim=hidden_dim,
445
+ num_layers=num_layers,
446
+ time_embed_dim=time_embed_dim,
447
+ temperature=temperature,
448
+ )
449
+
450
+ # Continuous encoder (shared across modes)
451
+ self.continuous = LatentEncoder(
452
+ input_dim=input_dim,
453
+ latent_dim=continuous_dim,
454
+ hidden_dim=hidden_dim,
455
+ num_layers=num_layers,
456
+ time_embed_dim=time_embed_dim,
457
+ )
458
+
459
+ # Combine categorical embedding and continuous latent
460
+ self.combine = nn.Linear(
461
+ self.categorical.embedding_dim + continuous_dim,
462
+ continuous_dim,
463
+ )
464
+
465
+ @property
466
+ def latent_dim(self) -> int:
467
+ """Return the combined latent dimension."""
468
+ return self.continuous_dim
469
+
470
+ def forward(
471
+ self,
472
+ z_t: Tensor,
473
+ t: Tensor | float,
474
+ ) -> Dict[str, Tensor]:
475
+ """
476
+ Encode to both categorical and continuous latents.
477
+
478
+ Returns dictionary with all components.
479
+ """
480
+ # Categorical
481
+ cat_logits = self.categorical(z_t, t)
482
+ cat_probs = F.softmax(cat_logits, dim=-1)
483
+
484
+ # Continuous
485
+ cont_mean, cont_log_std = self.continuous(z_t, t)
486
+
487
+ return {
488
+ "categorical_logits": cat_logits,
489
+ "categorical_probs": cat_probs,
490
+ "continuous_mean": cont_mean,
491
+ "continuous_log_std": cont_log_std,
492
+ }
493
+
494
+ def sample(
495
+ self,
496
+ z_t: Tensor,
497
+ t: Tensor | float,
498
+ hard: bool = False,
499
+ temperature: Optional[float] = None,
500
+ ) -> Tuple[Tensor, Dict[str, Tensor]]:
501
+ """
502
+ Sample combined latent.
503
+
504
+ Returns:
505
+ Tuple of (w, info_dict) where w is (B, latent_dim).
506
+ """
507
+ # Sample categorical
508
+ cat_w, cat_probs, cat_logits = self.categorical.sample(
509
+ z_t, t, hard=hard, temperature=temperature
510
+ )
511
+
512
+ # Sample continuous
513
+ cont_w, cont_mean, cont_log_std = self.continuous.sample(z_t, t)
514
+
515
+ # Combine
516
+ combined = torch.cat([cat_w, cont_w], dim=-1)
517
+ w = self.combine(combined)
518
+
519
+ info = {
520
+ "categorical_w": cat_w,
521
+ "categorical_probs": cat_probs,
522
+ "categorical_logits": cat_logits,
523
+ "continuous_w": cont_w,
524
+ "continuous_mean": cont_mean,
525
+ "continuous_log_std": cont_log_std,
526
+ }
527
+
528
+ return w, info
529
+
530
+ def kl_divergence(self, info: Dict[str, Tensor]) -> Tensor:
531
+ """Compute total KL divergence."""
532
+ cat_kl = self.categorical.kl_divergence(info["categorical_probs"])
533
+ cont_kl = self.continuous.kl_divergence(
534
+ info["continuous_mean"], info["continuous_log_std"]
535
+ )
536
+ return cat_kl + cont_kl