dhb-xr 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. dhb_xr/__init__.py +61 -0
  2. dhb_xr/cli.py +206 -0
  3. dhb_xr/core/__init__.py +28 -0
  4. dhb_xr/core/geometry.py +167 -0
  5. dhb_xr/core/geometry_torch.py +77 -0
  6. dhb_xr/core/types.py +113 -0
  7. dhb_xr/database/__init__.py +10 -0
  8. dhb_xr/database/motion_db.py +79 -0
  9. dhb_xr/database/retrieval.py +6 -0
  10. dhb_xr/database/similarity.py +71 -0
  11. dhb_xr/decoder/__init__.py +13 -0
  12. dhb_xr/decoder/decoder_torch.py +52 -0
  13. dhb_xr/decoder/dhb_dr.py +261 -0
  14. dhb_xr/decoder/dhb_qr.py +89 -0
  15. dhb_xr/encoder/__init__.py +27 -0
  16. dhb_xr/encoder/dhb_dr.py +418 -0
  17. dhb_xr/encoder/dhb_qr.py +129 -0
  18. dhb_xr/encoder/dhb_ti.py +204 -0
  19. dhb_xr/encoder/encoder_torch.py +54 -0
  20. dhb_xr/encoder/padding.py +82 -0
  21. dhb_xr/generative/__init__.py +78 -0
  22. dhb_xr/generative/flow_matching.py +705 -0
  23. dhb_xr/generative/latent_encoder.py +536 -0
  24. dhb_xr/generative/sampling.py +203 -0
  25. dhb_xr/generative/training.py +475 -0
  26. dhb_xr/generative/vfm_tokenizer.py +485 -0
  27. dhb_xr/integration/__init__.py +13 -0
  28. dhb_xr/integration/vla/__init__.py +11 -0
  29. dhb_xr/integration/vla/libero.py +132 -0
  30. dhb_xr/integration/vla/pipeline.py +85 -0
  31. dhb_xr/integration/vla/robocasa.py +85 -0
  32. dhb_xr/losses/__init__.py +16 -0
  33. dhb_xr/losses/geodesic_loss.py +91 -0
  34. dhb_xr/losses/hybrid_loss.py +36 -0
  35. dhb_xr/losses/invariant_loss.py +73 -0
  36. dhb_xr/optimization/__init__.py +72 -0
  37. dhb_xr/optimization/casadi_solver.py +342 -0
  38. dhb_xr/optimization/constraints.py +32 -0
  39. dhb_xr/optimization/cusadi_solver.py +311 -0
  40. dhb_xr/optimization/export_casadi_decode.py +111 -0
  41. dhb_xr/optimization/fatrop_solver.py +477 -0
  42. dhb_xr/optimization/torch_solver.py +85 -0
  43. dhb_xr/preprocessing/__init__.py +42 -0
  44. dhb_xr/preprocessing/diagnostics.py +330 -0
  45. dhb_xr/preprocessing/trajectory_cleaner.py +485 -0
  46. dhb_xr/tokenization/__init__.py +56 -0
  47. dhb_xr/tokenization/causal_encoder.py +54 -0
  48. dhb_xr/tokenization/compression.py +749 -0
  49. dhb_xr/tokenization/hierarchical.py +359 -0
  50. dhb_xr/tokenization/rvq.py +178 -0
  51. dhb_xr/tokenization/vqvae.py +155 -0
  52. dhb_xr/utils/__init__.py +24 -0
  53. dhb_xr/utils/io.py +59 -0
  54. dhb_xr/utils/resampling.py +66 -0
  55. dhb_xr/utils/xdof_loader.py +89 -0
  56. dhb_xr/visualization/__init__.py +5 -0
  57. dhb_xr/visualization/plot.py +242 -0
  58. dhb_xr-0.2.1.dist-info/METADATA +784 -0
  59. dhb_xr-0.2.1.dist-info/RECORD +82 -0
  60. dhb_xr-0.2.1.dist-info/WHEEL +5 -0
  61. dhb_xr-0.2.1.dist-info/entry_points.txt +2 -0
  62. dhb_xr-0.2.1.dist-info/top_level.txt +3 -0
  63. examples/__init__.py +54 -0
  64. examples/basic_encoding.py +82 -0
  65. examples/benchmark_backends.py +37 -0
  66. examples/dhb_qr_comparison.py +79 -0
  67. examples/dhb_ti_time_invariant.py +72 -0
  68. examples/gpu_batch_optimization.py +102 -0
  69. examples/imitation_learning.py +53 -0
  70. examples/integration/__init__.py +19 -0
  71. examples/integration/libero_full_demo.py +692 -0
  72. examples/integration/libero_pro_dhb_demo.py +1063 -0
  73. examples/integration/libero_simulation_demo.py +286 -0
  74. examples/integration/libero_swap_demo.py +534 -0
  75. examples/integration/robocasa_libero_dhb_pipeline.py +56 -0
  76. examples/integration/test_libero_adapter.py +47 -0
  77. examples/integration/test_libero_encoding.py +75 -0
  78. examples/integration/test_libero_retrieval.py +105 -0
  79. examples/motion_database.py +88 -0
  80. examples/trajectory_adaptation.py +85 -0
  81. examples/vla_tokenization.py +107 -0
  82. notebooks/__init__.py +24 -0
