numkong 7.0.0 → 7.4.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (315) hide show
  1. package/README.md +239 -122
  2. package/binding.gyp +25 -491
  3. package/c/dispatch_bf16.c +59 -1
  4. package/c/dispatch_e2m3.c +41 -8
  5. package/c/dispatch_e3m2.c +49 -8
  6. package/c/dispatch_e4m3.c +51 -9
  7. package/c/dispatch_e5m2.c +45 -1
  8. package/c/dispatch_f16.c +79 -26
  9. package/c/dispatch_f16c.c +5 -5
  10. package/c/dispatch_f32.c +56 -0
  11. package/c/dispatch_f64.c +52 -0
  12. package/c/dispatch_i4.c +3 -0
  13. package/c/dispatch_i8.c +62 -3
  14. package/c/dispatch_other.c +18 -0
  15. package/c/dispatch_u1.c +54 -9
  16. package/c/dispatch_u4.c +3 -0
  17. package/c/dispatch_u8.c +64 -3
  18. package/c/numkong.c +3 -0
  19. package/include/README.md +79 -9
  20. package/include/numkong/attention/sapphireamx.h +278 -276
  21. package/include/numkong/attention/sme.h +983 -977
  22. package/include/numkong/attention.h +1 -1
  23. package/include/numkong/capabilities.h +289 -94
  24. package/include/numkong/cast/README.md +40 -40
  25. package/include/numkong/cast/diamond.h +64 -0
  26. package/include/numkong/cast/haswell.h +42 -194
  27. package/include/numkong/cast/icelake.h +42 -37
  28. package/include/numkong/cast/loongsonasx.h +252 -0
  29. package/include/numkong/cast/neon.h +216 -249
  30. package/include/numkong/cast/powervsx.h +449 -0
  31. package/include/numkong/cast/rvv.h +223 -274
  32. package/include/numkong/cast/sapphire.h +18 -18
  33. package/include/numkong/cast/serial.h +1018 -944
  34. package/include/numkong/cast/skylake.h +82 -23
  35. package/include/numkong/cast/v128relaxed.h +462 -105
  36. package/include/numkong/cast.h +24 -0
  37. package/include/numkong/cast.hpp +44 -0
  38. package/include/numkong/curved/README.md +17 -17
  39. package/include/numkong/curved/neon.h +131 -7
  40. package/include/numkong/curved/neonbfdot.h +6 -7
  41. package/include/numkong/curved/rvv.h +26 -26
  42. package/include/numkong/curved/smef64.h +186 -182
  43. package/include/numkong/curved.h +14 -18
  44. package/include/numkong/dot/README.md +154 -137
  45. package/include/numkong/dot/alder.h +43 -43
  46. package/include/numkong/dot/diamond.h +158 -0
  47. package/include/numkong/dot/genoa.h +4 -30
  48. package/include/numkong/dot/haswell.h +215 -180
  49. package/include/numkong/dot/icelake.h +190 -76
  50. package/include/numkong/dot/loongsonasx.h +671 -0
  51. package/include/numkong/dot/neon.h +124 -73
  52. package/include/numkong/dot/neonbfdot.h +11 -12
  53. package/include/numkong/dot/neonfhm.h +44 -46
  54. package/include/numkong/dot/neonfp8.h +323 -0
  55. package/include/numkong/dot/neonsdot.h +190 -76
  56. package/include/numkong/dot/powervsx.h +752 -0
  57. package/include/numkong/dot/rvv.h +92 -84
  58. package/include/numkong/dot/rvvbf16.h +12 -12
  59. package/include/numkong/dot/rvvhalf.h +12 -12
  60. package/include/numkong/dot/sapphire.h +4 -4
  61. package/include/numkong/dot/serial.h +66 -30
  62. package/include/numkong/dot/sierra.h +31 -31
  63. package/include/numkong/dot/skylake.h +142 -110
  64. package/include/numkong/dot/sve.h +217 -177
  65. package/include/numkong/dot/svebfdot.h +10 -10
  66. package/include/numkong/dot/svehalf.h +85 -41
  67. package/include/numkong/dot/svesdot.h +89 -0
  68. package/include/numkong/dot/v128relaxed.h +124 -89
  69. package/include/numkong/dot.h +114 -48
  70. package/include/numkong/dots/README.md +203 -203
  71. package/include/numkong/dots/alder.h +12 -9
  72. package/include/numkong/dots/diamond.h +86 -0
  73. package/include/numkong/dots/genoa.h +10 -4
  74. package/include/numkong/dots/haswell.h +63 -48
  75. package/include/numkong/dots/icelake.h +27 -18
  76. package/include/numkong/dots/loongsonasx.h +176 -0
  77. package/include/numkong/dots/neon.h +14 -11
  78. package/include/numkong/dots/neonbfdot.h +4 -3
  79. package/include/numkong/dots/neonfhm.h +11 -9
  80. package/include/numkong/dots/neonfp8.h +99 -0
  81. package/include/numkong/dots/neonsdot.h +48 -12
  82. package/include/numkong/dots/powervsx.h +194 -0
  83. package/include/numkong/dots/rvv.h +451 -344
  84. package/include/numkong/dots/sapphireamx.h +1028 -984
  85. package/include/numkong/dots/serial.h +213 -197
  86. package/include/numkong/dots/sierra.h +10 -7
  87. package/include/numkong/dots/skylake.h +47 -36
  88. package/include/numkong/dots/sme.h +2001 -2364
  89. package/include/numkong/dots/smebi32.h +175 -162
  90. package/include/numkong/dots/smef64.h +328 -323
  91. package/include/numkong/dots/v128relaxed.h +64 -41
  92. package/include/numkong/dots.h +573 -293
  93. package/include/numkong/dots.hpp +45 -43
  94. package/include/numkong/each/README.md +133 -137
  95. package/include/numkong/each/haswell.h +6 -6
  96. package/include/numkong/each/icelake.h +7 -7
  97. package/include/numkong/each/neon.h +76 -42
  98. package/include/numkong/each/neonbfdot.h +11 -12
  99. package/include/numkong/each/neonhalf.h +24 -116
  100. package/include/numkong/each/rvv.h +28 -28
  101. package/include/numkong/each/sapphire.h +27 -161
  102. package/include/numkong/each/serial.h +6 -6
  103. package/include/numkong/each/skylake.h +7 -7
  104. package/include/numkong/each/v128relaxed.h +562 -0
  105. package/include/numkong/each.h +148 -62
  106. package/include/numkong/each.hpp +2 -2
  107. package/include/numkong/geospatial/README.md +18 -18
  108. package/include/numkong/geospatial/haswell.h +365 -325
  109. package/include/numkong/geospatial/neon.h +350 -306
  110. package/include/numkong/geospatial/rvv.h +4 -4
  111. package/include/numkong/geospatial/skylake.h +376 -340
  112. package/include/numkong/geospatial/v128relaxed.h +366 -327
  113. package/include/numkong/geospatial.h +17 -17
  114. package/include/numkong/matrix.hpp +4 -4
  115. package/include/numkong/maxsim/README.md +14 -14
  116. package/include/numkong/maxsim/alder.h +6 -6
  117. package/include/numkong/maxsim/genoa.h +4 -4
  118. package/include/numkong/maxsim/haswell.h +6 -6
  119. package/include/numkong/maxsim/icelake.h +18 -18
  120. package/include/numkong/maxsim/neonsdot.h +21 -21
  121. package/include/numkong/maxsim/sapphireamx.h +14 -14
  122. package/include/numkong/maxsim/serial.h +6 -6
  123. package/include/numkong/maxsim/sme.h +221 -196
  124. package/include/numkong/maxsim/v128relaxed.h +6 -6
  125. package/include/numkong/mesh/README.md +62 -56
  126. package/include/numkong/mesh/haswell.h +339 -464
  127. package/include/numkong/mesh/neon.h +1100 -519
  128. package/include/numkong/mesh/neonbfdot.h +36 -68
  129. package/include/numkong/mesh/rvv.h +530 -435
  130. package/include/numkong/mesh/serial.h +75 -91
  131. package/include/numkong/mesh/skylake.h +1627 -302
  132. package/include/numkong/mesh/v128relaxed.h +443 -330
  133. package/include/numkong/mesh.h +63 -49
  134. package/include/numkong/mesh.hpp +4 -4
  135. package/include/numkong/numkong.h +3 -3
  136. package/include/numkong/numkong.hpp +1 -0
  137. package/include/numkong/probability/README.md +23 -19
  138. package/include/numkong/probability/neon.h +82 -52
  139. package/include/numkong/probability/rvv.h +28 -23
  140. package/include/numkong/probability/serial.h +51 -39
  141. package/include/numkong/probability.h +20 -23
  142. package/include/numkong/random.h +1 -1
  143. package/include/numkong/reduce/README.md +143 -138
  144. package/include/numkong/reduce/alder.h +81 -77
  145. package/include/numkong/reduce/haswell.h +222 -220
  146. package/include/numkong/reduce/neon.h +629 -519
  147. package/include/numkong/reduce/neonbfdot.h +7 -218
  148. package/include/numkong/reduce/neonfhm.h +9 -381
  149. package/include/numkong/reduce/neonsdot.h +9 -9
  150. package/include/numkong/reduce/rvv.h +928 -802
  151. package/include/numkong/reduce/serial.h +23 -27
  152. package/include/numkong/reduce/sierra.h +20 -20
  153. package/include/numkong/reduce/skylake.h +326 -324
  154. package/include/numkong/reduce/v128relaxed.h +52 -52
  155. package/include/numkong/reduce.h +4 -23
  156. package/include/numkong/reduce.hpp +156 -11
  157. package/include/numkong/scalar/README.md +6 -6
  158. package/include/numkong/scalar/haswell.h +26 -17
  159. package/include/numkong/scalar/loongsonasx.h +74 -0
  160. package/include/numkong/scalar/neon.h +9 -9
  161. package/include/numkong/scalar/powervsx.h +96 -0
  162. package/include/numkong/scalar/rvv.h +2 -2
  163. package/include/numkong/scalar/sapphire.h +21 -10
  164. package/include/numkong/scalar/serial.h +21 -21
  165. package/include/numkong/scalar.h +13 -0
  166. package/include/numkong/set/README.md +28 -28
  167. package/include/numkong/set/haswell.h +12 -12
  168. package/include/numkong/set/icelake.h +14 -14
  169. package/include/numkong/set/loongsonasx.h +181 -0
  170. package/include/numkong/set/neon.h +17 -18
  171. package/include/numkong/set/powervsx.h +326 -0
  172. package/include/numkong/set/rvv.h +4 -4
  173. package/include/numkong/set/serial.h +6 -6
  174. package/include/numkong/set/sve.h +60 -59
  175. package/include/numkong/set/v128relaxed.h +6 -6
  176. package/include/numkong/set.h +21 -7
  177. package/include/numkong/sets/README.md +26 -26
  178. package/include/numkong/sets/loongsonasx.h +52 -0
  179. package/include/numkong/sets/powervsx.h +65 -0
  180. package/include/numkong/sets/smebi32.h +395 -364
  181. package/include/numkong/sets.h +83 -40
  182. package/include/numkong/sparse/README.md +4 -4
  183. package/include/numkong/sparse/icelake.h +101 -101
  184. package/include/numkong/sparse/serial.h +1 -1
  185. package/include/numkong/sparse/sve2.h +137 -141
  186. package/include/numkong/sparse/turin.h +12 -12
  187. package/include/numkong/sparse.h +10 -10
  188. package/include/numkong/spatial/README.md +230 -226
  189. package/include/numkong/spatial/alder.h +113 -116
  190. package/include/numkong/spatial/diamond.h +240 -0
  191. package/include/numkong/spatial/genoa.h +0 -68
  192. package/include/numkong/spatial/haswell.h +74 -55
  193. package/include/numkong/spatial/icelake.h +539 -58
  194. package/include/numkong/spatial/loongsonasx.h +483 -0
  195. package/include/numkong/spatial/neon.h +125 -52
  196. package/include/numkong/spatial/neonbfdot.h +8 -9
  197. package/include/numkong/spatial/neonfp8.h +258 -0
  198. package/include/numkong/spatial/neonsdot.h +180 -12
  199. package/include/numkong/spatial/powervsx.h +738 -0
  200. package/include/numkong/spatial/rvv.h +146 -139
  201. package/include/numkong/spatial/rvvbf16.h +17 -12
  202. package/include/numkong/spatial/rvvhalf.h +13 -10
  203. package/include/numkong/spatial/serial.h +13 -12
  204. package/include/numkong/spatial/sierra.h +232 -39
  205. package/include/numkong/spatial/skylake.h +73 -74
  206. package/include/numkong/spatial/sve.h +93 -72
  207. package/include/numkong/spatial/svebfdot.h +29 -29
  208. package/include/numkong/spatial/svehalf.h +52 -26
  209. package/include/numkong/spatial/svesdot.h +142 -0
  210. package/include/numkong/spatial/v128relaxed.h +293 -41
  211. package/include/numkong/spatial.h +338 -82
  212. package/include/numkong/spatials/README.md +194 -194
  213. package/include/numkong/spatials/diamond.h +82 -0
  214. package/include/numkong/spatials/haswell.h +2 -2
  215. package/include/numkong/spatials/loongsonasx.h +153 -0
  216. package/include/numkong/spatials/neonfp8.h +111 -0
  217. package/include/numkong/spatials/neonsdot.h +34 -0
  218. package/include/numkong/spatials/powervsx.h +153 -0
  219. package/include/numkong/spatials/rvv.h +259 -243
  220. package/include/numkong/spatials/sapphireamx.h +173 -173
  221. package/include/numkong/spatials/serial.h +2 -2
  222. package/include/numkong/spatials/skylake.h +2 -2
  223. package/include/numkong/spatials/sme.h +590 -605
  224. package/include/numkong/spatials/smef64.h +139 -130
  225. package/include/numkong/spatials/v128relaxed.h +2 -2
  226. package/include/numkong/spatials.h +820 -500
  227. package/include/numkong/spatials.hpp +49 -48
  228. package/include/numkong/tensor.hpp +406 -17
  229. package/include/numkong/trigonometry/README.md +19 -19
  230. package/include/numkong/trigonometry/haswell.h +402 -401
  231. package/include/numkong/trigonometry/neon.h +386 -387
  232. package/include/numkong/trigonometry/rvv.h +52 -51
  233. package/include/numkong/trigonometry/serial.h +13 -13
  234. package/include/numkong/trigonometry/skylake.h +373 -369
  235. package/include/numkong/trigonometry/v128relaxed.h +375 -374
  236. package/include/numkong/trigonometry.h +13 -13
  237. package/include/numkong/trigonometry.hpp +2 -2
  238. package/include/numkong/types.h +287 -49
  239. package/include/numkong/types.hpp +436 -12
  240. package/include/numkong/vector.hpp +82 -14
  241. package/javascript/dist/cjs/numkong-wasm.js +6 -12
  242. package/javascript/dist/cjs/numkong.d.ts +7 -1
  243. package/javascript/dist/cjs/numkong.js +37 -11
  244. package/javascript/dist/cjs/types.d.ts +9 -0
  245. package/javascript/dist/cjs/types.js +96 -0
  246. package/javascript/dist/esm/numkong-browser.d.ts +14 -0
  247. package/javascript/dist/esm/numkong-browser.js +23 -0
  248. package/javascript/dist/esm/numkong-wasm.js +6 -12
  249. package/javascript/dist/esm/numkong.d.ts +7 -1
  250. package/javascript/dist/esm/numkong.js +37 -11
  251. package/javascript/dist/esm/types.d.ts +9 -0
  252. package/javascript/dist/esm/types.js +96 -0
  253. package/javascript/node-gyp-build.d.ts +4 -1
  254. package/javascript/numkong-browser.ts +40 -0
  255. package/javascript/numkong-wasm.ts +7 -13
  256. package/javascript/numkong.c +5 -26
  257. package/javascript/numkong.ts +36 -11
  258. package/javascript/tsconfig-base.json +1 -0
  259. package/javascript/tsconfig-cjs.json +6 -1
  260. package/javascript/types.ts +110 -0
  261. package/numkong.gypi +101 -0
  262. package/package.json +34 -13
  263. package/probes/arm_neon.c +8 -0
  264. package/probes/arm_neon_bfdot.c +9 -0
  265. package/probes/arm_neon_fhm.c +9 -0
  266. package/probes/arm_neon_half.c +8 -0
  267. package/probes/arm_neon_sdot.c +9 -0
  268. package/probes/arm_neonfp8.c +9 -0
  269. package/probes/arm_sme.c +16 -0
  270. package/probes/arm_sme2.c +16 -0
  271. package/probes/arm_sme2p1.c +16 -0
  272. package/probes/arm_sme_bf16.c +16 -0
  273. package/probes/arm_sme_bi32.c +16 -0
  274. package/probes/arm_sme_f64.c +16 -0
  275. package/probes/arm_sme_fa64.c +14 -0
  276. package/probes/arm_sme_half.c +16 -0
  277. package/probes/arm_sme_lut2.c +15 -0
  278. package/probes/arm_sve.c +18 -0
  279. package/probes/arm_sve2.c +20 -0
  280. package/probes/arm_sve2p1.c +18 -0
  281. package/probes/arm_sve_bfdot.c +20 -0
  282. package/probes/arm_sve_half.c +18 -0
  283. package/probes/arm_sve_sdot.c +21 -0
  284. package/probes/loongarch_lasx.c +12 -0
  285. package/probes/power_vsx.c +12 -0
  286. package/probes/probe.js +127 -0
  287. package/probes/riscv_rvv.c +14 -0
  288. package/probes/riscv_rvv_bb.c +15 -0
  289. package/probes/riscv_rvv_bf16.c +17 -0
  290. package/probes/riscv_rvv_half.c +14 -0
  291. package/probes/wasm_v128relaxed.c +11 -0
  292. package/probes/x86_alder.c +17 -0
  293. package/probes/x86_diamond.c +17 -0
  294. package/probes/x86_genoa.c +17 -0
  295. package/probes/x86_graniteamx.c +19 -0
  296. package/probes/x86_haswell.c +11 -0
  297. package/probes/x86_icelake.c +17 -0
  298. package/probes/x86_sapphire.c +16 -0
  299. package/probes/x86_sapphireamx.c +18 -0
  300. package/probes/x86_sierra.c +17 -0
  301. package/probes/x86_skylake.c +15 -0
  302. package/probes/x86_turin.c +17 -0
  303. package/wasm/numkong-emscripten.js +2 -0
  304. package/wasm/numkong.d.ts +14 -0
  305. package/wasm/numkong.js +1124 -0
  306. package/wasm/numkong.wasm +0 -0
  307. package/include/numkong/curved/neonhalf.h +0 -212
  308. package/include/numkong/dot/neonhalf.h +0 -198
  309. package/include/numkong/dots/neonhalf.h +0 -57
  310. package/include/numkong/mesh/neonhalf.h +0 -616
  311. package/include/numkong/reduce/neonhalf.h +0 -157
  312. package/include/numkong/spatial/neonhalf.h +0 -118
  313. package/include/numkong/spatial/sapphire.h +0 -343
  314. package/include/numkong/spatials/neonhalf.h +0 -58
  315. package/javascript/README.md +0 -246
