numkong 7.0.0 → 7.4.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (315) hide show
  1. package/README.md +197 -124
  2. package/binding.gyp +34 -484
  3. package/c/dispatch_bf16.c +59 -1
  4. package/c/dispatch_e2m3.c +41 -8
  5. package/c/dispatch_e3m2.c +49 -8
  6. package/c/dispatch_e4m3.c +51 -9
  7. package/c/dispatch_e5m2.c +45 -1
  8. package/c/dispatch_f16.c +79 -26
  9. package/c/dispatch_f16c.c +5 -5
  10. package/c/dispatch_f32.c +56 -0
  11. package/c/dispatch_f64.c +52 -0
  12. package/c/dispatch_i4.c +3 -0
  13. package/c/dispatch_i8.c +62 -3
  14. package/c/dispatch_other.c +18 -0
  15. package/c/dispatch_u1.c +54 -9
  16. package/c/dispatch_u4.c +3 -0
  17. package/c/dispatch_u8.c +64 -3
  18. package/c/numkong.c +3 -0
  19. package/include/README.md +79 -9
  20. package/include/numkong/attention/sapphireamx.h +278 -276
  21. package/include/numkong/attention/sme.h +983 -977
  22. package/include/numkong/attention.h +1 -1
  23. package/include/numkong/capabilities.h +289 -94
  24. package/include/numkong/cast/README.md +40 -40
  25. package/include/numkong/cast/diamond.h +64 -0
  26. package/include/numkong/cast/haswell.h +42 -194
  27. package/include/numkong/cast/icelake.h +42 -37
  28. package/include/numkong/cast/loongsonasx.h +252 -0
  29. package/include/numkong/cast/neon.h +216 -249
  30. package/include/numkong/cast/powervsx.h +449 -0
  31. package/include/numkong/cast/rvv.h +223 -274
  32. package/include/numkong/cast/sapphire.h +18 -18
  33. package/include/numkong/cast/serial.h +1018 -944
  34. package/include/numkong/cast/skylake.h +82 -23
  35. package/include/numkong/cast/v128relaxed.h +462 -105
  36. package/include/numkong/cast.h +24 -0
  37. package/include/numkong/cast.hpp +44 -0
  38. package/include/numkong/curved/README.md +17 -17
  39. package/include/numkong/curved/neon.h +131 -7
  40. package/include/numkong/curved/neonbfdot.h +6 -7
  41. package/include/numkong/curved/rvv.h +26 -26
  42. package/include/numkong/curved/smef64.h +186 -182
  43. package/include/numkong/curved.h +14 -18
  44. package/include/numkong/dot/README.md +154 -137
  45. package/include/numkong/dot/alder.h +43 -43
  46. package/include/numkong/dot/diamond.h +158 -0
  47. package/include/numkong/dot/genoa.h +4 -30
  48. package/include/numkong/dot/haswell.h +215 -180
  49. package/include/numkong/dot/icelake.h +190 -76
  50. package/include/numkong/dot/loongsonasx.h +671 -0
  51. package/include/numkong/dot/neon.h +124 -73
  52. package/include/numkong/dot/neonbfdot.h +11 -12
  53. package/include/numkong/dot/neonfhm.h +44 -46
  54. package/include/numkong/dot/neonfp8.h +323 -0
  55. package/include/numkong/dot/neonsdot.h +190 -76
  56. package/include/numkong/dot/powervsx.h +752 -0
  57. package/include/numkong/dot/rvv.h +92 -84
  58. package/include/numkong/dot/rvvbf16.h +12 -12
  59. package/include/numkong/dot/rvvhalf.h +12 -12
  60. package/include/numkong/dot/sapphire.h +4 -4
  61. package/include/numkong/dot/serial.h +66 -30
  62. package/include/numkong/dot/sierra.h +31 -31
  63. package/include/numkong/dot/skylake.h +142 -110
  64. package/include/numkong/dot/sve.h +217 -177
  65. package/include/numkong/dot/svebfdot.h +10 -10
  66. package/include/numkong/dot/svehalf.h +85 -41
  67. package/include/numkong/dot/svesdot.h +89 -0
  68. package/include/numkong/dot/v128relaxed.h +124 -89
  69. package/include/numkong/dot.h +114 -48
  70. package/include/numkong/dots/README.md +203 -203
  71. package/include/numkong/dots/alder.h +12 -9
  72. package/include/numkong/dots/diamond.h +86 -0
  73. package/include/numkong/dots/genoa.h +10 -4
  74. package/include/numkong/dots/haswell.h +63 -48
  75. package/include/numkong/dots/icelake.h +27 -18
  76. package/include/numkong/dots/loongsonasx.h +176 -0
  77. package/include/numkong/dots/neon.h +14 -11
  78. package/include/numkong/dots/neonbfdot.h +4 -3
  79. package/include/numkong/dots/neonfhm.h +11 -9
  80. package/include/numkong/dots/neonfp8.h +99 -0
  81. package/include/numkong/dots/neonsdot.h +48 -12
  82. package/include/numkong/dots/powervsx.h +194 -0
  83. package/include/numkong/dots/rvv.h +451 -344
  84. package/include/numkong/dots/sapphireamx.h +1028 -984
  85. package/include/numkong/dots/serial.h +213 -197
  86. package/include/numkong/dots/sierra.h +10 -7
  87. package/include/numkong/dots/skylake.h +47 -36
  88. package/include/numkong/dots/sme.h +2001 -2364
  89. package/include/numkong/dots/smebi32.h +175 -162
  90. package/include/numkong/dots/smef64.h +328 -323
  91. package/include/numkong/dots/v128relaxed.h +64 -41
  92. package/include/numkong/dots.h +573 -293
  93. package/include/numkong/dots.hpp +45 -43
  94. package/include/numkong/each/README.md +133 -137
  95. package/include/numkong/each/haswell.h +6 -6
  96. package/include/numkong/each/icelake.h +7 -7
  97. package/include/numkong/each/neon.h +76 -42
  98. package/include/numkong/each/neonbfdot.h +11 -12
  99. package/include/numkong/each/neonhalf.h +24 -116
  100. package/include/numkong/each/rvv.h +28 -28
  101. package/include/numkong/each/sapphire.h +27 -161
  102. package/include/numkong/each/serial.h +6 -6
  103. package/include/numkong/each/skylake.h +7 -7
  104. package/include/numkong/each/v128relaxed.h +562 -0
  105. package/include/numkong/each.h +148 -62
  106. package/include/numkong/each.hpp +2 -2
  107. package/include/numkong/geospatial/README.md +18 -18
  108. package/include/numkong/geospatial/haswell.h +365 -325
  109. package/include/numkong/geospatial/neon.h +350 -306
  110. package/include/numkong/geospatial/rvv.h +4 -4
  111. package/include/numkong/geospatial/skylake.h +376 -340
  112. package/include/numkong/geospatial/v128relaxed.h +366 -327
  113. package/include/numkong/geospatial.h +17 -17
  114. package/include/numkong/matrix.hpp +4 -4
  115. package/include/numkong/maxsim/README.md +14 -14
  116. package/include/numkong/maxsim/alder.h +6 -6
  117. package/include/numkong/maxsim/genoa.h +4 -4
  118. package/include/numkong/maxsim/haswell.h +6 -6
  119. package/include/numkong/maxsim/icelake.h +18 -18
  120. package/include/numkong/maxsim/neonsdot.h +21 -21
  121. package/include/numkong/maxsim/sapphireamx.h +14 -14
  122. package/include/numkong/maxsim/serial.h +6 -6
  123. package/include/numkong/maxsim/sme.h +221 -196
  124. package/include/numkong/maxsim/v128relaxed.h +6 -6
  125. package/include/numkong/mesh/README.md +62 -56
  126. package/include/numkong/mesh/haswell.h +339 -464
  127. package/include/numkong/mesh/neon.h +1100 -519
  128. package/include/numkong/mesh/neonbfdot.h +36 -68
  129. package/include/numkong/mesh/rvv.h +530 -435
  130. package/include/numkong/mesh/serial.h +75 -91
  131. package/include/numkong/mesh/skylake.h +1627 -302
  132. package/include/numkong/mesh/v128relaxed.h +443 -330
  133. package/include/numkong/mesh.h +63 -49
  134. package/include/numkong/mesh.hpp +4 -4
  135. package/include/numkong/numkong.h +3 -3
  136. package/include/numkong/numkong.hpp +1 -0
  137. package/include/numkong/probability/README.md +23 -19
  138. package/include/numkong/probability/neon.h +82 -52
  139. package/include/numkong/probability/rvv.h +28 -23
  140. package/include/numkong/probability/serial.h +51 -39
  141. package/include/numkong/probability.h +20 -23
  142. package/include/numkong/random.h +1 -1
  143. package/include/numkong/reduce/README.md +143 -138
  144. package/include/numkong/reduce/alder.h +81 -77
  145. package/include/numkong/reduce/haswell.h +222 -220
  146. package/include/numkong/reduce/neon.h +629 -519
  147. package/include/numkong/reduce/neonbfdot.h +7 -218
  148. package/include/numkong/reduce/neonfhm.h +9 -381
  149. package/include/numkong/reduce/neonsdot.h +9 -9
  150. package/include/numkong/reduce/rvv.h +928 -802
  151. package/include/numkong/reduce/serial.h +23 -27
  152. package/include/numkong/reduce/sierra.h +20 -20
  153. package/include/numkong/reduce/skylake.h +326 -324
  154. package/include/numkong/reduce/v128relaxed.h +52 -52
  155. package/include/numkong/reduce.h +4 -23
  156. package/include/numkong/reduce.hpp +156 -11
  157. package/include/numkong/scalar/README.md +6 -6
  158. package/include/numkong/scalar/haswell.h +26 -17
  159. package/include/numkong/scalar/loongsonasx.h +74 -0
  160. package/include/numkong/scalar/neon.h +9 -9
  161. package/include/numkong/scalar/powervsx.h +96 -0
  162. package/include/numkong/scalar/rvv.h +2 -2
  163. package/include/numkong/scalar/sapphire.h +21 -10
  164. package/include/numkong/scalar/serial.h +21 -21
  165. package/include/numkong/scalar.h +13 -0
  166. package/include/numkong/set/README.md +28 -28
  167. package/include/numkong/set/haswell.h +12 -12
  168. package/include/numkong/set/icelake.h +14 -14
  169. package/include/numkong/set/loongsonasx.h +181 -0
  170. package/include/numkong/set/neon.h +17 -18
  171. package/include/numkong/set/powervsx.h +326 -0
  172. package/include/numkong/set/rvv.h +4 -4
  173. package/include/numkong/set/serial.h +6 -6
  174. package/include/numkong/set/sve.h +60 -59
  175. package/include/numkong/set/v128relaxed.h +6 -6
  176. package/include/numkong/set.h +21 -7
  177. package/include/numkong/sets/README.md +26 -26
  178. package/include/numkong/sets/loongsonasx.h +52 -0
  179. package/include/numkong/sets/powervsx.h +65 -0
  180. package/include/numkong/sets/smebi32.h +395 -364
  181. package/include/numkong/sets.h +83 -40
  182. package/include/numkong/sparse/README.md +4 -4
  183. package/include/numkong/sparse/icelake.h +101 -101
  184. package/include/numkong/sparse/serial.h +1 -1
  185. package/include/numkong/sparse/sve2.h +137 -141
  186. package/include/numkong/sparse/turin.h +12 -12
  187. package/include/numkong/sparse.h +10 -10
  188. package/include/numkong/spatial/README.md +230 -226
  189. package/include/numkong/spatial/alder.h +113 -116
  190. package/include/numkong/spatial/diamond.h +240 -0
  191. package/include/numkong/spatial/genoa.h +0 -68
  192. package/include/numkong/spatial/haswell.h +74 -55
  193. package/include/numkong/spatial/icelake.h +539 -58
  194. package/include/numkong/spatial/loongsonasx.h +483 -0
  195. package/include/numkong/spatial/neon.h +125 -52
  196. package/include/numkong/spatial/neonbfdot.h +8 -9
  197. package/include/numkong/spatial/neonfp8.h +258 -0
  198. package/include/numkong/spatial/neonsdot.h +180 -12
  199. package/include/numkong/spatial/powervsx.h +738 -0
  200. package/include/numkong/spatial/rvv.h +146 -139
  201. package/include/numkong/spatial/rvvbf16.h +17 -12
  202. package/include/numkong/spatial/rvvhalf.h +13 -10
  203. package/include/numkong/spatial/serial.h +13 -12
  204. package/include/numkong/spatial/sierra.h +232 -39
  205. package/include/numkong/spatial/skylake.h +73 -74
  206. package/include/numkong/spatial/sve.h +93 -72
  207. package/include/numkong/spatial/svebfdot.h +29 -29
  208. package/include/numkong/spatial/svehalf.h +52 -26
  209. package/include/numkong/spatial/svesdot.h +142 -0
  210. package/include/numkong/spatial/v128relaxed.h +293 -41
  211. package/include/numkong/spatial.h +338 -82
  212. package/include/numkong/spatials/README.md +194 -194
  213. package/include/numkong/spatials/diamond.h +82 -0
  214. package/include/numkong/spatials/haswell.h +2 -2
  215. package/include/numkong/spatials/loongsonasx.h +153 -0
  216. package/include/numkong/spatials/neonfp8.h +111 -0
  217. package/include/numkong/spatials/neonsdot.h +34 -0
  218. package/include/numkong/spatials/powervsx.h +153 -0
  219. package/include/numkong/spatials/rvv.h +259 -243
  220. package/include/numkong/spatials/sapphireamx.h +173 -173
  221. package/include/numkong/spatials/serial.h +2 -2
  222. package/include/numkong/spatials/skylake.h +2 -2
  223. package/include/numkong/spatials/sme.h +590 -605
  224. package/include/numkong/spatials/smef64.h +139 -130
  225. package/include/numkong/spatials/v128relaxed.h +2 -2
  226. package/include/numkong/spatials.h +820 -500
  227. package/include/numkong/spatials.hpp +49 -48
  228. package/include/numkong/tensor.hpp +406 -17
  229. package/include/numkong/trigonometry/README.md +19 -19
  230. package/include/numkong/trigonometry/haswell.h +402 -401
  231. package/include/numkong/trigonometry/neon.h +386 -387
  232. package/include/numkong/trigonometry/rvv.h +52 -51
  233. package/include/numkong/trigonometry/serial.h +13 -13
  234. package/include/numkong/trigonometry/skylake.h +373 -369
  235. package/include/numkong/trigonometry/v128relaxed.h +375 -374
  236. package/include/numkong/trigonometry.h +13 -13
  237. package/include/numkong/trigonometry.hpp +2 -2
  238. package/include/numkong/types.h +287 -49
  239. package/include/numkong/types.hpp +436 -12
  240. package/include/numkong/vector.hpp +82 -14
  241. package/javascript/dist/cjs/numkong-wasm.js +6 -12
  242. package/javascript/dist/cjs/numkong.d.ts +7 -1
  243. package/javascript/dist/cjs/numkong.js +37 -11
  244. package/javascript/dist/cjs/types.d.ts +9 -0
  245. package/javascript/dist/cjs/types.js +96 -0
  246. package/javascript/dist/esm/numkong-browser.d.ts +14 -0
  247. package/javascript/dist/esm/numkong-browser.js +23 -0
  248. package/javascript/dist/esm/numkong-wasm.js +6 -12
  249. package/javascript/dist/esm/numkong.d.ts +7 -1
  250. package/javascript/dist/esm/numkong.js +37 -11
  251. package/javascript/dist/esm/types.d.ts +9 -0
  252. package/javascript/dist/esm/types.js +96 -0
  253. package/javascript/node-gyp-build.d.ts +4 -1
  254. package/javascript/numkong-browser.ts +40 -0
  255. package/javascript/numkong-wasm.ts +7 -13
  256. package/javascript/numkong.c +5 -26
  257. package/javascript/numkong.ts +36 -11
  258. package/javascript/tsconfig-base.json +1 -0
  259. package/javascript/tsconfig-cjs.json +6 -1
  260. package/javascript/types.ts +110 -0
  261. package/numkong.gypi +101 -0
  262. package/package.json +34 -13
  263. package/probes/arm_neon.c +8 -0
  264. package/probes/arm_neon_bfdot.c +9 -0
  265. package/probes/arm_neon_fhm.c +9 -0
  266. package/probes/arm_neon_half.c +8 -0
  267. package/probes/arm_neon_sdot.c +9 -0
  268. package/probes/arm_neonfp8.c +9 -0
  269. package/probes/arm_sme.c +16 -0
  270. package/probes/arm_sme2.c +16 -0
  271. package/probes/arm_sme2p1.c +16 -0
  272. package/probes/arm_sme_bf16.c +16 -0
  273. package/probes/arm_sme_bi32.c +16 -0
  274. package/probes/arm_sme_f64.c +16 -0
  275. package/probes/arm_sme_fa64.c +14 -0
  276. package/probes/arm_sme_half.c +16 -0
  277. package/probes/arm_sme_lut2.c +15 -0
  278. package/probes/arm_sve.c +18 -0
  279. package/probes/arm_sve2.c +20 -0
  280. package/probes/arm_sve2p1.c +18 -0
  281. package/probes/arm_sve_bfdot.c +20 -0
  282. package/probes/arm_sve_half.c +18 -0
  283. package/probes/arm_sve_sdot.c +21 -0
  284. package/probes/loongarch_lasx.c +12 -0
  285. package/probes/power_vsx.c +12 -0
  286. package/probes/probe.js +127 -0
  287. package/probes/riscv_rvv.c +14 -0
  288. package/probes/riscv_rvv_bb.c +15 -0
  289. package/probes/riscv_rvv_bf16.c +17 -0
  290. package/probes/riscv_rvv_half.c +14 -0
  291. package/probes/wasm_v128relaxed.c +11 -0
  292. package/probes/x86_alder.c +17 -0
  293. package/probes/x86_diamond.c +17 -0
  294. package/probes/x86_genoa.c +17 -0
  295. package/probes/x86_graniteamx.c +19 -0
  296. package/probes/x86_haswell.c +11 -0
  297. package/probes/x86_icelake.c +17 -0
  298. package/probes/x86_sapphire.c +16 -0
  299. package/probes/x86_sapphireamx.c +18 -0
  300. package/probes/x86_sierra.c +17 -0
  301. package/probes/x86_skylake.c +15 -0
  302. package/probes/x86_turin.c +17 -0
  303. package/wasm/numkong-emscripten.js +2 -0
  304. package/wasm/numkong.d.ts +14 -0
  305. package/wasm/numkong.js +1124 -0
  306. package/wasm/numkong.wasm +0 -0
  307. package/include/numkong/curved/neonhalf.h +0 -212
  308. package/include/numkong/dot/neonhalf.h +0 -198
  309. package/include/numkong/dots/neonhalf.h +0 -57
  310. package/include/numkong/mesh/neonhalf.h +0 -616
  311. package/include/numkong/reduce/neonhalf.h +0 -157
  312. package/include/numkong/spatial/neonhalf.h +0 -118
  313. package/include/numkong/spatial/sapphire.h +0 -343
  314. package/include/numkong/spatials/neonhalf.h +0 -58
  315. package/javascript/README.md +0 -246