@@ -0,0 +1,359 @@
1
+ """
2
+ Hierarchical tokenization for variable-rate DHB compression.
3
+
4
+ Extends RVQ with multi-level structure for lossy compression:
5
+ - Coarse tokens: Global trajectory structure
6
+ - Fine tokens: Local details and refinements
7
+ - Configurable depth for compression/quality tradeoff
8
+
9
+ Inspired by:
10
+ - BEAST (NeurIPS 2025): B-spline encoded action sequences
11
+ - VQ-VLA: Multi-level vector quantization
12
+ """
13
+
14
+ try:
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ from dhb_xr.tokenization.vqvae import VectorQuantizer
19
+ from dhb_xr.tokenization.causal_encoder import CausalConv1dEncoder
20
+ HAS_TORCH = True
21
+ except ImportError:
22
+ HAS_TORCH = False
23
+
24
+
25
+ if HAS_TORCH:
26
+
27
+ class HierarchicalTokenizer(nn.Module):
28
+ """
29
+ Hierarchical RVQ with variable-rate output.
30
+
31
+ Provides coarse-to-fine tokenization:
32
+ - Level 0: Low-frequency global structure (high compression)
33
+ - Level 1-N: Residual details (configurable refinement)
34
+
35
+ For inference, can truncate to fewer levels for faster/coarser output.
36
+
37
+ Example:
38
+ >>> tokenizer = HierarchicalTokenizer(
39
+ ... invariant_dim=8, latent_dim=32, codebook_size=256, num_levels=4
40
+ ... )
41
+ >>> tokens, recon = tokenizer(invariants)
42
+ >>>
43
+ >>> # Coarse only (4x fewer tokens)
44
+ >>> tokens_coarse, recon_coarse = tokenizer(invariants, max_level=1)
45
+ """
46
+
47
+ def __init__(
48
+ self,
49
+ invariant_dim: int,
50
+ latent_dim: int,
51
+ codebook_size: int,
52
+ num_levels: int = 4,
53
+ temporal_downsample: int = 2,
54
+ num_layers: int = 2,
55
+ ):
56
+ """
57
+ Args:
58
+ invariant_dim: DHB invariant dimension (typically 8)
59
+ latent_dim: Latent embedding dimension
60
+ codebook_size: VQ codebook size per level
61
+ num_levels: Number of hierarchy levels
62
+ temporal_downsample: Downsample factor between levels
63
+ num_layers: Conv layers per encoder/decoder
64
+ """
65
+ super().__init__()
66
+
67
+ self.invariant_dim = invariant_dim
68
+ self.latent_dim = latent_dim
69
+ self.codebook_size = codebook_size
70
+ self.num_levels = num_levels
71
+ self.temporal_downsample = temporal_downsample
72
+
73
+ # Per-level encoders (progressively downsample)
74
+ self.encoders = nn.ModuleList()
75
+ self.vqs = nn.ModuleList()
76
+ self.decoders = nn.ModuleList()
77
+
78
+ for level in range(num_levels):
79
+ # Encoder: downsample temporally at each level
80
+ if level == 0:
81
+ enc = CausalConv1dEncoder(
82
+ invariant_dim, latent_dim, latent_dim, num_layers
83
+ )
84
+ else:
85
+ enc = nn.Sequential(
86
+ CausalConv1dEncoder(latent_dim, latent_dim, latent_dim, num_layers),
87
+ TemporalDownsample(temporal_downsample),
88
+ )
89
+ self.encoders.append(enc)
90
+
91
+ # VQ at each level
92
+ self.vqs.append(VectorQuantizer(codebook_size, latent_dim))
93
+
94
+ # Decoder: upsample to match previous level
95
+ if level == 0:
96
+ dec = CausalConv1dEncoder(
97
+ latent_dim, latent_dim, invariant_dim, num_layers
98
+ )
99
+ else:
100
+ dec = nn.Sequential(
101
+ TemporalUpsample(temporal_downsample),
102
+ CausalConv1dEncoder(latent_dim, latent_dim, latent_dim, num_layers),
103
+ )
104
+ self.decoders.append(dec)
105
+
106
+ # Final projection back to invariant space
107
+ self.final_proj = nn.Linear(latent_dim, invariant_dim)
108
+
109
+ def forward(
110
+ self,
111
+ invariants: torch.Tensor,
112
+ max_level: int = None,
113
+ return_all_levels: bool = False,
114
+ ) -> tuple:
115
+ """
116
+ Hierarchical encoding and decoding.
117
+
118
+ Args:
119
+ invariants: (B, T, invariant_dim) input
120
+ max_level: Stop at this level (None = all levels)
121
+ return_all_levels: Return tokens/recon at each level
122
+
123
+ Returns:
124
+ all_tokens: List of (B, T_l) tokens per level
125
+ reconstructed: (B, T, invariant_dim) reconstruction
126
+ level_info: Optional dict with per-level details
127
+ """
128
+ B, T, C = invariants.shape
129
+ max_level = max_level or self.num_levels
130
+
131
+ all_tokens = []
132
+ all_z = []
133
+ all_z_q = []
134
+ level_info = {}
135
+
136
+ # Encode through hierarchy
137
+ x = invariants
138
+ for level in range(max_level):
139
+ z = self.encoders[level](x if level == 0 else z_residual)
140
+ indices, z_q_st, z_q = self.vqs[level](z)
141
+
142
+ all_tokens.append(indices)
143
+ all_z.append(z)
144
+ all_z_q.append(z_q)
145
+
146
+ if level < max_level - 1:
147
+ z_residual = z - z_q.detach()
148
+
149
+ level_info[f"level_{level}"] = {
150
+ "shape": tuple(z.shape),
151
+ "tokens": indices.shape[-1],
152
+ }
153
+
154
+ # Decode through hierarchy (reverse order)
155
+ reconstructed = torch.zeros_like(invariants)
156
+ for level in reversed(range(max_level)):
157
+ dec_out = self.decoders[level](all_z_q[level])
158
+
159
+ # Match temporal dimension
160
+ if dec_out.shape[1] > reconstructed.shape[1]:
161
+ dec_out = dec_out[:, :reconstructed.shape[1], :]
162
+ elif dec_out.shape[1] < reconstructed.shape[1]:
163
+ # Upsample to match
164
+ dec_out = F.interpolate(
165
+ dec_out.transpose(1, 2),
166
+ size=reconstructed.shape[1],
167
+ mode='linear',
168
+ align_corners=True
169
+ ).transpose(1, 2)
170
+
171
+ if level == 0:
172
+ reconstructed = dec_out
173
+ else:
174
+ reconstructed = reconstructed + self.final_proj(dec_out)
175
+
176
+ if return_all_levels:
177
+ return all_tokens, reconstructed, level_info
178
+ return all_tokens, reconstructed
179
+
180
+ def loss(
181
+ self,
182
+ invariants: torch.Tensor,
183
+ reconstructed: torch.Tensor,
184
+ all_z: list,
185
+ all_z_q: list,
186
+ beta: float = 0.25,
187
+ level_weights: list = None,
188
+ ) -> torch.Tensor:
189
+ """
190
+ Compute hierarchical loss with per-level weighting.
191
+
192
+ Args:
193
+ invariants: Original input
194
+ reconstructed: Reconstruction
195
+ all_z: Latents at each level
196
+ all_z_q: Quantized latents at each level
197
+ beta: Commitment loss weight
198
+ level_weights: Optional weights per level (default: exponential decay)
199
+
200
+ Returns:
201
+ Total loss
202
+ """
203
+ # Reconstruction loss
204
+ rec_loss = F.mse_loss(reconstructed, invariants)
205
+
206
+ # Per-level VQ losses
207
+ if level_weights is None:
208
+ level_weights = [0.5 ** i for i in range(len(all_z))]
209
+
210
+ commitment = 0
211
+ codebook = 0
212
+ for i, (z, z_q) in enumerate(zip(all_z, all_z_q)):
213
+ commitment += level_weights[i] * F.mse_loss(z, z_q.detach())
214
+ codebook += level_weights[i] * F.mse_loss(z_q, z.detach())
215
+
216
+ return rec_loss + beta * commitment + codebook
217
+
218
+ def get_compression_stats(self, T: int, max_level: int = None) -> dict:
219
+ """
220
+ Compute compression statistics.
221
+
222
+ Args:
223
+ T: Original sequence length
224
+ max_level: Number of levels to use
225
+
226
+ Returns:
227
+ Compression statistics
228
+ """
229
+ max_level = max_level or self.num_levels
230
+
231
+ total_tokens = 0
232
+ for level in range(max_level):
233
+ level_T = T // (self.temporal_downsample ** level)
234
+ total_tokens += level_T
235
+
236
+ original_values = T * self.invariant_dim
237
+ token_values = total_tokens # Each token is one index
238
+
239
+ return {
240
+ "original_values": original_values,
241
+ "total_tokens": total_tokens,
242
+ "tokens_per_level": [T // (self.temporal_downsample ** l) for l in range(max_level)],
243
+ "compression_ratio": original_values / token_values if token_values > 0 else 1,
244
+ "bits_per_value": (total_tokens * np.log2(self.codebook_size)) / original_values,
245
+ }
246
+
247
+
248
+ class TemporalDownsample(nn.Module):
249
+ """Temporal downsampling via strided convolution."""
250
+
251
+ def __init__(self, factor: int = 2):
252
+ super().__init__()
253
+ self.factor = factor
254
+
255
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
256
+ # x: (B, T, D)
257
+ return x[:, ::self.factor, :]
258
+
259
+
260
+ class TemporalUpsample(nn.Module):
261
+ """Temporal upsampling via interpolation."""
262
+
263
+ def __init__(self, factor: int = 2):
264
+ super().__init__()
265
+ self.factor = factor
266
+
267
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
268
+ # x: (B, T, D)
269
+ B, T, D = x.shape
270
+ x = x.transpose(1, 2) # (B, D, T)
271
+ x = F.interpolate(x, size=T * self.factor, mode='linear', align_corners=True)
272
+ return x.transpose(1, 2) # (B, T*factor, D)
273
+
274
+
275
+ class ProgressiveTokenizer(nn.Module):
276
+ """
277
+ Progressive refinement tokenizer.
278
+
279
+ Outputs can be truncated at any level for variable-rate decoding:
280
+ - 1 level: ~4x compression, coarse motion
281
+ - 2 levels: ~2x compression, medium detail
282
+ - 4 levels: ~1x compression, full fidelity
283
+
284
+ Ideal for streaming/bandwidth-adaptive applications.
285
+ """
286
+
287
+ def __init__(
288
+ self,
289
+ invariant_dim: int,
290
+ latent_dim: int,
291
+ codebook_size: int,
292
+ num_refinements: int = 3,
293
+ ):
294
+ super().__init__()
295
+
296
+ self.base_tokenizer = nn.Sequential(
297
+ CausalConv1dEncoder(invariant_dim, latent_dim, latent_dim, 2),
298
+ )
299
+ self.base_vq = VectorQuantizer(codebook_size, latent_dim)
300
+ self.base_decoder = CausalConv1dEncoder(latent_dim, latent_dim, invariant_dim, 2)
301
+
302
+ # Refinement stages
303
+ self.refinements = nn.ModuleList()
304
+ self.refine_vqs = nn.ModuleList()
305
+ self.refine_decoders = nn.ModuleList()
306
+
307
+ for _ in range(num_refinements):
308
+ self.refinements.append(
309
+ CausalConv1dEncoder(invariant_dim, latent_dim, latent_dim, 1)
310
+ )
311
+ self.refine_vqs.append(VectorQuantizer(codebook_size, latent_dim))
312
+ self.refine_decoders.append(
313
+ CausalConv1dEncoder(latent_dim, latent_dim, invariant_dim, 1)
314
+ )
315
+
316
+ def forward(self, invariants: torch.Tensor, num_refine: int = None) -> tuple:
317
+ """
318
+ Progressive tokenization.
319
+
320
+ Args:
321
+ invariants: (B, T, D) input
322
+ num_refine: Number of refinement levels (0 = base only)
323
+
324
+ Returns:
325
+ all_tokens: List of token tensors
326
+ reconstructed: Final reconstruction
327
+ """
328
+ if num_refine is None:
329
+ num_refine = len(self.refinements)
330
+
331
+ # Base encoding
332
+ z_base = self.base_tokenizer(invariants)
333
+ tokens_base, z_q_st, z_q = self.base_vq(z_base)
334
+ recon = self.base_decoder(z_q_st)
335
+
336
+ all_tokens = [tokens_base]
337
+
338
+ # Progressive refinements
339
+ residual = invariants - recon
340
+ for i in range(min(num_refine, len(self.refinements))):
341
+ z_ref = self.refinements[i](residual)
342
+ tokens_ref, z_ref_st, z_ref_q = self.refine_vqs[i](z_ref)
343
+
344
+ all_tokens.append(tokens_ref)
345
+
346
+ recon = recon + self.refine_decoders[i](z_ref_st)
347
+ residual = invariants - recon
348
+
349
+ return all_tokens, recon
350
+
351
+ else:
352
+ HierarchicalTokenizer = None
353
+ ProgressiveTokenizer = None
354
+ TemporalDownsample = None
355
+ TemporalUpsample = None
356
+
357
+
358
+ # Import numpy for compression stats (needed even without torch)
359
+ import numpy as np
@@ -0,0 +1,178 @@
1
+ """Residual VQ (RVQ) tokenizer for higher capacity."""
2
+
3
+ try:
4
+ import torch
5
+ import torch.nn as nn
6
+ from dhb_xr.tokenization.vqvae import VectorQuantizer
7
+ from dhb_xr.tokenization.causal_encoder import CausalConv1dEncoder
8
+ HAS_TORCH = True
9
+ except ImportError:
10
+ HAS_TORCH = False
11
+
12
+ if HAS_TORCH:
13
+
14
+ class ResidualVQTokenizer(nn.Module):
15
+ """RVQ: multiple codebooks on residual. invariants (B, T, C) -> list of (B, T) tokens, (B, T, C) reconstructed."""
16
+
17
+ def __init__(
18
+ self,
19
+ invariant_dim: int,
20
+ latent_dim: int,
21
+ codebook_size: int,
22
+ num_codebooks: int = 2,
23
+ num_layers: int = 2,
24
+ ):
25
+ super().__init__()
26
+ self.encoder = CausalConv1dEncoder(
27
+ invariant_dim, latent_dim, latent_dim, num_layers
28
+ )
29
+ self.vqs = nn.ModuleList([
30
+ VectorQuantizer(codebook_size, latent_dim) for _ in range(num_codebooks)
31
+ ])
32
+ self.decoder = CausalConv1dEncoder(
33
+ latent_dim, latent_dim, invariant_dim, num_layers
34
+ )
35
+ self.num_codebooks = num_codebooks
36
+
37
+ def forward(self, invariants: torch.Tensor) -> tuple:
38
+ z = self.encoder(invariants)
39
+ residuals = z
40
+ all_indices = []
41
+ z_sum = torch.zeros_like(z)
42
+ for vq in self.vqs:
43
+ indices, z_q_st, z_q = vq(residuals)
44
+ all_indices.append(indices)
45
+ z_sum = z_sum + z_q_st
46
+ residuals = residuals - z_q.detach()
47
+ reconstructed = self.decoder(z_sum)
48
+ return all_indices, reconstructed, z, z_sum
49
+
50
+ # ---- Flow matching integration API ----
51
+
52
+ def encode_continuous(self, invariants: torch.Tensor) -> torch.Tensor:
53
+ """
54
+ Encode invariants to continuous latent space (before quantization).
55
+
56
+ Args:
57
+ invariants: Input invariant sequences (B, T, C).
58
+
59
+ Returns:
60
+ Continuous latent z (B, T, latent_dim).
61
+ """
62
+ return self.encoder(invariants)
63
+
64
+ def decode_from_latent(self, z: torch.Tensor) -> torch.Tensor:
65
+ """
66
+ Decode from continuous latent to invariants.
67
+
68
+ Args:
69
+ z: Continuous latent (B, T, latent_dim).
70
+
71
+ Returns:
72
+ Reconstructed invariants (B, T, invariant_dim).
73
+ """
74
+ return self.decoder(z)
75
+
76
+ def quantize(self, z: torch.Tensor, num_codebooks: int = None) -> tuple:
77
+ """
78
+ Quantize continuous latent using RVQ.
79
+
80
+ Args:
81
+ z: Continuous latent (B, T, latent_dim).
82
+ num_codebooks: Number of codebooks to use (default: all).
83
+
84
+ Returns:
85
+ Tuple of (all_indices, z_sum) where all_indices is list of (B, T).
86
+ """
87
+ if num_codebooks is None:
88
+ num_codebooks = self.num_codebooks
89
+
90
+ residuals = z
91
+ all_indices = []
92
+ z_sum = torch.zeros_like(z)
93
+
94
+ for i, vq in enumerate(self.vqs[:num_codebooks]):
95
+ indices, z_q_st, z_q = vq(residuals)
96
+ all_indices.append(indices)
97
+ z_sum = z_sum + z_q_st
98
+ residuals = residuals - z_q.detach()
99
+
100
+ return all_indices, z_sum
101
+
102
+ def encode_partial(
103
+ self,
104
+ invariants: torch.Tensor,
105
+ num_codebooks: int,
106
+ ) -> tuple:
107
+ """
108
+ Encode with partial RVQ (for hierarchical VFM).
109
+
110
+ Uses only the first num_codebooks codebooks.
111
+
112
+ Args:
113
+ invariants: Input invariant sequences (B, T, C).
114
+ num_codebooks: Number of codebooks to use.
115
+
116
+ Returns:
117
+ Tuple of (all_indices, z_sum, reconstructed).
118
+ """
119
+ z = self.encoder(invariants)
120
+ all_indices, z_sum = self.quantize(z, num_codebooks)
121
+ reconstructed = self.decoder(z_sum)
122
+ return all_indices, z_sum, reconstructed
123
+
124
+ def get_codebook_embeddings(self, codebook_idx: int = 0) -> torch.Tensor:
125
+ """
126
+ Get codebook embeddings for a specific codebook.
127
+
128
+ Args:
129
+ codebook_idx: Index of the codebook (0 to num_codebooks-1).
130
+
131
+ Returns:
132
+ Codebook embeddings (codebook_size, latent_dim).
133
+ """
134
+ return self.vqs[codebook_idx].embedding.weight.data
135
+
136
+ def get_all_codebook_embeddings(self) -> list:
137
+ """
138
+ Get all codebook embeddings.
139
+
140
+ Returns:
141
+ List of codebook embeddings, each (codebook_size, latent_dim).
142
+ """
143
+ return [vq.embedding.weight.data for vq in self.vqs]
144
+
145
+ def embed_tokens(self, all_indices: list) -> torch.Tensor:
146
+ """
147
+ Convert RVQ token indices to summed embeddings.
148
+
149
+ Args:
150
+ all_indices: List of token indices, each (B, T).
151
+
152
+ Returns:
153
+ Summed embeddings (B, T, latent_dim).
154
+ """
155
+ z_sum = None
156
+ for i, (indices, vq) in enumerate(zip(all_indices, self.vqs)):
157
+ z_q = vq.embedding(indices)
158
+ if z_sum is None:
159
+ z_sum = z_q
160
+ else:
161
+ z_sum = z_sum + z_q
162
+ return z_sum
163
+
164
+ def decode_tokens(self, all_indices: list) -> torch.Tensor:
165
+ """
166
+ Decode RVQ token indices to invariants.
167
+
168
+ Args:
169
+ all_indices: List of token indices, each (B, T).
170
+
171
+ Returns:
172
+ Reconstructed invariants (B, T, invariant_dim).
173
+ """
174
+ z_sum = self.embed_tokens(all_indices)
175
+ return self.decoder(z_sum)
176
+
177
+ else:
178
+ ResidualVQTokenizer = None
@@ -0,0 +1,155 @@
1
+ """VQ-VAE tokenizer for DHB invariant sequences (DHB-Token)."""
2
+
3
+ try:
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from dhb_xr.tokenization.causal_encoder import CausalConv1dEncoder
8
+ HAS_TORCH = True
9
+ except ImportError:
10
+ HAS_TORCH = False
11
+
12
+ if HAS_TORCH:
13
+
14
+ class VectorQuantizer(nn.Module):
15
+ def __init__(self, num_embeddings: int, embedding_dim: int):
16
+ super().__init__()
17
+ self.embedding = nn.Embedding(num_embeddings, embedding_dim)
18
+ self.embedding.weight.data.uniform_(-1 / num_embeddings, 1 / num_embeddings)
19
+
20
+ def forward(self, z: torch.Tensor) -> tuple:
21
+ # z: (B, T, D)
22
+ B, T, D = z.shape
23
+ z_flat = z.reshape(-1, D)
24
+ d = torch.cdist(z_flat, self.embedding.weight)
25
+ indices = d.argmin(dim=1)
26
+ z_q = self.embedding(indices).reshape(B, T, D)
27
+ z_q_st = z + (z_q - z).detach()
28
+ return indices.reshape(B, T), z_q_st, z_q
29
+
30
+ class DHBTokenizer(nn.Module):
31
+ """
32
+ Causal VQ-VAE for invariant sequences.
33
+ invariants (B, T, C) -> tokens (B, T), reconstructed (B, T, C).
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ invariant_dim: int,
39
+ latent_dim: int,
40
+ codebook_size: int,
41
+ num_layers: int = 2,
42
+ kernel_size: int = 3,
43
+ ):
44
+ super().__init__()
45
+ self.encoder = CausalConv1dEncoder(
46
+ invariant_dim, latent_dim, latent_dim, num_layers, kernel_size
47
+ )
48
+ self.vq = VectorQuantizer(codebook_size, latent_dim)
49
+ self.decoder = CausalConv1dEncoder(
50
+ latent_dim, latent_dim, invariant_dim, num_layers, kernel_size
51
+ )
52
+ self.invariant_dim = invariant_dim
53
+ self.latent_dim = latent_dim
54
+ self.codebook_size = codebook_size
55
+
56
+ def forward(self, invariants: torch.Tensor) -> tuple:
57
+ z = self.encoder(invariants)
58
+ indices, z_q_st, z_q = self.vq(z)
59
+ reconstructed = self.decoder(z_q_st)
60
+ return indices, reconstructed, z, z_q
61
+
62
+ def loss(
63
+ self,
64
+ invariants: torch.Tensor,
65
+ reconstructed: torch.Tensor,
66
+ z: torch.Tensor,
67
+ z_q: torch.Tensor,
68
+ beta: float = 0.25,
69
+ ) -> torch.Tensor:
70
+ rec_loss = F.mse_loss(reconstructed, invariants)
71
+ commitment = F.mse_loss(z, z_q)
72
+ codebook = F.mse_loss(z_q, z.detach())
73
+ return rec_loss + beta * commitment + codebook
74
+
75
+ # ---- Flow matching integration API ----
76
+
77
+ def encode_continuous(self, invariants: torch.Tensor) -> torch.Tensor:
78
+ """
79
+ Encode invariants to continuous latent space (before quantization).
80
+
81
+ This is useful for flow matching which operates in continuous space.
82
+
83
+ Args:
84
+ invariants: Input invariant sequences (B, T, C).
85
+
86
+ Returns:
87
+ Continuous latent z (B, T, latent_dim).
88
+ """
89
+ return self.encoder(invariants)
90
+
91
+ def decode_from_latent(self, z: torch.Tensor) -> torch.Tensor:
92
+ """
93
+ Decode from continuous latent to invariants.
94
+
95
+ Bypasses the VQ step, useful for flow matching generation.
96
+
97
+ Args:
98
+ z: Continuous latent (B, T, latent_dim).
99
+
100
+ Returns:
101
+ Reconstructed invariants (B, T, invariant_dim).
102
+ """
103
+ return self.decoder(z)
104
+
105
+ def quantize(self, z: torch.Tensor) -> tuple:
106
+ """
107
+ Quantize continuous latent to discrete tokens.
108
+
109
+ Args:
110
+ z: Continuous latent (B, T, latent_dim).
111
+
112
+ Returns:
113
+ Tuple of (indices, z_q_st, z_q).
114
+ """
115
+ return self.vq(z)
116
+
117
+ def get_codebook_embeddings(self) -> torch.Tensor:
118
+ """
119
+ Get the VQ codebook embeddings.
120
+
121
+ Useful for flow matching in embedding space or visualization.
122
+
123
+ Returns:
124
+ Codebook embeddings (codebook_size, latent_dim).
125
+ """
126
+ return self.vq.embedding.weight.data
127
+
128
+ def embed_tokens(self, indices: torch.Tensor) -> torch.Tensor:
129
+ """
130
+ Convert token indices to embeddings.
131
+
132
+ Args:
133
+ indices: Token indices (B, T).
134
+
135
+ Returns:
136
+ Token embeddings (B, T, latent_dim).
137
+ """
138
+ return self.vq.embedding(indices)
139
+
140
+ def decode_tokens(self, indices: torch.Tensor) -> torch.Tensor:
141
+ """
142
+ Decode token indices to invariants.
143
+
144
+ Args:
145
+ indices: Token indices (B, T).
146
+
147
+ Returns:
148
+ Reconstructed invariants (B, T, invariant_dim).
149
+ """
150
+ z_q = self.embed_tokens(indices)
151
+ return self.decoder(z_q)
152
+
153
+ else:
154
+ DHBTokenizer = None
155
+ VectorQuantizer = None