@@ -8,13 +8,13 @@
8
8
  *
9
9
  * @section skylake_cast_instructions AVX-512 Conversion Instructions
10
10
  *
11
- * Intrinsic Instruction SKL ICL Genoa
12
- * _mm512_cvtph_ps VCVTPH2PS (ZMM, YMM) 5cy @ p05 5cy @ p05 4cy @ p01
13
- * _mm512_cvtps_ph VCVTPS2PH (YMM, ZMM, imm) 5cy @ p05 5cy @ p05 4cy @ p01
14
- * _mm512_cvtps_epi32 VCVTPS2DQ (ZMM, ZMM) 4cy @ p0 4cy @ p0 3cy @ p01
15
- * _mm512_cvtepi32_ps VCVTDQ2PS (ZMM, ZMM) 4cy @ p0 4cy @ p0 3cy @ p01
16
- * _mm512_cvtepi32_epi16 VPMOVDW (YMM, ZMM) 3cy @ p5 3cy @ p5 2cy @ p12
17
- * _mm512_cvtsepi32_epi8 VPMOVSDB (XMM, ZMM) 3cy @ p5 3cy @ p5 2cy @ p12
11
+ * Intrinsic Instruction SKL ICL Genoa
12
+ * _mm512_cvtph_ps VCVTPH2PS (ZMM, YMM) 5cy @ p05 5cy @ p05 4cy @ p01
13
+ * _mm512_cvtps_ph VCVTPS2PH (YMM, ZMM, imm) 5cy @ p05 5cy @ p05 4cy @ p01
14
+ * _mm512_cvtps_epi32 VCVTPS2DQ (ZMM, ZMM) 4cy @ p0 4cy @ p0 3cy @ p01
15
+ * _mm512_cvtepi32_ps VCVTDQ2PS (ZMM, ZMM) 4cy @ p0 4cy @ p0 3cy @ p01
16
+ * _mm512_cvtepi32_epi16 VPMOVDW (YMM, ZMM) 3cy @ p5 3cy @ p5 2cy @ p12
17
+ * _mm512_cvtsepi32_epi8 VPMOVSDB (XMM, ZMM) 3cy @ p5 3cy @ p5 2cy @ p12
18
18
  *
