numkong 7.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +201 -0
- package/README.md +495 -0
- package/binding.gyp +540 -0
- package/c/dispatch.h +512 -0
- package/c/dispatch_bf16.c +389 -0
- package/c/dispatch_bf16c.c +52 -0
- package/c/dispatch_e2m3.c +263 -0
- package/c/dispatch_e3m2.c +243 -0
- package/c/dispatch_e4m3.c +276 -0
- package/c/dispatch_e5m2.c +272 -0
- package/c/dispatch_f16.c +376 -0
- package/c/dispatch_f16c.c +58 -0
- package/c/dispatch_f32.c +378 -0
- package/c/dispatch_f32c.c +99 -0
- package/c/dispatch_f64.c +296 -0
- package/c/dispatch_f64c.c +98 -0
- package/c/dispatch_i16.c +96 -0
- package/c/dispatch_i32.c +89 -0
- package/c/dispatch_i4.c +150 -0
- package/c/dispatch_i64.c +86 -0
- package/c/dispatch_i8.c +289 -0
- package/c/dispatch_other.c +330 -0
- package/c/dispatch_u1.c +148 -0
- package/c/dispatch_u16.c +124 -0
- package/c/dispatch_u32.c +118 -0
- package/c/dispatch_u4.c +150 -0
- package/c/dispatch_u64.c +102 -0
- package/c/dispatch_u8.c +303 -0
- package/c/numkong.c +950 -0
- package/include/README.md +573 -0
- package/include/module.modulemap +129 -0
- package/include/numkong/attention/sapphireamx.h +1361 -0
- package/include/numkong/attention/sme.h +2066 -0
- package/include/numkong/attention.h +49 -0
- package/include/numkong/capabilities.h +748 -0
- package/include/numkong/cast/README.md +262 -0
- package/include/numkong/cast/haswell.h +975 -0
- package/include/numkong/cast/icelake.h +470 -0
- package/include/numkong/cast/neon.h +1192 -0
- package/include/numkong/cast/rvv.h +1021 -0
- package/include/numkong/cast/sapphire.h +262 -0
- package/include/numkong/cast/serial.h +2262 -0
- package/include/numkong/cast/skylake.h +856 -0
- package/include/numkong/cast/v128relaxed.h +180 -0
- package/include/numkong/cast.h +230 -0
- package/include/numkong/curved/README.md +223 -0
- package/include/numkong/curved/genoa.h +182 -0
- package/include/numkong/curved/haswell.h +276 -0
- package/include/numkong/curved/neon.h +205 -0
- package/include/numkong/curved/neonbfdot.h +212 -0
- package/include/numkong/curved/neonhalf.h +212 -0
- package/include/numkong/curved/rvv.h +305 -0
- package/include/numkong/curved/serial.h +207 -0
- package/include/numkong/curved/skylake.h +457 -0
- package/include/numkong/curved/smef64.h +506 -0
- package/include/numkong/curved.h +517 -0
- package/include/numkong/curved.hpp +144 -0
- package/include/numkong/dot/README.md +425 -0
- package/include/numkong/dot/alder.h +563 -0
- package/include/numkong/dot/genoa.h +315 -0
- package/include/numkong/dot/haswell.h +1688 -0
- package/include/numkong/dot/icelake.h +883 -0
- package/include/numkong/dot/neon.h +818 -0
- package/include/numkong/dot/neonbfdot.h +244 -0
- package/include/numkong/dot/neonfhm.h +360 -0
- package/include/numkong/dot/neonhalf.h +198 -0
- package/include/numkong/dot/neonsdot.h +508 -0
- package/include/numkong/dot/rvv.h +714 -0
- package/include/numkong/dot/rvvbb.h +72 -0
- package/include/numkong/dot/rvvbf16.h +123 -0
- package/include/numkong/dot/rvvhalf.h +129 -0
- package/include/numkong/dot/sapphire.h +141 -0
- package/include/numkong/dot/serial.h +838 -0
- package/include/numkong/dot/sierra.h +405 -0
- package/include/numkong/dot/skylake.h +1084 -0
- package/include/numkong/dot/sve.h +379 -0
- package/include/numkong/dot/svebfdot.h +74 -0
- package/include/numkong/dot/svehalf.h +123 -0
- package/include/numkong/dot/v128relaxed.h +1258 -0
- package/include/numkong/dot.h +1070 -0
- package/include/numkong/dot.hpp +94 -0
- package/include/numkong/dots/README.md +496 -0
- package/include/numkong/dots/alder.h +114 -0
- package/include/numkong/dots/genoa.h +94 -0
- package/include/numkong/dots/haswell.h +295 -0
- package/include/numkong/dots/icelake.h +171 -0
- package/include/numkong/dots/neon.h +120 -0
- package/include/numkong/dots/neonbfdot.h +58 -0
- package/include/numkong/dots/neonfhm.h +94 -0
- package/include/numkong/dots/neonhalf.h +57 -0
- package/include/numkong/dots/neonsdot.h +108 -0
- package/include/numkong/dots/rvv.h +2486 -0
- package/include/numkong/dots/sapphireamx.h +3973 -0
- package/include/numkong/dots/serial.h +2844 -0
- package/include/numkong/dots/sierra.h +97 -0
- package/include/numkong/dots/skylake.h +196 -0
- package/include/numkong/dots/sme.h +5372 -0
- package/include/numkong/dots/smebi32.h +461 -0
- package/include/numkong/dots/smef64.h +1318 -0
- package/include/numkong/dots/smehalf.h +47 -0
- package/include/numkong/dots/v128relaxed.h +294 -0
- package/include/numkong/dots.h +2804 -0
- package/include/numkong/dots.hpp +639 -0
- package/include/numkong/each/README.md +469 -0
- package/include/numkong/each/haswell.h +1658 -0
- package/include/numkong/each/icelake.h +272 -0
- package/include/numkong/each/neon.h +1104 -0
- package/include/numkong/each/neonbfdot.h +212 -0
- package/include/numkong/each/neonhalf.h +410 -0
- package/include/numkong/each/rvv.h +1121 -0
- package/include/numkong/each/sapphire.h +477 -0
- package/include/numkong/each/serial.h +260 -0
- package/include/numkong/each/skylake.h +1562 -0
- package/include/numkong/each.h +2146 -0
- package/include/numkong/each.hpp +434 -0
- package/include/numkong/geospatial/README.md +147 -0
- package/include/numkong/geospatial/haswell.h +593 -0
- package/include/numkong/geospatial/neon.h +571 -0
- package/include/numkong/geospatial/rvv.h +701 -0
- package/include/numkong/geospatial/serial.h +309 -0
- package/include/numkong/geospatial/skylake.h +577 -0
- package/include/numkong/geospatial/v128relaxed.h +613 -0
- package/include/numkong/geospatial.h +453 -0
- package/include/numkong/geospatial.hpp +235 -0
- package/include/numkong/matrix.hpp +336 -0
- package/include/numkong/maxsim/README.md +187 -0
- package/include/numkong/maxsim/alder.h +511 -0
- package/include/numkong/maxsim/genoa.h +115 -0
- package/include/numkong/maxsim/haswell.h +553 -0
- package/include/numkong/maxsim/icelake.h +480 -0
- package/include/numkong/maxsim/neonsdot.h +394 -0
- package/include/numkong/maxsim/sapphireamx.h +877 -0
- package/include/numkong/maxsim/serial.h +490 -0
- package/include/numkong/maxsim/sme.h +929 -0
- package/include/numkong/maxsim/v128relaxed.h +280 -0
- package/include/numkong/maxsim.h +571 -0
- package/include/numkong/maxsim.hpp +133 -0
- package/include/numkong/mesh/README.md +227 -0
- package/include/numkong/mesh/haswell.h +2235 -0
- package/include/numkong/mesh/neon.h +1329 -0
- package/include/numkong/mesh/neonbfdot.h +842 -0
- package/include/numkong/mesh/neonhalf.h +616 -0
- package/include/numkong/mesh/rvv.h +916 -0
- package/include/numkong/mesh/serial.h +742 -0
- package/include/numkong/mesh/skylake.h +1135 -0
- package/include/numkong/mesh/v128relaxed.h +1052 -0
- package/include/numkong/mesh.h +652 -0
- package/include/numkong/mesh.hpp +762 -0
- package/include/numkong/numkong.h +78 -0
- package/include/numkong/numkong.hpp +57 -0
- package/include/numkong/probability/README.md +173 -0
- package/include/numkong/probability/haswell.h +267 -0
- package/include/numkong/probability/neon.h +225 -0
- package/include/numkong/probability/rvv.h +409 -0
- package/include/numkong/probability/serial.h +169 -0
- package/include/numkong/probability/skylake.h +324 -0
- package/include/numkong/probability.h +383 -0
- package/include/numkong/probability.hpp +120 -0
- package/include/numkong/random.h +50 -0
- package/include/numkong/random.hpp +285 -0
- package/include/numkong/reduce/README.md +547 -0
- package/include/numkong/reduce/alder.h +632 -0
- package/include/numkong/reduce/genoa.h +201 -0
- package/include/numkong/reduce/haswell.h +3783 -0
- package/include/numkong/reduce/icelake.h +549 -0
- package/include/numkong/reduce/neon.h +3841 -0
- package/include/numkong/reduce/neonbfdot.h +353 -0
- package/include/numkong/reduce/neonfhm.h +665 -0
- package/include/numkong/reduce/neonhalf.h +157 -0
- package/include/numkong/reduce/neonsdot.h +357 -0
- package/include/numkong/reduce/rvv.h +3407 -0
- package/include/numkong/reduce/serial.h +757 -0
- package/include/numkong/reduce/sierra.h +338 -0
- package/include/numkong/reduce/skylake.h +3792 -0
- package/include/numkong/reduce/v128relaxed.h +2302 -0
- package/include/numkong/reduce.h +1597 -0
- package/include/numkong/reduce.hpp +633 -0
- package/include/numkong/scalar/README.md +89 -0
- package/include/numkong/scalar/haswell.h +113 -0
- package/include/numkong/scalar/neon.h +122 -0
- package/include/numkong/scalar/neonhalf.h +70 -0
- package/include/numkong/scalar/rvv.h +211 -0
- package/include/numkong/scalar/sapphire.h +63 -0
- package/include/numkong/scalar/serial.h +332 -0
- package/include/numkong/scalar/v128relaxed.h +56 -0
- package/include/numkong/scalar.h +683 -0
- package/include/numkong/set/README.md +179 -0
- package/include/numkong/set/haswell.h +334 -0
- package/include/numkong/set/icelake.h +485 -0
- package/include/numkong/set/neon.h +364 -0
- package/include/numkong/set/rvv.h +226 -0
- package/include/numkong/set/rvvbb.h +117 -0
- package/include/numkong/set/serial.h +174 -0
- package/include/numkong/set/sve.h +185 -0
- package/include/numkong/set/v128relaxed.h +240 -0
- package/include/numkong/set.h +457 -0
- package/include/numkong/set.hpp +114 -0
- package/include/numkong/sets/README.md +149 -0
- package/include/numkong/sets/haswell.h +63 -0
- package/include/numkong/sets/icelake.h +66 -0
- package/include/numkong/sets/neon.h +61 -0
- package/include/numkong/sets/serial.h +43 -0
- package/include/numkong/sets/smebi32.h +1099 -0
- package/include/numkong/sets/v128relaxed.h +58 -0
- package/include/numkong/sets.h +339 -0
- package/include/numkong/sparse/README.md +156 -0
- package/include/numkong/sparse/icelake.h +463 -0
- package/include/numkong/sparse/neon.h +288 -0
- package/include/numkong/sparse/serial.h +117 -0
- package/include/numkong/sparse/sve2.h +507 -0
- package/include/numkong/sparse/turin.h +322 -0
- package/include/numkong/sparse.h +363 -0
- package/include/numkong/sparse.hpp +113 -0
- package/include/numkong/spatial/README.md +435 -0
- package/include/numkong/spatial/alder.h +607 -0
- package/include/numkong/spatial/genoa.h +290 -0
- package/include/numkong/spatial/haswell.h +960 -0
- package/include/numkong/spatial/icelake.h +586 -0
- package/include/numkong/spatial/neon.h +773 -0
- package/include/numkong/spatial/neonbfdot.h +165 -0
- package/include/numkong/spatial/neonhalf.h +118 -0
- package/include/numkong/spatial/neonsdot.h +261 -0
- package/include/numkong/spatial/rvv.h +984 -0
- package/include/numkong/spatial/rvvbf16.h +123 -0
- package/include/numkong/spatial/rvvhalf.h +117 -0
- package/include/numkong/spatial/sapphire.h +343 -0
- package/include/numkong/spatial/serial.h +346 -0
- package/include/numkong/spatial/sierra.h +323 -0
- package/include/numkong/spatial/skylake.h +606 -0
- package/include/numkong/spatial/sve.h +224 -0
- package/include/numkong/spatial/svebfdot.h +122 -0
- package/include/numkong/spatial/svehalf.h +109 -0
- package/include/numkong/spatial/v128relaxed.h +717 -0
- package/include/numkong/spatial.h +1425 -0
- package/include/numkong/spatial.hpp +183 -0
- package/include/numkong/spatials/README.md +580 -0
- package/include/numkong/spatials/alder.h +94 -0
- package/include/numkong/spatials/genoa.h +94 -0
- package/include/numkong/spatials/haswell.h +219 -0
- package/include/numkong/spatials/icelake.h +113 -0
- package/include/numkong/spatials/neon.h +109 -0
- package/include/numkong/spatials/neonbfdot.h +60 -0
- package/include/numkong/spatials/neonfhm.h +92 -0
- package/include/numkong/spatials/neonhalf.h +58 -0
- package/include/numkong/spatials/neonsdot.h +109 -0
- package/include/numkong/spatials/rvv.h +1960 -0
- package/include/numkong/spatials/sapphireamx.h +1149 -0
- package/include/numkong/spatials/serial.h +226 -0
- package/include/numkong/spatials/sierra.h +96 -0
- package/include/numkong/spatials/skylake.h +184 -0
- package/include/numkong/spatials/sme.h +1901 -0
- package/include/numkong/spatials/smef64.h +465 -0
- package/include/numkong/spatials/v128relaxed.h +240 -0
- package/include/numkong/spatials.h +3021 -0
- package/include/numkong/spatials.hpp +508 -0
- package/include/numkong/tensor.hpp +1592 -0
- package/include/numkong/trigonometry/README.md +184 -0
- package/include/numkong/trigonometry/haswell.h +652 -0
- package/include/numkong/trigonometry/neon.h +639 -0
- package/include/numkong/trigonometry/rvv.h +699 -0
- package/include/numkong/trigonometry/serial.h +703 -0
- package/include/numkong/trigonometry/skylake.h +721 -0
- package/include/numkong/trigonometry/v128relaxed.h +666 -0
- package/include/numkong/trigonometry.h +467 -0
- package/include/numkong/trigonometry.hpp +166 -0
- package/include/numkong/types.h +1384 -0
- package/include/numkong/types.hpp +5603 -0
- package/include/numkong/vector.hpp +698 -0
- package/javascript/README.md +246 -0
- package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
- package/javascript/dist/cjs/numkong-wasm.js +617 -0
- package/javascript/dist/cjs/numkong.d.ts +343 -0
- package/javascript/dist/cjs/numkong.js +523 -0
- package/javascript/dist/cjs/package.json +3 -0
- package/javascript/dist/cjs/types.d.ts +284 -0
- package/javascript/dist/cjs/types.js +653 -0
- package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
- package/javascript/dist/esm/numkong-wasm.js +595 -0
- package/javascript/dist/esm/numkong.d.ts +343 -0
- package/javascript/dist/esm/numkong.js +452 -0
- package/javascript/dist/esm/package.json +3 -0
- package/javascript/dist/esm/types.d.ts +284 -0
- package/javascript/dist/esm/types.js +630 -0
- package/javascript/dist-package-cjs.json +3 -0
- package/javascript/dist-package-esm.json +3 -0
- package/javascript/node-gyp-build.d.ts +1 -0
- package/javascript/numkong-wasm.ts +756 -0
- package/javascript/numkong.c +689 -0
- package/javascript/numkong.ts +575 -0
- package/javascript/tsconfig-base.json +39 -0
- package/javascript/tsconfig-cjs.json +8 -0
- package/javascript/tsconfig-esm.json +8 -0
- package/javascript/types.ts +674 -0
- package/package.json +87 -0
|
@@ -0,0 +1,762 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief C++ bindings for mesh-distance kernels.
|
|
3
|
+
* @file include/numkong/mesh.hpp
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date February 5, 2026
|
|
6
|
+
*/
|
|
7
|
+
#ifndef NK_MESH_HPP
|
|
8
|
+
#define NK_MESH_HPP
|
|
9
|
+
|
|
10
|
+
#include <cstdint>
|
|
11
|
+
#include <type_traits>
|
|
12
|
+
#include <utility>
|
|
13
|
+
|
|
14
|
+
#include "numkong/mesh.h"
|
|
15
|
+
|
|
16
|
+
#include "numkong/types.hpp"
|
|
17
|
+
|
|
18
|
+
namespace ashvardanian::numkong {
|
|
19
|
+
|
|
20
|
+
#pragma region - SVD Helpers for Scalar Fallbacks
|
|
21
|
+
|
|
22
|
+
/** @brief 3x3 matrix determinant. */
|
|
23
|
+
template <typename scalar_type_>
|
|
24
|
+
scalar_type_ det3x3_(scalar_type_ const *m) {
|
|
25
|
+
return m[0] * (m[4] * m[8] - m[5] * m[7]) - m[1] * (m[3] * m[8] - m[5] * m[6]) + m[2] * (m[3] * m[7] - m[4] * m[6]);
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
/** @brief Conditional swap helper. */
|
|
29
|
+
template <typename scalar_type_>
|
|
30
|
+
void conditional_swap_(bool c, scalar_type_ *x, scalar_type_ *y) {
|
|
31
|
+
scalar_type_ temp = *x;
|
|
32
|
+
*x = c ? *y : *x;
|
|
33
|
+
*y = c ? temp : *y;
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
/** @brief Conditional negating swap helper. */
|
|
37
|
+
template <typename scalar_type_>
|
|
38
|
+
void conditional_negating_swap_(bool c, scalar_type_ *x, scalar_type_ *y) {
|
|
39
|
+
scalar_type_ neg_x = scalar_type_(0.0) - *x;
|
|
40
|
+
*x = c ? *y : *x;
|
|
41
|
+
*y = c ? neg_x : *y;
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
/** @brief Approximate Givens quaternion for Jacobi eigenanalysis. */
|
|
45
|
+
template <typename scalar_type_>
|
|
46
|
+
void approximate_givens_quaternion_(scalar_type_ a11, scalar_type_ a12, scalar_type_ a22, scalar_type_ *cos_half,
|
|
47
|
+
scalar_type_ *sin_half) {
|
|
48
|
+
constexpr scalar_type_ gamma_k = scalar_type_(5.828427124746190); // gamma = (sqrt8 + 3)^2 / 4
|
|
49
|
+
constexpr scalar_type_ cstar_k = scalar_type_(0.9238795325112867); // cos(pi/8)
|
|
50
|
+
constexpr scalar_type_ sstar_k = scalar_type_(0.3826834323650898); // sin(pi/8)
|
|
51
|
+
|
|
52
|
+
*cos_half = scalar_type_(2.0) * (a11 - a22);
|
|
53
|
+
*sin_half = a12;
|
|
54
|
+
bool use_givens = gamma_k * (*sin_half) * (*sin_half) < (*cos_half) * (*cos_half);
|
|
55
|
+
scalar_type_ w = ((*cos_half) * (*cos_half) + (*sin_half) * (*sin_half)).rsqrt();
|
|
56
|
+
*cos_half = use_givens ? w * (*cos_half) : cstar_k;
|
|
57
|
+
*sin_half = use_givens ? w * (*sin_half) : sstar_k;
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
/** @brief Jacobi conjugation step for eigenanalysis. */
|
|
61
|
+
template <typename scalar_type_>
|
|
62
|
+
void jacobi_conjugation_(int idx_x, int idx_y, int idx_z, scalar_type_ *s11, scalar_type_ *s21, scalar_type_ *s22,
|
|
63
|
+
scalar_type_ *s31, scalar_type_ *s32, scalar_type_ *s33, scalar_type_ *quat) {
|
|
64
|
+
|
|
65
|
+
scalar_type_ cos_half, sin_half;
|
|
66
|
+
approximate_givens_quaternion_(*s11, *s21, *s22, &cos_half, &sin_half);
|
|
67
|
+
scalar_type_ scale = cos_half * cos_half + sin_half * sin_half;
|
|
68
|
+
scalar_type_ cos_theta = (cos_half * cos_half - sin_half * sin_half) / scale;
|
|
69
|
+
scalar_type_ sin_theta = (scalar_type_(2.0) * sin_half * cos_half) / scale;
|
|
70
|
+
scalar_type_ s11_old = *s11, s21_old = *s21, s22_old = *s22;
|
|
71
|
+
scalar_type_ s31_old = *s31, s32_old = *s32, s33_old = *s33;
|
|
72
|
+
|
|
73
|
+
*s11 = cos_theta * (cos_theta * s11_old + sin_theta * s21_old) +
|
|
74
|
+
sin_theta * (cos_theta * s21_old + sin_theta * s22_old);
|
|
75
|
+
*s21 = cos_theta * ((scalar_type_(0.0) - sin_theta) * s11_old + cos_theta * s21_old) +
|
|
76
|
+
sin_theta * ((scalar_type_(0.0) - sin_theta) * s21_old + cos_theta * s22_old);
|
|
77
|
+
*s22 = (scalar_type_(0.0) - sin_theta) * ((scalar_type_(0.0) - sin_theta) * s11_old + cos_theta * s21_old) +
|
|
78
|
+
cos_theta * ((scalar_type_(0.0) - sin_theta) * s21_old + cos_theta * s22_old);
|
|
79
|
+
*s31 = cos_theta * s31_old + sin_theta * s32_old;
|
|
80
|
+
*s32 = (scalar_type_(0.0) - sin_theta) * s31_old + cos_theta * s32_old;
|
|
81
|
+
*s33 = s33_old;
|
|
82
|
+
|
|
83
|
+
// Update quaternion accumulator
|
|
84
|
+
scalar_type_ quat_temp[3];
|
|
85
|
+
quat_temp[0] = quat[0] * sin_half;
|
|
86
|
+
quat_temp[1] = quat[1] * sin_half;
|
|
87
|
+
quat_temp[2] = quat[2] * sin_half;
|
|
88
|
+
sin_half = sin_half * quat[3];
|
|
89
|
+
quat[0] = quat[0] * cos_half;
|
|
90
|
+
quat[1] = quat[1] * cos_half;
|
|
91
|
+
quat[2] = quat[2] * cos_half;
|
|
92
|
+
quat[3] = quat[3] * cos_half;
|
|
93
|
+
quat[idx_z] = quat[idx_z] + sin_half;
|
|
94
|
+
quat[3] = quat[3] - quat_temp[idx_z];
|
|
95
|
+
quat[idx_x] = quat[idx_x] + quat_temp[idx_y];
|
|
96
|
+
quat[idx_y] = quat[idx_y] - quat_temp[idx_x];
|
|
97
|
+
// Cyclic permutation of matrix elements
|
|
98
|
+
s11_old = *s22, s21_old = *s32, s22_old = *s33, s31_old = *s21, s32_old = *s31, s33_old = *s11;
|
|
99
|
+
*s11 = s11_old, *s21 = s21_old, *s22 = s22_old, *s31 = s31_old, *s32 = s32_old, *s33 = s33_old;
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
/** @brief Convert quaternion to 3x3 rotation matrix. */
|
|
103
|
+
template <typename scalar_type_>
|
|
104
|
+
void quaternion_to_mat3x3_(scalar_type_ const *quat, scalar_type_ *matrix) {
|
|
105
|
+
scalar_type_ w = quat[3], x = quat[0], y = quat[1], z = quat[2];
|
|
106
|
+
scalar_type_ q_xx = x * x, q_yy = y * y, q_zz = z * z;
|
|
107
|
+
scalar_type_ q_xz = x * z, q_xy = x * y, q_yz = y * z;
|
|
108
|
+
scalar_type_ q_wx = w * x, q_wy = w * y, q_wz = w * z;
|
|
109
|
+
matrix[0] = scalar_type_(1.0) - scalar_type_(2.0) * (q_yy + q_zz);
|
|
110
|
+
matrix[1] = scalar_type_(2.0) * (q_xy - q_wz);
|
|
111
|
+
matrix[2] = scalar_type_(2.0) * (q_xz + q_wy);
|
|
112
|
+
matrix[3] = scalar_type_(2.0) * (q_xy + q_wz);
|
|
113
|
+
matrix[4] = scalar_type_(1.0) - scalar_type_(2.0) * (q_xx + q_zz);
|
|
114
|
+
matrix[5] = scalar_type_(2.0) * (q_yz - q_wx);
|
|
115
|
+
matrix[6] = scalar_type_(2.0) * (q_xz - q_wy);
|
|
116
|
+
matrix[7] = scalar_type_(2.0) * (q_yz + q_wx);
|
|
117
|
+
matrix[8] = scalar_type_(1.0) - scalar_type_(2.0) * (q_xx + q_yy);
|
|
118
|
+
}
|
|
119
|
+
|
|
120
|
+
/** @brief Jacobi eigenanalysis for symmetric 3x3 matrix. */
|
|
121
|
+
template <typename scalar_type_>
|
|
122
|
+
void jacobi_eigenanalysis_(scalar_type_ *s11, scalar_type_ *s21, scalar_type_ *s22, scalar_type_ *s31,
|
|
123
|
+
scalar_type_ *s32, scalar_type_ *s33, scalar_type_ *quat) {
|
|
124
|
+
quat[0] = scalar_type_(0.0);
|
|
125
|
+
quat[1] = scalar_type_(0.0);
|
|
126
|
+
quat[2] = scalar_type_(0.0);
|
|
127
|
+
quat[3] = scalar_type_(1.0);
|
|
128
|
+
// 16 iterations for better convergence
|
|
129
|
+
for (unsigned int iter = 0; iter < 16; iter++) {
|
|
130
|
+
jacobi_conjugation_(0, 1, 2, s11, s21, s22, s31, s32, s33, quat);
|
|
131
|
+
jacobi_conjugation_(1, 2, 0, s11, s21, s22, s31, s32, s33, quat);
|
|
132
|
+
jacobi_conjugation_(2, 0, 1, s11, s21, s22, s31, s32, s33, quat);
|
|
133
|
+
}
|
|
134
|
+
scalar_type_ norm = (quat[0] * quat[0] + quat[1] * quat[1] + quat[2] * quat[2] + quat[3] * quat[3]).rsqrt();
|
|
135
|
+
quat[0] = quat[0] * norm;
|
|
136
|
+
quat[1] = quat[1] * norm;
|
|
137
|
+
quat[2] = quat[2] * norm;
|
|
138
|
+
quat[3] = quat[3] * norm;
|
|
139
|
+
}
|
|
140
|
+
|
|
141
|
+
/** @brief QR Givens quaternion for QR decomposition. */
|
|
142
|
+
template <typename scalar_type_>
|
|
143
|
+
void qr_givens_quaternion_(scalar_type_ a1, scalar_type_ a2, scalar_type_ *cos_half, scalar_type_ *sin_half) {
|
|
144
|
+
constexpr scalar_type_ epsilon_k = scalar_type_(1e-12);
|
|
145
|
+
|
|
146
|
+
scalar_type_ a1_sq_plus_a2_sq = a1 * a1 + a2 * a2;
|
|
147
|
+
scalar_type_ rho = a1_sq_plus_a2_sq * a1_sq_plus_a2_sq.rsqrt();
|
|
148
|
+
rho = a1_sq_plus_a2_sq > epsilon_k ? rho : scalar_type_(0.0);
|
|
149
|
+
*sin_half = rho > epsilon_k ? a2 : scalar_type_(0.0);
|
|
150
|
+
scalar_type_ abs_a1 = a1 < scalar_type_(0.0) ? (scalar_type_(0.0) - a1) : a1;
|
|
151
|
+
scalar_type_ max_rho = rho > epsilon_k ? rho : epsilon_k;
|
|
152
|
+
*cos_half = abs_a1 + max_rho;
|
|
153
|
+
bool should_swap = a1 < scalar_type_(0.0);
|
|
154
|
+
conditional_swap_(should_swap, sin_half, cos_half);
|
|
155
|
+
scalar_type_ w = ((*cos_half) * (*cos_half) + (*sin_half) * (*sin_half)).rsqrt();
|
|
156
|
+
*cos_half = (*cos_half) * w;
|
|
157
|
+
*sin_half = (*sin_half) * w;
|
|
158
|
+
}
|
|
159
|
+
|
|
160
|
+
/** @brief Sort singular values in descending order. */
|
|
161
|
+
template <typename scalar_type_>
|
|
162
|
+
void sort_singular_values_(scalar_type_ *b, scalar_type_ *v) {
|
|
163
|
+
scalar_type_ rho1 = b[0] * b[0] + b[3] * b[3] + b[6] * b[6];
|
|
164
|
+
scalar_type_ rho2 = b[1] * b[1] + b[4] * b[4] + b[7] * b[7];
|
|
165
|
+
scalar_type_ rho3 = b[2] * b[2] + b[5] * b[5] + b[8] * b[8];
|
|
166
|
+
bool should_swap;
|
|
167
|
+
// Sort columns by descending singular value magnitude
|
|
168
|
+
should_swap = rho1 < rho2;
|
|
169
|
+
conditional_negating_swap_(should_swap, &b[0], &b[1]);
|
|
170
|
+
conditional_negating_swap_(should_swap, &v[0], &v[1]);
|
|
171
|
+
conditional_negating_swap_(should_swap, &b[3], &b[4]);
|
|
172
|
+
conditional_negating_swap_(should_swap, &v[3], &v[4]);
|
|
173
|
+
conditional_negating_swap_(should_swap, &b[6], &b[7]);
|
|
174
|
+
conditional_negating_swap_(should_swap, &v[6], &v[7]);
|
|
175
|
+
conditional_swap_(should_swap, &rho1, &rho2);
|
|
176
|
+
should_swap = rho1 < rho3;
|
|
177
|
+
conditional_negating_swap_(should_swap, &b[0], &b[2]);
|
|
178
|
+
conditional_negating_swap_(should_swap, &v[0], &v[2]);
|
|
179
|
+
conditional_negating_swap_(should_swap, &b[3], &b[5]);
|
|
180
|
+
conditional_negating_swap_(should_swap, &v[3], &v[5]);
|
|
181
|
+
conditional_negating_swap_(should_swap, &b[6], &b[8]);
|
|
182
|
+
conditional_negating_swap_(should_swap, &v[6], &v[8]);
|
|
183
|
+
conditional_swap_(should_swap, &rho1, &rho3);
|
|
184
|
+
should_swap = rho2 < rho3;
|
|
185
|
+
conditional_negating_swap_(should_swap, &b[1], &b[2]);
|
|
186
|
+
conditional_negating_swap_(should_swap, &v[1], &v[2]);
|
|
187
|
+
conditional_negating_swap_(should_swap, &b[4], &b[5]);
|
|
188
|
+
conditional_negating_swap_(should_swap, &v[4], &v[5]);
|
|
189
|
+
conditional_negating_swap_(should_swap, &b[7], &b[8]);
|
|
190
|
+
conditional_negating_swap_(should_swap, &v[7], &v[8]);
|
|
191
|
+
}
|
|
192
|
+
|
|
193
|
+
/** @brief QR decomposition of 3x3 matrix. */
|
|
194
|
+
template <typename scalar_type_>
|
|
195
|
+
void qr_decomposition_(scalar_type_ const *input, scalar_type_ *q, scalar_type_ *r) {
|
|
196
|
+
scalar_type_ cos_half_1, sin_half_1;
|
|
197
|
+
scalar_type_ cos_half_2, sin_half_2;
|
|
198
|
+
scalar_type_ cos_half_3, sin_half_3;
|
|
199
|
+
scalar_type_ cos_theta, sin_theta;
|
|
200
|
+
scalar_type_ rotation_temp[9], matrix_temp[9];
|
|
201
|
+
// First Givens rotation (zero input[3])
|
|
202
|
+
qr_givens_quaternion_(input[0], input[3], &cos_half_1, &sin_half_1);
|
|
203
|
+
cos_theta = scalar_type_(1.0) - scalar_type_(2.0) * sin_half_1 * sin_half_1;
|
|
204
|
+
sin_theta = scalar_type_(2.0) * cos_half_1 * sin_half_1;
|
|
205
|
+
rotation_temp[0] = cos_theta * input[0] + sin_theta * input[3];
|
|
206
|
+
rotation_temp[1] = cos_theta * input[1] + sin_theta * input[4];
|
|
207
|
+
rotation_temp[2] = cos_theta * input[2] + sin_theta * input[5];
|
|
208
|
+
rotation_temp[3] = (scalar_type_(0.0) - sin_theta) * input[0] + cos_theta * input[3];
|
|
209
|
+
rotation_temp[4] = (scalar_type_(0.0) - sin_theta) * input[1] + cos_theta * input[4];
|
|
210
|
+
rotation_temp[5] = (scalar_type_(0.0) - sin_theta) * input[2] + cos_theta * input[5];
|
|
211
|
+
rotation_temp[6] = input[6];
|
|
212
|
+
rotation_temp[7] = input[7];
|
|
213
|
+
rotation_temp[8] = input[8];
|
|
214
|
+
// Second Givens rotation (zero rotation_temp[6])
|
|
215
|
+
qr_givens_quaternion_(rotation_temp[0], rotation_temp[6], &cos_half_2, &sin_half_2);
|
|
216
|
+
cos_theta = scalar_type_(1.0) - scalar_type_(2.0) * sin_half_2 * sin_half_2;
|
|
217
|
+
sin_theta = scalar_type_(2.0) * cos_half_2 * sin_half_2;
|
|
218
|
+
matrix_temp[0] = cos_theta * rotation_temp[0] + sin_theta * rotation_temp[6];
|
|
219
|
+
matrix_temp[1] = cos_theta * rotation_temp[1] + sin_theta * rotation_temp[7];
|
|
220
|
+
matrix_temp[2] = cos_theta * rotation_temp[2] + sin_theta * rotation_temp[8];
|
|
221
|
+
matrix_temp[3] = rotation_temp[3];
|
|
222
|
+
matrix_temp[4] = rotation_temp[4];
|
|
223
|
+
matrix_temp[5] = rotation_temp[5];
|
|
224
|
+
matrix_temp[6] = (scalar_type_(0.0) - sin_theta) * rotation_temp[0] + cos_theta * rotation_temp[6];
|
|
225
|
+
matrix_temp[7] = (scalar_type_(0.0) - sin_theta) * rotation_temp[1] + cos_theta * rotation_temp[7];
|
|
226
|
+
matrix_temp[8] = (scalar_type_(0.0) - sin_theta) * rotation_temp[2] + cos_theta * rotation_temp[8];
|
|
227
|
+
// Third Givens rotation (zero matrix_temp[7])
|
|
228
|
+
qr_givens_quaternion_(matrix_temp[4], matrix_temp[7], &cos_half_3, &sin_half_3);
|
|
229
|
+
cos_theta = scalar_type_(1.0) - scalar_type_(2.0) * sin_half_3 * sin_half_3;
|
|
230
|
+
sin_theta = scalar_type_(2.0) * cos_half_3 * sin_half_3;
|
|
231
|
+
r[0] = matrix_temp[0];
|
|
232
|
+
r[1] = matrix_temp[1];
|
|
233
|
+
r[2] = matrix_temp[2];
|
|
234
|
+
r[3] = cos_theta * matrix_temp[3] + sin_theta * matrix_temp[6];
|
|
235
|
+
r[4] = cos_theta * matrix_temp[4] + sin_theta * matrix_temp[7];
|
|
236
|
+
r[5] = cos_theta * matrix_temp[5] + sin_theta * matrix_temp[8];
|
|
237
|
+
r[6] = (scalar_type_(0.0) - sin_theta) * matrix_temp[3] + cos_theta * matrix_temp[6];
|
|
238
|
+
r[7] = (scalar_type_(0.0) - sin_theta) * matrix_temp[4] + cos_theta * matrix_temp[7];
|
|
239
|
+
r[8] = (scalar_type_(0.0) - sin_theta) * matrix_temp[5] + cos_theta * matrix_temp[8];
|
|
240
|
+
// Construct Q = Q1 * Q2 * Q3 (closed-form expressions)
|
|
241
|
+
scalar_type_ sin_half_1_sq = sin_half_1 * sin_half_1;
|
|
242
|
+
scalar_type_ sin_half_2_sq = sin_half_2 * sin_half_2;
|
|
243
|
+
scalar_type_ sin_half_3_sq = sin_half_3 * sin_half_3;
|
|
244
|
+
q[0] = (scalar_type_(-1.0) + scalar_type_(2.0) * sin_half_1_sq) *
|
|
245
|
+
(scalar_type_(-1.0) + scalar_type_(2.0) * sin_half_2_sq);
|
|
246
|
+
q[1] = scalar_type_(4.0) * cos_half_2 * cos_half_3 * (scalar_type_(-1.0) + scalar_type_(2.0) * sin_half_1_sq) *
|
|
247
|
+
sin_half_2 * sin_half_3 +
|
|
248
|
+
scalar_type_(2.0) * cos_half_1 * sin_half_1 * (scalar_type_(-1.0) + scalar_type_(2.0) * sin_half_3_sq);
|
|
249
|
+
q[2] = scalar_type_(4.0) * cos_half_1 * cos_half_3 * sin_half_1 * sin_half_3 -
|
|
250
|
+
scalar_type_(2.0) * cos_half_2 * (scalar_type_(-1.0) + scalar_type_(2.0) * sin_half_1_sq) * sin_half_2 *
|
|
251
|
+
(scalar_type_(-1.0) + scalar_type_(2.0) * sin_half_3_sq);
|
|
252
|
+
q[3] = scalar_type_(2.0) * cos_half_1 * sin_half_1 * (scalar_type_(1.0) - scalar_type_(2.0) * sin_half_2_sq);
|
|
253
|
+
q[4] = scalar_type_(-8.0) * cos_half_1 * cos_half_2 * cos_half_3 * sin_half_1 * sin_half_2 * sin_half_3 +
|
|
254
|
+
(scalar_type_(-1.0) + scalar_type_(2.0) * sin_half_1_sq) *
|
|
255
|
+
(scalar_type_(-1.0) + scalar_type_(2.0) * sin_half_3_sq);
|
|
256
|
+
q[5] = scalar_type_(-2.0) * cos_half_3 * sin_half_3 +
|
|
257
|
+
scalar_type_(4.0) * sin_half_1 *
|
|
258
|
+
(cos_half_3 * sin_half_1 * sin_half_3 +
|
|
259
|
+
cos_half_1 * cos_half_2 * sin_half_2 * (scalar_type_(-1.0) + scalar_type_(2.0) * sin_half_3_sq));
|
|
260
|
+
q[6] = scalar_type_(2.0) * cos_half_2 * sin_half_2;
|
|
261
|
+
q[7] = scalar_type_(2.0) * cos_half_3 * (scalar_type_(1.0) - scalar_type_(2.0) * sin_half_2_sq) * sin_half_3;
|
|
262
|
+
q[8] = (scalar_type_(-1.0) + scalar_type_(2.0) * sin_half_2_sq) *
|
|
263
|
+
(scalar_type_(-1.0) + scalar_type_(2.0) * sin_half_3_sq);
|
|
264
|
+
}
|
|
265
|
+
|
|
266
|
+
/** @brief 3x3 SVD: A = U * S * Vt using McAdams algorithm. */
|
|
267
|
+
template <typename scalar_type_>
|
|
268
|
+
void svd3x3_(scalar_type_ const *a, scalar_type_ *svd_u, scalar_type_ *svd_s, scalar_type_ *svd_v) {
|
|
269
|
+
// Compute At * A (symmetric)
|
|
270
|
+
scalar_type_ ata[9];
|
|
271
|
+
ata[0] = a[0] * a[0] + a[3] * a[3] + a[6] * a[6];
|
|
272
|
+
ata[1] = a[0] * a[1] + a[3] * a[4] + a[6] * a[7];
|
|
273
|
+
ata[2] = a[0] * a[2] + a[3] * a[5] + a[6] * a[8];
|
|
274
|
+
ata[3] = ata[1];
|
|
275
|
+
ata[4] = a[1] * a[1] + a[4] * a[4] + a[7] * a[7];
|
|
276
|
+
ata[5] = a[1] * a[2] + a[4] * a[5] + a[7] * a[8];
|
|
277
|
+
ata[6] = ata[2];
|
|
278
|
+
ata[7] = ata[5];
|
|
279
|
+
ata[8] = a[2] * a[2] + a[5] * a[5] + a[8] * a[8];
|
|
280
|
+
// Jacobi eigenanalysis of At * A
|
|
281
|
+
scalar_type_ quat[4];
|
|
282
|
+
jacobi_eigenanalysis_(&ata[0], &ata[1], &ata[4], &ata[2], &ata[5], &ata[8], quat);
|
|
283
|
+
quaternion_to_mat3x3_(quat, svd_v);
|
|
284
|
+
// B = A * V
|
|
285
|
+
scalar_type_ product[9];
|
|
286
|
+
product[0] = a[0] * svd_v[0] + a[1] * svd_v[3] + a[2] * svd_v[6];
|
|
287
|
+
product[1] = a[0] * svd_v[1] + a[1] * svd_v[4] + a[2] * svd_v[7];
|
|
288
|
+
product[2] = a[0] * svd_v[2] + a[1] * svd_v[5] + a[2] * svd_v[8];
|
|
289
|
+
product[3] = a[3] * svd_v[0] + a[4] * svd_v[3] + a[5] * svd_v[6];
|
|
290
|
+
product[4] = a[3] * svd_v[1] + a[4] * svd_v[4] + a[5] * svd_v[7];
|
|
291
|
+
product[5] = a[3] * svd_v[2] + a[4] * svd_v[5] + a[5] * svd_v[8];
|
|
292
|
+
product[6] = a[6] * svd_v[0] + a[7] * svd_v[3] + a[8] * svd_v[6];
|
|
293
|
+
product[7] = a[6] * svd_v[1] + a[7] * svd_v[4] + a[8] * svd_v[7];
|
|
294
|
+
product[8] = a[6] * svd_v[2] + a[7] * svd_v[5] + a[8] * svd_v[8];
|
|
295
|
+
// Sort singular values and update V
|
|
296
|
+
sort_singular_values_(product, svd_v);
|
|
297
|
+
// Compute singular values from column norms of sorted B
|
|
298
|
+
scalar_type_ s1_sq = product[0] * product[0] + product[3] * product[3] + product[6] * product[6];
|
|
299
|
+
scalar_type_ s2_sq = product[1] * product[1] + product[4] * product[4] + product[7] * product[7];
|
|
300
|
+
scalar_type_ s3_sq = product[2] * product[2] + product[5] * product[5] + product[8] * product[8];
|
|
301
|
+
// QR decomposition: B = U * R
|
|
302
|
+
scalar_type_ qr_r[9];
|
|
303
|
+
qr_decomposition_(product, svd_u, qr_r);
|
|
304
|
+
// Store singular values in diagonal of svd_s
|
|
305
|
+
svd_s[0] = s1_sq.sqrt();
|
|
306
|
+
svd_s[1] = scalar_type_(0.0);
|
|
307
|
+
svd_s[2] = scalar_type_(0.0);
|
|
308
|
+
svd_s[3] = scalar_type_(0.0);
|
|
309
|
+
svd_s[4] = s2_sq.sqrt();
|
|
310
|
+
svd_s[5] = scalar_type_(0.0);
|
|
311
|
+
svd_s[6] = scalar_type_(0.0);
|
|
312
|
+
svd_s[7] = scalar_type_(0.0);
|
|
313
|
+
svd_s[8] = s3_sq.sqrt();
|
|
314
|
+
}
|
|
315
|
+
|
|
316
|
+
#pragma endregion - SVD Helpers for Scalar Fallbacks
|
|
317
|
+
|
|
318
|
+
#pragma region - Mesh Alignment Kernels
|
|
319
|
+
|
|
320
|
+
/**
|
|
321
|
+
* @brief Root Mean Square Deviation between two 3D point clouds (no alignment)
|
|
322
|
+
* @param[in] a,b Point clouds [d x 3] interleaved (x0,y0,z0, x1,y1,z1, ...)
|
|
323
|
+
* @param[in] d Number of 3D points
|
|
324
|
+
* @param[out] a_centroid,b_centroid Centroids (3 values each), can be nullptr
|
|
325
|
+
* @param[out] rotation 3x3 rotation matrix (9 values), always identity, can be nullptr
|
|
326
|
+
* @param[out] scale Scale factor, always 1.0, can be nullptr
|
|
327
|
+
* @param[out] metric Output RMSD value
|
|
328
|
+
*
|
|
329
|
+
* @tparam in_type_ Input point type (f32_t, f64_t, f16_t, bf16_t)
|
|
330
|
+
* @tparam transform_type_ Type of centroids, rotation, and scale outputs
|
|
331
|
+
* @tparam metric_type_ Type of the scalar fit metric output
|
|
332
|
+
* @tparam allow_simd_ Enable SIMD kernel dispatch when `prefer_simd_k`
|
|
333
|
+
*/
|
|
334
|
+
template <typename in_type_, typename transform_type_ = typename in_type_::mesh_transform_t,
|
|
335
|
+
typename metric_type_ = typename in_type_::mesh_metric_t,
|
|
336
|
+
allow_simd_t allow_simd_ = prefer_simd_k>
|
|
337
|
+
void rmsd( //
|
|
338
|
+
in_type_ const *a, in_type_ const *b, std::size_t n, //
|
|
339
|
+
transform_type_ *a_centroid, transform_type_ *b_centroid, transform_type_ *rotation, transform_type_ *scale,
|
|
340
|
+
metric_type_ *metric) noexcept {
|
|
341
|
+
constexpr bool simd = allow_simd_ == prefer_simd_k &&
|
|
342
|
+
std::is_same_v<transform_type_, typename in_type_::mesh_transform_t> &&
|
|
343
|
+
std::is_same_v<metric_type_, typename in_type_::mesh_metric_t>;
|
|
344
|
+
|
|
345
|
+
if constexpr (std::is_same_v<in_type_, f64_t> && simd)
|
|
346
|
+
nk_rmsd_f64(&a->raw_, &b->raw_, n, &a_centroid->raw_, &b_centroid->raw_, &rotation->raw_, &scale->raw_,
|
|
347
|
+
&metric->raw_);
|
|
348
|
+
else if constexpr (std::is_same_v<in_type_, f32_t> && simd)
|
|
349
|
+
nk_rmsd_f32(&a->raw_, &b->raw_, n, &a_centroid->raw_, &b_centroid->raw_, &rotation->raw_, &scale->raw_,
|
|
350
|
+
&metric->raw_);
|
|
351
|
+
else if constexpr (std::is_same_v<in_type_, f16_t> && simd)
|
|
352
|
+
nk_rmsd_f16(&a->raw_, &b->raw_, n, &a_centroid->raw_, &b_centroid->raw_, &rotation->raw_, &scale->raw_,
|
|
353
|
+
&metric->raw_);
|
|
354
|
+
else if constexpr (std::is_same_v<in_type_, bf16_t> && simd)
|
|
355
|
+
nk_rmsd_bf16(&a->raw_, &b->raw_, n, &a_centroid->raw_, &b_centroid->raw_, &rotation->raw_,
|
|
356
|
+
scale ? &scale->raw_ : nullptr, &metric->raw_);
|
|
357
|
+
// Scalar fallback
|
|
358
|
+
else {
|
|
359
|
+
// Step 1: Compute centroids
|
|
360
|
+
metric_type_ sum_a_x {}, sum_a_y {}, sum_a_z {};
|
|
361
|
+
metric_type_ sum_b_x {}, sum_b_y {}, sum_b_z {};
|
|
362
|
+
metric_type_ val_a_x, val_a_y, val_a_z, val_b_x, val_b_y, val_b_z;
|
|
363
|
+
|
|
364
|
+
for (std::size_t i = 0; i < n; i++) {
|
|
365
|
+
val_a_x = metric_type_(a[i * 3 + 0]);
|
|
366
|
+
val_a_y = metric_type_(a[i * 3 + 1]);
|
|
367
|
+
val_a_z = metric_type_(a[i * 3 + 2]);
|
|
368
|
+
val_b_x = metric_type_(b[i * 3 + 0]);
|
|
369
|
+
val_b_y = metric_type_(b[i * 3 + 1]);
|
|
370
|
+
val_b_z = metric_type_(b[i * 3 + 2]);
|
|
371
|
+
sum_a_x = sum_a_x + val_a_x;
|
|
372
|
+
sum_a_y = sum_a_y + val_a_y;
|
|
373
|
+
sum_a_z = sum_a_z + val_a_z;
|
|
374
|
+
sum_b_x = sum_b_x + val_b_x;
|
|
375
|
+
sum_b_y = sum_b_y + val_b_y;
|
|
376
|
+
sum_b_z = sum_b_z + val_b_z;
|
|
377
|
+
}
|
|
378
|
+
|
|
379
|
+
metric_type_ inv_n = metric_type_(1.0) / metric_type_(static_cast<double>(n));
|
|
380
|
+
metric_type_ centroid_a_x = sum_a_x * inv_n;
|
|
381
|
+
metric_type_ centroid_a_y = sum_a_y * inv_n;
|
|
382
|
+
metric_type_ centroid_a_z = sum_a_z * inv_n;
|
|
383
|
+
metric_type_ centroid_b_x = sum_b_x * inv_n;
|
|
384
|
+
metric_type_ centroid_b_y = sum_b_y * inv_n;
|
|
385
|
+
metric_type_ centroid_b_z = sum_b_z * inv_n;
|
|
386
|
+
|
|
387
|
+
// Step 2: Store centroids if requested
|
|
388
|
+
if (a_centroid)
|
|
389
|
+
a_centroid[0] = transform_type_(centroid_a_x), a_centroid[1] = transform_type_(centroid_a_y),
|
|
390
|
+
a_centroid[2] = transform_type_(centroid_a_z);
|
|
391
|
+
if (b_centroid)
|
|
392
|
+
b_centroid[0] = transform_type_(centroid_b_x), b_centroid[1] = transform_type_(centroid_b_y),
|
|
393
|
+
b_centroid[2] = transform_type_(centroid_b_z);
|
|
394
|
+
|
|
395
|
+
// Step 3: RMSD uses identity rotation and scale=1.0
|
|
396
|
+
if (rotation) {
|
|
397
|
+
rotation[0] = transform_type_(1.0);
|
|
398
|
+
rotation[1] = transform_type_(0.0);
|
|
399
|
+
rotation[2] = transform_type_(0.0);
|
|
400
|
+
rotation[3] = transform_type_(0.0);
|
|
401
|
+
rotation[4] = transform_type_(1.0);
|
|
402
|
+
rotation[5] = transform_type_(0.0);
|
|
403
|
+
rotation[6] = transform_type_(0.0);
|
|
404
|
+
rotation[7] = transform_type_(0.0);
|
|
405
|
+
rotation[8] = transform_type_(1.0);
|
|
406
|
+
}
|
|
407
|
+
if (scale) *scale = transform_type_(1.0);
|
|
408
|
+
|
|
409
|
+
// Step 4: Compute RMSD between centered point clouds
|
|
410
|
+
metric_type_ sum_squared {};
|
|
411
|
+
for (std::size_t i = 0; i < n; i++) {
|
|
412
|
+
val_a_x = metric_type_(a[i * 3 + 0]);
|
|
413
|
+
val_a_y = metric_type_(a[i * 3 + 1]);
|
|
414
|
+
val_a_z = metric_type_(a[i * 3 + 2]);
|
|
415
|
+
val_b_x = metric_type_(b[i * 3 + 0]);
|
|
416
|
+
val_b_y = metric_type_(b[i * 3 + 1]);
|
|
417
|
+
val_b_z = metric_type_(b[i * 3 + 2]);
|
|
418
|
+
metric_type_ dx = (val_a_x - centroid_a_x) - (val_b_x - centroid_b_x);
|
|
419
|
+
metric_type_ dy = (val_a_y - centroid_a_y) - (val_b_y - centroid_b_y);
|
|
420
|
+
metric_type_ dz = (val_a_z - centroid_a_z) - (val_b_z - centroid_b_z);
|
|
421
|
+
sum_squared = sum_squared + dx * dx + dy * dy + dz * dz;
|
|
422
|
+
}
|
|
423
|
+
|
|
424
|
+
*metric = (sum_squared * inv_n).sqrt();
|
|
425
|
+
}
|
|
426
|
+
}
|
|
427
|
+
|
|
428
|
+
/**
|
|
429
|
+
* @brief Kabsch algorithm: min ‖P − R × Q‖² over rotation R ∈ SO(3)
|
|
430
|
+
* @param[in] a,b Point clouds [n x 3] interleaved (source and target)
|
|
431
|
+
* @param[in] n Number of 3D points
|
|
432
|
+
* @param[out] a_centroid,b_centroid Centroids (3 values each), can be nullptr
|
|
433
|
+
* @param[out] rotation 3x3 rotation matrix (9 values, row-major), can be nullptr
|
|
434
|
+
* @param[out] scale Scale factor, always 1.0 for Kabsch, can be nullptr
|
|
435
|
+
* @param[out] metric Output RMSD after optimal rotation
|
|
436
|
+
*
|
|
437
|
+
* @tparam in_type_ Input point type (f32_t, f64_t, f16_t, bf16_t)
|
|
438
|
+
* @tparam transform_type_ Type of centroids, rotation, and scale outputs
|
|
439
|
+
* @tparam metric_type_ Type of the scalar fit metric output
|
|
440
|
+
* @tparam allow_simd_ Enable SIMD kernel dispatch when `prefer_simd_k`
|
|
441
|
+
*/
|
|
442
|
+
template <typename in_type_, typename transform_type_ = typename in_type_::mesh_transform_t,
|
|
443
|
+
typename metric_type_ = typename in_type_::mesh_metric_t,
|
|
444
|
+
allow_simd_t allow_simd_ = prefer_simd_k>
|
|
445
|
+
void kabsch( //
|
|
446
|
+
in_type_ const *a, in_type_ const *b, std::size_t n, //
|
|
447
|
+
transform_type_ *a_centroid, transform_type_ *b_centroid, transform_type_ *rotation, transform_type_ *scale,
|
|
448
|
+
metric_type_ *metric) noexcept {
|
|
449
|
+
constexpr bool simd = allow_simd_ == prefer_simd_k &&
|
|
450
|
+
std::is_same_v<transform_type_, typename in_type_::mesh_transform_t> &&
|
|
451
|
+
std::is_same_v<metric_type_, typename in_type_::mesh_metric_t>;
|
|
452
|
+
|
|
453
|
+
if constexpr (std::is_same_v<in_type_, f64_t> && simd)
|
|
454
|
+
nk_kabsch_f64(&a->raw_, &b->raw_, n, a_centroid ? &a_centroid->raw_ : nullptr, &b_centroid->raw_,
|
|
455
|
+
&rotation->raw_, &scale->raw_, &metric->raw_);
|
|
456
|
+
else if constexpr (std::is_same_v<in_type_, f32_t> && simd)
|
|
457
|
+
nk_kabsch_f32(&a->raw_, &b->raw_, n, &a_centroid->raw_, &b_centroid->raw_, &rotation->raw_, &scale->raw_,
|
|
458
|
+
&metric->raw_);
|
|
459
|
+
else if constexpr (std::is_same_v<in_type_, f16_t> && simd)
|
|
460
|
+
nk_kabsch_f16(&a->raw_, &b->raw_, n, &a_centroid->raw_, &b_centroid->raw_, &rotation->raw_, &scale->raw_,
|
|
461
|
+
&metric->raw_);
|
|
462
|
+
else if constexpr (std::is_same_v<in_type_, bf16_t> && simd)
|
|
463
|
+
nk_kabsch_bf16(&a->raw_, &b->raw_, n, &a_centroid->raw_, &b_centroid->raw_, &rotation->raw_, &scale->raw_,
|
|
464
|
+
&metric->raw_);
|
|
465
|
+
// Scalar fallback
|
|
466
|
+
else {
|
|
467
|
+
// Step 1: Compute centroids
|
|
468
|
+
metric_type_ sum_a_x {}, sum_a_y {}, sum_a_z {};
|
|
469
|
+
metric_type_ sum_b_x {}, sum_b_y {}, sum_b_z {};
|
|
470
|
+
metric_type_ val_a_x, val_a_y, val_a_z, val_b_x, val_b_y, val_b_z;
|
|
471
|
+
|
|
472
|
+
for (std::size_t i = 0; i < n; i++) {
|
|
473
|
+
val_a_x = metric_type_(a[i * 3 + 0]);
|
|
474
|
+
val_a_y = metric_type_(a[i * 3 + 1]);
|
|
475
|
+
val_a_z = metric_type_(a[i * 3 + 2]);
|
|
476
|
+
val_b_x = metric_type_(b[i * 3 + 0]);
|
|
477
|
+
val_b_y = metric_type_(b[i * 3 + 1]);
|
|
478
|
+
val_b_z = metric_type_(b[i * 3 + 2]);
|
|
479
|
+
sum_a_x = sum_a_x + val_a_x;
|
|
480
|
+
sum_a_y = sum_a_y + val_a_y;
|
|
481
|
+
sum_a_z = sum_a_z + val_a_z;
|
|
482
|
+
sum_b_x = sum_b_x + val_b_x;
|
|
483
|
+
sum_b_y = sum_b_y + val_b_y;
|
|
484
|
+
sum_b_z = sum_b_z + val_b_z;
|
|
485
|
+
}
|
|
486
|
+
|
|
487
|
+
metric_type_ inv_n = metric_type_(1.0) / metric_type_(static_cast<double>(n));
|
|
488
|
+
metric_type_ centroid_a_x = sum_a_x * inv_n;
|
|
489
|
+
metric_type_ centroid_a_y = sum_a_y * inv_n;
|
|
490
|
+
metric_type_ centroid_a_z = sum_a_z * inv_n;
|
|
491
|
+
metric_type_ centroid_b_x = sum_b_x * inv_n;
|
|
492
|
+
metric_type_ centroid_b_y = sum_b_y * inv_n;
|
|
493
|
+
metric_type_ centroid_b_z = sum_b_z * inv_n;
|
|
494
|
+
|
|
495
|
+
if (a_centroid)
|
|
496
|
+
a_centroid[0] = transform_type_(centroid_a_x), a_centroid[1] = transform_type_(centroid_a_y),
|
|
497
|
+
a_centroid[2] = transform_type_(centroid_a_z);
|
|
498
|
+
|
|
499
|
+
if (b_centroid)
|
|
500
|
+
b_centroid[0] = transform_type_(centroid_b_x), b_centroid[1] = transform_type_(centroid_b_y),
|
|
501
|
+
b_centroid[2] = transform_type_(centroid_b_z);
|
|
502
|
+
|
|
503
|
+
// Step 2: Build 3x3 covariance matrix H = (A - A_bar)^T x (B - B_bar)
|
|
504
|
+
metric_type_ cross_covariance[9] = {};
|
|
505
|
+
for (std::size_t i = 0; i < n; i++) {
|
|
506
|
+
val_a_x = metric_type_(a[i * 3 + 0]) - centroid_a_x;
|
|
507
|
+
val_a_y = metric_type_(a[i * 3 + 1]) - centroid_a_y;
|
|
508
|
+
val_a_z = metric_type_(a[i * 3 + 2]) - centroid_a_z;
|
|
509
|
+
val_b_x = metric_type_(b[i * 3 + 0]) - centroid_b_x;
|
|
510
|
+
val_b_y = metric_type_(b[i * 3 + 1]) - centroid_b_y;
|
|
511
|
+
val_b_z = metric_type_(b[i * 3 + 2]) - centroid_b_z;
|
|
512
|
+
cross_covariance[0] = cross_covariance[0] + val_a_x * val_b_x;
|
|
513
|
+
cross_covariance[1] = cross_covariance[1] + val_a_x * val_b_y;
|
|
514
|
+
cross_covariance[2] = cross_covariance[2] + val_a_x * val_b_z;
|
|
515
|
+
cross_covariance[3] = cross_covariance[3] + val_a_y * val_b_x;
|
|
516
|
+
cross_covariance[4] = cross_covariance[4] + val_a_y * val_b_y;
|
|
517
|
+
cross_covariance[5] = cross_covariance[5] + val_a_y * val_b_z;
|
|
518
|
+
cross_covariance[6] = cross_covariance[6] + val_a_z * val_b_x;
|
|
519
|
+
cross_covariance[7] = cross_covariance[7] + val_a_z * val_b_y;
|
|
520
|
+
cross_covariance[8] = cross_covariance[8] + val_a_z * val_b_z;
|
|
521
|
+
}
|
|
522
|
+
|
|
523
|
+
// Step 3: SVD of H = U * S * Vt
|
|
524
|
+
metric_type_ svd_u[9], svd_s[9], svd_v[9];
|
|
525
|
+
svd3x3_(cross_covariance, svd_u, svd_s, svd_v);
|
|
526
|
+
|
|
527
|
+
// Step 4: R = V * Ut
|
|
528
|
+
metric_type_ rotation_matrix[9];
|
|
529
|
+
rotation_matrix[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
|
|
530
|
+
rotation_matrix[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
|
|
531
|
+
rotation_matrix[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
|
|
532
|
+
rotation_matrix[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
|
|
533
|
+
rotation_matrix[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
|
|
534
|
+
rotation_matrix[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
|
|
535
|
+
rotation_matrix[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
|
|
536
|
+
rotation_matrix[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
|
|
537
|
+
rotation_matrix[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
|
|
538
|
+
|
|
539
|
+
// Handle reflection: if det(R) < 0, negate third column of V and recompute R
|
|
540
|
+
metric_type_ rotation_det = det3x3_(rotation_matrix);
|
|
541
|
+
if (rotation_det < metric_type_(0.0)) {
|
|
542
|
+
svd_v[2] = metric_type_(0.0) - svd_v[2];
|
|
543
|
+
svd_v[5] = metric_type_(0.0) - svd_v[5];
|
|
544
|
+
svd_v[8] = metric_type_(0.0) - svd_v[8];
|
|
545
|
+
rotation_matrix[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
|
|
546
|
+
rotation_matrix[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
|
|
547
|
+
rotation_matrix[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
|
|
548
|
+
rotation_matrix[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
|
|
549
|
+
rotation_matrix[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
|
|
550
|
+
rotation_matrix[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
|
|
551
|
+
rotation_matrix[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
|
|
552
|
+
rotation_matrix[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
|
|
553
|
+
rotation_matrix[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
|
|
554
|
+
}
|
|
555
|
+
|
|
556
|
+
// Output rotation matrix and scale=1.0
|
|
557
|
+
if (rotation) {
|
|
558
|
+
for (unsigned int j = 0; j < 9; j++) rotation[j] = transform_type_(rotation_matrix[j]);
|
|
559
|
+
}
|
|
560
|
+
if (scale) *scale = transform_type_(1.0);
|
|
561
|
+
|
|
562
|
+
// Step 5: Compute RMSD after rotation
|
|
563
|
+
metric_type_ sum_squared {};
|
|
564
|
+
for (std::size_t i = 0; i < n; i++) {
|
|
565
|
+
metric_type_ point_a[3], point_b[3], rotated_point_a[3];
|
|
566
|
+
point_a[0] = metric_type_(a[i * 3 + 0]) - centroid_a_x;
|
|
567
|
+
point_a[1] = metric_type_(a[i * 3 + 1]) - centroid_a_y;
|
|
568
|
+
point_a[2] = metric_type_(a[i * 3 + 2]) - centroid_a_z;
|
|
569
|
+
point_b[0] = metric_type_(b[i * 3 + 0]) - centroid_b_x;
|
|
570
|
+
point_b[1] = metric_type_(b[i * 3 + 1]) - centroid_b_y;
|
|
571
|
+
point_b[2] = metric_type_(b[i * 3 + 2]) - centroid_b_z;
|
|
572
|
+
rotated_point_a[0] = rotation_matrix[0] * point_a[0] + rotation_matrix[1] * point_a[1] +
|
|
573
|
+
rotation_matrix[2] * point_a[2];
|
|
574
|
+
rotated_point_a[1] = rotation_matrix[3] * point_a[0] + rotation_matrix[4] * point_a[1] +
|
|
575
|
+
rotation_matrix[5] * point_a[2];
|
|
576
|
+
rotated_point_a[2] = rotation_matrix[6] * point_a[0] + rotation_matrix[7] * point_a[1] +
|
|
577
|
+
rotation_matrix[8] * point_a[2];
|
|
578
|
+
metric_type_ dx = rotated_point_a[0] - point_b[0];
|
|
579
|
+
metric_type_ dy = rotated_point_a[1] - point_b[1];
|
|
580
|
+
metric_type_ dz = rotated_point_a[2] - point_b[2];
|
|
581
|
+
sum_squared = sum_squared + dx * dx + dy * dy + dz * dz;
|
|
582
|
+
}
|
|
583
|
+
|
|
584
|
+
*metric = (sum_squared * inv_n).sqrt();
|
|
585
|
+
}
|
|
586
|
+
}
|
|
587
|
+
|
|
588
|
+
/**
|
|
589
|
+
* @brief Umeyama algorithm: min ‖P − s × R × Q‖² over R ∈ SO(3), s ∈ ℝ⁺
|
|
590
|
+
* @param[in] a,b Point clouds [n x 3] interleaved (source and target)
|
|
591
|
+
* @param[in] n Number of 3D points
|
|
592
|
+
* @param[out] a_centroid,b_centroid Centroids (3 values each), can be nullptr
|
|
593
|
+
* @param[out] rotation 3x3 rotation matrix (9 values, row-major), can be nullptr
|
|
594
|
+
* @param[out] scale Uniform scale factor, can be nullptr
|
|
595
|
+
* @param[out] metric Output RMSD after optimal transformation
|
|
596
|
+
*
|
|
597
|
+
* @tparam in_type_ Input point type (f32_t, f64_t, f16_t, bf16_t)
|
|
598
|
+
* @tparam transform_type_ Type of centroids, rotation, and scale outputs
|
|
599
|
+
* @tparam metric_type_ Type of the scalar fit metric output
|
|
600
|
+
* @tparam allow_simd_ Enable SIMD kernel dispatch when `prefer_simd_k`
|
|
601
|
+
*/
|
|
602
|
+
template <typename in_type_, typename transform_type_ = typename in_type_::mesh_transform_t,
|
|
603
|
+
typename metric_type_ = typename in_type_::mesh_metric_t, allow_simd_t allow_simd_ = prefer_simd_k>
|
|
604
|
+
void umeyama(in_type_ const *a, in_type_ const *b, std::size_t n, transform_type_ *a_centroid,
|
|
605
|
+
transform_type_ *b_centroid, transform_type_ *rotation, transform_type_ *scale,
|
|
606
|
+
metric_type_ *metric) noexcept {
|
|
607
|
+
constexpr bool simd = allow_simd_ == prefer_simd_k &&
|
|
608
|
+
std::is_same_v<transform_type_, typename in_type_::mesh_transform_t> &&
|
|
609
|
+
std::is_same_v<metric_type_, typename in_type_::mesh_metric_t>;
|
|
610
|
+
|
|
611
|
+
if constexpr (std::is_same_v<in_type_, f64_t> && simd)
|
|
612
|
+
nk_umeyama_f64(&a->raw_, &b->raw_, n, &a_centroid->raw_, &b_centroid->raw_, &rotation->raw_, &scale->raw_,
|
|
613
|
+
&metric->raw_);
|
|
614
|
+
else if constexpr (std::is_same_v<in_type_, f32_t> && simd)
|
|
615
|
+
nk_umeyama_f32(&a->raw_, &b->raw_, n, &a_centroid->raw_, &b_centroid->raw_, &rotation->raw_, &scale->raw_,
|
|
616
|
+
&metric->raw_);
|
|
617
|
+
else if constexpr (std::is_same_v<in_type_, f16_t> && simd)
|
|
618
|
+
nk_umeyama_f16(&a->raw_, &b->raw_, n, &a_centroid->raw_, &b_centroid->raw_, &rotation->raw_, &scale->raw_,
|
|
619
|
+
&metric->raw_);
|
|
620
|
+
else if constexpr (std::is_same_v<in_type_, bf16_t> && simd)
|
|
621
|
+
nk_umeyama_bf16(&a->raw_, &b->raw_, n, &a_centroid->raw_, &b_centroid->raw_, &rotation->raw_, &scale->raw_,
|
|
622
|
+
&metric->raw_);
|
|
623
|
+
// Scalar fallback
|
|
624
|
+
else {
|
|
625
|
+
// Step 1: Compute centroids
|
|
626
|
+
metric_type_ sum_a_x {}, sum_a_y {}, sum_a_z {};
|
|
627
|
+
metric_type_ sum_b_x {}, sum_b_y {}, sum_b_z {};
|
|
628
|
+
metric_type_ val_a_x, val_a_y, val_a_z, val_b_x, val_b_y, val_b_z;
|
|
629
|
+
|
|
630
|
+
for (std::size_t i = 0; i < n; i++) {
|
|
631
|
+
val_a_x = metric_type_(a[i * 3 + 0]);
|
|
632
|
+
val_a_y = metric_type_(a[i * 3 + 1]);
|
|
633
|
+
val_a_z = metric_type_(a[i * 3 + 2]);
|
|
634
|
+
val_b_x = metric_type_(b[i * 3 + 0]);
|
|
635
|
+
val_b_y = metric_type_(b[i * 3 + 1]);
|
|
636
|
+
val_b_z = metric_type_(b[i * 3 + 2]);
|
|
637
|
+
sum_a_x = sum_a_x + val_a_x;
|
|
638
|
+
sum_a_y = sum_a_y + val_a_y;
|
|
639
|
+
sum_a_z = sum_a_z + val_a_z;
|
|
640
|
+
sum_b_x = sum_b_x + val_b_x;
|
|
641
|
+
sum_b_y = sum_b_y + val_b_y;
|
|
642
|
+
sum_b_z = sum_b_z + val_b_z;
|
|
643
|
+
}
|
|
644
|
+
|
|
645
|
+
metric_type_ inv_n = metric_type_(1.0) / metric_type_(static_cast<double>(n));
|
|
646
|
+
metric_type_ centroid_a_x = sum_a_x * inv_n;
|
|
647
|
+
metric_type_ centroid_a_y = sum_a_y * inv_n;
|
|
648
|
+
metric_type_ centroid_a_z = sum_a_z * inv_n;
|
|
649
|
+
metric_type_ centroid_b_x = sum_b_x * inv_n;
|
|
650
|
+
metric_type_ centroid_b_y = sum_b_y * inv_n;
|
|
651
|
+
metric_type_ centroid_b_z = sum_b_z * inv_n;
|
|
652
|
+
|
|
653
|
+
if (a_centroid) {
|
|
654
|
+
a_centroid[0] = transform_type_(centroid_a_x);
|
|
655
|
+
a_centroid[1] = transform_type_(centroid_a_y);
|
|
656
|
+
a_centroid[2] = transform_type_(centroid_a_z);
|
|
657
|
+
}
|
|
658
|
+
if (b_centroid) {
|
|
659
|
+
b_centroid[0] = transform_type_(centroid_b_x);
|
|
660
|
+
b_centroid[1] = transform_type_(centroid_b_y);
|
|
661
|
+
b_centroid[2] = transform_type_(centroid_b_z);
|
|
662
|
+
}
|
|
663
|
+
|
|
664
|
+
// Step 2: Build covariance matrix H and compute variance of A
|
|
665
|
+
metric_type_ cross_covariance[9] = {};
|
|
666
|
+
metric_type_ variance_a {};
|
|
667
|
+
for (std::size_t i = 0; i < n; i++) {
|
|
668
|
+
val_a_x = metric_type_(a[i * 3 + 0]) - centroid_a_x;
|
|
669
|
+
val_a_y = metric_type_(a[i * 3 + 1]) - centroid_a_y;
|
|
670
|
+
val_a_z = metric_type_(a[i * 3 + 2]) - centroid_a_z;
|
|
671
|
+
val_b_x = metric_type_(b[i * 3 + 0]) - centroid_b_x;
|
|
672
|
+
val_b_y = metric_type_(b[i * 3 + 1]) - centroid_b_y;
|
|
673
|
+
val_b_z = metric_type_(b[i * 3 + 2]) - centroid_b_z;
|
|
674
|
+
variance_a = variance_a + val_a_x * val_a_x + val_a_y * val_a_y + val_a_z * val_a_z;
|
|
675
|
+
cross_covariance[0] = cross_covariance[0] + val_a_x * val_b_x;
|
|
676
|
+
cross_covariance[1] = cross_covariance[1] + val_a_x * val_b_y;
|
|
677
|
+
cross_covariance[2] = cross_covariance[2] + val_a_x * val_b_z;
|
|
678
|
+
cross_covariance[3] = cross_covariance[3] + val_a_y * val_b_x;
|
|
679
|
+
cross_covariance[4] = cross_covariance[4] + val_a_y * val_b_y;
|
|
680
|
+
cross_covariance[5] = cross_covariance[5] + val_a_y * val_b_z;
|
|
681
|
+
cross_covariance[6] = cross_covariance[6] + val_a_z * val_b_x;
|
|
682
|
+
cross_covariance[7] = cross_covariance[7] + val_a_z * val_b_y;
|
|
683
|
+
cross_covariance[8] = cross_covariance[8] + val_a_z * val_b_z;
|
|
684
|
+
}
|
|
685
|
+
variance_a = variance_a * inv_n;
|
|
686
|
+
|
|
687
|
+
// Step 3: SVD of H = U * S * Vt
|
|
688
|
+
metric_type_ svd_u[9], svd_s[9], svd_v[9];
|
|
689
|
+
svd3x3_(cross_covariance, svd_u, svd_s, svd_v);
|
|
690
|
+
|
|
691
|
+
// Step 4: R = V * Ut
|
|
692
|
+
metric_type_ rotation_matrix[9];
|
|
693
|
+
rotation_matrix[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
|
|
694
|
+
rotation_matrix[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
|
|
695
|
+
rotation_matrix[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
|
|
696
|
+
rotation_matrix[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
|
|
697
|
+
rotation_matrix[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
|
|
698
|
+
rotation_matrix[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
|
|
699
|
+
rotation_matrix[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
|
|
700
|
+
rotation_matrix[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
|
|
701
|
+
rotation_matrix[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
|
|
702
|
+
|
|
703
|
+
// Handle reflection and compute scale: c = trace(D*S) / variance_a
|
|
704
|
+
// D = diag(1, 1, det(R)), svd_s contains singular values on diagonal
|
|
705
|
+
metric_type_ rotation_det = det3x3_(rotation_matrix);
|
|
706
|
+
metric_type_ sign_det = rotation_det < metric_type_(0.0) ? metric_type_(-1.0) : metric_type_(1.0);
|
|
707
|
+
metric_type_ trace_scaled_s = svd_s[0] + svd_s[4] + sign_det * svd_s[8];
|
|
708
|
+
metric_type_ scale_factor = trace_scaled_s / (metric_type_(static_cast<double>(n)) * variance_a);
|
|
709
|
+
|
|
710
|
+
if (scale) *scale = transform_type_(scale_factor);
|
|
711
|
+
|
|
712
|
+
if (rotation_det < metric_type_(0.0)) {
|
|
713
|
+
svd_v[2] = metric_type_(0.0) - svd_v[2];
|
|
714
|
+
svd_v[5] = metric_type_(0.0) - svd_v[5];
|
|
715
|
+
svd_v[8] = metric_type_(0.0) - svd_v[8];
|
|
716
|
+
rotation_matrix[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
|
|
717
|
+
rotation_matrix[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
|
|
718
|
+
rotation_matrix[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
|
|
719
|
+
rotation_matrix[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
|
|
720
|
+
rotation_matrix[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
|
|
721
|
+
rotation_matrix[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
|
|
722
|
+
rotation_matrix[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
|
|
723
|
+
rotation_matrix[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
|
|
724
|
+
rotation_matrix[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
|
|
725
|
+
}
|
|
726
|
+
|
|
727
|
+
// Output rotation matrix
|
|
728
|
+
if (rotation) {
|
|
729
|
+
for (unsigned int j = 0; j < 9; j++) rotation[j] = transform_type_(rotation_matrix[j]);
|
|
730
|
+
}
|
|
731
|
+
|
|
732
|
+
// Step 5: Compute RMSD after similarity transform: ||c * R * a - b||
|
|
733
|
+
metric_type_ sum_squared {};
|
|
734
|
+
for (std::size_t i = 0; i < n; i++) {
|
|
735
|
+
metric_type_ point_a[3], point_b[3], rotated_point_a[3];
|
|
736
|
+
point_a[0] = metric_type_(a[i * 3 + 0]) - centroid_a_x;
|
|
737
|
+
point_a[1] = metric_type_(a[i * 3 + 1]) - centroid_a_y;
|
|
738
|
+
point_a[2] = metric_type_(a[i * 3 + 2]) - centroid_a_z;
|
|
739
|
+
point_b[0] = metric_type_(b[i * 3 + 0]) - centroid_b_x;
|
|
740
|
+
point_b[1] = metric_type_(b[i * 3 + 1]) - centroid_b_y;
|
|
741
|
+
point_b[2] = metric_type_(b[i * 3 + 2]) - centroid_b_z;
|
|
742
|
+
rotated_point_a[0] = scale_factor * (rotation_matrix[0] * point_a[0] + rotation_matrix[1] * point_a[1] +
|
|
743
|
+
rotation_matrix[2] * point_a[2]);
|
|
744
|
+
rotated_point_a[1] = scale_factor * (rotation_matrix[3] * point_a[0] + rotation_matrix[4] * point_a[1] +
|
|
745
|
+
rotation_matrix[5] * point_a[2]);
|
|
746
|
+
rotated_point_a[2] = scale_factor * (rotation_matrix[6] * point_a[0] + rotation_matrix[7] * point_a[1] +
|
|
747
|
+
rotation_matrix[8] * point_a[2]);
|
|
748
|
+
metric_type_ dx = rotated_point_a[0] - point_b[0];
|
|
749
|
+
metric_type_ dy = rotated_point_a[1] - point_b[1];
|
|
750
|
+
metric_type_ dz = rotated_point_a[2] - point_b[2];
|
|
751
|
+
sum_squared = sum_squared + dx * dx + dy * dy + dz * dz;
|
|
752
|
+
}
|
|
753
|
+
|
|
754
|
+
*metric = (sum_squared * inv_n).sqrt();
|
|
755
|
+
}
|
|
756
|
+
}
|
|
757
|
+
|
|
758
|
+
#pragma endregion - Mesh Alignment Kernels
|
|
759
|
+
|
|
760
|
+
} // namespace ashvardanian::numkong
|
|
761
|
+
|
|
762
|
+
#endif // NK_MESH_HPP
|