numkong 7.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (294) hide show
  1. package/LICENSE +201 -0
  2. package/README.md +495 -0
  3. package/binding.gyp +540 -0
  4. package/c/dispatch.h +512 -0
  5. package/c/dispatch_bf16.c +389 -0
  6. package/c/dispatch_bf16c.c +52 -0
  7. package/c/dispatch_e2m3.c +263 -0
  8. package/c/dispatch_e3m2.c +243 -0
  9. package/c/dispatch_e4m3.c +276 -0
  10. package/c/dispatch_e5m2.c +272 -0
  11. package/c/dispatch_f16.c +376 -0
  12. package/c/dispatch_f16c.c +58 -0
  13. package/c/dispatch_f32.c +378 -0
  14. package/c/dispatch_f32c.c +99 -0
  15. package/c/dispatch_f64.c +296 -0
  16. package/c/dispatch_f64c.c +98 -0
  17. package/c/dispatch_i16.c +96 -0
  18. package/c/dispatch_i32.c +89 -0
  19. package/c/dispatch_i4.c +150 -0
  20. package/c/dispatch_i64.c +86 -0
  21. package/c/dispatch_i8.c +289 -0
  22. package/c/dispatch_other.c +330 -0
  23. package/c/dispatch_u1.c +148 -0
  24. package/c/dispatch_u16.c +124 -0
  25. package/c/dispatch_u32.c +118 -0
  26. package/c/dispatch_u4.c +150 -0
  27. package/c/dispatch_u64.c +102 -0
  28. package/c/dispatch_u8.c +303 -0
  29. package/c/numkong.c +950 -0
  30. package/include/README.md +573 -0
  31. package/include/module.modulemap +129 -0
  32. package/include/numkong/attention/sapphireamx.h +1361 -0
  33. package/include/numkong/attention/sme.h +2066 -0
  34. package/include/numkong/attention.h +49 -0
  35. package/include/numkong/capabilities.h +748 -0
  36. package/include/numkong/cast/README.md +262 -0
  37. package/include/numkong/cast/haswell.h +975 -0
  38. package/include/numkong/cast/icelake.h +470 -0
  39. package/include/numkong/cast/neon.h +1192 -0
  40. package/include/numkong/cast/rvv.h +1021 -0
  41. package/include/numkong/cast/sapphire.h +262 -0
  42. package/include/numkong/cast/serial.h +2262 -0
  43. package/include/numkong/cast/skylake.h +856 -0
  44. package/include/numkong/cast/v128relaxed.h +180 -0
  45. package/include/numkong/cast.h +230 -0
  46. package/include/numkong/curved/README.md +223 -0
  47. package/include/numkong/curved/genoa.h +182 -0
  48. package/include/numkong/curved/haswell.h +276 -0
  49. package/include/numkong/curved/neon.h +205 -0
  50. package/include/numkong/curved/neonbfdot.h +212 -0
  51. package/include/numkong/curved/neonhalf.h +212 -0
  52. package/include/numkong/curved/rvv.h +305 -0
  53. package/include/numkong/curved/serial.h +207 -0
  54. package/include/numkong/curved/skylake.h +457 -0
  55. package/include/numkong/curved/smef64.h +506 -0
  56. package/include/numkong/curved.h +517 -0
  57. package/include/numkong/curved.hpp +144 -0
  58. package/include/numkong/dot/README.md +425 -0
  59. package/include/numkong/dot/alder.h +563 -0
  60. package/include/numkong/dot/genoa.h +315 -0
  61. package/include/numkong/dot/haswell.h +1688 -0
  62. package/include/numkong/dot/icelake.h +883 -0
  63. package/include/numkong/dot/neon.h +818 -0
  64. package/include/numkong/dot/neonbfdot.h +244 -0
  65. package/include/numkong/dot/neonfhm.h +360 -0
  66. package/include/numkong/dot/neonhalf.h +198 -0
  67. package/include/numkong/dot/neonsdot.h +508 -0
  68. package/include/numkong/dot/rvv.h +714 -0
  69. package/include/numkong/dot/rvvbb.h +72 -0
  70. package/include/numkong/dot/rvvbf16.h +123 -0
  71. package/include/numkong/dot/rvvhalf.h +129 -0
  72. package/include/numkong/dot/sapphire.h +141 -0
  73. package/include/numkong/dot/serial.h +838 -0
  74. package/include/numkong/dot/sierra.h +405 -0
  75. package/include/numkong/dot/skylake.h +1084 -0
  76. package/include/numkong/dot/sve.h +379 -0
  77. package/include/numkong/dot/svebfdot.h +74 -0
  78. package/include/numkong/dot/svehalf.h +123 -0
  79. package/include/numkong/dot/v128relaxed.h +1258 -0
  80. package/include/numkong/dot.h +1070 -0
  81. package/include/numkong/dot.hpp +94 -0
  82. package/include/numkong/dots/README.md +496 -0
  83. package/include/numkong/dots/alder.h +114 -0
  84. package/include/numkong/dots/genoa.h +94 -0
  85. package/include/numkong/dots/haswell.h +295 -0
  86. package/include/numkong/dots/icelake.h +171 -0
  87. package/include/numkong/dots/neon.h +120 -0
  88. package/include/numkong/dots/neonbfdot.h +58 -0
  89. package/include/numkong/dots/neonfhm.h +94 -0
  90. package/include/numkong/dots/neonhalf.h +57 -0
  91. package/include/numkong/dots/neonsdot.h +108 -0
  92. package/include/numkong/dots/rvv.h +2486 -0
  93. package/include/numkong/dots/sapphireamx.h +3973 -0
  94. package/include/numkong/dots/serial.h +2844 -0
  95. package/include/numkong/dots/sierra.h +97 -0
  96. package/include/numkong/dots/skylake.h +196 -0
  97. package/include/numkong/dots/sme.h +5372 -0
  98. package/include/numkong/dots/smebi32.h +461 -0
  99. package/include/numkong/dots/smef64.h +1318 -0
  100. package/include/numkong/dots/smehalf.h +47 -0
  101. package/include/numkong/dots/v128relaxed.h +294 -0
  102. package/include/numkong/dots.h +2804 -0
  103. package/include/numkong/dots.hpp +639 -0
  104. package/include/numkong/each/README.md +469 -0
  105. package/include/numkong/each/haswell.h +1658 -0
  106. package/include/numkong/each/icelake.h +272 -0
  107. package/include/numkong/each/neon.h +1104 -0
  108. package/include/numkong/each/neonbfdot.h +212 -0
  109. package/include/numkong/each/neonhalf.h +410 -0
  110. package/include/numkong/each/rvv.h +1121 -0
  111. package/include/numkong/each/sapphire.h +477 -0
  112. package/include/numkong/each/serial.h +260 -0
  113. package/include/numkong/each/skylake.h +1562 -0
  114. package/include/numkong/each.h +2146 -0
  115. package/include/numkong/each.hpp +434 -0
  116. package/include/numkong/geospatial/README.md +147 -0
  117. package/include/numkong/geospatial/haswell.h +593 -0
  118. package/include/numkong/geospatial/neon.h +571 -0
  119. package/include/numkong/geospatial/rvv.h +701 -0
  120. package/include/numkong/geospatial/serial.h +309 -0
  121. package/include/numkong/geospatial/skylake.h +577 -0
  122. package/include/numkong/geospatial/v128relaxed.h +613 -0
  123. package/include/numkong/geospatial.h +453 -0
  124. package/include/numkong/geospatial.hpp +235 -0
  125. package/include/numkong/matrix.hpp +336 -0
  126. package/include/numkong/maxsim/README.md +187 -0
  127. package/include/numkong/maxsim/alder.h +511 -0
  128. package/include/numkong/maxsim/genoa.h +115 -0
  129. package/include/numkong/maxsim/haswell.h +553 -0
  130. package/include/numkong/maxsim/icelake.h +480 -0
  131. package/include/numkong/maxsim/neonsdot.h +394 -0
  132. package/include/numkong/maxsim/sapphireamx.h +877 -0
  133. package/include/numkong/maxsim/serial.h +490 -0
  134. package/include/numkong/maxsim/sme.h +929 -0
  135. package/include/numkong/maxsim/v128relaxed.h +280 -0
  136. package/include/numkong/maxsim.h +571 -0
  137. package/include/numkong/maxsim.hpp +133 -0
  138. package/include/numkong/mesh/README.md +227 -0
  139. package/include/numkong/mesh/haswell.h +2235 -0
  140. package/include/numkong/mesh/neon.h +1329 -0
  141. package/include/numkong/mesh/neonbfdot.h +842 -0
  142. package/include/numkong/mesh/neonhalf.h +616 -0
  143. package/include/numkong/mesh/rvv.h +916 -0
  144. package/include/numkong/mesh/serial.h +742 -0
  145. package/include/numkong/mesh/skylake.h +1135 -0
  146. package/include/numkong/mesh/v128relaxed.h +1052 -0
  147. package/include/numkong/mesh.h +652 -0
  148. package/include/numkong/mesh.hpp +762 -0
  149. package/include/numkong/numkong.h +78 -0
  150. package/include/numkong/numkong.hpp +57 -0
  151. package/include/numkong/probability/README.md +173 -0
  152. package/include/numkong/probability/haswell.h +267 -0
  153. package/include/numkong/probability/neon.h +225 -0
  154. package/include/numkong/probability/rvv.h +409 -0
  155. package/include/numkong/probability/serial.h +169 -0
  156. package/include/numkong/probability/skylake.h +324 -0
  157. package/include/numkong/probability.h +383 -0
  158. package/include/numkong/probability.hpp +120 -0
  159. package/include/numkong/random.h +50 -0
  160. package/include/numkong/random.hpp +285 -0
  161. package/include/numkong/reduce/README.md +547 -0
  162. package/include/numkong/reduce/alder.h +632 -0
  163. package/include/numkong/reduce/genoa.h +201 -0
  164. package/include/numkong/reduce/haswell.h +3783 -0
  165. package/include/numkong/reduce/icelake.h +549 -0
  166. package/include/numkong/reduce/neon.h +3841 -0
  167. package/include/numkong/reduce/neonbfdot.h +353 -0
  168. package/include/numkong/reduce/neonfhm.h +665 -0
  169. package/include/numkong/reduce/neonhalf.h +157 -0
  170. package/include/numkong/reduce/neonsdot.h +357 -0
  171. package/include/numkong/reduce/rvv.h +3407 -0
  172. package/include/numkong/reduce/serial.h +757 -0
  173. package/include/numkong/reduce/sierra.h +338 -0
  174. package/include/numkong/reduce/skylake.h +3792 -0
  175. package/include/numkong/reduce/v128relaxed.h +2302 -0
  176. package/include/numkong/reduce.h +1597 -0
  177. package/include/numkong/reduce.hpp +633 -0
  178. package/include/numkong/scalar/README.md +89 -0
  179. package/include/numkong/scalar/haswell.h +113 -0
  180. package/include/numkong/scalar/neon.h +122 -0
  181. package/include/numkong/scalar/neonhalf.h +70 -0
  182. package/include/numkong/scalar/rvv.h +211 -0
  183. package/include/numkong/scalar/sapphire.h +63 -0
  184. package/include/numkong/scalar/serial.h +332 -0
  185. package/include/numkong/scalar/v128relaxed.h +56 -0
  186. package/include/numkong/scalar.h +683 -0
  187. package/include/numkong/set/README.md +179 -0
  188. package/include/numkong/set/haswell.h +334 -0
  189. package/include/numkong/set/icelake.h +485 -0
  190. package/include/numkong/set/neon.h +364 -0
  191. package/include/numkong/set/rvv.h +226 -0
  192. package/include/numkong/set/rvvbb.h +117 -0
  193. package/include/numkong/set/serial.h +174 -0
  194. package/include/numkong/set/sve.h +185 -0
  195. package/include/numkong/set/v128relaxed.h +240 -0
  196. package/include/numkong/set.h +457 -0
  197. package/include/numkong/set.hpp +114 -0
  198. package/include/numkong/sets/README.md +149 -0
  199. package/include/numkong/sets/haswell.h +63 -0
  200. package/include/numkong/sets/icelake.h +66 -0
  201. package/include/numkong/sets/neon.h +61 -0
  202. package/include/numkong/sets/serial.h +43 -0
  203. package/include/numkong/sets/smebi32.h +1099 -0
  204. package/include/numkong/sets/v128relaxed.h +58 -0
  205. package/include/numkong/sets.h +339 -0
  206. package/include/numkong/sparse/README.md +156 -0
  207. package/include/numkong/sparse/icelake.h +463 -0
  208. package/include/numkong/sparse/neon.h +288 -0
  209. package/include/numkong/sparse/serial.h +117 -0
  210. package/include/numkong/sparse/sve2.h +507 -0
  211. package/include/numkong/sparse/turin.h +322 -0
  212. package/include/numkong/sparse.h +363 -0
  213. package/include/numkong/sparse.hpp +113 -0
  214. package/include/numkong/spatial/README.md +435 -0
  215. package/include/numkong/spatial/alder.h +607 -0
  216. package/include/numkong/spatial/genoa.h +290 -0
  217. package/include/numkong/spatial/haswell.h +960 -0
  218. package/include/numkong/spatial/icelake.h +586 -0
  219. package/include/numkong/spatial/neon.h +773 -0
  220. package/include/numkong/spatial/neonbfdot.h +165 -0
  221. package/include/numkong/spatial/neonhalf.h +118 -0
  222. package/include/numkong/spatial/neonsdot.h +261 -0
  223. package/include/numkong/spatial/rvv.h +984 -0
  224. package/include/numkong/spatial/rvvbf16.h +123 -0
  225. package/include/numkong/spatial/rvvhalf.h +117 -0
  226. package/include/numkong/spatial/sapphire.h +343 -0
  227. package/include/numkong/spatial/serial.h +346 -0
  228. package/include/numkong/spatial/sierra.h +323 -0
  229. package/include/numkong/spatial/skylake.h +606 -0
  230. package/include/numkong/spatial/sve.h +224 -0
  231. package/include/numkong/spatial/svebfdot.h +122 -0
  232. package/include/numkong/spatial/svehalf.h +109 -0
  233. package/include/numkong/spatial/v128relaxed.h +717 -0
  234. package/include/numkong/spatial.h +1425 -0
  235. package/include/numkong/spatial.hpp +183 -0
  236. package/include/numkong/spatials/README.md +580 -0
  237. package/include/numkong/spatials/alder.h +94 -0
  238. package/include/numkong/spatials/genoa.h +94 -0
  239. package/include/numkong/spatials/haswell.h +219 -0
  240. package/include/numkong/spatials/icelake.h +113 -0
  241. package/include/numkong/spatials/neon.h +109 -0
  242. package/include/numkong/spatials/neonbfdot.h +60 -0
  243. package/include/numkong/spatials/neonfhm.h +92 -0
  244. package/include/numkong/spatials/neonhalf.h +58 -0
  245. package/include/numkong/spatials/neonsdot.h +109 -0
  246. package/include/numkong/spatials/rvv.h +1960 -0
  247. package/include/numkong/spatials/sapphireamx.h +1149 -0
  248. package/include/numkong/spatials/serial.h +226 -0
  249. package/include/numkong/spatials/sierra.h +96 -0
  250. package/include/numkong/spatials/skylake.h +184 -0
  251. package/include/numkong/spatials/sme.h +1901 -0
  252. package/include/numkong/spatials/smef64.h +465 -0
  253. package/include/numkong/spatials/v128relaxed.h +240 -0
  254. package/include/numkong/spatials.h +3021 -0
  255. package/include/numkong/spatials.hpp +508 -0
  256. package/include/numkong/tensor.hpp +1592 -0
  257. package/include/numkong/trigonometry/README.md +184 -0
  258. package/include/numkong/trigonometry/haswell.h +652 -0
  259. package/include/numkong/trigonometry/neon.h +639 -0
  260. package/include/numkong/trigonometry/rvv.h +699 -0
  261. package/include/numkong/trigonometry/serial.h +703 -0
  262. package/include/numkong/trigonometry/skylake.h +721 -0
  263. package/include/numkong/trigonometry/v128relaxed.h +666 -0
  264. package/include/numkong/trigonometry.h +467 -0
  265. package/include/numkong/trigonometry.hpp +166 -0
  266. package/include/numkong/types.h +1384 -0
  267. package/include/numkong/types.hpp +5603 -0
  268. package/include/numkong/vector.hpp +698 -0
  269. package/javascript/README.md +246 -0
  270. package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
  271. package/javascript/dist/cjs/numkong-wasm.js +617 -0
  272. package/javascript/dist/cjs/numkong.d.ts +343 -0
  273. package/javascript/dist/cjs/numkong.js +523 -0
  274. package/javascript/dist/cjs/package.json +3 -0
  275. package/javascript/dist/cjs/types.d.ts +284 -0
  276. package/javascript/dist/cjs/types.js +653 -0
  277. package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
  278. package/javascript/dist/esm/numkong-wasm.js +595 -0
  279. package/javascript/dist/esm/numkong.d.ts +343 -0
  280. package/javascript/dist/esm/numkong.js +452 -0
  281. package/javascript/dist/esm/package.json +3 -0
  282. package/javascript/dist/esm/types.d.ts +284 -0
  283. package/javascript/dist/esm/types.js +630 -0
  284. package/javascript/dist-package-cjs.json +3 -0
  285. package/javascript/dist-package-esm.json +3 -0
  286. package/javascript/node-gyp-build.d.ts +1 -0
  287. package/javascript/numkong-wasm.ts +756 -0
  288. package/javascript/numkong.c +689 -0
  289. package/javascript/numkong.ts +575 -0
  290. package/javascript/tsconfig-base.json +39 -0
  291. package/javascript/tsconfig-cjs.json +8 -0
  292. package/javascript/tsconfig-esm.json +8 -0
  293. package/javascript/types.ts +674 -0
  294. package/package.json +87 -0
