numkong 7.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (294) hide show
  1. package/LICENSE +201 -0
  2. package/README.md +495 -0
  3. package/binding.gyp +540 -0
  4. package/c/dispatch.h +512 -0
  5. package/c/dispatch_bf16.c +389 -0
  6. package/c/dispatch_bf16c.c +52 -0
  7. package/c/dispatch_e2m3.c +263 -0
  8. package/c/dispatch_e3m2.c +243 -0
  9. package/c/dispatch_e4m3.c +276 -0
  10. package/c/dispatch_e5m2.c +272 -0
  11. package/c/dispatch_f16.c +376 -0
  12. package/c/dispatch_f16c.c +58 -0
  13. package/c/dispatch_f32.c +378 -0
  14. package/c/dispatch_f32c.c +99 -0
  15. package/c/dispatch_f64.c +296 -0
  16. package/c/dispatch_f64c.c +98 -0
  17. package/c/dispatch_i16.c +96 -0
  18. package/c/dispatch_i32.c +89 -0
  19. package/c/dispatch_i4.c +150 -0
  20. package/c/dispatch_i64.c +86 -0
  21. package/c/dispatch_i8.c +289 -0
  22. package/c/dispatch_other.c +330 -0
  23. package/c/dispatch_u1.c +148 -0
  24. package/c/dispatch_u16.c +124 -0
  25. package/c/dispatch_u32.c +118 -0
  26. package/c/dispatch_u4.c +150 -0
  27. package/c/dispatch_u64.c +102 -0
  28. package/c/dispatch_u8.c +303 -0
  29. package/c/numkong.c +950 -0
  30. package/include/README.md +573 -0
  31. package/include/module.modulemap +129 -0
  32. package/include/numkong/attention/sapphireamx.h +1361 -0
  33. package/include/numkong/attention/sme.h +2066 -0
  34. package/include/numkong/attention.h +49 -0
  35. package/include/numkong/capabilities.h +748 -0
  36. package/include/numkong/cast/README.md +262 -0
  37. package/include/numkong/cast/haswell.h +975 -0
  38. package/include/numkong/cast/icelake.h +470 -0
  39. package/include/numkong/cast/neon.h +1192 -0
  40. package/include/numkong/cast/rvv.h +1021 -0
  41. package/include/numkong/cast/sapphire.h +262 -0
  42. package/include/numkong/cast/serial.h +2262 -0
  43. package/include/numkong/cast/skylake.h +856 -0
  44. package/include/numkong/cast/v128relaxed.h +180 -0
  45. package/include/numkong/cast.h +230 -0
  46. package/include/numkong/curved/README.md +223 -0
  47. package/include/numkong/curved/genoa.h +182 -0
  48. package/include/numkong/curved/haswell.h +276 -0
  49. package/include/numkong/curved/neon.h +205 -0
  50. package/include/numkong/curved/neonbfdot.h +212 -0
  51. package/include/numkong/curved/neonhalf.h +212 -0
  52. package/include/numkong/curved/rvv.h +305 -0
  53. package/include/numkong/curved/serial.h +207 -0
  54. package/include/numkong/curved/skylake.h +457 -0
  55. package/include/numkong/curved/smef64.h +506 -0
  56. package/include/numkong/curved.h +517 -0
  57. package/include/numkong/curved.hpp +144 -0
  58. package/include/numkong/dot/README.md +425 -0
  59. package/include/numkong/dot/alder.h +563 -0
  60. package/include/numkong/dot/genoa.h +315 -0
  61. package/include/numkong/dot/haswell.h +1688 -0
  62. package/include/numkong/dot/icelake.h +883 -0
  63. package/include/numkong/dot/neon.h +818 -0
  64. package/include/numkong/dot/neonbfdot.h +244 -0
  65. package/include/numkong/dot/neonfhm.h +360 -0
  66. package/include/numkong/dot/neonhalf.h +198 -0
  67. package/include/numkong/dot/neonsdot.h +508 -0
  68. package/include/numkong/dot/rvv.h +714 -0
  69. package/include/numkong/dot/rvvbb.h +72 -0
  70. package/include/numkong/dot/rvvbf16.h +123 -0
  71. package/include/numkong/dot/rvvhalf.h +129 -0
  72. package/include/numkong/dot/sapphire.h +141 -0
  73. package/include/numkong/dot/serial.h +838 -0
  74. package/include/numkong/dot/sierra.h +405 -0
  75. package/include/numkong/dot/skylake.h +1084 -0
  76. package/include/numkong/dot/sve.h +379 -0
  77. package/include/numkong/dot/svebfdot.h +74 -0
  78. package/include/numkong/dot/svehalf.h +123 -0
  79. package/include/numkong/dot/v128relaxed.h +1258 -0
  80. package/include/numkong/dot.h +1070 -0
  81. package/include/numkong/dot.hpp +94 -0
  82. package/include/numkong/dots/README.md +496 -0
  83. package/include/numkong/dots/alder.h +114 -0
  84. package/include/numkong/dots/genoa.h +94 -0
  85. package/include/numkong/dots/haswell.h +295 -0
  86. package/include/numkong/dots/icelake.h +171 -0
  87. package/include/numkong/dots/neon.h +120 -0
  88. package/include/numkong/dots/neonbfdot.h +58 -0
  89. package/include/numkong/dots/neonfhm.h +94 -0
  90. package/include/numkong/dots/neonhalf.h +57 -0
  91. package/include/numkong/dots/neonsdot.h +108 -0
  92. package/include/numkong/dots/rvv.h +2486 -0
  93. package/include/numkong/dots/sapphireamx.h +3973 -0
  94. package/include/numkong/dots/serial.h +2844 -0
  95. package/include/numkong/dots/sierra.h +97 -0
  96. package/include/numkong/dots/skylake.h +196 -0
  97. package/include/numkong/dots/sme.h +5372 -0
  98. package/include/numkong/dots/smebi32.h +461 -0
  99. package/include/numkong/dots/smef64.h +1318 -0
  100. package/include/numkong/dots/smehalf.h +47 -0
  101. package/include/numkong/dots/v128relaxed.h +294 -0
  102. package/include/numkong/dots.h +2804 -0
  103. package/include/numkong/dots.hpp +639 -0
  104. package/include/numkong/each/README.md +469 -0
  105. package/include/numkong/each/haswell.h +1658 -0
  106. package/include/numkong/each/icelake.h +272 -0
  107. package/include/numkong/each/neon.h +1104 -0
  108. package/include/numkong/each/neonbfdot.h +212 -0
  109. package/include/numkong/each/neonhalf.h +410 -0
  110. package/include/numkong/each/rvv.h +1121 -0
  111. package/include/numkong/each/sapphire.h +477 -0
  112. package/include/numkong/each/serial.h +260 -0
  113. package/include/numkong/each/skylake.h +1562 -0
  114. package/include/numkong/each.h +2146 -0
  115. package/include/numkong/each.hpp +434 -0
  116. package/include/numkong/geospatial/README.md +147 -0
  117. package/include/numkong/geospatial/haswell.h +593 -0
  118. package/include/numkong/geospatial/neon.h +571 -0
  119. package/include/numkong/geospatial/rvv.h +701 -0
  120. package/include/numkong/geospatial/serial.h +309 -0
  121. package/include/numkong/geospatial/skylake.h +577 -0
  122. package/include/numkong/geospatial/v128relaxed.h +613 -0
  123. package/include/numkong/geospatial.h +453 -0
  124. package/include/numkong/geospatial.hpp +235 -0
  125. package/include/numkong/matrix.hpp +336 -0
  126. package/include/numkong/maxsim/README.md +187 -0
  127. package/include/numkong/maxsim/alder.h +511 -0
  128. package/include/numkong/maxsim/genoa.h +115 -0
  129. package/include/numkong/maxsim/haswell.h +553 -0
  130. package/include/numkong/maxsim/icelake.h +480 -0
  131. package/include/numkong/maxsim/neonsdot.h +394 -0
  132. package/include/numkong/maxsim/sapphireamx.h +877 -0
  133. package/include/numkong/maxsim/serial.h +490 -0
  134. package/include/numkong/maxsim/sme.h +929 -0
  135. package/include/numkong/maxsim/v128relaxed.h +280 -0
  136. package/include/numkong/maxsim.h +571 -0
  137. package/include/numkong/maxsim.hpp +133 -0
  138. package/include/numkong/mesh/README.md +227 -0
  139. package/include/numkong/mesh/haswell.h +2235 -0
  140. package/include/numkong/mesh/neon.h +1329 -0
  141. package/include/numkong/mesh/neonbfdot.h +842 -0
  142. package/include/numkong/mesh/neonhalf.h +616 -0
  143. package/include/numkong/mesh/rvv.h +916 -0
  144. package/include/numkong/mesh/serial.h +742 -0
  145. package/include/numkong/mesh/skylake.h +1135 -0
  146. package/include/numkong/mesh/v128relaxed.h +1052 -0
  147. package/include/numkong/mesh.h +652 -0
  148. package/include/numkong/mesh.hpp +762 -0
  149. package/include/numkong/numkong.h +78 -0
  150. package/include/numkong/numkong.hpp +57 -0
  151. package/include/numkong/probability/README.md +173 -0
  152. package/include/numkong/probability/haswell.h +267 -0
  153. package/include/numkong/probability/neon.h +225 -0
  154. package/include/numkong/probability/rvv.h +409 -0
  155. package/include/numkong/probability/serial.h +169 -0
  156. package/include/numkong/probability/skylake.h +324 -0
  157. package/include/numkong/probability.h +383 -0
  158. package/include/numkong/probability.hpp +120 -0
  159. package/include/numkong/random.h +50 -0
  160. package/include/numkong/random.hpp +285 -0
  161. package/include/numkong/reduce/README.md +547 -0
  162. package/include/numkong/reduce/alder.h +632 -0
  163. package/include/numkong/reduce/genoa.h +201 -0
  164. package/include/numkong/reduce/haswell.h +3783 -0
  165. package/include/numkong/reduce/icelake.h +549 -0
  166. package/include/numkong/reduce/neon.h +3841 -0
  167. package/include/numkong/reduce/neonbfdot.h +353 -0
  168. package/include/numkong/reduce/neonfhm.h +665 -0
  169. package/include/numkong/reduce/neonhalf.h +157 -0
  170. package/include/numkong/reduce/neonsdot.h +357 -0
  171. package/include/numkong/reduce/rvv.h +3407 -0
  172. package/include/numkong/reduce/serial.h +757 -0
  173. package/include/numkong/reduce/sierra.h +338 -0
  174. package/include/numkong/reduce/skylake.h +3792 -0
  175. package/include/numkong/reduce/v128relaxed.h +2302 -0
  176. package/include/numkong/reduce.h +1597 -0
  177. package/include/numkong/reduce.hpp +633 -0
  178. package/include/numkong/scalar/README.md +89 -0
  179. package/include/numkong/scalar/haswell.h +113 -0
  180. package/include/numkong/scalar/neon.h +122 -0
  181. package/include/numkong/scalar/neonhalf.h +70 -0
  182. package/include/numkong/scalar/rvv.h +211 -0
  183. package/include/numkong/scalar/sapphire.h +63 -0
  184. package/include/numkong/scalar/serial.h +332 -0
  185. package/include/numkong/scalar/v128relaxed.h +56 -0
  186. package/include/numkong/scalar.h +683 -0
  187. package/include/numkong/set/README.md +179 -0
  188. package/include/numkong/set/haswell.h +334 -0
  189. package/include/numkong/set/icelake.h +485 -0
  190. package/include/numkong/set/neon.h +364 -0
  191. package/include/numkong/set/rvv.h +226 -0
  192. package/include/numkong/set/rvvbb.h +117 -0
  193. package/include/numkong/set/serial.h +174 -0
  194. package/include/numkong/set/sve.h +185 -0
  195. package/include/numkong/set/v128relaxed.h +240 -0
  196. package/include/numkong/set.h +457 -0
  197. package/include/numkong/set.hpp +114 -0
  198. package/include/numkong/sets/README.md +149 -0
  199. package/include/numkong/sets/haswell.h +63 -0
  200. package/include/numkong/sets/icelake.h +66 -0
  201. package/include/numkong/sets/neon.h +61 -0
  202. package/include/numkong/sets/serial.h +43 -0
  203. package/include/numkong/sets/smebi32.h +1099 -0
  204. package/include/numkong/sets/v128relaxed.h +58 -0
  205. package/include/numkong/sets.h +339 -0
  206. package/include/numkong/sparse/README.md +156 -0
  207. package/include/numkong/sparse/icelake.h +463 -0
  208. package/include/numkong/sparse/neon.h +288 -0
  209. package/include/numkong/sparse/serial.h +117 -0
  210. package/include/numkong/sparse/sve2.h +507 -0
  211. package/include/numkong/sparse/turin.h +322 -0
  212. package/include/numkong/sparse.h +363 -0
  213. package/include/numkong/sparse.hpp +113 -0
  214. package/include/numkong/spatial/README.md +435 -0
  215. package/include/numkong/spatial/alder.h +607 -0
  216. package/include/numkong/spatial/genoa.h +290 -0
  217. package/include/numkong/spatial/haswell.h +960 -0
  218. package/include/numkong/spatial/icelake.h +586 -0
  219. package/include/numkong/spatial/neon.h +773 -0
  220. package/include/numkong/spatial/neonbfdot.h +165 -0
  221. package/include/numkong/spatial/neonhalf.h +118 -0
  222. package/include/numkong/spatial/neonsdot.h +261 -0
  223. package/include/numkong/spatial/rvv.h +984 -0
  224. package/include/numkong/spatial/rvvbf16.h +123 -0
  225. package/include/numkong/spatial/rvvhalf.h +117 -0
  226. package/include/numkong/spatial/sapphire.h +343 -0
  227. package/include/numkong/spatial/serial.h +346 -0
  228. package/include/numkong/spatial/sierra.h +323 -0
  229. package/include/numkong/spatial/skylake.h +606 -0
  230. package/include/numkong/spatial/sve.h +224 -0
  231. package/include/numkong/spatial/svebfdot.h +122 -0
  232. package/include/numkong/spatial/svehalf.h +109 -0
  233. package/include/numkong/spatial/v128relaxed.h +717 -0
  234. package/include/numkong/spatial.h +1425 -0
  235. package/include/numkong/spatial.hpp +183 -0
  236. package/include/numkong/spatials/README.md +580 -0
  237. package/include/numkong/spatials/alder.h +94 -0
  238. package/include/numkong/spatials/genoa.h +94 -0
  239. package/include/numkong/spatials/haswell.h +219 -0
  240. package/include/numkong/spatials/icelake.h +113 -0
  241. package/include/numkong/spatials/neon.h +109 -0
  242. package/include/numkong/spatials/neonbfdot.h +60 -0
  243. package/include/numkong/spatials/neonfhm.h +92 -0
  244. package/include/numkong/spatials/neonhalf.h +58 -0
  245. package/include/numkong/spatials/neonsdot.h +109 -0
  246. package/include/numkong/spatials/rvv.h +1960 -0
  247. package/include/numkong/spatials/sapphireamx.h +1149 -0
  248. package/include/numkong/spatials/serial.h +226 -0
  249. package/include/numkong/spatials/sierra.h +96 -0
  250. package/include/numkong/spatials/skylake.h +184 -0
  251. package/include/numkong/spatials/sme.h +1901 -0
  252. package/include/numkong/spatials/smef64.h +465 -0
  253. package/include/numkong/spatials/v128relaxed.h +240 -0
  254. package/include/numkong/spatials.h +3021 -0
  255. package/include/numkong/spatials.hpp +508 -0
  256. package/include/numkong/tensor.hpp +1592 -0
  257. package/include/numkong/trigonometry/README.md +184 -0
  258. package/include/numkong/trigonometry/haswell.h +652 -0
  259. package/include/numkong/trigonometry/neon.h +639 -0
  260. package/include/numkong/trigonometry/rvv.h +699 -0
  261. package/include/numkong/trigonometry/serial.h +703 -0
  262. package/include/numkong/trigonometry/skylake.h +721 -0
  263. package/include/numkong/trigonometry/v128relaxed.h +666 -0
  264. package/include/numkong/trigonometry.h +467 -0
  265. package/include/numkong/trigonometry.hpp +166 -0
  266. package/include/numkong/types.h +1384 -0
  267. package/include/numkong/types.hpp +5603 -0
  268. package/include/numkong/vector.hpp +698 -0
  269. package/javascript/README.md +246 -0
  270. package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
  271. package/javascript/dist/cjs/numkong-wasm.js +617 -0
  272. package/javascript/dist/cjs/numkong.d.ts +343 -0
  273. package/javascript/dist/cjs/numkong.js +523 -0
  274. package/javascript/dist/cjs/package.json +3 -0
  275. package/javascript/dist/cjs/types.d.ts +284 -0
  276. package/javascript/dist/cjs/types.js +653 -0
  277. package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
  278. package/javascript/dist/esm/numkong-wasm.js +595 -0
  279. package/javascript/dist/esm/numkong.d.ts +343 -0
  280. package/javascript/dist/esm/numkong.js +452 -0
  281. package/javascript/dist/esm/package.json +3 -0
  282. package/javascript/dist/esm/types.d.ts +284 -0
  283. package/javascript/dist/esm/types.js +630 -0
  284. package/javascript/dist-package-cjs.json +3 -0
  285. package/javascript/dist-package-esm.json +3 -0
  286. package/javascript/node-gyp-build.d.ts +1 -0
  287. package/javascript/numkong-wasm.ts +756 -0
  288. package/javascript/numkong.c +689 -0
  289. package/javascript/numkong.ts +575 -0
  290. package/javascript/tsconfig-base.json +39 -0
  291. package/javascript/tsconfig-cjs.json +8 -0
  292. package/javascript/tsconfig-esm.json +8 -0
  293. package/javascript/types.ts +674 -0
  294. package/package.json +87 -0
