dhb-xr 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dhb_xr/__init__.py +61 -0
- dhb_xr/cli.py +206 -0
- dhb_xr/core/__init__.py +28 -0
- dhb_xr/core/geometry.py +167 -0
- dhb_xr/core/geometry_torch.py +77 -0
- dhb_xr/core/types.py +113 -0
- dhb_xr/database/__init__.py +10 -0
- dhb_xr/database/motion_db.py +79 -0
- dhb_xr/database/retrieval.py +6 -0
- dhb_xr/database/similarity.py +71 -0
- dhb_xr/decoder/__init__.py +13 -0
- dhb_xr/decoder/decoder_torch.py +52 -0
- dhb_xr/decoder/dhb_dr.py +261 -0
- dhb_xr/decoder/dhb_qr.py +89 -0
- dhb_xr/encoder/__init__.py +27 -0
- dhb_xr/encoder/dhb_dr.py +418 -0
- dhb_xr/encoder/dhb_qr.py +129 -0
- dhb_xr/encoder/dhb_ti.py +204 -0
- dhb_xr/encoder/encoder_torch.py +54 -0
- dhb_xr/encoder/padding.py +82 -0
- dhb_xr/generative/__init__.py +78 -0
- dhb_xr/generative/flow_matching.py +705 -0
- dhb_xr/generative/latent_encoder.py +536 -0
- dhb_xr/generative/sampling.py +203 -0
- dhb_xr/generative/training.py +475 -0
- dhb_xr/generative/vfm_tokenizer.py +485 -0
- dhb_xr/integration/__init__.py +13 -0
- dhb_xr/integration/vla/__init__.py +11 -0
- dhb_xr/integration/vla/libero.py +132 -0
- dhb_xr/integration/vla/pipeline.py +85 -0
- dhb_xr/integration/vla/robocasa.py +85 -0
- dhb_xr/losses/__init__.py +16 -0
- dhb_xr/losses/geodesic_loss.py +91 -0
- dhb_xr/losses/hybrid_loss.py +36 -0
- dhb_xr/losses/invariant_loss.py +73 -0
- dhb_xr/optimization/__init__.py +72 -0
- dhb_xr/optimization/casadi_solver.py +342 -0
- dhb_xr/optimization/constraints.py +32 -0
- dhb_xr/optimization/cusadi_solver.py +311 -0
- dhb_xr/optimization/export_casadi_decode.py +111 -0
- dhb_xr/optimization/fatrop_solver.py +477 -0
- dhb_xr/optimization/torch_solver.py +85 -0
- dhb_xr/preprocessing/__init__.py +42 -0
- dhb_xr/preprocessing/diagnostics.py +330 -0
- dhb_xr/preprocessing/trajectory_cleaner.py +485 -0
- dhb_xr/tokenization/__init__.py +56 -0
- dhb_xr/tokenization/causal_encoder.py +54 -0
- dhb_xr/tokenization/compression.py +749 -0
- dhb_xr/tokenization/hierarchical.py +359 -0
- dhb_xr/tokenization/rvq.py +178 -0
- dhb_xr/tokenization/vqvae.py +155 -0
- dhb_xr/utils/__init__.py +24 -0
- dhb_xr/utils/io.py +59 -0
- dhb_xr/utils/resampling.py +66 -0
- dhb_xr/utils/xdof_loader.py +89 -0
- dhb_xr/visualization/__init__.py +5 -0
- dhb_xr/visualization/plot.py +242 -0
- dhb_xr-0.2.1.dist-info/METADATA +784 -0
- dhb_xr-0.2.1.dist-info/RECORD +82 -0
- dhb_xr-0.2.1.dist-info/WHEEL +5 -0
- dhb_xr-0.2.1.dist-info/entry_points.txt +2 -0
- dhb_xr-0.2.1.dist-info/top_level.txt +3 -0
- examples/__init__.py +54 -0
- examples/basic_encoding.py +82 -0
- examples/benchmark_backends.py +37 -0
- examples/dhb_qr_comparison.py +79 -0
- examples/dhb_ti_time_invariant.py +72 -0
- examples/gpu_batch_optimization.py +102 -0
- examples/imitation_learning.py +53 -0
- examples/integration/__init__.py +19 -0
- examples/integration/libero_full_demo.py +692 -0
- examples/integration/libero_pro_dhb_demo.py +1063 -0
- examples/integration/libero_simulation_demo.py +286 -0
- examples/integration/libero_swap_demo.py +534 -0
- examples/integration/robocasa_libero_dhb_pipeline.py +56 -0
- examples/integration/test_libero_adapter.py +47 -0
- examples/integration/test_libero_encoding.py +75 -0
- examples/integration/test_libero_retrieval.py +105 -0
- examples/motion_database.py +88 -0
- examples/trajectory_adaptation.py +85 -0
- examples/vla_tokenization.py +107 -0
- notebooks/__init__.py +24 -0
|
@@ -0,0 +1,359 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Hierarchical tokenization for variable-rate DHB compression.
|
|
3
|
+
|
|
4
|
+
Extends RVQ with multi-level structure for lossy compression:
|
|
5
|
+
- Coarse tokens: Global trajectory structure
|
|
6
|
+
- Fine tokens: Local details and refinements
|
|
7
|
+
- Configurable depth for compression/quality tradeoff
|
|
8
|
+
|
|
9
|
+
Inspired by:
|
|
10
|
+
- BEAST (NeurIPS 2025): B-spline encoded action sequences
|
|
11
|
+
- VQ-VLA: Multi-level vector quantization
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
try:
|
|
15
|
+
import torch
|
|
16
|
+
import torch.nn as nn
|
|
17
|
+
import torch.nn.functional as F
|
|
18
|
+
from dhb_xr.tokenization.vqvae import VectorQuantizer
|
|
19
|
+
from dhb_xr.tokenization.causal_encoder import CausalConv1dEncoder
|
|
20
|
+
HAS_TORCH = True
|
|
21
|
+
except ImportError:
|
|
22
|
+
HAS_TORCH = False
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
if HAS_TORCH:
|
|
26
|
+
|
|
27
|
+
class HierarchicalTokenizer(nn.Module):
|
|
28
|
+
"""
|
|
29
|
+
Hierarchical RVQ with variable-rate output.
|
|
30
|
+
|
|
31
|
+
Provides coarse-to-fine tokenization:
|
|
32
|
+
- Level 0: Low-frequency global structure (high compression)
|
|
33
|
+
- Level 1-N: Residual details (configurable refinement)
|
|
34
|
+
|
|
35
|
+
For inference, can truncate to fewer levels for faster/coarser output.
|
|
36
|
+
|
|
37
|
+
Example:
|
|
38
|
+
>>> tokenizer = HierarchicalTokenizer(
|
|
39
|
+
... invariant_dim=8, latent_dim=32, codebook_size=256, num_levels=4
|
|
40
|
+
... )
|
|
41
|
+
>>> tokens, recon = tokenizer(invariants)
|
|
42
|
+
>>>
|
|
43
|
+
>>> # Coarse only (4x fewer tokens)
|
|
44
|
+
>>> tokens_coarse, recon_coarse = tokenizer(invariants, max_level=1)
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
def __init__(
|
|
48
|
+
self,
|
|
49
|
+
invariant_dim: int,
|
|
50
|
+
latent_dim: int,
|
|
51
|
+
codebook_size: int,
|
|
52
|
+
num_levels: int = 4,
|
|
53
|
+
temporal_downsample: int = 2,
|
|
54
|
+
num_layers: int = 2,
|
|
55
|
+
):
|
|
56
|
+
"""
|
|
57
|
+
Args:
|
|
58
|
+
invariant_dim: DHB invariant dimension (typically 8)
|
|
59
|
+
latent_dim: Latent embedding dimension
|
|
60
|
+
codebook_size: VQ codebook size per level
|
|
61
|
+
num_levels: Number of hierarchy levels
|
|
62
|
+
temporal_downsample: Downsample factor between levels
|
|
63
|
+
num_layers: Conv layers per encoder/decoder
|
|
64
|
+
"""
|
|
65
|
+
super().__init__()
|
|
66
|
+
|
|
67
|
+
self.invariant_dim = invariant_dim
|
|
68
|
+
self.latent_dim = latent_dim
|
|
69
|
+
self.codebook_size = codebook_size
|
|
70
|
+
self.num_levels = num_levels
|
|
71
|
+
self.temporal_downsample = temporal_downsample
|
|
72
|
+
|
|
73
|
+
# Per-level encoders (progressively downsample)
|
|
74
|
+
self.encoders = nn.ModuleList()
|
|
75
|
+
self.vqs = nn.ModuleList()
|
|
76
|
+
self.decoders = nn.ModuleList()
|
|
77
|
+
|
|
78
|
+
for level in range(num_levels):
|
|
79
|
+
# Encoder: downsample temporally at each level
|
|
80
|
+
if level == 0:
|
|
81
|
+
enc = CausalConv1dEncoder(
|
|
82
|
+
invariant_dim, latent_dim, latent_dim, num_layers
|
|
83
|
+
)
|
|
84
|
+
else:
|
|
85
|
+
enc = nn.Sequential(
|
|
86
|
+
CausalConv1dEncoder(latent_dim, latent_dim, latent_dim, num_layers),
|
|
87
|
+
TemporalDownsample(temporal_downsample),
|
|
88
|
+
)
|
|
89
|
+
self.encoders.append(enc)
|
|
90
|
+
|
|
91
|
+
# VQ at each level
|
|
92
|
+
self.vqs.append(VectorQuantizer(codebook_size, latent_dim))
|
|
93
|
+
|
|
94
|
+
# Decoder: upsample to match previous level
|
|
95
|
+
if level == 0:
|
|
96
|
+
dec = CausalConv1dEncoder(
|
|
97
|
+
latent_dim, latent_dim, invariant_dim, num_layers
|
|
98
|
+
)
|
|
99
|
+
else:
|
|
100
|
+
dec = nn.Sequential(
|
|
101
|
+
TemporalUpsample(temporal_downsample),
|
|
102
|
+
CausalConv1dEncoder(latent_dim, latent_dim, latent_dim, num_layers),
|
|
103
|
+
)
|
|
104
|
+
self.decoders.append(dec)
|
|
105
|
+
|
|
106
|
+
# Final projection back to invariant space
|
|
107
|
+
self.final_proj = nn.Linear(latent_dim, invariant_dim)
|
|
108
|
+
|
|
109
|
+
def forward(
|
|
110
|
+
self,
|
|
111
|
+
invariants: torch.Tensor,
|
|
112
|
+
max_level: int = None,
|
|
113
|
+
return_all_levels: bool = False,
|
|
114
|
+
) -> tuple:
|
|
115
|
+
"""
|
|
116
|
+
Hierarchical encoding and decoding.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
invariants: (B, T, invariant_dim) input
|
|
120
|
+
max_level: Stop at this level (None = all levels)
|
|
121
|
+
return_all_levels: Return tokens/recon at each level
|
|
122
|
+
|
|
123
|
+
Returns:
|
|
124
|
+
all_tokens: List of (B, T_l) tokens per level
|
|
125
|
+
reconstructed: (B, T, invariant_dim) reconstruction
|
|
126
|
+
level_info: Optional dict with per-level details
|
|
127
|
+
"""
|
|
128
|
+
B, T, C = invariants.shape
|
|
129
|
+
max_level = max_level or self.num_levels
|
|
130
|
+
|
|
131
|
+
all_tokens = []
|
|
132
|
+
all_z = []
|
|
133
|
+
all_z_q = []
|
|
134
|
+
level_info = {}
|
|
135
|
+
|
|
136
|
+
# Encode through hierarchy
|
|
137
|
+
x = invariants
|
|
138
|
+
for level in range(max_level):
|
|
139
|
+
z = self.encoders[level](x if level == 0 else z_residual)
|
|
140
|
+
indices, z_q_st, z_q = self.vqs[level](z)
|
|
141
|
+
|
|
142
|
+
all_tokens.append(indices)
|
|
143
|
+
all_z.append(z)
|
|
144
|
+
all_z_q.append(z_q)
|
|
145
|
+
|
|
146
|
+
if level < max_level - 1:
|
|
147
|
+
z_residual = z - z_q.detach()
|
|
148
|
+
|
|
149
|
+
level_info[f"level_{level}"] = {
|
|
150
|
+
"shape": tuple(z.shape),
|
|
151
|
+
"tokens": indices.shape[-1],
|
|
152
|
+
}
|
|
153
|
+
|
|
154
|
+
# Decode through hierarchy (reverse order)
|
|
155
|
+
reconstructed = torch.zeros_like(invariants)
|
|
156
|
+
for level in reversed(range(max_level)):
|
|
157
|
+
dec_out = self.decoders[level](all_z_q[level])
|
|
158
|
+
|
|
159
|
+
# Match temporal dimension
|
|
160
|
+
if dec_out.shape[1] > reconstructed.shape[1]:
|
|
161
|
+
dec_out = dec_out[:, :reconstructed.shape[1], :]
|
|
162
|
+
elif dec_out.shape[1] < reconstructed.shape[1]:
|
|
163
|
+
# Upsample to match
|
|
164
|
+
dec_out = F.interpolate(
|
|
165
|
+
dec_out.transpose(1, 2),
|
|
166
|
+
size=reconstructed.shape[1],
|
|
167
|
+
mode='linear',
|
|
168
|
+
align_corners=True
|
|
169
|
+
).transpose(1, 2)
|
|
170
|
+
|
|
171
|
+
if level == 0:
|
|
172
|
+
reconstructed = dec_out
|
|
173
|
+
else:
|
|
174
|
+
reconstructed = reconstructed + self.final_proj(dec_out)
|
|
175
|
+
|
|
176
|
+
if return_all_levels:
|
|
177
|
+
return all_tokens, reconstructed, level_info
|
|
178
|
+
return all_tokens, reconstructed
|
|
179
|
+
|
|
180
|
+
def loss(
|
|
181
|
+
self,
|
|
182
|
+
invariants: torch.Tensor,
|
|
183
|
+
reconstructed: torch.Tensor,
|
|
184
|
+
all_z: list,
|
|
185
|
+
all_z_q: list,
|
|
186
|
+
beta: float = 0.25,
|
|
187
|
+
level_weights: list = None,
|
|
188
|
+
) -> torch.Tensor:
|
|
189
|
+
"""
|
|
190
|
+
Compute hierarchical loss with per-level weighting.
|
|
191
|
+
|
|
192
|
+
Args:
|
|
193
|
+
invariants: Original input
|
|
194
|
+
reconstructed: Reconstruction
|
|
195
|
+
all_z: Latents at each level
|
|
196
|
+
all_z_q: Quantized latents at each level
|
|
197
|
+
beta: Commitment loss weight
|
|
198
|
+
level_weights: Optional weights per level (default: exponential decay)
|
|
199
|
+
|
|
200
|
+
Returns:
|
|
201
|
+
Total loss
|
|
202
|
+
"""
|
|
203
|
+
# Reconstruction loss
|
|
204
|
+
rec_loss = F.mse_loss(reconstructed, invariants)
|
|
205
|
+
|
|
206
|
+
# Per-level VQ losses
|
|
207
|
+
if level_weights is None:
|
|
208
|
+
level_weights = [0.5 ** i for i in range(len(all_z))]
|
|
209
|
+
|
|
210
|
+
commitment = 0
|
|
211
|
+
codebook = 0
|
|
212
|
+
for i, (z, z_q) in enumerate(zip(all_z, all_z_q)):
|
|
213
|
+
commitment += level_weights[i] * F.mse_loss(z, z_q.detach())
|
|
214
|
+
codebook += level_weights[i] * F.mse_loss(z_q, z.detach())
|
|
215
|
+
|
|
216
|
+
return rec_loss + beta * commitment + codebook
|
|
217
|
+
|
|
218
|
+
def get_compression_stats(self, T: int, max_level: int = None) -> dict:
|
|
219
|
+
"""
|
|
220
|
+
Compute compression statistics.
|
|
221
|
+
|
|
222
|
+
Args:
|
|
223
|
+
T: Original sequence length
|
|
224
|
+
max_level: Number of levels to use
|
|
225
|
+
|
|
226
|
+
Returns:
|
|
227
|
+
Compression statistics
|
|
228
|
+
"""
|
|
229
|
+
max_level = max_level or self.num_levels
|
|
230
|
+
|
|
231
|
+
total_tokens = 0
|
|
232
|
+
for level in range(max_level):
|
|
233
|
+
level_T = T // (self.temporal_downsample ** level)
|
|
234
|
+
total_tokens += level_T
|
|
235
|
+
|
|
236
|
+
original_values = T * self.invariant_dim
|
|
237
|
+
token_values = total_tokens # Each token is one index
|
|
238
|
+
|
|
239
|
+
return {
|
|
240
|
+
"original_values": original_values,
|
|
241
|
+
"total_tokens": total_tokens,
|
|
242
|
+
"tokens_per_level": [T // (self.temporal_downsample ** l) for l in range(max_level)],
|
|
243
|
+
"compression_ratio": original_values / token_values if token_values > 0 else 1,
|
|
244
|
+
"bits_per_value": (total_tokens * np.log2(self.codebook_size)) / original_values,
|
|
245
|
+
}
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
class TemporalDownsample(nn.Module):
|
|
249
|
+
"""Temporal downsampling via strided convolution."""
|
|
250
|
+
|
|
251
|
+
def __init__(self, factor: int = 2):
|
|
252
|
+
super().__init__()
|
|
253
|
+
self.factor = factor
|
|
254
|
+
|
|
255
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
256
|
+
# x: (B, T, D)
|
|
257
|
+
return x[:, ::self.factor, :]
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
class TemporalUpsample(nn.Module):
|
|
261
|
+
"""Temporal upsampling via interpolation."""
|
|
262
|
+
|
|
263
|
+
def __init__(self, factor: int = 2):
|
|
264
|
+
super().__init__()
|
|
265
|
+
self.factor = factor
|
|
266
|
+
|
|
267
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
268
|
+
# x: (B, T, D)
|
|
269
|
+
B, T, D = x.shape
|
|
270
|
+
x = x.transpose(1, 2) # (B, D, T)
|
|
271
|
+
x = F.interpolate(x, size=T * self.factor, mode='linear', align_corners=True)
|
|
272
|
+
return x.transpose(1, 2) # (B, T*factor, D)
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
class ProgressiveTokenizer(nn.Module):
|
|
276
|
+
"""
|
|
277
|
+
Progressive refinement tokenizer.
|
|
278
|
+
|
|
279
|
+
Outputs can be truncated at any level for variable-rate decoding:
|
|
280
|
+
- 1 level: ~4x compression, coarse motion
|
|
281
|
+
- 2 levels: ~2x compression, medium detail
|
|
282
|
+
- 4 levels: ~1x compression, full fidelity
|
|
283
|
+
|
|
284
|
+
Ideal for streaming/bandwidth-adaptive applications.
|
|
285
|
+
"""
|
|
286
|
+
|
|
287
|
+
def __init__(
|
|
288
|
+
self,
|
|
289
|
+
invariant_dim: int,
|
|
290
|
+
latent_dim: int,
|
|
291
|
+
codebook_size: int,
|
|
292
|
+
num_refinements: int = 3,
|
|
293
|
+
):
|
|
294
|
+
super().__init__()
|
|
295
|
+
|
|
296
|
+
self.base_tokenizer = nn.Sequential(
|
|
297
|
+
CausalConv1dEncoder(invariant_dim, latent_dim, latent_dim, 2),
|
|
298
|
+
)
|
|
299
|
+
self.base_vq = VectorQuantizer(codebook_size, latent_dim)
|
|
300
|
+
self.base_decoder = CausalConv1dEncoder(latent_dim, latent_dim, invariant_dim, 2)
|
|
301
|
+
|
|
302
|
+
# Refinement stages
|
|
303
|
+
self.refinements = nn.ModuleList()
|
|
304
|
+
self.refine_vqs = nn.ModuleList()
|
|
305
|
+
self.refine_decoders = nn.ModuleList()
|
|
306
|
+
|
|
307
|
+
for _ in range(num_refinements):
|
|
308
|
+
self.refinements.append(
|
|
309
|
+
CausalConv1dEncoder(invariant_dim, latent_dim, latent_dim, 1)
|
|
310
|
+
)
|
|
311
|
+
self.refine_vqs.append(VectorQuantizer(codebook_size, latent_dim))
|
|
312
|
+
self.refine_decoders.append(
|
|
313
|
+
CausalConv1dEncoder(latent_dim, latent_dim, invariant_dim, 1)
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
def forward(self, invariants: torch.Tensor, num_refine: int = None) -> tuple:
|
|
317
|
+
"""
|
|
318
|
+
Progressive tokenization.
|
|
319
|
+
|
|
320
|
+
Args:
|
|
321
|
+
invariants: (B, T, D) input
|
|
322
|
+
num_refine: Number of refinement levels (0 = base only)
|
|
323
|
+
|
|
324
|
+
Returns:
|
|
325
|
+
all_tokens: List of token tensors
|
|
326
|
+
reconstructed: Final reconstruction
|
|
327
|
+
"""
|
|
328
|
+
if num_refine is None:
|
|
329
|
+
num_refine = len(self.refinements)
|
|
330
|
+
|
|
331
|
+
# Base encoding
|
|
332
|
+
z_base = self.base_tokenizer(invariants)
|
|
333
|
+
tokens_base, z_q_st, z_q = self.base_vq(z_base)
|
|
334
|
+
recon = self.base_decoder(z_q_st)
|
|
335
|
+
|
|
336
|
+
all_tokens = [tokens_base]
|
|
337
|
+
|
|
338
|
+
# Progressive refinements
|
|
339
|
+
residual = invariants - recon
|
|
340
|
+
for i in range(min(num_refine, len(self.refinements))):
|
|
341
|
+
z_ref = self.refinements[i](residual)
|
|
342
|
+
tokens_ref, z_ref_st, z_ref_q = self.refine_vqs[i](z_ref)
|
|
343
|
+
|
|
344
|
+
all_tokens.append(tokens_ref)
|
|
345
|
+
|
|
346
|
+
recon = recon + self.refine_decoders[i](z_ref_st)
|
|
347
|
+
residual = invariants - recon
|
|
348
|
+
|
|
349
|
+
return all_tokens, recon
|
|
350
|
+
|
|
351
|
+
else:
|
|
352
|
+
HierarchicalTokenizer = None
|
|
353
|
+
ProgressiveTokenizer = None
|
|
354
|
+
TemporalDownsample = None
|
|
355
|
+
TemporalUpsample = None
|
|
356
|
+
|
|
357
|
+
|
|
358
|
+
# Import numpy for compression stats (needed even without torch)
|
|
359
|
+
import numpy as np
|
|
@@ -0,0 +1,178 @@
|
|
|
1
|
+
"""Residual VQ (RVQ) tokenizer for higher capacity."""
|
|
2
|
+
|
|
3
|
+
try:
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
from dhb_xr.tokenization.vqvae import VectorQuantizer
|
|
7
|
+
from dhb_xr.tokenization.causal_encoder import CausalConv1dEncoder
|
|
8
|
+
HAS_TORCH = True
|
|
9
|
+
except ImportError:
|
|
10
|
+
HAS_TORCH = False
|
|
11
|
+
|
|
12
|
+
if HAS_TORCH:
|
|
13
|
+
|
|
14
|
+
class ResidualVQTokenizer(nn.Module):
|
|
15
|
+
"""RVQ: multiple codebooks on residual. invariants (B, T, C) -> list of (B, T) tokens, (B, T, C) reconstructed."""
|
|
16
|
+
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
invariant_dim: int,
|
|
20
|
+
latent_dim: int,
|
|
21
|
+
codebook_size: int,
|
|
22
|
+
num_codebooks: int = 2,
|
|
23
|
+
num_layers: int = 2,
|
|
24
|
+
):
|
|
25
|
+
super().__init__()
|
|
26
|
+
self.encoder = CausalConv1dEncoder(
|
|
27
|
+
invariant_dim, latent_dim, latent_dim, num_layers
|
|
28
|
+
)
|
|
29
|
+
self.vqs = nn.ModuleList([
|
|
30
|
+
VectorQuantizer(codebook_size, latent_dim) for _ in range(num_codebooks)
|
|
31
|
+
])
|
|
32
|
+
self.decoder = CausalConv1dEncoder(
|
|
33
|
+
latent_dim, latent_dim, invariant_dim, num_layers
|
|
34
|
+
)
|
|
35
|
+
self.num_codebooks = num_codebooks
|
|
36
|
+
|
|
37
|
+
def forward(self, invariants: torch.Tensor) -> tuple:
|
|
38
|
+
z = self.encoder(invariants)
|
|
39
|
+
residuals = z
|
|
40
|
+
all_indices = []
|
|
41
|
+
z_sum = torch.zeros_like(z)
|
|
42
|
+
for vq in self.vqs:
|
|
43
|
+
indices, z_q_st, z_q = vq(residuals)
|
|
44
|
+
all_indices.append(indices)
|
|
45
|
+
z_sum = z_sum + z_q_st
|
|
46
|
+
residuals = residuals - z_q.detach()
|
|
47
|
+
reconstructed = self.decoder(z_sum)
|
|
48
|
+
return all_indices, reconstructed, z, z_sum
|
|
49
|
+
|
|
50
|
+
# ---- Flow matching integration API ----
|
|
51
|
+
|
|
52
|
+
def encode_continuous(self, invariants: torch.Tensor) -> torch.Tensor:
|
|
53
|
+
"""
|
|
54
|
+
Encode invariants to continuous latent space (before quantization).
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
invariants: Input invariant sequences (B, T, C).
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
Continuous latent z (B, T, latent_dim).
|
|
61
|
+
"""
|
|
62
|
+
return self.encoder(invariants)
|
|
63
|
+
|
|
64
|
+
def decode_from_latent(self, z: torch.Tensor) -> torch.Tensor:
|
|
65
|
+
"""
|
|
66
|
+
Decode from continuous latent to invariants.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
z: Continuous latent (B, T, latent_dim).
|
|
70
|
+
|
|
71
|
+
Returns:
|
|
72
|
+
Reconstructed invariants (B, T, invariant_dim).
|
|
73
|
+
"""
|
|
74
|
+
return self.decoder(z)
|
|
75
|
+
|
|
76
|
+
def quantize(self, z: torch.Tensor, num_codebooks: int = None) -> tuple:
|
|
77
|
+
"""
|
|
78
|
+
Quantize continuous latent using RVQ.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
z: Continuous latent (B, T, latent_dim).
|
|
82
|
+
num_codebooks: Number of codebooks to use (default: all).
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
Tuple of (all_indices, z_sum) where all_indices is list of (B, T).
|
|
86
|
+
"""
|
|
87
|
+
if num_codebooks is None:
|
|
88
|
+
num_codebooks = self.num_codebooks
|
|
89
|
+
|
|
90
|
+
residuals = z
|
|
91
|
+
all_indices = []
|
|
92
|
+
z_sum = torch.zeros_like(z)
|
|
93
|
+
|
|
94
|
+
for i, vq in enumerate(self.vqs[:num_codebooks]):
|
|
95
|
+
indices, z_q_st, z_q = vq(residuals)
|
|
96
|
+
all_indices.append(indices)
|
|
97
|
+
z_sum = z_sum + z_q_st
|
|
98
|
+
residuals = residuals - z_q.detach()
|
|
99
|
+
|
|
100
|
+
return all_indices, z_sum
|
|
101
|
+
|
|
102
|
+
def encode_partial(
|
|
103
|
+
self,
|
|
104
|
+
invariants: torch.Tensor,
|
|
105
|
+
num_codebooks: int,
|
|
106
|
+
) -> tuple:
|
|
107
|
+
"""
|
|
108
|
+
Encode with partial RVQ (for hierarchical VFM).
|
|
109
|
+
|
|
110
|
+
Uses only the first num_codebooks codebooks.
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
invariants: Input invariant sequences (B, T, C).
|
|
114
|
+
num_codebooks: Number of codebooks to use.
|
|
115
|
+
|
|
116
|
+
Returns:
|
|
117
|
+
Tuple of (all_indices, z_sum, reconstructed).
|
|
118
|
+
"""
|
|
119
|
+
z = self.encoder(invariants)
|
|
120
|
+
all_indices, z_sum = self.quantize(z, num_codebooks)
|
|
121
|
+
reconstructed = self.decoder(z_sum)
|
|
122
|
+
return all_indices, z_sum, reconstructed
|
|
123
|
+
|
|
124
|
+
def get_codebook_embeddings(self, codebook_idx: int = 0) -> torch.Tensor:
|
|
125
|
+
"""
|
|
126
|
+
Get codebook embeddings for a specific codebook.
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
codebook_idx: Index of the codebook (0 to num_codebooks-1).
|
|
130
|
+
|
|
131
|
+
Returns:
|
|
132
|
+
Codebook embeddings (codebook_size, latent_dim).
|
|
133
|
+
"""
|
|
134
|
+
return self.vqs[codebook_idx].embedding.weight.data
|
|
135
|
+
|
|
136
|
+
def get_all_codebook_embeddings(self) -> list:
|
|
137
|
+
"""
|
|
138
|
+
Get all codebook embeddings.
|
|
139
|
+
|
|
140
|
+
Returns:
|
|
141
|
+
List of codebook embeddings, each (codebook_size, latent_dim).
|
|
142
|
+
"""
|
|
143
|
+
return [vq.embedding.weight.data for vq in self.vqs]
|
|
144
|
+
|
|
145
|
+
def embed_tokens(self, all_indices: list) -> torch.Tensor:
|
|
146
|
+
"""
|
|
147
|
+
Convert RVQ token indices to summed embeddings.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
all_indices: List of token indices, each (B, T).
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
Summed embeddings (B, T, latent_dim).
|
|
154
|
+
"""
|
|
155
|
+
z_sum = None
|
|
156
|
+
for i, (indices, vq) in enumerate(zip(all_indices, self.vqs)):
|
|
157
|
+
z_q = vq.embedding(indices)
|
|
158
|
+
if z_sum is None:
|
|
159
|
+
z_sum = z_q
|
|
160
|
+
else:
|
|
161
|
+
z_sum = z_sum + z_q
|
|
162
|
+
return z_sum
|
|
163
|
+
|
|
164
|
+
def decode_tokens(self, all_indices: list) -> torch.Tensor:
|
|
165
|
+
"""
|
|
166
|
+
Decode RVQ token indices to invariants.
|
|
167
|
+
|
|
168
|
+
Args:
|
|
169
|
+
all_indices: List of token indices, each (B, T).
|
|
170
|
+
|
|
171
|
+
Returns:
|
|
172
|
+
Reconstructed invariants (B, T, invariant_dim).
|
|
173
|
+
"""
|
|
174
|
+
z_sum = self.embed_tokens(all_indices)
|
|
175
|
+
return self.decoder(z_sum)
|
|
176
|
+
|
|
177
|
+
else:
|
|
178
|
+
ResidualVQTokenizer = None
|
|
@@ -0,0 +1,155 @@
|
|
|
1
|
+
"""VQ-VAE tokenizer for DHB invariant sequences (DHB-Token)."""
|
|
2
|
+
|
|
3
|
+
try:
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
import torch.nn.functional as F
|
|
7
|
+
from dhb_xr.tokenization.causal_encoder import CausalConv1dEncoder
|
|
8
|
+
HAS_TORCH = True
|
|
9
|
+
except ImportError:
|
|
10
|
+
HAS_TORCH = False
|
|
11
|
+
|
|
12
|
+
if HAS_TORCH:
|
|
13
|
+
|
|
14
|
+
class VectorQuantizer(nn.Module):
|
|
15
|
+
def __init__(self, num_embeddings: int, embedding_dim: int):
|
|
16
|
+
super().__init__()
|
|
17
|
+
self.embedding = nn.Embedding(num_embeddings, embedding_dim)
|
|
18
|
+
self.embedding.weight.data.uniform_(-1 / num_embeddings, 1 / num_embeddings)
|
|
19
|
+
|
|
20
|
+
def forward(self, z: torch.Tensor) -> tuple:
|
|
21
|
+
# z: (B, T, D)
|
|
22
|
+
B, T, D = z.shape
|
|
23
|
+
z_flat = z.reshape(-1, D)
|
|
24
|
+
d = torch.cdist(z_flat, self.embedding.weight)
|
|
25
|
+
indices = d.argmin(dim=1)
|
|
26
|
+
z_q = self.embedding(indices).reshape(B, T, D)
|
|
27
|
+
z_q_st = z + (z_q - z).detach()
|
|
28
|
+
return indices.reshape(B, T), z_q_st, z_q
|
|
29
|
+
|
|
30
|
+
class DHBTokenizer(nn.Module):
|
|
31
|
+
"""
|
|
32
|
+
Causal VQ-VAE for invariant sequences.
|
|
33
|
+
invariants (B, T, C) -> tokens (B, T), reconstructed (B, T, C).
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def __init__(
|
|
37
|
+
self,
|
|
38
|
+
invariant_dim: int,
|
|
39
|
+
latent_dim: int,
|
|
40
|
+
codebook_size: int,
|
|
41
|
+
num_layers: int = 2,
|
|
42
|
+
kernel_size: int = 3,
|
|
43
|
+
):
|
|
44
|
+
super().__init__()
|
|
45
|
+
self.encoder = CausalConv1dEncoder(
|
|
46
|
+
invariant_dim, latent_dim, latent_dim, num_layers, kernel_size
|
|
47
|
+
)
|
|
48
|
+
self.vq = VectorQuantizer(codebook_size, latent_dim)
|
|
49
|
+
self.decoder = CausalConv1dEncoder(
|
|
50
|
+
latent_dim, latent_dim, invariant_dim, num_layers, kernel_size
|
|
51
|
+
)
|
|
52
|
+
self.invariant_dim = invariant_dim
|
|
53
|
+
self.latent_dim = latent_dim
|
|
54
|
+
self.codebook_size = codebook_size
|
|
55
|
+
|
|
56
|
+
def forward(self, invariants: torch.Tensor) -> tuple:
|
|
57
|
+
z = self.encoder(invariants)
|
|
58
|
+
indices, z_q_st, z_q = self.vq(z)
|
|
59
|
+
reconstructed = self.decoder(z_q_st)
|
|
60
|
+
return indices, reconstructed, z, z_q
|
|
61
|
+
|
|
62
|
+
def loss(
|
|
63
|
+
self,
|
|
64
|
+
invariants: torch.Tensor,
|
|
65
|
+
reconstructed: torch.Tensor,
|
|
66
|
+
z: torch.Tensor,
|
|
67
|
+
z_q: torch.Tensor,
|
|
68
|
+
beta: float = 0.25,
|
|
69
|
+
) -> torch.Tensor:
|
|
70
|
+
rec_loss = F.mse_loss(reconstructed, invariants)
|
|
71
|
+
commitment = F.mse_loss(z, z_q)
|
|
72
|
+
codebook = F.mse_loss(z_q, z.detach())
|
|
73
|
+
return rec_loss + beta * commitment + codebook
|
|
74
|
+
|
|
75
|
+
# ---- Flow matching integration API ----
|
|
76
|
+
|
|
77
|
+
def encode_continuous(self, invariants: torch.Tensor) -> torch.Tensor:
|
|
78
|
+
"""
|
|
79
|
+
Encode invariants to continuous latent space (before quantization).
|
|
80
|
+
|
|
81
|
+
This is useful for flow matching which operates in continuous space.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
invariants: Input invariant sequences (B, T, C).
|
|
85
|
+
|
|
86
|
+
Returns:
|
|
87
|
+
Continuous latent z (B, T, latent_dim).
|
|
88
|
+
"""
|
|
89
|
+
return self.encoder(invariants)
|
|
90
|
+
|
|
91
|
+
def decode_from_latent(self, z: torch.Tensor) -> torch.Tensor:
|
|
92
|
+
"""
|
|
93
|
+
Decode from continuous latent to invariants.
|
|
94
|
+
|
|
95
|
+
Bypasses the VQ step, useful for flow matching generation.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
z: Continuous latent (B, T, latent_dim).
|
|
99
|
+
|
|
100
|
+
Returns:
|
|
101
|
+
Reconstructed invariants (B, T, invariant_dim).
|
|
102
|
+
"""
|
|
103
|
+
return self.decoder(z)
|
|
104
|
+
|
|
105
|
+
def quantize(self, z: torch.Tensor) -> tuple:
|
|
106
|
+
"""
|
|
107
|
+
Quantize continuous latent to discrete tokens.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
z: Continuous latent (B, T, latent_dim).
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
Tuple of (indices, z_q_st, z_q).
|
|
114
|
+
"""
|
|
115
|
+
return self.vq(z)
|
|
116
|
+
|
|
117
|
+
def get_codebook_embeddings(self) -> torch.Tensor:
|
|
118
|
+
"""
|
|
119
|
+
Get the VQ codebook embeddings.
|
|
120
|
+
|
|
121
|
+
Useful for flow matching in embedding space or visualization.
|
|
122
|
+
|
|
123
|
+
Returns:
|
|
124
|
+
Codebook embeddings (codebook_size, latent_dim).
|
|
125
|
+
"""
|
|
126
|
+
return self.vq.embedding.weight.data
|
|
127
|
+
|
|
128
|
+
def embed_tokens(self, indices: torch.Tensor) -> torch.Tensor:
|
|
129
|
+
"""
|
|
130
|
+
Convert token indices to embeddings.
|
|
131
|
+
|
|
132
|
+
Args:
|
|
133
|
+
indices: Token indices (B, T).
|
|
134
|
+
|
|
135
|
+
Returns:
|
|
136
|
+
Token embeddings (B, T, latent_dim).
|
|
137
|
+
"""
|
|
138
|
+
return self.vq.embedding(indices)
|
|
139
|
+
|
|
140
|
+
def decode_tokens(self, indices: torch.Tensor) -> torch.Tensor:
|
|
141
|
+
"""
|
|
142
|
+
Decode token indices to invariants.
|
|
143
|
+
|
|
144
|
+
Args:
|
|
145
|
+
indices: Token indices (B, T).
|
|
146
|
+
|
|
147
|
+
Returns:
|
|
148
|
+
Reconstructed invariants (B, T, invariant_dim).
|
|
149
|
+
"""
|
|
150
|
+
z_q = self.embed_tokens(indices)
|
|
151
|
+
return self.decoder(z_q)
|
|
152
|
+
|
|
153
|
+
else:
|
|
154
|
+
DHBTokenizer = None
|
|
155
|
+
VectorQuantizer = None
|