19
19
  * F16 conversions use hardware F16C via VCVTPH2PS/VCVTPS2PH. BF16 lacks hardware support on Skylake,
20
20
  * requiring emulation via VPMOVZXWD + VPSLLD for bf16-to-f32, achieving ~4cy total. FP8 (E4M3/E5M2)
@@ -41,7 +41,7 @@ extern "C" {
41
41
  #pragma GCC target("avx2", "avx512f", "avx512vl", "avx512bw", "avx512dq", "f16c", "fma", "bmi", "bmi2")
42
42
  #endif
43
43
 
44
- #pragma region - Type Punned Loads and Stores
44
+ #pragma region Type Punned Loads and Stores
45
45
 
46
46
  /** @brief Type-agnostic 512-bit full load (Skylake AVX-512). */
47
47
  NK_INTERNAL void nk_load_b512_skylake_(void const *src, nk_b512_vec_t *dst) { dst->zmm = _mm512_loadu_si512(src); }
@@ -132,9 +132,32 @@ NK_INTERNAL void nk_partial_store_b64x4_skylake_(nk_b256_vec_t const *src, void
132
132
  _mm256_mask_storeu_epi64(dst, mask, src->ymm);
133
133
  }
134
134
 
135
- #pragma endregion - Type Punned Loads and Stores
135
+ /** @brief Type-agnostic full store for 512-bit vector (Skylake AVX-512). */
136
+ NK_INTERNAL void nk_store_b512_skylake_(nk_b512_vec_t const *src, void *dst) {
137
+ _mm512_storeu_si512((__m512i *)dst, src->zmm);
138
+ }
139
+
140
+ /** @brief Type-agnostic partial store for 16-bit elements (32 elements max) from 512-bit vector (Skylake AVX-512). */
141
+ NK_INTERNAL void nk_partial_store_b16x32_skylake_(nk_b512_vec_t const *src, void *dst, nk_size_t n) {
142
+ __mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, (unsigned int)n);
143
+ _mm512_mask_storeu_epi16(dst, mask, src->zmm);
144
+ }
145
+
146
+ /** @brief Type-agnostic partial store for 8-bit elements (64 elements max) from 512-bit vector (Skylake AVX-512). */
147
+ NK_INTERNAL void nk_partial_store_b8x64_skylake_(nk_b512_vec_t const *src, void *dst, nk_size_t n) {
148
+ __mmask64 mask = _bzhi_u64(0xFFFFFFFFFFFFFFFFULL, (unsigned int)n);
149
+ _mm512_mask_storeu_epi8(dst, mask, src->zmm);
150
+ }
151
+
152
+ /** @brief Type-agnostic partial store for 64-bit elements (8 elements max) from 512-bit vector (Skylake AVX-512). */
153
+ NK_INTERNAL void nk_partial_store_b64x8_skylake_(nk_b512_vec_t const *src, void *dst, nk_size_t n) {
154
+ __mmask8 mask = (__mmask8)_bzhi_u32(0xFF, (unsigned int)n);
155
+ _mm512_mask_storeu_epi64(dst, mask, src->zmm);
156
+ }
136
157
 
