numkong 7.0.0 → 7.4.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (315) hide show
  1. package/README.md +239 -122
  2. package/binding.gyp +25 -491
  3. package/c/dispatch_bf16.c +59 -1
  4. package/c/dispatch_e2m3.c +41 -8
  5. package/c/dispatch_e3m2.c +49 -8
  6. package/c/dispatch_e4m3.c +51 -9
  7. package/c/dispatch_e5m2.c +45 -1
  8. package/c/dispatch_f16.c +79 -26
  9. package/c/dispatch_f16c.c +5 -5
  10. package/c/dispatch_f32.c +56 -0
  11. package/c/dispatch_f64.c +52 -0
  12. package/c/dispatch_i4.c +3 -0
  13. package/c/dispatch_i8.c +62 -3
  14. package/c/dispatch_other.c +18 -0
  15. package/c/dispatch_u1.c +54 -9
  16. package/c/dispatch_u4.c +3 -0
  17. package/c/dispatch_u8.c +64 -3
  18. package/c/numkong.c +3 -0
  19. package/include/README.md +79 -9
  20. package/include/numkong/attention/sapphireamx.h +278 -276
  21. package/include/numkong/attention/sme.h +983 -977
  22. package/include/numkong/attention.h +1 -1
  23. package/include/numkong/capabilities.h +289 -94
  24. package/include/numkong/cast/README.md +40 -40
  25. package/include/numkong/cast/diamond.h +64 -0
  26. package/include/numkong/cast/haswell.h +42 -194
  27. package/include/numkong/cast/icelake.h +42 -37
  28. package/include/numkong/cast/loongsonasx.h +252 -0
  29. package/include/numkong/cast/neon.h +216 -249
  30. package/include/numkong/cast/powervsx.h +449 -0
  31. package/include/numkong/cast/rvv.h +223 -274
  32. package/include/numkong/cast/sapphire.h +18 -18
  33. package/include/numkong/cast/serial.h +1018 -944
  34. package/include/numkong/cast/skylake.h +82 -23
  35. package/include/numkong/cast/v128relaxed.h +462 -105
  36. package/include/numkong/cast.h +24 -0
  37. package/include/numkong/cast.hpp +44 -0
  38. package/include/numkong/curved/README.md +17 -17
  39. package/include/numkong/curved/neon.h +131 -7
  40. package/include/numkong/curved/neonbfdot.h +6 -7
  41. package/include/numkong/curved/rvv.h +26 -26
  42. package/include/numkong/curved/smef64.h +186 -182
  43. package/include/numkong/curved.h +14 -18
  44. package/include/numkong/dot/README.md +154 -137
  45. package/include/numkong/dot/alder.h +43 -43
  46. package/include/numkong/dot/diamond.h +158 -0
  47. package/include/numkong/dot/genoa.h +4 -30
  48. package/include/numkong/dot/haswell.h +215 -180
  49. package/include/numkong/dot/icelake.h +190 -76
  50. package/include/numkong/dot/loongsonasx.h +671 -0
  51. package/include/numkong/dot/neon.h +124 -73
  52. package/include/numkong/dot/neonbfdot.h +11 -12
  53. package/include/numkong/dot/neonfhm.h +44 -46
  54. package/include/numkong/dot/neonfp8.h +323 -0
  55. package/include/numkong/dot/neonsdot.h +190 -76
  56. package/include/numkong/dot/powervsx.h +752 -0
  57. package/include/numkong/dot/rvv.h +92 -84
  58. package/include/numkong/dot/rvvbf16.h +12 -12
  59. package/include/numkong/dot/rvvhalf.h +12 -12
  60. package/include/numkong/dot/sapphire.h +4 -4
  61. package/include/numkong/dot/serial.h +66 -30
  62. package/include/numkong/dot/sierra.h +31 -31
  63. package/include/numkong/dot/skylake.h +142 -110
  64. package/include/numkong/dot/sve.h +217 -177
  65. package/include/numkong/dot/svebfdot.h +10 -10
  66. package/include/numkong/dot/svehalf.h +85 -41
  67. package/include/numkong/dot/svesdot.h +89 -0
  68. package/include/numkong/dot/v128relaxed.h +124 -89
  69. package/include/numkong/dot.h +114 -48
  70. package/include/numkong/dots/README.md +203 -203
  71. package/include/numkong/dots/alder.h +12 -9
  72. package/include/numkong/dots/diamond.h +86 -0
  73. package/include/numkong/dots/genoa.h +10 -4
  74. package/include/numkong/dots/haswell.h +63 -48
  75. package/include/numkong/dots/icelake.h +27 -18
  76. package/include/numkong/dots/loongsonasx.h +176 -0
  77. package/include/numkong/dots/neon.h +14 -11
  78. package/include/numkong/dots/neonbfdot.h +4 -3
  79. package/include/numkong/dots/neonfhm.h +11 -9
  80. package/include/numkong/dots/neonfp8.h +99 -0
  81. package/include/numkong/dots/neonsdot.h +48 -12
  82. package/include/numkong/dots/powervsx.h +194 -0
  83. package/include/numkong/dots/rvv.h +451 -344
  84. package/include/numkong/dots/sapphireamx.h +1028 -984
  85. package/include/numkong/dots/serial.h +213 -197
  86. package/include/numkong/dots/sierra.h +10 -7
  87. package/include/numkong/dots/skylake.h +47 -36
  88. package/include/numkong/dots/sme.h +2001 -2364
  89. package/include/numkong/dots/smebi32.h +175 -162
  90. package/include/numkong/dots/smef64.h +328 -323
  91. package/include/numkong/dots/v128relaxed.h +64 -41
  92. package/include/numkong/dots.h +573 -293
  93. package/include/numkong/dots.hpp +45 -43
  94. package/include/numkong/each/README.md +133 -137
  95. package/include/numkong/each/haswell.h +6 -6
  96. package/include/numkong/each/icelake.h +7 -7
  97. package/include/numkong/each/neon.h +76 -42
  98. package/include/numkong/each/neonbfdot.h +11 -12
  99. package/include/numkong/each/neonhalf.h +24 -116
  100. package/include/numkong/each/rvv.h +28 -28
  101. package/include/numkong/each/sapphire.h +27 -161
  102. package/include/numkong/each/serial.h +6 -6
  103. package/include/numkong/each/skylake.h +7 -7
  104. package/include/numkong/each/v128relaxed.h +562 -0
  105. package/include/numkong/each.h +148 -62
  106. package/include/numkong/each.hpp +2 -2
  107. package/include/numkong/geospatial/README.md +18 -18
  108. package/include/numkong/geospatial/haswell.h +365 -325
  109. package/include/numkong/geospatial/neon.h +350 -306
  110. package/include/numkong/geospatial/rvv.h +4 -4
  111. package/include/numkong/geospatial/skylake.h +376 -340
  112. package/include/numkong/geospatial/v128relaxed.h +366 -327
  113. package/include/numkong/geospatial.h +17 -17
  114. package/include/numkong/matrix.hpp +4 -4
  115. package/include/numkong/maxsim/README.md +14 -14
  116. package/include/numkong/maxsim/alder.h +6 -6
  117. package/include/numkong/maxsim/genoa.h +4 -4
  118. package/include/numkong/maxsim/haswell.h +6 -6
  119. package/include/numkong/maxsim/icelake.h +18 -18
  120. package/include/numkong/maxsim/neonsdot.h +21 -21
  121. package/include/numkong/maxsim/sapphireamx.h +14 -14
  122. package/include/numkong/maxsim/serial.h +6 -6
  123. package/include/numkong/maxsim/sme.h +221 -196
  124. package/include/numkong/maxsim/v128relaxed.h +6 -6
  125. package/include/numkong/mesh/README.md +62 -56
  126. package/include/numkong/mesh/haswell.h +339 -464
  127. package/include/numkong/mesh/neon.h +1100 -519
  128. package/include/numkong/mesh/neonbfdot.h +36 -68
  129. package/include/numkong/mesh/rvv.h +530 -435
  130. package/include/numkong/mesh/serial.h +75 -91
  131. package/include/numkong/mesh/skylake.h +1627 -302
  132. package/include/numkong/mesh/v128relaxed.h +443 -330
  133. package/include/numkong/mesh.h +63 -49
  134. package/include/numkong/mesh.hpp +4 -4
  135. package/include/numkong/numkong.h +3 -3
  136. package/include/numkong/numkong.hpp +1 -0
  137. package/include/numkong/probability/README.md +23 -19
  138. package/include/numkong/probability/neon.h +82 -52
  139. package/include/numkong/probability/rvv.h +28 -23
  140. package/include/numkong/probability/serial.h +51 -39
  141. package/include/numkong/probability.h +20 -23
  142. package/include/numkong/random.h +1 -1
  143. package/include/numkong/reduce/README.md +143 -138
  144. package/include/numkong/reduce/alder.h +81 -77
  145. package/include/numkong/reduce/haswell.h +222 -220
  146. package/include/numkong/reduce/neon.h +629 -519
  147. package/include/numkong/reduce/neonbfdot.h +7 -218
  148. package/include/numkong/reduce/neonfhm.h +9 -381
  149. package/include/numkong/reduce/neonsdot.h +9 -9
  150. package/include/numkong/reduce/rvv.h +928 -802
  151. package/include/numkong/reduce/serial.h +23 -27
  152. package/include/numkong/reduce/sierra.h +20 -20
  153. package/include/numkong/reduce/skylake.h +326 -324
  154. package/include/numkong/reduce/v128relaxed.h +52 -52
  155. package/include/numkong/reduce.h +4 -23
  156. package/include/numkong/reduce.hpp +156 -11
  157. package/include/numkong/scalar/README.md +6 -6
  158. package/include/numkong/scalar/haswell.h +26 -17
  159. package/include/numkong/scalar/loongsonasx.h +74 -0
  160. package/include/numkong/scalar/neon.h +9 -9
  161. package/include/numkong/scalar/powervsx.h +96 -0
  162. package/include/numkong/scalar/rvv.h +2 -2
  163. package/include/numkong/scalar/sapphire.h +21 -10
  164. package/include/numkong/scalar/serial.h +21 -21
  165. package/include/numkong/scalar.h +13 -0
  166. package/include/numkong/set/README.md +28 -28
  167. package/include/numkong/set/haswell.h +12 -12
  168. package/include/numkong/set/icelake.h +14 -14
  169. package/include/numkong/set/loongsonasx.h +181 -0
  170. package/include/numkong/set/neon.h +17 -18
  171. package/include/numkong/set/powervsx.h +326 -0
  172. package/include/numkong/set/rvv.h +4 -4
  173. package/include/numkong/set/serial.h +6 -6
  174. package/include/numkong/set/sve.h +60 -59
  175. package/include/numkong/set/v128relaxed.h +6 -6
  176. package/include/numkong/set.h +21 -7
  177. package/include/numkong/sets/README.md +26 -26
  178. package/include/numkong/sets/loongsonasx.h +52 -0
  179. package/include/numkong/sets/powervsx.h +65 -0
  180. package/include/numkong/sets/smebi32.h +395 -364
  181. package/include/numkong/sets.h +83 -40
  182. package/include/numkong/sparse/README.md +4 -4
  183. package/include/numkong/sparse/icelake.h +101 -101
  184. package/include/numkong/sparse/serial.h +1 -1
  185. package/include/numkong/sparse/sve2.h +137 -141
  186. package/include/numkong/sparse/turin.h +12 -12
  187. package/include/numkong/sparse.h +10 -10
  188. package/include/numkong/spatial/README.md +230 -226
  189. package/include/numkong/spatial/alder.h +113 -116
  190. package/include/numkong/spatial/diamond.h +240 -0
  191. package/include/numkong/spatial/genoa.h +0 -68
  192. package/include/numkong/spatial/haswell.h +74 -55
  193. package/include/numkong/spatial/icelake.h +539 -58
  194. package/include/numkong/spatial/loongsonasx.h +483 -0
  195. package/include/numkong/spatial/neon.h +125 -52
  196. package/include/numkong/spatial/neonbfdot.h +8 -9
  197. package/include/numkong/spatial/neonfp8.h +258 -0
  198. package/include/numkong/spatial/neonsdot.h +180 -12
  199. package/include/numkong/spatial/powervsx.h +738 -0
  200. package/include/numkong/spatial/rvv.h +146 -139
  201. package/include/numkong/spatial/rvvbf16.h +17 -12
  202. package/include/numkong/spatial/rvvhalf.h +13 -10
  203. package/include/numkong/spatial/serial.h +13 -12
  204. package/include/numkong/spatial/sierra.h +232 -39
  205. package/include/numkong/spatial/skylake.h +73 -74
  206. package/include/numkong/spatial/sve.h +93 -72
  207. package/include/numkong/spatial/svebfdot.h +29 -29
  208. package/include/numkong/spatial/svehalf.h +52 -26
  209. package/include/numkong/spatial/svesdot.h +142 -0
  210. package/include/numkong/spatial/v128relaxed.h +293 -41
  211. package/include/numkong/spatial.h +338 -82
  212. package/include/numkong/spatials/README.md +194 -194
  213. package/include/numkong/spatials/diamond.h +82 -0
  214. package/include/numkong/spatials/haswell.h +2 -2
  215. package/include/numkong/spatials/loongsonasx.h +153 -0
  216. package/include/numkong/spatials/neonfp8.h +111 -0
  217. package/include/numkong/spatials/neonsdot.h +34 -0
  218. package/include/numkong/spatials/powervsx.h +153 -0
  219. package/include/numkong/spatials/rvv.h +259 -243
  220. package/include/numkong/spatials/sapphireamx.h +173 -173
  221. package/include/numkong/spatials/serial.h +2 -2
  222. package/include/numkong/spatials/skylake.h +2 -2
  223. package/include/numkong/spatials/sme.h +590 -605
  224. package/include/numkong/spatials/smef64.h +139 -130
  225. package/include/numkong/spatials/v128relaxed.h +2 -2
  226. package/include/numkong/spatials.h +820 -500
  227. package/include/numkong/spatials.hpp +49 -48
  228. package/include/numkong/tensor.hpp +406 -17
  229. package/include/numkong/trigonometry/README.md +19 -19
  230. package/include/numkong/trigonometry/haswell.h +402 -401
  231. package/include/numkong/trigonometry/neon.h +386 -387
  232. package/include/numkong/trigonometry/rvv.h +52 -51
  233. package/include/numkong/trigonometry/serial.h +13 -13
  234. package/include/numkong/trigonometry/skylake.h +373 -369
  235. package/include/numkong/trigonometry/v128relaxed.h +375 -374
  236. package/include/numkong/trigonometry.h +13 -13
  237. package/include/numkong/trigonometry.hpp +2 -2
  238. package/include/numkong/types.h +287 -49
  239. package/include/numkong/types.hpp +436 -12
  240. package/include/numkong/vector.hpp +82 -14
  241. package/javascript/dist/cjs/numkong-wasm.js +6 -12
  242. package/javascript/dist/cjs/numkong.d.ts +7 -1
  243. package/javascript/dist/cjs/numkong.js +37 -11
  244. package/javascript/dist/cjs/types.d.ts +9 -0
  245. package/javascript/dist/cjs/types.js +96 -0
  246. package/javascript/dist/esm/numkong-browser.d.ts +14 -0
  247. package/javascript/dist/esm/numkong-browser.js +23 -0
  248. package/javascript/dist/esm/numkong-wasm.js +6 -12
  249. package/javascript/dist/esm/numkong.d.ts +7 -1
  250. package/javascript/dist/esm/numkong.js +37 -11
  251. package/javascript/dist/esm/types.d.ts +9 -0
  252. package/javascript/dist/esm/types.js +96 -0
  253. package/javascript/node-gyp-build.d.ts +4 -1
  254. package/javascript/numkong-browser.ts +40 -0
  255. package/javascript/numkong-wasm.ts +7 -13
  256. package/javascript/numkong.c +5 -26
  257. package/javascript/numkong.ts +36 -11
  258. package/javascript/tsconfig-base.json +1 -0
  259. package/javascript/tsconfig-cjs.json +6 -1
  260. package/javascript/types.ts +110 -0
  261. package/numkong.gypi +101 -0
  262. package/package.json +34 -13
  263. package/probes/arm_neon.c +8 -0
  264. package/probes/arm_neon_bfdot.c +9 -0
  265. package/probes/arm_neon_fhm.c +9 -0
  266. package/probes/arm_neon_half.c +8 -0
  267. package/probes/arm_neon_sdot.c +9 -0
  268. package/probes/arm_neonfp8.c +9 -0
  269. package/probes/arm_sme.c +16 -0
  270. package/probes/arm_sme2.c +16 -0
  271. package/probes/arm_sme2p1.c +16 -0
  272. package/probes/arm_sme_bf16.c +16 -0
  273. package/probes/arm_sme_bi32.c +16 -0
  274. package/probes/arm_sme_f64.c +16 -0
  275. package/probes/arm_sme_fa64.c +14 -0
  276. package/probes/arm_sme_half.c +16 -0
  277. package/probes/arm_sme_lut2.c +15 -0
  278. package/probes/arm_sve.c +18 -0
  279. package/probes/arm_sve2.c +20 -0
  280. package/probes/arm_sve2p1.c +18 -0
  281. package/probes/arm_sve_bfdot.c +20 -0
  282. package/probes/arm_sve_half.c +18 -0
  283. package/probes/arm_sve_sdot.c +21 -0
  284. package/probes/loongarch_lasx.c +12 -0
  285. package/probes/power_vsx.c +12 -0
  286. package/probes/probe.js +127 -0
  287. package/probes/riscv_rvv.c +14 -0
  288. package/probes/riscv_rvv_bb.c +15 -0
  289. package/probes/riscv_rvv_bf16.c +17 -0
  290. package/probes/riscv_rvv_half.c +14 -0
  291. package/probes/wasm_v128relaxed.c +11 -0
  292. package/probes/x86_alder.c +17 -0
  293. package/probes/x86_diamond.c +17 -0
  294. package/probes/x86_genoa.c +17 -0
  295. package/probes/x86_graniteamx.c +19 -0
  296. package/probes/x86_haswell.c +11 -0
  297. package/probes/x86_icelake.c +17 -0
  298. package/probes/x86_sapphire.c +16 -0
  299. package/probes/x86_sapphireamx.c +18 -0
  300. package/probes/x86_sierra.c +17 -0
  301. package/probes/x86_skylake.c +15 -0
  302. package/probes/x86_turin.c +17 -0
  303. package/wasm/numkong-emscripten.js +2 -0
  304. package/wasm/numkong.d.ts +14 -0
  305. package/wasm/numkong.js +1124 -0
  306. package/wasm/numkong.wasm +0 -0
  307. package/include/numkong/curved/neonhalf.h +0 -212
  308. package/include/numkong/dot/neonhalf.h +0 -198
  309. package/include/numkong/dots/neonhalf.h +0 -57
  310. package/include/numkong/mesh/neonhalf.h +0 -616
  311. package/include/numkong/reduce/neonhalf.h +0 -157
  312. package/include/numkong/spatial/neonhalf.h +0 -118
  313. package/include/numkong/spatial/sapphire.h +0 -343
  314. package/include/numkong/spatials/neonhalf.h +0 -58
  315. package/javascript/README.md +0 -246
