numkong 7.0.0 → 7.4.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +239 -122
- package/binding.gyp +25 -491
- package/c/dispatch_bf16.c +59 -1
- package/c/dispatch_e2m3.c +41 -8
- package/c/dispatch_e3m2.c +49 -8
- package/c/dispatch_e4m3.c +51 -9
- package/c/dispatch_e5m2.c +45 -1
- package/c/dispatch_f16.c +79 -26
- package/c/dispatch_f16c.c +5 -5
- package/c/dispatch_f32.c +56 -0
- package/c/dispatch_f64.c +52 -0
- package/c/dispatch_i4.c +3 -0
- package/c/dispatch_i8.c +62 -3
- package/c/dispatch_other.c +18 -0
- package/c/dispatch_u1.c +54 -9
- package/c/dispatch_u4.c +3 -0
- package/c/dispatch_u8.c +64 -3
- package/c/numkong.c +3 -0
- package/include/README.md +79 -9
- package/include/numkong/attention/sapphireamx.h +278 -276
- package/include/numkong/attention/sme.h +983 -977
- package/include/numkong/attention.h +1 -1
- package/include/numkong/capabilities.h +289 -94
- package/include/numkong/cast/README.md +40 -40
- package/include/numkong/cast/diamond.h +64 -0
- package/include/numkong/cast/haswell.h +42 -194
- package/include/numkong/cast/icelake.h +42 -37
- package/include/numkong/cast/loongsonasx.h +252 -0
- package/include/numkong/cast/neon.h +216 -249
- package/include/numkong/cast/powervsx.h +449 -0
- package/include/numkong/cast/rvv.h +223 -274
- package/include/numkong/cast/sapphire.h +18 -18
- package/include/numkong/cast/serial.h +1018 -944
- package/include/numkong/cast/skylake.h +82 -23
- package/include/numkong/cast/v128relaxed.h +462 -105
- package/include/numkong/cast.h +24 -0
- package/include/numkong/cast.hpp +44 -0
- package/include/numkong/curved/README.md +17 -17
- package/include/numkong/curved/neon.h +131 -7
- package/include/numkong/curved/neonbfdot.h +6 -7
- package/include/numkong/curved/rvv.h +26 -26
- package/include/numkong/curved/smef64.h +186 -182
- package/include/numkong/curved.h +14 -18
- package/include/numkong/dot/README.md +154 -137
- package/include/numkong/dot/alder.h +43 -43
- package/include/numkong/dot/diamond.h +158 -0
- package/include/numkong/dot/genoa.h +4 -30
- package/include/numkong/dot/haswell.h +215 -180
- package/include/numkong/dot/icelake.h +190 -76
- package/include/numkong/dot/loongsonasx.h +671 -0
- package/include/numkong/dot/neon.h +124 -73
- package/include/numkong/dot/neonbfdot.h +11 -12
- package/include/numkong/dot/neonfhm.h +44 -46
- package/include/numkong/dot/neonfp8.h +323 -0
- package/include/numkong/dot/neonsdot.h +190 -76
- package/include/numkong/dot/powervsx.h +752 -0
- package/include/numkong/dot/rvv.h +92 -84
- package/include/numkong/dot/rvvbf16.h +12 -12
- package/include/numkong/dot/rvvhalf.h +12 -12
- package/include/numkong/dot/sapphire.h +4 -4
- package/include/numkong/dot/serial.h +66 -30
- package/include/numkong/dot/sierra.h +31 -31
- package/include/numkong/dot/skylake.h +142 -110
- package/include/numkong/dot/sve.h +217 -177
- package/include/numkong/dot/svebfdot.h +10 -10
- package/include/numkong/dot/svehalf.h +85 -41
- package/include/numkong/dot/svesdot.h +89 -0
- package/include/numkong/dot/v128relaxed.h +124 -89
- package/include/numkong/dot.h +114 -48
- package/include/numkong/dots/README.md +203 -203
- package/include/numkong/dots/alder.h +12 -9
- package/include/numkong/dots/diamond.h +86 -0
- package/include/numkong/dots/genoa.h +10 -4
- package/include/numkong/dots/haswell.h +63 -48
- package/include/numkong/dots/icelake.h +27 -18
- package/include/numkong/dots/loongsonasx.h +176 -0
- package/include/numkong/dots/neon.h +14 -11
- package/include/numkong/dots/neonbfdot.h +4 -3
- package/include/numkong/dots/neonfhm.h +11 -9
- package/include/numkong/dots/neonfp8.h +99 -0
- package/include/numkong/dots/neonsdot.h +48 -12
- package/include/numkong/dots/powervsx.h +194 -0
- package/include/numkong/dots/rvv.h +451 -344
- package/include/numkong/dots/sapphireamx.h +1028 -984
- package/include/numkong/dots/serial.h +213 -197
- package/include/numkong/dots/sierra.h +10 -7
- package/include/numkong/dots/skylake.h +47 -36
- package/include/numkong/dots/sme.h +2001 -2364
- package/include/numkong/dots/smebi32.h +175 -162
- package/include/numkong/dots/smef64.h +328 -323
- package/include/numkong/dots/v128relaxed.h +64 -41
- package/include/numkong/dots.h +573 -293
- package/include/numkong/dots.hpp +45 -43
- package/include/numkong/each/README.md +133 -137
- package/include/numkong/each/haswell.h +6 -6
- package/include/numkong/each/icelake.h +7 -7
- package/include/numkong/each/neon.h +76 -42
- package/include/numkong/each/neonbfdot.h +11 -12
- package/include/numkong/each/neonhalf.h +24 -116
- package/include/numkong/each/rvv.h +28 -28
- package/include/numkong/each/sapphire.h +27 -161
- package/include/numkong/each/serial.h +6 -6
- package/include/numkong/each/skylake.h +7 -7
- package/include/numkong/each/v128relaxed.h +562 -0
- package/include/numkong/each.h +148 -62
- package/include/numkong/each.hpp +2 -2
- package/include/numkong/geospatial/README.md +18 -18
- package/include/numkong/geospatial/haswell.h +365 -325
- package/include/numkong/geospatial/neon.h +350 -306
- package/include/numkong/geospatial/rvv.h +4 -4
- package/include/numkong/geospatial/skylake.h +376 -340
- package/include/numkong/geospatial/v128relaxed.h +366 -327
- package/include/numkong/geospatial.h +17 -17
- package/include/numkong/matrix.hpp +4 -4
- package/include/numkong/maxsim/README.md +14 -14
- package/include/numkong/maxsim/alder.h +6 -6
- package/include/numkong/maxsim/genoa.h +4 -4
- package/include/numkong/maxsim/haswell.h +6 -6
- package/include/numkong/maxsim/icelake.h +18 -18
- package/include/numkong/maxsim/neonsdot.h +21 -21
- package/include/numkong/maxsim/sapphireamx.h +14 -14
- package/include/numkong/maxsim/serial.h +6 -6
- package/include/numkong/maxsim/sme.h +221 -196
- package/include/numkong/maxsim/v128relaxed.h +6 -6
- package/include/numkong/mesh/README.md +62 -56
- package/include/numkong/mesh/haswell.h +339 -464
- package/include/numkong/mesh/neon.h +1100 -519
- package/include/numkong/mesh/neonbfdot.h +36 -68
- package/include/numkong/mesh/rvv.h +530 -435
- package/include/numkong/mesh/serial.h +75 -91
- package/include/numkong/mesh/skylake.h +1627 -302
- package/include/numkong/mesh/v128relaxed.h +443 -330
- package/include/numkong/mesh.h +63 -49
- package/include/numkong/mesh.hpp +4 -4
- package/include/numkong/numkong.h +3 -3
- package/include/numkong/numkong.hpp +1 -0
- package/include/numkong/probability/README.md +23 -19
- package/include/numkong/probability/neon.h +82 -52
- package/include/numkong/probability/rvv.h +28 -23
- package/include/numkong/probability/serial.h +51 -39
- package/include/numkong/probability.h +20 -23
- package/include/numkong/random.h +1 -1
- package/include/numkong/reduce/README.md +143 -138
- package/include/numkong/reduce/alder.h +81 -77
- package/include/numkong/reduce/haswell.h +222 -220
- package/include/numkong/reduce/neon.h +629 -519
- package/include/numkong/reduce/neonbfdot.h +7 -218
- package/include/numkong/reduce/neonfhm.h +9 -381
- package/include/numkong/reduce/neonsdot.h +9 -9
- package/include/numkong/reduce/rvv.h +928 -802
- package/include/numkong/reduce/serial.h +23 -27
- package/include/numkong/reduce/sierra.h +20 -20
- package/include/numkong/reduce/skylake.h +326 -324
- package/include/numkong/reduce/v128relaxed.h +52 -52
- package/include/numkong/reduce.h +4 -23
- package/include/numkong/reduce.hpp +156 -11
- package/include/numkong/scalar/README.md +6 -6
- package/include/numkong/scalar/haswell.h +26 -17
- package/include/numkong/scalar/loongsonasx.h +74 -0
- package/include/numkong/scalar/neon.h +9 -9
- package/include/numkong/scalar/powervsx.h +96 -0
- package/include/numkong/scalar/rvv.h +2 -2
- package/include/numkong/scalar/sapphire.h +21 -10
- package/include/numkong/scalar/serial.h +21 -21
- package/include/numkong/scalar.h +13 -0
- package/include/numkong/set/README.md +28 -28
- package/include/numkong/set/haswell.h +12 -12
- package/include/numkong/set/icelake.h +14 -14
- package/include/numkong/set/loongsonasx.h +181 -0
- package/include/numkong/set/neon.h +17 -18
- package/include/numkong/set/powervsx.h +326 -0
- package/include/numkong/set/rvv.h +4 -4
- package/include/numkong/set/serial.h +6 -6
- package/include/numkong/set/sve.h +60 -59
- package/include/numkong/set/v128relaxed.h +6 -6
- package/include/numkong/set.h +21 -7
- package/include/numkong/sets/README.md +26 -26
- package/include/numkong/sets/loongsonasx.h +52 -0
- package/include/numkong/sets/powervsx.h +65 -0
- package/include/numkong/sets/smebi32.h +395 -364
- package/include/numkong/sets.h +83 -40
- package/include/numkong/sparse/README.md +4 -4
- package/include/numkong/sparse/icelake.h +101 -101
- package/include/numkong/sparse/serial.h +1 -1
- package/include/numkong/sparse/sve2.h +137 -141
- package/include/numkong/sparse/turin.h +12 -12
- package/include/numkong/sparse.h +10 -10
- package/include/numkong/spatial/README.md +230 -226
- package/include/numkong/spatial/alder.h +113 -116
- package/include/numkong/spatial/diamond.h +240 -0
- package/include/numkong/spatial/genoa.h +0 -68
- package/include/numkong/spatial/haswell.h +74 -55
- package/include/numkong/spatial/icelake.h +539 -58
- package/include/numkong/spatial/loongsonasx.h +483 -0
- package/include/numkong/spatial/neon.h +125 -52
- package/include/numkong/spatial/neonbfdot.h +8 -9
- package/include/numkong/spatial/neonfp8.h +258 -0
- package/include/numkong/spatial/neonsdot.h +180 -12
- package/include/numkong/spatial/powervsx.h +738 -0
- package/include/numkong/spatial/rvv.h +146 -139
- package/include/numkong/spatial/rvvbf16.h +17 -12
- package/include/numkong/spatial/rvvhalf.h +13 -10
- package/include/numkong/spatial/serial.h +13 -12
- package/include/numkong/spatial/sierra.h +232 -39
- package/include/numkong/spatial/skylake.h +73 -74
- package/include/numkong/spatial/sve.h +93 -72
- package/include/numkong/spatial/svebfdot.h +29 -29
- package/include/numkong/spatial/svehalf.h +52 -26
- package/include/numkong/spatial/svesdot.h +142 -0
- package/include/numkong/spatial/v128relaxed.h +293 -41
- package/include/numkong/spatial.h +338 -82
- package/include/numkong/spatials/README.md +194 -194
- package/include/numkong/spatials/diamond.h +82 -0
- package/include/numkong/spatials/haswell.h +2 -2
- package/include/numkong/spatials/loongsonasx.h +153 -0
- package/include/numkong/spatials/neonfp8.h +111 -0
- package/include/numkong/spatials/neonsdot.h +34 -0
- package/include/numkong/spatials/powervsx.h +153 -0
- package/include/numkong/spatials/rvv.h +259 -243
- package/include/numkong/spatials/sapphireamx.h +173 -173
- package/include/numkong/spatials/serial.h +2 -2
- package/include/numkong/spatials/skylake.h +2 -2
- package/include/numkong/spatials/sme.h +590 -605
- package/include/numkong/spatials/smef64.h +139 -130
- package/include/numkong/spatials/v128relaxed.h +2 -2
- package/include/numkong/spatials.h +820 -500
- package/include/numkong/spatials.hpp +49 -48
- package/include/numkong/tensor.hpp +406 -17
- package/include/numkong/trigonometry/README.md +19 -19
- package/include/numkong/trigonometry/haswell.h +402 -401
- package/include/numkong/trigonometry/neon.h +386 -387
- package/include/numkong/trigonometry/rvv.h +52 -51
- package/include/numkong/trigonometry/serial.h +13 -13
- package/include/numkong/trigonometry/skylake.h +373 -369
- package/include/numkong/trigonometry/v128relaxed.h +375 -374
- package/include/numkong/trigonometry.h +13 -13
- package/include/numkong/trigonometry.hpp +2 -2
- package/include/numkong/types.h +287 -49
- package/include/numkong/types.hpp +436 -12
- package/include/numkong/vector.hpp +82 -14
- package/javascript/dist/cjs/numkong-wasm.js +6 -12
- package/javascript/dist/cjs/numkong.d.ts +7 -1
- package/javascript/dist/cjs/numkong.js +37 -11
- package/javascript/dist/cjs/types.d.ts +9 -0
- package/javascript/dist/cjs/types.js +96 -0
- package/javascript/dist/esm/numkong-browser.d.ts +14 -0
- package/javascript/dist/esm/numkong-browser.js +23 -0
- package/javascript/dist/esm/numkong-wasm.js +6 -12
- package/javascript/dist/esm/numkong.d.ts +7 -1
- package/javascript/dist/esm/numkong.js +37 -11
- package/javascript/dist/esm/types.d.ts +9 -0
- package/javascript/dist/esm/types.js +96 -0
- package/javascript/node-gyp-build.d.ts +4 -1
- package/javascript/numkong-browser.ts +40 -0
- package/javascript/numkong-wasm.ts +7 -13
- package/javascript/numkong.c +5 -26
- package/javascript/numkong.ts +36 -11
- package/javascript/tsconfig-base.json +1 -0
- package/javascript/tsconfig-cjs.json +6 -1
- package/javascript/types.ts +110 -0
- package/numkong.gypi +101 -0
- package/package.json +34 -13
- package/probes/arm_neon.c +8 -0
- package/probes/arm_neon_bfdot.c +9 -0
- package/probes/arm_neon_fhm.c +9 -0
- package/probes/arm_neon_half.c +8 -0
- package/probes/arm_neon_sdot.c +9 -0
- package/probes/arm_neonfp8.c +9 -0
- package/probes/arm_sme.c +16 -0
- package/probes/arm_sme2.c +16 -0
- package/probes/arm_sme2p1.c +16 -0
- package/probes/arm_sme_bf16.c +16 -0
- package/probes/arm_sme_bi32.c +16 -0
- package/probes/arm_sme_f64.c +16 -0
- package/probes/arm_sme_fa64.c +14 -0
- package/probes/arm_sme_half.c +16 -0
- package/probes/arm_sme_lut2.c +15 -0
- package/probes/arm_sve.c +18 -0
- package/probes/arm_sve2.c +20 -0
- package/probes/arm_sve2p1.c +18 -0
- package/probes/arm_sve_bfdot.c +20 -0
- package/probes/arm_sve_half.c +18 -0
- package/probes/arm_sve_sdot.c +21 -0
- package/probes/loongarch_lasx.c +12 -0
- package/probes/power_vsx.c +12 -0
- package/probes/probe.js +127 -0
- package/probes/riscv_rvv.c +14 -0
- package/probes/riscv_rvv_bb.c +15 -0
- package/probes/riscv_rvv_bf16.c +17 -0
- package/probes/riscv_rvv_half.c +14 -0
- package/probes/wasm_v128relaxed.c +11 -0
- package/probes/x86_alder.c +17 -0
- package/probes/x86_diamond.c +17 -0
- package/probes/x86_genoa.c +17 -0
- package/probes/x86_graniteamx.c +19 -0
- package/probes/x86_haswell.c +11 -0
- package/probes/x86_icelake.c +17 -0
- package/probes/x86_sapphire.c +16 -0
- package/probes/x86_sapphireamx.c +18 -0
- package/probes/x86_sierra.c +17 -0
- package/probes/x86_skylake.c +15 -0
- package/probes/x86_turin.c +17 -0
- package/wasm/numkong-emscripten.js +2 -0
- package/wasm/numkong.d.ts +14 -0
- package/wasm/numkong.js +1124 -0
- package/wasm/numkong.wasm +0 -0
- package/include/numkong/curved/neonhalf.h +0 -212
- package/include/numkong/dot/neonhalf.h +0 -198
- package/include/numkong/dots/neonhalf.h +0 -57
- package/include/numkong/mesh/neonhalf.h +0 -616
- package/include/numkong/reduce/neonhalf.h +0 -157
- package/include/numkong/spatial/neonhalf.h +0 -118
- package/include/numkong/spatial/sapphire.h +0 -343
- package/include/numkong/spatials/neonhalf.h +0 -58
- package/javascript/README.md +0 -246
|
@@ -8,12 +8,12 @@
|
|
|
8
8
|
*
|
|
9
9
|
* @section haswell_mesh_instructions Key AVX2 Mesh Instructions
|
|
10
10
|
*
|
|
11
|
-
* Intrinsic
|
|
12
|
-
* _mm256_fmadd_ps
|
|
13
|
-
* _mm256_hadd_ps
|
|
14
|
-
* _mm256_permute2f128_ps
|
|
15
|
-
* _mm256_extractf128_ps
|
|
16
|
-
* _mm256_i32gather_ps
|
|
11
|
+
* Intrinsic Instruction Haswell Genoa
|
|
12
|
+
* _mm256_fmadd_ps VFMADD (YMM, YMM, YMM) 5cy @ p01 4cy @ p01
|
|
13
|
+
* _mm256_hadd_ps VHADDPS (YMM, YMM, YMM) 7cy @ p1+p5 4cy @ p123+p23+p23
|
|
14
|
+
* _mm256_permute2f128_ps VPERM2F128 (YMM, YMM, YMM, I8) 3cy @ p5 2cy @ p12
|
|
15
|
+
* _mm256_extractf128_ps VEXTRACTF128 (XMM, YMM, I8) 3cy @ p5 1cy @ p0123
|
|
16
|
+
* _mm256_i32gather_ps VGATHERDPS (YMM, M, YMM, YMM) 22cy (34 uops) 19cy (17 uops)
|
|
17
17
|
*
|
|
18
18
|
* Point cloud operations (centroid, covariance, Kabsch alignment) use gather instructions for
|
|
19
19
|
* stride-3 xyz deinterleaving. Multiple FMA accumulators hide the 5-cycle FMA latency. VHADDPS
|
|
@@ -50,10 +50,10 @@ extern "C" {
|
|
|
50
50
|
*/
|
|
51
51
|
NK_INTERNAL void nk_deinterleave_f32x8_haswell_(nk_f32_t const *ptr, __m256 *x_out, __m256 *y_out, __m256 *z_out) {
|
|
52
52
|
// Gather indices: 0, 3, 6, 9, 12, 15, 18, 21 (stride 3)
|
|
53
|
-
__m256i
|
|
54
|
-
*x_out = _mm256_i32gather_ps(ptr + 0,
|
|
55
|
-
*y_out = _mm256_i32gather_ps(ptr + 1,
|
|
56
|
-
*z_out = _mm256_i32gather_ps(ptr + 2,
|
|
53
|
+
__m256i idx_i32x8 = _mm256_setr_epi32(0, 3, 6, 9, 12, 15, 18, 21);
|
|
54
|
+
*x_out = _mm256_i32gather_ps(ptr + 0, idx_i32x8, 4);
|
|
55
|
+
*y_out = _mm256_i32gather_ps(ptr + 1, idx_i32x8, 4);
|
|
56
|
+
*z_out = _mm256_i32gather_ps(ptr + 2, idx_i32x8, 4);
|
|
57
57
|
}
|
|
58
58
|
|
|
59
59
|
/* Deinterleave 12 f64 values (4 xyz triplets) into separate x, y, z vectors.
|
|
@@ -134,84 +134,84 @@ NK_INTERNAL nk_f64_t nk_transformed_ssd_f32_haswell_(nk_f32_t const *a, nk_f32_t
|
|
|
134
134
|
nk_deinterleave_f32x8_haswell_(a + index * 3, &a_x_f32x8, &a_y_f32x8, &a_z_f32x8),
|
|
135
135
|
nk_deinterleave_f32x8_haswell_(b + index * 3, &b_x_f32x8, &b_y_f32x8, &b_z_f32x8);
|
|
136
136
|
|
|
137
|
-
__m256d
|
|
138
|
-
__m256d
|
|
139
|
-
__m256d
|
|
140
|
-
__m256d
|
|
141
|
-
__m256d
|
|
142
|
-
__m256d
|
|
143
|
-
__m256d
|
|
144
|
-
__m256d
|
|
145
|
-
__m256d
|
|
146
|
-
__m256d
|
|
147
|
-
__m256d
|
|
148
|
-
__m256d
|
|
149
|
-
|
|
150
|
-
__m256d
|
|
151
|
-
__m256d
|
|
152
|
-
__m256d
|
|
153
|
-
__m256d
|
|
154
|
-
__m256d
|
|
155
|
-
__m256d
|
|
156
|
-
__m256d
|
|
157
|
-
__m256d
|
|
158
|
-
__m256d
|
|
159
|
-
__m256d
|
|
160
|
-
__m256d
|
|
161
|
-
__m256d
|
|
162
|
-
|
|
163
|
-
__m256d
|
|
164
|
-
scaled_rotation_x_z_f64x4,
|
|
165
|
-
_mm256_fmadd_pd(scaled_rotation_x_y_f64x4,
|
|
166
|
-
_mm256_mul_pd(scaled_rotation_x_x_f64x4,
|
|
167
|
-
__m256d
|
|
168
|
-
scaled_rotation_x_z_f64x4,
|
|
169
|
-
_mm256_fmadd_pd(scaled_rotation_x_y_f64x4,
|
|
170
|
-
_mm256_mul_pd(scaled_rotation_x_x_f64x4,
|
|
171
|
-
__m256d
|
|
172
|
-
scaled_rotation_y_z_f64x4,
|
|
173
|
-
_mm256_fmadd_pd(scaled_rotation_y_y_f64x4,
|
|
174
|
-
_mm256_mul_pd(scaled_rotation_y_x_f64x4,
|
|
175
|
-
__m256d
|
|
176
|
-
scaled_rotation_y_z_f64x4,
|
|
177
|
-
_mm256_fmadd_pd(scaled_rotation_y_y_f64x4,
|
|
178
|
-
_mm256_mul_pd(scaled_rotation_y_x_f64x4,
|
|
179
|
-
__m256d
|
|
180
|
-
scaled_rotation_z_z_f64x4,
|
|
181
|
-
_mm256_fmadd_pd(scaled_rotation_z_y_f64x4,
|
|
182
|
-
_mm256_mul_pd(scaled_rotation_z_x_f64x4,
|
|
183
|
-
__m256d
|
|
184
|
-
scaled_rotation_z_z_f64x4,
|
|
185
|
-
_mm256_fmadd_pd(scaled_rotation_z_y_f64x4,
|
|
186
|
-
_mm256_mul_pd(scaled_rotation_z_x_f64x4,
|
|
187
|
-
|
|
188
|
-
__m256d
|
|
189
|
-
__m256d
|
|
190
|
-
__m256d
|
|
191
|
-
__m256d
|
|
192
|
-
__m256d
|
|
193
|
-
__m256d
|
|
194
|
-
|
|
195
|
-
__m256d batch_sum_squared_f64x4 = _mm256_add_pd(_mm256_mul_pd(
|
|
196
|
-
_mm256_mul_pd(
|
|
197
|
-
batch_sum_squared_f64x4 = _mm256_fmadd_pd(
|
|
198
|
-
batch_sum_squared_f64x4 = _mm256_fmadd_pd(
|
|
199
|
-
batch_sum_squared_f64x4 = _mm256_fmadd_pd(
|
|
200
|
-
batch_sum_squared_f64x4 = _mm256_fmadd_pd(
|
|
137
|
+
__m256d a_x_low_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(a_x_f32x8));
|
|
138
|
+
__m256d a_x_high_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(a_x_f32x8, 1));
|
|
139
|
+
__m256d a_y_low_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(a_y_f32x8));
|
|
140
|
+
__m256d a_y_high_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(a_y_f32x8, 1));
|
|
141
|
+
__m256d a_z_low_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(a_z_f32x8));
|
|
142
|
+
__m256d a_z_high_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(a_z_f32x8, 1));
|
|
143
|
+
__m256d b_x_low_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(b_x_f32x8));
|
|
144
|
+
__m256d b_x_high_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(b_x_f32x8, 1));
|
|
145
|
+
__m256d b_y_low_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(b_y_f32x8));
|
|
146
|
+
__m256d b_y_high_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(b_y_f32x8, 1));
|
|
147
|
+
__m256d b_z_low_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(b_z_f32x8));
|
|
148
|
+
__m256d b_z_high_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(b_z_f32x8, 1));
|
|
149
|
+
|
|
150
|
+
__m256d centered_a_x_low_f64x4 = _mm256_sub_pd(a_x_low_f64x4, centroid_a_x_f64x4);
|
|
151
|
+
__m256d centered_a_x_high_f64x4 = _mm256_sub_pd(a_x_high_f64x4, centroid_a_x_f64x4);
|
|
152
|
+
__m256d centered_a_y_low_f64x4 = _mm256_sub_pd(a_y_low_f64x4, centroid_a_y_f64x4);
|
|
153
|
+
__m256d centered_a_y_high_f64x4 = _mm256_sub_pd(a_y_high_f64x4, centroid_a_y_f64x4);
|
|
154
|
+
__m256d centered_a_z_low_f64x4 = _mm256_sub_pd(a_z_low_f64x4, centroid_a_z_f64x4);
|
|
155
|
+
__m256d centered_a_z_high_f64x4 = _mm256_sub_pd(a_z_high_f64x4, centroid_a_z_f64x4);
|
|
156
|
+
__m256d centered_b_x_low_f64x4 = _mm256_sub_pd(b_x_low_f64x4, centroid_b_x_f64x4);
|
|
157
|
+
__m256d centered_b_x_high_f64x4 = _mm256_sub_pd(b_x_high_f64x4, centroid_b_x_f64x4);
|
|
158
|
+
__m256d centered_b_y_low_f64x4 = _mm256_sub_pd(b_y_low_f64x4, centroid_b_y_f64x4);
|
|
159
|
+
__m256d centered_b_y_high_f64x4 = _mm256_sub_pd(b_y_high_f64x4, centroid_b_y_f64x4);
|
|
160
|
+
__m256d centered_b_z_low_f64x4 = _mm256_sub_pd(b_z_low_f64x4, centroid_b_z_f64x4);
|
|
161
|
+
__m256d centered_b_z_high_f64x4 = _mm256_sub_pd(b_z_high_f64x4, centroid_b_z_f64x4);
|
|
162
|
+
|
|
163
|
+
__m256d rotated_a_x_low_f64x4 = _mm256_fmadd_pd(
|
|
164
|
+
scaled_rotation_x_z_f64x4, centered_a_z_low_f64x4,
|
|
165
|
+
_mm256_fmadd_pd(scaled_rotation_x_y_f64x4, centered_a_y_low_f64x4,
|
|
166
|
+
_mm256_mul_pd(scaled_rotation_x_x_f64x4, centered_a_x_low_f64x4)));
|
|
167
|
+
__m256d rotated_a_x_high_f64x4 = _mm256_fmadd_pd(
|
|
168
|
+
scaled_rotation_x_z_f64x4, centered_a_z_high_f64x4,
|
|
169
|
+
_mm256_fmadd_pd(scaled_rotation_x_y_f64x4, centered_a_y_high_f64x4,
|
|
170
|
+
_mm256_mul_pd(scaled_rotation_x_x_f64x4, centered_a_x_high_f64x4)));
|
|
171
|
+
__m256d rotated_a_y_low_f64x4 = _mm256_fmadd_pd(
|
|
172
|
+
scaled_rotation_y_z_f64x4, centered_a_z_low_f64x4,
|
|
173
|
+
_mm256_fmadd_pd(scaled_rotation_y_y_f64x4, centered_a_y_low_f64x4,
|
|
174
|
+
_mm256_mul_pd(scaled_rotation_y_x_f64x4, centered_a_x_low_f64x4)));
|
|
175
|
+
__m256d rotated_a_y_high_f64x4 = _mm256_fmadd_pd(
|
|
176
|
+
scaled_rotation_y_z_f64x4, centered_a_z_high_f64x4,
|
|
177
|
+
_mm256_fmadd_pd(scaled_rotation_y_y_f64x4, centered_a_y_high_f64x4,
|
|
178
|
+
_mm256_mul_pd(scaled_rotation_y_x_f64x4, centered_a_x_high_f64x4)));
|
|
179
|
+
__m256d rotated_a_z_low_f64x4 = _mm256_fmadd_pd(
|
|
180
|
+
scaled_rotation_z_z_f64x4, centered_a_z_low_f64x4,
|
|
181
|
+
_mm256_fmadd_pd(scaled_rotation_z_y_f64x4, centered_a_y_low_f64x4,
|
|
182
|
+
_mm256_mul_pd(scaled_rotation_z_x_f64x4, centered_a_x_low_f64x4)));
|
|
183
|
+
__m256d rotated_a_z_high_f64x4 = _mm256_fmadd_pd(
|
|
184
|
+
scaled_rotation_z_z_f64x4, centered_a_z_high_f64x4,
|
|
185
|
+
_mm256_fmadd_pd(scaled_rotation_z_y_f64x4, centered_a_y_high_f64x4,
|
|
186
|
+
_mm256_mul_pd(scaled_rotation_z_x_f64x4, centered_a_x_high_f64x4)));
|
|
187
|
+
|
|
188
|
+
__m256d delta_x_low_f64x4 = _mm256_sub_pd(rotated_a_x_low_f64x4, centered_b_x_low_f64x4);
|
|
189
|
+
__m256d delta_x_high_f64x4 = _mm256_sub_pd(rotated_a_x_high_f64x4, centered_b_x_high_f64x4);
|
|
190
|
+
__m256d delta_y_low_f64x4 = _mm256_sub_pd(rotated_a_y_low_f64x4, centered_b_y_low_f64x4);
|
|
191
|
+
__m256d delta_y_high_f64x4 = _mm256_sub_pd(rotated_a_y_high_f64x4, centered_b_y_high_f64x4);
|
|
192
|
+
__m256d delta_z_low_f64x4 = _mm256_sub_pd(rotated_a_z_low_f64x4, centered_b_z_low_f64x4);
|
|
193
|
+
__m256d delta_z_high_f64x4 = _mm256_sub_pd(rotated_a_z_high_f64x4, centered_b_z_high_f64x4);
|
|
194
|
+
|
|
195
|
+
__m256d batch_sum_squared_f64x4 = _mm256_add_pd(_mm256_mul_pd(delta_x_low_f64x4, delta_x_low_f64x4),
|
|
196
|
+
_mm256_mul_pd(delta_x_high_f64x4, delta_x_high_f64x4));
|
|
197
|
+
batch_sum_squared_f64x4 = _mm256_fmadd_pd(delta_y_low_f64x4, delta_y_low_f64x4, batch_sum_squared_f64x4);
|
|
198
|
+
batch_sum_squared_f64x4 = _mm256_fmadd_pd(delta_y_high_f64x4, delta_y_high_f64x4, batch_sum_squared_f64x4);
|
|
199
|
+
batch_sum_squared_f64x4 = _mm256_fmadd_pd(delta_z_low_f64x4, delta_z_low_f64x4, batch_sum_squared_f64x4);
|
|
200
|
+
batch_sum_squared_f64x4 = _mm256_fmadd_pd(delta_z_high_f64x4, delta_z_high_f64x4, batch_sum_squared_f64x4);
|
|
201
201
|
sum_squared_f64x4 = _mm256_add_pd(sum_squared_f64x4, batch_sum_squared_f64x4);
|
|
202
202
|
}
|
|
203
203
|
|
|
204
204
|
nk_f64_t sum_squared = nk_reduce_add_f64x4_haswell_(sum_squared_f64x4);
|
|
205
205
|
for (; index < n; ++index) {
|
|
206
|
-
nk_f64_t centered_a_x = (nk_f64_t)a[index * 3 + 0] - centroid_a_x
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
nk_f64_t centered_b_x = (nk_f64_t)b[index * 3 + 0] - centroid_b_x
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
nk_f64_t rotated_a_x = scale * (r[0] * centered_a_x + r[1] * centered_a_y + r[2] * centered_a_z)
|
|
213
|
-
|
|
214
|
-
|
|
206
|
+
nk_f64_t centered_a_x = (nk_f64_t)a[index * 3 + 0] - centroid_a_x,
|
|
207
|
+
centered_a_y = (nk_f64_t)a[index * 3 + 1] - centroid_a_y,
|
|
208
|
+
centered_a_z = (nk_f64_t)a[index * 3 + 2] - centroid_a_z;
|
|
209
|
+
nk_f64_t centered_b_x = (nk_f64_t)b[index * 3 + 0] - centroid_b_x,
|
|
210
|
+
centered_b_y = (nk_f64_t)b[index * 3 + 1] - centroid_b_y,
|
|
211
|
+
centered_b_z = (nk_f64_t)b[index * 3 + 2] - centroid_b_z;
|
|
212
|
+
nk_f64_t rotated_a_x = scale * (r[0] * centered_a_x + r[1] * centered_a_y + r[2] * centered_a_z),
|
|
213
|
+
rotated_a_y = scale * (r[3] * centered_a_x + r[4] * centered_a_y + r[5] * centered_a_z),
|
|
214
|
+
rotated_a_z = scale * (r[6] * centered_a_x + r[7] * centered_a_y + r[8] * centered_a_z);
|
|
215
215
|
nk_f64_t delta_x = rotated_a_x - centered_b_x, delta_y = rotated_a_y - centered_b_y,
|
|
216
216
|
delta_z = rotated_a_z - centered_b_z;
|
|
217
217
|
sum_squared += delta_x * delta_x + delta_y * delta_y + delta_z * delta_z;
|
|
@@ -290,20 +290,15 @@ NK_INTERNAL nk_f64_t nk_transformed_ssd_f64_haswell_(nk_f64_t const *a, nk_f64_t
|
|
|
290
290
|
|
|
291
291
|
// Scalar tail
|
|
292
292
|
for (; j < n; ++j) {
|
|
293
|
-
nk_f64_t pa_x = a[j * 3 + 0] - centroid_a_x
|
|
294
|
-
|
|
295
|
-
nk_f64_t
|
|
296
|
-
|
|
297
|
-
nk_f64_t
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
nk_f64_t
|
|
302
|
-
nk_f64_t ra_z = scale * (r[6] * pa_x + r[7] * pa_y + r[8] * pa_z);
|
|
303
|
-
|
|
304
|
-
nk_f64_t delta_x = ra_x - pb_x;
|
|
305
|
-
nk_f64_t delta_y = ra_y - pb_y;
|
|
306
|
-
nk_f64_t delta_z = ra_z - pb_z;
|
|
293
|
+
nk_f64_t pa_x = a[j * 3 + 0] - centroid_a_x, pa_y = a[j * 3 + 1] - centroid_a_y,
|
|
294
|
+
pa_z = a[j * 3 + 2] - centroid_a_z;
|
|
295
|
+
nk_f64_t pb_x = b[j * 3 + 0] - centroid_b_x, pb_y = b[j * 3 + 1] - centroid_b_y,
|
|
296
|
+
pb_z = b[j * 3 + 2] - centroid_b_z;
|
|
297
|
+
nk_f64_t ra_x = scale * (r[0] * pa_x + r[1] * pa_y + r[2] * pa_z),
|
|
298
|
+
ra_y = scale * (r[3] * pa_x + r[4] * pa_y + r[5] * pa_z),
|
|
299
|
+
ra_z = scale * (r[6] * pa_x + r[7] * pa_y + r[8] * pa_z);
|
|
300
|
+
|
|
301
|
+
nk_f64_t delta_x = ra_x - pb_x, delta_y = ra_y - pb_y, delta_z = ra_z - pb_z;
|
|
307
302
|
nk_accumulate_square_f64_(&sum_squared, &sum_squared_compensation, delta_x);
|
|
308
303
|
nk_accumulate_square_f64_(&sum_squared, &sum_squared_compensation, delta_y);
|
|
309
304
|
nk_accumulate_square_f64_(&sum_squared, &sum_squared_compensation, delta_z);
|
|
@@ -330,38 +325,38 @@ NK_PUBLIC void nk_rmsd_f32_haswell(nk_f32_t const *a, nk_f32_t const *b, nk_size
|
|
|
330
325
|
nk_deinterleave_f32x8_haswell_(a + index * 3, &a_x_f32x8, &a_y_f32x8, &a_z_f32x8),
|
|
331
326
|
nk_deinterleave_f32x8_haswell_(b + index * 3, &b_x_f32x8, &b_y_f32x8, &b_z_f32x8);
|
|
332
327
|
|
|
333
|
-
__m256d
|
|
334
|
-
__m256d
|
|
335
|
-
__m256d
|
|
336
|
-
__m256d
|
|
337
|
-
__m256d
|
|
338
|
-
__m256d
|
|
339
|
-
__m256d
|
|
340
|
-
__m256d
|
|
341
|
-
__m256d
|
|
342
|
-
__m256d
|
|
343
|
-
__m256d
|
|
344
|
-
__m256d
|
|
345
|
-
|
|
346
|
-
sum_a_x_f64x4 = _mm256_add_pd(sum_a_x_f64x4, _mm256_add_pd(
|
|
347
|
-
sum_a_y_f64x4 = _mm256_add_pd(sum_a_y_f64x4, _mm256_add_pd(
|
|
348
|
-
sum_a_z_f64x4 = _mm256_add_pd(sum_a_z_f64x4, _mm256_add_pd(
|
|
349
|
-
sum_b_x_f64x4 = _mm256_add_pd(sum_b_x_f64x4, _mm256_add_pd(
|
|
350
|
-
sum_b_y_f64x4 = _mm256_add_pd(sum_b_y_f64x4, _mm256_add_pd(
|
|
351
|
-
sum_b_z_f64x4 = _mm256_add_pd(sum_b_z_f64x4, _mm256_add_pd(
|
|
352
|
-
|
|
353
|
-
__m256d
|
|
354
|
-
__m256d
|
|
355
|
-
__m256d
|
|
356
|
-
__m256d
|
|
357
|
-
__m256d
|
|
358
|
-
__m256d
|
|
359
|
-
__m256d batch_sum_squared_f64x4 = _mm256_add_pd(_mm256_mul_pd(
|
|
360
|
-
_mm256_mul_pd(
|
|
361
|
-
batch_sum_squared_f64x4 = _mm256_fmadd_pd(
|
|
362
|
-
batch_sum_squared_f64x4 = _mm256_fmadd_pd(
|
|
363
|
-
batch_sum_squared_f64x4 = _mm256_fmadd_pd(
|
|
364
|
-
batch_sum_squared_f64x4 = _mm256_fmadd_pd(
|
|
328
|
+
__m256d a_x_low_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(a_x_f32x8));
|
|
329
|
+
__m256d a_x_high_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(a_x_f32x8, 1));
|
|
330
|
+
__m256d a_y_low_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(a_y_f32x8));
|
|
331
|
+
__m256d a_y_high_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(a_y_f32x8, 1));
|
|
332
|
+
__m256d a_z_low_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(a_z_f32x8));
|
|
333
|
+
__m256d a_z_high_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(a_z_f32x8, 1));
|
|
334
|
+
__m256d b_x_low_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(b_x_f32x8));
|
|
335
|
+
__m256d b_x_high_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(b_x_f32x8, 1));
|
|
336
|
+
__m256d b_y_low_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(b_y_f32x8));
|
|
337
|
+
__m256d b_y_high_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(b_y_f32x8, 1));
|
|
338
|
+
__m256d b_z_low_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(b_z_f32x8));
|
|
339
|
+
__m256d b_z_high_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(b_z_f32x8, 1));
|
|
340
|
+
|
|
341
|
+
sum_a_x_f64x4 = _mm256_add_pd(sum_a_x_f64x4, _mm256_add_pd(a_x_low_f64x4, a_x_high_f64x4));
|
|
342
|
+
sum_a_y_f64x4 = _mm256_add_pd(sum_a_y_f64x4, _mm256_add_pd(a_y_low_f64x4, a_y_high_f64x4));
|
|
343
|
+
sum_a_z_f64x4 = _mm256_add_pd(sum_a_z_f64x4, _mm256_add_pd(a_z_low_f64x4, a_z_high_f64x4));
|
|
344
|
+
sum_b_x_f64x4 = _mm256_add_pd(sum_b_x_f64x4, _mm256_add_pd(b_x_low_f64x4, b_x_high_f64x4));
|
|
345
|
+
sum_b_y_f64x4 = _mm256_add_pd(sum_b_y_f64x4, _mm256_add_pd(b_y_low_f64x4, b_y_high_f64x4));
|
|
346
|
+
sum_b_z_f64x4 = _mm256_add_pd(sum_b_z_f64x4, _mm256_add_pd(b_z_low_f64x4, b_z_high_f64x4));
|
|
347
|
+
|
|
348
|
+
__m256d delta_x_low_f64x4 = _mm256_sub_pd(a_x_low_f64x4, b_x_low_f64x4);
|
|
349
|
+
__m256d delta_x_high_f64x4 = _mm256_sub_pd(a_x_high_f64x4, b_x_high_f64x4);
|
|
350
|
+
__m256d delta_y_low_f64x4 = _mm256_sub_pd(a_y_low_f64x4, b_y_low_f64x4);
|
|
351
|
+
__m256d delta_y_high_f64x4 = _mm256_sub_pd(a_y_high_f64x4, b_y_high_f64x4);
|
|
352
|
+
__m256d delta_z_low_f64x4 = _mm256_sub_pd(a_z_low_f64x4, b_z_low_f64x4);
|
|
353
|
+
__m256d delta_z_high_f64x4 = _mm256_sub_pd(a_z_high_f64x4, b_z_high_f64x4);
|
|
354
|
+
__m256d batch_sum_squared_f64x4 = _mm256_add_pd(_mm256_mul_pd(delta_x_low_f64x4, delta_x_low_f64x4),
|
|
355
|
+
_mm256_mul_pd(delta_x_high_f64x4, delta_x_high_f64x4));
|
|
356
|
+
batch_sum_squared_f64x4 = _mm256_fmadd_pd(delta_y_low_f64x4, delta_y_low_f64x4, batch_sum_squared_f64x4);
|
|
357
|
+
batch_sum_squared_f64x4 = _mm256_fmadd_pd(delta_y_high_f64x4, delta_y_high_f64x4, batch_sum_squared_f64x4);
|
|
358
|
+
batch_sum_squared_f64x4 = _mm256_fmadd_pd(delta_z_low_f64x4, delta_z_low_f64x4, batch_sum_squared_f64x4);
|
|
359
|
+
batch_sum_squared_f64x4 = _mm256_fmadd_pd(delta_z_high_f64x4, delta_z_high_f64x4, batch_sum_squared_f64x4);
|
|
365
360
|
sum_squared_f64x4 = _mm256_add_pd(sum_squared_f64x4, batch_sum_squared_f64x4);
|
|
366
361
|
}
|
|
367
362
|
|
|
@@ -401,12 +396,10 @@ NK_PUBLIC void nk_rmsd_f32_haswell(nk_f32_t const *a, nk_f32_t const *b, nk_size
|
|
|
401
396
|
|
|
402
397
|
NK_PUBLIC void nk_rmsd_f64_haswell(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *a_centroid,
|
|
403
398
|
nk_f64_t *b_centroid, nk_f64_t *rotation, nk_f64_t *scale, nk_f64_t *result) {
|
|
404
|
-
|
|
405
|
-
if (rotation)
|
|
406
|
-
rotation[0] = 1, rotation[1] = 0, rotation[2] = 0
|
|
407
|
-
rotation[3] = 0, rotation[4] = 1, rotation[5] = 0;
|
|
399
|
+
// RMSD uses identity rotation and scale=1.0
|
|
400
|
+
if (rotation)
|
|
401
|
+
rotation[0] = 1, rotation[1] = 0, rotation[2] = 0, rotation[3] = 0, rotation[4] = 1, rotation[5] = 0,
|
|
408
402
|
rotation[6] = 0, rotation[7] = 0, rotation[8] = 1;
|
|
409
|
-
}
|
|
410
403
|
if (scale) *scale = 1.0;
|
|
411
404
|
__m256d const zeros_f64x4 = _mm256_setzero_pd();
|
|
412
405
|
|
|
@@ -521,16 +514,8 @@ NK_PUBLIC void nk_rmsd_f64_haswell(nk_f64_t const *a, nk_f64_t const *b, nk_size
|
|
|
521
514
|
nk_f64_t centroid_b_y = total_by * inv_n;
|
|
522
515
|
nk_f64_t centroid_b_z = total_bz * inv_n;
|
|
523
516
|
|
|
524
|
-
if (a_centroid)
|
|
525
|
-
|
|
526
|
-
a_centroid[1] = centroid_a_y;
|
|
527
|
-
a_centroid[2] = centroid_a_z;
|
|
528
|
-
}
|
|
529
|
-
if (b_centroid) {
|
|
530
|
-
b_centroid[0] = centroid_b_x;
|
|
531
|
-
b_centroid[1] = centroid_b_y;
|
|
532
|
-
b_centroid[2] = centroid_b_z;
|
|
533
|
-
}
|
|
517
|
+
if (a_centroid) a_centroid[0] = centroid_a_x, a_centroid[1] = centroid_a_y, a_centroid[2] = centroid_a_z;
|
|
518
|
+
if (b_centroid) b_centroid[0] = centroid_b_x, b_centroid[1] = centroid_b_y, b_centroid[2] = centroid_b_z;
|
|
534
519
|
|
|
535
520
|
// Compute RMSD
|
|
536
521
|
nk_f64_t mean_diff_x = centroid_a_x - centroid_b_x;
|
|
@@ -559,53 +544,53 @@ NK_PUBLIC void nk_kabsch_f32_haswell(nk_f32_t const *a, nk_f32_t const *b, nk_si
|
|
|
559
544
|
for (; index + 8 <= n; index += 8) {
|
|
560
545
|
nk_deinterleave_f32x8_haswell_(a + index * 3, &a_x_f32x8, &a_y_f32x8, &a_z_f32x8),
|
|
561
546
|
nk_deinterleave_f32x8_haswell_(b + index * 3, &b_x_f32x8, &b_y_f32x8, &b_z_f32x8);
|
|
562
|
-
__m256d
|
|
563
|
-
__m256d
|
|
564
|
-
__m256d
|
|
565
|
-
__m256d
|
|
566
|
-
__m256d
|
|
567
|
-
__m256d
|
|
568
|
-
__m256d
|
|
569
|
-
__m256d
|
|
570
|
-
__m256d
|
|
571
|
-
__m256d
|
|
572
|
-
__m256d
|
|
573
|
-
__m256d
|
|
574
|
-
|
|
575
|
-
sum_a_x_f64x4 = _mm256_add_pd(sum_a_x_f64x4, _mm256_add_pd(
|
|
576
|
-
sum_a_y_f64x4 = _mm256_add_pd(sum_a_y_f64x4, _mm256_add_pd(
|
|
577
|
-
sum_a_z_f64x4 = _mm256_add_pd(sum_a_z_f64x4, _mm256_add_pd(
|
|
578
|
-
sum_b_x_f64x4 = _mm256_add_pd(sum_b_x_f64x4, _mm256_add_pd(
|
|
579
|
-
sum_b_y_f64x4 = _mm256_add_pd(sum_b_y_f64x4, _mm256_add_pd(
|
|
580
|
-
sum_b_z_f64x4 = _mm256_add_pd(sum_b_z_f64x4, _mm256_add_pd(
|
|
581
|
-
|
|
582
|
-
covariance_00_f64x4 = _mm256_add_pd(
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
covariance_01_f64x4 = _mm256_add_pd(
|
|
586
|
-
|
|
587
|
-
|
|
588
|
-
covariance_02_f64x4 = _mm256_add_pd(
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
covariance_10_f64x4 = _mm256_add_pd(
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
covariance_11_f64x4 = _mm256_add_pd(
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
covariance_12_f64x4 = _mm256_add_pd(
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
covariance_20_f64x4 = _mm256_add_pd(
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
covariance_21_f64x4 = _mm256_add_pd(
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
covariance_22_f64x4 = _mm256_add_pd(
|
|
607
|
-
|
|
608
|
-
|
|
547
|
+
__m256d a_x_low_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(a_x_f32x8));
|
|
548
|
+
__m256d a_x_high_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(a_x_f32x8, 1));
|
|
549
|
+
__m256d a_y_low_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(a_y_f32x8));
|
|
550
|
+
__m256d a_y_high_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(a_y_f32x8, 1));
|
|
551
|
+
__m256d a_z_low_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(a_z_f32x8));
|
|
552
|
+
__m256d a_z_high_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(a_z_f32x8, 1));
|
|
553
|
+
__m256d b_x_low_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(b_x_f32x8));
|
|
554
|
+
__m256d b_x_high_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(b_x_f32x8, 1));
|
|
555
|
+
__m256d b_y_low_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(b_y_f32x8));
|
|
556
|
+
__m256d b_y_high_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(b_y_f32x8, 1));
|
|
557
|
+
__m256d b_z_low_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(b_z_f32x8));
|
|
558
|
+
__m256d b_z_high_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(b_z_f32x8, 1));
|
|
559
|
+
|
|
560
|
+
sum_a_x_f64x4 = _mm256_add_pd(sum_a_x_f64x4, _mm256_add_pd(a_x_low_f64x4, a_x_high_f64x4));
|
|
561
|
+
sum_a_y_f64x4 = _mm256_add_pd(sum_a_y_f64x4, _mm256_add_pd(a_y_low_f64x4, a_y_high_f64x4));
|
|
562
|
+
sum_a_z_f64x4 = _mm256_add_pd(sum_a_z_f64x4, _mm256_add_pd(a_z_low_f64x4, a_z_high_f64x4));
|
|
563
|
+
sum_b_x_f64x4 = _mm256_add_pd(sum_b_x_f64x4, _mm256_add_pd(b_x_low_f64x4, b_x_high_f64x4));
|
|
564
|
+
sum_b_y_f64x4 = _mm256_add_pd(sum_b_y_f64x4, _mm256_add_pd(b_y_low_f64x4, b_y_high_f64x4));
|
|
565
|
+
sum_b_z_f64x4 = _mm256_add_pd(sum_b_z_f64x4, _mm256_add_pd(b_z_low_f64x4, b_z_high_f64x4));
|
|
566
|
+
|
|
567
|
+
covariance_00_f64x4 = _mm256_add_pd(
|
|
568
|
+
covariance_00_f64x4,
|
|
569
|
+
_mm256_add_pd(_mm256_mul_pd(a_x_low_f64x4, b_x_low_f64x4), _mm256_mul_pd(a_x_high_f64x4, b_x_high_f64x4)));
|
|
570
|
+
covariance_01_f64x4 = _mm256_add_pd(
|
|
571
|
+
covariance_01_f64x4,
|
|
572
|
+
_mm256_add_pd(_mm256_mul_pd(a_x_low_f64x4, b_y_low_f64x4), _mm256_mul_pd(a_x_high_f64x4, b_y_high_f64x4)));
|
|
573
|
+
covariance_02_f64x4 = _mm256_add_pd(
|
|
574
|
+
covariance_02_f64x4,
|
|
575
|
+
_mm256_add_pd(_mm256_mul_pd(a_x_low_f64x4, b_z_low_f64x4), _mm256_mul_pd(a_x_high_f64x4, b_z_high_f64x4)));
|
|
576
|
+
covariance_10_f64x4 = _mm256_add_pd(
|
|
577
|
+
covariance_10_f64x4,
|
|
578
|
+
_mm256_add_pd(_mm256_mul_pd(a_y_low_f64x4, b_x_low_f64x4), _mm256_mul_pd(a_y_high_f64x4, b_x_high_f64x4)));
|
|
579
|
+
covariance_11_f64x4 = _mm256_add_pd(
|
|
580
|
+
covariance_11_f64x4,
|
|
581
|
+
_mm256_add_pd(_mm256_mul_pd(a_y_low_f64x4, b_y_low_f64x4), _mm256_mul_pd(a_y_high_f64x4, b_y_high_f64x4)));
|
|
582
|
+
covariance_12_f64x4 = _mm256_add_pd(
|
|
583
|
+
covariance_12_f64x4,
|
|
584
|
+
_mm256_add_pd(_mm256_mul_pd(a_y_low_f64x4, b_z_low_f64x4), _mm256_mul_pd(a_y_high_f64x4, b_z_high_f64x4)));
|
|
585
|
+
covariance_20_f64x4 = _mm256_add_pd(
|
|
586
|
+
covariance_20_f64x4,
|
|
587
|
+
_mm256_add_pd(_mm256_mul_pd(a_z_low_f64x4, b_x_low_f64x4), _mm256_mul_pd(a_z_high_f64x4, b_x_high_f64x4)));
|
|
588
|
+
covariance_21_f64x4 = _mm256_add_pd(
|
|
589
|
+
covariance_21_f64x4,
|
|
590
|
+
_mm256_add_pd(_mm256_mul_pd(a_z_low_f64x4, b_y_low_f64x4), _mm256_mul_pd(a_z_high_f64x4, b_y_high_f64x4)));
|
|
591
|
+
covariance_22_f64x4 = _mm256_add_pd(
|
|
592
|
+
covariance_22_f64x4,
|
|
593
|
+
_mm256_add_pd(_mm256_mul_pd(a_z_low_f64x4, b_z_low_f64x4), _mm256_mul_pd(a_z_high_f64x4, b_z_high_f64x4)));
|
|
609
594
|
}
|
|
610
595
|
|
|
611
596
|
nk_f64_t sum_a_x = nk_reduce_add_f64x4_haswell_(sum_a_x_f64x4);
|
|
@@ -775,27 +760,19 @@ NK_PUBLIC void nk_kabsch_f64_haswell(nk_f64_t const *a, nk_f64_t const *b, nk_si
|
|
|
775
760
|
nk_f64_t centroid_b_y = sum_b_y * inv_n;
|
|
776
761
|
nk_f64_t centroid_b_z = sum_b_z * inv_n;
|
|
777
762
|
|
|
778
|
-
if (a_centroid)
|
|
779
|
-
|
|
780
|
-
a_centroid[1] = centroid_a_y;
|
|
781
|
-
a_centroid[2] = centroid_a_z;
|
|
782
|
-
}
|
|
783
|
-
if (b_centroid) {
|
|
784
|
-
b_centroid[0] = centroid_b_x;
|
|
785
|
-
b_centroid[1] = centroid_b_y;
|
|
786
|
-
b_centroid[2] = centroid_b_z;
|
|
787
|
-
}
|
|
763
|
+
if (a_centroid) a_centroid[0] = centroid_a_x, a_centroid[1] = centroid_a_y, a_centroid[2] = centroid_a_z;
|
|
764
|
+
if (b_centroid) b_centroid[0] = centroid_b_x, b_centroid[1] = centroid_b_y, b_centroid[2] = centroid_b_z;
|
|
788
765
|
|
|
789
766
|
// Apply centering correction: H_centered = H - n * centroid_a * centroid_bᵀ
|
|
790
|
-
covariance_x_x -= n * centroid_a_x * centroid_b_x;
|
|
791
|
-
covariance_x_y -= n * centroid_a_x * centroid_b_y;
|
|
792
|
-
covariance_x_z -= n * centroid_a_x * centroid_b_z;
|
|
793
|
-
covariance_y_x -= n * centroid_a_y * centroid_b_x;
|
|
794
|
-
covariance_y_y -= n * centroid_a_y * centroid_b_y;
|
|
795
|
-
covariance_y_z -= n * centroid_a_y * centroid_b_z;
|
|
796
|
-
covariance_z_x -= n * centroid_a_z * centroid_b_x;
|
|
797
|
-
covariance_z_y -= n * centroid_a_z * centroid_b_y;
|
|
798
|
-
covariance_z_z -= n * centroid_a_z * centroid_b_z;
|
|
767
|
+
covariance_x_x -= (nk_f64_t)n * centroid_a_x * centroid_b_x;
|
|
768
|
+
covariance_x_y -= (nk_f64_t)n * centroid_a_x * centroid_b_y;
|
|
769
|
+
covariance_x_z -= (nk_f64_t)n * centroid_a_x * centroid_b_z;
|
|
770
|
+
covariance_y_x -= (nk_f64_t)n * centroid_a_y * centroid_b_x;
|
|
771
|
+
covariance_y_y -= (nk_f64_t)n * centroid_a_y * centroid_b_y;
|
|
772
|
+
covariance_y_z -= (nk_f64_t)n * centroid_a_y * centroid_b_z;
|
|
773
|
+
covariance_z_x -= (nk_f64_t)n * centroid_a_z * centroid_b_x;
|
|
774
|
+
covariance_z_y -= (nk_f64_t)n * centroid_a_z * centroid_b_y;
|
|
775
|
+
covariance_z_z -= (nk_f64_t)n * centroid_a_z * centroid_b_z;
|
|
799
776
|
|
|
800
777
|
// Compute SVD and optimal rotation using f64 precision (svd_s is 9-element diagonal matrix)
|
|
801
778
|
nk_f64_t cross_covariance[9] = {covariance_x_x, covariance_x_y, covariance_x_z, covariance_y_x, covariance_y_y,
|
|
@@ -808,16 +785,13 @@ NK_PUBLIC void nk_kabsch_f64_haswell(nk_f64_t const *a, nk_f64_t const *b, nk_si
|
|
|
808
785
|
|
|
809
786
|
// Handle reflection: if det(R) < 0, negate third column of V and recompute R
|
|
810
787
|
if (nk_det3x3_f64_(r) < 0) {
|
|
811
|
-
svd_v[2] = -svd_v[2];
|
|
812
|
-
svd_v[5] = -svd_v[5];
|
|
813
|
-
svd_v[8] = -svd_v[8];
|
|
788
|
+
svd_v[2] = -svd_v[2], svd_v[5] = -svd_v[5], svd_v[8] = -svd_v[8];
|
|
814
789
|
nk_rotation_from_svd_f64_haswell_(svd_u, svd_v, r);
|
|
815
790
|
}
|
|
816
791
|
|
|
817
|
-
|
|
818
|
-
if (rotation)
|
|
792
|
+
// Output rotation matrix and scale=1.0
|
|
793
|
+
if (rotation)
|
|
819
794
|
for (int j = 0; j < 9; ++j) rotation[j] = r[j];
|
|
820
|
-
}
|
|
821
795
|
if (scale) *scale = 1.0;
|
|
822
796
|
|
|
823
797
|
// Compute RMSD after optimal rotation
|
|
@@ -842,60 +816,60 @@ NK_PUBLIC void nk_umeyama_f32_haswell(nk_f32_t const *a, nk_f32_t const *b, nk_s
|
|
|
842
816
|
for (; index + 8 <= n; index += 8) {
|
|
843
817
|
nk_deinterleave_f32x8_haswell_(a + index * 3, &a_x_f32x8, &a_y_f32x8, &a_z_f32x8),
|
|
844
818
|
nk_deinterleave_f32x8_haswell_(b + index * 3, &b_x_f32x8, &b_y_f32x8, &b_z_f32x8);
|
|
845
|
-
__m256d
|
|
846
|
-
__m256d
|
|
847
|
-
__m256d
|
|
848
|
-
__m256d
|
|
849
|
-
__m256d
|
|
850
|
-
__m256d
|
|
851
|
-
__m256d
|
|
852
|
-
__m256d
|
|
853
|
-
__m256d
|
|
854
|
-
__m256d
|
|
855
|
-
__m256d
|
|
856
|
-
__m256d
|
|
857
|
-
|
|
858
|
-
sum_a_x_f64x4 = _mm256_add_pd(sum_a_x_f64x4, _mm256_add_pd(
|
|
859
|
-
sum_a_y_f64x4 = _mm256_add_pd(sum_a_y_f64x4, _mm256_add_pd(
|
|
860
|
-
sum_a_z_f64x4 = _mm256_add_pd(sum_a_z_f64x4, _mm256_add_pd(
|
|
861
|
-
sum_b_x_f64x4 = _mm256_add_pd(sum_b_x_f64x4, _mm256_add_pd(
|
|
862
|
-
sum_b_y_f64x4 = _mm256_add_pd(sum_b_y_f64x4, _mm256_add_pd(
|
|
863
|
-
sum_b_z_f64x4 = _mm256_add_pd(sum_b_z_f64x4, _mm256_add_pd(
|
|
864
|
-
covariance_00_f64x4 = _mm256_add_pd(
|
|
865
|
-
|
|
866
|
-
|
|
867
|
-
covariance_01_f64x4 = _mm256_add_pd(
|
|
868
|
-
|
|
869
|
-
|
|
870
|
-
covariance_02_f64x4 = _mm256_add_pd(
|
|
871
|
-
|
|
872
|
-
|
|
873
|
-
covariance_10_f64x4 = _mm256_add_pd(
|
|
874
|
-
|
|
875
|
-
|
|
876
|
-
covariance_11_f64x4 = _mm256_add_pd(
|
|
877
|
-
|
|
878
|
-
|
|
879
|
-
covariance_12_f64x4 = _mm256_add_pd(
|
|
880
|
-
|
|
881
|
-
|
|
882
|
-
covariance_20_f64x4 = _mm256_add_pd(
|
|
883
|
-
|
|
884
|
-
|
|
885
|
-
covariance_21_f64x4 = _mm256_add_pd(
|
|
886
|
-
|
|
887
|
-
|
|
888
|
-
covariance_22_f64x4 = _mm256_add_pd(
|
|
889
|
-
|
|
890
|
-
|
|
819
|
+
__m256d a_x_low_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(a_x_f32x8));
|
|
820
|
+
__m256d a_x_high_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(a_x_f32x8, 1));
|
|
821
|
+
__m256d a_y_low_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(a_y_f32x8));
|
|
822
|
+
__m256d a_y_high_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(a_y_f32x8, 1));
|
|
823
|
+
__m256d a_z_low_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(a_z_f32x8));
|
|
824
|
+
__m256d a_z_high_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(a_z_f32x8, 1));
|
|
825
|
+
__m256d b_x_low_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(b_x_f32x8));
|
|
826
|
+
__m256d b_x_high_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(b_x_f32x8, 1));
|
|
827
|
+
__m256d b_y_low_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(b_y_f32x8));
|
|
828
|
+
__m256d b_y_high_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(b_y_f32x8, 1));
|
|
829
|
+
__m256d b_z_low_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(b_z_f32x8));
|
|
830
|
+
__m256d b_z_high_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(b_z_f32x8, 1));
|
|
831
|
+
|
|
832
|
+
sum_a_x_f64x4 = _mm256_add_pd(sum_a_x_f64x4, _mm256_add_pd(a_x_low_f64x4, a_x_high_f64x4));
|
|
833
|
+
sum_a_y_f64x4 = _mm256_add_pd(sum_a_y_f64x4, _mm256_add_pd(a_y_low_f64x4, a_y_high_f64x4));
|
|
834
|
+
sum_a_z_f64x4 = _mm256_add_pd(sum_a_z_f64x4, _mm256_add_pd(a_z_low_f64x4, a_z_high_f64x4));
|
|
835
|
+
sum_b_x_f64x4 = _mm256_add_pd(sum_b_x_f64x4, _mm256_add_pd(b_x_low_f64x4, b_x_high_f64x4));
|
|
836
|
+
sum_b_y_f64x4 = _mm256_add_pd(sum_b_y_f64x4, _mm256_add_pd(b_y_low_f64x4, b_y_high_f64x4));
|
|
837
|
+
sum_b_z_f64x4 = _mm256_add_pd(sum_b_z_f64x4, _mm256_add_pd(b_z_low_f64x4, b_z_high_f64x4));
|
|
838
|
+
covariance_00_f64x4 = _mm256_add_pd(
|
|
839
|
+
covariance_00_f64x4,
|
|
840
|
+
_mm256_add_pd(_mm256_mul_pd(a_x_low_f64x4, b_x_low_f64x4), _mm256_mul_pd(a_x_high_f64x4, b_x_high_f64x4)));
|
|
841
|
+
covariance_01_f64x4 = _mm256_add_pd(
|
|
842
|
+
covariance_01_f64x4,
|
|
843
|
+
_mm256_add_pd(_mm256_mul_pd(a_x_low_f64x4, b_y_low_f64x4), _mm256_mul_pd(a_x_high_f64x4, b_y_high_f64x4)));
|
|
844
|
+
covariance_02_f64x4 = _mm256_add_pd(
|
|
845
|
+
covariance_02_f64x4,
|
|
846
|
+
_mm256_add_pd(_mm256_mul_pd(a_x_low_f64x4, b_z_low_f64x4), _mm256_mul_pd(a_x_high_f64x4, b_z_high_f64x4)));
|
|
847
|
+
covariance_10_f64x4 = _mm256_add_pd(
|
|
848
|
+
covariance_10_f64x4,
|
|
849
|
+
_mm256_add_pd(_mm256_mul_pd(a_y_low_f64x4, b_x_low_f64x4), _mm256_mul_pd(a_y_high_f64x4, b_x_high_f64x4)));
|
|
850
|
+
covariance_11_f64x4 = _mm256_add_pd(
|
|
851
|
+
covariance_11_f64x4,
|
|
852
|
+
_mm256_add_pd(_mm256_mul_pd(a_y_low_f64x4, b_y_low_f64x4), _mm256_mul_pd(a_y_high_f64x4, b_y_high_f64x4)));
|
|
853
|
+
covariance_12_f64x4 = _mm256_add_pd(
|
|
854
|
+
covariance_12_f64x4,
|
|
855
|
+
_mm256_add_pd(_mm256_mul_pd(a_y_low_f64x4, b_z_low_f64x4), _mm256_mul_pd(a_y_high_f64x4, b_z_high_f64x4)));
|
|
856
|
+
covariance_20_f64x4 = _mm256_add_pd(
|
|
857
|
+
covariance_20_f64x4,
|
|
858
|
+
_mm256_add_pd(_mm256_mul_pd(a_z_low_f64x4, b_x_low_f64x4), _mm256_mul_pd(a_z_high_f64x4, b_x_high_f64x4)));
|
|
859
|
+
covariance_21_f64x4 = _mm256_add_pd(
|
|
860
|
+
covariance_21_f64x4,
|
|
861
|
+
_mm256_add_pd(_mm256_mul_pd(a_z_low_f64x4, b_y_low_f64x4), _mm256_mul_pd(a_z_high_f64x4, b_y_high_f64x4)));
|
|
862
|
+
covariance_22_f64x4 = _mm256_add_pd(
|
|
863
|
+
covariance_22_f64x4,
|
|
864
|
+
_mm256_add_pd(_mm256_mul_pd(a_z_low_f64x4, b_z_low_f64x4), _mm256_mul_pd(a_z_high_f64x4, b_z_high_f64x4)));
|
|
891
865
|
variance_a_f64x4 = _mm256_add_pd(
|
|
892
866
|
variance_a_f64x4,
|
|
893
|
-
_mm256_add_pd(_mm256_add_pd(_mm256_mul_pd(
|
|
894
|
-
_mm256_mul_pd(
|
|
895
|
-
_mm256_add_pd(_mm256_add_pd(_mm256_mul_pd(
|
|
896
|
-
_mm256_mul_pd(
|
|
897
|
-
_mm256_add_pd(_mm256_mul_pd(
|
|
898
|
-
_mm256_mul_pd(
|
|
867
|
+
_mm256_add_pd(_mm256_add_pd(_mm256_mul_pd(a_x_low_f64x4, a_x_low_f64x4),
|
|
868
|
+
_mm256_mul_pd(a_x_high_f64x4, a_x_high_f64x4)),
|
|
869
|
+
_mm256_add_pd(_mm256_add_pd(_mm256_mul_pd(a_y_low_f64x4, a_y_low_f64x4),
|
|
870
|
+
_mm256_mul_pd(a_y_high_f64x4, a_y_high_f64x4)),
|
|
871
|
+
_mm256_add_pd(_mm256_mul_pd(a_z_low_f64x4, a_z_low_f64x4),
|
|
872
|
+
_mm256_mul_pd(a_z_high_f64x4, a_z_high_f64x4)))));
|
|
899
873
|
}
|
|
900
874
|
|
|
901
875
|
nk_f64_t sum_a_x = nk_reduce_add_f64x4_haswell_(sum_a_x_f64x4);
|
|
@@ -1106,7 +1080,7 @@ NK_PUBLIC void nk_umeyama_f64_haswell(nk_f64_t const *a, nk_f64_t const *b, nk_s
|
|
|
1106
1080
|
nk_f64_t det = nk_det3x3_f64_(r);
|
|
1107
1081
|
nk_f64_t d3 = det < 0 ? -1.0 : 1.0;
|
|
1108
1082
|
nk_f64_t trace_ds = nk_sum_three_products_f64_(svd_s[0], 1.0, svd_s[4], 1.0, svd_s[8], d3);
|
|
1109
|
-
nk_f64_t c = trace_ds / (n * variance_a);
|
|
1083
|
+
nk_f64_t c = trace_ds / ((nk_f64_t)n * variance_a);
|
|
1110
1084
|
if (scale) *scale = c;
|
|
1111
1085
|
|
|
1112
1086
|
// Handle reflection
|
|
@@ -1115,10 +1089,9 @@ NK_PUBLIC void nk_umeyama_f64_haswell(nk_f64_t const *a, nk_f64_t const *b, nk_s
|
|
|
1115
1089
|
nk_rotation_from_svd_f64_haswell_(svd_u, svd_v, r);
|
|
1116
1090
|
}
|
|
1117
1091
|
|
|
1118
|
-
|
|
1119
|
-
if (rotation)
|
|
1092
|
+
// Output rotation matrix
|
|
1093
|
+
if (rotation)
|
|
1120
1094
|
for (int j = 0; j < 9; ++j) rotation[j] = r[j];
|
|
1121
|
-
}
|
|
1122
1095
|
|
|
1123
1096
|
// Compute RMSD with scaling
|
|
1124
1097
|
nk_f64_t sum_squared = nk_transformed_ssd_f64_haswell_(a, b, n, r, c, centroid_a_x, centroid_a_y, centroid_a_z,
|
|
@@ -1247,20 +1220,13 @@ NK_INTERNAL nk_f32_t nk_transformed_ssd_f16_haswell_(nk_f16_t const *a, nk_f16_t
|
|
|
1247
1220
|
nk_f16_to_f32_haswell(&b[j * 3 + 1], &b_y_f32);
|
|
1248
1221
|
nk_f16_to_f32_haswell(&b[j * 3 + 2], &b_z_f32);
|
|
1249
1222
|
|
|
1250
|
-
nk_f32_t pa_x = a_x_f32 - centroid_a_x;
|
|
1251
|
-
nk_f32_t
|
|
1252
|
-
nk_f32_t
|
|
1253
|
-
|
|
1254
|
-
|
|
1255
|
-
nk_f32_t pb_z = b_z_f32 - centroid_b_z;
|
|
1223
|
+
nk_f32_t pa_x = a_x_f32 - centroid_a_x, pa_y = a_y_f32 - centroid_a_y, pa_z = a_z_f32 - centroid_a_z;
|
|
1224
|
+
nk_f32_t pb_x = b_x_f32 - centroid_b_x, pb_y = b_y_f32 - centroid_b_y, pb_z = b_z_f32 - centroid_b_z;
|
|
1225
|
+
nk_f32_t ra_x = scale * (r[0] * pa_x + r[1] * pa_y + r[2] * pa_z),
|
|
1226
|
+
ra_y = scale * (r[3] * pa_x + r[4] * pa_y + r[5] * pa_z),
|
|
1227
|
+
ra_z = scale * (r[6] * pa_x + r[7] * pa_y + r[8] * pa_z);
|
|
1256
1228
|
|
|
1257
|
-
nk_f32_t
|
|
1258
|
-
nk_f32_t ra_y = scale * (r[3] * pa_x + r[4] * pa_y + r[5] * pa_z);
|
|
1259
|
-
nk_f32_t ra_z = scale * (r[6] * pa_x + r[7] * pa_y + r[8] * pa_z);
|
|
1260
|
-
|
|
1261
|
-
nk_f32_t delta_x = ra_x - pb_x;
|
|
1262
|
-
nk_f32_t delta_y = ra_y - pb_y;
|
|
1263
|
-
nk_f32_t delta_z = ra_z - pb_z;
|
|
1229
|
+
nk_f32_t delta_x = ra_x - pb_x, delta_y = ra_y - pb_y, delta_z = ra_z - pb_z;
|
|
1264
1230
|
sum_squared += delta_x * delta_x + delta_y * delta_y + delta_z * delta_z;
|
|
1265
1231
|
}
|
|
1266
1232
|
|
|
@@ -1344,20 +1310,13 @@ NK_INTERNAL nk_f32_t nk_transformed_ssd_bf16_haswell_(nk_bf16_t const *a, nk_bf1
|
|
|
1344
1310
|
nk_bf16_to_f32_serial(&b[j * 3 + 1], &b_y_f32);
|
|
1345
1311
|
nk_bf16_to_f32_serial(&b[j * 3 + 2], &b_z_f32);
|
|
1346
1312
|
|
|
1347
|
-
nk_f32_t pa_x = a_x_f32 - centroid_a_x;
|
|
1348
|
-
nk_f32_t
|
|
1349
|
-
nk_f32_t
|
|
1350
|
-
|
|
1351
|
-
|
|
1352
|
-
nk_f32_t pb_z = b_z_f32 - centroid_b_z;
|
|
1353
|
-
|
|
1354
|
-
nk_f32_t ra_x = scale * (r[0] * pa_x + r[1] * pa_y + r[2] * pa_z);
|
|
1355
|
-
nk_f32_t ra_y = scale * (r[3] * pa_x + r[4] * pa_y + r[5] * pa_z);
|
|
1356
|
-
nk_f32_t ra_z = scale * (r[6] * pa_x + r[7] * pa_y + r[8] * pa_z);
|
|
1313
|
+
nk_f32_t pa_x = a_x_f32 - centroid_a_x, pa_y = a_y_f32 - centroid_a_y, pa_z = a_z_f32 - centroid_a_z;
|
|
1314
|
+
nk_f32_t pb_x = b_x_f32 - centroid_b_x, pb_y = b_y_f32 - centroid_b_y, pb_z = b_z_f32 - centroid_b_z;
|
|
1315
|
+
nk_f32_t ra_x = scale * (r[0] * pa_x + r[1] * pa_y + r[2] * pa_z),
|
|
1316
|
+
ra_y = scale * (r[3] * pa_x + r[4] * pa_y + r[5] * pa_z),
|
|
1317
|
+
ra_z = scale * (r[6] * pa_x + r[7] * pa_y + r[8] * pa_z);
|
|
1357
1318
|
|
|
1358
|
-
nk_f32_t delta_x = ra_x - pb_x;
|
|
1359
|
-
nk_f32_t delta_y = ra_y - pb_y;
|
|
1360
|
-
nk_f32_t delta_z = ra_z - pb_z;
|
|
1319
|
+
nk_f32_t delta_x = ra_x - pb_x, delta_y = ra_y - pb_y, delta_z = ra_z - pb_z;
|
|
1361
1320
|
sum_squared += delta_x * delta_x + delta_y * delta_y + delta_z * delta_z;
|
|
1362
1321
|
}
|
|
1363
1322
|
|
|
@@ -1366,12 +1325,10 @@ NK_INTERNAL nk_f32_t nk_transformed_ssd_bf16_haswell_(nk_bf16_t const *a, nk_bf1
|
|
|
1366
1325
|
|
|
1367
1326
|
NK_PUBLIC void nk_rmsd_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
|
|
1368
1327
|
nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
|
|
1369
|
-
|
|
1370
|
-
if (rotation)
|
|
1371
|
-
rotation[0] = 1, rotation[1] = 0, rotation[2] = 0
|
|
1372
|
-
rotation[3] = 0, rotation[4] = 1, rotation[5] = 0;
|
|
1328
|
+
// RMSD uses identity rotation and scale=1.0
|
|
1329
|
+
if (rotation)
|
|
1330
|
+
rotation[0] = 1, rotation[1] = 0, rotation[2] = 0, rotation[3] = 0, rotation[4] = 1, rotation[5] = 0,
|
|
1373
1331
|
rotation[6] = 0, rotation[7] = 0, rotation[8] = 1;
|
|
1374
|
-
}
|
|
1375
1332
|
if (scale) *scale = 1.0f;
|
|
1376
1333
|
|
|
1377
1334
|
__m256 const zeros_f32x8 = _mm256_setzero_ps();
|
|
@@ -1446,16 +1403,8 @@ NK_PUBLIC void nk_rmsd_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_size
|
|
|
1446
1403
|
nk_f32_t centroid_b_y = total_by * inv_n;
|
|
1447
1404
|
nk_f32_t centroid_b_z = total_bz * inv_n;
|
|
1448
1405
|
|
|
1449
|
-
if (a_centroid)
|
|
1450
|
-
|
|
1451
|
-
a_centroid[1] = centroid_a_y;
|
|
1452
|
-
a_centroid[2] = centroid_a_z;
|
|
1453
|
-
}
|
|
1454
|
-
if (b_centroid) {
|
|
1455
|
-
b_centroid[0] = centroid_b_x;
|
|
1456
|
-
b_centroid[1] = centroid_b_y;
|
|
1457
|
-
b_centroid[2] = centroid_b_z;
|
|
1458
|
-
}
|
|
1406
|
+
if (a_centroid) a_centroid[0] = centroid_a_x, a_centroid[1] = centroid_a_y, a_centroid[2] = centroid_a_z;
|
|
1407
|
+
if (b_centroid) b_centroid[0] = centroid_b_x, b_centroid[1] = centroid_b_y, b_centroid[2] = centroid_b_z;
|
|
1459
1408
|
|
|
1460
1409
|
// Compute RMSD
|
|
1461
1410
|
nk_f32_t mean_diff_x = centroid_a_x - centroid_b_x;
|
|
@@ -1469,12 +1418,10 @@ NK_PUBLIC void nk_rmsd_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_size
|
|
|
1469
1418
|
|
|
1470
1419
|
NK_PUBLIC void nk_rmsd_bf16_haswell(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
|
|
1471
1420
|
nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
|
|
1472
|
-
|
|
1473
|
-
if (rotation)
|
|
1474
|
-
rotation[0] = 1, rotation[1] = 0, rotation[2] = 0
|
|
1475
|
-
rotation[3] = 0, rotation[4] = 1, rotation[5] = 0;
|
|
1421
|
+
// RMSD uses identity rotation and scale=1.0
|
|
1422
|
+
if (rotation)
|
|
1423
|
+
rotation[0] = 1, rotation[1] = 0, rotation[2] = 0, rotation[3] = 0, rotation[4] = 1, rotation[5] = 0,
|
|
1476
1424
|
rotation[6] = 0, rotation[7] = 0, rotation[8] = 1;
|
|
1477
|
-
}
|
|
1478
1425
|
if (scale) *scale = 1.0f;
|
|
1479
1426
|
|
|
1480
1427
|
__m256 const zeros_f32x8 = _mm256_setzero_ps();
|
|
@@ -1549,16 +1496,8 @@ NK_PUBLIC void nk_rmsd_bf16_haswell(nk_bf16_t const *a, nk_bf16_t const *b, nk_s
|
|
|
1549
1496
|
nk_f32_t centroid_b_y = total_by * inv_n;
|
|
1550
1497
|
nk_f32_t centroid_b_z = total_bz * inv_n;
|
|
1551
1498
|
|
|
1552
|
-
if (a_centroid)
|
|
1553
|
-
|
|
1554
|
-
a_centroid[1] = centroid_a_y;
|
|
1555
|
-
a_centroid[2] = centroid_a_z;
|
|
1556
|
-
}
|
|
1557
|
-
if (b_centroid) {
|
|
1558
|
-
b_centroid[0] = centroid_b_x;
|
|
1559
|
-
b_centroid[1] = centroid_b_y;
|
|
1560
|
-
b_centroid[2] = centroid_b_z;
|
|
1561
|
-
}
|
|
1499
|
+
if (a_centroid) a_centroid[0] = centroid_a_x, a_centroid[1] = centroid_a_y, a_centroid[2] = centroid_a_z;
|
|
1500
|
+
if (b_centroid) b_centroid[0] = centroid_b_x, b_centroid[1] = centroid_b_y, b_centroid[2] = centroid_b_z;
|
|
1562
1501
|
|
|
1563
1502
|
// Compute RMSD
|
|
1564
1503
|
nk_f32_t mean_diff_x = centroid_a_x - centroid_b_x;
|
|
@@ -1638,21 +1577,11 @@ NK_PUBLIC void nk_kabsch_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_si
|
|
|
1638
1577
|
nk_f16_to_f32_haswell(&b[i * 3 + 0], &bx);
|
|
1639
1578
|
nk_f16_to_f32_haswell(&b[i * 3 + 1], &by);
|
|
1640
1579
|
nk_f16_to_f32_haswell(&b[i * 3 + 2], &bz);
|
|
1641
|
-
sum_a_x += ax;
|
|
1642
|
-
|
|
1643
|
-
|
|
1644
|
-
|
|
1645
|
-
|
|
1646
|
-
sum_b_z += bz;
|
|
1647
|
-
covariance_x_x += ax * bx;
|
|
1648
|
-
covariance_x_y += ax * by;
|
|
1649
|
-
covariance_x_z += ax * bz;
|
|
1650
|
-
covariance_y_x += ay * bx;
|
|
1651
|
-
covariance_y_y += ay * by;
|
|
1652
|
-
covariance_y_z += ay * bz;
|
|
1653
|
-
covariance_z_x += az * bx;
|
|
1654
|
-
covariance_z_y += az * by;
|
|
1655
|
-
covariance_z_z += az * bz;
|
|
1580
|
+
sum_a_x += ax, sum_a_y += ay, sum_a_z += az;
|
|
1581
|
+
sum_b_x += bx, sum_b_y += by, sum_b_z += bz;
|
|
1582
|
+
covariance_x_x += ax * bx, covariance_x_y += ax * by, covariance_x_z += ax * bz;
|
|
1583
|
+
covariance_y_x += ay * bx, covariance_y_y += ay * by, covariance_y_z += ay * bz;
|
|
1584
|
+
covariance_z_x += az * bx, covariance_z_y += az * by, covariance_z_z += az * bz;
|
|
1656
1585
|
}
|
|
1657
1586
|
|
|
1658
1587
|
// Compute centroids
|
|
@@ -1664,27 +1593,19 @@ NK_PUBLIC void nk_kabsch_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_si
|
|
|
1664
1593
|
nk_f32_t centroid_b_y = sum_b_y * inv_n;
|
|
1665
1594
|
nk_f32_t centroid_b_z = sum_b_z * inv_n;
|
|
1666
1595
|
|
|
1667
|
-
if (a_centroid)
|
|
1668
|
-
|
|
1669
|
-
a_centroid[1] = centroid_a_y;
|
|
1670
|
-
a_centroid[2] = centroid_a_z;
|
|
1671
|
-
}
|
|
1672
|
-
if (b_centroid) {
|
|
1673
|
-
b_centroid[0] = centroid_b_x;
|
|
1674
|
-
b_centroid[1] = centroid_b_y;
|
|
1675
|
-
b_centroid[2] = centroid_b_z;
|
|
1676
|
-
}
|
|
1596
|
+
if (a_centroid) a_centroid[0] = centroid_a_x, a_centroid[1] = centroid_a_y, a_centroid[2] = centroid_a_z;
|
|
1597
|
+
if (b_centroid) b_centroid[0] = centroid_b_x, b_centroid[1] = centroid_b_y, b_centroid[2] = centroid_b_z;
|
|
1677
1598
|
|
|
1678
1599
|
// Apply centering correction: H_centered = H - n * centroid_a * centroid_bᵀ
|
|
1679
|
-
covariance_x_x -= n * centroid_a_x * centroid_b_x;
|
|
1680
|
-
covariance_x_y -= n * centroid_a_x * centroid_b_y;
|
|
1681
|
-
covariance_x_z -= n * centroid_a_x * centroid_b_z;
|
|
1682
|
-
covariance_y_x -= n * centroid_a_y * centroid_b_x;
|
|
1683
|
-
covariance_y_y -= n * centroid_a_y * centroid_b_y;
|
|
1684
|
-
covariance_y_z -= n * centroid_a_y * centroid_b_z;
|
|
1685
|
-
covariance_z_x -= n * centroid_a_z * centroid_b_x;
|
|
1686
|
-
covariance_z_y -= n * centroid_a_z * centroid_b_y;
|
|
1687
|
-
covariance_z_z -= n * centroid_a_z * centroid_b_z;
|
|
1600
|
+
covariance_x_x -= (nk_f32_t)n * centroid_a_x * centroid_b_x;
|
|
1601
|
+
covariance_x_y -= (nk_f32_t)n * centroid_a_x * centroid_b_y;
|
|
1602
|
+
covariance_x_z -= (nk_f32_t)n * centroid_a_x * centroid_b_z;
|
|
1603
|
+
covariance_y_x -= (nk_f32_t)n * centroid_a_y * centroid_b_x;
|
|
1604
|
+
covariance_y_y -= (nk_f32_t)n * centroid_a_y * centroid_b_y;
|
|
1605
|
+
covariance_y_z -= (nk_f32_t)n * centroid_a_y * centroid_b_z;
|
|
1606
|
+
covariance_z_x -= (nk_f32_t)n * centroid_a_z * centroid_b_x;
|
|
1607
|
+
covariance_z_y -= (nk_f32_t)n * centroid_a_z * centroid_b_y;
|
|
1608
|
+
covariance_z_z -= (nk_f32_t)n * centroid_a_z * centroid_b_z;
|
|
1688
1609
|
|
|
1689
1610
|
// Compute SVD and optimal rotation
|
|
1690
1611
|
nk_f32_t cross_covariance[9] = {covariance_x_x, covariance_x_y, covariance_x_z, covariance_y_x, covariance_y_y,
|
|
@@ -1706,9 +1627,7 @@ NK_PUBLIC void nk_kabsch_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_si
|
|
|
1706
1627
|
|
|
1707
1628
|
// Handle reflection: if det(R) < 0, negate third column of V and recompute R
|
|
1708
1629
|
if (nk_det3x3_f32_(r) < 0) {
|
|
1709
|
-
svd_v[2] = -svd_v[2];
|
|
1710
|
-
svd_v[5] = -svd_v[5];
|
|
1711
|
-
svd_v[8] = -svd_v[8];
|
|
1630
|
+
svd_v[2] = -svd_v[2], svd_v[5] = -svd_v[5], svd_v[8] = -svd_v[8];
|
|
1712
1631
|
r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
|
|
1713
1632
|
r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
|
|
1714
1633
|
r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
|
|
@@ -1720,10 +1639,9 @@ NK_PUBLIC void nk_kabsch_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_si
|
|
|
1720
1639
|
r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
|
|
1721
1640
|
}
|
|
1722
1641
|
|
|
1723
|
-
|
|
1724
|
-
if (rotation)
|
|
1642
|
+
// Output rotation matrix and scale=1.0
|
|
1643
|
+
if (rotation)
|
|
1725
1644
|
for (int j = 0; j < 9; ++j) rotation[j] = r[j];
|
|
1726
|
-
}
|
|
1727
1645
|
if (scale) *scale = 1.0f;
|
|
1728
1646
|
|
|
1729
1647
|
// Compute RMSD after optimal rotation
|
|
@@ -1800,21 +1718,11 @@ NK_PUBLIC void nk_kabsch_bf16_haswell(nk_bf16_t const *a, nk_bf16_t const *b, nk
|
|
|
1800
1718
|
nk_bf16_to_f32_serial(&b[i * 3 + 0], &bx);
|
|
1801
1719
|
nk_bf16_to_f32_serial(&b[i * 3 + 1], &by);
|
|
1802
1720
|
nk_bf16_to_f32_serial(&b[i * 3 + 2], &bz);
|
|
1803
|
-
sum_a_x += ax;
|
|
1804
|
-
|
|
1805
|
-
|
|
1806
|
-
|
|
1807
|
-
|
|
1808
|
-
sum_b_z += bz;
|
|
1809
|
-
covariance_x_x += ax * bx;
|
|
1810
|
-
covariance_x_y += ax * by;
|
|
1811
|
-
covariance_x_z += ax * bz;
|
|
1812
|
-
covariance_y_x += ay * bx;
|
|
1813
|
-
covariance_y_y += ay * by;
|
|
1814
|
-
covariance_y_z += ay * bz;
|
|
1815
|
-
covariance_z_x += az * bx;
|
|
1816
|
-
covariance_z_y += az * by;
|
|
1817
|
-
covariance_z_z += az * bz;
|
|
1721
|
+
sum_a_x += ax, sum_a_y += ay, sum_a_z += az;
|
|
1722
|
+
sum_b_x += bx, sum_b_y += by, sum_b_z += bz;
|
|
1723
|
+
covariance_x_x += ax * bx, covariance_x_y += ax * by, covariance_x_z += ax * bz;
|
|
1724
|
+
covariance_y_x += ay * bx, covariance_y_y += ay * by, covariance_y_z += ay * bz;
|
|
1725
|
+
covariance_z_x += az * bx, covariance_z_y += az * by, covariance_z_z += az * bz;
|
|
1818
1726
|
}
|
|
1819
1727
|
|
|
1820
1728
|
// Compute centroids
|
|
@@ -1826,27 +1734,19 @@ NK_PUBLIC void nk_kabsch_bf16_haswell(nk_bf16_t const *a, nk_bf16_t const *b, nk
|
|
|
1826
1734
|
nk_f32_t centroid_b_y = sum_b_y * inv_n;
|
|
1827
1735
|
nk_f32_t centroid_b_z = sum_b_z * inv_n;
|
|
1828
1736
|
|
|
1829
|
-
if (a_centroid)
|
|
1830
|
-
|
|
1831
|
-
a_centroid[1] = centroid_a_y;
|
|
1832
|
-
a_centroid[2] = centroid_a_z;
|
|
1833
|
-
}
|
|
1834
|
-
if (b_centroid) {
|
|
1835
|
-
b_centroid[0] = centroid_b_x;
|
|
1836
|
-
b_centroid[1] = centroid_b_y;
|
|
1837
|
-
b_centroid[2] = centroid_b_z;
|
|
1838
|
-
}
|
|
1737
|
+
if (a_centroid) a_centroid[0] = centroid_a_x, a_centroid[1] = centroid_a_y, a_centroid[2] = centroid_a_z;
|
|
1738
|
+
if (b_centroid) b_centroid[0] = centroid_b_x, b_centroid[1] = centroid_b_y, b_centroid[2] = centroid_b_z;
|
|
1839
1739
|
|
|
1840
1740
|
// Apply centering correction: H_centered = H - n * centroid_a * centroid_bᵀ
|
|
1841
|
-
covariance_x_x -= n * centroid_a_x * centroid_b_x;
|
|
1842
|
-
covariance_x_y -= n * centroid_a_x * centroid_b_y;
|
|
1843
|
-
covariance_x_z -= n * centroid_a_x * centroid_b_z;
|
|
1844
|
-
covariance_y_x -= n * centroid_a_y * centroid_b_x;
|
|
1845
|
-
covariance_y_y -= n * centroid_a_y * centroid_b_y;
|
|
1846
|
-
covariance_y_z -= n * centroid_a_y * centroid_b_z;
|
|
1847
|
-
covariance_z_x -= n * centroid_a_z * centroid_b_x;
|
|
1848
|
-
covariance_z_y -= n * centroid_a_z * centroid_b_y;
|
|
1849
|
-
covariance_z_z -= n * centroid_a_z * centroid_b_z;
|
|
1741
|
+
covariance_x_x -= (nk_f32_t)n * centroid_a_x * centroid_b_x;
|
|
1742
|
+
covariance_x_y -= (nk_f32_t)n * centroid_a_x * centroid_b_y;
|
|
1743
|
+
covariance_x_z -= (nk_f32_t)n * centroid_a_x * centroid_b_z;
|
|
1744
|
+
covariance_y_x -= (nk_f32_t)n * centroid_a_y * centroid_b_x;
|
|
1745
|
+
covariance_y_y -= (nk_f32_t)n * centroid_a_y * centroid_b_y;
|
|
1746
|
+
covariance_y_z -= (nk_f32_t)n * centroid_a_y * centroid_b_z;
|
|
1747
|
+
covariance_z_x -= (nk_f32_t)n * centroid_a_z * centroid_b_x;
|
|
1748
|
+
covariance_z_y -= (nk_f32_t)n * centroid_a_z * centroid_b_y;
|
|
1749
|
+
covariance_z_z -= (nk_f32_t)n * centroid_a_z * centroid_b_z;
|
|
1850
1750
|
|
|
1851
1751
|
// Compute SVD and optimal rotation
|
|
1852
1752
|
nk_f32_t cross_covariance[9] = {covariance_x_x, covariance_x_y, covariance_x_z, covariance_y_x, covariance_y_y,
|
|
@@ -1868,9 +1768,7 @@ NK_PUBLIC void nk_kabsch_bf16_haswell(nk_bf16_t const *a, nk_bf16_t const *b, nk
|
|
|
1868
1768
|
|
|
1869
1769
|
// Handle reflection: if det(R) < 0, negate third column of V and recompute R
|
|
1870
1770
|
if (nk_det3x3_f32_(r) < 0) {
|
|
1871
|
-
svd_v[2] = -svd_v[2];
|
|
1872
|
-
svd_v[5] = -svd_v[5];
|
|
1873
|
-
svd_v[8] = -svd_v[8];
|
|
1771
|
+
svd_v[2] = -svd_v[2], svd_v[5] = -svd_v[5], svd_v[8] = -svd_v[8];
|
|
1874
1772
|
r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
|
|
1875
1773
|
r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
|
|
1876
1774
|
r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
|
|
@@ -1882,10 +1780,9 @@ NK_PUBLIC void nk_kabsch_bf16_haswell(nk_bf16_t const *a, nk_bf16_t const *b, nk
|
|
|
1882
1780
|
r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
|
|
1883
1781
|
}
|
|
1884
1782
|
|
|
1885
|
-
|
|
1886
|
-
if (rotation)
|
|
1783
|
+
// Output rotation matrix and scale=1.0
|
|
1784
|
+
if (rotation)
|
|
1887
1785
|
for (int j = 0; j < 9; ++j) rotation[j] = r[j];
|
|
1888
|
-
}
|
|
1889
1786
|
if (scale) *scale = 1.0f;
|
|
1890
1787
|
|
|
1891
1788
|
// Compute RMSD after optimal rotation
|
|
@@ -1965,21 +1862,11 @@ NK_PUBLIC void nk_umeyama_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_s
|
|
|
1965
1862
|
nk_f16_to_f32_haswell(&b[i * 3 + 0], &bx);
|
|
1966
1863
|
nk_f16_to_f32_haswell(&b[i * 3 + 1], &by);
|
|
1967
1864
|
nk_f16_to_f32_haswell(&b[i * 3 + 2], &bz);
|
|
1968
|
-
sum_a_x += ax;
|
|
1969
|
-
|
|
1970
|
-
|
|
1971
|
-
|
|
1972
|
-
|
|
1973
|
-
sum_b_z += bz;
|
|
1974
|
-
covariance_x_x += ax * bx;
|
|
1975
|
-
covariance_x_y += ax * by;
|
|
1976
|
-
covariance_x_z += ax * bz;
|
|
1977
|
-
covariance_y_x += ay * bx;
|
|
1978
|
-
covariance_y_y += ay * by;
|
|
1979
|
-
covariance_y_z += ay * bz;
|
|
1980
|
-
covariance_z_x += az * bx;
|
|
1981
|
-
covariance_z_y += az * by;
|
|
1982
|
-
covariance_z_z += az * bz;
|
|
1865
|
+
sum_a_x += ax, sum_a_y += ay, sum_a_z += az;
|
|
1866
|
+
sum_b_x += bx, sum_b_y += by, sum_b_z += bz;
|
|
1867
|
+
covariance_x_x += ax * bx, covariance_x_y += ax * by, covariance_x_z += ax * bz;
|
|
1868
|
+
covariance_y_x += ay * bx, covariance_y_y += ay * by, covariance_y_z += ay * bz;
|
|
1869
|
+
covariance_z_x += az * bx, covariance_z_y += az * by, covariance_z_z += az * bz;
|
|
1983
1870
|
variance_a_sum += ax * ax + ay * ay + az * az;
|
|
1984
1871
|
}
|
|
1985
1872
|
|
|
@@ -1996,15 +1883,15 @@ NK_PUBLIC void nk_umeyama_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_s
|
|
|
1996
1883
|
(centroid_a_x * centroid_a_x + centroid_a_y * centroid_a_y + centroid_a_z * centroid_a_z);
|
|
1997
1884
|
|
|
1998
1885
|
// Apply centering correction to covariance matrix
|
|
1999
|
-
covariance_x_x -= n * centroid_a_x * centroid_b_x;
|
|
2000
|
-
covariance_x_y -= n * centroid_a_x * centroid_b_y;
|
|
2001
|
-
covariance_x_z -= n * centroid_a_x * centroid_b_z;
|
|
2002
|
-
covariance_y_x -= n * centroid_a_y * centroid_b_x;
|
|
2003
|
-
covariance_y_y -= n * centroid_a_y * centroid_b_y;
|
|
2004
|
-
covariance_y_z -= n * centroid_a_y * centroid_b_z;
|
|
2005
|
-
covariance_z_x -= n * centroid_a_z * centroid_b_x;
|
|
2006
|
-
covariance_z_y -= n * centroid_a_z * centroid_b_y;
|
|
2007
|
-
covariance_z_z -= n * centroid_a_z * centroid_b_z;
|
|
1886
|
+
covariance_x_x -= (nk_f32_t)n * centroid_a_x * centroid_b_x;
|
|
1887
|
+
covariance_x_y -= (nk_f32_t)n * centroid_a_x * centroid_b_y;
|
|
1888
|
+
covariance_x_z -= (nk_f32_t)n * centroid_a_x * centroid_b_z;
|
|
1889
|
+
covariance_y_x -= (nk_f32_t)n * centroid_a_y * centroid_b_x;
|
|
1890
|
+
covariance_y_y -= (nk_f32_t)n * centroid_a_y * centroid_b_y;
|
|
1891
|
+
covariance_y_z -= (nk_f32_t)n * centroid_a_y * centroid_b_z;
|
|
1892
|
+
covariance_z_x -= (nk_f32_t)n * centroid_a_z * centroid_b_x;
|
|
1893
|
+
covariance_z_y -= (nk_f32_t)n * centroid_a_z * centroid_b_y;
|
|
1894
|
+
covariance_z_z -= (nk_f32_t)n * centroid_a_z * centroid_b_z;
|
|
2008
1895
|
|
|
2009
1896
|
nk_f32_t cross_covariance[9] = {covariance_x_x, covariance_x_y, covariance_x_z, covariance_y_x, covariance_y_y,
|
|
2010
1897
|
covariance_y_z, covariance_z_x, covariance_z_y, covariance_z_z};
|
|
@@ -2029,7 +1916,7 @@ NK_PUBLIC void nk_umeyama_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_s
|
|
|
2029
1916
|
nk_f32_t det = nk_det3x3_f32_(r);
|
|
2030
1917
|
nk_f32_t d3 = det < 0 ? -1.0f : 1.0f;
|
|
2031
1918
|
nk_f32_t trace_ds = svd_s[0] + svd_s[4] + d3 * svd_s[8];
|
|
2032
|
-
nk_f32_t c = trace_ds / (n * variance_a);
|
|
1919
|
+
nk_f32_t c = trace_ds / ((nk_f32_t)n * variance_a);
|
|
2033
1920
|
if (scale) *scale = c;
|
|
2034
1921
|
|
|
2035
1922
|
// Handle reflection
|
|
@@ -2046,10 +1933,9 @@ NK_PUBLIC void nk_umeyama_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_s
|
|
|
2046
1933
|
r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
|
|
2047
1934
|
}
|
|
2048
1935
|
|
|
2049
|
-
|
|
2050
|
-
if (rotation)
|
|
1936
|
+
// Output rotation matrix
|
|
1937
|
+
if (rotation)
|
|
2051
1938
|
for (int j = 0; j < 9; ++j) rotation[j] = r[j];
|
|
2052
|
-
}
|
|
2053
1939
|
|
|
2054
1940
|
// Compute RMSD with scaling
|
|
2055
1941
|
nk_f32_t sum_squared = nk_transformed_ssd_f16_haswell_(a, b, n, r, c, centroid_a_x, centroid_a_y, centroid_a_z,
|
|
@@ -2128,21 +2014,11 @@ NK_PUBLIC void nk_umeyama_bf16_haswell(nk_bf16_t const *a, nk_bf16_t const *b, n
|
|
|
2128
2014
|
nk_bf16_to_f32_serial(&b[i * 3 + 0], &bx);
|
|
2129
2015
|
nk_bf16_to_f32_serial(&b[i * 3 + 1], &by);
|
|
2130
2016
|
nk_bf16_to_f32_serial(&b[i * 3 + 2], &bz);
|
|
2131
|
-
sum_a_x += ax;
|
|
2132
|
-
|
|
2133
|
-
|
|
2134
|
-
|
|
2135
|
-
|
|
2136
|
-
sum_b_z += bz;
|
|
2137
|
-
covariance_x_x += ax * bx;
|
|
2138
|
-
covariance_x_y += ax * by;
|
|
2139
|
-
covariance_x_z += ax * bz;
|
|
2140
|
-
covariance_y_x += ay * bx;
|
|
2141
|
-
covariance_y_y += ay * by;
|
|
2142
|
-
covariance_y_z += ay * bz;
|
|
2143
|
-
covariance_z_x += az * bx;
|
|
2144
|
-
covariance_z_y += az * by;
|
|
2145
|
-
covariance_z_z += az * bz;
|
|
2017
|
+
sum_a_x += ax, sum_a_y += ay, sum_a_z += az;
|
|
2018
|
+
sum_b_x += bx, sum_b_y += by, sum_b_z += bz;
|
|
2019
|
+
covariance_x_x += ax * bx, covariance_x_y += ax * by, covariance_x_z += ax * bz;
|
|
2020
|
+
covariance_y_x += ay * bx, covariance_y_y += ay * by, covariance_y_z += ay * bz;
|
|
2021
|
+
covariance_z_x += az * bx, covariance_z_y += az * by, covariance_z_z += az * bz;
|
|
2146
2022
|
variance_a_sum += ax * ax + ay * ay + az * az;
|
|
2147
2023
|
}
|
|
2148
2024
|
|
|
@@ -2159,15 +2035,15 @@ NK_PUBLIC void nk_umeyama_bf16_haswell(nk_bf16_t const *a, nk_bf16_t const *b, n
|
|
|
2159
2035
|
(centroid_a_x * centroid_a_x + centroid_a_y * centroid_a_y + centroid_a_z * centroid_a_z);
|
|
2160
2036
|
|
|
2161
2037
|
// Apply centering correction to covariance matrix
|
|
2162
|
-
covariance_x_x -= n * centroid_a_x * centroid_b_x;
|
|
2163
|
-
covariance_x_y -= n * centroid_a_x * centroid_b_y;
|
|
2164
|
-
covariance_x_z -= n * centroid_a_x * centroid_b_z;
|
|
2165
|
-
covariance_y_x -= n * centroid_a_y * centroid_b_x;
|
|
2166
|
-
covariance_y_y -= n * centroid_a_y * centroid_b_y;
|
|
2167
|
-
covariance_y_z -= n * centroid_a_y * centroid_b_z;
|
|
2168
|
-
covariance_z_x -= n * centroid_a_z * centroid_b_x;
|
|
2169
|
-
covariance_z_y -= n * centroid_a_z * centroid_b_y;
|
|
2170
|
-
covariance_z_z -= n * centroid_a_z * centroid_b_z;
|
|
2038
|
+
covariance_x_x -= (nk_f32_t)n * centroid_a_x * centroid_b_x;
|
|
2039
|
+
covariance_x_y -= (nk_f32_t)n * centroid_a_x * centroid_b_y;
|
|
2040
|
+
covariance_x_z -= (nk_f32_t)n * centroid_a_x * centroid_b_z;
|
|
2041
|
+
covariance_y_x -= (nk_f32_t)n * centroid_a_y * centroid_b_x;
|
|
2042
|
+
covariance_y_y -= (nk_f32_t)n * centroid_a_y * centroid_b_y;
|
|
2043
|
+
covariance_y_z -= (nk_f32_t)n * centroid_a_y * centroid_b_z;
|
|
2044
|
+
covariance_z_x -= (nk_f32_t)n * centroid_a_z * centroid_b_x;
|
|
2045
|
+
covariance_z_y -= (nk_f32_t)n * centroid_a_z * centroid_b_y;
|
|
2046
|
+
covariance_z_z -= (nk_f32_t)n * centroid_a_z * centroid_b_z;
|
|
2171
2047
|
|
|
2172
2048
|
nk_f32_t cross_covariance[9] = {covariance_x_x, covariance_x_y, covariance_x_z, covariance_y_x, covariance_y_y,
|
|
2173
2049
|
covariance_y_z, covariance_z_x, covariance_z_y, covariance_z_z};
|
|
@@ -2192,7 +2068,7 @@ NK_PUBLIC void nk_umeyama_bf16_haswell(nk_bf16_t const *a, nk_bf16_t const *b, n
|
|
|
2192
2068
|
nk_f32_t det = nk_det3x3_f32_(r);
|
|
2193
2069
|
nk_f32_t d3 = det < 0 ? -1.0f : 1.0f;
|
|
2194
2070
|
nk_f32_t trace_ds = svd_s[0] + svd_s[4] + d3 * svd_s[8];
|
|
2195
|
-
nk_f32_t c = trace_ds / (n * variance_a);
|
|
2071
|
+
nk_f32_t c = trace_ds / ((nk_f32_t)n * variance_a);
|
|
2196
2072
|
if (scale) *scale = c;
|
|
2197
2073
|
|
|
2198
2074
|
// Handle reflection
|
|
@@ -2209,10 +2085,9 @@ NK_PUBLIC void nk_umeyama_bf16_haswell(nk_bf16_t const *a, nk_bf16_t const *b, n
|
|
|
2209
2085
|
r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
|
|
2210
2086
|
}
|
|
2211
2087
|
|
|
2212
|
-
|
|
2213
|
-
if (rotation)
|
|
2088
|
+
// Output rotation matrix
|
|
2089
|
+
if (rotation)
|
|
2214
2090
|
for (int j = 0; j < 9; ++j) rotation[j] = r[j];
|
|
2215
|
-
}
|
|
2216
2091
|
|
|
2217
2092
|
// Compute RMSD with scaling
|
|
2218
2093
|
nk_f32_t sum_squared = nk_transformed_ssd_bf16_haswell_(a, b, n, r, c, centroid_a_x, centroid_a_y, centroid_a_z,
|