numkong 7.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (294) hide show
  1. package/LICENSE +201 -0
  2. package/README.md +495 -0
  3. package/binding.gyp +540 -0
  4. package/c/dispatch.h +512 -0
  5. package/c/dispatch_bf16.c +389 -0
  6. package/c/dispatch_bf16c.c +52 -0
  7. package/c/dispatch_e2m3.c +263 -0
  8. package/c/dispatch_e3m2.c +243 -0
  9. package/c/dispatch_e4m3.c +276 -0
  10. package/c/dispatch_e5m2.c +272 -0
  11. package/c/dispatch_f16.c +376 -0
  12. package/c/dispatch_f16c.c +58 -0
  13. package/c/dispatch_f32.c +378 -0
  14. package/c/dispatch_f32c.c +99 -0
  15. package/c/dispatch_f64.c +296 -0
  16. package/c/dispatch_f64c.c +98 -0
  17. package/c/dispatch_i16.c +96 -0
  18. package/c/dispatch_i32.c +89 -0
  19. package/c/dispatch_i4.c +150 -0
  20. package/c/dispatch_i64.c +86 -0
  21. package/c/dispatch_i8.c +289 -0
  22. package/c/dispatch_other.c +330 -0
  23. package/c/dispatch_u1.c +148 -0
  24. package/c/dispatch_u16.c +124 -0
  25. package/c/dispatch_u32.c +118 -0
  26. package/c/dispatch_u4.c +150 -0
  27. package/c/dispatch_u64.c +102 -0
  28. package/c/dispatch_u8.c +303 -0
  29. package/c/numkong.c +950 -0
  30. package/include/README.md +573 -0
  31. package/include/module.modulemap +129 -0
  32. package/include/numkong/attention/sapphireamx.h +1361 -0
  33. package/include/numkong/attention/sme.h +2066 -0
  34. package/include/numkong/attention.h +49 -0
  35. package/include/numkong/capabilities.h +748 -0
  36. package/include/numkong/cast/README.md +262 -0
  37. package/include/numkong/cast/haswell.h +975 -0
  38. package/include/numkong/cast/icelake.h +470 -0
  39. package/include/numkong/cast/neon.h +1192 -0
  40. package/include/numkong/cast/rvv.h +1021 -0
  41. package/include/numkong/cast/sapphire.h +262 -0
  42. package/include/numkong/cast/serial.h +2262 -0
  43. package/include/numkong/cast/skylake.h +856 -0
  44. package/include/numkong/cast/v128relaxed.h +180 -0
  45. package/include/numkong/cast.h +230 -0
  46. package/include/numkong/curved/README.md +223 -0
  47. package/include/numkong/curved/genoa.h +182 -0
  48. package/include/numkong/curved/haswell.h +276 -0
  49. package/include/numkong/curved/neon.h +205 -0
  50. package/include/numkong/curved/neonbfdot.h +212 -0
  51. package/include/numkong/curved/neonhalf.h +212 -0
  52. package/include/numkong/curved/rvv.h +305 -0
  53. package/include/numkong/curved/serial.h +207 -0
  54. package/include/numkong/curved/skylake.h +457 -0
  55. package/include/numkong/curved/smef64.h +506 -0
  56. package/include/numkong/curved.h +517 -0
  57. package/include/numkong/curved.hpp +144 -0
  58. package/include/numkong/dot/README.md +425 -0
  59. package/include/numkong/dot/alder.h +563 -0
  60. package/include/numkong/dot/genoa.h +315 -0
  61. package/include/numkong/dot/haswell.h +1688 -0
  62. package/include/numkong/dot/icelake.h +883 -0
  63. package/include/numkong/dot/neon.h +818 -0
  64. package/include/numkong/dot/neonbfdot.h +244 -0
  65. package/include/numkong/dot/neonfhm.h +360 -0
  66. package/include/numkong/dot/neonhalf.h +198 -0
  67. package/include/numkong/dot/neonsdot.h +508 -0
  68. package/include/numkong/dot/rvv.h +714 -0
  69. package/include/numkong/dot/rvvbb.h +72 -0
  70. package/include/numkong/dot/rvvbf16.h +123 -0
  71. package/include/numkong/dot/rvvhalf.h +129 -0
  72. package/include/numkong/dot/sapphire.h +141 -0
  73. package/include/numkong/dot/serial.h +838 -0
  74. package/include/numkong/dot/sierra.h +405 -0
  75. package/include/numkong/dot/skylake.h +1084 -0
  76. package/include/numkong/dot/sve.h +379 -0
  77. package/include/numkong/dot/svebfdot.h +74 -0
  78. package/include/numkong/dot/svehalf.h +123 -0
  79. package/include/numkong/dot/v128relaxed.h +1258 -0
  80. package/include/numkong/dot.h +1070 -0
  81. package/include/numkong/dot.hpp +94 -0
  82. package/include/numkong/dots/README.md +496 -0
  83. package/include/numkong/dots/alder.h +114 -0
  84. package/include/numkong/dots/genoa.h +94 -0
  85. package/include/numkong/dots/haswell.h +295 -0
  86. package/include/numkong/dots/icelake.h +171 -0
  87. package/include/numkong/dots/neon.h +120 -0
  88. package/include/numkong/dots/neonbfdot.h +58 -0
  89. package/include/numkong/dots/neonfhm.h +94 -0
  90. package/include/numkong/dots/neonhalf.h +57 -0
  91. package/include/numkong/dots/neonsdot.h +108 -0
  92. package/include/numkong/dots/rvv.h +2486 -0
  93. package/include/numkong/dots/sapphireamx.h +3973 -0
  94. package/include/numkong/dots/serial.h +2844 -0
  95. package/include/numkong/dots/sierra.h +97 -0
  96. package/include/numkong/dots/skylake.h +196 -0
  97. package/include/numkong/dots/sme.h +5372 -0
  98. package/include/numkong/dots/smebi32.h +461 -0
  99. package/include/numkong/dots/smef64.h +1318 -0
  100. package/include/numkong/dots/smehalf.h +47 -0
  101. package/include/numkong/dots/v128relaxed.h +294 -0
  102. package/include/numkong/dots.h +2804 -0
  103. package/include/numkong/dots.hpp +639 -0
  104. package/include/numkong/each/README.md +469 -0
  105. package/include/numkong/each/haswell.h +1658 -0
  106. package/include/numkong/each/icelake.h +272 -0
  107. package/include/numkong/each/neon.h +1104 -0
  108. package/include/numkong/each/neonbfdot.h +212 -0
  109. package/include/numkong/each/neonhalf.h +410 -0
  110. package/include/numkong/each/rvv.h +1121 -0
  111. package/include/numkong/each/sapphire.h +477 -0
  112. package/include/numkong/each/serial.h +260 -0
  113. package/include/numkong/each/skylake.h +1562 -0
  114. package/include/numkong/each.h +2146 -0
  115. package/include/numkong/each.hpp +434 -0
  116. package/include/numkong/geospatial/README.md +147 -0
  117. package/include/numkong/geospatial/haswell.h +593 -0
  118. package/include/numkong/geospatial/neon.h +571 -0
  119. package/include/numkong/geospatial/rvv.h +701 -0
  120. package/include/numkong/geospatial/serial.h +309 -0
  121. package/include/numkong/geospatial/skylake.h +577 -0
  122. package/include/numkong/geospatial/v128relaxed.h +613 -0
  123. package/include/numkong/geospatial.h +453 -0
  124. package/include/numkong/geospatial.hpp +235 -0
  125. package/include/numkong/matrix.hpp +336 -0
  126. package/include/numkong/maxsim/README.md +187 -0
  127. package/include/numkong/maxsim/alder.h +511 -0
  128. package/include/numkong/maxsim/genoa.h +115 -0
  129. package/include/numkong/maxsim/haswell.h +553 -0
  130. package/include/numkong/maxsim/icelake.h +480 -0
  131. package/include/numkong/maxsim/neonsdot.h +394 -0
  132. package/include/numkong/maxsim/sapphireamx.h +877 -0
  133. package/include/numkong/maxsim/serial.h +490 -0
  134. package/include/numkong/maxsim/sme.h +929 -0
  135. package/include/numkong/maxsim/v128relaxed.h +280 -0
  136. package/include/numkong/maxsim.h +571 -0
  137. package/include/numkong/maxsim.hpp +133 -0
  138. package/include/numkong/mesh/README.md +227 -0
  139. package/include/numkong/mesh/haswell.h +2235 -0
  140. package/include/numkong/mesh/neon.h +1329 -0
  141. package/include/numkong/mesh/neonbfdot.h +842 -0
  142. package/include/numkong/mesh/neonhalf.h +616 -0
  143. package/include/numkong/mesh/rvv.h +916 -0
  144. package/include/numkong/mesh/serial.h +742 -0
  145. package/include/numkong/mesh/skylake.h +1135 -0
  146. package/include/numkong/mesh/v128relaxed.h +1052 -0
  147. package/include/numkong/mesh.h +652 -0
  148. package/include/numkong/mesh.hpp +762 -0
  149. package/include/numkong/numkong.h +78 -0
  150. package/include/numkong/numkong.hpp +57 -0
  151. package/include/numkong/probability/README.md +173 -0
  152. package/include/numkong/probability/haswell.h +267 -0
  153. package/include/numkong/probability/neon.h +225 -0
  154. package/include/numkong/probability/rvv.h +409 -0
  155. package/include/numkong/probability/serial.h +169 -0
  156. package/include/numkong/probability/skylake.h +324 -0
  157. package/include/numkong/probability.h +383 -0
  158. package/include/numkong/probability.hpp +120 -0
  159. package/include/numkong/random.h +50 -0
  160. package/include/numkong/random.hpp +285 -0
  161. package/include/numkong/reduce/README.md +547 -0
  162. package/include/numkong/reduce/alder.h +632 -0
  163. package/include/numkong/reduce/genoa.h +201 -0
  164. package/include/numkong/reduce/haswell.h +3783 -0
  165. package/include/numkong/reduce/icelake.h +549 -0
  166. package/include/numkong/reduce/neon.h +3841 -0
  167. package/include/numkong/reduce/neonbfdot.h +353 -0
  168. package/include/numkong/reduce/neonfhm.h +665 -0
  169. package/include/numkong/reduce/neonhalf.h +157 -0
  170. package/include/numkong/reduce/neonsdot.h +357 -0
  171. package/include/numkong/reduce/rvv.h +3407 -0
  172. package/include/numkong/reduce/serial.h +757 -0
  173. package/include/numkong/reduce/sierra.h +338 -0
  174. package/include/numkong/reduce/skylake.h +3792 -0
  175. package/include/numkong/reduce/v128relaxed.h +2302 -0
  176. package/include/numkong/reduce.h +1597 -0
  177. package/include/numkong/reduce.hpp +633 -0
  178. package/include/numkong/scalar/README.md +89 -0
  179. package/include/numkong/scalar/haswell.h +113 -0
  180. package/include/numkong/scalar/neon.h +122 -0
  181. package/include/numkong/scalar/neonhalf.h +70 -0
  182. package/include/numkong/scalar/rvv.h +211 -0
  183. package/include/numkong/scalar/sapphire.h +63 -0
  184. package/include/numkong/scalar/serial.h +332 -0
  185. package/include/numkong/scalar/v128relaxed.h +56 -0
  186. package/include/numkong/scalar.h +683 -0
  187. package/include/numkong/set/README.md +179 -0
  188. package/include/numkong/set/haswell.h +334 -0
  189. package/include/numkong/set/icelake.h +485 -0
  190. package/include/numkong/set/neon.h +364 -0
  191. package/include/numkong/set/rvv.h +226 -0
  192. package/include/numkong/set/rvvbb.h +117 -0
  193. package/include/numkong/set/serial.h +174 -0
  194. package/include/numkong/set/sve.h +185 -0
  195. package/include/numkong/set/v128relaxed.h +240 -0
  196. package/include/numkong/set.h +457 -0
  197. package/include/numkong/set.hpp +114 -0
  198. package/include/numkong/sets/README.md +149 -0
  199. package/include/numkong/sets/haswell.h +63 -0
  200. package/include/numkong/sets/icelake.h +66 -0
  201. package/include/numkong/sets/neon.h +61 -0
  202. package/include/numkong/sets/serial.h +43 -0
  203. package/include/numkong/sets/smebi32.h +1099 -0
  204. package/include/numkong/sets/v128relaxed.h +58 -0
  205. package/include/numkong/sets.h +339 -0
  206. package/include/numkong/sparse/README.md +156 -0
  207. package/include/numkong/sparse/icelake.h +463 -0
  208. package/include/numkong/sparse/neon.h +288 -0
  209. package/include/numkong/sparse/serial.h +117 -0
  210. package/include/numkong/sparse/sve2.h +507 -0
  211. package/include/numkong/sparse/turin.h +322 -0
  212. package/include/numkong/sparse.h +363 -0
  213. package/include/numkong/sparse.hpp +113 -0
  214. package/include/numkong/spatial/README.md +435 -0
  215. package/include/numkong/spatial/alder.h +607 -0
  216. package/include/numkong/spatial/genoa.h +290 -0
  217. package/include/numkong/spatial/haswell.h +960 -0
  218. package/include/numkong/spatial/icelake.h +586 -0
  219. package/include/numkong/spatial/neon.h +773 -0
  220. package/include/numkong/spatial/neonbfdot.h +165 -0
  221. package/include/numkong/spatial/neonhalf.h +118 -0
  222. package/include/numkong/spatial/neonsdot.h +261 -0
  223. package/include/numkong/spatial/rvv.h +984 -0
  224. package/include/numkong/spatial/rvvbf16.h +123 -0
  225. package/include/numkong/spatial/rvvhalf.h +117 -0
  226. package/include/numkong/spatial/sapphire.h +343 -0
  227. package/include/numkong/spatial/serial.h +346 -0
  228. package/include/numkong/spatial/sierra.h +323 -0
  229. package/include/numkong/spatial/skylake.h +606 -0
  230. package/include/numkong/spatial/sve.h +224 -0
  231. package/include/numkong/spatial/svebfdot.h +122 -0
  232. package/include/numkong/spatial/svehalf.h +109 -0
  233. package/include/numkong/spatial/v128relaxed.h +717 -0
  234. package/include/numkong/spatial.h +1425 -0
  235. package/include/numkong/spatial.hpp +183 -0
  236. package/include/numkong/spatials/README.md +580 -0
  237. package/include/numkong/spatials/alder.h +94 -0
  238. package/include/numkong/spatials/genoa.h +94 -0
  239. package/include/numkong/spatials/haswell.h +219 -0
  240. package/include/numkong/spatials/icelake.h +113 -0
  241. package/include/numkong/spatials/neon.h +109 -0
  242. package/include/numkong/spatials/neonbfdot.h +60 -0
  243. package/include/numkong/spatials/neonfhm.h +92 -0
  244. package/include/numkong/spatials/neonhalf.h +58 -0
  245. package/include/numkong/spatials/neonsdot.h +109 -0
  246. package/include/numkong/spatials/rvv.h +1960 -0
  247. package/include/numkong/spatials/sapphireamx.h +1149 -0
  248. package/include/numkong/spatials/serial.h +226 -0
  249. package/include/numkong/spatials/sierra.h +96 -0
  250. package/include/numkong/spatials/skylake.h +184 -0
  251. package/include/numkong/spatials/sme.h +1901 -0
  252. package/include/numkong/spatials/smef64.h +465 -0
  253. package/include/numkong/spatials/v128relaxed.h +240 -0
  254. package/include/numkong/spatials.h +3021 -0
  255. package/include/numkong/spatials.hpp +508 -0
  256. package/include/numkong/tensor.hpp +1592 -0
  257. package/include/numkong/trigonometry/README.md +184 -0
  258. package/include/numkong/trigonometry/haswell.h +652 -0
  259. package/include/numkong/trigonometry/neon.h +639 -0
  260. package/include/numkong/trigonometry/rvv.h +699 -0
  261. package/include/numkong/trigonometry/serial.h +703 -0
  262. package/include/numkong/trigonometry/skylake.h +721 -0
  263. package/include/numkong/trigonometry/v128relaxed.h +666 -0
  264. package/include/numkong/trigonometry.h +467 -0
  265. package/include/numkong/trigonometry.hpp +166 -0
  266. package/include/numkong/types.h +1384 -0
  267. package/include/numkong/types.hpp +5603 -0
  268. package/include/numkong/vector.hpp +698 -0
  269. package/javascript/README.md +246 -0
  270. package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
  271. package/javascript/dist/cjs/numkong-wasm.js +617 -0
  272. package/javascript/dist/cjs/numkong.d.ts +343 -0
  273. package/javascript/dist/cjs/numkong.js +523 -0
  274. package/javascript/dist/cjs/package.json +3 -0
  275. package/javascript/dist/cjs/types.d.ts +284 -0
  276. package/javascript/dist/cjs/types.js +653 -0
  277. package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
  278. package/javascript/dist/esm/numkong-wasm.js +595 -0
  279. package/javascript/dist/esm/numkong.d.ts +343 -0
  280. package/javascript/dist/esm/numkong.js +452 -0
  281. package/javascript/dist/esm/package.json +3 -0
  282. package/javascript/dist/esm/types.d.ts +284 -0
  283. package/javascript/dist/esm/types.js +630 -0
  284. package/javascript/dist-package-cjs.json +3 -0
  285. package/javascript/dist-package-esm.json +3 -0
  286. package/javascript/node-gyp-build.d.ts +1 -0
  287. package/javascript/numkong-wasm.ts +756 -0
  288. package/javascript/numkong.c +689 -0
  289. package/javascript/numkong.ts +575 -0
  290. package/javascript/tsconfig-base.json +39 -0
  291. package/javascript/tsconfig-cjs.json +8 -0
  292. package/javascript/tsconfig-esm.json +8 -0
  293. package/javascript/types.ts +674 -0
  294. package/package.json +87 -0
