numkong 7.0.0 → 7.4.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (315) hide show
  1. package/README.md +239 -122
  2. package/binding.gyp +25 -491
  3. package/c/dispatch_bf16.c +59 -1
  4. package/c/dispatch_e2m3.c +41 -8
  5. package/c/dispatch_e3m2.c +49 -8
  6. package/c/dispatch_e4m3.c +51 -9
  7. package/c/dispatch_e5m2.c +45 -1
  8. package/c/dispatch_f16.c +79 -26
  9. package/c/dispatch_f16c.c +5 -5
  10. package/c/dispatch_f32.c +56 -0
  11. package/c/dispatch_f64.c +52 -0
  12. package/c/dispatch_i4.c +3 -0
  13. package/c/dispatch_i8.c +62 -3
  14. package/c/dispatch_other.c +18 -0
  15. package/c/dispatch_u1.c +54 -9
  16. package/c/dispatch_u4.c +3 -0
  17. package/c/dispatch_u8.c +64 -3
  18. package/c/numkong.c +3 -0
  19. package/include/README.md +79 -9
  20. package/include/numkong/attention/sapphireamx.h +278 -276
  21. package/include/numkong/attention/sme.h +983 -977
  22. package/include/numkong/attention.h +1 -1
  23. package/include/numkong/capabilities.h +289 -94
  24. package/include/numkong/cast/README.md +40 -40
  25. package/include/numkong/cast/diamond.h +64 -0
  26. package/include/numkong/cast/haswell.h +42 -194
  27. package/include/numkong/cast/icelake.h +42 -37
  28. package/include/numkong/cast/loongsonasx.h +252 -0
  29. package/include/numkong/cast/neon.h +216 -249
  30. package/include/numkong/cast/powervsx.h +449 -0
  31. package/include/numkong/cast/rvv.h +223 -274
  32. package/include/numkong/cast/sapphire.h +18 -18
  33. package/include/numkong/cast/serial.h +1018 -944
  34. package/include/numkong/cast/skylake.h +82 -23
  35. package/include/numkong/cast/v128relaxed.h +462 -105
  36. package/include/numkong/cast.h +24 -0
  37. package/include/numkong/cast.hpp +44 -0
  38. package/include/numkong/curved/README.md +17 -17
  39. package/include/numkong/curved/neon.h +131 -7
  40. package/include/numkong/curved/neonbfdot.h +6 -7
  41. package/include/numkong/curved/rvv.h +26 -26
  42. package/include/numkong/curved/smef64.h +186 -182
  43. package/include/numkong/curved.h +14 -18
  44. package/include/numkong/dot/README.md +154 -137
  45. package/include/numkong/dot/alder.h +43 -43
  46. package/include/numkong/dot/diamond.h +158 -0
  47. package/include/numkong/dot/genoa.h +4 -30
  48. package/include/numkong/dot/haswell.h +215 -180
  49. package/include/numkong/dot/icelake.h +190 -76
  50. package/include/numkong/dot/loongsonasx.h +671 -0
  51. package/include/numkong/dot/neon.h +124 -73
  52. package/include/numkong/dot/neonbfdot.h +11 -12
  53. package/include/numkong/dot/neonfhm.h +44 -46
  54. package/include/numkong/dot/neonfp8.h +323 -0
  55. package/include/numkong/dot/neonsdot.h +190 -76
  56. package/include/numkong/dot/powervsx.h +752 -0
  57. package/include/numkong/dot/rvv.h +92 -84
  58. package/include/numkong/dot/rvvbf16.h +12 -12
  59. package/include/numkong/dot/rvvhalf.h +12 -12
  60. package/include/numkong/dot/sapphire.h +4 -4
  61. package/include/numkong/dot/serial.h +66 -30
  62. package/include/numkong/dot/sierra.h +31 -31
  63. package/include/numkong/dot/skylake.h +142 -110
  64. package/include/numkong/dot/sve.h +217 -177
  65. package/include/numkong/dot/svebfdot.h +10 -10
  66. package/include/numkong/dot/svehalf.h +85 -41
  67. package/include/numkong/dot/svesdot.h +89 -0
  68. package/include/numkong/dot/v128relaxed.h +124 -89
  69. package/include/numkong/dot.h +114 -48
  70. package/include/numkong/dots/README.md +203 -203
  71. package/include/numkong/dots/alder.h +12 -9
  72. package/include/numkong/dots/diamond.h +86 -0
  73. package/include/numkong/dots/genoa.h +10 -4
  74. package/include/numkong/dots/haswell.h +63 -48
  75. package/include/numkong/dots/icelake.h +27 -18
  76. package/include/numkong/dots/loongsonasx.h +176 -0
  77. package/include/numkong/dots/neon.h +14 -11
  78. package/include/numkong/dots/neonbfdot.h +4 -3
  79. package/include/numkong/dots/neonfhm.h +11 -9
  80. package/include/numkong/dots/neonfp8.h +99 -0
  81. package/include/numkong/dots/neonsdot.h +48 -12
  82. package/include/numkong/dots/powervsx.h +194 -0
  83. package/include/numkong/dots/rvv.h +451 -344
  84. package/include/numkong/dots/sapphireamx.h +1028 -984
  85. package/include/numkong/dots/serial.h +213 -197
  86. package/include/numkong/dots/sierra.h +10 -7
  87. package/include/numkong/dots/skylake.h +47 -36
  88. package/include/numkong/dots/sme.h +2001 -2364
  89. package/include/numkong/dots/smebi32.h +175 -162
  90. package/include/numkong/dots/smef64.h +328 -323
  91. package/include/numkong/dots/v128relaxed.h +64 -41
  92. package/include/numkong/dots.h +573 -293
  93. package/include/numkong/dots.hpp +45 -43
  94. package/include/numkong/each/README.md +133 -137
  95. package/include/numkong/each/haswell.h +6 -6
  96. package/include/numkong/each/icelake.h +7 -7
  97. package/include/numkong/each/neon.h +76 -42
  98. package/include/numkong/each/neonbfdot.h +11 -12
  99. package/include/numkong/each/neonhalf.h +24 -116
  100. package/include/numkong/each/rvv.h +28 -28
  101. package/include/numkong/each/sapphire.h +27 -161
  102. package/include/numkong/each/serial.h +6 -6
  103. package/include/numkong/each/skylake.h +7 -7
  104. package/include/numkong/each/v128relaxed.h +562 -0
  105. package/include/numkong/each.h +148 -62
  106. package/include/numkong/each.hpp +2 -2
  107. package/include/numkong/geospatial/README.md +18 -18
  108. package/include/numkong/geospatial/haswell.h +365 -325
  109. package/include/numkong/geospatial/neon.h +350 -306
  110. package/include/numkong/geospatial/rvv.h +4 -4
  111. package/include/numkong/geospatial/skylake.h +376 -340
  112. package/include/numkong/geospatial/v128relaxed.h +366 -327
  113. package/include/numkong/geospatial.h +17 -17
  114. package/include/numkong/matrix.hpp +4 -4
  115. package/include/numkong/maxsim/README.md +14 -14
  116. package/include/numkong/maxsim/alder.h +6 -6
  117. package/include/numkong/maxsim/genoa.h +4 -4
  118. package/include/numkong/maxsim/haswell.h +6 -6
  119. package/include/numkong/maxsim/icelake.h +18 -18
  120. package/include/numkong/maxsim/neonsdot.h +21 -21
  121. package/include/numkong/maxsim/sapphireamx.h +14 -14
  122. package/include/numkong/maxsim/serial.h +6 -6
  123. package/include/numkong/maxsim/sme.h +221 -196
  124. package/include/numkong/maxsim/v128relaxed.h +6 -6
  125. package/include/numkong/mesh/README.md +62 -56
  126. package/include/numkong/mesh/haswell.h +339 -464
  127. package/include/numkong/mesh/neon.h +1100 -519
  128. package/include/numkong/mesh/neonbfdot.h +36 -68
  129. package/include/numkong/mesh/rvv.h +530 -435
  130. package/include/numkong/mesh/serial.h +75 -91
  131. package/include/numkong/mesh/skylake.h +1627 -302
  132. package/include/numkong/mesh/v128relaxed.h +443 -330
  133. package/include/numkong/mesh.h +63 -49
  134. package/include/numkong/mesh.hpp +4 -4
  135. package/include/numkong/numkong.h +3 -3
  136. package/include/numkong/numkong.hpp +1 -0
  137. package/include/numkong/probability/README.md +23 -19
  138. package/include/numkong/probability/neon.h +82 -52
  139. package/include/numkong/probability/rvv.h +28 -23
  140. package/include/numkong/probability/serial.h +51 -39
  141. package/include/numkong/probability.h +20 -23
  142. package/include/numkong/random.h +1 -1
  143. package/include/numkong/reduce/README.md +143 -138
  144. package/include/numkong/reduce/alder.h +81 -77
  145. package/include/numkong/reduce/haswell.h +222 -220
  146. package/include/numkong/reduce/neon.h +629 -519
  147. package/include/numkong/reduce/neonbfdot.h +7 -218
  148. package/include/numkong/reduce/neonfhm.h +9 -381
  149. package/include/numkong/reduce/neonsdot.h +9 -9
  150. package/include/numkong/reduce/rvv.h +928 -802
  151. package/include/numkong/reduce/serial.h +23 -27
  152. package/include/numkong/reduce/sierra.h +20 -20
  153. package/include/numkong/reduce/skylake.h +326 -324
  154. package/include/numkong/reduce/v128relaxed.h +52 -52
  155. package/include/numkong/reduce.h +4 -23
  156. package/include/numkong/reduce.hpp +156 -11
  157. package/include/numkong/scalar/README.md +6 -6
  158. package/include/numkong/scalar/haswell.h +26 -17
  159. package/include/numkong/scalar/loongsonasx.h +74 -0
  160. package/include/numkong/scalar/neon.h +9 -9
  161. package/include/numkong/scalar/powervsx.h +96 -0
  162. package/include/numkong/scalar/rvv.h +2 -2
  163. package/include/numkong/scalar/sapphire.h +21 -10
  164. package/include/numkong/scalar/serial.h +21 -21
  165. package/include/numkong/scalar.h +13 -0
  166. package/include/numkong/set/README.md +28 -28
  167. package/include/numkong/set/haswell.h +12 -12
  168. package/include/numkong/set/icelake.h +14 -14
  169. package/include/numkong/set/loongsonasx.h +181 -0
  170. package/include/numkong/set/neon.h +17 -18
  171. package/include/numkong/set/powervsx.h +326 -0
  172. package/include/numkong/set/rvv.h +4 -4
  173. package/include/numkong/set/serial.h +6 -6
  174. package/include/numkong/set/sve.h +60 -59
  175. package/include/numkong/set/v128relaxed.h +6 -6
  176. package/include/numkong/set.h +21 -7
  177. package/include/numkong/sets/README.md +26 -26
  178. package/include/numkong/sets/loongsonasx.h +52 -0
  179. package/include/numkong/sets/powervsx.h +65 -0
  180. package/include/numkong/sets/smebi32.h +395 -364
  181. package/include/numkong/sets.h +83 -40
  182. package/include/numkong/sparse/README.md +4 -4
  183. package/include/numkong/sparse/icelake.h +101 -101
  184. package/include/numkong/sparse/serial.h +1 -1
  185. package/include/numkong/sparse/sve2.h +137 -141
  186. package/include/numkong/sparse/turin.h +12 -12
  187. package/include/numkong/sparse.h +10 -10
  188. package/include/numkong/spatial/README.md +230 -226
  189. package/include/numkong/spatial/alder.h +113 -116
  190. package/include/numkong/spatial/diamond.h +240 -0
  191. package/include/numkong/spatial/genoa.h +0 -68
  192. package/include/numkong/spatial/haswell.h +74 -55
  193. package/include/numkong/spatial/icelake.h +539 -58
  194. package/include/numkong/spatial/loongsonasx.h +483 -0
  195. package/include/numkong/spatial/neon.h +125 -52
  196. package/include/numkong/spatial/neonbfdot.h +8 -9
  197. package/include/numkong/spatial/neonfp8.h +258 -0
  198. package/include/numkong/spatial/neonsdot.h +180 -12
  199. package/include/numkong/spatial/powervsx.h +738 -0
  200. package/include/numkong/spatial/rvv.h +146 -139
  201. package/include/numkong/spatial/rvvbf16.h +17 -12
  202. package/include/numkong/spatial/rvvhalf.h +13 -10
  203. package/include/numkong/spatial/serial.h +13 -12
  204. package/include/numkong/spatial/sierra.h +232 -39
  205. package/include/numkong/spatial/skylake.h +73 -74
  206. package/include/numkong/spatial/sve.h +93 -72
  207. package/include/numkong/spatial/svebfdot.h +29 -29
  208. package/include/numkong/spatial/svehalf.h +52 -26
  209. package/include/numkong/spatial/svesdot.h +142 -0
  210. package/include/numkong/spatial/v128relaxed.h +293 -41
  211. package/include/numkong/spatial.h +338 -82
  212. package/include/numkong/spatials/README.md +194 -194
  213. package/include/numkong/spatials/diamond.h +82 -0
  214. package/include/numkong/spatials/haswell.h +2 -2
  215. package/include/numkong/spatials/loongsonasx.h +153 -0
  216. package/include/numkong/spatials/neonfp8.h +111 -0
  217. package/include/numkong/spatials/neonsdot.h +34 -0
  218. package/include/numkong/spatials/powervsx.h +153 -0
  219. package/include/numkong/spatials/rvv.h +259 -243
  220. package/include/numkong/spatials/sapphireamx.h +173 -173
  221. package/include/numkong/spatials/serial.h +2 -2
  222. package/include/numkong/spatials/skylake.h +2 -2
  223. package/include/numkong/spatials/sme.h +590 -605
  224. package/include/numkong/spatials/smef64.h +139 -130
  225. package/include/numkong/spatials/v128relaxed.h +2 -2
  226. package/include/numkong/spatials.h +820 -500
  227. package/include/numkong/spatials.hpp +49 -48
  228. package/include/numkong/tensor.hpp +406 -17
  229. package/include/numkong/trigonometry/README.md +19 -19
  230. package/include/numkong/trigonometry/haswell.h +402 -401
  231. package/include/numkong/trigonometry/neon.h +386 -387
  232. package/include/numkong/trigonometry/rvv.h +52 -51
  233. package/include/numkong/trigonometry/serial.h +13 -13
  234. package/include/numkong/trigonometry/skylake.h +373 -369
  235. package/include/numkong/trigonometry/v128relaxed.h +375 -374
  236. package/include/numkong/trigonometry.h +13 -13
  237. package/include/numkong/trigonometry.hpp +2 -2
  238. package/include/numkong/types.h +287 -49
  239. package/include/numkong/types.hpp +436 -12
  240. package/include/numkong/vector.hpp +82 -14
  241. package/javascript/dist/cjs/numkong-wasm.js +6 -12
  242. package/javascript/dist/cjs/numkong.d.ts +7 -1
  243. package/javascript/dist/cjs/numkong.js +37 -11
  244. package/javascript/dist/cjs/types.d.ts +9 -0
  245. package/javascript/dist/cjs/types.js +96 -0
  246. package/javascript/dist/esm/numkong-browser.d.ts +14 -0
  247. package/javascript/dist/esm/numkong-browser.js +23 -0
  248. package/javascript/dist/esm/numkong-wasm.js +6 -12
  249. package/javascript/dist/esm/numkong.d.ts +7 -1
  250. package/javascript/dist/esm/numkong.js +37 -11
  251. package/javascript/dist/esm/types.d.ts +9 -0
  252. package/javascript/dist/esm/types.js +96 -0
  253. package/javascript/node-gyp-build.d.ts +4 -1
  254. package/javascript/numkong-browser.ts +40 -0
  255. package/javascript/numkong-wasm.ts +7 -13
  256. package/javascript/numkong.c +5 -26
  257. package/javascript/numkong.ts +36 -11
  258. package/javascript/tsconfig-base.json +1 -0
  259. package/javascript/tsconfig-cjs.json +6 -1
  260. package/javascript/types.ts +110 -0
  261. package/numkong.gypi +101 -0
  262. package/package.json +34 -13
  263. package/probes/arm_neon.c +8 -0
  264. package/probes/arm_neon_bfdot.c +9 -0
  265. package/probes/arm_neon_fhm.c +9 -0
  266. package/probes/arm_neon_half.c +8 -0
  267. package/probes/arm_neon_sdot.c +9 -0
  268. package/probes/arm_neonfp8.c +9 -0
  269. package/probes/arm_sme.c +16 -0
  270. package/probes/arm_sme2.c +16 -0
  271. package/probes/arm_sme2p1.c +16 -0
  272. package/probes/arm_sme_bf16.c +16 -0
  273. package/probes/arm_sme_bi32.c +16 -0
  274. package/probes/arm_sme_f64.c +16 -0
  275. package/probes/arm_sme_fa64.c +14 -0
  276. package/probes/arm_sme_half.c +16 -0
  277. package/probes/arm_sme_lut2.c +15 -0
  278. package/probes/arm_sve.c +18 -0
  279. package/probes/arm_sve2.c +20 -0
  280. package/probes/arm_sve2p1.c +18 -0
  281. package/probes/arm_sve_bfdot.c +20 -0
  282. package/probes/arm_sve_half.c +18 -0
  283. package/probes/arm_sve_sdot.c +21 -0
  284. package/probes/loongarch_lasx.c +12 -0
  285. package/probes/power_vsx.c +12 -0
  286. package/probes/probe.js +127 -0
  287. package/probes/riscv_rvv.c +14 -0
  288. package/probes/riscv_rvv_bb.c +15 -0
  289. package/probes/riscv_rvv_bf16.c +17 -0
  290. package/probes/riscv_rvv_half.c +14 -0
  291. package/probes/wasm_v128relaxed.c +11 -0
  292. package/probes/x86_alder.c +17 -0
  293. package/probes/x86_diamond.c +17 -0
  294. package/probes/x86_genoa.c +17 -0
  295. package/probes/x86_graniteamx.c +19 -0
  296. package/probes/x86_haswell.c +11 -0
  297. package/probes/x86_icelake.c +17 -0
  298. package/probes/x86_sapphire.c +16 -0
  299. package/probes/x86_sapphireamx.c +18 -0
  300. package/probes/x86_sierra.c +17 -0
  301. package/probes/x86_skylake.c +15 -0
  302. package/probes/x86_turin.c +17 -0
  303. package/wasm/numkong-emscripten.js +2 -0
  304. package/wasm/numkong.d.ts +14 -0
  305. package/wasm/numkong.js +1124 -0
  306. package/wasm/numkong.wasm +0 -0
  307. package/include/numkong/curved/neonhalf.h +0 -212
  308. package/include/numkong/dot/neonhalf.h +0 -198
  309. package/include/numkong/dots/neonhalf.h +0 -57
  310. package/include/numkong/mesh/neonhalf.h +0 -616
  311. package/include/numkong/reduce/neonhalf.h +0 -157
  312. package/include/numkong/spatial/neonhalf.h +0 -118
  313. package/include/numkong/spatial/sapphire.h +0 -343
  314. package/include/numkong/spatials/neonhalf.h +0 -58
  315. package/javascript/README.md +0 -246
