numkong 7.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (294) hide show
  1. package/LICENSE +201 -0
  2. package/README.md +495 -0
  3. package/binding.gyp +540 -0
  4. package/c/dispatch.h +512 -0
  5. package/c/dispatch_bf16.c +389 -0
  6. package/c/dispatch_bf16c.c +52 -0
  7. package/c/dispatch_e2m3.c +263 -0
  8. package/c/dispatch_e3m2.c +243 -0
  9. package/c/dispatch_e4m3.c +276 -0
  10. package/c/dispatch_e5m2.c +272 -0
  11. package/c/dispatch_f16.c +376 -0
  12. package/c/dispatch_f16c.c +58 -0
  13. package/c/dispatch_f32.c +378 -0
  14. package/c/dispatch_f32c.c +99 -0
  15. package/c/dispatch_f64.c +296 -0
  16. package/c/dispatch_f64c.c +98 -0
  17. package/c/dispatch_i16.c +96 -0
  18. package/c/dispatch_i32.c +89 -0
  19. package/c/dispatch_i4.c +150 -0
  20. package/c/dispatch_i64.c +86 -0
  21. package/c/dispatch_i8.c +289 -0
  22. package/c/dispatch_other.c +330 -0
  23. package/c/dispatch_u1.c +148 -0
  24. package/c/dispatch_u16.c +124 -0
  25. package/c/dispatch_u32.c +118 -0
  26. package/c/dispatch_u4.c +150 -0
  27. package/c/dispatch_u64.c +102 -0
  28. package/c/dispatch_u8.c +303 -0
  29. package/c/numkong.c +950 -0
  30. package/include/README.md +573 -0
  31. package/include/module.modulemap +129 -0
  32. package/include/numkong/attention/sapphireamx.h +1361 -0
  33. package/include/numkong/attention/sme.h +2066 -0
  34. package/include/numkong/attention.h +49 -0
  35. package/include/numkong/capabilities.h +748 -0
  36. package/include/numkong/cast/README.md +262 -0
  37. package/include/numkong/cast/haswell.h +975 -0
  38. package/include/numkong/cast/icelake.h +470 -0
  39. package/include/numkong/cast/neon.h +1192 -0
  40. package/include/numkong/cast/rvv.h +1021 -0
  41. package/include/numkong/cast/sapphire.h +262 -0
  42. package/include/numkong/cast/serial.h +2262 -0
  43. package/include/numkong/cast/skylake.h +856 -0
  44. package/include/numkong/cast/v128relaxed.h +180 -0
  45. package/include/numkong/cast.h +230 -0
  46. package/include/numkong/curved/README.md +223 -0
  47. package/include/numkong/curved/genoa.h +182 -0
  48. package/include/numkong/curved/haswell.h +276 -0
  49. package/include/numkong/curved/neon.h +205 -0
  50. package/include/numkong/curved/neonbfdot.h +212 -0
  51. package/include/numkong/curved/neonhalf.h +212 -0
  52. package/include/numkong/curved/rvv.h +305 -0
  53. package/include/numkong/curved/serial.h +207 -0
  54. package/include/numkong/curved/skylake.h +457 -0
  55. package/include/numkong/curved/smef64.h +506 -0
  56. package/include/numkong/curved.h +517 -0
  57. package/include/numkong/curved.hpp +144 -0
  58. package/include/numkong/dot/README.md +425 -0
  59. package/include/numkong/dot/alder.h +563 -0
  60. package/include/numkong/dot/genoa.h +315 -0
  61. package/include/numkong/dot/haswell.h +1688 -0
  62. package/include/numkong/dot/icelake.h +883 -0
  63. package/include/numkong/dot/neon.h +818 -0
  64. package/include/numkong/dot/neonbfdot.h +244 -0
  65. package/include/numkong/dot/neonfhm.h +360 -0
  66. package/include/numkong/dot/neonhalf.h +198 -0
  67. package/include/numkong/dot/neonsdot.h +508 -0
  68. package/include/numkong/dot/rvv.h +714 -0
  69. package/include/numkong/dot/rvvbb.h +72 -0
  70. package/include/numkong/dot/rvvbf16.h +123 -0
  71. package/include/numkong/dot/rvvhalf.h +129 -0
  72. package/include/numkong/dot/sapphire.h +141 -0
  73. package/include/numkong/dot/serial.h +838 -0
  74. package/include/numkong/dot/sierra.h +405 -0
  75. package/include/numkong/dot/skylake.h +1084 -0
  76. package/include/numkong/dot/sve.h +379 -0
  77. package/include/numkong/dot/svebfdot.h +74 -0
  78. package/include/numkong/dot/svehalf.h +123 -0
  79. package/include/numkong/dot/v128relaxed.h +1258 -0
  80. package/include/numkong/dot.h +1070 -0
  81. package/include/numkong/dot.hpp +94 -0
  82. package/include/numkong/dots/README.md +496 -0
  83. package/include/numkong/dots/alder.h +114 -0
  84. package/include/numkong/dots/genoa.h +94 -0
  85. package/include/numkong/dots/haswell.h +295 -0
  86. package/include/numkong/dots/icelake.h +171 -0
  87. package/include/numkong/dots/neon.h +120 -0
  88. package/include/numkong/dots/neonbfdot.h +58 -0
  89. package/include/numkong/dots/neonfhm.h +94 -0
  90. package/include/numkong/dots/neonhalf.h +57 -0
  91. package/include/numkong/dots/neonsdot.h +108 -0
  92. package/include/numkong/dots/rvv.h +2486 -0
  93. package/include/numkong/dots/sapphireamx.h +3973 -0
  94. package/include/numkong/dots/serial.h +2844 -0
  95. package/include/numkong/dots/sierra.h +97 -0
  96. package/include/numkong/dots/skylake.h +196 -0
  97. package/include/numkong/dots/sme.h +5372 -0
  98. package/include/numkong/dots/smebi32.h +461 -0
  99. package/include/numkong/dots/smef64.h +1318 -0
  100. package/include/numkong/dots/smehalf.h +47 -0
  101. package/include/numkong/dots/v128relaxed.h +294 -0
  102. package/include/numkong/dots.h +2804 -0
  103. package/include/numkong/dots.hpp +639 -0
  104. package/include/numkong/each/README.md +469 -0
  105. package/include/numkong/each/haswell.h +1658 -0
  106. package/include/numkong/each/icelake.h +272 -0
  107. package/include/numkong/each/neon.h +1104 -0
  108. package/include/numkong/each/neonbfdot.h +212 -0
  109. package/include/numkong/each/neonhalf.h +410 -0
  110. package/include/numkong/each/rvv.h +1121 -0
  111. package/include/numkong/each/sapphire.h +477 -0
  112. package/include/numkong/each/serial.h +260 -0
  113. package/include/numkong/each/skylake.h +1562 -0
  114. package/include/numkong/each.h +2146 -0
  115. package/include/numkong/each.hpp +434 -0
  116. package/include/numkong/geospatial/README.md +147 -0
  117. package/include/numkong/geospatial/haswell.h +593 -0
  118. package/include/numkong/geospatial/neon.h +571 -0
  119. package/include/numkong/geospatial/rvv.h +701 -0
  120. package/include/numkong/geospatial/serial.h +309 -0
  121. package/include/numkong/geospatial/skylake.h +577 -0
  122. package/include/numkong/geospatial/v128relaxed.h +613 -0
  123. package/include/numkong/geospatial.h +453 -0
  124. package/include/numkong/geospatial.hpp +235 -0
  125. package/include/numkong/matrix.hpp +336 -0
  126. package/include/numkong/maxsim/README.md +187 -0
  127. package/include/numkong/maxsim/alder.h +511 -0
  128. package/include/numkong/maxsim/genoa.h +115 -0
  129. package/include/numkong/maxsim/haswell.h +553 -0
  130. package/include/numkong/maxsim/icelake.h +480 -0
  131. package/include/numkong/maxsim/neonsdot.h +394 -0
  132. package/include/numkong/maxsim/sapphireamx.h +877 -0
  133. package/include/numkong/maxsim/serial.h +490 -0
  134. package/include/numkong/maxsim/sme.h +929 -0
  135. package/include/numkong/maxsim/v128relaxed.h +280 -0
  136. package/include/numkong/maxsim.h +571 -0
  137. package/include/numkong/maxsim.hpp +133 -0
  138. package/include/numkong/mesh/README.md +227 -0
  139. package/include/numkong/mesh/haswell.h +2235 -0
  140. package/include/numkong/mesh/neon.h +1329 -0
  141. package/include/numkong/mesh/neonbfdot.h +842 -0
  142. package/include/numkong/mesh/neonhalf.h +616 -0
  143. package/include/numkong/mesh/rvv.h +916 -0
  144. package/include/numkong/mesh/serial.h +742 -0
  145. package/include/numkong/mesh/skylake.h +1135 -0
  146. package/include/numkong/mesh/v128relaxed.h +1052 -0
  147. package/include/numkong/mesh.h +652 -0
  148. package/include/numkong/mesh.hpp +762 -0
  149. package/include/numkong/numkong.h +78 -0
  150. package/include/numkong/numkong.hpp +57 -0
  151. package/include/numkong/probability/README.md +173 -0
  152. package/include/numkong/probability/haswell.h +267 -0
  153. package/include/numkong/probability/neon.h +225 -0
  154. package/include/numkong/probability/rvv.h +409 -0
  155. package/include/numkong/probability/serial.h +169 -0
  156. package/include/numkong/probability/skylake.h +324 -0
  157. package/include/numkong/probability.h +383 -0
  158. package/include/numkong/probability.hpp +120 -0
  159. package/include/numkong/random.h +50 -0
  160. package/include/numkong/random.hpp +285 -0
  161. package/include/numkong/reduce/README.md +547 -0
  162. package/include/numkong/reduce/alder.h +632 -0
  163. package/include/numkong/reduce/genoa.h +201 -0
  164. package/include/numkong/reduce/haswell.h +3783 -0
  165. package/include/numkong/reduce/icelake.h +549 -0
  166. package/include/numkong/reduce/neon.h +3841 -0
  167. package/include/numkong/reduce/neonbfdot.h +353 -0
  168. package/include/numkong/reduce/neonfhm.h +665 -0
  169. package/include/numkong/reduce/neonhalf.h +157 -0
  170. package/include/numkong/reduce/neonsdot.h +357 -0
  171. package/include/numkong/reduce/rvv.h +3407 -0
  172. package/include/numkong/reduce/serial.h +757 -0
  173. package/include/numkong/reduce/sierra.h +338 -0
  174. package/include/numkong/reduce/skylake.h +3792 -0
  175. package/include/numkong/reduce/v128relaxed.h +2302 -0
  176. package/include/numkong/reduce.h +1597 -0
  177. package/include/numkong/reduce.hpp +633 -0
  178. package/include/numkong/scalar/README.md +89 -0
  179. package/include/numkong/scalar/haswell.h +113 -0
  180. package/include/numkong/scalar/neon.h +122 -0
  181. package/include/numkong/scalar/neonhalf.h +70 -0
  182. package/include/numkong/scalar/rvv.h +211 -0
  183. package/include/numkong/scalar/sapphire.h +63 -0
  184. package/include/numkong/scalar/serial.h +332 -0
  185. package/include/numkong/scalar/v128relaxed.h +56 -0
  186. package/include/numkong/scalar.h +683 -0
  187. package/include/numkong/set/README.md +179 -0
  188. package/include/numkong/set/haswell.h +334 -0
  189. package/include/numkong/set/icelake.h +485 -0
  190. package/include/numkong/set/neon.h +364 -0
  191. package/include/numkong/set/rvv.h +226 -0
  192. package/include/numkong/set/rvvbb.h +117 -0
  193. package/include/numkong/set/serial.h +174 -0
  194. package/include/numkong/set/sve.h +185 -0
  195. package/include/numkong/set/v128relaxed.h +240 -0
  196. package/include/numkong/set.h +457 -0
  197. package/include/numkong/set.hpp +114 -0
  198. package/include/numkong/sets/README.md +149 -0
  199. package/include/numkong/sets/haswell.h +63 -0
  200. package/include/numkong/sets/icelake.h +66 -0
  201. package/include/numkong/sets/neon.h +61 -0
  202. package/include/numkong/sets/serial.h +43 -0
  203. package/include/numkong/sets/smebi32.h +1099 -0
  204. package/include/numkong/sets/v128relaxed.h +58 -0
  205. package/include/numkong/sets.h +339 -0
  206. package/include/numkong/sparse/README.md +156 -0
  207. package/include/numkong/sparse/icelake.h +463 -0
  208. package/include/numkong/sparse/neon.h +288 -0
  209. package/include/numkong/sparse/serial.h +117 -0
  210. package/include/numkong/sparse/sve2.h +507 -0
  211. package/include/numkong/sparse/turin.h +322 -0
  212. package/include/numkong/sparse.h +363 -0
  213. package/include/numkong/sparse.hpp +113 -0
  214. package/include/numkong/spatial/README.md +435 -0
  215. package/include/numkong/spatial/alder.h +607 -0
  216. package/include/numkong/spatial/genoa.h +290 -0
  217. package/include/numkong/spatial/haswell.h +960 -0
  218. package/include/numkong/spatial/icelake.h +586 -0
  219. package/include/numkong/spatial/neon.h +773 -0
  220. package/include/numkong/spatial/neonbfdot.h +165 -0
  221. package/include/numkong/spatial/neonhalf.h +118 -0
  222. package/include/numkong/spatial/neonsdot.h +261 -0
  223. package/include/numkong/spatial/rvv.h +984 -0
  224. package/include/numkong/spatial/rvvbf16.h +123 -0
  225. package/include/numkong/spatial/rvvhalf.h +117 -0
  226. package/include/numkong/spatial/sapphire.h +343 -0
  227. package/include/numkong/spatial/serial.h +346 -0
  228. package/include/numkong/spatial/sierra.h +323 -0
  229. package/include/numkong/spatial/skylake.h +606 -0
  230. package/include/numkong/spatial/sve.h +224 -0
  231. package/include/numkong/spatial/svebfdot.h +122 -0
  232. package/include/numkong/spatial/svehalf.h +109 -0
  233. package/include/numkong/spatial/v128relaxed.h +717 -0
  234. package/include/numkong/spatial.h +1425 -0
  235. package/include/numkong/spatial.hpp +183 -0
  236. package/include/numkong/spatials/README.md +580 -0
  237. package/include/numkong/spatials/alder.h +94 -0
  238. package/include/numkong/spatials/genoa.h +94 -0
  239. package/include/numkong/spatials/haswell.h +219 -0
  240. package/include/numkong/spatials/icelake.h +113 -0
  241. package/include/numkong/spatials/neon.h +109 -0
  242. package/include/numkong/spatials/neonbfdot.h +60 -0
  243. package/include/numkong/spatials/neonfhm.h +92 -0
  244. package/include/numkong/spatials/neonhalf.h +58 -0
  245. package/include/numkong/spatials/neonsdot.h +109 -0
  246. package/include/numkong/spatials/rvv.h +1960 -0
  247. package/include/numkong/spatials/sapphireamx.h +1149 -0
  248. package/include/numkong/spatials/serial.h +226 -0
  249. package/include/numkong/spatials/sierra.h +96 -0
  250. package/include/numkong/spatials/skylake.h +184 -0
  251. package/include/numkong/spatials/sme.h +1901 -0
  252. package/include/numkong/spatials/smef64.h +465 -0
  253. package/include/numkong/spatials/v128relaxed.h +240 -0
  254. package/include/numkong/spatials.h +3021 -0
  255. package/include/numkong/spatials.hpp +508 -0
  256. package/include/numkong/tensor.hpp +1592 -0
  257. package/include/numkong/trigonometry/README.md +184 -0
  258. package/include/numkong/trigonometry/haswell.h +652 -0
  259. package/include/numkong/trigonometry/neon.h +639 -0
  260. package/include/numkong/trigonometry/rvv.h +699 -0
  261. package/include/numkong/trigonometry/serial.h +703 -0
  262. package/include/numkong/trigonometry/skylake.h +721 -0
  263. package/include/numkong/trigonometry/v128relaxed.h +666 -0
  264. package/include/numkong/trigonometry.h +467 -0
  265. package/include/numkong/trigonometry.hpp +166 -0
  266. package/include/numkong/types.h +1384 -0
  267. package/include/numkong/types.hpp +5603 -0
  268. package/include/numkong/vector.hpp +698 -0
  269. package/javascript/README.md +246 -0
  270. package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
  271. package/javascript/dist/cjs/numkong-wasm.js +617 -0
  272. package/javascript/dist/cjs/numkong.d.ts +343 -0
  273. package/javascript/dist/cjs/numkong.js +523 -0
  274. package/javascript/dist/cjs/package.json +3 -0
  275. package/javascript/dist/cjs/types.d.ts +284 -0
  276. package/javascript/dist/cjs/types.js +653 -0
  277. package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
  278. package/javascript/dist/esm/numkong-wasm.js +595 -0
  279. package/javascript/dist/esm/numkong.d.ts +343 -0
  280. package/javascript/dist/esm/numkong.js +452 -0
  281. package/javascript/dist/esm/package.json +3 -0
  282. package/javascript/dist/esm/types.d.ts +284 -0
  283. package/javascript/dist/esm/types.js +630 -0
  284. package/javascript/dist-package-cjs.json +3 -0
  285. package/javascript/dist-package-esm.json +3 -0
  286. package/javascript/node-gyp-build.d.ts +1 -0
  287. package/javascript/numkong-wasm.ts +756 -0
  288. package/javascript/numkong.c +689 -0
  289. package/javascript/numkong.ts +575 -0
  290. package/javascript/tsconfig-base.json +39 -0
  291. package/javascript/tsconfig-cjs.json +8 -0
  292. package/javascript/tsconfig-esm.json +8 -0
  293. package/javascript/types.ts +674 -0
  294. package/package.json +87 -0
