numkong 7.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (294) hide show
  1. package/LICENSE +201 -0
  2. package/README.md +495 -0
  3. package/binding.gyp +540 -0
  4. package/c/dispatch.h +512 -0
  5. package/c/dispatch_bf16.c +389 -0
  6. package/c/dispatch_bf16c.c +52 -0
  7. package/c/dispatch_e2m3.c +263 -0
  8. package/c/dispatch_e3m2.c +243 -0
  9. package/c/dispatch_e4m3.c +276 -0
  10. package/c/dispatch_e5m2.c +272 -0
  11. package/c/dispatch_f16.c +376 -0
  12. package/c/dispatch_f16c.c +58 -0
  13. package/c/dispatch_f32.c +378 -0
  14. package/c/dispatch_f32c.c +99 -0
  15. package/c/dispatch_f64.c +296 -0
  16. package/c/dispatch_f64c.c +98 -0
  17. package/c/dispatch_i16.c +96 -0
  18. package/c/dispatch_i32.c +89 -0
  19. package/c/dispatch_i4.c +150 -0
  20. package/c/dispatch_i64.c +86 -0
  21. package/c/dispatch_i8.c +289 -0
  22. package/c/dispatch_other.c +330 -0
  23. package/c/dispatch_u1.c +148 -0
  24. package/c/dispatch_u16.c +124 -0
  25. package/c/dispatch_u32.c +118 -0
  26. package/c/dispatch_u4.c +150 -0
  27. package/c/dispatch_u64.c +102 -0
  28. package/c/dispatch_u8.c +303 -0
  29. package/c/numkong.c +950 -0
  30. package/include/README.md +573 -0
  31. package/include/module.modulemap +129 -0
  32. package/include/numkong/attention/sapphireamx.h +1361 -0
  33. package/include/numkong/attention/sme.h +2066 -0
  34. package/include/numkong/attention.h +49 -0
  35. package/include/numkong/capabilities.h +748 -0
  36. package/include/numkong/cast/README.md +262 -0
  37. package/include/numkong/cast/haswell.h +975 -0
  38. package/include/numkong/cast/icelake.h +470 -0
  39. package/include/numkong/cast/neon.h +1192 -0
  40. package/include/numkong/cast/rvv.h +1021 -0
  41. package/include/numkong/cast/sapphire.h +262 -0
  42. package/include/numkong/cast/serial.h +2262 -0
  43. package/include/numkong/cast/skylake.h +856 -0
  44. package/include/numkong/cast/v128relaxed.h +180 -0
  45. package/include/numkong/cast.h +230 -0
  46. package/include/numkong/curved/README.md +223 -0
  47. package/include/numkong/curved/genoa.h +182 -0
  48. package/include/numkong/curved/haswell.h +276 -0
  49. package/include/numkong/curved/neon.h +205 -0
  50. package/include/numkong/curved/neonbfdot.h +212 -0
  51. package/include/numkong/curved/neonhalf.h +212 -0
  52. package/include/numkong/curved/rvv.h +305 -0
  53. package/include/numkong/curved/serial.h +207 -0
  54. package/include/numkong/curved/skylake.h +457 -0
  55. package/include/numkong/curved/smef64.h +506 -0
  56. package/include/numkong/curved.h +517 -0
  57. package/include/numkong/curved.hpp +144 -0
  58. package/include/numkong/dot/README.md +425 -0
  59. package/include/numkong/dot/alder.h +563 -0
  60. package/include/numkong/dot/genoa.h +315 -0
  61. package/include/numkong/dot/haswell.h +1688 -0
  62. package/include/numkong/dot/icelake.h +883 -0
  63. package/include/numkong/dot/neon.h +818 -0
  64. package/include/numkong/dot/neonbfdot.h +244 -0
  65. package/include/numkong/dot/neonfhm.h +360 -0
  66. package/include/numkong/dot/neonhalf.h +198 -0
  67. package/include/numkong/dot/neonsdot.h +508 -0
  68. package/include/numkong/dot/rvv.h +714 -0
  69. package/include/numkong/dot/rvvbb.h +72 -0
  70. package/include/numkong/dot/rvvbf16.h +123 -0
  71. package/include/numkong/dot/rvvhalf.h +129 -0
  72. package/include/numkong/dot/sapphire.h +141 -0
  73. package/include/numkong/dot/serial.h +838 -0
  74. package/include/numkong/dot/sierra.h +405 -0
  75. package/include/numkong/dot/skylake.h +1084 -0
  76. package/include/numkong/dot/sve.h +379 -0
  77. package/include/numkong/dot/svebfdot.h +74 -0
  78. package/include/numkong/dot/svehalf.h +123 -0
  79. package/include/numkong/dot/v128relaxed.h +1258 -0
  80. package/include/numkong/dot.h +1070 -0
  81. package/include/numkong/dot.hpp +94 -0
  82. package/include/numkong/dots/README.md +496 -0
  83. package/include/numkong/dots/alder.h +114 -0
  84. package/include/numkong/dots/genoa.h +94 -0
  85. package/include/numkong/dots/haswell.h +295 -0
  86. package/include/numkong/dots/icelake.h +171 -0
  87. package/include/numkong/dots/neon.h +120 -0
  88. package/include/numkong/dots/neonbfdot.h +58 -0
  89. package/include/numkong/dots/neonfhm.h +94 -0
  90. package/include/numkong/dots/neonhalf.h +57 -0
  91. package/include/numkong/dots/neonsdot.h +108 -0
  92. package/include/numkong/dots/rvv.h +2486 -0
  93. package/include/numkong/dots/sapphireamx.h +3973 -0
  94. package/include/numkong/dots/serial.h +2844 -0
  95. package/include/numkong/dots/sierra.h +97 -0
  96. package/include/numkong/dots/skylake.h +196 -0
  97. package/include/numkong/dots/sme.h +5372 -0
  98. package/include/numkong/dots/smebi32.h +461 -0
  99. package/include/numkong/dots/smef64.h +1318 -0
  100. package/include/numkong/dots/smehalf.h +47 -0
  101. package/include/numkong/dots/v128relaxed.h +294 -0
  102. package/include/numkong/dots.h +2804 -0
  103. package/include/numkong/dots.hpp +639 -0
  104. package/include/numkong/each/README.md +469 -0
  105. package/include/numkong/each/haswell.h +1658 -0
  106. package/include/numkong/each/icelake.h +272 -0
  107. package/include/numkong/each/neon.h +1104 -0
  108. package/include/numkong/each/neonbfdot.h +212 -0
  109. package/include/numkong/each/neonhalf.h +410 -0
  110. package/include/numkong/each/rvv.h +1121 -0
  111. package/include/numkong/each/sapphire.h +477 -0
  112. package/include/numkong/each/serial.h +260 -0
  113. package/include/numkong/each/skylake.h +1562 -0
  114. package/include/numkong/each.h +2146 -0
  115. package/include/numkong/each.hpp +434 -0
  116. package/include/numkong/geospatial/README.md +147 -0
  117. package/include/numkong/geospatial/haswell.h +593 -0
  118. package/include/numkong/geospatial/neon.h +571 -0
  119. package/include/numkong/geospatial/rvv.h +701 -0
  120. package/include/numkong/geospatial/serial.h +309 -0
  121. package/include/numkong/geospatial/skylake.h +577 -0
  122. package/include/numkong/geospatial/v128relaxed.h +613 -0
  123. package/include/numkong/geospatial.h +453 -0
  124. package/include/numkong/geospatial.hpp +235 -0
  125. package/include/numkong/matrix.hpp +336 -0
  126. package/include/numkong/maxsim/README.md +187 -0
  127. package/include/numkong/maxsim/alder.h +511 -0
  128. package/include/numkong/maxsim/genoa.h +115 -0
  129. package/include/numkong/maxsim/haswell.h +553 -0
  130. package/include/numkong/maxsim/icelake.h +480 -0
  131. package/include/numkong/maxsim/neonsdot.h +394 -0
  132. package/include/numkong/maxsim/sapphireamx.h +877 -0
  133. package/include/numkong/maxsim/serial.h +490 -0
  134. package/include/numkong/maxsim/sme.h +929 -0
  135. package/include/numkong/maxsim/v128relaxed.h +280 -0
  136. package/include/numkong/maxsim.h +571 -0
  137. package/include/numkong/maxsim.hpp +133 -0
  138. package/include/numkong/mesh/README.md +227 -0
  139. package/include/numkong/mesh/haswell.h +2235 -0
  140. package/include/numkong/mesh/neon.h +1329 -0
  141. package/include/numkong/mesh/neonbfdot.h +842 -0
  142. package/include/numkong/mesh/neonhalf.h +616 -0
  143. package/include/numkong/mesh/rvv.h +916 -0
  144. package/include/numkong/mesh/serial.h +742 -0
  145. package/include/numkong/mesh/skylake.h +1135 -0
  146. package/include/numkong/mesh/v128relaxed.h +1052 -0
  147. package/include/numkong/mesh.h +652 -0
  148. package/include/numkong/mesh.hpp +762 -0
  149. package/include/numkong/numkong.h +78 -0
  150. package/include/numkong/numkong.hpp +57 -0
  151. package/include/numkong/probability/README.md +173 -0
  152. package/include/numkong/probability/haswell.h +267 -0
  153. package/include/numkong/probability/neon.h +225 -0
  154. package/include/numkong/probability/rvv.h +409 -0
  155. package/include/numkong/probability/serial.h +169 -0
  156. package/include/numkong/probability/skylake.h +324 -0
  157. package/include/numkong/probability.h +383 -0
  158. package/include/numkong/probability.hpp +120 -0
  159. package/include/numkong/random.h +50 -0
  160. package/include/numkong/random.hpp +285 -0
  161. package/include/numkong/reduce/README.md +547 -0
  162. package/include/numkong/reduce/alder.h +632 -0
  163. package/include/numkong/reduce/genoa.h +201 -0
  164. package/include/numkong/reduce/haswell.h +3783 -0
  165. package/include/numkong/reduce/icelake.h +549 -0
  166. package/include/numkong/reduce/neon.h +3841 -0
  167. package/include/numkong/reduce/neonbfdot.h +353 -0
  168. package/include/numkong/reduce/neonfhm.h +665 -0
  169. package/include/numkong/reduce/neonhalf.h +157 -0
  170. package/include/numkong/reduce/neonsdot.h +357 -0
  171. package/include/numkong/reduce/rvv.h +3407 -0
  172. package/include/numkong/reduce/serial.h +757 -0
  173. package/include/numkong/reduce/sierra.h +338 -0
  174. package/include/numkong/reduce/skylake.h +3792 -0
  175. package/include/numkong/reduce/v128relaxed.h +2302 -0
  176. package/include/numkong/reduce.h +1597 -0
  177. package/include/numkong/reduce.hpp +633 -0
  178. package/include/numkong/scalar/README.md +89 -0
  179. package/include/numkong/scalar/haswell.h +113 -0
  180. package/include/numkong/scalar/neon.h +122 -0
  181. package/include/numkong/scalar/neonhalf.h +70 -0
  182. package/include/numkong/scalar/rvv.h +211 -0
  183. package/include/numkong/scalar/sapphire.h +63 -0
  184. package/include/numkong/scalar/serial.h +332 -0
  185. package/include/numkong/scalar/v128relaxed.h +56 -0
  186. package/include/numkong/scalar.h +683 -0
  187. package/include/numkong/set/README.md +179 -0
  188. package/include/numkong/set/haswell.h +334 -0
  189. package/include/numkong/set/icelake.h +485 -0
  190. package/include/numkong/set/neon.h +364 -0
  191. package/include/numkong/set/rvv.h +226 -0
  192. package/include/numkong/set/rvvbb.h +117 -0
  193. package/include/numkong/set/serial.h +174 -0
  194. package/include/numkong/set/sve.h +185 -0
  195. package/include/numkong/set/v128relaxed.h +240 -0
  196. package/include/numkong/set.h +457 -0
  197. package/include/numkong/set.hpp +114 -0
  198. package/include/numkong/sets/README.md +149 -0
  199. package/include/numkong/sets/haswell.h +63 -0
  200. package/include/numkong/sets/icelake.h +66 -0
  201. package/include/numkong/sets/neon.h +61 -0
  202. package/include/numkong/sets/serial.h +43 -0
  203. package/include/numkong/sets/smebi32.h +1099 -0
  204. package/include/numkong/sets/v128relaxed.h +58 -0
  205. package/include/numkong/sets.h +339 -0
  206. package/include/numkong/sparse/README.md +156 -0
  207. package/include/numkong/sparse/icelake.h +463 -0
  208. package/include/numkong/sparse/neon.h +288 -0
  209. package/include/numkong/sparse/serial.h +117 -0
  210. package/include/numkong/sparse/sve2.h +507 -0
  211. package/include/numkong/sparse/turin.h +322 -0
  212. package/include/numkong/sparse.h +363 -0
  213. package/include/numkong/sparse.hpp +113 -0
  214. package/include/numkong/spatial/README.md +435 -0
  215. package/include/numkong/spatial/alder.h +607 -0
  216. package/include/numkong/spatial/genoa.h +290 -0
  217. package/include/numkong/spatial/haswell.h +960 -0
  218. package/include/numkong/spatial/icelake.h +586 -0
  219. package/include/numkong/spatial/neon.h +773 -0
  220. package/include/numkong/spatial/neonbfdot.h +165 -0
  221. package/include/numkong/spatial/neonhalf.h +118 -0
  222. package/include/numkong/spatial/neonsdot.h +261 -0
  223. package/include/numkong/spatial/rvv.h +984 -0
  224. package/include/numkong/spatial/rvvbf16.h +123 -0
  225. package/include/numkong/spatial/rvvhalf.h +117 -0
  226. package/include/numkong/spatial/sapphire.h +343 -0
  227. package/include/numkong/spatial/serial.h +346 -0
  228. package/include/numkong/spatial/sierra.h +323 -0
  229. package/include/numkong/spatial/skylake.h +606 -0
  230. package/include/numkong/spatial/sve.h +224 -0
  231. package/include/numkong/spatial/svebfdot.h +122 -0
  232. package/include/numkong/spatial/svehalf.h +109 -0
  233. package/include/numkong/spatial/v128relaxed.h +717 -0
  234. package/include/numkong/spatial.h +1425 -0
  235. package/include/numkong/spatial.hpp +183 -0
  236. package/include/numkong/spatials/README.md +580 -0
  237. package/include/numkong/spatials/alder.h +94 -0
  238. package/include/numkong/spatials/genoa.h +94 -0
  239. package/include/numkong/spatials/haswell.h +219 -0
  240. package/include/numkong/spatials/icelake.h +113 -0
  241. package/include/numkong/spatials/neon.h +109 -0
  242. package/include/numkong/spatials/neonbfdot.h +60 -0
  243. package/include/numkong/spatials/neonfhm.h +92 -0
  244. package/include/numkong/spatials/neonhalf.h +58 -0
  245. package/include/numkong/spatials/neonsdot.h +109 -0
  246. package/include/numkong/spatials/rvv.h +1960 -0
  247. package/include/numkong/spatials/sapphireamx.h +1149 -0
  248. package/include/numkong/spatials/serial.h +226 -0
  249. package/include/numkong/spatials/sierra.h +96 -0
  250. package/include/numkong/spatials/skylake.h +184 -0
  251. package/include/numkong/spatials/sme.h +1901 -0
  252. package/include/numkong/spatials/smef64.h +465 -0
  253. package/include/numkong/spatials/v128relaxed.h +240 -0
  254. package/include/numkong/spatials.h +3021 -0
  255. package/include/numkong/spatials.hpp +508 -0
  256. package/include/numkong/tensor.hpp +1592 -0
  257. package/include/numkong/trigonometry/README.md +184 -0
  258. package/include/numkong/trigonometry/haswell.h +652 -0
  259. package/include/numkong/trigonometry/neon.h +639 -0
  260. package/include/numkong/trigonometry/rvv.h +699 -0
  261. package/include/numkong/trigonometry/serial.h +703 -0
  262. package/include/numkong/trigonometry/skylake.h +721 -0
  263. package/include/numkong/trigonometry/v128relaxed.h +666 -0
  264. package/include/numkong/trigonometry.h +467 -0
  265. package/include/numkong/trigonometry.hpp +166 -0
  266. package/include/numkong/types.h +1384 -0
  267. package/include/numkong/types.hpp +5603 -0
  268. package/include/numkong/vector.hpp +698 -0
  269. package/javascript/README.md +246 -0
  270. package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
  271. package/javascript/dist/cjs/numkong-wasm.js +617 -0
  272. package/javascript/dist/cjs/numkong.d.ts +343 -0
  273. package/javascript/dist/cjs/numkong.js +523 -0
  274. package/javascript/dist/cjs/package.json +3 -0
  275. package/javascript/dist/cjs/types.d.ts +284 -0
  276. package/javascript/dist/cjs/types.js +653 -0
  277. package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
  278. package/javascript/dist/esm/numkong-wasm.js +595 -0
  279. package/javascript/dist/esm/numkong.d.ts +343 -0
  280. package/javascript/dist/esm/numkong.js +452 -0
  281. package/javascript/dist/esm/package.json +3 -0
  282. package/javascript/dist/esm/types.d.ts +284 -0
  283. package/javascript/dist/esm/types.js +630 -0
  284. package/javascript/dist-package-cjs.json +3 -0
  285. package/javascript/dist-package-esm.json +3 -0
  286. package/javascript/node-gyp-build.d.ts +1 -0
  287. package/javascript/numkong-wasm.ts +756 -0
  288. package/javascript/numkong.c +689 -0
  289. package/javascript/numkong.ts +575 -0
  290. package/javascript/tsconfig-base.json +39 -0
  291. package/javascript/tsconfig-cjs.json +8 -0
  292. package/javascript/tsconfig-esm.json +8 -0
  293. package/javascript/types.ts +674 -0
  294. package/package.json +87 -0
