numkong 7.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (294) hide show
  1. package/LICENSE +201 -0
  2. package/README.md +495 -0
  3. package/binding.gyp +540 -0
  4. package/c/dispatch.h +512 -0
  5. package/c/dispatch_bf16.c +389 -0
  6. package/c/dispatch_bf16c.c +52 -0
  7. package/c/dispatch_e2m3.c +263 -0
  8. package/c/dispatch_e3m2.c +243 -0
  9. package/c/dispatch_e4m3.c +276 -0
  10. package/c/dispatch_e5m2.c +272 -0
  11. package/c/dispatch_f16.c +376 -0
  12. package/c/dispatch_f16c.c +58 -0
  13. package/c/dispatch_f32.c +378 -0
  14. package/c/dispatch_f32c.c +99 -0
  15. package/c/dispatch_f64.c +296 -0
  16. package/c/dispatch_f64c.c +98 -0
  17. package/c/dispatch_i16.c +96 -0
  18. package/c/dispatch_i32.c +89 -0
  19. package/c/dispatch_i4.c +150 -0
  20. package/c/dispatch_i64.c +86 -0
  21. package/c/dispatch_i8.c +289 -0
  22. package/c/dispatch_other.c +330 -0
  23. package/c/dispatch_u1.c +148 -0
  24. package/c/dispatch_u16.c +124 -0
  25. package/c/dispatch_u32.c +118 -0
  26. package/c/dispatch_u4.c +150 -0
  27. package/c/dispatch_u64.c +102 -0
  28. package/c/dispatch_u8.c +303 -0
  29. package/c/numkong.c +950 -0
  30. package/include/README.md +573 -0
  31. package/include/module.modulemap +129 -0
  32. package/include/numkong/attention/sapphireamx.h +1361 -0
  33. package/include/numkong/attention/sme.h +2066 -0
  34. package/include/numkong/attention.h +49 -0
  35. package/include/numkong/capabilities.h +748 -0
  36. package/include/numkong/cast/README.md +262 -0
  37. package/include/numkong/cast/haswell.h +975 -0
  38. package/include/numkong/cast/icelake.h +470 -0
  39. package/include/numkong/cast/neon.h +1192 -0
  40. package/include/numkong/cast/rvv.h +1021 -0
  41. package/include/numkong/cast/sapphire.h +262 -0
  42. package/include/numkong/cast/serial.h +2262 -0
  43. package/include/numkong/cast/skylake.h +856 -0
  44. package/include/numkong/cast/v128relaxed.h +180 -0
  45. package/include/numkong/cast.h +230 -0
  46. package/include/numkong/curved/README.md +223 -0
  47. package/include/numkong/curved/genoa.h +182 -0
  48. package/include/numkong/curved/haswell.h +276 -0
  49. package/include/numkong/curved/neon.h +205 -0
  50. package/include/numkong/curved/neonbfdot.h +212 -0
  51. package/include/numkong/curved/neonhalf.h +212 -0
  52. package/include/numkong/curved/rvv.h +305 -0
  53. package/include/numkong/curved/serial.h +207 -0
  54. package/include/numkong/curved/skylake.h +457 -0
  55. package/include/numkong/curved/smef64.h +506 -0
  56. package/include/numkong/curved.h +517 -0
  57. package/include/numkong/curved.hpp +144 -0
  58. package/include/numkong/dot/README.md +425 -0
  59. package/include/numkong/dot/alder.h +563 -0
  60. package/include/numkong/dot/genoa.h +315 -0
  61. package/include/numkong/dot/haswell.h +1688 -0
  62. package/include/numkong/dot/icelake.h +883 -0
  63. package/include/numkong/dot/neon.h +818 -0
  64. package/include/numkong/dot/neonbfdot.h +244 -0
  65. package/include/numkong/dot/neonfhm.h +360 -0
  66. package/include/numkong/dot/neonhalf.h +198 -0
  67. package/include/numkong/dot/neonsdot.h +508 -0
  68. package/include/numkong/dot/rvv.h +714 -0
  69. package/include/numkong/dot/rvvbb.h +72 -0
  70. package/include/numkong/dot/rvvbf16.h +123 -0
  71. package/include/numkong/dot/rvvhalf.h +129 -0
  72. package/include/numkong/dot/sapphire.h +141 -0
  73. package/include/numkong/dot/serial.h +838 -0
  74. package/include/numkong/dot/sierra.h +405 -0
  75. package/include/numkong/dot/skylake.h +1084 -0
  76. package/include/numkong/dot/sve.h +379 -0
  77. package/include/numkong/dot/svebfdot.h +74 -0
  78. package/include/numkong/dot/svehalf.h +123 -0
  79. package/include/numkong/dot/v128relaxed.h +1258 -0
  80. package/include/numkong/dot.h +1070 -0
  81. package/include/numkong/dot.hpp +94 -0
  82. package/include/numkong/dots/README.md +496 -0
  83. package/include/numkong/dots/alder.h +114 -0
  84. package/include/numkong/dots/genoa.h +94 -0
  85. package/include/numkong/dots/haswell.h +295 -0
  86. package/include/numkong/dots/icelake.h +171 -0
  87. package/include/numkong/dots/neon.h +120 -0
  88. package/include/numkong/dots/neonbfdot.h +58 -0
  89. package/include/numkong/dots/neonfhm.h +94 -0
  90. package/include/numkong/dots/neonhalf.h +57 -0
  91. package/include/numkong/dots/neonsdot.h +108 -0
  92. package/include/numkong/dots/rvv.h +2486 -0
  93. package/include/numkong/dots/sapphireamx.h +3973 -0
  94. package/include/numkong/dots/serial.h +2844 -0
  95. package/include/numkong/dots/sierra.h +97 -0
  96. package/include/numkong/dots/skylake.h +196 -0
  97. package/include/numkong/dots/sme.h +5372 -0
  98. package/include/numkong/dots/smebi32.h +461 -0
  99. package/include/numkong/dots/smef64.h +1318 -0
  100. package/include/numkong/dots/smehalf.h +47 -0
  101. package/include/numkong/dots/v128relaxed.h +294 -0
  102. package/include/numkong/dots.h +2804 -0
  103. package/include/numkong/dots.hpp +639 -0
  104. package/include/numkong/each/README.md +469 -0
  105. package/include/numkong/each/haswell.h +1658 -0
  106. package/include/numkong/each/icelake.h +272 -0
  107. package/include/numkong/each/neon.h +1104 -0
  108. package/include/numkong/each/neonbfdot.h +212 -0
  109. package/include/numkong/each/neonhalf.h +410 -0
  110. package/include/numkong/each/rvv.h +1121 -0
  111. package/include/numkong/each/sapphire.h +477 -0
  112. package/include/numkong/each/serial.h +260 -0
  113. package/include/numkong/each/skylake.h +1562 -0
  114. package/include/numkong/each.h +2146 -0
  115. package/include/numkong/each.hpp +434 -0
  116. package/include/numkong/geospatial/README.md +147 -0
  117. package/include/numkong/geospatial/haswell.h +593 -0
  118. package/include/numkong/geospatial/neon.h +571 -0
  119. package/include/numkong/geospatial/rvv.h +701 -0
  120. package/include/numkong/geospatial/serial.h +309 -0
  121. package/include/numkong/geospatial/skylake.h +577 -0
  122. package/include/numkong/geospatial/v128relaxed.h +613 -0
  123. package/include/numkong/geospatial.h +453 -0
  124. package/include/numkong/geospatial.hpp +235 -0
  125. package/include/numkong/matrix.hpp +336 -0
  126. package/include/numkong/maxsim/README.md +187 -0
  127. package/include/numkong/maxsim/alder.h +511 -0
  128. package/include/numkong/maxsim/genoa.h +115 -0
  129. package/include/numkong/maxsim/haswell.h +553 -0
  130. package/include/numkong/maxsim/icelake.h +480 -0
  131. package/include/numkong/maxsim/neonsdot.h +394 -0
  132. package/include/numkong/maxsim/sapphireamx.h +877 -0
  133. package/include/numkong/maxsim/serial.h +490 -0
  134. package/include/numkong/maxsim/sme.h +929 -0
  135. package/include/numkong/maxsim/v128relaxed.h +280 -0
  136. package/include/numkong/maxsim.h +571 -0
  137. package/include/numkong/maxsim.hpp +133 -0
  138. package/include/numkong/mesh/README.md +227 -0
  139. package/include/numkong/mesh/haswell.h +2235 -0
  140. package/include/numkong/mesh/neon.h +1329 -0
  141. package/include/numkong/mesh/neonbfdot.h +842 -0
  142. package/include/numkong/mesh/neonhalf.h +616 -0
  143. package/include/numkong/mesh/rvv.h +916 -0
  144. package/include/numkong/mesh/serial.h +742 -0
  145. package/include/numkong/mesh/skylake.h +1135 -0
  146. package/include/numkong/mesh/v128relaxed.h +1052 -0
  147. package/include/numkong/mesh.h +652 -0
  148. package/include/numkong/mesh.hpp +762 -0
  149. package/include/numkong/numkong.h +78 -0
  150. package/include/numkong/numkong.hpp +57 -0
  151. package/include/numkong/probability/README.md +173 -0
  152. package/include/numkong/probability/haswell.h +267 -0
  153. package/include/numkong/probability/neon.h +225 -0
  154. package/include/numkong/probability/rvv.h +409 -0
  155. package/include/numkong/probability/serial.h +169 -0
  156. package/include/numkong/probability/skylake.h +324 -0
  157. package/include/numkong/probability.h +383 -0
  158. package/include/numkong/probability.hpp +120 -0
  159. package/include/numkong/random.h +50 -0
  160. package/include/numkong/random.hpp +285 -0
  161. package/include/numkong/reduce/README.md +547 -0
  162. package/include/numkong/reduce/alder.h +632 -0
  163. package/include/numkong/reduce/genoa.h +201 -0
  164. package/include/numkong/reduce/haswell.h +3783 -0
  165. package/include/numkong/reduce/icelake.h +549 -0
  166. package/include/numkong/reduce/neon.h +3841 -0
  167. package/include/numkong/reduce/neonbfdot.h +353 -0
  168. package/include/numkong/reduce/neonfhm.h +665 -0
  169. package/include/numkong/reduce/neonhalf.h +157 -0
  170. package/include/numkong/reduce/neonsdot.h +357 -0
  171. package/include/numkong/reduce/rvv.h +3407 -0
  172. package/include/numkong/reduce/serial.h +757 -0
  173. package/include/numkong/reduce/sierra.h +338 -0
  174. package/include/numkong/reduce/skylake.h +3792 -0
  175. package/include/numkong/reduce/v128relaxed.h +2302 -0
  176. package/include/numkong/reduce.h +1597 -0
  177. package/include/numkong/reduce.hpp +633 -0
  178. package/include/numkong/scalar/README.md +89 -0
  179. package/include/numkong/scalar/haswell.h +113 -0
  180. package/include/numkong/scalar/neon.h +122 -0
  181. package/include/numkong/scalar/neonhalf.h +70 -0
  182. package/include/numkong/scalar/rvv.h +211 -0
  183. package/include/numkong/scalar/sapphire.h +63 -0
  184. package/include/numkong/scalar/serial.h +332 -0
  185. package/include/numkong/scalar/v128relaxed.h +56 -0
  186. package/include/numkong/scalar.h +683 -0
  187. package/include/numkong/set/README.md +179 -0
  188. package/include/numkong/set/haswell.h +334 -0
  189. package/include/numkong/set/icelake.h +485 -0
  190. package/include/numkong/set/neon.h +364 -0
  191. package/include/numkong/set/rvv.h +226 -0
  192. package/include/numkong/set/rvvbb.h +117 -0
  193. package/include/numkong/set/serial.h +174 -0
  194. package/include/numkong/set/sve.h +185 -0
  195. package/include/numkong/set/v128relaxed.h +240 -0
  196. package/include/numkong/set.h +457 -0
  197. package/include/numkong/set.hpp +114 -0
  198. package/include/numkong/sets/README.md +149 -0
  199. package/include/numkong/sets/haswell.h +63 -0
  200. package/include/numkong/sets/icelake.h +66 -0
  201. package/include/numkong/sets/neon.h +61 -0
  202. package/include/numkong/sets/serial.h +43 -0
  203. package/include/numkong/sets/smebi32.h +1099 -0
  204. package/include/numkong/sets/v128relaxed.h +58 -0
  205. package/include/numkong/sets.h +339 -0
  206. package/include/numkong/sparse/README.md +156 -0
  207. package/include/numkong/sparse/icelake.h +463 -0
  208. package/include/numkong/sparse/neon.h +288 -0
  209. package/include/numkong/sparse/serial.h +117 -0
  210. package/include/numkong/sparse/sve2.h +507 -0
  211. package/include/numkong/sparse/turin.h +322 -0
  212. package/include/numkong/sparse.h +363 -0
  213. package/include/numkong/sparse.hpp +113 -0
  214. package/include/numkong/spatial/README.md +435 -0
  215. package/include/numkong/spatial/alder.h +607 -0
  216. package/include/numkong/spatial/genoa.h +290 -0
  217. package/include/numkong/spatial/haswell.h +960 -0
  218. package/include/numkong/spatial/icelake.h +586 -0
  219. package/include/numkong/spatial/neon.h +773 -0
  220. package/include/numkong/spatial/neonbfdot.h +165 -0
  221. package/include/numkong/spatial/neonhalf.h +118 -0
  222. package/include/numkong/spatial/neonsdot.h +261 -0
  223. package/include/numkong/spatial/rvv.h +984 -0
  224. package/include/numkong/spatial/rvvbf16.h +123 -0
  225. package/include/numkong/spatial/rvvhalf.h +117 -0
  226. package/include/numkong/spatial/sapphire.h +343 -0
  227. package/include/numkong/spatial/serial.h +346 -0
  228. package/include/numkong/spatial/sierra.h +323 -0
  229. package/include/numkong/spatial/skylake.h +606 -0
  230. package/include/numkong/spatial/sve.h +224 -0
  231. package/include/numkong/spatial/svebfdot.h +122 -0
  232. package/include/numkong/spatial/svehalf.h +109 -0
  233. package/include/numkong/spatial/v128relaxed.h +717 -0
  234. package/include/numkong/spatial.h +1425 -0
  235. package/include/numkong/spatial.hpp +183 -0
  236. package/include/numkong/spatials/README.md +580 -0
  237. package/include/numkong/spatials/alder.h +94 -0
  238. package/include/numkong/spatials/genoa.h +94 -0
  239. package/include/numkong/spatials/haswell.h +219 -0
  240. package/include/numkong/spatials/icelake.h +113 -0
  241. package/include/numkong/spatials/neon.h +109 -0
  242. package/include/numkong/spatials/neonbfdot.h +60 -0
  243. package/include/numkong/spatials/neonfhm.h +92 -0
  244. package/include/numkong/spatials/neonhalf.h +58 -0
  245. package/include/numkong/spatials/neonsdot.h +109 -0
  246. package/include/numkong/spatials/rvv.h +1960 -0
  247. package/include/numkong/spatials/sapphireamx.h +1149 -0
  248. package/include/numkong/spatials/serial.h +226 -0
  249. package/include/numkong/spatials/sierra.h +96 -0
  250. package/include/numkong/spatials/skylake.h +184 -0
  251. package/include/numkong/spatials/sme.h +1901 -0
  252. package/include/numkong/spatials/smef64.h +465 -0
  253. package/include/numkong/spatials/v128relaxed.h +240 -0
  254. package/include/numkong/spatials.h +3021 -0
  255. package/include/numkong/spatials.hpp +508 -0
  256. package/include/numkong/tensor.hpp +1592 -0
  257. package/include/numkong/trigonometry/README.md +184 -0
  258. package/include/numkong/trigonometry/haswell.h +652 -0
  259. package/include/numkong/trigonometry/neon.h +639 -0
  260. package/include/numkong/trigonometry/rvv.h +699 -0
  261. package/include/numkong/trigonometry/serial.h +703 -0
  262. package/include/numkong/trigonometry/skylake.h +721 -0
  263. package/include/numkong/trigonometry/v128relaxed.h +666 -0
  264. package/include/numkong/trigonometry.h +467 -0
  265. package/include/numkong/trigonometry.hpp +166 -0
  266. package/include/numkong/types.h +1384 -0
  267. package/include/numkong/types.hpp +5603 -0
  268. package/include/numkong/vector.hpp +698 -0
  269. package/javascript/README.md +246 -0
  270. package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
  271. package/javascript/dist/cjs/numkong-wasm.js +617 -0
  272. package/javascript/dist/cjs/numkong.d.ts +343 -0
  273. package/javascript/dist/cjs/numkong.js +523 -0
  274. package/javascript/dist/cjs/package.json +3 -0
  275. package/javascript/dist/cjs/types.d.ts +284 -0
  276. package/javascript/dist/cjs/types.js +653 -0
  277. package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
  278. package/javascript/dist/esm/numkong-wasm.js +595 -0
  279. package/javascript/dist/esm/numkong.d.ts +343 -0
  280. package/javascript/dist/esm/numkong.js +452 -0
  281. package/javascript/dist/esm/package.json +3 -0
  282. package/javascript/dist/esm/types.d.ts +284 -0
  283. package/javascript/dist/esm/types.js +630 -0
  284. package/javascript/dist-package-cjs.json +3 -0
  285. package/javascript/dist-package-esm.json +3 -0
  286. package/javascript/node-gyp-build.d.ts +1 -0
  287. package/javascript/numkong-wasm.ts +756 -0
  288. package/javascript/numkong.c +689 -0
  289. package/javascript/numkong.ts +575 -0
  290. package/javascript/tsconfig-base.json +39 -0
  291. package/javascript/tsconfig-cjs.json +8 -0
  292. package/javascript/tsconfig-esm.json +8 -0
  293. package/javascript/types.ts +674 -0
  294. package/package.json +87 -0