@@ -8,14 +8,13 @@
8
8
  *
9
9
  * @section dot_icelake_instructions VNNI Instructions Performance
10
10
  *
11
- * Intrinsic Instruction Ice Genoa
12
- * _mm512_dpwssd_epi32 VPDPWSSD (ZMM, ZMM, ZMM) 5cy @ p0 4cy @ p01
13
- * _mm512_dpbusd_epi32 VPDPBUSD (ZMM, ZMM, ZMM) 5cy @ p0 4cy @ p01
14
- * _mm512_madd_epi16 VPMADDWD (ZMM, ZMM, ZMM) 5cy @ p05 3cy @ p01
11
+ * Intrinsic Instruction Icelake Genoa
12
+ * _mm512_dpwssd_epi32 VPDPWSSD (ZMM, ZMM, ZMM) 5cy @ p0 4cy @ p01
13
+ * _mm512_dpbusd_epi32 VPDPBUSD (ZMM, ZMM, ZMM) 5cy @ p0 4cy @ p01
14
+ * _mm512_madd_epi16 VPMADDWD (ZMM, ZMM, ZMM) 5cy @ p05 3cy @ p01
15
15
  *
16
16
  * Ice Lake introduces AVX-512 VNNI for accelerated integer dot products. VNNI instructions bottleneck
