dhb-xr 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dhb_xr/__init__.py +61 -0
- dhb_xr/cli.py +206 -0
- dhb_xr/core/__init__.py +28 -0
- dhb_xr/core/geometry.py +167 -0
- dhb_xr/core/geometry_torch.py +77 -0
- dhb_xr/core/types.py +113 -0
- dhb_xr/database/__init__.py +10 -0
- dhb_xr/database/motion_db.py +79 -0
- dhb_xr/database/retrieval.py +6 -0
- dhb_xr/database/similarity.py +71 -0
- dhb_xr/decoder/__init__.py +13 -0
- dhb_xr/decoder/decoder_torch.py +52 -0
- dhb_xr/decoder/dhb_dr.py +261 -0
- dhb_xr/decoder/dhb_qr.py +89 -0
- dhb_xr/encoder/__init__.py +27 -0
- dhb_xr/encoder/dhb_dr.py +418 -0
- dhb_xr/encoder/dhb_qr.py +129 -0
- dhb_xr/encoder/dhb_ti.py +204 -0
- dhb_xr/encoder/encoder_torch.py +54 -0
- dhb_xr/encoder/padding.py +82 -0
- dhb_xr/generative/__init__.py +78 -0
- dhb_xr/generative/flow_matching.py +705 -0
- dhb_xr/generative/latent_encoder.py +536 -0
- dhb_xr/generative/sampling.py +203 -0
- dhb_xr/generative/training.py +475 -0
- dhb_xr/generative/vfm_tokenizer.py +485 -0
- dhb_xr/integration/__init__.py +13 -0
- dhb_xr/integration/vla/__init__.py +11 -0
- dhb_xr/integration/vla/libero.py +132 -0
- dhb_xr/integration/vla/pipeline.py +85 -0
- dhb_xr/integration/vla/robocasa.py +85 -0
- dhb_xr/losses/__init__.py +16 -0
- dhb_xr/losses/geodesic_loss.py +91 -0
- dhb_xr/losses/hybrid_loss.py +36 -0
- dhb_xr/losses/invariant_loss.py +73 -0
- dhb_xr/optimization/__init__.py +72 -0
- dhb_xr/optimization/casadi_solver.py +342 -0
- dhb_xr/optimization/constraints.py +32 -0
- dhb_xr/optimization/cusadi_solver.py +311 -0
- dhb_xr/optimization/export_casadi_decode.py +111 -0
- dhb_xr/optimization/fatrop_solver.py +477 -0
- dhb_xr/optimization/torch_solver.py +85 -0
- dhb_xr/preprocessing/__init__.py +42 -0
- dhb_xr/preprocessing/diagnostics.py +330 -0
- dhb_xr/preprocessing/trajectory_cleaner.py +485 -0
- dhb_xr/tokenization/__init__.py +56 -0
- dhb_xr/tokenization/causal_encoder.py +54 -0
- dhb_xr/tokenization/compression.py +749 -0
- dhb_xr/tokenization/hierarchical.py +359 -0
- dhb_xr/tokenization/rvq.py +178 -0
- dhb_xr/tokenization/vqvae.py +155 -0
- dhb_xr/utils/__init__.py +24 -0
- dhb_xr/utils/io.py +59 -0
- dhb_xr/utils/resampling.py +66 -0
- dhb_xr/utils/xdof_loader.py +89 -0
- dhb_xr/visualization/__init__.py +5 -0
- dhb_xr/visualization/plot.py +242 -0
- dhb_xr-0.2.1.dist-info/METADATA +784 -0
- dhb_xr-0.2.1.dist-info/RECORD +82 -0
- dhb_xr-0.2.1.dist-info/WHEEL +5 -0
- dhb_xr-0.2.1.dist-info/entry_points.txt +2 -0
- dhb_xr-0.2.1.dist-info/top_level.txt +3 -0
- examples/__init__.py +54 -0
- examples/basic_encoding.py +82 -0
- examples/benchmark_backends.py +37 -0
- examples/dhb_qr_comparison.py +79 -0
- examples/dhb_ti_time_invariant.py +72 -0
- examples/gpu_batch_optimization.py +102 -0
- examples/imitation_learning.py +53 -0
- examples/integration/__init__.py +19 -0
- examples/integration/libero_full_demo.py +692 -0
- examples/integration/libero_pro_dhb_demo.py +1063 -0
- examples/integration/libero_simulation_demo.py +286 -0
- examples/integration/libero_swap_demo.py +534 -0
- examples/integration/robocasa_libero_dhb_pipeline.py +56 -0
- examples/integration/test_libero_adapter.py +47 -0
- examples/integration/test_libero_encoding.py +75 -0
- examples/integration/test_libero_retrieval.py +105 -0
- examples/motion_database.py +88 -0
- examples/trajectory_adaptation.py +85 -0
- examples/vla_tokenization.py +107 -0
- notebooks/__init__.py +24 -0
|
@@ -0,0 +1,705 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Variational Flow Matching for DHB-Token generation.
|
|
3
|
+
|
|
4
|
+
Implements:
|
|
5
|
+
- FlowMatcher: Base flow matching (FM/RFM) for deterministic velocity regression
|
|
6
|
+
- VariationalFlowMatcher: V-RFM with latent conditioning for multi-modal generation
|
|
7
|
+
|
|
8
|
+
The variational extension addresses the "mode averaging" problem in standard FM:
|
|
9
|
+
when multiple valid trajectories exist from the same state, a deterministic velocity
|
|
10
|
+
regressor averages modes, producing implausible intermediate behavior. By conditioning
|
|
11
|
+
on a latent variable w, V-RFM can represent and sample from multiple modes.
|
|
12
|
+
|
|
13
|
+
Reference: "Flow Matching for Generative Modeling" (Lipman et al., 2023)
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
from __future__ import annotations
|
|
17
|
+
|
|
18
|
+
import math
|
|
19
|
+
import torch
|
|
20
|
+
import torch.nn as nn
|
|
21
|
+
import torch.nn.functional as F
|
|
22
|
+
from torch import Tensor
|
|
23
|
+
from typing import Optional, Tuple, Dict, Any
|
|
24
|
+
|
|
25
|
+
from .sampling import ode_solve
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class SinusoidalTimeEmbedding(nn.Module):
|
|
29
|
+
"""
|
|
30
|
+
Sinusoidal positional embedding for time t in [0, 1].
|
|
31
|
+
|
|
32
|
+
Maps scalar time to a high-dimensional representation using
|
|
33
|
+
sinusoidal functions at different frequencies.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def __init__(self, dim: int, max_period: float = 10000.0):
|
|
37
|
+
"""
|
|
38
|
+
Args:
|
|
39
|
+
dim: Output embedding dimension (should be even).
|
|
40
|
+
max_period: Maximum period for sinusoidal functions.
|
|
41
|
+
"""
|
|
42
|
+
super().__init__()
|
|
43
|
+
self.dim = dim
|
|
44
|
+
self.max_period = max_period
|
|
45
|
+
|
|
46
|
+
# Precompute frequency factors
|
|
47
|
+
half_dim = dim // 2
|
|
48
|
+
freqs = torch.exp(
|
|
49
|
+
-math.log(max_period) * torch.arange(half_dim, dtype=torch.float32) / half_dim
|
|
50
|
+
)
|
|
51
|
+
self.register_buffer("freqs", freqs)
|
|
52
|
+
|
|
53
|
+
def forward(self, t: Tensor) -> Tensor:
|
|
54
|
+
"""
|
|
55
|
+
Args:
|
|
56
|
+
t: Time values (B,) or scalar, in [0, 1].
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
Time embeddings (B, dim) or (dim,).
|
|
60
|
+
"""
|
|
61
|
+
if t.dim() == 0:
|
|
62
|
+
t = t.unsqueeze(0)
|
|
63
|
+
|
|
64
|
+
# Scale t to match frequency range
|
|
65
|
+
args = t.unsqueeze(-1) * self.freqs * 2 * math.pi
|
|
66
|
+
|
|
67
|
+
# Concatenate sin and cos
|
|
68
|
+
embedding = torch.cat([torch.sin(args), torch.cos(args)], dim=-1)
|
|
69
|
+
|
|
70
|
+
return embedding
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class VelocityNetwork(nn.Module):
|
|
74
|
+
"""
|
|
75
|
+
Neural network predicting the velocity field v_theta(z_t, t, w).
|
|
76
|
+
|
|
77
|
+
Architecture: MLP with residual connections, processing temporal sequences.
|
|
78
|
+
Supports optional latent conditioning for variational flow matching.
|
|
79
|
+
"""
|
|
80
|
+
|
|
81
|
+
def __init__(
|
|
82
|
+
self,
|
|
83
|
+
latent_dim: int,
|
|
84
|
+
hidden_dim: int = 256,
|
|
85
|
+
num_layers: int = 4,
|
|
86
|
+
time_embed_dim: int = 64,
|
|
87
|
+
condition_dim: int = 0,
|
|
88
|
+
dropout: float = 0.1,
|
|
89
|
+
use_layer_norm: bool = True,
|
|
90
|
+
):
|
|
91
|
+
"""
|
|
92
|
+
Args:
|
|
93
|
+
latent_dim: Dimension of the latent space z.
|
|
94
|
+
hidden_dim: Hidden layer dimension.
|
|
95
|
+
num_layers: Number of residual blocks.
|
|
96
|
+
time_embed_dim: Dimension of time embedding.
|
|
97
|
+
condition_dim: Dimension of conditioning latent w (0 for unconditional).
|
|
98
|
+
dropout: Dropout rate.
|
|
99
|
+
use_layer_norm: Whether to use layer normalization.
|
|
100
|
+
"""
|
|
101
|
+
super().__init__()
|
|
102
|
+
self.latent_dim = latent_dim
|
|
103
|
+
self.hidden_dim = hidden_dim
|
|
104
|
+
self.condition_dim = condition_dim
|
|
105
|
+
|
|
106
|
+
# Time embedding
|
|
107
|
+
self.time_embed = SinusoidalTimeEmbedding(time_embed_dim)
|
|
108
|
+
|
|
109
|
+
# Input projection: z_t + time_embed + (optional) condition
|
|
110
|
+
input_dim = latent_dim + time_embed_dim + condition_dim
|
|
111
|
+
self.input_proj = nn.Linear(input_dim, hidden_dim)
|
|
112
|
+
|
|
113
|
+
# Residual blocks
|
|
114
|
+
self.blocks = nn.ModuleList()
|
|
115
|
+
for _ in range(num_layers):
|
|
116
|
+
block = nn.Sequential(
|
|
117
|
+
nn.LayerNorm(hidden_dim) if use_layer_norm else nn.Identity(),
|
|
118
|
+
nn.Linear(hidden_dim, hidden_dim * 2),
|
|
119
|
+
nn.GELU(),
|
|
120
|
+
nn.Dropout(dropout),
|
|
121
|
+
nn.Linear(hidden_dim * 2, hidden_dim),
|
|
122
|
+
nn.Dropout(dropout),
|
|
123
|
+
)
|
|
124
|
+
self.blocks.append(block)
|
|
125
|
+
|
|
126
|
+
# Output projection to velocity
|
|
127
|
+
self.output_proj = nn.Sequential(
|
|
128
|
+
nn.LayerNorm(hidden_dim) if use_layer_norm else nn.Identity(),
|
|
129
|
+
nn.Linear(hidden_dim, latent_dim),
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
def forward(
|
|
133
|
+
self,
|
|
134
|
+
z_t: Tensor,
|
|
135
|
+
t: Tensor | float,
|
|
136
|
+
w: Optional[Tensor] = None,
|
|
137
|
+
) -> Tensor:
|
|
138
|
+
"""
|
|
139
|
+
Predict velocity at state z_t and time t.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
z_t: Current state (B, T, D) or (B, D).
|
|
143
|
+
t: Time (B,) or scalar in [0, 1].
|
|
144
|
+
w: Optional conditioning latent (B, condition_dim).
|
|
145
|
+
|
|
146
|
+
Returns:
|
|
147
|
+
Predicted velocity v_theta (same shape as z_t).
|
|
148
|
+
"""
|
|
149
|
+
# Handle scalar time
|
|
150
|
+
if isinstance(t, float):
|
|
151
|
+
t = torch.tensor([t], device=z_t.device, dtype=z_t.dtype)
|
|
152
|
+
if t.dim() == 0:
|
|
153
|
+
t = t.unsqueeze(0)
|
|
154
|
+
|
|
155
|
+
# Get batch size and handle 2D/3D input
|
|
156
|
+
is_3d = z_t.dim() == 3
|
|
157
|
+
if is_3d:
|
|
158
|
+
B, T, D = z_t.shape
|
|
159
|
+
z_flat = z_t.reshape(B * T, D)
|
|
160
|
+
else:
|
|
161
|
+
B, D = z_t.shape
|
|
162
|
+
T = 1
|
|
163
|
+
z_flat = z_t
|
|
164
|
+
|
|
165
|
+
# Time embedding (broadcast to all timesteps)
|
|
166
|
+
t_embed = self.time_embed(t) # (B, time_embed_dim)
|
|
167
|
+
if is_3d:
|
|
168
|
+
t_embed = t_embed.unsqueeze(1).expand(-1, T, -1).reshape(B * T, -1)
|
|
169
|
+
else:
|
|
170
|
+
t_embed = t_embed.expand(B, -1)
|
|
171
|
+
|
|
172
|
+
# Build input
|
|
173
|
+
inputs = [z_flat, t_embed]
|
|
174
|
+
|
|
175
|
+
# Add conditioning if provided
|
|
176
|
+
if w is not None:
|
|
177
|
+
if is_3d:
|
|
178
|
+
w_expanded = w.unsqueeze(1).expand(-1, T, -1).reshape(B * T, -1)
|
|
179
|
+
else:
|
|
180
|
+
w_expanded = w
|
|
181
|
+
inputs.append(w_expanded)
|
|
182
|
+
|
|
183
|
+
x = torch.cat(inputs, dim=-1)
|
|
184
|
+
x = self.input_proj(x)
|
|
185
|
+
|
|
186
|
+
# Residual blocks
|
|
187
|
+
for block in self.blocks:
|
|
188
|
+
x = x + block(x)
|
|
189
|
+
|
|
190
|
+
# Output velocity
|
|
191
|
+
v = self.output_proj(x)
|
|
192
|
+
|
|
193
|
+
# Reshape back if needed
|
|
194
|
+
if is_3d:
|
|
195
|
+
v = v.reshape(B, T, D)
|
|
196
|
+
|
|
197
|
+
return v
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
class FlowMatcher(nn.Module):
|
|
201
|
+
"""
|
|
202
|
+
Base Flow Matching model for trajectory generation.
|
|
203
|
+
|
|
204
|
+
Learns a velocity field that transports samples from noise (z_0)
|
|
205
|
+
to data (z_1) along straight interpolation paths.
|
|
206
|
+
|
|
207
|
+
Training objective:
|
|
208
|
+
L = E[||v_theta(z_t, t) - (z_1 - z_0)||^2]
|
|
209
|
+
|
|
210
|
+
where z_t = (1-t)*z_0 + t*z_1 is the linear interpolation.
|
|
211
|
+
"""
|
|
212
|
+
|
|
213
|
+
def __init__(
|
|
214
|
+
self,
|
|
215
|
+
latent_dim: int,
|
|
216
|
+
hidden_dim: int = 256,
|
|
217
|
+
num_layers: int = 4,
|
|
218
|
+
time_embed_dim: int = 64,
|
|
219
|
+
dropout: float = 0.1,
|
|
220
|
+
):
|
|
221
|
+
"""
|
|
222
|
+
Args:
|
|
223
|
+
latent_dim: Dimension of the latent space.
|
|
224
|
+
hidden_dim: Hidden layer dimension.
|
|
225
|
+
num_layers: Number of residual blocks.
|
|
226
|
+
time_embed_dim: Dimension of time embedding.
|
|
227
|
+
dropout: Dropout rate.
|
|
228
|
+
"""
|
|
229
|
+
super().__init__()
|
|
230
|
+
self.latent_dim = latent_dim
|
|
231
|
+
|
|
232
|
+
self.velocity_net = VelocityNetwork(
|
|
233
|
+
latent_dim=latent_dim,
|
|
234
|
+
hidden_dim=hidden_dim,
|
|
235
|
+
num_layers=num_layers,
|
|
236
|
+
time_embed_dim=time_embed_dim,
|
|
237
|
+
condition_dim=0,
|
|
238
|
+
dropout=dropout,
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
def forward(self, z_t: Tensor, t: Tensor | float) -> Tensor:
|
|
242
|
+
"""Predict velocity at state z_t and time t."""
|
|
243
|
+
return self.velocity_net(z_t, t)
|
|
244
|
+
|
|
245
|
+
def interpolate(self, z_0: Tensor, z_1: Tensor, t: Tensor | float) -> Tensor:
|
|
246
|
+
"""
|
|
247
|
+
Linear interpolation between noise (z_0) and data (z_1).
|
|
248
|
+
|
|
249
|
+
Args:
|
|
250
|
+
z_0: Noise samples (B, T, D) or (B, D).
|
|
251
|
+
z_1: Data samples (same shape).
|
|
252
|
+
t: Interpolation time in [0, 1].
|
|
253
|
+
|
|
254
|
+
Returns:
|
|
255
|
+
Interpolated state z_t.
|
|
256
|
+
"""
|
|
257
|
+
if isinstance(t, float):
|
|
258
|
+
return (1 - t) * z_0 + t * z_1
|
|
259
|
+
|
|
260
|
+
# Handle tensor t with proper broadcasting
|
|
261
|
+
if z_0.dim() == 3:
|
|
262
|
+
t = t.view(-1, 1, 1)
|
|
263
|
+
else:
|
|
264
|
+
t = t.view(-1, 1)
|
|
265
|
+
|
|
266
|
+
return (1 - t) * z_0 + t * z_1
|
|
267
|
+
|
|
268
|
+
def target_velocity(self, z_0: Tensor, z_1: Tensor) -> Tensor:
|
|
269
|
+
"""
|
|
270
|
+
Compute target velocity for flow matching (straight path).
|
|
271
|
+
|
|
272
|
+
The optimal transport direction is simply z_1 - z_0.
|
|
273
|
+
"""
|
|
274
|
+
return z_1 - z_0
|
|
275
|
+
|
|
276
|
+
def loss(
|
|
277
|
+
self,
|
|
278
|
+
z_1: Tensor,
|
|
279
|
+
z_0: Optional[Tensor] = None,
|
|
280
|
+
) -> Tensor:
|
|
281
|
+
"""
|
|
282
|
+
Compute flow matching loss.
|
|
283
|
+
|
|
284
|
+
Args:
|
|
285
|
+
z_1: Data samples (B, T, D) or (B, D).
|
|
286
|
+
z_0: Noise samples. If None, sampled from N(0, 1).
|
|
287
|
+
|
|
288
|
+
Returns:
|
|
289
|
+
Scalar loss value.
|
|
290
|
+
"""
|
|
291
|
+
# Sample noise if not provided
|
|
292
|
+
if z_0 is None:
|
|
293
|
+
z_0 = torch.randn_like(z_1)
|
|
294
|
+
|
|
295
|
+
# Sample random time
|
|
296
|
+
B = z_1.shape[0]
|
|
297
|
+
t = torch.rand(B, device=z_1.device, dtype=z_1.dtype)
|
|
298
|
+
|
|
299
|
+
# Interpolate
|
|
300
|
+
z_t = self.interpolate(z_0, z_1, t)
|
|
301
|
+
|
|
302
|
+
# Target velocity
|
|
303
|
+
v_target = self.target_velocity(z_0, z_1)
|
|
304
|
+
|
|
305
|
+
# Predicted velocity
|
|
306
|
+
v_pred = self.forward(z_t, t)
|
|
307
|
+
|
|
308
|
+
# MSE loss
|
|
309
|
+
loss = F.mse_loss(v_pred, v_target)
|
|
310
|
+
|
|
311
|
+
return loss
|
|
312
|
+
|
|
313
|
+
def sample(
|
|
314
|
+
self,
|
|
315
|
+
num_samples: int,
|
|
316
|
+
seq_len: int,
|
|
317
|
+
num_steps: int = 10,
|
|
318
|
+
method: str = "euler",
|
|
319
|
+
device: str = "cpu",
|
|
320
|
+
return_trajectory: bool = False,
|
|
321
|
+
) -> Tensor | Tuple[Tensor, list]:
|
|
322
|
+
"""
|
|
323
|
+
Generate samples by solving the flow ODE from noise to data.
|
|
324
|
+
|
|
325
|
+
Args:
|
|
326
|
+
num_samples: Number of samples to generate.
|
|
327
|
+
seq_len: Sequence length T.
|
|
328
|
+
num_steps: Number of ODE integration steps.
|
|
329
|
+
method: ODE solver ('euler' or 'rk4').
|
|
330
|
+
device: Device for computation.
|
|
331
|
+
return_trajectory: If True, return intermediate states.
|
|
332
|
+
|
|
333
|
+
Returns:
|
|
334
|
+
Generated samples (num_samples, seq_len, latent_dim).
|
|
335
|
+
"""
|
|
336
|
+
# Start from noise
|
|
337
|
+
z_0 = torch.randn(num_samples, seq_len, self.latent_dim, device=device)
|
|
338
|
+
|
|
339
|
+
# Define velocity function for ODE solver
|
|
340
|
+
def velocity_fn(z: Tensor, t: float) -> Tensor:
|
|
341
|
+
t_tensor = torch.tensor([t], device=z.device, dtype=z.dtype).expand(z.shape[0])
|
|
342
|
+
return self.forward(z, t_tensor)
|
|
343
|
+
|
|
344
|
+
# Solve ODE from t=0 to t=1
|
|
345
|
+
return ode_solve(
|
|
346
|
+
velocity_fn,
|
|
347
|
+
z_0,
|
|
348
|
+
t_start=0.0,
|
|
349
|
+
t_end=1.0,
|
|
350
|
+
num_steps=num_steps,
|
|
351
|
+
method=method,
|
|
352
|
+
return_trajectory=return_trajectory,
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
class VariationalFlowMatcher(nn.Module):
|
|
357
|
+
"""
|
|
358
|
+
Variational Flow Matching (V-RFM) for multi-modal trajectory generation.
|
|
359
|
+
|
|
360
|
+
Extends FlowMatcher with a latent variable w that conditions the velocity
|
|
361
|
+
field, enabling representation of multiple modes in the trajectory space.
|
|
362
|
+
|
|
363
|
+
Training objective:
|
|
364
|
+
L = E[||v_theta(z_t, t, w) - (z_1 - z_0)||^2] + beta * KL(q(w|z_t,t) || p(w))
|
|
365
|
+
|
|
366
|
+
At inference, sampling different w values produces diverse trajectories.
|
|
367
|
+
"""
|
|
368
|
+
|
|
369
|
+
def __init__(
|
|
370
|
+
self,
|
|
371
|
+
latent_dim: int,
|
|
372
|
+
condition_dim: int = 16,
|
|
373
|
+
hidden_dim: int = 256,
|
|
374
|
+
num_layers: int = 4,
|
|
375
|
+
time_embed_dim: int = 64,
|
|
376
|
+
dropout: float = 0.1,
|
|
377
|
+
prior_std: float = 1.0,
|
|
378
|
+
):
|
|
379
|
+
"""
|
|
380
|
+
Args:
|
|
381
|
+
latent_dim: Dimension of the trajectory latent space z.
|
|
382
|
+
condition_dim: Dimension of the conditioning latent w.
|
|
383
|
+
hidden_dim: Hidden layer dimension.
|
|
384
|
+
num_layers: Number of residual blocks.
|
|
385
|
+
time_embed_dim: Dimension of time embedding.
|
|
386
|
+
dropout: Dropout rate.
|
|
387
|
+
prior_std: Standard deviation of prior p(w) = N(0, prior_std^2).
|
|
388
|
+
"""
|
|
389
|
+
super().__init__()
|
|
390
|
+
self.latent_dim = latent_dim
|
|
391
|
+
self.condition_dim = condition_dim
|
|
392
|
+
self.prior_std = prior_std
|
|
393
|
+
|
|
394
|
+
# Velocity network with conditioning
|
|
395
|
+
self.velocity_net = VelocityNetwork(
|
|
396
|
+
latent_dim=latent_dim,
|
|
397
|
+
hidden_dim=hidden_dim,
|
|
398
|
+
num_layers=num_layers,
|
|
399
|
+
time_embed_dim=time_embed_dim,
|
|
400
|
+
condition_dim=condition_dim,
|
|
401
|
+
dropout=dropout,
|
|
402
|
+
)
|
|
403
|
+
|
|
404
|
+
# Posterior encoder q(w | z_t, t)
|
|
405
|
+
# Simple MLP that maps (z_t, t) to mean and log_std of w
|
|
406
|
+
self.posterior_net = nn.Sequential(
|
|
407
|
+
nn.Linear(latent_dim + time_embed_dim, hidden_dim),
|
|
408
|
+
nn.GELU(),
|
|
409
|
+
nn.Linear(hidden_dim, hidden_dim),
|
|
410
|
+
nn.GELU(),
|
|
411
|
+
nn.Linear(hidden_dim, condition_dim * 2), # mean and log_std
|
|
412
|
+
)
|
|
413
|
+
|
|
414
|
+
self.time_embed = SinusoidalTimeEmbedding(time_embed_dim)
|
|
415
|
+
|
|
416
|
+
def encode_posterior(
|
|
417
|
+
self,
|
|
418
|
+
z_t: Tensor,
|
|
419
|
+
t: Tensor | float,
|
|
420
|
+
) -> Tuple[Tensor, Tensor]:
|
|
421
|
+
"""
|
|
422
|
+
Encode posterior q(w | z_t, t).
|
|
423
|
+
|
|
424
|
+
Args:
|
|
425
|
+
z_t: Current state (B, T, D) or (B, D).
|
|
426
|
+
t: Time (B,) or scalar.
|
|
427
|
+
|
|
428
|
+
Returns:
|
|
429
|
+
Tuple of (mean, log_std) for the posterior distribution.
|
|
430
|
+
"""
|
|
431
|
+
# Handle scalar time
|
|
432
|
+
if isinstance(t, float):
|
|
433
|
+
t = torch.tensor([t], device=z_t.device, dtype=z_t.dtype)
|
|
434
|
+
if t.dim() == 0:
|
|
435
|
+
t = t.unsqueeze(0)
|
|
436
|
+
|
|
437
|
+
# Pool over sequence if 3D
|
|
438
|
+
if z_t.dim() == 3:
|
|
439
|
+
z_pooled = z_t.mean(dim=1) # (B, D)
|
|
440
|
+
else:
|
|
441
|
+
z_pooled = z_t
|
|
442
|
+
|
|
443
|
+
# Time embedding
|
|
444
|
+
t_embed = self.time_embed(t) # (B, time_embed_dim)
|
|
445
|
+
|
|
446
|
+
# Concatenate and encode
|
|
447
|
+
x = torch.cat([z_pooled, t_embed], dim=-1)
|
|
448
|
+
out = self.posterior_net(x)
|
|
449
|
+
|
|
450
|
+
mean, log_std = out.chunk(2, dim=-1)
|
|
451
|
+
log_std = torch.clamp(log_std, min=-10, max=2) # Stability
|
|
452
|
+
|
|
453
|
+
return mean, log_std
|
|
454
|
+
|
|
455
|
+
def sample_posterior(
|
|
456
|
+
self,
|
|
457
|
+
z_t: Tensor,
|
|
458
|
+
t: Tensor | float,
|
|
459
|
+
) -> Tuple[Tensor, Tensor, Tensor]:
|
|
460
|
+
"""
|
|
461
|
+
Sample from posterior using reparameterization trick.
|
|
462
|
+
|
|
463
|
+
Returns:
|
|
464
|
+
Tuple of (w, mean, log_std).
|
|
465
|
+
"""
|
|
466
|
+
mean, log_std = self.encode_posterior(z_t, t)
|
|
467
|
+
std = torch.exp(log_std)
|
|
468
|
+
|
|
469
|
+
# Reparameterization
|
|
470
|
+
eps = torch.randn_like(mean)
|
|
471
|
+
w = mean + std * eps
|
|
472
|
+
|
|
473
|
+
return w, mean, log_std
|
|
474
|
+
|
|
475
|
+
def kl_divergence(self, mean: Tensor, log_std: Tensor) -> Tensor:
|
|
476
|
+
"""
|
|
477
|
+
Compute KL divergence from posterior to prior.
|
|
478
|
+
|
|
479
|
+
KL(N(mean, std) || N(0, prior_std))
|
|
480
|
+
"""
|
|
481
|
+
std = torch.exp(log_std)
|
|
482
|
+
prior_var = self.prior_std ** 2
|
|
483
|
+
|
|
484
|
+
kl = 0.5 * (
|
|
485
|
+
(std ** 2 + mean ** 2) / prior_var
|
|
486
|
+
- 1
|
|
487
|
+
- 2 * log_std
|
|
488
|
+
+ 2 * math.log(self.prior_std)
|
|
489
|
+
)
|
|
490
|
+
|
|
491
|
+
return kl.sum(dim=-1).mean()
|
|
492
|
+
|
|
493
|
+
def forward(
|
|
494
|
+
self,
|
|
495
|
+
z_t: Tensor,
|
|
496
|
+
t: Tensor | float,
|
|
497
|
+
w: Optional[Tensor] = None,
|
|
498
|
+
) -> Tensor:
|
|
499
|
+
"""
|
|
500
|
+
Predict velocity at state z_t and time t.
|
|
501
|
+
|
|
502
|
+
If w is None during training, it's sampled from the posterior.
|
|
503
|
+
"""
|
|
504
|
+
return self.velocity_net(z_t, t, w)
|
|
505
|
+
|
|
506
|
+
def interpolate(self, z_0: Tensor, z_1: Tensor, t: Tensor | float) -> Tensor:
|
|
507
|
+
"""Linear interpolation between noise and data."""
|
|
508
|
+
if isinstance(t, float):
|
|
509
|
+
return (1 - t) * z_0 + t * z_1
|
|
510
|
+
|
|
511
|
+
if z_0.dim() == 3:
|
|
512
|
+
t = t.view(-1, 1, 1)
|
|
513
|
+
else:
|
|
514
|
+
t = t.view(-1, 1)
|
|
515
|
+
|
|
516
|
+
return (1 - t) * z_0 + t * z_1
|
|
517
|
+
|
|
518
|
+
def target_velocity(self, z_0: Tensor, z_1: Tensor) -> Tensor:
|
|
519
|
+
"""Target velocity (straight path)."""
|
|
520
|
+
return z_1 - z_0
|
|
521
|
+
|
|
522
|
+
def loss(
|
|
523
|
+
self,
|
|
524
|
+
z_1: Tensor,
|
|
525
|
+
z_0: Optional[Tensor] = None,
|
|
526
|
+
beta: float = 0.01,
|
|
527
|
+
) -> Dict[str, Tensor]:
|
|
528
|
+
"""
|
|
529
|
+
Compute V-RFM loss with KL regularization.
|
|
530
|
+
|
|
531
|
+
Args:
|
|
532
|
+
z_1: Data samples (B, T, D) or (B, D).
|
|
533
|
+
z_0: Noise samples. If None, sampled from N(0, 1).
|
|
534
|
+
beta: Weight for KL divergence term.
|
|
535
|
+
|
|
536
|
+
Returns:
|
|
537
|
+
Dictionary with 'total', 'reconstruction', and 'kl' losses.
|
|
538
|
+
"""
|
|
539
|
+
if z_0 is None:
|
|
540
|
+
z_0 = torch.randn_like(z_1)
|
|
541
|
+
|
|
542
|
+
B = z_1.shape[0]
|
|
543
|
+
t = torch.rand(B, device=z_1.device, dtype=z_1.dtype)
|
|
544
|
+
|
|
545
|
+
# Interpolate
|
|
546
|
+
z_t = self.interpolate(z_0, z_1, t)
|
|
547
|
+
|
|
548
|
+
# Sample from posterior
|
|
549
|
+
w, mean, log_std = self.sample_posterior(z_t, t)
|
|
550
|
+
|
|
551
|
+
# Target velocity
|
|
552
|
+
v_target = self.target_velocity(z_0, z_1)
|
|
553
|
+
|
|
554
|
+
# Predicted velocity (conditioned on w)
|
|
555
|
+
v_pred = self.forward(z_t, t, w)
|
|
556
|
+
|
|
557
|
+
# Reconstruction loss
|
|
558
|
+
recon_loss = F.mse_loss(v_pred, v_target)
|
|
559
|
+
|
|
560
|
+
# KL divergence
|
|
561
|
+
kl_loss = self.kl_divergence(mean, log_std)
|
|
562
|
+
|
|
563
|
+
# Total loss
|
|
564
|
+
total_loss = recon_loss + beta * kl_loss
|
|
565
|
+
|
|
566
|
+
return {
|
|
567
|
+
"total": total_loss,
|
|
568
|
+
"reconstruction": recon_loss,
|
|
569
|
+
"kl": kl_loss,
|
|
570
|
+
}
|
|
571
|
+
|
|
572
|
+
def sample(
|
|
573
|
+
self,
|
|
574
|
+
num_samples: int,
|
|
575
|
+
seq_len: int,
|
|
576
|
+
num_steps: int = 10,
|
|
577
|
+
method: str = "euler",
|
|
578
|
+
w: Optional[Tensor] = None,
|
|
579
|
+
device: str = "cpu",
|
|
580
|
+
return_trajectory: bool = False,
|
|
581
|
+
) -> Tensor | Tuple[Tensor, list]:
|
|
582
|
+
"""
|
|
583
|
+
Generate samples by solving the flow ODE.
|
|
584
|
+
|
|
585
|
+
Args:
|
|
586
|
+
num_samples: Number of samples to generate.
|
|
587
|
+
seq_len: Sequence length T.
|
|
588
|
+
num_steps: Number of ODE integration steps.
|
|
589
|
+
method: ODE solver ('euler' or 'rk4').
|
|
590
|
+
w: Conditioning latent. If None, sampled from prior.
|
|
591
|
+
device: Device for computation.
|
|
592
|
+
return_trajectory: If True, return intermediate states.
|
|
593
|
+
|
|
594
|
+
Returns:
|
|
595
|
+
Generated samples (num_samples, seq_len, latent_dim).
|
|
596
|
+
"""
|
|
597
|
+
# Sample w from prior if not provided
|
|
598
|
+
if w is None:
|
|
599
|
+
w = torch.randn(num_samples, self.condition_dim, device=device) * self.prior_std
|
|
600
|
+
|
|
601
|
+
# Start from noise
|
|
602
|
+
z_0 = torch.randn(num_samples, seq_len, self.latent_dim, device=device)
|
|
603
|
+
|
|
604
|
+
# Define velocity function
|
|
605
|
+
def velocity_fn(z: Tensor, t: float) -> Tensor:
|
|
606
|
+
t_tensor = torch.tensor([t], device=z.device, dtype=z.dtype).expand(z.shape[0])
|
|
607
|
+
return self.forward(z, t_tensor, w)
|
|
608
|
+
|
|
609
|
+
return ode_solve(
|
|
610
|
+
velocity_fn,
|
|
611
|
+
z_0,
|
|
612
|
+
t_start=0.0,
|
|
613
|
+
t_end=1.0,
|
|
614
|
+
num_steps=num_steps,
|
|
615
|
+
method=method,
|
|
616
|
+
return_trajectory=return_trajectory,
|
|
617
|
+
)
|
|
618
|
+
|
|
619
|
+
def sample_multimodal(
|
|
620
|
+
self,
|
|
621
|
+
num_samples: int,
|
|
622
|
+
seq_len: int,
|
|
623
|
+
num_modes: int = 4,
|
|
624
|
+
num_steps: int = 10,
|
|
625
|
+
method: str = "euler",
|
|
626
|
+
device: str = "cpu",
|
|
627
|
+
) -> Tensor:
|
|
628
|
+
"""
|
|
629
|
+
Generate diverse samples by sampling different w values.
|
|
630
|
+
|
|
631
|
+
Args:
|
|
632
|
+
num_samples: Number of samples per mode.
|
|
633
|
+
seq_len: Sequence length.
|
|
634
|
+
num_modes: Number of distinct modes to sample.
|
|
635
|
+
num_steps: ODE integration steps.
|
|
636
|
+
method: ODE solver method.
|
|
637
|
+
device: Device for computation.
|
|
638
|
+
|
|
639
|
+
Returns:
|
|
640
|
+
Generated samples (num_modes * num_samples, seq_len, latent_dim).
|
|
641
|
+
"""
|
|
642
|
+
all_samples = []
|
|
643
|
+
|
|
644
|
+
for _ in range(num_modes):
|
|
645
|
+
# Sample a different w for each mode
|
|
646
|
+
w = torch.randn(num_samples, self.condition_dim, device=device) * self.prior_std
|
|
647
|
+
samples = self.sample(
|
|
648
|
+
num_samples, seq_len, num_steps, method, w, device
|
|
649
|
+
)
|
|
650
|
+
all_samples.append(samples)
|
|
651
|
+
|
|
652
|
+
return torch.cat(all_samples, dim=0)
|
|
653
|
+
|
|
654
|
+
def sample_continuation(
|
|
655
|
+
self,
|
|
656
|
+
z_prefix: Tensor,
|
|
657
|
+
prefix_len: int,
|
|
658
|
+
total_len: int,
|
|
659
|
+
num_modes: int = 1,
|
|
660
|
+
num_steps: int = 10,
|
|
661
|
+
method: str = "euler",
|
|
662
|
+
) -> Tensor:
|
|
663
|
+
"""
|
|
664
|
+
Generate trajectory continuations from a prefix.
|
|
665
|
+
|
|
666
|
+
Args:
|
|
667
|
+
z_prefix: Prefix latent sequence (B, prefix_len, D).
|
|
668
|
+
prefix_len: Length of prefix to condition on.
|
|
669
|
+
total_len: Total output sequence length.
|
|
670
|
+
num_modes: Number of continuation modes to generate.
|
|
671
|
+
num_steps: ODE integration steps.
|
|
672
|
+
method: ODE solver method.
|
|
673
|
+
|
|
674
|
+
Returns:
|
|
675
|
+
Continued sequences (B * num_modes, total_len, D).
|
|
676
|
+
"""
|
|
677
|
+
B = z_prefix.shape[0]
|
|
678
|
+
device = z_prefix.device
|
|
679
|
+
continuation_len = total_len - prefix_len
|
|
680
|
+
|
|
681
|
+
all_continuations = []
|
|
682
|
+
|
|
683
|
+
for _ in range(num_modes):
|
|
684
|
+
# Sample w from prior
|
|
685
|
+
w = torch.randn(B, self.condition_dim, device=device) * self.prior_std
|
|
686
|
+
|
|
687
|
+
# Start from noise for continuation part
|
|
688
|
+
z_cont_noise = torch.randn(B, continuation_len, self.latent_dim, device=device)
|
|
689
|
+
|
|
690
|
+
# Full sequence: prefix (fixed) + continuation (to be generated)
|
|
691
|
+
# For now, simple approach: generate full sequence conditioned on w
|
|
692
|
+
# and use the continuation part
|
|
693
|
+
z_0 = torch.randn(B, total_len, self.latent_dim, device=device)
|
|
694
|
+
|
|
695
|
+
def velocity_fn(z: Tensor, t: float) -> Tensor:
|
|
696
|
+
t_tensor = torch.tensor([t], device=z.device, dtype=z.dtype).expand(z.shape[0])
|
|
697
|
+
return self.forward(z, t_tensor, w)
|
|
698
|
+
|
|
699
|
+
z_1 = ode_solve(velocity_fn, z_0, 0.0, 1.0, num_steps, method)
|
|
700
|
+
|
|
701
|
+
# Blend: keep prefix, take generated continuation
|
|
702
|
+
z_out = torch.cat([z_prefix, z_1[:, prefix_len:]], dim=1)
|
|
703
|
+
all_continuations.append(z_out)
|
|
704
|
+
|
|
705
|
+
return torch.cat(all_continuations, dim=0)
|