@@ -8,15 +8,13 @@
8
8
  *
9
9
  * @section spatial_icelake_instructions Key AVX-512 VNNI Spatial Instructions
10
10
  *
11
- * Intrinsic Instruction Ice Genoa
12
- * _mm512_dpwssd_epi32 VPDPWSSD (ZMM, ZMM, ZMM) 5cy @ p0 4cy @ p01
13
- * _mm512_cvtepi8_epi16 VPMOVSXBW (ZMM, YMM) 3cy @ p5 3cy @ p12
14
- * _mm512_sub_epi16 VPSUBW (ZMM, ZMM, ZMM) 1cy @ p05 1cy @ p0123
15
- * _mm512_reduce_add_epi32 (pseudo: shuffle chain) ~8cy ~8cy
11
+ * Intrinsic Instruction Icelake Genoa
12
+ * _mm512_dpwssd_epi32 VPDPWSSD (ZMM, ZMM, ZMM) 5cy @ p0 4cy @ p01
13
+ * _mm512_cvtepi8_epi16 VPMOVSXBW (ZMM, YMM) 3cy @ p5 3cy @ p12
14
+ * _mm512_sub_epi16 VPSUBW (ZMM, ZMM, ZMM) 1cy @ p05 1cy @ p0123
16
15
  *
17
16
  * Ice Lake's VNNI enables efficient i8 distance computations via VPDPWSSD for squared differences.
18
17
  * After widening i8 to i16, the same instruction computes both multiply and horizontal pair addition.
19
- * This approach avoids the asymmetric VPDPBUSD issues with signed values like -128.
20
18
  */
21
19
  #ifndef NK_SPATIAL_ICELAKE_H
22
20
  #define NK_SPATIAL_ICELAKE_H
@@ -25,18 +23,21 @@
25
23
  #if NK_TARGET_ICELAKE
26
24
 
27
25
  #include "numkong/types.h"
26
+ #include "numkong/spatial/haswell.h" // `nk_angular_normalize_f32_haswell_`, `nk_f32_sqrt_haswell`
27
+ #include "numkong/reduce/skylake.h" // `nk_reduce_add_f32x16_skylake_`
28
28
 
29
29
  #if defined(__cplusplus)
30
30
  extern "C" {
31
31
  #endif
32
32
 
33
33
  #if defined(__clang__)
34
- #pragma clang attribute push( \
35
- __attribute__((target("avx2,avx512f,avx512vl,avx512bw,avx512dq,avx512vnni,f16c,fma,bmi,bmi2"))), \
34
+ #pragma clang attribute push( \
35
+ __attribute__((target("avx2,avx512f,avx512vl,avx512bw,avx512dq,avx512vnni,avx512vbmi,f16c,fma,bmi,bmi2"))), \
36
36
  apply_to = function)
