numkong 7.0.0 → 7.4.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (315) hide show
  1. package/README.md +197 -124
  2. package/binding.gyp +34 -484
  3. package/c/dispatch_bf16.c +59 -1
  4. package/c/dispatch_e2m3.c +41 -8
  5. package/c/dispatch_e3m2.c +49 -8
  6. package/c/dispatch_e4m3.c +51 -9
  7. package/c/dispatch_e5m2.c +45 -1
  8. package/c/dispatch_f16.c +79 -26
  9. package/c/dispatch_f16c.c +5 -5
  10. package/c/dispatch_f32.c +56 -0
  11. package/c/dispatch_f64.c +52 -0
  12. package/c/dispatch_i4.c +3 -0
  13. package/c/dispatch_i8.c +62 -3
  14. package/c/dispatch_other.c +18 -0
  15. package/c/dispatch_u1.c +54 -9
  16. package/c/dispatch_u4.c +3 -0
  17. package/c/dispatch_u8.c +64 -3
  18. package/c/numkong.c +3 -0
  19. package/include/README.md +79 -9
  20. package/include/numkong/attention/sapphireamx.h +278 -276
  21. package/include/numkong/attention/sme.h +983 -977
  22. package/include/numkong/attention.h +1 -1
  23. package/include/numkong/capabilities.h +289 -94
  24. package/include/numkong/cast/README.md +40 -40
  25. package/include/numkong/cast/diamond.h +64 -0
  26. package/include/numkong/cast/haswell.h +42 -194
  27. package/include/numkong/cast/icelake.h +42 -37
  28. package/include/numkong/cast/loongsonasx.h +252 -0
  29. package/include/numkong/cast/neon.h +216 -249
  30. package/include/numkong/cast/powervsx.h +449 -0
  31. package/include/numkong/cast/rvv.h +223 -274
  32. package/include/numkong/cast/sapphire.h +18 -18
  33. package/include/numkong/cast/serial.h +1018 -944
  34. package/include/numkong/cast/skylake.h +82 -23
  35. package/include/numkong/cast/v128relaxed.h +462 -105
  36. package/include/numkong/cast.h +24 -0
  37. package/include/numkong/cast.hpp +44 -0
  38. package/include/numkong/curved/README.md +17 -17
  39. package/include/numkong/curved/neon.h +131 -7
  40. package/include/numkong/curved/neonbfdot.h +6 -7
  41. package/include/numkong/curved/rvv.h +26 -26
  42. package/include/numkong/curved/smef64.h +186 -182
  43. package/include/numkong/curved.h +14 -18
  44. package/include/numkong/dot/README.md +154 -137
  45. package/include/numkong/dot/alder.h +43 -43
  46. package/include/numkong/dot/diamond.h +158 -0
  47. package/include/numkong/dot/genoa.h +4 -30
  48. package/include/numkong/dot/haswell.h +215 -180
  49. package/include/numkong/dot/icelake.h +190 -76
  50. package/include/numkong/dot/loongsonasx.h +671 -0
  51. package/include/numkong/dot/neon.h +124 -73
  52. package/include/numkong/dot/neonbfdot.h +11 -12
  53. package/include/numkong/dot/neonfhm.h +44 -46
  54. package/include/numkong/dot/neonfp8.h +323 -0
  55. package/include/numkong/dot/neonsdot.h +190 -76
  56. package/include/numkong/dot/powervsx.h +752 -0
  57. package/include/numkong/dot/rvv.h +92 -84
  58. package/include/numkong/dot/rvvbf16.h +12 -12
  59. package/include/numkong/dot/rvvhalf.h +12 -12
  60. package/include/numkong/dot/sapphire.h +4 -4
  61. package/include/numkong/dot/serial.h +66 -30
  62. package/include/numkong/dot/sierra.h +31 -31
  63. package/include/numkong/dot/skylake.h +142 -110
  64. package/include/numkong/dot/sve.h +217 -177
  65. package/include/numkong/dot/svebfdot.h +10 -10
  66. package/include/numkong/dot/svehalf.h +85 -41
  67. package/include/numkong/dot/svesdot.h +89 -0
  68. package/include/numkong/dot/v128relaxed.h +124 -89
  69. package/include/numkong/dot.h +114 -48
  70. package/include/numkong/dots/README.md +203 -203
  71. package/include/numkong/dots/alder.h +12 -9
  72. package/include/numkong/dots/diamond.h +86 -0
  73. package/include/numkong/dots/genoa.h +10 -4
  74. package/include/numkong/dots/haswell.h +63 -48
  75. package/include/numkong/dots/icelake.h +27 -18
  76. package/include/numkong/dots/loongsonasx.h +176 -0
  77. package/include/numkong/dots/neon.h +14 -11
  78. package/include/numkong/dots/neonbfdot.h +4 -3
  79. package/include/numkong/dots/neonfhm.h +11 -9
  80. package/include/numkong/dots/neonfp8.h +99 -0
  81. package/include/numkong/dots/neonsdot.h +48 -12
  82. package/include/numkong/dots/powervsx.h +194 -0
  83. package/include/numkong/dots/rvv.h +451 -344
  84. package/include/numkong/dots/sapphireamx.h +1028 -984
  85. package/include/numkong/dots/serial.h +213 -197
  86. package/include/numkong/dots/sierra.h +10 -7
  87. package/include/numkong/dots/skylake.h +47 -36
  88. package/include/numkong/dots/sme.h +2001 -2364
  89. package/include/numkong/dots/smebi32.h +175 -162
  90. package/include/numkong/dots/smef64.h +328 -323
  91. package/include/numkong/dots/v128relaxed.h +64 -41
  92. package/include/numkong/dots.h +573 -293
  93. package/include/numkong/dots.hpp +45 -43
  94. package/include/numkong/each/README.md +133 -137
  95. package/include/numkong/each/haswell.h +6 -6
  96. package/include/numkong/each/icelake.h +7 -7
  97. package/include/numkong/each/neon.h +76 -42
  98. package/include/numkong/each/neonbfdot.h +11 -12
  99. package/include/numkong/each/neonhalf.h +24 -116
  100. package/include/numkong/each/rvv.h +28 -28
  101. package/include/numkong/each/sapphire.h +27 -161
  102. package/include/numkong/each/serial.h +6 -6
  103. package/include/numkong/each/skylake.h +7 -7
  104. package/include/numkong/each/v128relaxed.h +562 -0
  105. package/include/numkong/each.h +148 -62
  106. package/include/numkong/each.hpp +2 -2
  107. package/include/numkong/geospatial/README.md +18 -18
  108. package/include/numkong/geospatial/haswell.h +365 -325
  109. package/include/numkong/geospatial/neon.h +350 -306
  110. package/include/numkong/geospatial/rvv.h +4 -4
  111. package/include/numkong/geospatial/skylake.h +376 -340
  112. package/include/numkong/geospatial/v128relaxed.h +366 -327
  113. package/include/numkong/geospatial.h +17 -17
  114. package/include/numkong/matrix.hpp +4 -4
  115. package/include/numkong/maxsim/README.md +14 -14
  116. package/include/numkong/maxsim/alder.h +6 -6
  117. package/include/numkong/maxsim/genoa.h +4 -4
  118. package/include/numkong/maxsim/haswell.h +6 -6
  119. package/include/numkong/maxsim/icelake.h +18 -18
  120. package/include/numkong/maxsim/neonsdot.h +21 -21
  121. package/include/numkong/maxsim/sapphireamx.h +14 -14
  122. package/include/numkong/maxsim/serial.h +6 -6
  123. package/include/numkong/maxsim/sme.h +221 -196
  124. package/include/numkong/maxsim/v128relaxed.h +6 -6
  125. package/include/numkong/mesh/README.md +62 -56
  126. package/include/numkong/mesh/haswell.h +339 -464
  127. package/include/numkong/mesh/neon.h +1100 -519
  128. package/include/numkong/mesh/neonbfdot.h +36 -68
  129. package/include/numkong/mesh/rvv.h +530 -435
  130. package/include/numkong/mesh/serial.h +75 -91
  131. package/include/numkong/mesh/skylake.h +1627 -302
  132. package/include/numkong/mesh/v128relaxed.h +443 -330
  133. package/include/numkong/mesh.h +63 -49
  134. package/include/numkong/mesh.hpp +4 -4
  135. package/include/numkong/numkong.h +3 -3
  136. package/include/numkong/numkong.hpp +1 -0
  137. package/include/numkong/probability/README.md +23 -19
  138. package/include/numkong/probability/neon.h +82 -52
  139. package/include/numkong/probability/rvv.h +28 -23
  140. package/include/numkong/probability/serial.h +51 -39
  141. package/include/numkong/probability.h +20 -23
  142. package/include/numkong/random.h +1 -1
  143. package/include/numkong/reduce/README.md +143 -138
  144. package/include/numkong/reduce/alder.h +81 -77
  145. package/include/numkong/reduce/haswell.h +222 -220
  146. package/include/numkong/reduce/neon.h +629 -519
  147. package/include/numkong/reduce/neonbfdot.h +7 -218
  148. package/include/numkong/reduce/neonfhm.h +9 -381
  149. package/include/numkong/reduce/neonsdot.h +9 -9
  150. package/include/numkong/reduce/rvv.h +928 -802
  151. package/include/numkong/reduce/serial.h +23 -27
  152. package/include/numkong/reduce/sierra.h +20 -20
  153. package/include/numkong/reduce/skylake.h +326 -324
  154. package/include/numkong/reduce/v128relaxed.h +52 -52
  155. package/include/numkong/reduce.h +4 -23
  156. package/include/numkong/reduce.hpp +156 -11
  157. package/include/numkong/scalar/README.md +6 -6
  158. package/include/numkong/scalar/haswell.h +26 -17
  159. package/include/numkong/scalar/loongsonasx.h +74 -0
  160. package/include/numkong/scalar/neon.h +9 -9
  161. package/include/numkong/scalar/powervsx.h +96 -0
  162. package/include/numkong/scalar/rvv.h +2 -2
  163. package/include/numkong/scalar/sapphire.h +21 -10
  164. package/include/numkong/scalar/serial.h +21 -21
  165. package/include/numkong/scalar.h +13 -0
  166. package/include/numkong/set/README.md +28 -28
  167. package/include/numkong/set/haswell.h +12 -12
  168. package/include/numkong/set/icelake.h +14 -14
  169. package/include/numkong/set/loongsonasx.h +181 -0
  170. package/include/numkong/set/neon.h +17 -18
  171. package/include/numkong/set/powervsx.h +326 -0
  172. package/include/numkong/set/rvv.h +4 -4
  173. package/include/numkong/set/serial.h +6 -6
  174. package/include/numkong/set/sve.h +60 -59
  175. package/include/numkong/set/v128relaxed.h +6 -6
  176. package/include/numkong/set.h +21 -7
  177. package/include/numkong/sets/README.md +26 -26
  178. package/include/numkong/sets/loongsonasx.h +52 -0
  179. package/include/numkong/sets/powervsx.h +65 -0
  180. package/include/numkong/sets/smebi32.h +395 -364
  181. package/include/numkong/sets.h +83 -40
  182. package/include/numkong/sparse/README.md +4 -4
  183. package/include/numkong/sparse/icelake.h +101 -101
  184. package/include/numkong/sparse/serial.h +1 -1
  185. package/include/numkong/sparse/sve2.h +137 -141
  186. package/include/numkong/sparse/turin.h +12 -12
  187. package/include/numkong/sparse.h +10 -10
  188. package/include/numkong/spatial/README.md +230 -226
  189. package/include/numkong/spatial/alder.h +113 -116
  190. package/include/numkong/spatial/diamond.h +240 -0
  191. package/include/numkong/spatial/genoa.h +0 -68
  192. package/include/numkong/spatial/haswell.h +74 -55
  193. package/include/numkong/spatial/icelake.h +539 -58
  194. package/include/numkong/spatial/loongsonasx.h +483 -0
  195. package/include/numkong/spatial/neon.h +125 -52
  196. package/include/numkong/spatial/neonbfdot.h +8 -9
  197. package/include/numkong/spatial/neonfp8.h +258 -0
  198. package/include/numkong/spatial/neonsdot.h +180 -12
  199. package/include/numkong/spatial/powervsx.h +738 -0
  200. package/include/numkong/spatial/rvv.h +146 -139
  201. package/include/numkong/spatial/rvvbf16.h +17 -12
  202. package/include/numkong/spatial/rvvhalf.h +13 -10
  203. package/include/numkong/spatial/serial.h +13 -12
  204. package/include/numkong/spatial/sierra.h +232 -39
  205. package/include/numkong/spatial/skylake.h +73 -74
  206. package/include/numkong/spatial/sve.h +93 -72
  207. package/include/numkong/spatial/svebfdot.h +29 -29
  208. package/include/numkong/spatial/svehalf.h +52 -26
  209. package/include/numkong/spatial/svesdot.h +142 -0
  210. package/include/numkong/spatial/v128relaxed.h +293 -41
  211. package/include/numkong/spatial.h +338 -82
  212. package/include/numkong/spatials/README.md +194 -194
  213. package/include/numkong/spatials/diamond.h +82 -0
  214. package/include/numkong/spatials/haswell.h +2 -2
  215. package/include/numkong/spatials/loongsonasx.h +153 -0
  216. package/include/numkong/spatials/neonfp8.h +111 -0
  217. package/include/numkong/spatials/neonsdot.h +34 -0
  218. package/include/numkong/spatials/powervsx.h +153 -0
  219. package/include/numkong/spatials/rvv.h +259 -243
  220. package/include/numkong/spatials/sapphireamx.h +173 -173
  221. package/include/numkong/spatials/serial.h +2 -2
  222. package/include/numkong/spatials/skylake.h +2 -2
  223. package/include/numkong/spatials/sme.h +590 -605
  224. package/include/numkong/spatials/smef64.h +139 -130
  225. package/include/numkong/spatials/v128relaxed.h +2 -2
  226. package/include/numkong/spatials.h +820 -500
  227. package/include/numkong/spatials.hpp +49 -48
  228. package/include/numkong/tensor.hpp +406 -17
  229. package/include/numkong/trigonometry/README.md +19 -19
  230. package/include/numkong/trigonometry/haswell.h +402 -401
  231. package/include/numkong/trigonometry/neon.h +386 -387
  232. package/include/numkong/trigonometry/rvv.h +52 -51
  233. package/include/numkong/trigonometry/serial.h +13 -13
  234. package/include/numkong/trigonometry/skylake.h +373 -369
  235. package/include/numkong/trigonometry/v128relaxed.h +375 -374
  236. package/include/numkong/trigonometry.h +13 -13
  237. package/include/numkong/trigonometry.hpp +2 -2
  238. package/include/numkong/types.h +287 -49
  239. package/include/numkong/types.hpp +436 -12
  240. package/include/numkong/vector.hpp +82 -14
  241. package/javascript/dist/cjs/numkong-wasm.js +6 -12
  242. package/javascript/dist/cjs/numkong.d.ts +7 -1
  243. package/javascript/dist/cjs/numkong.js +37 -11
  244. package/javascript/dist/cjs/types.d.ts +9 -0
  245. package/javascript/dist/cjs/types.js +96 -0
  246. package/javascript/dist/esm/numkong-browser.d.ts +14 -0
  247. package/javascript/dist/esm/numkong-browser.js +23 -0
  248. package/javascript/dist/esm/numkong-wasm.js +6 -12
  249. package/javascript/dist/esm/numkong.d.ts +7 -1
  250. package/javascript/dist/esm/numkong.js +37 -11
  251. package/javascript/dist/esm/types.d.ts +9 -0
  252. package/javascript/dist/esm/types.js +96 -0
  253. package/javascript/node-gyp-build.d.ts +4 -1
  254. package/javascript/numkong-browser.ts +40 -0
  255. package/javascript/numkong-wasm.ts +7 -13
  256. package/javascript/numkong.c +5 -26
  257. package/javascript/numkong.ts +36 -11
  258. package/javascript/tsconfig-base.json +1 -0
  259. package/javascript/tsconfig-cjs.json +6 -1
  260. package/javascript/types.ts +110 -0
  261. package/numkong.gypi +101 -0
  262. package/package.json +34 -13
  263. package/probes/arm_neon.c +8 -0
  264. package/probes/arm_neon_bfdot.c +9 -0
  265. package/probes/arm_neon_fhm.c +9 -0
  266. package/probes/arm_neon_half.c +8 -0
  267. package/probes/arm_neon_sdot.c +9 -0
  268. package/probes/arm_neonfp8.c +9 -0
  269. package/probes/arm_sme.c +16 -0
  270. package/probes/arm_sme2.c +16 -0
  271. package/probes/arm_sme2p1.c +16 -0
  272. package/probes/arm_sme_bf16.c +16 -0
  273. package/probes/arm_sme_bi32.c +16 -0
  274. package/probes/arm_sme_f64.c +16 -0
  275. package/probes/arm_sme_fa64.c +14 -0
  276. package/probes/arm_sme_half.c +16 -0
  277. package/probes/arm_sme_lut2.c +15 -0
  278. package/probes/arm_sve.c +18 -0
  279. package/probes/arm_sve2.c +20 -0
  280. package/probes/arm_sve2p1.c +18 -0
  281. package/probes/arm_sve_bfdot.c +20 -0
  282. package/probes/arm_sve_half.c +18 -0
  283. package/probes/arm_sve_sdot.c +21 -0
  284. package/probes/loongarch_lasx.c +12 -0
  285. package/probes/power_vsx.c +12 -0
  286. package/probes/probe.js +127 -0
  287. package/probes/riscv_rvv.c +14 -0
  288. package/probes/riscv_rvv_bb.c +15 -0
  289. package/probes/riscv_rvv_bf16.c +17 -0
  290. package/probes/riscv_rvv_half.c +14 -0
  291. package/probes/wasm_v128relaxed.c +11 -0
  292. package/probes/x86_alder.c +17 -0
  293. package/probes/x86_diamond.c +17 -0
  294. package/probes/x86_genoa.c +17 -0
  295. package/probes/x86_graniteamx.c +19 -0
  296. package/probes/x86_haswell.c +11 -0
  297. package/probes/x86_icelake.c +17 -0
  298. package/probes/x86_sapphire.c +16 -0
  299. package/probes/x86_sapphireamx.c +18 -0
  300. package/probes/x86_sierra.c +17 -0
  301. package/probes/x86_skylake.c +15 -0
  302. package/probes/x86_turin.c +17 -0
  303. package/wasm/numkong-emscripten.js +2 -0
  304. package/wasm/numkong.d.ts +14 -0
  305. package/wasm/numkong.js +1124 -0
  306. package/wasm/numkong.wasm +0 -0
  307. package/include/numkong/curved/neonhalf.h +0 -212
  308. package/include/numkong/dot/neonhalf.h +0 -198
  309. package/include/numkong/dots/neonhalf.h +0 -57
  310. package/include/numkong/mesh/neonhalf.h +0 -616
  311. package/include/numkong/reduce/neonhalf.h +0 -157
  312. package/include/numkong/spatial/neonhalf.h +0 -118
  313. package/include/numkong/spatial/sapphire.h +0 -343
  314. package/include/numkong/spatials/neonhalf.h +0 -58
  315. package/javascript/README.md +0 -246