17
17
  * on port 0, limiting throughput to 1/cy. AMD Genoa dual-issues on ports 0-1, achieving 0.5/cy throughput.
18
- * We use VPDPWSSD for signed i8 inputs after widening to i16, since VPDPBUSD is asymmetric (unsigned x signed).
19
18
  *
20
19
  * @section dot_icelake_stateful Stateful Streaming Logic
21
20
  *
@@ -80,6 +79,7 @@
80
79
  #if NK_TARGET_ICELAKE
81
80
 
82
81
  #include "numkong/types.h"
82
+ #include "numkong/reduce/skylake.h" // `nk_reduce_add_f32x16_skylake_`
83
83
 
84
84
  #if defined(__cplusplus)
85
85
  extern "C" {
@@ -268,13 +268,13 @@ NK_INTERNAL void nk_dot_i8x64_finalize_icelake(
268
268
  __m128i sum_d_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(sum_d_i32x8), _mm256_extracti128_si256(sum_d_i32x8, 1));
269
269
 
270
270
  // 4-way transpose reduce
271
- __m128i t_ab_lo = _mm_unpacklo_epi32(sum_a_i32x4, sum_b_i32x4);
272
- __m128i t_cd_lo = _mm_unpacklo_epi32(sum_c_i32x4, sum_d_i32x4);
273
- __m128i t_ab_hi = _mm_unpackhi_epi32(sum_a_i32x4, sum_b_i32x4);
274
- __m128i t_cd_hi = _mm_unpackhi_epi32(sum_c_i32x4, sum_d_i32x4);
271
+ __m128i t_ab_low = _mm_unpacklo_epi32(sum_a_i32x4, sum_b_i32x4);
272
+ __m128i t_cd_low = _mm_unpacklo_epi32(sum_c_i32x4, sum_d_i32x4);
273
+ __m128i t_ab_high = _mm_unpackhi_epi32(sum_a_i32x4, sum_b_i32x4);
274
+ __m128i t_cd_high = _mm_unpackhi_epi32(sum_c_i32x4, sum_d_i32x4);
275
275
  __m128i biased_i32x4 = _mm_add_epi32(
276
- _mm_add_epi32(_mm_unpacklo_epi64(t_ab_lo, t_cd_lo), _mm_unpackhi_epi64(t_ab_lo, t_cd_lo)),
277
- _mm_add_epi32(_mm_unpacklo_epi64(t_ab_hi, t_cd_hi), _mm_unpackhi_epi64(t_ab_hi, t_cd_hi)));
276
+ _mm_add_epi32(_mm_unpacklo_epi64(t_ab_low, t_cd_low), _mm_unpackhi_epi64(t_ab_low, t_cd_low)),
277
+ _mm_add_epi32(_mm_unpacklo_epi64(t_ab_high, t_cd_high), _mm_unpackhi_epi64(t_ab_high, t_cd_high)));
278
278
 
