numkong 7.0.0 → 7.4.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (315) hide show
  1. package/README.md +239 -122
  2. package/binding.gyp +25 -491
  3. package/c/dispatch_bf16.c +59 -1
  4. package/c/dispatch_e2m3.c +41 -8
  5. package/c/dispatch_e3m2.c +49 -8
  6. package/c/dispatch_e4m3.c +51 -9
  7. package/c/dispatch_e5m2.c +45 -1
  8. package/c/dispatch_f16.c +79 -26
  9. package/c/dispatch_f16c.c +5 -5
  10. package/c/dispatch_f32.c +56 -0
  11. package/c/dispatch_f64.c +52 -0
  12. package/c/dispatch_i4.c +3 -0
  13. package/c/dispatch_i8.c +62 -3
  14. package/c/dispatch_other.c +18 -0
  15. package/c/dispatch_u1.c +54 -9
  16. package/c/dispatch_u4.c +3 -0
  17. package/c/dispatch_u8.c +64 -3
  18. package/c/numkong.c +3 -0
  19. package/include/README.md +79 -9
  20. package/include/numkong/attention/sapphireamx.h +278 -276
  21. package/include/numkong/attention/sme.h +983 -977
  22. package/include/numkong/attention.h +1 -1
  23. package/include/numkong/capabilities.h +289 -94
  24. package/include/numkong/cast/README.md +40 -40
  25. package/include/numkong/cast/diamond.h +64 -0
  26. package/include/numkong/cast/haswell.h +42 -194
  27. package/include/numkong/cast/icelake.h +42 -37
  28. package/include/numkong/cast/loongsonasx.h +252 -0
  29. package/include/numkong/cast/neon.h +216 -249
  30. package/include/numkong/cast/powervsx.h +449 -0
  31. package/include/numkong/cast/rvv.h +223 -274
  32. package/include/numkong/cast/sapphire.h +18 -18
  33. package/include/numkong/cast/serial.h +1018 -944
  34. package/include/numkong/cast/skylake.h +82 -23
  35. package/include/numkong/cast/v128relaxed.h +462 -105
  36. package/include/numkong/cast.h +24 -0
  37. package/include/numkong/cast.hpp +44 -0
  38. package/include/numkong/curved/README.md +17 -17
  39. package/include/numkong/curved/neon.h +131 -7
  40. package/include/numkong/curved/neonbfdot.h +6 -7
  41. package/include/numkong/curved/rvv.h +26 -26
  42. package/include/numkong/curved/smef64.h +186 -182
  43. package/include/numkong/curved.h +14 -18
  44. package/include/numkong/dot/README.md +154 -137
  45. package/include/numkong/dot/alder.h +43 -43
  46. package/include/numkong/dot/diamond.h +158 -0
  47. package/include/numkong/dot/genoa.h +4 -30
  48. package/include/numkong/dot/haswell.h +215 -180
  49. package/include/numkong/dot/icelake.h +190 -76
  50. package/include/numkong/dot/loongsonasx.h +671 -0
  51. package/include/numkong/dot/neon.h +124 -73
  52. package/include/numkong/dot/neonbfdot.h +11 -12
  53. package/include/numkong/dot/neonfhm.h +44 -46
  54. package/include/numkong/dot/neonfp8.h +323 -0
  55. package/include/numkong/dot/neonsdot.h +190 -76
  56. package/include/numkong/dot/powervsx.h +752 -0
  57. package/include/numkong/dot/rvv.h +92 -84
  58. package/include/numkong/dot/rvvbf16.h +12 -12
  59. package/include/numkong/dot/rvvhalf.h +12 -12
  60. package/include/numkong/dot/sapphire.h +4 -4
  61. package/include/numkong/dot/serial.h +66 -30
  62. package/include/numkong/dot/sierra.h +31 -31
  63. package/include/numkong/dot/skylake.h +142 -110
  64. package/include/numkong/dot/sve.h +217 -177
  65. package/include/numkong/dot/svebfdot.h +10 -10
  66. package/include/numkong/dot/svehalf.h +85 -41
  67. package/include/numkong/dot/svesdot.h +89 -0
  68. package/include/numkong/dot/v128relaxed.h +124 -89
  69. package/include/numkong/dot.h +114 -48
  70. package/include/numkong/dots/README.md +203 -203
  71. package/include/numkong/dots/alder.h +12 -9
  72. package/include/numkong/dots/diamond.h +86 -0
  73. package/include/numkong/dots/genoa.h +10 -4
  74. package/include/numkong/dots/haswell.h +63 -48
  75. package/include/numkong/dots/icelake.h +27 -18
  76. package/include/numkong/dots/loongsonasx.h +176 -0
  77. package/include/numkong/dots/neon.h +14 -11
  78. package/include/numkong/dots/neonbfdot.h +4 -3
  79. package/include/numkong/dots/neonfhm.h +11 -9
  80. package/include/numkong/dots/neonfp8.h +99 -0
  81. package/include/numkong/dots/neonsdot.h +48 -12
  82. package/include/numkong/dots/powervsx.h +194 -0
  83. package/include/numkong/dots/rvv.h +451 -344
  84. package/include/numkong/dots/sapphireamx.h +1028 -984
  85. package/include/numkong/dots/serial.h +213 -197
  86. package/include/numkong/dots/sierra.h +10 -7
  87. package/include/numkong/dots/skylake.h +47 -36
  88. package/include/numkong/dots/sme.h +2001 -2364
  89. package/include/numkong/dots/smebi32.h +175 -162
  90. package/include/numkong/dots/smef64.h +328 -323
  91. package/include/numkong/dots/v128relaxed.h +64 -41
  92. package/include/numkong/dots.h +573 -293
  93. package/include/numkong/dots.hpp +45 -43
  94. package/include/numkong/each/README.md +133 -137
  95. package/include/numkong/each/haswell.h +6 -6
  96. package/include/numkong/each/icelake.h +7 -7
  97. package/include/numkong/each/neon.h +76 -42
  98. package/include/numkong/each/neonbfdot.h +11 -12
  99. package/include/numkong/each/neonhalf.h +24 -116
  100. package/include/numkong/each/rvv.h +28 -28
  101. package/include/numkong/each/sapphire.h +27 -161
  102. package/include/numkong/each/serial.h +6 -6
  103. package/include/numkong/each/skylake.h +7 -7
  104. package/include/numkong/each/v128relaxed.h +562 -0
  105. package/include/numkong/each.h +148 -62
  106. package/include/numkong/each.hpp +2 -2
  107. package/include/numkong/geospatial/README.md +18 -18
  108. package/include/numkong/geospatial/haswell.h +365 -325
  109. package/include/numkong/geospatial/neon.h +350 -306
  110. package/include/numkong/geospatial/rvv.h +4 -4
  111. package/include/numkong/geospatial/skylake.h +376 -340
  112. package/include/numkong/geospatial/v128relaxed.h +366 -327
  113. package/include/numkong/geospatial.h +17 -17
  114. package/include/numkong/matrix.hpp +4 -4
  115. package/include/numkong/maxsim/README.md +14 -14
  116. package/include/numkong/maxsim/alder.h +6 -6
  117. package/include/numkong/maxsim/genoa.h +4 -4
  118. package/include/numkong/maxsim/haswell.h +6 -6
  119. package/include/numkong/maxsim/icelake.h +18 -18
  120. package/include/numkong/maxsim/neonsdot.h +21 -21
  121. package/include/numkong/maxsim/sapphireamx.h +14 -14
  122. package/include/numkong/maxsim/serial.h +6 -6
  123. package/include/numkong/maxsim/sme.h +221 -196
  124. package/include/numkong/maxsim/v128relaxed.h +6 -6
  125. package/include/numkong/mesh/README.md +62 -56
  126. package/include/numkong/mesh/haswell.h +339 -464
  127. package/include/numkong/mesh/neon.h +1100 -519
  128. package/include/numkong/mesh/neonbfdot.h +36 -68
  129. package/include/numkong/mesh/rvv.h +530 -435
  130. package/include/numkong/mesh/serial.h +75 -91
  131. package/include/numkong/mesh/skylake.h +1627 -302
  132. package/include/numkong/mesh/v128relaxed.h +443 -330
  133. package/include/numkong/mesh.h +63 -49
  134. package/include/numkong/mesh.hpp +4 -4
  135. package/include/numkong/numkong.h +3 -3
  136. package/include/numkong/numkong.hpp +1 -0
  137. package/include/numkong/probability/README.md +23 -19
  138. package/include/numkong/probability/neon.h +82 -52
  139. package/include/numkong/probability/rvv.h +28 -23
  140. package/include/numkong/probability/serial.h +51 -39
  141. package/include/numkong/probability.h +20 -23
  142. package/include/numkong/random.h +1 -1
  143. package/include/numkong/reduce/README.md +143 -138
  144. package/include/numkong/reduce/alder.h +81 -77
  145. package/include/numkong/reduce/haswell.h +222 -220
  146. package/include/numkong/reduce/neon.h +629 -519
  147. package/include/numkong/reduce/neonbfdot.h +7 -218
  148. package/include/numkong/reduce/neonfhm.h +9 -381
  149. package/include/numkong/reduce/neonsdot.h +9 -9
  150. package/include/numkong/reduce/rvv.h +928 -802
  151. package/include/numkong/reduce/serial.h +23 -27
  152. package/include/numkong/reduce/sierra.h +20 -20
  153. package/include/numkong/reduce/skylake.h +326 -324
  154. package/include/numkong/reduce/v128relaxed.h +52 -52
  155. package/include/numkong/reduce.h +4 -23
  156. package/include/numkong/reduce.hpp +156 -11
  157. package/include/numkong/scalar/README.md +6 -6
  158. package/include/numkong/scalar/haswell.h +26 -17
  159. package/include/numkong/scalar/loongsonasx.h +74 -0
  160. package/include/numkong/scalar/neon.h +9 -9
  161. package/include/numkong/scalar/powervsx.h +96 -0
  162. package/include/numkong/scalar/rvv.h +2 -2
  163. package/include/numkong/scalar/sapphire.h +21 -10
  164. package/include/numkong/scalar/serial.h +21 -21
  165. package/include/numkong/scalar.h +13 -0
  166. package/include/numkong/set/README.md +28 -28
  167. package/include/numkong/set/haswell.h +12 -12
  168. package/include/numkong/set/icelake.h +14 -14
  169. package/include/numkong/set/loongsonasx.h +181 -0
  170. package/include/numkong/set/neon.h +17 -18
  171. package/include/numkong/set/powervsx.h +326 -0
  172. package/include/numkong/set/rvv.h +4 -4
  173. package/include/numkong/set/serial.h +6 -6
  174. package/include/numkong/set/sve.h +60 -59
  175. package/include/numkong/set/v128relaxed.h +6 -6
  176. package/include/numkong/set.h +21 -7
  177. package/include/numkong/sets/README.md +26 -26
  178. package/include/numkong/sets/loongsonasx.h +52 -0
  179. package/include/numkong/sets/powervsx.h +65 -0
  180. package/include/numkong/sets/smebi32.h +395 -364
  181. package/include/numkong/sets.h +83 -40
  182. package/include/numkong/sparse/README.md +4 -4
  183. package/include/numkong/sparse/icelake.h +101 -101
  184. package/include/numkong/sparse/serial.h +1 -1
  185. package/include/numkong/sparse/sve2.h +137 -141
  186. package/include/numkong/sparse/turin.h +12 -12
  187. package/include/numkong/sparse.h +10 -10
  188. package/include/numkong/spatial/README.md +230 -226
  189. package/include/numkong/spatial/alder.h +113 -116
  190. package/include/numkong/spatial/diamond.h +240 -0
  191. package/include/numkong/spatial/genoa.h +0 -68
  192. package/include/numkong/spatial/haswell.h +74 -55
  193. package/include/numkong/spatial/icelake.h +539 -58
  194. package/include/numkong/spatial/loongsonasx.h +483 -0
  195. package/include/numkong/spatial/neon.h +125 -52
  196. package/include/numkong/spatial/neonbfdot.h +8 -9
  197. package/include/numkong/spatial/neonfp8.h +258 -0
  198. package/include/numkong/spatial/neonsdot.h +180 -12
  199. package/include/numkong/spatial/powervsx.h +738 -0
  200. package/include/numkong/spatial/rvv.h +146 -139
  201. package/include/numkong/spatial/rvvbf16.h +17 -12
  202. package/include/numkong/spatial/rvvhalf.h +13 -10
  203. package/include/numkong/spatial/serial.h +13 -12
  204. package/include/numkong/spatial/sierra.h +232 -39
  205. package/include/numkong/spatial/skylake.h +73 -74
  206. package/include/numkong/spatial/sve.h +93 -72
  207. package/include/numkong/spatial/svebfdot.h +29 -29
  208. package/include/numkong/spatial/svehalf.h +52 -26
  209. package/include/numkong/spatial/svesdot.h +142 -0
  210. package/include/numkong/spatial/v128relaxed.h +293 -41
  211. package/include/numkong/spatial.h +338 -82
  212. package/include/numkong/spatials/README.md +194 -194
  213. package/include/numkong/spatials/diamond.h +82 -0
  214. package/include/numkong/spatials/haswell.h +2 -2
  215. package/include/numkong/spatials/loongsonasx.h +153 -0
  216. package/include/numkong/spatials/neonfp8.h +111 -0
  217. package/include/numkong/spatials/neonsdot.h +34 -0
  218. package/include/numkong/spatials/powervsx.h +153 -0
  219. package/include/numkong/spatials/rvv.h +259 -243
  220. package/include/numkong/spatials/sapphireamx.h +173 -173
  221. package/include/numkong/spatials/serial.h +2 -2
  222. package/include/numkong/spatials/skylake.h +2 -2
  223. package/include/numkong/spatials/sme.h +590 -605
  224. package/include/numkong/spatials/smef64.h +139 -130
  225. package/include/numkong/spatials/v128relaxed.h +2 -2
  226. package/include/numkong/spatials.h +820 -500
  227. package/include/numkong/spatials.hpp +49 -48
  228. package/include/numkong/tensor.hpp +406 -17
  229. package/include/numkong/trigonometry/README.md +19 -19
  230. package/include/numkong/trigonometry/haswell.h +402 -401
  231. package/include/numkong/trigonometry/neon.h +386 -387
  232. package/include/numkong/trigonometry/rvv.h +52 -51
  233. package/include/numkong/trigonometry/serial.h +13 -13
  234. package/include/numkong/trigonometry/skylake.h +373 -369
  235. package/include/numkong/trigonometry/v128relaxed.h +375 -374
  236. package/include/numkong/trigonometry.h +13 -13
  237. package/include/numkong/trigonometry.hpp +2 -2
  238. package/include/numkong/types.h +287 -49
  239. package/include/numkong/types.hpp +436 -12
  240. package/include/numkong/vector.hpp +82 -14
  241. package/javascript/dist/cjs/numkong-wasm.js +6 -12
  242. package/javascript/dist/cjs/numkong.d.ts +7 -1
  243. package/javascript/dist/cjs/numkong.js +37 -11
  244. package/javascript/dist/cjs/types.d.ts +9 -0
  245. package/javascript/dist/cjs/types.js +96 -0
  246. package/javascript/dist/esm/numkong-browser.d.ts +14 -0
  247. package/javascript/dist/esm/numkong-browser.js +23 -0
  248. package/javascript/dist/esm/numkong-wasm.js +6 -12
  249. package/javascript/dist/esm/numkong.d.ts +7 -1
  250. package/javascript/dist/esm/numkong.js +37 -11
  251. package/javascript/dist/esm/types.d.ts +9 -0
  252. package/javascript/dist/esm/types.js +96 -0
  253. package/javascript/node-gyp-build.d.ts +4 -1
  254. package/javascript/numkong-browser.ts +40 -0
  255. package/javascript/numkong-wasm.ts +7 -13
  256. package/javascript/numkong.c +5 -26
  257. package/javascript/numkong.ts +36 -11
  258. package/javascript/tsconfig-base.json +1 -0
  259. package/javascript/tsconfig-cjs.json +6 -1
  260. package/javascript/types.ts +110 -0
  261. package/numkong.gypi +101 -0
  262. package/package.json +34 -13
  263. package/probes/arm_neon.c +8 -0
  264. package/probes/arm_neon_bfdot.c +9 -0
  265. package/probes/arm_neon_fhm.c +9 -0
  266. package/probes/arm_neon_half.c +8 -0
  267. package/probes/arm_neon_sdot.c +9 -0
  268. package/probes/arm_neonfp8.c +9 -0
  269. package/probes/arm_sme.c +16 -0
  270. package/probes/arm_sme2.c +16 -0
  271. package/probes/arm_sme2p1.c +16 -0
  272. package/probes/arm_sme_bf16.c +16 -0
  273. package/probes/arm_sme_bi32.c +16 -0
  274. package/probes/arm_sme_f64.c +16 -0
  275. package/probes/arm_sme_fa64.c +14 -0
  276. package/probes/arm_sme_half.c +16 -0
  277. package/probes/arm_sme_lut2.c +15 -0
  278. package/probes/arm_sve.c +18 -0
  279. package/probes/arm_sve2.c +20 -0
  280. package/probes/arm_sve2p1.c +18 -0
  281. package/probes/arm_sve_bfdot.c +20 -0
  282. package/probes/arm_sve_half.c +18 -0
  283. package/probes/arm_sve_sdot.c +21 -0
  284. package/probes/loongarch_lasx.c +12 -0
  285. package/probes/power_vsx.c +12 -0
  286. package/probes/probe.js +127 -0
  287. package/probes/riscv_rvv.c +14 -0
  288. package/probes/riscv_rvv_bb.c +15 -0
  289. package/probes/riscv_rvv_bf16.c +17 -0
  290. package/probes/riscv_rvv_half.c +14 -0
  291. package/probes/wasm_v128relaxed.c +11 -0
  292. package/probes/x86_alder.c +17 -0
  293. package/probes/x86_diamond.c +17 -0
  294. package/probes/x86_genoa.c +17 -0
  295. package/probes/x86_graniteamx.c +19 -0
  296. package/probes/x86_haswell.c +11 -0
  297. package/probes/x86_icelake.c +17 -0
  298. package/probes/x86_sapphire.c +16 -0
  299. package/probes/x86_sapphireamx.c +18 -0
  300. package/probes/x86_sierra.c +17 -0
  301. package/probes/x86_skylake.c +15 -0
  302. package/probes/x86_turin.c +17 -0
  303. package/wasm/numkong-emscripten.js +2 -0
  304. package/wasm/numkong.d.ts +14 -0
  305. package/wasm/numkong.js +1124 -0
  306. package/wasm/numkong.wasm +0 -0
  307. package/include/numkong/curved/neonhalf.h +0 -212
  308. package/include/numkong/dot/neonhalf.h +0 -198
  309. package/include/numkong/dots/neonhalf.h +0 -57
  310. package/include/numkong/mesh/neonhalf.h +0 -616
  311. package/include/numkong/reduce/neonhalf.h +0 -157
  312. package/include/numkong/spatial/neonhalf.h +0 -118
  313. package/include/numkong/spatial/sapphire.h +0 -343
  314. package/include/numkong/spatials/neonhalf.h +0 -58
  315. package/javascript/README.md +0 -246