@@ -0,0 +1,72 @@
1
+ /**
2
+ * @brief SIMD-accelerated Dot Products for RISC-V with Zvbb.
3
+ * @file include/numkong/dot/rvvbb.h
4
+ * @author Ash Vardanian
5
+ * @date February 22, 2026
6
+ *
7
+ * @sa include/numkong/dot.h
8
+ *
9
+ * Zvbb (Vector Basic Bit-manipulation) provides native per-element popcount via `vcpop.v`,
10
+ * replacing the 11-instruction SWAR approach with a single instruction for u1 dot products.
11
+ *
12
+ * Only `nk_dot_u1` benefits from Zvbb (it needs byte-level popcount of AND results).
13
+ * Requires: RVV 1.0 + Zvbb extension (GCC 14+ or Clang 18+)
14
+ */
15
+ #ifndef NK_DOT_RVVBB_H
16
+ #define NK_DOT_RVVBB_H
17
+
18
+ #if NK_TARGET_RISCV_
19
+ #if NK_TARGET_RVVBB
20
+
21
+ #include "numkong/types.h"
22
+ #include "numkong/set/rvvbb.h" // `nk_popcount_u8m4_rvvbb_`
23
+
24
+ #if defined(__clang__)
25
+ #pragma clang attribute push(__attribute__((target("arch=+v,+zvbb"))), apply_to = function)
26
+ #elif defined(__GNUC__)
27
+ #pragma GCC push_options
28
+ #pragma GCC target("arch=+v,+zvbb")
29
+ #endif
30
+
31
+ #if defined(__cplusplus)
32
+ extern "C" {
33
+ #endif
34
+
35
+ NK_PUBLIC void nk_dot_u1_rvvbb(nk_u1x8_t const *a, nk_u1x8_t const *b, nk_size_t n_bits, nk_u32_t *result) {
36
+ nk_size_t count_bytes = nk_size_divide_round_up_(n_bits, NK_BITS_PER_BYTE);
37
+
38
+ vuint32m1_t sum_u32m1 = __riscv_vmv_v_x_u32m1(0, 1);
39
+
40
+ nk_size_t i = 0;
41
+ for (nk_size_t vector_length; i + 1 <= count_bytes; i += vector_length) {
42
+ vector_length = __riscv_vsetvl_e8m4(count_bytes - i);
43
+
44
+ // Load and AND to find shared bits (dot product of binary vectors)
45
+ vuint8m4_t a_u8m4 = __riscv_vle8_v_u8m4(a + i, vector_length);
46
+ vuint8m4_t b_u8m4 = __riscv_vle8_v_u8m4(b + i, vector_length);
47
+ vuint8m4_t and_u8m4 = __riscv_vand_vv_u8m4(a_u8m4, b_u8m4, vector_length);
48
+
49
+ // Native per-element popcount via Zvbb (1 instruction vs 11 SWAR)
50
+ vuint8m4_t popcount_u8m4 = nk_popcount_u8m4_rvvbb_(and_u8m4);
51
+
52
+ // Widen to u16 and accumulate via widening reduction sum
53
+ vuint16m8_t popcount_u16m8 = __riscv_vwaddu_vx_u16m8(popcount_u8m4, 0, vector_length);
54
+ sum_u32m1 = __riscv_vwredsumu_vs_u16m8_u32m1(popcount_u16m8, sum_u32m1, vector_length);
55
+ }
56
+
57
+ *result = __riscv_vmv_x_s_u32m1_u32(sum_u32m1);
58
+ }
59
+
60
+ #if defined(__cplusplus)
61
+ } // extern "C"
62
+ #endif
63
+
64
+ #if defined(__clang__)
65
+ #pragma clang attribute pop
66
+ #elif defined(__GNUC__)
67
+ #pragma GCC pop_options
68
+ #endif
69
+
70
+ #endif // NK_TARGET_RVVBB
71
+ #endif // NK_TARGET_RISCV_
72
+ #endif // NK_DOT_RVVBB_H
@@ -0,0 +1,123 @@
1
+ /**
2
+ * @brief SIMD-accelerated Dot Products for RISC-V BF16.
3
+ * @file include/numkong/dot/rvvbf16.h
4
+ * @author Ash Vardanian
5
+ * @date January 5, 2026
6
+ *
7
+ * @sa include/numkong/dot.h
8
+ *
9
+ * Alibaba XuanTie C930 and similar chips implement RVV 1.0 with Zvfbfwma extension.
10
+ * Zvfbfwma provides widening bf16 fused multiply-accumulate to f32:
11
+ * vfwmaccbf16: f32 ← bf16 ⨯ bf16
12
+ *
13
+ * All mini-float types use 256-entry VLUXEI16 LUT gathers from cast/rvv.h (3 instructions each).
14
+ * All variants then use vfwmaccbf16_vv for fused bf16 ⨯ bf16 → f32 multiply-accumulate.
15
+ *
16
+ * Requires: RVV 1.0 + Zvfbfwma extension (GCC 14+ or Clang 18+)
17
+ */
18
+ #ifndef NK_DOT_RVVBF16_H
19
+ #define NK_DOT_RVVBF16_H
20
+
21
+ #if NK_TARGET_RISCV_
22
+ #if NK_TARGET_RVVBF16
23
+
24
+ #include "numkong/types.h"
25
+ #include "numkong/cast/rvv.h" // `nk_e4m3m1_to_bf16m2_rvv_`, `nk_e5m2m1_to_bf16m2_rvv_`, etc.
26
+
27
+ #if defined(__clang__)
28
+ #pragma clang attribute push(__attribute__((target("arch=+v,+zvfbfwma"))), apply_to = function)
29
+ #elif defined(__GNUC__)
30
+ #pragma GCC push_options
31
+ #pragma GCC target("arch=+v,+zvfbfwma")
32
+ #endif
33
+
34
+ #if defined(__cplusplus)
35
+ extern "C" {
36
+ #endif
37
+
38
+ NK_PUBLIC void nk_dot_bf16_rvvbf16(nk_bf16_t const *a_scalars, nk_bf16_t const *b_scalars, nk_size_t count_scalars,
39
+ nk_f32_t *result) {
40
+ nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
41
+ vfloat32m2_t sum_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
42
+ for (nk_size_t vector_length; count_scalars > 0;
43
+ count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
44
+ vector_length = __riscv_vsetvl_e16m1(count_scalars);
45
+ vuint16m1_t a_u16m1 = __riscv_vle16_v_u16m1((unsigned short const *)a_scalars, vector_length);
46
+ vuint16m1_t b_u16m1 = __riscv_vle16_v_u16m1((unsigned short const *)b_scalars, vector_length);
47
+ vbfloat16m1_t a_bf16m1 = __riscv_vreinterpret_v_u16m1_bf16m1(a_u16m1);
48
+ vbfloat16m1_t b_bf16m1 = __riscv_vreinterpret_v_u16m1_bf16m1(b_u16m1);
49
+ // Widening bf16 FMA: f32 ← bf16 ⨯ bf16, per-lane accumulation
50
+ sum_f32m2 = __riscv_vfwmaccbf16_vv_f32m2_tu(sum_f32m2, a_bf16m1, b_bf16m1, vector_length);
51
+ }
52
+ // Single horizontal reduction at the end
53
+ vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, vlmax);
54
+ *result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m2_f32m1(sum_f32m2, zero_f32m1, vlmax));
55
+ }
56
+
57
+ /** @brief Convert e2m3 to bf16 via 256-entry LUT in cast/rvv.h + reinterpret. */
58
+ NK_INTERNAL vbfloat16m2_t nk_e2m3m1_to_bf16m2_rvvbf16_(vuint8m1_t raw_u8m1, nk_size_t vector_length) {
59
+ return __riscv_vreinterpret_v_u16m2_bf16m2(nk_e2m3m1_to_bf16m2_rvv_(raw_u8m1, vector_length));
60
+ }
61
+
62
+ /** @brief Convert e3m2 to bf16 via 256-entry LUT in cast/rvv.h + reinterpret. */
63
+ NK_INTERNAL vbfloat16m2_t nk_e3m2m1_to_bf16m2_rvvbf16_(vuint8m1_t raw_u8m1, nk_size_t vector_length) {
64
+ return __riscv_vreinterpret_v_u16m2_bf16m2(nk_e3m2m1_to_bf16m2_rvv_(raw_u8m1, vector_length));
65
+ }
66
+
67
+ /** @brief Convert e4m3 to bf16 via 256-entry LUT in cast/rvv.h + reinterpret. */
68
+ NK_INTERNAL vbfloat16m2_t nk_e4m3m1_to_bf16m2_rvvbf16_(vuint8m1_t raw_u8m1, nk_size_t vector_length) {
69
+ return __riscv_vreinterpret_v_u16m2_bf16m2(nk_e4m3m1_to_bf16m2_rvv_(raw_u8m1, vector_length));
70
+ }
71
+
72
+ /** @brief Convert e5m2 to bf16 via 256-entry LUT in cast/rvv.h + reinterpret. */
73
+ NK_INTERNAL vbfloat16m2_t nk_e5m2m1_to_bf16m2_rvvbf16_(vuint8m1_t raw_u8m1, nk_size_t vector_length) {
74
+ return __riscv_vreinterpret_v_u16m2_bf16m2(nk_e5m2m1_to_bf16m2_rvv_(raw_u8m1, vector_length));
75
+ }
76
+
77
+ NK_PUBLIC void nk_dot_e4m3_rvvbf16(nk_e4m3_t const *a_scalars, nk_e4m3_t const *b_scalars, nk_size_t count_scalars,
78
+ nk_f32_t *result) {
79
+ nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
80
+ vfloat32m4_t sum_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, vlmax);
81
+ for (nk_size_t vector_length; count_scalars > 0;
82
+ count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
83
+ vector_length = __riscv_vsetvl_e8m1(count_scalars);
84
+ vuint8m1_t a_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)a_scalars, vector_length);
85
+ vuint8m1_t b_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)b_scalars, vector_length);
86
+ vbfloat16m2_t a_bf16m2 = nk_e4m3m1_to_bf16m2_rvvbf16_(a_u8m1, vector_length);
87
+ vbfloat16m2_t b_bf16m2 = nk_e4m3m1_to_bf16m2_rvvbf16_(b_u8m1, vector_length);
88
+ sum_f32m4 = __riscv_vfwmaccbf16_vv_f32m4_tu(sum_f32m4, a_bf16m2, b_bf16m2, vector_length);
89
+ }
90
+ vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, vlmax);
91
+ *result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sum_f32m4, zero_f32m1, vlmax));
92
+ }
93
+
94
+ NK_PUBLIC void nk_dot_e5m2_rvvbf16(nk_e5m2_t const *a_scalars, nk_e5m2_t const *b_scalars, nk_size_t count_scalars,
95
+ nk_f32_t *result) {
96
+ nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
97
+ vfloat32m4_t sum_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, vlmax);
98
+ for (nk_size_t vector_length; count_scalars > 0;
99
+ count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
100
+ vector_length = __riscv_vsetvl_e8m1(count_scalars);
101
+ vuint8m1_t a_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)a_scalars, vector_length);
102
+ vuint8m1_t b_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)b_scalars, vector_length);
103
+ vbfloat16m2_t a_bf16m2 = nk_e5m2m1_to_bf16m2_rvvbf16_(a_u8m1, vector_length);
104
+ vbfloat16m2_t b_bf16m2 = nk_e5m2m1_to_bf16m2_rvvbf16_(b_u8m1, vector_length);
105
+ sum_f32m4 = __riscv_vfwmaccbf16_vv_f32m4_tu(sum_f32m4, a_bf16m2, b_bf16m2, vector_length);
106
+ }
107
+ vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, vlmax);
108
+ *result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sum_f32m4, zero_f32m1, vlmax));
109
+ }
110
+
111
+ #if defined(__cplusplus)
112
+ } // extern "C"
113
+ #endif
114
+
115
+ #if defined(__clang__)
116
+ #pragma clang attribute pop
117
+ #elif defined(__GNUC__)
118
+ #pragma GCC pop_options
119
+ #endif
120
+
121
+ #endif // NK_TARGET_RVVBF16
122
+ #endif // NK_TARGET_RISCV_
123
+ #endif // NK_DOT_RVVBF16_H
@@ -0,0 +1,129 @@
1
+ /**
2
+ * @brief SIMD-accelerated Dot Products for RISC-V FP16.
3
+ * @file include/numkong/dot/rvvhalf.h
4
+ * @author Ash Vardanian
5
+ * @date January 5, 2026
6
+ *
7
+ * @sa include/numkong/dot.h
8
+ *
9
+ * SiFive P670/X280 and similar chips implement RVV 1.0 with Zvfh extension.
10
+ * Zvfh provides native half-precision (f16) vector operations.
11
+ * Uses widening multiply (f16 ⨯ f16 → f32) for precision, then reduces to f32.
12
+ *
13
+ * For e2m3, e3m2, e4m3: conversion uses 256-entry VLUXEI16 LUT gathers from cast/rvv.h (3 instructions each).
14
+ * For e5m2: conversion uses pure shift (vzext + vsll) since e5m2 and f16 share the same exponent bias.
15
+ * All variants then use vfwmacc_vv for widening fused f16 ⨯ f16 → f32 multiply-accumulate.
16
+ *
17
+ * Requires: RVV 1.0 + Zvfh extension (GCC 14+ or Clang 18+)
18
+ */
19
+ #ifndef NK_DOT_RVVHALF_H
20
+ #define NK_DOT_RVVHALF_H
21
+
22
+ #if NK_TARGET_RISCV_
23
+ #if NK_TARGET_RVVHALF
24
+
25
+ #include "numkong/types.h"
26
+ #include "numkong/cast/rvv.h" // `nk_e4m3m1_to_f16m2_rvv_`, `nk_e2m3m1_to_f16m2_rvv_`, etc.
27
+
28
+ #if defined(__clang__)
29
+ #pragma clang attribute push(__attribute__((target("arch=+v,+zvfh"))), apply_to = function)
30
+ #elif defined(__GNUC__)
31
+ #pragma GCC push_options
32
+ #pragma GCC target("arch=+v,+zvfh")
33
+ #endif
34
+
35
+ #if defined(__cplusplus)
36
+ extern "C" {
37
+ #endif
38
+
39
+ NK_PUBLIC void nk_dot_f16_rvvhalf(nk_f16_t const *a_scalars, nk_f16_t const *b_scalars, nk_size_t count_scalars,
40
+ nk_f32_t *result) {
41
+ nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
42
+ vfloat32m2_t sum_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
43
+ for (nk_size_t vector_length; count_scalars > 0;
44
+ count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
45
+ vector_length = __riscv_vsetvl_e16m1(count_scalars);
46
+ vuint16m1_t a_u16m1 = __riscv_vle16_v_u16m1((unsigned short const *)a_scalars, vector_length);
47
+ vuint16m1_t b_u16m1 = __riscv_vle16_v_u16m1((unsigned short const *)b_scalars, vector_length);
48
+ vfloat16m1_t a_f16m1 = __riscv_vreinterpret_v_u16m1_f16m1(a_u16m1);
49
+ vfloat16m1_t b_f16m1 = __riscv_vreinterpret_v_u16m1_f16m1(b_u16m1);
50
+ // Widening FMA: f32 += f16 ⨯ f16, per-lane accumulation
51
+ sum_f32m2 = __riscv_vfwmacc_vv_f32m2_tu(sum_f32m2, a_f16m1, b_f16m1, vector_length);
52
+ }
53
+ // Single horizontal reduction at the end
54
+ vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, vlmax);
55
+ *result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m2_f32m1(sum_f32m2, zero_f32m1, vlmax));
56
+ }
57
+
58
+ /** @brief Convert e2m3 to f16 via 256-entry LUT in cast/rvv.h + reinterpret. */
59
+ NK_INTERNAL vfloat16m2_t nk_e2m3m1_to_f16m2_rvvhalf_(vuint8m1_t raw_u8m1, nk_size_t vector_length) {
60
+ return __riscv_vreinterpret_v_u16m2_f16m2(nk_e2m3m1_to_f16m2_rvv_(raw_u8m1, vector_length));
61
+ }
62
+
63
+ /** @brief Convert e3m2 to f16 via 256-entry LUT in cast/rvv.h + reinterpret. */
64
+ NK_INTERNAL vfloat16m2_t nk_e3m2m1_to_f16m2_rvvhalf_(vuint8m1_t raw_u8m1, nk_size_t vector_length) {
65
+ return __riscv_vreinterpret_v_u16m2_f16m2(nk_e3m2m1_to_f16m2_rvv_(raw_u8m1, vector_length));
66
+ }
67
+
68
+ /** @brief Convert e4m3 to f16 via 256-entry LUT in cast/rvv.h + reinterpret. */
69
+ NK_INTERNAL vfloat16m2_t nk_e4m3m1_to_f16m2_rvvhalf_(vuint8m1_t raw_u8m1, nk_size_t vector_length) {
70
+ return __riscv_vreinterpret_v_u16m2_f16m2(nk_e4m3m1_to_f16m2_rvv_(raw_u8m1, vector_length));
71
+ }
72
+
73
+ /**
74
+ * @brief Convert e5m2 (1-5-2 sign-exp-mantissa, 8-bit) to f16 via pure shift (no LUT).
75
+ * Same exponent bias (15) means f16 = (lower7 << 8) | (sign << 15). Handles all cases.
76
+ */
77
+ NK_INTERNAL vfloat16m2_t nk_e5m2m1_to_f16m2_rvvhalf_(vuint8m1_t raw_u8m1, nk_size_t vector_length) {
78
+ vuint16m2_t wide_u16m2 = __riscv_vzext_vf2_u16m2(raw_u8m1, vector_length);
79
+ vuint16m2_t result_u16m2 = __riscv_vsll_vx_u16m2(wide_u16m2, 8, vector_length);
80
+ return __riscv_vreinterpret_v_u16m2_f16m2(result_u16m2);
81
+ }
82
+
83
+ NK_PUBLIC void nk_dot_e4m3_rvvhalf(nk_e4m3_t const *a_scalars, nk_e4m3_t const *b_scalars, nk_size_t count_scalars,
84
+ nk_f32_t *result) {
85
+ nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
86
+ vfloat32m4_t sum_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, vlmax);
87
+ for (nk_size_t vector_length; count_scalars > 0;
88
+ count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
89
+ vector_length = __riscv_vsetvl_e8m1(count_scalars);
90
+ vuint8m1_t a_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)a_scalars, vector_length);
91
+ vuint8m1_t b_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)b_scalars, vector_length);
92
+ vfloat16m2_t a_f16m2 = nk_e4m3m1_to_f16m2_rvvhalf_(a_u8m1, vector_length);
93
+ vfloat16m2_t b_f16m2 = nk_e4m3m1_to_f16m2_rvvhalf_(b_u8m1, vector_length);
94
+ sum_f32m4 = __riscv_vfwmacc_vv_f32m4_tu(sum_f32m4, a_f16m2, b_f16m2, vector_length);
95
+ }
96
+ vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, vlmax);
97
+ *result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sum_f32m4, zero_f32m1, vlmax));
98
+ }
99
+
100
+ NK_PUBLIC void nk_dot_e5m2_rvvhalf(nk_e5m2_t const *a_scalars, nk_e5m2_t const *b_scalars, nk_size_t count_scalars,
101
+ nk_f32_t *result) {
102
+ nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
103
+ vfloat32m4_t sum_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, vlmax);
104
+ for (nk_size_t vector_length; count_scalars > 0;
105
+ count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
106
+ vector_length = __riscv_vsetvl_e8m1(count_scalars);
107
+ vuint8m1_t a_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)a_scalars, vector_length);
108
+ vuint8m1_t b_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)b_scalars, vector_length);
109
+ vfloat16m2_t a_f16m2 = nk_e5m2m1_to_f16m2_rvvhalf_(a_u8m1, vector_length);
110
+ vfloat16m2_t b_f16m2 = nk_e5m2m1_to_f16m2_rvvhalf_(b_u8m1, vector_length);
111
+ sum_f32m4 = __riscv_vfwmacc_vv_f32m4_tu(sum_f32m4, a_f16m2, b_f16m2, vector_length);
112
+ }
113
+ vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, vlmax);
114
+ *result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sum_f32m4, zero_f32m1, vlmax));
115
+ }
116
+
117
+ #if defined(__cplusplus)
118
+ } // extern "C"
119
+ #endif
120
+
121
+ #if defined(__clang__)
122
+ #pragma clang attribute pop
123
+ #elif defined(__GNUC__)
124
+ #pragma GCC pop_options
125
+ #endif
126
+
127
+ #endif // NK_TARGET_RVVHALF
128
+ #endif // NK_TARGET_RISCV_
129
+ #endif // NK_DOT_RVVHALF_H
@@ -0,0 +1,141 @@
1
+ /**
2
+ * @brief SIMD-accelerated Dot Products for Sapphire Rapids.
3
+ * @file include/numkong/dot/sapphire.h
4
+ * @author Ash Vardanian
5
+ * @date February 7, 2026
6
+ *
7
+ * @sa include/numkong/dot.h
8
+ *
9
+ * @section dot_sapphire_instructions Key AVX-512 FP16 Instructions
10
+ *
11
+ * Intrinsic Instruction Latency Throughput Ports
12
+ * _mm512_fmadd_ph VFMADDPH (ZMM, ZMM, ZMM) 4cy 0.5/cy p01
13
+ * _mm512_fmadd_ps VFMADD132PS (ZMM, ZMM, ZMM) 4cy 0.5/cy p01
14
+ * _mm512_cvtph_ps VCVTPH2PS (ZMM, YMM) 7cy 1/cy p01
15
+ *
16
+ * Sapphire Rapids introduces native AVX-512 FP16 support, enabling 32 FP16 FMAs per instruction at the same
17
+ * throughput as 16 FP32 FMAs — effectively 2x compute density. For FP6 types (E2M3 and E3M2) whose products
18
+ * are small enough to accumulate safely in FP16, this provides near-2x speedup over the Genoa BF16 path.
19
+ *
20
+ * @section dot_sapphire_accumulation Safe FP16 Accumulation
21
+ *
22
+ * E2M3 max product: 7.5² = 56.25; flush every 4 iterations → max lane sum ~225, FP16 ULP ~0.125.
23
+ * E3M2 max product: 28² = 784; flush every 4 iterations → max lane sum ~3136, FP16 ULP ~2.0.
24
+ * After the flush window, we widen the FP16 accumulator to FP32 and reset.
25
+ *
26
+ * @section dot_sapphire_stateful Stateful Streaming Logic
27
+ *
28
+ * Typed wrappers control the flush cadence:
29
+ * - nk_dot_e2m3x32_state_sapphire_t flushes every 4 iterations (128 elements)
30
+ * - nk_dot_e3m2x32_state_sapphire_t flushes every 4 iterations (128 elements)
31
+ */
32
+ #ifndef NK_DOT_SAPPHIRE_H
33
+ #define NK_DOT_SAPPHIRE_H
34
+
35
+ #if NK_TARGET_X86_
36
+ #if NK_TARGET_SAPPHIRE
37
+
38
+ #include "numkong/types.h"
39
+ #include "numkong/reduce/skylake.h" // `nk_reduce_add_f32x16_skylake_`
40
+ #include "numkong/dot/skylake.h" // `nk_dot_through_f32_finalize_skylake_`
41
+
42
+ #if defined(__cplusplus)
43
+ extern "C" {
44
+ #endif
45
+
46
+ #if defined(__clang__)
47
+ #pragma clang attribute push( \
48
+ __attribute__((target("avx2,avx512f,avx512vl,avx512bw,avx512dq,avx512fp16,f16c,fma,bmi,bmi2"))), \
49
+ apply_to = function)
50
+ #elif defined(__GNUC__)
51
+ #pragma GCC push_options
52
+ #pragma GCC target("avx2", "avx512f", "avx512vl", "avx512bw", "avx512dq", "avx512fp16", "f16c", "fma", "bmi", "bmi2")
53
+ #endif
54
+
55
+ /** @brief Convert 32x e2m3 → 32x f16 via 64-entry signed LUT lookup (AVX-512BW).
56
+ * E2M3 format: S EE MMM (bias=1, 6 bits total: sign at bit 5, magnitude bits 4-0).
57
+ * F16: S EEEEE MMMMMMMMMM (bias=15).
58
+ *
59
+ * Uses permutex2var with two 32-entry LUTs (positive and negative F16 values).
60
+ * The E2M3 sign bit (bit 5) naturally becomes the source-select bit of the 6-bit index,
61
+ * so no separate sign extraction, shift, or OR is needed. After cvtepu8_epi16, bits 15:6
62
+ * are zero and permutex2var only reads bits 5:0, so no AND mask is required either. */
63
+ NK_INTERNAL __m512h nk_e2m3x32_to_f16x32_sapphire_(__m256i e2m3x32) {
64
+ __m512i idx_i16x32 = _mm512_cvtepu8_epi16(e2m3x32);
65
+
66
+ // 32-entry LUT for positive E2M3 magnitudes → F16
67
+ __m512i const lut_pos_i16x32 = _mm512_set_epi16( //
68
+ 0x4780, 0x4700, 0x4680, 0x4600, 0x4580, 0x4500, 0x4480, 0x4400, // [31-24] exp=3: f16_exp=17
69
+ 0x4380, 0x4300, 0x4280, 0x4200, 0x4180, 0x4100, 0x4080, 0x4000, // [23-16] exp=2: f16_exp=16
70
+ 0x3F80, 0x3F00, 0x3E80, 0x3E00, 0x3D80, 0x3D00, 0x3C80, 0x3C00, // [15-8] exp=1: f16_exp=15
71
+ 0x3B00, 0x3A00, 0x3900, 0x3800, 0x3600, 0x3400, 0x3000, 0x0000); // [7-0] exp=0: subnormals (0, 1/8..7/8)
72
+
73
+ // 32-entry LUT for negative E2M3 magnitudes → F16 (= positive | 0x8000)
74
+ __m512i const lut_neg_i16x32 = _mm512_set_epi16( //
75
+ (short)0xC780, (short)0xC700, (short)0xC680, (short)0xC600, //
76
+ (short)0xC580, (short)0xC500, (short)0xC480, (short)0xC400, // [31-24] exp=3
77
+ (short)0xC380, (short)0xC300, (short)0xC280, (short)0xC200, //
78
+ (short)0xC180, (short)0xC100, (short)0xC080, (short)0xC000, // [23-16] exp=2
79
+ (short)0xBF80, (short)0xBF00, (short)0xBE80, (short)0xBE00, //
80
+ (short)0xBD80, (short)0xBD00, (short)0xBC80, (short)0xBC00, // [15-8] exp=1
81
+ (short)0xBB00, (short)0xBA00, (short)0xB900, (short)0xB800, //
82
+ (short)0xB600, (short)0xB400, (short)0xB000, (short)0x8000); // [7-0] exp=0
83
+
84
+ return nk_m512h_from_m512i_(_mm512_permutex2var_epi16(lut_pos_i16x32, idx_i16x32, lut_neg_i16x32));
85
+ }
86
+
87
+ /** @brief Convert 32x e3m2 → 32x f16 via 64-entry signed LUT lookup (AVX-512BW).
88
+ * E3M2 format: S EEE MM (bias=3, 6 bits total: sign at bit 5, magnitude bits 4-0).
89
+ * F16: S EEEEE MMMMMMMMMM (bias=15).
90
+ *
91
+ * Same permutex2var technique as E2M3 — sign bit 5 selects the LUT source. */
92
+ NK_INTERNAL __m512h nk_e3m2x32_to_f16x32_sapphire_(__m256i e3m2x32) {
93
+ __m512i idx_i16x32 = _mm512_cvtepu8_epi16(e3m2x32);
94
+
95
+ // 32-entry LUT for positive E3M2 magnitudes → F16
96
+ __m512i const lut_pos_i16x32 = _mm512_set_epi16( //
97
+ 0x4F00, 0x4E00, 0x4D00, 0x4C00, // [31-28] exp=7: f16_exp=19
98
+ 0x4B00, 0x4A00, 0x4900, 0x4800, // [27-24] exp=6: f16_exp=18
99
+ 0x4700, 0x4600, 0x4500, 0x4400, // [23-20] exp=5: f16_exp=17
100
+ 0x4300, 0x4200, 0x4100, 0x4000, // [19-16] exp=4: f16_exp=16
101
+ 0x3F00, 0x3E00, 0x3D00, 0x3C00, // [15-12] exp=3: f16_exp=15
102
+ 0x3B00, 0x3A00, 0x3900, 0x3800, // [11-8] exp=2: f16_exp=14
103
+ 0x3700, 0x3600, 0x3500, 0x3400, // [7-4] exp=1: f16_exp=13
104
+ 0x3200, 0x3000, 0x2C00, 0x0000); // [3-0] exp=0: subnormals
105
+
106
+ // 32-entry LUT for negative E3M2 magnitudes → F16 (= positive | 0x8000)
107
+ __m512i const lut_neg_i16x32 = _mm512_set_epi16( //
108
+ (short)0xCF00, (short)0xCE00, (short)0xCD00, (short)0xCC00, // [31-28] exp=7
109
+ (short)0xCB00, (short)0xCA00, (short)0xC900, (short)0xC800, // [27-24] exp=6
110
+ (short)0xC700, (short)0xC600, (short)0xC500, (short)0xC400, // [23-20] exp=5
111
+ (short)0xC300, (short)0xC200, (short)0xC100, (short)0xC000, // [19-16] exp=4
112
+ (short)0xBF00, (short)0xBE00, (short)0xBD00, (short)0xBC00, // [15-12] exp=3
113
+ (short)0xBB00, (short)0xBA00, (short)0xB900, (short)0xB800, // [11-8] exp=2
114
+ (short)0xB700, (short)0xB600, (short)0xB500, (short)0xB400, // [7-4] exp=1
115
+ (short)0xB200, (short)0xB000, (short)0xAC00, (short)0x8000); // [3-0] exp=0
116
+
117
+ return nk_m512h_from_m512i_(_mm512_permutex2var_epi16(lut_pos_i16x32, idx_i16x32, lut_neg_i16x32));
118
+ }
119
+
120
+ /** @brief Flush 32 FP16 values to FP32 accumulator by splitting into 2x16 halves. */
121
+ NK_INTERNAL __m512 nk_flush_f16_to_f32_sapphire_(__m512h acc_f16x32, __m512 sum_f32x16) {
122
+ __m256i low_f16x16 = _mm512_castsi512_si256(nk_m512i_from_m512h_(acc_f16x32));
123
+ __m256i high_f16x16 = _mm512_extracti64x4_epi64(nk_m512i_from_m512h_(acc_f16x32), 1);
124
+ sum_f32x16 = _mm512_add_ps(sum_f32x16, _mm512_cvtph_ps(low_f16x16));
125
+ sum_f32x16 = _mm512_add_ps(sum_f32x16, _mm512_cvtph_ps(high_f16x16));
126
+ return sum_f32x16;
127
+ }
128
+
129
+ #if defined(__clang__)
130
+ #pragma clang attribute pop
131
+ #elif defined(__GNUC__)
132
+ #pragma GCC pop_options
133
+ #endif
134
+
135
+ #if defined(__cplusplus)
136
+ } // extern "C"
137
+ #endif
138
+
139
+ #endif // NK_TARGET_SAPPHIRE
140
+ #endif // NK_TARGET_X86_
141
+ #endif // NK_DOT_SAPPHIRE_H