numkong 7.0.0 → 7.4.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (315) hide show
  1. package/README.md +239 -122
  2. package/binding.gyp +25 -491
  3. package/c/dispatch_bf16.c +59 -1
  4. package/c/dispatch_e2m3.c +41 -8
  5. package/c/dispatch_e3m2.c +49 -8
  6. package/c/dispatch_e4m3.c +51 -9
  7. package/c/dispatch_e5m2.c +45 -1
  8. package/c/dispatch_f16.c +79 -26
  9. package/c/dispatch_f16c.c +5 -5
  10. package/c/dispatch_f32.c +56 -0
  11. package/c/dispatch_f64.c +52 -0
  12. package/c/dispatch_i4.c +3 -0
  13. package/c/dispatch_i8.c +62 -3
  14. package/c/dispatch_other.c +18 -0
  15. package/c/dispatch_u1.c +54 -9
  16. package/c/dispatch_u4.c +3 -0
  17. package/c/dispatch_u8.c +64 -3
  18. package/c/numkong.c +3 -0
  19. package/include/README.md +79 -9
  20. package/include/numkong/attention/sapphireamx.h +278 -276
  21. package/include/numkong/attention/sme.h +983 -977
  22. package/include/numkong/attention.h +1 -1
  23. package/include/numkong/capabilities.h +289 -94
  24. package/include/numkong/cast/README.md +40 -40
  25. package/include/numkong/cast/diamond.h +64 -0
  26. package/include/numkong/cast/haswell.h +42 -194
  27. package/include/numkong/cast/icelake.h +42 -37
  28. package/include/numkong/cast/loongsonasx.h +252 -0
  29. package/include/numkong/cast/neon.h +216 -249
  30. package/include/numkong/cast/powervsx.h +449 -0
  31. package/include/numkong/cast/rvv.h +223 -274
  32. package/include/numkong/cast/sapphire.h +18 -18
  33. package/include/numkong/cast/serial.h +1018 -944
  34. package/include/numkong/cast/skylake.h +82 -23
  35. package/include/numkong/cast/v128relaxed.h +462 -105
  36. package/include/numkong/cast.h +24 -0
  37. package/include/numkong/cast.hpp +44 -0
  38. package/include/numkong/curved/README.md +17 -17
  39. package/include/numkong/curved/neon.h +131 -7
  40. package/include/numkong/curved/neonbfdot.h +6 -7
  41. package/include/numkong/curved/rvv.h +26 -26
  42. package/include/numkong/curved/smef64.h +186 -182
  43. package/include/numkong/curved.h +14 -18
  44. package/include/numkong/dot/README.md +154 -137
  45. package/include/numkong/dot/alder.h +43 -43
  46. package/include/numkong/dot/diamond.h +158 -0
  47. package/include/numkong/dot/genoa.h +4 -30
  48. package/include/numkong/dot/haswell.h +215 -180
  49. package/include/numkong/dot/icelake.h +190 -76
  50. package/include/numkong/dot/loongsonasx.h +671 -0
  51. package/include/numkong/dot/neon.h +124 -73
  52. package/include/numkong/dot/neonbfdot.h +11 -12
  53. package/include/numkong/dot/neonfhm.h +44 -46
  54. package/include/numkong/dot/neonfp8.h +323 -0
  55. package/include/numkong/dot/neonsdot.h +190 -76
  56. package/include/numkong/dot/powervsx.h +752 -0
  57. package/include/numkong/dot/rvv.h +92 -84
  58. package/include/numkong/dot/rvvbf16.h +12 -12
  59. package/include/numkong/dot/rvvhalf.h +12 -12
  60. package/include/numkong/dot/sapphire.h +4 -4
  61. package/include/numkong/dot/serial.h +66 -30
  62. package/include/numkong/dot/sierra.h +31 -31
  63. package/include/numkong/dot/skylake.h +142 -110
  64. package/include/numkong/dot/sve.h +217 -177
  65. package/include/numkong/dot/svebfdot.h +10 -10
  66. package/include/numkong/dot/svehalf.h +85 -41
  67. package/include/numkong/dot/svesdot.h +89 -0
  68. package/include/numkong/dot/v128relaxed.h +124 -89
  69. package/include/numkong/dot.h +114 -48
  70. package/include/numkong/dots/README.md +203 -203
  71. package/include/numkong/dots/alder.h +12 -9
  72. package/include/numkong/dots/diamond.h +86 -0
  73. package/include/numkong/dots/genoa.h +10 -4
  74. package/include/numkong/dots/haswell.h +63 -48
  75. package/include/numkong/dots/icelake.h +27 -18
  76. package/include/numkong/dots/loongsonasx.h +176 -0
  77. package/include/numkong/dots/neon.h +14 -11
  78. package/include/numkong/dots/neonbfdot.h +4 -3
  79. package/include/numkong/dots/neonfhm.h +11 -9
  80. package/include/numkong/dots/neonfp8.h +99 -0
  81. package/include/numkong/dots/neonsdot.h +48 -12
  82. package/include/numkong/dots/powervsx.h +194 -0
  83. package/include/numkong/dots/rvv.h +451 -344
  84. package/include/numkong/dots/sapphireamx.h +1028 -984
  85. package/include/numkong/dots/serial.h +213 -197
  86. package/include/numkong/dots/sierra.h +10 -7
  87. package/include/numkong/dots/skylake.h +47 -36
  88. package/include/numkong/dots/sme.h +2001 -2364
  89. package/include/numkong/dots/smebi32.h +175 -162
  90. package/include/numkong/dots/smef64.h +328 -323
  91. package/include/numkong/dots/v128relaxed.h +64 -41
  92. package/include/numkong/dots.h +573 -293
  93. package/include/numkong/dots.hpp +45 -43
  94. package/include/numkong/each/README.md +133 -137
  95. package/include/numkong/each/haswell.h +6 -6
  96. package/include/numkong/each/icelake.h +7 -7
  97. package/include/numkong/each/neon.h +76 -42
  98. package/include/numkong/each/neonbfdot.h +11 -12
  99. package/include/numkong/each/neonhalf.h +24 -116
  100. package/include/numkong/each/rvv.h +28 -28
  101. package/include/numkong/each/sapphire.h +27 -161
  102. package/include/numkong/each/serial.h +6 -6
  103. package/include/numkong/each/skylake.h +7 -7
  104. package/include/numkong/each/v128relaxed.h +562 -0
  105. package/include/numkong/each.h +148 -62
  106. package/include/numkong/each.hpp +2 -2
  107. package/include/numkong/geospatial/README.md +18 -18
  108. package/include/numkong/geospatial/haswell.h +365 -325
  109. package/include/numkong/geospatial/neon.h +350 -306
  110. package/include/numkong/geospatial/rvv.h +4 -4
  111. package/include/numkong/geospatial/skylake.h +376 -340
  112. package/include/numkong/geospatial/v128relaxed.h +366 -327
  113. package/include/numkong/geospatial.h +17 -17
  114. package/include/numkong/matrix.hpp +4 -4
  115. package/include/numkong/maxsim/README.md +14 -14
  116. package/include/numkong/maxsim/alder.h +6 -6
  117. package/include/numkong/maxsim/genoa.h +4 -4
  118. package/include/numkong/maxsim/haswell.h +6 -6
  119. package/include/numkong/maxsim/icelake.h +18 -18
  120. package/include/numkong/maxsim/neonsdot.h +21 -21
  121. package/include/numkong/maxsim/sapphireamx.h +14 -14
  122. package/include/numkong/maxsim/serial.h +6 -6
  123. package/include/numkong/maxsim/sme.h +221 -196
  124. package/include/numkong/maxsim/v128relaxed.h +6 -6
  125. package/include/numkong/mesh/README.md +62 -56
  126. package/include/numkong/mesh/haswell.h +339 -464
  127. package/include/numkong/mesh/neon.h +1100 -519
  128. package/include/numkong/mesh/neonbfdot.h +36 -68
  129. package/include/numkong/mesh/rvv.h +530 -435
  130. package/include/numkong/mesh/serial.h +75 -91
  131. package/include/numkong/mesh/skylake.h +1627 -302
  132. package/include/numkong/mesh/v128relaxed.h +443 -330
  133. package/include/numkong/mesh.h +63 -49
  134. package/include/numkong/mesh.hpp +4 -4
  135. package/include/numkong/numkong.h +3 -3
  136. package/include/numkong/numkong.hpp +1 -0
  137. package/include/numkong/probability/README.md +23 -19
  138. package/include/numkong/probability/neon.h +82 -52
  139. package/include/numkong/probability/rvv.h +28 -23
  140. package/include/numkong/probability/serial.h +51 -39
  141. package/include/numkong/probability.h +20 -23
  142. package/include/numkong/random.h +1 -1
  143. package/include/numkong/reduce/README.md +143 -138
  144. package/include/numkong/reduce/alder.h +81 -77
  145. package/include/numkong/reduce/haswell.h +222 -220
  146. package/include/numkong/reduce/neon.h +629 -519
  147. package/include/numkong/reduce/neonbfdot.h +7 -218
  148. package/include/numkong/reduce/neonfhm.h +9 -381
  149. package/include/numkong/reduce/neonsdot.h +9 -9
  150. package/include/numkong/reduce/rvv.h +928 -802
  151. package/include/numkong/reduce/serial.h +23 -27
  152. package/include/numkong/reduce/sierra.h +20 -20
  153. package/include/numkong/reduce/skylake.h +326 -324
  154. package/include/numkong/reduce/v128relaxed.h +52 -52
  155. package/include/numkong/reduce.h +4 -23
  156. package/include/numkong/reduce.hpp +156 -11
  157. package/include/numkong/scalar/README.md +6 -6
  158. package/include/numkong/scalar/haswell.h +26 -17
  159. package/include/numkong/scalar/loongsonasx.h +74 -0
  160. package/include/numkong/scalar/neon.h +9 -9
  161. package/include/numkong/scalar/powervsx.h +96 -0
  162. package/include/numkong/scalar/rvv.h +2 -2
  163. package/include/numkong/scalar/sapphire.h +21 -10
  164. package/include/numkong/scalar/serial.h +21 -21
  165. package/include/numkong/scalar.h +13 -0
  166. package/include/numkong/set/README.md +28 -28
  167. package/include/numkong/set/haswell.h +12 -12
  168. package/include/numkong/set/icelake.h +14 -14
  169. package/include/numkong/set/loongsonasx.h +181 -0
  170. package/include/numkong/set/neon.h +17 -18
  171. package/include/numkong/set/powervsx.h +326 -0
  172. package/include/numkong/set/rvv.h +4 -4
  173. package/include/numkong/set/serial.h +6 -6
  174. package/include/numkong/set/sve.h +60 -59
  175. package/include/numkong/set/v128relaxed.h +6 -6
  176. package/include/numkong/set.h +21 -7
  177. package/include/numkong/sets/README.md +26 -26
  178. package/include/numkong/sets/loongsonasx.h +52 -0
  179. package/include/numkong/sets/powervsx.h +65 -0
  180. package/include/numkong/sets/smebi32.h +395 -364
  181. package/include/numkong/sets.h +83 -40
  182. package/include/numkong/sparse/README.md +4 -4
  183. package/include/numkong/sparse/icelake.h +101 -101
  184. package/include/numkong/sparse/serial.h +1 -1
  185. package/include/numkong/sparse/sve2.h +137 -141
  186. package/include/numkong/sparse/turin.h +12 -12
  187. package/include/numkong/sparse.h +10 -10
  188. package/include/numkong/spatial/README.md +230 -226
  189. package/include/numkong/spatial/alder.h +113 -116
  190. package/include/numkong/spatial/diamond.h +240 -0
  191. package/include/numkong/spatial/genoa.h +0 -68
  192. package/include/numkong/spatial/haswell.h +74 -55
  193. package/include/numkong/spatial/icelake.h +539 -58
  194. package/include/numkong/spatial/loongsonasx.h +483 -0
  195. package/include/numkong/spatial/neon.h +125 -52
  196. package/include/numkong/spatial/neonbfdot.h +8 -9
  197. package/include/numkong/spatial/neonfp8.h +258 -0
  198. package/include/numkong/spatial/neonsdot.h +180 -12
  199. package/include/numkong/spatial/powervsx.h +738 -0
  200. package/include/numkong/spatial/rvv.h +146 -139
  201. package/include/numkong/spatial/rvvbf16.h +17 -12
  202. package/include/numkong/spatial/rvvhalf.h +13 -10
  203. package/include/numkong/spatial/serial.h +13 -12
  204. package/include/numkong/spatial/sierra.h +232 -39
  205. package/include/numkong/spatial/skylake.h +73 -74
  206. package/include/numkong/spatial/sve.h +93 -72
  207. package/include/numkong/spatial/svebfdot.h +29 -29
  208. package/include/numkong/spatial/svehalf.h +52 -26
  209. package/include/numkong/spatial/svesdot.h +142 -0
  210. package/include/numkong/spatial/v128relaxed.h +293 -41
  211. package/include/numkong/spatial.h +338 -82
  212. package/include/numkong/spatials/README.md +194 -194
  213. package/include/numkong/spatials/diamond.h +82 -0
  214. package/include/numkong/spatials/haswell.h +2 -2
  215. package/include/numkong/spatials/loongsonasx.h +153 -0
  216. package/include/numkong/spatials/neonfp8.h +111 -0
  217. package/include/numkong/spatials/neonsdot.h +34 -0
  218. package/include/numkong/spatials/powervsx.h +153 -0
  219. package/include/numkong/spatials/rvv.h +259 -243
  220. package/include/numkong/spatials/sapphireamx.h +173 -173
  221. package/include/numkong/spatials/serial.h +2 -2
  222. package/include/numkong/spatials/skylake.h +2 -2
  223. package/include/numkong/spatials/sme.h +590 -605
  224. package/include/numkong/spatials/smef64.h +139 -130
  225. package/include/numkong/spatials/v128relaxed.h +2 -2
  226. package/include/numkong/spatials.h +820 -500
  227. package/include/numkong/spatials.hpp +49 -48
  228. package/include/numkong/tensor.hpp +406 -17
  229. package/include/numkong/trigonometry/README.md +19 -19
  230. package/include/numkong/trigonometry/haswell.h +402 -401
  231. package/include/numkong/trigonometry/neon.h +386 -387
  232. package/include/numkong/trigonometry/rvv.h +52 -51
  233. package/include/numkong/trigonometry/serial.h +13 -13
  234. package/include/numkong/trigonometry/skylake.h +373 -369
  235. package/include/numkong/trigonometry/v128relaxed.h +375 -374
  236. package/include/numkong/trigonometry.h +13 -13
  237. package/include/numkong/trigonometry.hpp +2 -2
  238. package/include/numkong/types.h +287 -49
  239. package/include/numkong/types.hpp +436 -12
  240. package/include/numkong/vector.hpp +82 -14
  241. package/javascript/dist/cjs/numkong-wasm.js +6 -12
  242. package/javascript/dist/cjs/numkong.d.ts +7 -1
  243. package/javascript/dist/cjs/numkong.js +37 -11
  244. package/javascript/dist/cjs/types.d.ts +9 -0
  245. package/javascript/dist/cjs/types.js +96 -0
  246. package/javascript/dist/esm/numkong-browser.d.ts +14 -0
  247. package/javascript/dist/esm/numkong-browser.js +23 -0
  248. package/javascript/dist/esm/numkong-wasm.js +6 -12
  249. package/javascript/dist/esm/numkong.d.ts +7 -1
  250. package/javascript/dist/esm/numkong.js +37 -11
  251. package/javascript/dist/esm/types.d.ts +9 -0
  252. package/javascript/dist/esm/types.js +96 -0
  253. package/javascript/node-gyp-build.d.ts +4 -1
  254. package/javascript/numkong-browser.ts +40 -0
  255. package/javascript/numkong-wasm.ts +7 -13
  256. package/javascript/numkong.c +5 -26
  257. package/javascript/numkong.ts +36 -11
  258. package/javascript/tsconfig-base.json +1 -0
  259. package/javascript/tsconfig-cjs.json +6 -1
  260. package/javascript/types.ts +110 -0
  261. package/numkong.gypi +101 -0
  262. package/package.json +34 -13
  263. package/probes/arm_neon.c +8 -0
  264. package/probes/arm_neon_bfdot.c +9 -0
  265. package/probes/arm_neon_fhm.c +9 -0
  266. package/probes/arm_neon_half.c +8 -0
  267. package/probes/arm_neon_sdot.c +9 -0
  268. package/probes/arm_neonfp8.c +9 -0
  269. package/probes/arm_sme.c +16 -0
  270. package/probes/arm_sme2.c +16 -0
  271. package/probes/arm_sme2p1.c +16 -0
  272. package/probes/arm_sme_bf16.c +16 -0
  273. package/probes/arm_sme_bi32.c +16 -0
  274. package/probes/arm_sme_f64.c +16 -0
  275. package/probes/arm_sme_fa64.c +14 -0
  276. package/probes/arm_sme_half.c +16 -0
  277. package/probes/arm_sme_lut2.c +15 -0
  278. package/probes/arm_sve.c +18 -0
  279. package/probes/arm_sve2.c +20 -0
  280. package/probes/arm_sve2p1.c +18 -0
  281. package/probes/arm_sve_bfdot.c +20 -0
  282. package/probes/arm_sve_half.c +18 -0
  283. package/probes/arm_sve_sdot.c +21 -0
  284. package/probes/loongarch_lasx.c +12 -0
  285. package/probes/power_vsx.c +12 -0
  286. package/probes/probe.js +127 -0
  287. package/probes/riscv_rvv.c +14 -0
  288. package/probes/riscv_rvv_bb.c +15 -0
  289. package/probes/riscv_rvv_bf16.c +17 -0
  290. package/probes/riscv_rvv_half.c +14 -0
  291. package/probes/wasm_v128relaxed.c +11 -0
  292. package/probes/x86_alder.c +17 -0
  293. package/probes/x86_diamond.c +17 -0
  294. package/probes/x86_genoa.c +17 -0
  295. package/probes/x86_graniteamx.c +19 -0
  296. package/probes/x86_haswell.c +11 -0
  297. package/probes/x86_icelake.c +17 -0
  298. package/probes/x86_sapphire.c +16 -0
  299. package/probes/x86_sapphireamx.c +18 -0
  300. package/probes/x86_sierra.c +17 -0
  301. package/probes/x86_skylake.c +15 -0
  302. package/probes/x86_turin.c +17 -0
  303. package/wasm/numkong-emscripten.js +2 -0
  304. package/wasm/numkong.d.ts +14 -0
  305. package/wasm/numkong.js +1124 -0
  306. package/wasm/numkong.wasm +0 -0
  307. package/include/numkong/curved/neonhalf.h +0 -212
  308. package/include/numkong/dot/neonhalf.h +0 -198
  309. package/include/numkong/dots/neonhalf.h +0 -57
  310. package/include/numkong/mesh/neonhalf.h +0 -616
  311. package/include/numkong/reduce/neonhalf.h +0 -157
  312. package/include/numkong/spatial/neonhalf.h +0 -118
  313. package/include/numkong/spatial/sapphire.h +0 -343
  314. package/include/numkong/spatials/neonhalf.h +0 -58
  315. package/javascript/README.md +0 -246
