numkong 7.0.0 → 7.4.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (315) hide show
  1. package/README.md +197 -124
  2. package/binding.gyp +34 -484
  3. package/c/dispatch_bf16.c +59 -1
  4. package/c/dispatch_e2m3.c +41 -8
  5. package/c/dispatch_e3m2.c +49 -8
  6. package/c/dispatch_e4m3.c +51 -9
  7. package/c/dispatch_e5m2.c +45 -1
  8. package/c/dispatch_f16.c +79 -26
  9. package/c/dispatch_f16c.c +5 -5
  10. package/c/dispatch_f32.c +56 -0
  11. package/c/dispatch_f64.c +52 -0
  12. package/c/dispatch_i4.c +3 -0
  13. package/c/dispatch_i8.c +62 -3
  14. package/c/dispatch_other.c +18 -0
  15. package/c/dispatch_u1.c +54 -9
  16. package/c/dispatch_u4.c +3 -0
  17. package/c/dispatch_u8.c +64 -3
  18. package/c/numkong.c +3 -0
  19. package/include/README.md +79 -9
  20. package/include/numkong/attention/sapphireamx.h +278 -276
  21. package/include/numkong/attention/sme.h +983 -977
  22. package/include/numkong/attention.h +1 -1
  23. package/include/numkong/capabilities.h +289 -94
  24. package/include/numkong/cast/README.md +40 -40
  25. package/include/numkong/cast/diamond.h +64 -0
  26. package/include/numkong/cast/haswell.h +42 -194
  27. package/include/numkong/cast/icelake.h +42 -37
  28. package/include/numkong/cast/loongsonasx.h +252 -0
  29. package/include/numkong/cast/neon.h +216 -249
  30. package/include/numkong/cast/powervsx.h +449 -0
  31. package/include/numkong/cast/rvv.h +223 -274
  32. package/include/numkong/cast/sapphire.h +18 -18
  33. package/include/numkong/cast/serial.h +1018 -944
  34. package/include/numkong/cast/skylake.h +82 -23
  35. package/include/numkong/cast/v128relaxed.h +462 -105
  36. package/include/numkong/cast.h +24 -0
  37. package/include/numkong/cast.hpp +44 -0
  38. package/include/numkong/curved/README.md +17 -17
  39. package/include/numkong/curved/neon.h +131 -7
  40. package/include/numkong/curved/neonbfdot.h +6 -7
  41. package/include/numkong/curved/rvv.h +26 -26
  42. package/include/numkong/curved/smef64.h +186 -182
  43. package/include/numkong/curved.h +14 -18
  44. package/include/numkong/dot/README.md +154 -137
  45. package/include/numkong/dot/alder.h +43 -43
  46. package/include/numkong/dot/diamond.h +158 -0
  47. package/include/numkong/dot/genoa.h +4 -30
  48. package/include/numkong/dot/haswell.h +215 -180
  49. package/include/numkong/dot/icelake.h +190 -76
  50. package/include/numkong/dot/loongsonasx.h +671 -0
  51. package/include/numkong/dot/neon.h +124 -73
  52. package/include/numkong/dot/neonbfdot.h +11 -12
  53. package/include/numkong/dot/neonfhm.h +44 -46
  54. package/include/numkong/dot/neonfp8.h +323 -0
  55. package/include/numkong/dot/neonsdot.h +190 -76
  56. package/include/numkong/dot/powervsx.h +752 -0
  57. package/include/numkong/dot/rvv.h +92 -84
  58. package/include/numkong/dot/rvvbf16.h +12 -12
  59. package/include/numkong/dot/rvvhalf.h +12 -12
  60. package/include/numkong/dot/sapphire.h +4 -4
  61. package/include/numkong/dot/serial.h +66 -30
  62. package/include/numkong/dot/sierra.h +31 -31
  63. package/include/numkong/dot/skylake.h +142 -110
  64. package/include/numkong/dot/sve.h +217 -177
  65. package/include/numkong/dot/svebfdot.h +10 -10
  66. package/include/numkong/dot/svehalf.h +85 -41
  67. package/include/numkong/dot/svesdot.h +89 -0
  68. package/include/numkong/dot/v128relaxed.h +124 -89
  69. package/include/numkong/dot.h +114 -48
  70. package/include/numkong/dots/README.md +203 -203
  71. package/include/numkong/dots/alder.h +12 -9
  72. package/include/numkong/dots/diamond.h +86 -0
  73. package/include/numkong/dots/genoa.h +10 -4
  74. package/include/numkong/dots/haswell.h +63 -48
  75. package/include/numkong/dots/icelake.h +27 -18
  76. package/include/numkong/dots/loongsonasx.h +176 -0
  77. package/include/numkong/dots/neon.h +14 -11
  78. package/include/numkong/dots/neonbfdot.h +4 -3
  79. package/include/numkong/dots/neonfhm.h +11 -9
  80. package/include/numkong/dots/neonfp8.h +99 -0
  81. package/include/numkong/dots/neonsdot.h +48 -12
  82. package/include/numkong/dots/powervsx.h +194 -0
  83. package/include/numkong/dots/rvv.h +451 -344
  84. package/include/numkong/dots/sapphireamx.h +1028 -984
  85. package/include/numkong/dots/serial.h +213 -197
  86. package/include/numkong/dots/sierra.h +10 -7
  87. package/include/numkong/dots/skylake.h +47 -36
  88. package/include/numkong/dots/sme.h +2001 -2364
  89. package/include/numkong/dots/smebi32.h +175 -162
  90. package/include/numkong/dots/smef64.h +328 -323
  91. package/include/numkong/dots/v128relaxed.h +64 -41
  92. package/include/numkong/dots.h +573 -293
  93. package/include/numkong/dots.hpp +45 -43
  94. package/include/numkong/each/README.md +133 -137
  95. package/include/numkong/each/haswell.h +6 -6
  96. package/include/numkong/each/icelake.h +7 -7
  97. package/include/numkong/each/neon.h +76 -42
  98. package/include/numkong/each/neonbfdot.h +11 -12
  99. package/include/numkong/each/neonhalf.h +24 -116
  100. package/include/numkong/each/rvv.h +28 -28
  101. package/include/numkong/each/sapphire.h +27 -161
  102. package/include/numkong/each/serial.h +6 -6
  103. package/include/numkong/each/skylake.h +7 -7
  104. package/include/numkong/each/v128relaxed.h +562 -0
  105. package/include/numkong/each.h +148 -62
  106. package/include/numkong/each.hpp +2 -2
  107. package/include/numkong/geospatial/README.md +18 -18
  108. package/include/numkong/geospatial/haswell.h +365 -325
  109. package/include/numkong/geospatial/neon.h +350 -306
  110. package/include/numkong/geospatial/rvv.h +4 -4
  111. package/include/numkong/geospatial/skylake.h +376 -340
  112. package/include/numkong/geospatial/v128relaxed.h +366 -327
  113. package/include/numkong/geospatial.h +17 -17
  114. package/include/numkong/matrix.hpp +4 -4
  115. package/include/numkong/maxsim/README.md +14 -14
  116. package/include/numkong/maxsim/alder.h +6 -6
  117. package/include/numkong/maxsim/genoa.h +4 -4
  118. package/include/numkong/maxsim/haswell.h +6 -6
  119. package/include/numkong/maxsim/icelake.h +18 -18
  120. package/include/numkong/maxsim/neonsdot.h +21 -21
  121. package/include/numkong/maxsim/sapphireamx.h +14 -14
  122. package/include/numkong/maxsim/serial.h +6 -6
  123. package/include/numkong/maxsim/sme.h +221 -196
  124. package/include/numkong/maxsim/v128relaxed.h +6 -6
  125. package/include/numkong/mesh/README.md +62 -56
  126. package/include/numkong/mesh/haswell.h +339 -464
  127. package/include/numkong/mesh/neon.h +1100 -519
  128. package/include/numkong/mesh/neonbfdot.h +36 -68
  129. package/include/numkong/mesh/rvv.h +530 -435
  130. package/include/numkong/mesh/serial.h +75 -91
  131. package/include/numkong/mesh/skylake.h +1627 -302
  132. package/include/numkong/mesh/v128relaxed.h +443 -330
  133. package/include/numkong/mesh.h +63 -49
  134. package/include/numkong/mesh.hpp +4 -4
  135. package/include/numkong/numkong.h +3 -3
  136. package/include/numkong/numkong.hpp +1 -0
  137. package/include/numkong/probability/README.md +23 -19
  138. package/include/numkong/probability/neon.h +82 -52
  139. package/include/numkong/probability/rvv.h +28 -23
  140. package/include/numkong/probability/serial.h +51 -39
  141. package/include/numkong/probability.h +20 -23
  142. package/include/numkong/random.h +1 -1
  143. package/include/numkong/reduce/README.md +143 -138
  144. package/include/numkong/reduce/alder.h +81 -77
  145. package/include/numkong/reduce/haswell.h +222 -220
  146. package/include/numkong/reduce/neon.h +629 -519
  147. package/include/numkong/reduce/neonbfdot.h +7 -218
  148. package/include/numkong/reduce/neonfhm.h +9 -381
  149. package/include/numkong/reduce/neonsdot.h +9 -9
  150. package/include/numkong/reduce/rvv.h +928 -802
  151. package/include/numkong/reduce/serial.h +23 -27
  152. package/include/numkong/reduce/sierra.h +20 -20
  153. package/include/numkong/reduce/skylake.h +326 -324
  154. package/include/numkong/reduce/v128relaxed.h +52 -52
  155. package/include/numkong/reduce.h +4 -23
  156. package/include/numkong/reduce.hpp +156 -11
  157. package/include/numkong/scalar/README.md +6 -6
  158. package/include/numkong/scalar/haswell.h +26 -17
  159. package/include/numkong/scalar/loongsonasx.h +74 -0
  160. package/include/numkong/scalar/neon.h +9 -9
  161. package/include/numkong/scalar/powervsx.h +96 -0
  162. package/include/numkong/scalar/rvv.h +2 -2
  163. package/include/numkong/scalar/sapphire.h +21 -10
  164. package/include/numkong/scalar/serial.h +21 -21
  165. package/include/numkong/scalar.h +13 -0
  166. package/include/numkong/set/README.md +28 -28
  167. package/include/numkong/set/haswell.h +12 -12
  168. package/include/numkong/set/icelake.h +14 -14
  169. package/include/numkong/set/loongsonasx.h +181 -0
  170. package/include/numkong/set/neon.h +17 -18
  171. package/include/numkong/set/powervsx.h +326 -0
  172. package/include/numkong/set/rvv.h +4 -4
  173. package/include/numkong/set/serial.h +6 -6
  174. package/include/numkong/set/sve.h +60 -59
  175. package/include/numkong/set/v128relaxed.h +6 -6
  176. package/include/numkong/set.h +21 -7
  177. package/include/numkong/sets/README.md +26 -26
  178. package/include/numkong/sets/loongsonasx.h +52 -0
  179. package/include/numkong/sets/powervsx.h +65 -0
  180. package/include/numkong/sets/smebi32.h +395 -364
  181. package/include/numkong/sets.h +83 -40
  182. package/include/numkong/sparse/README.md +4 -4
  183. package/include/numkong/sparse/icelake.h +101 -101
  184. package/include/numkong/sparse/serial.h +1 -1
  185. package/include/numkong/sparse/sve2.h +137 -141
  186. package/include/numkong/sparse/turin.h +12 -12
  187. package/include/numkong/sparse.h +10 -10
  188. package/include/numkong/spatial/README.md +230 -226
  189. package/include/numkong/spatial/alder.h +113 -116
  190. package/include/numkong/spatial/diamond.h +240 -0
  191. package/include/numkong/spatial/genoa.h +0 -68
  192. package/include/numkong/spatial/haswell.h +74 -55
  193. package/include/numkong/spatial/icelake.h +539 -58
  194. package/include/numkong/spatial/loongsonasx.h +483 -0
  195. package/include/numkong/spatial/neon.h +125 -52
  196. package/include/numkong/spatial/neonbfdot.h +8 -9
  197. package/include/numkong/spatial/neonfp8.h +258 -0
  198. package/include/numkong/spatial/neonsdot.h +180 -12
  199. package/include/numkong/spatial/powervsx.h +738 -0
  200. package/include/numkong/spatial/rvv.h +146 -139
  201. package/include/numkong/spatial/rvvbf16.h +17 -12
  202. package/include/numkong/spatial/rvvhalf.h +13 -10
  203. package/include/numkong/spatial/serial.h +13 -12
  204. package/include/numkong/spatial/sierra.h +232 -39
  205. package/include/numkong/spatial/skylake.h +73 -74
  206. package/include/numkong/spatial/sve.h +93 -72
  207. package/include/numkong/spatial/svebfdot.h +29 -29
  208. package/include/numkong/spatial/svehalf.h +52 -26
  209. package/include/numkong/spatial/svesdot.h +142 -0
  210. package/include/numkong/spatial/v128relaxed.h +293 -41
  211. package/include/numkong/spatial.h +338 -82
  212. package/include/numkong/spatials/README.md +194 -194
  213. package/include/numkong/spatials/diamond.h +82 -0
  214. package/include/numkong/spatials/haswell.h +2 -2
  215. package/include/numkong/spatials/loongsonasx.h +153 -0
  216. package/include/numkong/spatials/neonfp8.h +111 -0
  217. package/include/numkong/spatials/neonsdot.h +34 -0
  218. package/include/numkong/spatials/powervsx.h +153 -0
  219. package/include/numkong/spatials/rvv.h +259 -243
  220. package/include/numkong/spatials/sapphireamx.h +173 -173
  221. package/include/numkong/spatials/serial.h +2 -2
  222. package/include/numkong/spatials/skylake.h +2 -2
  223. package/include/numkong/spatials/sme.h +590 -605
  224. package/include/numkong/spatials/smef64.h +139 -130
  225. package/include/numkong/spatials/v128relaxed.h +2 -2
  226. package/include/numkong/spatials.h +820 -500
  227. package/include/numkong/spatials.hpp +49 -48
  228. package/include/numkong/tensor.hpp +406 -17
  229. package/include/numkong/trigonometry/README.md +19 -19
  230. package/include/numkong/trigonometry/haswell.h +402 -401
  231. package/include/numkong/trigonometry/neon.h +386 -387
  232. package/include/numkong/trigonometry/rvv.h +52 -51
  233. package/include/numkong/trigonometry/serial.h +13 -13
  234. package/include/numkong/trigonometry/skylake.h +373 -369
  235. package/include/numkong/trigonometry/v128relaxed.h +375 -374
  236. package/include/numkong/trigonometry.h +13 -13
  237. package/include/numkong/trigonometry.hpp +2 -2
  238. package/include/numkong/types.h +287 -49
  239. package/include/numkong/types.hpp +436 -12
  240. package/include/numkong/vector.hpp +82 -14
  241. package/javascript/dist/cjs/numkong-wasm.js +6 -12
  242. package/javascript/dist/cjs/numkong.d.ts +7 -1
  243. package/javascript/dist/cjs/numkong.js +37 -11
  244. package/javascript/dist/cjs/types.d.ts +9 -0
  245. package/javascript/dist/cjs/types.js +96 -0
  246. package/javascript/dist/esm/numkong-browser.d.ts +14 -0
  247. package/javascript/dist/esm/numkong-browser.js +23 -0
  248. package/javascript/dist/esm/numkong-wasm.js +6 -12
  249. package/javascript/dist/esm/numkong.d.ts +7 -1
  250. package/javascript/dist/esm/numkong.js +37 -11
  251. package/javascript/dist/esm/types.d.ts +9 -0
  252. package/javascript/dist/esm/types.js +96 -0
  253. package/javascript/node-gyp-build.d.ts +4 -1
  254. package/javascript/numkong-browser.ts +40 -0
  255. package/javascript/numkong-wasm.ts +7 -13
  256. package/javascript/numkong.c +5 -26
  257. package/javascript/numkong.ts +36 -11
  258. package/javascript/tsconfig-base.json +1 -0
  259. package/javascript/tsconfig-cjs.json +6 -1
  260. package/javascript/types.ts +110 -0
  261. package/numkong.gypi +101 -0
  262. package/package.json +34 -13
  263. package/probes/arm_neon.c +8 -0
  264. package/probes/arm_neon_bfdot.c +9 -0
  265. package/probes/arm_neon_fhm.c +9 -0
  266. package/probes/arm_neon_half.c +8 -0
  267. package/probes/arm_neon_sdot.c +9 -0
  268. package/probes/arm_neonfp8.c +9 -0
  269. package/probes/arm_sme.c +16 -0
  270. package/probes/arm_sme2.c +16 -0
  271. package/probes/arm_sme2p1.c +16 -0
  272. package/probes/arm_sme_bf16.c +16 -0
  273. package/probes/arm_sme_bi32.c +16 -0
  274. package/probes/arm_sme_f64.c +16 -0
  275. package/probes/arm_sme_fa64.c +14 -0
  276. package/probes/arm_sme_half.c +16 -0
  277. package/probes/arm_sme_lut2.c +15 -0
  278. package/probes/arm_sve.c +18 -0
  279. package/probes/arm_sve2.c +20 -0
  280. package/probes/arm_sve2p1.c +18 -0
  281. package/probes/arm_sve_bfdot.c +20 -0
  282. package/probes/arm_sve_half.c +18 -0
  283. package/probes/arm_sve_sdot.c +21 -0
  284. package/probes/loongarch_lasx.c +12 -0
  285. package/probes/power_vsx.c +12 -0
  286. package/probes/probe.js +127 -0
  287. package/probes/riscv_rvv.c +14 -0
  288. package/probes/riscv_rvv_bb.c +15 -0
  289. package/probes/riscv_rvv_bf16.c +17 -0
  290. package/probes/riscv_rvv_half.c +14 -0
  291. package/probes/wasm_v128relaxed.c +11 -0
  292. package/probes/x86_alder.c +17 -0
  293. package/probes/x86_diamond.c +17 -0
  294. package/probes/x86_genoa.c +17 -0
  295. package/probes/x86_graniteamx.c +19 -0
  296. package/probes/x86_haswell.c +11 -0
  297. package/probes/x86_icelake.c +17 -0
  298. package/probes/x86_sapphire.c +16 -0
  299. package/probes/x86_sapphireamx.c +18 -0
  300. package/probes/x86_sierra.c +17 -0
  301. package/probes/x86_skylake.c +15 -0
  302. package/probes/x86_turin.c +17 -0
  303. package/wasm/numkong-emscripten.js +2 -0
  304. package/wasm/numkong.d.ts +14 -0
  305. package/wasm/numkong.js +1124 -0
  306. package/wasm/numkong.wasm +0 -0
  307. package/include/numkong/curved/neonhalf.h +0 -212
  308. package/include/numkong/dot/neonhalf.h +0 -198
  309. package/include/numkong/dots/neonhalf.h +0 -57
  310. package/include/numkong/mesh/neonhalf.h +0 -616
  311. package/include/numkong/reduce/neonhalf.h +0 -157
  312. package/include/numkong/spatial/neonhalf.h +0 -118
  313. package/include/numkong/spatial/sapphire.h +0 -343
  314. package/include/numkong/spatials/neonhalf.h +0 -58
  315. package/javascript/README.md +0 -246