@@ -133,7 +133,7 @@ NK_INTERNAL vfloat64m4_t nk_f64m4_reciprocal_rvv_(vfloat64m4_t x_f64m4, nk_size_
133
133
  return est_f64m4;
134
134
  }
135
135
 
136
- #pragma region - Small Integers
136
+ #pragma region I8 and U8 Integers
137
137
 
138
138
  NK_PUBLIC void nk_sqeuclidean_i8_rvv(nk_i8_t const *a_scalars, nk_i8_t const *b_scalars, nk_size_t count_scalars,
139
139
  nk_u32_t *result) {
@@ -187,13 +187,13 @@ NK_PUBLIC void nk_euclidean_u8_rvv(nk_u8_t const *a_scalars, nk_u8_t const *b_sc
187
187
  *result = nk_f32_sqrt_rvv((nk_f32_t)d2);
188
188
  }
189
189
 
190
- #pragma endregion - Small Integers
191
- #pragma region - Traditional Floats
190
+ #pragma endregion I8 and U8 Integers
191
+ #pragma region F32 and F64 Floats
192
192
 
193
193
  NK_PUBLIC void nk_sqeuclidean_f32_rvv(nk_f32_t const *a_scalars, nk_f32_t const *b_scalars, nk_size_t count_scalars,
194
194
  nk_f64_t *result) {
195
- nk_size_t vlmax = __riscv_vsetvlmax_e64m2();
196
- vfloat64m2_t sum_f64m2 = __riscv_vfmv_v_f_f64m2(0.0, vlmax);
195
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e64m2();
196
+ vfloat64m2_t sum_f64m2 = __riscv_vfmv_v_f_f64m2(0.0, max_vector_length);
197
197
  for (nk_size_t vector_length; count_scalars > 0;
198
198
  count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
199
199
  vector_length = __riscv_vsetvl_e32m1(count_scalars);
@@ -206,7 +206,7 @@ NK_PUBLIC void nk_sqeuclidean_f32_rvv(nk_f32_t const *a_scalars, nk_f32_t const
206
206
  }
207
207
  // Single horizontal reduction at the end
208
208
  vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
209
- *result = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m2_f64m1(sum_f64m2, zero_f64m1, vlmax));
209
+ *result = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m2_f64m1(sum_f64m2, zero_f64m1, max_vector_length));
210
210
  }
