numkong 7.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (294) hide show
  1. package/LICENSE +201 -0
  2. package/README.md +495 -0
  3. package/binding.gyp +540 -0
  4. package/c/dispatch.h +512 -0
  5. package/c/dispatch_bf16.c +389 -0
  6. package/c/dispatch_bf16c.c +52 -0
  7. package/c/dispatch_e2m3.c +263 -0
  8. package/c/dispatch_e3m2.c +243 -0
  9. package/c/dispatch_e4m3.c +276 -0
  10. package/c/dispatch_e5m2.c +272 -0
  11. package/c/dispatch_f16.c +376 -0
  12. package/c/dispatch_f16c.c +58 -0
  13. package/c/dispatch_f32.c +378 -0
  14. package/c/dispatch_f32c.c +99 -0
  15. package/c/dispatch_f64.c +296 -0
  16. package/c/dispatch_f64c.c +98 -0
  17. package/c/dispatch_i16.c +96 -0
  18. package/c/dispatch_i32.c +89 -0
  19. package/c/dispatch_i4.c +150 -0
  20. package/c/dispatch_i64.c +86 -0
  21. package/c/dispatch_i8.c +289 -0
  22. package/c/dispatch_other.c +330 -0
  23. package/c/dispatch_u1.c +148 -0
  24. package/c/dispatch_u16.c +124 -0
  25. package/c/dispatch_u32.c +118 -0
  26. package/c/dispatch_u4.c +150 -0
  27. package/c/dispatch_u64.c +102 -0
  28. package/c/dispatch_u8.c +303 -0
  29. package/c/numkong.c +950 -0
  30. package/include/README.md +573 -0
  31. package/include/module.modulemap +129 -0
  32. package/include/numkong/attention/sapphireamx.h +1361 -0
  33. package/include/numkong/attention/sme.h +2066 -0
  34. package/include/numkong/attention.h +49 -0
  35. package/include/numkong/capabilities.h +748 -0
  36. package/include/numkong/cast/README.md +262 -0
  37. package/include/numkong/cast/haswell.h +975 -0
  38. package/include/numkong/cast/icelake.h +470 -0
  39. package/include/numkong/cast/neon.h +1192 -0
  40. package/include/numkong/cast/rvv.h +1021 -0
  41. package/include/numkong/cast/sapphire.h +262 -0
  42. package/include/numkong/cast/serial.h +2262 -0
  43. package/include/numkong/cast/skylake.h +856 -0
  44. package/include/numkong/cast/v128relaxed.h +180 -0
  45. package/include/numkong/cast.h +230 -0
  46. package/include/numkong/curved/README.md +223 -0
  47. package/include/numkong/curved/genoa.h +182 -0
  48. package/include/numkong/curved/haswell.h +276 -0
  49. package/include/numkong/curved/neon.h +205 -0
  50. package/include/numkong/curved/neonbfdot.h +212 -0
  51. package/include/numkong/curved/neonhalf.h +212 -0
  52. package/include/numkong/curved/rvv.h +305 -0
  53. package/include/numkong/curved/serial.h +207 -0
  54. package/include/numkong/curved/skylake.h +457 -0
  55. package/include/numkong/curved/smef64.h +506 -0
  56. package/include/numkong/curved.h +517 -0
  57. package/include/numkong/curved.hpp +144 -0
  58. package/include/numkong/dot/README.md +425 -0
  59. package/include/numkong/dot/alder.h +563 -0
  60. package/include/numkong/dot/genoa.h +315 -0
  61. package/include/numkong/dot/haswell.h +1688 -0
  62. package/include/numkong/dot/icelake.h +883 -0
  63. package/include/numkong/dot/neon.h +818 -0
  64. package/include/numkong/dot/neonbfdot.h +244 -0
  65. package/include/numkong/dot/neonfhm.h +360 -0
  66. package/include/numkong/dot/neonhalf.h +198 -0
  67. package/include/numkong/dot/neonsdot.h +508 -0
  68. package/include/numkong/dot/rvv.h +714 -0
  69. package/include/numkong/dot/rvvbb.h +72 -0
  70. package/include/numkong/dot/rvvbf16.h +123 -0
  71. package/include/numkong/dot/rvvhalf.h +129 -0
  72. package/include/numkong/dot/sapphire.h +141 -0
  73. package/include/numkong/dot/serial.h +838 -0
  74. package/include/numkong/dot/sierra.h +405 -0
  75. package/include/numkong/dot/skylake.h +1084 -0
  76. package/include/numkong/dot/sve.h +379 -0
  77. package/include/numkong/dot/svebfdot.h +74 -0
  78. package/include/numkong/dot/svehalf.h +123 -0
  79. package/include/numkong/dot/v128relaxed.h +1258 -0
  80. package/include/numkong/dot.h +1070 -0
  81. package/include/numkong/dot.hpp +94 -0
  82. package/include/numkong/dots/README.md +496 -0
  83. package/include/numkong/dots/alder.h +114 -0
  84. package/include/numkong/dots/genoa.h +94 -0
  85. package/include/numkong/dots/haswell.h +295 -0
  86. package/include/numkong/dots/icelake.h +171 -0
  87. package/include/numkong/dots/neon.h +120 -0
  88. package/include/numkong/dots/neonbfdot.h +58 -0
  89. package/include/numkong/dots/neonfhm.h +94 -0
  90. package/include/numkong/dots/neonhalf.h +57 -0
  91. package/include/numkong/dots/neonsdot.h +108 -0
  92. package/include/numkong/dots/rvv.h +2486 -0
  93. package/include/numkong/dots/sapphireamx.h +3973 -0
  94. package/include/numkong/dots/serial.h +2844 -0
  95. package/include/numkong/dots/sierra.h +97 -0
  96. package/include/numkong/dots/skylake.h +196 -0
  97. package/include/numkong/dots/sme.h +5372 -0
  98. package/include/numkong/dots/smebi32.h +461 -0
  99. package/include/numkong/dots/smef64.h +1318 -0
  100. package/include/numkong/dots/smehalf.h +47 -0
  101. package/include/numkong/dots/v128relaxed.h +294 -0
  102. package/include/numkong/dots.h +2804 -0
  103. package/include/numkong/dots.hpp +639 -0
  104. package/include/numkong/each/README.md +469 -0
  105. package/include/numkong/each/haswell.h +1658 -0
  106. package/include/numkong/each/icelake.h +272 -0
  107. package/include/numkong/each/neon.h +1104 -0
  108. package/include/numkong/each/neonbfdot.h +212 -0
  109. package/include/numkong/each/neonhalf.h +410 -0
  110. package/include/numkong/each/rvv.h +1121 -0
  111. package/include/numkong/each/sapphire.h +477 -0
  112. package/include/numkong/each/serial.h +260 -0
  113. package/include/numkong/each/skylake.h +1562 -0
  114. package/include/numkong/each.h +2146 -0
  115. package/include/numkong/each.hpp +434 -0
  116. package/include/numkong/geospatial/README.md +147 -0
  117. package/include/numkong/geospatial/haswell.h +593 -0
  118. package/include/numkong/geospatial/neon.h +571 -0
  119. package/include/numkong/geospatial/rvv.h +701 -0
  120. package/include/numkong/geospatial/serial.h +309 -0
  121. package/include/numkong/geospatial/skylake.h +577 -0
  122. package/include/numkong/geospatial/v128relaxed.h +613 -0
  123. package/include/numkong/geospatial.h +453 -0
  124. package/include/numkong/geospatial.hpp +235 -0
  125. package/include/numkong/matrix.hpp +336 -0
  126. package/include/numkong/maxsim/README.md +187 -0
  127. package/include/numkong/maxsim/alder.h +511 -0
  128. package/include/numkong/maxsim/genoa.h +115 -0
  129. package/include/numkong/maxsim/haswell.h +553 -0
  130. package/include/numkong/maxsim/icelake.h +480 -0
  131. package/include/numkong/maxsim/neonsdot.h +394 -0
  132. package/include/numkong/maxsim/sapphireamx.h +877 -0
  133. package/include/numkong/maxsim/serial.h +490 -0
  134. package/include/numkong/maxsim/sme.h +929 -0
  135. package/include/numkong/maxsim/v128relaxed.h +280 -0
  136. package/include/numkong/maxsim.h +571 -0
  137. package/include/numkong/maxsim.hpp +133 -0
  138. package/include/numkong/mesh/README.md +227 -0
  139. package/include/numkong/mesh/haswell.h +2235 -0
  140. package/include/numkong/mesh/neon.h +1329 -0
  141. package/include/numkong/mesh/neonbfdot.h +842 -0
  142. package/include/numkong/mesh/neonhalf.h +616 -0
  143. package/include/numkong/mesh/rvv.h +916 -0
  144. package/include/numkong/mesh/serial.h +742 -0
  145. package/include/numkong/mesh/skylake.h +1135 -0
  146. package/include/numkong/mesh/v128relaxed.h +1052 -0
  147. package/include/numkong/mesh.h +652 -0
  148. package/include/numkong/mesh.hpp +762 -0
  149. package/include/numkong/numkong.h +78 -0
  150. package/include/numkong/numkong.hpp +57 -0
  151. package/include/numkong/probability/README.md +173 -0
  152. package/include/numkong/probability/haswell.h +267 -0
  153. package/include/numkong/probability/neon.h +225 -0
  154. package/include/numkong/probability/rvv.h +409 -0
  155. package/include/numkong/probability/serial.h +169 -0
  156. package/include/numkong/probability/skylake.h +324 -0
  157. package/include/numkong/probability.h +383 -0
  158. package/include/numkong/probability.hpp +120 -0
  159. package/include/numkong/random.h +50 -0
  160. package/include/numkong/random.hpp +285 -0
  161. package/include/numkong/reduce/README.md +547 -0
  162. package/include/numkong/reduce/alder.h +632 -0
  163. package/include/numkong/reduce/genoa.h +201 -0
  164. package/include/numkong/reduce/haswell.h +3783 -0
  165. package/include/numkong/reduce/icelake.h +549 -0
  166. package/include/numkong/reduce/neon.h +3841 -0
  167. package/include/numkong/reduce/neonbfdot.h +353 -0
  168. package/include/numkong/reduce/neonfhm.h +665 -0
  169. package/include/numkong/reduce/neonhalf.h +157 -0
  170. package/include/numkong/reduce/neonsdot.h +357 -0
  171. package/include/numkong/reduce/rvv.h +3407 -0
  172. package/include/numkong/reduce/serial.h +757 -0
  173. package/include/numkong/reduce/sierra.h +338 -0
  174. package/include/numkong/reduce/skylake.h +3792 -0
  175. package/include/numkong/reduce/v128relaxed.h +2302 -0
  176. package/include/numkong/reduce.h +1597 -0
  177. package/include/numkong/reduce.hpp +633 -0
  178. package/include/numkong/scalar/README.md +89 -0
  179. package/include/numkong/scalar/haswell.h +113 -0
  180. package/include/numkong/scalar/neon.h +122 -0
  181. package/include/numkong/scalar/neonhalf.h +70 -0
  182. package/include/numkong/scalar/rvv.h +211 -0
  183. package/include/numkong/scalar/sapphire.h +63 -0
  184. package/include/numkong/scalar/serial.h +332 -0
  185. package/include/numkong/scalar/v128relaxed.h +56 -0
  186. package/include/numkong/scalar.h +683 -0
  187. package/include/numkong/set/README.md +179 -0
  188. package/include/numkong/set/haswell.h +334 -0
  189. package/include/numkong/set/icelake.h +485 -0
  190. package/include/numkong/set/neon.h +364 -0
  191. package/include/numkong/set/rvv.h +226 -0
  192. package/include/numkong/set/rvvbb.h +117 -0
  193. package/include/numkong/set/serial.h +174 -0
  194. package/include/numkong/set/sve.h +185 -0
  195. package/include/numkong/set/v128relaxed.h +240 -0
  196. package/include/numkong/set.h +457 -0
  197. package/include/numkong/set.hpp +114 -0
  198. package/include/numkong/sets/README.md +149 -0
  199. package/include/numkong/sets/haswell.h +63 -0
  200. package/include/numkong/sets/icelake.h +66 -0
  201. package/include/numkong/sets/neon.h +61 -0
  202. package/include/numkong/sets/serial.h +43 -0
  203. package/include/numkong/sets/smebi32.h +1099 -0
  204. package/include/numkong/sets/v128relaxed.h +58 -0
  205. package/include/numkong/sets.h +339 -0
  206. package/include/numkong/sparse/README.md +156 -0
  207. package/include/numkong/sparse/icelake.h +463 -0
  208. package/include/numkong/sparse/neon.h +288 -0
  209. package/include/numkong/sparse/serial.h +117 -0
  210. package/include/numkong/sparse/sve2.h +507 -0
  211. package/include/numkong/sparse/turin.h +322 -0
  212. package/include/numkong/sparse.h +363 -0
  213. package/include/numkong/sparse.hpp +113 -0
  214. package/include/numkong/spatial/README.md +435 -0
  215. package/include/numkong/spatial/alder.h +607 -0
  216. package/include/numkong/spatial/genoa.h +290 -0
  217. package/include/numkong/spatial/haswell.h +960 -0
  218. package/include/numkong/spatial/icelake.h +586 -0
  219. package/include/numkong/spatial/neon.h +773 -0
  220. package/include/numkong/spatial/neonbfdot.h +165 -0
  221. package/include/numkong/spatial/neonhalf.h +118 -0
  222. package/include/numkong/spatial/neonsdot.h +261 -0
  223. package/include/numkong/spatial/rvv.h +984 -0
  224. package/include/numkong/spatial/rvvbf16.h +123 -0
  225. package/include/numkong/spatial/rvvhalf.h +117 -0
  226. package/include/numkong/spatial/sapphire.h +343 -0
  227. package/include/numkong/spatial/serial.h +346 -0
  228. package/include/numkong/spatial/sierra.h +323 -0
  229. package/include/numkong/spatial/skylake.h +606 -0
  230. package/include/numkong/spatial/sve.h +224 -0
  231. package/include/numkong/spatial/svebfdot.h +122 -0
  232. package/include/numkong/spatial/svehalf.h +109 -0
  233. package/include/numkong/spatial/v128relaxed.h +717 -0
  234. package/include/numkong/spatial.h +1425 -0
  235. package/include/numkong/spatial.hpp +183 -0
  236. package/include/numkong/spatials/README.md +580 -0
  237. package/include/numkong/spatials/alder.h +94 -0
  238. package/include/numkong/spatials/genoa.h +94 -0
  239. package/include/numkong/spatials/haswell.h +219 -0
  240. package/include/numkong/spatials/icelake.h +113 -0
  241. package/include/numkong/spatials/neon.h +109 -0
  242. package/include/numkong/spatials/neonbfdot.h +60 -0
  243. package/include/numkong/spatials/neonfhm.h +92 -0
  244. package/include/numkong/spatials/neonhalf.h +58 -0
  245. package/include/numkong/spatials/neonsdot.h +109 -0
  246. package/include/numkong/spatials/rvv.h +1960 -0
  247. package/include/numkong/spatials/sapphireamx.h +1149 -0
  248. package/include/numkong/spatials/serial.h +226 -0
  249. package/include/numkong/spatials/sierra.h +96 -0
  250. package/include/numkong/spatials/skylake.h +184 -0
  251. package/include/numkong/spatials/sme.h +1901 -0
  252. package/include/numkong/spatials/smef64.h +465 -0
  253. package/include/numkong/spatials/v128relaxed.h +240 -0
  254. package/include/numkong/spatials.h +3021 -0
  255. package/include/numkong/spatials.hpp +508 -0
  256. package/include/numkong/tensor.hpp +1592 -0
  257. package/include/numkong/trigonometry/README.md +184 -0
  258. package/include/numkong/trigonometry/haswell.h +652 -0
  259. package/include/numkong/trigonometry/neon.h +639 -0
  260. package/include/numkong/trigonometry/rvv.h +699 -0
  261. package/include/numkong/trigonometry/serial.h +703 -0
  262. package/include/numkong/trigonometry/skylake.h +721 -0
  263. package/include/numkong/trigonometry/v128relaxed.h +666 -0
  264. package/include/numkong/trigonometry.h +467 -0
  265. package/include/numkong/trigonometry.hpp +166 -0
  266. package/include/numkong/types.h +1384 -0
  267. package/include/numkong/types.hpp +5603 -0
  268. package/include/numkong/vector.hpp +698 -0
  269. package/javascript/README.md +246 -0
  270. package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
  271. package/javascript/dist/cjs/numkong-wasm.js +617 -0
  272. package/javascript/dist/cjs/numkong.d.ts +343 -0
  273. package/javascript/dist/cjs/numkong.js +523 -0
  274. package/javascript/dist/cjs/package.json +3 -0
  275. package/javascript/dist/cjs/types.d.ts +284 -0
  276. package/javascript/dist/cjs/types.js +653 -0
  277. package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
  278. package/javascript/dist/esm/numkong-wasm.js +595 -0
  279. package/javascript/dist/esm/numkong.d.ts +343 -0
  280. package/javascript/dist/esm/numkong.js +452 -0
  281. package/javascript/dist/esm/package.json +3 -0
  282. package/javascript/dist/esm/types.d.ts +284 -0
  283. package/javascript/dist/esm/types.js +630 -0
  284. package/javascript/dist-package-cjs.json +3 -0
  285. package/javascript/dist-package-esm.json +3 -0
  286. package/javascript/node-gyp-build.d.ts +1 -0
  287. package/javascript/numkong-wasm.ts +756 -0
  288. package/javascript/numkong.c +689 -0
  289. package/javascript/numkong.ts +575 -0
  290. package/javascript/tsconfig-base.json +39 -0
  291. package/javascript/tsconfig-cjs.json +8 -0
  292. package/javascript/tsconfig-esm.json +8 -0
  293. package/javascript/types.ts +674 -0
  294. package/package.json +87 -0