37
37
  #elif defined(__GNUC__)
38
38
  #pragma GCC push_options
39
- #pragma GCC target("avx2", "avx512f", "avx512vl", "avx512bw", "avx512dq", "avx512vnni", "f16c", "fma", "bmi", "bmi2")
39
+ #pragma GCC target("avx2", "avx512f", "avx512vl", "avx512bw", "avx512dq", "avx512vnni", "avx512vbmi", "f16c", "fma", \
40
+ "bmi", "bmi2")
40
41
  #endif
41
42
 
42
43
  NK_PUBLIC void nk_sqeuclidean_i8_icelake(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_u32_t *result) {
@@ -142,7 +143,7 @@ nk_angular_i8_icelake_cycle:
142
143
  //
143
144
  // VNNI instruction performance (Ice Lake vs Zen4 Genoa):
144
145
  //
145
- // Instruction Ice Genoa
146
+ // Instruction Icelake Genoa
146
147
  // VPDPBUSDS (ZMM, ZMM, ZMM) 5cy @ p0 4cy @ p01
147
148
  // VPDPWSSDS (ZMM, ZMM, ZMM) 5cy @ p0 4cy @ p01
148
149
  // VPMADDWD (ZMM, ZMM, ZMM) 5cy @ p05 3cy @ p01
@@ -173,7 +174,8 @@ nk_angular_i8_icelake_cycle:
173
174
  nk_i32_t dot_product_i32 = _mm512_reduce_add_epi32(dot_product_i32x16);
174
175
  nk_i32_t a_norm_sq_i32 = _mm512_reduce_add_epi32(a_norm_sq_i32x16);
175
176
  nk_i32_t b_norm_sq_i32 = _mm512_reduce_add_epi32(b_norm_sq_i32x16);
176
- *result = nk_angular_normalize_f32_haswell_(dot_product_i32, a_norm_sq_i32, b_norm_sq_i32);
177
+ *result = nk_angular_normalize_f32_haswell_((nk_f32_t)dot_product_i32, (nk_f32_t)a_norm_sq_i32,
178
+ (nk_f32_t)b_norm_sq_i32);
177
179
  }
178
180
  NK_PUBLIC void nk_sqeuclidean_u8_icelake(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_u32_t *result) {
179
181
  __m512i distance_sq_low_i32x16 = _mm512_setzero_si512();
@@ -258,7 +260,8 @@ nk_angular_u8_icelake_cycle:
258
260
  _mm512_add_epi32(dot_product_low_i32x16, dot_product_high_i32x16));
259
261
  nk_i32_t a_norm_sq_i32 = _mm512_reduce_add_epi32(_mm512_add_epi32(a_norm_sq_low_i32x16, a_norm_sq_high_i32x16));
260
262
  nk_i32_t b_norm_sq_i32 = _mm512_reduce_add_epi32(_mm512_add_epi32(b_norm_sq_low_i32x16, b_norm_sq_high_i32x16));
261
- *result = nk_angular_normalize_f32_haswell_(dot_product_i32, a_norm_sq_i32, b_norm_sq_i32);
263
+ *result = nk_angular_normalize_f32_haswell_((nk_f32_t)dot_product_i32, (nk_f32_t)a_norm_sq_i32,
264
+ (nk_f32_t)b_norm_sq_i32);
262
265
  }
263
266
 