@@ -8,12 +8,12 @@
8
8
  *
9
9
  * @section haswell_mesh_instructions Key AVX2 Mesh Instructions
10
10
  *
11
- * Intrinsic Instruction Latency Throughput Ports
12
- * _mm256_fmadd_ps VFMADD (YMM, YMM, YMM) 5cy 0.5/cy p01
13
- * _mm256_hadd_ps VHADDPS (YMM, YMM, YMM) 7cy 0.5/cy p01+p5
14
- * _mm256_permute2f128_ps VPERM2F128 (YMM, YMM, YMM, I8) 3cy 1/cy p5
15
- * _mm256_extractf128_ps VEXTRACTF128 (XMM, YMM, I8) 3cy 1/cy p5
16
- * _mm256_i32gather_ps VGATHERDPS (YMM, M, YMM, YMM) 12cy 5/cy p0+p23
11
+ * Intrinsic Instruction Haswell Genoa
12
+ * _mm256_fmadd_ps VFMADD (YMM, YMM, YMM) 5cy @ p01 4cy @ p01
13
+ * _mm256_hadd_ps VHADDPS (YMM, YMM, YMM) 7cy @ p1+p5 4cy @ p123+p23+p23
14
+ * _mm256_permute2f128_ps VPERM2F128 (YMM, YMM, YMM, I8) 3cy @ p5 2cy @ p12
15
+ * _mm256_extractf128_ps VEXTRACTF128 (XMM, YMM, I8) 3cy @ p5 1cy @ p0123
16
+ * _mm256_i32gather_ps VGATHERDPS (YMM, M, YMM, YMM) 22cy (34 uops) 19cy (17 uops)
17
17
  *
18
18
  * Point cloud operations (centroid, covariance, Kabsch alignment) use gather instructions for
19
19
  * stride-3 xyz deinterleaving. Multiple FMA accumulators hide the 5-cycle FMA latency. VHADDPS
