dhb-xr 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. dhb_xr/__init__.py +61 -0
  2. dhb_xr/cli.py +206 -0
  3. dhb_xr/core/__init__.py +28 -0
  4. dhb_xr/core/geometry.py +167 -0
  5. dhb_xr/core/geometry_torch.py +77 -0
  6. dhb_xr/core/types.py +113 -0
  7. dhb_xr/database/__init__.py +10 -0
  8. dhb_xr/database/motion_db.py +79 -0
  9. dhb_xr/database/retrieval.py +6 -0
  10. dhb_xr/database/similarity.py +71 -0
  11. dhb_xr/decoder/__init__.py +13 -0
  12. dhb_xr/decoder/decoder_torch.py +52 -0
  13. dhb_xr/decoder/dhb_dr.py +261 -0
  14. dhb_xr/decoder/dhb_qr.py +89 -0
  15. dhb_xr/encoder/__init__.py +27 -0
  16. dhb_xr/encoder/dhb_dr.py +418 -0
  17. dhb_xr/encoder/dhb_qr.py +129 -0
  18. dhb_xr/encoder/dhb_ti.py +204 -0
  19. dhb_xr/encoder/encoder_torch.py +54 -0
  20. dhb_xr/encoder/padding.py +82 -0
  21. dhb_xr/generative/__init__.py +78 -0
  22. dhb_xr/generative/flow_matching.py +705 -0
  23. dhb_xr/generative/latent_encoder.py +536 -0
  24. dhb_xr/generative/sampling.py +203 -0
  25. dhb_xr/generative/training.py +475 -0
  26. dhb_xr/generative/vfm_tokenizer.py +485 -0
  27. dhb_xr/integration/__init__.py +13 -0
  28. dhb_xr/integration/vla/__init__.py +11 -0
  29. dhb_xr/integration/vla/libero.py +132 -0
  30. dhb_xr/integration/vla/pipeline.py +85 -0
  31. dhb_xr/integration/vla/robocasa.py +85 -0
  32. dhb_xr/losses/__init__.py +16 -0
  33. dhb_xr/losses/geodesic_loss.py +91 -0
  34. dhb_xr/losses/hybrid_loss.py +36 -0
  35. dhb_xr/losses/invariant_loss.py +73 -0
  36. dhb_xr/optimization/__init__.py +72 -0
  37. dhb_xr/optimization/casadi_solver.py +342 -0
  38. dhb_xr/optimization/constraints.py +32 -0
  39. dhb_xr/optimization/cusadi_solver.py +311 -0
  40. dhb_xr/optimization/export_casadi_decode.py +111 -0
  41. dhb_xr/optimization/fatrop_solver.py +477 -0
  42. dhb_xr/optimization/torch_solver.py +85 -0
  43. dhb_xr/preprocessing/__init__.py +42 -0
  44. dhb_xr/preprocessing/diagnostics.py +330 -0
  45. dhb_xr/preprocessing/trajectory_cleaner.py +485 -0
  46. dhb_xr/tokenization/__init__.py +56 -0
  47. dhb_xr/tokenization/causal_encoder.py +54 -0
  48. dhb_xr/tokenization/compression.py +749 -0
  49. dhb_xr/tokenization/hierarchical.py +359 -0
  50. dhb_xr/tokenization/rvq.py +178 -0
  51. dhb_xr/tokenization/vqvae.py +155 -0
  52. dhb_xr/utils/__init__.py +24 -0
  53. dhb_xr/utils/io.py +59 -0
  54. dhb_xr/utils/resampling.py +66 -0
  55. dhb_xr/utils/xdof_loader.py +89 -0
  56. dhb_xr/visualization/__init__.py +5 -0
  57. dhb_xr/visualization/plot.py +242 -0
  58. dhb_xr-0.2.1.dist-info/METADATA +784 -0
  59. dhb_xr-0.2.1.dist-info/RECORD +82 -0
  60. dhb_xr-0.2.1.dist-info/WHEEL +5 -0
  61. dhb_xr-0.2.1.dist-info/entry_points.txt +2 -0
  62. dhb_xr-0.2.1.dist-info/top_level.txt +3 -0
  63. examples/__init__.py +54 -0
  64. examples/basic_encoding.py +82 -0
  65. examples/benchmark_backends.py +37 -0
  66. examples/dhb_qr_comparison.py +79 -0
  67. examples/dhb_ti_time_invariant.py +72 -0
  68. examples/gpu_batch_optimization.py +102 -0
  69. examples/imitation_learning.py +53 -0
  70. examples/integration/__init__.py +19 -0
  71. examples/integration/libero_full_demo.py +692 -0
  72. examples/integration/libero_pro_dhb_demo.py +1063 -0
  73. examples/integration/libero_simulation_demo.py +286 -0
  74. examples/integration/libero_swap_demo.py +534 -0
  75. examples/integration/robocasa_libero_dhb_pipeline.py +56 -0
  76. examples/integration/test_libero_adapter.py +47 -0
  77. examples/integration/test_libero_encoding.py +75 -0
  78. examples/integration/test_libero_retrieval.py +105 -0
  79. examples/motion_database.py +88 -0
  80. examples/trajectory_adaptation.py +85 -0
  81. examples/vla_tokenization.py +107 -0
  82. notebooks/__init__.py +24 -0
