numkong 7.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +201 -0
- package/README.md +495 -0
- package/binding.gyp +540 -0
- package/c/dispatch.h +512 -0
- package/c/dispatch_bf16.c +389 -0
- package/c/dispatch_bf16c.c +52 -0
- package/c/dispatch_e2m3.c +263 -0
- package/c/dispatch_e3m2.c +243 -0
- package/c/dispatch_e4m3.c +276 -0
- package/c/dispatch_e5m2.c +272 -0
- package/c/dispatch_f16.c +376 -0
- package/c/dispatch_f16c.c +58 -0
- package/c/dispatch_f32.c +378 -0
- package/c/dispatch_f32c.c +99 -0
- package/c/dispatch_f64.c +296 -0
- package/c/dispatch_f64c.c +98 -0
- package/c/dispatch_i16.c +96 -0
- package/c/dispatch_i32.c +89 -0
- package/c/dispatch_i4.c +150 -0
- package/c/dispatch_i64.c +86 -0
- package/c/dispatch_i8.c +289 -0
- package/c/dispatch_other.c +330 -0
- package/c/dispatch_u1.c +148 -0
- package/c/dispatch_u16.c +124 -0
- package/c/dispatch_u32.c +118 -0
- package/c/dispatch_u4.c +150 -0
- package/c/dispatch_u64.c +102 -0
- package/c/dispatch_u8.c +303 -0
- package/c/numkong.c +950 -0
- package/include/README.md +573 -0
- package/include/module.modulemap +129 -0
- package/include/numkong/attention/sapphireamx.h +1361 -0
- package/include/numkong/attention/sme.h +2066 -0
- package/include/numkong/attention.h +49 -0
- package/include/numkong/capabilities.h +748 -0
- package/include/numkong/cast/README.md +262 -0
- package/include/numkong/cast/haswell.h +975 -0
- package/include/numkong/cast/icelake.h +470 -0
- package/include/numkong/cast/neon.h +1192 -0
- package/include/numkong/cast/rvv.h +1021 -0
- package/include/numkong/cast/sapphire.h +262 -0
- package/include/numkong/cast/serial.h +2262 -0
- package/include/numkong/cast/skylake.h +856 -0
- package/include/numkong/cast/v128relaxed.h +180 -0
- package/include/numkong/cast.h +230 -0
- package/include/numkong/curved/README.md +223 -0
- package/include/numkong/curved/genoa.h +182 -0
- package/include/numkong/curved/haswell.h +276 -0
- package/include/numkong/curved/neon.h +205 -0
- package/include/numkong/curved/neonbfdot.h +212 -0
- package/include/numkong/curved/neonhalf.h +212 -0
- package/include/numkong/curved/rvv.h +305 -0
- package/include/numkong/curved/serial.h +207 -0
- package/include/numkong/curved/skylake.h +457 -0
- package/include/numkong/curved/smef64.h +506 -0
- package/include/numkong/curved.h +517 -0
- package/include/numkong/curved.hpp +144 -0
- package/include/numkong/dot/README.md +425 -0
- package/include/numkong/dot/alder.h +563 -0
- package/include/numkong/dot/genoa.h +315 -0
- package/include/numkong/dot/haswell.h +1688 -0
- package/include/numkong/dot/icelake.h +883 -0
- package/include/numkong/dot/neon.h +818 -0
- package/include/numkong/dot/neonbfdot.h +244 -0
- package/include/numkong/dot/neonfhm.h +360 -0
- package/include/numkong/dot/neonhalf.h +198 -0
- package/include/numkong/dot/neonsdot.h +508 -0
- package/include/numkong/dot/rvv.h +714 -0
- package/include/numkong/dot/rvvbb.h +72 -0
- package/include/numkong/dot/rvvbf16.h +123 -0
- package/include/numkong/dot/rvvhalf.h +129 -0
- package/include/numkong/dot/sapphire.h +141 -0
- package/include/numkong/dot/serial.h +838 -0
- package/include/numkong/dot/sierra.h +405 -0
- package/include/numkong/dot/skylake.h +1084 -0
- package/include/numkong/dot/sve.h +379 -0
- package/include/numkong/dot/svebfdot.h +74 -0
- package/include/numkong/dot/svehalf.h +123 -0
- package/include/numkong/dot/v128relaxed.h +1258 -0
- package/include/numkong/dot.h +1070 -0
- package/include/numkong/dot.hpp +94 -0
- package/include/numkong/dots/README.md +496 -0
- package/include/numkong/dots/alder.h +114 -0
- package/include/numkong/dots/genoa.h +94 -0
- package/include/numkong/dots/haswell.h +295 -0
- package/include/numkong/dots/icelake.h +171 -0
- package/include/numkong/dots/neon.h +120 -0
- package/include/numkong/dots/neonbfdot.h +58 -0
- package/include/numkong/dots/neonfhm.h +94 -0
- package/include/numkong/dots/neonhalf.h +57 -0
- package/include/numkong/dots/neonsdot.h +108 -0
- package/include/numkong/dots/rvv.h +2486 -0
- package/include/numkong/dots/sapphireamx.h +3973 -0
- package/include/numkong/dots/serial.h +2844 -0
- package/include/numkong/dots/sierra.h +97 -0
- package/include/numkong/dots/skylake.h +196 -0
- package/include/numkong/dots/sme.h +5372 -0
- package/include/numkong/dots/smebi32.h +461 -0
- package/include/numkong/dots/smef64.h +1318 -0
- package/include/numkong/dots/smehalf.h +47 -0
- package/include/numkong/dots/v128relaxed.h +294 -0
- package/include/numkong/dots.h +2804 -0
- package/include/numkong/dots.hpp +639 -0
- package/include/numkong/each/README.md +469 -0
- package/include/numkong/each/haswell.h +1658 -0
- package/include/numkong/each/icelake.h +272 -0
- package/include/numkong/each/neon.h +1104 -0
- package/include/numkong/each/neonbfdot.h +212 -0
- package/include/numkong/each/neonhalf.h +410 -0
- package/include/numkong/each/rvv.h +1121 -0
- package/include/numkong/each/sapphire.h +477 -0
- package/include/numkong/each/serial.h +260 -0
- package/include/numkong/each/skylake.h +1562 -0
- package/include/numkong/each.h +2146 -0
- package/include/numkong/each.hpp +434 -0
- package/include/numkong/geospatial/README.md +147 -0
- package/include/numkong/geospatial/haswell.h +593 -0
- package/include/numkong/geospatial/neon.h +571 -0
- package/include/numkong/geospatial/rvv.h +701 -0
- package/include/numkong/geospatial/serial.h +309 -0
- package/include/numkong/geospatial/skylake.h +577 -0
- package/include/numkong/geospatial/v128relaxed.h +613 -0
- package/include/numkong/geospatial.h +453 -0
- package/include/numkong/geospatial.hpp +235 -0
- package/include/numkong/matrix.hpp +336 -0
- package/include/numkong/maxsim/README.md +187 -0
- package/include/numkong/maxsim/alder.h +511 -0
- package/include/numkong/maxsim/genoa.h +115 -0
- package/include/numkong/maxsim/haswell.h +553 -0
- package/include/numkong/maxsim/icelake.h +480 -0
- package/include/numkong/maxsim/neonsdot.h +394 -0
- package/include/numkong/maxsim/sapphireamx.h +877 -0
- package/include/numkong/maxsim/serial.h +490 -0
- package/include/numkong/maxsim/sme.h +929 -0
- package/include/numkong/maxsim/v128relaxed.h +280 -0
- package/include/numkong/maxsim.h +571 -0
- package/include/numkong/maxsim.hpp +133 -0
- package/include/numkong/mesh/README.md +227 -0
- package/include/numkong/mesh/haswell.h +2235 -0
- package/include/numkong/mesh/neon.h +1329 -0
- package/include/numkong/mesh/neonbfdot.h +842 -0
- package/include/numkong/mesh/neonhalf.h +616 -0
- package/include/numkong/mesh/rvv.h +916 -0
- package/include/numkong/mesh/serial.h +742 -0
- package/include/numkong/mesh/skylake.h +1135 -0
- package/include/numkong/mesh/v128relaxed.h +1052 -0
- package/include/numkong/mesh.h +652 -0
- package/include/numkong/mesh.hpp +762 -0
- package/include/numkong/numkong.h +78 -0
- package/include/numkong/numkong.hpp +57 -0
- package/include/numkong/probability/README.md +173 -0
- package/include/numkong/probability/haswell.h +267 -0
- package/include/numkong/probability/neon.h +225 -0
- package/include/numkong/probability/rvv.h +409 -0
- package/include/numkong/probability/serial.h +169 -0
- package/include/numkong/probability/skylake.h +324 -0
- package/include/numkong/probability.h +383 -0
- package/include/numkong/probability.hpp +120 -0
- package/include/numkong/random.h +50 -0
- package/include/numkong/random.hpp +285 -0
- package/include/numkong/reduce/README.md +547 -0
- package/include/numkong/reduce/alder.h +632 -0
- package/include/numkong/reduce/genoa.h +201 -0
- package/include/numkong/reduce/haswell.h +3783 -0
- package/include/numkong/reduce/icelake.h +549 -0
- package/include/numkong/reduce/neon.h +3841 -0
- package/include/numkong/reduce/neonbfdot.h +353 -0
- package/include/numkong/reduce/neonfhm.h +665 -0
- package/include/numkong/reduce/neonhalf.h +157 -0
- package/include/numkong/reduce/neonsdot.h +357 -0
- package/include/numkong/reduce/rvv.h +3407 -0
- package/include/numkong/reduce/serial.h +757 -0
- package/include/numkong/reduce/sierra.h +338 -0
- package/include/numkong/reduce/skylake.h +3792 -0
- package/include/numkong/reduce/v128relaxed.h +2302 -0
- package/include/numkong/reduce.h +1597 -0
- package/include/numkong/reduce.hpp +633 -0
- package/include/numkong/scalar/README.md +89 -0
- package/include/numkong/scalar/haswell.h +113 -0
- package/include/numkong/scalar/neon.h +122 -0
- package/include/numkong/scalar/neonhalf.h +70 -0
- package/include/numkong/scalar/rvv.h +211 -0
- package/include/numkong/scalar/sapphire.h +63 -0
- package/include/numkong/scalar/serial.h +332 -0
- package/include/numkong/scalar/v128relaxed.h +56 -0
- package/include/numkong/scalar.h +683 -0
- package/include/numkong/set/README.md +179 -0
- package/include/numkong/set/haswell.h +334 -0
- package/include/numkong/set/icelake.h +485 -0
- package/include/numkong/set/neon.h +364 -0
- package/include/numkong/set/rvv.h +226 -0
- package/include/numkong/set/rvvbb.h +117 -0
- package/include/numkong/set/serial.h +174 -0
- package/include/numkong/set/sve.h +185 -0
- package/include/numkong/set/v128relaxed.h +240 -0
- package/include/numkong/set.h +457 -0
- package/include/numkong/set.hpp +114 -0
- package/include/numkong/sets/README.md +149 -0
- package/include/numkong/sets/haswell.h +63 -0
- package/include/numkong/sets/icelake.h +66 -0
- package/include/numkong/sets/neon.h +61 -0
- package/include/numkong/sets/serial.h +43 -0
- package/include/numkong/sets/smebi32.h +1099 -0
- package/include/numkong/sets/v128relaxed.h +58 -0
- package/include/numkong/sets.h +339 -0
- package/include/numkong/sparse/README.md +156 -0
- package/include/numkong/sparse/icelake.h +463 -0
- package/include/numkong/sparse/neon.h +288 -0
- package/include/numkong/sparse/serial.h +117 -0
- package/include/numkong/sparse/sve2.h +507 -0
- package/include/numkong/sparse/turin.h +322 -0
- package/include/numkong/sparse.h +363 -0
- package/include/numkong/sparse.hpp +113 -0
- package/include/numkong/spatial/README.md +435 -0
- package/include/numkong/spatial/alder.h +607 -0
- package/include/numkong/spatial/genoa.h +290 -0
- package/include/numkong/spatial/haswell.h +960 -0
- package/include/numkong/spatial/icelake.h +586 -0
- package/include/numkong/spatial/neon.h +773 -0
- package/include/numkong/spatial/neonbfdot.h +165 -0
- package/include/numkong/spatial/neonhalf.h +118 -0
- package/include/numkong/spatial/neonsdot.h +261 -0
- package/include/numkong/spatial/rvv.h +984 -0
- package/include/numkong/spatial/rvvbf16.h +123 -0
- package/include/numkong/spatial/rvvhalf.h +117 -0
- package/include/numkong/spatial/sapphire.h +343 -0
- package/include/numkong/spatial/serial.h +346 -0
- package/include/numkong/spatial/sierra.h +323 -0
- package/include/numkong/spatial/skylake.h +606 -0
- package/include/numkong/spatial/sve.h +224 -0
- package/include/numkong/spatial/svebfdot.h +122 -0
- package/include/numkong/spatial/svehalf.h +109 -0
- package/include/numkong/spatial/v128relaxed.h +717 -0
- package/include/numkong/spatial.h +1425 -0
- package/include/numkong/spatial.hpp +183 -0
- package/include/numkong/spatials/README.md +580 -0
- package/include/numkong/spatials/alder.h +94 -0
- package/include/numkong/spatials/genoa.h +94 -0
- package/include/numkong/spatials/haswell.h +219 -0
- package/include/numkong/spatials/icelake.h +113 -0
- package/include/numkong/spatials/neon.h +109 -0
- package/include/numkong/spatials/neonbfdot.h +60 -0
- package/include/numkong/spatials/neonfhm.h +92 -0
- package/include/numkong/spatials/neonhalf.h +58 -0
- package/include/numkong/spatials/neonsdot.h +109 -0
- package/include/numkong/spatials/rvv.h +1960 -0
- package/include/numkong/spatials/sapphireamx.h +1149 -0
- package/include/numkong/spatials/serial.h +226 -0
- package/include/numkong/spatials/sierra.h +96 -0
- package/include/numkong/spatials/skylake.h +184 -0
- package/include/numkong/spatials/sme.h +1901 -0
- package/include/numkong/spatials/smef64.h +465 -0
- package/include/numkong/spatials/v128relaxed.h +240 -0
- package/include/numkong/spatials.h +3021 -0
- package/include/numkong/spatials.hpp +508 -0
- package/include/numkong/tensor.hpp +1592 -0
- package/include/numkong/trigonometry/README.md +184 -0
- package/include/numkong/trigonometry/haswell.h +652 -0
- package/include/numkong/trigonometry/neon.h +639 -0
- package/include/numkong/trigonometry/rvv.h +699 -0
- package/include/numkong/trigonometry/serial.h +703 -0
- package/include/numkong/trigonometry/skylake.h +721 -0
- package/include/numkong/trigonometry/v128relaxed.h +666 -0
- package/include/numkong/trigonometry.h +467 -0
- package/include/numkong/trigonometry.hpp +166 -0
- package/include/numkong/types.h +1384 -0
- package/include/numkong/types.hpp +5603 -0
- package/include/numkong/vector.hpp +698 -0
- package/javascript/README.md +246 -0
- package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
- package/javascript/dist/cjs/numkong-wasm.js +617 -0
- package/javascript/dist/cjs/numkong.d.ts +343 -0
- package/javascript/dist/cjs/numkong.js +523 -0
- package/javascript/dist/cjs/package.json +3 -0
- package/javascript/dist/cjs/types.d.ts +284 -0
- package/javascript/dist/cjs/types.js +653 -0
- package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
- package/javascript/dist/esm/numkong-wasm.js +595 -0
- package/javascript/dist/esm/numkong.d.ts +343 -0
- package/javascript/dist/esm/numkong.js +452 -0
- package/javascript/dist/esm/package.json +3 -0
- package/javascript/dist/esm/types.d.ts +284 -0
- package/javascript/dist/esm/types.js +630 -0
- package/javascript/dist-package-cjs.json +3 -0
- package/javascript/dist-package-esm.json +3 -0
- package/javascript/node-gyp-build.d.ts +1 -0
- package/javascript/numkong-wasm.ts +756 -0
- package/javascript/numkong.c +689 -0
- package/javascript/numkong.ts +575 -0
- package/javascript/tsconfig-base.json +39 -0
- package/javascript/tsconfig-cjs.json +8 -0
- package/javascript/tsconfig-esm.json +8 -0
- package/javascript/types.ts +674 -0
- package/package.json +87 -0
|
@@ -0,0 +1,2235 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SIMD-accelerated Point Cloud Alignment for Haswell.
|
|
3
|
+
* @file include/numkong/mesh/haswell.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date December 27, 2025
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/mesh.h
|
|
8
|
+
*
|
|
9
|
+
* @section haswell_mesh_instructions Key AVX2 Mesh Instructions
|
|
10
|
+
*
|
|
11
|
+
* Intrinsic Instruction Latency Throughput Ports
|
|
12
|
+
* _mm256_fmadd_ps VFMADD (YMM, YMM, YMM) 5cy 0.5/cy p01
|
|
13
|
+
* _mm256_hadd_ps VHADDPS (YMM, YMM, YMM) 7cy 0.5/cy p01+p5
|
|
14
|
+
* _mm256_permute2f128_ps VPERM2F128 (YMM, YMM, YMM, I8) 3cy 1/cy p5
|
|
15
|
+
* _mm256_extractf128_ps VEXTRACTF128 (XMM, YMM, I8) 3cy 1/cy p5
|
|
16
|
+
* _mm256_i32gather_ps VGATHERDPS (YMM, M, YMM, YMM) 12cy 5/cy p0+p23
|
|
17
|
+
*
|
|
18
|
+
* Point cloud operations (centroid, covariance, Kabsch alignment) use gather instructions for
|
|
19
|
+
* stride-3 xyz deinterleaving. Multiple FMA accumulators hide the 5-cycle FMA latency. VHADDPS
|
|
20
|
+
* interleaves results across lanes, requiring additional shuffles for final scalar reduction.
|
|
21
|
+
*/
|
|
22
|
+
#ifndef NK_MESH_HASWELL_H
|
|
23
|
+
#define NK_MESH_HASWELL_H
|
|
24
|
+
|
|
25
|
+
#if NK_TARGET_X86_
|
|
26
|
+
#if NK_TARGET_HASWELL
|
|
27
|
+
|
|
28
|
+
#include "numkong/types.h"
|
|
29
|
+
#include "numkong/dot/haswell.h"
|
|
30
|
+
#include "numkong/mesh/serial.h"
|
|
31
|
+
#include "numkong/reduce/haswell.h" // `nk_reduce_add_f32x8_haswell_`
|
|
32
|
+
#include "numkong/spatial/haswell.h" // `nk_f32_sqrt_haswell`, `nk_f64_sqrt_haswell`
|
|
33
|
+
|
|
34
|
+
#if defined(__cplusplus)
|
|
35
|
+
extern "C" {
|
|
36
|
+
#endif
|
|
37
|
+
|
|
38
|
+
#if defined(__clang__)
|
|
39
|
+
#pragma clang attribute push(__attribute__((target("avx2,f16c,fma,bmi,bmi2"))), apply_to = function)
|
|
40
|
+
#elif defined(__GNUC__)
|
|
41
|
+
#pragma GCC push_options
|
|
42
|
+
#pragma GCC target("avx2", "f16c", "fma", "bmi", "bmi2")
|
|
43
|
+
#endif
|
|
44
|
+
|
|
45
|
+
/* Deinterleave 24 floats (8 xyz triplets) into separate x, y, z vectors.
|
|
46
|
+
* Uses AVX2 gather instructions for clean stride-3 access.
|
|
47
|
+
*
|
|
48
|
+
* Input: 24 contiguous floats [x0,y0,z0, x1,y1,z1, ..., x7,y7,z7]
|
|
49
|
+
* Output: x[8], y[8], z[8] vectors
|
|
50
|
+
*/
|
|
51
|
+
NK_INTERNAL void nk_deinterleave_f32x8_haswell_(nk_f32_t const *ptr, __m256 *x_out, __m256 *y_out, __m256 *z_out) {
|
|
52
|
+
// Gather indices: 0, 3, 6, 9, 12, 15, 18, 21 (stride 3)
|
|
53
|
+
__m256i idx = _mm256_setr_epi32(0, 3, 6, 9, 12, 15, 18, 21);
|
|
54
|
+
*x_out = _mm256_i32gather_ps(ptr + 0, idx, 4);
|
|
55
|
+
*y_out = _mm256_i32gather_ps(ptr + 1, idx, 4);
|
|
56
|
+
*z_out = _mm256_i32gather_ps(ptr + 2, idx, 4);
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
/* Deinterleave 12 f64 values (4 xyz triplets) into separate x, y, z vectors.
|
|
60
|
+
* Uses scalar extraction for simplicity as AVX2 lacks efficient stride-3 gather for f64.
|
|
61
|
+
*
|
|
62
|
+
* Input: 12 contiguous f64 [x0,y0,z0, x1,y1,z1, x2,y2,z2, x3,y3,z3]
|
|
63
|
+
* Output: x[4], y[4], z[4] vectors
|
|
64
|
+
*/
|
|
65
|
+
NK_INTERNAL void nk_deinterleave_f64x4_haswell_(nk_f64_t const *ptr, __m256d *x_out, __m256d *y_out, __m256d *z_out) {
|
|
66
|
+
nk_f64_t x0 = ptr[0], x1 = ptr[3], x2 = ptr[6], x3 = ptr[9];
|
|
67
|
+
nk_f64_t y0 = ptr[1], y1 = ptr[4], y2 = ptr[7], y3 = ptr[10];
|
|
68
|
+
nk_f64_t z0 = ptr[2], z1 = ptr[5], z2 = ptr[8], z3 = ptr[11];
|
|
69
|
+
|
|
70
|
+
*x_out = _mm256_setr_pd(x0, x1, x2, x3);
|
|
71
|
+
*y_out = _mm256_setr_pd(y0, y1, y2, y3);
|
|
72
|
+
*z_out = _mm256_setr_pd(z0, z1, z2, z3);
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
/* Horizontal reduction helpers moved to reduce.h:
|
|
76
|
+
* - nk_reduce_add_f32x8_haswell_
|
|
77
|
+
* - nk_reduce_add_f64x4_haswell_
|
|
78
|
+
*/
|
|
79
|
+
|
|
80
|
+
NK_INTERNAL nk_f64_t nk_reduce_stable_f64x4_haswell_(__m256d values_f64x4) {
|
|
81
|
+
nk_b256_vec_t values;
|
|
82
|
+
values.ymm_pd = values_f64x4;
|
|
83
|
+
nk_f64_t sum = 0.0, compensation = 0.0;
|
|
84
|
+
nk_accumulate_sum_f64_(&sum, &compensation, values.f64s[0]);
|
|
85
|
+
nk_accumulate_sum_f64_(&sum, &compensation, values.f64s[1]);
|
|
86
|
+
nk_accumulate_sum_f64_(&sum, &compensation, values.f64s[2]);
|
|
87
|
+
nk_accumulate_sum_f64_(&sum, &compensation, values.f64s[3]);
|
|
88
|
+
return sum + compensation;
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
NK_INTERNAL void nk_rotation_from_svd_f64_haswell_(nk_f64_t const *svd_u, nk_f64_t const *svd_v, nk_f64_t *rotation) {
|
|
92
|
+
nk_rotation_from_svd_f64_serial_(svd_u, svd_v, rotation);
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
NK_INTERNAL void nk_accumulate_square_f64x4_haswell_(__m256d *sum_f64x4, __m256d *compensation_f64x4,
|
|
96
|
+
__m256d values_f64x4) {
|
|
97
|
+
__m256d product_f64x4 = _mm256_mul_pd(values_f64x4, values_f64x4);
|
|
98
|
+
__m256d product_error_f64x4 = _mm256_fmsub_pd(values_f64x4, values_f64x4, product_f64x4);
|
|
99
|
+
__m256d tentative_sum_f64x4 = _mm256_add_pd(*sum_f64x4, product_f64x4);
|
|
100
|
+
__m256d virtual_addend_f64x4 = _mm256_sub_pd(tentative_sum_f64x4, *sum_f64x4);
|
|
101
|
+
__m256d sum_error_f64x4 = _mm256_add_pd(
|
|
102
|
+
_mm256_sub_pd(*sum_f64x4, _mm256_sub_pd(tentative_sum_f64x4, virtual_addend_f64x4)),
|
|
103
|
+
_mm256_sub_pd(product_f64x4, virtual_addend_f64x4));
|
|
104
|
+
*sum_f64x4 = tentative_sum_f64x4;
|
|
105
|
+
*compensation_f64x4 = _mm256_add_pd(*compensation_f64x4, _mm256_add_pd(sum_error_f64x4, product_error_f64x4));
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
/* Compute sum of squared distances after applying rotation (and optional scale).
|
|
109
|
+
* Used by kabsch (scale=1.0) and umeyama (scale=computed_scale).
|
|
110
|
+
* Returns sum_squared, caller computes sqrt(sum_squared / n).
|
|
111
|
+
*/
|
|
112
|
+
NK_INTERNAL nk_f64_t nk_transformed_ssd_f32_haswell_(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n,
|
|
113
|
+
nk_f64_t const *r, nk_f64_t scale, nk_f64_t centroid_a_x,
|
|
114
|
+
nk_f64_t centroid_a_y, nk_f64_t centroid_a_z,
|
|
115
|
+
nk_f64_t centroid_b_x, nk_f64_t centroid_b_y,
|
|
116
|
+
nk_f64_t centroid_b_z) {
|
|
117
|
+
__m256d scaled_rotation_x_x_f64x4 = _mm256_set1_pd(scale * r[0]);
|
|
118
|
+
__m256d scaled_rotation_x_y_f64x4 = _mm256_set1_pd(scale * r[1]);
|
|
119
|
+
__m256d scaled_rotation_x_z_f64x4 = _mm256_set1_pd(scale * r[2]);
|
|
120
|
+
__m256d scaled_rotation_y_x_f64x4 = _mm256_set1_pd(scale * r[3]);
|
|
121
|
+
__m256d scaled_rotation_y_y_f64x4 = _mm256_set1_pd(scale * r[4]);
|
|
122
|
+
__m256d scaled_rotation_y_z_f64x4 = _mm256_set1_pd(scale * r[5]);
|
|
123
|
+
__m256d scaled_rotation_z_x_f64x4 = _mm256_set1_pd(scale * r[6]);
|
|
124
|
+
__m256d scaled_rotation_z_y_f64x4 = _mm256_set1_pd(scale * r[7]);
|
|
125
|
+
__m256d scaled_rotation_z_z_f64x4 = _mm256_set1_pd(scale * r[8]);
|
|
126
|
+
__m256d centroid_a_x_f64x4 = _mm256_set1_pd(centroid_a_x), centroid_a_y_f64x4 = _mm256_set1_pd(centroid_a_y);
|
|
127
|
+
__m256d centroid_a_z_f64x4 = _mm256_set1_pd(centroid_a_z), centroid_b_x_f64x4 = _mm256_set1_pd(centroid_b_x);
|
|
128
|
+
__m256d centroid_b_y_f64x4 = _mm256_set1_pd(centroid_b_y), centroid_b_z_f64x4 = _mm256_set1_pd(centroid_b_z);
|
|
129
|
+
__m256d sum_squared_f64x4 = _mm256_setzero_pd();
|
|
130
|
+
__m256 a_x_f32x8, a_y_f32x8, a_z_f32x8, b_x_f32x8, b_y_f32x8, b_z_f32x8;
|
|
131
|
+
nk_size_t index = 0;
|
|
132
|
+
|
|
133
|
+
for (; index + 8 <= n; index += 8) {
|
|
134
|
+
nk_deinterleave_f32x8_haswell_(a + index * 3, &a_x_f32x8, &a_y_f32x8, &a_z_f32x8),
|
|
135
|
+
nk_deinterleave_f32x8_haswell_(b + index * 3, &b_x_f32x8, &b_y_f32x8, &b_z_f32x8);
|
|
136
|
+
|
|
137
|
+
__m256d a_x_lower_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(a_x_f32x8));
|
|
138
|
+
__m256d a_x_upper_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(a_x_f32x8, 1));
|
|
139
|
+
__m256d a_y_lower_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(a_y_f32x8));
|
|
140
|
+
__m256d a_y_upper_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(a_y_f32x8, 1));
|
|
141
|
+
__m256d a_z_lower_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(a_z_f32x8));
|
|
142
|
+
__m256d a_z_upper_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(a_z_f32x8, 1));
|
|
143
|
+
__m256d b_x_lower_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(b_x_f32x8));
|
|
144
|
+
__m256d b_x_upper_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(b_x_f32x8, 1));
|
|
145
|
+
__m256d b_y_lower_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(b_y_f32x8));
|
|
146
|
+
__m256d b_y_upper_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(b_y_f32x8, 1));
|
|
147
|
+
__m256d b_z_lower_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(b_z_f32x8));
|
|
148
|
+
__m256d b_z_upper_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(b_z_f32x8, 1));
|
|
149
|
+
|
|
150
|
+
__m256d centered_a_x_lower_f64x4 = _mm256_sub_pd(a_x_lower_f64x4, centroid_a_x_f64x4);
|
|
151
|
+
__m256d centered_a_x_upper_f64x4 = _mm256_sub_pd(a_x_upper_f64x4, centroid_a_x_f64x4);
|
|
152
|
+
__m256d centered_a_y_lower_f64x4 = _mm256_sub_pd(a_y_lower_f64x4, centroid_a_y_f64x4);
|
|
153
|
+
__m256d centered_a_y_upper_f64x4 = _mm256_sub_pd(a_y_upper_f64x4, centroid_a_y_f64x4);
|
|
154
|
+
__m256d centered_a_z_lower_f64x4 = _mm256_sub_pd(a_z_lower_f64x4, centroid_a_z_f64x4);
|
|
155
|
+
__m256d centered_a_z_upper_f64x4 = _mm256_sub_pd(a_z_upper_f64x4, centroid_a_z_f64x4);
|
|
156
|
+
__m256d centered_b_x_lower_f64x4 = _mm256_sub_pd(b_x_lower_f64x4, centroid_b_x_f64x4);
|
|
157
|
+
__m256d centered_b_x_upper_f64x4 = _mm256_sub_pd(b_x_upper_f64x4, centroid_b_x_f64x4);
|
|
158
|
+
__m256d centered_b_y_lower_f64x4 = _mm256_sub_pd(b_y_lower_f64x4, centroid_b_y_f64x4);
|
|
159
|
+
__m256d centered_b_y_upper_f64x4 = _mm256_sub_pd(b_y_upper_f64x4, centroid_b_y_f64x4);
|
|
160
|
+
__m256d centered_b_z_lower_f64x4 = _mm256_sub_pd(b_z_lower_f64x4, centroid_b_z_f64x4);
|
|
161
|
+
__m256d centered_b_z_upper_f64x4 = _mm256_sub_pd(b_z_upper_f64x4, centroid_b_z_f64x4);
|
|
162
|
+
|
|
163
|
+
__m256d rotated_a_x_lower_f64x4 = _mm256_fmadd_pd(
|
|
164
|
+
scaled_rotation_x_z_f64x4, centered_a_z_lower_f64x4,
|
|
165
|
+
_mm256_fmadd_pd(scaled_rotation_x_y_f64x4, centered_a_y_lower_f64x4,
|
|
166
|
+
_mm256_mul_pd(scaled_rotation_x_x_f64x4, centered_a_x_lower_f64x4)));
|
|
167
|
+
__m256d rotated_a_x_upper_f64x4 = _mm256_fmadd_pd(
|
|
168
|
+
scaled_rotation_x_z_f64x4, centered_a_z_upper_f64x4,
|
|
169
|
+
_mm256_fmadd_pd(scaled_rotation_x_y_f64x4, centered_a_y_upper_f64x4,
|
|
170
|
+
_mm256_mul_pd(scaled_rotation_x_x_f64x4, centered_a_x_upper_f64x4)));
|
|
171
|
+
__m256d rotated_a_y_lower_f64x4 = _mm256_fmadd_pd(
|
|
172
|
+
scaled_rotation_y_z_f64x4, centered_a_z_lower_f64x4,
|
|
173
|
+
_mm256_fmadd_pd(scaled_rotation_y_y_f64x4, centered_a_y_lower_f64x4,
|
|
174
|
+
_mm256_mul_pd(scaled_rotation_y_x_f64x4, centered_a_x_lower_f64x4)));
|
|
175
|
+
__m256d rotated_a_y_upper_f64x4 = _mm256_fmadd_pd(
|
|
176
|
+
scaled_rotation_y_z_f64x4, centered_a_z_upper_f64x4,
|
|
177
|
+
_mm256_fmadd_pd(scaled_rotation_y_y_f64x4, centered_a_y_upper_f64x4,
|
|
178
|
+
_mm256_mul_pd(scaled_rotation_y_x_f64x4, centered_a_x_upper_f64x4)));
|
|
179
|
+
__m256d rotated_a_z_lower_f64x4 = _mm256_fmadd_pd(
|
|
180
|
+
scaled_rotation_z_z_f64x4, centered_a_z_lower_f64x4,
|
|
181
|
+
_mm256_fmadd_pd(scaled_rotation_z_y_f64x4, centered_a_y_lower_f64x4,
|
|
182
|
+
_mm256_mul_pd(scaled_rotation_z_x_f64x4, centered_a_x_lower_f64x4)));
|
|
183
|
+
__m256d rotated_a_z_upper_f64x4 = _mm256_fmadd_pd(
|
|
184
|
+
scaled_rotation_z_z_f64x4, centered_a_z_upper_f64x4,
|
|
185
|
+
_mm256_fmadd_pd(scaled_rotation_z_y_f64x4, centered_a_y_upper_f64x4,
|
|
186
|
+
_mm256_mul_pd(scaled_rotation_z_x_f64x4, centered_a_x_upper_f64x4)));
|
|
187
|
+
|
|
188
|
+
__m256d delta_x_lower_f64x4 = _mm256_sub_pd(rotated_a_x_lower_f64x4, centered_b_x_lower_f64x4);
|
|
189
|
+
__m256d delta_x_upper_f64x4 = _mm256_sub_pd(rotated_a_x_upper_f64x4, centered_b_x_upper_f64x4);
|
|
190
|
+
__m256d delta_y_lower_f64x4 = _mm256_sub_pd(rotated_a_y_lower_f64x4, centered_b_y_lower_f64x4);
|
|
191
|
+
__m256d delta_y_upper_f64x4 = _mm256_sub_pd(rotated_a_y_upper_f64x4, centered_b_y_upper_f64x4);
|
|
192
|
+
__m256d delta_z_lower_f64x4 = _mm256_sub_pd(rotated_a_z_lower_f64x4, centered_b_z_lower_f64x4);
|
|
193
|
+
__m256d delta_z_upper_f64x4 = _mm256_sub_pd(rotated_a_z_upper_f64x4, centered_b_z_upper_f64x4);
|
|
194
|
+
|
|
195
|
+
__m256d batch_sum_squared_f64x4 = _mm256_add_pd(_mm256_mul_pd(delta_x_lower_f64x4, delta_x_lower_f64x4),
|
|
196
|
+
_mm256_mul_pd(delta_x_upper_f64x4, delta_x_upper_f64x4));
|
|
197
|
+
batch_sum_squared_f64x4 = _mm256_fmadd_pd(delta_y_lower_f64x4, delta_y_lower_f64x4, batch_sum_squared_f64x4);
|
|
198
|
+
batch_sum_squared_f64x4 = _mm256_fmadd_pd(delta_y_upper_f64x4, delta_y_upper_f64x4, batch_sum_squared_f64x4);
|
|
199
|
+
batch_sum_squared_f64x4 = _mm256_fmadd_pd(delta_z_lower_f64x4, delta_z_lower_f64x4, batch_sum_squared_f64x4);
|
|
200
|
+
batch_sum_squared_f64x4 = _mm256_fmadd_pd(delta_z_upper_f64x4, delta_z_upper_f64x4, batch_sum_squared_f64x4);
|
|
201
|
+
sum_squared_f64x4 = _mm256_add_pd(sum_squared_f64x4, batch_sum_squared_f64x4);
|
|
202
|
+
}
|
|
203
|
+
|
|
204
|
+
nk_f64_t sum_squared = nk_reduce_add_f64x4_haswell_(sum_squared_f64x4);
|
|
205
|
+
for (; index < n; ++index) {
|
|
206
|
+
nk_f64_t centered_a_x = (nk_f64_t)a[index * 3 + 0] - centroid_a_x;
|
|
207
|
+
nk_f64_t centered_a_y = (nk_f64_t)a[index * 3 + 1] - centroid_a_y;
|
|
208
|
+
nk_f64_t centered_a_z = (nk_f64_t)a[index * 3 + 2] - centroid_a_z;
|
|
209
|
+
nk_f64_t centered_b_x = (nk_f64_t)b[index * 3 + 0] - centroid_b_x;
|
|
210
|
+
nk_f64_t centered_b_y = (nk_f64_t)b[index * 3 + 1] - centroid_b_y;
|
|
211
|
+
nk_f64_t centered_b_z = (nk_f64_t)b[index * 3 + 2] - centroid_b_z;
|
|
212
|
+
nk_f64_t rotated_a_x = scale * (r[0] * centered_a_x + r[1] * centered_a_y + r[2] * centered_a_z);
|
|
213
|
+
nk_f64_t rotated_a_y = scale * (r[3] * centered_a_x + r[4] * centered_a_y + r[5] * centered_a_z);
|
|
214
|
+
nk_f64_t rotated_a_z = scale * (r[6] * centered_a_x + r[7] * centered_a_y + r[8] * centered_a_z);
|
|
215
|
+
nk_f64_t delta_x = rotated_a_x - centered_b_x, delta_y = rotated_a_y - centered_b_y,
|
|
216
|
+
delta_z = rotated_a_z - centered_b_z;
|
|
217
|
+
sum_squared += delta_x * delta_x + delta_y * delta_y + delta_z * delta_z;
|
|
218
|
+
}
|
|
219
|
+
|
|
220
|
+
return sum_squared;
|
|
221
|
+
}
|
|
222
|
+
|
|
223
|
+
/* Compute sum of squared distances for f64 after applying rotation (and optional scale).
|
|
224
|
+
* Rotation matrix, scale and data are all f64 for full precision.
|
|
225
|
+
*/
|
|
226
|
+
NK_INTERNAL nk_f64_t nk_transformed_ssd_f64_haswell_(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n,
|
|
227
|
+
nk_f64_t const *r, nk_f64_t scale, nk_f64_t centroid_a_x,
|
|
228
|
+
nk_f64_t centroid_a_y, nk_f64_t centroid_a_z,
|
|
229
|
+
nk_f64_t centroid_b_x, nk_f64_t centroid_b_y,
|
|
230
|
+
nk_f64_t centroid_b_z) {
|
|
231
|
+
// Broadcast scaled rotation matrix elements
|
|
232
|
+
__m256d scaled_rotation_x_x_f64x4 = _mm256_set1_pd(scale * r[0]);
|
|
233
|
+
__m256d scaled_rotation_x_y_f64x4 = _mm256_set1_pd(scale * r[1]);
|
|
234
|
+
__m256d scaled_rotation_x_z_f64x4 = _mm256_set1_pd(scale * r[2]);
|
|
235
|
+
__m256d scaled_rotation_y_x_f64x4 = _mm256_set1_pd(scale * r[3]);
|
|
236
|
+
__m256d scaled_rotation_y_y_f64x4 = _mm256_set1_pd(scale * r[4]);
|
|
237
|
+
__m256d scaled_rotation_y_z_f64x4 = _mm256_set1_pd(scale * r[5]);
|
|
238
|
+
__m256d scaled_rotation_z_x_f64x4 = _mm256_set1_pd(scale * r[6]);
|
|
239
|
+
__m256d scaled_rotation_z_y_f64x4 = _mm256_set1_pd(scale * r[7]);
|
|
240
|
+
__m256d scaled_rotation_z_z_f64x4 = _mm256_set1_pd(scale * r[8]);
|
|
241
|
+
|
|
242
|
+
// Broadcast centroids
|
|
243
|
+
__m256d centroid_a_x_f64x4 = _mm256_set1_pd(centroid_a_x);
|
|
244
|
+
__m256d centroid_a_y_f64x4 = _mm256_set1_pd(centroid_a_y);
|
|
245
|
+
__m256d centroid_a_z_f64x4 = _mm256_set1_pd(centroid_a_z);
|
|
246
|
+
__m256d centroid_b_x_f64x4 = _mm256_set1_pd(centroid_b_x);
|
|
247
|
+
__m256d centroid_b_y_f64x4 = _mm256_set1_pd(centroid_b_y);
|
|
248
|
+
__m256d centroid_b_z_f64x4 = _mm256_set1_pd(centroid_b_z);
|
|
249
|
+
|
|
250
|
+
__m256d sum_squared_f64x4 = _mm256_setzero_pd();
|
|
251
|
+
__m256d sum_squared_compensation_f64x4 = _mm256_setzero_pd();
|
|
252
|
+
__m256d a_x_f64x4, a_y_f64x4, a_z_f64x4, b_x_f64x4, b_y_f64x4, b_z_f64x4;
|
|
253
|
+
nk_size_t j = 0;
|
|
254
|
+
|
|
255
|
+
for (; j + 4 <= n; j += 4) {
|
|
256
|
+
nk_deinterleave_f64x4_haswell_(a + j * 3, &a_x_f64x4, &a_y_f64x4, &a_z_f64x4);
|
|
257
|
+
nk_deinterleave_f64x4_haswell_(b + j * 3, &b_x_f64x4, &b_y_f64x4, &b_z_f64x4);
|
|
258
|
+
|
|
259
|
+
// Center points
|
|
260
|
+
__m256d pa_x_f64x4 = _mm256_sub_pd(a_x_f64x4, centroid_a_x_f64x4);
|
|
261
|
+
__m256d pa_y_f64x4 = _mm256_sub_pd(a_y_f64x4, centroid_a_y_f64x4);
|
|
262
|
+
__m256d pa_z_f64x4 = _mm256_sub_pd(a_z_f64x4, centroid_a_z_f64x4);
|
|
263
|
+
__m256d pb_x_f64x4 = _mm256_sub_pd(b_x_f64x4, centroid_b_x_f64x4);
|
|
264
|
+
__m256d pb_y_f64x4 = _mm256_sub_pd(b_y_f64x4, centroid_b_y_f64x4);
|
|
265
|
+
__m256d pb_z_f64x4 = _mm256_sub_pd(b_z_f64x4, centroid_b_z_f64x4);
|
|
266
|
+
|
|
267
|
+
// Rotate and scale: ra = scale * R * pa
|
|
268
|
+
__m256d ra_x_f64x4 = _mm256_fmadd_pd(scaled_rotation_x_z_f64x4, pa_z_f64x4,
|
|
269
|
+
_mm256_fmadd_pd(scaled_rotation_x_y_f64x4, pa_y_f64x4,
|
|
270
|
+
_mm256_mul_pd(scaled_rotation_x_x_f64x4, pa_x_f64x4)));
|
|
271
|
+
__m256d ra_y_f64x4 = _mm256_fmadd_pd(scaled_rotation_y_z_f64x4, pa_z_f64x4,
|
|
272
|
+
_mm256_fmadd_pd(scaled_rotation_y_y_f64x4, pa_y_f64x4,
|
|
273
|
+
_mm256_mul_pd(scaled_rotation_y_x_f64x4, pa_x_f64x4)));
|
|
274
|
+
__m256d ra_z_f64x4 = _mm256_fmadd_pd(scaled_rotation_z_z_f64x4, pa_z_f64x4,
|
|
275
|
+
_mm256_fmadd_pd(scaled_rotation_z_y_f64x4, pa_y_f64x4,
|
|
276
|
+
_mm256_mul_pd(scaled_rotation_z_x_f64x4, pa_x_f64x4)));
|
|
277
|
+
|
|
278
|
+
// Delta and accumulate
|
|
279
|
+
__m256d delta_x_f64x4 = _mm256_sub_pd(ra_x_f64x4, pb_x_f64x4);
|
|
280
|
+
__m256d delta_y_f64x4 = _mm256_sub_pd(ra_y_f64x4, pb_y_f64x4);
|
|
281
|
+
__m256d delta_z_f64x4 = _mm256_sub_pd(ra_z_f64x4, pb_z_f64x4);
|
|
282
|
+
|
|
283
|
+
nk_accumulate_square_f64x4_haswell_(&sum_squared_f64x4, &sum_squared_compensation_f64x4, delta_x_f64x4);
|
|
284
|
+
nk_accumulate_square_f64x4_haswell_(&sum_squared_f64x4, &sum_squared_compensation_f64x4, delta_y_f64x4);
|
|
285
|
+
nk_accumulate_square_f64x4_haswell_(&sum_squared_f64x4, &sum_squared_compensation_f64x4, delta_z_f64x4);
|
|
286
|
+
}
|
|
287
|
+
|
|
288
|
+
nk_f64_t sum_squared = nk_dot_stable_sum_f64x4_haswell_(sum_squared_f64x4, sum_squared_compensation_f64x4);
|
|
289
|
+
nk_f64_t sum_squared_compensation = 0.0;
|
|
290
|
+
|
|
291
|
+
// Scalar tail
|
|
292
|
+
for (; j < n; ++j) {
|
|
293
|
+
nk_f64_t pa_x = a[j * 3 + 0] - centroid_a_x;
|
|
294
|
+
nk_f64_t pa_y = a[j * 3 + 1] - centroid_a_y;
|
|
295
|
+
nk_f64_t pa_z = a[j * 3 + 2] - centroid_a_z;
|
|
296
|
+
nk_f64_t pb_x = b[j * 3 + 0] - centroid_b_x;
|
|
297
|
+
nk_f64_t pb_y = b[j * 3 + 1] - centroid_b_y;
|
|
298
|
+
nk_f64_t pb_z = b[j * 3 + 2] - centroid_b_z;
|
|
299
|
+
|
|
300
|
+
nk_f64_t ra_x = scale * (r[0] * pa_x + r[1] * pa_y + r[2] * pa_z);
|
|
301
|
+
nk_f64_t ra_y = scale * (r[3] * pa_x + r[4] * pa_y + r[5] * pa_z);
|
|
302
|
+
nk_f64_t ra_z = scale * (r[6] * pa_x + r[7] * pa_y + r[8] * pa_z);
|
|
303
|
+
|
|
304
|
+
nk_f64_t delta_x = ra_x - pb_x;
|
|
305
|
+
nk_f64_t delta_y = ra_y - pb_y;
|
|
306
|
+
nk_f64_t delta_z = ra_z - pb_z;
|
|
307
|
+
nk_accumulate_square_f64_(&sum_squared, &sum_squared_compensation, delta_x);
|
|
308
|
+
nk_accumulate_square_f64_(&sum_squared, &sum_squared_compensation, delta_y);
|
|
309
|
+
nk_accumulate_square_f64_(&sum_squared, &sum_squared_compensation, delta_z);
|
|
310
|
+
}
|
|
311
|
+
|
|
312
|
+
return sum_squared + sum_squared_compensation;
|
|
313
|
+
}
|
|
314
|
+
|
|
315
|
+
NK_PUBLIC void nk_rmsd_f32_haswell(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f32_t *a_centroid,
|
|
316
|
+
nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f64_t *result) {
|
|
317
|
+
if (rotation)
|
|
318
|
+
rotation[0] = 1, rotation[1] = 0, rotation[2] = 0, rotation[3] = 0, rotation[4] = 1, rotation[5] = 0,
|
|
319
|
+
rotation[6] = 0, rotation[7] = 0, rotation[8] = 1;
|
|
320
|
+
if (scale) *scale = 1.0f;
|
|
321
|
+
|
|
322
|
+
__m256d sum_a_x_f64x4 = _mm256_setzero_pd(), sum_a_y_f64x4 = _mm256_setzero_pd();
|
|
323
|
+
__m256d sum_a_z_f64x4 = _mm256_setzero_pd(), sum_b_x_f64x4 = _mm256_setzero_pd();
|
|
324
|
+
__m256d sum_b_y_f64x4 = _mm256_setzero_pd(), sum_b_z_f64x4 = _mm256_setzero_pd();
|
|
325
|
+
__m256d sum_squared_f64x4 = _mm256_setzero_pd();
|
|
326
|
+
__m256 a_x_f32x8, a_y_f32x8, a_z_f32x8, b_x_f32x8, b_y_f32x8, b_z_f32x8;
|
|
327
|
+
nk_size_t index = 0;
|
|
328
|
+
|
|
329
|
+
for (; index + 8 <= n; index += 8) {
|
|
330
|
+
nk_deinterleave_f32x8_haswell_(a + index * 3, &a_x_f32x8, &a_y_f32x8, &a_z_f32x8),
|
|
331
|
+
nk_deinterleave_f32x8_haswell_(b + index * 3, &b_x_f32x8, &b_y_f32x8, &b_z_f32x8);
|
|
332
|
+
|
|
333
|
+
__m256d a_x_lower_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(a_x_f32x8));
|
|
334
|
+
__m256d a_x_upper_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(a_x_f32x8, 1));
|
|
335
|
+
__m256d a_y_lower_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(a_y_f32x8));
|
|
336
|
+
__m256d a_y_upper_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(a_y_f32x8, 1));
|
|
337
|
+
__m256d a_z_lower_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(a_z_f32x8));
|
|
338
|
+
__m256d a_z_upper_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(a_z_f32x8, 1));
|
|
339
|
+
__m256d b_x_lower_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(b_x_f32x8));
|
|
340
|
+
__m256d b_x_upper_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(b_x_f32x8, 1));
|
|
341
|
+
__m256d b_y_lower_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(b_y_f32x8));
|
|
342
|
+
__m256d b_y_upper_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(b_y_f32x8, 1));
|
|
343
|
+
__m256d b_z_lower_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(b_z_f32x8));
|
|
344
|
+
__m256d b_z_upper_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(b_z_f32x8, 1));
|
|
345
|
+
|
|
346
|
+
sum_a_x_f64x4 = _mm256_add_pd(sum_a_x_f64x4, _mm256_add_pd(a_x_lower_f64x4, a_x_upper_f64x4));
|
|
347
|
+
sum_a_y_f64x4 = _mm256_add_pd(sum_a_y_f64x4, _mm256_add_pd(a_y_lower_f64x4, a_y_upper_f64x4));
|
|
348
|
+
sum_a_z_f64x4 = _mm256_add_pd(sum_a_z_f64x4, _mm256_add_pd(a_z_lower_f64x4, a_z_upper_f64x4));
|
|
349
|
+
sum_b_x_f64x4 = _mm256_add_pd(sum_b_x_f64x4, _mm256_add_pd(b_x_lower_f64x4, b_x_upper_f64x4));
|
|
350
|
+
sum_b_y_f64x4 = _mm256_add_pd(sum_b_y_f64x4, _mm256_add_pd(b_y_lower_f64x4, b_y_upper_f64x4));
|
|
351
|
+
sum_b_z_f64x4 = _mm256_add_pd(sum_b_z_f64x4, _mm256_add_pd(b_z_lower_f64x4, b_z_upper_f64x4));
|
|
352
|
+
|
|
353
|
+
__m256d delta_x_lower_f64x4 = _mm256_sub_pd(a_x_lower_f64x4, b_x_lower_f64x4);
|
|
354
|
+
__m256d delta_x_upper_f64x4 = _mm256_sub_pd(a_x_upper_f64x4, b_x_upper_f64x4);
|
|
355
|
+
__m256d delta_y_lower_f64x4 = _mm256_sub_pd(a_y_lower_f64x4, b_y_lower_f64x4);
|
|
356
|
+
__m256d delta_y_upper_f64x4 = _mm256_sub_pd(a_y_upper_f64x4, b_y_upper_f64x4);
|
|
357
|
+
__m256d delta_z_lower_f64x4 = _mm256_sub_pd(a_z_lower_f64x4, b_z_lower_f64x4);
|
|
358
|
+
__m256d delta_z_upper_f64x4 = _mm256_sub_pd(a_z_upper_f64x4, b_z_upper_f64x4);
|
|
359
|
+
__m256d batch_sum_squared_f64x4 = _mm256_add_pd(_mm256_mul_pd(delta_x_lower_f64x4, delta_x_lower_f64x4),
|
|
360
|
+
_mm256_mul_pd(delta_x_upper_f64x4, delta_x_upper_f64x4));
|
|
361
|
+
batch_sum_squared_f64x4 = _mm256_fmadd_pd(delta_y_lower_f64x4, delta_y_lower_f64x4, batch_sum_squared_f64x4);
|
|
362
|
+
batch_sum_squared_f64x4 = _mm256_fmadd_pd(delta_y_upper_f64x4, delta_y_upper_f64x4, batch_sum_squared_f64x4);
|
|
363
|
+
batch_sum_squared_f64x4 = _mm256_fmadd_pd(delta_z_lower_f64x4, delta_z_lower_f64x4, batch_sum_squared_f64x4);
|
|
364
|
+
batch_sum_squared_f64x4 = _mm256_fmadd_pd(delta_z_upper_f64x4, delta_z_upper_f64x4, batch_sum_squared_f64x4);
|
|
365
|
+
sum_squared_f64x4 = _mm256_add_pd(sum_squared_f64x4, batch_sum_squared_f64x4);
|
|
366
|
+
}
|
|
367
|
+
|
|
368
|
+
nk_f64_t total_a_x = nk_reduce_add_f64x4_haswell_(sum_a_x_f64x4);
|
|
369
|
+
nk_f64_t total_a_y = nk_reduce_add_f64x4_haswell_(sum_a_y_f64x4);
|
|
370
|
+
nk_f64_t total_a_z = nk_reduce_add_f64x4_haswell_(sum_a_z_f64x4);
|
|
371
|
+
nk_f64_t total_b_x = nk_reduce_add_f64x4_haswell_(sum_b_x_f64x4);
|
|
372
|
+
nk_f64_t total_b_y = nk_reduce_add_f64x4_haswell_(sum_b_y_f64x4);
|
|
373
|
+
nk_f64_t total_b_z = nk_reduce_add_f64x4_haswell_(sum_b_z_f64x4);
|
|
374
|
+
nk_f64_t sum_squared = nk_reduce_add_f64x4_haswell_(sum_squared_f64x4);
|
|
375
|
+
|
|
376
|
+
for (; index < n; ++index) {
|
|
377
|
+
nk_f64_t a_x = a[index * 3 + 0], a_y = a[index * 3 + 1], a_z = a[index * 3 + 2];
|
|
378
|
+
nk_f64_t b_x = b[index * 3 + 0], b_y = b[index * 3 + 1], b_z = b[index * 3 + 2];
|
|
379
|
+
total_a_x += a_x, total_a_y += a_y, total_a_z += a_z;
|
|
380
|
+
total_b_x += b_x, total_b_y += b_y, total_b_z += b_z;
|
|
381
|
+
nk_f64_t delta_x = a_x - b_x, delta_y = a_y - b_y, delta_z = a_z - b_z;
|
|
382
|
+
sum_squared += delta_x * delta_x + delta_y * delta_y + delta_z * delta_z;
|
|
383
|
+
}
|
|
384
|
+
|
|
385
|
+
nk_f64_t inv_n = 1.0 / (nk_f64_t)n;
|
|
386
|
+
nk_f64_t centroid_a_x = total_a_x * inv_n, centroid_a_y = total_a_y * inv_n, centroid_a_z = total_a_z * inv_n;
|
|
387
|
+
nk_f64_t centroid_b_x = total_b_x * inv_n, centroid_b_y = total_b_y * inv_n, centroid_b_z = total_b_z * inv_n;
|
|
388
|
+
if (a_centroid)
|
|
389
|
+
a_centroid[0] = (nk_f32_t)centroid_a_x, a_centroid[1] = (nk_f32_t)centroid_a_y,
|
|
390
|
+
a_centroid[2] = (nk_f32_t)centroid_a_z;
|
|
391
|
+
if (b_centroid)
|
|
392
|
+
b_centroid[0] = (nk_f32_t)centroid_b_x, b_centroid[1] = (nk_f32_t)centroid_b_y,
|
|
393
|
+
b_centroid[2] = (nk_f32_t)centroid_b_z;
|
|
394
|
+
|
|
395
|
+
nk_f64_t mean_delta_x = centroid_a_x - centroid_b_x, mean_delta_y = centroid_a_y - centroid_b_y,
|
|
396
|
+
mean_delta_z = centroid_a_z - centroid_b_z;
|
|
397
|
+
nk_f64_t mean_delta_squared = mean_delta_x * mean_delta_x + mean_delta_y * mean_delta_y +
|
|
398
|
+
mean_delta_z * mean_delta_z;
|
|
399
|
+
*result = nk_f64_sqrt_haswell(sum_squared * inv_n - mean_delta_squared);
|
|
400
|
+
}
|
|
401
|
+
|
|
402
|
+
NK_PUBLIC void nk_rmsd_f64_haswell(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *a_centroid,
|
|
403
|
+
nk_f64_t *b_centroid, nk_f64_t *rotation, nk_f64_t *scale, nk_f64_t *result) {
|
|
404
|
+
/* RMSD uses identity rotation and scale=1.0 */
|
|
405
|
+
if (rotation) {
|
|
406
|
+
rotation[0] = 1, rotation[1] = 0, rotation[2] = 0;
|
|
407
|
+
rotation[3] = 0, rotation[4] = 1, rotation[5] = 0;
|
|
408
|
+
rotation[6] = 0, rotation[7] = 0, rotation[8] = 1;
|
|
409
|
+
}
|
|
410
|
+
if (scale) *scale = 1.0;
|
|
411
|
+
__m256d const zeros_f64x4 = _mm256_setzero_pd();
|
|
412
|
+
|
|
413
|
+
// Accumulators for centroids and squared differences
|
|
414
|
+
__m256d sum_a_x_f64x4 = zeros_f64x4, sum_a_y_f64x4 = zeros_f64x4, sum_a_z_f64x4 = zeros_f64x4;
|
|
415
|
+
__m256d sum_b_x_f64x4 = zeros_f64x4, sum_b_y_f64x4 = zeros_f64x4, sum_b_z_f64x4 = zeros_f64x4;
|
|
416
|
+
__m256d sum_squared_x_f64x4 = zeros_f64x4, sum_squared_y_f64x4 = zeros_f64x4, sum_squared_z_f64x4 = zeros_f64x4;
|
|
417
|
+
|
|
418
|
+
__m256d a_x_f64x4, a_y_f64x4, a_z_f64x4, b_x_f64x4, b_y_f64x4, b_z_f64x4;
|
|
419
|
+
nk_size_t i = 0;
|
|
420
|
+
|
|
421
|
+
// Main loop with 2x unrolling
|
|
422
|
+
for (; i + 8 <= n; i += 8) {
|
|
423
|
+
// Iteration 0
|
|
424
|
+
nk_deinterleave_f64x4_haswell_(a + i * 3, &a_x_f64x4, &a_y_f64x4, &a_z_f64x4);
|
|
425
|
+
nk_deinterleave_f64x4_haswell_(b + i * 3, &b_x_f64x4, &b_y_f64x4, &b_z_f64x4);
|
|
426
|
+
|
|
427
|
+
sum_a_x_f64x4 = _mm256_add_pd(sum_a_x_f64x4, a_x_f64x4);
|
|
428
|
+
sum_a_y_f64x4 = _mm256_add_pd(sum_a_y_f64x4, a_y_f64x4);
|
|
429
|
+
sum_a_z_f64x4 = _mm256_add_pd(sum_a_z_f64x4, a_z_f64x4);
|
|
430
|
+
sum_b_x_f64x4 = _mm256_add_pd(sum_b_x_f64x4, b_x_f64x4);
|
|
431
|
+
sum_b_y_f64x4 = _mm256_add_pd(sum_b_y_f64x4, b_y_f64x4);
|
|
432
|
+
sum_b_z_f64x4 = _mm256_add_pd(sum_b_z_f64x4, b_z_f64x4);
|
|
433
|
+
|
|
434
|
+
__m256d delta_x_f64x4 = _mm256_sub_pd(a_x_f64x4, b_x_f64x4);
|
|
435
|
+
__m256d delta_y_f64x4 = _mm256_sub_pd(a_y_f64x4, b_y_f64x4);
|
|
436
|
+
__m256d delta_z_f64x4 = _mm256_sub_pd(a_z_f64x4, b_z_f64x4);
|
|
437
|
+
|
|
438
|
+
sum_squared_x_f64x4 = _mm256_fmadd_pd(delta_x_f64x4, delta_x_f64x4, sum_squared_x_f64x4);
|
|
439
|
+
sum_squared_y_f64x4 = _mm256_fmadd_pd(delta_y_f64x4, delta_y_f64x4, sum_squared_y_f64x4);
|
|
440
|
+
sum_squared_z_f64x4 = _mm256_fmadd_pd(delta_z_f64x4, delta_z_f64x4, sum_squared_z_f64x4);
|
|
441
|
+
|
|
442
|
+
// Iteration 1
|
|
443
|
+
__m256d a_x1_f64x4, a_y1_f64x4, a_z1_f64x4, b_x1_f64x4, b_y1_f64x4, b_z1_f64x4;
|
|
444
|
+
nk_deinterleave_f64x4_haswell_(a + (i + 4) * 3, &a_x1_f64x4, &a_y1_f64x4, &a_z1_f64x4);
|
|
445
|
+
nk_deinterleave_f64x4_haswell_(b + (i + 4) * 3, &b_x1_f64x4, &b_y1_f64x4, &b_z1_f64x4);
|
|
446
|
+
|
|
447
|
+
sum_a_x_f64x4 = _mm256_add_pd(sum_a_x_f64x4, a_x1_f64x4);
|
|
448
|
+
sum_a_y_f64x4 = _mm256_add_pd(sum_a_y_f64x4, a_y1_f64x4);
|
|
449
|
+
sum_a_z_f64x4 = _mm256_add_pd(sum_a_z_f64x4, a_z1_f64x4);
|
|
450
|
+
sum_b_x_f64x4 = _mm256_add_pd(sum_b_x_f64x4, b_x1_f64x4);
|
|
451
|
+
sum_b_y_f64x4 = _mm256_add_pd(sum_b_y_f64x4, b_y1_f64x4);
|
|
452
|
+
sum_b_z_f64x4 = _mm256_add_pd(sum_b_z_f64x4, b_z1_f64x4);
|
|
453
|
+
|
|
454
|
+
__m256d delta_x1_f64x4 = _mm256_sub_pd(a_x1_f64x4, b_x1_f64x4);
|
|
455
|
+
__m256d delta_y1_f64x4 = _mm256_sub_pd(a_y1_f64x4, b_y1_f64x4);
|
|
456
|
+
__m256d delta_z1_f64x4 = _mm256_sub_pd(a_z1_f64x4, b_z1_f64x4);
|
|
457
|
+
|
|
458
|
+
sum_squared_x_f64x4 = _mm256_fmadd_pd(delta_x1_f64x4, delta_x1_f64x4, sum_squared_x_f64x4);
|
|
459
|
+
sum_squared_y_f64x4 = _mm256_fmadd_pd(delta_y1_f64x4, delta_y1_f64x4, sum_squared_y_f64x4);
|
|
460
|
+
sum_squared_z_f64x4 = _mm256_fmadd_pd(delta_z1_f64x4, delta_z1_f64x4, sum_squared_z_f64x4);
|
|
461
|
+
}
|
|
462
|
+
|
|
463
|
+
// Handle 4-point remainder
|
|
464
|
+
for (; i + 4 <= n; i += 4) {
|
|
465
|
+
nk_deinterleave_f64x4_haswell_(a + i * 3, &a_x_f64x4, &a_y_f64x4, &a_z_f64x4);
|
|
466
|
+
nk_deinterleave_f64x4_haswell_(b + i * 3, &b_x_f64x4, &b_y_f64x4, &b_z_f64x4);
|
|
467
|
+
|
|
468
|
+
sum_a_x_f64x4 = _mm256_add_pd(sum_a_x_f64x4, a_x_f64x4);
|
|
469
|
+
sum_a_y_f64x4 = _mm256_add_pd(sum_a_y_f64x4, a_y_f64x4);
|
|
470
|
+
sum_a_z_f64x4 = _mm256_add_pd(sum_a_z_f64x4, a_z_f64x4);
|
|
471
|
+
sum_b_x_f64x4 = _mm256_add_pd(sum_b_x_f64x4, b_x_f64x4);
|
|
472
|
+
sum_b_y_f64x4 = _mm256_add_pd(sum_b_y_f64x4, b_y_f64x4);
|
|
473
|
+
sum_b_z_f64x4 = _mm256_add_pd(sum_b_z_f64x4, b_z_f64x4);
|
|
474
|
+
|
|
475
|
+
__m256d delta_x_f64x4 = _mm256_sub_pd(a_x_f64x4, b_x_f64x4);
|
|
476
|
+
__m256d delta_y_f64x4 = _mm256_sub_pd(a_y_f64x4, b_y_f64x4);
|
|
477
|
+
__m256d delta_z_f64x4 = _mm256_sub_pd(a_z_f64x4, b_z_f64x4);
|
|
478
|
+
|
|
479
|
+
sum_squared_x_f64x4 = _mm256_fmadd_pd(delta_x_f64x4, delta_x_f64x4, sum_squared_x_f64x4);
|
|
480
|
+
sum_squared_y_f64x4 = _mm256_fmadd_pd(delta_y_f64x4, delta_y_f64x4, sum_squared_y_f64x4);
|
|
481
|
+
sum_squared_z_f64x4 = _mm256_fmadd_pd(delta_z_f64x4, delta_z_f64x4, sum_squared_z_f64x4);
|
|
482
|
+
}
|
|
483
|
+
|
|
484
|
+
// Reduce vectors to scalars
|
|
485
|
+
nk_f64_t total_ax = nk_reduce_stable_f64x4_haswell_(sum_a_x_f64x4), total_ax_compensation = 0.0;
|
|
486
|
+
nk_f64_t total_ay = nk_reduce_stable_f64x4_haswell_(sum_a_y_f64x4), total_ay_compensation = 0.0;
|
|
487
|
+
nk_f64_t total_az = nk_reduce_stable_f64x4_haswell_(sum_a_z_f64x4), total_az_compensation = 0.0;
|
|
488
|
+
nk_f64_t total_bx = nk_reduce_stable_f64x4_haswell_(sum_b_x_f64x4), total_bx_compensation = 0.0;
|
|
489
|
+
nk_f64_t total_by = nk_reduce_stable_f64x4_haswell_(sum_b_y_f64x4), total_by_compensation = 0.0;
|
|
490
|
+
nk_f64_t total_bz = nk_reduce_stable_f64x4_haswell_(sum_b_z_f64x4), total_bz_compensation = 0.0;
|
|
491
|
+
nk_f64_t total_sq_x = nk_reduce_stable_f64x4_haswell_(sum_squared_x_f64x4), total_sq_x_compensation = 0.0;
|
|
492
|
+
nk_f64_t total_sq_y = nk_reduce_stable_f64x4_haswell_(sum_squared_y_f64x4), total_sq_y_compensation = 0.0;
|
|
493
|
+
nk_f64_t total_sq_z = nk_reduce_stable_f64x4_haswell_(sum_squared_z_f64x4), total_sq_z_compensation = 0.0;
|
|
494
|
+
|
|
495
|
+
// Scalar tail
|
|
496
|
+
for (; i < n; ++i) {
|
|
497
|
+
nk_f64_t ax = a[i * 3 + 0], ay = a[i * 3 + 1], az = a[i * 3 + 2];
|
|
498
|
+
nk_f64_t bx = b[i * 3 + 0], by = b[i * 3 + 1], bz = b[i * 3 + 2];
|
|
499
|
+
nk_accumulate_sum_f64_(&total_ax, &total_ax_compensation, ax);
|
|
500
|
+
nk_accumulate_sum_f64_(&total_ay, &total_ay_compensation, ay);
|
|
501
|
+
nk_accumulate_sum_f64_(&total_az, &total_az_compensation, az);
|
|
502
|
+
nk_accumulate_sum_f64_(&total_bx, &total_bx_compensation, bx);
|
|
503
|
+
nk_accumulate_sum_f64_(&total_by, &total_by_compensation, by);
|
|
504
|
+
nk_accumulate_sum_f64_(&total_bz, &total_bz_compensation, bz);
|
|
505
|
+
nk_f64_t delta_x = ax - bx, delta_y = ay - by, delta_z = az - bz;
|
|
506
|
+
nk_accumulate_square_f64_(&total_sq_x, &total_sq_x_compensation, delta_x);
|
|
507
|
+
nk_accumulate_square_f64_(&total_sq_y, &total_sq_y_compensation, delta_y);
|
|
508
|
+
nk_accumulate_square_f64_(&total_sq_z, &total_sq_z_compensation, delta_z);
|
|
509
|
+
}
|
|
510
|
+
|
|
511
|
+
total_ax += total_ax_compensation, total_ay += total_ay_compensation, total_az += total_az_compensation;
|
|
512
|
+
total_bx += total_bx_compensation, total_by += total_by_compensation, total_bz += total_bz_compensation;
|
|
513
|
+
total_sq_x += total_sq_x_compensation, total_sq_y += total_sq_y_compensation, total_sq_z += total_sq_z_compensation;
|
|
514
|
+
|
|
515
|
+
// Compute centroids
|
|
516
|
+
nk_f64_t inv_n = 1.0 / (nk_f64_t)n;
|
|
517
|
+
nk_f64_t centroid_a_x = total_ax * inv_n;
|
|
518
|
+
nk_f64_t centroid_a_y = total_ay * inv_n;
|
|
519
|
+
nk_f64_t centroid_a_z = total_az * inv_n;
|
|
520
|
+
nk_f64_t centroid_b_x = total_bx * inv_n;
|
|
521
|
+
nk_f64_t centroid_b_y = total_by * inv_n;
|
|
522
|
+
nk_f64_t centroid_b_z = total_bz * inv_n;
|
|
523
|
+
|
|
524
|
+
if (a_centroid) {
|
|
525
|
+
a_centroid[0] = centroid_a_x;
|
|
526
|
+
a_centroid[1] = centroid_a_y;
|
|
527
|
+
a_centroid[2] = centroid_a_z;
|
|
528
|
+
}
|
|
529
|
+
if (b_centroid) {
|
|
530
|
+
b_centroid[0] = centroid_b_x;
|
|
531
|
+
b_centroid[1] = centroid_b_y;
|
|
532
|
+
b_centroid[2] = centroid_b_z;
|
|
533
|
+
}
|
|
534
|
+
|
|
535
|
+
// Compute RMSD
|
|
536
|
+
nk_f64_t mean_diff_x = centroid_a_x - centroid_b_x;
|
|
537
|
+
nk_f64_t mean_diff_y = centroid_a_y - centroid_b_y;
|
|
538
|
+
nk_f64_t mean_diff_z = centroid_a_z - centroid_b_z;
|
|
539
|
+
nk_f64_t sum_squared = total_sq_x + total_sq_y + total_sq_z;
|
|
540
|
+
nk_f64_t mean_diff_sq = mean_diff_x * mean_diff_x + mean_diff_y * mean_diff_y + mean_diff_z * mean_diff_z;
|
|
541
|
+
|
|
542
|
+
*result = nk_f64_sqrt_haswell(sum_squared * inv_n - mean_diff_sq);
|
|
543
|
+
}
|
|
544
|
+
|
|
545
|
+
NK_PUBLIC void nk_kabsch_f32_haswell(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f32_t *a_centroid,
|
|
546
|
+
nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f64_t *result) {
|
|
547
|
+
if (scale) *scale = 1.0f;
|
|
548
|
+
__m256d sum_a_x_f64x4 = _mm256_setzero_pd(), sum_a_y_f64x4 = _mm256_setzero_pd();
|
|
549
|
+
__m256d sum_a_z_f64x4 = _mm256_setzero_pd(), sum_b_x_f64x4 = _mm256_setzero_pd();
|
|
550
|
+
__m256d sum_b_y_f64x4 = _mm256_setzero_pd(), sum_b_z_f64x4 = _mm256_setzero_pd();
|
|
551
|
+
__m256d covariance_00_f64x4 = _mm256_setzero_pd(), covariance_01_f64x4 = _mm256_setzero_pd();
|
|
552
|
+
__m256d covariance_02_f64x4 = _mm256_setzero_pd(), covariance_10_f64x4 = _mm256_setzero_pd();
|
|
553
|
+
__m256d covariance_11_f64x4 = _mm256_setzero_pd(), covariance_12_f64x4 = _mm256_setzero_pd();
|
|
554
|
+
__m256d covariance_20_f64x4 = _mm256_setzero_pd(), covariance_21_f64x4 = _mm256_setzero_pd();
|
|
555
|
+
__m256d covariance_22_f64x4 = _mm256_setzero_pd();
|
|
556
|
+
__m256 a_x_f32x8, a_y_f32x8, a_z_f32x8, b_x_f32x8, b_y_f32x8, b_z_f32x8;
|
|
557
|
+
nk_size_t index = 0;
|
|
558
|
+
|
|
559
|
+
for (; index + 8 <= n; index += 8) {
|
|
560
|
+
nk_deinterleave_f32x8_haswell_(a + index * 3, &a_x_f32x8, &a_y_f32x8, &a_z_f32x8),
|
|
561
|
+
nk_deinterleave_f32x8_haswell_(b + index * 3, &b_x_f32x8, &b_y_f32x8, &b_z_f32x8);
|
|
562
|
+
__m256d a_x_lower_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(a_x_f32x8));
|
|
563
|
+
__m256d a_x_upper_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(a_x_f32x8, 1));
|
|
564
|
+
__m256d a_y_lower_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(a_y_f32x8));
|
|
565
|
+
__m256d a_y_upper_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(a_y_f32x8, 1));
|
|
566
|
+
__m256d a_z_lower_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(a_z_f32x8));
|
|
567
|
+
__m256d a_z_upper_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(a_z_f32x8, 1));
|
|
568
|
+
__m256d b_x_lower_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(b_x_f32x8));
|
|
569
|
+
__m256d b_x_upper_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(b_x_f32x8, 1));
|
|
570
|
+
__m256d b_y_lower_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(b_y_f32x8));
|
|
571
|
+
__m256d b_y_upper_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(b_y_f32x8, 1));
|
|
572
|
+
__m256d b_z_lower_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(b_z_f32x8));
|
|
573
|
+
__m256d b_z_upper_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(b_z_f32x8, 1));
|
|
574
|
+
|
|
575
|
+
sum_a_x_f64x4 = _mm256_add_pd(sum_a_x_f64x4, _mm256_add_pd(a_x_lower_f64x4, a_x_upper_f64x4));
|
|
576
|
+
sum_a_y_f64x4 = _mm256_add_pd(sum_a_y_f64x4, _mm256_add_pd(a_y_lower_f64x4, a_y_upper_f64x4));
|
|
577
|
+
sum_a_z_f64x4 = _mm256_add_pd(sum_a_z_f64x4, _mm256_add_pd(a_z_lower_f64x4, a_z_upper_f64x4));
|
|
578
|
+
sum_b_x_f64x4 = _mm256_add_pd(sum_b_x_f64x4, _mm256_add_pd(b_x_lower_f64x4, b_x_upper_f64x4));
|
|
579
|
+
sum_b_y_f64x4 = _mm256_add_pd(sum_b_y_f64x4, _mm256_add_pd(b_y_lower_f64x4, b_y_upper_f64x4));
|
|
580
|
+
sum_b_z_f64x4 = _mm256_add_pd(sum_b_z_f64x4, _mm256_add_pd(b_z_lower_f64x4, b_z_upper_f64x4));
|
|
581
|
+
|
|
582
|
+
covariance_00_f64x4 = _mm256_add_pd(covariance_00_f64x4,
|
|
583
|
+
_mm256_add_pd(_mm256_mul_pd(a_x_lower_f64x4, b_x_lower_f64x4),
|
|
584
|
+
_mm256_mul_pd(a_x_upper_f64x4, b_x_upper_f64x4)));
|
|
585
|
+
covariance_01_f64x4 = _mm256_add_pd(covariance_01_f64x4,
|
|
586
|
+
_mm256_add_pd(_mm256_mul_pd(a_x_lower_f64x4, b_y_lower_f64x4),
|
|
587
|
+
_mm256_mul_pd(a_x_upper_f64x4, b_y_upper_f64x4)));
|
|
588
|
+
covariance_02_f64x4 = _mm256_add_pd(covariance_02_f64x4,
|
|
589
|
+
_mm256_add_pd(_mm256_mul_pd(a_x_lower_f64x4, b_z_lower_f64x4),
|
|
590
|
+
_mm256_mul_pd(a_x_upper_f64x4, b_z_upper_f64x4)));
|
|
591
|
+
covariance_10_f64x4 = _mm256_add_pd(covariance_10_f64x4,
|
|
592
|
+
_mm256_add_pd(_mm256_mul_pd(a_y_lower_f64x4, b_x_lower_f64x4),
|
|
593
|
+
_mm256_mul_pd(a_y_upper_f64x4, b_x_upper_f64x4)));
|
|
594
|
+
covariance_11_f64x4 = _mm256_add_pd(covariance_11_f64x4,
|
|
595
|
+
_mm256_add_pd(_mm256_mul_pd(a_y_lower_f64x4, b_y_lower_f64x4),
|
|
596
|
+
_mm256_mul_pd(a_y_upper_f64x4, b_y_upper_f64x4)));
|
|
597
|
+
covariance_12_f64x4 = _mm256_add_pd(covariance_12_f64x4,
|
|
598
|
+
_mm256_add_pd(_mm256_mul_pd(a_y_lower_f64x4, b_z_lower_f64x4),
|
|
599
|
+
_mm256_mul_pd(a_y_upper_f64x4, b_z_upper_f64x4)));
|
|
600
|
+
covariance_20_f64x4 = _mm256_add_pd(covariance_20_f64x4,
|
|
601
|
+
_mm256_add_pd(_mm256_mul_pd(a_z_lower_f64x4, b_x_lower_f64x4),
|
|
602
|
+
_mm256_mul_pd(a_z_upper_f64x4, b_x_upper_f64x4)));
|
|
603
|
+
covariance_21_f64x4 = _mm256_add_pd(covariance_21_f64x4,
|
|
604
|
+
_mm256_add_pd(_mm256_mul_pd(a_z_lower_f64x4, b_y_lower_f64x4),
|
|
605
|
+
_mm256_mul_pd(a_z_upper_f64x4, b_y_upper_f64x4)));
|
|
606
|
+
covariance_22_f64x4 = _mm256_add_pd(covariance_22_f64x4,
|
|
607
|
+
_mm256_add_pd(_mm256_mul_pd(a_z_lower_f64x4, b_z_lower_f64x4),
|
|
608
|
+
_mm256_mul_pd(a_z_upper_f64x4, b_z_upper_f64x4)));
|
|
609
|
+
}
|
|
610
|
+
|
|
611
|
+
nk_f64_t sum_a_x = nk_reduce_add_f64x4_haswell_(sum_a_x_f64x4);
|
|
612
|
+
nk_f64_t sum_a_y = nk_reduce_add_f64x4_haswell_(sum_a_y_f64x4);
|
|
613
|
+
nk_f64_t sum_a_z = nk_reduce_add_f64x4_haswell_(sum_a_z_f64x4);
|
|
614
|
+
nk_f64_t sum_b_x = nk_reduce_add_f64x4_haswell_(sum_b_x_f64x4);
|
|
615
|
+
nk_f64_t sum_b_y = nk_reduce_add_f64x4_haswell_(sum_b_y_f64x4);
|
|
616
|
+
nk_f64_t sum_b_z = nk_reduce_add_f64x4_haswell_(sum_b_z_f64x4);
|
|
617
|
+
nk_f64_t h[9] = {
|
|
618
|
+
nk_reduce_add_f64x4_haswell_(covariance_00_f64x4), nk_reduce_add_f64x4_haswell_(covariance_01_f64x4),
|
|
619
|
+
nk_reduce_add_f64x4_haswell_(covariance_02_f64x4), nk_reduce_add_f64x4_haswell_(covariance_10_f64x4),
|
|
620
|
+
nk_reduce_add_f64x4_haswell_(covariance_11_f64x4), nk_reduce_add_f64x4_haswell_(covariance_12_f64x4),
|
|
621
|
+
nk_reduce_add_f64x4_haswell_(covariance_20_f64x4), nk_reduce_add_f64x4_haswell_(covariance_21_f64x4),
|
|
622
|
+
nk_reduce_add_f64x4_haswell_(covariance_22_f64x4)};
|
|
623
|
+
|
|
624
|
+
for (; index < n; ++index) {
|
|
625
|
+
nk_f64_t a_x = a[index * 3 + 0], a_y = a[index * 3 + 1], a_z = a[index * 3 + 2];
|
|
626
|
+
nk_f64_t b_x = b[index * 3 + 0], b_y = b[index * 3 + 1], b_z = b[index * 3 + 2];
|
|
627
|
+
sum_a_x += a_x, sum_a_y += a_y, sum_a_z += a_z;
|
|
628
|
+
sum_b_x += b_x, sum_b_y += b_y, sum_b_z += b_z;
|
|
629
|
+
h[0] += a_x * b_x, h[1] += a_x * b_y, h[2] += a_x * b_z;
|
|
630
|
+
h[3] += a_y * b_x, h[4] += a_y * b_y, h[5] += a_y * b_z;
|
|
631
|
+
h[6] += a_z * b_x, h[7] += a_z * b_y, h[8] += a_z * b_z;
|
|
632
|
+
}
|
|
633
|
+
|
|
634
|
+
nk_f64_t inv_n = 1.0 / (nk_f64_t)n;
|
|
635
|
+
nk_f64_t centroid_a_x = sum_a_x * inv_n, centroid_a_y = sum_a_y * inv_n, centroid_a_z = sum_a_z * inv_n;
|
|
636
|
+
nk_f64_t centroid_b_x = sum_b_x * inv_n, centroid_b_y = sum_b_y * inv_n, centroid_b_z = sum_b_z * inv_n;
|
|
637
|
+
if (a_centroid)
|
|
638
|
+
a_centroid[0] = (nk_f32_t)centroid_a_x, a_centroid[1] = (nk_f32_t)centroid_a_y,
|
|
639
|
+
a_centroid[2] = (nk_f32_t)centroid_a_z;
|
|
640
|
+
if (b_centroid)
|
|
641
|
+
b_centroid[0] = (nk_f32_t)centroid_b_x, b_centroid[1] = (nk_f32_t)centroid_b_y,
|
|
642
|
+
b_centroid[2] = (nk_f32_t)centroid_b_z;
|
|
643
|
+
|
|
644
|
+
h[0] -= (nk_f64_t)n * centroid_a_x * centroid_b_x, h[1] -= (nk_f64_t)n * centroid_a_x * centroid_b_y,
|
|
645
|
+
h[2] -= (nk_f64_t)n * centroid_a_x * centroid_b_z, h[3] -= (nk_f64_t)n * centroid_a_y * centroid_b_x,
|
|
646
|
+
h[4] -= (nk_f64_t)n * centroid_a_y * centroid_b_y, h[5] -= (nk_f64_t)n * centroid_a_y * centroid_b_z,
|
|
647
|
+
h[6] -= (nk_f64_t)n * centroid_a_z * centroid_b_x, h[7] -= (nk_f64_t)n * centroid_a_z * centroid_b_y,
|
|
648
|
+
h[8] -= (nk_f64_t)n * centroid_a_z * centroid_b_z;
|
|
649
|
+
|
|
650
|
+
nk_f64_t cross_covariance[9] = {h[0], h[1], h[2], h[3], h[4], h[5], h[6], h[7], h[8]};
|
|
651
|
+
nk_f64_t svd_u[9], svd_s[9], svd_v[9], r[9];
|
|
652
|
+
nk_svd3x3_f64_(cross_covariance, svd_u, svd_s, svd_v);
|
|
653
|
+
r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
|
|
654
|
+
r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
|
|
655
|
+
r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
|
|
656
|
+
r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
|
|
657
|
+
r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
|
|
658
|
+
r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
|
|
659
|
+
r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
|
|
660
|
+
r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
|
|
661
|
+
r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
|
|
662
|
+
if (nk_det3x3_f64_(r) < 0) {
|
|
663
|
+
svd_v[2] = -svd_v[2], svd_v[5] = -svd_v[5], svd_v[8] = -svd_v[8];
|
|
664
|
+
r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
|
|
665
|
+
r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
|
|
666
|
+
r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
|
|
667
|
+
r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
|
|
668
|
+
r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
|
|
669
|
+
r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
|
|
670
|
+
r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
|
|
671
|
+
r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
|
|
672
|
+
r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
|
|
673
|
+
}
|
|
674
|
+
|
|
675
|
+
if (rotation)
|
|
676
|
+
for (int j = 0; j != 9; ++j) rotation[j] = (nk_f32_t)r[j];
|
|
677
|
+
nk_f64_t sum_squared = nk_transformed_ssd_f32_haswell_(a, b, n, r, 1.0, centroid_a_x, centroid_a_y, centroid_a_z,
|
|
678
|
+
centroid_b_x, centroid_b_y, centroid_b_z);
|
|
679
|
+
*result = nk_f64_sqrt_haswell(sum_squared / (nk_f64_t)n);
|
|
680
|
+
}
|
|
681
|
+
|
|
682
|
+
NK_PUBLIC void nk_kabsch_f64_haswell(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *a_centroid,
|
|
683
|
+
nk_f64_t *b_centroid, nk_f64_t *rotation, nk_f64_t *scale, nk_f64_t *result) {
|
|
684
|
+
__m256d const zeros_f64x4 = _mm256_setzero_pd();
|
|
685
|
+
|
|
686
|
+
// Accumulators for centroids
|
|
687
|
+
__m256d sum_a_x_f64x4 = zeros_f64x4, sum_a_y_f64x4 = zeros_f64x4, sum_a_z_f64x4 = zeros_f64x4;
|
|
688
|
+
__m256d sum_b_x_f64x4 = zeros_f64x4, sum_b_y_f64x4 = zeros_f64x4, sum_b_z_f64x4 = zeros_f64x4;
|
|
689
|
+
|
|
690
|
+
// Accumulators for covariance matrix (sum of outer products)
|
|
691
|
+
__m256d cov_xx_f64x4 = zeros_f64x4, cov_xy_f64x4 = zeros_f64x4, cov_xz_f64x4 = zeros_f64x4;
|
|
692
|
+
__m256d cov_yx_f64x4 = zeros_f64x4, cov_yy_f64x4 = zeros_f64x4, cov_yz_f64x4 = zeros_f64x4;
|
|
693
|
+
__m256d cov_zx_f64x4 = zeros_f64x4, cov_zy_f64x4 = zeros_f64x4, cov_zz_f64x4 = zeros_f64x4;
|
|
694
|
+
|
|
695
|
+
nk_size_t i = 0;
|
|
696
|
+
__m256d a_x_f64x4, a_y_f64x4, a_z_f64x4, b_x_f64x4, b_y_f64x4, b_z_f64x4;
|
|
697
|
+
|
|
698
|
+
// Fused single-pass
|
|
699
|
+
for (; i + 4 <= n; i += 4) {
|
|
700
|
+
nk_deinterleave_f64x4_haswell_(a + i * 3, &a_x_f64x4, &a_y_f64x4, &a_z_f64x4);
|
|
701
|
+
nk_deinterleave_f64x4_haswell_(b + i * 3, &b_x_f64x4, &b_y_f64x4, &b_z_f64x4);
|
|
702
|
+
|
|
703
|
+
sum_a_x_f64x4 = _mm256_add_pd(sum_a_x_f64x4, a_x_f64x4);
|
|
704
|
+
sum_a_y_f64x4 = _mm256_add_pd(sum_a_y_f64x4, a_y_f64x4);
|
|
705
|
+
sum_a_z_f64x4 = _mm256_add_pd(sum_a_z_f64x4, a_z_f64x4);
|
|
706
|
+
sum_b_x_f64x4 = _mm256_add_pd(sum_b_x_f64x4, b_x_f64x4);
|
|
707
|
+
sum_b_y_f64x4 = _mm256_add_pd(sum_b_y_f64x4, b_y_f64x4);
|
|
708
|
+
sum_b_z_f64x4 = _mm256_add_pd(sum_b_z_f64x4, b_z_f64x4);
|
|
709
|
+
|
|
710
|
+
cov_xx_f64x4 = _mm256_fmadd_pd(a_x_f64x4, b_x_f64x4, cov_xx_f64x4);
|
|
711
|
+
cov_xy_f64x4 = _mm256_fmadd_pd(a_x_f64x4, b_y_f64x4, cov_xy_f64x4);
|
|
712
|
+
cov_xz_f64x4 = _mm256_fmadd_pd(a_x_f64x4, b_z_f64x4, cov_xz_f64x4);
|
|
713
|
+
cov_yx_f64x4 = _mm256_fmadd_pd(a_y_f64x4, b_x_f64x4, cov_yx_f64x4);
|
|
714
|
+
cov_yy_f64x4 = _mm256_fmadd_pd(a_y_f64x4, b_y_f64x4, cov_yy_f64x4);
|
|
715
|
+
cov_yz_f64x4 = _mm256_fmadd_pd(a_y_f64x4, b_z_f64x4, cov_yz_f64x4);
|
|
716
|
+
cov_zx_f64x4 = _mm256_fmadd_pd(a_z_f64x4, b_x_f64x4, cov_zx_f64x4);
|
|
717
|
+
cov_zy_f64x4 = _mm256_fmadd_pd(a_z_f64x4, b_y_f64x4, cov_zy_f64x4);
|
|
718
|
+
cov_zz_f64x4 = _mm256_fmadd_pd(a_z_f64x4, b_z_f64x4, cov_zz_f64x4);
|
|
719
|
+
}
|
|
720
|
+
|
|
721
|
+
// Reduce vector accumulators
|
|
722
|
+
nk_f64_t sum_a_x = nk_reduce_stable_f64x4_haswell_(sum_a_x_f64x4), sum_a_x_compensation = 0.0;
|
|
723
|
+
nk_f64_t sum_a_y = nk_reduce_stable_f64x4_haswell_(sum_a_y_f64x4), sum_a_y_compensation = 0.0;
|
|
724
|
+
nk_f64_t sum_a_z = nk_reduce_stable_f64x4_haswell_(sum_a_z_f64x4), sum_a_z_compensation = 0.0;
|
|
725
|
+
nk_f64_t sum_b_x = nk_reduce_stable_f64x4_haswell_(sum_b_x_f64x4), sum_b_x_compensation = 0.0;
|
|
726
|
+
nk_f64_t sum_b_y = nk_reduce_stable_f64x4_haswell_(sum_b_y_f64x4), sum_b_y_compensation = 0.0;
|
|
727
|
+
nk_f64_t sum_b_z = nk_reduce_stable_f64x4_haswell_(sum_b_z_f64x4), sum_b_z_compensation = 0.0;
|
|
728
|
+
|
|
729
|
+
nk_f64_t covariance_x_x = nk_reduce_stable_f64x4_haswell_(cov_xx_f64x4), covariance_x_x_compensation = 0.0;
|
|
730
|
+
nk_f64_t covariance_x_y = nk_reduce_stable_f64x4_haswell_(cov_xy_f64x4), covariance_x_y_compensation = 0.0;
|
|
731
|
+
nk_f64_t covariance_x_z = nk_reduce_stable_f64x4_haswell_(cov_xz_f64x4), covariance_x_z_compensation = 0.0;
|
|
732
|
+
nk_f64_t covariance_y_x = nk_reduce_stable_f64x4_haswell_(cov_yx_f64x4), covariance_y_x_compensation = 0.0;
|
|
733
|
+
nk_f64_t covariance_y_y = nk_reduce_stable_f64x4_haswell_(cov_yy_f64x4), covariance_y_y_compensation = 0.0;
|
|
734
|
+
nk_f64_t covariance_y_z = nk_reduce_stable_f64x4_haswell_(cov_yz_f64x4), covariance_y_z_compensation = 0.0;
|
|
735
|
+
nk_f64_t covariance_z_x = nk_reduce_stable_f64x4_haswell_(cov_zx_f64x4), covariance_z_x_compensation = 0.0;
|
|
736
|
+
nk_f64_t covariance_z_y = nk_reduce_stable_f64x4_haswell_(cov_zy_f64x4), covariance_z_y_compensation = 0.0;
|
|
737
|
+
nk_f64_t covariance_z_z = nk_reduce_stable_f64x4_haswell_(cov_zz_f64x4), covariance_z_z_compensation = 0.0;
|
|
738
|
+
|
|
739
|
+
// Scalar tail
|
|
740
|
+
for (; i < n; ++i) {
|
|
741
|
+
nk_f64_t ax = a[i * 3 + 0], ay = a[i * 3 + 1], az = a[i * 3 + 2];
|
|
742
|
+
nk_f64_t bx = b[i * 3 + 0], by = b[i * 3 + 1], bz = b[i * 3 + 2];
|
|
743
|
+
nk_accumulate_sum_f64_(&sum_a_x, &sum_a_x_compensation, ax);
|
|
744
|
+
nk_accumulate_sum_f64_(&sum_a_y, &sum_a_y_compensation, ay);
|
|
745
|
+
nk_accumulate_sum_f64_(&sum_a_z, &sum_a_z_compensation, az);
|
|
746
|
+
nk_accumulate_sum_f64_(&sum_b_x, &sum_b_x_compensation, bx);
|
|
747
|
+
nk_accumulate_sum_f64_(&sum_b_y, &sum_b_y_compensation, by);
|
|
748
|
+
nk_accumulate_sum_f64_(&sum_b_z, &sum_b_z_compensation, bz);
|
|
749
|
+
nk_accumulate_product_f64_(&covariance_x_x, &covariance_x_x_compensation, ax, bx);
|
|
750
|
+
nk_accumulate_product_f64_(&covariance_x_y, &covariance_x_y_compensation, ax, by);
|
|
751
|
+
nk_accumulate_product_f64_(&covariance_x_z, &covariance_x_z_compensation, ax, bz);
|
|
752
|
+
nk_accumulate_product_f64_(&covariance_y_x, &covariance_y_x_compensation, ay, bx);
|
|
753
|
+
nk_accumulate_product_f64_(&covariance_y_y, &covariance_y_y_compensation, ay, by);
|
|
754
|
+
nk_accumulate_product_f64_(&covariance_y_z, &covariance_y_z_compensation, ay, bz);
|
|
755
|
+
nk_accumulate_product_f64_(&covariance_z_x, &covariance_z_x_compensation, az, bx);
|
|
756
|
+
nk_accumulate_product_f64_(&covariance_z_y, &covariance_z_y_compensation, az, by);
|
|
757
|
+
nk_accumulate_product_f64_(&covariance_z_z, &covariance_z_z_compensation, az, bz);
|
|
758
|
+
}
|
|
759
|
+
|
|
760
|
+
sum_a_x += sum_a_x_compensation, sum_a_y += sum_a_y_compensation, sum_a_z += sum_a_z_compensation;
|
|
761
|
+
sum_b_x += sum_b_x_compensation, sum_b_y += sum_b_y_compensation, sum_b_z += sum_b_z_compensation;
|
|
762
|
+
covariance_x_x += covariance_x_x_compensation, covariance_x_y += covariance_x_y_compensation,
|
|
763
|
+
covariance_x_z += covariance_x_z_compensation;
|
|
764
|
+
covariance_y_x += covariance_y_x_compensation, covariance_y_y += covariance_y_y_compensation,
|
|
765
|
+
covariance_y_z += covariance_y_z_compensation;
|
|
766
|
+
covariance_z_x += covariance_z_x_compensation, covariance_z_y += covariance_z_y_compensation,
|
|
767
|
+
covariance_z_z += covariance_z_z_compensation;
|
|
768
|
+
|
|
769
|
+
// Compute centroids
|
|
770
|
+
nk_f64_t inv_n = 1.0 / (nk_f64_t)n;
|
|
771
|
+
nk_f64_t centroid_a_x = sum_a_x * inv_n;
|
|
772
|
+
nk_f64_t centroid_a_y = sum_a_y * inv_n;
|
|
773
|
+
nk_f64_t centroid_a_z = sum_a_z * inv_n;
|
|
774
|
+
nk_f64_t centroid_b_x = sum_b_x * inv_n;
|
|
775
|
+
nk_f64_t centroid_b_y = sum_b_y * inv_n;
|
|
776
|
+
nk_f64_t centroid_b_z = sum_b_z * inv_n;
|
|
777
|
+
|
|
778
|
+
if (a_centroid) {
|
|
779
|
+
a_centroid[0] = centroid_a_x;
|
|
780
|
+
a_centroid[1] = centroid_a_y;
|
|
781
|
+
a_centroid[2] = centroid_a_z;
|
|
782
|
+
}
|
|
783
|
+
if (b_centroid) {
|
|
784
|
+
b_centroid[0] = centroid_b_x;
|
|
785
|
+
b_centroid[1] = centroid_b_y;
|
|
786
|
+
b_centroid[2] = centroid_b_z;
|
|
787
|
+
}
|
|
788
|
+
|
|
789
|
+
// Apply centering correction: H_centered = H - n * centroid_a * centroid_bᵀ
|
|
790
|
+
covariance_x_x -= n * centroid_a_x * centroid_b_x;
|
|
791
|
+
covariance_x_y -= n * centroid_a_x * centroid_b_y;
|
|
792
|
+
covariance_x_z -= n * centroid_a_x * centroid_b_z;
|
|
793
|
+
covariance_y_x -= n * centroid_a_y * centroid_b_x;
|
|
794
|
+
covariance_y_y -= n * centroid_a_y * centroid_b_y;
|
|
795
|
+
covariance_y_z -= n * centroid_a_y * centroid_b_z;
|
|
796
|
+
covariance_z_x -= n * centroid_a_z * centroid_b_x;
|
|
797
|
+
covariance_z_y -= n * centroid_a_z * centroid_b_y;
|
|
798
|
+
covariance_z_z -= n * centroid_a_z * centroid_b_z;
|
|
799
|
+
|
|
800
|
+
// Compute SVD and optimal rotation using f64 precision (svd_s is 9-element diagonal matrix)
|
|
801
|
+
nk_f64_t cross_covariance[9] = {covariance_x_x, covariance_x_y, covariance_x_z, covariance_y_x, covariance_y_y,
|
|
802
|
+
covariance_y_z, covariance_z_x, covariance_z_y, covariance_z_z};
|
|
803
|
+
nk_f64_t svd_u[9], svd_s[9], svd_v[9];
|
|
804
|
+
nk_svd3x3_f64_(cross_covariance, svd_u, svd_s, svd_v);
|
|
805
|
+
|
|
806
|
+
nk_f64_t r[9];
|
|
807
|
+
nk_rotation_from_svd_f64_haswell_(svd_u, svd_v, r);
|
|
808
|
+
|
|
809
|
+
// Handle reflection: if det(R) < 0, negate third column of V and recompute R
|
|
810
|
+
if (nk_det3x3_f64_(r) < 0) {
|
|
811
|
+
svd_v[2] = -svd_v[2];
|
|
812
|
+
svd_v[5] = -svd_v[5];
|
|
813
|
+
svd_v[8] = -svd_v[8];
|
|
814
|
+
nk_rotation_from_svd_f64_haswell_(svd_u, svd_v, r);
|
|
815
|
+
}
|
|
816
|
+
|
|
817
|
+
/* Output rotation matrix and scale=1.0 */
|
|
818
|
+
if (rotation) {
|
|
819
|
+
for (int j = 0; j < 9; ++j) rotation[j] = r[j];
|
|
820
|
+
}
|
|
821
|
+
if (scale) *scale = 1.0;
|
|
822
|
+
|
|
823
|
+
// Compute RMSD after optimal rotation
|
|
824
|
+
nk_f64_t sum_squared = nk_transformed_ssd_f64_haswell_(a, b, n, r, 1.0, centroid_a_x, centroid_a_y, centroid_a_z,
|
|
825
|
+
centroid_b_x, centroid_b_y, centroid_b_z);
|
|
826
|
+
*result = nk_f64_sqrt_haswell(sum_squared * inv_n);
|
|
827
|
+
}
|
|
828
|
+
|
|
829
|
+
NK_PUBLIC void nk_umeyama_f32_haswell(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f32_t *a_centroid,
|
|
830
|
+
nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f64_t *result) {
|
|
831
|
+
__m256d sum_a_x_f64x4 = _mm256_setzero_pd(), sum_a_y_f64x4 = _mm256_setzero_pd();
|
|
832
|
+
__m256d sum_a_z_f64x4 = _mm256_setzero_pd(), sum_b_x_f64x4 = _mm256_setzero_pd();
|
|
833
|
+
__m256d sum_b_y_f64x4 = _mm256_setzero_pd(), sum_b_z_f64x4 = _mm256_setzero_pd();
|
|
834
|
+
__m256d covariance_00_f64x4 = _mm256_setzero_pd(), covariance_01_f64x4 = _mm256_setzero_pd();
|
|
835
|
+
__m256d covariance_02_f64x4 = _mm256_setzero_pd(), covariance_10_f64x4 = _mm256_setzero_pd();
|
|
836
|
+
__m256d covariance_11_f64x4 = _mm256_setzero_pd(), covariance_12_f64x4 = _mm256_setzero_pd();
|
|
837
|
+
__m256d covariance_20_f64x4 = _mm256_setzero_pd(), covariance_21_f64x4 = _mm256_setzero_pd();
|
|
838
|
+
__m256d covariance_22_f64x4 = _mm256_setzero_pd(), variance_a_f64x4 = _mm256_setzero_pd();
|
|
839
|
+
__m256 a_x_f32x8, a_y_f32x8, a_z_f32x8, b_x_f32x8, b_y_f32x8, b_z_f32x8;
|
|
840
|
+
nk_size_t index = 0;
|
|
841
|
+
|
|
842
|
+
for (; index + 8 <= n; index += 8) {
|
|
843
|
+
nk_deinterleave_f32x8_haswell_(a + index * 3, &a_x_f32x8, &a_y_f32x8, &a_z_f32x8),
|
|
844
|
+
nk_deinterleave_f32x8_haswell_(b + index * 3, &b_x_f32x8, &b_y_f32x8, &b_z_f32x8);
|
|
845
|
+
__m256d a_x_lower_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(a_x_f32x8));
|
|
846
|
+
__m256d a_x_upper_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(a_x_f32x8, 1));
|
|
847
|
+
__m256d a_y_lower_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(a_y_f32x8));
|
|
848
|
+
__m256d a_y_upper_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(a_y_f32x8, 1));
|
|
849
|
+
__m256d a_z_lower_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(a_z_f32x8));
|
|
850
|
+
__m256d a_z_upper_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(a_z_f32x8, 1));
|
|
851
|
+
__m256d b_x_lower_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(b_x_f32x8));
|
|
852
|
+
__m256d b_x_upper_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(b_x_f32x8, 1));
|
|
853
|
+
__m256d b_y_lower_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(b_y_f32x8));
|
|
854
|
+
__m256d b_y_upper_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(b_y_f32x8, 1));
|
|
855
|
+
__m256d b_z_lower_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(b_z_f32x8));
|
|
856
|
+
__m256d b_z_upper_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(b_z_f32x8, 1));
|
|
857
|
+
|
|
858
|
+
sum_a_x_f64x4 = _mm256_add_pd(sum_a_x_f64x4, _mm256_add_pd(a_x_lower_f64x4, a_x_upper_f64x4));
|
|
859
|
+
sum_a_y_f64x4 = _mm256_add_pd(sum_a_y_f64x4, _mm256_add_pd(a_y_lower_f64x4, a_y_upper_f64x4));
|
|
860
|
+
sum_a_z_f64x4 = _mm256_add_pd(sum_a_z_f64x4, _mm256_add_pd(a_z_lower_f64x4, a_z_upper_f64x4));
|
|
861
|
+
sum_b_x_f64x4 = _mm256_add_pd(sum_b_x_f64x4, _mm256_add_pd(b_x_lower_f64x4, b_x_upper_f64x4));
|
|
862
|
+
sum_b_y_f64x4 = _mm256_add_pd(sum_b_y_f64x4, _mm256_add_pd(b_y_lower_f64x4, b_y_upper_f64x4));
|
|
863
|
+
sum_b_z_f64x4 = _mm256_add_pd(sum_b_z_f64x4, _mm256_add_pd(b_z_lower_f64x4, b_z_upper_f64x4));
|
|
864
|
+
covariance_00_f64x4 = _mm256_add_pd(covariance_00_f64x4,
|
|
865
|
+
_mm256_add_pd(_mm256_mul_pd(a_x_lower_f64x4, b_x_lower_f64x4),
|
|
866
|
+
_mm256_mul_pd(a_x_upper_f64x4, b_x_upper_f64x4)));
|
|
867
|
+
covariance_01_f64x4 = _mm256_add_pd(covariance_01_f64x4,
|
|
868
|
+
_mm256_add_pd(_mm256_mul_pd(a_x_lower_f64x4, b_y_lower_f64x4),
|
|
869
|
+
_mm256_mul_pd(a_x_upper_f64x4, b_y_upper_f64x4)));
|
|
870
|
+
covariance_02_f64x4 = _mm256_add_pd(covariance_02_f64x4,
|
|
871
|
+
_mm256_add_pd(_mm256_mul_pd(a_x_lower_f64x4, b_z_lower_f64x4),
|
|
872
|
+
_mm256_mul_pd(a_x_upper_f64x4, b_z_upper_f64x4)));
|
|
873
|
+
covariance_10_f64x4 = _mm256_add_pd(covariance_10_f64x4,
|
|
874
|
+
_mm256_add_pd(_mm256_mul_pd(a_y_lower_f64x4, b_x_lower_f64x4),
|
|
875
|
+
_mm256_mul_pd(a_y_upper_f64x4, b_x_upper_f64x4)));
|
|
876
|
+
covariance_11_f64x4 = _mm256_add_pd(covariance_11_f64x4,
|
|
877
|
+
_mm256_add_pd(_mm256_mul_pd(a_y_lower_f64x4, b_y_lower_f64x4),
|
|
878
|
+
_mm256_mul_pd(a_y_upper_f64x4, b_y_upper_f64x4)));
|
|
879
|
+
covariance_12_f64x4 = _mm256_add_pd(covariance_12_f64x4,
|
|
880
|
+
_mm256_add_pd(_mm256_mul_pd(a_y_lower_f64x4, b_z_lower_f64x4),
|
|
881
|
+
_mm256_mul_pd(a_y_upper_f64x4, b_z_upper_f64x4)));
|
|
882
|
+
covariance_20_f64x4 = _mm256_add_pd(covariance_20_f64x4,
|
|
883
|
+
_mm256_add_pd(_mm256_mul_pd(a_z_lower_f64x4, b_x_lower_f64x4),
|
|
884
|
+
_mm256_mul_pd(a_z_upper_f64x4, b_x_upper_f64x4)));
|
|
885
|
+
covariance_21_f64x4 = _mm256_add_pd(covariance_21_f64x4,
|
|
886
|
+
_mm256_add_pd(_mm256_mul_pd(a_z_lower_f64x4, b_y_lower_f64x4),
|
|
887
|
+
_mm256_mul_pd(a_z_upper_f64x4, b_y_upper_f64x4)));
|
|
888
|
+
covariance_22_f64x4 = _mm256_add_pd(covariance_22_f64x4,
|
|
889
|
+
_mm256_add_pd(_mm256_mul_pd(a_z_lower_f64x4, b_z_lower_f64x4),
|
|
890
|
+
_mm256_mul_pd(a_z_upper_f64x4, b_z_upper_f64x4)));
|
|
891
|
+
variance_a_f64x4 = _mm256_add_pd(
|
|
892
|
+
variance_a_f64x4,
|
|
893
|
+
_mm256_add_pd(_mm256_add_pd(_mm256_mul_pd(a_x_lower_f64x4, a_x_lower_f64x4),
|
|
894
|
+
_mm256_mul_pd(a_x_upper_f64x4, a_x_upper_f64x4)),
|
|
895
|
+
_mm256_add_pd(_mm256_add_pd(_mm256_mul_pd(a_y_lower_f64x4, a_y_lower_f64x4),
|
|
896
|
+
_mm256_mul_pd(a_y_upper_f64x4, a_y_upper_f64x4)),
|
|
897
|
+
_mm256_add_pd(_mm256_mul_pd(a_z_lower_f64x4, a_z_lower_f64x4),
|
|
898
|
+
_mm256_mul_pd(a_z_upper_f64x4, a_z_upper_f64x4)))));
|
|
899
|
+
}
|
|
900
|
+
|
|
901
|
+
nk_f64_t sum_a_x = nk_reduce_add_f64x4_haswell_(sum_a_x_f64x4);
|
|
902
|
+
nk_f64_t sum_a_y = nk_reduce_add_f64x4_haswell_(sum_a_y_f64x4);
|
|
903
|
+
nk_f64_t sum_a_z = nk_reduce_add_f64x4_haswell_(sum_a_z_f64x4);
|
|
904
|
+
nk_f64_t sum_b_x = nk_reduce_add_f64x4_haswell_(sum_b_x_f64x4);
|
|
905
|
+
nk_f64_t sum_b_y = nk_reduce_add_f64x4_haswell_(sum_b_y_f64x4);
|
|
906
|
+
nk_f64_t sum_b_z = nk_reduce_add_f64x4_haswell_(sum_b_z_f64x4);
|
|
907
|
+
nk_f64_t h[9] = {
|
|
908
|
+
nk_reduce_add_f64x4_haswell_(covariance_00_f64x4), nk_reduce_add_f64x4_haswell_(covariance_01_f64x4),
|
|
909
|
+
nk_reduce_add_f64x4_haswell_(covariance_02_f64x4), nk_reduce_add_f64x4_haswell_(covariance_10_f64x4),
|
|
910
|
+
nk_reduce_add_f64x4_haswell_(covariance_11_f64x4), nk_reduce_add_f64x4_haswell_(covariance_12_f64x4),
|
|
911
|
+
nk_reduce_add_f64x4_haswell_(covariance_20_f64x4), nk_reduce_add_f64x4_haswell_(covariance_21_f64x4),
|
|
912
|
+
nk_reduce_add_f64x4_haswell_(covariance_22_f64x4)};
|
|
913
|
+
nk_f64_t variance_a = nk_reduce_add_f64x4_haswell_(variance_a_f64x4);
|
|
914
|
+
|
|
915
|
+
for (; index < n; ++index) {
|
|
916
|
+
nk_f64_t a_x = a[index * 3 + 0], a_y = a[index * 3 + 1], a_z = a[index * 3 + 2];
|
|
917
|
+
nk_f64_t b_x = b[index * 3 + 0], b_y = b[index * 3 + 1], b_z = b[index * 3 + 2];
|
|
918
|
+
sum_a_x += a_x, sum_a_y += a_y, sum_a_z += a_z;
|
|
919
|
+
sum_b_x += b_x, sum_b_y += b_y, sum_b_z += b_z;
|
|
920
|
+
h[0] += a_x * b_x, h[1] += a_x * b_y, h[2] += a_x * b_z;
|
|
921
|
+
h[3] += a_y * b_x, h[4] += a_y * b_y, h[5] += a_y * b_z;
|
|
922
|
+
h[6] += a_z * b_x, h[7] += a_z * b_y, h[8] += a_z * b_z;
|
|
923
|
+
variance_a += a_x * a_x + a_y * a_y + a_z * a_z;
|
|
924
|
+
}
|
|
925
|
+
|
|
926
|
+
nk_f64_t inv_n = 1.0 / (nk_f64_t)n;
|
|
927
|
+
nk_f64_t centroid_a_x = sum_a_x * inv_n, centroid_a_y = sum_a_y * inv_n, centroid_a_z = sum_a_z * inv_n;
|
|
928
|
+
nk_f64_t centroid_b_x = sum_b_x * inv_n, centroid_b_y = sum_b_y * inv_n, centroid_b_z = sum_b_z * inv_n;
|
|
929
|
+
if (a_centroid)
|
|
930
|
+
a_centroid[0] = (nk_f32_t)centroid_a_x, a_centroid[1] = (nk_f32_t)centroid_a_y,
|
|
931
|
+
a_centroid[2] = (nk_f32_t)centroid_a_z;
|
|
932
|
+
if (b_centroid)
|
|
933
|
+
b_centroid[0] = (nk_f32_t)centroid_b_x, b_centroid[1] = (nk_f32_t)centroid_b_y,
|
|
934
|
+
b_centroid[2] = (nk_f32_t)centroid_b_z;
|
|
935
|
+
|
|
936
|
+
variance_a = variance_a * inv_n -
|
|
937
|
+
(centroid_a_x * centroid_a_x + centroid_a_y * centroid_a_y + centroid_a_z * centroid_a_z);
|
|
938
|
+
h[0] -= (nk_f64_t)n * centroid_a_x * centroid_b_x, h[1] -= (nk_f64_t)n * centroid_a_x * centroid_b_y,
|
|
939
|
+
h[2] -= (nk_f64_t)n * centroid_a_x * centroid_b_z, h[3] -= (nk_f64_t)n * centroid_a_y * centroid_b_x,
|
|
940
|
+
h[4] -= (nk_f64_t)n * centroid_a_y * centroid_b_y, h[5] -= (nk_f64_t)n * centroid_a_y * centroid_b_z,
|
|
941
|
+
h[6] -= (nk_f64_t)n * centroid_a_z * centroid_b_x, h[7] -= (nk_f64_t)n * centroid_a_z * centroid_b_y,
|
|
942
|
+
h[8] -= (nk_f64_t)n * centroid_a_z * centroid_b_z;
|
|
943
|
+
|
|
944
|
+
nk_f64_t cross_covariance[9] = {h[0], h[1], h[2], h[3], h[4], h[5], h[6], h[7], h[8]};
|
|
945
|
+
nk_f64_t svd_u[9], svd_s[9], svd_v[9], r[9];
|
|
946
|
+
nk_svd3x3_f64_(cross_covariance, svd_u, svd_s, svd_v);
|
|
947
|
+
r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
|
|
948
|
+
r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
|
|
949
|
+
r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
|
|
950
|
+
r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
|
|
951
|
+
r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
|
|
952
|
+
r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
|
|
953
|
+
r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
|
|
954
|
+
r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
|
|
955
|
+
r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
|
|
956
|
+
|
|
957
|
+
nk_f64_t det = nk_det3x3_f64_(r), sign_correction = det < 0 ? -1.0 : 1.0;
|
|
958
|
+
if (det < 0) {
|
|
959
|
+
svd_v[2] = -svd_v[2], svd_v[5] = -svd_v[5], svd_v[8] = -svd_v[8];
|
|
960
|
+
r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
|
|
961
|
+
r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
|
|
962
|
+
r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
|
|
963
|
+
r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
|
|
964
|
+
r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
|
|
965
|
+
r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
|
|
966
|
+
r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
|
|
967
|
+
r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
|
|
968
|
+
r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
|
|
969
|
+
}
|
|
970
|
+
|
|
971
|
+
nk_f64_t applied_scale = (svd_s[0] + svd_s[4] + sign_correction * svd_s[8]) / ((nk_f64_t)n * variance_a);
|
|
972
|
+
if (rotation)
|
|
973
|
+
for (int j = 0; j != 9; ++j) rotation[j] = (nk_f32_t)r[j];
|
|
974
|
+
if (scale) *scale = (nk_f32_t)applied_scale;
|
|
975
|
+
*result = nk_f64_sqrt_haswell(nk_transformed_ssd_f32_haswell_(a, b, n, r, applied_scale, centroid_a_x, centroid_a_y,
|
|
976
|
+
centroid_a_z, centroid_b_x, centroid_b_y,
|
|
977
|
+
centroid_b_z) /
|
|
978
|
+
(nk_f64_t)n);
|
|
979
|
+
}
|
|
980
|
+
|
|
981
|
+
NK_PUBLIC void nk_umeyama_f64_haswell(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *a_centroid,
|
|
982
|
+
nk_f64_t *b_centroid, nk_f64_t *rotation, nk_f64_t *scale, nk_f64_t *result) {
|
|
983
|
+
// Fused single-pass: centroids, covariance, and variance of A
|
|
984
|
+
__m256d const zeros_f64x4 = _mm256_setzero_pd();
|
|
985
|
+
|
|
986
|
+
__m256d sum_a_x_f64x4 = zeros_f64x4, sum_a_y_f64x4 = zeros_f64x4, sum_a_z_f64x4 = zeros_f64x4;
|
|
987
|
+
__m256d sum_b_x_f64x4 = zeros_f64x4, sum_b_y_f64x4 = zeros_f64x4, sum_b_z_f64x4 = zeros_f64x4;
|
|
988
|
+
__m256d cov_xx_f64x4 = zeros_f64x4, cov_xy_f64x4 = zeros_f64x4, cov_xz_f64x4 = zeros_f64x4;
|
|
989
|
+
__m256d cov_yx_f64x4 = zeros_f64x4, cov_yy_f64x4 = zeros_f64x4, cov_yz_f64x4 = zeros_f64x4;
|
|
990
|
+
__m256d cov_zx_f64x4 = zeros_f64x4, cov_zy_f64x4 = zeros_f64x4, cov_zz_f64x4 = zeros_f64x4;
|
|
991
|
+
__m256d variance_a_f64x4 = zeros_f64x4;
|
|
992
|
+
|
|
993
|
+
nk_size_t i = 0;
|
|
994
|
+
__m256d a_x_f64x4, a_y_f64x4, a_z_f64x4, b_x_f64x4, b_y_f64x4, b_z_f64x4;
|
|
995
|
+
|
|
996
|
+
for (; i + 4 <= n; i += 4) {
|
|
997
|
+
nk_deinterleave_f64x4_haswell_(a + i * 3, &a_x_f64x4, &a_y_f64x4, &a_z_f64x4);
|
|
998
|
+
nk_deinterleave_f64x4_haswell_(b + i * 3, &b_x_f64x4, &b_y_f64x4, &b_z_f64x4);
|
|
999
|
+
|
|
1000
|
+
sum_a_x_f64x4 = _mm256_add_pd(sum_a_x_f64x4, a_x_f64x4),
|
|
1001
|
+
sum_a_y_f64x4 = _mm256_add_pd(sum_a_y_f64x4, a_y_f64x4);
|
|
1002
|
+
sum_a_z_f64x4 = _mm256_add_pd(sum_a_z_f64x4, a_z_f64x4);
|
|
1003
|
+
sum_b_x_f64x4 = _mm256_add_pd(sum_b_x_f64x4, b_x_f64x4),
|
|
1004
|
+
sum_b_y_f64x4 = _mm256_add_pd(sum_b_y_f64x4, b_y_f64x4);
|
|
1005
|
+
sum_b_z_f64x4 = _mm256_add_pd(sum_b_z_f64x4, b_z_f64x4);
|
|
1006
|
+
|
|
1007
|
+
cov_xx_f64x4 = _mm256_fmadd_pd(a_x_f64x4, b_x_f64x4, cov_xx_f64x4),
|
|
1008
|
+
cov_xy_f64x4 = _mm256_fmadd_pd(a_x_f64x4, b_y_f64x4, cov_xy_f64x4);
|
|
1009
|
+
cov_xz_f64x4 = _mm256_fmadd_pd(a_x_f64x4, b_z_f64x4, cov_xz_f64x4);
|
|
1010
|
+
cov_yx_f64x4 = _mm256_fmadd_pd(a_y_f64x4, b_x_f64x4, cov_yx_f64x4),
|
|
1011
|
+
cov_yy_f64x4 = _mm256_fmadd_pd(a_y_f64x4, b_y_f64x4, cov_yy_f64x4);
|
|
1012
|
+
cov_yz_f64x4 = _mm256_fmadd_pd(a_y_f64x4, b_z_f64x4, cov_yz_f64x4);
|
|
1013
|
+
cov_zx_f64x4 = _mm256_fmadd_pd(a_z_f64x4, b_x_f64x4, cov_zx_f64x4),
|
|
1014
|
+
cov_zy_f64x4 = _mm256_fmadd_pd(a_z_f64x4, b_y_f64x4, cov_zy_f64x4);
|
|
1015
|
+
cov_zz_f64x4 = _mm256_fmadd_pd(a_z_f64x4, b_z_f64x4, cov_zz_f64x4);
|
|
1016
|
+
variance_a_f64x4 = _mm256_fmadd_pd(a_x_f64x4, a_x_f64x4, variance_a_f64x4);
|
|
1017
|
+
variance_a_f64x4 = _mm256_fmadd_pd(a_y_f64x4, a_y_f64x4, variance_a_f64x4);
|
|
1018
|
+
variance_a_f64x4 = _mm256_fmadd_pd(a_z_f64x4, a_z_f64x4, variance_a_f64x4);
|
|
1019
|
+
}
|
|
1020
|
+
|
|
1021
|
+
// Reduce vector accumulators
|
|
1022
|
+
nk_f64_t sum_a_x = nk_reduce_stable_f64x4_haswell_(sum_a_x_f64x4), sum_a_x_compensation = 0.0;
|
|
1023
|
+
nk_f64_t sum_a_y = nk_reduce_stable_f64x4_haswell_(sum_a_y_f64x4), sum_a_y_compensation = 0.0;
|
|
1024
|
+
nk_f64_t sum_a_z = nk_reduce_stable_f64x4_haswell_(sum_a_z_f64x4), sum_a_z_compensation = 0.0;
|
|
1025
|
+
nk_f64_t sum_b_x = nk_reduce_stable_f64x4_haswell_(sum_b_x_f64x4), sum_b_x_compensation = 0.0;
|
|
1026
|
+
nk_f64_t sum_b_y = nk_reduce_stable_f64x4_haswell_(sum_b_y_f64x4), sum_b_y_compensation = 0.0;
|
|
1027
|
+
nk_f64_t sum_b_z = nk_reduce_stable_f64x4_haswell_(sum_b_z_f64x4), sum_b_z_compensation = 0.0;
|
|
1028
|
+
nk_f64_t covariance_x_x = nk_reduce_stable_f64x4_haswell_(cov_xx_f64x4), covariance_x_x_compensation = 0.0;
|
|
1029
|
+
nk_f64_t covariance_x_y = nk_reduce_stable_f64x4_haswell_(cov_xy_f64x4), covariance_x_y_compensation = 0.0;
|
|
1030
|
+
nk_f64_t covariance_x_z = nk_reduce_stable_f64x4_haswell_(cov_xz_f64x4), covariance_x_z_compensation = 0.0;
|
|
1031
|
+
nk_f64_t covariance_y_x = nk_reduce_stable_f64x4_haswell_(cov_yx_f64x4), covariance_y_x_compensation = 0.0;
|
|
1032
|
+
nk_f64_t covariance_y_y = nk_reduce_stable_f64x4_haswell_(cov_yy_f64x4), covariance_y_y_compensation = 0.0;
|
|
1033
|
+
nk_f64_t covariance_y_z = nk_reduce_stable_f64x4_haswell_(cov_yz_f64x4), covariance_y_z_compensation = 0.0;
|
|
1034
|
+
nk_f64_t covariance_z_x = nk_reduce_stable_f64x4_haswell_(cov_zx_f64x4), covariance_z_x_compensation = 0.0;
|
|
1035
|
+
nk_f64_t covariance_z_y = nk_reduce_stable_f64x4_haswell_(cov_zy_f64x4), covariance_z_y_compensation = 0.0;
|
|
1036
|
+
nk_f64_t covariance_z_z = nk_reduce_stable_f64x4_haswell_(cov_zz_f64x4), covariance_z_z_compensation = 0.0;
|
|
1037
|
+
nk_f64_t variance_a_sum = nk_reduce_stable_f64x4_haswell_(variance_a_f64x4), variance_a_compensation = 0.0;
|
|
1038
|
+
|
|
1039
|
+
// Scalar tail loop for remaining points
|
|
1040
|
+
for (; i < n; i++) {
|
|
1041
|
+
nk_f64_t ax = a[i * 3 + 0], ay = a[i * 3 + 1], az = a[i * 3 + 2];
|
|
1042
|
+
nk_f64_t bx = b[i * 3 + 0], by = b[i * 3 + 1], bz = b[i * 3 + 2];
|
|
1043
|
+
nk_accumulate_sum_f64_(&sum_a_x, &sum_a_x_compensation, ax);
|
|
1044
|
+
nk_accumulate_sum_f64_(&sum_a_y, &sum_a_y_compensation, ay);
|
|
1045
|
+
nk_accumulate_sum_f64_(&sum_a_z, &sum_a_z_compensation, az);
|
|
1046
|
+
nk_accumulate_sum_f64_(&sum_b_x, &sum_b_x_compensation, bx);
|
|
1047
|
+
nk_accumulate_sum_f64_(&sum_b_y, &sum_b_y_compensation, by);
|
|
1048
|
+
nk_accumulate_sum_f64_(&sum_b_z, &sum_b_z_compensation, bz);
|
|
1049
|
+
nk_accumulate_product_f64_(&covariance_x_x, &covariance_x_x_compensation, ax, bx);
|
|
1050
|
+
nk_accumulate_product_f64_(&covariance_x_y, &covariance_x_y_compensation, ax, by);
|
|
1051
|
+
nk_accumulate_product_f64_(&covariance_x_z, &covariance_x_z_compensation, ax, bz);
|
|
1052
|
+
nk_accumulate_product_f64_(&covariance_y_x, &covariance_y_x_compensation, ay, bx);
|
|
1053
|
+
nk_accumulate_product_f64_(&covariance_y_y, &covariance_y_y_compensation, ay, by);
|
|
1054
|
+
nk_accumulate_product_f64_(&covariance_y_z, &covariance_y_z_compensation, ay, bz);
|
|
1055
|
+
nk_accumulate_product_f64_(&covariance_z_x, &covariance_z_x_compensation, az, bx);
|
|
1056
|
+
nk_accumulate_product_f64_(&covariance_z_y, &covariance_z_y_compensation, az, by);
|
|
1057
|
+
nk_accumulate_product_f64_(&covariance_z_z, &covariance_z_z_compensation, az, bz);
|
|
1058
|
+
nk_accumulate_square_f64_(&variance_a_sum, &variance_a_compensation, ax);
|
|
1059
|
+
nk_accumulate_square_f64_(&variance_a_sum, &variance_a_compensation, ay);
|
|
1060
|
+
nk_accumulate_square_f64_(&variance_a_sum, &variance_a_compensation, az);
|
|
1061
|
+
}
|
|
1062
|
+
|
|
1063
|
+
sum_a_x += sum_a_x_compensation, sum_a_y += sum_a_y_compensation, sum_a_z += sum_a_z_compensation;
|
|
1064
|
+
sum_b_x += sum_b_x_compensation, sum_b_y += sum_b_y_compensation, sum_b_z += sum_b_z_compensation;
|
|
1065
|
+
covariance_x_x += covariance_x_x_compensation, covariance_x_y += covariance_x_y_compensation,
|
|
1066
|
+
covariance_x_z += covariance_x_z_compensation;
|
|
1067
|
+
covariance_y_x += covariance_y_x_compensation, covariance_y_y += covariance_y_y_compensation,
|
|
1068
|
+
covariance_y_z += covariance_y_z_compensation;
|
|
1069
|
+
covariance_z_x += covariance_z_x_compensation, covariance_z_y += covariance_z_y_compensation,
|
|
1070
|
+
covariance_z_z += covariance_z_z_compensation;
|
|
1071
|
+
variance_a_sum += variance_a_compensation;
|
|
1072
|
+
|
|
1073
|
+
// Compute centroids
|
|
1074
|
+
nk_f64_t inv_n = 1.0 / (nk_f64_t)n;
|
|
1075
|
+
|
|
1076
|
+
nk_f64_t centroid_a_x = sum_a_x * inv_n, centroid_a_y = sum_a_y * inv_n, centroid_a_z = sum_a_z * inv_n;
|
|
1077
|
+
nk_f64_t centroid_b_x = sum_b_x * inv_n, centroid_b_y = sum_b_y * inv_n, centroid_b_z = sum_b_z * inv_n;
|
|
1078
|
+
|
|
1079
|
+
if (a_centroid) a_centroid[0] = centroid_a_x, a_centroid[1] = centroid_a_y, a_centroid[2] = centroid_a_z;
|
|
1080
|
+
if (b_centroid) b_centroid[0] = centroid_b_x, b_centroid[1] = centroid_b_y, b_centroid[2] = centroid_b_z;
|
|
1081
|
+
|
|
1082
|
+
// Compute centered covariance and variance
|
|
1083
|
+
nk_f64_t variance_a = variance_a_sum * inv_n -
|
|
1084
|
+
(centroid_a_x * centroid_a_x + centroid_a_y * centroid_a_y + centroid_a_z * centroid_a_z);
|
|
1085
|
+
|
|
1086
|
+
nk_f64_t cross_covariance[9];
|
|
1087
|
+
cross_covariance[0] = covariance_x_x - sum_a_x * sum_b_x * inv_n;
|
|
1088
|
+
cross_covariance[1] = covariance_x_y - sum_a_x * sum_b_y * inv_n;
|
|
1089
|
+
cross_covariance[2] = covariance_x_z - sum_a_x * sum_b_z * inv_n;
|
|
1090
|
+
cross_covariance[3] = covariance_y_x - sum_a_y * sum_b_x * inv_n;
|
|
1091
|
+
cross_covariance[4] = covariance_y_y - sum_a_y * sum_b_y * inv_n;
|
|
1092
|
+
cross_covariance[5] = covariance_y_z - sum_a_y * sum_b_z * inv_n;
|
|
1093
|
+
cross_covariance[6] = covariance_z_x - sum_a_z * sum_b_x * inv_n;
|
|
1094
|
+
cross_covariance[7] = covariance_z_y - sum_a_z * sum_b_y * inv_n;
|
|
1095
|
+
cross_covariance[8] = covariance_z_z - sum_a_z * sum_b_z * inv_n;
|
|
1096
|
+
|
|
1097
|
+
// SVD using f64 for full precision (svd_s is 9-element diagonal matrix)
|
|
1098
|
+
nk_f64_t svd_u[9], svd_s[9], svd_v[9];
|
|
1099
|
+
nk_svd3x3_f64_(cross_covariance, svd_u, svd_s, svd_v);
|
|
1100
|
+
|
|
1101
|
+
nk_f64_t r[9];
|
|
1102
|
+
nk_rotation_from_svd_f64_haswell_(svd_u, svd_v, r);
|
|
1103
|
+
|
|
1104
|
+
// Scale factor: c = trace(D × S) / (n × variance(a))
|
|
1105
|
+
// svd_s diagonal: [0], [4], [8]
|
|
1106
|
+
nk_f64_t det = nk_det3x3_f64_(r);
|
|
1107
|
+
nk_f64_t d3 = det < 0 ? -1.0 : 1.0;
|
|
1108
|
+
nk_f64_t trace_ds = nk_sum_three_products_f64_(svd_s[0], 1.0, svd_s[4], 1.0, svd_s[8], d3);
|
|
1109
|
+
nk_f64_t c = trace_ds / (n * variance_a);
|
|
1110
|
+
if (scale) *scale = c;
|
|
1111
|
+
|
|
1112
|
+
// Handle reflection
|
|
1113
|
+
if (det < 0) {
|
|
1114
|
+
svd_v[2] = -svd_v[2], svd_v[5] = -svd_v[5], svd_v[8] = -svd_v[8];
|
|
1115
|
+
nk_rotation_from_svd_f64_haswell_(svd_u, svd_v, r);
|
|
1116
|
+
}
|
|
1117
|
+
|
|
1118
|
+
/* Output rotation matrix */
|
|
1119
|
+
if (rotation) {
|
|
1120
|
+
for (int j = 0; j < 9; ++j) rotation[j] = r[j];
|
|
1121
|
+
}
|
|
1122
|
+
|
|
1123
|
+
// Compute RMSD with scaling
|
|
1124
|
+
nk_f64_t sum_squared = nk_transformed_ssd_f64_haswell_(a, b, n, r, c, centroid_a_x, centroid_a_y, centroid_a_z,
|
|
1125
|
+
centroid_b_x, centroid_b_y, centroid_b_z);
|
|
1126
|
+
*result = nk_f64_sqrt_haswell(sum_squared * inv_n);
|
|
1127
|
+
}
|
|
1128
|
+
|
|
1129
|
+
/* Deinterleave 8 f16 xyz triplets (24 f16 values) and convert to 3 x __m256 f32.
|
|
1130
|
+
* Uses scalar extraction for clean stride-3 access, then F16C conversion.
|
|
1131
|
+
*
|
|
1132
|
+
* Input: 24 contiguous f16 [x0,y0,z0, x1,y1,z1, ..., x7,y7,z7]
|
|
1133
|
+
* Output: x[8], y[8], z[8] vectors in f32
|
|
1134
|
+
*/
|
|
1135
|
+
NK_INTERNAL void nk_deinterleave_f16x8_to_f32x8_haswell_(nk_f16_t const *ptr, __m256 *x_out, __m256 *y_out,
|
|
1136
|
+
__m256 *z_out) {
|
|
1137
|
+
// Extract x, y, z components with stride-3 access
|
|
1138
|
+
nk_b256_vec_t x_vec, y_vec, z_vec;
|
|
1139
|
+
x_vec.f16s[0] = ptr[0], x_vec.f16s[1] = ptr[3], x_vec.f16s[2] = ptr[6], x_vec.f16s[3] = ptr[9];
|
|
1140
|
+
x_vec.f16s[4] = ptr[12], x_vec.f16s[5] = ptr[15], x_vec.f16s[6] = ptr[18], x_vec.f16s[7] = ptr[21];
|
|
1141
|
+
y_vec.f16s[0] = ptr[1], y_vec.f16s[1] = ptr[4], y_vec.f16s[2] = ptr[7], y_vec.f16s[3] = ptr[10];
|
|
1142
|
+
y_vec.f16s[4] = ptr[13], y_vec.f16s[5] = ptr[16], y_vec.f16s[6] = ptr[19], y_vec.f16s[7] = ptr[22];
|
|
1143
|
+
z_vec.f16s[0] = ptr[2], z_vec.f16s[1] = ptr[5], z_vec.f16s[2] = ptr[8], z_vec.f16s[3] = ptr[11];
|
|
1144
|
+
z_vec.f16s[4] = ptr[14], z_vec.f16s[5] = ptr[17], z_vec.f16s[6] = ptr[20], z_vec.f16s[7] = ptr[23];
|
|
1145
|
+
// Convert f16 to f32 using F16C
|
|
1146
|
+
*x_out = _mm256_cvtph_ps(x_vec.xmms[0]);
|
|
1147
|
+
*y_out = _mm256_cvtph_ps(y_vec.xmms[0]);
|
|
1148
|
+
*z_out = _mm256_cvtph_ps(z_vec.xmms[0]);
|
|
1149
|
+
}
|
|
1150
|
+
|
|
1151
|
+
/* Deinterleave 8 bf16 xyz triplets (24 bf16 values) and convert to 3 x __m256 f32.
|
|
1152
|
+
* Uses scalar extraction for clean stride-3 access, then bit-shift conversion.
|
|
1153
|
+
*
|
|
1154
|
+
* Input: 24 contiguous bf16 [x0,y0,z0, x1,y1,z1, ..., x7,y7,z7]
|
|
1155
|
+
* Output: x[8], y[8], z[8] vectors in f32
|
|
1156
|
+
*/
|
|
1157
|
+
NK_INTERNAL void nk_deinterleave_bf16x8_to_f32x8_haswell_(nk_bf16_t const *ptr, __m256 *x_out, __m256 *y_out,
|
|
1158
|
+
__m256 *z_out) {
|
|
1159
|
+
// Extract x, y, z components with stride-3 access
|
|
1160
|
+
nk_b256_vec_t x_vec, y_vec, z_vec;
|
|
1161
|
+
x_vec.bf16s[0] = ptr[0], x_vec.bf16s[1] = ptr[3], x_vec.bf16s[2] = ptr[6], x_vec.bf16s[3] = ptr[9];
|
|
1162
|
+
x_vec.bf16s[4] = ptr[12], x_vec.bf16s[5] = ptr[15], x_vec.bf16s[6] = ptr[18], x_vec.bf16s[7] = ptr[21];
|
|
1163
|
+
y_vec.bf16s[0] = ptr[1], y_vec.bf16s[1] = ptr[4], y_vec.bf16s[2] = ptr[7], y_vec.bf16s[3] = ptr[10];
|
|
1164
|
+
y_vec.bf16s[4] = ptr[13], y_vec.bf16s[5] = ptr[16], y_vec.bf16s[6] = ptr[19], y_vec.bf16s[7] = ptr[22];
|
|
1165
|
+
z_vec.bf16s[0] = ptr[2], z_vec.bf16s[1] = ptr[5], z_vec.bf16s[2] = ptr[8], z_vec.bf16s[3] = ptr[11];
|
|
1166
|
+
z_vec.bf16s[4] = ptr[14], z_vec.bf16s[5] = ptr[17], z_vec.bf16s[6] = ptr[20], z_vec.bf16s[7] = ptr[23];
|
|
1167
|
+
// Convert bf16 to f32 by left-shifting 16 bits
|
|
1168
|
+
*x_out = nk_bf16x8_to_f32x8_haswell_(x_vec.xmms[0]);
|
|
1169
|
+
*y_out = nk_bf16x8_to_f32x8_haswell_(y_vec.xmms[0]);
|
|
1170
|
+
*z_out = nk_bf16x8_to_f32x8_haswell_(z_vec.xmms[0]);
|
|
1171
|
+
}
|
|
1172
|
+
|
|
1173
|
+
/* Compute sum of squared distances for f16 data after applying rotation (and optional scale).
|
|
1174
|
+
* Loads f16 data, converts to f32 during processing.
|
|
1175
|
+
* Note: rotation matrix r is f32 (from SVD), scale and computation done in f32.
|
|
1176
|
+
*/
|
|
1177
|
+
NK_INTERNAL nk_f32_t nk_transformed_ssd_f16_haswell_(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n,
|
|
1178
|
+
nk_f32_t const *r, nk_f32_t scale, nk_f32_t centroid_a_x,
|
|
1179
|
+
nk_f32_t centroid_a_y, nk_f32_t centroid_a_z,
|
|
1180
|
+
nk_f32_t centroid_b_x, nk_f32_t centroid_b_y,
|
|
1181
|
+
nk_f32_t centroid_b_z) {
|
|
1182
|
+
// Broadcast scaled rotation matrix elements
|
|
1183
|
+
__m256 scaled_rotation_x_x_f32x8 = _mm256_set1_ps(scale * r[0]);
|
|
1184
|
+
__m256 scaled_rotation_x_y_f32x8 = _mm256_set1_ps(scale * r[1]);
|
|
1185
|
+
__m256 scaled_rotation_x_z_f32x8 = _mm256_set1_ps(scale * r[2]);
|
|
1186
|
+
__m256 scaled_rotation_y_x_f32x8 = _mm256_set1_ps(scale * r[3]);
|
|
1187
|
+
__m256 scaled_rotation_y_y_f32x8 = _mm256_set1_ps(scale * r[4]);
|
|
1188
|
+
__m256 scaled_rotation_y_z_f32x8 = _mm256_set1_ps(scale * r[5]);
|
|
1189
|
+
__m256 scaled_rotation_z_x_f32x8 = _mm256_set1_ps(scale * r[6]);
|
|
1190
|
+
__m256 scaled_rotation_z_y_f32x8 = _mm256_set1_ps(scale * r[7]);
|
|
1191
|
+
__m256 scaled_rotation_z_z_f32x8 = _mm256_set1_ps(scale * r[8]);
|
|
1192
|
+
|
|
1193
|
+
// Broadcast centroids
|
|
1194
|
+
__m256 centroid_a_x_f32x8 = _mm256_set1_ps(centroid_a_x);
|
|
1195
|
+
__m256 centroid_a_y_f32x8 = _mm256_set1_ps(centroid_a_y);
|
|
1196
|
+
__m256 centroid_a_z_f32x8 = _mm256_set1_ps(centroid_a_z);
|
|
1197
|
+
__m256 centroid_b_x_f32x8 = _mm256_set1_ps(centroid_b_x);
|
|
1198
|
+
__m256 centroid_b_y_f32x8 = _mm256_set1_ps(centroid_b_y);
|
|
1199
|
+
__m256 centroid_b_z_f32x8 = _mm256_set1_ps(centroid_b_z);
|
|
1200
|
+
|
|
1201
|
+
__m256 sum_squared_f32x8 = _mm256_setzero_ps();
|
|
1202
|
+
__m256 a_x_f32x8, a_y_f32x8, a_z_f32x8, b_x_f32x8, b_y_f32x8, b_z_f32x8;
|
|
1203
|
+
nk_size_t j = 0;
|
|
1204
|
+
|
|
1205
|
+
for (; j + 8 <= n; j += 8) {
|
|
1206
|
+
nk_deinterleave_f16x8_to_f32x8_haswell_(a + j * 3, &a_x_f32x8, &a_y_f32x8, &a_z_f32x8);
|
|
1207
|
+
nk_deinterleave_f16x8_to_f32x8_haswell_(b + j * 3, &b_x_f32x8, &b_y_f32x8, &b_z_f32x8);
|
|
1208
|
+
|
|
1209
|
+
// Center points
|
|
1210
|
+
__m256 pa_x_f32x8 = _mm256_sub_ps(a_x_f32x8, centroid_a_x_f32x8);
|
|
1211
|
+
__m256 pa_y_f32x8 = _mm256_sub_ps(a_y_f32x8, centroid_a_y_f32x8);
|
|
1212
|
+
__m256 pa_z_f32x8 = _mm256_sub_ps(a_z_f32x8, centroid_a_z_f32x8);
|
|
1213
|
+
__m256 pb_x_f32x8 = _mm256_sub_ps(b_x_f32x8, centroid_b_x_f32x8);
|
|
1214
|
+
__m256 pb_y_f32x8 = _mm256_sub_ps(b_y_f32x8, centroid_b_y_f32x8);
|
|
1215
|
+
__m256 pb_z_f32x8 = _mm256_sub_ps(b_z_f32x8, centroid_b_z_f32x8);
|
|
1216
|
+
|
|
1217
|
+
// Rotate and scale: ra = scale * R * pa
|
|
1218
|
+
__m256 ra_x_f32x8 = _mm256_fmadd_ps(scaled_rotation_x_z_f32x8, pa_z_f32x8,
|
|
1219
|
+
_mm256_fmadd_ps(scaled_rotation_x_y_f32x8, pa_y_f32x8,
|
|
1220
|
+
_mm256_mul_ps(scaled_rotation_x_x_f32x8, pa_x_f32x8)));
|
|
1221
|
+
__m256 ra_y_f32x8 = _mm256_fmadd_ps(scaled_rotation_y_z_f32x8, pa_z_f32x8,
|
|
1222
|
+
_mm256_fmadd_ps(scaled_rotation_y_y_f32x8, pa_y_f32x8,
|
|
1223
|
+
_mm256_mul_ps(scaled_rotation_y_x_f32x8, pa_x_f32x8)));
|
|
1224
|
+
__m256 ra_z_f32x8 = _mm256_fmadd_ps(scaled_rotation_z_z_f32x8, pa_z_f32x8,
|
|
1225
|
+
_mm256_fmadd_ps(scaled_rotation_z_y_f32x8, pa_y_f32x8,
|
|
1226
|
+
_mm256_mul_ps(scaled_rotation_z_x_f32x8, pa_x_f32x8)));
|
|
1227
|
+
|
|
1228
|
+
// Delta and accumulate
|
|
1229
|
+
__m256 delta_x_f32x8 = _mm256_sub_ps(ra_x_f32x8, pb_x_f32x8);
|
|
1230
|
+
__m256 delta_y_f32x8 = _mm256_sub_ps(ra_y_f32x8, pb_y_f32x8);
|
|
1231
|
+
__m256 delta_z_f32x8 = _mm256_sub_ps(ra_z_f32x8, pb_z_f32x8);
|
|
1232
|
+
|
|
1233
|
+
sum_squared_f32x8 = _mm256_fmadd_ps(delta_x_f32x8, delta_x_f32x8, sum_squared_f32x8);
|
|
1234
|
+
sum_squared_f32x8 = _mm256_fmadd_ps(delta_y_f32x8, delta_y_f32x8, sum_squared_f32x8);
|
|
1235
|
+
sum_squared_f32x8 = _mm256_fmadd_ps(delta_z_f32x8, delta_z_f32x8, sum_squared_f32x8);
|
|
1236
|
+
}
|
|
1237
|
+
|
|
1238
|
+
nk_f32_t sum_squared = nk_reduce_add_f32x8_haswell_(sum_squared_f32x8);
|
|
1239
|
+
|
|
1240
|
+
// Scalar tail
|
|
1241
|
+
for (; j < n; ++j) {
|
|
1242
|
+
nk_f32_t a_x_f32, a_y_f32, a_z_f32, b_x_f32, b_y_f32, b_z_f32;
|
|
1243
|
+
nk_f16_to_f32_haswell(&a[j * 3 + 0], &a_x_f32);
|
|
1244
|
+
nk_f16_to_f32_haswell(&a[j * 3 + 1], &a_y_f32);
|
|
1245
|
+
nk_f16_to_f32_haswell(&a[j * 3 + 2], &a_z_f32);
|
|
1246
|
+
nk_f16_to_f32_haswell(&b[j * 3 + 0], &b_x_f32);
|
|
1247
|
+
nk_f16_to_f32_haswell(&b[j * 3 + 1], &b_y_f32);
|
|
1248
|
+
nk_f16_to_f32_haswell(&b[j * 3 + 2], &b_z_f32);
|
|
1249
|
+
|
|
1250
|
+
nk_f32_t pa_x = a_x_f32 - centroid_a_x;
|
|
1251
|
+
nk_f32_t pa_y = a_y_f32 - centroid_a_y;
|
|
1252
|
+
nk_f32_t pa_z = a_z_f32 - centroid_a_z;
|
|
1253
|
+
nk_f32_t pb_x = b_x_f32 - centroid_b_x;
|
|
1254
|
+
nk_f32_t pb_y = b_y_f32 - centroid_b_y;
|
|
1255
|
+
nk_f32_t pb_z = b_z_f32 - centroid_b_z;
|
|
1256
|
+
|
|
1257
|
+
nk_f32_t ra_x = scale * (r[0] * pa_x + r[1] * pa_y + r[2] * pa_z);
|
|
1258
|
+
nk_f32_t ra_y = scale * (r[3] * pa_x + r[4] * pa_y + r[5] * pa_z);
|
|
1259
|
+
nk_f32_t ra_z = scale * (r[6] * pa_x + r[7] * pa_y + r[8] * pa_z);
|
|
1260
|
+
|
|
1261
|
+
nk_f32_t delta_x = ra_x - pb_x;
|
|
1262
|
+
nk_f32_t delta_y = ra_y - pb_y;
|
|
1263
|
+
nk_f32_t delta_z = ra_z - pb_z;
|
|
1264
|
+
sum_squared += delta_x * delta_x + delta_y * delta_y + delta_z * delta_z;
|
|
1265
|
+
}
|
|
1266
|
+
|
|
1267
|
+
return sum_squared;
|
|
1268
|
+
}
|
|
1269
|
+
|
|
1270
|
+
/* Compute sum of squared distances for bf16 data after applying rotation (and optional scale).
|
|
1271
|
+
* Loads bf16 data, converts to f32 during processing.
|
|
1272
|
+
* Note: rotation matrix r is f32 (from SVD), scale and computation done in f32.
|
|
1273
|
+
*/
|
|
1274
|
+
NK_INTERNAL nk_f32_t nk_transformed_ssd_bf16_haswell_(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n,
|
|
1275
|
+
nk_f32_t const *r, nk_f32_t scale, nk_f32_t centroid_a_x,
|
|
1276
|
+
nk_f32_t centroid_a_y, nk_f32_t centroid_a_z,
|
|
1277
|
+
nk_f32_t centroid_b_x, nk_f32_t centroid_b_y,
|
|
1278
|
+
nk_f32_t centroid_b_z) {
|
|
1279
|
+
// Broadcast scaled rotation matrix elements
|
|
1280
|
+
__m256 scaled_rotation_x_x_f32x8 = _mm256_set1_ps(scale * r[0]);
|
|
1281
|
+
__m256 scaled_rotation_x_y_f32x8 = _mm256_set1_ps(scale * r[1]);
|
|
1282
|
+
__m256 scaled_rotation_x_z_f32x8 = _mm256_set1_ps(scale * r[2]);
|
|
1283
|
+
__m256 scaled_rotation_y_x_f32x8 = _mm256_set1_ps(scale * r[3]);
|
|
1284
|
+
__m256 scaled_rotation_y_y_f32x8 = _mm256_set1_ps(scale * r[4]);
|
|
1285
|
+
__m256 scaled_rotation_y_z_f32x8 = _mm256_set1_ps(scale * r[5]);
|
|
1286
|
+
__m256 scaled_rotation_z_x_f32x8 = _mm256_set1_ps(scale * r[6]);
|
|
1287
|
+
__m256 scaled_rotation_z_y_f32x8 = _mm256_set1_ps(scale * r[7]);
|
|
1288
|
+
__m256 scaled_rotation_z_z_f32x8 = _mm256_set1_ps(scale * r[8]);
|
|
1289
|
+
|
|
1290
|
+
// Broadcast centroids
|
|
1291
|
+
__m256 centroid_a_x_f32x8 = _mm256_set1_ps(centroid_a_x);
|
|
1292
|
+
__m256 centroid_a_y_f32x8 = _mm256_set1_ps(centroid_a_y);
|
|
1293
|
+
__m256 centroid_a_z_f32x8 = _mm256_set1_ps(centroid_a_z);
|
|
1294
|
+
__m256 centroid_b_x_f32x8 = _mm256_set1_ps(centroid_b_x);
|
|
1295
|
+
__m256 centroid_b_y_f32x8 = _mm256_set1_ps(centroid_b_y);
|
|
1296
|
+
__m256 centroid_b_z_f32x8 = _mm256_set1_ps(centroid_b_z);
|
|
1297
|
+
|
|
1298
|
+
__m256 sum_squared_f32x8 = _mm256_setzero_ps();
|
|
1299
|
+
__m256 a_x_f32x8, a_y_f32x8, a_z_f32x8, b_x_f32x8, b_y_f32x8, b_z_f32x8;
|
|
1300
|
+
nk_size_t j = 0;
|
|
1301
|
+
|
|
1302
|
+
for (; j + 8 <= n; j += 8) {
|
|
1303
|
+
nk_deinterleave_bf16x8_to_f32x8_haswell_(a + j * 3, &a_x_f32x8, &a_y_f32x8, &a_z_f32x8);
|
|
1304
|
+
nk_deinterleave_bf16x8_to_f32x8_haswell_(b + j * 3, &b_x_f32x8, &b_y_f32x8, &b_z_f32x8);
|
|
1305
|
+
|
|
1306
|
+
// Center points
|
|
1307
|
+
__m256 pa_x_f32x8 = _mm256_sub_ps(a_x_f32x8, centroid_a_x_f32x8);
|
|
1308
|
+
__m256 pa_y_f32x8 = _mm256_sub_ps(a_y_f32x8, centroid_a_y_f32x8);
|
|
1309
|
+
__m256 pa_z_f32x8 = _mm256_sub_ps(a_z_f32x8, centroid_a_z_f32x8);
|
|
1310
|
+
__m256 pb_x_f32x8 = _mm256_sub_ps(b_x_f32x8, centroid_b_x_f32x8);
|
|
1311
|
+
__m256 pb_y_f32x8 = _mm256_sub_ps(b_y_f32x8, centroid_b_y_f32x8);
|
|
1312
|
+
__m256 pb_z_f32x8 = _mm256_sub_ps(b_z_f32x8, centroid_b_z_f32x8);
|
|
1313
|
+
|
|
1314
|
+
// Rotate and scale: ra = scale * R * pa
|
|
1315
|
+
__m256 ra_x_f32x8 = _mm256_fmadd_ps(scaled_rotation_x_z_f32x8, pa_z_f32x8,
|
|
1316
|
+
_mm256_fmadd_ps(scaled_rotation_x_y_f32x8, pa_y_f32x8,
|
|
1317
|
+
_mm256_mul_ps(scaled_rotation_x_x_f32x8, pa_x_f32x8)));
|
|
1318
|
+
__m256 ra_y_f32x8 = _mm256_fmadd_ps(scaled_rotation_y_z_f32x8, pa_z_f32x8,
|
|
1319
|
+
_mm256_fmadd_ps(scaled_rotation_y_y_f32x8, pa_y_f32x8,
|
|
1320
|
+
_mm256_mul_ps(scaled_rotation_y_x_f32x8, pa_x_f32x8)));
|
|
1321
|
+
__m256 ra_z_f32x8 = _mm256_fmadd_ps(scaled_rotation_z_z_f32x8, pa_z_f32x8,
|
|
1322
|
+
_mm256_fmadd_ps(scaled_rotation_z_y_f32x8, pa_y_f32x8,
|
|
1323
|
+
_mm256_mul_ps(scaled_rotation_z_x_f32x8, pa_x_f32x8)));
|
|
1324
|
+
|
|
1325
|
+
// Delta and accumulate
|
|
1326
|
+
__m256 delta_x_f32x8 = _mm256_sub_ps(ra_x_f32x8, pb_x_f32x8);
|
|
1327
|
+
__m256 delta_y_f32x8 = _mm256_sub_ps(ra_y_f32x8, pb_y_f32x8);
|
|
1328
|
+
__m256 delta_z_f32x8 = _mm256_sub_ps(ra_z_f32x8, pb_z_f32x8);
|
|
1329
|
+
|
|
1330
|
+
sum_squared_f32x8 = _mm256_fmadd_ps(delta_x_f32x8, delta_x_f32x8, sum_squared_f32x8);
|
|
1331
|
+
sum_squared_f32x8 = _mm256_fmadd_ps(delta_y_f32x8, delta_y_f32x8, sum_squared_f32x8);
|
|
1332
|
+
sum_squared_f32x8 = _mm256_fmadd_ps(delta_z_f32x8, delta_z_f32x8, sum_squared_f32x8);
|
|
1333
|
+
}
|
|
1334
|
+
|
|
1335
|
+
nk_f32_t sum_squared = nk_reduce_add_f32x8_haswell_(sum_squared_f32x8);
|
|
1336
|
+
|
|
1337
|
+
// Scalar tail
|
|
1338
|
+
for (; j < n; ++j) {
|
|
1339
|
+
nk_f32_t a_x_f32, a_y_f32, a_z_f32, b_x_f32, b_y_f32, b_z_f32;
|
|
1340
|
+
nk_bf16_to_f32_serial(&a[j * 3 + 0], &a_x_f32);
|
|
1341
|
+
nk_bf16_to_f32_serial(&a[j * 3 + 1], &a_y_f32);
|
|
1342
|
+
nk_bf16_to_f32_serial(&a[j * 3 + 2], &a_z_f32);
|
|
1343
|
+
nk_bf16_to_f32_serial(&b[j * 3 + 0], &b_x_f32);
|
|
1344
|
+
nk_bf16_to_f32_serial(&b[j * 3 + 1], &b_y_f32);
|
|
1345
|
+
nk_bf16_to_f32_serial(&b[j * 3 + 2], &b_z_f32);
|
|
1346
|
+
|
|
1347
|
+
nk_f32_t pa_x = a_x_f32 - centroid_a_x;
|
|
1348
|
+
nk_f32_t pa_y = a_y_f32 - centroid_a_y;
|
|
1349
|
+
nk_f32_t pa_z = a_z_f32 - centroid_a_z;
|
|
1350
|
+
nk_f32_t pb_x = b_x_f32 - centroid_b_x;
|
|
1351
|
+
nk_f32_t pb_y = b_y_f32 - centroid_b_y;
|
|
1352
|
+
nk_f32_t pb_z = b_z_f32 - centroid_b_z;
|
|
1353
|
+
|
|
1354
|
+
nk_f32_t ra_x = scale * (r[0] * pa_x + r[1] * pa_y + r[2] * pa_z);
|
|
1355
|
+
nk_f32_t ra_y = scale * (r[3] * pa_x + r[4] * pa_y + r[5] * pa_z);
|
|
1356
|
+
nk_f32_t ra_z = scale * (r[6] * pa_x + r[7] * pa_y + r[8] * pa_z);
|
|
1357
|
+
|
|
1358
|
+
nk_f32_t delta_x = ra_x - pb_x;
|
|
1359
|
+
nk_f32_t delta_y = ra_y - pb_y;
|
|
1360
|
+
nk_f32_t delta_z = ra_z - pb_z;
|
|
1361
|
+
sum_squared += delta_x * delta_x + delta_y * delta_y + delta_z * delta_z;
|
|
1362
|
+
}
|
|
1363
|
+
|
|
1364
|
+
return sum_squared;
|
|
1365
|
+
}
|
|
1366
|
+
|
|
1367
|
+
NK_PUBLIC void nk_rmsd_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
|
|
1368
|
+
nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
|
|
1369
|
+
/* RMSD uses identity rotation and scale=1.0 */
|
|
1370
|
+
if (rotation) {
|
|
1371
|
+
rotation[0] = 1, rotation[1] = 0, rotation[2] = 0;
|
|
1372
|
+
rotation[3] = 0, rotation[4] = 1, rotation[5] = 0;
|
|
1373
|
+
rotation[6] = 0, rotation[7] = 0, rotation[8] = 1;
|
|
1374
|
+
}
|
|
1375
|
+
if (scale) *scale = 1.0f;
|
|
1376
|
+
|
|
1377
|
+
__m256 const zeros_f32x8 = _mm256_setzero_ps();
|
|
1378
|
+
|
|
1379
|
+
// Accumulators for centroids and squared differences (all in f32)
|
|
1380
|
+
__m256 sum_a_x_f32x8 = zeros_f32x8, sum_a_y_f32x8 = zeros_f32x8, sum_a_z_f32x8 = zeros_f32x8;
|
|
1381
|
+
__m256 sum_b_x_f32x8 = zeros_f32x8, sum_b_y_f32x8 = zeros_f32x8, sum_b_z_f32x8 = zeros_f32x8;
|
|
1382
|
+
__m256 sum_squared_x_f32x8 = zeros_f32x8, sum_squared_y_f32x8 = zeros_f32x8, sum_squared_z_f32x8 = zeros_f32x8;
|
|
1383
|
+
|
|
1384
|
+
__m256 a_x_f32x8, a_y_f32x8, a_z_f32x8, b_x_f32x8, b_y_f32x8, b_z_f32x8;
|
|
1385
|
+
nk_size_t i = 0;
|
|
1386
|
+
|
|
1387
|
+
// Main loop processing 8 points at a time
|
|
1388
|
+
for (; i + 8 <= n; i += 8) {
|
|
1389
|
+
nk_deinterleave_f16x8_to_f32x8_haswell_(a + i * 3, &a_x_f32x8, &a_y_f32x8, &a_z_f32x8);
|
|
1390
|
+
nk_deinterleave_f16x8_to_f32x8_haswell_(b + i * 3, &b_x_f32x8, &b_y_f32x8, &b_z_f32x8);
|
|
1391
|
+
|
|
1392
|
+
sum_a_x_f32x8 = _mm256_add_ps(sum_a_x_f32x8, a_x_f32x8);
|
|
1393
|
+
sum_a_y_f32x8 = _mm256_add_ps(sum_a_y_f32x8, a_y_f32x8);
|
|
1394
|
+
sum_a_z_f32x8 = _mm256_add_ps(sum_a_z_f32x8, a_z_f32x8);
|
|
1395
|
+
sum_b_x_f32x8 = _mm256_add_ps(sum_b_x_f32x8, b_x_f32x8);
|
|
1396
|
+
sum_b_y_f32x8 = _mm256_add_ps(sum_b_y_f32x8, b_y_f32x8);
|
|
1397
|
+
sum_b_z_f32x8 = _mm256_add_ps(sum_b_z_f32x8, b_z_f32x8);
|
|
1398
|
+
|
|
1399
|
+
__m256 delta_x_f32x8 = _mm256_sub_ps(a_x_f32x8, b_x_f32x8);
|
|
1400
|
+
__m256 delta_y_f32x8 = _mm256_sub_ps(a_y_f32x8, b_y_f32x8);
|
|
1401
|
+
__m256 delta_z_f32x8 = _mm256_sub_ps(a_z_f32x8, b_z_f32x8);
|
|
1402
|
+
|
|
1403
|
+
sum_squared_x_f32x8 = _mm256_fmadd_ps(delta_x_f32x8, delta_x_f32x8, sum_squared_x_f32x8);
|
|
1404
|
+
sum_squared_y_f32x8 = _mm256_fmadd_ps(delta_y_f32x8, delta_y_f32x8, sum_squared_y_f32x8);
|
|
1405
|
+
sum_squared_z_f32x8 = _mm256_fmadd_ps(delta_z_f32x8, delta_z_f32x8, sum_squared_z_f32x8);
|
|
1406
|
+
}
|
|
1407
|
+
|
|
1408
|
+
// Reduce vectors to scalars
|
|
1409
|
+
nk_f32_t total_ax = nk_reduce_add_f32x8_haswell_(sum_a_x_f32x8);
|
|
1410
|
+
nk_f32_t total_ay = nk_reduce_add_f32x8_haswell_(sum_a_y_f32x8);
|
|
1411
|
+
nk_f32_t total_az = nk_reduce_add_f32x8_haswell_(sum_a_z_f32x8);
|
|
1412
|
+
nk_f32_t total_bx = nk_reduce_add_f32x8_haswell_(sum_b_x_f32x8);
|
|
1413
|
+
nk_f32_t total_by = nk_reduce_add_f32x8_haswell_(sum_b_y_f32x8);
|
|
1414
|
+
nk_f32_t total_bz = nk_reduce_add_f32x8_haswell_(sum_b_z_f32x8);
|
|
1415
|
+
nk_f32_t total_sq_x = nk_reduce_add_f32x8_haswell_(sum_squared_x_f32x8);
|
|
1416
|
+
nk_f32_t total_sq_y = nk_reduce_add_f32x8_haswell_(sum_squared_y_f32x8);
|
|
1417
|
+
nk_f32_t total_sq_z = nk_reduce_add_f32x8_haswell_(sum_squared_z_f32x8);
|
|
1418
|
+
|
|
1419
|
+
// Scalar tail
|
|
1420
|
+
for (; i < n; ++i) {
|
|
1421
|
+
nk_f32_t ax, ay, az, bx, by, bz;
|
|
1422
|
+
nk_f16_to_f32_haswell(&a[i * 3 + 0], &ax);
|
|
1423
|
+
nk_f16_to_f32_haswell(&a[i * 3 + 1], &ay);
|
|
1424
|
+
nk_f16_to_f32_haswell(&a[i * 3 + 2], &az);
|
|
1425
|
+
nk_f16_to_f32_haswell(&b[i * 3 + 0], &bx);
|
|
1426
|
+
nk_f16_to_f32_haswell(&b[i * 3 + 1], &by);
|
|
1427
|
+
nk_f16_to_f32_haswell(&b[i * 3 + 2], &bz);
|
|
1428
|
+
total_ax += ax;
|
|
1429
|
+
total_ay += ay;
|
|
1430
|
+
total_az += az;
|
|
1431
|
+
total_bx += bx;
|
|
1432
|
+
total_by += by;
|
|
1433
|
+
total_bz += bz;
|
|
1434
|
+
nk_f32_t delta_x = ax - bx, delta_y = ay - by, delta_z = az - bz;
|
|
1435
|
+
total_sq_x += delta_x * delta_x;
|
|
1436
|
+
total_sq_y += delta_y * delta_y;
|
|
1437
|
+
total_sq_z += delta_z * delta_z;
|
|
1438
|
+
}
|
|
1439
|
+
|
|
1440
|
+
// Compute centroids
|
|
1441
|
+
nk_f32_t inv_n = 1.0f / (nk_f32_t)n;
|
|
1442
|
+
nk_f32_t centroid_a_x = total_ax * inv_n;
|
|
1443
|
+
nk_f32_t centroid_a_y = total_ay * inv_n;
|
|
1444
|
+
nk_f32_t centroid_a_z = total_az * inv_n;
|
|
1445
|
+
nk_f32_t centroid_b_x = total_bx * inv_n;
|
|
1446
|
+
nk_f32_t centroid_b_y = total_by * inv_n;
|
|
1447
|
+
nk_f32_t centroid_b_z = total_bz * inv_n;
|
|
1448
|
+
|
|
1449
|
+
if (a_centroid) {
|
|
1450
|
+
a_centroid[0] = centroid_a_x;
|
|
1451
|
+
a_centroid[1] = centroid_a_y;
|
|
1452
|
+
a_centroid[2] = centroid_a_z;
|
|
1453
|
+
}
|
|
1454
|
+
if (b_centroid) {
|
|
1455
|
+
b_centroid[0] = centroid_b_x;
|
|
1456
|
+
b_centroid[1] = centroid_b_y;
|
|
1457
|
+
b_centroid[2] = centroid_b_z;
|
|
1458
|
+
}
|
|
1459
|
+
|
|
1460
|
+
// Compute RMSD
|
|
1461
|
+
nk_f32_t mean_diff_x = centroid_a_x - centroid_b_x;
|
|
1462
|
+
nk_f32_t mean_diff_y = centroid_a_y - centroid_b_y;
|
|
1463
|
+
nk_f32_t mean_diff_z = centroid_a_z - centroid_b_z;
|
|
1464
|
+
nk_f32_t sum_squared = total_sq_x + total_sq_y + total_sq_z;
|
|
1465
|
+
nk_f32_t mean_diff_sq = mean_diff_x * mean_diff_x + mean_diff_y * mean_diff_y + mean_diff_z * mean_diff_z;
|
|
1466
|
+
|
|
1467
|
+
*result = nk_f32_sqrt_haswell(sum_squared * inv_n - mean_diff_sq);
|
|
1468
|
+
}
|
|
1469
|
+
|
|
1470
|
+
NK_PUBLIC void nk_rmsd_bf16_haswell(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
|
|
1471
|
+
nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
|
|
1472
|
+
/* RMSD uses identity rotation and scale=1.0 */
|
|
1473
|
+
if (rotation) {
|
|
1474
|
+
rotation[0] = 1, rotation[1] = 0, rotation[2] = 0;
|
|
1475
|
+
rotation[3] = 0, rotation[4] = 1, rotation[5] = 0;
|
|
1476
|
+
rotation[6] = 0, rotation[7] = 0, rotation[8] = 1;
|
|
1477
|
+
}
|
|
1478
|
+
if (scale) *scale = 1.0f;
|
|
1479
|
+
|
|
1480
|
+
__m256 const zeros_f32x8 = _mm256_setzero_ps();
|
|
1481
|
+
|
|
1482
|
+
// Accumulators for centroids and squared differences (all in f32)
|
|
1483
|
+
__m256 sum_a_x_f32x8 = zeros_f32x8, sum_a_y_f32x8 = zeros_f32x8, sum_a_z_f32x8 = zeros_f32x8;
|
|
1484
|
+
__m256 sum_b_x_f32x8 = zeros_f32x8, sum_b_y_f32x8 = zeros_f32x8, sum_b_z_f32x8 = zeros_f32x8;
|
|
1485
|
+
__m256 sum_squared_x_f32x8 = zeros_f32x8, sum_squared_y_f32x8 = zeros_f32x8, sum_squared_z_f32x8 = zeros_f32x8;
|
|
1486
|
+
|
|
1487
|
+
__m256 a_x_f32x8, a_y_f32x8, a_z_f32x8, b_x_f32x8, b_y_f32x8, b_z_f32x8;
|
|
1488
|
+
nk_size_t i = 0;
|
|
1489
|
+
|
|
1490
|
+
// Main loop processing 8 points at a time
|
|
1491
|
+
for (; i + 8 <= n; i += 8) {
|
|
1492
|
+
nk_deinterleave_bf16x8_to_f32x8_haswell_(a + i * 3, &a_x_f32x8, &a_y_f32x8, &a_z_f32x8);
|
|
1493
|
+
nk_deinterleave_bf16x8_to_f32x8_haswell_(b + i * 3, &b_x_f32x8, &b_y_f32x8, &b_z_f32x8);
|
|
1494
|
+
|
|
1495
|
+
sum_a_x_f32x8 = _mm256_add_ps(sum_a_x_f32x8, a_x_f32x8);
|
|
1496
|
+
sum_a_y_f32x8 = _mm256_add_ps(sum_a_y_f32x8, a_y_f32x8);
|
|
1497
|
+
sum_a_z_f32x8 = _mm256_add_ps(sum_a_z_f32x8, a_z_f32x8);
|
|
1498
|
+
sum_b_x_f32x8 = _mm256_add_ps(sum_b_x_f32x8, b_x_f32x8);
|
|
1499
|
+
sum_b_y_f32x8 = _mm256_add_ps(sum_b_y_f32x8, b_y_f32x8);
|
|
1500
|
+
sum_b_z_f32x8 = _mm256_add_ps(sum_b_z_f32x8, b_z_f32x8);
|
|
1501
|
+
|
|
1502
|
+
__m256 delta_x_f32x8 = _mm256_sub_ps(a_x_f32x8, b_x_f32x8);
|
|
1503
|
+
__m256 delta_y_f32x8 = _mm256_sub_ps(a_y_f32x8, b_y_f32x8);
|
|
1504
|
+
__m256 delta_z_f32x8 = _mm256_sub_ps(a_z_f32x8, b_z_f32x8);
|
|
1505
|
+
|
|
1506
|
+
sum_squared_x_f32x8 = _mm256_fmadd_ps(delta_x_f32x8, delta_x_f32x8, sum_squared_x_f32x8);
|
|
1507
|
+
sum_squared_y_f32x8 = _mm256_fmadd_ps(delta_y_f32x8, delta_y_f32x8, sum_squared_y_f32x8);
|
|
1508
|
+
sum_squared_z_f32x8 = _mm256_fmadd_ps(delta_z_f32x8, delta_z_f32x8, sum_squared_z_f32x8);
|
|
1509
|
+
}
|
|
1510
|
+
|
|
1511
|
+
// Reduce vectors to scalars
|
|
1512
|
+
nk_f32_t total_ax = nk_reduce_add_f32x8_haswell_(sum_a_x_f32x8);
|
|
1513
|
+
nk_f32_t total_ay = nk_reduce_add_f32x8_haswell_(sum_a_y_f32x8);
|
|
1514
|
+
nk_f32_t total_az = nk_reduce_add_f32x8_haswell_(sum_a_z_f32x8);
|
|
1515
|
+
nk_f32_t total_bx = nk_reduce_add_f32x8_haswell_(sum_b_x_f32x8);
|
|
1516
|
+
nk_f32_t total_by = nk_reduce_add_f32x8_haswell_(sum_b_y_f32x8);
|
|
1517
|
+
nk_f32_t total_bz = nk_reduce_add_f32x8_haswell_(sum_b_z_f32x8);
|
|
1518
|
+
nk_f32_t total_sq_x = nk_reduce_add_f32x8_haswell_(sum_squared_x_f32x8);
|
|
1519
|
+
nk_f32_t total_sq_y = nk_reduce_add_f32x8_haswell_(sum_squared_y_f32x8);
|
|
1520
|
+
nk_f32_t total_sq_z = nk_reduce_add_f32x8_haswell_(sum_squared_z_f32x8);
|
|
1521
|
+
|
|
1522
|
+
// Scalar tail
|
|
1523
|
+
for (; i < n; ++i) {
|
|
1524
|
+
nk_f32_t ax, ay, az, bx, by, bz;
|
|
1525
|
+
nk_bf16_to_f32_serial(&a[i * 3 + 0], &ax);
|
|
1526
|
+
nk_bf16_to_f32_serial(&a[i * 3 + 1], &ay);
|
|
1527
|
+
nk_bf16_to_f32_serial(&a[i * 3 + 2], &az);
|
|
1528
|
+
nk_bf16_to_f32_serial(&b[i * 3 + 0], &bx);
|
|
1529
|
+
nk_bf16_to_f32_serial(&b[i * 3 + 1], &by);
|
|
1530
|
+
nk_bf16_to_f32_serial(&b[i * 3 + 2], &bz);
|
|
1531
|
+
total_ax += ax;
|
|
1532
|
+
total_ay += ay;
|
|
1533
|
+
total_az += az;
|
|
1534
|
+
total_bx += bx;
|
|
1535
|
+
total_by += by;
|
|
1536
|
+
total_bz += bz;
|
|
1537
|
+
nk_f32_t delta_x = ax - bx, delta_y = ay - by, delta_z = az - bz;
|
|
1538
|
+
total_sq_x += delta_x * delta_x;
|
|
1539
|
+
total_sq_y += delta_y * delta_y;
|
|
1540
|
+
total_sq_z += delta_z * delta_z;
|
|
1541
|
+
}
|
|
1542
|
+
|
|
1543
|
+
// Compute centroids
|
|
1544
|
+
nk_f32_t inv_n = 1.0f / (nk_f32_t)n;
|
|
1545
|
+
nk_f32_t centroid_a_x = total_ax * inv_n;
|
|
1546
|
+
nk_f32_t centroid_a_y = total_ay * inv_n;
|
|
1547
|
+
nk_f32_t centroid_a_z = total_az * inv_n;
|
|
1548
|
+
nk_f32_t centroid_b_x = total_bx * inv_n;
|
|
1549
|
+
nk_f32_t centroid_b_y = total_by * inv_n;
|
|
1550
|
+
nk_f32_t centroid_b_z = total_bz * inv_n;
|
|
1551
|
+
|
|
1552
|
+
if (a_centroid) {
|
|
1553
|
+
a_centroid[0] = centroid_a_x;
|
|
1554
|
+
a_centroid[1] = centroid_a_y;
|
|
1555
|
+
a_centroid[2] = centroid_a_z;
|
|
1556
|
+
}
|
|
1557
|
+
if (b_centroid) {
|
|
1558
|
+
b_centroid[0] = centroid_b_x;
|
|
1559
|
+
b_centroid[1] = centroid_b_y;
|
|
1560
|
+
b_centroid[2] = centroid_b_z;
|
|
1561
|
+
}
|
|
1562
|
+
|
|
1563
|
+
// Compute RMSD
|
|
1564
|
+
nk_f32_t mean_diff_x = centroid_a_x - centroid_b_x;
|
|
1565
|
+
nk_f32_t mean_diff_y = centroid_a_y - centroid_b_y;
|
|
1566
|
+
nk_f32_t mean_diff_z = centroid_a_z - centroid_b_z;
|
|
1567
|
+
nk_f32_t sum_squared = total_sq_x + total_sq_y + total_sq_z;
|
|
1568
|
+
nk_f32_t mean_diff_sq = mean_diff_x * mean_diff_x + mean_diff_y * mean_diff_y + mean_diff_z * mean_diff_z;
|
|
1569
|
+
|
|
1570
|
+
*result = nk_f32_sqrt_haswell(sum_squared * inv_n - mean_diff_sq);
|
|
1571
|
+
}
|
|
1572
|
+
|
|
1573
|
+
NK_PUBLIC void nk_kabsch_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
|
|
1574
|
+
nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
|
|
1575
|
+
// Fused single-pass: load f16, convert to f32, compute centroids and covariance
|
|
1576
|
+
__m256 const zeros_f32x8 = _mm256_setzero_ps();
|
|
1577
|
+
|
|
1578
|
+
// Accumulators for centroids (f32)
|
|
1579
|
+
__m256 sum_a_x_f32x8 = zeros_f32x8, sum_a_y_f32x8 = zeros_f32x8, sum_a_z_f32x8 = zeros_f32x8;
|
|
1580
|
+
__m256 sum_b_x_f32x8 = zeros_f32x8, sum_b_y_f32x8 = zeros_f32x8, sum_b_z_f32x8 = zeros_f32x8;
|
|
1581
|
+
|
|
1582
|
+
// Accumulators for covariance matrix (sum of outer products)
|
|
1583
|
+
__m256 cov_xx_f32x8 = zeros_f32x8, cov_xy_f32x8 = zeros_f32x8, cov_xz_f32x8 = zeros_f32x8;
|
|
1584
|
+
__m256 cov_yx_f32x8 = zeros_f32x8, cov_yy_f32x8 = zeros_f32x8, cov_yz_f32x8 = zeros_f32x8;
|
|
1585
|
+
__m256 cov_zx_f32x8 = zeros_f32x8, cov_zy_f32x8 = zeros_f32x8, cov_zz_f32x8 = zeros_f32x8;
|
|
1586
|
+
|
|
1587
|
+
nk_size_t i = 0;
|
|
1588
|
+
__m256 a_x_f32x8, a_y_f32x8, a_z_f32x8, b_x_f32x8, b_y_f32x8, b_z_f32x8;
|
|
1589
|
+
|
|
1590
|
+
for (; i + 8 <= n; i += 8) {
|
|
1591
|
+
nk_deinterleave_f16x8_to_f32x8_haswell_(a + i * 3, &a_x_f32x8, &a_y_f32x8, &a_z_f32x8);
|
|
1592
|
+
nk_deinterleave_f16x8_to_f32x8_haswell_(b + i * 3, &b_x_f32x8, &b_y_f32x8, &b_z_f32x8);
|
|
1593
|
+
|
|
1594
|
+
// Accumulate centroids
|
|
1595
|
+
sum_a_x_f32x8 = _mm256_add_ps(sum_a_x_f32x8, a_x_f32x8);
|
|
1596
|
+
sum_a_y_f32x8 = _mm256_add_ps(sum_a_y_f32x8, a_y_f32x8);
|
|
1597
|
+
sum_a_z_f32x8 = _mm256_add_ps(sum_a_z_f32x8, a_z_f32x8);
|
|
1598
|
+
sum_b_x_f32x8 = _mm256_add_ps(sum_b_x_f32x8, b_x_f32x8);
|
|
1599
|
+
sum_b_y_f32x8 = _mm256_add_ps(sum_b_y_f32x8, b_y_f32x8);
|
|
1600
|
+
sum_b_z_f32x8 = _mm256_add_ps(sum_b_z_f32x8, b_z_f32x8);
|
|
1601
|
+
|
|
1602
|
+
// Accumulate outer products
|
|
1603
|
+
cov_xx_f32x8 = _mm256_fmadd_ps(a_x_f32x8, b_x_f32x8, cov_xx_f32x8);
|
|
1604
|
+
cov_xy_f32x8 = _mm256_fmadd_ps(a_x_f32x8, b_y_f32x8, cov_xy_f32x8);
|
|
1605
|
+
cov_xz_f32x8 = _mm256_fmadd_ps(a_x_f32x8, b_z_f32x8, cov_xz_f32x8);
|
|
1606
|
+
cov_yx_f32x8 = _mm256_fmadd_ps(a_y_f32x8, b_x_f32x8, cov_yx_f32x8);
|
|
1607
|
+
cov_yy_f32x8 = _mm256_fmadd_ps(a_y_f32x8, b_y_f32x8, cov_yy_f32x8);
|
|
1608
|
+
cov_yz_f32x8 = _mm256_fmadd_ps(a_y_f32x8, b_z_f32x8, cov_yz_f32x8);
|
|
1609
|
+
cov_zx_f32x8 = _mm256_fmadd_ps(a_z_f32x8, b_x_f32x8, cov_zx_f32x8);
|
|
1610
|
+
cov_zy_f32x8 = _mm256_fmadd_ps(a_z_f32x8, b_y_f32x8, cov_zy_f32x8);
|
|
1611
|
+
cov_zz_f32x8 = _mm256_fmadd_ps(a_z_f32x8, b_z_f32x8, cov_zz_f32x8);
|
|
1612
|
+
}
|
|
1613
|
+
|
|
1614
|
+
// Reduce vector accumulators
|
|
1615
|
+
nk_f32_t sum_a_x = nk_reduce_add_f32x8_haswell_(sum_a_x_f32x8);
|
|
1616
|
+
nk_f32_t sum_a_y = nk_reduce_add_f32x8_haswell_(sum_a_y_f32x8);
|
|
1617
|
+
nk_f32_t sum_a_z = nk_reduce_add_f32x8_haswell_(sum_a_z_f32x8);
|
|
1618
|
+
nk_f32_t sum_b_x = nk_reduce_add_f32x8_haswell_(sum_b_x_f32x8);
|
|
1619
|
+
nk_f32_t sum_b_y = nk_reduce_add_f32x8_haswell_(sum_b_y_f32x8);
|
|
1620
|
+
nk_f32_t sum_b_z = nk_reduce_add_f32x8_haswell_(sum_b_z_f32x8);
|
|
1621
|
+
|
|
1622
|
+
nk_f32_t covariance_x_x = nk_reduce_add_f32x8_haswell_(cov_xx_f32x8);
|
|
1623
|
+
nk_f32_t covariance_x_y = nk_reduce_add_f32x8_haswell_(cov_xy_f32x8);
|
|
1624
|
+
nk_f32_t covariance_x_z = nk_reduce_add_f32x8_haswell_(cov_xz_f32x8);
|
|
1625
|
+
nk_f32_t covariance_y_x = nk_reduce_add_f32x8_haswell_(cov_yx_f32x8);
|
|
1626
|
+
nk_f32_t covariance_y_y = nk_reduce_add_f32x8_haswell_(cov_yy_f32x8);
|
|
1627
|
+
nk_f32_t covariance_y_z = nk_reduce_add_f32x8_haswell_(cov_yz_f32x8);
|
|
1628
|
+
nk_f32_t covariance_z_x = nk_reduce_add_f32x8_haswell_(cov_zx_f32x8);
|
|
1629
|
+
nk_f32_t covariance_z_y = nk_reduce_add_f32x8_haswell_(cov_zy_f32x8);
|
|
1630
|
+
nk_f32_t covariance_z_z = nk_reduce_add_f32x8_haswell_(cov_zz_f32x8);
|
|
1631
|
+
|
|
1632
|
+
// Scalar tail
|
|
1633
|
+
for (; i < n; ++i) {
|
|
1634
|
+
nk_f32_t ax, ay, az, bx, by, bz;
|
|
1635
|
+
nk_f16_to_f32_haswell(&a[i * 3 + 0], &ax);
|
|
1636
|
+
nk_f16_to_f32_haswell(&a[i * 3 + 1], &ay);
|
|
1637
|
+
nk_f16_to_f32_haswell(&a[i * 3 + 2], &az);
|
|
1638
|
+
nk_f16_to_f32_haswell(&b[i * 3 + 0], &bx);
|
|
1639
|
+
nk_f16_to_f32_haswell(&b[i * 3 + 1], &by);
|
|
1640
|
+
nk_f16_to_f32_haswell(&b[i * 3 + 2], &bz);
|
|
1641
|
+
sum_a_x += ax;
|
|
1642
|
+
sum_a_y += ay;
|
|
1643
|
+
sum_a_z += az;
|
|
1644
|
+
sum_b_x += bx;
|
|
1645
|
+
sum_b_y += by;
|
|
1646
|
+
sum_b_z += bz;
|
|
1647
|
+
covariance_x_x += ax * bx;
|
|
1648
|
+
covariance_x_y += ax * by;
|
|
1649
|
+
covariance_x_z += ax * bz;
|
|
1650
|
+
covariance_y_x += ay * bx;
|
|
1651
|
+
covariance_y_y += ay * by;
|
|
1652
|
+
covariance_y_z += ay * bz;
|
|
1653
|
+
covariance_z_x += az * bx;
|
|
1654
|
+
covariance_z_y += az * by;
|
|
1655
|
+
covariance_z_z += az * bz;
|
|
1656
|
+
}
|
|
1657
|
+
|
|
1658
|
+
// Compute centroids
|
|
1659
|
+
nk_f32_t inv_n = 1.0f / (nk_f32_t)n;
|
|
1660
|
+
nk_f32_t centroid_a_x = sum_a_x * inv_n;
|
|
1661
|
+
nk_f32_t centroid_a_y = sum_a_y * inv_n;
|
|
1662
|
+
nk_f32_t centroid_a_z = sum_a_z * inv_n;
|
|
1663
|
+
nk_f32_t centroid_b_x = sum_b_x * inv_n;
|
|
1664
|
+
nk_f32_t centroid_b_y = sum_b_y * inv_n;
|
|
1665
|
+
nk_f32_t centroid_b_z = sum_b_z * inv_n;
|
|
1666
|
+
|
|
1667
|
+
if (a_centroid) {
|
|
1668
|
+
a_centroid[0] = centroid_a_x;
|
|
1669
|
+
a_centroid[1] = centroid_a_y;
|
|
1670
|
+
a_centroid[2] = centroid_a_z;
|
|
1671
|
+
}
|
|
1672
|
+
if (b_centroid) {
|
|
1673
|
+
b_centroid[0] = centroid_b_x;
|
|
1674
|
+
b_centroid[1] = centroid_b_y;
|
|
1675
|
+
b_centroid[2] = centroid_b_z;
|
|
1676
|
+
}
|
|
1677
|
+
|
|
1678
|
+
// Apply centering correction: H_centered = H - n * centroid_a * centroid_bᵀ
|
|
1679
|
+
covariance_x_x -= n * centroid_a_x * centroid_b_x;
|
|
1680
|
+
covariance_x_y -= n * centroid_a_x * centroid_b_y;
|
|
1681
|
+
covariance_x_z -= n * centroid_a_x * centroid_b_z;
|
|
1682
|
+
covariance_y_x -= n * centroid_a_y * centroid_b_x;
|
|
1683
|
+
covariance_y_y -= n * centroid_a_y * centroid_b_y;
|
|
1684
|
+
covariance_y_z -= n * centroid_a_y * centroid_b_z;
|
|
1685
|
+
covariance_z_x -= n * centroid_a_z * centroid_b_x;
|
|
1686
|
+
covariance_z_y -= n * centroid_a_z * centroid_b_y;
|
|
1687
|
+
covariance_z_z -= n * centroid_a_z * centroid_b_z;
|
|
1688
|
+
|
|
1689
|
+
// Compute SVD and optimal rotation
|
|
1690
|
+
nk_f32_t cross_covariance[9] = {covariance_x_x, covariance_x_y, covariance_x_z, covariance_y_x, covariance_y_y,
|
|
1691
|
+
covariance_y_z, covariance_z_x, covariance_z_y, covariance_z_z};
|
|
1692
|
+
nk_f32_t svd_u[9], svd_s[9], svd_v[9];
|
|
1693
|
+
nk_svd3x3_f32_(cross_covariance, svd_u, svd_s, svd_v);
|
|
1694
|
+
|
|
1695
|
+
// R = V * Uᵀ
|
|
1696
|
+
nk_f32_t r[9];
|
|
1697
|
+
r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
|
|
1698
|
+
r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
|
|
1699
|
+
r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
|
|
1700
|
+
r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
|
|
1701
|
+
r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
|
|
1702
|
+
r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
|
|
1703
|
+
r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
|
|
1704
|
+
r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
|
|
1705
|
+
r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
|
|
1706
|
+
|
|
1707
|
+
// Handle reflection: if det(R) < 0, negate third column of V and recompute R
|
|
1708
|
+
if (nk_det3x3_f32_(r) < 0) {
|
|
1709
|
+
svd_v[2] = -svd_v[2];
|
|
1710
|
+
svd_v[5] = -svd_v[5];
|
|
1711
|
+
svd_v[8] = -svd_v[8];
|
|
1712
|
+
r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
|
|
1713
|
+
r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
|
|
1714
|
+
r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
|
|
1715
|
+
r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
|
|
1716
|
+
r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
|
|
1717
|
+
r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
|
|
1718
|
+
r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
|
|
1719
|
+
r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
|
|
1720
|
+
r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
|
|
1721
|
+
}
|
|
1722
|
+
|
|
1723
|
+
/* Output rotation matrix and scale=1.0 */
|
|
1724
|
+
if (rotation) {
|
|
1725
|
+
for (int j = 0; j < 9; ++j) rotation[j] = r[j];
|
|
1726
|
+
}
|
|
1727
|
+
if (scale) *scale = 1.0f;
|
|
1728
|
+
|
|
1729
|
+
// Compute RMSD after optimal rotation
|
|
1730
|
+
nk_f32_t sum_squared = nk_transformed_ssd_f16_haswell_(a, b, n, r, 1.0f, centroid_a_x, centroid_a_y, centroid_a_z,
|
|
1731
|
+
centroid_b_x, centroid_b_y, centroid_b_z);
|
|
1732
|
+
*result = nk_f32_sqrt_haswell(sum_squared * inv_n);
|
|
1733
|
+
}
|
|
1734
|
+
|
|
1735
|
+
NK_PUBLIC void nk_kabsch_bf16_haswell(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
|
|
1736
|
+
nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
|
|
1737
|
+
// Fused single-pass: load bf16, convert to f32, compute centroids and covariance
|
|
1738
|
+
__m256 const zeros_f32x8 = _mm256_setzero_ps();
|
|
1739
|
+
|
|
1740
|
+
// Accumulators for centroids (f32)
|
|
1741
|
+
__m256 sum_a_x_f32x8 = zeros_f32x8, sum_a_y_f32x8 = zeros_f32x8, sum_a_z_f32x8 = zeros_f32x8;
|
|
1742
|
+
__m256 sum_b_x_f32x8 = zeros_f32x8, sum_b_y_f32x8 = zeros_f32x8, sum_b_z_f32x8 = zeros_f32x8;
|
|
1743
|
+
|
|
1744
|
+
// Accumulators for covariance matrix (sum of outer products)
|
|
1745
|
+
__m256 cov_xx_f32x8 = zeros_f32x8, cov_xy_f32x8 = zeros_f32x8, cov_xz_f32x8 = zeros_f32x8;
|
|
1746
|
+
__m256 cov_yx_f32x8 = zeros_f32x8, cov_yy_f32x8 = zeros_f32x8, cov_yz_f32x8 = zeros_f32x8;
|
|
1747
|
+
__m256 cov_zx_f32x8 = zeros_f32x8, cov_zy_f32x8 = zeros_f32x8, cov_zz_f32x8 = zeros_f32x8;
|
|
1748
|
+
|
|
1749
|
+
nk_size_t i = 0;
|
|
1750
|
+
__m256 a_x_f32x8, a_y_f32x8, a_z_f32x8, b_x_f32x8, b_y_f32x8, b_z_f32x8;
|
|
1751
|
+
|
|
1752
|
+
for (; i + 8 <= n; i += 8) {
|
|
1753
|
+
nk_deinterleave_bf16x8_to_f32x8_haswell_(a + i * 3, &a_x_f32x8, &a_y_f32x8, &a_z_f32x8);
|
|
1754
|
+
nk_deinterleave_bf16x8_to_f32x8_haswell_(b + i * 3, &b_x_f32x8, &b_y_f32x8, &b_z_f32x8);
|
|
1755
|
+
|
|
1756
|
+
// Accumulate centroids
|
|
1757
|
+
sum_a_x_f32x8 = _mm256_add_ps(sum_a_x_f32x8, a_x_f32x8);
|
|
1758
|
+
sum_a_y_f32x8 = _mm256_add_ps(sum_a_y_f32x8, a_y_f32x8);
|
|
1759
|
+
sum_a_z_f32x8 = _mm256_add_ps(sum_a_z_f32x8, a_z_f32x8);
|
|
1760
|
+
sum_b_x_f32x8 = _mm256_add_ps(sum_b_x_f32x8, b_x_f32x8);
|
|
1761
|
+
sum_b_y_f32x8 = _mm256_add_ps(sum_b_y_f32x8, b_y_f32x8);
|
|
1762
|
+
sum_b_z_f32x8 = _mm256_add_ps(sum_b_z_f32x8, b_z_f32x8);
|
|
1763
|
+
|
|
1764
|
+
// Accumulate outer products
|
|
1765
|
+
cov_xx_f32x8 = _mm256_fmadd_ps(a_x_f32x8, b_x_f32x8, cov_xx_f32x8);
|
|
1766
|
+
cov_xy_f32x8 = _mm256_fmadd_ps(a_x_f32x8, b_y_f32x8, cov_xy_f32x8);
|
|
1767
|
+
cov_xz_f32x8 = _mm256_fmadd_ps(a_x_f32x8, b_z_f32x8, cov_xz_f32x8);
|
|
1768
|
+
cov_yx_f32x8 = _mm256_fmadd_ps(a_y_f32x8, b_x_f32x8, cov_yx_f32x8);
|
|
1769
|
+
cov_yy_f32x8 = _mm256_fmadd_ps(a_y_f32x8, b_y_f32x8, cov_yy_f32x8);
|
|
1770
|
+
cov_yz_f32x8 = _mm256_fmadd_ps(a_y_f32x8, b_z_f32x8, cov_yz_f32x8);
|
|
1771
|
+
cov_zx_f32x8 = _mm256_fmadd_ps(a_z_f32x8, b_x_f32x8, cov_zx_f32x8);
|
|
1772
|
+
cov_zy_f32x8 = _mm256_fmadd_ps(a_z_f32x8, b_y_f32x8, cov_zy_f32x8);
|
|
1773
|
+
cov_zz_f32x8 = _mm256_fmadd_ps(a_z_f32x8, b_z_f32x8, cov_zz_f32x8);
|
|
1774
|
+
}
|
|
1775
|
+
|
|
1776
|
+
// Reduce vector accumulators
|
|
1777
|
+
nk_f32_t sum_a_x = nk_reduce_add_f32x8_haswell_(sum_a_x_f32x8);
|
|
1778
|
+
nk_f32_t sum_a_y = nk_reduce_add_f32x8_haswell_(sum_a_y_f32x8);
|
|
1779
|
+
nk_f32_t sum_a_z = nk_reduce_add_f32x8_haswell_(sum_a_z_f32x8);
|
|
1780
|
+
nk_f32_t sum_b_x = nk_reduce_add_f32x8_haswell_(sum_b_x_f32x8);
|
|
1781
|
+
nk_f32_t sum_b_y = nk_reduce_add_f32x8_haswell_(sum_b_y_f32x8);
|
|
1782
|
+
nk_f32_t sum_b_z = nk_reduce_add_f32x8_haswell_(sum_b_z_f32x8);
|
|
1783
|
+
|
|
1784
|
+
nk_f32_t covariance_x_x = nk_reduce_add_f32x8_haswell_(cov_xx_f32x8);
|
|
1785
|
+
nk_f32_t covariance_x_y = nk_reduce_add_f32x8_haswell_(cov_xy_f32x8);
|
|
1786
|
+
nk_f32_t covariance_x_z = nk_reduce_add_f32x8_haswell_(cov_xz_f32x8);
|
|
1787
|
+
nk_f32_t covariance_y_x = nk_reduce_add_f32x8_haswell_(cov_yx_f32x8);
|
|
1788
|
+
nk_f32_t covariance_y_y = nk_reduce_add_f32x8_haswell_(cov_yy_f32x8);
|
|
1789
|
+
nk_f32_t covariance_y_z = nk_reduce_add_f32x8_haswell_(cov_yz_f32x8);
|
|
1790
|
+
nk_f32_t covariance_z_x = nk_reduce_add_f32x8_haswell_(cov_zx_f32x8);
|
|
1791
|
+
nk_f32_t covariance_z_y = nk_reduce_add_f32x8_haswell_(cov_zy_f32x8);
|
|
1792
|
+
nk_f32_t covariance_z_z = nk_reduce_add_f32x8_haswell_(cov_zz_f32x8);
|
|
1793
|
+
|
|
1794
|
+
// Scalar tail
|
|
1795
|
+
for (; i < n; ++i) {
|
|
1796
|
+
nk_f32_t ax, ay, az, bx, by, bz;
|
|
1797
|
+
nk_bf16_to_f32_serial(&a[i * 3 + 0], &ax);
|
|
1798
|
+
nk_bf16_to_f32_serial(&a[i * 3 + 1], &ay);
|
|
1799
|
+
nk_bf16_to_f32_serial(&a[i * 3 + 2], &az);
|
|
1800
|
+
nk_bf16_to_f32_serial(&b[i * 3 + 0], &bx);
|
|
1801
|
+
nk_bf16_to_f32_serial(&b[i * 3 + 1], &by);
|
|
1802
|
+
nk_bf16_to_f32_serial(&b[i * 3 + 2], &bz);
|
|
1803
|
+
sum_a_x += ax;
|
|
1804
|
+
sum_a_y += ay;
|
|
1805
|
+
sum_a_z += az;
|
|
1806
|
+
sum_b_x += bx;
|
|
1807
|
+
sum_b_y += by;
|
|
1808
|
+
sum_b_z += bz;
|
|
1809
|
+
covariance_x_x += ax * bx;
|
|
1810
|
+
covariance_x_y += ax * by;
|
|
1811
|
+
covariance_x_z += ax * bz;
|
|
1812
|
+
covariance_y_x += ay * bx;
|
|
1813
|
+
covariance_y_y += ay * by;
|
|
1814
|
+
covariance_y_z += ay * bz;
|
|
1815
|
+
covariance_z_x += az * bx;
|
|
1816
|
+
covariance_z_y += az * by;
|
|
1817
|
+
covariance_z_z += az * bz;
|
|
1818
|
+
}
|
|
1819
|
+
|
|
1820
|
+
// Compute centroids
|
|
1821
|
+
nk_f32_t inv_n = 1.0f / (nk_f32_t)n;
|
|
1822
|
+
nk_f32_t centroid_a_x = sum_a_x * inv_n;
|
|
1823
|
+
nk_f32_t centroid_a_y = sum_a_y * inv_n;
|
|
1824
|
+
nk_f32_t centroid_a_z = sum_a_z * inv_n;
|
|
1825
|
+
nk_f32_t centroid_b_x = sum_b_x * inv_n;
|
|
1826
|
+
nk_f32_t centroid_b_y = sum_b_y * inv_n;
|
|
1827
|
+
nk_f32_t centroid_b_z = sum_b_z * inv_n;
|
|
1828
|
+
|
|
1829
|
+
if (a_centroid) {
|
|
1830
|
+
a_centroid[0] = centroid_a_x;
|
|
1831
|
+
a_centroid[1] = centroid_a_y;
|
|
1832
|
+
a_centroid[2] = centroid_a_z;
|
|
1833
|
+
}
|
|
1834
|
+
if (b_centroid) {
|
|
1835
|
+
b_centroid[0] = centroid_b_x;
|
|
1836
|
+
b_centroid[1] = centroid_b_y;
|
|
1837
|
+
b_centroid[2] = centroid_b_z;
|
|
1838
|
+
}
|
|
1839
|
+
|
|
1840
|
+
// Apply centering correction: H_centered = H - n * centroid_a * centroid_bᵀ
|
|
1841
|
+
covariance_x_x -= n * centroid_a_x * centroid_b_x;
|
|
1842
|
+
covariance_x_y -= n * centroid_a_x * centroid_b_y;
|
|
1843
|
+
covariance_x_z -= n * centroid_a_x * centroid_b_z;
|
|
1844
|
+
covariance_y_x -= n * centroid_a_y * centroid_b_x;
|
|
1845
|
+
covariance_y_y -= n * centroid_a_y * centroid_b_y;
|
|
1846
|
+
covariance_y_z -= n * centroid_a_y * centroid_b_z;
|
|
1847
|
+
covariance_z_x -= n * centroid_a_z * centroid_b_x;
|
|
1848
|
+
covariance_z_y -= n * centroid_a_z * centroid_b_y;
|
|
1849
|
+
covariance_z_z -= n * centroid_a_z * centroid_b_z;
|
|
1850
|
+
|
|
1851
|
+
// Compute SVD and optimal rotation
|
|
1852
|
+
nk_f32_t cross_covariance[9] = {covariance_x_x, covariance_x_y, covariance_x_z, covariance_y_x, covariance_y_y,
|
|
1853
|
+
covariance_y_z, covariance_z_x, covariance_z_y, covariance_z_z};
|
|
1854
|
+
nk_f32_t svd_u[9], svd_s[9], svd_v[9];
|
|
1855
|
+
nk_svd3x3_f32_(cross_covariance, svd_u, svd_s, svd_v);
|
|
1856
|
+
|
|
1857
|
+
// R = V * Uᵀ
|
|
1858
|
+
nk_f32_t r[9];
|
|
1859
|
+
r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
|
|
1860
|
+
r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
|
|
1861
|
+
r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
|
|
1862
|
+
r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
|
|
1863
|
+
r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
|
|
1864
|
+
r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
|
|
1865
|
+
r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
|
|
1866
|
+
r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
|
|
1867
|
+
r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
|
|
1868
|
+
|
|
1869
|
+
// Handle reflection: if det(R) < 0, negate third column of V and recompute R
|
|
1870
|
+
if (nk_det3x3_f32_(r) < 0) {
|
|
1871
|
+
svd_v[2] = -svd_v[2];
|
|
1872
|
+
svd_v[5] = -svd_v[5];
|
|
1873
|
+
svd_v[8] = -svd_v[8];
|
|
1874
|
+
r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
|
|
1875
|
+
r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
|
|
1876
|
+
r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
|
|
1877
|
+
r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
|
|
1878
|
+
r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
|
|
1879
|
+
r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
|
|
1880
|
+
r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
|
|
1881
|
+
r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
|
|
1882
|
+
r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
|
|
1883
|
+
}
|
|
1884
|
+
|
|
1885
|
+
/* Output rotation matrix and scale=1.0 */
|
|
1886
|
+
if (rotation) {
|
|
1887
|
+
for (int j = 0; j < 9; ++j) rotation[j] = r[j];
|
|
1888
|
+
}
|
|
1889
|
+
if (scale) *scale = 1.0f;
|
|
1890
|
+
|
|
1891
|
+
// Compute RMSD after optimal rotation
|
|
1892
|
+
nk_f32_t sum_squared = nk_transformed_ssd_bf16_haswell_(a, b, n, r, 1.0f, centroid_a_x, centroid_a_y, centroid_a_z,
|
|
1893
|
+
centroid_b_x, centroid_b_y, centroid_b_z);
|
|
1894
|
+
*result = nk_f32_sqrt_haswell(sum_squared * inv_n);
|
|
1895
|
+
}
|
|
1896
|
+
|
|
1897
|
+
NK_PUBLIC void nk_umeyama_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
|
|
1898
|
+
nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
|
|
1899
|
+
// Fused single-pass: load f16, convert to f32, compute centroids, covariance, and variance
|
|
1900
|
+
__m256 const zeros_f32x8 = _mm256_setzero_ps();
|
|
1901
|
+
|
|
1902
|
+
__m256 sum_a_x_f32x8 = zeros_f32x8, sum_a_y_f32x8 = zeros_f32x8, sum_a_z_f32x8 = zeros_f32x8;
|
|
1903
|
+
__m256 sum_b_x_f32x8 = zeros_f32x8, sum_b_y_f32x8 = zeros_f32x8, sum_b_z_f32x8 = zeros_f32x8;
|
|
1904
|
+
__m256 cov_xx_f32x8 = zeros_f32x8, cov_xy_f32x8 = zeros_f32x8, cov_xz_f32x8 = zeros_f32x8;
|
|
1905
|
+
__m256 cov_yx_f32x8 = zeros_f32x8, cov_yy_f32x8 = zeros_f32x8, cov_yz_f32x8 = zeros_f32x8;
|
|
1906
|
+
__m256 cov_zx_f32x8 = zeros_f32x8, cov_zy_f32x8 = zeros_f32x8, cov_zz_f32x8 = zeros_f32x8;
|
|
1907
|
+
__m256 variance_a_f32x8 = zeros_f32x8;
|
|
1908
|
+
|
|
1909
|
+
nk_size_t i = 0;
|
|
1910
|
+
__m256 a_x_f32x8, a_y_f32x8, a_z_f32x8, b_x_f32x8, b_y_f32x8, b_z_f32x8;
|
|
1911
|
+
|
|
1912
|
+
for (; i + 8 <= n; i += 8) {
|
|
1913
|
+
nk_deinterleave_f16x8_to_f32x8_haswell_(a + i * 3, &a_x_f32x8, &a_y_f32x8, &a_z_f32x8);
|
|
1914
|
+
nk_deinterleave_f16x8_to_f32x8_haswell_(b + i * 3, &b_x_f32x8, &b_y_f32x8, &b_z_f32x8);
|
|
1915
|
+
|
|
1916
|
+
// Accumulate centroids
|
|
1917
|
+
sum_a_x_f32x8 = _mm256_add_ps(sum_a_x_f32x8, a_x_f32x8);
|
|
1918
|
+
sum_a_y_f32x8 = _mm256_add_ps(sum_a_y_f32x8, a_y_f32x8);
|
|
1919
|
+
sum_a_z_f32x8 = _mm256_add_ps(sum_a_z_f32x8, a_z_f32x8);
|
|
1920
|
+
sum_b_x_f32x8 = _mm256_add_ps(sum_b_x_f32x8, b_x_f32x8);
|
|
1921
|
+
sum_b_y_f32x8 = _mm256_add_ps(sum_b_y_f32x8, b_y_f32x8);
|
|
1922
|
+
sum_b_z_f32x8 = _mm256_add_ps(sum_b_z_f32x8, b_z_f32x8);
|
|
1923
|
+
|
|
1924
|
+
// Accumulate outer products
|
|
1925
|
+
cov_xx_f32x8 = _mm256_fmadd_ps(a_x_f32x8, b_x_f32x8, cov_xx_f32x8);
|
|
1926
|
+
cov_xy_f32x8 = _mm256_fmadd_ps(a_x_f32x8, b_y_f32x8, cov_xy_f32x8);
|
|
1927
|
+
cov_xz_f32x8 = _mm256_fmadd_ps(a_x_f32x8, b_z_f32x8, cov_xz_f32x8);
|
|
1928
|
+
cov_yx_f32x8 = _mm256_fmadd_ps(a_y_f32x8, b_x_f32x8, cov_yx_f32x8);
|
|
1929
|
+
cov_yy_f32x8 = _mm256_fmadd_ps(a_y_f32x8, b_y_f32x8, cov_yy_f32x8);
|
|
1930
|
+
cov_yz_f32x8 = _mm256_fmadd_ps(a_y_f32x8, b_z_f32x8, cov_yz_f32x8);
|
|
1931
|
+
cov_zx_f32x8 = _mm256_fmadd_ps(a_z_f32x8, b_x_f32x8, cov_zx_f32x8);
|
|
1932
|
+
cov_zy_f32x8 = _mm256_fmadd_ps(a_z_f32x8, b_y_f32x8, cov_zy_f32x8);
|
|
1933
|
+
cov_zz_f32x8 = _mm256_fmadd_ps(a_z_f32x8, b_z_f32x8, cov_zz_f32x8);
|
|
1934
|
+
|
|
1935
|
+
// Accumulate variance of A
|
|
1936
|
+
variance_a_f32x8 = _mm256_fmadd_ps(a_x_f32x8, a_x_f32x8, variance_a_f32x8);
|
|
1937
|
+
variance_a_f32x8 = _mm256_fmadd_ps(a_y_f32x8, a_y_f32x8, variance_a_f32x8);
|
|
1938
|
+
variance_a_f32x8 = _mm256_fmadd_ps(a_z_f32x8, a_z_f32x8, variance_a_f32x8);
|
|
1939
|
+
}
|
|
1940
|
+
|
|
1941
|
+
// Reduce vector accumulators
|
|
1942
|
+
nk_f32_t sum_a_x = nk_reduce_add_f32x8_haswell_(sum_a_x_f32x8);
|
|
1943
|
+
nk_f32_t sum_a_y = nk_reduce_add_f32x8_haswell_(sum_a_y_f32x8);
|
|
1944
|
+
nk_f32_t sum_a_z = nk_reduce_add_f32x8_haswell_(sum_a_z_f32x8);
|
|
1945
|
+
nk_f32_t sum_b_x = nk_reduce_add_f32x8_haswell_(sum_b_x_f32x8);
|
|
1946
|
+
nk_f32_t sum_b_y = nk_reduce_add_f32x8_haswell_(sum_b_y_f32x8);
|
|
1947
|
+
nk_f32_t sum_b_z = nk_reduce_add_f32x8_haswell_(sum_b_z_f32x8);
|
|
1948
|
+
nk_f32_t covariance_x_x = nk_reduce_add_f32x8_haswell_(cov_xx_f32x8);
|
|
1949
|
+
nk_f32_t covariance_x_y = nk_reduce_add_f32x8_haswell_(cov_xy_f32x8);
|
|
1950
|
+
nk_f32_t covariance_x_z = nk_reduce_add_f32x8_haswell_(cov_xz_f32x8);
|
|
1951
|
+
nk_f32_t covariance_y_x = nk_reduce_add_f32x8_haswell_(cov_yx_f32x8);
|
|
1952
|
+
nk_f32_t covariance_y_y = nk_reduce_add_f32x8_haswell_(cov_yy_f32x8);
|
|
1953
|
+
nk_f32_t covariance_y_z = nk_reduce_add_f32x8_haswell_(cov_yz_f32x8);
|
|
1954
|
+
nk_f32_t covariance_z_x = nk_reduce_add_f32x8_haswell_(cov_zx_f32x8);
|
|
1955
|
+
nk_f32_t covariance_z_y = nk_reduce_add_f32x8_haswell_(cov_zy_f32x8);
|
|
1956
|
+
nk_f32_t covariance_z_z = nk_reduce_add_f32x8_haswell_(cov_zz_f32x8);
|
|
1957
|
+
nk_f32_t variance_a_sum = nk_reduce_add_f32x8_haswell_(variance_a_f32x8);
|
|
1958
|
+
|
|
1959
|
+
// Scalar tail
|
|
1960
|
+
for (; i < n; ++i) {
|
|
1961
|
+
nk_f32_t ax, ay, az, bx, by, bz;
|
|
1962
|
+
nk_f16_to_f32_haswell(&a[i * 3 + 0], &ax);
|
|
1963
|
+
nk_f16_to_f32_haswell(&a[i * 3 + 1], &ay);
|
|
1964
|
+
nk_f16_to_f32_haswell(&a[i * 3 + 2], &az);
|
|
1965
|
+
nk_f16_to_f32_haswell(&b[i * 3 + 0], &bx);
|
|
1966
|
+
nk_f16_to_f32_haswell(&b[i * 3 + 1], &by);
|
|
1967
|
+
nk_f16_to_f32_haswell(&b[i * 3 + 2], &bz);
|
|
1968
|
+
sum_a_x += ax;
|
|
1969
|
+
sum_a_y += ay;
|
|
1970
|
+
sum_a_z += az;
|
|
1971
|
+
sum_b_x += bx;
|
|
1972
|
+
sum_b_y += by;
|
|
1973
|
+
sum_b_z += bz;
|
|
1974
|
+
covariance_x_x += ax * bx;
|
|
1975
|
+
covariance_x_y += ax * by;
|
|
1976
|
+
covariance_x_z += ax * bz;
|
|
1977
|
+
covariance_y_x += ay * bx;
|
|
1978
|
+
covariance_y_y += ay * by;
|
|
1979
|
+
covariance_y_z += ay * bz;
|
|
1980
|
+
covariance_z_x += az * bx;
|
|
1981
|
+
covariance_z_y += az * by;
|
|
1982
|
+
covariance_z_z += az * bz;
|
|
1983
|
+
variance_a_sum += ax * ax + ay * ay + az * az;
|
|
1984
|
+
}
|
|
1985
|
+
|
|
1986
|
+
// Compute centroids
|
|
1987
|
+
nk_f32_t inv_n = 1.0f / (nk_f32_t)n;
|
|
1988
|
+
nk_f32_t centroid_a_x = sum_a_x * inv_n, centroid_a_y = sum_a_y * inv_n, centroid_a_z = sum_a_z * inv_n;
|
|
1989
|
+
nk_f32_t centroid_b_x = sum_b_x * inv_n, centroid_b_y = sum_b_y * inv_n, centroid_b_z = sum_b_z * inv_n;
|
|
1990
|
+
|
|
1991
|
+
if (a_centroid) a_centroid[0] = centroid_a_x, a_centroid[1] = centroid_a_y, a_centroid[2] = centroid_a_z;
|
|
1992
|
+
if (b_centroid) b_centroid[0] = centroid_b_x, b_centroid[1] = centroid_b_y, b_centroid[2] = centroid_b_z;
|
|
1993
|
+
|
|
1994
|
+
// Compute centered covariance and variance
|
|
1995
|
+
nk_f32_t variance_a = variance_a_sum * inv_n -
|
|
1996
|
+
(centroid_a_x * centroid_a_x + centroid_a_y * centroid_a_y + centroid_a_z * centroid_a_z);
|
|
1997
|
+
|
|
1998
|
+
// Apply centering correction to covariance matrix
|
|
1999
|
+
covariance_x_x -= n * centroid_a_x * centroid_b_x;
|
|
2000
|
+
covariance_x_y -= n * centroid_a_x * centroid_b_y;
|
|
2001
|
+
covariance_x_z -= n * centroid_a_x * centroid_b_z;
|
|
2002
|
+
covariance_y_x -= n * centroid_a_y * centroid_b_x;
|
|
2003
|
+
covariance_y_y -= n * centroid_a_y * centroid_b_y;
|
|
2004
|
+
covariance_y_z -= n * centroid_a_y * centroid_b_z;
|
|
2005
|
+
covariance_z_x -= n * centroid_a_z * centroid_b_x;
|
|
2006
|
+
covariance_z_y -= n * centroid_a_z * centroid_b_y;
|
|
2007
|
+
covariance_z_z -= n * centroid_a_z * centroid_b_z;
|
|
2008
|
+
|
|
2009
|
+
nk_f32_t cross_covariance[9] = {covariance_x_x, covariance_x_y, covariance_x_z, covariance_y_x, covariance_y_y,
|
|
2010
|
+
covariance_y_z, covariance_z_x, covariance_z_y, covariance_z_z};
|
|
2011
|
+
|
|
2012
|
+
// SVD
|
|
2013
|
+
nk_f32_t svd_u[9], svd_s[9], svd_v[9];
|
|
2014
|
+
nk_svd3x3_f32_(cross_covariance, svd_u, svd_s, svd_v);
|
|
2015
|
+
|
|
2016
|
+
// R = V * Uᵀ
|
|
2017
|
+
nk_f32_t r[9];
|
|
2018
|
+
r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
|
|
2019
|
+
r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
|
|
2020
|
+
r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
|
|
2021
|
+
r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
|
|
2022
|
+
r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
|
|
2023
|
+
r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
|
|
2024
|
+
r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
|
|
2025
|
+
r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
|
|
2026
|
+
r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
|
|
2027
|
+
|
|
2028
|
+
// Scale factor: c = trace(D × S) / (n × variance(a))
|
|
2029
|
+
nk_f32_t det = nk_det3x3_f32_(r);
|
|
2030
|
+
nk_f32_t d3 = det < 0 ? -1.0f : 1.0f;
|
|
2031
|
+
nk_f32_t trace_ds = svd_s[0] + svd_s[4] + d3 * svd_s[8];
|
|
2032
|
+
nk_f32_t c = trace_ds / (n * variance_a);
|
|
2033
|
+
if (scale) *scale = c;
|
|
2034
|
+
|
|
2035
|
+
// Handle reflection
|
|
2036
|
+
if (det < 0) {
|
|
2037
|
+
svd_v[2] = -svd_v[2], svd_v[5] = -svd_v[5], svd_v[8] = -svd_v[8];
|
|
2038
|
+
r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
|
|
2039
|
+
r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
|
|
2040
|
+
r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
|
|
2041
|
+
r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
|
|
2042
|
+
r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
|
|
2043
|
+
r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
|
|
2044
|
+
r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
|
|
2045
|
+
r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
|
|
2046
|
+
r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
|
|
2047
|
+
}
|
|
2048
|
+
|
|
2049
|
+
/* Output rotation matrix */
|
|
2050
|
+
if (rotation) {
|
|
2051
|
+
for (int j = 0; j < 9; ++j) rotation[j] = r[j];
|
|
2052
|
+
}
|
|
2053
|
+
|
|
2054
|
+
// Compute RMSD with scaling
|
|
2055
|
+
nk_f32_t sum_squared = nk_transformed_ssd_f16_haswell_(a, b, n, r, c, centroid_a_x, centroid_a_y, centroid_a_z,
|
|
2056
|
+
centroid_b_x, centroid_b_y, centroid_b_z);
|
|
2057
|
+
*result = nk_f32_sqrt_haswell(sum_squared * inv_n);
|
|
2058
|
+
}
|
|
2059
|
+
|
|
2060
|
+
NK_PUBLIC void nk_umeyama_bf16_haswell(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
|
|
2061
|
+
nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
|
|
2062
|
+
// Fused single-pass: load bf16, convert to f32, compute centroids, covariance, and variance
|
|
2063
|
+
__m256 const zeros_f32x8 = _mm256_setzero_ps();
|
|
2064
|
+
|
|
2065
|
+
__m256 sum_a_x_f32x8 = zeros_f32x8, sum_a_y_f32x8 = zeros_f32x8, sum_a_z_f32x8 = zeros_f32x8;
|
|
2066
|
+
__m256 sum_b_x_f32x8 = zeros_f32x8, sum_b_y_f32x8 = zeros_f32x8, sum_b_z_f32x8 = zeros_f32x8;
|
|
2067
|
+
__m256 cov_xx_f32x8 = zeros_f32x8, cov_xy_f32x8 = zeros_f32x8, cov_xz_f32x8 = zeros_f32x8;
|
|
2068
|
+
__m256 cov_yx_f32x8 = zeros_f32x8, cov_yy_f32x8 = zeros_f32x8, cov_yz_f32x8 = zeros_f32x8;
|
|
2069
|
+
__m256 cov_zx_f32x8 = zeros_f32x8, cov_zy_f32x8 = zeros_f32x8, cov_zz_f32x8 = zeros_f32x8;
|
|
2070
|
+
__m256 variance_a_f32x8 = zeros_f32x8;
|
|
2071
|
+
|
|
2072
|
+
nk_size_t i = 0;
|
|
2073
|
+
__m256 a_x_f32x8, a_y_f32x8, a_z_f32x8, b_x_f32x8, b_y_f32x8, b_z_f32x8;
|
|
2074
|
+
|
|
2075
|
+
for (; i + 8 <= n; i += 8) {
|
|
2076
|
+
nk_deinterleave_bf16x8_to_f32x8_haswell_(a + i * 3, &a_x_f32x8, &a_y_f32x8, &a_z_f32x8);
|
|
2077
|
+
nk_deinterleave_bf16x8_to_f32x8_haswell_(b + i * 3, &b_x_f32x8, &b_y_f32x8, &b_z_f32x8);
|
|
2078
|
+
|
|
2079
|
+
// Accumulate centroids
|
|
2080
|
+
sum_a_x_f32x8 = _mm256_add_ps(sum_a_x_f32x8, a_x_f32x8);
|
|
2081
|
+
sum_a_y_f32x8 = _mm256_add_ps(sum_a_y_f32x8, a_y_f32x8);
|
|
2082
|
+
sum_a_z_f32x8 = _mm256_add_ps(sum_a_z_f32x8, a_z_f32x8);
|
|
2083
|
+
sum_b_x_f32x8 = _mm256_add_ps(sum_b_x_f32x8, b_x_f32x8);
|
|
2084
|
+
sum_b_y_f32x8 = _mm256_add_ps(sum_b_y_f32x8, b_y_f32x8);
|
|
2085
|
+
sum_b_z_f32x8 = _mm256_add_ps(sum_b_z_f32x8, b_z_f32x8);
|
|
2086
|
+
|
|
2087
|
+
// Accumulate outer products
|
|
2088
|
+
cov_xx_f32x8 = _mm256_fmadd_ps(a_x_f32x8, b_x_f32x8, cov_xx_f32x8);
|
|
2089
|
+
cov_xy_f32x8 = _mm256_fmadd_ps(a_x_f32x8, b_y_f32x8, cov_xy_f32x8);
|
|
2090
|
+
cov_xz_f32x8 = _mm256_fmadd_ps(a_x_f32x8, b_z_f32x8, cov_xz_f32x8);
|
|
2091
|
+
cov_yx_f32x8 = _mm256_fmadd_ps(a_y_f32x8, b_x_f32x8, cov_yx_f32x8);
|
|
2092
|
+
cov_yy_f32x8 = _mm256_fmadd_ps(a_y_f32x8, b_y_f32x8, cov_yy_f32x8);
|
|
2093
|
+
cov_yz_f32x8 = _mm256_fmadd_ps(a_y_f32x8, b_z_f32x8, cov_yz_f32x8);
|
|
2094
|
+
cov_zx_f32x8 = _mm256_fmadd_ps(a_z_f32x8, b_x_f32x8, cov_zx_f32x8);
|
|
2095
|
+
cov_zy_f32x8 = _mm256_fmadd_ps(a_z_f32x8, b_y_f32x8, cov_zy_f32x8);
|
|
2096
|
+
cov_zz_f32x8 = _mm256_fmadd_ps(a_z_f32x8, b_z_f32x8, cov_zz_f32x8);
|
|
2097
|
+
|
|
2098
|
+
// Accumulate variance of A
|
|
2099
|
+
variance_a_f32x8 = _mm256_fmadd_ps(a_x_f32x8, a_x_f32x8, variance_a_f32x8);
|
|
2100
|
+
variance_a_f32x8 = _mm256_fmadd_ps(a_y_f32x8, a_y_f32x8, variance_a_f32x8);
|
|
2101
|
+
variance_a_f32x8 = _mm256_fmadd_ps(a_z_f32x8, a_z_f32x8, variance_a_f32x8);
|
|
2102
|
+
}
|
|
2103
|
+
|
|
2104
|
+
// Reduce vector accumulators
|
|
2105
|
+
nk_f32_t sum_a_x = nk_reduce_add_f32x8_haswell_(sum_a_x_f32x8);
|
|
2106
|
+
nk_f32_t sum_a_y = nk_reduce_add_f32x8_haswell_(sum_a_y_f32x8);
|
|
2107
|
+
nk_f32_t sum_a_z = nk_reduce_add_f32x8_haswell_(sum_a_z_f32x8);
|
|
2108
|
+
nk_f32_t sum_b_x = nk_reduce_add_f32x8_haswell_(sum_b_x_f32x8);
|
|
2109
|
+
nk_f32_t sum_b_y = nk_reduce_add_f32x8_haswell_(sum_b_y_f32x8);
|
|
2110
|
+
nk_f32_t sum_b_z = nk_reduce_add_f32x8_haswell_(sum_b_z_f32x8);
|
|
2111
|
+
nk_f32_t covariance_x_x = nk_reduce_add_f32x8_haswell_(cov_xx_f32x8);
|
|
2112
|
+
nk_f32_t covariance_x_y = nk_reduce_add_f32x8_haswell_(cov_xy_f32x8);
|
|
2113
|
+
nk_f32_t covariance_x_z = nk_reduce_add_f32x8_haswell_(cov_xz_f32x8);
|
|
2114
|
+
nk_f32_t covariance_y_x = nk_reduce_add_f32x8_haswell_(cov_yx_f32x8);
|
|
2115
|
+
nk_f32_t covariance_y_y = nk_reduce_add_f32x8_haswell_(cov_yy_f32x8);
|
|
2116
|
+
nk_f32_t covariance_y_z = nk_reduce_add_f32x8_haswell_(cov_yz_f32x8);
|
|
2117
|
+
nk_f32_t covariance_z_x = nk_reduce_add_f32x8_haswell_(cov_zx_f32x8);
|
|
2118
|
+
nk_f32_t covariance_z_y = nk_reduce_add_f32x8_haswell_(cov_zy_f32x8);
|
|
2119
|
+
nk_f32_t covariance_z_z = nk_reduce_add_f32x8_haswell_(cov_zz_f32x8);
|
|
2120
|
+
nk_f32_t variance_a_sum = nk_reduce_add_f32x8_haswell_(variance_a_f32x8);
|
|
2121
|
+
|
|
2122
|
+
// Scalar tail
|
|
2123
|
+
for (; i < n; ++i) {
|
|
2124
|
+
nk_f32_t ax, ay, az, bx, by, bz;
|
|
2125
|
+
nk_bf16_to_f32_serial(&a[i * 3 + 0], &ax);
|
|
2126
|
+
nk_bf16_to_f32_serial(&a[i * 3 + 1], &ay);
|
|
2127
|
+
nk_bf16_to_f32_serial(&a[i * 3 + 2], &az);
|
|
2128
|
+
nk_bf16_to_f32_serial(&b[i * 3 + 0], &bx);
|
|
2129
|
+
nk_bf16_to_f32_serial(&b[i * 3 + 1], &by);
|
|
2130
|
+
nk_bf16_to_f32_serial(&b[i * 3 + 2], &bz);
|
|
2131
|
+
sum_a_x += ax;
|
|
2132
|
+
sum_a_y += ay;
|
|
2133
|
+
sum_a_z += az;
|
|
2134
|
+
sum_b_x += bx;
|
|
2135
|
+
sum_b_y += by;
|
|
2136
|
+
sum_b_z += bz;
|
|
2137
|
+
covariance_x_x += ax * bx;
|
|
2138
|
+
covariance_x_y += ax * by;
|
|
2139
|
+
covariance_x_z += ax * bz;
|
|
2140
|
+
covariance_y_x += ay * bx;
|
|
2141
|
+
covariance_y_y += ay * by;
|
|
2142
|
+
covariance_y_z += ay * bz;
|
|
2143
|
+
covariance_z_x += az * bx;
|
|
2144
|
+
covariance_z_y += az * by;
|
|
2145
|
+
covariance_z_z += az * bz;
|
|
2146
|
+
variance_a_sum += ax * ax + ay * ay + az * az;
|
|
2147
|
+
}
|
|
2148
|
+
|
|
2149
|
+
// Compute centroids
|
|
2150
|
+
nk_f32_t inv_n = 1.0f / (nk_f32_t)n;
|
|
2151
|
+
nk_f32_t centroid_a_x = sum_a_x * inv_n, centroid_a_y = sum_a_y * inv_n, centroid_a_z = sum_a_z * inv_n;
|
|
2152
|
+
nk_f32_t centroid_b_x = sum_b_x * inv_n, centroid_b_y = sum_b_y * inv_n, centroid_b_z = sum_b_z * inv_n;
|
|
2153
|
+
|
|
2154
|
+
if (a_centroid) a_centroid[0] = centroid_a_x, a_centroid[1] = centroid_a_y, a_centroid[2] = centroid_a_z;
|
|
2155
|
+
if (b_centroid) b_centroid[0] = centroid_b_x, b_centroid[1] = centroid_b_y, b_centroid[2] = centroid_b_z;
|
|
2156
|
+
|
|
2157
|
+
// Compute centered covariance and variance
|
|
2158
|
+
nk_f32_t variance_a = variance_a_sum * inv_n -
|
|
2159
|
+
(centroid_a_x * centroid_a_x + centroid_a_y * centroid_a_y + centroid_a_z * centroid_a_z);
|
|
2160
|
+
|
|
2161
|
+
// Apply centering correction to covariance matrix
|
|
2162
|
+
covariance_x_x -= n * centroid_a_x * centroid_b_x;
|
|
2163
|
+
covariance_x_y -= n * centroid_a_x * centroid_b_y;
|
|
2164
|
+
covariance_x_z -= n * centroid_a_x * centroid_b_z;
|
|
2165
|
+
covariance_y_x -= n * centroid_a_y * centroid_b_x;
|
|
2166
|
+
covariance_y_y -= n * centroid_a_y * centroid_b_y;
|
|
2167
|
+
covariance_y_z -= n * centroid_a_y * centroid_b_z;
|
|
2168
|
+
covariance_z_x -= n * centroid_a_z * centroid_b_x;
|
|
2169
|
+
covariance_z_y -= n * centroid_a_z * centroid_b_y;
|
|
2170
|
+
covariance_z_z -= n * centroid_a_z * centroid_b_z;
|
|
2171
|
+
|
|
2172
|
+
nk_f32_t cross_covariance[9] = {covariance_x_x, covariance_x_y, covariance_x_z, covariance_y_x, covariance_y_y,
|
|
2173
|
+
covariance_y_z, covariance_z_x, covariance_z_y, covariance_z_z};
|
|
2174
|
+
|
|
2175
|
+
// SVD
|
|
2176
|
+
nk_f32_t svd_u[9], svd_s[9], svd_v[9];
|
|
2177
|
+
nk_svd3x3_f32_(cross_covariance, svd_u, svd_s, svd_v);
|
|
2178
|
+
|
|
2179
|
+
// R = V * Uᵀ
|
|
2180
|
+
nk_f32_t r[9];
|
|
2181
|
+
r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
|
|
2182
|
+
r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
|
|
2183
|
+
r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
|
|
2184
|
+
r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
|
|
2185
|
+
r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
|
|
2186
|
+
r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
|
|
2187
|
+
r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
|
|
2188
|
+
r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
|
|
2189
|
+
r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
|
|
2190
|
+
|
|
2191
|
+
// Scale factor: c = trace(D × S) / (n × variance(a))
|
|
2192
|
+
nk_f32_t det = nk_det3x3_f32_(r);
|
|
2193
|
+
nk_f32_t d3 = det < 0 ? -1.0f : 1.0f;
|
|
2194
|
+
nk_f32_t trace_ds = svd_s[0] + svd_s[4] + d3 * svd_s[8];
|
|
2195
|
+
nk_f32_t c = trace_ds / (n * variance_a);
|
|
2196
|
+
if (scale) *scale = c;
|
|
2197
|
+
|
|
2198
|
+
// Handle reflection
|
|
2199
|
+
if (det < 0) {
|
|
2200
|
+
svd_v[2] = -svd_v[2], svd_v[5] = -svd_v[5], svd_v[8] = -svd_v[8];
|
|
2201
|
+
r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
|
|
2202
|
+
r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
|
|
2203
|
+
r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
|
|
2204
|
+
r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
|
|
2205
|
+
r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
|
|
2206
|
+
r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
|
|
2207
|
+
r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
|
|
2208
|
+
r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
|
|
2209
|
+
r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
|
|
2210
|
+
}
|
|
2211
|
+
|
|
2212
|
+
/* Output rotation matrix */
|
|
2213
|
+
if (rotation) {
|
|
2214
|
+
for (int j = 0; j < 9; ++j) rotation[j] = r[j];
|
|
2215
|
+
}
|
|
2216
|
+
|
|
2217
|
+
// Compute RMSD with scaling
|
|
2218
|
+
nk_f32_t sum_squared = nk_transformed_ssd_bf16_haswell_(a, b, n, r, c, centroid_a_x, centroid_a_y, centroid_a_z,
|
|
2219
|
+
centroid_b_x, centroid_b_y, centroid_b_z);
|
|
2220
|
+
*result = nk_f32_sqrt_haswell(sum_squared * inv_n);
|
|
2221
|
+
}
|
|
2222
|
+
|
|
2223
|
+
#if defined(__clang__)
|
|
2224
|
+
#pragma clang attribute pop
|
|
2225
|
+
#elif defined(__GNUC__)
|
|
2226
|
+
#pragma GCC pop_options
|
|
2227
|
+
#endif
|
|
2228
|
+
|
|
2229
|
+
#if defined(__cplusplus)
|
|
2230
|
+
} // extern "C"
|
|
2231
|
+
#endif
|
|
2232
|
+
|
|
2233
|
+
#endif // NK_TARGET_HASWELL
|
|
2234
|
+
#endif // NK_TARGET_X86_
|
|
2235
|
+
#endif // NK_MESH_HASWELL_H
|