@@ -94,11 +94,11 @@ extern "C" {
94
94
  #define nk_define_dot_(input_type, accumulator_type, output_type, load_and_convert) \
95
95
  NK_PUBLIC void nk_dot_##input_type##_serial(nk_##input_type##_t const *a, nk_##input_type##_t const *b, \
96
96
  nk_size_t n, nk_##output_type##_t *result) { \
97
- nk_##accumulator_type##_t sum = 0, a_val, b_val; \
97
+ nk_##accumulator_type##_t sum = 0, a_value, b_value; \
98
98
  for (nk_size_t i = 0; i != n; ++i) { \
99
- load_and_convert(a + i, &a_val); \
100
- load_and_convert(b + i, &b_val); \
101
- sum += a_val * b_val; \
99
+ load_and_convert(a + i, &a_value); \
100
+ load_and_convert(b + i, &b_value); \
101
+ sum += a_value * b_value; \
102
102
  } \
103
103
  *result = (nk_##output_type##_t)sum; \
104
104
  }
@@ -139,15 +139,15 @@ extern "C" {
139
139
  result->imag = sum_imag; \
140
140
  }
141
141
 
142
- #pragma region - Traditional Floats
142
+ #pragma region F32 and F64 Floats
143
143
 
144
144
  nk_define_dot_(f32, f64, f64, nk_assign_from_to_) // nk_dot_f32_serial
145
145
  nk_define_dot_complex_(f32c, f64, f64c, nk_assign_from_to_) // nk_dot_f32c_serial
146
146
  nk_define_vdot_complex_(f32c, f64, f64c, nk_assign_from_to_) // nk_vdot_f32c_serial
147
147
 
148
- #pragma endregion - Traditional Floats
148
+ #pragma endregion F32 and F64 Floats
149
149
 
150
- #pragma region - Smaller Floats
150
+ #pragma region F16 and BF16 Floats
151
151
 
152
152
  nk_define_dot_(f16, f32, f32, nk_f16_to_f32_serial) // nk_dot_f16_serial
153
153
  nk_define_dot_complex_(f16c, f32, f32c, nk_f16_to_f32_serial) // nk_dot_f16c_serial
@@ -162,9 +162,9 @@ nk_define_dot_(e5m2, f32, f32, nk_e5m2_to_f32_serial) // nk_dot_e5m2_serial
162
162
  nk_define_dot_(e2m3, f32, f32, nk_e2m3_to_f32_serial) // nk_dot_e2m3_serial
163
163
  nk_define_dot_(e3m2, f32, f32, nk_e3m2_to_f32_serial) // nk_dot_e3m2_serial
164
164
 
165
- #pragma endregion - Smaller Floats
165
+ #pragma endregion F16 and BF16 Floats
166
166
 
167
- #pragma region - Small Integers
167
+ #pragma region I8 and U8 Integers
168
168
 
169
169
  nk_define_dot_(i8, i32, i32, nk_assign_from_to_) // nk_dot_i8_serial
170
170
  nk_define_dot_(u8, u32, u32, nk_assign_from_to_) // nk_dot_u8_serial
@@ -207,9 +207,9 @@ NK_PUBLIC void nk_dot_u4_serial(nk_u4x2_t const *a, nk_u4x2_t const *b, nk_size_
207
207
  *result = sum;
208
208
  }
209
209
 
210
- #pragma endregion - Small Integers
210
+ #pragma endregion I8 and U8 Integers
211
211
 
212
- #pragma region - Traditional Floats
212
+ #pragma region F32 and F64 Floats
213
213
 
214
214
  /* Double-precision dot-produce variants
215
215
  *
@@ -325,9 +325,9 @@ NK_INTERNAL void nk_dot_f32x4_finalize_serial(
325
325
  result->f64s[3] = state_d->sums[0] + state_d->sums[1] + state_d->sums[2] + state_d->sums[3];
326
326
  }
327
327
 
328
- #pragma endregion - Traditional Floats
328
+ #pragma endregion F32 and F64 Floats
329
329
 
330
- #pragma region - Smaller Floats
330
+ #pragma region F16 and BF16 Floats
331
331
 
332
332
  typedef struct nk_dot_f16x8_state_serial_t {
333
333
  nk_f32_t sums[4];
@@ -364,6 +364,36 @@ NK_INTERNAL void nk_dot_f16x8_finalize_serial(
364
364
  result->f32s[3] = state_d->sums[0] + state_d->sums[1] + state_d->sums[2] + state_d->sums[3];
365
365
  }
366
366
 
367
+ typedef struct nk_dot_through_f32x4_state_serial_t {
368
+ nk_f32_t sums[4];
369
+ } nk_dot_through_f32x4_state_serial_t;
370
+
371
+ NK_INTERNAL void nk_dot_through_f32x4_init_serial(nk_dot_through_f32x4_state_serial_t *state) {
372
+ state->sums[0] = 0, state->sums[1] = 0, state->sums[2] = 0, state->sums[3] = 0;
373
+ }
374
+
375
+ NK_INTERNAL void nk_dot_through_f32x4_update_serial(nk_dot_through_f32x4_state_serial_t *state, nk_b128_vec_t a,
376
+ nk_b128_vec_t b, nk_size_t depth_offset,
377
+ nk_size_t active_dimensions) {
378
+ nk_unused_(depth_offset);
379
+ nk_unused_(active_dimensions);
380
+ state->sums[0] += a.f32s[0] * b.f32s[0];
381
+ state->sums[1] += a.f32s[1] * b.f32s[1];
382
+ state->sums[2] += a.f32s[2] * b.f32s[2];
383
+ state->sums[3] += a.f32s[3] * b.f32s[3];
384
+ }
385
+
386
+ NK_INTERNAL void nk_dot_through_f32x4_finalize_serial( //
387
+ nk_dot_through_f32x4_state_serial_t const *state_a, nk_dot_through_f32x4_state_serial_t const *state_b, //
388
+ nk_dot_through_f32x4_state_serial_t const *state_c, nk_dot_through_f32x4_state_serial_t const *state_d, //
389
+ nk_size_t total_dimensions, nk_b128_vec_t *result) {
390
+ nk_unused_(total_dimensions);
391
+ result->f32s[0] = state_a->sums[0] + state_a->sums[1] + state_a->sums[2] + state_a->sums[3];
392
+ result->f32s[1] = state_b->sums[0] + state_b->sums[1] + state_b->sums[2] + state_b->sums[3];
393
+ result->f32s[2] = state_c->sums[0] + state_c->sums[1] + state_c->sums[2] + state_c->sums[3];
394
+ result->f32s[3] = state_d->sums[0] + state_d->sums[1] + state_d->sums[2] + state_d->sums[3];
395
+ }
396
+
367
397
  typedef struct nk_dot_bf16x8_state_serial_t {
368
398
  nk_f32_t sums[4];
369
399
  } nk_dot_bf16x8_state_serial_t;
@@ -399,9 +429,9 @@ NK_INTERNAL void nk_dot_bf16x8_finalize_serial(
399
429
  result->f32s[3] = state_d->sums[0] + state_d->sums[1] + state_d->sums[2] + state_d->sums[3];
400
430
  }
401
431
 
402
- #pragma endregion - Smaller Floats
432
+ #pragma endregion F16 and BF16 Floats
403
433
 
404
- #pragma region - Small Integers
434
+ #pragma region I8 and U8 Integers
405
435
 
406
436
  typedef struct nk_dot_i8x16_state_serial_t {
407
437
  nk_i64_t sums[2];
@@ -476,9 +506,9 @@ NK_INTERNAL void nk_dot_u8x16_finalize_serial(
476
506
  result->u32s[3] = (nk_u32_t)(state_d->sums[0] + state_d->sums[1]);
477
507
  }
478
508
 
479
- #pragma endregion - Small Integers
509
+ #pragma endregion I8 and U8 Integers
480
510
 
481
- #pragma region - Smaller Floats
511
+ #pragma region F16 and BF16 Floats
482
512
 
483
513
  typedef struct nk_dot_e4m3x16_state_serial_t {
484
514
  nk_f32_t sums[4];
@@ -640,9 +670,9 @@ NK_INTERNAL void nk_dot_e3m2x16_finalize_serial(
640
670
  result->f32s[3] = state_d->sums[0] + state_d->sums[1] + state_d->sums[2] + state_d->sums[3];
641
671
  }
642
672
 
643
- #pragma endregion - Smaller Floats
673
+ #pragma endregion F16 and BF16 Floats
644
674
 
645
- #pragma region - Small Integers
675
+ #pragma region I8 and U8 Integers
646
676
 
647
677
  // U4x2 state: processes 16 nibbles (8 bytes = 64 bits) per update
648
678
  typedef struct nk_dot_u4x16_state_serial_t {
@@ -694,20 +724,26 @@ NK_INTERNAL void nk_dot_u4x16_finalize_serial(nk_dot_u4x16_state_serial_t const
694
724
  }
695
725
 
696
726
  NK_INTERNAL void nk_load_i4x16_to_i8x16_serial_(void const *src, nk_b128_vec_t *dst) {
697
- nk_i4_to_i8_serial_((nk_i4x2_t const *)src, dst->i8s, 16);
727
+ nk_i4x2_t const *pairs = (nk_i4x2_t const *)src;
728
+ for (nk_size_t i = 0; i < 8; ++i) nk_i4x2_to_i8x2_serial(&pairs[i], &dst->i8s[i * 2]);
698
729
  }
699
730
 
700
731
  NK_INTERNAL void nk_partial_load_i4x16_to_i8x16_serial_(void const *src, nk_b128_vec_t *dst, nk_size_t n) {
701
- nk_i4_to_i8_serial_((nk_i4x2_t const *)src, dst->i8s, n);
732
+ nk_i4x2_t const *pairs = (nk_i4x2_t const *)src;
733
+ nk_size_t count_pairs = n / 2;
734
+ for (nk_size_t i = 0; i < count_pairs; ++i) nk_i4x2_to_i8x2_serial(&pairs[i], &dst->i8s[i * 2]);
702
735
  for (nk_size_t i = n; i < 16; ++i) dst->i8s[i] = 0;
703
736
  }
704
737
 
705
738
  NK_INTERNAL void nk_load_u4x16_to_u8x16_serial_(void const *src, nk_b128_vec_t *dst) {
706
- nk_u4_to_u8_serial_((nk_u4x2_t const *)src, dst->u8s, 16);
739
+ nk_u4x2_t const *pairs = (nk_u4x2_t const *)src;
740
+ for (nk_size_t i = 0; i < 8; ++i) nk_u4x2_to_u8x2_serial(&pairs[i], &dst->u8s[i * 2]);
707
741
  }
708
742
 
709
743
  NK_INTERNAL void nk_partial_load_u4x16_to_u8x16_serial_(void const *src, nk_b128_vec_t *dst, nk_size_t n) {
710
- nk_u4_to_u8_serial_((nk_u4x2_t const *)src, dst->u8s, n);
744
+ nk_u4x2_t const *pairs = (nk_u4x2_t const *)src;
745
+ nk_size_t count_pairs = n / 2;
746
+ for (nk_size_t i = 0; i < count_pairs; ++i) nk_u4x2_to_u8x2_serial(&pairs[i], &dst->u8s[i * 2]);
711
747
  for (nk_size_t i = n; i < 16; ++i) dst->u8s[i] = 0;
712
748
  }
713
749
 
@@ -759,9 +795,9 @@ NK_INTERNAL void nk_dot_i4x16_finalize_serial(nk_dot_i4x16_state_serial_t const
759
795
  result->i32s[3] = (nk_i32_t)(state_d->sums[0] + state_d->sums[1]);
760
796
  }
761
797
 
762
- #pragma endregion - Small Integers
798
+ #pragma endregion I8 and U8 Integers
763
799
 
764
- #pragma region - Binary
800
+ #pragma region Binary
765
801
 
766
802
  NK_PUBLIC void nk_dot_u1_serial(nk_u1x8_t const *a, nk_u1x8_t const *b, nk_size_t n_bits, nk_u32_t *result) {
767
803
  nk_u32_t dot = 0;
@@ -798,7 +834,7 @@ NK_INTERNAL void nk_dot_u1x128_finalize_serial(nk_dot_u1x128_state_serial_t cons
798
834
  result->u32s[3] = state_d->dot_count;
799
835
  }
800
836
 
801
- #pragma endregion - Binary
837
+ #pragma endregion Binary
802
838
 
803
839
  /**
804
840
  * Serial fallback sum helpers for progressive element-sum accumulation.
@@ -806,7 +842,7 @@ NK_INTERNAL void nk_dot_u1x128_finalize_serial(nk_dot_u1x128_state_serial_t cons
806
842
  * on the depth loop's already-loaded vectors, avoiding a separate sum pass.
807
843
  */
808
844
 
809
- #pragma region - Stateful Element Sum Helpers (for compensated GEMM)
845
+ #pragma region Stateful Element Sum Helpers (for compensated GEMM)
810
846
 
811
847
  /* i4x32: Haswell i4 (nk_b128_vec_t containing 32 nibbles in 16 bytes) */
812
848
  typedef struct nk_sum_i4x32_state_serial_t {
@@ -818,8 +854,8 @@ NK_INTERNAL void nk_sum_i4x32_init_serial(nk_sum_i4x32_state_serial_t *state) {
818
854
  NK_INTERNAL void nk_sum_i4x32_update_serial(nk_sum_i4x32_state_serial_t *state, nk_b128_vec_t v) {
819
855
  nk_u8_t const *d = (nk_u8_t const *)&v;
820
856
  for (int i = 0; i < 16; i++) {
821
- nk_i8_t low = (nk_i8_t)((d[i] & 0x0F) ^ 0x08) - 8; /* sign-extend low nibble */
822
- nk_i8_t high = (nk_i8_t)((d[i] >> 4) ^ 0x08) - 8; /* sign-extend high nibble */
857
+ nk_i8_t low = (nk_i8_t)((d[i] & 0x0F) ^ 0x08) - 8; // sign-extend low nibble
858
+ nk_i8_t high = (nk_i8_t)((d[i] >> 4) ^ 0x08) - 8; // sign-extend high nibble
823
859
  state->sum += low + high;
824
860
  }
825
861
  }
@@ -829,7 +865,7 @@ NK_INTERNAL nk_i32_t nk_sum_i4x32_finalize_serial(nk_sum_i4x32_state_serial_t co
829
865
  return (nk_i32_t)state->sum;
830
866
  }
831
867
 
832
- #pragma endregion - Stateful Element Sum Helpers
868
+ #pragma endregion Stateful Element Sum Helpers
833
869
 
834
870
  #if defined(__cplusplus)
835
871
  } // extern "C"
@@ -8,9 +8,9 @@
8
8
  *
9
9
  * @section dot_sierra_instructions AVX-VNNI-INT8 Instructions
10
10
  *
11
- * Intrinsic Instruction
12
- * _mm256_dpbssd_epi32 VPDPBSSD (YMM, YMM, YMM) i8 x i8 -> i32
13
- * _mm256_dpbuud_epi32 VPDPBUUD (YMM, YMM, YMM) u8 x u8 -> u32
11
+ * Intrinsic Instruction
12
+ * _mm256_dpbssd_epi32 VPDPBSSD (YMM, YMM, YMM) i8 × i8 i32
13
+ * _mm256_dpbuud_epi32 VPDPBUUD (YMM, YMM, YMM) u8 × u8 u32
14
14
  *
15
15
  * Sierra Forest CPUs support AVX-VNNI-INT8, adding native signed*signed and
16
16
  * unsigned*unsigned 8-bit dot products. This eliminates the algebraic sign
@@ -248,10 +248,10 @@ NK_PUBLIC void nk_dot_e2m3_sierra(nk_e2m3_t const *a_scalars, nk_e2m3_t const *b
248
248
  // Uses dpbssd instead of dpbusd — both operands are already signed i8 after
249
249
  // LUT + sign application, so no unsigned conversion is needed.
250
250
  //
251
- __m256i const lut_lower_u8x32 = _mm256_set_epi8(30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0, 30, 28,
252
- 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
253
- __m256i const lut_upper_u8x32 = _mm256_set_epi8(120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32,
254
- 120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32);
251
+ __m256i const lut_low_u8x32 = _mm256_set_epi8(30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0, 30, 28, 26,
252
+ 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
253
+ __m256i const lut_high_u8x32 = _mm256_set_epi8(120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32,
254
+ 120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32);
255
255
  __m256i const nibble_mask_u8x32 = _mm256_set1_epi8(0x0F);
256
256
  __m256i const magnitude_mask_u8x32 = _mm256_set1_epi8(0x1F);
257
257
  __m256i const half_select_u8x32 = _mm256_set1_epi8(0x10);
@@ -277,11 +277,11 @@ nk_dot_e2m3_sierra_cycle:
277
277
  // Decode a: extract magnitude, dual-VPSHUFB LUT, apply sign
278
278
  __m256i a_magnitude_u8x32 = _mm256_and_si256(a_e2m3_u8x32, magnitude_mask_u8x32);
279
279
  __m256i a_shuffle_index_u8x32 = _mm256_and_si256(a_magnitude_u8x32, nibble_mask_u8x32);
280
- __m256i a_upper_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(a_magnitude_u8x32, half_select_u8x32),
281
- half_select_u8x32);
282
- __m256i a_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_lower_u8x32, a_shuffle_index_u8x32),
283
- _mm256_shuffle_epi8(lut_upper_u8x32, a_shuffle_index_u8x32),
284
- a_upper_select_u8x32);
280
+ __m256i a_high_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(a_magnitude_u8x32, half_select_u8x32),
281
+ half_select_u8x32);
282
+ __m256i a_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_low_u8x32, a_shuffle_index_u8x32),
283
+ _mm256_shuffle_epi8(lut_high_u8x32, a_shuffle_index_u8x32),
284
+ a_high_select_u8x32);
285
285
  __m256i a_negate_mask_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(a_e2m3_u8x32, sign_mask_u8x32), sign_mask_u8x32);