@@ -43,41 +43,45 @@ extern "C" {
43
43
  /** @brief Compensated horizontal sum of RVV f64m1 lanes via TwoSum tree reduction.
44
44
  *
45
45
  * Uses vslidedown to extract the upper half at each tree level (same pattern as
46
- * nk_reduce_vsaddu_u64m1_rvv_ in reduce/rvv.h). Tail lanes beyond vlmax are zero
46
+ * nk_reduce_vsaddu_u64m1_rvv_ in reduce/rvv.h). Tail lanes beyond vector_length are zero
47
47
  * from the initial vfmv_v_f, so they are harmless in the reduction.
48
48
  */
49
49
  NK_INTERNAL nk_f64_t nk_dot_stable_sum_f64m1_rvv_(vfloat64m1_t sum_f64m1, vfloat64m1_t compensation_f64m1) {
50
- nk_size_t vlmax = __riscv_vsetvlmax_e64m1();
50
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e64m1();
51
51
  // Stage 0: TwoSum merge of sum + compensation
52
- vfloat64m1_t tentative_sum_f64m1 = __riscv_vfadd_vv_f64m1(sum_f64m1, compensation_f64m1, vlmax);
53
- vfloat64m1_t virtual_addend_f64m1 = __riscv_vfsub_vv_f64m1(tentative_sum_f64m1, sum_f64m1, vlmax);
52
+ vfloat64m1_t tentative_sum_f64m1 = __riscv_vfadd_vv_f64m1(sum_f64m1, compensation_f64m1, max_vector_length);
53
+ vfloat64m1_t virtual_addend_f64m1 = __riscv_vfsub_vv_f64m1(tentative_sum_f64m1, sum_f64m1, max_vector_length);
54
54
  vfloat64m1_t accumulated_error_f64m1 = __riscv_vfadd_vv_f64m1(
55
- __riscv_vfsub_vv_f64m1(sum_f64m1, __riscv_vfsub_vv_f64m1(tentative_sum_f64m1, virtual_addend_f64m1, vlmax),
56
- vlmax),
57
- __riscv_vfsub_vv_f64m1(compensation_f64m1, virtual_addend_f64m1, vlmax), vlmax);
55
+ __riscv_vfsub_vv_f64m1(sum_f64m1,
56
+ __riscv_vfsub_vv_f64m1(tentative_sum_f64m1, virtual_addend_f64m1, max_vector_length),
57
+ max_vector_length),
58
+ __riscv_vfsub_vv_f64m1(compensation_f64m1, virtual_addend_f64m1, max_vector_length), max_vector_length);
58
59
  // Tree reduction: TwoSum halving at each level
59
- for (nk_size_t half = vlmax / 2; half > 0; half >>= 1) {
60
- vfloat64m1_t upper_sum_f64m1 = __riscv_vslidedown_vx_f64m1(tentative_sum_f64m1, half, vlmax);
61
- vfloat64m1_t upper_error_f64m1 = __riscv_vslidedown_vx_f64m1(accumulated_error_f64m1, half, vlmax);
62
- vfloat64m1_t halved_tentative_sum_f64m1 = __riscv_vfadd_vv_f64m1(tentative_sum_f64m1, upper_sum_f64m1, vlmax);
60
+ for (nk_size_t half = max_vector_length / 2; half > 0; half >>= 1) {
61
+ vfloat64m1_t upper_sum_f64m1 = __riscv_vslidedown_vx_f64m1(tentative_sum_f64m1, half, max_vector_length);
62
+ vfloat64m1_t upper_error_f64m1 = __riscv_vslidedown_vx_f64m1(accumulated_error_f64m1, half, max_vector_length);
63
+ vfloat64m1_t halved_tentative_sum_f64m1 = __riscv_vfadd_vv_f64m1(tentative_sum_f64m1, upper_sum_f64m1,
64
+ max_vector_length);
63
65
  vfloat64m1_t halved_virtual_addend_f64m1 = __riscv_vfsub_vv_f64m1(halved_tentative_sum_f64m1,
64
- tentative_sum_f64m1, vlmax);
66
+ tentative_sum_f64m1, max_vector_length);
65
67
  vfloat64m1_t rounding_error_f64m1 = __riscv_vfadd_vv_f64m1(
66
68
  __riscv_vfsub_vv_f64m1(
67
69
  tentative_sum_f64m1,
68
- __riscv_vfsub_vv_f64m1(halved_tentative_sum_f64m1, halved_virtual_addend_f64m1, vlmax), vlmax),
69
- __riscv_vfsub_vv_f64m1(upper_sum_f64m1, halved_virtual_addend_f64m1, vlmax), vlmax);
70
+ __riscv_vfsub_vv_f64m1(halved_tentative_sum_f64m1, halved_virtual_addend_f64m1, max_vector_length),
71
+ max_vector_length),
72
+ __riscv_vfsub_vv_f64m1(upper_sum_f64m1, halved_virtual_addend_f64m1, max_vector_length), max_vector_length);
70
73
  tentative_sum_f64m1 = halved_tentative_sum_f64m1;
