numkong 7.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +201 -0
- package/README.md +495 -0
- package/binding.gyp +540 -0
- package/c/dispatch.h +512 -0
- package/c/dispatch_bf16.c +389 -0
- package/c/dispatch_bf16c.c +52 -0
- package/c/dispatch_e2m3.c +263 -0
- package/c/dispatch_e3m2.c +243 -0
- package/c/dispatch_e4m3.c +276 -0
- package/c/dispatch_e5m2.c +272 -0
- package/c/dispatch_f16.c +376 -0
- package/c/dispatch_f16c.c +58 -0
- package/c/dispatch_f32.c +378 -0
- package/c/dispatch_f32c.c +99 -0
- package/c/dispatch_f64.c +296 -0
- package/c/dispatch_f64c.c +98 -0
- package/c/dispatch_i16.c +96 -0
- package/c/dispatch_i32.c +89 -0
- package/c/dispatch_i4.c +150 -0
- package/c/dispatch_i64.c +86 -0
- package/c/dispatch_i8.c +289 -0
- package/c/dispatch_other.c +330 -0
- package/c/dispatch_u1.c +148 -0
- package/c/dispatch_u16.c +124 -0
- package/c/dispatch_u32.c +118 -0
- package/c/dispatch_u4.c +150 -0
- package/c/dispatch_u64.c +102 -0
- package/c/dispatch_u8.c +303 -0
- package/c/numkong.c +950 -0
- package/include/README.md +573 -0
- package/include/module.modulemap +129 -0
- package/include/numkong/attention/sapphireamx.h +1361 -0
- package/include/numkong/attention/sme.h +2066 -0
- package/include/numkong/attention.h +49 -0
- package/include/numkong/capabilities.h +748 -0
- package/include/numkong/cast/README.md +262 -0
- package/include/numkong/cast/haswell.h +975 -0
- package/include/numkong/cast/icelake.h +470 -0
- package/include/numkong/cast/neon.h +1192 -0
- package/include/numkong/cast/rvv.h +1021 -0
- package/include/numkong/cast/sapphire.h +262 -0
- package/include/numkong/cast/serial.h +2262 -0
- package/include/numkong/cast/skylake.h +856 -0
- package/include/numkong/cast/v128relaxed.h +180 -0
- package/include/numkong/cast.h +230 -0
- package/include/numkong/curved/README.md +223 -0
- package/include/numkong/curved/genoa.h +182 -0
- package/include/numkong/curved/haswell.h +276 -0
- package/include/numkong/curved/neon.h +205 -0
- package/include/numkong/curved/neonbfdot.h +212 -0
- package/include/numkong/curved/neonhalf.h +212 -0
- package/include/numkong/curved/rvv.h +305 -0
- package/include/numkong/curved/serial.h +207 -0
- package/include/numkong/curved/skylake.h +457 -0
- package/include/numkong/curved/smef64.h +506 -0
- package/include/numkong/curved.h +517 -0
- package/include/numkong/curved.hpp +144 -0
- package/include/numkong/dot/README.md +425 -0
- package/include/numkong/dot/alder.h +563 -0
- package/include/numkong/dot/genoa.h +315 -0
- package/include/numkong/dot/haswell.h +1688 -0
- package/include/numkong/dot/icelake.h +883 -0
- package/include/numkong/dot/neon.h +818 -0
- package/include/numkong/dot/neonbfdot.h +244 -0
- package/include/numkong/dot/neonfhm.h +360 -0
- package/include/numkong/dot/neonhalf.h +198 -0
- package/include/numkong/dot/neonsdot.h +508 -0
- package/include/numkong/dot/rvv.h +714 -0
- package/include/numkong/dot/rvvbb.h +72 -0
- package/include/numkong/dot/rvvbf16.h +123 -0
- package/include/numkong/dot/rvvhalf.h +129 -0
- package/include/numkong/dot/sapphire.h +141 -0
- package/include/numkong/dot/serial.h +838 -0
- package/include/numkong/dot/sierra.h +405 -0
- package/include/numkong/dot/skylake.h +1084 -0
- package/include/numkong/dot/sve.h +379 -0
- package/include/numkong/dot/svebfdot.h +74 -0
- package/include/numkong/dot/svehalf.h +123 -0
- package/include/numkong/dot/v128relaxed.h +1258 -0
- package/include/numkong/dot.h +1070 -0
- package/include/numkong/dot.hpp +94 -0
- package/include/numkong/dots/README.md +496 -0
- package/include/numkong/dots/alder.h +114 -0
- package/include/numkong/dots/genoa.h +94 -0
- package/include/numkong/dots/haswell.h +295 -0
- package/include/numkong/dots/icelake.h +171 -0
- package/include/numkong/dots/neon.h +120 -0
- package/include/numkong/dots/neonbfdot.h +58 -0
- package/include/numkong/dots/neonfhm.h +94 -0
- package/include/numkong/dots/neonhalf.h +57 -0
- package/include/numkong/dots/neonsdot.h +108 -0
- package/include/numkong/dots/rvv.h +2486 -0
- package/include/numkong/dots/sapphireamx.h +3973 -0
- package/include/numkong/dots/serial.h +2844 -0
- package/include/numkong/dots/sierra.h +97 -0
- package/include/numkong/dots/skylake.h +196 -0
- package/include/numkong/dots/sme.h +5372 -0
- package/include/numkong/dots/smebi32.h +461 -0
- package/include/numkong/dots/smef64.h +1318 -0
- package/include/numkong/dots/smehalf.h +47 -0
- package/include/numkong/dots/v128relaxed.h +294 -0
- package/include/numkong/dots.h +2804 -0
- package/include/numkong/dots.hpp +639 -0
- package/include/numkong/each/README.md +469 -0
- package/include/numkong/each/haswell.h +1658 -0
- package/include/numkong/each/icelake.h +272 -0
- package/include/numkong/each/neon.h +1104 -0
- package/include/numkong/each/neonbfdot.h +212 -0
- package/include/numkong/each/neonhalf.h +410 -0
- package/include/numkong/each/rvv.h +1121 -0
- package/include/numkong/each/sapphire.h +477 -0
- package/include/numkong/each/serial.h +260 -0
- package/include/numkong/each/skylake.h +1562 -0
- package/include/numkong/each.h +2146 -0
- package/include/numkong/each.hpp +434 -0
- package/include/numkong/geospatial/README.md +147 -0
- package/include/numkong/geospatial/haswell.h +593 -0
- package/include/numkong/geospatial/neon.h +571 -0
- package/include/numkong/geospatial/rvv.h +701 -0
- package/include/numkong/geospatial/serial.h +309 -0
- package/include/numkong/geospatial/skylake.h +577 -0
- package/include/numkong/geospatial/v128relaxed.h +613 -0
- package/include/numkong/geospatial.h +453 -0
- package/include/numkong/geospatial.hpp +235 -0
- package/include/numkong/matrix.hpp +336 -0
- package/include/numkong/maxsim/README.md +187 -0
- package/include/numkong/maxsim/alder.h +511 -0
- package/include/numkong/maxsim/genoa.h +115 -0
- package/include/numkong/maxsim/haswell.h +553 -0
- package/include/numkong/maxsim/icelake.h +480 -0
- package/include/numkong/maxsim/neonsdot.h +394 -0
- package/include/numkong/maxsim/sapphireamx.h +877 -0
- package/include/numkong/maxsim/serial.h +490 -0
- package/include/numkong/maxsim/sme.h +929 -0
- package/include/numkong/maxsim/v128relaxed.h +280 -0
- package/include/numkong/maxsim.h +571 -0
- package/include/numkong/maxsim.hpp +133 -0
- package/include/numkong/mesh/README.md +227 -0
- package/include/numkong/mesh/haswell.h +2235 -0
- package/include/numkong/mesh/neon.h +1329 -0
- package/include/numkong/mesh/neonbfdot.h +842 -0
- package/include/numkong/mesh/neonhalf.h +616 -0
- package/include/numkong/mesh/rvv.h +916 -0
- package/include/numkong/mesh/serial.h +742 -0
- package/include/numkong/mesh/skylake.h +1135 -0
- package/include/numkong/mesh/v128relaxed.h +1052 -0
- package/include/numkong/mesh.h +652 -0
- package/include/numkong/mesh.hpp +762 -0
- package/include/numkong/numkong.h +78 -0
- package/include/numkong/numkong.hpp +57 -0
- package/include/numkong/probability/README.md +173 -0
- package/include/numkong/probability/haswell.h +267 -0
- package/include/numkong/probability/neon.h +225 -0
- package/include/numkong/probability/rvv.h +409 -0
- package/include/numkong/probability/serial.h +169 -0
- package/include/numkong/probability/skylake.h +324 -0
- package/include/numkong/probability.h +383 -0
- package/include/numkong/probability.hpp +120 -0
- package/include/numkong/random.h +50 -0
- package/include/numkong/random.hpp +285 -0
- package/include/numkong/reduce/README.md +547 -0
- package/include/numkong/reduce/alder.h +632 -0
- package/include/numkong/reduce/genoa.h +201 -0
- package/include/numkong/reduce/haswell.h +3783 -0
- package/include/numkong/reduce/icelake.h +549 -0
- package/include/numkong/reduce/neon.h +3841 -0
- package/include/numkong/reduce/neonbfdot.h +353 -0
- package/include/numkong/reduce/neonfhm.h +665 -0
- package/include/numkong/reduce/neonhalf.h +157 -0
- package/include/numkong/reduce/neonsdot.h +357 -0
- package/include/numkong/reduce/rvv.h +3407 -0
- package/include/numkong/reduce/serial.h +757 -0
- package/include/numkong/reduce/sierra.h +338 -0
- package/include/numkong/reduce/skylake.h +3792 -0
- package/include/numkong/reduce/v128relaxed.h +2302 -0
- package/include/numkong/reduce.h +1597 -0
- package/include/numkong/reduce.hpp +633 -0
- package/include/numkong/scalar/README.md +89 -0
- package/include/numkong/scalar/haswell.h +113 -0
- package/include/numkong/scalar/neon.h +122 -0
- package/include/numkong/scalar/neonhalf.h +70 -0
- package/include/numkong/scalar/rvv.h +211 -0
- package/include/numkong/scalar/sapphire.h +63 -0
- package/include/numkong/scalar/serial.h +332 -0
- package/include/numkong/scalar/v128relaxed.h +56 -0
- package/include/numkong/scalar.h +683 -0
- package/include/numkong/set/README.md +179 -0
- package/include/numkong/set/haswell.h +334 -0
- package/include/numkong/set/icelake.h +485 -0
- package/include/numkong/set/neon.h +364 -0
- package/include/numkong/set/rvv.h +226 -0
- package/include/numkong/set/rvvbb.h +117 -0
- package/include/numkong/set/serial.h +174 -0
- package/include/numkong/set/sve.h +185 -0
- package/include/numkong/set/v128relaxed.h +240 -0
- package/include/numkong/set.h +457 -0
- package/include/numkong/set.hpp +114 -0
- package/include/numkong/sets/README.md +149 -0
- package/include/numkong/sets/haswell.h +63 -0
- package/include/numkong/sets/icelake.h +66 -0
- package/include/numkong/sets/neon.h +61 -0
- package/include/numkong/sets/serial.h +43 -0
- package/include/numkong/sets/smebi32.h +1099 -0
- package/include/numkong/sets/v128relaxed.h +58 -0
- package/include/numkong/sets.h +339 -0
- package/include/numkong/sparse/README.md +156 -0
- package/include/numkong/sparse/icelake.h +463 -0
- package/include/numkong/sparse/neon.h +288 -0
- package/include/numkong/sparse/serial.h +117 -0
- package/include/numkong/sparse/sve2.h +507 -0
- package/include/numkong/sparse/turin.h +322 -0
- package/include/numkong/sparse.h +363 -0
- package/include/numkong/sparse.hpp +113 -0
- package/include/numkong/spatial/README.md +435 -0
- package/include/numkong/spatial/alder.h +607 -0
- package/include/numkong/spatial/genoa.h +290 -0
- package/include/numkong/spatial/haswell.h +960 -0
- package/include/numkong/spatial/icelake.h +586 -0
- package/include/numkong/spatial/neon.h +773 -0
- package/include/numkong/spatial/neonbfdot.h +165 -0
- package/include/numkong/spatial/neonhalf.h +118 -0
- package/include/numkong/spatial/neonsdot.h +261 -0
- package/include/numkong/spatial/rvv.h +984 -0
- package/include/numkong/spatial/rvvbf16.h +123 -0
- package/include/numkong/spatial/rvvhalf.h +117 -0
- package/include/numkong/spatial/sapphire.h +343 -0
- package/include/numkong/spatial/serial.h +346 -0
- package/include/numkong/spatial/sierra.h +323 -0
- package/include/numkong/spatial/skylake.h +606 -0
- package/include/numkong/spatial/sve.h +224 -0
- package/include/numkong/spatial/svebfdot.h +122 -0
- package/include/numkong/spatial/svehalf.h +109 -0
- package/include/numkong/spatial/v128relaxed.h +717 -0
- package/include/numkong/spatial.h +1425 -0
- package/include/numkong/spatial.hpp +183 -0
- package/include/numkong/spatials/README.md +580 -0
- package/include/numkong/spatials/alder.h +94 -0
- package/include/numkong/spatials/genoa.h +94 -0
- package/include/numkong/spatials/haswell.h +219 -0
- package/include/numkong/spatials/icelake.h +113 -0
- package/include/numkong/spatials/neon.h +109 -0
- package/include/numkong/spatials/neonbfdot.h +60 -0
- package/include/numkong/spatials/neonfhm.h +92 -0
- package/include/numkong/spatials/neonhalf.h +58 -0
- package/include/numkong/spatials/neonsdot.h +109 -0
- package/include/numkong/spatials/rvv.h +1960 -0
- package/include/numkong/spatials/sapphireamx.h +1149 -0
- package/include/numkong/spatials/serial.h +226 -0
- package/include/numkong/spatials/sierra.h +96 -0
- package/include/numkong/spatials/skylake.h +184 -0
- package/include/numkong/spatials/sme.h +1901 -0
- package/include/numkong/spatials/smef64.h +465 -0
- package/include/numkong/spatials/v128relaxed.h +240 -0
- package/include/numkong/spatials.h +3021 -0
- package/include/numkong/spatials.hpp +508 -0
- package/include/numkong/tensor.hpp +1592 -0
- package/include/numkong/trigonometry/README.md +184 -0
- package/include/numkong/trigonometry/haswell.h +652 -0
- package/include/numkong/trigonometry/neon.h +639 -0
- package/include/numkong/trigonometry/rvv.h +699 -0
- package/include/numkong/trigonometry/serial.h +703 -0
- package/include/numkong/trigonometry/skylake.h +721 -0
- package/include/numkong/trigonometry/v128relaxed.h +666 -0
- package/include/numkong/trigonometry.h +467 -0
- package/include/numkong/trigonometry.hpp +166 -0
- package/include/numkong/types.h +1384 -0
- package/include/numkong/types.hpp +5603 -0
- package/include/numkong/vector.hpp +698 -0
- package/javascript/README.md +246 -0
- package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
- package/javascript/dist/cjs/numkong-wasm.js +617 -0
- package/javascript/dist/cjs/numkong.d.ts +343 -0
- package/javascript/dist/cjs/numkong.js +523 -0
- package/javascript/dist/cjs/package.json +3 -0
- package/javascript/dist/cjs/types.d.ts +284 -0
- package/javascript/dist/cjs/types.js +653 -0
- package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
- package/javascript/dist/esm/numkong-wasm.js +595 -0
- package/javascript/dist/esm/numkong.d.ts +343 -0
- package/javascript/dist/esm/numkong.js +452 -0
- package/javascript/dist/esm/package.json +3 -0
- package/javascript/dist/esm/types.d.ts +284 -0
- package/javascript/dist/esm/types.js +630 -0
- package/javascript/dist-package-cjs.json +3 -0
- package/javascript/dist-package-esm.json +3 -0
- package/javascript/node-gyp-build.d.ts +1 -0
- package/javascript/numkong-wasm.ts +756 -0
- package/javascript/numkong.c +689 -0
- package/javascript/numkong.ts +575 -0
- package/javascript/tsconfig-base.json +39 -0
- package/javascript/tsconfig-cjs.json +8 -0
- package/javascript/tsconfig-esm.json +8 -0
- package/javascript/types.ts +674 -0
- package/package.json +87 -0
|
@@ -0,0 +1,842 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SIMD-accelerated Point Cloud Alignment for NEON BF16.
|
|
3
|
+
* @file include/numkong/mesh/neonbfdot.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date December 27, 2025
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/mesh.h
|
|
8
|
+
*
|
|
9
|
+
* @section mesh_neonbfdot_instructions ARM NEON BF16 Instructions (ARMv8.6-BF16)
|
|
10
|
+
*
|
|
11
|
+
* Intrinsic Instruction Latency Throughput
|
|
12
|
+
* A76 M4+/V1+/Oryon
|
|
13
|
+
* vld3_u16 LD3 (V.4H x 3) 6cy 1/cy 2/cy
|
|
14
|
+
* vshll_n_u16 USHLL (V.4S, V.4H, #16) 2cy 2/cy 4/cy
|
|
15
|
+
* vfmaq_f32 FMLA (V.4S, V.4S, V.4S) 4cy 2/cy 4/cy
|
|
16
|
+
* vaddq_f32 FADD (V.4S, V.4S, V.4S) 2cy 2/cy 4/cy
|
|
17
|
+
* vsubq_f32 FSUB (V.4S, V.4S, V.4S) 2cy 2/cy 4/cy
|
|
18
|
+
* vmulq_f32 FMUL (V.4S, V.4S, V.4S) 3cy 2/cy 4/cy
|
|
19
|
+
* vdupq_n_f32 DUP (V.4S, scalar) 2cy 2/cy 4/cy
|
|
20
|
+
* vaddvq_f32 FADDP+FADDP (V.4S) 4cy 1/cy 2/cy
|
|
21
|
+
*
|
|
22
|
+
* The ARMv8.6-BF16 extension enables BF16 storage with F32 computation for 3D mesh alignment
|
|
23
|
+
* operations. BF16's wider exponent range (matching F32) prevents overflow in geometric calculations
|
|
24
|
+
* while halving memory bandwidth compared to F32.
|
|
25
|
+
*
|
|
26
|
+
* For point cloud registration (RMSD, Kabsch, Umeyama), BF16 data is loaded using VLD3 de-interleave
|
|
27
|
+
* operations, converted to F32 via bit-shift widening, then processed with F32 FMA chains. The 2x
|
|
28
|
+
* unrolling with dual accumulators hides the 4-cycle FMA latency, achieving near-peak throughput
|
|
29
|
+
* for covariance matrix and centroid computations.
|
|
30
|
+
*/
|
|
31
|
+
#ifndef NK_MESH_NEONBFDOT_H
|
|
32
|
+
#define NK_MESH_NEONBFDOT_H
|
|
33
|
+
|
|
34
|
+
#if NK_TARGET_ARM_
|
|
35
|
+
#if NK_TARGET_NEONBFDOT
|
|
36
|
+
|
|
37
|
+
#include "numkong/types.h"
|
|
38
|
+
#include "numkong/spatial/neon.h" // `nk_f32_sqrt_neon`
|
|
39
|
+
|
|
40
|
+
#if defined(__cplusplus)
|
|
41
|
+
extern "C" {
|
|
42
|
+
#endif
|
|
43
|
+
|
|
44
|
+
#if defined(__clang__)
|
|
45
|
+
#pragma clang attribute push(__attribute__((target("arch=armv8.6-a+simd+bf16"))), apply_to = function)
|
|
46
|
+
#elif defined(__GNUC__)
|
|
47
|
+
#pragma GCC push_options
|
|
48
|
+
#pragma GCC target("arch=armv8.6-a+simd+bf16")
|
|
49
|
+
#endif
|
|
50
|
+
|
|
51
|
+
/* Load 4 bf16 xyz points (12 bf16 values) → 3x float32x4_t.
|
|
52
|
+
* Uses vld3_u16 to de-interleave xyz triplets, then converts bf16 to f32.
|
|
53
|
+
*
|
|
54
|
+
* Input: 12 contiguous bf16 [x0,y0,z0, x1,y1,z1, x2,y2,z2, x3,y3,z3]
|
|
55
|
+
* Output: x[4], y[4], z[4] vectors in f32
|
|
56
|
+
*/
|
|
57
|
+
NK_INTERNAL void nk_deinterleave_bf16x4_to_f32x4_neonbfdot_(nk_bf16_t const *ptr, float32x4_t *x_out,
|
|
58
|
+
float32x4_t *y_out, float32x4_t *z_out) {
|
|
59
|
+
// Load 12 bf16 values and de-interleave into x, y, z components
|
|
60
|
+
uint16x4x3_t xyz = vld3_u16((uint16_t const *)ptr);
|
|
61
|
+
// Convert bf16 to f32 by zero-extending to lower 16 bits, then shifting left by 16
|
|
62
|
+
uint32x4_t x_u32 = vshll_n_u16(xyz.val[0], 16);
|
|
63
|
+
uint32x4_t y_u32 = vshll_n_u16(xyz.val[1], 16);
|
|
64
|
+
uint32x4_t z_u32 = vshll_n_u16(xyz.val[2], 16);
|
|
65
|
+
*x_out = vreinterpretq_f32_u32(x_u32);
|
|
66
|
+
*y_out = vreinterpretq_f32_u32(y_u32);
|
|
67
|
+
*z_out = vreinterpretq_f32_u32(z_u32);
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
NK_INTERNAL void nk_partial_deinterleave_bf16_to_f32x4_neonbfdot_(nk_bf16_t const *ptr, nk_size_t n_points,
|
|
71
|
+
float32x4_t *x_out, float32x4_t *y_out,
|
|
72
|
+
float32x4_t *z_out) {
|
|
73
|
+
nk_u16_t buf[12] = {0};
|
|
74
|
+
nk_u16_t const *src = (nk_u16_t const *)ptr;
|
|
75
|
+
for (nk_size_t k = 0; k < n_points * 3; ++k) buf[k] = src[k];
|
|
76
|
+
nk_deinterleave_bf16x4_to_f32x4_neonbfdot_((nk_bf16_t const *)buf, x_out, y_out, z_out);
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
/* Compute sum of squared distances for bf16 data after applying rotation (and optional scale).
|
|
80
|
+
* Loads bf16 data, converts to f32 during processing.
|
|
81
|
+
* Note: rotation matrix r is f32 (from SVD), scale and computation done in f32.
|
|
82
|
+
*/
|
|
83
|
+
NK_INTERNAL nk_f32_t nk_transformed_ssd_bf16_neonbfdot_(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n,
|
|
84
|
+
nk_f32_t const *r, nk_f32_t scale, nk_f32_t centroid_a_x,
|
|
85
|
+
nk_f32_t centroid_a_y, nk_f32_t centroid_a_z,
|
|
86
|
+
nk_f32_t centroid_b_x, nk_f32_t centroid_b_y,
|
|
87
|
+
nk_f32_t centroid_b_z) {
|
|
88
|
+
// Broadcast scaled rotation matrix elements
|
|
89
|
+
float32x4_t scaled_rotation_x_x_f32x4 = vdupq_n_f32(scale * r[0]);
|
|
90
|
+
float32x4_t scaled_rotation_x_y_f32x4 = vdupq_n_f32(scale * r[1]);
|
|
91
|
+
float32x4_t scaled_rotation_x_z_f32x4 = vdupq_n_f32(scale * r[2]);
|
|
92
|
+
float32x4_t scaled_rotation_y_x_f32x4 = vdupq_n_f32(scale * r[3]);
|
|
93
|
+
float32x4_t scaled_rotation_y_y_f32x4 = vdupq_n_f32(scale * r[4]);
|
|
94
|
+
float32x4_t scaled_rotation_y_z_f32x4 = vdupq_n_f32(scale * r[5]);
|
|
95
|
+
float32x4_t scaled_rotation_z_x_f32x4 = vdupq_n_f32(scale * r[6]);
|
|
96
|
+
float32x4_t scaled_rotation_z_y_f32x4 = vdupq_n_f32(scale * r[7]);
|
|
97
|
+
float32x4_t scaled_rotation_z_z_f32x4 = vdupq_n_f32(scale * r[8]);
|
|
98
|
+
|
|
99
|
+
// Broadcast centroids
|
|
100
|
+
float32x4_t centroid_a_x_f32x4 = vdupq_n_f32(centroid_a_x);
|
|
101
|
+
float32x4_t centroid_a_y_f32x4 = vdupq_n_f32(centroid_a_y);
|
|
102
|
+
float32x4_t centroid_a_z_f32x4 = vdupq_n_f32(centroid_a_z);
|
|
103
|
+
float32x4_t centroid_b_x_f32x4 = vdupq_n_f32(centroid_b_x);
|
|
104
|
+
float32x4_t centroid_b_y_f32x4 = vdupq_n_f32(centroid_b_y);
|
|
105
|
+
float32x4_t centroid_b_z_f32x4 = vdupq_n_f32(centroid_b_z);
|
|
106
|
+
|
|
107
|
+
// Two independent accumulators to hide FMA latency
|
|
108
|
+
float32x4_t sum_squared_a_f32x4 = vdupq_n_f32(0);
|
|
109
|
+
float32x4_t sum_squared_b_f32x4 = vdupq_n_f32(0);
|
|
110
|
+
nk_size_t j = 0;
|
|
111
|
+
|
|
112
|
+
// Main loop: process 8 points per iteration (2x unrolled)
|
|
113
|
+
for (; j + 8 <= n; j += 8) {
|
|
114
|
+
// First batch of 4 points
|
|
115
|
+
float32x4_t a1_x_f32x4, a1_y_f32x4, a1_z_f32x4, b1_x_f32x4, b1_y_f32x4, b1_z_f32x4;
|
|
116
|
+
nk_deinterleave_bf16x4_to_f32x4_neonbfdot_(a + j * 3, &a1_x_f32x4, &a1_y_f32x4, &a1_z_f32x4);
|
|
117
|
+
nk_deinterleave_bf16x4_to_f32x4_neonbfdot_(b + j * 3, &b1_x_f32x4, &b1_y_f32x4, &b1_z_f32x4);
|
|
118
|
+
|
|
119
|
+
// Second batch of 4 points
|
|
120
|
+
float32x4_t a2_x_f32x4, a2_y_f32x4, a2_z_f32x4, b2_x_f32x4, b2_y_f32x4, b2_z_f32x4;
|
|
121
|
+
nk_deinterleave_bf16x4_to_f32x4_neonbfdot_(a + (j + 4) * 3, &a2_x_f32x4, &a2_y_f32x4, &a2_z_f32x4);
|
|
122
|
+
nk_deinterleave_bf16x4_to_f32x4_neonbfdot_(b + (j + 4) * 3, &b2_x_f32x4, &b2_y_f32x4, &b2_z_f32x4);
|
|
123
|
+
|
|
124
|
+
// Center first batch
|
|
125
|
+
float32x4_t pa1_x_f32x4 = vsubq_f32(a1_x_f32x4, centroid_a_x_f32x4);
|
|
126
|
+
float32x4_t pa1_y_f32x4 = vsubq_f32(a1_y_f32x4, centroid_a_y_f32x4);
|
|
127
|
+
float32x4_t pa1_z_f32x4 = vsubq_f32(a1_z_f32x4, centroid_a_z_f32x4);
|
|
128
|
+
float32x4_t pb1_x_f32x4 = vsubq_f32(b1_x_f32x4, centroid_b_x_f32x4);
|
|
129
|
+
float32x4_t pb1_y_f32x4 = vsubq_f32(b1_y_f32x4, centroid_b_y_f32x4);
|
|
130
|
+
float32x4_t pb1_z_f32x4 = vsubq_f32(b1_z_f32x4, centroid_b_z_f32x4);
|
|
131
|
+
|
|
132
|
+
// Center second batch
|
|
133
|
+
float32x4_t pa2_x_f32x4 = vsubq_f32(a2_x_f32x4, centroid_a_x_f32x4);
|
|
134
|
+
float32x4_t pa2_y_f32x4 = vsubq_f32(a2_y_f32x4, centroid_a_y_f32x4);
|
|
135
|
+
float32x4_t pa2_z_f32x4 = vsubq_f32(a2_z_f32x4, centroid_a_z_f32x4);
|
|
136
|
+
float32x4_t pb2_x_f32x4 = vsubq_f32(b2_x_f32x4, centroid_b_x_f32x4);
|
|
137
|
+
float32x4_t pb2_y_f32x4 = vsubq_f32(b2_y_f32x4, centroid_b_y_f32x4);
|
|
138
|
+
float32x4_t pb2_z_f32x4 = vsubq_f32(b2_z_f32x4, centroid_b_z_f32x4);
|
|
139
|
+
|
|
140
|
+
// Rotate and scale first batch: ra1 = scale * R * pa1
|
|
141
|
+
float32x4_t ra1_x_f32x4 = vfmaq_f32(
|
|
142
|
+
vfmaq_f32(vmulq_f32(scaled_rotation_x_x_f32x4, pa1_x_f32x4), scaled_rotation_x_y_f32x4, pa1_y_f32x4),
|
|
143
|
+
scaled_rotation_x_z_f32x4, pa1_z_f32x4);
|
|
144
|
+
float32x4_t ra1_y_f32x4 = vfmaq_f32(
|
|
145
|
+
vfmaq_f32(vmulq_f32(scaled_rotation_y_x_f32x4, pa1_x_f32x4), scaled_rotation_y_y_f32x4, pa1_y_f32x4),
|
|
146
|
+
scaled_rotation_y_z_f32x4, pa1_z_f32x4);
|
|
147
|
+
float32x4_t ra1_z_f32x4 = vfmaq_f32(
|
|
148
|
+
vfmaq_f32(vmulq_f32(scaled_rotation_z_x_f32x4, pa1_x_f32x4), scaled_rotation_z_y_f32x4, pa1_y_f32x4),
|
|
149
|
+
scaled_rotation_z_z_f32x4, pa1_z_f32x4);
|
|
150
|
+
|
|
151
|
+
// Rotate and scale second batch: ra2 = scale * R * pa2
|
|
152
|
+
float32x4_t ra2_x_f32x4 = vfmaq_f32(
|
|
153
|
+
vfmaq_f32(vmulq_f32(scaled_rotation_x_x_f32x4, pa2_x_f32x4), scaled_rotation_x_y_f32x4, pa2_y_f32x4),
|
|
154
|
+
scaled_rotation_x_z_f32x4, pa2_z_f32x4);
|
|
155
|
+
float32x4_t ra2_y_f32x4 = vfmaq_f32(
|
|
156
|
+
vfmaq_f32(vmulq_f32(scaled_rotation_y_x_f32x4, pa2_x_f32x4), scaled_rotation_y_y_f32x4, pa2_y_f32x4),
|
|
157
|
+
scaled_rotation_y_z_f32x4, pa2_z_f32x4);
|
|
158
|
+
float32x4_t ra2_z_f32x4 = vfmaq_f32(
|
|
159
|
+
vfmaq_f32(vmulq_f32(scaled_rotation_z_x_f32x4, pa2_x_f32x4), scaled_rotation_z_y_f32x4, pa2_y_f32x4),
|
|
160
|
+
scaled_rotation_z_z_f32x4, pa2_z_f32x4);
|
|
161
|
+
|
|
162
|
+
// Deltas
|
|
163
|
+
float32x4_t delta1_x_f32x4 = vsubq_f32(ra1_x_f32x4, pb1_x_f32x4);
|
|
164
|
+
float32x4_t delta1_y_f32x4 = vsubq_f32(ra1_y_f32x4, pb1_y_f32x4);
|
|
165
|
+
float32x4_t delta1_z_f32x4 = vsubq_f32(ra1_z_f32x4, pb1_z_f32x4);
|
|
166
|
+
float32x4_t delta2_x_f32x4 = vsubq_f32(ra2_x_f32x4, pb2_x_f32x4);
|
|
167
|
+
float32x4_t delta2_y_f32x4 = vsubq_f32(ra2_y_f32x4, pb2_y_f32x4);
|
|
168
|
+
float32x4_t delta2_z_f32x4 = vsubq_f32(ra2_z_f32x4, pb2_z_f32x4);
|
|
169
|
+
|
|
170
|
+
// Accumulate to independent accumulators
|
|
171
|
+
sum_squared_a_f32x4 = vfmaq_f32(sum_squared_a_f32x4, delta1_x_f32x4, delta1_x_f32x4);
|
|
172
|
+
sum_squared_b_f32x4 = vfmaq_f32(sum_squared_b_f32x4, delta2_x_f32x4, delta2_x_f32x4);
|
|
173
|
+
sum_squared_a_f32x4 = vfmaq_f32(sum_squared_a_f32x4, delta1_y_f32x4, delta1_y_f32x4);
|
|
174
|
+
sum_squared_b_f32x4 = vfmaq_f32(sum_squared_b_f32x4, delta2_y_f32x4, delta2_y_f32x4);
|
|
175
|
+
sum_squared_a_f32x4 = vfmaq_f32(sum_squared_a_f32x4, delta1_z_f32x4, delta1_z_f32x4);
|
|
176
|
+
sum_squared_b_f32x4 = vfmaq_f32(sum_squared_b_f32x4, delta2_z_f32x4, delta2_z_f32x4);
|
|
177
|
+
}
|
|
178
|
+
|
|
179
|
+
// Handle remaining 4 points
|
|
180
|
+
if (j + 4 <= n) {
|
|
181
|
+
float32x4_t a_x_f32x4, a_y_f32x4, a_z_f32x4, b_x_f32x4, b_y_f32x4, b_z_f32x4;
|
|
182
|
+
nk_deinterleave_bf16x4_to_f32x4_neonbfdot_(a + j * 3, &a_x_f32x4, &a_y_f32x4, &a_z_f32x4);
|
|
183
|
+
nk_deinterleave_bf16x4_to_f32x4_neonbfdot_(b + j * 3, &b_x_f32x4, &b_y_f32x4, &b_z_f32x4);
|
|
184
|
+
|
|
185
|
+
float32x4_t pa_x_f32x4 = vsubq_f32(a_x_f32x4, centroid_a_x_f32x4);
|
|
186
|
+
float32x4_t pa_y_f32x4 = vsubq_f32(a_y_f32x4, centroid_a_y_f32x4);
|
|
187
|
+
float32x4_t pa_z_f32x4 = vsubq_f32(a_z_f32x4, centroid_a_z_f32x4);
|
|
188
|
+
float32x4_t pb_x_f32x4 = vsubq_f32(b_x_f32x4, centroid_b_x_f32x4);
|
|
189
|
+
float32x4_t pb_y_f32x4 = vsubq_f32(b_y_f32x4, centroid_b_y_f32x4);
|
|
190
|
+
float32x4_t pb_z_f32x4 = vsubq_f32(b_z_f32x4, centroid_b_z_f32x4);
|
|
191
|
+
|
|
192
|
+
float32x4_t ra_x_f32x4 = vfmaq_f32(
|
|
193
|
+
vfmaq_f32(vmulq_f32(scaled_rotation_x_x_f32x4, pa_x_f32x4), scaled_rotation_x_y_f32x4, pa_y_f32x4),
|
|
194
|
+
scaled_rotation_x_z_f32x4, pa_z_f32x4);
|
|
195
|
+
float32x4_t ra_y_f32x4 = vfmaq_f32(
|
|
196
|
+
vfmaq_f32(vmulq_f32(scaled_rotation_y_x_f32x4, pa_x_f32x4), scaled_rotation_y_y_f32x4, pa_y_f32x4),
|
|
197
|
+
scaled_rotation_y_z_f32x4, pa_z_f32x4);
|
|
198
|
+
float32x4_t ra_z_f32x4 = vfmaq_f32(
|
|
199
|
+
vfmaq_f32(vmulq_f32(scaled_rotation_z_x_f32x4, pa_x_f32x4), scaled_rotation_z_y_f32x4, pa_y_f32x4),
|
|
200
|
+
scaled_rotation_z_z_f32x4, pa_z_f32x4);
|
|
201
|
+
|
|
202
|
+
float32x4_t delta_x_f32x4 = vsubq_f32(ra_x_f32x4, pb_x_f32x4);
|
|
203
|
+
float32x4_t delta_y_f32x4 = vsubq_f32(ra_y_f32x4, pb_y_f32x4);
|
|
204
|
+
float32x4_t delta_z_f32x4 = vsubq_f32(ra_z_f32x4, pb_z_f32x4);
|
|
205
|
+
|
|
206
|
+
sum_squared_a_f32x4 = vfmaq_f32(sum_squared_a_f32x4, delta_x_f32x4, delta_x_f32x4);
|
|
207
|
+
sum_squared_a_f32x4 = vfmaq_f32(sum_squared_a_f32x4, delta_y_f32x4, delta_y_f32x4);
|
|
208
|
+
sum_squared_a_f32x4 = vfmaq_f32(sum_squared_a_f32x4, delta_z_f32x4, delta_z_f32x4);
|
|
209
|
+
j += 4;
|
|
210
|
+
}
|
|
211
|
+
|
|
212
|
+
// Partial tail: handle remaining 1-3 points with vectorized partial deinterleave
|
|
213
|
+
if (j < n) {
|
|
214
|
+
float32x4_t a_x_f32x4, a_y_f32x4, a_z_f32x4, b_x_f32x4, b_y_f32x4, b_z_f32x4;
|
|
215
|
+
nk_partial_deinterleave_bf16_to_f32x4_neonbfdot_(a + j * 3, n - j, &a_x_f32x4, &a_y_f32x4, &a_z_f32x4);
|
|
216
|
+
nk_partial_deinterleave_bf16_to_f32x4_neonbfdot_(b + j * 3, n - j, &b_x_f32x4, &b_y_f32x4, &b_z_f32x4);
|
|
217
|
+
|
|
218
|
+
// Mask invalid lanes to zero BEFORE centering
|
|
219
|
+
uint32x4_t lane_u32x4 = {0, 1, 2, 3};
|
|
220
|
+
uint32x4_t valid_u32x4 = vcltq_u32(lane_u32x4, vdupq_n_u32((uint32_t)(n - j)));
|
|
221
|
+
float32x4_t zero_f32x4 = vdupq_n_f32(0);
|
|
222
|
+
a_x_f32x4 = vbslq_f32(valid_u32x4, a_x_f32x4, zero_f32x4);
|
|
223
|
+
a_y_f32x4 = vbslq_f32(valid_u32x4, a_y_f32x4, zero_f32x4);
|
|
224
|
+
a_z_f32x4 = vbslq_f32(valid_u32x4, a_z_f32x4, zero_f32x4);
|
|
225
|
+
b_x_f32x4 = vbslq_f32(valid_u32x4, b_x_f32x4, zero_f32x4);
|
|
226
|
+
b_y_f32x4 = vbslq_f32(valid_u32x4, b_y_f32x4, zero_f32x4);
|
|
227
|
+
b_z_f32x4 = vbslq_f32(valid_u32x4, b_z_f32x4, zero_f32x4);
|
|
228
|
+
|
|
229
|
+
// Same centering + rotation + delta + FMA as body
|
|
230
|
+
float32x4_t pa_x_f32x4 = vsubq_f32(a_x_f32x4, centroid_a_x_f32x4);
|
|
231
|
+
float32x4_t pa_y_f32x4 = vsubq_f32(a_y_f32x4, centroid_a_y_f32x4);
|
|
232
|
+
float32x4_t pa_z_f32x4 = vsubq_f32(a_z_f32x4, centroid_a_z_f32x4);
|
|
233
|
+
float32x4_t pb_x_f32x4 = vsubq_f32(b_x_f32x4, centroid_b_x_f32x4);
|
|
234
|
+
float32x4_t pb_y_f32x4 = vsubq_f32(b_y_f32x4, centroid_b_y_f32x4);
|
|
235
|
+
float32x4_t pb_z_f32x4 = vsubq_f32(b_z_f32x4, centroid_b_z_f32x4);
|
|
236
|
+
|
|
237
|
+
float32x4_t ra_x_f32x4 = vfmaq_f32(
|
|
238
|
+
vfmaq_f32(vmulq_f32(scaled_rotation_x_x_f32x4, pa_x_f32x4), scaled_rotation_x_y_f32x4, pa_y_f32x4),
|
|
239
|
+
scaled_rotation_x_z_f32x4, pa_z_f32x4);
|
|
240
|
+
float32x4_t ra_y_f32x4 = vfmaq_f32(
|
|
241
|
+
vfmaq_f32(vmulq_f32(scaled_rotation_y_x_f32x4, pa_x_f32x4), scaled_rotation_y_y_f32x4, pa_y_f32x4),
|
|
242
|
+
scaled_rotation_y_z_f32x4, pa_z_f32x4);
|
|
243
|
+
float32x4_t ra_z_f32x4 = vfmaq_f32(
|
|
244
|
+
vfmaq_f32(vmulq_f32(scaled_rotation_z_x_f32x4, pa_x_f32x4), scaled_rotation_z_y_f32x4, pa_y_f32x4),
|
|
245
|
+
scaled_rotation_z_z_f32x4, pa_z_f32x4);
|
|
246
|
+
|
|
247
|
+
float32x4_t delta_x_f32x4 = vsubq_f32(ra_x_f32x4, pb_x_f32x4);
|
|
248
|
+
float32x4_t delta_y_f32x4 = vsubq_f32(ra_y_f32x4, pb_y_f32x4);
|
|
249
|
+
float32x4_t delta_z_f32x4 = vsubq_f32(ra_z_f32x4, pb_z_f32x4);
|
|
250
|
+
|
|
251
|
+
sum_squared_a_f32x4 = vfmaq_f32(sum_squared_a_f32x4, delta_x_f32x4, delta_x_f32x4);
|
|
252
|
+
sum_squared_a_f32x4 = vfmaq_f32(sum_squared_a_f32x4, delta_y_f32x4, delta_y_f32x4);
|
|
253
|
+
sum_squared_a_f32x4 = vfmaq_f32(sum_squared_a_f32x4, delta_z_f32x4, delta_z_f32x4);
|
|
254
|
+
}
|
|
255
|
+
|
|
256
|
+
// Combine accumulators and reduce
|
|
257
|
+
float32x4_t sum_squared_f32x4 = vaddq_f32(sum_squared_a_f32x4, sum_squared_b_f32x4);
|
|
258
|
+
nk_f32_t sum_squared = vaddvq_f32(sum_squared_f32x4);
|
|
259
|
+
|
|
260
|
+
return sum_squared;
|
|
261
|
+
}
|
|
262
|
+
|
|
263
|
+
NK_PUBLIC void nk_rmsd_bf16_neonbfdot(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
|
|
264
|
+
nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
|
|
265
|
+
/* RMSD uses identity rotation and scale=1.0 */
|
|
266
|
+
if (rotation) {
|
|
267
|
+
rotation[0] = 1, rotation[1] = 0, rotation[2] = 0;
|
|
268
|
+
rotation[3] = 0, rotation[4] = 1, rotation[5] = 0;
|
|
269
|
+
rotation[6] = 0, rotation[7] = 0, rotation[8] = 1;
|
|
270
|
+
}
|
|
271
|
+
if (scale) *scale = 1.0f;
|
|
272
|
+
|
|
273
|
+
float32x4_t const zeros_f32x4 = vdupq_n_f32(0);
|
|
274
|
+
|
|
275
|
+
// Accumulators for centroids and squared differences
|
|
276
|
+
float32x4_t sum_a_x_f32x4 = zeros_f32x4, sum_a_y_f32x4 = zeros_f32x4, sum_a_z_f32x4 = zeros_f32x4;
|
|
277
|
+
float32x4_t sum_b_x_f32x4 = zeros_f32x4, sum_b_y_f32x4 = zeros_f32x4, sum_b_z_f32x4 = zeros_f32x4;
|
|
278
|
+
float32x4_t sum_squared_x_f32x4 = zeros_f32x4, sum_squared_y_f32x4 = zeros_f32x4, sum_squared_z_f32x4 = zeros_f32x4;
|
|
279
|
+
|
|
280
|
+
float32x4_t a_x_f32x4, a_y_f32x4, a_z_f32x4, b_x_f32x4, b_y_f32x4, b_z_f32x4;
|
|
281
|
+
nk_size_t i = 0;
|
|
282
|
+
|
|
283
|
+
// Main loop processing 4 points at a time
|
|
284
|
+
for (; i + 4 <= n; i += 4) {
|
|
285
|
+
nk_deinterleave_bf16x4_to_f32x4_neonbfdot_(a + i * 3, &a_x_f32x4, &a_y_f32x4, &a_z_f32x4);
|
|
286
|
+
nk_deinterleave_bf16x4_to_f32x4_neonbfdot_(b + i * 3, &b_x_f32x4, &b_y_f32x4, &b_z_f32x4);
|
|
287
|
+
|
|
288
|
+
sum_a_x_f32x4 = vaddq_f32(sum_a_x_f32x4, a_x_f32x4);
|
|
289
|
+
sum_a_y_f32x4 = vaddq_f32(sum_a_y_f32x4, a_y_f32x4);
|
|
290
|
+
sum_a_z_f32x4 = vaddq_f32(sum_a_z_f32x4, a_z_f32x4);
|
|
291
|
+
sum_b_x_f32x4 = vaddq_f32(sum_b_x_f32x4, b_x_f32x4);
|
|
292
|
+
sum_b_y_f32x4 = vaddq_f32(sum_b_y_f32x4, b_y_f32x4);
|
|
293
|
+
sum_b_z_f32x4 = vaddq_f32(sum_b_z_f32x4, b_z_f32x4);
|
|
294
|
+
|
|
295
|
+
float32x4_t delta_x_f32x4 = vsubq_f32(a_x_f32x4, b_x_f32x4);
|
|
296
|
+
float32x4_t delta_y_f32x4 = vsubq_f32(a_y_f32x4, b_y_f32x4);
|
|
297
|
+
float32x4_t delta_z_f32x4 = vsubq_f32(a_z_f32x4, b_z_f32x4);
|
|
298
|
+
|
|
299
|
+
sum_squared_x_f32x4 = vfmaq_f32(sum_squared_x_f32x4, delta_x_f32x4, delta_x_f32x4);
|
|
300
|
+
sum_squared_y_f32x4 = vfmaq_f32(sum_squared_y_f32x4, delta_y_f32x4, delta_y_f32x4);
|
|
301
|
+
sum_squared_z_f32x4 = vfmaq_f32(sum_squared_z_f32x4, delta_z_f32x4, delta_z_f32x4);
|
|
302
|
+
}
|
|
303
|
+
|
|
304
|
+
// Partial tail: handle remaining 1-3 points with vectorized partial deinterleave
|
|
305
|
+
if (i < n) {
|
|
306
|
+
float32x4_t a_x_f32x4, a_y_f32x4, a_z_f32x4, b_x_f32x4, b_y_f32x4, b_z_f32x4;
|
|
307
|
+
nk_partial_deinterleave_bf16_to_f32x4_neonbfdot_(a + i * 3, n - i, &a_x_f32x4, &a_y_f32x4, &a_z_f32x4);
|
|
308
|
+
nk_partial_deinterleave_bf16_to_f32x4_neonbfdot_(b + i * 3, n - i, &b_x_f32x4, &b_y_f32x4, &b_z_f32x4);
|
|
309
|
+
|
|
310
|
+
sum_a_x_f32x4 = vaddq_f32(sum_a_x_f32x4, a_x_f32x4);
|
|
311
|
+
sum_a_y_f32x4 = vaddq_f32(sum_a_y_f32x4, a_y_f32x4);
|
|
312
|
+
sum_a_z_f32x4 = vaddq_f32(sum_a_z_f32x4, a_z_f32x4);
|
|
313
|
+
sum_b_x_f32x4 = vaddq_f32(sum_b_x_f32x4, b_x_f32x4);
|
|
314
|
+
sum_b_y_f32x4 = vaddq_f32(sum_b_y_f32x4, b_y_f32x4);
|
|
315
|
+
sum_b_z_f32x4 = vaddq_f32(sum_b_z_f32x4, b_z_f32x4);
|
|
316
|
+
|
|
317
|
+
float32x4_t delta_x_f32x4 = vsubq_f32(a_x_f32x4, b_x_f32x4);
|
|
318
|
+
float32x4_t delta_y_f32x4 = vsubq_f32(a_y_f32x4, b_y_f32x4);
|
|
319
|
+
float32x4_t delta_z_f32x4 = vsubq_f32(a_z_f32x4, b_z_f32x4);
|
|
320
|
+
|
|
321
|
+
sum_squared_x_f32x4 = vfmaq_f32(sum_squared_x_f32x4, delta_x_f32x4, delta_x_f32x4);
|
|
322
|
+
sum_squared_y_f32x4 = vfmaq_f32(sum_squared_y_f32x4, delta_y_f32x4, delta_y_f32x4);
|
|
323
|
+
sum_squared_z_f32x4 = vfmaq_f32(sum_squared_z_f32x4, delta_z_f32x4, delta_z_f32x4);
|
|
324
|
+
}
|
|
325
|
+
|
|
326
|
+
// Reduce vectors to scalars
|
|
327
|
+
nk_f32_t total_ax = vaddvq_f32(sum_a_x_f32x4);
|
|
328
|
+
nk_f32_t total_ay = vaddvq_f32(sum_a_y_f32x4);
|
|
329
|
+
nk_f32_t total_az = vaddvq_f32(sum_a_z_f32x4);
|
|
330
|
+
nk_f32_t total_bx = vaddvq_f32(sum_b_x_f32x4);
|
|
331
|
+
nk_f32_t total_by = vaddvq_f32(sum_b_y_f32x4);
|
|
332
|
+
nk_f32_t total_bz = vaddvq_f32(sum_b_z_f32x4);
|
|
333
|
+
nk_f32_t total_squared_x = vaddvq_f32(sum_squared_x_f32x4);
|
|
334
|
+
nk_f32_t total_squared_y = vaddvq_f32(sum_squared_y_f32x4);
|
|
335
|
+
nk_f32_t total_squared_z = vaddvq_f32(sum_squared_z_f32x4);
|
|
336
|
+
|
|
337
|
+
// Compute centroids
|
|
338
|
+
nk_f32_t inv_n = 1.0f / (nk_f32_t)n;
|
|
339
|
+
nk_f32_t centroid_a_x = total_ax * inv_n;
|
|
340
|
+
nk_f32_t centroid_a_y = total_ay * inv_n;
|
|
341
|
+
nk_f32_t centroid_a_z = total_az * inv_n;
|
|
342
|
+
nk_f32_t centroid_b_x = total_bx * inv_n;
|
|
343
|
+
nk_f32_t centroid_b_y = total_by * inv_n;
|
|
344
|
+
nk_f32_t centroid_b_z = total_bz * inv_n;
|
|
345
|
+
|
|
346
|
+
if (a_centroid) {
|
|
347
|
+
a_centroid[0] = centroid_a_x;
|
|
348
|
+
a_centroid[1] = centroid_a_y;
|
|
349
|
+
a_centroid[2] = centroid_a_z;
|
|
350
|
+
}
|
|
351
|
+
if (b_centroid) {
|
|
352
|
+
b_centroid[0] = centroid_b_x;
|
|
353
|
+
b_centroid[1] = centroid_b_y;
|
|
354
|
+
b_centroid[2] = centroid_b_z;
|
|
355
|
+
}
|
|
356
|
+
|
|
357
|
+
// Compute RMSD
|
|
358
|
+
nk_f32_t mean_diff_x = centroid_a_x - centroid_b_x;
|
|
359
|
+
nk_f32_t mean_diff_y = centroid_a_y - centroid_b_y;
|
|
360
|
+
nk_f32_t mean_diff_z = centroid_a_z - centroid_b_z;
|
|
361
|
+
nk_f32_t sum_squared = total_squared_x + total_squared_y + total_squared_z;
|
|
362
|
+
nk_f32_t mean_diff_sq = mean_diff_x * mean_diff_x + mean_diff_y * mean_diff_y + mean_diff_z * mean_diff_z;
|
|
363
|
+
|
|
364
|
+
*result = nk_f32_sqrt_neon(sum_squared * inv_n - mean_diff_sq);
|
|
365
|
+
}
|
|
366
|
+
|
|
367
|
+
NK_PUBLIC void nk_kabsch_bf16_neonbfdot(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
|
|
368
|
+
nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
|
|
369
|
+
float32x4_t const zeros_f32x4 = vdupq_n_f32(0);
|
|
370
|
+
|
|
371
|
+
/* 2x unrolling with dual accumulators to hide FMA latency. */
|
|
372
|
+
float32x4_t sum_a_x_a_f32x4 = zeros_f32x4, sum_a_y_a_f32x4 = zeros_f32x4, sum_a_z_a_f32x4 = zeros_f32x4;
|
|
373
|
+
float32x4_t sum_b_x_a_f32x4 = zeros_f32x4, sum_b_y_a_f32x4 = zeros_f32x4, sum_b_z_a_f32x4 = zeros_f32x4;
|
|
374
|
+
float32x4_t sum_a_x_b_f32x4 = zeros_f32x4, sum_a_y_b_f32x4 = zeros_f32x4, sum_a_z_b_f32x4 = zeros_f32x4;
|
|
375
|
+
float32x4_t sum_b_x_b_f32x4 = zeros_f32x4, sum_b_y_b_f32x4 = zeros_f32x4, sum_b_z_b_f32x4 = zeros_f32x4;
|
|
376
|
+
|
|
377
|
+
float32x4_t cov_xx_a_f32x4 = zeros_f32x4, cov_xy_a_f32x4 = zeros_f32x4, cov_xz_a_f32x4 = zeros_f32x4;
|
|
378
|
+
float32x4_t cov_yx_a_f32x4 = zeros_f32x4, cov_yy_a_f32x4 = zeros_f32x4, cov_yz_a_f32x4 = zeros_f32x4;
|
|
379
|
+
float32x4_t cov_zx_a_f32x4 = zeros_f32x4, cov_zy_a_f32x4 = zeros_f32x4, cov_zz_a_f32x4 = zeros_f32x4;
|
|
380
|
+
float32x4_t cov_xx_b_f32x4 = zeros_f32x4, cov_xy_b_f32x4 = zeros_f32x4, cov_xz_b_f32x4 = zeros_f32x4;
|
|
381
|
+
float32x4_t cov_yx_b_f32x4 = zeros_f32x4, cov_yy_b_f32x4 = zeros_f32x4, cov_yz_b_f32x4 = zeros_f32x4;
|
|
382
|
+
float32x4_t cov_zx_b_f32x4 = zeros_f32x4, cov_zy_b_f32x4 = zeros_f32x4, cov_zz_b_f32x4 = zeros_f32x4;
|
|
383
|
+
|
|
384
|
+
nk_size_t i = 0;
|
|
385
|
+
float32x4_t a1_x_f32x4, a1_y_f32x4, a1_z_f32x4, b1_x_f32x4, b1_y_f32x4, b1_z_f32x4;
|
|
386
|
+
float32x4_t a2_x_f32x4, a2_y_f32x4, a2_z_f32x4, b2_x_f32x4, b2_y_f32x4, b2_z_f32x4;
|
|
387
|
+
|
|
388
|
+
// Main loop: 8 points per iteration (2x unrolled)
|
|
389
|
+
for (; i + 8 <= n; i += 8) {
|
|
390
|
+
nk_deinterleave_bf16x4_to_f32x4_neonbfdot_(a + i * 3, &a1_x_f32x4, &a1_y_f32x4, &a1_z_f32x4);
|
|
391
|
+
nk_deinterleave_bf16x4_to_f32x4_neonbfdot_(b + i * 3, &b1_x_f32x4, &b1_y_f32x4, &b1_z_f32x4);
|
|
392
|
+
nk_deinterleave_bf16x4_to_f32x4_neonbfdot_(a + (i + 4) * 3, &a2_x_f32x4, &a2_y_f32x4, &a2_z_f32x4);
|
|
393
|
+
nk_deinterleave_bf16x4_to_f32x4_neonbfdot_(b + (i + 4) * 3, &b2_x_f32x4, &b2_y_f32x4, &b2_z_f32x4);
|
|
394
|
+
|
|
395
|
+
// Interleaved accumulation to hide FMA latency
|
|
396
|
+
sum_a_x_a_f32x4 = vaddq_f32(sum_a_x_a_f32x4, a1_x_f32x4);
|
|
397
|
+
sum_a_x_b_f32x4 = vaddq_f32(sum_a_x_b_f32x4, a2_x_f32x4);
|
|
398
|
+
sum_a_y_a_f32x4 = vaddq_f32(sum_a_y_a_f32x4, a1_y_f32x4);
|
|
399
|
+
sum_a_y_b_f32x4 = vaddq_f32(sum_a_y_b_f32x4, a2_y_f32x4);
|
|
400
|
+
sum_a_z_a_f32x4 = vaddq_f32(sum_a_z_a_f32x4, a1_z_f32x4);
|
|
401
|
+
sum_a_z_b_f32x4 = vaddq_f32(sum_a_z_b_f32x4, a2_z_f32x4);
|
|
402
|
+
sum_b_x_a_f32x4 = vaddq_f32(sum_b_x_a_f32x4, b1_x_f32x4);
|
|
403
|
+
sum_b_x_b_f32x4 = vaddq_f32(sum_b_x_b_f32x4, b2_x_f32x4);
|
|
404
|
+
sum_b_y_a_f32x4 = vaddq_f32(sum_b_y_a_f32x4, b1_y_f32x4);
|
|
405
|
+
sum_b_y_b_f32x4 = vaddq_f32(sum_b_y_b_f32x4, b2_y_f32x4);
|
|
406
|
+
sum_b_z_a_f32x4 = vaddq_f32(sum_b_z_a_f32x4, b1_z_f32x4);
|
|
407
|
+
sum_b_z_b_f32x4 = vaddq_f32(sum_b_z_b_f32x4, b2_z_f32x4);
|
|
408
|
+
|
|
409
|
+
cov_xx_a_f32x4 = vfmaq_f32(cov_xx_a_f32x4, a1_x_f32x4, b1_x_f32x4);
|
|
410
|
+
cov_xx_b_f32x4 = vfmaq_f32(cov_xx_b_f32x4, a2_x_f32x4, b2_x_f32x4);
|
|
411
|
+
cov_xy_a_f32x4 = vfmaq_f32(cov_xy_a_f32x4, a1_x_f32x4, b1_y_f32x4);
|
|
412
|
+
cov_xy_b_f32x4 = vfmaq_f32(cov_xy_b_f32x4, a2_x_f32x4, b2_y_f32x4);
|
|
413
|
+
cov_xz_a_f32x4 = vfmaq_f32(cov_xz_a_f32x4, a1_x_f32x4, b1_z_f32x4);
|
|
414
|
+
cov_xz_b_f32x4 = vfmaq_f32(cov_xz_b_f32x4, a2_x_f32x4, b2_z_f32x4);
|
|
415
|
+
cov_yx_a_f32x4 = vfmaq_f32(cov_yx_a_f32x4, a1_y_f32x4, b1_x_f32x4);
|
|
416
|
+
cov_yx_b_f32x4 = vfmaq_f32(cov_yx_b_f32x4, a2_y_f32x4, b2_x_f32x4);
|
|
417
|
+
cov_yy_a_f32x4 = vfmaq_f32(cov_yy_a_f32x4, a1_y_f32x4, b1_y_f32x4);
|
|
418
|
+
cov_yy_b_f32x4 = vfmaq_f32(cov_yy_b_f32x4, a2_y_f32x4, b2_y_f32x4);
|
|
419
|
+
cov_yz_a_f32x4 = vfmaq_f32(cov_yz_a_f32x4, a1_y_f32x4, b1_z_f32x4);
|
|
420
|
+
cov_yz_b_f32x4 = vfmaq_f32(cov_yz_b_f32x4, a2_y_f32x4, b2_z_f32x4);
|
|
421
|
+
cov_zx_a_f32x4 = vfmaq_f32(cov_zx_a_f32x4, a1_z_f32x4, b1_x_f32x4);
|
|
422
|
+
cov_zx_b_f32x4 = vfmaq_f32(cov_zx_b_f32x4, a2_z_f32x4, b2_x_f32x4);
|
|
423
|
+
cov_zy_a_f32x4 = vfmaq_f32(cov_zy_a_f32x4, a1_z_f32x4, b1_y_f32x4);
|
|
424
|
+
cov_zy_b_f32x4 = vfmaq_f32(cov_zy_b_f32x4, a2_z_f32x4, b2_y_f32x4);
|
|
425
|
+
cov_zz_a_f32x4 = vfmaq_f32(cov_zz_a_f32x4, a1_z_f32x4, b1_z_f32x4);
|
|
426
|
+
cov_zz_b_f32x4 = vfmaq_f32(cov_zz_b_f32x4, a2_z_f32x4, b2_z_f32x4);
|
|
427
|
+
}
|
|
428
|
+
|
|
429
|
+
// 4-point tail
|
|
430
|
+
for (; i + 4 <= n; i += 4) {
|
|
431
|
+
nk_deinterleave_bf16x4_to_f32x4_neonbfdot_(a + i * 3, &a1_x_f32x4, &a1_y_f32x4, &a1_z_f32x4);
|
|
432
|
+
nk_deinterleave_bf16x4_to_f32x4_neonbfdot_(b + i * 3, &b1_x_f32x4, &b1_y_f32x4, &b1_z_f32x4);
|
|
433
|
+
sum_a_x_a_f32x4 = vaddq_f32(sum_a_x_a_f32x4, a1_x_f32x4);
|
|
434
|
+
sum_a_y_a_f32x4 = vaddq_f32(sum_a_y_a_f32x4, a1_y_f32x4);
|
|
435
|
+
sum_a_z_a_f32x4 = vaddq_f32(sum_a_z_a_f32x4, a1_z_f32x4);
|
|
436
|
+
sum_b_x_a_f32x4 = vaddq_f32(sum_b_x_a_f32x4, b1_x_f32x4);
|
|
437
|
+
sum_b_y_a_f32x4 = vaddq_f32(sum_b_y_a_f32x4, b1_y_f32x4);
|
|
438
|
+
sum_b_z_a_f32x4 = vaddq_f32(sum_b_z_a_f32x4, b1_z_f32x4);
|
|
439
|
+
cov_xx_a_f32x4 = vfmaq_f32(cov_xx_a_f32x4, a1_x_f32x4, b1_x_f32x4);
|
|
440
|
+
cov_xy_a_f32x4 = vfmaq_f32(cov_xy_a_f32x4, a1_x_f32x4, b1_y_f32x4);
|
|
441
|
+
cov_xz_a_f32x4 = vfmaq_f32(cov_xz_a_f32x4, a1_x_f32x4, b1_z_f32x4);
|
|
442
|
+
cov_yx_a_f32x4 = vfmaq_f32(cov_yx_a_f32x4, a1_y_f32x4, b1_x_f32x4);
|
|
443
|
+
cov_yy_a_f32x4 = vfmaq_f32(cov_yy_a_f32x4, a1_y_f32x4, b1_y_f32x4);
|
|
444
|
+
cov_yz_a_f32x4 = vfmaq_f32(cov_yz_a_f32x4, a1_y_f32x4, b1_z_f32x4);
|
|
445
|
+
cov_zx_a_f32x4 = vfmaq_f32(cov_zx_a_f32x4, a1_z_f32x4, b1_x_f32x4);
|
|
446
|
+
cov_zy_a_f32x4 = vfmaq_f32(cov_zy_a_f32x4, a1_z_f32x4, b1_y_f32x4);
|
|
447
|
+
cov_zz_a_f32x4 = vfmaq_f32(cov_zz_a_f32x4, a1_z_f32x4, b1_z_f32x4);
|
|
448
|
+
}
|
|
449
|
+
|
|
450
|
+
// Partial tail: handle remaining 1-3 points with vectorized partial deinterleave
|
|
451
|
+
if (i < n) {
|
|
452
|
+
nk_partial_deinterleave_bf16_to_f32x4_neonbfdot_(a + i * 3, n - i, &a1_x_f32x4, &a1_y_f32x4, &a1_z_f32x4);
|
|
453
|
+
nk_partial_deinterleave_bf16_to_f32x4_neonbfdot_(b + i * 3, n - i, &b1_x_f32x4, &b1_y_f32x4, &b1_z_f32x4);
|
|
454
|
+
sum_a_x_a_f32x4 = vaddq_f32(sum_a_x_a_f32x4, a1_x_f32x4);
|
|
455
|
+
sum_a_y_a_f32x4 = vaddq_f32(sum_a_y_a_f32x4, a1_y_f32x4);
|
|
456
|
+
sum_a_z_a_f32x4 = vaddq_f32(sum_a_z_a_f32x4, a1_z_f32x4);
|
|
457
|
+
sum_b_x_a_f32x4 = vaddq_f32(sum_b_x_a_f32x4, b1_x_f32x4);
|
|
458
|
+
sum_b_y_a_f32x4 = vaddq_f32(sum_b_y_a_f32x4, b1_y_f32x4);
|
|
459
|
+
sum_b_z_a_f32x4 = vaddq_f32(sum_b_z_a_f32x4, b1_z_f32x4);
|
|
460
|
+
cov_xx_a_f32x4 = vfmaq_f32(cov_xx_a_f32x4, a1_x_f32x4, b1_x_f32x4);
|
|
461
|
+
cov_xy_a_f32x4 = vfmaq_f32(cov_xy_a_f32x4, a1_x_f32x4, b1_y_f32x4);
|
|
462
|
+
cov_xz_a_f32x4 = vfmaq_f32(cov_xz_a_f32x4, a1_x_f32x4, b1_z_f32x4);
|
|
463
|
+
cov_yx_a_f32x4 = vfmaq_f32(cov_yx_a_f32x4, a1_y_f32x4, b1_x_f32x4);
|
|
464
|
+
cov_yy_a_f32x4 = vfmaq_f32(cov_yy_a_f32x4, a1_y_f32x4, b1_y_f32x4);
|
|
465
|
+
cov_yz_a_f32x4 = vfmaq_f32(cov_yz_a_f32x4, a1_y_f32x4, b1_z_f32x4);
|
|
466
|
+
cov_zx_a_f32x4 = vfmaq_f32(cov_zx_a_f32x4, a1_z_f32x4, b1_x_f32x4);
|
|
467
|
+
cov_zy_a_f32x4 = vfmaq_f32(cov_zy_a_f32x4, a1_z_f32x4, b1_y_f32x4);
|
|
468
|
+
cov_zz_a_f32x4 = vfmaq_f32(cov_zz_a_f32x4, a1_z_f32x4, b1_z_f32x4);
|
|
469
|
+
}
|
|
470
|
+
|
|
471
|
+
// Combine dual accumulators
|
|
472
|
+
float32x4_t sum_a_x_f32x4 = vaddq_f32(sum_a_x_a_f32x4, sum_a_x_b_f32x4);
|
|
473
|
+
float32x4_t sum_a_y_f32x4 = vaddq_f32(sum_a_y_a_f32x4, sum_a_y_b_f32x4);
|
|
474
|
+
float32x4_t sum_a_z_f32x4 = vaddq_f32(sum_a_z_a_f32x4, sum_a_z_b_f32x4);
|
|
475
|
+
float32x4_t sum_b_x_f32x4 = vaddq_f32(sum_b_x_a_f32x4, sum_b_x_b_f32x4);
|
|
476
|
+
float32x4_t sum_b_y_f32x4 = vaddq_f32(sum_b_y_a_f32x4, sum_b_y_b_f32x4);
|
|
477
|
+
float32x4_t sum_b_z_f32x4 = vaddq_f32(sum_b_z_a_f32x4, sum_b_z_b_f32x4);
|
|
478
|
+
float32x4_t cov_xx_f32x4 = vaddq_f32(cov_xx_a_f32x4, cov_xx_b_f32x4);
|
|
479
|
+
float32x4_t cov_xy_f32x4 = vaddq_f32(cov_xy_a_f32x4, cov_xy_b_f32x4);
|
|
480
|
+
float32x4_t cov_xz_f32x4 = vaddq_f32(cov_xz_a_f32x4, cov_xz_b_f32x4);
|
|
481
|
+
float32x4_t cov_yx_f32x4 = vaddq_f32(cov_yx_a_f32x4, cov_yx_b_f32x4);
|
|
482
|
+
float32x4_t cov_yy_f32x4 = vaddq_f32(cov_yy_a_f32x4, cov_yy_b_f32x4);
|
|
483
|
+
float32x4_t cov_yz_f32x4 = vaddq_f32(cov_yz_a_f32x4, cov_yz_b_f32x4);
|
|
484
|
+
float32x4_t cov_zx_f32x4 = vaddq_f32(cov_zx_a_f32x4, cov_zx_b_f32x4);
|
|
485
|
+
float32x4_t cov_zy_f32x4 = vaddq_f32(cov_zy_a_f32x4, cov_zy_b_f32x4);
|
|
486
|
+
float32x4_t cov_zz_f32x4 = vaddq_f32(cov_zz_a_f32x4, cov_zz_b_f32x4);
|
|
487
|
+
|
|
488
|
+
// Reduce vector accumulators
|
|
489
|
+
nk_f32_t sum_a_x = vaddvq_f32(sum_a_x_f32x4);
|
|
490
|
+
nk_f32_t sum_a_y = vaddvq_f32(sum_a_y_f32x4);
|
|
491
|
+
nk_f32_t sum_a_z = vaddvq_f32(sum_a_z_f32x4);
|
|
492
|
+
nk_f32_t sum_b_x = vaddvq_f32(sum_b_x_f32x4);
|
|
493
|
+
nk_f32_t sum_b_y = vaddvq_f32(sum_b_y_f32x4);
|
|
494
|
+
nk_f32_t sum_b_z = vaddvq_f32(sum_b_z_f32x4);
|
|
495
|
+
|
|
496
|
+
nk_f32_t covariance_x_x = vaddvq_f32(cov_xx_f32x4);
|
|
497
|
+
nk_f32_t covariance_x_y = vaddvq_f32(cov_xy_f32x4);
|
|
498
|
+
nk_f32_t covariance_x_z = vaddvq_f32(cov_xz_f32x4);
|
|
499
|
+
nk_f32_t covariance_y_x = vaddvq_f32(cov_yx_f32x4);
|
|
500
|
+
nk_f32_t covariance_y_y = vaddvq_f32(cov_yy_f32x4);
|
|
501
|
+
nk_f32_t covariance_y_z = vaddvq_f32(cov_yz_f32x4);
|
|
502
|
+
nk_f32_t covariance_z_x = vaddvq_f32(cov_zx_f32x4);
|
|
503
|
+
nk_f32_t covariance_z_y = vaddvq_f32(cov_zy_f32x4);
|
|
504
|
+
nk_f32_t covariance_z_z = vaddvq_f32(cov_zz_f32x4);
|
|
505
|
+
|
|
506
|
+
// Compute centroids
|
|
507
|
+
nk_f32_t inv_n = 1.0f / (nk_f32_t)n;
|
|
508
|
+
nk_f32_t centroid_a_x = sum_a_x * inv_n;
|
|
509
|
+
nk_f32_t centroid_a_y = sum_a_y * inv_n;
|
|
510
|
+
nk_f32_t centroid_a_z = sum_a_z * inv_n;
|
|
511
|
+
nk_f32_t centroid_b_x = sum_b_x * inv_n;
|
|
512
|
+
nk_f32_t centroid_b_y = sum_b_y * inv_n;
|
|
513
|
+
nk_f32_t centroid_b_z = sum_b_z * inv_n;
|
|
514
|
+
|
|
515
|
+
if (a_centroid) {
|
|
516
|
+
a_centroid[0] = centroid_a_x;
|
|
517
|
+
a_centroid[1] = centroid_a_y;
|
|
518
|
+
a_centroid[2] = centroid_a_z;
|
|
519
|
+
}
|
|
520
|
+
if (b_centroid) {
|
|
521
|
+
b_centroid[0] = centroid_b_x;
|
|
522
|
+
b_centroid[1] = centroid_b_y;
|
|
523
|
+
b_centroid[2] = centroid_b_z;
|
|
524
|
+
}
|
|
525
|
+
|
|
526
|
+
// Apply centering correction: H_centered = H - n * centroid_a * centroid_bᵀ
|
|
527
|
+
covariance_x_x -= n * centroid_a_x * centroid_b_x;
|
|
528
|
+
covariance_x_y -= n * centroid_a_x * centroid_b_y;
|
|
529
|
+
covariance_x_z -= n * centroid_a_x * centroid_b_z;
|
|
530
|
+
covariance_y_x -= n * centroid_a_y * centroid_b_x;
|
|
531
|
+
covariance_y_y -= n * centroid_a_y * centroid_b_y;
|
|
532
|
+
covariance_y_z -= n * centroid_a_y * centroid_b_z;
|
|
533
|
+
covariance_z_x -= n * centroid_a_z * centroid_b_x;
|
|
534
|
+
covariance_z_y -= n * centroid_a_z * centroid_b_y;
|
|
535
|
+
covariance_z_z -= n * centroid_a_z * centroid_b_z;
|
|
536
|
+
|
|
537
|
+
// Compute SVD and optimal rotation
|
|
538
|
+
nk_f32_t cross_covariance[9] = {covariance_x_x, covariance_x_y, covariance_x_z, covariance_y_x, covariance_y_y,
|
|
539
|
+
covariance_y_z, covariance_z_x, covariance_z_y, covariance_z_z};
|
|
540
|
+
nk_f32_t svd_u[9], svd_s[9], svd_v[9];
|
|
541
|
+
nk_svd3x3_f32_(cross_covariance, svd_u, svd_s, svd_v);
|
|
542
|
+
|
|
543
|
+
// R = V * Uᵀ
|
|
544
|
+
nk_f32_t r[9];
|
|
545
|
+
r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
|
|
546
|
+
r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
|
|
547
|
+
r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
|
|
548
|
+
r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
|
|
549
|
+
r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
|
|
550
|
+
r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
|
|
551
|
+
r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
|
|
552
|
+
r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
|
|
553
|
+
r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
|
|
554
|
+
|
|
555
|
+
// Handle reflection: if det(R) < 0, negate third column of V and recompute R
|
|
556
|
+
if (nk_det3x3_f32_(r) < 0) {
|
|
557
|
+
svd_v[2] = -svd_v[2];
|
|
558
|
+
svd_v[5] = -svd_v[5];
|
|
559
|
+
svd_v[8] = -svd_v[8];
|
|
560
|
+
r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
|
|
561
|
+
r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
|
|
562
|
+
r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
|
|
563
|
+
r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
|
|
564
|
+
r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
|
|
565
|
+
r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
|
|
566
|
+
r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
|
|
567
|
+
r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
|
|
568
|
+
r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
|
|
569
|
+
}
|
|
570
|
+
|
|
571
|
+
/* Output rotation matrix and scale=1.0 */
|
|
572
|
+
if (rotation) {
|
|
573
|
+
for (int j = 0; j < 9; ++j) rotation[j] = r[j];
|
|
574
|
+
}
|
|
575
|
+
if (scale) *scale = 1.0f;
|
|
576
|
+
|
|
577
|
+
// Compute RMSD after optimal rotation
|
|
578
|
+
nk_f32_t sum_squared = nk_transformed_ssd_bf16_neonbfdot_(a, b, n, r, 1.0f, centroid_a_x, centroid_a_y,
|
|
579
|
+
centroid_a_z, centroid_b_x, centroid_b_y, centroid_b_z);
|
|
580
|
+
*result = nk_f32_sqrt_neon(sum_squared * inv_n);
|
|
581
|
+
}
|
|
582
|
+
|
|
583
|
+
NK_PUBLIC void nk_umeyama_bf16_neonbfdot(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
|
|
584
|
+
nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
|
|
585
|
+
float32x4_t const zeros_f32x4 = vdupq_n_f32(0);
|
|
586
|
+
|
|
587
|
+
/* 2x unrolling with dual accumulators to hide FMA latency. */
|
|
588
|
+
float32x4_t sum_a_x_a_f32x4 = zeros_f32x4, sum_a_y_a_f32x4 = zeros_f32x4, sum_a_z_a_f32x4 = zeros_f32x4;
|
|
589
|
+
float32x4_t sum_b_x_a_f32x4 = zeros_f32x4, sum_b_y_a_f32x4 = zeros_f32x4, sum_b_z_a_f32x4 = zeros_f32x4;
|
|
590
|
+
float32x4_t sum_a_x_b_f32x4 = zeros_f32x4, sum_a_y_b_f32x4 = zeros_f32x4, sum_a_z_b_f32x4 = zeros_f32x4;
|
|
591
|
+
float32x4_t sum_b_x_b_f32x4 = zeros_f32x4, sum_b_y_b_f32x4 = zeros_f32x4, sum_b_z_b_f32x4 = zeros_f32x4;
|
|
592
|
+
|
|
593
|
+
float32x4_t cov_xx_a_f32x4 = zeros_f32x4, cov_xy_a_f32x4 = zeros_f32x4, cov_xz_a_f32x4 = zeros_f32x4;
|
|
594
|
+
float32x4_t cov_yx_a_f32x4 = zeros_f32x4, cov_yy_a_f32x4 = zeros_f32x4, cov_yz_a_f32x4 = zeros_f32x4;
|
|
595
|
+
float32x4_t cov_zx_a_f32x4 = zeros_f32x4, cov_zy_a_f32x4 = zeros_f32x4, cov_zz_a_f32x4 = zeros_f32x4;
|
|
596
|
+
float32x4_t cov_xx_b_f32x4 = zeros_f32x4, cov_xy_b_f32x4 = zeros_f32x4, cov_xz_b_f32x4 = zeros_f32x4;
|
|
597
|
+
float32x4_t cov_yx_b_f32x4 = zeros_f32x4, cov_yy_b_f32x4 = zeros_f32x4, cov_yz_b_f32x4 = zeros_f32x4;
|
|
598
|
+
float32x4_t cov_zx_b_f32x4 = zeros_f32x4, cov_zy_b_f32x4 = zeros_f32x4, cov_zz_b_f32x4 = zeros_f32x4;
|
|
599
|
+
|
|
600
|
+
// Variance of A accumulators
|
|
601
|
+
float32x4_t variance_a_a_f32x4 = zeros_f32x4;
|
|
602
|
+
float32x4_t variance_a_b_f32x4 = zeros_f32x4;
|
|
603
|
+
|
|
604
|
+
nk_size_t i = 0;
|
|
605
|
+
float32x4_t a1_x_f32x4, a1_y_f32x4, a1_z_f32x4, b1_x_f32x4, b1_y_f32x4, b1_z_f32x4;
|
|
606
|
+
float32x4_t a2_x_f32x4, a2_y_f32x4, a2_z_f32x4, b2_x_f32x4, b2_y_f32x4, b2_z_f32x4;
|
|
607
|
+
|
|
608
|
+
// Main loop: 8 points per iteration (2x unrolled)
|
|
609
|
+
for (; i + 8 <= n; i += 8) {
|
|
610
|
+
nk_deinterleave_bf16x4_to_f32x4_neonbfdot_(a + i * 3, &a1_x_f32x4, &a1_y_f32x4, &a1_z_f32x4);
|
|
611
|
+
nk_deinterleave_bf16x4_to_f32x4_neonbfdot_(b + i * 3, &b1_x_f32x4, &b1_y_f32x4, &b1_z_f32x4);
|
|
612
|
+
nk_deinterleave_bf16x4_to_f32x4_neonbfdot_(a + (i + 4) * 3, &a2_x_f32x4, &a2_y_f32x4, &a2_z_f32x4);
|
|
613
|
+
nk_deinterleave_bf16x4_to_f32x4_neonbfdot_(b + (i + 4) * 3, &b2_x_f32x4, &b2_y_f32x4, &b2_z_f32x4);
|
|
614
|
+
|
|
615
|
+
// Interleaved accumulation to hide FMA latency
|
|
616
|
+
sum_a_x_a_f32x4 = vaddq_f32(sum_a_x_a_f32x4, a1_x_f32x4);
|
|
617
|
+
sum_a_x_b_f32x4 = vaddq_f32(sum_a_x_b_f32x4, a2_x_f32x4);
|
|
618
|
+
sum_a_y_a_f32x4 = vaddq_f32(sum_a_y_a_f32x4, a1_y_f32x4);
|
|
619
|
+
sum_a_y_b_f32x4 = vaddq_f32(sum_a_y_b_f32x4, a2_y_f32x4);
|
|
620
|
+
sum_a_z_a_f32x4 = vaddq_f32(sum_a_z_a_f32x4, a1_z_f32x4);
|
|
621
|
+
sum_a_z_b_f32x4 = vaddq_f32(sum_a_z_b_f32x4, a2_z_f32x4);
|
|
622
|
+
sum_b_x_a_f32x4 = vaddq_f32(sum_b_x_a_f32x4, b1_x_f32x4);
|
|
623
|
+
sum_b_x_b_f32x4 = vaddq_f32(sum_b_x_b_f32x4, b2_x_f32x4);
|
|
624
|
+
sum_b_y_a_f32x4 = vaddq_f32(sum_b_y_a_f32x4, b1_y_f32x4);
|
|
625
|
+
sum_b_y_b_f32x4 = vaddq_f32(sum_b_y_b_f32x4, b2_y_f32x4);
|
|
626
|
+
sum_b_z_a_f32x4 = vaddq_f32(sum_b_z_a_f32x4, b1_z_f32x4);
|
|
627
|
+
sum_b_z_b_f32x4 = vaddq_f32(sum_b_z_b_f32x4, b2_z_f32x4);
|
|
628
|
+
|
|
629
|
+
// Covariance matrix
|
|
630
|
+
cov_xx_a_f32x4 = vfmaq_f32(cov_xx_a_f32x4, a1_x_f32x4, b1_x_f32x4);
|
|
631
|
+
cov_xx_b_f32x4 = vfmaq_f32(cov_xx_b_f32x4, a2_x_f32x4, b2_x_f32x4);
|
|
632
|
+
cov_xy_a_f32x4 = vfmaq_f32(cov_xy_a_f32x4, a1_x_f32x4, b1_y_f32x4);
|
|
633
|
+
cov_xy_b_f32x4 = vfmaq_f32(cov_xy_b_f32x4, a2_x_f32x4, b2_y_f32x4);
|
|
634
|
+
cov_xz_a_f32x4 = vfmaq_f32(cov_xz_a_f32x4, a1_x_f32x4, b1_z_f32x4);
|
|
635
|
+
cov_xz_b_f32x4 = vfmaq_f32(cov_xz_b_f32x4, a2_x_f32x4, b2_z_f32x4);
|
|
636
|
+
cov_yx_a_f32x4 = vfmaq_f32(cov_yx_a_f32x4, a1_y_f32x4, b1_x_f32x4);
|
|
637
|
+
cov_yx_b_f32x4 = vfmaq_f32(cov_yx_b_f32x4, a2_y_f32x4, b2_x_f32x4);
|
|
638
|
+
cov_yy_a_f32x4 = vfmaq_f32(cov_yy_a_f32x4, a1_y_f32x4, b1_y_f32x4);
|
|
639
|
+
cov_yy_b_f32x4 = vfmaq_f32(cov_yy_b_f32x4, a2_y_f32x4, b2_y_f32x4);
|
|
640
|
+
cov_yz_a_f32x4 = vfmaq_f32(cov_yz_a_f32x4, a1_y_f32x4, b1_z_f32x4);
|
|
641
|
+
cov_yz_b_f32x4 = vfmaq_f32(cov_yz_b_f32x4, a2_y_f32x4, b2_z_f32x4);
|
|
642
|
+
cov_zx_a_f32x4 = vfmaq_f32(cov_zx_a_f32x4, a1_z_f32x4, b1_x_f32x4);
|
|
643
|
+
cov_zx_b_f32x4 = vfmaq_f32(cov_zx_b_f32x4, a2_z_f32x4, b2_x_f32x4);
|
|
644
|
+
cov_zy_a_f32x4 = vfmaq_f32(cov_zy_a_f32x4, a1_z_f32x4, b1_y_f32x4);
|
|
645
|
+
cov_zy_b_f32x4 = vfmaq_f32(cov_zy_b_f32x4, a2_z_f32x4, b2_y_f32x4);
|
|
646
|
+
cov_zz_a_f32x4 = vfmaq_f32(cov_zz_a_f32x4, a1_z_f32x4, b1_z_f32x4);
|
|
647
|
+
cov_zz_b_f32x4 = vfmaq_f32(cov_zz_b_f32x4, a2_z_f32x4, b2_z_f32x4);
|
|
648
|
+
|
|
649
|
+
// Variance of A
|
|
650
|
+
variance_a_a_f32x4 = vfmaq_f32(variance_a_a_f32x4, a1_x_f32x4, a1_x_f32x4);
|
|
651
|
+
variance_a_b_f32x4 = vfmaq_f32(variance_a_b_f32x4, a2_x_f32x4, a2_x_f32x4);
|
|
652
|
+
variance_a_a_f32x4 = vfmaq_f32(variance_a_a_f32x4, a1_y_f32x4, a1_y_f32x4);
|
|
653
|
+
variance_a_b_f32x4 = vfmaq_f32(variance_a_b_f32x4, a2_y_f32x4, a2_y_f32x4);
|
|
654
|
+
variance_a_a_f32x4 = vfmaq_f32(variance_a_a_f32x4, a1_z_f32x4, a1_z_f32x4);
|
|
655
|
+
variance_a_b_f32x4 = vfmaq_f32(variance_a_b_f32x4, a2_z_f32x4, a2_z_f32x4);
|
|
656
|
+
}
|
|
657
|
+
|
|
658
|
+
// 4-point tail
|
|
659
|
+
for (; i + 4 <= n; i += 4) {
|
|
660
|
+
nk_deinterleave_bf16x4_to_f32x4_neonbfdot_(a + i * 3, &a1_x_f32x4, &a1_y_f32x4, &a1_z_f32x4);
|
|
661
|
+
nk_deinterleave_bf16x4_to_f32x4_neonbfdot_(b + i * 3, &b1_x_f32x4, &b1_y_f32x4, &b1_z_f32x4);
|
|
662
|
+
sum_a_x_a_f32x4 = vaddq_f32(sum_a_x_a_f32x4, a1_x_f32x4);
|
|
663
|
+
sum_a_y_a_f32x4 = vaddq_f32(sum_a_y_a_f32x4, a1_y_f32x4);
|
|
664
|
+
sum_a_z_a_f32x4 = vaddq_f32(sum_a_z_a_f32x4, a1_z_f32x4);
|
|
665
|
+
sum_b_x_a_f32x4 = vaddq_f32(sum_b_x_a_f32x4, b1_x_f32x4);
|
|
666
|
+
sum_b_y_a_f32x4 = vaddq_f32(sum_b_y_a_f32x4, b1_y_f32x4);
|
|
667
|
+
sum_b_z_a_f32x4 = vaddq_f32(sum_b_z_a_f32x4, b1_z_f32x4);
|
|
668
|
+
cov_xx_a_f32x4 = vfmaq_f32(cov_xx_a_f32x4, a1_x_f32x4, b1_x_f32x4);
|
|
669
|
+
cov_xy_a_f32x4 = vfmaq_f32(cov_xy_a_f32x4, a1_x_f32x4, b1_y_f32x4);
|
|
670
|
+
cov_xz_a_f32x4 = vfmaq_f32(cov_xz_a_f32x4, a1_x_f32x4, b1_z_f32x4);
|
|
671
|
+
cov_yx_a_f32x4 = vfmaq_f32(cov_yx_a_f32x4, a1_y_f32x4, b1_x_f32x4);
|
|
672
|
+
cov_yy_a_f32x4 = vfmaq_f32(cov_yy_a_f32x4, a1_y_f32x4, b1_y_f32x4);
|
|
673
|
+
cov_yz_a_f32x4 = vfmaq_f32(cov_yz_a_f32x4, a1_y_f32x4, b1_z_f32x4);
|
|
674
|
+
cov_zx_a_f32x4 = vfmaq_f32(cov_zx_a_f32x4, a1_z_f32x4, b1_x_f32x4);
|
|
675
|
+
cov_zy_a_f32x4 = vfmaq_f32(cov_zy_a_f32x4, a1_z_f32x4, b1_y_f32x4);
|
|
676
|
+
cov_zz_a_f32x4 = vfmaq_f32(cov_zz_a_f32x4, a1_z_f32x4, b1_z_f32x4);
|
|
677
|
+
variance_a_a_f32x4 = vfmaq_f32(variance_a_a_f32x4, a1_x_f32x4, a1_x_f32x4);
|
|
678
|
+
variance_a_a_f32x4 = vfmaq_f32(variance_a_a_f32x4, a1_y_f32x4, a1_y_f32x4);
|
|
679
|
+
variance_a_a_f32x4 = vfmaq_f32(variance_a_a_f32x4, a1_z_f32x4, a1_z_f32x4);
|
|
680
|
+
}
|
|
681
|
+
|
|
682
|
+
// Partial tail: handle remaining 1-3 points with vectorized partial deinterleave
|
|
683
|
+
if (i < n) {
|
|
684
|
+
nk_partial_deinterleave_bf16_to_f32x4_neonbfdot_(a + i * 3, n - i, &a1_x_f32x4, &a1_y_f32x4, &a1_z_f32x4);
|
|
685
|
+
nk_partial_deinterleave_bf16_to_f32x4_neonbfdot_(b + i * 3, n - i, &b1_x_f32x4, &b1_y_f32x4, &b1_z_f32x4);
|
|
686
|
+
sum_a_x_a_f32x4 = vaddq_f32(sum_a_x_a_f32x4, a1_x_f32x4);
|
|
687
|
+
sum_a_y_a_f32x4 = vaddq_f32(sum_a_y_a_f32x4, a1_y_f32x4);
|
|
688
|
+
sum_a_z_a_f32x4 = vaddq_f32(sum_a_z_a_f32x4, a1_z_f32x4);
|
|
689
|
+
sum_b_x_a_f32x4 = vaddq_f32(sum_b_x_a_f32x4, b1_x_f32x4);
|
|
690
|
+
sum_b_y_a_f32x4 = vaddq_f32(sum_b_y_a_f32x4, b1_y_f32x4);
|
|
691
|
+
sum_b_z_a_f32x4 = vaddq_f32(sum_b_z_a_f32x4, b1_z_f32x4);
|
|
692
|
+
cov_xx_a_f32x4 = vfmaq_f32(cov_xx_a_f32x4, a1_x_f32x4, b1_x_f32x4);
|
|
693
|
+
cov_xy_a_f32x4 = vfmaq_f32(cov_xy_a_f32x4, a1_x_f32x4, b1_y_f32x4);
|
|
694
|
+
cov_xz_a_f32x4 = vfmaq_f32(cov_xz_a_f32x4, a1_x_f32x4, b1_z_f32x4);
|
|
695
|
+
cov_yx_a_f32x4 = vfmaq_f32(cov_yx_a_f32x4, a1_y_f32x4, b1_x_f32x4);
|
|
696
|
+
cov_yy_a_f32x4 = vfmaq_f32(cov_yy_a_f32x4, a1_y_f32x4, b1_y_f32x4);
|
|
697
|
+
cov_yz_a_f32x4 = vfmaq_f32(cov_yz_a_f32x4, a1_y_f32x4, b1_z_f32x4);
|
|
698
|
+
cov_zx_a_f32x4 = vfmaq_f32(cov_zx_a_f32x4, a1_z_f32x4, b1_x_f32x4);
|
|
699
|
+
cov_zy_a_f32x4 = vfmaq_f32(cov_zy_a_f32x4, a1_z_f32x4, b1_y_f32x4);
|
|
700
|
+
cov_zz_a_f32x4 = vfmaq_f32(cov_zz_a_f32x4, a1_z_f32x4, b1_z_f32x4);
|
|
701
|
+
variance_a_a_f32x4 = vfmaq_f32(variance_a_a_f32x4, a1_x_f32x4, a1_x_f32x4);
|
|
702
|
+
variance_a_a_f32x4 = vfmaq_f32(variance_a_a_f32x4, a1_y_f32x4, a1_y_f32x4);
|
|
703
|
+
variance_a_a_f32x4 = vfmaq_f32(variance_a_a_f32x4, a1_z_f32x4, a1_z_f32x4);
|
|
704
|
+
}
|
|
705
|
+
|
|
706
|
+
// Combine dual accumulators
|
|
707
|
+
float32x4_t sum_a_x_f32x4 = vaddq_f32(sum_a_x_a_f32x4, sum_a_x_b_f32x4);
|
|
708
|
+
float32x4_t sum_a_y_f32x4 = vaddq_f32(sum_a_y_a_f32x4, sum_a_y_b_f32x4);
|
|
709
|
+
float32x4_t sum_a_z_f32x4 = vaddq_f32(sum_a_z_a_f32x4, sum_a_z_b_f32x4);
|
|
710
|
+
float32x4_t sum_b_x_f32x4 = vaddq_f32(sum_b_x_a_f32x4, sum_b_x_b_f32x4);
|
|
711
|
+
float32x4_t sum_b_y_f32x4 = vaddq_f32(sum_b_y_a_f32x4, sum_b_y_b_f32x4);
|
|
712
|
+
float32x4_t sum_b_z_f32x4 = vaddq_f32(sum_b_z_a_f32x4, sum_b_z_b_f32x4);
|
|
713
|
+
float32x4_t cov_xx_f32x4 = vaddq_f32(cov_xx_a_f32x4, cov_xx_b_f32x4);
|
|
714
|
+
float32x4_t cov_xy_f32x4 = vaddq_f32(cov_xy_a_f32x4, cov_xy_b_f32x4);
|
|
715
|
+
float32x4_t cov_xz_f32x4 = vaddq_f32(cov_xz_a_f32x4, cov_xz_b_f32x4);
|
|
716
|
+
float32x4_t cov_yx_f32x4 = vaddq_f32(cov_yx_a_f32x4, cov_yx_b_f32x4);
|
|
717
|
+
float32x4_t cov_yy_f32x4 = vaddq_f32(cov_yy_a_f32x4, cov_yy_b_f32x4);
|
|
718
|
+
float32x4_t cov_yz_f32x4 = vaddq_f32(cov_yz_a_f32x4, cov_yz_b_f32x4);
|
|
719
|
+
float32x4_t cov_zx_f32x4 = vaddq_f32(cov_zx_a_f32x4, cov_zx_b_f32x4);
|
|
720
|
+
float32x4_t cov_zy_f32x4 = vaddq_f32(cov_zy_a_f32x4, cov_zy_b_f32x4);
|
|
721
|
+
float32x4_t cov_zz_f32x4 = vaddq_f32(cov_zz_a_f32x4, cov_zz_b_f32x4);
|
|
722
|
+
float32x4_t variance_a_f32x4 = vaddq_f32(variance_a_a_f32x4, variance_a_b_f32x4);
|
|
723
|
+
|
|
724
|
+
// Reduce vector accumulators
|
|
725
|
+
nk_f32_t sum_a_x = vaddvq_f32(sum_a_x_f32x4);
|
|
726
|
+
nk_f32_t sum_a_y = vaddvq_f32(sum_a_y_f32x4);
|
|
727
|
+
nk_f32_t sum_a_z = vaddvq_f32(sum_a_z_f32x4);
|
|
728
|
+
nk_f32_t sum_b_x = vaddvq_f32(sum_b_x_f32x4);
|
|
729
|
+
nk_f32_t sum_b_y = vaddvq_f32(sum_b_y_f32x4);
|
|
730
|
+
nk_f32_t sum_b_z = vaddvq_f32(sum_b_z_f32x4);
|
|
731
|
+
|
|
732
|
+
nk_f32_t covariance_x_x = vaddvq_f32(cov_xx_f32x4);
|
|
733
|
+
nk_f32_t covariance_x_y = vaddvq_f32(cov_xy_f32x4);
|
|
734
|
+
nk_f32_t covariance_x_z = vaddvq_f32(cov_xz_f32x4);
|
|
735
|
+
nk_f32_t covariance_y_x = vaddvq_f32(cov_yx_f32x4);
|
|
736
|
+
nk_f32_t covariance_y_y = vaddvq_f32(cov_yy_f32x4);
|
|
737
|
+
nk_f32_t covariance_y_z = vaddvq_f32(cov_yz_f32x4);
|
|
738
|
+
nk_f32_t covariance_z_x = vaddvq_f32(cov_zx_f32x4);
|
|
739
|
+
nk_f32_t covariance_z_y = vaddvq_f32(cov_zy_f32x4);
|
|
740
|
+
nk_f32_t covariance_z_z = vaddvq_f32(cov_zz_f32x4);
|
|
741
|
+
nk_f32_t variance_a_sum = vaddvq_f32(variance_a_f32x4);
|
|
742
|
+
|
|
743
|
+
// Compute centroids
|
|
744
|
+
nk_f32_t inv_n = 1.0f / (nk_f32_t)n;
|
|
745
|
+
nk_f32_t centroid_a_x = sum_a_x * inv_n;
|
|
746
|
+
nk_f32_t centroid_a_y = sum_a_y * inv_n;
|
|
747
|
+
nk_f32_t centroid_a_z = sum_a_z * inv_n;
|
|
748
|
+
nk_f32_t centroid_b_x = sum_b_x * inv_n;
|
|
749
|
+
nk_f32_t centroid_b_y = sum_b_y * inv_n;
|
|
750
|
+
nk_f32_t centroid_b_z = sum_b_z * inv_n;
|
|
751
|
+
|
|
752
|
+
if (a_centroid) {
|
|
753
|
+
a_centroid[0] = centroid_a_x;
|
|
754
|
+
a_centroid[1] = centroid_a_y;
|
|
755
|
+
a_centroid[2] = centroid_a_z;
|
|
756
|
+
}
|
|
757
|
+
if (b_centroid) {
|
|
758
|
+
b_centroid[0] = centroid_b_x;
|
|
759
|
+
b_centroid[1] = centroid_b_y;
|
|
760
|
+
b_centroid[2] = centroid_b_z;
|
|
761
|
+
}
|
|
762
|
+
|
|
763
|
+
// Compute centered variance of A
|
|
764
|
+
nk_f32_t variance_a = variance_a_sum * inv_n -
|
|
765
|
+
(centroid_a_x * centroid_a_x + centroid_a_y * centroid_a_y + centroid_a_z * centroid_a_z);
|
|
766
|
+
|
|
767
|
+
// Apply centering correction: H_centered = H - n * centroid_a * centroid_bᵀ
|
|
768
|
+
covariance_x_x -= n * centroid_a_x * centroid_b_x;
|
|
769
|
+
covariance_x_y -= n * centroid_a_x * centroid_b_y;
|
|
770
|
+
covariance_x_z -= n * centroid_a_x * centroid_b_z;
|
|
771
|
+
covariance_y_x -= n * centroid_a_y * centroid_b_x;
|
|
772
|
+
covariance_y_y -= n * centroid_a_y * centroid_b_y;
|
|
773
|
+
covariance_y_z -= n * centroid_a_y * centroid_b_z;
|
|
774
|
+
covariance_z_x -= n * centroid_a_z * centroid_b_x;
|
|
775
|
+
covariance_z_y -= n * centroid_a_z * centroid_b_y;
|
|
776
|
+
covariance_z_z -= n * centroid_a_z * centroid_b_z;
|
|
777
|
+
|
|
778
|
+
// Compute SVD
|
|
779
|
+
nk_f32_t cross_covariance[9] = {covariance_x_x, covariance_x_y, covariance_x_z, covariance_y_x, covariance_y_y,
|
|
780
|
+
covariance_y_z, covariance_z_x, covariance_z_y, covariance_z_z};
|
|
781
|
+
nk_f32_t svd_u[9], svd_s[9], svd_v[9];
|
|
782
|
+
nk_svd3x3_f32_(cross_covariance, svd_u, svd_s, svd_v);
|
|
783
|
+
|
|
784
|
+
// R = V * Uᵀ
|
|
785
|
+
nk_f32_t r[9];
|
|
786
|
+
r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
|
|
787
|
+
r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
|
|
788
|
+
r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
|
|
789
|
+
r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
|
|
790
|
+
r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
|
|
791
|
+
r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
|
|
792
|
+
r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
|
|
793
|
+
r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
|
|
794
|
+
r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
|
|
795
|
+
|
|
796
|
+
// Handle reflection and compute scale: c = trace(D × S) / variance(a)
|
|
797
|
+
// D = diag(1, 1, det(R)), svd_s contains proper positive singular values on diagonal
|
|
798
|
+
nk_f32_t rotation_det = nk_det3x3_f32_(r);
|
|
799
|
+
nk_f32_t sign_det = rotation_det < 0 ? -1.0f : 1.0f;
|
|
800
|
+
nk_f32_t trace_scaled_s = svd_s[0] + svd_s[4] + sign_det * svd_s[8];
|
|
801
|
+
nk_f32_t c = trace_scaled_s / ((nk_f32_t)n * variance_a);
|
|
802
|
+
if (scale) *scale = c;
|
|
803
|
+
|
|
804
|
+
if (rotation_det < 0) {
|
|
805
|
+
svd_v[2] = -svd_v[2];
|
|
806
|
+
svd_v[5] = -svd_v[5];
|
|
807
|
+
svd_v[8] = -svd_v[8];
|
|
808
|
+
r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
|
|
809
|
+
r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
|
|
810
|
+
r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
|
|
811
|
+
r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
|
|
812
|
+
r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
|
|
813
|
+
r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
|
|
814
|
+
r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
|
|
815
|
+
r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
|
|
816
|
+
r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
|
|
817
|
+
}
|
|
818
|
+
|
|
819
|
+
/* Output rotation matrix */
|
|
820
|
+
if (rotation) {
|
|
821
|
+
for (int j = 0; j < 9; ++j) rotation[j] = r[j];
|
|
822
|
+
}
|
|
823
|
+
|
|
824
|
+
// Compute RMSD after similarity transform: ‖c × R × a - b‖
|
|
825
|
+
nk_f32_t sum_squared = nk_transformed_ssd_bf16_neonbfdot_(a, b, n, r, c, centroid_a_x, centroid_a_y, centroid_a_z,
|
|
826
|
+
centroid_b_x, centroid_b_y, centroid_b_z);
|
|
827
|
+
*result = nk_f32_sqrt_neon(sum_squared * inv_n);
|
|
828
|
+
}
|
|
829
|
+
|
|
830
|
+
#if defined(__clang__)
|
|
831
|
+
#pragma clang attribute pop
|
|
832
|
+
#elif defined(__GNUC__)
|
|
833
|
+
#pragma GCC pop_options
|
|
834
|
+
#endif
|
|
835
|
+
|
|
836
|
+
#if defined(__cplusplus)
|
|
837
|
+
} // extern "C"
|
|
838
|
+
#endif
|
|
839
|
+
|
|
840
|
+
#endif // NK_TARGET_NEONBFDOT
|
|
841
|
+
#endif // NK_TARGET_ARM_
|
|
842
|
+
#endif // NK_MESH_NEONBFDOT_H
|