211
211
 
212
212
  NK_PUBLIC void nk_euclidean_f32_rvv(nk_f32_t const *a_scalars, nk_f32_t const *b_scalars, nk_size_t count_scalars,
@@ -239,8 +239,8 @@ NK_PUBLIC void nk_euclidean_f64_rvv(nk_f64_t const *a_scalars, nk_f64_t const *b
239
239
  *result = nk_f64_sqrt_rvv(*result);
240
240
  }
241
241
 
242
- #pragma endregion - Traditional Floats
243
- #pragma region - Small Integers
242
+ #pragma endregion F32 and F64 Floats
243
+ #pragma region I8 and U8 Integers
244
244
 
245
245
  NK_PUBLIC void nk_angular_i8_rvv(nk_i8_t const *a_scalars, nk_i8_t const *b_scalars, nk_size_t count_scalars,
246
246
  nk_f32_t *result) {
@@ -320,15 +320,15 @@ NK_PUBLIC void nk_angular_u8_rvv(nk_u8_t const *a_scalars, nk_u8_t const *b_scal
320
320
  }
321
321
  }
322
322
 
323
- #pragma endregion - Small Integers
324
- #pragma region - Traditional Floats
323
+ #pragma endregion I8 and U8 Integers
324
+ #pragma region F32 and F64 Floats
325
325
 
326
326
  NK_PUBLIC void nk_angular_f32_rvv(nk_f32_t const *a_scalars, nk_f32_t const *b_scalars, nk_size_t count_scalars,
327
327
  nk_f64_t *result) {
328
- nk_size_t vlmax = __riscv_vsetvlmax_e64m2();
329
- vfloat64m2_t dot_f64m2 = __riscv_vfmv_v_f_f64m2(0.0, vlmax);
330
- vfloat64m2_t a_norm_sq_f64m2 = __riscv_vfmv_v_f_f64m2(0.0, vlmax);
331
- vfloat64m2_t b_norm_sq_f64m2 = __riscv_vfmv_v_f_f64m2(0.0, vlmax);
328
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e64m2();
329
+ vfloat64m2_t dot_f64m2 = __riscv_vfmv_v_f_f64m2(0.0, max_vector_length);
330
+ vfloat64m2_t a_norm_sq_f64m2 = __riscv_vfmv_v_f_f64m2(0.0, max_vector_length);
331
+ vfloat64m2_t b_norm_sq_f64m2 = __riscv_vfmv_v_f_f64m2(0.0, max_vector_length);
332
332
 
333
333
  for (nk_size_t vector_length; count_scalars > 0;
334
334
  count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
@@ -344,11 +344,12 @@ NK_PUBLIC void nk_angular_f32_rvv(nk_f32_t const *a_scalars, nk_f32_t const *b_s
344
344
 
345
345
  // Single horizontal reduction at the end for all three accumulators
346
346
  vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
347
- nk_f64_t dot_f64 = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m2_f64m1(dot_f64m2, zero_f64m1, vlmax));
347
+ nk_f64_t dot_f64 = __riscv_vfmv_f_s_f64m1_f64(
348
+ __riscv_vfredusum_vs_f64m2_f64m1(dot_f64m2, zero_f64m1, max_vector_length));
348
349
  nk_f64_t a_norm_sq_f64 = __riscv_vfmv_f_s_f64m1_f64(
349
- __riscv_vfredusum_vs_f64m2_f64m1(a_norm_sq_f64m2, zero_f64m1, vlmax));
350
+ __riscv_vfredusum_vs_f64m2_f64m1(a_norm_sq_f64m2, zero_f64m1, max_vector_length));
350
351
  nk_f64_t b_norm_sq_f64 = __riscv_vfmv_f_s_f64m1_f64(
351
- __riscv_vfredusum_vs_f64m2_f64m1(b_norm_sq_f64m2, zero_f64m1, vlmax));
352
+ __riscv_vfredusum_vs_f64m2_f64m1(b_norm_sq_f64m2, zero_f64m1, max_vector_length));
352
353
 
353
354
  // Normalize: 1 − dot / √(‖a‖² × ‖b‖²)
354
355
  if (a_norm_sq_f64 == 0.0 && b_norm_sq_f64 == 0.0) { *result = 0.0; }
@@ -413,13 +414,13 @@ NK_PUBLIC void nk_angular_f64_rvv(nk_f64_t const *a_scalars, nk_f64_t const *b_s
413
414
  }
414
415
  }
415
416
 
416
- #pragma endregion - Traditional Floats
417
- #pragma region - Smaller Floats
417
+ #pragma endregion F32 and F64 Floats
418
+ #pragma region F16 and BF16 Floats
418
419
 
419
420
  NK_PUBLIC void nk_sqeuclidean_f16_rvv(nk_f16_t const *a_scalars, nk_f16_t const *b_scalars, nk_size_t count_scalars,
420
421
  nk_f32_t *result) {
421
- nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
422
- vfloat32m2_t sum_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
422
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
423
+ vfloat32m2_t sum_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, max_vector_length);
423
424
  for (nk_size_t vector_length; count_scalars > 0;
424
425
  count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
425
426
  vector_length = __riscv_vsetvl_e16m1(count_scalars);
@@ -436,7 +437,7 @@ NK_PUBLIC void nk_sqeuclidean_f16_rvv(nk_f16_t const *a_scalars, nk_f16_t const
436
437
  }
437
438
  // Single horizontal reduction at the end
438
439
  vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
439
- *result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m2_f32m1(sum_f32m2, zero_f32m1, vlmax));
440
+ *result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m2_f32m1(sum_f32m2, zero_f32m1, max_vector_length));
440
441
  }
441
442
 
442
443
  NK_PUBLIC void nk_euclidean_f16_rvv(nk_f16_t const *a_scalars, nk_f16_t const *b_scalars, nk_size_t count_scalars,
@@ -447,10 +448,10 @@ NK_PUBLIC void nk_euclidean_f16_rvv(nk_f16_t const *a_scalars, nk_f16_t const *b
447
448
 
448
449
  NK_PUBLIC void nk_angular_f16_rvv(nk_f16_t const *a_scalars, nk_f16_t const *b_scalars, nk_size_t count_scalars,
449
450
  nk_f32_t *result) {
450
- nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
451
- vfloat32m2_t dot_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
452
- vfloat32m2_t a_norm_sq_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
453
- vfloat32m2_t b_norm_sq_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
451
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
452
+ vfloat32m2_t dot_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, max_vector_length);
453
+ vfloat32m2_t a_norm_sq_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, max_vector_length);
454
+ vfloat32m2_t b_norm_sq_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, max_vector_length);
454
455
 
