dhb-xr 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dhb_xr/__init__.py +61 -0
- dhb_xr/cli.py +206 -0
- dhb_xr/core/__init__.py +28 -0
- dhb_xr/core/geometry.py +167 -0
- dhb_xr/core/geometry_torch.py +77 -0
- dhb_xr/core/types.py +113 -0
- dhb_xr/database/__init__.py +10 -0
- dhb_xr/database/motion_db.py +79 -0
- dhb_xr/database/retrieval.py +6 -0
- dhb_xr/database/similarity.py +71 -0
- dhb_xr/decoder/__init__.py +13 -0
- dhb_xr/decoder/decoder_torch.py +52 -0
- dhb_xr/decoder/dhb_dr.py +261 -0
- dhb_xr/decoder/dhb_qr.py +89 -0
- dhb_xr/encoder/__init__.py +27 -0
- dhb_xr/encoder/dhb_dr.py +418 -0
- dhb_xr/encoder/dhb_qr.py +129 -0
- dhb_xr/encoder/dhb_ti.py +204 -0
- dhb_xr/encoder/encoder_torch.py +54 -0
- dhb_xr/encoder/padding.py +82 -0
- dhb_xr/generative/__init__.py +78 -0
- dhb_xr/generative/flow_matching.py +705 -0
- dhb_xr/generative/latent_encoder.py +536 -0
- dhb_xr/generative/sampling.py +203 -0
- dhb_xr/generative/training.py +475 -0
- dhb_xr/generative/vfm_tokenizer.py +485 -0
- dhb_xr/integration/__init__.py +13 -0
- dhb_xr/integration/vla/__init__.py +11 -0
- dhb_xr/integration/vla/libero.py +132 -0
- dhb_xr/integration/vla/pipeline.py +85 -0
- dhb_xr/integration/vla/robocasa.py +85 -0
- dhb_xr/losses/__init__.py +16 -0
- dhb_xr/losses/geodesic_loss.py +91 -0
- dhb_xr/losses/hybrid_loss.py +36 -0
- dhb_xr/losses/invariant_loss.py +73 -0
- dhb_xr/optimization/__init__.py +72 -0
- dhb_xr/optimization/casadi_solver.py +342 -0
- dhb_xr/optimization/constraints.py +32 -0
- dhb_xr/optimization/cusadi_solver.py +311 -0
- dhb_xr/optimization/export_casadi_decode.py +111 -0
- dhb_xr/optimization/fatrop_solver.py +477 -0
- dhb_xr/optimization/torch_solver.py +85 -0
- dhb_xr/preprocessing/__init__.py +42 -0
- dhb_xr/preprocessing/diagnostics.py +330 -0
- dhb_xr/preprocessing/trajectory_cleaner.py +485 -0
- dhb_xr/tokenization/__init__.py +56 -0
- dhb_xr/tokenization/causal_encoder.py +54 -0
- dhb_xr/tokenization/compression.py +749 -0
- dhb_xr/tokenization/hierarchical.py +359 -0
- dhb_xr/tokenization/rvq.py +178 -0
- dhb_xr/tokenization/vqvae.py +155 -0
- dhb_xr/utils/__init__.py +24 -0
- dhb_xr/utils/io.py +59 -0
- dhb_xr/utils/resampling.py +66 -0
- dhb_xr/utils/xdof_loader.py +89 -0
- dhb_xr/visualization/__init__.py +5 -0
- dhb_xr/visualization/plot.py +242 -0
- dhb_xr-0.2.1.dist-info/METADATA +784 -0
- dhb_xr-0.2.1.dist-info/RECORD +82 -0
- dhb_xr-0.2.1.dist-info/WHEEL +5 -0
- dhb_xr-0.2.1.dist-info/entry_points.txt +2 -0
- dhb_xr-0.2.1.dist-info/top_level.txt +3 -0
- examples/__init__.py +54 -0
- examples/basic_encoding.py +82 -0
- examples/benchmark_backends.py +37 -0
- examples/dhb_qr_comparison.py +79 -0
- examples/dhb_ti_time_invariant.py +72 -0
- examples/gpu_batch_optimization.py +102 -0
- examples/imitation_learning.py +53 -0
- examples/integration/__init__.py +19 -0
- examples/integration/libero_full_demo.py +692 -0
- examples/integration/libero_pro_dhb_demo.py +1063 -0
- examples/integration/libero_simulation_demo.py +286 -0
- examples/integration/libero_swap_demo.py +534 -0
- examples/integration/robocasa_libero_dhb_pipeline.py +56 -0
- examples/integration/test_libero_adapter.py +47 -0
- examples/integration/test_libero_encoding.py +75 -0
- examples/integration/test_libero_retrieval.py +105 -0
- examples/motion_database.py +88 -0
- examples/trajectory_adaptation.py +85 -0
- examples/vla_tokenization.py +107 -0
- notebooks/__init__.py +24 -0
|
@@ -0,0 +1,536 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Latent encoders for variational flow matching.
|
|
3
|
+
|
|
4
|
+
Provides inference networks q(w | z_t, t) for the variational formulation:
|
|
5
|
+
- LatentEncoder: Gaussian posterior (continuous w)
|
|
6
|
+
- CategoricalLatentEncoder: Categorical posterior with Gumbel-Softmax (discrete w)
|
|
7
|
+
|
|
8
|
+
The latent w captures "intent" or "mode" information that helps the velocity
|
|
9
|
+
network distinguish between multiple valid trajectory continuations.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
|
|
14
|
+
import math
|
|
15
|
+
import torch
|
|
16
|
+
import torch.nn as nn
|
|
17
|
+
import torch.nn.functional as F
|
|
18
|
+
from torch import Tensor
|
|
19
|
+
from typing import Optional, Tuple, Dict
|
|
20
|
+
|
|
21
|
+
from .flow_matching import SinusoidalTimeEmbedding
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class LatentEncoder(nn.Module):
|
|
25
|
+
"""
|
|
26
|
+
Gaussian latent encoder for variational flow matching.
|
|
27
|
+
|
|
28
|
+
Maps (z_t, t) to a Gaussian distribution over the conditioning latent w.
|
|
29
|
+
Uses an MLP with optional temporal attention for sequence inputs.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def __init__(
|
|
33
|
+
self,
|
|
34
|
+
input_dim: int,
|
|
35
|
+
latent_dim: int,
|
|
36
|
+
hidden_dim: int = 256,
|
|
37
|
+
num_layers: int = 2,
|
|
38
|
+
time_embed_dim: int = 64,
|
|
39
|
+
use_attention: bool = False,
|
|
40
|
+
num_heads: int = 4,
|
|
41
|
+
):
|
|
42
|
+
"""
|
|
43
|
+
Args:
|
|
44
|
+
input_dim: Dimension of input z_t.
|
|
45
|
+
latent_dim: Dimension of output latent w.
|
|
46
|
+
hidden_dim: Hidden layer dimension.
|
|
47
|
+
num_layers: Number of hidden layers.
|
|
48
|
+
time_embed_dim: Dimension of time embedding.
|
|
49
|
+
use_attention: Whether to use attention for sequence pooling.
|
|
50
|
+
num_heads: Number of attention heads (if use_attention=True).
|
|
51
|
+
"""
|
|
52
|
+
super().__init__()
|
|
53
|
+
self.input_dim = input_dim
|
|
54
|
+
self.latent_dim = latent_dim
|
|
55
|
+
self.use_attention = use_attention
|
|
56
|
+
|
|
57
|
+
# Time embedding
|
|
58
|
+
self.time_embed = SinusoidalTimeEmbedding(time_embed_dim)
|
|
59
|
+
|
|
60
|
+
# Optional attention for sequence pooling
|
|
61
|
+
if use_attention:
|
|
62
|
+
self.attention = nn.MultiheadAttention(
|
|
63
|
+
embed_dim=input_dim,
|
|
64
|
+
num_heads=num_heads,
|
|
65
|
+
batch_first=True,
|
|
66
|
+
)
|
|
67
|
+
self.pool_query = nn.Parameter(torch.randn(1, 1, input_dim))
|
|
68
|
+
|
|
69
|
+
# MLP encoder
|
|
70
|
+
layers = []
|
|
71
|
+
in_dim = input_dim + time_embed_dim
|
|
72
|
+
for i in range(num_layers):
|
|
73
|
+
out_dim = hidden_dim if i < num_layers - 1 else hidden_dim
|
|
74
|
+
layers.extend([
|
|
75
|
+
nn.Linear(in_dim, out_dim),
|
|
76
|
+
nn.LayerNorm(out_dim),
|
|
77
|
+
nn.GELU(),
|
|
78
|
+
])
|
|
79
|
+
in_dim = out_dim
|
|
80
|
+
|
|
81
|
+
self.encoder = nn.Sequential(*layers)
|
|
82
|
+
|
|
83
|
+
# Output heads for mean and log_std
|
|
84
|
+
self.mean_head = nn.Linear(hidden_dim, latent_dim)
|
|
85
|
+
self.log_std_head = nn.Linear(hidden_dim, latent_dim)
|
|
86
|
+
|
|
87
|
+
def pool_sequence(self, z_t: Tensor) -> Tensor:
|
|
88
|
+
"""Pool sequence dimension to get batch-level representation."""
|
|
89
|
+
if z_t.dim() == 2:
|
|
90
|
+
return z_t # Already (B, D)
|
|
91
|
+
|
|
92
|
+
if self.use_attention:
|
|
93
|
+
# Use attention pooling
|
|
94
|
+
B, T, D = z_t.shape
|
|
95
|
+
query = self.pool_query.expand(B, -1, -1)
|
|
96
|
+
pooled, _ = self.attention(query, z_t, z_t)
|
|
97
|
+
return pooled.squeeze(1) # (B, D)
|
|
98
|
+
else:
|
|
99
|
+
# Simple mean pooling
|
|
100
|
+
return z_t.mean(dim=1)
|
|
101
|
+
|
|
102
|
+
def forward(
|
|
103
|
+
self,
|
|
104
|
+
z_t: Tensor,
|
|
105
|
+
t: Tensor | float,
|
|
106
|
+
) -> Tuple[Tensor, Tensor]:
|
|
107
|
+
"""
|
|
108
|
+
Encode to posterior parameters.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
z_t: Input state (B, T, D) or (B, D).
|
|
112
|
+
t: Time (B,) or scalar in [0, 1].
|
|
113
|
+
|
|
114
|
+
Returns:
|
|
115
|
+
Tuple of (mean, log_std) for q(w | z_t, t).
|
|
116
|
+
"""
|
|
117
|
+
# Handle scalar time
|
|
118
|
+
if isinstance(t, float):
|
|
119
|
+
t = torch.tensor([t], device=z_t.device, dtype=z_t.dtype)
|
|
120
|
+
if t.dim() == 0:
|
|
121
|
+
t = t.unsqueeze(0)
|
|
122
|
+
|
|
123
|
+
# Pool sequence
|
|
124
|
+
z_pooled = self.pool_sequence(z_t)
|
|
125
|
+
|
|
126
|
+
# Time embedding
|
|
127
|
+
t_embed = self.time_embed(t)
|
|
128
|
+
|
|
129
|
+
# Expand t_embed to batch size if needed
|
|
130
|
+
if t_embed.shape[0] == 1 and z_pooled.shape[0] > 1:
|
|
131
|
+
t_embed = t_embed.expand(z_pooled.shape[0], -1)
|
|
132
|
+
|
|
133
|
+
# Concatenate and encode
|
|
134
|
+
x = torch.cat([z_pooled, t_embed], dim=-1)
|
|
135
|
+
h = self.encoder(x)
|
|
136
|
+
|
|
137
|
+
# Output parameters
|
|
138
|
+
mean = self.mean_head(h)
|
|
139
|
+
log_std = self.log_std_head(h)
|
|
140
|
+
log_std = torch.clamp(log_std, min=-10, max=2)
|
|
141
|
+
|
|
142
|
+
return mean, log_std
|
|
143
|
+
|
|
144
|
+
def sample(
|
|
145
|
+
self,
|
|
146
|
+
z_t: Tensor,
|
|
147
|
+
t: Tensor | float,
|
|
148
|
+
num_samples: int = 1,
|
|
149
|
+
) -> Tuple[Tensor, Tensor, Tensor]:
|
|
150
|
+
"""
|
|
151
|
+
Sample from posterior using reparameterization trick.
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
z_t: Input state.
|
|
155
|
+
t: Time.
|
|
156
|
+
num_samples: Number of samples per input.
|
|
157
|
+
|
|
158
|
+
Returns:
|
|
159
|
+
Tuple of (w, mean, log_std) where w is (B, num_samples, latent_dim)
|
|
160
|
+
or (B, latent_dim) if num_samples=1.
|
|
161
|
+
"""
|
|
162
|
+
mean, log_std = self.forward(z_t, t)
|
|
163
|
+
std = torch.exp(log_std)
|
|
164
|
+
|
|
165
|
+
if num_samples == 1:
|
|
166
|
+
eps = torch.randn_like(mean)
|
|
167
|
+
w = mean + std * eps
|
|
168
|
+
else:
|
|
169
|
+
# (B, num_samples, latent_dim)
|
|
170
|
+
eps = torch.randn(mean.shape[0], num_samples, mean.shape[1], device=mean.device)
|
|
171
|
+
w = mean.unsqueeze(1) + std.unsqueeze(1) * eps
|
|
172
|
+
|
|
173
|
+
return w, mean, log_std
|
|
174
|
+
|
|
175
|
+
def kl_divergence(
|
|
176
|
+
self,
|
|
177
|
+
mean: Tensor,
|
|
178
|
+
log_std: Tensor,
|
|
179
|
+
prior_mean: float = 0.0,
|
|
180
|
+
prior_std: float = 1.0,
|
|
181
|
+
) -> Tensor:
|
|
182
|
+
"""
|
|
183
|
+
Compute KL divergence from posterior to prior.
|
|
184
|
+
|
|
185
|
+
KL(N(mean, std^2) || N(prior_mean, prior_std^2))
|
|
186
|
+
"""
|
|
187
|
+
std = torch.exp(log_std)
|
|
188
|
+
prior_var = prior_std ** 2
|
|
189
|
+
|
|
190
|
+
kl = 0.5 * (
|
|
191
|
+
(std ** 2 + (mean - prior_mean) ** 2) / prior_var
|
|
192
|
+
- 1
|
|
193
|
+
- 2 * log_std
|
|
194
|
+
+ 2 * math.log(prior_std)
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
return kl.sum(dim=-1).mean()
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
class CategoricalLatentEncoder(nn.Module):
|
|
201
|
+
"""
|
|
202
|
+
Categorical latent encoder with Gumbel-Softmax for discrete modes.
|
|
203
|
+
|
|
204
|
+
Useful when the multi-modality has a clear categorical structure
|
|
205
|
+
(e.g., left vs. right grasp, different skill types).
|
|
206
|
+
|
|
207
|
+
Uses Gumbel-Softmax for differentiable sampling during training
|
|
208
|
+
and argmax during inference.
|
|
209
|
+
"""
|
|
210
|
+
|
|
211
|
+
def __init__(
|
|
212
|
+
self,
|
|
213
|
+
input_dim: int,
|
|
214
|
+
num_categories: int,
|
|
215
|
+
embedding_dim: int = 16,
|
|
216
|
+
hidden_dim: int = 256,
|
|
217
|
+
num_layers: int = 2,
|
|
218
|
+
time_embed_dim: int = 64,
|
|
219
|
+
temperature: float = 1.0,
|
|
220
|
+
):
|
|
221
|
+
"""
|
|
222
|
+
Args:
|
|
223
|
+
input_dim: Dimension of input z_t.
|
|
224
|
+
num_categories: Number of discrete categories (modes).
|
|
225
|
+
embedding_dim: Dimension of category embeddings.
|
|
226
|
+
hidden_dim: Hidden layer dimension.
|
|
227
|
+
num_layers: Number of hidden layers.
|
|
228
|
+
time_embed_dim: Dimension of time embedding.
|
|
229
|
+
temperature: Gumbel-Softmax temperature.
|
|
230
|
+
"""
|
|
231
|
+
super().__init__()
|
|
232
|
+
self.input_dim = input_dim
|
|
233
|
+
self.num_categories = num_categories
|
|
234
|
+
self.embedding_dim = embedding_dim
|
|
235
|
+
self.temperature = temperature
|
|
236
|
+
|
|
237
|
+
# Time embedding
|
|
238
|
+
self.time_embed = SinusoidalTimeEmbedding(time_embed_dim)
|
|
239
|
+
|
|
240
|
+
# MLP encoder
|
|
241
|
+
layers = []
|
|
242
|
+
in_dim = input_dim + time_embed_dim
|
|
243
|
+
for i in range(num_layers):
|
|
244
|
+
out_dim = hidden_dim
|
|
245
|
+
layers.extend([
|
|
246
|
+
nn.Linear(in_dim, out_dim),
|
|
247
|
+
nn.LayerNorm(out_dim),
|
|
248
|
+
nn.GELU(),
|
|
249
|
+
])
|
|
250
|
+
in_dim = out_dim
|
|
251
|
+
|
|
252
|
+
self.encoder = nn.Sequential(*layers)
|
|
253
|
+
|
|
254
|
+
# Logits head
|
|
255
|
+
self.logits_head = nn.Linear(hidden_dim, num_categories)
|
|
256
|
+
|
|
257
|
+
# Category embeddings (for converting one-hot to continuous)
|
|
258
|
+
self.category_embeddings = nn.Embedding(num_categories, embedding_dim)
|
|
259
|
+
|
|
260
|
+
@property
|
|
261
|
+
def latent_dim(self) -> int:
|
|
262
|
+
"""Return the effective latent dimension (embedding_dim)."""
|
|
263
|
+
return self.embedding_dim
|
|
264
|
+
|
|
265
|
+
def forward(
|
|
266
|
+
self,
|
|
267
|
+
z_t: Tensor,
|
|
268
|
+
t: Tensor | float,
|
|
269
|
+
) -> Tensor:
|
|
270
|
+
"""
|
|
271
|
+
Compute category logits.
|
|
272
|
+
|
|
273
|
+
Args:
|
|
274
|
+
z_t: Input state (B, T, D) or (B, D).
|
|
275
|
+
t: Time (B,) or scalar.
|
|
276
|
+
|
|
277
|
+
Returns:
|
|
278
|
+
Logits (B, num_categories).
|
|
279
|
+
"""
|
|
280
|
+
# Handle scalar time
|
|
281
|
+
if isinstance(t, float):
|
|
282
|
+
t = torch.tensor([t], device=z_t.device, dtype=z_t.dtype)
|
|
283
|
+
if t.dim() == 0:
|
|
284
|
+
t = t.unsqueeze(0)
|
|
285
|
+
|
|
286
|
+
# Pool sequence
|
|
287
|
+
if z_t.dim() == 3:
|
|
288
|
+
z_pooled = z_t.mean(dim=1)
|
|
289
|
+
else:
|
|
290
|
+
z_pooled = z_t
|
|
291
|
+
|
|
292
|
+
# Time embedding
|
|
293
|
+
t_embed = self.time_embed(t)
|
|
294
|
+
if t_embed.shape[0] == 1 and z_pooled.shape[0] > 1:
|
|
295
|
+
t_embed = t_embed.expand(z_pooled.shape[0], -1)
|
|
296
|
+
|
|
297
|
+
# Encode
|
|
298
|
+
x = torch.cat([z_pooled, t_embed], dim=-1)
|
|
299
|
+
h = self.encoder(x)
|
|
300
|
+
|
|
301
|
+
# Logits
|
|
302
|
+
logits = self.logits_head(h)
|
|
303
|
+
|
|
304
|
+
return logits
|
|
305
|
+
|
|
306
|
+
def sample(
|
|
307
|
+
self,
|
|
308
|
+
z_t: Tensor,
|
|
309
|
+
t: Tensor | float,
|
|
310
|
+
hard: bool = False,
|
|
311
|
+
temperature: Optional[float] = None,
|
|
312
|
+
) -> Tuple[Tensor, Tensor, Tensor]:
|
|
313
|
+
"""
|
|
314
|
+
Sample using Gumbel-Softmax.
|
|
315
|
+
|
|
316
|
+
Args:
|
|
317
|
+
z_t: Input state.
|
|
318
|
+
t: Time.
|
|
319
|
+
hard: If True, use straight-through estimator for hard samples.
|
|
320
|
+
temperature: Override default temperature.
|
|
321
|
+
|
|
322
|
+
Returns:
|
|
323
|
+
Tuple of (w, probs, logits) where:
|
|
324
|
+
- w: Category embedding (B, embedding_dim)
|
|
325
|
+
- probs: Category probabilities (B, num_categories)
|
|
326
|
+
- logits: Raw logits (B, num_categories)
|
|
327
|
+
"""
|
|
328
|
+
if temperature is None:
|
|
329
|
+
temperature = self.temperature
|
|
330
|
+
|
|
331
|
+
logits = self.forward(z_t, t)
|
|
332
|
+
probs = F.softmax(logits, dim=-1)
|
|
333
|
+
|
|
334
|
+
# Gumbel-Softmax sampling
|
|
335
|
+
if self.training:
|
|
336
|
+
gumbel_probs = F.gumbel_softmax(logits, tau=temperature, hard=hard)
|
|
337
|
+
else:
|
|
338
|
+
# During inference, use argmax
|
|
339
|
+
if hard:
|
|
340
|
+
indices = logits.argmax(dim=-1)
|
|
341
|
+
gumbel_probs = F.one_hot(indices, self.num_categories).float()
|
|
342
|
+
else:
|
|
343
|
+
gumbel_probs = probs
|
|
344
|
+
|
|
345
|
+
# Convert to embedding
|
|
346
|
+
# gumbel_probs: (B, num_categories)
|
|
347
|
+
# category_embeddings: (num_categories, embedding_dim)
|
|
348
|
+
w = gumbel_probs @ self.category_embeddings.weight
|
|
349
|
+
|
|
350
|
+
return w, probs, logits
|
|
351
|
+
|
|
352
|
+
def sample_prior(
|
|
353
|
+
self,
|
|
354
|
+
batch_size: int,
|
|
355
|
+
device: str = "cpu",
|
|
356
|
+
hard: bool = True,
|
|
357
|
+
) -> Tensor:
|
|
358
|
+
"""
|
|
359
|
+
Sample from uniform categorical prior.
|
|
360
|
+
|
|
361
|
+
Args:
|
|
362
|
+
batch_size: Number of samples.
|
|
363
|
+
device: Device for tensor.
|
|
364
|
+
hard: If True, return embeddings for hard categories.
|
|
365
|
+
|
|
366
|
+
Returns:
|
|
367
|
+
Category embeddings (batch_size, embedding_dim).
|
|
368
|
+
"""
|
|
369
|
+
if hard:
|
|
370
|
+
# Sample uniform categories
|
|
371
|
+
indices = torch.randint(
|
|
372
|
+
0, self.num_categories, (batch_size,), device=device
|
|
373
|
+
)
|
|
374
|
+
w = self.category_embeddings(indices)
|
|
375
|
+
else:
|
|
376
|
+
# Soft uniform mixture
|
|
377
|
+
probs = torch.ones(batch_size, self.num_categories, device=device)
|
|
378
|
+
probs = probs / self.num_categories
|
|
379
|
+
w = probs @ self.category_embeddings.weight
|
|
380
|
+
|
|
381
|
+
return w
|
|
382
|
+
|
|
383
|
+
def kl_divergence(
|
|
384
|
+
self,
|
|
385
|
+
probs: Tensor,
|
|
386
|
+
prior_probs: Optional[Tensor] = None,
|
|
387
|
+
) -> Tensor:
|
|
388
|
+
"""
|
|
389
|
+
Compute KL divergence from posterior to prior.
|
|
390
|
+
|
|
391
|
+
If prior_probs is None, uses uniform prior.
|
|
392
|
+
"""
|
|
393
|
+
if prior_probs is None:
|
|
394
|
+
# Uniform prior
|
|
395
|
+
prior_probs = torch.ones_like(probs) / self.num_categories
|
|
396
|
+
|
|
397
|
+
# KL = sum_k p_k * log(p_k / q_k)
|
|
398
|
+
kl = probs * (torch.log(probs + 1e-10) - torch.log(prior_probs + 1e-10))
|
|
399
|
+
|
|
400
|
+
return kl.sum(dim=-1).mean()
|
|
401
|
+
|
|
402
|
+
def entropy(self, probs: Tensor) -> Tensor:
|
|
403
|
+
"""Compute entropy of the categorical distribution."""
|
|
404
|
+
return -(probs * torch.log(probs + 1e-10)).sum(dim=-1).mean()
|
|
405
|
+
|
|
406
|
+
|
|
407
|
+
class HybridLatentEncoder(nn.Module):
|
|
408
|
+
"""
|
|
409
|
+
Hybrid encoder combining categorical and continuous latents.
|
|
410
|
+
|
|
411
|
+
Useful for representing both discrete mode selection and
|
|
412
|
+
continuous variations within each mode.
|
|
413
|
+
"""
|
|
414
|
+
|
|
415
|
+
def __init__(
|
|
416
|
+
self,
|
|
417
|
+
input_dim: int,
|
|
418
|
+
num_categories: int,
|
|
419
|
+
continuous_dim: int,
|
|
420
|
+
hidden_dim: int = 256,
|
|
421
|
+
num_layers: int = 2,
|
|
422
|
+
time_embed_dim: int = 64,
|
|
423
|
+
temperature: float = 1.0,
|
|
424
|
+
):
|
|
425
|
+
"""
|
|
426
|
+
Args:
|
|
427
|
+
input_dim: Dimension of input z_t.
|
|
428
|
+
num_categories: Number of discrete modes.
|
|
429
|
+
continuous_dim: Dimension of continuous latent per mode.
|
|
430
|
+
hidden_dim: Hidden layer dimension.
|
|
431
|
+
num_layers: Number of hidden layers.
|
|
432
|
+
time_embed_dim: Dimension of time embedding.
|
|
433
|
+
temperature: Gumbel-Softmax temperature.
|
|
434
|
+
"""
|
|
435
|
+
super().__init__()
|
|
436
|
+
self.num_categories = num_categories
|
|
437
|
+
self.continuous_dim = continuous_dim
|
|
438
|
+
|
|
439
|
+
# Categorical encoder
|
|
440
|
+
self.categorical = CategoricalLatentEncoder(
|
|
441
|
+
input_dim=input_dim,
|
|
442
|
+
num_categories=num_categories,
|
|
443
|
+
embedding_dim=hidden_dim // 4,
|
|
444
|
+
hidden_dim=hidden_dim,
|
|
445
|
+
num_layers=num_layers,
|
|
446
|
+
time_embed_dim=time_embed_dim,
|
|
447
|
+
temperature=temperature,
|
|
448
|
+
)
|
|
449
|
+
|
|
450
|
+
# Continuous encoder (shared across modes)
|
|
451
|
+
self.continuous = LatentEncoder(
|
|
452
|
+
input_dim=input_dim,
|
|
453
|
+
latent_dim=continuous_dim,
|
|
454
|
+
hidden_dim=hidden_dim,
|
|
455
|
+
num_layers=num_layers,
|
|
456
|
+
time_embed_dim=time_embed_dim,
|
|
457
|
+
)
|
|
458
|
+
|
|
459
|
+
# Combine categorical embedding and continuous latent
|
|
460
|
+
self.combine = nn.Linear(
|
|
461
|
+
self.categorical.embedding_dim + continuous_dim,
|
|
462
|
+
continuous_dim,
|
|
463
|
+
)
|
|
464
|
+
|
|
465
|
+
@property
|
|
466
|
+
def latent_dim(self) -> int:
|
|
467
|
+
"""Return the combined latent dimension."""
|
|
468
|
+
return self.continuous_dim
|
|
469
|
+
|
|
470
|
+
def forward(
|
|
471
|
+
self,
|
|
472
|
+
z_t: Tensor,
|
|
473
|
+
t: Tensor | float,
|
|
474
|
+
) -> Dict[str, Tensor]:
|
|
475
|
+
"""
|
|
476
|
+
Encode to both categorical and continuous latents.
|
|
477
|
+
|
|
478
|
+
Returns dictionary with all components.
|
|
479
|
+
"""
|
|
480
|
+
# Categorical
|
|
481
|
+
cat_logits = self.categorical(z_t, t)
|
|
482
|
+
cat_probs = F.softmax(cat_logits, dim=-1)
|
|
483
|
+
|
|
484
|
+
# Continuous
|
|
485
|
+
cont_mean, cont_log_std = self.continuous(z_t, t)
|
|
486
|
+
|
|
487
|
+
return {
|
|
488
|
+
"categorical_logits": cat_logits,
|
|
489
|
+
"categorical_probs": cat_probs,
|
|
490
|
+
"continuous_mean": cont_mean,
|
|
491
|
+
"continuous_log_std": cont_log_std,
|
|
492
|
+
}
|
|
493
|
+
|
|
494
|
+
def sample(
|
|
495
|
+
self,
|
|
496
|
+
z_t: Tensor,
|
|
497
|
+
t: Tensor | float,
|
|
498
|
+
hard: bool = False,
|
|
499
|
+
temperature: Optional[float] = None,
|
|
500
|
+
) -> Tuple[Tensor, Dict[str, Tensor]]:
|
|
501
|
+
"""
|
|
502
|
+
Sample combined latent.
|
|
503
|
+
|
|
504
|
+
Returns:
|
|
505
|
+
Tuple of (w, info_dict) where w is (B, latent_dim).
|
|
506
|
+
"""
|
|
507
|
+
# Sample categorical
|
|
508
|
+
cat_w, cat_probs, cat_logits = self.categorical.sample(
|
|
509
|
+
z_t, t, hard=hard, temperature=temperature
|
|
510
|
+
)
|
|
511
|
+
|
|
512
|
+
# Sample continuous
|
|
513
|
+
cont_w, cont_mean, cont_log_std = self.continuous.sample(z_t, t)
|
|
514
|
+
|
|
515
|
+
# Combine
|
|
516
|
+
combined = torch.cat([cat_w, cont_w], dim=-1)
|
|
517
|
+
w = self.combine(combined)
|
|
518
|
+
|
|
519
|
+
info = {
|
|
520
|
+
"categorical_w": cat_w,
|
|
521
|
+
"categorical_probs": cat_probs,
|
|
522
|
+
"categorical_logits": cat_logits,
|
|
523
|
+
"continuous_w": cont_w,
|
|
524
|
+
"continuous_mean": cont_mean,
|
|
525
|
+
"continuous_log_std": cont_log_std,
|
|
526
|
+
}
|
|
527
|
+
|
|
528
|
+
return w, info
|
|
529
|
+
|
|
530
|
+
def kl_divergence(self, info: Dict[str, Tensor]) -> Tensor:
|
|
531
|
+
"""Compute total KL divergence."""
|
|
532
|
+
cat_kl = self.categorical.kl_divergence(info["categorical_probs"])
|
|
533
|
+
cont_kl = self.continuous.kl_divergence(
|
|
534
|
+
info["continuous_mean"], info["continuous_log_std"]
|
|
535
|
+
)
|
|
536
|
+
return cat_kl + cont_kl
|