@@ -50,10 +50,10 @@ extern "C" {
50
50
  */
51
51
  NK_INTERNAL void nk_deinterleave_f32x8_haswell_(nk_f32_t const *ptr, __m256 *x_out, __m256 *y_out, __m256 *z_out) {
52
52
  // Gather indices: 0, 3, 6, 9, 12, 15, 18, 21 (stride 3)
53
- __m256i idx = _mm256_setr_epi32(0, 3, 6, 9, 12, 15, 18, 21);
54
- *x_out = _mm256_i32gather_ps(ptr + 0, idx, 4);
55
- *y_out = _mm256_i32gather_ps(ptr + 1, idx, 4);
56
- *z_out = _mm256_i32gather_ps(ptr + 2, idx, 4);
53
+ __m256i idx_i32x8 = _mm256_setr_epi32(0, 3, 6, 9, 12, 15, 18, 21);
54
+ *x_out = _mm256_i32gather_ps(ptr + 0, idx_i32x8, 4);
55
+ *y_out = _mm256_i32gather_ps(ptr + 1, idx_i32x8, 4);
56
+ *z_out = _mm256_i32gather_ps(ptr + 2, idx_i32x8, 4);
57
57
  }
58
58
 
59
59
  /* Deinterleave 12 f64 values (4 xyz triplets) into separate x, y, z vectors.
@@ -134,84 +134,84 @@ NK_INTERNAL nk_f64_t nk_transformed_ssd_f32_haswell_(nk_f32_t const *a, nk_f32_t
134
134
  nk_deinterleave_f32x8_haswell_(a + index * 3, &a_x_f32x8, &a_y_f32x8, &a_z_f32x8),
135
135
  nk_deinterleave_f32x8_haswell_(b + index * 3, &b_x_f32x8, &b_y_f32x8, &b_z_f32x8);
136
136
 
137
- __m256d a_x_lower_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(a_x_f32x8));
138
- __m256d a_x_upper_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(a_x_f32x8, 1));
139
- __m256d a_y_lower_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(a_y_f32x8));
140
- __m256d a_y_upper_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(a_y_f32x8, 1));
141
- __m256d a_z_lower_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(a_z_f32x8));
142
- __m256d a_z_upper_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(a_z_f32x8, 1));
143
- __m256d b_x_lower_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(b_x_f32x8));
144
- __m256d b_x_upper_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(b_x_f32x8, 1));
145
- __m256d b_y_lower_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(b_y_f32x8));
146
- __m256d b_y_upper_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(b_y_f32x8, 1));
147
- __m256d b_z_lower_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(b_z_f32x8));
148
- __m256d b_z_upper_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(b_z_f32x8, 1));
149
-
150
- __m256d centered_a_x_lower_f64x4 = _mm256_sub_pd(a_x_lower_f64x4, centroid_a_x_f64x4);
151
- __m256d centered_a_x_upper_f64x4 = _mm256_sub_pd(a_x_upper_f64x4, centroid_a_x_f64x4);
152
- __m256d centered_a_y_lower_f64x4 = _mm256_sub_pd(a_y_lower_f64x4, centroid_a_y_f64x4);
153
- __m256d centered_a_y_upper_f64x4 = _mm256_sub_pd(a_y_upper_f64x4, centroid_a_y_f64x4);
154
- __m256d centered_a_z_lower_f64x4 = _mm256_sub_pd(a_z_lower_f64x4, centroid_a_z_f64x4);
155
- __m256d centered_a_z_upper_f64x4 = _mm256_sub_pd(a_z_upper_f64x4, centroid_a_z_f64x4);
156
- __m256d centered_b_x_lower_f64x4 = _mm256_sub_pd(b_x_lower_f64x4, centroid_b_x_f64x4);
157
- __m256d centered_b_x_upper_f64x4 = _mm256_sub_pd(b_x_upper_f64x4, centroid_b_x_f64x4);
158
- __m256d centered_b_y_lower_f64x4 = _mm256_sub_pd(b_y_lower_f64x4, centroid_b_y_f64x4);
159
- __m256d centered_b_y_upper_f64x4 = _mm256_sub_pd(b_y_upper_f64x4, centroid_b_y_f64x4);
160
- __m256d centered_b_z_lower_f64x4 = _mm256_sub_pd(b_z_lower_f64x4, centroid_b_z_f64x4);
161
- __m256d centered_b_z_upper_f64x4 = _mm256_sub_pd(b_z_upper_f64x4, centroid_b_z_f64x4);
162
-
163
- __m256d rotated_a_x_lower_f64x4 = _mm256_fmadd_pd(
164
- scaled_rotation_x_z_f64x4, centered_a_z_lower_f64x4,
165
- _mm256_fmadd_pd(scaled_rotation_x_y_f64x4, centered_a_y_lower_f64x4,
166
- _mm256_mul_pd(scaled_rotation_x_x_f64x4, centered_a_x_lower_f64x4)));
167
- __m256d rotated_a_x_upper_f64x4 = _mm256_fmadd_pd(
168
- scaled_rotation_x_z_f64x4, centered_a_z_upper_f64x4,
169
- _mm256_fmadd_pd(scaled_rotation_x_y_f64x4, centered_a_y_upper_f64x4,
170
- _mm256_mul_pd(scaled_rotation_x_x_f64x4, centered_a_x_upper_f64x4)));
171
- __m256d rotated_a_y_lower_f64x4 = _mm256_fmadd_pd(
172
- scaled_rotation_y_z_f64x4, centered_a_z_lower_f64x4,
173
- _mm256_fmadd_pd(scaled_rotation_y_y_f64x4, centered_a_y_lower_f64x4,
174
- _mm256_mul_pd(scaled_rotation_y_x_f64x4, centered_a_x_lower_f64x4)));
175
- __m256d rotated_a_y_upper_f64x4 = _mm256_fmadd_pd(
176
- scaled_rotation_y_z_f64x4, centered_a_z_upper_f64x4,
177
- _mm256_fmadd_pd(scaled_rotation_y_y_f64x4, centered_a_y_upper_f64x4,
178
- _mm256_mul_pd(scaled_rotation_y_x_f64x4, centered_a_x_upper_f64x4)));
179
- __m256d rotated_a_z_lower_f64x4 = _mm256_fmadd_pd(
180
- scaled_rotation_z_z_f64x4, centered_a_z_lower_f64x4,
181
- _mm256_fmadd_pd(scaled_rotation_z_y_f64x4, centered_a_y_lower_f64x4,
182
- _mm256_mul_pd(scaled_rotation_z_x_f64x4, centered_a_x_lower_f64x4)));
183
- __m256d rotated_a_z_upper_f64x4 = _mm256_fmadd_pd(
184
- scaled_rotation_z_z_f64x4, centered_a_z_upper_f64x4,
185
- _mm256_fmadd_pd(scaled_rotation_z_y_f64x4, centered_a_y_upper_f64x4,
186
- _mm256_mul_pd(scaled_rotation_z_x_f64x4, centered_a_x_upper_f64x4)));
187
-
188
- __m256d delta_x_lower_f64x4 = _mm256_sub_pd(rotated_a_x_lower_f64x4, centered_b_x_lower_f64x4);
189
- __m256d delta_x_upper_f64x4 = _mm256_sub_pd(rotated_a_x_upper_f64x4, centered_b_x_upper_f64x4);
190
- __m256d delta_y_lower_f64x4 = _mm256_sub_pd(rotated_a_y_lower_f64x4, centered_b_y_lower_f64x4);
191
- __m256d delta_y_upper_f64x4 = _mm256_sub_pd(rotated_a_y_upper_f64x4, centered_b_y_upper_f64x4);
192
- __m256d delta_z_lower_f64x4 = _mm256_sub_pd(rotated_a_z_lower_f64x4, centered_b_z_lower_f64x4);
193
- __m256d delta_z_upper_f64x4 = _mm256_sub_pd(rotated_a_z_upper_f64x4, centered_b_z_upper_f64x4);
194
-
195
- __m256d batch_sum_squared_f64x4 = _mm256_add_pd(_mm256_mul_pd(delta_x_lower_f64x4, delta_x_lower_f64x4),
196
- _mm256_mul_pd(delta_x_upper_f64x4, delta_x_upper_f64x4));
197
- batch_sum_squared_f64x4 = _mm256_fmadd_pd(delta_y_lower_f64x4, delta_y_lower_f64x4, batch_sum_squared_f64x4);
198
- batch_sum_squared_f64x4 = _mm256_fmadd_pd(delta_y_upper_f64x4, delta_y_upper_f64x4, batch_sum_squared_f64x4);
199
- batch_sum_squared_f64x4 = _mm256_fmadd_pd(delta_z_lower_f64x4, delta_z_lower_f64x4, batch_sum_squared_f64x4);
200
- batch_sum_squared_f64x4 = _mm256_fmadd_pd(delta_z_upper_f64x4, delta_z_upper_f64x4, batch_sum_squared_f64x4);
137
+ __m256d a_x_low_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(a_x_f32x8));
138
+ __m256d a_x_high_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(a_x_f32x8, 1));
139
+ __m256d a_y_low_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(a_y_f32x8));
140
+ __m256d a_y_high_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(a_y_f32x8, 1));
141
+ __m256d a_z_low_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(a_z_f32x8));
142
+ __m256d a_z_high_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(a_z_f32x8, 1));
143
+ __m256d b_x_low_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(b_x_f32x8));
144
+ __m256d b_x_high_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(b_x_f32x8, 1));
145
+ __m256d b_y_low_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(b_y_f32x8));
146
+ __m256d b_y_high_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(b_y_f32x8, 1));
147
+ __m256d b_z_low_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(b_z_f32x8));
148
+ __m256d b_z_high_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(b_z_f32x8, 1));
149
+
150
+ __m256d centered_a_x_low_f64x4 = _mm256_sub_pd(a_x_low_f64x4, centroid_a_x_f64x4);
151
+ __m256d centered_a_x_high_f64x4 = _mm256_sub_pd(a_x_high_f64x4, centroid_a_x_f64x4);
152
+ __m256d centered_a_y_low_f64x4 = _mm256_sub_pd(a_y_low_f64x4, centroid_a_y_f64x4);
153
+ __m256d centered_a_y_high_f64x4 = _mm256_sub_pd(a_y_high_f64x4, centroid_a_y_f64x4);
154
+ __m256d centered_a_z_low_f64x4 = _mm256_sub_pd(a_z_low_f64x4, centroid_a_z_f64x4);
155
+ __m256d centered_a_z_high_f64x4 = _mm256_sub_pd(a_z_high_f64x4, centroid_a_z_f64x4);
156
+ __m256d centered_b_x_low_f64x4 = _mm256_sub_pd(b_x_low_f64x4, centroid_b_x_f64x4);
157
+ __m256d centered_b_x_high_f64x4 = _mm256_sub_pd(b_x_high_f64x4, centroid_b_x_f64x4);
158
+ __m256d centered_b_y_low_f64x4 = _mm256_sub_pd(b_y_low_f64x4, centroid_b_y_f64x4);
159
+ __m256d centered_b_y_high_f64x4 = _mm256_sub_pd(b_y_high_f64x4, centroid_b_y_f64x4);
160
+ __m256d centered_b_z_low_f64x4 = _mm256_sub_pd(b_z_low_f64x4, centroid_b_z_f64x4);
161
+ __m256d centered_b_z_high_f64x4 = _mm256_sub_pd(b_z_high_f64x4, centroid_b_z_f64x4);
162
+
163
+ __m256d rotated_a_x_low_f64x4 = _mm256_fmadd_pd(
164
+ scaled_rotation_x_z_f64x4, centered_a_z_low_f64x4,
165
+ _mm256_fmadd_pd(scaled_rotation_x_y_f64x4, centered_a_y_low_f64x4,
166
+ _mm256_mul_pd(scaled_rotation_x_x_f64x4, centered_a_x_low_f64x4)));
167
+ __m256d rotated_a_x_high_f64x4 = _mm256_fmadd_pd(
168
+ scaled_rotation_x_z_f64x4, centered_a_z_high_f64x4,
169
+ _mm256_fmadd_pd(scaled_rotation_x_y_f64x4, centered_a_y_high_f64x4,
170
+ _mm256_mul_pd(scaled_rotation_x_x_f64x4, centered_a_x_high_f64x4)));
171
+ __m256d rotated_a_y_low_f64x4 = _mm256_fmadd_pd(
172
+ scaled_rotation_y_z_f64x4, centered_a_z_low_f64x4,
173
+ _mm256_fmadd_pd(scaled_rotation_y_y_f64x4, centered_a_y_low_f64x4,
174
+ _mm256_mul_pd(scaled_rotation_y_x_f64x4, centered_a_x_low_f64x4)));
175
+ __m256d rotated_a_y_high_f64x4 = _mm256_fmadd_pd(
176
+ scaled_rotation_y_z_f64x4, centered_a_z_high_f64x4,
177
+ _mm256_fmadd_pd(scaled_rotation_y_y_f64x4, centered_a_y_high_f64x4,
178
+ _mm256_mul_pd(scaled_rotation_y_x_f64x4, centered_a_x_high_f64x4)));
179
+ __m256d rotated_a_z_low_f64x4 = _mm256_fmadd_pd(
180
+ scaled_rotation_z_z_f64x4, centered_a_z_low_f64x4,
181
+ _mm256_fmadd_pd(scaled_rotation_z_y_f64x4, centered_a_y_low_f64x4,
182
+ _mm256_mul_pd(scaled_rotation_z_x_f64x4, centered_a_x_low_f64x4)));
183
+ __m256d rotated_a_z_high_f64x4 = _mm256_fmadd_pd(
184
+ scaled_rotation_z_z_f64x4, centered_a_z_high_f64x4,
185
+ _mm256_fmadd_pd(scaled_rotation_z_y_f64x4, centered_a_y_high_f64x4,
186
+ _mm256_mul_pd(scaled_rotation_z_x_f64x4, centered_a_x_high_f64x4)));
187
+
188
+ __m256d delta_x_low_f64x4 = _mm256_sub_pd(rotated_a_x_low_f64x4, centered_b_x_low_f64x4);
189
+ __m256d delta_x_high_f64x4 = _mm256_sub_pd(rotated_a_x_high_f64x4, centered_b_x_high_f64x4);
190
+ __m256d delta_y_low_f64x4 = _mm256_sub_pd(rotated_a_y_low_f64x4, centered_b_y_low_f64x4);
191
+ __m256d delta_y_high_f64x4 = _mm256_sub_pd(rotated_a_y_high_f64x4, centered_b_y_high_f64x4);
192
+ __m256d delta_z_low_f64x4 = _mm256_sub_pd(rotated_a_z_low_f64x4, centered_b_z_low_f64x4);
193
+ __m256d delta_z_high_f64x4 = _mm256_sub_pd(rotated_a_z_high_f64x4, centered_b_z_high_f64x4);
194
+
195
+ __m256d batch_sum_squared_f64x4 = _mm256_add_pd(_mm256_mul_pd(delta_x_low_f64x4, delta_x_low_f64x4),
196
+ _mm256_mul_pd(delta_x_high_f64x4, delta_x_high_f64x4));
197
+ batch_sum_squared_f64x4 = _mm256_fmadd_pd(delta_y_low_f64x4, delta_y_low_f64x4, batch_sum_squared_f64x4);
198
+ batch_sum_squared_f64x4 = _mm256_fmadd_pd(delta_y_high_f64x4, delta_y_high_f64x4, batch_sum_squared_f64x4);
199
+ batch_sum_squared_f64x4 = _mm256_fmadd_pd(delta_z_low_f64x4, delta_z_low_f64x4, batch_sum_squared_f64x4);
200
+ batch_sum_squared_f64x4 = _mm256_fmadd_pd(delta_z_high_f64x4, delta_z_high_f64x4, batch_sum_squared_f64x4);
201
201
  sum_squared_f64x4 = _mm256_add_pd(sum_squared_f64x4, batch_sum_squared_f64x4);
202
202
  }
203
203
 
204
204
  nk_f64_t sum_squared = nk_reduce_add_f64x4_haswell_(sum_squared_f64x4);
205
205
  for (; index < n; ++index) {
206
- nk_f64_t centered_a_x = (nk_f64_t)a[index * 3 + 0] - centroid_a_x;
207
- nk_f64_t centered_a_y = (nk_f64_t)a[index * 3 + 1] - centroid_a_y;
208
- nk_f64_t centered_a_z = (nk_f64_t)a[index * 3 + 2] - centroid_a_z;
209
- nk_f64_t centered_b_x = (nk_f64_t)b[index * 3 + 0] - centroid_b_x;
210
- nk_f64_t centered_b_y = (nk_f64_t)b[index * 3 + 1] - centroid_b_y;
211
- nk_f64_t centered_b_z = (nk_f64_t)b[index * 3 + 2] - centroid_b_z;
212
- nk_f64_t rotated_a_x = scale * (r[0] * centered_a_x + r[1] * centered_a_y + r[2] * centered_a_z);
213
- nk_f64_t rotated_a_y = scale * (r[3] * centered_a_x + r[4] * centered_a_y + r[5] * centered_a_z);
214
- nk_f64_t rotated_a_z = scale * (r[6] * centered_a_x + r[7] * centered_a_y + r[8] * centered_a_z);
206
+ nk_f64_t centered_a_x = (nk_f64_t)a[index * 3 + 0] - centroid_a_x,
207
+ centered_a_y = (nk_f64_t)a[index * 3 + 1] - centroid_a_y,
208
+ centered_a_z = (nk_f64_t)a[index * 3 + 2] - centroid_a_z;
209
+ nk_f64_t centered_b_x = (nk_f64_t)b[index * 3 + 0] - centroid_b_x,
210
+ centered_b_y = (nk_f64_t)b[index * 3 + 1] - centroid_b_y,
211
+ centered_b_z = (nk_f64_t)b[index * 3 + 2] - centroid_b_z;
212
+ nk_f64_t rotated_a_x = scale * (r[0] * centered_a_x + r[1] * centered_a_y + r[2] * centered_a_z),
213
+ rotated_a_y = scale * (r[3] * centered_a_x + r[4] * centered_a_y + r[5] * centered_a_z),
214
+ rotated_a_z = scale * (r[6] * centered_a_x + r[7] * centered_a_y + r[8] * centered_a_z);
215
215
  nk_f64_t delta_x = rotated_a_x - centered_b_x, delta_y = rotated_a_y - centered_b_y,
216
216
  delta_z = rotated_a_z - centered_b_z;
217
217
  sum_squared += delta_x * delta_x + delta_y * delta_y + delta_z * delta_z;
@@ -290,20 +290,15 @@ NK_INTERNAL nk_f64_t nk_transformed_ssd_f64_haswell_(nk_f64_t const *a, nk_f64_t
290
290
 
291
291
  // Scalar tail
292
292
  for (; j < n; ++j) {
293
- nk_f64_t pa_x = a[j * 3 + 0] - centroid_a_x;
294
- nk_f64_t pa_y = a[j * 3 + 1] - centroid_a_y;
295
- nk_f64_t pa_z = a[j * 3 + 2] - centroid_a_z;
296
- nk_f64_t pb_x = b[j * 3 + 0] - centroid_b_x;
297
- nk_f64_t pb_y = b[j * 3 + 1] - centroid_b_y;
298
- nk_f64_t pb_z = b[j * 3 + 2] - centroid_b_z;
299
-
300
- nk_f64_t ra_x = scale * (r[0] * pa_x + r[1] * pa_y + r[2] * pa_z);
301
- nk_f64_t ra_y = scale * (r[3] * pa_x + r[4] * pa_y + r[5] * pa_z);
302
- nk_f64_t ra_z = scale * (r[6] * pa_x + r[7] * pa_y + r[8] * pa_z);
303
-
304
- nk_f64_t delta_x = ra_x - pb_x;
305
- nk_f64_t delta_y = ra_y - pb_y;
306
- nk_f64_t delta_z = ra_z - pb_z;
293
+ nk_f64_t pa_x = a[j * 3 + 0] - centroid_a_x, pa_y = a[j * 3 + 1] - centroid_a_y,
294
+ pa_z = a[j * 3 + 2] - centroid_a_z;
295
+ nk_f64_t pb_x = b[j * 3 + 0] - centroid_b_x, pb_y = b[j * 3 + 1] - centroid_b_y,
296
+ pb_z = b[j * 3 + 2] - centroid_b_z;
297
+ nk_f64_t ra_x = scale * (r[0] * pa_x + r[1] * pa_y + r[2] * pa_z),
298
+ ra_y = scale * (r[3] * pa_x + r[4] * pa_y + r[5] * pa_z),
299
+ ra_z = scale * (r[6] * pa_x + r[7] * pa_y + r[8] * pa_z);
300
+
301
+ nk_f64_t delta_x = ra_x - pb_x, delta_y = ra_y - pb_y, delta_z = ra_z - pb_z;
307
302
  nk_accumulate_square_f64_(&sum_squared, &sum_squared_compensation, delta_x);
308
303
  nk_accumulate_square_f64_(&sum_squared, &sum_squared_compensation, delta_y);
309
304
  nk_accumulate_square_f64_(&sum_squared, &sum_squared_compensation, delta_z);
@@ -330,38 +325,38 @@ NK_PUBLIC void nk_rmsd_f32_haswell(nk_f32_t const *a, nk_f32_t const *b, nk_size
330
325
  nk_deinterleave_f32x8_haswell_(a + index * 3, &a_x_f32x8, &a_y_f32x8, &a_z_f32x8),
331
326
  nk_deinterleave_f32x8_haswell_(b + index * 3, &b_x_f32x8, &b_y_f32x8, &b_z_f32x8);
332
327
 
333
- __m256d a_x_lower_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(a_x_f32x8));
334
- __m256d a_x_upper_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(a_x_f32x8, 1));
335
- __m256d a_y_lower_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(a_y_f32x8));
336
- __m256d a_y_upper_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(a_y_f32x8, 1));
337
- __m256d a_z_lower_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(a_z_f32x8));
338
- __m256d a_z_upper_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(a_z_f32x8, 1));
339
- __m256d b_x_lower_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(b_x_f32x8));
340
- __m256d b_x_upper_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(b_x_f32x8, 1));
341
- __m256d b_y_lower_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(b_y_f32x8));
342
- __m256d b_y_upper_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(b_y_f32x8, 1));
343
- __m256d b_z_lower_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(b_z_f32x8));
344
- __m256d b_z_upper_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(b_z_f32x8, 1));
345
-
346
- sum_a_x_f64x4 = _mm256_add_pd(sum_a_x_f64x4, _mm256_add_pd(a_x_lower_f64x4, a_x_upper_f64x4));
347
- sum_a_y_f64x4 = _mm256_add_pd(sum_a_y_f64x4, _mm256_add_pd(a_y_lower_f64x4, a_y_upper_f64x4));
348
- sum_a_z_f64x4 = _mm256_add_pd(sum_a_z_f64x4, _mm256_add_pd(a_z_lower_f64x4, a_z_upper_f64x4));
349
- sum_b_x_f64x4 = _mm256_add_pd(sum_b_x_f64x4, _mm256_add_pd(b_x_lower_f64x4, b_x_upper_f64x4));
350
- sum_b_y_f64x4 = _mm256_add_pd(sum_b_y_f64x4, _mm256_add_pd(b_y_lower_f64x4, b_y_upper_f64x4));
351
- sum_b_z_f64x4 = _mm256_add_pd(sum_b_z_f64x4, _mm256_add_pd(b_z_lower_f64x4, b_z_upper_f64x4));
352
-
353
- __m256d delta_x_lower_f64x4 = _mm256_sub_pd(a_x_lower_f64x4, b_x_lower_f64x4);
354
- __m256d delta_x_upper_f64x4 = _mm256_sub_pd(a_x_upper_f64x4, b_x_upper_f64x4);
355
- __m256d delta_y_lower_f64x4 = _mm256_sub_pd(a_y_lower_f64x4, b_y_lower_f64x4);
356
- __m256d delta_y_upper_f64x4 = _mm256_sub_pd(a_y_upper_f64x4, b_y_upper_f64x4);
357
- __m256d delta_z_lower_f64x4 = _mm256_sub_pd(a_z_lower_f64x4, b_z_lower_f64x4);
358
- __m256d delta_z_upper_f64x4 = _mm256_sub_pd(a_z_upper_f64x4, b_z_upper_f64x4);
359
- __m256d batch_sum_squared_f64x4 = _mm256_add_pd(_mm256_mul_pd(delta_x_lower_f64x4, delta_x_lower_f64x4),
360
- _mm256_mul_pd(delta_x_upper_f64x4, delta_x_upper_f64x4));
361
- batch_sum_squared_f64x4 = _mm256_fmadd_pd(delta_y_lower_f64x4, delta_y_lower_f64x4, batch_sum_squared_f64x4);
362
- batch_sum_squared_f64x4 = _mm256_fmadd_pd(delta_y_upper_f64x4, delta_y_upper_f64x4, batch_sum_squared_f64x4);
363
- batch_sum_squared_f64x4 = _mm256_fmadd_pd(delta_z_lower_f64x4, delta_z_lower_f64x4, batch_sum_squared_f64x4);
364
- batch_sum_squared_f64x4 = _mm256_fmadd_pd(delta_z_upper_f64x4, delta_z_upper_f64x4, batch_sum_squared_f64x4);
328
+ __m256d a_x_low_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(a_x_f32x8));
329
+ __m256d a_x_high_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(a_x_f32x8, 1));
330
+ __m256d a_y_low_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(a_y_f32x8));
331
+ __m256d a_y_high_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(a_y_f32x8, 1));
332
+ __m256d a_z_low_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(a_z_f32x8));
333
+ __m256d a_z_high_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(a_z_f32x8, 1));
334
+ __m256d b_x_low_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(b_x_f32x8));
335
+ __m256d b_x_high_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(b_x_f32x8, 1));
336
+ __m256d b_y_low_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(b_y_f32x8));
337
+ __m256d b_y_high_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(b_y_f32x8, 1));
338
+ __m256d b_z_low_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(b_z_f32x8));
339
+ __m256d b_z_high_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(b_z_f32x8, 1));
340
+
341
+ sum_a_x_f64x4 = _mm256_add_pd(sum_a_x_f64x4, _mm256_add_pd(a_x_low_f64x4, a_x_high_f64x4));
342
+ sum_a_y_f64x4 = _mm256_add_pd(sum_a_y_f64x4, _mm256_add_pd(a_y_low_f64x4, a_y_high_f64x4));
343
+ sum_a_z_f64x4 = _mm256_add_pd(sum_a_z_f64x4, _mm256_add_pd(a_z_low_f64x4, a_z_high_f64x4));
344
+ sum_b_x_f64x4 = _mm256_add_pd(sum_b_x_f64x4, _mm256_add_pd(b_x_low_f64x4, b_x_high_f64x4));
345
+ sum_b_y_f64x4 = _mm256_add_pd(sum_b_y_f64x4, _mm256_add_pd(b_y_low_f64x4, b_y_high_f64x4));
346
+ sum_b_z_f64x4 = _mm256_add_pd(sum_b_z_f64x4, _mm256_add_pd(b_z_low_f64x4, b_z_high_f64x4));
347
+
348
+ __m256d delta_x_low_f64x4 = _mm256_sub_pd(a_x_low_f64x4, b_x_low_f64x4);
349
+ __m256d delta_x_high_f64x4 = _mm256_sub_pd(a_x_high_f64x4, b_x_high_f64x4);
350
+ __m256d delta_y_low_f64x4 = _mm256_sub_pd(a_y_low_f64x4, b_y_low_f64x4);
351
+ __m256d delta_y_high_f64x4 = _mm256_sub_pd(a_y_high_f64x4, b_y_high_f64x4);
352
+ __m256d delta_z_low_f64x4 = _mm256_sub_pd(a_z_low_f64x4, b_z_low_f64x4);
353
+ __m256d delta_z_high_f64x4 = _mm256_sub_pd(a_z_high_f64x4, b_z_high_f64x4);
354
+ __m256d batch_sum_squared_f64x4 = _mm256_add_pd(_mm256_mul_pd(delta_x_low_f64x4, delta_x_low_f64x4),
355
+ _mm256_mul_pd(delta_x_high_f64x4, delta_x_high_f64x4));
356
+ batch_sum_squared_f64x4 = _mm256_fmadd_pd(delta_y_low_f64x4, delta_y_low_f64x4, batch_sum_squared_f64x4);
357
+ batch_sum_squared_f64x4 = _mm256_fmadd_pd(delta_y_high_f64x4, delta_y_high_f64x4, batch_sum_squared_f64x4);
358
+ batch_sum_squared_f64x4 = _mm256_fmadd_pd(delta_z_low_f64x4, delta_z_low_f64x4, batch_sum_squared_f64x4);
359
+ batch_sum_squared_f64x4 = _mm256_fmadd_pd(delta_z_high_f64x4, delta_z_high_f64x4, batch_sum_squared_f64x4);
365
360
  sum_squared_f64x4 = _mm256_add_pd(sum_squared_f64x4, batch_sum_squared_f64x4);
366
361
  }
367
362
 
@@ -401,12 +396,10 @@ NK_PUBLIC void nk_rmsd_f32_haswell(nk_f32_t const *a, nk_f32_t const *b, nk_size
401
396
 
402
397
  NK_PUBLIC void nk_rmsd_f64_haswell(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *a_centroid,
403
398
  nk_f64_t *b_centroid, nk_f64_t *rotation, nk_f64_t *scale, nk_f64_t *result) {
404
- /* RMSD uses identity rotation and scale=1.0 */
405
- if (rotation) {
406
- rotation[0] = 1, rotation[1] = 0, rotation[2] = 0;
407
- rotation[3] = 0, rotation[4] = 1, rotation[5] = 0;
399
+ // RMSD uses identity rotation and scale=1.0
400
+ if (rotation)
401
+ rotation[0] = 1, rotation[1] = 0, rotation[2] = 0, rotation[3] = 0, rotation[4] = 1, rotation[5] = 0,
408
402
  rotation[6] = 0, rotation[7] = 0, rotation[8] = 1;
409
- }
410
403
  if (scale) *scale = 1.0;
411
404
  __m256d const zeros_f64x4 = _mm256_setzero_pd();
412
405
 
@@ -521,16 +514,8 @@ NK_PUBLIC void nk_rmsd_f64_haswell(nk_f64_t const *a, nk_f64_t const *b, nk_size
521
514
  nk_f64_t centroid_b_y = total_by * inv_n;
522
515
  nk_f64_t centroid_b_z = total_bz * inv_n;
523
516
 
524
- if (a_centroid) {
525
- a_centroid[0] = centroid_a_x;
526
- a_centroid[1] = centroid_a_y;
527
- a_centroid[2] = centroid_a_z;
528
- }
529
- if (b_centroid) {
530
- b_centroid[0] = centroid_b_x;
531
- b_centroid[1] = centroid_b_y;
532
- b_centroid[2] = centroid_b_z;
533
- }
517
+ if (a_centroid) a_centroid[0] = centroid_a_x, a_centroid[1] = centroid_a_y, a_centroid[2] = centroid_a_z;
518
+ if (b_centroid) b_centroid[0] = centroid_b_x, b_centroid[1] = centroid_b_y, b_centroid[2] = centroid_b_z;
534
519
 
535
520
  // Compute RMSD
536
521
  nk_f64_t mean_diff_x = centroid_a_x - centroid_b_x;
@@ -559,53 +544,53 @@ NK_PUBLIC void nk_kabsch_f32_haswell(nk_f32_t const *a, nk_f32_t const *b, nk_si
559
544
  for (; index + 8 <= n; index += 8) {
560
545
  nk_deinterleave_f32x8_haswell_(a + index * 3, &a_x_f32x8, &a_y_f32x8, &a_z_f32x8),
561
546
  nk_deinterleave_f32x8_haswell_(b + index * 3, &b_x_f32x8, &b_y_f32x8, &b_z_f32x8);
562
- __m256d a_x_lower_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(a_x_f32x8));
563
- __m256d a_x_upper_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(a_x_f32x8, 1));
564
- __m256d a_y_lower_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(a_y_f32x8));
565
- __m256d a_y_upper_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(a_y_f32x8, 1));
566
- __m256d a_z_lower_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(a_z_f32x8));
567
- __m256d a_z_upper_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(a_z_f32x8, 1));
568
- __m256d b_x_lower_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(b_x_f32x8));
569
- __m256d b_x_upper_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(b_x_f32x8, 1));
570
- __m256d b_y_lower_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(b_y_f32x8));
571
- __m256d b_y_upper_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(b_y_f32x8, 1));
572
- __m256d b_z_lower_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(b_z_f32x8));
573
- __m256d b_z_upper_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(b_z_f32x8, 1));
574
-
575
- sum_a_x_f64x4 = _mm256_add_pd(sum_a_x_f64x4, _mm256_add_pd(a_x_lower_f64x4, a_x_upper_f64x4));
576
- sum_a_y_f64x4 = _mm256_add_pd(sum_a_y_f64x4, _mm256_add_pd(a_y_lower_f64x4, a_y_upper_f64x4));
577
- sum_a_z_f64x4 = _mm256_add_pd(sum_a_z_f64x4, _mm256_add_pd(a_z_lower_f64x4, a_z_upper_f64x4));
578
- sum_b_x_f64x4 = _mm256_add_pd(sum_b_x_f64x4, _mm256_add_pd(b_x_lower_f64x4, b_x_upper_f64x4));
579
- sum_b_y_f64x4 = _mm256_add_pd(sum_b_y_f64x4, _mm256_add_pd(b_y_lower_f64x4, b_y_upper_f64x4));
580
- sum_b_z_f64x4 = _mm256_add_pd(sum_b_z_f64x4, _mm256_add_pd(b_z_lower_f64x4, b_z_upper_f64x4));
581
-
582
- covariance_00_f64x4 = _mm256_add_pd(covariance_00_f64x4,
583
- _mm256_add_pd(_mm256_mul_pd(a_x_lower_f64x4, b_x_lower_f64x4),
584
- _mm256_mul_pd(a_x_upper_f64x4, b_x_upper_f64x4)));
585
- covariance_01_f64x4 = _mm256_add_pd(covariance_01_f64x4,
586
- _mm256_add_pd(_mm256_mul_pd(a_x_lower_f64x4, b_y_lower_f64x4),
587
- _mm256_mul_pd(a_x_upper_f64x4, b_y_upper_f64x4)));
588
- covariance_02_f64x4 = _mm256_add_pd(covariance_02_f64x4,
589
- _mm256_add_pd(_mm256_mul_pd(a_x_lower_f64x4, b_z_lower_f64x4),
590
- _mm256_mul_pd(a_x_upper_f64x4, b_z_upper_f64x4)));
591
- covariance_10_f64x4 = _mm256_add_pd(covariance_10_f64x4,
592
- _mm256_add_pd(_mm256_mul_pd(a_y_lower_f64x4, b_x_lower_f64x4),
593
- _mm256_mul_pd(a_y_upper_f64x4, b_x_upper_f64x4)));
594
- covariance_11_f64x4 = _mm256_add_pd(covariance_11_f64x4,
595
- _mm256_add_pd(_mm256_mul_pd(a_y_lower_f64x4, b_y_lower_f64x4),
596
- _mm256_mul_pd(a_y_upper_f64x4, b_y_upper_f64x4)));
597
- covariance_12_f64x4 = _mm256_add_pd(covariance_12_f64x4,
598
- _mm256_add_pd(_mm256_mul_pd(a_y_lower_f64x4, b_z_lower_f64x4),
599
- _mm256_mul_pd(a_y_upper_f64x4, b_z_upper_f64x4)));
600
- covariance_20_f64x4 = _mm256_add_pd(covariance_20_f64x4,
601
- _mm256_add_pd(_mm256_mul_pd(a_z_lower_f64x4, b_x_lower_f64x4),
602
- _mm256_mul_pd(a_z_upper_f64x4, b_x_upper_f64x4)));
603
- covariance_21_f64x4 = _mm256_add_pd(covariance_21_f64x4,
604
- _mm256_add_pd(_mm256_mul_pd(a_z_lower_f64x4, b_y_lower_f64x4),
605
- _mm256_mul_pd(a_z_upper_f64x4, b_y_upper_f64x4)));
606
- covariance_22_f64x4 = _mm256_add_pd(covariance_22_f64x4,
607
- _mm256_add_pd(_mm256_mul_pd(a_z_lower_f64x4, b_z_lower_f64x4),
608
- _mm256_mul_pd(a_z_upper_f64x4, b_z_upper_f64x4)));
547
+ __m256d a_x_low_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(a_x_f32x8));
548
+ __m256d a_x_high_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(a_x_f32x8, 1));
549
+ __m256d a_y_low_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(a_y_f32x8));
550
+ __m256d a_y_high_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(a_y_f32x8, 1));
551
+ __m256d a_z_low_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(a_z_f32x8));
552
+ __m256d a_z_high_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(a_z_f32x8, 1));
553
+ __m256d b_x_low_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(b_x_f32x8));
554
+ __m256d b_x_high_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(b_x_f32x8, 1));
555
+ __m256d b_y_low_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(b_y_f32x8));
556
+ __m256d b_y_high_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(b_y_f32x8, 1));
557
+ __m256d b_z_low_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(b_z_f32x8));
558
+ __m256d b_z_high_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(b_z_f32x8, 1));
559
+
560
+ sum_a_x_f64x4 = _mm256_add_pd(sum_a_x_f64x4, _mm256_add_pd(a_x_low_f64x4, a_x_high_f64x4));
561
+ sum_a_y_f64x4 = _mm256_add_pd(sum_a_y_f64x4, _mm256_add_pd(a_y_low_f64x4, a_y_high_f64x4));
562
+ sum_a_z_f64x4 = _mm256_add_pd(sum_a_z_f64x4, _mm256_add_pd(a_z_low_f64x4, a_z_high_f64x4));
563
+ sum_b_x_f64x4 = _mm256_add_pd(sum_b_x_f64x4, _mm256_add_pd(b_x_low_f64x4, b_x_high_f64x4));
564
+ sum_b_y_f64x4 = _mm256_add_pd(sum_b_y_f64x4, _mm256_add_pd(b_y_low_f64x4, b_y_high_f64x4));
565
+ sum_b_z_f64x4 = _mm256_add_pd(sum_b_z_f64x4, _mm256_add_pd(b_z_low_f64x4, b_z_high_f64x4));
566
+
567
+ covariance_00_f64x4 = _mm256_add_pd(
568
+ covariance_00_f64x4,
569
+ _mm256_add_pd(_mm256_mul_pd(a_x_low_f64x4, b_x_low_f64x4), _mm256_mul_pd(a_x_high_f64x4, b_x_high_f64x4)));
570
+ covariance_01_f64x4 = _mm256_add_pd(
571
+ covariance_01_f64x4,
572
+ _mm256_add_pd(_mm256_mul_pd(a_x_low_f64x4, b_y_low_f64x4), _mm256_mul_pd(a_x_high_f64x4, b_y_high_f64x4)));
573
+ covariance_02_f64x4 = _mm256_add_pd(
574
+ covariance_02_f64x4,
575
+ _mm256_add_pd(_mm256_mul_pd(a_x_low_f64x4, b_z_low_f64x4), _mm256_mul_pd(a_x_high_f64x4, b_z_high_f64x4)));
576
+ covariance_10_f64x4 = _mm256_add_pd(
577
+ covariance_10_f64x4,
578
+ _mm256_add_pd(_mm256_mul_pd(a_y_low_f64x4, b_x_low_f64x4), _mm256_mul_pd(a_y_high_f64x4, b_x_high_f64x4)));
579
+ covariance_11_f64x4 = _mm256_add_pd(
580
+ covariance_11_f64x4,
581
+ _mm256_add_pd(_mm256_mul_pd(a_y_low_f64x4, b_y_low_f64x4), _mm256_mul_pd(a_y_high_f64x4, b_y_high_f64x4)));
582
+ covariance_12_f64x4 = _mm256_add_pd(
583
+ covariance_12_f64x4,
584
+ _mm256_add_pd(_mm256_mul_pd(a_y_low_f64x4, b_z_low_f64x4), _mm256_mul_pd(a_y_high_f64x4, b_z_high_f64x4)));
585
+ covariance_20_f64x4 = _mm256_add_pd(
586
+ covariance_20_f64x4,
587
+ _mm256_add_pd(_mm256_mul_pd(a_z_low_f64x4, b_x_low_f64x4), _mm256_mul_pd(a_z_high_f64x4, b_x_high_f64x4)));
588
+ covariance_21_f64x4 = _mm256_add_pd(
589
+ covariance_21_f64x4,
590
+ _mm256_add_pd(_mm256_mul_pd(a_z_low_f64x4, b_y_low_f64x4), _mm256_mul_pd(a_z_high_f64x4, b_y_high_f64x4)));
591
+ covariance_22_f64x4 = _mm256_add_pd(
592
+ covariance_22_f64x4,
593
+ _mm256_add_pd(_mm256_mul_pd(a_z_low_f64x4, b_z_low_f64x4), _mm256_mul_pd(a_z_high_f64x4, b_z_high_f64x4)));
609
594
  }