455
456
  for (nk_size_t vector_length; count_scalars > 0;
456
457
  count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
@@ -470,11 +471,12 @@ NK_PUBLIC void nk_angular_f16_rvv(nk_f16_t const *a_scalars, nk_f16_t const *b_s
470
471
 
471
472
  // Single horizontal reduction at the end for all three accumulators
472
473
  vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
473
- nk_f32_t dot_f32 = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m2_f32m1(dot_f32m2, zero_f32m1, vlmax));
474
+ nk_f32_t dot_f32 = __riscv_vfmv_f_s_f32m1_f32(
475
+ __riscv_vfredusum_vs_f32m2_f32m1(dot_f32m2, zero_f32m1, max_vector_length));
474
476
  nk_f32_t a_norm_sq_f32 = __riscv_vfmv_f_s_f32m1_f32(
475
- __riscv_vfredusum_vs_f32m2_f32m1(a_norm_sq_f32m2, zero_f32m1, vlmax));
477
+ __riscv_vfredusum_vs_f32m2_f32m1(a_norm_sq_f32m2, zero_f32m1, max_vector_length));
476
478
  nk_f32_t b_norm_sq_f32 = __riscv_vfmv_f_s_f32m1_f32(
477
- __riscv_vfredusum_vs_f32m2_f32m1(b_norm_sq_f32m2, zero_f32m1, vlmax));
479
+ __riscv_vfredusum_vs_f32m2_f32m1(b_norm_sq_f32m2, zero_f32m1, max_vector_length));
478
480
 
479
481
  if (a_norm_sq_f32 == 0.0f && b_norm_sq_f32 == 0.0f) { *result = 0.0f; }
480
482
  else if (dot_f32 == 0.0f) { *result = 1.0f; }
@@ -486,8 +488,8 @@ NK_PUBLIC void nk_angular_f16_rvv(nk_f16_t const *a_scalars, nk_f16_t const *b_s
486
488
 
487
489
  NK_PUBLIC void nk_sqeuclidean_bf16_rvv(nk_bf16_t const *a_scalars, nk_bf16_t const *b_scalars, nk_size_t count_scalars,
488
490
  nk_f32_t *result) {
489
- nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
490
- vfloat32m2_t sum_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
491
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
492
+ vfloat32m2_t sum_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, max_vector_length);
491
493
  for (nk_size_t vector_length; count_scalars > 0;
492
494
  count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
493
495
  vector_length = __riscv_vsetvl_e16m1(count_scalars);
@@ -504,7 +506,7 @@ NK_PUBLIC void nk_sqeuclidean_bf16_rvv(nk_bf16_t const *a_scalars, nk_bf16_t con
504
506
  }
505
507
  // Single horizontal reduction at the end
506
508
  vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
507
- *result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m2_f32m1(sum_f32m2, zero_f32m1, vlmax));
509
+ *result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m2_f32m1(sum_f32m2, zero_f32m1, max_vector_length));
508
510
  }
509
511
 
510
512
  NK_PUBLIC void nk_euclidean_bf16_rvv(nk_bf16_t const *a_scalars, nk_bf16_t const *b_scalars, nk_size_t count_scalars,
@@ -515,10 +517,10 @@ NK_PUBLIC void nk_euclidean_bf16_rvv(nk_bf16_t const *a_scalars, nk_bf16_t const
515
517
 
516
518
  NK_PUBLIC void nk_angular_bf16_rvv(nk_bf16_t const *a_scalars, nk_bf16_t const *b_scalars, nk_size_t count_scalars,
517
519
  nk_f32_t *result) {
518
- nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
519
- vfloat32m2_t dot_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
520
- vfloat32m2_t a_norm_sq_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
521
- vfloat32m2_t b_norm_sq_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
520
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
521
+ vfloat32m2_t dot_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, max_vector_length);
522
+ vfloat32m2_t a_norm_sq_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, max_vector_length);
523
+ vfloat32m2_t b_norm_sq_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, max_vector_length);
522
524
 
523
525
  for (nk_size_t vector_length; count_scalars > 0;
524
526
  count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
@@ -538,11 +540,12 @@ NK_PUBLIC void nk_angular_bf16_rvv(nk_bf16_t const *a_scalars, nk_bf16_t const *
538
540
 
539
541
  // Single horizontal reduction at the end for all three accumulators
540
542
  vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
541
- nk_f32_t dot_f32 = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m2_f32m1(dot_f32m2, zero_f32m1, vlmax));
543
+ nk_f32_t dot_f32 = __riscv_vfmv_f_s_f32m1_f32(
544
+ __riscv_vfredusum_vs_f32m2_f32m1(dot_f32m2, zero_f32m1, max_vector_length));
542
545
  nk_f32_t a_norm_sq_f32 = __riscv_vfmv_f_s_f32m1_f32(
543
- __riscv_vfredusum_vs_f32m2_f32m1(a_norm_sq_f32m2, zero_f32m1, vlmax));
546
+ __riscv_vfredusum_vs_f32m2_f32m1(a_norm_sq_f32m2, zero_f32m1, max_vector_length));
544
547
  nk_f32_t b_norm_sq_f32 = __riscv_vfmv_f_s_f32m1_f32(
545
- __riscv_vfredusum_vs_f32m2_f32m1(b_norm_sq_f32m2, zero_f32m1, vlmax));
548
+ __riscv_vfredusum_vs_f32m2_f32m1(b_norm_sq_f32m2, zero_f32m1, max_vector_length));
546
549
 
547
550
  if (a_norm_sq_f32 == 0.0f && b_norm_sq_f32 == 0.0f) { *result = 0.0f; }
548
551
  else if (dot_f32 == 0.0f) { *result = 1.0f; }
@@ -554,8 +557,8 @@ NK_PUBLIC void nk_angular_bf16_rvv(nk_bf16_t const *a_scalars, nk_bf16_t const *
554
557
 
555
558
  NK_PUBLIC void nk_sqeuclidean_e4m3_rvv(nk_e4m3_t const *a_scalars, nk_e4m3_t const *b_scalars, nk_size_t count_scalars,
556
559
  nk_f32_t *result) {
557
- nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
558
- vfloat32m4_t sum_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, vlmax);
560
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m4();
561
+ vfloat32m4_t sum_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, max_vector_length);
559
562
  for (nk_size_t vector_length; count_scalars > 0;
560
563
  count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
561
564
  vector_length = __riscv_vsetvl_e8m1(count_scalars);
@@ -572,7 +575,7 @@ NK_PUBLIC void nk_sqeuclidean_e4m3_rvv(nk_e4m3_t const *a_scalars, nk_e4m3_t con
572
575
  }
573
576
  // Single horizontal reduction at the end
574
577
  vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
575
- *result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sum_f32m4, zero_f32m1, vlmax));
578
+ *result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sum_f32m4, zero_f32m1, max_vector_length));
576
579
  }
577
580
 
578
581
  NK_PUBLIC void nk_euclidean_e4m3_rvv(nk_e4m3_t const *a_scalars, nk_e4m3_t const *b_scalars, nk_size_t count_scalars,
@@ -583,10 +586,10 @@ NK_PUBLIC void nk_euclidean_e4m3_rvv(nk_e4m3_t const *a_scalars, nk_e4m3_t const
583
586
 
584
587
  NK_PUBLIC void nk_angular_e4m3_rvv(nk_e4m3_t const *a_scalars, nk_e4m3_t const *b_scalars, nk_size_t count_scalars,
585
588
  nk_f32_t *result) {
586
- nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
587
- vfloat32m4_t dot_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, vlmax);
588
- vfloat32m4_t a_norm_sq_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, vlmax);
589
- vfloat32m4_t b_norm_sq_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, vlmax);
589
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m4();
590
+ vfloat32m4_t dot_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, max_vector_length);
591
+ vfloat32m4_t a_norm_sq_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, max_vector_length);
592
+ vfloat32m4_t b_norm_sq_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, max_vector_length);
590
593
 
591
594
  for (nk_size_t vector_length; count_scalars > 0;
592
595
  count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
@@ -606,11 +609,12 @@ NK_PUBLIC void nk_angular_e4m3_rvv(nk_e4m3_t const *a_scalars, nk_e4m3_t const *
606
609
 
607
610
  // Single horizontal reduction at the end for all three accumulators
608
611
  vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
609
- nk_f32_t dot_f32 = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(dot_f32m4, zero_f32m1, vlmax));
612
+ nk_f32_t dot_f32 = __riscv_vfmv_f_s_f32m1_f32(
613
+ __riscv_vfredusum_vs_f32m4_f32m1(dot_f32m4, zero_f32m1, max_vector_length));
610
614
  nk_f32_t a_norm_sq_f32 = __riscv_vfmv_f_s_f32m1_f32(
611
- __riscv_vfredusum_vs_f32m4_f32m1(a_norm_sq_f32m4, zero_f32m1, vlmax));
615
+ __riscv_vfredusum_vs_f32m4_f32m1(a_norm_sq_f32m4, zero_f32m1, max_vector_length));
612
616
  nk_f32_t b_norm_sq_f32 = __riscv_vfmv_f_s_f32m1_f32(
613
- __riscv_vfredusum_vs_f32m4_f32m1(b_norm_sq_f32m4, zero_f32m1, vlmax));
617
+ __riscv_vfredusum_vs_f32m4_f32m1(b_norm_sq_f32m4, zero_f32m1, max_vector_length));
614
618
 
615
619
  if (a_norm_sq_f32 == 0.0f && b_norm_sq_f32 == 0.0f) { *result = 0.0f; }
616
620
  else if (dot_f32 == 0.0f) { *result = 1.0f; }