264
267
  NK_PUBLIC void nk_sqeuclidean_i4_icelake(nk_i4x2_t const *a, nk_i4x2_t const *b, nk_size_t n, nk_u32_t *result) {
@@ -285,7 +288,7 @@ NK_PUBLIC void nk_sqeuclidean_i4_icelake(nk_i4x2_t const *a, nk_i4x2_t const *b,
285
288
  __m512i const nibble_mask_u8x64 = _mm512_set1_epi8(0x0F);
286
289
  __m512i const eight_i8x64 = _mm512_set1_epi8(8);
287
290
 
288
- __m512i a_i4_vec, b_i4_vec;
291
+ __m512i a_i4_u8x64, b_i4_u8x64;
289
292
  __m512i a_low_u8x64, a_high_u8x64, b_low_u8x64, b_high_u8x64;
290
293
  __m512i a_low_i8x64, a_high_i8x64, b_low_i8x64, b_high_i8x64;
291
294
  __m512i diff_low_u8x64, diff_high_u8x64;
@@ -294,22 +297,22 @@ NK_PUBLIC void nk_sqeuclidean_i4_icelake(nk_i4x2_t const *a, nk_i4x2_t const *b,
294
297
  nk_sqeuclidean_i4_icelake_cycle:
295
298
  if (n_bytes < 64) {
296
299
  __mmask64 mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFF, n_bytes);
297
- a_i4_vec = _mm512_maskz_loadu_epi8(mask, a);
298
- b_i4_vec = _mm512_maskz_loadu_epi8(mask, b);
300
+ a_i4_u8x64 = _mm512_maskz_loadu_epi8(mask, a);
301
+ b_i4_u8x64 = _mm512_maskz_loadu_epi8(mask, b);
299
302
  n_bytes = 0;
300
303
  }
301
304
  else {
302
- a_i4_vec = _mm512_loadu_epi8(a);
303
- b_i4_vec = _mm512_loadu_epi8(b);
305
+ a_i4_u8x64 = _mm512_loadu_epi8(a);
306
+ b_i4_u8x64 = _mm512_loadu_epi8(b);
304
307
  a += 64, b += 64, n_bytes -= 64;
305
308
  }
306
309
 
307
310
  // Extract nibbles as unsigned [0,15]. VPSHUFB ignores high 4 bits of index,
308
311
  // so no AND needed for low nibbles when used with lookup, but we need it here.
309
- a_low_u8x64 = _mm512_and_si512(a_i4_vec, nibble_mask_u8x64);
310
- a_high_u8x64 = _mm512_and_si512(_mm512_srli_epi16(a_i4_vec, 4), nibble_mask_u8x64);
311
- b_low_u8x64 = _mm512_and_si512(b_i4_vec, nibble_mask_u8x64);
312
- b_high_u8x64 = _mm512_and_si512(_mm512_srli_epi16(b_i4_vec, 4), nibble_mask_u8x64);
312
+ a_low_u8x64 = _mm512_and_si512(a_i4_u8x64, nibble_mask_u8x64);
313
+ a_high_u8x64 = _mm512_and_si512(_mm512_srli_epi16(a_i4_u8x64, 4), nibble_mask_u8x64);
314
+ b_low_u8x64 = _mm512_and_si512(b_i4_u8x64, nibble_mask_u8x64);
315
+ b_high_u8x64 = _mm512_and_si512(_mm512_srli_epi16(b_i4_u8x64, 4), nibble_mask_u8x64);
313
316
 
314
317
  // Sign extend using XOR trick: signed = (nibble ^ 8) - 8
315
318
  a_low_i8x64 = _mm512_sub_epi8(_mm512_xor_si512(a_low_u8x64, eight_i8x64), eight_i8x64);
@@ -363,7 +366,7 @@ NK_PUBLIC void nk_angular_i4_icelake(nk_i4x2_t const *a, nk_i4x2_t const *b, nk_
363
366
  __m512i const eight_i8x64 = _mm512_set1_epi8(8);
364
367
  __m512i const zeros_i8x64 = _mm512_setzero_si512();
365
368
 
366
- __m512i a_i4_vec, b_i4_vec;
369
+ __m512i a_i4_u8x64, b_i4_u8x64;
367
370
  __m512i a_low_u8x64, a_high_u8x64, b_low_u8x64, b_high_u8x64;
368
371
  __m512i ax_low_u8x64, ax_high_u8x64, bx_low_u8x64, bx_high_u8x64;
369
372
  __m512i a_low_i8x64, a_high_i8x64, b_low_i8x64, b_high_i8x64;
@@ -379,21 +382,21 @@ NK_PUBLIC void nk_angular_i4_icelake(nk_i4x2_t const *a, nk_i4x2_t const *b, nk_
379
382
  nk_angular_i4_icelake_cycle:
380
383
  if (n_bytes < 64) {
381
384
  __mmask64 mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFF, n_bytes);
382
- a_i4_vec = _mm512_mask_loadu_epi8(_mm512_set1_epi8((char)0x88), mask, a);
383
- b_i4_vec = _mm512_mask_loadu_epi8(_mm512_set1_epi8((char)0x88), mask, b);
385
+ a_i4_u8x64 = _mm512_mask_loadu_epi8(_mm512_set1_epi8((char)0x88), mask, a);
386
+ b_i4_u8x64 = _mm512_mask_loadu_epi8(_mm512_set1_epi8((char)0x88), mask, b);
384
387
  n_bytes = 0;
385
388
  }
386
389
  else {
387
- a_i4_vec = _mm512_loadu_epi8(a);
388
- b_i4_vec = _mm512_loadu_epi8(b);
390
+ a_i4_u8x64 = _mm512_loadu_epi8(a);
391
+ b_i4_u8x64 = _mm512_loadu_epi8(b);
389
392
  a += 64, b += 64, n_bytes -= 64;
390
393
  }
391
394
 
392
395
  // Extract nibbles as unsigned [0,15]
393
- a_low_u8x64 = _mm512_and_si512(a_i4_vec, nibble_mask_u8x64);
394
- a_high_u8x64 = _mm512_and_si512(_mm512_srli_epi16(a_i4_vec, 4), nibble_mask_u8x64);
395
- b_low_u8x64 = _mm512_and_si512(b_i4_vec, nibble_mask_u8x64);
396
- b_high_u8x64 = _mm512_and_si512(_mm512_srli_epi16(b_i4_vec, 4), nibble_mask_u8x64);
396
+ a_low_u8x64 = _mm512_and_si512(a_i4_u8x64, nibble_mask_u8x64);
397
+ a_high_u8x64 = _mm512_and_si512(_mm512_srli_epi16(a_i4_u8x64, 4), nibble_mask_u8x64);
398
+ b_low_u8x64 = _mm512_and_si512(b_i4_u8x64, nibble_mask_u8x64);
399
+ b_high_u8x64 = _mm512_and_si512(_mm512_srli_epi16(b_i4_u8x64, 4), nibble_mask_u8x64);
397
400
 
398
401
  // Compute biased values: ax = a ^ 8 (still ∈ [0,15], just reordered)
399
402
  ax_low_u8x64 = _mm512_xor_si512(a_low_u8x64, eight_i8x64);
@@ -440,7 +443,7 @@ nk_angular_i4_icelake_cycle:
440
443
  nk_i32_t norm_excess = 128 * (nk_i32_t)(nk_size_round_up_to_multiple_(n_bytes_total, 64) - n_bytes_total);
441
444
  nk_i32_t a2 = _mm512_reduce_add_epi32(a2_i32x16) - norm_excess;
442
445
  nk_i32_t b2 = _mm512_reduce_add_epi32(b2_i32x16) - norm_excess;
443
- *result = nk_angular_normalize_f32_haswell_(ab, (nk_f32_t)a2, (nk_f32_t)b2);
446
+ *result = nk_angular_normalize_f32_haswell_((nk_f32_t)ab, (nk_f32_t)a2, (nk_f32_t)b2);
444
447
  }
445
448
 
446
449
  NK_PUBLIC void nk_sqeuclidean_u4_icelake(nk_u4x2_t const *a, nk_u4x2_t const *b, nk_size_t n, nk_u32_t *result) {
@@ -457,7 +460,7 @@ NK_PUBLIC void nk_sqeuclidean_u4_icelake(nk_u4x2_t const *a, nk_u4x2_t const *b,
457
460
  // No sign extension needed since values are unsigned.
458
461
  __m512i const nibble_mask_u8x64 = _mm512_set1_epi8(0x0F);
459
462
 
460
- __m512i a_u4_vec, b_u4_vec;
463
+ __m512i a_u4_u8x64, b_u4_u8x64;
461
464
  __m512i a_low_u8x64, a_high_u8x64, b_low_u8x64, b_high_u8x64;
462
465
  __m512i diff_low_u8x64, diff_high_u8x64;
463
466
  __m512i d2_i32x16 = _mm512_setzero_si512();
@@ -465,21 +468,21 @@ NK_PUBLIC void nk_sqeuclidean_u4_icelake(nk_u4x2_t const *a, nk_u4x2_t const *b,
465
468
  nk_sqeuclidean_u4_icelake_cycle:
466
469
  if (n_bytes < 64) {
467
470
  __mmask64 mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFF, n_bytes);
468
- a_u4_vec = _mm512_maskz_loadu_epi8(mask, a);
469
- b_u4_vec = _mm512_maskz_loadu_epi8(mask, b);
471
+ a_u4_u8x64 = _mm512_maskz_loadu_epi8(mask, a);
472
+ b_u4_u8x64 = _mm512_maskz_loadu_epi8(mask, b);
470
473
  n_bytes = 0;
471
474
  }
472
475
  else {
473
- a_u4_vec = _mm512_loadu_epi8(a);
474
- b_u4_vec = _mm512_loadu_epi8(b);
476
+ a_u4_u8x64 = _mm512_loadu_epi8(a);
477
+ b_u4_u8x64 = _mm512_loadu_epi8(b);
475
478
  a += 64, b += 64, n_bytes -= 64;
476
479
  }
477
480
 
478
481
  // Extract nibbles as unsigned [0,15]
479
- a_low_u8x64 = _mm512_and_si512(a_u4_vec, nibble_mask_u8x64);
480
- a_high_u8x64 = _mm512_and_si512(_mm512_srli_epi16(a_u4_vec, 4), nibble_mask_u8x64);
481
- b_low_u8x64 = _mm512_and_si512(b_u4_vec, nibble_mask_u8x64);
482
- b_high_u8x64 = _mm512_and_si512(_mm512_srli_epi16(b_u4_vec, 4), nibble_mask_u8x64);
482
+ a_low_u8x64 = _mm512_and_si512(a_u4_u8x64, nibble_mask_u8x64);
483
+ a_high_u8x64 = _mm512_and_si512(_mm512_srli_epi16(a_u4_u8x64, 4), nibble_mask_u8x64);
484
+ b_low_u8x64 = _mm512_and_si512(b_u4_u8x64, nibble_mask_u8x64);
485
+ b_high_u8x64 = _mm512_and_si512(_mm512_srli_epi16(b_u4_u8x64, 4), nibble_mask_u8x64);
483
486
 
484
487
  // Absolute difference for unsigned: |a-b| = (a ⊖ b) | (b ⊖ a) where ⊖ is saturating sub
485
488
  diff_low_u8x64 = _mm512_or_si512(_mm512_subs_epu8(a_low_u8x64, b_low_u8x64),
@@ -515,7 +518,7 @@ NK_PUBLIC void nk_angular_u4_icelake(nk_u4x2_t const *a, nk_u4x2_t const *b, nk_
515
518
  __m512i const nibble_mask_u8x64 = _mm512_set1_epi8(0x0F);
516
519
  __m512i const zeros_i8x64 = _mm512_setzero_si512();
517
520
 
518
- __m512i a_u4_vec, b_u4_vec;
521
+ __m512i a_u4_u8x64, b_u4_u8x64;
519
522
  __m512i a_low_u8x64, a_high_u8x64, b_low_u8x64, b_high_u8x64;
520
523
 
521
524
  __m512i ab_i32x16 = zeros_i8x64;
@@ -525,21 +528,21 @@ NK_PUBLIC void nk_angular_u4_icelake(nk_u4x2_t const *a, nk_u4x2_t const *b, nk_
525
528
  nk_angular_u4_icelake_cycle:
526
529
  if (n_bytes < 64) {
527
530
  __mmask64 mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFF, n_bytes);
528
- a_u4_vec = _mm512_maskz_loadu_epi8(mask, a);
529
- b_u4_vec = _mm512_maskz_loadu_epi8(mask, b);
531
+ a_u4_u8x64 = _mm512_maskz_loadu_epi8(mask, a);
532
+ b_u4_u8x64 = _mm512_maskz_loadu_epi8(mask, b);
530
533
  n_bytes = 0;
531
534
  }
532
535
  else {
533
- a_u4_vec = _mm512_loadu_epi8(a);
534
- b_u4_vec = _mm512_loadu_epi8(b);
536
+ a_u4_u8x64 = _mm512_loadu_epi8(a);
537
+ b_u4_u8x64 = _mm512_loadu_epi8(b);
535
538
  a += 64, b += 64, n_bytes -= 64;
536
539
  }
537
540
 
538
541
  // Extract nibbles as unsigned [0,15]
539
- a_low_u8x64 = _mm512_and_si512(a_u4_vec, nibble_mask_u8x64);
540
- a_high_u8x64 = _mm512_and_si512(_mm512_srli_epi16(a_u4_vec, 4), nibble_mask_u8x64);
541
- b_low_u8x64 = _mm512_and_si512(b_u4_vec, nibble_mask_u8x64);
542
- b_high_u8x64 = _mm512_and_si512(_mm512_srli_epi16(b_u4_vec, 4), nibble_mask_u8x64);
542
+ a_low_u8x64 = _mm512_and_si512(a_u4_u8x64, nibble_mask_u8x64);
543
+ a_high_u8x64 = _mm512_and_si512(_mm512_srli_epi16(a_u4_u8x64, 4), nibble_mask_u8x64);
544
+ b_low_u8x64 = _mm512_and_si512(b_u4_u8x64, nibble_mask_u8x64);
545
+ b_high_u8x64 = _mm512_and_si512(_mm512_srli_epi16(b_u4_u8x64, 4), nibble_mask_u8x64);
543
546
 
544
547
  // Dot product with DPBUSD (safe for unsigned [0,15])
545
548
  ab_i32x16 = _mm512_dpbusd_epi32(ab_i32x16, a_low_u8x64, b_low_u8x64);
@@ -553,22 +556,500 @@ nk_angular_u4_icelake_cycle:
553
556
  (char)225, (char)196, (char)169, (char)144, 121, 100, 81, 64, 49, 36, 25, 16, 9, 4, 1, 0, //
554
557
  (char)225, (char)196, (char)169, (char)144, 121, 100, 81, 64, 49, 36, 25, 16, 9, 4, 1, 0);
555
558
 
556
- __m512i a2_lo_u8x64 = _mm512_shuffle_epi8(u4_squares_lookup_u8x64, a_low_u8x64);
557
- __m512i a2_hi_u8x64 = _mm512_shuffle_epi8(u4_squares_lookup_u8x64, a_high_u8x64);
558
- __m512i b2_lo_u8x64 = _mm512_shuffle_epi8(u4_squares_lookup_u8x64, b_low_u8x64);
559
- __m512i b2_hi_u8x64 = _mm512_shuffle_epi8(u4_squares_lookup_u8x64, b_high_u8x64);
559
+ __m512i a2_low_u8x64 = _mm512_shuffle_epi8(u4_squares_lookup_u8x64, a_low_u8x64);
560
+ __m512i a2_high_u8x64 = _mm512_shuffle_epi8(u4_squares_lookup_u8x64, a_high_u8x64);
561
+ __m512i b2_low_u8x64 = _mm512_shuffle_epi8(u4_squares_lookup_u8x64, b_low_u8x64);
562
+ __m512i b2_high_u8x64 = _mm512_shuffle_epi8(u4_squares_lookup_u8x64, b_high_u8x64);
560
563
 
561
564
  // Accumulate low and high squares separately using SAD to avoid u8 overflow
562
- a2_i64x8 = _mm512_add_epi64(a2_i64x8, _mm512_sad_epu8(a2_lo_u8x64, zeros_i8x64));
563
- a2_i64x8 = _mm512_add_epi64(a2_i64x8, _mm512_sad_epu8(a2_hi_u8x64, zeros_i8x64));
564
- b2_i64x8 = _mm512_add_epi64(b2_i64x8, _mm512_sad_epu8(b2_lo_u8x64, zeros_i8x64));
565
- b2_i64x8 = _mm512_add_epi64(b2_i64x8, _mm512_sad_epu8(b2_hi_u8x64, zeros_i8x64));
565
+ a2_i64x8 = _mm512_add_epi64(a2_i64x8, _mm512_sad_epu8(a2_low_u8x64, zeros_i8x64));
566
+ a2_i64x8 = _mm512_add_epi64(a2_i64x8, _mm512_sad_epu8(a2_high_u8x64, zeros_i8x64));
567
+ b2_i64x8 = _mm512_add_epi64(b2_i64x8, _mm512_sad_epu8(b2_low_u8x64, zeros_i8x64));
568
+ b2_i64x8 = _mm512_add_epi64(b2_i64x8, _mm512_sad_epu8(b2_high_u8x64, zeros_i8x64));
566
569
  if (n_bytes) goto nk_angular_u4_icelake_cycle;
567
570
 
568
571
  nk_i32_t ab = _mm512_reduce_add_epi32(ab_i32x16);
569
572
  nk_i64_t a2 = _mm512_reduce_add_epi64(a2_i64x8);
570
573
  nk_i64_t b2 = _mm512_reduce_add_epi64(b2_i64x8);
571
- *result = nk_angular_normalize_f32_haswell_(ab, (nk_f32_t)a2, (nk_f32_t)b2);
574
+ *result = nk_angular_normalize_f32_haswell_((nk_f32_t)ab, (nk_f32_t)a2, (nk_f32_t)b2);
575
+ }
576
+
577
+ NK_PUBLIC void nk_sqeuclidean_e4m3_icelake(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result) {
578
+ // E4M3 squared Euclidean distance via octave VNNI.
579
+
580
+ __m512i const lut_normal_u8x64 = _mm512_set_epi8( //
581
+ 120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32, //
582
+ 30, 28, 26, 24, 22, 20, 18, 16, 15, 14, 13, 12, 11, 10, 9, 8, //
583
+ 120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32, //
584
+ 30, 28, 26, 24, 22, 20, 18, 16, 15, 14, 13, 12, 11, 10, 9, 8); //
585
+ __m512i const lut_subnorm_u8x64 = _mm512_set_epi8( //
586
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, //
587
+ 0, 0, 0, 0, 0, 0, 0, 0, 14, 12, 10, 8, 6, 4, 2, 0, //
588
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, //
589
+ 0, 0, 0, 0, 0, 0, 0, 0, 14, 12, 10, 8, 6, 4, 2, 0); //
590
+ __m512i const magnitude_mask_u8x64 = _mm512_set1_epi8(0x7F);
591
+ __m512i const subnorm_threshold_u8x64 = _mm512_set1_epi8(0x08);
592
+ __m512i const oct_threshold_20_u8x64 = _mm512_set1_epi8(0x20);
593
+ __m512i const oct_threshold_40_u8x64 = _mm512_set1_epi8(0x40);
594
+ __m512i const oct_threshold_60_u8x64 = _mm512_set1_epi8(0x60);
595
+
596
+ __m512i ab0_i32x16 = _mm512_setzero_si512(), ab1_i32x16 = _mm512_setzero_si512();
597
+ __m512i ab2_i32x16 = _mm512_setzero_si512(), ab3_i32x16 = _mm512_setzero_si512();
598
+ __m512i ab4_i32x16 = _mm512_setzero_si512(), ab5_i32x16 = _mm512_setzero_si512();
599
+ __m512i ab6_i32x16 = _mm512_setzero_si512();
600
+ __m512i a2_0_i32x16 = _mm512_setzero_si512(), a2_2_i32x16 = _mm512_setzero_si512();
601
+ __m512i a2_4_i32x16 = _mm512_setzero_si512(), a2_6_i32x16 = _mm512_setzero_si512();
602
+ __m512i b2_0_i32x16 = _mm512_setzero_si512(), b2_2_i32x16 = _mm512_setzero_si512();
603
+ __m512i b2_4_i32x16 = _mm512_setzero_si512(), b2_6_i32x16 = _mm512_setzero_si512();
604
+ __m512i a_e4m3_u8x64, b_e4m3_u8x64;
605
+
606
+ nk_sqeuclidean_e4m3_icelake_cycle:
607
+ if (n < 64) {
608
+ __mmask64 mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFF, n);
609
+ a_e4m3_u8x64 = _mm512_maskz_loadu_epi8(mask, a);
610
+ b_e4m3_u8x64 = _mm512_maskz_loadu_epi8(mask, b);
611
+ n = 0;
612
+ }
613
+ else {
614
+ a_e4m3_u8x64 = _mm512_loadu_si512(a);
615
+ b_e4m3_u8x64 = _mm512_loadu_si512(b);
616
+ a += 64, b += 64, n -= 64;
617
+ }
618
+
619
+ __m512i a_magnitude_u8x64 = _mm512_and_si512(a_e4m3_u8x64, magnitude_mask_u8x64);
620
+ __m512i b_magnitude_u8x64 = _mm512_and_si512(b_e4m3_u8x64, magnitude_mask_u8x64);
621
+ __m512i a_base_u8x64 = _mm512_permutexvar_epi8(a_magnitude_u8x64, lut_normal_u8x64);
622
+ __m512i b_base_u8x64 = _mm512_permutexvar_epi8(b_magnitude_u8x64, lut_normal_u8x64);
623
+ a_base_u8x64 = _mm512_mask_permutexvar_epi8(a_base_u8x64,
624
+ _mm512_cmplt_epu8_mask(a_magnitude_u8x64, subnorm_threshold_u8x64),
625
+ a_magnitude_u8x64, lut_subnorm_u8x64);
626
+ b_base_u8x64 = _mm512_mask_permutexvar_epi8(b_base_u8x64,
627
+ _mm512_cmplt_epu8_mask(b_magnitude_u8x64, subnorm_threshold_u8x64),
628
+ b_magnitude_u8x64, lut_subnorm_u8x64);
629
+
630
+ __m512i sign_diff_u8x64 = _mm512_ternarylogic_epi64(a_e4m3_u8x64, b_e4m3_u8x64, magnitude_mask_u8x64, 0x14);
631
+ __m512i b_signed_i8x64 = _mm512_mask_sub_epi8(b_base_u8x64, _mm512_test_epi8_mask(sign_diff_u8x64, sign_diff_u8x64),
632
+ _mm512_setzero_si512(), b_base_u8x64);
633
+
634
+ __mmask64 ka_lt20 = _mm512_cmplt_epu8_mask(a_magnitude_u8x64, oct_threshold_20_u8x64);
635
+ __mmask64 ka_lt40 = _mm512_cmplt_epu8_mask(a_magnitude_u8x64, oct_threshold_40_u8x64);
636
+ __mmask64 ka_lt60 = _mm512_cmplt_epu8_mask(a_magnitude_u8x64, oct_threshold_60_u8x64);
637
+ __mmask64 kb_lt20 = _mm512_cmplt_epu8_mask(b_magnitude_u8x64, oct_threshold_20_u8x64);
638
+ __mmask64 kb_lt40 = _mm512_cmplt_epu8_mask(b_magnitude_u8x64, oct_threshold_40_u8x64);
639
+ __mmask64 kb_lt60 = _mm512_cmplt_epu8_mask(b_magnitude_u8x64, oct_threshold_60_u8x64);
640
+
641
+ __m512i a0_u8x64 = _mm512_maskz_mov_epi8(ka_lt20, a_base_u8x64);
642
+ __m512i a1_u8x64 = _mm512_maskz_mov_epi8(ka_lt40 & ~ka_lt20, a_base_u8x64);
643
+ __m512i a2_u8x64 = _mm512_maskz_mov_epi8(ka_lt60 & ~ka_lt40, a_base_u8x64);
644
+ __m512i a3_u8x64 = _mm512_maskz_mov_epi8(~ka_lt60, a_base_u8x64);
645
+
646
+ __m512i b0_i8x64 = _mm512_maskz_mov_epi8(kb_lt20, b_signed_i8x64);
647
+ __m512i b1_i8x64 = _mm512_maskz_mov_epi8(kb_lt40 & ~kb_lt20, b_signed_i8x64);
648
+ __m512i b2_i8x64 = _mm512_maskz_mov_epi8(kb_lt60 & ~kb_lt40, b_signed_i8x64);
649
+ __m512i b3_i8x64 = _mm512_maskz_mov_epi8(~kb_lt60, b_signed_i8x64);
650
+
651
+ // dot(a,b): 16 VPDPBUSD
652
+ ab0_i32x16 = _mm512_dpbusd_epi32(ab0_i32x16, a0_u8x64, b0_i8x64);
653
+ ab1_i32x16 = _mm512_dpbusd_epi32(_mm512_dpbusd_epi32(ab1_i32x16, a0_u8x64, b1_i8x64), a1_u8x64, b0_i8x64);
654
+ ab2_i32x16 = _mm512_dpbusd_epi32(
655
+ _mm512_dpbusd_epi32(_mm512_dpbusd_epi32(ab2_i32x16, a0_u8x64, b2_i8x64), a1_u8x64, b1_i8x64), a2_u8x64,
656
+ b0_i8x64);
657
+ ab3_i32x16 = _mm512_dpbusd_epi32(
658
+ _mm512_dpbusd_epi32(
659
+ _mm512_dpbusd_epi32(_mm512_dpbusd_epi32(ab3_i32x16, a0_u8x64, b3_i8x64), a1_u8x64, b2_i8x64), a2_u8x64,
660
+ b1_i8x64),
661
+ a3_u8x64, b0_i8x64);
662
+ ab4_i32x16 = _mm512_dpbusd_epi32(
663
+ _mm512_dpbusd_epi32(_mm512_dpbusd_epi32(ab4_i32x16, a1_u8x64, b3_i8x64), a2_u8x64, b2_i8x64), a3_u8x64,
664
+ b1_i8x64);
665
+ ab5_i32x16 = _mm512_dpbusd_epi32(_mm512_dpbusd_epi32(ab5_i32x16, a2_u8x64, b3_i8x64), a3_u8x64, b2_i8x64);
666
+ ab6_i32x16 = _mm512_dpbusd_epi32(ab6_i32x16, a3_u8x64, b3_i8x64);
667
+
668
+ // ||a||²: 4 VPDPBUSD (self-dot, same-octave only)
669
+ a2_0_i32x16 = _mm512_dpbusd_epi32(a2_0_i32x16, a0_u8x64, a0_u8x64);
670
+ a2_2_i32x16 = _mm512_dpbusd_epi32(a2_2_i32x16, a1_u8x64, a1_u8x64);
671
+ a2_4_i32x16 = _mm512_dpbusd_epi32(a2_4_i32x16, a2_u8x64, a2_u8x64);
672
+ a2_6_i32x16 = _mm512_dpbusd_epi32(a2_6_i32x16, a3_u8x64, a3_u8x64);
673
+
674
+ // ||b||²: 4 VPDPBUSD (unsigned b, not signed)
675
+ __m512i b0_u8x64 = _mm512_maskz_mov_epi8(kb_lt20, b_base_u8x64);
676
+ __m512i b1_u8x64 = _mm512_maskz_mov_epi8(kb_lt40 & ~kb_lt20, b_base_u8x64);
677
+ __m512i b2_u8x64 = _mm512_maskz_mov_epi8(kb_lt60 & ~kb_lt40, b_base_u8x64);
678
+ __m512i b3_u8x64 = _mm512_maskz_mov_epi8(~kb_lt60, b_base_u8x64);
679
+ b2_0_i32x16 = _mm512_dpbusd_epi32(b2_0_i32x16, b0_u8x64, b0_u8x64);
680
+ b2_2_i32x16 = _mm512_dpbusd_epi32(b2_2_i32x16, b1_u8x64, b1_u8x64);
681
+ b2_4_i32x16 = _mm512_dpbusd_epi32(b2_4_i32x16, b2_u8x64, b2_u8x64);
682
+ b2_6_i32x16 = _mm512_dpbusd_epi32(b2_6_i32x16, b3_u8x64, b3_u8x64);
683
+
684
+ if (n) goto nk_sqeuclidean_e4m3_icelake_cycle;
685
+
686
+ // Reduce dot(a,b)
687
+ __m512 ab_f32x16 = _mm512_mul_ps(_mm512_cvtepi32_ps(ab0_i32x16), _mm512_set1_ps(9.5367431640625e-07f));
688
+ ab_f32x16 = _mm512_fmadd_ps(_mm512_cvtepi32_ps(ab1_i32x16), _mm512_set1_ps(1.52587890625e-05f), ab_f32x16);
689
+ ab_f32x16 = _mm512_fmadd_ps(_mm512_cvtepi32_ps(ab2_i32x16), _mm512_set1_ps(2.44140625e-04f), ab_f32x16);
690
+ ab_f32x16 = _mm512_fmadd_ps(_mm512_cvtepi32_ps(ab3_i32x16), _mm512_set1_ps(3.90625e-03f), ab_f32x16);
691
+ ab_f32x16 = _mm512_fmadd_ps(_mm512_cvtepi32_ps(ab4_i32x16), _mm512_set1_ps(6.25e-02f), ab_f32x16);
692
+ ab_f32x16 = _mm512_fmadd_ps(_mm512_cvtepi32_ps(ab5_i32x16), _mm512_set1_ps(1.0f), ab_f32x16);
693
+ ab_f32x16 = _mm512_fmadd_ps(_mm512_cvtepi32_ps(ab6_i32x16), _mm512_set1_ps(16.0f), ab_f32x16);
694
+
695
+ // Reduce ||a||² and ||b||² (even-k only: scale = 2^(8·oct − 20))
696
+ __m512 a2_f32x16 = _mm512_mul_ps(_mm512_cvtepi32_ps(a2_0_i32x16), _mm512_set1_ps(9.5367431640625e-07f));
697
+ a2_f32x16 = _mm512_fmadd_ps(_mm512_cvtepi32_ps(a2_2_i32x16), _mm512_set1_ps(2.44140625e-04f), a2_f32x16);
698
+ a2_f32x16 = _mm512_fmadd_ps(_mm512_cvtepi32_ps(a2_4_i32x16), _mm512_set1_ps(6.25e-02f), a2_f32x16);
699
+ a2_f32x16 = _mm512_fmadd_ps(_mm512_cvtepi32_ps(a2_6_i32x16), _mm512_set1_ps(16.0f), a2_f32x16);
700
+
701
+ __m512 b2_f32x16 = _mm512_mul_ps(_mm512_cvtepi32_ps(b2_0_i32x16), _mm512_set1_ps(9.5367431640625e-07f));
702
+ b2_f32x16 = _mm512_fmadd_ps(_mm512_cvtepi32_ps(b2_2_i32x16), _mm512_set1_ps(2.44140625e-04f), b2_f32x16);
703
+ b2_f32x16 = _mm512_fmadd_ps(_mm512_cvtepi32_ps(b2_4_i32x16), _mm512_set1_ps(6.25e-02f), b2_f32x16);
704
+ b2_f32x16 = _mm512_fmadd_ps(_mm512_cvtepi32_ps(b2_6_i32x16), _mm512_set1_ps(16.0f), b2_f32x16);
705
+
706
+ // (a-b)² = ||a||² + ||b||² - 2·dot(a,b)
707
+ __m512 sum_sq_f32x16 = _mm512_add_ps(a2_f32x16, b2_f32x16);
708
+ *result = nk_reduce_add_f32x16_skylake_(_mm512_fnmadd_ps(_mm512_set1_ps(2.0f), ab_f32x16, sum_sq_f32x16));
709
+ }
710
+
711
+ NK_PUBLIC void nk_euclidean_e4m3_icelake(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result) {
712
+ nk_sqeuclidean_e4m3_icelake(a, b, n, result);
713
+ *result = nk_f32_sqrt_haswell(*result);
714
+ }
715
+
716
+ NK_PUBLIC void nk_angular_e4m3_icelake(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result) {
717
+ // E4M3 angular distance via octave VNNI.
718
+
719
+ __m512i const lut_normal_u8x64 = _mm512_set_epi8( //
720
+ 120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32, //
721
+ 30, 28, 26, 24, 22, 20, 18, 16, 15, 14, 13, 12, 11, 10, 9, 8, //
722
+ 120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32, //
723
+ 30, 28, 26, 24, 22, 20, 18, 16, 15, 14, 13, 12, 11, 10, 9, 8); //
724
+ __m512i const lut_subnorm_u8x64 = _mm512_set_epi8( //
725
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, //
726
+ 0, 0, 0, 0, 0, 0, 0, 0, 14, 12, 10, 8, 6, 4, 2, 0, //
727
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, //
728
+ 0, 0, 0, 0, 0, 0, 0, 0, 14, 12, 10, 8, 6, 4, 2, 0); //
729
+ __m512i const magnitude_mask_u8x64 = _mm512_set1_epi8(0x7F);
730
+ __m512i const subnorm_threshold_u8x64 = _mm512_set1_epi8(0x08);
731
+ __m512i const oct_threshold_20_u8x64 = _mm512_set1_epi8(0x20);
732
+ __m512i const oct_threshold_40_u8x64 = _mm512_set1_epi8(0x40);
733
+ __m512i const oct_threshold_60_u8x64 = _mm512_set1_epi8(0x60);
734
+
735
+ __m512i ab0_i32x16 = _mm512_setzero_si512(), ab1_i32x16 = _mm512_setzero_si512();
736
+ __m512i ab2_i32x16 = _mm512_setzero_si512(), ab3_i32x16 = _mm512_setzero_si512();
737
+ __m512i ab4_i32x16 = _mm512_setzero_si512(), ab5_i32x16 = _mm512_setzero_si512();
738
+ __m512i ab6_i32x16 = _mm512_setzero_si512();
739
+ __m512i a2_0_i32x16 = _mm512_setzero_si512(), a2_2_i32x16 = _mm512_setzero_si512();
740
+ __m512i a2_4_i32x16 = _mm512_setzero_si512(), a2_6_i32x16 = _mm512_setzero_si512();
741
+ __m512i b2_0_i32x16 = _mm512_setzero_si512(), b2_2_i32x16 = _mm512_setzero_si512();
742
+ __m512i b2_4_i32x16 = _mm512_setzero_si512(), b2_6_i32x16 = _mm512_setzero_si512();
743
+ __m512i a_e4m3_u8x64, b_e4m3_u8x64;
744
+
745
+ nk_angular_e4m3_icelake_cycle:
746
+ if (n < 64) {
747
+ __mmask64 mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFF, n);
748
+ a_e4m3_u8x64 = _mm512_maskz_loadu_epi8(mask, a);
749
+ b_e4m3_u8x64 = _mm512_maskz_loadu_epi8(mask, b);
750
+ n = 0;
751
+ }
752
+ else {
753
+ a_e4m3_u8x64 = _mm512_loadu_si512(a);
754
+ b_e4m3_u8x64 = _mm512_loadu_si512(b);
755
+ a += 64, b += 64, n -= 64;
756
+ }
757
+
758
+ __m512i a_magnitude_u8x64 = _mm512_and_si512(a_e4m3_u8x64, magnitude_mask_u8x64);
759
+ __m512i b_magnitude_u8x64 = _mm512_and_si512(b_e4m3_u8x64, magnitude_mask_u8x64);
760
+ __m512i a_base_u8x64 = _mm512_permutexvar_epi8(a_magnitude_u8x64, lut_normal_u8x64);
761
+ __m512i b_base_u8x64 = _mm512_permutexvar_epi8(b_magnitude_u8x64, lut_normal_u8x64);
762
+ a_base_u8x64 = _mm512_mask_permutexvar_epi8(a_base_u8x64,
763
+ _mm512_cmplt_epu8_mask(a_magnitude_u8x64, subnorm_threshold_u8x64),
764
+ a_magnitude_u8x64, lut_subnorm_u8x64);
765
+ b_base_u8x64 = _mm512_mask_permutexvar_epi8(b_base_u8x64,
766
+ _mm512_cmplt_epu8_mask(b_magnitude_u8x64, subnorm_threshold_u8x64),
767
+ b_magnitude_u8x64, lut_subnorm_u8x64);
768
+
769
+ __m512i sign_diff_u8x64 = _mm512_ternarylogic_epi64(a_e4m3_u8x64, b_e4m3_u8x64, magnitude_mask_u8x64, 0x14);
770
+ __m512i b_signed_i8x64 = _mm512_mask_sub_epi8(b_base_u8x64, _mm512_test_epi8_mask(sign_diff_u8x64, sign_diff_u8x64),
771
+ _mm512_setzero_si512(), b_base_u8x64);
772
+
773
+ __mmask64 ka_lt20 = _mm512_cmplt_epu8_mask(a_magnitude_u8x64, oct_threshold_20_u8x64);
774
+ __mmask64 ka_lt40 = _mm512_cmplt_epu8_mask(a_magnitude_u8x64, oct_threshold_40_u8x64);
775
+ __mmask64 ka_lt60 = _mm512_cmplt_epu8_mask(a_magnitude_u8x64, oct_threshold_60_u8x64);
776
+ __mmask64 kb_lt20 = _mm512_cmplt_epu8_mask(b_magnitude_u8x64, oct_threshold_20_u8x64);
777
+ __mmask64 kb_lt40 = _mm512_cmplt_epu8_mask(b_magnitude_u8x64, oct_threshold_40_u8x64);
778
+ __mmask64 kb_lt60 = _mm512_cmplt_epu8_mask(b_magnitude_u8x64, oct_threshold_60_u8x64);
779
+
780
+ __m512i a0_u8x64 = _mm512_maskz_mov_epi8(ka_lt20, a_base_u8x64);
781
+ __m512i a1_u8x64 = _mm512_maskz_mov_epi8(ka_lt40 & ~ka_lt20, a_base_u8x64);
782
+ __m512i a2_u8x64 = _mm512_maskz_mov_epi8(ka_lt60 & ~ka_lt40, a_base_u8x64);
783
+ __m512i a3_u8x64 = _mm512_maskz_mov_epi8(~ka_lt60, a_base_u8x64);
784
+
785
+ __m512i b0_i8x64 = _mm512_maskz_mov_epi8(kb_lt20, b_signed_i8x64);
786
+ __m512i b1_i8x64 = _mm512_maskz_mov_epi8(kb_lt40 & ~kb_lt20, b_signed_i8x64);
787
+ __m512i b2_i8x64 = _mm512_maskz_mov_epi8(kb_lt60 & ~kb_lt40, b_signed_i8x64);
788
+ __m512i b3_i8x64 = _mm512_maskz_mov_epi8(~kb_lt60, b_signed_i8x64);
789
+
790
+ // dot(a,b): 16 VPDPBUSD
791
+ ab0_i32x16 = _mm512_dpbusd_epi32(ab0_i32x16, a0_u8x64, b0_i8x64);
792
+ ab1_i32x16 = _mm512_dpbusd_epi32(_mm512_dpbusd_epi32(ab1_i32x16, a0_u8x64, b1_i8x64), a1_u8x64, b0_i8x64);
793
+ ab2_i32x16 = _mm512_dpbusd_epi32(
794
+ _mm512_dpbusd_epi32(_mm512_dpbusd_epi32(ab2_i32x16, a0_u8x64, b2_i8x64), a1_u8x64, b1_i8x64), a2_u8x64,
795
+ b0_i8x64);
796
+ ab3_i32x16 = _mm512_dpbusd_epi32(
797
+ _mm512_dpbusd_epi32(
798
+ _mm512_dpbusd_epi32(_mm512_dpbusd_epi32(ab3_i32x16, a0_u8x64, b3_i8x64), a1_u8x64, b2_i8x64), a2_u8x64,
799
+ b1_i8x64),
800
+ a3_u8x64, b0_i8x64);
801
+ ab4_i32x16 = _mm512_dpbusd_epi32(
802
+ _mm512_dpbusd_epi32(_mm512_dpbusd_epi32(ab4_i32x16, a1_u8x64, b3_i8x64), a2_u8x64, b2_i8x64), a3_u8x64,
803
+ b1_i8x64);
804
+ ab5_i32x16 = _mm512_dpbusd_epi32(_mm512_dpbusd_epi32(ab5_i32x16, a2_u8x64, b3_i8x64), a3_u8x64, b2_i8x64);
805
+ ab6_i32x16 = _mm512_dpbusd_epi32(ab6_i32x16, a3_u8x64, b3_i8x64);
806
+
807
+ // ||a||²: 4 VPDPBUSD
808
+ a2_0_i32x16 = _mm512_dpbusd_epi32(a2_0_i32x16, a0_u8x64, a0_u8x64);
809
+ a2_2_i32x16 = _mm512_dpbusd_epi32(a2_2_i32x16, a1_u8x64, a1_u8x64);
810
+ a2_4_i32x16 = _mm512_dpbusd_epi32(a2_4_i32x16, a2_u8x64, a2_u8x64);
811
+ a2_6_i32x16 = _mm512_dpbusd_epi32(a2_6_i32x16, a3_u8x64, a3_u8x64);
812
+
813
+ // ||b||²: 4 VPDPBUSD (unsigned b)
814
+ __m512i b0_u8x64 = _mm512_maskz_mov_epi8(kb_lt20, b_base_u8x64);
815
+ __m512i b1_u8x64 = _mm512_maskz_mov_epi8(kb_lt40 & ~kb_lt20, b_base_u8x64);
816
+ __m512i b2_u8x64 = _mm512_maskz_mov_epi8(kb_lt60 & ~kb_lt40, b_base_u8x64);
817
+ __m512i b3_u8x64 = _mm512_maskz_mov_epi8(~kb_lt60, b_base_u8x64);
818
+ b2_0_i32x16 = _mm512_dpbusd_epi32(b2_0_i32x16, b0_u8x64, b0_u8x64);
819
+ b2_2_i32x16 = _mm512_dpbusd_epi32(b2_2_i32x16, b1_u8x64, b1_u8x64);
820
+ b2_4_i32x16 = _mm512_dpbusd_epi32(b2_4_i32x16, b2_u8x64, b2_u8x64);
821
+ b2_6_i32x16 = _mm512_dpbusd_epi32(b2_6_i32x16, b3_u8x64, b3_u8x64);
822
+
823
+ if (n) goto nk_angular_e4m3_icelake_cycle;
824
+
825
+ // Reduce dot(a,b)
826
+ __m512 ab_f32x16 = _mm512_mul_ps(_mm512_cvtepi32_ps(ab0_i32x16), _mm512_set1_ps(9.5367431640625e-07f));
827
+ ab_f32x16 = _mm512_fmadd_ps(_mm512_cvtepi32_ps(ab1_i32x16), _mm512_set1_ps(1.52587890625e-05f), ab_f32x16);
828
+ ab_f32x16 = _mm512_fmadd_ps(_mm512_cvtepi32_ps(ab2_i32x16), _mm512_set1_ps(2.44140625e-04f), ab_f32x16);
829
+ ab_f32x16 = _mm512_fmadd_ps(_mm512_cvtepi32_ps(ab3_i32x16), _mm512_set1_ps(3.90625e-03f), ab_f32x16);
830
+ ab_f32x16 = _mm512_fmadd_ps(_mm512_cvtepi32_ps(ab4_i32x16), _mm512_set1_ps(6.25e-02f), ab_f32x16);
831
+ ab_f32x16 = _mm512_fmadd_ps(_mm512_cvtepi32_ps(ab5_i32x16), _mm512_set1_ps(1.0f), ab_f32x16);
832
+ ab_f32x16 = _mm512_fmadd_ps(_mm512_cvtepi32_ps(ab6_i32x16), _mm512_set1_ps(16.0f), ab_f32x16);
833
+
834
+ __m512 a2_f32x16 = _mm512_mul_ps(_mm512_cvtepi32_ps(a2_0_i32x16), _mm512_set1_ps(9.5367431640625e-07f));
835
+ a2_f32x16 = _mm512_fmadd_ps(_mm512_cvtepi32_ps(a2_2_i32x16), _mm512_set1_ps(2.44140625e-04f), a2_f32x16);
836
+ a2_f32x16 = _mm512_fmadd_ps(_mm512_cvtepi32_ps(a2_4_i32x16), _mm512_set1_ps(6.25e-02f), a2_f32x16);
837
+ a2_f32x16 = _mm512_fmadd_ps(_mm512_cvtepi32_ps(a2_6_i32x16), _mm512_set1_ps(16.0f), a2_f32x16);
838
+
839
+ __m512 b2_f32x16 = _mm512_mul_ps(_mm512_cvtepi32_ps(b2_0_i32x16), _mm512_set1_ps(9.5367431640625e-07f));
840
+ b2_f32x16 = _mm512_fmadd_ps(_mm512_cvtepi32_ps(b2_2_i32x16), _mm512_set1_ps(2.44140625e-04f), b2_f32x16);
841
+ b2_f32x16 = _mm512_fmadd_ps(_mm512_cvtepi32_ps(b2_4_i32x16), _mm512_set1_ps(6.25e-02f), b2_f32x16);
842
+ b2_f32x16 = _mm512_fmadd_ps(_mm512_cvtepi32_ps(b2_6_i32x16), _mm512_set1_ps(16.0f), b2_f32x16);
843
+
844
+ nk_f32_t ab_f32 = nk_reduce_add_f32x16_skylake_(ab_f32x16);
845
+ nk_f32_t a_norm_sq_f32 = nk_reduce_add_f32x16_skylake_(a2_f32x16);
846
+ nk_f32_t b_norm_sq_f32 = nk_reduce_add_f32x16_skylake_(b2_f32x16);
847
+ *result = nk_angular_normalize_f32_haswell_(ab_f32, a_norm_sq_f32, b_norm_sq_f32);
848
+ }
849
+
850
+ NK_PUBLIC void nk_sqeuclidean_e2m3_icelake(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t *result) {
851
+ // E2M3 squared Euclidean distance via VPDPBUSD integer MAC.
852
+ __m512i const lut_magnitude_u8x64 = _mm512_set_epi8(120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36,
853
+ 32, 30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0,
854
+ 120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36,
855
+ 32, 30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
856
+ __m512i const magnitude_mask_u8x64 = _mm512_set1_epi8(0x1F);
857
+ __m512i const sign_mask_u8x64 = _mm512_set1_epi8(0x20);
858
+ __m512i ab_i32x16 = _mm512_setzero_si512();
859
+ __m512i a2_i32x16 = _mm512_setzero_si512();
860
+ __m512i b2_i32x16 = _mm512_setzero_si512();
861
+ __m512i a_e2m3_u8x64, b_e2m3_u8x64;
862
+
863
+ nk_sqeuclidean_e2m3_icelake_cycle:
864
+ if (n < 64) {
865
+ __mmask64 mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFF, n);
866
+ a_e2m3_u8x64 = _mm512_maskz_loadu_epi8(mask, a);
867
+ b_e2m3_u8x64 = _mm512_maskz_loadu_epi8(mask, b);
868
+ n = 0;
869
+ }
870
+ else {
871
+ a_e2m3_u8x64 = _mm512_loadu_si512(a);
872
+ b_e2m3_u8x64 = _mm512_loadu_si512(b);
873
+ a += 64, b += 64, n -= 64;
874
+ }
875
+
876
+ __m512i a_magnitude_u8x64 = _mm512_and_si512(a_e2m3_u8x64, magnitude_mask_u8x64);
877
+ __m512i b_magnitude_u8x64 = _mm512_and_si512(b_e2m3_u8x64, magnitude_mask_u8x64);
878
+ __m512i a_unsigned_u8x64 = _mm512_permutexvar_epi8(a_magnitude_u8x64, lut_magnitude_u8x64);
879
+ __m512i b_unsigned_u8x64 = _mm512_permutexvar_epi8(b_magnitude_u8x64, lut_magnitude_u8x64);
880
+
881
+ __m512i sign_combined_u8x64 = _mm512_and_si512(_mm512_xor_si512(a_e2m3_u8x64, b_e2m3_u8x64), sign_mask_u8x64);
882
+ __mmask64 negate_mask = _mm512_test_epi8_mask(sign_combined_u8x64, sign_combined_u8x64);
883
+ __m512i b_signed_i8x64 = _mm512_mask_sub_epi8(b_unsigned_u8x64, negate_mask, _mm512_setzero_si512(),
884
+ b_unsigned_u8x64);
885
+
886
+ ab_i32x16 = _mm512_dpbusd_epi32(ab_i32x16, a_unsigned_u8x64, b_signed_i8x64);
887
+ a2_i32x16 = _mm512_dpbusd_epi32(a2_i32x16, a_unsigned_u8x64, a_unsigned_u8x64);
888
+ b2_i32x16 = _mm512_dpbusd_epi32(b2_i32x16, b_unsigned_u8x64, b_unsigned_u8x64);
889
+
890
+ if (n) goto nk_sqeuclidean_e2m3_icelake_cycle;
891
+
892
+ // (a-b)² = a² + b² − 2·ab, scaled by 256 (16² from LUT)
893
+ __m512 a2_f32x16 = _mm512_cvtepi32_ps(a2_i32x16);
894
+ __m512 b2_f32x16 = _mm512_cvtepi32_ps(b2_i32x16);
895
+ __m512 ab_f32x16 = _mm512_cvtepi32_ps(ab_i32x16);
896
+ __m512 sum_sq_f32x16 = _mm512_add_ps(a2_f32x16, b2_f32x16);
897
+ *result = nk_reduce_add_f32x16_skylake_(_mm512_fnmadd_ps(_mm512_set1_ps(2.0f), ab_f32x16, sum_sq_f32x16)) / 256.0f;
898
+ }
899
+
900
+ NK_PUBLIC void nk_euclidean_e2m3_icelake(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t *result) {
901
+ nk_sqeuclidean_e2m3_icelake(a, b, n, result);
902
+ *result = nk_f32_sqrt_haswell(*result);
903
+ }
904
+
905
+ NK_PUBLIC void nk_angular_e2m3_icelake(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t *result) {
906
+ // E2M3 angular distance via VPDPBUSD integer MAC.
907
+ __m512i const lut_magnitude_u8x64 = _mm512_set_epi8(120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36,
908
+ 32, 30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0,
909
+ 120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36,
910
+ 32, 30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
911
+ __m512i const magnitude_mask_u8x64 = _mm512_set1_epi8(0x1F);
912
+ __m512i const sign_mask_u8x64 = _mm512_set1_epi8(0x20);
913
+ __m512i ab_i32x16 = _mm512_setzero_si512();
914
+ __m512i a2_i32x16 = _mm512_setzero_si512();
915
+ __m512i b2_i32x16 = _mm512_setzero_si512();
916
+ __m512i a_e2m3_u8x64, b_e2m3_u8x64;
917
+
918
+ nk_angular_e2m3_icelake_cycle:
919
+ if (n < 64) {
920
+ __mmask64 mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFF, n);
921
+ a_e2m3_u8x64 = _mm512_maskz_loadu_epi8(mask, a);
922
+ b_e2m3_u8x64 = _mm512_maskz_loadu_epi8(mask, b);
923
+ n = 0;
924
+ }
925
+ else {
926
+ a_e2m3_u8x64 = _mm512_loadu_si512(a);
927
+ b_e2m3_u8x64 = _mm512_loadu_si512(b);
928
+ a += 64, b += 64, n -= 64;
929
+ }
930
+
931
+ __m512i a_magnitude_u8x64 = _mm512_and_si512(a_e2m3_u8x64, magnitude_mask_u8x64);
932
+ __m512i b_magnitude_u8x64 = _mm512_and_si512(b_e2m3_u8x64, magnitude_mask_u8x64);
933
+ __m512i a_unsigned_u8x64 = _mm512_permutexvar_epi8(a_magnitude_u8x64, lut_magnitude_u8x64);
934
+ __m512i b_unsigned_u8x64 = _mm512_permutexvar_epi8(b_magnitude_u8x64, lut_magnitude_u8x64);
935
+
936
+ __m512i sign_combined_u8x64 = _mm512_and_si512(_mm512_xor_si512(a_e2m3_u8x64, b_e2m3_u8x64), sign_mask_u8x64);
937
+ __mmask64 negate_mask = _mm512_test_epi8_mask(sign_combined_u8x64, sign_combined_u8x64);
938
+ __m512i b_signed_i8x64 = _mm512_mask_sub_epi8(b_unsigned_u8x64, negate_mask, _mm512_setzero_si512(),
939
+ b_unsigned_u8x64);
940
+
941
+ ab_i32x16 = _mm512_dpbusd_epi32(ab_i32x16, a_unsigned_u8x64, b_signed_i8x64);
942
+ a2_i32x16 = _mm512_dpbusd_epi32(a2_i32x16, a_unsigned_u8x64, a_unsigned_u8x64);
943
+ b2_i32x16 = _mm512_dpbusd_epi32(b2_i32x16, b_unsigned_u8x64, b_unsigned_u8x64);
944
+
945
+ if (n) goto nk_angular_e2m3_icelake_cycle;
946
+
947
+ nk_f32_t ab_f32 = (nk_f32_t)_mm512_reduce_add_epi32(ab_i32x16) / 256.0f;
948
+ nk_f32_t a_norm_sq_f32 = (nk_f32_t)_mm512_reduce_add_epi32(a2_i32x16) / 256.0f;
949
+ nk_f32_t b_norm_sq_f32 = (nk_f32_t)_mm512_reduce_add_epi32(b2_i32x16) / 256.0f;
950
+ *result = nk_angular_normalize_f32_haswell_(ab_f32, a_norm_sq_f32, b_norm_sq_f32);
951
+ }
952
+
953
+ NK_PUBLIC void nk_sqeuclidean_e3m2_icelake(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_f32_t *result) {
954
+ // E3M2 squared Euclidean distance via direct difference squaring.
955
+ __m512i const lut_magnitude_i16x32 = _mm512_set_epi16( //
956
+ 448, 384, 320, 256, 224, 192, 160, 128, 112, 96, 80, 64, 56, 48, 40, 32, //
957
+ 28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2, 1, 0);
958
+ __m512i const magnitude_mask_i16x32 = _mm512_set1_epi16(0x1F);
959
+ __m512i const sign_mask_i16x32 = _mm512_set1_epi16(0x20);
960
+ __m512i sum_i32x16 = _mm512_setzero_si512();
961
+ __m256i a_e3m2_u8x32, b_e3m2_u8x32;
962
+
963
+ nk_sqeuclidean_e3m2_icelake_cycle:
964
+ if (n < 32) {
965
+ __mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, (unsigned int)n);
966
+ a_e3m2_u8x32 = _mm256_maskz_loadu_epi8(mask, a);
967
+ b_e3m2_u8x32 = _mm256_maskz_loadu_epi8(mask, b);
968
+ n = 0;
969
+ }
970
+ else {
971
+ a_e3m2_u8x32 = _mm256_loadu_si256((__m256i const *)a);
972
+ b_e3m2_u8x32 = _mm256_loadu_si256((__m256i const *)b);
973
+ a += 32, b += 32, n -= 32;
974
+ }
975
+
976
+ __m512i a_u16x32 = _mm512_cvtepu8_epi16(a_e3m2_u8x32);
977
+ __m512i b_u16x32 = _mm512_cvtepu8_epi16(b_e3m2_u8x32);
978
+ __m512i a_unsigned_i16x32 = _mm512_permutexvar_epi16(_mm512_and_si512(a_u16x32, magnitude_mask_i16x32),
979
+ lut_magnitude_i16x32);
980
+ __m512i b_unsigned_i16x32 = _mm512_permutexvar_epi16(_mm512_and_si512(b_u16x32, magnitude_mask_i16x32),
981
+ lut_magnitude_i16x32);
982
+
983
+ // Apply signs individually
984
+ __mmask32 a_negative_mask = _mm512_test_epi16_mask(a_u16x32, sign_mask_i16x32);
985
+ __mmask32 b_negative_mask = _mm512_test_epi16_mask(b_u16x32, sign_mask_i16x32);
986
+ __m512i a_signed_i16x32 = _mm512_mask_sub_epi16(a_unsigned_i16x32, a_negative_mask, _mm512_setzero_si512(),
987
+ a_unsigned_i16x32);
988
+ __m512i b_signed_i16x32 = _mm512_mask_sub_epi16(b_unsigned_i16x32, b_negative_mask, _mm512_setzero_si512(),
989
+ b_unsigned_i16x32);
990
+
991
+ // Direct difference squaring: (a-b)² via VPMADDWD
992
+ __m512i diff_i16x32 = _mm512_sub_epi16(a_signed_i16x32, b_signed_i16x32);
993
+ sum_i32x16 = _mm512_add_epi32(sum_i32x16, _mm512_madd_epi16(diff_i16x32, diff_i16x32));
994
+
995
+ if (n) goto nk_sqeuclidean_e3m2_icelake_cycle;
996
+ *result = (nk_f32_t)_mm512_reduce_add_epi32(sum_i32x16) / 256.0f;
997
+ }
998
+
999
+ NK_PUBLIC void nk_euclidean_e3m2_icelake(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_f32_t *result) {
1000
+ nk_sqeuclidean_e3m2_icelake(a, b, n, result);
1001
+ *result = nk_f32_sqrt_haswell(*result);
1002
+ }
1003
+
1004
+ NK_PUBLIC void nk_angular_e3m2_icelake(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_f32_t *result) {
1005
+ // E3M2 angular distance via VPMADDWD integer MAC.
1006
+ __m512i const lut_magnitude_i16x32 = _mm512_set_epi16( //
1007
+ 448, 384, 320, 256, 224, 192, 160, 128, 112, 96, 80, 64, 56, 48, 40, 32, //
1008
+ 28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2, 1, 0);
1009
+ __m512i const magnitude_mask_i16x32 = _mm512_set1_epi16(0x1F);
1010
+ __m512i const sign_mask_i16x32 = _mm512_set1_epi16(0x20);
1011
+ __m512i ab_i32x16 = _mm512_setzero_si512();
1012
+ __m512i a2_i32x16 = _mm512_setzero_si512();
1013
+ __m512i b2_i32x16 = _mm512_setzero_si512();
1014
+ __m256i a_e3m2_u8x32, b_e3m2_u8x32;
1015
+
1016
+ nk_angular_e3m2_icelake_cycle:
1017
+ if (n < 32) {
1018
+ __mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, (unsigned int)n);
1019
+ a_e3m2_u8x32 = _mm256_maskz_loadu_epi8(mask, a);
1020
+ b_e3m2_u8x32 = _mm256_maskz_loadu_epi8(mask, b);
1021
+ n = 0;
1022
+ }
1023
+ else {
1024
+ a_e3m2_u8x32 = _mm256_loadu_si256((__m256i const *)a);
1025
+ b_e3m2_u8x32 = _mm256_loadu_si256((__m256i const *)b);
1026
+ a += 32, b += 32, n -= 32;
1027
+ }
1028
+
1029
+ __m512i a_u16x32 = _mm512_cvtepu8_epi16(a_e3m2_u8x32);
1030
+ __m512i b_u16x32 = _mm512_cvtepu8_epi16(b_e3m2_u8x32);
1031
+ __m512i a_unsigned_i16x32 = _mm512_permutexvar_epi16(_mm512_and_si512(a_u16x32, magnitude_mask_i16x32),
1032
+ lut_magnitude_i16x32);
1033
+ __m512i b_unsigned_i16x32 = _mm512_permutexvar_epi16(_mm512_and_si512(b_u16x32, magnitude_mask_i16x32),
1034
+ lut_magnitude_i16x32);
1035
+
1036
+ __mmask32 a_negative_mask = _mm512_test_epi16_mask(a_u16x32, sign_mask_i16x32);
1037
+ __mmask32 b_negative_mask = _mm512_test_epi16_mask(b_u16x32, sign_mask_i16x32);
1038
+ __m512i a_signed_i16x32 = _mm512_mask_sub_epi16(a_unsigned_i16x32, a_negative_mask, _mm512_setzero_si512(),
1039
+ a_unsigned_i16x32);
1040
+ __m512i b_signed_i16x32 = _mm512_mask_sub_epi16(b_unsigned_i16x32, b_negative_mask, _mm512_setzero_si512(),
1041
+ b_unsigned_i16x32);
1042
+
1043
+ ab_i32x16 = _mm512_add_epi32(ab_i32x16, _mm512_madd_epi16(a_signed_i16x32, b_signed_i16x32));
1044
+ a2_i32x16 = _mm512_add_epi32(a2_i32x16, _mm512_madd_epi16(a_unsigned_i16x32, a_unsigned_i16x32));
1045
+ b2_i32x16 = _mm512_add_epi32(b2_i32x16, _mm512_madd_epi16(b_unsigned_i16x32, b_unsigned_i16x32));
1046
+
1047
+ if (n) goto nk_angular_e3m2_icelake_cycle;
1048
+
1049
+ nk_f32_t ab_f32 = (nk_f32_t)_mm512_reduce_add_epi32(ab_i32x16) / 256.0f;
1050
+ nk_f32_t a_norm_sq_f32 = (nk_f32_t)_mm512_reduce_add_epi32(a2_i32x16) / 256.0f;
1051
+ nk_f32_t b_norm_sq_f32 = (nk_f32_t)_mm512_reduce_add_epi32(b2_i32x16) / 256.0f;
1052
+ *result = nk_angular_normalize_f32_haswell_(ab_f32, a_norm_sq_f32, b_norm_sq_f32);
572
1053
  }
573
1054
 
574
1055
  #if defined(__clang__)