numkong 7.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (294) hide show
  1. package/LICENSE +201 -0
  2. package/README.md +495 -0
  3. package/binding.gyp +540 -0
  4. package/c/dispatch.h +512 -0
  5. package/c/dispatch_bf16.c +389 -0
  6. package/c/dispatch_bf16c.c +52 -0
  7. package/c/dispatch_e2m3.c +263 -0
  8. package/c/dispatch_e3m2.c +243 -0
  9. package/c/dispatch_e4m3.c +276 -0
  10. package/c/dispatch_e5m2.c +272 -0
  11. package/c/dispatch_f16.c +376 -0
  12. package/c/dispatch_f16c.c +58 -0
  13. package/c/dispatch_f32.c +378 -0
  14. package/c/dispatch_f32c.c +99 -0
  15. package/c/dispatch_f64.c +296 -0
  16. package/c/dispatch_f64c.c +98 -0
  17. package/c/dispatch_i16.c +96 -0
  18. package/c/dispatch_i32.c +89 -0
  19. package/c/dispatch_i4.c +150 -0
  20. package/c/dispatch_i64.c +86 -0
  21. package/c/dispatch_i8.c +289 -0
  22. package/c/dispatch_other.c +330 -0
  23. package/c/dispatch_u1.c +148 -0
  24. package/c/dispatch_u16.c +124 -0
  25. package/c/dispatch_u32.c +118 -0
  26. package/c/dispatch_u4.c +150 -0
  27. package/c/dispatch_u64.c +102 -0
  28. package/c/dispatch_u8.c +303 -0
  29. package/c/numkong.c +950 -0
  30. package/include/README.md +573 -0
  31. package/include/module.modulemap +129 -0
  32. package/include/numkong/attention/sapphireamx.h +1361 -0
  33. package/include/numkong/attention/sme.h +2066 -0
  34. package/include/numkong/attention.h +49 -0
  35. package/include/numkong/capabilities.h +748 -0
  36. package/include/numkong/cast/README.md +262 -0
  37. package/include/numkong/cast/haswell.h +975 -0
  38. package/include/numkong/cast/icelake.h +470 -0
  39. package/include/numkong/cast/neon.h +1192 -0
  40. package/include/numkong/cast/rvv.h +1021 -0
  41. package/include/numkong/cast/sapphire.h +262 -0
  42. package/include/numkong/cast/serial.h +2262 -0
  43. package/include/numkong/cast/skylake.h +856 -0
  44. package/include/numkong/cast/v128relaxed.h +180 -0
  45. package/include/numkong/cast.h +230 -0
  46. package/include/numkong/curved/README.md +223 -0
  47. package/include/numkong/curved/genoa.h +182 -0
  48. package/include/numkong/curved/haswell.h +276 -0
  49. package/include/numkong/curved/neon.h +205 -0
  50. package/include/numkong/curved/neonbfdot.h +212 -0
  51. package/include/numkong/curved/neonhalf.h +212 -0
  52. package/include/numkong/curved/rvv.h +305 -0
  53. package/include/numkong/curved/serial.h +207 -0
  54. package/include/numkong/curved/skylake.h +457 -0
  55. package/include/numkong/curved/smef64.h +506 -0
  56. package/include/numkong/curved.h +517 -0
  57. package/include/numkong/curved.hpp +144 -0
  58. package/include/numkong/dot/README.md +425 -0
  59. package/include/numkong/dot/alder.h +563 -0
  60. package/include/numkong/dot/genoa.h +315 -0
  61. package/include/numkong/dot/haswell.h +1688 -0
  62. package/include/numkong/dot/icelake.h +883 -0
  63. package/include/numkong/dot/neon.h +818 -0
  64. package/include/numkong/dot/neonbfdot.h +244 -0
  65. package/include/numkong/dot/neonfhm.h +360 -0
  66. package/include/numkong/dot/neonhalf.h +198 -0
  67. package/include/numkong/dot/neonsdot.h +508 -0
  68. package/include/numkong/dot/rvv.h +714 -0
  69. package/include/numkong/dot/rvvbb.h +72 -0
  70. package/include/numkong/dot/rvvbf16.h +123 -0
  71. package/include/numkong/dot/rvvhalf.h +129 -0
  72. package/include/numkong/dot/sapphire.h +141 -0
  73. package/include/numkong/dot/serial.h +838 -0
  74. package/include/numkong/dot/sierra.h +405 -0
  75. package/include/numkong/dot/skylake.h +1084 -0
  76. package/include/numkong/dot/sve.h +379 -0
  77. package/include/numkong/dot/svebfdot.h +74 -0
  78. package/include/numkong/dot/svehalf.h +123 -0
  79. package/include/numkong/dot/v128relaxed.h +1258 -0
  80. package/include/numkong/dot.h +1070 -0
  81. package/include/numkong/dot.hpp +94 -0
  82. package/include/numkong/dots/README.md +496 -0
  83. package/include/numkong/dots/alder.h +114 -0
  84. package/include/numkong/dots/genoa.h +94 -0
  85. package/include/numkong/dots/haswell.h +295 -0
  86. package/include/numkong/dots/icelake.h +171 -0
  87. package/include/numkong/dots/neon.h +120 -0
  88. package/include/numkong/dots/neonbfdot.h +58 -0
  89. package/include/numkong/dots/neonfhm.h +94 -0
  90. package/include/numkong/dots/neonhalf.h +57 -0
  91. package/include/numkong/dots/neonsdot.h +108 -0
  92. package/include/numkong/dots/rvv.h +2486 -0
  93. package/include/numkong/dots/sapphireamx.h +3973 -0
  94. package/include/numkong/dots/serial.h +2844 -0
  95. package/include/numkong/dots/sierra.h +97 -0
  96. package/include/numkong/dots/skylake.h +196 -0
  97. package/include/numkong/dots/sme.h +5372 -0
  98. package/include/numkong/dots/smebi32.h +461 -0
  99. package/include/numkong/dots/smef64.h +1318 -0
  100. package/include/numkong/dots/smehalf.h +47 -0
  101. package/include/numkong/dots/v128relaxed.h +294 -0
  102. package/include/numkong/dots.h +2804 -0
  103. package/include/numkong/dots.hpp +639 -0
  104. package/include/numkong/each/README.md +469 -0
  105. package/include/numkong/each/haswell.h +1658 -0
  106. package/include/numkong/each/icelake.h +272 -0
  107. package/include/numkong/each/neon.h +1104 -0
  108. package/include/numkong/each/neonbfdot.h +212 -0
  109. package/include/numkong/each/neonhalf.h +410 -0
  110. package/include/numkong/each/rvv.h +1121 -0
  111. package/include/numkong/each/sapphire.h +477 -0
  112. package/include/numkong/each/serial.h +260 -0
  113. package/include/numkong/each/skylake.h +1562 -0
  114. package/include/numkong/each.h +2146 -0
  115. package/include/numkong/each.hpp +434 -0
  116. package/include/numkong/geospatial/README.md +147 -0
  117. package/include/numkong/geospatial/haswell.h +593 -0
  118. package/include/numkong/geospatial/neon.h +571 -0
  119. package/include/numkong/geospatial/rvv.h +701 -0
  120. package/include/numkong/geospatial/serial.h +309 -0
  121. package/include/numkong/geospatial/skylake.h +577 -0
  122. package/include/numkong/geospatial/v128relaxed.h +613 -0
  123. package/include/numkong/geospatial.h +453 -0
  124. package/include/numkong/geospatial.hpp +235 -0
  125. package/include/numkong/matrix.hpp +336 -0
  126. package/include/numkong/maxsim/README.md +187 -0
  127. package/include/numkong/maxsim/alder.h +511 -0
  128. package/include/numkong/maxsim/genoa.h +115 -0
  129. package/include/numkong/maxsim/haswell.h +553 -0
  130. package/include/numkong/maxsim/icelake.h +480 -0
  131. package/include/numkong/maxsim/neonsdot.h +394 -0
  132. package/include/numkong/maxsim/sapphireamx.h +877 -0
  133. package/include/numkong/maxsim/serial.h +490 -0
  134. package/include/numkong/maxsim/sme.h +929 -0
  135. package/include/numkong/maxsim/v128relaxed.h +280 -0
  136. package/include/numkong/maxsim.h +571 -0
  137. package/include/numkong/maxsim.hpp +133 -0
  138. package/include/numkong/mesh/README.md +227 -0
  139. package/include/numkong/mesh/haswell.h +2235 -0
  140. package/include/numkong/mesh/neon.h +1329 -0
  141. package/include/numkong/mesh/neonbfdot.h +842 -0
  142. package/include/numkong/mesh/neonhalf.h +616 -0
  143. package/include/numkong/mesh/rvv.h +916 -0
  144. package/include/numkong/mesh/serial.h +742 -0
  145. package/include/numkong/mesh/skylake.h +1135 -0
  146. package/include/numkong/mesh/v128relaxed.h +1052 -0
  147. package/include/numkong/mesh.h +652 -0
  148. package/include/numkong/mesh.hpp +762 -0
  149. package/include/numkong/numkong.h +78 -0
  150. package/include/numkong/numkong.hpp +57 -0
  151. package/include/numkong/probability/README.md +173 -0
  152. package/include/numkong/probability/haswell.h +267 -0
  153. package/include/numkong/probability/neon.h +225 -0
  154. package/include/numkong/probability/rvv.h +409 -0
  155. package/include/numkong/probability/serial.h +169 -0
  156. package/include/numkong/probability/skylake.h +324 -0
  157. package/include/numkong/probability.h +383 -0
  158. package/include/numkong/probability.hpp +120 -0
  159. package/include/numkong/random.h +50 -0
  160. package/include/numkong/random.hpp +285 -0
  161. package/include/numkong/reduce/README.md +547 -0
  162. package/include/numkong/reduce/alder.h +632 -0
  163. package/include/numkong/reduce/genoa.h +201 -0
  164. package/include/numkong/reduce/haswell.h +3783 -0
  165. package/include/numkong/reduce/icelake.h +549 -0
  166. package/include/numkong/reduce/neon.h +3841 -0
  167. package/include/numkong/reduce/neonbfdot.h +353 -0
  168. package/include/numkong/reduce/neonfhm.h +665 -0
  169. package/include/numkong/reduce/neonhalf.h +157 -0
  170. package/include/numkong/reduce/neonsdot.h +357 -0
  171. package/include/numkong/reduce/rvv.h +3407 -0
  172. package/include/numkong/reduce/serial.h +757 -0
  173. package/include/numkong/reduce/sierra.h +338 -0
  174. package/include/numkong/reduce/skylake.h +3792 -0
  175. package/include/numkong/reduce/v128relaxed.h +2302 -0
  176. package/include/numkong/reduce.h +1597 -0
  177. package/include/numkong/reduce.hpp +633 -0
  178. package/include/numkong/scalar/README.md +89 -0
  179. package/include/numkong/scalar/haswell.h +113 -0
  180. package/include/numkong/scalar/neon.h +122 -0
  181. package/include/numkong/scalar/neonhalf.h +70 -0
  182. package/include/numkong/scalar/rvv.h +211 -0
  183. package/include/numkong/scalar/sapphire.h +63 -0
  184. package/include/numkong/scalar/serial.h +332 -0
  185. package/include/numkong/scalar/v128relaxed.h +56 -0
  186. package/include/numkong/scalar.h +683 -0
  187. package/include/numkong/set/README.md +179 -0
  188. package/include/numkong/set/haswell.h +334 -0
  189. package/include/numkong/set/icelake.h +485 -0
  190. package/include/numkong/set/neon.h +364 -0
  191. package/include/numkong/set/rvv.h +226 -0
  192. package/include/numkong/set/rvvbb.h +117 -0
  193. package/include/numkong/set/serial.h +174 -0
  194. package/include/numkong/set/sve.h +185 -0
  195. package/include/numkong/set/v128relaxed.h +240 -0
  196. package/include/numkong/set.h +457 -0
  197. package/include/numkong/set.hpp +114 -0
  198. package/include/numkong/sets/README.md +149 -0
  199. package/include/numkong/sets/haswell.h +63 -0
  200. package/include/numkong/sets/icelake.h +66 -0
  201. package/include/numkong/sets/neon.h +61 -0
  202. package/include/numkong/sets/serial.h +43 -0
  203. package/include/numkong/sets/smebi32.h +1099 -0
  204. package/include/numkong/sets/v128relaxed.h +58 -0
  205. package/include/numkong/sets.h +339 -0
  206. package/include/numkong/sparse/README.md +156 -0
  207. package/include/numkong/sparse/icelake.h +463 -0
  208. package/include/numkong/sparse/neon.h +288 -0
  209. package/include/numkong/sparse/serial.h +117 -0
  210. package/include/numkong/sparse/sve2.h +507 -0
  211. package/include/numkong/sparse/turin.h +322 -0
  212. package/include/numkong/sparse.h +363 -0
  213. package/include/numkong/sparse.hpp +113 -0
  214. package/include/numkong/spatial/README.md +435 -0
  215. package/include/numkong/spatial/alder.h +607 -0
  216. package/include/numkong/spatial/genoa.h +290 -0
  217. package/include/numkong/spatial/haswell.h +960 -0
  218. package/include/numkong/spatial/icelake.h +586 -0
  219. package/include/numkong/spatial/neon.h +773 -0
  220. package/include/numkong/spatial/neonbfdot.h +165 -0
  221. package/include/numkong/spatial/neonhalf.h +118 -0
  222. package/include/numkong/spatial/neonsdot.h +261 -0
  223. package/include/numkong/spatial/rvv.h +984 -0
  224. package/include/numkong/spatial/rvvbf16.h +123 -0
  225. package/include/numkong/spatial/rvvhalf.h +117 -0
  226. package/include/numkong/spatial/sapphire.h +343 -0
  227. package/include/numkong/spatial/serial.h +346 -0
  228. package/include/numkong/spatial/sierra.h +323 -0
  229. package/include/numkong/spatial/skylake.h +606 -0
  230. package/include/numkong/spatial/sve.h +224 -0
  231. package/include/numkong/spatial/svebfdot.h +122 -0
  232. package/include/numkong/spatial/svehalf.h +109 -0
  233. package/include/numkong/spatial/v128relaxed.h +717 -0
  234. package/include/numkong/spatial.h +1425 -0
  235. package/include/numkong/spatial.hpp +183 -0
  236. package/include/numkong/spatials/README.md +580 -0
  237. package/include/numkong/spatials/alder.h +94 -0
  238. package/include/numkong/spatials/genoa.h +94 -0
  239. package/include/numkong/spatials/haswell.h +219 -0
  240. package/include/numkong/spatials/icelake.h +113 -0
  241. package/include/numkong/spatials/neon.h +109 -0
  242. package/include/numkong/spatials/neonbfdot.h +60 -0
  243. package/include/numkong/spatials/neonfhm.h +92 -0
  244. package/include/numkong/spatials/neonhalf.h +58 -0
  245. package/include/numkong/spatials/neonsdot.h +109 -0
  246. package/include/numkong/spatials/rvv.h +1960 -0
  247. package/include/numkong/spatials/sapphireamx.h +1149 -0
  248. package/include/numkong/spatials/serial.h +226 -0
  249. package/include/numkong/spatials/sierra.h +96 -0
  250. package/include/numkong/spatials/skylake.h +184 -0
  251. package/include/numkong/spatials/sme.h +1901 -0
  252. package/include/numkong/spatials/smef64.h +465 -0
  253. package/include/numkong/spatials/v128relaxed.h +240 -0
  254. package/include/numkong/spatials.h +3021 -0
  255. package/include/numkong/spatials.hpp +508 -0
  256. package/include/numkong/tensor.hpp +1592 -0
  257. package/include/numkong/trigonometry/README.md +184 -0
  258. package/include/numkong/trigonometry/haswell.h +652 -0
  259. package/include/numkong/trigonometry/neon.h +639 -0
  260. package/include/numkong/trigonometry/rvv.h +699 -0
  261. package/include/numkong/trigonometry/serial.h +703 -0
  262. package/include/numkong/trigonometry/skylake.h +721 -0
  263. package/include/numkong/trigonometry/v128relaxed.h +666 -0
  264. package/include/numkong/trigonometry.h +467 -0
  265. package/include/numkong/trigonometry.hpp +166 -0
  266. package/include/numkong/types.h +1384 -0
  267. package/include/numkong/types.hpp +5603 -0
  268. package/include/numkong/vector.hpp +698 -0
  269. package/javascript/README.md +246 -0
  270. package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
  271. package/javascript/dist/cjs/numkong-wasm.js +617 -0
  272. package/javascript/dist/cjs/numkong.d.ts +343 -0
  273. package/javascript/dist/cjs/numkong.js +523 -0
  274. package/javascript/dist/cjs/package.json +3 -0
  275. package/javascript/dist/cjs/types.d.ts +284 -0
  276. package/javascript/dist/cjs/types.js +653 -0
  277. package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
  278. package/javascript/dist/esm/numkong-wasm.js +595 -0
  279. package/javascript/dist/esm/numkong.d.ts +343 -0
  280. package/javascript/dist/esm/numkong.js +452 -0
  281. package/javascript/dist/esm/package.json +3 -0
  282. package/javascript/dist/esm/types.d.ts +284 -0
  283. package/javascript/dist/esm/types.js +630 -0
  284. package/javascript/dist-package-cjs.json +3 -0
  285. package/javascript/dist-package-esm.json +3 -0
  286. package/javascript/node-gyp-build.d.ts +1 -0
  287. package/javascript/numkong-wasm.ts +756 -0
  288. package/javascript/numkong.c +689 -0
  289. package/javascript/numkong.ts +575 -0
  290. package/javascript/tsconfig-base.json +39 -0
  291. package/javascript/tsconfig-cjs.json +8 -0
  292. package/javascript/tsconfig-esm.json +8 -0
  293. package/javascript/types.ts +674 -0
  294. package/package.json +87 -0