@@ -0,0 +1,1960 @@
1
+ /**
2
+ * @brief Batched Spatial Distances for RISC-V Vector (RVV).
3
+ * @file include/numkong/spatials/rvv.h
4
+ * @author Ash Vardanian
5
+ * @date February 23, 2026
6
+ *
7
+ * @sa include/numkong/spatials.h
8
+ */
9
+ #ifndef NK_SPATIALS_RVV_H
10
+ #define NK_SPATIALS_RVV_H
11
+
12
+ #if NK_TARGET_RISCV_
13
+ #if NK_TARGET_RVV
14
+
15
+ #include "numkong/dots/serial.h"
16
+ #include "numkong/dots/rvv.h"
17
+ #include "numkong/spatial/rvv.h"
18
+
19
+ #if defined(__cplusplus)
20
+ extern "C" {
21
+ #endif
22
+
23
+ #if defined(__clang__)
24
+ #pragma clang attribute push(__attribute__((target("arch=+v"))), apply_to = function)
25
+ #elif defined(__GNUC__)
26
+ #pragma GCC push_options
27
+ #pragma GCC target("arch=+v")
28
+ #endif
29
+
30
+ #pragma region Single Precision Floats
31
+
32
+ NK_INTERNAL void nk_angulars_packed_f32_rvv_finalize_(nk_f32_t const *a, void const *b_packed, nk_f64_t *c,
33
+ nk_size_t rows, nk_size_t columns, nk_size_t depth,
34
+ nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
35
+ nk_cross_packed_buffer_header_t const *header = (nk_cross_packed_buffer_header_t const *)b_packed;
36
+ nk_f64_t const *target_norms = (nk_f64_t const *)((char const *)b_packed + sizeof(nk_cross_packed_buffer_header_t) +
37
+ header->column_count * header->depth_padded_values *
38
+ sizeof(nk_f32_t));
39
+ for (nk_size_t row_index = 0; row_index < rows; row_index++) {
40
+ nk_f32_t const *a_row = a + row_index * a_stride_elements;
41
+ nk_f64_t *result_row = c + row_index * c_stride_elements;
42
+ nk_f64_t query_norm_sq_f64 = nk_dots_reduce_sumsq_f32_(a_row, depth);
43
+ nk_size_t count_columns = columns;
44
+ nk_f64_t *result_ptr = result_row;
45
+ nk_f64_t const *norms_ptr = target_norms;
46
+ while (count_columns > 0) {
47
+ size_t vector_length = __riscv_vsetvl_e64m1(count_columns);
48
+ vfloat64m1_t dots_f64m1 = __riscv_vle64_v_f64m1(result_ptr, vector_length);
49
+ vfloat64m1_t target_norms_sq_f64m1 = __riscv_vle64_v_f64m1(norms_ptr, vector_length);
50
+ vfloat64m1_t norms_product_f64m1 = __riscv_vfmul_vf_f64m1(target_norms_sq_f64m1, query_norm_sq_f64,
51
+ vector_length);
52
+ vfloat64m1_t rsqrt_f64m1 = nk_rsqrt_f64m1_rvv_(norms_product_f64m1, vector_length);
53
+ vfloat64m1_t normalized_dots_f64m1 = __riscv_vfmul_vv_f64m1(dots_f64m1, rsqrt_f64m1, vector_length);
54
+ vfloat64m1_t angular_f64m1 = __riscv_vfrsub_vf_f64m1(normalized_dots_f64m1, 1.0, vector_length);
55
+ angular_f64m1 = __riscv_vfmax_vf_f64m1(angular_f64m1, 0.0, vector_length);
56
+ __riscv_vse64_v_f64m1(result_ptr, angular_f64m1, vector_length);
57
+ result_ptr += vector_length;
58
+ norms_ptr += vector_length;
59
+ count_columns -= vector_length;
60
+ }
61
+ }
62
+ }
63
+
64
+ NK_PUBLIC void nk_angulars_packed_f32_rvv( //
65
+ nk_f32_t const *a, void const *b_packed, nk_f64_t *c, //
66
+ nk_size_t rows, nk_size_t columns, nk_size_t depth, //
67
+ nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
68
+ nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_f32_t);
69
+ nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f64_t);
70
+ nk_dots_packed_f32_rvv(a, b_packed, c, rows, columns, depth, a_stride_in_bytes, c_stride_in_bytes);
71
+ nk_angulars_packed_f32_rvv_finalize_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
72
+ }
73
+
74
+ NK_INTERNAL void nk_euclideans_packed_f32_rvv_finalize_(nk_f32_t const *a, void const *b_packed, nk_f64_t *c,
75
+ nk_size_t rows, nk_size_t columns, nk_size_t depth,
76
+ nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
77
+ nk_cross_packed_buffer_header_t const *header = (nk_cross_packed_buffer_header_t const *)b_packed;
78
+ nk_f64_t const *target_norms = (nk_f64_t const *)((char const *)b_packed + sizeof(nk_cross_packed_buffer_header_t) +
79
+ header->column_count * header->depth_padded_values *
80
+ sizeof(nk_f32_t));
81
+ for (nk_size_t row_index = 0; row_index < rows; row_index++) {
82
+ nk_f32_t const *a_row = a + row_index * a_stride_elements;
83
+ nk_f64_t *result_row = c + row_index * c_stride_elements;
84
+ nk_f64_t query_norm_sq_f64 = nk_dots_reduce_sumsq_f32_(a_row, depth);
85
+ nk_size_t count_columns = columns;
86
+ nk_f64_t *result_ptr = result_row;
87
+ nk_f64_t const *norms_ptr = target_norms;
88
+ while (count_columns > 0) {
89
+ size_t vector_length = __riscv_vsetvl_e64m1(count_columns);
90
+ vfloat64m1_t dots_f64m1 = __riscv_vle64_v_f64m1(result_ptr, vector_length);
91
+ vfloat64m1_t target_norms_sq_f64m1 = __riscv_vle64_v_f64m1(norms_ptr, vector_length);
92
+ vfloat64m1_t sum_sq_f64m1 = __riscv_vfadd_vf_f64m1(target_norms_sq_f64m1, query_norm_sq_f64, vector_length);
93
+ vfloat64m1_t dist_sq_f64m1 = __riscv_vfsub_vv_f64m1(
94
+ sum_sq_f64m1, __riscv_vfmul_vf_f64m1(dots_f64m1, 2.0, vector_length), vector_length);
95
+ dist_sq_f64m1 = __riscv_vfmax_vf_f64m1(dist_sq_f64m1, 0.0, vector_length);
96
+ __riscv_vse64_v_f64m1(result_ptr, __riscv_vfsqrt_v_f64m1(dist_sq_f64m1, vector_length), vector_length);
97
+ result_ptr += vector_length;
98
+ norms_ptr += vector_length;
99
+ count_columns -= vector_length;
100
+ }
101
+ }
102
+ }
103
+
104
+ NK_PUBLIC void nk_euclideans_packed_f32_rvv( //
105
+ nk_f32_t const *a, void const *b_packed, nk_f64_t *c, //
106
+ nk_size_t rows, nk_size_t columns, nk_size_t depth, //
107
+ nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
108
+ nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_f32_t);
109
+ nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f64_t);
110
+ nk_dots_packed_f32_rvv(a, b_packed, c, rows, columns, depth, a_stride_in_bytes, c_stride_in_bytes);
111
+ nk_euclideans_packed_f32_rvv_finalize_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
112
+ }
113
+
114
+ NK_INTERNAL void nk_angulars_symmetric_f32_rvv_finalize_(nk_f32_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
115
+ nk_size_t stride_elements, nk_f64_t *result,
116
+ nk_size_t result_stride_elements, nk_size_t row_start,
117
+ nk_size_t row_count) {
118
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
119
+ nk_f64_t *result_row = result + row_index * result_stride_elements;
120
+ result_row[row_index] = nk_dots_reduce_sumsq_f32_(vectors + row_index * stride_elements, depth);
121
+ }
122
+ nk_f64_t norms_cache[256];
123
+ for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
124
+ nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
125
+ for (nk_size_t col = chunk_start; col < chunk_end; ++col)
126
+ norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_f32_(vectors + col * stride_elements, depth);
127
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
128
+ nk_size_t col_start = row_index + 1 > chunk_start ? row_index + 1 : chunk_start;
129
+ if (col_start >= chunk_end) continue;
130
+ nk_f64_t *result_row = result + row_index * result_stride_elements;
131
+ nk_f64_t query_norm_sq_f64 = result_row[row_index];
132
+ nk_size_t count_remaining = chunk_end - col_start;
133
+ nk_f64_t *result_ptr = result_row + col_start;
134
+ nk_f64_t const *norms_ptr = norms_cache + (col_start - chunk_start);
135
+ while (count_remaining > 0) {
136
+ size_t vector_length = __riscv_vsetvl_e64m1(count_remaining);
137
+ vfloat64m1_t dots_f64m1 = __riscv_vle64_v_f64m1(result_ptr, vector_length);
138
+ vfloat64m1_t target_norms_sq_f64m1 = __riscv_vle64_v_f64m1(norms_ptr, vector_length);
139
+ vfloat64m1_t norms_product_f64m1 = __riscv_vfmul_vf_f64m1(target_norms_sq_f64m1, query_norm_sq_f64,
140
+ vector_length);
141
+ vfloat64m1_t rsqrt_f64m1 = nk_rsqrt_f64m1_rvv_(norms_product_f64m1, vector_length);
142
+ vfloat64m1_t normalized_dots_f64m1 = __riscv_vfmul_vv_f64m1(dots_f64m1, rsqrt_f64m1, vector_length);
143
+ vfloat64m1_t angular_f64m1 = __riscv_vfrsub_vf_f64m1(normalized_dots_f64m1, 1.0, vector_length);
144
+ angular_f64m1 = __riscv_vfmax_vf_f64m1(angular_f64m1, 0.0, vector_length);
145
+ __riscv_vse64_v_f64m1(result_ptr, angular_f64m1, vector_length);
146
+ result_ptr += vector_length;
147
+ norms_ptr += vector_length;
148
+ count_remaining -= vector_length;
149
+ }
150
+ }
151
+ }
152
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index)
153
+ result[row_index * result_stride_elements + row_index] = 0;
154
+ }
155
+
156
+ NK_PUBLIC void nk_angulars_symmetric_f32_rvv( //
157
+ nk_f32_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
158
+ nk_f64_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
159
+ nk_size_t const stride_elements = stride / sizeof(nk_f32_t);
160
+ nk_size_t const result_stride_elements = result_stride / sizeof(nk_f64_t);
161
+ nk_dots_symmetric_f32_rvv(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
162
+ nk_angulars_symmetric_f32_rvv_finalize_(vectors, n_vectors, depth, stride_elements, result, result_stride_elements,
163
+ row_start, row_count);
164
+ }
165
+
166
+ NK_INTERNAL void nk_euclideans_symmetric_f32_rvv_finalize_(nk_f32_t const *vectors, nk_size_t n_vectors,
167
+ nk_size_t depth, nk_size_t stride_elements, nk_f64_t *result,
168
+ nk_size_t result_stride_elements, nk_size_t row_start,
169
+ nk_size_t row_count) {
170
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
171
+ nk_f64_t *result_row = result + row_index * result_stride_elements;
172
+ result_row[row_index] = nk_dots_reduce_sumsq_f32_(vectors + row_index * stride_elements, depth);
173
+ }
174
+ nk_f64_t norms_cache[256];
175
+ for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
176
+ nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
177
+ for (nk_size_t col = chunk_start; col < chunk_end; ++col)
178
+ norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_f32_(vectors + col * stride_elements, depth);
179
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
180
+ nk_size_t col_start = row_index + 1 > chunk_start ? row_index + 1 : chunk_start;
181
+ if (col_start >= chunk_end) continue;
182
+ nk_f64_t *result_row = result + row_index * result_stride_elements;
183
+ nk_f64_t query_norm_sq_f64 = result_row[row_index];
184
+ nk_size_t count_remaining = chunk_end - col_start;
185
+ nk_f64_t *result_ptr = result_row + col_start;
186
+ nk_f64_t const *norms_ptr = norms_cache + (col_start - chunk_start);
187
+ while (count_remaining > 0) {
188
+ size_t vector_length = __riscv_vsetvl_e64m1(count_remaining);
189
+ vfloat64m1_t dots_f64m1 = __riscv_vle64_v_f64m1(result_ptr, vector_length);
190
+ vfloat64m1_t target_norms_sq_f64m1 = __riscv_vle64_v_f64m1(norms_ptr, vector_length);
191
+ vfloat64m1_t sum_sq_f64m1 = __riscv_vfadd_vf_f64m1(target_norms_sq_f64m1, query_norm_sq_f64,
192
+ vector_length);
193
+ vfloat64m1_t dist_sq_f64m1 = __riscv_vfsub_vv_f64m1(
194
+ sum_sq_f64m1, __riscv_vfmul_vf_f64m1(dots_f64m1, 2.0, vector_length), vector_length);
195
+ dist_sq_f64m1 = __riscv_vfmax_vf_f64m1(dist_sq_f64m1, 0.0, vector_length);
196
+ __riscv_vse64_v_f64m1(result_ptr, __riscv_vfsqrt_v_f64m1(dist_sq_f64m1, vector_length), vector_length);
197
+ result_ptr += vector_length;
198
+ norms_ptr += vector_length;
199
+ count_remaining -= vector_length;
200
+ }
201
+ }
202
+ }
203
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index)
204
+ result[row_index * result_stride_elements + row_index] = 0;
205
+ }
206
+
207
+ NK_PUBLIC void nk_euclideans_symmetric_f32_rvv( //
208
+ nk_f32_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
209
+ nk_f64_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
210
+ nk_size_t const stride_elements = stride / sizeof(nk_f32_t);
211
+ nk_size_t const result_stride_elements = result_stride / sizeof(nk_f64_t);
212
+ nk_dots_symmetric_f32_rvv(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
213
+ nk_euclideans_symmetric_f32_rvv_finalize_(vectors, n_vectors, depth, stride_elements, result,
214
+ result_stride_elements, row_start, row_count);
215
+ }
216
+
217
+ #pragma endregion // Single Precision Floats
218
+
219
+ #pragma region Double Precision Floats
220
+
221
+ NK_INTERNAL void nk_angulars_packed_f64_rvv_finalize_(nk_f64_t const *a, void const *b_packed, nk_f64_t *c,
222
+ nk_size_t rows, nk_size_t columns, nk_size_t depth,
223
+ nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
224
+ nk_cross_packed_buffer_header_t const *header = (nk_cross_packed_buffer_header_t const *)b_packed;
225
+ nk_f64_t const *target_norms = (nk_f64_t const *)((char const *)b_packed + sizeof(nk_cross_packed_buffer_header_t) +
226
+ header->column_count * header->depth_padded_values *
227
+ sizeof(nk_f64_t));
228
+ for (nk_size_t row_index = 0; row_index < rows; row_index++) {
229
+ nk_f64_t const *a_row = a + row_index * a_stride_elements;
230
+ nk_f64_t *result_row = c + row_index * c_stride_elements;
231
+ nk_f64_t query_norm_sq_f64 = nk_dots_reduce_sumsq_f64_(a_row, depth);
232
+ nk_size_t count_columns = columns;
233
+ nk_f64_t *result_ptr = result_row;
234
+ nk_f64_t const *norms_ptr = target_norms;
235
+ while (count_columns > 0) {
236
+ size_t vector_length = __riscv_vsetvl_e64m1(count_columns);
237
+ vfloat64m1_t dots_f64m1 = __riscv_vle64_v_f64m1(result_ptr, vector_length);
238
+ vfloat64m1_t target_norms_sq_f64m1 = __riscv_vle64_v_f64m1(norms_ptr, vector_length);
239
+ vfloat64m1_t norms_product_f64m1 = __riscv_vfmul_vf_f64m1(target_norms_sq_f64m1, query_norm_sq_f64,
240
+ vector_length);
241
+ vfloat64m1_t rsqrt_f64m1 = nk_rsqrt_f64m1_rvv_(norms_product_f64m1, vector_length);
242
+ vfloat64m1_t normalized_dots_f64m1 = __riscv_vfmul_vv_f64m1(dots_f64m1, rsqrt_f64m1, vector_length);
243
+ vfloat64m1_t angular_f64m1 = __riscv_vfrsub_vf_f64m1(normalized_dots_f64m1, 1.0, vector_length);
244
+ angular_f64m1 = __riscv_vfmax_vf_f64m1(angular_f64m1, 0.0, vector_length);
245
+ __riscv_vse64_v_f64m1(result_ptr, angular_f64m1, vector_length);
246
+ result_ptr += vector_length;
247
+ norms_ptr += vector_length;
248
+ count_columns -= vector_length;
249
+ }
250
+ }
251
+ }
252
+
253
+ NK_PUBLIC void nk_angulars_packed_f64_rvv( //
254
+ nk_f64_t const *a, void const *b_packed, nk_f64_t *c, //
255
+ nk_size_t rows, nk_size_t columns, nk_size_t depth, //
256
+ nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
257
+ nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_f64_t);
258
+ nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f64_t);
259
+ nk_dots_packed_f64_rvv(a, b_packed, c, rows, columns, depth, a_stride_in_bytes, c_stride_in_bytes);
260
+ nk_angulars_packed_f64_rvv_finalize_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
261
+ }
262
+
263
+ NK_INTERNAL void nk_euclideans_packed_f64_rvv_finalize_(nk_f64_t const *a, void const *b_packed, nk_f64_t *c,
264
+ nk_size_t rows, nk_size_t columns, nk_size_t depth,
265
+ nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
266
+ nk_cross_packed_buffer_header_t const *header = (nk_cross_packed_buffer_header_t const *)b_packed;
267
+ nk_f64_t const *target_norms = (nk_f64_t const *)((char const *)b_packed + sizeof(nk_cross_packed_buffer_header_t) +
268
+ header->column_count * header->depth_padded_values *
269
+ sizeof(nk_f64_t));
270
+ for (nk_size_t row_index = 0; row_index < rows; row_index++) {
271
+ nk_f64_t const *a_row = a + row_index * a_stride_elements;
272
+ nk_f64_t *result_row = c + row_index * c_stride_elements;
273
+ nk_f64_t query_norm_sq_f64 = nk_dots_reduce_sumsq_f64_(a_row, depth);
274
+ nk_size_t count_columns = columns;
275
+ nk_f64_t *result_ptr = result_row;
276
+ nk_f64_t const *norms_ptr = target_norms;
277
+ while (count_columns > 0) {
278
+ size_t vector_length = __riscv_vsetvl_e64m1(count_columns);
279
+ vfloat64m1_t dots_f64m1 = __riscv_vle64_v_f64m1(result_ptr, vector_length);
280
+ vfloat64m1_t target_norms_sq_f64m1 = __riscv_vle64_v_f64m1(norms_ptr, vector_length);
281
+ vfloat64m1_t sum_sq_f64m1 = __riscv_vfadd_vf_f64m1(target_norms_sq_f64m1, query_norm_sq_f64, vector_length);
282
+ vfloat64m1_t dist_sq_f64m1 = __riscv_vfsub_vv_f64m1(
283
+ sum_sq_f64m1, __riscv_vfmul_vf_f64m1(dots_f64m1, 2.0, vector_length), vector_length);
284
+ dist_sq_f64m1 = __riscv_vfmax_vf_f64m1(dist_sq_f64m1, 0.0, vector_length);
285
+ __riscv_vse64_v_f64m1(result_ptr, __riscv_vfsqrt_v_f64m1(dist_sq_f64m1, vector_length), vector_length);
286
+ result_ptr += vector_length;
287
+ norms_ptr += vector_length;
288
+ count_columns -= vector_length;
289
+ }
290
+ }
291
+ }
292
+
293
+ NK_PUBLIC void nk_euclideans_packed_f64_rvv( //
294
+ nk_f64_t const *a, void const *b_packed, nk_f64_t *c, //
295
+ nk_size_t rows, nk_size_t columns, nk_size_t depth, //
296
+ nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
297
+ nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_f64_t);
298
+ nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f64_t);
299
+ nk_dots_packed_f64_rvv(a, b_packed, c, rows, columns, depth, a_stride_in_bytes, c_stride_in_bytes);
300
+ nk_euclideans_packed_f64_rvv_finalize_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
301
+ }
302
+
303
+ NK_INTERNAL void nk_angulars_symmetric_f64_rvv_finalize_(nk_f64_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
304
+ nk_size_t stride_elements, nk_f64_t *result,
305
+ nk_size_t result_stride_elements, nk_size_t row_start,
306
+ nk_size_t row_count) {
307
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
308
+ nk_f64_t *result_row = result + row_index * result_stride_elements;
309
+ result_row[row_index] = nk_dots_reduce_sumsq_f64_(vectors + row_index * stride_elements, depth);
310
+ }
311
+ nk_f64_t norms_cache[256];
312
+ for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
313
+ nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
314
+ for (nk_size_t col = chunk_start; col < chunk_end; ++col)
315
+ norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_f64_(vectors + col * stride_elements, depth);
316
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
317
+ nk_size_t col_start = row_index + 1 > chunk_start ? row_index + 1 : chunk_start;
318
+ if (col_start >= chunk_end) continue;
319
+ nk_f64_t *result_row = result + row_index * result_stride_elements;
320
+ nk_f64_t query_norm_sq_f64 = result_row[row_index];
321
+ nk_size_t count_remaining = chunk_end - col_start;
322
+ nk_f64_t *result_ptr = result_row + col_start;
323
+ nk_f64_t const *norms_ptr = norms_cache + (col_start - chunk_start);
324
+ while (count_remaining > 0) {
325
+ size_t vector_length = __riscv_vsetvl_e64m1(count_remaining);
326
+ vfloat64m1_t dots_f64m1 = __riscv_vle64_v_f64m1(result_ptr, vector_length);
327
+ vfloat64m1_t target_norms_sq_f64m1 = __riscv_vle64_v_f64m1(norms_ptr, vector_length);
328
+ vfloat64m1_t norms_product_f64m1 = __riscv_vfmul_vf_f64m1(target_norms_sq_f64m1, query_norm_sq_f64,
329
+ vector_length);
330
+ vfloat64m1_t rsqrt_f64m1 = nk_rsqrt_f64m1_rvv_(norms_product_f64m1, vector_length);
331
+ vfloat64m1_t normalized_dots_f64m1 = __riscv_vfmul_vv_f64m1(dots_f64m1, rsqrt_f64m1, vector_length);
332
+ vfloat64m1_t angular_f64m1 = __riscv_vfrsub_vf_f64m1(normalized_dots_f64m1, 1.0, vector_length);
333
+ angular_f64m1 = __riscv_vfmax_vf_f64m1(angular_f64m1, 0.0, vector_length);
334
+ __riscv_vse64_v_f64m1(result_ptr, angular_f64m1, vector_length);
335
+ result_ptr += vector_length;
336
+ norms_ptr += vector_length;
337
+ count_remaining -= vector_length;
338
+ }
339
+ }
340
+ }
341
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index)
342
+ result[row_index * result_stride_elements + row_index] = 0;
343
+ }
344
+
345
+ NK_PUBLIC void nk_angulars_symmetric_f64_rvv( //
346
+ nk_f64_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
347
+ nk_f64_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
348
+ nk_size_t const stride_elements = stride / sizeof(nk_f64_t);
349
+ nk_size_t const result_stride_elements = result_stride / sizeof(nk_f64_t);
350
+ nk_dots_symmetric_f64_rvv(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
351
+ nk_angulars_symmetric_f64_rvv_finalize_(vectors, n_vectors, depth, stride_elements, result, result_stride_elements,
352
+ row_start, row_count);
353
+ }
354
+
355
+ NK_INTERNAL void nk_euclideans_symmetric_f64_rvv_finalize_(nk_f64_t const *vectors, nk_size_t n_vectors,
356
+ nk_size_t depth, nk_size_t stride_elements, nk_f64_t *result,
357
+ nk_size_t result_stride_elements, nk_size_t row_start,
358
+ nk_size_t row_count) {
359
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
360
+ nk_f64_t *result_row = result + row_index * result_stride_elements;
361
+ result_row[row_index] = nk_dots_reduce_sumsq_f64_(vectors + row_index * stride_elements, depth);
362
+ }
363
+ nk_f64_t norms_cache[256];
364
+ for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
365
+ nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
366
+ for (nk_size_t col = chunk_start; col < chunk_end; ++col)
367
+ norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_f64_(vectors + col * stride_elements, depth);
368
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
369
+ nk_size_t col_start = row_index + 1 > chunk_start ? row_index + 1 : chunk_start;
370
+ if (col_start >= chunk_end) continue;
371
+ nk_f64_t *result_row = result + row_index * result_stride_elements;
372
+ nk_f64_t query_norm_sq_f64 = result_row[row_index];
373
+ nk_size_t count_remaining = chunk_end - col_start;
374
+ nk_f64_t *result_ptr = result_row + col_start;
375
+ nk_f64_t const *norms_ptr = norms_cache + (col_start - chunk_start);
376
+ while (count_remaining > 0) {
377
+ size_t vector_length = __riscv_vsetvl_e64m1(count_remaining);
378
+ vfloat64m1_t dots_f64m1 = __riscv_vle64_v_f64m1(result_ptr, vector_length);
379
+ vfloat64m1_t target_norms_sq_f64m1 = __riscv_vle64_v_f64m1(norms_ptr, vector_length);
380
+ vfloat64m1_t sum_sq_f64m1 = __riscv_vfadd_vf_f64m1(target_norms_sq_f64m1, query_norm_sq_f64,
381
+ vector_length);
382
+ vfloat64m1_t dist_sq_f64m1 = __riscv_vfsub_vv_f64m1(
383
+ sum_sq_f64m1, __riscv_vfmul_vf_f64m1(dots_f64m1, 2.0, vector_length), vector_length);
384
+ dist_sq_f64m1 = __riscv_vfmax_vf_f64m1(dist_sq_f64m1, 0.0, vector_length);
385
+ __riscv_vse64_v_f64m1(result_ptr, __riscv_vfsqrt_v_f64m1(dist_sq_f64m1, vector_length), vector_length);
386
+ result_ptr += vector_length;
387
+ norms_ptr += vector_length;
388
+ count_remaining -= vector_length;
389
+ }
390
+ }
391
+ }
392
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index)
393
+ result[row_index * result_stride_elements + row_index] = 0;
394
+ }
395
+
396
+ NK_PUBLIC void nk_euclideans_symmetric_f64_rvv( //
397
+ nk_f64_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
398
+ nk_f64_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
399
+ nk_size_t const stride_elements = stride / sizeof(nk_f64_t);
400
+ nk_size_t const result_stride_elements = result_stride / sizeof(nk_f64_t);
401
+ nk_dots_symmetric_f64_rvv(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
402
+ nk_euclideans_symmetric_f64_rvv_finalize_(vectors, n_vectors, depth, stride_elements, result,
403
+ result_stride_elements, row_start, row_count);
404
+ }
405
+
406
+ #pragma endregion // Double Precision Floats
407
+
408
+ #pragma region Half Precision Floats
409
+
410
+ NK_INTERNAL void nk_angulars_packed_f16_rvv_finalize_(nk_f16_t const *a, void const *b_packed, nk_f32_t *c,
411
+ nk_size_t rows, nk_size_t columns, nk_size_t depth,
412
+ nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
413
+ nk_cross_packed_buffer_header_t const *header = (nk_cross_packed_buffer_header_t const *)b_packed;
414
+ nk_f32_t const *target_norms = (nk_f32_t const *)((char const *)b_packed + sizeof(nk_cross_packed_buffer_header_t) +
415
+ header->column_count * header->depth_padded_values *
416
+ sizeof(nk_f32_t));
417
+ for (nk_size_t row_index = 0; row_index < rows; row_index++) {
418
+ nk_f16_t const *a_row = a + row_index * a_stride_elements;
419
+ nk_f32_t *result_row = c + row_index * c_stride_elements;
420
+ nk_f32_t query_norm_sq_f32 = nk_dots_reduce_sumsq_f16_(a_row, depth);
421
+ nk_size_t count_columns = columns;
422
+ nk_f32_t *result_ptr = result_row;
423
+ nk_f32_t const *norms_ptr = target_norms;
424
+ while (count_columns > 0) {
425
+ size_t vector_length = __riscv_vsetvl_e32m1(count_columns);
426
+ vfloat32m1_t dots_f32m1 = __riscv_vle32_v_f32m1(result_ptr, vector_length);
427
+ vfloat32m1_t target_norms_sq_f32m1 = __riscv_vle32_v_f32m1(norms_ptr, vector_length);
428
+ vfloat32m1_t norms_product_f32m1 = __riscv_vfmul_vf_f32m1(target_norms_sq_f32m1, query_norm_sq_f32,
429
+ vector_length);
430
+ vfloat32m1_t rsqrt_f32m1 = nk_rsqrt_f32m1_rvv_(norms_product_f32m1, vector_length);
431
+ vfloat32m1_t normalized_dots_f32m1 = __riscv_vfmul_vv_f32m1(dots_f32m1, rsqrt_f32m1, vector_length);
432
+ vfloat32m1_t angular_f32m1 = __riscv_vfrsub_vf_f32m1(normalized_dots_f32m1, 1.0f, vector_length);
433
+ angular_f32m1 = __riscv_vfmax_vf_f32m1(angular_f32m1, 0.0f, vector_length);
434
+ __riscv_vse32_v_f32m1(result_ptr, angular_f32m1, vector_length);
435
+ result_ptr += vector_length;
436
+ norms_ptr += vector_length;
437
+ count_columns -= vector_length;
438
+ }
439
+ }
440
+ }
441
+
442
+ NK_PUBLIC void nk_angulars_packed_f16_rvv( //
443
+ nk_f16_t const *a, void const *b_packed, nk_f32_t *c, //
444
+ nk_size_t rows, nk_size_t columns, nk_size_t depth, //
445
+ nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
446
+ nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_f16_t);
447
+ nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
448
+ nk_dots_packed_f16_rvv(a, b_packed, c, rows, columns, depth, a_stride_in_bytes, c_stride_in_bytes);
449
+ nk_angulars_packed_f16_rvv_finalize_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
450
+ }
451
+
452
+ NK_INTERNAL void nk_euclideans_packed_f16_rvv_finalize_(nk_f16_t const *a, void const *b_packed, nk_f32_t *c,
453
+ nk_size_t rows, nk_size_t columns, nk_size_t depth,
454
+ nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
455
+ nk_cross_packed_buffer_header_t const *header = (nk_cross_packed_buffer_header_t const *)b_packed;
456
+ nk_f32_t const *target_norms = (nk_f32_t const *)((char const *)b_packed + sizeof(nk_cross_packed_buffer_header_t) +
457
+ header->column_count * header->depth_padded_values *
458
+ sizeof(nk_f32_t));
459
+ for (nk_size_t row_index = 0; row_index < rows; row_index++) {
460
+ nk_f16_t const *a_row = a + row_index * a_stride_elements;
461
+ nk_f32_t *result_row = c + row_index * c_stride_elements;
462
+ nk_f32_t query_norm_sq_f32 = nk_dots_reduce_sumsq_f16_(a_row, depth);
463
+ nk_size_t count_columns = columns;
464
+ nk_f32_t *result_ptr = result_row;
465
+ nk_f32_t const *norms_ptr = target_norms;
466
+ while (count_columns > 0) {
467
+ size_t vector_length = __riscv_vsetvl_e32m1(count_columns);
468
+ vfloat32m1_t dots_f32m1 = __riscv_vle32_v_f32m1(result_ptr, vector_length);
469
+ vfloat32m1_t target_norms_sq_f32m1 = __riscv_vle32_v_f32m1(norms_ptr, vector_length);
470
+ vfloat32m1_t sum_sq_f32m1 = __riscv_vfadd_vf_f32m1(target_norms_sq_f32m1, query_norm_sq_f32, vector_length);
471
+ vfloat32m1_t dist_sq_f32m1 = __riscv_vfsub_vv_f32m1(
472
+ sum_sq_f32m1, __riscv_vfmul_vf_f32m1(dots_f32m1, 2.0f, vector_length), vector_length);
473
+ dist_sq_f32m1 = __riscv_vfmax_vf_f32m1(dist_sq_f32m1, 0.0f, vector_length);
474
+ __riscv_vse32_v_f32m1(result_ptr, __riscv_vfsqrt_v_f32m1(dist_sq_f32m1, vector_length), vector_length);
475
+ result_ptr += vector_length;
476
+ norms_ptr += vector_length;
477
+ count_columns -= vector_length;
478
+ }
479
+ }
480
+ }
481
+
482
+ NK_PUBLIC void nk_euclideans_packed_f16_rvv( //
483
+ nk_f16_t const *a, void const *b_packed, nk_f32_t *c, //
484
+ nk_size_t rows, nk_size_t columns, nk_size_t depth, //
485
+ nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
486
+ nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_f16_t);
487
+ nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
488
+ nk_dots_packed_f16_rvv(a, b_packed, c, rows, columns, depth, a_stride_in_bytes, c_stride_in_bytes);
489
+ nk_euclideans_packed_f16_rvv_finalize_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
490
+ }
491
+
492
+ NK_INTERNAL void nk_angulars_symmetric_f16_rvv_finalize_(nk_f16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
493
+ nk_size_t stride_elements, nk_f32_t *result,
494
+ nk_size_t result_stride_elements, nk_size_t row_start,
495
+ nk_size_t row_count) {
496
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
497
+ nk_f32_t *result_row = result + row_index * result_stride_elements;
498
+ result_row[row_index] = nk_dots_reduce_sumsq_f16_(vectors + row_index * stride_elements, depth);
499
+ }
500
+ nk_f32_t norms_cache[256];
501
+ for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
502
+ nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
503
+ for (nk_size_t col = chunk_start; col < chunk_end; ++col)
504
+ norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_f16_(vectors + col * stride_elements, depth);
505
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
506
+ nk_size_t col_start = row_index + 1 > chunk_start ? row_index + 1 : chunk_start;
507
+ if (col_start >= chunk_end) continue;
508
+ nk_f32_t *result_row = result + row_index * result_stride_elements;
509
+ nk_f32_t query_norm_sq_f32 = result_row[row_index];
510
+ nk_size_t count_remaining = chunk_end - col_start;
511
+ nk_f32_t *result_ptr = result_row + col_start;
512
+ nk_f32_t const *norms_ptr = norms_cache + (col_start - chunk_start);
513
+ while (count_remaining > 0) {
514
+ size_t vector_length = __riscv_vsetvl_e32m1(count_remaining);
515
+ vfloat32m1_t dots_f32m1 = __riscv_vle32_v_f32m1(result_ptr, vector_length);
516
+ vfloat32m1_t target_norms_sq_f32m1 = __riscv_vle32_v_f32m1(norms_ptr, vector_length);
517
+ vfloat32m1_t norms_product_f32m1 = __riscv_vfmul_vf_f32m1(target_norms_sq_f32m1, query_norm_sq_f32,
518
+ vector_length);
519
+ vfloat32m1_t rsqrt_f32m1 = nk_rsqrt_f32m1_rvv_(norms_product_f32m1, vector_length);
520
+ vfloat32m1_t normalized_dots_f32m1 = __riscv_vfmul_vv_f32m1(dots_f32m1, rsqrt_f32m1, vector_length);
521
+ vfloat32m1_t angular_f32m1 = __riscv_vfrsub_vf_f32m1(normalized_dots_f32m1, 1.0f, vector_length);
522
+ angular_f32m1 = __riscv_vfmax_vf_f32m1(angular_f32m1, 0.0f, vector_length);
523
+ __riscv_vse32_v_f32m1(result_ptr, angular_f32m1, vector_length);
524
+ result_ptr += vector_length;
525
+ norms_ptr += vector_length;
526
+ count_remaining -= vector_length;
527
+ }
528
+ }
529
+ }
530
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index)
531
+ result[row_index * result_stride_elements + row_index] = 0;
532
+ }
533
+
534
+ NK_PUBLIC void nk_angulars_symmetric_f16_rvv( //
535
+ nk_f16_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
536
+ nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
537
+ nk_size_t const stride_elements = stride / sizeof(nk_f16_t);
538
+ nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
539
+ nk_dots_symmetric_f16_rvv(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
540
+ nk_angulars_symmetric_f16_rvv_finalize_(vectors, n_vectors, depth, stride_elements, result, result_stride_elements,
541
+ row_start, row_count);
542
+ }
543
+
544
+ NK_INTERNAL void nk_euclideans_symmetric_f16_rvv_finalize_(nk_f16_t const *vectors, nk_size_t n_vectors,
545
+ nk_size_t depth, nk_size_t stride_elements, nk_f32_t *result,
546
+ nk_size_t result_stride_elements, nk_size_t row_start,
547
+ nk_size_t row_count) {
548
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
549
+ nk_f32_t *result_row = result + row_index * result_stride_elements;
550
+ result_row[row_index] = nk_dots_reduce_sumsq_f16_(vectors + row_index * stride_elements, depth);
551
+ }
552
+ nk_f32_t norms_cache[256];
553
+ for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
554
+ nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
555
+ for (nk_size_t col = chunk_start; col < chunk_end; ++col)
556
+ norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_f16_(vectors + col * stride_elements, depth);
557
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
558
+ nk_size_t col_start = row_index + 1 > chunk_start ? row_index + 1 : chunk_start;
559
+ if (col_start >= chunk_end) continue;
560
+ nk_f32_t *result_row = result + row_index * result_stride_elements;
561
+ nk_f32_t query_norm_sq_f32 = result_row[row_index];
562
+ nk_size_t count_remaining = chunk_end - col_start;
563
+ nk_f32_t *result_ptr = result_row + col_start;
564
+ nk_f32_t const *norms_ptr = norms_cache + (col_start - chunk_start);
565
+ while (count_remaining > 0) {
566
+ size_t vector_length = __riscv_vsetvl_e32m1(count_remaining);
567
+ vfloat32m1_t dots_f32m1 = __riscv_vle32_v_f32m1(result_ptr, vector_length);
568
+ vfloat32m1_t target_norms_sq_f32m1 = __riscv_vle32_v_f32m1(norms_ptr, vector_length);
569
+ vfloat32m1_t sum_sq_f32m1 = __riscv_vfadd_vf_f32m1(target_norms_sq_f32m1, query_norm_sq_f32,
570
+ vector_length);
571
+ vfloat32m1_t dist_sq_f32m1 = __riscv_vfsub_vv_f32m1(
572
+ sum_sq_f32m1, __riscv_vfmul_vf_f32m1(dots_f32m1, 2.0f, vector_length), vector_length);
573
+ dist_sq_f32m1 = __riscv_vfmax_vf_f32m1(dist_sq_f32m1, 0.0f, vector_length);
574
+ __riscv_vse32_v_f32m1(result_ptr, __riscv_vfsqrt_v_f32m1(dist_sq_f32m1, vector_length), vector_length);
575
+ result_ptr += vector_length;
576
+ norms_ptr += vector_length;
577
+ count_remaining -= vector_length;
578
+ }
579
+ }
580
+ }
581
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index)
582
+ result[row_index * result_stride_elements + row_index] = 0;
583
+ }
584
+
585
+ NK_PUBLIC void nk_euclideans_symmetric_f16_rvv( //
586
+ nk_f16_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
587
+ nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
588
+ nk_size_t const stride_elements = stride / sizeof(nk_f16_t);
589
+ nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
590
+ nk_dots_symmetric_f16_rvv(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
591
+ nk_euclideans_symmetric_f16_rvv_finalize_(vectors, n_vectors, depth, stride_elements, result,
592
+ result_stride_elements, row_start, row_count);
593
+ }
594
+
595
+ #pragma endregion // Half Precision Floats
596
+
597
+ #pragma region Brain Float 16
598
+
599
+ NK_INTERNAL void nk_angulars_packed_bf16_rvv_finalize_(nk_bf16_t const *a, void const *b_packed, nk_f32_t *c,
600
+ nk_size_t rows, nk_size_t columns, nk_size_t depth,
601
+ nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
602
+ nk_cross_packed_buffer_header_t const *header = (nk_cross_packed_buffer_header_t const *)b_packed;
603
+ nk_f32_t const *target_norms = (nk_f32_t const *)((char const *)b_packed + sizeof(nk_cross_packed_buffer_header_t) +
604
+ header->column_count * header->depth_padded_values *
605
+ sizeof(nk_f32_t));
606
+ for (nk_size_t row_index = 0; row_index < rows; row_index++) {
607
+ nk_bf16_t const *a_row = a + row_index * a_stride_elements;
608
+ nk_f32_t *result_row = c + row_index * c_stride_elements;
609
+ nk_f32_t query_norm_sq_f32 = nk_dots_reduce_sumsq_bf16_(a_row, depth);
610
+ nk_size_t count_columns = columns;
611
+ nk_f32_t *result_ptr = result_row;
612
+ nk_f32_t const *norms_ptr = target_norms;
613
+ while (count_columns > 0) {
614
+ size_t vector_length = __riscv_vsetvl_e32m1(count_columns);
615
+ vfloat32m1_t dots_f32m1 = __riscv_vle32_v_f32m1(result_ptr, vector_length);
616
+ vfloat32m1_t target_norms_sq_f32m1 = __riscv_vle32_v_f32m1(norms_ptr, vector_length);
617
+ vfloat32m1_t norms_product_f32m1 = __riscv_vfmul_vf_f32m1(target_norms_sq_f32m1, query_norm_sq_f32,
618
+ vector_length);
619
+ vfloat32m1_t rsqrt_f32m1 = nk_rsqrt_f32m1_rvv_(norms_product_f32m1, vector_length);
620
+ vfloat32m1_t normalized_dots_f32m1 = __riscv_vfmul_vv_f32m1(dots_f32m1, rsqrt_f32m1, vector_length);
621
+ vfloat32m1_t angular_f32m1 = __riscv_vfrsub_vf_f32m1(normalized_dots_f32m1, 1.0f, vector_length);
622
+ angular_f32m1 = __riscv_vfmax_vf_f32m1(angular_f32m1, 0.0f, vector_length);
623
+ __riscv_vse32_v_f32m1(result_ptr, angular_f32m1, vector_length);
624
+ result_ptr += vector_length;
625
+ norms_ptr += vector_length;
626
+ count_columns -= vector_length;
627
+ }
628
+ }
629
+ }
630
+
631
+ NK_PUBLIC void nk_angulars_packed_bf16_rvv( //
632
+ nk_bf16_t const *a, void const *b_packed, nk_f32_t *c, //
633
+ nk_size_t rows, nk_size_t columns, nk_size_t depth, //
634
+ nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
635
+ nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_bf16_t);
636
+ nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
637
+ nk_dots_packed_bf16_rvv(a, b_packed, c, rows, columns, depth, a_stride_in_bytes, c_stride_in_bytes);
638
+ nk_angulars_packed_bf16_rvv_finalize_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
639
+ }
640
+
641
+ NK_INTERNAL void nk_euclideans_packed_bf16_rvv_finalize_(nk_bf16_t const *a, void const *b_packed, nk_f32_t *c,
642
+ nk_size_t rows, nk_size_t columns, nk_size_t depth,
643
+ nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
644
+ nk_cross_packed_buffer_header_t const *header = (nk_cross_packed_buffer_header_t const *)b_packed;
645
+ nk_f32_t const *target_norms = (nk_f32_t const *)((char const *)b_packed + sizeof(nk_cross_packed_buffer_header_t) +
646
+ header->column_count * header->depth_padded_values *
647
+ sizeof(nk_f32_t));
648
+ for (nk_size_t row_index = 0; row_index < rows; row_index++) {
649
+ nk_bf16_t const *a_row = a + row_index * a_stride_elements;
650
+ nk_f32_t *result_row = c + row_index * c_stride_elements;
651
+ nk_f32_t query_norm_sq_f32 = nk_dots_reduce_sumsq_bf16_(a_row, depth);
652
+ nk_size_t count_columns = columns;
653
+ nk_f32_t *result_ptr = result_row;
654
+ nk_f32_t const *norms_ptr = target_norms;
655
+ while (count_columns > 0) {
656
+ size_t vector_length = __riscv_vsetvl_e32m1(count_columns);
657
+ vfloat32m1_t dots_f32m1 = __riscv_vle32_v_f32m1(result_ptr, vector_length);
658
+ vfloat32m1_t target_norms_sq_f32m1 = __riscv_vle32_v_f32m1(norms_ptr, vector_length);
659
+ vfloat32m1_t sum_sq_f32m1 = __riscv_vfadd_vf_f32m1(target_norms_sq_f32m1, query_norm_sq_f32, vector_length);
660
+ vfloat32m1_t dist_sq_f32m1 = __riscv_vfsub_vv_f32m1(
661
+ sum_sq_f32m1, __riscv_vfmul_vf_f32m1(dots_f32m1, 2.0f, vector_length), vector_length);
662
+ dist_sq_f32m1 = __riscv_vfmax_vf_f32m1(dist_sq_f32m1, 0.0f, vector_length);
663
+ __riscv_vse32_v_f32m1(result_ptr, __riscv_vfsqrt_v_f32m1(dist_sq_f32m1, vector_length), vector_length);
664
+ result_ptr += vector_length;
665
+ norms_ptr += vector_length;
666
+ count_columns -= vector_length;
667
+ }
668
+ }
669
+ }
670
+
671
+ NK_PUBLIC void nk_euclideans_packed_bf16_rvv( //
672
+ nk_bf16_t const *a, void const *b_packed, nk_f32_t *c, //
673
+ nk_size_t rows, nk_size_t columns, nk_size_t depth, //
674
+ nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
675
+ nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_bf16_t);
676
+ nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
677
+ nk_dots_packed_bf16_rvv(a, b_packed, c, rows, columns, depth, a_stride_in_bytes, c_stride_in_bytes);
678
+ nk_euclideans_packed_bf16_rvv_finalize_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
679
+ }
680
+
681
+ NK_INTERNAL void nk_angulars_symmetric_bf16_rvv_finalize_(nk_bf16_t const *vectors, nk_size_t n_vectors,
682
+ nk_size_t depth, nk_size_t stride_elements, nk_f32_t *result,
683
+ nk_size_t result_stride_elements, nk_size_t row_start,
684
+ nk_size_t row_count) {
685
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
686
+ nk_f32_t *result_row = result + row_index * result_stride_elements;
687
+ result_row[row_index] = nk_dots_reduce_sumsq_bf16_(vectors + row_index * stride_elements, depth);
688
+ }
689
+ nk_f32_t norms_cache[256];
690
+ for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
691
+ nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
692
+ for (nk_size_t col = chunk_start; col < chunk_end; ++col)
693
+ norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_bf16_(vectors + col * stride_elements, depth);
694
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
695
+ nk_size_t col_start = row_index + 1 > chunk_start ? row_index + 1 : chunk_start;
696
+ if (col_start >= chunk_end) continue;
697
+ nk_f32_t *result_row = result + row_index * result_stride_elements;
698
+ nk_f32_t query_norm_sq_f32 = result_row[row_index];
699
+ nk_size_t count_remaining = chunk_end - col_start;
700
+ nk_f32_t *result_ptr = result_row + col_start;
701
+ nk_f32_t const *norms_ptr = norms_cache + (col_start - chunk_start);
702
+ while (count_remaining > 0) {
703
+ size_t vector_length = __riscv_vsetvl_e32m1(count_remaining);
704
+ vfloat32m1_t dots_f32m1 = __riscv_vle32_v_f32m1(result_ptr, vector_length);
705
+ vfloat32m1_t target_norms_sq_f32m1 = __riscv_vle32_v_f32m1(norms_ptr, vector_length);
706
+ vfloat32m1_t norms_product_f32m1 = __riscv_vfmul_vf_f32m1(target_norms_sq_f32m1, query_norm_sq_f32,
707
+ vector_length);
708
+ vfloat32m1_t rsqrt_f32m1 = nk_rsqrt_f32m1_rvv_(norms_product_f32m1, vector_length);
709
+ vfloat32m1_t normalized_dots_f32m1 = __riscv_vfmul_vv_f32m1(dots_f32m1, rsqrt_f32m1, vector_length);
710
+ vfloat32m1_t angular_f32m1 = __riscv_vfrsub_vf_f32m1(normalized_dots_f32m1, 1.0f, vector_length);
711
+ angular_f32m1 = __riscv_vfmax_vf_f32m1(angular_f32m1, 0.0f, vector_length);
712
+ __riscv_vse32_v_f32m1(result_ptr, angular_f32m1, vector_length);
713
+ result_ptr += vector_length;
714
+ norms_ptr += vector_length;
715
+ count_remaining -= vector_length;
716
+ }
717
+ }
718
+ }
719
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index)
720
+ result[row_index * result_stride_elements + row_index] = 0;
721
+ }
722
+
723
+ NK_PUBLIC void nk_angulars_symmetric_bf16_rvv( //
724
+ nk_bf16_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
725
+ nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
726
+ nk_size_t const stride_elements = stride / sizeof(nk_bf16_t);
727
+ nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
728
+ nk_dots_symmetric_bf16_rvv(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
729
+ nk_angulars_symmetric_bf16_rvv_finalize_(vectors, n_vectors, depth, stride_elements, result, result_stride_elements,
730
+ row_start, row_count);
731
+ }
732
+
733
+ NK_INTERNAL void nk_euclideans_symmetric_bf16_rvv_finalize_(nk_bf16_t const *vectors, nk_size_t n_vectors,
734
+ nk_size_t depth, nk_size_t stride_elements,
735
+ nk_f32_t *result, nk_size_t result_stride_elements,
736
+ nk_size_t row_start, nk_size_t row_count) {
737
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
738
+ nk_f32_t *result_row = result + row_index * result_stride_elements;
739
+ result_row[row_index] = nk_dots_reduce_sumsq_bf16_(vectors + row_index * stride_elements, depth);
740
+ }
741
+ nk_f32_t norms_cache[256];
742
+ for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
743
+ nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
744
+ for (nk_size_t col = chunk_start; col < chunk_end; ++col)
745
+ norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_bf16_(vectors + col * stride_elements, depth);
746
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
747
+ nk_size_t col_start = row_index + 1 > chunk_start ? row_index + 1 : chunk_start;
748
+ if (col_start >= chunk_end) continue;
749
+ nk_f32_t *result_row = result + row_index * result_stride_elements;
750
+ nk_f32_t query_norm_sq_f32 = result_row[row_index];
751
+ nk_size_t count_remaining = chunk_end - col_start;
752
+ nk_f32_t *result_ptr = result_row + col_start;
753
+ nk_f32_t const *norms_ptr = norms_cache + (col_start - chunk_start);
754
+ while (count_remaining > 0) {
755
+ size_t vector_length = __riscv_vsetvl_e32m1(count_remaining);
756
+ vfloat32m1_t dots_f32m1 = __riscv_vle32_v_f32m1(result_ptr, vector_length);
757
+ vfloat32m1_t target_norms_sq_f32m1 = __riscv_vle32_v_f32m1(norms_ptr, vector_length);
758
+ vfloat32m1_t sum_sq_f32m1 = __riscv_vfadd_vf_f32m1(target_norms_sq_f32m1, query_norm_sq_f32,
759
+ vector_length);
760
+ vfloat32m1_t dist_sq_f32m1 = __riscv_vfsub_vv_f32m1(
761
+ sum_sq_f32m1, __riscv_vfmul_vf_f32m1(dots_f32m1, 2.0f, vector_length), vector_length);
762
+ dist_sq_f32m1 = __riscv_vfmax_vf_f32m1(dist_sq_f32m1, 0.0f, vector_length);
763
+ __riscv_vse32_v_f32m1(result_ptr, __riscv_vfsqrt_v_f32m1(dist_sq_f32m1, vector_length), vector_length);
764
+ result_ptr += vector_length;
765
+ norms_ptr += vector_length;
766
+ count_remaining -= vector_length;
767
+ }
768
+ }
769
+ }
770
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index)
771
+ result[row_index * result_stride_elements + row_index] = 0;
772
+ }
773
+
774
+ NK_PUBLIC void nk_euclideans_symmetric_bf16_rvv( //
775
+ nk_bf16_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
776
+ nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
777
+ nk_size_t const stride_elements = stride / sizeof(nk_bf16_t);
778
+ nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
779
+ nk_dots_symmetric_bf16_rvv(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
780
+ nk_euclideans_symmetric_bf16_rvv_finalize_(vectors, n_vectors, depth, stride_elements, result,
781
+ result_stride_elements, row_start, row_count);
782
+ }
783
+
784
+ #pragma endregion // Brain Float 16
785
+
786
+ #pragma region Micro Precision E2M3
787
+
788
+ NK_INTERNAL void nk_angulars_packed_e2m3_rvv_finalize_(nk_e2m3_t const *a, void const *b_packed, nk_f32_t *c,
789
+ nk_size_t rows, nk_size_t columns, nk_size_t depth,
790
+ nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
791
+ nk_cross_packed_buffer_header_t const *header = (nk_cross_packed_buffer_header_t const *)b_packed;
792
+ nk_f32_t const *target_norms = (nk_f32_t const *)((char const *)b_packed + sizeof(nk_cross_packed_buffer_header_t) +
793
+ header->column_count * header->depth_padded_values *
794
+ sizeof(nk_e2m3_t));
795
+ for (nk_size_t row_index = 0; row_index < rows; row_index++) {
796
+ nk_e2m3_t const *a_row = a + row_index * a_stride_elements;
797
+ nk_f32_t *result_row = c + row_index * c_stride_elements;
798
+ nk_f32_t query_norm_sq_f32 = nk_dots_reduce_sumsq_e2m3_(a_row, depth);
799
+ nk_size_t count_columns = columns;
800
+ nk_f32_t *result_ptr = result_row;
801
+ nk_f32_t const *norms_ptr = target_norms;
802
+ while (count_columns > 0) {
803
+ size_t vector_length = __riscv_vsetvl_e32m1(count_columns);
804
+ vfloat32m1_t dots_f32m1 = __riscv_vle32_v_f32m1(result_ptr, vector_length);
805
+ vfloat32m1_t target_norms_sq_f32m1 = __riscv_vle32_v_f32m1(norms_ptr, vector_length);
806
+ vfloat32m1_t norms_product_f32m1 = __riscv_vfmul_vf_f32m1(target_norms_sq_f32m1, query_norm_sq_f32,
807
+ vector_length);
808
+ vfloat32m1_t rsqrt_f32m1 = nk_rsqrt_f32m1_rvv_(norms_product_f32m1, vector_length);
809
+ vfloat32m1_t normalized_dots_f32m1 = __riscv_vfmul_vv_f32m1(dots_f32m1, rsqrt_f32m1, vector_length);
810
+ vfloat32m1_t angular_f32m1 = __riscv_vfrsub_vf_f32m1(normalized_dots_f32m1, 1.0f, vector_length);
811
+ angular_f32m1 = __riscv_vfmax_vf_f32m1(angular_f32m1, 0.0f, vector_length);
812
+ __riscv_vse32_v_f32m1(result_ptr, angular_f32m1, vector_length);
813
+ result_ptr += vector_length;
814
+ norms_ptr += vector_length;
815
+ count_columns -= vector_length;
816
+ }
817
+ }
818
+ }
819
+
820
+ NK_PUBLIC void nk_angulars_packed_e2m3_rvv( //
821
+ nk_e2m3_t const *a, void const *b_packed, nk_f32_t *c, //
822
+ nk_size_t rows, nk_size_t columns, nk_size_t depth, //
823
+ nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
824
+ nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_e2m3_t);
825
+ nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
826
+ nk_dots_packed_e2m3_rvv(a, b_packed, c, rows, columns, depth, a_stride_in_bytes, c_stride_in_bytes);
827
+ nk_angulars_packed_e2m3_rvv_finalize_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
828
+ }
829
+
830
+ NK_INTERNAL void nk_euclideans_packed_e2m3_rvv_finalize_(nk_e2m3_t const *a, void const *b_packed, nk_f32_t *c,
831
+ nk_size_t rows, nk_size_t columns, nk_size_t depth,
832
+ nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
833
+ nk_cross_packed_buffer_header_t const *header = (nk_cross_packed_buffer_header_t const *)b_packed;
834
+ nk_f32_t const *target_norms = (nk_f32_t const *)((char const *)b_packed + sizeof(nk_cross_packed_buffer_header_t) +
835
+ header->column_count * header->depth_padded_values *
836
+ sizeof(nk_e2m3_t));
837
+ for (nk_size_t row_index = 0; row_index < rows; row_index++) {
838
+ nk_e2m3_t const *a_row = a + row_index * a_stride_elements;
839
+ nk_f32_t *result_row = c + row_index * c_stride_elements;
840
+ nk_f32_t query_norm_sq_f32 = nk_dots_reduce_sumsq_e2m3_(a_row, depth);
841
+ nk_size_t count_columns = columns;
842
+ nk_f32_t *result_ptr = result_row;
843
+ nk_f32_t const *norms_ptr = target_norms;
844
+ while (count_columns > 0) {
845
+ size_t vector_length = __riscv_vsetvl_e32m1(count_columns);
846
+ vfloat32m1_t dots_f32m1 = __riscv_vle32_v_f32m1(result_ptr, vector_length);
847
+ vfloat32m1_t target_norms_sq_f32m1 = __riscv_vle32_v_f32m1(norms_ptr, vector_length);
848
+ vfloat32m1_t sum_sq_f32m1 = __riscv_vfadd_vf_f32m1(target_norms_sq_f32m1, query_norm_sq_f32, vector_length);
849
+ vfloat32m1_t dist_sq_f32m1 = __riscv_vfsub_vv_f32m1(
850
+ sum_sq_f32m1, __riscv_vfmul_vf_f32m1(dots_f32m1, 2.0f, vector_length), vector_length);
851
+ dist_sq_f32m1 = __riscv_vfmax_vf_f32m1(dist_sq_f32m1, 0.0f, vector_length);
852
+ __riscv_vse32_v_f32m1(result_ptr, __riscv_vfsqrt_v_f32m1(dist_sq_f32m1, vector_length), vector_length);
853
+ result_ptr += vector_length;
854
+ norms_ptr += vector_length;
855
+ count_columns -= vector_length;
856
+ }
857
+ }
858
+ }
859
+
860
+ NK_PUBLIC void nk_euclideans_packed_e2m3_rvv( //
861
+ nk_e2m3_t const *a, void const *b_packed, nk_f32_t *c, //
862
+ nk_size_t rows, nk_size_t columns, nk_size_t depth, //
863
+ nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
864
+ nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_e2m3_t);
865
+ nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
866
+ nk_dots_packed_e2m3_rvv(a, b_packed, c, rows, columns, depth, a_stride_in_bytes, c_stride_in_bytes);
867
+ nk_euclideans_packed_e2m3_rvv_finalize_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
868
+ }
869
+
870
+ NK_INTERNAL void nk_angulars_symmetric_e2m3_rvv_finalize_(nk_e2m3_t const *vectors, nk_size_t n_vectors,
871
+ nk_size_t depth, nk_size_t stride_elements, nk_f32_t *result,
872
+ nk_size_t result_stride_elements, nk_size_t row_start,
873
+ nk_size_t row_count) {
874
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
875
+ nk_f32_t *result_row = result + row_index * result_stride_elements;
876
+ result_row[row_index] = nk_dots_reduce_sumsq_e2m3_(vectors + row_index * stride_elements, depth);
877
+ }
878
+ nk_f32_t norms_cache[256];
879
+ for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
880
+ nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
881
+ for (nk_size_t col = chunk_start; col < chunk_end; ++col)
882
+ norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_e2m3_(vectors + col * stride_elements, depth);
883
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
884
+ nk_size_t col_start = row_index + 1 > chunk_start ? row_index + 1 : chunk_start;
885
+ if (col_start >= chunk_end) continue;
886
+ nk_f32_t *result_row = result + row_index * result_stride_elements;
887
+ nk_f32_t query_norm_sq_f32 = result_row[row_index];
888
+ nk_size_t count_remaining = chunk_end - col_start;
889
+ nk_f32_t *result_ptr = result_row + col_start;
890
+ nk_f32_t const *norms_ptr = norms_cache + (col_start - chunk_start);
891
+ while (count_remaining > 0) {
892
+ size_t vector_length = __riscv_vsetvl_e32m1(count_remaining);
893
+ vfloat32m1_t dots_f32m1 = __riscv_vle32_v_f32m1(result_ptr, vector_length);
894
+ vfloat32m1_t target_norms_sq_f32m1 = __riscv_vle32_v_f32m1(norms_ptr, vector_length);
895
+ vfloat32m1_t norms_product_f32m1 = __riscv_vfmul_vf_f32m1(target_norms_sq_f32m1, query_norm_sq_f32,
896
+ vector_length);
897
+ vfloat32m1_t rsqrt_f32m1 = nk_rsqrt_f32m1_rvv_(norms_product_f32m1, vector_length);
898
+ vfloat32m1_t normalized_dots_f32m1 = __riscv_vfmul_vv_f32m1(dots_f32m1, rsqrt_f32m1, vector_length);
899
+ vfloat32m1_t angular_f32m1 = __riscv_vfrsub_vf_f32m1(normalized_dots_f32m1, 1.0f, vector_length);
900
+ angular_f32m1 = __riscv_vfmax_vf_f32m1(angular_f32m1, 0.0f, vector_length);
901
+ __riscv_vse32_v_f32m1(result_ptr, angular_f32m1, vector_length);
902
+ result_ptr += vector_length;
903
+ norms_ptr += vector_length;
904
+ count_remaining -= vector_length;
905
+ }
906
+ }
907
+ }
908
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index)
909
+ result[row_index * result_stride_elements + row_index] = 0;
910
+ }
911
+
912
+ NK_PUBLIC void nk_angulars_symmetric_e2m3_rvv( //
913
+ nk_e2m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
914
+ nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
915
+ nk_size_t const stride_elements = stride / sizeof(nk_e2m3_t);
916
+ nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
917
+ nk_dots_symmetric_e2m3_rvv(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
918
+ nk_angulars_symmetric_e2m3_rvv_finalize_(vectors, n_vectors, depth, stride_elements, result, result_stride_elements,
919
+ row_start, row_count);
920
+ }
921
+
922
+ NK_INTERNAL void nk_euclideans_symmetric_e2m3_rvv_finalize_(nk_e2m3_t const *vectors, nk_size_t n_vectors,
923
+ nk_size_t depth, nk_size_t stride_elements,
924
+ nk_f32_t *result, nk_size_t result_stride_elements,
925
+ nk_size_t row_start, nk_size_t row_count) {
926
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
927
+ nk_f32_t *result_row = result + row_index * result_stride_elements;
928
+ result_row[row_index] = nk_dots_reduce_sumsq_e2m3_(vectors + row_index * stride_elements, depth);
929
+ }
930
+ nk_f32_t norms_cache[256];
931
+ for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
932
+ nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
933
+ for (nk_size_t col = chunk_start; col < chunk_end; ++col)
934
+ norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_e2m3_(vectors + col * stride_elements, depth);
935
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
936
+ nk_size_t col_start = row_index + 1 > chunk_start ? row_index + 1 : chunk_start;
937
+ if (col_start >= chunk_end) continue;
938
+ nk_f32_t *result_row = result + row_index * result_stride_elements;
939
+ nk_f32_t query_norm_sq_f32 = result_row[row_index];
940
+ nk_size_t count_remaining = chunk_end - col_start;
941
+ nk_f32_t *result_ptr = result_row + col_start;
942
+ nk_f32_t const *norms_ptr = norms_cache + (col_start - chunk_start);
943
+ while (count_remaining > 0) {
944
+ size_t vector_length = __riscv_vsetvl_e32m1(count_remaining);
945
+ vfloat32m1_t dots_f32m1 = __riscv_vle32_v_f32m1(result_ptr, vector_length);
946
+ vfloat32m1_t target_norms_sq_f32m1 = __riscv_vle32_v_f32m1(norms_ptr, vector_length);
947
+ vfloat32m1_t sum_sq_f32m1 = __riscv_vfadd_vf_f32m1(target_norms_sq_f32m1, query_norm_sq_f32,
948
+ vector_length);
949
+ vfloat32m1_t dist_sq_f32m1 = __riscv_vfsub_vv_f32m1(
950
+ sum_sq_f32m1, __riscv_vfmul_vf_f32m1(dots_f32m1, 2.0f, vector_length), vector_length);
951
+ dist_sq_f32m1 = __riscv_vfmax_vf_f32m1(dist_sq_f32m1, 0.0f, vector_length);
952
+ __riscv_vse32_v_f32m1(result_ptr, __riscv_vfsqrt_v_f32m1(dist_sq_f32m1, vector_length), vector_length);
953
+ result_ptr += vector_length;
954
+ norms_ptr += vector_length;
955
+ count_remaining -= vector_length;
956
+ }
957
+ }
958
+ }
959
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index)
960
+ result[row_index * result_stride_elements + row_index] = 0;
961
+ }
962
+
963
+ NK_PUBLIC void nk_euclideans_symmetric_e2m3_rvv( //
964
+ nk_e2m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
965
+ nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
966
+ nk_size_t const stride_elements = stride / sizeof(nk_e2m3_t);
967
+ nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
968
+ nk_dots_symmetric_e2m3_rvv(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
969
+ nk_euclideans_symmetric_e2m3_rvv_finalize_(vectors, n_vectors, depth, stride_elements, result,
970
+ result_stride_elements, row_start, row_count);
971
+ }
972
+
973
+ #pragma endregion // Micro Precision E2M3
974
+
975
+ #pragma region Micro Precision E3M2
976
+
977
+ NK_INTERNAL void nk_angulars_packed_e3m2_rvv_finalize_(nk_e3m2_t const *a, void const *b_packed, nk_f32_t *c,
978
+ nk_size_t rows, nk_size_t columns, nk_size_t depth,
979
+ nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
980
+ nk_cross_packed_buffer_header_t const *header = (nk_cross_packed_buffer_header_t const *)b_packed;
981
+ nk_f32_t const *target_norms = (nk_f32_t const *)((char const *)b_packed + sizeof(nk_cross_packed_buffer_header_t) +
982
+ header->column_count * header->depth_padded_values *
983
+ sizeof(nk_i16_t));
984
+ for (nk_size_t row_index = 0; row_index < rows; row_index++) {
985
+ nk_e3m2_t const *a_row = a + row_index * a_stride_elements;
986
+ nk_f32_t *result_row = c + row_index * c_stride_elements;
987
+ nk_f32_t query_norm_sq_f32 = nk_dots_reduce_sumsq_e3m2_(a_row, depth);
988
+ nk_size_t count_columns = columns;
989
+ nk_f32_t *result_ptr = result_row;
990
+ nk_f32_t const *norms_ptr = target_norms;
991
+ while (count_columns > 0) {
992
+ size_t vector_length = __riscv_vsetvl_e32m1(count_columns);
993
+ vfloat32m1_t dots_f32m1 = __riscv_vle32_v_f32m1(result_ptr, vector_length);
994
+ vfloat32m1_t target_norms_sq_f32m1 = __riscv_vle32_v_f32m1(norms_ptr, vector_length);
995
+ vfloat32m1_t norms_product_f32m1 = __riscv_vfmul_vf_f32m1(target_norms_sq_f32m1, query_norm_sq_f32,
996
+ vector_length);
997
+ vfloat32m1_t rsqrt_f32m1 = nk_rsqrt_f32m1_rvv_(norms_product_f32m1, vector_length);
998
+ vfloat32m1_t normalized_dots_f32m1 = __riscv_vfmul_vv_f32m1(dots_f32m1, rsqrt_f32m1, vector_length);
999
+ vfloat32m1_t angular_f32m1 = __riscv_vfrsub_vf_f32m1(normalized_dots_f32m1, 1.0f, vector_length);
1000
+ angular_f32m1 = __riscv_vfmax_vf_f32m1(angular_f32m1, 0.0f, vector_length);
1001
+ __riscv_vse32_v_f32m1(result_ptr, angular_f32m1, vector_length);
1002
+ result_ptr += vector_length;
1003
+ norms_ptr += vector_length;
1004
+ count_columns -= vector_length;
1005
+ }
1006
+ }
1007
+ }
1008
+
1009
+ NK_PUBLIC void nk_angulars_packed_e3m2_rvv( //
1010
+ nk_e3m2_t const *a, void const *b_packed, nk_f32_t *c, //
1011
+ nk_size_t rows, nk_size_t columns, nk_size_t depth, //
1012
+ nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
1013
+ nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_e3m2_t);
1014
+ nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
1015
+ nk_dots_packed_e3m2_rvv(a, b_packed, c, rows, columns, depth, a_stride_in_bytes, c_stride_in_bytes);
1016
+ nk_angulars_packed_e3m2_rvv_finalize_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
1017
+ }
1018
+
1019
+ NK_INTERNAL void nk_euclideans_packed_e3m2_rvv_finalize_(nk_e3m2_t const *a, void const *b_packed, nk_f32_t *c,
1020
+ nk_size_t rows, nk_size_t columns, nk_size_t depth,
1021
+ nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
1022
+ nk_cross_packed_buffer_header_t const *header = (nk_cross_packed_buffer_header_t const *)b_packed;
1023
+ nk_f32_t const *target_norms = (nk_f32_t const *)((char const *)b_packed + sizeof(nk_cross_packed_buffer_header_t) +
1024
+ header->column_count * header->depth_padded_values *
1025
+ sizeof(nk_i16_t));
1026
+ for (nk_size_t row_index = 0; row_index < rows; row_index++) {
1027
+ nk_e3m2_t const *a_row = a + row_index * a_stride_elements;
1028
+ nk_f32_t *result_row = c + row_index * c_stride_elements;
1029
+ nk_f32_t query_norm_sq_f32 = nk_dots_reduce_sumsq_e3m2_(a_row, depth);
1030
+ nk_size_t count_columns = columns;
1031
+ nk_f32_t *result_ptr = result_row;
1032
+ nk_f32_t const *norms_ptr = target_norms;
1033
+ while (count_columns > 0) {
1034
+ size_t vector_length = __riscv_vsetvl_e32m1(count_columns);
1035
+ vfloat32m1_t dots_f32m1 = __riscv_vle32_v_f32m1(result_ptr, vector_length);
1036
+ vfloat32m1_t target_norms_sq_f32m1 = __riscv_vle32_v_f32m1(norms_ptr, vector_length);
1037
+ vfloat32m1_t sum_sq_f32m1 = __riscv_vfadd_vf_f32m1(target_norms_sq_f32m1, query_norm_sq_f32, vector_length);
1038
+ vfloat32m1_t dist_sq_f32m1 = __riscv_vfsub_vv_f32m1(
1039
+ sum_sq_f32m1, __riscv_vfmul_vf_f32m1(dots_f32m1, 2.0f, vector_length), vector_length);
1040
+ dist_sq_f32m1 = __riscv_vfmax_vf_f32m1(dist_sq_f32m1, 0.0f, vector_length);
1041
+ __riscv_vse32_v_f32m1(result_ptr, __riscv_vfsqrt_v_f32m1(dist_sq_f32m1, vector_length), vector_length);
1042
+ result_ptr += vector_length;
1043
+ norms_ptr += vector_length;
1044
+ count_columns -= vector_length;
1045
+ }
1046
+ }
1047
+ }
1048
+
1049
+ NK_PUBLIC void nk_euclideans_packed_e3m2_rvv( //
1050
+ nk_e3m2_t const *a, void const *b_packed, nk_f32_t *c, //
1051
+ nk_size_t rows, nk_size_t columns, nk_size_t depth, //
1052
+ nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
1053
+ nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_e3m2_t);
1054
+ nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
1055
+ nk_dots_packed_e3m2_rvv(a, b_packed, c, rows, columns, depth, a_stride_in_bytes, c_stride_in_bytes);
1056
+ nk_euclideans_packed_e3m2_rvv_finalize_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
1057
+ }
1058
+
1059
+ NK_INTERNAL void nk_angulars_symmetric_e3m2_rvv_finalize_(nk_e3m2_t const *vectors, nk_size_t n_vectors,
1060
+ nk_size_t depth, nk_size_t stride_elements, nk_f32_t *result,
1061
+ nk_size_t result_stride_elements, nk_size_t row_start,
1062
+ nk_size_t row_count) {
1063
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
1064
+ nk_f32_t *result_row = result + row_index * result_stride_elements;
1065
+ result_row[row_index] = nk_dots_reduce_sumsq_e3m2_(vectors + row_index * stride_elements, depth);
1066
+ }
1067
+ nk_f32_t norms_cache[256];
1068
+ for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
1069
+ nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
1070
+ for (nk_size_t col = chunk_start; col < chunk_end; ++col)
1071
+ norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_e3m2_(vectors + col * stride_elements, depth);
1072
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
1073
+ nk_size_t col_start = row_index + 1 > chunk_start ? row_index + 1 : chunk_start;
1074
+ if (col_start >= chunk_end) continue;
1075
+ nk_f32_t *result_row = result + row_index * result_stride_elements;
1076
+ nk_f32_t query_norm_sq_f32 = result_row[row_index];
1077
+ nk_size_t count_remaining = chunk_end - col_start;
1078
+ nk_f32_t *result_ptr = result_row + col_start;
1079
+ nk_f32_t const *norms_ptr = norms_cache + (col_start - chunk_start);
1080
+ while (count_remaining > 0) {
1081
+ size_t vector_length = __riscv_vsetvl_e32m1(count_remaining);
1082
+ vfloat32m1_t dots_f32m1 = __riscv_vle32_v_f32m1(result_ptr, vector_length);
1083
+ vfloat32m1_t target_norms_sq_f32m1 = __riscv_vle32_v_f32m1(norms_ptr, vector_length);
1084
+ vfloat32m1_t norms_product_f32m1 = __riscv_vfmul_vf_f32m1(target_norms_sq_f32m1, query_norm_sq_f32,
1085
+ vector_length);
1086
+ vfloat32m1_t rsqrt_f32m1 = nk_rsqrt_f32m1_rvv_(norms_product_f32m1, vector_length);
1087
+ vfloat32m1_t normalized_dots_f32m1 = __riscv_vfmul_vv_f32m1(dots_f32m1, rsqrt_f32m1, vector_length);
1088
+ vfloat32m1_t angular_f32m1 = __riscv_vfrsub_vf_f32m1(normalized_dots_f32m1, 1.0f, vector_length);
1089
+ angular_f32m1 = __riscv_vfmax_vf_f32m1(angular_f32m1, 0.0f, vector_length);
1090
+ __riscv_vse32_v_f32m1(result_ptr, angular_f32m1, vector_length);
1091
+ result_ptr += vector_length;
1092
+ norms_ptr += vector_length;
1093
+ count_remaining -= vector_length;
1094
+ }
1095
+ }
1096
+ }
1097
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index)
1098
+ result[row_index * result_stride_elements + row_index] = 0;
1099
+ }
1100
+
1101
+ NK_PUBLIC void nk_angulars_symmetric_e3m2_rvv( //
1102
+ nk_e3m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
1103
+ nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
1104
+ nk_size_t const stride_elements = stride / sizeof(nk_e3m2_t);
1105
+ nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
1106
+ nk_dots_symmetric_e3m2_rvv(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
1107
+ nk_angulars_symmetric_e3m2_rvv_finalize_(vectors, n_vectors, depth, stride_elements, result, result_stride_elements,
1108
+ row_start, row_count);
1109
+ }
1110
+
1111
+ NK_INTERNAL void nk_euclideans_symmetric_e3m2_rvv_finalize_(nk_e3m2_t const *vectors, nk_size_t n_vectors,
1112
+ nk_size_t depth, nk_size_t stride_elements,
1113
+ nk_f32_t *result, nk_size_t result_stride_elements,
1114
+ nk_size_t row_start, nk_size_t row_count) {
1115
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
1116
+ nk_f32_t *result_row = result + row_index * result_stride_elements;
1117
+ result_row[row_index] = nk_dots_reduce_sumsq_e3m2_(vectors + row_index * stride_elements, depth);
1118
+ }
1119
+ nk_f32_t norms_cache[256];
1120
+ for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
1121
+ nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
1122
+ for (nk_size_t col = chunk_start; col < chunk_end; ++col)
1123
+ norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_e3m2_(vectors + col * stride_elements, depth);
1124
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
1125
+ nk_size_t col_start = row_index + 1 > chunk_start ? row_index + 1 : chunk_start;
1126
+ if (col_start >= chunk_end) continue;
1127
+ nk_f32_t *result_row = result + row_index * result_stride_elements;
1128
+ nk_f32_t query_norm_sq_f32 = result_row[row_index];
1129
+ nk_size_t count_remaining = chunk_end - col_start;
1130
+ nk_f32_t *result_ptr = result_row + col_start;
1131
+ nk_f32_t const *norms_ptr = norms_cache + (col_start - chunk_start);
1132
+ while (count_remaining > 0) {
1133
+ size_t vector_length = __riscv_vsetvl_e32m1(count_remaining);
1134
+ vfloat32m1_t dots_f32m1 = __riscv_vle32_v_f32m1(result_ptr, vector_length);
1135
+ vfloat32m1_t target_norms_sq_f32m1 = __riscv_vle32_v_f32m1(norms_ptr, vector_length);
1136
+ vfloat32m1_t sum_sq_f32m1 = __riscv_vfadd_vf_f32m1(target_norms_sq_f32m1, query_norm_sq_f32,
1137
+ vector_length);
1138
+ vfloat32m1_t dist_sq_f32m1 = __riscv_vfsub_vv_f32m1(
1139
+ sum_sq_f32m1, __riscv_vfmul_vf_f32m1(dots_f32m1, 2.0f, vector_length), vector_length);
1140
+ dist_sq_f32m1 = __riscv_vfmax_vf_f32m1(dist_sq_f32m1, 0.0f, vector_length);
1141
+ __riscv_vse32_v_f32m1(result_ptr, __riscv_vfsqrt_v_f32m1(dist_sq_f32m1, vector_length), vector_length);
1142
+ result_ptr += vector_length;
1143
+ norms_ptr += vector_length;
1144
+ count_remaining -= vector_length;
1145
+ }
1146
+ }
1147
+ }
1148
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index)
1149
+ result[row_index * result_stride_elements + row_index] = 0;
1150
+ }
1151
+
1152
+ NK_PUBLIC void nk_euclideans_symmetric_e3m2_rvv( //
1153
+ nk_e3m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
1154
+ nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
1155
+ nk_size_t const stride_elements = stride / sizeof(nk_e3m2_t);
1156
+ nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
1157
+ nk_dots_symmetric_e3m2_rvv(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
1158
+ nk_euclideans_symmetric_e3m2_rvv_finalize_(vectors, n_vectors, depth, stride_elements, result,
1159
+ result_stride_elements, row_start, row_count);
1160
+ }
1161
+
1162
+ #pragma endregion // Micro Precision E3M2
1163
+
1164
+ #pragma region Quarter Precision E4M3
1165
+
1166
+ NK_INTERNAL void nk_angulars_packed_e4m3_rvv_finalize_(nk_e4m3_t const *a, void const *b_packed, nk_f32_t *c,
1167
+ nk_size_t rows, nk_size_t columns, nk_size_t depth,
1168
+ nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
1169
+ nk_cross_packed_buffer_header_t const *header = (nk_cross_packed_buffer_header_t const *)b_packed;
1170
+ nk_f32_t const *target_norms = (nk_f32_t const *)((char const *)b_packed + sizeof(nk_cross_packed_buffer_header_t) +
1171
+ header->column_count * header->depth_padded_values *
1172
+ sizeof(nk_f32_t));
1173
+ for (nk_size_t row_index = 0; row_index < rows; row_index++) {
1174
+ nk_e4m3_t const *a_row = a + row_index * a_stride_elements;
1175
+ nk_f32_t *result_row = c + row_index * c_stride_elements;
1176
+ nk_f32_t query_norm_sq_f32 = nk_dots_reduce_sumsq_e4m3_(a_row, depth);
1177
+ nk_size_t count_columns = columns;
1178
+ nk_f32_t *result_ptr = result_row;
1179
+ nk_f32_t const *norms_ptr = target_norms;
1180
+ while (count_columns > 0) {
1181
+ size_t vector_length = __riscv_vsetvl_e32m1(count_columns);
1182
+ vfloat32m1_t dots_f32m1 = __riscv_vle32_v_f32m1(result_ptr, vector_length);
1183
+ vfloat32m1_t target_norms_sq_f32m1 = __riscv_vle32_v_f32m1(norms_ptr, vector_length);
1184
+ vfloat32m1_t norms_product_f32m1 = __riscv_vfmul_vf_f32m1(target_norms_sq_f32m1, query_norm_sq_f32,
1185
+ vector_length);
1186
+ vfloat32m1_t rsqrt_f32m1 = nk_rsqrt_f32m1_rvv_(norms_product_f32m1, vector_length);
1187
+ vfloat32m1_t normalized_dots_f32m1 = __riscv_vfmul_vv_f32m1(dots_f32m1, rsqrt_f32m1, vector_length);
1188
+ vfloat32m1_t angular_f32m1 = __riscv_vfrsub_vf_f32m1(normalized_dots_f32m1, 1.0f, vector_length);
1189
+ angular_f32m1 = __riscv_vfmax_vf_f32m1(angular_f32m1, 0.0f, vector_length);
1190
+ __riscv_vse32_v_f32m1(result_ptr, angular_f32m1, vector_length);
1191
+ result_ptr += vector_length;
1192
+ norms_ptr += vector_length;
1193
+ count_columns -= vector_length;
1194
+ }
1195
+ }
1196
+ }
1197
+
1198
+ NK_PUBLIC void nk_angulars_packed_e4m3_rvv( //
1199
+ nk_e4m3_t const *a, void const *b_packed, nk_f32_t *c, //
1200
+ nk_size_t rows, nk_size_t columns, nk_size_t depth, //
1201
+ nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
1202
+ nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_e4m3_t);
1203
+ nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
1204
+ nk_dots_packed_e4m3_rvv(a, b_packed, c, rows, columns, depth, a_stride_in_bytes, c_stride_in_bytes);
1205
+ nk_angulars_packed_e4m3_rvv_finalize_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
1206
+ }
1207
+
1208
+ NK_INTERNAL void nk_euclideans_packed_e4m3_rvv_finalize_(nk_e4m3_t const *a, void const *b_packed, nk_f32_t *c,
1209
+ nk_size_t rows, nk_size_t columns, nk_size_t depth,
1210
+ nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
1211
+ nk_cross_packed_buffer_header_t const *header = (nk_cross_packed_buffer_header_t const *)b_packed;
1212
+ nk_f32_t const *target_norms = (nk_f32_t const *)((char const *)b_packed + sizeof(nk_cross_packed_buffer_header_t) +
1213
+ header->column_count * header->depth_padded_values *
1214
+ sizeof(nk_f32_t));
1215
+ for (nk_size_t row_index = 0; row_index < rows; row_index++) {
1216
+ nk_e4m3_t const *a_row = a + row_index * a_stride_elements;
1217
+ nk_f32_t *result_row = c + row_index * c_stride_elements;
1218
+ nk_f32_t query_norm_sq_f32 = nk_dots_reduce_sumsq_e4m3_(a_row, depth);
1219
+ nk_size_t count_columns = columns;
1220
+ nk_f32_t *result_ptr = result_row;
1221
+ nk_f32_t const *norms_ptr = target_norms;
1222
+ while (count_columns > 0) {
1223
+ size_t vector_length = __riscv_vsetvl_e32m1(count_columns);
1224
+ vfloat32m1_t dots_f32m1 = __riscv_vle32_v_f32m1(result_ptr, vector_length);
1225
+ vfloat32m1_t target_norms_sq_f32m1 = __riscv_vle32_v_f32m1(norms_ptr, vector_length);
1226
+ vfloat32m1_t sum_sq_f32m1 = __riscv_vfadd_vf_f32m1(target_norms_sq_f32m1, query_norm_sq_f32, vector_length);
1227
+ vfloat32m1_t dist_sq_f32m1 = __riscv_vfsub_vv_f32m1(
1228
+ sum_sq_f32m1, __riscv_vfmul_vf_f32m1(dots_f32m1, 2.0f, vector_length), vector_length);
1229
+ dist_sq_f32m1 = __riscv_vfmax_vf_f32m1(dist_sq_f32m1, 0.0f, vector_length);
1230
+ __riscv_vse32_v_f32m1(result_ptr, __riscv_vfsqrt_v_f32m1(dist_sq_f32m1, vector_length), vector_length);
1231
+ result_ptr += vector_length;
1232
+ norms_ptr += vector_length;
1233
+ count_columns -= vector_length;
1234
+ }
1235
+ }
1236
+ }
1237
+
1238
+ NK_PUBLIC void nk_euclideans_packed_e4m3_rvv( //
1239
+ nk_e4m3_t const *a, void const *b_packed, nk_f32_t *c, //
1240
+ nk_size_t rows, nk_size_t columns, nk_size_t depth, //
1241
+ nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
1242
+ nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_e4m3_t);
1243
+ nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
1244
+ nk_dots_packed_e4m3_rvv(a, b_packed, c, rows, columns, depth, a_stride_in_bytes, c_stride_in_bytes);
1245
+ nk_euclideans_packed_e4m3_rvv_finalize_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
1246
+ }
1247
+
1248
+ NK_INTERNAL void nk_angulars_symmetric_e4m3_rvv_finalize_(nk_e4m3_t const *vectors, nk_size_t n_vectors,
1249
+ nk_size_t depth, nk_size_t stride_elements, nk_f32_t *result,
1250
+ nk_size_t result_stride_elements, nk_size_t row_start,
1251
+ nk_size_t row_count) {
1252
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
1253
+ nk_f32_t *result_row = result + row_index * result_stride_elements;
1254
+ result_row[row_index] = nk_dots_reduce_sumsq_e4m3_(vectors + row_index * stride_elements, depth);
1255
+ }
1256
+ nk_f32_t norms_cache[256];
1257
+ for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
1258
+ nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
1259
+ for (nk_size_t col = chunk_start; col < chunk_end; ++col)
1260
+ norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_e4m3_(vectors + col * stride_elements, depth);
1261
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
1262
+ nk_size_t col_start = row_index + 1 > chunk_start ? row_index + 1 : chunk_start;
1263
+ if (col_start >= chunk_end) continue;
1264
+ nk_f32_t *result_row = result + row_index * result_stride_elements;
1265
+ nk_f32_t query_norm_sq_f32 = result_row[row_index];
1266
+ nk_size_t count_remaining = chunk_end - col_start;
1267
+ nk_f32_t *result_ptr = result_row + col_start;
1268
+ nk_f32_t const *norms_ptr = norms_cache + (col_start - chunk_start);
1269
+ while (count_remaining > 0) {
1270
+ size_t vector_length = __riscv_vsetvl_e32m1(count_remaining);
1271
+ vfloat32m1_t dots_f32m1 = __riscv_vle32_v_f32m1(result_ptr, vector_length);
1272
+ vfloat32m1_t target_norms_sq_f32m1 = __riscv_vle32_v_f32m1(norms_ptr, vector_length);
1273
+ vfloat32m1_t norms_product_f32m1 = __riscv_vfmul_vf_f32m1(target_norms_sq_f32m1, query_norm_sq_f32,
1274
+ vector_length);
1275
+ vfloat32m1_t rsqrt_f32m1 = nk_rsqrt_f32m1_rvv_(norms_product_f32m1, vector_length);
1276
+ vfloat32m1_t normalized_dots_f32m1 = __riscv_vfmul_vv_f32m1(dots_f32m1, rsqrt_f32m1, vector_length);
1277
+ vfloat32m1_t angular_f32m1 = __riscv_vfrsub_vf_f32m1(normalized_dots_f32m1, 1.0f, vector_length);
1278
+ angular_f32m1 = __riscv_vfmax_vf_f32m1(angular_f32m1, 0.0f, vector_length);
1279
+ __riscv_vse32_v_f32m1(result_ptr, angular_f32m1, vector_length);
1280
+ result_ptr += vector_length;
1281
+ norms_ptr += vector_length;
1282
+ count_remaining -= vector_length;
1283
+ }
1284
+ }
1285
+ }
1286
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index)
1287
+ result[row_index * result_stride_elements + row_index] = 0;
1288
+ }
1289
+
1290
+ NK_PUBLIC void nk_angulars_symmetric_e4m3_rvv( //
1291
+ nk_e4m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
1292
+ nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
1293
+ nk_size_t const stride_elements = stride / sizeof(nk_e4m3_t);
1294
+ nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
1295
+ nk_dots_symmetric_e4m3_rvv(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
1296
+ nk_angulars_symmetric_e4m3_rvv_finalize_(vectors, n_vectors, depth, stride_elements, result, result_stride_elements,
1297
+ row_start, row_count);
1298
+ }
1299
+
1300
+ NK_INTERNAL void nk_euclideans_symmetric_e4m3_rvv_finalize_(nk_e4m3_t const *vectors, nk_size_t n_vectors,
1301
+ nk_size_t depth, nk_size_t stride_elements,
1302
+ nk_f32_t *result, nk_size_t result_stride_elements,
1303
+ nk_size_t row_start, nk_size_t row_count) {
1304
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
1305
+ nk_f32_t *result_row = result + row_index * result_stride_elements;
1306
+ result_row[row_index] = nk_dots_reduce_sumsq_e4m3_(vectors + row_index * stride_elements, depth);
1307
+ }
1308
+ nk_f32_t norms_cache[256];
1309
+ for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
1310
+ nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
1311
+ for (nk_size_t col = chunk_start; col < chunk_end; ++col)
1312
+ norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_e4m3_(vectors + col * stride_elements, depth);
1313
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
1314
+ nk_size_t col_start = row_index + 1 > chunk_start ? row_index + 1 : chunk_start;
1315
+ if (col_start >= chunk_end) continue;
1316
+ nk_f32_t *result_row = result + row_index * result_stride_elements;
1317
+ nk_f32_t query_norm_sq_f32 = result_row[row_index];
1318
+ nk_size_t count_remaining = chunk_end - col_start;
1319
+ nk_f32_t *result_ptr = result_row + col_start;
1320
+ nk_f32_t const *norms_ptr = norms_cache + (col_start - chunk_start);
1321
+ while (count_remaining > 0) {
1322
+ size_t vector_length = __riscv_vsetvl_e32m1(count_remaining);
1323
+ vfloat32m1_t dots_f32m1 = __riscv_vle32_v_f32m1(result_ptr, vector_length);
1324
+ vfloat32m1_t target_norms_sq_f32m1 = __riscv_vle32_v_f32m1(norms_ptr, vector_length);
1325
+ vfloat32m1_t sum_sq_f32m1 = __riscv_vfadd_vf_f32m1(target_norms_sq_f32m1, query_norm_sq_f32,
1326
+ vector_length);
1327
+ vfloat32m1_t dist_sq_f32m1 = __riscv_vfsub_vv_f32m1(
1328
+ sum_sq_f32m1, __riscv_vfmul_vf_f32m1(dots_f32m1, 2.0f, vector_length), vector_length);
1329
+ dist_sq_f32m1 = __riscv_vfmax_vf_f32m1(dist_sq_f32m1, 0.0f, vector_length);
1330
+ __riscv_vse32_v_f32m1(result_ptr, __riscv_vfsqrt_v_f32m1(dist_sq_f32m1, vector_length), vector_length);
1331
+ result_ptr += vector_length;
1332
+ norms_ptr += vector_length;
1333
+ count_remaining -= vector_length;
1334
+ }
1335
+ }
1336
+ }
1337
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index)
1338
+ result[row_index * result_stride_elements + row_index] = 0;
1339
+ }
1340
+
1341
+ NK_PUBLIC void nk_euclideans_symmetric_e4m3_rvv( //
1342
+ nk_e4m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
1343
+ nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
1344
+ nk_size_t const stride_elements = stride / sizeof(nk_e4m3_t);
1345
+ nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
1346
+ nk_dots_symmetric_e4m3_rvv(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
1347
+ nk_euclideans_symmetric_e4m3_rvv_finalize_(vectors, n_vectors, depth, stride_elements, result,
1348
+ result_stride_elements, row_start, row_count);
1349
+ }
1350
+
1351
+ #pragma endregion // Quarter Precision E4M3
1352
+
1353
+ #pragma region Quarter Precision E5M2
1354
+
1355
+ NK_INTERNAL void nk_angulars_packed_e5m2_rvv_finalize_(nk_e5m2_t const *a, void const *b_packed, nk_f32_t *c,
1356
+ nk_size_t rows, nk_size_t columns, nk_size_t depth,
1357
+ nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
1358
+ nk_cross_packed_buffer_header_t const *header = (nk_cross_packed_buffer_header_t const *)b_packed;
1359
+ nk_f32_t const *target_norms = (nk_f32_t const *)((char const *)b_packed + sizeof(nk_cross_packed_buffer_header_t) +
1360
+ header->column_count * header->depth_padded_values *
1361
+ sizeof(nk_f32_t));
1362
+ for (nk_size_t row_index = 0; row_index < rows; row_index++) {
1363
+ nk_e5m2_t const *a_row = a + row_index * a_stride_elements;
1364
+ nk_f32_t *result_row = c + row_index * c_stride_elements;
1365
+ nk_f32_t query_norm_sq_f32 = nk_dots_reduce_sumsq_e5m2_(a_row, depth);
1366
+ nk_size_t count_columns = columns;
1367
+ nk_f32_t *result_ptr = result_row;
1368
+ nk_f32_t const *norms_ptr = target_norms;
1369
+ while (count_columns > 0) {
1370
+ size_t vector_length = __riscv_vsetvl_e32m1(count_columns);
1371
+ vfloat32m1_t dots_f32m1 = __riscv_vle32_v_f32m1(result_ptr, vector_length);
1372
+ vfloat32m1_t target_norms_sq_f32m1 = __riscv_vle32_v_f32m1(norms_ptr, vector_length);
1373
+ vfloat32m1_t norms_product_f32m1 = __riscv_vfmul_vf_f32m1(target_norms_sq_f32m1, query_norm_sq_f32,
1374
+ vector_length);
1375
+ vfloat32m1_t rsqrt_f32m1 = nk_rsqrt_f32m1_rvv_(norms_product_f32m1, vector_length);
1376
+ vfloat32m1_t normalized_dots_f32m1 = __riscv_vfmul_vv_f32m1(dots_f32m1, rsqrt_f32m1, vector_length);
1377
+ vfloat32m1_t angular_f32m1 = __riscv_vfrsub_vf_f32m1(normalized_dots_f32m1, 1.0f, vector_length);
1378
+ angular_f32m1 = __riscv_vfmax_vf_f32m1(angular_f32m1, 0.0f, vector_length);
1379
+ __riscv_vse32_v_f32m1(result_ptr, angular_f32m1, vector_length);
1380
+ result_ptr += vector_length;
1381
+ norms_ptr += vector_length;
1382
+ count_columns -= vector_length;
1383
+ }
1384
+ }
1385
+ }
1386
+
1387
+ NK_PUBLIC void nk_angulars_packed_e5m2_rvv( //
1388
+ nk_e5m2_t const *a, void const *b_packed, nk_f32_t *c, //
1389
+ nk_size_t rows, nk_size_t columns, nk_size_t depth, //
1390
+ nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
1391
+ nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_e5m2_t);
1392
+ nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
1393
+ nk_dots_packed_e5m2_rvv(a, b_packed, c, rows, columns, depth, a_stride_in_bytes, c_stride_in_bytes);
1394
+ nk_angulars_packed_e5m2_rvv_finalize_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
1395
+ }
1396
+
1397
+ NK_INTERNAL void nk_euclideans_packed_e5m2_rvv_finalize_(nk_e5m2_t const *a, void const *b_packed, nk_f32_t *c,
1398
+ nk_size_t rows, nk_size_t columns, nk_size_t depth,
1399
+ nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
1400
+ nk_cross_packed_buffer_header_t const *header = (nk_cross_packed_buffer_header_t const *)b_packed;
1401
+ nk_f32_t const *target_norms = (nk_f32_t const *)((char const *)b_packed + sizeof(nk_cross_packed_buffer_header_t) +
1402
+ header->column_count * header->depth_padded_values *
1403
+ sizeof(nk_f32_t));
1404
+ for (nk_size_t row_index = 0; row_index < rows; row_index++) {
1405
+ nk_e5m2_t const *a_row = a + row_index * a_stride_elements;
1406
+ nk_f32_t *result_row = c + row_index * c_stride_elements;
1407
+ nk_f32_t query_norm_sq_f32 = nk_dots_reduce_sumsq_e5m2_(a_row, depth);
1408
+ nk_size_t count_columns = columns;
1409
+ nk_f32_t *result_ptr = result_row;
1410
+ nk_f32_t const *norms_ptr = target_norms;
1411
+ while (count_columns > 0) {
1412
+ size_t vector_length = __riscv_vsetvl_e32m1(count_columns);
1413
+ vfloat32m1_t dots_f32m1 = __riscv_vle32_v_f32m1(result_ptr, vector_length);
1414
+ vfloat32m1_t target_norms_sq_f32m1 = __riscv_vle32_v_f32m1(norms_ptr, vector_length);
1415
+ vfloat32m1_t sum_sq_f32m1 = __riscv_vfadd_vf_f32m1(target_norms_sq_f32m1, query_norm_sq_f32, vector_length);
1416
+ vfloat32m1_t dist_sq_f32m1 = __riscv_vfsub_vv_f32m1(
1417
+ sum_sq_f32m1, __riscv_vfmul_vf_f32m1(dots_f32m1, 2.0f, vector_length), vector_length);
1418
+ dist_sq_f32m1 = __riscv_vfmax_vf_f32m1(dist_sq_f32m1, 0.0f, vector_length);
1419
+ __riscv_vse32_v_f32m1(result_ptr, __riscv_vfsqrt_v_f32m1(dist_sq_f32m1, vector_length), vector_length);
1420
+ result_ptr += vector_length;
1421
+ norms_ptr += vector_length;
1422
+ count_columns -= vector_length;
1423
+ }
1424
+ }
1425
+ }
1426
+
1427
+ NK_PUBLIC void nk_euclideans_packed_e5m2_rvv( //
1428
+ nk_e5m2_t const *a, void const *b_packed, nk_f32_t *c, //
1429
+ nk_size_t rows, nk_size_t columns, nk_size_t depth, //
1430
+ nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
1431
+ nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_e5m2_t);
1432
+ nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
1433
+ nk_dots_packed_e5m2_rvv(a, b_packed, c, rows, columns, depth, a_stride_in_bytes, c_stride_in_bytes);
1434
+ nk_euclideans_packed_e5m2_rvv_finalize_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
1435
+ }
1436
+
1437
+ NK_INTERNAL void nk_angulars_symmetric_e5m2_rvv_finalize_(nk_e5m2_t const *vectors, nk_size_t n_vectors,
1438
+ nk_size_t depth, nk_size_t stride_elements, nk_f32_t *result,
1439
+ nk_size_t result_stride_elements, nk_size_t row_start,
1440
+ nk_size_t row_count) {
1441
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
1442
+ nk_f32_t *result_row = result + row_index * result_stride_elements;
1443
+ result_row[row_index] = nk_dots_reduce_sumsq_e5m2_(vectors + row_index * stride_elements, depth);
1444
+ }
1445
+ nk_f32_t norms_cache[256];
1446
+ for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
1447
+ nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
1448
+ for (nk_size_t col = chunk_start; col < chunk_end; ++col)
1449
+ norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_e5m2_(vectors + col * stride_elements, depth);
1450
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
1451
+ nk_size_t col_start = row_index + 1 > chunk_start ? row_index + 1 : chunk_start;
1452
+ if (col_start >= chunk_end) continue;
1453
+ nk_f32_t *result_row = result + row_index * result_stride_elements;
1454
+ nk_f32_t query_norm_sq_f32 = result_row[row_index];
1455
+ nk_size_t count_remaining = chunk_end - col_start;
1456
+ nk_f32_t *result_ptr = result_row + col_start;
1457
+ nk_f32_t const *norms_ptr = norms_cache + (col_start - chunk_start);
1458
+ while (count_remaining > 0) {
1459
+ size_t vector_length = __riscv_vsetvl_e32m1(count_remaining);
1460
+ vfloat32m1_t dots_f32m1 = __riscv_vle32_v_f32m1(result_ptr, vector_length);
1461
+ vfloat32m1_t target_norms_sq_f32m1 = __riscv_vle32_v_f32m1(norms_ptr, vector_length);
1462
+ vfloat32m1_t norms_product_f32m1 = __riscv_vfmul_vf_f32m1(target_norms_sq_f32m1, query_norm_sq_f32,
1463
+ vector_length);
1464
+ vfloat32m1_t rsqrt_f32m1 = nk_rsqrt_f32m1_rvv_(norms_product_f32m1, vector_length);
1465
+ vfloat32m1_t normalized_dots_f32m1 = __riscv_vfmul_vv_f32m1(dots_f32m1, rsqrt_f32m1, vector_length);
1466
+ vfloat32m1_t angular_f32m1 = __riscv_vfrsub_vf_f32m1(normalized_dots_f32m1, 1.0f, vector_length);
1467
+ angular_f32m1 = __riscv_vfmax_vf_f32m1(angular_f32m1, 0.0f, vector_length);
1468
+ __riscv_vse32_v_f32m1(result_ptr, angular_f32m1, vector_length);
1469
+ result_ptr += vector_length;
1470
+ norms_ptr += vector_length;
1471
+ count_remaining -= vector_length;
1472
+ }
1473
+ }
1474
+ }
1475
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index)
1476
+ result[row_index * result_stride_elements + row_index] = 0;
1477
+ }
1478
+
1479
+ NK_PUBLIC void nk_angulars_symmetric_e5m2_rvv( //
1480
+ nk_e5m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
1481
+ nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
1482
+ nk_size_t const stride_elements = stride / sizeof(nk_e5m2_t);
1483
+ nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
1484
+ nk_dots_symmetric_e5m2_rvv(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
1485
+ nk_angulars_symmetric_e5m2_rvv_finalize_(vectors, n_vectors, depth, stride_elements, result, result_stride_elements,
1486
+ row_start, row_count);
1487
+ }
1488
+
1489
+ NK_INTERNAL void nk_euclideans_symmetric_e5m2_rvv_finalize_(nk_e5m2_t const *vectors, nk_size_t n_vectors,
1490
+ nk_size_t depth, nk_size_t stride_elements,
1491
+ nk_f32_t *result, nk_size_t result_stride_elements,
1492
+ nk_size_t row_start, nk_size_t row_count) {
1493
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
1494
+ nk_f32_t *result_row = result + row_index * result_stride_elements;
1495
+ result_row[row_index] = nk_dots_reduce_sumsq_e5m2_(vectors + row_index * stride_elements, depth);
1496
+ }
1497
+ nk_f32_t norms_cache[256];
1498
+ for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
1499
+ nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
1500
+ for (nk_size_t col = chunk_start; col < chunk_end; ++col)
1501
+ norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_e5m2_(vectors + col * stride_elements, depth);
1502
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
1503
+ nk_size_t col_start = row_index + 1 > chunk_start ? row_index + 1 : chunk_start;
1504
+ if (col_start >= chunk_end) continue;
1505
+ nk_f32_t *result_row = result + row_index * result_stride_elements;
1506
+ nk_f32_t query_norm_sq_f32 = result_row[row_index];
1507
+ nk_size_t count_remaining = chunk_end - col_start;
1508
+ nk_f32_t *result_ptr = result_row + col_start;
1509
+ nk_f32_t const *norms_ptr = norms_cache + (col_start - chunk_start);
1510
+ while (count_remaining > 0) {
1511
+ size_t vector_length = __riscv_vsetvl_e32m1(count_remaining);
1512
+ vfloat32m1_t dots_f32m1 = __riscv_vle32_v_f32m1(result_ptr, vector_length);
1513
+ vfloat32m1_t target_norms_sq_f32m1 = __riscv_vle32_v_f32m1(norms_ptr, vector_length);
1514
+ vfloat32m1_t sum_sq_f32m1 = __riscv_vfadd_vf_f32m1(target_norms_sq_f32m1, query_norm_sq_f32,
1515
+ vector_length);
1516
+ vfloat32m1_t dist_sq_f32m1 = __riscv_vfsub_vv_f32m1(
1517
+ sum_sq_f32m1, __riscv_vfmul_vf_f32m1(dots_f32m1, 2.0f, vector_length), vector_length);
1518
+ dist_sq_f32m1 = __riscv_vfmax_vf_f32m1(dist_sq_f32m1, 0.0f, vector_length);
1519
+ __riscv_vse32_v_f32m1(result_ptr, __riscv_vfsqrt_v_f32m1(dist_sq_f32m1, vector_length), vector_length);
1520
+ result_ptr += vector_length;
1521
+ norms_ptr += vector_length;
1522
+ count_remaining -= vector_length;
1523
+ }
1524
+ }
1525
+ }
1526
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index)
1527
+ result[row_index * result_stride_elements + row_index] = 0;
1528
+ }
1529
+
1530
+ NK_PUBLIC void nk_euclideans_symmetric_e5m2_rvv( //
1531
+ nk_e5m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
1532
+ nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
1533
+ nk_size_t const stride_elements = stride / sizeof(nk_e5m2_t);
1534
+ nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
1535
+ nk_dots_symmetric_e5m2_rvv(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
1536
+ nk_euclideans_symmetric_e5m2_rvv_finalize_(vectors, n_vectors, depth, stride_elements, result,
1537
+ result_stride_elements, row_start, row_count);
1538
+ }
1539
+
1540
+ #pragma endregion // Quarter Precision E5M2
1541
+
1542
+ #pragma region Signed 8-bit Integers
1543
+
1544
+ NK_INTERNAL void nk_angulars_packed_i8_rvv_finalize_(nk_i8_t const *a, void const *b_packed, nk_f32_t *c,
1545
+ nk_size_t rows, nk_size_t columns, nk_size_t depth,
1546
+ nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
1547
+ nk_cross_packed_buffer_header_t const *header = (nk_cross_packed_buffer_header_t const *)b_packed;
1548
+ nk_u32_t const *target_norms = (nk_u32_t const *)((char const *)b_packed + sizeof(nk_cross_packed_buffer_header_t) +
1549
+ header->column_count * header->depth_padded_values *
1550
+ sizeof(nk_i8_t));
1551
+ for (nk_size_t row_index = 0; row_index < rows; row_index++) {
1552
+ nk_i8_t const *a_row = a + row_index * a_stride_elements;
1553
+ nk_f32_t *result_row = c + row_index * c_stride_elements;
1554
+ nk_u32_t query_norm_sq = nk_dots_reduce_sumsq_i8_(a_row, depth);
1555
+ nk_f32_t query_norm_sq_f32 = (nk_f32_t)query_norm_sq;
1556
+ nk_size_t count_columns = columns;
1557
+ nk_f32_t *result_ptr = result_row;
1558
+ nk_u32_t const *norms_ptr = target_norms;
1559
+ while (count_columns > 0) {
1560
+ size_t vector_length = __riscv_vsetvl_e32m1(count_columns);
1561
+ vfloat32m1_t dots_f32m1 = __riscv_vfcvt_f_x_v_f32m1(
1562
+ __riscv_vle32_v_i32m1((nk_i32_t const *)result_ptr, vector_length), vector_length);
1563
+ vfloat32m1_t target_norms_sq_f32m1 = __riscv_vfcvt_f_xu_v_f32m1(
1564
+ __riscv_vle32_v_u32m1((nk_u32_t const *)norms_ptr, vector_length), vector_length);
1565
+ vfloat32m1_t norms_product_f32m1 = __riscv_vfmul_vf_f32m1(target_norms_sq_f32m1, query_norm_sq_f32,
1566
+ vector_length);
1567
+ vfloat32m1_t rsqrt_f32m1 = nk_rsqrt_f32m1_rvv_(norms_product_f32m1, vector_length);
1568
+ vfloat32m1_t normalized_dots_f32m1 = __riscv_vfmul_vv_f32m1(dots_f32m1, rsqrt_f32m1, vector_length);
1569
+ vfloat32m1_t angular_f32m1 = __riscv_vfrsub_vf_f32m1(normalized_dots_f32m1, 1.0f, vector_length);
1570
+ angular_f32m1 = __riscv_vfmax_vf_f32m1(angular_f32m1, 0.0f, vector_length);
1571
+ __riscv_vse32_v_f32m1(result_ptr, angular_f32m1, vector_length);
1572
+ result_ptr += vector_length;
1573
+ norms_ptr += vector_length;
1574
+ count_columns -= vector_length;
1575
+ }
1576
+ }
1577
+ }
1578
+
1579
+ NK_PUBLIC void nk_angulars_packed_i8_rvv( //
1580
+ nk_i8_t const *a, void const *b_packed, nk_f32_t *c, //
1581
+ nk_size_t rows, nk_size_t columns, nk_size_t depth, //
1582
+ nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
1583
+ nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_i8_t);
1584
+ nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
1585
+ nk_dots_packed_i8_rvv(a, b_packed, (nk_i32_t *)c, rows, columns, depth, a_stride_in_bytes, c_stride_in_bytes);
1586
+ nk_angulars_packed_i8_rvv_finalize_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
1587
+ }
1588
+
1589
+ NK_INTERNAL void nk_euclideans_packed_i8_rvv_finalize_(nk_i8_t const *a, void const *b_packed, nk_f32_t *c,
1590
+ nk_size_t rows, nk_size_t columns, nk_size_t depth,
1591
+ nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
1592
+ nk_cross_packed_buffer_header_t const *header = (nk_cross_packed_buffer_header_t const *)b_packed;
1593
+ nk_u32_t const *target_norms = (nk_u32_t const *)((char const *)b_packed + sizeof(nk_cross_packed_buffer_header_t) +
1594
+ header->column_count * header->depth_padded_values *
1595
+ sizeof(nk_i8_t));
1596
+ for (nk_size_t row_index = 0; row_index < rows; row_index++) {
1597
+ nk_i8_t const *a_row = a + row_index * a_stride_elements;
1598
+ nk_f32_t *result_row = c + row_index * c_stride_elements;
1599
+ nk_u32_t query_norm_sq = nk_dots_reduce_sumsq_i8_(a_row, depth);
1600
+ nk_f32_t query_norm_sq_f32 = (nk_f32_t)query_norm_sq;
1601
+ nk_size_t count_columns = columns;
1602
+ nk_f32_t *result_ptr = result_row;
1603
+ nk_u32_t const *norms_ptr = target_norms;
1604
+ while (count_columns > 0) {
1605
+ size_t vector_length = __riscv_vsetvl_e32m1(count_columns);
1606
+ vfloat32m1_t dots_f32m1 = __riscv_vfcvt_f_x_v_f32m1(
1607
+ __riscv_vle32_v_i32m1((nk_i32_t const *)result_ptr, vector_length), vector_length);
1608
+ vfloat32m1_t target_norms_sq_f32m1 = __riscv_vfcvt_f_xu_v_f32m1(
1609
+ __riscv_vle32_v_u32m1((nk_u32_t const *)norms_ptr, vector_length), vector_length);
1610
+ vfloat32m1_t sum_sq_f32m1 = __riscv_vfadd_vf_f32m1(target_norms_sq_f32m1, query_norm_sq_f32, vector_length);
1611
+ vfloat32m1_t dist_sq_f32m1 = __riscv_vfsub_vv_f32m1(
1612
+ sum_sq_f32m1, __riscv_vfmul_vf_f32m1(dots_f32m1, 2.0f, vector_length), vector_length);
1613
+ dist_sq_f32m1 = __riscv_vfmax_vf_f32m1(dist_sq_f32m1, 0.0f, vector_length);
1614
+ __riscv_vse32_v_f32m1(result_ptr, __riscv_vfsqrt_v_f32m1(dist_sq_f32m1, vector_length), vector_length);
1615
+ result_ptr += vector_length;
1616
+ norms_ptr += vector_length;
1617
+ count_columns -= vector_length;
1618
+ }
1619
+ }
1620
+ }
1621
+
1622
+ NK_PUBLIC void nk_euclideans_packed_i8_rvv( //
1623
+ nk_i8_t const *a, void const *b_packed, nk_f32_t *c, //
1624
+ nk_size_t rows, nk_size_t columns, nk_size_t depth, //
1625
+ nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
1626
+ nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_i8_t);
1627
+ nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
1628
+ nk_dots_packed_i8_rvv(a, b_packed, (nk_i32_t *)c, rows, columns, depth, a_stride_in_bytes, c_stride_in_bytes);
1629
+ nk_euclideans_packed_i8_rvv_finalize_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
1630
+ }
1631
+
1632
+ NK_INTERNAL void nk_angulars_symmetric_i8_rvv_finalize_(nk_i8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1633
+ nk_size_t stride_elements, nk_f32_t *result,
1634
+ nk_size_t result_stride_elements, nk_size_t row_start,
1635
+ nk_size_t row_count) {
1636
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
1637
+ nk_u32_t norm = nk_dots_reduce_sumsq_i8_(vectors + row_index * stride_elements, depth);
1638
+ ((nk_u32_t *)(result + row_index * result_stride_elements))[row_index] = norm;
1639
+ }
1640
+ nk_u32_t norms_cache[256];
1641
+ for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
1642
+ nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
1643
+ for (nk_size_t col = chunk_start; col < chunk_end; ++col)
1644
+ norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_i8_(vectors + col * stride_elements, depth);
1645
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
1646
+ nk_size_t col_start = row_index + 1 > chunk_start ? row_index + 1 : chunk_start;
1647
+ if (col_start >= chunk_end) continue;
1648
+ nk_f32_t *result_row = result + row_index * result_stride_elements;
1649
+ nk_u32_t query_norm_sq = ((nk_u32_t *)result_row)[row_index];
1650
+ nk_f32_t query_norm_sq_f32 = (nk_f32_t)query_norm_sq;
1651
+ nk_size_t count_remaining = chunk_end - col_start;
1652
+ nk_f32_t *result_ptr = result_row + col_start;
1653
+ nk_u32_t const *norms_ptr = norms_cache + (col_start - chunk_start);
1654
+ while (count_remaining > 0) {
1655
+ size_t vector_length = __riscv_vsetvl_e32m1(count_remaining);
1656
+ vfloat32m1_t dots_f32m1 = __riscv_vfcvt_f_x_v_f32m1(
1657
+ __riscv_vle32_v_i32m1((nk_i32_t const *)result_ptr, vector_length), vector_length);
1658
+ vfloat32m1_t target_norms_sq_f32m1 = __riscv_vfcvt_f_xu_v_f32m1(
1659
+ __riscv_vle32_v_u32m1((nk_u32_t const *)norms_ptr, vector_length), vector_length);
1660
+ vfloat32m1_t norms_product_f32m1 = __riscv_vfmul_vf_f32m1(target_norms_sq_f32m1, query_norm_sq_f32,
1661
+ vector_length);
1662
+ vfloat32m1_t rsqrt_f32m1 = nk_rsqrt_f32m1_rvv_(norms_product_f32m1, vector_length);
1663
+ vfloat32m1_t normalized_dots_f32m1 = __riscv_vfmul_vv_f32m1(dots_f32m1, rsqrt_f32m1, vector_length);
1664
+ vfloat32m1_t angular_f32m1 = __riscv_vfrsub_vf_f32m1(normalized_dots_f32m1, 1.0f, vector_length);
1665
+ angular_f32m1 = __riscv_vfmax_vf_f32m1(angular_f32m1, 0.0f, vector_length);
1666
+ __riscv_vse32_v_f32m1(result_ptr, angular_f32m1, vector_length);
1667
+ result_ptr += vector_length;
1668
+ norms_ptr += vector_length;
1669
+ count_remaining -= vector_length;
1670
+ }
1671
+ }
1672
+ }
1673
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index)
1674
+ result[row_index * result_stride_elements + row_index] = 0;
1675
+ }
1676
+
1677
+ NK_PUBLIC void nk_angulars_symmetric_i8_rvv( //
1678
+ nk_i8_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
1679
+ nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
1680
+ nk_size_t const stride_elements = stride / sizeof(nk_i8_t);
1681
+ nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
1682
+ nk_dots_symmetric_i8_rvv(vectors, n_vectors, depth, stride, (nk_i32_t *)result, result_stride, row_start,
1683
+ row_count);
1684
+ nk_angulars_symmetric_i8_rvv_finalize_(vectors, n_vectors, depth, stride_elements, result, result_stride_elements,
1685
+ row_start, row_count);
1686
+ }
1687
+
1688
+ NK_INTERNAL void nk_euclideans_symmetric_i8_rvv_finalize_(nk_i8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1689
+ nk_size_t stride_elements, nk_f32_t *result,
1690
+ nk_size_t result_stride_elements, nk_size_t row_start,
1691
+ nk_size_t row_count) {
1692
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
1693
+ nk_u32_t norm = nk_dots_reduce_sumsq_i8_(vectors + row_index * stride_elements, depth);
1694
+ ((nk_u32_t *)(result + row_index * result_stride_elements))[row_index] = norm;
1695
+ }
1696
+ nk_u32_t norms_cache[256];
1697
+ for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
1698
+ nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
1699
+ for (nk_size_t col = chunk_start; col < chunk_end; ++col)
1700
+ norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_i8_(vectors + col * stride_elements, depth);
1701
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
1702
+ nk_size_t col_start = row_index + 1 > chunk_start ? row_index + 1 : chunk_start;
1703
+ if (col_start >= chunk_end) continue;
1704
+ nk_f32_t *result_row = result + row_index * result_stride_elements;
1705
+ nk_u32_t query_norm_sq = ((nk_u32_t *)result_row)[row_index];
1706
+ nk_f32_t query_norm_sq_f32 = (nk_f32_t)query_norm_sq;
1707
+ nk_size_t count_remaining = chunk_end - col_start;
1708
+ nk_f32_t *result_ptr = result_row + col_start;
1709
+ nk_u32_t const *norms_ptr = norms_cache + (col_start - chunk_start);
1710
+ while (count_remaining > 0) {
1711
+ size_t vector_length = __riscv_vsetvl_e32m1(count_remaining);
1712
+ vfloat32m1_t dots_f32m1 = __riscv_vfcvt_f_x_v_f32m1(
1713
+ __riscv_vle32_v_i32m1((nk_i32_t const *)result_ptr, vector_length), vector_length);
1714
+ vfloat32m1_t target_norms_sq_f32m1 = __riscv_vfcvt_f_xu_v_f32m1(
1715
+ __riscv_vle32_v_u32m1((nk_u32_t const *)norms_ptr, vector_length), vector_length);
1716
+ vfloat32m1_t sum_sq_f32m1 = __riscv_vfadd_vf_f32m1(target_norms_sq_f32m1, query_norm_sq_f32,
1717
+ vector_length);
1718
+ vfloat32m1_t dist_sq_f32m1 = __riscv_vfsub_vv_f32m1(
1719
+ sum_sq_f32m1, __riscv_vfmul_vf_f32m1(dots_f32m1, 2.0f, vector_length), vector_length);
1720
+ dist_sq_f32m1 = __riscv_vfmax_vf_f32m1(dist_sq_f32m1, 0.0f, vector_length);
1721
+ __riscv_vse32_v_f32m1(result_ptr, __riscv_vfsqrt_v_f32m1(dist_sq_f32m1, vector_length), vector_length);
1722
+ result_ptr += vector_length;
1723
+ norms_ptr += vector_length;
1724
+ count_remaining -= vector_length;
1725
+ }
1726
+ }
1727
+ }
1728
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index)
1729
+ result[row_index * result_stride_elements + row_index] = 0;
1730
+ }
1731
+
1732
+ NK_PUBLIC void nk_euclideans_symmetric_i8_rvv( //
1733
+ nk_i8_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
1734
+ nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
1735
+ nk_size_t const stride_elements = stride / sizeof(nk_i8_t);
1736
+ nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
1737
+ nk_dots_symmetric_i8_rvv(vectors, n_vectors, depth, stride, (nk_i32_t *)result, result_stride, row_start,
1738
+ row_count);
1739
+ nk_euclideans_symmetric_i8_rvv_finalize_(vectors, n_vectors, depth, stride_elements, result, result_stride_elements,
1740
+ row_start, row_count);
1741
+ }
1742
+
1743
+ #pragma endregion // Signed 8-bit Integers
1744
+
1745
+ #pragma region Unsigned 8-bit Integers
1746
+
1747
+ NK_INTERNAL void nk_angulars_packed_u8_rvv_finalize_(nk_u8_t const *a, void const *b_packed, nk_f32_t *c,
1748
+ nk_size_t rows, nk_size_t columns, nk_size_t depth,
1749
+ nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
1750
+ nk_cross_packed_buffer_header_t const *header = (nk_cross_packed_buffer_header_t const *)b_packed;
1751
+ nk_u32_t const *target_norms = (nk_u32_t const *)((char const *)b_packed + sizeof(nk_cross_packed_buffer_header_t) +
1752
+ header->column_count * header->depth_padded_values *
1753
+ sizeof(nk_u8_t));
1754
+ for (nk_size_t row_index = 0; row_index < rows; row_index++) {
1755
+ nk_u8_t const *a_row = a + row_index * a_stride_elements;
1756
+ nk_f32_t *result_row = c + row_index * c_stride_elements;
1757
+ nk_u32_t query_norm_sq = nk_dots_reduce_sumsq_u8_(a_row, depth);
1758
+ nk_f32_t query_norm_sq_f32 = (nk_f32_t)query_norm_sq;
1759
+ nk_size_t count_columns = columns;
1760
+ nk_f32_t *result_ptr = result_row;
1761
+ nk_u32_t const *norms_ptr = target_norms;
1762
+ while (count_columns > 0) {
1763
+ size_t vector_length = __riscv_vsetvl_e32m1(count_columns);
1764
+ vfloat32m1_t dots_f32m1 = __riscv_vfcvt_f_xu_v_f32m1(
1765
+ __riscv_vle32_v_u32m1((nk_u32_t const *)result_ptr, vector_length), vector_length);
1766
+ vfloat32m1_t target_norms_sq_f32m1 = __riscv_vfcvt_f_xu_v_f32m1(
1767
+ __riscv_vle32_v_u32m1((nk_u32_t const *)norms_ptr, vector_length), vector_length);
1768
+ vfloat32m1_t norms_product_f32m1 = __riscv_vfmul_vf_f32m1(target_norms_sq_f32m1, query_norm_sq_f32,
1769
+ vector_length);
1770
+ vfloat32m1_t rsqrt_f32m1 = nk_rsqrt_f32m1_rvv_(norms_product_f32m1, vector_length);
1771
+ vfloat32m1_t normalized_dots_f32m1 = __riscv_vfmul_vv_f32m1(dots_f32m1, rsqrt_f32m1, vector_length);
1772
+ vfloat32m1_t angular_f32m1 = __riscv_vfrsub_vf_f32m1(normalized_dots_f32m1, 1.0f, vector_length);
1773
+ angular_f32m1 = __riscv_vfmax_vf_f32m1(angular_f32m1, 0.0f, vector_length);
1774
+ __riscv_vse32_v_f32m1(result_ptr, angular_f32m1, vector_length);
1775
+ result_ptr += vector_length;
1776
+ norms_ptr += vector_length;
1777
+ count_columns -= vector_length;
1778
+ }
1779
+ }
1780
+ }
1781
+
1782
+ NK_PUBLIC void nk_angulars_packed_u8_rvv( //
1783
+ nk_u8_t const *a, void const *b_packed, nk_f32_t *c, //
1784
+ nk_size_t rows, nk_size_t columns, nk_size_t depth, //
1785
+ nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
1786
+ nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_u8_t);
1787
+ nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
1788
+ nk_dots_packed_u8_rvv(a, b_packed, (nk_u32_t *)c, rows, columns, depth, a_stride_in_bytes, c_stride_in_bytes);
1789
+ nk_angulars_packed_u8_rvv_finalize_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
1790
+ }
1791
+
1792
+ NK_INTERNAL void nk_euclideans_packed_u8_rvv_finalize_(nk_u8_t const *a, void const *b_packed, nk_f32_t *c,
1793
+ nk_size_t rows, nk_size_t columns, nk_size_t depth,
1794
+ nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
1795
+ nk_cross_packed_buffer_header_t const *header = (nk_cross_packed_buffer_header_t const *)b_packed;
1796
+ nk_u32_t const *target_norms = (nk_u32_t const *)((char const *)b_packed + sizeof(nk_cross_packed_buffer_header_t) +
1797
+ header->column_count * header->depth_padded_values *
1798
+ sizeof(nk_u8_t));
1799
+ for (nk_size_t row_index = 0; row_index < rows; row_index++) {
1800
+ nk_u8_t const *a_row = a + row_index * a_stride_elements;
1801
+ nk_f32_t *result_row = c + row_index * c_stride_elements;
1802
+ nk_u32_t query_norm_sq = nk_dots_reduce_sumsq_u8_(a_row, depth);
1803
+ nk_f32_t query_norm_sq_f32 = (nk_f32_t)query_norm_sq;
1804
+ nk_size_t count_columns = columns;
1805
+ nk_f32_t *result_ptr = result_row;
1806
+ nk_u32_t const *norms_ptr = target_norms;
1807
+ while (count_columns > 0) {
1808
+ size_t vector_length = __riscv_vsetvl_e32m1(count_columns);
1809
+ vfloat32m1_t dots_f32m1 = __riscv_vfcvt_f_xu_v_f32m1(
1810
+ __riscv_vle32_v_u32m1((nk_u32_t const *)result_ptr, vector_length), vector_length);
1811
+ vfloat32m1_t target_norms_sq_f32m1 = __riscv_vfcvt_f_xu_v_f32m1(
1812
+ __riscv_vle32_v_u32m1((nk_u32_t const *)norms_ptr, vector_length), vector_length);
1813
+ vfloat32m1_t sum_sq_f32m1 = __riscv_vfadd_vf_f32m1(target_norms_sq_f32m1, query_norm_sq_f32, vector_length);
1814
+ vfloat32m1_t dist_sq_f32m1 = __riscv_vfsub_vv_f32m1(
1815
+ sum_sq_f32m1, __riscv_vfmul_vf_f32m1(dots_f32m1, 2.0f, vector_length), vector_length);
1816
+ dist_sq_f32m1 = __riscv_vfmax_vf_f32m1(dist_sq_f32m1, 0.0f, vector_length);
1817
+ __riscv_vse32_v_f32m1(result_ptr, __riscv_vfsqrt_v_f32m1(dist_sq_f32m1, vector_length), vector_length);
1818
+ result_ptr += vector_length;
1819
+ norms_ptr += vector_length;
1820
+ count_columns -= vector_length;
1821
+ }
1822
+ }
1823
+ }
1824
+
1825
+ NK_PUBLIC void nk_euclideans_packed_u8_rvv( //
1826
+ nk_u8_t const *a, void const *b_packed, nk_f32_t *c, //
1827
+ nk_size_t rows, nk_size_t columns, nk_size_t depth, //
1828
+ nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
1829
+ nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_u8_t);
1830
+ nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
1831
+ nk_dots_packed_u8_rvv(a, b_packed, (nk_u32_t *)c, rows, columns, depth, a_stride_in_bytes, c_stride_in_bytes);
1832
+ nk_euclideans_packed_u8_rvv_finalize_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
1833
+ }
1834
+
1835
+ NK_INTERNAL void nk_angulars_symmetric_u8_rvv_finalize_(nk_u8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1836
+ nk_size_t stride_elements, nk_f32_t *result,
1837
+ nk_size_t result_stride_elements, nk_size_t row_start,
1838
+ nk_size_t row_count) {
1839
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
1840
+ nk_u32_t norm = nk_dots_reduce_sumsq_u8_(vectors + row_index * stride_elements, depth);
1841
+ ((nk_u32_t *)(result + row_index * result_stride_elements))[row_index] = norm;
1842
+ }
1843
+ nk_u32_t norms_cache[256];
1844
+ for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
1845
+ nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
1846
+ for (nk_size_t col = chunk_start; col < chunk_end; ++col)
1847
+ norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_u8_(vectors + col * stride_elements, depth);
1848
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
1849
+ nk_size_t col_start = row_index + 1 > chunk_start ? row_index + 1 : chunk_start;
1850
+ if (col_start >= chunk_end) continue;
1851
+ nk_f32_t *result_row = result + row_index * result_stride_elements;
1852
+ nk_u32_t query_norm_sq = ((nk_u32_t *)result_row)[row_index];
1853
+ nk_f32_t query_norm_sq_f32 = (nk_f32_t)query_norm_sq;
1854
+ nk_size_t count_remaining = chunk_end - col_start;
1855
+ nk_f32_t *result_ptr = result_row + col_start;
1856
+ nk_u32_t const *norms_ptr = norms_cache + (col_start - chunk_start);
1857
+ while (count_remaining > 0) {
1858
+ size_t vector_length = __riscv_vsetvl_e32m1(count_remaining);
1859
+ vfloat32m1_t dots_f32m1 = __riscv_vfcvt_f_xu_v_f32m1(
1860
+ __riscv_vle32_v_u32m1((nk_u32_t const *)result_ptr, vector_length), vector_length);
1861
+ vfloat32m1_t target_norms_sq_f32m1 = __riscv_vfcvt_f_xu_v_f32m1(
1862
+ __riscv_vle32_v_u32m1((nk_u32_t const *)norms_ptr, vector_length), vector_length);
1863
+ vfloat32m1_t norms_product_f32m1 = __riscv_vfmul_vf_f32m1(target_norms_sq_f32m1, query_norm_sq_f32,
1864
+ vector_length);
1865
+ vfloat32m1_t rsqrt_f32m1 = nk_rsqrt_f32m1_rvv_(norms_product_f32m1, vector_length);
1866
+ vfloat32m1_t normalized_dots_f32m1 = __riscv_vfmul_vv_f32m1(dots_f32m1, rsqrt_f32m1, vector_length);
1867
+ vfloat32m1_t angular_f32m1 = __riscv_vfrsub_vf_f32m1(normalized_dots_f32m1, 1.0f, vector_length);
1868
+ angular_f32m1 = __riscv_vfmax_vf_f32m1(angular_f32m1, 0.0f, vector_length);
1869
+ __riscv_vse32_v_f32m1(result_ptr, angular_f32m1, vector_length);
1870
+ result_ptr += vector_length;
1871
+ norms_ptr += vector_length;
1872
+ count_remaining -= vector_length;
1873
+ }
1874
+ }
1875
+ }
1876
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index)
1877
+ result[row_index * result_stride_elements + row_index] = 0;
1878
+ }
1879
+
1880
+ NK_PUBLIC void nk_angulars_symmetric_u8_rvv( //
1881
+ nk_u8_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
1882
+ nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
1883
+ nk_size_t const stride_elements = stride / sizeof(nk_u8_t);
1884
+ nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
1885
+ nk_dots_symmetric_u8_rvv(vectors, n_vectors, depth, stride, (nk_u32_t *)result, result_stride, row_start,
1886
+ row_count);
1887
+ nk_angulars_symmetric_u8_rvv_finalize_(vectors, n_vectors, depth, stride_elements, result, result_stride_elements,
1888
+ row_start, row_count);
1889
+ }
1890
+
1891
+ NK_INTERNAL void nk_euclideans_symmetric_u8_rvv_finalize_(nk_u8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1892
+ nk_size_t stride_elements, nk_f32_t *result,
1893
+ nk_size_t result_stride_elements, nk_size_t row_start,
1894
+ nk_size_t row_count) {
1895
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
1896
+ nk_u32_t norm = nk_dots_reduce_sumsq_u8_(vectors + row_index * stride_elements, depth);
1897
+ ((nk_u32_t *)(result + row_index * result_stride_elements))[row_index] = norm;
1898
+ }
1899
+ nk_u32_t norms_cache[256];
1900
+ for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
1901
+ nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
1902
+ for (nk_size_t col = chunk_start; col < chunk_end; ++col)
1903
+ norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_u8_(vectors + col * stride_elements, depth);
1904
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
1905
+ nk_size_t col_start = row_index + 1 > chunk_start ? row_index + 1 : chunk_start;
1906
+ if (col_start >= chunk_end) continue;
1907
+ nk_f32_t *result_row = result + row_index * result_stride_elements;
1908
+ nk_u32_t query_norm_sq = ((nk_u32_t *)result_row)[row_index];
1909
+ nk_f32_t query_norm_sq_f32 = (nk_f32_t)query_norm_sq;
1910
+ nk_size_t count_remaining = chunk_end - col_start;
1911
+ nk_f32_t *result_ptr = result_row + col_start;
1912
+ nk_u32_t const *norms_ptr = norms_cache + (col_start - chunk_start);
1913
+ while (count_remaining > 0) {
1914
+ size_t vector_length = __riscv_vsetvl_e32m1(count_remaining);
1915
+ vfloat32m1_t dots_f32m1 = __riscv_vfcvt_f_xu_v_f32m1(
1916
+ __riscv_vle32_v_u32m1((nk_u32_t const *)result_ptr, vector_length), vector_length);
1917
+ vfloat32m1_t target_norms_sq_f32m1 = __riscv_vfcvt_f_xu_v_f32m1(
1918
+ __riscv_vle32_v_u32m1((nk_u32_t const *)norms_ptr, vector_length), vector_length);
1919
+ vfloat32m1_t sum_sq_f32m1 = __riscv_vfadd_vf_f32m1(target_norms_sq_f32m1, query_norm_sq_f32,
1920
+ vector_length);
1921
+ vfloat32m1_t dist_sq_f32m1 = __riscv_vfsub_vv_f32m1(
1922
+ sum_sq_f32m1, __riscv_vfmul_vf_f32m1(dots_f32m1, 2.0f, vector_length), vector_length);
1923
+ dist_sq_f32m1 = __riscv_vfmax_vf_f32m1(dist_sq_f32m1, 0.0f, vector_length);
1924
+ __riscv_vse32_v_f32m1(result_ptr, __riscv_vfsqrt_v_f32m1(dist_sq_f32m1, vector_length), vector_length);
1925
+ result_ptr += vector_length;
1926
+ norms_ptr += vector_length;
1927
+ count_remaining -= vector_length;
1928
+ }
1929
+ }
1930
+ }
1931
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index)
1932
+ result[row_index * result_stride_elements + row_index] = 0;
1933
+ }
1934
+
1935
+ NK_PUBLIC void nk_euclideans_symmetric_u8_rvv( //
1936
+ nk_u8_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
1937
+ nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
1938
+ nk_size_t const stride_elements = stride / sizeof(nk_u8_t);
1939
+ nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
1940
+ nk_dots_symmetric_u8_rvv(vectors, n_vectors, depth, stride, (nk_u32_t *)result, result_stride, row_start,
1941
+ row_count);
1942
+ nk_euclideans_symmetric_u8_rvv_finalize_(vectors, n_vectors, depth, stride_elements, result, result_stride_elements,
1943
+ row_start, row_count);
1944
+ }
1945
+
1946
+ #pragma endregion // Unsigned 8-bit Integers
1947
+
1948
+ #if defined(__clang__)
1949
+ #pragma clang attribute pop
1950
+ #elif defined(__GNUC__)
1951
+ #pragma GCC pop_options
1952
+ #endif
1953
+
1954
+ #if defined(__cplusplus)
1955
+ } // extern "C"
1956
+ #endif
1957
+
1958
+ #endif // NK_TARGET_RVV
1959
+ #endif // NK_TARGET_RISCV_
1960
+ #endif // NK_SPATIALS_RVV_H