@@ -8,18 +8,18 @@
8
8
  *
9
9
  * @section sapphire_elementwise_instructions Relevant Instructions
10
10
  *
11
- * Intrinsic Instruction Sapphire Genoa
12
- * _mm512_add_ph VADDPH (ZMM, ZMM, ZMM) 4cy @ p05 3cy @ p01
13
- * _mm512_mul_ph VMULPH (ZMM, ZMM, ZMM) 4cy @ p05 3cy @ p01
14
- * _mm512_fmadd_ph VFMADD (ZMM, ZMM, ZMM) 4cy @ p05 4cy @ p01
15
- * _mm512_cvtepi16_ph VCVTW2PH (ZMM, ZMM) 4cy @ p05 4cy @ p01
16
- * _mm512_cvtph_epi16 VCVTPH2W (ZMM, ZMM) 4cy @ p05 4cy @ p01
17
- * _mm512_cvtepi8_epi16 VPMOVSXBW (ZMM, YMM) 3cy @ p5 3cy @ p12
18
- * _mm512_cvtsepi16_epi8 VPMOVSWB (YMM, ZMM) 4cy @ p5 4cy @ p12
19
- * _mm512_packus_epi16 VPACKUSWB (ZMM, ZMM, ZMM) 1cy @ p5 1cy @ p12
20
- * _mm256_add_ph VADDPH (YMM, YMM, YMM) 4cy @ p05 3cy @ p01
21
- * _mm512_maskz_loadu_epi16 VMOVDQU16 (ZMM {K}, M512) 7cy @ p23 7cy @ p23
22
- * _mm512_mask_storeu_epi16 VMOVDQU16 (M512 {K}, ZMM) 4cy @ p4 4cy @ p4
11
+ * Intrinsic Instruction Sapphire Genoa
12
+ * _mm512_add_ph VADDPH (ZMM, ZMM, ZMM) 4cy @ p05 3cy @ p01
13
+ * _mm512_mul_ph VMULPH (ZMM, ZMM, ZMM) 4cy @ p05 3cy @ p01
14
+ * _mm512_fmadd_ph VFMADD (ZMM, ZMM, ZMM) 4cy @ p05 4cy @ p01
15
+ * _mm512_cvtepi16_ph VCVTW2PH (ZMM, ZMM) 4cy @ p05 4cy @ p01
16
+ * _mm512_cvtph_epi16 VCVTPH2W (ZMM, ZMM) 4cy @ p05 4cy @ p01
17
+ * _mm512_cvtepi8_epi16 VPMOVSXBW (ZMM, YMM) 3cy @ p5 3cy @ p12
18
+ * _mm512_cvtsepi16_epi8 VPMOVSWB (YMM, ZMM) 4cy @ p5 4cy @ p12
19
+ * _mm512_packus_epi16 VPACKUSWB (ZMM, ZMM, ZMM) 1cy @ p5 1cy @ p12
20
+ * _mm256_add_ph VADDPH (YMM, YMM, YMM) 4cy @ p05 3cy @ p01
21
+ * _mm512_maskz_loadu_epi16 VMOVDQU16 (ZMM {K}, M512) 7cy @ p23 7cy @ p23
22
+ * _mm512_mask_storeu_epi16 VMOVDQU16 (M512 {K}, ZMM) 4cy @ p4 4cy @ p4
23
23
  */