610
595
 
611
596
  nk_f64_t sum_a_x = nk_reduce_add_f64x4_haswell_(sum_a_x_f64x4);
@@ -775,27 +760,19 @@ NK_PUBLIC void nk_kabsch_f64_haswell(nk_f64_t const *a, nk_f64_t const *b, nk_si
775
760
  nk_f64_t centroid_b_y = sum_b_y * inv_n;
776
761
  nk_f64_t centroid_b_z = sum_b_z * inv_n;
777
762
 
778
- if (a_centroid) {
779
- a_centroid[0] = centroid_a_x;
780
- a_centroid[1] = centroid_a_y;
781
- a_centroid[2] = centroid_a_z;
782
- }
783
- if (b_centroid) {
784
- b_centroid[0] = centroid_b_x;
785
- b_centroid[1] = centroid_b_y;
786
- b_centroid[2] = centroid_b_z;
787
- }
763
+ if (a_centroid) a_centroid[0] = centroid_a_x, a_centroid[1] = centroid_a_y, a_centroid[2] = centroid_a_z;
764
+ if (b_centroid) b_centroid[0] = centroid_b_x, b_centroid[1] = centroid_b_y, b_centroid[2] = centroid_b_z;
788
765
 
789
766
  // Apply centering correction: H_centered = H - n * centroid_a * centroid_bᵀ
790
- covariance_x_x -= n * centroid_a_x * centroid_b_x;
791
- covariance_x_y -= n * centroid_a_x * centroid_b_y;
792
- covariance_x_z -= n * centroid_a_x * centroid_b_z;
793
- covariance_y_x -= n * centroid_a_y * centroid_b_x;
794
- covariance_y_y -= n * centroid_a_y * centroid_b_y;
795
- covariance_y_z -= n * centroid_a_y * centroid_b_z;
796
- covariance_z_x -= n * centroid_a_z * centroid_b_x;
797
- covariance_z_y -= n * centroid_a_z * centroid_b_y;
798
- covariance_z_z -= n * centroid_a_z * centroid_b_z;
767
+ covariance_x_x -= (nk_f64_t)n * centroid_a_x * centroid_b_x;
768
+ covariance_x_y -= (nk_f64_t)n * centroid_a_x * centroid_b_y;
769
+ covariance_x_z -= (nk_f64_t)n * centroid_a_x * centroid_b_z;
770
+ covariance_y_x -= (nk_f64_t)n * centroid_a_y * centroid_b_x;
771
+ covariance_y_y -= (nk_f64_t)n * centroid_a_y * centroid_b_y;
772
+ covariance_y_z -= (nk_f64_t)n * centroid_a_y * centroid_b_z;
773
+ covariance_z_x -= (nk_f64_t)n * centroid_a_z * centroid_b_x;
774
+ covariance_z_y -= (nk_f64_t)n * centroid_a_z * centroid_b_y;
775
+ covariance_z_z -= (nk_f64_t)n * centroid_a_z * centroid_b_z;
799
776
 
800
777
  // Compute SVD and optimal rotation using f64 precision (svd_s is 9-element diagonal matrix)
801
778
  nk_f64_t cross_covariance[9] = {covariance_x_x, covariance_x_y, covariance_x_z, covariance_y_x, covariance_y_y,
@@ -808,16 +785,13 @@ NK_PUBLIC void nk_kabsch_f64_haswell(nk_f64_t const *a, nk_f64_t const *b, nk_si
808
785
 
809
786
  // Handle reflection: if det(R) < 0, negate third column of V and recompute R
810
787
  if (nk_det3x3_f64_(r) < 0) {
811
- svd_v[2] = -svd_v[2];
812
- svd_v[5] = -svd_v[5];
813
- svd_v[8] = -svd_v[8];
788
+ svd_v[2] = -svd_v[2], svd_v[5] = -svd_v[5], svd_v[8] = -svd_v[8];
814
789
  nk_rotation_from_svd_f64_haswell_(svd_u, svd_v, r);
815
790
  }
816
791
 
817
- /* Output rotation matrix and scale=1.0 */
818
- if (rotation) {
792
+ // Output rotation matrix and scale=1.0
793
+ if (rotation)
819
794
  for (int j = 0; j < 9; ++j) rotation[j] = r[j];
820
- }
821
795
  if (scale) *scale = 1.0;
822
796
 
823
797
  // Compute RMSD after optimal rotation
@@ -842,60 +816,60 @@ NK_PUBLIC void nk_umeyama_f32_haswell(nk_f32_t const *a, nk_f32_t const *b, nk_s
842
816
  for (; index + 8 <= n; index += 8) {
843
817
  nk_deinterleave_f32x8_haswell_(a + index * 3, &a_x_f32x8, &a_y_f32x8, &a_z_f32x8),
844
818
  nk_deinterleave_f32x8_haswell_(b + index * 3, &b_x_f32x8, &b_y_f32x8, &b_z_f32x8);
845
- __m256d a_x_lower_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(a_x_f32x8));
846
- __m256d a_x_upper_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(a_x_f32x8, 1));
847
- __m256d a_y_lower_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(a_y_f32x8));
848
- __m256d a_y_upper_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(a_y_f32x8, 1));
849
- __m256d a_z_lower_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(a_z_f32x8));
850
- __m256d a_z_upper_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(a_z_f32x8, 1));
851
- __m256d b_x_lower_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(b_x_f32x8));
852
- __m256d b_x_upper_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(b_x_f32x8, 1));
853
- __m256d b_y_lower_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(b_y_f32x8));
854
- __m256d b_y_upper_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(b_y_f32x8, 1));
855
- __m256d b_z_lower_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(b_z_f32x8));
856
- __m256d b_z_upper_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(b_z_f32x8, 1));
857
-
858
- sum_a_x_f64x4 = _mm256_add_pd(sum_a_x_f64x4, _mm256_add_pd(a_x_lower_f64x4, a_x_upper_f64x4));
859
- sum_a_y_f64x4 = _mm256_add_pd(sum_a_y_f64x4, _mm256_add_pd(a_y_lower_f64x4, a_y_upper_f64x4));
860
- sum_a_z_f64x4 = _mm256_add_pd(sum_a_z_f64x4, _mm256_add_pd(a_z_lower_f64x4, a_z_upper_f64x4));
861
- sum_b_x_f64x4 = _mm256_add_pd(sum_b_x_f64x4, _mm256_add_pd(b_x_lower_f64x4, b_x_upper_f64x4));
862
- sum_b_y_f64x4 = _mm256_add_pd(sum_b_y_f64x4, _mm256_add_pd(b_y_lower_f64x4, b_y_upper_f64x4));
863
- sum_b_z_f64x4 = _mm256_add_pd(sum_b_z_f64x4, _mm256_add_pd(b_z_lower_f64x4, b_z_upper_f64x4));
864
- covariance_00_f64x4 = _mm256_add_pd(covariance_00_f64x4,
865
- _mm256_add_pd(_mm256_mul_pd(a_x_lower_f64x4, b_x_lower_f64x4),
866
- _mm256_mul_pd(a_x_upper_f64x4, b_x_upper_f64x4)));
867
- covariance_01_f64x4 = _mm256_add_pd(covariance_01_f64x4,
868
- _mm256_add_pd(_mm256_mul_pd(a_x_lower_f64x4, b_y_lower_f64x4),
869
- _mm256_mul_pd(a_x_upper_f64x4, b_y_upper_f64x4)));
870
- covariance_02_f64x4 = _mm256_add_pd(covariance_02_f64x4,
871
- _mm256_add_pd(_mm256_mul_pd(a_x_lower_f64x4, b_z_lower_f64x4),
872
- _mm256_mul_pd(a_x_upper_f64x4, b_z_upper_f64x4)));
873
- covariance_10_f64x4 = _mm256_add_pd(covariance_10_f64x4,
874
- _mm256_add_pd(_mm256_mul_pd(a_y_lower_f64x4, b_x_lower_f64x4),
875
- _mm256_mul_pd(a_y_upper_f64x4, b_x_upper_f64x4)));
876
- covariance_11_f64x4 = _mm256_add_pd(covariance_11_f64x4,
877
- _mm256_add_pd(_mm256_mul_pd(a_y_lower_f64x4, b_y_lower_f64x4),
878
- _mm256_mul_pd(a_y_upper_f64x4, b_y_upper_f64x4)));
879
- covariance_12_f64x4 = _mm256_add_pd(covariance_12_f64x4,
880
- _mm256_add_pd(_mm256_mul_pd(a_y_lower_f64x4, b_z_lower_f64x4),
881
- _mm256_mul_pd(a_y_upper_f64x4, b_z_upper_f64x4)));
882
- covariance_20_f64x4 = _mm256_add_pd(covariance_20_f64x4,
883
- _mm256_add_pd(_mm256_mul_pd(a_z_lower_f64x4, b_x_lower_f64x4),
884
- _mm256_mul_pd(a_z_upper_f64x4, b_x_upper_f64x4)));
885
- covariance_21_f64x4 = _mm256_add_pd(covariance_21_f64x4,
886
- _mm256_add_pd(_mm256_mul_pd(a_z_lower_f64x4, b_y_lower_f64x4),
887
- _mm256_mul_pd(a_z_upper_f64x4, b_y_upper_f64x4)));
888
- covariance_22_f64x4 = _mm256_add_pd(covariance_22_f64x4,
889
- _mm256_add_pd(_mm256_mul_pd(a_z_lower_f64x4, b_z_lower_f64x4),
890
- _mm256_mul_pd(a_z_upper_f64x4, b_z_upper_f64x4)));
819
+ __m256d a_x_low_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(a_x_f32x8));
820
+ __m256d a_x_high_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(a_x_f32x8, 1));
821
+ __m256d a_y_low_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(a_y_f32x8));
822
+ __m256d a_y_high_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(a_y_f32x8, 1));
823
+ __m256d a_z_low_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(a_z_f32x8));
824
+ __m256d a_z_high_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(a_z_f32x8, 1));
825
+ __m256d b_x_low_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(b_x_f32x8));
826
+ __m256d b_x_high_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(b_x_f32x8, 1));
827
+ __m256d b_y_low_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(b_y_f32x8));
828
+ __m256d b_y_high_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(b_y_f32x8, 1));
829
+ __m256d b_z_low_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(b_z_f32x8));
830
+ __m256d b_z_high_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(b_z_f32x8, 1));
831
+
832
+ sum_a_x_f64x4 = _mm256_add_pd(sum_a_x_f64x4, _mm256_add_pd(a_x_low_f64x4, a_x_high_f64x4));
833
+ sum_a_y_f64x4 = _mm256_add_pd(sum_a_y_f64x4, _mm256_add_pd(a_y_low_f64x4, a_y_high_f64x4));
834
+ sum_a_z_f64x4 = _mm256_add_pd(sum_a_z_f64x4, _mm256_add_pd(a_z_low_f64x4, a_z_high_f64x4));
835
+ sum_b_x_f64x4 = _mm256_add_pd(sum_b_x_f64x4, _mm256_add_pd(b_x_low_f64x4, b_x_high_f64x4));
836
+ sum_b_y_f64x4 = _mm256_add_pd(sum_b_y_f64x4, _mm256_add_pd(b_y_low_f64x4, b_y_high_f64x4));
837
+ sum_b_z_f64x4 = _mm256_add_pd(sum_b_z_f64x4, _mm256_add_pd(b_z_low_f64x4, b_z_high_f64x4));
838
+ covariance_00_f64x4 = _mm256_add_pd(
839
+ covariance_00_f64x4,
840
+ _mm256_add_pd(_mm256_mul_pd(a_x_low_f64x4, b_x_low_f64x4), _mm256_mul_pd(a_x_high_f64x4, b_x_high_f64x4)));
841
+ covariance_01_f64x4 = _mm256_add_pd(
842
+ covariance_01_f64x4,
843
+ _mm256_add_pd(_mm256_mul_pd(a_x_low_f64x4, b_y_low_f64x4), _mm256_mul_pd(a_x_high_f64x4, b_y_high_f64x4)));
844
+ covariance_02_f64x4 = _mm256_add_pd(
845
+ covariance_02_f64x4,
846
+ _mm256_add_pd(_mm256_mul_pd(a_x_low_f64x4, b_z_low_f64x4), _mm256_mul_pd(a_x_high_f64x4, b_z_high_f64x4)));
847
+ covariance_10_f64x4 = _mm256_add_pd(
848
+ covariance_10_f64x4,
849
+ _mm256_add_pd(_mm256_mul_pd(a_y_low_f64x4, b_x_low_f64x4), _mm256_mul_pd(a_y_high_f64x4, b_x_high_f64x4)));
850
+ covariance_11_f64x4 = _mm256_add_pd(
851
+ covariance_11_f64x4,
852
+ _mm256_add_pd(_mm256_mul_pd(a_y_low_f64x4, b_y_low_f64x4), _mm256_mul_pd(a_y_high_f64x4, b_y_high_f64x4)));
853
+ covariance_12_f64x4 = _mm256_add_pd(
854
+ covariance_12_f64x4,
855
+ _mm256_add_pd(_mm256_mul_pd(a_y_low_f64x4, b_z_low_f64x4), _mm256_mul_pd(a_y_high_f64x4, b_z_high_f64x4)));
856
+ covariance_20_f64x4 = _mm256_add_pd(
857
+ covariance_20_f64x4,
858
+ _mm256_add_pd(_mm256_mul_pd(a_z_low_f64x4, b_x_low_f64x4), _mm256_mul_pd(a_z_high_f64x4, b_x_high_f64x4)));
859
+ covariance_21_f64x4 = _mm256_add_pd(
860
+ covariance_21_f64x4,
861
+ _mm256_add_pd(_mm256_mul_pd(a_z_low_f64x4, b_y_low_f64x4), _mm256_mul_pd(a_z_high_f64x4, b_y_high_f64x4)));
862
+ covariance_22_f64x4 = _mm256_add_pd(
863
+ covariance_22_f64x4,
864
+ _mm256_add_pd(_mm256_mul_pd(a_z_low_f64x4, b_z_low_f64x4), _mm256_mul_pd(a_z_high_f64x4, b_z_high_f64x4)));
891
865
  variance_a_f64x4 = _mm256_add_pd(
892
866
  variance_a_f64x4,
893
- _mm256_add_pd(_mm256_add_pd(_mm256_mul_pd(a_x_lower_f64x4, a_x_lower_f64x4),
894
- _mm256_mul_pd(a_x_upper_f64x4, a_x_upper_f64x4)),
895
- _mm256_add_pd(_mm256_add_pd(_mm256_mul_pd(a_y_lower_f64x4, a_y_lower_f64x4),
896
- _mm256_mul_pd(a_y_upper_f64x4, a_y_upper_f64x4)),
897
- _mm256_add_pd(_mm256_mul_pd(a_z_lower_f64x4, a_z_lower_f64x4),
898
- _mm256_mul_pd(a_z_upper_f64x4, a_z_upper_f64x4)))));
867
+ _mm256_add_pd(_mm256_add_pd(_mm256_mul_pd(a_x_low_f64x4, a_x_low_f64x4),
868
+ _mm256_mul_pd(a_x_high_f64x4, a_x_high_f64x4)),
869
+ _mm256_add_pd(_mm256_add_pd(_mm256_mul_pd(a_y_low_f64x4, a_y_low_f64x4),
870
+ _mm256_mul_pd(a_y_high_f64x4, a_y_high_f64x4)),
871
+ _mm256_add_pd(_mm256_mul_pd(a_z_low_f64x4, a_z_low_f64x4),
872
+ _mm256_mul_pd(a_z_high_f64x4, a_z_high_f64x4)))));
899
873
  }
