dhb-xr 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dhb_xr/__init__.py +61 -0
- dhb_xr/cli.py +206 -0
- dhb_xr/core/__init__.py +28 -0
- dhb_xr/core/geometry.py +167 -0
- dhb_xr/core/geometry_torch.py +77 -0
- dhb_xr/core/types.py +113 -0
- dhb_xr/database/__init__.py +10 -0
- dhb_xr/database/motion_db.py +79 -0
- dhb_xr/database/retrieval.py +6 -0
- dhb_xr/database/similarity.py +71 -0
- dhb_xr/decoder/__init__.py +13 -0
- dhb_xr/decoder/decoder_torch.py +52 -0
- dhb_xr/decoder/dhb_dr.py +261 -0
- dhb_xr/decoder/dhb_qr.py +89 -0
- dhb_xr/encoder/__init__.py +27 -0
- dhb_xr/encoder/dhb_dr.py +418 -0
- dhb_xr/encoder/dhb_qr.py +129 -0
- dhb_xr/encoder/dhb_ti.py +204 -0
- dhb_xr/encoder/encoder_torch.py +54 -0
- dhb_xr/encoder/padding.py +82 -0
- dhb_xr/generative/__init__.py +78 -0
- dhb_xr/generative/flow_matching.py +705 -0
- dhb_xr/generative/latent_encoder.py +536 -0
- dhb_xr/generative/sampling.py +203 -0
- dhb_xr/generative/training.py +475 -0
- dhb_xr/generative/vfm_tokenizer.py +485 -0
- dhb_xr/integration/__init__.py +13 -0
- dhb_xr/integration/vla/__init__.py +11 -0
- dhb_xr/integration/vla/libero.py +132 -0
- dhb_xr/integration/vla/pipeline.py +85 -0
- dhb_xr/integration/vla/robocasa.py +85 -0
- dhb_xr/losses/__init__.py +16 -0
- dhb_xr/losses/geodesic_loss.py +91 -0
- dhb_xr/losses/hybrid_loss.py +36 -0
- dhb_xr/losses/invariant_loss.py +73 -0
- dhb_xr/optimization/__init__.py +72 -0
- dhb_xr/optimization/casadi_solver.py +342 -0
- dhb_xr/optimization/constraints.py +32 -0
- dhb_xr/optimization/cusadi_solver.py +311 -0
- dhb_xr/optimization/export_casadi_decode.py +111 -0
- dhb_xr/optimization/fatrop_solver.py +477 -0
- dhb_xr/optimization/torch_solver.py +85 -0
- dhb_xr/preprocessing/__init__.py +42 -0
- dhb_xr/preprocessing/diagnostics.py +330 -0
- dhb_xr/preprocessing/trajectory_cleaner.py +485 -0
- dhb_xr/tokenization/__init__.py +56 -0
- dhb_xr/tokenization/causal_encoder.py +54 -0
- dhb_xr/tokenization/compression.py +749 -0
- dhb_xr/tokenization/hierarchical.py +359 -0
- dhb_xr/tokenization/rvq.py +178 -0
- dhb_xr/tokenization/vqvae.py +155 -0
- dhb_xr/utils/__init__.py +24 -0
- dhb_xr/utils/io.py +59 -0
- dhb_xr/utils/resampling.py +66 -0
- dhb_xr/utils/xdof_loader.py +89 -0
- dhb_xr/visualization/__init__.py +5 -0
- dhb_xr/visualization/plot.py +242 -0
- dhb_xr-0.2.1.dist-info/METADATA +784 -0
- dhb_xr-0.2.1.dist-info/RECORD +82 -0
- dhb_xr-0.2.1.dist-info/WHEEL +5 -0
- dhb_xr-0.2.1.dist-info/entry_points.txt +2 -0
- dhb_xr-0.2.1.dist-info/top_level.txt +3 -0
- examples/__init__.py +54 -0
- examples/basic_encoding.py +82 -0
- examples/benchmark_backends.py +37 -0
- examples/dhb_qr_comparison.py +79 -0
- examples/dhb_ti_time_invariant.py +72 -0
- examples/gpu_batch_optimization.py +102 -0
- examples/imitation_learning.py +53 -0
- examples/integration/__init__.py +19 -0
- examples/integration/libero_full_demo.py +692 -0
- examples/integration/libero_pro_dhb_demo.py +1063 -0
- examples/integration/libero_simulation_demo.py +286 -0
- examples/integration/libero_swap_demo.py +534 -0
- examples/integration/robocasa_libero_dhb_pipeline.py +56 -0
- examples/integration/test_libero_adapter.py +47 -0
- examples/integration/test_libero_encoding.py +75 -0
- examples/integration/test_libero_retrieval.py +105 -0
- examples/motion_database.py +88 -0
- examples/trajectory_adaptation.py +85 -0
- examples/vla_tokenization.py +107 -0
- notebooks/__init__.py +24 -0
dhb_xr/encoder/dhb_ti.py
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
1
|
+
"""
|
|
2
|
+
DHB-TI: Time-invariant reparameterization.
|
|
3
|
+
|
|
4
|
+
Reparameterize a trajectory by a geometric progress variable (translational arc-length,
|
|
5
|
+
angular, or hybrid) and resample at uniform progress knots so that DHB-DR/DHB-QR
|
|
6
|
+
invariants are approximately independent of execution speed and sampling rate.
|
|
7
|
+
|
|
8
|
+
Progress variables:
|
|
9
|
+
- translation: s_{i+1} = s_i + ||Δp_i||
|
|
10
|
+
- angular: θ_{i+1} = θ_i + ||Δr_i||
|
|
11
|
+
- hybrid: σ_{i+1} = σ_i + α||Δp_i|| + (1-α)||Δr_i||, α in [0,1]
|
|
12
|
+
|
|
13
|
+
Uniform knots σ_k = k * Σ/(M-1), then interpolate poses at σ_k (position spline, quat SLERP).
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
from __future__ import annotations
|
|
17
|
+
|
|
18
|
+
import numpy as np
|
|
19
|
+
from typing import Dict, Any, Optional, Literal
|
|
20
|
+
from scipy.interpolate import CubicSpline
|
|
21
|
+
|
|
22
|
+
from dhb_xr.core import geometry as geom
|
|
23
|
+
from dhb_xr.core.types import DHBMethod, EncodingMethod
|
|
24
|
+
|
|
25
|
+
ProgressKind = Literal["translation", "angular", "hybrid"]
|
|
26
|
+
_EPS = 1e-12
|
|
27
|
+
_MIN_STEP = 1e-10 # minimum progress step to avoid degeneracy in interpolation
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def compute_progress(
|
|
31
|
+
positions: np.ndarray,
|
|
32
|
+
quaternions: np.ndarray,
|
|
33
|
+
kind: ProgressKind = "hybrid",
|
|
34
|
+
alpha: float = 0.5,
|
|
35
|
+
min_step: float = _MIN_STEP,
|
|
36
|
+
) -> np.ndarray:
|
|
37
|
+
"""
|
|
38
|
+
Compute cumulative progress along the trajectory.
|
|
39
|
+
|
|
40
|
+
positions: (N, 3), quaternions: (N, 4) wxyz.
|
|
41
|
+
kind: 'translation' (arc-length), 'angular' (rotation magnitude), 'hybrid'.
|
|
42
|
+
alpha: weight for translation in hybrid; (1-alpha) for rotation. Ignored if kind != 'hybrid'.
|
|
43
|
+
min_step: minimum increment per segment to avoid degenerate progress (clamped).
|
|
44
|
+
|
|
45
|
+
Returns progress (N,) with progress[0] = 0.
|
|
46
|
+
"""
|
|
47
|
+
positions = np.asarray(positions, dtype=np.float64)
|
|
48
|
+
quaternions = np.asarray(quaternions, dtype=np.float64)
|
|
49
|
+
n = positions.shape[0]
|
|
50
|
+
if n < 2:
|
|
51
|
+
return np.zeros(n)
|
|
52
|
+
|
|
53
|
+
delta_p = np.diff(positions, axis=0)
|
|
54
|
+
step_p = np.linalg.norm(delta_p, axis=1)
|
|
55
|
+
step_p = np.maximum(step_p, min_step)
|
|
56
|
+
|
|
57
|
+
delta_r = np.array([
|
|
58
|
+
geom.quat_relative_axis_angle(quaternions[i], quaternions[i + 1])
|
|
59
|
+
for i in range(n - 1)
|
|
60
|
+
])
|
|
61
|
+
step_r = np.linalg.norm(delta_r, axis=1)
|
|
62
|
+
step_r = np.maximum(step_r, min_step)
|
|
63
|
+
|
|
64
|
+
if kind == "translation":
|
|
65
|
+
steps = step_p
|
|
66
|
+
elif kind == "angular":
|
|
67
|
+
steps = step_r
|
|
68
|
+
else:
|
|
69
|
+
steps = alpha * step_p + (1.0 - alpha) * step_r
|
|
70
|
+
|
|
71
|
+
progress = np.concatenate([[0.0], np.cumsum(steps)])
|
|
72
|
+
return progress
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def resample_by_progress(
|
|
76
|
+
positions: np.ndarray,
|
|
77
|
+
quaternions: np.ndarray,
|
|
78
|
+
M: int,
|
|
79
|
+
progress_kind: ProgressKind = "hybrid",
|
|
80
|
+
alpha: float = 0.5,
|
|
81
|
+
progress: Optional[np.ndarray] = None,
|
|
82
|
+
min_step: float = _MIN_STEP,
|
|
83
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
84
|
+
"""
|
|
85
|
+
Resample trajectory to M poses at uniform progress knots.
|
|
86
|
+
|
|
87
|
+
progress_knots: σ_k = k * Σ/(M-1), k = 0,...,M-1.
|
|
88
|
+
Positions interpolated with cubic spline in progress; orientations with SLERP.
|
|
89
|
+
|
|
90
|
+
Returns (positions_M, quaternions_M) (M, 3), (M, 4) wxyz.
|
|
91
|
+
"""
|
|
92
|
+
positions = np.asarray(positions, dtype=np.float64)
|
|
93
|
+
quaternions = np.asarray(quaternions, dtype=np.float64)
|
|
94
|
+
n = positions.shape[0]
|
|
95
|
+
if n < 2:
|
|
96
|
+
raise ValueError("resample_by_progress requires at least 2 poses")
|
|
97
|
+
|
|
98
|
+
if progress is None:
|
|
99
|
+
progress = compute_progress(
|
|
100
|
+
positions, quaternions, kind=progress_kind, alpha=alpha, min_step=min_step
|
|
101
|
+
)
|
|
102
|
+
else:
|
|
103
|
+
progress = np.asarray(progress, dtype=np.float64)
|
|
104
|
+
if progress.shape[0] != n:
|
|
105
|
+
raise ValueError("progress length must match positions/quaternions")
|
|
106
|
+
|
|
107
|
+
sigma_total = progress[-1]
|
|
108
|
+
if sigma_total <= 0:
|
|
109
|
+
sigma_total = 1.0
|
|
110
|
+
sigma_knots = np.linspace(0, sigma_total, M, dtype=np.float64)
|
|
111
|
+
|
|
112
|
+
pos_resample = np.zeros((M, 3))
|
|
113
|
+
quat_resample = np.zeros((M, 4))
|
|
114
|
+
for j in range(3):
|
|
115
|
+
cs = CubicSpline(progress, positions[:, j])
|
|
116
|
+
pos_resample[:, j] = cs(sigma_knots)
|
|
117
|
+
|
|
118
|
+
for k in range(M):
|
|
119
|
+
s = sigma_knots[k]
|
|
120
|
+
if s <= progress[0] + _EPS:
|
|
121
|
+
quat_resample[k] = quaternions[0].copy()
|
|
122
|
+
continue
|
|
123
|
+
if s >= progress[-1] - _EPS:
|
|
124
|
+
quat_resample[k] = quaternions[-1].copy()
|
|
125
|
+
continue
|
|
126
|
+
i = np.searchsorted(progress, s, side="right") - 1
|
|
127
|
+
i = min(max(i, 0), n - 2)
|
|
128
|
+
p_lo, p_hi = progress[i], progress[i + 1]
|
|
129
|
+
segment = p_hi - p_lo
|
|
130
|
+
if segment <= _EPS:
|
|
131
|
+
t = 0.0
|
|
132
|
+
else:
|
|
133
|
+
t = float((s - p_lo) / segment)
|
|
134
|
+
t = np.clip(t, 0.0, 1.0)
|
|
135
|
+
quat_resample[k] = geom.quat_slerp(quaternions[i], quaternions[i + 1], t)
|
|
136
|
+
|
|
137
|
+
return pos_resample, quat_resample
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def encode_dhb_dr_ti(
|
|
141
|
+
positions: np.ndarray,
|
|
142
|
+
quaternions: np.ndarray,
|
|
143
|
+
M: int,
|
|
144
|
+
progress_kind: ProgressKind = "hybrid",
|
|
145
|
+
alpha: float = 0.5,
|
|
146
|
+
method: Union[str, EncodingMethod] = EncodingMethod.POSITION,
|
|
147
|
+
use_default_initial_frames: bool = True,
|
|
148
|
+
init_pose: Optional[Dict[str, np.ndarray]] = None,
|
|
149
|
+
dhb_method: DHBMethod = DHBMethod.DOUBLE_REFLECTION,
|
|
150
|
+
min_step: float = _MIN_STEP,
|
|
151
|
+
**encode_kw,
|
|
152
|
+
) -> Dict[str, Any]:
|
|
153
|
+
"""
|
|
154
|
+
Time-invariant encode: reparameterize by progress to M samples, then DHB-DR encode.
|
|
155
|
+
|
|
156
|
+
Returns same structure as encode_dhb_dr (linear_motion_invariants, angular_motion_invariants,
|
|
157
|
+
initial_pose, ...).
|
|
158
|
+
"""
|
|
159
|
+
from dhb_xr.encoder.dhb_dr import encode_dhb_dr
|
|
160
|
+
|
|
161
|
+
pos_m, quat_m = resample_by_progress(
|
|
162
|
+
positions, quaternions, M,
|
|
163
|
+
progress_kind=progress_kind, alpha=alpha, min_step=min_step,
|
|
164
|
+
)
|
|
165
|
+
return encode_dhb_dr(
|
|
166
|
+
pos_m, quat_m,
|
|
167
|
+
method=method,
|
|
168
|
+
use_default_initial_frames=use_default_initial_frames,
|
|
169
|
+
init_pose=init_pose,
|
|
170
|
+
dhb_method=dhb_method,
|
|
171
|
+
**encode_kw,
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def encode_dhb_qr_ti(
|
|
176
|
+
positions: np.ndarray,
|
|
177
|
+
quaternions: np.ndarray,
|
|
178
|
+
M: int,
|
|
179
|
+
progress_kind: ProgressKind = "hybrid",
|
|
180
|
+
alpha: float = 0.5,
|
|
181
|
+
method: EncodingMethod = EncodingMethod.POSITION,
|
|
182
|
+
use_default_initial_frames: bool = True,
|
|
183
|
+
init_pose: Optional[Dict[str, np.ndarray]] = None,
|
|
184
|
+
min_step: float = _MIN_STEP,
|
|
185
|
+
**encode_kw,
|
|
186
|
+
) -> Dict[str, Any]:
|
|
187
|
+
"""
|
|
188
|
+
Time-invariant encode: reparameterize by progress to M samples, then DHB-QR encode.
|
|
189
|
+
|
|
190
|
+
Returns same structure as encode_dhb_qr.
|
|
191
|
+
"""
|
|
192
|
+
from dhb_xr.encoder.dhb_qr import encode_dhb_qr
|
|
193
|
+
|
|
194
|
+
pos_m, quat_m = resample_by_progress(
|
|
195
|
+
positions, quaternions, M,
|
|
196
|
+
progress_kind=progress_kind, alpha=alpha, min_step=min_step,
|
|
197
|
+
)
|
|
198
|
+
return encode_dhb_qr(
|
|
199
|
+
pos_m, quat_m,
|
|
200
|
+
method=method,
|
|
201
|
+
use_default_initial_frames=use_default_initial_frames,
|
|
202
|
+
init_pose=init_pose,
|
|
203
|
+
**encode_kw,
|
|
204
|
+
)
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
"""Batched DHB encoder in PyTorch: wrapper over numpy encode for GPU-friendly batch API."""
|
|
2
|
+
|
|
3
|
+
try:
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
except ImportError:
|
|
7
|
+
torch = None
|
|
8
|
+
nn = None
|
|
9
|
+
|
|
10
|
+
if torch is not None:
|
|
11
|
+
|
|
12
|
+
import numpy as np
|
|
13
|
+
from dhb_xr.encoder.dhb_dr import encode_dhb_dr
|
|
14
|
+
from dhb_xr.core.types import DHBMethod, EncodingMethod
|
|
15
|
+
|
|
16
|
+
class DHBEncoderTorch(nn.Module):
|
|
17
|
+
"""
|
|
18
|
+
Batched DHB-DR encoder.
|
|
19
|
+
positions (B, N, 3), quaternions (B, N, 4) wxyz -> invariants (B, N-2, 8).
|
|
20
|
+
Uses numpy encode_dhb_dr per batch item and stacks; supports .to(device) for moving data to GPU.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
def __init__(self, dhb_method: str = "double_reflection"):
|
|
24
|
+
super().__init__()
|
|
25
|
+
self.dhb_method = DHBMethod.DOUBLE_REFLECTION if dhb_method == "double_reflection" else DHBMethod.ORIGINAL
|
|
26
|
+
self.k = 4 if self.dhb_method == DHBMethod.DOUBLE_REFLECTION else 3
|
|
27
|
+
|
|
28
|
+
def forward(
|
|
29
|
+
self,
|
|
30
|
+
positions: torch.Tensor,
|
|
31
|
+
quaternions: torch.Tensor,
|
|
32
|
+
) -> torch.Tensor:
|
|
33
|
+
B, N, _ = positions.shape
|
|
34
|
+
assert quaternions.shape == (B, N, 4)
|
|
35
|
+
device = positions.device
|
|
36
|
+
dtype = positions.dtype
|
|
37
|
+
inv_list = []
|
|
38
|
+
for b in range(B):
|
|
39
|
+
pos = positions[b].detach().cpu().numpy()
|
|
40
|
+
quat = quaternions[b].detach().cpu().numpy()
|
|
41
|
+
out = encode_dhb_dr(
|
|
42
|
+
pos, quat,
|
|
43
|
+
method=EncodingMethod.POSITION,
|
|
44
|
+
use_default_initial_frames=True,
|
|
45
|
+
dhb_method=self.dhb_method,
|
|
46
|
+
)
|
|
47
|
+
lin = out["linear_motion_invariants"]
|
|
48
|
+
ang = out["angular_motion_invariants"]
|
|
49
|
+
inv_list.append(torch.from_numpy(np.concatenate([lin, ang], axis=1)))
|
|
50
|
+
out_t = torch.stack(inv_list, dim=0).to(device=device, dtype=dtype)
|
|
51
|
+
return out_t
|
|
52
|
+
|
|
53
|
+
else:
|
|
54
|
+
DHBEncoderTorch = None
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Boundary extrapolation for length preservation.
|
|
3
|
+
Prepends 2 poses and appends 1 pose so invariant length aligns with N-1 steps.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
from typing import Dict, Any
|
|
8
|
+
|
|
9
|
+
from dhb_xr.core import geometry as geom
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def extrapolate_boundary_poses(
|
|
13
|
+
positions: np.ndarray,
|
|
14
|
+
quaternions: np.ndarray,
|
|
15
|
+
) -> tuple:
|
|
16
|
+
"""
|
|
17
|
+
Extrapolate boundary poses for length-preserved encoding.
|
|
18
|
+
|
|
19
|
+
Prepends g_{-2}, g_{-1} and appends g_N using first-order extrapolation:
|
|
20
|
+
- Position: p_{-1} = 2*p_0 - p_1, p_{-2} = 3*p_0 - 2*p_1, p_N = 2*p_{N-1} - p_{N-2}
|
|
21
|
+
- Rotation: Lie-algebra consistent; R_{-1} = R_0 * exp(-[Δr_0]_×), etc.
|
|
22
|
+
|
|
23
|
+
positions: (N, 3)
|
|
24
|
+
quaternions: (N, 4) wxyz
|
|
25
|
+
|
|
26
|
+
Returns (positions_ext, quaternions_ext) with shape (N+3, 3) and (N+3, 4).
|
|
27
|
+
"""
|
|
28
|
+
positions = np.asarray(positions, dtype=np.float64)
|
|
29
|
+
quaternions = np.asarray(quaternions, dtype=np.float64)
|
|
30
|
+
n = positions.shape[0]
|
|
31
|
+
assert n >= 2 and quaternions.shape[0] == n
|
|
32
|
+
|
|
33
|
+
# Rotation differentials for extrapolation
|
|
34
|
+
rotation_diff = np.zeros((n - 1, 3))
|
|
35
|
+
for i in range(1, n):
|
|
36
|
+
R_prev = geom.quat_to_rot(quaternions[i - 1]).T
|
|
37
|
+
R_curr = geom.quat_to_rot(quaternions[i]).T
|
|
38
|
+
R_rel = R_curr @ R_prev.T
|
|
39
|
+
rotation_diff[i - 1] = geom.rot_to_axis_angle(R_rel)
|
|
40
|
+
|
|
41
|
+
# Prepend: p_{-1} = 2*p_0 - p_1, p_{-2} = 3*p_0 - 2*p_1
|
|
42
|
+
dp0 = positions[1] - positions[0]
|
|
43
|
+
p_minus1 = positions[0] - dp0 # 2*p_0 - p_1
|
|
44
|
+
p_minus2 = p_minus1 - dp0 # 3*p_0 - 2*p_1
|
|
45
|
+
|
|
46
|
+
# R_{-1} = R_0 * exp(-[Δr_0]_×), Δr_{-1} = Δr_0 for first-order
|
|
47
|
+
dr0 = rotation_diff[0]
|
|
48
|
+
R0 = geom.quat_to_rot(quaternions[0])
|
|
49
|
+
R_minus1 = R0 @ geom.axis_angle_to_rot(-dr0).T
|
|
50
|
+
q_minus1 = geom.rot_to_quat(R_minus1)
|
|
51
|
+
dr_minus1 = dr0 # first-order
|
|
52
|
+
R_minus2_mat = geom.quat_to_rot(q_minus1) @ geom.axis_angle_to_rot(-dr_minus1).T
|
|
53
|
+
q_minus2 = geom.rot_to_quat(R_minus2_mat)
|
|
54
|
+
|
|
55
|
+
# Append: p_N = 2*p_{N-1} - p_{N-2}
|
|
56
|
+
p_last = 2 * positions[-1] - positions[-2]
|
|
57
|
+
dr_last = rotation_diff[-1]
|
|
58
|
+
R_last = geom.quat_to_rot(quaternions[-1]) @ geom.axis_angle_to_rot(dr_last).T
|
|
59
|
+
q_last = geom.rot_to_quat(R_last)
|
|
60
|
+
|
|
61
|
+
positions_ext = np.vstack(
|
|
62
|
+
(p_minus2.reshape(1, 3), p_minus1.reshape(1, 3), positions, p_last.reshape(1, 3))
|
|
63
|
+
)
|
|
64
|
+
quaternions_ext = np.vstack(
|
|
65
|
+
(
|
|
66
|
+
q_minus2.reshape(1, 4),
|
|
67
|
+
q_minus1.reshape(1, 4),
|
|
68
|
+
quaternions,
|
|
69
|
+
q_last.reshape(1, 4),
|
|
70
|
+
)
|
|
71
|
+
)
|
|
72
|
+
return positions_ext, quaternions_ext
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def apply_length_preserving_padding(
|
|
76
|
+
positions: np.ndarray,
|
|
77
|
+
quaternions: np.ndarray,
|
|
78
|
+
) -> tuple:
|
|
79
|
+
"""
|
|
80
|
+
Same as extrapolate_boundary_poses. Kept for API compatibility.
|
|
81
|
+
"""
|
|
82
|
+
return extrapolate_boundary_poses(positions, quaternions)
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Generative models for DHB-Token trajectory generation.
|
|
3
|
+
|
|
4
|
+
This module provides variational flow matching (VFM) and rectified flow matching (V-RFM)
|
|
5
|
+
for multi-modal trajectory generation in the DHB invariant latent space.
|
|
6
|
+
|
|
7
|
+
Key components:
|
|
8
|
+
- FlowMatcher: Base flow matching model for continuous latent generation
|
|
9
|
+
- VariationalFlowMatcher: VFM with latent conditioning for multi-modal generation
|
|
10
|
+
- VFMTokenGenerator: End-to-end integration with DHB tokenizers
|
|
11
|
+
|
|
12
|
+
Example usage:
|
|
13
|
+
>>> from dhb_xr.generative import VariationalFlowMatcher, VFMTokenGenerator
|
|
14
|
+
>>> from dhb_xr.tokenization import DHBTokenizer
|
|
15
|
+
>>>
|
|
16
|
+
>>> # Create tokenizer and flow matcher
|
|
17
|
+
>>> tokenizer = DHBTokenizer(invariant_dim=8, latent_dim=32, codebook_size=512)
|
|
18
|
+
>>> flow_matcher = VariationalFlowMatcher(latent_dim=32, hidden_dim=128)
|
|
19
|
+
>>>
|
|
20
|
+
>>> # Create end-to-end generator
|
|
21
|
+
>>> generator = VFMTokenGenerator(tokenizer, flow_matcher)
|
|
22
|
+
>>>
|
|
23
|
+
>>> # Generate multi-modal trajectories
|
|
24
|
+
>>> invariants = generator.generate_multimodal(prefix_invariants, num_modes=4)
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
from .flow_matching import (
|
|
28
|
+
SinusoidalTimeEmbedding,
|
|
29
|
+
VelocityNetwork,
|
|
30
|
+
FlowMatcher,
|
|
31
|
+
VariationalFlowMatcher,
|
|
32
|
+
)
|
|
33
|
+
from .sampling import euler_solve, rk4_solve, ode_solve
|
|
34
|
+
from .latent_encoder import (
|
|
35
|
+
LatentEncoder,
|
|
36
|
+
CategoricalLatentEncoder,
|
|
37
|
+
HybridLatentEncoder,
|
|
38
|
+
)
|
|
39
|
+
from .vfm_tokenizer import (
|
|
40
|
+
VFMTokenGenerator,
|
|
41
|
+
ConditionalVFMGenerator,
|
|
42
|
+
)
|
|
43
|
+
from .training import (
|
|
44
|
+
InvariantDataset,
|
|
45
|
+
train_vfm_tokenizer,
|
|
46
|
+
evaluate_model,
|
|
47
|
+
compute_reconstruction_error,
|
|
48
|
+
compute_generation_diversity,
|
|
49
|
+
linear_kl_schedule,
|
|
50
|
+
cyclical_kl_schedule,
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
__all__ = [
|
|
54
|
+
# Flow matching
|
|
55
|
+
"SinusoidalTimeEmbedding",
|
|
56
|
+
"VelocityNetwork",
|
|
57
|
+
"FlowMatcher",
|
|
58
|
+
"VariationalFlowMatcher",
|
|
59
|
+
# Latent encoders
|
|
60
|
+
"LatentEncoder",
|
|
61
|
+
"CategoricalLatentEncoder",
|
|
62
|
+
"HybridLatentEncoder",
|
|
63
|
+
# VFM Token generators
|
|
64
|
+
"VFMTokenGenerator",
|
|
65
|
+
"ConditionalVFMGenerator",
|
|
66
|
+
# Training utilities
|
|
67
|
+
"InvariantDataset",
|
|
68
|
+
"train_vfm_tokenizer",
|
|
69
|
+
"evaluate_model",
|
|
70
|
+
"compute_reconstruction_error",
|
|
71
|
+
"compute_generation_diversity",
|
|
72
|
+
"linear_kl_schedule",
|
|
73
|
+
"cyclical_kl_schedule",
|
|
74
|
+
# ODE solvers
|
|
75
|
+
"euler_solve",
|
|
76
|
+
"rk4_solve",
|
|
77
|
+
"ode_solve",
|
|
78
|
+
]
|