numkong 7.0.0 → 7.4.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (315) hide show
  1. package/README.md +239 -122
  2. package/binding.gyp +25 -491
  3. package/c/dispatch_bf16.c +59 -1
  4. package/c/dispatch_e2m3.c +41 -8
  5. package/c/dispatch_e3m2.c +49 -8
  6. package/c/dispatch_e4m3.c +51 -9
  7. package/c/dispatch_e5m2.c +45 -1
  8. package/c/dispatch_f16.c +79 -26
  9. package/c/dispatch_f16c.c +5 -5
  10. package/c/dispatch_f32.c +56 -0
  11. package/c/dispatch_f64.c +52 -0
  12. package/c/dispatch_i4.c +3 -0
  13. package/c/dispatch_i8.c +62 -3
  14. package/c/dispatch_other.c +18 -0
  15. package/c/dispatch_u1.c +54 -9
  16. package/c/dispatch_u4.c +3 -0
  17. package/c/dispatch_u8.c +64 -3
  18. package/c/numkong.c +3 -0
  19. package/include/README.md +79 -9
  20. package/include/numkong/attention/sapphireamx.h +278 -276
  21. package/include/numkong/attention/sme.h +983 -977
  22. package/include/numkong/attention.h +1 -1
  23. package/include/numkong/capabilities.h +289 -94
  24. package/include/numkong/cast/README.md +40 -40
  25. package/include/numkong/cast/diamond.h +64 -0
  26. package/include/numkong/cast/haswell.h +42 -194
  27. package/include/numkong/cast/icelake.h +42 -37
  28. package/include/numkong/cast/loongsonasx.h +252 -0
  29. package/include/numkong/cast/neon.h +216 -249
  30. package/include/numkong/cast/powervsx.h +449 -0
  31. package/include/numkong/cast/rvv.h +223 -274
  32. package/include/numkong/cast/sapphire.h +18 -18
  33. package/include/numkong/cast/serial.h +1018 -944
  34. package/include/numkong/cast/skylake.h +82 -23
  35. package/include/numkong/cast/v128relaxed.h +462 -105
  36. package/include/numkong/cast.h +24 -0
  37. package/include/numkong/cast.hpp +44 -0
  38. package/include/numkong/curved/README.md +17 -17
  39. package/include/numkong/curved/neon.h +131 -7
  40. package/include/numkong/curved/neonbfdot.h +6 -7
  41. package/include/numkong/curved/rvv.h +26 -26
  42. package/include/numkong/curved/smef64.h +186 -182
  43. package/include/numkong/curved.h +14 -18
  44. package/include/numkong/dot/README.md +154 -137
  45. package/include/numkong/dot/alder.h +43 -43
  46. package/include/numkong/dot/diamond.h +158 -0
  47. package/include/numkong/dot/genoa.h +4 -30
  48. package/include/numkong/dot/haswell.h +215 -180
  49. package/include/numkong/dot/icelake.h +190 -76
  50. package/include/numkong/dot/loongsonasx.h +671 -0
  51. package/include/numkong/dot/neon.h +124 -73
  52. package/include/numkong/dot/neonbfdot.h +11 -12
  53. package/include/numkong/dot/neonfhm.h +44 -46
  54. package/include/numkong/dot/neonfp8.h +323 -0
  55. package/include/numkong/dot/neonsdot.h +190 -76
  56. package/include/numkong/dot/powervsx.h +752 -0
  57. package/include/numkong/dot/rvv.h +92 -84
  58. package/include/numkong/dot/rvvbf16.h +12 -12
  59. package/include/numkong/dot/rvvhalf.h +12 -12
  60. package/include/numkong/dot/sapphire.h +4 -4
  61. package/include/numkong/dot/serial.h +66 -30
  62. package/include/numkong/dot/sierra.h +31 -31
  63. package/include/numkong/dot/skylake.h +142 -110
  64. package/include/numkong/dot/sve.h +217 -177
  65. package/include/numkong/dot/svebfdot.h +10 -10
  66. package/include/numkong/dot/svehalf.h +85 -41
  67. package/include/numkong/dot/svesdot.h +89 -0
  68. package/include/numkong/dot/v128relaxed.h +124 -89
  69. package/include/numkong/dot.h +114 -48
  70. package/include/numkong/dots/README.md +203 -203
  71. package/include/numkong/dots/alder.h +12 -9
  72. package/include/numkong/dots/diamond.h +86 -0
  73. package/include/numkong/dots/genoa.h +10 -4
  74. package/include/numkong/dots/haswell.h +63 -48
  75. package/include/numkong/dots/icelake.h +27 -18
  76. package/include/numkong/dots/loongsonasx.h +176 -0
  77. package/include/numkong/dots/neon.h +14 -11
  78. package/include/numkong/dots/neonbfdot.h +4 -3
  79. package/include/numkong/dots/neonfhm.h +11 -9
  80. package/include/numkong/dots/neonfp8.h +99 -0
  81. package/include/numkong/dots/neonsdot.h +48 -12
  82. package/include/numkong/dots/powervsx.h +194 -0
  83. package/include/numkong/dots/rvv.h +451 -344
  84. package/include/numkong/dots/sapphireamx.h +1028 -984
  85. package/include/numkong/dots/serial.h +213 -197
  86. package/include/numkong/dots/sierra.h +10 -7
  87. package/include/numkong/dots/skylake.h +47 -36
  88. package/include/numkong/dots/sme.h +2001 -2364
  89. package/include/numkong/dots/smebi32.h +175 -162
  90. package/include/numkong/dots/smef64.h +328 -323
  91. package/include/numkong/dots/v128relaxed.h +64 -41
  92. package/include/numkong/dots.h +573 -293
  93. package/include/numkong/dots.hpp +45 -43
  94. package/include/numkong/each/README.md +133 -137
  95. package/include/numkong/each/haswell.h +6 -6
  96. package/include/numkong/each/icelake.h +7 -7
  97. package/include/numkong/each/neon.h +76 -42
  98. package/include/numkong/each/neonbfdot.h +11 -12
  99. package/include/numkong/each/neonhalf.h +24 -116
  100. package/include/numkong/each/rvv.h +28 -28
  101. package/include/numkong/each/sapphire.h +27 -161
  102. package/include/numkong/each/serial.h +6 -6
  103. package/include/numkong/each/skylake.h +7 -7
  104. package/include/numkong/each/v128relaxed.h +562 -0
  105. package/include/numkong/each.h +148 -62
  106. package/include/numkong/each.hpp +2 -2
  107. package/include/numkong/geospatial/README.md +18 -18
  108. package/include/numkong/geospatial/haswell.h +365 -325
  109. package/include/numkong/geospatial/neon.h +350 -306
  110. package/include/numkong/geospatial/rvv.h +4 -4
  111. package/include/numkong/geospatial/skylake.h +376 -340
  112. package/include/numkong/geospatial/v128relaxed.h +366 -327
  113. package/include/numkong/geospatial.h +17 -17
  114. package/include/numkong/matrix.hpp +4 -4
  115. package/include/numkong/maxsim/README.md +14 -14
  116. package/include/numkong/maxsim/alder.h +6 -6
  117. package/include/numkong/maxsim/genoa.h +4 -4
  118. package/include/numkong/maxsim/haswell.h +6 -6
  119. package/include/numkong/maxsim/icelake.h +18 -18
  120. package/include/numkong/maxsim/neonsdot.h +21 -21
  121. package/include/numkong/maxsim/sapphireamx.h +14 -14
  122. package/include/numkong/maxsim/serial.h +6 -6
  123. package/include/numkong/maxsim/sme.h +221 -196
  124. package/include/numkong/maxsim/v128relaxed.h +6 -6
  125. package/include/numkong/mesh/README.md +62 -56
  126. package/include/numkong/mesh/haswell.h +339 -464
  127. package/include/numkong/mesh/neon.h +1100 -519
  128. package/include/numkong/mesh/neonbfdot.h +36 -68
  129. package/include/numkong/mesh/rvv.h +530 -435
  130. package/include/numkong/mesh/serial.h +75 -91
  131. package/include/numkong/mesh/skylake.h +1627 -302
  132. package/include/numkong/mesh/v128relaxed.h +443 -330
  133. package/include/numkong/mesh.h +63 -49
  134. package/include/numkong/mesh.hpp +4 -4
  135. package/include/numkong/numkong.h +3 -3
  136. package/include/numkong/numkong.hpp +1 -0
  137. package/include/numkong/probability/README.md +23 -19
  138. package/include/numkong/probability/neon.h +82 -52
  139. package/include/numkong/probability/rvv.h +28 -23
  140. package/include/numkong/probability/serial.h +51 -39
  141. package/include/numkong/probability.h +20 -23
  142. package/include/numkong/random.h +1 -1
  143. package/include/numkong/reduce/README.md +143 -138
  144. package/include/numkong/reduce/alder.h +81 -77
  145. package/include/numkong/reduce/haswell.h +222 -220
  146. package/include/numkong/reduce/neon.h +629 -519
  147. package/include/numkong/reduce/neonbfdot.h +7 -218
  148. package/include/numkong/reduce/neonfhm.h +9 -381
  149. package/include/numkong/reduce/neonsdot.h +9 -9
  150. package/include/numkong/reduce/rvv.h +928 -802
  151. package/include/numkong/reduce/serial.h +23 -27
  152. package/include/numkong/reduce/sierra.h +20 -20
  153. package/include/numkong/reduce/skylake.h +326 -324
  154. package/include/numkong/reduce/v128relaxed.h +52 -52
  155. package/include/numkong/reduce.h +4 -23
  156. package/include/numkong/reduce.hpp +156 -11
  157. package/include/numkong/scalar/README.md +6 -6
  158. package/include/numkong/scalar/haswell.h +26 -17
  159. package/include/numkong/scalar/loongsonasx.h +74 -0
  160. package/include/numkong/scalar/neon.h +9 -9
  161. package/include/numkong/scalar/powervsx.h +96 -0
  162. package/include/numkong/scalar/rvv.h +2 -2
  163. package/include/numkong/scalar/sapphire.h +21 -10
  164. package/include/numkong/scalar/serial.h +21 -21
  165. package/include/numkong/scalar.h +13 -0
  166. package/include/numkong/set/README.md +28 -28
  167. package/include/numkong/set/haswell.h +12 -12
  168. package/include/numkong/set/icelake.h +14 -14
  169. package/include/numkong/set/loongsonasx.h +181 -0
  170. package/include/numkong/set/neon.h +17 -18
  171. package/include/numkong/set/powervsx.h +326 -0
  172. package/include/numkong/set/rvv.h +4 -4
  173. package/include/numkong/set/serial.h +6 -6
  174. package/include/numkong/set/sve.h +60 -59
  175. package/include/numkong/set/v128relaxed.h +6 -6
  176. package/include/numkong/set.h +21 -7
  177. package/include/numkong/sets/README.md +26 -26
  178. package/include/numkong/sets/loongsonasx.h +52 -0
  179. package/include/numkong/sets/powervsx.h +65 -0
  180. package/include/numkong/sets/smebi32.h +395 -364
  181. package/include/numkong/sets.h +83 -40
  182. package/include/numkong/sparse/README.md +4 -4
  183. package/include/numkong/sparse/icelake.h +101 -101
  184. package/include/numkong/sparse/serial.h +1 -1
  185. package/include/numkong/sparse/sve2.h +137 -141
  186. package/include/numkong/sparse/turin.h +12 -12
  187. package/include/numkong/sparse.h +10 -10
  188. package/include/numkong/spatial/README.md +230 -226
  189. package/include/numkong/spatial/alder.h +113 -116
  190. package/include/numkong/spatial/diamond.h +240 -0
  191. package/include/numkong/spatial/genoa.h +0 -68
  192. package/include/numkong/spatial/haswell.h +74 -55
  193. package/include/numkong/spatial/icelake.h +539 -58
  194. package/include/numkong/spatial/loongsonasx.h +483 -0
  195. package/include/numkong/spatial/neon.h +125 -52
  196. package/include/numkong/spatial/neonbfdot.h +8 -9
  197. package/include/numkong/spatial/neonfp8.h +258 -0
  198. package/include/numkong/spatial/neonsdot.h +180 -12
  199. package/include/numkong/spatial/powervsx.h +738 -0
  200. package/include/numkong/spatial/rvv.h +146 -139
  201. package/include/numkong/spatial/rvvbf16.h +17 -12
  202. package/include/numkong/spatial/rvvhalf.h +13 -10
  203. package/include/numkong/spatial/serial.h +13 -12
  204. package/include/numkong/spatial/sierra.h +232 -39
  205. package/include/numkong/spatial/skylake.h +73 -74
  206. package/include/numkong/spatial/sve.h +93 -72
  207. package/include/numkong/spatial/svebfdot.h +29 -29
  208. package/include/numkong/spatial/svehalf.h +52 -26
  209. package/include/numkong/spatial/svesdot.h +142 -0
  210. package/include/numkong/spatial/v128relaxed.h +293 -41
  211. package/include/numkong/spatial.h +338 -82
  212. package/include/numkong/spatials/README.md +194 -194
  213. package/include/numkong/spatials/diamond.h +82 -0
  214. package/include/numkong/spatials/haswell.h +2 -2
  215. package/include/numkong/spatials/loongsonasx.h +153 -0
  216. package/include/numkong/spatials/neonfp8.h +111 -0
  217. package/include/numkong/spatials/neonsdot.h +34 -0
  218. package/include/numkong/spatials/powervsx.h +153 -0
  219. package/include/numkong/spatials/rvv.h +259 -243
  220. package/include/numkong/spatials/sapphireamx.h +173 -173
  221. package/include/numkong/spatials/serial.h +2 -2
  222. package/include/numkong/spatials/skylake.h +2 -2
  223. package/include/numkong/spatials/sme.h +590 -605
  224. package/include/numkong/spatials/smef64.h +139 -130
  225. package/include/numkong/spatials/v128relaxed.h +2 -2
  226. package/include/numkong/spatials.h +820 -500
  227. package/include/numkong/spatials.hpp +49 -48
  228. package/include/numkong/tensor.hpp +406 -17
  229. package/include/numkong/trigonometry/README.md +19 -19
  230. package/include/numkong/trigonometry/haswell.h +402 -401
  231. package/include/numkong/trigonometry/neon.h +386 -387
  232. package/include/numkong/trigonometry/rvv.h +52 -51
  233. package/include/numkong/trigonometry/serial.h +13 -13
  234. package/include/numkong/trigonometry/skylake.h +373 -369
  235. package/include/numkong/trigonometry/v128relaxed.h +375 -374
  236. package/include/numkong/trigonometry.h +13 -13
  237. package/include/numkong/trigonometry.hpp +2 -2
  238. package/include/numkong/types.h +287 -49
  239. package/include/numkong/types.hpp +436 -12
  240. package/include/numkong/vector.hpp +82 -14
  241. package/javascript/dist/cjs/numkong-wasm.js +6 -12
  242. package/javascript/dist/cjs/numkong.d.ts +7 -1
  243. package/javascript/dist/cjs/numkong.js +37 -11
  244. package/javascript/dist/cjs/types.d.ts +9 -0
  245. package/javascript/dist/cjs/types.js +96 -0
  246. package/javascript/dist/esm/numkong-browser.d.ts +14 -0
  247. package/javascript/dist/esm/numkong-browser.js +23 -0
  248. package/javascript/dist/esm/numkong-wasm.js +6 -12
  249. package/javascript/dist/esm/numkong.d.ts +7 -1
  250. package/javascript/dist/esm/numkong.js +37 -11
  251. package/javascript/dist/esm/types.d.ts +9 -0
  252. package/javascript/dist/esm/types.js +96 -0
  253. package/javascript/node-gyp-build.d.ts +4 -1
  254. package/javascript/numkong-browser.ts +40 -0
  255. package/javascript/numkong-wasm.ts +7 -13
  256. package/javascript/numkong.c +5 -26
  257. package/javascript/numkong.ts +36 -11
  258. package/javascript/tsconfig-base.json +1 -0
  259. package/javascript/tsconfig-cjs.json +6 -1
  260. package/javascript/types.ts +110 -0
  261. package/numkong.gypi +101 -0
  262. package/package.json +34 -13
  263. package/probes/arm_neon.c +8 -0
  264. package/probes/arm_neon_bfdot.c +9 -0
  265. package/probes/arm_neon_fhm.c +9 -0
  266. package/probes/arm_neon_half.c +8 -0
  267. package/probes/arm_neon_sdot.c +9 -0
  268. package/probes/arm_neonfp8.c +9 -0
  269. package/probes/arm_sme.c +16 -0
  270. package/probes/arm_sme2.c +16 -0
  271. package/probes/arm_sme2p1.c +16 -0
  272. package/probes/arm_sme_bf16.c +16 -0
  273. package/probes/arm_sme_bi32.c +16 -0
  274. package/probes/arm_sme_f64.c +16 -0
  275. package/probes/arm_sme_fa64.c +14 -0
  276. package/probes/arm_sme_half.c +16 -0
  277. package/probes/arm_sme_lut2.c +15 -0
  278. package/probes/arm_sve.c +18 -0
  279. package/probes/arm_sve2.c +20 -0
  280. package/probes/arm_sve2p1.c +18 -0
  281. package/probes/arm_sve_bfdot.c +20 -0
  282. package/probes/arm_sve_half.c +18 -0
  283. package/probes/arm_sve_sdot.c +21 -0
  284. package/probes/loongarch_lasx.c +12 -0
  285. package/probes/power_vsx.c +12 -0
  286. package/probes/probe.js +127 -0
  287. package/probes/riscv_rvv.c +14 -0
  288. package/probes/riscv_rvv_bb.c +15 -0
  289. package/probes/riscv_rvv_bf16.c +17 -0
  290. package/probes/riscv_rvv_half.c +14 -0
  291. package/probes/wasm_v128relaxed.c +11 -0
  292. package/probes/x86_alder.c +17 -0
  293. package/probes/x86_diamond.c +17 -0
  294. package/probes/x86_genoa.c +17 -0
  295. package/probes/x86_graniteamx.c +19 -0
  296. package/probes/x86_haswell.c +11 -0
  297. package/probes/x86_icelake.c +17 -0
  298. package/probes/x86_sapphire.c +16 -0
  299. package/probes/x86_sapphireamx.c +18 -0
  300. package/probes/x86_sierra.c +17 -0
  301. package/probes/x86_skylake.c +15 -0
  302. package/probes/x86_turin.c +17 -0
  303. package/wasm/numkong-emscripten.js +2 -0
  304. package/wasm/numkong.d.ts +14 -0
  305. package/wasm/numkong.js +1124 -0
  306. package/wasm/numkong.wasm +0 -0
  307. package/include/numkong/curved/neonhalf.h +0 -212
  308. package/include/numkong/dot/neonhalf.h +0 -198
  309. package/include/numkong/dots/neonhalf.h +0 -57
  310. package/include/numkong/mesh/neonhalf.h +0 -616
  311. package/include/numkong/reduce/neonhalf.h +0 -157
  312. package/include/numkong/spatial/neonhalf.h +0 -118
  313. package/include/numkong/spatial/sapphire.h +0 -343
  314. package/include/numkong/spatials/neonhalf.h +0 -58
  315. package/javascript/README.md +0 -246