71
74
  accumulated_error_f64m1 = __riscv_vfadd_vv_f64m1(
72
- __riscv_vfadd_vv_f64m1(accumulated_error_f64m1, upper_error_f64m1, vlmax), rounding_error_f64m1, vlmax);
75
+ __riscv_vfadd_vv_f64m1(accumulated_error_f64m1, upper_error_f64m1, max_vector_length), rounding_error_f64m1,
76
+ max_vector_length);
73
77
  }
74
78
  return __riscv_vfmv_f_s_f64m1_f64(tentative_sum_f64m1) + __riscv_vfmv_f_s_f64m1_f64(accumulated_error_f64m1);
75
79
  }
76
80
 
77
81
  NK_PUBLIC void nk_dot_i8_rvv(nk_i8_t const *a_scalars, nk_i8_t const *b_scalars, nk_size_t count_scalars,
78
82
  nk_i32_t *result) {
79
- nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
80
- vint32m4_t sum_i32m4 = __riscv_vmv_v_x_i32m4(0, vlmax);
83
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m4();
84
+ vint32m4_t sum_i32m4 = __riscv_vmv_v_x_i32m4(0, max_vector_length);
81
85
  for (nk_size_t vector_length; count_scalars > 0;
82
86
  count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
83
87
  vector_length = __riscv_vsetvl_e8m1(count_scalars);
@@ -89,14 +93,14 @@ NK_PUBLIC void nk_dot_i8_rvv(nk_i8_t const *a_scalars, nk_i8_t const *b_scalars,
89
93
  sum_i32m4 = __riscv_vwadd_wv_i32m4_tu(sum_i32m4, sum_i32m4, ab_i16m2, vector_length);
90
94
  }
91
95
  // Single horizontal reduction at the end
92
- vint32m1_t zero_i32m1 = __riscv_vmv_v_x_i32m1(0, vlmax);
93
- *result = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m4_i32m1(sum_i32m4, zero_i32m1, vlmax));
96
+ vint32m1_t zero_i32m1 = __riscv_vmv_v_x_i32m1(0, max_vector_length);
97
+ *result = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m4_i32m1(sum_i32m4, zero_i32m1, max_vector_length));
94
98
  }