@@ -0,0 +1,838 @@
1
+ /**
2
+ * @brief SWAR-accelerated Dot Products for SIMD-free CPUs.
3
+ * @file include/numkong/dot/serial.h
4
+ * @author Ash Vardanian
5
+ * @date December 27, 2025
6
+ *
7
+ * @sa include/numkong/dot.h
8
+ *
9
+ * @section dot_serial_instructions Serial Fallback Implementation
10
+ *
11
+ * The serial backend provides portable scalar implementations for all numeric types without requiring
12
+ * any SIMD extensions. While significantly slower than vectorized implementations, these serve as:
13
+ *
14
+ * - Reference implementations for correctness validation
15
+ * - Fallbacks for platforms without SIMD support (WASM, older CPUs)
16
+ * - Baseline for benchmarking vectorized speedups
17
+ *
18
+ * For f64 dot products, compensated (Kahan-style) summation is used to minimize floating-point
19
+ * accumulation errors. For smaller types (f16, bf16, FP8), values are upcast to f32 for accumulation.
20
+ *
21
+ * @section dot_serial_stateful Stateful Streaming Logic
22
+ *
23
+ * To build memory-optimal tiled algorithms, this file defines following structures and force-inlined
24
+ * `NK_INTERNAL` functions:
25
+ *
26
+ * - nk_dot_f64x2 state with compensated summation for numerical stability,
27
+ * - nk_dot_f32x4 state with simple f32 accumulation,
28
+ * - nk_dot_f16x8 state for f16 inputs via f32 upcasting,
29
+ * - nk_dot_bf16x8 state for bf16 inputs via f32 upcasting,
30
+ * - nk_dot_i8x16 for 8-bit signed integer inputs,
31
+ * - nk_dot_u8x16 for 8-bit unsigned integer inputs,
32
+ * - nk_dot_e4m3x16, nk_dot_e5m2x16, nk_dot_e2m3x16, nk_dot_e3m2x16 for FP8/FP6 inputs,
33
+ * - nk_dot_i4x16, nk_dot_u4x16 for 4-bit integer inputs.
34
+ *
35
+ * @code{c}
36
+ * nk_dot_f64x2_state_serial_t state_first, state_second, state_third, state_fourth;
37
+ * nk_b128_vec_t query_f64x2, target_first_f64x2, target_second_f64x2, target_third_f64x2, target_fourth_f64x2;
38
+ * nk_dot_f64x2_init_serial(&state_first);
39
+ * nk_dot_f64x2_init_serial(&state_second);
40
+ * nk_dot_f64x2_init_serial(&state_third);
41
+ * nk_dot_f64x2_init_serial(&state_fourth);
42
+ * for (nk_size_t idx = 0; idx + 2 <= depth; idx += 2) {
43
+ * query_f64x2.f64s[0] = query_ptr[idx], query_f64x2.f64s[1] = query_ptr[idx + 1];
44
+ * target_first_f64x2.f64s[0] = target_first_ptr[idx], target_first_f64x2.f64s[1] = target_first_ptr[idx + 1];
45
+ * target_second_f64x2.f64s[0] = target_second_ptr[idx], target_second_f64x2.f64s[1] = target_second_ptr[idx + 1];
46
+ * target_third_f64x2.f64s[0] = target_third_ptr[idx], target_third_f64x2.f64s[1] = target_third_ptr[idx + 1];
47
+ * target_fourth_f64x2.f64s[0] = target_fourth_ptr[idx], target_fourth_f64x2.f64s[1] = target_fourth_ptr[idx + 1];
48
+ * nk_dot_f64x2_update_serial(&state_first, query_f64x2, target_first_f64x2, idx, 2);
49
+ * nk_dot_f64x2_update_serial(&state_second, query_f64x2, target_second_f64x2, idx, 2);
50
+ * nk_dot_f64x2_update_serial(&state_third, query_f64x2, target_third_f64x2, idx, 2);
51
+ * nk_dot_f64x2_update_serial(&state_fourth, query_f64x2, target_fourth_f64x2, idx, 2);
52
+ * }
53
+ * nk_b256_vec_t results_f64x4;
54
+ * nk_dot_f64x2_finalize_serial(&state_first, &state_second, &state_third, &state_fourth, depth, &results_f64x4);
55
+ * @endcode
56
+ *
57
+ * Integer types follow a similar pattern with appropriate type changes:
58
+ *
59
+ * @code{c}
60
+ * nk_dot_i8x16_state_serial_t state_first, state_second, state_third, state_fourth;
61
+ * nk_b128_vec_t query_i8x16, target_first_i8x16, target_second_i8x16, target_third_i8x16, target_fourth_i8x16;
62
+ * nk_dot_i8x16_init_serial(&state_first);
63
+ * nk_dot_i8x16_init_serial(&state_second);
64
+ * nk_dot_i8x16_init_serial(&state_third);
65
+ * nk_dot_i8x16_init_serial(&state_fourth);
66
+ * for (nk_size_t idx = 0; idx + 16 <= depth; idx += 16) {
67
+ * memcpy(query_i8x16.i8s, query_ptr + idx, 16);
68
+ * memcpy(target_first_i8x16.i8s, target_first_ptr + idx, 16);
69
+ * memcpy(target_second_i8x16.i8s, target_second_ptr + idx, 16);
70
+ * memcpy(target_third_i8x16.i8s, target_third_ptr + idx, 16);
71
+ * memcpy(target_fourth_i8x16.i8s, target_fourth_ptr + idx, 16);
72
+ * nk_dot_i8x16_update_serial(&state_first, query_i8x16, target_first_i8x16, idx, 16);
73
+ * nk_dot_i8x16_update_serial(&state_second, query_i8x16, target_second_i8x16, idx, 16);
74
+ * nk_dot_i8x16_update_serial(&state_third, query_i8x16, target_third_i8x16, idx, 16);
75
+ * nk_dot_i8x16_update_serial(&state_fourth, query_i8x16, target_fourth_i8x16, idx, 16);
76
+ * }
77
+ * nk_b128_vec_t results_i32x4;
78
+ * nk_dot_i8x16_finalize_serial(&state_first, &state_second, &state_third, &state_fourth, depth, &results_i32x4);
79
+ * @endcode
80
+ */
81
+ #ifndef NK_DOT_SERIAL_H
82
+ #define NK_DOT_SERIAL_H
83
+
84
+ #include "numkong/types.h"
85
+ #include "numkong/reduce/serial.h" // `nk_f64_abs_`
86
+
87
+ #if defined(__cplusplus)
88
+ extern "C" {
89
+ #endif
90
+
91
+ /**
92
+ * @brief Macro for dot product with simple accumulation.
93
+ */
94
+ #define nk_define_dot_(input_type, accumulator_type, output_type, load_and_convert) \
95
+ NK_PUBLIC void nk_dot_##input_type##_serial(nk_##input_type##_t const *a, nk_##input_type##_t const *b, \
96
+ nk_size_t n, nk_##output_type##_t *result) { \
97
+ nk_##accumulator_type##_t sum = 0, a_val, b_val; \
98
+ for (nk_size_t i = 0; i != n; ++i) { \
99
+ load_and_convert(a + i, &a_val); \
100
+ load_and_convert(b + i, &b_val); \
101
+ sum += a_val * b_val; \
102
+ } \
103
+ *result = (nk_##output_type##_t)sum; \
104
+ }
105
+
106
+ #define nk_define_dot_complex_(input_type, accumulator_type, output_complex_type, load_and_convert) \
107
+ NK_PUBLIC void nk_dot_##input_type##_serial(nk_##input_type##_t const *a_pairs, \
108
+ nk_##input_type##_t const *b_pairs, nk_size_t count_pairs, \
109
+ nk_##output_complex_type##_t *result) { \
110
+ nk_##accumulator_type##_t sum_real = 0, sum_imag = 0; \
111
+ nk_##accumulator_type##_t a_real, b_real, a_imag, b_imag; \
112
+ for (nk_size_t i = 0; i != count_pairs; ++i) { \
113
+ load_and_convert(&(a_pairs + i)->real, &a_real); \
114
+ load_and_convert(&(b_pairs + i)->real, &b_real); \
115
+ load_and_convert(&(a_pairs + i)->imag, &a_imag); \
116
+ load_and_convert(&(b_pairs + i)->imag, &b_imag); \
117
+ sum_real += a_real * b_real - a_imag * b_imag; \
118
+ sum_imag += a_real * b_imag + a_imag * b_real; \
119
+ } \
120
+ result->real = sum_real; \
121
+ result->imag = sum_imag; \
122
+ }
123
+
124
+ #define nk_define_vdot_complex_(input_type, accumulator_type, output_complex_type, load_and_convert) \
125
+ NK_PUBLIC void nk_vdot_##input_type##_serial(nk_##input_type##_t const *a_pairs, \
126
+ nk_##input_type##_t const *b_pairs, nk_size_t count_pairs, \
127
+ nk_##output_complex_type##_t *result) { \
128
+ nk_##accumulator_type##_t sum_real = 0, sum_imag = 0; \
129
+ nk_##accumulator_type##_t a_real, b_real, a_imag, b_imag; \
130
+ for (nk_size_t i = 0; i != count_pairs; ++i) { \
131
+ load_and_convert(&(a_pairs + i)->real, &a_real); \
132
+ load_and_convert(&(b_pairs + i)->real, &b_real); \
133
+ load_and_convert(&(a_pairs + i)->imag, &a_imag); \
134
+ load_and_convert(&(b_pairs + i)->imag, &b_imag); \
135
+ sum_real += a_real * b_real + a_imag * b_imag; \
136
+ sum_imag += a_real * b_imag - a_imag * b_real; \
137
+ } \
138
+ result->real = sum_real; \
139
+ result->imag = sum_imag; \
140
+ }
141
+
142
+ #pragma region - Traditional Floats
143
+
144
+ nk_define_dot_(f32, f64, f64, nk_assign_from_to_) // nk_dot_f32_serial
145
+ nk_define_dot_complex_(f32c, f64, f64c, nk_assign_from_to_) // nk_dot_f32c_serial
146
+ nk_define_vdot_complex_(f32c, f64, f64c, nk_assign_from_to_) // nk_vdot_f32c_serial
147
+
148
+ #pragma endregion - Traditional Floats
149
+
150
+ #pragma region - Smaller Floats
151
+
152
+ nk_define_dot_(f16, f32, f32, nk_f16_to_f32_serial) // nk_dot_f16_serial
153
+ nk_define_dot_complex_(f16c, f32, f32c, nk_f16_to_f32_serial) // nk_dot_f16c_serial
154
+ nk_define_vdot_complex_(f16c, f32, f32c, nk_f16_to_f32_serial) // nk_vdot_f16c_serial
155
+
156
+ nk_define_dot_(bf16, f32, f32, nk_bf16_to_f32_serial) // nk_dot_bf16_serial
157
+ nk_define_dot_complex_(bf16c, f32, f32c, nk_bf16_to_f32_serial) // nk_dot_bf16c_serial
158
+ nk_define_vdot_complex_(bf16c, f32, f32c, nk_bf16_to_f32_serial) // nk_vdot_bf16c_serial
159
+
160
+ nk_define_dot_(e4m3, f32, f32, nk_e4m3_to_f32_serial) // nk_dot_e4m3_serial
161
+ nk_define_dot_(e5m2, f32, f32, nk_e5m2_to_f32_serial) // nk_dot_e5m2_serial
162
+ nk_define_dot_(e2m3, f32, f32, nk_e2m3_to_f32_serial) // nk_dot_e2m3_serial
163
+ nk_define_dot_(e3m2, f32, f32, nk_e3m2_to_f32_serial) // nk_dot_e3m2_serial
164
+
165
+ #pragma endregion - Smaller Floats
166
+
167
+ #pragma region - Small Integers
168
+
169
+ nk_define_dot_(i8, i32, i32, nk_assign_from_to_) // nk_dot_i8_serial
170
+ nk_define_dot_(u8, u32, u32, nk_assign_from_to_) // nk_dot_u8_serial
171
+
172
+ #undef nk_define_dot_
173
+ #undef nk_define_dot_complex_
174
+ #undef nk_define_vdot_complex_
175
+
176
+ NK_PUBLIC void nk_dot_i4_serial(nk_i4x2_t const *a, nk_i4x2_t const *b, nk_size_t n, nk_i32_t *result) {
177
+ // i4 values are packed as nibbles: two 4-bit signed values per byte.
178
+ // Parameter `n` is the number of 4-bit values (dimensions), not bytes.
179
+ // Sign extension: (nibble ^ 8) - 8 maps [0,15] to [-8,7]
180
+ n = nk_size_round_up_to_multiple_(n, 2);
181
+ nk_size_t n_bytes = n / 2;
182
+ nk_i32_t sum = 0;
183
+ for (nk_size_t i = 0; i < n_bytes; ++i) {
184
+ nk_i32_t a_low = (nk_i32_t)nk_i4x2_low_(a[i]);
185
+ nk_i32_t b_low = (nk_i32_t)nk_i4x2_low_(b[i]);
186
+ nk_i32_t a_high = (nk_i32_t)nk_i4x2_high_(a[i]);
187
+ nk_i32_t b_high = (nk_i32_t)nk_i4x2_high_(b[i]);
188
+ sum += a_low * b_low + a_high * b_high;
189
+ }
190
+ *result = sum;
191
+ }
192
+
193
+ NK_PUBLIC void nk_dot_u4_serial(nk_u4x2_t const *a, nk_u4x2_t const *b, nk_size_t n, nk_u32_t *result) {
194
+ // u4 values are packed as nibbles: two 4-bit unsigned values per byte.
195
+ // Parameter `n` is the number of 4-bit values (dimensions), not bytes.
196
+ // No sign extension needed - values are ∈ [0,15].
197
+ n = nk_size_round_up_to_multiple_(n, 2);
198
+ nk_size_t n_bytes = n / 2;
199
+ nk_u32_t sum = 0;
200
+ for (nk_size_t i = 0; i < n_bytes; ++i) {
201
+ nk_u32_t a_low = (nk_u32_t)nk_u4x2_low_(a[i]);
202
+ nk_u32_t b_low = (nk_u32_t)nk_u4x2_low_(b[i]);
203
+ nk_u32_t a_high = (nk_u32_t)nk_u4x2_high_(a[i]);
204
+ nk_u32_t b_high = (nk_u32_t)nk_u4x2_high_(b[i]);
205
+ sum += a_low * b_low + a_high * b_high;
206
+ }
207
+ *result = sum;
208
+ }
209
+
210
+ #pragma endregion - Small Integers
211
+
212
+ #pragma region - Traditional Floats
213
+
214
+ /* Double-precision dot-produce variants
215
+ *
216
+ * Implements Neumaier's Kahan-Babuška variant to minimize floating-point rounding errors.
217
+ * Unlike Kahan, Neumaier handles the case where the term being added is larger than the
218
+ * running sum. Achieves O(1) error growth regardless of vector dimension.
219
+ *
220
+ * Algorithm: For each term, compute t = sum + term, then:
221
+ * - If ‖sum‖ ≥ ‖term‖: c += (sum - t) + term (lost low-order bits of term)
222
+ * - Else: c += (term - t) + sum (lost low-order bits of sum)
223
+ *
224
+ * @see Neumaier, A. (1974). "Rundungsfehleranalyse einiger Verfahren zur Summation endlicher Summen"
225
+ */
226
+ NK_PUBLIC void nk_dot_f64_serial(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result) {
227
+ nk_f64_t sum = 0, compensation = 0;
228
+ for (nk_size_t i = 0; i != n; ++i) nk_f64_dot2_(&sum, &compensation, a[i], b[i]);
229
+ *result = sum + compensation;
230
+ }
231
+
232
+ NK_PUBLIC void nk_dot_f64c_serial(nk_f64c_t const *a_pairs, nk_f64c_t const *b_pairs, nk_size_t count_pairs,
233
+ nk_f64c_t *result) {
234
+ nk_f64_t sum_real = 0, sum_imag = 0, compensation_real = 0, compensation_imag = 0;
235
+ for (nk_size_t i = 0; i != count_pairs; ++i) {
236
+ nk_f64_t a_real = a_pairs[i].real, b_real = b_pairs[i].real;
237
+ nk_f64_t a_imag = a_pairs[i].imag, b_imag = b_pairs[i].imag;
238
+ nk_f64_dot2_(&sum_real, &compensation_real, a_real, b_real);
239
+ nk_f64_dot2_(&sum_real, &compensation_real, -a_imag, b_imag);
240
+ nk_f64_dot2_(&sum_imag, &compensation_imag, a_real, b_imag);
241
+ nk_f64_dot2_(&sum_imag, &compensation_imag, a_imag, b_real);
242
+ }
243
+ result->real = sum_real + compensation_real;
244
+ result->imag = sum_imag + compensation_imag;
245
+ }
246
+
247
+ NK_PUBLIC void nk_vdot_f64c_serial(nk_f64c_t const *a_pairs, nk_f64c_t const *b_pairs, nk_size_t count_pairs,
248
+ nk_f64c_t *result) {
249
+ nk_f64_t sum_real = 0, sum_imag = 0, compensation_real = 0, compensation_imag = 0;
250
+ for (nk_size_t i = 0; i != count_pairs; ++i) {
251
+ nk_f64_t a_real = a_pairs[i].real, b_real = b_pairs[i].real;
252
+ nk_f64_t a_imag = a_pairs[i].imag, b_imag = b_pairs[i].imag;
253
+ nk_f64_dot2_(&sum_real, &compensation_real, a_real, b_real);
254
+ nk_f64_dot2_(&sum_real, &compensation_real, a_imag, b_imag);
255
+ nk_f64_dot2_(&sum_imag, &compensation_imag, a_real, b_imag);
256
+ nk_f64_dot2_(&sum_imag, &compensation_imag, -a_imag, b_real);
257
+ }
258
+ result->real = sum_real + compensation_real;
259
+ result->imag = sum_imag + compensation_imag;
260
+ }
261
+
262
+ typedef struct nk_dot_f64x2_state_serial_t {
263
+ nk_f64_t sums[2];
264
+ nk_f64_t compensations[2];
265
+ } nk_dot_f64x2_state_serial_t;
266
+
267
+ NK_INTERNAL void nk_dot_f64x2_init_serial(nk_dot_f64x2_state_serial_t *state) {
268
+ state->sums[0] = 0, state->sums[1] = 0;
269
+ state->compensations[0] = 0, state->compensations[1] = 0;
270
+ }
271
+
272
+ NK_INTERNAL void nk_dot_f64x2_update_serial(nk_dot_f64x2_state_serial_t *state, nk_b128_vec_t a, nk_b128_vec_t b,
273
+ nk_size_t depth_offset, nk_size_t active_dimensions) {
274
+ nk_unused_(depth_offset);
275
+ nk_unused_(active_dimensions);
276
+ nk_f64_t sum0 = state->sums[0], compensation0 = state->compensations[0];
277
+ nk_f64_t sum1 = state->sums[1], compensation1 = state->compensations[1];
278
+ nk_f64_dot2_(&sum0, &compensation0, a.f64s[0], b.f64s[0]);
279
+ nk_f64_dot2_(&sum1, &compensation1, a.f64s[1], b.f64s[1]);
280
+
281
+ state->sums[0] = sum0, state->sums[1] = sum1;
282
+ state->compensations[0] = compensation0, state->compensations[1] = compensation1;
283
+ }
284
+
285
+ NK_INTERNAL void nk_dot_f64x2_finalize_serial( //
286
+ nk_dot_f64x2_state_serial_t const *state_a, nk_dot_f64x2_state_serial_t const *state_b, //
287
+ nk_dot_f64x2_state_serial_t const *state_c, nk_dot_f64x2_state_serial_t const *state_d, //
288
+ nk_size_t total_dimensions, nk_b256_vec_t *result) {
289
+ nk_unused_(total_dimensions);
290
+ result->f64s[0] = nk_reduce_sum_f64_serial_(state_a->sums, state_a->compensations, 2);
291
+ result->f64s[1] = nk_reduce_sum_f64_serial_(state_b->sums, state_b->compensations, 2);
292
+ result->f64s[2] = nk_reduce_sum_f64_serial_(state_c->sums, state_c->compensations, 2);
293
+ result->f64s[3] = nk_reduce_sum_f64_serial_(state_d->sums, state_d->compensations, 2);
294
+ }
295
+
296
+ typedef struct nk_dot_f32x4_state_serial_t {
297
+ nk_f64_t sums[4];
298
+ } nk_dot_f32x4_state_serial_t;
299
+
300
+ NK_INTERNAL void nk_dot_f32x4_init_serial(nk_dot_f32x4_state_serial_t *state) {
301
+ state->sums[0] = 0, state->sums[1] = 0, state->sums[2] = 0, state->sums[3] = 0;
302
+ }
303
+
304
+ NK_INTERNAL void nk_dot_f32x4_update_serial(nk_dot_f32x4_state_serial_t *state, nk_b128_vec_t a, nk_b128_vec_t b,
305
+ nk_size_t depth_offset, nk_size_t active_dimensions) {
306
+ nk_unused_(depth_offset);
307
+ nk_unused_(active_dimensions);
308
+ nk_f64_t sum0 = state->sums[0];
309
+ nk_f64_t sum1 = state->sums[1];
310
+ nk_f64_t sum2 = state->sums[2];
311
+ nk_f64_t sum3 = state->sums[3];
312
+ sum0 += (nk_f64_t)a.f32s[0] * b.f32s[0], sum1 += (nk_f64_t)a.f32s[1] * b.f32s[1];
313
+ sum2 += (nk_f64_t)a.f32s[2] * b.f32s[2], sum3 += (nk_f64_t)a.f32s[3] * b.f32s[3];
314
+ state->sums[0] = sum0, state->sums[1] = sum1, state->sums[2] = sum2, state->sums[3] = sum3;
315
+ }
316
+
317
+ NK_INTERNAL void nk_dot_f32x4_finalize_serial( //
318
+ nk_dot_f32x4_state_serial_t const *state_a, nk_dot_f32x4_state_serial_t const *state_b, //
319
+ nk_dot_f32x4_state_serial_t const *state_c, nk_dot_f32x4_state_serial_t const *state_d, //
320
+ nk_size_t total_dimensions, nk_b256_vec_t *result) {
321
+ nk_unused_(total_dimensions);
322
+ result->f64s[0] = state_a->sums[0] + state_a->sums[1] + state_a->sums[2] + state_a->sums[3];
323
+ result->f64s[1] = state_b->sums[0] + state_b->sums[1] + state_b->sums[2] + state_b->sums[3];
324
+ result->f64s[2] = state_c->sums[0] + state_c->sums[1] + state_c->sums[2] + state_c->sums[3];
325
+ result->f64s[3] = state_d->sums[0] + state_d->sums[1] + state_d->sums[2] + state_d->sums[3];
326
+ }
327
+
328
+ #pragma endregion - Traditional Floats
329
+
330
+ #pragma region - Smaller Floats
331
+
332
+ typedef struct nk_dot_f16x8_state_serial_t {
333
+ nk_f32_t sums[4];
334
+ } nk_dot_f16x8_state_serial_t;
335
+
336
+ NK_INTERNAL void nk_dot_f16x8_init_serial(nk_dot_f16x8_state_serial_t *state) {
337
+ state->sums[0] = 0, state->sums[1] = 0, state->sums[2] = 0, state->sums[3] = 0;
338
+ }
339
+
340
+ NK_INTERNAL void nk_dot_f16x8_update_serial(nk_dot_f16x8_state_serial_t *state, nk_b128_vec_t a, nk_b128_vec_t b,
341
+ nk_size_t depth_offset, nk_size_t active_dimensions) {
342
+ nk_unused_(depth_offset);
343
+ nk_unused_(active_dimensions);
344
+ nk_f32_t sum0 = state->sums[0], sum1 = state->sums[1], sum2 = state->sums[2], sum3 = state->sums[3];
345
+ for (nk_size_t i = 0; i < 8; i += 4) {
346
+ nk_f32_t a0, a1, a2, a3, b0, b1, b2, b3;
347
+ nk_f16_to_f32_serial(a.f16s + i + 0, &a0), nk_f16_to_f32_serial(a.f16s + i + 1, &a1);
348
+ nk_f16_to_f32_serial(a.f16s + i + 2, &a2), nk_f16_to_f32_serial(a.f16s + i + 3, &a3);
349
+ nk_f16_to_f32_serial(b.f16s + i + 0, &b0), nk_f16_to_f32_serial(b.f16s + i + 1, &b1);
350
+ nk_f16_to_f32_serial(b.f16s + i + 2, &b2), nk_f16_to_f32_serial(b.f16s + i + 3, &b3);
351
+ sum0 += a0 * b0, sum1 += a1 * b1, sum2 += a2 * b2, sum3 += a3 * b3;
352
+ }
353
+ state->sums[0] = sum0, state->sums[1] = sum1, state->sums[2] = sum2, state->sums[3] = sum3;
354
+ }
355
+
356
+ NK_INTERNAL void nk_dot_f16x8_finalize_serial( //
357
+ nk_dot_f16x8_state_serial_t const *state_a, nk_dot_f16x8_state_serial_t const *state_b, //
358
+ nk_dot_f16x8_state_serial_t const *state_c, nk_dot_f16x8_state_serial_t const *state_d, //
359
+ nk_size_t total_dimensions, nk_b128_vec_t *result) {
360
+ nk_unused_(total_dimensions);
361
+ result->f32s[0] = state_a->sums[0] + state_a->sums[1] + state_a->sums[2] + state_a->sums[3];
362
+ result->f32s[1] = state_b->sums[0] + state_b->sums[1] + state_b->sums[2] + state_b->sums[3];
363
+ result->f32s[2] = state_c->sums[0] + state_c->sums[1] + state_c->sums[2] + state_c->sums[3];
364
+ result->f32s[3] = state_d->sums[0] + state_d->sums[1] + state_d->sums[2] + state_d->sums[3];
365
+ }
366
+
367
+ typedef struct nk_dot_bf16x8_state_serial_t {
368
+ nk_f32_t sums[4];
369
+ } nk_dot_bf16x8_state_serial_t;
370
+
371
+ NK_INTERNAL void nk_dot_bf16x8_init_serial(nk_dot_bf16x8_state_serial_t *state) {
372
+ state->sums[0] = 0, state->sums[1] = 0, state->sums[2] = 0, state->sums[3] = 0;
373
+ }
374
+
375
+ NK_INTERNAL void nk_dot_bf16x8_update_serial(nk_dot_bf16x8_state_serial_t *state, nk_b128_vec_t a, nk_b128_vec_t b,
376
+ nk_size_t depth_offset, nk_size_t active_dimensions) {
377
+ nk_unused_(depth_offset);
378
+ nk_unused_(active_dimensions);
379
+ nk_f32_t sum0 = state->sums[0], sum1 = state->sums[1], sum2 = state->sums[2], sum3 = state->sums[3];
380
+ for (nk_size_t i = 0; i < 8; i += 4) {
381
+ nk_f32_t a0, a1, a2, a3, b0, b1, b2, b3;
382
+ nk_bf16_to_f32_serial(a.bf16s + i + 0, &a0), nk_bf16_to_f32_serial(a.bf16s + i + 1, &a1);
383
+ nk_bf16_to_f32_serial(a.bf16s + i + 2, &a2), nk_bf16_to_f32_serial(a.bf16s + i + 3, &a3);
384
+ nk_bf16_to_f32_serial(b.bf16s + i + 0, &b0), nk_bf16_to_f32_serial(b.bf16s + i + 1, &b1);
385
+ nk_bf16_to_f32_serial(b.bf16s + i + 2, &b2), nk_bf16_to_f32_serial(b.bf16s + i + 3, &b3);
386
+ sum0 += a0 * b0, sum1 += a1 * b1, sum2 += a2 * b2, sum3 += a3 * b3;
387
+ }
388
+ state->sums[0] = sum0, state->sums[1] = sum1, state->sums[2] = sum2, state->sums[3] = sum3;
389
+ }
390
+
391
+ NK_INTERNAL void nk_dot_bf16x8_finalize_serial( //
392
+ nk_dot_bf16x8_state_serial_t const *state_a, nk_dot_bf16x8_state_serial_t const *state_b, //
393
+ nk_dot_bf16x8_state_serial_t const *state_c, nk_dot_bf16x8_state_serial_t const *state_d, //
394
+ nk_size_t total_dimensions, nk_b128_vec_t *result) {
395
+ nk_unused_(total_dimensions);
396
+ result->f32s[0] = state_a->sums[0] + state_a->sums[1] + state_a->sums[2] + state_a->sums[3];
397
+ result->f32s[1] = state_b->sums[0] + state_b->sums[1] + state_b->sums[2] + state_b->sums[3];
398
+ result->f32s[2] = state_c->sums[0] + state_c->sums[1] + state_c->sums[2] + state_c->sums[3];
399
+ result->f32s[3] = state_d->sums[0] + state_d->sums[1] + state_d->sums[2] + state_d->sums[3];
400
+ }
401
+
402
+ #pragma endregion - Smaller Floats
403
+
404
+ #pragma region - Small Integers
405
+
406
+ typedef struct nk_dot_i8x16_state_serial_t {
407
+ nk_i64_t sums[2];
408
+ } nk_dot_i8x16_state_serial_t;
409
+
410
+ NK_INTERNAL void nk_dot_i8x16_init_serial(nk_dot_i8x16_state_serial_t *state) {
411
+ state->sums[0] = 0, state->sums[1] = 0;
412
+ }
413
+
414
+ NK_INTERNAL void nk_dot_i8x16_update_serial(nk_dot_i8x16_state_serial_t *state, nk_b128_vec_t a, nk_b128_vec_t b,
415
+ nk_size_t depth_offset, nk_size_t active_dimensions) {
416
+ nk_unused_(depth_offset);
417
+ nk_unused_(active_dimensions);
418
+ nk_i64_t sum0 = state->sums[0];
419
+ nk_i64_t sum1 = state->sums[1];
420
+ sum0 += (nk_i16_t)a.i8s[0] * (nk_i16_t)b.i8s[0], sum1 += (nk_i16_t)a.i8s[1] * (nk_i16_t)b.i8s[1];
421
+ sum0 += (nk_i16_t)a.i8s[2] * (nk_i16_t)b.i8s[2], sum1 += (nk_i16_t)a.i8s[3] * (nk_i16_t)b.i8s[3];
422
+ sum0 += (nk_i16_t)a.i8s[4] * (nk_i16_t)b.i8s[4], sum1 += (nk_i16_t)a.i8s[5] * (nk_i16_t)b.i8s[5];
423
+ sum0 += (nk_i16_t)a.i8s[6] * (nk_i16_t)b.i8s[6], sum1 += (nk_i16_t)a.i8s[7] * (nk_i16_t)b.i8s[7];
424
+ sum0 += (nk_i16_t)a.i8s[8] * (nk_i16_t)b.i8s[8], sum1 += (nk_i16_t)a.i8s[9] * (nk_i16_t)b.i8s[9];
425
+ sum0 += (nk_i16_t)a.i8s[10] * (nk_i16_t)b.i8s[10], sum1 += (nk_i16_t)a.i8s[11] * (nk_i16_t)b.i8s[11];
426
+ sum0 += (nk_i16_t)a.i8s[12] * (nk_i16_t)b.i8s[12], sum1 += (nk_i16_t)a.i8s[13] * (nk_i16_t)b.i8s[13];
427
+ sum0 += (nk_i16_t)a.i8s[14] * (nk_i16_t)b.i8s[14], sum1 += (nk_i16_t)a.i8s[15] * (nk_i16_t)b.i8s[15];
428
+ state->sums[0] = sum0, state->sums[1] = sum1;
429
+ }
430
+
431
+ NK_INTERNAL void nk_dot_i8x16_finalize_serial( //
432
+ nk_dot_i8x16_state_serial_t const *state_a, nk_dot_i8x16_state_serial_t const *state_b, //
433
+ nk_dot_i8x16_state_serial_t const *state_c, nk_dot_i8x16_state_serial_t const *state_d, //
434
+ nk_size_t total_dimensions, nk_b128_vec_t *result) {
435
+ nk_unused_(total_dimensions);
436
+ result->i32s[0] = (nk_i32_t)(state_a->sums[0] + state_a->sums[1]);
437
+ result->i32s[1] = (nk_i32_t)(state_b->sums[0] + state_b->sums[1]);
438
+ result->i32s[2] = (nk_i32_t)(state_c->sums[0] + state_c->sums[1]);
439
+ result->i32s[3] = (nk_i32_t)(state_d->sums[0] + state_d->sums[1]);
440
+ }
441
+
442
+ typedef struct nk_dot_u8x16_state_serial_t {
443
+ nk_u64_t sums[2];
444
+ } nk_dot_u8x16_state_serial_t;
445
+
446
+ NK_INTERNAL void nk_dot_u8x16_init_serial(nk_dot_u8x16_state_serial_t *state) {
447
+ state->sums[0] = 0, state->sums[1] = 0;
448
+ }
449
+
450
+ NK_INTERNAL void nk_dot_u8x16_update_serial(nk_dot_u8x16_state_serial_t *state, nk_b128_vec_t a, nk_b128_vec_t b,
451
+ nk_size_t depth_offset, nk_size_t active_dimensions) {
452
+ nk_unused_(depth_offset);
453
+ nk_unused_(active_dimensions);
454
+ nk_u64_t sum0 = state->sums[0];
455
+ nk_u64_t sum1 = state->sums[1];
456
+
457
+ sum0 += (nk_u16_t)a.u8s[0] * (nk_u16_t)b.u8s[0], sum1 += (nk_u16_t)a.u8s[1] * (nk_u16_t)b.u8s[1];
458
+ sum0 += (nk_u16_t)a.u8s[2] * (nk_u16_t)b.u8s[2], sum1 += (nk_u16_t)a.u8s[3] * (nk_u16_t)b.u8s[3];
459
+ sum0 += (nk_u16_t)a.u8s[4] * (nk_u16_t)b.u8s[4], sum1 += (nk_u16_t)a.u8s[5] * (nk_u16_t)b.u8s[5];
460
+ sum0 += (nk_u16_t)a.u8s[6] * (nk_u16_t)b.u8s[6], sum1 += (nk_u16_t)a.u8s[7] * (nk_u16_t)b.u8s[7];
461
+ sum0 += (nk_u16_t)a.u8s[8] * (nk_u16_t)b.u8s[8], sum1 += (nk_u16_t)a.u8s[9] * (nk_u16_t)b.u8s[9];
462
+ sum0 += (nk_u16_t)a.u8s[10] * (nk_u16_t)b.u8s[10], sum1 += (nk_u16_t)a.u8s[11] * (nk_u16_t)b.u8s[11];
463
+ sum0 += (nk_u16_t)a.u8s[12] * (nk_u16_t)b.u8s[12], sum1 += (nk_u16_t)a.u8s[13] * (nk_u16_t)b.u8s[13];
464
+ sum0 += (nk_u16_t)a.u8s[14] * (nk_u16_t)b.u8s[14], sum1 += (nk_u16_t)a.u8s[15] * (nk_u16_t)b.u8s[15];
465
+ state->sums[0] = sum0, state->sums[1] = sum1;
466
+ }
467
+
468
+ NK_INTERNAL void nk_dot_u8x16_finalize_serial( //
469
+ nk_dot_u8x16_state_serial_t const *state_a, nk_dot_u8x16_state_serial_t const *state_b, //
470
+ nk_dot_u8x16_state_serial_t const *state_c, nk_dot_u8x16_state_serial_t const *state_d, //
471
+ nk_size_t total_dimensions, nk_b128_vec_t *result) {
472
+ nk_unused_(total_dimensions);
473
+ result->u32s[0] = (nk_u32_t)(state_a->sums[0] + state_a->sums[1]);
474
+ result->u32s[1] = (nk_u32_t)(state_b->sums[0] + state_b->sums[1]);
475
+ result->u32s[2] = (nk_u32_t)(state_c->sums[0] + state_c->sums[1]);
476
+ result->u32s[3] = (nk_u32_t)(state_d->sums[0] + state_d->sums[1]);
477
+ }
478
+
479
+ #pragma endregion - Small Integers
480
+
481
+ #pragma region - Smaller Floats
482
+
483
+ typedef struct nk_dot_e4m3x16_state_serial_t {
484
+ nk_f32_t sums[4];
485
+ } nk_dot_e4m3x16_state_serial_t;
486
+
487
+ NK_INTERNAL void nk_dot_e4m3x16_init_serial(nk_dot_e4m3x16_state_serial_t *state) {
488
+ state->sums[0] = 0, state->sums[1] = 0, state->sums[2] = 0, state->sums[3] = 0;
489
+ }
490
+
491
+ NK_INTERNAL void nk_dot_e4m3x16_update_serial(nk_dot_e4m3x16_state_serial_t *state, nk_b128_vec_t a, nk_b128_vec_t b,
492
+ nk_size_t depth_offset, nk_size_t active_dimensions) {
493
+ nk_unused_(depth_offset);
494
+ nk_unused_(active_dimensions);
495
+ nk_f32_t sum0 = state->sums[0];
496
+ nk_f32_t sum1 = state->sums[1];
497
+ nk_f32_t sum2 = state->sums[2];
498
+ nk_f32_t sum3 = state->sums[3];
499
+ nk_f32_t ai0, ai1, ai2, ai3;
500
+ nk_f32_t bi0, bi1, bi2, bi3;
501
+ for (nk_size_t i = 0; i != 16; i += 4) {
502
+ nk_e4m3_to_f32_serial(a.e4m3s + i, &ai0), nk_e4m3_to_f32_serial(b.e4m3s + i, &bi0);
503
+ nk_e4m3_to_f32_serial(a.e4m3s + i + 1, &ai1), nk_e4m3_to_f32_serial(b.e4m3s + i + 1, &bi1);
504
+ nk_e4m3_to_f32_serial(a.e4m3s + i + 2, &ai2), nk_e4m3_to_f32_serial(b.e4m3s + i + 2, &bi2);
505
+ nk_e4m3_to_f32_serial(a.e4m3s + i + 3, &ai3), nk_e4m3_to_f32_serial(b.e4m3s + i + 3, &bi3);
506
+ sum0 += ai0 * bi0, sum1 += ai1 * bi1, sum2 += ai2 * bi2, sum3 += ai3 * bi3;
507
+ }
508
+
509
+ state->sums[0] = sum0, state->sums[1] = sum1, state->sums[2] = sum2, state->sums[3] = sum3;
510
+ }
511
+
512
+ NK_INTERNAL void nk_dot_e4m3x16_finalize_serial( //
513
+ nk_dot_e4m3x16_state_serial_t const *state_a, nk_dot_e4m3x16_state_serial_t const *state_b, //
514
+ nk_dot_e4m3x16_state_serial_t const *state_c, nk_dot_e4m3x16_state_serial_t const *state_d, //
515
+ nk_size_t total_dimensions, nk_b128_vec_t *result) {
516
+ nk_unused_(total_dimensions);
517
+ result->f32s[0] = state_a->sums[0] + state_a->sums[1] + state_a->sums[2] + state_a->sums[3];
518
+ result->f32s[1] = state_b->sums[0] + state_b->sums[1] + state_b->sums[2] + state_b->sums[3];
519
+ result->f32s[2] = state_c->sums[0] + state_c->sums[1] + state_c->sums[2] + state_c->sums[3];
520
+ result->f32s[3] = state_d->sums[0] + state_d->sums[1] + state_d->sums[2] + state_d->sums[3];
521
+ }
522
+
523
+ typedef struct nk_dot_e5m2x16_state_serial_t {
524
+ nk_f32_t sums[4];
525
+ } nk_dot_e5m2x16_state_serial_t;
526
+
527
+ NK_INTERNAL void nk_dot_e5m2x16_init_serial(nk_dot_e5m2x16_state_serial_t *state) {
528
+ state->sums[0] = 0, state->sums[1] = 0, state->sums[2] = 0, state->sums[3] = 0;
529
+ }
530
+
531
+ NK_INTERNAL void nk_dot_e5m2x16_update_serial(nk_dot_e5m2x16_state_serial_t *state, nk_b128_vec_t a, nk_b128_vec_t b,
532
+ nk_size_t depth_offset, nk_size_t active_dimensions) {
533
+ nk_unused_(depth_offset);
534
+ nk_unused_(active_dimensions);
535
+ nk_f32_t sum0 = state->sums[0];
536
+ nk_f32_t sum1 = state->sums[1];
537
+ nk_f32_t sum2 = state->sums[2];
538
+ nk_f32_t sum3 = state->sums[3];
539
+ nk_f32_t ai0, ai1, ai2, ai3;
540
+ nk_f32_t bi0, bi1, bi2, bi3;
541
+ for (nk_size_t i = 0; i != 16; i += 4) {
542
+ nk_e5m2_to_f32_serial(a.e5m2s + i, &ai0), nk_e5m2_to_f32_serial(b.e5m2s + i, &bi0);
543
+ nk_e5m2_to_f32_serial(a.e5m2s + i + 1, &ai1), nk_e5m2_to_f32_serial(b.e5m2s + i + 1, &bi1);
544
+ nk_e5m2_to_f32_serial(a.e5m2s + i + 2, &ai2), nk_e5m2_to_f32_serial(b.e5m2s + i + 2, &bi2);
545
+ nk_e5m2_to_f32_serial(a.e5m2s + i + 3, &ai3), nk_e5m2_to_f32_serial(b.e5m2s + i + 3, &bi3);
546
+ sum0 += ai0 * bi0, sum1 += ai1 * bi1, sum2 += ai2 * bi2, sum3 += ai3 * bi3;
547
+ }
548
+
549
+ state->sums[0] = sum0, state->sums[1] = sum1, state->sums[2] = sum2, state->sums[3] = sum3;
550
+ }
551
+
552
+ NK_INTERNAL void nk_dot_e5m2x16_finalize_serial( //
553
+ nk_dot_e5m2x16_state_serial_t const *state_a, nk_dot_e5m2x16_state_serial_t const *state_b, //
554
+ nk_dot_e5m2x16_state_serial_t const *state_c, nk_dot_e5m2x16_state_serial_t const *state_d, //
555
+ nk_size_t total_dimensions, nk_b128_vec_t *result) {
556
+ nk_unused_(total_dimensions);
557
+ result->f32s[0] = state_a->sums[0] + state_a->sums[1] + state_a->sums[2] + state_a->sums[3];
558
+ result->f32s[1] = state_b->sums[0] + state_b->sums[1] + state_b->sums[2] + state_b->sums[3];
559
+ result->f32s[2] = state_c->sums[0] + state_c->sums[1] + state_c->sums[2] + state_c->sums[3];
560
+ result->f32s[3] = state_d->sums[0] + state_d->sums[1] + state_d->sums[2] + state_d->sums[3];
561
+ }
562
+
563
+ typedef struct nk_dot_e2m3x16_state_serial_t {
564
+ nk_f32_t sums[4];
565
+ } nk_dot_e2m3x16_state_serial_t;
566
+
567
+ NK_INTERNAL void nk_dot_e2m3x16_init_serial(nk_dot_e2m3x16_state_serial_t *state) {
568
+ state->sums[0] = 0, state->sums[1] = 0, state->sums[2] = 0, state->sums[3] = 0;
569
+ }
570
+
571
+ NK_INTERNAL void nk_dot_e2m3x16_update_serial(nk_dot_e2m3x16_state_serial_t *state, nk_b128_vec_t a, nk_b128_vec_t b,
572
+ nk_size_t depth_offset, nk_size_t active_dimensions) {
573
+ nk_unused_(depth_offset);
574
+ nk_unused_(active_dimensions);
575
+ nk_f32_t sum0 = state->sums[0];
576
+ nk_f32_t sum1 = state->sums[1];
577
+ nk_f32_t sum2 = state->sums[2];
578
+ nk_f32_t sum3 = state->sums[3];
579
+ nk_f32_t ai0, ai1, ai2, ai3;
580
+ nk_f32_t bi0, bi1, bi2, bi3;
581
+ for (nk_size_t i = 0; i != 16; i += 4) {
582
+ nk_e2m3_to_f32_serial(a.e2m3s + i, &ai0), nk_e2m3_to_f32_serial(b.e2m3s + i, &bi0);
583
+ nk_e2m3_to_f32_serial(a.e2m3s + i + 1, &ai1), nk_e2m3_to_f32_serial(b.e2m3s + i + 1, &bi1);
584
+ nk_e2m3_to_f32_serial(a.e2m3s + i + 2, &ai2), nk_e2m3_to_f32_serial(b.e2m3s + i + 2, &bi2);
585
+ nk_e2m3_to_f32_serial(a.e2m3s + i + 3, &ai3), nk_e2m3_to_f32_serial(b.e2m3s + i + 3, &bi3);
586
+ sum0 += ai0 * bi0, sum1 += ai1 * bi1, sum2 += ai2 * bi2, sum3 += ai3 * bi3;
587
+ }
588
+
589
+ state->sums[0] = sum0, state->sums[1] = sum1, state->sums[2] = sum2, state->sums[3] = sum3;
590
+ }
591
+
592
+ NK_INTERNAL void nk_dot_e2m3x16_finalize_serial( //
593
+ nk_dot_e2m3x16_state_serial_t const *state_a, nk_dot_e2m3x16_state_serial_t const *state_b, //
594
+ nk_dot_e2m3x16_state_serial_t const *state_c, nk_dot_e2m3x16_state_serial_t const *state_d, //
595
+ nk_size_t total_dimensions, nk_b128_vec_t *result) {
596
+ nk_unused_(total_dimensions);
597
+ result->f32s[0] = state_a->sums[0] + state_a->sums[1] + state_a->sums[2] + state_a->sums[3];
598
+ result->f32s[1] = state_b->sums[0] + state_b->sums[1] + state_b->sums[2] + state_b->sums[3];
599
+ result->f32s[2] = state_c->sums[0] + state_c->sums[1] + state_c->sums[2] + state_c->sums[3];
600
+ result->f32s[3] = state_d->sums[0] + state_d->sums[1] + state_d->sums[2] + state_d->sums[3];
601
+ }
602
+
603
+ typedef struct nk_dot_e3m2x16_state_serial_t {
604
+ nk_f32_t sums[4];
605
+ } nk_dot_e3m2x16_state_serial_t;
606
+
607
+ NK_INTERNAL void nk_dot_e3m2x16_init_serial(nk_dot_e3m2x16_state_serial_t *state) {
608
+ state->sums[0] = 0, state->sums[1] = 0, state->sums[2] = 0, state->sums[3] = 0;
609
+ }
610
+
611
+ NK_INTERNAL void nk_dot_e3m2x16_update_serial(nk_dot_e3m2x16_state_serial_t *state, nk_b128_vec_t a, nk_b128_vec_t b,
612
+ nk_size_t depth_offset, nk_size_t active_dimensions) {
613
+ nk_unused_(depth_offset);
614
+ nk_unused_(active_dimensions);
615
+ nk_f32_t sum0 = state->sums[0];
616
+ nk_f32_t sum1 = state->sums[1];
617
+ nk_f32_t sum2 = state->sums[2];
618
+ nk_f32_t sum3 = state->sums[3];
619
+ nk_f32_t ai0, ai1, ai2, ai3;
620
+ nk_f32_t bi0, bi1, bi2, bi3;
621
+ for (nk_size_t i = 0; i != 16; i += 4) {
622
+ nk_e3m2_to_f32_serial(a.e3m2s + i, &ai0), nk_e3m2_to_f32_serial(b.e3m2s + i, &bi0);
623
+ nk_e3m2_to_f32_serial(a.e3m2s + i + 1, &ai1), nk_e3m2_to_f32_serial(b.e3m2s + i + 1, &bi1);
624
+ nk_e3m2_to_f32_serial(a.e3m2s + i + 2, &ai2), nk_e3m2_to_f32_serial(b.e3m2s + i + 2, &bi2);
625
+ nk_e3m2_to_f32_serial(a.e3m2s + i + 3, &ai3), nk_e3m2_to_f32_serial(b.e3m2s + i + 3, &bi3);
626
+ sum0 += ai0 * bi0, sum1 += ai1 * bi1, sum2 += ai2 * bi2, sum3 += ai3 * bi3;
627
+ }
628
+
629
+ state->sums[0] = sum0, state->sums[1] = sum1, state->sums[2] = sum2, state->sums[3] = sum3;
630
+ }
631
+
632
+ NK_INTERNAL void nk_dot_e3m2x16_finalize_serial( //
633
+ nk_dot_e3m2x16_state_serial_t const *state_a, nk_dot_e3m2x16_state_serial_t const *state_b, //
634
+ nk_dot_e3m2x16_state_serial_t const *state_c, nk_dot_e3m2x16_state_serial_t const *state_d, //
635
+ nk_size_t total_dimensions, nk_b128_vec_t *result) {
636
+ nk_unused_(total_dimensions);
637
+ result->f32s[0] = state_a->sums[0] + state_a->sums[1] + state_a->sums[2] + state_a->sums[3];
638
+ result->f32s[1] = state_b->sums[0] + state_b->sums[1] + state_b->sums[2] + state_b->sums[3];
639
+ result->f32s[2] = state_c->sums[0] + state_c->sums[1] + state_c->sums[2] + state_c->sums[3];
640
+ result->f32s[3] = state_d->sums[0] + state_d->sums[1] + state_d->sums[2] + state_d->sums[3];
641
+ }
642
+
643
+ #pragma endregion - Smaller Floats
644
+
645
+ #pragma region - Small Integers
646
+
647
+ // U4x2 state: processes 16 nibbles (8 bytes = 64 bits) per update
648
+ typedef struct nk_dot_u4x16_state_serial_t {
649
+ nk_u64_t sums[2]; // sums[0]: low nibbles, sums[1]: high nibbles
650
+ } nk_dot_u4x16_state_serial_t;
651
+
652
+ NK_INTERNAL void nk_dot_u4x16_init_serial(nk_dot_u4x16_state_serial_t *state) {
653
+ state->sums[0] = 0, state->sums[1] = 0;
654
+ }
655
+
656
+ NK_INTERNAL void nk_dot_u4x16_update_serial(nk_dot_u4x16_state_serial_t *state, nk_b64_vec_t a, nk_b64_vec_t b,
657
+ nk_size_t depth_offset, nk_size_t active_dimensions) {
658
+ nk_unused_(depth_offset);
659
+ nk_unused_(active_dimensions);
660
+ // Process 8 bytes (16 nibbles total) using SWAR
661
+ // Separate accumulators for low and high nibbles
662
+ nk_u64_t sum_low = state->sums[0];
663
+ nk_u64_t sum_high = state->sums[1];
664
+
665
+ // Process all 8 bytes, extracting and multiplying nibbles
666
+ for (nk_size_t i = 0; i < 8; i++) {
667
+ nk_u8_t a_byte = a.u8s[i];
668
+ nk_u8_t b_byte = b.u8s[i];
669
+
670
+ // Extract low and high nibbles using SWAR masks
671
+ nk_u8_t a_low = a_byte & 0x0F;
672
+ nk_u8_t b_low = b_byte & 0x0F;
673
+ nk_u8_t a_high = (a_byte >> 4) & 0x0F;
674
+ nk_u8_t b_high = (b_byte >> 4) & 0x0F;
675
+
676
+ // Accumulate products into separate accumulators
677
+ sum_low += (nk_u32_t)a_low * (nk_u32_t)b_low;
678
+ sum_high += (nk_u32_t)a_high * (nk_u32_t)b_high;
679
+ }
680
+
681
+ state->sums[0] = sum_low, state->sums[1] = sum_high;
682
+ }
683
+
684
+ NK_INTERNAL void nk_dot_u4x16_finalize_serial(nk_dot_u4x16_state_serial_t const *state_a,
685
+ nk_dot_u4x16_state_serial_t const *state_b,
686
+ nk_dot_u4x16_state_serial_t const *state_c,
687
+ nk_dot_u4x16_state_serial_t const *state_d, nk_size_t total_dimensions,
688
+ nk_b128_vec_t *result) {
689
+ nk_unused_(total_dimensions);
690
+ result->u32s[0] = (nk_u32_t)(state_a->sums[0] + state_a->sums[1]);
691
+ result->u32s[1] = (nk_u32_t)(state_b->sums[0] + state_b->sums[1]);
692
+ result->u32s[2] = (nk_u32_t)(state_c->sums[0] + state_c->sums[1]);
693
+ result->u32s[3] = (nk_u32_t)(state_d->sums[0] + state_d->sums[1]);
694
+ }
695
+
696
+ NK_INTERNAL void nk_load_i4x16_to_i8x16_serial_(void const *src, nk_b128_vec_t *dst) {
697
+ nk_i4_to_i8_serial_((nk_i4x2_t const *)src, dst->i8s, 16);
698
+ }
699
+
700
+ NK_INTERNAL void nk_partial_load_i4x16_to_i8x16_serial_(void const *src, nk_b128_vec_t *dst, nk_size_t n) {
701
+ nk_i4_to_i8_serial_((nk_i4x2_t const *)src, dst->i8s, n);
702
+ for (nk_size_t i = n; i < 16; ++i) dst->i8s[i] = 0;
703
+ }
704
+
705
+ NK_INTERNAL void nk_load_u4x16_to_u8x16_serial_(void const *src, nk_b128_vec_t *dst) {
706
+ nk_u4_to_u8_serial_((nk_u4x2_t const *)src, dst->u8s, 16);
707
+ }
708
+
709
+ NK_INTERNAL void nk_partial_load_u4x16_to_u8x16_serial_(void const *src, nk_b128_vec_t *dst, nk_size_t n) {
710
+ nk_u4_to_u8_serial_((nk_u4x2_t const *)src, dst->u8s, n);
711
+ for (nk_size_t i = n; i < 16; ++i) dst->u8s[i] = 0;
712
+ }
713
+
714
+ typedef struct nk_dot_i4x16_state_serial_t {
715
+ nk_i64_t sums[2]; // sums[0]: low nibbles, sums[1]: high nibbles
716
+ } nk_dot_i4x16_state_serial_t;
717
+
718
+ NK_INTERNAL void nk_dot_i4x16_init_serial(nk_dot_i4x16_state_serial_t *state) {
719
+ state->sums[0] = 0, state->sums[1] = 0;
720
+ }
721
+
722
+ NK_INTERNAL void nk_dot_i4x16_update_serial(nk_dot_i4x16_state_serial_t *state, nk_b64_vec_t a, nk_b64_vec_t b,
723
+ nk_size_t depth_offset, nk_size_t active_dimensions) {
724
+ nk_unused_(depth_offset);
725
+ nk_unused_(active_dimensions);
726
+ // Process 8 bytes (16 nibbles total) using SWAR with sign extension
727
+ // Separate accumulators for low and high nibbles
728
+ nk_i64_t sum_low = state->sums[0];
729
+ nk_i64_t sum_high = state->sums[1];
730
+
731
+ // Process all 8 bytes, extracting and multiplying signed nibbles
732
+ for (nk_size_t i = 0; i < 8; i++) {
733
+ nk_u8_t a_byte = a.u8s[i];
734
+ nk_u8_t b_byte = b.u8s[i];
735
+
736
+ // Extract nibbles and sign extend: (nibble ^ 8) - 8 maps [0,15] → [-8,7]
737
+ nk_i8_t a_low = (nk_i8_t)(((a_byte & 0x0F) ^ 8) - 8);
738
+ nk_i8_t b_low = (nk_i8_t)(((b_byte & 0x0F) ^ 8) - 8);
739
+ nk_i8_t a_high = (nk_i8_t)((((a_byte >> 4) & 0x0F) ^ 8) - 8);
740
+ nk_i8_t b_high = (nk_i8_t)((((b_byte >> 4) & 0x0F) ^ 8) - 8);
741
+
742
+ // Accumulate products into separate accumulators
743
+ sum_low += (nk_i32_t)a_low * (nk_i32_t)b_low;
744
+ sum_high += (nk_i32_t)a_high * (nk_i32_t)b_high;
745
+ }
746
+
747
+ state->sums[0] = sum_low, state->sums[1] = sum_high;
748
+ }
749
+
750
+ NK_INTERNAL void nk_dot_i4x16_finalize_serial(nk_dot_i4x16_state_serial_t const *state_a,
751
+ nk_dot_i4x16_state_serial_t const *state_b,
752
+ nk_dot_i4x16_state_serial_t const *state_c,
753
+ nk_dot_i4x16_state_serial_t const *state_d, nk_size_t total_dimensions,
754
+ nk_b128_vec_t *result) {
755
+ nk_unused_(total_dimensions);
756
+ result->i32s[0] = (nk_i32_t)(state_a->sums[0] + state_a->sums[1]);
757
+ result->i32s[1] = (nk_i32_t)(state_b->sums[0] + state_b->sums[1]);
758
+ result->i32s[2] = (nk_i32_t)(state_c->sums[0] + state_c->sums[1]);
759
+ result->i32s[3] = (nk_i32_t)(state_d->sums[0] + state_d->sums[1]);
760
+ }
761
+
762
+ #pragma endregion - Small Integers
763
+
764
+ #pragma region - Binary
765
+
766
+ NK_PUBLIC void nk_dot_u1_serial(nk_u1x8_t const *a, nk_u1x8_t const *b, nk_size_t n_bits, nk_u32_t *result) {
767
+ nk_u32_t dot = 0;
768
+ nk_size_t bytes = nk_size_divide_round_up_(n_bits, NK_BITS_PER_BYTE);
769
+ for (nk_size_t i = 0; i < bytes; ++i) dot += nk_u1x8_popcount_(((nk_u8_t const *)a)[i] & ((nk_u8_t const *)b)[i]);
770
+ *result = dot;
771
+ }
772
+
773
+ typedef struct nk_dot_u1x128_state_serial_t {
774
+ nk_u32_t dot_count;
775
+ } nk_dot_u1x128_state_serial_t;
776
+
777
+ NK_INTERNAL void nk_dot_u1x128_init_serial(nk_dot_u1x128_state_serial_t *state) { state->dot_count = 0; }
778
+
779
+ NK_INTERNAL void nk_dot_u1x128_update_serial(nk_dot_u1x128_state_serial_t *state, nk_b128_vec_t a, nk_b128_vec_t b,
780
+ nk_size_t depth_offset, nk_size_t active_dimensions) {
781
+ nk_unused_(depth_offset);
782
+ nk_unused_(active_dimensions);
783
+ nk_u64_t and_low = a.u64s[0] & b.u64s[0];
784
+ nk_u64_t and_high = a.u64s[1] & b.u64s[1];
785
+ state->dot_count += (nk_u32_t)nk_u64_popcount_(and_low);
786
+ state->dot_count += (nk_u32_t)nk_u64_popcount_(and_high);
787
+ }
788
+
789
+ NK_INTERNAL void nk_dot_u1x128_finalize_serial(nk_dot_u1x128_state_serial_t const *state_a,
790
+ nk_dot_u1x128_state_serial_t const *state_b,
791
+ nk_dot_u1x128_state_serial_t const *state_c,
792
+ nk_dot_u1x128_state_serial_t const *state_d, nk_size_t total_dimensions,
793
+ nk_b128_vec_t *result) {
794
+ nk_unused_(total_dimensions);
795
+ result->u32s[0] = state_a->dot_count;
796
+ result->u32s[1] = state_b->dot_count;
797
+ result->u32s[2] = state_c->dot_count;
798
+ result->u32s[3] = state_d->dot_count;
799
+ }
800
+
801
+ #pragma endregion - Binary
802
+
803
+ /**
804
+ * Serial fallback sum helpers for progressive element-sum accumulation.
805
+ * Used by the compensated symmetric GEMM macro to piggyback sum computation
806
+ * on the depth loop's already-loaded vectors, avoiding a separate sum pass.
807
+ */
808
+
809
+ #pragma region - Stateful Element Sum Helpers (for compensated GEMM)
810
+
811
+ /* i4x32: Haswell i4 (nk_b128_vec_t containing 32 nibbles in 16 bytes) */
812
+ typedef struct nk_sum_i4x32_state_serial_t {
813
+ nk_i64_t sum;
814
+ } nk_sum_i4x32_state_serial_t;
815
+
816
+ NK_INTERNAL void nk_sum_i4x32_init_serial(nk_sum_i4x32_state_serial_t *state) { state->sum = 0; }
817
+
818
+ NK_INTERNAL void nk_sum_i4x32_update_serial(nk_sum_i4x32_state_serial_t *state, nk_b128_vec_t v) {
819
+ nk_u8_t const *d = (nk_u8_t const *)&v;
820
+ for (int i = 0; i < 16; i++) {
821
+ nk_i8_t low = (nk_i8_t)((d[i] & 0x0F) ^ 0x08) - 8; /* sign-extend low nibble */
822
+ nk_i8_t high = (nk_i8_t)((d[i] >> 4) ^ 0x08) - 8; /* sign-extend high nibble */
823
+ state->sum += low + high;
824
+ }
825
+ }
826
+
827
+ NK_INTERNAL nk_i32_t nk_sum_i4x32_finalize_serial(nk_sum_i4x32_state_serial_t const *state, nk_size_t count) {
828
+ nk_unused_(count);
829
+ return (nk_i32_t)state->sum;
830
+ }
831
+
832
+ #pragma endregion - Stateful Element Sum Helpers
833
+
834
+ #if defined(__cplusplus)
835
+ } // extern "C"
836
+ #endif
837
+
838
+ #endif // NK_DOT_SERIAL_H