@@ -0,0 +1,2235 @@
1
+ /**
2
+ * @brief SIMD-accelerated Point Cloud Alignment for Haswell.
3
+ * @file include/numkong/mesh/haswell.h
4
+ * @author Ash Vardanian
5
+ * @date December 27, 2025
6
+ *
7
+ * @sa include/numkong/mesh.h
8
+ *
9
+ * @section haswell_mesh_instructions Key AVX2 Mesh Instructions
10
+ *
11
+ * Intrinsic Instruction Latency Throughput Ports
12
+ * _mm256_fmadd_ps VFMADD (YMM, YMM, YMM) 5cy 0.5/cy p01
13
+ * _mm256_hadd_ps VHADDPS (YMM, YMM, YMM) 7cy 0.5/cy p01+p5
14
+ * _mm256_permute2f128_ps VPERM2F128 (YMM, YMM, YMM, I8) 3cy 1/cy p5
15
+ * _mm256_extractf128_ps VEXTRACTF128 (XMM, YMM, I8) 3cy 1/cy p5
16
+ * _mm256_i32gather_ps VGATHERDPS (YMM, M, YMM, YMM) 12cy 5/cy p0+p23
17
+ *
18
+ * Point cloud operations (centroid, covariance, Kabsch alignment) use gather instructions for
19
+ * stride-3 xyz deinterleaving. Multiple FMA accumulators hide the 5-cycle FMA latency. VHADDPS
20
+ * interleaves results across lanes, requiring additional shuffles for final scalar reduction.
21
+ */
22
+ #ifndef NK_MESH_HASWELL_H
23
+ #define NK_MESH_HASWELL_H
24
+
25
+ #if NK_TARGET_X86_
26
+ #if NK_TARGET_HASWELL
27
+
28
+ #include "numkong/types.h"
29
+ #include "numkong/dot/haswell.h"
30
+ #include "numkong/mesh/serial.h"
31
+ #include "numkong/reduce/haswell.h" // `nk_reduce_add_f32x8_haswell_`
32
+ #include "numkong/spatial/haswell.h" // `nk_f32_sqrt_haswell`, `nk_f64_sqrt_haswell`
33
+
34
+ #if defined(__cplusplus)
35
+ extern "C" {
36
+ #endif
37
+
38
+ #if defined(__clang__)
39
+ #pragma clang attribute push(__attribute__((target("avx2,f16c,fma,bmi,bmi2"))), apply_to = function)
40
+ #elif defined(__GNUC__)
41
+ #pragma GCC push_options
42
+ #pragma GCC target("avx2", "f16c", "fma", "bmi", "bmi2")
43
+ #endif
44
+
45
+ /* Deinterleave 24 floats (8 xyz triplets) into separate x, y, z vectors.
46
+ * Uses AVX2 gather instructions for clean stride-3 access.
47
+ *
48
+ * Input: 24 contiguous floats [x0,y0,z0, x1,y1,z1, ..., x7,y7,z7]
49
+ * Output: x[8], y[8], z[8] vectors
50
+ */
51
+ NK_INTERNAL void nk_deinterleave_f32x8_haswell_(nk_f32_t const *ptr, __m256 *x_out, __m256 *y_out, __m256 *z_out) {
52
+ // Gather indices: 0, 3, 6, 9, 12, 15, 18, 21 (stride 3)
53
+ __m256i idx = _mm256_setr_epi32(0, 3, 6, 9, 12, 15, 18, 21);
54
+ *x_out = _mm256_i32gather_ps(ptr + 0, idx, 4);
55
+ *y_out = _mm256_i32gather_ps(ptr + 1, idx, 4);
56
+ *z_out = _mm256_i32gather_ps(ptr + 2, idx, 4);
57
+ }
58
+
59
+ /* Deinterleave 12 f64 values (4 xyz triplets) into separate x, y, z vectors.
60
+ * Uses scalar extraction for simplicity as AVX2 lacks efficient stride-3 gather for f64.
61
+ *
62
+ * Input: 12 contiguous f64 [x0,y0,z0, x1,y1,z1, x2,y2,z2, x3,y3,z3]
63
+ * Output: x[4], y[4], z[4] vectors
64
+ */
65
+ NK_INTERNAL void nk_deinterleave_f64x4_haswell_(nk_f64_t const *ptr, __m256d *x_out, __m256d *y_out, __m256d *z_out) {
66
+ nk_f64_t x0 = ptr[0], x1 = ptr[3], x2 = ptr[6], x3 = ptr[9];
67
+ nk_f64_t y0 = ptr[1], y1 = ptr[4], y2 = ptr[7], y3 = ptr[10];
68
+ nk_f64_t z0 = ptr[2], z1 = ptr[5], z2 = ptr[8], z3 = ptr[11];
69
+
70
+ *x_out = _mm256_setr_pd(x0, x1, x2, x3);
71
+ *y_out = _mm256_setr_pd(y0, y1, y2, y3);
72
+ *z_out = _mm256_setr_pd(z0, z1, z2, z3);
73
+ }
74
+
75
+ /* Horizontal reduction helpers moved to reduce.h:
76
+ * - nk_reduce_add_f32x8_haswell_
77
+ * - nk_reduce_add_f64x4_haswell_
78
+ */
79
+
80
+ NK_INTERNAL nk_f64_t nk_reduce_stable_f64x4_haswell_(__m256d values_f64x4) {
81
+ nk_b256_vec_t values;
82
+ values.ymm_pd = values_f64x4;
83
+ nk_f64_t sum = 0.0, compensation = 0.0;
84
+ nk_accumulate_sum_f64_(&sum, &compensation, values.f64s[0]);
85
+ nk_accumulate_sum_f64_(&sum, &compensation, values.f64s[1]);
86
+ nk_accumulate_sum_f64_(&sum, &compensation, values.f64s[2]);
87
+ nk_accumulate_sum_f64_(&sum, &compensation, values.f64s[3]);
88
+ return sum + compensation;
89
+ }
90
+
91
+ NK_INTERNAL void nk_rotation_from_svd_f64_haswell_(nk_f64_t const *svd_u, nk_f64_t const *svd_v, nk_f64_t *rotation) {
92
+ nk_rotation_from_svd_f64_serial_(svd_u, svd_v, rotation);
93
+ }
94
+
95
+ NK_INTERNAL void nk_accumulate_square_f64x4_haswell_(__m256d *sum_f64x4, __m256d *compensation_f64x4,
96
+ __m256d values_f64x4) {
97
+ __m256d product_f64x4 = _mm256_mul_pd(values_f64x4, values_f64x4);
98
+ __m256d product_error_f64x4 = _mm256_fmsub_pd(values_f64x4, values_f64x4, product_f64x4);
99
+ __m256d tentative_sum_f64x4 = _mm256_add_pd(*sum_f64x4, product_f64x4);
100
+ __m256d virtual_addend_f64x4 = _mm256_sub_pd(tentative_sum_f64x4, *sum_f64x4);
101
+ __m256d sum_error_f64x4 = _mm256_add_pd(
102
+ _mm256_sub_pd(*sum_f64x4, _mm256_sub_pd(tentative_sum_f64x4, virtual_addend_f64x4)),
103
+ _mm256_sub_pd(product_f64x4, virtual_addend_f64x4));
104
+ *sum_f64x4 = tentative_sum_f64x4;
105
+ *compensation_f64x4 = _mm256_add_pd(*compensation_f64x4, _mm256_add_pd(sum_error_f64x4, product_error_f64x4));
106
+ }
107
+
108
+ /* Compute sum of squared distances after applying rotation (and optional scale).
109
+ * Used by kabsch (scale=1.0) and umeyama (scale=computed_scale).
110
+ * Returns sum_squared, caller computes sqrt(sum_squared / n).
111
+ */
112
+ NK_INTERNAL nk_f64_t nk_transformed_ssd_f32_haswell_(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n,
113
+ nk_f64_t const *r, nk_f64_t scale, nk_f64_t centroid_a_x,
114
+ nk_f64_t centroid_a_y, nk_f64_t centroid_a_z,
115
+ nk_f64_t centroid_b_x, nk_f64_t centroid_b_y,
116
+ nk_f64_t centroid_b_z) {
117
+ __m256d scaled_rotation_x_x_f64x4 = _mm256_set1_pd(scale * r[0]);
118
+ __m256d scaled_rotation_x_y_f64x4 = _mm256_set1_pd(scale * r[1]);
119
+ __m256d scaled_rotation_x_z_f64x4 = _mm256_set1_pd(scale * r[2]);
120
+ __m256d scaled_rotation_y_x_f64x4 = _mm256_set1_pd(scale * r[3]);
121
+ __m256d scaled_rotation_y_y_f64x4 = _mm256_set1_pd(scale * r[4]);
122
+ __m256d scaled_rotation_y_z_f64x4 = _mm256_set1_pd(scale * r[5]);
123
+ __m256d scaled_rotation_z_x_f64x4 = _mm256_set1_pd(scale * r[6]);
124
+ __m256d scaled_rotation_z_y_f64x4 = _mm256_set1_pd(scale * r[7]);
125
+ __m256d scaled_rotation_z_z_f64x4 = _mm256_set1_pd(scale * r[8]);
126
+ __m256d centroid_a_x_f64x4 = _mm256_set1_pd(centroid_a_x), centroid_a_y_f64x4 = _mm256_set1_pd(centroid_a_y);
127
+ __m256d centroid_a_z_f64x4 = _mm256_set1_pd(centroid_a_z), centroid_b_x_f64x4 = _mm256_set1_pd(centroid_b_x);
128
+ __m256d centroid_b_y_f64x4 = _mm256_set1_pd(centroid_b_y), centroid_b_z_f64x4 = _mm256_set1_pd(centroid_b_z);
129
+ __m256d sum_squared_f64x4 = _mm256_setzero_pd();
130
+ __m256 a_x_f32x8, a_y_f32x8, a_z_f32x8, b_x_f32x8, b_y_f32x8, b_z_f32x8;
131
+ nk_size_t index = 0;
132
+
133
+ for (; index + 8 <= n; index += 8) {
134
+ nk_deinterleave_f32x8_haswell_(a + index * 3, &a_x_f32x8, &a_y_f32x8, &a_z_f32x8),
135
+ nk_deinterleave_f32x8_haswell_(b + index * 3, &b_x_f32x8, &b_y_f32x8, &b_z_f32x8);
136
+
137
+ __m256d a_x_lower_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(a_x_f32x8));
138
+ __m256d a_x_upper_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(a_x_f32x8, 1));
139
+ __m256d a_y_lower_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(a_y_f32x8));
140
+ __m256d a_y_upper_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(a_y_f32x8, 1));
141
+ __m256d a_z_lower_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(a_z_f32x8));
142
+ __m256d a_z_upper_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(a_z_f32x8, 1));
143
+ __m256d b_x_lower_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(b_x_f32x8));
144
+ __m256d b_x_upper_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(b_x_f32x8, 1));
145
+ __m256d b_y_lower_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(b_y_f32x8));
146
+ __m256d b_y_upper_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(b_y_f32x8, 1));
147
+ __m256d b_z_lower_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(b_z_f32x8));
148
+ __m256d b_z_upper_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(b_z_f32x8, 1));
149
+
150
+ __m256d centered_a_x_lower_f64x4 = _mm256_sub_pd(a_x_lower_f64x4, centroid_a_x_f64x4);
151
+ __m256d centered_a_x_upper_f64x4 = _mm256_sub_pd(a_x_upper_f64x4, centroid_a_x_f64x4);
152
+ __m256d centered_a_y_lower_f64x4 = _mm256_sub_pd(a_y_lower_f64x4, centroid_a_y_f64x4);
153
+ __m256d centered_a_y_upper_f64x4 = _mm256_sub_pd(a_y_upper_f64x4, centroid_a_y_f64x4);
154
+ __m256d centered_a_z_lower_f64x4 = _mm256_sub_pd(a_z_lower_f64x4, centroid_a_z_f64x4);
155
+ __m256d centered_a_z_upper_f64x4 = _mm256_sub_pd(a_z_upper_f64x4, centroid_a_z_f64x4);
156
+ __m256d centered_b_x_lower_f64x4 = _mm256_sub_pd(b_x_lower_f64x4, centroid_b_x_f64x4);
157
+ __m256d centered_b_x_upper_f64x4 = _mm256_sub_pd(b_x_upper_f64x4, centroid_b_x_f64x4);
158
+ __m256d centered_b_y_lower_f64x4 = _mm256_sub_pd(b_y_lower_f64x4, centroid_b_y_f64x4);
159
+ __m256d centered_b_y_upper_f64x4 = _mm256_sub_pd(b_y_upper_f64x4, centroid_b_y_f64x4);
160
+ __m256d centered_b_z_lower_f64x4 = _mm256_sub_pd(b_z_lower_f64x4, centroid_b_z_f64x4);
161
+ __m256d centered_b_z_upper_f64x4 = _mm256_sub_pd(b_z_upper_f64x4, centroid_b_z_f64x4);
162
+
163
+ __m256d rotated_a_x_lower_f64x4 = _mm256_fmadd_pd(
164
+ scaled_rotation_x_z_f64x4, centered_a_z_lower_f64x4,
165
+ _mm256_fmadd_pd(scaled_rotation_x_y_f64x4, centered_a_y_lower_f64x4,
166
+ _mm256_mul_pd(scaled_rotation_x_x_f64x4, centered_a_x_lower_f64x4)));
167
+ __m256d rotated_a_x_upper_f64x4 = _mm256_fmadd_pd(
168
+ scaled_rotation_x_z_f64x4, centered_a_z_upper_f64x4,
169
+ _mm256_fmadd_pd(scaled_rotation_x_y_f64x4, centered_a_y_upper_f64x4,
170
+ _mm256_mul_pd(scaled_rotation_x_x_f64x4, centered_a_x_upper_f64x4)));
171
+ __m256d rotated_a_y_lower_f64x4 = _mm256_fmadd_pd(
172
+ scaled_rotation_y_z_f64x4, centered_a_z_lower_f64x4,
173
+ _mm256_fmadd_pd(scaled_rotation_y_y_f64x4, centered_a_y_lower_f64x4,
174
+ _mm256_mul_pd(scaled_rotation_y_x_f64x4, centered_a_x_lower_f64x4)));
175
+ __m256d rotated_a_y_upper_f64x4 = _mm256_fmadd_pd(
176
+ scaled_rotation_y_z_f64x4, centered_a_z_upper_f64x4,
177
+ _mm256_fmadd_pd(scaled_rotation_y_y_f64x4, centered_a_y_upper_f64x4,
178
+ _mm256_mul_pd(scaled_rotation_y_x_f64x4, centered_a_x_upper_f64x4)));
179
+ __m256d rotated_a_z_lower_f64x4 = _mm256_fmadd_pd(
180
+ scaled_rotation_z_z_f64x4, centered_a_z_lower_f64x4,
181
+ _mm256_fmadd_pd(scaled_rotation_z_y_f64x4, centered_a_y_lower_f64x4,
182
+ _mm256_mul_pd(scaled_rotation_z_x_f64x4, centered_a_x_lower_f64x4)));
183
+ __m256d rotated_a_z_upper_f64x4 = _mm256_fmadd_pd(
184
+ scaled_rotation_z_z_f64x4, centered_a_z_upper_f64x4,
185
+ _mm256_fmadd_pd(scaled_rotation_z_y_f64x4, centered_a_y_upper_f64x4,
186
+ _mm256_mul_pd(scaled_rotation_z_x_f64x4, centered_a_x_upper_f64x4)));
187
+
188
+ __m256d delta_x_lower_f64x4 = _mm256_sub_pd(rotated_a_x_lower_f64x4, centered_b_x_lower_f64x4);
189
+ __m256d delta_x_upper_f64x4 = _mm256_sub_pd(rotated_a_x_upper_f64x4, centered_b_x_upper_f64x4);
190
+ __m256d delta_y_lower_f64x4 = _mm256_sub_pd(rotated_a_y_lower_f64x4, centered_b_y_lower_f64x4);
191
+ __m256d delta_y_upper_f64x4 = _mm256_sub_pd(rotated_a_y_upper_f64x4, centered_b_y_upper_f64x4);
192
+ __m256d delta_z_lower_f64x4 = _mm256_sub_pd(rotated_a_z_lower_f64x4, centered_b_z_lower_f64x4);
193
+ __m256d delta_z_upper_f64x4 = _mm256_sub_pd(rotated_a_z_upper_f64x4, centered_b_z_upper_f64x4);
194
+
195
+ __m256d batch_sum_squared_f64x4 = _mm256_add_pd(_mm256_mul_pd(delta_x_lower_f64x4, delta_x_lower_f64x4),
196
+ _mm256_mul_pd(delta_x_upper_f64x4, delta_x_upper_f64x4));
197
+ batch_sum_squared_f64x4 = _mm256_fmadd_pd(delta_y_lower_f64x4, delta_y_lower_f64x4, batch_sum_squared_f64x4);
198
+ batch_sum_squared_f64x4 = _mm256_fmadd_pd(delta_y_upper_f64x4, delta_y_upper_f64x4, batch_sum_squared_f64x4);
199
+ batch_sum_squared_f64x4 = _mm256_fmadd_pd(delta_z_lower_f64x4, delta_z_lower_f64x4, batch_sum_squared_f64x4);
200
+ batch_sum_squared_f64x4 = _mm256_fmadd_pd(delta_z_upper_f64x4, delta_z_upper_f64x4, batch_sum_squared_f64x4);
201
+ sum_squared_f64x4 = _mm256_add_pd(sum_squared_f64x4, batch_sum_squared_f64x4);
202
+ }
203
+
204
+ nk_f64_t sum_squared = nk_reduce_add_f64x4_haswell_(sum_squared_f64x4);
205
+ for (; index < n; ++index) {
206
+ nk_f64_t centered_a_x = (nk_f64_t)a[index * 3 + 0] - centroid_a_x;
207
+ nk_f64_t centered_a_y = (nk_f64_t)a[index * 3 + 1] - centroid_a_y;
208
+ nk_f64_t centered_a_z = (nk_f64_t)a[index * 3 + 2] - centroid_a_z;
209
+ nk_f64_t centered_b_x = (nk_f64_t)b[index * 3 + 0] - centroid_b_x;
210
+ nk_f64_t centered_b_y = (nk_f64_t)b[index * 3 + 1] - centroid_b_y;
211
+ nk_f64_t centered_b_z = (nk_f64_t)b[index * 3 + 2] - centroid_b_z;
212
+ nk_f64_t rotated_a_x = scale * (r[0] * centered_a_x + r[1] * centered_a_y + r[2] * centered_a_z);
213
+ nk_f64_t rotated_a_y = scale * (r[3] * centered_a_x + r[4] * centered_a_y + r[5] * centered_a_z);
214
+ nk_f64_t rotated_a_z = scale * (r[6] * centered_a_x + r[7] * centered_a_y + r[8] * centered_a_z);
215
+ nk_f64_t delta_x = rotated_a_x - centered_b_x, delta_y = rotated_a_y - centered_b_y,
216
+ delta_z = rotated_a_z - centered_b_z;
217
+ sum_squared += delta_x * delta_x + delta_y * delta_y + delta_z * delta_z;
218
+ }
219
+
220
+ return sum_squared;
221
+ }
222
+
223
+ /* Compute sum of squared distances for f64 after applying rotation (and optional scale).
224
+ * Rotation matrix, scale and data are all f64 for full precision.
225
+ */
226
+ NK_INTERNAL nk_f64_t nk_transformed_ssd_f64_haswell_(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n,
227
+ nk_f64_t const *r, nk_f64_t scale, nk_f64_t centroid_a_x,
228
+ nk_f64_t centroid_a_y, nk_f64_t centroid_a_z,
229
+ nk_f64_t centroid_b_x, nk_f64_t centroid_b_y,
230
+ nk_f64_t centroid_b_z) {
231
+ // Broadcast scaled rotation matrix elements
232
+ __m256d scaled_rotation_x_x_f64x4 = _mm256_set1_pd(scale * r[0]);
233
+ __m256d scaled_rotation_x_y_f64x4 = _mm256_set1_pd(scale * r[1]);
234
+ __m256d scaled_rotation_x_z_f64x4 = _mm256_set1_pd(scale * r[2]);
235
+ __m256d scaled_rotation_y_x_f64x4 = _mm256_set1_pd(scale * r[3]);
236
+ __m256d scaled_rotation_y_y_f64x4 = _mm256_set1_pd(scale * r[4]);
237
+ __m256d scaled_rotation_y_z_f64x4 = _mm256_set1_pd(scale * r[5]);
238
+ __m256d scaled_rotation_z_x_f64x4 = _mm256_set1_pd(scale * r[6]);
239
+ __m256d scaled_rotation_z_y_f64x4 = _mm256_set1_pd(scale * r[7]);
240
+ __m256d scaled_rotation_z_z_f64x4 = _mm256_set1_pd(scale * r[8]);
241
+
242
+ // Broadcast centroids
243
+ __m256d centroid_a_x_f64x4 = _mm256_set1_pd(centroid_a_x);
244
+ __m256d centroid_a_y_f64x4 = _mm256_set1_pd(centroid_a_y);
245
+ __m256d centroid_a_z_f64x4 = _mm256_set1_pd(centroid_a_z);
246
+ __m256d centroid_b_x_f64x4 = _mm256_set1_pd(centroid_b_x);
247
+ __m256d centroid_b_y_f64x4 = _mm256_set1_pd(centroid_b_y);
248
+ __m256d centroid_b_z_f64x4 = _mm256_set1_pd(centroid_b_z);
249
+
250
+ __m256d sum_squared_f64x4 = _mm256_setzero_pd();
251
+ __m256d sum_squared_compensation_f64x4 = _mm256_setzero_pd();
252
+ __m256d a_x_f64x4, a_y_f64x4, a_z_f64x4, b_x_f64x4, b_y_f64x4, b_z_f64x4;
253
+ nk_size_t j = 0;
254
+
255
+ for (; j + 4 <= n; j += 4) {
256
+ nk_deinterleave_f64x4_haswell_(a + j * 3, &a_x_f64x4, &a_y_f64x4, &a_z_f64x4);
257
+ nk_deinterleave_f64x4_haswell_(b + j * 3, &b_x_f64x4, &b_y_f64x4, &b_z_f64x4);
258
+
259
+ // Center points
260
+ __m256d pa_x_f64x4 = _mm256_sub_pd(a_x_f64x4, centroid_a_x_f64x4);
261
+ __m256d pa_y_f64x4 = _mm256_sub_pd(a_y_f64x4, centroid_a_y_f64x4);
262
+ __m256d pa_z_f64x4 = _mm256_sub_pd(a_z_f64x4, centroid_a_z_f64x4);
263
+ __m256d pb_x_f64x4 = _mm256_sub_pd(b_x_f64x4, centroid_b_x_f64x4);
264
+ __m256d pb_y_f64x4 = _mm256_sub_pd(b_y_f64x4, centroid_b_y_f64x4);
265
+ __m256d pb_z_f64x4 = _mm256_sub_pd(b_z_f64x4, centroid_b_z_f64x4);
266
+
267
+ // Rotate and scale: ra = scale * R * pa
268
+ __m256d ra_x_f64x4 = _mm256_fmadd_pd(scaled_rotation_x_z_f64x4, pa_z_f64x4,
269
+ _mm256_fmadd_pd(scaled_rotation_x_y_f64x4, pa_y_f64x4,
270
+ _mm256_mul_pd(scaled_rotation_x_x_f64x4, pa_x_f64x4)));
271
+ __m256d ra_y_f64x4 = _mm256_fmadd_pd(scaled_rotation_y_z_f64x4, pa_z_f64x4,
272
+ _mm256_fmadd_pd(scaled_rotation_y_y_f64x4, pa_y_f64x4,
273
+ _mm256_mul_pd(scaled_rotation_y_x_f64x4, pa_x_f64x4)));
274
+ __m256d ra_z_f64x4 = _mm256_fmadd_pd(scaled_rotation_z_z_f64x4, pa_z_f64x4,
275
+ _mm256_fmadd_pd(scaled_rotation_z_y_f64x4, pa_y_f64x4,
276
+ _mm256_mul_pd(scaled_rotation_z_x_f64x4, pa_x_f64x4)));
277
+
278
+ // Delta and accumulate
279
+ __m256d delta_x_f64x4 = _mm256_sub_pd(ra_x_f64x4, pb_x_f64x4);
280
+ __m256d delta_y_f64x4 = _mm256_sub_pd(ra_y_f64x4, pb_y_f64x4);
281
+ __m256d delta_z_f64x4 = _mm256_sub_pd(ra_z_f64x4, pb_z_f64x4);
282
+
283
+ nk_accumulate_square_f64x4_haswell_(&sum_squared_f64x4, &sum_squared_compensation_f64x4, delta_x_f64x4);
284
+ nk_accumulate_square_f64x4_haswell_(&sum_squared_f64x4, &sum_squared_compensation_f64x4, delta_y_f64x4);
285
+ nk_accumulate_square_f64x4_haswell_(&sum_squared_f64x4, &sum_squared_compensation_f64x4, delta_z_f64x4);
286
+ }
287
+
288
+ nk_f64_t sum_squared = nk_dot_stable_sum_f64x4_haswell_(sum_squared_f64x4, sum_squared_compensation_f64x4);
289
+ nk_f64_t sum_squared_compensation = 0.0;
290
+
291
+ // Scalar tail
292
+ for (; j < n; ++j) {
293
+ nk_f64_t pa_x = a[j * 3 + 0] - centroid_a_x;
294
+ nk_f64_t pa_y = a[j * 3 + 1] - centroid_a_y;
295
+ nk_f64_t pa_z = a[j * 3 + 2] - centroid_a_z;
296
+ nk_f64_t pb_x = b[j * 3 + 0] - centroid_b_x;
297
+ nk_f64_t pb_y = b[j * 3 + 1] - centroid_b_y;
298
+ nk_f64_t pb_z = b[j * 3 + 2] - centroid_b_z;
299
+
300
+ nk_f64_t ra_x = scale * (r[0] * pa_x + r[1] * pa_y + r[2] * pa_z);
301
+ nk_f64_t ra_y = scale * (r[3] * pa_x + r[4] * pa_y + r[5] * pa_z);
302
+ nk_f64_t ra_z = scale * (r[6] * pa_x + r[7] * pa_y + r[8] * pa_z);
303
+
304
+ nk_f64_t delta_x = ra_x - pb_x;
305
+ nk_f64_t delta_y = ra_y - pb_y;
306
+ nk_f64_t delta_z = ra_z - pb_z;
307
+ nk_accumulate_square_f64_(&sum_squared, &sum_squared_compensation, delta_x);
308
+ nk_accumulate_square_f64_(&sum_squared, &sum_squared_compensation, delta_y);
309
+ nk_accumulate_square_f64_(&sum_squared, &sum_squared_compensation, delta_z);
310
+ }
311
+
312
+ return sum_squared + sum_squared_compensation;
313
+ }
314
+
315
+ NK_PUBLIC void nk_rmsd_f32_haswell(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f32_t *a_centroid,
316
+ nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f64_t *result) {
317
+ if (rotation)
318
+ rotation[0] = 1, rotation[1] = 0, rotation[2] = 0, rotation[3] = 0, rotation[4] = 1, rotation[5] = 0,
319
+ rotation[6] = 0, rotation[7] = 0, rotation[8] = 1;
320
+ if (scale) *scale = 1.0f;
321
+
322
+ __m256d sum_a_x_f64x4 = _mm256_setzero_pd(), sum_a_y_f64x4 = _mm256_setzero_pd();
323
+ __m256d sum_a_z_f64x4 = _mm256_setzero_pd(), sum_b_x_f64x4 = _mm256_setzero_pd();
324
+ __m256d sum_b_y_f64x4 = _mm256_setzero_pd(), sum_b_z_f64x4 = _mm256_setzero_pd();
325
+ __m256d sum_squared_f64x4 = _mm256_setzero_pd();
326
+ __m256 a_x_f32x8, a_y_f32x8, a_z_f32x8, b_x_f32x8, b_y_f32x8, b_z_f32x8;
327
+ nk_size_t index = 0;
328
+
329
+ for (; index + 8 <= n; index += 8) {
330
+ nk_deinterleave_f32x8_haswell_(a + index * 3, &a_x_f32x8, &a_y_f32x8, &a_z_f32x8),
331
+ nk_deinterleave_f32x8_haswell_(b + index * 3, &b_x_f32x8, &b_y_f32x8, &b_z_f32x8);
332
+
333
+ __m256d a_x_lower_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(a_x_f32x8));
334
+ __m256d a_x_upper_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(a_x_f32x8, 1));
335
+ __m256d a_y_lower_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(a_y_f32x8));
336
+ __m256d a_y_upper_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(a_y_f32x8, 1));
337
+ __m256d a_z_lower_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(a_z_f32x8));
338
+ __m256d a_z_upper_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(a_z_f32x8, 1));
339
+ __m256d b_x_lower_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(b_x_f32x8));
340
+ __m256d b_x_upper_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(b_x_f32x8, 1));
341
+ __m256d b_y_lower_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(b_y_f32x8));
342
+ __m256d b_y_upper_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(b_y_f32x8, 1));
343
+ __m256d b_z_lower_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(b_z_f32x8));
344
+ __m256d b_z_upper_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(b_z_f32x8, 1));
345
+
346
+ sum_a_x_f64x4 = _mm256_add_pd(sum_a_x_f64x4, _mm256_add_pd(a_x_lower_f64x4, a_x_upper_f64x4));
347
+ sum_a_y_f64x4 = _mm256_add_pd(sum_a_y_f64x4, _mm256_add_pd(a_y_lower_f64x4, a_y_upper_f64x4));
348
+ sum_a_z_f64x4 = _mm256_add_pd(sum_a_z_f64x4, _mm256_add_pd(a_z_lower_f64x4, a_z_upper_f64x4));
349
+ sum_b_x_f64x4 = _mm256_add_pd(sum_b_x_f64x4, _mm256_add_pd(b_x_lower_f64x4, b_x_upper_f64x4));
350
+ sum_b_y_f64x4 = _mm256_add_pd(sum_b_y_f64x4, _mm256_add_pd(b_y_lower_f64x4, b_y_upper_f64x4));
351
+ sum_b_z_f64x4 = _mm256_add_pd(sum_b_z_f64x4, _mm256_add_pd(b_z_lower_f64x4, b_z_upper_f64x4));
352
+
353
+ __m256d delta_x_lower_f64x4 = _mm256_sub_pd(a_x_lower_f64x4, b_x_lower_f64x4);
354
+ __m256d delta_x_upper_f64x4 = _mm256_sub_pd(a_x_upper_f64x4, b_x_upper_f64x4);
355
+ __m256d delta_y_lower_f64x4 = _mm256_sub_pd(a_y_lower_f64x4, b_y_lower_f64x4);
356
+ __m256d delta_y_upper_f64x4 = _mm256_sub_pd(a_y_upper_f64x4, b_y_upper_f64x4);
357
+ __m256d delta_z_lower_f64x4 = _mm256_sub_pd(a_z_lower_f64x4, b_z_lower_f64x4);
358
+ __m256d delta_z_upper_f64x4 = _mm256_sub_pd(a_z_upper_f64x4, b_z_upper_f64x4);
359
+ __m256d batch_sum_squared_f64x4 = _mm256_add_pd(_mm256_mul_pd(delta_x_lower_f64x4, delta_x_lower_f64x4),
360
+ _mm256_mul_pd(delta_x_upper_f64x4, delta_x_upper_f64x4));
361
+ batch_sum_squared_f64x4 = _mm256_fmadd_pd(delta_y_lower_f64x4, delta_y_lower_f64x4, batch_sum_squared_f64x4);
362
+ batch_sum_squared_f64x4 = _mm256_fmadd_pd(delta_y_upper_f64x4, delta_y_upper_f64x4, batch_sum_squared_f64x4);
363
+ batch_sum_squared_f64x4 = _mm256_fmadd_pd(delta_z_lower_f64x4, delta_z_lower_f64x4, batch_sum_squared_f64x4);
364
+ batch_sum_squared_f64x4 = _mm256_fmadd_pd(delta_z_upper_f64x4, delta_z_upper_f64x4, batch_sum_squared_f64x4);
365
+ sum_squared_f64x4 = _mm256_add_pd(sum_squared_f64x4, batch_sum_squared_f64x4);
366
+ }
367
+
368
+ nk_f64_t total_a_x = nk_reduce_add_f64x4_haswell_(sum_a_x_f64x4);
369
+ nk_f64_t total_a_y = nk_reduce_add_f64x4_haswell_(sum_a_y_f64x4);
370
+ nk_f64_t total_a_z = nk_reduce_add_f64x4_haswell_(sum_a_z_f64x4);
371
+ nk_f64_t total_b_x = nk_reduce_add_f64x4_haswell_(sum_b_x_f64x4);
372
+ nk_f64_t total_b_y = nk_reduce_add_f64x4_haswell_(sum_b_y_f64x4);
373
+ nk_f64_t total_b_z = nk_reduce_add_f64x4_haswell_(sum_b_z_f64x4);
374
+ nk_f64_t sum_squared = nk_reduce_add_f64x4_haswell_(sum_squared_f64x4);
375
+
376
+ for (; index < n; ++index) {
377
+ nk_f64_t a_x = a[index * 3 + 0], a_y = a[index * 3 + 1], a_z = a[index * 3 + 2];
378
+ nk_f64_t b_x = b[index * 3 + 0], b_y = b[index * 3 + 1], b_z = b[index * 3 + 2];
379
+ total_a_x += a_x, total_a_y += a_y, total_a_z += a_z;
380
+ total_b_x += b_x, total_b_y += b_y, total_b_z += b_z;
381
+ nk_f64_t delta_x = a_x - b_x, delta_y = a_y - b_y, delta_z = a_z - b_z;
382
+ sum_squared += delta_x * delta_x + delta_y * delta_y + delta_z * delta_z;
383
+ }
384
+
385
+ nk_f64_t inv_n = 1.0 / (nk_f64_t)n;
386
+ nk_f64_t centroid_a_x = total_a_x * inv_n, centroid_a_y = total_a_y * inv_n, centroid_a_z = total_a_z * inv_n;
387
+ nk_f64_t centroid_b_x = total_b_x * inv_n, centroid_b_y = total_b_y * inv_n, centroid_b_z = total_b_z * inv_n;
388
+ if (a_centroid)
389
+ a_centroid[0] = (nk_f32_t)centroid_a_x, a_centroid[1] = (nk_f32_t)centroid_a_y,
390
+ a_centroid[2] = (nk_f32_t)centroid_a_z;
391
+ if (b_centroid)
392
+ b_centroid[0] = (nk_f32_t)centroid_b_x, b_centroid[1] = (nk_f32_t)centroid_b_y,
393
+ b_centroid[2] = (nk_f32_t)centroid_b_z;
394
+
395
+ nk_f64_t mean_delta_x = centroid_a_x - centroid_b_x, mean_delta_y = centroid_a_y - centroid_b_y,
396
+ mean_delta_z = centroid_a_z - centroid_b_z;
397
+ nk_f64_t mean_delta_squared = mean_delta_x * mean_delta_x + mean_delta_y * mean_delta_y +
398
+ mean_delta_z * mean_delta_z;
399
+ *result = nk_f64_sqrt_haswell(sum_squared * inv_n - mean_delta_squared);
400
+ }
401
+
402
+ NK_PUBLIC void nk_rmsd_f64_haswell(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *a_centroid,
403
+ nk_f64_t *b_centroid, nk_f64_t *rotation, nk_f64_t *scale, nk_f64_t *result) {
404
+ /* RMSD uses identity rotation and scale=1.0 */
405
+ if (rotation) {
406
+ rotation[0] = 1, rotation[1] = 0, rotation[2] = 0;
407
+ rotation[3] = 0, rotation[4] = 1, rotation[5] = 0;
408
+ rotation[6] = 0, rotation[7] = 0, rotation[8] = 1;
409
+ }
410
+ if (scale) *scale = 1.0;
411
+ __m256d const zeros_f64x4 = _mm256_setzero_pd();
412
+
413
+ // Accumulators for centroids and squared differences
414
+ __m256d sum_a_x_f64x4 = zeros_f64x4, sum_a_y_f64x4 = zeros_f64x4, sum_a_z_f64x4 = zeros_f64x4;
415
+ __m256d sum_b_x_f64x4 = zeros_f64x4, sum_b_y_f64x4 = zeros_f64x4, sum_b_z_f64x4 = zeros_f64x4;
416
+ __m256d sum_squared_x_f64x4 = zeros_f64x4, sum_squared_y_f64x4 = zeros_f64x4, sum_squared_z_f64x4 = zeros_f64x4;
417
+
418
+ __m256d a_x_f64x4, a_y_f64x4, a_z_f64x4, b_x_f64x4, b_y_f64x4, b_z_f64x4;
419
+ nk_size_t i = 0;
420
+
421
+ // Main loop with 2x unrolling
422
+ for (; i + 8 <= n; i += 8) {
423
+ // Iteration 0
424
+ nk_deinterleave_f64x4_haswell_(a + i * 3, &a_x_f64x4, &a_y_f64x4, &a_z_f64x4);
425
+ nk_deinterleave_f64x4_haswell_(b + i * 3, &b_x_f64x4, &b_y_f64x4, &b_z_f64x4);
426
+
427
+ sum_a_x_f64x4 = _mm256_add_pd(sum_a_x_f64x4, a_x_f64x4);
428
+ sum_a_y_f64x4 = _mm256_add_pd(sum_a_y_f64x4, a_y_f64x4);
429
+ sum_a_z_f64x4 = _mm256_add_pd(sum_a_z_f64x4, a_z_f64x4);
430
+ sum_b_x_f64x4 = _mm256_add_pd(sum_b_x_f64x4, b_x_f64x4);
431
+ sum_b_y_f64x4 = _mm256_add_pd(sum_b_y_f64x4, b_y_f64x4);
432
+ sum_b_z_f64x4 = _mm256_add_pd(sum_b_z_f64x4, b_z_f64x4);
433
+
434
+ __m256d delta_x_f64x4 = _mm256_sub_pd(a_x_f64x4, b_x_f64x4);
435
+ __m256d delta_y_f64x4 = _mm256_sub_pd(a_y_f64x4, b_y_f64x4);
436
+ __m256d delta_z_f64x4 = _mm256_sub_pd(a_z_f64x4, b_z_f64x4);
437
+
438
+ sum_squared_x_f64x4 = _mm256_fmadd_pd(delta_x_f64x4, delta_x_f64x4, sum_squared_x_f64x4);
439
+ sum_squared_y_f64x4 = _mm256_fmadd_pd(delta_y_f64x4, delta_y_f64x4, sum_squared_y_f64x4);
440
+ sum_squared_z_f64x4 = _mm256_fmadd_pd(delta_z_f64x4, delta_z_f64x4, sum_squared_z_f64x4);
441
+
442
+ // Iteration 1
443
+ __m256d a_x1_f64x4, a_y1_f64x4, a_z1_f64x4, b_x1_f64x4, b_y1_f64x4, b_z1_f64x4;
444
+ nk_deinterleave_f64x4_haswell_(a + (i + 4) * 3, &a_x1_f64x4, &a_y1_f64x4, &a_z1_f64x4);
445
+ nk_deinterleave_f64x4_haswell_(b + (i + 4) * 3, &b_x1_f64x4, &b_y1_f64x4, &b_z1_f64x4);
446
+
447
+ sum_a_x_f64x4 = _mm256_add_pd(sum_a_x_f64x4, a_x1_f64x4);
448
+ sum_a_y_f64x4 = _mm256_add_pd(sum_a_y_f64x4, a_y1_f64x4);
449
+ sum_a_z_f64x4 = _mm256_add_pd(sum_a_z_f64x4, a_z1_f64x4);
450
+ sum_b_x_f64x4 = _mm256_add_pd(sum_b_x_f64x4, b_x1_f64x4);
451
+ sum_b_y_f64x4 = _mm256_add_pd(sum_b_y_f64x4, b_y1_f64x4);
452
+ sum_b_z_f64x4 = _mm256_add_pd(sum_b_z_f64x4, b_z1_f64x4);
453
+
454
+ __m256d delta_x1_f64x4 = _mm256_sub_pd(a_x1_f64x4, b_x1_f64x4);
455
+ __m256d delta_y1_f64x4 = _mm256_sub_pd(a_y1_f64x4, b_y1_f64x4);
456
+ __m256d delta_z1_f64x4 = _mm256_sub_pd(a_z1_f64x4, b_z1_f64x4);
457
+
458
+ sum_squared_x_f64x4 = _mm256_fmadd_pd(delta_x1_f64x4, delta_x1_f64x4, sum_squared_x_f64x4);
459
+ sum_squared_y_f64x4 = _mm256_fmadd_pd(delta_y1_f64x4, delta_y1_f64x4, sum_squared_y_f64x4);
460
+ sum_squared_z_f64x4 = _mm256_fmadd_pd(delta_z1_f64x4, delta_z1_f64x4, sum_squared_z_f64x4);
461
+ }
462
+
463
+ // Handle 4-point remainder
464
+ for (; i + 4 <= n; i += 4) {
465
+ nk_deinterleave_f64x4_haswell_(a + i * 3, &a_x_f64x4, &a_y_f64x4, &a_z_f64x4);
466
+ nk_deinterleave_f64x4_haswell_(b + i * 3, &b_x_f64x4, &b_y_f64x4, &b_z_f64x4);
467
+
468
+ sum_a_x_f64x4 = _mm256_add_pd(sum_a_x_f64x4, a_x_f64x4);
469
+ sum_a_y_f64x4 = _mm256_add_pd(sum_a_y_f64x4, a_y_f64x4);
470
+ sum_a_z_f64x4 = _mm256_add_pd(sum_a_z_f64x4, a_z_f64x4);
471
+ sum_b_x_f64x4 = _mm256_add_pd(sum_b_x_f64x4, b_x_f64x4);
472
+ sum_b_y_f64x4 = _mm256_add_pd(sum_b_y_f64x4, b_y_f64x4);
473
+ sum_b_z_f64x4 = _mm256_add_pd(sum_b_z_f64x4, b_z_f64x4);
474
+
475
+ __m256d delta_x_f64x4 = _mm256_sub_pd(a_x_f64x4, b_x_f64x4);
476
+ __m256d delta_y_f64x4 = _mm256_sub_pd(a_y_f64x4, b_y_f64x4);
477
+ __m256d delta_z_f64x4 = _mm256_sub_pd(a_z_f64x4, b_z_f64x4);
478
+
479
+ sum_squared_x_f64x4 = _mm256_fmadd_pd(delta_x_f64x4, delta_x_f64x4, sum_squared_x_f64x4);
480
+ sum_squared_y_f64x4 = _mm256_fmadd_pd(delta_y_f64x4, delta_y_f64x4, sum_squared_y_f64x4);
481
+ sum_squared_z_f64x4 = _mm256_fmadd_pd(delta_z_f64x4, delta_z_f64x4, sum_squared_z_f64x4);
482
+ }
483
+
484
+ // Reduce vectors to scalars
485
+ nk_f64_t total_ax = nk_reduce_stable_f64x4_haswell_(sum_a_x_f64x4), total_ax_compensation = 0.0;
486
+ nk_f64_t total_ay = nk_reduce_stable_f64x4_haswell_(sum_a_y_f64x4), total_ay_compensation = 0.0;
487
+ nk_f64_t total_az = nk_reduce_stable_f64x4_haswell_(sum_a_z_f64x4), total_az_compensation = 0.0;
488
+ nk_f64_t total_bx = nk_reduce_stable_f64x4_haswell_(sum_b_x_f64x4), total_bx_compensation = 0.0;
489
+ nk_f64_t total_by = nk_reduce_stable_f64x4_haswell_(sum_b_y_f64x4), total_by_compensation = 0.0;
490
+ nk_f64_t total_bz = nk_reduce_stable_f64x4_haswell_(sum_b_z_f64x4), total_bz_compensation = 0.0;
491
+ nk_f64_t total_sq_x = nk_reduce_stable_f64x4_haswell_(sum_squared_x_f64x4), total_sq_x_compensation = 0.0;
492
+ nk_f64_t total_sq_y = nk_reduce_stable_f64x4_haswell_(sum_squared_y_f64x4), total_sq_y_compensation = 0.0;
493
+ nk_f64_t total_sq_z = nk_reduce_stable_f64x4_haswell_(sum_squared_z_f64x4), total_sq_z_compensation = 0.0;
494
+
495
+ // Scalar tail
496
+ for (; i < n; ++i) {
497
+ nk_f64_t ax = a[i * 3 + 0], ay = a[i * 3 + 1], az = a[i * 3 + 2];
498
+ nk_f64_t bx = b[i * 3 + 0], by = b[i * 3 + 1], bz = b[i * 3 + 2];
499
+ nk_accumulate_sum_f64_(&total_ax, &total_ax_compensation, ax);
500
+ nk_accumulate_sum_f64_(&total_ay, &total_ay_compensation, ay);
501
+ nk_accumulate_sum_f64_(&total_az, &total_az_compensation, az);
502
+ nk_accumulate_sum_f64_(&total_bx, &total_bx_compensation, bx);
503
+ nk_accumulate_sum_f64_(&total_by, &total_by_compensation, by);
504
+ nk_accumulate_sum_f64_(&total_bz, &total_bz_compensation, bz);
505
+ nk_f64_t delta_x = ax - bx, delta_y = ay - by, delta_z = az - bz;
506
+ nk_accumulate_square_f64_(&total_sq_x, &total_sq_x_compensation, delta_x);
507
+ nk_accumulate_square_f64_(&total_sq_y, &total_sq_y_compensation, delta_y);
508
+ nk_accumulate_square_f64_(&total_sq_z, &total_sq_z_compensation, delta_z);
509
+ }
510
+
511
+ total_ax += total_ax_compensation, total_ay += total_ay_compensation, total_az += total_az_compensation;
512
+ total_bx += total_bx_compensation, total_by += total_by_compensation, total_bz += total_bz_compensation;
513
+ total_sq_x += total_sq_x_compensation, total_sq_y += total_sq_y_compensation, total_sq_z += total_sq_z_compensation;
514
+
515
+ // Compute centroids
516
+ nk_f64_t inv_n = 1.0 / (nk_f64_t)n;
517
+ nk_f64_t centroid_a_x = total_ax * inv_n;
518
+ nk_f64_t centroid_a_y = total_ay * inv_n;
519
+ nk_f64_t centroid_a_z = total_az * inv_n;
520
+ nk_f64_t centroid_b_x = total_bx * inv_n;
521
+ nk_f64_t centroid_b_y = total_by * inv_n;
522
+ nk_f64_t centroid_b_z = total_bz * inv_n;
523
+
524
+ if (a_centroid) {
525
+ a_centroid[0] = centroid_a_x;
526
+ a_centroid[1] = centroid_a_y;
527
+ a_centroid[2] = centroid_a_z;
528
+ }
529
+ if (b_centroid) {
530
+ b_centroid[0] = centroid_b_x;
531
+ b_centroid[1] = centroid_b_y;
532
+ b_centroid[2] = centroid_b_z;
533
+ }
534
+
535
+ // Compute RMSD
536
+ nk_f64_t mean_diff_x = centroid_a_x - centroid_b_x;
537
+ nk_f64_t mean_diff_y = centroid_a_y - centroid_b_y;
538
+ nk_f64_t mean_diff_z = centroid_a_z - centroid_b_z;
539
+ nk_f64_t sum_squared = total_sq_x + total_sq_y + total_sq_z;
540
+ nk_f64_t mean_diff_sq = mean_diff_x * mean_diff_x + mean_diff_y * mean_diff_y + mean_diff_z * mean_diff_z;
541
+
542
+ *result = nk_f64_sqrt_haswell(sum_squared * inv_n - mean_diff_sq);
543
+ }
544
+
545
+ NK_PUBLIC void nk_kabsch_f32_haswell(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f32_t *a_centroid,
546
+ nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f64_t *result) {
547
+ if (scale) *scale = 1.0f;
548
+ __m256d sum_a_x_f64x4 = _mm256_setzero_pd(), sum_a_y_f64x4 = _mm256_setzero_pd();
549
+ __m256d sum_a_z_f64x4 = _mm256_setzero_pd(), sum_b_x_f64x4 = _mm256_setzero_pd();
550
+ __m256d sum_b_y_f64x4 = _mm256_setzero_pd(), sum_b_z_f64x4 = _mm256_setzero_pd();
551
+ __m256d covariance_00_f64x4 = _mm256_setzero_pd(), covariance_01_f64x4 = _mm256_setzero_pd();
552
+ __m256d covariance_02_f64x4 = _mm256_setzero_pd(), covariance_10_f64x4 = _mm256_setzero_pd();
553
+ __m256d covariance_11_f64x4 = _mm256_setzero_pd(), covariance_12_f64x4 = _mm256_setzero_pd();
554
+ __m256d covariance_20_f64x4 = _mm256_setzero_pd(), covariance_21_f64x4 = _mm256_setzero_pd();
555
+ __m256d covariance_22_f64x4 = _mm256_setzero_pd();
556
+ __m256 a_x_f32x8, a_y_f32x8, a_z_f32x8, b_x_f32x8, b_y_f32x8, b_z_f32x8;
557
+ nk_size_t index = 0;
558
+
559
+ for (; index + 8 <= n; index += 8) {
560
+ nk_deinterleave_f32x8_haswell_(a + index * 3, &a_x_f32x8, &a_y_f32x8, &a_z_f32x8),
561
+ nk_deinterleave_f32x8_haswell_(b + index * 3, &b_x_f32x8, &b_y_f32x8, &b_z_f32x8);
562
+ __m256d a_x_lower_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(a_x_f32x8));
563
+ __m256d a_x_upper_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(a_x_f32x8, 1));
564
+ __m256d a_y_lower_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(a_y_f32x8));
565
+ __m256d a_y_upper_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(a_y_f32x8, 1));
566
+ __m256d a_z_lower_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(a_z_f32x8));
567
+ __m256d a_z_upper_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(a_z_f32x8, 1));
568
+ __m256d b_x_lower_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(b_x_f32x8));
569
+ __m256d b_x_upper_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(b_x_f32x8, 1));
570
+ __m256d b_y_lower_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(b_y_f32x8));
571
+ __m256d b_y_upper_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(b_y_f32x8, 1));
572
+ __m256d b_z_lower_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(b_z_f32x8));
573
+ __m256d b_z_upper_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(b_z_f32x8, 1));
574
+
575
+ sum_a_x_f64x4 = _mm256_add_pd(sum_a_x_f64x4, _mm256_add_pd(a_x_lower_f64x4, a_x_upper_f64x4));
576
+ sum_a_y_f64x4 = _mm256_add_pd(sum_a_y_f64x4, _mm256_add_pd(a_y_lower_f64x4, a_y_upper_f64x4));
577
+ sum_a_z_f64x4 = _mm256_add_pd(sum_a_z_f64x4, _mm256_add_pd(a_z_lower_f64x4, a_z_upper_f64x4));
578
+ sum_b_x_f64x4 = _mm256_add_pd(sum_b_x_f64x4, _mm256_add_pd(b_x_lower_f64x4, b_x_upper_f64x4));
579
+ sum_b_y_f64x4 = _mm256_add_pd(sum_b_y_f64x4, _mm256_add_pd(b_y_lower_f64x4, b_y_upper_f64x4));
580
+ sum_b_z_f64x4 = _mm256_add_pd(sum_b_z_f64x4, _mm256_add_pd(b_z_lower_f64x4, b_z_upper_f64x4));
581
+
582
+ covariance_00_f64x4 = _mm256_add_pd(covariance_00_f64x4,
583
+ _mm256_add_pd(_mm256_mul_pd(a_x_lower_f64x4, b_x_lower_f64x4),
584
+ _mm256_mul_pd(a_x_upper_f64x4, b_x_upper_f64x4)));
585
+ covariance_01_f64x4 = _mm256_add_pd(covariance_01_f64x4,
586
+ _mm256_add_pd(_mm256_mul_pd(a_x_lower_f64x4, b_y_lower_f64x4),
587
+ _mm256_mul_pd(a_x_upper_f64x4, b_y_upper_f64x4)));
588
+ covariance_02_f64x4 = _mm256_add_pd(covariance_02_f64x4,
589
+ _mm256_add_pd(_mm256_mul_pd(a_x_lower_f64x4, b_z_lower_f64x4),
590
+ _mm256_mul_pd(a_x_upper_f64x4, b_z_upper_f64x4)));
591
+ covariance_10_f64x4 = _mm256_add_pd(covariance_10_f64x4,
592
+ _mm256_add_pd(_mm256_mul_pd(a_y_lower_f64x4, b_x_lower_f64x4),
593
+ _mm256_mul_pd(a_y_upper_f64x4, b_x_upper_f64x4)));
594
+ covariance_11_f64x4 = _mm256_add_pd(covariance_11_f64x4,
595
+ _mm256_add_pd(_mm256_mul_pd(a_y_lower_f64x4, b_y_lower_f64x4),
596
+ _mm256_mul_pd(a_y_upper_f64x4, b_y_upper_f64x4)));
597
+ covariance_12_f64x4 = _mm256_add_pd(covariance_12_f64x4,
598
+ _mm256_add_pd(_mm256_mul_pd(a_y_lower_f64x4, b_z_lower_f64x4),
599
+ _mm256_mul_pd(a_y_upper_f64x4, b_z_upper_f64x4)));
600
+ covariance_20_f64x4 = _mm256_add_pd(covariance_20_f64x4,
601
+ _mm256_add_pd(_mm256_mul_pd(a_z_lower_f64x4, b_x_lower_f64x4),
602
+ _mm256_mul_pd(a_z_upper_f64x4, b_x_upper_f64x4)));
603
+ covariance_21_f64x4 = _mm256_add_pd(covariance_21_f64x4,
604
+ _mm256_add_pd(_mm256_mul_pd(a_z_lower_f64x4, b_y_lower_f64x4),
605
+ _mm256_mul_pd(a_z_upper_f64x4, b_y_upper_f64x4)));
606
+ covariance_22_f64x4 = _mm256_add_pd(covariance_22_f64x4,
607
+ _mm256_add_pd(_mm256_mul_pd(a_z_lower_f64x4, b_z_lower_f64x4),
608
+ _mm256_mul_pd(a_z_upper_f64x4, b_z_upper_f64x4)));
609
+ }
610
+
611
+ nk_f64_t sum_a_x = nk_reduce_add_f64x4_haswell_(sum_a_x_f64x4);
612
+ nk_f64_t sum_a_y = nk_reduce_add_f64x4_haswell_(sum_a_y_f64x4);
613
+ nk_f64_t sum_a_z = nk_reduce_add_f64x4_haswell_(sum_a_z_f64x4);
614
+ nk_f64_t sum_b_x = nk_reduce_add_f64x4_haswell_(sum_b_x_f64x4);
615
+ nk_f64_t sum_b_y = nk_reduce_add_f64x4_haswell_(sum_b_y_f64x4);
616
+ nk_f64_t sum_b_z = nk_reduce_add_f64x4_haswell_(sum_b_z_f64x4);
617
+ nk_f64_t h[9] = {
618
+ nk_reduce_add_f64x4_haswell_(covariance_00_f64x4), nk_reduce_add_f64x4_haswell_(covariance_01_f64x4),
619
+ nk_reduce_add_f64x4_haswell_(covariance_02_f64x4), nk_reduce_add_f64x4_haswell_(covariance_10_f64x4),
620
+ nk_reduce_add_f64x4_haswell_(covariance_11_f64x4), nk_reduce_add_f64x4_haswell_(covariance_12_f64x4),
621
+ nk_reduce_add_f64x4_haswell_(covariance_20_f64x4), nk_reduce_add_f64x4_haswell_(covariance_21_f64x4),
622
+ nk_reduce_add_f64x4_haswell_(covariance_22_f64x4)};
623
+
624
+ for (; index < n; ++index) {
625
+ nk_f64_t a_x = a[index * 3 + 0], a_y = a[index * 3 + 1], a_z = a[index * 3 + 2];
626
+ nk_f64_t b_x = b[index * 3 + 0], b_y = b[index * 3 + 1], b_z = b[index * 3 + 2];
627
+ sum_a_x += a_x, sum_a_y += a_y, sum_a_z += a_z;
628
+ sum_b_x += b_x, sum_b_y += b_y, sum_b_z += b_z;
629
+ h[0] += a_x * b_x, h[1] += a_x * b_y, h[2] += a_x * b_z;
630
+ h[3] += a_y * b_x, h[4] += a_y * b_y, h[5] += a_y * b_z;
631
+ h[6] += a_z * b_x, h[7] += a_z * b_y, h[8] += a_z * b_z;
632
+ }
633
+
634
+ nk_f64_t inv_n = 1.0 / (nk_f64_t)n;
635
+ nk_f64_t centroid_a_x = sum_a_x * inv_n, centroid_a_y = sum_a_y * inv_n, centroid_a_z = sum_a_z * inv_n;
636
+ nk_f64_t centroid_b_x = sum_b_x * inv_n, centroid_b_y = sum_b_y * inv_n, centroid_b_z = sum_b_z * inv_n;
637
+ if (a_centroid)
638
+ a_centroid[0] = (nk_f32_t)centroid_a_x, a_centroid[1] = (nk_f32_t)centroid_a_y,
639
+ a_centroid[2] = (nk_f32_t)centroid_a_z;
640
+ if (b_centroid)
641
+ b_centroid[0] = (nk_f32_t)centroid_b_x, b_centroid[1] = (nk_f32_t)centroid_b_y,
642
+ b_centroid[2] = (nk_f32_t)centroid_b_z;
643
+
644
+ h[0] -= (nk_f64_t)n * centroid_a_x * centroid_b_x, h[1] -= (nk_f64_t)n * centroid_a_x * centroid_b_y,
645
+ h[2] -= (nk_f64_t)n * centroid_a_x * centroid_b_z, h[3] -= (nk_f64_t)n * centroid_a_y * centroid_b_x,
646
+ h[4] -= (nk_f64_t)n * centroid_a_y * centroid_b_y, h[5] -= (nk_f64_t)n * centroid_a_y * centroid_b_z,
647
+ h[6] -= (nk_f64_t)n * centroid_a_z * centroid_b_x, h[7] -= (nk_f64_t)n * centroid_a_z * centroid_b_y,
648
+ h[8] -= (nk_f64_t)n * centroid_a_z * centroid_b_z;
649
+
650
+ nk_f64_t cross_covariance[9] = {h[0], h[1], h[2], h[3], h[4], h[5], h[6], h[7], h[8]};
651
+ nk_f64_t svd_u[9], svd_s[9], svd_v[9], r[9];
652
+ nk_svd3x3_f64_(cross_covariance, svd_u, svd_s, svd_v);
653
+ r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
654
+ r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
655
+ r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
656
+ r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
657
+ r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
658
+ r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
659
+ r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
660
+ r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
661
+ r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
662
+ if (nk_det3x3_f64_(r) < 0) {
663
+ svd_v[2] = -svd_v[2], svd_v[5] = -svd_v[5], svd_v[8] = -svd_v[8];
664
+ r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
665
+ r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
666
+ r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
667
+ r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
668
+ r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
669
+ r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
670
+ r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
671
+ r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
672
+ r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
673
+ }
674
+
675
+ if (rotation)
676
+ for (int j = 0; j != 9; ++j) rotation[j] = (nk_f32_t)r[j];
677
+ nk_f64_t sum_squared = nk_transformed_ssd_f32_haswell_(a, b, n, r, 1.0, centroid_a_x, centroid_a_y, centroid_a_z,
678
+ centroid_b_x, centroid_b_y, centroid_b_z);
679
+ *result = nk_f64_sqrt_haswell(sum_squared / (nk_f64_t)n);
680
+ }
681
+
682
+ NK_PUBLIC void nk_kabsch_f64_haswell(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *a_centroid,
683
+ nk_f64_t *b_centroid, nk_f64_t *rotation, nk_f64_t *scale, nk_f64_t *result) {
684
+ __m256d const zeros_f64x4 = _mm256_setzero_pd();
685
+
686
+ // Accumulators for centroids
687
+ __m256d sum_a_x_f64x4 = zeros_f64x4, sum_a_y_f64x4 = zeros_f64x4, sum_a_z_f64x4 = zeros_f64x4;
688
+ __m256d sum_b_x_f64x4 = zeros_f64x4, sum_b_y_f64x4 = zeros_f64x4, sum_b_z_f64x4 = zeros_f64x4;
689
+
690
+ // Accumulators for covariance matrix (sum of outer products)
691
+ __m256d cov_xx_f64x4 = zeros_f64x4, cov_xy_f64x4 = zeros_f64x4, cov_xz_f64x4 = zeros_f64x4;
692
+ __m256d cov_yx_f64x4 = zeros_f64x4, cov_yy_f64x4 = zeros_f64x4, cov_yz_f64x4 = zeros_f64x4;
693
+ __m256d cov_zx_f64x4 = zeros_f64x4, cov_zy_f64x4 = zeros_f64x4, cov_zz_f64x4 = zeros_f64x4;
694
+
695
+ nk_size_t i = 0;
696
+ __m256d a_x_f64x4, a_y_f64x4, a_z_f64x4, b_x_f64x4, b_y_f64x4, b_z_f64x4;
697
+
698
+ // Fused single-pass
699
+ for (; i + 4 <= n; i += 4) {
700
+ nk_deinterleave_f64x4_haswell_(a + i * 3, &a_x_f64x4, &a_y_f64x4, &a_z_f64x4);
701
+ nk_deinterleave_f64x4_haswell_(b + i * 3, &b_x_f64x4, &b_y_f64x4, &b_z_f64x4);
702
+
703
+ sum_a_x_f64x4 = _mm256_add_pd(sum_a_x_f64x4, a_x_f64x4);
704
+ sum_a_y_f64x4 = _mm256_add_pd(sum_a_y_f64x4, a_y_f64x4);
705
+ sum_a_z_f64x4 = _mm256_add_pd(sum_a_z_f64x4, a_z_f64x4);
706
+ sum_b_x_f64x4 = _mm256_add_pd(sum_b_x_f64x4, b_x_f64x4);
707
+ sum_b_y_f64x4 = _mm256_add_pd(sum_b_y_f64x4, b_y_f64x4);
708
+ sum_b_z_f64x4 = _mm256_add_pd(sum_b_z_f64x4, b_z_f64x4);
709
+
710
+ cov_xx_f64x4 = _mm256_fmadd_pd(a_x_f64x4, b_x_f64x4, cov_xx_f64x4);
711
+ cov_xy_f64x4 = _mm256_fmadd_pd(a_x_f64x4, b_y_f64x4, cov_xy_f64x4);
712
+ cov_xz_f64x4 = _mm256_fmadd_pd(a_x_f64x4, b_z_f64x4, cov_xz_f64x4);
713
+ cov_yx_f64x4 = _mm256_fmadd_pd(a_y_f64x4, b_x_f64x4, cov_yx_f64x4);
714
+ cov_yy_f64x4 = _mm256_fmadd_pd(a_y_f64x4, b_y_f64x4, cov_yy_f64x4);
715
+ cov_yz_f64x4 = _mm256_fmadd_pd(a_y_f64x4, b_z_f64x4, cov_yz_f64x4);
716
+ cov_zx_f64x4 = _mm256_fmadd_pd(a_z_f64x4, b_x_f64x4, cov_zx_f64x4);
717
+ cov_zy_f64x4 = _mm256_fmadd_pd(a_z_f64x4, b_y_f64x4, cov_zy_f64x4);
718
+ cov_zz_f64x4 = _mm256_fmadd_pd(a_z_f64x4, b_z_f64x4, cov_zz_f64x4);
719
+ }
720
+
721
+ // Reduce vector accumulators
722
+ nk_f64_t sum_a_x = nk_reduce_stable_f64x4_haswell_(sum_a_x_f64x4), sum_a_x_compensation = 0.0;
723
+ nk_f64_t sum_a_y = nk_reduce_stable_f64x4_haswell_(sum_a_y_f64x4), sum_a_y_compensation = 0.0;
724
+ nk_f64_t sum_a_z = nk_reduce_stable_f64x4_haswell_(sum_a_z_f64x4), sum_a_z_compensation = 0.0;
725
+ nk_f64_t sum_b_x = nk_reduce_stable_f64x4_haswell_(sum_b_x_f64x4), sum_b_x_compensation = 0.0;
726
+ nk_f64_t sum_b_y = nk_reduce_stable_f64x4_haswell_(sum_b_y_f64x4), sum_b_y_compensation = 0.0;
727
+ nk_f64_t sum_b_z = nk_reduce_stable_f64x4_haswell_(sum_b_z_f64x4), sum_b_z_compensation = 0.0;
728
+
729
+ nk_f64_t covariance_x_x = nk_reduce_stable_f64x4_haswell_(cov_xx_f64x4), covariance_x_x_compensation = 0.0;
730
+ nk_f64_t covariance_x_y = nk_reduce_stable_f64x4_haswell_(cov_xy_f64x4), covariance_x_y_compensation = 0.0;
731
+ nk_f64_t covariance_x_z = nk_reduce_stable_f64x4_haswell_(cov_xz_f64x4), covariance_x_z_compensation = 0.0;
732
+ nk_f64_t covariance_y_x = nk_reduce_stable_f64x4_haswell_(cov_yx_f64x4), covariance_y_x_compensation = 0.0;
733
+ nk_f64_t covariance_y_y = nk_reduce_stable_f64x4_haswell_(cov_yy_f64x4), covariance_y_y_compensation = 0.0;
734
+ nk_f64_t covariance_y_z = nk_reduce_stable_f64x4_haswell_(cov_yz_f64x4), covariance_y_z_compensation = 0.0;
735
+ nk_f64_t covariance_z_x = nk_reduce_stable_f64x4_haswell_(cov_zx_f64x4), covariance_z_x_compensation = 0.0;
736
+ nk_f64_t covariance_z_y = nk_reduce_stable_f64x4_haswell_(cov_zy_f64x4), covariance_z_y_compensation = 0.0;
737
+ nk_f64_t covariance_z_z = nk_reduce_stable_f64x4_haswell_(cov_zz_f64x4), covariance_z_z_compensation = 0.0;
738
+
739
+ // Scalar tail
740
+ for (; i < n; ++i) {
741
+ nk_f64_t ax = a[i * 3 + 0], ay = a[i * 3 + 1], az = a[i * 3 + 2];
742
+ nk_f64_t bx = b[i * 3 + 0], by = b[i * 3 + 1], bz = b[i * 3 + 2];
743
+ nk_accumulate_sum_f64_(&sum_a_x, &sum_a_x_compensation, ax);
744
+ nk_accumulate_sum_f64_(&sum_a_y, &sum_a_y_compensation, ay);
745
+ nk_accumulate_sum_f64_(&sum_a_z, &sum_a_z_compensation, az);
746
+ nk_accumulate_sum_f64_(&sum_b_x, &sum_b_x_compensation, bx);
747
+ nk_accumulate_sum_f64_(&sum_b_y, &sum_b_y_compensation, by);
748
+ nk_accumulate_sum_f64_(&sum_b_z, &sum_b_z_compensation, bz);
749
+ nk_accumulate_product_f64_(&covariance_x_x, &covariance_x_x_compensation, ax, bx);
750
+ nk_accumulate_product_f64_(&covariance_x_y, &covariance_x_y_compensation, ax, by);
751
+ nk_accumulate_product_f64_(&covariance_x_z, &covariance_x_z_compensation, ax, bz);
752
+ nk_accumulate_product_f64_(&covariance_y_x, &covariance_y_x_compensation, ay, bx);
753
+ nk_accumulate_product_f64_(&covariance_y_y, &covariance_y_y_compensation, ay, by);
754
+ nk_accumulate_product_f64_(&covariance_y_z, &covariance_y_z_compensation, ay, bz);
755
+ nk_accumulate_product_f64_(&covariance_z_x, &covariance_z_x_compensation, az, bx);
756
+ nk_accumulate_product_f64_(&covariance_z_y, &covariance_z_y_compensation, az, by);
757
+ nk_accumulate_product_f64_(&covariance_z_z, &covariance_z_z_compensation, az, bz);
758
+ }
759
+
760
+ sum_a_x += sum_a_x_compensation, sum_a_y += sum_a_y_compensation, sum_a_z += sum_a_z_compensation;
761
+ sum_b_x += sum_b_x_compensation, sum_b_y += sum_b_y_compensation, sum_b_z += sum_b_z_compensation;
762
+ covariance_x_x += covariance_x_x_compensation, covariance_x_y += covariance_x_y_compensation,
763
+ covariance_x_z += covariance_x_z_compensation;
764
+ covariance_y_x += covariance_y_x_compensation, covariance_y_y += covariance_y_y_compensation,
765
+ covariance_y_z += covariance_y_z_compensation;
766
+ covariance_z_x += covariance_z_x_compensation, covariance_z_y += covariance_z_y_compensation,
767
+ covariance_z_z += covariance_z_z_compensation;
768
+
769
+ // Compute centroids
770
+ nk_f64_t inv_n = 1.0 / (nk_f64_t)n;
771
+ nk_f64_t centroid_a_x = sum_a_x * inv_n;
772
+ nk_f64_t centroid_a_y = sum_a_y * inv_n;
773
+ nk_f64_t centroid_a_z = sum_a_z * inv_n;
774
+ nk_f64_t centroid_b_x = sum_b_x * inv_n;
775
+ nk_f64_t centroid_b_y = sum_b_y * inv_n;
776
+ nk_f64_t centroid_b_z = sum_b_z * inv_n;
777
+
778
+ if (a_centroid) {
779
+ a_centroid[0] = centroid_a_x;
780
+ a_centroid[1] = centroid_a_y;
781
+ a_centroid[2] = centroid_a_z;
782
+ }
783
+ if (b_centroid) {
784
+ b_centroid[0] = centroid_b_x;
785
+ b_centroid[1] = centroid_b_y;
786
+ b_centroid[2] = centroid_b_z;
787
+ }
788
+
789
+ // Apply centering correction: H_centered = H - n * centroid_a * centroid_bᵀ
790
+ covariance_x_x -= n * centroid_a_x * centroid_b_x;
791
+ covariance_x_y -= n * centroid_a_x * centroid_b_y;
792
+ covariance_x_z -= n * centroid_a_x * centroid_b_z;
793
+ covariance_y_x -= n * centroid_a_y * centroid_b_x;
794
+ covariance_y_y -= n * centroid_a_y * centroid_b_y;
795
+ covariance_y_z -= n * centroid_a_y * centroid_b_z;
796
+ covariance_z_x -= n * centroid_a_z * centroid_b_x;
797
+ covariance_z_y -= n * centroid_a_z * centroid_b_y;
798
+ covariance_z_z -= n * centroid_a_z * centroid_b_z;
799
+
800
+ // Compute SVD and optimal rotation using f64 precision (svd_s is 9-element diagonal matrix)
801
+ nk_f64_t cross_covariance[9] = {covariance_x_x, covariance_x_y, covariance_x_z, covariance_y_x, covariance_y_y,
802
+ covariance_y_z, covariance_z_x, covariance_z_y, covariance_z_z};
803
+ nk_f64_t svd_u[9], svd_s[9], svd_v[9];
804
+ nk_svd3x3_f64_(cross_covariance, svd_u, svd_s, svd_v);
805
+
806
+ nk_f64_t r[9];
807
+ nk_rotation_from_svd_f64_haswell_(svd_u, svd_v, r);
808
+
809
+ // Handle reflection: if det(R) < 0, negate third column of V and recompute R
810
+ if (nk_det3x3_f64_(r) < 0) {
811
+ svd_v[2] = -svd_v[2];
812
+ svd_v[5] = -svd_v[5];
813
+ svd_v[8] = -svd_v[8];
814
+ nk_rotation_from_svd_f64_haswell_(svd_u, svd_v, r);
815
+ }
816
+
817
+ /* Output rotation matrix and scale=1.0 */
818
+ if (rotation) {
819
+ for (int j = 0; j < 9; ++j) rotation[j] = r[j];
820
+ }
821
+ if (scale) *scale = 1.0;
822
+
823
+ // Compute RMSD after optimal rotation
824
+ nk_f64_t sum_squared = nk_transformed_ssd_f64_haswell_(a, b, n, r, 1.0, centroid_a_x, centroid_a_y, centroid_a_z,
825
+ centroid_b_x, centroid_b_y, centroid_b_z);
826
+ *result = nk_f64_sqrt_haswell(sum_squared * inv_n);
827
+ }
828
+
829
+ NK_PUBLIC void nk_umeyama_f32_haswell(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f32_t *a_centroid,
830
+ nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f64_t *result) {
831
+ __m256d sum_a_x_f64x4 = _mm256_setzero_pd(), sum_a_y_f64x4 = _mm256_setzero_pd();
832
+ __m256d sum_a_z_f64x4 = _mm256_setzero_pd(), sum_b_x_f64x4 = _mm256_setzero_pd();
833
+ __m256d sum_b_y_f64x4 = _mm256_setzero_pd(), sum_b_z_f64x4 = _mm256_setzero_pd();
834
+ __m256d covariance_00_f64x4 = _mm256_setzero_pd(), covariance_01_f64x4 = _mm256_setzero_pd();
835
+ __m256d covariance_02_f64x4 = _mm256_setzero_pd(), covariance_10_f64x4 = _mm256_setzero_pd();
836
+ __m256d covariance_11_f64x4 = _mm256_setzero_pd(), covariance_12_f64x4 = _mm256_setzero_pd();
837
+ __m256d covariance_20_f64x4 = _mm256_setzero_pd(), covariance_21_f64x4 = _mm256_setzero_pd();
838
+ __m256d covariance_22_f64x4 = _mm256_setzero_pd(), variance_a_f64x4 = _mm256_setzero_pd();
839
+ __m256 a_x_f32x8, a_y_f32x8, a_z_f32x8, b_x_f32x8, b_y_f32x8, b_z_f32x8;
840
+ nk_size_t index = 0;
841
+
842
+ for (; index + 8 <= n; index += 8) {
843
+ nk_deinterleave_f32x8_haswell_(a + index * 3, &a_x_f32x8, &a_y_f32x8, &a_z_f32x8),
844
+ nk_deinterleave_f32x8_haswell_(b + index * 3, &b_x_f32x8, &b_y_f32x8, &b_z_f32x8);
845
+ __m256d a_x_lower_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(a_x_f32x8));
846
+ __m256d a_x_upper_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(a_x_f32x8, 1));
847
+ __m256d a_y_lower_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(a_y_f32x8));
848
+ __m256d a_y_upper_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(a_y_f32x8, 1));
849
+ __m256d a_z_lower_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(a_z_f32x8));
850
+ __m256d a_z_upper_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(a_z_f32x8, 1));
851
+ __m256d b_x_lower_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(b_x_f32x8));
852
+ __m256d b_x_upper_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(b_x_f32x8, 1));
853
+ __m256d b_y_lower_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(b_y_f32x8));
854
+ __m256d b_y_upper_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(b_y_f32x8, 1));
855
+ __m256d b_z_lower_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(b_z_f32x8));
856
+ __m256d b_z_upper_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(b_z_f32x8, 1));
857
+
858
+ sum_a_x_f64x4 = _mm256_add_pd(sum_a_x_f64x4, _mm256_add_pd(a_x_lower_f64x4, a_x_upper_f64x4));
859
+ sum_a_y_f64x4 = _mm256_add_pd(sum_a_y_f64x4, _mm256_add_pd(a_y_lower_f64x4, a_y_upper_f64x4));
860
+ sum_a_z_f64x4 = _mm256_add_pd(sum_a_z_f64x4, _mm256_add_pd(a_z_lower_f64x4, a_z_upper_f64x4));
861
+ sum_b_x_f64x4 = _mm256_add_pd(sum_b_x_f64x4, _mm256_add_pd(b_x_lower_f64x4, b_x_upper_f64x4));
862
+ sum_b_y_f64x4 = _mm256_add_pd(sum_b_y_f64x4, _mm256_add_pd(b_y_lower_f64x4, b_y_upper_f64x4));
863
+ sum_b_z_f64x4 = _mm256_add_pd(sum_b_z_f64x4, _mm256_add_pd(b_z_lower_f64x4, b_z_upper_f64x4));
864
+ covariance_00_f64x4 = _mm256_add_pd(covariance_00_f64x4,
865
+ _mm256_add_pd(_mm256_mul_pd(a_x_lower_f64x4, b_x_lower_f64x4),
866
+ _mm256_mul_pd(a_x_upper_f64x4, b_x_upper_f64x4)));
867
+ covariance_01_f64x4 = _mm256_add_pd(covariance_01_f64x4,
868
+ _mm256_add_pd(_mm256_mul_pd(a_x_lower_f64x4, b_y_lower_f64x4),
869
+ _mm256_mul_pd(a_x_upper_f64x4, b_y_upper_f64x4)));
870
+ covariance_02_f64x4 = _mm256_add_pd(covariance_02_f64x4,
871
+ _mm256_add_pd(_mm256_mul_pd(a_x_lower_f64x4, b_z_lower_f64x4),
872
+ _mm256_mul_pd(a_x_upper_f64x4, b_z_upper_f64x4)));
873
+ covariance_10_f64x4 = _mm256_add_pd(covariance_10_f64x4,
874
+ _mm256_add_pd(_mm256_mul_pd(a_y_lower_f64x4, b_x_lower_f64x4),
875
+ _mm256_mul_pd(a_y_upper_f64x4, b_x_upper_f64x4)));
876
+ covariance_11_f64x4 = _mm256_add_pd(covariance_11_f64x4,
877
+ _mm256_add_pd(_mm256_mul_pd(a_y_lower_f64x4, b_y_lower_f64x4),
878
+ _mm256_mul_pd(a_y_upper_f64x4, b_y_upper_f64x4)));
879
+ covariance_12_f64x4 = _mm256_add_pd(covariance_12_f64x4,
880
+ _mm256_add_pd(_mm256_mul_pd(a_y_lower_f64x4, b_z_lower_f64x4),
881
+ _mm256_mul_pd(a_y_upper_f64x4, b_z_upper_f64x4)));
882
+ covariance_20_f64x4 = _mm256_add_pd(covariance_20_f64x4,
883
+ _mm256_add_pd(_mm256_mul_pd(a_z_lower_f64x4, b_x_lower_f64x4),
884
+ _mm256_mul_pd(a_z_upper_f64x4, b_x_upper_f64x4)));
885
+ covariance_21_f64x4 = _mm256_add_pd(covariance_21_f64x4,
886
+ _mm256_add_pd(_mm256_mul_pd(a_z_lower_f64x4, b_y_lower_f64x4),
887
+ _mm256_mul_pd(a_z_upper_f64x4, b_y_upper_f64x4)));
888
+ covariance_22_f64x4 = _mm256_add_pd(covariance_22_f64x4,
889
+ _mm256_add_pd(_mm256_mul_pd(a_z_lower_f64x4, b_z_lower_f64x4),
890
+ _mm256_mul_pd(a_z_upper_f64x4, b_z_upper_f64x4)));
891
+ variance_a_f64x4 = _mm256_add_pd(
892
+ variance_a_f64x4,
893
+ _mm256_add_pd(_mm256_add_pd(_mm256_mul_pd(a_x_lower_f64x4, a_x_lower_f64x4),
894
+ _mm256_mul_pd(a_x_upper_f64x4, a_x_upper_f64x4)),
895
+ _mm256_add_pd(_mm256_add_pd(_mm256_mul_pd(a_y_lower_f64x4, a_y_lower_f64x4),
896
+ _mm256_mul_pd(a_y_upper_f64x4, a_y_upper_f64x4)),
897
+ _mm256_add_pd(_mm256_mul_pd(a_z_lower_f64x4, a_z_lower_f64x4),
898
+ _mm256_mul_pd(a_z_upper_f64x4, a_z_upper_f64x4)))));
899
+ }
900
+
901
+ nk_f64_t sum_a_x = nk_reduce_add_f64x4_haswell_(sum_a_x_f64x4);
902
+ nk_f64_t sum_a_y = nk_reduce_add_f64x4_haswell_(sum_a_y_f64x4);
903
+ nk_f64_t sum_a_z = nk_reduce_add_f64x4_haswell_(sum_a_z_f64x4);
904
+ nk_f64_t sum_b_x = nk_reduce_add_f64x4_haswell_(sum_b_x_f64x4);
905
+ nk_f64_t sum_b_y = nk_reduce_add_f64x4_haswell_(sum_b_y_f64x4);
906
+ nk_f64_t sum_b_z = nk_reduce_add_f64x4_haswell_(sum_b_z_f64x4);
907
+ nk_f64_t h[9] = {
908
+ nk_reduce_add_f64x4_haswell_(covariance_00_f64x4), nk_reduce_add_f64x4_haswell_(covariance_01_f64x4),
909
+ nk_reduce_add_f64x4_haswell_(covariance_02_f64x4), nk_reduce_add_f64x4_haswell_(covariance_10_f64x4),
910
+ nk_reduce_add_f64x4_haswell_(covariance_11_f64x4), nk_reduce_add_f64x4_haswell_(covariance_12_f64x4),
911
+ nk_reduce_add_f64x4_haswell_(covariance_20_f64x4), nk_reduce_add_f64x4_haswell_(covariance_21_f64x4),
912
+ nk_reduce_add_f64x4_haswell_(covariance_22_f64x4)};
913
+ nk_f64_t variance_a = nk_reduce_add_f64x4_haswell_(variance_a_f64x4);
914
+
915
+ for (; index < n; ++index) {
916
+ nk_f64_t a_x = a[index * 3 + 0], a_y = a[index * 3 + 1], a_z = a[index * 3 + 2];
917
+ nk_f64_t b_x = b[index * 3 + 0], b_y = b[index * 3 + 1], b_z = b[index * 3 + 2];
918
+ sum_a_x += a_x, sum_a_y += a_y, sum_a_z += a_z;
919
+ sum_b_x += b_x, sum_b_y += b_y, sum_b_z += b_z;
920
+ h[0] += a_x * b_x, h[1] += a_x * b_y, h[2] += a_x * b_z;
921
+ h[3] += a_y * b_x, h[4] += a_y * b_y, h[5] += a_y * b_z;
922
+ h[6] += a_z * b_x, h[7] += a_z * b_y, h[8] += a_z * b_z;
923
+ variance_a += a_x * a_x + a_y * a_y + a_z * a_z;
924
+ }
925
+
926
+ nk_f64_t inv_n = 1.0 / (nk_f64_t)n;
927
+ nk_f64_t centroid_a_x = sum_a_x * inv_n, centroid_a_y = sum_a_y * inv_n, centroid_a_z = sum_a_z * inv_n;
928
+ nk_f64_t centroid_b_x = sum_b_x * inv_n, centroid_b_y = sum_b_y * inv_n, centroid_b_z = sum_b_z * inv_n;
929
+ if (a_centroid)
930
+ a_centroid[0] = (nk_f32_t)centroid_a_x, a_centroid[1] = (nk_f32_t)centroid_a_y,
931
+ a_centroid[2] = (nk_f32_t)centroid_a_z;
932
+ if (b_centroid)
933
+ b_centroid[0] = (nk_f32_t)centroid_b_x, b_centroid[1] = (nk_f32_t)centroid_b_y,
934
+ b_centroid[2] = (nk_f32_t)centroid_b_z;
935
+
936
+ variance_a = variance_a * inv_n -
937
+ (centroid_a_x * centroid_a_x + centroid_a_y * centroid_a_y + centroid_a_z * centroid_a_z);
938
+ h[0] -= (nk_f64_t)n * centroid_a_x * centroid_b_x, h[1] -= (nk_f64_t)n * centroid_a_x * centroid_b_y,
939
+ h[2] -= (nk_f64_t)n * centroid_a_x * centroid_b_z, h[3] -= (nk_f64_t)n * centroid_a_y * centroid_b_x,
940
+ h[4] -= (nk_f64_t)n * centroid_a_y * centroid_b_y, h[5] -= (nk_f64_t)n * centroid_a_y * centroid_b_z,
941
+ h[6] -= (nk_f64_t)n * centroid_a_z * centroid_b_x, h[7] -= (nk_f64_t)n * centroid_a_z * centroid_b_y,
942
+ h[8] -= (nk_f64_t)n * centroid_a_z * centroid_b_z;
943
+
944
+ nk_f64_t cross_covariance[9] = {h[0], h[1], h[2], h[3], h[4], h[5], h[6], h[7], h[8]};
945
+ nk_f64_t svd_u[9], svd_s[9], svd_v[9], r[9];
946
+ nk_svd3x3_f64_(cross_covariance, svd_u, svd_s, svd_v);
947
+ r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
948
+ r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
949
+ r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
950
+ r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
951
+ r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
952
+ r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
953
+ r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
954
+ r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
955
+ r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
956
+
957
+ nk_f64_t det = nk_det3x3_f64_(r), sign_correction = det < 0 ? -1.0 : 1.0;
958
+ if (det < 0) {
959
+ svd_v[2] = -svd_v[2], svd_v[5] = -svd_v[5], svd_v[8] = -svd_v[8];
960
+ r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
961
+ r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
962
+ r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
963
+ r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
964
+ r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
965
+ r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
966
+ r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
967
+ r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
968
+ r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
969
+ }
970
+
971
+ nk_f64_t applied_scale = (svd_s[0] + svd_s[4] + sign_correction * svd_s[8]) / ((nk_f64_t)n * variance_a);
972
+ if (rotation)
973
+ for (int j = 0; j != 9; ++j) rotation[j] = (nk_f32_t)r[j];
974
+ if (scale) *scale = (nk_f32_t)applied_scale;
975
+ *result = nk_f64_sqrt_haswell(nk_transformed_ssd_f32_haswell_(a, b, n, r, applied_scale, centroid_a_x, centroid_a_y,
976
+ centroid_a_z, centroid_b_x, centroid_b_y,
977
+ centroid_b_z) /
978
+ (nk_f64_t)n);
979
+ }
980
+
981
+ NK_PUBLIC void nk_umeyama_f64_haswell(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *a_centroid,
982
+ nk_f64_t *b_centroid, nk_f64_t *rotation, nk_f64_t *scale, nk_f64_t *result) {
983
+ // Fused single-pass: centroids, covariance, and variance of A
984
+ __m256d const zeros_f64x4 = _mm256_setzero_pd();
985
+
986
+ __m256d sum_a_x_f64x4 = zeros_f64x4, sum_a_y_f64x4 = zeros_f64x4, sum_a_z_f64x4 = zeros_f64x4;
987
+ __m256d sum_b_x_f64x4 = zeros_f64x4, sum_b_y_f64x4 = zeros_f64x4, sum_b_z_f64x4 = zeros_f64x4;
988
+ __m256d cov_xx_f64x4 = zeros_f64x4, cov_xy_f64x4 = zeros_f64x4, cov_xz_f64x4 = zeros_f64x4;
989
+ __m256d cov_yx_f64x4 = zeros_f64x4, cov_yy_f64x4 = zeros_f64x4, cov_yz_f64x4 = zeros_f64x4;
990
+ __m256d cov_zx_f64x4 = zeros_f64x4, cov_zy_f64x4 = zeros_f64x4, cov_zz_f64x4 = zeros_f64x4;
991
+ __m256d variance_a_f64x4 = zeros_f64x4;
992
+
993
+ nk_size_t i = 0;
994
+ __m256d a_x_f64x4, a_y_f64x4, a_z_f64x4, b_x_f64x4, b_y_f64x4, b_z_f64x4;
995
+
996
+ for (; i + 4 <= n; i += 4) {
997
+ nk_deinterleave_f64x4_haswell_(a + i * 3, &a_x_f64x4, &a_y_f64x4, &a_z_f64x4);
998
+ nk_deinterleave_f64x4_haswell_(b + i * 3, &b_x_f64x4, &b_y_f64x4, &b_z_f64x4);
999
+
1000
+ sum_a_x_f64x4 = _mm256_add_pd(sum_a_x_f64x4, a_x_f64x4),
1001
+ sum_a_y_f64x4 = _mm256_add_pd(sum_a_y_f64x4, a_y_f64x4);
1002
+ sum_a_z_f64x4 = _mm256_add_pd(sum_a_z_f64x4, a_z_f64x4);
1003
+ sum_b_x_f64x4 = _mm256_add_pd(sum_b_x_f64x4, b_x_f64x4),
1004
+ sum_b_y_f64x4 = _mm256_add_pd(sum_b_y_f64x4, b_y_f64x4);
1005
+ sum_b_z_f64x4 = _mm256_add_pd(sum_b_z_f64x4, b_z_f64x4);
1006
+
1007
+ cov_xx_f64x4 = _mm256_fmadd_pd(a_x_f64x4, b_x_f64x4, cov_xx_f64x4),
1008
+ cov_xy_f64x4 = _mm256_fmadd_pd(a_x_f64x4, b_y_f64x4, cov_xy_f64x4);
1009
+ cov_xz_f64x4 = _mm256_fmadd_pd(a_x_f64x4, b_z_f64x4, cov_xz_f64x4);
1010
+ cov_yx_f64x4 = _mm256_fmadd_pd(a_y_f64x4, b_x_f64x4, cov_yx_f64x4),
1011
+ cov_yy_f64x4 = _mm256_fmadd_pd(a_y_f64x4, b_y_f64x4, cov_yy_f64x4);
1012
+ cov_yz_f64x4 = _mm256_fmadd_pd(a_y_f64x4, b_z_f64x4, cov_yz_f64x4);
1013
+ cov_zx_f64x4 = _mm256_fmadd_pd(a_z_f64x4, b_x_f64x4, cov_zx_f64x4),
1014
+ cov_zy_f64x4 = _mm256_fmadd_pd(a_z_f64x4, b_y_f64x4, cov_zy_f64x4);
1015
+ cov_zz_f64x4 = _mm256_fmadd_pd(a_z_f64x4, b_z_f64x4, cov_zz_f64x4);
1016
+ variance_a_f64x4 = _mm256_fmadd_pd(a_x_f64x4, a_x_f64x4, variance_a_f64x4);
1017
+ variance_a_f64x4 = _mm256_fmadd_pd(a_y_f64x4, a_y_f64x4, variance_a_f64x4);
1018
+ variance_a_f64x4 = _mm256_fmadd_pd(a_z_f64x4, a_z_f64x4, variance_a_f64x4);
1019
+ }
1020
+
1021
+ // Reduce vector accumulators
1022
+ nk_f64_t sum_a_x = nk_reduce_stable_f64x4_haswell_(sum_a_x_f64x4), sum_a_x_compensation = 0.0;
1023
+ nk_f64_t sum_a_y = nk_reduce_stable_f64x4_haswell_(sum_a_y_f64x4), sum_a_y_compensation = 0.0;
1024
+ nk_f64_t sum_a_z = nk_reduce_stable_f64x4_haswell_(sum_a_z_f64x4), sum_a_z_compensation = 0.0;
1025
+ nk_f64_t sum_b_x = nk_reduce_stable_f64x4_haswell_(sum_b_x_f64x4), sum_b_x_compensation = 0.0;
1026
+ nk_f64_t sum_b_y = nk_reduce_stable_f64x4_haswell_(sum_b_y_f64x4), sum_b_y_compensation = 0.0;
1027
+ nk_f64_t sum_b_z = nk_reduce_stable_f64x4_haswell_(sum_b_z_f64x4), sum_b_z_compensation = 0.0;
1028
+ nk_f64_t covariance_x_x = nk_reduce_stable_f64x4_haswell_(cov_xx_f64x4), covariance_x_x_compensation = 0.0;
1029
+ nk_f64_t covariance_x_y = nk_reduce_stable_f64x4_haswell_(cov_xy_f64x4), covariance_x_y_compensation = 0.0;
1030
+ nk_f64_t covariance_x_z = nk_reduce_stable_f64x4_haswell_(cov_xz_f64x4), covariance_x_z_compensation = 0.0;
1031
+ nk_f64_t covariance_y_x = nk_reduce_stable_f64x4_haswell_(cov_yx_f64x4), covariance_y_x_compensation = 0.0;
1032
+ nk_f64_t covariance_y_y = nk_reduce_stable_f64x4_haswell_(cov_yy_f64x4), covariance_y_y_compensation = 0.0;
1033
+ nk_f64_t covariance_y_z = nk_reduce_stable_f64x4_haswell_(cov_yz_f64x4), covariance_y_z_compensation = 0.0;
1034
+ nk_f64_t covariance_z_x = nk_reduce_stable_f64x4_haswell_(cov_zx_f64x4), covariance_z_x_compensation = 0.0;
1035
+ nk_f64_t covariance_z_y = nk_reduce_stable_f64x4_haswell_(cov_zy_f64x4), covariance_z_y_compensation = 0.0;
1036
+ nk_f64_t covariance_z_z = nk_reduce_stable_f64x4_haswell_(cov_zz_f64x4), covariance_z_z_compensation = 0.0;
1037
+ nk_f64_t variance_a_sum = nk_reduce_stable_f64x4_haswell_(variance_a_f64x4), variance_a_compensation = 0.0;
1038
+
1039
+ // Scalar tail loop for remaining points
1040
+ for (; i < n; i++) {
1041
+ nk_f64_t ax = a[i * 3 + 0], ay = a[i * 3 + 1], az = a[i * 3 + 2];
1042
+ nk_f64_t bx = b[i * 3 + 0], by = b[i * 3 + 1], bz = b[i * 3 + 2];
1043
+ nk_accumulate_sum_f64_(&sum_a_x, &sum_a_x_compensation, ax);
1044
+ nk_accumulate_sum_f64_(&sum_a_y, &sum_a_y_compensation, ay);
1045
+ nk_accumulate_sum_f64_(&sum_a_z, &sum_a_z_compensation, az);
1046
+ nk_accumulate_sum_f64_(&sum_b_x, &sum_b_x_compensation, bx);
1047
+ nk_accumulate_sum_f64_(&sum_b_y, &sum_b_y_compensation, by);
1048
+ nk_accumulate_sum_f64_(&sum_b_z, &sum_b_z_compensation, bz);
1049
+ nk_accumulate_product_f64_(&covariance_x_x, &covariance_x_x_compensation, ax, bx);
1050
+ nk_accumulate_product_f64_(&covariance_x_y, &covariance_x_y_compensation, ax, by);
1051
+ nk_accumulate_product_f64_(&covariance_x_z, &covariance_x_z_compensation, ax, bz);
1052
+ nk_accumulate_product_f64_(&covariance_y_x, &covariance_y_x_compensation, ay, bx);
1053
+ nk_accumulate_product_f64_(&covariance_y_y, &covariance_y_y_compensation, ay, by);
1054
+ nk_accumulate_product_f64_(&covariance_y_z, &covariance_y_z_compensation, ay, bz);
1055
+ nk_accumulate_product_f64_(&covariance_z_x, &covariance_z_x_compensation, az, bx);
1056
+ nk_accumulate_product_f64_(&covariance_z_y, &covariance_z_y_compensation, az, by);
1057
+ nk_accumulate_product_f64_(&covariance_z_z, &covariance_z_z_compensation, az, bz);
1058
+ nk_accumulate_square_f64_(&variance_a_sum, &variance_a_compensation, ax);
1059
+ nk_accumulate_square_f64_(&variance_a_sum, &variance_a_compensation, ay);
1060
+ nk_accumulate_square_f64_(&variance_a_sum, &variance_a_compensation, az);
1061
+ }
1062
+
1063
+ sum_a_x += sum_a_x_compensation, sum_a_y += sum_a_y_compensation, sum_a_z += sum_a_z_compensation;
1064
+ sum_b_x += sum_b_x_compensation, sum_b_y += sum_b_y_compensation, sum_b_z += sum_b_z_compensation;
1065
+ covariance_x_x += covariance_x_x_compensation, covariance_x_y += covariance_x_y_compensation,
1066
+ covariance_x_z += covariance_x_z_compensation;
1067
+ covariance_y_x += covariance_y_x_compensation, covariance_y_y += covariance_y_y_compensation,
1068
+ covariance_y_z += covariance_y_z_compensation;
1069
+ covariance_z_x += covariance_z_x_compensation, covariance_z_y += covariance_z_y_compensation,
1070
+ covariance_z_z += covariance_z_z_compensation;
1071
+ variance_a_sum += variance_a_compensation;
1072
+
1073
+ // Compute centroids
1074
+ nk_f64_t inv_n = 1.0 / (nk_f64_t)n;
1075
+
1076
+ nk_f64_t centroid_a_x = sum_a_x * inv_n, centroid_a_y = sum_a_y * inv_n, centroid_a_z = sum_a_z * inv_n;
1077
+ nk_f64_t centroid_b_x = sum_b_x * inv_n, centroid_b_y = sum_b_y * inv_n, centroid_b_z = sum_b_z * inv_n;
1078
+
1079
+ if (a_centroid) a_centroid[0] = centroid_a_x, a_centroid[1] = centroid_a_y, a_centroid[2] = centroid_a_z;
1080
+ if (b_centroid) b_centroid[0] = centroid_b_x, b_centroid[1] = centroid_b_y, b_centroid[2] = centroid_b_z;
1081
+
1082
+ // Compute centered covariance and variance
1083
+ nk_f64_t variance_a = variance_a_sum * inv_n -
1084
+ (centroid_a_x * centroid_a_x + centroid_a_y * centroid_a_y + centroid_a_z * centroid_a_z);
1085
+
1086
+ nk_f64_t cross_covariance[9];
1087
+ cross_covariance[0] = covariance_x_x - sum_a_x * sum_b_x * inv_n;
1088
+ cross_covariance[1] = covariance_x_y - sum_a_x * sum_b_y * inv_n;
1089
+ cross_covariance[2] = covariance_x_z - sum_a_x * sum_b_z * inv_n;
1090
+ cross_covariance[3] = covariance_y_x - sum_a_y * sum_b_x * inv_n;
1091
+ cross_covariance[4] = covariance_y_y - sum_a_y * sum_b_y * inv_n;
1092
+ cross_covariance[5] = covariance_y_z - sum_a_y * sum_b_z * inv_n;
1093
+ cross_covariance[6] = covariance_z_x - sum_a_z * sum_b_x * inv_n;
1094
+ cross_covariance[7] = covariance_z_y - sum_a_z * sum_b_y * inv_n;
1095
+ cross_covariance[8] = covariance_z_z - sum_a_z * sum_b_z * inv_n;
1096
+
1097
+ // SVD using f64 for full precision (svd_s is 9-element diagonal matrix)
1098
+ nk_f64_t svd_u[9], svd_s[9], svd_v[9];
1099
+ nk_svd3x3_f64_(cross_covariance, svd_u, svd_s, svd_v);
1100
+
1101
+ nk_f64_t r[9];
1102
+ nk_rotation_from_svd_f64_haswell_(svd_u, svd_v, r);
1103
+
1104
+ // Scale factor: c = trace(D × S) / (n × variance(a))
1105
+ // svd_s diagonal: [0], [4], [8]
1106
+ nk_f64_t det = nk_det3x3_f64_(r);
1107
+ nk_f64_t d3 = det < 0 ? -1.0 : 1.0;
1108
+ nk_f64_t trace_ds = nk_sum_three_products_f64_(svd_s[0], 1.0, svd_s[4], 1.0, svd_s[8], d3);
1109
+ nk_f64_t c = trace_ds / (n * variance_a);
1110
+ if (scale) *scale = c;
1111
+
1112
+ // Handle reflection
1113
+ if (det < 0) {
1114
+ svd_v[2] = -svd_v[2], svd_v[5] = -svd_v[5], svd_v[8] = -svd_v[8];
1115
+ nk_rotation_from_svd_f64_haswell_(svd_u, svd_v, r);
1116
+ }
1117
+
1118
+ /* Output rotation matrix */
1119
+ if (rotation) {
1120
+ for (int j = 0; j < 9; ++j) rotation[j] = r[j];
1121
+ }
1122
+
1123
+ // Compute RMSD with scaling
1124
+ nk_f64_t sum_squared = nk_transformed_ssd_f64_haswell_(a, b, n, r, c, centroid_a_x, centroid_a_y, centroid_a_z,
1125
+ centroid_b_x, centroid_b_y, centroid_b_z);
1126
+ *result = nk_f64_sqrt_haswell(sum_squared * inv_n);
1127
+ }
1128
+
1129
+ /* Deinterleave 8 f16 xyz triplets (24 f16 values) and convert to 3 x __m256 f32.
1130
+ * Uses scalar extraction for clean stride-3 access, then F16C conversion.
1131
+ *
1132
+ * Input: 24 contiguous f16 [x0,y0,z0, x1,y1,z1, ..., x7,y7,z7]
1133
+ * Output: x[8], y[8], z[8] vectors in f32
1134
+ */
1135
+ NK_INTERNAL void nk_deinterleave_f16x8_to_f32x8_haswell_(nk_f16_t const *ptr, __m256 *x_out, __m256 *y_out,
1136
+ __m256 *z_out) {
1137
+ // Extract x, y, z components with stride-3 access
1138
+ nk_b256_vec_t x_vec, y_vec, z_vec;
1139
+ x_vec.f16s[0] = ptr[0], x_vec.f16s[1] = ptr[3], x_vec.f16s[2] = ptr[6], x_vec.f16s[3] = ptr[9];
1140
+ x_vec.f16s[4] = ptr[12], x_vec.f16s[5] = ptr[15], x_vec.f16s[6] = ptr[18], x_vec.f16s[7] = ptr[21];
1141
+ y_vec.f16s[0] = ptr[1], y_vec.f16s[1] = ptr[4], y_vec.f16s[2] = ptr[7], y_vec.f16s[3] = ptr[10];
1142
+ y_vec.f16s[4] = ptr[13], y_vec.f16s[5] = ptr[16], y_vec.f16s[6] = ptr[19], y_vec.f16s[7] = ptr[22];
1143
+ z_vec.f16s[0] = ptr[2], z_vec.f16s[1] = ptr[5], z_vec.f16s[2] = ptr[8], z_vec.f16s[3] = ptr[11];
1144
+ z_vec.f16s[4] = ptr[14], z_vec.f16s[5] = ptr[17], z_vec.f16s[6] = ptr[20], z_vec.f16s[7] = ptr[23];
1145
+ // Convert f16 to f32 using F16C
1146
+ *x_out = _mm256_cvtph_ps(x_vec.xmms[0]);
1147
+ *y_out = _mm256_cvtph_ps(y_vec.xmms[0]);
1148
+ *z_out = _mm256_cvtph_ps(z_vec.xmms[0]);
1149
+ }
1150
+
1151
+ /* Deinterleave 8 bf16 xyz triplets (24 bf16 values) and convert to 3 x __m256 f32.
1152
+ * Uses scalar extraction for clean stride-3 access, then bit-shift conversion.
1153
+ *
1154
+ * Input: 24 contiguous bf16 [x0,y0,z0, x1,y1,z1, ..., x7,y7,z7]
1155
+ * Output: x[8], y[8], z[8] vectors in f32
1156
+ */
1157
+ NK_INTERNAL void nk_deinterleave_bf16x8_to_f32x8_haswell_(nk_bf16_t const *ptr, __m256 *x_out, __m256 *y_out,
1158
+ __m256 *z_out) {
1159
+ // Extract x, y, z components with stride-3 access
1160
+ nk_b256_vec_t x_vec, y_vec, z_vec;
1161
+ x_vec.bf16s[0] = ptr[0], x_vec.bf16s[1] = ptr[3], x_vec.bf16s[2] = ptr[6], x_vec.bf16s[3] = ptr[9];
1162
+ x_vec.bf16s[4] = ptr[12], x_vec.bf16s[5] = ptr[15], x_vec.bf16s[6] = ptr[18], x_vec.bf16s[7] = ptr[21];
1163
+ y_vec.bf16s[0] = ptr[1], y_vec.bf16s[1] = ptr[4], y_vec.bf16s[2] = ptr[7], y_vec.bf16s[3] = ptr[10];
1164
+ y_vec.bf16s[4] = ptr[13], y_vec.bf16s[5] = ptr[16], y_vec.bf16s[6] = ptr[19], y_vec.bf16s[7] = ptr[22];
1165
+ z_vec.bf16s[0] = ptr[2], z_vec.bf16s[1] = ptr[5], z_vec.bf16s[2] = ptr[8], z_vec.bf16s[3] = ptr[11];
1166
+ z_vec.bf16s[4] = ptr[14], z_vec.bf16s[5] = ptr[17], z_vec.bf16s[6] = ptr[20], z_vec.bf16s[7] = ptr[23];
1167
+ // Convert bf16 to f32 by left-shifting 16 bits
1168
+ *x_out = nk_bf16x8_to_f32x8_haswell_(x_vec.xmms[0]);
1169
+ *y_out = nk_bf16x8_to_f32x8_haswell_(y_vec.xmms[0]);
1170
+ *z_out = nk_bf16x8_to_f32x8_haswell_(z_vec.xmms[0]);
1171
+ }
1172
+
1173
+ /* Compute sum of squared distances for f16 data after applying rotation (and optional scale).
1174
+ * Loads f16 data, converts to f32 during processing.
1175
+ * Note: rotation matrix r is f32 (from SVD), scale and computation done in f32.
1176
+ */
1177
+ NK_INTERNAL nk_f32_t nk_transformed_ssd_f16_haswell_(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n,
1178
+ nk_f32_t const *r, nk_f32_t scale, nk_f32_t centroid_a_x,
1179
+ nk_f32_t centroid_a_y, nk_f32_t centroid_a_z,
1180
+ nk_f32_t centroid_b_x, nk_f32_t centroid_b_y,
1181
+ nk_f32_t centroid_b_z) {
1182
+ // Broadcast scaled rotation matrix elements
1183
+ __m256 scaled_rotation_x_x_f32x8 = _mm256_set1_ps(scale * r[0]);
1184
+ __m256 scaled_rotation_x_y_f32x8 = _mm256_set1_ps(scale * r[1]);
1185
+ __m256 scaled_rotation_x_z_f32x8 = _mm256_set1_ps(scale * r[2]);
1186
+ __m256 scaled_rotation_y_x_f32x8 = _mm256_set1_ps(scale * r[3]);
1187
+ __m256 scaled_rotation_y_y_f32x8 = _mm256_set1_ps(scale * r[4]);
1188
+ __m256 scaled_rotation_y_z_f32x8 = _mm256_set1_ps(scale * r[5]);
1189
+ __m256 scaled_rotation_z_x_f32x8 = _mm256_set1_ps(scale * r[6]);
1190
+ __m256 scaled_rotation_z_y_f32x8 = _mm256_set1_ps(scale * r[7]);
1191
+ __m256 scaled_rotation_z_z_f32x8 = _mm256_set1_ps(scale * r[8]);
1192
+
1193
+ // Broadcast centroids
1194
+ __m256 centroid_a_x_f32x8 = _mm256_set1_ps(centroid_a_x);
1195
+ __m256 centroid_a_y_f32x8 = _mm256_set1_ps(centroid_a_y);
1196
+ __m256 centroid_a_z_f32x8 = _mm256_set1_ps(centroid_a_z);
1197
+ __m256 centroid_b_x_f32x8 = _mm256_set1_ps(centroid_b_x);
1198
+ __m256 centroid_b_y_f32x8 = _mm256_set1_ps(centroid_b_y);
1199
+ __m256 centroid_b_z_f32x8 = _mm256_set1_ps(centroid_b_z);
1200
+
1201
+ __m256 sum_squared_f32x8 = _mm256_setzero_ps();
1202
+ __m256 a_x_f32x8, a_y_f32x8, a_z_f32x8, b_x_f32x8, b_y_f32x8, b_z_f32x8;
1203
+ nk_size_t j = 0;
1204
+
1205
+ for (; j + 8 <= n; j += 8) {
1206
+ nk_deinterleave_f16x8_to_f32x8_haswell_(a + j * 3, &a_x_f32x8, &a_y_f32x8, &a_z_f32x8);
1207
+ nk_deinterleave_f16x8_to_f32x8_haswell_(b + j * 3, &b_x_f32x8, &b_y_f32x8, &b_z_f32x8);
1208
+
1209
+ // Center points
1210
+ __m256 pa_x_f32x8 = _mm256_sub_ps(a_x_f32x8, centroid_a_x_f32x8);
1211
+ __m256 pa_y_f32x8 = _mm256_sub_ps(a_y_f32x8, centroid_a_y_f32x8);
1212
+ __m256 pa_z_f32x8 = _mm256_sub_ps(a_z_f32x8, centroid_a_z_f32x8);
1213
+ __m256 pb_x_f32x8 = _mm256_sub_ps(b_x_f32x8, centroid_b_x_f32x8);
1214
+ __m256 pb_y_f32x8 = _mm256_sub_ps(b_y_f32x8, centroid_b_y_f32x8);
1215
+ __m256 pb_z_f32x8 = _mm256_sub_ps(b_z_f32x8, centroid_b_z_f32x8);
1216
+
1217
+ // Rotate and scale: ra = scale * R * pa
1218
+ __m256 ra_x_f32x8 = _mm256_fmadd_ps(scaled_rotation_x_z_f32x8, pa_z_f32x8,
1219
+ _mm256_fmadd_ps(scaled_rotation_x_y_f32x8, pa_y_f32x8,
1220
+ _mm256_mul_ps(scaled_rotation_x_x_f32x8, pa_x_f32x8)));
1221
+ __m256 ra_y_f32x8 = _mm256_fmadd_ps(scaled_rotation_y_z_f32x8, pa_z_f32x8,
1222
+ _mm256_fmadd_ps(scaled_rotation_y_y_f32x8, pa_y_f32x8,
1223
+ _mm256_mul_ps(scaled_rotation_y_x_f32x8, pa_x_f32x8)));
1224
+ __m256 ra_z_f32x8 = _mm256_fmadd_ps(scaled_rotation_z_z_f32x8, pa_z_f32x8,
1225
+ _mm256_fmadd_ps(scaled_rotation_z_y_f32x8, pa_y_f32x8,
1226
+ _mm256_mul_ps(scaled_rotation_z_x_f32x8, pa_x_f32x8)));
1227
+
1228
+ // Delta and accumulate
1229
+ __m256 delta_x_f32x8 = _mm256_sub_ps(ra_x_f32x8, pb_x_f32x8);
1230
+ __m256 delta_y_f32x8 = _mm256_sub_ps(ra_y_f32x8, pb_y_f32x8);
1231
+ __m256 delta_z_f32x8 = _mm256_sub_ps(ra_z_f32x8, pb_z_f32x8);
1232
+
1233
+ sum_squared_f32x8 = _mm256_fmadd_ps(delta_x_f32x8, delta_x_f32x8, sum_squared_f32x8);
1234
+ sum_squared_f32x8 = _mm256_fmadd_ps(delta_y_f32x8, delta_y_f32x8, sum_squared_f32x8);
1235
+ sum_squared_f32x8 = _mm256_fmadd_ps(delta_z_f32x8, delta_z_f32x8, sum_squared_f32x8);
1236
+ }
1237
+
1238
+ nk_f32_t sum_squared = nk_reduce_add_f32x8_haswell_(sum_squared_f32x8);
1239
+
1240
+ // Scalar tail
1241
+ for (; j < n; ++j) {
1242
+ nk_f32_t a_x_f32, a_y_f32, a_z_f32, b_x_f32, b_y_f32, b_z_f32;
1243
+ nk_f16_to_f32_haswell(&a[j * 3 + 0], &a_x_f32);
1244
+ nk_f16_to_f32_haswell(&a[j * 3 + 1], &a_y_f32);
1245
+ nk_f16_to_f32_haswell(&a[j * 3 + 2], &a_z_f32);
1246
+ nk_f16_to_f32_haswell(&b[j * 3 + 0], &b_x_f32);
1247
+ nk_f16_to_f32_haswell(&b[j * 3 + 1], &b_y_f32);
1248
+ nk_f16_to_f32_haswell(&b[j * 3 + 2], &b_z_f32);
1249
+
1250
+ nk_f32_t pa_x = a_x_f32 - centroid_a_x;
1251
+ nk_f32_t pa_y = a_y_f32 - centroid_a_y;
1252
+ nk_f32_t pa_z = a_z_f32 - centroid_a_z;
1253
+ nk_f32_t pb_x = b_x_f32 - centroid_b_x;
1254
+ nk_f32_t pb_y = b_y_f32 - centroid_b_y;
1255
+ nk_f32_t pb_z = b_z_f32 - centroid_b_z;
1256
+
1257
+ nk_f32_t ra_x = scale * (r[0] * pa_x + r[1] * pa_y + r[2] * pa_z);
1258
+ nk_f32_t ra_y = scale * (r[3] * pa_x + r[4] * pa_y + r[5] * pa_z);
1259
+ nk_f32_t ra_z = scale * (r[6] * pa_x + r[7] * pa_y + r[8] * pa_z);
1260
+
1261
+ nk_f32_t delta_x = ra_x - pb_x;
1262
+ nk_f32_t delta_y = ra_y - pb_y;
1263
+ nk_f32_t delta_z = ra_z - pb_z;
1264
+ sum_squared += delta_x * delta_x + delta_y * delta_y + delta_z * delta_z;
1265
+ }
1266
+
1267
+ return sum_squared;
1268
+ }
1269
+
1270
+ /* Compute sum of squared distances for bf16 data after applying rotation (and optional scale).
1271
+ * Loads bf16 data, converts to f32 during processing.
1272
+ * Note: rotation matrix r is f32 (from SVD), scale and computation done in f32.
1273
+ */
1274
+ NK_INTERNAL nk_f32_t nk_transformed_ssd_bf16_haswell_(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n,
1275
+ nk_f32_t const *r, nk_f32_t scale, nk_f32_t centroid_a_x,
1276
+ nk_f32_t centroid_a_y, nk_f32_t centroid_a_z,
1277
+ nk_f32_t centroid_b_x, nk_f32_t centroid_b_y,
1278
+ nk_f32_t centroid_b_z) {
1279
+ // Broadcast scaled rotation matrix elements
1280
+ __m256 scaled_rotation_x_x_f32x8 = _mm256_set1_ps(scale * r[0]);
1281
+ __m256 scaled_rotation_x_y_f32x8 = _mm256_set1_ps(scale * r[1]);
1282
+ __m256 scaled_rotation_x_z_f32x8 = _mm256_set1_ps(scale * r[2]);
1283
+ __m256 scaled_rotation_y_x_f32x8 = _mm256_set1_ps(scale * r[3]);
1284
+ __m256 scaled_rotation_y_y_f32x8 = _mm256_set1_ps(scale * r[4]);
1285
+ __m256 scaled_rotation_y_z_f32x8 = _mm256_set1_ps(scale * r[5]);
1286
+ __m256 scaled_rotation_z_x_f32x8 = _mm256_set1_ps(scale * r[6]);
1287
+ __m256 scaled_rotation_z_y_f32x8 = _mm256_set1_ps(scale * r[7]);
1288
+ __m256 scaled_rotation_z_z_f32x8 = _mm256_set1_ps(scale * r[8]);
1289
+
1290
+ // Broadcast centroids
1291
+ __m256 centroid_a_x_f32x8 = _mm256_set1_ps(centroid_a_x);
1292
+ __m256 centroid_a_y_f32x8 = _mm256_set1_ps(centroid_a_y);
1293
+ __m256 centroid_a_z_f32x8 = _mm256_set1_ps(centroid_a_z);
1294
+ __m256 centroid_b_x_f32x8 = _mm256_set1_ps(centroid_b_x);
1295
+ __m256 centroid_b_y_f32x8 = _mm256_set1_ps(centroid_b_y);
1296
+ __m256 centroid_b_z_f32x8 = _mm256_set1_ps(centroid_b_z);
1297
+
1298
+ __m256 sum_squared_f32x8 = _mm256_setzero_ps();
1299
+ __m256 a_x_f32x8, a_y_f32x8, a_z_f32x8, b_x_f32x8, b_y_f32x8, b_z_f32x8;
1300
+ nk_size_t j = 0;
1301
+
1302
+ for (; j + 8 <= n; j += 8) {
1303
+ nk_deinterleave_bf16x8_to_f32x8_haswell_(a + j * 3, &a_x_f32x8, &a_y_f32x8, &a_z_f32x8);
1304
+ nk_deinterleave_bf16x8_to_f32x8_haswell_(b + j * 3, &b_x_f32x8, &b_y_f32x8, &b_z_f32x8);
1305
+
1306
+ // Center points
1307
+ __m256 pa_x_f32x8 = _mm256_sub_ps(a_x_f32x8, centroid_a_x_f32x8);
1308
+ __m256 pa_y_f32x8 = _mm256_sub_ps(a_y_f32x8, centroid_a_y_f32x8);
1309
+ __m256 pa_z_f32x8 = _mm256_sub_ps(a_z_f32x8, centroid_a_z_f32x8);
1310
+ __m256 pb_x_f32x8 = _mm256_sub_ps(b_x_f32x8, centroid_b_x_f32x8);
1311
+ __m256 pb_y_f32x8 = _mm256_sub_ps(b_y_f32x8, centroid_b_y_f32x8);
1312
+ __m256 pb_z_f32x8 = _mm256_sub_ps(b_z_f32x8, centroid_b_z_f32x8);
1313
+
1314
+ // Rotate and scale: ra = scale * R * pa
1315
+ __m256 ra_x_f32x8 = _mm256_fmadd_ps(scaled_rotation_x_z_f32x8, pa_z_f32x8,
1316
+ _mm256_fmadd_ps(scaled_rotation_x_y_f32x8, pa_y_f32x8,
1317
+ _mm256_mul_ps(scaled_rotation_x_x_f32x8, pa_x_f32x8)));
1318
+ __m256 ra_y_f32x8 = _mm256_fmadd_ps(scaled_rotation_y_z_f32x8, pa_z_f32x8,
1319
+ _mm256_fmadd_ps(scaled_rotation_y_y_f32x8, pa_y_f32x8,
1320
+ _mm256_mul_ps(scaled_rotation_y_x_f32x8, pa_x_f32x8)));
1321
+ __m256 ra_z_f32x8 = _mm256_fmadd_ps(scaled_rotation_z_z_f32x8, pa_z_f32x8,
1322
+ _mm256_fmadd_ps(scaled_rotation_z_y_f32x8, pa_y_f32x8,
1323
+ _mm256_mul_ps(scaled_rotation_z_x_f32x8, pa_x_f32x8)));
1324
+
1325
+ // Delta and accumulate
1326
+ __m256 delta_x_f32x8 = _mm256_sub_ps(ra_x_f32x8, pb_x_f32x8);
1327
+ __m256 delta_y_f32x8 = _mm256_sub_ps(ra_y_f32x8, pb_y_f32x8);
1328
+ __m256 delta_z_f32x8 = _mm256_sub_ps(ra_z_f32x8, pb_z_f32x8);
1329
+
1330
+ sum_squared_f32x8 = _mm256_fmadd_ps(delta_x_f32x8, delta_x_f32x8, sum_squared_f32x8);
1331
+ sum_squared_f32x8 = _mm256_fmadd_ps(delta_y_f32x8, delta_y_f32x8, sum_squared_f32x8);
1332
+ sum_squared_f32x8 = _mm256_fmadd_ps(delta_z_f32x8, delta_z_f32x8, sum_squared_f32x8);
1333
+ }
1334
+
1335
+ nk_f32_t sum_squared = nk_reduce_add_f32x8_haswell_(sum_squared_f32x8);
1336
+
1337
+ // Scalar tail
1338
+ for (; j < n; ++j) {
1339
+ nk_f32_t a_x_f32, a_y_f32, a_z_f32, b_x_f32, b_y_f32, b_z_f32;
1340
+ nk_bf16_to_f32_serial(&a[j * 3 + 0], &a_x_f32);
1341
+ nk_bf16_to_f32_serial(&a[j * 3 + 1], &a_y_f32);
1342
+ nk_bf16_to_f32_serial(&a[j * 3 + 2], &a_z_f32);
1343
+ nk_bf16_to_f32_serial(&b[j * 3 + 0], &b_x_f32);
1344
+ nk_bf16_to_f32_serial(&b[j * 3 + 1], &b_y_f32);
1345
+ nk_bf16_to_f32_serial(&b[j * 3 + 2], &b_z_f32);
1346
+
1347
+ nk_f32_t pa_x = a_x_f32 - centroid_a_x;
1348
+ nk_f32_t pa_y = a_y_f32 - centroid_a_y;
1349
+ nk_f32_t pa_z = a_z_f32 - centroid_a_z;
1350
+ nk_f32_t pb_x = b_x_f32 - centroid_b_x;
1351
+ nk_f32_t pb_y = b_y_f32 - centroid_b_y;
1352
+ nk_f32_t pb_z = b_z_f32 - centroid_b_z;
1353
+
1354
+ nk_f32_t ra_x = scale * (r[0] * pa_x + r[1] * pa_y + r[2] * pa_z);
1355
+ nk_f32_t ra_y = scale * (r[3] * pa_x + r[4] * pa_y + r[5] * pa_z);
1356
+ nk_f32_t ra_z = scale * (r[6] * pa_x + r[7] * pa_y + r[8] * pa_z);
1357
+
1358
+ nk_f32_t delta_x = ra_x - pb_x;
1359
+ nk_f32_t delta_y = ra_y - pb_y;
1360
+ nk_f32_t delta_z = ra_z - pb_z;
1361
+ sum_squared += delta_x * delta_x + delta_y * delta_y + delta_z * delta_z;
1362
+ }
1363
+
1364
+ return sum_squared;
1365
+ }
1366
+
1367
+ NK_PUBLIC void nk_rmsd_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
1368
+ nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
1369
+ /* RMSD uses identity rotation and scale=1.0 */
1370
+ if (rotation) {
1371
+ rotation[0] = 1, rotation[1] = 0, rotation[2] = 0;
1372
+ rotation[3] = 0, rotation[4] = 1, rotation[5] = 0;
1373
+ rotation[6] = 0, rotation[7] = 0, rotation[8] = 1;
1374
+ }
1375
+ if (scale) *scale = 1.0f;
1376
+
1377
+ __m256 const zeros_f32x8 = _mm256_setzero_ps();
1378
+
1379
+ // Accumulators for centroids and squared differences (all in f32)
1380
+ __m256 sum_a_x_f32x8 = zeros_f32x8, sum_a_y_f32x8 = zeros_f32x8, sum_a_z_f32x8 = zeros_f32x8;
1381
+ __m256 sum_b_x_f32x8 = zeros_f32x8, sum_b_y_f32x8 = zeros_f32x8, sum_b_z_f32x8 = zeros_f32x8;
1382
+ __m256 sum_squared_x_f32x8 = zeros_f32x8, sum_squared_y_f32x8 = zeros_f32x8, sum_squared_z_f32x8 = zeros_f32x8;
1383
+
1384
+ __m256 a_x_f32x8, a_y_f32x8, a_z_f32x8, b_x_f32x8, b_y_f32x8, b_z_f32x8;
1385
+ nk_size_t i = 0;
1386
+
1387
+ // Main loop processing 8 points at a time
1388
+ for (; i + 8 <= n; i += 8) {
1389
+ nk_deinterleave_f16x8_to_f32x8_haswell_(a + i * 3, &a_x_f32x8, &a_y_f32x8, &a_z_f32x8);
1390
+ nk_deinterleave_f16x8_to_f32x8_haswell_(b + i * 3, &b_x_f32x8, &b_y_f32x8, &b_z_f32x8);
1391
+
1392
+ sum_a_x_f32x8 = _mm256_add_ps(sum_a_x_f32x8, a_x_f32x8);
1393
+ sum_a_y_f32x8 = _mm256_add_ps(sum_a_y_f32x8, a_y_f32x8);
1394
+ sum_a_z_f32x8 = _mm256_add_ps(sum_a_z_f32x8, a_z_f32x8);
1395
+ sum_b_x_f32x8 = _mm256_add_ps(sum_b_x_f32x8, b_x_f32x8);
1396
+ sum_b_y_f32x8 = _mm256_add_ps(sum_b_y_f32x8, b_y_f32x8);
1397
+ sum_b_z_f32x8 = _mm256_add_ps(sum_b_z_f32x8, b_z_f32x8);
1398
+
1399
+ __m256 delta_x_f32x8 = _mm256_sub_ps(a_x_f32x8, b_x_f32x8);
1400
+ __m256 delta_y_f32x8 = _mm256_sub_ps(a_y_f32x8, b_y_f32x8);
1401
+ __m256 delta_z_f32x8 = _mm256_sub_ps(a_z_f32x8, b_z_f32x8);
1402
+
1403
+ sum_squared_x_f32x8 = _mm256_fmadd_ps(delta_x_f32x8, delta_x_f32x8, sum_squared_x_f32x8);
1404
+ sum_squared_y_f32x8 = _mm256_fmadd_ps(delta_y_f32x8, delta_y_f32x8, sum_squared_y_f32x8);
1405
+ sum_squared_z_f32x8 = _mm256_fmadd_ps(delta_z_f32x8, delta_z_f32x8, sum_squared_z_f32x8);
1406
+ }
1407
+
1408
+ // Reduce vectors to scalars
1409
+ nk_f32_t total_ax = nk_reduce_add_f32x8_haswell_(sum_a_x_f32x8);
1410
+ nk_f32_t total_ay = nk_reduce_add_f32x8_haswell_(sum_a_y_f32x8);
1411
+ nk_f32_t total_az = nk_reduce_add_f32x8_haswell_(sum_a_z_f32x8);
1412
+ nk_f32_t total_bx = nk_reduce_add_f32x8_haswell_(sum_b_x_f32x8);
1413
+ nk_f32_t total_by = nk_reduce_add_f32x8_haswell_(sum_b_y_f32x8);
1414
+ nk_f32_t total_bz = nk_reduce_add_f32x8_haswell_(sum_b_z_f32x8);
1415
+ nk_f32_t total_sq_x = nk_reduce_add_f32x8_haswell_(sum_squared_x_f32x8);
1416
+ nk_f32_t total_sq_y = nk_reduce_add_f32x8_haswell_(sum_squared_y_f32x8);
1417
+ nk_f32_t total_sq_z = nk_reduce_add_f32x8_haswell_(sum_squared_z_f32x8);
1418
+
1419
+ // Scalar tail
1420
+ for (; i < n; ++i) {
1421
+ nk_f32_t ax, ay, az, bx, by, bz;
1422
+ nk_f16_to_f32_haswell(&a[i * 3 + 0], &ax);
1423
+ nk_f16_to_f32_haswell(&a[i * 3 + 1], &ay);
1424
+ nk_f16_to_f32_haswell(&a[i * 3 + 2], &az);
1425
+ nk_f16_to_f32_haswell(&b[i * 3 + 0], &bx);
1426
+ nk_f16_to_f32_haswell(&b[i * 3 + 1], &by);
1427
+ nk_f16_to_f32_haswell(&b[i * 3 + 2], &bz);
1428
+ total_ax += ax;
1429
+ total_ay += ay;
1430
+ total_az += az;
1431
+ total_bx += bx;
1432
+ total_by += by;
1433
+ total_bz += bz;
1434
+ nk_f32_t delta_x = ax - bx, delta_y = ay - by, delta_z = az - bz;
1435
+ total_sq_x += delta_x * delta_x;
1436
+ total_sq_y += delta_y * delta_y;
1437
+ total_sq_z += delta_z * delta_z;
1438
+ }
1439
+
1440
+ // Compute centroids
1441
+ nk_f32_t inv_n = 1.0f / (nk_f32_t)n;
1442
+ nk_f32_t centroid_a_x = total_ax * inv_n;
1443
+ nk_f32_t centroid_a_y = total_ay * inv_n;
1444
+ nk_f32_t centroid_a_z = total_az * inv_n;
1445
+ nk_f32_t centroid_b_x = total_bx * inv_n;
1446
+ nk_f32_t centroid_b_y = total_by * inv_n;
1447
+ nk_f32_t centroid_b_z = total_bz * inv_n;
1448
+
1449
+ if (a_centroid) {
1450
+ a_centroid[0] = centroid_a_x;
1451
+ a_centroid[1] = centroid_a_y;
1452
+ a_centroid[2] = centroid_a_z;
1453
+ }
1454
+ if (b_centroid) {
1455
+ b_centroid[0] = centroid_b_x;
1456
+ b_centroid[1] = centroid_b_y;
1457
+ b_centroid[2] = centroid_b_z;
1458
+ }
1459
+
1460
+ // Compute RMSD
1461
+ nk_f32_t mean_diff_x = centroid_a_x - centroid_b_x;
1462
+ nk_f32_t mean_diff_y = centroid_a_y - centroid_b_y;
1463
+ nk_f32_t mean_diff_z = centroid_a_z - centroid_b_z;
1464
+ nk_f32_t sum_squared = total_sq_x + total_sq_y + total_sq_z;
1465
+ nk_f32_t mean_diff_sq = mean_diff_x * mean_diff_x + mean_diff_y * mean_diff_y + mean_diff_z * mean_diff_z;
1466
+
1467
+ *result = nk_f32_sqrt_haswell(sum_squared * inv_n - mean_diff_sq);
1468
+ }
1469
+
1470
+ NK_PUBLIC void nk_rmsd_bf16_haswell(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
1471
+ nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
1472
+ /* RMSD uses identity rotation and scale=1.0 */
1473
+ if (rotation) {
1474
+ rotation[0] = 1, rotation[1] = 0, rotation[2] = 0;
1475
+ rotation[3] = 0, rotation[4] = 1, rotation[5] = 0;
1476
+ rotation[6] = 0, rotation[7] = 0, rotation[8] = 1;
1477
+ }
1478
+ if (scale) *scale = 1.0f;
1479
+
1480
+ __m256 const zeros_f32x8 = _mm256_setzero_ps();
1481
+
1482
+ // Accumulators for centroids and squared differences (all in f32)
1483
+ __m256 sum_a_x_f32x8 = zeros_f32x8, sum_a_y_f32x8 = zeros_f32x8, sum_a_z_f32x8 = zeros_f32x8;
1484
+ __m256 sum_b_x_f32x8 = zeros_f32x8, sum_b_y_f32x8 = zeros_f32x8, sum_b_z_f32x8 = zeros_f32x8;
1485
+ __m256 sum_squared_x_f32x8 = zeros_f32x8, sum_squared_y_f32x8 = zeros_f32x8, sum_squared_z_f32x8 = zeros_f32x8;
1486
+
1487
+ __m256 a_x_f32x8, a_y_f32x8, a_z_f32x8, b_x_f32x8, b_y_f32x8, b_z_f32x8;
1488
+ nk_size_t i = 0;
1489
+
1490
+ // Main loop processing 8 points at a time
1491
+ for (; i + 8 <= n; i += 8) {
1492
+ nk_deinterleave_bf16x8_to_f32x8_haswell_(a + i * 3, &a_x_f32x8, &a_y_f32x8, &a_z_f32x8);
1493
+ nk_deinterleave_bf16x8_to_f32x8_haswell_(b + i * 3, &b_x_f32x8, &b_y_f32x8, &b_z_f32x8);
1494
+
1495
+ sum_a_x_f32x8 = _mm256_add_ps(sum_a_x_f32x8, a_x_f32x8);
1496
+ sum_a_y_f32x8 = _mm256_add_ps(sum_a_y_f32x8, a_y_f32x8);
1497
+ sum_a_z_f32x8 = _mm256_add_ps(sum_a_z_f32x8, a_z_f32x8);
1498
+ sum_b_x_f32x8 = _mm256_add_ps(sum_b_x_f32x8, b_x_f32x8);
1499
+ sum_b_y_f32x8 = _mm256_add_ps(sum_b_y_f32x8, b_y_f32x8);
1500
+ sum_b_z_f32x8 = _mm256_add_ps(sum_b_z_f32x8, b_z_f32x8);
1501
+
1502
+ __m256 delta_x_f32x8 = _mm256_sub_ps(a_x_f32x8, b_x_f32x8);
1503
+ __m256 delta_y_f32x8 = _mm256_sub_ps(a_y_f32x8, b_y_f32x8);
1504
+ __m256 delta_z_f32x8 = _mm256_sub_ps(a_z_f32x8, b_z_f32x8);
1505
+
1506
+ sum_squared_x_f32x8 = _mm256_fmadd_ps(delta_x_f32x8, delta_x_f32x8, sum_squared_x_f32x8);
1507
+ sum_squared_y_f32x8 = _mm256_fmadd_ps(delta_y_f32x8, delta_y_f32x8, sum_squared_y_f32x8);
1508
+ sum_squared_z_f32x8 = _mm256_fmadd_ps(delta_z_f32x8, delta_z_f32x8, sum_squared_z_f32x8);
1509
+ }
1510
+
1511
+ // Reduce vectors to scalars
1512
+ nk_f32_t total_ax = nk_reduce_add_f32x8_haswell_(sum_a_x_f32x8);
1513
+ nk_f32_t total_ay = nk_reduce_add_f32x8_haswell_(sum_a_y_f32x8);
1514
+ nk_f32_t total_az = nk_reduce_add_f32x8_haswell_(sum_a_z_f32x8);
1515
+ nk_f32_t total_bx = nk_reduce_add_f32x8_haswell_(sum_b_x_f32x8);
1516
+ nk_f32_t total_by = nk_reduce_add_f32x8_haswell_(sum_b_y_f32x8);
1517
+ nk_f32_t total_bz = nk_reduce_add_f32x8_haswell_(sum_b_z_f32x8);
1518
+ nk_f32_t total_sq_x = nk_reduce_add_f32x8_haswell_(sum_squared_x_f32x8);
1519
+ nk_f32_t total_sq_y = nk_reduce_add_f32x8_haswell_(sum_squared_y_f32x8);
1520
+ nk_f32_t total_sq_z = nk_reduce_add_f32x8_haswell_(sum_squared_z_f32x8);
1521
+
1522
+ // Scalar tail
1523
+ for (; i < n; ++i) {
1524
+ nk_f32_t ax, ay, az, bx, by, bz;
1525
+ nk_bf16_to_f32_serial(&a[i * 3 + 0], &ax);
1526
+ nk_bf16_to_f32_serial(&a[i * 3 + 1], &ay);
1527
+ nk_bf16_to_f32_serial(&a[i * 3 + 2], &az);
1528
+ nk_bf16_to_f32_serial(&b[i * 3 + 0], &bx);
1529
+ nk_bf16_to_f32_serial(&b[i * 3 + 1], &by);
1530
+ nk_bf16_to_f32_serial(&b[i * 3 + 2], &bz);
1531
+ total_ax += ax;
1532
+ total_ay += ay;
1533
+ total_az += az;
1534
+ total_bx += bx;
1535
+ total_by += by;
1536
+ total_bz += bz;
1537
+ nk_f32_t delta_x = ax - bx, delta_y = ay - by, delta_z = az - bz;
1538
+ total_sq_x += delta_x * delta_x;
1539
+ total_sq_y += delta_y * delta_y;
1540
+ total_sq_z += delta_z * delta_z;
1541
+ }
1542
+
1543
+ // Compute centroids
1544
+ nk_f32_t inv_n = 1.0f / (nk_f32_t)n;
1545
+ nk_f32_t centroid_a_x = total_ax * inv_n;
1546
+ nk_f32_t centroid_a_y = total_ay * inv_n;
1547
+ nk_f32_t centroid_a_z = total_az * inv_n;
1548
+ nk_f32_t centroid_b_x = total_bx * inv_n;
1549
+ nk_f32_t centroid_b_y = total_by * inv_n;
1550
+ nk_f32_t centroid_b_z = total_bz * inv_n;
1551
+
1552
+ if (a_centroid) {
1553
+ a_centroid[0] = centroid_a_x;
1554
+ a_centroid[1] = centroid_a_y;
1555
+ a_centroid[2] = centroid_a_z;
1556
+ }
1557
+ if (b_centroid) {
1558
+ b_centroid[0] = centroid_b_x;
1559
+ b_centroid[1] = centroid_b_y;
1560
+ b_centroid[2] = centroid_b_z;
1561
+ }
1562
+
1563
+ // Compute RMSD
1564
+ nk_f32_t mean_diff_x = centroid_a_x - centroid_b_x;
1565
+ nk_f32_t mean_diff_y = centroid_a_y - centroid_b_y;
1566
+ nk_f32_t mean_diff_z = centroid_a_z - centroid_b_z;
1567
+ nk_f32_t sum_squared = total_sq_x + total_sq_y + total_sq_z;
1568
+ nk_f32_t mean_diff_sq = mean_diff_x * mean_diff_x + mean_diff_y * mean_diff_y + mean_diff_z * mean_diff_z;
1569
+
1570
+ *result = nk_f32_sqrt_haswell(sum_squared * inv_n - mean_diff_sq);
1571
+ }
1572
+
1573
+ NK_PUBLIC void nk_kabsch_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
1574
+ nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
1575
+ // Fused single-pass: load f16, convert to f32, compute centroids and covariance
1576
+ __m256 const zeros_f32x8 = _mm256_setzero_ps();
1577
+
1578
+ // Accumulators for centroids (f32)
1579
+ __m256 sum_a_x_f32x8 = zeros_f32x8, sum_a_y_f32x8 = zeros_f32x8, sum_a_z_f32x8 = zeros_f32x8;
1580
+ __m256 sum_b_x_f32x8 = zeros_f32x8, sum_b_y_f32x8 = zeros_f32x8, sum_b_z_f32x8 = zeros_f32x8;
1581
+
1582
+ // Accumulators for covariance matrix (sum of outer products)
1583
+ __m256 cov_xx_f32x8 = zeros_f32x8, cov_xy_f32x8 = zeros_f32x8, cov_xz_f32x8 = zeros_f32x8;
1584
+ __m256 cov_yx_f32x8 = zeros_f32x8, cov_yy_f32x8 = zeros_f32x8, cov_yz_f32x8 = zeros_f32x8;
1585
+ __m256 cov_zx_f32x8 = zeros_f32x8, cov_zy_f32x8 = zeros_f32x8, cov_zz_f32x8 = zeros_f32x8;
1586
+
1587
+ nk_size_t i = 0;
1588
+ __m256 a_x_f32x8, a_y_f32x8, a_z_f32x8, b_x_f32x8, b_y_f32x8, b_z_f32x8;
1589
+
1590
+ for (; i + 8 <= n; i += 8) {
1591
+ nk_deinterleave_f16x8_to_f32x8_haswell_(a + i * 3, &a_x_f32x8, &a_y_f32x8, &a_z_f32x8);
1592
+ nk_deinterleave_f16x8_to_f32x8_haswell_(b + i * 3, &b_x_f32x8, &b_y_f32x8, &b_z_f32x8);
1593
+
1594
+ // Accumulate centroids
1595
+ sum_a_x_f32x8 = _mm256_add_ps(sum_a_x_f32x8, a_x_f32x8);
1596
+ sum_a_y_f32x8 = _mm256_add_ps(sum_a_y_f32x8, a_y_f32x8);
1597
+ sum_a_z_f32x8 = _mm256_add_ps(sum_a_z_f32x8, a_z_f32x8);
1598
+ sum_b_x_f32x8 = _mm256_add_ps(sum_b_x_f32x8, b_x_f32x8);
1599
+ sum_b_y_f32x8 = _mm256_add_ps(sum_b_y_f32x8, b_y_f32x8);
1600
+ sum_b_z_f32x8 = _mm256_add_ps(sum_b_z_f32x8, b_z_f32x8);
1601
+
1602
+ // Accumulate outer products
1603
+ cov_xx_f32x8 = _mm256_fmadd_ps(a_x_f32x8, b_x_f32x8, cov_xx_f32x8);
1604
+ cov_xy_f32x8 = _mm256_fmadd_ps(a_x_f32x8, b_y_f32x8, cov_xy_f32x8);
1605
+ cov_xz_f32x8 = _mm256_fmadd_ps(a_x_f32x8, b_z_f32x8, cov_xz_f32x8);
1606
+ cov_yx_f32x8 = _mm256_fmadd_ps(a_y_f32x8, b_x_f32x8, cov_yx_f32x8);
1607
+ cov_yy_f32x8 = _mm256_fmadd_ps(a_y_f32x8, b_y_f32x8, cov_yy_f32x8);
1608
+ cov_yz_f32x8 = _mm256_fmadd_ps(a_y_f32x8, b_z_f32x8, cov_yz_f32x8);
1609
+ cov_zx_f32x8 = _mm256_fmadd_ps(a_z_f32x8, b_x_f32x8, cov_zx_f32x8);
1610
+ cov_zy_f32x8 = _mm256_fmadd_ps(a_z_f32x8, b_y_f32x8, cov_zy_f32x8);
1611
+ cov_zz_f32x8 = _mm256_fmadd_ps(a_z_f32x8, b_z_f32x8, cov_zz_f32x8);
1612
+ }
1613
+
1614
+ // Reduce vector accumulators
1615
+ nk_f32_t sum_a_x = nk_reduce_add_f32x8_haswell_(sum_a_x_f32x8);
1616
+ nk_f32_t sum_a_y = nk_reduce_add_f32x8_haswell_(sum_a_y_f32x8);
1617
+ nk_f32_t sum_a_z = nk_reduce_add_f32x8_haswell_(sum_a_z_f32x8);
1618
+ nk_f32_t sum_b_x = nk_reduce_add_f32x8_haswell_(sum_b_x_f32x8);
1619
+ nk_f32_t sum_b_y = nk_reduce_add_f32x8_haswell_(sum_b_y_f32x8);
1620
+ nk_f32_t sum_b_z = nk_reduce_add_f32x8_haswell_(sum_b_z_f32x8);
1621
+
1622
+ nk_f32_t covariance_x_x = nk_reduce_add_f32x8_haswell_(cov_xx_f32x8);
1623
+ nk_f32_t covariance_x_y = nk_reduce_add_f32x8_haswell_(cov_xy_f32x8);
1624
+ nk_f32_t covariance_x_z = nk_reduce_add_f32x8_haswell_(cov_xz_f32x8);
1625
+ nk_f32_t covariance_y_x = nk_reduce_add_f32x8_haswell_(cov_yx_f32x8);
1626
+ nk_f32_t covariance_y_y = nk_reduce_add_f32x8_haswell_(cov_yy_f32x8);
1627
+ nk_f32_t covariance_y_z = nk_reduce_add_f32x8_haswell_(cov_yz_f32x8);
1628
+ nk_f32_t covariance_z_x = nk_reduce_add_f32x8_haswell_(cov_zx_f32x8);
1629
+ nk_f32_t covariance_z_y = nk_reduce_add_f32x8_haswell_(cov_zy_f32x8);
1630
+ nk_f32_t covariance_z_z = nk_reduce_add_f32x8_haswell_(cov_zz_f32x8);
1631
+
1632
+ // Scalar tail
1633
+ for (; i < n; ++i) {
1634
+ nk_f32_t ax, ay, az, bx, by, bz;
1635
+ nk_f16_to_f32_haswell(&a[i * 3 + 0], &ax);
1636
+ nk_f16_to_f32_haswell(&a[i * 3 + 1], &ay);
1637
+ nk_f16_to_f32_haswell(&a[i * 3 + 2], &az);
1638
+ nk_f16_to_f32_haswell(&b[i * 3 + 0], &bx);
1639
+ nk_f16_to_f32_haswell(&b[i * 3 + 1], &by);
1640
+ nk_f16_to_f32_haswell(&b[i * 3 + 2], &bz);
1641
+ sum_a_x += ax;
1642
+ sum_a_y += ay;
1643
+ sum_a_z += az;
1644
+ sum_b_x += bx;
1645
+ sum_b_y += by;
1646
+ sum_b_z += bz;
1647
+ covariance_x_x += ax * bx;
1648
+ covariance_x_y += ax * by;
1649
+ covariance_x_z += ax * bz;
1650
+ covariance_y_x += ay * bx;
1651
+ covariance_y_y += ay * by;
1652
+ covariance_y_z += ay * bz;
1653
+ covariance_z_x += az * bx;
1654
+ covariance_z_y += az * by;
1655
+ covariance_z_z += az * bz;
1656
+ }
1657
+
1658
+ // Compute centroids
1659
+ nk_f32_t inv_n = 1.0f / (nk_f32_t)n;
1660
+ nk_f32_t centroid_a_x = sum_a_x * inv_n;
1661
+ nk_f32_t centroid_a_y = sum_a_y * inv_n;
1662
+ nk_f32_t centroid_a_z = sum_a_z * inv_n;
1663
+ nk_f32_t centroid_b_x = sum_b_x * inv_n;
1664
+ nk_f32_t centroid_b_y = sum_b_y * inv_n;
1665
+ nk_f32_t centroid_b_z = sum_b_z * inv_n;
1666
+
1667
+ if (a_centroid) {
1668
+ a_centroid[0] = centroid_a_x;
1669
+ a_centroid[1] = centroid_a_y;
1670
+ a_centroid[2] = centroid_a_z;
1671
+ }
1672
+ if (b_centroid) {
1673
+ b_centroid[0] = centroid_b_x;
1674
+ b_centroid[1] = centroid_b_y;
1675
+ b_centroid[2] = centroid_b_z;
1676
+ }
1677
+
1678
+ // Apply centering correction: H_centered = H - n * centroid_a * centroid_bᵀ
1679
+ covariance_x_x -= n * centroid_a_x * centroid_b_x;
1680
+ covariance_x_y -= n * centroid_a_x * centroid_b_y;
1681
+ covariance_x_z -= n * centroid_a_x * centroid_b_z;
1682
+ covariance_y_x -= n * centroid_a_y * centroid_b_x;
1683
+ covariance_y_y -= n * centroid_a_y * centroid_b_y;
1684
+ covariance_y_z -= n * centroid_a_y * centroid_b_z;
1685
+ covariance_z_x -= n * centroid_a_z * centroid_b_x;
1686
+ covariance_z_y -= n * centroid_a_z * centroid_b_y;
1687
+ covariance_z_z -= n * centroid_a_z * centroid_b_z;
1688
+
1689
+ // Compute SVD and optimal rotation
1690
+ nk_f32_t cross_covariance[9] = {covariance_x_x, covariance_x_y, covariance_x_z, covariance_y_x, covariance_y_y,
1691
+ covariance_y_z, covariance_z_x, covariance_z_y, covariance_z_z};
1692
+ nk_f32_t svd_u[9], svd_s[9], svd_v[9];
1693
+ nk_svd3x3_f32_(cross_covariance, svd_u, svd_s, svd_v);
1694
+
1695
+ // R = V * Uᵀ
1696
+ nk_f32_t r[9];
1697
+ r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
1698
+ r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
1699
+ r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
1700
+ r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
1701
+ r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
1702
+ r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
1703
+ r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
1704
+ r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
1705
+ r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
1706
+
1707
+ // Handle reflection: if det(R) < 0, negate third column of V and recompute R
1708
+ if (nk_det3x3_f32_(r) < 0) {
1709
+ svd_v[2] = -svd_v[2];
1710
+ svd_v[5] = -svd_v[5];
1711
+ svd_v[8] = -svd_v[8];
1712
+ r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
1713
+ r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
1714
+ r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
1715
+ r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
1716
+ r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
1717
+ r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
1718
+ r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
1719
+ r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
1720
+ r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
1721
+ }
1722
+
1723
+ /* Output rotation matrix and scale=1.0 */
1724
+ if (rotation) {
1725
+ for (int j = 0; j < 9; ++j) rotation[j] = r[j];
1726
+ }
1727
+ if (scale) *scale = 1.0f;
1728
+
1729
+ // Compute RMSD after optimal rotation
1730
+ nk_f32_t sum_squared = nk_transformed_ssd_f16_haswell_(a, b, n, r, 1.0f, centroid_a_x, centroid_a_y, centroid_a_z,
1731
+ centroid_b_x, centroid_b_y, centroid_b_z);
1732
+ *result = nk_f32_sqrt_haswell(sum_squared * inv_n);
1733
+ }
1734
+
1735
+ NK_PUBLIC void nk_kabsch_bf16_haswell(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
1736
+ nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
1737
+ // Fused single-pass: load bf16, convert to f32, compute centroids and covariance
1738
+ __m256 const zeros_f32x8 = _mm256_setzero_ps();
1739
+
1740
+ // Accumulators for centroids (f32)
1741
+ __m256 sum_a_x_f32x8 = zeros_f32x8, sum_a_y_f32x8 = zeros_f32x8, sum_a_z_f32x8 = zeros_f32x8;
1742
+ __m256 sum_b_x_f32x8 = zeros_f32x8, sum_b_y_f32x8 = zeros_f32x8, sum_b_z_f32x8 = zeros_f32x8;
1743
+
1744
+ // Accumulators for covariance matrix (sum of outer products)
1745
+ __m256 cov_xx_f32x8 = zeros_f32x8, cov_xy_f32x8 = zeros_f32x8, cov_xz_f32x8 = zeros_f32x8;
1746
+ __m256 cov_yx_f32x8 = zeros_f32x8, cov_yy_f32x8 = zeros_f32x8, cov_yz_f32x8 = zeros_f32x8;
1747
+ __m256 cov_zx_f32x8 = zeros_f32x8, cov_zy_f32x8 = zeros_f32x8, cov_zz_f32x8 = zeros_f32x8;
1748
+
1749
+ nk_size_t i = 0;
1750
+ __m256 a_x_f32x8, a_y_f32x8, a_z_f32x8, b_x_f32x8, b_y_f32x8, b_z_f32x8;
1751
+
1752
+ for (; i + 8 <= n; i += 8) {
1753
+ nk_deinterleave_bf16x8_to_f32x8_haswell_(a + i * 3, &a_x_f32x8, &a_y_f32x8, &a_z_f32x8);
1754
+ nk_deinterleave_bf16x8_to_f32x8_haswell_(b + i * 3, &b_x_f32x8, &b_y_f32x8, &b_z_f32x8);
1755
+
1756
+ // Accumulate centroids
1757
+ sum_a_x_f32x8 = _mm256_add_ps(sum_a_x_f32x8, a_x_f32x8);
1758
+ sum_a_y_f32x8 = _mm256_add_ps(sum_a_y_f32x8, a_y_f32x8);
1759
+ sum_a_z_f32x8 = _mm256_add_ps(sum_a_z_f32x8, a_z_f32x8);
1760
+ sum_b_x_f32x8 = _mm256_add_ps(sum_b_x_f32x8, b_x_f32x8);
1761
+ sum_b_y_f32x8 = _mm256_add_ps(sum_b_y_f32x8, b_y_f32x8);
1762
+ sum_b_z_f32x8 = _mm256_add_ps(sum_b_z_f32x8, b_z_f32x8);
1763
+
1764
+ // Accumulate outer products
1765
+ cov_xx_f32x8 = _mm256_fmadd_ps(a_x_f32x8, b_x_f32x8, cov_xx_f32x8);
1766
+ cov_xy_f32x8 = _mm256_fmadd_ps(a_x_f32x8, b_y_f32x8, cov_xy_f32x8);
1767
+ cov_xz_f32x8 = _mm256_fmadd_ps(a_x_f32x8, b_z_f32x8, cov_xz_f32x8);
1768
+ cov_yx_f32x8 = _mm256_fmadd_ps(a_y_f32x8, b_x_f32x8, cov_yx_f32x8);
1769
+ cov_yy_f32x8 = _mm256_fmadd_ps(a_y_f32x8, b_y_f32x8, cov_yy_f32x8);
1770
+ cov_yz_f32x8 = _mm256_fmadd_ps(a_y_f32x8, b_z_f32x8, cov_yz_f32x8);
1771
+ cov_zx_f32x8 = _mm256_fmadd_ps(a_z_f32x8, b_x_f32x8, cov_zx_f32x8);
1772
+ cov_zy_f32x8 = _mm256_fmadd_ps(a_z_f32x8, b_y_f32x8, cov_zy_f32x8);
1773
+ cov_zz_f32x8 = _mm256_fmadd_ps(a_z_f32x8, b_z_f32x8, cov_zz_f32x8);
1774
+ }
1775
+
1776
+ // Reduce vector accumulators
1777
+ nk_f32_t sum_a_x = nk_reduce_add_f32x8_haswell_(sum_a_x_f32x8);
1778
+ nk_f32_t sum_a_y = nk_reduce_add_f32x8_haswell_(sum_a_y_f32x8);
1779
+ nk_f32_t sum_a_z = nk_reduce_add_f32x8_haswell_(sum_a_z_f32x8);
1780
+ nk_f32_t sum_b_x = nk_reduce_add_f32x8_haswell_(sum_b_x_f32x8);
1781
+ nk_f32_t sum_b_y = nk_reduce_add_f32x8_haswell_(sum_b_y_f32x8);
1782
+ nk_f32_t sum_b_z = nk_reduce_add_f32x8_haswell_(sum_b_z_f32x8);
1783
+
1784
+ nk_f32_t covariance_x_x = nk_reduce_add_f32x8_haswell_(cov_xx_f32x8);
1785
+ nk_f32_t covariance_x_y = nk_reduce_add_f32x8_haswell_(cov_xy_f32x8);
1786
+ nk_f32_t covariance_x_z = nk_reduce_add_f32x8_haswell_(cov_xz_f32x8);
1787
+ nk_f32_t covariance_y_x = nk_reduce_add_f32x8_haswell_(cov_yx_f32x8);
1788
+ nk_f32_t covariance_y_y = nk_reduce_add_f32x8_haswell_(cov_yy_f32x8);
1789
+ nk_f32_t covariance_y_z = nk_reduce_add_f32x8_haswell_(cov_yz_f32x8);
1790
+ nk_f32_t covariance_z_x = nk_reduce_add_f32x8_haswell_(cov_zx_f32x8);
1791
+ nk_f32_t covariance_z_y = nk_reduce_add_f32x8_haswell_(cov_zy_f32x8);
1792
+ nk_f32_t covariance_z_z = nk_reduce_add_f32x8_haswell_(cov_zz_f32x8);
1793
+
1794
+ // Scalar tail
1795
+ for (; i < n; ++i) {
1796
+ nk_f32_t ax, ay, az, bx, by, bz;
1797
+ nk_bf16_to_f32_serial(&a[i * 3 + 0], &ax);
1798
+ nk_bf16_to_f32_serial(&a[i * 3 + 1], &ay);
1799
+ nk_bf16_to_f32_serial(&a[i * 3 + 2], &az);
1800
+ nk_bf16_to_f32_serial(&b[i * 3 + 0], &bx);
1801
+ nk_bf16_to_f32_serial(&b[i * 3 + 1], &by);
1802
+ nk_bf16_to_f32_serial(&b[i * 3 + 2], &bz);
1803
+ sum_a_x += ax;
1804
+ sum_a_y += ay;
1805
+ sum_a_z += az;
1806
+ sum_b_x += bx;
1807
+ sum_b_y += by;
1808
+ sum_b_z += bz;
1809
+ covariance_x_x += ax * bx;
1810
+ covariance_x_y += ax * by;
1811
+ covariance_x_z += ax * bz;
1812
+ covariance_y_x += ay * bx;
1813
+ covariance_y_y += ay * by;
1814
+ covariance_y_z += ay * bz;
1815
+ covariance_z_x += az * bx;
1816
+ covariance_z_y += az * by;
1817
+ covariance_z_z += az * bz;
1818
+ }
1819
+
1820
+ // Compute centroids
1821
+ nk_f32_t inv_n = 1.0f / (nk_f32_t)n;
1822
+ nk_f32_t centroid_a_x = sum_a_x * inv_n;
1823
+ nk_f32_t centroid_a_y = sum_a_y * inv_n;
1824
+ nk_f32_t centroid_a_z = sum_a_z * inv_n;
1825
+ nk_f32_t centroid_b_x = sum_b_x * inv_n;
1826
+ nk_f32_t centroid_b_y = sum_b_y * inv_n;
1827
+ nk_f32_t centroid_b_z = sum_b_z * inv_n;
1828
+
1829
+ if (a_centroid) {
1830
+ a_centroid[0] = centroid_a_x;
1831
+ a_centroid[1] = centroid_a_y;
1832
+ a_centroid[2] = centroid_a_z;
1833
+ }
1834
+ if (b_centroid) {
1835
+ b_centroid[0] = centroid_b_x;
1836
+ b_centroid[1] = centroid_b_y;
1837
+ b_centroid[2] = centroid_b_z;
1838
+ }
1839
+
1840
+ // Apply centering correction: H_centered = H - n * centroid_a * centroid_bᵀ
1841
+ covariance_x_x -= n * centroid_a_x * centroid_b_x;
1842
+ covariance_x_y -= n * centroid_a_x * centroid_b_y;
1843
+ covariance_x_z -= n * centroid_a_x * centroid_b_z;
1844
+ covariance_y_x -= n * centroid_a_y * centroid_b_x;
1845
+ covariance_y_y -= n * centroid_a_y * centroid_b_y;
1846
+ covariance_y_z -= n * centroid_a_y * centroid_b_z;
1847
+ covariance_z_x -= n * centroid_a_z * centroid_b_x;
1848
+ covariance_z_y -= n * centroid_a_z * centroid_b_y;
1849
+ covariance_z_z -= n * centroid_a_z * centroid_b_z;
1850
+
1851
+ // Compute SVD and optimal rotation
1852
+ nk_f32_t cross_covariance[9] = {covariance_x_x, covariance_x_y, covariance_x_z, covariance_y_x, covariance_y_y,
1853
+ covariance_y_z, covariance_z_x, covariance_z_y, covariance_z_z};
1854
+ nk_f32_t svd_u[9], svd_s[9], svd_v[9];
1855
+ nk_svd3x3_f32_(cross_covariance, svd_u, svd_s, svd_v);
1856
+
1857
+ // R = V * Uᵀ
1858
+ nk_f32_t r[9];
1859
+ r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
1860
+ r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
1861
+ r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
1862
+ r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
1863
+ r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
1864
+ r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
1865
+ r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
1866
+ r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
1867
+ r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
1868
+
1869
+ // Handle reflection: if det(R) < 0, negate third column of V and recompute R
1870
+ if (nk_det3x3_f32_(r) < 0) {
1871
+ svd_v[2] = -svd_v[2];
1872
+ svd_v[5] = -svd_v[5];
1873
+ svd_v[8] = -svd_v[8];
1874
+ r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
1875
+ r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
1876
+ r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
1877
+ r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
1878
+ r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
1879
+ r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
1880
+ r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
1881
+ r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
1882
+ r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
1883
+ }
1884
+
1885
+ /* Output rotation matrix and scale=1.0 */
1886
+ if (rotation) {
1887
+ for (int j = 0; j < 9; ++j) rotation[j] = r[j];
1888
+ }
1889
+ if (scale) *scale = 1.0f;
1890
+
1891
+ // Compute RMSD after optimal rotation
1892
+ nk_f32_t sum_squared = nk_transformed_ssd_bf16_haswell_(a, b, n, r, 1.0f, centroid_a_x, centroid_a_y, centroid_a_z,
1893
+ centroid_b_x, centroid_b_y, centroid_b_z);
1894
+ *result = nk_f32_sqrt_haswell(sum_squared * inv_n);
1895
+ }
1896
+
1897
+ NK_PUBLIC void nk_umeyama_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
1898
+ nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
1899
+ // Fused single-pass: load f16, convert to f32, compute centroids, covariance, and variance
1900
+ __m256 const zeros_f32x8 = _mm256_setzero_ps();
1901
+
1902
+ __m256 sum_a_x_f32x8 = zeros_f32x8, sum_a_y_f32x8 = zeros_f32x8, sum_a_z_f32x8 = zeros_f32x8;
1903
+ __m256 sum_b_x_f32x8 = zeros_f32x8, sum_b_y_f32x8 = zeros_f32x8, sum_b_z_f32x8 = zeros_f32x8;
1904
+ __m256 cov_xx_f32x8 = zeros_f32x8, cov_xy_f32x8 = zeros_f32x8, cov_xz_f32x8 = zeros_f32x8;
1905
+ __m256 cov_yx_f32x8 = zeros_f32x8, cov_yy_f32x8 = zeros_f32x8, cov_yz_f32x8 = zeros_f32x8;
1906
+ __m256 cov_zx_f32x8 = zeros_f32x8, cov_zy_f32x8 = zeros_f32x8, cov_zz_f32x8 = zeros_f32x8;
1907
+ __m256 variance_a_f32x8 = zeros_f32x8;
1908
+
1909
+ nk_size_t i = 0;
1910
+ __m256 a_x_f32x8, a_y_f32x8, a_z_f32x8, b_x_f32x8, b_y_f32x8, b_z_f32x8;
1911
+
1912
+ for (; i + 8 <= n; i += 8) {
1913
+ nk_deinterleave_f16x8_to_f32x8_haswell_(a + i * 3, &a_x_f32x8, &a_y_f32x8, &a_z_f32x8);
1914
+ nk_deinterleave_f16x8_to_f32x8_haswell_(b + i * 3, &b_x_f32x8, &b_y_f32x8, &b_z_f32x8);
1915
+
1916
+ // Accumulate centroids
1917
+ sum_a_x_f32x8 = _mm256_add_ps(sum_a_x_f32x8, a_x_f32x8);
1918
+ sum_a_y_f32x8 = _mm256_add_ps(sum_a_y_f32x8, a_y_f32x8);
1919
+ sum_a_z_f32x8 = _mm256_add_ps(sum_a_z_f32x8, a_z_f32x8);
1920
+ sum_b_x_f32x8 = _mm256_add_ps(sum_b_x_f32x8, b_x_f32x8);
1921
+ sum_b_y_f32x8 = _mm256_add_ps(sum_b_y_f32x8, b_y_f32x8);
1922
+ sum_b_z_f32x8 = _mm256_add_ps(sum_b_z_f32x8, b_z_f32x8);
1923
+
1924
+ // Accumulate outer products
1925
+ cov_xx_f32x8 = _mm256_fmadd_ps(a_x_f32x8, b_x_f32x8, cov_xx_f32x8);
1926
+ cov_xy_f32x8 = _mm256_fmadd_ps(a_x_f32x8, b_y_f32x8, cov_xy_f32x8);
1927
+ cov_xz_f32x8 = _mm256_fmadd_ps(a_x_f32x8, b_z_f32x8, cov_xz_f32x8);
1928
+ cov_yx_f32x8 = _mm256_fmadd_ps(a_y_f32x8, b_x_f32x8, cov_yx_f32x8);
1929
+ cov_yy_f32x8 = _mm256_fmadd_ps(a_y_f32x8, b_y_f32x8, cov_yy_f32x8);
1930
+ cov_yz_f32x8 = _mm256_fmadd_ps(a_y_f32x8, b_z_f32x8, cov_yz_f32x8);
1931
+ cov_zx_f32x8 = _mm256_fmadd_ps(a_z_f32x8, b_x_f32x8, cov_zx_f32x8);
1932
+ cov_zy_f32x8 = _mm256_fmadd_ps(a_z_f32x8, b_y_f32x8, cov_zy_f32x8);
1933
+ cov_zz_f32x8 = _mm256_fmadd_ps(a_z_f32x8, b_z_f32x8, cov_zz_f32x8);
1934
+
1935
+ // Accumulate variance of A
1936
+ variance_a_f32x8 = _mm256_fmadd_ps(a_x_f32x8, a_x_f32x8, variance_a_f32x8);
1937
+ variance_a_f32x8 = _mm256_fmadd_ps(a_y_f32x8, a_y_f32x8, variance_a_f32x8);
1938
+ variance_a_f32x8 = _mm256_fmadd_ps(a_z_f32x8, a_z_f32x8, variance_a_f32x8);
1939
+ }
1940
+
1941
+ // Reduce vector accumulators
1942
+ nk_f32_t sum_a_x = nk_reduce_add_f32x8_haswell_(sum_a_x_f32x8);
1943
+ nk_f32_t sum_a_y = nk_reduce_add_f32x8_haswell_(sum_a_y_f32x8);
1944
+ nk_f32_t sum_a_z = nk_reduce_add_f32x8_haswell_(sum_a_z_f32x8);
1945
+ nk_f32_t sum_b_x = nk_reduce_add_f32x8_haswell_(sum_b_x_f32x8);
1946
+ nk_f32_t sum_b_y = nk_reduce_add_f32x8_haswell_(sum_b_y_f32x8);
1947
+ nk_f32_t sum_b_z = nk_reduce_add_f32x8_haswell_(sum_b_z_f32x8);
1948
+ nk_f32_t covariance_x_x = nk_reduce_add_f32x8_haswell_(cov_xx_f32x8);
1949
+ nk_f32_t covariance_x_y = nk_reduce_add_f32x8_haswell_(cov_xy_f32x8);
1950
+ nk_f32_t covariance_x_z = nk_reduce_add_f32x8_haswell_(cov_xz_f32x8);
1951
+ nk_f32_t covariance_y_x = nk_reduce_add_f32x8_haswell_(cov_yx_f32x8);
1952
+ nk_f32_t covariance_y_y = nk_reduce_add_f32x8_haswell_(cov_yy_f32x8);
1953
+ nk_f32_t covariance_y_z = nk_reduce_add_f32x8_haswell_(cov_yz_f32x8);
1954
+ nk_f32_t covariance_z_x = nk_reduce_add_f32x8_haswell_(cov_zx_f32x8);
1955
+ nk_f32_t covariance_z_y = nk_reduce_add_f32x8_haswell_(cov_zy_f32x8);
1956
+ nk_f32_t covariance_z_z = nk_reduce_add_f32x8_haswell_(cov_zz_f32x8);
1957
+ nk_f32_t variance_a_sum = nk_reduce_add_f32x8_haswell_(variance_a_f32x8);
1958
+
1959
+ // Scalar tail
1960
+ for (; i < n; ++i) {
1961
+ nk_f32_t ax, ay, az, bx, by, bz;
1962
+ nk_f16_to_f32_haswell(&a[i * 3 + 0], &ax);
1963
+ nk_f16_to_f32_haswell(&a[i * 3 + 1], &ay);
1964
+ nk_f16_to_f32_haswell(&a[i * 3 + 2], &az);
1965
+ nk_f16_to_f32_haswell(&b[i * 3 + 0], &bx);
1966
+ nk_f16_to_f32_haswell(&b[i * 3 + 1], &by);
1967
+ nk_f16_to_f32_haswell(&b[i * 3 + 2], &bz);
1968
+ sum_a_x += ax;
1969
+ sum_a_y += ay;
1970
+ sum_a_z += az;
1971
+ sum_b_x += bx;
1972
+ sum_b_y += by;
1973
+ sum_b_z += bz;
1974
+ covariance_x_x += ax * bx;
1975
+ covariance_x_y += ax * by;
1976
+ covariance_x_z += ax * bz;
1977
+ covariance_y_x += ay * bx;
1978
+ covariance_y_y += ay * by;
1979
+ covariance_y_z += ay * bz;
1980
+ covariance_z_x += az * bx;
1981
+ covariance_z_y += az * by;
1982
+ covariance_z_z += az * bz;
1983
+ variance_a_sum += ax * ax + ay * ay + az * az;
1984
+ }
1985
+
1986
+ // Compute centroids
1987
+ nk_f32_t inv_n = 1.0f / (nk_f32_t)n;
1988
+ nk_f32_t centroid_a_x = sum_a_x * inv_n, centroid_a_y = sum_a_y * inv_n, centroid_a_z = sum_a_z * inv_n;
1989
+ nk_f32_t centroid_b_x = sum_b_x * inv_n, centroid_b_y = sum_b_y * inv_n, centroid_b_z = sum_b_z * inv_n;
1990
+
1991
+ if (a_centroid) a_centroid[0] = centroid_a_x, a_centroid[1] = centroid_a_y, a_centroid[2] = centroid_a_z;
1992
+ if (b_centroid) b_centroid[0] = centroid_b_x, b_centroid[1] = centroid_b_y, b_centroid[2] = centroid_b_z;
1993
+
1994
+ // Compute centered covariance and variance
1995
+ nk_f32_t variance_a = variance_a_sum * inv_n -
1996
+ (centroid_a_x * centroid_a_x + centroid_a_y * centroid_a_y + centroid_a_z * centroid_a_z);
1997
+
1998
+ // Apply centering correction to covariance matrix
1999
+ covariance_x_x -= n * centroid_a_x * centroid_b_x;
2000
+ covariance_x_y -= n * centroid_a_x * centroid_b_y;
2001
+ covariance_x_z -= n * centroid_a_x * centroid_b_z;
2002
+ covariance_y_x -= n * centroid_a_y * centroid_b_x;
2003
+ covariance_y_y -= n * centroid_a_y * centroid_b_y;
2004
+ covariance_y_z -= n * centroid_a_y * centroid_b_z;
2005
+ covariance_z_x -= n * centroid_a_z * centroid_b_x;
2006
+ covariance_z_y -= n * centroid_a_z * centroid_b_y;
2007
+ covariance_z_z -= n * centroid_a_z * centroid_b_z;
2008
+
2009
+ nk_f32_t cross_covariance[9] = {covariance_x_x, covariance_x_y, covariance_x_z, covariance_y_x, covariance_y_y,
2010
+ covariance_y_z, covariance_z_x, covariance_z_y, covariance_z_z};
2011
+
2012
+ // SVD
2013
+ nk_f32_t svd_u[9], svd_s[9], svd_v[9];
2014
+ nk_svd3x3_f32_(cross_covariance, svd_u, svd_s, svd_v);
2015
+
2016
+ // R = V * Uᵀ
2017
+ nk_f32_t r[9];
2018
+ r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
2019
+ r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
2020
+ r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
2021
+ r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
2022
+ r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
2023
+ r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
2024
+ r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
2025
+ r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
2026
+ r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
2027
+
2028
+ // Scale factor: c = trace(D × S) / (n × variance(a))
2029
+ nk_f32_t det = nk_det3x3_f32_(r);
2030
+ nk_f32_t d3 = det < 0 ? -1.0f : 1.0f;
2031
+ nk_f32_t trace_ds = svd_s[0] + svd_s[4] + d3 * svd_s[8];
2032
+ nk_f32_t c = trace_ds / (n * variance_a);
2033
+ if (scale) *scale = c;
2034
+
2035
+ // Handle reflection
2036
+ if (det < 0) {
2037
+ svd_v[2] = -svd_v[2], svd_v[5] = -svd_v[5], svd_v[8] = -svd_v[8];
2038
+ r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
2039
+ r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
2040
+ r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
2041
+ r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
2042
+ r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
2043
+ r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
2044
+ r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
2045
+ r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
2046
+ r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
2047
+ }
2048
+
2049
+ /* Output rotation matrix */
2050
+ if (rotation) {
2051
+ for (int j = 0; j < 9; ++j) rotation[j] = r[j];
2052
+ }
2053
+
2054
+ // Compute RMSD with scaling
2055
+ nk_f32_t sum_squared = nk_transformed_ssd_f16_haswell_(a, b, n, r, c, centroid_a_x, centroid_a_y, centroid_a_z,
2056
+ centroid_b_x, centroid_b_y, centroid_b_z);
2057
+ *result = nk_f32_sqrt_haswell(sum_squared * inv_n);
2058
+ }
2059
+
2060
+ NK_PUBLIC void nk_umeyama_bf16_haswell(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
2061
+ nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
2062
+ // Fused single-pass: load bf16, convert to f32, compute centroids, covariance, and variance
2063
+ __m256 const zeros_f32x8 = _mm256_setzero_ps();
2064
+
2065
+ __m256 sum_a_x_f32x8 = zeros_f32x8, sum_a_y_f32x8 = zeros_f32x8, sum_a_z_f32x8 = zeros_f32x8;
2066
+ __m256 sum_b_x_f32x8 = zeros_f32x8, sum_b_y_f32x8 = zeros_f32x8, sum_b_z_f32x8 = zeros_f32x8;
2067
+ __m256 cov_xx_f32x8 = zeros_f32x8, cov_xy_f32x8 = zeros_f32x8, cov_xz_f32x8 = zeros_f32x8;
2068
+ __m256 cov_yx_f32x8 = zeros_f32x8, cov_yy_f32x8 = zeros_f32x8, cov_yz_f32x8 = zeros_f32x8;
2069
+ __m256 cov_zx_f32x8 = zeros_f32x8, cov_zy_f32x8 = zeros_f32x8, cov_zz_f32x8 = zeros_f32x8;
2070
+ __m256 variance_a_f32x8 = zeros_f32x8;
2071
+
2072
+ nk_size_t i = 0;
2073
+ __m256 a_x_f32x8, a_y_f32x8, a_z_f32x8, b_x_f32x8, b_y_f32x8, b_z_f32x8;
2074
+
2075
+ for (; i + 8 <= n; i += 8) {
2076
+ nk_deinterleave_bf16x8_to_f32x8_haswell_(a + i * 3, &a_x_f32x8, &a_y_f32x8, &a_z_f32x8);
2077
+ nk_deinterleave_bf16x8_to_f32x8_haswell_(b + i * 3, &b_x_f32x8, &b_y_f32x8, &b_z_f32x8);
2078
+
2079
+ // Accumulate centroids
2080
+ sum_a_x_f32x8 = _mm256_add_ps(sum_a_x_f32x8, a_x_f32x8);
2081
+ sum_a_y_f32x8 = _mm256_add_ps(sum_a_y_f32x8, a_y_f32x8);
2082
+ sum_a_z_f32x8 = _mm256_add_ps(sum_a_z_f32x8, a_z_f32x8);
2083
+ sum_b_x_f32x8 = _mm256_add_ps(sum_b_x_f32x8, b_x_f32x8);
2084
+ sum_b_y_f32x8 = _mm256_add_ps(sum_b_y_f32x8, b_y_f32x8);
2085
+ sum_b_z_f32x8 = _mm256_add_ps(sum_b_z_f32x8, b_z_f32x8);
2086
+
2087
+ // Accumulate outer products
2088
+ cov_xx_f32x8 = _mm256_fmadd_ps(a_x_f32x8, b_x_f32x8, cov_xx_f32x8);
2089
+ cov_xy_f32x8 = _mm256_fmadd_ps(a_x_f32x8, b_y_f32x8, cov_xy_f32x8);
2090
+ cov_xz_f32x8 = _mm256_fmadd_ps(a_x_f32x8, b_z_f32x8, cov_xz_f32x8);
2091
+ cov_yx_f32x8 = _mm256_fmadd_ps(a_y_f32x8, b_x_f32x8, cov_yx_f32x8);
2092
+ cov_yy_f32x8 = _mm256_fmadd_ps(a_y_f32x8, b_y_f32x8, cov_yy_f32x8);
2093
+ cov_yz_f32x8 = _mm256_fmadd_ps(a_y_f32x8, b_z_f32x8, cov_yz_f32x8);
2094
+ cov_zx_f32x8 = _mm256_fmadd_ps(a_z_f32x8, b_x_f32x8, cov_zx_f32x8);
2095
+ cov_zy_f32x8 = _mm256_fmadd_ps(a_z_f32x8, b_y_f32x8, cov_zy_f32x8);
2096
+ cov_zz_f32x8 = _mm256_fmadd_ps(a_z_f32x8, b_z_f32x8, cov_zz_f32x8);
2097
+
2098
+ // Accumulate variance of A
2099
+ variance_a_f32x8 = _mm256_fmadd_ps(a_x_f32x8, a_x_f32x8, variance_a_f32x8);
2100
+ variance_a_f32x8 = _mm256_fmadd_ps(a_y_f32x8, a_y_f32x8, variance_a_f32x8);
2101
+ variance_a_f32x8 = _mm256_fmadd_ps(a_z_f32x8, a_z_f32x8, variance_a_f32x8);
2102
+ }
2103
+
2104
+ // Reduce vector accumulators
2105
+ nk_f32_t sum_a_x = nk_reduce_add_f32x8_haswell_(sum_a_x_f32x8);
2106
+ nk_f32_t sum_a_y = nk_reduce_add_f32x8_haswell_(sum_a_y_f32x8);
2107
+ nk_f32_t sum_a_z = nk_reduce_add_f32x8_haswell_(sum_a_z_f32x8);
2108
+ nk_f32_t sum_b_x = nk_reduce_add_f32x8_haswell_(sum_b_x_f32x8);
2109
+ nk_f32_t sum_b_y = nk_reduce_add_f32x8_haswell_(sum_b_y_f32x8);
2110
+ nk_f32_t sum_b_z = nk_reduce_add_f32x8_haswell_(sum_b_z_f32x8);
2111
+ nk_f32_t covariance_x_x = nk_reduce_add_f32x8_haswell_(cov_xx_f32x8);
2112
+ nk_f32_t covariance_x_y = nk_reduce_add_f32x8_haswell_(cov_xy_f32x8);
2113
+ nk_f32_t covariance_x_z = nk_reduce_add_f32x8_haswell_(cov_xz_f32x8);
2114
+ nk_f32_t covariance_y_x = nk_reduce_add_f32x8_haswell_(cov_yx_f32x8);
2115
+ nk_f32_t covariance_y_y = nk_reduce_add_f32x8_haswell_(cov_yy_f32x8);
2116
+ nk_f32_t covariance_y_z = nk_reduce_add_f32x8_haswell_(cov_yz_f32x8);
2117
+ nk_f32_t covariance_z_x = nk_reduce_add_f32x8_haswell_(cov_zx_f32x8);
2118
+ nk_f32_t covariance_z_y = nk_reduce_add_f32x8_haswell_(cov_zy_f32x8);
2119
+ nk_f32_t covariance_z_z = nk_reduce_add_f32x8_haswell_(cov_zz_f32x8);
2120
+ nk_f32_t variance_a_sum = nk_reduce_add_f32x8_haswell_(variance_a_f32x8);
2121
+
2122
+ // Scalar tail
2123
+ for (; i < n; ++i) {
2124
+ nk_f32_t ax, ay, az, bx, by, bz;
2125
+ nk_bf16_to_f32_serial(&a[i * 3 + 0], &ax);
2126
+ nk_bf16_to_f32_serial(&a[i * 3 + 1], &ay);
2127
+ nk_bf16_to_f32_serial(&a[i * 3 + 2], &az);
2128
+ nk_bf16_to_f32_serial(&b[i * 3 + 0], &bx);
2129
+ nk_bf16_to_f32_serial(&b[i * 3 + 1], &by);
2130
+ nk_bf16_to_f32_serial(&b[i * 3 + 2], &bz);
2131
+ sum_a_x += ax;
2132
+ sum_a_y += ay;
2133
+ sum_a_z += az;
2134
+ sum_b_x += bx;
2135
+ sum_b_y += by;
2136
+ sum_b_z += bz;
2137
+ covariance_x_x += ax * bx;
2138
+ covariance_x_y += ax * by;
2139
+ covariance_x_z += ax * bz;
2140
+ covariance_y_x += ay * bx;
2141
+ covariance_y_y += ay * by;
2142
+ covariance_y_z += ay * bz;
2143
+ covariance_z_x += az * bx;
2144
+ covariance_z_y += az * by;
2145
+ covariance_z_z += az * bz;
2146
+ variance_a_sum += ax * ax + ay * ay + az * az;
2147
+ }
2148
+
2149
+ // Compute centroids
2150
+ nk_f32_t inv_n = 1.0f / (nk_f32_t)n;
2151
+ nk_f32_t centroid_a_x = sum_a_x * inv_n, centroid_a_y = sum_a_y * inv_n, centroid_a_z = sum_a_z * inv_n;
2152
+ nk_f32_t centroid_b_x = sum_b_x * inv_n, centroid_b_y = sum_b_y * inv_n, centroid_b_z = sum_b_z * inv_n;
2153
+
2154
+ if (a_centroid) a_centroid[0] = centroid_a_x, a_centroid[1] = centroid_a_y, a_centroid[2] = centroid_a_z;
2155
+ if (b_centroid) b_centroid[0] = centroid_b_x, b_centroid[1] = centroid_b_y, b_centroid[2] = centroid_b_z;
2156
+
2157
+ // Compute centered covariance and variance
2158
+ nk_f32_t variance_a = variance_a_sum * inv_n -
2159
+ (centroid_a_x * centroid_a_x + centroid_a_y * centroid_a_y + centroid_a_z * centroid_a_z);
2160
+
2161
+ // Apply centering correction to covariance matrix
2162
+ covariance_x_x -= n * centroid_a_x * centroid_b_x;
2163
+ covariance_x_y -= n * centroid_a_x * centroid_b_y;
2164
+ covariance_x_z -= n * centroid_a_x * centroid_b_z;
2165
+ covariance_y_x -= n * centroid_a_y * centroid_b_x;
2166
+ covariance_y_y -= n * centroid_a_y * centroid_b_y;
2167
+ covariance_y_z -= n * centroid_a_y * centroid_b_z;
2168
+ covariance_z_x -= n * centroid_a_z * centroid_b_x;
2169
+ covariance_z_y -= n * centroid_a_z * centroid_b_y;
2170
+ covariance_z_z -= n * centroid_a_z * centroid_b_z;
2171
+
2172
+ nk_f32_t cross_covariance[9] = {covariance_x_x, covariance_x_y, covariance_x_z, covariance_y_x, covariance_y_y,
2173
+ covariance_y_z, covariance_z_x, covariance_z_y, covariance_z_z};
2174
+
2175
+ // SVD
2176
+ nk_f32_t svd_u[9], svd_s[9], svd_v[9];
2177
+ nk_svd3x3_f32_(cross_covariance, svd_u, svd_s, svd_v);
2178
+
2179
+ // R = V * Uᵀ
2180
+ nk_f32_t r[9];
2181
+ r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
2182
+ r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
2183
+ r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
2184
+ r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
2185
+ r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
2186
+ r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
2187
+ r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
2188
+ r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
2189
+ r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
2190
+
2191
+ // Scale factor: c = trace(D × S) / (n × variance(a))
2192
+ nk_f32_t det = nk_det3x3_f32_(r);
2193
+ nk_f32_t d3 = det < 0 ? -1.0f : 1.0f;
2194
+ nk_f32_t trace_ds = svd_s[0] + svd_s[4] + d3 * svd_s[8];
2195
+ nk_f32_t c = trace_ds / (n * variance_a);
2196
+ if (scale) *scale = c;
2197
+
2198
+ // Handle reflection
2199
+ if (det < 0) {
2200
+ svd_v[2] = -svd_v[2], svd_v[5] = -svd_v[5], svd_v[8] = -svd_v[8];
2201
+ r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
2202
+ r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
2203
+ r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
2204
+ r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
2205
+ r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
2206
+ r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
2207
+ r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
2208
+ r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
2209
+ r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
2210
+ }
2211
+
2212
+ /* Output rotation matrix */
2213
+ if (rotation) {
2214
+ for (int j = 0; j < 9; ++j) rotation[j] = r[j];
2215
+ }
2216
+
2217
+ // Compute RMSD with scaling
2218
+ nk_f32_t sum_squared = nk_transformed_ssd_bf16_haswell_(a, b, n, r, c, centroid_a_x, centroid_a_y, centroid_a_z,
2219
+ centroid_b_x, centroid_b_y, centroid_b_z);
2220
+ *result = nk_f32_sqrt_haswell(sum_squared * inv_n);
2221
+ }
2222
+
2223
+ #if defined(__clang__)
2224
+ #pragma clang attribute pop
2225
+ #elif defined(__GNUC__)
2226
+ #pragma GCC pop_options
2227
+ #endif
2228
+
2229
+ #if defined(__cplusplus)
2230
+ } // extern "C"
2231
+ #endif
2232
+
2233
+ #endif // NK_TARGET_HASWELL
2234
+ #endif // NK_TARGET_X86_
2235
+ #endif // NK_MESH_HASWELL_H