95
99
 
96
100
  NK_PUBLIC void nk_dot_u8_rvv(nk_u8_t const *a_scalars, nk_u8_t const *b_scalars, nk_size_t count_scalars,
97
101
  nk_u32_t *result) {
98
- nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
99
- vuint32m4_t sum_u32m4 = __riscv_vmv_v_x_u32m4(0, vlmax);
102
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m4();
103
+ vuint32m4_t sum_u32m4 = __riscv_vmv_v_x_u32m4(0, max_vector_length);
100
104
  for (nk_size_t vector_length; count_scalars > 0;
101
105
  count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
102
106
  vector_length = __riscv_vsetvl_e8m1(count_scalars);
@@ -108,14 +112,14 @@ NK_PUBLIC void nk_dot_u8_rvv(nk_u8_t const *a_scalars, nk_u8_t const *b_scalars,
108
112
  sum_u32m4 = __riscv_vwaddu_wv_u32m4_tu(sum_u32m4, sum_u32m4, ab_u16m2, vector_length);
109
113
  }
110
114
  // Single horizontal reduction at the end
111
- vuint32m1_t zero_u32m1 = __riscv_vmv_v_x_u32m1(0, vlmax);
112
- *result = __riscv_vmv_x_s_u32m1_u32(__riscv_vredsum_vs_u32m4_u32m1(sum_u32m4, zero_u32m1, vlmax));
115
+ vuint32m1_t zero_u32m1 = __riscv_vmv_v_x_u32m1(0, max_vector_length);
116
+ *result = __riscv_vmv_x_s_u32m1_u32(__riscv_vredsum_vs_u32m4_u32m1(sum_u32m4, zero_u32m1, max_vector_length));
113
117
  }
114
118
 
115
119
  NK_PUBLIC void nk_dot_f32_rvv(nk_f32_t const *a_scalars, nk_f32_t const *b_scalars, nk_size_t count_scalars,
116
120
  nk_f64_t *result) {
117
- nk_size_t vlmax = __riscv_vsetvlmax_e64m2();
118
- vfloat64m2_t sum_f64m2 = __riscv_vfmv_v_f_f64m2(0.0, vlmax);
121
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e64m2();
122
+ vfloat64m2_t sum_f64m2 = __riscv_vfmv_v_f_f64m2(0.0, max_vector_length);
119
123
  for (nk_size_t vector_length; count_scalars > 0;
120
124
  count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
121
125
  vector_length = __riscv_vsetvl_e32m1(count_scalars);
@@ -125,16 +129,16 @@ NK_PUBLIC void nk_dot_f32_rvv(nk_f32_t const *a_scalars, nk_f32_t const *b_scala
125
129
  sum_f64m2 = __riscv_vfwmacc_vv_f64m2_tu(sum_f64m2, a_f32m1, b_f32m1, vector_length);
126
130
  }
127
131
  // Single horizontal reduction at the end
128
- vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, vlmax);
129
- *result = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m2_f64m1(sum_f64m2, zero_f64m1, vlmax));
132
+ vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, max_vector_length);
133
+ *result = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m2_f64m1(sum_f64m2, zero_f64m1, max_vector_length));
130
134
  }
131
135
 