279
279
  // Apply compensation: result = biased − 128 × Σb
280
280
  __m128i correction_i32x4 = _mm_slli_epi32(b_sums.xmm, 7); // × 128
@@ -335,13 +335,13 @@ NK_INTERNAL void nk_dot_u8x64_finalize_icelake(
335
335
  __m128i sum_d_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(sum_d_i32x8), _mm256_extracti128_si256(sum_d_i32x8, 1));
336
336
 
337
337
  // 4-way transpose reduce
338
- __m128i t_ab_lo = _mm_unpacklo_epi32(sum_a_i32x4, sum_b_i32x4);
339
- __m128i t_cd_lo = _mm_unpacklo_epi32(sum_c_i32x4, sum_d_i32x4);
340
- __m128i t_ab_hi = _mm_unpackhi_epi32(sum_a_i32x4, sum_b_i32x4);
341
- __m128i t_cd_hi = _mm_unpackhi_epi32(sum_c_i32x4, sum_d_i32x4);
338
+ __m128i t_ab_low = _mm_unpacklo_epi32(sum_a_i32x4, sum_b_i32x4);
339
+ __m128i t_cd_low = _mm_unpacklo_epi32(sum_c_i32x4, sum_d_i32x4);
340
+ __m128i t_ab_high = _mm_unpackhi_epi32(sum_a_i32x4, sum_b_i32x4);
341
+ __m128i t_cd_high = _mm_unpackhi_epi32(sum_c_i32x4, sum_d_i32x4);
342
342
  __m128i biased_i32x4 = _mm_add_epi32(
343
- _mm_add_epi32(_mm_unpacklo_epi64(t_ab_lo, t_cd_lo), _mm_unpackhi_epi64(t_ab_lo, t_cd_lo)),
344
- _mm_add_epi32(_mm_unpacklo_epi64(t_ab_hi, t_cd_hi), _mm_unpackhi_epi64(t_ab_hi, t_cd_hi)));
343
+ _mm_add_epi32(_mm_unpacklo_epi64(t_ab_low, t_cd_low), _mm_unpackhi_epi64(t_ab_low, t_cd_low)),
344
+ _mm_add_epi32(_mm_unpacklo_epi64(t_ab_high, t_cd_high), _mm_unpackhi_epi64(t_ab_high, t_cd_high)));
345
345
 
346
346
  // Apply compensation: result = biased + 128 × Σb
347
347
  __m128i correction_i32x4 = _mm_slli_epi32(b_sums.xmm, 7); // × 128
@@ -402,18 +402,18 @@ NK_INTERNAL void nk_sum_i4x128_update_icelake(nk_sum_i4x128_state_icelake_t *sta
402
402
  __m512i const nibble_mask_u8x64 = _mm512_set1_epi8(0x0F);
403
403
  __m512i const xor_mask_u8x64 = _mm512_set1_epi8(0x08);
404
404
  __m512i const zeros_u8x64 = _mm512_setzero_si512();
405
- /* Extract low and high nibbles, XOR with 8 to get unsigned representation */
405
+ // Extract low and high nibbles, XOR with 8 to get unsigned representation
406
406
  __m512i low_u8x64 = _mm512_and_si512(v.zmm, nibble_mask_u8x64);
407
407
  __m512i high_u8x64 = _mm512_and_si512(_mm512_srli_epi16(v.zmm, 4), nibble_mask_u8x64);
408
408
  __m512i low_biased_u8x64 = _mm512_xor_si512(low_u8x64, xor_mask_u8x64);
409
409
  __m512i high_biased_u8x64 = _mm512_xor_si512(high_u8x64, xor_mask_u8x64);
410
- /* SAD against zero gives sum of unsigned values, accumulate in u64 lanes */
410
+ // SAD against zero gives sum of unsigned values, accumulate in u64 lanes
411
411
  state->biased_sum_u64x8 = _mm512_add_epi64(state->biased_sum_u64x8, _mm512_sad_epu8(low_biased_u8x64, zeros_u8x64));
412
412
  state->biased_sum_u64x8 = _mm512_add_epi64(state->biased_sum_u64x8,
413
413
  _mm512_sad_epu8(high_biased_u8x64, zeros_u8x64));
414
414
  }
415
415
  NK_INTERNAL nk_i32_t nk_sum_i4x128_finalize_icelake(nk_sum_i4x128_state_icelake_t const *state, nk_size_t count) {
416
- /* Reduce u64x8 → scalar, then undo XOR bias: signed_sum = unsigned_sum - 8 * count */
416
+ // Reduce u64x8 → scalar, then undo XOR bias: signed_sum = unsigned_sum - 8 * count
417
417
  nk_i64_t unsigned_sum = _mm512_reduce_add_epi64(state->biased_sum_u64x8);
418
418
  return (nk_i32_t)(unsigned_sum - 8 * (nk_i64_t)count);
419
419
  }
@@ -454,26 +454,26 @@ nk_dot_i4_icelake_cycle:
454
454
  }
455
455
 
456
456
  // Extract low and high nibbles
