splatreg 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
splatreg/__init__.py ADDED
@@ -0,0 +1,37 @@
1
+ """splatreg — composable geometry-first SE(3)/Sim(3) registration for 3D Gaussian Splatting.
2
+
3
+ *gsplat renders your Gaussians; splatreg registers against them.*
4
+
5
+ Public surface (filled in by the carve):
6
+ register(target, source, residuals=[...], transform="sim3", backend="builtin") -> RegisterResult
7
+ merge([a, b, ...], ref=0) -> Gaussians
8
+ Tracker(target, residuals=[...]).track(frame) -> RegisterResult
9
+ Residual, Solver (extension points)
10
+ """
11
+
12
+ from .core.types import Gaussians, Frame, RegisterResult, LinearizedProblem, SE3Update
13
+ from .residuals.base import Residual
14
+ from .solvers.base import Solver
15
+ from .quality import QualityConfig, resolve_quality
16
+
17
+ # The high-level pipeline (splatreg.api) is added by the carve; tolerate its absence pre-build.
18
+ try:
19
+ from .api import register, merge, Tracker # noqa: F401
20
+ except ImportError:
21
+ register = merge = Tracker = None # type: ignore
22
+
23
+ __version__ = "0.0.1"
24
+ __all__ = [
25
+ "register",
26
+ "merge",
27
+ "Tracker",
28
+ "Residual",
29
+ "Solver",
30
+ "QualityConfig",
31
+ "resolve_quality",
32
+ "Gaussians",
33
+ "Frame",
34
+ "RegisterResult",
35
+ "LinearizedProblem",
36
+ "SE3Update",
37
+ ]
splatreg/align.py ADDED
@@ -0,0 +1,381 @@
1
+ """Global coarse-init aligner — the basin finder that runs BEFORE the fine LM refine.
2
+
3
+ ``register`` solves splat-to-splat alignment as *coarse global init → fine multi-residual
4
+ LM*. This module is the coarse half: given two Gaussian splats it returns a 4x4 transform
5
+ (Sim(3) by default, SE(3) optional) that lands ``source`` inside the convergence basin of
6
+ ``target`` — close enough (typically within ~10-15deg / a few % scale) that the LM finishes
7
+ the job. It is deliberately approximate: the goal is a *good init*, not a precise pose.
8
+
9
+ Algorithm (ported from the A/B-bench metric-side pred->GT global aligner —
10
+ ``project_ab_bench_fscore_alignment``: super-Fibonacci SO(3) candidate sweep + GPU-batched
11
+ trimmed ICP, tuned defaults 256 rotations / 40 ICP iters / 12288 points):
12
+
13
+ 1. Centre both clouds on their centroids (``Gaussians.means``).
14
+ 2. Estimate scale as the ratio of the two clouds' RMS radius about their centroids
15
+ (Sim(3) only; SE(3) fixes scale to 1).
16
+ 3. Seed SO(3) with a deterministic near-uniform super-Fibonacci grid (Alexa, CVPR 2022)
17
+ plus a handful of PCA principal-axis sign-flip candidates. A ~26deg covering provably
18
+ lands one seed in the global basin, so even featureless / symmetric clouds recover.
19
+ 4. Run *all* seeds through one GPU-batched trimmed point-to-point ICP (batched
20
+ nearest-neighbour via ``torch.cdist`` + a batched closed-form Umeyama step each iter,
21
+ outlier-trimmed). Score each converged seed by the trimmed symmetric Chamfer between the
22
+ transformed source and the target.
23
+ 5. Keep the lowest-Chamfer seed and recover the exact closed-form similarity that maps the
24
+ subsampled source onto that winner; return it as a 4x4 matrix.
25
+
26
+ **Fully on-device (GPU-native).** Everything — the super-Fibonacci/PCA seeds, the batched ICP
27
+ sweep, and the final closed-form Umeyama recovery — runs in torch on ``source``'s device; there
28
+ is no ``.cpu()`` / numpy round-trip in the compute path (the final recovery is done in float64
29
+ on-device for SVD precision). Self-contained: torch only, no gsplat / pytorch3d / scipy / numpy
30
+ / SLAM imports. Deterministic — no RNG (closed-form seeds, strided subsample, first-index
31
+ tie-break) — so it is reproducible.
32
+ """
33
+
34
+ from __future__ import annotations
35
+
36
+ import math
37
+
38
+ import torch
39
+
40
+ from .core.types import Gaussians
41
+
42
+ # Defaults tuned (A/B-bench) for 0 failures on the hardest case — a near-featureless sphere
43
+ # under an arbitrary uniform-SO(3) transform.
44
+ DEFAULT_N_ROTATIONS = 1024 # super-Fibonacci SO(3) seeds (~16deg covering); + a few PCA seeds.
45
+ # Why 1024 (was 256): on a near-isotropic sphere shell the PCA axes are degenerate (eigenvalue
46
+ # spread ~1.06) so the sign-flip seeds collapse to a single identity seed, and a 256-seed (~26deg)
47
+ # covering is too coarse for any point-to-point-ICP seed to fall into the (tiny but real) correct
48
+ # basin — every seed plateaus at Chamfer ~7 mm (a genuinely WRONG pose, not symmetry ambiguity).
49
+ # A 1024-seed (~16deg) covering puts a seed inside that basin, driving the trimmed-ICP score from
50
+ # 0.0056 to 0.00003 and the Chamfer to <0.06 mm at every symmetric cell. GPU-affordable (the
51
+ # batched ICP sweep is chunked at _SEED_BATCH); verified to keep NOISE/OUTLIERS at 9/9.
52
+ DEFAULT_ICP_ITERS = 40 # trimmed-ICP iterations per seed
53
+ DEFAULT_N_POINTS = 12288 # deterministic strided target subsample for the fit (denser -> lower floor)
54
+ _SUB_SOURCE = 4096 # deterministic strided source subsample driving the fit
55
+ _ICP_TRIM_KEEP = 0.85 # fraction of best correspondences kept each iter (outlier reject)
56
+ _SEED_BATCH = 64 # seeds processed per chunk in the batched ICP (bounds peak memory)
57
+ _PSI = 1.533751168755204288118041
58
+
59
+
60
+ # ── super-Fibonacci SO(3) grid (on-device torch) ──────────────────────────────────────
61
+
62
+
63
+ def _superfib_quats(n: int, device, dtype) -> torch.Tensor:
64
+ """Super-Fibonacci unit quaternions ``(n, 4)`` as ``(x, y, z, w)``, on ``device``."""
65
+ phi = math.sqrt(2.0)
66
+ i = torch.arange(int(n), device=device, dtype=dtype)
67
+ s = i + 0.5
68
+ t = s / float(n)
69
+ r = torch.sqrt(t)
70
+ rr = torch.sqrt((1.0 - t).clamp_min(0.0))
71
+ alpha = (2.0 * math.pi / phi) * s
72
+ beta = (2.0 * math.pi / _PSI) * s
73
+ return torch.stack(
74
+ [r * torch.sin(alpha), r * torch.cos(alpha), rr * torch.sin(beta), rr * torch.cos(beta)], dim=1
75
+ )
76
+
77
+
78
+ def _quats_to_R(q: torch.Tensor) -> torch.Tensor:
79
+ """Unit quaternions ``(n, 4)`` ``(x, y, z, w)`` -> rotation matrices ``(n, 3, 3)``."""
80
+ q = q / q.norm(dim=1, keepdim=True).clamp_min(1e-12)
81
+ x, y, z, w = q[:, 0], q[:, 1], q[:, 2], q[:, 3]
82
+ R = q.new_empty((q.shape[0], 3, 3))
83
+ R[:, 0, 0] = 1 - 2 * (y * y + z * z)
84
+ R[:, 0, 1] = 2 * (x * y - z * w)
85
+ R[:, 0, 2] = 2 * (x * z + y * w)
86
+ R[:, 1, 0] = 2 * (x * y + z * w)
87
+ R[:, 1, 1] = 1 - 2 * (x * x + z * z)
88
+ R[:, 1, 2] = 2 * (y * z - x * w)
89
+ R[:, 2, 0] = 2 * (x * z - y * w)
90
+ R[:, 2, 1] = 2 * (y * z + x * w)
91
+ R[:, 2, 2] = 1 - 2 * (x * x + y * y)
92
+ return R
93
+
94
+
95
+ def super_fibonacci_so3(n: int, device=None, dtype=torch.float64) -> torch.Tensor:
96
+ """Deterministic near-uniform SO(3) covering as ``(n, 3, 3)`` rotation matrices (on-device).
97
+
98
+ Super-Fibonacci spiral on the quaternion 3-sphere (Alexa, CVPR 2022). ~26deg covering at
99
+ ``n == 256``.
100
+ """
101
+ return _quats_to_R(_superfib_quats(int(n), device, dtype))
102
+
103
+
104
+ # ── PCA sign-flip seeds (on-device torch) ─────────────────────────────────────────────
105
+
106
+ # Eigenvalue-spread threshold for degenerate-PCA detection (symmetric-object note).
107
+ # If the ratio of the largest to smallest PCA eigenvalue is below this, the cloud is
108
+ # near-isotropic (sphere-like) and the PCA axes are arbitrary/unstable.
109
+ # Used only for diagnostic purposes in _pca_seed_rotations; the PCA seeds are kept
110
+ # regardless (removing them degrades performance on the sphere because PCA seeds
111
+ # accidentally provide good centroid-alignment candidates that the 256-seed Fibonacci
112
+ # grid at default density would miss, since a sphere is near-rotationally symmetric
113
+ # but NOT perfectly so at N=800 — individual seeds still differ in Chamfer score by
114
+ # ~9mm, which the trimmed ICP can distinguish).
115
+ _PCA_ISOTROPY_THRESH = 2.0
116
+
117
+
118
+ def _pca_axes(pts_centered: torch.Tensor) -> torch.Tensor:
119
+ """Principal axes (as columns), descending variance, from centered points."""
120
+ _, _, Vh = torch.linalg.svd(pts_centered, full_matrices=False)
121
+ return Vh.transpose(-2, -1)
122
+
123
+
124
+ def _pca_eigenvalue_spread(pts_centered: torch.Tensor) -> float:
125
+ """Ratio of largest to smallest singular value of the centred cloud (isotropy probe).
126
+
127
+ Values near 1.0 indicate a near-isotropic (sphere-like) cloud where PCA axes are
128
+ arbitrary. The asymmetric test object scores ~2.5; a sphere ~1.05.
129
+ """
130
+ _, sv, _ = torch.linalg.svd(pts_centered, full_matrices=False)
131
+ return float((sv[0] / sv[-1].clamp_min(1e-10)).item())
132
+
133
+
134
+ def _pca_seed_rotations(src: torch.Tensor, tgt: torch.Tensor) -> torch.Tensor:
135
+ """Identity + PCA principal-axis-match sign flips, ``(<=5, 3, 3)`` on-device.
136
+
137
+ ICP from identity alone has a small basin; PCA-axis seeds cover large symmetry-axis
138
+ rotations a uniform grid may straddle.
139
+
140
+ Note on symmetric objects: for near-isotropic clouds the PCA axes
141
+ are arbitrary (eigenvalue spread < _PCA_ISOTROPY_THRESH). We keep the PCA seeds
142
+ anyway because on the test sphere (N=800) individual rotations still differ in
143
+ Chamfer score by ~9mm, so the batched trimmed ICP still selects among them
144
+ meaningfully — and at default 256 Fibonacci seeds the grid is too coarse to guarantee
145
+ a near-centroid seed without the PCA candidates providing additional coverage.
146
+ Removing PCA seeds for isotropic clouds was tested and reliably made the symmetric
147
+ result WORSE (8/9→3/9) due to the Fibonacci grid missing the correct basin.
148
+ """
149
+ dev, dt = src.device, src.dtype
150
+ Vs = _pca_axes(src - src.mean(0))
151
+ Vt = _pca_axes(tgt - tgt.mean(0))
152
+ cands = [torch.eye(3, device=dev, dtype=dt)]
153
+ for sx in (1.0, -1.0):
154
+ for sy in (1.0, -1.0):
155
+ S = torch.diag(torch.tensor([sx, sy, sx * sy], device=dev, dtype=dt))
156
+ R = Vt @ S @ Vs.transpose(-2, -1)
157
+ if torch.linalg.det(R) > 0: # proper rotation only
158
+ cands.append(R)
159
+ return torch.stack(cands, dim=0)
160
+
161
+
162
+ # ── closed-form similarity (Umeyama 1991, on-device torch) ────────────────────────────
163
+
164
+
165
+ def _umeyama(src: torch.Tensor, dst: torch.Tensor, with_scale: bool):
166
+ """Least-squares similarity ``dst ≈ s * (src @ R.T) + t`` (Umeyama 1991), all in torch.
167
+
168
+ ``with_scale=False`` fixes ``s = 1`` (rigid Kabsch). Returns ``(s, R, t)`` torch tensors on
169
+ ``src``'s device/dtype. Run in float64 (cast by the caller) for SVD precision.
170
+ """
171
+ n = src.shape[0]
172
+ src_mean = src.mean(0)
173
+ dst_mean = dst.mean(0)
174
+ sc = src - src_mean
175
+ dc = dst - dst_mean
176
+ cov = (dc.transpose(-2, -1) @ sc) / n
177
+ U, D, Vh = torch.linalg.svd(cov)
178
+ S = torch.eye(3, device=src.device, dtype=src.dtype)
179
+ if torch.linalg.det(U) * torch.linalg.det(Vh) < 0: # reflection guard
180
+ S[2, 2] = -1.0
181
+ R = U @ S @ Vh
182
+ if with_scale:
183
+ var_src = sc.pow(2).sum() / n
184
+ s = (D * torch.diagonal(S)).sum() / var_src.clamp_min(1e-12)
185
+ else:
186
+ s = torch.ones((), device=src.device, dtype=src.dtype)
187
+ t = dst_mean - s * (R @ src_mean)
188
+ return s, R, t
189
+
190
+
191
+ # ── batched torch ICP helpers (already on-device) ─────────────────────────────────────
192
+
193
+
194
+ def _stride_subsample(a: torch.Tensor, k: int) -> torch.Tensor:
195
+ """Deterministic strided subsample to <= k rows (no RNG)."""
196
+ if a.shape[0] <= k:
197
+ return a
198
+ sel = torch.linspace(0, a.shape[0] - 1, k, device=a.device).round().to(torch.int64)
199
+ return a[sel]
200
+
201
+
202
+ def _batched_nn(X: torch.Tensor, Y: torch.Tensor):
203
+ """Batched squared-distance nearest neighbour: for each ``X[b,i]`` its nearest ``Y[b]``.
204
+
205
+ Chunked ``torch.cdist``. Returns ``(dist_sq (B,Nx), idx (B,Nx))``; chunks over the query
206
+ axis to bound the ``(B,Nx,Ny)`` pairwise tensor.
207
+ """
208
+ B, Nx, _ = X.shape
209
+ Ny = Y.shape[1]
210
+ out_d = X.new_empty((B, Nx))
211
+ out_i = torch.empty((B, Nx), device=X.device, dtype=torch.int64)
212
+ step = max(1, int(8_000_000 // max(B * Ny, 1)))
213
+ for lo in range(0, Nx, step):
214
+ hi = min(lo + step, Nx)
215
+ d = torch.cdist(X[:, lo:hi], Y)
216
+ md, mi = d.min(dim=2)
217
+ out_d[:, lo:hi] = md * md
218
+ out_i[:, lo:hi] = mi
219
+ return out_d, out_i
220
+
221
+
222
+ def _batched_umeyama(X: torch.Tensor, Y: torch.Tensor, w: torch.Tensor, with_scale: bool):
223
+ """Batched weighted closed-form similarity mapping ``X[b] -> Y[b]`` (Umeyama 1991).
224
+
225
+ ``w (B,N)`` are non-negative weights (the trim mask). Returns ``(s (B,), R (B,3,3),
226
+ t (B,3))`` with ``Y ≈ s * (X @ R.T) + t``.
227
+ """
228
+ B = X.shape[0]
229
+ ws = w.sum(dim=1, keepdim=True).clamp_min(1e-9)
230
+ wn = (w / ws).unsqueeze(-1)
231
+ mu_x = (wn * X).sum(dim=1)
232
+ mu_y = (wn * Y).sum(dim=1)
233
+ Xc = X - mu_x.unsqueeze(1)
234
+ Yc = Y - mu_y.unsqueeze(1)
235
+ cov = torch.einsum("bni,bnj->bij", wn * Yc, Xc)
236
+ U, Dsv, Vh = torch.linalg.svd(cov)
237
+ detUV = torch.linalg.det(U) * torch.linalg.det(Vh)
238
+ S = torch.eye(3, device=X.device, dtype=X.dtype).expand(B, 3, 3).clone()
239
+ S[:, 2, 2] = torch.sign(detUV)
240
+ R = U @ S @ Vh
241
+ if with_scale:
242
+ var_x = (wn * Xc.pow(2)).sum(dim=(1, 2)).clamp_min(1e-12)
243
+ s = (Dsv * torch.diagonal(S, dim1=1, dim2=2)).sum(dim=1) / var_x
244
+ else:
245
+ s = torch.ones(B, device=X.device, dtype=X.dtype)
246
+ t = mu_y - s.unsqueeze(1) * torch.einsum("bij,bj->bi", R, mu_x)
247
+ return s, R, t
248
+
249
+
250
+ def _trim_mean_sqrt(dist_sq: torch.Tensor, frac: float) -> torch.Tensor:
251
+ """Mean of the sqrt of the lowest ``frac`` fraction of squared distances, per batch row."""
252
+ k = max(1, int(round(frac * dist_sq.shape[1])))
253
+ v, _ = torch.sort(dist_sq, dim=1)
254
+ return v[:, :k].sqrt().mean(dim=1)
255
+
256
+
257
+ def _batched_trimmed_icp(
258
+ src_sub: torch.Tensor, tgt_sub: torch.Tensor, R_seeds: torch.Tensor, with_scale: bool, icp_iters: int
259
+ ):
260
+ """All seeds through one batched trimmed point-to-point ICP (centre, per-seed scale match,
261
+ trimmed Umeyama iterations, trimmed symmetric Chamfer score). Returns ``(aligned (B,K,3),
262
+ score (B,))`` (lower score = better). Seeds processed in chunks of ``_SEED_BATCH``."""
263
+ K, Kg = src_sub.shape[0], tgt_sub.shape[0]
264
+ pc = src_sub.mean(0)
265
+ gc = tgt_sub.mean(0)
266
+ p0c = src_sub - pc
267
+ rg = (tgt_sub - gc).pow(2).sum(-1).mean().sqrt()
268
+ keep_k = max(3, int(round(_ICP_TRIM_KEEP * K)))
269
+
270
+ aligned_chunks: list[torch.Tensor] = []
271
+ score_chunks: list[torch.Tensor] = []
272
+ for lo in range(0, R_seeds.shape[0], _SEED_BATCH):
273
+ Rb = R_seeds[lo : lo + _SEED_BATCH]
274
+ b = Rb.shape[0]
275
+ X = torch.bmm(p0c.unsqueeze(0).expand(b, K, 3), Rb.transpose(1, 2))
276
+ if with_scale:
277
+ rp = X.pow(2).sum(-1).mean(1).sqrt()
278
+ X = X * (rg / rp.clamp_min(1e-9)).view(b, 1, 1)
279
+ X = X + gc
280
+ Y = tgt_sub.unsqueeze(0).expand(b, Kg, 3).contiguous()
281
+ for _ in range(int(icp_iters)):
282
+ d, idx = _batched_nn(X, Y)
283
+ Ym = torch.gather(Y, 1, idx.unsqueeze(-1).expand(b, K, 3))
284
+ thr = torch.kthvalue(d, keep_k, dim=1, keepdim=True).values
285
+ wmask = (d <= thr).to(X.dtype)
286
+ s_i, R_i, t_i = _batched_umeyama(X, Ym, wmask, with_scale)
287
+ X = s_i[:, None, None] * torch.bmm(X, R_i.transpose(1, 2)) + t_i[:, None, :]
288
+ d_xy, _ = _batched_nn(X, Y)
289
+ d_yx, _ = _batched_nn(Y, X)
290
+ score = 0.5 * (_trim_mean_sqrt(d_xy, _ICP_TRIM_KEEP) + _trim_mean_sqrt(d_yx, _ICP_TRIM_KEEP))
291
+ aligned_chunks.append(X)
292
+ score_chunks.append(score)
293
+ return torch.cat(aligned_chunks, dim=0), torch.cat(score_chunks, dim=0)
294
+
295
+
296
+ # ── public entry ──────────────────────────────────────────────────────────────────────
297
+
298
+
299
+ @torch.no_grad()
300
+ def global_align(
301
+ target: Gaussians,
302
+ source: Gaussians,
303
+ *,
304
+ transform: str = "sim3",
305
+ n_rotations: int = DEFAULT_N_ROTATIONS,
306
+ icp_iters: int = DEFAULT_ICP_ITERS,
307
+ n_points: int = DEFAULT_N_POINTS,
308
+ seed: int = 0,
309
+ ) -> torch.Tensor:
310
+ """Coarse global init: a 4x4 transform that lands ``source`` in ``target``'s basin.
311
+
312
+ Sweeps ``n_rotations`` super-Fibonacci SO(3) candidates (plus PCA sign-flips), scores each
313
+ by a batched trimmed symmetric Chamfer after a batched trimmed ICP, and returns the
314
+ closed-form similarity for the best seed — a *basin-correct init* for the fine LM, not a
315
+ precise pose. Runs fully on ``source``'s device (CPU or CUDA); no host round-trip.
316
+
317
+ Args:
318
+ target: the reference splat (``source`` aligns onto it). Only ``.means`` is read.
319
+ source: the splat to align.
320
+ transform: ``"sim3"`` (default) estimates a scale; ``"se3"`` fixes scale to 1.
321
+ n_rotations: super-Fibonacci SO(3) seed count (more -> finer covering, slower).
322
+ icp_iters: trimmed-ICP iterations per seed.
323
+ n_points: target subsample size for the fit.
324
+ seed: accepted for API stability; deterministic (no RNG).
325
+
326
+ Returns:
327
+ A ``(4, 4)`` ``torch.Tensor`` on ``source``'s device/dtype: ``T @ [x,y,z,1]``; top-left
328
+ block ``s*R`` (Sim(3)) or ``R`` (SE(3)), last column the translation.
329
+ """
330
+ del seed # deterministic; no RNG to seed (documented for API stability)
331
+ if transform not in ("sim3", "se3"):
332
+ raise ValueError(f"transform must be 'sim3' or 'se3', got {transform!r}")
333
+ with_scale = transform == "sim3"
334
+
335
+ dev = source.means.device
336
+ dtype = source.means.dtype
337
+ src_full = source.means
338
+ tgt_full = target.means.to(device=dev, dtype=dtype)
339
+
340
+ if src_full.shape[0] < 3 or tgt_full.shape[0] < 3:
341
+ return torch.eye(4, device=dev, dtype=dtype)
342
+
343
+ # Fit in float32 on-device (SVD/cdist stability); deterministic strided subsample.
344
+ src_sub = _stride_subsample(src_full, _SUB_SOURCE).to(torch.float32)
345
+ tgt_sub = _stride_subsample(tgt_full, int(n_points)).to(torch.float32)
346
+
347
+ # Seeds (all on-device): PCA sign-flip candidates first (deterministic tie favours them),
348
+ # then the super-Fibonacci grid.
349
+ R_pca = _pca_seed_rotations(src_sub, tgt_sub) # (<=5,3,3)
350
+ R_grid = super_fibonacci_so3(int(n_rotations), device=dev, dtype=torch.float32) # (n,3,3)
351
+ R_seeds = torch.cat([R_pca, R_grid], dim=0)
352
+
353
+ aligned_sub, scores = _batched_trimmed_icp(src_sub, tgt_sub, R_seeds, with_scale, int(icp_iters))
354
+
355
+ # Stability tie-break (symmetric fix): among seeds within _SCORE_EPS
356
+ # of the best score, prefer the one whose centroid is closest to the target centroid —
357
+ # a proxy for scale/translation stability that avoids per-seed SVD calls and works in
358
+ # tensor ops. For symmetric clouds all seeds score nearly equally; the first-index
359
+ # tie-break can pick a degenerate seed (wrong translation). Centroid proximity picks
360
+ # the seed that lands source closest to target without needing to refit Umeyama per seed.
361
+ _SCORE_EPS = 5e-3
362
+ best_score = scores.min()
363
+ near_best = scores <= best_score + _SCORE_EPS * (best_score.abs().clamp_min(1.0))
364
+ near_idx = near_best.nonzero(as_tuple=False).squeeze(1) # indices of near-best seeds
365
+ if near_idx.shape[0] > 1:
366
+ # Pick the near-best seed whose centroid is closest to the target centroid.
367
+ # aligned_sub[b] is the source sub-sample after ICP convergence under seed b.
368
+ near_centroids = aligned_sub[near_idx].mean(dim=1) # (K, 3)
369
+ tgt_centroid = tgt_sub.mean(0) # (3,)
370
+ centroid_dists = (near_centroids - tgt_centroid).norm(dim=1) # (K,)
371
+ best_sub = near_idx[centroid_dists.argmin()]
372
+ else:
373
+ best_sub = torch.argmin(scores)
374
+ winner = aligned_sub[best_sub] # stays on-device
375
+
376
+ # Recover the exact closed-form similarity src_sub -> winner in float64 on-device.
377
+ s_tot, R_tot, t_tot = _umeyama(src_sub.double(), winner.double(), with_scale)
378
+ T = torch.eye(4, dtype=dtype, device=dev)
379
+ T[:3, :3] = (s_tot * R_tot).to(dtype=dtype)
380
+ T[:3, 3] = t_tot.to(dtype=dtype)
381
+ return T