@@ -0,0 +1,262 @@
1
+ /**
2
+ * @brief SIMD-accelerated Type Conversions for Sapphire Rapids.
3
+ * @file include/numkong/cast/sapphire.h
4
+ * @author Ash Vardanian
5
+ * @date January 2, 2026
6
+ *
7
+ * @section sapphire_cast_instructions Relevant Instructions
8
+ *
9
+ * Intrinsic Instruction Sapphire Genoa
10
+ * _mm_cvtss_sh VCVTSS2SH (XMM, XMM, XMM) 5cy @ p05 5cy @ p01
11
+ * _mm_cvtsh_ss VCVTSH2SS (XMM, XMM, XMM) 5cy @ p05 5cy @ p01
12
+ * _mm256_cvtepu8_epi16 VPMOVZXBW (YMM, XMM) 3cy @ p5 3cy @ p12
13
+ * _mm256_mul_ph VMULPH (YMM, YMM, YMM) 4cy @ p05 3cy @ p01
14
+ * _mm256_cvtepi16_ph VCVTW2PH (YMM, YMM) 4cy @ p05 4cy @ p01
15
+ * _mm256_cvtph_epi16 VCVTPH2W (YMM, YMM) 4cy @ p05 4cy @ p01
16
+ * _mm256_mask_blend_epi16 VPBLENDMW (YMM, K, YMM, YMM) 1cy @ p05 1cy @ p0123
17
+ * _mm256_testn_epi16_mask VPTESTNMW (K, YMM, YMM) 3cy @ p5 3cy @ p0
18
+ * _mm256_cvtepi16_epi8 VPMOVWB (XMM, YMM) 4cy @ p5 4cy @ p12
19
+ * _mm_maskz_loadu_epi8 VMOVDQU8 (XMM {K}, M128) 7cy @ p23 7cy @ p23
20
+ * _mm256_mask_storeu_epi16 VMOVDQU16 (M256 {K}, YMM) 4cy @ p4 4cy @ p4
21
+ */
22
+ #ifndef NK_CAST_SAPPHIRE_H
23
+ #define NK_CAST_SAPPHIRE_H
24
+
25
+ #if NK_TARGET_X86_
26
+ #if NK_TARGET_SAPPHIRE
27
+
28
+ #include "numkong/types.h"
29
+ #include "numkong/cast/icelake.h" // `nk_cast_icelake`
30
+
31
+ #if defined(__cplusplus)
32
+ extern "C" {
33
+ #endif
34
+
35
+ #if defined(__clang__)
36
+ #pragma clang attribute push(__attribute__((target("avx2,avx512f,avx512vl,avx512bw,avx512fp16,f16c,fma,bmi,bmi2"))), \
37
+ apply_to = function)
38
+ #elif defined(__GNUC__)
39
+ #pragma GCC push_options
40
+ #pragma GCC target("avx2", "avx512f", "avx512vl", "avx512bw", "avx512fp16", "f16c", "fma", "bmi", "bmi2")
41
+ #endif
42
+
43
+ NK_PUBLIC void nk_f32_to_f16_sapphire(nk_f32_t const *from, nk_f16_t *to) {
44
+ *to = _mm_cvtsi128_si32(_mm_castph_si128(_mm_cvtss_sh(_mm_setzero_ph(), _mm_set_ss(*from))));
45
+ }
46
+
47
+ NK_PUBLIC void nk_f16_to_f32_sapphire(nk_f16_t const *from, nk_f32_t *to) {
48
+ *to = _mm_cvtss_f32(_mm_cvtsh_ss(_mm_setzero_ps(), _mm_castsi128_ph(_mm_cvtsi32_si128(*from))));
49
+ }
50
+
51
+ #pragma region - Vectorized Conversions
52
+
53
+ /** @brief Convert 16x e4m3 → 16x f16 via bit manipulation (AVX-512 FP16).
54
+ * E4M3 format: S EEEE MMM (bias=7). F16: S EEEEE MMMMMMMMMM (bias=15).
55
+ * Normal: sign | ((exp+8)<<10) | (mant<<7).
56
+ * Subnormals (exp=0): value = mantissa ÷ 512, computed via f16 arithmetic. */
57
+ NK_INTERNAL __m256h nk_e4m3x16_to_f16x16_sapphire_(__m128i e4m3_i8x16) {
58
+ __m256i e4m3_i16x16 = _mm256_cvtepu8_epi16(e4m3_i8x16);
59
+
60
+ // Extract fields
61
+ __m256i mantissa_i16x16 = _mm256_and_si256(e4m3_i16x16, _mm256_set1_epi16(0x07));
62
+ __m256i sign_i16x16 = _mm256_and_si256(_mm256_slli_epi16(e4m3_i16x16, 8), _mm256_set1_epi16((short)0x8000));
63
+
64
+ // Normal path: sign | ((exp+8)<<10) | (mantissa<<7) via single shift + bias add
65
+ __m256i exp_mantissa_i16x16 = _mm256_slli_epi16(_mm256_and_si256(e4m3_i16x16, _mm256_set1_epi16(0x7F)), 7);
66
+ __m256i exp_mantissa_biased_i16x16 = _mm256_add_epi16(exp_mantissa_i16x16, _mm256_set1_epi16(0x2000));
67
+ __m256i normal_i16x16 = _mm256_or_si256(sign_i16x16, exp_mantissa_biased_i16x16);
68
+
69
+ // Subnormal fix: for exp==0 lanes, use (subnorm_abs | sign); else keep normal
70
+ __mmask16 is_subnormal = _mm256_testn_epi16_mask(e4m3_i16x16, _mm256_set1_epi16(0x78));
71
+ __m256h subnorm_abs_f16x16 = _mm256_mul_ph(_mm256_cvtepi16_ph(mantissa_i16x16),
72
+ _mm256_castsi256_ph(_mm256_set1_epi16(0x1800))); // 1/512
73
+ __m256i subnorm_signed_i16x16 = _mm256_or_si256(_mm256_castph_si256(subnorm_abs_f16x16), sign_i16x16);
74
+ __m256i result_i16x16 = _mm256_mask_blend_epi16(is_subnormal, normal_i16x16, subnorm_signed_i16x16);
75
+
76
+ // NaN path: E4M3FN has NaN only when exp=15 AND mant=7 (lower 7 bits == 0x7F)
77
+ __mmask16 is_nan = _mm256_cmpeq_epi16_mask( //
78
+ _mm256_and_si256(e4m3_i16x16, _mm256_set1_epi16(0x7F)), //
79
+ _mm256_set1_epi16(0x7F)); //
80
+ __m256i nan_bits = _mm256_or_si256(sign_i16x16, _mm256_set1_epi16(0x7E00)); // F16 quiet NaN
81
+ return _mm256_castsi256_ph(_mm256_mask_blend_epi16(is_nan, result_i16x16, nan_bits));
82
+ }
83
+
84
+ /** @brief Convert 16x e5m2 → 16x f16 via bit manipulation (AVX-512 FP16).
85
+ * E5M2 format: S EEEEE MM (bias=15). F16: S EEEEE MMMMMMMMMM (bias=15).
86
+ * Normal: sign | (exp<<10) | (mant<<8) (same exponent bias).
87
+ * Subnormals (exp=0): value = mantissa ÷ 65536, computed via f16 arithmetic. */
88
+ NK_INTERNAL __m256h nk_e5m2x16_to_f16x16_sapphire_(__m128i e5m2_i8x16) {
89
+ __m256i e5m2_i16x16 = _mm256_cvtepu8_epi16(e5m2_i8x16);
90
+
91
+ // Extract fields
92
+ __m256i mantissa_i16x16 = _mm256_and_si256(e5m2_i16x16, _mm256_set1_epi16(0x03));
93
+ __m256i sign_i16x16 = _mm256_and_si256(_mm256_slli_epi16(e5m2_i16x16, 8), _mm256_set1_epi16((short)0x8000));
94
+
95
+ // Normal path: sign | (exp<<10) | (mant<<8) - same exponent bias so just shift lower7 by 8
96
+ __m256i exp_mantissa_i16x16 = _mm256_slli_epi16(_mm256_and_si256(e5m2_i16x16, _mm256_set1_epi16(0x7F)), 8);
97
+ __m256i normal_i16x16 = _mm256_or_si256(sign_i16x16, exp_mantissa_i16x16);
98
+
99
+ // Subnormal fix: for exp==0 lanes, use (subnorm_abs | sign); else keep normal
100
+ __mmask16 is_subnormal = _mm256_testn_epi16_mask(e5m2_i16x16, _mm256_set1_epi16(0x7C));
101
+ __m256h subnorm_abs_f16x16 = _mm256_mul_ph(_mm256_cvtepi16_ph(mantissa_i16x16),
102
+ _mm256_castsi256_ph(_mm256_set1_epi16(0x0100))); // 1/65536
103
+ __m256i subnorm_signed_i16x16 = _mm256_or_si256(_mm256_castph_si256(subnorm_abs_f16x16), sign_i16x16);
104
+ return _mm256_castsi256_ph(_mm256_mask_blend_epi16(is_subnormal, normal_i16x16, subnorm_signed_i16x16));
105
+ }
106
+
107
+ /** @brief Convert 16x f16 → 16x e4m3 via bit manipulation (AVX-512 FP16).
108
+ * F16: S EEEEE MMMMMMMMMM (bias=15). E4M3: S EEEE MMM (bias=7).
109
+ * Handles normal, subnormal, and overflow cases with RNE rounding. */
110
+ NK_INTERNAL __m128i nk_f16x16_to_e4m3x16_sapphire_(__m256h f16x16) {
111
+ __m256i bits_i16x16 = _mm256_castph_si256(f16x16);
112
+ __m256i sign_i16x16 = _mm256_srli_epi16(bits_i16x16, 15);
113
+ __m256i f16_exp_i16x16 = _mm256_and_si256(_mm256_srli_epi16(bits_i16x16, 10), _mm256_set1_epi16(0x1F));
114
+
115
+ // Round mantissa from 10 to 3 bits using RNE (round to nearest, ties to even)
116
+ __m256i significand_i16x16 = _mm256_or_si256(_mm256_and_si256(bits_i16x16, _mm256_set1_epi16(0x03FF)),
117
+ _mm256_set1_epi16(0x0400)); // Add implicit 1 bit
118
+ __m256i lsb_i16x16 = _mm256_and_si256(_mm256_srli_epi16(significand_i16x16, 7), _mm256_set1_epi16(1));
119
+ __m256i rounding_bias_i16x16 = _mm256_add_epi16(_mm256_set1_epi16(0x003F), lsb_i16x16);
120
+ __m256i rounded_sig_i16x16 = _mm256_add_epi16(significand_i16x16, rounding_bias_i16x16);
121
+ __m256i carry_i16x16 = _mm256_srli_epi16(rounded_sig_i16x16, 11); // Carry into exponent if bit 11 set
122
+ __m256i f16_mantissa_i16x16 = _mm256_and_si256(_mm256_srli_epi16(rounded_sig_i16x16, 7), _mm256_set1_epi16(0x07));
123
+ // If carry, mantissa becomes 0 (we rounded up to next power of 2)
124
+ f16_mantissa_i16x16 = _mm256_andnot_si256(_mm256_slli_epi16(carry_i16x16, 15), f16_mantissa_i16x16);
125
+ __m256i e4m3_exp_i16x16 = _mm256_sub_epi16(_mm256_add_epi16(f16_exp_i16x16, carry_i16x16), _mm256_set1_epi16(8));
126
+
127
+ // Detect underflow (exp <= 0) and overflow (exp > 15)
128
+ __mmask16 is_subnormal = _mm256_cmpgt_epi16_mask(_mm256_set1_epi16(1), e4m3_exp_i16x16);
129
+ __mmask16 overflow = _mm256_cmpgt_epi16_mask(e4m3_exp_i16x16, _mm256_set1_epi16(15));
130
+
131
+ // Normal path: clamp exp to [1,15]
132
+ // e4m3FN quirk: exp=15 with mantissa=7 is NaN (0x7F), so clamp mantissa to 6 when exp=15.
133
+ __m256i clamped_exp_i16x16 = _mm256_max_epi16(e4m3_exp_i16x16, _mm256_set1_epi16(1));
134
+ clamped_exp_i16x16 = _mm256_min_epi16(clamped_exp_i16x16, _mm256_set1_epi16(15));
135
+ __mmask16 is_max_exp = _mm256_cmpeq_epi16_mask(clamped_exp_i16x16, _mm256_set1_epi16(15));
136
+ __m256i max_mantissa_i16x16 = _mm256_mask_blend_epi16(is_max_exp, _mm256_set1_epi16(7), _mm256_set1_epi16(6));
137
+ __m256i normal_mantissa_i16x16 = _mm256_min_epi16(f16_mantissa_i16x16, max_mantissa_i16x16);
138
+ normal_mantissa_i16x16 = _mm256_mask_blend_epi16(overflow, normal_mantissa_i16x16, _mm256_set1_epi16(0x06));
139
+ __m256i normal_e4m3_i16x16 = _mm256_or_si256(
140
+ _mm256_slli_epi16(sign_i16x16, 7),
141
+ _mm256_or_si256(_mm256_slli_epi16(clamped_exp_i16x16, 3), normal_mantissa_i16x16));
142
+
143
+ // Subnormal path: mantissa = round(abs_f16 * 512)
144
+ __m256h abs_f16x16 = _mm256_castsi256_ph(_mm256_and_si256(_mm256_castph_si256(f16x16), _mm256_set1_epi16(0x7FFF)));
145
+ __m256h scaled_f16x16 = _mm256_mul_ph(abs_f16x16, _mm256_castsi256_ph(_mm256_set1_epi16(0x6000))); // 512
146
+ __m256i subnorm_mantissa_i16x16 = _mm256_cvtph_epi16(scaled_f16x16);
147
+ __mmask16 promotes_to_normal = _mm256_cmpgt_epi16_mask(subnorm_mantissa_i16x16, _mm256_set1_epi16(7));
148
+ subnorm_mantissa_i16x16 = _mm256_min_epi16(subnorm_mantissa_i16x16, _mm256_set1_epi16(7));
149
+ subnorm_mantissa_i16x16 = _mm256_max_epi16(subnorm_mantissa_i16x16, _mm256_setzero_si256());
150
+ __m256i subnorm_e4m3_i16x16 = _mm256_or_si256(_mm256_slli_epi16(sign_i16x16, 7), subnorm_mantissa_i16x16);
151
+ __m256i first_normal_e4m3_i16x16 = _mm256_or_si256(_mm256_slli_epi16(sign_i16x16, 7), _mm256_set1_epi16(0x08));
152
+ subnorm_e4m3_i16x16 = _mm256_mask_blend_epi16(promotes_to_normal, subnorm_e4m3_i16x16, first_normal_e4m3_i16x16);
153
+
154
+ // Blend: use subnormal result when exp <= 0
155
+ __m256i e4m3_i16x16 = _mm256_mask_blend_epi16(is_subnormal, normal_e4m3_i16x16, subnorm_e4m3_i16x16);
156
+
157
+ // Pack 16 i16s to 16 unsigned i8s via AVX-512BW
158
+ return _mm256_cvtepi16_epi8(e4m3_i16x16);
159
+ }
160
+
161
+ /** @brief Convert 16x f16 → 16x e5m2 via bit manipulation (AVX-512 FP16).
162
+ * F16: S EEEEE MMMMMMMMMM (bias=15). E5M2: S EEEEE MM (bias=15).
163
+ * Same exponent bias, so just round mantissa from 10 to 2 bits. */
164
+ NK_INTERNAL __m128i nk_f16x16_to_e5m2x16_sapphire_(__m256h f16x16) {
165
+ __m256i bits_i16x16 = _mm256_castph_si256(f16x16);
166
+ __m256i sign_i16x16 = _mm256_srli_epi16(bits_i16x16, 15);
167
+ __m256i f16_exp_i16x16 = _mm256_and_si256(_mm256_srli_epi16(bits_i16x16, 10), _mm256_set1_epi16(0x1F));
168
+
169
+ // Round mantissa from 10 to 2 bits using RNE (round to nearest, ties to even)
170
+ __m256i significand_i16x16 = _mm256_or_si256(_mm256_and_si256(bits_i16x16, _mm256_set1_epi16(0x03FF)),
171
+ _mm256_set1_epi16(0x0400)); // Add implicit 1 bit
172
+ __m256i lsb_i16x16 = _mm256_and_si256(_mm256_srli_epi16(significand_i16x16, 8), _mm256_set1_epi16(1));
173
+ __m256i rounding_bias_i16x16 = _mm256_add_epi16(_mm256_set1_epi16(0x007F), lsb_i16x16);
174
+ __m256i rounded_sig_i16x16 = _mm256_add_epi16(significand_i16x16, rounding_bias_i16x16);
175
+ __m256i carry_i16x16 = _mm256_srli_epi16(rounded_sig_i16x16, 11); // Carry into exponent if bit 11 set
176
+ __m256i f16_mantissa_i16x16 = _mm256_and_si256(_mm256_srli_epi16(rounded_sig_i16x16, 8), _mm256_set1_epi16(0x03));
177
+ // If carry, mantissa becomes 0 (we rounded up to next power of 2)
178
+ f16_mantissa_i16x16 = _mm256_andnot_si256(_mm256_slli_epi16(carry_i16x16, 15), f16_mantissa_i16x16);
179
+ __m256i e5m2_exp_i16x16 = _mm256_add_epi16(f16_exp_i16x16, carry_i16x16);
180
+
181
+ // Detect subnormal (exp <= 0) and overflow (exp > 31)
182
+ __mmask16 is_subnormal = _mm256_cmpeq_epi16_mask(f16_exp_i16x16, _mm256_setzero_si256());
183
+ __mmask16 overflow = _mm256_cmpgt_epi16_mask(e5m2_exp_i16x16, _mm256_set1_epi16(31));
184
+
185
+ // Normal path: clamp exp to [1,31], on overflow return infinity
186
+ __m256i clamped_exp_i16x16 = _mm256_max_epi16(e5m2_exp_i16x16, _mm256_set1_epi16(1));
187
+ clamped_exp_i16x16 = _mm256_min_epi16(clamped_exp_i16x16, _mm256_set1_epi16(31));
188
+ __m256i normal_mantissa_i16x16 = _mm256_mask_blend_epi16(overflow, f16_mantissa_i16x16, _mm256_setzero_si256());
189
+ __m256i normal_e5m2_i16x16 = _mm256_or_si256(
190
+ _mm256_slli_epi16(sign_i16x16, 7),
191
+ _mm256_or_si256(_mm256_slli_epi16(clamped_exp_i16x16, 2), normal_mantissa_i16x16));
192
+
193
+ // Subnormal path: mantissa = round(abs_f16 * 65536)
194
+ __m256h abs_f16x16 = _mm256_castsi256_ph(_mm256_and_si256(_mm256_castph_si256(f16x16), _mm256_set1_epi16(0x7FFF)));
195
+ __m256h scaled_f16x16 = _mm256_mul_ph(abs_f16x16, _mm256_castsi256_ph(_mm256_set1_epi16(0x7C00))); // 65536 (inf)
196
+ __m256i subnorm_mantissa_i16x16 = _mm256_cvtph_epi16(scaled_f16x16);
197
+ __mmask16 promotes_to_normal = _mm256_cmpgt_epi16_mask(subnorm_mantissa_i16x16, _mm256_set1_epi16(3));
198
+ subnorm_mantissa_i16x16 = _mm256_min_epi16(subnorm_mantissa_i16x16, _mm256_set1_epi16(3));
199
+ subnorm_mantissa_i16x16 = _mm256_max_epi16(subnorm_mantissa_i16x16, _mm256_setzero_si256());
200
+ __m256i subnorm_e5m2_i16x16 = _mm256_or_si256(_mm256_slli_epi16(sign_i16x16, 7), subnorm_mantissa_i16x16);
201
+ __m256i first_normal_e5m2_i16x16 = _mm256_or_si256(_mm256_slli_epi16(sign_i16x16, 7), _mm256_set1_epi16(0x04));
202
+ subnorm_e5m2_i16x16 = _mm256_mask_blend_epi16(promotes_to_normal, subnorm_e5m2_i16x16, first_normal_e5m2_i16x16);
203
+
204
+ // Blend: use subnormal result when exp == 0
205
+ __m256i e5m2_i16x16 = _mm256_mask_blend_epi16(is_subnormal, normal_e5m2_i16x16, subnorm_e5m2_i16x16);
206
+
207
+ // Pack 16 i16s to 16 unsigned i8s via AVX-512BW
208
+ return _mm256_cvtepi16_epi8(e5m2_i16x16);
209
+ }
210
+
211
+ #pragma endregion - Vectorized Conversions
212
+
213
+ #pragma region - Public API
214
+
215
+ NK_PUBLIC void nk_cast_sapphire(void const *from, nk_dtype_t from_type, nk_size_t n, void *to, nk_dtype_t to_type) {
216
+ // Group 1: Conversions to f16 (e4m3 → f16, e5m2 → f16)
217
+ if (to_type == nk_f16_k && (from_type == nk_e4m3_k || from_type == nk_e5m2_k)) {
218
+ nk_e4m3_t const *from_ptr = (nk_e4m3_t const *)from;
219
+ nk_f16_t *to_ptr = (nk_f16_t *)to;
220
+ for (nk_size_t idx = 0; idx < n; idx += 16) {
221
+ nk_size_t remaining = n - idx;
222
+ __mmask16 mask = (remaining >= 16) ? 0xFFFF : (unsigned short)_bzhi_u32(0xFFFF, (unsigned)remaining);
223
+ __m128i in_f8x16 = _mm_maskz_loadu_epi8(mask, from_ptr + idx);
224
+ __m256h out_f16x16 = (from_type == nk_e4m3_k) ? nk_e4m3x16_to_f16x16_sapphire_(in_f8x16)
225
+ : nk_e5m2x16_to_f16x16_sapphire_(in_f8x16);
226
+ _mm256_mask_storeu_epi16(to_ptr + idx, mask, _mm256_castph_si256(out_f16x16));
227
+ }
228
+ }
229
+
230
+ // Group 2: Conversions from f16 (f16 → e4m3, f16 → e5m2)
231
+ else if (from_type == nk_f16_k && (to_type == nk_e4m3_k || to_type == nk_e5m2_k)) {
232
+ nk_f16_t const *from_ptr = (nk_f16_t const *)from;
233
+ nk_e4m3_t *to_ptr = (nk_e4m3_t *)to;
234
+ for (nk_size_t idx = 0; idx < n; idx += 16) {
235
+ nk_size_t remaining = n - idx;
236
+ __mmask16 mask = (remaining >= 16) ? 0xFFFF : (unsigned short)_bzhi_u32(0xFFFF, (unsigned)remaining);
237
+ __m256h in_f16x16 = _mm256_castsi256_ph(_mm256_maskz_loadu_epi16(mask, from_ptr + idx));
238
+ __m128i out_f8x16 = (to_type == nk_e4m3_k) ? nk_f16x16_to_e4m3x16_sapphire_(in_f16x16)
239
+ : nk_f16x16_to_e5m2x16_sapphire_(in_f16x16);
240
+ _mm_mask_storeu_epi8(to_ptr + idx, mask, out_f8x16);
241
+ }
242
+ }
243
+
244
+ // Default: delegate to Ice for all other conversions
245
+ else nk_cast_icelake(from, from_type, n, to, to_type);
246
+ }
247
+
248
+ #pragma endregion - Public API
249
+
250
+ #if defined(__clang__)
251
+ #pragma clang attribute pop
252
+ #elif defined(__GNUC__)
253
+ #pragma GCC pop_options
254
+ #endif
255
+
256
+ #if defined(__cplusplus)
257
+ } // extern "C"
258
+ #endif
259
+
260
+ #endif // NK_TARGET_SAPPHIRE
261
+ #endif // NK_TARGET_X86_
262
+ #endif // NK_CAST_SAPPHIRE_H