457
- __m512i a_lo_u8x64 = _mm512_and_si512(a_i4x128, nibble_mask_u8x64);
458
- __m512i a_hi_u8x64 = _mm512_and_si512(_mm512_srli_epi16(a_i4x128, 4), nibble_mask_u8x64);
459
- __m512i b_lo_u8x64 = _mm512_and_si512(b_i4x128, nibble_mask_u8x64);
460
- __m512i b_hi_u8x64 = _mm512_and_si512(_mm512_srli_epi16(b_i4x128, 4), nibble_mask_u8x64);
457
+ __m512i a_low_u8x64 = _mm512_and_si512(a_i4x128, nibble_mask_u8x64);
458
+ __m512i a_high_u8x64 = _mm512_and_si512(_mm512_srli_epi16(a_i4x128, 4), nibble_mask_u8x64);
459
+ __m512i b_low_u8x64 = _mm512_and_si512(b_i4x128, nibble_mask_u8x64);
460
+ __m512i b_high_u8x64 = _mm512_and_si512(_mm512_srli_epi16(b_i4x128, 4), nibble_mask_u8x64);
461
461
 
462
462
  // XOR with 8 to get cx, dx values for the algebraic transformation
463
- __m512i c_lo_u8x64 = _mm512_xor_si512(a_lo_u8x64, xor_mask_u8x64);
464
- __m512i c_hi_u8x64 = _mm512_xor_si512(a_hi_u8x64, xor_mask_u8x64);
465
- __m512i d_lo_u8x64 = _mm512_xor_si512(b_lo_u8x64, xor_mask_u8x64);
466
- __m512i d_hi_u8x64 = _mm512_xor_si512(b_hi_u8x64, xor_mask_u8x64);
463
+ __m512i c_low_u8x64 = _mm512_xor_si512(a_low_u8x64, xor_mask_u8x64);
464
+ __m512i c_high_u8x64 = _mm512_xor_si512(a_high_u8x64, xor_mask_u8x64);
465
+ __m512i d_low_u8x64 = _mm512_xor_si512(b_low_u8x64, xor_mask_u8x64);
466
+ __m512i d_high_u8x64 = _mm512_xor_si512(b_high_u8x64, xor_mask_u8x64);
467
467
 
468
468
  // Compute dot products of cx*dx for low and high nibbles
469
- sum_cd_i32x16 = _mm512_dpbusd_epi32(sum_cd_i32x16, c_lo_u8x64, d_lo_u8x64);
470
- sum_cd_i32x16 = _mm512_dpbusd_epi32(sum_cd_i32x16, c_hi_u8x64, d_hi_u8x64);
469
+ sum_cd_i32x16 = _mm512_dpbusd_epi32(sum_cd_i32x16, c_low_u8x64, d_low_u8x64);
470
+ sum_cd_i32x16 = _mm512_dpbusd_epi32(sum_cd_i32x16, c_high_u8x64, d_high_u8x64);
471
471
 
472
472
  // Accumulate sums of cx and dx using SAD against zeros
473
- sum_cx_i64x8 = _mm512_add_epi64(sum_cx_i64x8, _mm512_sad_epu8(c_lo_u8x64, zeros_u8x64));
474
- sum_cx_i64x8 = _mm512_add_epi64(sum_cx_i64x8, _mm512_sad_epu8(c_hi_u8x64, zeros_u8x64));
475
- sum_dx_i64x8 = _mm512_add_epi64(sum_dx_i64x8, _mm512_sad_epu8(d_lo_u8x64, zeros_u8x64));
476
- sum_dx_i64x8 = _mm512_add_epi64(sum_dx_i64x8, _mm512_sad_epu8(d_hi_u8x64, zeros_u8x64));
473
+ sum_cx_i64x8 = _mm512_add_epi64(sum_cx_i64x8, _mm512_sad_epu8(c_low_u8x64, zeros_u8x64));
474
+ sum_cx_i64x8 = _mm512_add_epi64(sum_cx_i64x8, _mm512_sad_epu8(c_high_u8x64, zeros_u8x64));
475
+ sum_dx_i64x8 = _mm512_add_epi64(sum_dx_i64x8, _mm512_sad_epu8(d_low_u8x64, zeros_u8x64));
476
+ sum_dx_i64x8 = _mm512_add_epi64(sum_dx_i64x8, _mm512_sad_epu8(d_high_u8x64, zeros_u8x64));
477
477
  if (n_bytes) goto nk_dot_i4_icelake_cycle;
478
478
 
479
479
  // Reduce partial sums and apply algebraic correction
@@ -509,15 +509,15 @@ nk_dot_u4_icelake_cycle:
509
509
  }
510
510
 
511
511
  // Extract low and high nibbles
512
- __m512i a_lo_u8x64 = _mm512_and_si512(a_u4x128, nibble_mask_u8x64);
513
- __m512i a_hi_u8x64 = _mm512_and_si512(_mm512_srli_epi16(a_u4x128, 4), nibble_mask_u8x64);
514
- __m512i b_lo_u8x64 = _mm512_and_si512(b_u4x128, nibble_mask_u8x64);
515
- __m512i b_hi_u8x64 = _mm512_and_si512(_mm512_srli_epi16(b_u4x128, 4), nibble_mask_u8x64);
512
+ __m512i a_low_u8x64 = _mm512_and_si512(a_u4x128, nibble_mask_u8x64);
513
+ __m512i a_high_u8x64 = _mm512_and_si512(_mm512_srli_epi16(a_u4x128, 4), nibble_mask_u8x64);
514
+ __m512i b_low_u8x64 = _mm512_and_si512(b_u4x128, nibble_mask_u8x64);
515
+ __m512i b_high_u8x64 = _mm512_and_si512(_mm512_srli_epi16(b_u4x128, 4), nibble_mask_u8x64);
516
516
 
517
517
  // DPBUSD works directly for u4 since values are ∈ [0,15]
518
518
  // and the signed interpretation of [0,15] is the same as unsigned
519
- sum_i32x16 = _mm512_dpbusd_epi32(sum_i32x16, a_lo_u8x64, b_lo_u8x64);
520
- sum_i32x16 = _mm512_dpbusd_epi32(sum_i32x16, a_hi_u8x64, b_hi_u8x64);
519
+ sum_i32x16 = _mm512_dpbusd_epi32(sum_i32x16, a_low_u8x64, b_low_u8x64);
520
+ sum_i32x16 = _mm512_dpbusd_epi32(sum_i32x16, a_high_u8x64, b_high_u8x64);
521
521
  if (n_bytes) goto nk_dot_u4_icelake_cycle;
522
522
 
523
523
  *result = (nk_u32_t)_mm512_reduce_add_epi32(sum_i32x16);