132
136
  NK_PUBLIC void nk_dot_f64_rvv(nk_f64_t const *a_scalars, nk_f64_t const *b_scalars, nk_size_t count_scalars,
133
137
  nk_f64_t *result) {
134
138
  // Dot2 (Ogita-Rump-Oishi) compensated accumulation via TwoProd + TwoSum
135
- nk_size_t vlmax = __riscv_vsetvlmax_e64m1();
136
- vfloat64m1_t sum_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, vlmax);
137
- vfloat64m1_t compensation_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, vlmax);
139
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e64m1();
140
+ vfloat64m1_t sum_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, max_vector_length);
141
+ vfloat64m1_t compensation_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, max_vector_length);
138
142
  for (nk_size_t vector_length; count_scalars > 0;
139
143
  count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
140
144
  vector_length = __riscv_vsetvl_e64m1(count_scalars);
@@ -163,8 +167,8 @@ NK_PUBLIC void nk_dot_f64_rvv(nk_f64_t const *a_scalars, nk_f64_t const *b_scala
163
167
 
164
168
  NK_PUBLIC void nk_dot_f16_rvv(nk_f16_t const *a_scalars, nk_f16_t const *b_scalars, nk_size_t count_scalars,
165
169
  nk_f32_t *result) {
166
- nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
167
- vfloat32m2_t sum_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
170
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
171
+ vfloat32m2_t sum_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, max_vector_length);
168
172
  for (nk_size_t vector_length; count_scalars > 0;
169
173
  count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
170
174
  vector_length = __riscv_vsetvl_e16m1(count_scalars);
@@ -179,14 +183,14 @@ NK_PUBLIC void nk_dot_f16_rvv(nk_f16_t const *a_scalars, nk_f16_t const *b_scala
179
183
  sum_f32m2 = __riscv_vfmacc_vv_f32m2_tu(sum_f32m2, a_f32m2, b_f32m2, vector_length);
180
184
  }
181
185
  // Single horizontal reduction at the end
182
- vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, vlmax);
183
- *result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m2_f32m1(sum_f32m2, zero_f32m1, vlmax));
186
+ vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, max_vector_length);
187
+ *result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m2_f32m1(sum_f32m2, zero_f32m1, max_vector_length));
184
188
  }
185
189
 
186
190
  NK_PUBLIC void nk_dot_bf16_rvv(nk_bf16_t const *a_scalars, nk_bf16_t const *b_scalars, nk_size_t count_scalars,
187
191
  nk_f32_t *result) {
188
- nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
189
- vfloat32m2_t sum_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
192
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
193
+ vfloat32m2_t sum_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, max_vector_length);
190
194
  for (nk_size_t vector_length; count_scalars > 0;
191
195
  count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
192
196
  vector_length = __riscv_vsetvl_e16m1(count_scalars);
@@ -201,14 +205,14 @@ NK_PUBLIC void nk_dot_bf16_rvv(nk_bf16_t const *a_scalars, nk_bf16_t const *b_sc
201
205
  sum_f32m2 = __riscv_vfmacc_vv_f32m2_tu(sum_f32m2, a_f32m2, b_f32m2, vector_length);
202
206
  }
203
207
  // Single horizontal reduction at the end
204
- vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, vlmax);
205
- *result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m2_f32m1(sum_f32m2, zero_f32m1, vlmax));
208
+ vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, max_vector_length);
209
+ *result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m2_f32m1(sum_f32m2, zero_f32m1, max_vector_length));
206
210
  }
207
211
 
208
212
  NK_PUBLIC void nk_dot_e4m3_rvv(nk_e4m3_t const *a_scalars, nk_e4m3_t const *b_scalars, nk_size_t count_scalars,
209
213
  nk_f32_t *result) {
210
- nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
211
- vfloat32m4_t sum_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, vlmax);
214
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m4();
215
+ vfloat32m4_t sum_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, max_vector_length);
212
216
  for (nk_size_t vector_length; count_scalars > 0;
213
217
  count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
214
218
  vector_length = __riscv_vsetvl_e8m1(count_scalars);
@@ -223,14 +227,14 @@ NK_PUBLIC void nk_dot_e4m3_rvv(nk_e4m3_t const *a_scalars, nk_e4m3_t const *b_sc
223
227
  sum_f32m4 = __riscv_vfmacc_vv_f32m4_tu(sum_f32m4, a_f32m4, b_f32m4, vector_length);
224
228
  }
225
229
  // Single horizontal reduction at the end
226
- vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, vlmax);
227
- *result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sum_f32m4, zero_f32m1, vlmax));
230
+ vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, max_vector_length);
231
+ *result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sum_f32m4, zero_f32m1, max_vector_length));
228
232
  }
229
233
 
230
234
  NK_PUBLIC void nk_dot_e5m2_rvv(nk_e5m2_t const *a_scalars, nk_e5m2_t const *b_scalars, nk_size_t count_scalars,
231
235
  nk_f32_t *result) {
232
- nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
233
- vfloat32m4_t sum_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, vlmax);
236
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m4();
237
+ vfloat32m4_t sum_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, max_vector_length);
234
238
  for (nk_size_t vector_length; count_scalars > 0;
235
239
  count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
236
240
  vector_length = __riscv_vsetvl_e8m1(count_scalars);
@@ -245,8 +249,8 @@ NK_PUBLIC void nk_dot_e5m2_rvv(nk_e5m2_t const *a_scalars, nk_e5m2_t const *b_sc
245
249
  sum_f32m4 = __riscv_vfmacc_vv_f32m4_tu(sum_f32m4, a_f32m4, b_f32m4, vector_length);
246
250
  }
247
251
  // Single horizontal reduction at the end
248
- vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, vlmax);
249
- *result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sum_f32m4, zero_f32m1, vlmax));
252
+ vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, max_vector_length);
253
+ *result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sum_f32m4, zero_f32m1, max_vector_length));
250
254
  }
251
255
 
252
256
  NK_PUBLIC void nk_dot_e2m3_rvv(nk_e2m3_t const *a_scalars, nk_e2m3_t const *b_scalars, nk_size_t count_scalars,
@@ -257,8 +261,8 @@ NK_PUBLIC void nk_dot_e2m3_rvv(nk_e2m3_t const *a_scalars, nk_e2m3_t const *b_sc
257
261
  static nk_u8_t const lut_magnitude[32] = {0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30,
258
262
  32, 36, 40, 44, 48, 52, 56, 60, 64, 72, 80, 88, 96, 104, 112, 120};
259
263
 
260
- nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
261
- vint32m4_t sum_i32m4 = __riscv_vmv_v_x_i32m4(0, vlmax);
264
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m4();
265
+ vint32m4_t sum_i32m4 = __riscv_vmv_v_x_i32m4(0, max_vector_length);
262
266
  for (nk_size_t vector_length; count_scalars > 0;
263
267
  count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
264
268
  vector_length = __riscv_vsetvl_e8m1(count_scalars);
@@ -285,8 +289,8 @@ NK_PUBLIC void nk_dot_e2m3_rvv(nk_e2m3_t const *a_scalars, nk_e2m3_t const *b_sc
285
289
  vint16m2_t products_i16m2 = __riscv_vwmul_vv_i16m2(a_signed_i8m1, b_signed_i8m1, vector_length);
286
290
  sum_i32m4 = __riscv_vwadd_wv_i32m4_tu(sum_i32m4, sum_i32m4, products_i16m2, vector_length);
287
291
  }
288
- vint32m1_t zero_i32m1 = __riscv_vmv_v_x_i32m1(0, vlmax);
289
- nk_i32_t sum = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m4_i32m1(sum_i32m4, zero_i32m1, vlmax));
292
+ vint32m1_t zero_i32m1 = __riscv_vmv_v_x_i32m1(0, max_vector_length);
293
+ nk_i32_t sum = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m4_i32m1(sum_i32m4, zero_i32m1, max_vector_length));
290
294
  *result = (nk_f32_t)sum / 256.0f;
291
295
  }
292
296
 
@@ -298,8 +302,8 @@ NK_PUBLIC void nk_dot_e3m2_rvv(nk_e3m2_t const *a_scalars, nk_e3m2_t const *b_sc
298
302
  static nk_u16_t const lut_magnitude[32] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28,
299
303
  32, 40, 48, 56, 64, 80, 96, 112, 128, 160, 192, 224, 256, 320, 384, 448};
300
304
 
