dhb-xr 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. dhb_xr/__init__.py +61 -0
  2. dhb_xr/cli.py +206 -0
  3. dhb_xr/core/__init__.py +28 -0
  4. dhb_xr/core/geometry.py +167 -0
  5. dhb_xr/core/geometry_torch.py +77 -0
  6. dhb_xr/core/types.py +113 -0
  7. dhb_xr/database/__init__.py +10 -0
  8. dhb_xr/database/motion_db.py +79 -0
  9. dhb_xr/database/retrieval.py +6 -0
  10. dhb_xr/database/similarity.py +71 -0
  11. dhb_xr/decoder/__init__.py +13 -0
  12. dhb_xr/decoder/decoder_torch.py +52 -0
  13. dhb_xr/decoder/dhb_dr.py +261 -0
  14. dhb_xr/decoder/dhb_qr.py +89 -0
  15. dhb_xr/encoder/__init__.py +27 -0
  16. dhb_xr/encoder/dhb_dr.py +418 -0
  17. dhb_xr/encoder/dhb_qr.py +129 -0
  18. dhb_xr/encoder/dhb_ti.py +204 -0
  19. dhb_xr/encoder/encoder_torch.py +54 -0
  20. dhb_xr/encoder/padding.py +82 -0
  21. dhb_xr/generative/__init__.py +78 -0
  22. dhb_xr/generative/flow_matching.py +705 -0
  23. dhb_xr/generative/latent_encoder.py +536 -0
  24. dhb_xr/generative/sampling.py +203 -0
  25. dhb_xr/generative/training.py +475 -0
  26. dhb_xr/generative/vfm_tokenizer.py +485 -0
  27. dhb_xr/integration/__init__.py +13 -0
  28. dhb_xr/integration/vla/__init__.py +11 -0
  29. dhb_xr/integration/vla/libero.py +132 -0
  30. dhb_xr/integration/vla/pipeline.py +85 -0
  31. dhb_xr/integration/vla/robocasa.py +85 -0
  32. dhb_xr/losses/__init__.py +16 -0
  33. dhb_xr/losses/geodesic_loss.py +91 -0
  34. dhb_xr/losses/hybrid_loss.py +36 -0
  35. dhb_xr/losses/invariant_loss.py +73 -0
  36. dhb_xr/optimization/__init__.py +72 -0
  37. dhb_xr/optimization/casadi_solver.py +342 -0
  38. dhb_xr/optimization/constraints.py +32 -0
  39. dhb_xr/optimization/cusadi_solver.py +311 -0
  40. dhb_xr/optimization/export_casadi_decode.py +111 -0
  41. dhb_xr/optimization/fatrop_solver.py +477 -0
  42. dhb_xr/optimization/torch_solver.py +85 -0
  43. dhb_xr/preprocessing/__init__.py +42 -0
  44. dhb_xr/preprocessing/diagnostics.py +330 -0
  45. dhb_xr/preprocessing/trajectory_cleaner.py +485 -0
  46. dhb_xr/tokenization/__init__.py +56 -0
  47. dhb_xr/tokenization/causal_encoder.py +54 -0
  48. dhb_xr/tokenization/compression.py +749 -0
  49. dhb_xr/tokenization/hierarchical.py +359 -0
  50. dhb_xr/tokenization/rvq.py +178 -0
  51. dhb_xr/tokenization/vqvae.py +155 -0
  52. dhb_xr/utils/__init__.py +24 -0
  53. dhb_xr/utils/io.py +59 -0
  54. dhb_xr/utils/resampling.py +66 -0
  55. dhb_xr/utils/xdof_loader.py +89 -0
  56. dhb_xr/visualization/__init__.py +5 -0
  57. dhb_xr/visualization/plot.py +242 -0
  58. dhb_xr-0.2.1.dist-info/METADATA +784 -0
  59. dhb_xr-0.2.1.dist-info/RECORD +82 -0
  60. dhb_xr-0.2.1.dist-info/WHEEL +5 -0
  61. dhb_xr-0.2.1.dist-info/entry_points.txt +2 -0
  62. dhb_xr-0.2.1.dist-info/top_level.txt +3 -0
  63. examples/__init__.py +54 -0
  64. examples/basic_encoding.py +82 -0
  65. examples/benchmark_backends.py +37 -0
  66. examples/dhb_qr_comparison.py +79 -0
  67. examples/dhb_ti_time_invariant.py +72 -0
  68. examples/gpu_batch_optimization.py +102 -0
  69. examples/imitation_learning.py +53 -0
  70. examples/integration/__init__.py +19 -0
  71. examples/integration/libero_full_demo.py +692 -0
  72. examples/integration/libero_pro_dhb_demo.py +1063 -0
  73. examples/integration/libero_simulation_demo.py +286 -0
  74. examples/integration/libero_swap_demo.py +534 -0
  75. examples/integration/robocasa_libero_dhb_pipeline.py +56 -0
  76. examples/integration/test_libero_adapter.py +47 -0
  77. examples/integration/test_libero_encoding.py +75 -0
  78. examples/integration/test_libero_retrieval.py +105 -0
  79. examples/motion_database.py +88 -0
  80. examples/trajectory_adaptation.py +85 -0
  81. examples/vla_tokenization.py +107 -0
  82. notebooks/__init__.py +24 -0