286
286
  __m256i a_signed_i8x32 = _mm256_blendv_epi8(
287
287
  a_unsigned_u8x32, _mm256_sub_epi8(_mm256_setzero_si256(), a_unsigned_u8x32), a_negate_mask_u8x32);
@@ -289,11 +289,11 @@ nk_dot_e2m3_sierra_cycle:
289
289
  // Decode b: same LUT decode + sign
290
290
  __m256i b_magnitude_u8x32 = _mm256_and_si256(b_e2m3_u8x32, magnitude_mask_u8x32);
291
291
  __m256i b_shuffle_index_u8x32 = _mm256_and_si256(b_magnitude_u8x32, nibble_mask_u8x32);
292
- __m256i b_upper_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(b_magnitude_u8x32, half_select_u8x32),
293
- half_select_u8x32);
294
- __m256i b_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_lower_u8x32, b_shuffle_index_u8x32),
295
- _mm256_shuffle_epi8(lut_upper_u8x32, b_shuffle_index_u8x32),
296
- b_upper_select_u8x32);
292
+ __m256i b_high_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(b_magnitude_u8x32, half_select_u8x32),
293
+ half_select_u8x32);
294
+ __m256i b_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_low_u8x32, b_shuffle_index_u8x32),
295
+ _mm256_shuffle_epi8(lut_high_u8x32, b_shuffle_index_u8x32),
296
+ b_high_select_u8x32);
297
297
  __m256i b_negate_mask_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(b_e2m3_u8x32, sign_mask_u8x32), sign_mask_u8x32);
