numkong 7.0.0 → 7.4.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (315) hide show
  1. package/README.md +197 -124
  2. package/binding.gyp +34 -484
  3. package/c/dispatch_bf16.c +59 -1
  4. package/c/dispatch_e2m3.c +41 -8
  5. package/c/dispatch_e3m2.c +49 -8
  6. package/c/dispatch_e4m3.c +51 -9
  7. package/c/dispatch_e5m2.c +45 -1
  8. package/c/dispatch_f16.c +79 -26
  9. package/c/dispatch_f16c.c +5 -5
  10. package/c/dispatch_f32.c +56 -0
  11. package/c/dispatch_f64.c +52 -0
  12. package/c/dispatch_i4.c +3 -0
  13. package/c/dispatch_i8.c +62 -3
  14. package/c/dispatch_other.c +18 -0
  15. package/c/dispatch_u1.c +54 -9
  16. package/c/dispatch_u4.c +3 -0
  17. package/c/dispatch_u8.c +64 -3
  18. package/c/numkong.c +3 -0
  19. package/include/README.md +79 -9
  20. package/include/numkong/attention/sapphireamx.h +278 -276
  21. package/include/numkong/attention/sme.h +983 -977
  22. package/include/numkong/attention.h +1 -1
  23. package/include/numkong/capabilities.h +289 -94
  24. package/include/numkong/cast/README.md +40 -40
  25. package/include/numkong/cast/diamond.h +64 -0
  26. package/include/numkong/cast/haswell.h +42 -194
  27. package/include/numkong/cast/icelake.h +42 -37
  28. package/include/numkong/cast/loongsonasx.h +252 -0
  29. package/include/numkong/cast/neon.h +216 -249
  30. package/include/numkong/cast/powervsx.h +449 -0
  31. package/include/numkong/cast/rvv.h +223 -274
  32. package/include/numkong/cast/sapphire.h +18 -18
  33. package/include/numkong/cast/serial.h +1018 -944
  34. package/include/numkong/cast/skylake.h +82 -23
  35. package/include/numkong/cast/v128relaxed.h +462 -105
  36. package/include/numkong/cast.h +24 -0
  37. package/include/numkong/cast.hpp +44 -0
  38. package/include/numkong/curved/README.md +17 -17
  39. package/include/numkong/curved/neon.h +131 -7
  40. package/include/numkong/curved/neonbfdot.h +6 -7
  41. package/include/numkong/curved/rvv.h +26 -26
  42. package/include/numkong/curved/smef64.h +186 -182
  43. package/include/numkong/curved.h +14 -18
  44. package/include/numkong/dot/README.md +154 -137
  45. package/include/numkong/dot/alder.h +43 -43
  46. package/include/numkong/dot/diamond.h +158 -0
  47. package/include/numkong/dot/genoa.h +4 -30
  48. package/include/numkong/dot/haswell.h +215 -180
  49. package/include/numkong/dot/icelake.h +190 -76
  50. package/include/numkong/dot/loongsonasx.h +671 -0
  51. package/include/numkong/dot/neon.h +124 -73
  52. package/include/numkong/dot/neonbfdot.h +11 -12
  53. package/include/numkong/dot/neonfhm.h +44 -46
  54. package/include/numkong/dot/neonfp8.h +323 -0
  55. package/include/numkong/dot/neonsdot.h +190 -76
  56. package/include/numkong/dot/powervsx.h +752 -0
  57. package/include/numkong/dot/rvv.h +92 -84
  58. package/include/numkong/dot/rvvbf16.h +12 -12
  59. package/include/numkong/dot/rvvhalf.h +12 -12
  60. package/include/numkong/dot/sapphire.h +4 -4
  61. package/include/numkong/dot/serial.h +66 -30
  62. package/include/numkong/dot/sierra.h +31 -31
  63. package/include/numkong/dot/skylake.h +142 -110
  64. package/include/numkong/dot/sve.h +217 -177
  65. package/include/numkong/dot/svebfdot.h +10 -10
  66. package/include/numkong/dot/svehalf.h +85 -41
  67. package/include/numkong/dot/svesdot.h +89 -0
  68. package/include/numkong/dot/v128relaxed.h +124 -89
  69. package/include/numkong/dot.h +114 -48
  70. package/include/numkong/dots/README.md +203 -203
  71. package/include/numkong/dots/alder.h +12 -9
  72. package/include/numkong/dots/diamond.h +86 -0
  73. package/include/numkong/dots/genoa.h +10 -4
  74. package/include/numkong/dots/haswell.h +63 -48
  75. package/include/numkong/dots/icelake.h +27 -18
  76. package/include/numkong/dots/loongsonasx.h +176 -0
  77. package/include/numkong/dots/neon.h +14 -11
  78. package/include/numkong/dots/neonbfdot.h +4 -3
  79. package/include/numkong/dots/neonfhm.h +11 -9
  80. package/include/numkong/dots/neonfp8.h +99 -0
  81. package/include/numkong/dots/neonsdot.h +48 -12
  82. package/include/numkong/dots/powervsx.h +194 -0
  83. package/include/numkong/dots/rvv.h +451 -344
  84. package/include/numkong/dots/sapphireamx.h +1028 -984
  85. package/include/numkong/dots/serial.h +213 -197
  86. package/include/numkong/dots/sierra.h +10 -7
  87. package/include/numkong/dots/skylake.h +47 -36
  88. package/include/numkong/dots/sme.h +2001 -2364
  89. package/include/numkong/dots/smebi32.h +175 -162
  90. package/include/numkong/dots/smef64.h +328 -323
  91. package/include/numkong/dots/v128relaxed.h +64 -41
  92. package/include/numkong/dots.h +573 -293
  93. package/include/numkong/dots.hpp +45 -43
  94. package/include/numkong/each/README.md +133 -137
  95. package/include/numkong/each/haswell.h +6 -6
  96. package/include/numkong/each/icelake.h +7 -7
  97. package/include/numkong/each/neon.h +76 -42
  98. package/include/numkong/each/neonbfdot.h +11 -12
  99. package/include/numkong/each/neonhalf.h +24 -116
  100. package/include/numkong/each/rvv.h +28 -28
  101. package/include/numkong/each/sapphire.h +27 -161
  102. package/include/numkong/each/serial.h +6 -6
  103. package/include/numkong/each/skylake.h +7 -7
  104. package/include/numkong/each/v128relaxed.h +562 -0
  105. package/include/numkong/each.h +148 -62
  106. package/include/numkong/each.hpp +2 -2
  107. package/include/numkong/geospatial/README.md +18 -18
  108. package/include/numkong/geospatial/haswell.h +365 -325
  109. package/include/numkong/geospatial/neon.h +350 -306
  110. package/include/numkong/geospatial/rvv.h +4 -4
  111. package/include/numkong/geospatial/skylake.h +376 -340
  112. package/include/numkong/geospatial/v128relaxed.h +366 -327
  113. package/include/numkong/geospatial.h +17 -17
  114. package/include/numkong/matrix.hpp +4 -4
  115. package/include/numkong/maxsim/README.md +14 -14
  116. package/include/numkong/maxsim/alder.h +6 -6
  117. package/include/numkong/maxsim/genoa.h +4 -4
  118. package/include/numkong/maxsim/haswell.h +6 -6
  119. package/include/numkong/maxsim/icelake.h +18 -18
  120. package/include/numkong/maxsim/neonsdot.h +21 -21
  121. package/include/numkong/maxsim/sapphireamx.h +14 -14
  122. package/include/numkong/maxsim/serial.h +6 -6
  123. package/include/numkong/maxsim/sme.h +221 -196
  124. package/include/numkong/maxsim/v128relaxed.h +6 -6
  125. package/include/numkong/mesh/README.md +62 -56
  126. package/include/numkong/mesh/haswell.h +339 -464
  127. package/include/numkong/mesh/neon.h +1100 -519
  128. package/include/numkong/mesh/neonbfdot.h +36 -68
  129. package/include/numkong/mesh/rvv.h +530 -435
  130. package/include/numkong/mesh/serial.h +75 -91
  131. package/include/numkong/mesh/skylake.h +1627 -302
  132. package/include/numkong/mesh/v128relaxed.h +443 -330
  133. package/include/numkong/mesh.h +63 -49
  134. package/include/numkong/mesh.hpp +4 -4
  135. package/include/numkong/numkong.h +3 -3
  136. package/include/numkong/numkong.hpp +1 -0
  137. package/include/numkong/probability/README.md +23 -19
  138. package/include/numkong/probability/neon.h +82 -52
  139. package/include/numkong/probability/rvv.h +28 -23
  140. package/include/numkong/probability/serial.h +51 -39
  141. package/include/numkong/probability.h +20 -23
  142. package/include/numkong/random.h +1 -1
  143. package/include/numkong/reduce/README.md +143 -138
  144. package/include/numkong/reduce/alder.h +81 -77
  145. package/include/numkong/reduce/haswell.h +222 -220
  146. package/include/numkong/reduce/neon.h +629 -519
  147. package/include/numkong/reduce/neonbfdot.h +7 -218
  148. package/include/numkong/reduce/neonfhm.h +9 -381
  149. package/include/numkong/reduce/neonsdot.h +9 -9
  150. package/include/numkong/reduce/rvv.h +928 -802
  151. package/include/numkong/reduce/serial.h +23 -27
  152. package/include/numkong/reduce/sierra.h +20 -20
  153. package/include/numkong/reduce/skylake.h +326 -324
  154. package/include/numkong/reduce/v128relaxed.h +52 -52
  155. package/include/numkong/reduce.h +4 -23
  156. package/include/numkong/reduce.hpp +156 -11
  157. package/include/numkong/scalar/README.md +6 -6
  158. package/include/numkong/scalar/haswell.h +26 -17
  159. package/include/numkong/scalar/loongsonasx.h +74 -0
  160. package/include/numkong/scalar/neon.h +9 -9
  161. package/include/numkong/scalar/powervsx.h +96 -0
  162. package/include/numkong/scalar/rvv.h +2 -2
  163. package/include/numkong/scalar/sapphire.h +21 -10
  164. package/include/numkong/scalar/serial.h +21 -21
  165. package/include/numkong/scalar.h +13 -0
  166. package/include/numkong/set/README.md +28 -28
  167. package/include/numkong/set/haswell.h +12 -12
  168. package/include/numkong/set/icelake.h +14 -14
  169. package/include/numkong/set/loongsonasx.h +181 -0
  170. package/include/numkong/set/neon.h +17 -18
  171. package/include/numkong/set/powervsx.h +326 -0
  172. package/include/numkong/set/rvv.h +4 -4
  173. package/include/numkong/set/serial.h +6 -6
  174. package/include/numkong/set/sve.h +60 -59
  175. package/include/numkong/set/v128relaxed.h +6 -6
  176. package/include/numkong/set.h +21 -7
  177. package/include/numkong/sets/README.md +26 -26
  178. package/include/numkong/sets/loongsonasx.h +52 -0
  179. package/include/numkong/sets/powervsx.h +65 -0
  180. package/include/numkong/sets/smebi32.h +395 -364
  181. package/include/numkong/sets.h +83 -40
  182. package/include/numkong/sparse/README.md +4 -4
  183. package/include/numkong/sparse/icelake.h +101 -101
  184. package/include/numkong/sparse/serial.h +1 -1
  185. package/include/numkong/sparse/sve2.h +137 -141
  186. package/include/numkong/sparse/turin.h +12 -12
  187. package/include/numkong/sparse.h +10 -10
  188. package/include/numkong/spatial/README.md +230 -226
  189. package/include/numkong/spatial/alder.h +113 -116
  190. package/include/numkong/spatial/diamond.h +240 -0
  191. package/include/numkong/spatial/genoa.h +0 -68
  192. package/include/numkong/spatial/haswell.h +74 -55
  193. package/include/numkong/spatial/icelake.h +539 -58
  194. package/include/numkong/spatial/loongsonasx.h +483 -0
  195. package/include/numkong/spatial/neon.h +125 -52
  196. package/include/numkong/spatial/neonbfdot.h +8 -9
  197. package/include/numkong/spatial/neonfp8.h +258 -0
  198. package/include/numkong/spatial/neonsdot.h +180 -12
  199. package/include/numkong/spatial/powervsx.h +738 -0
  200. package/include/numkong/spatial/rvv.h +146 -139
  201. package/include/numkong/spatial/rvvbf16.h +17 -12
  202. package/include/numkong/spatial/rvvhalf.h +13 -10
  203. package/include/numkong/spatial/serial.h +13 -12
  204. package/include/numkong/spatial/sierra.h +232 -39
  205. package/include/numkong/spatial/skylake.h +73 -74
  206. package/include/numkong/spatial/sve.h +93 -72
  207. package/include/numkong/spatial/svebfdot.h +29 -29
  208. package/include/numkong/spatial/svehalf.h +52 -26
  209. package/include/numkong/spatial/svesdot.h +142 -0
  210. package/include/numkong/spatial/v128relaxed.h +293 -41
  211. package/include/numkong/spatial.h +338 -82
  212. package/include/numkong/spatials/README.md +194 -194
  213. package/include/numkong/spatials/diamond.h +82 -0
  214. package/include/numkong/spatials/haswell.h +2 -2
  215. package/include/numkong/spatials/loongsonasx.h +153 -0
  216. package/include/numkong/spatials/neonfp8.h +111 -0
  217. package/include/numkong/spatials/neonsdot.h +34 -0
  218. package/include/numkong/spatials/powervsx.h +153 -0
  219. package/include/numkong/spatials/rvv.h +259 -243
  220. package/include/numkong/spatials/sapphireamx.h +173 -173
  221. package/include/numkong/spatials/serial.h +2 -2
  222. package/include/numkong/spatials/skylake.h +2 -2
  223. package/include/numkong/spatials/sme.h +590 -605
  224. package/include/numkong/spatials/smef64.h +139 -130
  225. package/include/numkong/spatials/v128relaxed.h +2 -2
  226. package/include/numkong/spatials.h +820 -500
  227. package/include/numkong/spatials.hpp +49 -48
  228. package/include/numkong/tensor.hpp +406 -17
  229. package/include/numkong/trigonometry/README.md +19 -19
  230. package/include/numkong/trigonometry/haswell.h +402 -401
  231. package/include/numkong/trigonometry/neon.h +386 -387
  232. package/include/numkong/trigonometry/rvv.h +52 -51
  233. package/include/numkong/trigonometry/serial.h +13 -13
  234. package/include/numkong/trigonometry/skylake.h +373 -369
  235. package/include/numkong/trigonometry/v128relaxed.h +375 -374
  236. package/include/numkong/trigonometry.h +13 -13
  237. package/include/numkong/trigonometry.hpp +2 -2
  238. package/include/numkong/types.h +287 -49
  239. package/include/numkong/types.hpp +436 -12
  240. package/include/numkong/vector.hpp +82 -14
  241. package/javascript/dist/cjs/numkong-wasm.js +6 -12
  242. package/javascript/dist/cjs/numkong.d.ts +7 -1
  243. package/javascript/dist/cjs/numkong.js +37 -11
  244. package/javascript/dist/cjs/types.d.ts +9 -0
  245. package/javascript/dist/cjs/types.js +96 -0
  246. package/javascript/dist/esm/numkong-browser.d.ts +14 -0
  247. package/javascript/dist/esm/numkong-browser.js +23 -0
  248. package/javascript/dist/esm/numkong-wasm.js +6 -12
  249. package/javascript/dist/esm/numkong.d.ts +7 -1
  250. package/javascript/dist/esm/numkong.js +37 -11
  251. package/javascript/dist/esm/types.d.ts +9 -0
  252. package/javascript/dist/esm/types.js +96 -0
  253. package/javascript/node-gyp-build.d.ts +4 -1
  254. package/javascript/numkong-browser.ts +40 -0
  255. package/javascript/numkong-wasm.ts +7 -13
  256. package/javascript/numkong.c +5 -26
  257. package/javascript/numkong.ts +36 -11
  258. package/javascript/tsconfig-base.json +1 -0
  259. package/javascript/tsconfig-cjs.json +6 -1
  260. package/javascript/types.ts +110 -0
  261. package/numkong.gypi +101 -0
  262. package/package.json +34 -13
  263. package/probes/arm_neon.c +8 -0
  264. package/probes/arm_neon_bfdot.c +9 -0
  265. package/probes/arm_neon_fhm.c +9 -0
  266. package/probes/arm_neon_half.c +8 -0
  267. package/probes/arm_neon_sdot.c +9 -0
  268. package/probes/arm_neonfp8.c +9 -0
  269. package/probes/arm_sme.c +16 -0
  270. package/probes/arm_sme2.c +16 -0
  271. package/probes/arm_sme2p1.c +16 -0
  272. package/probes/arm_sme_bf16.c +16 -0
  273. package/probes/arm_sme_bi32.c +16 -0
  274. package/probes/arm_sme_f64.c +16 -0
  275. package/probes/arm_sme_fa64.c +14 -0
  276. package/probes/arm_sme_half.c +16 -0
  277. package/probes/arm_sme_lut2.c +15 -0
  278. package/probes/arm_sve.c +18 -0
  279. package/probes/arm_sve2.c +20 -0
  280. package/probes/arm_sve2p1.c +18 -0
  281. package/probes/arm_sve_bfdot.c +20 -0
  282. package/probes/arm_sve_half.c +18 -0
  283. package/probes/arm_sve_sdot.c +21 -0
  284. package/probes/loongarch_lasx.c +12 -0
  285. package/probes/power_vsx.c +12 -0
  286. package/probes/probe.js +127 -0
  287. package/probes/riscv_rvv.c +14 -0
  288. package/probes/riscv_rvv_bb.c +15 -0
  289. package/probes/riscv_rvv_bf16.c +17 -0
  290. package/probes/riscv_rvv_half.c +14 -0
  291. package/probes/wasm_v128relaxed.c +11 -0
  292. package/probes/x86_alder.c +17 -0
  293. package/probes/x86_diamond.c +17 -0
  294. package/probes/x86_genoa.c +17 -0
  295. package/probes/x86_graniteamx.c +19 -0
  296. package/probes/x86_haswell.c +11 -0
  297. package/probes/x86_icelake.c +17 -0
  298. package/probes/x86_sapphire.c +16 -0
  299. package/probes/x86_sapphireamx.c +18 -0
  300. package/probes/x86_sierra.c +17 -0
  301. package/probes/x86_skylake.c +15 -0
  302. package/probes/x86_turin.c +17 -0
  303. package/wasm/numkong-emscripten.js +2 -0
  304. package/wasm/numkong.d.ts +14 -0
  305. package/wasm/numkong.js +1124 -0
  306. package/wasm/numkong.wasm +0 -0
  307. package/include/numkong/curved/neonhalf.h +0 -212
  308. package/include/numkong/dot/neonhalf.h +0 -198
  309. package/include/numkong/dots/neonhalf.h +0 -57
  310. package/include/numkong/mesh/neonhalf.h +0 -616
  311. package/include/numkong/reduce/neonhalf.h +0 -157
  312. package/include/numkong/spatial/neonhalf.h +0 -118
  313. package/include/numkong/spatial/sapphire.h +0 -343
  314. package/include/numkong/spatials/neonhalf.h +0 -58
  315. package/javascript/README.md +0 -246
@@ -8,11 +8,11 @@
8
8
  *
9
9
  * @section skylake_mesh_instructions Key AVX-512 Mesh Instructions
10
10
  *
11
- * Intrinsic Instruction Latency Throughput Ports
12
- * _mm512_fmadd_ps VFMADD132PS (ZMM, ZMM, ZMM) 4cy 0.5/cy p05
13
- * _mm512_permutexvar_ps VPERMPS (ZMM, ZMM, ZMM) 3cy 1/cy p5
14
- * _mm512_permutex2var_ps VPERMT2PS (ZMM, ZMM, ZMM) 3cy 1/cy p5
15
- * _mm512_extractf32x8_ps VEXTRACTF32X8 (YMM, ZMM, I8) 3cy 1/cy p5
11
+ * Intrinsic Instruction Skylake-X Genoa
12
+ * _mm512_fmadd_ps VFMADD132PS (ZMM, ZMM, ZMM) 4cy @ p05 4cy @ p01
13
+ * _mm512_permutexvar_ps VPERMPS (ZMM, ZMM, ZMM) 3cy @ p5 4cy @ p12
14
+ * _mm512_permutex2var_ps VPERMT2PS (ZMM, ZMM, ZMM) 3cy @ p5 4cy @ p12
15
+ * _mm512_extractf32x8_ps VEXTRACTF32X8 (YMM, ZMM, I8) 3cy @ p5 1cy @ p0123
16
16
  *
17
17
  * Point cloud operations use VPERMT2PS for stride-3 deinterleaving of xyz coordinates, avoiding
18
18
  * expensive gather instructions. This achieves ~1.8x speedup over scalar deinterleaving. Dual FMA
@@ -28,6 +28,7 @@
28
28
  #include "numkong/dot/skylake.h"
29
29
  #include "numkong/mesh/serial.h"
30
30
  #include "numkong/spatial/haswell.h"
31
+ #include "numkong/cast/skylake.h"
31
32
 
32
33
  #if defined(__cplusplus)
