numkong 7.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (294) hide show
  1. package/LICENSE +201 -0
  2. package/README.md +495 -0
  3. package/binding.gyp +540 -0
  4. package/c/dispatch.h +512 -0
  5. package/c/dispatch_bf16.c +389 -0
  6. package/c/dispatch_bf16c.c +52 -0
  7. package/c/dispatch_e2m3.c +263 -0
  8. package/c/dispatch_e3m2.c +243 -0
  9. package/c/dispatch_e4m3.c +276 -0
  10. package/c/dispatch_e5m2.c +272 -0
  11. package/c/dispatch_f16.c +376 -0
  12. package/c/dispatch_f16c.c +58 -0
  13. package/c/dispatch_f32.c +378 -0
  14. package/c/dispatch_f32c.c +99 -0
  15. package/c/dispatch_f64.c +296 -0
  16. package/c/dispatch_f64c.c +98 -0
  17. package/c/dispatch_i16.c +96 -0
  18. package/c/dispatch_i32.c +89 -0
  19. package/c/dispatch_i4.c +150 -0
  20. package/c/dispatch_i64.c +86 -0
  21. package/c/dispatch_i8.c +289 -0
  22. package/c/dispatch_other.c +330 -0
  23. package/c/dispatch_u1.c +148 -0
  24. package/c/dispatch_u16.c +124 -0
  25. package/c/dispatch_u32.c +118 -0
  26. package/c/dispatch_u4.c +150 -0
  27. package/c/dispatch_u64.c +102 -0
  28. package/c/dispatch_u8.c +303 -0
  29. package/c/numkong.c +950 -0
  30. package/include/README.md +573 -0
  31. package/include/module.modulemap +129 -0
  32. package/include/numkong/attention/sapphireamx.h +1361 -0
  33. package/include/numkong/attention/sme.h +2066 -0
  34. package/include/numkong/attention.h +49 -0
  35. package/include/numkong/capabilities.h +748 -0
  36. package/include/numkong/cast/README.md +262 -0
  37. package/include/numkong/cast/haswell.h +975 -0
  38. package/include/numkong/cast/icelake.h +470 -0
  39. package/include/numkong/cast/neon.h +1192 -0
  40. package/include/numkong/cast/rvv.h +1021 -0
  41. package/include/numkong/cast/sapphire.h +262 -0
  42. package/include/numkong/cast/serial.h +2262 -0
  43. package/include/numkong/cast/skylake.h +856 -0
  44. package/include/numkong/cast/v128relaxed.h +180 -0
  45. package/include/numkong/cast.h +230 -0
  46. package/include/numkong/curved/README.md +223 -0
  47. package/include/numkong/curved/genoa.h +182 -0
  48. package/include/numkong/curved/haswell.h +276 -0
  49. package/include/numkong/curved/neon.h +205 -0
  50. package/include/numkong/curved/neonbfdot.h +212 -0
  51. package/include/numkong/curved/neonhalf.h +212 -0
  52. package/include/numkong/curved/rvv.h +305 -0
  53. package/include/numkong/curved/serial.h +207 -0
  54. package/include/numkong/curved/skylake.h +457 -0
  55. package/include/numkong/curved/smef64.h +506 -0
  56. package/include/numkong/curved.h +517 -0
  57. package/include/numkong/curved.hpp +144 -0
  58. package/include/numkong/dot/README.md +425 -0
  59. package/include/numkong/dot/alder.h +563 -0
  60. package/include/numkong/dot/genoa.h +315 -0
  61. package/include/numkong/dot/haswell.h +1688 -0
  62. package/include/numkong/dot/icelake.h +883 -0
  63. package/include/numkong/dot/neon.h +818 -0
  64. package/include/numkong/dot/neonbfdot.h +244 -0
  65. package/include/numkong/dot/neonfhm.h +360 -0
  66. package/include/numkong/dot/neonhalf.h +198 -0
  67. package/include/numkong/dot/neonsdot.h +508 -0
  68. package/include/numkong/dot/rvv.h +714 -0
  69. package/include/numkong/dot/rvvbb.h +72 -0
  70. package/include/numkong/dot/rvvbf16.h +123 -0
  71. package/include/numkong/dot/rvvhalf.h +129 -0
  72. package/include/numkong/dot/sapphire.h +141 -0
  73. package/include/numkong/dot/serial.h +838 -0
  74. package/include/numkong/dot/sierra.h +405 -0
  75. package/include/numkong/dot/skylake.h +1084 -0
  76. package/include/numkong/dot/sve.h +379 -0
  77. package/include/numkong/dot/svebfdot.h +74 -0
  78. package/include/numkong/dot/svehalf.h +123 -0
  79. package/include/numkong/dot/v128relaxed.h +1258 -0
  80. package/include/numkong/dot.h +1070 -0
  81. package/include/numkong/dot.hpp +94 -0
  82. package/include/numkong/dots/README.md +496 -0
  83. package/include/numkong/dots/alder.h +114 -0
  84. package/include/numkong/dots/genoa.h +94 -0
  85. package/include/numkong/dots/haswell.h +295 -0
  86. package/include/numkong/dots/icelake.h +171 -0
  87. package/include/numkong/dots/neon.h +120 -0
  88. package/include/numkong/dots/neonbfdot.h +58 -0
  89. package/include/numkong/dots/neonfhm.h +94 -0
  90. package/include/numkong/dots/neonhalf.h +57 -0
  91. package/include/numkong/dots/neonsdot.h +108 -0
  92. package/include/numkong/dots/rvv.h +2486 -0
  93. package/include/numkong/dots/sapphireamx.h +3973 -0
  94. package/include/numkong/dots/serial.h +2844 -0
  95. package/include/numkong/dots/sierra.h +97 -0
  96. package/include/numkong/dots/skylake.h +196 -0
  97. package/include/numkong/dots/sme.h +5372 -0
  98. package/include/numkong/dots/smebi32.h +461 -0
  99. package/include/numkong/dots/smef64.h +1318 -0
  100. package/include/numkong/dots/smehalf.h +47 -0
  101. package/include/numkong/dots/v128relaxed.h +294 -0
  102. package/include/numkong/dots.h +2804 -0
  103. package/include/numkong/dots.hpp +639 -0
  104. package/include/numkong/each/README.md +469 -0
  105. package/include/numkong/each/haswell.h +1658 -0
  106. package/include/numkong/each/icelake.h +272 -0
  107. package/include/numkong/each/neon.h +1104 -0
  108. package/include/numkong/each/neonbfdot.h +212 -0
  109. package/include/numkong/each/neonhalf.h +410 -0
  110. package/include/numkong/each/rvv.h +1121 -0
  111. package/include/numkong/each/sapphire.h +477 -0
  112. package/include/numkong/each/serial.h +260 -0
  113. package/include/numkong/each/skylake.h +1562 -0
  114. package/include/numkong/each.h +2146 -0
  115. package/include/numkong/each.hpp +434 -0
  116. package/include/numkong/geospatial/README.md +147 -0
  117. package/include/numkong/geospatial/haswell.h +593 -0
  118. package/include/numkong/geospatial/neon.h +571 -0
  119. package/include/numkong/geospatial/rvv.h +701 -0
  120. package/include/numkong/geospatial/serial.h +309 -0
  121. package/include/numkong/geospatial/skylake.h +577 -0
  122. package/include/numkong/geospatial/v128relaxed.h +613 -0
  123. package/include/numkong/geospatial.h +453 -0
  124. package/include/numkong/geospatial.hpp +235 -0
  125. package/include/numkong/matrix.hpp +336 -0
  126. package/include/numkong/maxsim/README.md +187 -0
  127. package/include/numkong/maxsim/alder.h +511 -0
  128. package/include/numkong/maxsim/genoa.h +115 -0
  129. package/include/numkong/maxsim/haswell.h +553 -0
  130. package/include/numkong/maxsim/icelake.h +480 -0
  131. package/include/numkong/maxsim/neonsdot.h +394 -0
  132. package/include/numkong/maxsim/sapphireamx.h +877 -0
  133. package/include/numkong/maxsim/serial.h +490 -0
  134. package/include/numkong/maxsim/sme.h +929 -0
  135. package/include/numkong/maxsim/v128relaxed.h +280 -0
  136. package/include/numkong/maxsim.h +571 -0
  137. package/include/numkong/maxsim.hpp +133 -0
  138. package/include/numkong/mesh/README.md +227 -0
  139. package/include/numkong/mesh/haswell.h +2235 -0
  140. package/include/numkong/mesh/neon.h +1329 -0
  141. package/include/numkong/mesh/neonbfdot.h +842 -0
  142. package/include/numkong/mesh/neonhalf.h +616 -0
  143. package/include/numkong/mesh/rvv.h +916 -0
  144. package/include/numkong/mesh/serial.h +742 -0
  145. package/include/numkong/mesh/skylake.h +1135 -0
  146. package/include/numkong/mesh/v128relaxed.h +1052 -0
  147. package/include/numkong/mesh.h +652 -0
  148. package/include/numkong/mesh.hpp +762 -0
  149. package/include/numkong/numkong.h +78 -0
  150. package/include/numkong/numkong.hpp +57 -0
  151. package/include/numkong/probability/README.md +173 -0
  152. package/include/numkong/probability/haswell.h +267 -0
  153. package/include/numkong/probability/neon.h +225 -0
  154. package/include/numkong/probability/rvv.h +409 -0
  155. package/include/numkong/probability/serial.h +169 -0
  156. package/include/numkong/probability/skylake.h +324 -0
  157. package/include/numkong/probability.h +383 -0
  158. package/include/numkong/probability.hpp +120 -0
  159. package/include/numkong/random.h +50 -0
  160. package/include/numkong/random.hpp +285 -0
  161. package/include/numkong/reduce/README.md +547 -0
  162. package/include/numkong/reduce/alder.h +632 -0
  163. package/include/numkong/reduce/genoa.h +201 -0
  164. package/include/numkong/reduce/haswell.h +3783 -0
  165. package/include/numkong/reduce/icelake.h +549 -0
  166. package/include/numkong/reduce/neon.h +3841 -0
  167. package/include/numkong/reduce/neonbfdot.h +353 -0
  168. package/include/numkong/reduce/neonfhm.h +665 -0
  169. package/include/numkong/reduce/neonhalf.h +157 -0
  170. package/include/numkong/reduce/neonsdot.h +357 -0
  171. package/include/numkong/reduce/rvv.h +3407 -0
  172. package/include/numkong/reduce/serial.h +757 -0
  173. package/include/numkong/reduce/sierra.h +338 -0
  174. package/include/numkong/reduce/skylake.h +3792 -0
  175. package/include/numkong/reduce/v128relaxed.h +2302 -0
  176. package/include/numkong/reduce.h +1597 -0
  177. package/include/numkong/reduce.hpp +633 -0
  178. package/include/numkong/scalar/README.md +89 -0
  179. package/include/numkong/scalar/haswell.h +113 -0
  180. package/include/numkong/scalar/neon.h +122 -0
  181. package/include/numkong/scalar/neonhalf.h +70 -0
  182. package/include/numkong/scalar/rvv.h +211 -0
  183. package/include/numkong/scalar/sapphire.h +63 -0
  184. package/include/numkong/scalar/serial.h +332 -0
  185. package/include/numkong/scalar/v128relaxed.h +56 -0
  186. package/include/numkong/scalar.h +683 -0
  187. package/include/numkong/set/README.md +179 -0
  188. package/include/numkong/set/haswell.h +334 -0
  189. package/include/numkong/set/icelake.h +485 -0
  190. package/include/numkong/set/neon.h +364 -0
  191. package/include/numkong/set/rvv.h +226 -0
  192. package/include/numkong/set/rvvbb.h +117 -0
  193. package/include/numkong/set/serial.h +174 -0
  194. package/include/numkong/set/sve.h +185 -0
  195. package/include/numkong/set/v128relaxed.h +240 -0
  196. package/include/numkong/set.h +457 -0
  197. package/include/numkong/set.hpp +114 -0
  198. package/include/numkong/sets/README.md +149 -0
  199. package/include/numkong/sets/haswell.h +63 -0
  200. package/include/numkong/sets/icelake.h +66 -0
  201. package/include/numkong/sets/neon.h +61 -0
  202. package/include/numkong/sets/serial.h +43 -0
  203. package/include/numkong/sets/smebi32.h +1099 -0
  204. package/include/numkong/sets/v128relaxed.h +58 -0
  205. package/include/numkong/sets.h +339 -0
  206. package/include/numkong/sparse/README.md +156 -0
  207. package/include/numkong/sparse/icelake.h +463 -0
  208. package/include/numkong/sparse/neon.h +288 -0
  209. package/include/numkong/sparse/serial.h +117 -0
  210. package/include/numkong/sparse/sve2.h +507 -0
  211. package/include/numkong/sparse/turin.h +322 -0
  212. package/include/numkong/sparse.h +363 -0
  213. package/include/numkong/sparse.hpp +113 -0
  214. package/include/numkong/spatial/README.md +435 -0
  215. package/include/numkong/spatial/alder.h +607 -0
  216. package/include/numkong/spatial/genoa.h +290 -0
  217. package/include/numkong/spatial/haswell.h +960 -0
  218. package/include/numkong/spatial/icelake.h +586 -0
  219. package/include/numkong/spatial/neon.h +773 -0
  220. package/include/numkong/spatial/neonbfdot.h +165 -0
  221. package/include/numkong/spatial/neonhalf.h +118 -0
  222. package/include/numkong/spatial/neonsdot.h +261 -0
  223. package/include/numkong/spatial/rvv.h +984 -0
  224. package/include/numkong/spatial/rvvbf16.h +123 -0
  225. package/include/numkong/spatial/rvvhalf.h +117 -0
  226. package/include/numkong/spatial/sapphire.h +343 -0
  227. package/include/numkong/spatial/serial.h +346 -0
  228. package/include/numkong/spatial/sierra.h +323 -0
  229. package/include/numkong/spatial/skylake.h +606 -0
  230. package/include/numkong/spatial/sve.h +224 -0
  231. package/include/numkong/spatial/svebfdot.h +122 -0
  232. package/include/numkong/spatial/svehalf.h +109 -0
  233. package/include/numkong/spatial/v128relaxed.h +717 -0
  234. package/include/numkong/spatial.h +1425 -0
  235. package/include/numkong/spatial.hpp +183 -0
  236. package/include/numkong/spatials/README.md +580 -0
  237. package/include/numkong/spatials/alder.h +94 -0
  238. package/include/numkong/spatials/genoa.h +94 -0
  239. package/include/numkong/spatials/haswell.h +219 -0
  240. package/include/numkong/spatials/icelake.h +113 -0
  241. package/include/numkong/spatials/neon.h +109 -0
  242. package/include/numkong/spatials/neonbfdot.h +60 -0
  243. package/include/numkong/spatials/neonfhm.h +92 -0
  244. package/include/numkong/spatials/neonhalf.h +58 -0
  245. package/include/numkong/spatials/neonsdot.h +109 -0
  246. package/include/numkong/spatials/rvv.h +1960 -0
  247. package/include/numkong/spatials/sapphireamx.h +1149 -0
  248. package/include/numkong/spatials/serial.h +226 -0
  249. package/include/numkong/spatials/sierra.h +96 -0
  250. package/include/numkong/spatials/skylake.h +184 -0
  251. package/include/numkong/spatials/sme.h +1901 -0
  252. package/include/numkong/spatials/smef64.h +465 -0
  253. package/include/numkong/spatials/v128relaxed.h +240 -0
  254. package/include/numkong/spatials.h +3021 -0
  255. package/include/numkong/spatials.hpp +508 -0
  256. package/include/numkong/tensor.hpp +1592 -0
  257. package/include/numkong/trigonometry/README.md +184 -0
  258. package/include/numkong/trigonometry/haswell.h +652 -0
  259. package/include/numkong/trigonometry/neon.h +639 -0
  260. package/include/numkong/trigonometry/rvv.h +699 -0
  261. package/include/numkong/trigonometry/serial.h +703 -0
  262. package/include/numkong/trigonometry/skylake.h +721 -0
  263. package/include/numkong/trigonometry/v128relaxed.h +666 -0
  264. package/include/numkong/trigonometry.h +467 -0
  265. package/include/numkong/trigonometry.hpp +166 -0
  266. package/include/numkong/types.h +1384 -0
  267. package/include/numkong/types.hpp +5603 -0
  268. package/include/numkong/vector.hpp +698 -0
  269. package/javascript/README.md +246 -0
  270. package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
  271. package/javascript/dist/cjs/numkong-wasm.js +617 -0
  272. package/javascript/dist/cjs/numkong.d.ts +343 -0
  273. package/javascript/dist/cjs/numkong.js +523 -0
  274. package/javascript/dist/cjs/package.json +3 -0
  275. package/javascript/dist/cjs/types.d.ts +284 -0
  276. package/javascript/dist/cjs/types.js +653 -0
  277. package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
  278. package/javascript/dist/esm/numkong-wasm.js +595 -0
  279. package/javascript/dist/esm/numkong.d.ts +343 -0
  280. package/javascript/dist/esm/numkong.js +452 -0
  281. package/javascript/dist/esm/package.json +3 -0
  282. package/javascript/dist/esm/types.d.ts +284 -0
  283. package/javascript/dist/esm/types.js +630 -0
  284. package/javascript/dist-package-cjs.json +3 -0
  285. package/javascript/dist-package-esm.json +3 -0
  286. package/javascript/node-gyp-build.d.ts +1 -0
  287. package/javascript/numkong-wasm.ts +756 -0
  288. package/javascript/numkong.c +689 -0
  289. package/javascript/numkong.ts +575 -0
  290. package/javascript/tsconfig-base.json +39 -0
  291. package/javascript/tsconfig-cjs.json +8 -0
  292. package/javascript/tsconfig-esm.json +8 -0
  293. package/javascript/types.ts +674 -0
  294. package/package.json +87 -0