298
298
  __m256i b_signed_i8x32 = _mm256_blendv_epi8(
299
299
  b_unsigned_u8x32, _mm256_sub_epi8(_mm256_setzero_si256(), b_unsigned_u8x32), b_negate_mask_u8x32);
@@ -318,10 +318,10 @@ NK_INTERNAL void nk_dot_e2m3x32_update_sierra(nk_dot_e2m3x32_state_sierra_t *sta
318
318
  nk_unused_(depth_offset);
319
319
  nk_unused_(active_dimensions);
320
320
  // Same LUT constants...
321
- __m256i const lut_lower_u8x32 = _mm256_set_epi8(30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0, 30, 28,
322
- 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
323
- __m256i const lut_upper_u8x32 = _mm256_set_epi8(120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32,
324
- 120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32);
321
+ __m256i const lut_low_u8x32 = _mm256_set_epi8(30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0, 30, 28, 26,
322
+ 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
323
+ __m256i const lut_high_u8x32 = _mm256_set_epi8(120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32,
324
+ 120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32);
325
325
  __m256i const nibble_mask_u8x32 = _mm256_set1_epi8(0x0F);
326
326
  __m256i const magnitude_mask_u8x32 = _mm256_set1_epi8(0x1F);
327
327
  __m256i const half_select_u8x32 = _mm256_set1_epi8(0x10);
@@ -333,11 +333,11 @@ NK_INTERNAL void nk_dot_e2m3x32_update_sierra(nk_dot_e2m3x32_state_sierra_t *sta
333
333
  // Decode a
334
334
  __m256i a_magnitude_u8x32 = _mm256_and_si256(a_e2m3_u8x32, magnitude_mask_u8x32);
335
335
  __m256i a_shuffle_index_u8x32 = _mm256_and_si256(a_magnitude_u8x32, nibble_mask_u8x32);
336
- __m256i a_upper_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(a_magnitude_u8x32, half_select_u8x32),
337
- half_select_u8x32);
338
- __m256i a_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_lower_u8x32, a_shuffle_index_u8x32),
339
- _mm256_shuffle_epi8(lut_upper_u8x32, a_shuffle_index_u8x32),
340
- a_upper_select_u8x32);
336
+ __m256i a_high_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(a_magnitude_u8x32, half_select_u8x32),
337
+ half_select_u8x32);
338
+ __m256i a_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_low_u8x32, a_shuffle_index_u8x32),
339
+ _mm256_shuffle_epi8(lut_high_u8x32, a_shuffle_index_u8x32),
340
+ a_high_select_u8x32);
341
341
  __m256i a_negate_mask_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(a_e2m3_u8x32, sign_mask_u8x32), sign_mask_u8x32);