33
34
  extern "C" {
@@ -112,6 +113,115 @@ NK_INTERNAL void nk_deinterleave_f64x8_skylake_(
112
113
  *z_f64x8_out = _mm512_permutex2var_pd(z01_f64x8, idx_z_2_i64x8, reg2_f64x8);
113
114
  }
114
115
 
116
+ /* Deinterleave 16 f16 3D points from xyz,xyz,xyz... to separate x,y,z vectors in f32.
117
+ * Input: 48 consecutive f16 values (16 points * 3 coordinates)
118
+ * Output: Three __m512 vectors containing the x, y, z coordinates separately (as f32).
119
+ */
120
+ NK_INTERNAL void nk_deinterleave_f16x16_to_f32x16_skylake_( //
121
+ nk_f16_t const *ptr, __m512 *x_f32x16_out, __m512 *y_f32x16_out, __m512 *z_f32x16_out) { //
122
+ __m512 reg0_f32x16 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i const *)(ptr)));
123
+ __m512 reg1_f32x16 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i const *)(ptr + 16)));
124
+ __m512 reg2_f32x16 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i const *)(ptr + 32)));
125
+
126
+ __m512i idx_x_01_i32x16 = _mm512_setr_epi32(0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 0, 0, 0, 0, 0);
127
+ __m512i idx_x_2_i32x16 = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 17, 20, 23, 26, 29);
128
+ __m512 x01_f32x16 = _mm512_permutex2var_ps(reg0_f32x16, idx_x_01_i32x16, reg1_f32x16);
129
+ *x_f32x16_out = _mm512_permutex2var_ps(x01_f32x16, idx_x_2_i32x16, reg2_f32x16);
130
+
131
+ __m512i idx_y_01_i32x16 = _mm512_setr_epi32(1, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31, 0, 0, 0, 0, 0);
132
+ __m512i idx_y_2_i32x16 = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 18, 21, 24, 27, 30);
133
+ __m512 y01_f32x16 = _mm512_permutex2var_ps(reg0_f32x16, idx_y_01_i32x16, reg1_f32x16);
134
+ *y_f32x16_out = _mm512_permutex2var_ps(y01_f32x16, idx_y_2_i32x16, reg2_f32x16);
135
+
136
+ __m512i idx_z_01_i32x16 = _mm512_setr_epi32(2, 5, 8, 11, 14, 17, 20, 23, 26, 29, 0, 0, 0, 0, 0, 0);
137
+ __m512i idx_z_2_i32x16 = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 16, 19, 22, 25, 28, 31);
138
+ __m512 z01_f32x16 = _mm512_permutex2var_ps(reg0_f32x16, idx_z_01_i32x16, reg1_f32x16);
139
+ *z_f32x16_out = _mm512_permutex2var_ps(z01_f32x16, idx_z_2_i32x16, reg2_f32x16);
140
+ }
141
+
142
+ /* Deinterleave 16 bf16 3D points from xyz,xyz,xyz... to separate x,y,z vectors in f32.
143
+ * Input: 48 consecutive bf16 values (16 points * 3 coordinates)
144
+ * Output: Three __m512 vectors containing the x, y, z coordinates separately (as f32).
145
+ */
146
+ NK_INTERNAL void nk_deinterleave_bf16x16_to_f32x16_skylake_( //
147
+ nk_bf16_t const *ptr, __m512 *x_f32x16_out, __m512 *y_f32x16_out, __m512 *z_f32x16_out) { //
148
+ __m512 reg0_f32x16 = nk_bf16x16_to_f32x16_skylake_(_mm256_loadu_si256((__m256i const *)(ptr)));
149
+ __m512 reg1_f32x16 = nk_bf16x16_to_f32x16_skylake_(_mm256_loadu_si256((__m256i const *)(ptr + 16)));
150
+ __m512 reg2_f32x16 = nk_bf16x16_to_f32x16_skylake_(_mm256_loadu_si256((__m256i const *)(ptr + 32)));
151
+
152
+ __m512i idx_x_01_i32x16 = _mm512_setr_epi32(0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 0, 0, 0, 0, 0);
153
+ __m512i idx_x_2_i32x16 = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 17, 20, 23, 26, 29);
154
+ __m512 x01_f32x16 = _mm512_permutex2var_ps(reg0_f32x16, idx_x_01_i32x16, reg1_f32x16);
155
+ *x_f32x16_out = _mm512_permutex2var_ps(x01_f32x16, idx_x_2_i32x16, reg2_f32x16);
156
+
157
+ __m512i idx_y_01_i32x16 = _mm512_setr_epi32(1, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31, 0, 0, 0, 0, 0);
158
+ __m512i idx_y_2_i32x16 = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 18, 21, 24, 27, 30);
159
+ __m512 y01_f32x16 = _mm512_permutex2var_ps(reg0_f32x16, idx_y_01_i32x16, reg1_f32x16);
160
+ *y_f32x16_out = _mm512_permutex2var_ps(y01_f32x16, idx_y_2_i32x16, reg2_f32x16);
161
+
162
+ __m512i idx_z_01_i32x16 = _mm512_setr_epi32(2, 5, 8, 11, 14, 17, 20, 23, 26, 29, 0, 0, 0, 0, 0, 0);
163
+ __m512i idx_z_2_i32x16 = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 16, 19, 22, 25, 28, 31);
164
+ __m512 z01_f32x16 = _mm512_permutex2var_ps(reg0_f32x16, idx_z_01_i32x16, reg1_f32x16);
165
+ *z_f32x16_out = _mm512_permutex2var_ps(z01_f32x16, idx_z_2_i32x16, reg2_f32x16);
166
+ }
167
+
168
+ /* Masked-tail deinterleave for f16: loads up to 16 xyz points using AVX-512 masked loads,
169
+ * converts f16→f32, and deinterleaves into separate x,y,z vectors.
170
+ * Unused lanes are zero. Uses the same permutex2var shuffle as the full-width version.
171
+ */
172
+ NK_INTERNAL void nk_deinterleave_f16_tail_to_f32x16_skylake_( //
173
+ nk_f16_t const *ptr, nk_size_t count, __m512 *x_f32x16_out, __m512 *y_f32x16_out, __m512 *z_f32x16_out) { //
174
+ nk_size_t total = count * 3;
175
+ __mmask16 mask0_i16x16 = (__mmask16)_bzhi_u32(0xFFFF, total >= 16 ? 16 : total);
176
+ __mmask16 mask1_i16x16 = total > 16 ? (__mmask16)_bzhi_u32(0xFFFF, total >= 32 ? 16 : total - 16) : 0;
177
+ __mmask16 mask2_i16x16 = total > 32 ? (__mmask16)_bzhi_u32(0xFFFF, total - 32) : 0;
178
+ __m512 reg0_f32x16 = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask0_i16x16, ptr));
179
+ __m512 reg1_f32x16 = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask1_i16x16, ptr + 16));
180
+ __m512 reg2_f32x16 = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask2_i16x16, ptr + 32));
181
+
182
+ __m512i idx_x_01_i32x16 = _mm512_setr_epi32(0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 0, 0, 0, 0, 0);
183
+ __m512i idx_x_2_i32x16 = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 17, 20, 23, 26, 29);
184
+ *x_f32x16_out = _mm512_permutex2var_ps(_mm512_permutex2var_ps(reg0_f32x16, idx_x_01_i32x16, reg1_f32x16),
185
+ idx_x_2_i32x16, reg2_f32x16);
186
+
187
+ __m512i idx_y_01_i32x16 = _mm512_setr_epi32(1, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31, 0, 0, 0, 0, 0);
188
+ __m512i idx_y_2_i32x16 = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 18, 21, 24, 27, 30);
189
+ *y_f32x16_out = _mm512_permutex2var_ps(_mm512_permutex2var_ps(reg0_f32x16, idx_y_01_i32x16, reg1_f32x16),
190
+ idx_y_2_i32x16, reg2_f32x16);
191
+
192
+ __m512i idx_z_01_i32x16 = _mm512_setr_epi32(2, 5, 8, 11, 14, 17, 20, 23, 26, 29, 0, 0, 0, 0, 0, 0);
193
+ __m512i idx_z_2_i32x16 = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 16, 19, 22, 25, 28, 31);
194
+ *z_f32x16_out = _mm512_permutex2var_ps(_mm512_permutex2var_ps(reg0_f32x16, idx_z_01_i32x16, reg1_f32x16),
195
+ idx_z_2_i32x16, reg2_f32x16);
196
+ }
197
+
198
+ /* Masked-tail deinterleave for bf16: same as f16 but with bf16→f32 conversion. */
199
+ NK_INTERNAL void nk_deinterleave_bf16_tail_to_f32x16_skylake_( //
200
+ nk_bf16_t const *ptr, nk_size_t count, __m512 *x_f32x16_out, __m512 *y_f32x16_out, __m512 *z_f32x16_out) { //
201
+ nk_size_t total = count * 3;
202
+ __mmask16 mask0_i16x16 = (__mmask16)_bzhi_u32(0xFFFF, total >= 16 ? 16 : total);
203
+ __mmask16 mask1_i16x16 = total > 16 ? (__mmask16)_bzhi_u32(0xFFFF, total >= 32 ? 16 : total - 16) : 0;
204
+ __mmask16 mask2_i16x16 = total > 32 ? (__mmask16)_bzhi_u32(0xFFFF, total - 32) : 0;
205
+ __m512 reg0_f32x16 = nk_bf16x16_to_f32x16_skylake_(_mm256_maskz_loadu_epi16(mask0_i16x16, ptr));
206
+ __m512 reg1_f32x16 = nk_bf16x16_to_f32x16_skylake_(_mm256_maskz_loadu_epi16(mask1_i16x16, ptr + 16));
207
+ __m512 reg2_f32x16 = nk_bf16x16_to_f32x16_skylake_(_mm256_maskz_loadu_epi16(mask2_i16x16, ptr + 32));
208
+
209
+ __m512i idx_x_01_i32x16 = _mm512_setr_epi32(0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 0, 0, 0, 0, 0);
210
+ __m512i idx_x_2_i32x16 = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 17, 20, 23, 26, 29);
211
+ *x_f32x16_out = _mm512_permutex2var_ps(_mm512_permutex2var_ps(reg0_f32x16, idx_x_01_i32x16, reg1_f32x16),
212
+ idx_x_2_i32x16, reg2_f32x16);
213
+
214
+ __m512i idx_y_01_i32x16 = _mm512_setr_epi32(1, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31, 0, 0, 0, 0, 0);
215
+ __m512i idx_y_2_i32x16 = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 18, 21, 24, 27, 30);
216
+ *y_f32x16_out = _mm512_permutex2var_ps(_mm512_permutex2var_ps(reg0_f32x16, idx_y_01_i32x16, reg1_f32x16),
217
+ idx_y_2_i32x16, reg2_f32x16);
218
+
219
+ __m512i idx_z_01_i32x16 = _mm512_setr_epi32(2, 5, 8, 11, 14, 17, 20, 23, 26, 29, 0, 0, 0, 0, 0, 0);
220
+ __m512i idx_z_2_i32x16 = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 16, 19, 22, 25, 28, 31);
221
+ *z_f32x16_out = _mm512_permutex2var_ps(_mm512_permutex2var_ps(reg0_f32x16, idx_z_01_i32x16, reg1_f32x16),
222
+ idx_z_2_i32x16, reg2_f32x16);
223
+ }
224
+
115
225
  NK_INTERNAL nk_f64_t nk_reduce_stable_f64x8_skylake_(__m512d values_f64x8) {
116
226
  nk_b512_vec_t values;
117
227
  values.zmm_pd = values_f64x8;
@@ -166,84 +276,84 @@ NK_INTERNAL nk_f64_t nk_transformed_ssd_f32_skylake_(nk_f32_t const *a, nk_f32_t
166
276
  for (; index + 16 <= n; index += 16) {
167
277
  nk_deinterleave_f32x16_skylake_(a + index * 3, &a_x_f32x16, &a_y_f32x16, &a_z_f32x16),
168
278
  nk_deinterleave_f32x16_skylake_(b + index * 3, &b_x_f32x16, &b_y_f32x16, &b_z_f32x16);
169
- __m512d a_x_lower_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_x_f32x16));
170
- __m512d a_x_upper_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_x_f32x16, 1));
171
- __m512d a_y_lower_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_y_f32x16));
172
- __m512d a_y_upper_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_y_f32x16, 1));
173
- __m512d a_z_lower_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_z_f32x16));
174
- __m512d a_z_upper_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_z_f32x16, 1));
175
- __m512d b_x_lower_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_x_f32x16));
176
- __m512d b_x_upper_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_x_f32x16, 1));
177
- __m512d b_y_lower_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_y_f32x16));
178
- __m512d b_y_upper_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_y_f32x16, 1));
179
- __m512d b_z_lower_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_z_f32x16));
180
- __m512d b_z_upper_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_z_f32x16, 1));
181
-
182
- __m512d centered_a_x_lower_f64x8 = _mm512_sub_pd(a_x_lower_f64x8, centroid_a_x_f64x8);
183
- __m512d centered_a_x_upper_f64x8 = _mm512_sub_pd(a_x_upper_f64x8, centroid_a_x_f64x8);
184
- __m512d centered_a_y_lower_f64x8 = _mm512_sub_pd(a_y_lower_f64x8, centroid_a_y_f64x8);
185
- __m512d centered_a_y_upper_f64x8 = _mm512_sub_pd(a_y_upper_f64x8, centroid_a_y_f64x8);
186
- __m512d centered_a_z_lower_f64x8 = _mm512_sub_pd(a_z_lower_f64x8, centroid_a_z_f64x8);
187
- __m512d centered_a_z_upper_f64x8 = _mm512_sub_pd(a_z_upper_f64x8, centroid_a_z_f64x8);
188
- __m512d centered_b_x_lower_f64x8 = _mm512_sub_pd(b_x_lower_f64x8, centroid_b_x_f64x8);
189
- __m512d centered_b_x_upper_f64x8 = _mm512_sub_pd(b_x_upper_f64x8, centroid_b_x_f64x8);
190
- __m512d centered_b_y_lower_f64x8 = _mm512_sub_pd(b_y_lower_f64x8, centroid_b_y_f64x8);
191
- __m512d centered_b_y_upper_f64x8 = _mm512_sub_pd(b_y_upper_f64x8, centroid_b_y_f64x8);
192
- __m512d centered_b_z_lower_f64x8 = _mm512_sub_pd(b_z_lower_f64x8, centroid_b_z_f64x8);
193
- __m512d centered_b_z_upper_f64x8 = _mm512_sub_pd(b_z_upper_f64x8, centroid_b_z_f64x8);
194
-
195
- __m512d rotated_a_x_lower_f64x8 = _mm512_fmadd_pd(
196
- scaled_rotation_x_z_f64x8, centered_a_z_lower_f64x8,
197
- _mm512_fmadd_pd(scaled_rotation_x_y_f64x8, centered_a_y_lower_f64x8,
198
- _mm512_mul_pd(scaled_rotation_x_x_f64x8, centered_a_x_lower_f64x8)));
199
- __m512d rotated_a_x_upper_f64x8 = _mm512_fmadd_pd(
200
- scaled_rotation_x_z_f64x8, centered_a_z_upper_f64x8,
201
- _mm512_fmadd_pd(scaled_rotation_x_y_f64x8, centered_a_y_upper_f64x8,
202
- _mm512_mul_pd(scaled_rotation_x_x_f64x8, centered_a_x_upper_f64x8)));
203
- __m512d rotated_a_y_lower_f64x8 = _mm512_fmadd_pd(
204
- scaled_rotation_y_z_f64x8, centered_a_z_lower_f64x8,
205
- _mm512_fmadd_pd(scaled_rotation_y_y_f64x8, centered_a_y_lower_f64x8,
206
- _mm512_mul_pd(scaled_rotation_y_x_f64x8, centered_a_x_lower_f64x8)));
207
- __m512d rotated_a_y_upper_f64x8 = _mm512_fmadd_pd(
208
- scaled_rotation_y_z_f64x8, centered_a_z_upper_f64x8,
209
- _mm512_fmadd_pd(scaled_rotation_y_y_f64x8, centered_a_y_upper_f64x8,
210
- _mm512_mul_pd(scaled_rotation_y_x_f64x8, centered_a_x_upper_f64x8)));
211
- __m512d rotated_a_z_lower_f64x8 = _mm512_fmadd_pd(
212
- scaled_rotation_z_z_f64x8, centered_a_z_lower_f64x8,
213
- _mm512_fmadd_pd(scaled_rotation_z_y_f64x8, centered_a_y_lower_f64x8,
214
- _mm512_mul_pd(scaled_rotation_z_x_f64x8, centered_a_x_lower_f64x8)));
215
- __m512d rotated_a_z_upper_f64x8 = _mm512_fmadd_pd(
216
- scaled_rotation_z_z_f64x8, centered_a_z_upper_f64x8,
217
- _mm512_fmadd_pd(scaled_rotation_z_y_f64x8, centered_a_y_upper_f64x8,
218
- _mm512_mul_pd(scaled_rotation_z_x_f64x8, centered_a_x_upper_f64x8)));
219
-
220
- __m512d delta_x_lower_f64x8 = _mm512_sub_pd(rotated_a_x_lower_f64x8, centered_b_x_lower_f64x8);
221
- __m512d delta_x_upper_f64x8 = _mm512_sub_pd(rotated_a_x_upper_f64x8, centered_b_x_upper_f64x8);
222
- __m512d delta_y_lower_f64x8 = _mm512_sub_pd(rotated_a_y_lower_f64x8, centered_b_y_lower_f64x8);
223
- __m512d delta_y_upper_f64x8 = _mm512_sub_pd(rotated_a_y_upper_f64x8, centered_b_y_upper_f64x8);
224
- __m512d delta_z_lower_f64x8 = _mm512_sub_pd(rotated_a_z_lower_f64x8, centered_b_z_lower_f64x8);
225
- __m512d delta_z_upper_f64x8 = _mm512_sub_pd(rotated_a_z_upper_f64x8, centered_b_z_upper_f64x8);
226
-
227
- __m512d batch_sum_squared_f64x8 = _mm512_add_pd(_mm512_mul_pd(delta_x_lower_f64x8, delta_x_lower_f64x8),
228
- _mm512_mul_pd(delta_x_upper_f64x8, delta_x_upper_f64x8));
229
- batch_sum_squared_f64x8 = _mm512_fmadd_pd(delta_y_lower_f64x8, delta_y_lower_f64x8, batch_sum_squared_f64x8);
230
- batch_sum_squared_f64x8 = _mm512_fmadd_pd(delta_y_upper_f64x8, delta_y_upper_f64x8, batch_sum_squared_f64x8);
231
- batch_sum_squared_f64x8 = _mm512_fmadd_pd(delta_z_lower_f64x8, delta_z_lower_f64x8, batch_sum_squared_f64x8);
232
- batch_sum_squared_f64x8 = _mm512_fmadd_pd(delta_z_upper_f64x8, delta_z_upper_f64x8, batch_sum_squared_f64x8);
279
+ __m512d a_x_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_x_f32x16));
280
+ __m512d a_x_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_x_f32x16, 1));
281
+ __m512d a_y_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_y_f32x16));
282
+ __m512d a_y_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_y_f32x16, 1));
283
+ __m512d a_z_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_z_f32x16));
284
+ __m512d a_z_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_z_f32x16, 1));
285
+ __m512d b_x_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_x_f32x16));
286
+ __m512d b_x_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_x_f32x16, 1));
287
+ __m512d b_y_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_y_f32x16));
288
+ __m512d b_y_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_y_f32x16, 1));
289
+ __m512d b_z_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_z_f32x16));
290
+ __m512d b_z_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_z_f32x16, 1));
291
+
292
+ __m512d centered_a_x_low_f64x8 = _mm512_sub_pd(a_x_low_f64x8, centroid_a_x_f64x8);
293
+ __m512d centered_a_x_high_f64x8 = _mm512_sub_pd(a_x_high_f64x8, centroid_a_x_f64x8);
294
+ __m512d centered_a_y_low_f64x8 = _mm512_sub_pd(a_y_low_f64x8, centroid_a_y_f64x8);
295
+ __m512d centered_a_y_high_f64x8 = _mm512_sub_pd(a_y_high_f64x8, centroid_a_y_f64x8);
296
+ __m512d centered_a_z_low_f64x8 = _mm512_sub_pd(a_z_low_f64x8, centroid_a_z_f64x8);
297
+ __m512d centered_a_z_high_f64x8 = _mm512_sub_pd(a_z_high_f64x8, centroid_a_z_f64x8);
298
+ __m512d centered_b_x_low_f64x8 = _mm512_sub_pd(b_x_low_f64x8, centroid_b_x_f64x8);
299
+ __m512d centered_b_x_high_f64x8 = _mm512_sub_pd(b_x_high_f64x8, centroid_b_x_f64x8);
300
+ __m512d centered_b_y_low_f64x8 = _mm512_sub_pd(b_y_low_f64x8, centroid_b_y_f64x8);
301
+ __m512d centered_b_y_high_f64x8 = _mm512_sub_pd(b_y_high_f64x8, centroid_b_y_f64x8);
302
+ __m512d centered_b_z_low_f64x8 = _mm512_sub_pd(b_z_low_f64x8, centroid_b_z_f64x8);
303
+ __m512d centered_b_z_high_f64x8 = _mm512_sub_pd(b_z_high_f64x8, centroid_b_z_f64x8);
304
+
305
+ __m512d rotated_a_x_low_f64x8 = _mm512_fmadd_pd(
306
+ scaled_rotation_x_z_f64x8, centered_a_z_low_f64x8,
307
+ _mm512_fmadd_pd(scaled_rotation_x_y_f64x8, centered_a_y_low_f64x8,
308
+ _mm512_mul_pd(scaled_rotation_x_x_f64x8, centered_a_x_low_f64x8)));
309
+ __m512d rotated_a_x_high_f64x8 = _mm512_fmadd_pd(
310
+ scaled_rotation_x_z_f64x8, centered_a_z_high_f64x8,
311
+ _mm512_fmadd_pd(scaled_rotation_x_y_f64x8, centered_a_y_high_f64x8,
312
+ _mm512_mul_pd(scaled_rotation_x_x_f64x8, centered_a_x_high_f64x8)));
313
+ __m512d rotated_a_y_low_f64x8 = _mm512_fmadd_pd(
314
+ scaled_rotation_y_z_f64x8, centered_a_z_low_f64x8,
315
+ _mm512_fmadd_pd(scaled_rotation_y_y_f64x8, centered_a_y_low_f64x8,
316
+ _mm512_mul_pd(scaled_rotation_y_x_f64x8, centered_a_x_low_f64x8)));
317
+ __m512d rotated_a_y_high_f64x8 = _mm512_fmadd_pd(
318
+ scaled_rotation_y_z_f64x8, centered_a_z_high_f64x8,
319
+ _mm512_fmadd_pd(scaled_rotation_y_y_f64x8, centered_a_y_high_f64x8,
320
+ _mm512_mul_pd(scaled_rotation_y_x_f64x8, centered_a_x_high_f64x8)));
321
+ __m512d rotated_a_z_low_f64x8 = _mm512_fmadd_pd(
322
+ scaled_rotation_z_z_f64x8, centered_a_z_low_f64x8,
323
+ _mm512_fmadd_pd(scaled_rotation_z_y_f64x8, centered_a_y_low_f64x8,
324
+ _mm512_mul_pd(scaled_rotation_z_x_f64x8, centered_a_x_low_f64x8)));
325
+ __m512d rotated_a_z_high_f64x8 = _mm512_fmadd_pd(
326
+ scaled_rotation_z_z_f64x8, centered_a_z_high_f64x8,
327
+ _mm512_fmadd_pd(scaled_rotation_z_y_f64x8, centered_a_y_high_f64x8,
328
+ _mm512_mul_pd(scaled_rotation_z_x_f64x8, centered_a_x_high_f64x8)));
329
+
330
+ __m512d delta_x_low_f64x8 = _mm512_sub_pd(rotated_a_x_low_f64x8, centered_b_x_low_f64x8);
331
+ __m512d delta_x_high_f64x8 = _mm512_sub_pd(rotated_a_x_high_f64x8, centered_b_x_high_f64x8);
332
+ __m512d delta_y_low_f64x8 = _mm512_sub_pd(rotated_a_y_low_f64x8, centered_b_y_low_f64x8);
333
+ __m512d delta_y_high_f64x8 = _mm512_sub_pd(rotated_a_y_high_f64x8, centered_b_y_high_f64x8);
334
+ __m512d delta_z_low_f64x8 = _mm512_sub_pd(rotated_a_z_low_f64x8, centered_b_z_low_f64x8);
335
+ __m512d delta_z_high_f64x8 = _mm512_sub_pd(rotated_a_z_high_f64x8, centered_b_z_high_f64x8);
336
+
337
+ __m512d batch_sum_squared_f64x8 = _mm512_add_pd(_mm512_mul_pd(delta_x_low_f64x8, delta_x_low_f64x8),
338
+ _mm512_mul_pd(delta_x_high_f64x8, delta_x_high_f64x8));
339
+ batch_sum_squared_f64x8 = _mm512_fmadd_pd(delta_y_low_f64x8, delta_y_low_f64x8, batch_sum_squared_f64x8);
340
+ batch_sum_squared_f64x8 = _mm512_fmadd_pd(delta_y_high_f64x8, delta_y_high_f64x8, batch_sum_squared_f64x8);
341
+ batch_sum_squared_f64x8 = _mm512_fmadd_pd(delta_z_low_f64x8, delta_z_low_f64x8, batch_sum_squared_f64x8);
342
+ batch_sum_squared_f64x8 = _mm512_fmadd_pd(delta_z_high_f64x8, delta_z_high_f64x8, batch_sum_squared_f64x8);
233
343
  sum_squared_f64x8 = _mm512_add_pd(sum_squared_f64x8, batch_sum_squared_f64x8);
234
344
  }
235
345
 
236
346
  nk_f64_t sum_squared = _mm512_reduce_add_pd(sum_squared_f64x8);