@@ -622,8 +626,8 @@ NK_PUBLIC void nk_angular_e4m3_rvv(nk_e4m3_t const *a_scalars, nk_e4m3_t const *
622
626
 
623
627
  NK_PUBLIC void nk_sqeuclidean_e5m2_rvv(nk_e5m2_t const *a_scalars, nk_e5m2_t const *b_scalars, nk_size_t count_scalars,
624
628
  nk_f32_t *result) {
625
- nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
626
- vfloat32m4_t sum_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, vlmax);
629
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m4();
630
+ vfloat32m4_t sum_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, max_vector_length);
627
631
  for (nk_size_t vector_length; count_scalars > 0;
628
632
  count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
629
633
  vector_length = __riscv_vsetvl_e8m1(count_scalars);
@@ -640,7 +644,7 @@ NK_PUBLIC void nk_sqeuclidean_e5m2_rvv(nk_e5m2_t const *a_scalars, nk_e5m2_t con
640
644
  }
641
645
  // Single horizontal reduction at the end
642
646
  vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
643
- *result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sum_f32m4, zero_f32m1, vlmax));
647
+ *result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sum_f32m4, zero_f32m1, max_vector_length));
644
648
  }
645
649
 
646
650
  NK_PUBLIC void nk_euclidean_e5m2_rvv(nk_e5m2_t const *a_scalars, nk_e5m2_t const *b_scalars, nk_size_t count_scalars,
@@ -651,10 +655,10 @@ NK_PUBLIC void nk_euclidean_e5m2_rvv(nk_e5m2_t const *a_scalars, nk_e5m2_t const
651
655
 
652
656
  NK_PUBLIC void nk_angular_e5m2_rvv(nk_e5m2_t const *a_scalars, nk_e5m2_t const *b_scalars, nk_size_t count_scalars,
653
657
  nk_f32_t *result) {
654
- nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
655
- vfloat32m4_t dot_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, vlmax);
656
- vfloat32m4_t a_norm_sq_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, vlmax);
657
- vfloat32m4_t b_norm_sq_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, vlmax);
658
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m4();
659
+ vfloat32m4_t dot_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, max_vector_length);
660
+ vfloat32m4_t a_norm_sq_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, max_vector_length);
661
+ vfloat32m4_t b_norm_sq_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, max_vector_length);
658
662
 
659
663
  for (nk_size_t vector_length; count_scalars > 0;
660
664
  count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
@@ -674,11 +678,12 @@ NK_PUBLIC void nk_angular_e5m2_rvv(nk_e5m2_t const *a_scalars, nk_e5m2_t const *
674
678
 
675
679
  // Single horizontal reduction at the end for all three accumulators
676
680
  vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
677
- nk_f32_t dot_f32 = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(dot_f32m4, zero_f32m1, vlmax));
681
+ nk_f32_t dot_f32 = __riscv_vfmv_f_s_f32m1_f32(
682
+ __riscv_vfredusum_vs_f32m4_f32m1(dot_f32m4, zero_f32m1, max_vector_length));
678
683
  nk_f32_t a_norm_sq_f32 = __riscv_vfmv_f_s_f32m1_f32(
679
- __riscv_vfredusum_vs_f32m4_f32m1(a_norm_sq_f32m4, zero_f32m1, vlmax));
684
+ __riscv_vfredusum_vs_f32m4_f32m1(a_norm_sq_f32m4, zero_f32m1, max_vector_length));
680
685
  nk_f32_t b_norm_sq_f32 = __riscv_vfmv_f_s_f32m1_f32(
681
- __riscv_vfredusum_vs_f32m4_f32m1(b_norm_sq_f32m4, zero_f32m1, vlmax));
686
+ __riscv_vfredusum_vs_f32m4_f32m1(b_norm_sq_f32m4, zero_f32m1, max_vector_length));
682
687
 
683
688
  if (a_norm_sq_f32 == 0.0f && b_norm_sq_f32 == 0.0f) { *result = 0.0f; }
684
689
  else if (dot_f32 == 0.0f) { *result = 1.0f; }
@@ -688,8 +693,8 @@ NK_PUBLIC void nk_angular_e5m2_rvv(nk_e5m2_t const *a_scalars, nk_e5m2_t const *
688
693
  }
689
694
  }
690
695
 
691
- #pragma endregion - Smaller Floats
692
- #pragma region - Small Integers
696
+ #pragma endregion F16 and BF16 Floats
697
+ #pragma region I8 and U8 Integers
693
698
 
694
699
  NK_PUBLIC void nk_sqeuclidean_i4_rvv(nk_i4x2_t const *a_scalars, nk_i4x2_t const *b_scalars, nk_size_t count_scalars,
695
700
  nk_u32_t *result) {
@@ -713,31 +718,31 @@ NK_PUBLIC void nk_sqeuclidean_i4_rvv(nk_i4x2_t const *a_scalars, nk_i4x2_t const
713
718
  };
714
719
  count_scalars = nk_size_round_up_to_multiple_(count_scalars, 2);
715
720
  nk_size_t n_bytes = count_scalars / 2;
716
- nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
717
- vuint32m4_t sum_u32m4 = __riscv_vmv_v_x_u32m4(0, vlmax);
721
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m4();
722
+ vuint32m4_t sum_u32m4 = __riscv_vmv_v_x_u32m4(0, max_vector_length);
718
723
  for (nk_size_t vector_length; n_bytes > 0;
719
724
  n_bytes -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
720
725
  vector_length = __riscv_vsetvl_e8m1(n_bytes);
721
726
  vuint8m1_t a_packed_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)a_scalars, vector_length);
722
727
  vuint8m1_t b_packed_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)b_scalars, vector_length);
723
- // Build LUT indices: high nibble pair = (a_hi << 4) | b_hi
724
- vuint8m1_t hi_idx_u8m1 = __riscv_vor_vv_u8m1(__riscv_vand_vx_u8m1(a_packed_u8m1, 0xF0, vector_length),
725
- __riscv_vsrl_vx_u8m1(b_packed_u8m1, 4, vector_length),
726
- vector_length);
727
- // Low nibble pair = (a_lo << 4) | b_lo
728
- vuint8m1_t lo_idx_u8m1 = __riscv_vor_vv_u8m1(
728
+ // Build LUT indices: high nibble pair = (a_high << 4) | b_hi
729
+ vuint8m1_t high_idx_u8m1 = __riscv_vor_vv_u8m1(__riscv_vand_vx_u8m1(a_packed_u8m1, 0xF0, vector_length),
730
+ __riscv_vsrl_vx_u8m1(b_packed_u8m1, 4, vector_length),
731
+ vector_length);
732
+ // Low nibble pair = (a_low << 4) | b_lo
733
+ vuint8m1_t low_idx_u8m1 = __riscv_vor_vv_u8m1(
729
734
  __riscv_vsll_vx_u8m1(__riscv_vand_vx_u8m1(a_packed_u8m1, 0x0F, vector_length), 4, vector_length),
730
735
  __riscv_vand_vx_u8m1(b_packed_u8m1, 0x0F, vector_length), vector_length);
731
736
  // Gather squared differences from LUT (0-225, fits u8)
732
- vuint8m1_t sq_hi_u8m1 = __riscv_vluxei8_v_u8m1(nk_i4_sqd_lut_, hi_idx_u8m1, vector_length);
733
- vuint8m1_t sq_lo_u8m1 = __riscv_vluxei8_v_u8m1(nk_i4_sqd_lut_, lo_idx_u8m1, vector_length);
737
+ vuint8m1_t sq_high_u8m1 = __riscv_vluxei8_v_u8m1(nk_i4_sqd_lut_, high_idx_u8m1, vector_length);
738
+ vuint8m1_t sq_low_u8m1 = __riscv_vluxei8_v_u8m1(nk_i4_sqd_lut_, low_idx_u8m1, vector_length);
734
739
  // Combine and per-lane accumulate: u8+u8→u16, then u32+=u16
735
- vuint16m2_t combined_u16m2 = __riscv_vwaddu_vv_u16m2(sq_hi_u8m1, sq_lo_u8m1, vector_length);
740
+ vuint16m2_t combined_u16m2 = __riscv_vwaddu_vv_u16m2(sq_high_u8m1, sq_low_u8m1, vector_length);
736
741
  sum_u32m4 = __riscv_vwaddu_wv_u32m4_tu(sum_u32m4, sum_u32m4, combined_u16m2, vector_length);
737
742
  }
738
743
  // Single horizontal reduction after loop
739
- vuint32m1_t zero_u32m1 = __riscv_vmv_v_x_u32m1(0, vlmax);
740
- *result = __riscv_vmv_x_s_u32m1_u32(__riscv_vredsum_vs_u32m4_u32m1(sum_u32m4, zero_u32m1, vlmax));
744
+ vuint32m1_t zero_u32m1 = __riscv_vmv_v_x_u32m1(0, max_vector_length);
745
+ *result = __riscv_vmv_x_s_u32m1_u32(__riscv_vredsum_vs_u32m4_u32m1(sum_u32m4, zero_u32m1, max_vector_length));
741
746
  }
742
747
 