342
342
  __m256i a_signed_i8x32 = _mm256_blendv_epi8(
343
343
  a_unsigned_u8x32, _mm256_sub_epi8(_mm256_setzero_si256(), a_unsigned_u8x32), a_negate_mask_u8x32);
@@ -345,11 +345,11 @@ NK_INTERNAL void nk_dot_e2m3x32_update_sierra(nk_dot_e2m3x32_state_sierra_t *sta
345
345
  // Decode b
346
346
  __m256i b_magnitude_u8x32 = _mm256_and_si256(b_e2m3_u8x32, magnitude_mask_u8x32);
347
347
  __m256i b_shuffle_index_u8x32 = _mm256_and_si256(b_magnitude_u8x32, nibble_mask_u8x32);
348
- __m256i b_upper_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(b_magnitude_u8x32, half_select_u8x32),
349
- half_select_u8x32);
350
- __m256i b_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_lower_u8x32, b_shuffle_index_u8x32),
351
- _mm256_shuffle_epi8(lut_upper_u8x32, b_shuffle_index_u8x32),
352
- b_upper_select_u8x32);
348
+ __m256i b_high_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(b_magnitude_u8x32, half_select_u8x32),
349
+ half_select_u8x32);
350
+ __m256i b_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_low_u8x32, b_shuffle_index_u8x32),
351
+ _mm256_shuffle_epi8(lut_high_u8x32, b_shuffle_index_u8x32),
352
+ b_high_select_u8x32);
353
353
  __m256i b_negate_mask_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(b_e2m3_u8x32, sign_mask_u8x32), sign_mask_u8x32);
354
354
  __m256i b_signed_i8x32 = _mm256_blendv_epi8(
355
355
  b_unsigned_u8x32, _mm256_sub_epi8(_mm256_setzero_si256(), b_unsigned_u8x32), b_negate_mask_u8x32);