237
347
  for (; index < n; ++index) {
238
- nk_f64_t centered_a_x = (nk_f64_t)a[index * 3 + 0] - centroid_a_x;
239
- nk_f64_t centered_a_y = (nk_f64_t)a[index * 3 + 1] - centroid_a_y;
240
- nk_f64_t centered_a_z = (nk_f64_t)a[index * 3 + 2] - centroid_a_z;
241
- nk_f64_t centered_b_x = (nk_f64_t)b[index * 3 + 0] - centroid_b_x;
242
- nk_f64_t centered_b_y = (nk_f64_t)b[index * 3 + 1] - centroid_b_y;
243
- nk_f64_t centered_b_z = (nk_f64_t)b[index * 3 + 2] - centroid_b_z;
244
- nk_f64_t rotated_a_x = scale * (r[0] * centered_a_x + r[1] * centered_a_y + r[2] * centered_a_z);
245
- nk_f64_t rotated_a_y = scale * (r[3] * centered_a_x + r[4] * centered_a_y + r[5] * centered_a_z);
246
- nk_f64_t rotated_a_z = scale * (r[6] * centered_a_x + r[7] * centered_a_y + r[8] * centered_a_z);
348
+ nk_f64_t centered_a_x = (nk_f64_t)a[index * 3 + 0] - centroid_a_x,
349
+ centered_a_y = (nk_f64_t)a[index * 3 + 1] - centroid_a_y,
350
+ centered_a_z = (nk_f64_t)a[index * 3 + 2] - centroid_a_z;
351
+ nk_f64_t centered_b_x = (nk_f64_t)b[index * 3 + 0] - centroid_b_x,
352
+ centered_b_y = (nk_f64_t)b[index * 3 + 1] - centroid_b_y,
353
+ centered_b_z = (nk_f64_t)b[index * 3 + 2] - centroid_b_z;
354
+ nk_f64_t rotated_a_x = scale * (r[0] * centered_a_x + r[1] * centered_a_y + r[2] * centered_a_z),
355
+ rotated_a_y = scale * (r[3] * centered_a_x + r[4] * centered_a_y + r[5] * centered_a_z),
356
+ rotated_a_z = scale * (r[6] * centered_a_x + r[7] * centered_a_y + r[8] * centered_a_z);
247
357
  nk_f64_t delta_x = rotated_a_x - centered_b_x, delta_y = rotated_a_y - centered_b_y,
248
358
  delta_z = rotated_a_z - centered_b_z;
249
359
  sum_squared += delta_x * delta_x + delta_y * delta_y + delta_z * delta_z;
@@ -322,20 +432,16 @@ NK_INTERNAL nk_f64_t nk_transformed_ssd_f64_skylake_(nk_f64_t const *a, nk_f64_t
322
432
 
323
433
  // Scalar tail
324
434
  for (; j < n; ++j) {
325
- nk_f64_t pa_x = a[j * 3 + 0] - centroid_a_x;
326
- nk_f64_t pa_y = a[j * 3 + 1] - centroid_a_y;
327
- nk_f64_t pa_z = a[j * 3 + 2] - centroid_a_z;
328
- nk_f64_t pb_x = b[j * 3 + 0] - centroid_b_x;
329
- nk_f64_t pb_y = b[j * 3 + 1] - centroid_b_y;
330
- nk_f64_t pb_z = b[j * 3 + 2] - centroid_b_z;
331
-
332
- nk_f64_t ra_x = scale * (r[0] * pa_x + r[1] * pa_y + r[2] * pa_z);
333
- nk_f64_t ra_y = scale * (r[3] * pa_x + r[4] * pa_y + r[5] * pa_z);
334
- nk_f64_t ra_z = scale * (r[6] * pa_x + r[7] * pa_y + r[8] * pa_z);
335
-
336
- nk_f64_t delta_x = ra_x - pb_x;
337
- nk_f64_t delta_y = ra_y - pb_y;
338
- nk_f64_t delta_z = ra_z - pb_z;
435
+ nk_f64_t pa_x = a[j * 3 + 0] - centroid_a_x, pa_y = a[j * 3 + 1] - centroid_a_y,
436
+ pa_z = a[j * 3 + 2] - centroid_a_z;
437
+ nk_f64_t pb_x = b[j * 3 + 0] - centroid_b_x, pb_y = b[j * 3 + 1] - centroid_b_y,
438
+ pb_z = b[j * 3 + 2] - centroid_b_z;
439
+
440
+ nk_f64_t ra_x = scale * (r[0] * pa_x + r[1] * pa_y + r[2] * pa_z),
441
+ ra_y = scale * (r[3] * pa_x + r[4] * pa_y + r[5] * pa_z),
442
+ ra_z = scale * (r[6] * pa_x + r[7] * pa_y + r[8] * pa_z);
443
+
444
+ nk_f64_t delta_x = ra_x - pb_x, delta_y = ra_y - pb_y, delta_z = ra_z - pb_z;
339
445
  nk_accumulate_square_f64_(&sum_squared, &sum_squared_compensation, delta_x);
340
446
  nk_accumulate_square_f64_(&sum_squared, &sum_squared_compensation, delta_y);
341
447
  nk_accumulate_square_f64_(&sum_squared, &sum_squared_compensation, delta_z);
@@ -344,139 +450,526 @@ NK_INTERNAL nk_f64_t nk_transformed_ssd_f64_skylake_(nk_f64_t const *a, nk_f64_t
344
450
  return sum_squared + sum_squared_compensation;
345
451
  }
346
452
 
347
- NK_INTERNAL void nk_centroid_and_cross_covariance_f32_skylake_( //
348
- nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, //
349
- nk_f64_t *centroid_a_x, nk_f64_t *centroid_a_y, nk_f64_t *centroid_a_z, nk_f64_t *centroid_b_x,
350
- nk_f64_t *centroid_b_y, nk_f64_t *centroid_b_z, nk_f64_t cross_covariance_f64[9]) {
351
- __m512d sum_a_x_f64x8 = _mm512_setzero_pd(), sum_a_y_f64x8 = _mm512_setzero_pd();
352
- __m512d sum_a_z_f64x8 = _mm512_setzero_pd(), sum_b_x_f64x8 = _mm512_setzero_pd();
353
- __m512d sum_b_y_f64x8 = _mm512_setzero_pd(), sum_b_z_f64x8 = _mm512_setzero_pd();
354
- __m512d covariance_00_f64x8 = _mm512_setzero_pd(), covariance_01_f64x8 = _mm512_setzero_pd();
355
- __m512d covariance_02_f64x8 = _mm512_setzero_pd(), covariance_10_f64x8 = _mm512_setzero_pd();
356
- __m512d covariance_11_f64x8 = _mm512_setzero_pd(), covariance_12_f64x8 = _mm512_setzero_pd();
357
- __m512d covariance_20_f64x8 = _mm512_setzero_pd(), covariance_21_f64x8 = _mm512_setzero_pd();
358
- __m512d covariance_22_f64x8 = _mm512_setzero_pd();
453
+ /* Compute sum of squared distances for f16 data after applying rotation (and optional scale).
454
+ * Loads f16, converts to f32 for computation. Rotation matrix, scale, and centroids are f32.
455
+ */
456
+ NK_INTERNAL nk_f32_t nk_transformed_ssd_f16_skylake_(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n,
457
+ nk_f32_t const *r, nk_f32_t scale, nk_f32_t centroid_a_x,
458
+ nk_f32_t centroid_a_y, nk_f32_t centroid_a_z,
459
+ nk_f32_t centroid_b_x, nk_f32_t centroid_b_y,
460
+ nk_f32_t centroid_b_z) {
461
+ __m512 scaled_rotation_x_x_f32x16 = _mm512_set1_ps(scale * r[0]);
462
+ __m512 scaled_rotation_x_y_f32x16 = _mm512_set1_ps(scale * r[1]);
463
+ __m512 scaled_rotation_x_z_f32x16 = _mm512_set1_ps(scale * r[2]);
464
+ __m512 scaled_rotation_y_x_f32x16 = _mm512_set1_ps(scale * r[3]);
465
+ __m512 scaled_rotation_y_y_f32x16 = _mm512_set1_ps(scale * r[4]);
466
+ __m512 scaled_rotation_y_z_f32x16 = _mm512_set1_ps(scale * r[5]);
467
+ __m512 scaled_rotation_z_x_f32x16 = _mm512_set1_ps(scale * r[6]);
468
+ __m512 scaled_rotation_z_y_f32x16 = _mm512_set1_ps(scale * r[7]);
469
+ __m512 scaled_rotation_z_z_f32x16 = _mm512_set1_ps(scale * r[8]);
470
+
471
+ __m512 centroid_a_x_f32x16 = _mm512_set1_ps(centroid_a_x);
472
+ __m512 centroid_a_y_f32x16 = _mm512_set1_ps(centroid_a_y);
473
+ __m512 centroid_a_z_f32x16 = _mm512_set1_ps(centroid_a_z);
474
+ __m512 centroid_b_x_f32x16 = _mm512_set1_ps(centroid_b_x);
475
+ __m512 centroid_b_y_f32x16 = _mm512_set1_ps(centroid_b_y);
476
+ __m512 centroid_b_z_f32x16 = _mm512_set1_ps(centroid_b_z);
477
+
478
+ __m512 sum_squared_f32x16 = _mm512_setzero_ps();
359
479
  __m512 a_x_f32x16, a_y_f32x16, a_z_f32x16, b_x_f32x16, b_y_f32x16, b_z_f32x16;
360
- nk_size_t index = 0;
480
+ nk_size_t j = 0;
361
481
 
362
- for (; index + 16 <= n; index += 16) {
363
- nk_deinterleave_f32x16_skylake_(a + index * 3, &a_x_f32x16, &a_y_f32x16, &a_z_f32x16),
364
- nk_deinterleave_f32x16_skylake_(b + index * 3, &b_x_f32x16, &b_y_f32x16, &b_z_f32x16);
365
- __m512d a_x_lower_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_x_f32x16));
366
- __m512d a_x_upper_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_x_f32x16, 1));
367
- __m512d a_y_lower_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_y_f32x16));
368
- __m512d a_y_upper_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_y_f32x16, 1));
369
- __m512d a_z_lower_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_z_f32x16));
370
- __m512d a_z_upper_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_z_f32x16, 1));
371
- __m512d b_x_lower_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_x_f32x16));
372
- __m512d b_x_upper_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_x_f32x16, 1));
373
- __m512d b_y_lower_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_y_f32x16));
374
- __m512d b_y_upper_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_y_f32x16, 1));
375
- __m512d b_z_lower_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_z_f32x16));
376
- __m512d b_z_upper_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_z_f32x16, 1));
377
-
378
- sum_a_x_f64x8 = _mm512_add_pd(sum_a_x_f64x8, _mm512_add_pd(a_x_lower_f64x8, a_x_upper_f64x8)),
379
- sum_a_y_f64x8 = _mm512_add_pd(sum_a_y_f64x8, _mm512_add_pd(a_y_lower_f64x8, a_y_upper_f64x8)),
380
- sum_a_z_f64x8 = _mm512_add_pd(sum_a_z_f64x8, _mm512_add_pd(a_z_lower_f64x8, a_z_upper_f64x8));
381
- sum_b_x_f64x8 = _mm512_add_pd(sum_b_x_f64x8, _mm512_add_pd(b_x_lower_f64x8, b_x_upper_f64x8)),
382
- sum_b_y_f64x8 = _mm512_add_pd(sum_b_y_f64x8, _mm512_add_pd(b_y_lower_f64x8, b_y_upper_f64x8)),
383
- sum_b_z_f64x8 = _mm512_add_pd(sum_b_z_f64x8, _mm512_add_pd(b_z_lower_f64x8, b_z_upper_f64x8));
384
- covariance_00_f64x8 = _mm512_add_pd(covariance_00_f64x8,
385
- _mm512_add_pd(_mm512_mul_pd(a_x_lower_f64x8, b_x_lower_f64x8),
386
- _mm512_mul_pd(a_x_upper_f64x8, b_x_upper_f64x8))),
387
- covariance_01_f64x8 = _mm512_add_pd(covariance_01_f64x8,
388
- _mm512_add_pd(_mm512_mul_pd(a_x_lower_f64x8, b_y_lower_f64x8),
389
- _mm512_mul_pd(a_x_upper_f64x8, b_y_upper_f64x8))),
390
- covariance_02_f64x8 = _mm512_add_pd(covariance_02_f64x8,
391
- _mm512_add_pd(_mm512_mul_pd(a_x_lower_f64x8, b_z_lower_f64x8),
392
- _mm512_mul_pd(a_x_upper_f64x8, b_z_upper_f64x8)));
393
- covariance_10_f64x8 = _mm512_add_pd(covariance_10_f64x8,
394
- _mm512_add_pd(_mm512_mul_pd(a_y_lower_f64x8, b_x_lower_f64x8),
395
- _mm512_mul_pd(a_y_upper_f64x8, b_x_upper_f64x8))),
396
- covariance_11_f64x8 = _mm512_add_pd(covariance_11_f64x8,
397
- _mm512_add_pd(_mm512_mul_pd(a_y_lower_f64x8, b_y_lower_f64x8),
398
- _mm512_mul_pd(a_y_upper_f64x8, b_y_upper_f64x8))),
399
- covariance_12_f64x8 = _mm512_add_pd(covariance_12_f64x8,
400
- _mm512_add_pd(_mm512_mul_pd(a_y_lower_f64x8, b_z_lower_f64x8),
401
- _mm512_mul_pd(a_y_upper_f64x8, b_z_upper_f64x8)));
402
- covariance_20_f64x8 = _mm512_add_pd(covariance_20_f64x8,
403
- _mm512_add_pd(_mm512_mul_pd(a_z_lower_f64x8, b_x_lower_f64x8),
404
- _mm512_mul_pd(a_z_upper_f64x8, b_x_upper_f64x8))),
405
- covariance_21_f64x8 = _mm512_add_pd(covariance_21_f64x8,
406
- _mm512_add_pd(_mm512_mul_pd(a_z_lower_f64x8, b_y_lower_f64x8),
407
- _mm512_mul_pd(a_z_upper_f64x8, b_y_upper_f64x8))),
408
- covariance_22_f64x8 = _mm512_add_pd(covariance_22_f64x8,
409
- _mm512_add_pd(_mm512_mul_pd(a_z_lower_f64x8, b_z_lower_f64x8),
410
- _mm512_mul_pd(a_z_upper_f64x8, b_z_upper_f64x8)));
482
+ for (; j + 16 <= n; j += 16) {
483
+ nk_deinterleave_f16x16_to_f32x16_skylake_(a + j * 3, &a_x_f32x16, &a_y_f32x16, &a_z_f32x16);
484
+ nk_deinterleave_f16x16_to_f32x16_skylake_(b + j * 3, &b_x_f32x16, &b_y_f32x16, &b_z_f32x16);
485
+
486
+ __m512 pa_x_f32x16 = _mm512_sub_ps(a_x_f32x16, centroid_a_x_f32x16);
487
+ __m512 pa_y_f32x16 = _mm512_sub_ps(a_y_f32x16, centroid_a_y_f32x16);
488
+ __m512 pa_z_f32x16 = _mm512_sub_ps(a_z_f32x16, centroid_a_z_f32x16);
489
+ __m512 pb_x_f32x16 = _mm512_sub_ps(b_x_f32x16, centroid_b_x_f32x16);
490
+ __m512 pb_y_f32x16 = _mm512_sub_ps(b_y_f32x16, centroid_b_y_f32x16);
491
+ __m512 pb_z_f32x16 = _mm512_sub_ps(b_z_f32x16, centroid_b_z_f32x16);
492
+
493
+ __m512 ra_x_f32x16 = _mm512_fmadd_ps(scaled_rotation_x_z_f32x16, pa_z_f32x16,
494
+ _mm512_fmadd_ps(scaled_rotation_x_y_f32x16, pa_y_f32x16,
495
+ _mm512_mul_ps(scaled_rotation_x_x_f32x16, pa_x_f32x16)));
496
+ __m512 ra_y_f32x16 = _mm512_fmadd_ps(scaled_rotation_y_z_f32x16, pa_z_f32x16,
497
+ _mm512_fmadd_ps(scaled_rotation_y_y_f32x16, pa_y_f32x16,
498
+ _mm512_mul_ps(scaled_rotation_y_x_f32x16, pa_x_f32x16)));
499
+ __m512 ra_z_f32x16 = _mm512_fmadd_ps(scaled_rotation_z_z_f32x16, pa_z_f32x16,
500
+ _mm512_fmadd_ps(scaled_rotation_z_y_f32x16, pa_y_f32x16,
501
+ _mm512_mul_ps(scaled_rotation_z_x_f32x16, pa_x_f32x16)));
502
+
503
+ __m512 delta_x_f32x16 = _mm512_sub_ps(ra_x_f32x16, pb_x_f32x16);
504
+ __m512 delta_y_f32x16 = _mm512_sub_ps(ra_y_f32x16, pb_y_f32x16);
505
+ __m512 delta_z_f32x16 = _mm512_sub_ps(ra_z_f32x16, pb_z_f32x16);
506
+
507
+ sum_squared_f32x16 = _mm512_fmadd_ps(delta_x_f32x16, delta_x_f32x16, sum_squared_f32x16);
508
+ sum_squared_f32x16 = _mm512_fmadd_ps(delta_y_f32x16, delta_y_f32x16, sum_squared_f32x16);
509
+ sum_squared_f32x16 = _mm512_fmadd_ps(delta_z_f32x16, delta_z_f32x16, sum_squared_f32x16);
411
510
  }
412
511
 
413
- nk_f64_t sum_a_x = _mm512_reduce_add_pd(sum_a_x_f64x8), sum_a_y = _mm512_reduce_add_pd(sum_a_y_f64x8),
414
- sum_a_z = _mm512_reduce_add_pd(sum_a_z_f64x8);
415
- nk_f64_t sum_b_x = _mm512_reduce_add_pd(sum_b_x_f64x8), sum_b_y = _mm512_reduce_add_pd(sum_b_y_f64x8),
416
- sum_b_z = _mm512_reduce_add_pd(sum_b_z_f64x8);
417
- nk_f64_t covariance_00 = _mm512_reduce_add_pd(covariance_00_f64x8),
418
- covariance_01 = _mm512_reduce_add_pd(covariance_01_f64x8),
419
- covariance_02 = _mm512_reduce_add_pd(covariance_02_f64x8);
420
- nk_f64_t covariance_10 = _mm512_reduce_add_pd(covariance_10_f64x8),
421
- covariance_11 = _mm512_reduce_add_pd(covariance_11_f64x8),
422
- covariance_12 = _mm512_reduce_add_pd(covariance_12_f64x8);
423
- nk_f64_t covariance_20 = _mm512_reduce_add_pd(covariance_20_f64x8),
424
- covariance_21 = _mm512_reduce_add_pd(covariance_21_f64x8),
425
- covariance_22 = _mm512_reduce_add_pd(covariance_22_f64x8);
512
+ // Tail: deinterleave remaining points into zero-initialized vectors
513
+ if (j < n) {
514
+ nk_size_t tail = n - j;
515
+ nk_deinterleave_f16_tail_to_f32x16_skylake_(a + j * 3, tail, &a_x_f32x16, &a_y_f32x16, &a_z_f32x16);
516
+ nk_deinterleave_f16_tail_to_f32x16_skylake_(b + j * 3, tail, &b_x_f32x16, &b_y_f32x16, &b_z_f32x16);
517
+
518
+ __m512 pa_x_f32x16 = _mm512_sub_ps(a_x_f32x16, centroid_a_x_f32x16);
519
+ __m512 pa_y_f32x16 = _mm512_sub_ps(a_y_f32x16, centroid_a_y_f32x16);
520
+ __m512 pa_z_f32x16 = _mm512_sub_ps(a_z_f32x16, centroid_a_z_f32x16);
521
+ __m512 pb_x_f32x16 = _mm512_sub_ps(b_x_f32x16, centroid_b_x_f32x16);
522
+ __m512 pb_y_f32x16 = _mm512_sub_ps(b_y_f32x16, centroid_b_y_f32x16);
523
+ __m512 pb_z_f32x16 = _mm512_sub_ps(b_z_f32x16, centroid_b_z_f32x16);
524
+
525
+ __m512 ra_x_f32x16 = _mm512_fmadd_ps(scaled_rotation_x_z_f32x16, pa_z_f32x16,
526
+ _mm512_fmadd_ps(scaled_rotation_x_y_f32x16, pa_y_f32x16,
527
+ _mm512_mul_ps(scaled_rotation_x_x_f32x16, pa_x_f32x16)));
528
+ __m512 ra_y_f32x16 = _mm512_fmadd_ps(scaled_rotation_y_z_f32x16, pa_z_f32x16,
529
+ _mm512_fmadd_ps(scaled_rotation_y_y_f32x16, pa_y_f32x16,
530
+ _mm512_mul_ps(scaled_rotation_y_x_f32x16, pa_x_f32x16)));
531
+ __m512 ra_z_f32x16 = _mm512_fmadd_ps(scaled_rotation_z_z_f32x16, pa_z_f32x16,
532
+ _mm512_fmadd_ps(scaled_rotation_z_y_f32x16, pa_y_f32x16,
533
+ _mm512_mul_ps(scaled_rotation_z_x_f32x16, pa_x_f32x16)));
534
+
535
+ __m512 delta_x_f32x16 = _mm512_sub_ps(ra_x_f32x16, pb_x_f32x16);
536
+ __m512 delta_y_f32x16 = _mm512_sub_ps(ra_y_f32x16, pb_y_f32x16);
537
+ __m512 delta_z_f32x16 = _mm512_sub_ps(ra_z_f32x16, pb_z_f32x16);
538
+
539
+ sum_squared_f32x16 = _mm512_fmadd_ps(delta_x_f32x16, delta_x_f32x16, sum_squared_f32x16);
540
+ sum_squared_f32x16 = _mm512_fmadd_ps(delta_y_f32x16, delta_y_f32x16, sum_squared_f32x16);
541
+ sum_squared_f32x16 = _mm512_fmadd_ps(delta_z_f32x16, delta_z_f32x16, sum_squared_f32x16);
542
+ }
426
543
 
427
- for (; index < n; ++index) {
428
- nk_f64_t a_x = a[index * 3 + 0], a_y = a[index * 3 + 1], a_z = a[index * 3 + 2];
429
- nk_f64_t b_x = b[index * 3 + 0], b_y = b[index * 3 + 1], b_z = b[index * 3 + 2];
430
- sum_a_x += a_x, sum_a_y += a_y, sum_a_z += a_z;
431
- sum_b_x += b_x, sum_b_y += b_y, sum_b_z += b_z;
432
- covariance_00 += a_x * b_x, covariance_01 += a_x * b_y, covariance_02 += a_x * b_z;
433
- covariance_10 += a_y * b_x, covariance_11 += a_y * b_y, covariance_12 += a_y * b_z;
434
- covariance_20 += a_z * b_x, covariance_21 += a_z * b_y, covariance_22 += a_z * b_z;
544
+ return _mm512_reduce_add_ps(sum_squared_f32x16);
545
+ }
546
+
547
+ /* Compute sum of squared distances for bf16 data after applying rotation (and optional scale).
548
+ * Loads bf16, converts to f32 for computation. Rotation matrix, scale, and centroids are f32.
549
+ */
550
+ NK_INTERNAL nk_f32_t nk_transformed_ssd_bf16_skylake_(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n,
551
+ nk_f32_t const *r, nk_f32_t scale, nk_f32_t centroid_a_x,
552
+ nk_f32_t centroid_a_y, nk_f32_t centroid_a_z,
553
+ nk_f32_t centroid_b_x, nk_f32_t centroid_b_y,
554
+ nk_f32_t centroid_b_z) {
555
+ __m512 scaled_rotation_x_x_f32x16 = _mm512_set1_ps(scale * r[0]);
556
+ __m512 scaled_rotation_x_y_f32x16 = _mm512_set1_ps(scale * r[1]);
557
+ __m512 scaled_rotation_x_z_f32x16 = _mm512_set1_ps(scale * r[2]);
558
+ __m512 scaled_rotation_y_x_f32x16 = _mm512_set1_ps(scale * r[3]);
559
+ __m512 scaled_rotation_y_y_f32x16 = _mm512_set1_ps(scale * r[4]);
560
+ __m512 scaled_rotation_y_z_f32x16 = _mm512_set1_ps(scale * r[5]);
561
+ __m512 scaled_rotation_z_x_f32x16 = _mm512_set1_ps(scale * r[6]);
562
+ __m512 scaled_rotation_z_y_f32x16 = _mm512_set1_ps(scale * r[7]);
563
+ __m512 scaled_rotation_z_z_f32x16 = _mm512_set1_ps(scale * r[8]);
564
+
565
+ __m512 centroid_a_x_f32x16 = _mm512_set1_ps(centroid_a_x);
566
+ __m512 centroid_a_y_f32x16 = _mm512_set1_ps(centroid_a_y);
567
+ __m512 centroid_a_z_f32x16 = _mm512_set1_ps(centroid_a_z);
568
+ __m512 centroid_b_x_f32x16 = _mm512_set1_ps(centroid_b_x);
569
+ __m512 centroid_b_y_f32x16 = _mm512_set1_ps(centroid_b_y);
570
+ __m512 centroid_b_z_f32x16 = _mm512_set1_ps(centroid_b_z);
571
+
572
+ __m512 sum_squared_f32x16 = _mm512_setzero_ps();
573
+ __m512 a_x_f32x16, a_y_f32x16, a_z_f32x16, b_x_f32x16, b_y_f32x16, b_z_f32x16;
574
+ nk_size_t j = 0;
575
+
576
+ for (; j + 16 <= n; j += 16) {
577
+ nk_deinterleave_bf16x16_to_f32x16_skylake_(a + j * 3, &a_x_f32x16, &a_y_f32x16, &a_z_f32x16);
578
+ nk_deinterleave_bf16x16_to_f32x16_skylake_(b + j * 3, &b_x_f32x16, &b_y_f32x16, &b_z_f32x16);
579
+
580
+ __m512 pa_x_f32x16 = _mm512_sub_ps(a_x_f32x16, centroid_a_x_f32x16);
581
+ __m512 pa_y_f32x16 = _mm512_sub_ps(a_y_f32x16, centroid_a_y_f32x16);
582
+ __m512 pa_z_f32x16 = _mm512_sub_ps(a_z_f32x16, centroid_a_z_f32x16);
583
+ __m512 pb_x_f32x16 = _mm512_sub_ps(b_x_f32x16, centroid_b_x_f32x16);
584
+ __m512 pb_y_f32x16 = _mm512_sub_ps(b_y_f32x16, centroid_b_y_f32x16);
585
+ __m512 pb_z_f32x16 = _mm512_sub_ps(b_z_f32x16, centroid_b_z_f32x16);
586
+
587
+ __m512 ra_x_f32x16 = _mm512_fmadd_ps(scaled_rotation_x_z_f32x16, pa_z_f32x16,
588
+ _mm512_fmadd_ps(scaled_rotation_x_y_f32x16, pa_y_f32x16,
589
+ _mm512_mul_ps(scaled_rotation_x_x_f32x16, pa_x_f32x16)));
590
+ __m512 ra_y_f32x16 = _mm512_fmadd_ps(scaled_rotation_y_z_f32x16, pa_z_f32x16,
591
+ _mm512_fmadd_ps(scaled_rotation_y_y_f32x16, pa_y_f32x16,
592
+ _mm512_mul_ps(scaled_rotation_y_x_f32x16, pa_x_f32x16)));
593
+ __m512 ra_z_f32x16 = _mm512_fmadd_ps(scaled_rotation_z_z_f32x16, pa_z_f32x16,
594
+ _mm512_fmadd_ps(scaled_rotation_z_y_f32x16, pa_y_f32x16,
595
+ _mm512_mul_ps(scaled_rotation_z_x_f32x16, pa_x_f32x16)));
596
+
597
+ __m512 delta_x_f32x16 = _mm512_sub_ps(ra_x_f32x16, pb_x_f32x16);
598
+ __m512 delta_y_f32x16 = _mm512_sub_ps(ra_y_f32x16, pb_y_f32x16);
599
+ __m512 delta_z_f32x16 = _mm512_sub_ps(ra_z_f32x16, pb_z_f32x16);
600
+
601
+ sum_squared_f32x16 = _mm512_fmadd_ps(delta_x_f32x16, delta_x_f32x16, sum_squared_f32x16);
602
+ sum_squared_f32x16 = _mm512_fmadd_ps(delta_y_f32x16, delta_y_f32x16, sum_squared_f32x16);
603
+ sum_squared_f32x16 = _mm512_fmadd_ps(delta_z_f32x16, delta_z_f32x16, sum_squared_f32x16);
604
+ }
605
+
606
+ // Tail: deinterleave remaining points into zero-initialized vectors
607
+ if (j < n) {
608
+ nk_size_t tail = n - j;
609
+ nk_deinterleave_bf16_tail_to_f32x16_skylake_(a + j * 3, tail, &a_x_f32x16, &a_y_f32x16, &a_z_f32x16);
610
+ nk_deinterleave_bf16_tail_to_f32x16_skylake_(b + j * 3, tail, &b_x_f32x16, &b_y_f32x16, &b_z_f32x16);
611
+
612
+ __m512 pa_x_f32x16 = _mm512_sub_ps(a_x_f32x16, centroid_a_x_f32x16);
613
+ __m512 pa_y_f32x16 = _mm512_sub_ps(a_y_f32x16, centroid_a_y_f32x16);
614
+ __m512 pa_z_f32x16 = _mm512_sub_ps(a_z_f32x16, centroid_a_z_f32x16);
615
+ __m512 pb_x_f32x16 = _mm512_sub_ps(b_x_f32x16, centroid_b_x_f32x16);
616
+ __m512 pb_y_f32x16 = _mm512_sub_ps(b_y_f32x16, centroid_b_y_f32x16);
617
+ __m512 pb_z_f32x16 = _mm512_sub_ps(b_z_f32x16, centroid_b_z_f32x16);
618
+
619
+ __m512 ra_x_f32x16 = _mm512_fmadd_ps(scaled_rotation_x_z_f32x16, pa_z_f32x16,
620
+ _mm512_fmadd_ps(scaled_rotation_x_y_f32x16, pa_y_f32x16,
621
+ _mm512_mul_ps(scaled_rotation_x_x_f32x16, pa_x_f32x16)));
622
+ __m512 ra_y_f32x16 = _mm512_fmadd_ps(scaled_rotation_y_z_f32x16, pa_z_f32x16,
623
+ _mm512_fmadd_ps(scaled_rotation_y_y_f32x16, pa_y_f32x16,
624
+ _mm512_mul_ps(scaled_rotation_y_x_f32x16, pa_x_f32x16)));
625
+ __m512 ra_z_f32x16 = _mm512_fmadd_ps(scaled_rotation_z_z_f32x16, pa_z_f32x16,
626
+ _mm512_fmadd_ps(scaled_rotation_z_y_f32x16, pa_y_f32x16,
627
+ _mm512_mul_ps(scaled_rotation_z_x_f32x16, pa_x_f32x16)));
628
+
629
+ __m512 delta_x_f32x16 = _mm512_sub_ps(ra_x_f32x16, pb_x_f32x16);
630
+ __m512 delta_y_f32x16 = _mm512_sub_ps(ra_y_f32x16, pb_y_f32x16);
631
+ __m512 delta_z_f32x16 = _mm512_sub_ps(ra_z_f32x16, pb_z_f32x16);
632
+
633
+ sum_squared_f32x16 = _mm512_fmadd_ps(delta_x_f32x16, delta_x_f32x16, sum_squared_f32x16);
634
+ sum_squared_f32x16 = _mm512_fmadd_ps(delta_y_f32x16, delta_y_f32x16, sum_squared_f32x16);
635
+ sum_squared_f32x16 = _mm512_fmadd_ps(delta_z_f32x16, delta_z_f32x16, sum_squared_f32x16);
435
636
  }
436
637
 
437
- nk_f64_t inv_n = 1.0 / (nk_f64_t)n, n_f64 = (nk_f64_t)n;
438
- *centroid_a_x = sum_a_x * inv_n, *centroid_a_y = sum_a_y * inv_n, *centroid_a_z = sum_a_z * inv_n;
439
- *centroid_b_x = sum_b_x * inv_n, *centroid_b_y = sum_b_y * inv_n, *centroid_b_z = sum_b_z * inv_n;
440
- cross_covariance_f64[0] = covariance_00 - n_f64 * (*centroid_a_x) * (*centroid_b_x),
441
- cross_covariance_f64[1] = covariance_01 - n_f64 * (*centroid_a_x) * (*centroid_b_y),
442
- cross_covariance_f64[2] = covariance_02 - n_f64 * (*centroid_a_x) * (*centroid_b_z);
443
- cross_covariance_f64[3] = covariance_10 - n_f64 * (*centroid_a_y) * (*centroid_b_x),
444
- cross_covariance_f64[4] = covariance_11 - n_f64 * (*centroid_a_y) * (*centroid_b_y),
445
- cross_covariance_f64[5] = covariance_12 - n_f64 * (*centroid_a_y) * (*centroid_b_z);
446
- cross_covariance_f64[6] = covariance_20 - n_f64 * (*centroid_a_z) * (*centroid_b_x),
447
- cross_covariance_f64[7] = covariance_21 - n_f64 * (*centroid_a_z) * (*centroid_b_y),
448
- cross_covariance_f64[8] = covariance_22 - n_f64 * (*centroid_a_z) * (*centroid_b_z);
638
+ return _mm512_reduce_add_ps(sum_squared_f32x16);
449
639
  }