301
- nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
302
- vint32m4_t sum_i32m4 = __riscv_vmv_v_x_i32m4(0, vlmax);
305
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m4();
306
+ vint32m4_t sum_i32m4 = __riscv_vmv_v_x_i32m4(0, max_vector_length);
303
307
  for (nk_size_t vector_length; count_scalars > 0;
304
308
  count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
305
309
  vector_length = __riscv_vsetvl_e8m1(count_scalars);
@@ -333,8 +337,8 @@ NK_PUBLIC void nk_dot_e3m2_rvv(nk_e3m2_t const *a_scalars, nk_e3m2_t const *b_sc
333
337
  // Widening multiply-accumulate: i16×i16 → i32
334
338
  sum_i32m4 = __riscv_vwmacc_vv_i32m4_tu(sum_i32m4, a_signed_i16m2, b_signed_i16m2, vector_length);
335
339
  }
336
- vint32m1_t zero_i32m1 = __riscv_vmv_v_x_i32m1(0, vlmax);
337
- nk_i32_t sum = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m4_i32m1(sum_i32m4, zero_i32m1, vlmax));
340
+ vint32m1_t zero_i32m1 = __riscv_vmv_v_x_i32m1(0, max_vector_length);
341
+ nk_i32_t sum = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m4_i32m1(sum_i32m4, zero_i32m1, max_vector_length));
338
342
  *result = (nk_f32_t)sum / 256.0f;
339
343
  }
340
344
 
@@ -344,8 +348,8 @@ NK_PUBLIC void nk_dot_i4_rvv(nk_i4x2_t const *a_scalars, nk_i4x2_t const *b_scal
344
348
  count_dimensions = nk_size_round_up_to_multiple_(count_dimensions, 2);
345
349
  nk_size_t n_full_bytes = count_dimensions / 2;
346
350
 
347
- nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
348
- vint32m4_t sum_i32m4 = __riscv_vmv_v_x_i32m4(0, vlmax);
351
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m4();
352
+ vint32m4_t sum_i32m4 = __riscv_vmv_v_x_i32m4(0, max_vector_length);
349
353
  for (nk_size_t vector_length; n_full_bytes > 0;
350
354
  n_full_bytes -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
351
355
  vector_length = __riscv_vsetvl_e8m1(n_full_bytes);
@@ -377,8 +381,8 @@ NK_PUBLIC void nk_dot_i4_rvv(nk_i4x2_t const *a_scalars, nk_i4x2_t const *b_scal
377
381
  sum_i32m4 = __riscv_vwadd_wv_i32m4_tu(sum_i32m4, sum_i32m4, ab_low_i16m2, vector_length);
378
382
  }
379
383
  // Single horizontal reduction at the end
380
- vint32m1_t zero_i32m1 = __riscv_vmv_v_x_i32m1(0, vlmax);
381
- *result = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m4_i32m1(sum_i32m4, zero_i32m1, vlmax));
384
+ vint32m1_t zero_i32m1 = __riscv_vmv_v_x_i32m1(0, max_vector_length);
385
+ *result = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m4_i32m1(sum_i32m4, zero_i32m1, max_vector_length));
382
386
  }
383
387
 
384
388
  NK_PUBLIC void nk_dot_u4_rvv(nk_u4x2_t const *a_scalars, nk_u4x2_t const *b_scalars, nk_size_t count_dimensions,
@@ -387,8 +391,8 @@ NK_PUBLIC void nk_dot_u4_rvv(nk_u4x2_t const *a_scalars, nk_u4x2_t const *b_scal
387
391
  count_dimensions = nk_size_round_up_to_multiple_(count_dimensions, 2);
388
392
  nk_size_t n_full_bytes = count_dimensions / 2;
389
393
 
390
- nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
391
- vuint32m4_t sum_u32m4 = __riscv_vmv_v_x_u32m4(0, vlmax);
394
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m4();
395
+ vuint32m4_t sum_u32m4 = __riscv_vmv_v_x_u32m4(0, max_vector_length);
392
396
  for (nk_size_t vector_length; n_full_bytes > 0;
393
397
  n_full_bytes -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
394
398
  vector_length = __riscv_vsetvl_e8m1(n_full_bytes);
@@ -410,8 +414,8 @@ NK_PUBLIC void nk_dot_u4_rvv(nk_u4x2_t const *a_scalars, nk_u4x2_t const *b_scal
410
414
  sum_u32m4 = __riscv_vwaddu_wv_u32m4_tu(sum_u32m4, sum_u32m4, ab_low_u16m2, vector_length);
411
415
  }
412
416
  // Single horizontal reduction at the end
413
- vuint32m1_t zero_u32m1 = __riscv_vmv_v_x_u32m1(0, vlmax);
414
- *result = __riscv_vmv_x_s_u32m1_u32(__riscv_vredsum_vs_u32m4_u32m1(sum_u32m4, zero_u32m1, vlmax));
417
+ vuint32m1_t zero_u32m1 = __riscv_vmv_v_x_u32m1(0, max_vector_length);
418
+ *result = __riscv_vmv_x_s_u32m1_u32(__riscv_vredsum_vs_u32m4_u32m1(sum_u32m4, zero_u32m1, max_vector_length));
415
419
  }
416
420
 
417
421
  NK_PUBLIC void nk_dot_u1_rvv(nk_u1x8_t const *a, nk_u1x8_t const *b, nk_size_t n_bits, nk_u32_t *result) {
@@ -443,9 +447,9 @@ NK_PUBLIC void nk_dot_f32c_rvv(nk_f32c_t const *a_pairs, nk_f32c_t const *b_pair
443
447
  nk_f64c_t *results) {
444
448
  nk_f32_t const *a_f32 = (nk_f32_t const *)a_pairs;
445
449
  nk_f32_t const *b_f32 = (nk_f32_t const *)b_pairs;
446
- nk_size_t vlmax = __riscv_vsetvlmax_e64m2();
447
- vfloat64m2_t sum_real_f64m2 = __riscv_vfmv_v_f_f64m2(0.0, vlmax);
448
- vfloat64m2_t sum_imag_f64m2 = __riscv_vfmv_v_f_f64m2(0.0, vlmax);
450
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e64m2();
451
+ vfloat64m2_t sum_real_f64m2 = __riscv_vfmv_v_f_f64m2(0.0, max_vector_length);
452
+ vfloat64m2_t sum_imag_f64m2 = __riscv_vfmv_v_f_f64m2(0.0, max_vector_length);
449
453
  for (nk_size_t vector_length; count_pairs > 0;
450
454
  count_pairs -= vector_length, a_f32 += vector_length * 2, b_f32 += vector_length * 2) {
451
455
  vector_length = __riscv_vsetvl_e32m1(count_pairs);
@@ -462,18 +466,20 @@ NK_PUBLIC void nk_dot_f32c_rvv(nk_f32c_t const *a_pairs, nk_f32c_t const *b_pair
462
466
  sum_imag_f64m2 = __riscv_vfwmacc_vv_f64m2_tu(sum_imag_f64m2, a_real_f32m1, b_imag_f32m1, vector_length);
463
467
  sum_imag_f64m2 = __riscv_vfwmacc_vv_f64m2_tu(sum_imag_f64m2, a_imag_f32m1, b_real_f32m1, vector_length);
464
468
  }
465
- vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, vlmax);
466
- results->real = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m2_f64m1(sum_real_f64m2, zero_f64m1, vlmax));
467
- results->imag = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m2_f64m1(sum_imag_f64m2, zero_f64m1, vlmax));
469
+ vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, max_vector_length);
470
+ results->real = __riscv_vfmv_f_s_f64m1_f64(
471
+ __riscv_vfredusum_vs_f64m2_f64m1(sum_real_f64m2, zero_f64m1, max_vector_length));
472
+ results->imag = __riscv_vfmv_f_s_f64m1_f64(
473
+ __riscv_vfredusum_vs_f64m2_f64m1(sum_imag_f64m2, zero_f64m1, max_vector_length));
468
474
  }
469
475
 