@@ -0,0 +1,323 @@
1
+ /**
2
+ * @brief SIMD-accelerated Dot Products for NEON FP8DOT4.
3
+ * @file include/numkong/dot/neonfp8.h
4
+ * @author Ash Vardanian
5
+ * @date March 23, 2026
6
+ *
7
+ * @sa include/numkong/dot.h
8
+ *
9
+ * @section dot_neonfp8_instructions ARM NEON FP8DOT4 Instructions (FEAT_FP8DOT4)
10
+ *
11
+ * Intrinsic Instruction V1
12
+ * vdotq_f32_mf8 FDOT (V.4S, V.16B, V.16B) 4cy @ 2p
13
+ * vld1q_u8 LD1 (V.16B) 4cy @ 2p
14
+ * vaddvq_f32 FADDP+FADDP (V.4S) 4cy @ 1p
15
+ * vpaddq_f32 FADDP (V.4S, V.4S, V.4S) 2cy @ 2p
16
+ *
17
+ * FEAT_FP8DOT4 adds NEON FDOT instructions that take two 128-bit vectors of FP8 (E4M3 or E5M2),
18
+ * perform 4-way multiply-accumulate into FP32 per lane. Each FDOT processes 16 FP8 elements
19
+ * into 4 FP32 accumulators. The FP8 format is selected by the FPMR register.
20
+ *
21
+ * FP6 types (E2M3, E3M2) are losslessly promoted to FP8 (E4M3, E5M2) by rebiasing the exponent.
22
+ * Normal values: magnitude += 48. Subnormal values (exp=0): 8-entry or 4-entry TBL lookup.
23
+ *
24
+ * @section dot_neonfp8_stateful Stateful Streaming Logic
25
+ *
26
+ * Defines stateful init/update/finalize helpers for tiled GEMM via the dots/ macros:
27
+ * - nk_dot_e4m3x16_state_neonfp8_t, nk_dot_e5m2x16_state_neonfp8_t
28
+ * - nk_dot_e2m3x16_state_neonfp8_t, nk_dot_e3m2x16_state_neonfp8_t
29
+ */
30
+ #ifndef NK_DOT_NEONFP8_H
31
+ #define NK_DOT_NEONFP8_H
32
+
33
+ #if NK_TARGET_ARM_
34
+ #if NK_TARGET_NEONFP8
35
+
36
+ #include "numkong/types.h"
37
+ #include "numkong/cast/serial.h" // `nk_partial_load_b8x16_serial_`
38
+
39
+ /** @brief FPM immediate for E4M3 × E4M3 dot products: src1=E4M3(1), src2=E4M3(1). */
40
+ #define NK_FPM_E4M3_ ((fpm_t)((1ull << 0) | (1ull << 3)))
41
+ /** @brief FPM immediate for E5M2 × E5M2 dot products: src1=E5M2(0), src2=E5M2(0). */
42
+ #define NK_FPM_E5M2_ ((fpm_t)0)
43
+
44
+ #if defined(__cplusplus)
45
+ extern "C" {
46
+ #endif
47
+
48
+ #if defined(__clang__)
49
+ #pragma clang attribute push(__attribute__((target("arch=armv8-a+simd+fp8dot4"))), apply_to = function)
50
+ #elif defined(__GNUC__)
51
+ #pragma GCC push_options
52
+ #pragma GCC target("arch=armv8-a+simd+fp8dot4")
53
+ #endif
54
+
55
+ /**
56
+ * @brief Convert 16 E2M3 bytes (0b00SEEMMM) to E4M3 bytes (0bSEEEEMMM).
57
+ *
58
+ * Normal values (exp>0, mag>=8): rebias exponent by +6 → magnitude += 48.
59
+ * Subnormal values (exp=0, mag<8): 8-entry TBL lookup for normalization.
60
+ * Zero (mag=0): maps to E4M3 zero. Sign moved from bit 5 to bit 7.
61
+ */
62
+ NK_INTERNAL uint8x16_t nk_e2m3x16_to_e4m3x16_neonfp8_(uint8x16_t raw_u8x16) {
63
+ uint8x16_t sign_u8x16 = vandq_u8(raw_u8x16, vdupq_n_u8(0x20));
64
+ uint8x16_t mag_u8x16 = vandq_u8(raw_u8x16, vdupq_n_u8(0x1F));
65
+
66
+ // Normal path: rebias exponent by +6 → add 48 to magnitude
67
+ uint8x16_t normal_mag_u8x16 = vaddq_u8(mag_u8x16, vdupq_n_u8(48));
68
+
69
+ // Subnormal path: 8-entry LUT for mag 0-7
70
+ // 0→0, 1→32, 2→40, 3→44, 4→48, 5→50, 6→52, 7→54
71
+ uint8x16_t sub_lut_u8x16 = vcombine_u8(vcreate_u8(0x363432302c282000ull), vcreate_u8(0));
72
+ uint8x16_t sub_mag_u8x16 = vqtbl1q_u8(sub_lut_u8x16, mag_u8x16);
73
+
74
+ // Select: subnormal (mag < 8) uses LUT, normal uses +48
75
+ uint8x16_t is_normal_u8x16 = vcgeq_u8(mag_u8x16, vdupq_n_u8(8));
76
+ uint8x16_t result_mag_u8x16 = vbslq_u8(is_normal_u8x16, normal_mag_u8x16, sub_mag_u8x16);
77
+
78
+ // Move sign from bit 5 to bit 7
79
+ uint8x16_t sign_shifted_u8x16 = vshlq_n_u8(sign_u8x16, 2);
80
+ return vorrq_u8(sign_shifted_u8x16, result_mag_u8x16);
81
+ }
82
+
83
+ /**
84
+ * @brief Convert 16 E3M2 bytes (0b00SEEEMM) to E5M2 bytes (0bSEEEEEMM).
85
+ *
86
+ * Normal values (exp>0, mag>=4): rebias exponent by +12 → magnitude += 48.
87
+ * Subnormal values (exp=0, mag<4): 4-entry TBL lookup for normalization.
88
+ * Zero (mag=0): maps to E5M2 zero. Sign moved from bit 5 to bit 7.
89
+ */
90
+ NK_INTERNAL uint8x16_t nk_e3m2x16_to_e5m2x16_neonfp8_(uint8x16_t raw_u8x16) {
91
+ uint8x16_t sign_u8x16 = vandq_u8(raw_u8x16, vdupq_n_u8(0x20));
92
+ uint8x16_t mag_u8x16 = vandq_u8(raw_u8x16, vdupq_n_u8(0x1F));
93
+
94
+ // Normal path: rebias exponent by +12 → add 48 to magnitude
95
+ uint8x16_t normal_mag_u8x16 = vaddq_u8(mag_u8x16, vdupq_n_u8(48));
96
+
97
+ // Subnormal path: 4-entry LUT for mag 0-3
98
+ // 0→0, 1→44, 2→48, 3→50
99
+ uint8x16_t sub_lut_u8x16 = vcombine_u8(vcreate_u8(0x0000000032302c00ull), vcreate_u8(0));
100
+ uint8x16_t sub_mag_u8x16 = vqtbl1q_u8(sub_lut_u8x16, mag_u8x16);
101
+
102
+ // Select: subnormal (mag < 4) uses LUT, normal uses +48
103
+ uint8x16_t is_normal_u8x16 = vcgeq_u8(mag_u8x16, vdupq_n_u8(4));
104
+ uint8x16_t result_mag_u8x16 = vbslq_u8(is_normal_u8x16, normal_mag_u8x16, sub_mag_u8x16);
105
+
106
+ // Move sign from bit 5 to bit 7
107
+ uint8x16_t sign_shifted_u8x16 = vshlq_n_u8(sign_u8x16, 2);
108
+ return vorrq_u8(sign_shifted_u8x16, result_mag_u8x16);
109
+ }
110
+
111
+ NK_PUBLIC void nk_dot_e4m3_neonfp8(nk_e4m3_t const *a_scalars, nk_e4m3_t const *b_scalars, nk_size_t count_scalars,
112
+ nk_f32_t *result) {
113
+ mfloat8x16_t a_mf8x16, b_mf8x16;
114
+ float32x4_t sum_f32x4 = vdupq_n_f32(0);
115
+ nk_dot_e4m3_neonfp8_cycle:
116
+ if (count_scalars < 16) {
117
+ nk_b128_vec_t a_vec, b_vec;
118
+ nk_partial_load_b8x16_serial_(a_scalars, &a_vec, count_scalars);
119
+ nk_partial_load_b8x16_serial_(b_scalars, &b_vec, count_scalars);
120
+ a_mf8x16 = vreinterpretq_mf8_u8(a_vec.u8x16);
121
+ b_mf8x16 = vreinterpretq_mf8_u8(b_vec.u8x16);
122
+ count_scalars = 0;
123
+ }
124
+ else {
125
+ a_mf8x16 = vreinterpretq_mf8_u8(vld1q_u8((nk_u8_t const *)a_scalars));
126
+ b_mf8x16 = vreinterpretq_mf8_u8(vld1q_u8((nk_u8_t const *)b_scalars));
127
+ a_scalars += 16, b_scalars += 16, count_scalars -= 16;
128
+ }
129
+ sum_f32x4 = vdotq_f32_mf8_fpm(sum_f32x4, a_mf8x16, b_mf8x16, NK_FPM_E4M3_);
130
+ if (count_scalars) goto nk_dot_e4m3_neonfp8_cycle;
131
+ *result = vaddvq_f32(sum_f32x4);
132
+ }
133
+
134
+ NK_PUBLIC void nk_dot_e5m2_neonfp8(nk_e5m2_t const *a_scalars, nk_e5m2_t const *b_scalars, nk_size_t count_scalars,
135
+ nk_f32_t *result) {
136
+ mfloat8x16_t a_mf8x16, b_mf8x16;
137
+ float32x4_t sum_f32x4 = vdupq_n_f32(0);
138
+ nk_dot_e5m2_neonfp8_cycle:
139
+ if (count_scalars < 16) {
140
+ nk_b128_vec_t a_vec, b_vec;
141
+ nk_partial_load_b8x16_serial_(a_scalars, &a_vec, count_scalars);
142
+ nk_partial_load_b8x16_serial_(b_scalars, &b_vec, count_scalars);
143
+ a_mf8x16 = vreinterpretq_mf8_u8(a_vec.u8x16);
144
+ b_mf8x16 = vreinterpretq_mf8_u8(b_vec.u8x16);
145
+ count_scalars = 0;
146
+ }
147
+ else {
148
+ a_mf8x16 = vreinterpretq_mf8_u8(vld1q_u8((nk_u8_t const *)a_scalars));
149
+ b_mf8x16 = vreinterpretq_mf8_u8(vld1q_u8((nk_u8_t const *)b_scalars));
150
+ a_scalars += 16, b_scalars += 16, count_scalars -= 16;
151
+ }
152
+ sum_f32x4 = vdotq_f32_mf8_fpm(sum_f32x4, a_mf8x16, b_mf8x16, NK_FPM_E5M2_);
153
+ if (count_scalars) goto nk_dot_e5m2_neonfp8_cycle;
154
+ *result = vaddvq_f32(sum_f32x4);
155
+ }
156
+
157
+ NK_PUBLIC void nk_dot_e2m3_neonfp8(nk_e2m3_t const *a_scalars, nk_e2m3_t const *b_scalars, nk_size_t count_scalars,
158
+ nk_f32_t *result) {
159
+ mfloat8x16_t a_mf8x16, b_mf8x16;
160
+ float32x4_t sum_f32x4 = vdupq_n_f32(0);
161
+ nk_dot_e2m3_neonfp8_cycle:
162
+ if (count_scalars < 16) {
163
+ nk_b128_vec_t a_vec, b_vec;
164
+ nk_partial_load_b8x16_serial_(a_scalars, &a_vec, count_scalars);
165
+ nk_partial_load_b8x16_serial_(b_scalars, &b_vec, count_scalars);
166
+ a_mf8x16 = vreinterpretq_mf8_u8(nk_e2m3x16_to_e4m3x16_neonfp8_(a_vec.u8x16));
167
+ b_mf8x16 = vreinterpretq_mf8_u8(nk_e2m3x16_to_e4m3x16_neonfp8_(b_vec.u8x16));
168
+ count_scalars = 0;
169
+ }
170
+ else {
171
+ a_mf8x16 = vreinterpretq_mf8_u8(nk_e2m3x16_to_e4m3x16_neonfp8_(vld1q_u8((nk_u8_t const *)a_scalars)));
172
+ b_mf8x16 = vreinterpretq_mf8_u8(nk_e2m3x16_to_e4m3x16_neonfp8_(vld1q_u8((nk_u8_t const *)b_scalars)));
173
+ a_scalars += 16, b_scalars += 16, count_scalars -= 16;
174
+ }
175
+ sum_f32x4 = vdotq_f32_mf8_fpm(sum_f32x4, a_mf8x16, b_mf8x16, NK_FPM_E4M3_);
176
+ if (count_scalars) goto nk_dot_e2m3_neonfp8_cycle;
177
+ *result = vaddvq_f32(sum_f32x4);
178
+ }
179
+
180
+ NK_PUBLIC void nk_dot_e3m2_neonfp8(nk_e3m2_t const *a_scalars, nk_e3m2_t const *b_scalars, nk_size_t count_scalars,
181
+ nk_f32_t *result) {
182
+ mfloat8x16_t a_mf8x16, b_mf8x16;
183
+ float32x4_t sum_f32x4 = vdupq_n_f32(0);
184
+ nk_dot_e3m2_neonfp8_cycle:
185
+ if (count_scalars < 16) {
186
+ nk_b128_vec_t a_vec, b_vec;
187
+ nk_partial_load_b8x16_serial_(a_scalars, &a_vec, count_scalars);
188
+ nk_partial_load_b8x16_serial_(b_scalars, &b_vec, count_scalars);
189
+ a_mf8x16 = vreinterpretq_mf8_u8(nk_e3m2x16_to_e5m2x16_neonfp8_(a_vec.u8x16));
190
+ b_mf8x16 = vreinterpretq_mf8_u8(nk_e3m2x16_to_e5m2x16_neonfp8_(b_vec.u8x16));
191
+ count_scalars = 0;
192
+ }
193
+ else {
194
+ a_mf8x16 = vreinterpretq_mf8_u8(nk_e3m2x16_to_e5m2x16_neonfp8_(vld1q_u8((nk_u8_t const *)a_scalars)));
195
+ b_mf8x16 = vreinterpretq_mf8_u8(nk_e3m2x16_to_e5m2x16_neonfp8_(vld1q_u8((nk_u8_t const *)b_scalars)));
196
+ a_scalars += 16, b_scalars += 16, count_scalars -= 16;
197
+ }
198
+ sum_f32x4 = vdotq_f32_mf8_fpm(sum_f32x4, a_mf8x16, b_mf8x16, NK_FPM_E5M2_);
199
+ if (count_scalars) goto nk_dot_e3m2_neonfp8_cycle;
200
+ *result = vaddvq_f32(sum_f32x4);
201
+ }
202
+
203
+ typedef struct nk_dot_e4m3x16_state_neonfp8_t {
204
+ float32x4_t sum_f32x4;
205
+ } nk_dot_e4m3x16_state_neonfp8_t;
206
+
207
+ NK_INTERNAL void nk_dot_e4m3x16_init_neonfp8(nk_dot_e4m3x16_state_neonfp8_t *state) {
208
+ state->sum_f32x4 = vdupq_n_f32(0);
209
+ }
210
+
211
+ NK_INTERNAL void nk_dot_e4m3x16_update_neonfp8(nk_dot_e4m3x16_state_neonfp8_t *state, nk_b128_vec_t a, nk_b128_vec_t b,
212
+ nk_size_t depth_offset, nk_size_t active_dimensions) {
213
+ nk_unused_(depth_offset);
214
+ nk_unused_(active_dimensions);
215
+ mfloat8x16_t a_mf8x16 = vreinterpretq_mf8_u8(a.u8x16);
216
+ mfloat8x16_t b_mf8x16 = vreinterpretq_mf8_u8(b.u8x16);
217
+ state->sum_f32x4 = vdotq_f32_mf8_fpm(state->sum_f32x4, a_mf8x16, b_mf8x16, NK_FPM_E4M3_);
218
+ }
219
+
220
+ NK_INTERNAL void nk_dot_e4m3x16_finalize_neonfp8( //
221
+ nk_dot_e4m3x16_state_neonfp8_t const *state_a, nk_dot_e4m3x16_state_neonfp8_t const *state_b, //
222
+ nk_dot_e4m3x16_state_neonfp8_t const *state_c, nk_dot_e4m3x16_state_neonfp8_t const *state_d, //
223
+ nk_size_t total_dimensions, nk_b128_vec_t *result) {
224
+ nk_unused_(total_dimensions);
225
+ float32x4_t ab_f32x4 = vpaddq_f32(state_a->sum_f32x4, state_b->sum_f32x4);
226
+ float32x4_t cd_f32x4 = vpaddq_f32(state_c->sum_f32x4, state_d->sum_f32x4);
227
+ result->f32x4 = vpaddq_f32(ab_f32x4, cd_f32x4);
228
+ }
229
+
230
+ typedef struct nk_dot_e5m2x16_state_neonfp8_t {
231
+ float32x4_t sum_f32x4;
232
+ } nk_dot_e5m2x16_state_neonfp8_t;
233
+
234
+ NK_INTERNAL void nk_dot_e5m2x16_init_neonfp8(nk_dot_e5m2x16_state_neonfp8_t *state) {
235
+ state->sum_f32x4 = vdupq_n_f32(0);
236
+ }
237
+
238
+ NK_INTERNAL void nk_dot_e5m2x16_update_neonfp8(nk_dot_e5m2x16_state_neonfp8_t *state, nk_b128_vec_t a, nk_b128_vec_t b,
239
+ nk_size_t depth_offset, nk_size_t active_dimensions) {
240
+ nk_unused_(depth_offset);
241
+ nk_unused_(active_dimensions);
242
+ mfloat8x16_t a_mf8x16 = vreinterpretq_mf8_u8(a.u8x16);
243
+ mfloat8x16_t b_mf8x16 = vreinterpretq_mf8_u8(b.u8x16);
244
+ state->sum_f32x4 = vdotq_f32_mf8_fpm(state->sum_f32x4, a_mf8x16, b_mf8x16, NK_FPM_E5M2_);
245
+ }
246
+
247
+ NK_INTERNAL void nk_dot_e5m2x16_finalize_neonfp8( //
248
+ nk_dot_e5m2x16_state_neonfp8_t const *state_a, nk_dot_e5m2x16_state_neonfp8_t const *state_b, //
249
+ nk_dot_e5m2x16_state_neonfp8_t const *state_c, nk_dot_e5m2x16_state_neonfp8_t const *state_d, //
250
+ nk_size_t total_dimensions, nk_b128_vec_t *result) {
251
+ nk_unused_(total_dimensions);
252
+ float32x4_t ab_f32x4 = vpaddq_f32(state_a->sum_f32x4, state_b->sum_f32x4);
253
+ float32x4_t cd_f32x4 = vpaddq_f32(state_c->sum_f32x4, state_d->sum_f32x4);
254
+ result->f32x4 = vpaddq_f32(ab_f32x4, cd_f32x4);
255
+ }
256
+
257
+ typedef struct nk_dot_e2m3x16_state_neonfp8_t {
258
+ float32x4_t sum_f32x4;
259
+ } nk_dot_e2m3x16_state_neonfp8_t;
260
+
261
+ NK_INTERNAL void nk_dot_e2m3x16_init_neonfp8(nk_dot_e2m3x16_state_neonfp8_t *state) {
262
+ state->sum_f32x4 = vdupq_n_f32(0);
263
+ }
264
+
265
+ NK_INTERNAL void nk_dot_e2m3x16_update_neonfp8(nk_dot_e2m3x16_state_neonfp8_t *state, nk_b128_vec_t a, nk_b128_vec_t b,
266
+ nk_size_t depth_offset, nk_size_t active_dimensions) {
267
+ nk_unused_(depth_offset);
268
+ nk_unused_(active_dimensions);
269
+ mfloat8x16_t a_mf8x16 = vreinterpretq_mf8_u8(nk_e2m3x16_to_e4m3x16_neonfp8_(a.u8x16));
270
+ mfloat8x16_t b_mf8x16 = vreinterpretq_mf8_u8(nk_e2m3x16_to_e4m3x16_neonfp8_(b.u8x16));
271
+ state->sum_f32x4 = vdotq_f32_mf8_fpm(state->sum_f32x4, a_mf8x16, b_mf8x16, NK_FPM_E4M3_);
272
+ }
273
+
274
+ NK_INTERNAL void nk_dot_e2m3x16_finalize_neonfp8( //
275
+ nk_dot_e2m3x16_state_neonfp8_t const *state_a, nk_dot_e2m3x16_state_neonfp8_t const *state_b, //
276
+ nk_dot_e2m3x16_state_neonfp8_t const *state_c, nk_dot_e2m3x16_state_neonfp8_t const *state_d, //
277
+ nk_size_t total_dimensions, nk_b128_vec_t *result) {
278
+ nk_unused_(total_dimensions);
279
+ float32x4_t ab_f32x4 = vpaddq_f32(state_a->sum_f32x4, state_b->sum_f32x4);
280
+ float32x4_t cd_f32x4 = vpaddq_f32(state_c->sum_f32x4, state_d->sum_f32x4);
281
+ result->f32x4 = vpaddq_f32(ab_f32x4, cd_f32x4);
282
+ }
283
+
284
+ typedef struct nk_dot_e3m2x16_state_neonfp8_t {
285
+ float32x4_t sum_f32x4;
286
+ } nk_dot_e3m2x16_state_neonfp8_t;
287
+
288
+ NK_INTERNAL void nk_dot_e3m2x16_init_neonfp8(nk_dot_e3m2x16_state_neonfp8_t *state) {
289
+ state->sum_f32x4 = vdupq_n_f32(0);
290
+ }
291
+
292
+ NK_INTERNAL void nk_dot_e3m2x16_update_neonfp8(nk_dot_e3m2x16_state_neonfp8_t *state, nk_b128_vec_t a, nk_b128_vec_t b,
293
+ nk_size_t depth_offset, nk_size_t active_dimensions) {
294
+ nk_unused_(depth_offset);
295
+ nk_unused_(active_dimensions);
296
+ mfloat8x16_t a_mf8x16 = vreinterpretq_mf8_u8(nk_e3m2x16_to_e5m2x16_neonfp8_(a.u8x16));
297
+ mfloat8x16_t b_mf8x16 = vreinterpretq_mf8_u8(nk_e3m2x16_to_e5m2x16_neonfp8_(b.u8x16));
298
+ state->sum_f32x4 = vdotq_f32_mf8_fpm(state->sum_f32x4, a_mf8x16, b_mf8x16, NK_FPM_E5M2_);
299
+ }
300
+
301
+ NK_INTERNAL void nk_dot_e3m2x16_finalize_neonfp8( //
302
+ nk_dot_e3m2x16_state_neonfp8_t const *state_a, nk_dot_e3m2x16_state_neonfp8_t const *state_b, //
303
+ nk_dot_e3m2x16_state_neonfp8_t const *state_c, nk_dot_e3m2x16_state_neonfp8_t const *state_d, //
304
+ nk_size_t total_dimensions, nk_b128_vec_t *result) {
305
+ nk_unused_(total_dimensions);
306
+ float32x4_t ab_f32x4 = vpaddq_f32(state_a->sum_f32x4, state_b->sum_f32x4);
307
+ float32x4_t cd_f32x4 = vpaddq_f32(state_c->sum_f32x4, state_d->sum_f32x4);
308
+ result->f32x4 = vpaddq_f32(ab_f32x4, cd_f32x4);
309
+ }
310
+
311
+ #if defined(__clang__)
312
+ #pragma clang attribute pop
313
+ #elif defined(__GNUC__)
314
+ #pragma GCC pop_options
315
+ #endif
316
+
317
+ #if defined(__cplusplus)
318
+ } // extern "C"
319
+ #endif
320
+
321
+ #endif // NK_TARGET_NEONFP8
322
+ #endif // NK_TARGET_ARM_
323
+ #endif // NK_DOT_NEONFP8_H