900
874
 
901
875
  nk_f64_t sum_a_x = nk_reduce_add_f64x4_haswell_(sum_a_x_f64x4);
@@ -1106,7 +1080,7 @@ NK_PUBLIC void nk_umeyama_f64_haswell(nk_f64_t const *a, nk_f64_t const *b, nk_s
1106
1080
  nk_f64_t det = nk_det3x3_f64_(r);
1107
1081
  nk_f64_t d3 = det < 0 ? -1.0 : 1.0;
1108
1082
  nk_f64_t trace_ds = nk_sum_three_products_f64_(svd_s[0], 1.0, svd_s[4], 1.0, svd_s[8], d3);
1109
- nk_f64_t c = trace_ds / (n * variance_a);
1083
+ nk_f64_t c = trace_ds / ((nk_f64_t)n * variance_a);
1110
1084
  if (scale) *scale = c;
1111
1085
 
1112
1086
  // Handle reflection
@@ -1115,10 +1089,9 @@ NK_PUBLIC void nk_umeyama_f64_haswell(nk_f64_t const *a, nk_f64_t const *b, nk_s
1115
1089
  nk_rotation_from_svd_f64_haswell_(svd_u, svd_v, r);
1116
1090
  }
1117
1091
 
1118
- /* Output rotation matrix */
1119
- if (rotation) {
1092
+ // Output rotation matrix
1093
+ if (rotation)
1120
1094
  for (int j = 0; j < 9; ++j) rotation[j] = r[j];
1121
- }
1122
1095
 
1123
1096
  // Compute RMSD with scaling
1124
1097
  nk_f64_t sum_squared = nk_transformed_ssd_f64_haswell_(a, b, n, r, c, centroid_a_x, centroid_a_y, centroid_a_z,
@@ -1247,20 +1220,13 @@ NK_INTERNAL nk_f32_t nk_transformed_ssd_f16_haswell_(nk_f16_t const *a, nk_f16_t
1247
1220
  nk_f16_to_f32_haswell(&b[j * 3 + 1], &b_y_f32);
1248
1221
  nk_f16_to_f32_haswell(&b[j * 3 + 2], &b_z_f32);
1249
1222
 
1250
- nk_f32_t pa_x = a_x_f32 - centroid_a_x;
1251
- nk_f32_t pa_y = a_y_f32 - centroid_a_y;
1252
- nk_f32_t pa_z = a_z_f32 - centroid_a_z;
1253
- nk_f32_t pb_x = b_x_f32 - centroid_b_x;
1254
- nk_f32_t pb_y = b_y_f32 - centroid_b_y;
1255
- nk_f32_t pb_z = b_z_f32 - centroid_b_z;
1223
+ nk_f32_t pa_x = a_x_f32 - centroid_a_x, pa_y = a_y_f32 - centroid_a_y, pa_z = a_z_f32 - centroid_a_z;
1224
+ nk_f32_t pb_x = b_x_f32 - centroid_b_x, pb_y = b_y_f32 - centroid_b_y, pb_z = b_z_f32 - centroid_b_z;
1225
+ nk_f32_t ra_x = scale * (r[0] * pa_x + r[1] * pa_y + r[2] * pa_z),
1226
+ ra_y = scale * (r[3] * pa_x + r[4] * pa_y + r[5] * pa_z),
1227
+ ra_z = scale * (r[6] * pa_x + r[7] * pa_y + r[8] * pa_z);
1256
1228
 
1257
- nk_f32_t ra_x = scale * (r[0] * pa_x + r[1] * pa_y + r[2] * pa_z);
1258
- nk_f32_t ra_y = scale * (r[3] * pa_x + r[4] * pa_y + r[5] * pa_z);
1259
- nk_f32_t ra_z = scale * (r[6] * pa_x + r[7] * pa_y + r[8] * pa_z);
1260
-
1261
- nk_f32_t delta_x = ra_x - pb_x;
1262
- nk_f32_t delta_y = ra_y - pb_y;
1263
- nk_f32_t delta_z = ra_z - pb_z;
1229
+ nk_f32_t delta_x = ra_x - pb_x, delta_y = ra_y - pb_y, delta_z = ra_z - pb_z;
1264
1230
  sum_squared += delta_x * delta_x + delta_y * delta_y + delta_z * delta_z;
1265
1231
  }
1266
1232
 
@@ -1344,20 +1310,13 @@ NK_INTERNAL nk_f32_t nk_transformed_ssd_bf16_haswell_(nk_bf16_t const *a, nk_bf1
1344
1310
  nk_bf16_to_f32_serial(&b[j * 3 + 1], &b_y_f32);
1345
1311
  nk_bf16_to_f32_serial(&b[j * 3 + 2], &b_z_f32);
1346
1312
 
1347
- nk_f32_t pa_x = a_x_f32 - centroid_a_x;
1348
- nk_f32_t pa_y = a_y_f32 - centroid_a_y;
1349
- nk_f32_t pa_z = a_z_f32 - centroid_a_z;
1350
- nk_f32_t pb_x = b_x_f32 - centroid_b_x;
1351
- nk_f32_t pb_y = b_y_f32 - centroid_b_y;
1352
- nk_f32_t pb_z = b_z_f32 - centroid_b_z;
1353
-
1354
- nk_f32_t ra_x = scale * (r[0] * pa_x + r[1] * pa_y + r[2] * pa_z);
1355
- nk_f32_t ra_y = scale * (r[3] * pa_x + r[4] * pa_y + r[5] * pa_z);
1356
- nk_f32_t ra_z = scale * (r[6] * pa_x + r[7] * pa_y + r[8] * pa_z);
1313
+ nk_f32_t pa_x = a_x_f32 - centroid_a_x, pa_y = a_y_f32 - centroid_a_y, pa_z = a_z_f32 - centroid_a_z;
1314
+ nk_f32_t pb_x = b_x_f32 - centroid_b_x, pb_y = b_y_f32 - centroid_b_y, pb_z = b_z_f32 - centroid_b_z;
1315
+ nk_f32_t ra_x = scale * (r[0] * pa_x + r[1] * pa_y + r[2] * pa_z),
1316
+ ra_y = scale * (r[3] * pa_x + r[4] * pa_y + r[5] * pa_z),
1317
+ ra_z = scale * (r[6] * pa_x + r[7] * pa_y + r[8] * pa_z);
1357
1318
 
1358
- nk_f32_t delta_x = ra_x - pb_x;
1359
- nk_f32_t delta_y = ra_y - pb_y;
1360
- nk_f32_t delta_z = ra_z - pb_z;
1319
+ nk_f32_t delta_x = ra_x - pb_x, delta_y = ra_y - pb_y, delta_z = ra_z - pb_z;
1361
1320
  sum_squared += delta_x * delta_x + delta_y * delta_y + delta_z * delta_z;
1362
1321
  }
1363
1322
 
@@ -1366,12 +1325,10 @@ NK_INTERNAL nk_f32_t nk_transformed_ssd_bf16_haswell_(nk_bf16_t const *a, nk_bf1
1366
1325
 
1367
1326
  NK_PUBLIC void nk_rmsd_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
1368
1327
  nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
1369
- /* RMSD uses identity rotation and scale=1.0 */
1370
- if (rotation) {
1371
- rotation[0] = 1, rotation[1] = 0, rotation[2] = 0;
1372
- rotation[3] = 0, rotation[4] = 1, rotation[5] = 0;
1328
+ // RMSD uses identity rotation and scale=1.0
1329
+ if (rotation)
1330
+ rotation[0] = 1, rotation[1] = 0, rotation[2] = 0, rotation[3] = 0, rotation[4] = 1, rotation[5] = 0,
1373
1331
  rotation[6] = 0, rotation[7] = 0, rotation[8] = 1;
1374
- }
1375
1332
  if (scale) *scale = 1.0f;
1376
1333
 
1377
1334
  __m256 const zeros_f32x8 = _mm256_setzero_ps();
@@ -1446,16 +1403,8 @@ NK_PUBLIC void nk_rmsd_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_size
1446
1403
  nk_f32_t centroid_b_y = total_by * inv_n;
1447
1404
  nk_f32_t centroid_b_z = total_bz * inv_n;
1448
1405
 
1449
- if (a_centroid) {
1450
- a_centroid[0] = centroid_a_x;
1451
- a_centroid[1] = centroid_a_y;
1452
- a_centroid[2] = centroid_a_z;
1453
- }
1454
- if (b_centroid) {
1455
- b_centroid[0] = centroid_b_x;
1456
- b_centroid[1] = centroid_b_y;
1457
- b_centroid[2] = centroid_b_z;
1458
- }
1406
+ if (a_centroid) a_centroid[0] = centroid_a_x, a_centroid[1] = centroid_a_y, a_centroid[2] = centroid_a_z;
1407
+ if (b_centroid) b_centroid[0] = centroid_b_x, b_centroid[1] = centroid_b_y, b_centroid[2] = centroid_b_z;
1459
1408
 
1460
1409
  // Compute RMSD
1461
1410
  nk_f32_t mean_diff_x = centroid_a_x - centroid_b_x;
@@ -1469,12 +1418,10 @@ NK_PUBLIC void nk_rmsd_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_size
1469
1418
 
1470
1419
  NK_PUBLIC void nk_rmsd_bf16_haswell(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
1471
1420
  nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
1472
- /* RMSD uses identity rotation and scale=1.0 */
1473
- if (rotation) {
1474
- rotation[0] = 1, rotation[1] = 0, rotation[2] = 0;
1475
- rotation[3] = 0, rotation[4] = 1, rotation[5] = 0;
1421
+ // RMSD uses identity rotation and scale=1.0
1422
+ if (rotation)
1423
+ rotation[0] = 1, rotation[1] = 0, rotation[2] = 0, rotation[3] = 0, rotation[4] = 1, rotation[5] = 0,
1476
1424
  rotation[6] = 0, rotation[7] = 0, rotation[8] = 1;
1477
- }
1478
1425
  if (scale) *scale = 1.0f;
1479
1426
 
1480
1427
  __m256 const zeros_f32x8 = _mm256_setzero_ps();
@@ -1549,16 +1496,8 @@ NK_PUBLIC void nk_rmsd_bf16_haswell(nk_bf16_t const *a, nk_bf16_t const *b, nk_s
1549
1496
  nk_f32_t centroid_b_y = total_by * inv_n;
1550
1497
  nk_f32_t centroid_b_z = total_bz * inv_n;
1551
1498
 
1552
- if (a_centroid) {
1553
- a_centroid[0] = centroid_a_x;
1554
- a_centroid[1] = centroid_a_y;
1555
- a_centroid[2] = centroid_a_z;
1556
- }
1557
- if (b_centroid) {
1558
- b_centroid[0] = centroid_b_x;
1559
- b_centroid[1] = centroid_b_y;
1560
- b_centroid[2] = centroid_b_z;
1561
- }
1499
+ if (a_centroid) a_centroid[0] = centroid_a_x, a_centroid[1] = centroid_a_y, a_centroid[2] = centroid_a_z;
1500
+ if (b_centroid) b_centroid[0] = centroid_b_x, b_centroid[1] = centroid_b_y, b_centroid[2] = centroid_b_z;
1562
1501
 
1563
1502
  // Compute RMSD
1564
1503
  nk_f32_t mean_diff_x = centroid_a_x - centroid_b_x;
@@ -1638,21 +1577,11 @@ NK_PUBLIC void nk_kabsch_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_si
1638
1577
  nk_f16_to_f32_haswell(&b[i * 3 + 0], &bx);
1639
1578
  nk_f16_to_f32_haswell(&b[i * 3 + 1], &by);
1640
1579
  nk_f16_to_f32_haswell(&b[i * 3 + 2], &bz);
1641
- sum_a_x += ax;
1642
- sum_a_y += ay;
1643
- sum_a_z += az;
1644
- sum_b_x += bx;
1645
- sum_b_y += by;
1646
- sum_b_z += bz;
1647
- covariance_x_x += ax * bx;
1648
- covariance_x_y += ax * by;
1649
- covariance_x_z += ax * bz;
1650
- covariance_y_x += ay * bx;
1651
- covariance_y_y += ay * by;
1652
- covariance_y_z += ay * bz;
1653
- covariance_z_x += az * bx;
1654
- covariance_z_y += az * by;
1655
- covariance_z_z += az * bz;
1580
+ sum_a_x += ax, sum_a_y += ay, sum_a_z += az;
1581
+ sum_b_x += bx, sum_b_y += by, sum_b_z += bz;
1582
+ covariance_x_x += ax * bx, covariance_x_y += ax * by, covariance_x_z += ax * bz;
1583
+ covariance_y_x += ay * bx, covariance_y_y += ay * by, covariance_y_z += ay * bz;
1584
+ covariance_z_x += az * bx, covariance_z_y += az * by, covariance_z_z += az * bz;
1656
1585
  }
1657
1586
 
1658
1587
  // Compute centroids
@@ -1664,27 +1593,19 @@ NK_PUBLIC void nk_kabsch_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_si
1664
1593
  nk_f32_t centroid_b_y = sum_b_y * inv_n;
1665
1594
  nk_f32_t centroid_b_z = sum_b_z * inv_n;
1666
1595
 
1667
- if (a_centroid) {
1668
- a_centroid[0] = centroid_a_x;
1669
- a_centroid[1] = centroid_a_y;
1670
- a_centroid[2] = centroid_a_z;
1671
- }
1672
- if (b_centroid) {
1673
- b_centroid[0] = centroid_b_x;
1674
- b_centroid[1] = centroid_b_y;
1675
- b_centroid[2] = centroid_b_z;
1676
- }
1596
+ if (a_centroid) a_centroid[0] = centroid_a_x, a_centroid[1] = centroid_a_y, a_centroid[2] = centroid_a_z;
1597
+ if (b_centroid) b_centroid[0] = centroid_b_x, b_centroid[1] = centroid_b_y, b_centroid[2] = centroid_b_z;
1677
1598
 
1678
1599
  // Apply centering correction: H_centered = H - n * centroid_a * centroid_bᵀ
1679
- covariance_x_x -= n * centroid_a_x * centroid_b_x;
1680
- covariance_x_y -= n * centroid_a_x * centroid_b_y;
1681
- covariance_x_z -= n * centroid_a_x * centroid_b_z;
1682
- covariance_y_x -= n * centroid_a_y * centroid_b_x;
1683
- covariance_y_y -= n * centroid_a_y * centroid_b_y;
1684
- covariance_y_z -= n * centroid_a_y * centroid_b_z;
1685
- covariance_z_x -= n * centroid_a_z * centroid_b_x;
1686
- covariance_z_y -= n * centroid_a_z * centroid_b_y;
1687
- covariance_z_z -= n * centroid_a_z * centroid_b_z;
1600
+ covariance_x_x -= (nk_f32_t)n * centroid_a_x * centroid_b_x;
1601
+ covariance_x_y -= (nk_f32_t)n * centroid_a_x * centroid_b_y;
1602
+ covariance_x_z -= (nk_f32_t)n * centroid_a_x * centroid_b_z;
1603
+ covariance_y_x -= (nk_f32_t)n * centroid_a_y * centroid_b_x;
1604
+ covariance_y_y -= (nk_f32_t)n * centroid_a_y * centroid_b_y;
1605
+ covariance_y_z -= (nk_f32_t)n * centroid_a_y * centroid_b_z;
1606
+ covariance_z_x -= (nk_f32_t)n * centroid_a_z * centroid_b_x;
1607
+ covariance_z_y -= (nk_f32_t)n * centroid_a_z * centroid_b_y;
1608
+ covariance_z_z -= (nk_f32_t)n * centroid_a_z * centroid_b_z;
1688
1609
 
1689
1610
  // Compute SVD and optimal rotation
1690
1611
  nk_f32_t cross_covariance[9] = {covariance_x_x, covariance_x_y, covariance_x_z, covariance_y_x, covariance_y_y,
@@ -1706,9 +1627,7 @@ NK_PUBLIC void nk_kabsch_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_si
1706
1627
 
1707
1628
  // Handle reflection: if det(R) < 0, negate third column of V and recompute R
1708
1629
  if (nk_det3x3_f32_(r) < 0) {
1709
- svd_v[2] = -svd_v[2];
1710
- svd_v[5] = -svd_v[5];
1711
- svd_v[8] = -svd_v[8];
1630
+ svd_v[2] = -svd_v[2], svd_v[5] = -svd_v[5], svd_v[8] = -svd_v[8];
1712
1631
  r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
1713
1632
  r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
1714
1633
  r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
@@ -1720,10 +1639,9 @@ NK_PUBLIC void nk_kabsch_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_si
1720
1639
  r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
1721
1640
  }
1722
1641
 
1723
- /* Output rotation matrix and scale=1.0 */
1724
- if (rotation) {
1642
+ // Output rotation matrix and scale=1.0
1643
+ if (rotation)
1725
1644
  for (int j = 0; j < 9; ++j) rotation[j] = r[j];
1726
- }
1727
1645
  if (scale) *scale = 1.0f;
1728
1646
 
1729
1647
  // Compute RMSD after optimal rotation
@@ -1800,21 +1718,11 @@ NK_PUBLIC void nk_kabsch_bf16_haswell(nk_bf16_t const *a, nk_bf16_t const *b, nk
1800
1718
  nk_bf16_to_f32_serial(&b[i * 3 + 0], &bx);
1801
1719
  nk_bf16_to_f32_serial(&b[i * 3 + 1], &by);
1802
1720
  nk_bf16_to_f32_serial(&b[i * 3 + 2], &bz);
1803
- sum_a_x += ax;
1804
- sum_a_y += ay;
1805
- sum_a_z += az;
1806
- sum_b_x += bx;
1807
- sum_b_y += by;
1808
- sum_b_z += bz;
1809
- covariance_x_x += ax * bx;
1810
- covariance_x_y += ax * by;
1811
- covariance_x_z += ax * bz;
1812
- covariance_y_x += ay * bx;
1813
- covariance_y_y += ay * by;
1814
- covariance_y_z += ay * bz;
1815
- covariance_z_x += az * bx;
1816
- covariance_z_y += az * by;
1817
- covariance_z_z += az * bz;
1721
+ sum_a_x += ax, sum_a_y += ay, sum_a_z += az;
1722
+ sum_b_x += bx, sum_b_y += by, sum_b_z += bz;
1723
+ covariance_x_x += ax * bx, covariance_x_y += ax * by, covariance_x_z += ax * bz;
1724
+ covariance_y_x += ay * bx, covariance_y_y += ay * by, covariance_y_z += ay * bz;
1725
+ covariance_z_x += az * bx, covariance_z_y += az * by, covariance_z_z += az * bz;
1818
1726
  }
1819
1727
 
1820
1728
  // Compute centroids
@@ -1826,27 +1734,19 @@ NK_PUBLIC void nk_kabsch_bf16_haswell(nk_bf16_t const *a, nk_bf16_t const *b, nk
1826
1734
  nk_f32_t centroid_b_y = sum_b_y * inv_n;
1827
1735
  nk_f32_t centroid_b_z = sum_b_z * inv_n;
1828
1736
 
1829
- if (a_centroid) {
1830
- a_centroid[0] = centroid_a_x;
1831
- a_centroid[1] = centroid_a_y;
1832
- a_centroid[2] = centroid_a_z;
1833
- }
1834
- if (b_centroid) {
1835
- b_centroid[0] = centroid_b_x;
1836
- b_centroid[1] = centroid_b_y;
1837
- b_centroid[2] = centroid_b_z;
1838
- }
1737
+ if (a_centroid) a_centroid[0] = centroid_a_x, a_centroid[1] = centroid_a_y, a_centroid[2] = centroid_a_z;
1738
+ if (b_centroid) b_centroid[0] = centroid_b_x, b_centroid[1] = centroid_b_y, b_centroid[2] = centroid_b_z;
1839
1739
 
1840
1740
  // Apply centering correction: H_centered = H - n * centroid_a * centroid_bᵀ
1841
- covariance_x_x -= n * centroid_a_x * centroid_b_x;
1842
- covariance_x_y -= n * centroid_a_x * centroid_b_y;
1843
- covariance_x_z -= n * centroid_a_x * centroid_b_z;
1844
- covariance_y_x -= n * centroid_a_y * centroid_b_x;
1845
- covariance_y_y -= n * centroid_a_y * centroid_b_y;
1846
- covariance_y_z -= n * centroid_a_y * centroid_b_z;
1847
- covariance_z_x -= n * centroid_a_z * centroid_b_x;
1848
- covariance_z_y -= n * centroid_a_z * centroid_b_y;
1849
- covariance_z_z -= n * centroid_a_z * centroid_b_z;
1741
+ covariance_x_x -= (nk_f32_t)n * centroid_a_x * centroid_b_x;
1742
+ covariance_x_y -= (nk_f32_t)n * centroid_a_x * centroid_b_y;
1743
+ covariance_x_z -= (nk_f32_t)n * centroid_a_x * centroid_b_z;
1744
+ covariance_y_x -= (nk_f32_t)n * centroid_a_y * centroid_b_x;
1745
+ covariance_y_y -= (nk_f32_t)n * centroid_a_y * centroid_b_y;
1746
+ covariance_y_z -= (nk_f32_t)n * centroid_a_y * centroid_b_z;
1747
+ covariance_z_x -= (nk_f32_t)n * centroid_a_z * centroid_b_x;
1748
+ covariance_z_y -= (nk_f32_t)n * centroid_a_z * centroid_b_y;
1749
+ covariance_z_z -= (nk_f32_t)n * centroid_a_z * centroid_b_z;
1850
1750
 
1851
1751
  // Compute SVD and optimal rotation
1852
1752
  nk_f32_t cross_covariance[9] = {covariance_x_x, covariance_x_y, covariance_x_z, covariance_y_x, covariance_y_y,
@@ -1868,9 +1768,7 @@ NK_PUBLIC void nk_kabsch_bf16_haswell(nk_bf16_t const *a, nk_bf16_t const *b, nk
1868
1768
 
1869
1769
  // Handle reflection: if det(R) < 0, negate third column of V and recompute R
1870
1770
  if (nk_det3x3_f32_(r) < 0) {
1871
- svd_v[2] = -svd_v[2];
1872
- svd_v[5] = -svd_v[5];
1873
- svd_v[8] = -svd_v[8];
1771
+ svd_v[2] = -svd_v[2], svd_v[5] = -svd_v[5], svd_v[8] = -svd_v[8];
1874
1772
  r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
1875
1773
  r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
1876
1774
  r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
@@ -1882,10 +1780,9 @@ NK_PUBLIC void nk_kabsch_bf16_haswell(nk_bf16_t const *a, nk_bf16_t const *b, nk
1882
1780
  r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
1883
1781
  }
1884
1782
 
1885
- /* Output rotation matrix and scale=1.0 */
1886
- if (rotation) {
1783
+ // Output rotation matrix and scale=1.0
1784
+ if (rotation)
1887
1785
  for (int j = 0; j < 9; ++j) rotation[j] = r[j];
1888
- }
1889
1786
  if (scale) *scale = 1.0f;
1890
1787
 
1891
1788
  // Compute RMSD after optimal rotation
@@ -1965,21 +1862,11 @@ NK_PUBLIC void nk_umeyama_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_s
1965
1862
  nk_f16_to_f32_haswell(&b[i * 3 + 0], &bx);
1966
1863
  nk_f16_to_f32_haswell(&b[i * 3 + 1], &by);
1967
1864
  nk_f16_to_f32_haswell(&b[i * 3 + 2], &bz);
1968
- sum_a_x += ax;
1969
- sum_a_y += ay;
1970
- sum_a_z += az;
1971
- sum_b_x += bx;
1972
- sum_b_y += by;
1973
- sum_b_z += bz;
1974
- covariance_x_x += ax * bx;
1975
- covariance_x_y += ax * by;
1976
- covariance_x_z += ax * bz;
1977
- covariance_y_x += ay * bx;
1978
- covariance_y_y += ay * by;
1979
- covariance_y_z += ay * bz;
1980
- covariance_z_x += az * bx;
1981
- covariance_z_y += az * by;
1982
- covariance_z_z += az * bz;
1865
+ sum_a_x += ax, sum_a_y += ay, sum_a_z += az;
1866
+ sum_b_x += bx, sum_b_y += by, sum_b_z += bz;
1867
+ covariance_x_x += ax * bx, covariance_x_y += ax * by, covariance_x_z += ax * bz;
1868
+ covariance_y_x += ay * bx, covariance_y_y += ay * by, covariance_y_z += ay * bz;
1869
+ covariance_z_x += az * bx, covariance_z_y += az * by, covariance_z_z += az * bz;
1983
1870
  variance_a_sum += ax * ax + ay * ay + az * az;
1984
1871
  }
1985
1872
 
@@ -1996,15 +1883,15 @@ NK_PUBLIC void nk_umeyama_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_s
1996
1883
  (centroid_a_x * centroid_a_x + centroid_a_y * centroid_a_y + centroid_a_z * centroid_a_z);
1997
1884
 
1998
1885
  // Apply centering correction to covariance matrix
1999
- covariance_x_x -= n * centroid_a_x * centroid_b_x;
2000
- covariance_x_y -= n * centroid_a_x * centroid_b_y;
2001
- covariance_x_z -= n * centroid_a_x * centroid_b_z;
2002
- covariance_y_x -= n * centroid_a_y * centroid_b_x;
2003
- covariance_y_y -= n * centroid_a_y * centroid_b_y;
2004
- covariance_y_z -= n * centroid_a_y * centroid_b_z;
2005
- covariance_z_x -= n * centroid_a_z * centroid_b_x;
2006
- covariance_z_y -= n * centroid_a_z * centroid_b_y;
2007
- covariance_z_z -= n * centroid_a_z * centroid_b_z;
1886
+ covariance_x_x -= (nk_f32_t)n * centroid_a_x * centroid_b_x;
1887
+ covariance_x_y -= (nk_f32_t)n * centroid_a_x * centroid_b_y;
1888
+ covariance_x_z -= (nk_f32_t)n * centroid_a_x * centroid_b_z;
1889
+ covariance_y_x -= (nk_f32_t)n * centroid_a_y * centroid_b_x;
1890
+ covariance_y_y -= (nk_f32_t)n * centroid_a_y * centroid_b_y;
1891
+ covariance_y_z -= (nk_f32_t)n * centroid_a_y * centroid_b_z;
1892
+ covariance_z_x -= (nk_f32_t)n * centroid_a_z * centroid_b_x;
1893
+ covariance_z_y -= (nk_f32_t)n * centroid_a_z * centroid_b_y;
1894
+ covariance_z_z -= (nk_f32_t)n * centroid_a_z * centroid_b_z;
2008
1895
 
2009
1896
  nk_f32_t cross_covariance[9] = {covariance_x_x, covariance_x_y, covariance_x_z, covariance_y_x, covariance_y_y,
2010
1897
  covariance_y_z, covariance_z_x, covariance_z_y, covariance_z_z};
@@ -2029,7 +1916,7 @@ NK_PUBLIC void nk_umeyama_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_s
2029
1916
  nk_f32_t det = nk_det3x3_f32_(r);
2030
1917
  nk_f32_t d3 = det < 0 ? -1.0f : 1.0f;
2031
1918
  nk_f32_t trace_ds = svd_s[0] + svd_s[4] + d3 * svd_s[8];
2032
- nk_f32_t c = trace_ds / (n * variance_a);
1919
+ nk_f32_t c = trace_ds / ((nk_f32_t)n * variance_a);
2033
1920
  if (scale) *scale = c;
2034
1921
 
2035
1922
  // Handle reflection
@@ -2046,10 +1933,9 @@ NK_PUBLIC void nk_umeyama_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_s
2046
1933
  r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
2047
1934
  }
2048
1935
 
2049
- /* Output rotation matrix */
2050
- if (rotation) {
1936
+ // Output rotation matrix
1937
+ if (rotation)
2051
1938
  for (int j = 0; j < 9; ++j) rotation[j] = r[j];
2052
- }
2053
1939
 
2054
1940
  // Compute RMSD with scaling
2055
1941
  nk_f32_t sum_squared = nk_transformed_ssd_f16_haswell_(a, b, n, r, c, centroid_a_x, centroid_a_y, centroid_a_z,
@@ -2128,21 +2014,11 @@ NK_PUBLIC void nk_umeyama_bf16_haswell(nk_bf16_t const *a, nk_bf16_t const *b, n
2128
2014
  nk_bf16_to_f32_serial(&b[i * 3 + 0], &bx);
2129
2015
  nk_bf16_to_f32_serial(&b[i * 3 + 1], &by);
2130
2016
  nk_bf16_to_f32_serial(&b[i * 3 + 2], &bz);
2131
- sum_a_x += ax;
2132
- sum_a_y += ay;
2133
- sum_a_z += az;
2134
- sum_b_x += bx;
2135
- sum_b_y += by;
2136
- sum_b_z += bz;
2137
- covariance_x_x += ax * bx;
2138
- covariance_x_y += ax * by;
2139
- covariance_x_z += ax * bz;
2140
- covariance_y_x += ay * bx;
2141
- covariance_y_y += ay * by;
2142
- covariance_y_z += ay * bz;
2143
- covariance_z_x += az * bx;
2144
- covariance_z_y += az * by;
2145
- covariance_z_z += az * bz;
2017
+ sum_a_x += ax, sum_a_y += ay, sum_a_z += az;
2018
+ sum_b_x += bx, sum_b_y += by, sum_b_z += bz;
2019
+ covariance_x_x += ax * bx, covariance_x_y += ax * by, covariance_x_z += ax * bz;
2020
+ covariance_y_x += ay * bx, covariance_y_y += ay * by, covariance_y_z += ay * bz;
2021
+ covariance_z_x += az * bx, covariance_z_y += az * by, covariance_z_z += az * bz;
2146
2022
  variance_a_sum += ax * ax + ay * ay + az * az;
2147
2023
  }
2148
2024
 
@@ -2159,15 +2035,15 @@ NK_PUBLIC void nk_umeyama_bf16_haswell(nk_bf16_t const *a, nk_bf16_t const *b, n
2159
2035
  (centroid_a_x * centroid_a_x + centroid_a_y * centroid_a_y + centroid_a_z * centroid_a_z);
2160
2036
 
2161
2037
  // Apply centering correction to covariance matrix
2162
- covariance_x_x -= n * centroid_a_x * centroid_b_x;
2163
- covariance_x_y -= n * centroid_a_x * centroid_b_y;
2164
- covariance_x_z -= n * centroid_a_x * centroid_b_z;
2165
- covariance_y_x -= n * centroid_a_y * centroid_b_x;
2166
- covariance_y_y -= n * centroid_a_y * centroid_b_y;
2167
- covariance_y_z -= n * centroid_a_y * centroid_b_z;
2168
- covariance_z_x -= n * centroid_a_z * centroid_b_x;
2169
- covariance_z_y -= n * centroid_a_z * centroid_b_y;
2170
- covariance_z_z -= n * centroid_a_z * centroid_b_z;
2038
+ covariance_x_x -= (nk_f32_t)n * centroid_a_x * centroid_b_x;
2039
+ covariance_x_y -= (nk_f32_t)n * centroid_a_x * centroid_b_y;
2040
+ covariance_x_z -= (nk_f32_t)n * centroid_a_x * centroid_b_z;
2041
+ covariance_y_x -= (nk_f32_t)n * centroid_a_y * centroid_b_x;
2042
+ covariance_y_y -= (nk_f32_t)n * centroid_a_y * centroid_b_y;
2043
+ covariance_y_z -= (nk_f32_t)n * centroid_a_y * centroid_b_z;
2044
+ covariance_z_x -= (nk_f32_t)n * centroid_a_z * centroid_b_x;
2045
+ covariance_z_y -= (nk_f32_t)n * centroid_a_z * centroid_b_y;
2046
+ covariance_z_z -= (nk_f32_t)n * centroid_a_z * centroid_b_z;
2171
2047
 
2172
2048
  nk_f32_t cross_covariance[9] = {covariance_x_x, covariance_x_y, covariance_x_z, covariance_y_x, covariance_y_y,
2173
2049
  covariance_y_z, covariance_z_x, covariance_z_y, covariance_z_z};
@@ -2192,7 +2068,7 @@ NK_PUBLIC void nk_umeyama_bf16_haswell(nk_bf16_t const *a, nk_bf16_t const *b, n
2192
2068
  nk_f32_t det = nk_det3x3_f32_(r);
2193
2069
  nk_f32_t d3 = det < 0 ? -1.0f : 1.0f;
2194
2070
  nk_f32_t trace_ds = svd_s[0] + svd_s[4] + d3 * svd_s[8];
2195
- nk_f32_t c = trace_ds / (n * variance_a);
2071
+ nk_f32_t c = trace_ds / ((nk_f32_t)n * variance_a);
2196
2072
  if (scale) *scale = c;
2197
2073
 
2198
2074
  // Handle reflection
@@ -2209,10 +2085,9 @@ NK_PUBLIC void nk_umeyama_bf16_haswell(nk_bf16_t const *a, nk_bf16_t const *b, n
2209
2085
  r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
2210
2086
  }
2211
2087
 
2212
- /* Output rotation matrix */
2213
- if (rotation) {
2088
+ // Output rotation matrix
2089
+ if (rotation)
2214
2090
  for (int j = 0; j < 9; ++j) rotation[j] = r[j];
2215
- }
2216
2091
 
2217
2092
  // Compute RMSD with scaling
2218
2093
  nk_f32_t sum_squared = nk_transformed_ssd_bf16_haswell_(a, b, n, r, c, centroid_a_x, centroid_a_y, centroid_a_z,