450
640
 
451
641
  NK_PUBLIC void nk_rmsd_f32_skylake(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f32_t *a_centroid,
452
642
  nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f64_t *result) {
453
- nk_f64_t identity[9] = {1, 0, 0, 0, 1, 0, 0, 0, 1};
454
- nk_f64_t centroid_a_x, centroid_a_y, centroid_a_z, centroid_b_x, centroid_b_y, centroid_b_z;
455
- nk_f64_t cross_covariance_f64[9];
456
643
  if (rotation)
457
644
  rotation[0] = 1, rotation[1] = 0, rotation[2] = 0, rotation[3] = 0, rotation[4] = 1, rotation[5] = 0,
458
645
  rotation[6] = 0, rotation[7] = 0, rotation[8] = 1;
459
646
  if (scale) *scale = 1.0f;
460
- nk_centroid_and_cross_covariance_f32_skylake_(a, b, n, &centroid_a_x, &centroid_a_y, &centroid_a_z, &centroid_b_x,
461
- &centroid_b_y, &centroid_b_z, cross_covariance_f64);
647
+
648
+ // Fused single-pass: centroids + squared differences in f64, using the identity:
649
+ // RMSD = √(E[(a-b)²] - (ā - b̄)²)
650
+ __m512d const zeros_f64x8 = _mm512_setzero_pd();
651
+ __m512d sum_a_x_f64x8 = zeros_f64x8, sum_a_y_f64x8 = zeros_f64x8, sum_a_z_f64x8 = zeros_f64x8;
652
+ __m512d sum_b_x_f64x8 = zeros_f64x8, sum_b_y_f64x8 = zeros_f64x8, sum_b_z_f64x8 = zeros_f64x8;
653
+ __m512d sum_squared_x_f64x8 = zeros_f64x8, sum_squared_y_f64x8 = zeros_f64x8, sum_squared_z_f64x8 = zeros_f64x8;
654
+ __m512 a_x_f32x16, a_y_f32x16, a_z_f32x16, b_x_f32x16, b_y_f32x16, b_z_f32x16;
655
+ nk_size_t i = 0;
656
+
657
+ // Main loop with 2x unrolling (32 points per iteration)
658
+ for (; i + 32 <= n; i += 32) {
659
+ // Iteration 0: points i..i+15
660
+ nk_deinterleave_f32x16_skylake_(a + i * 3, &a_x_f32x16, &a_y_f32x16, &a_z_f32x16);
661
+ nk_deinterleave_f32x16_skylake_(b + i * 3, &b_x_f32x16, &b_y_f32x16, &b_z_f32x16);
662
+ __m512d a_x_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_x_f32x16));
663
+ __m512d a_x_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_x_f32x16, 1));
664
+ __m512d a_y_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_y_f32x16));
665
+ __m512d a_y_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_y_f32x16, 1));
666
+ __m512d a_z_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_z_f32x16));
667
+ __m512d a_z_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_z_f32x16, 1));
668
+ __m512d b_x_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_x_f32x16));
669
+ __m512d b_x_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_x_f32x16, 1));
670
+ __m512d b_y_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_y_f32x16));
671
+ __m512d b_y_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_y_f32x16, 1));
672
+ __m512d b_z_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_z_f32x16));
673
+ __m512d b_z_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_z_f32x16, 1));
674
+
675
+ sum_a_x_f64x8 = _mm512_add_pd(sum_a_x_f64x8, _mm512_add_pd(a_x_low_f64x8, a_x_high_f64x8));
676
+ sum_a_y_f64x8 = _mm512_add_pd(sum_a_y_f64x8, _mm512_add_pd(a_y_low_f64x8, a_y_high_f64x8));
677
+ sum_a_z_f64x8 = _mm512_add_pd(sum_a_z_f64x8, _mm512_add_pd(a_z_low_f64x8, a_z_high_f64x8));
678
+ sum_b_x_f64x8 = _mm512_add_pd(sum_b_x_f64x8, _mm512_add_pd(b_x_low_f64x8, b_x_high_f64x8));
679
+ sum_b_y_f64x8 = _mm512_add_pd(sum_b_y_f64x8, _mm512_add_pd(b_y_low_f64x8, b_y_high_f64x8));
680
+ sum_b_z_f64x8 = _mm512_add_pd(sum_b_z_f64x8, _mm512_add_pd(b_z_low_f64x8, b_z_high_f64x8));
681
+
682
+ __m512d delta_x_low_f64x8 = _mm512_sub_pd(a_x_low_f64x8, b_x_low_f64x8);
683
+ __m512d delta_x_high_f64x8 = _mm512_sub_pd(a_x_high_f64x8, b_x_high_f64x8);
684
+ __m512d delta_y_low_f64x8 = _mm512_sub_pd(a_y_low_f64x8, b_y_low_f64x8);
685
+ __m512d delta_y_high_f64x8 = _mm512_sub_pd(a_y_high_f64x8, b_y_high_f64x8);
686
+ __m512d delta_z_low_f64x8 = _mm512_sub_pd(a_z_low_f64x8, b_z_low_f64x8);
687
+ __m512d delta_z_high_f64x8 = _mm512_sub_pd(a_z_high_f64x8, b_z_high_f64x8);
688
+ sum_squared_x_f64x8 = _mm512_fmadd_pd(delta_x_low_f64x8, delta_x_low_f64x8, sum_squared_x_f64x8);
689
+ sum_squared_x_f64x8 = _mm512_fmadd_pd(delta_x_high_f64x8, delta_x_high_f64x8, sum_squared_x_f64x8);
690
+ sum_squared_y_f64x8 = _mm512_fmadd_pd(delta_y_low_f64x8, delta_y_low_f64x8, sum_squared_y_f64x8);
691
+ sum_squared_y_f64x8 = _mm512_fmadd_pd(delta_y_high_f64x8, delta_y_high_f64x8, sum_squared_y_f64x8);
692
+ sum_squared_z_f64x8 = _mm512_fmadd_pd(delta_z_low_f64x8, delta_z_low_f64x8, sum_squared_z_f64x8);
693
+ sum_squared_z_f64x8 = _mm512_fmadd_pd(delta_z_high_f64x8, delta_z_high_f64x8, sum_squared_z_f64x8);
694
+
695
+ // Iteration 1: points i+16..i+31
696
+ nk_deinterleave_f32x16_skylake_(a + (i + 16) * 3, &a_x_f32x16, &a_y_f32x16, &a_z_f32x16);
697
+ nk_deinterleave_f32x16_skylake_(b + (i + 16) * 3, &b_x_f32x16, &b_y_f32x16, &b_z_f32x16);
698
+ a_x_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_x_f32x16));
699
+ a_x_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_x_f32x16, 1));
700
+ a_y_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_y_f32x16));
701
+ a_y_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_y_f32x16, 1));
702
+ a_z_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_z_f32x16));
703
+ a_z_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_z_f32x16, 1));
704
+ b_x_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_x_f32x16));
705
+ b_x_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_x_f32x16, 1));
706
+ b_y_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_y_f32x16));
707
+ b_y_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_y_f32x16, 1));
708
+ b_z_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_z_f32x16));
709
+ b_z_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_z_f32x16, 1));
710
+
711
+ sum_a_x_f64x8 = _mm512_add_pd(sum_a_x_f64x8, _mm512_add_pd(a_x_low_f64x8, a_x_high_f64x8));
712
+ sum_a_y_f64x8 = _mm512_add_pd(sum_a_y_f64x8, _mm512_add_pd(a_y_low_f64x8, a_y_high_f64x8));
713
+ sum_a_z_f64x8 = _mm512_add_pd(sum_a_z_f64x8, _mm512_add_pd(a_z_low_f64x8, a_z_high_f64x8));
714
+ sum_b_x_f64x8 = _mm512_add_pd(sum_b_x_f64x8, _mm512_add_pd(b_x_low_f64x8, b_x_high_f64x8));
715
+ sum_b_y_f64x8 = _mm512_add_pd(sum_b_y_f64x8, _mm512_add_pd(b_y_low_f64x8, b_y_high_f64x8));
716
+ sum_b_z_f64x8 = _mm512_add_pd(sum_b_z_f64x8, _mm512_add_pd(b_z_low_f64x8, b_z_high_f64x8));
717
+
718
+ delta_x_low_f64x8 = _mm512_sub_pd(a_x_low_f64x8, b_x_low_f64x8);
719
+ delta_x_high_f64x8 = _mm512_sub_pd(a_x_high_f64x8, b_x_high_f64x8);
720
+ delta_y_low_f64x8 = _mm512_sub_pd(a_y_low_f64x8, b_y_low_f64x8);
721
+ delta_y_high_f64x8 = _mm512_sub_pd(a_y_high_f64x8, b_y_high_f64x8);
722
+ delta_z_low_f64x8 = _mm512_sub_pd(a_z_low_f64x8, b_z_low_f64x8);
723
+ delta_z_high_f64x8 = _mm512_sub_pd(a_z_high_f64x8, b_z_high_f64x8);
724
+ sum_squared_x_f64x8 = _mm512_fmadd_pd(delta_x_low_f64x8, delta_x_low_f64x8, sum_squared_x_f64x8);
725
+ sum_squared_x_f64x8 = _mm512_fmadd_pd(delta_x_high_f64x8, delta_x_high_f64x8, sum_squared_x_f64x8);
726
+ sum_squared_y_f64x8 = _mm512_fmadd_pd(delta_y_low_f64x8, delta_y_low_f64x8, sum_squared_y_f64x8);
727
+ sum_squared_y_f64x8 = _mm512_fmadd_pd(delta_y_high_f64x8, delta_y_high_f64x8, sum_squared_y_f64x8);
728
+ sum_squared_z_f64x8 = _mm512_fmadd_pd(delta_z_low_f64x8, delta_z_low_f64x8, sum_squared_z_f64x8);
729
+ sum_squared_z_f64x8 = _mm512_fmadd_pd(delta_z_high_f64x8, delta_z_high_f64x8, sum_squared_z_f64x8);
730
+ }
731
+
732
+ // Handle 16-point remainder
733
+ for (; i + 16 <= n; i += 16) {
734
+ nk_deinterleave_f32x16_skylake_(a + i * 3, &a_x_f32x16, &a_y_f32x16, &a_z_f32x16);
735
+ nk_deinterleave_f32x16_skylake_(b + i * 3, &b_x_f32x16, &b_y_f32x16, &b_z_f32x16);
736
+ __m512d a_x_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_x_f32x16));
737
+ __m512d a_x_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_x_f32x16, 1));
738
+ __m512d a_y_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_y_f32x16));
739
+ __m512d a_y_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_y_f32x16, 1));
740
+ __m512d a_z_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_z_f32x16));
741
+ __m512d a_z_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_z_f32x16, 1));
742
+ __m512d b_x_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_x_f32x16));
743
+ __m512d b_x_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_x_f32x16, 1));
744
+ __m512d b_y_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_y_f32x16));
745
+ __m512d b_y_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_y_f32x16, 1));
746
+ __m512d b_z_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_z_f32x16));
747
+ __m512d b_z_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_z_f32x16, 1));
748
+
749
+ sum_a_x_f64x8 = _mm512_add_pd(sum_a_x_f64x8, _mm512_add_pd(a_x_low_f64x8, a_x_high_f64x8));
750
+ sum_a_y_f64x8 = _mm512_add_pd(sum_a_y_f64x8, _mm512_add_pd(a_y_low_f64x8, a_y_high_f64x8));
751
+ sum_a_z_f64x8 = _mm512_add_pd(sum_a_z_f64x8, _mm512_add_pd(a_z_low_f64x8, a_z_high_f64x8));
752
+ sum_b_x_f64x8 = _mm512_add_pd(sum_b_x_f64x8, _mm512_add_pd(b_x_low_f64x8, b_x_high_f64x8));
753
+ sum_b_y_f64x8 = _mm512_add_pd(sum_b_y_f64x8, _mm512_add_pd(b_y_low_f64x8, b_y_high_f64x8));
754
+ sum_b_z_f64x8 = _mm512_add_pd(sum_b_z_f64x8, _mm512_add_pd(b_z_low_f64x8, b_z_high_f64x8));
755
+
756
+ __m512d delta_x_low_f64x8 = _mm512_sub_pd(a_x_low_f64x8, b_x_low_f64x8);
757
+ __m512d delta_x_high_f64x8 = _mm512_sub_pd(a_x_high_f64x8, b_x_high_f64x8);
758
+ __m512d delta_y_low_f64x8 = _mm512_sub_pd(a_y_low_f64x8, b_y_low_f64x8);
759
+ __m512d delta_y_high_f64x8 = _mm512_sub_pd(a_y_high_f64x8, b_y_high_f64x8);
760
+ __m512d delta_z_low_f64x8 = _mm512_sub_pd(a_z_low_f64x8, b_z_low_f64x8);
761
+ __m512d delta_z_high_f64x8 = _mm512_sub_pd(a_z_high_f64x8, b_z_high_f64x8);
762
+ sum_squared_x_f64x8 = _mm512_fmadd_pd(delta_x_low_f64x8, delta_x_low_f64x8, sum_squared_x_f64x8);
763
+ sum_squared_x_f64x8 = _mm512_fmadd_pd(delta_x_high_f64x8, delta_x_high_f64x8, sum_squared_x_f64x8);
764
+ sum_squared_y_f64x8 = _mm512_fmadd_pd(delta_y_low_f64x8, delta_y_low_f64x8, sum_squared_y_f64x8);
765
+ sum_squared_y_f64x8 = _mm512_fmadd_pd(delta_y_high_f64x8, delta_y_high_f64x8, sum_squared_y_f64x8);
766
+ sum_squared_z_f64x8 = _mm512_fmadd_pd(delta_z_low_f64x8, delta_z_low_f64x8, sum_squared_z_f64x8);
767
+ sum_squared_z_f64x8 = _mm512_fmadd_pd(delta_z_high_f64x8, delta_z_high_f64x8, sum_squared_z_f64x8);
768
+ }
769
+
770
+ // Tail: use masked gather for remaining < 16 points
771
+ if (i < n) {
772
+ nk_size_t tail = n - i;
773
+ __mmask16 mask = (__mmask16)_bzhi_u32(0xFFFF, tail);
774
+ __m512i const gather_idx_i32x16 = _mm512_setr_epi32(0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33, 36, 39, 42, 45);
775
+ __m512 zeros_f32x16 = _mm512_setzero_ps();
776
+ nk_f32_t const *a_tail = a + i * 3;
777
+ nk_f32_t const *b_tail = b + i * 3;
778
+
779
+ a_x_f32x16 = _mm512_mask_i32gather_ps(zeros_f32x16, mask, gather_idx_i32x16, a_tail + 0, 4);
780
+ a_y_f32x16 = _mm512_mask_i32gather_ps(zeros_f32x16, mask, gather_idx_i32x16, a_tail + 1, 4);
781
+ a_z_f32x16 = _mm512_mask_i32gather_ps(zeros_f32x16, mask, gather_idx_i32x16, a_tail + 2, 4);
782
+ b_x_f32x16 = _mm512_mask_i32gather_ps(zeros_f32x16, mask, gather_idx_i32x16, b_tail + 0, 4);
783
+ b_y_f32x16 = _mm512_mask_i32gather_ps(zeros_f32x16, mask, gather_idx_i32x16, b_tail + 1, 4);
784
+ b_z_f32x16 = _mm512_mask_i32gather_ps(zeros_f32x16, mask, gather_idx_i32x16, b_tail + 2, 4);
785
+
786
+ __m512d a_x_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_x_f32x16));
787
+ __m512d a_x_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_x_f32x16, 1));
788
+ __m512d a_y_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_y_f32x16));
789
+ __m512d a_y_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_y_f32x16, 1));
790
+ __m512d a_z_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_z_f32x16));
791
+ __m512d a_z_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_z_f32x16, 1));
792
+ __m512d b_x_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_x_f32x16));
793
+ __m512d b_x_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_x_f32x16, 1));
794
+ __m512d b_y_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_y_f32x16));
795
+ __m512d b_y_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_y_f32x16, 1));
796
+ __m512d b_z_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_z_f32x16));
797
+ __m512d b_z_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_z_f32x16, 1));
798
+
799
+ sum_a_x_f64x8 = _mm512_add_pd(sum_a_x_f64x8, _mm512_add_pd(a_x_low_f64x8, a_x_high_f64x8));
800
+ sum_a_y_f64x8 = _mm512_add_pd(sum_a_y_f64x8, _mm512_add_pd(a_y_low_f64x8, a_y_high_f64x8));
801
+ sum_a_z_f64x8 = _mm512_add_pd(sum_a_z_f64x8, _mm512_add_pd(a_z_low_f64x8, a_z_high_f64x8));
802
+ sum_b_x_f64x8 = _mm512_add_pd(sum_b_x_f64x8, _mm512_add_pd(b_x_low_f64x8, b_x_high_f64x8));
803
+ sum_b_y_f64x8 = _mm512_add_pd(sum_b_y_f64x8, _mm512_add_pd(b_y_low_f64x8, b_y_high_f64x8));
804
+ sum_b_z_f64x8 = _mm512_add_pd(sum_b_z_f64x8, _mm512_add_pd(b_z_low_f64x8, b_z_high_f64x8));
805
+
806
+ __m512d delta_x_low_f64x8 = _mm512_sub_pd(a_x_low_f64x8, b_x_low_f64x8);
807
+ __m512d delta_x_high_f64x8 = _mm512_sub_pd(a_x_high_f64x8, b_x_high_f64x8);
808
+ __m512d delta_y_low_f64x8 = _mm512_sub_pd(a_y_low_f64x8, b_y_low_f64x8);
809
+ __m512d delta_y_high_f64x8 = _mm512_sub_pd(a_y_high_f64x8, b_y_high_f64x8);
810
+ __m512d delta_z_low_f64x8 = _mm512_sub_pd(a_z_low_f64x8, b_z_low_f64x8);
811
+ __m512d delta_z_high_f64x8 = _mm512_sub_pd(a_z_high_f64x8, b_z_high_f64x8);
812
+ sum_squared_x_f64x8 = _mm512_fmadd_pd(delta_x_low_f64x8, delta_x_low_f64x8, sum_squared_x_f64x8);
813
+ sum_squared_x_f64x8 = _mm512_fmadd_pd(delta_x_high_f64x8, delta_x_high_f64x8, sum_squared_x_f64x8);
814
+ sum_squared_y_f64x8 = _mm512_fmadd_pd(delta_y_low_f64x8, delta_y_low_f64x8, sum_squared_y_f64x8);
815
+ sum_squared_y_f64x8 = _mm512_fmadd_pd(delta_y_high_f64x8, delta_y_high_f64x8, sum_squared_y_f64x8);
816
+ sum_squared_z_f64x8 = _mm512_fmadd_pd(delta_z_low_f64x8, delta_z_low_f64x8, sum_squared_z_f64x8);
817
+ sum_squared_z_f64x8 = _mm512_fmadd_pd(delta_z_high_f64x8, delta_z_high_f64x8, sum_squared_z_f64x8);
818
+ }
819
+
820
+ // Reduce and compute centroids
821
+ nk_f64_t inv_n = 1.0 / (nk_f64_t)n;
822
+ nk_f64_t total_ax = _mm512_reduce_add_pd(sum_a_x_f64x8);
823
+ nk_f64_t total_ay = _mm512_reduce_add_pd(sum_a_y_f64x8);
824
+ nk_f64_t total_az = _mm512_reduce_add_pd(sum_a_z_f64x8);
825
+ nk_f64_t total_bx = _mm512_reduce_add_pd(sum_b_x_f64x8);
826
+ nk_f64_t total_by = _mm512_reduce_add_pd(sum_b_y_f64x8);
827
+ nk_f64_t total_bz = _mm512_reduce_add_pd(sum_b_z_f64x8);
828
+ nk_f64_t total_sq_x = _mm512_reduce_add_pd(sum_squared_x_f64x8);
829
+ nk_f64_t total_sq_y = _mm512_reduce_add_pd(sum_squared_y_f64x8);
830
+ nk_f64_t total_sq_z = _mm512_reduce_add_pd(sum_squared_z_f64x8);
831
+
832
+ nk_f64_t centroid_a_x = total_ax * inv_n, centroid_a_y = total_ay * inv_n, centroid_a_z = total_az * inv_n;
833
+ nk_f64_t centroid_b_x = total_bx * inv_n, centroid_b_y = total_by * inv_n, centroid_b_z = total_bz * inv_n;
462
834
  if (a_centroid)
463
835
  a_centroid[0] = (nk_f32_t)centroid_a_x, a_centroid[1] = (nk_f32_t)centroid_a_y,
464
836
  a_centroid[2] = (nk_f32_t)centroid_a_z;
465
837
  if (b_centroid)
466
838
  b_centroid[0] = (nk_f32_t)centroid_b_x, b_centroid[1] = (nk_f32_t)centroid_b_y,
467
839
  b_centroid[2] = (nk_f32_t)centroid_b_z;
468
- *result = nk_f64_sqrt_haswell(nk_transformed_ssd_f32_skylake_(a, b, n, identity, 1.0, centroid_a_x, centroid_a_y,
469
- centroid_a_z, centroid_b_x, centroid_b_y,
470
- centroid_b_z) /
471
- (nk_f64_t)n);
840
+
841
+ nk_f64_t mean_diff_x = centroid_a_x - centroid_b_x, mean_diff_y = centroid_a_y - centroid_b_y,
842
+ mean_diff_z = centroid_a_z - centroid_b_z;
843
+ nk_f64_t sum_squared = total_sq_x + total_sq_y + total_sq_z;
844
+ nk_f64_t mean_diff_sq = mean_diff_x * mean_diff_x + mean_diff_y * mean_diff_y + mean_diff_z * mean_diff_z;
845
+ *result = nk_f64_sqrt_haswell(sum_squared * inv_n - mean_diff_sq);
472
846
  }
473
847
 