743
748
  NK_PUBLIC void nk_euclidean_i4_rvv(nk_i4x2_t const *a_scalars, nk_i4x2_t const *b_scalars, nk_size_t count_scalars,
@@ -770,10 +775,10 @@ NK_PUBLIC void nk_angular_i4_rvv(nk_i4x2_t const *a_scalars, nk_i4x2_t const *b_
770
775
  static nk_u8_t const nk_i4_sq_lut_[16] = {0, 1, 4, 9, 16, 25, 36, 49, 64, 49, 36, 25, 16, 9, 4, 1};
771
776
  count_scalars = nk_size_round_up_to_multiple_(count_scalars, 2);
772
777
  nk_size_t n_bytes = count_scalars / 2;
773
- nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
774
- vint32m4_t dot_i32m4 = __riscv_vmv_v_x_i32m4(0, vlmax);
775
- vuint32m4_t a_norm_sq_u32m4 = __riscv_vmv_v_x_u32m4(0, vlmax);
776
- vuint32m4_t b_norm_sq_u32m4 = __riscv_vmv_v_x_u32m4(0, vlmax);
778
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m4();
779
+ vint32m4_t dot_i32m4 = __riscv_vmv_v_x_i32m4(0, max_vector_length);
780
+ vuint32m4_t a_norm_sq_u32m4 = __riscv_vmv_v_x_u32m4(0, max_vector_length);
781
+ vuint32m4_t b_norm_sq_u32m4 = __riscv_vmv_v_x_u32m4(0, max_vector_length);
777
782
 
778
783
  for (nk_size_t vector_length; n_bytes > 0;
779
784
  n_bytes -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
@@ -782,44 +787,45 @@ NK_PUBLIC void nk_angular_i4_rvv(nk_i4x2_t const *a_scalars, nk_i4x2_t const *b_
782
787
  vuint8m1_t b_packed_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)b_scalars, vector_length);
783
788
 
784
789
  // Extract nibbles for index building
785
- vuint8m1_t a_hi_u8m1 = __riscv_vsrl_vx_u8m1(a_packed_u8m1, 4, vector_length);
786
- vuint8m1_t b_hi_u8m1 = __riscv_vsrl_vx_u8m1(b_packed_u8m1, 4, vector_length);
787
- vuint8m1_t a_lo_u8m1 = __riscv_vand_vx_u8m1(a_packed_u8m1, 0x0F, vector_length);
788
- vuint8m1_t b_lo_u8m1 = __riscv_vand_vx_u8m1(b_packed_u8m1, 0x0F, vector_length);
790
+ vuint8m1_t a_high_u8m1 = __riscv_vsrl_vx_u8m1(a_packed_u8m1, 4, vector_length);
791
+ vuint8m1_t b_high_u8m1 = __riscv_vsrl_vx_u8m1(b_packed_u8m1, 4, vector_length);
792
+ vuint8m1_t a_low_u8m1 = __riscv_vand_vx_u8m1(a_packed_u8m1, 0x0F, vector_length);
793
+ vuint8m1_t b_low_u8m1 = __riscv_vand_vx_u8m1(b_packed_u8m1, 0x0F, vector_length);
789
794
 
790
795
  // Dot product via 256-entry LUT: dot_lut[(a<<4)|b] = a_signed * b_signed (i8)
791
- vuint8m1_t hi_idx_u8m1 = __riscv_vor_vv_u8m1(__riscv_vand_vx_u8m1(a_packed_u8m1, 0xF0, vector_length),
792
- b_hi_u8m1, vector_length);
793
- vuint8m1_t lo_idx_u8m1 = __riscv_vor_vv_u8m1(__riscv_vsll_vx_u8m1(a_lo_u8m1, 4, vector_length), b_lo_u8m1,
794
- vector_length);
795
- vint8m1_t dot_hi_i8m1 = __riscv_vluxei8_v_i8m1(nk_i4_dot_lut_, hi_idx_u8m1, vector_length);
796
- vint8m1_t dot_lo_i8m1 = __riscv_vluxei8_v_i8m1(nk_i4_dot_lut_, lo_idx_u8m1, vector_length);
796
+ vuint8m1_t high_idx_u8m1 = __riscv_vor_vv_u8m1(__riscv_vand_vx_u8m1(a_packed_u8m1, 0xF0, vector_length),
797
+ b_high_u8m1, vector_length);
798
+ vuint8m1_t low_idx_u8m1 = __riscv_vor_vv_u8m1(__riscv_vsll_vx_u8m1(a_low_u8m1, 4, vector_length), b_low_u8m1,
799
+ vector_length);
800
+ vint8m1_t dot_high_i8m1 = __riscv_vluxei8_v_i8m1(nk_i4_dot_lut_, high_idx_u8m1, vector_length);
801
+ vint8m1_t dot_low_i8m1 = __riscv_vluxei8_v_i8m1(nk_i4_dot_lut_, low_idx_u8m1, vector_length);
797
802
  // Widen i8→i16, add hi+lo, then per-lane accumulate i32+=i16
798
- vint16m2_t dot_combined_i16m2 = __riscv_vwadd_vv_i16m2(dot_hi_i8m1, dot_lo_i8m1, vector_length);
803
+ vint16m2_t dot_combined_i16m2 = __riscv_vwadd_vv_i16m2(dot_high_i8m1, dot_low_i8m1, vector_length);
799
804
  dot_i32m4 = __riscv_vwadd_wv_i32m4_tu(dot_i32m4, dot_i32m4, dot_combined_i16m2, vector_length);
800
805
 
801
806
  // Norms via 16-entry squaring LUT + vluxei8
802
- vuint8m1_t a_hi_sq_u8m1 = __riscv_vluxei8_v_u8m1(nk_i4_sq_lut_, a_hi_u8m1, vector_length);
803
- vuint8m1_t a_lo_sq_u8m1 = __riscv_vluxei8_v_u8m1(nk_i4_sq_lut_, a_lo_u8m1, vector_length);
804
- vuint16m2_t a_sq_combined_u16m2 = __riscv_vwaddu_vv_u16m2(a_hi_sq_u8m1, a_lo_sq_u8m1, vector_length);
807
+ vuint8m1_t a_high_sq_u8m1 = __riscv_vluxei8_v_u8m1(nk_i4_sq_lut_, a_high_u8m1, vector_length);
808
+ vuint8m1_t a_low_sq_u8m1 = __riscv_vluxei8_v_u8m1(nk_i4_sq_lut_, a_low_u8m1, vector_length);
809
+ vuint16m2_t a_sq_combined_u16m2 = __riscv_vwaddu_vv_u16m2(a_high_sq_u8m1, a_low_sq_u8m1, vector_length);
805
810
  a_norm_sq_u32m4 = __riscv_vwaddu_wv_u32m4_tu(a_norm_sq_u32m4, a_norm_sq_u32m4, a_sq_combined_u16m2,
806
811
  vector_length);
807
812
 
808
- vuint8m1_t b_hi_sq_u8m1 = __riscv_vluxei8_v_u8m1(nk_i4_sq_lut_, b_hi_u8m1, vector_length);
809
- vuint8m1_t b_lo_sq_u8m1 = __riscv_vluxei8_v_u8m1(nk_i4_sq_lut_, b_lo_u8m1, vector_length);
810
- vuint16m2_t b_sq_combined_u16m2 = __riscv_vwaddu_vv_u16m2(b_hi_sq_u8m1, b_lo_sq_u8m1, vector_length);
813
+ vuint8m1_t b_high_sq_u8m1 = __riscv_vluxei8_v_u8m1(nk_i4_sq_lut_, b_high_u8m1, vector_length);
814
+ vuint8m1_t b_low_sq_u8m1 = __riscv_vluxei8_v_u8m1(nk_i4_sq_lut_, b_low_u8m1, vector_length);
815
+ vuint16m2_t b_sq_combined_u16m2 = __riscv_vwaddu_vv_u16m2(b_high_sq_u8m1, b_low_sq_u8m1, vector_length);
811
816
  b_norm_sq_u32m4 = __riscv_vwaddu_wv_u32m4_tu(b_norm_sq_u32m4, b_norm_sq_u32m4, b_sq_combined_u16m2,
812
817
  vector_length);
813
818
  }
814
819
 
815
820
  // Single horizontal reductions after loop
816
- vint32m1_t zero_i32m1 = __riscv_vmv_v_x_i32m1(0, vlmax);
817
- vuint32m1_t zero_u32m1 = __riscv_vmv_v_x_u32m1(0, vlmax);
818
- nk_i32_t dot_i32 = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m4_i32m1(dot_i32m4, zero_i32m1, vlmax));
821
+ vint32m1_t zero_i32m1 = __riscv_vmv_v_x_i32m1(0, max_vector_length);
822
+ vuint32m1_t zero_u32m1 = __riscv_vmv_v_x_u32m1(0, max_vector_length);
823
+ nk_i32_t dot_i32 = __riscv_vmv_x_s_i32m1_i32(
824
+ __riscv_vredsum_vs_i32m4_i32m1(dot_i32m4, zero_i32m1, max_vector_length));
819
825
  nk_u32_t a_norm_sq_u32 = __riscv_vmv_x_s_u32m1_u32(
820
- __riscv_vredsum_vs_u32m4_u32m1(a_norm_sq_u32m4, zero_u32m1, vlmax));
826
+ __riscv_vredsum_vs_u32m4_u32m1(a_norm_sq_u32m4, zero_u32m1, max_vector_length));
821
827
  nk_u32_t b_norm_sq_u32 = __riscv_vmv_x_s_u32m1_u32(
822
- __riscv_vredsum_vs_u32m4_u32m1(b_norm_sq_u32m4, zero_u32m1, vlmax));
828
+ __riscv_vredsum_vs_u32m4_u32m1(b_norm_sq_u32m4, zero_u32m1, max_vector_length));
823
829
 
824
830
  if (a_norm_sq_u32 == 0 && b_norm_sq_u32 == 0) { *result = 0.0f; }
825
831
  else if (dot_i32 == 0) { *result = 1.0f; }
@@ -852,31 +858,31 @@ NK_PUBLIC void nk_sqeuclidean_u4_rvv(nk_u4x2_t const *a_scalars, nk_u4x2_t const
852
858
  };