470
476
  NK_PUBLIC void nk_vdot_f32c_rvv(nk_f32c_t const *a_pairs, nk_f32c_t const *b_pairs, nk_size_t count_pairs,
471
477
  nk_f64c_t *results) {
472
478
  nk_f32_t const *a_f32 = (nk_f32_t const *)a_pairs;
473
479
  nk_f32_t const *b_f32 = (nk_f32_t const *)b_pairs;
474
- nk_size_t vlmax = __riscv_vsetvlmax_e64m2();
475
- vfloat64m2_t sum_real_f64m2 = __riscv_vfmv_v_f_f64m2(0.0, vlmax);
476
- vfloat64m2_t sum_imag_f64m2 = __riscv_vfmv_v_f_f64m2(0.0, vlmax);
480
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e64m2();
481
+ vfloat64m2_t sum_real_f64m2 = __riscv_vfmv_v_f_f64m2(0.0, max_vector_length);
482
+ vfloat64m2_t sum_imag_f64m2 = __riscv_vfmv_v_f_f64m2(0.0, max_vector_length);
477
483
  for (nk_size_t vector_length; count_pairs > 0;
478
484
  count_pairs -= vector_length, a_f32 += vector_length * 2, b_f32 += vector_length * 2) {
479
485
  vector_length = __riscv_vsetvl_e32m1(count_pairs);
@@ -490,9 +496,11 @@ NK_PUBLIC void nk_vdot_f32c_rvv(nk_f32c_t const *a_pairs, nk_f32c_t const *b_pai
490
496
  sum_imag_f64m2 = __riscv_vfwmacc_vv_f64m2_tu(sum_imag_f64m2, a_real_f32m1, b_imag_f32m1, vector_length);
491
497
  sum_imag_f64m2 = __riscv_vfwnmsac_vv_f64m2_tu(sum_imag_f64m2, a_imag_f32m1, b_real_f32m1, vector_length);
492
498
  }
493
- vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, vlmax);
494
- results->real = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m2_f64m1(sum_real_f64m2, zero_f64m1, vlmax));
495
- results->imag = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m2_f64m1(sum_imag_f64m2, zero_f64m1, vlmax));
499
+ vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, max_vector_length);
500
+ results->real = __riscv_vfmv_f_s_f64m1_f64(
501
+ __riscv_vfredusum_vs_f64m2_f64m1(sum_real_f64m2, zero_f64m1, max_vector_length));
502
+ results->imag = __riscv_vfmv_f_s_f64m1_f64(
503
+ __riscv_vfredusum_vs_f64m2_f64m1(sum_imag_f64m2, zero_f64m1, max_vector_length));
496
504
  }
497
505
 
498
506
  NK_PUBLIC void nk_dot_f64c_rvv(nk_f64c_t const *a_pairs, nk_f64c_t const *b_pairs, nk_size_t count_pairs,
@@ -500,11 +508,11 @@ NK_PUBLIC void nk_dot_f64c_rvv(nk_f64c_t const *a_pairs, nk_f64c_t const *b_pair
500
508
  // Dot2 (Ogita-Rump-Oishi) compensated complex dot product
501
509
  nk_f64_t const *a_f64 = (nk_f64_t const *)a_pairs;
502
510
  nk_f64_t const *b_f64 = (nk_f64_t const *)b_pairs;
503
- nk_size_t vlmax = __riscv_vsetvlmax_e64m1();
504
- vfloat64m1_t sum_real_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, vlmax);
505
- vfloat64m1_t comp_real_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, vlmax);
506
- vfloat64m1_t sum_imag_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, vlmax);
507
- vfloat64m1_t comp_imag_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, vlmax);
511
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e64m1();
512
+ vfloat64m1_t sum_real_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, max_vector_length);
513
+ vfloat64m1_t comp_real_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, max_vector_length);
514
+ vfloat64m1_t sum_imag_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, max_vector_length);
515
+ vfloat64m1_t comp_imag_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, max_vector_length);
508
516
  for (nk_size_t vector_length; count_pairs > 0;
509
517
  count_pairs -= vector_length, a_f64 += vector_length * 2, b_f64 += vector_length * 2) {
510
518
  vector_length = __riscv_vsetvl_e64m1(count_pairs);
@@ -602,11 +610,11 @@ NK_PUBLIC void nk_vdot_f64c_rvv(nk_f64c_t const *a_pairs, nk_f64c_t const *b_pai
602
610
  // Dot2 (Ogita-Rump-Oishi) compensated conjugate complex dot product
603
611
  nk_f64_t const *a_f64 = (nk_f64_t const *)a_pairs;
604
612
  nk_f64_t const *b_f64 = (nk_f64_t const *)b_pairs;
605
- nk_size_t vlmax = __riscv_vsetvlmax_e64m1();
606
- vfloat64m1_t sum_real_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, vlmax);
607
- vfloat64m1_t comp_real_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, vlmax);
608
- vfloat64m1_t sum_imag_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, vlmax);
609
- vfloat64m1_t comp_imag_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, vlmax);
613
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e64m1();
614
+ vfloat64m1_t sum_real_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, max_vector_length);
615
+ vfloat64m1_t comp_real_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, max_vector_length);
616
+ vfloat64m1_t sum_imag_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, max_vector_length);
617
+ vfloat64m1_t comp_imag_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, max_vector_length);
610
618
  for (nk_size_t vector_length; count_pairs > 0;
611
619
  count_pairs -= vector_length, a_f64 += vector_length * 2, b_f64 += vector_length * 2) {
612
620
  vector_length = __riscv_vsetvl_e64m1(count_pairs);
@@ -37,8 +37,8 @@ extern "C" {
37
37
 
38
38
  NK_PUBLIC void nk_dot_bf16_rvvbf16(nk_bf16_t const *a_scalars, nk_bf16_t const *b_scalars, nk_size_t count_scalars,
39
39
  nk_f32_t *result) {
40
- nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
41
- vfloat32m2_t sum_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
40
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
41
+ vfloat32m2_t sum_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, max_vector_length);
42
42
  for (nk_size_t vector_length; count_scalars > 0;
43
43
  count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
44
44
  vector_length = __riscv_vsetvl_e16m1(count_scalars);
@@ -50,8 +50,8 @@ NK_PUBLIC void nk_dot_bf16_rvvbf16(nk_bf16_t const *a_scalars, nk_bf16_t const *
50
50
  sum_f32m2 = __riscv_vfwmaccbf16_vv_f32m2_tu(sum_f32m2, a_bf16m1, b_bf16m1, vector_length);
51
51
  }
52
52
  // Single horizontal reduction at the end
53
- vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, vlmax);
54
- *result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m2_f32m1(sum_f32m2, zero_f32m1, vlmax));
53
+ vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, max_vector_length);
54
+ *result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m2_f32m1(sum_f32m2, zero_f32m1, max_vector_length));
55
55
  }
56
56
 
57
57
  /** @brief Convert e2m3 to bf16 via 256-entry LUT in cast/rvv.h + reinterpret. */
@@ -76,8 +76,8 @@ NK_INTERNAL vbfloat16m2_t nk_e5m2m1_to_bf16m2_rvvbf16_(vuint8m1_t raw_u8m1, nk_s
76
76
 
77
77
  NK_PUBLIC void nk_dot_e4m3_rvvbf16(nk_e4m3_t const *a_scalars, nk_e4m3_t const *b_scalars, nk_size_t count_scalars,
78
78
  nk_f32_t *result) {
79
- nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
80
- vfloat32m4_t sum_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, vlmax);
79
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m4();
80
+ vfloat32m4_t sum_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, max_vector_length);
81
81
  for (nk_size_t vector_length; count_scalars > 0;
82
82
  count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
83
83
  vector_length = __riscv_vsetvl_e8m1(count_scalars);
@@ -87,14 +87,14 @@ NK_PUBLIC void nk_dot_e4m3_rvvbf16(nk_e4m3_t const *a_scalars, nk_e4m3_t const *
87
87
  vbfloat16m2_t b_bf16m2 = nk_e4m3m1_to_bf16m2_rvvbf16_(b_u8m1, vector_length);
88
88
  sum_f32m4 = __riscv_vfwmaccbf16_vv_f32m4_tu(sum_f32m4, a_bf16m2, b_bf16m2, vector_length);
89
89
  }
90
- vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, vlmax);
91
- *result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sum_f32m4, zero_f32m1, vlmax));
90
+ vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, max_vector_length);
91
+ *result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sum_f32m4, zero_f32m1, max_vector_length));
92
92
  }