24
24
  #ifndef NK_EACH_SAPPHIRE_H
25
25
  #define NK_EACH_SAPPHIRE_H
@@ -54,8 +54,8 @@ nk_each_sum_f16_sapphire_cycle:
54
54
  n = 0;
55
55
  }
56
56
  else {
57
- a_f16_vec = _mm512_loadu_ph(a);
58
- b_f16_vec = _mm512_loadu_ph(b);
57
+ a_f16_vec = _mm512_castsi512_ph(_mm512_loadu_epi16(a));
58
+ b_f16_vec = _mm512_castsi512_ph(_mm512_loadu_epi16(b));
59
59
  a += 32, b += 32, n -= 32;
60
60
  }
61
61
  sum_f16_vec = _mm512_add_ph(a_f16_vec, b_f16_vec);
@@ -287,146 +287,11 @@ nk_each_blend_i8_sapphire_cycle:
287
287
  if (n) goto nk_each_blend_i8_sapphire_cycle;
288
288
  }
289
289
 
290
- NK_PUBLIC void nk_each_fma_i8_sapphire( //
291
- nk_i8_t const *a, nk_i8_t const *b, nk_i8_t const *c, nk_size_t n, //
292
- nk_f32_t const *alpha, nk_f32_t const *beta, nk_i8_t *result) {
293
-
294
- short alpha_short, beta_short;
295
- nk_f32_to_f16_sapphire(alpha, (nk_f16_t *)&alpha_short);
296
- nk_f32_to_f16_sapphire(beta, (nk_f16_t *)&beta_short);
297
- __mmask64 mask = 0xFFFFFFFFFFFFFFFF;
298
- __m512h alpha_f16x32 = _mm512_castsi512_ph(_mm512_set1_epi16(alpha_short));
299
- __m512h beta_f16x32 = _mm512_castsi512_ph(_mm512_set1_epi16(beta_short));
300
- __m256i a_low_i8x32, a_high_i8x32, b_low_i8x32, b_high_i8x32, c_low_i8x32, c_high_i8x32;
301
- __m512i result_i8x64;
302
- __m512h a_low_f16x32, a_high_f16x32, b_low_f16x32, b_high_f16x32;
303
- __m512h c_low_f16x32, c_high_f16x32, ab_low_f16x32, ab_high_f16x32;
304
- __m512h ab_scaled_low_f16x32, ab_scaled_high_f16x32, result_low_f16x32, result_high_f16x32;
305
- __m512i result_low_i16x32, result_high_i16x32;
306
- __m512h min_f16x32 = _mm512_cvtepi16_ph(_mm512_set1_epi16(-128));
307
- __m512h max_f16x32 = _mm512_cvtepi16_ph(_mm512_set1_epi16(127));
308
-
309
- nk_each_fma_i8_sapphire_cycle:
310
- if (n < 64) {
311
- // Tail: use masked 512-bit loads and extract (runs once)
312
- mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFFull, n);
313
- __m512i a_i8x64 = _mm512_maskz_loadu_epi8(mask, a);
314
- __m512i b_i8x64 = _mm512_maskz_loadu_epi8(mask, b);
315
- __m512i c_i8x64 = _mm512_maskz_loadu_epi8(mask, c);
316
- a_low_i8x32 = _mm512_castsi512_si256(a_i8x64);
317
- a_high_i8x32 = _mm512_extracti64x4_epi64(a_i8x64, 1);
318
- b_low_i8x32 = _mm512_castsi512_si256(b_i8x64);
319
- b_high_i8x32 = _mm512_extracti64x4_epi64(b_i8x64, 1);
320
- c_low_i8x32 = _mm512_castsi512_si256(c_i8x64);
321
- c_high_i8x32 = _mm512_extracti64x4_epi64(c_i8x64, 1);
322
- n = 0;
323
- }
324
- else {
325
- // Hot path: 2×256-bit loads per vector to avoid VEXTRACTI64X4 (Port 5)
326
- a_low_i8x32 = _mm256_loadu_epi8(a);
327
- a_high_i8x32 = _mm256_loadu_epi8(a + 32);
328
- b_low_i8x32 = _mm256_loadu_epi8(b);
329
- b_high_i8x32 = _mm256_loadu_epi8(b + 32);
330
- c_low_i8x32 = _mm256_loadu_epi8(c);
331
- c_high_i8x32 = _mm256_loadu_epi8(c + 32);
332
- a += 64, b += 64, c += 64, n -= 64;
333
- }
334
- // Upcast from 256-bit halves:
335
- a_low_f16x32 = _mm512_cvtepi16_ph(_mm512_cvtepi8_epi16(a_low_i8x32));
336
- a_high_f16x32 = _mm512_cvtepi16_ph(_mm512_cvtepi8_epi16(a_high_i8x32));
337
- b_low_f16x32 = _mm512_cvtepi16_ph(_mm512_cvtepi8_epi16(b_low_i8x32));
338
- b_high_f16x32 = _mm512_cvtepi16_ph(_mm512_cvtepi8_epi16(b_high_i8x32));
339
- c_low_f16x32 = _mm512_cvtepi16_ph(_mm512_cvtepi8_epi16(c_low_i8x32));
340
- c_high_f16x32 = _mm512_cvtepi16_ph(_mm512_cvtepi8_epi16(c_high_i8x32));
341
- // Multiply:
342
- ab_low_f16x32 = _mm512_mul_ph(a_low_f16x32, b_low_f16x32);
343
- ab_high_f16x32 = _mm512_mul_ph(a_high_f16x32, b_high_f16x32);
344
- // Scale:
345
- ab_scaled_low_f16x32 = _mm512_mul_ph(ab_low_f16x32, alpha_f16x32);
346
- ab_scaled_high_f16x32 = _mm512_mul_ph(ab_high_f16x32, alpha_f16x32);
347
- // Add:
348
- result_low_f16x32 = _mm512_fmadd_ph(c_low_f16x32, beta_f16x32, ab_scaled_low_f16x32);
349
- result_high_f16x32 = _mm512_fmadd_ph(c_high_f16x32, beta_f16x32, ab_scaled_high_f16x32);
350
- // Clip the 16-bit result to 8-bit:
351
- result_low_f16x32 = _mm512_max_ph(_mm512_min_ph(result_low_f16x32, max_f16x32), min_f16x32);
352
- result_high_f16x32 = _mm512_max_ph(_mm512_min_ph(result_high_f16x32, max_f16x32), min_f16x32);
353
- // Downcast:
354
- result_low_i16x32 = _mm512_cvtph_epi16(result_low_f16x32);
355
- result_high_i16x32 = _mm512_cvtph_epi16(result_high_f16x32);
356
- // Merge back:
357
- result_i8x64 = _mm512_inserti64x4(_mm512_castsi256_si512(_mm512_cvtsepi16_epi8(result_low_i16x32)),
358
- _mm512_cvtsepi16_epi8(result_high_i16x32), 1);
359
- _mm512_mask_storeu_epi8(result, mask, result_i8x64);
360
- result += 64;
361
- if (n) goto nk_each_fma_i8_sapphire_cycle;
362
- }
363
-
364
- NK_PUBLIC void nk_each_fma_u8_sapphire( //
365
- nk_u8_t const *a, nk_u8_t const *b, nk_u8_t const *c, nk_size_t n, //
366
- nk_f32_t const *alpha, nk_f32_t const *beta, nk_u8_t *result) {
367
-
368
- short alpha_short, beta_short;
369
- nk_f32_to_f16_sapphire(alpha, (nk_f16_t *)&alpha_short);
370
- nk_f32_to_f16_sapphire(beta, (nk_f16_t *)&beta_short);
371
- __mmask64 mask = 0xFFFFFFFFFFFFFFFF;
372
- __m512h alpha_f16x32 = _mm512_castsi512_ph(_mm512_set1_epi16(alpha_short));
373
- __m512h beta_f16x32 = _mm512_castsi512_ph(_mm512_set1_epi16(beta_short));
374
- __m512i a_u8x64, b_u8x64, c_u8x64, result_u8x64;
375
- __m512h a_low_f16x32, a_high_f16x32, b_low_f16x32, b_high_f16x32;
376
- __m512h c_low_f16x32, c_high_f16x32, ab_low_f16x32, ab_high_f16x32;
377
- __m512h ab_scaled_low_f16x32, ab_scaled_high_f16x32, result_low_f16x32, result_high_f16x32;
378
- __m512i result_low_i16x32, result_high_i16x32;
379
- __m512h min_f16x32 = _mm512_cvtepi16_ph(_mm512_set1_epi16(0));
380
- __m512h max_f16x32 = _mm512_cvtepi16_ph(_mm512_set1_epi16(255));
381
-
382
- nk_each_fma_u8_sapphire_cycle:
383
- if (n < 64) {
384
- mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFFull, n);
385
- a_u8x64 = _mm512_maskz_loadu_epi8(mask, a);
386
- b_u8x64 = _mm512_maskz_loadu_epi8(mask, b);
387
- c_u8x64 = _mm512_maskz_loadu_epi8(mask, c);
388
- n = 0;
389
- }
390
- else {
391
- a_u8x64 = _mm512_loadu_epi8(a);
392
- b_u8x64 = _mm512_loadu_epi8(b);
393
- c_u8x64 = _mm512_loadu_epi8(c);
394
- a += 64, b += 64, c += 64, n -= 64;
395
- }
396
- // Upcast:
397
- a_low_f16x32 = _mm512_cvtepi16_ph(_mm512_unpacklo_epi8(a_u8x64, _mm512_setzero_si512()));
398
- a_high_f16x32 = _mm512_cvtepi16_ph(_mm512_unpackhi_epi8(a_u8x64, _mm512_setzero_si512()));
399
- b_low_f16x32 = _mm512_cvtepi16_ph(_mm512_unpacklo_epi8(b_u8x64, _mm512_setzero_si512()));
400
- b_high_f16x32 = _mm512_cvtepi16_ph(_mm512_unpackhi_epi8(b_u8x64, _mm512_setzero_si512()));
401
- c_low_f16x32 = _mm512_cvtepi16_ph(_mm512_unpacklo_epi8(c_u8x64, _mm512_setzero_si512()));
402
- c_high_f16x32 = _mm512_cvtepi16_ph(_mm512_unpackhi_epi8(c_u8x64, _mm512_setzero_si512()));
403
- // Multiply:
404
- ab_low_f16x32 = _mm512_mul_ph(a_low_f16x32, b_low_f16x32);
405
- ab_high_f16x32 = _mm512_mul_ph(a_high_f16x32, b_high_f16x32);
406
- // Scale:
407
- ab_scaled_low_f16x32 = _mm512_mul_ph(ab_low_f16x32, alpha_f16x32);
408
- ab_scaled_high_f16x32 = _mm512_mul_ph(ab_high_f16x32, alpha_f16x32);
409
- // Add:
410
- result_low_f16x32 = _mm512_fmadd_ph(c_low_f16x32, beta_f16x32, ab_scaled_low_f16x32);
411
- result_high_f16x32 = _mm512_fmadd_ph(c_high_f16x32, beta_f16x32, ab_scaled_high_f16x32);
412
- // Clip the 16-bit result to 8-bit:
413
- result_low_f16x32 = _mm512_max_ph(_mm512_min_ph(result_low_f16x32, max_f16x32), min_f16x32);
414
- result_high_f16x32 = _mm512_max_ph(_mm512_min_ph(result_high_f16x32, max_f16x32), min_f16x32);
415
- // Downcast:
416
- result_low_i16x32 = _mm512_cvtph_epi16(result_low_f16x32);
417
- result_high_i16x32 = _mm512_cvtph_epi16(result_high_f16x32);
418
- // Merge back:
419
- result_u8x64 = _mm512_packus_epi16(result_low_i16x32, result_high_i16x32);
420
- _mm512_mask_storeu_epi8(result, mask, result_u8x64);
421
- result += 64;
422
- if (n) goto nk_each_fma_u8_sapphire_cycle;
423
- }
424
-
425
290
  NK_PUBLIC void nk_each_sum_e4m3_sapphire(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_e4m3_t *result) {
426
291
  __m256i a_e4m3x32, b_e4m3x32;
427
- __m256h a_lo_f16x16, a_hi_f16x16, b_lo_f16x16, b_hi_f16x16;
428
- __m256h sum_lo_f16x16, sum_hi_f16x16;
429
- __m128i result_lo_e4m3x16, result_hi_e4m3x16;
292
+ __m256h a_low_f16x16, a_high_f16x16, b_low_f16x16, b_high_f16x16;
293
+ __m256h sum_low_f16x16, sum_high_f16x16;
294
+ __m128i result_low_e4m3x16, result_high_e4m3x16;
430
295
  __mmask32 mask = 0xFFFFFFFF;
431
296
  nk_each_sum_e4m3_sapphire_cycle:
432
297
  if (n < 32) {
@@ -442,21 +307,22 @@ nk_each_sum_e4m3_sapphire_cycle:
442
307
  }
443
308
 
444
309
  // Convert e4m3x16 → f16x16 (two halves)
445
- a_lo_f16x16 = nk_e4m3x16_to_f16x16_sapphire_(_mm256_castsi256_si128(a_e4m3x32));
446
- a_hi_f16x16 = nk_e4m3x16_to_f16x16_sapphire_(_mm256_extracti128_si256(a_e4m3x32, 1));
447
- b_lo_f16x16 = nk_e4m3x16_to_f16x16_sapphire_(_mm256_castsi256_si128(b_e4m3x32));
448
- b_hi_f16x16 = nk_e4m3x16_to_f16x16_sapphire_(_mm256_extracti128_si256(b_e4m3x32, 1));
310
+ a_low_f16x16 = nk_e4m3x16_to_f16x16_sapphire_(_mm256_castsi256_si128(a_e4m3x32));
311
+ a_high_f16x16 = nk_e4m3x16_to_f16x16_sapphire_(_mm256_extracti128_si256(a_e4m3x32, 1));
312
+ b_low_f16x16 = nk_e4m3x16_to_f16x16_sapphire_(_mm256_castsi256_si128(b_e4m3x32));
313
+ b_high_f16x16 = nk_e4m3x16_to_f16x16_sapphire_(_mm256_extracti128_si256(b_e4m3x32, 1));
449
314
 
450
315
  // Add in F16 - e4m3 sum is safe (max 896 < 65504)
451
- sum_lo_f16x16 = _mm256_add_ph(a_lo_f16x16, b_lo_f16x16);
452
- sum_hi_f16x16 = _mm256_add_ph(a_hi_f16x16, b_hi_f16x16);
316
+ sum_low_f16x16 = _mm256_add_ph(a_low_f16x16, b_low_f16x16);
317
+ sum_high_f16x16 = _mm256_add_ph(a_high_f16x16, b_high_f16x16);
453
318
 
454
319
  // Convert f16x16 → e4m3x16
455
- result_lo_e4m3x16 = nk_f16x16_to_e4m3x16_sapphire_(sum_lo_f16x16);
456
- result_hi_e4m3x16 = nk_f16x16_to_e4m3x16_sapphire_(sum_hi_f16x16);
320
+ result_low_e4m3x16 = nk_f16x16_to_e4m3x16_sapphire_(sum_low_f16x16);
321
+ result_high_e4m3x16 = nk_f16x16_to_e4m3x16_sapphire_(sum_high_f16x16);
457
322
 
458
323
  // Pack and store
459
- __m256i result_e4m3x32 = _mm256_inserti128_si256(_mm256_castsi128_si256(result_lo_e4m3x16), result_hi_e4m3x16, 1);
324
+ __m256i result_e4m3x32 = _mm256_inserti128_si256(_mm256_castsi128_si256(result_low_e4m3x16), result_high_e4m3x16,
325
+ 1);
460
326
  _mm256_mask_storeu_epi8(result, mask, result_e4m3x32);
461
327
  result += 32;
462
328
  if (n) goto nk_each_sum_e4m3_sapphire_cycle;
@@ -107,8 +107,8 @@ nk_define_each_scale_(i16, f32, nk_assign_from_to_, nk_f32_to_i16_serial) /
107
107
  nk_define_each_scale_(u16, f32, nk_assign_from_to_, nk_f32_to_u16_serial) // nk_each_scale_u16_serial
108
108
  nk_define_each_scale_(i32, f64, nk_assign_from_to_, nk_f64_to_i32_serial) // nk_each_scale_i32_serial
109
109
  nk_define_each_scale_(u32, f64, nk_assign_from_to_, nk_f64_to_u32_serial) // nk_each_scale_u32_serial
110
- nk_define_each_scale_(i64, f64, nk_assign_from_to_, nk_f64_to_i64_serial) // nk_each_scale_i64_serial
111
- nk_define_each_scale_(u64, f64, nk_assign_from_to_, nk_f64_to_u64_serial) // nk_each_scale_u64_serial
110
+ nk_define_each_scale_(i64, f64, nk_f64_from_i64_, nk_f64_to_i64_serial) // nk_each_scale_i64_serial
111
+ nk_define_each_scale_(u64, f64, nk_f64_from_u64_, nk_f64_to_u64_serial) // nk_each_scale_u64_serial
112
112
 
113
113
  nk_define_each_blend_(f64, f64, nk_assign_from_to_, nk_assign_from_to_) // nk_each_blend_f64_serial
114
114
  nk_define_each_blend_(f32, f32, nk_assign_from_to_, nk_assign_from_to_) // nk_each_blend_f32_serial
@@ -124,8 +124,8 @@ nk_define_each_blend_(i16, f32, nk_assign_from_to_, nk_f32_to_i16_serial) /
124
124
  nk_define_each_blend_(u16, f32, nk_assign_from_to_, nk_f32_to_u16_serial) // nk_each_blend_u16_serial
125
125
  nk_define_each_blend_(i32, f64, nk_assign_from_to_, nk_f64_to_i32_serial) // nk_each_blend_i32_serial
126
126
  nk_define_each_blend_(u32, f64, nk_assign_from_to_, nk_f64_to_u32_serial) // nk_each_blend_u32_serial
127
- nk_define_each_blend_(i64, f64, nk_assign_from_to_, nk_f64_to_i64_serial) // nk_each_blend_i64_serial
128
- nk_define_each_blend_(u64, f64, nk_assign_from_to_, nk_f64_to_u64_serial) // nk_each_blend_u64_serial
127
+ nk_define_each_blend_(i64, f64, nk_f64_from_i64_, nk_f64_to_i64_serial) // nk_each_blend_i64_serial
128
+ nk_define_each_blend_(u64, f64, nk_f64_from_u64_, nk_f64_to_u64_serial) // nk_each_blend_u64_serial
129
129
 
130
130
  nk_define_each_fma_(f64, f64, nk_assign_from_to_, nk_assign_from_to_) // nk_each_fma_f64_serial
131
131
  nk_define_each_fma_(f32, f32, nk_assign_from_to_, nk_assign_from_to_) // nk_each_fma_f32_serial
@@ -141,8 +141,8 @@ nk_define_each_fma_(i16, f32, nk_assign_from_to_, nk_f32_to_i16_serial) //
141
141
  nk_define_each_fma_(u16, f32, nk_assign_from_to_, nk_f32_to_u16_serial) // nk_each_fma_u16_serial
142
142
  nk_define_each_fma_(i32, f64, nk_assign_from_to_, nk_f64_to_i32_serial) // nk_each_fma_i32_serial
143
143
  nk_define_each_fma_(u32, f64, nk_assign_from_to_, nk_f64_to_u32_serial) // nk_each_fma_u32_serial
144
- nk_define_each_fma_(i64, f64, nk_assign_from_to_, nk_f64_to_i64_serial) // nk_each_fma_i64_serial
145
- nk_define_each_fma_(u64, f64, nk_assign_from_to_, nk_f64_to_u64_serial) // nk_each_fma_u64_serial
144
+ nk_define_each_fma_(i64, f64, nk_f64_from_i64_, nk_f64_to_i64_serial) // nk_each_fma_i64_serial
145
+ nk_define_each_fma_(u64, f64, nk_f64_from_u64_, nk_f64_to_u64_serial) // nk_each_fma_u64_serial
146
146
 
147
147
  #undef nk_define_each_scale_
148
148
  #undef nk_define_each_sum_
@@ -8,13 +8,13 @@
8
8
  *
9
9
  * @section skylake_elementwise_instructions Relevant Instructions
10
10
  *
11
- * Intrinsic Instruction SKL ICL Genoa
12
- * _mm512_add_ps VADDPS (ZMM, ZMM, ZMM) 4cy @ p05 4cy @ p0 3cy @ p01
13
- * _mm512_fmadd_ps VFMADD132PS (ZMM, ZMM, ZMM) 4cy @ p05 4cy @ p0 4cy @ p01
14
- * _mm512_mul_ps VMULPS (ZMM, ZMM, ZMM) 4cy @ p05 4cy @ p0 3cy @ p01
15
- * _mm512_cvtph_ps VCVTPH2PS (ZMM, YMM) 5cy @ p05 7cy @ p0 5cy @ p01
16
- * _mm512_maskz_loadu_ps VMOVUPS (ZMM {K}, M512) 7cy @ p23 7cy @ p23 7cy @ p23
17
- * _mm512_mask_storeu_ps VMOVUPS (M512 {K}, ZMM) 4cy @ p4 4cy @ p4 4cy @ p4
11
+ * Intrinsic Instruction SKL ICL Genoa
12
+ * _mm512_add_ps VADDPS (ZMM, ZMM, ZMM) 4cy @ p05 4cy @ p0 3cy @ p01
13
+ * _mm512_fmadd_ps VFMADD132PS (ZMM, ZMM, ZMM) 4cy @ p05 4cy @ p0 4cy @ p01
14
+ * _mm512_mul_ps VMULPS (ZMM, ZMM, ZMM) 4cy @ p05 4cy @ p0 3cy @ p01
15
+ * _mm512_cvtph_ps VCVTPH2PS (ZMM, YMM) 5cy @ p05 7cy @ p0 5cy @ p01
16
+ * _mm512_maskz_loadu_ps VMOVUPS (ZMM {K}, M512) 7cy @ p23 7cy @ p23 7cy @ p23
17
+ * _mm512_mask_storeu_ps VMOVUPS (M512 {K}, ZMM) 4cy @ p4 4cy @ p4 4cy @ p4
18
18
  *
19
19
  * Skylake-X server chips have dual 512-bit FMA units enabling 0.5cy throughput for arithmetic operations.
20
20
  * AVX-512 masked loads and stores eliminate branch misprediction penalties for partial vector processing.