splatreg 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- splatreg/__init__.py +37 -0
- splatreg/align.py +381 -0
- splatreg/align_features.py +1849 -0
- splatreg/api.py +736 -0
- splatreg/core/__init__.py +3 -0
- splatreg/core/lie.py +221 -0
- splatreg/core/types.py +95 -0
- splatreg/fuse.py +294 -0
- splatreg/geometry/__init__.py +10 -0
- splatreg/geometry/gaussian_sdf.py +370 -0
- splatreg/io.py +424 -0
- splatreg/py.typed +0 -0
- splatreg/quality.py +332 -0
- splatreg/residuals/__init__.py +31 -0
- splatreg/residuals/base.py +49 -0
- splatreg/residuals/icp.py +192 -0
- splatreg/residuals/photometric.py +337 -0
- splatreg/residuals/prior.py +117 -0
- splatreg/residuals/sdf.py +237 -0
- splatreg/solvers/__init__.py +10 -0
- splatreg/solvers/_backend_common.py +84 -0
- splatreg/solvers/base.py +21 -0
- splatreg/solvers/lm.py +360 -0
- splatreg/solvers/pypose_backend.py +113 -0
- splatreg/solvers/theseus_backend.py +143 -0
- splatreg/testing.py +75 -0
- splatreg/track.py +211 -0
- splatreg-1.0.0.dist-info/METADATA +218 -0
- splatreg-1.0.0.dist-info/RECORD +32 -0
- splatreg-1.0.0.dist-info/WHEEL +5 -0
- splatreg-1.0.0.dist-info/licenses/LICENSE +29 -0
- splatreg-1.0.0.dist-info/top_level.txt +1 -0
splatreg/__init__.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
"""splatreg — composable geometry-first SE(3)/Sim(3) registration for 3D Gaussian Splatting.
|
|
2
|
+
|
|
3
|
+
*gsplat renders your Gaussians; splatreg registers against them.*
|
|
4
|
+
|
|
5
|
+
Public surface (filled in by the carve):
|
|
6
|
+
register(target, source, residuals=[...], transform="sim3", backend="builtin") -> RegisterResult
|
|
7
|
+
merge([a, b, ...], ref=0) -> Gaussians
|
|
8
|
+
Tracker(target, residuals=[...]).track(frame) -> RegisterResult
|
|
9
|
+
Residual, Solver (extension points)
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from .core.types import Gaussians, Frame, RegisterResult, LinearizedProblem, SE3Update
|
|
13
|
+
from .residuals.base import Residual
|
|
14
|
+
from .solvers.base import Solver
|
|
15
|
+
from .quality import QualityConfig, resolve_quality
|
|
16
|
+
|
|
17
|
+
# The high-level pipeline (splatreg.api) is added by the carve; tolerate its absence pre-build.
|
|
18
|
+
try:
|
|
19
|
+
from .api import register, merge, Tracker # noqa: F401
|
|
20
|
+
except ImportError:
|
|
21
|
+
register = merge = Tracker = None # type: ignore
|
|
22
|
+
|
|
23
|
+
__version__ = "0.0.1"
|
|
24
|
+
__all__ = [
|
|
25
|
+
"register",
|
|
26
|
+
"merge",
|
|
27
|
+
"Tracker",
|
|
28
|
+
"Residual",
|
|
29
|
+
"Solver",
|
|
30
|
+
"QualityConfig",
|
|
31
|
+
"resolve_quality",
|
|
32
|
+
"Gaussians",
|
|
33
|
+
"Frame",
|
|
34
|
+
"RegisterResult",
|
|
35
|
+
"LinearizedProblem",
|
|
36
|
+
"SE3Update",
|
|
37
|
+
]
|
splatreg/align.py
ADDED
|
@@ -0,0 +1,381 @@
|
|
|
1
|
+
"""Global coarse-init aligner — the basin finder that runs BEFORE the fine LM refine.
|
|
2
|
+
|
|
3
|
+
``register`` solves splat-to-splat alignment as *coarse global init → fine multi-residual
|
|
4
|
+
LM*. This module is the coarse half: given two Gaussian splats it returns a 4x4 transform
|
|
5
|
+
(Sim(3) by default, SE(3) optional) that lands ``source`` inside the convergence basin of
|
|
6
|
+
``target`` — close enough (typically within ~10-15deg / a few % scale) that the LM finishes
|
|
7
|
+
the job. It is deliberately approximate: the goal is a *good init*, not a precise pose.
|
|
8
|
+
|
|
9
|
+
Algorithm (ported from the A/B-bench metric-side pred->GT global aligner —
|
|
10
|
+
``project_ab_bench_fscore_alignment``: super-Fibonacci SO(3) candidate sweep + GPU-batched
|
|
11
|
+
trimmed ICP, tuned defaults 256 rotations / 40 ICP iters / 12288 points):
|
|
12
|
+
|
|
13
|
+
1. Centre both clouds on their centroids (``Gaussians.means``).
|
|
14
|
+
2. Estimate scale as the ratio of the two clouds' RMS radius about their centroids
|
|
15
|
+
(Sim(3) only; SE(3) fixes scale to 1).
|
|
16
|
+
3. Seed SO(3) with a deterministic near-uniform super-Fibonacci grid (Alexa, CVPR 2022)
|
|
17
|
+
plus a handful of PCA principal-axis sign-flip candidates. A ~26deg covering provably
|
|
18
|
+
lands one seed in the global basin, so even featureless / symmetric clouds recover.
|
|
19
|
+
4. Run *all* seeds through one GPU-batched trimmed point-to-point ICP (batched
|
|
20
|
+
nearest-neighbour via ``torch.cdist`` + a batched closed-form Umeyama step each iter,
|
|
21
|
+
outlier-trimmed). Score each converged seed by the trimmed symmetric Chamfer between the
|
|
22
|
+
transformed source and the target.
|
|
23
|
+
5. Keep the lowest-Chamfer seed and recover the exact closed-form similarity that maps the
|
|
24
|
+
subsampled source onto that winner; return it as a 4x4 matrix.
|
|
25
|
+
|
|
26
|
+
**Fully on-device (GPU-native).** Everything — the super-Fibonacci/PCA seeds, the batched ICP
|
|
27
|
+
sweep, and the final closed-form Umeyama recovery — runs in torch on ``source``'s device; there
|
|
28
|
+
is no ``.cpu()`` / numpy round-trip in the compute path (the final recovery is done in float64
|
|
29
|
+
on-device for SVD precision). Self-contained: torch only, no gsplat / pytorch3d / scipy / numpy
|
|
30
|
+
/ SLAM imports. Deterministic — no RNG (closed-form seeds, strided subsample, first-index
|
|
31
|
+
tie-break) — so it is reproducible.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
from __future__ import annotations
|
|
35
|
+
|
|
36
|
+
import math
|
|
37
|
+
|
|
38
|
+
import torch
|
|
39
|
+
|
|
40
|
+
from .core.types import Gaussians
|
|
41
|
+
|
|
42
|
+
# Defaults tuned (A/B-bench) for 0 failures on the hardest case — a near-featureless sphere
|
|
43
|
+
# under an arbitrary uniform-SO(3) transform.
|
|
44
|
+
DEFAULT_N_ROTATIONS = 1024 # super-Fibonacci SO(3) seeds (~16deg covering); + a few PCA seeds.
|
|
45
|
+
# Why 1024 (was 256): on a near-isotropic sphere shell the PCA axes are degenerate (eigenvalue
|
|
46
|
+
# spread ~1.06) so the sign-flip seeds collapse to a single identity seed, and a 256-seed (~26deg)
|
|
47
|
+
# covering is too coarse for any point-to-point-ICP seed to fall into the (tiny but real) correct
|
|
48
|
+
# basin — every seed plateaus at Chamfer ~7 mm (a genuinely WRONG pose, not symmetry ambiguity).
|
|
49
|
+
# A 1024-seed (~16deg) covering puts a seed inside that basin, driving the trimmed-ICP score from
|
|
50
|
+
# 0.0056 to 0.00003 and the Chamfer to <0.06 mm at every symmetric cell. GPU-affordable (the
|
|
51
|
+
# batched ICP sweep is chunked at _SEED_BATCH); verified to keep NOISE/OUTLIERS at 9/9.
|
|
52
|
+
DEFAULT_ICP_ITERS = 40 # trimmed-ICP iterations per seed
|
|
53
|
+
DEFAULT_N_POINTS = 12288 # deterministic strided target subsample for the fit (denser -> lower floor)
|
|
54
|
+
_SUB_SOURCE = 4096 # deterministic strided source subsample driving the fit
|
|
55
|
+
_ICP_TRIM_KEEP = 0.85 # fraction of best correspondences kept each iter (outlier reject)
|
|
56
|
+
_SEED_BATCH = 64 # seeds processed per chunk in the batched ICP (bounds peak memory)
|
|
57
|
+
_PSI = 1.533751168755204288118041
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
# ── super-Fibonacci SO(3) grid (on-device torch) ──────────────────────────────────────
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def _superfib_quats(n: int, device, dtype) -> torch.Tensor:
|
|
64
|
+
"""Super-Fibonacci unit quaternions ``(n, 4)`` as ``(x, y, z, w)``, on ``device``."""
|
|
65
|
+
phi = math.sqrt(2.0)
|
|
66
|
+
i = torch.arange(int(n), device=device, dtype=dtype)
|
|
67
|
+
s = i + 0.5
|
|
68
|
+
t = s / float(n)
|
|
69
|
+
r = torch.sqrt(t)
|
|
70
|
+
rr = torch.sqrt((1.0 - t).clamp_min(0.0))
|
|
71
|
+
alpha = (2.0 * math.pi / phi) * s
|
|
72
|
+
beta = (2.0 * math.pi / _PSI) * s
|
|
73
|
+
return torch.stack(
|
|
74
|
+
[r * torch.sin(alpha), r * torch.cos(alpha), rr * torch.sin(beta), rr * torch.cos(beta)], dim=1
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def _quats_to_R(q: torch.Tensor) -> torch.Tensor:
|
|
79
|
+
"""Unit quaternions ``(n, 4)`` ``(x, y, z, w)`` -> rotation matrices ``(n, 3, 3)``."""
|
|
80
|
+
q = q / q.norm(dim=1, keepdim=True).clamp_min(1e-12)
|
|
81
|
+
x, y, z, w = q[:, 0], q[:, 1], q[:, 2], q[:, 3]
|
|
82
|
+
R = q.new_empty((q.shape[0], 3, 3))
|
|
83
|
+
R[:, 0, 0] = 1 - 2 * (y * y + z * z)
|
|
84
|
+
R[:, 0, 1] = 2 * (x * y - z * w)
|
|
85
|
+
R[:, 0, 2] = 2 * (x * z + y * w)
|
|
86
|
+
R[:, 1, 0] = 2 * (x * y + z * w)
|
|
87
|
+
R[:, 1, 1] = 1 - 2 * (x * x + z * z)
|
|
88
|
+
R[:, 1, 2] = 2 * (y * z - x * w)
|
|
89
|
+
R[:, 2, 0] = 2 * (x * z - y * w)
|
|
90
|
+
R[:, 2, 1] = 2 * (y * z + x * w)
|
|
91
|
+
R[:, 2, 2] = 1 - 2 * (x * x + y * y)
|
|
92
|
+
return R
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def super_fibonacci_so3(n: int, device=None, dtype=torch.float64) -> torch.Tensor:
|
|
96
|
+
"""Deterministic near-uniform SO(3) covering as ``(n, 3, 3)`` rotation matrices (on-device).
|
|
97
|
+
|
|
98
|
+
Super-Fibonacci spiral on the quaternion 3-sphere (Alexa, CVPR 2022). ~26deg covering at
|
|
99
|
+
``n == 256``.
|
|
100
|
+
"""
|
|
101
|
+
return _quats_to_R(_superfib_quats(int(n), device, dtype))
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
# ── PCA sign-flip seeds (on-device torch) ─────────────────────────────────────────────
|
|
105
|
+
|
|
106
|
+
# Eigenvalue-spread threshold for degenerate-PCA detection (symmetric-object note).
|
|
107
|
+
# If the ratio of the largest to smallest PCA eigenvalue is below this, the cloud is
|
|
108
|
+
# near-isotropic (sphere-like) and the PCA axes are arbitrary/unstable.
|
|
109
|
+
# Used only for diagnostic purposes in _pca_seed_rotations; the PCA seeds are kept
|
|
110
|
+
# regardless (removing them degrades performance on the sphere because PCA seeds
|
|
111
|
+
# accidentally provide good centroid-alignment candidates that the 256-seed Fibonacci
|
|
112
|
+
# grid at default density would miss, since a sphere is near-rotationally symmetric
|
|
113
|
+
# but NOT perfectly so at N=800 — individual seeds still differ in Chamfer score by
|
|
114
|
+
# ~9mm, which the trimmed ICP can distinguish).
|
|
115
|
+
_PCA_ISOTROPY_THRESH = 2.0
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def _pca_axes(pts_centered: torch.Tensor) -> torch.Tensor:
|
|
119
|
+
"""Principal axes (as columns), descending variance, from centered points."""
|
|
120
|
+
_, _, Vh = torch.linalg.svd(pts_centered, full_matrices=False)
|
|
121
|
+
return Vh.transpose(-2, -1)
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def _pca_eigenvalue_spread(pts_centered: torch.Tensor) -> float:
|
|
125
|
+
"""Ratio of largest to smallest singular value of the centred cloud (isotropy probe).
|
|
126
|
+
|
|
127
|
+
Values near 1.0 indicate a near-isotropic (sphere-like) cloud where PCA axes are
|
|
128
|
+
arbitrary. The asymmetric test object scores ~2.5; a sphere ~1.05.
|
|
129
|
+
"""
|
|
130
|
+
_, sv, _ = torch.linalg.svd(pts_centered, full_matrices=False)
|
|
131
|
+
return float((sv[0] / sv[-1].clamp_min(1e-10)).item())
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def _pca_seed_rotations(src: torch.Tensor, tgt: torch.Tensor) -> torch.Tensor:
|
|
135
|
+
"""Identity + PCA principal-axis-match sign flips, ``(<=5, 3, 3)`` on-device.
|
|
136
|
+
|
|
137
|
+
ICP from identity alone has a small basin; PCA-axis seeds cover large symmetry-axis
|
|
138
|
+
rotations a uniform grid may straddle.
|
|
139
|
+
|
|
140
|
+
Note on symmetric objects: for near-isotropic clouds the PCA axes
|
|
141
|
+
are arbitrary (eigenvalue spread < _PCA_ISOTROPY_THRESH). We keep the PCA seeds
|
|
142
|
+
anyway because on the test sphere (N=800) individual rotations still differ in
|
|
143
|
+
Chamfer score by ~9mm, so the batched trimmed ICP still selects among them
|
|
144
|
+
meaningfully — and at default 256 Fibonacci seeds the grid is too coarse to guarantee
|
|
145
|
+
a near-centroid seed without the PCA candidates providing additional coverage.
|
|
146
|
+
Removing PCA seeds for isotropic clouds was tested and reliably made the symmetric
|
|
147
|
+
result WORSE (8/9→3/9) due to the Fibonacci grid missing the correct basin.
|
|
148
|
+
"""
|
|
149
|
+
dev, dt = src.device, src.dtype
|
|
150
|
+
Vs = _pca_axes(src - src.mean(0))
|
|
151
|
+
Vt = _pca_axes(tgt - tgt.mean(0))
|
|
152
|
+
cands = [torch.eye(3, device=dev, dtype=dt)]
|
|
153
|
+
for sx in (1.0, -1.0):
|
|
154
|
+
for sy in (1.0, -1.0):
|
|
155
|
+
S = torch.diag(torch.tensor([sx, sy, sx * sy], device=dev, dtype=dt))
|
|
156
|
+
R = Vt @ S @ Vs.transpose(-2, -1)
|
|
157
|
+
if torch.linalg.det(R) > 0: # proper rotation only
|
|
158
|
+
cands.append(R)
|
|
159
|
+
return torch.stack(cands, dim=0)
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
# ── closed-form similarity (Umeyama 1991, on-device torch) ────────────────────────────
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def _umeyama(src: torch.Tensor, dst: torch.Tensor, with_scale: bool):
|
|
166
|
+
"""Least-squares similarity ``dst ≈ s * (src @ R.T) + t`` (Umeyama 1991), all in torch.
|
|
167
|
+
|
|
168
|
+
``with_scale=False`` fixes ``s = 1`` (rigid Kabsch). Returns ``(s, R, t)`` torch tensors on
|
|
169
|
+
``src``'s device/dtype. Run in float64 (cast by the caller) for SVD precision.
|
|
170
|
+
"""
|
|
171
|
+
n = src.shape[0]
|
|
172
|
+
src_mean = src.mean(0)
|
|
173
|
+
dst_mean = dst.mean(0)
|
|
174
|
+
sc = src - src_mean
|
|
175
|
+
dc = dst - dst_mean
|
|
176
|
+
cov = (dc.transpose(-2, -1) @ sc) / n
|
|
177
|
+
U, D, Vh = torch.linalg.svd(cov)
|
|
178
|
+
S = torch.eye(3, device=src.device, dtype=src.dtype)
|
|
179
|
+
if torch.linalg.det(U) * torch.linalg.det(Vh) < 0: # reflection guard
|
|
180
|
+
S[2, 2] = -1.0
|
|
181
|
+
R = U @ S @ Vh
|
|
182
|
+
if with_scale:
|
|
183
|
+
var_src = sc.pow(2).sum() / n
|
|
184
|
+
s = (D * torch.diagonal(S)).sum() / var_src.clamp_min(1e-12)
|
|
185
|
+
else:
|
|
186
|
+
s = torch.ones((), device=src.device, dtype=src.dtype)
|
|
187
|
+
t = dst_mean - s * (R @ src_mean)
|
|
188
|
+
return s, R, t
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
# ── batched torch ICP helpers (already on-device) ─────────────────────────────────────
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
def _stride_subsample(a: torch.Tensor, k: int) -> torch.Tensor:
|
|
195
|
+
"""Deterministic strided subsample to <= k rows (no RNG)."""
|
|
196
|
+
if a.shape[0] <= k:
|
|
197
|
+
return a
|
|
198
|
+
sel = torch.linspace(0, a.shape[0] - 1, k, device=a.device).round().to(torch.int64)
|
|
199
|
+
return a[sel]
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def _batched_nn(X: torch.Tensor, Y: torch.Tensor):
|
|
203
|
+
"""Batched squared-distance nearest neighbour: for each ``X[b,i]`` its nearest ``Y[b]``.
|
|
204
|
+
|
|
205
|
+
Chunked ``torch.cdist``. Returns ``(dist_sq (B,Nx), idx (B,Nx))``; chunks over the query
|
|
206
|
+
axis to bound the ``(B,Nx,Ny)`` pairwise tensor.
|
|
207
|
+
"""
|
|
208
|
+
B, Nx, _ = X.shape
|
|
209
|
+
Ny = Y.shape[1]
|
|
210
|
+
out_d = X.new_empty((B, Nx))
|
|
211
|
+
out_i = torch.empty((B, Nx), device=X.device, dtype=torch.int64)
|
|
212
|
+
step = max(1, int(8_000_000 // max(B * Ny, 1)))
|
|
213
|
+
for lo in range(0, Nx, step):
|
|
214
|
+
hi = min(lo + step, Nx)
|
|
215
|
+
d = torch.cdist(X[:, lo:hi], Y)
|
|
216
|
+
md, mi = d.min(dim=2)
|
|
217
|
+
out_d[:, lo:hi] = md * md
|
|
218
|
+
out_i[:, lo:hi] = mi
|
|
219
|
+
return out_d, out_i
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
def _batched_umeyama(X: torch.Tensor, Y: torch.Tensor, w: torch.Tensor, with_scale: bool):
|
|
223
|
+
"""Batched weighted closed-form similarity mapping ``X[b] -> Y[b]`` (Umeyama 1991).
|
|
224
|
+
|
|
225
|
+
``w (B,N)`` are non-negative weights (the trim mask). Returns ``(s (B,), R (B,3,3),
|
|
226
|
+
t (B,3))`` with ``Y ≈ s * (X @ R.T) + t``.
|
|
227
|
+
"""
|
|
228
|
+
B = X.shape[0]
|
|
229
|
+
ws = w.sum(dim=1, keepdim=True).clamp_min(1e-9)
|
|
230
|
+
wn = (w / ws).unsqueeze(-1)
|
|
231
|
+
mu_x = (wn * X).sum(dim=1)
|
|
232
|
+
mu_y = (wn * Y).sum(dim=1)
|
|
233
|
+
Xc = X - mu_x.unsqueeze(1)
|
|
234
|
+
Yc = Y - mu_y.unsqueeze(1)
|
|
235
|
+
cov = torch.einsum("bni,bnj->bij", wn * Yc, Xc)
|
|
236
|
+
U, Dsv, Vh = torch.linalg.svd(cov)
|
|
237
|
+
detUV = torch.linalg.det(U) * torch.linalg.det(Vh)
|
|
238
|
+
S = torch.eye(3, device=X.device, dtype=X.dtype).expand(B, 3, 3).clone()
|
|
239
|
+
S[:, 2, 2] = torch.sign(detUV)
|
|
240
|
+
R = U @ S @ Vh
|
|
241
|
+
if with_scale:
|
|
242
|
+
var_x = (wn * Xc.pow(2)).sum(dim=(1, 2)).clamp_min(1e-12)
|
|
243
|
+
s = (Dsv * torch.diagonal(S, dim1=1, dim2=2)).sum(dim=1) / var_x
|
|
244
|
+
else:
|
|
245
|
+
s = torch.ones(B, device=X.device, dtype=X.dtype)
|
|
246
|
+
t = mu_y - s.unsqueeze(1) * torch.einsum("bij,bj->bi", R, mu_x)
|
|
247
|
+
return s, R, t
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
def _trim_mean_sqrt(dist_sq: torch.Tensor, frac: float) -> torch.Tensor:
|
|
251
|
+
"""Mean of the sqrt of the lowest ``frac`` fraction of squared distances, per batch row."""
|
|
252
|
+
k = max(1, int(round(frac * dist_sq.shape[1])))
|
|
253
|
+
v, _ = torch.sort(dist_sq, dim=1)
|
|
254
|
+
return v[:, :k].sqrt().mean(dim=1)
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
def _batched_trimmed_icp(
|
|
258
|
+
src_sub: torch.Tensor, tgt_sub: torch.Tensor, R_seeds: torch.Tensor, with_scale: bool, icp_iters: int
|
|
259
|
+
):
|
|
260
|
+
"""All seeds through one batched trimmed point-to-point ICP (centre, per-seed scale match,
|
|
261
|
+
trimmed Umeyama iterations, trimmed symmetric Chamfer score). Returns ``(aligned (B,K,3),
|
|
262
|
+
score (B,))`` (lower score = better). Seeds processed in chunks of ``_SEED_BATCH``."""
|
|
263
|
+
K, Kg = src_sub.shape[0], tgt_sub.shape[0]
|
|
264
|
+
pc = src_sub.mean(0)
|
|
265
|
+
gc = tgt_sub.mean(0)
|
|
266
|
+
p0c = src_sub - pc
|
|
267
|
+
rg = (tgt_sub - gc).pow(2).sum(-1).mean().sqrt()
|
|
268
|
+
keep_k = max(3, int(round(_ICP_TRIM_KEEP * K)))
|
|
269
|
+
|
|
270
|
+
aligned_chunks: list[torch.Tensor] = []
|
|
271
|
+
score_chunks: list[torch.Tensor] = []
|
|
272
|
+
for lo in range(0, R_seeds.shape[0], _SEED_BATCH):
|
|
273
|
+
Rb = R_seeds[lo : lo + _SEED_BATCH]
|
|
274
|
+
b = Rb.shape[0]
|
|
275
|
+
X = torch.bmm(p0c.unsqueeze(0).expand(b, K, 3), Rb.transpose(1, 2))
|
|
276
|
+
if with_scale:
|
|
277
|
+
rp = X.pow(2).sum(-1).mean(1).sqrt()
|
|
278
|
+
X = X * (rg / rp.clamp_min(1e-9)).view(b, 1, 1)
|
|
279
|
+
X = X + gc
|
|
280
|
+
Y = tgt_sub.unsqueeze(0).expand(b, Kg, 3).contiguous()
|
|
281
|
+
for _ in range(int(icp_iters)):
|
|
282
|
+
d, idx = _batched_nn(X, Y)
|
|
283
|
+
Ym = torch.gather(Y, 1, idx.unsqueeze(-1).expand(b, K, 3))
|
|
284
|
+
thr = torch.kthvalue(d, keep_k, dim=1, keepdim=True).values
|
|
285
|
+
wmask = (d <= thr).to(X.dtype)
|
|
286
|
+
s_i, R_i, t_i = _batched_umeyama(X, Ym, wmask, with_scale)
|
|
287
|
+
X = s_i[:, None, None] * torch.bmm(X, R_i.transpose(1, 2)) + t_i[:, None, :]
|
|
288
|
+
d_xy, _ = _batched_nn(X, Y)
|
|
289
|
+
d_yx, _ = _batched_nn(Y, X)
|
|
290
|
+
score = 0.5 * (_trim_mean_sqrt(d_xy, _ICP_TRIM_KEEP) + _trim_mean_sqrt(d_yx, _ICP_TRIM_KEEP))
|
|
291
|
+
aligned_chunks.append(X)
|
|
292
|
+
score_chunks.append(score)
|
|
293
|
+
return torch.cat(aligned_chunks, dim=0), torch.cat(score_chunks, dim=0)
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
# ── public entry ──────────────────────────────────────────────────────────────────────
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
@torch.no_grad()
|
|
300
|
+
def global_align(
|
|
301
|
+
target: Gaussians,
|
|
302
|
+
source: Gaussians,
|
|
303
|
+
*,
|
|
304
|
+
transform: str = "sim3",
|
|
305
|
+
n_rotations: int = DEFAULT_N_ROTATIONS,
|
|
306
|
+
icp_iters: int = DEFAULT_ICP_ITERS,
|
|
307
|
+
n_points: int = DEFAULT_N_POINTS,
|
|
308
|
+
seed: int = 0,
|
|
309
|
+
) -> torch.Tensor:
|
|
310
|
+
"""Coarse global init: a 4x4 transform that lands ``source`` in ``target``'s basin.
|
|
311
|
+
|
|
312
|
+
Sweeps ``n_rotations`` super-Fibonacci SO(3) candidates (plus PCA sign-flips), scores each
|
|
313
|
+
by a batched trimmed symmetric Chamfer after a batched trimmed ICP, and returns the
|
|
314
|
+
closed-form similarity for the best seed — a *basin-correct init* for the fine LM, not a
|
|
315
|
+
precise pose. Runs fully on ``source``'s device (CPU or CUDA); no host round-trip.
|
|
316
|
+
|
|
317
|
+
Args:
|
|
318
|
+
target: the reference splat (``source`` aligns onto it). Only ``.means`` is read.
|
|
319
|
+
source: the splat to align.
|
|
320
|
+
transform: ``"sim3"`` (default) estimates a scale; ``"se3"`` fixes scale to 1.
|
|
321
|
+
n_rotations: super-Fibonacci SO(3) seed count (more -> finer covering, slower).
|
|
322
|
+
icp_iters: trimmed-ICP iterations per seed.
|
|
323
|
+
n_points: target subsample size for the fit.
|
|
324
|
+
seed: accepted for API stability; deterministic (no RNG).
|
|
325
|
+
|
|
326
|
+
Returns:
|
|
327
|
+
A ``(4, 4)`` ``torch.Tensor`` on ``source``'s device/dtype: ``T @ [x,y,z,1]``; top-left
|
|
328
|
+
block ``s*R`` (Sim(3)) or ``R`` (SE(3)), last column the translation.
|
|
329
|
+
"""
|
|
330
|
+
del seed # deterministic; no RNG to seed (documented for API stability)
|
|
331
|
+
if transform not in ("sim3", "se3"):
|
|
332
|
+
raise ValueError(f"transform must be 'sim3' or 'se3', got {transform!r}")
|
|
333
|
+
with_scale = transform == "sim3"
|
|
334
|
+
|
|
335
|
+
dev = source.means.device
|
|
336
|
+
dtype = source.means.dtype
|
|
337
|
+
src_full = source.means
|
|
338
|
+
tgt_full = target.means.to(device=dev, dtype=dtype)
|
|
339
|
+
|
|
340
|
+
if src_full.shape[0] < 3 or tgt_full.shape[0] < 3:
|
|
341
|
+
return torch.eye(4, device=dev, dtype=dtype)
|
|
342
|
+
|
|
343
|
+
# Fit in float32 on-device (SVD/cdist stability); deterministic strided subsample.
|
|
344
|
+
src_sub = _stride_subsample(src_full, _SUB_SOURCE).to(torch.float32)
|
|
345
|
+
tgt_sub = _stride_subsample(tgt_full, int(n_points)).to(torch.float32)
|
|
346
|
+
|
|
347
|
+
# Seeds (all on-device): PCA sign-flip candidates first (deterministic tie favours them),
|
|
348
|
+
# then the super-Fibonacci grid.
|
|
349
|
+
R_pca = _pca_seed_rotations(src_sub, tgt_sub) # (<=5,3,3)
|
|
350
|
+
R_grid = super_fibonacci_so3(int(n_rotations), device=dev, dtype=torch.float32) # (n,3,3)
|
|
351
|
+
R_seeds = torch.cat([R_pca, R_grid], dim=0)
|
|
352
|
+
|
|
353
|
+
aligned_sub, scores = _batched_trimmed_icp(src_sub, tgt_sub, R_seeds, with_scale, int(icp_iters))
|
|
354
|
+
|
|
355
|
+
# Stability tie-break (symmetric fix): among seeds within _SCORE_EPS
|
|
356
|
+
# of the best score, prefer the one whose centroid is closest to the target centroid —
|
|
357
|
+
# a proxy for scale/translation stability that avoids per-seed SVD calls and works in
|
|
358
|
+
# tensor ops. For symmetric clouds all seeds score nearly equally; the first-index
|
|
359
|
+
# tie-break can pick a degenerate seed (wrong translation). Centroid proximity picks
|
|
360
|
+
# the seed that lands source closest to target without needing to refit Umeyama per seed.
|
|
361
|
+
_SCORE_EPS = 5e-3
|
|
362
|
+
best_score = scores.min()
|
|
363
|
+
near_best = scores <= best_score + _SCORE_EPS * (best_score.abs().clamp_min(1.0))
|
|
364
|
+
near_idx = near_best.nonzero(as_tuple=False).squeeze(1) # indices of near-best seeds
|
|
365
|
+
if near_idx.shape[0] > 1:
|
|
366
|
+
# Pick the near-best seed whose centroid is closest to the target centroid.
|
|
367
|
+
# aligned_sub[b] is the source sub-sample after ICP convergence under seed b.
|
|
368
|
+
near_centroids = aligned_sub[near_idx].mean(dim=1) # (K, 3)
|
|
369
|
+
tgt_centroid = tgt_sub.mean(0) # (3,)
|
|
370
|
+
centroid_dists = (near_centroids - tgt_centroid).norm(dim=1) # (K,)
|
|
371
|
+
best_sub = near_idx[centroid_dists.argmin()]
|
|
372
|
+
else:
|
|
373
|
+
best_sub = torch.argmin(scores)
|
|
374
|
+
winner = aligned_sub[best_sub] # stays on-device
|
|
375
|
+
|
|
376
|
+
# Recover the exact closed-form similarity src_sub -> winner in float64 on-device.
|
|
377
|
+
s_tot, R_tot, t_tot = _umeyama(src_sub.double(), winner.double(), with_scale)
|
|
378
|
+
T = torch.eye(4, dtype=dtype, device=dev)
|
|
379
|
+
T[:3, :3] = (s_tot * R_tot).to(dtype=dtype)
|
|
380
|
+
T[:3, 3] = t_tot.to(dtype=dtype)
|
|
381
|
+
return T
|