93
93
 
94
94
  NK_PUBLIC void nk_dot_e5m2_rvvbf16(nk_e5m2_t const *a_scalars, nk_e5m2_t const *b_scalars, nk_size_t count_scalars,
95
95
  nk_f32_t *result) {
96
- nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
97
- vfloat32m4_t sum_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, vlmax);
96
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m4();
97
+ vfloat32m4_t sum_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, max_vector_length);
98
98
  for (nk_size_t vector_length; count_scalars > 0;
99
99
  count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
100
100
  vector_length = __riscv_vsetvl_e8m1(count_scalars);
@@ -104,8 +104,8 @@ NK_PUBLIC void nk_dot_e5m2_rvvbf16(nk_e5m2_t const *a_scalars, nk_e5m2_t const *
104
104
  vbfloat16m2_t b_bf16m2 = nk_e5m2m1_to_bf16m2_rvvbf16_(b_u8m1, vector_length);
105
105
  sum_f32m4 = __riscv_vfwmaccbf16_vv_f32m4_tu(sum_f32m4, a_bf16m2, b_bf16m2, vector_length);
106
106
  }
107
- vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, vlmax);
108
- *result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sum_f32m4, zero_f32m1, vlmax));
107
+ vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, max_vector_length);
108
+ *result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sum_f32m4, zero_f32m1, max_vector_length));
109
109
  }
110
110
 
111
111
  #if defined(__cplusplus)
@@ -38,8 +38,8 @@ extern "C" {
38
38
 
39
39
  NK_PUBLIC void nk_dot_f16_rvvhalf(nk_f16_t const *a_scalars, nk_f16_t const *b_scalars, nk_size_t count_scalars,
40
40
  nk_f32_t *result) {
41
- nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
42
- vfloat32m2_t sum_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
41
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
42
+ vfloat32m2_t sum_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, max_vector_length);
43
43
  for (nk_size_t vector_length; count_scalars > 0;
44
44
  count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
45
45
  vector_length = __riscv_vsetvl_e16m1(count_scalars);
@@ -51,8 +51,8 @@ NK_PUBLIC void nk_dot_f16_rvvhalf(nk_f16_t const *a_scalars, nk_f16_t const *b_s
51
51
  sum_f32m2 = __riscv_vfwmacc_vv_f32m2_tu(sum_f32m2, a_f16m1, b_f16m1, vector_length);
52
52
  }
53
53
  // Single horizontal reduction at the end
54
- vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, vlmax);
55
- *result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m2_f32m1(sum_f32m2, zero_f32m1, vlmax));
54
+ vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, max_vector_length);
55
+ *result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m2_f32m1(sum_f32m2, zero_f32m1, max_vector_length));
56
56
  }
57
57
 
58
58
  /** @brief Convert e2m3 to f16 via 256-entry LUT in cast/rvv.h + reinterpret. */
@@ -82,8 +82,8 @@ NK_INTERNAL vfloat16m2_t nk_e5m2m1_to_f16m2_rvvhalf_(vuint8m1_t raw_u8m1, nk_siz
82
82
 
83
83
  NK_PUBLIC void nk_dot_e4m3_rvvhalf(nk_e4m3_t const *a_scalars, nk_e4m3_t const *b_scalars, nk_size_t count_scalars,
84
84
  nk_f32_t *result) {
85
- nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
86
- vfloat32m4_t sum_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, vlmax);
85
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m4();
86
+ vfloat32m4_t sum_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, max_vector_length);
87
87
  for (nk_size_t vector_length; count_scalars > 0;
88
88
  count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
89
89
  vector_length = __riscv_vsetvl_e8m1(count_scalars);
@@ -93,14 +93,14 @@ NK_PUBLIC void nk_dot_e4m3_rvvhalf(nk_e4m3_t const *a_scalars, nk_e4m3_t const *
93
93
  vfloat16m2_t b_f16m2 = nk_e4m3m1_to_f16m2_rvvhalf_(b_u8m1, vector_length);
94
94
  sum_f32m4 = __riscv_vfwmacc_vv_f32m4_tu(sum_f32m4, a_f16m2, b_f16m2, vector_length);
95
95
  }
96
- vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, vlmax);
97
- *result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sum_f32m4, zero_f32m1, vlmax));
96
+ vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, max_vector_length);
97
+ *result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sum_f32m4, zero_f32m1, max_vector_length));
98
98
  }
99
99
 
100
100
  NK_PUBLIC void nk_dot_e5m2_rvvhalf(nk_e5m2_t const *a_scalars, nk_e5m2_t const *b_scalars, nk_size_t count_scalars,
101
101
  nk_f32_t *result) {
102
- nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
103
- vfloat32m4_t sum_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, vlmax);
102
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m4();
103
+ vfloat32m4_t sum_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, max_vector_length);
104
104
  for (nk_size_t vector_length; count_scalars > 0;
105
105
  count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
106
106
  vector_length = __riscv_vsetvl_e8m1(count_scalars);
@@ -110,8 +110,8 @@ NK_PUBLIC void nk_dot_e5m2_rvvhalf(nk_e5m2_t const *a_scalars, nk_e5m2_t const *
110
110
  vfloat16m2_t b_f16m2 = nk_e5m2m1_to_f16m2_rvvhalf_(b_u8m1, vector_length);
111
111
  sum_f32m4 = __riscv_vfwmacc_vv_f32m4_tu(sum_f32m4, a_f16m2, b_f16m2, vector_length);
112
112
  }
113
- vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, vlmax);
114
- *result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sum_f32m4, zero_f32m1, vlmax));
113
+ vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, max_vector_length);
114
+ *result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sum_f32m4, zero_f32m1, max_vector_length));
115
115
  }
116
116
 
117
117
  #if defined(__cplusplus)
@@ -8,10 +8,10 @@
8
8
  *
9
9
  * @section dot_sapphire_instructions Key AVX-512 FP16 Instructions
10
10
  *
11
- * Intrinsic Instruction Latency Throughput Ports
12
- * _mm512_fmadd_ph VFMADDPH (ZMM, ZMM, ZMM) 4cy 0.5/cy p01
13
- * _mm512_fmadd_ps VFMADD132PS (ZMM, ZMM, ZMM) 4cy 0.5/cy p01
14
- * _mm512_cvtph_ps VCVTPH2PS (ZMM, YMM) 7cy 1/cy p01
11
+ * Intrinsic Instruction Sapphire Rapids
12
+ * _mm512_fmadd_ph VFMADDPH (ZMM, ZMM, ZMM) 4cy @ p01
13
+ * _mm512_fmadd_ps VFMADD132PS (ZMM, ZMM, ZMM) 4cy @ p01
14
+ * _mm512_cvtph_ps VCVTPH2PS (ZMM, YMM) 7cy @ p01
15
15
  *
16
16
  * Sapphire Rapids introduces native AVX-512 FP16 support, enabling 32 FP16 FMAs per instruction at the same
17
17
  * throughput as 16 FP32 FMAs — effectively 2x compute density. For FP6 types (E2M3 and E3M2) whose products