474
848
  NK_PUBLIC void nk_kabsch_f32_skylake(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f32_t *a_centroid,
475
849
  nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f64_t *result) {
476
- nk_f64_t centroid_a_x, centroid_a_y, centroid_a_z, centroid_b_x, centroid_b_y, centroid_b_z;
477
- nk_f64_t cross_covariance_f64[9];
478
- nk_centroid_and_cross_covariance_f32_skylake_(a, b, n, &centroid_a_x, &centroid_a_y, &centroid_a_z, &centroid_b_x,
479
- &centroid_b_y, &centroid_b_z, cross_covariance_f64);
850
+ // Fused single-pass: centroids + covariance in f64
851
+ __m512d const zeros_f64x8 = _mm512_setzero_pd();
852
+ __m512d sum_a_x_f64x8 = zeros_f64x8, sum_a_y_f64x8 = zeros_f64x8, sum_a_z_f64x8 = zeros_f64x8;
853
+ __m512d sum_b_x_f64x8 = zeros_f64x8, sum_b_y_f64x8 = zeros_f64x8, sum_b_z_f64x8 = zeros_f64x8;
854
+ __m512d cov_xx_f64x8 = zeros_f64x8, cov_xy_f64x8 = zeros_f64x8, cov_xz_f64x8 = zeros_f64x8;
855
+ __m512d cov_yx_f64x8 = zeros_f64x8, cov_yy_f64x8 = zeros_f64x8, cov_yz_f64x8 = zeros_f64x8;
856
+ __m512d cov_zx_f64x8 = zeros_f64x8, cov_zy_f64x8 = zeros_f64x8, cov_zz_f64x8 = zeros_f64x8;
857
+ __m512 a_x_f32x16, a_y_f32x16, a_z_f32x16, b_x_f32x16, b_y_f32x16, b_z_f32x16;
858
+ nk_size_t i = 0;
859
+
860
+ for (; i + 16 <= n; i += 16) {
861
+ nk_deinterleave_f32x16_skylake_(a + i * 3, &a_x_f32x16, &a_y_f32x16, &a_z_f32x16);
862
+ nk_deinterleave_f32x16_skylake_(b + i * 3, &b_x_f32x16, &b_y_f32x16, &b_z_f32x16);
863
+ __m512d a_x_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_x_f32x16));
864
+ __m512d a_x_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_x_f32x16, 1));
865
+ __m512d a_y_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_y_f32x16));
866
+ __m512d a_y_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_y_f32x16, 1));
867
+ __m512d a_z_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_z_f32x16));
868
+ __m512d a_z_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_z_f32x16, 1));
869
+ __m512d b_x_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_x_f32x16));
870
+ __m512d b_x_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_x_f32x16, 1));
871
+ __m512d b_y_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_y_f32x16));
872
+ __m512d b_y_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_y_f32x16, 1));
873
+ __m512d b_z_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_z_f32x16));
874
+ __m512d b_z_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_z_f32x16, 1));
875
+
876
+ sum_a_x_f64x8 = _mm512_add_pd(sum_a_x_f64x8, _mm512_add_pd(a_x_low_f64x8, a_x_high_f64x8));
877
+ sum_a_y_f64x8 = _mm512_add_pd(sum_a_y_f64x8, _mm512_add_pd(a_y_low_f64x8, a_y_high_f64x8));
878
+ sum_a_z_f64x8 = _mm512_add_pd(sum_a_z_f64x8, _mm512_add_pd(a_z_low_f64x8, a_z_high_f64x8));
879
+ sum_b_x_f64x8 = _mm512_add_pd(sum_b_x_f64x8, _mm512_add_pd(b_x_low_f64x8, b_x_high_f64x8));
880
+ sum_b_y_f64x8 = _mm512_add_pd(sum_b_y_f64x8, _mm512_add_pd(b_y_low_f64x8, b_y_high_f64x8));
881
+ sum_b_z_f64x8 = _mm512_add_pd(sum_b_z_f64x8, _mm512_add_pd(b_z_low_f64x8, b_z_high_f64x8));
882
+
883
+ cov_xx_f64x8 = _mm512_add_pd(cov_xx_f64x8, _mm512_add_pd(_mm512_mul_pd(a_x_low_f64x8, b_x_low_f64x8),
884
+ _mm512_mul_pd(a_x_high_f64x8, b_x_high_f64x8)));
885
+ cov_xy_f64x8 = _mm512_add_pd(cov_xy_f64x8, _mm512_add_pd(_mm512_mul_pd(a_x_low_f64x8, b_y_low_f64x8),
886
+ _mm512_mul_pd(a_x_high_f64x8, b_y_high_f64x8)));
887
+ cov_xz_f64x8 = _mm512_add_pd(cov_xz_f64x8, _mm512_add_pd(_mm512_mul_pd(a_x_low_f64x8, b_z_low_f64x8),
888
+ _mm512_mul_pd(a_x_high_f64x8, b_z_high_f64x8)));
889
+ cov_yx_f64x8 = _mm512_add_pd(cov_yx_f64x8, _mm512_add_pd(_mm512_mul_pd(a_y_low_f64x8, b_x_low_f64x8),
890
+ _mm512_mul_pd(a_y_high_f64x8, b_x_high_f64x8)));
891
+ cov_yy_f64x8 = _mm512_add_pd(cov_yy_f64x8, _mm512_add_pd(_mm512_mul_pd(a_y_low_f64x8, b_y_low_f64x8),
892
+ _mm512_mul_pd(a_y_high_f64x8, b_y_high_f64x8)));
893
+ cov_yz_f64x8 = _mm512_add_pd(cov_yz_f64x8, _mm512_add_pd(_mm512_mul_pd(a_y_low_f64x8, b_z_low_f64x8),
894
+ _mm512_mul_pd(a_y_high_f64x8, b_z_high_f64x8)));
895
+ cov_zx_f64x8 = _mm512_add_pd(cov_zx_f64x8, _mm512_add_pd(_mm512_mul_pd(a_z_low_f64x8, b_x_low_f64x8),
896
+ _mm512_mul_pd(a_z_high_f64x8, b_x_high_f64x8)));
897
+ cov_zy_f64x8 = _mm512_add_pd(cov_zy_f64x8, _mm512_add_pd(_mm512_mul_pd(a_z_low_f64x8, b_y_low_f64x8),
898
+ _mm512_mul_pd(a_z_high_f64x8, b_y_high_f64x8)));
899
+ cov_zz_f64x8 = _mm512_add_pd(cov_zz_f64x8, _mm512_add_pd(_mm512_mul_pd(a_z_low_f64x8, b_z_low_f64x8),
900
+ _mm512_mul_pd(a_z_high_f64x8, b_z_high_f64x8)));
901
+ }
902
+
903
+ // Tail: use masked gather for remaining < 16 points
904
+ if (i < n) {
905
+ nk_size_t tail = n - i;
906
+ __mmask16 mask = (__mmask16)_bzhi_u32(0xFFFF, tail);
907
+ __m512i const gather_idx_i32x16 = _mm512_setr_epi32(0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33, 36, 39, 42, 45);
908
+ __m512 zeros_f32x16 = _mm512_setzero_ps();
909
+ nk_f32_t const *a_tail = a + i * 3;
910
+ nk_f32_t const *b_tail = b + i * 3;
911
+
912
+ a_x_f32x16 = _mm512_mask_i32gather_ps(zeros_f32x16, mask, gather_idx_i32x16, a_tail + 0, 4);
913
+ a_y_f32x16 = _mm512_mask_i32gather_ps(zeros_f32x16, mask, gather_idx_i32x16, a_tail + 1, 4);
914
+ a_z_f32x16 = _mm512_mask_i32gather_ps(zeros_f32x16, mask, gather_idx_i32x16, a_tail + 2, 4);
915
+ b_x_f32x16 = _mm512_mask_i32gather_ps(zeros_f32x16, mask, gather_idx_i32x16, b_tail + 0, 4);
916
+ b_y_f32x16 = _mm512_mask_i32gather_ps(zeros_f32x16, mask, gather_idx_i32x16, b_tail + 1, 4);
917
+ b_z_f32x16 = _mm512_mask_i32gather_ps(zeros_f32x16, mask, gather_idx_i32x16, b_tail + 2, 4);
918
+
919
+ __m512d a_x_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_x_f32x16));
920
+ __m512d a_x_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_x_f32x16, 1));
921
+ __m512d a_y_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_y_f32x16));
922
+ __m512d a_y_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_y_f32x16, 1));
923
+ __m512d a_z_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_z_f32x16));
924
+ __m512d a_z_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_z_f32x16, 1));
925
+ __m512d b_x_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_x_f32x16));
926
+ __m512d b_x_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_x_f32x16, 1));
927
+ __m512d b_y_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_y_f32x16));
928
+ __m512d b_y_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_y_f32x16, 1));
929
+ __m512d b_z_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_z_f32x16));
930
+ __m512d b_z_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_z_f32x16, 1));
931
+
932
+ sum_a_x_f64x8 = _mm512_add_pd(sum_a_x_f64x8, _mm512_add_pd(a_x_low_f64x8, a_x_high_f64x8));
933
+ sum_a_y_f64x8 = _mm512_add_pd(sum_a_y_f64x8, _mm512_add_pd(a_y_low_f64x8, a_y_high_f64x8));
934
+ sum_a_z_f64x8 = _mm512_add_pd(sum_a_z_f64x8, _mm512_add_pd(a_z_low_f64x8, a_z_high_f64x8));
935
+ sum_b_x_f64x8 = _mm512_add_pd(sum_b_x_f64x8, _mm512_add_pd(b_x_low_f64x8, b_x_high_f64x8));
936
+ sum_b_y_f64x8 = _mm512_add_pd(sum_b_y_f64x8, _mm512_add_pd(b_y_low_f64x8, b_y_high_f64x8));
937
+ sum_b_z_f64x8 = _mm512_add_pd(sum_b_z_f64x8, _mm512_add_pd(b_z_low_f64x8, b_z_high_f64x8));
938
+
939
+ cov_xx_f64x8 = _mm512_add_pd(cov_xx_f64x8, _mm512_add_pd(_mm512_mul_pd(a_x_low_f64x8, b_x_low_f64x8),
940
+ _mm512_mul_pd(a_x_high_f64x8, b_x_high_f64x8)));
941
+ cov_xy_f64x8 = _mm512_add_pd(cov_xy_f64x8, _mm512_add_pd(_mm512_mul_pd(a_x_low_f64x8, b_y_low_f64x8),
942
+ _mm512_mul_pd(a_x_high_f64x8, b_y_high_f64x8)));
943
+ cov_xz_f64x8 = _mm512_add_pd(cov_xz_f64x8, _mm512_add_pd(_mm512_mul_pd(a_x_low_f64x8, b_z_low_f64x8),
944
+ _mm512_mul_pd(a_x_high_f64x8, b_z_high_f64x8)));
945
+ cov_yx_f64x8 = _mm512_add_pd(cov_yx_f64x8, _mm512_add_pd(_mm512_mul_pd(a_y_low_f64x8, b_x_low_f64x8),
946
+ _mm512_mul_pd(a_y_high_f64x8, b_x_high_f64x8)));
947
+ cov_yy_f64x8 = _mm512_add_pd(cov_yy_f64x8, _mm512_add_pd(_mm512_mul_pd(a_y_low_f64x8, b_y_low_f64x8),
948
+ _mm512_mul_pd(a_y_high_f64x8, b_y_high_f64x8)));
949
+ cov_yz_f64x8 = _mm512_add_pd(cov_yz_f64x8, _mm512_add_pd(_mm512_mul_pd(a_y_low_f64x8, b_z_low_f64x8),
950
+ _mm512_mul_pd(a_y_high_f64x8, b_z_high_f64x8)));
951
+ cov_zx_f64x8 = _mm512_add_pd(cov_zx_f64x8, _mm512_add_pd(_mm512_mul_pd(a_z_low_f64x8, b_x_low_f64x8),
952
+ _mm512_mul_pd(a_z_high_f64x8, b_x_high_f64x8)));
953
+ cov_zy_f64x8 = _mm512_add_pd(cov_zy_f64x8, _mm512_add_pd(_mm512_mul_pd(a_z_low_f64x8, b_y_low_f64x8),
954
+ _mm512_mul_pd(a_z_high_f64x8, b_y_high_f64x8)));
955
+ cov_zz_f64x8 = _mm512_add_pd(cov_zz_f64x8, _mm512_add_pd(_mm512_mul_pd(a_z_low_f64x8, b_z_low_f64x8),
956
+ _mm512_mul_pd(a_z_high_f64x8, b_z_high_f64x8)));
957
+ }
958
+
959
+ nk_f64_t sum_a_x = _mm512_reduce_add_pd(sum_a_x_f64x8), sum_a_y = _mm512_reduce_add_pd(sum_a_y_f64x8),
960
+ sum_a_z = _mm512_reduce_add_pd(sum_a_z_f64x8);
961
+ nk_f64_t sum_b_x = _mm512_reduce_add_pd(sum_b_x_f64x8), sum_b_y = _mm512_reduce_add_pd(sum_b_y_f64x8),
962
+ sum_b_z = _mm512_reduce_add_pd(sum_b_z_f64x8);
963
+ nk_f64_t covariance_x_x = _mm512_reduce_add_pd(cov_xx_f64x8), covariance_x_y = _mm512_reduce_add_pd(cov_xy_f64x8),
964
+ covariance_x_z = _mm512_reduce_add_pd(cov_xz_f64x8);
965
+ nk_f64_t covariance_y_x = _mm512_reduce_add_pd(cov_yx_f64x8), covariance_y_y = _mm512_reduce_add_pd(cov_yy_f64x8),
966
+ covariance_y_z = _mm512_reduce_add_pd(cov_yz_f64x8);
967
+ nk_f64_t covariance_z_x = _mm512_reduce_add_pd(cov_zx_f64x8), covariance_z_y = _mm512_reduce_add_pd(cov_zy_f64x8),
968
+ covariance_z_z = _mm512_reduce_add_pd(cov_zz_f64x8);
969
+
970
+ nk_f64_t inv_n = 1.0 / (nk_f64_t)n;
971
+ nk_f64_t centroid_a_x = sum_a_x * inv_n, centroid_a_y = sum_a_y * inv_n, centroid_a_z = sum_a_z * inv_n;
972
+ nk_f64_t centroid_b_x = sum_b_x * inv_n, centroid_b_y = sum_b_y * inv_n, centroid_b_z = sum_b_z * inv_n;
480
973
  if (a_centroid)
481
974
  a_centroid[0] = (nk_f32_t)centroid_a_x, a_centroid[1] = (nk_f32_t)centroid_a_y,
482
975
  a_centroid[2] = (nk_f32_t)centroid_a_z;
@@ -485,51 +978,40 @@ NK_PUBLIC void nk_kabsch_f32_skylake(nk_f32_t const *a, nk_f32_t const *b, nk_si
485
978
  b_centroid[2] = (nk_f32_t)centroid_b_z;
486
979
  if (scale) *scale = 1.0f;
487
980
 
488
- nk_f64_t svd_u[9], svd_s[9], svd_v[9], r[9];
489
- nk_svd3x3_f64_(cross_covariance_f64, svd_u, svd_s, svd_v);
490
- r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
491
- r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
492
- r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
493
- r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
494
- r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
495
- r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
496
- r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
497
- r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
498
- r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
981
+ nk_f64_t n_f64 = (nk_f64_t)n;
982
+ nk_f64_t cross_covariance[9];
983
+ cross_covariance[0] = covariance_x_x - n_f64 * centroid_a_x * centroid_b_x;
984
+ cross_covariance[1] = covariance_x_y - n_f64 * centroid_a_x * centroid_b_y;
985
+ cross_covariance[2] = covariance_x_z - n_f64 * centroid_a_x * centroid_b_z;
986
+ cross_covariance[3] = covariance_y_x - n_f64 * centroid_a_y * centroid_b_x;
987
+ cross_covariance[4] = covariance_y_y - n_f64 * centroid_a_y * centroid_b_y;
988
+ cross_covariance[5] = covariance_y_z - n_f64 * centroid_a_y * centroid_b_z;
989
+ cross_covariance[6] = covariance_z_x - n_f64 * centroid_a_z * centroid_b_x;
990
+ cross_covariance[7] = covariance_z_y - n_f64 * centroid_a_z * centroid_b_y;
991
+ cross_covariance[8] = covariance_z_z - n_f64 * centroid_a_z * centroid_b_z;
992
+
993
+ nk_f64_t svd_u[9], svd_s[9], svd_v[9];
994
+ nk_svd3x3_f64_(cross_covariance, svd_u, svd_s, svd_v);
995
+ nk_f64_t r[9];
996
+ nk_rotation_from_svd_f64_skylake_(svd_u, svd_v, r);
499
997
  if (nk_det3x3_f64_(r) < 0) {
500
998
  svd_v[2] = -svd_v[2], svd_v[5] = -svd_v[5], svd_v[8] = -svd_v[8];
501
- r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
502
- r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
503
- r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
504
- r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
505
- r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
506
- r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
507
- r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
508
- r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
509
- r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
999
+ nk_rotation_from_svd_f64_skylake_(svd_u, svd_v, r);
510
1000
  }
511
1001
  if (rotation)
512
- for (int index = 0; index != 9; ++index) rotation[index] = (nk_f32_t)r[index];
1002
+ for (int j = 0; j < 9; ++j) rotation[j] = (nk_f32_t)r[j];
513
1003
  *result = nk_f64_sqrt_haswell(nk_transformed_ssd_f32_skylake_(a, b, n, r, 1.0, centroid_a_x, centroid_a_y,
514
1004
  centroid_a_z, centroid_b_x, centroid_b_y,
515
1005
  centroid_b_z) /
516
- (nk_f64_t)n);
1006
+ n_f64);
517
1007
  }
518
1008
 
519
1009
  NK_PUBLIC void nk_rmsd_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *a_centroid,
520
1010
  nk_f64_t *b_centroid, nk_f64_t *rotation, nk_f64_t *scale, nk_f64_t *result) {
521
1011
  // RMSD uses identity rotation and scale=1.0.
522
- if (rotation) {
523
- rotation[0] = 1;
524
- rotation[1] = 0;
525
- rotation[2] = 0;
526
- rotation[3] = 0;
527
- rotation[4] = 1;
528
- rotation[5] = 0;
529
- rotation[6] = 0;
530
- rotation[7] = 0;
531
- rotation[8] = 1;
532
- }
1012
+ if (rotation)
1013
+ rotation[0] = 1, rotation[1] = 0, rotation[2] = 0, rotation[3] = 0, rotation[4] = 1, rotation[5] = 0,
1014
+ rotation[6] = 0, rotation[7] = 0, rotation[8] = 1;
533
1015
  if (scale) *scale = 1.0;
534
1016
  // Optimized fused single-pass implementation for f64.
535
1017
  // Computes centroids and squared differences in one pass using the identity:
@@ -633,6 +1115,7 @@ NK_PUBLIC void nk_rmsd_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_size
633
1115
  sum_squared_x_f64x8 = _mm512_fmadd_pd(delta_x_f64x8, delta_x_f64x8, sum_squared_x_f64x8);
634
1116
  sum_squared_y_f64x8 = _mm512_fmadd_pd(delta_y_f64x8, delta_y_f64x8, sum_squared_y_f64x8);
635
1117
  sum_squared_z_f64x8 = _mm512_fmadd_pd(delta_z_f64x8, delta_z_f64x8, sum_squared_z_f64x8);
1118
+ i = n;
636
1119
  }
637
1120
 
638
1121
  // Reduce and compute centroids.
@@ -759,6 +1242,7 @@ NK_PUBLIC void nk_kabsch_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_si
759
1242
  cov_zx_f64x8 = _mm512_fmadd_pd(a_z_f64x8, b_x_f64x8, cov_zx_f64x8),
760
1243
  cov_zy_f64x8 = _mm512_fmadd_pd(a_z_f64x8, b_y_f64x8, cov_zy_f64x8),
761
1244
  cov_zz_f64x8 = _mm512_fmadd_pd(a_z_f64x8, b_z_f64x8, cov_zz_f64x8);
1245
+ i = n;
762
1246
  }
763
1247
 
764
1248
  // Reduce centroids and covariance.
@@ -840,9 +1324,8 @@ NK_PUBLIC void nk_kabsch_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_si
840
1324
  }
841
1325
 
842
1326
  // Output rotation matrix and scale=1.0.
843
- if (rotation) {
1327
+ if (rotation)
844
1328
  for (int j = 0; j < 9; ++j) rotation[j] = (nk_f64_t)r[j];
845
- }
846
1329
  if (scale) *scale = 1.0;
847
1330
 
848
1331
  // Compute RMSD after optimal rotation
@@ -851,51 +1334,153 @@ NK_PUBLIC void nk_kabsch_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_si
851
1334
  *result = nk_f64_sqrt_haswell(sum_squared * inv_n);
852
1335
  }
853
1336
 
854
- NK_INTERNAL void nk_centroid_and_cross_covariance_and_variance_f32_skylake_( //
855
- nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, //
856
- nk_f64_t *centroid_a_x, nk_f64_t *centroid_a_y, nk_f64_t *centroid_a_z, nk_f64_t *centroid_b_x,
857
- nk_f64_t *centroid_b_y, nk_f64_t *centroid_b_z, nk_f64_t cross_covariance_f64[9], nk_f64_t *variance_a) {
858
- nk_centroid_and_cross_covariance_f32_skylake_(a, b, n, centroid_a_x, centroid_a_y, centroid_a_z, centroid_b_x,
859
- centroid_b_y, centroid_b_z, cross_covariance_f64);
860
- __m512d variance_a_f64x8 = _mm512_setzero_pd();
861
- __m512 a_x_f32x16, a_y_f32x16, a_z_f32x16;
862
- nk_size_t index = 0;
1337
+ NK_PUBLIC void nk_umeyama_f32_skylake(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f32_t *a_centroid,
1338
+ nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f64_t *result) {
1339
+ // Fused single-pass: centroids + covariance + variance of A, all in f64
1340
+ __m512d const zeros_f64x8 = _mm512_setzero_pd();
1341
+ __m512d sum_a_x_f64x8 = zeros_f64x8, sum_a_y_f64x8 = zeros_f64x8, sum_a_z_f64x8 = zeros_f64x8;
1342
+ __m512d sum_b_x_f64x8 = zeros_f64x8, sum_b_y_f64x8 = zeros_f64x8, sum_b_z_f64x8 = zeros_f64x8;
1343
+ __m512d cov_xx_f64x8 = zeros_f64x8, cov_xy_f64x8 = zeros_f64x8, cov_xz_f64x8 = zeros_f64x8;
1344
+ __m512d cov_yx_f64x8 = zeros_f64x8, cov_yy_f64x8 = zeros_f64x8, cov_yz_f64x8 = zeros_f64x8;
1345
+ __m512d cov_zx_f64x8 = zeros_f64x8, cov_zy_f64x8 = zeros_f64x8, cov_zz_f64x8 = zeros_f64x8;
1346
+ __m512d variance_a_f64x8 = zeros_f64x8;
1347
+ __m512 a_x_f32x16, a_y_f32x16, a_z_f32x16, b_x_f32x16, b_y_f32x16, b_z_f32x16;
1348
+ nk_size_t i = 0;
863
1349
 
864
- for (; index + 16 <= n; index += 16) {
865
- nk_deinterleave_f32x16_skylake_(a + index * 3, &a_x_f32x16, &a_y_f32x16, &a_z_f32x16);
866
- __m512d a_x_lower_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_x_f32x16));
867
- __m512d a_x_upper_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_x_f32x16, 1));
868
- __m512d a_y_lower_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_y_f32x16));
869
- __m512d a_y_upper_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_y_f32x16, 1));
870
- __m512d a_z_lower_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_z_f32x16));
871
- __m512d a_z_upper_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_z_f32x16, 1));
872
- __m512d batch_norm_squared_f64x8 = _mm512_add_pd(_mm512_mul_pd(a_x_lower_f64x8, a_x_lower_f64x8),
873
- _mm512_mul_pd(a_x_upper_f64x8, a_x_upper_f64x8));
874
- batch_norm_squared_f64x8 = _mm512_fmadd_pd(a_y_lower_f64x8, a_y_lower_f64x8, batch_norm_squared_f64x8);
875
- batch_norm_squared_f64x8 = _mm512_fmadd_pd(a_y_upper_f64x8, a_y_upper_f64x8, batch_norm_squared_f64x8);
876
- batch_norm_squared_f64x8 = _mm512_fmadd_pd(a_z_lower_f64x8, a_z_lower_f64x8, batch_norm_squared_f64x8);
877
- batch_norm_squared_f64x8 = _mm512_fmadd_pd(a_z_upper_f64x8, a_z_upper_f64x8, batch_norm_squared_f64x8);
878
- variance_a_f64x8 = _mm512_add_pd(variance_a_f64x8, batch_norm_squared_f64x8);
1350
+ for (; i + 16 <= n; i += 16) {
1351
+ nk_deinterleave_f32x16_skylake_(a + i * 3, &a_x_f32x16, &a_y_f32x16, &a_z_f32x16);
1352
+ nk_deinterleave_f32x16_skylake_(b + i * 3, &b_x_f32x16, &b_y_f32x16, &b_z_f32x16);
1353
+ __m512d a_x_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_x_f32x16));
1354
+ __m512d a_x_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_x_f32x16, 1));
1355
+ __m512d a_y_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_y_f32x16));
1356
+ __m512d a_y_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_y_f32x16, 1));
1357
+ __m512d a_z_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_z_f32x16));
1358
+ __m512d a_z_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_z_f32x16, 1));
1359
+ __m512d b_x_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_x_f32x16));
1360
+ __m512d b_x_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_x_f32x16, 1));
1361
+ __m512d b_y_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_y_f32x16));
1362
+ __m512d b_y_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_y_f32x16, 1));
1363
+ __m512d b_z_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_z_f32x16));
1364
+ __m512d b_z_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_z_f32x16, 1));
1365
+
1366
+ sum_a_x_f64x8 = _mm512_add_pd(sum_a_x_f64x8, _mm512_add_pd(a_x_low_f64x8, a_x_high_f64x8));
1367
+ sum_a_y_f64x8 = _mm512_add_pd(sum_a_y_f64x8, _mm512_add_pd(a_y_low_f64x8, a_y_high_f64x8));
1368
+ sum_a_z_f64x8 = _mm512_add_pd(sum_a_z_f64x8, _mm512_add_pd(a_z_low_f64x8, a_z_high_f64x8));
1369
+ sum_b_x_f64x8 = _mm512_add_pd(sum_b_x_f64x8, _mm512_add_pd(b_x_low_f64x8, b_x_high_f64x8));
1370
+ sum_b_y_f64x8 = _mm512_add_pd(sum_b_y_f64x8, _mm512_add_pd(b_y_low_f64x8, b_y_high_f64x8));
1371
+ sum_b_z_f64x8 = _mm512_add_pd(sum_b_z_f64x8, _mm512_add_pd(b_z_low_f64x8, b_z_high_f64x8));
1372
+
1373
+ cov_xx_f64x8 = _mm512_add_pd(cov_xx_f64x8, _mm512_add_pd(_mm512_mul_pd(a_x_low_f64x8, b_x_low_f64x8),
1374
+ _mm512_mul_pd(a_x_high_f64x8, b_x_high_f64x8)));
1375
+ cov_xy_f64x8 = _mm512_add_pd(cov_xy_f64x8, _mm512_add_pd(_mm512_mul_pd(a_x_low_f64x8, b_y_low_f64x8),
1376
+ _mm512_mul_pd(a_x_high_f64x8, b_y_high_f64x8)));
1377
+ cov_xz_f64x8 = _mm512_add_pd(cov_xz_f64x8, _mm512_add_pd(_mm512_mul_pd(a_x_low_f64x8, b_z_low_f64x8),
1378
+ _mm512_mul_pd(a_x_high_f64x8, b_z_high_f64x8)));
1379
+ cov_yx_f64x8 = _mm512_add_pd(cov_yx_f64x8, _mm512_add_pd(_mm512_mul_pd(a_y_low_f64x8, b_x_low_f64x8),
1380
+ _mm512_mul_pd(a_y_high_f64x8, b_x_high_f64x8)));
1381
+ cov_yy_f64x8 = _mm512_add_pd(cov_yy_f64x8, _mm512_add_pd(_mm512_mul_pd(a_y_low_f64x8, b_y_low_f64x8),
1382
+ _mm512_mul_pd(a_y_high_f64x8, b_y_high_f64x8)));
1383
+ cov_yz_f64x8 = _mm512_add_pd(cov_yz_f64x8, _mm512_add_pd(_mm512_mul_pd(a_y_low_f64x8, b_z_low_f64x8),
1384
+ _mm512_mul_pd(a_y_high_f64x8, b_z_high_f64x8)));
1385
+ cov_zx_f64x8 = _mm512_add_pd(cov_zx_f64x8, _mm512_add_pd(_mm512_mul_pd(a_z_low_f64x8, b_x_low_f64x8),
1386
+ _mm512_mul_pd(a_z_high_f64x8, b_x_high_f64x8)));
1387
+ cov_zy_f64x8 = _mm512_add_pd(cov_zy_f64x8, _mm512_add_pd(_mm512_mul_pd(a_z_low_f64x8, b_y_low_f64x8),
1388
+ _mm512_mul_pd(a_z_high_f64x8, b_y_high_f64x8)));
1389
+ cov_zz_f64x8 = _mm512_add_pd(cov_zz_f64x8, _mm512_add_pd(_mm512_mul_pd(a_z_low_f64x8, b_z_low_f64x8),
1390
+ _mm512_mul_pd(a_z_high_f64x8, b_z_high_f64x8)));
1391
+
1392
+ variance_a_f64x8 = _mm512_add_pd(
1393
+ variance_a_f64x8,
1394
+ _mm512_add_pd(_mm512_mul_pd(a_x_low_f64x8, a_x_low_f64x8), _mm512_mul_pd(a_x_high_f64x8, a_x_high_f64x8)));
1395
+ variance_a_f64x8 = _mm512_add_pd(
1396
+ variance_a_f64x8,
1397
+ _mm512_add_pd(_mm512_mul_pd(a_y_low_f64x8, a_y_low_f64x8), _mm512_mul_pd(a_y_high_f64x8, a_y_high_f64x8)));
1398
+ variance_a_f64x8 = _mm512_add_pd(
1399
+ variance_a_f64x8,
1400
+ _mm512_add_pd(_mm512_mul_pd(a_z_low_f64x8, a_z_low_f64x8), _mm512_mul_pd(a_z_high_f64x8, a_z_high_f64x8)));
879
1401
  }