@@ -0,0 +1,984 @@
1
+ /**
2
+ * @brief SIMD-accelerated Spatial Similarity Measures for RISC-V.
3
+ * @file include/numkong/spatial/rvv.h
4
+ * @author Ash Vardanian
5
+ * @date January 5, 2026
6
+ *
7
+ * @sa include/numkong/spatial.h
8
+ *
9
+ * RVV uses vector length agnostic programming where:
10
+ * - `vsetvl_e*m*(n)` sets VL = min(n, VLMAX) and returns actual VL
11
+ * - Loads/stores with VL automatically handle partial vectors (tail elements)
12
+ * - No explicit masking needed for simple reductions
13
+ *
14
+ * This file contains base RVV 1.0 operations (i8, u8, f32, f64).
15
+ * For f16 (Zvfh) see rvvhalf.h, for bf16 (Zvfbfwma) see rvvbf16.h.
16
+ *
17
+ * Precision strategies matching Skylake:
18
+ * - i8 L2: diff (i8-i8 → i16), square (i16 × i16 → i32), reduce to i32
19
+ * - u8 L2: |diff| via widening, square → u32, reduce to u32
20
+ * - f32: Widen to f64 for accumulation, downcast result to f32
21
+ * - f64: Direct f64 accumulation
22
+ */
23
+ #ifndef NK_SPATIAL_RVV_H
24
+ #define NK_SPATIAL_RVV_H
25
+
26
+ #if NK_TARGET_RISCV_
27
+ #if NK_TARGET_RVV
28
+
29
+ #include "numkong/types.h"
30
+ #include "numkong/scalar/rvv.h" // `nk_f32_rsqrt_rvv`
31
+ #include "numkong/cast/rvv.h" // `nk_e4m3m1_to_f32m4_rvv_`
32
+ #include "numkong/dot/rvv.h" // `nk_dot_stable_sum_f64m1_rvv_`
33
+
34
+ #if defined(__clang__)
35
+ #pragma clang attribute push(__attribute__((target("arch=+v"))), apply_to = function)
36
+ #elif defined(__GNUC__)
37
+ #pragma GCC push_options
38
+ #pragma GCC target("arch=+v")
39
+ #endif
40
+
41
+ #if defined(__cplusplus)
42
+ extern "C" {
43
+ #endif
44
+
45
+ /**
46
+ * @brief Vectorized `1/√x` for f32 m1 register group using `vfrsqrt7` + 2 Newton-Raphson steps.
47
+ *
48
+ * Achieves ~28 bits of precision, sufficient for f32's 23-bit mantissa.
49
+ * Formula per iteration: y' = y × (3 − x × y²) × 0.5
50
+ */
51
+ NK_INTERNAL vfloat32m1_t nk_rsqrt_f32m1_rvv_(vfloat32m1_t values_f32m1, size_t vector_length) {
52
+ vfloat32m1_t rsqrt_f32m1 = __riscv_vfrsqrt7_v_f32m1(values_f32m1, vector_length);
53
+ for (int step = 0; step < 2; step++) {
54
+ vfloat32m1_t rsqrt_sq_f32m1 = __riscv_vfmul_vv_f32m1(rsqrt_f32m1, rsqrt_f32m1, vector_length);
55
+ vfloat32m1_t residual_f32m1 = __riscv_vfrsub_vf_f32m1(
56
+ __riscv_vfmul_vv_f32m1(values_f32m1, rsqrt_sq_f32m1, vector_length), 3.0f, vector_length);
57
+ rsqrt_f32m1 = __riscv_vfmul_vf_f32m1(__riscv_vfmul_vv_f32m1(rsqrt_f32m1, residual_f32m1, vector_length), 0.5f,
58
+ vector_length);
59
+ }
60
+ return rsqrt_f32m1;
61
+ }
62
+
63
+ /**
64
+ * @brief Vectorized `1/√x` for f64 m1 register group using `vfrsqrt7` + 3 Newton-Raphson steps.
65
+ *
66
+ * Achieves ~56 bits of precision, sufficient for f64's 52-bit mantissa.
67
+ * Formula per iteration: y' = y × (3 − x × y²) × 0.5
68
+ */
69
+ NK_INTERNAL vfloat64m1_t nk_rsqrt_f64m1_rvv_(vfloat64m1_t values_f64m1, size_t vector_length) {
70
+ vfloat64m1_t rsqrt_f64m1 = __riscv_vfrsqrt7_v_f64m1(values_f64m1, vector_length);
71
+ for (int step = 0; step < 3; step++) {
72
+ vfloat64m1_t rsqrt_sq_f64m1 = __riscv_vfmul_vv_f64m1(rsqrt_f64m1, rsqrt_f64m1, vector_length);
73
+ vfloat64m1_t residual_f64m1 = __riscv_vfrsub_vf_f64m1(
74
+ __riscv_vfmul_vv_f64m1(values_f64m1, rsqrt_sq_f64m1, vector_length), 3.0, vector_length);
75
+ rsqrt_f64m1 = __riscv_vfmul_vf_f64m1(__riscv_vfmul_vv_f64m1(rsqrt_f64m1, residual_f64m1, vector_length), 0.5,
76
+ vector_length);
77
+ }
78
+ return rsqrt_f64m1;
79
+ }
80
+
81
+ /**
82
+ * @brief Approximate reciprocal of f32 vector (m4) using vfrec7 + 2 Newton-Raphson steps.
83
+ * Achieves ~28-bit precision, sufficient for f32 (24-bit mantissa).
84
+ */
85
+ NK_INTERNAL vfloat32m4_t nk_f32m4_reciprocal_rvv_(vfloat32m4_t x_f32m4, nk_size_t vector_length) {
86
+ vfloat32m4_t est_f32m4 = __riscv_vfrec7_v_f32m4(x_f32m4, vector_length);
87
+ vfloat32m4_t two_f32m4 = __riscv_vfmv_v_f_f32m4(2.0f, vector_length);
88
+ // NR step 1: est = est * (2 - x * est)
89
+ est_f32m4 = __riscv_vfmul_vv_f32m4(
90
+ est_f32m4, __riscv_vfnmsac_vv_f32m4(two_f32m4, x_f32m4, est_f32m4, vector_length), vector_length);
91
+ // NR step 2: est = est * (2 - x * est)
92
+ two_f32m4 = __riscv_vfmv_v_f_f32m4(2.0f, vector_length);
93
+ est_f32m4 = __riscv_vfmul_vv_f32m4(
94
+ est_f32m4, __riscv_vfnmsac_vv_f32m4(two_f32m4, x_f32m4, est_f32m4, vector_length), vector_length);
95
+ return est_f32m4;
96
+ }
97
+
98
+ /**
99
+ * @brief Approximate reciprocal of f32 vector (m2) using vfrec7 + 2 Newton-Raphson steps.
100
+ * Achieves ~28-bit precision, sufficient for f32 (24-bit mantissa).
101
+ */
102
+ NK_INTERNAL vfloat32m2_t nk_f32m2_reciprocal_rvv_(vfloat32m2_t x_f32m2, nk_size_t vector_length) {
103
+ vfloat32m2_t est_f32m2 = __riscv_vfrec7_v_f32m2(x_f32m2, vector_length);
104
+ vfloat32m2_t two_f32m2 = __riscv_vfmv_v_f_f32m2(2.0f, vector_length);
105
+ // NR step 1: est = est * (2 - x * est)
106
+ est_f32m2 = __riscv_vfmul_vv_f32m2(
107
+ est_f32m2, __riscv_vfnmsac_vv_f32m2(two_f32m2, x_f32m2, est_f32m2, vector_length), vector_length);
108
+ // NR step 2: est = est * (2 - x * est)
109
+ two_f32m2 = __riscv_vfmv_v_f_f32m2(2.0f, vector_length);
110
+ est_f32m2 = __riscv_vfmul_vv_f32m2(
111
+ est_f32m2, __riscv_vfnmsac_vv_f32m2(two_f32m2, x_f32m2, est_f32m2, vector_length), vector_length);
112
+ return est_f32m2;
113
+ }
114
+
115
+ /**
116
+ * @brief Approximate reciprocal of f64 vector (m4) using vfrec7 + 3 Newton-Raphson steps.
117
+ * Achieves ~56-bit precision, sufficient for f64 (52-bit mantissa).
118
+ */
119
+ NK_INTERNAL vfloat64m4_t nk_f64m4_reciprocal_rvv_(vfloat64m4_t x_f64m4, nk_size_t vector_length) {
120
+ vfloat64m4_t est_f64m4 = __riscv_vfrec7_v_f64m4(x_f64m4, vector_length);
121
+ vfloat64m4_t two_f64m4 = __riscv_vfmv_v_f_f64m4(2.0, vector_length);
122
+ // NR step 1
123
+ est_f64m4 = __riscv_vfmul_vv_f64m4(
124
+ est_f64m4, __riscv_vfnmsac_vv_f64m4(two_f64m4, x_f64m4, est_f64m4, vector_length), vector_length);
125
+ // NR step 2
126
+ two_f64m4 = __riscv_vfmv_v_f_f64m4(2.0, vector_length);
127
+ est_f64m4 = __riscv_vfmul_vv_f64m4(
128
+ est_f64m4, __riscv_vfnmsac_vv_f64m4(two_f64m4, x_f64m4, est_f64m4, vector_length), vector_length);
129
+ // NR step 3
130
+ two_f64m4 = __riscv_vfmv_v_f_f64m4(2.0, vector_length);
131
+ est_f64m4 = __riscv_vfmul_vv_f64m4(
132
+ est_f64m4, __riscv_vfnmsac_vv_f64m4(two_f64m4, x_f64m4, est_f64m4, vector_length), vector_length);
133
+ return est_f64m4;
134
+ }
135
+
136
+ #pragma region - Small Integers
137
+
138
+ NK_PUBLIC void nk_sqeuclidean_i8_rvv(nk_i8_t const *a_scalars, nk_i8_t const *b_scalars, nk_size_t count_scalars,
139
+ nk_u32_t *result) {
140
+ vint32m1_t sum_i32m1 = __riscv_vmv_v_x_i32m1(0, 1);
141
+ for (nk_size_t vector_length; count_scalars > 0;
142
+ count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
143
+ vector_length = __riscv_vsetvl_e8m1(count_scalars);
144
+ vint8m1_t a_i8m1 = __riscv_vle8_v_i8m1(a_scalars, vector_length);
145
+ vint8m1_t b_i8m1 = __riscv_vle8_v_i8m1(b_scalars, vector_length);
146
+ // Widening subtract: i8 - i8 → i16
147
+ vint16m2_t diff_i16m2 = __riscv_vwsub_vv_i16m2(a_i8m1, b_i8m1, vector_length);
148
+ // Widening square: i16 × i16 → i32
149
+ vint32m4_t sq_i32m4 = __riscv_vwmul_vv_i32m4(diff_i16m2, diff_i16m2, vector_length);
150
+ // Reduce to scalar
151
+ sum_i32m1 = __riscv_vredsum_vs_i32m4_i32m1(sq_i32m4, sum_i32m1, vector_length);
152
+ }
153
+ *result = (nk_u32_t)__riscv_vmv_x_s_i32m1_i32(sum_i32m1);
154
+ }
155
+
156
+ NK_PUBLIC void nk_euclidean_i8_rvv(nk_i8_t const *a_scalars, nk_i8_t const *b_scalars, nk_size_t count_scalars,
157
+ nk_f32_t *result) {
158
+ nk_u32_t d2;
159
+ nk_sqeuclidean_i8_rvv(a_scalars, b_scalars, count_scalars, &d2);
160
+ *result = nk_f32_sqrt_rvv((nk_f32_t)d2);
161
+ }
162
+
163
+ NK_PUBLIC void nk_sqeuclidean_u8_rvv(nk_u8_t const *a_scalars, nk_u8_t const *b_scalars, nk_size_t count_scalars,
164
+ nk_u32_t *result) {
165
+ vuint32m1_t sum_u32m1 = __riscv_vmv_v_x_u32m1(0, 1);
166
+ for (nk_size_t vector_length; count_scalars > 0;
167
+ count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
168
+ vector_length = __riscv_vsetvl_e8m1(count_scalars);
169
+ vuint8m1_t a_u8m1 = __riscv_vle8_v_u8m1(a_scalars, vector_length);
170
+ vuint8m1_t b_u8m1 = __riscv_vle8_v_u8m1(b_scalars, vector_length);
171
+ // Compute |a - b| using saturating subtraction: max(a-b, b-a) = (a -sat b) | (b -sat a)
172
+ vuint8m1_t diff_ab_u8m1 = __riscv_vssubu_vv_u8m1(a_u8m1, b_u8m1, vector_length);
173
+ vuint8m1_t diff_ba_u8m1 = __riscv_vssubu_vv_u8m1(b_u8m1, a_u8m1, vector_length);
174
+ vuint8m1_t abs_diff_u8m1 = __riscv_vor_vv_u8m1(diff_ab_u8m1, diff_ba_u8m1, vector_length);
175
+ // Widening multiply: u8 × u8 → u16
176
+ vuint16m2_t sq_u16m2 = __riscv_vwmulu_vv_u16m2(abs_diff_u8m1, abs_diff_u8m1, vector_length);
177
+ // Widening reduce: u16 → u32
178
+ sum_u32m1 = __riscv_vwredsumu_vs_u16m2_u32m1(sq_u16m2, sum_u32m1, vector_length);
179
+ }
180
+ *result = __riscv_vmv_x_s_u32m1_u32(sum_u32m1);
181
+ }
182
+
183
+ NK_PUBLIC void nk_euclidean_u8_rvv(nk_u8_t const *a_scalars, nk_u8_t const *b_scalars, nk_size_t count_scalars,
184
+ nk_f32_t *result) {
185
+ nk_u32_t d2;
186
+ nk_sqeuclidean_u8_rvv(a_scalars, b_scalars, count_scalars, &d2);
187
+ *result = nk_f32_sqrt_rvv((nk_f32_t)d2);
188
+ }
189
+
190
+ #pragma endregion - Small Integers
191
+ #pragma region - Traditional Floats
192
+
193
+ NK_PUBLIC void nk_sqeuclidean_f32_rvv(nk_f32_t const *a_scalars, nk_f32_t const *b_scalars, nk_size_t count_scalars,
194
+ nk_f64_t *result) {
195
+ nk_size_t vlmax = __riscv_vsetvlmax_e64m2();
196
+ vfloat64m2_t sum_f64m2 = __riscv_vfmv_v_f_f64m2(0.0, vlmax);
197
+ for (nk_size_t vector_length; count_scalars > 0;
198
+ count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
199
+ vector_length = __riscv_vsetvl_e32m1(count_scalars);
200
+ vfloat32m1_t a_f32m1 = __riscv_vle32_v_f32m1(a_scalars, vector_length);
201
+ vfloat32m1_t b_f32m1 = __riscv_vle32_v_f32m1(b_scalars, vector_length);
202
+ vfloat64m2_t a_f64m2 = __riscv_vfwcvt_f_f_v_f64m2(a_f32m1, vector_length);
203
+ vfloat64m2_t b_f64m2 = __riscv_vfwcvt_f_f_v_f64m2(b_f32m1, vector_length);
204
+ vfloat64m2_t diff_f64m2 = __riscv_vfsub_vv_f64m2(a_f64m2, b_f64m2, vector_length);
205
+ sum_f64m2 = __riscv_vfmacc_vv_f64m2_tu(sum_f64m2, diff_f64m2, diff_f64m2, vector_length);
206
+ }
207
+ // Single horizontal reduction at the end
208
+ vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
209
+ *result = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m2_f64m1(sum_f64m2, zero_f64m1, vlmax));
210
+ }
211
+
212
+ NK_PUBLIC void nk_euclidean_f32_rvv(nk_f32_t const *a_scalars, nk_f32_t const *b_scalars, nk_size_t count_scalars,
213
+ nk_f64_t *result) {
214
+ nk_sqeuclidean_f32_rvv(a_scalars, b_scalars, count_scalars, result);
215
+ *result = nk_f64_sqrt_rvv(*result);
216
+ }
217
+
218
+ NK_PUBLIC void nk_sqeuclidean_f64_rvv(nk_f64_t const *a_scalars, nk_f64_t const *b_scalars, nk_size_t count_scalars,
219
+ nk_f64_t *result) {
220
+ nk_size_t vector_length_max = __riscv_vsetvlmax_e64m1();
221
+ vfloat64m1_t sum_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, vector_length_max);
222
+ for (nk_size_t vector_length; count_scalars > 0;
223
+ count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
224
+ vector_length = __riscv_vsetvl_e64m1(count_scalars);
225
+ vfloat64m1_t a_f64m1 = __riscv_vle64_v_f64m1(a_scalars, vector_length);
226
+ vfloat64m1_t b_f64m1 = __riscv_vle64_v_f64m1(b_scalars, vector_length);
227
+ // Compute difference and accumulate diff² into vector lanes
228
+ vfloat64m1_t diff_f64m1 = __riscv_vfsub_vv_f64m1(a_f64m1, b_f64m1, vector_length);
229
+ sum_f64m1 = __riscv_vfmacc_vv_f64m1_tu(sum_f64m1, diff_f64m1, diff_f64m1, vector_length);
230
+ }
231
+ // Single horizontal reduction at the end
232
+ vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, vector_length_max);
233
+ *result = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m1_f64m1(sum_f64m1, zero_f64m1, vector_length_max));
234
+ }
235
+
236
+ NK_PUBLIC void nk_euclidean_f64_rvv(nk_f64_t const *a_scalars, nk_f64_t const *b_scalars, nk_size_t count_scalars,
237
+ nk_f64_t *result) {
238
+ nk_sqeuclidean_f64_rvv(a_scalars, b_scalars, count_scalars, result);
239
+ *result = nk_f64_sqrt_rvv(*result);
240
+ }
241
+
242
+ #pragma endregion - Traditional Floats
243
+ #pragma region - Small Integers
244
+
245
+ NK_PUBLIC void nk_angular_i8_rvv(nk_i8_t const *a_scalars, nk_i8_t const *b_scalars, nk_size_t count_scalars,
246
+ nk_f32_t *result) {
247
+ vint32m1_t dot_i32m1 = __riscv_vmv_v_x_i32m1(0, 1);
248
+ vint32m1_t a_norm_sq_i32m1 = __riscv_vmv_v_x_i32m1(0, 1);
249
+ vint32m1_t b_norm_sq_i32m1 = __riscv_vmv_v_x_i32m1(0, 1);
250
+
251
+ for (nk_size_t vector_length; count_scalars > 0;
252
+ count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
253
+ vector_length = __riscv_vsetvl_e8m1(count_scalars);
254
+ vint8m1_t a_i8m1 = __riscv_vle8_v_i8m1(a_scalars, vector_length);
255
+ vint8m1_t b_i8m1 = __riscv_vle8_v_i8m1(b_scalars, vector_length);
256
+
257
+ // dot += a × b (widened to i32)
258
+ vint16m2_t ab_i16m2 = __riscv_vwmul_vv_i16m2(a_i8m1, b_i8m1, vector_length);
259
+ dot_i32m1 = __riscv_vwredsum_vs_i16m2_i32m1(ab_i16m2, dot_i32m1, vector_length);
260
+
261
+ // a_norm_sq += a × a
262
+ vint16m2_t aa_i16m2 = __riscv_vwmul_vv_i16m2(a_i8m1, a_i8m1, vector_length);
263
+ a_norm_sq_i32m1 = __riscv_vwredsum_vs_i16m2_i32m1(aa_i16m2, a_norm_sq_i32m1, vector_length);
264
+
265
+ // b_norm_sq += b × b
266
+ vint16m2_t bb_i16m2 = __riscv_vwmul_vv_i16m2(b_i8m1, b_i8m1, vector_length);
267
+ b_norm_sq_i32m1 = __riscv_vwredsum_vs_i16m2_i32m1(bb_i16m2, b_norm_sq_i32m1, vector_length);
268
+ }
269
+
270
+ nk_i32_t dot_i32 = __riscv_vmv_x_s_i32m1_i32(dot_i32m1);
271
+ nk_i32_t a_norm_sq_i32 = __riscv_vmv_x_s_i32m1_i32(a_norm_sq_i32m1);
272
+ nk_i32_t b_norm_sq_i32 = __riscv_vmv_x_s_i32m1_i32(b_norm_sq_i32m1);
273
+
274
+ // Normalize: 1 − dot / √(‖a‖² × ‖b‖²)
275
+ if (a_norm_sq_i32 == 0 && b_norm_sq_i32 == 0) { *result = 0.0f; }
276
+ else if (dot_i32 == 0) { *result = 1.0f; }
277
+ else {
278
+ nk_f32_t unclipped = 1.0f - (nk_f32_t)dot_i32 * nk_f32_rsqrt_rvv((nk_f32_t)a_norm_sq_i32) *
279
+ nk_f32_rsqrt_rvv((nk_f32_t)b_norm_sq_i32);
280
+ *result = unclipped > 0 ? unclipped : 0;
281
+ }
282
+ }
283
+
284
+ NK_PUBLIC void nk_angular_u8_rvv(nk_u8_t const *a_scalars, nk_u8_t const *b_scalars, nk_size_t count_scalars,
285
+ nk_f32_t *result) {
286
+ vuint32m1_t dot_u32m1 = __riscv_vmv_v_x_u32m1(0, 1);
287
+ vuint32m1_t a_norm_sq_u32m1 = __riscv_vmv_v_x_u32m1(0, 1);
288
+ vuint32m1_t b_norm_sq_u32m1 = __riscv_vmv_v_x_u32m1(0, 1);
289
+
290
+ for (nk_size_t vector_length; count_scalars > 0;
291
+ count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
292
+ vector_length = __riscv_vsetvl_e8m1(count_scalars);
293
+ vuint8m1_t a_u8m1 = __riscv_vle8_v_u8m1(a_scalars, vector_length);
294
+ vuint8m1_t b_u8m1 = __riscv_vle8_v_u8m1(b_scalars, vector_length);
295
+
296
+ // dot += a × b (widened to u32)
297
+ vuint16m2_t ab_u16m2 = __riscv_vwmulu_vv_u16m2(a_u8m1, b_u8m1, vector_length);
298
+ dot_u32m1 = __riscv_vwredsumu_vs_u16m2_u32m1(ab_u16m2, dot_u32m1, vector_length);
299
+
300
+ // a_norm_sq += a × a
301
+ vuint16m2_t aa_u16m2 = __riscv_vwmulu_vv_u16m2(a_u8m1, a_u8m1, vector_length);
302
+ a_norm_sq_u32m1 = __riscv_vwredsumu_vs_u16m2_u32m1(aa_u16m2, a_norm_sq_u32m1, vector_length);
303
+
304
+ // b_norm_sq += b × b
305
+ vuint16m2_t bb_u16m2 = __riscv_vwmulu_vv_u16m2(b_u8m1, b_u8m1, vector_length);
306
+ b_norm_sq_u32m1 = __riscv_vwredsumu_vs_u16m2_u32m1(bb_u16m2, b_norm_sq_u32m1, vector_length);
307
+ }
308
+
309
+ nk_u32_t dot_u32 = __riscv_vmv_x_s_u32m1_u32(dot_u32m1);
310
+ nk_u32_t a_norm_sq_u32 = __riscv_vmv_x_s_u32m1_u32(a_norm_sq_u32m1);
311
+ nk_u32_t b_norm_sq_u32 = __riscv_vmv_x_s_u32m1_u32(b_norm_sq_u32m1);
312
+
313
+ // Normalize: 1 − dot / √(‖a‖² × ‖b‖²)
314
+ if (a_norm_sq_u32 == 0 && b_norm_sq_u32 == 0) { *result = 0.0f; }
315
+ else if (dot_u32 == 0) { *result = 1.0f; }
316
+ else {
317
+ nk_f32_t unclipped = 1.0f - (nk_f32_t)dot_u32 * nk_f32_rsqrt_rvv((nk_f32_t)a_norm_sq_u32) *
318
+ nk_f32_rsqrt_rvv((nk_f32_t)b_norm_sq_u32);
319
+ *result = unclipped > 0 ? unclipped : 0;
320
+ }
321
+ }
322
+
323
+ #pragma endregion - Small Integers
324
+ #pragma region - Traditional Floats
325
+
326
+ NK_PUBLIC void nk_angular_f32_rvv(nk_f32_t const *a_scalars, nk_f32_t const *b_scalars, nk_size_t count_scalars,
327
+ nk_f64_t *result) {
328
+ nk_size_t vlmax = __riscv_vsetvlmax_e64m2();
329
+ vfloat64m2_t dot_f64m2 = __riscv_vfmv_v_f_f64m2(0.0, vlmax);
330
+ vfloat64m2_t a_norm_sq_f64m2 = __riscv_vfmv_v_f_f64m2(0.0, vlmax);
331
+ vfloat64m2_t b_norm_sq_f64m2 = __riscv_vfmv_v_f_f64m2(0.0, vlmax);
332
+
333
+ for (nk_size_t vector_length; count_scalars > 0;
334
+ count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
335
+ vector_length = __riscv_vsetvl_e32m1(count_scalars);
336
+ vfloat32m1_t a_f32m1 = __riscv_vle32_v_f32m1(a_scalars, vector_length);
337
+ vfloat32m1_t b_f32m1 = __riscv_vle32_v_f32m1(b_scalars, vector_length);
338
+
339
+ // Widening multiply-accumulate into f64 vector lanes
340
+ dot_f64m2 = __riscv_vfwmacc_vv_f64m2_tu(dot_f64m2, a_f32m1, b_f32m1, vector_length);
341
+ a_norm_sq_f64m2 = __riscv_vfwmacc_vv_f64m2_tu(a_norm_sq_f64m2, a_f32m1, a_f32m1, vector_length);
342
+ b_norm_sq_f64m2 = __riscv_vfwmacc_vv_f64m2_tu(b_norm_sq_f64m2, b_f32m1, b_f32m1, vector_length);
343
+ }
344
+
345
+ // Single horizontal reduction at the end for all three accumulators
346
+ vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
347
+ nk_f64_t dot_f64 = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m2_f64m1(dot_f64m2, zero_f64m1, vlmax));
348
+ nk_f64_t a_norm_sq_f64 = __riscv_vfmv_f_s_f64m1_f64(
349
+ __riscv_vfredusum_vs_f64m2_f64m1(a_norm_sq_f64m2, zero_f64m1, vlmax));
350
+ nk_f64_t b_norm_sq_f64 = __riscv_vfmv_f_s_f64m1_f64(
351
+ __riscv_vfredusum_vs_f64m2_f64m1(b_norm_sq_f64m2, zero_f64m1, vlmax));
352
+
353
+ // Normalize: 1 − dot / √(‖a‖² × ‖b‖²)
354
+ if (a_norm_sq_f64 == 0.0 && b_norm_sq_f64 == 0.0) { *result = 0.0; }
355
+ else if (dot_f64 == 0.0) { *result = 1.0; }
356
+ else {
357
+ nk_f64_t unclipped = 1.0 - dot_f64 * nk_f64_rsqrt_rvv(a_norm_sq_f64) * nk_f64_rsqrt_rvv(b_norm_sq_f64);
358
+ *result = unclipped > 0 ? unclipped : 0.0;
359
+ }
360
+ }
361
+
362
+ NK_PUBLIC void nk_angular_f64_rvv(nk_f64_t const *a_scalars, nk_f64_t const *b_scalars, nk_size_t count_scalars,
363
+ nk_f64_t *result) {
364
+ // Dot2 (Ogita-Rump-Oishi) for cross-product (may have cancellation),
365
+ // simple FMA for self-products a²/b² (all positive, no cancellation)
366
+ nk_size_t vector_length_max = __riscv_vsetvlmax_e64m1();
367
+ vfloat64m1_t dot_sum_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, vector_length_max);
368
+ vfloat64m1_t dot_compensation_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, vector_length_max);
369
+ vfloat64m1_t a_norm_sq_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, vector_length_max);
370
+ vfloat64m1_t b_norm_sq_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, vector_length_max);
371
+
372
+ for (nk_size_t vector_length; count_scalars > 0;
373
+ count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
374
+ vector_length = __riscv_vsetvl_e64m1(count_scalars);
375
+ vfloat64m1_t a_f64m1 = __riscv_vle64_v_f64m1(a_scalars, vector_length);
376
+ vfloat64m1_t b_f64m1 = __riscv_vle64_v_f64m1(b_scalars, vector_length);
377
+
378
+ // TwoProd: product = a*b, product_error = fma(a,b,-product)
379
+ vfloat64m1_t product_f64m1 = __riscv_vfmul_vv_f64m1(a_f64m1, b_f64m1, vector_length);
380
+ vfloat64m1_t product_error_f64m1 = __riscv_vfmsac_vv_f64m1(product_f64m1, a_f64m1, b_f64m1, vector_length);
381
+ // TwoSum: tentative_sum = sum + product
382
+ vfloat64m1_t tentative_sum_f64m1 = __riscv_vfadd_vv_f64m1(dot_sum_f64m1, product_f64m1, vector_length);
383
+ vfloat64m1_t virtual_addend_f64m1 = __riscv_vfsub_vv_f64m1(tentative_sum_f64m1, dot_sum_f64m1, vector_length);
384
+ vfloat64m1_t sum_error_f64m1 = __riscv_vfadd_vv_f64m1(
385
+ __riscv_vfsub_vv_f64m1(dot_sum_f64m1,
386
+ __riscv_vfsub_vv_f64m1(tentative_sum_f64m1, virtual_addend_f64m1, vector_length),
387
+ vector_length),
388
+ __riscv_vfsub_vv_f64m1(product_f64m1, virtual_addend_f64m1, vector_length), vector_length);
389
+ // Tail-undisturbed updates: preserve zero tails across partial iterations
390
+ dot_sum_f64m1 = __riscv_vslideup_vx_f64m1_tu(dot_sum_f64m1, tentative_sum_f64m1, 0, vector_length);
391
+ vfloat64m1_t total_error_f64m1 = __riscv_vfadd_vv_f64m1(sum_error_f64m1, product_error_f64m1, vector_length);
392
+ dot_compensation_f64m1 = __riscv_vfadd_vv_f64m1_tu(dot_compensation_f64m1, dot_compensation_f64m1,
393
+ total_error_f64m1, vector_length);
394
+ // Simple FMA for self-products (no cancellation possible)
395
+ a_norm_sq_f64m1 = __riscv_vfmacc_vv_f64m1_tu(a_norm_sq_f64m1, a_f64m1, a_f64m1, vector_length);
396
+ b_norm_sq_f64m1 = __riscv_vfmacc_vv_f64m1_tu(b_norm_sq_f64m1, b_f64m1, b_f64m1, vector_length);
397
+ }
398
+
399
+ // Compensated horizontal reduction for cross-product, simple reduction for self-products
400
+ nk_f64_t dot_f64 = nk_dot_stable_sum_f64m1_rvv_(dot_sum_f64m1, dot_compensation_f64m1);
401
+ vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, vector_length_max);
402
+ nk_f64_t a_norm_sq_f64 = __riscv_vfmv_f_s_f64m1_f64(
403
+ __riscv_vfredusum_vs_f64m1_f64m1(a_norm_sq_f64m1, zero_f64m1, vector_length_max));
404
+ nk_f64_t b_norm_sq_f64 = __riscv_vfmv_f_s_f64m1_f64(
405
+ __riscv_vfredusum_vs_f64m1_f64m1(b_norm_sq_f64m1, zero_f64m1, vector_length_max));
406
+
407
+ // Normalize: 1 − dot / √(‖a‖² × ‖b‖²)
408
+ if (a_norm_sq_f64 == 0.0 && b_norm_sq_f64 == 0.0) { *result = 0.0; }
409
+ else if (dot_f64 == 0.0) { *result = 1.0; }
410
+ else {
411
+ nk_f64_t unclipped = 1.0 - dot_f64 * nk_f64_rsqrt_rvv(a_norm_sq_f64) * nk_f64_rsqrt_rvv(b_norm_sq_f64);
412
+ *result = unclipped > 0 ? unclipped : 0;
413
+ }
414
+ }
415
+
416
+ #pragma endregion - Traditional Floats
417
+ #pragma region - Smaller Floats
418
+
419
+ NK_PUBLIC void nk_sqeuclidean_f16_rvv(nk_f16_t const *a_scalars, nk_f16_t const *b_scalars, nk_size_t count_scalars,
420
+ nk_f32_t *result) {
421
+ nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
422
+ vfloat32m2_t sum_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
423
+ for (nk_size_t vector_length; count_scalars > 0;
424
+ count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
425
+ vector_length = __riscv_vsetvl_e16m1(count_scalars);
426
+
427
+ // Load f16 as u16 bits and convert to f32 via helper
428
+ vuint16m1_t a_u16m1 = __riscv_vle16_v_u16m1((nk_u16_t const *)a_scalars, vector_length);
429
+ vuint16m1_t b_u16m1 = __riscv_vle16_v_u16m1((nk_u16_t const *)b_scalars, vector_length);
430
+ vfloat32m2_t a_f32m2 = nk_f16m1_to_f32m2_rvv_(a_u16m1, vector_length);
431
+ vfloat32m2_t b_f32m2 = nk_f16m1_to_f32m2_rvv_(b_u16m1, vector_length);
432
+
433
+ // Compute difference in f32, accumulate diff² into vector lanes
434
+ vfloat32m2_t diff_f32m2 = __riscv_vfsub_vv_f32m2(a_f32m2, b_f32m2, vector_length);
435
+ sum_f32m2 = __riscv_vfmacc_vv_f32m2_tu(sum_f32m2, diff_f32m2, diff_f32m2, vector_length);
436
+ }
437
+ // Single horizontal reduction at the end
438
+ vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
439
+ *result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m2_f32m1(sum_f32m2, zero_f32m1, vlmax));
440
+ }
441
+
442
+ NK_PUBLIC void nk_euclidean_f16_rvv(nk_f16_t const *a_scalars, nk_f16_t const *b_scalars, nk_size_t count_scalars,
443
+ nk_f32_t *result) {
444
+ nk_sqeuclidean_f16_rvv(a_scalars, b_scalars, count_scalars, result);
445
+ *result = nk_f32_sqrt_rvv(*result);
446
+ }
447
+
448
+ NK_PUBLIC void nk_angular_f16_rvv(nk_f16_t const *a_scalars, nk_f16_t const *b_scalars, nk_size_t count_scalars,
449
+ nk_f32_t *result) {
450
+ nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
451
+ vfloat32m2_t dot_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
452
+ vfloat32m2_t a_norm_sq_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
453
+ vfloat32m2_t b_norm_sq_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
454
+
455
+ for (nk_size_t vector_length; count_scalars > 0;
456
+ count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
457
+ vector_length = __riscv_vsetvl_e16m1(count_scalars);
458
+
459
+ // Load f16 as u16 bits and convert to f32 via helper
460
+ vuint16m1_t a_u16m1 = __riscv_vle16_v_u16m1((nk_u16_t const *)a_scalars, vector_length);
461
+ vuint16m1_t b_u16m1 = __riscv_vle16_v_u16m1((nk_u16_t const *)b_scalars, vector_length);
462
+ vfloat32m2_t a_f32m2 = nk_f16m1_to_f32m2_rvv_(a_u16m1, vector_length);
463
+ vfloat32m2_t b_f32m2 = nk_f16m1_to_f32m2_rvv_(b_u16m1, vector_length);
464
+
465
+ // Multiply-accumulate into f32 vector lanes
466
+ dot_f32m2 = __riscv_vfmacc_vv_f32m2_tu(dot_f32m2, a_f32m2, b_f32m2, vector_length);
467
+ a_norm_sq_f32m2 = __riscv_vfmacc_vv_f32m2_tu(a_norm_sq_f32m2, a_f32m2, a_f32m2, vector_length);
468
+ b_norm_sq_f32m2 = __riscv_vfmacc_vv_f32m2_tu(b_norm_sq_f32m2, b_f32m2, b_f32m2, vector_length);
469
+ }
470
+
471
+ // Single horizontal reduction at the end for all three accumulators
472
+ vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
473
+ nk_f32_t dot_f32 = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m2_f32m1(dot_f32m2, zero_f32m1, vlmax));
474
+ nk_f32_t a_norm_sq_f32 = __riscv_vfmv_f_s_f32m1_f32(
475
+ __riscv_vfredusum_vs_f32m2_f32m1(a_norm_sq_f32m2, zero_f32m1, vlmax));
476
+ nk_f32_t b_norm_sq_f32 = __riscv_vfmv_f_s_f32m1_f32(
477
+ __riscv_vfredusum_vs_f32m2_f32m1(b_norm_sq_f32m2, zero_f32m1, vlmax));
478
+
479
+ if (a_norm_sq_f32 == 0.0f && b_norm_sq_f32 == 0.0f) { *result = 0.0f; }
480
+ else if (dot_f32 == 0.0f) { *result = 1.0f; }
481
+ else {
482
+ nk_f32_t unclipped = 1.0f - dot_f32 * nk_f32_rsqrt_rvv(a_norm_sq_f32) * nk_f32_rsqrt_rvv(b_norm_sq_f32);
483
+ *result = unclipped > 0.0f ? unclipped : 0.0f;
484
+ }
485
+ }
486
+
487
+ NK_PUBLIC void nk_sqeuclidean_bf16_rvv(nk_bf16_t const *a_scalars, nk_bf16_t const *b_scalars, nk_size_t count_scalars,
488
+ nk_f32_t *result) {
489
+ nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
490
+ vfloat32m2_t sum_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
491
+ for (nk_size_t vector_length; count_scalars > 0;
492
+ count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
493
+ vector_length = __riscv_vsetvl_e16m1(count_scalars);
494
+
495
+ // Load bf16 as u16 and convert to f32 via helper
496
+ vuint16m1_t a_u16m1 = __riscv_vle16_v_u16m1((nk_u16_t const *)a_scalars, vector_length);
497
+ vuint16m1_t b_u16m1 = __riscv_vle16_v_u16m1((nk_u16_t const *)b_scalars, vector_length);
498
+ vfloat32m2_t a_f32m2 = nk_bf16m1_to_f32m2_rvv_(a_u16m1, vector_length);
499
+ vfloat32m2_t b_f32m2 = nk_bf16m1_to_f32m2_rvv_(b_u16m1, vector_length);
500
+
501
+ // Compute difference in f32, accumulate diff² into vector lanes
502
+ vfloat32m2_t diff_f32m2 = __riscv_vfsub_vv_f32m2(a_f32m2, b_f32m2, vector_length);
503
+ sum_f32m2 = __riscv_vfmacc_vv_f32m2_tu(sum_f32m2, diff_f32m2, diff_f32m2, vector_length);
504
+ }
505
+ // Single horizontal reduction at the end
506
+ vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
507
+ *result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m2_f32m1(sum_f32m2, zero_f32m1, vlmax));
508
+ }
509
+
510
+ NK_PUBLIC void nk_euclidean_bf16_rvv(nk_bf16_t const *a_scalars, nk_bf16_t const *b_scalars, nk_size_t count_scalars,
511
+ nk_f32_t *result) {
512
+ nk_sqeuclidean_bf16_rvv(a_scalars, b_scalars, count_scalars, result);
513
+ *result = nk_f32_sqrt_rvv(*result);
514
+ }
515
+
516
+ NK_PUBLIC void nk_angular_bf16_rvv(nk_bf16_t const *a_scalars, nk_bf16_t const *b_scalars, nk_size_t count_scalars,
517
+ nk_f32_t *result) {
518
+ nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
519
+ vfloat32m2_t dot_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
520
+ vfloat32m2_t a_norm_sq_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
521
+ vfloat32m2_t b_norm_sq_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
522
+
523
+ for (nk_size_t vector_length; count_scalars > 0;
524
+ count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
525
+ vector_length = __riscv_vsetvl_e16m1(count_scalars);
526
+
527
+ // Load bf16 as u16 and convert to f32 via helper
528
+ vuint16m1_t a_u16m1 = __riscv_vle16_v_u16m1((nk_u16_t const *)a_scalars, vector_length);
529
+ vuint16m1_t b_u16m1 = __riscv_vle16_v_u16m1((nk_u16_t const *)b_scalars, vector_length);
530
+ vfloat32m2_t a_f32m2 = nk_bf16m1_to_f32m2_rvv_(a_u16m1, vector_length);
531
+ vfloat32m2_t b_f32m2 = nk_bf16m1_to_f32m2_rvv_(b_u16m1, vector_length);
532
+
533
+ // Multiply-accumulate into f32 vector lanes
534
+ dot_f32m2 = __riscv_vfmacc_vv_f32m2_tu(dot_f32m2, a_f32m2, b_f32m2, vector_length);
535
+ a_norm_sq_f32m2 = __riscv_vfmacc_vv_f32m2_tu(a_norm_sq_f32m2, a_f32m2, a_f32m2, vector_length);
536
+ b_norm_sq_f32m2 = __riscv_vfmacc_vv_f32m2_tu(b_norm_sq_f32m2, b_f32m2, b_f32m2, vector_length);
537
+ }
538
+
539
+ // Single horizontal reduction at the end for all three accumulators
540
+ vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
541
+ nk_f32_t dot_f32 = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m2_f32m1(dot_f32m2, zero_f32m1, vlmax));
542
+ nk_f32_t a_norm_sq_f32 = __riscv_vfmv_f_s_f32m1_f32(
543
+ __riscv_vfredusum_vs_f32m2_f32m1(a_norm_sq_f32m2, zero_f32m1, vlmax));
544
+ nk_f32_t b_norm_sq_f32 = __riscv_vfmv_f_s_f32m1_f32(
545
+ __riscv_vfredusum_vs_f32m2_f32m1(b_norm_sq_f32m2, zero_f32m1, vlmax));
546
+
547
+ if (a_norm_sq_f32 == 0.0f && b_norm_sq_f32 == 0.0f) { *result = 0.0f; }
548
+ else if (dot_f32 == 0.0f) { *result = 1.0f; }
549
+ else {
550
+ nk_f32_t unclipped = 1.0f - dot_f32 * nk_f32_rsqrt_rvv(a_norm_sq_f32) * nk_f32_rsqrt_rvv(b_norm_sq_f32);
551
+ *result = unclipped > 0.0f ? unclipped : 0.0f;
552
+ }
553
+ }
554
+
555
+ NK_PUBLIC void nk_sqeuclidean_e4m3_rvv(nk_e4m3_t const *a_scalars, nk_e4m3_t const *b_scalars, nk_size_t count_scalars,
556
+ nk_f32_t *result) {
557
+ nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
558
+ vfloat32m4_t sum_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, vlmax);
559
+ for (nk_size_t vector_length; count_scalars > 0;
560
+ count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
561
+ vector_length = __riscv_vsetvl_e8m1(count_scalars);
562
+
563
+ // Load e4m3 as u8 and convert to f32 via helper
564
+ vuint8m1_t a_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)a_scalars, vector_length);
565
+ vuint8m1_t b_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)b_scalars, vector_length);
566
+ vfloat32m4_t a_f32m4 = nk_e4m3m1_to_f32m4_rvv_(a_u8m1, vector_length);
567
+ vfloat32m4_t b_f32m4 = nk_e4m3m1_to_f32m4_rvv_(b_u8m1, vector_length);
568
+
569
+ // Compute difference in f32, accumulate diff² into vector lanes
570
+ vfloat32m4_t diff_f32m4 = __riscv_vfsub_vv_f32m4(a_f32m4, b_f32m4, vector_length);
571
+ sum_f32m4 = __riscv_vfmacc_vv_f32m4_tu(sum_f32m4, diff_f32m4, diff_f32m4, vector_length);
572
+ }
573
+ // Single horizontal reduction at the end
574
+ vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
575
+ *result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sum_f32m4, zero_f32m1, vlmax));
576
+ }
577
+
578
+ NK_PUBLIC void nk_euclidean_e4m3_rvv(nk_e4m3_t const *a_scalars, nk_e4m3_t const *b_scalars, nk_size_t count_scalars,
579
+ nk_f32_t *result) {
580
+ nk_sqeuclidean_e4m3_rvv(a_scalars, b_scalars, count_scalars, result);
581
+ *result = nk_f32_sqrt_rvv(*result);
582
+ }
583
+
584
+ NK_PUBLIC void nk_angular_e4m3_rvv(nk_e4m3_t const *a_scalars, nk_e4m3_t const *b_scalars, nk_size_t count_scalars,
585
+ nk_f32_t *result) {
586
+ nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
587
+ vfloat32m4_t dot_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, vlmax);
588
+ vfloat32m4_t a_norm_sq_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, vlmax);
589
+ vfloat32m4_t b_norm_sq_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, vlmax);
590
+
591
+ for (nk_size_t vector_length; count_scalars > 0;
592
+ count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
593
+ vector_length = __riscv_vsetvl_e8m1(count_scalars);
594
+
595
+ // Load e4m3 as u8 and convert to f32 via helper
596
+ vuint8m1_t a_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)a_scalars, vector_length);
597
+ vuint8m1_t b_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)b_scalars, vector_length);
598
+ vfloat32m4_t a_f32m4 = nk_e4m3m1_to_f32m4_rvv_(a_u8m1, vector_length);
599
+ vfloat32m4_t b_f32m4 = nk_e4m3m1_to_f32m4_rvv_(b_u8m1, vector_length);
600
+
601
+ // Multiply-accumulate into f32 vector lanes
602
+ dot_f32m4 = __riscv_vfmacc_vv_f32m4_tu(dot_f32m4, a_f32m4, b_f32m4, vector_length);
603
+ a_norm_sq_f32m4 = __riscv_vfmacc_vv_f32m4_tu(a_norm_sq_f32m4, a_f32m4, a_f32m4, vector_length);
604
+ b_norm_sq_f32m4 = __riscv_vfmacc_vv_f32m4_tu(b_norm_sq_f32m4, b_f32m4, b_f32m4, vector_length);
605
+ }
606
+
607
+ // Single horizontal reduction at the end for all three accumulators
608
+ vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
609
+ nk_f32_t dot_f32 = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(dot_f32m4, zero_f32m1, vlmax));
610
+ nk_f32_t a_norm_sq_f32 = __riscv_vfmv_f_s_f32m1_f32(
611
+ __riscv_vfredusum_vs_f32m4_f32m1(a_norm_sq_f32m4, zero_f32m1, vlmax));
612
+ nk_f32_t b_norm_sq_f32 = __riscv_vfmv_f_s_f32m1_f32(
613
+ __riscv_vfredusum_vs_f32m4_f32m1(b_norm_sq_f32m4, zero_f32m1, vlmax));
614
+
615
+ if (a_norm_sq_f32 == 0.0f && b_norm_sq_f32 == 0.0f) { *result = 0.0f; }
616
+ else if (dot_f32 == 0.0f) { *result = 1.0f; }
617
+ else {
618
+ nk_f32_t unclipped = 1.0f - dot_f32 * nk_f32_rsqrt_rvv(a_norm_sq_f32) * nk_f32_rsqrt_rvv(b_norm_sq_f32);
619
+ *result = unclipped > 0.0f ? unclipped : 0.0f;
620
+ }
621
+ }
622
+
623
+ NK_PUBLIC void nk_sqeuclidean_e5m2_rvv(nk_e5m2_t const *a_scalars, nk_e5m2_t const *b_scalars, nk_size_t count_scalars,
624
+ nk_f32_t *result) {
625
+ nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
626
+ vfloat32m4_t sum_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, vlmax);
627
+ for (nk_size_t vector_length; count_scalars > 0;
628
+ count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
629
+ vector_length = __riscv_vsetvl_e8m1(count_scalars);
630
+
631
+ // Load e5m2 as u8 and convert to f32 via helper
632
+ vuint8m1_t a_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)a_scalars, vector_length);
633
+ vuint8m1_t b_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)b_scalars, vector_length);
634
+ vfloat32m4_t a_f32m4 = nk_e5m2m1_to_f32m4_rvv_(a_u8m1, vector_length);
635
+ vfloat32m4_t b_f32m4 = nk_e5m2m1_to_f32m4_rvv_(b_u8m1, vector_length);
636
+
637
+ // Compute difference in f32, accumulate diff² into vector lanes
638
+ vfloat32m4_t diff_f32m4 = __riscv_vfsub_vv_f32m4(a_f32m4, b_f32m4, vector_length);
639
+ sum_f32m4 = __riscv_vfmacc_vv_f32m4_tu(sum_f32m4, diff_f32m4, diff_f32m4, vector_length);
640
+ }
641
+ // Single horizontal reduction at the end
642
+ vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
643
+ *result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(sum_f32m4, zero_f32m1, vlmax));
644
+ }
645
+
646
+ NK_PUBLIC void nk_euclidean_e5m2_rvv(nk_e5m2_t const *a_scalars, nk_e5m2_t const *b_scalars, nk_size_t count_scalars,
647
+ nk_f32_t *result) {
648
+ nk_sqeuclidean_e5m2_rvv(a_scalars, b_scalars, count_scalars, result);
649
+ *result = nk_f32_sqrt_rvv(*result);
650
+ }
651
+
652
+ NK_PUBLIC void nk_angular_e5m2_rvv(nk_e5m2_t const *a_scalars, nk_e5m2_t const *b_scalars, nk_size_t count_scalars,
653
+ nk_f32_t *result) {
654
+ nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
655
+ vfloat32m4_t dot_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, vlmax);
656
+ vfloat32m4_t a_norm_sq_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, vlmax);
657
+ vfloat32m4_t b_norm_sq_f32m4 = __riscv_vfmv_v_f_f32m4(0.0f, vlmax);
658
+
659
+ for (nk_size_t vector_length; count_scalars > 0;
660
+ count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
661
+ vector_length = __riscv_vsetvl_e8m1(count_scalars);
662
+
663
+ // Load e5m2 as u8 and convert to f32 via helper
664
+ vuint8m1_t a_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)a_scalars, vector_length);
665
+ vuint8m1_t b_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)b_scalars, vector_length);
666
+ vfloat32m4_t a_f32m4 = nk_e5m2m1_to_f32m4_rvv_(a_u8m1, vector_length);
667
+ vfloat32m4_t b_f32m4 = nk_e5m2m1_to_f32m4_rvv_(b_u8m1, vector_length);
668
+
669
+ // Multiply-accumulate into f32 vector lanes
670
+ dot_f32m4 = __riscv_vfmacc_vv_f32m4_tu(dot_f32m4, a_f32m4, b_f32m4, vector_length);
671
+ a_norm_sq_f32m4 = __riscv_vfmacc_vv_f32m4_tu(a_norm_sq_f32m4, a_f32m4, a_f32m4, vector_length);
672
+ b_norm_sq_f32m4 = __riscv_vfmacc_vv_f32m4_tu(b_norm_sq_f32m4, b_f32m4, b_f32m4, vector_length);
673
+ }
674
+
675
+ // Single horizontal reduction at the end for all three accumulators
676
+ vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
677
+ nk_f32_t dot_f32 = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m4_f32m1(dot_f32m4, zero_f32m1, vlmax));
678
+ nk_f32_t a_norm_sq_f32 = __riscv_vfmv_f_s_f32m1_f32(
679
+ __riscv_vfredusum_vs_f32m4_f32m1(a_norm_sq_f32m4, zero_f32m1, vlmax));
680
+ nk_f32_t b_norm_sq_f32 = __riscv_vfmv_f_s_f32m1_f32(
681
+ __riscv_vfredusum_vs_f32m4_f32m1(b_norm_sq_f32m4, zero_f32m1, vlmax));
682
+
683
+ if (a_norm_sq_f32 == 0.0f && b_norm_sq_f32 == 0.0f) { *result = 0.0f; }
684
+ else if (dot_f32 == 0.0f) { *result = 1.0f; }
685
+ else {
686
+ nk_f32_t unclipped = 1.0f - dot_f32 * nk_f32_rsqrt_rvv(a_norm_sq_f32) * nk_f32_rsqrt_rvv(b_norm_sq_f32);
687
+ *result = unclipped > 0.0f ? unclipped : 0.0f;
688
+ }
689
+ }
690
+
691
+ #pragma endregion - Smaller Floats
692
+ #pragma region - Small Integers
693
+
694
+ NK_PUBLIC void nk_sqeuclidean_i4_rvv(nk_i4x2_t const *a_scalars, nk_i4x2_t const *b_scalars, nk_size_t count_scalars,
695
+ nk_u32_t *result) {
696
+ static nk_u8_t const nk_i4_sqd_lut_[256] = {
697
+ 0, 1, 4, 9, 16, 25, 36, 49, 64, 49, 36, 25, 16, 9, 4, 1, //
698
+ 1, 0, 1, 4, 9, 16, 25, 36, 81, 64, 49, 36, 25, 16, 9, 4, //
699
+ 4, 1, 0, 1, 4, 9, 16, 25, 100, 81, 64, 49, 36, 25, 16, 9, //
700
+ 9, 4, 1, 0, 1, 4, 9, 16, 121, 100, 81, 64, 49, 36, 25, 16, //
701
+ 16, 9, 4, 1, 0, 1, 4, 9, 144, 121, 100, 81, 64, 49, 36, 25, //
702
+ 25, 16, 9, 4, 1, 0, 1, 4, 169, 144, 121, 100, 81, 64, 49, 36, //
703
+ 36, 25, 16, 9, 4, 1, 0, 1, 196, 169, 144, 121, 100, 81, 64, 49, //
704
+ 49, 36, 25, 16, 9, 4, 1, 0, 225, 196, 169, 144, 121, 100, 81, 64, //
705
+ 64, 81, 100, 121, 144, 169, 196, 225, 0, 1, 4, 9, 16, 25, 36, 49, //
706
+ 49, 64, 81, 100, 121, 144, 169, 196, 1, 0, 1, 4, 9, 16, 25, 36, //
707
+ 36, 49, 64, 81, 100, 121, 144, 169, 4, 1, 0, 1, 4, 9, 16, 25, //
708
+ 25, 36, 49, 64, 81, 100, 121, 144, 9, 4, 1, 0, 1, 4, 9, 16, //
709
+ 16, 25, 36, 49, 64, 81, 100, 121, 16, 9, 4, 1, 0, 1, 4, 9, //
710
+ 9, 16, 25, 36, 49, 64, 81, 100, 25, 16, 9, 4, 1, 0, 1, 4, //
711
+ 4, 9, 16, 25, 36, 49, 64, 81, 36, 25, 16, 9, 4, 1, 0, 1, //
712
+ 1, 4, 9, 16, 25, 36, 49, 64, 49, 36, 25, 16, 9, 4, 1, 0, //
713
+ };
714
+ count_scalars = nk_size_round_up_to_multiple_(count_scalars, 2);
715
+ nk_size_t n_bytes = count_scalars / 2;
716
+ nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
717
+ vuint32m4_t sum_u32m4 = __riscv_vmv_v_x_u32m4(0, vlmax);
718
+ for (nk_size_t vector_length; n_bytes > 0;
719
+ n_bytes -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
720
+ vector_length = __riscv_vsetvl_e8m1(n_bytes);
721
+ vuint8m1_t a_packed_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)a_scalars, vector_length);
722
+ vuint8m1_t b_packed_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)b_scalars, vector_length);
723
+ // Build LUT indices: high nibble pair = (a_hi << 4) | b_hi
724
+ vuint8m1_t hi_idx_u8m1 = __riscv_vor_vv_u8m1(__riscv_vand_vx_u8m1(a_packed_u8m1, 0xF0, vector_length),
725
+ __riscv_vsrl_vx_u8m1(b_packed_u8m1, 4, vector_length),
726
+ vector_length);
727
+ // Low nibble pair = (a_lo << 4) | b_lo
728
+ vuint8m1_t lo_idx_u8m1 = __riscv_vor_vv_u8m1(
729
+ __riscv_vsll_vx_u8m1(__riscv_vand_vx_u8m1(a_packed_u8m1, 0x0F, vector_length), 4, vector_length),
730
+ __riscv_vand_vx_u8m1(b_packed_u8m1, 0x0F, vector_length), vector_length);
731
+ // Gather squared differences from LUT (0-225, fits u8)
732
+ vuint8m1_t sq_hi_u8m1 = __riscv_vluxei8_v_u8m1(nk_i4_sqd_lut_, hi_idx_u8m1, vector_length);
733
+ vuint8m1_t sq_lo_u8m1 = __riscv_vluxei8_v_u8m1(nk_i4_sqd_lut_, lo_idx_u8m1, vector_length);
734
+ // Combine and per-lane accumulate: u8+u8→u16, then u32+=u16
735
+ vuint16m2_t combined_u16m2 = __riscv_vwaddu_vv_u16m2(sq_hi_u8m1, sq_lo_u8m1, vector_length);
736
+ sum_u32m4 = __riscv_vwaddu_wv_u32m4_tu(sum_u32m4, sum_u32m4, combined_u16m2, vector_length);
737
+ }
738
+ // Single horizontal reduction after loop
739
+ vuint32m1_t zero_u32m1 = __riscv_vmv_v_x_u32m1(0, vlmax);
740
+ *result = __riscv_vmv_x_s_u32m1_u32(__riscv_vredsum_vs_u32m4_u32m1(sum_u32m4, zero_u32m1, vlmax));
741
+ }
742
+
743
+ NK_PUBLIC void nk_euclidean_i4_rvv(nk_i4x2_t const *a_scalars, nk_i4x2_t const *b_scalars, nk_size_t count_scalars,
744
+ nk_f32_t *result) {
745
+ nk_u32_t d2;
746
+ nk_sqeuclidean_i4_rvv(a_scalars, b_scalars, count_scalars, &d2);
747
+ *result = nk_f32_sqrt_rvv((nk_f32_t)d2);
748
+ }
749
+
750
+ NK_PUBLIC void nk_angular_i4_rvv(nk_i4x2_t const *a_scalars, nk_i4x2_t const *b_scalars, nk_size_t count_scalars,
751
+ nk_f32_t *result) {
752
+ static nk_i8_t const nk_i4_dot_lut_[256] = {
753
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, //
754
+ 0, 1, 2, 3, 4, 5, 6, 7, -8, -7, -6, -5, -4, -3, -2, -1, //
755
+ 0, 2, 4, 6, 8, 10, 12, 14, -16, -14, -12, -10, -8, -6, -4, -2, //
756
+ 0, 3, 6, 9, 12, 15, 18, 21, -24, -21, -18, -15, -12, -9, -6, -3, //
757
+ 0, 4, 8, 12, 16, 20, 24, 28, -32, -28, -24, -20, -16, -12, -8, -4, //
758
+ 0, 5, 10, 15, 20, 25, 30, 35, -40, -35, -30, -25, -20, -15, -10, -5, //
759
+ 0, 6, 12, 18, 24, 30, 36, 42, -48, -42, -36, -30, -24, -18, -12, -6, //
760
+ 0, 7, 14, 21, 28, 35, 42, 49, -56, -49, -42, -35, -28, -21, -14, -7, //
761
+ 0, -8, -16, -24, -32, -40, -48, -56, 64, 56, 48, 40, 32, 24, 16, 8, //
762
+ 0, -7, -14, -21, -28, -35, -42, -49, 56, 49, 42, 35, 28, 21, 14, 7, //
763
+ 0, -6, -12, -18, -24, -30, -36, -42, 48, 42, 36, 30, 24, 18, 12, 6, //
764
+ 0, -5, -10, -15, -20, -25, -30, -35, 40, 35, 30, 25, 20, 15, 10, 5, //
765
+ 0, -4, -8, -12, -16, -20, -24, -28, 32, 28, 24, 20, 16, 12, 8, 4, //
766
+ 0, -3, -6, -9, -12, -15, -18, -21, 24, 21, 18, 15, 12, 9, 6, 3, //
767
+ 0, -2, -4, -6, -8, -10, -12, -14, 16, 14, 12, 10, 8, 6, 4, 2, //
768
+ 0, -1, -2, -3, -4, -5, -6, -7, 8, 7, 6, 5, 4, 3, 2, 1, //
769
+ };
770
+ static nk_u8_t const nk_i4_sq_lut_[16] = {0, 1, 4, 9, 16, 25, 36, 49, 64, 49, 36, 25, 16, 9, 4, 1};
771
+ count_scalars = nk_size_round_up_to_multiple_(count_scalars, 2);
772
+ nk_size_t n_bytes = count_scalars / 2;
773
+ nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
774
+ vint32m4_t dot_i32m4 = __riscv_vmv_v_x_i32m4(0, vlmax);
775
+ vuint32m4_t a_norm_sq_u32m4 = __riscv_vmv_v_x_u32m4(0, vlmax);
776
+ vuint32m4_t b_norm_sq_u32m4 = __riscv_vmv_v_x_u32m4(0, vlmax);
777
+
778
+ for (nk_size_t vector_length; n_bytes > 0;
779
+ n_bytes -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
780
+ vector_length = __riscv_vsetvl_e8m1(n_bytes);
781
+ vuint8m1_t a_packed_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)a_scalars, vector_length);
782
+ vuint8m1_t b_packed_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)b_scalars, vector_length);
783
+
784
+ // Extract nibbles for index building
785
+ vuint8m1_t a_hi_u8m1 = __riscv_vsrl_vx_u8m1(a_packed_u8m1, 4, vector_length);
786
+ vuint8m1_t b_hi_u8m1 = __riscv_vsrl_vx_u8m1(b_packed_u8m1, 4, vector_length);
787
+ vuint8m1_t a_lo_u8m1 = __riscv_vand_vx_u8m1(a_packed_u8m1, 0x0F, vector_length);
788
+ vuint8m1_t b_lo_u8m1 = __riscv_vand_vx_u8m1(b_packed_u8m1, 0x0F, vector_length);
789
+
790
+ // Dot product via 256-entry LUT: dot_lut[(a<<4)|b] = a_signed * b_signed (i8)
791
+ vuint8m1_t hi_idx_u8m1 = __riscv_vor_vv_u8m1(__riscv_vand_vx_u8m1(a_packed_u8m1, 0xF0, vector_length),
792
+ b_hi_u8m1, vector_length);
793
+ vuint8m1_t lo_idx_u8m1 = __riscv_vor_vv_u8m1(__riscv_vsll_vx_u8m1(a_lo_u8m1, 4, vector_length), b_lo_u8m1,
794
+ vector_length);
795
+ vint8m1_t dot_hi_i8m1 = __riscv_vluxei8_v_i8m1(nk_i4_dot_lut_, hi_idx_u8m1, vector_length);
796
+ vint8m1_t dot_lo_i8m1 = __riscv_vluxei8_v_i8m1(nk_i4_dot_lut_, lo_idx_u8m1, vector_length);
797
+ // Widen i8→i16, add hi+lo, then per-lane accumulate i32+=i16
798
+ vint16m2_t dot_combined_i16m2 = __riscv_vwadd_vv_i16m2(dot_hi_i8m1, dot_lo_i8m1, vector_length);
799
+ dot_i32m4 = __riscv_vwadd_wv_i32m4_tu(dot_i32m4, dot_i32m4, dot_combined_i16m2, vector_length);
800
+
801
+ // Norms via 16-entry squaring LUT + vluxei8
802
+ vuint8m1_t a_hi_sq_u8m1 = __riscv_vluxei8_v_u8m1(nk_i4_sq_lut_, a_hi_u8m1, vector_length);
803
+ vuint8m1_t a_lo_sq_u8m1 = __riscv_vluxei8_v_u8m1(nk_i4_sq_lut_, a_lo_u8m1, vector_length);
804
+ vuint16m2_t a_sq_combined_u16m2 = __riscv_vwaddu_vv_u16m2(a_hi_sq_u8m1, a_lo_sq_u8m1, vector_length);
805
+ a_norm_sq_u32m4 = __riscv_vwaddu_wv_u32m4_tu(a_norm_sq_u32m4, a_norm_sq_u32m4, a_sq_combined_u16m2,
806
+ vector_length);
807
+
808
+ vuint8m1_t b_hi_sq_u8m1 = __riscv_vluxei8_v_u8m1(nk_i4_sq_lut_, b_hi_u8m1, vector_length);
809
+ vuint8m1_t b_lo_sq_u8m1 = __riscv_vluxei8_v_u8m1(nk_i4_sq_lut_, b_lo_u8m1, vector_length);
810
+ vuint16m2_t b_sq_combined_u16m2 = __riscv_vwaddu_vv_u16m2(b_hi_sq_u8m1, b_lo_sq_u8m1, vector_length);
811
+ b_norm_sq_u32m4 = __riscv_vwaddu_wv_u32m4_tu(b_norm_sq_u32m4, b_norm_sq_u32m4, b_sq_combined_u16m2,
812
+ vector_length);
813
+ }
814
+
815
+ // Single horizontal reductions after loop
816
+ vint32m1_t zero_i32m1 = __riscv_vmv_v_x_i32m1(0, vlmax);
817
+ vuint32m1_t zero_u32m1 = __riscv_vmv_v_x_u32m1(0, vlmax);
818
+ nk_i32_t dot_i32 = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m4_i32m1(dot_i32m4, zero_i32m1, vlmax));
819
+ nk_u32_t a_norm_sq_u32 = __riscv_vmv_x_s_u32m1_u32(
820
+ __riscv_vredsum_vs_u32m4_u32m1(a_norm_sq_u32m4, zero_u32m1, vlmax));
821
+ nk_u32_t b_norm_sq_u32 = __riscv_vmv_x_s_u32m1_u32(
822
+ __riscv_vredsum_vs_u32m4_u32m1(b_norm_sq_u32m4, zero_u32m1, vlmax));
823
+
824
+ if (a_norm_sq_u32 == 0 && b_norm_sq_u32 == 0) { *result = 0.0f; }
825
+ else if (dot_i32 == 0) { *result = 1.0f; }
826
+ else {
827
+ nk_f32_t unclipped = 1.0f - (nk_f32_t)dot_i32 * nk_f32_rsqrt_rvv((nk_f32_t)a_norm_sq_u32) *
828
+ nk_f32_rsqrt_rvv((nk_f32_t)b_norm_sq_u32);
829
+ *result = unclipped > 0 ? unclipped : 0;
830
+ }
831
+ }
832
+
833
+ NK_PUBLIC void nk_sqeuclidean_u4_rvv(nk_u4x2_t const *a_scalars, nk_u4x2_t const *b_scalars, nk_size_t count_scalars,
834
+ nk_u32_t *result) {
835
+ static nk_u8_t const nk_u4_sqd_lut_[256] = {
836
+ 0, 1, 4, 9, 16, 25, 36, 49, 64, 81, 100, 121, 144, 169, 196, 225, //
837
+ 1, 0, 1, 4, 9, 16, 25, 36, 49, 64, 81, 100, 121, 144, 169, 196, //
838
+ 4, 1, 0, 1, 4, 9, 16, 25, 36, 49, 64, 81, 100, 121, 144, 169, //
839
+ 9, 4, 1, 0, 1, 4, 9, 16, 25, 36, 49, 64, 81, 100, 121, 144, //
840
+ 16, 9, 4, 1, 0, 1, 4, 9, 16, 25, 36, 49, 64, 81, 100, 121, //
841
+ 25, 16, 9, 4, 1, 0, 1, 4, 9, 16, 25, 36, 49, 64, 81, 100, //
842
+ 36, 25, 16, 9, 4, 1, 0, 1, 4, 9, 16, 25, 36, 49, 64, 81, //
843
+ 49, 36, 25, 16, 9, 4, 1, 0, 1, 4, 9, 16, 25, 36, 49, 64, //
844
+ 64, 49, 36, 25, 16, 9, 4, 1, 0, 1, 4, 9, 16, 25, 36, 49, //
845
+ 81, 64, 49, 36, 25, 16, 9, 4, 1, 0, 1, 4, 9, 16, 25, 36, //
846
+ 100, 81, 64, 49, 36, 25, 16, 9, 4, 1, 0, 1, 4, 9, 16, 25, //
847
+ 121, 100, 81, 64, 49, 36, 25, 16, 9, 4, 1, 0, 1, 4, 9, 16, //
848
+ 144, 121, 100, 81, 64, 49, 36, 25, 16, 9, 4, 1, 0, 1, 4, 9, //
849
+ 169, 144, 121, 100, 81, 64, 49, 36, 25, 16, 9, 4, 1, 0, 1, 4, //
850
+ 196, 169, 144, 121, 100, 81, 64, 49, 36, 25, 16, 9, 4, 1, 0, 1, //
851
+ 225, 196, 169, 144, 121, 100, 81, 64, 49, 36, 25, 16, 9, 4, 1, 0, //
852
+ };
853
+ count_scalars = nk_size_round_up_to_multiple_(count_scalars, 2);
854
+ nk_size_t n_bytes = count_scalars / 2;
855
+ nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
856
+ vuint32m4_t sum_u32m4 = __riscv_vmv_v_x_u32m4(0, vlmax);
857
+ for (nk_size_t vector_length; n_bytes > 0;
858
+ n_bytes -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
859
+ vector_length = __riscv_vsetvl_e8m1(n_bytes);
860
+ vuint8m1_t a_packed_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)a_scalars, vector_length);
861
+ vuint8m1_t b_packed_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)b_scalars, vector_length);
862
+ // Build LUT indices: high nibble pair = (a_hi & 0xF0) | (b_hi >> 4)
863
+ vuint8m1_t hi_idx_u8m1 = __riscv_vor_vv_u8m1(__riscv_vand_vx_u8m1(a_packed_u8m1, 0xF0, vector_length),
864
+ __riscv_vsrl_vx_u8m1(b_packed_u8m1, 4, vector_length),
865
+ vector_length);
866
+ // Low nibble pair = (a_lo << 4) | b_lo
867
+ vuint8m1_t lo_idx_u8m1 = __riscv_vor_vv_u8m1(
868
+ __riscv_vsll_vx_u8m1(__riscv_vand_vx_u8m1(a_packed_u8m1, 0x0F, vector_length), 4, vector_length),
869
+ __riscv_vand_vx_u8m1(b_packed_u8m1, 0x0F, vector_length), vector_length);
870
+ // Gather squared differences from LUT (0-225, fits u8)
871
+ vuint8m1_t sq_hi_u8m1 = __riscv_vluxei8_v_u8m1(nk_u4_sqd_lut_, hi_idx_u8m1, vector_length);
872
+ vuint8m1_t sq_lo_u8m1 = __riscv_vluxei8_v_u8m1(nk_u4_sqd_lut_, lo_idx_u8m1, vector_length);
873
+ // Combine and per-lane accumulate: u8+u8→u16, then u32+=u16
874
+ vuint16m2_t combined_u16m2 = __riscv_vwaddu_vv_u16m2(sq_hi_u8m1, sq_lo_u8m1, vector_length);
875
+ sum_u32m4 = __riscv_vwaddu_wv_u32m4_tu(sum_u32m4, sum_u32m4, combined_u16m2, vector_length);
876
+ }
877
+ // Single horizontal reduction after loop
878
+ vuint32m1_t zero_u32m1 = __riscv_vmv_v_x_u32m1(0, vlmax);
879
+ *result = __riscv_vmv_x_s_u32m1_u32(__riscv_vredsum_vs_u32m4_u32m1(sum_u32m4, zero_u32m1, vlmax));
880
+ }
881
+
882
+ NK_PUBLIC void nk_euclidean_u4_rvv(nk_u4x2_t const *a_scalars, nk_u4x2_t const *b_scalars, nk_size_t count_scalars,
883
+ nk_f32_t *result) {
884
+ nk_u32_t d2;
885
+ nk_sqeuclidean_u4_rvv(a_scalars, b_scalars, count_scalars, &d2);
886
+ *result = nk_f32_sqrt_rvv((nk_f32_t)d2);
887
+ }
888
+
889
+ NK_PUBLIC void nk_angular_u4_rvv(nk_u4x2_t const *a_scalars, nk_u4x2_t const *b_scalars, nk_size_t count_scalars,
890
+ nk_f32_t *result) {
891
+ static nk_u8_t const nk_u4_dot_lut_[256] = {
892
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, //
893
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, //
894
+ 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, //
895
+ 0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33, 36, 39, 42, 45, //
896
+ 0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60, //
897
+ 0, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60, 65, 70, 75, //
898
+ 0, 6, 12, 18, 24, 30, 36, 42, 48, 54, 60, 66, 72, 78, 84, 90, //
899
+ 0, 7, 14, 21, 28, 35, 42, 49, 56, 63, 70, 77, 84, 91, 98, 105, //
900
+ 0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120, //
901
+ 0, 9, 18, 27, 36, 45, 54, 63, 72, 81, 90, 99, 108, 117, 126, 135, //
902
+ 0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150, //
903
+ 0, 11, 22, 33, 44, 55, 66, 77, 88, 99, 110, 121, 132, 143, 154, 165, //
904
+ 0, 12, 24, 36, 48, 60, 72, 84, 96, 108, 120, 132, 144, 156, 168, 180, //
905
+ 0, 13, 26, 39, 52, 65, 78, 91, 104, 117, 130, 143, 156, 169, 182, 195, //
906
+ 0, 14, 28, 42, 56, 70, 84, 98, 112, 126, 140, 154, 168, 182, 196, 210, //
907
+ 0, 15, 30, 45, 60, 75, 90, 105, 120, 135, 150, 165, 180, 195, 210, 225, //
908
+ };
909
+ static nk_u8_t const nk_u4_sq_lut_[16] = {0, 1, 4, 9, 16, 25, 36, 49, 64, 81, 100, 121, 144, 169, 196, 225};
910
+ count_scalars = nk_size_round_up_to_multiple_(count_scalars, 2);
911
+ nk_size_t n_bytes = count_scalars / 2;
912
+ nk_size_t vlmax = __riscv_vsetvlmax_e32m4();
913
+ vuint32m4_t dot_u32m4 = __riscv_vmv_v_x_u32m4(0, vlmax);
914
+ vuint32m4_t a_norm_sq_u32m4 = __riscv_vmv_v_x_u32m4(0, vlmax);
915
+ vuint32m4_t b_norm_sq_u32m4 = __riscv_vmv_v_x_u32m4(0, vlmax);
916
+
917
+ for (nk_size_t vector_length; n_bytes > 0;
918
+ n_bytes -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
919
+ vector_length = __riscv_vsetvl_e8m1(n_bytes);
920
+ vuint8m1_t a_packed_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)a_scalars, vector_length);
921
+ vuint8m1_t b_packed_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)b_scalars, vector_length);
922
+
923
+ // Extract nibbles
924
+ vuint8m1_t a_hi_u8m1 = __riscv_vsrl_vx_u8m1(a_packed_u8m1, 4, vector_length);
925
+ vuint8m1_t b_hi_u8m1 = __riscv_vsrl_vx_u8m1(b_packed_u8m1, 4, vector_length);
926
+ vuint8m1_t a_lo_u8m1 = __riscv_vand_vx_u8m1(a_packed_u8m1, 0x0F, vector_length);
927
+ vuint8m1_t b_lo_u8m1 = __riscv_vand_vx_u8m1(b_packed_u8m1, 0x0F, vector_length);
928
+
929
+ // Dot product via 256-entry LUT: dot_lut[(a<<4)|b] = a * b (u8)
930
+ vuint8m1_t hi_idx_u8m1 = __riscv_vor_vv_u8m1(__riscv_vand_vx_u8m1(a_packed_u8m1, 0xF0, vector_length),
931
+ b_hi_u8m1, vector_length);
932
+ vuint8m1_t lo_idx_u8m1 = __riscv_vor_vv_u8m1(__riscv_vsll_vx_u8m1(a_lo_u8m1, 4, vector_length), b_lo_u8m1,
933
+ vector_length);
934
+ vuint8m1_t dot_hi_u8m1 = __riscv_vluxei8_v_u8m1(nk_u4_dot_lut_, hi_idx_u8m1, vector_length);
935
+ vuint8m1_t dot_lo_u8m1 = __riscv_vluxei8_v_u8m1(nk_u4_dot_lut_, lo_idx_u8m1, vector_length);
936
+ // Widen u8→u16, add hi+lo, then per-lane accumulate u32+=u16
937
+ vuint16m2_t dot_combined_u16m2 = __riscv_vwaddu_vv_u16m2(dot_hi_u8m1, dot_lo_u8m1, vector_length);
938
+ dot_u32m4 = __riscv_vwaddu_wv_u32m4_tu(dot_u32m4, dot_u32m4, dot_combined_u16m2, vector_length);
939
+
940
+ // Norms via 16-entry squaring LUT + vluxei8
941
+ vuint8m1_t a_hi_sq_u8m1 = __riscv_vluxei8_v_u8m1(nk_u4_sq_lut_, a_hi_u8m1, vector_length);
942
+ vuint8m1_t a_lo_sq_u8m1 = __riscv_vluxei8_v_u8m1(nk_u4_sq_lut_, a_lo_u8m1, vector_length);
943
+ vuint16m2_t a_sq_combined_u16m2 = __riscv_vwaddu_vv_u16m2(a_hi_sq_u8m1, a_lo_sq_u8m1, vector_length);
944
+ a_norm_sq_u32m4 = __riscv_vwaddu_wv_u32m4_tu(a_norm_sq_u32m4, a_norm_sq_u32m4, a_sq_combined_u16m2,
945
+ vector_length);
946
+
947
+ vuint8m1_t b_hi_sq_u8m1 = __riscv_vluxei8_v_u8m1(nk_u4_sq_lut_, b_hi_u8m1, vector_length);
948
+ vuint8m1_t b_lo_sq_u8m1 = __riscv_vluxei8_v_u8m1(nk_u4_sq_lut_, b_lo_u8m1, vector_length);
949
+ vuint16m2_t b_sq_combined_u16m2 = __riscv_vwaddu_vv_u16m2(b_hi_sq_u8m1, b_lo_sq_u8m1, vector_length);
950
+ b_norm_sq_u32m4 = __riscv_vwaddu_wv_u32m4_tu(b_norm_sq_u32m4, b_norm_sq_u32m4, b_sq_combined_u16m2,
951
+ vector_length);
952
+ }
953
+
954
+ // Single horizontal reductions after loop
955
+ vuint32m1_t zero_u32m1 = __riscv_vmv_v_x_u32m1(0, vlmax);
956
+ nk_u32_t dot_u32 = __riscv_vmv_x_s_u32m1_u32(__riscv_vredsum_vs_u32m4_u32m1(dot_u32m4, zero_u32m1, vlmax));
957
+ nk_u32_t a_norm_sq_u32 = __riscv_vmv_x_s_u32m1_u32(
958
+ __riscv_vredsum_vs_u32m4_u32m1(a_norm_sq_u32m4, zero_u32m1, vlmax));
959
+ nk_u32_t b_norm_sq_u32 = __riscv_vmv_x_s_u32m1_u32(
960
+ __riscv_vredsum_vs_u32m4_u32m1(b_norm_sq_u32m4, zero_u32m1, vlmax));
961
+
962
+ if (a_norm_sq_u32 == 0 && b_norm_sq_u32 == 0) { *result = 0.0f; }
963
+ else if (dot_u32 == 0) { *result = 1.0f; }
964
+ else {
965
+ nk_f32_t unclipped = 1.0f - (nk_f32_t)dot_u32 * nk_f32_rsqrt_rvv((nk_f32_t)a_norm_sq_u32) *
966
+ nk_f32_rsqrt_rvv((nk_f32_t)b_norm_sq_u32);
967
+ *result = unclipped > 0 ? unclipped : 0;
968
+ }
969
+ }
970
+
971
+ #if defined(__cplusplus)
972
+ } // extern "C"
973
+ #endif
974
+
975
+ #if defined(__clang__)
976
+ #pragma clang attribute pop
977
+ #elif defined(__GNUC__)
978
+ #pragma GCC pop_options
979
+ #endif
980
+
981
+ #pragma endregion - Small Integers
982
+ #endif // NK_TARGET_RVV
983
+ #endif // NK_TARGET_RISCV_
984
+ #endif // NK_SPATIAL_RVV_H