@@ -545,22 +545,22 @@ NK_INTERNAL void nk_dot_i4x128_update_icelake(nk_dot_i4x128_state_icelake_t *sta
545
545
  __m512i b_i4x128 = b.zmm;
546
546
 
547
547
  // Extract low and high nibbles (all 128 nibbles from 64 bytes)
548
- __m512i a_lo_u8x64 = _mm512_and_si512(a_i4x128, nibble_mask_u8x64);
549
- __m512i a_hi_u8x64 = _mm512_and_si512(_mm512_srli_epi16(a_i4x128, 4), nibble_mask_u8x64);
550
- __m512i b_lo_u8x64 = _mm512_and_si512(b_i4x128, nibble_mask_u8x64);
551
- __m512i b_hi_u8x64 = _mm512_and_si512(_mm512_srli_epi16(b_i4x128, 4), nibble_mask_u8x64);
548
+ __m512i a_low_u8x64 = _mm512_and_si512(a_i4x128, nibble_mask_u8x64);
549
+ __m512i a_high_u8x64 = _mm512_and_si512(_mm512_srli_epi16(a_i4x128, 4), nibble_mask_u8x64);
550
+ __m512i b_low_u8x64 = _mm512_and_si512(b_i4x128, nibble_mask_u8x64);
551
+ __m512i b_high_u8x64 = _mm512_and_si512(_mm512_srli_epi16(b_i4x128, 4), nibble_mask_u8x64);
552
552
 
553
553
  // Apply bias transformation: XOR with 8
554
- __m512i a_biased_lo_u8x64 = _mm512_xor_si512(a_lo_u8x64, bias_xor_mask_u8x64);
555
- __m512i a_biased_hi_u8x64 = _mm512_xor_si512(a_hi_u8x64, bias_xor_mask_u8x64);
556
- __m512i b_biased_lo_u8x64 = _mm512_xor_si512(b_lo_u8x64, bias_xor_mask_u8x64);
557
- __m512i b_biased_hi_u8x64 = _mm512_xor_si512(b_hi_u8x64, bias_xor_mask_u8x64);
554
+ __m512i a_biased_low_u8x64 = _mm512_xor_si512(a_low_u8x64, bias_xor_mask_u8x64);
555
+ __m512i a_biased_high_u8x64 = _mm512_xor_si512(a_high_u8x64, bias_xor_mask_u8x64);
556
+ __m512i b_biased_low_u8x64 = _mm512_xor_si512(b_low_u8x64, bias_xor_mask_u8x64);
557
+ __m512i b_biased_high_u8x64 = _mm512_xor_si512(b_high_u8x64, bias_xor_mask_u8x64);
558
558
 
559
559
  // Compute dot products of a_biased×b_biased — no SAD correction accumulators
560
- state->biased_product_sum_i32x16 = _mm512_dpbusd_epi32(state->biased_product_sum_i32x16, a_biased_lo_u8x64,
561
- b_biased_lo_u8x64);
562
- state->biased_product_sum_i32x16 = _mm512_dpbusd_epi32(state->biased_product_sum_i32x16, a_biased_hi_u8x64,
563
- b_biased_hi_u8x64);
560
+ state->biased_product_sum_i32x16 = _mm512_dpbusd_epi32(state->biased_product_sum_i32x16, a_biased_low_u8x64,
561
+ b_biased_low_u8x64);
562
+ state->biased_product_sum_i32x16 = _mm512_dpbusd_epi32(state->biased_product_sum_i32x16, a_biased_high_u8x64,
563
+ b_biased_high_u8x64);
564
564
  }
565
565
 
566
566
  NK_INTERNAL void nk_dot_i4x128_finalize_icelake( //
@@ -596,13 +596,13 @@ NK_INTERNAL void nk_dot_i4x128_finalize_icelake(
596
596
  _mm256_extracti128_si256(product_d_i32x8, 1));
597
597
 
598
598
  // 4-way transpose reduce
599
- __m128i t_ab_lo = _mm_unpacklo_epi32(product_a_i32x4, product_b_i32x4);
600
- __m128i t_cd_lo = _mm_unpacklo_epi32(product_c_i32x4, product_d_i32x4);
601
- __m128i t_ab_hi = _mm_unpackhi_epi32(product_a_i32x4, product_b_i32x4);
602
- __m128i t_cd_hi = _mm_unpackhi_epi32(product_c_i32x4, product_d_i32x4);
599
+ __m128i t_ab_low = _mm_unpacklo_epi32(product_a_i32x4, product_b_i32x4);
600
+ __m128i t_cd_low = _mm_unpacklo_epi32(product_c_i32x4, product_d_i32x4);
601
+ __m128i t_ab_high = _mm_unpackhi_epi32(product_a_i32x4, product_b_i32x4);
602
+ __m128i t_cd_high = _mm_unpackhi_epi32(product_c_i32x4, product_d_i32x4);
603
603
  __m128i biased_i32x4 = _mm_add_epi32(
604
- _mm_add_epi32(_mm_unpacklo_epi64(t_ab_lo, t_cd_lo), _mm_unpackhi_epi64(t_ab_lo, t_cd_lo)),
605
- _mm_add_epi32(_mm_unpacklo_epi64(t_ab_hi, t_cd_hi), _mm_unpackhi_epi64(t_ab_hi, t_cd_hi)));
604
+ _mm_add_epi32(_mm_unpacklo_epi64(t_ab_low, t_cd_low), _mm_unpackhi_epi64(t_ab_low, t_cd_low)),
605
+ _mm_add_epi32(_mm_unpacklo_epi64(t_ab_high, t_cd_high), _mm_unpackhi_epi64(t_ab_high, t_cd_high)));
606
606
 
607
607
  // Apply compensation: result = biased − 8×(Σa + Σb) − 64×depth_padded
608
608
  __m128i a_sum_broadcast_i32x4 = _mm_set1_epi32(a_sum);
@@ -633,14 +633,14 @@ NK_INTERNAL void nk_dot_u4x128_update_icelake(nk_dot_u4x128_state_icelake_t *sta
633
633
  __m512i b_u4x128 = b.zmm;
634
634
 
635
635
  // Extract low and high nibbles (all 128 nibbles from 64 bytes)
636
- __m512i a_lo_u8x64 = _mm512_and_si512(a_u4x128, nibble_mask_u8x64);
637
- __m512i a_hi_u8x64 = _mm512_and_si512(_mm512_srli_epi16(a_u4x128, 4), nibble_mask_u8x64);
638
- __m512i b_lo_u8x64 = _mm512_and_si512(b_u4x128, nibble_mask_u8x64);
639
- __m512i b_hi_u8x64 = _mm512_and_si512(_mm512_srli_epi16(b_u4x128, 4), nibble_mask_u8x64);
636
+ __m512i a_low_u8x64 = _mm512_and_si512(a_u4x128, nibble_mask_u8x64);
637
+ __m512i a_high_u8x64 = _mm512_and_si512(_mm512_srli_epi16(a_u4x128, 4), nibble_mask_u8x64);
638
+ __m512i b_low_u8x64 = _mm512_and_si512(b_u4x128, nibble_mask_u8x64);
639
+ __m512i b_high_u8x64 = _mm512_and_si512(_mm512_srli_epi16(b_u4x128, 4), nibble_mask_u8x64);
640
640
 
641
641
  // DPBUSD works directly for u4 since values are ∈ [0,15]
642
- state->sum_i32x16 = _mm512_dpbusd_epi32(state->sum_i32x16, a_lo_u8x64, b_lo_u8x64);
643
- state->sum_i32x16 = _mm512_dpbusd_epi32(state->sum_i32x16, a_hi_u8x64, b_hi_u8x64);
642
+ state->sum_i32x16 = _mm512_dpbusd_epi32(state->sum_i32x16, a_low_u8x64, b_low_u8x64);
643
+ state->sum_i32x16 = _mm512_dpbusd_epi32(state->sum_i32x16, a_high_u8x64, b_high_u8x64);
644
644
  }
645
645
 