880
1402
 
881
- nk_f64_t variance_sum = _mm512_reduce_add_pd(variance_a_f64x8);
882
- for (; index < n; ++index) {
883
- nk_f64_t a_x = a[index * 3 + 0], a_y = a[index * 3 + 1], a_z = a[index * 3 + 2];
884
- variance_sum += a_x * a_x + a_y * a_y + a_z * a_z;
1403
+ // Tail: use masked gather for remaining < 16 points
1404
+ if (i < n) {
1405
+ nk_size_t tail = n - i;
1406
+ __mmask16 mask = (__mmask16)_bzhi_u32(0xFFFF, tail);
1407
+ __m512i const gather_idx_i32x16 = _mm512_setr_epi32(0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33, 36, 39, 42, 45);
1408
+ __m512 zeros_f32x16 = _mm512_setzero_ps();
1409
+ nk_f32_t const *a_tail = a + i * 3;
1410
+ nk_f32_t const *b_tail = b + i * 3;
1411
+
1412
+ a_x_f32x16 = _mm512_mask_i32gather_ps(zeros_f32x16, mask, gather_idx_i32x16, a_tail + 0, 4);
1413
+ a_y_f32x16 = _mm512_mask_i32gather_ps(zeros_f32x16, mask, gather_idx_i32x16, a_tail + 1, 4);
1414
+ a_z_f32x16 = _mm512_mask_i32gather_ps(zeros_f32x16, mask, gather_idx_i32x16, a_tail + 2, 4);
1415
+ b_x_f32x16 = _mm512_mask_i32gather_ps(zeros_f32x16, mask, gather_idx_i32x16, b_tail + 0, 4);
1416
+ b_y_f32x16 = _mm512_mask_i32gather_ps(zeros_f32x16, mask, gather_idx_i32x16, b_tail + 1, 4);
1417
+ b_z_f32x16 = _mm512_mask_i32gather_ps(zeros_f32x16, mask, gather_idx_i32x16, b_tail + 2, 4);
1418
+
1419
+ __m512d a_x_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_x_f32x16));
1420
+ __m512d a_x_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_x_f32x16, 1));
1421
+ __m512d a_y_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_y_f32x16));
1422
+ __m512d a_y_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_y_f32x16, 1));
1423
+ __m512d a_z_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_z_f32x16));
1424
+ __m512d a_z_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_z_f32x16, 1));
1425
+ __m512d b_x_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_x_f32x16));
1426
+ __m512d b_x_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_x_f32x16, 1));
1427
+ __m512d b_y_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_y_f32x16));
1428
+ __m512d b_y_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_y_f32x16, 1));
1429
+ __m512d b_z_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_z_f32x16));
1430
+ __m512d b_z_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_z_f32x16, 1));
1431
+
1432
+ sum_a_x_f64x8 = _mm512_add_pd(sum_a_x_f64x8, _mm512_add_pd(a_x_low_f64x8, a_x_high_f64x8));
1433
+ sum_a_y_f64x8 = _mm512_add_pd(sum_a_y_f64x8, _mm512_add_pd(a_y_low_f64x8, a_y_high_f64x8));
1434
+ sum_a_z_f64x8 = _mm512_add_pd(sum_a_z_f64x8, _mm512_add_pd(a_z_low_f64x8, a_z_high_f64x8));
1435
+ sum_b_x_f64x8 = _mm512_add_pd(sum_b_x_f64x8, _mm512_add_pd(b_x_low_f64x8, b_x_high_f64x8));
1436
+ sum_b_y_f64x8 = _mm512_add_pd(sum_b_y_f64x8, _mm512_add_pd(b_y_low_f64x8, b_y_high_f64x8));
1437
+ sum_b_z_f64x8 = _mm512_add_pd(sum_b_z_f64x8, _mm512_add_pd(b_z_low_f64x8, b_z_high_f64x8));
1438
+
1439
+ cov_xx_f64x8 = _mm512_add_pd(cov_xx_f64x8, _mm512_add_pd(_mm512_mul_pd(a_x_low_f64x8, b_x_low_f64x8),
1440
+ _mm512_mul_pd(a_x_high_f64x8, b_x_high_f64x8)));
1441
+ cov_xy_f64x8 = _mm512_add_pd(cov_xy_f64x8, _mm512_add_pd(_mm512_mul_pd(a_x_low_f64x8, b_y_low_f64x8),
1442
+ _mm512_mul_pd(a_x_high_f64x8, b_y_high_f64x8)));
1443
+ cov_xz_f64x8 = _mm512_add_pd(cov_xz_f64x8, _mm512_add_pd(_mm512_mul_pd(a_x_low_f64x8, b_z_low_f64x8),
1444
+ _mm512_mul_pd(a_x_high_f64x8, b_z_high_f64x8)));
1445
+ cov_yx_f64x8 = _mm512_add_pd(cov_yx_f64x8, _mm512_add_pd(_mm512_mul_pd(a_y_low_f64x8, b_x_low_f64x8),
1446
+ _mm512_mul_pd(a_y_high_f64x8, b_x_high_f64x8)));
1447
+ cov_yy_f64x8 = _mm512_add_pd(cov_yy_f64x8, _mm512_add_pd(_mm512_mul_pd(a_y_low_f64x8, b_y_low_f64x8),
1448
+ _mm512_mul_pd(a_y_high_f64x8, b_y_high_f64x8)));
1449
+ cov_yz_f64x8 = _mm512_add_pd(cov_yz_f64x8, _mm512_add_pd(_mm512_mul_pd(a_y_low_f64x8, b_z_low_f64x8),
1450
+ _mm512_mul_pd(a_y_high_f64x8, b_z_high_f64x8)));
1451
+ cov_zx_f64x8 = _mm512_add_pd(cov_zx_f64x8, _mm512_add_pd(_mm512_mul_pd(a_z_low_f64x8, b_x_low_f64x8),
1452
+ _mm512_mul_pd(a_z_high_f64x8, b_x_high_f64x8)));
1453
+ cov_zy_f64x8 = _mm512_add_pd(cov_zy_f64x8, _mm512_add_pd(_mm512_mul_pd(a_z_low_f64x8, b_y_low_f64x8),
1454
+ _mm512_mul_pd(a_z_high_f64x8, b_y_high_f64x8)));
1455
+ cov_zz_f64x8 = _mm512_add_pd(cov_zz_f64x8, _mm512_add_pd(_mm512_mul_pd(a_z_low_f64x8, b_z_low_f64x8),
1456
+ _mm512_mul_pd(a_z_high_f64x8, b_z_high_f64x8)));
1457
+
1458
+ variance_a_f64x8 = _mm512_add_pd(
1459
+ variance_a_f64x8,
1460
+ _mm512_add_pd(_mm512_mul_pd(a_x_low_f64x8, a_x_low_f64x8), _mm512_mul_pd(a_x_high_f64x8, a_x_high_f64x8)));
1461
+ variance_a_f64x8 = _mm512_add_pd(
1462
+ variance_a_f64x8,
1463
+ _mm512_add_pd(_mm512_mul_pd(a_y_low_f64x8, a_y_low_f64x8), _mm512_mul_pd(a_y_high_f64x8, a_y_high_f64x8)));
1464
+ variance_a_f64x8 = _mm512_add_pd(
1465
+ variance_a_f64x8,
1466
+ _mm512_add_pd(_mm512_mul_pd(a_z_low_f64x8, a_z_low_f64x8), _mm512_mul_pd(a_z_high_f64x8, a_z_high_f64x8)));
885
1467
  }
886
1468
 
887
- nk_f64_t inv_n = 1.0 / (nk_f64_t)n;
888
- *variance_a = variance_sum * inv_n - ((*centroid_a_x) * (*centroid_a_x) + (*centroid_a_y) * (*centroid_a_y) +
889
- (*centroid_a_z) * (*centroid_a_z));
890
- }
1469
+ nk_f64_t sum_a_x = _mm512_reduce_add_pd(sum_a_x_f64x8), sum_a_y = _mm512_reduce_add_pd(sum_a_y_f64x8),
1470
+ sum_a_z = _mm512_reduce_add_pd(sum_a_z_f64x8);
1471
+ nk_f64_t sum_b_x = _mm512_reduce_add_pd(sum_b_x_f64x8), sum_b_y = _mm512_reduce_add_pd(sum_b_y_f64x8),
1472
+ sum_b_z = _mm512_reduce_add_pd(sum_b_z_f64x8);
1473
+ nk_f64_t covariance_x_x = _mm512_reduce_add_pd(cov_xx_f64x8), covariance_x_y = _mm512_reduce_add_pd(cov_xy_f64x8),
1474
+ covariance_x_z = _mm512_reduce_add_pd(cov_xz_f64x8);
1475
+ nk_f64_t covariance_y_x = _mm512_reduce_add_pd(cov_yx_f64x8), covariance_y_y = _mm512_reduce_add_pd(cov_yy_f64x8),
1476
+ covariance_y_z = _mm512_reduce_add_pd(cov_yz_f64x8);
1477
+ nk_f64_t covariance_z_x = _mm512_reduce_add_pd(cov_zx_f64x8), covariance_z_y = _mm512_reduce_add_pd(cov_zy_f64x8),
1478
+ covariance_z_z = _mm512_reduce_add_pd(cov_zz_f64x8);
1479
+ nk_f64_t variance_a_sum = _mm512_reduce_add_pd(variance_a_f64x8);
891
1480
 
892
- NK_PUBLIC void nk_umeyama_f32_skylake(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f32_t *a_centroid,
893
- nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f64_t *result) {
894
- nk_f64_t centroid_a_x, centroid_a_y, centroid_a_z, centroid_b_x, centroid_b_y, centroid_b_z, variance_a;
895
- nk_f64_t cross_covariance_f64[9];
896
- nk_centroid_and_cross_covariance_and_variance_f32_skylake_(a, b, n, &centroid_a_x, &centroid_a_y, &centroid_a_z,
897
- &centroid_b_x, &centroid_b_y, &centroid_b_z,
898
- cross_covariance_f64, &variance_a);
1481
+ nk_f64_t inv_n = 1.0 / (nk_f64_t)n;
1482
+ nk_f64_t centroid_a_x = sum_a_x * inv_n, centroid_a_y = sum_a_y * inv_n, centroid_a_z = sum_a_z * inv_n;
1483
+ nk_f64_t centroid_b_x = sum_b_x * inv_n, centroid_b_y = sum_b_y * inv_n, centroid_b_z = sum_b_z * inv_n;
899
1484
  if (a_centroid)
900
1485
  a_centroid[0] = (nk_f32_t)centroid_a_x, a_centroid[1] = (nk_f32_t)centroid_a_y,
901
1486
  a_centroid[2] = (nk_f32_t)centroid_a_z;
@@ -903,41 +1488,49 @@ NK_PUBLIC void nk_umeyama_f32_skylake(nk_f32_t const *a, nk_f32_t const *b, nk_s
903
1488
  b_centroid[0] = (nk_f32_t)centroid_b_x, b_centroid[1] = (nk_f32_t)centroid_b_y,
904
1489
  b_centroid[2] = (nk_f32_t)centroid_b_z;
905
1490
 
906
- nk_f64_t svd_u[9], svd_s[9], svd_v[9], r[9];
907
- nk_svd3x3_f64_(cross_covariance_f64, svd_u, svd_s, svd_v);
908
- r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
909
- r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
910
- r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
911
- r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
912
- r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
913
- r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
914
- r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
915
- r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
916
- r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
1491
+ // Compute centered covariance and variance
1492
+ nk_f64_t variance_a = variance_a_sum * inv_n -
1493
+ (centroid_a_x * centroid_a_x + centroid_a_y * centroid_a_y + centroid_a_z * centroid_a_z);
917
1494
 
1495
+ // Compute centered covariance matrix: Hᵢⱼ = Σ(aᵢ×bⱼ) - Σaᵢ × Σbⱼ / n
1496
+ nk_f64_t n_f64 = (nk_f64_t)n;
1497
+ nk_f64_t cross_covariance[9];
1498
+ cross_covariance[0] = covariance_x_x - n_f64 * centroid_a_x * centroid_b_x;
1499
+ cross_covariance[1] = covariance_x_y - n_f64 * centroid_a_x * centroid_b_y;
1500
+ cross_covariance[2] = covariance_x_z - n_f64 * centroid_a_x * centroid_b_z;
1501
+ cross_covariance[3] = covariance_y_x - n_f64 * centroid_a_y * centroid_b_x;
1502
+ cross_covariance[4] = covariance_y_y - n_f64 * centroid_a_y * centroid_b_y;
1503
+ cross_covariance[5] = covariance_y_z - n_f64 * centroid_a_y * centroid_b_z;
1504
+ cross_covariance[6] = covariance_z_x - n_f64 * centroid_a_z * centroid_b_x;
1505
+ cross_covariance[7] = covariance_z_y - n_f64 * centroid_a_z * centroid_b_y;
1506
+ cross_covariance[8] = covariance_z_z - n_f64 * centroid_a_z * centroid_b_z;
1507
+
1508
+ // SVD using f64 for full precision
1509
+ nk_f64_t svd_u[9], svd_s[9], svd_v[9];
1510
+ nk_svd3x3_f64_(cross_covariance, svd_u, svd_s, svd_v);
1511
+
1512
+ nk_f64_t r[9];
1513
+ nk_rotation_from_svd_f64_skylake_(svd_u, svd_v, r);
1514
+
1515
+ // Scale factor: c = trace(D × S) / (n × variance(a))
918
1516
  nk_f64_t det = nk_det3x3_f64_(r);
919
- nk_f64_t trace_signed_singular_values = svd_s[0] + svd_s[4] + (det < 0 ? -svd_s[8] : svd_s[8]);
920
- nk_f64_t applied_scale = trace_signed_singular_values / ((nk_f64_t)n * variance_a);
1517
+ nk_f64_t d3 = det < 0 ? -1.0 : 1.0;
1518
+ nk_f64_t trace_ds = nk_sum_three_products_f64_(svd_s[0], 1.0, svd_s[4], 1.0, svd_s[8], d3);
1519
+ nk_f64_t applied_scale = trace_ds / ((nk_f64_t)n * variance_a);
1520
+ if (scale) *scale = (nk_f32_t)applied_scale;
1521
+
1522
+ // Handle reflection
921
1523
  if (det < 0) {
922
1524
  svd_v[2] = -svd_v[2], svd_v[5] = -svd_v[5], svd_v[8] = -svd_v[8];
923
- r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
924
- r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
925
- r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
926
- r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
927
- r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
928
- r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
929
- r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
930
- r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
931
- r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
1525
+ nk_rotation_from_svd_f64_skylake_(svd_u, svd_v, r);
932
1526
  }
933
1527
 
934
1528
  if (rotation)
935
- for (int index = 0; index != 9; ++index) rotation[index] = (nk_f32_t)r[index];
936
- if (scale) *scale = (nk_f32_t)applied_scale;
1529
+ for (int j = 0; j < 9; ++j) rotation[j] = (nk_f32_t)r[j];
937
1530
  *result = nk_f64_sqrt_haswell(nk_transformed_ssd_f32_skylake_(a, b, n, r, applied_scale, centroid_a_x, centroid_a_y,
938
1531
  centroid_a_z, centroid_b_x, centroid_b_y,
939
1532
  centroid_b_z) /
940
- (nk_f64_t)n);
1533
+ n_f64);
941
1534
  }
942
1535
 
943
1536
  NK_PUBLIC void nk_umeyama_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *a_centroid,
@@ -1013,6 +1606,7 @@ NK_PUBLIC void nk_umeyama_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_s
1013
1606
  variance_a_f64x8 = _mm512_fmadd_pd(a_x_f64x8, a_x_f64x8, variance_a_f64x8);
1014
1607
  variance_a_f64x8 = _mm512_fmadd_pd(a_y_f64x8, a_y_f64x8, variance_a_f64x8);
1015
1608
  variance_a_f64x8 = _mm512_fmadd_pd(a_z_f64x8, a_z_f64x8, variance_a_f64x8);
1609
+ i = n;
1016
1610
  }
1017
1611
 
1018
1612
  // Reduce centroids, covariance, and variance.
@@ -1100,7 +1694,7 @@ NK_PUBLIC void nk_umeyama_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_s
1100
1694
  nk_f64_t det = nk_det3x3_f64_(r);
1101
1695
  nk_f64_t d3 = det < 0 ? -1.0 : 1.0;
1102
1696
  nk_f64_t trace_ds = nk_sum_three_products_f64_(svd_s[0], 1.0, svd_s[4], 1.0, svd_s[8], d3);
1103
- nk_f64_t c = trace_ds / (n * variance_a);
1697
+ nk_f64_t c = trace_ds / ((nk_f64_t)n * variance_a);
1104
1698
  if (scale) *scale = c;
1105
1699
 
1106
1700
  // Handle reflection
@@ -1110,9 +1704,8 @@ NK_PUBLIC void nk_umeyama_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_s
1110
1704
  }
1111
1705
 
1112
1706
  // Output rotation matrix.
1113
- if (rotation) {
1707
+ if (rotation)
1114
1708
  for (int j = 0; j < 9; ++j) rotation[j] = (nk_f64_t)r[j];
1115
- }
1116
1709
 
1117
1710
  // Compute RMSD with scaling
1118
1711
  nk_f64_t sum_squared = nk_transformed_ssd_f64_skylake_(a, b, n, r, c, centroid_a_x, centroid_a_y, centroid_a_z,
@@ -1120,6 +1713,738 @@ NK_PUBLIC void nk_umeyama_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_s
1120
1713
  *result = nk_f64_sqrt_haswell(sum_squared * inv_n);
1121
1714
  }
1122
1715
 