@@ -0,0 +1,203 @@
1
+ """
2
+ ODE solvers for flow matching sampling.
3
+
4
+ Provides Euler and RK4 integrators for solving the flow ODE:
5
+ dz/dt = v_theta(z_t, t)
6
+
7
+ Starting from z_0 (noise) and integrating to z_1 (data).
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import torch
13
+ from torch import Tensor
14
+ from typing import Callable, Optional, Tuple, List
15
+
16
+
17
+ def euler_solve(
18
+ velocity_fn: Callable[[Tensor, float], Tensor],
19
+ z_init: Tensor,
20
+ t_start: float = 0.0,
21
+ t_end: float = 1.0,
22
+ num_steps: int = 10,
23
+ return_trajectory: bool = False,
24
+ ) -> Tensor | Tuple[Tensor, List[Tensor]]:
25
+ """
26
+ Euler method for ODE integration.
27
+
28
+ Solves dz/dt = v(z, t) from t_start to t_end.
29
+
30
+ Args:
31
+ velocity_fn: Function (z, t) -> velocity. Takes tensor z and scalar t.
32
+ z_init: Initial state (B, T, D) or (B, D).
33
+ t_start: Starting time (typically 0 for noise).
34
+ t_end: Ending time (typically 1 for data).
35
+ num_steps: Number of Euler steps.
36
+ return_trajectory: If True, return list of intermediate states.
37
+
38
+ Returns:
39
+ Final state z_end, optionally with trajectory list.
40
+ """
41
+ dt = (t_end - t_start) / num_steps
42
+ z = z_init.clone()
43
+ trajectory = [z.clone()] if return_trajectory else None
44
+
45
+ for i in range(num_steps):
46
+ t = t_start + i * dt
47
+ v = velocity_fn(z, t)
48
+ z = z + dt * v
49
+ if return_trajectory:
50
+ trajectory.append(z.clone())
51
+
52
+ if return_trajectory:
53
+ return z, trajectory
54
+ return z
55
+
56
+
57
+ def rk4_solve(
58
+ velocity_fn: Callable[[Tensor, float], Tensor],
59
+ z_init: Tensor,
60
+ t_start: float = 0.0,
61
+ t_end: float = 1.0,
62
+ num_steps: int = 10,
63
+ return_trajectory: bool = False,
64
+ ) -> Tensor | Tuple[Tensor, List[Tensor]]:
65
+ """
66
+ 4th-order Runge-Kutta method for ODE integration.
67
+
68
+ More accurate than Euler for the same step count, but 4x more expensive.
69
+
70
+ Args:
71
+ velocity_fn: Function (z, t) -> velocity.
72
+ z_init: Initial state (B, T, D) or (B, D).
73
+ t_start: Starting time.
74
+ t_end: Ending time.
75
+ num_steps: Number of RK4 steps.
76
+ return_trajectory: If True, return list of intermediate states.
77
+
78
+ Returns:
79
+ Final state z_end, optionally with trajectory list.
80
+ """
81
+ dt = (t_end - t_start) / num_steps
82
+ z = z_init.clone()
83
+ trajectory = [z.clone()] if return_trajectory else None
84
+
85
+ for i in range(num_steps):
86
+ t = t_start + i * dt
87
+
88
+ k1 = velocity_fn(z, t)
89
+ k2 = velocity_fn(z + 0.5 * dt * k1, t + 0.5 * dt)
90
+ k3 = velocity_fn(z + 0.5 * dt * k2, t + 0.5 * dt)
91
+ k4 = velocity_fn(z + dt * k3, t + dt)
92
+
93
+ z = z + (dt / 6.0) * (k1 + 2 * k2 + 2 * k3 + k4)
94
+
95
+ if return_trajectory:
96
+ trajectory.append(z.clone())
97
+
98
+ if return_trajectory:
99
+ return z, trajectory
100
+ return z
101
+
102
+
103
+ def ode_solve(
104
+ velocity_fn: Callable[[Tensor, float], Tensor],
105
+ z_init: Tensor,
106
+ t_start: float = 0.0,
107
+ t_end: float = 1.0,
108
+ num_steps: int = 10,
109
+ method: str = "euler",
110
+ return_trajectory: bool = False,
111
+ ) -> Tensor | Tuple[Tensor, List[Tensor]]:
112
+ """
113
+ Unified ODE solver interface.
114
+
115
+ Args:
116
+ velocity_fn: Function (z, t) -> velocity.
117
+ z_init: Initial state.
118
+ t_start: Starting time.
119
+ t_end: Ending time.
120
+ num_steps: Number of integration steps.
121
+ method: 'euler' or 'rk4'.
122
+ return_trajectory: If True, return intermediate states.
123
+
124
+ Returns:
125
+ Final state, optionally with trajectory.
126
+ """
127
+ if method == "euler":
128
+ return euler_solve(
129
+ velocity_fn, z_init, t_start, t_end, num_steps, return_trajectory
130
+ )
131
+ elif method == "rk4":
132
+ return rk4_solve(
133
+ velocity_fn, z_init, t_start, t_end, num_steps, return_trajectory
134
+ )
135
+ else:
136
+ raise ValueError(f"Unknown ODE method: {method}. Use 'euler' or 'rk4'.")
137
+
138
+
139
+ def adaptive_euler_solve(
140
+ velocity_fn: Callable[[Tensor, float], Tensor],
141
+ z_init: Tensor,
142
+ t_start: float = 0.0,
143
+ t_end: float = 1.0,
144
+ atol: float = 1e-5,
145
+ rtol: float = 1e-5,
146
+ max_steps: int = 1000,
147
+ min_dt: float = 1e-6,
148
+ ) -> Tuple[Tensor, int]:
149
+ """
150
+ Adaptive step-size Euler method with error estimation.
151
+
152
+ Uses step doubling to estimate error and adjust step size.
153
+ Useful when velocity field has varying stiffness.
154
+
155
+ Args:
156
+ velocity_fn: Function (z, t) -> velocity.
157
+ z_init: Initial state.
158
+ t_start: Starting time.
159
+ t_end: Ending time.
160
+ atol: Absolute tolerance.
161
+ rtol: Relative tolerance.
162
+ max_steps: Maximum number of steps.
163
+ min_dt: Minimum step size.
164
+
165
+ Returns:
166
+ Final state and number of steps taken.
167
+ """
168
+ z = z_init.clone()
169
+ t = t_start
170
+ dt = (t_end - t_start) / 10 # Initial step size
171
+ steps = 0
172
+
173
+ while t < t_end and steps < max_steps:
174
+ # Ensure we don't overshoot
175
+ dt = min(dt, t_end - t)
176
+
177
+ # Full step
178
+ v = velocity_fn(z, t)
179
+ z_full = z + dt * v
180
+
181
+ # Two half steps
182
+ z_half = z + 0.5 * dt * v
183
+ v_half = velocity_fn(z_half, t + 0.5 * dt)
184
+ z_double = z_half + 0.5 * dt * v_half
185
+
186
+ # Error estimate
187
+ error = (z_double - z_full).abs().max()
188
+ tol = atol + rtol * z.abs().max()
189
+
190
+ if error < tol:
191
+ # Accept step
192
+ z = z_double # Use the more accurate estimate
193
+ t = t + dt
194
+ steps += 1
195
+
196
+ # Increase step size
197
+ if error < 0.5 * tol:
198
+ dt = min(dt * 1.5, t_end - t)
199
+ else:
200
+ # Reject step, reduce step size
201
+ dt = max(dt * 0.5, min_dt)
202
+
203
+ return z, steps
@@ -0,0 +1,475 @@
1
+ """
2
+ Training utilities for VFM token generators.
3
+
4
+ Provides:
5
+ - InvariantDataset: PyTorch dataset for DHB invariant sequences
6
+ - train_vfm_tokenizer: Training loop for joint tokenizer + flow matching
7
+ - Evaluation metrics and logging utilities
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import os
13
+ import time
14
+ import torch
15
+ import torch.nn as nn
16
+ from torch.utils.data import Dataset, DataLoader
17
+ from torch import Tensor
18
+ from typing import Optional, Dict, List, Callable, Any, Union
19
+ import numpy as np
20
+
21
+ try:
22
+ from tqdm import tqdm
23
+ HAS_TQDM = True
24
+ except ImportError:
25
+ HAS_TQDM = False
26
+ tqdm = lambda x, **kwargs: x
27
+
28
+
29
+ class InvariantDataset(Dataset):
30
+ """
31
+ PyTorch Dataset for DHB invariant sequences.
32
+
33
+ Supports loading from:
34
+ - NumPy arrays
35
+ - Lists of arrays
36
+ - Directory of .npy files
37
+ """
38
+
39
+ def __init__(
40
+ self,
41
+ data: Union[np.ndarray, List[np.ndarray], str],
42
+ seq_len: Optional[int] = None,
43
+ normalize: bool = True,
44
+ augment: bool = False,
45
+ ):
46
+ """
47
+ Args:
48
+ data: Invariant sequences as:
49
+ - np.ndarray of shape (N, T, D)
50
+ - List of (T_i, D) arrays (variable length)
51
+ - Path to directory with .npy files
52
+ seq_len: Fixed sequence length. If provided, sequences are
53
+ padded/truncated to this length.
54
+ normalize: Whether to normalize invariants.
55
+ augment: Whether to apply data augmentation.
56
+ """
57
+ super().__init__()
58
+ self.seq_len = seq_len
59
+ self.normalize = normalize
60
+ self.augment = augment
61
+
62
+ # Load data
63
+ if isinstance(data, str):
64
+ self.sequences = self._load_from_dir(data)
65
+ elif isinstance(data, np.ndarray):
66
+ self.sequences = [data[i] for i in range(len(data))]
67
+ else:
68
+ self.sequences = list(data)
69
+
70
+ # Compute normalization stats
71
+ if normalize:
72
+ all_data = np.concatenate(self.sequences, axis=0)
73
+ self.mean = all_data.mean(axis=0)
74
+ self.std = all_data.std(axis=0) + 1e-8
75
+ else:
76
+ self.mean = None
77
+ self.std = None
78
+
79
+ def _load_from_dir(self, path: str) -> List[np.ndarray]:
80
+ """Load sequences from directory of .npy files."""
81
+ sequences = []
82
+ for fname in sorted(os.listdir(path)):
83
+ if fname.endswith(".npy"):
84
+ arr = np.load(os.path.join(path, fname))
85
+ sequences.append(arr)
86
+ return sequences
87
+
88
+ def __len__(self) -> int:
89
+ return len(self.sequences)
90
+
91
+ def __getitem__(self, idx: int) -> Tensor:
92
+ seq = self.sequences[idx].copy()
93
+
94
+ # Normalize
95
+ if self.normalize:
96
+ seq = (seq - self.mean) / self.std
97
+
98
+ # Pad/truncate to fixed length
99
+ if self.seq_len is not None:
100
+ if len(seq) < self.seq_len:
101
+ # Pad with zeros
102
+ pad = np.zeros((self.seq_len - len(seq), seq.shape[-1]))
103
+ seq = np.concatenate([seq, pad], axis=0)
104
+ elif len(seq) > self.seq_len:
105
+ # Random crop
106
+ if self.augment:
107
+ start = np.random.randint(0, len(seq) - self.seq_len + 1)
108
+ else:
109
+ start = 0
110
+ seq = seq[start:start + self.seq_len]
111
+
112
+ # Augmentation
113
+ if self.augment:
114
+ seq = self._augment(seq)
115
+
116
+ return torch.from_numpy(seq).float()
117
+
118
+ def _augment(self, seq: np.ndarray) -> np.ndarray:
119
+ """Apply data augmentation."""
120
+ # Random noise
121
+ if np.random.random() < 0.3:
122
+ noise = np.random.randn(*seq.shape) * 0.01
123
+ seq = seq + noise
124
+
125
+ # Random scale
126
+ if np.random.random() < 0.3:
127
+ scale = np.random.uniform(0.9, 1.1)
128
+ seq = seq * scale
129
+
130
+ return seq
131
+
132
+ def denormalize(self, seq: Tensor) -> Tensor:
133
+ """Convert normalized sequence back to original scale."""
134
+ if not self.normalize:
135
+ return seq
136
+ mean = torch.from_numpy(self.mean).to(seq.device).float()
137
+ std = torch.from_numpy(self.std).to(seq.device).float()
138
+ return seq * std + mean
139
+
140
+
141
+ def train_vfm_tokenizer(
142
+ model: nn.Module,
143
+ train_dataset: Dataset,
144
+ val_dataset: Optional[Dataset] = None,
145
+ num_epochs: int = 100,
146
+ batch_size: int = 32,
147
+ learning_rate: float = 1e-4,
148
+ tokenizer_weight: float = 1.0,
149
+ flow_weight: float = 1.0,
150
+ beta: float = 0.01,
151
+ beta_schedule: Optional[Callable[[int], float]] = None,
152
+ grad_clip: float = 1.0,
153
+ log_every: int = 100,
154
+ eval_every: int = 1000,
155
+ save_every: int = 5000,
156
+ save_dir: Optional[str] = None,
157
+ device: str = "cpu",
158
+ callback: Optional[Callable[[Dict], None]] = None,
159
+ ) -> Dict[str, List[float]]:
160
+ """
161
+ Train VFMTokenGenerator with joint tokenizer + flow matching loss.
162
+
163
+ Args:
164
+ model: VFMTokenGenerator or similar model with loss() method.
165
+ train_dataset: Training dataset of invariant sequences.
166
+ val_dataset: Optional validation dataset.
167
+ num_epochs: Number of training epochs.
168
+ batch_size: Batch size.
169
+ learning_rate: Learning rate.
170
+ tokenizer_weight: Weight for tokenizer reconstruction loss.
171
+ flow_weight: Weight for flow matching loss.
172
+ beta: KL divergence weight (for variational models).
173
+ beta_schedule: Optional function epoch -> beta for KL annealing.
174
+ grad_clip: Gradient clipping norm.
175
+ log_every: Log frequency (steps).
176
+ eval_every: Evaluation frequency (steps).
177
+ save_every: Checkpoint save frequency (steps).
178
+ save_dir: Directory for saving checkpoints.
179
+ device: Device for training.
180
+ callback: Optional callback function called with metrics dict.
181
+
182
+ Returns:
183
+ Dictionary of training history.
184
+ """
185
+ model = model.to(device)
186
+ optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
187
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
188
+ optimizer, T_max=num_epochs
189
+ )
190
+
191
+ train_loader = DataLoader(
192
+ train_dataset, batch_size=batch_size, shuffle=True, num_workers=0
193
+ )
194
+
195
+ if val_dataset is not None:
196
+ val_loader = DataLoader(
197
+ val_dataset, batch_size=batch_size, shuffle=False, num_workers=0
198
+ )
199
+
200
+ # Training history
201
+ history = {
202
+ "train_loss": [],
203
+ "train_tokenizer_loss": [],
204
+ "train_flow_loss": [],
205
+ "train_kl_loss": [],
206
+ "val_loss": [],
207
+ "learning_rate": [],
208
+ }
209
+
210
+ # Training loop
211
+ global_step = 0
212
+ best_val_loss = float("inf")
213
+
214
+ for epoch in range(num_epochs):
215
+ model.train()
216
+ epoch_losses = {"total": 0, "tokenizer": 0, "flow": 0, "kl": 0}
217
+ num_batches = 0
218
+
219
+ # Get current beta
220
+ current_beta = beta_schedule(epoch) if beta_schedule else beta
221
+
222
+ pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
223
+ for batch in pbar:
224
+ batch = batch.to(device)
225
+
226
+ # Forward and loss
227
+ optimizer.zero_grad()
228
+ losses = model.loss(
229
+ batch,
230
+ tokenizer_weight=tokenizer_weight,
231
+ flow_weight=flow_weight,
232
+ beta=current_beta,
233
+ )
234
+
235
+ # Backward
236
+ losses["total"].backward()
237
+ if grad_clip > 0:
238
+ torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
239
+ optimizer.step()
240
+
241
+ # Accumulate losses
242
+ epoch_losses["total"] += losses["total"].item()
243
+ epoch_losses["tokenizer"] += losses["tokenizer"].item()
244
+ epoch_losses["flow"] += losses["flow"].item()
245
+ epoch_losses["kl"] += losses["kl"].item()
246
+ num_batches += 1
247
+ global_step += 1
248
+
249
+ # Update progress bar
250
+ pbar.set_postfix({
251
+ "loss": f"{losses['total'].item():.4f}",
252
+ "tok": f"{losses['tokenizer'].item():.4f}",
253
+ "flow": f"{losses['flow'].item():.4f}",
254
+ })
255
+
256
+ # Logging
257
+ if global_step % log_every == 0:
258
+ avg_loss = epoch_losses["total"] / num_batches
259
+ history["train_loss"].append(avg_loss)
260
+ history["learning_rate"].append(optimizer.param_groups[0]["lr"])
261
+
262
+ if callback:
263
+ callback({
264
+ "step": global_step,
265
+ "epoch": epoch,
266
+ "train_loss": avg_loss,
267
+ "lr": optimizer.param_groups[0]["lr"],
268
+ })
269
+
270
+ # Evaluation
271
+ if val_dataset is not None and global_step % eval_every == 0:
272
+ val_loss = evaluate_model(model, val_loader, device, current_beta)
273
+ history["val_loss"].append(val_loss)
274
+
275
+ if val_loss < best_val_loss:
276
+ best_val_loss = val_loss
277
+ if save_dir:
278
+ save_checkpoint(
279
+ model, optimizer, epoch, global_step,
280
+ os.path.join(save_dir, "best_model.pt")
281
+ )
282
+
283
+ # Checkpointing
284
+ if save_dir and global_step % save_every == 0:
285
+ save_checkpoint(
286
+ model, optimizer, epoch, global_step,
287
+ os.path.join(save_dir, f"checkpoint_{global_step}.pt")
288
+ )
289
+
290
+ # End of epoch
291
+ scheduler.step()
292
+
293
+ # Record epoch averages
294
+ for key in epoch_losses:
295
+ epoch_losses[key] /= max(num_batches, 1)
296
+
297
+ history["train_tokenizer_loss"].append(epoch_losses["tokenizer"])
298
+ history["train_flow_loss"].append(epoch_losses["flow"])
299
+ history["train_kl_loss"].append(epoch_losses["kl"])
300
+
301
+ print(f"Epoch {epoch+1}: loss={epoch_losses['total']:.4f}, "
302
+ f"tok={epoch_losses['tokenizer']:.4f}, "
303
+ f"flow={epoch_losses['flow']:.4f}, "
304
+ f"kl={epoch_losses['kl']:.4f}")
305
+
306
+ return history
307
+
308
+
309
+ def evaluate_model(
310
+ model: nn.Module,
311
+ data_loader: DataLoader,
312
+ device: str,
313
+ beta: float = 0.01,
314
+ ) -> float:
315
+ """Evaluate model on validation data."""
316
+ model.eval()
317
+ total_loss = 0
318
+ num_batches = 0
319
+
320
+ with torch.no_grad():
321
+ for batch in data_loader:
322
+ batch = batch.to(device)
323
+ losses = model.loss(batch, beta=beta)
324
+ total_loss += losses["total"].item()
325
+ num_batches += 1
326
+
327
+ model.train()
328
+ return total_loss / max(num_batches, 1)
329
+
330
+
331
+ def save_checkpoint(
332
+ model: nn.Module,
333
+ optimizer: torch.optim.Optimizer,
334
+ epoch: int,
335
+ step: int,
336
+ path: str,
337
+ ):
338
+ """Save training checkpoint."""
339
+ os.makedirs(os.path.dirname(path), exist_ok=True)
340
+ torch.save({
341
+ "model_state_dict": model.state_dict(),
342
+ "optimizer_state_dict": optimizer.state_dict(),
343
+ "epoch": epoch,
344
+ "step": step,
345
+ }, path)
346
+
347
+
348
+ def load_checkpoint(
349
+ model: nn.Module,
350
+ optimizer: Optional[torch.optim.Optimizer],
351
+ path: str,
352
+ ) -> Dict:
353
+ """Load training checkpoint."""
354
+ checkpoint = torch.load(path)
355
+ model.load_state_dict(checkpoint["model_state_dict"])
356
+ if optimizer is not None:
357
+ optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
358
+ return {
359
+ "epoch": checkpoint["epoch"],
360
+ "step": checkpoint["step"],
361
+ }
362
+
363
+
364
+ # ---- Evaluation metrics ----
365
+
366
+ def compute_reconstruction_error(
367
+ model: nn.Module,
368
+ data_loader: DataLoader,
369
+ device: str,
370
+ ) -> Dict[str, float]:
371
+ """
372
+ Compute reconstruction error metrics.
373
+
374
+ Returns MSE, MAE, and max error.
375
+ """
376
+ model.eval()
377
+ all_mse = []
378
+ all_mae = []
379
+ all_max = []
380
+
381
+ with torch.no_grad():
382
+ for batch in data_loader:
383
+ batch = batch.to(device)
384
+ outputs = model(batch)
385
+ recon = outputs["invariants_reconstructed"]
386
+
387
+ mse = ((batch - recon) ** 2).mean(dim=(1, 2))
388
+ mae = (batch - recon).abs().mean(dim=(1, 2))
389
+ max_err = (batch - recon).abs().max(dim=-1)[0].max(dim=-1)[0]
390
+
391
+ all_mse.extend(mse.cpu().numpy())
392
+ all_mae.extend(mae.cpu().numpy())
393
+ all_max.extend(max_err.cpu().numpy())
394
+
395
+ model.train()
396
+ return {
397
+ "mse": np.mean(all_mse),
398
+ "mae": np.mean(all_mae),
399
+ "max_error": np.mean(all_max),
400
+ }
401
+
402
+
403
+ def compute_generation_diversity(
404
+ model: nn.Module,
405
+ num_samples: int = 100,
406
+ seq_len: int = 50,
407
+ device: str = "cpu",
408
+ ) -> Dict[str, float]:
409
+ """
410
+ Compute diversity metrics for generated samples.
411
+
412
+ Returns pairwise distance statistics.
413
+ """
414
+ model.eval()
415
+
416
+ with torch.no_grad():
417
+ # Generate samples
418
+ if hasattr(model, "generate_multimodal"):
419
+ samples = model.generate_multimodal(
420
+ num_samples=num_samples // 4,
421
+ seq_len=seq_len,
422
+ num_modes=4,
423
+ device=device,
424
+ )
425
+ else:
426
+ samples = model.generate(
427
+ num_samples=num_samples,
428
+ seq_len=seq_len,
429
+ device=device,
430
+ )
431
+
432
+ # Compute pairwise distances
433
+ samples_flat = samples.reshape(samples.shape[0], -1)
434
+ dists = torch.cdist(samples_flat, samples_flat)
435
+
436
+ # Get upper triangle (excluding diagonal)
437
+ mask = torch.triu(torch.ones_like(dists), diagonal=1).bool()
438
+ pairwise_dists = dists[mask]
439
+
440
+ model.train()
441
+ return {
442
+ "mean_pairwise_dist": pairwise_dists.mean().item(),
443
+ "std_pairwise_dist": pairwise_dists.std().item(),
444
+ "min_pairwise_dist": pairwise_dists.min().item(),
445
+ "max_pairwise_dist": pairwise_dists.max().item(),
446
+ }
447
+
448
+
449
+ # ---- KL annealing schedules ----
450
+
451
+ def linear_kl_schedule(
452
+ warmup_epochs: int,
453
+ max_beta: float = 1.0,
454
+ ) -> Callable[[int], float]:
455
+ """Linear KL annealing from 0 to max_beta over warmup_epochs."""
456
+ def schedule(epoch: int) -> float:
457
+ if epoch < warmup_epochs:
458
+ return max_beta * epoch / warmup_epochs
459
+ return max_beta
460
+ return schedule
461
+
462
+
463
+ def cyclical_kl_schedule(
464
+ cycle_length: int = 10,
465
+ num_cycles: int = 4,
466
+ max_beta: float = 1.0,
467
+ ) -> Callable[[int], float]:
468
+ """Cyclical KL annealing."""
469
+ def schedule(epoch: int) -> float:
470
+ cycle = epoch // cycle_length
471
+ if cycle >= num_cycles:
472
+ return max_beta
473
+ progress = (epoch % cycle_length) / cycle_length
474
+ return max_beta * progress
475
+ return schedule