646
646
  NK_INTERNAL void nk_dot_u4x128_finalize_icelake( //
@@ -667,16 +667,17 @@ NK_INTERNAL void nk_dot_u4x128_finalize_icelake(
667
667
  __m128i sum_d_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(sum_d_i32x8), _mm256_extracti128_si256(sum_d_i32x8, 1));
668
668
 
669
669
  // 4-way transpose to get [a,b,c,d] in lanes
670
- __m128i transpose_ab_low = _mm_unpacklo_epi32(sum_a_i32x4, sum_b_i32x4);
671
- __m128i transpose_cd_low = _mm_unpacklo_epi32(sum_c_i32x4, sum_d_i32x4);
672
- __m128i transpose_ab_high = _mm_unpackhi_epi32(sum_a_i32x4, sum_b_i32x4);
673
- __m128i transpose_cd_high = _mm_unpackhi_epi32(sum_c_i32x4, sum_d_i32x4);
674
- __m128i sum_lane0 = _mm_unpacklo_epi64(transpose_ab_low, transpose_cd_low);
675
- __m128i sum_lane1 = _mm_unpackhi_epi64(transpose_ab_low, transpose_cd_low);
676
- __m128i sum_lane2 = _mm_unpacklo_epi64(transpose_ab_high, transpose_cd_high);
677
- __m128i sum_lane3 = _mm_unpackhi_epi64(transpose_ab_high, transpose_cd_high);
678
-
679
- __m128i final_i32x4 = _mm_add_epi32(_mm_add_epi32(sum_lane0, sum_lane1), _mm_add_epi32(sum_lane2, sum_lane3));
670
+ __m128i transpose_ab_low_i32x4 = _mm_unpacklo_epi32(sum_a_i32x4, sum_b_i32x4);
671
+ __m128i transpose_cd_low_i32x4 = _mm_unpacklo_epi32(sum_c_i32x4, sum_d_i32x4);
672
+ __m128i transpose_ab_high_i32x4 = _mm_unpackhi_epi32(sum_a_i32x4, sum_b_i32x4);
673
+ __m128i transpose_cd_high_i32x4 = _mm_unpackhi_epi32(sum_c_i32x4, sum_d_i32x4);
674
+ __m128i sum_lane0_i32x4 = _mm_unpacklo_epi64(transpose_ab_low_i32x4, transpose_cd_low_i32x4);
675
+ __m128i sum_lane1_i32x4 = _mm_unpackhi_epi64(transpose_ab_low_i32x4, transpose_cd_low_i32x4);
676
+ __m128i sum_lane2_i32x4 = _mm_unpacklo_epi64(transpose_ab_high_i32x4, transpose_cd_high_i32x4);
677
+ __m128i sum_lane3_i32x4 = _mm_unpackhi_epi64(transpose_ab_high_i32x4, transpose_cd_high_i32x4);
678
+
679
+ __m128i final_i32x4 = _mm_add_epi32(_mm_add_epi32(sum_lane0_i32x4, sum_lane1_i32x4),
680
+ _mm_add_epi32(sum_lane2_i32x4, sum_lane3_i32x4));
680
681
  result->xmm = final_i32x4;
681
682
  }
682
683
 
@@ -801,7 +802,120 @@ nk_dot_e3m2_icelake_cycle:
801
802
  *result = (nk_f32_t)_mm512_reduce_add_epi32(sum_i32x16) / 256.0f;
802
803
  }
803
804
 