@@ -0,0 +1,856 @@
1
+ /**
2
+ * @brief SIMD-accelerated Type Conversions for Skylake.
3
+ * @file include/numkong/cast/skylake.h
4
+ * @author Ash Vardanian
5
+ * @date December 27, 2025
6
+ *
7
+ * @sa include/numkong/cast.h
8
+ *
9
+ * @section skylake_cast_instructions AVX-512 Conversion Instructions
10
+ *
11
+ * Intrinsic Instruction SKL ICL Genoa
12
+ * _mm512_cvtph_ps VCVTPH2PS (ZMM, YMM) 5cy @ p05 5cy @ p05 4cy @ p01
13
+ * _mm512_cvtps_ph VCVTPS2PH (YMM, ZMM, imm) 5cy @ p05 5cy @ p05 4cy @ p01
14
+ * _mm512_cvtps_epi32 VCVTPS2DQ (ZMM, ZMM) 4cy @ p0 4cy @ p0 3cy @ p01
15
+ * _mm512_cvtepi32_ps VCVTDQ2PS (ZMM, ZMM) 4cy @ p0 4cy @ p0 3cy @ p01
16
+ * _mm512_cvtepi32_epi16 VPMOVDW (YMM, ZMM) 3cy @ p5 3cy @ p5 2cy @ p12
17
+ * _mm512_cvtsepi32_epi8 VPMOVSDB (XMM, ZMM) 3cy @ p5 3cy @ p5 2cy @ p12
18
+ *
19
+ * F16 conversions use hardware F16C via VCVTPH2PS/VCVTPS2PH. BF16 lacks hardware support on Skylake,
20
+ * requiring emulation via VPMOVZXWD + VPSLLD for bf16-to-f32, achieving ~4cy total. FP8 (E4M3/E5M2)
21
+ * conversions use bit manipulation with VPTERNLOGD for sign/exp/mantissa composition.
22
+ */
23
+ #ifndef NK_CAST_SKYLAKE_H
24
+ #define NK_CAST_SKYLAKE_H
25
+
26
+ #if NK_TARGET_X86_
27
+ #if NK_TARGET_SKYLAKE
28
+
29
+ #include "numkong/types.h"
30
+ #include "numkong/cast/serial.h" // `nk_dtype_bits`
31
+
32
+ #if defined(__cplusplus)
33
+ extern "C" {
34
+ #endif
35
+
36
+ #if defined(__clang__)
37
+ #pragma clang attribute push(__attribute__((target("avx2,avx512f,avx512vl,avx512bw,avx512dq,f16c,fma,bmi,bmi2"))), \
38
+ apply_to = function)
39
+ #elif defined(__GNUC__)
40
+ #pragma GCC push_options
41
+ #pragma GCC target("avx2", "avx512f", "avx512vl", "avx512bw", "avx512dq", "f16c", "fma", "bmi", "bmi2")
42
+ #endif
43
+
44
+ #pragma region - Type Punned Loads and Stores
45
+
46
+ /** @brief Type-agnostic 512-bit full load (Skylake AVX-512). */
47
+ NK_INTERNAL void nk_load_b512_skylake_(void const *src, nk_b512_vec_t *dst) { dst->zmm = _mm512_loadu_si512(src); }
48
+
49
+ /** @brief Type-agnostic partial load for 64-bit elements (8 elements max) into 512-bit vector (Skylake AVX-512). */
50
+ NK_INTERNAL void nk_partial_load_b64x8_skylake_(void const *src, nk_b512_vec_t *dst, nk_size_t n) {
51
+ __mmask8 mask = (__mmask8)_bzhi_u32(0xFF, (unsigned int)n);
52
+ dst->zmm = _mm512_maskz_loadu_epi64(mask, src);
53
+ }
54
+
55
+ /** @brief Type-agnostic partial load for 32-bit elements (16 elements max) into 512-bit vector (Skylake AVX-512). */
56
+ NK_INTERNAL void nk_partial_load_b32x16_skylake_(void const *src, nk_b512_vec_t *dst, nk_size_t n) {
57
+ __mmask16 mask = (__mmask16)_bzhi_u32(0xFFFF, (unsigned int)n);
58
+ dst->zmm = _mm512_maskz_loadu_epi32(mask, src);
59
+ }
60
+
61
+ /** @brief Type-agnostic partial load for 16-bit elements (32 elements max) into 512-bit vector (Skylake AVX-512). */
62
+ NK_INTERNAL void nk_partial_load_b16x32_skylake_(void const *src, nk_b512_vec_t *dst, nk_size_t n) {
63
+ __mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, (unsigned int)n);
64
+ dst->zmm = _mm512_maskz_loadu_epi16(mask, src);
65
+ }
66
+
67
+ /** @brief Partial load for 8-bit elements (64 max) into 512-bit vector (zeros in remaining slots). */
68
+ NK_INTERNAL void nk_partial_load_b8x64_skylake_(void const *src, nk_b512_vec_t *dst, nk_size_t n) {
69
+ __mmask64 mask = _bzhi_u64(0xFFFFFFFFFFFFFFFFULL, (unsigned int)n);
70
+ dst->zmm = _mm512_maskz_loadu_epi8(mask, src);
71
+ }
72
+
73
+ /** @brief Partial load for 4-bit nibbles (128 max = 64 bytes) into 512-bit vector (Skylake AVX-512). */
74
+ NK_INTERNAL void nk_partial_load_b4x128_skylake_(void const *src, nk_b512_vec_t *dst, nk_size_t n) {
75
+ nk_size_t n_bytes = nk_size_divide_round_up_(n, 2);
76
+ __mmask64 mask = _bzhi_u64(0xFFFFFFFFFFFFFFFFULL, (unsigned int)n_bytes);
77
+ dst->zmm = _mm512_maskz_loadu_epi8(mask, src);
78
+ }
79
+
80
+ /** @brief Type-agnostic partial load for 32-bit elements (8 elements max) into 256-bit vector (Skylake AVX-512). */
81
+ NK_INTERNAL void nk_partial_load_b32x8_skylake_(void const *src, nk_b256_vec_t *dst, nk_size_t n) {
82
+ __mmask8 mask = (__mmask8)_bzhi_u32(0xFF, (unsigned int)n);
83
+ dst->ymm = _mm256_maskz_loadu_epi32(mask, src);
84
+ }
85
+
86
+ /** @brief Type-agnostic partial load for 16-bit elements (16 elements max) into 256-bit vector (Skylake AVX-512). */
87
+ NK_INTERNAL void nk_partial_load_b16x16_skylake_(void const *src, nk_b256_vec_t *dst, nk_size_t n) {
88
+ __mmask16 mask = (__mmask16)_bzhi_u32(0xFFFF, (unsigned int)n);
89
+ dst->ymm = _mm256_maskz_loadu_epi16(mask, src);
90
+ }
91
+
92
+ /** @brief Type-agnostic partial load for 8-bit elements (16 elements max) into 128-bit vector (Skylake AVX-512). */
93
+ NK_INTERNAL void nk_partial_load_b8x16_skylake_(void const *src, nk_b128_vec_t *dst, nk_size_t n) {
94
+ __mmask16 mask = (__mmask16)_bzhi_u32(0xFFFF, (unsigned int)n);
95
+ dst->xmm = _mm_maskz_loadu_epi8(mask, src);
96
+ }
97
+
98
+ /** @brief Partial load for 1-bit elements (512 max bits = 64 bytes) into 512-bit vector (Skylake AVX-512).
99
+ * Wrapper that converts bit count to byte count and delegates to byte-level masked load. */
100
+ NK_INTERNAL void nk_partial_load_b1x512_skylake_(void const *src, nk_b512_vec_t *dst, nk_size_t n_bits) {
101
+ nk_size_t n_bytes = nk_size_divide_round_up_(n_bits, 8);
102
+ nk_partial_load_b8x64_skylake_(src, dst, n_bytes);
103
+ }
104
+
105
+ /** @brief Type-agnostic partial load for 32-bit elements (4 elements max) into 128-bit vector (Skylake AVX-512). */
106
+ NK_INTERNAL void nk_partial_load_b32x4_skylake_(void const *src, nk_b128_vec_t *dst, nk_size_t n) {
107
+ __mmask8 mask = (__mmask8)_bzhi_u32(0xF, (unsigned int)n);
108
+ dst->xmm = _mm_maskz_loadu_epi32(mask, src);
109
+ }
110
+
111
+ /** @brief Type-agnostic partial load for 64-bit elements (4 elements max) into 256-bit vector (Skylake AVX-512). */
112
+ NK_INTERNAL void nk_partial_load_b64x4_skylake_(void const *src, nk_b256_vec_t *dst, nk_size_t n) {
113
+ __mmask8 mask = (__mmask8)_bzhi_u32(0xF, (unsigned int)n);
114
+ dst->ymm = _mm256_maskz_loadu_epi64(mask, src);
115
+ }
116
+
117
+ /** @brief Type-agnostic partial store for 32-bit elements (16 elements max) from 512-bit vector (Skylake AVX-512). */
118
+ NK_INTERNAL void nk_partial_store_b32x16_skylake_(nk_b512_vec_t const *src, void *dst, nk_size_t n) {
119
+ __mmask16 mask = (__mmask16)_bzhi_u32(0xFFFF, (unsigned int)n);
120
+ _mm512_mask_storeu_epi32(dst, mask, src->zmm);
121
+ }
122
+
123
+ /** @brief Type-agnostic partial store for 32-bit elements (4 elements max) from 128-bit vector (Skylake AVX-512). */
124
+ NK_INTERNAL void nk_partial_store_b32x4_skylake_(nk_b128_vec_t const *src, void *dst, nk_size_t n) {
125
+ __mmask8 mask = (__mmask8)_bzhi_u32(0xF, (unsigned int)n);
126
+ _mm_mask_storeu_epi32(dst, mask, src->xmm);
127
+ }
128
+
129
+ /** @brief Type-agnostic partial store for 64-bit elements (4 elements max) from 256-bit vector (Skylake AVX-512). */
130
+ NK_INTERNAL void nk_partial_store_b64x4_skylake_(nk_b256_vec_t const *src, void *dst, nk_size_t n) {
131
+ __mmask8 mask = (__mmask8)_bzhi_u32(0xF, (unsigned int)n);
132
+ _mm256_mask_storeu_epi64(dst, mask, src->ymm);
133
+ }
134
+
135
+ #pragma endregion - Type Punned Loads and Stores
136
+
137
+ #pragma region - Vectorized Conversions
138
+
139
+ /** @brief Convert 16x bf16 → 16x f32 (Skylake AVX-512). */
140
+ NK_INTERNAL __m512 nk_bf16x16_to_f32x16_skylake_(__m256i a) {
141
+ // Upcasting from `bf16` to `f32` is done by shifting the `bf16` values by 16 bits to the left, like:
142
+ return _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(a), 16));
143
+ }
144
+
145
+ /** @brief Convert 16x f32 → 16x bf16 (Skylake AVX-512). */
146
+ NK_INTERNAL __m256i nk_f32x16_to_bf16x16_skylake_(__m512 a) {
147
+ // Round-to-nearest-even: add (0x7FFF + lsb) to match hardware BF16 behavior
148
+ __m512i bits = _mm512_castps_si512(a);
149
+ __m512i lsb = _mm512_and_si512(_mm512_srli_epi32(bits, 16), _mm512_set1_epi32(1));
150
+ __m512i rounded = _mm512_add_epi32(bits, _mm512_add_epi32(_mm512_set1_epi32(0x7FFF), lsb));
151
+ __m512i x = _mm512_srli_epi32(rounded, 16);
152
+ return _mm512_cvtepi32_epi16(x);
153
+ }
154
+
155
+ /** @brief Convert 16x e4m3 → 16x f32 via bit manipulation (AVX-512).
156
+ * E4M3 format: S EEEE MMM (bias=7). F32: sign<<31, (exp+120)<<23, mantissa<<20.
157
+ * Subnormals (exp=0): value = mantissa × 2⁽¹⁻⁷⁾ × 2⁻³ = mantissa ÷ 512. */
158
+ NK_INTERNAL __m512 nk_e4m3x16_to_f32x16_skylake_(__m128i e4m3_i8x16) {
159
+ __m512i e4m3_i32x16 = _mm512_cvtepu8_epi32(e4m3_i8x16);
160
+
161
+ // Extract fields
162
+ __m512i exp_i32x16 = _mm512_and_si512(_mm512_srli_epi32(e4m3_i32x16, 3), _mm512_set1_epi32(0x0F));
163
+ __m512i mantissa_i32x16 = _mm512_and_si512(e4m3_i32x16, _mm512_set1_epi32(0x07));
164
+ __m512i sign_i32x16 = _mm512_slli_epi32(_mm512_srli_epi32(e4m3_i32x16, 7), 31);
165
+
166
+ // Normal path: sign | ((exp+120)<<23) | (mantissa<<20)
167
+ __m512i f32_exp_i32x16 = _mm512_slli_epi32(_mm512_add_epi32(exp_i32x16, _mm512_set1_epi32(120)), 23);
168
+ __m512i f32_mantissa_i32x16 = _mm512_slli_epi32(mantissa_i32x16, 20);
169
+ __m512 result_f32x16 = _mm512_castsi512_ps(
170
+ _mm512_ternarylogic_epi32(sign_i32x16, f32_exp_i32x16, f32_mantissa_i32x16, 0xFE));
171
+
172
+ // Subnormal fix: for exp==0 lanes, replace with (mantissa / 512) | sign using masked OR
173
+ __mmask16 is_subnormal = _mm512_testn_epi32_mask(e4m3_i32x16, _mm512_set1_epi32(0x78));
174
+ __m512 subnorm_abs_f32x16 = _mm512_mul_ps(_mm512_cvtepi32_ps(mantissa_i32x16), _mm512_set1_ps(1.0f / 512.0f));
175
+ result_f32x16 = _mm512_mask_or_ps(result_f32x16, is_subnormal, subnorm_abs_f32x16,
176
+ _mm512_castsi512_ps(sign_i32x16));
177
+
178
+ // NaN path: E4M3FN has NaN only when exp=15 AND mant=7 (0x7F or 0xFF)
179
+ __mmask16 is_nan = _mm512_mask_cmpeq_epi32_mask( //
180
+ _mm512_cmpeq_epi32_mask(exp_i32x16, _mm512_set1_epi32(15)), //
181
+ mantissa_i32x16, _mm512_set1_epi32(7)); //
182
+ __m512i nan_bits = _mm512_or_si512(sign_i32x16, _mm512_set1_epi32(0x7FC00000)); // F32 quiet NaN
183
+ return _mm512_mask_blend_ps(is_nan, result_f32x16, _mm512_castsi512_ps(nan_bits));
184
+ }
185
+
186
+ /** @brief Convert 16x e5m2 → 16x f32 via bit manipulation (AVX-512).
187
+ * E5M2 format: S EEEEE MM (bias=15). F32: sign<<31, (exp+112)<<23, mantissa<<21.
188
+ * Subnormals (exp=0): value = mantissa × 2⁽¹⁻¹⁵⁾ × 2⁻² = mantissa ÷ 65536. */
189
+ NK_INTERNAL __m512 nk_e5m2x16_to_f32x16_skylake_(__m128i e5m2_i8x16) {
190
+ __m512i e5m2_i32x16 = _mm512_cvtepu8_epi32(e5m2_i8x16);
191
+
192
+ // Extract fields
193
+ __m512i exp_i32x16 = _mm512_and_si512(_mm512_srli_epi32(e5m2_i32x16, 2), _mm512_set1_epi32(0x1F));
194
+ __m512i mantissa_i32x16 = _mm512_and_si512(e5m2_i32x16, _mm512_set1_epi32(0x03));
195
+ __m512i sign_i32x16 = _mm512_slli_epi32(_mm512_srli_epi32(e5m2_i32x16, 7), 31);
196
+
197
+ // Normal path: sign | ((exp+112)<<23) | (mantissa<<21)
198
+ __m512i f32_exp_i32x16 = _mm512_slli_epi32(_mm512_add_epi32(exp_i32x16, _mm512_set1_epi32(112)), 23);
199
+ __m512i f32_mantissa_i32x16 = _mm512_slli_epi32(mantissa_i32x16, 21);
200
+ __m512 result_f32x16 = _mm512_castsi512_ps(
201
+ _mm512_ternarylogic_epi32(sign_i32x16, f32_exp_i32x16, f32_mantissa_i32x16, 0xFE));
202
+
203
+ // Subnormal fix: for exp==0 lanes, replace with (mantissa / 65536) | sign using masked OR
204
+ __mmask16 is_subnormal = _mm512_testn_epi32_mask(e5m2_i32x16, _mm512_set1_epi32(0x7C));
205
+ __m512 subnorm_abs_f32x16 = _mm512_mul_ps(_mm512_cvtepi32_ps(mantissa_i32x16), _mm512_set1_ps(1.0f / 65536.0f));
206
+ return _mm512_mask_or_ps(result_f32x16, is_subnormal, subnorm_abs_f32x16, _mm512_castsi512_ps(sign_i32x16));
207
+ }
208
+
209
+ /** @brief Convert 16x e2m3 → 16x f32 via bit manipulation (AVX-512).
210
+ * E2M3 format: S EE MMM (bias=1, only 6 bits used). F32: sign<<31, (exp+126)<<23, mantissa<<20.
211
+ * Subnormals (exp=0): value = mantissa × 2⁽¹⁻¹⁾ × 2⁻³ = mantissa ÷ 8. */
212
+ NK_INTERNAL __m512 nk_e2m3x16_to_f32x16_skylake_(__m128i e2m3_i8x16) {
213
+ __m512i e2m3_i32x16 = _mm512_cvtepu8_epi32(e2m3_i8x16);
214
+
215
+ // Extract fields (only 6 bits used: S EE MMM)
216
+ __m512i exp_i32x16 = _mm512_and_si512(_mm512_srli_epi32(e2m3_i32x16, 3), _mm512_set1_epi32(0x03));
217
+ __m512i mantissa_i32x16 = _mm512_and_si512(e2m3_i32x16, _mm512_set1_epi32(0x07));
218
+ __m512i sign_i32x16 = _mm512_slli_epi32(_mm512_srli_epi32(e2m3_i32x16, 5), 31);
219
+
220
+ // Normal path: sign | ((exp+126)<<23) | (mantissa<<20)
221
+ __m512i f32_exp_i32x16 = _mm512_slli_epi32(_mm512_add_epi32(exp_i32x16, _mm512_set1_epi32(126)), 23);
222
+ __m512i f32_mantissa_i32x16 = _mm512_slli_epi32(mantissa_i32x16, 20);
223
+ __m512 result_f32x16 = _mm512_castsi512_ps(
224
+ _mm512_ternarylogic_epi32(sign_i32x16, f32_exp_i32x16, f32_mantissa_i32x16, 0xFE));
225
+
226
+ // Subnormal fix: for exp==0 lanes, replace with (mantissa / 8) | sign using masked OR
227
+ __mmask16 is_subnormal = _mm512_testn_epi32_mask(e2m3_i32x16, _mm512_set1_epi32(0x18));
228
+ __m512 subnorm_abs_f32x16 = _mm512_mul_ps(_mm512_cvtepi32_ps(mantissa_i32x16), _mm512_set1_ps(1.0f / 8.0f));
229
+ return _mm512_mask_or_ps(result_f32x16, is_subnormal, subnorm_abs_f32x16, _mm512_castsi512_ps(sign_i32x16));
230
+ }
231
+
232
+ /** @brief Convert 16x e3m2 → 16x f32 via bit manipulation (AVX-512).
233
+ * E3M2 format: S EEE MM (bias=3, only 6 bits used). F32: sign<<31, (exp+124)<<23, mantissa<<21.
234
+ * Subnormals (exp=0): value = mantissa × 2⁽¹⁻³⁾ × 2⁻² = mantissa ÷ 16. */
235
+ NK_INTERNAL __m512 nk_e3m2x16_to_f32x16_skylake_(__m128i e3m2_i8x16) {
236
+ __m512i e3m2_i32x16 = _mm512_cvtepu8_epi32(e3m2_i8x16);
237
+
238
+ // Extract fields (only 6 bits used: S EEE MM)
239
+ __m512i exp_i32x16 = _mm512_and_si512(_mm512_srli_epi32(e3m2_i32x16, 2), _mm512_set1_epi32(0x07));
240
+ __m512i mantissa_i32x16 = _mm512_and_si512(e3m2_i32x16, _mm512_set1_epi32(0x03));
241
+ __m512i sign_i32x16 = _mm512_slli_epi32(_mm512_srli_epi32(e3m2_i32x16, 5), 31);
242
+
243
+ // Normal path: sign | ((exp+124)<<23) | (mantissa<<21)
244
+ __m512i f32_exp_i32x16 = _mm512_slli_epi32(_mm512_add_epi32(exp_i32x16, _mm512_set1_epi32(124)), 23);
245
+ __m512i f32_mantissa_i32x16 = _mm512_slli_epi32(mantissa_i32x16, 21);
246
+ __m512 result_f32x16 = _mm512_castsi512_ps(
247
+ _mm512_ternarylogic_epi32(sign_i32x16, f32_exp_i32x16, f32_mantissa_i32x16, 0xFE));
248
+
249
+ // Subnormal fix: for exp==0 lanes, replace with (mantissa / 16) | sign using masked OR
250
+ __mmask16 is_subnormal = _mm512_testn_epi32_mask(e3m2_i32x16, _mm512_set1_epi32(0x1C));
251
+ __m512 subnorm_abs_f32x16 = _mm512_mul_ps(_mm512_cvtepi32_ps(mantissa_i32x16), _mm512_set1_ps(1.0f / 16.0f));
252
+ return _mm512_mask_or_ps(result_f32x16, is_subnormal, subnorm_abs_f32x16, _mm512_castsi512_ps(sign_i32x16));
253
+ }
254
+
255
+ /** @brief Convert 16x f32 → 16x e2m3 via bit manipulation (AVX-512).
256
+ * E2M3 format: S EE MMM (bias=1). Handles normal, subnormal, and overflow cases.
257
+ * Subnormals (f32_exp ≤ 126): mantissa = round(abs_f32 * 8), clamped to [0,7]. */
258
+ NK_INTERNAL __m128i nk_f32x16_to_e2m3x16_skylake_(__m512 f32x16) {
259
+ __m512i bits_i32x16 = _mm512_castps_si512(f32x16);
260
+ __m512i sign_i32x16 = _mm512_srli_epi32(bits_i32x16, 31);
261
+ __m512i f32_exp_i32x16 = _mm512_and_si512(_mm512_srli_epi32(bits_i32x16, 23), _mm512_set1_epi32(0xFF));
262
+
263
+ // Round mantissa from 23 to 3 bits using RNE (round to nearest, ties to even)
264
+ __m512i significand_i32x16 = _mm512_or_si512(_mm512_and_si512(bits_i32x16, _mm512_set1_epi32(0x007FFFFF)),
265
+ _mm512_set1_epi32(0x00800000)); // (a & mask) | implicit_one
266
+ __m512i lsb_i32x16 = _mm512_and_si512(_mm512_srli_epi32(significand_i32x16, 20), _mm512_set1_epi32(1));
267
+ __m512i rounding_bias_i32x16 = _mm512_add_epi32(_mm512_set1_epi32(0x0007FFFF), lsb_i32x16);
268
+ __m512i rounded_sig_i32x16 = _mm512_add_epi32(significand_i32x16, rounding_bias_i32x16);
269
+ __m512i carry_i32x16 = _mm512_srli_epi32(rounded_sig_i32x16, 24); // Carry into exponent if bit 24 set
270
+ __m512i f32_mantissa_i32x16 = _mm512_and_si512(_mm512_srli_epi32(rounded_sig_i32x16, 20), _mm512_set1_epi32(0x07));
271
+ // If carry, mantissa becomes 0 (we rounded up to next power of 2)
272
+ f32_mantissa_i32x16 = _mm512_andnot_si512(_mm512_slli_epi32(carry_i32x16, 31), f32_mantissa_i32x16);
273
+ __m512i e2m3_exp_i32x16 = _mm512_sub_epi32(_mm512_add_epi32(f32_exp_i32x16, carry_i32x16), _mm512_set1_epi32(126));
274
+
275
+ // Detect underflow (exp <= 0, maps to subnormal/zero) and overflow (exp > 3)
276
+ __mmask16 is_subnormal = _mm512_cmpgt_epi32_mask(_mm512_set1_epi32(1), e2m3_exp_i32x16);
277
+ __mmask16 overflow = _mm512_cmpgt_epi32_mask(e2m3_exp_i32x16, _mm512_set1_epi32(3));
278
+
279
+ // Normal path: clamp exp to [1,3], extract mantissa bits
280
+ __m512i clamped_exp_i32x16 = _mm512_max_epi32(e2m3_exp_i32x16, _mm512_set1_epi32(1));
281
+ clamped_exp_i32x16 = _mm512_min_epi32(clamped_exp_i32x16, _mm512_set1_epi32(3));
282
+ __m512i normal_mantissa_i32x16 = _mm512_mask_blend_epi32(overflow, f32_mantissa_i32x16, _mm512_set1_epi32(0x07));
283
+ __m512i normal_e2m3_i32x16 = _mm512_ternarylogic_epi32(_mm512_slli_epi32(sign_i32x16, 5),
284
+ _mm512_slli_epi32(clamped_exp_i32x16, 3),
285
+ normal_mantissa_i32x16, 0xFE); // a | b | c
286
+
287
+ // Subnormal path: mantissa = round(abs_f32 * 8)
288
+ // If mantissa rounds to 8 or higher, promote to first normal (exp_field=1, mantissa=0) = 0x08
289
+ __m512 abs_f32x16 = _mm512_and_ps(f32x16, _mm512_castsi512_ps(_mm512_set1_epi32(0x7FFFFFFF)));
290
+ __m512 scaled_f32x16 = _mm512_mul_ps(abs_f32x16, _mm512_set1_ps(8.0f));
291
+ __m512i subnorm_mantissa_i32x16 = _mm512_cvtps_epi32(scaled_f32x16);
292
+ __mmask16 promotes_to_normal = _mm512_cmpgt_epi32_mask(subnorm_mantissa_i32x16, _mm512_set1_epi32(7));
293
+ subnorm_mantissa_i32x16 = _mm512_min_epi32(subnorm_mantissa_i32x16, _mm512_set1_epi32(7));
294
+ subnorm_mantissa_i32x16 = _mm512_max_epi32(subnorm_mantissa_i32x16, _mm512_setzero_si512());
295
+ __m512i subnorm_e2m3_i32x16 = _mm512_or_si512(_mm512_slli_epi32(sign_i32x16, 5), subnorm_mantissa_i32x16);
296
+ // When mantissa rounds to 8, use first normal value (0x08) instead of clamped subnormal
297
+ __m512i first_normal_e2m3_i32x16 = _mm512_or_si512(_mm512_slli_epi32(sign_i32x16, 5), _mm512_set1_epi32(0x08));
298
+ subnorm_e2m3_i32x16 = _mm512_mask_blend_epi32(promotes_to_normal, subnorm_e2m3_i32x16, first_normal_e2m3_i32x16);
299
+
300
+ // Blend: use subnormal result when exp <= 0, else normal
301
+ __m512i e2m3_i32x16 = _mm512_mask_blend_epi32(is_subnormal, normal_e2m3_i32x16, subnorm_e2m3_i32x16);
302
+
303
+ // Pack 16 i32s to 16 unsigned i8s via AVX-512 cvtepi32_epi8
304
+ return _mm512_cvtepi32_epi8(e2m3_i32x16);
305
+ }
306
+
307
+ /** @brief Convert 16x f32 → 16x e3m2 via bit manipulation (AVX-512).
308
+ * E3M2 format: S EEE MM (bias=3). Handles normal, subnormal, and overflow cases.
309
+ * Subnormals (f32_exp ≤ 124): mantissa = round(abs_f32 * 16), clamped to [0,3]. */
310
+ NK_INTERNAL __m128i nk_f32x16_to_e3m2x16_skylake_(__m512 f32x16) {
311
+ __m512i bits_i32x16 = _mm512_castps_si512(f32x16);
312
+ __m512i sign_i32x16 = _mm512_srli_epi32(bits_i32x16, 31);
313
+ __m512i f32_exp_i32x16 = _mm512_and_si512(_mm512_srli_epi32(bits_i32x16, 23), _mm512_set1_epi32(0xFF));
314
+
315
+ // Round mantissa from 23 to 2 bits using RNE (round to nearest, ties to even)
316
+ __m512i significand_i32x16 = _mm512_or_si512(_mm512_and_si512(bits_i32x16, _mm512_set1_epi32(0x007FFFFF)),
317
+ _mm512_set1_epi32(0x00800000)); // (a & mask) | implicit_one
318
+ __m512i lsb_i32x16 = _mm512_and_si512(_mm512_srli_epi32(significand_i32x16, 21), _mm512_set1_epi32(1));
319
+ __m512i rounding_bias_i32x16 = _mm512_add_epi32(_mm512_set1_epi32(0x000FFFFF), lsb_i32x16);
320
+ __m512i rounded_sig_i32x16 = _mm512_add_epi32(significand_i32x16, rounding_bias_i32x16);
321
+ __m512i carry_i32x16 = _mm512_srli_epi32(rounded_sig_i32x16, 24); // Carry into exponent if bit 24 set
322
+ __m512i f32_mantissa_i32x16 = _mm512_and_si512(_mm512_srli_epi32(rounded_sig_i32x16, 21), _mm512_set1_epi32(0x03));
323
+ // If carry, mantissa becomes 0 (we rounded up to next power of 2)
324
+ f32_mantissa_i32x16 = _mm512_andnot_si512(_mm512_slli_epi32(carry_i32x16, 31), f32_mantissa_i32x16);
325
+ __m512i e3m2_exp_i32x16 = _mm512_sub_epi32(_mm512_add_epi32(f32_exp_i32x16, carry_i32x16), _mm512_set1_epi32(124));
326
+
327
+ // Detect underflow (exp <= 0, maps to subnormal/zero) and overflow (exp > 7)
328
+ __mmask16 is_subnormal = _mm512_cmpgt_epi32_mask(_mm512_set1_epi32(1), e3m2_exp_i32x16);
329
+ __mmask16 overflow = _mm512_cmpgt_epi32_mask(e3m2_exp_i32x16, _mm512_set1_epi32(7));
330
+
331
+ // Normal path: clamp exp to [1,7], extract mantissa bits
332
+ __m512i clamped_exp_i32x16 = _mm512_max_epi32(e3m2_exp_i32x16, _mm512_set1_epi32(1));
333
+ clamped_exp_i32x16 = _mm512_min_epi32(clamped_exp_i32x16, _mm512_set1_epi32(7));
334
+ __m512i normal_mantissa_i32x16 = _mm512_mask_blend_epi32(overflow, f32_mantissa_i32x16, _mm512_set1_epi32(0x03));
335
+ __m512i normal_e3m2_i32x16 = _mm512_ternarylogic_epi32(_mm512_slli_epi32(sign_i32x16, 5),
336
+ _mm512_slli_epi32(clamped_exp_i32x16, 2),
337
+ normal_mantissa_i32x16, 0xFE); // a | b | c
338
+
339
+ // Subnormal path: mantissa = round(abs_f32 * 16)
340
+ // If mantissa rounds to 4 or higher, promote to first normal (exp_field=1, mantissa=0) = 0x04
341
+ __m512 abs_f32x16 = _mm512_and_ps(f32x16, _mm512_castsi512_ps(_mm512_set1_epi32(0x7FFFFFFF)));
342
+ __m512 scaled_f32x16 = _mm512_mul_ps(abs_f32x16, _mm512_set1_ps(16.0f));
343
+ __m512i subnorm_mantissa_i32x16 = _mm512_cvtps_epi32(scaled_f32x16);
344
+ __mmask16 promotes_to_normal = _mm512_cmpgt_epi32_mask(subnorm_mantissa_i32x16, _mm512_set1_epi32(3));
345
+ subnorm_mantissa_i32x16 = _mm512_min_epi32(subnorm_mantissa_i32x16, _mm512_set1_epi32(3));
346
+ subnorm_mantissa_i32x16 = _mm512_max_epi32(subnorm_mantissa_i32x16, _mm512_setzero_si512());
347
+ __m512i subnorm_e3m2_i32x16 = _mm512_or_si512(_mm512_slli_epi32(sign_i32x16, 5), subnorm_mantissa_i32x16);
348
+ // When mantissa rounds to 4, use first normal value (0x04) instead of clamped subnormal
349
+ __m512i first_normal_e3m2_i32x16 = _mm512_or_si512(_mm512_slli_epi32(sign_i32x16, 5), _mm512_set1_epi32(0x04));
350
+ subnorm_e3m2_i32x16 = _mm512_mask_blend_epi32(promotes_to_normal, subnorm_e3m2_i32x16, first_normal_e3m2_i32x16);
351
+
352
+ // Blend: use subnormal result when exp <= 0, else normal
353
+ __m512i e3m2_i32x16 = _mm512_mask_blend_epi32(is_subnormal, normal_e3m2_i32x16, subnorm_e3m2_i32x16);
354
+
355
+ // Pack 16 i32s to 16 unsigned i8s via AVX-512 cvtepi32_epi8
356
+ return _mm512_cvtepi32_epi8(e3m2_i32x16);
357
+ }
358
+
359
+ /** @brief Convert 16x f32 → 16x e4m3 via bit manipulation (AVX-512).
360
+ * E4M3 format: S EEEE MMM (bias=7). Handles normal, subnormal, and overflow cases.
361
+ * Subnormals (f32_exp ≤ 120): mantissa = round(abs_f32 * 512), clamped to [0,7]. */
362
+ NK_INTERNAL __m128i nk_f32x16_to_e4m3x16_skylake_(__m512 f32x16) {
363
+ __m512i bits_i32x16 = _mm512_castps_si512(f32x16);
364
+ __m512i sign_i32x16 = _mm512_srli_epi32(bits_i32x16, 31);
365
+ __m512i f32_exp_i32x16 = _mm512_and_si512(_mm512_srli_epi32(bits_i32x16, 23), _mm512_set1_epi32(0xFF));
366
+
367
+ // Round mantissa from 23 to 3 bits using RNE (round to nearest, ties to even)
368
+ // RNE trick: add (half - 1 + lsb) where lsb is the bit that will become the new lsb after shift
369
+ __m512i significand_i32x16 = _mm512_or_si512(_mm512_and_si512(bits_i32x16, _mm512_set1_epi32(0x007FFFFF)),
370
+ _mm512_set1_epi32(0x00800000)); // (a & mask) | implicit_one
371
+ __m512i lsb_i32x16 = _mm512_and_si512(_mm512_srli_epi32(significand_i32x16, 20), _mm512_set1_epi32(1));
372
+ __m512i rounding_bias_i32x16 = _mm512_add_epi32(_mm512_set1_epi32(0x0007FFFF), lsb_i32x16);
373
+ __m512i rounded_sig_i32x16 = _mm512_add_epi32(significand_i32x16, rounding_bias_i32x16);
374
+ __m512i carry_i32x16 = _mm512_srli_epi32(rounded_sig_i32x16, 24); // Carry into exponent if bit 24 set
375
+ __m512i f32_mantissa_i32x16 = _mm512_and_si512(_mm512_srli_epi32(rounded_sig_i32x16, 20), _mm512_set1_epi32(0x07));
376
+ // If carry, mantissa becomes 0 (we rounded up to next power of 2)
377
+ f32_mantissa_i32x16 = _mm512_andnot_si512(_mm512_slli_epi32(carry_i32x16, 31), f32_mantissa_i32x16);
378
+ __m512i e4m3_exp_i32x16 = _mm512_sub_epi32(_mm512_add_epi32(f32_exp_i32x16, carry_i32x16), _mm512_set1_epi32(120));
379
+
380
+ // Detect underflow (exp <= 0, maps to subnormal/zero) and overflow (exp > 15)
381
+ __mmask16 is_subnormal = _mm512_cmpgt_epi32_mask(_mm512_set1_epi32(1), e4m3_exp_i32x16);
382
+ __mmask16 overflow = _mm512_cmpgt_epi32_mask(e4m3_exp_i32x16, _mm512_set1_epi32(15));
383
+
384
+ // Normal path: clamp exp to [1,15], extract mantissa bits
385
+ // e4m3FN quirk: exp=15 with mantissa=7 is NaN (0x7F), so clamp mantissa to 6 when exp=15.
386
+ __m512i clamped_exp_i32x16 = _mm512_max_epi32(e4m3_exp_i32x16, _mm512_set1_epi32(1));
387
+ clamped_exp_i32x16 = _mm512_min_epi32(clamped_exp_i32x16, _mm512_set1_epi32(15));
388
+ __mmask16 is_max_exp = _mm512_cmpeq_epi32_mask(clamped_exp_i32x16, _mm512_set1_epi32(15));
389
+ __m512i max_mantissa_i32x16 = _mm512_mask_blend_epi32(is_max_exp, _mm512_set1_epi32(7), _mm512_set1_epi32(6));
390
+ __m512i normal_mantissa_i32x16 = _mm512_min_epi32(f32_mantissa_i32x16, max_mantissa_i32x16);
391
+ normal_mantissa_i32x16 = _mm512_mask_blend_epi32(overflow, normal_mantissa_i32x16, _mm512_set1_epi32(0x06));
392
+ __m512i normal_e4m3_i32x16 = _mm512_ternarylogic_epi32(_mm512_slli_epi32(sign_i32x16, 7),
393
+ _mm512_slli_epi32(clamped_exp_i32x16, 3),
394
+ normal_mantissa_i32x16, 0xFE); // a | b | c
395
+
396
+ // Subnormal path: mantissa = round(abs_f32 * 512)
397
+ // If mantissa rounds to 8 or higher, promote to first normal (exp_field=1, mantissa=0) = 0x08
398
+ __m512 abs_f32x16 = _mm512_and_ps(f32x16, _mm512_castsi512_ps(_mm512_set1_epi32(0x7FFFFFFF)));
399
+ __m512 scaled_f32x16 = _mm512_mul_ps(abs_f32x16, _mm512_set1_ps(512.0f));
400
+ __m512i subnorm_mantissa_i32x16 = _mm512_cvtps_epi32(scaled_f32x16);
401
+ __mmask16 promotes_to_normal = _mm512_cmpgt_epi32_mask(subnorm_mantissa_i32x16, _mm512_set1_epi32(7));
402
+ subnorm_mantissa_i32x16 = _mm512_min_epi32(subnorm_mantissa_i32x16, _mm512_set1_epi32(7));
403
+ subnorm_mantissa_i32x16 = _mm512_max_epi32(subnorm_mantissa_i32x16, _mm512_setzero_si512());
404
+ __m512i subnorm_e4m3_i32x16 = _mm512_or_si512(_mm512_slli_epi32(sign_i32x16, 7), subnorm_mantissa_i32x16);
405
+ // When mantissa rounds to 8, use first normal value (0x08) instead of clamped subnormal
406
+ __m512i first_normal_e4m3_i32x16 = _mm512_or_si512(_mm512_slli_epi32(sign_i32x16, 7), _mm512_set1_epi32(0x08));
407
+ subnorm_e4m3_i32x16 = _mm512_mask_blend_epi32(promotes_to_normal, subnorm_e4m3_i32x16, first_normal_e4m3_i32x16);
408
+
409
+ // Blend: use subnormal result when exp <= 0, else normal
410
+ __m512i e4m3_i32x16 = _mm512_mask_blend_epi32(is_subnormal, normal_e4m3_i32x16, subnorm_e4m3_i32x16);
411
+
412
+ // Pack 16 i32s to 16 unsigned i8s via AVX-512 cvtepi32_epi8
413
+ return _mm512_cvtepi32_epi8(e4m3_i32x16);
414
+ }
415
+
416
+ /** @brief Convert 16x f32 → 16x e5m2 via bit manipulation (AVX-512).
417
+ * E5M2 format: S EEEEE MM (bias=15). Handles normal, subnormal, and overflow cases.
418
+ * Uses RNE (round to nearest even) for mantissa rounding. */
419
+ NK_INTERNAL __m128i nk_f32x16_to_e5m2x16_skylake_(__m512 f32x16) {
420
+ __m512i bits_i32x16 = _mm512_castps_si512(f32x16);
421
+ __m512i sign_i32x16 = _mm512_srli_epi32(bits_i32x16, 31);
422
+ __m512i f32_exp_i32x16 = _mm512_and_si512(_mm512_srli_epi32(bits_i32x16, 23), _mm512_set1_epi32(0xFF));
423
+
424
+ // Round mantissa from 23 to 2 bits using RNE (round to nearest, ties to even)
425
+ // RNE trick: add (half - 1 + lsb) where lsb is the bit that will become the new lsb after shift
426
+ __m512i significand_i32x16 = _mm512_or_si512(_mm512_and_si512(bits_i32x16, _mm512_set1_epi32(0x007FFFFF)),
427
+ _mm512_set1_epi32(0x00800000)); // (a & mask) | implicit_one
428
+ __m512i lsb_i32x16 = _mm512_and_si512(_mm512_srli_epi32(significand_i32x16, 21), _mm512_set1_epi32(1));
429
+ __m512i rounding_bias_i32x16 = _mm512_add_epi32(_mm512_set1_epi32(0x000FFFFF), lsb_i32x16); // half = 0x100000
430
+ __m512i rounded_sig_i32x16 = _mm512_add_epi32(significand_i32x16, rounding_bias_i32x16);
431
+ __m512i carry_i32x16 = _mm512_srli_epi32(rounded_sig_i32x16, 24); // Carry into exponent if bit 24 set
432
+ __m512i f32_mantissa_i32x16 = _mm512_and_si512(_mm512_srli_epi32(rounded_sig_i32x16, 21), _mm512_set1_epi32(0x03));
433
+ // If carry, mantissa becomes 0 (we rounded up to next power of 2)
434
+ f32_mantissa_i32x16 = _mm512_andnot_si512(_mm512_slli_epi32(carry_i32x16, 31), f32_mantissa_i32x16);
435
+ __m512i e5m2_exp_i32x16 = _mm512_sub_epi32(_mm512_add_epi32(f32_exp_i32x16, carry_i32x16), _mm512_set1_epi32(112));
436
+
437
+ // Detect subnormal (exp <= 0) and overflow (exp > 31)
438
+ __mmask16 is_subnormal = _mm512_cmpgt_epi32_mask(_mm512_set1_epi32(1), e5m2_exp_i32x16);
439
+ __mmask16 overflow = _mm512_cmpgt_epi32_mask(e5m2_exp_i32x16, _mm512_set1_epi32(31));
440
+
441
+ // Normal path: clamp exp to [1,31], on overflow return infinity (exp=31, mantissa=0 = 0x7C)
442
+ __m512i clamped_exp_i32x16 = _mm512_max_epi32(e5m2_exp_i32x16, _mm512_set1_epi32(1));
443
+ clamped_exp_i32x16 = _mm512_min_epi32(clamped_exp_i32x16, _mm512_set1_epi32(31));
444
+ __m512i normal_mantissa_i32x16 = _mm512_mask_blend_epi32(overflow, f32_mantissa_i32x16, _mm512_setzero_si512());
445
+ __m512i normal_e5m2_i32x16 = _mm512_ternarylogic_epi32(_mm512_slli_epi32(sign_i32x16, 7),
446
+ _mm512_slli_epi32(clamped_exp_i32x16, 2),
447
+ normal_mantissa_i32x16, 0xFE); // a | b | c
448
+
449
+ // Subnormal path: mantissa = round(abs_f32 * 65536)
450
+ // If mantissa rounds to 4 or higher, promote to first normal (exp_field=1, mantissa=0) = 0x04
451
+ __m512 abs_f32x16 = _mm512_and_ps(f32x16, _mm512_castsi512_ps(_mm512_set1_epi32(0x7FFFFFFF)));
452
+ __m512 scaled_f32x16 = _mm512_mul_ps(abs_f32x16, _mm512_set1_ps(65536.0f));
453
+ __m512i subnorm_mantissa_i32x16 = _mm512_cvtps_epi32(scaled_f32x16);
454
+ __mmask16 promotes_to_normal = _mm512_cmpgt_epi32_mask(subnorm_mantissa_i32x16, _mm512_set1_epi32(3));
455
+ subnorm_mantissa_i32x16 = _mm512_min_epi32(subnorm_mantissa_i32x16, _mm512_set1_epi32(3));
456
+ subnorm_mantissa_i32x16 = _mm512_max_epi32(subnorm_mantissa_i32x16, _mm512_setzero_si512());
457
+ __m512i subnorm_e5m2_i32x16 = _mm512_or_si512(_mm512_slli_epi32(sign_i32x16, 7), subnorm_mantissa_i32x16);
458
+ // When mantissa rounds to 4, use first normal value (0x04) instead of clamped subnormal
459
+ __m512i first_normal_e5m2_i32x16 = _mm512_or_si512(_mm512_slli_epi32(sign_i32x16, 7), _mm512_set1_epi32(0x04));
460
+ subnorm_e5m2_i32x16 = _mm512_mask_blend_epi32(promotes_to_normal, subnorm_e5m2_i32x16, first_normal_e5m2_i32x16);
461
+
462
+ // Blend: use subnormal result when exp <= 0
463
+ __m512i e5m2_i32x16 = _mm512_mask_blend_epi32(is_subnormal, normal_e5m2_i32x16, subnorm_e5m2_i32x16);
464
+
465
+ // Pack 16 i32s to 16 unsigned i8s via AVX-512 cvtepi32_epi8
466
+ return _mm512_cvtepi32_epi8(e5m2_i32x16);
467
+ }
468
+
469
+ NK_INTERNAL __m512 nk_i8x16_to_f32x16_skylake_(__m128i i8x16) {
470
+ return _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(i8x16));
471
+ }
472
+ NK_INTERNAL __m512 nk_u8x16_to_f32x16_skylake_(__m128i u8x16) {
473
+ return _mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(u8x16));
474
+ }
475
+ NK_INTERNAL __m512 nk_i16x16_to_f32x16_skylake_(__m256i i16x16) {
476
+ return _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(i16x16));
477
+ }
478
+ NK_INTERNAL __m512 nk_u16x16_to_f32x16_skylake_(__m256i u16x16) {
479
+ return _mm512_cvtepi32_ps(_mm512_cvtepu16_epi32(u16x16));
480
+ }
481
+
482
+ NK_INTERNAL __m128i nk_f32x16_to_i8x16_skylake_(__m512 f32x16) {
483
+ __m512 clamped = _mm512_min_ps(_mm512_max_ps(f32x16, _mm512_set1_ps(-128.0f)), _mm512_set1_ps(127.0f));
484
+ return _mm512_cvtsepi32_epi8(_mm512_cvtps_epi32(clamped));
485
+ }
486
+ NK_INTERNAL __m128i nk_f32x16_to_u8x16_skylake_(__m512 f32x16) {
487
+ __m512 clamped = _mm512_min_ps(_mm512_max_ps(f32x16, _mm512_setzero_ps()), _mm512_set1_ps(255.0f));
488
+ return _mm512_cvtusepi32_epi8(_mm512_cvtps_epu32(clamped));
489
+ }
490
+ NK_INTERNAL __m256i nk_f32x16_to_i16x16_skylake_(__m512 f32x16) {
491
+ __m512 clamped = _mm512_min_ps(_mm512_max_ps(f32x16, _mm512_set1_ps(-32768.0f)), _mm512_set1_ps(32767.0f));
492
+ return _mm512_cvtsepi32_epi16(_mm512_cvtps_epi32(clamped));
493
+ }
494
+ NK_INTERNAL __m256i nk_f32x16_to_u16x16_skylake_(__m512 f32x16) {
495
+ __m512 clamped = _mm512_min_ps(_mm512_max_ps(f32x16, _mm512_setzero_ps()), _mm512_set1_ps(65535.0f));
496
+ return _mm512_cvtusepi32_epi16(_mm512_cvtps_epu32(clamped));
497
+ }
498
+
499
+ NK_INTERNAL __m512i nk_u8x8_to_u64x8_skylake_(__m128i u8x8) { return _mm512_cvtepu8_epi64(u8x8); }
500
+ NK_INTERNAL __m512i nk_u16x8_to_u64x8_skylake_(__m128i u16x8) { return _mm512_cvtepu16_epi64(u16x8); }
501
+ NK_INTERNAL __m512i nk_u32x8_to_u64x8_skylake_(__m256i u32x8) { return _mm512_cvtepu32_epi64(u32x8); }
502
+
503
+ NK_INTERNAL __m128i nk_u64x8_to_u8x8_skylake_(__m512i u64x8) {
504
+ __m512i clamped = _mm512_min_epu64(u64x8, _mm512_set1_epi64(255));
505
+ return _mm512_cvtepi64_epi8(clamped);
506
+ }
507
+ NK_INTERNAL __m128i nk_u64x8_to_u16x8_skylake_(__m512i u64x8) {
508
+ __m512i clamped = _mm512_min_epu64(u64x8, _mm512_set1_epi64(65535));
509
+ return _mm512_cvtepi64_epi16(clamped);
510
+ }
511
+ NK_INTERNAL __m256i nk_u64x8_to_u32x8_skylake_(__m512i u64x8) {
512
+ __m512i clamped = _mm512_min_epu64(u64x8, _mm512_set1_epi64(0xFFFFFFFFULL));
513
+ return _mm512_cvtepi64_epi32(clamped);
514
+ }
515
+
516
+ NK_INTERNAL __m512i nk_i8x8_to_i64x8_skylake_(__m128i i8x8) { return _mm512_cvtepi8_epi64(i8x8); }
517
+ NK_INTERNAL __m512i nk_i16x8_to_i64x8_skylake_(__m128i i16x8) { return _mm512_cvtepi16_epi64(i16x8); }
518
+ NK_INTERNAL __m512i nk_i32x8_to_i64x8_skylake_(__m256i i32x8) { return _mm512_cvtepi32_epi64(i32x8); }
519
+ NK_INTERNAL __m512i nk_u8x8_to_i64x8_skylake_(__m128i u8x8) { return _mm512_cvtepu8_epi64(u8x8); }
520
+ NK_INTERNAL __m512i nk_u16x8_to_i64x8_skylake_(__m128i u16x8) { return _mm512_cvtepu16_epi64(u16x8); }
521
+ NK_INTERNAL __m512i nk_u32x8_to_i64x8_skylake_(__m256i u32x8) { return _mm512_cvtepu32_epi64(u32x8); }
522
+
523
+ NK_INTERNAL __m128i nk_i64x8_to_i8x8_skylake_(__m512i i64x8) {
524
+ __m512i clamped = _mm512_max_epi64(_mm512_min_epi64(i64x8, _mm512_set1_epi64(127)), _mm512_set1_epi64(-128));
525
+ return _mm512_cvtepi64_epi8(clamped);
526
+ }
527
+ NK_INTERNAL __m128i nk_i64x8_to_u8x8_skylake_(__m512i i64x8) {
528
+ __m512i clamped = _mm512_max_epi64(_mm512_min_epi64(i64x8, _mm512_set1_epi64(255)), _mm512_setzero_si512());
529
+ return _mm512_cvtepi64_epi8(clamped);
530
+ }
531
+ NK_INTERNAL __m128i nk_i64x8_to_i16x8_skylake_(__m512i i64x8) {
532
+ __m512i clamped = _mm512_max_epi64(_mm512_min_epi64(i64x8, _mm512_set1_epi64(32767)), _mm512_set1_epi64(-32768));
533
+ return _mm512_cvtepi64_epi16(clamped);
534
+ }
535
+ NK_INTERNAL __m128i nk_i64x8_to_u16x8_skylake_(__m512i i64x8) {
536
+ __m512i clamped = _mm512_max_epi64(_mm512_min_epi64(i64x8, _mm512_set1_epi64(65535)), _mm512_setzero_si512());
537
+ return _mm512_cvtepi64_epi16(clamped);
538
+ }
539
+ NK_INTERNAL __m256i nk_i64x8_to_i32x8_skylake_(__m512i i64x8) {
540
+ __m512i clamped = _mm512_max_epi64(_mm512_min_epi64(i64x8, _mm512_set1_epi64(NK_I32_MAX)),
541
+ _mm512_set1_epi64(NK_I32_MIN));
542
+ return _mm512_cvtepi64_epi32(clamped);
543
+ }
544
+ NK_INTERNAL __m256i nk_i64x8_to_u32x8_skylake_(__m512i i64x8) {
545
+ __m512i clamped = _mm512_max_epi64(_mm512_min_epi64(i64x8, _mm512_set1_epi64(NK_U32_MAX)), _mm512_setzero_si512());
546
+ return _mm512_cvtepi64_epi32(clamped);
547
+ }
548
+
549
+ NK_INTERNAL __m512d nk_f32x8_to_f64x8_skylake_(__m256 f32x8) { return _mm512_cvtps_pd(f32x8); }
550
+ NK_INTERNAL __m512d nk_i32x8_to_f64x8_skylake_(__m256i i32x8) { return _mm512_cvtepi32_pd(i32x8); }
551
+ NK_INTERNAL __m512d nk_u32x8_to_f64x8_skylake_(__m256i u32x8) { return _mm512_cvtepu32_pd(u32x8); }
552
+
553
+ NK_INTERNAL __m256 nk_f64x8_to_f32x8_skylake_(__m512d f64x8) { return _mm512_cvtpd_ps(f64x8); }
554
+ NK_INTERNAL __m256i nk_f64x8_to_i32x8_skylake_(__m512d f64x8) {
555
+ __m512d clamped = _mm512_min_pd(_mm512_max_pd(f64x8, _mm512_set1_pd((double)NK_I32_MIN)),
556
+ _mm512_set1_pd((double)NK_I32_MAX));
557
+ return _mm512_cvtpd_epi32(clamped);
558
+ }
559
+ NK_INTERNAL __m256i nk_f64x8_to_u32x8_skylake_(__m512d f64x8) {
560
+ __m512d clamped = _mm512_min_pd(_mm512_max_pd(f64x8, _mm512_setzero_pd()), _mm512_set1_pd((double)NK_U32_MAX));
561
+ return _mm512_cvtpd_epu32(clamped);
562
+ }
563
+
564
+ #pragma endregion - Vectorized Conversions
565
+
566
+ #pragma region - Converting Loads and Stores
567
+
568
+ /** @brief Load 16 f16 values and convert to 16 f32 (Skylake AVX-512). */
569
+ NK_INTERNAL void nk_load_f16x16_to_f32x16_skylake_(void const *src, nk_b512_vec_t *dst) {
570
+ dst->zmm_ps = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i const *)src));
571
+ }
572
+
573
+ /** @brief Partial load of up to 16 f16 values with conversion to f32 (Skylake AVX-512). */
574
+ NK_INTERNAL void nk_partial_load_f16x16_to_f32x16_skylake_(void const *src, nk_b512_vec_t *dst, nk_size_t n) {
575
+ nk_b256_vec_t f16_partial;
576
+ nk_partial_load_b16x16_skylake_(src, &f16_partial, n);
577
+ dst->zmm_ps = _mm512_cvtph_ps(f16_partial.ymm);
578
+ }
579
+
580
+ /** @brief Load 16 bf16 values and convert to 16 f32 (Skylake AVX-512). */
581
+ NK_INTERNAL void nk_load_bf16x16_to_f32x16_skylake_(void const *src, nk_b512_vec_t *dst) {
582
+ dst->zmm_ps = nk_bf16x16_to_f32x16_skylake_(_mm256_loadu_si256((__m256i const *)src));
583
+ }
584
+
585
+ /** @brief Partial load of up to 16 bf16 values with conversion to f32 (Skylake AVX-512). */
586
+ NK_INTERNAL void nk_partial_load_bf16x16_to_f32x16_skylake_(void const *src, nk_b512_vec_t *dst, nk_size_t n) {
587
+ nk_b256_vec_t bf16_partial;
588
+ nk_partial_load_b16x16_skylake_(src, &bf16_partial, n);
589
+ dst->zmm_ps = nk_bf16x16_to_f32x16_skylake_(bf16_partial.ymm);
590
+ }
591
+
592
+ /** @brief Load 16 e4m3 values and convert to 16 f32 (Skylake AVX-512). */
593
+ NK_INTERNAL void nk_load_e4m3x16_to_f32x16_skylake_(void const *src, nk_b512_vec_t *dst) {
594
+ dst->zmm_ps = nk_e4m3x16_to_f32x16_skylake_(_mm_loadu_si128((__m128i const *)src));
595
+ }
596
+
597
+ /** @brief Partial load of up to 16 e4m3 values with conversion to f32 (Skylake AVX-512). */
598
+ NK_INTERNAL void nk_partial_load_e4m3x16_to_f32x16_skylake_(void const *src, nk_b512_vec_t *dst, nk_size_t n) {
599
+ nk_b128_vec_t e4m3_partial;
600
+ nk_partial_load_b8x16_skylake_(src, &e4m3_partial, n);
601
+ dst->zmm_ps = nk_e4m3x16_to_f32x16_skylake_(e4m3_partial.xmm);
602
+ }
603
+
604
+ /** @brief Load 16 e5m2 values and convert to 16 f32 (Skylake AVX-512). */
605
+ NK_INTERNAL void nk_load_e5m2x16_to_f32x16_skylake_(void const *src, nk_b512_vec_t *dst) {
606
+ dst->zmm_ps = nk_e5m2x16_to_f32x16_skylake_(_mm_loadu_si128((__m128i const *)src));
607
+ }
608
+
609
+ /** @brief Partial load of up to 16 e5m2 values with conversion to f32 (Skylake AVX-512). */
610
+ NK_INTERNAL void nk_partial_load_e5m2x16_to_f32x16_skylake_(void const *src, nk_b512_vec_t *dst, nk_size_t n) {
611
+ nk_b128_vec_t e5m2_partial;
612
+ nk_partial_load_b8x16_skylake_(src, &e5m2_partial, n);
613
+ dst->zmm_ps = nk_e5m2x16_to_f32x16_skylake_(e5m2_partial.xmm);
614
+ }
615
+
616
+ /** @brief Load 16 e2m3 values and convert to 16 f32 (Skylake AVX-512). */
617
+ NK_INTERNAL void nk_load_e2m3x16_to_f32x16_skylake_(void const *src, nk_b512_vec_t *dst) {
618
+ dst->zmm_ps = nk_e2m3x16_to_f32x16_skylake_(_mm_loadu_si128((__m128i const *)src));
619
+ }
620
+
621
+ /** @brief Partial load of up to 16 e2m3 values with conversion to f32 (Skylake AVX-512). */
622
+ NK_INTERNAL void nk_partial_load_e2m3x16_to_f32x16_skylake_(void const *src, nk_b512_vec_t *dst, nk_size_t n) {
623
+ nk_b128_vec_t e2m3_partial;
624
+ nk_partial_load_b8x16_skylake_(src, &e2m3_partial, n);
625
+ dst->zmm_ps = nk_e2m3x16_to_f32x16_skylake_(e2m3_partial.xmm);
626
+ }
627
+
628
+ /** @brief Load 16 e3m2 values and convert to 16 f32 (Skylake AVX-512). */
629
+ NK_INTERNAL void nk_load_e3m2x16_to_f32x16_skylake_(void const *src, nk_b512_vec_t *dst) {
630
+ dst->zmm_ps = nk_e3m2x16_to_f32x16_skylake_(_mm_loadu_si128((__m128i const *)src));
631
+ }
632
+
633
+ /** @brief Partial load of up to 16 e3m2 values with conversion to f32 (Skylake AVX-512). */
634
+ NK_INTERNAL void nk_partial_load_e3m2x16_to_f32x16_skylake_(void const *src, nk_b512_vec_t *dst, nk_size_t n) {
635
+ nk_b128_vec_t e3m2_partial;
636
+ nk_partial_load_b8x16_skylake_(src, &e3m2_partial, n);
637
+ dst->zmm_ps = nk_e3m2x16_to_f32x16_skylake_(e3m2_partial.xmm);
638
+ }
639
+
640
+ #pragma endregion - Converting Loads and Stores
641
+
642
+ #pragma region - Public API
643
+
644
+ NK_PUBLIC void nk_cast_skylake(void const *from, nk_dtype_t from_type, nk_size_t n, void *to, nk_dtype_t to_type) {
645
+ // Same-type fast path
646
+ if (from_type == to_type) {
647
+ nk_size_t size_bits = nk_dtype_bits(from_type);
648
+ if (size_bits > 0) nk_copy_bytes_(to, from, nk_size_divide_round_up_(n * size_bits, 8));
649
+ return;
650
+ }
651
+
652
+ // Type classification for hub selection
653
+ int from_f32_hub = (from_type == nk_f32_k || from_type == nk_f16_k || from_type == nk_bf16_k ||
654
+ from_type == nk_e4m3_k || from_type == nk_e5m2_k || from_type == nk_e2m3_k ||
655
+ from_type == nk_e3m2_k || from_type == nk_i8_k || from_type == nk_u8_k ||
656
+ from_type == nk_i16_k || from_type == nk_u16_k);
657
+ int to_f32_hub = (to_type == nk_f32_k || to_type == nk_f16_k || to_type == nk_bf16_k || to_type == nk_e4m3_k ||
658
+ to_type == nk_e5m2_k || to_type == nk_e2m3_k || to_type == nk_e3m2_k || to_type == nk_i8_k ||
659
+ to_type == nk_u8_k || to_type == nk_i16_k || to_type == nk_u16_k);
660
+ int from_unsigned = (from_type == nk_u8_k || from_type == nk_u16_k || from_type == nk_u32_k ||
661
+ from_type == nk_u64_k);
662
+ int to_unsigned = (to_type == nk_u8_k || to_type == nk_u16_k || to_type == nk_u32_k || to_type == nk_u64_k);
663
+ int from_signed = (from_type == nk_i8_k || from_type == nk_i16_k || from_type == nk_i32_k || from_type == nk_i64_k);
664
+ int to_signed = (to_type == nk_i8_k || to_type == nk_i16_k || to_type == nk_i32_k || to_type == nk_i64_k);
665
+ int from_f64 = (from_type == nk_f64_k);
666
+ int to_f64 = (to_type == nk_f64_k);
667
+
668
+ nk_u8_t const *from_ptr = (nk_u8_t const *)from;
669
+ nk_u8_t *to_ptr = (nk_u8_t *)to;
670
+
671
+ // Hub 1: f32x16 - float types + small integers (16 elements/batch)
672
+ if (from_f32_hub && to_f32_hub) {
673
+ nk_size_t from_bytes = nk_dtype_bits(from_type) / NK_BITS_PER_BYTE;
674
+ nk_size_t to_bytes = nk_dtype_bits(to_type) / NK_BITS_PER_BYTE;
675
+ while (n > 0) {
676
+ nk_size_t batch = n < 16 ? n : 16;
677
+ __mmask16 mask = (__mmask16)_bzhi_u32(0xFFFF, (unsigned int)batch);
678
+ __m512 hub_f32x16;
679
+
680
+ // Upcast to f32x16
681
+ if (from_type == nk_f32_k) hub_f32x16 = _mm512_maskz_loadu_ps(mask, from_ptr);
682
+ else if (from_type == nk_f16_k) hub_f32x16 = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, from_ptr));
683
+ else if (from_type == nk_bf16_k)
684
+ hub_f32x16 = nk_bf16x16_to_f32x16_skylake_(_mm256_maskz_loadu_epi16(mask, from_ptr));
685
+ else if (from_type == nk_e4m3_k)
686
+ hub_f32x16 = nk_e4m3x16_to_f32x16_skylake_(_mm_maskz_loadu_epi8(mask, from_ptr));
687
+ else if (from_type == nk_e5m2_k)
688
+ hub_f32x16 = nk_e5m2x16_to_f32x16_skylake_(_mm_maskz_loadu_epi8(mask, from_ptr));
689
+ else if (from_type == nk_e2m3_k)
690
+ hub_f32x16 = nk_e2m3x16_to_f32x16_skylake_(_mm_maskz_loadu_epi8(mask, from_ptr));
691
+ else if (from_type == nk_e3m2_k)
692
+ hub_f32x16 = nk_e3m2x16_to_f32x16_skylake_(_mm_maskz_loadu_epi8(mask, from_ptr));
693
+ else if (from_type == nk_i8_k)
694
+ hub_f32x16 = nk_i8x16_to_f32x16_skylake_(_mm_maskz_loadu_epi8(mask, from_ptr));
695
+ else if (from_type == nk_u8_k)
696
+ hub_f32x16 = nk_u8x16_to_f32x16_skylake_(_mm_maskz_loadu_epi8(mask, from_ptr));
697
+ else if (from_type == nk_i16_k)
698
+ hub_f32x16 = nk_i16x16_to_f32x16_skylake_(_mm256_maskz_loadu_epi16(mask, from_ptr));
699
+ else if (from_type == nk_u16_k)
700
+ hub_f32x16 = nk_u16x16_to_f32x16_skylake_(_mm256_maskz_loadu_epi16(mask, from_ptr));
701
+ else hub_f32x16 = _mm512_setzero_ps();
702
+
703
+ // Downcast from f32x16
704
+ if (to_type == nk_f32_k) _mm512_mask_storeu_ps(to_ptr, mask, hub_f32x16);
705
+ else if (to_type == nk_f16_k)
706
+ _mm256_mask_storeu_epi16(to_ptr, mask, _mm512_cvtps_ph(hub_f32x16, _MM_FROUND_TO_NEAREST_INT));
707
+ else if (to_type == nk_bf16_k)
708
+ _mm256_mask_storeu_epi16(to_ptr, mask, nk_f32x16_to_bf16x16_skylake_(hub_f32x16));
709
+ else if (to_type == nk_e4m3_k)
710
+ _mm_mask_storeu_epi8(to_ptr, mask, nk_f32x16_to_e4m3x16_skylake_(hub_f32x16));
711
+ else if (to_type == nk_e5m2_k)
712
+ _mm_mask_storeu_epi8(to_ptr, mask, nk_f32x16_to_e5m2x16_skylake_(hub_f32x16));
713
+ else if (to_type == nk_e2m3_k)
714
+ _mm_mask_storeu_epi8(to_ptr, mask, nk_f32x16_to_e2m3x16_skylake_(hub_f32x16));
715
+ else if (to_type == nk_e3m2_k)
716
+ _mm_mask_storeu_epi8(to_ptr, mask, nk_f32x16_to_e3m2x16_skylake_(hub_f32x16));
717
+ else if (to_type == nk_i8_k) _mm_mask_storeu_epi8(to_ptr, mask, nk_f32x16_to_i8x16_skylake_(hub_f32x16));
718
+ else if (to_type == nk_u8_k) _mm_mask_storeu_epi8(to_ptr, mask, nk_f32x16_to_u8x16_skylake_(hub_f32x16));
719
+ else if (to_type == nk_i16_k)
720
+ _mm256_mask_storeu_epi16(to_ptr, mask, nk_f32x16_to_i16x16_skylake_(hub_f32x16));
721
+ else if (to_type == nk_u16_k)
722
+ _mm256_mask_storeu_epi16(to_ptr, mask, nk_f32x16_to_u16x16_skylake_(hub_f32x16));
723
+
724
+ from_ptr += batch * from_bytes;
725
+ to_ptr += batch * to_bytes;
726
+ n -= batch;
727
+ }
728
+ return;
729
+ }
730
+
731
+ // Hub 2: u64x8 - unsigned ↔ unsigned integers (8 elements/batch)
732
+ if (from_unsigned && to_unsigned) {
733
+ nk_size_t from_bytes = nk_dtype_bits(from_type) / NK_BITS_PER_BYTE;
734
+ nk_size_t to_bytes = nk_dtype_bits(to_type) / NK_BITS_PER_BYTE;
735
+ while (n > 0) {
736
+ nk_size_t batch = n < 8 ? n : 8;
737
+ __mmask8 mask = (__mmask8)_bzhi_u32(0xFF, (unsigned int)batch);
738
+ __m512i hub_u64x8;
739
+
740
+ // Upcast to u64x8
741
+ if (from_type == nk_u8_k) hub_u64x8 = nk_u8x8_to_u64x8_skylake_(_mm_maskz_loadu_epi8(mask, from_ptr));
742
+ else if (from_type == nk_u16_k)
743
+ hub_u64x8 = nk_u16x8_to_u64x8_skylake_(_mm_maskz_loadu_epi16(mask, from_ptr));
744
+ else if (from_type == nk_u32_k)
745
+ hub_u64x8 = nk_u32x8_to_u64x8_skylake_(_mm256_maskz_loadu_epi32(mask, from_ptr));
746
+ else if (from_type == nk_u64_k) hub_u64x8 = _mm512_maskz_loadu_epi64(mask, from_ptr);
747
+ else hub_u64x8 = _mm512_setzero_si512();
748
+
749
+ // Downcast from u64x8
750
+ if (to_type == nk_u8_k) _mm_mask_storeu_epi8(to_ptr, mask, nk_u64x8_to_u8x8_skylake_(hub_u64x8));
751
+ else if (to_type == nk_u16_k) _mm_mask_storeu_epi16(to_ptr, mask, nk_u64x8_to_u16x8_skylake_(hub_u64x8));
752
+ else if (to_type == nk_u32_k) _mm256_mask_storeu_epi32(to_ptr, mask, nk_u64x8_to_u32x8_skylake_(hub_u64x8));
753
+ else if (to_type == nk_u64_k) _mm512_mask_storeu_epi64(to_ptr, mask, hub_u64x8);
754
+
755
+ from_ptr += batch * from_bytes;
756
+ to_ptr += batch * to_bytes;
757
+ n -= batch;
758
+ }
759
+ return;
760
+ }
761
+
762
+ // Hub 3: i64x8 - signed/mixed integer conversions (8 elements/batch)
763
+ if ((from_signed || from_unsigned) && (to_signed || to_unsigned)) {
764
+ nk_size_t from_bytes = nk_dtype_bits(from_type) / NK_BITS_PER_BYTE;
765
+ nk_size_t to_bytes = nk_dtype_bits(to_type) / NK_BITS_PER_BYTE;
766
+ while (n > 0) {
767
+ nk_size_t batch = n < 8 ? n : 8;
768
+ __mmask8 mask = (__mmask8)_bzhi_u32(0xFF, (unsigned int)batch);
769
+ __m512i hub_i64x8;
770
+
771
+ // Upcast to i64x8
772
+ if (from_type == nk_i8_k) hub_i64x8 = nk_i8x8_to_i64x8_skylake_(_mm_maskz_loadu_epi8(mask, from_ptr));
773
+ else if (from_type == nk_u8_k) hub_i64x8 = nk_u8x8_to_i64x8_skylake_(_mm_maskz_loadu_epi8(mask, from_ptr));
774
+ else if (from_type == nk_i16_k)
775
+ hub_i64x8 = nk_i16x8_to_i64x8_skylake_(_mm_maskz_loadu_epi16(mask, from_ptr));
776
+ else if (from_type == nk_u16_k)
777
+ hub_i64x8 = nk_u16x8_to_i64x8_skylake_(_mm_maskz_loadu_epi16(mask, from_ptr));
778
+ else if (from_type == nk_i32_k)
779
+ hub_i64x8 = nk_i32x8_to_i64x8_skylake_(_mm256_maskz_loadu_epi32(mask, from_ptr));
780
+ else if (from_type == nk_u32_k)
781
+ hub_i64x8 = nk_u32x8_to_i64x8_skylake_(_mm256_maskz_loadu_epi32(mask, from_ptr));
782
+ else if (from_type == nk_i64_k || from_type == nk_u64_k)
783
+ hub_i64x8 = _mm512_maskz_loadu_epi64(mask, from_ptr);
784
+ else hub_i64x8 = _mm512_setzero_si512();
785
+
786
+ // Downcast from i64x8
787
+ if (to_type == nk_i8_k) _mm_mask_storeu_epi8(to_ptr, mask, nk_i64x8_to_i8x8_skylake_(hub_i64x8));
788
+ else if (to_type == nk_u8_k) _mm_mask_storeu_epi8(to_ptr, mask, nk_i64x8_to_u8x8_skylake_(hub_i64x8));
789
+ else if (to_type == nk_i16_k) _mm_mask_storeu_epi16(to_ptr, mask, nk_i64x8_to_i16x8_skylake_(hub_i64x8));
790
+ else if (to_type == nk_u16_k) _mm_mask_storeu_epi16(to_ptr, mask, nk_i64x8_to_u16x8_skylake_(hub_i64x8));
791
+ else if (to_type == nk_i32_k) _mm256_mask_storeu_epi32(to_ptr, mask, nk_i64x8_to_i32x8_skylake_(hub_i64x8));
792
+ else if (to_type == nk_u32_k) _mm256_mask_storeu_epi32(to_ptr, mask, nk_i64x8_to_u32x8_skylake_(hub_i64x8));
793
+ else if (to_type == nk_i64_k || to_type == nk_u64_k) _mm512_mask_storeu_epi64(to_ptr, mask, hub_i64x8);
794
+
795
+ from_ptr += batch * from_bytes;
796
+ to_ptr += batch * to_bytes;
797
+ n -= batch;
798
+ }
799
+ return;
800
+ }
801
+
802
+ // Hub 4: f64x8 - f64 conversions (8 elements/batch)
803
+ // Only enter when both sides are types we can actually handle: f64, f32, i32, u32.
804
+ // Unsupported pairs (e.g. i8→f64, f16→f64) fall through to serial fallback.
805
+ if ((from_f64 || to_f64) && //
806
+ (from_type == nk_f64_k || from_type == nk_f32_k || from_type == nk_i32_k || from_type == nk_u32_k) && //
807
+ (to_type == nk_f64_k || to_type == nk_f32_k || to_type == nk_i32_k || to_type == nk_u32_k)) {
808
+ nk_size_t from_bytes = nk_dtype_bits(from_type) / NK_BITS_PER_BYTE;
809
+ nk_size_t to_bytes = nk_dtype_bits(to_type) / NK_BITS_PER_BYTE;
810
+ while (n > 0) {
811
+ nk_size_t batch = n < 8 ? n : 8;
812
+ __mmask8 mask = (__mmask8)_bzhi_u32(0xFF, (unsigned int)batch);
813
+ __m512d hub_f64x8;
814
+
815
+ // Upcast to f64x8
816
+ if (from_type == nk_f64_k) hub_f64x8 = _mm512_maskz_loadu_pd(mask, from_ptr);
817
+ else if (from_type == nk_f32_k)
818
+ hub_f64x8 = nk_f32x8_to_f64x8_skylake_(_mm256_maskz_loadu_ps(mask, from_ptr));
819
+ else if (from_type == nk_i32_k)
820
+ hub_f64x8 = nk_i32x8_to_f64x8_skylake_(_mm256_maskz_loadu_epi32(mask, from_ptr));
821
+ else if (from_type == nk_u32_k)
822
+ hub_f64x8 = nk_u32x8_to_f64x8_skylake_(_mm256_maskz_loadu_epi32(mask, from_ptr));
823
+ else hub_f64x8 = _mm512_setzero_pd();
824
+
825
+ // Downcast from f64x8
826
+ if (to_type == nk_f64_k) _mm512_mask_storeu_pd(to_ptr, mask, hub_f64x8);
827
+ else if (to_type == nk_f32_k) _mm256_mask_storeu_ps(to_ptr, mask, nk_f64x8_to_f32x8_skylake_(hub_f64x8));
828
+ else if (to_type == nk_i32_k) _mm256_mask_storeu_epi32(to_ptr, mask, nk_f64x8_to_i32x8_skylake_(hub_f64x8));
829
+ else if (to_type == nk_u32_k) _mm256_mask_storeu_epi32(to_ptr, mask, nk_f64x8_to_u32x8_skylake_(hub_f64x8));
830
+
831
+ from_ptr += batch * from_bytes;
832
+ to_ptr += batch * to_bytes;
833
+ n -= batch;
834
+ }
835
+ return;
836
+ }
837
+
838
+ // Fallback: complex types, i4/u4/u1, unsupported combinations
839
+ nk_cast_serial(from, from_type, n, to, to_type);
840
+ }
841
+
842
+ #pragma endregion - Public API
843
+
844
+ #if defined(__clang__)
845
+ #pragma clang attribute pop
846
+ #elif defined(__GNUC__)
847
+ #pragma GCC pop_options
848
+ #endif
849
+
850
+ #if defined(__cplusplus)
851
+ } // extern "C"
852
+ #endif
853
+
854
+ #endif // NK_TARGET_SKYLAKE
855
+ #endif // NK_TARGET_X86_
856
+ #endif // NK_CAST_SKYLAKE_H