137
- #pragma region - Vectorized Conversions
158
+ #pragma endregion Type Punned Loads and Stores
159
+
160
+ #pragma region Vectorized Conversions
138
161
 
139
162
  /** @brief Convert 16x bf16 → 16x f32 (Skylake AVX-512). */
140
163
  NK_INTERNAL __m512 nk_bf16x16_to_f32x16_skylake_(__m256i a) {
@@ -169,18 +192,20 @@ NK_INTERNAL __m512 nk_e4m3x16_to_f32x16_skylake_(__m128i e4m3_i8x16) {
169
192
  __m512 result_f32x16 = _mm512_castsi512_ps(
170
193
  _mm512_ternarylogic_epi32(sign_i32x16, f32_exp_i32x16, f32_mantissa_i32x16, 0xFE));
171
194
 
172
- // Subnormal fix: for exp==0 lanes, replace with (mantissa / 512) | sign using masked OR
195
+ // Subnormal fix: vpermps from 8-entry LUT (repeated to fill 16 lanes)
173
196
  __mmask16 is_subnormal = _mm512_testn_epi32_mask(e4m3_i32x16, _mm512_set1_epi32(0x78));
174
- __m512 subnorm_abs_f32x16 = _mm512_mul_ps(_mm512_cvtepi32_ps(mantissa_i32x16), _mm512_set1_ps(1.0f / 512.0f));
197
+ __m512 subnorm_lut_f32x16 = _mm512_setr_ps( //
198
+ 0, 1.0f / 512, 2.0f / 512, 3.0f / 512, 4.0f / 512, 5.0f / 512, 6.0f / 512, 7.0f / 512, //
199
+ 0, 1.0f / 512, 2.0f / 512, 3.0f / 512, 4.0f / 512, 5.0f / 512, 6.0f / 512, 7.0f / 512);
200
+ __m512 subnorm_abs_f32x16 = _mm512_permutexvar_ps(mantissa_i32x16, subnorm_lut_f32x16);
175
201
  result_f32x16 = _mm512_mask_or_ps(result_f32x16, is_subnormal, subnorm_abs_f32x16,
176
202
  _mm512_castsi512_ps(sign_i32x16));
177
203
 
178
- // NaN path: E4M3FN has NaN only when exp=15 AND mant=7 (0x7F or 0xFF)
179
- __mmask16 is_nan = _mm512_mask_cmpeq_epi32_mask( //
180
- _mm512_cmpeq_epi32_mask(exp_i32x16, _mm512_set1_epi32(15)), //
181
- mantissa_i32x16, _mm512_set1_epi32(7)); //
182
- __m512i nan_bits = _mm512_or_si512(sign_i32x16, _mm512_set1_epi32(0x7FC00000)); // F32 quiet NaN
183
- return _mm512_mask_blend_ps(is_nan, result_f32x16, _mm512_castsi512_ps(nan_bits));
204
+ // NaN: E4M3FN has NaN only at magnitude 0x7F (single mask comparison)
205
+ __m512i lower7_i32x16 = _mm512_and_si512(e4m3_i32x16, _mm512_set1_epi32(0x7F));
206
+ __mmask16 is_nan = _mm512_cmpeq_epi32_mask(lower7_i32x16, _mm512_set1_epi32(0x7F));
207
+ __m512i nan_i32x16 = _mm512_or_si512(sign_i32x16, _mm512_set1_epi32(0x7FC00000));
208
+ return _mm512_mask_blend_ps(is_nan, result_f32x16, _mm512_castsi512_ps(nan_i32x16));
184
209
  }
185
210
 
186
211
  /** @brief Convert 16x e5m2 → 16x f32 via bit manipulation (AVX-512).
@@ -561,9 +586,43 @@ NK_INTERNAL __m256i nk_f64x8_to_u32x8_skylake_(__m512d f64x8) {
561
586
  return _mm512_cvtpd_epu32(clamped);
562
587
  }
563
588
 
564
- #pragma endregion - Vectorized Conversions
589
+ /**
590
+ * @brief Convert 64x E2M3 → 64x I8 using VPSHUFB LUT (Skylake AVX-512).
591
+ *
592
+ * E2M3 format: [sign:1][magnitude:5] where magnitude indexes a 32-entry LUT
593
+ * that produces the scaled integer value. Sign bit negates the result.
594
+ * The 32-entry LUT is split into two 16-entry halves for VPSHUFB (which
595
+ * indexes within 16-byte lanes). Bit 4 of the magnitude selects the half.
596
+ */
597
+ NK_INTERNAL __m512i nk_e2m3x64_to_i8x64_skylake_(__m512i raw_i8x64) {
598
+ // lut_magnitude[0..15] = {0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30}
599
+ // lut_magnitude[16..31] = {32, 36, 40, 44, 48, 52, 56, 60, 64, 72, 80, 88, 96, 104, 112, 120}
600
+ // _mm512_set4_epi32(d3, d2, d1, d0) fills bytes [0..3]=d0, [4..7]=d1, [8..11]=d2, [12..15]=d3
601
+ // per 128-bit lane, matching VPSHUFB's per-lane indexing.
602
+ __m512i lut_low_i8x64 = _mm512_set4_epi32( //
603
+ 0x1E1C1A18, 0x16141210, 0x0E0C0A08, 0x06040200);
604
+ __m512i lut_high_i8x64 = _mm512_set4_epi32( //
605
+ 0x78706860, 0x58504840, 0x3C383430, 0x2C282420);
606
+
607
+ __m512i magnitude_i8x64 = _mm512_and_si512(raw_i8x64, _mm512_set1_epi8(0x1F));
608
+ __m512i index_i8x64 = _mm512_and_si512(magnitude_i8x64, _mm512_set1_epi8(0x0F));
609
+
610
+ __m512i val_low_i8x64 = _mm512_shuffle_epi8(lut_low_i8x64, index_i8x64);
611
+ __m512i val_high_i8x64 = _mm512_shuffle_epi8(lut_high_i8x64, index_i8x64);
612
+
613
+ // Select high half when bit 4 of magnitude is set (magnitude >= 16)
614
+ __mmask64 use_high_mask = _mm512_test_epi8_mask(magnitude_i8x64, _mm512_set1_epi8(0x10));
615
+ __m512i val_i8x64 = _mm512_mask_blend_epi8(use_high_mask, val_low_i8x64, val_high_i8x64);
616
+
617
+ // Negate if sign bit (bit 5) is set
618
+ __mmask64 sign_mask = _mm512_test_epi8_mask(raw_i8x64, _mm512_set1_epi8(0x20));
619
+ __m512i negated_i8x64 = _mm512_sub_epi8(_mm512_setzero_si512(), val_i8x64);
620
+ return _mm512_mask_blend_epi8(sign_mask, val_i8x64, negated_i8x64);
621
+ }
622
+
623
+ #pragma endregion Vectorized Conversions
565
624
 
566
- #pragma region - Converting Loads and Stores
625
+ #pragma region Converting Loads and Stores
567
626
 
568
627
  /** @brief Load 16 f16 values and convert to 16 f32 (Skylake AVX-512). */
569
628
  NK_INTERNAL void nk_load_f16x16_to_f32x16_skylake_(void const *src, nk_b512_vec_t *dst) {
@@ -637,9 +696,9 @@ NK_INTERNAL void nk_partial_load_e3m2x16_to_f32x16_skylake_(void const *src, nk_
637
696
  dst->zmm_ps = nk_e3m2x16_to_f32x16_skylake_(e3m2_partial.xmm);
638
697
  }
639
698
 
640
- #pragma endregion - Converting Loads and Stores
699
+ #pragma endregion Converting Loads and Stores
641
700
 
642
- #pragma region - Public API
701
+ #pragma region Public API
643
702
 
644
703
  NK_PUBLIC void nk_cast_skylake(void const *from, nk_dtype_t from_type, nk_size_t n, void *to, nk_dtype_t to_type) {
645
704
  // Same-type fast path
@@ -839,7 +898,7 @@ NK_PUBLIC void nk_cast_skylake(void const *from, nk_dtype_t from_type, nk_size_t
839
898
  nk_cast_serial(from, from_type, n, to, to_type);
840
899
  }
841
900
 
842
- #pragma endregion - Public API
901
+ #pragma endregion Public API
843
902
 
844
903
  #if defined(__clang__)
845
904
  #pragma clang attribute pop