853
859
  count_scalars = nk_size_round_up_to_multiple_(count_scalars, 2);
854
860
  nk_size_t n_bytes = count_scalars / 2;
855
- nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
856
- vuint32m4_t sum_u32m4 = __riscv_vmv_v_x_u32m4(0, vlmax);
861
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m4();
862
+ vuint32m4_t sum_u32m4 = __riscv_vmv_v_x_u32m4(0, max_vector_length);
857
863
  for (nk_size_t vector_length; n_bytes > 0;
858
864
  n_bytes -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
859
865
  vector_length = __riscv_vsetvl_e8m1(n_bytes);
860
866
  vuint8m1_t a_packed_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)a_scalars, vector_length);
861
867
  vuint8m1_t b_packed_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)b_scalars, vector_length);
862
- // Build LUT indices: high nibble pair = (a_hi & 0xF0) | (b_hi >> 4)
863
- vuint8m1_t hi_idx_u8m1 = __riscv_vor_vv_u8m1(__riscv_vand_vx_u8m1(a_packed_u8m1, 0xF0, vector_length),
864
- __riscv_vsrl_vx_u8m1(b_packed_u8m1, 4, vector_length),
865
- vector_length);
866
- // Low nibble pair = (a_lo << 4) | b_lo
867
- vuint8m1_t lo_idx_u8m1 = __riscv_vor_vv_u8m1(
868
+ // Build LUT indices: high nibble pair = (a_high & 0xF0) | (b_high >> 4)
869
+ vuint8m1_t high_idx_u8m1 = __riscv_vor_vv_u8m1(__riscv_vand_vx_u8m1(a_packed_u8m1, 0xF0, vector_length),
870
+ __riscv_vsrl_vx_u8m1(b_packed_u8m1, 4, vector_length),
871
+ vector_length);
872
+ // Low nibble pair = (a_low << 4) | b_lo
873
+ vuint8m1_t low_idx_u8m1 = __riscv_vor_vv_u8m1(
868
874
  __riscv_vsll_vx_u8m1(__riscv_vand_vx_u8m1(a_packed_u8m1, 0x0F, vector_length), 4, vector_length),
869
875
  __riscv_vand_vx_u8m1(b_packed_u8m1, 0x0F, vector_length), vector_length);
870
876
  // Gather squared differences from LUT (0-225, fits u8)
871
- vuint8m1_t sq_hi_u8m1 = __riscv_vluxei8_v_u8m1(nk_u4_sqd_lut_, hi_idx_u8m1, vector_length);
872
- vuint8m1_t sq_lo_u8m1 = __riscv_vluxei8_v_u8m1(nk_u4_sqd_lut_, lo_idx_u8m1, vector_length);
877
+ vuint8m1_t sq_high_u8m1 = __riscv_vluxei8_v_u8m1(nk_u4_sqd_lut_, high_idx_u8m1, vector_length);
878
+ vuint8m1_t sq_low_u8m1 = __riscv_vluxei8_v_u8m1(nk_u4_sqd_lut_, low_idx_u8m1, vector_length);
873
879
  // Combine and per-lane accumulate: u8+u8→u16, then u32+=u16
874
- vuint16m2_t combined_u16m2 = __riscv_vwaddu_vv_u16m2(sq_hi_u8m1, sq_lo_u8m1, vector_length);
880
+ vuint16m2_t combined_u16m2 = __riscv_vwaddu_vv_u16m2(sq_high_u8m1, sq_low_u8m1, vector_length);
875
881
  sum_u32m4 = __riscv_vwaddu_wv_u32m4_tu(sum_u32m4, sum_u32m4, combined_u16m2, vector_length);
876
882
  }
877
883
  // Single horizontal reduction after loop
878
- vuint32m1_t zero_u32m1 = __riscv_vmv_v_x_u32m1(0, vlmax);
879
- *result = __riscv_vmv_x_s_u32m1_u32(__riscv_vredsum_vs_u32m4_u32m1(sum_u32m4, zero_u32m1, vlmax));
884
+ vuint32m1_t zero_u32m1 = __riscv_vmv_v_x_u32m1(0, max_vector_length);
885
+ *result = __riscv_vmv_x_s_u32m1_u32(__riscv_vredsum_vs_u32m4_u32m1(sum_u32m4, zero_u32m1, max_vector_length));
880
886
  }
881
887
 
882
888
  NK_PUBLIC void nk_euclidean_u4_rvv(nk_u4x2_t const *a_scalars, nk_u4x2_t const *b_scalars, nk_size_t count_scalars,
@@ -909,10 +915,10 @@ NK_PUBLIC void nk_angular_u4_rvv(nk_u4x2_t const *a_scalars, nk_u4x2_t const *b_
909
915
  static nk_u8_t const nk_u4_sq_lut_[16] = {0, 1, 4, 9, 16, 25, 36, 49, 64, 81, 100, 121, 144, 169, 196, 225};
910
916
  count_scalars = nk_size_round_up_to_multiple_(count_scalars, 2);
911
917
  nk_size_t n_bytes = count_scalars / 2;
912
- nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
913
- vuint32m4_t dot_u32m4 = __riscv_vmv_v_x_u32m4(0, vlmax);
914
- vuint32m4_t a_norm_sq_u32m4 = __riscv_vmv_v_x_u32m4(0, vlmax);
915
- vuint32m4_t b_norm_sq_u32m4 = __riscv_vmv_v_x_u32m4(0, vlmax);
918
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m4();
919
+ vuint32m4_t dot_u32m4 = __riscv_vmv_v_x_u32m4(0, max_vector_length);
920
+ vuint32m4_t a_norm_sq_u32m4 = __riscv_vmv_v_x_u32m4(0, max_vector_length);
921
+ vuint32m4_t b_norm_sq_u32m4 = __riscv_vmv_v_x_u32m4(0, max_vector_length);
916
922
 
917
923
  for (nk_size_t vector_length; n_bytes > 0;
918
924
  n_bytes -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
@@ -921,43 +927,44 @@ NK_PUBLIC void nk_angular_u4_rvv(nk_u4x2_t const *a_scalars, nk_u4x2_t const *b_
921
927
  vuint8m1_t b_packed_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)b_scalars, vector_length);
922
928
 
923
929
  // Extract nibbles
924
- vuint8m1_t a_hi_u8m1 = __riscv_vsrl_vx_u8m1(a_packed_u8m1, 4, vector_length);
925
- vuint8m1_t b_hi_u8m1 = __riscv_vsrl_vx_u8m1(b_packed_u8m1, 4, vector_length);
926
- vuint8m1_t a_lo_u8m1 = __riscv_vand_vx_u8m1(a_packed_u8m1, 0x0F, vector_length);
927
- vuint8m1_t b_lo_u8m1 = __riscv_vand_vx_u8m1(b_packed_u8m1, 0x0F, vector_length);
930
+ vuint8m1_t a_high_u8m1 = __riscv_vsrl_vx_u8m1(a_packed_u8m1, 4, vector_length);
931
+ vuint8m1_t b_high_u8m1 = __riscv_vsrl_vx_u8m1(b_packed_u8m1, 4, vector_length);
932
+ vuint8m1_t a_low_u8m1 = __riscv_vand_vx_u8m1(a_packed_u8m1, 0x0F, vector_length);
933
+ vuint8m1_t b_low_u8m1 = __riscv_vand_vx_u8m1(b_packed_u8m1, 0x0F, vector_length);
928
934
 
929
935
  // Dot product via 256-entry LUT: dot_lut[(a<<4)|b] = a * b (u8)
930
- vuint8m1_t hi_idx_u8m1 = __riscv_vor_vv_u8m1(__riscv_vand_vx_u8m1(a_packed_u8m1, 0xF0, vector_length),
931
- b_hi_u8m1, vector_length);
932
- vuint8m1_t lo_idx_u8m1 = __riscv_vor_vv_u8m1(__riscv_vsll_vx_u8m1(a_lo_u8m1, 4, vector_length), b_lo_u8m1,
933
- vector_length);
934
- vuint8m1_t dot_hi_u8m1 = __riscv_vluxei8_v_u8m1(nk_u4_dot_lut_, hi_idx_u8m1, vector_length);
935
- vuint8m1_t dot_lo_u8m1 = __riscv_vluxei8_v_u8m1(nk_u4_dot_lut_, lo_idx_u8m1, vector_length);
936
+ vuint8m1_t high_idx_u8m1 = __riscv_vor_vv_u8m1(__riscv_vand_vx_u8m1(a_packed_u8m1, 0xF0, vector_length),
937
+ b_high_u8m1, vector_length);
938
+ vuint8m1_t low_idx_u8m1 = __riscv_vor_vv_u8m1(__riscv_vsll_vx_u8m1(a_low_u8m1, 4, vector_length), b_low_u8m1,
939
+ vector_length);
940
+ vuint8m1_t dot_high_u8m1 = __riscv_vluxei8_v_u8m1(nk_u4_dot_lut_, high_idx_u8m1, vector_length);
941
+ vuint8m1_t dot_low_u8m1 = __riscv_vluxei8_v_u8m1(nk_u4_dot_lut_, low_idx_u8m1, vector_length);
936
942
  // Widen u8→u16, add hi+lo, then per-lane accumulate u32+=u16
937
- vuint16m2_t dot_combined_u16m2 = __riscv_vwaddu_vv_u16m2(dot_hi_u8m1, dot_lo_u8m1, vector_length);
943
+ vuint16m2_t dot_combined_u16m2 = __riscv_vwaddu_vv_u16m2(dot_high_u8m1, dot_low_u8m1, vector_length);
938
944
  dot_u32m4 = __riscv_vwaddu_wv_u32m4_tu(dot_u32m4, dot_u32m4, dot_combined_u16m2, vector_length);
939
945
 
940
946
  // Norms via 16-entry squaring LUT + vluxei8
941
- vuint8m1_t a_hi_sq_u8m1 = __riscv_vluxei8_v_u8m1(nk_u4_sq_lut_, a_hi_u8m1, vector_length);
942
- vuint8m1_t a_lo_sq_u8m1 = __riscv_vluxei8_v_u8m1(nk_u4_sq_lut_, a_lo_u8m1, vector_length);
943
- vuint16m2_t a_sq_combined_u16m2 = __riscv_vwaddu_vv_u16m2(a_hi_sq_u8m1, a_lo_sq_u8m1, vector_length);
947
+ vuint8m1_t a_high_sq_u8m1 = __riscv_vluxei8_v_u8m1(nk_u4_sq_lut_, a_high_u8m1, vector_length);
948
+ vuint8m1_t a_low_sq_u8m1 = __riscv_vluxei8_v_u8m1(nk_u4_sq_lut_, a_low_u8m1, vector_length);
949
+ vuint16m2_t a_sq_combined_u16m2 = __riscv_vwaddu_vv_u16m2(a_high_sq_u8m1, a_low_sq_u8m1, vector_length);
944
950
  a_norm_sq_u32m4 = __riscv_vwaddu_wv_u32m4_tu(a_norm_sq_u32m4, a_norm_sq_u32m4, a_sq_combined_u16m2,
945
951
  vector_length);
946
952
 
947
- vuint8m1_t b_hi_sq_u8m1 = __riscv_vluxei8_v_u8m1(nk_u4_sq_lut_, b_hi_u8m1, vector_length);
948
- vuint8m1_t b_lo_sq_u8m1 = __riscv_vluxei8_v_u8m1(nk_u4_sq_lut_, b_lo_u8m1, vector_length);
949
- vuint16m2_t b_sq_combined_u16m2 = __riscv_vwaddu_vv_u16m2(b_hi_sq_u8m1, b_lo_sq_u8m1, vector_length);
953
+ vuint8m1_t b_high_sq_u8m1 = __riscv_vluxei8_v_u8m1(nk_u4_sq_lut_, b_high_u8m1, vector_length);
954
+ vuint8m1_t b_low_sq_u8m1 = __riscv_vluxei8_v_u8m1(nk_u4_sq_lut_, b_low_u8m1, vector_length);
955
+ vuint16m2_t b_sq_combined_u16m2 = __riscv_vwaddu_vv_u16m2(b_high_sq_u8m1, b_low_sq_u8m1, vector_length);
950
956
  b_norm_sq_u32m4 = __riscv_vwaddu_wv_u32m4_tu(b_norm_sq_u32m4, b_norm_sq_u32m4, b_sq_combined_u16m2,
951
957
  vector_length);
952
958
  }
953
959
 
954
960
  // Single horizontal reductions after loop
955
- vuint32m1_t zero_u32m1 = __riscv_vmv_v_x_u32m1(0, vlmax);
956
- nk_u32_t dot_u32 = __riscv_vmv_x_s_u32m1_u32(__riscv_vredsum_vs_u32m4_u32m1(dot_u32m4, zero_u32m1, vlmax));
961
+ vuint32m1_t zero_u32m1 = __riscv_vmv_v_x_u32m1(0, max_vector_length);
962
+ nk_u32_t dot_u32 = __riscv_vmv_x_s_u32m1_u32(
963
+ __riscv_vredsum_vs_u32m4_u32m1(dot_u32m4, zero_u32m1, max_vector_length));
957
964
  nk_u32_t a_norm_sq_u32 = __riscv_vmv_x_s_u32m1_u32(
958
- __riscv_vredsum_vs_u32m4_u32m1(a_norm_sq_u32m4, zero_u32m1, vlmax));
965
+ __riscv_vredsum_vs_u32m4_u32m1(a_norm_sq_u32m4, zero_u32m1, max_vector_length));
959
966
  nk_u32_t b_norm_sq_u32 = __riscv_vmv_x_s_u32m1_u32(
960
- __riscv_vredsum_vs_u32m4_u32m1(b_norm_sq_u32m4, zero_u32m1, vlmax));
967
+ __riscv_vredsum_vs_u32m4_u32m1(b_norm_sq_u32m4, zero_u32m1, max_vector_length));
961
968
 
