dhb-xr 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dhb_xr/__init__.py +61 -0
- dhb_xr/cli.py +206 -0
- dhb_xr/core/__init__.py +28 -0
- dhb_xr/core/geometry.py +167 -0
- dhb_xr/core/geometry_torch.py +77 -0
- dhb_xr/core/types.py +113 -0
- dhb_xr/database/__init__.py +10 -0
- dhb_xr/database/motion_db.py +79 -0
- dhb_xr/database/retrieval.py +6 -0
- dhb_xr/database/similarity.py +71 -0
- dhb_xr/decoder/__init__.py +13 -0
- dhb_xr/decoder/decoder_torch.py +52 -0
- dhb_xr/decoder/dhb_dr.py +261 -0
- dhb_xr/decoder/dhb_qr.py +89 -0
- dhb_xr/encoder/__init__.py +27 -0
- dhb_xr/encoder/dhb_dr.py +418 -0
- dhb_xr/encoder/dhb_qr.py +129 -0
- dhb_xr/encoder/dhb_ti.py +204 -0
- dhb_xr/encoder/encoder_torch.py +54 -0
- dhb_xr/encoder/padding.py +82 -0
- dhb_xr/generative/__init__.py +78 -0
- dhb_xr/generative/flow_matching.py +705 -0
- dhb_xr/generative/latent_encoder.py +536 -0
- dhb_xr/generative/sampling.py +203 -0
- dhb_xr/generative/training.py +475 -0
- dhb_xr/generative/vfm_tokenizer.py +485 -0
- dhb_xr/integration/__init__.py +13 -0
- dhb_xr/integration/vla/__init__.py +11 -0
- dhb_xr/integration/vla/libero.py +132 -0
- dhb_xr/integration/vla/pipeline.py +85 -0
- dhb_xr/integration/vla/robocasa.py +85 -0
- dhb_xr/losses/__init__.py +16 -0
- dhb_xr/losses/geodesic_loss.py +91 -0
- dhb_xr/losses/hybrid_loss.py +36 -0
- dhb_xr/losses/invariant_loss.py +73 -0
- dhb_xr/optimization/__init__.py +72 -0
- dhb_xr/optimization/casadi_solver.py +342 -0
- dhb_xr/optimization/constraints.py +32 -0
- dhb_xr/optimization/cusadi_solver.py +311 -0
- dhb_xr/optimization/export_casadi_decode.py +111 -0
- dhb_xr/optimization/fatrop_solver.py +477 -0
- dhb_xr/optimization/torch_solver.py +85 -0
- dhb_xr/preprocessing/__init__.py +42 -0
- dhb_xr/preprocessing/diagnostics.py +330 -0
- dhb_xr/preprocessing/trajectory_cleaner.py +485 -0
- dhb_xr/tokenization/__init__.py +56 -0
- dhb_xr/tokenization/causal_encoder.py +54 -0
- dhb_xr/tokenization/compression.py +749 -0
- dhb_xr/tokenization/hierarchical.py +359 -0
- dhb_xr/tokenization/rvq.py +178 -0
- dhb_xr/tokenization/vqvae.py +155 -0
- dhb_xr/utils/__init__.py +24 -0
- dhb_xr/utils/io.py +59 -0
- dhb_xr/utils/resampling.py +66 -0
- dhb_xr/utils/xdof_loader.py +89 -0
- dhb_xr/visualization/__init__.py +5 -0
- dhb_xr/visualization/plot.py +242 -0
- dhb_xr-0.2.1.dist-info/METADATA +784 -0
- dhb_xr-0.2.1.dist-info/RECORD +82 -0
- dhb_xr-0.2.1.dist-info/WHEEL +5 -0
- dhb_xr-0.2.1.dist-info/entry_points.txt +2 -0
- dhb_xr-0.2.1.dist-info/top_level.txt +3 -0
- examples/__init__.py +54 -0
- examples/basic_encoding.py +82 -0
- examples/benchmark_backends.py +37 -0
- examples/dhb_qr_comparison.py +79 -0
- examples/dhb_ti_time_invariant.py +72 -0
- examples/gpu_batch_optimization.py +102 -0
- examples/imitation_learning.py +53 -0
- examples/integration/__init__.py +19 -0
- examples/integration/libero_full_demo.py +692 -0
- examples/integration/libero_pro_dhb_demo.py +1063 -0
- examples/integration/libero_simulation_demo.py +286 -0
- examples/integration/libero_swap_demo.py +534 -0
- examples/integration/robocasa_libero_dhb_pipeline.py +56 -0
- examples/integration/test_libero_adapter.py +47 -0
- examples/integration/test_libero_encoding.py +75 -0
- examples/integration/test_libero_retrieval.py +105 -0
- examples/motion_database.py +88 -0
- examples/trajectory_adaptation.py +85 -0
- examples/vla_tokenization.py +107 -0
- notebooks/__init__.py +24 -0
|
@@ -0,0 +1,203 @@
|
|
|
1
|
+
"""
|
|
2
|
+
ODE solvers for flow matching sampling.
|
|
3
|
+
|
|
4
|
+
Provides Euler and RK4 integrators for solving the flow ODE:
|
|
5
|
+
dz/dt = v_theta(z_t, t)
|
|
6
|
+
|
|
7
|
+
Starting from z_0 (noise) and integrating to z_1 (data).
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
import torch
|
|
13
|
+
from torch import Tensor
|
|
14
|
+
from typing import Callable, Optional, Tuple, List
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def euler_solve(
|
|
18
|
+
velocity_fn: Callable[[Tensor, float], Tensor],
|
|
19
|
+
z_init: Tensor,
|
|
20
|
+
t_start: float = 0.0,
|
|
21
|
+
t_end: float = 1.0,
|
|
22
|
+
num_steps: int = 10,
|
|
23
|
+
return_trajectory: bool = False,
|
|
24
|
+
) -> Tensor | Tuple[Tensor, List[Tensor]]:
|
|
25
|
+
"""
|
|
26
|
+
Euler method for ODE integration.
|
|
27
|
+
|
|
28
|
+
Solves dz/dt = v(z, t) from t_start to t_end.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
velocity_fn: Function (z, t) -> velocity. Takes tensor z and scalar t.
|
|
32
|
+
z_init: Initial state (B, T, D) or (B, D).
|
|
33
|
+
t_start: Starting time (typically 0 for noise).
|
|
34
|
+
t_end: Ending time (typically 1 for data).
|
|
35
|
+
num_steps: Number of Euler steps.
|
|
36
|
+
return_trajectory: If True, return list of intermediate states.
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
Final state z_end, optionally with trajectory list.
|
|
40
|
+
"""
|
|
41
|
+
dt = (t_end - t_start) / num_steps
|
|
42
|
+
z = z_init.clone()
|
|
43
|
+
trajectory = [z.clone()] if return_trajectory else None
|
|
44
|
+
|
|
45
|
+
for i in range(num_steps):
|
|
46
|
+
t = t_start + i * dt
|
|
47
|
+
v = velocity_fn(z, t)
|
|
48
|
+
z = z + dt * v
|
|
49
|
+
if return_trajectory:
|
|
50
|
+
trajectory.append(z.clone())
|
|
51
|
+
|
|
52
|
+
if return_trajectory:
|
|
53
|
+
return z, trajectory
|
|
54
|
+
return z
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def rk4_solve(
|
|
58
|
+
velocity_fn: Callable[[Tensor, float], Tensor],
|
|
59
|
+
z_init: Tensor,
|
|
60
|
+
t_start: float = 0.0,
|
|
61
|
+
t_end: float = 1.0,
|
|
62
|
+
num_steps: int = 10,
|
|
63
|
+
return_trajectory: bool = False,
|
|
64
|
+
) -> Tensor | Tuple[Tensor, List[Tensor]]:
|
|
65
|
+
"""
|
|
66
|
+
4th-order Runge-Kutta method for ODE integration.
|
|
67
|
+
|
|
68
|
+
More accurate than Euler for the same step count, but 4x more expensive.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
velocity_fn: Function (z, t) -> velocity.
|
|
72
|
+
z_init: Initial state (B, T, D) or (B, D).
|
|
73
|
+
t_start: Starting time.
|
|
74
|
+
t_end: Ending time.
|
|
75
|
+
num_steps: Number of RK4 steps.
|
|
76
|
+
return_trajectory: If True, return list of intermediate states.
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
Final state z_end, optionally with trajectory list.
|
|
80
|
+
"""
|
|
81
|
+
dt = (t_end - t_start) / num_steps
|
|
82
|
+
z = z_init.clone()
|
|
83
|
+
trajectory = [z.clone()] if return_trajectory else None
|
|
84
|
+
|
|
85
|
+
for i in range(num_steps):
|
|
86
|
+
t = t_start + i * dt
|
|
87
|
+
|
|
88
|
+
k1 = velocity_fn(z, t)
|
|
89
|
+
k2 = velocity_fn(z + 0.5 * dt * k1, t + 0.5 * dt)
|
|
90
|
+
k3 = velocity_fn(z + 0.5 * dt * k2, t + 0.5 * dt)
|
|
91
|
+
k4 = velocity_fn(z + dt * k3, t + dt)
|
|
92
|
+
|
|
93
|
+
z = z + (dt / 6.0) * (k1 + 2 * k2 + 2 * k3 + k4)
|
|
94
|
+
|
|
95
|
+
if return_trajectory:
|
|
96
|
+
trajectory.append(z.clone())
|
|
97
|
+
|
|
98
|
+
if return_trajectory:
|
|
99
|
+
return z, trajectory
|
|
100
|
+
return z
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def ode_solve(
|
|
104
|
+
velocity_fn: Callable[[Tensor, float], Tensor],
|
|
105
|
+
z_init: Tensor,
|
|
106
|
+
t_start: float = 0.0,
|
|
107
|
+
t_end: float = 1.0,
|
|
108
|
+
num_steps: int = 10,
|
|
109
|
+
method: str = "euler",
|
|
110
|
+
return_trajectory: bool = False,
|
|
111
|
+
) -> Tensor | Tuple[Tensor, List[Tensor]]:
|
|
112
|
+
"""
|
|
113
|
+
Unified ODE solver interface.
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
velocity_fn: Function (z, t) -> velocity.
|
|
117
|
+
z_init: Initial state.
|
|
118
|
+
t_start: Starting time.
|
|
119
|
+
t_end: Ending time.
|
|
120
|
+
num_steps: Number of integration steps.
|
|
121
|
+
method: 'euler' or 'rk4'.
|
|
122
|
+
return_trajectory: If True, return intermediate states.
|
|
123
|
+
|
|
124
|
+
Returns:
|
|
125
|
+
Final state, optionally with trajectory.
|
|
126
|
+
"""
|
|
127
|
+
if method == "euler":
|
|
128
|
+
return euler_solve(
|
|
129
|
+
velocity_fn, z_init, t_start, t_end, num_steps, return_trajectory
|
|
130
|
+
)
|
|
131
|
+
elif method == "rk4":
|
|
132
|
+
return rk4_solve(
|
|
133
|
+
velocity_fn, z_init, t_start, t_end, num_steps, return_trajectory
|
|
134
|
+
)
|
|
135
|
+
else:
|
|
136
|
+
raise ValueError(f"Unknown ODE method: {method}. Use 'euler' or 'rk4'.")
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def adaptive_euler_solve(
|
|
140
|
+
velocity_fn: Callable[[Tensor, float], Tensor],
|
|
141
|
+
z_init: Tensor,
|
|
142
|
+
t_start: float = 0.0,
|
|
143
|
+
t_end: float = 1.0,
|
|
144
|
+
atol: float = 1e-5,
|
|
145
|
+
rtol: float = 1e-5,
|
|
146
|
+
max_steps: int = 1000,
|
|
147
|
+
min_dt: float = 1e-6,
|
|
148
|
+
) -> Tuple[Tensor, int]:
|
|
149
|
+
"""
|
|
150
|
+
Adaptive step-size Euler method with error estimation.
|
|
151
|
+
|
|
152
|
+
Uses step doubling to estimate error and adjust step size.
|
|
153
|
+
Useful when velocity field has varying stiffness.
|
|
154
|
+
|
|
155
|
+
Args:
|
|
156
|
+
velocity_fn: Function (z, t) -> velocity.
|
|
157
|
+
z_init: Initial state.
|
|
158
|
+
t_start: Starting time.
|
|
159
|
+
t_end: Ending time.
|
|
160
|
+
atol: Absolute tolerance.
|
|
161
|
+
rtol: Relative tolerance.
|
|
162
|
+
max_steps: Maximum number of steps.
|
|
163
|
+
min_dt: Minimum step size.
|
|
164
|
+
|
|
165
|
+
Returns:
|
|
166
|
+
Final state and number of steps taken.
|
|
167
|
+
"""
|
|
168
|
+
z = z_init.clone()
|
|
169
|
+
t = t_start
|
|
170
|
+
dt = (t_end - t_start) / 10 # Initial step size
|
|
171
|
+
steps = 0
|
|
172
|
+
|
|
173
|
+
while t < t_end and steps < max_steps:
|
|
174
|
+
# Ensure we don't overshoot
|
|
175
|
+
dt = min(dt, t_end - t)
|
|
176
|
+
|
|
177
|
+
# Full step
|
|
178
|
+
v = velocity_fn(z, t)
|
|
179
|
+
z_full = z + dt * v
|
|
180
|
+
|
|
181
|
+
# Two half steps
|
|
182
|
+
z_half = z + 0.5 * dt * v
|
|
183
|
+
v_half = velocity_fn(z_half, t + 0.5 * dt)
|
|
184
|
+
z_double = z_half + 0.5 * dt * v_half
|
|
185
|
+
|
|
186
|
+
# Error estimate
|
|
187
|
+
error = (z_double - z_full).abs().max()
|
|
188
|
+
tol = atol + rtol * z.abs().max()
|
|
189
|
+
|
|
190
|
+
if error < tol:
|
|
191
|
+
# Accept step
|
|
192
|
+
z = z_double # Use the more accurate estimate
|
|
193
|
+
t = t + dt
|
|
194
|
+
steps += 1
|
|
195
|
+
|
|
196
|
+
# Increase step size
|
|
197
|
+
if error < 0.5 * tol:
|
|
198
|
+
dt = min(dt * 1.5, t_end - t)
|
|
199
|
+
else:
|
|
200
|
+
# Reject step, reduce step size
|
|
201
|
+
dt = max(dt * 0.5, min_dt)
|
|
202
|
+
|
|
203
|
+
return z, steps
|
|
@@ -0,0 +1,475 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Training utilities for VFM token generators.
|
|
3
|
+
|
|
4
|
+
Provides:
|
|
5
|
+
- InvariantDataset: PyTorch dataset for DHB invariant sequences
|
|
6
|
+
- train_vfm_tokenizer: Training loop for joint tokenizer + flow matching
|
|
7
|
+
- Evaluation metrics and logging utilities
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
import os
|
|
13
|
+
import time
|
|
14
|
+
import torch
|
|
15
|
+
import torch.nn as nn
|
|
16
|
+
from torch.utils.data import Dataset, DataLoader
|
|
17
|
+
from torch import Tensor
|
|
18
|
+
from typing import Optional, Dict, List, Callable, Any, Union
|
|
19
|
+
import numpy as np
|
|
20
|
+
|
|
21
|
+
try:
|
|
22
|
+
from tqdm import tqdm
|
|
23
|
+
HAS_TQDM = True
|
|
24
|
+
except ImportError:
|
|
25
|
+
HAS_TQDM = False
|
|
26
|
+
tqdm = lambda x, **kwargs: x
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class InvariantDataset(Dataset):
|
|
30
|
+
"""
|
|
31
|
+
PyTorch Dataset for DHB invariant sequences.
|
|
32
|
+
|
|
33
|
+
Supports loading from:
|
|
34
|
+
- NumPy arrays
|
|
35
|
+
- Lists of arrays
|
|
36
|
+
- Directory of .npy files
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
def __init__(
|
|
40
|
+
self,
|
|
41
|
+
data: Union[np.ndarray, List[np.ndarray], str],
|
|
42
|
+
seq_len: Optional[int] = None,
|
|
43
|
+
normalize: bool = True,
|
|
44
|
+
augment: bool = False,
|
|
45
|
+
):
|
|
46
|
+
"""
|
|
47
|
+
Args:
|
|
48
|
+
data: Invariant sequences as:
|
|
49
|
+
- np.ndarray of shape (N, T, D)
|
|
50
|
+
- List of (T_i, D) arrays (variable length)
|
|
51
|
+
- Path to directory with .npy files
|
|
52
|
+
seq_len: Fixed sequence length. If provided, sequences are
|
|
53
|
+
padded/truncated to this length.
|
|
54
|
+
normalize: Whether to normalize invariants.
|
|
55
|
+
augment: Whether to apply data augmentation.
|
|
56
|
+
"""
|
|
57
|
+
super().__init__()
|
|
58
|
+
self.seq_len = seq_len
|
|
59
|
+
self.normalize = normalize
|
|
60
|
+
self.augment = augment
|
|
61
|
+
|
|
62
|
+
# Load data
|
|
63
|
+
if isinstance(data, str):
|
|
64
|
+
self.sequences = self._load_from_dir(data)
|
|
65
|
+
elif isinstance(data, np.ndarray):
|
|
66
|
+
self.sequences = [data[i] for i in range(len(data))]
|
|
67
|
+
else:
|
|
68
|
+
self.sequences = list(data)
|
|
69
|
+
|
|
70
|
+
# Compute normalization stats
|
|
71
|
+
if normalize:
|
|
72
|
+
all_data = np.concatenate(self.sequences, axis=0)
|
|
73
|
+
self.mean = all_data.mean(axis=0)
|
|
74
|
+
self.std = all_data.std(axis=0) + 1e-8
|
|
75
|
+
else:
|
|
76
|
+
self.mean = None
|
|
77
|
+
self.std = None
|
|
78
|
+
|
|
79
|
+
def _load_from_dir(self, path: str) -> List[np.ndarray]:
|
|
80
|
+
"""Load sequences from directory of .npy files."""
|
|
81
|
+
sequences = []
|
|
82
|
+
for fname in sorted(os.listdir(path)):
|
|
83
|
+
if fname.endswith(".npy"):
|
|
84
|
+
arr = np.load(os.path.join(path, fname))
|
|
85
|
+
sequences.append(arr)
|
|
86
|
+
return sequences
|
|
87
|
+
|
|
88
|
+
def __len__(self) -> int:
|
|
89
|
+
return len(self.sequences)
|
|
90
|
+
|
|
91
|
+
def __getitem__(self, idx: int) -> Tensor:
|
|
92
|
+
seq = self.sequences[idx].copy()
|
|
93
|
+
|
|
94
|
+
# Normalize
|
|
95
|
+
if self.normalize:
|
|
96
|
+
seq = (seq - self.mean) / self.std
|
|
97
|
+
|
|
98
|
+
# Pad/truncate to fixed length
|
|
99
|
+
if self.seq_len is not None:
|
|
100
|
+
if len(seq) < self.seq_len:
|
|
101
|
+
# Pad with zeros
|
|
102
|
+
pad = np.zeros((self.seq_len - len(seq), seq.shape[-1]))
|
|
103
|
+
seq = np.concatenate([seq, pad], axis=0)
|
|
104
|
+
elif len(seq) > self.seq_len:
|
|
105
|
+
# Random crop
|
|
106
|
+
if self.augment:
|
|
107
|
+
start = np.random.randint(0, len(seq) - self.seq_len + 1)
|
|
108
|
+
else:
|
|
109
|
+
start = 0
|
|
110
|
+
seq = seq[start:start + self.seq_len]
|
|
111
|
+
|
|
112
|
+
# Augmentation
|
|
113
|
+
if self.augment:
|
|
114
|
+
seq = self._augment(seq)
|
|
115
|
+
|
|
116
|
+
return torch.from_numpy(seq).float()
|
|
117
|
+
|
|
118
|
+
def _augment(self, seq: np.ndarray) -> np.ndarray:
|
|
119
|
+
"""Apply data augmentation."""
|
|
120
|
+
# Random noise
|
|
121
|
+
if np.random.random() < 0.3:
|
|
122
|
+
noise = np.random.randn(*seq.shape) * 0.01
|
|
123
|
+
seq = seq + noise
|
|
124
|
+
|
|
125
|
+
# Random scale
|
|
126
|
+
if np.random.random() < 0.3:
|
|
127
|
+
scale = np.random.uniform(0.9, 1.1)
|
|
128
|
+
seq = seq * scale
|
|
129
|
+
|
|
130
|
+
return seq
|
|
131
|
+
|
|
132
|
+
def denormalize(self, seq: Tensor) -> Tensor:
|
|
133
|
+
"""Convert normalized sequence back to original scale."""
|
|
134
|
+
if not self.normalize:
|
|
135
|
+
return seq
|
|
136
|
+
mean = torch.from_numpy(self.mean).to(seq.device).float()
|
|
137
|
+
std = torch.from_numpy(self.std).to(seq.device).float()
|
|
138
|
+
return seq * std + mean
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def train_vfm_tokenizer(
|
|
142
|
+
model: nn.Module,
|
|
143
|
+
train_dataset: Dataset,
|
|
144
|
+
val_dataset: Optional[Dataset] = None,
|
|
145
|
+
num_epochs: int = 100,
|
|
146
|
+
batch_size: int = 32,
|
|
147
|
+
learning_rate: float = 1e-4,
|
|
148
|
+
tokenizer_weight: float = 1.0,
|
|
149
|
+
flow_weight: float = 1.0,
|
|
150
|
+
beta: float = 0.01,
|
|
151
|
+
beta_schedule: Optional[Callable[[int], float]] = None,
|
|
152
|
+
grad_clip: float = 1.0,
|
|
153
|
+
log_every: int = 100,
|
|
154
|
+
eval_every: int = 1000,
|
|
155
|
+
save_every: int = 5000,
|
|
156
|
+
save_dir: Optional[str] = None,
|
|
157
|
+
device: str = "cpu",
|
|
158
|
+
callback: Optional[Callable[[Dict], None]] = None,
|
|
159
|
+
) -> Dict[str, List[float]]:
|
|
160
|
+
"""
|
|
161
|
+
Train VFMTokenGenerator with joint tokenizer + flow matching loss.
|
|
162
|
+
|
|
163
|
+
Args:
|
|
164
|
+
model: VFMTokenGenerator or similar model with loss() method.
|
|
165
|
+
train_dataset: Training dataset of invariant sequences.
|
|
166
|
+
val_dataset: Optional validation dataset.
|
|
167
|
+
num_epochs: Number of training epochs.
|
|
168
|
+
batch_size: Batch size.
|
|
169
|
+
learning_rate: Learning rate.
|
|
170
|
+
tokenizer_weight: Weight for tokenizer reconstruction loss.
|
|
171
|
+
flow_weight: Weight for flow matching loss.
|
|
172
|
+
beta: KL divergence weight (for variational models).
|
|
173
|
+
beta_schedule: Optional function epoch -> beta for KL annealing.
|
|
174
|
+
grad_clip: Gradient clipping norm.
|
|
175
|
+
log_every: Log frequency (steps).
|
|
176
|
+
eval_every: Evaluation frequency (steps).
|
|
177
|
+
save_every: Checkpoint save frequency (steps).
|
|
178
|
+
save_dir: Directory for saving checkpoints.
|
|
179
|
+
device: Device for training.
|
|
180
|
+
callback: Optional callback function called with metrics dict.
|
|
181
|
+
|
|
182
|
+
Returns:
|
|
183
|
+
Dictionary of training history.
|
|
184
|
+
"""
|
|
185
|
+
model = model.to(device)
|
|
186
|
+
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
|
|
187
|
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
|
188
|
+
optimizer, T_max=num_epochs
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
train_loader = DataLoader(
|
|
192
|
+
train_dataset, batch_size=batch_size, shuffle=True, num_workers=0
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
if val_dataset is not None:
|
|
196
|
+
val_loader = DataLoader(
|
|
197
|
+
val_dataset, batch_size=batch_size, shuffle=False, num_workers=0
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
# Training history
|
|
201
|
+
history = {
|
|
202
|
+
"train_loss": [],
|
|
203
|
+
"train_tokenizer_loss": [],
|
|
204
|
+
"train_flow_loss": [],
|
|
205
|
+
"train_kl_loss": [],
|
|
206
|
+
"val_loss": [],
|
|
207
|
+
"learning_rate": [],
|
|
208
|
+
}
|
|
209
|
+
|
|
210
|
+
# Training loop
|
|
211
|
+
global_step = 0
|
|
212
|
+
best_val_loss = float("inf")
|
|
213
|
+
|
|
214
|
+
for epoch in range(num_epochs):
|
|
215
|
+
model.train()
|
|
216
|
+
epoch_losses = {"total": 0, "tokenizer": 0, "flow": 0, "kl": 0}
|
|
217
|
+
num_batches = 0
|
|
218
|
+
|
|
219
|
+
# Get current beta
|
|
220
|
+
current_beta = beta_schedule(epoch) if beta_schedule else beta
|
|
221
|
+
|
|
222
|
+
pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
|
|
223
|
+
for batch in pbar:
|
|
224
|
+
batch = batch.to(device)
|
|
225
|
+
|
|
226
|
+
# Forward and loss
|
|
227
|
+
optimizer.zero_grad()
|
|
228
|
+
losses = model.loss(
|
|
229
|
+
batch,
|
|
230
|
+
tokenizer_weight=tokenizer_weight,
|
|
231
|
+
flow_weight=flow_weight,
|
|
232
|
+
beta=current_beta,
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
# Backward
|
|
236
|
+
losses["total"].backward()
|
|
237
|
+
if grad_clip > 0:
|
|
238
|
+
torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
|
|
239
|
+
optimizer.step()
|
|
240
|
+
|
|
241
|
+
# Accumulate losses
|
|
242
|
+
epoch_losses["total"] += losses["total"].item()
|
|
243
|
+
epoch_losses["tokenizer"] += losses["tokenizer"].item()
|
|
244
|
+
epoch_losses["flow"] += losses["flow"].item()
|
|
245
|
+
epoch_losses["kl"] += losses["kl"].item()
|
|
246
|
+
num_batches += 1
|
|
247
|
+
global_step += 1
|
|
248
|
+
|
|
249
|
+
# Update progress bar
|
|
250
|
+
pbar.set_postfix({
|
|
251
|
+
"loss": f"{losses['total'].item():.4f}",
|
|
252
|
+
"tok": f"{losses['tokenizer'].item():.4f}",
|
|
253
|
+
"flow": f"{losses['flow'].item():.4f}",
|
|
254
|
+
})
|
|
255
|
+
|
|
256
|
+
# Logging
|
|
257
|
+
if global_step % log_every == 0:
|
|
258
|
+
avg_loss = epoch_losses["total"] / num_batches
|
|
259
|
+
history["train_loss"].append(avg_loss)
|
|
260
|
+
history["learning_rate"].append(optimizer.param_groups[0]["lr"])
|
|
261
|
+
|
|
262
|
+
if callback:
|
|
263
|
+
callback({
|
|
264
|
+
"step": global_step,
|
|
265
|
+
"epoch": epoch,
|
|
266
|
+
"train_loss": avg_loss,
|
|
267
|
+
"lr": optimizer.param_groups[0]["lr"],
|
|
268
|
+
})
|
|
269
|
+
|
|
270
|
+
# Evaluation
|
|
271
|
+
if val_dataset is not None and global_step % eval_every == 0:
|
|
272
|
+
val_loss = evaluate_model(model, val_loader, device, current_beta)
|
|
273
|
+
history["val_loss"].append(val_loss)
|
|
274
|
+
|
|
275
|
+
if val_loss < best_val_loss:
|
|
276
|
+
best_val_loss = val_loss
|
|
277
|
+
if save_dir:
|
|
278
|
+
save_checkpoint(
|
|
279
|
+
model, optimizer, epoch, global_step,
|
|
280
|
+
os.path.join(save_dir, "best_model.pt")
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
# Checkpointing
|
|
284
|
+
if save_dir and global_step % save_every == 0:
|
|
285
|
+
save_checkpoint(
|
|
286
|
+
model, optimizer, epoch, global_step,
|
|
287
|
+
os.path.join(save_dir, f"checkpoint_{global_step}.pt")
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
# End of epoch
|
|
291
|
+
scheduler.step()
|
|
292
|
+
|
|
293
|
+
# Record epoch averages
|
|
294
|
+
for key in epoch_losses:
|
|
295
|
+
epoch_losses[key] /= max(num_batches, 1)
|
|
296
|
+
|
|
297
|
+
history["train_tokenizer_loss"].append(epoch_losses["tokenizer"])
|
|
298
|
+
history["train_flow_loss"].append(epoch_losses["flow"])
|
|
299
|
+
history["train_kl_loss"].append(epoch_losses["kl"])
|
|
300
|
+
|
|
301
|
+
print(f"Epoch {epoch+1}: loss={epoch_losses['total']:.4f}, "
|
|
302
|
+
f"tok={epoch_losses['tokenizer']:.4f}, "
|
|
303
|
+
f"flow={epoch_losses['flow']:.4f}, "
|
|
304
|
+
f"kl={epoch_losses['kl']:.4f}")
|
|
305
|
+
|
|
306
|
+
return history
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
def evaluate_model(
|
|
310
|
+
model: nn.Module,
|
|
311
|
+
data_loader: DataLoader,
|
|
312
|
+
device: str,
|
|
313
|
+
beta: float = 0.01,
|
|
314
|
+
) -> float:
|
|
315
|
+
"""Evaluate model on validation data."""
|
|
316
|
+
model.eval()
|
|
317
|
+
total_loss = 0
|
|
318
|
+
num_batches = 0
|
|
319
|
+
|
|
320
|
+
with torch.no_grad():
|
|
321
|
+
for batch in data_loader:
|
|
322
|
+
batch = batch.to(device)
|
|
323
|
+
losses = model.loss(batch, beta=beta)
|
|
324
|
+
total_loss += losses["total"].item()
|
|
325
|
+
num_batches += 1
|
|
326
|
+
|
|
327
|
+
model.train()
|
|
328
|
+
return total_loss / max(num_batches, 1)
|
|
329
|
+
|
|
330
|
+
|
|
331
|
+
def save_checkpoint(
|
|
332
|
+
model: nn.Module,
|
|
333
|
+
optimizer: torch.optim.Optimizer,
|
|
334
|
+
epoch: int,
|
|
335
|
+
step: int,
|
|
336
|
+
path: str,
|
|
337
|
+
):
|
|
338
|
+
"""Save training checkpoint."""
|
|
339
|
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
|
340
|
+
torch.save({
|
|
341
|
+
"model_state_dict": model.state_dict(),
|
|
342
|
+
"optimizer_state_dict": optimizer.state_dict(),
|
|
343
|
+
"epoch": epoch,
|
|
344
|
+
"step": step,
|
|
345
|
+
}, path)
|
|
346
|
+
|
|
347
|
+
|
|
348
|
+
def load_checkpoint(
|
|
349
|
+
model: nn.Module,
|
|
350
|
+
optimizer: Optional[torch.optim.Optimizer],
|
|
351
|
+
path: str,
|
|
352
|
+
) -> Dict:
|
|
353
|
+
"""Load training checkpoint."""
|
|
354
|
+
checkpoint = torch.load(path)
|
|
355
|
+
model.load_state_dict(checkpoint["model_state_dict"])
|
|
356
|
+
if optimizer is not None:
|
|
357
|
+
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
|
358
|
+
return {
|
|
359
|
+
"epoch": checkpoint["epoch"],
|
|
360
|
+
"step": checkpoint["step"],
|
|
361
|
+
}
|
|
362
|
+
|
|
363
|
+
|
|
364
|
+
# ---- Evaluation metrics ----
|
|
365
|
+
|
|
366
|
+
def compute_reconstruction_error(
|
|
367
|
+
model: nn.Module,
|
|
368
|
+
data_loader: DataLoader,
|
|
369
|
+
device: str,
|
|
370
|
+
) -> Dict[str, float]:
|
|
371
|
+
"""
|
|
372
|
+
Compute reconstruction error metrics.
|
|
373
|
+
|
|
374
|
+
Returns MSE, MAE, and max error.
|
|
375
|
+
"""
|
|
376
|
+
model.eval()
|
|
377
|
+
all_mse = []
|
|
378
|
+
all_mae = []
|
|
379
|
+
all_max = []
|
|
380
|
+
|
|
381
|
+
with torch.no_grad():
|
|
382
|
+
for batch in data_loader:
|
|
383
|
+
batch = batch.to(device)
|
|
384
|
+
outputs = model(batch)
|
|
385
|
+
recon = outputs["invariants_reconstructed"]
|
|
386
|
+
|
|
387
|
+
mse = ((batch - recon) ** 2).mean(dim=(1, 2))
|
|
388
|
+
mae = (batch - recon).abs().mean(dim=(1, 2))
|
|
389
|
+
max_err = (batch - recon).abs().max(dim=-1)[0].max(dim=-1)[0]
|
|
390
|
+
|
|
391
|
+
all_mse.extend(mse.cpu().numpy())
|
|
392
|
+
all_mae.extend(mae.cpu().numpy())
|
|
393
|
+
all_max.extend(max_err.cpu().numpy())
|
|
394
|
+
|
|
395
|
+
model.train()
|
|
396
|
+
return {
|
|
397
|
+
"mse": np.mean(all_mse),
|
|
398
|
+
"mae": np.mean(all_mae),
|
|
399
|
+
"max_error": np.mean(all_max),
|
|
400
|
+
}
|
|
401
|
+
|
|
402
|
+
|
|
403
|
+
def compute_generation_diversity(
|
|
404
|
+
model: nn.Module,
|
|
405
|
+
num_samples: int = 100,
|
|
406
|
+
seq_len: int = 50,
|
|
407
|
+
device: str = "cpu",
|
|
408
|
+
) -> Dict[str, float]:
|
|
409
|
+
"""
|
|
410
|
+
Compute diversity metrics for generated samples.
|
|
411
|
+
|
|
412
|
+
Returns pairwise distance statistics.
|
|
413
|
+
"""
|
|
414
|
+
model.eval()
|
|
415
|
+
|
|
416
|
+
with torch.no_grad():
|
|
417
|
+
# Generate samples
|
|
418
|
+
if hasattr(model, "generate_multimodal"):
|
|
419
|
+
samples = model.generate_multimodal(
|
|
420
|
+
num_samples=num_samples // 4,
|
|
421
|
+
seq_len=seq_len,
|
|
422
|
+
num_modes=4,
|
|
423
|
+
device=device,
|
|
424
|
+
)
|
|
425
|
+
else:
|
|
426
|
+
samples = model.generate(
|
|
427
|
+
num_samples=num_samples,
|
|
428
|
+
seq_len=seq_len,
|
|
429
|
+
device=device,
|
|
430
|
+
)
|
|
431
|
+
|
|
432
|
+
# Compute pairwise distances
|
|
433
|
+
samples_flat = samples.reshape(samples.shape[0], -1)
|
|
434
|
+
dists = torch.cdist(samples_flat, samples_flat)
|
|
435
|
+
|
|
436
|
+
# Get upper triangle (excluding diagonal)
|
|
437
|
+
mask = torch.triu(torch.ones_like(dists), diagonal=1).bool()
|
|
438
|
+
pairwise_dists = dists[mask]
|
|
439
|
+
|
|
440
|
+
model.train()
|
|
441
|
+
return {
|
|
442
|
+
"mean_pairwise_dist": pairwise_dists.mean().item(),
|
|
443
|
+
"std_pairwise_dist": pairwise_dists.std().item(),
|
|
444
|
+
"min_pairwise_dist": pairwise_dists.min().item(),
|
|
445
|
+
"max_pairwise_dist": pairwise_dists.max().item(),
|
|
446
|
+
}
|
|
447
|
+
|
|
448
|
+
|
|
449
|
+
# ---- KL annealing schedules ----
|
|
450
|
+
|
|
451
|
+
def linear_kl_schedule(
|
|
452
|
+
warmup_epochs: int,
|
|
453
|
+
max_beta: float = 1.0,
|
|
454
|
+
) -> Callable[[int], float]:
|
|
455
|
+
"""Linear KL annealing from 0 to max_beta over warmup_epochs."""
|
|
456
|
+
def schedule(epoch: int) -> float:
|
|
457
|
+
if epoch < warmup_epochs:
|
|
458
|
+
return max_beta * epoch / warmup_epochs
|
|
459
|
+
return max_beta
|
|
460
|
+
return schedule
|
|
461
|
+
|
|
462
|
+
|
|
463
|
+
def cyclical_kl_schedule(
|
|
464
|
+
cycle_length: int = 10,
|
|
465
|
+
num_cycles: int = 4,
|
|
466
|
+
max_beta: float = 1.0,
|
|
467
|
+
) -> Callable[[int], float]:
|
|
468
|
+
"""Cyclical KL annealing."""
|
|
469
|
+
def schedule(epoch: int) -> float:
|
|
470
|
+
cycle = epoch // cycle_length
|
|
471
|
+
if cycle >= num_cycles:
|
|
472
|
+
return max_beta
|
|
473
|
+
progress = (epoch % cycle_length) / cycle_length
|
|
474
|
+
return max_beta * progress
|
|
475
|
+
return schedule
|