1716
+ NK_PUBLIC void nk_rmsd_f16_skylake(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
1717
+ nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
1718
+ if (rotation)
1719
+ rotation[0] = 1, rotation[1] = 0, rotation[2] = 0, rotation[3] = 0, rotation[4] = 1, rotation[5] = 0,
1720
+ rotation[6] = 0, rotation[7] = 0, rotation[8] = 1;
1721
+ if (scale) *scale = 1.0f;
1722
+
1723
+ __m512 const zeros_f32x16 = _mm512_setzero_ps();
1724
+ __m512 sum_a_x_f32x16 = zeros_f32x16, sum_a_y_f32x16 = zeros_f32x16, sum_a_z_f32x16 = zeros_f32x16;
1725
+ __m512 sum_b_x_f32x16 = zeros_f32x16, sum_b_y_f32x16 = zeros_f32x16, sum_b_z_f32x16 = zeros_f32x16;
1726
+ __m512 sum_squared_x_f32x16 = zeros_f32x16, sum_squared_y_f32x16 = zeros_f32x16;
1727
+ __m512 sum_squared_z_f32x16 = zeros_f32x16;
1728
+ __m512 a_x_f32x16, a_y_f32x16, a_z_f32x16, b_x_f32x16, b_y_f32x16, b_z_f32x16;
1729
+ nk_size_t i = 0;
1730
+
1731
+ for (; i + 16 <= n; i += 16) {
1732
+ nk_deinterleave_f16x16_to_f32x16_skylake_(a + i * 3, &a_x_f32x16, &a_y_f32x16, &a_z_f32x16);
1733
+ nk_deinterleave_f16x16_to_f32x16_skylake_(b + i * 3, &b_x_f32x16, &b_y_f32x16, &b_z_f32x16);
1734
+
1735
+ sum_a_x_f32x16 = _mm512_add_ps(sum_a_x_f32x16, a_x_f32x16);
1736
+ sum_a_y_f32x16 = _mm512_add_ps(sum_a_y_f32x16, a_y_f32x16);
1737
+ sum_a_z_f32x16 = _mm512_add_ps(sum_a_z_f32x16, a_z_f32x16);
1738
+ sum_b_x_f32x16 = _mm512_add_ps(sum_b_x_f32x16, b_x_f32x16);
1739
+ sum_b_y_f32x16 = _mm512_add_ps(sum_b_y_f32x16, b_y_f32x16);
1740
+ sum_b_z_f32x16 = _mm512_add_ps(sum_b_z_f32x16, b_z_f32x16);
1741
+
1742
+ __m512 delta_x_f32x16 = _mm512_sub_ps(a_x_f32x16, b_x_f32x16);
1743
+ __m512 delta_y_f32x16 = _mm512_sub_ps(a_y_f32x16, b_y_f32x16);
1744
+ __m512 delta_z_f32x16 = _mm512_sub_ps(a_z_f32x16, b_z_f32x16);
1745
+
1746
+ sum_squared_x_f32x16 = _mm512_fmadd_ps(delta_x_f32x16, delta_x_f32x16, sum_squared_x_f32x16);
1747
+ sum_squared_y_f32x16 = _mm512_fmadd_ps(delta_y_f32x16, delta_y_f32x16, sum_squared_y_f32x16);
1748
+ sum_squared_z_f32x16 = _mm512_fmadd_ps(delta_z_f32x16, delta_z_f32x16, sum_squared_z_f32x16);
1749
+ }
1750
+
1751
+ // Tail: deinterleave remaining points into zero-initialized vectors
1752
+ if (i < n) {
1753
+ nk_size_t tail = n - i;
1754
+ nk_deinterleave_f16_tail_to_f32x16_skylake_(a + i * 3, tail, &a_x_f32x16, &a_y_f32x16, &a_z_f32x16);
1755
+ nk_deinterleave_f16_tail_to_f32x16_skylake_(b + i * 3, tail, &b_x_f32x16, &b_y_f32x16, &b_z_f32x16);
1756
+
1757
+ sum_a_x_f32x16 = _mm512_add_ps(sum_a_x_f32x16, a_x_f32x16);
1758
+ sum_a_y_f32x16 = _mm512_add_ps(sum_a_y_f32x16, a_y_f32x16);
1759
+ sum_a_z_f32x16 = _mm512_add_ps(sum_a_z_f32x16, a_z_f32x16);
1760
+ sum_b_x_f32x16 = _mm512_add_ps(sum_b_x_f32x16, b_x_f32x16);
1761
+ sum_b_y_f32x16 = _mm512_add_ps(sum_b_y_f32x16, b_y_f32x16);
1762
+ sum_b_z_f32x16 = _mm512_add_ps(sum_b_z_f32x16, b_z_f32x16);
1763
+
1764
+ __m512 delta_x_f32x16 = _mm512_sub_ps(a_x_f32x16, b_x_f32x16);
1765
+ __m512 delta_y_f32x16 = _mm512_sub_ps(a_y_f32x16, b_y_f32x16);
1766
+ __m512 delta_z_f32x16 = _mm512_sub_ps(a_z_f32x16, b_z_f32x16);
1767
+
1768
+ sum_squared_x_f32x16 = _mm512_fmadd_ps(delta_x_f32x16, delta_x_f32x16, sum_squared_x_f32x16);
1769
+ sum_squared_y_f32x16 = _mm512_fmadd_ps(delta_y_f32x16, delta_y_f32x16, sum_squared_y_f32x16);
1770
+ sum_squared_z_f32x16 = _mm512_fmadd_ps(delta_z_f32x16, delta_z_f32x16, sum_squared_z_f32x16);
1771
+ }
1772
+
1773
+ nk_f32_t total_ax = _mm512_reduce_add_ps(sum_a_x_f32x16);
1774
+ nk_f32_t total_ay = _mm512_reduce_add_ps(sum_a_y_f32x16);
1775
+ nk_f32_t total_az = _mm512_reduce_add_ps(sum_a_z_f32x16);
1776
+ nk_f32_t total_bx = _mm512_reduce_add_ps(sum_b_x_f32x16);
1777
+ nk_f32_t total_by = _mm512_reduce_add_ps(sum_b_y_f32x16);
1778
+ nk_f32_t total_bz = _mm512_reduce_add_ps(sum_b_z_f32x16);
1779
+ nk_f32_t total_sq_x = _mm512_reduce_add_ps(sum_squared_x_f32x16);
1780
+ nk_f32_t total_sq_y = _mm512_reduce_add_ps(sum_squared_y_f32x16);
1781
+ nk_f32_t total_sq_z = _mm512_reduce_add_ps(sum_squared_z_f32x16);
1782
+
1783
+ nk_f32_t inv_n = 1.0f / (nk_f32_t)n;
1784
+ nk_f32_t centroid_a_x = total_ax * inv_n;
1785
+ nk_f32_t centroid_a_y = total_ay * inv_n;
1786
+ nk_f32_t centroid_a_z = total_az * inv_n;
1787
+ nk_f32_t centroid_b_x = total_bx * inv_n;
1788
+ nk_f32_t centroid_b_y = total_by * inv_n;
1789
+ nk_f32_t centroid_b_z = total_bz * inv_n;
1790
+
1791
+ if (a_centroid) a_centroid[0] = centroid_a_x, a_centroid[1] = centroid_a_y, a_centroid[2] = centroid_a_z;
1792
+ if (b_centroid) b_centroid[0] = centroid_b_x, b_centroid[1] = centroid_b_y, b_centroid[2] = centroid_b_z;
1793
+
1794
+ nk_f32_t mean_diff_x = centroid_a_x - centroid_b_x;
1795
+ nk_f32_t mean_diff_y = centroid_a_y - centroid_b_y;
1796
+ nk_f32_t mean_diff_z = centroid_a_z - centroid_b_z;
1797
+ nk_f32_t sum_squared = total_sq_x + total_sq_y + total_sq_z;
1798
+ nk_f32_t mean_diff_sq = mean_diff_x * mean_diff_x + mean_diff_y * mean_diff_y + mean_diff_z * mean_diff_z;
1799
+
1800
+ *result = nk_f32_sqrt_haswell(sum_squared * inv_n - mean_diff_sq);
1801
+ }
1802
+
1803
+ NK_PUBLIC void nk_rmsd_bf16_skylake(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
1804
+ nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
1805
+ if (rotation)
1806
+ rotation[0] = 1, rotation[1] = 0, rotation[2] = 0, rotation[3] = 0, rotation[4] = 1, rotation[5] = 0,
1807
+ rotation[6] = 0, rotation[7] = 0, rotation[8] = 1;
1808
+ if (scale) *scale = 1.0f;
1809
+
1810
+ __m512 const zeros_f32x16 = _mm512_setzero_ps();
1811
+ __m512 sum_a_x_f32x16 = zeros_f32x16, sum_a_y_f32x16 = zeros_f32x16, sum_a_z_f32x16 = zeros_f32x16;
1812
+ __m512 sum_b_x_f32x16 = zeros_f32x16, sum_b_y_f32x16 = zeros_f32x16, sum_b_z_f32x16 = zeros_f32x16;
1813
+ __m512 sum_squared_x_f32x16 = zeros_f32x16, sum_squared_y_f32x16 = zeros_f32x16;
1814
+ __m512 sum_squared_z_f32x16 = zeros_f32x16;
1815
+ __m512 a_x_f32x16, a_y_f32x16, a_z_f32x16, b_x_f32x16, b_y_f32x16, b_z_f32x16;
1816
+ nk_size_t i = 0;
1817
+
1818
+ for (; i + 16 <= n; i += 16) {
1819
+ nk_deinterleave_bf16x16_to_f32x16_skylake_(a + i * 3, &a_x_f32x16, &a_y_f32x16, &a_z_f32x16);
1820
+ nk_deinterleave_bf16x16_to_f32x16_skylake_(b + i * 3, &b_x_f32x16, &b_y_f32x16, &b_z_f32x16);
1821
+
1822
+ sum_a_x_f32x16 = _mm512_add_ps(sum_a_x_f32x16, a_x_f32x16);
1823
+ sum_a_y_f32x16 = _mm512_add_ps(sum_a_y_f32x16, a_y_f32x16);
1824
+ sum_a_z_f32x16 = _mm512_add_ps(sum_a_z_f32x16, a_z_f32x16);
1825
+ sum_b_x_f32x16 = _mm512_add_ps(sum_b_x_f32x16, b_x_f32x16);
1826
+ sum_b_y_f32x16 = _mm512_add_ps(sum_b_y_f32x16, b_y_f32x16);
1827
+ sum_b_z_f32x16 = _mm512_add_ps(sum_b_z_f32x16, b_z_f32x16);
1828
+
1829
+ __m512 delta_x_f32x16 = _mm512_sub_ps(a_x_f32x16, b_x_f32x16);
1830
+ __m512 delta_y_f32x16 = _mm512_sub_ps(a_y_f32x16, b_y_f32x16);
1831
+ __m512 delta_z_f32x16 = _mm512_sub_ps(a_z_f32x16, b_z_f32x16);
1832
+
1833
+ sum_squared_x_f32x16 = _mm512_fmadd_ps(delta_x_f32x16, delta_x_f32x16, sum_squared_x_f32x16);
1834
+ sum_squared_y_f32x16 = _mm512_fmadd_ps(delta_y_f32x16, delta_y_f32x16, sum_squared_y_f32x16);
1835
+ sum_squared_z_f32x16 = _mm512_fmadd_ps(delta_z_f32x16, delta_z_f32x16, sum_squared_z_f32x16);
1836
+ }
1837
+
1838
+ // Tail: deinterleave remaining points into zero-initialized vectors
1839
+ if (i < n) {
1840
+ nk_size_t tail = n - i;
1841
+ nk_deinterleave_bf16_tail_to_f32x16_skylake_(a + i * 3, tail, &a_x_f32x16, &a_y_f32x16, &a_z_f32x16);
1842
+ nk_deinterleave_bf16_tail_to_f32x16_skylake_(b + i * 3, tail, &b_x_f32x16, &b_y_f32x16, &b_z_f32x16);
1843
+
1844
+ sum_a_x_f32x16 = _mm512_add_ps(sum_a_x_f32x16, a_x_f32x16);
1845
+ sum_a_y_f32x16 = _mm512_add_ps(sum_a_y_f32x16, a_y_f32x16);
1846
+ sum_a_z_f32x16 = _mm512_add_ps(sum_a_z_f32x16, a_z_f32x16);
1847
+ sum_b_x_f32x16 = _mm512_add_ps(sum_b_x_f32x16, b_x_f32x16);
1848
+ sum_b_y_f32x16 = _mm512_add_ps(sum_b_y_f32x16, b_y_f32x16);
1849
+ sum_b_z_f32x16 = _mm512_add_ps(sum_b_z_f32x16, b_z_f32x16);
1850
+
1851
+ __m512 delta_x_f32x16 = _mm512_sub_ps(a_x_f32x16, b_x_f32x16);
1852
+ __m512 delta_y_f32x16 = _mm512_sub_ps(a_y_f32x16, b_y_f32x16);
1853
+ __m512 delta_z_f32x16 = _mm512_sub_ps(a_z_f32x16, b_z_f32x16);
1854
+
1855
+ sum_squared_x_f32x16 = _mm512_fmadd_ps(delta_x_f32x16, delta_x_f32x16, sum_squared_x_f32x16);
1856
+ sum_squared_y_f32x16 = _mm512_fmadd_ps(delta_y_f32x16, delta_y_f32x16, sum_squared_y_f32x16);
1857
+ sum_squared_z_f32x16 = _mm512_fmadd_ps(delta_z_f32x16, delta_z_f32x16, sum_squared_z_f32x16);
1858
+ }
1859
+
1860
+ nk_f32_t total_ax = _mm512_reduce_add_ps(sum_a_x_f32x16);
1861
+ nk_f32_t total_ay = _mm512_reduce_add_ps(sum_a_y_f32x16);
1862
+ nk_f32_t total_az = _mm512_reduce_add_ps(sum_a_z_f32x16);
1863
+ nk_f32_t total_bx = _mm512_reduce_add_ps(sum_b_x_f32x16);
1864
+ nk_f32_t total_by = _mm512_reduce_add_ps(sum_b_y_f32x16);
1865
+ nk_f32_t total_bz = _mm512_reduce_add_ps(sum_b_z_f32x16);
1866
+ nk_f32_t total_sq_x = _mm512_reduce_add_ps(sum_squared_x_f32x16);
1867
+ nk_f32_t total_sq_y = _mm512_reduce_add_ps(sum_squared_y_f32x16);
1868
+ nk_f32_t total_sq_z = _mm512_reduce_add_ps(sum_squared_z_f32x16);
1869
+
1870
+ nk_f32_t inv_n = 1.0f / (nk_f32_t)n;
1871
+ nk_f32_t centroid_a_x = total_ax * inv_n;
1872
+ nk_f32_t centroid_a_y = total_ay * inv_n;
1873
+ nk_f32_t centroid_a_z = total_az * inv_n;
1874
+ nk_f32_t centroid_b_x = total_bx * inv_n;
1875
+ nk_f32_t centroid_b_y = total_by * inv_n;
1876
+ nk_f32_t centroid_b_z = total_bz * inv_n;
1877
+
1878
+ if (a_centroid) a_centroid[0] = centroid_a_x, a_centroid[1] = centroid_a_y, a_centroid[2] = centroid_a_z;
1879
+ if (b_centroid) b_centroid[0] = centroid_b_x, b_centroid[1] = centroid_b_y, b_centroid[2] = centroid_b_z;
1880
+
1881
+ nk_f32_t mean_diff_x = centroid_a_x - centroid_b_x;
1882
+ nk_f32_t mean_diff_y = centroid_a_y - centroid_b_y;
1883
+ nk_f32_t mean_diff_z = centroid_a_z - centroid_b_z;
1884
+ nk_f32_t sum_squared = total_sq_x + total_sq_y + total_sq_z;
1885
+ nk_f32_t mean_diff_sq = mean_diff_x * mean_diff_x + mean_diff_y * mean_diff_y + mean_diff_z * mean_diff_z;
1886
+
1887
+ *result = nk_f32_sqrt_haswell(sum_squared * inv_n - mean_diff_sq);
1888
+ }
1889
+
1890
+ NK_PUBLIC void nk_kabsch_f16_skylake(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
1891
+ nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
1892
+ __m512 const zeros_f32x16 = _mm512_setzero_ps();
1893
+
1894
+ __m512 sum_a_x_f32x16 = zeros_f32x16, sum_a_y_f32x16 = zeros_f32x16, sum_a_z_f32x16 = zeros_f32x16;
1895
+ __m512 sum_b_x_f32x16 = zeros_f32x16, sum_b_y_f32x16 = zeros_f32x16, sum_b_z_f32x16 = zeros_f32x16;
1896
+ __m512 cov_xx_f32x16 = zeros_f32x16, cov_xy_f32x16 = zeros_f32x16, cov_xz_f32x16 = zeros_f32x16;
1897
+ __m512 cov_yx_f32x16 = zeros_f32x16, cov_yy_f32x16 = zeros_f32x16, cov_yz_f32x16 = zeros_f32x16;
1898
+ __m512 cov_zx_f32x16 = zeros_f32x16, cov_zy_f32x16 = zeros_f32x16, cov_zz_f32x16 = zeros_f32x16;
1899
+
1900
+ nk_size_t i = 0;
1901
+ __m512 a_x_f32x16, a_y_f32x16, a_z_f32x16, b_x_f32x16, b_y_f32x16, b_z_f32x16;
1902
+
1903
+ for (; i + 16 <= n; i += 16) {
1904
+ nk_deinterleave_f16x16_to_f32x16_skylake_(a + i * 3, &a_x_f32x16, &a_y_f32x16, &a_z_f32x16);
1905
+ nk_deinterleave_f16x16_to_f32x16_skylake_(b + i * 3, &b_x_f32x16, &b_y_f32x16, &b_z_f32x16);
1906
+
1907
+ sum_a_x_f32x16 = _mm512_add_ps(sum_a_x_f32x16, a_x_f32x16);
1908
+ sum_a_y_f32x16 = _mm512_add_ps(sum_a_y_f32x16, a_y_f32x16);
1909
+ sum_a_z_f32x16 = _mm512_add_ps(sum_a_z_f32x16, a_z_f32x16);
1910
+ sum_b_x_f32x16 = _mm512_add_ps(sum_b_x_f32x16, b_x_f32x16);
1911
+ sum_b_y_f32x16 = _mm512_add_ps(sum_b_y_f32x16, b_y_f32x16);
1912
+ sum_b_z_f32x16 = _mm512_add_ps(sum_b_z_f32x16, b_z_f32x16);
1913
+
1914
+ cov_xx_f32x16 = _mm512_fmadd_ps(a_x_f32x16, b_x_f32x16, cov_xx_f32x16);
1915
+ cov_xy_f32x16 = _mm512_fmadd_ps(a_x_f32x16, b_y_f32x16, cov_xy_f32x16);
1916
+ cov_xz_f32x16 = _mm512_fmadd_ps(a_x_f32x16, b_z_f32x16, cov_xz_f32x16);
1917
+ cov_yx_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_x_f32x16, cov_yx_f32x16);
1918
+ cov_yy_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_y_f32x16, cov_yy_f32x16);
1919
+ cov_yz_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_z_f32x16, cov_yz_f32x16);
1920
+ cov_zx_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_x_f32x16, cov_zx_f32x16);
1921
+ cov_zy_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_y_f32x16, cov_zy_f32x16);
1922
+ cov_zz_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_z_f32x16, cov_zz_f32x16);
1923
+ }
1924
+
1925
+ // Tail: deinterleave remaining points into zero-initialized vectors
1926
+ if (i < n) {
1927
+ nk_size_t tail = n - i;
1928
+ nk_deinterleave_f16_tail_to_f32x16_skylake_(a + i * 3, tail, &a_x_f32x16, &a_y_f32x16, &a_z_f32x16);
1929
+ nk_deinterleave_f16_tail_to_f32x16_skylake_(b + i * 3, tail, &b_x_f32x16, &b_y_f32x16, &b_z_f32x16);
1930
+
1931
+ sum_a_x_f32x16 = _mm512_add_ps(sum_a_x_f32x16, a_x_f32x16);
1932
+ sum_a_y_f32x16 = _mm512_add_ps(sum_a_y_f32x16, a_y_f32x16);
1933
+ sum_a_z_f32x16 = _mm512_add_ps(sum_a_z_f32x16, a_z_f32x16);
1934
+ sum_b_x_f32x16 = _mm512_add_ps(sum_b_x_f32x16, b_x_f32x16);
1935
+ sum_b_y_f32x16 = _mm512_add_ps(sum_b_y_f32x16, b_y_f32x16);
1936
+ sum_b_z_f32x16 = _mm512_add_ps(sum_b_z_f32x16, b_z_f32x16);
1937
+
1938
+ cov_xx_f32x16 = _mm512_fmadd_ps(a_x_f32x16, b_x_f32x16, cov_xx_f32x16);
1939
+ cov_xy_f32x16 = _mm512_fmadd_ps(a_x_f32x16, b_y_f32x16, cov_xy_f32x16);
1940
+ cov_xz_f32x16 = _mm512_fmadd_ps(a_x_f32x16, b_z_f32x16, cov_xz_f32x16);
1941
+ cov_yx_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_x_f32x16, cov_yx_f32x16);
1942
+ cov_yy_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_y_f32x16, cov_yy_f32x16);
1943
+ cov_yz_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_z_f32x16, cov_yz_f32x16);
1944
+ cov_zx_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_x_f32x16, cov_zx_f32x16);
1945
+ cov_zy_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_y_f32x16, cov_zy_f32x16);
1946
+ cov_zz_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_z_f32x16, cov_zz_f32x16);
1947
+ }
1948
+
1949
+ nk_f32_t sum_a_x = _mm512_reduce_add_ps(sum_a_x_f32x16);
1950
+ nk_f32_t sum_a_y = _mm512_reduce_add_ps(sum_a_y_f32x16);
1951
+ nk_f32_t sum_a_z = _mm512_reduce_add_ps(sum_a_z_f32x16);
1952
+ nk_f32_t sum_b_x = _mm512_reduce_add_ps(sum_b_x_f32x16);
1953
+ nk_f32_t sum_b_y = _mm512_reduce_add_ps(sum_b_y_f32x16);
1954
+ nk_f32_t sum_b_z = _mm512_reduce_add_ps(sum_b_z_f32x16);
1955
+ nk_f32_t covariance_x_x = _mm512_reduce_add_ps(cov_xx_f32x16);
1956
+ nk_f32_t covariance_x_y = _mm512_reduce_add_ps(cov_xy_f32x16);
1957
+ nk_f32_t covariance_x_z = _mm512_reduce_add_ps(cov_xz_f32x16);
1958
+ nk_f32_t covariance_y_x = _mm512_reduce_add_ps(cov_yx_f32x16);
1959
+ nk_f32_t covariance_y_y = _mm512_reduce_add_ps(cov_yy_f32x16);
1960
+ nk_f32_t covariance_y_z = _mm512_reduce_add_ps(cov_yz_f32x16);
1961
+ nk_f32_t covariance_z_x = _mm512_reduce_add_ps(cov_zx_f32x16);
1962
+ nk_f32_t covariance_z_y = _mm512_reduce_add_ps(cov_zy_f32x16);
1963
+ nk_f32_t covariance_z_z = _mm512_reduce_add_ps(cov_zz_f32x16);
1964
+
1965
+ nk_f32_t inv_n = 1.0f / (nk_f32_t)n;
1966
+ nk_f32_t centroid_a_x = sum_a_x * inv_n, centroid_a_y = sum_a_y * inv_n, centroid_a_z = sum_a_z * inv_n;
1967
+ nk_f32_t centroid_b_x = sum_b_x * inv_n, centroid_b_y = sum_b_y * inv_n, centroid_b_z = sum_b_z * inv_n;
1968
+
1969
+ if (a_centroid) a_centroid[0] = centroid_a_x, a_centroid[1] = centroid_a_y, a_centroid[2] = centroid_a_z;
1970
+ if (b_centroid) b_centroid[0] = centroid_b_x, b_centroid[1] = centroid_b_y, b_centroid[2] = centroid_b_z;
1971
+
1972
+ covariance_x_x -= (nk_f32_t)n * centroid_a_x * centroid_b_x;
1973
+ covariance_x_y -= (nk_f32_t)n * centroid_a_x * centroid_b_y;
1974
+ covariance_x_z -= (nk_f32_t)n * centroid_a_x * centroid_b_z;
1975
+ covariance_y_x -= (nk_f32_t)n * centroid_a_y * centroid_b_x;
1976
+ covariance_y_y -= (nk_f32_t)n * centroid_a_y * centroid_b_y;
1977
+ covariance_y_z -= (nk_f32_t)n * centroid_a_y * centroid_b_z;
1978
+ covariance_z_x -= (nk_f32_t)n * centroid_a_z * centroid_b_x;
1979
+ covariance_z_y -= (nk_f32_t)n * centroid_a_z * centroid_b_y;
1980
+ covariance_z_z -= (nk_f32_t)n * centroid_a_z * centroid_b_z;
1981
+
1982
+ nk_f32_t cross_covariance[9] = {covariance_x_x, covariance_x_y, covariance_x_z, covariance_y_x, covariance_y_y,
1983
+ covariance_y_z, covariance_z_x, covariance_z_y, covariance_z_z};
1984
+ nk_f32_t svd_u[9], svd_s[9], svd_v[9];
1985
+ nk_svd3x3_f32_(cross_covariance, svd_u, svd_s, svd_v);
1986
+
1987
+ nk_f32_t r[9];
1988
+ r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
1989
+ r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
1990
+ r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
1991
+ r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
1992
+ r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
1993
+ r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
1994
+ r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
1995
+ r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
1996
+ r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
1997
+
1998
+ if (nk_det3x3_f32_(r) < 0) {
1999
+ svd_v[2] = -svd_v[2], svd_v[5] = -svd_v[5], svd_v[8] = -svd_v[8];
2000
+ r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
2001
+ r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
2002
+ r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
2003
+ r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
2004
+ r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
2005
+ r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
2006
+ r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
2007
+ r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
2008
+ r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
2009
+ }
2010
+
2011
+ if (rotation)
2012
+ for (int j = 0; j < 9; ++j) rotation[j] = r[j];
2013
+ if (scale) *scale = 1.0f;
2014
+
2015
+ nk_f32_t sum_squared = nk_transformed_ssd_f16_skylake_(a, b, n, r, 1.0f, centroid_a_x, centroid_a_y, centroid_a_z,
2016
+ centroid_b_x, centroid_b_y, centroid_b_z);
2017
+ *result = nk_f32_sqrt_haswell(sum_squared * inv_n);
2018
+ }
2019
+
2020
+ NK_PUBLIC void nk_kabsch_bf16_skylake(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
2021
+ nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
2022
+ __m512 const zeros_f32x16 = _mm512_setzero_ps();
2023
+
2024
+ __m512 sum_a_x_f32x16 = zeros_f32x16, sum_a_y_f32x16 = zeros_f32x16, sum_a_z_f32x16 = zeros_f32x16;
2025
+ __m512 sum_b_x_f32x16 = zeros_f32x16, sum_b_y_f32x16 = zeros_f32x16, sum_b_z_f32x16 = zeros_f32x16;
2026
+ __m512 cov_xx_f32x16 = zeros_f32x16, cov_xy_f32x16 = zeros_f32x16, cov_xz_f32x16 = zeros_f32x16;
2027
+ __m512 cov_yx_f32x16 = zeros_f32x16, cov_yy_f32x16 = zeros_f32x16, cov_yz_f32x16 = zeros_f32x16;
2028
+ __m512 cov_zx_f32x16 = zeros_f32x16, cov_zy_f32x16 = zeros_f32x16, cov_zz_f32x16 = zeros_f32x16;
2029
+
2030
+ nk_size_t i = 0;
2031
+ __m512 a_x_f32x16, a_y_f32x16, a_z_f32x16, b_x_f32x16, b_y_f32x16, b_z_f32x16;
2032
+
2033
+ for (; i + 16 <= n; i += 16) {
2034
+ nk_deinterleave_bf16x16_to_f32x16_skylake_(a + i * 3, &a_x_f32x16, &a_y_f32x16, &a_z_f32x16);
2035
+ nk_deinterleave_bf16x16_to_f32x16_skylake_(b + i * 3, &b_x_f32x16, &b_y_f32x16, &b_z_f32x16);
2036
+
2037
+ sum_a_x_f32x16 = _mm512_add_ps(sum_a_x_f32x16, a_x_f32x16);
2038
+ sum_a_y_f32x16 = _mm512_add_ps(sum_a_y_f32x16, a_y_f32x16);
2039
+ sum_a_z_f32x16 = _mm512_add_ps(sum_a_z_f32x16, a_z_f32x16);
2040
+ sum_b_x_f32x16 = _mm512_add_ps(sum_b_x_f32x16, b_x_f32x16);
2041
+ sum_b_y_f32x16 = _mm512_add_ps(sum_b_y_f32x16, b_y_f32x16);
2042
+ sum_b_z_f32x16 = _mm512_add_ps(sum_b_z_f32x16, b_z_f32x16);
2043
+
2044
+ cov_xx_f32x16 = _mm512_fmadd_ps(a_x_f32x16, b_x_f32x16, cov_xx_f32x16);
2045
+ cov_xy_f32x16 = _mm512_fmadd_ps(a_x_f32x16, b_y_f32x16, cov_xy_f32x16);
2046
+ cov_xz_f32x16 = _mm512_fmadd_ps(a_x_f32x16, b_z_f32x16, cov_xz_f32x16);
2047
+ cov_yx_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_x_f32x16, cov_yx_f32x16);
2048
+ cov_yy_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_y_f32x16, cov_yy_f32x16);
2049
+ cov_yz_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_z_f32x16, cov_yz_f32x16);
2050
+ cov_zx_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_x_f32x16, cov_zx_f32x16);
2051
+ cov_zy_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_y_f32x16, cov_zy_f32x16);
2052
+ cov_zz_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_z_f32x16, cov_zz_f32x16);
2053
+ }
2054
+
2055
+ // Tail: deinterleave remaining points into zero-initialized vectors
2056
+ if (i < n) {
2057
+ nk_size_t tail = n - i;
2058
+ nk_deinterleave_bf16_tail_to_f32x16_skylake_(a + i * 3, tail, &a_x_f32x16, &a_y_f32x16, &a_z_f32x16);
2059
+ nk_deinterleave_bf16_tail_to_f32x16_skylake_(b + i * 3, tail, &b_x_f32x16, &b_y_f32x16, &b_z_f32x16);
2060
+
2061
+ sum_a_x_f32x16 = _mm512_add_ps(sum_a_x_f32x16, a_x_f32x16);
2062
+ sum_a_y_f32x16 = _mm512_add_ps(sum_a_y_f32x16, a_y_f32x16);
2063
+ sum_a_z_f32x16 = _mm512_add_ps(sum_a_z_f32x16, a_z_f32x16);
2064
+ sum_b_x_f32x16 = _mm512_add_ps(sum_b_x_f32x16, b_x_f32x16);
2065
+ sum_b_y_f32x16 = _mm512_add_ps(sum_b_y_f32x16, b_y_f32x16);
2066
+ sum_b_z_f32x16 = _mm512_add_ps(sum_b_z_f32x16, b_z_f32x16);
2067
+
2068
+ cov_xx_f32x16 = _mm512_fmadd_ps(a_x_f32x16, b_x_f32x16, cov_xx_f32x16);
2069
+ cov_xy_f32x16 = _mm512_fmadd_ps(a_x_f32x16, b_y_f32x16, cov_xy_f32x16);
2070
+ cov_xz_f32x16 = _mm512_fmadd_ps(a_x_f32x16, b_z_f32x16, cov_xz_f32x16);
2071
+ cov_yx_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_x_f32x16, cov_yx_f32x16);
2072
+ cov_yy_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_y_f32x16, cov_yy_f32x16);
2073
+ cov_yz_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_z_f32x16, cov_yz_f32x16);
2074
+ cov_zx_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_x_f32x16, cov_zx_f32x16);
2075
+ cov_zy_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_y_f32x16, cov_zy_f32x16);
2076
+ cov_zz_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_z_f32x16, cov_zz_f32x16);
2077
+ }
2078
+
2079
+ nk_f32_t sum_a_x = _mm512_reduce_add_ps(sum_a_x_f32x16);
2080
+ nk_f32_t sum_a_y = _mm512_reduce_add_ps(sum_a_y_f32x16);
2081
+ nk_f32_t sum_a_z = _mm512_reduce_add_ps(sum_a_z_f32x16);
2082
+ nk_f32_t sum_b_x = _mm512_reduce_add_ps(sum_b_x_f32x16);
2083
+ nk_f32_t sum_b_y = _mm512_reduce_add_ps(sum_b_y_f32x16);
2084
+ nk_f32_t sum_b_z = _mm512_reduce_add_ps(sum_b_z_f32x16);
2085
+ nk_f32_t covariance_x_x = _mm512_reduce_add_ps(cov_xx_f32x16);
2086
+ nk_f32_t covariance_x_y = _mm512_reduce_add_ps(cov_xy_f32x16);
2087
+ nk_f32_t covariance_x_z = _mm512_reduce_add_ps(cov_xz_f32x16);
2088
+ nk_f32_t covariance_y_x = _mm512_reduce_add_ps(cov_yx_f32x16);
2089
+ nk_f32_t covariance_y_y = _mm512_reduce_add_ps(cov_yy_f32x16);
2090
+ nk_f32_t covariance_y_z = _mm512_reduce_add_ps(cov_yz_f32x16);
2091
+ nk_f32_t covariance_z_x = _mm512_reduce_add_ps(cov_zx_f32x16);
2092
+ nk_f32_t covariance_z_y = _mm512_reduce_add_ps(cov_zy_f32x16);
2093
+ nk_f32_t covariance_z_z = _mm512_reduce_add_ps(cov_zz_f32x16);
2094
+
2095
+ nk_f32_t inv_n = 1.0f / (nk_f32_t)n;
2096
+ nk_f32_t centroid_a_x = sum_a_x * inv_n, centroid_a_y = sum_a_y * inv_n, centroid_a_z = sum_a_z * inv_n;
2097
+ nk_f32_t centroid_b_x = sum_b_x * inv_n, centroid_b_y = sum_b_y * inv_n, centroid_b_z = sum_b_z * inv_n;
2098
+
2099
+ if (a_centroid) a_centroid[0] = centroid_a_x, a_centroid[1] = centroid_a_y, a_centroid[2] = centroid_a_z;
2100
+ if (b_centroid) b_centroid[0] = centroid_b_x, b_centroid[1] = centroid_b_y, b_centroid[2] = centroid_b_z;
2101
+
2102
+ covariance_x_x -= (nk_f32_t)n * centroid_a_x * centroid_b_x;
2103
+ covariance_x_y -= (nk_f32_t)n * centroid_a_x * centroid_b_y;
2104
+ covariance_x_z -= (nk_f32_t)n * centroid_a_x * centroid_b_z;
2105
+ covariance_y_x -= (nk_f32_t)n * centroid_a_y * centroid_b_x;
2106
+ covariance_y_y -= (nk_f32_t)n * centroid_a_y * centroid_b_y;
2107
+ covariance_y_z -= (nk_f32_t)n * centroid_a_y * centroid_b_z;
2108
+ covariance_z_x -= (nk_f32_t)n * centroid_a_z * centroid_b_x;
2109
+ covariance_z_y -= (nk_f32_t)n * centroid_a_z * centroid_b_y;
2110
+ covariance_z_z -= (nk_f32_t)n * centroid_a_z * centroid_b_z;
2111
+
2112
+ nk_f32_t cross_covariance[9] = {covariance_x_x, covariance_x_y, covariance_x_z, covariance_y_x, covariance_y_y,
2113
+ covariance_y_z, covariance_z_x, covariance_z_y, covariance_z_z};
2114
+ nk_f32_t svd_u[9], svd_s[9], svd_v[9];
2115
+ nk_svd3x3_f32_(cross_covariance, svd_u, svd_s, svd_v);
2116
+
2117
+ nk_f32_t r[9];
2118
+ r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
2119
+ r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
2120
+ r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
2121
+ r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
2122
+ r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
2123
+ r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
2124
+ r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
2125
+ r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
2126
+ r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
2127
+
2128
+ if (nk_det3x3_f32_(r) < 0) {
2129
+ svd_v[2] = -svd_v[2], svd_v[5] = -svd_v[5], svd_v[8] = -svd_v[8];
2130
+ r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
2131
+ r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
2132
+ r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
2133
+ r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
2134
+ r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
2135
+ r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
2136
+ r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
2137
+ r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
2138
+ r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
2139
+ }
2140
+
2141
+ if (rotation)
2142
+ for (int j = 0; j < 9; ++j) rotation[j] = r[j];
2143
+ if (scale) *scale = 1.0f;
2144
+
2145
+ nk_f32_t sum_squared = nk_transformed_ssd_bf16_skylake_(a, b, n, r, 1.0f, centroid_a_x, centroid_a_y, centroid_a_z,
2146
+ centroid_b_x, centroid_b_y, centroid_b_z);
2147
+ *result = nk_f32_sqrt_haswell(sum_squared * inv_n);
2148
+ }
2149
+
2150
+ NK_PUBLIC void nk_umeyama_f16_skylake(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
2151
+ nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
2152
+ __m512 const zeros_f32x16 = _mm512_setzero_ps();
2153
+
2154
+ __m512 sum_a_x_f32x16 = zeros_f32x16, sum_a_y_f32x16 = zeros_f32x16, sum_a_z_f32x16 = zeros_f32x16;
2155
+ __m512 sum_b_x_f32x16 = zeros_f32x16, sum_b_y_f32x16 = zeros_f32x16, sum_b_z_f32x16 = zeros_f32x16;
2156
+ __m512 cov_xx_f32x16 = zeros_f32x16, cov_xy_f32x16 = zeros_f32x16, cov_xz_f32x16 = zeros_f32x16;
2157
+ __m512 cov_yx_f32x16 = zeros_f32x16, cov_yy_f32x16 = zeros_f32x16, cov_yz_f32x16 = zeros_f32x16;
2158
+ __m512 cov_zx_f32x16 = zeros_f32x16, cov_zy_f32x16 = zeros_f32x16, cov_zz_f32x16 = zeros_f32x16;
2159
+ __m512 variance_a_f32x16 = zeros_f32x16;
2160
+
2161
+ nk_size_t i = 0;
2162
+ __m512 a_x_f32x16, a_y_f32x16, a_z_f32x16, b_x_f32x16, b_y_f32x16, b_z_f32x16;
2163
+
2164
+ for (; i + 16 <= n; i += 16) {
2165
+ nk_deinterleave_f16x16_to_f32x16_skylake_(a + i * 3, &a_x_f32x16, &a_y_f32x16, &a_z_f32x16);
2166
+ nk_deinterleave_f16x16_to_f32x16_skylake_(b + i * 3, &b_x_f32x16, &b_y_f32x16, &b_z_f32x16);
2167
+
2168
+ sum_a_x_f32x16 = _mm512_add_ps(sum_a_x_f32x16, a_x_f32x16);
2169
+ sum_a_y_f32x16 = _mm512_add_ps(sum_a_y_f32x16, a_y_f32x16);
2170
+ sum_a_z_f32x16 = _mm512_add_ps(sum_a_z_f32x16, a_z_f32x16);
2171
+ sum_b_x_f32x16 = _mm512_add_ps(sum_b_x_f32x16, b_x_f32x16);
2172
+ sum_b_y_f32x16 = _mm512_add_ps(sum_b_y_f32x16, b_y_f32x16);
2173
+ sum_b_z_f32x16 = _mm512_add_ps(sum_b_z_f32x16, b_z_f32x16);
2174
+
2175
+ cov_xx_f32x16 = _mm512_fmadd_ps(a_x_f32x16, b_x_f32x16, cov_xx_f32x16);
2176
+ cov_xy_f32x16 = _mm512_fmadd_ps(a_x_f32x16, b_y_f32x16, cov_xy_f32x16);
2177
+ cov_xz_f32x16 = _mm512_fmadd_ps(a_x_f32x16, b_z_f32x16, cov_xz_f32x16);
2178
+ cov_yx_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_x_f32x16, cov_yx_f32x16);
2179
+ cov_yy_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_y_f32x16, cov_yy_f32x16);
2180
+ cov_yz_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_z_f32x16, cov_yz_f32x16);
2181
+ cov_zx_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_x_f32x16, cov_zx_f32x16);
2182
+ cov_zy_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_y_f32x16, cov_zy_f32x16);
2183
+ cov_zz_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_z_f32x16, cov_zz_f32x16);
2184
+
2185
+ variance_a_f32x16 = _mm512_fmadd_ps(a_x_f32x16, a_x_f32x16, variance_a_f32x16);
2186
+ variance_a_f32x16 = _mm512_fmadd_ps(a_y_f32x16, a_y_f32x16, variance_a_f32x16);
2187
+ variance_a_f32x16 = _mm512_fmadd_ps(a_z_f32x16, a_z_f32x16, variance_a_f32x16);
2188
+ }
2189
+
2190
+ // Tail: deinterleave remaining points into zero-initialized vectors
2191
+ if (i < n) {
2192
+ nk_size_t tail = n - i;
2193
+ nk_deinterleave_f16_tail_to_f32x16_skylake_(a + i * 3, tail, &a_x_f32x16, &a_y_f32x16, &a_z_f32x16);
2194
+ nk_deinterleave_f16_tail_to_f32x16_skylake_(b + i * 3, tail, &b_x_f32x16, &b_y_f32x16, &b_z_f32x16);
2195
+
2196
+ sum_a_x_f32x16 = _mm512_add_ps(sum_a_x_f32x16, a_x_f32x16);
2197
+ sum_a_y_f32x16 = _mm512_add_ps(sum_a_y_f32x16, a_y_f32x16);
2198
+ sum_a_z_f32x16 = _mm512_add_ps(sum_a_z_f32x16, a_z_f32x16);
2199
+ sum_b_x_f32x16 = _mm512_add_ps(sum_b_x_f32x16, b_x_f32x16);
2200
+ sum_b_y_f32x16 = _mm512_add_ps(sum_b_y_f32x16, b_y_f32x16);
2201
+ sum_b_z_f32x16 = _mm512_add_ps(sum_b_z_f32x16, b_z_f32x16);
2202
+
2203
+ cov_xx_f32x16 = _mm512_fmadd_ps(a_x_f32x16, b_x_f32x16, cov_xx_f32x16);
2204
+ cov_xy_f32x16 = _mm512_fmadd_ps(a_x_f32x16, b_y_f32x16, cov_xy_f32x16);
2205
+ cov_xz_f32x16 = _mm512_fmadd_ps(a_x_f32x16, b_z_f32x16, cov_xz_f32x16);
2206
+ cov_yx_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_x_f32x16, cov_yx_f32x16);
2207
+ cov_yy_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_y_f32x16, cov_yy_f32x16);
2208
+ cov_yz_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_z_f32x16, cov_yz_f32x16);
2209
+ cov_zx_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_x_f32x16, cov_zx_f32x16);
2210
+ cov_zy_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_y_f32x16, cov_zy_f32x16);
2211
+ cov_zz_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_z_f32x16, cov_zz_f32x16);
2212
+
2213
+ variance_a_f32x16 = _mm512_fmadd_ps(a_x_f32x16, a_x_f32x16, variance_a_f32x16);
2214
+ variance_a_f32x16 = _mm512_fmadd_ps(a_y_f32x16, a_y_f32x16, variance_a_f32x16);
2215
+ variance_a_f32x16 = _mm512_fmadd_ps(a_z_f32x16, a_z_f32x16, variance_a_f32x16);
2216
+ }
2217
+
2218
+ nk_f32_t sum_a_x = _mm512_reduce_add_ps(sum_a_x_f32x16);
2219
+ nk_f32_t sum_a_y = _mm512_reduce_add_ps(sum_a_y_f32x16);
2220
+ nk_f32_t sum_a_z = _mm512_reduce_add_ps(sum_a_z_f32x16);
2221
+ nk_f32_t sum_b_x = _mm512_reduce_add_ps(sum_b_x_f32x16);
2222
+ nk_f32_t sum_b_y = _mm512_reduce_add_ps(sum_b_y_f32x16);
2223
+ nk_f32_t sum_b_z = _mm512_reduce_add_ps(sum_b_z_f32x16);
2224
+ nk_f32_t covariance_x_x = _mm512_reduce_add_ps(cov_xx_f32x16);
2225
+ nk_f32_t covariance_x_y = _mm512_reduce_add_ps(cov_xy_f32x16);
2226
+ nk_f32_t covariance_x_z = _mm512_reduce_add_ps(cov_xz_f32x16);
2227
+ nk_f32_t covariance_y_x = _mm512_reduce_add_ps(cov_yx_f32x16);
2228
+ nk_f32_t covariance_y_y = _mm512_reduce_add_ps(cov_yy_f32x16);
2229
+ nk_f32_t covariance_y_z = _mm512_reduce_add_ps(cov_yz_f32x16);
2230
+ nk_f32_t covariance_z_x = _mm512_reduce_add_ps(cov_zx_f32x16);
2231
+ nk_f32_t covariance_z_y = _mm512_reduce_add_ps(cov_zy_f32x16);
2232
+ nk_f32_t covariance_z_z = _mm512_reduce_add_ps(cov_zz_f32x16);
2233
+ nk_f32_t variance_a_sum = _mm512_reduce_add_ps(variance_a_f32x16);
2234
+
2235
+ nk_f32_t inv_n = 1.0f / (nk_f32_t)n;
2236
+ nk_f32_t centroid_a_x = sum_a_x * inv_n, centroid_a_y = sum_a_y * inv_n, centroid_a_z = sum_a_z * inv_n;
2237
+ nk_f32_t centroid_b_x = sum_b_x * inv_n, centroid_b_y = sum_b_y * inv_n, centroid_b_z = sum_b_z * inv_n;
2238
+
2239
+ if (a_centroid) a_centroid[0] = centroid_a_x, a_centroid[1] = centroid_a_y, a_centroid[2] = centroid_a_z;
2240
+ if (b_centroid) b_centroid[0] = centroid_b_x, b_centroid[1] = centroid_b_y, b_centroid[2] = centroid_b_z;
2241
+
2242
+ nk_f32_t variance_a = variance_a_sum * inv_n -
2243
+ (centroid_a_x * centroid_a_x + centroid_a_y * centroid_a_y + centroid_a_z * centroid_a_z);
2244
+
2245
+ covariance_x_x -= (nk_f32_t)n * centroid_a_x * centroid_b_x;
2246
+ covariance_x_y -= (nk_f32_t)n * centroid_a_x * centroid_b_y;
2247
+ covariance_x_z -= (nk_f32_t)n * centroid_a_x * centroid_b_z;
2248
+ covariance_y_x -= (nk_f32_t)n * centroid_a_y * centroid_b_x;
2249
+ covariance_y_y -= (nk_f32_t)n * centroid_a_y * centroid_b_y;
2250
+ covariance_y_z -= (nk_f32_t)n * centroid_a_y * centroid_b_z;
2251
+ covariance_z_x -= (nk_f32_t)n * centroid_a_z * centroid_b_x;
2252
+ covariance_z_y -= (nk_f32_t)n * centroid_a_z * centroid_b_y;
2253
+ covariance_z_z -= (nk_f32_t)n * centroid_a_z * centroid_b_z;
2254
+
2255
+ nk_f32_t cross_covariance[9] = {covariance_x_x, covariance_x_y, covariance_x_z, covariance_y_x, covariance_y_y,
2256
+ covariance_y_z, covariance_z_x, covariance_z_y, covariance_z_z};
2257
+
2258
+ nk_f32_t svd_u[9], svd_s[9], svd_v[9];
2259
+ nk_svd3x3_f32_(cross_covariance, svd_u, svd_s, svd_v);
2260
+
2261
+ nk_f32_t r[9];
2262
+ r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
2263
+ r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
2264
+ r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
2265
+ r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
2266
+ r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
2267
+ r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
2268
+ r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
2269
+ r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
2270
+ r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
2271
+
2272
+ nk_f32_t det = nk_det3x3_f32_(r);
2273
+ nk_f32_t d3 = det < 0 ? -1.0f : 1.0f;
2274
+ nk_f32_t trace_ds = svd_s[0] + svd_s[4] + d3 * svd_s[8];
2275
+ nk_f32_t c = trace_ds / ((nk_f32_t)n * variance_a);
2276
+ if (scale) *scale = c;
2277
+
2278
+ if (det < 0) {
2279
+ svd_v[2] = -svd_v[2], svd_v[5] = -svd_v[5], svd_v[8] = -svd_v[8];
2280
+ r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
2281
+ r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
2282
+ r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
2283
+ r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
2284
+ r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
2285
+ r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
2286
+ r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
2287
+ r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
2288
+ r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
2289
+ }
2290
+
2291
+ if (rotation)
2292
+ for (int j = 0; j < 9; ++j) rotation[j] = r[j];
2293
+
2294
+ nk_f32_t sum_squared = nk_transformed_ssd_f16_skylake_(a, b, n, r, c, centroid_a_x, centroid_a_y, centroid_a_z,
2295
+ centroid_b_x, centroid_b_y, centroid_b_z);
2296
+ *result = nk_f32_sqrt_haswell(sum_squared * inv_n);
2297
+ }
2298
+
2299
+ NK_PUBLIC void nk_umeyama_bf16_skylake(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
2300
+ nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
2301
+ __m512 const zeros_f32x16 = _mm512_setzero_ps();
2302
+
2303
+ __m512 sum_a_x_f32x16 = zeros_f32x16, sum_a_y_f32x16 = zeros_f32x16, sum_a_z_f32x16 = zeros_f32x16;
2304
+ __m512 sum_b_x_f32x16 = zeros_f32x16, sum_b_y_f32x16 = zeros_f32x16, sum_b_z_f32x16 = zeros_f32x16;
2305
+ __m512 cov_xx_f32x16 = zeros_f32x16, cov_xy_f32x16 = zeros_f32x16, cov_xz_f32x16 = zeros_f32x16;
2306
+ __m512 cov_yx_f32x16 = zeros_f32x16, cov_yy_f32x16 = zeros_f32x16, cov_yz_f32x16 = zeros_f32x16;
2307
+ __m512 cov_zx_f32x16 = zeros_f32x16, cov_zy_f32x16 = zeros_f32x16, cov_zz_f32x16 = zeros_f32x16;
2308
+ __m512 variance_a_f32x16 = zeros_f32x16;
2309
+
2310
+ nk_size_t i = 0;
2311
+ __m512 a_x_f32x16, a_y_f32x16, a_z_f32x16, b_x_f32x16, b_y_f32x16, b_z_f32x16;
2312
+
2313
+ for (; i + 16 <= n; i += 16) {
2314
+ nk_deinterleave_bf16x16_to_f32x16_skylake_(a + i * 3, &a_x_f32x16, &a_y_f32x16, &a_z_f32x16);
2315
+ nk_deinterleave_bf16x16_to_f32x16_skylake_(b + i * 3, &b_x_f32x16, &b_y_f32x16, &b_z_f32x16);
2316
+
2317
+ sum_a_x_f32x16 = _mm512_add_ps(sum_a_x_f32x16, a_x_f32x16);
2318
+ sum_a_y_f32x16 = _mm512_add_ps(sum_a_y_f32x16, a_y_f32x16);
2319
+ sum_a_z_f32x16 = _mm512_add_ps(sum_a_z_f32x16, a_z_f32x16);
2320
+ sum_b_x_f32x16 = _mm512_add_ps(sum_b_x_f32x16, b_x_f32x16);
2321
+ sum_b_y_f32x16 = _mm512_add_ps(sum_b_y_f32x16, b_y_f32x16);
2322
+ sum_b_z_f32x16 = _mm512_add_ps(sum_b_z_f32x16, b_z_f32x16);
2323
+
2324
+ cov_xx_f32x16 = _mm512_fmadd_ps(a_x_f32x16, b_x_f32x16, cov_xx_f32x16);
2325
+ cov_xy_f32x16 = _mm512_fmadd_ps(a_x_f32x16, b_y_f32x16, cov_xy_f32x16);
2326
+ cov_xz_f32x16 = _mm512_fmadd_ps(a_x_f32x16, b_z_f32x16, cov_xz_f32x16);
2327
+ cov_yx_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_x_f32x16, cov_yx_f32x16);
2328
+ cov_yy_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_y_f32x16, cov_yy_f32x16);
2329
+ cov_yz_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_z_f32x16, cov_yz_f32x16);
2330
+ cov_zx_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_x_f32x16, cov_zx_f32x16);
2331
+ cov_zy_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_y_f32x16, cov_zy_f32x16);
2332
+ cov_zz_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_z_f32x16, cov_zz_f32x16);
2333
+
2334
+ variance_a_f32x16 = _mm512_fmadd_ps(a_x_f32x16, a_x_f32x16, variance_a_f32x16);
2335
+ variance_a_f32x16 = _mm512_fmadd_ps(a_y_f32x16, a_y_f32x16, variance_a_f32x16);
2336
+ variance_a_f32x16 = _mm512_fmadd_ps(a_z_f32x16, a_z_f32x16, variance_a_f32x16);
2337
+ }
2338
+
2339
+ // Tail: deinterleave remaining points into zero-initialized vectors
2340
+ if (i < n) {
2341
+ nk_size_t tail = n - i;
2342
+ nk_deinterleave_bf16_tail_to_f32x16_skylake_(a + i * 3, tail, &a_x_f32x16, &a_y_f32x16, &a_z_f32x16);
2343
+ nk_deinterleave_bf16_tail_to_f32x16_skylake_(b + i * 3, tail, &b_x_f32x16, &b_y_f32x16, &b_z_f32x16);
2344
+
2345
+ sum_a_x_f32x16 = _mm512_add_ps(sum_a_x_f32x16, a_x_f32x16);
2346
+ sum_a_y_f32x16 = _mm512_add_ps(sum_a_y_f32x16, a_y_f32x16);
2347
+ sum_a_z_f32x16 = _mm512_add_ps(sum_a_z_f32x16, a_z_f32x16);
2348
+ sum_b_x_f32x16 = _mm512_add_ps(sum_b_x_f32x16, b_x_f32x16);
2349
+ sum_b_y_f32x16 = _mm512_add_ps(sum_b_y_f32x16, b_y_f32x16);
2350
+ sum_b_z_f32x16 = _mm512_add_ps(sum_b_z_f32x16, b_z_f32x16);
2351
+
2352
+ cov_xx_f32x16 = _mm512_fmadd_ps(a_x_f32x16, b_x_f32x16, cov_xx_f32x16);
2353
+ cov_xy_f32x16 = _mm512_fmadd_ps(a_x_f32x16, b_y_f32x16, cov_xy_f32x16);
2354
+ cov_xz_f32x16 = _mm512_fmadd_ps(a_x_f32x16, b_z_f32x16, cov_xz_f32x16);
2355
+ cov_yx_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_x_f32x16, cov_yx_f32x16);
2356
+ cov_yy_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_y_f32x16, cov_yy_f32x16);
2357
+ cov_yz_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_z_f32x16, cov_yz_f32x16);
2358
+ cov_zx_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_x_f32x16, cov_zx_f32x16);
2359
+ cov_zy_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_y_f32x16, cov_zy_f32x16);
2360
+ cov_zz_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_z_f32x16, cov_zz_f32x16);
2361
+
2362
+ variance_a_f32x16 = _mm512_fmadd_ps(a_x_f32x16, a_x_f32x16, variance_a_f32x16);
2363
+ variance_a_f32x16 = _mm512_fmadd_ps(a_y_f32x16, a_y_f32x16, variance_a_f32x16);
2364
+ variance_a_f32x16 = _mm512_fmadd_ps(a_z_f32x16, a_z_f32x16, variance_a_f32x16);
2365
+ }
2366
+
2367
+ nk_f32_t sum_a_x = _mm512_reduce_add_ps(sum_a_x_f32x16);
2368
+ nk_f32_t sum_a_y = _mm512_reduce_add_ps(sum_a_y_f32x16);
2369
+ nk_f32_t sum_a_z = _mm512_reduce_add_ps(sum_a_z_f32x16);
2370
+ nk_f32_t sum_b_x = _mm512_reduce_add_ps(sum_b_x_f32x16);
2371
+ nk_f32_t sum_b_y = _mm512_reduce_add_ps(sum_b_y_f32x16);
2372
+ nk_f32_t sum_b_z = _mm512_reduce_add_ps(sum_b_z_f32x16);
2373
+ nk_f32_t covariance_x_x = _mm512_reduce_add_ps(cov_xx_f32x16);
2374
+ nk_f32_t covariance_x_y = _mm512_reduce_add_ps(cov_xy_f32x16);
2375
+ nk_f32_t covariance_x_z = _mm512_reduce_add_ps(cov_xz_f32x16);
2376
+ nk_f32_t covariance_y_x = _mm512_reduce_add_ps(cov_yx_f32x16);
2377
+ nk_f32_t covariance_y_y = _mm512_reduce_add_ps(cov_yy_f32x16);
2378
+ nk_f32_t covariance_y_z = _mm512_reduce_add_ps(cov_yz_f32x16);
2379
+ nk_f32_t covariance_z_x = _mm512_reduce_add_ps(cov_zx_f32x16);
2380
+ nk_f32_t covariance_z_y = _mm512_reduce_add_ps(cov_zy_f32x16);
2381
+ nk_f32_t covariance_z_z = _mm512_reduce_add_ps(cov_zz_f32x16);
2382
+ nk_f32_t variance_a_sum = _mm512_reduce_add_ps(variance_a_f32x16);
2383
+
2384
+ nk_f32_t inv_n = 1.0f / (nk_f32_t)n;
2385
+ nk_f32_t centroid_a_x = sum_a_x * inv_n, centroid_a_y = sum_a_y * inv_n, centroid_a_z = sum_a_z * inv_n;
2386
+ nk_f32_t centroid_b_x = sum_b_x * inv_n, centroid_b_y = sum_b_y * inv_n, centroid_b_z = sum_b_z * inv_n;
2387
+
2388
+ if (a_centroid) a_centroid[0] = centroid_a_x, a_centroid[1] = centroid_a_y, a_centroid[2] = centroid_a_z;
2389
+ if (b_centroid) b_centroid[0] = centroid_b_x, b_centroid[1] = centroid_b_y, b_centroid[2] = centroid_b_z;
2390
+
2391
+ nk_f32_t variance_a = variance_a_sum * inv_n -
2392
+ (centroid_a_x * centroid_a_x + centroid_a_y * centroid_a_y + centroid_a_z * centroid_a_z);
2393
+
2394
+ covariance_x_x -= (nk_f32_t)n * centroid_a_x * centroid_b_x;
2395
+ covariance_x_y -= (nk_f32_t)n * centroid_a_x * centroid_b_y;
2396
+ covariance_x_z -= (nk_f32_t)n * centroid_a_x * centroid_b_z;
2397
+ covariance_y_x -= (nk_f32_t)n * centroid_a_y * centroid_b_x;
2398
+ covariance_y_y -= (nk_f32_t)n * centroid_a_y * centroid_b_y;
2399
+ covariance_y_z -= (nk_f32_t)n * centroid_a_y * centroid_b_z;
2400
+ covariance_z_x -= (nk_f32_t)n * centroid_a_z * centroid_b_x;
2401
+ covariance_z_y -= (nk_f32_t)n * centroid_a_z * centroid_b_y;
2402
+ covariance_z_z -= (nk_f32_t)n * centroid_a_z * centroid_b_z;
2403
+
2404
+ nk_f32_t cross_covariance[9] = {covariance_x_x, covariance_x_y, covariance_x_z, covariance_y_x, covariance_y_y,
2405
+ covariance_y_z, covariance_z_x, covariance_z_y, covariance_z_z};
2406
+
2407
+ nk_f32_t svd_u[9], svd_s[9], svd_v[9];
2408
+ nk_svd3x3_f32_(cross_covariance, svd_u, svd_s, svd_v);
2409
+
2410
+ nk_f32_t r[9];
2411
+ r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
2412
+ r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
2413
+ r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
2414
+ r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
2415
+ r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
2416
+ r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
2417
+ r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
2418
+ r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
2419
+ r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
2420
+
2421
+ nk_f32_t det = nk_det3x3_f32_(r);
2422
+ nk_f32_t d3 = det < 0 ? -1.0f : 1.0f;
2423
+ nk_f32_t trace_ds = svd_s[0] + svd_s[4] + d3 * svd_s[8];
2424
+ nk_f32_t c = trace_ds / ((nk_f32_t)n * variance_a);
2425
+ if (scale) *scale = c;
2426
+
2427
+ if (det < 0) {
2428
+ svd_v[2] = -svd_v[2], svd_v[5] = -svd_v[5], svd_v[8] = -svd_v[8];
2429
+ r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
2430
+ r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
2431
+ r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
2432
+ r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
2433
+ r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
2434
+ r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
2435
+ r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
2436
+ r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
2437
+ r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
2438
+ }
2439
+
2440
+ if (rotation)
2441
+ for (int j = 0; j < 9; ++j) rotation[j] = r[j];
2442
+
2443
+ nk_f32_t sum_squared = nk_transformed_ssd_bf16_skylake_(a, b, n, r, c, centroid_a_x, centroid_a_y, centroid_a_z,
2444
+ centroid_b_x, centroid_b_y, centroid_b_z);
2445
+ *result = nk_f32_sqrt_haswell(sum_squared * inv_n);
2446
+ }
2447
+
1123
2448
  #if defined(__clang__)
1124
2449
  #pragma clang attribute pop
1125
2450
  #elif defined(__GNUC__)