@@ -0,0 +1,705 @@
1
+ """
2
+ Variational Flow Matching for DHB-Token generation.
3
+
4
+ Implements:
5
+ - FlowMatcher: Base flow matching (FM/RFM) for deterministic velocity regression
6
+ - VariationalFlowMatcher: V-RFM with latent conditioning for multi-modal generation
7
+
8
+ The variational extension addresses the "mode averaging" problem in standard FM:
9
+ when multiple valid trajectories exist from the same state, a deterministic velocity
10
+ regressor averages modes, producing implausible intermediate behavior. By conditioning
11
+ on a latent variable w, V-RFM can represent and sample from multiple modes.
12
+
13
+ Reference: "Flow Matching for Generative Modeling" (Lipman et al., 2023)
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ import math
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+ from torch import Tensor
23
+ from typing import Optional, Tuple, Dict, Any
24
+
25
+ from .sampling import ode_solve
26
+
27
+
28
+ class SinusoidalTimeEmbedding(nn.Module):
29
+ """
30
+ Sinusoidal positional embedding for time t in [0, 1].
31
+
32
+ Maps scalar time to a high-dimensional representation using
33
+ sinusoidal functions at different frequencies.
34
+ """
35
+
36
+ def __init__(self, dim: int, max_period: float = 10000.0):
37
+ """
38
+ Args:
39
+ dim: Output embedding dimension (should be even).
40
+ max_period: Maximum period for sinusoidal functions.
41
+ """
42
+ super().__init__()
43
+ self.dim = dim
44
+ self.max_period = max_period
45
+
46
+ # Precompute frequency factors
47
+ half_dim = dim // 2
48
+ freqs = torch.exp(
49
+ -math.log(max_period) * torch.arange(half_dim, dtype=torch.float32) / half_dim
50
+ )
51
+ self.register_buffer("freqs", freqs)
52
+
53
+ def forward(self, t: Tensor) -> Tensor:
54
+ """
55
+ Args:
56
+ t: Time values (B,) or scalar, in [0, 1].
57
+
58
+ Returns:
59
+ Time embeddings (B, dim) or (dim,).
60
+ """
61
+ if t.dim() == 0:
62
+ t = t.unsqueeze(0)
63
+
64
+ # Scale t to match frequency range
65
+ args = t.unsqueeze(-1) * self.freqs * 2 * math.pi
66
+
67
+ # Concatenate sin and cos
68
+ embedding = torch.cat([torch.sin(args), torch.cos(args)], dim=-1)
69
+
70
+ return embedding
71
+
72
+
73
+ class VelocityNetwork(nn.Module):
74
+ """
75
+ Neural network predicting the velocity field v_theta(z_t, t, w).
76
+
77
+ Architecture: MLP with residual connections, processing temporal sequences.
78
+ Supports optional latent conditioning for variational flow matching.
79
+ """
80
+
81
+ def __init__(
82
+ self,
83
+ latent_dim: int,
84
+ hidden_dim: int = 256,
85
+ num_layers: int = 4,
86
+ time_embed_dim: int = 64,
87
+ condition_dim: int = 0,
88
+ dropout: float = 0.1,
89
+ use_layer_norm: bool = True,
90
+ ):
91
+ """
92
+ Args:
93
+ latent_dim: Dimension of the latent space z.
94
+ hidden_dim: Hidden layer dimension.
95
+ num_layers: Number of residual blocks.
96
+ time_embed_dim: Dimension of time embedding.
97
+ condition_dim: Dimension of conditioning latent w (0 for unconditional).
98
+ dropout: Dropout rate.
99
+ use_layer_norm: Whether to use layer normalization.
100
+ """
101
+ super().__init__()
102
+ self.latent_dim = latent_dim
103
+ self.hidden_dim = hidden_dim
104
+ self.condition_dim = condition_dim
105
+
106
+ # Time embedding
107
+ self.time_embed = SinusoidalTimeEmbedding(time_embed_dim)
108
+
109
+ # Input projection: z_t + time_embed + (optional) condition
110
+ input_dim = latent_dim + time_embed_dim + condition_dim
111
+ self.input_proj = nn.Linear(input_dim, hidden_dim)
112
+
113
+ # Residual blocks
114
+ self.blocks = nn.ModuleList()
115
+ for _ in range(num_layers):
116
+ block = nn.Sequential(
117
+ nn.LayerNorm(hidden_dim) if use_layer_norm else nn.Identity(),
118
+ nn.Linear(hidden_dim, hidden_dim * 2),
119
+ nn.GELU(),
120
+ nn.Dropout(dropout),
121
+ nn.Linear(hidden_dim * 2, hidden_dim),
122
+ nn.Dropout(dropout),
123
+ )
124
+ self.blocks.append(block)
125
+
126
+ # Output projection to velocity
127
+ self.output_proj = nn.Sequential(
128
+ nn.LayerNorm(hidden_dim) if use_layer_norm else nn.Identity(),
129
+ nn.Linear(hidden_dim, latent_dim),
130
+ )
131
+
132
+ def forward(
133
+ self,
134
+ z_t: Tensor,
135
+ t: Tensor | float,
136
+ w: Optional[Tensor] = None,
137
+ ) -> Tensor:
138
+ """
139
+ Predict velocity at state z_t and time t.
140
+
141
+ Args:
142
+ z_t: Current state (B, T, D) or (B, D).
143
+ t: Time (B,) or scalar in [0, 1].
144
+ w: Optional conditioning latent (B, condition_dim).
145
+
146
+ Returns:
147
+ Predicted velocity v_theta (same shape as z_t).
148
+ """
149
+ # Handle scalar time
150
+ if isinstance(t, float):
151
+ t = torch.tensor([t], device=z_t.device, dtype=z_t.dtype)
152
+ if t.dim() == 0:
153
+ t = t.unsqueeze(0)
154
+
155
+ # Get batch size and handle 2D/3D input
156
+ is_3d = z_t.dim() == 3
157
+ if is_3d:
158
+ B, T, D = z_t.shape
159
+ z_flat = z_t.reshape(B * T, D)
160
+ else:
161
+ B, D = z_t.shape
162
+ T = 1
163
+ z_flat = z_t
164
+
165
+ # Time embedding (broadcast to all timesteps)
166
+ t_embed = self.time_embed(t) # (B, time_embed_dim)
167
+ if is_3d:
168
+ t_embed = t_embed.unsqueeze(1).expand(-1, T, -1).reshape(B * T, -1)
169
+ else:
170
+ t_embed = t_embed.expand(B, -1)
171
+
172
+ # Build input
173
+ inputs = [z_flat, t_embed]
174
+
175
+ # Add conditioning if provided
176
+ if w is not None:
177
+ if is_3d:
178
+ w_expanded = w.unsqueeze(1).expand(-1, T, -1).reshape(B * T, -1)
179
+ else:
180
+ w_expanded = w
181
+ inputs.append(w_expanded)
182
+
183
+ x = torch.cat(inputs, dim=-1)
184
+ x = self.input_proj(x)
185
+
186
+ # Residual blocks
187
+ for block in self.blocks:
188
+ x = x + block(x)
189
+
190
+ # Output velocity
191
+ v = self.output_proj(x)
192
+
193
+ # Reshape back if needed
194
+ if is_3d:
195
+ v = v.reshape(B, T, D)
196
+
197
+ return v
198
+
199
+
200
+ class FlowMatcher(nn.Module):
201
+ """
202
+ Base Flow Matching model for trajectory generation.
203
+
204
+ Learns a velocity field that transports samples from noise (z_0)
205
+ to data (z_1) along straight interpolation paths.
206
+
207
+ Training objective:
208
+ L = E[||v_theta(z_t, t) - (z_1 - z_0)||^2]
209
+
210
+ where z_t = (1-t)*z_0 + t*z_1 is the linear interpolation.
211
+ """
212
+
213
+ def __init__(
214
+ self,
215
+ latent_dim: int,
216
+ hidden_dim: int = 256,
217
+ num_layers: int = 4,
218
+ time_embed_dim: int = 64,
219
+ dropout: float = 0.1,
220
+ ):
221
+ """
222
+ Args:
223
+ latent_dim: Dimension of the latent space.
224
+ hidden_dim: Hidden layer dimension.
225
+ num_layers: Number of residual blocks.
226
+ time_embed_dim: Dimension of time embedding.
227
+ dropout: Dropout rate.
228
+ """
229
+ super().__init__()
230
+ self.latent_dim = latent_dim
231
+
232
+ self.velocity_net = VelocityNetwork(
233
+ latent_dim=latent_dim,
234
+ hidden_dim=hidden_dim,
235
+ num_layers=num_layers,
236
+ time_embed_dim=time_embed_dim,
237
+ condition_dim=0,
238
+ dropout=dropout,
239
+ )
240
+
241
+ def forward(self, z_t: Tensor, t: Tensor | float) -> Tensor:
242
+ """Predict velocity at state z_t and time t."""
243
+ return self.velocity_net(z_t, t)
244
+
245
+ def interpolate(self, z_0: Tensor, z_1: Tensor, t: Tensor | float) -> Tensor:
246
+ """
247
+ Linear interpolation between noise (z_0) and data (z_1).
248
+
249
+ Args:
250
+ z_0: Noise samples (B, T, D) or (B, D).
251
+ z_1: Data samples (same shape).
252
+ t: Interpolation time in [0, 1].
253
+
254
+ Returns:
255
+ Interpolated state z_t.
256
+ """
257
+ if isinstance(t, float):
258
+ return (1 - t) * z_0 + t * z_1
259
+
260
+ # Handle tensor t with proper broadcasting
261
+ if z_0.dim() == 3:
262
+ t = t.view(-1, 1, 1)
263
+ else:
264
+ t = t.view(-1, 1)
265
+
266
+ return (1 - t) * z_0 + t * z_1
267
+
268
+ def target_velocity(self, z_0: Tensor, z_1: Tensor) -> Tensor:
269
+ """
270
+ Compute target velocity for flow matching (straight path).
271
+
272
+ The optimal transport direction is simply z_1 - z_0.
273
+ """
274
+ return z_1 - z_0
275
+
276
+ def loss(
277
+ self,
278
+ z_1: Tensor,
279
+ z_0: Optional[Tensor] = None,
280
+ ) -> Tensor:
281
+ """
282
+ Compute flow matching loss.
283
+
284
+ Args:
285
+ z_1: Data samples (B, T, D) or (B, D).
286
+ z_0: Noise samples. If None, sampled from N(0, 1).
287
+
288
+ Returns:
289
+ Scalar loss value.
290
+ """
291
+ # Sample noise if not provided
292
+ if z_0 is None:
293
+ z_0 = torch.randn_like(z_1)
294
+
295
+ # Sample random time
296
+ B = z_1.shape[0]
297
+ t = torch.rand(B, device=z_1.device, dtype=z_1.dtype)
298
+
299
+ # Interpolate
300
+ z_t = self.interpolate(z_0, z_1, t)
301
+
302
+ # Target velocity
303
+ v_target = self.target_velocity(z_0, z_1)
304
+
305
+ # Predicted velocity
306
+ v_pred = self.forward(z_t, t)
307
+
308
+ # MSE loss
309
+ loss = F.mse_loss(v_pred, v_target)
310
+
311
+ return loss
312
+
313
+ def sample(
314
+ self,
315
+ num_samples: int,
316
+ seq_len: int,
317
+ num_steps: int = 10,
318
+ method: str = "euler",
319
+ device: str = "cpu",
320
+ return_trajectory: bool = False,
321
+ ) -> Tensor | Tuple[Tensor, list]:
322
+ """
323
+ Generate samples by solving the flow ODE from noise to data.
324
+
325
+ Args:
326
+ num_samples: Number of samples to generate.
327
+ seq_len: Sequence length T.
328
+ num_steps: Number of ODE integration steps.
329
+ method: ODE solver ('euler' or 'rk4').
330
+ device: Device for computation.
331
+ return_trajectory: If True, return intermediate states.
332
+
333
+ Returns:
334
+ Generated samples (num_samples, seq_len, latent_dim).
335
+ """
336
+ # Start from noise
337
+ z_0 = torch.randn(num_samples, seq_len, self.latent_dim, device=device)
338
+
339
+ # Define velocity function for ODE solver
340
+ def velocity_fn(z: Tensor, t: float) -> Tensor:
341
+ t_tensor = torch.tensor([t], device=z.device, dtype=z.dtype).expand(z.shape[0])
342
+ return self.forward(z, t_tensor)
343
+
344
+ # Solve ODE from t=0 to t=1
345
+ return ode_solve(
346
+ velocity_fn,
347
+ z_0,
348
+ t_start=0.0,
349
+ t_end=1.0,
350
+ num_steps=num_steps,
351
+ method=method,
352
+ return_trajectory=return_trajectory,
353
+ )
354
+
355
+
356
+ class VariationalFlowMatcher(nn.Module):
357
+ """
358
+ Variational Flow Matching (V-RFM) for multi-modal trajectory generation.
359
+
360
+ Extends FlowMatcher with a latent variable w that conditions the velocity
361
+ field, enabling representation of multiple modes in the trajectory space.
362
+
363
+ Training objective:
364
+ L = E[||v_theta(z_t, t, w) - (z_1 - z_0)||^2] + beta * KL(q(w|z_t,t) || p(w))
365
+
366
+ At inference, sampling different w values produces diverse trajectories.
367
+ """
368
+
369
+ def __init__(
370
+ self,
371
+ latent_dim: int,
372
+ condition_dim: int = 16,
373
+ hidden_dim: int = 256,
374
+ num_layers: int = 4,
375
+ time_embed_dim: int = 64,
376
+ dropout: float = 0.1,
377
+ prior_std: float = 1.0,
378
+ ):
379
+ """
380
+ Args:
381
+ latent_dim: Dimension of the trajectory latent space z.
382
+ condition_dim: Dimension of the conditioning latent w.
383
+ hidden_dim: Hidden layer dimension.
384
+ num_layers: Number of residual blocks.
385
+ time_embed_dim: Dimension of time embedding.
386
+ dropout: Dropout rate.
387
+ prior_std: Standard deviation of prior p(w) = N(0, prior_std^2).
388
+ """
389
+ super().__init__()
390
+ self.latent_dim = latent_dim
391
+ self.condition_dim = condition_dim
392
+ self.prior_std = prior_std
393
+
394
+ # Velocity network with conditioning
395
+ self.velocity_net = VelocityNetwork(
396
+ latent_dim=latent_dim,
397
+ hidden_dim=hidden_dim,
398
+ num_layers=num_layers,
399
+ time_embed_dim=time_embed_dim,
400
+ condition_dim=condition_dim,
401
+ dropout=dropout,
402
+ )
403
+
404
+ # Posterior encoder q(w | z_t, t)
405
+ # Simple MLP that maps (z_t, t) to mean and log_std of w
406
+ self.posterior_net = nn.Sequential(
407
+ nn.Linear(latent_dim + time_embed_dim, hidden_dim),
408
+ nn.GELU(),
409
+ nn.Linear(hidden_dim, hidden_dim),
410
+ nn.GELU(),
411
+ nn.Linear(hidden_dim, condition_dim * 2), # mean and log_std
412
+ )
413
+
414
+ self.time_embed = SinusoidalTimeEmbedding(time_embed_dim)
415
+
416
+ def encode_posterior(
417
+ self,
418
+ z_t: Tensor,
419
+ t: Tensor | float,
420
+ ) -> Tuple[Tensor, Tensor]:
421
+ """
422
+ Encode posterior q(w | z_t, t).
423
+
424
+ Args:
425
+ z_t: Current state (B, T, D) or (B, D).
426
+ t: Time (B,) or scalar.
427
+
428
+ Returns:
429
+ Tuple of (mean, log_std) for the posterior distribution.
430
+ """
431
+ # Handle scalar time
432
+ if isinstance(t, float):
433
+ t = torch.tensor([t], device=z_t.device, dtype=z_t.dtype)
434
+ if t.dim() == 0:
435
+ t = t.unsqueeze(0)
436
+
437
+ # Pool over sequence if 3D
438
+ if z_t.dim() == 3:
439
+ z_pooled = z_t.mean(dim=1) # (B, D)
440
+ else:
441
+ z_pooled = z_t
442
+
443
+ # Time embedding
444
+ t_embed = self.time_embed(t) # (B, time_embed_dim)
445
+
446
+ # Concatenate and encode
447
+ x = torch.cat([z_pooled, t_embed], dim=-1)
448
+ out = self.posterior_net(x)
449
+
450
+ mean, log_std = out.chunk(2, dim=-1)
451
+ log_std = torch.clamp(log_std, min=-10, max=2) # Stability
452
+
453
+ return mean, log_std
454
+
455
+ def sample_posterior(
456
+ self,
457
+ z_t: Tensor,
458
+ t: Tensor | float,
459
+ ) -> Tuple[Tensor, Tensor, Tensor]:
460
+ """
461
+ Sample from posterior using reparameterization trick.
462
+
463
+ Returns:
464
+ Tuple of (w, mean, log_std).
465
+ """
466
+ mean, log_std = self.encode_posterior(z_t, t)
467
+ std = torch.exp(log_std)
468
+
469
+ # Reparameterization
470
+ eps = torch.randn_like(mean)
471
+ w = mean + std * eps
472
+
473
+ return w, mean, log_std
474
+
475
+ def kl_divergence(self, mean: Tensor, log_std: Tensor) -> Tensor:
476
+ """
477
+ Compute KL divergence from posterior to prior.
478
+
479
+ KL(N(mean, std) || N(0, prior_std))
480
+ """
481
+ std = torch.exp(log_std)
482
+ prior_var = self.prior_std ** 2
483
+
484
+ kl = 0.5 * (
485
+ (std ** 2 + mean ** 2) / prior_var
486
+ - 1
487
+ - 2 * log_std
488
+ + 2 * math.log(self.prior_std)
489
+ )
490
+
491
+ return kl.sum(dim=-1).mean()
492
+
493
+ def forward(
494
+ self,
495
+ z_t: Tensor,
496
+ t: Tensor | float,
497
+ w: Optional[Tensor] = None,
498
+ ) -> Tensor:
499
+ """
500
+ Predict velocity at state z_t and time t.
501
+
502
+ If w is None during training, it's sampled from the posterior.
503
+ """
504
+ return self.velocity_net(z_t, t, w)
505
+
506
+ def interpolate(self, z_0: Tensor, z_1: Tensor, t: Tensor | float) -> Tensor:
507
+ """Linear interpolation between noise and data."""
508
+ if isinstance(t, float):
509
+ return (1 - t) * z_0 + t * z_1
510
+
511
+ if z_0.dim() == 3:
512
+ t = t.view(-1, 1, 1)
513
+ else:
514
+ t = t.view(-1, 1)
515
+
516
+ return (1 - t) * z_0 + t * z_1
517
+
518
+ def target_velocity(self, z_0: Tensor, z_1: Tensor) -> Tensor:
519
+ """Target velocity (straight path)."""
520
+ return z_1 - z_0
521
+
522
+ def loss(
523
+ self,
524
+ z_1: Tensor,
525
+ z_0: Optional[Tensor] = None,
526
+ beta: float = 0.01,
527
+ ) -> Dict[str, Tensor]:
528
+ """
529
+ Compute V-RFM loss with KL regularization.
530
+
531
+ Args:
532
+ z_1: Data samples (B, T, D) or (B, D).
533
+ z_0: Noise samples. If None, sampled from N(0, 1).
534
+ beta: Weight for KL divergence term.
535
+
536
+ Returns:
537
+ Dictionary with 'total', 'reconstruction', and 'kl' losses.
538
+ """
539
+ if z_0 is None:
540
+ z_0 = torch.randn_like(z_1)
541
+
542
+ B = z_1.shape[0]
543
+ t = torch.rand(B, device=z_1.device, dtype=z_1.dtype)
544
+
545
+ # Interpolate
546
+ z_t = self.interpolate(z_0, z_1, t)
547
+
548
+ # Sample from posterior
549
+ w, mean, log_std = self.sample_posterior(z_t, t)
550
+
551
+ # Target velocity
552
+ v_target = self.target_velocity(z_0, z_1)
553
+
554
+ # Predicted velocity (conditioned on w)
555
+ v_pred = self.forward(z_t, t, w)
556
+
557
+ # Reconstruction loss
558
+ recon_loss = F.mse_loss(v_pred, v_target)
559
+
560
+ # KL divergence
561
+ kl_loss = self.kl_divergence(mean, log_std)
562
+
563
+ # Total loss
564
+ total_loss = recon_loss + beta * kl_loss
565
+
566
+ return {
567
+ "total": total_loss,
568
+ "reconstruction": recon_loss,
569
+ "kl": kl_loss,
570
+ }
571
+
572
+ def sample(
573
+ self,
574
+ num_samples: int,
575
+ seq_len: int,
576
+ num_steps: int = 10,
577
+ method: str = "euler",
578
+ w: Optional[Tensor] = None,
579
+ device: str = "cpu",
580
+ return_trajectory: bool = False,
581
+ ) -> Tensor | Tuple[Tensor, list]:
582
+ """
583
+ Generate samples by solving the flow ODE.
584
+
585
+ Args:
586
+ num_samples: Number of samples to generate.
587
+ seq_len: Sequence length T.
588
+ num_steps: Number of ODE integration steps.
589
+ method: ODE solver ('euler' or 'rk4').
590
+ w: Conditioning latent. If None, sampled from prior.
591
+ device: Device for computation.
592
+ return_trajectory: If True, return intermediate states.
593
+
594
+ Returns:
595
+ Generated samples (num_samples, seq_len, latent_dim).
596
+ """
597
+ # Sample w from prior if not provided
598
+ if w is None:
599
+ w = torch.randn(num_samples, self.condition_dim, device=device) * self.prior_std
600
+
601
+ # Start from noise
602
+ z_0 = torch.randn(num_samples, seq_len, self.latent_dim, device=device)
603
+
604
+ # Define velocity function
605
+ def velocity_fn(z: Tensor, t: float) -> Tensor:
606
+ t_tensor = torch.tensor([t], device=z.device, dtype=z.dtype).expand(z.shape[0])
607
+ return self.forward(z, t_tensor, w)
608
+
609
+ return ode_solve(
610
+ velocity_fn,
611
+ z_0,
612
+ t_start=0.0,
613
+ t_end=1.0,
614
+ num_steps=num_steps,
615
+ method=method,
616
+ return_trajectory=return_trajectory,
617
+ )
618
+
619
+ def sample_multimodal(
620
+ self,
621
+ num_samples: int,
622
+ seq_len: int,
623
+ num_modes: int = 4,
624
+ num_steps: int = 10,
625
+ method: str = "euler",
626
+ device: str = "cpu",
627
+ ) -> Tensor:
628
+ """
629
+ Generate diverse samples by sampling different w values.
630
+
631
+ Args:
632
+ num_samples: Number of samples per mode.
633
+ seq_len: Sequence length.
634
+ num_modes: Number of distinct modes to sample.
635
+ num_steps: ODE integration steps.
636
+ method: ODE solver method.
637
+ device: Device for computation.
638
+
639
+ Returns:
640
+ Generated samples (num_modes * num_samples, seq_len, latent_dim).
641
+ """
642
+ all_samples = []
643
+
644
+ for _ in range(num_modes):
645
+ # Sample a different w for each mode
646
+ w = torch.randn(num_samples, self.condition_dim, device=device) * self.prior_std
647
+ samples = self.sample(
648
+ num_samples, seq_len, num_steps, method, w, device
649
+ )
650
+ all_samples.append(samples)
651
+
652
+ return torch.cat(all_samples, dim=0)
653
+
654
+ def sample_continuation(
655
+ self,
656
+ z_prefix: Tensor,
657
+ prefix_len: int,
658
+ total_len: int,
659
+ num_modes: int = 1,
660
+ num_steps: int = 10,
661
+ method: str = "euler",
662
+ ) -> Tensor:
663
+ """
664
+ Generate trajectory continuations from a prefix.
665
+
666
+ Args:
667
+ z_prefix: Prefix latent sequence (B, prefix_len, D).
668
+ prefix_len: Length of prefix to condition on.
669
+ total_len: Total output sequence length.
670
+ num_modes: Number of continuation modes to generate.
671
+ num_steps: ODE integration steps.
672
+ method: ODE solver method.
673
+
674
+ Returns:
675
+ Continued sequences (B * num_modes, total_len, D).
676
+ """
677
+ B = z_prefix.shape[0]
678
+ device = z_prefix.device
679
+ continuation_len = total_len - prefix_len
680
+
681
+ all_continuations = []
682
+
683
+ for _ in range(num_modes):
684
+ # Sample w from prior
685
+ w = torch.randn(B, self.condition_dim, device=device) * self.prior_std
686
+
687
+ # Start from noise for continuation part
688
+ z_cont_noise = torch.randn(B, continuation_len, self.latent_dim, device=device)
689
+
690
+ # Full sequence: prefix (fixed) + continuation (to be generated)
691
+ # For now, simple approach: generate full sequence conditioned on w
692
+ # and use the continuation part
693
+ z_0 = torch.randn(B, total_len, self.latent_dim, device=device)
694
+
695
+ def velocity_fn(z: Tensor, t: float) -> Tensor:
696
+ t_tensor = torch.tensor([t], device=z.device, dtype=z.dtype).expand(z.shape[0])
697
+ return self.forward(z, t_tensor, w)
698
+
699
+ z_1 = ode_solve(velocity_fn, z_0, 0.0, 1.0, num_steps, method)
700
+
701
+ # Blend: keep prefix, take generated continuation
702
+ z_out = torch.cat([z_prefix, z_1[:, prefix_len:]], dim=1)
703
+ all_continuations.append(z_out)
704
+
705
+ return torch.cat(all_continuations, dim=0)