numkong 7.0.0 → 7.4.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +197 -124
- package/binding.gyp +34 -484
- package/c/dispatch_bf16.c +59 -1
- package/c/dispatch_e2m3.c +41 -8
- package/c/dispatch_e3m2.c +49 -8
- package/c/dispatch_e4m3.c +51 -9
- package/c/dispatch_e5m2.c +45 -1
- package/c/dispatch_f16.c +79 -26
- package/c/dispatch_f16c.c +5 -5
- package/c/dispatch_f32.c +56 -0
- package/c/dispatch_f64.c +52 -0
- package/c/dispatch_i4.c +3 -0
- package/c/dispatch_i8.c +62 -3
- package/c/dispatch_other.c +18 -0
- package/c/dispatch_u1.c +54 -9
- package/c/dispatch_u4.c +3 -0
- package/c/dispatch_u8.c +64 -3
- package/c/numkong.c +3 -0
- package/include/README.md +79 -9
- package/include/numkong/attention/sapphireamx.h +278 -276
- package/include/numkong/attention/sme.h +983 -977
- package/include/numkong/attention.h +1 -1
- package/include/numkong/capabilities.h +289 -94
- package/include/numkong/cast/README.md +40 -40
- package/include/numkong/cast/diamond.h +64 -0
- package/include/numkong/cast/haswell.h +42 -194
- package/include/numkong/cast/icelake.h +42 -37
- package/include/numkong/cast/loongsonasx.h +252 -0
- package/include/numkong/cast/neon.h +216 -249
- package/include/numkong/cast/powervsx.h +449 -0
- package/include/numkong/cast/rvv.h +223 -274
- package/include/numkong/cast/sapphire.h +18 -18
- package/include/numkong/cast/serial.h +1018 -944
- package/include/numkong/cast/skylake.h +82 -23
- package/include/numkong/cast/v128relaxed.h +462 -105
- package/include/numkong/cast.h +24 -0
- package/include/numkong/cast.hpp +44 -0
- package/include/numkong/curved/README.md +17 -17
- package/include/numkong/curved/neon.h +131 -7
- package/include/numkong/curved/neonbfdot.h +6 -7
- package/include/numkong/curved/rvv.h +26 -26
- package/include/numkong/curved/smef64.h +186 -182
- package/include/numkong/curved.h +14 -18
- package/include/numkong/dot/README.md +154 -137
- package/include/numkong/dot/alder.h +43 -43
- package/include/numkong/dot/diamond.h +158 -0
- package/include/numkong/dot/genoa.h +4 -30
- package/include/numkong/dot/haswell.h +215 -180
- package/include/numkong/dot/icelake.h +190 -76
- package/include/numkong/dot/loongsonasx.h +671 -0
- package/include/numkong/dot/neon.h +124 -73
- package/include/numkong/dot/neonbfdot.h +11 -12
- package/include/numkong/dot/neonfhm.h +44 -46
- package/include/numkong/dot/neonfp8.h +323 -0
- package/include/numkong/dot/neonsdot.h +190 -76
- package/include/numkong/dot/powervsx.h +752 -0
- package/include/numkong/dot/rvv.h +92 -84
- package/include/numkong/dot/rvvbf16.h +12 -12
- package/include/numkong/dot/rvvhalf.h +12 -12
- package/include/numkong/dot/sapphire.h +4 -4
- package/include/numkong/dot/serial.h +66 -30
- package/include/numkong/dot/sierra.h +31 -31
- package/include/numkong/dot/skylake.h +142 -110
- package/include/numkong/dot/sve.h +217 -177
- package/include/numkong/dot/svebfdot.h +10 -10
- package/include/numkong/dot/svehalf.h +85 -41
- package/include/numkong/dot/svesdot.h +89 -0
- package/include/numkong/dot/v128relaxed.h +124 -89
- package/include/numkong/dot.h +114 -48
- package/include/numkong/dots/README.md +203 -203
- package/include/numkong/dots/alder.h +12 -9
- package/include/numkong/dots/diamond.h +86 -0
- package/include/numkong/dots/genoa.h +10 -4
- package/include/numkong/dots/haswell.h +63 -48
- package/include/numkong/dots/icelake.h +27 -18
- package/include/numkong/dots/loongsonasx.h +176 -0
- package/include/numkong/dots/neon.h +14 -11
- package/include/numkong/dots/neonbfdot.h +4 -3
- package/include/numkong/dots/neonfhm.h +11 -9
- package/include/numkong/dots/neonfp8.h +99 -0
- package/include/numkong/dots/neonsdot.h +48 -12
- package/include/numkong/dots/powervsx.h +194 -0
- package/include/numkong/dots/rvv.h +451 -344
- package/include/numkong/dots/sapphireamx.h +1028 -984
- package/include/numkong/dots/serial.h +213 -197
- package/include/numkong/dots/sierra.h +10 -7
- package/include/numkong/dots/skylake.h +47 -36
- package/include/numkong/dots/sme.h +2001 -2364
- package/include/numkong/dots/smebi32.h +175 -162
- package/include/numkong/dots/smef64.h +328 -323
- package/include/numkong/dots/v128relaxed.h +64 -41
- package/include/numkong/dots.h +573 -293
- package/include/numkong/dots.hpp +45 -43
- package/include/numkong/each/README.md +133 -137
- package/include/numkong/each/haswell.h +6 -6
- package/include/numkong/each/icelake.h +7 -7
- package/include/numkong/each/neon.h +76 -42
- package/include/numkong/each/neonbfdot.h +11 -12
- package/include/numkong/each/neonhalf.h +24 -116
- package/include/numkong/each/rvv.h +28 -28
- package/include/numkong/each/sapphire.h +27 -161
- package/include/numkong/each/serial.h +6 -6
- package/include/numkong/each/skylake.h +7 -7
- package/include/numkong/each/v128relaxed.h +562 -0
- package/include/numkong/each.h +148 -62
- package/include/numkong/each.hpp +2 -2
- package/include/numkong/geospatial/README.md +18 -18
- package/include/numkong/geospatial/haswell.h +365 -325
- package/include/numkong/geospatial/neon.h +350 -306
- package/include/numkong/geospatial/rvv.h +4 -4
- package/include/numkong/geospatial/skylake.h +376 -340
- package/include/numkong/geospatial/v128relaxed.h +366 -327
- package/include/numkong/geospatial.h +17 -17
- package/include/numkong/matrix.hpp +4 -4
- package/include/numkong/maxsim/README.md +14 -14
- package/include/numkong/maxsim/alder.h +6 -6
- package/include/numkong/maxsim/genoa.h +4 -4
- package/include/numkong/maxsim/haswell.h +6 -6
- package/include/numkong/maxsim/icelake.h +18 -18
- package/include/numkong/maxsim/neonsdot.h +21 -21
- package/include/numkong/maxsim/sapphireamx.h +14 -14
- package/include/numkong/maxsim/serial.h +6 -6
- package/include/numkong/maxsim/sme.h +221 -196
- package/include/numkong/maxsim/v128relaxed.h +6 -6
- package/include/numkong/mesh/README.md +62 -56
- package/include/numkong/mesh/haswell.h +339 -464
- package/include/numkong/mesh/neon.h +1100 -519
- package/include/numkong/mesh/neonbfdot.h +36 -68
- package/include/numkong/mesh/rvv.h +530 -435
- package/include/numkong/mesh/serial.h +75 -91
- package/include/numkong/mesh/skylake.h +1627 -302
- package/include/numkong/mesh/v128relaxed.h +443 -330
- package/include/numkong/mesh.h +63 -49
- package/include/numkong/mesh.hpp +4 -4
- package/include/numkong/numkong.h +3 -3
- package/include/numkong/numkong.hpp +1 -0
- package/include/numkong/probability/README.md +23 -19
- package/include/numkong/probability/neon.h +82 -52
- package/include/numkong/probability/rvv.h +28 -23
- package/include/numkong/probability/serial.h +51 -39
- package/include/numkong/probability.h +20 -23
- package/include/numkong/random.h +1 -1
- package/include/numkong/reduce/README.md +143 -138
- package/include/numkong/reduce/alder.h +81 -77
- package/include/numkong/reduce/haswell.h +222 -220
- package/include/numkong/reduce/neon.h +629 -519
- package/include/numkong/reduce/neonbfdot.h +7 -218
- package/include/numkong/reduce/neonfhm.h +9 -381
- package/include/numkong/reduce/neonsdot.h +9 -9
- package/include/numkong/reduce/rvv.h +928 -802
- package/include/numkong/reduce/serial.h +23 -27
- package/include/numkong/reduce/sierra.h +20 -20
- package/include/numkong/reduce/skylake.h +326 -324
- package/include/numkong/reduce/v128relaxed.h +52 -52
- package/include/numkong/reduce.h +4 -23
- package/include/numkong/reduce.hpp +156 -11
- package/include/numkong/scalar/README.md +6 -6
- package/include/numkong/scalar/haswell.h +26 -17
- package/include/numkong/scalar/loongsonasx.h +74 -0
- package/include/numkong/scalar/neon.h +9 -9
- package/include/numkong/scalar/powervsx.h +96 -0
- package/include/numkong/scalar/rvv.h +2 -2
- package/include/numkong/scalar/sapphire.h +21 -10
- package/include/numkong/scalar/serial.h +21 -21
- package/include/numkong/scalar.h +13 -0
- package/include/numkong/set/README.md +28 -28
- package/include/numkong/set/haswell.h +12 -12
- package/include/numkong/set/icelake.h +14 -14
- package/include/numkong/set/loongsonasx.h +181 -0
- package/include/numkong/set/neon.h +17 -18
- package/include/numkong/set/powervsx.h +326 -0
- package/include/numkong/set/rvv.h +4 -4
- package/include/numkong/set/serial.h +6 -6
- package/include/numkong/set/sve.h +60 -59
- package/include/numkong/set/v128relaxed.h +6 -6
- package/include/numkong/set.h +21 -7
- package/include/numkong/sets/README.md +26 -26
- package/include/numkong/sets/loongsonasx.h +52 -0
- package/include/numkong/sets/powervsx.h +65 -0
- package/include/numkong/sets/smebi32.h +395 -364
- package/include/numkong/sets.h +83 -40
- package/include/numkong/sparse/README.md +4 -4
- package/include/numkong/sparse/icelake.h +101 -101
- package/include/numkong/sparse/serial.h +1 -1
- package/include/numkong/sparse/sve2.h +137 -141
- package/include/numkong/sparse/turin.h +12 -12
- package/include/numkong/sparse.h +10 -10
- package/include/numkong/spatial/README.md +230 -226
- package/include/numkong/spatial/alder.h +113 -116
- package/include/numkong/spatial/diamond.h +240 -0
- package/include/numkong/spatial/genoa.h +0 -68
- package/include/numkong/spatial/haswell.h +74 -55
- package/include/numkong/spatial/icelake.h +539 -58
- package/include/numkong/spatial/loongsonasx.h +483 -0
- package/include/numkong/spatial/neon.h +125 -52
- package/include/numkong/spatial/neonbfdot.h +8 -9
- package/include/numkong/spatial/neonfp8.h +258 -0
- package/include/numkong/spatial/neonsdot.h +180 -12
- package/include/numkong/spatial/powervsx.h +738 -0
- package/include/numkong/spatial/rvv.h +146 -139
- package/include/numkong/spatial/rvvbf16.h +17 -12
- package/include/numkong/spatial/rvvhalf.h +13 -10
- package/include/numkong/spatial/serial.h +13 -12
- package/include/numkong/spatial/sierra.h +232 -39
- package/include/numkong/spatial/skylake.h +73 -74
- package/include/numkong/spatial/sve.h +93 -72
- package/include/numkong/spatial/svebfdot.h +29 -29
- package/include/numkong/spatial/svehalf.h +52 -26
- package/include/numkong/spatial/svesdot.h +142 -0
- package/include/numkong/spatial/v128relaxed.h +293 -41
- package/include/numkong/spatial.h +338 -82
- package/include/numkong/spatials/README.md +194 -194
- package/include/numkong/spatials/diamond.h +82 -0
- package/include/numkong/spatials/haswell.h +2 -2
- package/include/numkong/spatials/loongsonasx.h +153 -0
- package/include/numkong/spatials/neonfp8.h +111 -0
- package/include/numkong/spatials/neonsdot.h +34 -0
- package/include/numkong/spatials/powervsx.h +153 -0
- package/include/numkong/spatials/rvv.h +259 -243
- package/include/numkong/spatials/sapphireamx.h +173 -173
- package/include/numkong/spatials/serial.h +2 -2
- package/include/numkong/spatials/skylake.h +2 -2
- package/include/numkong/spatials/sme.h +590 -605
- package/include/numkong/spatials/smef64.h +139 -130
- package/include/numkong/spatials/v128relaxed.h +2 -2
- package/include/numkong/spatials.h +820 -500
- package/include/numkong/spatials.hpp +49 -48
- package/include/numkong/tensor.hpp +406 -17
- package/include/numkong/trigonometry/README.md +19 -19
- package/include/numkong/trigonometry/haswell.h +402 -401
- package/include/numkong/trigonometry/neon.h +386 -387
- package/include/numkong/trigonometry/rvv.h +52 -51
- package/include/numkong/trigonometry/serial.h +13 -13
- package/include/numkong/trigonometry/skylake.h +373 -369
- package/include/numkong/trigonometry/v128relaxed.h +375 -374
- package/include/numkong/trigonometry.h +13 -13
- package/include/numkong/trigonometry.hpp +2 -2
- package/include/numkong/types.h +287 -49
- package/include/numkong/types.hpp +436 -12
- package/include/numkong/vector.hpp +82 -14
- package/javascript/dist/cjs/numkong-wasm.js +6 -12
- package/javascript/dist/cjs/numkong.d.ts +7 -1
- package/javascript/dist/cjs/numkong.js +37 -11
- package/javascript/dist/cjs/types.d.ts +9 -0
- package/javascript/dist/cjs/types.js +96 -0
- package/javascript/dist/esm/numkong-browser.d.ts +14 -0
- package/javascript/dist/esm/numkong-browser.js +23 -0
- package/javascript/dist/esm/numkong-wasm.js +6 -12
- package/javascript/dist/esm/numkong.d.ts +7 -1
- package/javascript/dist/esm/numkong.js +37 -11
- package/javascript/dist/esm/types.d.ts +9 -0
- package/javascript/dist/esm/types.js +96 -0
- package/javascript/node-gyp-build.d.ts +4 -1
- package/javascript/numkong-browser.ts +40 -0
- package/javascript/numkong-wasm.ts +7 -13
- package/javascript/numkong.c +5 -26
- package/javascript/numkong.ts +36 -11
- package/javascript/tsconfig-base.json +1 -0
- package/javascript/tsconfig-cjs.json +6 -1
- package/javascript/types.ts +110 -0
- package/numkong.gypi +101 -0
- package/package.json +34 -13
- package/probes/arm_neon.c +8 -0
- package/probes/arm_neon_bfdot.c +9 -0
- package/probes/arm_neon_fhm.c +9 -0
- package/probes/arm_neon_half.c +8 -0
- package/probes/arm_neon_sdot.c +9 -0
- package/probes/arm_neonfp8.c +9 -0
- package/probes/arm_sme.c +16 -0
- package/probes/arm_sme2.c +16 -0
- package/probes/arm_sme2p1.c +16 -0
- package/probes/arm_sme_bf16.c +16 -0
- package/probes/arm_sme_bi32.c +16 -0
- package/probes/arm_sme_f64.c +16 -0
- package/probes/arm_sme_fa64.c +14 -0
- package/probes/arm_sme_half.c +16 -0
- package/probes/arm_sme_lut2.c +15 -0
- package/probes/arm_sve.c +18 -0
- package/probes/arm_sve2.c +20 -0
- package/probes/arm_sve2p1.c +18 -0
- package/probes/arm_sve_bfdot.c +20 -0
- package/probes/arm_sve_half.c +18 -0
- package/probes/arm_sve_sdot.c +21 -0
- package/probes/loongarch_lasx.c +12 -0
- package/probes/power_vsx.c +12 -0
- package/probes/probe.js +127 -0
- package/probes/riscv_rvv.c +14 -0
- package/probes/riscv_rvv_bb.c +15 -0
- package/probes/riscv_rvv_bf16.c +17 -0
- package/probes/riscv_rvv_half.c +14 -0
- package/probes/wasm_v128relaxed.c +11 -0
- package/probes/x86_alder.c +17 -0
- package/probes/x86_diamond.c +17 -0
- package/probes/x86_genoa.c +17 -0
- package/probes/x86_graniteamx.c +19 -0
- package/probes/x86_haswell.c +11 -0
- package/probes/x86_icelake.c +17 -0
- package/probes/x86_sapphire.c +16 -0
- package/probes/x86_sapphireamx.c +18 -0
- package/probes/x86_sierra.c +17 -0
- package/probes/x86_skylake.c +15 -0
- package/probes/x86_turin.c +17 -0
- package/wasm/numkong-emscripten.js +2 -0
- package/wasm/numkong.d.ts +14 -0
- package/wasm/numkong.js +1124 -0
- package/wasm/numkong.wasm +0 -0
- package/include/numkong/curved/neonhalf.h +0 -212
- package/include/numkong/dot/neonhalf.h +0 -198
- package/include/numkong/dots/neonhalf.h +0 -57
- package/include/numkong/mesh/neonhalf.h +0 -616
- package/include/numkong/reduce/neonhalf.h +0 -157
- package/include/numkong/spatial/neonhalf.h +0 -118
- package/include/numkong/spatial/sapphire.h +0 -343
- package/include/numkong/spatials/neonhalf.h +0 -58
- package/javascript/README.md +0 -246
|
@@ -8,11 +8,11 @@
|
|
|
8
8
|
*
|
|
9
9
|
* @section skylake_mesh_instructions Key AVX-512 Mesh Instructions
|
|
10
10
|
*
|
|
11
|
-
* Intrinsic
|
|
12
|
-
* _mm512_fmadd_ps
|
|
13
|
-
* _mm512_permutexvar_ps
|
|
14
|
-
* _mm512_permutex2var_ps
|
|
15
|
-
* _mm512_extractf32x8_ps
|
|
11
|
+
* Intrinsic Instruction Skylake-X Genoa
|
|
12
|
+
* _mm512_fmadd_ps VFMADD132PS (ZMM, ZMM, ZMM) 4cy @ p05 4cy @ p01
|
|
13
|
+
* _mm512_permutexvar_ps VPERMPS (ZMM, ZMM, ZMM) 3cy @ p5 4cy @ p12
|
|
14
|
+
* _mm512_permutex2var_ps VPERMT2PS (ZMM, ZMM, ZMM) 3cy @ p5 4cy @ p12
|
|
15
|
+
* _mm512_extractf32x8_ps VEXTRACTF32X8 (YMM, ZMM, I8) 3cy @ p5 1cy @ p0123
|
|
16
16
|
*
|
|
17
17
|
* Point cloud operations use VPERMT2PS for stride-3 deinterleaving of xyz coordinates, avoiding
|
|
18
18
|
* expensive gather instructions. This achieves ~1.8x speedup over scalar deinterleaving. Dual FMA
|
|
@@ -28,6 +28,7 @@
|
|
|
28
28
|
#include "numkong/dot/skylake.h"
|
|
29
29
|
#include "numkong/mesh/serial.h"
|
|
30
30
|
#include "numkong/spatial/haswell.h"
|
|
31
|
+
#include "numkong/cast/skylake.h"
|
|
31
32
|
|
|
32
33
|
#if defined(__cplusplus)
|
|
33
34
|
extern "C" {
|
|
@@ -112,6 +113,115 @@ NK_INTERNAL void nk_deinterleave_f64x8_skylake_(
|
|
|
112
113
|
*z_f64x8_out = _mm512_permutex2var_pd(z01_f64x8, idx_z_2_i64x8, reg2_f64x8);
|
|
113
114
|
}
|
|
114
115
|
|
|
116
|
+
/* Deinterleave 16 f16 3D points from xyz,xyz,xyz... to separate x,y,z vectors in f32.
|
|
117
|
+
* Input: 48 consecutive f16 values (16 points * 3 coordinates)
|
|
118
|
+
* Output: Three __m512 vectors containing the x, y, z coordinates separately (as f32).
|
|
119
|
+
*/
|
|
120
|
+
NK_INTERNAL void nk_deinterleave_f16x16_to_f32x16_skylake_( //
|
|
121
|
+
nk_f16_t const *ptr, __m512 *x_f32x16_out, __m512 *y_f32x16_out, __m512 *z_f32x16_out) { //
|
|
122
|
+
__m512 reg0_f32x16 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i const *)(ptr)));
|
|
123
|
+
__m512 reg1_f32x16 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i const *)(ptr + 16)));
|
|
124
|
+
__m512 reg2_f32x16 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i const *)(ptr + 32)));
|
|
125
|
+
|
|
126
|
+
__m512i idx_x_01_i32x16 = _mm512_setr_epi32(0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 0, 0, 0, 0, 0);
|
|
127
|
+
__m512i idx_x_2_i32x16 = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 17, 20, 23, 26, 29);
|
|
128
|
+
__m512 x01_f32x16 = _mm512_permutex2var_ps(reg0_f32x16, idx_x_01_i32x16, reg1_f32x16);
|
|
129
|
+
*x_f32x16_out = _mm512_permutex2var_ps(x01_f32x16, idx_x_2_i32x16, reg2_f32x16);
|
|
130
|
+
|
|
131
|
+
__m512i idx_y_01_i32x16 = _mm512_setr_epi32(1, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31, 0, 0, 0, 0, 0);
|
|
132
|
+
__m512i idx_y_2_i32x16 = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 18, 21, 24, 27, 30);
|
|
133
|
+
__m512 y01_f32x16 = _mm512_permutex2var_ps(reg0_f32x16, idx_y_01_i32x16, reg1_f32x16);
|
|
134
|
+
*y_f32x16_out = _mm512_permutex2var_ps(y01_f32x16, idx_y_2_i32x16, reg2_f32x16);
|
|
135
|
+
|
|
136
|
+
__m512i idx_z_01_i32x16 = _mm512_setr_epi32(2, 5, 8, 11, 14, 17, 20, 23, 26, 29, 0, 0, 0, 0, 0, 0);
|
|
137
|
+
__m512i idx_z_2_i32x16 = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 16, 19, 22, 25, 28, 31);
|
|
138
|
+
__m512 z01_f32x16 = _mm512_permutex2var_ps(reg0_f32x16, idx_z_01_i32x16, reg1_f32x16);
|
|
139
|
+
*z_f32x16_out = _mm512_permutex2var_ps(z01_f32x16, idx_z_2_i32x16, reg2_f32x16);
|
|
140
|
+
}
|
|
141
|
+
|
|
142
|
+
/* Deinterleave 16 bf16 3D points from xyz,xyz,xyz... to separate x,y,z vectors in f32.
|
|
143
|
+
* Input: 48 consecutive bf16 values (16 points * 3 coordinates)
|
|
144
|
+
* Output: Three __m512 vectors containing the x, y, z coordinates separately (as f32).
|
|
145
|
+
*/
|
|
146
|
+
NK_INTERNAL void nk_deinterleave_bf16x16_to_f32x16_skylake_( //
|
|
147
|
+
nk_bf16_t const *ptr, __m512 *x_f32x16_out, __m512 *y_f32x16_out, __m512 *z_f32x16_out) { //
|
|
148
|
+
__m512 reg0_f32x16 = nk_bf16x16_to_f32x16_skylake_(_mm256_loadu_si256((__m256i const *)(ptr)));
|
|
149
|
+
__m512 reg1_f32x16 = nk_bf16x16_to_f32x16_skylake_(_mm256_loadu_si256((__m256i const *)(ptr + 16)));
|
|
150
|
+
__m512 reg2_f32x16 = nk_bf16x16_to_f32x16_skylake_(_mm256_loadu_si256((__m256i const *)(ptr + 32)));
|
|
151
|
+
|
|
152
|
+
__m512i idx_x_01_i32x16 = _mm512_setr_epi32(0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 0, 0, 0, 0, 0);
|
|
153
|
+
__m512i idx_x_2_i32x16 = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 17, 20, 23, 26, 29);
|
|
154
|
+
__m512 x01_f32x16 = _mm512_permutex2var_ps(reg0_f32x16, idx_x_01_i32x16, reg1_f32x16);
|
|
155
|
+
*x_f32x16_out = _mm512_permutex2var_ps(x01_f32x16, idx_x_2_i32x16, reg2_f32x16);
|
|
156
|
+
|
|
157
|
+
__m512i idx_y_01_i32x16 = _mm512_setr_epi32(1, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31, 0, 0, 0, 0, 0);
|
|
158
|
+
__m512i idx_y_2_i32x16 = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 18, 21, 24, 27, 30);
|
|
159
|
+
__m512 y01_f32x16 = _mm512_permutex2var_ps(reg0_f32x16, idx_y_01_i32x16, reg1_f32x16);
|
|
160
|
+
*y_f32x16_out = _mm512_permutex2var_ps(y01_f32x16, idx_y_2_i32x16, reg2_f32x16);
|
|
161
|
+
|
|
162
|
+
__m512i idx_z_01_i32x16 = _mm512_setr_epi32(2, 5, 8, 11, 14, 17, 20, 23, 26, 29, 0, 0, 0, 0, 0, 0);
|
|
163
|
+
__m512i idx_z_2_i32x16 = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 16, 19, 22, 25, 28, 31);
|
|
164
|
+
__m512 z01_f32x16 = _mm512_permutex2var_ps(reg0_f32x16, idx_z_01_i32x16, reg1_f32x16);
|
|
165
|
+
*z_f32x16_out = _mm512_permutex2var_ps(z01_f32x16, idx_z_2_i32x16, reg2_f32x16);
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
/* Masked-tail deinterleave for f16: loads up to 16 xyz points using AVX-512 masked loads,
|
|
169
|
+
* converts f16→f32, and deinterleaves into separate x,y,z vectors.
|
|
170
|
+
* Unused lanes are zero. Uses the same permutex2var shuffle as the full-width version.
|
|
171
|
+
*/
|
|
172
|
+
NK_INTERNAL void nk_deinterleave_f16_tail_to_f32x16_skylake_( //
|
|
173
|
+
nk_f16_t const *ptr, nk_size_t count, __m512 *x_f32x16_out, __m512 *y_f32x16_out, __m512 *z_f32x16_out) { //
|
|
174
|
+
nk_size_t total = count * 3;
|
|
175
|
+
__mmask16 mask0_i16x16 = (__mmask16)_bzhi_u32(0xFFFF, total >= 16 ? 16 : total);
|
|
176
|
+
__mmask16 mask1_i16x16 = total > 16 ? (__mmask16)_bzhi_u32(0xFFFF, total >= 32 ? 16 : total - 16) : 0;
|
|
177
|
+
__mmask16 mask2_i16x16 = total > 32 ? (__mmask16)_bzhi_u32(0xFFFF, total - 32) : 0;
|
|
178
|
+
__m512 reg0_f32x16 = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask0_i16x16, ptr));
|
|
179
|
+
__m512 reg1_f32x16 = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask1_i16x16, ptr + 16));
|
|
180
|
+
__m512 reg2_f32x16 = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask2_i16x16, ptr + 32));
|
|
181
|
+
|
|
182
|
+
__m512i idx_x_01_i32x16 = _mm512_setr_epi32(0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 0, 0, 0, 0, 0);
|
|
183
|
+
__m512i idx_x_2_i32x16 = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 17, 20, 23, 26, 29);
|
|
184
|
+
*x_f32x16_out = _mm512_permutex2var_ps(_mm512_permutex2var_ps(reg0_f32x16, idx_x_01_i32x16, reg1_f32x16),
|
|
185
|
+
idx_x_2_i32x16, reg2_f32x16);
|
|
186
|
+
|
|
187
|
+
__m512i idx_y_01_i32x16 = _mm512_setr_epi32(1, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31, 0, 0, 0, 0, 0);
|
|
188
|
+
__m512i idx_y_2_i32x16 = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 18, 21, 24, 27, 30);
|
|
189
|
+
*y_f32x16_out = _mm512_permutex2var_ps(_mm512_permutex2var_ps(reg0_f32x16, idx_y_01_i32x16, reg1_f32x16),
|
|
190
|
+
idx_y_2_i32x16, reg2_f32x16);
|
|
191
|
+
|
|
192
|
+
__m512i idx_z_01_i32x16 = _mm512_setr_epi32(2, 5, 8, 11, 14, 17, 20, 23, 26, 29, 0, 0, 0, 0, 0, 0);
|
|
193
|
+
__m512i idx_z_2_i32x16 = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 16, 19, 22, 25, 28, 31);
|
|
194
|
+
*z_f32x16_out = _mm512_permutex2var_ps(_mm512_permutex2var_ps(reg0_f32x16, idx_z_01_i32x16, reg1_f32x16),
|
|
195
|
+
idx_z_2_i32x16, reg2_f32x16);
|
|
196
|
+
}
|
|
197
|
+
|
|
198
|
+
/* Masked-tail deinterleave for bf16: same as f16 but with bf16→f32 conversion. */
|
|
199
|
+
NK_INTERNAL void nk_deinterleave_bf16_tail_to_f32x16_skylake_( //
|
|
200
|
+
nk_bf16_t const *ptr, nk_size_t count, __m512 *x_f32x16_out, __m512 *y_f32x16_out, __m512 *z_f32x16_out) { //
|
|
201
|
+
nk_size_t total = count * 3;
|
|
202
|
+
__mmask16 mask0_i16x16 = (__mmask16)_bzhi_u32(0xFFFF, total >= 16 ? 16 : total);
|
|
203
|
+
__mmask16 mask1_i16x16 = total > 16 ? (__mmask16)_bzhi_u32(0xFFFF, total >= 32 ? 16 : total - 16) : 0;
|
|
204
|
+
__mmask16 mask2_i16x16 = total > 32 ? (__mmask16)_bzhi_u32(0xFFFF, total - 32) : 0;
|
|
205
|
+
__m512 reg0_f32x16 = nk_bf16x16_to_f32x16_skylake_(_mm256_maskz_loadu_epi16(mask0_i16x16, ptr));
|
|
206
|
+
__m512 reg1_f32x16 = nk_bf16x16_to_f32x16_skylake_(_mm256_maskz_loadu_epi16(mask1_i16x16, ptr + 16));
|
|
207
|
+
__m512 reg2_f32x16 = nk_bf16x16_to_f32x16_skylake_(_mm256_maskz_loadu_epi16(mask2_i16x16, ptr + 32));
|
|
208
|
+
|
|
209
|
+
__m512i idx_x_01_i32x16 = _mm512_setr_epi32(0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 0, 0, 0, 0, 0);
|
|
210
|
+
__m512i idx_x_2_i32x16 = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 17, 20, 23, 26, 29);
|
|
211
|
+
*x_f32x16_out = _mm512_permutex2var_ps(_mm512_permutex2var_ps(reg0_f32x16, idx_x_01_i32x16, reg1_f32x16),
|
|
212
|
+
idx_x_2_i32x16, reg2_f32x16);
|
|
213
|
+
|
|
214
|
+
__m512i idx_y_01_i32x16 = _mm512_setr_epi32(1, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31, 0, 0, 0, 0, 0);
|
|
215
|
+
__m512i idx_y_2_i32x16 = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 18, 21, 24, 27, 30);
|
|
216
|
+
*y_f32x16_out = _mm512_permutex2var_ps(_mm512_permutex2var_ps(reg0_f32x16, idx_y_01_i32x16, reg1_f32x16),
|
|
217
|
+
idx_y_2_i32x16, reg2_f32x16);
|
|
218
|
+
|
|
219
|
+
__m512i idx_z_01_i32x16 = _mm512_setr_epi32(2, 5, 8, 11, 14, 17, 20, 23, 26, 29, 0, 0, 0, 0, 0, 0);
|
|
220
|
+
__m512i idx_z_2_i32x16 = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 16, 19, 22, 25, 28, 31);
|
|
221
|
+
*z_f32x16_out = _mm512_permutex2var_ps(_mm512_permutex2var_ps(reg0_f32x16, idx_z_01_i32x16, reg1_f32x16),
|
|
222
|
+
idx_z_2_i32x16, reg2_f32x16);
|
|
223
|
+
}
|
|
224
|
+
|
|
115
225
|
NK_INTERNAL nk_f64_t nk_reduce_stable_f64x8_skylake_(__m512d values_f64x8) {
|
|
116
226
|
nk_b512_vec_t values;
|
|
117
227
|
values.zmm_pd = values_f64x8;
|
|
@@ -166,84 +276,84 @@ NK_INTERNAL nk_f64_t nk_transformed_ssd_f32_skylake_(nk_f32_t const *a, nk_f32_t
|
|
|
166
276
|
for (; index + 16 <= n; index += 16) {
|
|
167
277
|
nk_deinterleave_f32x16_skylake_(a + index * 3, &a_x_f32x16, &a_y_f32x16, &a_z_f32x16),
|
|
168
278
|
nk_deinterleave_f32x16_skylake_(b + index * 3, &b_x_f32x16, &b_y_f32x16, &b_z_f32x16);
|
|
169
|
-
__m512d
|
|
170
|
-
__m512d
|
|
171
|
-
__m512d
|
|
172
|
-
__m512d
|
|
173
|
-
__m512d
|
|
174
|
-
__m512d
|
|
175
|
-
__m512d
|
|
176
|
-
__m512d
|
|
177
|
-
__m512d
|
|
178
|
-
__m512d
|
|
179
|
-
__m512d
|
|
180
|
-
__m512d
|
|
181
|
-
|
|
182
|
-
__m512d
|
|
183
|
-
__m512d
|
|
184
|
-
__m512d
|
|
185
|
-
__m512d
|
|
186
|
-
__m512d
|
|
187
|
-
__m512d
|
|
188
|
-
__m512d
|
|
189
|
-
__m512d
|
|
190
|
-
__m512d
|
|
191
|
-
__m512d
|
|
192
|
-
__m512d
|
|
193
|
-
__m512d
|
|
194
|
-
|
|
195
|
-
__m512d
|
|
196
|
-
scaled_rotation_x_z_f64x8,
|
|
197
|
-
_mm512_fmadd_pd(scaled_rotation_x_y_f64x8,
|
|
198
|
-
_mm512_mul_pd(scaled_rotation_x_x_f64x8,
|
|
199
|
-
__m512d
|
|
200
|
-
scaled_rotation_x_z_f64x8,
|
|
201
|
-
_mm512_fmadd_pd(scaled_rotation_x_y_f64x8,
|
|
202
|
-
_mm512_mul_pd(scaled_rotation_x_x_f64x8,
|
|
203
|
-
__m512d
|
|
204
|
-
scaled_rotation_y_z_f64x8,
|
|
205
|
-
_mm512_fmadd_pd(scaled_rotation_y_y_f64x8,
|
|
206
|
-
_mm512_mul_pd(scaled_rotation_y_x_f64x8,
|
|
207
|
-
__m512d
|
|
208
|
-
scaled_rotation_y_z_f64x8,
|
|
209
|
-
_mm512_fmadd_pd(scaled_rotation_y_y_f64x8,
|
|
210
|
-
_mm512_mul_pd(scaled_rotation_y_x_f64x8,
|
|
211
|
-
__m512d
|
|
212
|
-
scaled_rotation_z_z_f64x8,
|
|
213
|
-
_mm512_fmadd_pd(scaled_rotation_z_y_f64x8,
|
|
214
|
-
_mm512_mul_pd(scaled_rotation_z_x_f64x8,
|
|
215
|
-
__m512d
|
|
216
|
-
scaled_rotation_z_z_f64x8,
|
|
217
|
-
_mm512_fmadd_pd(scaled_rotation_z_y_f64x8,
|
|
218
|
-
_mm512_mul_pd(scaled_rotation_z_x_f64x8,
|
|
219
|
-
|
|
220
|
-
__m512d
|
|
221
|
-
__m512d
|
|
222
|
-
__m512d
|
|
223
|
-
__m512d
|
|
224
|
-
__m512d
|
|
225
|
-
__m512d
|
|
226
|
-
|
|
227
|
-
__m512d batch_sum_squared_f64x8 = _mm512_add_pd(_mm512_mul_pd(
|
|
228
|
-
_mm512_mul_pd(
|
|
229
|
-
batch_sum_squared_f64x8 = _mm512_fmadd_pd(
|
|
230
|
-
batch_sum_squared_f64x8 = _mm512_fmadd_pd(
|
|
231
|
-
batch_sum_squared_f64x8 = _mm512_fmadd_pd(
|
|
232
|
-
batch_sum_squared_f64x8 = _mm512_fmadd_pd(
|
|
279
|
+
__m512d a_x_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_x_f32x16));
|
|
280
|
+
__m512d a_x_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_x_f32x16, 1));
|
|
281
|
+
__m512d a_y_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_y_f32x16));
|
|
282
|
+
__m512d a_y_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_y_f32x16, 1));
|
|
283
|
+
__m512d a_z_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_z_f32x16));
|
|
284
|
+
__m512d a_z_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_z_f32x16, 1));
|
|
285
|
+
__m512d b_x_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_x_f32x16));
|
|
286
|
+
__m512d b_x_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_x_f32x16, 1));
|
|
287
|
+
__m512d b_y_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_y_f32x16));
|
|
288
|
+
__m512d b_y_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_y_f32x16, 1));
|
|
289
|
+
__m512d b_z_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_z_f32x16));
|
|
290
|
+
__m512d b_z_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_z_f32x16, 1));
|
|
291
|
+
|
|
292
|
+
__m512d centered_a_x_low_f64x8 = _mm512_sub_pd(a_x_low_f64x8, centroid_a_x_f64x8);
|
|
293
|
+
__m512d centered_a_x_high_f64x8 = _mm512_sub_pd(a_x_high_f64x8, centroid_a_x_f64x8);
|
|
294
|
+
__m512d centered_a_y_low_f64x8 = _mm512_sub_pd(a_y_low_f64x8, centroid_a_y_f64x8);
|
|
295
|
+
__m512d centered_a_y_high_f64x8 = _mm512_sub_pd(a_y_high_f64x8, centroid_a_y_f64x8);
|
|
296
|
+
__m512d centered_a_z_low_f64x8 = _mm512_sub_pd(a_z_low_f64x8, centroid_a_z_f64x8);
|
|
297
|
+
__m512d centered_a_z_high_f64x8 = _mm512_sub_pd(a_z_high_f64x8, centroid_a_z_f64x8);
|
|
298
|
+
__m512d centered_b_x_low_f64x8 = _mm512_sub_pd(b_x_low_f64x8, centroid_b_x_f64x8);
|
|
299
|
+
__m512d centered_b_x_high_f64x8 = _mm512_sub_pd(b_x_high_f64x8, centroid_b_x_f64x8);
|
|
300
|
+
__m512d centered_b_y_low_f64x8 = _mm512_sub_pd(b_y_low_f64x8, centroid_b_y_f64x8);
|
|
301
|
+
__m512d centered_b_y_high_f64x8 = _mm512_sub_pd(b_y_high_f64x8, centroid_b_y_f64x8);
|
|
302
|
+
__m512d centered_b_z_low_f64x8 = _mm512_sub_pd(b_z_low_f64x8, centroid_b_z_f64x8);
|
|
303
|
+
__m512d centered_b_z_high_f64x8 = _mm512_sub_pd(b_z_high_f64x8, centroid_b_z_f64x8);
|
|
304
|
+
|
|
305
|
+
__m512d rotated_a_x_low_f64x8 = _mm512_fmadd_pd(
|
|
306
|
+
scaled_rotation_x_z_f64x8, centered_a_z_low_f64x8,
|
|
307
|
+
_mm512_fmadd_pd(scaled_rotation_x_y_f64x8, centered_a_y_low_f64x8,
|
|
308
|
+
_mm512_mul_pd(scaled_rotation_x_x_f64x8, centered_a_x_low_f64x8)));
|
|
309
|
+
__m512d rotated_a_x_high_f64x8 = _mm512_fmadd_pd(
|
|
310
|
+
scaled_rotation_x_z_f64x8, centered_a_z_high_f64x8,
|
|
311
|
+
_mm512_fmadd_pd(scaled_rotation_x_y_f64x8, centered_a_y_high_f64x8,
|
|
312
|
+
_mm512_mul_pd(scaled_rotation_x_x_f64x8, centered_a_x_high_f64x8)));
|
|
313
|
+
__m512d rotated_a_y_low_f64x8 = _mm512_fmadd_pd(
|
|
314
|
+
scaled_rotation_y_z_f64x8, centered_a_z_low_f64x8,
|
|
315
|
+
_mm512_fmadd_pd(scaled_rotation_y_y_f64x8, centered_a_y_low_f64x8,
|
|
316
|
+
_mm512_mul_pd(scaled_rotation_y_x_f64x8, centered_a_x_low_f64x8)));
|
|
317
|
+
__m512d rotated_a_y_high_f64x8 = _mm512_fmadd_pd(
|
|
318
|
+
scaled_rotation_y_z_f64x8, centered_a_z_high_f64x8,
|
|
319
|
+
_mm512_fmadd_pd(scaled_rotation_y_y_f64x8, centered_a_y_high_f64x8,
|
|
320
|
+
_mm512_mul_pd(scaled_rotation_y_x_f64x8, centered_a_x_high_f64x8)));
|
|
321
|
+
__m512d rotated_a_z_low_f64x8 = _mm512_fmadd_pd(
|
|
322
|
+
scaled_rotation_z_z_f64x8, centered_a_z_low_f64x8,
|
|
323
|
+
_mm512_fmadd_pd(scaled_rotation_z_y_f64x8, centered_a_y_low_f64x8,
|
|
324
|
+
_mm512_mul_pd(scaled_rotation_z_x_f64x8, centered_a_x_low_f64x8)));
|
|
325
|
+
__m512d rotated_a_z_high_f64x8 = _mm512_fmadd_pd(
|
|
326
|
+
scaled_rotation_z_z_f64x8, centered_a_z_high_f64x8,
|
|
327
|
+
_mm512_fmadd_pd(scaled_rotation_z_y_f64x8, centered_a_y_high_f64x8,
|
|
328
|
+
_mm512_mul_pd(scaled_rotation_z_x_f64x8, centered_a_x_high_f64x8)));
|
|
329
|
+
|
|
330
|
+
__m512d delta_x_low_f64x8 = _mm512_sub_pd(rotated_a_x_low_f64x8, centered_b_x_low_f64x8);
|
|
331
|
+
__m512d delta_x_high_f64x8 = _mm512_sub_pd(rotated_a_x_high_f64x8, centered_b_x_high_f64x8);
|
|
332
|
+
__m512d delta_y_low_f64x8 = _mm512_sub_pd(rotated_a_y_low_f64x8, centered_b_y_low_f64x8);
|
|
333
|
+
__m512d delta_y_high_f64x8 = _mm512_sub_pd(rotated_a_y_high_f64x8, centered_b_y_high_f64x8);
|
|
334
|
+
__m512d delta_z_low_f64x8 = _mm512_sub_pd(rotated_a_z_low_f64x8, centered_b_z_low_f64x8);
|
|
335
|
+
__m512d delta_z_high_f64x8 = _mm512_sub_pd(rotated_a_z_high_f64x8, centered_b_z_high_f64x8);
|
|
336
|
+
|
|
337
|
+
__m512d batch_sum_squared_f64x8 = _mm512_add_pd(_mm512_mul_pd(delta_x_low_f64x8, delta_x_low_f64x8),
|
|
338
|
+
_mm512_mul_pd(delta_x_high_f64x8, delta_x_high_f64x8));
|
|
339
|
+
batch_sum_squared_f64x8 = _mm512_fmadd_pd(delta_y_low_f64x8, delta_y_low_f64x8, batch_sum_squared_f64x8);
|
|
340
|
+
batch_sum_squared_f64x8 = _mm512_fmadd_pd(delta_y_high_f64x8, delta_y_high_f64x8, batch_sum_squared_f64x8);
|
|
341
|
+
batch_sum_squared_f64x8 = _mm512_fmadd_pd(delta_z_low_f64x8, delta_z_low_f64x8, batch_sum_squared_f64x8);
|
|
342
|
+
batch_sum_squared_f64x8 = _mm512_fmadd_pd(delta_z_high_f64x8, delta_z_high_f64x8, batch_sum_squared_f64x8);
|
|
233
343
|
sum_squared_f64x8 = _mm512_add_pd(sum_squared_f64x8, batch_sum_squared_f64x8);
|
|
234
344
|
}
|
|
235
345
|
|
|
236
346
|
nk_f64_t sum_squared = _mm512_reduce_add_pd(sum_squared_f64x8);
|
|
237
347
|
for (; index < n; ++index) {
|
|
238
|
-
nk_f64_t centered_a_x = (nk_f64_t)a[index * 3 + 0] - centroid_a_x
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
nk_f64_t centered_b_x = (nk_f64_t)b[index * 3 + 0] - centroid_b_x
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
nk_f64_t rotated_a_x = scale * (r[0] * centered_a_x + r[1] * centered_a_y + r[2] * centered_a_z)
|
|
245
|
-
|
|
246
|
-
|
|
348
|
+
nk_f64_t centered_a_x = (nk_f64_t)a[index * 3 + 0] - centroid_a_x,
|
|
349
|
+
centered_a_y = (nk_f64_t)a[index * 3 + 1] - centroid_a_y,
|
|
350
|
+
centered_a_z = (nk_f64_t)a[index * 3 + 2] - centroid_a_z;
|
|
351
|
+
nk_f64_t centered_b_x = (nk_f64_t)b[index * 3 + 0] - centroid_b_x,
|
|
352
|
+
centered_b_y = (nk_f64_t)b[index * 3 + 1] - centroid_b_y,
|
|
353
|
+
centered_b_z = (nk_f64_t)b[index * 3 + 2] - centroid_b_z;
|
|
354
|
+
nk_f64_t rotated_a_x = scale * (r[0] * centered_a_x + r[1] * centered_a_y + r[2] * centered_a_z),
|
|
355
|
+
rotated_a_y = scale * (r[3] * centered_a_x + r[4] * centered_a_y + r[5] * centered_a_z),
|
|
356
|
+
rotated_a_z = scale * (r[6] * centered_a_x + r[7] * centered_a_y + r[8] * centered_a_z);
|
|
247
357
|
nk_f64_t delta_x = rotated_a_x - centered_b_x, delta_y = rotated_a_y - centered_b_y,
|
|
248
358
|
delta_z = rotated_a_z - centered_b_z;
|
|
249
359
|
sum_squared += delta_x * delta_x + delta_y * delta_y + delta_z * delta_z;
|
|
@@ -322,20 +432,16 @@ NK_INTERNAL nk_f64_t nk_transformed_ssd_f64_skylake_(nk_f64_t const *a, nk_f64_t
|
|
|
322
432
|
|
|
323
433
|
// Scalar tail
|
|
324
434
|
for (; j < n; ++j) {
|
|
325
|
-
nk_f64_t pa_x = a[j * 3 + 0] - centroid_a_x
|
|
326
|
-
|
|
327
|
-
nk_f64_t
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
nk_f64_t
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
nk_f64_t
|
|
335
|
-
|
|
336
|
-
nk_f64_t delta_x = ra_x - pb_x;
|
|
337
|
-
nk_f64_t delta_y = ra_y - pb_y;
|
|
338
|
-
nk_f64_t delta_z = ra_z - pb_z;
|
|
435
|
+
nk_f64_t pa_x = a[j * 3 + 0] - centroid_a_x, pa_y = a[j * 3 + 1] - centroid_a_y,
|
|
436
|
+
pa_z = a[j * 3 + 2] - centroid_a_z;
|
|
437
|
+
nk_f64_t pb_x = b[j * 3 + 0] - centroid_b_x, pb_y = b[j * 3 + 1] - centroid_b_y,
|
|
438
|
+
pb_z = b[j * 3 + 2] - centroid_b_z;
|
|
439
|
+
|
|
440
|
+
nk_f64_t ra_x = scale * (r[0] * pa_x + r[1] * pa_y + r[2] * pa_z),
|
|
441
|
+
ra_y = scale * (r[3] * pa_x + r[4] * pa_y + r[5] * pa_z),
|
|
442
|
+
ra_z = scale * (r[6] * pa_x + r[7] * pa_y + r[8] * pa_z);
|
|
443
|
+
|
|
444
|
+
nk_f64_t delta_x = ra_x - pb_x, delta_y = ra_y - pb_y, delta_z = ra_z - pb_z;
|
|
339
445
|
nk_accumulate_square_f64_(&sum_squared, &sum_squared_compensation, delta_x);
|
|
340
446
|
nk_accumulate_square_f64_(&sum_squared, &sum_squared_compensation, delta_y);
|
|
341
447
|
nk_accumulate_square_f64_(&sum_squared, &sum_squared_compensation, delta_z);
|
|
@@ -344,139 +450,526 @@ NK_INTERNAL nk_f64_t nk_transformed_ssd_f64_skylake_(nk_f64_t const *a, nk_f64_t
|
|
|
344
450
|
return sum_squared + sum_squared_compensation;
|
|
345
451
|
}
|
|
346
452
|
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
453
|
+
/* Compute sum of squared distances for f16 data after applying rotation (and optional scale).
|
|
454
|
+
* Loads f16, converts to f32 for computation. Rotation matrix, scale, and centroids are f32.
|
|
455
|
+
*/
|
|
456
|
+
NK_INTERNAL nk_f32_t nk_transformed_ssd_f16_skylake_(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n,
|
|
457
|
+
nk_f32_t const *r, nk_f32_t scale, nk_f32_t centroid_a_x,
|
|
458
|
+
nk_f32_t centroid_a_y, nk_f32_t centroid_a_z,
|
|
459
|
+
nk_f32_t centroid_b_x, nk_f32_t centroid_b_y,
|
|
460
|
+
nk_f32_t centroid_b_z) {
|
|
461
|
+
__m512 scaled_rotation_x_x_f32x16 = _mm512_set1_ps(scale * r[0]);
|
|
462
|
+
__m512 scaled_rotation_x_y_f32x16 = _mm512_set1_ps(scale * r[1]);
|
|
463
|
+
__m512 scaled_rotation_x_z_f32x16 = _mm512_set1_ps(scale * r[2]);
|
|
464
|
+
__m512 scaled_rotation_y_x_f32x16 = _mm512_set1_ps(scale * r[3]);
|
|
465
|
+
__m512 scaled_rotation_y_y_f32x16 = _mm512_set1_ps(scale * r[4]);
|
|
466
|
+
__m512 scaled_rotation_y_z_f32x16 = _mm512_set1_ps(scale * r[5]);
|
|
467
|
+
__m512 scaled_rotation_z_x_f32x16 = _mm512_set1_ps(scale * r[6]);
|
|
468
|
+
__m512 scaled_rotation_z_y_f32x16 = _mm512_set1_ps(scale * r[7]);
|
|
469
|
+
__m512 scaled_rotation_z_z_f32x16 = _mm512_set1_ps(scale * r[8]);
|
|
470
|
+
|
|
471
|
+
__m512 centroid_a_x_f32x16 = _mm512_set1_ps(centroid_a_x);
|
|
472
|
+
__m512 centroid_a_y_f32x16 = _mm512_set1_ps(centroid_a_y);
|
|
473
|
+
__m512 centroid_a_z_f32x16 = _mm512_set1_ps(centroid_a_z);
|
|
474
|
+
__m512 centroid_b_x_f32x16 = _mm512_set1_ps(centroid_b_x);
|
|
475
|
+
__m512 centroid_b_y_f32x16 = _mm512_set1_ps(centroid_b_y);
|
|
476
|
+
__m512 centroid_b_z_f32x16 = _mm512_set1_ps(centroid_b_z);
|
|
477
|
+
|
|
478
|
+
__m512 sum_squared_f32x16 = _mm512_setzero_ps();
|
|
359
479
|
__m512 a_x_f32x16, a_y_f32x16, a_z_f32x16, b_x_f32x16, b_y_f32x16, b_z_f32x16;
|
|
360
|
-
nk_size_t
|
|
480
|
+
nk_size_t j = 0;
|
|
361
481
|
|
|
362
|
-
for (;
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
covariance_02_f64x8 = _mm512_add_pd(covariance_02_f64x8,
|
|
391
|
-
_mm512_add_pd(_mm512_mul_pd(a_x_lower_f64x8, b_z_lower_f64x8),
|
|
392
|
-
_mm512_mul_pd(a_x_upper_f64x8, b_z_upper_f64x8)));
|
|
393
|
-
covariance_10_f64x8 = _mm512_add_pd(covariance_10_f64x8,
|
|
394
|
-
_mm512_add_pd(_mm512_mul_pd(a_y_lower_f64x8, b_x_lower_f64x8),
|
|
395
|
-
_mm512_mul_pd(a_y_upper_f64x8, b_x_upper_f64x8))),
|
|
396
|
-
covariance_11_f64x8 = _mm512_add_pd(covariance_11_f64x8,
|
|
397
|
-
_mm512_add_pd(_mm512_mul_pd(a_y_lower_f64x8, b_y_lower_f64x8),
|
|
398
|
-
_mm512_mul_pd(a_y_upper_f64x8, b_y_upper_f64x8))),
|
|
399
|
-
covariance_12_f64x8 = _mm512_add_pd(covariance_12_f64x8,
|
|
400
|
-
_mm512_add_pd(_mm512_mul_pd(a_y_lower_f64x8, b_z_lower_f64x8),
|
|
401
|
-
_mm512_mul_pd(a_y_upper_f64x8, b_z_upper_f64x8)));
|
|
402
|
-
covariance_20_f64x8 = _mm512_add_pd(covariance_20_f64x8,
|
|
403
|
-
_mm512_add_pd(_mm512_mul_pd(a_z_lower_f64x8, b_x_lower_f64x8),
|
|
404
|
-
_mm512_mul_pd(a_z_upper_f64x8, b_x_upper_f64x8))),
|
|
405
|
-
covariance_21_f64x8 = _mm512_add_pd(covariance_21_f64x8,
|
|
406
|
-
_mm512_add_pd(_mm512_mul_pd(a_z_lower_f64x8, b_y_lower_f64x8),
|
|
407
|
-
_mm512_mul_pd(a_z_upper_f64x8, b_y_upper_f64x8))),
|
|
408
|
-
covariance_22_f64x8 = _mm512_add_pd(covariance_22_f64x8,
|
|
409
|
-
_mm512_add_pd(_mm512_mul_pd(a_z_lower_f64x8, b_z_lower_f64x8),
|
|
410
|
-
_mm512_mul_pd(a_z_upper_f64x8, b_z_upper_f64x8)));
|
|
482
|
+
for (; j + 16 <= n; j += 16) {
|
|
483
|
+
nk_deinterleave_f16x16_to_f32x16_skylake_(a + j * 3, &a_x_f32x16, &a_y_f32x16, &a_z_f32x16);
|
|
484
|
+
nk_deinterleave_f16x16_to_f32x16_skylake_(b + j * 3, &b_x_f32x16, &b_y_f32x16, &b_z_f32x16);
|
|
485
|
+
|
|
486
|
+
__m512 pa_x_f32x16 = _mm512_sub_ps(a_x_f32x16, centroid_a_x_f32x16);
|
|
487
|
+
__m512 pa_y_f32x16 = _mm512_sub_ps(a_y_f32x16, centroid_a_y_f32x16);
|
|
488
|
+
__m512 pa_z_f32x16 = _mm512_sub_ps(a_z_f32x16, centroid_a_z_f32x16);
|
|
489
|
+
__m512 pb_x_f32x16 = _mm512_sub_ps(b_x_f32x16, centroid_b_x_f32x16);
|
|
490
|
+
__m512 pb_y_f32x16 = _mm512_sub_ps(b_y_f32x16, centroid_b_y_f32x16);
|
|
491
|
+
__m512 pb_z_f32x16 = _mm512_sub_ps(b_z_f32x16, centroid_b_z_f32x16);
|
|
492
|
+
|
|
493
|
+
__m512 ra_x_f32x16 = _mm512_fmadd_ps(scaled_rotation_x_z_f32x16, pa_z_f32x16,
|
|
494
|
+
_mm512_fmadd_ps(scaled_rotation_x_y_f32x16, pa_y_f32x16,
|
|
495
|
+
_mm512_mul_ps(scaled_rotation_x_x_f32x16, pa_x_f32x16)));
|
|
496
|
+
__m512 ra_y_f32x16 = _mm512_fmadd_ps(scaled_rotation_y_z_f32x16, pa_z_f32x16,
|
|
497
|
+
_mm512_fmadd_ps(scaled_rotation_y_y_f32x16, pa_y_f32x16,
|
|
498
|
+
_mm512_mul_ps(scaled_rotation_y_x_f32x16, pa_x_f32x16)));
|
|
499
|
+
__m512 ra_z_f32x16 = _mm512_fmadd_ps(scaled_rotation_z_z_f32x16, pa_z_f32x16,
|
|
500
|
+
_mm512_fmadd_ps(scaled_rotation_z_y_f32x16, pa_y_f32x16,
|
|
501
|
+
_mm512_mul_ps(scaled_rotation_z_x_f32x16, pa_x_f32x16)));
|
|
502
|
+
|
|
503
|
+
__m512 delta_x_f32x16 = _mm512_sub_ps(ra_x_f32x16, pb_x_f32x16);
|
|
504
|
+
__m512 delta_y_f32x16 = _mm512_sub_ps(ra_y_f32x16, pb_y_f32x16);
|
|
505
|
+
__m512 delta_z_f32x16 = _mm512_sub_ps(ra_z_f32x16, pb_z_f32x16);
|
|
506
|
+
|
|
507
|
+
sum_squared_f32x16 = _mm512_fmadd_ps(delta_x_f32x16, delta_x_f32x16, sum_squared_f32x16);
|
|
508
|
+
sum_squared_f32x16 = _mm512_fmadd_ps(delta_y_f32x16, delta_y_f32x16, sum_squared_f32x16);
|
|
509
|
+
sum_squared_f32x16 = _mm512_fmadd_ps(delta_z_f32x16, delta_z_f32x16, sum_squared_f32x16);
|
|
411
510
|
}
|
|
412
511
|
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
512
|
+
// Tail: deinterleave remaining points into zero-initialized vectors
|
|
513
|
+
if (j < n) {
|
|
514
|
+
nk_size_t tail = n - j;
|
|
515
|
+
nk_deinterleave_f16_tail_to_f32x16_skylake_(a + j * 3, tail, &a_x_f32x16, &a_y_f32x16, &a_z_f32x16);
|
|
516
|
+
nk_deinterleave_f16_tail_to_f32x16_skylake_(b + j * 3, tail, &b_x_f32x16, &b_y_f32x16, &b_z_f32x16);
|
|
517
|
+
|
|
518
|
+
__m512 pa_x_f32x16 = _mm512_sub_ps(a_x_f32x16, centroid_a_x_f32x16);
|
|
519
|
+
__m512 pa_y_f32x16 = _mm512_sub_ps(a_y_f32x16, centroid_a_y_f32x16);
|
|
520
|
+
__m512 pa_z_f32x16 = _mm512_sub_ps(a_z_f32x16, centroid_a_z_f32x16);
|
|
521
|
+
__m512 pb_x_f32x16 = _mm512_sub_ps(b_x_f32x16, centroid_b_x_f32x16);
|
|
522
|
+
__m512 pb_y_f32x16 = _mm512_sub_ps(b_y_f32x16, centroid_b_y_f32x16);
|
|
523
|
+
__m512 pb_z_f32x16 = _mm512_sub_ps(b_z_f32x16, centroid_b_z_f32x16);
|
|
524
|
+
|
|
525
|
+
__m512 ra_x_f32x16 = _mm512_fmadd_ps(scaled_rotation_x_z_f32x16, pa_z_f32x16,
|
|
526
|
+
_mm512_fmadd_ps(scaled_rotation_x_y_f32x16, pa_y_f32x16,
|
|
527
|
+
_mm512_mul_ps(scaled_rotation_x_x_f32x16, pa_x_f32x16)));
|
|
528
|
+
__m512 ra_y_f32x16 = _mm512_fmadd_ps(scaled_rotation_y_z_f32x16, pa_z_f32x16,
|
|
529
|
+
_mm512_fmadd_ps(scaled_rotation_y_y_f32x16, pa_y_f32x16,
|
|
530
|
+
_mm512_mul_ps(scaled_rotation_y_x_f32x16, pa_x_f32x16)));
|
|
531
|
+
__m512 ra_z_f32x16 = _mm512_fmadd_ps(scaled_rotation_z_z_f32x16, pa_z_f32x16,
|
|
532
|
+
_mm512_fmadd_ps(scaled_rotation_z_y_f32x16, pa_y_f32x16,
|
|
533
|
+
_mm512_mul_ps(scaled_rotation_z_x_f32x16, pa_x_f32x16)));
|
|
534
|
+
|
|
535
|
+
__m512 delta_x_f32x16 = _mm512_sub_ps(ra_x_f32x16, pb_x_f32x16);
|
|
536
|
+
__m512 delta_y_f32x16 = _mm512_sub_ps(ra_y_f32x16, pb_y_f32x16);
|
|
537
|
+
__m512 delta_z_f32x16 = _mm512_sub_ps(ra_z_f32x16, pb_z_f32x16);
|
|
538
|
+
|
|
539
|
+
sum_squared_f32x16 = _mm512_fmadd_ps(delta_x_f32x16, delta_x_f32x16, sum_squared_f32x16);
|
|
540
|
+
sum_squared_f32x16 = _mm512_fmadd_ps(delta_y_f32x16, delta_y_f32x16, sum_squared_f32x16);
|
|
541
|
+
sum_squared_f32x16 = _mm512_fmadd_ps(delta_z_f32x16, delta_z_f32x16, sum_squared_f32x16);
|
|
542
|
+
}
|
|
426
543
|
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
544
|
+
return _mm512_reduce_add_ps(sum_squared_f32x16);
|
|
545
|
+
}
|
|
546
|
+
|
|
547
|
+
/* Compute sum of squared distances for bf16 data after applying rotation (and optional scale).
|
|
548
|
+
* Loads bf16, converts to f32 for computation. Rotation matrix, scale, and centroids are f32.
|
|
549
|
+
*/
|
|
550
|
+
NK_INTERNAL nk_f32_t nk_transformed_ssd_bf16_skylake_(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n,
|
|
551
|
+
nk_f32_t const *r, nk_f32_t scale, nk_f32_t centroid_a_x,
|
|
552
|
+
nk_f32_t centroid_a_y, nk_f32_t centroid_a_z,
|
|
553
|
+
nk_f32_t centroid_b_x, nk_f32_t centroid_b_y,
|
|
554
|
+
nk_f32_t centroid_b_z) {
|
|
555
|
+
__m512 scaled_rotation_x_x_f32x16 = _mm512_set1_ps(scale * r[0]);
|
|
556
|
+
__m512 scaled_rotation_x_y_f32x16 = _mm512_set1_ps(scale * r[1]);
|
|
557
|
+
__m512 scaled_rotation_x_z_f32x16 = _mm512_set1_ps(scale * r[2]);
|
|
558
|
+
__m512 scaled_rotation_y_x_f32x16 = _mm512_set1_ps(scale * r[3]);
|
|
559
|
+
__m512 scaled_rotation_y_y_f32x16 = _mm512_set1_ps(scale * r[4]);
|
|
560
|
+
__m512 scaled_rotation_y_z_f32x16 = _mm512_set1_ps(scale * r[5]);
|
|
561
|
+
__m512 scaled_rotation_z_x_f32x16 = _mm512_set1_ps(scale * r[6]);
|
|
562
|
+
__m512 scaled_rotation_z_y_f32x16 = _mm512_set1_ps(scale * r[7]);
|
|
563
|
+
__m512 scaled_rotation_z_z_f32x16 = _mm512_set1_ps(scale * r[8]);
|
|
564
|
+
|
|
565
|
+
__m512 centroid_a_x_f32x16 = _mm512_set1_ps(centroid_a_x);
|
|
566
|
+
__m512 centroid_a_y_f32x16 = _mm512_set1_ps(centroid_a_y);
|
|
567
|
+
__m512 centroid_a_z_f32x16 = _mm512_set1_ps(centroid_a_z);
|
|
568
|
+
__m512 centroid_b_x_f32x16 = _mm512_set1_ps(centroid_b_x);
|
|
569
|
+
__m512 centroid_b_y_f32x16 = _mm512_set1_ps(centroid_b_y);
|
|
570
|
+
__m512 centroid_b_z_f32x16 = _mm512_set1_ps(centroid_b_z);
|
|
571
|
+
|
|
572
|
+
__m512 sum_squared_f32x16 = _mm512_setzero_ps();
|
|
573
|
+
__m512 a_x_f32x16, a_y_f32x16, a_z_f32x16, b_x_f32x16, b_y_f32x16, b_z_f32x16;
|
|
574
|
+
nk_size_t j = 0;
|
|
575
|
+
|
|
576
|
+
for (; j + 16 <= n; j += 16) {
|
|
577
|
+
nk_deinterleave_bf16x16_to_f32x16_skylake_(a + j * 3, &a_x_f32x16, &a_y_f32x16, &a_z_f32x16);
|
|
578
|
+
nk_deinterleave_bf16x16_to_f32x16_skylake_(b + j * 3, &b_x_f32x16, &b_y_f32x16, &b_z_f32x16);
|
|
579
|
+
|
|
580
|
+
__m512 pa_x_f32x16 = _mm512_sub_ps(a_x_f32x16, centroid_a_x_f32x16);
|
|
581
|
+
__m512 pa_y_f32x16 = _mm512_sub_ps(a_y_f32x16, centroid_a_y_f32x16);
|
|
582
|
+
__m512 pa_z_f32x16 = _mm512_sub_ps(a_z_f32x16, centroid_a_z_f32x16);
|
|
583
|
+
__m512 pb_x_f32x16 = _mm512_sub_ps(b_x_f32x16, centroid_b_x_f32x16);
|
|
584
|
+
__m512 pb_y_f32x16 = _mm512_sub_ps(b_y_f32x16, centroid_b_y_f32x16);
|
|
585
|
+
__m512 pb_z_f32x16 = _mm512_sub_ps(b_z_f32x16, centroid_b_z_f32x16);
|
|
586
|
+
|
|
587
|
+
__m512 ra_x_f32x16 = _mm512_fmadd_ps(scaled_rotation_x_z_f32x16, pa_z_f32x16,
|
|
588
|
+
_mm512_fmadd_ps(scaled_rotation_x_y_f32x16, pa_y_f32x16,
|
|
589
|
+
_mm512_mul_ps(scaled_rotation_x_x_f32x16, pa_x_f32x16)));
|
|
590
|
+
__m512 ra_y_f32x16 = _mm512_fmadd_ps(scaled_rotation_y_z_f32x16, pa_z_f32x16,
|
|
591
|
+
_mm512_fmadd_ps(scaled_rotation_y_y_f32x16, pa_y_f32x16,
|
|
592
|
+
_mm512_mul_ps(scaled_rotation_y_x_f32x16, pa_x_f32x16)));
|
|
593
|
+
__m512 ra_z_f32x16 = _mm512_fmadd_ps(scaled_rotation_z_z_f32x16, pa_z_f32x16,
|
|
594
|
+
_mm512_fmadd_ps(scaled_rotation_z_y_f32x16, pa_y_f32x16,
|
|
595
|
+
_mm512_mul_ps(scaled_rotation_z_x_f32x16, pa_x_f32x16)));
|
|
596
|
+
|
|
597
|
+
__m512 delta_x_f32x16 = _mm512_sub_ps(ra_x_f32x16, pb_x_f32x16);
|
|
598
|
+
__m512 delta_y_f32x16 = _mm512_sub_ps(ra_y_f32x16, pb_y_f32x16);
|
|
599
|
+
__m512 delta_z_f32x16 = _mm512_sub_ps(ra_z_f32x16, pb_z_f32x16);
|
|
600
|
+
|
|
601
|
+
sum_squared_f32x16 = _mm512_fmadd_ps(delta_x_f32x16, delta_x_f32x16, sum_squared_f32x16);
|
|
602
|
+
sum_squared_f32x16 = _mm512_fmadd_ps(delta_y_f32x16, delta_y_f32x16, sum_squared_f32x16);
|
|
603
|
+
sum_squared_f32x16 = _mm512_fmadd_ps(delta_z_f32x16, delta_z_f32x16, sum_squared_f32x16);
|
|
604
|
+
}
|
|
605
|
+
|
|
606
|
+
// Tail: deinterleave remaining points into zero-initialized vectors
|
|
607
|
+
if (j < n) {
|
|
608
|
+
nk_size_t tail = n - j;
|
|
609
|
+
nk_deinterleave_bf16_tail_to_f32x16_skylake_(a + j * 3, tail, &a_x_f32x16, &a_y_f32x16, &a_z_f32x16);
|
|
610
|
+
nk_deinterleave_bf16_tail_to_f32x16_skylake_(b + j * 3, tail, &b_x_f32x16, &b_y_f32x16, &b_z_f32x16);
|
|
611
|
+
|
|
612
|
+
__m512 pa_x_f32x16 = _mm512_sub_ps(a_x_f32x16, centroid_a_x_f32x16);
|
|
613
|
+
__m512 pa_y_f32x16 = _mm512_sub_ps(a_y_f32x16, centroid_a_y_f32x16);
|
|
614
|
+
__m512 pa_z_f32x16 = _mm512_sub_ps(a_z_f32x16, centroid_a_z_f32x16);
|
|
615
|
+
__m512 pb_x_f32x16 = _mm512_sub_ps(b_x_f32x16, centroid_b_x_f32x16);
|
|
616
|
+
__m512 pb_y_f32x16 = _mm512_sub_ps(b_y_f32x16, centroid_b_y_f32x16);
|
|
617
|
+
__m512 pb_z_f32x16 = _mm512_sub_ps(b_z_f32x16, centroid_b_z_f32x16);
|
|
618
|
+
|
|
619
|
+
__m512 ra_x_f32x16 = _mm512_fmadd_ps(scaled_rotation_x_z_f32x16, pa_z_f32x16,
|
|
620
|
+
_mm512_fmadd_ps(scaled_rotation_x_y_f32x16, pa_y_f32x16,
|
|
621
|
+
_mm512_mul_ps(scaled_rotation_x_x_f32x16, pa_x_f32x16)));
|
|
622
|
+
__m512 ra_y_f32x16 = _mm512_fmadd_ps(scaled_rotation_y_z_f32x16, pa_z_f32x16,
|
|
623
|
+
_mm512_fmadd_ps(scaled_rotation_y_y_f32x16, pa_y_f32x16,
|
|
624
|
+
_mm512_mul_ps(scaled_rotation_y_x_f32x16, pa_x_f32x16)));
|
|
625
|
+
__m512 ra_z_f32x16 = _mm512_fmadd_ps(scaled_rotation_z_z_f32x16, pa_z_f32x16,
|
|
626
|
+
_mm512_fmadd_ps(scaled_rotation_z_y_f32x16, pa_y_f32x16,
|
|
627
|
+
_mm512_mul_ps(scaled_rotation_z_x_f32x16, pa_x_f32x16)));
|
|
628
|
+
|
|
629
|
+
__m512 delta_x_f32x16 = _mm512_sub_ps(ra_x_f32x16, pb_x_f32x16);
|
|
630
|
+
__m512 delta_y_f32x16 = _mm512_sub_ps(ra_y_f32x16, pb_y_f32x16);
|
|
631
|
+
__m512 delta_z_f32x16 = _mm512_sub_ps(ra_z_f32x16, pb_z_f32x16);
|
|
632
|
+
|
|
633
|
+
sum_squared_f32x16 = _mm512_fmadd_ps(delta_x_f32x16, delta_x_f32x16, sum_squared_f32x16);
|
|
634
|
+
sum_squared_f32x16 = _mm512_fmadd_ps(delta_y_f32x16, delta_y_f32x16, sum_squared_f32x16);
|
|
635
|
+
sum_squared_f32x16 = _mm512_fmadd_ps(delta_z_f32x16, delta_z_f32x16, sum_squared_f32x16);
|
|
435
636
|
}
|
|
436
637
|
|
|
437
|
-
|
|
438
|
-
*centroid_a_x = sum_a_x * inv_n, *centroid_a_y = sum_a_y * inv_n, *centroid_a_z = sum_a_z * inv_n;
|
|
439
|
-
*centroid_b_x = sum_b_x * inv_n, *centroid_b_y = sum_b_y * inv_n, *centroid_b_z = sum_b_z * inv_n;
|
|
440
|
-
cross_covariance_f64[0] = covariance_00 - n_f64 * (*centroid_a_x) * (*centroid_b_x),
|
|
441
|
-
cross_covariance_f64[1] = covariance_01 - n_f64 * (*centroid_a_x) * (*centroid_b_y),
|
|
442
|
-
cross_covariance_f64[2] = covariance_02 - n_f64 * (*centroid_a_x) * (*centroid_b_z);
|
|
443
|
-
cross_covariance_f64[3] = covariance_10 - n_f64 * (*centroid_a_y) * (*centroid_b_x),
|
|
444
|
-
cross_covariance_f64[4] = covariance_11 - n_f64 * (*centroid_a_y) * (*centroid_b_y),
|
|
445
|
-
cross_covariance_f64[5] = covariance_12 - n_f64 * (*centroid_a_y) * (*centroid_b_z);
|
|
446
|
-
cross_covariance_f64[6] = covariance_20 - n_f64 * (*centroid_a_z) * (*centroid_b_x),
|
|
447
|
-
cross_covariance_f64[7] = covariance_21 - n_f64 * (*centroid_a_z) * (*centroid_b_y),
|
|
448
|
-
cross_covariance_f64[8] = covariance_22 - n_f64 * (*centroid_a_z) * (*centroid_b_z);
|
|
638
|
+
return _mm512_reduce_add_ps(sum_squared_f32x16);
|
|
449
639
|
}
|
|
450
640
|
|
|
451
641
|
NK_PUBLIC void nk_rmsd_f32_skylake(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f32_t *a_centroid,
|
|
452
642
|
nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f64_t *result) {
|
|
453
|
-
nk_f64_t identity[9] = {1, 0, 0, 0, 1, 0, 0, 0, 1};
|
|
454
|
-
nk_f64_t centroid_a_x, centroid_a_y, centroid_a_z, centroid_b_x, centroid_b_y, centroid_b_z;
|
|
455
|
-
nk_f64_t cross_covariance_f64[9];
|
|
456
643
|
if (rotation)
|
|
457
644
|
rotation[0] = 1, rotation[1] = 0, rotation[2] = 0, rotation[3] = 0, rotation[4] = 1, rotation[5] = 0,
|
|
458
645
|
rotation[6] = 0, rotation[7] = 0, rotation[8] = 1;
|
|
459
646
|
if (scale) *scale = 1.0f;
|
|
460
|
-
|
|
461
|
-
|
|
647
|
+
|
|
648
|
+
// Fused single-pass: centroids + squared differences in f64, using the identity:
|
|
649
|
+
// RMSD = √(E[(a-b)²] - (ā - b̄)²)
|
|
650
|
+
__m512d const zeros_f64x8 = _mm512_setzero_pd();
|
|
651
|
+
__m512d sum_a_x_f64x8 = zeros_f64x8, sum_a_y_f64x8 = zeros_f64x8, sum_a_z_f64x8 = zeros_f64x8;
|
|
652
|
+
__m512d sum_b_x_f64x8 = zeros_f64x8, sum_b_y_f64x8 = zeros_f64x8, sum_b_z_f64x8 = zeros_f64x8;
|
|
653
|
+
__m512d sum_squared_x_f64x8 = zeros_f64x8, sum_squared_y_f64x8 = zeros_f64x8, sum_squared_z_f64x8 = zeros_f64x8;
|
|
654
|
+
__m512 a_x_f32x16, a_y_f32x16, a_z_f32x16, b_x_f32x16, b_y_f32x16, b_z_f32x16;
|
|
655
|
+
nk_size_t i = 0;
|
|
656
|
+
|
|
657
|
+
// Main loop with 2x unrolling (32 points per iteration)
|
|
658
|
+
for (; i + 32 <= n; i += 32) {
|
|
659
|
+
// Iteration 0: points i..i+15
|
|
660
|
+
nk_deinterleave_f32x16_skylake_(a + i * 3, &a_x_f32x16, &a_y_f32x16, &a_z_f32x16);
|
|
661
|
+
nk_deinterleave_f32x16_skylake_(b + i * 3, &b_x_f32x16, &b_y_f32x16, &b_z_f32x16);
|
|
662
|
+
__m512d a_x_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_x_f32x16));
|
|
663
|
+
__m512d a_x_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_x_f32x16, 1));
|
|
664
|
+
__m512d a_y_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_y_f32x16));
|
|
665
|
+
__m512d a_y_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_y_f32x16, 1));
|
|
666
|
+
__m512d a_z_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_z_f32x16));
|
|
667
|
+
__m512d a_z_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_z_f32x16, 1));
|
|
668
|
+
__m512d b_x_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_x_f32x16));
|
|
669
|
+
__m512d b_x_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_x_f32x16, 1));
|
|
670
|
+
__m512d b_y_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_y_f32x16));
|
|
671
|
+
__m512d b_y_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_y_f32x16, 1));
|
|
672
|
+
__m512d b_z_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_z_f32x16));
|
|
673
|
+
__m512d b_z_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_z_f32x16, 1));
|
|
674
|
+
|
|
675
|
+
sum_a_x_f64x8 = _mm512_add_pd(sum_a_x_f64x8, _mm512_add_pd(a_x_low_f64x8, a_x_high_f64x8));
|
|
676
|
+
sum_a_y_f64x8 = _mm512_add_pd(sum_a_y_f64x8, _mm512_add_pd(a_y_low_f64x8, a_y_high_f64x8));
|
|
677
|
+
sum_a_z_f64x8 = _mm512_add_pd(sum_a_z_f64x8, _mm512_add_pd(a_z_low_f64x8, a_z_high_f64x8));
|
|
678
|
+
sum_b_x_f64x8 = _mm512_add_pd(sum_b_x_f64x8, _mm512_add_pd(b_x_low_f64x8, b_x_high_f64x8));
|
|
679
|
+
sum_b_y_f64x8 = _mm512_add_pd(sum_b_y_f64x8, _mm512_add_pd(b_y_low_f64x8, b_y_high_f64x8));
|
|
680
|
+
sum_b_z_f64x8 = _mm512_add_pd(sum_b_z_f64x8, _mm512_add_pd(b_z_low_f64x8, b_z_high_f64x8));
|
|
681
|
+
|
|
682
|
+
__m512d delta_x_low_f64x8 = _mm512_sub_pd(a_x_low_f64x8, b_x_low_f64x8);
|
|
683
|
+
__m512d delta_x_high_f64x8 = _mm512_sub_pd(a_x_high_f64x8, b_x_high_f64x8);
|
|
684
|
+
__m512d delta_y_low_f64x8 = _mm512_sub_pd(a_y_low_f64x8, b_y_low_f64x8);
|
|
685
|
+
__m512d delta_y_high_f64x8 = _mm512_sub_pd(a_y_high_f64x8, b_y_high_f64x8);
|
|
686
|
+
__m512d delta_z_low_f64x8 = _mm512_sub_pd(a_z_low_f64x8, b_z_low_f64x8);
|
|
687
|
+
__m512d delta_z_high_f64x8 = _mm512_sub_pd(a_z_high_f64x8, b_z_high_f64x8);
|
|
688
|
+
sum_squared_x_f64x8 = _mm512_fmadd_pd(delta_x_low_f64x8, delta_x_low_f64x8, sum_squared_x_f64x8);
|
|
689
|
+
sum_squared_x_f64x8 = _mm512_fmadd_pd(delta_x_high_f64x8, delta_x_high_f64x8, sum_squared_x_f64x8);
|
|
690
|
+
sum_squared_y_f64x8 = _mm512_fmadd_pd(delta_y_low_f64x8, delta_y_low_f64x8, sum_squared_y_f64x8);
|
|
691
|
+
sum_squared_y_f64x8 = _mm512_fmadd_pd(delta_y_high_f64x8, delta_y_high_f64x8, sum_squared_y_f64x8);
|
|
692
|
+
sum_squared_z_f64x8 = _mm512_fmadd_pd(delta_z_low_f64x8, delta_z_low_f64x8, sum_squared_z_f64x8);
|
|
693
|
+
sum_squared_z_f64x8 = _mm512_fmadd_pd(delta_z_high_f64x8, delta_z_high_f64x8, sum_squared_z_f64x8);
|
|
694
|
+
|
|
695
|
+
// Iteration 1: points i+16..i+31
|
|
696
|
+
nk_deinterleave_f32x16_skylake_(a + (i + 16) * 3, &a_x_f32x16, &a_y_f32x16, &a_z_f32x16);
|
|
697
|
+
nk_deinterleave_f32x16_skylake_(b + (i + 16) * 3, &b_x_f32x16, &b_y_f32x16, &b_z_f32x16);
|
|
698
|
+
a_x_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_x_f32x16));
|
|
699
|
+
a_x_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_x_f32x16, 1));
|
|
700
|
+
a_y_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_y_f32x16));
|
|
701
|
+
a_y_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_y_f32x16, 1));
|
|
702
|
+
a_z_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_z_f32x16));
|
|
703
|
+
a_z_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_z_f32x16, 1));
|
|
704
|
+
b_x_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_x_f32x16));
|
|
705
|
+
b_x_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_x_f32x16, 1));
|
|
706
|
+
b_y_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_y_f32x16));
|
|
707
|
+
b_y_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_y_f32x16, 1));
|
|
708
|
+
b_z_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_z_f32x16));
|
|
709
|
+
b_z_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_z_f32x16, 1));
|
|
710
|
+
|
|
711
|
+
sum_a_x_f64x8 = _mm512_add_pd(sum_a_x_f64x8, _mm512_add_pd(a_x_low_f64x8, a_x_high_f64x8));
|
|
712
|
+
sum_a_y_f64x8 = _mm512_add_pd(sum_a_y_f64x8, _mm512_add_pd(a_y_low_f64x8, a_y_high_f64x8));
|
|
713
|
+
sum_a_z_f64x8 = _mm512_add_pd(sum_a_z_f64x8, _mm512_add_pd(a_z_low_f64x8, a_z_high_f64x8));
|
|
714
|
+
sum_b_x_f64x8 = _mm512_add_pd(sum_b_x_f64x8, _mm512_add_pd(b_x_low_f64x8, b_x_high_f64x8));
|
|
715
|
+
sum_b_y_f64x8 = _mm512_add_pd(sum_b_y_f64x8, _mm512_add_pd(b_y_low_f64x8, b_y_high_f64x8));
|
|
716
|
+
sum_b_z_f64x8 = _mm512_add_pd(sum_b_z_f64x8, _mm512_add_pd(b_z_low_f64x8, b_z_high_f64x8));
|
|
717
|
+
|
|
718
|
+
delta_x_low_f64x8 = _mm512_sub_pd(a_x_low_f64x8, b_x_low_f64x8);
|
|
719
|
+
delta_x_high_f64x8 = _mm512_sub_pd(a_x_high_f64x8, b_x_high_f64x8);
|
|
720
|
+
delta_y_low_f64x8 = _mm512_sub_pd(a_y_low_f64x8, b_y_low_f64x8);
|
|
721
|
+
delta_y_high_f64x8 = _mm512_sub_pd(a_y_high_f64x8, b_y_high_f64x8);
|
|
722
|
+
delta_z_low_f64x8 = _mm512_sub_pd(a_z_low_f64x8, b_z_low_f64x8);
|
|
723
|
+
delta_z_high_f64x8 = _mm512_sub_pd(a_z_high_f64x8, b_z_high_f64x8);
|
|
724
|
+
sum_squared_x_f64x8 = _mm512_fmadd_pd(delta_x_low_f64x8, delta_x_low_f64x8, sum_squared_x_f64x8);
|
|
725
|
+
sum_squared_x_f64x8 = _mm512_fmadd_pd(delta_x_high_f64x8, delta_x_high_f64x8, sum_squared_x_f64x8);
|
|
726
|
+
sum_squared_y_f64x8 = _mm512_fmadd_pd(delta_y_low_f64x8, delta_y_low_f64x8, sum_squared_y_f64x8);
|
|
727
|
+
sum_squared_y_f64x8 = _mm512_fmadd_pd(delta_y_high_f64x8, delta_y_high_f64x8, sum_squared_y_f64x8);
|
|
728
|
+
sum_squared_z_f64x8 = _mm512_fmadd_pd(delta_z_low_f64x8, delta_z_low_f64x8, sum_squared_z_f64x8);
|
|
729
|
+
sum_squared_z_f64x8 = _mm512_fmadd_pd(delta_z_high_f64x8, delta_z_high_f64x8, sum_squared_z_f64x8);
|
|
730
|
+
}
|
|
731
|
+
|
|
732
|
+
// Handle 16-point remainder
|
|
733
|
+
for (; i + 16 <= n; i += 16) {
|
|
734
|
+
nk_deinterleave_f32x16_skylake_(a + i * 3, &a_x_f32x16, &a_y_f32x16, &a_z_f32x16);
|
|
735
|
+
nk_deinterleave_f32x16_skylake_(b + i * 3, &b_x_f32x16, &b_y_f32x16, &b_z_f32x16);
|
|
736
|
+
__m512d a_x_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_x_f32x16));
|
|
737
|
+
__m512d a_x_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_x_f32x16, 1));
|
|
738
|
+
__m512d a_y_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_y_f32x16));
|
|
739
|
+
__m512d a_y_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_y_f32x16, 1));
|
|
740
|
+
__m512d a_z_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_z_f32x16));
|
|
741
|
+
__m512d a_z_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_z_f32x16, 1));
|
|
742
|
+
__m512d b_x_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_x_f32x16));
|
|
743
|
+
__m512d b_x_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_x_f32x16, 1));
|
|
744
|
+
__m512d b_y_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_y_f32x16));
|
|
745
|
+
__m512d b_y_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_y_f32x16, 1));
|
|
746
|
+
__m512d b_z_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_z_f32x16));
|
|
747
|
+
__m512d b_z_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_z_f32x16, 1));
|
|
748
|
+
|
|
749
|
+
sum_a_x_f64x8 = _mm512_add_pd(sum_a_x_f64x8, _mm512_add_pd(a_x_low_f64x8, a_x_high_f64x8));
|
|
750
|
+
sum_a_y_f64x8 = _mm512_add_pd(sum_a_y_f64x8, _mm512_add_pd(a_y_low_f64x8, a_y_high_f64x8));
|
|
751
|
+
sum_a_z_f64x8 = _mm512_add_pd(sum_a_z_f64x8, _mm512_add_pd(a_z_low_f64x8, a_z_high_f64x8));
|
|
752
|
+
sum_b_x_f64x8 = _mm512_add_pd(sum_b_x_f64x8, _mm512_add_pd(b_x_low_f64x8, b_x_high_f64x8));
|
|
753
|
+
sum_b_y_f64x8 = _mm512_add_pd(sum_b_y_f64x8, _mm512_add_pd(b_y_low_f64x8, b_y_high_f64x8));
|
|
754
|
+
sum_b_z_f64x8 = _mm512_add_pd(sum_b_z_f64x8, _mm512_add_pd(b_z_low_f64x8, b_z_high_f64x8));
|
|
755
|
+
|
|
756
|
+
__m512d delta_x_low_f64x8 = _mm512_sub_pd(a_x_low_f64x8, b_x_low_f64x8);
|
|
757
|
+
__m512d delta_x_high_f64x8 = _mm512_sub_pd(a_x_high_f64x8, b_x_high_f64x8);
|
|
758
|
+
__m512d delta_y_low_f64x8 = _mm512_sub_pd(a_y_low_f64x8, b_y_low_f64x8);
|
|
759
|
+
__m512d delta_y_high_f64x8 = _mm512_sub_pd(a_y_high_f64x8, b_y_high_f64x8);
|
|
760
|
+
__m512d delta_z_low_f64x8 = _mm512_sub_pd(a_z_low_f64x8, b_z_low_f64x8);
|
|
761
|
+
__m512d delta_z_high_f64x8 = _mm512_sub_pd(a_z_high_f64x8, b_z_high_f64x8);
|
|
762
|
+
sum_squared_x_f64x8 = _mm512_fmadd_pd(delta_x_low_f64x8, delta_x_low_f64x8, sum_squared_x_f64x8);
|
|
763
|
+
sum_squared_x_f64x8 = _mm512_fmadd_pd(delta_x_high_f64x8, delta_x_high_f64x8, sum_squared_x_f64x8);
|
|
764
|
+
sum_squared_y_f64x8 = _mm512_fmadd_pd(delta_y_low_f64x8, delta_y_low_f64x8, sum_squared_y_f64x8);
|
|
765
|
+
sum_squared_y_f64x8 = _mm512_fmadd_pd(delta_y_high_f64x8, delta_y_high_f64x8, sum_squared_y_f64x8);
|
|
766
|
+
sum_squared_z_f64x8 = _mm512_fmadd_pd(delta_z_low_f64x8, delta_z_low_f64x8, sum_squared_z_f64x8);
|
|
767
|
+
sum_squared_z_f64x8 = _mm512_fmadd_pd(delta_z_high_f64x8, delta_z_high_f64x8, sum_squared_z_f64x8);
|
|
768
|
+
}
|
|
769
|
+
|
|
770
|
+
// Tail: use masked gather for remaining < 16 points
|
|
771
|
+
if (i < n) {
|
|
772
|
+
nk_size_t tail = n - i;
|
|
773
|
+
__mmask16 mask = (__mmask16)_bzhi_u32(0xFFFF, tail);
|
|
774
|
+
__m512i const gather_idx_i32x16 = _mm512_setr_epi32(0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33, 36, 39, 42, 45);
|
|
775
|
+
__m512 zeros_f32x16 = _mm512_setzero_ps();
|
|
776
|
+
nk_f32_t const *a_tail = a + i * 3;
|
|
777
|
+
nk_f32_t const *b_tail = b + i * 3;
|
|
778
|
+
|
|
779
|
+
a_x_f32x16 = _mm512_mask_i32gather_ps(zeros_f32x16, mask, gather_idx_i32x16, a_tail + 0, 4);
|
|
780
|
+
a_y_f32x16 = _mm512_mask_i32gather_ps(zeros_f32x16, mask, gather_idx_i32x16, a_tail + 1, 4);
|
|
781
|
+
a_z_f32x16 = _mm512_mask_i32gather_ps(zeros_f32x16, mask, gather_idx_i32x16, a_tail + 2, 4);
|
|
782
|
+
b_x_f32x16 = _mm512_mask_i32gather_ps(zeros_f32x16, mask, gather_idx_i32x16, b_tail + 0, 4);
|
|
783
|
+
b_y_f32x16 = _mm512_mask_i32gather_ps(zeros_f32x16, mask, gather_idx_i32x16, b_tail + 1, 4);
|
|
784
|
+
b_z_f32x16 = _mm512_mask_i32gather_ps(zeros_f32x16, mask, gather_idx_i32x16, b_tail + 2, 4);
|
|
785
|
+
|
|
786
|
+
__m512d a_x_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_x_f32x16));
|
|
787
|
+
__m512d a_x_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_x_f32x16, 1));
|
|
788
|
+
__m512d a_y_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_y_f32x16));
|
|
789
|
+
__m512d a_y_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_y_f32x16, 1));
|
|
790
|
+
__m512d a_z_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_z_f32x16));
|
|
791
|
+
__m512d a_z_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_z_f32x16, 1));
|
|
792
|
+
__m512d b_x_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_x_f32x16));
|
|
793
|
+
__m512d b_x_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_x_f32x16, 1));
|
|
794
|
+
__m512d b_y_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_y_f32x16));
|
|
795
|
+
__m512d b_y_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_y_f32x16, 1));
|
|
796
|
+
__m512d b_z_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_z_f32x16));
|
|
797
|
+
__m512d b_z_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_z_f32x16, 1));
|
|
798
|
+
|
|
799
|
+
sum_a_x_f64x8 = _mm512_add_pd(sum_a_x_f64x8, _mm512_add_pd(a_x_low_f64x8, a_x_high_f64x8));
|
|
800
|
+
sum_a_y_f64x8 = _mm512_add_pd(sum_a_y_f64x8, _mm512_add_pd(a_y_low_f64x8, a_y_high_f64x8));
|
|
801
|
+
sum_a_z_f64x8 = _mm512_add_pd(sum_a_z_f64x8, _mm512_add_pd(a_z_low_f64x8, a_z_high_f64x8));
|
|
802
|
+
sum_b_x_f64x8 = _mm512_add_pd(sum_b_x_f64x8, _mm512_add_pd(b_x_low_f64x8, b_x_high_f64x8));
|
|
803
|
+
sum_b_y_f64x8 = _mm512_add_pd(sum_b_y_f64x8, _mm512_add_pd(b_y_low_f64x8, b_y_high_f64x8));
|
|
804
|
+
sum_b_z_f64x8 = _mm512_add_pd(sum_b_z_f64x8, _mm512_add_pd(b_z_low_f64x8, b_z_high_f64x8));
|
|
805
|
+
|
|
806
|
+
__m512d delta_x_low_f64x8 = _mm512_sub_pd(a_x_low_f64x8, b_x_low_f64x8);
|
|
807
|
+
__m512d delta_x_high_f64x8 = _mm512_sub_pd(a_x_high_f64x8, b_x_high_f64x8);
|
|
808
|
+
__m512d delta_y_low_f64x8 = _mm512_sub_pd(a_y_low_f64x8, b_y_low_f64x8);
|
|
809
|
+
__m512d delta_y_high_f64x8 = _mm512_sub_pd(a_y_high_f64x8, b_y_high_f64x8);
|
|
810
|
+
__m512d delta_z_low_f64x8 = _mm512_sub_pd(a_z_low_f64x8, b_z_low_f64x8);
|
|
811
|
+
__m512d delta_z_high_f64x8 = _mm512_sub_pd(a_z_high_f64x8, b_z_high_f64x8);
|
|
812
|
+
sum_squared_x_f64x8 = _mm512_fmadd_pd(delta_x_low_f64x8, delta_x_low_f64x8, sum_squared_x_f64x8);
|
|
813
|
+
sum_squared_x_f64x8 = _mm512_fmadd_pd(delta_x_high_f64x8, delta_x_high_f64x8, sum_squared_x_f64x8);
|
|
814
|
+
sum_squared_y_f64x8 = _mm512_fmadd_pd(delta_y_low_f64x8, delta_y_low_f64x8, sum_squared_y_f64x8);
|
|
815
|
+
sum_squared_y_f64x8 = _mm512_fmadd_pd(delta_y_high_f64x8, delta_y_high_f64x8, sum_squared_y_f64x8);
|
|
816
|
+
sum_squared_z_f64x8 = _mm512_fmadd_pd(delta_z_low_f64x8, delta_z_low_f64x8, sum_squared_z_f64x8);
|
|
817
|
+
sum_squared_z_f64x8 = _mm512_fmadd_pd(delta_z_high_f64x8, delta_z_high_f64x8, sum_squared_z_f64x8);
|
|
818
|
+
}
|
|
819
|
+
|
|
820
|
+
// Reduce and compute centroids
|
|
821
|
+
nk_f64_t inv_n = 1.0 / (nk_f64_t)n;
|
|
822
|
+
nk_f64_t total_ax = _mm512_reduce_add_pd(sum_a_x_f64x8);
|
|
823
|
+
nk_f64_t total_ay = _mm512_reduce_add_pd(sum_a_y_f64x8);
|
|
824
|
+
nk_f64_t total_az = _mm512_reduce_add_pd(sum_a_z_f64x8);
|
|
825
|
+
nk_f64_t total_bx = _mm512_reduce_add_pd(sum_b_x_f64x8);
|
|
826
|
+
nk_f64_t total_by = _mm512_reduce_add_pd(sum_b_y_f64x8);
|
|
827
|
+
nk_f64_t total_bz = _mm512_reduce_add_pd(sum_b_z_f64x8);
|
|
828
|
+
nk_f64_t total_sq_x = _mm512_reduce_add_pd(sum_squared_x_f64x8);
|
|
829
|
+
nk_f64_t total_sq_y = _mm512_reduce_add_pd(sum_squared_y_f64x8);
|
|
830
|
+
nk_f64_t total_sq_z = _mm512_reduce_add_pd(sum_squared_z_f64x8);
|
|
831
|
+
|
|
832
|
+
nk_f64_t centroid_a_x = total_ax * inv_n, centroid_a_y = total_ay * inv_n, centroid_a_z = total_az * inv_n;
|
|
833
|
+
nk_f64_t centroid_b_x = total_bx * inv_n, centroid_b_y = total_by * inv_n, centroid_b_z = total_bz * inv_n;
|
|
462
834
|
if (a_centroid)
|
|
463
835
|
a_centroid[0] = (nk_f32_t)centroid_a_x, a_centroid[1] = (nk_f32_t)centroid_a_y,
|
|
464
836
|
a_centroid[2] = (nk_f32_t)centroid_a_z;
|
|
465
837
|
if (b_centroid)
|
|
466
838
|
b_centroid[0] = (nk_f32_t)centroid_b_x, b_centroid[1] = (nk_f32_t)centroid_b_y,
|
|
467
839
|
b_centroid[2] = (nk_f32_t)centroid_b_z;
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
840
|
+
|
|
841
|
+
nk_f64_t mean_diff_x = centroid_a_x - centroid_b_x, mean_diff_y = centroid_a_y - centroid_b_y,
|
|
842
|
+
mean_diff_z = centroid_a_z - centroid_b_z;
|
|
843
|
+
nk_f64_t sum_squared = total_sq_x + total_sq_y + total_sq_z;
|
|
844
|
+
nk_f64_t mean_diff_sq = mean_diff_x * mean_diff_x + mean_diff_y * mean_diff_y + mean_diff_z * mean_diff_z;
|
|
845
|
+
*result = nk_f64_sqrt_haswell(sum_squared * inv_n - mean_diff_sq);
|
|
472
846
|
}
|
|
473
847
|
|
|
474
848
|
NK_PUBLIC void nk_kabsch_f32_skylake(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f32_t *a_centroid,
|
|
475
849
|
nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f64_t *result) {
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
850
|
+
// Fused single-pass: centroids + covariance in f64
|
|
851
|
+
__m512d const zeros_f64x8 = _mm512_setzero_pd();
|
|
852
|
+
__m512d sum_a_x_f64x8 = zeros_f64x8, sum_a_y_f64x8 = zeros_f64x8, sum_a_z_f64x8 = zeros_f64x8;
|
|
853
|
+
__m512d sum_b_x_f64x8 = zeros_f64x8, sum_b_y_f64x8 = zeros_f64x8, sum_b_z_f64x8 = zeros_f64x8;
|
|
854
|
+
__m512d cov_xx_f64x8 = zeros_f64x8, cov_xy_f64x8 = zeros_f64x8, cov_xz_f64x8 = zeros_f64x8;
|
|
855
|
+
__m512d cov_yx_f64x8 = zeros_f64x8, cov_yy_f64x8 = zeros_f64x8, cov_yz_f64x8 = zeros_f64x8;
|
|
856
|
+
__m512d cov_zx_f64x8 = zeros_f64x8, cov_zy_f64x8 = zeros_f64x8, cov_zz_f64x8 = zeros_f64x8;
|
|
857
|
+
__m512 a_x_f32x16, a_y_f32x16, a_z_f32x16, b_x_f32x16, b_y_f32x16, b_z_f32x16;
|
|
858
|
+
nk_size_t i = 0;
|
|
859
|
+
|
|
860
|
+
for (; i + 16 <= n; i += 16) {
|
|
861
|
+
nk_deinterleave_f32x16_skylake_(a + i * 3, &a_x_f32x16, &a_y_f32x16, &a_z_f32x16);
|
|
862
|
+
nk_deinterleave_f32x16_skylake_(b + i * 3, &b_x_f32x16, &b_y_f32x16, &b_z_f32x16);
|
|
863
|
+
__m512d a_x_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_x_f32x16));
|
|
864
|
+
__m512d a_x_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_x_f32x16, 1));
|
|
865
|
+
__m512d a_y_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_y_f32x16));
|
|
866
|
+
__m512d a_y_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_y_f32x16, 1));
|
|
867
|
+
__m512d a_z_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_z_f32x16));
|
|
868
|
+
__m512d a_z_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_z_f32x16, 1));
|
|
869
|
+
__m512d b_x_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_x_f32x16));
|
|
870
|
+
__m512d b_x_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_x_f32x16, 1));
|
|
871
|
+
__m512d b_y_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_y_f32x16));
|
|
872
|
+
__m512d b_y_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_y_f32x16, 1));
|
|
873
|
+
__m512d b_z_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_z_f32x16));
|
|
874
|
+
__m512d b_z_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_z_f32x16, 1));
|
|
875
|
+
|
|
876
|
+
sum_a_x_f64x8 = _mm512_add_pd(sum_a_x_f64x8, _mm512_add_pd(a_x_low_f64x8, a_x_high_f64x8));
|
|
877
|
+
sum_a_y_f64x8 = _mm512_add_pd(sum_a_y_f64x8, _mm512_add_pd(a_y_low_f64x8, a_y_high_f64x8));
|
|
878
|
+
sum_a_z_f64x8 = _mm512_add_pd(sum_a_z_f64x8, _mm512_add_pd(a_z_low_f64x8, a_z_high_f64x8));
|
|
879
|
+
sum_b_x_f64x8 = _mm512_add_pd(sum_b_x_f64x8, _mm512_add_pd(b_x_low_f64x8, b_x_high_f64x8));
|
|
880
|
+
sum_b_y_f64x8 = _mm512_add_pd(sum_b_y_f64x8, _mm512_add_pd(b_y_low_f64x8, b_y_high_f64x8));
|
|
881
|
+
sum_b_z_f64x8 = _mm512_add_pd(sum_b_z_f64x8, _mm512_add_pd(b_z_low_f64x8, b_z_high_f64x8));
|
|
882
|
+
|
|
883
|
+
cov_xx_f64x8 = _mm512_add_pd(cov_xx_f64x8, _mm512_add_pd(_mm512_mul_pd(a_x_low_f64x8, b_x_low_f64x8),
|
|
884
|
+
_mm512_mul_pd(a_x_high_f64x8, b_x_high_f64x8)));
|
|
885
|
+
cov_xy_f64x8 = _mm512_add_pd(cov_xy_f64x8, _mm512_add_pd(_mm512_mul_pd(a_x_low_f64x8, b_y_low_f64x8),
|
|
886
|
+
_mm512_mul_pd(a_x_high_f64x8, b_y_high_f64x8)));
|
|
887
|
+
cov_xz_f64x8 = _mm512_add_pd(cov_xz_f64x8, _mm512_add_pd(_mm512_mul_pd(a_x_low_f64x8, b_z_low_f64x8),
|
|
888
|
+
_mm512_mul_pd(a_x_high_f64x8, b_z_high_f64x8)));
|
|
889
|
+
cov_yx_f64x8 = _mm512_add_pd(cov_yx_f64x8, _mm512_add_pd(_mm512_mul_pd(a_y_low_f64x8, b_x_low_f64x8),
|
|
890
|
+
_mm512_mul_pd(a_y_high_f64x8, b_x_high_f64x8)));
|
|
891
|
+
cov_yy_f64x8 = _mm512_add_pd(cov_yy_f64x8, _mm512_add_pd(_mm512_mul_pd(a_y_low_f64x8, b_y_low_f64x8),
|
|
892
|
+
_mm512_mul_pd(a_y_high_f64x8, b_y_high_f64x8)));
|
|
893
|
+
cov_yz_f64x8 = _mm512_add_pd(cov_yz_f64x8, _mm512_add_pd(_mm512_mul_pd(a_y_low_f64x8, b_z_low_f64x8),
|
|
894
|
+
_mm512_mul_pd(a_y_high_f64x8, b_z_high_f64x8)));
|
|
895
|
+
cov_zx_f64x8 = _mm512_add_pd(cov_zx_f64x8, _mm512_add_pd(_mm512_mul_pd(a_z_low_f64x8, b_x_low_f64x8),
|
|
896
|
+
_mm512_mul_pd(a_z_high_f64x8, b_x_high_f64x8)));
|
|
897
|
+
cov_zy_f64x8 = _mm512_add_pd(cov_zy_f64x8, _mm512_add_pd(_mm512_mul_pd(a_z_low_f64x8, b_y_low_f64x8),
|
|
898
|
+
_mm512_mul_pd(a_z_high_f64x8, b_y_high_f64x8)));
|
|
899
|
+
cov_zz_f64x8 = _mm512_add_pd(cov_zz_f64x8, _mm512_add_pd(_mm512_mul_pd(a_z_low_f64x8, b_z_low_f64x8),
|
|
900
|
+
_mm512_mul_pd(a_z_high_f64x8, b_z_high_f64x8)));
|
|
901
|
+
}
|
|
902
|
+
|
|
903
|
+
// Tail: use masked gather for remaining < 16 points
|
|
904
|
+
if (i < n) {
|
|
905
|
+
nk_size_t tail = n - i;
|
|
906
|
+
__mmask16 mask = (__mmask16)_bzhi_u32(0xFFFF, tail);
|
|
907
|
+
__m512i const gather_idx_i32x16 = _mm512_setr_epi32(0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33, 36, 39, 42, 45);
|
|
908
|
+
__m512 zeros_f32x16 = _mm512_setzero_ps();
|
|
909
|
+
nk_f32_t const *a_tail = a + i * 3;
|
|
910
|
+
nk_f32_t const *b_tail = b + i * 3;
|
|
911
|
+
|
|
912
|
+
a_x_f32x16 = _mm512_mask_i32gather_ps(zeros_f32x16, mask, gather_idx_i32x16, a_tail + 0, 4);
|
|
913
|
+
a_y_f32x16 = _mm512_mask_i32gather_ps(zeros_f32x16, mask, gather_idx_i32x16, a_tail + 1, 4);
|
|
914
|
+
a_z_f32x16 = _mm512_mask_i32gather_ps(zeros_f32x16, mask, gather_idx_i32x16, a_tail + 2, 4);
|
|
915
|
+
b_x_f32x16 = _mm512_mask_i32gather_ps(zeros_f32x16, mask, gather_idx_i32x16, b_tail + 0, 4);
|
|
916
|
+
b_y_f32x16 = _mm512_mask_i32gather_ps(zeros_f32x16, mask, gather_idx_i32x16, b_tail + 1, 4);
|
|
917
|
+
b_z_f32x16 = _mm512_mask_i32gather_ps(zeros_f32x16, mask, gather_idx_i32x16, b_tail + 2, 4);
|
|
918
|
+
|
|
919
|
+
__m512d a_x_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_x_f32x16));
|
|
920
|
+
__m512d a_x_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_x_f32x16, 1));
|
|
921
|
+
__m512d a_y_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_y_f32x16));
|
|
922
|
+
__m512d a_y_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_y_f32x16, 1));
|
|
923
|
+
__m512d a_z_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_z_f32x16));
|
|
924
|
+
__m512d a_z_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_z_f32x16, 1));
|
|
925
|
+
__m512d b_x_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_x_f32x16));
|
|
926
|
+
__m512d b_x_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_x_f32x16, 1));
|
|
927
|
+
__m512d b_y_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_y_f32x16));
|
|
928
|
+
__m512d b_y_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_y_f32x16, 1));
|
|
929
|
+
__m512d b_z_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_z_f32x16));
|
|
930
|
+
__m512d b_z_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_z_f32x16, 1));
|
|
931
|
+
|
|
932
|
+
sum_a_x_f64x8 = _mm512_add_pd(sum_a_x_f64x8, _mm512_add_pd(a_x_low_f64x8, a_x_high_f64x8));
|
|
933
|
+
sum_a_y_f64x8 = _mm512_add_pd(sum_a_y_f64x8, _mm512_add_pd(a_y_low_f64x8, a_y_high_f64x8));
|
|
934
|
+
sum_a_z_f64x8 = _mm512_add_pd(sum_a_z_f64x8, _mm512_add_pd(a_z_low_f64x8, a_z_high_f64x8));
|
|
935
|
+
sum_b_x_f64x8 = _mm512_add_pd(sum_b_x_f64x8, _mm512_add_pd(b_x_low_f64x8, b_x_high_f64x8));
|
|
936
|
+
sum_b_y_f64x8 = _mm512_add_pd(sum_b_y_f64x8, _mm512_add_pd(b_y_low_f64x8, b_y_high_f64x8));
|
|
937
|
+
sum_b_z_f64x8 = _mm512_add_pd(sum_b_z_f64x8, _mm512_add_pd(b_z_low_f64x8, b_z_high_f64x8));
|
|
938
|
+
|
|
939
|
+
cov_xx_f64x8 = _mm512_add_pd(cov_xx_f64x8, _mm512_add_pd(_mm512_mul_pd(a_x_low_f64x8, b_x_low_f64x8),
|
|
940
|
+
_mm512_mul_pd(a_x_high_f64x8, b_x_high_f64x8)));
|
|
941
|
+
cov_xy_f64x8 = _mm512_add_pd(cov_xy_f64x8, _mm512_add_pd(_mm512_mul_pd(a_x_low_f64x8, b_y_low_f64x8),
|
|
942
|
+
_mm512_mul_pd(a_x_high_f64x8, b_y_high_f64x8)));
|
|
943
|
+
cov_xz_f64x8 = _mm512_add_pd(cov_xz_f64x8, _mm512_add_pd(_mm512_mul_pd(a_x_low_f64x8, b_z_low_f64x8),
|
|
944
|
+
_mm512_mul_pd(a_x_high_f64x8, b_z_high_f64x8)));
|
|
945
|
+
cov_yx_f64x8 = _mm512_add_pd(cov_yx_f64x8, _mm512_add_pd(_mm512_mul_pd(a_y_low_f64x8, b_x_low_f64x8),
|
|
946
|
+
_mm512_mul_pd(a_y_high_f64x8, b_x_high_f64x8)));
|
|
947
|
+
cov_yy_f64x8 = _mm512_add_pd(cov_yy_f64x8, _mm512_add_pd(_mm512_mul_pd(a_y_low_f64x8, b_y_low_f64x8),
|
|
948
|
+
_mm512_mul_pd(a_y_high_f64x8, b_y_high_f64x8)));
|
|
949
|
+
cov_yz_f64x8 = _mm512_add_pd(cov_yz_f64x8, _mm512_add_pd(_mm512_mul_pd(a_y_low_f64x8, b_z_low_f64x8),
|
|
950
|
+
_mm512_mul_pd(a_y_high_f64x8, b_z_high_f64x8)));
|
|
951
|
+
cov_zx_f64x8 = _mm512_add_pd(cov_zx_f64x8, _mm512_add_pd(_mm512_mul_pd(a_z_low_f64x8, b_x_low_f64x8),
|
|
952
|
+
_mm512_mul_pd(a_z_high_f64x8, b_x_high_f64x8)));
|
|
953
|
+
cov_zy_f64x8 = _mm512_add_pd(cov_zy_f64x8, _mm512_add_pd(_mm512_mul_pd(a_z_low_f64x8, b_y_low_f64x8),
|
|
954
|
+
_mm512_mul_pd(a_z_high_f64x8, b_y_high_f64x8)));
|
|
955
|
+
cov_zz_f64x8 = _mm512_add_pd(cov_zz_f64x8, _mm512_add_pd(_mm512_mul_pd(a_z_low_f64x8, b_z_low_f64x8),
|
|
956
|
+
_mm512_mul_pd(a_z_high_f64x8, b_z_high_f64x8)));
|
|
957
|
+
}
|
|
958
|
+
|
|
959
|
+
nk_f64_t sum_a_x = _mm512_reduce_add_pd(sum_a_x_f64x8), sum_a_y = _mm512_reduce_add_pd(sum_a_y_f64x8),
|
|
960
|
+
sum_a_z = _mm512_reduce_add_pd(sum_a_z_f64x8);
|
|
961
|
+
nk_f64_t sum_b_x = _mm512_reduce_add_pd(sum_b_x_f64x8), sum_b_y = _mm512_reduce_add_pd(sum_b_y_f64x8),
|
|
962
|
+
sum_b_z = _mm512_reduce_add_pd(sum_b_z_f64x8);
|
|
963
|
+
nk_f64_t covariance_x_x = _mm512_reduce_add_pd(cov_xx_f64x8), covariance_x_y = _mm512_reduce_add_pd(cov_xy_f64x8),
|
|
964
|
+
covariance_x_z = _mm512_reduce_add_pd(cov_xz_f64x8);
|
|
965
|
+
nk_f64_t covariance_y_x = _mm512_reduce_add_pd(cov_yx_f64x8), covariance_y_y = _mm512_reduce_add_pd(cov_yy_f64x8),
|
|
966
|
+
covariance_y_z = _mm512_reduce_add_pd(cov_yz_f64x8);
|
|
967
|
+
nk_f64_t covariance_z_x = _mm512_reduce_add_pd(cov_zx_f64x8), covariance_z_y = _mm512_reduce_add_pd(cov_zy_f64x8),
|
|
968
|
+
covariance_z_z = _mm512_reduce_add_pd(cov_zz_f64x8);
|
|
969
|
+
|
|
970
|
+
nk_f64_t inv_n = 1.0 / (nk_f64_t)n;
|
|
971
|
+
nk_f64_t centroid_a_x = sum_a_x * inv_n, centroid_a_y = sum_a_y * inv_n, centroid_a_z = sum_a_z * inv_n;
|
|
972
|
+
nk_f64_t centroid_b_x = sum_b_x * inv_n, centroid_b_y = sum_b_y * inv_n, centroid_b_z = sum_b_z * inv_n;
|
|
480
973
|
if (a_centroid)
|
|
481
974
|
a_centroid[0] = (nk_f32_t)centroid_a_x, a_centroid[1] = (nk_f32_t)centroid_a_y,
|
|
482
975
|
a_centroid[2] = (nk_f32_t)centroid_a_z;
|
|
@@ -485,51 +978,40 @@ NK_PUBLIC void nk_kabsch_f32_skylake(nk_f32_t const *a, nk_f32_t const *b, nk_si
|
|
|
485
978
|
b_centroid[2] = (nk_f32_t)centroid_b_z;
|
|
486
979
|
if (scale) *scale = 1.0f;
|
|
487
980
|
|
|
488
|
-
nk_f64_t
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
981
|
+
nk_f64_t n_f64 = (nk_f64_t)n;
|
|
982
|
+
nk_f64_t cross_covariance[9];
|
|
983
|
+
cross_covariance[0] = covariance_x_x - n_f64 * centroid_a_x * centroid_b_x;
|
|
984
|
+
cross_covariance[1] = covariance_x_y - n_f64 * centroid_a_x * centroid_b_y;
|
|
985
|
+
cross_covariance[2] = covariance_x_z - n_f64 * centroid_a_x * centroid_b_z;
|
|
986
|
+
cross_covariance[3] = covariance_y_x - n_f64 * centroid_a_y * centroid_b_x;
|
|
987
|
+
cross_covariance[4] = covariance_y_y - n_f64 * centroid_a_y * centroid_b_y;
|
|
988
|
+
cross_covariance[5] = covariance_y_z - n_f64 * centroid_a_y * centroid_b_z;
|
|
989
|
+
cross_covariance[6] = covariance_z_x - n_f64 * centroid_a_z * centroid_b_x;
|
|
990
|
+
cross_covariance[7] = covariance_z_y - n_f64 * centroid_a_z * centroid_b_y;
|
|
991
|
+
cross_covariance[8] = covariance_z_z - n_f64 * centroid_a_z * centroid_b_z;
|
|
992
|
+
|
|
993
|
+
nk_f64_t svd_u[9], svd_s[9], svd_v[9];
|
|
994
|
+
nk_svd3x3_f64_(cross_covariance, svd_u, svd_s, svd_v);
|
|
995
|
+
nk_f64_t r[9];
|
|
996
|
+
nk_rotation_from_svd_f64_skylake_(svd_u, svd_v, r);
|
|
499
997
|
if (nk_det3x3_f64_(r) < 0) {
|
|
500
998
|
svd_v[2] = -svd_v[2], svd_v[5] = -svd_v[5], svd_v[8] = -svd_v[8];
|
|
501
|
-
|
|
502
|
-
r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
|
|
503
|
-
r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
|
|
504
|
-
r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
|
|
505
|
-
r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
|
|
506
|
-
r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
|
|
507
|
-
r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
|
|
508
|
-
r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
|
|
509
|
-
r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
|
|
999
|
+
nk_rotation_from_svd_f64_skylake_(svd_u, svd_v, r);
|
|
510
1000
|
}
|
|
511
1001
|
if (rotation)
|
|
512
|
-
for (int
|
|
1002
|
+
for (int j = 0; j < 9; ++j) rotation[j] = (nk_f32_t)r[j];
|
|
513
1003
|
*result = nk_f64_sqrt_haswell(nk_transformed_ssd_f32_skylake_(a, b, n, r, 1.0, centroid_a_x, centroid_a_y,
|
|
514
1004
|
centroid_a_z, centroid_b_x, centroid_b_y,
|
|
515
1005
|
centroid_b_z) /
|
|
516
|
-
|
|
1006
|
+
n_f64);
|
|
517
1007
|
}
|
|
518
1008
|
|
|
519
1009
|
NK_PUBLIC void nk_rmsd_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *a_centroid,
|
|
520
1010
|
nk_f64_t *b_centroid, nk_f64_t *rotation, nk_f64_t *scale, nk_f64_t *result) {
|
|
521
1011
|
// RMSD uses identity rotation and scale=1.0.
|
|
522
|
-
if (rotation)
|
|
523
|
-
rotation[0] = 1
|
|
524
|
-
rotation[
|
|
525
|
-
rotation[2] = 0;
|
|
526
|
-
rotation[3] = 0;
|
|
527
|
-
rotation[4] = 1;
|
|
528
|
-
rotation[5] = 0;
|
|
529
|
-
rotation[6] = 0;
|
|
530
|
-
rotation[7] = 0;
|
|
531
|
-
rotation[8] = 1;
|
|
532
|
-
}
|
|
1012
|
+
if (rotation)
|
|
1013
|
+
rotation[0] = 1, rotation[1] = 0, rotation[2] = 0, rotation[3] = 0, rotation[4] = 1, rotation[5] = 0,
|
|
1014
|
+
rotation[6] = 0, rotation[7] = 0, rotation[8] = 1;
|
|
533
1015
|
if (scale) *scale = 1.0;
|
|
534
1016
|
// Optimized fused single-pass implementation for f64.
|
|
535
1017
|
// Computes centroids and squared differences in one pass using the identity:
|
|
@@ -633,6 +1115,7 @@ NK_PUBLIC void nk_rmsd_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_size
|
|
|
633
1115
|
sum_squared_x_f64x8 = _mm512_fmadd_pd(delta_x_f64x8, delta_x_f64x8, sum_squared_x_f64x8);
|
|
634
1116
|
sum_squared_y_f64x8 = _mm512_fmadd_pd(delta_y_f64x8, delta_y_f64x8, sum_squared_y_f64x8);
|
|
635
1117
|
sum_squared_z_f64x8 = _mm512_fmadd_pd(delta_z_f64x8, delta_z_f64x8, sum_squared_z_f64x8);
|
|
1118
|
+
i = n;
|
|
636
1119
|
}
|
|
637
1120
|
|
|
638
1121
|
// Reduce and compute centroids.
|
|
@@ -759,6 +1242,7 @@ NK_PUBLIC void nk_kabsch_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_si
|
|
|
759
1242
|
cov_zx_f64x8 = _mm512_fmadd_pd(a_z_f64x8, b_x_f64x8, cov_zx_f64x8),
|
|
760
1243
|
cov_zy_f64x8 = _mm512_fmadd_pd(a_z_f64x8, b_y_f64x8, cov_zy_f64x8),
|
|
761
1244
|
cov_zz_f64x8 = _mm512_fmadd_pd(a_z_f64x8, b_z_f64x8, cov_zz_f64x8);
|
|
1245
|
+
i = n;
|
|
762
1246
|
}
|
|
763
1247
|
|
|
764
1248
|
// Reduce centroids and covariance.
|
|
@@ -840,9 +1324,8 @@ NK_PUBLIC void nk_kabsch_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_si
|
|
|
840
1324
|
}
|
|
841
1325
|
|
|
842
1326
|
// Output rotation matrix and scale=1.0.
|
|
843
|
-
if (rotation)
|
|
1327
|
+
if (rotation)
|
|
844
1328
|
for (int j = 0; j < 9; ++j) rotation[j] = (nk_f64_t)r[j];
|
|
845
|
-
}
|
|
846
1329
|
if (scale) *scale = 1.0;
|
|
847
1330
|
|
|
848
1331
|
// Compute RMSD after optimal rotation
|
|
@@ -851,51 +1334,153 @@ NK_PUBLIC void nk_kabsch_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_si
|
|
|
851
1334
|
*result = nk_f64_sqrt_haswell(sum_squared * inv_n);
|
|
852
1335
|
}
|
|
853
1336
|
|
|
854
|
-
|
|
855
|
-
|
|
856
|
-
|
|
857
|
-
|
|
858
|
-
|
|
859
|
-
|
|
860
|
-
__m512d
|
|
861
|
-
|
|
862
|
-
|
|
1337
|
+
NK_PUBLIC void nk_umeyama_f32_skylake(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f32_t *a_centroid,
|
|
1338
|
+
nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f64_t *result) {
|
|
1339
|
+
// Fused single-pass: centroids + covariance + variance of A, all in f64
|
|
1340
|
+
__m512d const zeros_f64x8 = _mm512_setzero_pd();
|
|
1341
|
+
__m512d sum_a_x_f64x8 = zeros_f64x8, sum_a_y_f64x8 = zeros_f64x8, sum_a_z_f64x8 = zeros_f64x8;
|
|
1342
|
+
__m512d sum_b_x_f64x8 = zeros_f64x8, sum_b_y_f64x8 = zeros_f64x8, sum_b_z_f64x8 = zeros_f64x8;
|
|
1343
|
+
__m512d cov_xx_f64x8 = zeros_f64x8, cov_xy_f64x8 = zeros_f64x8, cov_xz_f64x8 = zeros_f64x8;
|
|
1344
|
+
__m512d cov_yx_f64x8 = zeros_f64x8, cov_yy_f64x8 = zeros_f64x8, cov_yz_f64x8 = zeros_f64x8;
|
|
1345
|
+
__m512d cov_zx_f64x8 = zeros_f64x8, cov_zy_f64x8 = zeros_f64x8, cov_zz_f64x8 = zeros_f64x8;
|
|
1346
|
+
__m512d variance_a_f64x8 = zeros_f64x8;
|
|
1347
|
+
__m512 a_x_f32x16, a_y_f32x16, a_z_f32x16, b_x_f32x16, b_y_f32x16, b_z_f32x16;
|
|
1348
|
+
nk_size_t i = 0;
|
|
863
1349
|
|
|
864
|
-
for (;
|
|
865
|
-
nk_deinterleave_f32x16_skylake_(a +
|
|
866
|
-
|
|
867
|
-
__m512d
|
|
868
|
-
__m512d
|
|
869
|
-
__m512d
|
|
870
|
-
__m512d
|
|
871
|
-
__m512d
|
|
872
|
-
__m512d
|
|
873
|
-
|
|
874
|
-
|
|
875
|
-
|
|
876
|
-
|
|
877
|
-
|
|
878
|
-
|
|
1350
|
+
for (; i + 16 <= n; i += 16) {
|
|
1351
|
+
nk_deinterleave_f32x16_skylake_(a + i * 3, &a_x_f32x16, &a_y_f32x16, &a_z_f32x16);
|
|
1352
|
+
nk_deinterleave_f32x16_skylake_(b + i * 3, &b_x_f32x16, &b_y_f32x16, &b_z_f32x16);
|
|
1353
|
+
__m512d a_x_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_x_f32x16));
|
|
1354
|
+
__m512d a_x_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_x_f32x16, 1));
|
|
1355
|
+
__m512d a_y_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_y_f32x16));
|
|
1356
|
+
__m512d a_y_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_y_f32x16, 1));
|
|
1357
|
+
__m512d a_z_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_z_f32x16));
|
|
1358
|
+
__m512d a_z_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_z_f32x16, 1));
|
|
1359
|
+
__m512d b_x_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_x_f32x16));
|
|
1360
|
+
__m512d b_x_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_x_f32x16, 1));
|
|
1361
|
+
__m512d b_y_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_y_f32x16));
|
|
1362
|
+
__m512d b_y_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_y_f32x16, 1));
|
|
1363
|
+
__m512d b_z_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_z_f32x16));
|
|
1364
|
+
__m512d b_z_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_z_f32x16, 1));
|
|
1365
|
+
|
|
1366
|
+
sum_a_x_f64x8 = _mm512_add_pd(sum_a_x_f64x8, _mm512_add_pd(a_x_low_f64x8, a_x_high_f64x8));
|
|
1367
|
+
sum_a_y_f64x8 = _mm512_add_pd(sum_a_y_f64x8, _mm512_add_pd(a_y_low_f64x8, a_y_high_f64x8));
|
|
1368
|
+
sum_a_z_f64x8 = _mm512_add_pd(sum_a_z_f64x8, _mm512_add_pd(a_z_low_f64x8, a_z_high_f64x8));
|
|
1369
|
+
sum_b_x_f64x8 = _mm512_add_pd(sum_b_x_f64x8, _mm512_add_pd(b_x_low_f64x8, b_x_high_f64x8));
|
|
1370
|
+
sum_b_y_f64x8 = _mm512_add_pd(sum_b_y_f64x8, _mm512_add_pd(b_y_low_f64x8, b_y_high_f64x8));
|
|
1371
|
+
sum_b_z_f64x8 = _mm512_add_pd(sum_b_z_f64x8, _mm512_add_pd(b_z_low_f64x8, b_z_high_f64x8));
|
|
1372
|
+
|
|
1373
|
+
cov_xx_f64x8 = _mm512_add_pd(cov_xx_f64x8, _mm512_add_pd(_mm512_mul_pd(a_x_low_f64x8, b_x_low_f64x8),
|
|
1374
|
+
_mm512_mul_pd(a_x_high_f64x8, b_x_high_f64x8)));
|
|
1375
|
+
cov_xy_f64x8 = _mm512_add_pd(cov_xy_f64x8, _mm512_add_pd(_mm512_mul_pd(a_x_low_f64x8, b_y_low_f64x8),
|
|
1376
|
+
_mm512_mul_pd(a_x_high_f64x8, b_y_high_f64x8)));
|
|
1377
|
+
cov_xz_f64x8 = _mm512_add_pd(cov_xz_f64x8, _mm512_add_pd(_mm512_mul_pd(a_x_low_f64x8, b_z_low_f64x8),
|
|
1378
|
+
_mm512_mul_pd(a_x_high_f64x8, b_z_high_f64x8)));
|
|
1379
|
+
cov_yx_f64x8 = _mm512_add_pd(cov_yx_f64x8, _mm512_add_pd(_mm512_mul_pd(a_y_low_f64x8, b_x_low_f64x8),
|
|
1380
|
+
_mm512_mul_pd(a_y_high_f64x8, b_x_high_f64x8)));
|
|
1381
|
+
cov_yy_f64x8 = _mm512_add_pd(cov_yy_f64x8, _mm512_add_pd(_mm512_mul_pd(a_y_low_f64x8, b_y_low_f64x8),
|
|
1382
|
+
_mm512_mul_pd(a_y_high_f64x8, b_y_high_f64x8)));
|
|
1383
|
+
cov_yz_f64x8 = _mm512_add_pd(cov_yz_f64x8, _mm512_add_pd(_mm512_mul_pd(a_y_low_f64x8, b_z_low_f64x8),
|
|
1384
|
+
_mm512_mul_pd(a_y_high_f64x8, b_z_high_f64x8)));
|
|
1385
|
+
cov_zx_f64x8 = _mm512_add_pd(cov_zx_f64x8, _mm512_add_pd(_mm512_mul_pd(a_z_low_f64x8, b_x_low_f64x8),
|
|
1386
|
+
_mm512_mul_pd(a_z_high_f64x8, b_x_high_f64x8)));
|
|
1387
|
+
cov_zy_f64x8 = _mm512_add_pd(cov_zy_f64x8, _mm512_add_pd(_mm512_mul_pd(a_z_low_f64x8, b_y_low_f64x8),
|
|
1388
|
+
_mm512_mul_pd(a_z_high_f64x8, b_y_high_f64x8)));
|
|
1389
|
+
cov_zz_f64x8 = _mm512_add_pd(cov_zz_f64x8, _mm512_add_pd(_mm512_mul_pd(a_z_low_f64x8, b_z_low_f64x8),
|
|
1390
|
+
_mm512_mul_pd(a_z_high_f64x8, b_z_high_f64x8)));
|
|
1391
|
+
|
|
1392
|
+
variance_a_f64x8 = _mm512_add_pd(
|
|
1393
|
+
variance_a_f64x8,
|
|
1394
|
+
_mm512_add_pd(_mm512_mul_pd(a_x_low_f64x8, a_x_low_f64x8), _mm512_mul_pd(a_x_high_f64x8, a_x_high_f64x8)));
|
|
1395
|
+
variance_a_f64x8 = _mm512_add_pd(
|
|
1396
|
+
variance_a_f64x8,
|
|
1397
|
+
_mm512_add_pd(_mm512_mul_pd(a_y_low_f64x8, a_y_low_f64x8), _mm512_mul_pd(a_y_high_f64x8, a_y_high_f64x8)));
|
|
1398
|
+
variance_a_f64x8 = _mm512_add_pd(
|
|
1399
|
+
variance_a_f64x8,
|
|
1400
|
+
_mm512_add_pd(_mm512_mul_pd(a_z_low_f64x8, a_z_low_f64x8), _mm512_mul_pd(a_z_high_f64x8, a_z_high_f64x8)));
|
|
879
1401
|
}
|
|
880
1402
|
|
|
881
|
-
|
|
882
|
-
|
|
883
|
-
|
|
884
|
-
|
|
1403
|
+
// Tail: use masked gather for remaining < 16 points
|
|
1404
|
+
if (i < n) {
|
|
1405
|
+
nk_size_t tail = n - i;
|
|
1406
|
+
__mmask16 mask = (__mmask16)_bzhi_u32(0xFFFF, tail);
|
|
1407
|
+
__m512i const gather_idx_i32x16 = _mm512_setr_epi32(0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33, 36, 39, 42, 45);
|
|
1408
|
+
__m512 zeros_f32x16 = _mm512_setzero_ps();
|
|
1409
|
+
nk_f32_t const *a_tail = a + i * 3;
|
|
1410
|
+
nk_f32_t const *b_tail = b + i * 3;
|
|
1411
|
+
|
|
1412
|
+
a_x_f32x16 = _mm512_mask_i32gather_ps(zeros_f32x16, mask, gather_idx_i32x16, a_tail + 0, 4);
|
|
1413
|
+
a_y_f32x16 = _mm512_mask_i32gather_ps(zeros_f32x16, mask, gather_idx_i32x16, a_tail + 1, 4);
|
|
1414
|
+
a_z_f32x16 = _mm512_mask_i32gather_ps(zeros_f32x16, mask, gather_idx_i32x16, a_tail + 2, 4);
|
|
1415
|
+
b_x_f32x16 = _mm512_mask_i32gather_ps(zeros_f32x16, mask, gather_idx_i32x16, b_tail + 0, 4);
|
|
1416
|
+
b_y_f32x16 = _mm512_mask_i32gather_ps(zeros_f32x16, mask, gather_idx_i32x16, b_tail + 1, 4);
|
|
1417
|
+
b_z_f32x16 = _mm512_mask_i32gather_ps(zeros_f32x16, mask, gather_idx_i32x16, b_tail + 2, 4);
|
|
1418
|
+
|
|
1419
|
+
__m512d a_x_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_x_f32x16));
|
|
1420
|
+
__m512d a_x_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_x_f32x16, 1));
|
|
1421
|
+
__m512d a_y_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_y_f32x16));
|
|
1422
|
+
__m512d a_y_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_y_f32x16, 1));
|
|
1423
|
+
__m512d a_z_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_z_f32x16));
|
|
1424
|
+
__m512d a_z_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_z_f32x16, 1));
|
|
1425
|
+
__m512d b_x_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_x_f32x16));
|
|
1426
|
+
__m512d b_x_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_x_f32x16, 1));
|
|
1427
|
+
__m512d b_y_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_y_f32x16));
|
|
1428
|
+
__m512d b_y_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_y_f32x16, 1));
|
|
1429
|
+
__m512d b_z_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_z_f32x16));
|
|
1430
|
+
__m512d b_z_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_z_f32x16, 1));
|
|
1431
|
+
|
|
1432
|
+
sum_a_x_f64x8 = _mm512_add_pd(sum_a_x_f64x8, _mm512_add_pd(a_x_low_f64x8, a_x_high_f64x8));
|
|
1433
|
+
sum_a_y_f64x8 = _mm512_add_pd(sum_a_y_f64x8, _mm512_add_pd(a_y_low_f64x8, a_y_high_f64x8));
|
|
1434
|
+
sum_a_z_f64x8 = _mm512_add_pd(sum_a_z_f64x8, _mm512_add_pd(a_z_low_f64x8, a_z_high_f64x8));
|
|
1435
|
+
sum_b_x_f64x8 = _mm512_add_pd(sum_b_x_f64x8, _mm512_add_pd(b_x_low_f64x8, b_x_high_f64x8));
|
|
1436
|
+
sum_b_y_f64x8 = _mm512_add_pd(sum_b_y_f64x8, _mm512_add_pd(b_y_low_f64x8, b_y_high_f64x8));
|
|
1437
|
+
sum_b_z_f64x8 = _mm512_add_pd(sum_b_z_f64x8, _mm512_add_pd(b_z_low_f64x8, b_z_high_f64x8));
|
|
1438
|
+
|
|
1439
|
+
cov_xx_f64x8 = _mm512_add_pd(cov_xx_f64x8, _mm512_add_pd(_mm512_mul_pd(a_x_low_f64x8, b_x_low_f64x8),
|
|
1440
|
+
_mm512_mul_pd(a_x_high_f64x8, b_x_high_f64x8)));
|
|
1441
|
+
cov_xy_f64x8 = _mm512_add_pd(cov_xy_f64x8, _mm512_add_pd(_mm512_mul_pd(a_x_low_f64x8, b_y_low_f64x8),
|
|
1442
|
+
_mm512_mul_pd(a_x_high_f64x8, b_y_high_f64x8)));
|
|
1443
|
+
cov_xz_f64x8 = _mm512_add_pd(cov_xz_f64x8, _mm512_add_pd(_mm512_mul_pd(a_x_low_f64x8, b_z_low_f64x8),
|
|
1444
|
+
_mm512_mul_pd(a_x_high_f64x8, b_z_high_f64x8)));
|
|
1445
|
+
cov_yx_f64x8 = _mm512_add_pd(cov_yx_f64x8, _mm512_add_pd(_mm512_mul_pd(a_y_low_f64x8, b_x_low_f64x8),
|
|
1446
|
+
_mm512_mul_pd(a_y_high_f64x8, b_x_high_f64x8)));
|
|
1447
|
+
cov_yy_f64x8 = _mm512_add_pd(cov_yy_f64x8, _mm512_add_pd(_mm512_mul_pd(a_y_low_f64x8, b_y_low_f64x8),
|
|
1448
|
+
_mm512_mul_pd(a_y_high_f64x8, b_y_high_f64x8)));
|
|
1449
|
+
cov_yz_f64x8 = _mm512_add_pd(cov_yz_f64x8, _mm512_add_pd(_mm512_mul_pd(a_y_low_f64x8, b_z_low_f64x8),
|
|
1450
|
+
_mm512_mul_pd(a_y_high_f64x8, b_z_high_f64x8)));
|
|
1451
|
+
cov_zx_f64x8 = _mm512_add_pd(cov_zx_f64x8, _mm512_add_pd(_mm512_mul_pd(a_z_low_f64x8, b_x_low_f64x8),
|
|
1452
|
+
_mm512_mul_pd(a_z_high_f64x8, b_x_high_f64x8)));
|
|
1453
|
+
cov_zy_f64x8 = _mm512_add_pd(cov_zy_f64x8, _mm512_add_pd(_mm512_mul_pd(a_z_low_f64x8, b_y_low_f64x8),
|
|
1454
|
+
_mm512_mul_pd(a_z_high_f64x8, b_y_high_f64x8)));
|
|
1455
|
+
cov_zz_f64x8 = _mm512_add_pd(cov_zz_f64x8, _mm512_add_pd(_mm512_mul_pd(a_z_low_f64x8, b_z_low_f64x8),
|
|
1456
|
+
_mm512_mul_pd(a_z_high_f64x8, b_z_high_f64x8)));
|
|
1457
|
+
|
|
1458
|
+
variance_a_f64x8 = _mm512_add_pd(
|
|
1459
|
+
variance_a_f64x8,
|
|
1460
|
+
_mm512_add_pd(_mm512_mul_pd(a_x_low_f64x8, a_x_low_f64x8), _mm512_mul_pd(a_x_high_f64x8, a_x_high_f64x8)));
|
|
1461
|
+
variance_a_f64x8 = _mm512_add_pd(
|
|
1462
|
+
variance_a_f64x8,
|
|
1463
|
+
_mm512_add_pd(_mm512_mul_pd(a_y_low_f64x8, a_y_low_f64x8), _mm512_mul_pd(a_y_high_f64x8, a_y_high_f64x8)));
|
|
1464
|
+
variance_a_f64x8 = _mm512_add_pd(
|
|
1465
|
+
variance_a_f64x8,
|
|
1466
|
+
_mm512_add_pd(_mm512_mul_pd(a_z_low_f64x8, a_z_low_f64x8), _mm512_mul_pd(a_z_high_f64x8, a_z_high_f64x8)));
|
|
885
1467
|
}
|
|
886
1468
|
|
|
887
|
-
nk_f64_t
|
|
888
|
-
|
|
889
|
-
|
|
890
|
-
|
|
1469
|
+
nk_f64_t sum_a_x = _mm512_reduce_add_pd(sum_a_x_f64x8), sum_a_y = _mm512_reduce_add_pd(sum_a_y_f64x8),
|
|
1470
|
+
sum_a_z = _mm512_reduce_add_pd(sum_a_z_f64x8);
|
|
1471
|
+
nk_f64_t sum_b_x = _mm512_reduce_add_pd(sum_b_x_f64x8), sum_b_y = _mm512_reduce_add_pd(sum_b_y_f64x8),
|
|
1472
|
+
sum_b_z = _mm512_reduce_add_pd(sum_b_z_f64x8);
|
|
1473
|
+
nk_f64_t covariance_x_x = _mm512_reduce_add_pd(cov_xx_f64x8), covariance_x_y = _mm512_reduce_add_pd(cov_xy_f64x8),
|
|
1474
|
+
covariance_x_z = _mm512_reduce_add_pd(cov_xz_f64x8);
|
|
1475
|
+
nk_f64_t covariance_y_x = _mm512_reduce_add_pd(cov_yx_f64x8), covariance_y_y = _mm512_reduce_add_pd(cov_yy_f64x8),
|
|
1476
|
+
covariance_y_z = _mm512_reduce_add_pd(cov_yz_f64x8);
|
|
1477
|
+
nk_f64_t covariance_z_x = _mm512_reduce_add_pd(cov_zx_f64x8), covariance_z_y = _mm512_reduce_add_pd(cov_zy_f64x8),
|
|
1478
|
+
covariance_z_z = _mm512_reduce_add_pd(cov_zz_f64x8);
|
|
1479
|
+
nk_f64_t variance_a_sum = _mm512_reduce_add_pd(variance_a_f64x8);
|
|
891
1480
|
|
|
892
|
-
|
|
893
|
-
|
|
894
|
-
nk_f64_t
|
|
895
|
-
nk_f64_t cross_covariance_f64[9];
|
|
896
|
-
nk_centroid_and_cross_covariance_and_variance_f32_skylake_(a, b, n, ¢roid_a_x, ¢roid_a_y, ¢roid_a_z,
|
|
897
|
-
¢roid_b_x, ¢roid_b_y, ¢roid_b_z,
|
|
898
|
-
cross_covariance_f64, &variance_a);
|
|
1481
|
+
nk_f64_t inv_n = 1.0 / (nk_f64_t)n;
|
|
1482
|
+
nk_f64_t centroid_a_x = sum_a_x * inv_n, centroid_a_y = sum_a_y * inv_n, centroid_a_z = sum_a_z * inv_n;
|
|
1483
|
+
nk_f64_t centroid_b_x = sum_b_x * inv_n, centroid_b_y = sum_b_y * inv_n, centroid_b_z = sum_b_z * inv_n;
|
|
899
1484
|
if (a_centroid)
|
|
900
1485
|
a_centroid[0] = (nk_f32_t)centroid_a_x, a_centroid[1] = (nk_f32_t)centroid_a_y,
|
|
901
1486
|
a_centroid[2] = (nk_f32_t)centroid_a_z;
|
|
@@ -903,41 +1488,49 @@ NK_PUBLIC void nk_umeyama_f32_skylake(nk_f32_t const *a, nk_f32_t const *b, nk_s
|
|
|
903
1488
|
b_centroid[0] = (nk_f32_t)centroid_b_x, b_centroid[1] = (nk_f32_t)centroid_b_y,
|
|
904
1489
|
b_centroid[2] = (nk_f32_t)centroid_b_z;
|
|
905
1490
|
|
|
906
|
-
|
|
907
|
-
|
|
908
|
-
|
|
909
|
-
r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
|
|
910
|
-
r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
|
|
911
|
-
r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
|
|
912
|
-
r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
|
|
913
|
-
r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
|
|
914
|
-
r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
|
|
915
|
-
r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
|
|
916
|
-
r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
|
|
1491
|
+
// Compute centered covariance and variance
|
|
1492
|
+
nk_f64_t variance_a = variance_a_sum * inv_n -
|
|
1493
|
+
(centroid_a_x * centroid_a_x + centroid_a_y * centroid_a_y + centroid_a_z * centroid_a_z);
|
|
917
1494
|
|
|
1495
|
+
// Compute centered covariance matrix: Hᵢⱼ = Σ(aᵢ×bⱼ) - Σaᵢ × Σbⱼ / n
|
|
1496
|
+
nk_f64_t n_f64 = (nk_f64_t)n;
|
|
1497
|
+
nk_f64_t cross_covariance[9];
|
|
1498
|
+
cross_covariance[0] = covariance_x_x - n_f64 * centroid_a_x * centroid_b_x;
|
|
1499
|
+
cross_covariance[1] = covariance_x_y - n_f64 * centroid_a_x * centroid_b_y;
|
|
1500
|
+
cross_covariance[2] = covariance_x_z - n_f64 * centroid_a_x * centroid_b_z;
|
|
1501
|
+
cross_covariance[3] = covariance_y_x - n_f64 * centroid_a_y * centroid_b_x;
|
|
1502
|
+
cross_covariance[4] = covariance_y_y - n_f64 * centroid_a_y * centroid_b_y;
|
|
1503
|
+
cross_covariance[5] = covariance_y_z - n_f64 * centroid_a_y * centroid_b_z;
|
|
1504
|
+
cross_covariance[6] = covariance_z_x - n_f64 * centroid_a_z * centroid_b_x;
|
|
1505
|
+
cross_covariance[7] = covariance_z_y - n_f64 * centroid_a_z * centroid_b_y;
|
|
1506
|
+
cross_covariance[8] = covariance_z_z - n_f64 * centroid_a_z * centroid_b_z;
|
|
1507
|
+
|
|
1508
|
+
// SVD using f64 for full precision
|
|
1509
|
+
nk_f64_t svd_u[9], svd_s[9], svd_v[9];
|
|
1510
|
+
nk_svd3x3_f64_(cross_covariance, svd_u, svd_s, svd_v);
|
|
1511
|
+
|
|
1512
|
+
nk_f64_t r[9];
|
|
1513
|
+
nk_rotation_from_svd_f64_skylake_(svd_u, svd_v, r);
|
|
1514
|
+
|
|
1515
|
+
// Scale factor: c = trace(D × S) / (n × variance(a))
|
|
918
1516
|
nk_f64_t det = nk_det3x3_f64_(r);
|
|
919
|
-
nk_f64_t
|
|
920
|
-
nk_f64_t
|
|
1517
|
+
nk_f64_t d3 = det < 0 ? -1.0 : 1.0;
|
|
1518
|
+
nk_f64_t trace_ds = nk_sum_three_products_f64_(svd_s[0], 1.0, svd_s[4], 1.0, svd_s[8], d3);
|
|
1519
|
+
nk_f64_t applied_scale = trace_ds / ((nk_f64_t)n * variance_a);
|
|
1520
|
+
if (scale) *scale = (nk_f32_t)applied_scale;
|
|
1521
|
+
|
|
1522
|
+
// Handle reflection
|
|
921
1523
|
if (det < 0) {
|
|
922
1524
|
svd_v[2] = -svd_v[2], svd_v[5] = -svd_v[5], svd_v[8] = -svd_v[8];
|
|
923
|
-
|
|
924
|
-
r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
|
|
925
|
-
r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
|
|
926
|
-
r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
|
|
927
|
-
r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
|
|
928
|
-
r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
|
|
929
|
-
r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
|
|
930
|
-
r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
|
|
931
|
-
r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
|
|
1525
|
+
nk_rotation_from_svd_f64_skylake_(svd_u, svd_v, r);
|
|
932
1526
|
}
|
|
933
1527
|
|
|
934
1528
|
if (rotation)
|
|
935
|
-
for (int
|
|
936
|
-
if (scale) *scale = (nk_f32_t)applied_scale;
|
|
1529
|
+
for (int j = 0; j < 9; ++j) rotation[j] = (nk_f32_t)r[j];
|
|
937
1530
|
*result = nk_f64_sqrt_haswell(nk_transformed_ssd_f32_skylake_(a, b, n, r, applied_scale, centroid_a_x, centroid_a_y,
|
|
938
1531
|
centroid_a_z, centroid_b_x, centroid_b_y,
|
|
939
1532
|
centroid_b_z) /
|
|
940
|
-
|
|
1533
|
+
n_f64);
|
|
941
1534
|
}
|
|
942
1535
|
|
|
943
1536
|
NK_PUBLIC void nk_umeyama_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *a_centroid,
|
|
@@ -1013,6 +1606,7 @@ NK_PUBLIC void nk_umeyama_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_s
|
|
|
1013
1606
|
variance_a_f64x8 = _mm512_fmadd_pd(a_x_f64x8, a_x_f64x8, variance_a_f64x8);
|
|
1014
1607
|
variance_a_f64x8 = _mm512_fmadd_pd(a_y_f64x8, a_y_f64x8, variance_a_f64x8);
|
|
1015
1608
|
variance_a_f64x8 = _mm512_fmadd_pd(a_z_f64x8, a_z_f64x8, variance_a_f64x8);
|
|
1609
|
+
i = n;
|
|
1016
1610
|
}
|
|
1017
1611
|
|
|
1018
1612
|
// Reduce centroids, covariance, and variance.
|
|
@@ -1100,7 +1694,7 @@ NK_PUBLIC void nk_umeyama_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_s
|
|
|
1100
1694
|
nk_f64_t det = nk_det3x3_f64_(r);
|
|
1101
1695
|
nk_f64_t d3 = det < 0 ? -1.0 : 1.0;
|
|
1102
1696
|
nk_f64_t trace_ds = nk_sum_three_products_f64_(svd_s[0], 1.0, svd_s[4], 1.0, svd_s[8], d3);
|
|
1103
|
-
nk_f64_t c = trace_ds / (n * variance_a);
|
|
1697
|
+
nk_f64_t c = trace_ds / ((nk_f64_t)n * variance_a);
|
|
1104
1698
|
if (scale) *scale = c;
|
|
1105
1699
|
|
|
1106
1700
|
// Handle reflection
|
|
@@ -1110,9 +1704,8 @@ NK_PUBLIC void nk_umeyama_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_s
|
|
|
1110
1704
|
}
|
|
1111
1705
|
|
|
1112
1706
|
// Output rotation matrix.
|
|
1113
|
-
if (rotation)
|
|
1707
|
+
if (rotation)
|
|
1114
1708
|
for (int j = 0; j < 9; ++j) rotation[j] = (nk_f64_t)r[j];
|
|
1115
|
-
}
|
|
1116
1709
|
|
|
1117
1710
|
// Compute RMSD with scaling
|
|
1118
1711
|
nk_f64_t sum_squared = nk_transformed_ssd_f64_skylake_(a, b, n, r, c, centroid_a_x, centroid_a_y, centroid_a_z,
|
|
@@ -1120,6 +1713,738 @@ NK_PUBLIC void nk_umeyama_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_s
|
|
|
1120
1713
|
*result = nk_f64_sqrt_haswell(sum_squared * inv_n);
|
|
1121
1714
|
}
|
|
1122
1715
|
|
|
1716
|
+
NK_PUBLIC void nk_rmsd_f16_skylake(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
|
|
1717
|
+
nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
|
|
1718
|
+
if (rotation)
|
|
1719
|
+
rotation[0] = 1, rotation[1] = 0, rotation[2] = 0, rotation[3] = 0, rotation[4] = 1, rotation[5] = 0,
|
|
1720
|
+
rotation[6] = 0, rotation[7] = 0, rotation[8] = 1;
|
|
1721
|
+
if (scale) *scale = 1.0f;
|
|
1722
|
+
|
|
1723
|
+
__m512 const zeros_f32x16 = _mm512_setzero_ps();
|
|
1724
|
+
__m512 sum_a_x_f32x16 = zeros_f32x16, sum_a_y_f32x16 = zeros_f32x16, sum_a_z_f32x16 = zeros_f32x16;
|
|
1725
|
+
__m512 sum_b_x_f32x16 = zeros_f32x16, sum_b_y_f32x16 = zeros_f32x16, sum_b_z_f32x16 = zeros_f32x16;
|
|
1726
|
+
__m512 sum_squared_x_f32x16 = zeros_f32x16, sum_squared_y_f32x16 = zeros_f32x16;
|
|
1727
|
+
__m512 sum_squared_z_f32x16 = zeros_f32x16;
|
|
1728
|
+
__m512 a_x_f32x16, a_y_f32x16, a_z_f32x16, b_x_f32x16, b_y_f32x16, b_z_f32x16;
|
|
1729
|
+
nk_size_t i = 0;
|
|
1730
|
+
|
|
1731
|
+
for (; i + 16 <= n; i += 16) {
|
|
1732
|
+
nk_deinterleave_f16x16_to_f32x16_skylake_(a + i * 3, &a_x_f32x16, &a_y_f32x16, &a_z_f32x16);
|
|
1733
|
+
nk_deinterleave_f16x16_to_f32x16_skylake_(b + i * 3, &b_x_f32x16, &b_y_f32x16, &b_z_f32x16);
|
|
1734
|
+
|
|
1735
|
+
sum_a_x_f32x16 = _mm512_add_ps(sum_a_x_f32x16, a_x_f32x16);
|
|
1736
|
+
sum_a_y_f32x16 = _mm512_add_ps(sum_a_y_f32x16, a_y_f32x16);
|
|
1737
|
+
sum_a_z_f32x16 = _mm512_add_ps(sum_a_z_f32x16, a_z_f32x16);
|
|
1738
|
+
sum_b_x_f32x16 = _mm512_add_ps(sum_b_x_f32x16, b_x_f32x16);
|
|
1739
|
+
sum_b_y_f32x16 = _mm512_add_ps(sum_b_y_f32x16, b_y_f32x16);
|
|
1740
|
+
sum_b_z_f32x16 = _mm512_add_ps(sum_b_z_f32x16, b_z_f32x16);
|
|
1741
|
+
|
|
1742
|
+
__m512 delta_x_f32x16 = _mm512_sub_ps(a_x_f32x16, b_x_f32x16);
|
|
1743
|
+
__m512 delta_y_f32x16 = _mm512_sub_ps(a_y_f32x16, b_y_f32x16);
|
|
1744
|
+
__m512 delta_z_f32x16 = _mm512_sub_ps(a_z_f32x16, b_z_f32x16);
|
|
1745
|
+
|
|
1746
|
+
sum_squared_x_f32x16 = _mm512_fmadd_ps(delta_x_f32x16, delta_x_f32x16, sum_squared_x_f32x16);
|
|
1747
|
+
sum_squared_y_f32x16 = _mm512_fmadd_ps(delta_y_f32x16, delta_y_f32x16, sum_squared_y_f32x16);
|
|
1748
|
+
sum_squared_z_f32x16 = _mm512_fmadd_ps(delta_z_f32x16, delta_z_f32x16, sum_squared_z_f32x16);
|
|
1749
|
+
}
|
|
1750
|
+
|
|
1751
|
+
// Tail: deinterleave remaining points into zero-initialized vectors
|
|
1752
|
+
if (i < n) {
|
|
1753
|
+
nk_size_t tail = n - i;
|
|
1754
|
+
nk_deinterleave_f16_tail_to_f32x16_skylake_(a + i * 3, tail, &a_x_f32x16, &a_y_f32x16, &a_z_f32x16);
|
|
1755
|
+
nk_deinterleave_f16_tail_to_f32x16_skylake_(b + i * 3, tail, &b_x_f32x16, &b_y_f32x16, &b_z_f32x16);
|
|
1756
|
+
|
|
1757
|
+
sum_a_x_f32x16 = _mm512_add_ps(sum_a_x_f32x16, a_x_f32x16);
|
|
1758
|
+
sum_a_y_f32x16 = _mm512_add_ps(sum_a_y_f32x16, a_y_f32x16);
|
|
1759
|
+
sum_a_z_f32x16 = _mm512_add_ps(sum_a_z_f32x16, a_z_f32x16);
|
|
1760
|
+
sum_b_x_f32x16 = _mm512_add_ps(sum_b_x_f32x16, b_x_f32x16);
|
|
1761
|
+
sum_b_y_f32x16 = _mm512_add_ps(sum_b_y_f32x16, b_y_f32x16);
|
|
1762
|
+
sum_b_z_f32x16 = _mm512_add_ps(sum_b_z_f32x16, b_z_f32x16);
|
|
1763
|
+
|
|
1764
|
+
__m512 delta_x_f32x16 = _mm512_sub_ps(a_x_f32x16, b_x_f32x16);
|
|
1765
|
+
__m512 delta_y_f32x16 = _mm512_sub_ps(a_y_f32x16, b_y_f32x16);
|
|
1766
|
+
__m512 delta_z_f32x16 = _mm512_sub_ps(a_z_f32x16, b_z_f32x16);
|
|
1767
|
+
|
|
1768
|
+
sum_squared_x_f32x16 = _mm512_fmadd_ps(delta_x_f32x16, delta_x_f32x16, sum_squared_x_f32x16);
|
|
1769
|
+
sum_squared_y_f32x16 = _mm512_fmadd_ps(delta_y_f32x16, delta_y_f32x16, sum_squared_y_f32x16);
|
|
1770
|
+
sum_squared_z_f32x16 = _mm512_fmadd_ps(delta_z_f32x16, delta_z_f32x16, sum_squared_z_f32x16);
|
|
1771
|
+
}
|
|
1772
|
+
|
|
1773
|
+
nk_f32_t total_ax = _mm512_reduce_add_ps(sum_a_x_f32x16);
|
|
1774
|
+
nk_f32_t total_ay = _mm512_reduce_add_ps(sum_a_y_f32x16);
|
|
1775
|
+
nk_f32_t total_az = _mm512_reduce_add_ps(sum_a_z_f32x16);
|
|
1776
|
+
nk_f32_t total_bx = _mm512_reduce_add_ps(sum_b_x_f32x16);
|
|
1777
|
+
nk_f32_t total_by = _mm512_reduce_add_ps(sum_b_y_f32x16);
|
|
1778
|
+
nk_f32_t total_bz = _mm512_reduce_add_ps(sum_b_z_f32x16);
|
|
1779
|
+
nk_f32_t total_sq_x = _mm512_reduce_add_ps(sum_squared_x_f32x16);
|
|
1780
|
+
nk_f32_t total_sq_y = _mm512_reduce_add_ps(sum_squared_y_f32x16);
|
|
1781
|
+
nk_f32_t total_sq_z = _mm512_reduce_add_ps(sum_squared_z_f32x16);
|
|
1782
|
+
|
|
1783
|
+
nk_f32_t inv_n = 1.0f / (nk_f32_t)n;
|
|
1784
|
+
nk_f32_t centroid_a_x = total_ax * inv_n;
|
|
1785
|
+
nk_f32_t centroid_a_y = total_ay * inv_n;
|
|
1786
|
+
nk_f32_t centroid_a_z = total_az * inv_n;
|
|
1787
|
+
nk_f32_t centroid_b_x = total_bx * inv_n;
|
|
1788
|
+
nk_f32_t centroid_b_y = total_by * inv_n;
|
|
1789
|
+
nk_f32_t centroid_b_z = total_bz * inv_n;
|
|
1790
|
+
|
|
1791
|
+
if (a_centroid) a_centroid[0] = centroid_a_x, a_centroid[1] = centroid_a_y, a_centroid[2] = centroid_a_z;
|
|
1792
|
+
if (b_centroid) b_centroid[0] = centroid_b_x, b_centroid[1] = centroid_b_y, b_centroid[2] = centroid_b_z;
|
|
1793
|
+
|
|
1794
|
+
nk_f32_t mean_diff_x = centroid_a_x - centroid_b_x;
|
|
1795
|
+
nk_f32_t mean_diff_y = centroid_a_y - centroid_b_y;
|
|
1796
|
+
nk_f32_t mean_diff_z = centroid_a_z - centroid_b_z;
|
|
1797
|
+
nk_f32_t sum_squared = total_sq_x + total_sq_y + total_sq_z;
|
|
1798
|
+
nk_f32_t mean_diff_sq = mean_diff_x * mean_diff_x + mean_diff_y * mean_diff_y + mean_diff_z * mean_diff_z;
|
|
1799
|
+
|
|
1800
|
+
*result = nk_f32_sqrt_haswell(sum_squared * inv_n - mean_diff_sq);
|
|
1801
|
+
}
|
|
1802
|
+
|
|
1803
|
+
NK_PUBLIC void nk_rmsd_bf16_skylake(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
|
|
1804
|
+
nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
|
|
1805
|
+
if (rotation)
|
|
1806
|
+
rotation[0] = 1, rotation[1] = 0, rotation[2] = 0, rotation[3] = 0, rotation[4] = 1, rotation[5] = 0,
|
|
1807
|
+
rotation[6] = 0, rotation[7] = 0, rotation[8] = 1;
|
|
1808
|
+
if (scale) *scale = 1.0f;
|
|
1809
|
+
|
|
1810
|
+
__m512 const zeros_f32x16 = _mm512_setzero_ps();
|
|
1811
|
+
__m512 sum_a_x_f32x16 = zeros_f32x16, sum_a_y_f32x16 = zeros_f32x16, sum_a_z_f32x16 = zeros_f32x16;
|
|
1812
|
+
__m512 sum_b_x_f32x16 = zeros_f32x16, sum_b_y_f32x16 = zeros_f32x16, sum_b_z_f32x16 = zeros_f32x16;
|
|
1813
|
+
__m512 sum_squared_x_f32x16 = zeros_f32x16, sum_squared_y_f32x16 = zeros_f32x16;
|
|
1814
|
+
__m512 sum_squared_z_f32x16 = zeros_f32x16;
|
|
1815
|
+
__m512 a_x_f32x16, a_y_f32x16, a_z_f32x16, b_x_f32x16, b_y_f32x16, b_z_f32x16;
|
|
1816
|
+
nk_size_t i = 0;
|
|
1817
|
+
|
|
1818
|
+
for (; i + 16 <= n; i += 16) {
|
|
1819
|
+
nk_deinterleave_bf16x16_to_f32x16_skylake_(a + i * 3, &a_x_f32x16, &a_y_f32x16, &a_z_f32x16);
|
|
1820
|
+
nk_deinterleave_bf16x16_to_f32x16_skylake_(b + i * 3, &b_x_f32x16, &b_y_f32x16, &b_z_f32x16);
|
|
1821
|
+
|
|
1822
|
+
sum_a_x_f32x16 = _mm512_add_ps(sum_a_x_f32x16, a_x_f32x16);
|
|
1823
|
+
sum_a_y_f32x16 = _mm512_add_ps(sum_a_y_f32x16, a_y_f32x16);
|
|
1824
|
+
sum_a_z_f32x16 = _mm512_add_ps(sum_a_z_f32x16, a_z_f32x16);
|
|
1825
|
+
sum_b_x_f32x16 = _mm512_add_ps(sum_b_x_f32x16, b_x_f32x16);
|
|
1826
|
+
sum_b_y_f32x16 = _mm512_add_ps(sum_b_y_f32x16, b_y_f32x16);
|
|
1827
|
+
sum_b_z_f32x16 = _mm512_add_ps(sum_b_z_f32x16, b_z_f32x16);
|
|
1828
|
+
|
|
1829
|
+
__m512 delta_x_f32x16 = _mm512_sub_ps(a_x_f32x16, b_x_f32x16);
|
|
1830
|
+
__m512 delta_y_f32x16 = _mm512_sub_ps(a_y_f32x16, b_y_f32x16);
|
|
1831
|
+
__m512 delta_z_f32x16 = _mm512_sub_ps(a_z_f32x16, b_z_f32x16);
|
|
1832
|
+
|
|
1833
|
+
sum_squared_x_f32x16 = _mm512_fmadd_ps(delta_x_f32x16, delta_x_f32x16, sum_squared_x_f32x16);
|
|
1834
|
+
sum_squared_y_f32x16 = _mm512_fmadd_ps(delta_y_f32x16, delta_y_f32x16, sum_squared_y_f32x16);
|
|
1835
|
+
sum_squared_z_f32x16 = _mm512_fmadd_ps(delta_z_f32x16, delta_z_f32x16, sum_squared_z_f32x16);
|
|
1836
|
+
}
|
|
1837
|
+
|
|
1838
|
+
// Tail: deinterleave remaining points into zero-initialized vectors
|
|
1839
|
+
if (i < n) {
|
|
1840
|
+
nk_size_t tail = n - i;
|
|
1841
|
+
nk_deinterleave_bf16_tail_to_f32x16_skylake_(a + i * 3, tail, &a_x_f32x16, &a_y_f32x16, &a_z_f32x16);
|
|
1842
|
+
nk_deinterleave_bf16_tail_to_f32x16_skylake_(b + i * 3, tail, &b_x_f32x16, &b_y_f32x16, &b_z_f32x16);
|
|
1843
|
+
|
|
1844
|
+
sum_a_x_f32x16 = _mm512_add_ps(sum_a_x_f32x16, a_x_f32x16);
|
|
1845
|
+
sum_a_y_f32x16 = _mm512_add_ps(sum_a_y_f32x16, a_y_f32x16);
|
|
1846
|
+
sum_a_z_f32x16 = _mm512_add_ps(sum_a_z_f32x16, a_z_f32x16);
|
|
1847
|
+
sum_b_x_f32x16 = _mm512_add_ps(sum_b_x_f32x16, b_x_f32x16);
|
|
1848
|
+
sum_b_y_f32x16 = _mm512_add_ps(sum_b_y_f32x16, b_y_f32x16);
|
|
1849
|
+
sum_b_z_f32x16 = _mm512_add_ps(sum_b_z_f32x16, b_z_f32x16);
|
|
1850
|
+
|
|
1851
|
+
__m512 delta_x_f32x16 = _mm512_sub_ps(a_x_f32x16, b_x_f32x16);
|
|
1852
|
+
__m512 delta_y_f32x16 = _mm512_sub_ps(a_y_f32x16, b_y_f32x16);
|
|
1853
|
+
__m512 delta_z_f32x16 = _mm512_sub_ps(a_z_f32x16, b_z_f32x16);
|
|
1854
|
+
|
|
1855
|
+
sum_squared_x_f32x16 = _mm512_fmadd_ps(delta_x_f32x16, delta_x_f32x16, sum_squared_x_f32x16);
|
|
1856
|
+
sum_squared_y_f32x16 = _mm512_fmadd_ps(delta_y_f32x16, delta_y_f32x16, sum_squared_y_f32x16);
|
|
1857
|
+
sum_squared_z_f32x16 = _mm512_fmadd_ps(delta_z_f32x16, delta_z_f32x16, sum_squared_z_f32x16);
|
|
1858
|
+
}
|
|
1859
|
+
|
|
1860
|
+
nk_f32_t total_ax = _mm512_reduce_add_ps(sum_a_x_f32x16);
|
|
1861
|
+
nk_f32_t total_ay = _mm512_reduce_add_ps(sum_a_y_f32x16);
|
|
1862
|
+
nk_f32_t total_az = _mm512_reduce_add_ps(sum_a_z_f32x16);
|
|
1863
|
+
nk_f32_t total_bx = _mm512_reduce_add_ps(sum_b_x_f32x16);
|
|
1864
|
+
nk_f32_t total_by = _mm512_reduce_add_ps(sum_b_y_f32x16);
|
|
1865
|
+
nk_f32_t total_bz = _mm512_reduce_add_ps(sum_b_z_f32x16);
|
|
1866
|
+
nk_f32_t total_sq_x = _mm512_reduce_add_ps(sum_squared_x_f32x16);
|
|
1867
|
+
nk_f32_t total_sq_y = _mm512_reduce_add_ps(sum_squared_y_f32x16);
|
|
1868
|
+
nk_f32_t total_sq_z = _mm512_reduce_add_ps(sum_squared_z_f32x16);
|
|
1869
|
+
|
|
1870
|
+
nk_f32_t inv_n = 1.0f / (nk_f32_t)n;
|
|
1871
|
+
nk_f32_t centroid_a_x = total_ax * inv_n;
|
|
1872
|
+
nk_f32_t centroid_a_y = total_ay * inv_n;
|
|
1873
|
+
nk_f32_t centroid_a_z = total_az * inv_n;
|
|
1874
|
+
nk_f32_t centroid_b_x = total_bx * inv_n;
|
|
1875
|
+
nk_f32_t centroid_b_y = total_by * inv_n;
|
|
1876
|
+
nk_f32_t centroid_b_z = total_bz * inv_n;
|
|
1877
|
+
|
|
1878
|
+
if (a_centroid) a_centroid[0] = centroid_a_x, a_centroid[1] = centroid_a_y, a_centroid[2] = centroid_a_z;
|
|
1879
|
+
if (b_centroid) b_centroid[0] = centroid_b_x, b_centroid[1] = centroid_b_y, b_centroid[2] = centroid_b_z;
|
|
1880
|
+
|
|
1881
|
+
nk_f32_t mean_diff_x = centroid_a_x - centroid_b_x;
|
|
1882
|
+
nk_f32_t mean_diff_y = centroid_a_y - centroid_b_y;
|
|
1883
|
+
nk_f32_t mean_diff_z = centroid_a_z - centroid_b_z;
|
|
1884
|
+
nk_f32_t sum_squared = total_sq_x + total_sq_y + total_sq_z;
|
|
1885
|
+
nk_f32_t mean_diff_sq = mean_diff_x * mean_diff_x + mean_diff_y * mean_diff_y + mean_diff_z * mean_diff_z;
|
|
1886
|
+
|
|
1887
|
+
*result = nk_f32_sqrt_haswell(sum_squared * inv_n - mean_diff_sq);
|
|
1888
|
+
}
|
|
1889
|
+
|
|
1890
|
+
NK_PUBLIC void nk_kabsch_f16_skylake(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
|
|
1891
|
+
nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
|
|
1892
|
+
__m512 const zeros_f32x16 = _mm512_setzero_ps();
|
|
1893
|
+
|
|
1894
|
+
__m512 sum_a_x_f32x16 = zeros_f32x16, sum_a_y_f32x16 = zeros_f32x16, sum_a_z_f32x16 = zeros_f32x16;
|
|
1895
|
+
__m512 sum_b_x_f32x16 = zeros_f32x16, sum_b_y_f32x16 = zeros_f32x16, sum_b_z_f32x16 = zeros_f32x16;
|
|
1896
|
+
__m512 cov_xx_f32x16 = zeros_f32x16, cov_xy_f32x16 = zeros_f32x16, cov_xz_f32x16 = zeros_f32x16;
|
|
1897
|
+
__m512 cov_yx_f32x16 = zeros_f32x16, cov_yy_f32x16 = zeros_f32x16, cov_yz_f32x16 = zeros_f32x16;
|
|
1898
|
+
__m512 cov_zx_f32x16 = zeros_f32x16, cov_zy_f32x16 = zeros_f32x16, cov_zz_f32x16 = zeros_f32x16;
|
|
1899
|
+
|
|
1900
|
+
nk_size_t i = 0;
|
|
1901
|
+
__m512 a_x_f32x16, a_y_f32x16, a_z_f32x16, b_x_f32x16, b_y_f32x16, b_z_f32x16;
|
|
1902
|
+
|
|
1903
|
+
for (; i + 16 <= n; i += 16) {
|
|
1904
|
+
nk_deinterleave_f16x16_to_f32x16_skylake_(a + i * 3, &a_x_f32x16, &a_y_f32x16, &a_z_f32x16);
|
|
1905
|
+
nk_deinterleave_f16x16_to_f32x16_skylake_(b + i * 3, &b_x_f32x16, &b_y_f32x16, &b_z_f32x16);
|
|
1906
|
+
|
|
1907
|
+
sum_a_x_f32x16 = _mm512_add_ps(sum_a_x_f32x16, a_x_f32x16);
|
|
1908
|
+
sum_a_y_f32x16 = _mm512_add_ps(sum_a_y_f32x16, a_y_f32x16);
|
|
1909
|
+
sum_a_z_f32x16 = _mm512_add_ps(sum_a_z_f32x16, a_z_f32x16);
|
|
1910
|
+
sum_b_x_f32x16 = _mm512_add_ps(sum_b_x_f32x16, b_x_f32x16);
|
|
1911
|
+
sum_b_y_f32x16 = _mm512_add_ps(sum_b_y_f32x16, b_y_f32x16);
|
|
1912
|
+
sum_b_z_f32x16 = _mm512_add_ps(sum_b_z_f32x16, b_z_f32x16);
|
|
1913
|
+
|
|
1914
|
+
cov_xx_f32x16 = _mm512_fmadd_ps(a_x_f32x16, b_x_f32x16, cov_xx_f32x16);
|
|
1915
|
+
cov_xy_f32x16 = _mm512_fmadd_ps(a_x_f32x16, b_y_f32x16, cov_xy_f32x16);
|
|
1916
|
+
cov_xz_f32x16 = _mm512_fmadd_ps(a_x_f32x16, b_z_f32x16, cov_xz_f32x16);
|
|
1917
|
+
cov_yx_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_x_f32x16, cov_yx_f32x16);
|
|
1918
|
+
cov_yy_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_y_f32x16, cov_yy_f32x16);
|
|
1919
|
+
cov_yz_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_z_f32x16, cov_yz_f32x16);
|
|
1920
|
+
cov_zx_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_x_f32x16, cov_zx_f32x16);
|
|
1921
|
+
cov_zy_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_y_f32x16, cov_zy_f32x16);
|
|
1922
|
+
cov_zz_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_z_f32x16, cov_zz_f32x16);
|
|
1923
|
+
}
|
|
1924
|
+
|
|
1925
|
+
// Tail: deinterleave remaining points into zero-initialized vectors
|
|
1926
|
+
if (i < n) {
|
|
1927
|
+
nk_size_t tail = n - i;
|
|
1928
|
+
nk_deinterleave_f16_tail_to_f32x16_skylake_(a + i * 3, tail, &a_x_f32x16, &a_y_f32x16, &a_z_f32x16);
|
|
1929
|
+
nk_deinterleave_f16_tail_to_f32x16_skylake_(b + i * 3, tail, &b_x_f32x16, &b_y_f32x16, &b_z_f32x16);
|
|
1930
|
+
|
|
1931
|
+
sum_a_x_f32x16 = _mm512_add_ps(sum_a_x_f32x16, a_x_f32x16);
|
|
1932
|
+
sum_a_y_f32x16 = _mm512_add_ps(sum_a_y_f32x16, a_y_f32x16);
|
|
1933
|
+
sum_a_z_f32x16 = _mm512_add_ps(sum_a_z_f32x16, a_z_f32x16);
|
|
1934
|
+
sum_b_x_f32x16 = _mm512_add_ps(sum_b_x_f32x16, b_x_f32x16);
|
|
1935
|
+
sum_b_y_f32x16 = _mm512_add_ps(sum_b_y_f32x16, b_y_f32x16);
|
|
1936
|
+
sum_b_z_f32x16 = _mm512_add_ps(sum_b_z_f32x16, b_z_f32x16);
|
|
1937
|
+
|
|
1938
|
+
cov_xx_f32x16 = _mm512_fmadd_ps(a_x_f32x16, b_x_f32x16, cov_xx_f32x16);
|
|
1939
|
+
cov_xy_f32x16 = _mm512_fmadd_ps(a_x_f32x16, b_y_f32x16, cov_xy_f32x16);
|
|
1940
|
+
cov_xz_f32x16 = _mm512_fmadd_ps(a_x_f32x16, b_z_f32x16, cov_xz_f32x16);
|
|
1941
|
+
cov_yx_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_x_f32x16, cov_yx_f32x16);
|
|
1942
|
+
cov_yy_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_y_f32x16, cov_yy_f32x16);
|
|
1943
|
+
cov_yz_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_z_f32x16, cov_yz_f32x16);
|
|
1944
|
+
cov_zx_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_x_f32x16, cov_zx_f32x16);
|
|
1945
|
+
cov_zy_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_y_f32x16, cov_zy_f32x16);
|
|
1946
|
+
cov_zz_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_z_f32x16, cov_zz_f32x16);
|
|
1947
|
+
}
|
|
1948
|
+
|
|
1949
|
+
nk_f32_t sum_a_x = _mm512_reduce_add_ps(sum_a_x_f32x16);
|
|
1950
|
+
nk_f32_t sum_a_y = _mm512_reduce_add_ps(sum_a_y_f32x16);
|
|
1951
|
+
nk_f32_t sum_a_z = _mm512_reduce_add_ps(sum_a_z_f32x16);
|
|
1952
|
+
nk_f32_t sum_b_x = _mm512_reduce_add_ps(sum_b_x_f32x16);
|
|
1953
|
+
nk_f32_t sum_b_y = _mm512_reduce_add_ps(sum_b_y_f32x16);
|
|
1954
|
+
nk_f32_t sum_b_z = _mm512_reduce_add_ps(sum_b_z_f32x16);
|
|
1955
|
+
nk_f32_t covariance_x_x = _mm512_reduce_add_ps(cov_xx_f32x16);
|
|
1956
|
+
nk_f32_t covariance_x_y = _mm512_reduce_add_ps(cov_xy_f32x16);
|
|
1957
|
+
nk_f32_t covariance_x_z = _mm512_reduce_add_ps(cov_xz_f32x16);
|
|
1958
|
+
nk_f32_t covariance_y_x = _mm512_reduce_add_ps(cov_yx_f32x16);
|
|
1959
|
+
nk_f32_t covariance_y_y = _mm512_reduce_add_ps(cov_yy_f32x16);
|
|
1960
|
+
nk_f32_t covariance_y_z = _mm512_reduce_add_ps(cov_yz_f32x16);
|
|
1961
|
+
nk_f32_t covariance_z_x = _mm512_reduce_add_ps(cov_zx_f32x16);
|
|
1962
|
+
nk_f32_t covariance_z_y = _mm512_reduce_add_ps(cov_zy_f32x16);
|
|
1963
|
+
nk_f32_t covariance_z_z = _mm512_reduce_add_ps(cov_zz_f32x16);
|
|
1964
|
+
|
|
1965
|
+
nk_f32_t inv_n = 1.0f / (nk_f32_t)n;
|
|
1966
|
+
nk_f32_t centroid_a_x = sum_a_x * inv_n, centroid_a_y = sum_a_y * inv_n, centroid_a_z = sum_a_z * inv_n;
|
|
1967
|
+
nk_f32_t centroid_b_x = sum_b_x * inv_n, centroid_b_y = sum_b_y * inv_n, centroid_b_z = sum_b_z * inv_n;
|
|
1968
|
+
|
|
1969
|
+
if (a_centroid) a_centroid[0] = centroid_a_x, a_centroid[1] = centroid_a_y, a_centroid[2] = centroid_a_z;
|
|
1970
|
+
if (b_centroid) b_centroid[0] = centroid_b_x, b_centroid[1] = centroid_b_y, b_centroid[2] = centroid_b_z;
|
|
1971
|
+
|
|
1972
|
+
covariance_x_x -= (nk_f32_t)n * centroid_a_x * centroid_b_x;
|
|
1973
|
+
covariance_x_y -= (nk_f32_t)n * centroid_a_x * centroid_b_y;
|
|
1974
|
+
covariance_x_z -= (nk_f32_t)n * centroid_a_x * centroid_b_z;
|
|
1975
|
+
covariance_y_x -= (nk_f32_t)n * centroid_a_y * centroid_b_x;
|
|
1976
|
+
covariance_y_y -= (nk_f32_t)n * centroid_a_y * centroid_b_y;
|
|
1977
|
+
covariance_y_z -= (nk_f32_t)n * centroid_a_y * centroid_b_z;
|
|
1978
|
+
covariance_z_x -= (nk_f32_t)n * centroid_a_z * centroid_b_x;
|
|
1979
|
+
covariance_z_y -= (nk_f32_t)n * centroid_a_z * centroid_b_y;
|
|
1980
|
+
covariance_z_z -= (nk_f32_t)n * centroid_a_z * centroid_b_z;
|
|
1981
|
+
|
|
1982
|
+
nk_f32_t cross_covariance[9] = {covariance_x_x, covariance_x_y, covariance_x_z, covariance_y_x, covariance_y_y,
|
|
1983
|
+
covariance_y_z, covariance_z_x, covariance_z_y, covariance_z_z};
|
|
1984
|
+
nk_f32_t svd_u[9], svd_s[9], svd_v[9];
|
|
1985
|
+
nk_svd3x3_f32_(cross_covariance, svd_u, svd_s, svd_v);
|
|
1986
|
+
|
|
1987
|
+
nk_f32_t r[9];
|
|
1988
|
+
r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
|
|
1989
|
+
r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
|
|
1990
|
+
r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
|
|
1991
|
+
r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
|
|
1992
|
+
r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
|
|
1993
|
+
r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
|
|
1994
|
+
r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
|
|
1995
|
+
r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
|
|
1996
|
+
r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
|
|
1997
|
+
|
|
1998
|
+
if (nk_det3x3_f32_(r) < 0) {
|
|
1999
|
+
svd_v[2] = -svd_v[2], svd_v[5] = -svd_v[5], svd_v[8] = -svd_v[8];
|
|
2000
|
+
r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
|
|
2001
|
+
r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
|
|
2002
|
+
r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
|
|
2003
|
+
r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
|
|
2004
|
+
r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
|
|
2005
|
+
r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
|
|
2006
|
+
r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
|
|
2007
|
+
r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
|
|
2008
|
+
r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
|
|
2009
|
+
}
|
|
2010
|
+
|
|
2011
|
+
if (rotation)
|
|
2012
|
+
for (int j = 0; j < 9; ++j) rotation[j] = r[j];
|
|
2013
|
+
if (scale) *scale = 1.0f;
|
|
2014
|
+
|
|
2015
|
+
nk_f32_t sum_squared = nk_transformed_ssd_f16_skylake_(a, b, n, r, 1.0f, centroid_a_x, centroid_a_y, centroid_a_z,
|
|
2016
|
+
centroid_b_x, centroid_b_y, centroid_b_z);
|
|
2017
|
+
*result = nk_f32_sqrt_haswell(sum_squared * inv_n);
|
|
2018
|
+
}
|
|
2019
|
+
|
|
2020
|
+
NK_PUBLIC void nk_kabsch_bf16_skylake(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
|
|
2021
|
+
nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
|
|
2022
|
+
__m512 const zeros_f32x16 = _mm512_setzero_ps();
|
|
2023
|
+
|
|
2024
|
+
__m512 sum_a_x_f32x16 = zeros_f32x16, sum_a_y_f32x16 = zeros_f32x16, sum_a_z_f32x16 = zeros_f32x16;
|
|
2025
|
+
__m512 sum_b_x_f32x16 = zeros_f32x16, sum_b_y_f32x16 = zeros_f32x16, sum_b_z_f32x16 = zeros_f32x16;
|
|
2026
|
+
__m512 cov_xx_f32x16 = zeros_f32x16, cov_xy_f32x16 = zeros_f32x16, cov_xz_f32x16 = zeros_f32x16;
|
|
2027
|
+
__m512 cov_yx_f32x16 = zeros_f32x16, cov_yy_f32x16 = zeros_f32x16, cov_yz_f32x16 = zeros_f32x16;
|
|
2028
|
+
__m512 cov_zx_f32x16 = zeros_f32x16, cov_zy_f32x16 = zeros_f32x16, cov_zz_f32x16 = zeros_f32x16;
|
|
2029
|
+
|
|
2030
|
+
nk_size_t i = 0;
|
|
2031
|
+
__m512 a_x_f32x16, a_y_f32x16, a_z_f32x16, b_x_f32x16, b_y_f32x16, b_z_f32x16;
|
|
2032
|
+
|
|
2033
|
+
for (; i + 16 <= n; i += 16) {
|
|
2034
|
+
nk_deinterleave_bf16x16_to_f32x16_skylake_(a + i * 3, &a_x_f32x16, &a_y_f32x16, &a_z_f32x16);
|
|
2035
|
+
nk_deinterleave_bf16x16_to_f32x16_skylake_(b + i * 3, &b_x_f32x16, &b_y_f32x16, &b_z_f32x16);
|
|
2036
|
+
|
|
2037
|
+
sum_a_x_f32x16 = _mm512_add_ps(sum_a_x_f32x16, a_x_f32x16);
|
|
2038
|
+
sum_a_y_f32x16 = _mm512_add_ps(sum_a_y_f32x16, a_y_f32x16);
|
|
2039
|
+
sum_a_z_f32x16 = _mm512_add_ps(sum_a_z_f32x16, a_z_f32x16);
|
|
2040
|
+
sum_b_x_f32x16 = _mm512_add_ps(sum_b_x_f32x16, b_x_f32x16);
|
|
2041
|
+
sum_b_y_f32x16 = _mm512_add_ps(sum_b_y_f32x16, b_y_f32x16);
|
|
2042
|
+
sum_b_z_f32x16 = _mm512_add_ps(sum_b_z_f32x16, b_z_f32x16);
|
|
2043
|
+
|
|
2044
|
+
cov_xx_f32x16 = _mm512_fmadd_ps(a_x_f32x16, b_x_f32x16, cov_xx_f32x16);
|
|
2045
|
+
cov_xy_f32x16 = _mm512_fmadd_ps(a_x_f32x16, b_y_f32x16, cov_xy_f32x16);
|
|
2046
|
+
cov_xz_f32x16 = _mm512_fmadd_ps(a_x_f32x16, b_z_f32x16, cov_xz_f32x16);
|
|
2047
|
+
cov_yx_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_x_f32x16, cov_yx_f32x16);
|
|
2048
|
+
cov_yy_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_y_f32x16, cov_yy_f32x16);
|
|
2049
|
+
cov_yz_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_z_f32x16, cov_yz_f32x16);
|
|
2050
|
+
cov_zx_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_x_f32x16, cov_zx_f32x16);
|
|
2051
|
+
cov_zy_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_y_f32x16, cov_zy_f32x16);
|
|
2052
|
+
cov_zz_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_z_f32x16, cov_zz_f32x16);
|
|
2053
|
+
}
|
|
2054
|
+
|
|
2055
|
+
// Tail: deinterleave remaining points into zero-initialized vectors
|
|
2056
|
+
if (i < n) {
|
|
2057
|
+
nk_size_t tail = n - i;
|
|
2058
|
+
nk_deinterleave_bf16_tail_to_f32x16_skylake_(a + i * 3, tail, &a_x_f32x16, &a_y_f32x16, &a_z_f32x16);
|
|
2059
|
+
nk_deinterleave_bf16_tail_to_f32x16_skylake_(b + i * 3, tail, &b_x_f32x16, &b_y_f32x16, &b_z_f32x16);
|
|
2060
|
+
|
|
2061
|
+
sum_a_x_f32x16 = _mm512_add_ps(sum_a_x_f32x16, a_x_f32x16);
|
|
2062
|
+
sum_a_y_f32x16 = _mm512_add_ps(sum_a_y_f32x16, a_y_f32x16);
|
|
2063
|
+
sum_a_z_f32x16 = _mm512_add_ps(sum_a_z_f32x16, a_z_f32x16);
|
|
2064
|
+
sum_b_x_f32x16 = _mm512_add_ps(sum_b_x_f32x16, b_x_f32x16);
|
|
2065
|
+
sum_b_y_f32x16 = _mm512_add_ps(sum_b_y_f32x16, b_y_f32x16);
|
|
2066
|
+
sum_b_z_f32x16 = _mm512_add_ps(sum_b_z_f32x16, b_z_f32x16);
|
|
2067
|
+
|
|
2068
|
+
cov_xx_f32x16 = _mm512_fmadd_ps(a_x_f32x16, b_x_f32x16, cov_xx_f32x16);
|
|
2069
|
+
cov_xy_f32x16 = _mm512_fmadd_ps(a_x_f32x16, b_y_f32x16, cov_xy_f32x16);
|
|
2070
|
+
cov_xz_f32x16 = _mm512_fmadd_ps(a_x_f32x16, b_z_f32x16, cov_xz_f32x16);
|
|
2071
|
+
cov_yx_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_x_f32x16, cov_yx_f32x16);
|
|
2072
|
+
cov_yy_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_y_f32x16, cov_yy_f32x16);
|
|
2073
|
+
cov_yz_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_z_f32x16, cov_yz_f32x16);
|
|
2074
|
+
cov_zx_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_x_f32x16, cov_zx_f32x16);
|
|
2075
|
+
cov_zy_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_y_f32x16, cov_zy_f32x16);
|
|
2076
|
+
cov_zz_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_z_f32x16, cov_zz_f32x16);
|
|
2077
|
+
}
|
|
2078
|
+
|
|
2079
|
+
nk_f32_t sum_a_x = _mm512_reduce_add_ps(sum_a_x_f32x16);
|
|
2080
|
+
nk_f32_t sum_a_y = _mm512_reduce_add_ps(sum_a_y_f32x16);
|
|
2081
|
+
nk_f32_t sum_a_z = _mm512_reduce_add_ps(sum_a_z_f32x16);
|
|
2082
|
+
nk_f32_t sum_b_x = _mm512_reduce_add_ps(sum_b_x_f32x16);
|
|
2083
|
+
nk_f32_t sum_b_y = _mm512_reduce_add_ps(sum_b_y_f32x16);
|
|
2084
|
+
nk_f32_t sum_b_z = _mm512_reduce_add_ps(sum_b_z_f32x16);
|
|
2085
|
+
nk_f32_t covariance_x_x = _mm512_reduce_add_ps(cov_xx_f32x16);
|
|
2086
|
+
nk_f32_t covariance_x_y = _mm512_reduce_add_ps(cov_xy_f32x16);
|
|
2087
|
+
nk_f32_t covariance_x_z = _mm512_reduce_add_ps(cov_xz_f32x16);
|
|
2088
|
+
nk_f32_t covariance_y_x = _mm512_reduce_add_ps(cov_yx_f32x16);
|
|
2089
|
+
nk_f32_t covariance_y_y = _mm512_reduce_add_ps(cov_yy_f32x16);
|
|
2090
|
+
nk_f32_t covariance_y_z = _mm512_reduce_add_ps(cov_yz_f32x16);
|
|
2091
|
+
nk_f32_t covariance_z_x = _mm512_reduce_add_ps(cov_zx_f32x16);
|
|
2092
|
+
nk_f32_t covariance_z_y = _mm512_reduce_add_ps(cov_zy_f32x16);
|
|
2093
|
+
nk_f32_t covariance_z_z = _mm512_reduce_add_ps(cov_zz_f32x16);
|
|
2094
|
+
|
|
2095
|
+
nk_f32_t inv_n = 1.0f / (nk_f32_t)n;
|
|
2096
|
+
nk_f32_t centroid_a_x = sum_a_x * inv_n, centroid_a_y = sum_a_y * inv_n, centroid_a_z = sum_a_z * inv_n;
|
|
2097
|
+
nk_f32_t centroid_b_x = sum_b_x * inv_n, centroid_b_y = sum_b_y * inv_n, centroid_b_z = sum_b_z * inv_n;
|
|
2098
|
+
|
|
2099
|
+
if (a_centroid) a_centroid[0] = centroid_a_x, a_centroid[1] = centroid_a_y, a_centroid[2] = centroid_a_z;
|
|
2100
|
+
if (b_centroid) b_centroid[0] = centroid_b_x, b_centroid[1] = centroid_b_y, b_centroid[2] = centroid_b_z;
|
|
2101
|
+
|
|
2102
|
+
covariance_x_x -= (nk_f32_t)n * centroid_a_x * centroid_b_x;
|
|
2103
|
+
covariance_x_y -= (nk_f32_t)n * centroid_a_x * centroid_b_y;
|
|
2104
|
+
covariance_x_z -= (nk_f32_t)n * centroid_a_x * centroid_b_z;
|
|
2105
|
+
covariance_y_x -= (nk_f32_t)n * centroid_a_y * centroid_b_x;
|
|
2106
|
+
covariance_y_y -= (nk_f32_t)n * centroid_a_y * centroid_b_y;
|
|
2107
|
+
covariance_y_z -= (nk_f32_t)n * centroid_a_y * centroid_b_z;
|
|
2108
|
+
covariance_z_x -= (nk_f32_t)n * centroid_a_z * centroid_b_x;
|
|
2109
|
+
covariance_z_y -= (nk_f32_t)n * centroid_a_z * centroid_b_y;
|
|
2110
|
+
covariance_z_z -= (nk_f32_t)n * centroid_a_z * centroid_b_z;
|
|
2111
|
+
|
|
2112
|
+
nk_f32_t cross_covariance[9] = {covariance_x_x, covariance_x_y, covariance_x_z, covariance_y_x, covariance_y_y,
|
|
2113
|
+
covariance_y_z, covariance_z_x, covariance_z_y, covariance_z_z};
|
|
2114
|
+
nk_f32_t svd_u[9], svd_s[9], svd_v[9];
|
|
2115
|
+
nk_svd3x3_f32_(cross_covariance, svd_u, svd_s, svd_v);
|
|
2116
|
+
|
|
2117
|
+
nk_f32_t r[9];
|
|
2118
|
+
r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
|
|
2119
|
+
r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
|
|
2120
|
+
r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
|
|
2121
|
+
r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
|
|
2122
|
+
r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
|
|
2123
|
+
r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
|
|
2124
|
+
r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
|
|
2125
|
+
r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
|
|
2126
|
+
r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
|
|
2127
|
+
|
|
2128
|
+
if (nk_det3x3_f32_(r) < 0) {
|
|
2129
|
+
svd_v[2] = -svd_v[2], svd_v[5] = -svd_v[5], svd_v[8] = -svd_v[8];
|
|
2130
|
+
r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
|
|
2131
|
+
r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
|
|
2132
|
+
r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
|
|
2133
|
+
r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
|
|
2134
|
+
r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
|
|
2135
|
+
r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
|
|
2136
|
+
r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
|
|
2137
|
+
r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
|
|
2138
|
+
r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
|
|
2139
|
+
}
|
|
2140
|
+
|
|
2141
|
+
if (rotation)
|
|
2142
|
+
for (int j = 0; j < 9; ++j) rotation[j] = r[j];
|
|
2143
|
+
if (scale) *scale = 1.0f;
|
|
2144
|
+
|
|
2145
|
+
nk_f32_t sum_squared = nk_transformed_ssd_bf16_skylake_(a, b, n, r, 1.0f, centroid_a_x, centroid_a_y, centroid_a_z,
|
|
2146
|
+
centroid_b_x, centroid_b_y, centroid_b_z);
|
|
2147
|
+
*result = nk_f32_sqrt_haswell(sum_squared * inv_n);
|
|
2148
|
+
}
|
|
2149
|
+
|
|
2150
|
+
NK_PUBLIC void nk_umeyama_f16_skylake(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
|
|
2151
|
+
nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
|
|
2152
|
+
__m512 const zeros_f32x16 = _mm512_setzero_ps();
|
|
2153
|
+
|
|
2154
|
+
__m512 sum_a_x_f32x16 = zeros_f32x16, sum_a_y_f32x16 = zeros_f32x16, sum_a_z_f32x16 = zeros_f32x16;
|
|
2155
|
+
__m512 sum_b_x_f32x16 = zeros_f32x16, sum_b_y_f32x16 = zeros_f32x16, sum_b_z_f32x16 = zeros_f32x16;
|
|
2156
|
+
__m512 cov_xx_f32x16 = zeros_f32x16, cov_xy_f32x16 = zeros_f32x16, cov_xz_f32x16 = zeros_f32x16;
|
|
2157
|
+
__m512 cov_yx_f32x16 = zeros_f32x16, cov_yy_f32x16 = zeros_f32x16, cov_yz_f32x16 = zeros_f32x16;
|
|
2158
|
+
__m512 cov_zx_f32x16 = zeros_f32x16, cov_zy_f32x16 = zeros_f32x16, cov_zz_f32x16 = zeros_f32x16;
|
|
2159
|
+
__m512 variance_a_f32x16 = zeros_f32x16;
|
|
2160
|
+
|
|
2161
|
+
nk_size_t i = 0;
|
|
2162
|
+
__m512 a_x_f32x16, a_y_f32x16, a_z_f32x16, b_x_f32x16, b_y_f32x16, b_z_f32x16;
|
|
2163
|
+
|
|
2164
|
+
for (; i + 16 <= n; i += 16) {
|
|
2165
|
+
nk_deinterleave_f16x16_to_f32x16_skylake_(a + i * 3, &a_x_f32x16, &a_y_f32x16, &a_z_f32x16);
|
|
2166
|
+
nk_deinterleave_f16x16_to_f32x16_skylake_(b + i * 3, &b_x_f32x16, &b_y_f32x16, &b_z_f32x16);
|
|
2167
|
+
|
|
2168
|
+
sum_a_x_f32x16 = _mm512_add_ps(sum_a_x_f32x16, a_x_f32x16);
|
|
2169
|
+
sum_a_y_f32x16 = _mm512_add_ps(sum_a_y_f32x16, a_y_f32x16);
|
|
2170
|
+
sum_a_z_f32x16 = _mm512_add_ps(sum_a_z_f32x16, a_z_f32x16);
|
|
2171
|
+
sum_b_x_f32x16 = _mm512_add_ps(sum_b_x_f32x16, b_x_f32x16);
|
|
2172
|
+
sum_b_y_f32x16 = _mm512_add_ps(sum_b_y_f32x16, b_y_f32x16);
|
|
2173
|
+
sum_b_z_f32x16 = _mm512_add_ps(sum_b_z_f32x16, b_z_f32x16);
|
|
2174
|
+
|
|
2175
|
+
cov_xx_f32x16 = _mm512_fmadd_ps(a_x_f32x16, b_x_f32x16, cov_xx_f32x16);
|
|
2176
|
+
cov_xy_f32x16 = _mm512_fmadd_ps(a_x_f32x16, b_y_f32x16, cov_xy_f32x16);
|
|
2177
|
+
cov_xz_f32x16 = _mm512_fmadd_ps(a_x_f32x16, b_z_f32x16, cov_xz_f32x16);
|
|
2178
|
+
cov_yx_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_x_f32x16, cov_yx_f32x16);
|
|
2179
|
+
cov_yy_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_y_f32x16, cov_yy_f32x16);
|
|
2180
|
+
cov_yz_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_z_f32x16, cov_yz_f32x16);
|
|
2181
|
+
cov_zx_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_x_f32x16, cov_zx_f32x16);
|
|
2182
|
+
cov_zy_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_y_f32x16, cov_zy_f32x16);
|
|
2183
|
+
cov_zz_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_z_f32x16, cov_zz_f32x16);
|
|
2184
|
+
|
|
2185
|
+
variance_a_f32x16 = _mm512_fmadd_ps(a_x_f32x16, a_x_f32x16, variance_a_f32x16);
|
|
2186
|
+
variance_a_f32x16 = _mm512_fmadd_ps(a_y_f32x16, a_y_f32x16, variance_a_f32x16);
|
|
2187
|
+
variance_a_f32x16 = _mm512_fmadd_ps(a_z_f32x16, a_z_f32x16, variance_a_f32x16);
|
|
2188
|
+
}
|
|
2189
|
+
|
|
2190
|
+
// Tail: deinterleave remaining points into zero-initialized vectors
|
|
2191
|
+
if (i < n) {
|
|
2192
|
+
nk_size_t tail = n - i;
|
|
2193
|
+
nk_deinterleave_f16_tail_to_f32x16_skylake_(a + i * 3, tail, &a_x_f32x16, &a_y_f32x16, &a_z_f32x16);
|
|
2194
|
+
nk_deinterleave_f16_tail_to_f32x16_skylake_(b + i * 3, tail, &b_x_f32x16, &b_y_f32x16, &b_z_f32x16);
|
|
2195
|
+
|
|
2196
|
+
sum_a_x_f32x16 = _mm512_add_ps(sum_a_x_f32x16, a_x_f32x16);
|
|
2197
|
+
sum_a_y_f32x16 = _mm512_add_ps(sum_a_y_f32x16, a_y_f32x16);
|
|
2198
|
+
sum_a_z_f32x16 = _mm512_add_ps(sum_a_z_f32x16, a_z_f32x16);
|
|
2199
|
+
sum_b_x_f32x16 = _mm512_add_ps(sum_b_x_f32x16, b_x_f32x16);
|
|
2200
|
+
sum_b_y_f32x16 = _mm512_add_ps(sum_b_y_f32x16, b_y_f32x16);
|
|
2201
|
+
sum_b_z_f32x16 = _mm512_add_ps(sum_b_z_f32x16, b_z_f32x16);
|
|
2202
|
+
|
|
2203
|
+
cov_xx_f32x16 = _mm512_fmadd_ps(a_x_f32x16, b_x_f32x16, cov_xx_f32x16);
|
|
2204
|
+
cov_xy_f32x16 = _mm512_fmadd_ps(a_x_f32x16, b_y_f32x16, cov_xy_f32x16);
|
|
2205
|
+
cov_xz_f32x16 = _mm512_fmadd_ps(a_x_f32x16, b_z_f32x16, cov_xz_f32x16);
|
|
2206
|
+
cov_yx_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_x_f32x16, cov_yx_f32x16);
|
|
2207
|
+
cov_yy_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_y_f32x16, cov_yy_f32x16);
|
|
2208
|
+
cov_yz_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_z_f32x16, cov_yz_f32x16);
|
|
2209
|
+
cov_zx_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_x_f32x16, cov_zx_f32x16);
|
|
2210
|
+
cov_zy_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_y_f32x16, cov_zy_f32x16);
|
|
2211
|
+
cov_zz_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_z_f32x16, cov_zz_f32x16);
|
|
2212
|
+
|
|
2213
|
+
variance_a_f32x16 = _mm512_fmadd_ps(a_x_f32x16, a_x_f32x16, variance_a_f32x16);
|
|
2214
|
+
variance_a_f32x16 = _mm512_fmadd_ps(a_y_f32x16, a_y_f32x16, variance_a_f32x16);
|
|
2215
|
+
variance_a_f32x16 = _mm512_fmadd_ps(a_z_f32x16, a_z_f32x16, variance_a_f32x16);
|
|
2216
|
+
}
|
|
2217
|
+
|
|
2218
|
+
nk_f32_t sum_a_x = _mm512_reduce_add_ps(sum_a_x_f32x16);
|
|
2219
|
+
nk_f32_t sum_a_y = _mm512_reduce_add_ps(sum_a_y_f32x16);
|
|
2220
|
+
nk_f32_t sum_a_z = _mm512_reduce_add_ps(sum_a_z_f32x16);
|
|
2221
|
+
nk_f32_t sum_b_x = _mm512_reduce_add_ps(sum_b_x_f32x16);
|
|
2222
|
+
nk_f32_t sum_b_y = _mm512_reduce_add_ps(sum_b_y_f32x16);
|
|
2223
|
+
nk_f32_t sum_b_z = _mm512_reduce_add_ps(sum_b_z_f32x16);
|
|
2224
|
+
nk_f32_t covariance_x_x = _mm512_reduce_add_ps(cov_xx_f32x16);
|
|
2225
|
+
nk_f32_t covariance_x_y = _mm512_reduce_add_ps(cov_xy_f32x16);
|
|
2226
|
+
nk_f32_t covariance_x_z = _mm512_reduce_add_ps(cov_xz_f32x16);
|
|
2227
|
+
nk_f32_t covariance_y_x = _mm512_reduce_add_ps(cov_yx_f32x16);
|
|
2228
|
+
nk_f32_t covariance_y_y = _mm512_reduce_add_ps(cov_yy_f32x16);
|
|
2229
|
+
nk_f32_t covariance_y_z = _mm512_reduce_add_ps(cov_yz_f32x16);
|
|
2230
|
+
nk_f32_t covariance_z_x = _mm512_reduce_add_ps(cov_zx_f32x16);
|
|
2231
|
+
nk_f32_t covariance_z_y = _mm512_reduce_add_ps(cov_zy_f32x16);
|
|
2232
|
+
nk_f32_t covariance_z_z = _mm512_reduce_add_ps(cov_zz_f32x16);
|
|
2233
|
+
nk_f32_t variance_a_sum = _mm512_reduce_add_ps(variance_a_f32x16);
|
|
2234
|
+
|
|
2235
|
+
nk_f32_t inv_n = 1.0f / (nk_f32_t)n;
|
|
2236
|
+
nk_f32_t centroid_a_x = sum_a_x * inv_n, centroid_a_y = sum_a_y * inv_n, centroid_a_z = sum_a_z * inv_n;
|
|
2237
|
+
nk_f32_t centroid_b_x = sum_b_x * inv_n, centroid_b_y = sum_b_y * inv_n, centroid_b_z = sum_b_z * inv_n;
|
|
2238
|
+
|
|
2239
|
+
if (a_centroid) a_centroid[0] = centroid_a_x, a_centroid[1] = centroid_a_y, a_centroid[2] = centroid_a_z;
|
|
2240
|
+
if (b_centroid) b_centroid[0] = centroid_b_x, b_centroid[1] = centroid_b_y, b_centroid[2] = centroid_b_z;
|
|
2241
|
+
|
|
2242
|
+
nk_f32_t variance_a = variance_a_sum * inv_n -
|
|
2243
|
+
(centroid_a_x * centroid_a_x + centroid_a_y * centroid_a_y + centroid_a_z * centroid_a_z);
|
|
2244
|
+
|
|
2245
|
+
covariance_x_x -= (nk_f32_t)n * centroid_a_x * centroid_b_x;
|
|
2246
|
+
covariance_x_y -= (nk_f32_t)n * centroid_a_x * centroid_b_y;
|
|
2247
|
+
covariance_x_z -= (nk_f32_t)n * centroid_a_x * centroid_b_z;
|
|
2248
|
+
covariance_y_x -= (nk_f32_t)n * centroid_a_y * centroid_b_x;
|
|
2249
|
+
covariance_y_y -= (nk_f32_t)n * centroid_a_y * centroid_b_y;
|
|
2250
|
+
covariance_y_z -= (nk_f32_t)n * centroid_a_y * centroid_b_z;
|
|
2251
|
+
covariance_z_x -= (nk_f32_t)n * centroid_a_z * centroid_b_x;
|
|
2252
|
+
covariance_z_y -= (nk_f32_t)n * centroid_a_z * centroid_b_y;
|
|
2253
|
+
covariance_z_z -= (nk_f32_t)n * centroid_a_z * centroid_b_z;
|
|
2254
|
+
|
|
2255
|
+
nk_f32_t cross_covariance[9] = {covariance_x_x, covariance_x_y, covariance_x_z, covariance_y_x, covariance_y_y,
|
|
2256
|
+
covariance_y_z, covariance_z_x, covariance_z_y, covariance_z_z};
|
|
2257
|
+
|
|
2258
|
+
nk_f32_t svd_u[9], svd_s[9], svd_v[9];
|
|
2259
|
+
nk_svd3x3_f32_(cross_covariance, svd_u, svd_s, svd_v);
|
|
2260
|
+
|
|
2261
|
+
nk_f32_t r[9];
|
|
2262
|
+
r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
|
|
2263
|
+
r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
|
|
2264
|
+
r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
|
|
2265
|
+
r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
|
|
2266
|
+
r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
|
|
2267
|
+
r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
|
|
2268
|
+
r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
|
|
2269
|
+
r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
|
|
2270
|
+
r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
|
|
2271
|
+
|
|
2272
|
+
nk_f32_t det = nk_det3x3_f32_(r);
|
|
2273
|
+
nk_f32_t d3 = det < 0 ? -1.0f : 1.0f;
|
|
2274
|
+
nk_f32_t trace_ds = svd_s[0] + svd_s[4] + d3 * svd_s[8];
|
|
2275
|
+
nk_f32_t c = trace_ds / ((nk_f32_t)n * variance_a);
|
|
2276
|
+
if (scale) *scale = c;
|
|
2277
|
+
|
|
2278
|
+
if (det < 0) {
|
|
2279
|
+
svd_v[2] = -svd_v[2], svd_v[5] = -svd_v[5], svd_v[8] = -svd_v[8];
|
|
2280
|
+
r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
|
|
2281
|
+
r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
|
|
2282
|
+
r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
|
|
2283
|
+
r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
|
|
2284
|
+
r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
|
|
2285
|
+
r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
|
|
2286
|
+
r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
|
|
2287
|
+
r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
|
|
2288
|
+
r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
|
|
2289
|
+
}
|
|
2290
|
+
|
|
2291
|
+
if (rotation)
|
|
2292
|
+
for (int j = 0; j < 9; ++j) rotation[j] = r[j];
|
|
2293
|
+
|
|
2294
|
+
nk_f32_t sum_squared = nk_transformed_ssd_f16_skylake_(a, b, n, r, c, centroid_a_x, centroid_a_y, centroid_a_z,
|
|
2295
|
+
centroid_b_x, centroid_b_y, centroid_b_z);
|
|
2296
|
+
*result = nk_f32_sqrt_haswell(sum_squared * inv_n);
|
|
2297
|
+
}
|
|
2298
|
+
|
|
2299
|
+
NK_PUBLIC void nk_umeyama_bf16_skylake(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
|
|
2300
|
+
nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
|
|
2301
|
+
__m512 const zeros_f32x16 = _mm512_setzero_ps();
|
|
2302
|
+
|
|
2303
|
+
__m512 sum_a_x_f32x16 = zeros_f32x16, sum_a_y_f32x16 = zeros_f32x16, sum_a_z_f32x16 = zeros_f32x16;
|
|
2304
|
+
__m512 sum_b_x_f32x16 = zeros_f32x16, sum_b_y_f32x16 = zeros_f32x16, sum_b_z_f32x16 = zeros_f32x16;
|
|
2305
|
+
__m512 cov_xx_f32x16 = zeros_f32x16, cov_xy_f32x16 = zeros_f32x16, cov_xz_f32x16 = zeros_f32x16;
|
|
2306
|
+
__m512 cov_yx_f32x16 = zeros_f32x16, cov_yy_f32x16 = zeros_f32x16, cov_yz_f32x16 = zeros_f32x16;
|
|
2307
|
+
__m512 cov_zx_f32x16 = zeros_f32x16, cov_zy_f32x16 = zeros_f32x16, cov_zz_f32x16 = zeros_f32x16;
|
|
2308
|
+
__m512 variance_a_f32x16 = zeros_f32x16;
|
|
2309
|
+
|
|
2310
|
+
nk_size_t i = 0;
|
|
2311
|
+
__m512 a_x_f32x16, a_y_f32x16, a_z_f32x16, b_x_f32x16, b_y_f32x16, b_z_f32x16;
|
|
2312
|
+
|
|
2313
|
+
for (; i + 16 <= n; i += 16) {
|
|
2314
|
+
nk_deinterleave_bf16x16_to_f32x16_skylake_(a + i * 3, &a_x_f32x16, &a_y_f32x16, &a_z_f32x16);
|
|
2315
|
+
nk_deinterleave_bf16x16_to_f32x16_skylake_(b + i * 3, &b_x_f32x16, &b_y_f32x16, &b_z_f32x16);
|
|
2316
|
+
|
|
2317
|
+
sum_a_x_f32x16 = _mm512_add_ps(sum_a_x_f32x16, a_x_f32x16);
|
|
2318
|
+
sum_a_y_f32x16 = _mm512_add_ps(sum_a_y_f32x16, a_y_f32x16);
|
|
2319
|
+
sum_a_z_f32x16 = _mm512_add_ps(sum_a_z_f32x16, a_z_f32x16);
|
|
2320
|
+
sum_b_x_f32x16 = _mm512_add_ps(sum_b_x_f32x16, b_x_f32x16);
|
|
2321
|
+
sum_b_y_f32x16 = _mm512_add_ps(sum_b_y_f32x16, b_y_f32x16);
|
|
2322
|
+
sum_b_z_f32x16 = _mm512_add_ps(sum_b_z_f32x16, b_z_f32x16);
|
|
2323
|
+
|
|
2324
|
+
cov_xx_f32x16 = _mm512_fmadd_ps(a_x_f32x16, b_x_f32x16, cov_xx_f32x16);
|
|
2325
|
+
cov_xy_f32x16 = _mm512_fmadd_ps(a_x_f32x16, b_y_f32x16, cov_xy_f32x16);
|
|
2326
|
+
cov_xz_f32x16 = _mm512_fmadd_ps(a_x_f32x16, b_z_f32x16, cov_xz_f32x16);
|
|
2327
|
+
cov_yx_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_x_f32x16, cov_yx_f32x16);
|
|
2328
|
+
cov_yy_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_y_f32x16, cov_yy_f32x16);
|
|
2329
|
+
cov_yz_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_z_f32x16, cov_yz_f32x16);
|
|
2330
|
+
cov_zx_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_x_f32x16, cov_zx_f32x16);
|
|
2331
|
+
cov_zy_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_y_f32x16, cov_zy_f32x16);
|
|
2332
|
+
cov_zz_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_z_f32x16, cov_zz_f32x16);
|
|
2333
|
+
|
|
2334
|
+
variance_a_f32x16 = _mm512_fmadd_ps(a_x_f32x16, a_x_f32x16, variance_a_f32x16);
|
|
2335
|
+
variance_a_f32x16 = _mm512_fmadd_ps(a_y_f32x16, a_y_f32x16, variance_a_f32x16);
|
|
2336
|
+
variance_a_f32x16 = _mm512_fmadd_ps(a_z_f32x16, a_z_f32x16, variance_a_f32x16);
|
|
2337
|
+
}
|
|
2338
|
+
|
|
2339
|
+
// Tail: deinterleave remaining points into zero-initialized vectors
|
|
2340
|
+
if (i < n) {
|
|
2341
|
+
nk_size_t tail = n - i;
|
|
2342
|
+
nk_deinterleave_bf16_tail_to_f32x16_skylake_(a + i * 3, tail, &a_x_f32x16, &a_y_f32x16, &a_z_f32x16);
|
|
2343
|
+
nk_deinterleave_bf16_tail_to_f32x16_skylake_(b + i * 3, tail, &b_x_f32x16, &b_y_f32x16, &b_z_f32x16);
|
|
2344
|
+
|
|
2345
|
+
sum_a_x_f32x16 = _mm512_add_ps(sum_a_x_f32x16, a_x_f32x16);
|
|
2346
|
+
sum_a_y_f32x16 = _mm512_add_ps(sum_a_y_f32x16, a_y_f32x16);
|
|
2347
|
+
sum_a_z_f32x16 = _mm512_add_ps(sum_a_z_f32x16, a_z_f32x16);
|
|
2348
|
+
sum_b_x_f32x16 = _mm512_add_ps(sum_b_x_f32x16, b_x_f32x16);
|
|
2349
|
+
sum_b_y_f32x16 = _mm512_add_ps(sum_b_y_f32x16, b_y_f32x16);
|
|
2350
|
+
sum_b_z_f32x16 = _mm512_add_ps(sum_b_z_f32x16, b_z_f32x16);
|
|
2351
|
+
|
|
2352
|
+
cov_xx_f32x16 = _mm512_fmadd_ps(a_x_f32x16, b_x_f32x16, cov_xx_f32x16);
|
|
2353
|
+
cov_xy_f32x16 = _mm512_fmadd_ps(a_x_f32x16, b_y_f32x16, cov_xy_f32x16);
|
|
2354
|
+
cov_xz_f32x16 = _mm512_fmadd_ps(a_x_f32x16, b_z_f32x16, cov_xz_f32x16);
|
|
2355
|
+
cov_yx_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_x_f32x16, cov_yx_f32x16);
|
|
2356
|
+
cov_yy_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_y_f32x16, cov_yy_f32x16);
|
|
2357
|
+
cov_yz_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_z_f32x16, cov_yz_f32x16);
|
|
2358
|
+
cov_zx_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_x_f32x16, cov_zx_f32x16);
|
|
2359
|
+
cov_zy_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_y_f32x16, cov_zy_f32x16);
|
|
2360
|
+
cov_zz_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_z_f32x16, cov_zz_f32x16);
|
|
2361
|
+
|
|
2362
|
+
variance_a_f32x16 = _mm512_fmadd_ps(a_x_f32x16, a_x_f32x16, variance_a_f32x16);
|
|
2363
|
+
variance_a_f32x16 = _mm512_fmadd_ps(a_y_f32x16, a_y_f32x16, variance_a_f32x16);
|
|
2364
|
+
variance_a_f32x16 = _mm512_fmadd_ps(a_z_f32x16, a_z_f32x16, variance_a_f32x16);
|
|
2365
|
+
}
|
|
2366
|
+
|
|
2367
|
+
nk_f32_t sum_a_x = _mm512_reduce_add_ps(sum_a_x_f32x16);
|
|
2368
|
+
nk_f32_t sum_a_y = _mm512_reduce_add_ps(sum_a_y_f32x16);
|
|
2369
|
+
nk_f32_t sum_a_z = _mm512_reduce_add_ps(sum_a_z_f32x16);
|
|
2370
|
+
nk_f32_t sum_b_x = _mm512_reduce_add_ps(sum_b_x_f32x16);
|
|
2371
|
+
nk_f32_t sum_b_y = _mm512_reduce_add_ps(sum_b_y_f32x16);
|
|
2372
|
+
nk_f32_t sum_b_z = _mm512_reduce_add_ps(sum_b_z_f32x16);
|
|
2373
|
+
nk_f32_t covariance_x_x = _mm512_reduce_add_ps(cov_xx_f32x16);
|
|
2374
|
+
nk_f32_t covariance_x_y = _mm512_reduce_add_ps(cov_xy_f32x16);
|
|
2375
|
+
nk_f32_t covariance_x_z = _mm512_reduce_add_ps(cov_xz_f32x16);
|
|
2376
|
+
nk_f32_t covariance_y_x = _mm512_reduce_add_ps(cov_yx_f32x16);
|
|
2377
|
+
nk_f32_t covariance_y_y = _mm512_reduce_add_ps(cov_yy_f32x16);
|
|
2378
|
+
nk_f32_t covariance_y_z = _mm512_reduce_add_ps(cov_yz_f32x16);
|
|
2379
|
+
nk_f32_t covariance_z_x = _mm512_reduce_add_ps(cov_zx_f32x16);
|
|
2380
|
+
nk_f32_t covariance_z_y = _mm512_reduce_add_ps(cov_zy_f32x16);
|
|
2381
|
+
nk_f32_t covariance_z_z = _mm512_reduce_add_ps(cov_zz_f32x16);
|
|
2382
|
+
nk_f32_t variance_a_sum = _mm512_reduce_add_ps(variance_a_f32x16);
|
|
2383
|
+
|
|
2384
|
+
nk_f32_t inv_n = 1.0f / (nk_f32_t)n;
|
|
2385
|
+
nk_f32_t centroid_a_x = sum_a_x * inv_n, centroid_a_y = sum_a_y * inv_n, centroid_a_z = sum_a_z * inv_n;
|
|
2386
|
+
nk_f32_t centroid_b_x = sum_b_x * inv_n, centroid_b_y = sum_b_y * inv_n, centroid_b_z = sum_b_z * inv_n;
|
|
2387
|
+
|
|
2388
|
+
if (a_centroid) a_centroid[0] = centroid_a_x, a_centroid[1] = centroid_a_y, a_centroid[2] = centroid_a_z;
|
|
2389
|
+
if (b_centroid) b_centroid[0] = centroid_b_x, b_centroid[1] = centroid_b_y, b_centroid[2] = centroid_b_z;
|
|
2390
|
+
|
|
2391
|
+
nk_f32_t variance_a = variance_a_sum * inv_n -
|
|
2392
|
+
(centroid_a_x * centroid_a_x + centroid_a_y * centroid_a_y + centroid_a_z * centroid_a_z);
|
|
2393
|
+
|
|
2394
|
+
covariance_x_x -= (nk_f32_t)n * centroid_a_x * centroid_b_x;
|
|
2395
|
+
covariance_x_y -= (nk_f32_t)n * centroid_a_x * centroid_b_y;
|
|
2396
|
+
covariance_x_z -= (nk_f32_t)n * centroid_a_x * centroid_b_z;
|
|
2397
|
+
covariance_y_x -= (nk_f32_t)n * centroid_a_y * centroid_b_x;
|
|
2398
|
+
covariance_y_y -= (nk_f32_t)n * centroid_a_y * centroid_b_y;
|
|
2399
|
+
covariance_y_z -= (nk_f32_t)n * centroid_a_y * centroid_b_z;
|
|
2400
|
+
covariance_z_x -= (nk_f32_t)n * centroid_a_z * centroid_b_x;
|
|
2401
|
+
covariance_z_y -= (nk_f32_t)n * centroid_a_z * centroid_b_y;
|
|
2402
|
+
covariance_z_z -= (nk_f32_t)n * centroid_a_z * centroid_b_z;
|
|
2403
|
+
|
|
2404
|
+
nk_f32_t cross_covariance[9] = {covariance_x_x, covariance_x_y, covariance_x_z, covariance_y_x, covariance_y_y,
|
|
2405
|
+
covariance_y_z, covariance_z_x, covariance_z_y, covariance_z_z};
|
|
2406
|
+
|
|
2407
|
+
nk_f32_t svd_u[9], svd_s[9], svd_v[9];
|
|
2408
|
+
nk_svd3x3_f32_(cross_covariance, svd_u, svd_s, svd_v);
|
|
2409
|
+
|
|
2410
|
+
nk_f32_t r[9];
|
|
2411
|
+
r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
|
|
2412
|
+
r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
|
|
2413
|
+
r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
|
|
2414
|
+
r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
|
|
2415
|
+
r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
|
|
2416
|
+
r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
|
|
2417
|
+
r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
|
|
2418
|
+
r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
|
|
2419
|
+
r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
|
|
2420
|
+
|
|
2421
|
+
nk_f32_t det = nk_det3x3_f32_(r);
|
|
2422
|
+
nk_f32_t d3 = det < 0 ? -1.0f : 1.0f;
|
|
2423
|
+
nk_f32_t trace_ds = svd_s[0] + svd_s[4] + d3 * svd_s[8];
|
|
2424
|
+
nk_f32_t c = trace_ds / ((nk_f32_t)n * variance_a);
|
|
2425
|
+
if (scale) *scale = c;
|
|
2426
|
+
|
|
2427
|
+
if (det < 0) {
|
|
2428
|
+
svd_v[2] = -svd_v[2], svd_v[5] = -svd_v[5], svd_v[8] = -svd_v[8];
|
|
2429
|
+
r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
|
|
2430
|
+
r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
|
|
2431
|
+
r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
|
|
2432
|
+
r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
|
|
2433
|
+
r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
|
|
2434
|
+
r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
|
|
2435
|
+
r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
|
|
2436
|
+
r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
|
|
2437
|
+
r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
|
|
2438
|
+
}
|
|
2439
|
+
|
|
2440
|
+
if (rotation)
|
|
2441
|
+
for (int j = 0; j < 9; ++j) rotation[j] = r[j];
|
|
2442
|
+
|
|
2443
|
+
nk_f32_t sum_squared = nk_transformed_ssd_bf16_skylake_(a, b, n, r, c, centroid_a_x, centroid_a_y, centroid_a_z,
|
|
2444
|
+
centroid_b_x, centroid_b_y, centroid_b_z);
|
|
2445
|
+
*result = nk_f32_sqrt_haswell(sum_squared * inv_n);
|
|
2446
|
+
}
|
|
2447
|
+
|
|
1123
2448
|
#if defined(__clang__)
|
|
1124
2449
|
#pragma clang attribute pop
|
|
1125
2450
|
#elif defined(__GNUC__)
|