numkong 7.0.0 → 7.4.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +197 -124
- package/binding.gyp +34 -484
- package/c/dispatch_bf16.c +59 -1
- package/c/dispatch_e2m3.c +41 -8
- package/c/dispatch_e3m2.c +49 -8
- package/c/dispatch_e4m3.c +51 -9
- package/c/dispatch_e5m2.c +45 -1
- package/c/dispatch_f16.c +79 -26
- package/c/dispatch_f16c.c +5 -5
- package/c/dispatch_f32.c +56 -0
- package/c/dispatch_f64.c +52 -0
- package/c/dispatch_i4.c +3 -0
- package/c/dispatch_i8.c +62 -3
- package/c/dispatch_other.c +18 -0
- package/c/dispatch_u1.c +54 -9
- package/c/dispatch_u4.c +3 -0
- package/c/dispatch_u8.c +64 -3
- package/c/numkong.c +3 -0
- package/include/README.md +79 -9
- package/include/numkong/attention/sapphireamx.h +278 -276
- package/include/numkong/attention/sme.h +983 -977
- package/include/numkong/attention.h +1 -1
- package/include/numkong/capabilities.h +289 -94
- package/include/numkong/cast/README.md +40 -40
- package/include/numkong/cast/diamond.h +64 -0
- package/include/numkong/cast/haswell.h +42 -194
- package/include/numkong/cast/icelake.h +42 -37
- package/include/numkong/cast/loongsonasx.h +252 -0
- package/include/numkong/cast/neon.h +216 -249
- package/include/numkong/cast/powervsx.h +449 -0
- package/include/numkong/cast/rvv.h +223 -274
- package/include/numkong/cast/sapphire.h +18 -18
- package/include/numkong/cast/serial.h +1018 -944
- package/include/numkong/cast/skylake.h +82 -23
- package/include/numkong/cast/v128relaxed.h +462 -105
- package/include/numkong/cast.h +24 -0
- package/include/numkong/cast.hpp +44 -0
- package/include/numkong/curved/README.md +17 -17
- package/include/numkong/curved/neon.h +131 -7
- package/include/numkong/curved/neonbfdot.h +6 -7
- package/include/numkong/curved/rvv.h +26 -26
- package/include/numkong/curved/smef64.h +186 -182
- package/include/numkong/curved.h +14 -18
- package/include/numkong/dot/README.md +154 -137
- package/include/numkong/dot/alder.h +43 -43
- package/include/numkong/dot/diamond.h +158 -0
- package/include/numkong/dot/genoa.h +4 -30
- package/include/numkong/dot/haswell.h +215 -180
- package/include/numkong/dot/icelake.h +190 -76
- package/include/numkong/dot/loongsonasx.h +671 -0
- package/include/numkong/dot/neon.h +124 -73
- package/include/numkong/dot/neonbfdot.h +11 -12
- package/include/numkong/dot/neonfhm.h +44 -46
- package/include/numkong/dot/neonfp8.h +323 -0
- package/include/numkong/dot/neonsdot.h +190 -76
- package/include/numkong/dot/powervsx.h +752 -0
- package/include/numkong/dot/rvv.h +92 -84
- package/include/numkong/dot/rvvbf16.h +12 -12
- package/include/numkong/dot/rvvhalf.h +12 -12
- package/include/numkong/dot/sapphire.h +4 -4
- package/include/numkong/dot/serial.h +66 -30
- package/include/numkong/dot/sierra.h +31 -31
- package/include/numkong/dot/skylake.h +142 -110
- package/include/numkong/dot/sve.h +217 -177
- package/include/numkong/dot/svebfdot.h +10 -10
- package/include/numkong/dot/svehalf.h +85 -41
- package/include/numkong/dot/svesdot.h +89 -0
- package/include/numkong/dot/v128relaxed.h +124 -89
- package/include/numkong/dot.h +114 -48
- package/include/numkong/dots/README.md +203 -203
- package/include/numkong/dots/alder.h +12 -9
- package/include/numkong/dots/diamond.h +86 -0
- package/include/numkong/dots/genoa.h +10 -4
- package/include/numkong/dots/haswell.h +63 -48
- package/include/numkong/dots/icelake.h +27 -18
- package/include/numkong/dots/loongsonasx.h +176 -0
- package/include/numkong/dots/neon.h +14 -11
- package/include/numkong/dots/neonbfdot.h +4 -3
- package/include/numkong/dots/neonfhm.h +11 -9
- package/include/numkong/dots/neonfp8.h +99 -0
- package/include/numkong/dots/neonsdot.h +48 -12
- package/include/numkong/dots/powervsx.h +194 -0
- package/include/numkong/dots/rvv.h +451 -344
- package/include/numkong/dots/sapphireamx.h +1028 -984
- package/include/numkong/dots/serial.h +213 -197
- package/include/numkong/dots/sierra.h +10 -7
- package/include/numkong/dots/skylake.h +47 -36
- package/include/numkong/dots/sme.h +2001 -2364
- package/include/numkong/dots/smebi32.h +175 -162
- package/include/numkong/dots/smef64.h +328 -323
- package/include/numkong/dots/v128relaxed.h +64 -41
- package/include/numkong/dots.h +573 -293
- package/include/numkong/dots.hpp +45 -43
- package/include/numkong/each/README.md +133 -137
- package/include/numkong/each/haswell.h +6 -6
- package/include/numkong/each/icelake.h +7 -7
- package/include/numkong/each/neon.h +76 -42
- package/include/numkong/each/neonbfdot.h +11 -12
- package/include/numkong/each/neonhalf.h +24 -116
- package/include/numkong/each/rvv.h +28 -28
- package/include/numkong/each/sapphire.h +27 -161
- package/include/numkong/each/serial.h +6 -6
- package/include/numkong/each/skylake.h +7 -7
- package/include/numkong/each/v128relaxed.h +562 -0
- package/include/numkong/each.h +148 -62
- package/include/numkong/each.hpp +2 -2
- package/include/numkong/geospatial/README.md +18 -18
- package/include/numkong/geospatial/haswell.h +365 -325
- package/include/numkong/geospatial/neon.h +350 -306
- package/include/numkong/geospatial/rvv.h +4 -4
- package/include/numkong/geospatial/skylake.h +376 -340
- package/include/numkong/geospatial/v128relaxed.h +366 -327
- package/include/numkong/geospatial.h +17 -17
- package/include/numkong/matrix.hpp +4 -4
- package/include/numkong/maxsim/README.md +14 -14
- package/include/numkong/maxsim/alder.h +6 -6
- package/include/numkong/maxsim/genoa.h +4 -4
- package/include/numkong/maxsim/haswell.h +6 -6
- package/include/numkong/maxsim/icelake.h +18 -18
- package/include/numkong/maxsim/neonsdot.h +21 -21
- package/include/numkong/maxsim/sapphireamx.h +14 -14
- package/include/numkong/maxsim/serial.h +6 -6
- package/include/numkong/maxsim/sme.h +221 -196
- package/include/numkong/maxsim/v128relaxed.h +6 -6
- package/include/numkong/mesh/README.md +62 -56
- package/include/numkong/mesh/haswell.h +339 -464
- package/include/numkong/mesh/neon.h +1100 -519
- package/include/numkong/mesh/neonbfdot.h +36 -68
- package/include/numkong/mesh/rvv.h +530 -435
- package/include/numkong/mesh/serial.h +75 -91
- package/include/numkong/mesh/skylake.h +1627 -302
- package/include/numkong/mesh/v128relaxed.h +443 -330
- package/include/numkong/mesh.h +63 -49
- package/include/numkong/mesh.hpp +4 -4
- package/include/numkong/numkong.h +3 -3
- package/include/numkong/numkong.hpp +1 -0
- package/include/numkong/probability/README.md +23 -19
- package/include/numkong/probability/neon.h +82 -52
- package/include/numkong/probability/rvv.h +28 -23
- package/include/numkong/probability/serial.h +51 -39
- package/include/numkong/probability.h +20 -23
- package/include/numkong/random.h +1 -1
- package/include/numkong/reduce/README.md +143 -138
- package/include/numkong/reduce/alder.h +81 -77
- package/include/numkong/reduce/haswell.h +222 -220
- package/include/numkong/reduce/neon.h +629 -519
- package/include/numkong/reduce/neonbfdot.h +7 -218
- package/include/numkong/reduce/neonfhm.h +9 -381
- package/include/numkong/reduce/neonsdot.h +9 -9
- package/include/numkong/reduce/rvv.h +928 -802
- package/include/numkong/reduce/serial.h +23 -27
- package/include/numkong/reduce/sierra.h +20 -20
- package/include/numkong/reduce/skylake.h +326 -324
- package/include/numkong/reduce/v128relaxed.h +52 -52
- package/include/numkong/reduce.h +4 -23
- package/include/numkong/reduce.hpp +156 -11
- package/include/numkong/scalar/README.md +6 -6
- package/include/numkong/scalar/haswell.h +26 -17
- package/include/numkong/scalar/loongsonasx.h +74 -0
- package/include/numkong/scalar/neon.h +9 -9
- package/include/numkong/scalar/powervsx.h +96 -0
- package/include/numkong/scalar/rvv.h +2 -2
- package/include/numkong/scalar/sapphire.h +21 -10
- package/include/numkong/scalar/serial.h +21 -21
- package/include/numkong/scalar.h +13 -0
- package/include/numkong/set/README.md +28 -28
- package/include/numkong/set/haswell.h +12 -12
- package/include/numkong/set/icelake.h +14 -14
- package/include/numkong/set/loongsonasx.h +181 -0
- package/include/numkong/set/neon.h +17 -18
- package/include/numkong/set/powervsx.h +326 -0
- package/include/numkong/set/rvv.h +4 -4
- package/include/numkong/set/serial.h +6 -6
- package/include/numkong/set/sve.h +60 -59
- package/include/numkong/set/v128relaxed.h +6 -6
- package/include/numkong/set.h +21 -7
- package/include/numkong/sets/README.md +26 -26
- package/include/numkong/sets/loongsonasx.h +52 -0
- package/include/numkong/sets/powervsx.h +65 -0
- package/include/numkong/sets/smebi32.h +395 -364
- package/include/numkong/sets.h +83 -40
- package/include/numkong/sparse/README.md +4 -4
- package/include/numkong/sparse/icelake.h +101 -101
- package/include/numkong/sparse/serial.h +1 -1
- package/include/numkong/sparse/sve2.h +137 -141
- package/include/numkong/sparse/turin.h +12 -12
- package/include/numkong/sparse.h +10 -10
- package/include/numkong/spatial/README.md +230 -226
- package/include/numkong/spatial/alder.h +113 -116
- package/include/numkong/spatial/diamond.h +240 -0
- package/include/numkong/spatial/genoa.h +0 -68
- package/include/numkong/spatial/haswell.h +74 -55
- package/include/numkong/spatial/icelake.h +539 -58
- package/include/numkong/spatial/loongsonasx.h +483 -0
- package/include/numkong/spatial/neon.h +125 -52
- package/include/numkong/spatial/neonbfdot.h +8 -9
- package/include/numkong/spatial/neonfp8.h +258 -0
- package/include/numkong/spatial/neonsdot.h +180 -12
- package/include/numkong/spatial/powervsx.h +738 -0
- package/include/numkong/spatial/rvv.h +146 -139
- package/include/numkong/spatial/rvvbf16.h +17 -12
- package/include/numkong/spatial/rvvhalf.h +13 -10
- package/include/numkong/spatial/serial.h +13 -12
- package/include/numkong/spatial/sierra.h +232 -39
- package/include/numkong/spatial/skylake.h +73 -74
- package/include/numkong/spatial/sve.h +93 -72
- package/include/numkong/spatial/svebfdot.h +29 -29
- package/include/numkong/spatial/svehalf.h +52 -26
- package/include/numkong/spatial/svesdot.h +142 -0
- package/include/numkong/spatial/v128relaxed.h +293 -41
- package/include/numkong/spatial.h +338 -82
- package/include/numkong/spatials/README.md +194 -194
- package/include/numkong/spatials/diamond.h +82 -0
- package/include/numkong/spatials/haswell.h +2 -2
- package/include/numkong/spatials/loongsonasx.h +153 -0
- package/include/numkong/spatials/neonfp8.h +111 -0
- package/include/numkong/spatials/neonsdot.h +34 -0
- package/include/numkong/spatials/powervsx.h +153 -0
- package/include/numkong/spatials/rvv.h +259 -243
- package/include/numkong/spatials/sapphireamx.h +173 -173
- package/include/numkong/spatials/serial.h +2 -2
- package/include/numkong/spatials/skylake.h +2 -2
- package/include/numkong/spatials/sme.h +590 -605
- package/include/numkong/spatials/smef64.h +139 -130
- package/include/numkong/spatials/v128relaxed.h +2 -2
- package/include/numkong/spatials.h +820 -500
- package/include/numkong/spatials.hpp +49 -48
- package/include/numkong/tensor.hpp +406 -17
- package/include/numkong/trigonometry/README.md +19 -19
- package/include/numkong/trigonometry/haswell.h +402 -401
- package/include/numkong/trigonometry/neon.h +386 -387
- package/include/numkong/trigonometry/rvv.h +52 -51
- package/include/numkong/trigonometry/serial.h +13 -13
- package/include/numkong/trigonometry/skylake.h +373 -369
- package/include/numkong/trigonometry/v128relaxed.h +375 -374
- package/include/numkong/trigonometry.h +13 -13
- package/include/numkong/trigonometry.hpp +2 -2
- package/include/numkong/types.h +287 -49
- package/include/numkong/types.hpp +436 -12
- package/include/numkong/vector.hpp +82 -14
- package/javascript/dist/cjs/numkong-wasm.js +6 -12
- package/javascript/dist/cjs/numkong.d.ts +7 -1
- package/javascript/dist/cjs/numkong.js +37 -11
- package/javascript/dist/cjs/types.d.ts +9 -0
- package/javascript/dist/cjs/types.js +96 -0
- package/javascript/dist/esm/numkong-browser.d.ts +14 -0
- package/javascript/dist/esm/numkong-browser.js +23 -0
- package/javascript/dist/esm/numkong-wasm.js +6 -12
- package/javascript/dist/esm/numkong.d.ts +7 -1
- package/javascript/dist/esm/numkong.js +37 -11
- package/javascript/dist/esm/types.d.ts +9 -0
- package/javascript/dist/esm/types.js +96 -0
- package/javascript/node-gyp-build.d.ts +4 -1
- package/javascript/numkong-browser.ts +40 -0
- package/javascript/numkong-wasm.ts +7 -13
- package/javascript/numkong.c +5 -26
- package/javascript/numkong.ts +36 -11
- package/javascript/tsconfig-base.json +1 -0
- package/javascript/tsconfig-cjs.json +6 -1
- package/javascript/types.ts +110 -0
- package/numkong.gypi +101 -0
- package/package.json +34 -13
- package/probes/arm_neon.c +8 -0
- package/probes/arm_neon_bfdot.c +9 -0
- package/probes/arm_neon_fhm.c +9 -0
- package/probes/arm_neon_half.c +8 -0
- package/probes/arm_neon_sdot.c +9 -0
- package/probes/arm_neonfp8.c +9 -0
- package/probes/arm_sme.c +16 -0
- package/probes/arm_sme2.c +16 -0
- package/probes/arm_sme2p1.c +16 -0
- package/probes/arm_sme_bf16.c +16 -0
- package/probes/arm_sme_bi32.c +16 -0
- package/probes/arm_sme_f64.c +16 -0
- package/probes/arm_sme_fa64.c +14 -0
- package/probes/arm_sme_half.c +16 -0
- package/probes/arm_sme_lut2.c +15 -0
- package/probes/arm_sve.c +18 -0
- package/probes/arm_sve2.c +20 -0
- package/probes/arm_sve2p1.c +18 -0
- package/probes/arm_sve_bfdot.c +20 -0
- package/probes/arm_sve_half.c +18 -0
- package/probes/arm_sve_sdot.c +21 -0
- package/probes/loongarch_lasx.c +12 -0
- package/probes/power_vsx.c +12 -0
- package/probes/probe.js +127 -0
- package/probes/riscv_rvv.c +14 -0
- package/probes/riscv_rvv_bb.c +15 -0
- package/probes/riscv_rvv_bf16.c +17 -0
- package/probes/riscv_rvv_half.c +14 -0
- package/probes/wasm_v128relaxed.c +11 -0
- package/probes/x86_alder.c +17 -0
- package/probes/x86_diamond.c +17 -0
- package/probes/x86_genoa.c +17 -0
- package/probes/x86_graniteamx.c +19 -0
- package/probes/x86_haswell.c +11 -0
- package/probes/x86_icelake.c +17 -0
- package/probes/x86_sapphire.c +16 -0
- package/probes/x86_sapphireamx.c +18 -0
- package/probes/x86_sierra.c +17 -0
- package/probes/x86_skylake.c +15 -0
- package/probes/x86_turin.c +17 -0
- package/wasm/numkong-emscripten.js +2 -0
- package/wasm/numkong.d.ts +14 -0
- package/wasm/numkong.js +1124 -0
- package/wasm/numkong.wasm +0 -0
- package/include/numkong/curved/neonhalf.h +0 -212
- package/include/numkong/dot/neonhalf.h +0 -198
- package/include/numkong/dots/neonhalf.h +0 -57
- package/include/numkong/mesh/neonhalf.h +0 -616
- package/include/numkong/reduce/neonhalf.h +0 -157
- package/include/numkong/spatial/neonhalf.h +0 -118
- package/include/numkong/spatial/sapphire.h +0 -343
- package/include/numkong/spatials/neonhalf.h +0 -58
- package/javascript/README.md +0 -246
|
@@ -8,16 +8,15 @@
|
|
|
8
8
|
*
|
|
9
9
|
* @section mesh_neonbfdot_instructions ARM NEON BF16 Instructions (ARMv8.6-BF16)
|
|
10
10
|
*
|
|
11
|
-
* Intrinsic
|
|
12
|
-
*
|
|
13
|
-
*
|
|
14
|
-
*
|
|
15
|
-
*
|
|
16
|
-
*
|
|
17
|
-
*
|
|
18
|
-
*
|
|
19
|
-
*
|
|
20
|
-
* vaddvq_f32 FADDP+FADDP (V.4S) 4cy 1/cy 2/cy
|
|
11
|
+
* Intrinsic Instruction A76 M5
|
|
12
|
+
* vld3_u16 LD3 (V.4H x 3) 4cy @ 1p 4cy @ 1p
|
|
13
|
+
* vshll_n_u16 USHLL (V.4S, V.4H, #16) 2cy @ 2p 2cy @ 4p
|
|
14
|
+
* vfmaq_f32 FMLA (V.4S, V.4S, V.4S) 4cy @ 2p 3cy @ 4p
|
|
15
|
+
* vaddq_f32 FADD (V.4S, V.4S, V.4S) 2cy @ 2p 2cy @ 4p
|
|
16
|
+
* vsubq_f32 FSUB (V.4S, V.4S, V.4S) 2cy @ 2p 2cy @ 4p
|
|
17
|
+
* vmulq_f32 FMUL (V.4S, V.4S, V.4S) 3cy @ 2p 3cy @ 4p
|
|
18
|
+
* vdupq_n_f32 DUP (V.4S, scalar) 2cy @ 2p 2cy @ 4p
|
|
19
|
+
* vaddvq_f32 FADDP+FADDP (V.4S) 5cy @ 1p 8cy @ 1p
|
|
21
20
|
*
|
|
22
21
|
* The ARMv8.6-BF16 extension enables BF16 storage with F32 computation for 3D mesh alignment
|
|
23
22
|
* operations. BF16's wider exponent range (matching F32) prevents overflow in geometric calculations
|
|
@@ -57,14 +56,14 @@ extern "C" {
|
|
|
57
56
|
NK_INTERNAL void nk_deinterleave_bf16x4_to_f32x4_neonbfdot_(nk_bf16_t const *ptr, float32x4_t *x_out,
|
|
58
57
|
float32x4_t *y_out, float32x4_t *z_out) {
|
|
59
58
|
// Load 12 bf16 values and de-interleave into x, y, z components
|
|
60
|
-
uint16x4x3_t
|
|
59
|
+
uint16x4x3_t xyz_u16x4x3 = vld3_u16((nk_u16_t const *)ptr);
|
|
61
60
|
// Convert bf16 to f32 by zero-extending to lower 16 bits, then shifting left by 16
|
|
62
|
-
uint32x4_t
|
|
63
|
-
uint32x4_t
|
|
64
|
-
uint32x4_t
|
|
65
|
-
*x_out = vreinterpretq_f32_u32(
|
|
66
|
-
*y_out = vreinterpretq_f32_u32(
|
|
67
|
-
*z_out = vreinterpretq_f32_u32(
|
|
61
|
+
uint32x4_t x_u32x4 = vshll_n_u16(xyz_u16x4x3.val[0], 16);
|
|
62
|
+
uint32x4_t y_u32x4 = vshll_n_u16(xyz_u16x4x3.val[1], 16);
|
|
63
|
+
uint32x4_t z_u32x4 = vshll_n_u16(xyz_u16x4x3.val[2], 16);
|
|
64
|
+
*x_out = vreinterpretq_f32_u32(x_u32x4);
|
|
65
|
+
*y_out = vreinterpretq_f32_u32(y_u32x4);
|
|
66
|
+
*z_out = vreinterpretq_f32_u32(z_u32x4);
|
|
68
67
|
}
|
|
69
68
|
|
|
70
69
|
NK_INTERNAL void nk_partial_deinterleave_bf16_to_f32x4_neonbfdot_(nk_bf16_t const *ptr, nk_size_t n_points,
|
|
@@ -216,8 +215,9 @@ NK_INTERNAL nk_f32_t nk_transformed_ssd_bf16_neonbfdot_(nk_bf16_t const *a, nk_b
|
|
|
216
215
|
nk_partial_deinterleave_bf16_to_f32x4_neonbfdot_(b + j * 3, n - j, &b_x_f32x4, &b_y_f32x4, &b_z_f32x4);
|
|
217
216
|
|
|
218
217
|
// Mask invalid lanes to zero BEFORE centering
|
|
219
|
-
uint32x4_t lane_u32x4 =
|
|
220
|
-
|
|
218
|
+
uint32x4_t lane_u32x4 = vcombine_u32(vreinterpret_u32_u64(vcreate_u64(0x0000000100000000ULL)),
|
|
219
|
+
vreinterpret_u32_u64(vcreate_u64(0x0000000300000002ULL)));
|
|
220
|
+
uint32x4_t valid_u32x4 = vcltq_u32(lane_u32x4, vdupq_n_u32((nk_u32_t)(n - j)));
|
|
221
221
|
float32x4_t zero_f32x4 = vdupq_n_f32(0);
|
|
222
222
|
a_x_f32x4 = vbslq_f32(valid_u32x4, a_x_f32x4, zero_f32x4);
|
|
223
223
|
a_y_f32x4 = vbslq_f32(valid_u32x4, a_y_f32x4, zero_f32x4);
|
|
@@ -262,12 +262,10 @@ NK_INTERNAL nk_f32_t nk_transformed_ssd_bf16_neonbfdot_(nk_bf16_t const *a, nk_b
|
|
|
262
262
|
|
|
263
263
|
NK_PUBLIC void nk_rmsd_bf16_neonbfdot(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
|
|
264
264
|
nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
|
|
265
|
-
|
|
266
|
-
if (rotation)
|
|
267
|
-
rotation[0] = 1, rotation[1] = 0, rotation[2] = 0
|
|
268
|
-
rotation[3] = 0, rotation[4] = 1, rotation[5] = 0;
|
|
265
|
+
// RMSD uses identity rotation and scale=1.0
|
|
266
|
+
if (rotation)
|
|
267
|
+
rotation[0] = 1, rotation[1] = 0, rotation[2] = 0, rotation[3] = 0, rotation[4] = 1, rotation[5] = 0,
|
|
269
268
|
rotation[6] = 0, rotation[7] = 0, rotation[8] = 1;
|
|
270
|
-
}
|
|
271
269
|
if (scale) *scale = 1.0f;
|
|
272
270
|
|
|
273
271
|
float32x4_t const zeros_f32x4 = vdupq_n_f32(0);
|
|
@@ -343,16 +341,8 @@ NK_PUBLIC void nk_rmsd_bf16_neonbfdot(nk_bf16_t const *a, nk_bf16_t const *b, nk
|
|
|
343
341
|
nk_f32_t centroid_b_y = total_by * inv_n;
|
|
344
342
|
nk_f32_t centroid_b_z = total_bz * inv_n;
|
|
345
343
|
|
|
346
|
-
if (a_centroid)
|
|
347
|
-
|
|
348
|
-
a_centroid[1] = centroid_a_y;
|
|
349
|
-
a_centroid[2] = centroid_a_z;
|
|
350
|
-
}
|
|
351
|
-
if (b_centroid) {
|
|
352
|
-
b_centroid[0] = centroid_b_x;
|
|
353
|
-
b_centroid[1] = centroid_b_y;
|
|
354
|
-
b_centroid[2] = centroid_b_z;
|
|
355
|
-
}
|
|
344
|
+
if (a_centroid) a_centroid[0] = centroid_a_x, a_centroid[1] = centroid_a_y, a_centroid[2] = centroid_a_z;
|
|
345
|
+
if (b_centroid) b_centroid[0] = centroid_b_x, b_centroid[1] = centroid_b_y, b_centroid[2] = centroid_b_z;
|
|
356
346
|
|
|
357
347
|
// Compute RMSD
|
|
358
348
|
nk_f32_t mean_diff_x = centroid_a_x - centroid_b_x;
|
|
@@ -368,7 +358,7 @@ NK_PUBLIC void nk_kabsch_bf16_neonbfdot(nk_bf16_t const *a, nk_bf16_t const *b,
|
|
|
368
358
|
nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
|
|
369
359
|
float32x4_t const zeros_f32x4 = vdupq_n_f32(0);
|
|
370
360
|
|
|
371
|
-
|
|
361
|
+
// 2x unrolling with dual accumulators to hide FMA latency.
|
|
372
362
|
float32x4_t sum_a_x_a_f32x4 = zeros_f32x4, sum_a_y_a_f32x4 = zeros_f32x4, sum_a_z_a_f32x4 = zeros_f32x4;
|
|
373
363
|
float32x4_t sum_b_x_a_f32x4 = zeros_f32x4, sum_b_y_a_f32x4 = zeros_f32x4, sum_b_z_a_f32x4 = zeros_f32x4;
|
|
374
364
|
float32x4_t sum_a_x_b_f32x4 = zeros_f32x4, sum_a_y_b_f32x4 = zeros_f32x4, sum_a_z_b_f32x4 = zeros_f32x4;
|
|
@@ -512,16 +502,8 @@ NK_PUBLIC void nk_kabsch_bf16_neonbfdot(nk_bf16_t const *a, nk_bf16_t const *b,
|
|
|
512
502
|
nk_f32_t centroid_b_y = sum_b_y * inv_n;
|
|
513
503
|
nk_f32_t centroid_b_z = sum_b_z * inv_n;
|
|
514
504
|
|
|
515
|
-
if (a_centroid)
|
|
516
|
-
|
|
517
|
-
a_centroid[1] = centroid_a_y;
|
|
518
|
-
a_centroid[2] = centroid_a_z;
|
|
519
|
-
}
|
|
520
|
-
if (b_centroid) {
|
|
521
|
-
b_centroid[0] = centroid_b_x;
|
|
522
|
-
b_centroid[1] = centroid_b_y;
|
|
523
|
-
b_centroid[2] = centroid_b_z;
|
|
524
|
-
}
|
|
505
|
+
if (a_centroid) a_centroid[0] = centroid_a_x, a_centroid[1] = centroid_a_y, a_centroid[2] = centroid_a_z;
|
|
506
|
+
if (b_centroid) b_centroid[0] = centroid_b_x, b_centroid[1] = centroid_b_y, b_centroid[2] = centroid_b_z;
|
|
525
507
|
|
|
526
508
|
// Apply centering correction: H_centered = H - n * centroid_a * centroid_bᵀ
|
|
527
509
|
covariance_x_x -= n * centroid_a_x * centroid_b_x;
|
|
@@ -554,9 +536,7 @@ NK_PUBLIC void nk_kabsch_bf16_neonbfdot(nk_bf16_t const *a, nk_bf16_t const *b,
|
|
|
554
536
|
|
|
555
537
|
// Handle reflection: if det(R) < 0, negate third column of V and recompute R
|
|
556
538
|
if (nk_det3x3_f32_(r) < 0) {
|
|
557
|
-
svd_v[2] = -svd_v[2];
|
|
558
|
-
svd_v[5] = -svd_v[5];
|
|
559
|
-
svd_v[8] = -svd_v[8];
|
|
539
|
+
svd_v[2] = -svd_v[2], svd_v[5] = -svd_v[5], svd_v[8] = -svd_v[8];
|
|
560
540
|
r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
|
|
561
541
|
r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
|
|
562
542
|
r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
|
|
@@ -568,10 +548,9 @@ NK_PUBLIC void nk_kabsch_bf16_neonbfdot(nk_bf16_t const *a, nk_bf16_t const *b,
|
|
|
568
548
|
r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
|
|
569
549
|
}
|
|
570
550
|
|
|
571
|
-
|
|
572
|
-
if (rotation)
|
|
551
|
+
// Output rotation matrix and scale=1.0
|
|
552
|
+
if (rotation)
|
|
573
553
|
for (int j = 0; j < 9; ++j) rotation[j] = r[j];
|
|
574
|
-
}
|
|
575
554
|
if (scale) *scale = 1.0f;
|
|
576
555
|
|
|
577
556
|
// Compute RMSD after optimal rotation
|
|
@@ -584,7 +563,7 @@ NK_PUBLIC void nk_umeyama_bf16_neonbfdot(nk_bf16_t const *a, nk_bf16_t const *b,
|
|
|
584
563
|
nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
|
|
585
564
|
float32x4_t const zeros_f32x4 = vdupq_n_f32(0);
|
|
586
565
|
|
|
587
|
-
|
|
566
|
+
// 2x unrolling with dual accumulators to hide FMA latency.
|
|
588
567
|
float32x4_t sum_a_x_a_f32x4 = zeros_f32x4, sum_a_y_a_f32x4 = zeros_f32x4, sum_a_z_a_f32x4 = zeros_f32x4;
|
|
589
568
|
float32x4_t sum_b_x_a_f32x4 = zeros_f32x4, sum_b_y_a_f32x4 = zeros_f32x4, sum_b_z_a_f32x4 = zeros_f32x4;
|
|
590
569
|
float32x4_t sum_a_x_b_f32x4 = zeros_f32x4, sum_a_y_b_f32x4 = zeros_f32x4, sum_a_z_b_f32x4 = zeros_f32x4;
|
|
@@ -749,16 +728,8 @@ NK_PUBLIC void nk_umeyama_bf16_neonbfdot(nk_bf16_t const *a, nk_bf16_t const *b,
|
|
|
749
728
|
nk_f32_t centroid_b_y = sum_b_y * inv_n;
|
|
750
729
|
nk_f32_t centroid_b_z = sum_b_z * inv_n;
|
|
751
730
|
|
|
752
|
-
if (a_centroid)
|
|
753
|
-
|
|
754
|
-
a_centroid[1] = centroid_a_y;
|
|
755
|
-
a_centroid[2] = centroid_a_z;
|
|
756
|
-
}
|
|
757
|
-
if (b_centroid) {
|
|
758
|
-
b_centroid[0] = centroid_b_x;
|
|
759
|
-
b_centroid[1] = centroid_b_y;
|
|
760
|
-
b_centroid[2] = centroid_b_z;
|
|
761
|
-
}
|
|
731
|
+
if (a_centroid) a_centroid[0] = centroid_a_x, a_centroid[1] = centroid_a_y, a_centroid[2] = centroid_a_z;
|
|
732
|
+
if (b_centroid) b_centroid[0] = centroid_b_x, b_centroid[1] = centroid_b_y, b_centroid[2] = centroid_b_z;
|
|
762
733
|
|
|
763
734
|
// Compute centered variance of A
|
|
764
735
|
nk_f32_t variance_a = variance_a_sum * inv_n -
|
|
@@ -802,9 +773,7 @@ NK_PUBLIC void nk_umeyama_bf16_neonbfdot(nk_bf16_t const *a, nk_bf16_t const *b,
|
|
|
802
773
|
if (scale) *scale = c;
|
|
803
774
|
|
|
804
775
|
if (rotation_det < 0) {
|
|
805
|
-
svd_v[2] = -svd_v[2];
|
|
806
|
-
svd_v[5] = -svd_v[5];
|
|
807
|
-
svd_v[8] = -svd_v[8];
|
|
776
|
+
svd_v[2] = -svd_v[2], svd_v[5] = -svd_v[5], svd_v[8] = -svd_v[8];
|
|
808
777
|
r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
|
|
809
778
|
r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
|
|
810
779
|
r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
|
|
@@ -816,10 +785,9 @@ NK_PUBLIC void nk_umeyama_bf16_neonbfdot(nk_bf16_t const *a, nk_bf16_t const *b,
|
|
|
816
785
|
r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
|
|
817
786
|
}
|
|
818
787
|
|
|
819
|
-
|
|
820
|
-
if (rotation)
|
|
788
|
+
// Output rotation matrix
|
|
789
|
+
if (rotation)
|
|
821
790
|
for (int j = 0; j < 9; ++j) rotation[j] = r[j];
|
|
822
|
-
}
|
|
823
791
|
|
|
824
792
|
// Compute RMSD after similarity transform: ‖c × R × a - b‖
|
|
825
793
|
nk_f32_t sum_squared = nk_transformed_ssd_bf16_neonbfdot_(a, b, n, r, c, centroid_a_x, centroid_a_y, centroid_a_z,
|