804
- #pragma region - Binary
805
+ #pragma region F16 and BF16 Floats
806
+
807
+ NK_PUBLIC void nk_dot_e4m3_icelake(nk_e4m3_t const *a_scalars, nk_e4m3_t const *b_scalars, nk_size_t count_scalars,
808
+ nk_f32_t *result) {
809
+ // E4M3 dot product via octave decomposition + VPDPBUSD integer MAC.
810
+ // Splits 4-bit exponent into 2 octave bits + 2 remainder bits, maps low 5 bits via VPERMB
811
+ // to u8 integers [0, 120], then 16 VPDPBUSD cross-products across 4×4 octave pairs.
812
+
813
+ __m512i const lut_normal_u8x64 = _mm512_set_epi8( //
814
+ 120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32, //
815
+ 30, 28, 26, 24, 22, 20, 18, 16, 15, 14, 13, 12, 11, 10, 9, 8, //
816
+ 120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32, //
817
+ 30, 28, 26, 24, 22, 20, 18, 16, 15, 14, 13, 12, 11, 10, 9, 8); //
818
+ __m512i const lut_subnorm_u8x64 = _mm512_set_epi8( //
819
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, //
820
+ 0, 0, 0, 0, 0, 0, 0, 0, 14, 12, 10, 8, 6, 4, 2, 0, //
821
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, //
822
+ 0, 0, 0, 0, 0, 0, 0, 0, 14, 12, 10, 8, 6, 4, 2, 0); //
823
+ __m512i const magnitude_mask_u8x64 = _mm512_set1_epi8(0x7F);
824
+ __m512i const subnorm_threshold_u8x64 = _mm512_set1_epi8(0x08);
825
+ __m512i const oct_threshold_20_u8x64 = _mm512_set1_epi8(0x20);
826
+ __m512i const oct_threshold_40_u8x64 = _mm512_set1_epi8(0x40);
827
+ __m512i const oct_threshold_60_u8x64 = _mm512_set1_epi8(0x60);
828
+
829
+ __m512i dot0_i32x16 = _mm512_setzero_si512();
830
+ __m512i dot1_i32x16 = _mm512_setzero_si512();
831
+ __m512i dot2_i32x16 = _mm512_setzero_si512();
832
+ __m512i dot3_i32x16 = _mm512_setzero_si512();
833
+ __m512i dot4_i32x16 = _mm512_setzero_si512();
834
+ __m512i dot5_i32x16 = _mm512_setzero_si512();
835
+ __m512i dot6_i32x16 = _mm512_setzero_si512();
836
+ __m512i a_e4m3_u8x64, b_e4m3_u8x64;
837
+
838
+ nk_dot_e4m3_icelake_cycle:
839
+ if (count_scalars < 64) {
840
+ __mmask64 mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFF, count_scalars);
841
+ a_e4m3_u8x64 = _mm512_maskz_loadu_epi8(mask, a_scalars);
842
+ b_e4m3_u8x64 = _mm512_maskz_loadu_epi8(mask, b_scalars);
843
+ count_scalars = 0;
844
+ }
845
+ else {
846
+ a_e4m3_u8x64 = _mm512_loadu_si512(a_scalars);
847
+ b_e4m3_u8x64 = _mm512_loadu_si512(b_scalars);
848
+ a_scalars += 64, b_scalars += 64, count_scalars -= 64;
849
+ }
850
+
851
+ __m512i a_magnitude_u8x64 = _mm512_and_si512(a_e4m3_u8x64, magnitude_mask_u8x64);
852
+ __m512i b_magnitude_u8x64 = _mm512_and_si512(b_e4m3_u8x64, magnitude_mask_u8x64);
853
+ __m512i a_base_u8x64 = _mm512_permutexvar_epi8(a_magnitude_u8x64, lut_normal_u8x64);
854
+ __m512i b_base_u8x64 = _mm512_permutexvar_epi8(b_magnitude_u8x64, lut_normal_u8x64);
855
+
856
+ // Subnormal fixup via VPERMB (avoids VPADDB on Zen 4 ports 8+9 / SPR port 0)
857
+ a_base_u8x64 = _mm512_mask_permutexvar_epi8(a_base_u8x64,
858
+ _mm512_cmplt_epu8_mask(a_magnitude_u8x64, subnorm_threshold_u8x64),
859
+ a_magnitude_u8x64, lut_subnorm_u8x64);
860
+ b_base_u8x64 = _mm512_mask_permutexvar_epi8(b_base_u8x64,
861
+ _mm512_cmplt_epu8_mask(b_magnitude_u8x64, subnorm_threshold_u8x64),
862
+ b_magnitude_u8x64, lut_subnorm_u8x64);
863
+
864
+ // Sign via ternary logic: (a ^ b) & ~0x7F in one VPTERNLOGD (imm 0x14)
865
+ __m512i sign_diff_u8x64 = _mm512_ternarylogic_epi64(a_e4m3_u8x64, b_e4m3_u8x64, magnitude_mask_u8x64, 0x14);
866
+ __m512i b_signed_i8x64 = _mm512_mask_sub_epi8(b_base_u8x64, _mm512_test_epi8_mask(sign_diff_u8x64, sign_diff_u8x64),
867
+ _mm512_setzero_si512(), b_base_u8x64);
868
+
869
+ // Octave masks via cascaded range compares on magnitude
870
+ __mmask64 ka_lt20 = _mm512_cmplt_epu8_mask(a_magnitude_u8x64, oct_threshold_20_u8x64);
871
+ __mmask64 ka_lt40 = _mm512_cmplt_epu8_mask(a_magnitude_u8x64, oct_threshold_40_u8x64);
872
+ __mmask64 ka_lt60 = _mm512_cmplt_epu8_mask(a_magnitude_u8x64, oct_threshold_60_u8x64);
873
+ __mmask64 kb_lt20 = _mm512_cmplt_epu8_mask(b_magnitude_u8x64, oct_threshold_20_u8x64);
874
+ __mmask64 kb_lt40 = _mm512_cmplt_epu8_mask(b_magnitude_u8x64, oct_threshold_40_u8x64);
875
+ __mmask64 kb_lt60 = _mm512_cmplt_epu8_mask(b_magnitude_u8x64, oct_threshold_60_u8x64);
876
+
877
+ __m512i a0_u8x64 = _mm512_maskz_mov_epi8(ka_lt20, a_base_u8x64);
878
+ __m512i a1_u8x64 = _mm512_maskz_mov_epi8(ka_lt40 & ~ka_lt20, a_base_u8x64);
879
+ __m512i a2_u8x64 = _mm512_maskz_mov_epi8(ka_lt60 & ~ka_lt40, a_base_u8x64);
880
+ __m512i a3_u8x64 = _mm512_maskz_mov_epi8(~ka_lt60, a_base_u8x64);
881
+
882
+ __m512i b0_i8x64 = _mm512_maskz_mov_epi8(kb_lt20, b_signed_i8x64);
883
+ __m512i b1_i8x64 = _mm512_maskz_mov_epi8(kb_lt40 & ~kb_lt20, b_signed_i8x64);
884
+ __m512i b2_i8x64 = _mm512_maskz_mov_epi8(kb_lt60 & ~kb_lt40, b_signed_i8x64);
885
+ __m512i b3_i8x64 = _mm512_maskz_mov_epi8(~kb_lt60, b_signed_i8x64);
886
+
887
+ // 16 VPDPBUSD into 7 accumulators grouped by octave sum k = oa + ob
888
+ dot0_i32x16 = _mm512_dpbusd_epi32(dot0_i32x16, a0_u8x64, b0_i8x64);
889
+ dot1_i32x16 = _mm512_dpbusd_epi32(_mm512_dpbusd_epi32(dot1_i32x16, a0_u8x64, b1_i8x64), a1_u8x64, b0_i8x64);
890
+ dot2_i32x16 = _mm512_dpbusd_epi32(
891
+ _mm512_dpbusd_epi32(_mm512_dpbusd_epi32(dot2_i32x16, a0_u8x64, b2_i8x64), a1_u8x64, b1_i8x64), a2_u8x64,
892
+ b0_i8x64);
893
+ dot3_i32x16 = _mm512_dpbusd_epi32(
894
+ _mm512_dpbusd_epi32(
895
+ _mm512_dpbusd_epi32(_mm512_dpbusd_epi32(dot3_i32x16, a0_u8x64, b3_i8x64), a1_u8x64, b2_i8x64), a2_u8x64,
896
+ b1_i8x64),
897
+ a3_u8x64, b0_i8x64);
898
+ dot4_i32x16 = _mm512_dpbusd_epi32(
899
+ _mm512_dpbusd_epi32(_mm512_dpbusd_epi32(dot4_i32x16, a1_u8x64, b3_i8x64), a2_u8x64, b2_i8x64), a3_u8x64,
900
+ b1_i8x64);
901
+ dot5_i32x16 = _mm512_dpbusd_epi32(_mm512_dpbusd_epi32(dot5_i32x16, a2_u8x64, b3_i8x64), a3_u8x64, b2_i8x64);
902
+ dot6_i32x16 = _mm512_dpbusd_epi32(dot6_i32x16, a3_u8x64, b3_i8x64);
903
+
904
+ if (count_scalars) goto nk_dot_e4m3_icelake_cycle;
905
+
906
+ __m512 sum_f32x16 = _mm512_mul_ps(_mm512_cvtepi32_ps(dot0_i32x16), _mm512_set1_ps(9.5367431640625e-07f));
907
+ sum_f32x16 = _mm512_fmadd_ps(_mm512_cvtepi32_ps(dot1_i32x16), _mm512_set1_ps(1.52587890625e-05f), sum_f32x16);
908
+ sum_f32x16 = _mm512_fmadd_ps(_mm512_cvtepi32_ps(dot2_i32x16), _mm512_set1_ps(2.44140625e-04f), sum_f32x16);
909
+ sum_f32x16 = _mm512_fmadd_ps(_mm512_cvtepi32_ps(dot3_i32x16), _mm512_set1_ps(3.90625e-03f), sum_f32x16);
910
+ sum_f32x16 = _mm512_fmadd_ps(_mm512_cvtepi32_ps(dot4_i32x16), _mm512_set1_ps(6.25e-02f), sum_f32x16);
911
+ sum_f32x16 = _mm512_fmadd_ps(_mm512_cvtepi32_ps(dot5_i32x16), _mm512_set1_ps(1.0f), sum_f32x16);
912
+ sum_f32x16 = _mm512_fmadd_ps(_mm512_cvtepi32_ps(dot6_i32x16), _mm512_set1_ps(16.0f), sum_f32x16);
913
+ *result = nk_reduce_add_f32x16_skylake_(sum_f32x16);
914
+ }
915
+
916
+ #pragma endregion F16 and BF16 Floats
917
+
918
+ #pragma region Binary
805
919
 
806
920
  NK_PUBLIC void nk_dot_u1_icelake(nk_u1x8_t const *a, nk_u1x8_t const *b, nk_size_t n_bits, nk_u32_t *result) {
807
921
  nk_size_t n_bytes = nk_size_divide_round_up_(n_bits, NK_BITS_PER_BYTE);
@@ -866,7 +980,7 @@ NK_INTERNAL void nk_dot_u1x512_finalize_icelake( //
866
980
  result->xmm = _mm_hadd_epi32(ab_i32x4, cd_i32x4);
867
981
  }
868
982
 
869
- #pragma endregion - Binary
983
+ #pragma endregion Binary
870
984
 
871
985
  #if defined(__clang__)
872
986
  #pragma clang attribute pop