962
969
  if (a_norm_sq_u32 == 0 && b_norm_sq_u32 == 0) { *result = 0.0f; }
963
970
  else if (dot_u32 == 0) { *result = 1.0f; }
@@ -978,7 +985,7 @@ NK_PUBLIC void nk_angular_u4_rvv(nk_u4x2_t const *a_scalars, nk_u4x2_t const *b_
978
985
  #pragma GCC pop_options
979
986
  #endif
980
987
 
981
- #pragma endregion - Small Integers
988
+ #pragma endregion I8 and U8 Integers
982
989
  #endif // NK_TARGET_RVV
983
990
  #endif // NK_TARGET_RISCV_
984
991
  #endif // NK_SPATIAL_RVV_H
@@ -37,9 +37,9 @@ extern "C" {
37
37
  NK_PUBLIC void nk_sqeuclidean_bf16_rvvbf16(nk_bf16_t const *a_scalars, nk_bf16_t const *b_scalars,
38
38
  nk_size_t count_scalars, nk_f32_t *result) {
39
39
  // Per-lane accumulators — deferred horizontal reduction
40
- nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
41
- vfloat32m2_t sq_sum_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax); // a² + b²
42
- vfloat32m2_t ab_sum_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax); // a × b
40
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
41
+ vfloat32m2_t sq_sum_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, max_vector_length); // a² + b²
42
+ vfloat32m2_t ab_sum_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, max_vector_length); // a × b
43
43
 
44
44
  for (nk_size_t vector_length; count_scalars > 0;
45
45
  count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
@@ -57,8 +57,10 @@ NK_PUBLIC void nk_sqeuclidean_bf16_rvvbf16(nk_bf16_t const *a_scalars, nk_bf16_t
57
57
 
58
58
  // Single horizontal reduction after the loop
59
59
  vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
60
- nk_f32_t sq_sum = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m2_f32m1(sq_sum_f32m2, zero_f32m1, vlmax));
61
- nk_f32_t ab_sum = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m2_f32m1(ab_sum_f32m2, zero_f32m1, vlmax));
60
+ nk_f32_t sq_sum = __riscv_vfmv_f_s_f32m1_f32(
61
+ __riscv_vfredusum_vs_f32m2_f32m1(sq_sum_f32m2, zero_f32m1, max_vector_length));
62
+ nk_f32_t ab_sum = __riscv_vfmv_f_s_f32m1_f32(
63
+ __riscv_vfredusum_vs_f32m2_f32m1(ab_sum_f32m2, zero_f32m1, max_vector_length));
62
64
  *result = sq_sum - 2.0f * ab_sum;
63
65
  }
64
66
 
@@ -72,10 +74,10 @@ NK_PUBLIC void nk_euclidean_bf16_rvvbf16(nk_bf16_t const *a_scalars, nk_bf16_t c
72
74
  NK_PUBLIC void nk_angular_bf16_rvvbf16(nk_bf16_t const *a_scalars, nk_bf16_t const *b_scalars, nk_size_t count_scalars,
73
75
  nk_f32_t *result) {
74
76
  // Per-lane accumulators — deferred horizontal reduction
75
- nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
76
- vfloat32m2_t dot_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
77
- vfloat32m2_t a_sq_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
78
- vfloat32m2_t b_sq_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
77
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
78
+ vfloat32m2_t dot_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, max_vector_length);
79
+ vfloat32m2_t a_sq_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, max_vector_length);
80
+ vfloat32m2_t b_sq_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, max_vector_length);
79
81
 
80
82
  for (nk_size_t vector_length; count_scalars > 0;
81
83
  count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
@@ -95,9 +97,12 @@ NK_PUBLIC void nk_angular_bf16_rvvbf16(nk_bf16_t const *a_scalars, nk_bf16_t con
95
97
 
96
98
  // Single horizontal reduction after the loop
97
99
  vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
98
- nk_f32_t dot = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m2_f32m1(dot_f32m2, zero_f32m1, vlmax));
99
- nk_f32_t a_sq = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m2_f32m1(a_sq_f32m2, zero_f32m1, vlmax));
100
- nk_f32_t b_sq = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m2_f32m1(b_sq_f32m2, zero_f32m1, vlmax));
100
+ nk_f32_t dot = __riscv_vfmv_f_s_f32m1_f32(
101
+ __riscv_vfredusum_vs_f32m2_f32m1(dot_f32m2, zero_f32m1, max_vector_length));
102
+ nk_f32_t a_sq = __riscv_vfmv_f_s_f32m1_f32(
103
+ __riscv_vfredusum_vs_f32m2_f32m1(a_sq_f32m2, zero_f32m1, max_vector_length));
104
+ nk_f32_t b_sq = __riscv_vfmv_f_s_f32m1_f32(
105
+ __riscv_vfredusum_vs_f32m2_f32m1(b_sq_f32m2, zero_f32m1, max_vector_length));
101
106
 
102
107
  // Normalize: 1 − dot / sqrt(‖a‖² × ‖b‖²)
103
108
  if (a_sq == 0.0f && b_sq == 0.0f) { *result = 0.0f; }