@@ -0,0 +1,477 @@
1
+ /**
2
+ * @brief SIMD-accelerated Elementwise Arithmetic for Sapphire Rapids.
3
+ * @file include/numkong/each/sapphire.h
4
+ * @author Ash Vardanian
5
+ * @date December 27, 2025
6
+ *
7
+ * @sa include/numkong/each.h
8
+ *
9
+ * @section sapphire_elementwise_instructions Relevant Instructions
10
+ *
11
+ * Intrinsic Instruction Sapphire Genoa
12
+ * _mm512_add_ph VADDPH (ZMM, ZMM, ZMM) 4cy @ p05 3cy @ p01
13
+ * _mm512_mul_ph VMULPH (ZMM, ZMM, ZMM) 4cy @ p05 3cy @ p01
14
+ * _mm512_fmadd_ph VFMADD (ZMM, ZMM, ZMM) 4cy @ p05 4cy @ p01
15
+ * _mm512_cvtepi16_ph VCVTW2PH (ZMM, ZMM) 4cy @ p05 4cy @ p01
16
+ * _mm512_cvtph_epi16 VCVTPH2W (ZMM, ZMM) 4cy @ p05 4cy @ p01
17
+ * _mm512_cvtepi8_epi16 VPMOVSXBW (ZMM, YMM) 3cy @ p5 3cy @ p12
18
+ * _mm512_cvtsepi16_epi8 VPMOVSWB (YMM, ZMM) 4cy @ p5 4cy @ p12
19
+ * _mm512_packus_epi16 VPACKUSWB (ZMM, ZMM, ZMM) 1cy @ p5 1cy @ p12
20
+ * _mm256_add_ph VADDPH (YMM, YMM, YMM) 4cy @ p05 3cy @ p01
21
+ * _mm512_maskz_loadu_epi16 VMOVDQU16 (ZMM {K}, M512) 7cy @ p23 7cy @ p23
22
+ * _mm512_mask_storeu_epi16 VMOVDQU16 (M512 {K}, ZMM) 4cy @ p4 4cy @ p4
23
+ */
24
+ #ifndef NK_EACH_SAPPHIRE_H
25
+ #define NK_EACH_SAPPHIRE_H
26
+
27
+ #if NK_TARGET_X86_
28
+ #if NK_TARGET_SAPPHIRE
29
+
30
+ #include "numkong/types.h"
31
+ #include "numkong/cast/sapphire.h" // `nk_f32_to_f16_sapphire`
32
+
33
+ #if defined(__cplusplus)
34
+ extern "C" {
35
+ #endif
36
+
37
+ #if defined(__clang__)
38
+ #pragma clang attribute push(__attribute__((target("avx2,avx512f,avx512vl,avx512bw,avx512fp16,f16c,fma,bmi,bmi2"))), \
39
+ apply_to = function)
40
+ #elif defined(__GNUC__)
41
+ #pragma GCC push_options
42
+ #pragma GCC target("avx2", "avx512f", "avx512vl", "avx512bw", "avx512fp16", "f16c", "fma", "bmi", "bmi2")
43
+ #endif
44
+
45
+ NK_PUBLIC void nk_each_sum_f16_sapphire(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f16_t *result) {
46
+ __mmask32 mask = 0xFFFFFFFF;
47
+ __m512h a_f16_vec, b_f16_vec;
48
+ __m512h sum_f16_vec;
49
+ nk_each_sum_f16_sapphire_cycle:
50
+ if (n < 32) {
51
+ mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, n);
52
+ a_f16_vec = _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(mask, a));
53
+ b_f16_vec = _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(mask, b));
54
+ n = 0;
55
+ }
56
+ else {
57
+ a_f16_vec = _mm512_loadu_ph(a);
58
+ b_f16_vec = _mm512_loadu_ph(b);
59
+ a += 32, b += 32, n -= 32;
60
+ }
61
+ sum_f16_vec = _mm512_add_ph(a_f16_vec, b_f16_vec);
62
+ _mm512_mask_storeu_epi16(result, mask, _mm512_castph_si512(sum_f16_vec));
63
+ result += 32;
64
+ if (n) goto nk_each_sum_f16_sapphire_cycle;
65
+ }
66
+
67
+ NK_PUBLIC void nk_each_scale_u8_sapphire(nk_u8_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
68
+ nk_u8_t *result) {
69
+ short alpha_short, beta_short;
70
+ nk_f32_to_f16_sapphire(alpha, (nk_f16_t *)&alpha_short);
71
+ nk_f32_to_f16_sapphire(beta, (nk_f16_t *)&beta_short);
72
+ __mmask64 mask = 0xFFFFFFFFFFFFFFFFull;
73
+ __m512h alpha_f16x32 = _mm512_castsi512_ph(_mm512_set1_epi16(alpha_short));
74
+ __m512h beta_f16x32 = _mm512_castsi512_ph(_mm512_set1_epi16(beta_short));
75
+ __m512i a_u8x64, result_u8x64;
76
+ __m512h a_low_f16x32, a_high_f16x32;
77
+ __m512h result_low_f16x32, result_high_f16x32;
78
+ __m512i result_low_i16x32, result_high_i16x32;
79
+ nk_each_scale_u8_sapphire_cycle:
80
+ if (n < 64) {
81
+ mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFFull, n);
82
+ a_u8x64 = _mm512_maskz_loadu_epi8(mask, a);
83
+ n = 0;
84
+ }
85
+ else {
86
+ a_u8x64 = _mm512_loadu_epi8(a);
87
+ a += 64, n -= 64;
88
+ }
89
+ // Upcast:
90
+ a_low_f16x32 = _mm512_cvtepi16_ph(_mm512_unpacklo_epi8(a_u8x64, _mm512_setzero_si512()));
91
+ a_high_f16x32 = _mm512_cvtepi16_ph(_mm512_unpackhi_epi8(a_u8x64, _mm512_setzero_si512()));
92
+ // Scale:
93
+ result_low_f16x32 = _mm512_fmadd_ph(a_low_f16x32, alpha_f16x32, beta_f16x32);
94
+ result_high_f16x32 = _mm512_fmadd_ph(a_high_f16x32, alpha_f16x32, beta_f16x32);
95
+ // Downcast:
96
+ result_low_i16x32 = _mm512_cvtph_epi16(result_low_f16x32);
97
+ result_high_i16x32 = _mm512_cvtph_epi16(result_high_f16x32);
98
+ result_u8x64 = _mm512_packus_epi16(result_low_i16x32, result_high_i16x32);
99
+ _mm512_mask_storeu_epi8(result, mask, result_u8x64);
100
+ result += 64;
101
+ if (n) goto nk_each_scale_u8_sapphire_cycle;
102
+ }
103
+
104
+ NK_PUBLIC void nk_each_blend_u8_sapphire( //
105
+ nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, //
106
+ nk_f32_t const *alpha, nk_f32_t const *beta, nk_u8_t *result) {
107
+
108
+ nk_f32_t alpha_val = *alpha;
109
+ nk_f32_t beta_val = *beta;
110
+
111
+ // There are several special cases we may want to implement:
112
+ // 1. Simple addition, when both weights are equal to 1.0.
113
+ if (alpha_val == 1 && beta_val == 1) {
114
+ // In this case we can avoid expensive multiplications.
115
+ nk_each_sum_u8_icelake(a, b, n, result);
116
+ return;
117
+ }
118
+ // 2. Just scaling, when one of the weights is equal to zero.
119
+ else if (alpha_val == 0 || beta_val == 0) {
120
+ // In this case we can avoid half of the load instructions.
121
+ nk_f32_t zero = 0;
122
+ if (beta_val == 0) { nk_each_scale_u8_sapphire(a, n, alpha, &zero, result); }
123
+ else { nk_each_scale_u8_sapphire(b, n, beta, &zero, result); }
124
+ return;
125
+ }
126
+
127
+ // The general case.
128
+ short alpha_short, beta_short;
129
+ nk_f32_to_f16_sapphire(&alpha_val, (nk_f16_t *)&alpha_short);
130
+ nk_f32_to_f16_sapphire(&beta_val, (nk_f16_t *)&beta_short);
131
+ __mmask64 mask = 0xFFFFFFFFFFFFFFFFull;
132
+ __m512h alpha_f16x32 = _mm512_castsi512_ph(_mm512_set1_epi16(alpha_short));
133
+ __m512h beta_f16x32 = _mm512_castsi512_ph(_mm512_set1_epi16(beta_short));
134
+ __m512i a_u8x64, b_u8x64, result_u8x64;
135
+ __m512h a_low_f16x32, a_high_f16x32, b_low_f16x32, b_high_f16x32;
136
+ __m512h a_scaled_low_f16x32, a_scaled_high_f16x32, result_low_f16x32, result_high_f16x32;
137
+ __m512i result_low_i16x32, result_high_i16x32;
138
+ nk_each_blend_u8_sapphire_cycle:
139
+ if (n < 64) {
140
+ mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFFull, n);
141
+ a_u8x64 = _mm512_maskz_loadu_epi8(mask, a);
142
+ b_u8x64 = _mm512_maskz_loadu_epi8(mask, b);
143
+ n = 0;
144
+ }
145
+ else {
146
+ a_u8x64 = _mm512_loadu_epi8(a);
147
+ b_u8x64 = _mm512_loadu_epi8(b);
148
+ a += 64, b += 64, n -= 64;
149
+ }
150
+ // Upcast:
151
+ a_low_f16x32 = _mm512_cvtepi16_ph(_mm512_unpacklo_epi8(a_u8x64, _mm512_setzero_si512()));
152
+ a_high_f16x32 = _mm512_cvtepi16_ph(_mm512_unpackhi_epi8(a_u8x64, _mm512_setzero_si512()));
153
+ b_low_f16x32 = _mm512_cvtepi16_ph(_mm512_unpacklo_epi8(b_u8x64, _mm512_setzero_si512()));
154
+ b_high_f16x32 = _mm512_cvtepi16_ph(_mm512_unpackhi_epi8(b_u8x64, _mm512_setzero_si512()));
155
+ // Scale:
156
+ a_scaled_low_f16x32 = _mm512_mul_ph(a_low_f16x32, alpha_f16x32);
157
+ a_scaled_high_f16x32 = _mm512_mul_ph(a_high_f16x32, alpha_f16x32);
158
+ // Add:
159
+ result_low_f16x32 = _mm512_fmadd_ph(b_low_f16x32, beta_f16x32, a_scaled_low_f16x32);
160
+ result_high_f16x32 = _mm512_fmadd_ph(b_high_f16x32, beta_f16x32, a_scaled_high_f16x32);
161
+ // Downcast:
162
+ result_low_i16x32 = _mm512_cvtph_epi16(result_low_f16x32);
163
+ result_high_i16x32 = _mm512_cvtph_epi16(result_high_f16x32);
164
+ result_u8x64 = _mm512_packus_epi16(result_low_i16x32, result_high_i16x32);
165
+ _mm512_mask_storeu_epi8(result, mask, result_u8x64);
166
+ result += 64;
167
+ if (n) goto nk_each_blend_u8_sapphire_cycle;
168
+ }
169
+
170
+ NK_PUBLIC void nk_each_scale_i8_sapphire(nk_i8_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
171
+ nk_i8_t *result) {
172
+ short alpha_short, beta_short;
173
+ nk_f32_to_f16_sapphire(alpha, (nk_f16_t *)&alpha_short);
174
+ nk_f32_to_f16_sapphire(beta, (nk_f16_t *)&beta_short);
175
+ __mmask64 mask = 0xFFFFFFFFFFFFFFFFull;
176
+ __m512h alpha_f16x32 = _mm512_castsi512_ph(_mm512_set1_epi16(alpha_short));
177
+ __m512h beta_f16x32 = _mm512_castsi512_ph(_mm512_set1_epi16(beta_short));
178
+ __m256i a_low_i8x32, a_high_i8x32;
179
+ __m512i result_i8x64;
180
+ __m512h a_low_f16x32, a_high_f16x32;
181
+ __m512h result_low_f16x32, result_high_f16x32;
182
+ __m512i result_low_i16x32, result_high_i16x32;
183
+ nk_each_scale_i8_sapphire_cycle:
184
+ if (n < 64) {
185
+ // Tail: use masked 512-bit load and extract (runs once)
186
+ mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFFull, n);
187
+ __m512i a_i8x64 = _mm512_maskz_loadu_epi8(mask, a);
188
+ a_low_i8x32 = _mm512_castsi512_si256(a_i8x64);
189
+ a_high_i8x32 = _mm512_extracti64x4_epi64(a_i8x64, 1);
190
+ n = 0;
191
+ }
192
+ else {
193
+ // Hot path: 2×256-bit loads to avoid VEXTRACTI64X4 (Port 5)
194
+ a_low_i8x32 = _mm256_loadu_epi8(a);
195
+ a_high_i8x32 = _mm256_loadu_epi8(a + 32);
196
+ a += 64, n -= 64;
197
+ }
198
+ // Upcast from 256-bit halves:
199
+ a_low_f16x32 = _mm512_cvtepi16_ph(_mm512_cvtepi8_epi16(a_low_i8x32));
200
+ a_high_f16x32 = _mm512_cvtepi16_ph(_mm512_cvtepi8_epi16(a_high_i8x32));
201
+ // Scale:
202
+ result_low_f16x32 = _mm512_fmadd_ph(a_low_f16x32, alpha_f16x32, beta_f16x32);
203
+ result_high_f16x32 = _mm512_fmadd_ph(a_high_f16x32, alpha_f16x32, beta_f16x32);
204
+ // Downcast:
205
+ result_low_i16x32 = _mm512_cvtph_epi16(result_low_f16x32);
206
+ result_high_i16x32 = _mm512_cvtph_epi16(result_high_f16x32);
207
+ result_i8x64 = _mm512_inserti64x4(_mm512_castsi256_si512(_mm512_cvtsepi16_epi8(result_low_i16x32)),
208
+ _mm512_cvtsepi16_epi8(result_high_i16x32), 1);
209
+ _mm512_mask_storeu_epi8(result, mask, result_i8x64);
210
+ result += 64;
211
+ if (n) goto nk_each_scale_i8_sapphire_cycle;
212
+ }
213
+
214
+ NK_PUBLIC void nk_each_blend_i8_sapphire( //
215
+ nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, //
216
+ nk_f32_t const *alpha, nk_f32_t const *beta, nk_i8_t *result) {
217
+
218
+ nk_f32_t alpha_val = *alpha;
219
+ nk_f32_t beta_val = *beta;
220
+
221
+ // There are several special cases we may want to implement:
222
+ // 1. Simple addition, when both weights are equal to 1.0.
223
+ if (alpha_val == 1 && beta_val == 1) {
224
+ // In this case we can avoid expensive multiplications.
225
+ nk_each_sum_i8_icelake(a, b, n, result);
226
+ return;
227
+ }
228
+ // 2. Just scaling, when one of the weights is equal to zero.
229
+ else if (alpha_val == 0 || beta_val == 0) {
230
+ // In this case we can avoid half of the load instructions.
231
+ nk_f32_t zero = 0;
232
+ if (beta_val == 0) { nk_each_scale_i8_sapphire(a, n, alpha, &zero, result); }
233
+ else { nk_each_scale_i8_sapphire(b, n, beta, &zero, result); }
234
+ return;
235
+ }
236
+
237
+ // The general case.
238
+ short alpha_short, beta_short;
239
+ nk_f32_to_f16_sapphire(&alpha_val, (nk_f16_t *)&alpha_short);
240
+ nk_f32_to_f16_sapphire(&beta_val, (nk_f16_t *)&beta_short);
241
+ __mmask64 mask = 0xFFFFFFFFFFFFFFFFull;
242
+ __m512h alpha_f16x32 = _mm512_castsi512_ph(_mm512_set1_epi16(alpha_short));
243
+ __m512h beta_f16x32 = _mm512_castsi512_ph(_mm512_set1_epi16(beta_short));
244
+ __m256i a_low_i8x32, a_high_i8x32, b_low_i8x32, b_high_i8x32;
245
+ __m512i result_i8x64;
246
+ __m512h a_low_f16x32, a_high_f16x32, b_low_f16x32, b_high_f16x32;
247
+ __m512h a_scaled_low_f16x32, a_scaled_high_f16x32, result_low_f16x32, result_high_f16x32;
248
+ __m512i result_low_i16x32, result_high_i16x32;
249
+ nk_each_blend_i8_sapphire_cycle:
250
+ if (n < 64) {
251
+ // Tail: use masked 512-bit loads and extract (runs once)
252
+ mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFFull, n);
253
+ __m512i a_i8x64 = _mm512_maskz_loadu_epi8(mask, a);
254
+ __m512i b_i8x64 = _mm512_maskz_loadu_epi8(mask, b);
255
+ a_low_i8x32 = _mm512_castsi512_si256(a_i8x64);
256
+ a_high_i8x32 = _mm512_extracti64x4_epi64(a_i8x64, 1);
257
+ b_low_i8x32 = _mm512_castsi512_si256(b_i8x64);
258
+ b_high_i8x32 = _mm512_extracti64x4_epi64(b_i8x64, 1);
259
+ n = 0;
260
+ }
261
+ else {
262
+ // Hot path: 2×256-bit loads per vector to avoid VEXTRACTI64X4 (Port 5)
263
+ a_low_i8x32 = _mm256_loadu_epi8(a);
264
+ a_high_i8x32 = _mm256_loadu_epi8(a + 32);
265
+ b_low_i8x32 = _mm256_loadu_epi8(b);
266
+ b_high_i8x32 = _mm256_loadu_epi8(b + 32);
267
+ a += 64, b += 64, n -= 64;
268
+ }
269
+ // Upcast from 256-bit halves:
270
+ a_low_f16x32 = _mm512_cvtepi16_ph(_mm512_cvtepi8_epi16(a_low_i8x32));
271
+ a_high_f16x32 = _mm512_cvtepi16_ph(_mm512_cvtepi8_epi16(a_high_i8x32));
272
+ b_low_f16x32 = _mm512_cvtepi16_ph(_mm512_cvtepi8_epi16(b_low_i8x32));
273
+ b_high_f16x32 = _mm512_cvtepi16_ph(_mm512_cvtepi8_epi16(b_high_i8x32));
274
+ // Scale:
275
+ a_scaled_low_f16x32 = _mm512_mul_ph(a_low_f16x32, alpha_f16x32);
276
+ a_scaled_high_f16x32 = _mm512_mul_ph(a_high_f16x32, alpha_f16x32);
277
+ // Add:
278
+ result_low_f16x32 = _mm512_fmadd_ph(b_low_f16x32, beta_f16x32, a_scaled_low_f16x32);
279
+ result_high_f16x32 = _mm512_fmadd_ph(b_high_f16x32, beta_f16x32, a_scaled_high_f16x32);
280
+ // Downcast:
281
+ result_low_i16x32 = _mm512_cvtph_epi16(result_low_f16x32);
282
+ result_high_i16x32 = _mm512_cvtph_epi16(result_high_f16x32);
283
+ result_i8x64 = _mm512_inserti64x4(_mm512_castsi256_si512(_mm512_cvtsepi16_epi8(result_low_i16x32)),
284
+ _mm512_cvtsepi16_epi8(result_high_i16x32), 1);
285
+ _mm512_mask_storeu_epi8(result, mask, result_i8x64);
286
+ result += 64;
287
+ if (n) goto nk_each_blend_i8_sapphire_cycle;
288
+ }
289
+
290
+ NK_PUBLIC void nk_each_fma_i8_sapphire( //
291
+ nk_i8_t const *a, nk_i8_t const *b, nk_i8_t const *c, nk_size_t n, //
292
+ nk_f32_t const *alpha, nk_f32_t const *beta, nk_i8_t *result) {
293
+
294
+ short alpha_short, beta_short;
295
+ nk_f32_to_f16_sapphire(alpha, (nk_f16_t *)&alpha_short);
296
+ nk_f32_to_f16_sapphire(beta, (nk_f16_t *)&beta_short);
297
+ __mmask64 mask = 0xFFFFFFFFFFFFFFFF;
298
+ __m512h alpha_f16x32 = _mm512_castsi512_ph(_mm512_set1_epi16(alpha_short));
299
+ __m512h beta_f16x32 = _mm512_castsi512_ph(_mm512_set1_epi16(beta_short));
300
+ __m256i a_low_i8x32, a_high_i8x32, b_low_i8x32, b_high_i8x32, c_low_i8x32, c_high_i8x32;
301
+ __m512i result_i8x64;
302
+ __m512h a_low_f16x32, a_high_f16x32, b_low_f16x32, b_high_f16x32;
303
+ __m512h c_low_f16x32, c_high_f16x32, ab_low_f16x32, ab_high_f16x32;
304
+ __m512h ab_scaled_low_f16x32, ab_scaled_high_f16x32, result_low_f16x32, result_high_f16x32;
305
+ __m512i result_low_i16x32, result_high_i16x32;
306
+ __m512h min_f16x32 = _mm512_cvtepi16_ph(_mm512_set1_epi16(-128));
307
+ __m512h max_f16x32 = _mm512_cvtepi16_ph(_mm512_set1_epi16(127));
308
+
309
+ nk_each_fma_i8_sapphire_cycle:
310
+ if (n < 64) {
311
+ // Tail: use masked 512-bit loads and extract (runs once)
312
+ mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFFull, n);
313
+ __m512i a_i8x64 = _mm512_maskz_loadu_epi8(mask, a);
314
+ __m512i b_i8x64 = _mm512_maskz_loadu_epi8(mask, b);
315
+ __m512i c_i8x64 = _mm512_maskz_loadu_epi8(mask, c);
316
+ a_low_i8x32 = _mm512_castsi512_si256(a_i8x64);
317
+ a_high_i8x32 = _mm512_extracti64x4_epi64(a_i8x64, 1);
318
+ b_low_i8x32 = _mm512_castsi512_si256(b_i8x64);
319
+ b_high_i8x32 = _mm512_extracti64x4_epi64(b_i8x64, 1);
320
+ c_low_i8x32 = _mm512_castsi512_si256(c_i8x64);
321
+ c_high_i8x32 = _mm512_extracti64x4_epi64(c_i8x64, 1);
322
+ n = 0;
323
+ }
324
+ else {
325
+ // Hot path: 2×256-bit loads per vector to avoid VEXTRACTI64X4 (Port 5)
326
+ a_low_i8x32 = _mm256_loadu_epi8(a);
327
+ a_high_i8x32 = _mm256_loadu_epi8(a + 32);
328
+ b_low_i8x32 = _mm256_loadu_epi8(b);
329
+ b_high_i8x32 = _mm256_loadu_epi8(b + 32);
330
+ c_low_i8x32 = _mm256_loadu_epi8(c);
331
+ c_high_i8x32 = _mm256_loadu_epi8(c + 32);
332
+ a += 64, b += 64, c += 64, n -= 64;
333
+ }
334
+ // Upcast from 256-bit halves:
335
+ a_low_f16x32 = _mm512_cvtepi16_ph(_mm512_cvtepi8_epi16(a_low_i8x32));
336
+ a_high_f16x32 = _mm512_cvtepi16_ph(_mm512_cvtepi8_epi16(a_high_i8x32));
337
+ b_low_f16x32 = _mm512_cvtepi16_ph(_mm512_cvtepi8_epi16(b_low_i8x32));
338
+ b_high_f16x32 = _mm512_cvtepi16_ph(_mm512_cvtepi8_epi16(b_high_i8x32));
339
+ c_low_f16x32 = _mm512_cvtepi16_ph(_mm512_cvtepi8_epi16(c_low_i8x32));
340
+ c_high_f16x32 = _mm512_cvtepi16_ph(_mm512_cvtepi8_epi16(c_high_i8x32));
341
+ // Multiply:
342
+ ab_low_f16x32 = _mm512_mul_ph(a_low_f16x32, b_low_f16x32);
343
+ ab_high_f16x32 = _mm512_mul_ph(a_high_f16x32, b_high_f16x32);
344
+ // Scale:
345
+ ab_scaled_low_f16x32 = _mm512_mul_ph(ab_low_f16x32, alpha_f16x32);
346
+ ab_scaled_high_f16x32 = _mm512_mul_ph(ab_high_f16x32, alpha_f16x32);
347
+ // Add:
348
+ result_low_f16x32 = _mm512_fmadd_ph(c_low_f16x32, beta_f16x32, ab_scaled_low_f16x32);
349
+ result_high_f16x32 = _mm512_fmadd_ph(c_high_f16x32, beta_f16x32, ab_scaled_high_f16x32);
350
+ // Clip the 16-bit result to 8-bit:
351
+ result_low_f16x32 = _mm512_max_ph(_mm512_min_ph(result_low_f16x32, max_f16x32), min_f16x32);
352
+ result_high_f16x32 = _mm512_max_ph(_mm512_min_ph(result_high_f16x32, max_f16x32), min_f16x32);
353
+ // Downcast:
354
+ result_low_i16x32 = _mm512_cvtph_epi16(result_low_f16x32);
355
+ result_high_i16x32 = _mm512_cvtph_epi16(result_high_f16x32);
356
+ // Merge back:
357
+ result_i8x64 = _mm512_inserti64x4(_mm512_castsi256_si512(_mm512_cvtsepi16_epi8(result_low_i16x32)),
358
+ _mm512_cvtsepi16_epi8(result_high_i16x32), 1);
359
+ _mm512_mask_storeu_epi8(result, mask, result_i8x64);
360
+ result += 64;
361
+ if (n) goto nk_each_fma_i8_sapphire_cycle;
362
+ }
363
+
364
+ NK_PUBLIC void nk_each_fma_u8_sapphire( //
365
+ nk_u8_t const *a, nk_u8_t const *b, nk_u8_t const *c, nk_size_t n, //
366
+ nk_f32_t const *alpha, nk_f32_t const *beta, nk_u8_t *result) {
367
+
368
+ short alpha_short, beta_short;
369
+ nk_f32_to_f16_sapphire(alpha, (nk_f16_t *)&alpha_short);
370
+ nk_f32_to_f16_sapphire(beta, (nk_f16_t *)&beta_short);
371
+ __mmask64 mask = 0xFFFFFFFFFFFFFFFF;
372
+ __m512h alpha_f16x32 = _mm512_castsi512_ph(_mm512_set1_epi16(alpha_short));
373
+ __m512h beta_f16x32 = _mm512_castsi512_ph(_mm512_set1_epi16(beta_short));
374
+ __m512i a_u8x64, b_u8x64, c_u8x64, result_u8x64;
375
+ __m512h a_low_f16x32, a_high_f16x32, b_low_f16x32, b_high_f16x32;
376
+ __m512h c_low_f16x32, c_high_f16x32, ab_low_f16x32, ab_high_f16x32;
377
+ __m512h ab_scaled_low_f16x32, ab_scaled_high_f16x32, result_low_f16x32, result_high_f16x32;
378
+ __m512i result_low_i16x32, result_high_i16x32;
379
+ __m512h min_f16x32 = _mm512_cvtepi16_ph(_mm512_set1_epi16(0));
380
+ __m512h max_f16x32 = _mm512_cvtepi16_ph(_mm512_set1_epi16(255));
381
+
382
+ nk_each_fma_u8_sapphire_cycle:
383
+ if (n < 64) {
384
+ mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFFull, n);
385
+ a_u8x64 = _mm512_maskz_loadu_epi8(mask, a);
386
+ b_u8x64 = _mm512_maskz_loadu_epi8(mask, b);
387
+ c_u8x64 = _mm512_maskz_loadu_epi8(mask, c);
388
+ n = 0;
389
+ }
390
+ else {
391
+ a_u8x64 = _mm512_loadu_epi8(a);
392
+ b_u8x64 = _mm512_loadu_epi8(b);
393
+ c_u8x64 = _mm512_loadu_epi8(c);
394
+ a += 64, b += 64, c += 64, n -= 64;
395
+ }
396
+ // Upcast:
397
+ a_low_f16x32 = _mm512_cvtepi16_ph(_mm512_unpacklo_epi8(a_u8x64, _mm512_setzero_si512()));
398
+ a_high_f16x32 = _mm512_cvtepi16_ph(_mm512_unpackhi_epi8(a_u8x64, _mm512_setzero_si512()));
399
+ b_low_f16x32 = _mm512_cvtepi16_ph(_mm512_unpacklo_epi8(b_u8x64, _mm512_setzero_si512()));
400
+ b_high_f16x32 = _mm512_cvtepi16_ph(_mm512_unpackhi_epi8(b_u8x64, _mm512_setzero_si512()));
401
+ c_low_f16x32 = _mm512_cvtepi16_ph(_mm512_unpacklo_epi8(c_u8x64, _mm512_setzero_si512()));
402
+ c_high_f16x32 = _mm512_cvtepi16_ph(_mm512_unpackhi_epi8(c_u8x64, _mm512_setzero_si512()));
403
+ // Multiply:
404
+ ab_low_f16x32 = _mm512_mul_ph(a_low_f16x32, b_low_f16x32);
405
+ ab_high_f16x32 = _mm512_mul_ph(a_high_f16x32, b_high_f16x32);
406
+ // Scale:
407
+ ab_scaled_low_f16x32 = _mm512_mul_ph(ab_low_f16x32, alpha_f16x32);
408
+ ab_scaled_high_f16x32 = _mm512_mul_ph(ab_high_f16x32, alpha_f16x32);
409
+ // Add:
410
+ result_low_f16x32 = _mm512_fmadd_ph(c_low_f16x32, beta_f16x32, ab_scaled_low_f16x32);
411
+ result_high_f16x32 = _mm512_fmadd_ph(c_high_f16x32, beta_f16x32, ab_scaled_high_f16x32);
412
+ // Clip the 16-bit result to 8-bit:
413
+ result_low_f16x32 = _mm512_max_ph(_mm512_min_ph(result_low_f16x32, max_f16x32), min_f16x32);
414
+ result_high_f16x32 = _mm512_max_ph(_mm512_min_ph(result_high_f16x32, max_f16x32), min_f16x32);
415
+ // Downcast:
416
+ result_low_i16x32 = _mm512_cvtph_epi16(result_low_f16x32);
417
+ result_high_i16x32 = _mm512_cvtph_epi16(result_high_f16x32);
418
+ // Merge back:
419
+ result_u8x64 = _mm512_packus_epi16(result_low_i16x32, result_high_i16x32);
420
+ _mm512_mask_storeu_epi8(result, mask, result_u8x64);
421
+ result += 64;
422
+ if (n) goto nk_each_fma_u8_sapphire_cycle;
423
+ }
424
+
425
+ NK_PUBLIC void nk_each_sum_e4m3_sapphire(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_e4m3_t *result) {
426
+ __m256i a_e4m3x32, b_e4m3x32;
427
+ __m256h a_lo_f16x16, a_hi_f16x16, b_lo_f16x16, b_hi_f16x16;
428
+ __m256h sum_lo_f16x16, sum_hi_f16x16;
429
+ __m128i result_lo_e4m3x16, result_hi_e4m3x16;
430
+ __mmask32 mask = 0xFFFFFFFF;
431
+ nk_each_sum_e4m3_sapphire_cycle:
432
+ if (n < 32) {
433
+ mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, (unsigned int)n);
434
+ a_e4m3x32 = _mm256_maskz_loadu_epi8(mask, a);
435
+ b_e4m3x32 = _mm256_maskz_loadu_epi8(mask, b);
436
+ n = 0;
437
+ }
438
+ else {
439
+ a_e4m3x32 = _mm256_loadu_si256((__m256i const *)a);
440
+ b_e4m3x32 = _mm256_loadu_si256((__m256i const *)b);
441
+ a += 32, b += 32, n -= 32;
442
+ }
443
+
444
+ // Convert e4m3x16 → f16x16 (two halves)
445
+ a_lo_f16x16 = nk_e4m3x16_to_f16x16_sapphire_(_mm256_castsi256_si128(a_e4m3x32));
446
+ a_hi_f16x16 = nk_e4m3x16_to_f16x16_sapphire_(_mm256_extracti128_si256(a_e4m3x32, 1));
447
+ b_lo_f16x16 = nk_e4m3x16_to_f16x16_sapphire_(_mm256_castsi256_si128(b_e4m3x32));
448
+ b_hi_f16x16 = nk_e4m3x16_to_f16x16_sapphire_(_mm256_extracti128_si256(b_e4m3x32, 1));
449
+
450
+ // Add in F16 - e4m3 sum is safe (max 896 < 65504)
451
+ sum_lo_f16x16 = _mm256_add_ph(a_lo_f16x16, b_lo_f16x16);
452
+ sum_hi_f16x16 = _mm256_add_ph(a_hi_f16x16, b_hi_f16x16);
453
+
454
+ // Convert f16x16 → e4m3x16
455
+ result_lo_e4m3x16 = nk_f16x16_to_e4m3x16_sapphire_(sum_lo_f16x16);
456
+ result_hi_e4m3x16 = nk_f16x16_to_e4m3x16_sapphire_(sum_hi_f16x16);
457
+
458
+ // Pack and store
459
+ __m256i result_e4m3x32 = _mm256_inserti128_si256(_mm256_castsi128_si256(result_lo_e4m3x16), result_hi_e4m3x16, 1);
460
+ _mm256_mask_storeu_epi8(result, mask, result_e4m3x32);
461
+ result += 32;
462
+ if (n) goto nk_each_sum_e4m3_sapphire_cycle;
463
+ }
464
+
465
+ #if defined(__clang__)
466
+ #pragma clang attribute pop
467
+ #elif defined(__GNUC__)
468
+ #pragma GCC pop_options
469
+ #endif
470
+
471
+ #if defined(__cplusplus)
472
+ } // extern "C"
473
+ #endif
474
+
475
+ #endif // NK_TARGET_SAPPHIRE
476
+ #endif // NK_TARGET_X86_
477
+ #endif // NK_EACH_SAPPHIRE_H