@@ -0,0 +1,470 @@
1
+ /**
2
+ * @brief SIMD-accelerated Type Conversions for Ice Lake.
3
+ * @file include/numkong/cast/icelake.h
4
+ * @author Ash Vardanian
5
+ * @date January 2, 2026
6
+ *
7
+ * @section ice_cast_instructions AVX-512 VBMI2 Instructions
8
+ *
9
+ * Intrinsic Instruction Ice Genoa
10
+ * _mm512_permutex2var_epi16 VPERMI2W (ZMM, ZMM, ZMM) 3cy @ p5 2cy @ p12
11
+ * _mm512_test_epi16_mask VPTESTMW (k, ZMM, ZMM) 3cy @ p5 2cy @ p01
12
+ * _mm512_mask_mov_epi16 VMOVDQU16 (ZMM{k}, ZMM) 1cy @ p05 1cy @ p05
13
+ * _mm512_cvtepi16_epi8 VPMOVWB (YMM, ZMM) 3cy @ p5 2cy @ p12
14
+ *
15
+ * Ice Lake's AVX-512 VBMI2 enables efficient 128-entry LUT lookups via dual VPERMI2W operations.
16
+ * FP8-to-BF16/F16 conversions use 4 ZMM LUT registers with VPTESTMW for range selection, achieving
17
+ * ~6 cycles for 32 FP8 conversions. E5M2-to-F16 simplifies to VPSLLW due to matching exponent bias.
18
+ */
19
+ #ifndef NK_CAST_ICELAKE_H
20
+ #define NK_CAST_ICELAKE_H
21
+
22
+ #if NK_TARGET_X86_
23
+ #if NK_TARGET_ICELAKE
24
+
25
+ #include "numkong/types.h"
26
+ #include "numkong/cast/skylake.h"
27
+
28
+ #if defined(__cplusplus)
29
+ extern "C" {
30
+ #endif
31
+
32
+ #if defined(__clang__)
33
+ #pragma clang attribute push(__attribute__((target("avx2,avx512f,avx512vl,avx512bw,avx512dq,f16c,fma,bmi,bmi2"))), \
34
+ apply_to = function)
35
+ #elif defined(__GNUC__)
36
+ #pragma GCC push_options
37
+ #pragma GCC target("avx2", "avx512f", "avx512vl", "avx512bw", "avx512dq", "f16c", "fma", "bmi", "bmi2")
38
+ #endif
39
+
40
+ #pragma region - Vectorized Conversions
41
+
42
+ /** @brief Convert 32x e4m3 → 32x bf16 via arithmetic + 8-entry subnormal LUT (AVX-512BW).
43
+ * E4M3 format: S EEEE MMM (bias=7). BF16: S EEEEEEEE MMMMMMM (bias=127).
44
+ * Normal values (exp != 0): BF16 = sign | ((lower7 << 4) + 0x3C00).
45
+ * Subnormals (exp == 0, 8 values): looked up from 8-entry LUT via permutexvar.
46
+ * Memory: 16 bytes (8 × 16-bit entries) vs 256 bytes (128-entry LUT). OCP FP8 v1.0. */
47
+ NK_INTERNAL __m512i nk_e4m3x32_to_bf16x32_icelake_(__m256i e4m3x32) {
48
+ __m512i e4m3_i16x32 = _mm512_cvtepu8_epi16(e4m3x32);
49
+ __m512i sign_i16x32 = _mm512_and_si512(e4m3_i16x32, _mm512_set1_epi16((short)0x80));
50
+ __m512i lower7_i16x32 = _mm512_and_si512(e4m3_i16x32, _mm512_set1_epi16(0x7F));
51
+
52
+ // Normal path: BF16 = ((lower7 << 4) + 0x3C00) | (sign << 8)
53
+ // Formula: E4M3 exp=e, mant=m → BF16 exp = e+120 (bias 7→127), mant = m<<4
54
+ __m512i normal_abs_i16x32 = _mm512_add_epi16(_mm512_slli_epi16(lower7_i16x32, 4), _mm512_set1_epi16(0x3C00));
55
+
56
+ // Subnormal LUT (8 entries, repeated 4x for all lanes): E4M3 subnormals are mant × 2^(-9)
57
+ // Values: 0, 1/512, 2/512, 3/512, 4/512, 5/512, 6/512, 7/512
58
+ __m512i subn_lut_i16x32 = _mm512_set_epi16( //
59
+ 0x3C60, 0x3C40, 0x3C20, 0x3C00, 0x3BC0, 0x3B80, 0x3B00, 0x0000, // lane 3
60
+ 0x3C60, 0x3C40, 0x3C20, 0x3C00, 0x3BC0, 0x3B80, 0x3B00, 0x0000, // lane 2
61
+ 0x3C60, 0x3C40, 0x3C20, 0x3C00, 0x3BC0, 0x3B80, 0x3B00, 0x0000, // lane 1
62
+ 0x3C60, 0x3C40, 0x3C20, 0x3C00, 0x3BC0, 0x3B80, 0x3B00, 0x0000); // lane 0
63
+
64
+ // Lookup subnormals via permutexvar (use lower 3 bits of mantissa as index)
65
+ __m512i mant_idx_i16x32 = _mm512_and_si512(e4m3_i16x32, _mm512_set1_epi16(0x07));
66
+ __m512i subnorm_abs_i16x32 = _mm512_permutexvar_epi16(mant_idx_i16x32, subn_lut_i16x32);
67
+
68
+ // Blend: if exponent == 0, use subnormal; else use normal
69
+ __m512i exp_bits_i16x32 = _mm512_and_si512(e4m3_i16x32, _mm512_set1_epi16(0x78));
70
+ __mmask32 is_subnormal = _mm512_cmpeq_epi16_mask(exp_bits_i16x32, _mm512_setzero_si512());
71
+ __m512i result_abs_i16x32 = _mm512_mask_blend_epi16(is_subnormal, normal_abs_i16x32, subnorm_abs_i16x32);
72
+
73
+ // Apply sign: shift E4M3 bit 7 to BF16 bit 15
74
+ sign_i16x32 = _mm512_slli_epi16(sign_i16x32, 8);
75
+ return _mm512_or_si512(result_abs_i16x32, sign_i16x32);
76
+ }
77
+
78
+ /** @brief Convert 32x e5m2 → 32x bf16 via arithmetic + 4-entry subnormal LUT (AVX-512BW).
79
+ * E5M2 format: S EEEEE MM (bias=15). BF16: S EEEEEEEE MMMMMMM (bias=127).
80
+ * Normal values (exp != 0): BF16 = sign | ((lower7 << 5) + 0x3800).
81
+ * Subnormals (exp == 0, 4 values): looked up from 4-entry LUT via permutexvar.
82
+ * Memory: 8 bytes (4 × 16-bit entries) vs 256 bytes (128-entry LUT). OCP FP8 v1.0. */
83
+ NK_INTERNAL __m512i nk_e5m2x32_to_bf16x32_icelake_(__m256i e5m2x32) {
84
+ __m512i e5m2_i16x32 = _mm512_cvtepu8_epi16(e5m2x32);
85
+ __m512i sign_i16x32 = _mm512_and_si512(e5m2_i16x32, _mm512_set1_epi16((short)0x80));
86
+ __m512i lower7_i16x32 = _mm512_and_si512(e5m2_i16x32, _mm512_set1_epi16(0x7F));
87
+
88
+ // Normal path: BF16 = ((lower7 << 5) + 0x3800) | (sign << 8)
89
+ // Formula: E5M2 exp=e, mant=m → BF16 exp = e+112 (bias 15→127), mant = m<<5
90
+ __m512i normal_abs_i16x32 = _mm512_add_epi16(_mm512_slli_epi16(lower7_i16x32, 5), _mm512_set1_epi16(0x3800));
91
+
92
+ // Subnormal LUT (4 entries, repeated 8x for all lanes): E5M2 subnormals are mant × 2^(-16)
93
+ // Values: 0, 1/65536, 2/65536, 3/65536 (4 entries, then zeros for padding to 8)
94
+ __m512i subn_lut_i16x32 = _mm512_set_epi16( //
95
+ 0x0000, 0x0000, 0x0000, 0x0000, 0x3840, 0x3800, 0x3780, 0x0000, // lanes 3-2 (16 entries)
96
+ 0x0000, 0x0000, 0x0000, 0x0000, 0x3840, 0x3800, 0x3780, 0x0000, // lanes 1-0 (16 entries)
97
+ 0x0000, 0x0000, 0x0000, 0x0000, 0x3840, 0x3800, 0x3780, 0x0000, // repeat for remaining
98
+ 0x0000, 0x0000, 0x0000, 0x0000, 0x3840, 0x3800, 0x3780, 0x0000); // all 32 entries
99
+
100
+ // Lookup subnormals via permutexvar (use lower 2 bits of mantissa as index)
101
+ __m512i mant_idx_i16x32 = _mm512_and_si512(e5m2_i16x32, _mm512_set1_epi16(0x03));
102
+ __m512i subnorm_abs_i16x32 = _mm512_permutexvar_epi16(mant_idx_i16x32, subn_lut_i16x32);
103
+
104
+ // Blend: if exponent == 0, use subnormal; else use normal
105
+ __m512i exp_bits_i16x32 = _mm512_and_si512(e5m2_i16x32, _mm512_set1_epi16(0x7C));
106
+ __mmask32 is_subnormal = _mm512_cmpeq_epi16_mask(exp_bits_i16x32, _mm512_setzero_si512());
107
+ __m512i result_abs_i16x32 = _mm512_mask_blend_epi16(is_subnormal, normal_abs_i16x32, subnorm_abs_i16x32);
108
+
109
+ // Apply sign: shift E5M2 bit 7 to BF16 bit 15
110
+ sign_i16x32 = _mm512_slli_epi16(sign_i16x32, 8);
111
+ return _mm512_or_si512(result_abs_i16x32, sign_i16x32);
112
+ }
113
+
114
+ /** @brief Convert 32x e2m3 → 32x bf16 via 32-entry LUT lookup (AVX-512BW).
115
+ * E2M3 format: S EE MMM (bias=1, 6 bits total: sign at bit 5, magnitude bits 4-0).
116
+ * BF16: S EEEEEEEE MMMMMMM (bias=127). Uses single permutexvar; sign handled separately.
117
+ * Subnormals (exp=0): value = mant/8. OCP Microscaling Formats v1.0. */
118
+ NK_INTERNAL __m512i nk_e2m3x32_to_bf16x32_icelake_(__m256i e2m3x32) {
119
+ __m512i e2m3_i16x32 = _mm512_cvtepu8_epi16(e2m3x32);
120
+ __m512i sign_i16x32 = _mm512_and_si512(e2m3_i16x32, _mm512_set1_epi16(0x20)); // E2M3 sign at bit 5
121
+ __m512i idx_i16x32 = _mm512_and_si512(e2m3_i16x32, _mm512_set1_epi16(0x1F));
122
+
123
+ // 32-entry LUT for E2M3 magnitude (5 bits: bits [4:3]=exp, bits [2:0]=mant)
124
+ // E2M3: bias=1, range [0, 7.5] for positive, subnormals = mant/8 (OCP MX v1.0)
125
+ // BF16 = (bf16_exp << 7) | (bf16_mant), where bf16_exp = e2m3_exp + 126, bf16_mant = e2m3_mant << 4
126
+ __m512i const lut_i16x32 = _mm512_set_epi16( //
127
+ 0x40F0, 0x40E0, 0x40D0, 0x40C0, 0x40B0, 0x40A0, 0x4090, 0x4080, // [31-24] exp=3: bf16_exp=129
128
+ 0x4070, 0x4060, 0x4050, 0x4040, 0x4030, 0x4020, 0x4010, 0x4000, // [23-16] exp=2: bf16_exp=128
129
+ 0x3FF0, 0x3FE0, 0x3FD0, 0x3FC0, 0x3FB0, 0x3FA0, 0x3F90, 0x3F80, // [15-8] exp=1: bf16_exp=127
130
+ 0x3F60, 0x3F40, 0x3F20, 0x3F00, 0x3EC0, 0x3E80, 0x3E00, 0x0000); // [7-0] exp=0: subnormals 7/8..1/8, 0
131
+
132
+ // Single permutexvar for 32-entry lookup
133
+ __m512i result_i16x32 = _mm512_permutexvar_epi16(idx_i16x32, lut_i16x32);
134
+
135
+ // Apply sign: shift E2M3 bit 5 to BF16 bit 15, then OR
136
+ sign_i16x32 = _mm512_slli_epi16(sign_i16x32, 10);
137
+ return _mm512_or_si512(result_i16x32, sign_i16x32);
138
+ }
139
+
140
+ /** @brief Convert 32x e3m2 → 32x bf16 via 32-entry LUT lookup (AVX-512BW).
141
+ * E3M2 format: S EEE MM (bias=3, 6 bits total: sign at bit 7, magnitude bits 4-0).
142
+ * BF16: S EEEEEEEE MMMMMMM (bias=127). Uses single permutexvar; sign handled separately. */
143
+ NK_INTERNAL __m512i nk_e3m2x32_to_bf16x32_icelake_(__m256i e3m2x32) {
144
+ __m512i e3m2_i16x32 = _mm512_cvtepu8_epi16(e3m2x32);
145
+ __m512i sign_i16x32 = _mm512_and_si512(e3m2_i16x32, _mm512_set1_epi16(0x20)); // E3M2 sign at bit 5
146
+ __m512i idx_i16x32 = _mm512_and_si512(e3m2_i16x32, _mm512_set1_epi16(0x1F));
147
+
148
+ // 32-entry LUT for E3M2 magnitude (5 bits: bits [4:2]=exp, bits [1:0]=mant)
149
+ // E3M2: bias=3, range [0, 28] for positive, subnormals = mant/16 (OCP Microscaling v1.0)
150
+ // BF16 = (bf16_exp << 7) | (bf16_mant), where bf16_exp = e3m2_exp + 124, bf16_mant = e3m2_mant << 5
151
+ __m512i const lut_i16x32 = _mm512_set_epi16( //
152
+ 0x41E0, 0x41C0, 0x41A0, 0x4180, // [31-28] exp=7, mant=3-0: bf16_exp=131
153
+ 0x4160, 0x4140, 0x4120, 0x4100, // [27-24] exp=6, mant=3-0: bf16_exp=130
154
+ 0x40E0, 0x40C0, 0x40A0, 0x4080, // [23-20] exp=5, mant=3-0: bf16_exp=129
155
+ 0x4060, 0x4040, 0x4020, 0x4000, // [19-16] exp=4, mant=3-0: bf16_exp=128
156
+ 0x3FE0, 0x3FC0, 0x3FA0, 0x3F80, // [15-12] exp=3, mant=3-0: bf16_exp=127
157
+ 0x3F60, 0x3F40, 0x3F20, 0x3F00, // [11-8] exp=2, mant=3-0: bf16_exp=126
158
+ 0x3EE0, 0x3EC0, 0x3EA0, 0x3E80, // [7-4] exp=1, mant=3-0: bf16_exp=125
159
+ 0x3E40, 0x3E00, 0x3D80, 0x0000); // [3-0] exp=0: subnormals 3/16, 2/16, 1/16, 0
160
+
161
+ // Single permutexvar for 32-entry lookup
162
+ __m512i result_i16x32 = _mm512_permutexvar_epi16(idx_i16x32, lut_i16x32);
163
+
164
+ // Apply sign: shift E3M2 bit 5 to BF16 bit 15, then OR
165
+ sign_i16x32 = _mm512_slli_epi16(sign_i16x32, 10);
166
+ return _mm512_or_si512(result_i16x32, sign_i16x32);
167
+ }
168
+
169
+ /** @brief Convert 32x e4m3 → 32x f16 via 128-entry LUT lookup (AVX-512BW).
170
+ * E4M3 format: S EEEE MMM (bias=7). F16: S EEEEE MMMMMMMMMM (bias=15).
171
+ * Uses permutex2var for fast LUT lookup; sign handled separately via shift+OR.
172
+ * Handles all corner cases: zero, subnormals, normals, and NaN. */
173
+ NK_INTERNAL __m512i nk_e4m3x32_to_f16x32_icelake_(__m256i e4m3x32) {
174
+ __m512i e4m3_i16x32 = _mm512_cvtepu8_epi16(e4m3x32);
175
+ __m512i sign_i16x32 = _mm512_and_si512(e4m3_i16x32, _mm512_set1_epi16((short)0x80));
176
+ __m512i idx_i16x32 = _mm512_and_si512(e4m3_i16x32, _mm512_set1_epi16(0x7F));
177
+
178
+ // 128-entry LUT for E4M3 absolute values to F16, split into 4x32 chunks
179
+ // Subnormals (idx 0-7): 0, 1/512, ..., 7/512 mapped to F16
180
+ // Normals (idx 8-126): F16 = (lower7 << 7) + 0x2000
181
+ // NaN (idx 127): 0x7E00
182
+ __m512i const lut0_i16x32 = _mm512_set_epi16( // indices 0-31
183
+ 0x2F80, 0x2F00, 0x2E80, 0x2E00, 0x2D80, 0x2D00, 0x2C80, 0x2C00, // idx 31-24
184
+ 0x2B80, 0x2B00, 0x2A80, 0x2A00, 0x2980, 0x2900, 0x2880, 0x2800, // idx 23-16
185
+ 0x2780, 0x2700, 0x2680, 0x2600, 0x2580, 0x2500, 0x2480, 0x2400, // idx 15-8
186
+ 0x2300, 0x2200, 0x2100, 0x2000, 0x1E00, 0x1C00, 0x1800, 0x0000); // idx 7-0
187
+ __m512i const lut1_i16x32 = _mm512_set_epi16( // indices 32-63
188
+ 0x3F80, 0x3F00, 0x3E80, 0x3E00, 0x3D80, 0x3D00, 0x3C80, 0x3C00, // idx 63-56
189
+ 0x3B80, 0x3B00, 0x3A80, 0x3A00, 0x3980, 0x3900, 0x3880, 0x3800, // idx 55-48
190
+ 0x3780, 0x3700, 0x3680, 0x3600, 0x3580, 0x3500, 0x3480, 0x3400, // idx 47-40
191
+ 0x3380, 0x3300, 0x3280, 0x3200, 0x3180, 0x3100, 0x3080, 0x3000); // idx 39-32
192
+ __m512i const lut2_i16x32 = _mm512_set_epi16( // indices 64-95
193
+ 0x4F80, 0x4F00, 0x4E80, 0x4E00, 0x4D80, 0x4D00, 0x4C80, 0x4C00, // idx 95-88
194
+ 0x4B80, 0x4B00, 0x4A80, 0x4A00, 0x4980, 0x4900, 0x4880, 0x4800, // idx 87-80
195
+ 0x4780, 0x4700, 0x4680, 0x4600, 0x4580, 0x4500, 0x4480, 0x4400, // idx 79-72
196
+ 0x4380, 0x4300, 0x4280, 0x4200, 0x4180, 0x4100, 0x4080, 0x4000); // idx 71-64
197
+ __m512i const lut3_i16x32 = _mm512_set_epi16( // indices 96-127
198
+ 0x7E00, 0x5F00, 0x5E80, 0x5E00, 0x5D80, 0x5D00, 0x5C80, 0x5C00, // idx 127-120
199
+ 0x5B80, 0x5B00, 0x5A80, 0x5A00, 0x5980, 0x5900, 0x5880, 0x5800, // idx 119-112
200
+ 0x5780, 0x5700, 0x5680, 0x5600, 0x5580, 0x5500, 0x5480, 0x5400, // idx 111-104
201
+ 0x5380, 0x5300, 0x5280, 0x5200, 0x5180, 0x5100, 0x5080, 0x5000); // idx 103-96
202
+
203
+ // 2x permutex2var for 64-entry lookup each, then select based on bit 6
204
+ __m512i result_low_i16x32 = _mm512_permutex2var_epi16(lut0_i16x32, idx_i16x32, lut1_i16x32);
205
+ __m512i result_high_i16x32 = _mm512_permutex2var_epi16(lut2_i16x32, idx_i16x32, lut3_i16x32);
206
+
207
+ // Select between low (idx 0-63) and high (idx 64-127) based on bit 6
208
+ __mmask32 use_high_mask = _mm512_test_epi16_mask(idx_i16x32, _mm512_set1_epi16(0x40));
209
+ __m512i result_i16x32 = _mm512_mask_mov_epi16(result_low_i16x32, use_high_mask, result_high_i16x32);
210
+
211
+ // Apply sign: shift sign bit to bit 15, then OR
212
+ sign_i16x32 = _mm512_slli_epi16(sign_i16x32, 8);
213
+ return _mm512_or_si512(result_i16x32, sign_i16x32);
214
+ }
215
+
216
+ /** @brief Convert 32x e5m2 → 32x f16 via simple bit shift (AVX-512BW).
217
+ * E5M2 format: S EEEEE MM (bias=15). F16: S EEEEE MMMMMMMMMM (bias=15).
218
+ * Same exponent bias means F16 = (lower7 << 8) | (sign << 15).
219
+ * Handles all corner cases: zero, subnormals, normals, infinity, and NaN. */
220
+ NK_INTERNAL __m512i nk_e5m2x32_to_f16x32_icelake_(__m256i e5m2x32) {
221
+ __m512i e5m2_i16x32 = _mm512_cvtepu8_epi16(e5m2x32);
222
+ __m512i sign_i16x32 = _mm512_and_si512(e5m2_i16x32, _mm512_set1_epi16((short)0x80));
223
+ __m512i lower7_i16x32 = _mm512_and_si512(e5m2_i16x32, _mm512_set1_epi16(0x7F));
224
+
225
+ // F16 = (lower7 << 8) | (sign << 15)
226
+ // Works for all cases: subnormals, normals, infinity, and NaN
227
+ __m512i result_i16x32 = _mm512_slli_epi16(lower7_i16x32, 8);
228
+ sign_i16x32 = _mm512_slli_epi16(sign_i16x32, 8);
229
+ return _mm512_or_si512(result_i16x32, sign_i16x32);
230
+ }
231
+
232
+ /** @brief Convert 32x bf16 → 32x e4m3 via bit manipulation (AVX-512BW).
233
+ * BF16: S EEEEEEEE MMMMMMM (bias=127). E4M3: S EEEE MMM (bias=7).
234
+ * Handles normal, subnormal, and overflow cases with RNE rounding. */
235
+ NK_INTERNAL __m256i nk_bf16x32_to_e4m3x32_icelake_(__m512i bf16x32) {
236
+ __m512i sign_i16x32 = _mm512_srli_epi16(bf16x32, 15);
237
+ __m512i bf16_exp_i16x32 = _mm512_and_si512(_mm512_srli_epi16(bf16x32, 7), _mm512_set1_epi16(0xFF));
238
+
239
+ // Round mantissa from 7 to 3 bits using RNE (round to nearest, ties to even)
240
+ __m512i significand_i16x32 = _mm512_or_si512(_mm512_and_si512(bf16x32, _mm512_set1_epi16(0x7F)),
241
+ _mm512_set1_epi16(0x80)); // Add implicit 1 bit
242
+ __m512i lsb_i16x32 = _mm512_and_si512(_mm512_srli_epi16(significand_i16x32, 4), _mm512_set1_epi16(1));
243
+ __m512i rounding_bias_i16x32 = _mm512_add_epi16(_mm512_set1_epi16(0x07), lsb_i16x32);
244
+ __m512i rounded_sig_i16x32 = _mm512_add_epi16(significand_i16x32, rounding_bias_i16x32);
245
+ __m512i carry_i16x32 = _mm512_srli_epi16(rounded_sig_i16x32, 8); // Carry into exponent if bit 8 set
246
+ __m512i bf16_mantissa_i16x32 = _mm512_and_si512(_mm512_srli_epi16(rounded_sig_i16x32, 4), _mm512_set1_epi16(0x07));
247
+ // If carry, mantissa becomes 0 (we rounded up to next power of 2)
248
+ bf16_mantissa_i16x32 = _mm512_andnot_si512(_mm512_slli_epi16(carry_i16x32, 15), bf16_mantissa_i16x32);
249
+ __m512i e4m3_exp_i16x32 = _mm512_sub_epi16(_mm512_add_epi16(bf16_exp_i16x32, carry_i16x32), _mm512_set1_epi16(120));
250
+
251
+ // Detect underflow (exp <= 0) and overflow (exp > 15)
252
+ __mmask32 is_subnormal = _mm512_cmpgt_epi16_mask(_mm512_set1_epi16(1), e4m3_exp_i16x32);
253
+ __mmask32 overflow = _mm512_cmpgt_epi16_mask(e4m3_exp_i16x32, _mm512_set1_epi16(15));
254
+
255
+ // Normal path: clamp exp to [1,15]
256
+ // e4m3FN quirk: exp=15 with mantissa=7 is NaN (0x7F), so clamp mantissa to 6 when exp=15.
257
+ __m512i clamped_exp_i16x32 = _mm512_max_epi16(e4m3_exp_i16x32, _mm512_set1_epi16(1));
258
+ clamped_exp_i16x32 = _mm512_min_epi16(clamped_exp_i16x32, _mm512_set1_epi16(15));
259
+ __mmask32 is_max_exp = _mm512_cmpeq_epi16_mask(clamped_exp_i16x32, _mm512_set1_epi16(15));
260
+ __m512i max_mantissa_i16x32 = _mm512_mask_blend_epi16(is_max_exp, _mm512_set1_epi16(7), _mm512_set1_epi16(6));
261
+ __m512i normal_mantissa_i16x32 = _mm512_min_epi16(bf16_mantissa_i16x32, max_mantissa_i16x32);
262
+ normal_mantissa_i16x32 = _mm512_mask_blend_epi16(overflow, normal_mantissa_i16x32, _mm512_set1_epi16(0x06));
263
+ __m512i normal_e4m3_i16x32 = _mm512_or_si512(
264
+ _mm512_slli_epi16(sign_i16x32, 7),
265
+ _mm512_or_si512(_mm512_slli_epi16(clamped_exp_i16x32, 3), normal_mantissa_i16x32));
266
+
267
+ // Subnormal path: compute via f32 to get correct rounding
268
+ // bf16 to f32 is just left shift by 16
269
+ __m512i bf16_low_i32x16 = _mm512_cvtepu16_epi32(_mm512_castsi512_si256(bf16x32));
270
+ __m512i bf16_high_i32x16 = _mm512_cvtepu16_epi32(_mm512_extracti32x8_epi32(bf16x32, 1));
271
+ __m512 f32_low = _mm512_castsi512_ps(_mm512_slli_epi32(bf16_low_i32x16, 16));
272
+ __m512 f32_high = _mm512_castsi512_ps(_mm512_slli_epi32(bf16_high_i32x16, 16));
273
+ __m512 abs_f32_low = _mm512_and_ps(f32_low, _mm512_castsi512_ps(_mm512_set1_epi32(0x7FFFFFFF)));
274
+ __m512 abs_f32_high = _mm512_and_ps(f32_high, _mm512_castsi512_ps(_mm512_set1_epi32(0x7FFFFFFF)));
275
+ __m512 scaled_low = _mm512_mul_ps(abs_f32_low, _mm512_set1_ps(512.0f));
276
+ __m512 scaled_high = _mm512_mul_ps(abs_f32_high, _mm512_set1_ps(512.0f));
277
+ __m512i subnorm_mant_low_i32x16 = _mm512_cvtps_epi32(scaled_low);
278
+ __m512i subnorm_mant_high_i32x16 = _mm512_cvtps_epi32(scaled_high);
279
+ __m256i subnorm_mant_low_i16x16 = _mm512_cvtepi32_epi16(subnorm_mant_low_i32x16);
280
+ __m256i subnorm_mant_high_i16x16 = _mm512_cvtepi32_epi16(subnorm_mant_high_i32x16);
281
+ __m512i subnorm_mantissa_i16x32 = _mm512_inserti64x4(_mm512_castsi256_si512(subnorm_mant_low_i16x16),
282
+ subnorm_mant_high_i16x16, 1);
283
+ __mmask32 promotes_to_normal = _mm512_cmpgt_epi16_mask(subnorm_mantissa_i16x32, _mm512_set1_epi16(7));
284
+ subnorm_mantissa_i16x32 = _mm512_min_epi16(subnorm_mantissa_i16x32, _mm512_set1_epi16(7));
285
+ subnorm_mantissa_i16x32 = _mm512_max_epi16(subnorm_mantissa_i16x32, _mm512_setzero_si512());
286
+ __m512i subnorm_e4m3_i16x32 = _mm512_or_si512(_mm512_slli_epi16(sign_i16x32, 7), subnorm_mantissa_i16x32);
287
+ __m512i first_normal_e4m3_i16x32 = _mm512_or_si512(_mm512_slli_epi16(sign_i16x32, 7), _mm512_set1_epi16(0x08));
288
+ subnorm_e4m3_i16x32 = _mm512_mask_blend_epi16(promotes_to_normal, subnorm_e4m3_i16x32, first_normal_e4m3_i16x32);
289
+
290
+ // Blend: use subnormal result when exp <= 0
291
+ __m512i e4m3_i16x32 = _mm512_mask_blend_epi16(is_subnormal, normal_e4m3_i16x32, subnorm_e4m3_i16x32);
292
+
293
+ // Pack 32 i16s to 32 unsigned i8s via AVX-512BW
294
+ return _mm512_cvtepi16_epi8(e4m3_i16x32);
295
+ }
296
+
297
+ /** @brief Convert 32x bf16 → 32x e5m2 via bit manipulation (AVX-512BW).
298
+ * BF16: S EEEEEEEE MMMMMMM (bias=127). E5M2: S EEEEE MM (bias=15).
299
+ * Handles normal, subnormal, and overflow cases with RNE rounding. */
300
+ NK_INTERNAL __m256i nk_bf16x32_to_e5m2x32_icelake_(__m512i bf16x32) {
301
+ __m512i sign_i16x32 = _mm512_srli_epi16(bf16x32, 15);
302
+ __m512i bf16_exp_i16x32 = _mm512_and_si512(_mm512_srli_epi16(bf16x32, 7), _mm512_set1_epi16(0xFF));
303
+
304
+ // Round mantissa from 7 to 2 bits using RNE (round to nearest, ties to even)
305
+ __m512i significand_i16x32 = _mm512_or_si512(_mm512_and_si512(bf16x32, _mm512_set1_epi16(0x7F)),
306
+ _mm512_set1_epi16(0x80)); // Add implicit 1 bit
307
+ __m512i lsb_i16x32 = _mm512_and_si512(_mm512_srli_epi16(significand_i16x32, 5), _mm512_set1_epi16(1));
308
+ __m512i rounding_bias_i16x32 = _mm512_add_epi16(_mm512_set1_epi16(0x0F), lsb_i16x32);
309
+ __m512i rounded_sig_i16x32 = _mm512_add_epi16(significand_i16x32, rounding_bias_i16x32);
310
+ __m512i carry_i16x32 = _mm512_srli_epi16(rounded_sig_i16x32, 8); // Carry into exponent if bit 8 set
311
+ __m512i bf16_mantissa_i16x32 = _mm512_and_si512(_mm512_srli_epi16(rounded_sig_i16x32, 5), _mm512_set1_epi16(0x03));
312
+ // If carry, mantissa becomes 0 (we rounded up to next power of 2)
313
+ bf16_mantissa_i16x32 = _mm512_andnot_si512(_mm512_slli_epi16(carry_i16x32, 15), bf16_mantissa_i16x32);
314
+ __m512i e5m2_exp_i16x32 = _mm512_sub_epi16(_mm512_add_epi16(bf16_exp_i16x32, carry_i16x32), _mm512_set1_epi16(112));
315
+
316
+ // Detect subnormal (exp <= 0) and overflow (exp > 31)
317
+ __mmask32 is_subnormal = _mm512_cmpgt_epi16_mask(_mm512_set1_epi16(1), e5m2_exp_i16x32);
318
+ __mmask32 overflow = _mm512_cmpgt_epi16_mask(e5m2_exp_i16x32, _mm512_set1_epi16(31));
319
+
320
+ // Normal path: clamp exp to [1,31], on overflow return infinity (exp=31, mantissa=0 = 0x7C)
321
+ __m512i clamped_exp_i16x32 = _mm512_max_epi16(e5m2_exp_i16x32, _mm512_set1_epi16(1));
322
+ clamped_exp_i16x32 = _mm512_min_epi16(clamped_exp_i16x32, _mm512_set1_epi16(31));
323
+ __m512i normal_mantissa_i16x32 = _mm512_mask_blend_epi16(overflow, bf16_mantissa_i16x32, _mm512_setzero_si512());
324
+ __m512i normal_e5m2_i16x32 = _mm512_or_si512(
325
+ _mm512_slli_epi16(sign_i16x32, 7),
326
+ _mm512_or_si512(_mm512_slli_epi16(clamped_exp_i16x32, 2), normal_mantissa_i16x32));
327
+
328
+ // Subnormal path: compute via f32 to get correct rounding
329
+ __m512i bf16_low_i32x16 = _mm512_cvtepu16_epi32(_mm512_castsi512_si256(bf16x32));
330
+ __m512i bf16_high_i32x16 = _mm512_cvtepu16_epi32(_mm512_extracti32x8_epi32(bf16x32, 1));
331
+ __m512 f32_low = _mm512_castsi512_ps(_mm512_slli_epi32(bf16_low_i32x16, 16));
332
+ __m512 f32_high = _mm512_castsi512_ps(_mm512_slli_epi32(bf16_high_i32x16, 16));
333
+ __m512 abs_f32_low = _mm512_and_ps(f32_low, _mm512_castsi512_ps(_mm512_set1_epi32(0x7FFFFFFF)));
334
+ __m512 abs_f32_high = _mm512_and_ps(f32_high, _mm512_castsi512_ps(_mm512_set1_epi32(0x7FFFFFFF)));
335
+ __m512 scaled_low = _mm512_mul_ps(abs_f32_low, _mm512_set1_ps(65536.0f));
336
+ __m512 scaled_high = _mm512_mul_ps(abs_f32_high, _mm512_set1_ps(65536.0f));
337
+ __m512i subnorm_mant_low_i32x16 = _mm512_cvtps_epi32(scaled_low);
338
+ __m512i subnorm_mant_high_i32x16 = _mm512_cvtps_epi32(scaled_high);
339
+ __m256i subnorm_mant_low_i16x16 = _mm512_cvtepi32_epi16(subnorm_mant_low_i32x16);
340
+ __m256i subnorm_mant_high_i16x16 = _mm512_cvtepi32_epi16(subnorm_mant_high_i32x16);
341
+ __m512i subnorm_mantissa_i16x32 = _mm512_inserti64x4(_mm512_castsi256_si512(subnorm_mant_low_i16x16),
342
+ subnorm_mant_high_i16x16, 1);
343
+ __mmask32 promotes_to_normal = _mm512_cmpgt_epi16_mask(subnorm_mantissa_i16x32, _mm512_set1_epi16(3));
344
+ subnorm_mantissa_i16x32 = _mm512_min_epi16(subnorm_mantissa_i16x32, _mm512_set1_epi16(3));
345
+ subnorm_mantissa_i16x32 = _mm512_max_epi16(subnorm_mantissa_i16x32, _mm512_setzero_si512());
346
+ __m512i subnorm_e5m2_i16x32 = _mm512_or_si512(_mm512_slli_epi16(sign_i16x32, 7), subnorm_mantissa_i16x32);
347
+ __m512i first_normal_e5m2_i16x32 = _mm512_or_si512(_mm512_slli_epi16(sign_i16x32, 7), _mm512_set1_epi16(0x04));
348
+ subnorm_e5m2_i16x32 = _mm512_mask_blend_epi16(promotes_to_normal, subnorm_e5m2_i16x32, first_normal_e5m2_i16x32);
349
+
350
+ // Blend: use subnormal result when exp <= 0
351
+ __m512i e5m2_i16x32 = _mm512_mask_blend_epi16(is_subnormal, normal_e5m2_i16x32, subnorm_e5m2_i16x32);
352
+
353
+ // Pack 32 i16s to 32 unsigned i8s via AVX-512BW
354
+ return _mm512_cvtepi16_epi8(e5m2_i16x32);
355
+ }
356
+
357
+ /** @brief Load 32x e4m3 from memory and convert to 32x bf16 (Ice Lake AVX-512BW). */
358
+ NK_INTERNAL void nk_load_e4m3x32_to_bf16x32_icelake_(void const *src, nk_b512_vec_t *dst) {
359
+ dst->zmm = nk_e4m3x32_to_bf16x32_icelake_(_mm256_loadu_si256((__m256i const *)src));
360
+ }
361
+
362
+ /** @brief Partial load n e4m3 elements from memory and convert to bf16 (Ice Lake AVX-512BW). */
363
+ NK_INTERNAL void nk_partial_load_e4m3x32_to_bf16x32_icelake_(void const *src, nk_b512_vec_t *dst, nk_size_t n) {
364
+ __mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, (unsigned int)n);
365
+ __m256i e4m3_partial = _mm256_maskz_loadu_epi8(mask, src);
366
+ dst->zmm = nk_e4m3x32_to_bf16x32_icelake_(e4m3_partial);
367
+ }
368
+
369
+ /** @brief Load 32x e5m2 from memory and convert to 32x bf16 (Ice Lake AVX-512BW). */
370
+ NK_INTERNAL void nk_load_e5m2x32_to_bf16x32_icelake_(void const *src, nk_b512_vec_t *dst) {
371
+ dst->zmm = nk_e5m2x32_to_bf16x32_icelake_(_mm256_loadu_si256((__m256i const *)src));
372
+ }
373
+
374
+ /** @brief Partial load n e5m2 elements from memory and convert to bf16 (Ice Lake AVX-512BW). */
375
+ NK_INTERNAL void nk_partial_load_e5m2x32_to_bf16x32_icelake_(void const *src, nk_b512_vec_t *dst, nk_size_t n) {
376
+ __mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, (unsigned int)n);
377
+ __m256i e5m2_partial = _mm256_maskz_loadu_epi8(mask, src);
378
+ dst->zmm = nk_e5m2x32_to_bf16x32_icelake_(e5m2_partial);
379
+ }
380
+
381
+ /** @brief Load 32x e2m3 from memory and convert to 32x bf16 (Ice Lake AVX-512BW). */
382
+ NK_INTERNAL void nk_load_e2m3x32_to_bf16x32_icelake_(void const *src, nk_b512_vec_t *dst) {
383
+ dst->zmm = nk_e2m3x32_to_bf16x32_icelake_(_mm256_loadu_si256((__m256i const *)src));
384
+ }
385
+
386
+ /** @brief Partial load n e2m3 elements from memory and convert to bf16 (Ice Lake AVX-512BW). */
387
+ NK_INTERNAL void nk_partial_load_e2m3x32_to_bf16x32_icelake_(void const *src, nk_b512_vec_t *dst, nk_size_t n) {
388
+ __mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, (unsigned int)n);
389
+ __m256i e2m3_partial = _mm256_maskz_loadu_epi8(mask, src);
390
+ dst->zmm = nk_e2m3x32_to_bf16x32_icelake_(e2m3_partial);
391
+ }
392
+
393
+ /** @brief Load 32x e3m2 from memory and convert to 32x bf16 (Ice Lake AVX-512BW). */
394
+ NK_INTERNAL void nk_load_e3m2x32_to_bf16x32_icelake_(void const *src, nk_b512_vec_t *dst) {
395
+ dst->zmm = nk_e3m2x32_to_bf16x32_icelake_(_mm256_loadu_si256((__m256i const *)src));
396
+ }
397
+
398
+ /** @brief Partial load n e3m2 elements from memory and convert to bf16 (Ice Lake AVX-512BW). */
399
+ NK_INTERNAL void nk_partial_load_e3m2x32_to_bf16x32_icelake_(void const *src, nk_b512_vec_t *dst, nk_size_t n) {
400
+ __mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, (unsigned int)n);
401
+ __m256i e3m2_partial = _mm256_maskz_loadu_epi8(mask, src);
402
+ dst->zmm = nk_e3m2x32_to_bf16x32_icelake_(e3m2_partial);
403
+ }
404
+
405
+ #pragma endregion - Vectorized Conversions
406
+
407
+ #pragma region - Public API
408
+
409
+ NK_PUBLIC void nk_cast_icelake(void const *from, nk_dtype_t from_type, nk_size_t n, void *to, nk_dtype_t to_type) {
410
+ // Group 1: Conversions to bf16 (e4m3 → bf16, e5m2 → bf16)
411
+ if (to_type == nk_bf16_k && (from_type == nk_e4m3_k || from_type == nk_e5m2_k)) {
412
+ nk_e4m3_t const *from_ptr = (nk_e4m3_t const *)from;
413
+ nk_bf16_t *to_ptr = (nk_bf16_t *)to;
414
+ for (nk_size_t idx = 0; idx < n; idx += 32) {
415
+ nk_size_t remaining = n - idx;
416
+ __mmask32 mask = (remaining >= 32) ? 0xFFFFFFFF : _bzhi_u32(0xFFFFFFFF, (unsigned)remaining);
417
+ __m256i in_f8x32 = _mm256_maskz_loadu_epi8(mask, from_ptr + idx);
418
+ __m512i out_bf16x32 = (from_type == nk_e4m3_k) ? nk_e4m3x32_to_bf16x32_icelake_(in_f8x32)
419
+ : nk_e5m2x32_to_bf16x32_icelake_(in_f8x32);
420
+ _mm512_mask_storeu_epi16(to_ptr + idx, mask, out_bf16x32);
421
+ }
422
+ }
423
+
424
+ // Group 2: Conversions FROM bf16 (bf16 → e4m3, bf16 → e5m2)
425
+ else if (from_type == nk_bf16_k && (to_type == nk_e4m3_k || to_type == nk_e5m2_k)) {
426
+ nk_bf16_t const *from_ptr = (nk_bf16_t const *)from;
427
+ nk_e4m3_t *to_ptr = (nk_e4m3_t *)to;
428
+ for (nk_size_t idx = 0; idx < n; idx += 32) {
429
+ nk_size_t remaining = n - idx;
430
+ __mmask32 mask = (remaining >= 32) ? 0xFFFFFFFF : _bzhi_u32(0xFFFFFFFF, (unsigned)remaining);
431
+ __m512i in_bf16x32 = _mm512_maskz_loadu_epi16(mask, from_ptr + idx);
432
+ __m256i out_f8x32 = (to_type == nk_e4m3_k) ? nk_bf16x32_to_e4m3x32_icelake_(in_bf16x32)
433
+ : nk_bf16x32_to_e5m2x32_icelake_(in_bf16x32);
434
+ _mm256_mask_storeu_epi8(to_ptr + idx, mask, out_f8x32);
435
+ }
436
+ }
437
+
438
+ // Group 3: Conversions to f16 (e4m3 → f16, e5m2 → f16)
439
+ else if (to_type == nk_f16_k && (from_type == nk_e4m3_k || from_type == nk_e5m2_k)) {
440
+ nk_e4m3_t const *from_ptr = (nk_e4m3_t const *)from;
441
+ nk_f16_t *to_ptr = (nk_f16_t *)to;
442
+ for (nk_size_t idx = 0; idx < n; idx += 32) {
443
+ nk_size_t remaining = n - idx;
444
+ __mmask32 mask = (remaining >= 32) ? 0xFFFFFFFF : _bzhi_u32(0xFFFFFFFF, (unsigned)remaining);
445
+ __m256i in_f8x32 = _mm256_maskz_loadu_epi8(mask, from_ptr + idx);
446
+ __m512i out_f16x32 = (from_type == nk_e4m3_k) ? nk_e4m3x32_to_f16x32_icelake_(in_f8x32)
447
+ : nk_e5m2x32_to_f16x32_icelake_(in_f8x32);
448
+ _mm512_mask_storeu_epi16(to_ptr + idx, mask, out_f16x32);
449
+ }
450
+ }
451
+
452
+ // Default: delegate to Skylake for all other conversions
453
+ else nk_cast_skylake(from, from_type, n, to, to_type);
454
+ }
455
+
456
+ #pragma endregion - Public API
457
+
458
+ #if defined(__clang__)
459
+ #pragma clang attribute pop
460
+ #elif defined(__GNUC__)
461
+ #pragma GCC pop_options
462
+ #endif
463
+
464
+ #if defined(__cplusplus)
465
+ } // extern "C"
466
+ #endif
467
+
468
+ #endif // NK_TARGET_ICELAKE
469
+ #endif // NK_TARGET_X86_
470
+ #endif // NK_CAST_ICELAKE_H