numkong 7.0.0 → 7.4.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (315) hide show
  1. package/README.md +239 -122
  2. package/binding.gyp +25 -491
  3. package/c/dispatch_bf16.c +59 -1
  4. package/c/dispatch_e2m3.c +41 -8
  5. package/c/dispatch_e3m2.c +49 -8
  6. package/c/dispatch_e4m3.c +51 -9
  7. package/c/dispatch_e5m2.c +45 -1
  8. package/c/dispatch_f16.c +79 -26
  9. package/c/dispatch_f16c.c +5 -5
  10. package/c/dispatch_f32.c +56 -0
  11. package/c/dispatch_f64.c +52 -0
  12. package/c/dispatch_i4.c +3 -0
  13. package/c/dispatch_i8.c +62 -3
  14. package/c/dispatch_other.c +18 -0
  15. package/c/dispatch_u1.c +54 -9
  16. package/c/dispatch_u4.c +3 -0
  17. package/c/dispatch_u8.c +64 -3
  18. package/c/numkong.c +3 -0
  19. package/include/README.md +79 -9
  20. package/include/numkong/attention/sapphireamx.h +278 -276
  21. package/include/numkong/attention/sme.h +983 -977
  22. package/include/numkong/attention.h +1 -1
  23. package/include/numkong/capabilities.h +289 -94
  24. package/include/numkong/cast/README.md +40 -40
  25. package/include/numkong/cast/diamond.h +64 -0
  26. package/include/numkong/cast/haswell.h +42 -194
  27. package/include/numkong/cast/icelake.h +42 -37
  28. package/include/numkong/cast/loongsonasx.h +252 -0
  29. package/include/numkong/cast/neon.h +216 -249
  30. package/include/numkong/cast/powervsx.h +449 -0
  31. package/include/numkong/cast/rvv.h +223 -274
  32. package/include/numkong/cast/sapphire.h +18 -18
  33. package/include/numkong/cast/serial.h +1018 -944
  34. package/include/numkong/cast/skylake.h +82 -23
  35. package/include/numkong/cast/v128relaxed.h +462 -105
  36. package/include/numkong/cast.h +24 -0
  37. package/include/numkong/cast.hpp +44 -0
  38. package/include/numkong/curved/README.md +17 -17
  39. package/include/numkong/curved/neon.h +131 -7
  40. package/include/numkong/curved/neonbfdot.h +6 -7
  41. package/include/numkong/curved/rvv.h +26 -26
  42. package/include/numkong/curved/smef64.h +186 -182
  43. package/include/numkong/curved.h +14 -18
  44. package/include/numkong/dot/README.md +154 -137
  45. package/include/numkong/dot/alder.h +43 -43
  46. package/include/numkong/dot/diamond.h +158 -0
  47. package/include/numkong/dot/genoa.h +4 -30
  48. package/include/numkong/dot/haswell.h +215 -180
  49. package/include/numkong/dot/icelake.h +190 -76
  50. package/include/numkong/dot/loongsonasx.h +671 -0
  51. package/include/numkong/dot/neon.h +124 -73
  52. package/include/numkong/dot/neonbfdot.h +11 -12
  53. package/include/numkong/dot/neonfhm.h +44 -46
  54. package/include/numkong/dot/neonfp8.h +323 -0
  55. package/include/numkong/dot/neonsdot.h +190 -76
  56. package/include/numkong/dot/powervsx.h +752 -0
  57. package/include/numkong/dot/rvv.h +92 -84
  58. package/include/numkong/dot/rvvbf16.h +12 -12
  59. package/include/numkong/dot/rvvhalf.h +12 -12
  60. package/include/numkong/dot/sapphire.h +4 -4
  61. package/include/numkong/dot/serial.h +66 -30
  62. package/include/numkong/dot/sierra.h +31 -31
  63. package/include/numkong/dot/skylake.h +142 -110
  64. package/include/numkong/dot/sve.h +217 -177
  65. package/include/numkong/dot/svebfdot.h +10 -10
  66. package/include/numkong/dot/svehalf.h +85 -41
  67. package/include/numkong/dot/svesdot.h +89 -0
  68. package/include/numkong/dot/v128relaxed.h +124 -89
  69. package/include/numkong/dot.h +114 -48
  70. package/include/numkong/dots/README.md +203 -203
  71. package/include/numkong/dots/alder.h +12 -9
  72. package/include/numkong/dots/diamond.h +86 -0
  73. package/include/numkong/dots/genoa.h +10 -4
  74. package/include/numkong/dots/haswell.h +63 -48
  75. package/include/numkong/dots/icelake.h +27 -18
  76. package/include/numkong/dots/loongsonasx.h +176 -0
  77. package/include/numkong/dots/neon.h +14 -11
  78. package/include/numkong/dots/neonbfdot.h +4 -3
  79. package/include/numkong/dots/neonfhm.h +11 -9
  80. package/include/numkong/dots/neonfp8.h +99 -0
  81. package/include/numkong/dots/neonsdot.h +48 -12
  82. package/include/numkong/dots/powervsx.h +194 -0
  83. package/include/numkong/dots/rvv.h +451 -344
  84. package/include/numkong/dots/sapphireamx.h +1028 -984
  85. package/include/numkong/dots/serial.h +213 -197
  86. package/include/numkong/dots/sierra.h +10 -7
  87. package/include/numkong/dots/skylake.h +47 -36
  88. package/include/numkong/dots/sme.h +2001 -2364
  89. package/include/numkong/dots/smebi32.h +175 -162
  90. package/include/numkong/dots/smef64.h +328 -323
  91. package/include/numkong/dots/v128relaxed.h +64 -41
  92. package/include/numkong/dots.h +573 -293
  93. package/include/numkong/dots.hpp +45 -43
  94. package/include/numkong/each/README.md +133 -137
  95. package/include/numkong/each/haswell.h +6 -6
  96. package/include/numkong/each/icelake.h +7 -7
  97. package/include/numkong/each/neon.h +76 -42
  98. package/include/numkong/each/neonbfdot.h +11 -12
  99. package/include/numkong/each/neonhalf.h +24 -116
  100. package/include/numkong/each/rvv.h +28 -28
  101. package/include/numkong/each/sapphire.h +27 -161
  102. package/include/numkong/each/serial.h +6 -6
  103. package/include/numkong/each/skylake.h +7 -7
  104. package/include/numkong/each/v128relaxed.h +562 -0
  105. package/include/numkong/each.h +148 -62
  106. package/include/numkong/each.hpp +2 -2
  107. package/include/numkong/geospatial/README.md +18 -18
  108. package/include/numkong/geospatial/haswell.h +365 -325
  109. package/include/numkong/geospatial/neon.h +350 -306
  110. package/include/numkong/geospatial/rvv.h +4 -4
  111. package/include/numkong/geospatial/skylake.h +376 -340
  112. package/include/numkong/geospatial/v128relaxed.h +366 -327
  113. package/include/numkong/geospatial.h +17 -17
  114. package/include/numkong/matrix.hpp +4 -4
  115. package/include/numkong/maxsim/README.md +14 -14
  116. package/include/numkong/maxsim/alder.h +6 -6
  117. package/include/numkong/maxsim/genoa.h +4 -4
  118. package/include/numkong/maxsim/haswell.h +6 -6
  119. package/include/numkong/maxsim/icelake.h +18 -18
  120. package/include/numkong/maxsim/neonsdot.h +21 -21
  121. package/include/numkong/maxsim/sapphireamx.h +14 -14
  122. package/include/numkong/maxsim/serial.h +6 -6
  123. package/include/numkong/maxsim/sme.h +221 -196
  124. package/include/numkong/maxsim/v128relaxed.h +6 -6
  125. package/include/numkong/mesh/README.md +62 -56
  126. package/include/numkong/mesh/haswell.h +339 -464
  127. package/include/numkong/mesh/neon.h +1100 -519
  128. package/include/numkong/mesh/neonbfdot.h +36 -68
  129. package/include/numkong/mesh/rvv.h +530 -435
  130. package/include/numkong/mesh/serial.h +75 -91
  131. package/include/numkong/mesh/skylake.h +1627 -302
  132. package/include/numkong/mesh/v128relaxed.h +443 -330
  133. package/include/numkong/mesh.h +63 -49
  134. package/include/numkong/mesh.hpp +4 -4
  135. package/include/numkong/numkong.h +3 -3
  136. package/include/numkong/numkong.hpp +1 -0
  137. package/include/numkong/probability/README.md +23 -19
  138. package/include/numkong/probability/neon.h +82 -52
  139. package/include/numkong/probability/rvv.h +28 -23
  140. package/include/numkong/probability/serial.h +51 -39
  141. package/include/numkong/probability.h +20 -23
  142. package/include/numkong/random.h +1 -1
  143. package/include/numkong/reduce/README.md +143 -138
  144. package/include/numkong/reduce/alder.h +81 -77
  145. package/include/numkong/reduce/haswell.h +222 -220
  146. package/include/numkong/reduce/neon.h +629 -519
  147. package/include/numkong/reduce/neonbfdot.h +7 -218
  148. package/include/numkong/reduce/neonfhm.h +9 -381
  149. package/include/numkong/reduce/neonsdot.h +9 -9
  150. package/include/numkong/reduce/rvv.h +928 -802
  151. package/include/numkong/reduce/serial.h +23 -27
  152. package/include/numkong/reduce/sierra.h +20 -20
  153. package/include/numkong/reduce/skylake.h +326 -324
  154. package/include/numkong/reduce/v128relaxed.h +52 -52
  155. package/include/numkong/reduce.h +4 -23
  156. package/include/numkong/reduce.hpp +156 -11
  157. package/include/numkong/scalar/README.md +6 -6
  158. package/include/numkong/scalar/haswell.h +26 -17
  159. package/include/numkong/scalar/loongsonasx.h +74 -0
  160. package/include/numkong/scalar/neon.h +9 -9
  161. package/include/numkong/scalar/powervsx.h +96 -0
  162. package/include/numkong/scalar/rvv.h +2 -2
  163. package/include/numkong/scalar/sapphire.h +21 -10
  164. package/include/numkong/scalar/serial.h +21 -21
  165. package/include/numkong/scalar.h +13 -0
  166. package/include/numkong/set/README.md +28 -28
  167. package/include/numkong/set/haswell.h +12 -12
  168. package/include/numkong/set/icelake.h +14 -14
  169. package/include/numkong/set/loongsonasx.h +181 -0
  170. package/include/numkong/set/neon.h +17 -18
  171. package/include/numkong/set/powervsx.h +326 -0
  172. package/include/numkong/set/rvv.h +4 -4
  173. package/include/numkong/set/serial.h +6 -6
  174. package/include/numkong/set/sve.h +60 -59
  175. package/include/numkong/set/v128relaxed.h +6 -6
  176. package/include/numkong/set.h +21 -7
  177. package/include/numkong/sets/README.md +26 -26
  178. package/include/numkong/sets/loongsonasx.h +52 -0
  179. package/include/numkong/sets/powervsx.h +65 -0
  180. package/include/numkong/sets/smebi32.h +395 -364
  181. package/include/numkong/sets.h +83 -40
  182. package/include/numkong/sparse/README.md +4 -4
  183. package/include/numkong/sparse/icelake.h +101 -101
  184. package/include/numkong/sparse/serial.h +1 -1
  185. package/include/numkong/sparse/sve2.h +137 -141
  186. package/include/numkong/sparse/turin.h +12 -12
  187. package/include/numkong/sparse.h +10 -10
  188. package/include/numkong/spatial/README.md +230 -226
  189. package/include/numkong/spatial/alder.h +113 -116
  190. package/include/numkong/spatial/diamond.h +240 -0
  191. package/include/numkong/spatial/genoa.h +0 -68
  192. package/include/numkong/spatial/haswell.h +74 -55
  193. package/include/numkong/spatial/icelake.h +539 -58
  194. package/include/numkong/spatial/loongsonasx.h +483 -0
  195. package/include/numkong/spatial/neon.h +125 -52
  196. package/include/numkong/spatial/neonbfdot.h +8 -9
  197. package/include/numkong/spatial/neonfp8.h +258 -0
  198. package/include/numkong/spatial/neonsdot.h +180 -12
  199. package/include/numkong/spatial/powervsx.h +738 -0
  200. package/include/numkong/spatial/rvv.h +146 -139
  201. package/include/numkong/spatial/rvvbf16.h +17 -12
  202. package/include/numkong/spatial/rvvhalf.h +13 -10
  203. package/include/numkong/spatial/serial.h +13 -12
  204. package/include/numkong/spatial/sierra.h +232 -39
  205. package/include/numkong/spatial/skylake.h +73 -74
  206. package/include/numkong/spatial/sve.h +93 -72
  207. package/include/numkong/spatial/svebfdot.h +29 -29
  208. package/include/numkong/spatial/svehalf.h +52 -26
  209. package/include/numkong/spatial/svesdot.h +142 -0
  210. package/include/numkong/spatial/v128relaxed.h +293 -41
  211. package/include/numkong/spatial.h +338 -82
  212. package/include/numkong/spatials/README.md +194 -194
  213. package/include/numkong/spatials/diamond.h +82 -0
  214. package/include/numkong/spatials/haswell.h +2 -2
  215. package/include/numkong/spatials/loongsonasx.h +153 -0
  216. package/include/numkong/spatials/neonfp8.h +111 -0
  217. package/include/numkong/spatials/neonsdot.h +34 -0
  218. package/include/numkong/spatials/powervsx.h +153 -0
  219. package/include/numkong/spatials/rvv.h +259 -243
  220. package/include/numkong/spatials/sapphireamx.h +173 -173
  221. package/include/numkong/spatials/serial.h +2 -2
  222. package/include/numkong/spatials/skylake.h +2 -2
  223. package/include/numkong/spatials/sme.h +590 -605
  224. package/include/numkong/spatials/smef64.h +139 -130
  225. package/include/numkong/spatials/v128relaxed.h +2 -2
  226. package/include/numkong/spatials.h +820 -500
  227. package/include/numkong/spatials.hpp +49 -48
  228. package/include/numkong/tensor.hpp +406 -17
  229. package/include/numkong/trigonometry/README.md +19 -19
  230. package/include/numkong/trigonometry/haswell.h +402 -401
  231. package/include/numkong/trigonometry/neon.h +386 -387
  232. package/include/numkong/trigonometry/rvv.h +52 -51
  233. package/include/numkong/trigonometry/serial.h +13 -13
  234. package/include/numkong/trigonometry/skylake.h +373 -369
  235. package/include/numkong/trigonometry/v128relaxed.h +375 -374
  236. package/include/numkong/trigonometry.h +13 -13
  237. package/include/numkong/trigonometry.hpp +2 -2
  238. package/include/numkong/types.h +287 -49
  239. package/include/numkong/types.hpp +436 -12
  240. package/include/numkong/vector.hpp +82 -14
  241. package/javascript/dist/cjs/numkong-wasm.js +6 -12
  242. package/javascript/dist/cjs/numkong.d.ts +7 -1
  243. package/javascript/dist/cjs/numkong.js +37 -11
  244. package/javascript/dist/cjs/types.d.ts +9 -0
  245. package/javascript/dist/cjs/types.js +96 -0
  246. package/javascript/dist/esm/numkong-browser.d.ts +14 -0
  247. package/javascript/dist/esm/numkong-browser.js +23 -0
  248. package/javascript/dist/esm/numkong-wasm.js +6 -12
  249. package/javascript/dist/esm/numkong.d.ts +7 -1
  250. package/javascript/dist/esm/numkong.js +37 -11
  251. package/javascript/dist/esm/types.d.ts +9 -0
  252. package/javascript/dist/esm/types.js +96 -0
  253. package/javascript/node-gyp-build.d.ts +4 -1
  254. package/javascript/numkong-browser.ts +40 -0
  255. package/javascript/numkong-wasm.ts +7 -13
  256. package/javascript/numkong.c +5 -26
  257. package/javascript/numkong.ts +36 -11
  258. package/javascript/tsconfig-base.json +1 -0
  259. package/javascript/tsconfig-cjs.json +6 -1
  260. package/javascript/types.ts +110 -0
  261. package/numkong.gypi +101 -0
  262. package/package.json +34 -13
  263. package/probes/arm_neon.c +8 -0
  264. package/probes/arm_neon_bfdot.c +9 -0
  265. package/probes/arm_neon_fhm.c +9 -0
  266. package/probes/arm_neon_half.c +8 -0
  267. package/probes/arm_neon_sdot.c +9 -0
  268. package/probes/arm_neonfp8.c +9 -0
  269. package/probes/arm_sme.c +16 -0
  270. package/probes/arm_sme2.c +16 -0
  271. package/probes/arm_sme2p1.c +16 -0
  272. package/probes/arm_sme_bf16.c +16 -0
  273. package/probes/arm_sme_bi32.c +16 -0
  274. package/probes/arm_sme_f64.c +16 -0
  275. package/probes/arm_sme_fa64.c +14 -0
  276. package/probes/arm_sme_half.c +16 -0
  277. package/probes/arm_sme_lut2.c +15 -0
  278. package/probes/arm_sve.c +18 -0
  279. package/probes/arm_sve2.c +20 -0
  280. package/probes/arm_sve2p1.c +18 -0
  281. package/probes/arm_sve_bfdot.c +20 -0
  282. package/probes/arm_sve_half.c +18 -0
  283. package/probes/arm_sve_sdot.c +21 -0
  284. package/probes/loongarch_lasx.c +12 -0
  285. package/probes/power_vsx.c +12 -0
  286. package/probes/probe.js +127 -0
  287. package/probes/riscv_rvv.c +14 -0
  288. package/probes/riscv_rvv_bb.c +15 -0
  289. package/probes/riscv_rvv_bf16.c +17 -0
  290. package/probes/riscv_rvv_half.c +14 -0
  291. package/probes/wasm_v128relaxed.c +11 -0
  292. package/probes/x86_alder.c +17 -0
  293. package/probes/x86_diamond.c +17 -0
  294. package/probes/x86_genoa.c +17 -0
  295. package/probes/x86_graniteamx.c +19 -0
  296. package/probes/x86_haswell.c +11 -0
  297. package/probes/x86_icelake.c +17 -0
  298. package/probes/x86_sapphire.c +16 -0
  299. package/probes/x86_sapphireamx.c +18 -0
  300. package/probes/x86_sierra.c +17 -0
  301. package/probes/x86_skylake.c +15 -0
  302. package/probes/x86_turin.c +17 -0
  303. package/wasm/numkong-emscripten.js +2 -0
  304. package/wasm/numkong.d.ts +14 -0
  305. package/wasm/numkong.js +1124 -0
  306. package/wasm/numkong.wasm +0 -0
  307. package/include/numkong/curved/neonhalf.h +0 -212
  308. package/include/numkong/dot/neonhalf.h +0 -198
  309. package/include/numkong/dots/neonhalf.h +0 -57
  310. package/include/numkong/mesh/neonhalf.h +0 -616
  311. package/include/numkong/reduce/neonhalf.h +0 -157
  312. package/include/numkong/spatial/neonhalf.h +0 -118
  313. package/include/numkong/spatial/sapphire.h +0 -343
  314. package/include/numkong/spatials/neonhalf.h +0 -58
  315. package/javascript/README.md +0 -246
@@ -20,226 +20,211 @@ extern "C" {
20
20
  #endif
21
21
 
22
22
  #if defined(__clang__)
23
- #pragma clang attribute push(__attribute__((target("sme,sve"))), apply_to = function)
23
+ #pragma clang attribute push(__attribute__((target("sme"))), apply_to = function)
24
24
  #elif defined(__GNUC__)
25
25
  #pragma GCC push_options
26
26
  #pragma GCC target("+sme")
27
27
  #endif
28
28
 
29
- NK_PUBLIC nk_f32_t nk_dots_reduce_sumsq_f16_ssve_(nk_f16_t const *data, nk_size_t count) NK_STREAMING_COMPATIBLE_ {
30
- svfloat32_t accumulator_f32x = svdup_f32(0.0f);
31
- nk_size_t const vector_length = svcntw();
29
+ NK_PUBLIC nk_f32_t nk_dots_reduce_sumsq_f16_ssve_(nk_f16_t const *data, nk_size_t count) NK_STREAMING_ {
30
+ svfloat32_t accumulator_even_f32x = svdup_f32(0.0f);
31
+ svfloat32_t accumulator_odd_f32x = svdup_f32(0.0f);
32
+ nk_size_t const vector_length = svcnth();
33
+ nk_size_t const half_vector_length = svcntw();
32
34
  for (nk_size_t i = 0; i < count; i += vector_length) {
33
- svbool_t predicate_f32x = svwhilelt_b32_u64(i, count);
34
- svfloat32_t values_f32x = svcvt_f32_f16_x(
35
- predicate_f32x, svld1_f16(svwhilelt_b16_u64(i, count), (nk_f16_for_arm_simd_t const *)(data + i)));
36
- accumulator_f32x = svmla_f32_x(predicate_f32x, accumulator_f32x, values_f32x, values_f32x);
35
+ svbool_t predicate_b16x = svwhilelt_b16_u64(i, count);
36
+ svfloat16_t values_f16x = svld1_f16(predicate_b16x, (nk_f16_for_arm_simd_t const *)(data + i));
37
+
38
+ svbool_t predicate_even_b32x = svwhilelt_b32_u64(i, count);
39
+ svfloat32_t values_even_f32x = svcvt_f32_f16_x(predicate_even_b32x, values_f16x);
40
+ accumulator_even_f32x = svmla_f32_m(predicate_even_b32x, accumulator_even_f32x, values_even_f32x,
41
+ values_even_f32x);
42
+
43
+ svbool_t predicate_odd_b32x = svwhilelt_b32_u64(i + half_vector_length, count);
44
+ svfloat32_t values_odd_f32x = svcvtlt_f32_f16_x(predicate_odd_b32x, values_f16x);
45
+ accumulator_odd_f32x = svmla_f32_m(predicate_odd_b32x, accumulator_odd_f32x, values_odd_f32x, values_odd_f32x);
37
46
  }
38
- return svaddv_f32(svptrue_b32(), accumulator_f32x);
47
+ return svaddv_f32(svptrue_b32(), accumulator_even_f32x) + svaddv_f32(svptrue_b32(), accumulator_odd_f32x);
39
48
  }
40
49
 
41
- NK_PUBLIC nk_f32_t nk_dots_reduce_sumsq_bf16_ssve_(nk_bf16_t const *data, nk_size_t count) NK_STREAMING_COMPATIBLE_ {
50
+ NK_PUBLIC nk_f32_t nk_dots_reduce_sumsq_bf16_ssve_(nk_bf16_t const *data, nk_size_t count) NK_STREAMING_ {
42
51
  svfloat32_t accumulator_f32x = svdup_f32(0.0f);
43
- nk_size_t const vector_length = svcntw();
52
+ nk_size_t const vector_length = svcnth();
44
53
  for (nk_size_t i = 0; i < count; i += vector_length) {
45
- svbool_t predicate_f32x = svwhilelt_b32_u64(i, count);
46
- svuint16_t raw_u16x = svld1_u16(svwhilelt_b16_u64(i, count), (nk_u16_t const *)data + i);
47
- svfloat32_t values_f32x = svreinterpret_f32_u32(svlsl_n_u32_x(predicate_f32x, svunpklo_u32(raw_u16x), 16));
48
- accumulator_f32x = svmla_f32_x(predicate_f32x, accumulator_f32x, values_f32x, values_f32x);
54
+ svbool_t predicate_b16x = svwhilelt_b16_u64(i, count);
55
+ svbfloat16_t values_bf16x = svld1_bf16(predicate_b16x, (nk_bf16_for_arm_simd_t const *)(data + i));
56
+ accumulator_f32x = svbfdot_f32(accumulator_f32x, values_bf16x, values_bf16x);
49
57
  }
50
58
  return svaddv_f32(svptrue_b32(), accumulator_f32x);
51
59
  }
52
60
 
53
61
  NK_PUBLIC nk_f32_t nk_dots_reduce_sumsq_e4m3_ssve_(nk_e4m3_t const *data, nk_size_t count) NK_STREAMING_ {
54
- svfloat32_t accumulator_lo_f32x = svdup_f32(0.0f);
55
- svfloat32_t accumulator_hi_f32x = svdup_f32(0.0f);
56
- svuint16_t subnorm_lut_u16x = svld1_u16(svwhilelt_b16(0u, 8u), nk_e4m3_subnorm_f16_lut_);
62
+ svfloat32_t accumulator_even_f32x = svdup_f32(0.0f);
63
+ svfloat32_t accumulator_odd_f32x = svdup_f32(0.0f);
57
64
  nk_size_t const vector_length = svcnth();
58
65
  nk_size_t const half_vector_length = svcntw();
59
66
  for (nk_size_t i = 0; i < count; i += vector_length) {
60
67
  nk_size_t const batch_size = (i + vector_length < count) ? vector_length : (count - i);
61
- svbool_t predicate_i8x = svwhilelt_b8_u64(0u, batch_size);
62
- svbool_t predicate_f16x = svwhilelt_b16_u64(0u, batch_size);
63
- svuint8_t raw_u8x = svld1_u8(predicate_i8x, (nk_u8_t const *)data + i);
64
- svfloat16_t values_f16x = nk_e4m3x_to_f16x_ssve_(predicate_f16x, raw_u8x, subnorm_lut_u16x);
68
+ svbool_t predicate_b8x = svwhilelt_b8_u64(0u, batch_size);
69
+ svbool_t predicate_b16x = svwhilelt_b16_u64(0u, batch_size);
70
+ svuint8_t raw_u8x = svld1_u8(predicate_b8x, (nk_u8_t const *)data + i);
71
+ svfloat16_t values_f16x = nk_e4m3x_to_f16x_ssve_(predicate_b16x, raw_u8x);
65
72
 
66
- svbool_t predicate_lo_f32x = svwhilelt_b32_u64(0u, batch_size);
67
- svfloat32_t values_lo_f32x = svcvt_f32_f16_x(predicate_lo_f32x, values_f16x);
68
- accumulator_lo_f32x = svmla_f32_m(predicate_lo_f32x, accumulator_lo_f32x, values_lo_f32x, values_lo_f32x);
73
+ svbool_t predicate_even_b32x = svwhilelt_b32_u64(0u, batch_size);
74
+ svfloat32_t values_even_f32x = svcvt_f32_f16_x(predicate_even_b32x, values_f16x);
75
+ accumulator_even_f32x = svmla_f32_m(predicate_even_b32x, accumulator_even_f32x, values_even_f32x,
76
+ values_even_f32x);
69
77
 
70
- svbool_t predicate_hi_f32x = svwhilelt_b32_u64(half_vector_length, batch_size);
71
- svfloat32_t values_hi_f32x = svcvtlt_f32_f16_x(predicate_hi_f32x, values_f16x);
72
- accumulator_hi_f32x = svmla_f32_m(predicate_hi_f32x, accumulator_hi_f32x, values_hi_f32x, values_hi_f32x);
78
+ svbool_t predicate_odd_b32x = svwhilelt_b32_u64(half_vector_length, batch_size);
79
+ svfloat32_t values_odd_f32x = svcvtlt_f32_f16_x(predicate_odd_b32x, values_f16x);
80
+ accumulator_odd_f32x = svmla_f32_m(predicate_odd_b32x, accumulator_odd_f32x, values_odd_f32x, values_odd_f32x);
73
81
  }
74
- return svaddv_f32(svptrue_b32(), accumulator_lo_f32x) + svaddv_f32(svptrue_b32(), accumulator_hi_f32x);
82
+ return svaddv_f32(svptrue_b32(), accumulator_even_f32x) + svaddv_f32(svptrue_b32(), accumulator_odd_f32x);
75
83
  }
76
84
 
77
85
  NK_PUBLIC nk_f32_t nk_dots_reduce_sumsq_e5m2_ssve_(nk_e5m2_t const *data, nk_size_t count) NK_STREAMING_ {
78
- svfloat32_t accumulator_lo_f32x = svdup_f32(0.0f);
79
- svfloat32_t accumulator_hi_f32x = svdup_f32(0.0f);
86
+ svfloat32_t accumulator_even_f32x = svdup_f32(0.0f);
87
+ svfloat32_t accumulator_odd_f32x = svdup_f32(0.0f);
80
88
  nk_size_t const vector_length = svcnth();
81
89
  nk_size_t const half_vector_length = svcntw();
82
90
  for (nk_size_t i = 0; i < count; i += vector_length) {
83
91
  nk_size_t const batch_size = (i + vector_length < count) ? vector_length : (count - i);
84
- svbool_t predicate_i8x = svwhilelt_b8_u64(0u, batch_size);
85
- svbool_t predicate_f16x = svwhilelt_b16_u64(0u, batch_size);
86
- svuint8_t raw_u8x = svld1_u8(predicate_i8x, (nk_u8_t const *)data + i);
87
- svfloat16_t values_f16x = nk_e5m2x_to_f16x_ssve_(predicate_f16x, raw_u8x);
92
+ svbool_t predicate_b8x = svwhilelt_b8_u64(0u, batch_size);
93
+ svbool_t predicate_b16x = svwhilelt_b16_u64(0u, batch_size);
94
+ svuint8_t raw_u8x = svld1_u8(predicate_b8x, (nk_u8_t const *)data + i);
95
+ svfloat16_t values_f16x = nk_e5m2x_to_f16x_ssve_(predicate_b16x, raw_u8x);
88
96
 
89
- svbool_t predicate_lo_f32x = svwhilelt_b32_u64(0u, batch_size);
90
- svfloat32_t values_lo_f32x = svcvt_f32_f16_x(predicate_lo_f32x, values_f16x);
91
- accumulator_lo_f32x = svmla_f32_m(predicate_lo_f32x, accumulator_lo_f32x, values_lo_f32x, values_lo_f32x);
97
+ svbool_t predicate_even_b32x = svwhilelt_b32_u64(0u, batch_size);
98
+ svfloat32_t values_even_f32x = svcvt_f32_f16_x(predicate_even_b32x, values_f16x);
99
+ accumulator_even_f32x = svmla_f32_m(predicate_even_b32x, accumulator_even_f32x, values_even_f32x,
100
+ values_even_f32x);
92
101
 
93
- svbool_t predicate_hi_f32x = svwhilelt_b32_u64(half_vector_length, batch_size);
94
- svfloat32_t values_hi_f32x = svcvtlt_f32_f16_x(predicate_hi_f32x, values_f16x);
95
- accumulator_hi_f32x = svmla_f32_m(predicate_hi_f32x, accumulator_hi_f32x, values_hi_f32x, values_hi_f32x);
102
+ svbool_t predicate_odd_b32x = svwhilelt_b32_u64(half_vector_length, batch_size);
103
+ svfloat32_t values_odd_f32x = svcvtlt_f32_f16_x(predicate_odd_b32x, values_f16x);
104
+ accumulator_odd_f32x = svmla_f32_m(predicate_odd_b32x, accumulator_odd_f32x, values_odd_f32x, values_odd_f32x);
96
105
  }
97
- return svaddv_f32(svptrue_b32(), accumulator_lo_f32x) + svaddv_f32(svptrue_b32(), accumulator_hi_f32x);
106
+ return svaddv_f32(svptrue_b32(), accumulator_even_f32x) + svaddv_f32(svptrue_b32(), accumulator_odd_f32x);
98
107
  }
99
108
 
100
- NK_PUBLIC nk_f32_t nk_dots_reduce_sumsq_e2m3_ssve_(nk_e2m3_t const *data, nk_size_t count) NK_STREAMING_COMPATIBLE_ {
101
- svint64_t accumulator_i64x = svdup_s64(0);
102
- nk_size_t const vector_length = svcntd();
109
+ NK_PUBLIC nk_f32_t nk_dots_reduce_sumsq_e2m3_ssve_(nk_e2m3_t const *data, nk_size_t count) NK_STREAMING_ {
110
+ svint32_t accumulator_i32x = svdup_s32(0);
111
+ nk_size_t const vector_length = svcntb();
103
112
  for (nk_size_t i = 0; i < count; i += vector_length) {
104
- svbool_t predicate_i64x = svwhilelt_b64_u64(i, count);
105
- svuint8_t raw_u8x = svld1_u8(svwhilelt_b8_u64(i, count), (nk_u8_t const *)data + i);
106
- svint8_t values_i8x = nk_e2m3x_to_i8x_ssve_(svwhilelt_b8_u64(i, count), raw_u8x);
107
- svint16_t values_i16x = svunpklo_s16(values_i8x);
108
- svint16_t squares_i16x = svmul_s16_z(svwhilelt_b16_u64(i, count), values_i16x, values_i16x);
109
- svint64_t squares_i64x = svunpklo_s64(svunpklo_s32(squares_i16x));
110
- accumulator_i64x = svadd_s64_m(predicate_i64x, accumulator_i64x, squares_i64x);
113
+ svbool_t predicate_b8x = svwhilelt_b8_u64(i, count);
114
+ svuint8_t raw_u8x = svld1_u8(predicate_b8x, (nk_u8_t const *)data + i);
115
+ svint8_t values_i8x = nk_e2m3x_to_i8x_ssve_(predicate_b8x, raw_u8x);
116
+ accumulator_i32x = svdot_s32(accumulator_i32x, values_i8x, values_i8x);
111
117
  }
112
- return (nk_f32_t)svaddv_s64(svptrue_b64(), accumulator_i64x) / 256.0f;
118
+ return (nk_f32_t)svaddv_s32(svptrue_b32(), accumulator_i32x) / 256.0f;
113
119
  }
114
120
 
115
121
  NK_PUBLIC nk_f32_t nk_dots_reduce_sumsq_e3m2_ssve_(nk_e3m2_t const *data, nk_size_t count) NK_STREAMING_ {
116
- svfloat32_t accumulator_lo_f32x = svdup_f32(0.0f);
117
- svfloat32_t accumulator_hi_f32x = svdup_f32(0.0f);
122
+ svfloat32_t accumulator_even_f32x = svdup_f32(0.0f);
123
+ svfloat32_t accumulator_odd_f32x = svdup_f32(0.0f);
118
124
  nk_size_t const vector_length = svcnth();
119
125
  nk_size_t const half_vector_length = svcntw();
120
126
  for (nk_size_t i = 0; i < count; i += vector_length) {
121
127
  nk_size_t const batch_size = (i + vector_length < count) ? vector_length : (count - i);
122
- svbool_t predicate_i8x = svwhilelt_b8_u64(0u, batch_size);
123
- svbool_t predicate_f16x = svwhilelt_b16_u64(0u, batch_size);
124
- svuint8_t raw_u8x = svld1_u8(predicate_i8x, (nk_u8_t const *)data + i);
125
- svfloat16_t values_f16x = nk_e3m2x_to_f16x_ssve_(predicate_f16x, raw_u8x);
128
+ svbool_t predicate_b8x = svwhilelt_b8_u64(0u, batch_size);
129
+ svbool_t predicate_b16x = svwhilelt_b16_u64(0u, batch_size);
130
+ svuint8_t raw_u8x = svld1_u8(predicate_b8x, (nk_u8_t const *)data + i);
131
+ svfloat16_t values_f16x = nk_e3m2x_to_f16x_ssve_(predicate_b16x, raw_u8x);
126
132
 
127
- svbool_t predicate_lo_f32x = svwhilelt_b32_u64(0u, batch_size);
128
- svfloat32_t values_lo_f32x = svcvt_f32_f16_x(predicate_lo_f32x, values_f16x);
129
- accumulator_lo_f32x = svmla_f32_m(predicate_lo_f32x, accumulator_lo_f32x, values_lo_f32x, values_lo_f32x);
133
+ svbool_t predicate_even_b32x = svwhilelt_b32_u64(0u, batch_size);
134
+ svfloat32_t values_even_f32x = svcvt_f32_f16_x(predicate_even_b32x, values_f16x);
135
+ accumulator_even_f32x = svmla_f32_m(predicate_even_b32x, accumulator_even_f32x, values_even_f32x,
136
+ values_even_f32x);
130
137
 
131
- svbool_t predicate_hi_f32x = svwhilelt_b32_u64(half_vector_length, batch_size);
132
- svfloat32_t values_hi_f32x = svcvtlt_f32_f16_x(predicate_hi_f32x, values_f16x);
133
- accumulator_hi_f32x = svmla_f32_m(predicate_hi_f32x, accumulator_hi_f32x, values_hi_f32x, values_hi_f32x);
138
+ svbool_t predicate_odd_b32x = svwhilelt_b32_u64(half_vector_length, batch_size);
139
+ svfloat32_t values_odd_f32x = svcvtlt_f32_f16_x(predicate_odd_b32x, values_f16x);
140
+ accumulator_odd_f32x = svmla_f32_m(predicate_odd_b32x, accumulator_odd_f32x, values_odd_f32x, values_odd_f32x);
134
141
  }
135
- return svaddv_f32(svptrue_b32(), accumulator_lo_f32x) + svaddv_f32(svptrue_b32(), accumulator_hi_f32x);
142
+ return svaddv_f32(svptrue_b32(), accumulator_even_f32x) + svaddv_f32(svptrue_b32(), accumulator_odd_f32x);
136
143
  }
137
144
 
138
- NK_PUBLIC nk_u32_t nk_dots_reduce_sumsq_i8_ssve_(nk_i8_t const *data, nk_size_t count) NK_STREAMING_COMPATIBLE_ {
139
- svint64_t accumulator_i64x = svdup_s64(0);
140
- nk_size_t const vector_length = svcntd();
145
+ NK_PUBLIC nk_u32_t nk_dots_reduce_sumsq_i8_ssve_(nk_i8_t const *data, nk_size_t count) NK_STREAMING_ {
146
+ svint32_t accumulator_i32x = svdup_s32(0);
147
+ nk_size_t const vector_length = svcntb();
141
148
  for (nk_size_t i = 0; i < count; i += vector_length) {
142
- svbool_t predicate_i64x = svwhilelt_b64_u64(i, count);
143
- svint8_t loaded_i8x = svld1_s8(svwhilelt_b8_u64(i, count), data + i);
144
- svint16_t values_i16x = svunpklo_s16(loaded_i8x);
145
- svint16_t squares_i16x = svmul_s16_z(svwhilelt_b16_u64(i, count), values_i16x, values_i16x);
146
- svint64_t squares_i64x = svunpklo_s64(svunpklo_s32(squares_i16x));
147
- accumulator_i64x = svadd_s64_m(predicate_i64x, accumulator_i64x, squares_i64x);
149
+ svbool_t predicate_b8x = svwhilelt_b8_u64(i, count);
150
+ svint8_t loaded_i8x = svld1_s8(predicate_b8x, data + i);
151
+ accumulator_i32x = svdot_s32(accumulator_i32x, loaded_i8x, loaded_i8x);
148
152
  }
149
- return (nk_u32_t)svaddv_s64(svptrue_b64(), accumulator_i64x);
153
+ return (nk_u32_t)svaddv_s32(svptrue_b32(), accumulator_i32x);
150
154
  }
151
155
 
152
- NK_PUBLIC nk_u32_t nk_dots_reduce_sumsq_u8_ssve_(nk_u8_t const *data, nk_size_t count) NK_STREAMING_COMPATIBLE_ {
153
- svuint64_t accumulator_u64x = svdup_u64(0);
154
- nk_size_t const vector_length = svcntd();
156
+ NK_PUBLIC nk_u32_t nk_dots_reduce_sumsq_u8_ssve_(nk_u8_t const *data, nk_size_t count) NK_STREAMING_ {
157
+ svuint32_t accumulator_u32x = svdup_u32(0);
158
+ nk_size_t const vector_length = svcntb();
155
159
  for (nk_size_t i = 0; i < count; i += vector_length) {
156
- svbool_t predicate_u64x = svwhilelt_b64_u64(i, count);
157
- svuint8_t raw_u8x = svld1_u8(svwhilelt_b8_u64(i, count), data + i);
158
- svuint16_t values_u16x = svunpklo_u16(raw_u8x);
159
- svuint16_t squares_u16x = svmul_u16_z(svwhilelt_b16_u64(i, count), values_u16x, values_u16x);
160
- svuint64_t squares_u64x = svunpklo_u64(svunpklo_u32(squares_u16x));
161
- accumulator_u64x = svadd_u64_m(predicate_u64x, accumulator_u64x, squares_u64x);
160
+ svbool_t predicate_b8x = svwhilelt_b8_u64(i, count);
161
+ svuint8_t loaded_u8x = svld1_u8(predicate_b8x, data + i);
162
+ accumulator_u32x = svdot_u32(accumulator_u32x, loaded_u8x, loaded_u8x);
162
163
  }
163
- return (nk_u32_t)svaddv_u64(svptrue_b64(), accumulator_u64x);
164
+ return (nk_u32_t)svaddv_u32(svptrue_b32(), accumulator_u32x);
164
165
  }
165
166
 
166
- NK_PUBLIC nk_u32_t nk_dots_reduce_sumsq_i4_ssve_(nk_i4x2_t const *data, nk_size_t count) NK_STREAMING_COMPATIBLE_ {
167
- svint64_t accumulator_i64x = svdup_s64(0);
167
+ NK_PUBLIC nk_u32_t nk_dots_reduce_sumsq_i4_ssve_(nk_i4x2_t const *data, nk_size_t count) NK_STREAMING_ {
168
+ svint32_t accumulator_i32x = svdup_s32(0);
168
169
  nk_u8_t const *bytes = (nk_u8_t const *)data;
169
170
  nk_size_t const byte_count = (count + 1) / 2;
170
- nk_size_t const vector_length = svcntd();
171
+ nk_size_t const vector_length = svcntb();
171
172
  for (nk_size_t i = 0; i < byte_count; i += vector_length) {
172
- svbool_t predicate_u8x = svwhilelt_b8_u64(i, byte_count);
173
- svuint8_t packed_u8x = svld1_u8(predicate_u8x, bytes + i);
174
- svuint8_t low_u8x = svand_n_u8_x(predicate_u8x, packed_u8x, 0x0F);
175
- svuint8_t high_u8x = svlsr_n_u8_x(predicate_u8x, packed_u8x, 4);
173
+ svbool_t predicate_b8x = svwhilelt_b8_u64(i, byte_count);
174
+ svuint8_t packed_u8x = svld1_u8(predicate_b8x, bytes + i);
175
+ svuint8_t low_u8x = svand_n_u8_x(predicate_b8x, packed_u8x, 0x0F);
176
+ svuint8_t high_u8x = svlsr_n_u8_x(predicate_b8x, packed_u8x, 4);
176
177
  // Sign-extend 4-bit to 8-bit: shift left 4, arithmetic shift right 4
177
- svint8_t low_i8x = svasr_n_s8_x(predicate_u8x, svreinterpret_s8_u8(svlsl_n_u8_x(predicate_u8x, low_u8x, 4)), 4);
178
- svint8_t high_i8x = svasr_n_s8_x(predicate_u8x, svreinterpret_s8_u8(svlsl_n_u8_x(predicate_u8x, high_u8x, 4)),
178
+ svint8_t low_i8x = svasr_n_s8_x(predicate_b8x, svreinterpret_s8_u8(svlsl_n_u8_x(predicate_b8x, low_u8x, 4)), 4);
179
+ svint8_t high_i8x = svasr_n_s8_x(predicate_b8x, svreinterpret_s8_u8(svlsl_n_u8_x(predicate_b8x, high_u8x, 4)),
179
180
  4);
180
- // Widen to i16, square, sum per byte
181
- svbool_t predicate_i16x = svwhilelt_b16_u64(i, byte_count);
182
- svint16_t low_i16x = svunpklo_s16(low_i8x);
183
- svint16_t high_i16x = svunpklo_s16(high_i8x);
184
- svint16_t squares_low_i16x = svmul_s16_z(predicate_i16x, low_i16x, low_i16x);
185
- svint16_t squares_high_i16x = svmul_s16_z(predicate_i16x, high_i16x, high_i16x);
186
- svint16_t sum_i16x = svadd_s16_z(predicate_i16x, squares_low_i16x, squares_high_i16x);
187
- svbool_t predicate_i64x = svwhilelt_b64_u64(i, byte_count);
188
- svint64_t sum_i64x = svunpklo_s64(svunpklo_s32(sum_i16x));
189
- accumulator_i64x = svadd_s64_m(predicate_i64x, accumulator_i64x, sum_i64x);
190
- }
191
- return (nk_u32_t)svaddv_s64(svptrue_b64(), accumulator_i64x);
192
- }
193
-
194
- NK_PUBLIC nk_u32_t nk_dots_reduce_sumsq_u4_ssve_(nk_u4x2_t const *data, nk_size_t count) NK_STREAMING_COMPATIBLE_ {
195
- svuint64_t accumulator_u64x = svdup_u64(0);
181
+ accumulator_i32x = svdot_s32(accumulator_i32x, low_i8x, low_i8x);
182
+ accumulator_i32x = svdot_s32(accumulator_i32x, high_i8x, high_i8x);
183
+ }
184
+ return (nk_u32_t)svaddv_s32(svptrue_b32(), accumulator_i32x);
185
+ }
186
+
187
+ NK_PUBLIC nk_u32_t nk_dots_reduce_sumsq_u4_ssve_(nk_u4x2_t const *data, nk_size_t count) NK_STREAMING_ {
188
+ svuint32_t accumulator_u32x = svdup_u32(0);
196
189
  nk_u8_t const *bytes = (nk_u8_t const *)data;
197
190
  nk_size_t const byte_count = (count + 1) / 2;
198
- nk_size_t const vector_length = svcntd();
191
+ nk_size_t const vector_length = svcntb();
199
192
  for (nk_size_t i = 0; i < byte_count; i += vector_length) {
200
- svbool_t predicate_u8x = svwhilelt_b8_u64(i, byte_count);
201
- svuint8_t packed_u8x = svld1_u8(predicate_u8x, bytes + i);
202
- svuint8_t low_u8x = svand_n_u8_x(predicate_u8x, packed_u8x, 0x0F);
203
- svuint8_t high_u8x = svlsr_n_u8_x(predicate_u8x, packed_u8x, 4);
204
- // Widen to u16, square, sum per byte
205
- svbool_t predicate_u16x = svwhilelt_b16_u64(i, byte_count);
206
- svuint16_t low_u16x = svunpklo_u16(low_u8x);
207
- svuint16_t high_u16x = svunpklo_u16(high_u8x);
208
- svuint16_t squares_low_u16x = svmul_u16_z(predicate_u16x, low_u16x, low_u16x);
209
- svuint16_t squares_high_u16x = svmul_u16_z(predicate_u16x, high_u16x, high_u16x);
210
- svuint16_t sum_u16x = svadd_u16_z(predicate_u16x, squares_low_u16x, squares_high_u16x);
211
- svbool_t predicate_u64x = svwhilelt_b64_u64(i, byte_count);
212
- svuint64_t sum_u64x = svunpklo_u64(svunpklo_u32(sum_u16x));
213
- accumulator_u64x = svadd_u64_m(predicate_u64x, accumulator_u64x, sum_u64x);
214
- }
215
- return (nk_u32_t)svaddv_u64(svptrue_b64(), accumulator_u64x);
216
- }
217
-
218
- NK_PUBLIC svfloat32_t nk_angulars_from_dot_f32x_ssve_(svbool_t predicate_f32x, svfloat32_t dots_f32x,
193
+ svbool_t predicate_b8x = svwhilelt_b8_u64(i, byte_count);
194
+ svuint8_t packed_u8x = svld1_u8(predicate_b8x, bytes + i);
195
+ svuint8_t low_u8x = svand_n_u8_x(predicate_b8x, packed_u8x, 0x0F);
196
+ svuint8_t high_u8x = svlsr_n_u8_x(predicate_b8x, packed_u8x, 4);
197
+ accumulator_u32x = svdot_u32(accumulator_u32x, low_u8x, low_u8x);
198
+ accumulator_u32x = svdot_u32(accumulator_u32x, high_u8x, high_u8x);
199
+ }
200
+ return (nk_u32_t)svaddv_u32(svptrue_b32(), accumulator_u32x);
201
+ }
202
+
203
+ NK_PUBLIC svfloat32_t nk_angulars_from_dot_f32x_ssve_(svbool_t predicate_b32x, svfloat32_t dots_f32x,
219
204
  svfloat32_t query_norm_sq_f32x,
220
- svfloat32_t target_norms_sq_f32x) NK_STREAMING_COMPATIBLE_ {
221
- svfloat32_t norms_product_f32x = svmul_f32_x(predicate_f32x, query_norm_sq_f32x, target_norms_sq_f32x);
205
+ svfloat32_t target_norms_sq_f32x) NK_STREAMING_ {
206
+ svfloat32_t norms_product_f32x = svmul_f32_x(predicate_b32x, query_norm_sq_f32x, target_norms_sq_f32x);
222
207
  svfloat32_t rsqrt_f32x = svrsqrte_f32(norms_product_f32x);
223
- rsqrt_f32x = svmul_f32_x(predicate_f32x, rsqrt_f32x,
224
- svrsqrts_f32(svmul_f32_x(predicate_f32x, norms_product_f32x, rsqrt_f32x), rsqrt_f32x));
225
- rsqrt_f32x = svmul_f32_x(predicate_f32x, rsqrt_f32x,
226
- svrsqrts_f32(svmul_f32_x(predicate_f32x, norms_product_f32x, rsqrt_f32x), rsqrt_f32x));
227
- svfloat32_t angular_f32x = svsub_f32_x(predicate_f32x, svdup_n_f32(1.0f),
228
- svmul_f32_x(predicate_f32x, dots_f32x, rsqrt_f32x));
229
- return svmax_f32_x(predicate_f32x, angular_f32x, svdup_n_f32(0.0f));
208
+ rsqrt_f32x = svmul_f32_x(predicate_b32x, rsqrt_f32x,
209
+ svrsqrts_f32(svmul_f32_x(predicate_b32x, norms_product_f32x, rsqrt_f32x), rsqrt_f32x));
210
+ rsqrt_f32x = svmul_f32_x(predicate_b32x, rsqrt_f32x,
211
+ svrsqrts_f32(svmul_f32_x(predicate_b32x, norms_product_f32x, rsqrt_f32x), rsqrt_f32x));
212
+ svfloat32_t angular_f32x = svsub_f32_x(predicate_b32x, svdup_n_f32(1.0f),
213
+ svmul_f32_x(predicate_b32x, dots_f32x, rsqrt_f32x));
214
+ return svmax_f32_x(predicate_b32x, angular_f32x, svdup_n_f32(0.0f));
230
215
  }
231
216
 
232
- NK_PUBLIC svfloat32_t nk_euclideans_from_dot_f32x_ssve_(svbool_t predicate_f32x, svfloat32_t dots_f32x,
217
+ NK_PUBLIC svfloat32_t nk_euclideans_from_dot_f32x_ssve_(svbool_t predicate_b32x, svfloat32_t dots_f32x,
233
218
  svfloat32_t query_norm_sq_f32x,
234
- svfloat32_t target_norms_sq_f32x) NK_STREAMING_COMPATIBLE_ {
235
- svfloat32_t sum_sq_f32x = svadd_f32_x(predicate_f32x, query_norm_sq_f32x, target_norms_sq_f32x);
236
- svfloat32_t dist_sq_f32x = svsub_f32_x(predicate_f32x, sum_sq_f32x,
237
- svmul_f32_x(predicate_f32x, svdup_n_f32(2.0f), dots_f32x));
238
- dist_sq_f32x = svmax_f32_x(predicate_f32x, dist_sq_f32x, svdup_n_f32(0.0f));
239
- return svsqrt_f32_x(predicate_f32x, dist_sq_f32x);
219
+ svfloat32_t target_norms_sq_f32x) NK_STREAMING_ {
220
+ svfloat32_t sum_sq_f32x = svadd_f32_x(predicate_b32x, query_norm_sq_f32x, target_norms_sq_f32x);
221
+ svfloat32_t dist_sq_f32x = svsub_f32_x(predicate_b32x, sum_sq_f32x,
222
+ svmul_f32_x(predicate_b32x, svdup_n_f32(2.0f), dots_f32x));
223
+ dist_sq_f32x = svmax_f32_x(predicate_b32x, dist_sq_f32x, svdup_n_f32(0.0f));
224
+ return svsqrt_f32_x(predicate_b32x, dist_sq_f32x);
240
225
  }
241
226
 
242
- #pragma region Half Precision Floats
227
+ #pragma region F16 Floats
243
228
 
244
229
  __arm_locally_streaming static void nk_angulars_packed_f16_sme_finalize_streaming_( //
245
230
  nk_f16_t const *a, void const *b_packed, nk_f32_t *c, //
@@ -253,12 +238,12 @@ __arm_locally_streaming static void nk_angulars_packed_f16_sme_finalize_streamin
253
238
  nk_f32_t query_norm_sq_f32 = nk_dots_reduce_sumsq_f16_ssve_(a_row, depth);
254
239
  svfloat32_t query_norm_sq_f32x = svdup_n_f32(query_norm_sq_f32);
255
240
  for (nk_size_t col_index = 0; col_index < columns; col_index += svcntw()) {
256
- svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, columns);
257
- svfloat32_t dots_f32x = svld1_f32(predicate_f32x, result_row + col_index);
258
- svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_f32x, b_norms + col_index);
241
+ svbool_t predicate_b32x = svwhilelt_b32_u64(col_index, columns);
242
+ svfloat32_t dots_f32x = svld1_f32(predicate_b32x, result_row + col_index);
243
+ svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_b32x, b_norms + col_index);
259
244
  svst1_f32(
260
- predicate_f32x, result_row + col_index,
261
- nk_angulars_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
245
+ predicate_b32x, result_row + col_index,
246
+ nk_angulars_from_dot_f32x_ssve_(predicate_b32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
262
247
  }
263
248
  }
264
249
  }
@@ -286,12 +271,12 @@ __arm_locally_streaming static void nk_euclideans_packed_f16_sme_finalize_stream
286
271
  nk_f32_t query_norm_sq_f32 = nk_dots_reduce_sumsq_f16_ssve_(a_row, depth);
287
272
  svfloat32_t query_norm_sq_f32x = svdup_n_f32(query_norm_sq_f32);
288
273
  for (nk_size_t col_index = 0; col_index < columns; col_index += svcntw()) {
289
- svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, columns);
290
- svfloat32_t dots_f32x = svld1_f32(predicate_f32x, result_row + col_index);
291
- svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_f32x, b_norms + col_index);
274
+ svbool_t predicate_b32x = svwhilelt_b32_u64(col_index, columns);
275
+ svfloat32_t dots_f32x = svld1_f32(predicate_b32x, result_row + col_index);
276
+ svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_b32x, b_norms + col_index);
292
277
  svst1_f32(
293
- predicate_f32x, result_row + col_index,
294
- nk_euclideans_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
278
+ predicate_b32x, result_row + col_index,
279
+ nk_euclideans_from_dot_f32x_ssve_(predicate_b32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
295
280
  }
296
281
  }
297
282
  }
@@ -307,8 +292,8 @@ NK_PUBLIC void nk_euclideans_packed_f16_sme( //
307
292
  c_stride_elements);
308
293
  }
309
294
 
310
- __arm_locally_streaming static void nk_angulars_symmetric_f16_sme_finalize_streaming_( //
311
- nk_f16_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride_elements, //
295
+ __arm_locally_streaming static void nk_angulars_symmetric_f16_sme_finalize_streaming_( //
296
+ nk_f16_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, //
312
297
  nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
313
298
  // Phase 1: cache row norms on diagonal
314
299
  for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
@@ -317,8 +302,8 @@ __arm_locally_streaming static void nk_angulars_symmetric_f16_sme_finalize_strea
317
302
  }
318
303
  // Phase 2: column-first post-processing
319
304
  nk_f32_t norms_cache[256];
320
- for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
321
- nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
305
+ for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
306
+ nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
322
307
  for (nk_size_t col = chunk_start; col < chunk_end; ++col)
323
308
  norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_f16_ssve_(vectors + col * stride_elements, depth);
324
309
  for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
@@ -327,11 +312,11 @@ __arm_locally_streaming static void nk_angulars_symmetric_f16_sme_finalize_strea
327
312
  nk_f32_t *result_row = result + row_index * result_stride_elements;
328
313
  svfloat32_t query_norm_sq_f32x = svdup_n_f32(result_row[row_index]);
329
314
  for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntw()) {
330
- svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, chunk_end);
331
- svfloat32_t dots_f32x = svld1_f32(predicate_f32x, result_row + col_index);
332
- svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_f32x, norms_cache + (col_index - chunk_start));
333
- svst1_f32(predicate_f32x, result_row + col_index,
334
- nk_angulars_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x,
315
+ svbool_t predicate_b32x = svwhilelt_b32_u64(col_index, chunk_end);
316
+ svfloat32_t dots_f32x = svld1_f32(predicate_b32x, result_row + col_index);
317
+ svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_b32x, norms_cache + (col_index - chunk_start));
318
+ svst1_f32(predicate_b32x, result_row + col_index,
319
+ nk_angulars_from_dot_f32x_ssve_(predicate_b32x, dots_f32x, query_norm_sq_f32x,
335
320
  target_norms_sq_f32x));
336
321
  }
337
322
  }
@@ -341,19 +326,19 @@ __arm_locally_streaming static void nk_angulars_symmetric_f16_sme_finalize_strea
341
326
  result[row_index * result_stride_elements + row_index] = 0;
342
327
  }
343
328
 
344
- NK_PUBLIC void nk_angulars_symmetric_f16_sme( //
345
- nk_f16_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
346
- nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
347
- nk_size_t const stride_elements = stride / sizeof(nk_f16_t);
348
- nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
349
- nk_dots_symmetric_f16_sme_streaming_(vectors, n_vectors, depth, stride_elements, result, result_stride_elements,
329
+ NK_PUBLIC void nk_angulars_symmetric_f16_sme( //
330
+ nk_f16_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
331
+ nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
332
+ nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_f16_t);
333
+ nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
334
+ nk_dots_symmetric_f16_sme_streaming_(vectors, vectors_count, depth, stride_elements, result, result_stride_elements,
350
335
  row_start, row_count);
351
- nk_angulars_symmetric_f16_sme_finalize_streaming_(vectors, n_vectors, depth, stride_elements, result,
336
+ nk_angulars_symmetric_f16_sme_finalize_streaming_(vectors, vectors_count, depth, stride_elements, result,
352
337
  result_stride_elements, row_start, row_count);
353
338
  }
354
339
 
355
- __arm_locally_streaming static void nk_euclideans_symmetric_f16_sme_finalize_streaming_( //
356
- nk_f16_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride_elements, //
340
+ __arm_locally_streaming static void nk_euclideans_symmetric_f16_sme_finalize_streaming_( //
341
+ nk_f16_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, //
357
342
  nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
358
343
  // Phase 1: cache row norms on diagonal
359
344
  for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
@@ -362,8 +347,8 @@ __arm_locally_streaming static void nk_euclideans_symmetric_f16_sme_finalize_str
362
347
  }
363
348
  // Phase 2: column-first post-processing
364
349
  nk_f32_t norms_cache[256];
365
- for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
366
- nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
350
+ for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
351
+ nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
367
352
  for (nk_size_t col = chunk_start; col < chunk_end; ++col)
368
353
  norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_f16_ssve_(vectors + col * stride_elements, depth);
369
354
  for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
@@ -372,11 +357,11 @@ __arm_locally_streaming static void nk_euclideans_symmetric_f16_sme_finalize_str
372
357
  nk_f32_t *result_row = result + row_index * result_stride_elements;
373
358
  svfloat32_t query_norm_sq_f32x = svdup_n_f32(result_row[row_index]);
374
359
  for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntw()) {
375
- svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, chunk_end);
376
- svfloat32_t dots_f32x = svld1_f32(predicate_f32x, result_row + col_index);
377
- svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_f32x, norms_cache + (col_index - chunk_start));
378
- svst1_f32(predicate_f32x, result_row + col_index,
379
- nk_euclideans_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x,
360
+ svbool_t predicate_b32x = svwhilelt_b32_u64(col_index, chunk_end);
361
+ svfloat32_t dots_f32x = svld1_f32(predicate_b32x, result_row + col_index);
362
+ svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_b32x, norms_cache + (col_index - chunk_start));
363
+ svst1_f32(predicate_b32x, result_row + col_index,
364
+ nk_euclideans_from_dot_f32x_ssve_(predicate_b32x, dots_f32x, query_norm_sq_f32x,
380
365
  target_norms_sq_f32x));
381
366
  }
382
367
  }
@@ -386,20 +371,20 @@ __arm_locally_streaming static void nk_euclideans_symmetric_f16_sme_finalize_str
386
371
  result[row_index * result_stride_elements + row_index] = 0;
387
372
  }
388
373
 
389
- NK_PUBLIC void nk_euclideans_symmetric_f16_sme( //
390
- nk_f16_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
391
- nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
392
- nk_size_t const stride_elements = stride / sizeof(nk_f16_t);
393
- nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
394
- nk_dots_symmetric_f16_sme_streaming_(vectors, n_vectors, depth, stride_elements, result, result_stride_elements,
374
+ NK_PUBLIC void nk_euclideans_symmetric_f16_sme( //
375
+ nk_f16_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
376
+ nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
377
+ nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_f16_t);
378
+ nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
379
+ nk_dots_symmetric_f16_sme_streaming_(vectors, vectors_count, depth, stride_elements, result, result_stride_elements,
395
380
  row_start, row_count);
396
- nk_euclideans_symmetric_f16_sme_finalize_streaming_(vectors, n_vectors, depth, stride_elements, result,
381
+ nk_euclideans_symmetric_f16_sme_finalize_streaming_(vectors, vectors_count, depth, stride_elements, result,
397
382
  result_stride_elements, row_start, row_count);
398
383
  }
399
384
 
400
- #pragma endregion // Half Precision Floats
385
+ #pragma endregion F16 Floats
401
386
 
402
- #pragma region Brain Float 16
387
+ #pragma region BF16 Floats
403
388
 
404
389
  __arm_locally_streaming static void nk_angulars_packed_bf16_sme_finalize_streaming_( //
405
390
  nk_bf16_t const *a, void const *b_packed, nk_f32_t *c, //
@@ -413,12 +398,12 @@ __arm_locally_streaming static void nk_angulars_packed_bf16_sme_finalize_streami
413
398
  nk_f32_t query_norm_sq_f32 = nk_dots_reduce_sumsq_bf16_ssve_(a_row, depth);
414
399
  svfloat32_t query_norm_sq_f32x = svdup_n_f32(query_norm_sq_f32);
415
400
  for (nk_size_t col_index = 0; col_index < columns; col_index += svcntw()) {
416
- svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, columns);
417
- svfloat32_t dots_f32x = svld1_f32(predicate_f32x, result_row + col_index);
418
- svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_f32x, b_norms + col_index);
401
+ svbool_t predicate_b32x = svwhilelt_b32_u64(col_index, columns);
402
+ svfloat32_t dots_f32x = svld1_f32(predicate_b32x, result_row + col_index);
403
+ svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_b32x, b_norms + col_index);
419
404
  svst1_f32(
420
- predicate_f32x, result_row + col_index,
421
- nk_angulars_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
405
+ predicate_b32x, result_row + col_index,
406
+ nk_angulars_from_dot_f32x_ssve_(predicate_b32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
422
407
  }
423
408
  }
424
409
  }
@@ -446,12 +431,12 @@ __arm_locally_streaming static void nk_euclideans_packed_bf16_sme_finalize_strea
446
431
  nk_f32_t query_norm_sq_f32 = nk_dots_reduce_sumsq_bf16_ssve_(a_row, depth);
447
432
  svfloat32_t query_norm_sq_f32x = svdup_n_f32(query_norm_sq_f32);
448
433
  for (nk_size_t col_index = 0; col_index < columns; col_index += svcntw()) {
449
- svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, columns);
450
- svfloat32_t dots_f32x = svld1_f32(predicate_f32x, result_row + col_index);
451
- svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_f32x, b_norms + col_index);
434
+ svbool_t predicate_b32x = svwhilelt_b32_u64(col_index, columns);
435
+ svfloat32_t dots_f32x = svld1_f32(predicate_b32x, result_row + col_index);
436
+ svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_b32x, b_norms + col_index);
452
437
  svst1_f32(
453
- predicate_f32x, result_row + col_index,
454
- nk_euclideans_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
438
+ predicate_b32x, result_row + col_index,
439
+ nk_euclideans_from_dot_f32x_ssve_(predicate_b32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
455
440
  }
456
441
  }
457
442
  }
@@ -467,8 +452,8 @@ NK_PUBLIC void nk_euclideans_packed_bf16_sme( //
467
452
  c_stride_elements);
468
453
  }
469
454
 
470
- __arm_locally_streaming static void nk_angulars_symmetric_bf16_sme_finalize_streaming_( //
471
- nk_bf16_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride_elements, //
455
+ __arm_locally_streaming static void nk_angulars_symmetric_bf16_sme_finalize_streaming_( //
456
+ nk_bf16_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, //
472
457
  nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
473
458
  // Phase 1: cache row norms on diagonal
474
459
  for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
@@ -477,8 +462,8 @@ __arm_locally_streaming static void nk_angulars_symmetric_bf16_sme_finalize_stre
477
462
  }
478
463
  // Phase 2: column-first post-processing
479
464
  nk_f32_t norms_cache[256];
480
- for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
481
- nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
465
+ for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
466
+ nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
482
467
  for (nk_size_t col = chunk_start; col < chunk_end; ++col)
483
468
  norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_bf16_ssve_(vectors + col * stride_elements, depth);
484
469
  for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
@@ -487,11 +472,11 @@ __arm_locally_streaming static void nk_angulars_symmetric_bf16_sme_finalize_stre
487
472
  nk_f32_t *result_row = result + row_index * result_stride_elements;
488
473
  svfloat32_t query_norm_sq_f32x = svdup_n_f32(result_row[row_index]);
489
474
  for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntw()) {
490
- svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, chunk_end);
491
- svfloat32_t dots_f32x = svld1_f32(predicate_f32x, result_row + col_index);
492
- svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_f32x, norms_cache + (col_index - chunk_start));
493
- svst1_f32(predicate_f32x, result_row + col_index,
494
- nk_angulars_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x,
475
+ svbool_t predicate_b32x = svwhilelt_b32_u64(col_index, chunk_end);
476
+ svfloat32_t dots_f32x = svld1_f32(predicate_b32x, result_row + col_index);
477
+ svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_b32x, norms_cache + (col_index - chunk_start));
478
+ svst1_f32(predicate_b32x, result_row + col_index,
479
+ nk_angulars_from_dot_f32x_ssve_(predicate_b32x, dots_f32x, query_norm_sq_f32x,
495
480
  target_norms_sq_f32x));
496
481
  }
497
482
  }
@@ -501,19 +486,19 @@ __arm_locally_streaming static void nk_angulars_symmetric_bf16_sme_finalize_stre
501
486
  result[row_index * result_stride_elements + row_index] = 0;
502
487
  }
503
488
 
504
- NK_PUBLIC void nk_angulars_symmetric_bf16_sme( //
505
- nk_bf16_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
506
- nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
507
- nk_size_t const stride_elements = stride / sizeof(nk_bf16_t);
508
- nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
509
- nk_dots_symmetric_bf16_sme_streaming_(vectors, n_vectors, depth, stride_elements, result, result_stride_elements,
510
- row_start, row_count);
511
- nk_angulars_symmetric_bf16_sme_finalize_streaming_(vectors, n_vectors, depth, stride_elements, result,
489
+ NK_PUBLIC void nk_angulars_symmetric_bf16_sme( //
490
+ nk_bf16_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
491
+ nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
492
+ nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_bf16_t);
493
+ nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
494
+ nk_dots_symmetric_bf16_sme_streaming_(vectors, vectors_count, depth, stride_elements, result,
495
+ result_stride_elements, row_start, row_count);
496
+ nk_angulars_symmetric_bf16_sme_finalize_streaming_(vectors, vectors_count, depth, stride_elements, result,
512
497
  result_stride_elements, row_start, row_count);
513
498
  }
514
499
 
515
- __arm_locally_streaming static void nk_euclideans_symmetric_bf16_sme_finalize_streaming_( //
516
- nk_bf16_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride_elements, //
500
+ __arm_locally_streaming static void nk_euclideans_symmetric_bf16_sme_finalize_streaming_( //
501
+ nk_bf16_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, //
517
502
  nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
518
503
  // Phase 1: cache row norms on diagonal
519
504
  for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
@@ -522,8 +507,8 @@ __arm_locally_streaming static void nk_euclideans_symmetric_bf16_sme_finalize_st
522
507
  }
523
508
  // Phase 2: column-first post-processing
524
509
  nk_f32_t norms_cache[256];
525
- for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
526
- nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
510
+ for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
511
+ nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
527
512
  for (nk_size_t col = chunk_start; col < chunk_end; ++col)
528
513
  norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_bf16_ssve_(vectors + col * stride_elements, depth);
529
514
  for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
@@ -532,11 +517,11 @@ __arm_locally_streaming static void nk_euclideans_symmetric_bf16_sme_finalize_st
532
517
  nk_f32_t *result_row = result + row_index * result_stride_elements;
533
518
  svfloat32_t query_norm_sq_f32x = svdup_n_f32(result_row[row_index]);
534
519
  for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntw()) {
535
- svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, chunk_end);
536
- svfloat32_t dots_f32x = svld1_f32(predicate_f32x, result_row + col_index);
537
- svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_f32x, norms_cache + (col_index - chunk_start));
538
- svst1_f32(predicate_f32x, result_row + col_index,
539
- nk_euclideans_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x,
520
+ svbool_t predicate_b32x = svwhilelt_b32_u64(col_index, chunk_end);
521
+ svfloat32_t dots_f32x = svld1_f32(predicate_b32x, result_row + col_index);
522
+ svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_b32x, norms_cache + (col_index - chunk_start));
523
+ svst1_f32(predicate_b32x, result_row + col_index,
524
+ nk_euclideans_from_dot_f32x_ssve_(predicate_b32x, dots_f32x, query_norm_sq_f32x,
540
525
  target_norms_sq_f32x));
541
526
  }
542
527
  }
@@ -546,20 +531,20 @@ __arm_locally_streaming static void nk_euclideans_symmetric_bf16_sme_finalize_st
546
531
  result[row_index * result_stride_elements + row_index] = 0;
547
532
  }
548
533
 
549
- NK_PUBLIC void nk_euclideans_symmetric_bf16_sme( //
550
- nk_bf16_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
551
- nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
552
- nk_size_t const stride_elements = stride / sizeof(nk_bf16_t);
553
- nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
554
- nk_dots_symmetric_bf16_sme_streaming_(vectors, n_vectors, depth, stride_elements, result, result_stride_elements,
555
- row_start, row_count);
556
- nk_euclideans_symmetric_bf16_sme_finalize_streaming_(vectors, n_vectors, depth, stride_elements, result,
534
+ NK_PUBLIC void nk_euclideans_symmetric_bf16_sme( //
535
+ nk_bf16_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
536
+ nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
537
+ nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_bf16_t);
538
+ nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
539
+ nk_dots_symmetric_bf16_sme_streaming_(vectors, vectors_count, depth, stride_elements, result,
540
+ result_stride_elements, row_start, row_count);
541
+ nk_euclideans_symmetric_bf16_sme_finalize_streaming_(vectors, vectors_count, depth, stride_elements, result,
557
542
  result_stride_elements, row_start, row_count);
558
543
  }
559
544
 
560
- #pragma endregion // Brain Float 16
545
+ #pragma endregion BF16 Floats
561
546
 
562
- #pragma region Quarter Precision E4M3
547
+ #pragma region E4M3 Floats
563
548
 
564
549
  __arm_locally_streaming static void nk_angulars_packed_e4m3_sme_finalize_streaming_( //
565
550
  nk_e4m3_t const *a, void const *b_packed, nk_f32_t *c, //
@@ -573,12 +558,12 @@ __arm_locally_streaming static void nk_angulars_packed_e4m3_sme_finalize_streami
573
558
  nk_f32_t query_norm_sq_f32 = nk_dots_reduce_sumsq_e4m3_ssve_(a_row, depth);
574
559
  svfloat32_t query_norm_sq_f32x = svdup_n_f32(query_norm_sq_f32);
575
560
  for (nk_size_t col_index = 0; col_index < columns; col_index += svcntw()) {
576
- svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, columns);
577
- svfloat32_t dots_f32x = svld1_f32(predicate_f32x, result_row + col_index);
578
- svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_f32x, b_norms + col_index);
561
+ svbool_t predicate_b32x = svwhilelt_b32_u64(col_index, columns);
562
+ svfloat32_t dots_f32x = svld1_f32(predicate_b32x, result_row + col_index);
563
+ svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_b32x, b_norms + col_index);
579
564
  svst1_f32(
580
- predicate_f32x, result_row + col_index,
581
- nk_angulars_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
565
+ predicate_b32x, result_row + col_index,
566
+ nk_angulars_from_dot_f32x_ssve_(predicate_b32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
582
567
  }
583
568
  }
584
569
  }
@@ -606,12 +591,12 @@ __arm_locally_streaming static void nk_euclideans_packed_e4m3_sme_finalize_strea
606
591
  nk_f32_t query_norm_sq_f32 = nk_dots_reduce_sumsq_e4m3_ssve_(a_row, depth);
607
592
  svfloat32_t query_norm_sq_f32x = svdup_n_f32(query_norm_sq_f32);
608
593
  for (nk_size_t col_index = 0; col_index < columns; col_index += svcntw()) {
609
- svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, columns);
610
- svfloat32_t dots_f32x = svld1_f32(predicate_f32x, result_row + col_index);
611
- svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_f32x, b_norms + col_index);
594
+ svbool_t predicate_b32x = svwhilelt_b32_u64(col_index, columns);
595
+ svfloat32_t dots_f32x = svld1_f32(predicate_b32x, result_row + col_index);
596
+ svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_b32x, b_norms + col_index);
612
597
  svst1_f32(
613
- predicate_f32x, result_row + col_index,
614
- nk_euclideans_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
598
+ predicate_b32x, result_row + col_index,
599
+ nk_euclideans_from_dot_f32x_ssve_(predicate_b32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
615
600
  }
616
601
  }
617
602
  }
@@ -627,8 +612,8 @@ NK_PUBLIC void nk_euclideans_packed_e4m3_sme( //
627
612
  c_stride_elements);
628
613
  }
629
614
 
630
- __arm_locally_streaming static void nk_angulars_symmetric_e4m3_sme_finalize_streaming_( //
631
- nk_e4m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride_elements, //
615
+ __arm_locally_streaming static void nk_angulars_symmetric_e4m3_sme_finalize_streaming_( //
616
+ nk_e4m3_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, //
632
617
  nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
633
618
  // Phase 1: cache row norms on diagonal
634
619
  for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
@@ -637,8 +622,8 @@ __arm_locally_streaming static void nk_angulars_symmetric_e4m3_sme_finalize_stre
637
622
  }
638
623
  // Phase 2: column-first post-processing
639
624
  nk_f32_t norms_cache[256];
640
- for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
641
- nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
625
+ for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
626
+ nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
642
627
  for (nk_size_t col = chunk_start; col < chunk_end; ++col)
643
628
  norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_e4m3_ssve_(vectors + col * stride_elements, depth);
644
629
  for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
@@ -647,11 +632,11 @@ __arm_locally_streaming static void nk_angulars_symmetric_e4m3_sme_finalize_stre
647
632
  nk_f32_t *result_row = result + row_index * result_stride_elements;
648
633
  svfloat32_t query_norm_sq_f32x = svdup_n_f32(result_row[row_index]);
649
634
  for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntw()) {
650
- svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, chunk_end);
651
- svfloat32_t dots_f32x = svld1_f32(predicate_f32x, result_row + col_index);
652
- svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_f32x, norms_cache + (col_index - chunk_start));
653
- svst1_f32(predicate_f32x, result_row + col_index,
654
- nk_angulars_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x,
635
+ svbool_t predicate_b32x = svwhilelt_b32_u64(col_index, chunk_end);
636
+ svfloat32_t dots_f32x = svld1_f32(predicate_b32x, result_row + col_index);
637
+ svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_b32x, norms_cache + (col_index - chunk_start));
638
+ svst1_f32(predicate_b32x, result_row + col_index,
639
+ nk_angulars_from_dot_f32x_ssve_(predicate_b32x, dots_f32x, query_norm_sq_f32x,
655
640
  target_norms_sq_f32x));
656
641
  }
657
642
  }
@@ -661,19 +646,19 @@ __arm_locally_streaming static void nk_angulars_symmetric_e4m3_sme_finalize_stre
661
646
  result[row_index * result_stride_elements + row_index] = 0;
662
647
  }
663
648
 
664
- NK_PUBLIC void nk_angulars_symmetric_e4m3_sme( //
665
- nk_e4m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
666
- nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
667
- nk_size_t const stride_elements = stride / sizeof(nk_e4m3_t);
668
- nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
669
- nk_dots_symmetric_e4m3_sme_streaming_(vectors, n_vectors, depth, stride_elements, result, result_stride_elements,
670
- row_start, row_count);
671
- nk_angulars_symmetric_e4m3_sme_finalize_streaming_(vectors, n_vectors, depth, stride_elements, result,
649
+ NK_PUBLIC void nk_angulars_symmetric_e4m3_sme( //
650
+ nk_e4m3_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
651
+ nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
652
+ nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_e4m3_t);
653
+ nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
654
+ nk_dots_symmetric_e4m3_sme_streaming_(vectors, vectors_count, depth, stride_elements, result,
655
+ result_stride_elements, row_start, row_count);
656
+ nk_angulars_symmetric_e4m3_sme_finalize_streaming_(vectors, vectors_count, depth, stride_elements, result,
672
657
  result_stride_elements, row_start, row_count);
673
658
  }
674
659
 
675
- __arm_locally_streaming static void nk_euclideans_symmetric_e4m3_sme_finalize_streaming_( //
676
- nk_e4m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride_elements, //
660
+ __arm_locally_streaming static void nk_euclideans_symmetric_e4m3_sme_finalize_streaming_( //
661
+ nk_e4m3_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, //
677
662
  nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
678
663
  // Phase 1: cache row norms on diagonal
679
664
  for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
@@ -682,8 +667,8 @@ __arm_locally_streaming static void nk_euclideans_symmetric_e4m3_sme_finalize_st
682
667
  }
683
668
  // Phase 2: column-first post-processing
684
669
  nk_f32_t norms_cache[256];
685
- for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
686
- nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
670
+ for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
671
+ nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
687
672
  for (nk_size_t col = chunk_start; col < chunk_end; ++col)
688
673
  norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_e4m3_ssve_(vectors + col * stride_elements, depth);
689
674
  for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
@@ -692,11 +677,11 @@ __arm_locally_streaming static void nk_euclideans_symmetric_e4m3_sme_finalize_st
692
677
  nk_f32_t *result_row = result + row_index * result_stride_elements;
693
678
  svfloat32_t query_norm_sq_f32x = svdup_n_f32(result_row[row_index]);
694
679
  for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntw()) {
695
- svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, chunk_end);
696
- svfloat32_t dots_f32x = svld1_f32(predicate_f32x, result_row + col_index);
697
- svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_f32x, norms_cache + (col_index - chunk_start));
698
- svst1_f32(predicate_f32x, result_row + col_index,
699
- nk_euclideans_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x,
680
+ svbool_t predicate_b32x = svwhilelt_b32_u64(col_index, chunk_end);
681
+ svfloat32_t dots_f32x = svld1_f32(predicate_b32x, result_row + col_index);
682
+ svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_b32x, norms_cache + (col_index - chunk_start));
683
+ svst1_f32(predicate_b32x, result_row + col_index,
684
+ nk_euclideans_from_dot_f32x_ssve_(predicate_b32x, dots_f32x, query_norm_sq_f32x,
700
685
  target_norms_sq_f32x));
701
686
  }
702
687
  }
@@ -706,20 +691,20 @@ __arm_locally_streaming static void nk_euclideans_symmetric_e4m3_sme_finalize_st
706
691
  result[row_index * result_stride_elements + row_index] = 0;
707
692
  }
708
693
 
709
- NK_PUBLIC void nk_euclideans_symmetric_e4m3_sme( //
710
- nk_e4m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
711
- nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
712
- nk_size_t const stride_elements = stride / sizeof(nk_e4m3_t);
713
- nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
714
- nk_dots_symmetric_e4m3_sme_streaming_(vectors, n_vectors, depth, stride_elements, result, result_stride_elements,
715
- row_start, row_count);
716
- nk_euclideans_symmetric_e4m3_sme_finalize_streaming_(vectors, n_vectors, depth, stride_elements, result,
694
+ NK_PUBLIC void nk_euclideans_symmetric_e4m3_sme( //
695
+ nk_e4m3_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
696
+ nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
697
+ nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_e4m3_t);
698
+ nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
699
+ nk_dots_symmetric_e4m3_sme_streaming_(vectors, vectors_count, depth, stride_elements, result,
700
+ result_stride_elements, row_start, row_count);
701
+ nk_euclideans_symmetric_e4m3_sme_finalize_streaming_(vectors, vectors_count, depth, stride_elements, result,
717
702
  result_stride_elements, row_start, row_count);
718
703
  }
719
704
 
720
- #pragma endregion // Quarter Precision E4M3
705
+ #pragma endregion E4M3 Floats
721
706
 
722
- #pragma region Quarter Precision E5M2
707
+ #pragma region E5M2 Floats
723
708
 
724
709
  __arm_locally_streaming static void nk_angulars_packed_e5m2_sme_finalize_streaming_( //
725
710
  nk_e5m2_t const *a, void const *b_packed, nk_f32_t *c, //
@@ -733,12 +718,12 @@ __arm_locally_streaming static void nk_angulars_packed_e5m2_sme_finalize_streami
733
718
  nk_f32_t query_norm_sq_f32 = nk_dots_reduce_sumsq_e5m2_ssve_(a_row, depth);
734
719
  svfloat32_t query_norm_sq_f32x = svdup_n_f32(query_norm_sq_f32);
735
720
  for (nk_size_t col_index = 0; col_index < columns; col_index += svcntw()) {
736
- svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, columns);
737
- svfloat32_t dots_f32x = svld1_f32(predicate_f32x, result_row + col_index);
738
- svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_f32x, b_norms + col_index);
721
+ svbool_t predicate_b32x = svwhilelt_b32_u64(col_index, columns);
722
+ svfloat32_t dots_f32x = svld1_f32(predicate_b32x, result_row + col_index);
723
+ svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_b32x, b_norms + col_index);
739
724
  svst1_f32(
740
- predicate_f32x, result_row + col_index,
741
- nk_angulars_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
725
+ predicate_b32x, result_row + col_index,
726
+ nk_angulars_from_dot_f32x_ssve_(predicate_b32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
742
727
  }
743
728
  }
744
729
  }
@@ -766,12 +751,12 @@ __arm_locally_streaming static void nk_euclideans_packed_e5m2_sme_finalize_strea
766
751
  nk_f32_t query_norm_sq_f32 = nk_dots_reduce_sumsq_e5m2_ssve_(a_row, depth);
767
752
  svfloat32_t query_norm_sq_f32x = svdup_n_f32(query_norm_sq_f32);
768
753
  for (nk_size_t col_index = 0; col_index < columns; col_index += svcntw()) {
769
- svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, columns);
770
- svfloat32_t dots_f32x = svld1_f32(predicate_f32x, result_row + col_index);
771
- svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_f32x, b_norms + col_index);
754
+ svbool_t predicate_b32x = svwhilelt_b32_u64(col_index, columns);
755
+ svfloat32_t dots_f32x = svld1_f32(predicate_b32x, result_row + col_index);
756
+ svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_b32x, b_norms + col_index);
772
757
  svst1_f32(
773
- predicate_f32x, result_row + col_index,
774
- nk_euclideans_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
758
+ predicate_b32x, result_row + col_index,
759
+ nk_euclideans_from_dot_f32x_ssve_(predicate_b32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
775
760
  }
776
761
  }
777
762
  }
@@ -787,8 +772,8 @@ NK_PUBLIC void nk_euclideans_packed_e5m2_sme( //
787
772
  c_stride_elements);
788
773
  }
789
774
 
790
- __arm_locally_streaming static void nk_angulars_symmetric_e5m2_sme_finalize_streaming_( //
791
- nk_e5m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride_elements, //
775
+ __arm_locally_streaming static void nk_angulars_symmetric_e5m2_sme_finalize_streaming_( //
776
+ nk_e5m2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, //
792
777
  nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
793
778
  // Phase 1: cache row norms on diagonal
794
779
  for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
@@ -797,8 +782,8 @@ __arm_locally_streaming static void nk_angulars_symmetric_e5m2_sme_finalize_stre
797
782
  }
798
783
  // Phase 2: column-first post-processing
799
784
  nk_f32_t norms_cache[256];
800
- for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
801
- nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
785
+ for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
786
+ nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
802
787
  for (nk_size_t col = chunk_start; col < chunk_end; ++col)
803
788
  norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_e5m2_ssve_(vectors + col * stride_elements, depth);
804
789
  for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
@@ -807,11 +792,11 @@ __arm_locally_streaming static void nk_angulars_symmetric_e5m2_sme_finalize_stre
807
792
  nk_f32_t *result_row = result + row_index * result_stride_elements;
808
793
  svfloat32_t query_norm_sq_f32x = svdup_n_f32(result_row[row_index]);
809
794
  for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntw()) {
810
- svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, chunk_end);
811
- svfloat32_t dots_f32x = svld1_f32(predicate_f32x, result_row + col_index);
812
- svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_f32x, norms_cache + (col_index - chunk_start));
813
- svst1_f32(predicate_f32x, result_row + col_index,
814
- nk_angulars_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x,
795
+ svbool_t predicate_b32x = svwhilelt_b32_u64(col_index, chunk_end);
796
+ svfloat32_t dots_f32x = svld1_f32(predicate_b32x, result_row + col_index);
797
+ svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_b32x, norms_cache + (col_index - chunk_start));
798
+ svst1_f32(predicate_b32x, result_row + col_index,
799
+ nk_angulars_from_dot_f32x_ssve_(predicate_b32x, dots_f32x, query_norm_sq_f32x,
815
800
  target_norms_sq_f32x));
816
801
  }
817
802
  }
@@ -821,19 +806,19 @@ __arm_locally_streaming static void nk_angulars_symmetric_e5m2_sme_finalize_stre
821
806
  result[row_index * result_stride_elements + row_index] = 0;
822
807
  }
823
808
 
824
- NK_PUBLIC void nk_angulars_symmetric_e5m2_sme( //
825
- nk_e5m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
826
- nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
827
- nk_size_t const stride_elements = stride / sizeof(nk_e5m2_t);
828
- nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
829
- nk_dots_symmetric_e5m2_sme_streaming_(vectors, n_vectors, depth, stride_elements, result, result_stride_elements,
830
- row_start, row_count);
831
- nk_angulars_symmetric_e5m2_sme_finalize_streaming_(vectors, n_vectors, depth, stride_elements, result,
809
+ NK_PUBLIC void nk_angulars_symmetric_e5m2_sme( //
810
+ nk_e5m2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
811
+ nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
812
+ nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_e5m2_t);
813
+ nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
814
+ nk_dots_symmetric_e5m2_sme_streaming_(vectors, vectors_count, depth, stride_elements, result,
815
+ result_stride_elements, row_start, row_count);
816
+ nk_angulars_symmetric_e5m2_sme_finalize_streaming_(vectors, vectors_count, depth, stride_elements, result,
832
817
  result_stride_elements, row_start, row_count);
833
818
  }
834
819
 
835
- __arm_locally_streaming static void nk_euclideans_symmetric_e5m2_sme_finalize_streaming_( //
836
- nk_e5m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride_elements, //
820
+ __arm_locally_streaming static void nk_euclideans_symmetric_e5m2_sme_finalize_streaming_( //
821
+ nk_e5m2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, //
837
822
  nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
838
823
  // Phase 1: cache row norms on diagonal
839
824
  for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
@@ -842,8 +827,8 @@ __arm_locally_streaming static void nk_euclideans_symmetric_e5m2_sme_finalize_st
842
827
  }
843
828
  // Phase 2: column-first post-processing
844
829
  nk_f32_t norms_cache[256];
845
- for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
846
- nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
830
+ for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
831
+ nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
847
832
  for (nk_size_t col = chunk_start; col < chunk_end; ++col)
848
833
  norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_e5m2_ssve_(vectors + col * stride_elements, depth);
849
834
  for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
@@ -852,11 +837,11 @@ __arm_locally_streaming static void nk_euclideans_symmetric_e5m2_sme_finalize_st
852
837
  nk_f32_t *result_row = result + row_index * result_stride_elements;
853
838
  svfloat32_t query_norm_sq_f32x = svdup_n_f32(result_row[row_index]);
854
839
  for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntw()) {
855
- svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, chunk_end);
856
- svfloat32_t dots_f32x = svld1_f32(predicate_f32x, result_row + col_index);
857
- svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_f32x, norms_cache + (col_index - chunk_start));
858
- svst1_f32(predicate_f32x, result_row + col_index,
859
- nk_euclideans_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x,
840
+ svbool_t predicate_b32x = svwhilelt_b32_u64(col_index, chunk_end);
841
+ svfloat32_t dots_f32x = svld1_f32(predicate_b32x, result_row + col_index);
842
+ svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_b32x, norms_cache + (col_index - chunk_start));
843
+ svst1_f32(predicate_b32x, result_row + col_index,
844
+ nk_euclideans_from_dot_f32x_ssve_(predicate_b32x, dots_f32x, query_norm_sq_f32x,
860
845
  target_norms_sq_f32x));
861
846
  }
862
847
  }
@@ -866,20 +851,20 @@ __arm_locally_streaming static void nk_euclideans_symmetric_e5m2_sme_finalize_st
866
851
  result[row_index * result_stride_elements + row_index] = 0;
867
852
  }
868
853
 
869
- NK_PUBLIC void nk_euclideans_symmetric_e5m2_sme( //
870
- nk_e5m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
871
- nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
872
- nk_size_t const stride_elements = stride / sizeof(nk_e5m2_t);
873
- nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
874
- nk_dots_symmetric_e5m2_sme_streaming_(vectors, n_vectors, depth, stride_elements, result, result_stride_elements,
875
- row_start, row_count);
876
- nk_euclideans_symmetric_e5m2_sme_finalize_streaming_(vectors, n_vectors, depth, stride_elements, result,
854
+ NK_PUBLIC void nk_euclideans_symmetric_e5m2_sme( //
855
+ nk_e5m2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
856
+ nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
857
+ nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_e5m2_t);
858
+ nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
859
+ nk_dots_symmetric_e5m2_sme_streaming_(vectors, vectors_count, depth, stride_elements, result,
860
+ result_stride_elements, row_start, row_count);
861
+ nk_euclideans_symmetric_e5m2_sme_finalize_streaming_(vectors, vectors_count, depth, stride_elements, result,
877
862
  result_stride_elements, row_start, row_count);
878
863
  }
879
864
 
880
- #pragma endregion // Quarter Precision E5M2
865
+ #pragma endregion E5M2 Floats
881
866
 
882
- #pragma region Micro Precision E2M3
867
+ #pragma region E2M3 Floats
883
868
 
884
869
  __arm_locally_streaming static void nk_angulars_packed_e2m3_sme_finalize_streaming_( //
885
870
  nk_e2m3_t const *a, void const *b_packed, nk_f32_t *c, //
@@ -893,12 +878,12 @@ __arm_locally_streaming static void nk_angulars_packed_e2m3_sme_finalize_streami
893
878
  nk_f32_t query_norm_sq_f32 = nk_dots_reduce_sumsq_e2m3_ssve_(a_row, depth);
894
879
  svfloat32_t query_norm_sq_f32x = svdup_n_f32(query_norm_sq_f32);
895
880
  for (nk_size_t col_index = 0; col_index < columns; col_index += svcntw()) {
896
- svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, columns);
897
- svfloat32_t dots_f32x = svld1_f32(predicate_f32x, result_row + col_index);
898
- svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_f32x, b_norms + col_index);
881
+ svbool_t predicate_b32x = svwhilelt_b32_u64(col_index, columns);
882
+ svfloat32_t dots_f32x = svld1_f32(predicate_b32x, result_row + col_index);
883
+ svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_b32x, b_norms + col_index);
899
884
  svst1_f32(
900
- predicate_f32x, result_row + col_index,
901
- nk_angulars_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
885
+ predicate_b32x, result_row + col_index,
886
+ nk_angulars_from_dot_f32x_ssve_(predicate_b32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
902
887
  }
903
888
  }
904
889
  }
@@ -926,12 +911,12 @@ __arm_locally_streaming static void nk_euclideans_packed_e2m3_sme_finalize_strea
926
911
  nk_f32_t query_norm_sq_f32 = nk_dots_reduce_sumsq_e2m3_ssve_(a_row, depth);
927
912
  svfloat32_t query_norm_sq_f32x = svdup_n_f32(query_norm_sq_f32);
928
913
  for (nk_size_t col_index = 0; col_index < columns; col_index += svcntw()) {
929
- svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, columns);
930
- svfloat32_t dots_f32x = svld1_f32(predicate_f32x, result_row + col_index);
931
- svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_f32x, b_norms + col_index);
914
+ svbool_t predicate_b32x = svwhilelt_b32_u64(col_index, columns);
915
+ svfloat32_t dots_f32x = svld1_f32(predicate_b32x, result_row + col_index);
916
+ svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_b32x, b_norms + col_index);
932
917
  svst1_f32(
933
- predicate_f32x, result_row + col_index,
934
- nk_euclideans_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
918
+ predicate_b32x, result_row + col_index,
919
+ nk_euclideans_from_dot_f32x_ssve_(predicate_b32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
935
920
  }
936
921
  }
937
922
  }
@@ -947,8 +932,8 @@ NK_PUBLIC void nk_euclideans_packed_e2m3_sme( //
947
932
  c_stride_elements);
948
933
  }
949
934
 
950
- __arm_locally_streaming static void nk_angulars_symmetric_e2m3_sme_finalize_streaming_( //
951
- nk_e2m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride_elements, //
935
+ __arm_locally_streaming static void nk_angulars_symmetric_e2m3_sme_finalize_streaming_( //
936
+ nk_e2m3_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, //
952
937
  nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
953
938
  // Phase 1: cache row norms on diagonal
954
939
  for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
@@ -957,8 +942,8 @@ __arm_locally_streaming static void nk_angulars_symmetric_e2m3_sme_finalize_stre
957
942
  }
958
943
  // Phase 2: column-first post-processing
959
944
  nk_f32_t norms_cache[256];
960
- for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
961
- nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
945
+ for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
946
+ nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
962
947
  for (nk_size_t col = chunk_start; col < chunk_end; ++col)
963
948
  norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_e2m3_ssve_(vectors + col * stride_elements, depth);
964
949
  for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
@@ -967,11 +952,11 @@ __arm_locally_streaming static void nk_angulars_symmetric_e2m3_sme_finalize_stre
967
952
  nk_f32_t *result_row = result + row_index * result_stride_elements;
968
953
  svfloat32_t query_norm_sq_f32x = svdup_n_f32(result_row[row_index]);
969
954
  for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntw()) {
970
- svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, chunk_end);
971
- svfloat32_t dots_f32x = svld1_f32(predicate_f32x, result_row + col_index);
972
- svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_f32x, norms_cache + (col_index - chunk_start));
973
- svst1_f32(predicate_f32x, result_row + col_index,
974
- nk_angulars_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x,
955
+ svbool_t predicate_b32x = svwhilelt_b32_u64(col_index, chunk_end);
956
+ svfloat32_t dots_f32x = svld1_f32(predicate_b32x, result_row + col_index);
957
+ svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_b32x, norms_cache + (col_index - chunk_start));
958
+ svst1_f32(predicate_b32x, result_row + col_index,
959
+ nk_angulars_from_dot_f32x_ssve_(predicate_b32x, dots_f32x, query_norm_sq_f32x,
975
960
  target_norms_sq_f32x));
976
961
  }
977
962
  }
@@ -981,19 +966,19 @@ __arm_locally_streaming static void nk_angulars_symmetric_e2m3_sme_finalize_stre
981
966
  result[row_index * result_stride_elements + row_index] = 0;
982
967
  }
983
968
 
984
- NK_PUBLIC void nk_angulars_symmetric_e2m3_sme( //
985
- nk_e2m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
986
- nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
987
- nk_size_t const stride_elements = stride / sizeof(nk_e2m3_t);
988
- nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
989
- nk_dots_symmetric_e2m3_sme_streaming_(vectors, n_vectors, depth, stride_elements, result, result_stride_elements,
990
- row_start, row_count);
991
- nk_angulars_symmetric_e2m3_sme_finalize_streaming_(vectors, n_vectors, depth, stride_elements, result,
969
+ NK_PUBLIC void nk_angulars_symmetric_e2m3_sme( //
970
+ nk_e2m3_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
971
+ nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
972
+ nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_e2m3_t);
973
+ nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
974
+ nk_dots_symmetric_e2m3_sme_streaming_(vectors, vectors_count, depth, stride_elements, result,
975
+ result_stride_elements, row_start, row_count);
976
+ nk_angulars_symmetric_e2m3_sme_finalize_streaming_(vectors, vectors_count, depth, stride_elements, result,
992
977
  result_stride_elements, row_start, row_count);
993
978
  }
994
979
 
995
- __arm_locally_streaming static void nk_euclideans_symmetric_e2m3_sme_finalize_streaming_( //
996
- nk_e2m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride_elements, //
980
+ __arm_locally_streaming static void nk_euclideans_symmetric_e2m3_sme_finalize_streaming_( //
981
+ nk_e2m3_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, //
997
982
  nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
998
983
  // Phase 1: cache row norms on diagonal
999
984
  for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
@@ -1002,8 +987,8 @@ __arm_locally_streaming static void nk_euclideans_symmetric_e2m3_sme_finalize_st
1002
987
  }
1003
988
  // Phase 2: column-first post-processing
1004
989
  nk_f32_t norms_cache[256];
1005
- for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
1006
- nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
990
+ for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
991
+ nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
1007
992
  for (nk_size_t col = chunk_start; col < chunk_end; ++col)
1008
993
  norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_e2m3_ssve_(vectors + col * stride_elements, depth);
1009
994
  for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
@@ -1012,11 +997,11 @@ __arm_locally_streaming static void nk_euclideans_symmetric_e2m3_sme_finalize_st
1012
997
  nk_f32_t *result_row = result + row_index * result_stride_elements;
1013
998
  svfloat32_t query_norm_sq_f32x = svdup_n_f32(result_row[row_index]);
1014
999
  for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntw()) {
1015
- svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, chunk_end);
1016
- svfloat32_t dots_f32x = svld1_f32(predicate_f32x, result_row + col_index);
1017
- svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_f32x, norms_cache + (col_index - chunk_start));
1018
- svst1_f32(predicate_f32x, result_row + col_index,
1019
- nk_euclideans_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x,
1000
+ svbool_t predicate_b32x = svwhilelt_b32_u64(col_index, chunk_end);
1001
+ svfloat32_t dots_f32x = svld1_f32(predicate_b32x, result_row + col_index);
1002
+ svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_b32x, norms_cache + (col_index - chunk_start));
1003
+ svst1_f32(predicate_b32x, result_row + col_index,
1004
+ nk_euclideans_from_dot_f32x_ssve_(predicate_b32x, dots_f32x, query_norm_sq_f32x,
1020
1005
  target_norms_sq_f32x));
1021
1006
  }
1022
1007
  }
@@ -1026,20 +1011,20 @@ __arm_locally_streaming static void nk_euclideans_symmetric_e2m3_sme_finalize_st
1026
1011
  result[row_index * result_stride_elements + row_index] = 0;
1027
1012
  }
1028
1013
 
1029
- NK_PUBLIC void nk_euclideans_symmetric_e2m3_sme( //
1030
- nk_e2m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
1031
- nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
1032
- nk_size_t const stride_elements = stride / sizeof(nk_e2m3_t);
1033
- nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
1034
- nk_dots_symmetric_e2m3_sme_streaming_(vectors, n_vectors, depth, stride_elements, result, result_stride_elements,
1035
- row_start, row_count);
1036
- nk_euclideans_symmetric_e2m3_sme_finalize_streaming_(vectors, n_vectors, depth, stride_elements, result,
1014
+ NK_PUBLIC void nk_euclideans_symmetric_e2m3_sme( //
1015
+ nk_e2m3_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
1016
+ nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
1017
+ nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_e2m3_t);
1018
+ nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
1019
+ nk_dots_symmetric_e2m3_sme_streaming_(vectors, vectors_count, depth, stride_elements, result,
1020
+ result_stride_elements, row_start, row_count);
1021
+ nk_euclideans_symmetric_e2m3_sme_finalize_streaming_(vectors, vectors_count, depth, stride_elements, result,
1037
1022
  result_stride_elements, row_start, row_count);
1038
1023
  }
1039
1024
 
1040
- #pragma endregion // Micro Precision E2M3
1025
+ #pragma endregion E2M3 Floats
1041
1026
 
1042
- #pragma region Micro Precision E3M2
1027
+ #pragma region E3M2 Floats
1043
1028
 
1044
1029
  __arm_locally_streaming static void nk_angulars_packed_e3m2_sme_finalize_streaming_( //
1045
1030
  nk_e3m2_t const *a, void const *b_packed, nk_f32_t *c, //
@@ -1053,12 +1038,12 @@ __arm_locally_streaming static void nk_angulars_packed_e3m2_sme_finalize_streami
1053
1038
  nk_f32_t query_norm_sq_f32 = nk_dots_reduce_sumsq_e3m2_ssve_(a_row, depth);
1054
1039
  svfloat32_t query_norm_sq_f32x = svdup_n_f32(query_norm_sq_f32);
1055
1040
  for (nk_size_t col_index = 0; col_index < columns; col_index += svcntw()) {
1056
- svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, columns);
1057
- svfloat32_t dots_f32x = svld1_f32(predicate_f32x, result_row + col_index);
1058
- svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_f32x, b_norms + col_index);
1041
+ svbool_t predicate_b32x = svwhilelt_b32_u64(col_index, columns);
1042
+ svfloat32_t dots_f32x = svld1_f32(predicate_b32x, result_row + col_index);
1043
+ svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_b32x, b_norms + col_index);
1059
1044
  svst1_f32(
1060
- predicate_f32x, result_row + col_index,
1061
- nk_angulars_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
1045
+ predicate_b32x, result_row + col_index,
1046
+ nk_angulars_from_dot_f32x_ssve_(predicate_b32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
1062
1047
  }
1063
1048
  }
1064
1049
  }
@@ -1086,12 +1071,12 @@ __arm_locally_streaming static void nk_euclideans_packed_e3m2_sme_finalize_strea
1086
1071
  nk_f32_t query_norm_sq_f32 = nk_dots_reduce_sumsq_e3m2_ssve_(a_row, depth);
1087
1072
  svfloat32_t query_norm_sq_f32x = svdup_n_f32(query_norm_sq_f32);
1088
1073
  for (nk_size_t col_index = 0; col_index < columns; col_index += svcntw()) {
1089
- svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, columns);
1090
- svfloat32_t dots_f32x = svld1_f32(predicate_f32x, result_row + col_index);
1091
- svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_f32x, b_norms + col_index);
1074
+ svbool_t predicate_b32x = svwhilelt_b32_u64(col_index, columns);
1075
+ svfloat32_t dots_f32x = svld1_f32(predicate_b32x, result_row + col_index);
1076
+ svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_b32x, b_norms + col_index);
1092
1077
  svst1_f32(
1093
- predicate_f32x, result_row + col_index,
1094
- nk_euclideans_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
1078
+ predicate_b32x, result_row + col_index,
1079
+ nk_euclideans_from_dot_f32x_ssve_(predicate_b32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
1095
1080
  }
1096
1081
  }
1097
1082
  }
@@ -1107,8 +1092,8 @@ NK_PUBLIC void nk_euclideans_packed_e3m2_sme( //
1107
1092
  c_stride_elements);
1108
1093
  }
1109
1094
 
1110
- __arm_locally_streaming static void nk_angulars_symmetric_e3m2_sme_finalize_streaming_( //
1111
- nk_e3m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride_elements, //
1095
+ __arm_locally_streaming static void nk_angulars_symmetric_e3m2_sme_finalize_streaming_( //
1096
+ nk_e3m2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, //
1112
1097
  nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
1113
1098
  // Phase 1: cache row norms on diagonal
1114
1099
  for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
@@ -1117,8 +1102,8 @@ __arm_locally_streaming static void nk_angulars_symmetric_e3m2_sme_finalize_stre
1117
1102
  }
1118
1103
  // Phase 2: column-first post-processing
1119
1104
  nk_f32_t norms_cache[256];
1120
- for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
1121
- nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
1105
+ for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
1106
+ nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
1122
1107
  for (nk_size_t col = chunk_start; col < chunk_end; ++col)
1123
1108
  norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_e3m2_ssve_(vectors + col * stride_elements, depth);
1124
1109
  for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
@@ -1127,11 +1112,11 @@ __arm_locally_streaming static void nk_angulars_symmetric_e3m2_sme_finalize_stre
1127
1112
  nk_f32_t *result_row = result + row_index * result_stride_elements;
1128
1113
  svfloat32_t query_norm_sq_f32x = svdup_n_f32(result_row[row_index]);
1129
1114
  for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntw()) {
1130
- svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, chunk_end);
1131
- svfloat32_t dots_f32x = svld1_f32(predicate_f32x, result_row + col_index);
1132
- svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_f32x, norms_cache + (col_index - chunk_start));
1133
- svst1_f32(predicate_f32x, result_row + col_index,
1134
- nk_angulars_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x,
1115
+ svbool_t predicate_b32x = svwhilelt_b32_u64(col_index, chunk_end);
1116
+ svfloat32_t dots_f32x = svld1_f32(predicate_b32x, result_row + col_index);
1117
+ svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_b32x, norms_cache + (col_index - chunk_start));
1118
+ svst1_f32(predicate_b32x, result_row + col_index,
1119
+ nk_angulars_from_dot_f32x_ssve_(predicate_b32x, dots_f32x, query_norm_sq_f32x,
1135
1120
  target_norms_sq_f32x));
1136
1121
  }
1137
1122
  }
@@ -1141,19 +1126,19 @@ __arm_locally_streaming static void nk_angulars_symmetric_e3m2_sme_finalize_stre
1141
1126
  result[row_index * result_stride_elements + row_index] = 0;
1142
1127
  }
1143
1128
 
1144
- NK_PUBLIC void nk_angulars_symmetric_e3m2_sme( //
1145
- nk_e3m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
1146
- nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
1147
- nk_size_t const stride_elements = stride / sizeof(nk_e3m2_t);
1148
- nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
1149
- nk_dots_symmetric_e3m2_sme_streaming_(vectors, n_vectors, depth, stride_elements, result, result_stride_elements,
1150
- row_start, row_count);
1151
- nk_angulars_symmetric_e3m2_sme_finalize_streaming_(vectors, n_vectors, depth, stride_elements, result,
1129
+ NK_PUBLIC void nk_angulars_symmetric_e3m2_sme( //
1130
+ nk_e3m2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
1131
+ nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
1132
+ nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_e3m2_t);
1133
+ nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
1134
+ nk_dots_symmetric_e3m2_sme_streaming_(vectors, vectors_count, depth, stride_elements, result,
1135
+ result_stride_elements, row_start, row_count);
1136
+ nk_angulars_symmetric_e3m2_sme_finalize_streaming_(vectors, vectors_count, depth, stride_elements, result,
1152
1137
  result_stride_elements, row_start, row_count);
1153
1138
  }
1154
1139
 
1155
- __arm_locally_streaming static void nk_euclideans_symmetric_e3m2_sme_finalize_streaming_( //
1156
- nk_e3m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride_elements, //
1140
+ __arm_locally_streaming static void nk_euclideans_symmetric_e3m2_sme_finalize_streaming_( //
1141
+ nk_e3m2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, //
1157
1142
  nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
1158
1143
  // Phase 1: cache row norms on diagonal
1159
1144
  for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
@@ -1162,8 +1147,8 @@ __arm_locally_streaming static void nk_euclideans_symmetric_e3m2_sme_finalize_st
1162
1147
  }
1163
1148
  // Phase 2: column-first post-processing
1164
1149
  nk_f32_t norms_cache[256];
1165
- for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
1166
- nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
1150
+ for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
1151
+ nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
1167
1152
  for (nk_size_t col = chunk_start; col < chunk_end; ++col)
1168
1153
  norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_e3m2_ssve_(vectors + col * stride_elements, depth);
1169
1154
  for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
@@ -1172,11 +1157,11 @@ __arm_locally_streaming static void nk_euclideans_symmetric_e3m2_sme_finalize_st
1172
1157
  nk_f32_t *result_row = result + row_index * result_stride_elements;
1173
1158
  svfloat32_t query_norm_sq_f32x = svdup_n_f32(result_row[row_index]);
1174
1159
  for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntw()) {
1175
- svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, chunk_end);
1176
- svfloat32_t dots_f32x = svld1_f32(predicate_f32x, result_row + col_index);
1177
- svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_f32x, norms_cache + (col_index - chunk_start));
1178
- svst1_f32(predicate_f32x, result_row + col_index,
1179
- nk_euclideans_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x,
1160
+ svbool_t predicate_b32x = svwhilelt_b32_u64(col_index, chunk_end);
1161
+ svfloat32_t dots_f32x = svld1_f32(predicate_b32x, result_row + col_index);
1162
+ svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_b32x, norms_cache + (col_index - chunk_start));
1163
+ svst1_f32(predicate_b32x, result_row + col_index,
1164
+ nk_euclideans_from_dot_f32x_ssve_(predicate_b32x, dots_f32x, query_norm_sq_f32x,
1180
1165
  target_norms_sq_f32x));
1181
1166
  }
1182
1167
  }
@@ -1186,19 +1171,19 @@ __arm_locally_streaming static void nk_euclideans_symmetric_e3m2_sme_finalize_st
1186
1171
  result[row_index * result_stride_elements + row_index] = 0;
1187
1172
  }
1188
1173
 
1189
- NK_PUBLIC void nk_euclideans_symmetric_e3m2_sme( //
1190
- nk_e3m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
1191
- nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
1192
- nk_size_t const stride_elements = stride / sizeof(nk_e3m2_t);
1193
- nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
1194
- nk_dots_symmetric_e3m2_sme_streaming_(vectors, n_vectors, depth, stride_elements, result, result_stride_elements,
1195
- row_start, row_count);
1196
- nk_euclideans_symmetric_e3m2_sme_finalize_streaming_(vectors, n_vectors, depth, stride_elements, result,
1174
+ NK_PUBLIC void nk_euclideans_symmetric_e3m2_sme( //
1175
+ nk_e3m2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
1176
+ nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
1177
+ nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_e3m2_t);
1178
+ nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
1179
+ nk_dots_symmetric_e3m2_sme_streaming_(vectors, vectors_count, depth, stride_elements, result,
1180
+ result_stride_elements, row_start, row_count);
1181
+ nk_euclideans_symmetric_e3m2_sme_finalize_streaming_(vectors, vectors_count, depth, stride_elements, result,
1197
1182
  result_stride_elements, row_start, row_count);
1198
1183
  }
1199
1184
 
1200
- #pragma endregion // Micro Precision E3M2
1201
- #pragma region Signed 8-bit Integers
1185
+ #pragma endregion E3M2 Floats
1186
+ #pragma region I8 Integers
1202
1187
 
1203
1188
  __arm_locally_streaming static void nk_angulars_packed_i8_sme_finalize_streaming_( //
1204
1189
  nk_i8_t const *a, void const *b_packed, nk_f32_t *c, //
@@ -1212,14 +1197,14 @@ __arm_locally_streaming static void nk_angulars_packed_i8_sme_finalize_streaming
1212
1197
  nk_u32_t query_norm_sq_u32 = nk_dots_reduce_sumsq_i8_ssve_(a_row, depth);
1213
1198
  svfloat32_t query_norm_sq_f32x = svdup_n_f32((nk_f32_t)query_norm_sq_u32);
1214
1199
  for (nk_size_t col_index = 0; col_index < columns; col_index += svcntw()) {
1215
- svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, columns);
1200
+ svbool_t predicate_b32x = svwhilelt_b32_u64(col_index, columns);
1216
1201
  svfloat32_t dots_f32x = svcvt_f32_s32_x(
1217
- predicate_f32x, svld1_s32(predicate_f32x, (nk_i32_t const *)(result_row + col_index)));
1218
- svfloat32_t target_norms_sq_f32x = svcvt_f32_u32_x(predicate_f32x,
1219
- svld1_u32(predicate_f32x, b_norms + col_index));
1202
+ predicate_b32x, svld1_s32(predicate_b32x, (nk_i32_t const *)(result_row + col_index)));
1203
+ svfloat32_t target_norms_sq_f32x = svcvt_f32_u32_x(predicate_b32x,
1204
+ svld1_u32(predicate_b32x, b_norms + col_index));
1220
1205
  svst1_f32(
1221
- predicate_f32x, result_row + col_index,
1222
- nk_angulars_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
1206
+ predicate_b32x, result_row + col_index,
1207
+ nk_angulars_from_dot_f32x_ssve_(predicate_b32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
1223
1208
  }
1224
1209
  }
1225
1210
  }
@@ -1248,14 +1233,14 @@ __arm_locally_streaming static void nk_euclideans_packed_i8_sme_finalize_streami
1248
1233
  nk_u32_t query_norm_sq_u32 = nk_dots_reduce_sumsq_i8_ssve_(a_row, depth);
1249
1234
  svfloat32_t query_norm_sq_f32x = svdup_n_f32((nk_f32_t)query_norm_sq_u32);
1250
1235
  for (nk_size_t col_index = 0; col_index < columns; col_index += svcntw()) {
1251
- svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, columns);
1236
+ svbool_t predicate_b32x = svwhilelt_b32_u64(col_index, columns);
1252
1237
  svfloat32_t dots_f32x = svcvt_f32_s32_x(
1253
- predicate_f32x, svld1_s32(predicate_f32x, (nk_i32_t const *)(result_row + col_index)));
1254
- svfloat32_t target_norms_sq_f32x = svcvt_f32_u32_x(predicate_f32x,
1255
- svld1_u32(predicate_f32x, b_norms + col_index));
1238
+ predicate_b32x, svld1_s32(predicate_b32x, (nk_i32_t const *)(result_row + col_index)));
1239
+ svfloat32_t target_norms_sq_f32x = svcvt_f32_u32_x(predicate_b32x,
1240
+ svld1_u32(predicate_b32x, b_norms + col_index));
1256
1241
  svst1_f32(
1257
- predicate_f32x, result_row + col_index,
1258
- nk_euclideans_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
1242
+ predicate_b32x, result_row + col_index,
1243
+ nk_euclideans_from_dot_f32x_ssve_(predicate_b32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
1259
1244
  }
1260
1245
  }
1261
1246
  }
@@ -1272,8 +1257,8 @@ NK_PUBLIC void nk_euclideans_packed_i8_sme( //
1272
1257
  c_stride_elements);
1273
1258
  }
1274
1259
 
1275
- __arm_locally_streaming static void nk_angulars_symmetric_i8_sme_finalize_streaming_( //
1276
- nk_i8_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride_elements, //
1260
+ __arm_locally_streaming static void nk_angulars_symmetric_i8_sme_finalize_streaming_( //
1261
+ nk_i8_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, //
1277
1262
  nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
1278
1263
  // Phase 1: cache row norms on diagonal (store as u32 in f32 slot)
1279
1264
  for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
@@ -1282,8 +1267,8 @@ __arm_locally_streaming static void nk_angulars_symmetric_i8_sme_finalize_stream
1282
1267
  }
1283
1268
  // Phase 2: column-first post-processing
1284
1269
  nk_u32_t norms_cache[256];
1285
- for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
1286
- nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
1270
+ for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
1271
+ nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
1287
1272
  for (nk_size_t col = chunk_start; col < chunk_end; ++col)
1288
1273
  norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_i8_ssve_(vectors + col * stride_elements, depth);
1289
1274
  for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
@@ -1293,13 +1278,13 @@ __arm_locally_streaming static void nk_angulars_symmetric_i8_sme_finalize_stream
1293
1278
  nk_u32_t query_sumsq_u32 = ((nk_u32_t *)result_row)[row_index];
1294
1279
  svfloat32_t query_norm_sq_f32x = svdup_n_f32((nk_f32_t)query_sumsq_u32);
1295
1280
  for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntw()) {
1296
- svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, chunk_end);
1281
+ svbool_t predicate_b32x = svwhilelt_b32_u64(col_index, chunk_end);
1297
1282
  svfloat32_t dots_f32x = svcvt_f32_s32_x(
1298
- predicate_f32x, svld1_s32(predicate_f32x, (nk_i32_t *)(result_row + col_index)));
1283
+ predicate_b32x, svld1_s32(predicate_b32x, (nk_i32_t *)(result_row + col_index)));
1299
1284
  svfloat32_t target_norms_sq_f32x = svcvt_f32_u32_x(
1300
- predicate_f32x, svld1_u32(predicate_f32x, norms_cache + (col_index - chunk_start)));
1301
- svst1_f32(predicate_f32x, result_row + col_index,
1302
- nk_angulars_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x,
1285
+ predicate_b32x, svld1_u32(predicate_b32x, norms_cache + (col_index - chunk_start)));
1286
+ svst1_f32(predicate_b32x, result_row + col_index,
1287
+ nk_angulars_from_dot_f32x_ssve_(predicate_b32x, dots_f32x, query_norm_sq_f32x,
1303
1288
  target_norms_sq_f32x));
1304
1289
  }
1305
1290
  }
@@ -1309,19 +1294,19 @@ __arm_locally_streaming static void nk_angulars_symmetric_i8_sme_finalize_stream
1309
1294
  result[row_index * result_stride_elements + row_index] = 0;
1310
1295
  }
1311
1296
 
1312
- NK_PUBLIC void nk_angulars_symmetric_i8_sme( //
1313
- nk_i8_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
1314
- nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
1315
- nk_size_t const stride_elements = stride / sizeof(nk_i8_t);
1316
- nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
1317
- nk_dots_symmetric_i8_sme_streaming_(vectors, n_vectors, depth, stride_elements, (nk_i32_t *)result,
1297
+ NK_PUBLIC void nk_angulars_symmetric_i8_sme( //
1298
+ nk_i8_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
1299
+ nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
1300
+ nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_i8_t);
1301
+ nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
1302
+ nk_dots_symmetric_i8_sme_streaming_(vectors, vectors_count, depth, stride_elements, (nk_i32_t *)result,
1318
1303
  result_stride_elements, row_start, row_count);
1319
- nk_angulars_symmetric_i8_sme_finalize_streaming_(vectors, n_vectors, depth, stride_elements, result,
1304
+ nk_angulars_symmetric_i8_sme_finalize_streaming_(vectors, vectors_count, depth, stride_elements, result,
1320
1305
  result_stride_elements, row_start, row_count);
1321
1306
  }
1322
1307
 
1323
- __arm_locally_streaming static void nk_euclideans_symmetric_i8_sme_finalize_streaming_( //
1324
- nk_i8_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride_elements, //
1308
+ __arm_locally_streaming static void nk_euclideans_symmetric_i8_sme_finalize_streaming_( //
1309
+ nk_i8_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, //
1325
1310
  nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
1326
1311
  // Phase 1: cache row norms on diagonal (store as u32 in f32 slot)
1327
1312
  for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
@@ -1330,8 +1315,8 @@ __arm_locally_streaming static void nk_euclideans_symmetric_i8_sme_finalize_stre
1330
1315
  }
1331
1316
  // Phase 2: column-first post-processing
1332
1317
  nk_u32_t norms_cache[256];
1333
- for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
1334
- nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
1318
+ for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
1319
+ nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
1335
1320
  for (nk_size_t col = chunk_start; col < chunk_end; ++col)
1336
1321
  norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_i8_ssve_(vectors + col * stride_elements, depth);
1337
1322
  for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
@@ -1341,13 +1326,13 @@ __arm_locally_streaming static void nk_euclideans_symmetric_i8_sme_finalize_stre
1341
1326
  nk_u32_t query_sumsq_u32 = ((nk_u32_t *)result_row)[row_index];
1342
1327
  svfloat32_t query_norm_sq_f32x = svdup_n_f32((nk_f32_t)query_sumsq_u32);
1343
1328
  for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntw()) {
1344
- svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, chunk_end);
1329
+ svbool_t predicate_b32x = svwhilelt_b32_u64(col_index, chunk_end);
1345
1330
  svfloat32_t dots_f32x = svcvt_f32_s32_x(
1346
- predicate_f32x, svld1_s32(predicate_f32x, (nk_i32_t *)(result_row + col_index)));
1331
+ predicate_b32x, svld1_s32(predicate_b32x, (nk_i32_t *)(result_row + col_index)));
1347
1332
  svfloat32_t target_norms_sq_f32x = svcvt_f32_u32_x(
1348
- predicate_f32x, svld1_u32(predicate_f32x, norms_cache + (col_index - chunk_start)));
1349
- svst1_f32(predicate_f32x, result_row + col_index,
1350
- nk_euclideans_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x,
1333
+ predicate_b32x, svld1_u32(predicate_b32x, norms_cache + (col_index - chunk_start)));
1334
+ svst1_f32(predicate_b32x, result_row + col_index,
1335
+ nk_euclideans_from_dot_f32x_ssve_(predicate_b32x, dots_f32x, query_norm_sq_f32x,
1351
1336
  target_norms_sq_f32x));
1352
1337
  }
1353
1338
  }
@@ -1357,20 +1342,20 @@ __arm_locally_streaming static void nk_euclideans_symmetric_i8_sme_finalize_stre
1357
1342
  result[row_index * result_stride_elements + row_index] = 0;
1358
1343
  }
1359
1344
 
1360
- NK_PUBLIC void nk_euclideans_symmetric_i8_sme( //
1361
- nk_i8_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
1362
- nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
1363
- nk_size_t const stride_elements = stride / sizeof(nk_i8_t);
1364
- nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
1365
- nk_dots_symmetric_i8_sme_streaming_(vectors, n_vectors, depth, stride_elements, (nk_i32_t *)result,
1345
+ NK_PUBLIC void nk_euclideans_symmetric_i8_sme( //
1346
+ nk_i8_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
1347
+ nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
1348
+ nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_i8_t);
1349
+ nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
1350
+ nk_dots_symmetric_i8_sme_streaming_(vectors, vectors_count, depth, stride_elements, (nk_i32_t *)result,
1366
1351
  result_stride_elements, row_start, row_count);
1367
- nk_euclideans_symmetric_i8_sme_finalize_streaming_(vectors, n_vectors, depth, stride_elements, result,
1352
+ nk_euclideans_symmetric_i8_sme_finalize_streaming_(vectors, vectors_count, depth, stride_elements, result,
1368
1353
  result_stride_elements, row_start, row_count);
1369
1354
  }
1370
1355
 
1371
- #pragma endregion // Signed 8-bit Integers
1356
+ #pragma endregion I8 Integers
1372
1357
 
1373
- #pragma region Unsigned 8-bit Integers
1358
+ #pragma region U8 Integers
1374
1359
 
1375
1360
  __arm_locally_streaming static void nk_angulars_packed_u8_sme_finalize_streaming_( //
1376
1361
  nk_u8_t const *a, void const *b_packed, nk_f32_t *c, //
@@ -1384,14 +1369,14 @@ __arm_locally_streaming static void nk_angulars_packed_u8_sme_finalize_streaming
1384
1369
  nk_u32_t query_norm_sq_u32 = nk_dots_reduce_sumsq_u8_ssve_(a_row, depth);
1385
1370
  svfloat32_t query_norm_sq_f32x = svdup_n_f32((nk_f32_t)query_norm_sq_u32);
1386
1371
  for (nk_size_t col_index = 0; col_index < columns; col_index += svcntw()) {
1387
- svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, columns);
1372
+ svbool_t predicate_b32x = svwhilelt_b32_u64(col_index, columns);
1388
1373
  svfloat32_t dots_f32x = svcvt_f32_u32_x(
1389
- predicate_f32x, svld1_u32(predicate_f32x, (nk_u32_t const *)(result_row + col_index)));
1390
- svfloat32_t target_norms_sq_f32x = svcvt_f32_u32_x(predicate_f32x,
1391
- svld1_u32(predicate_f32x, b_norms + col_index));
1374
+ predicate_b32x, svld1_u32(predicate_b32x, (nk_u32_t const *)(result_row + col_index)));
1375
+ svfloat32_t target_norms_sq_f32x = svcvt_f32_u32_x(predicate_b32x,
1376
+ svld1_u32(predicate_b32x, b_norms + col_index));
1392
1377
  svst1_f32(
1393
- predicate_f32x, result_row + col_index,
1394
- nk_angulars_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
1378
+ predicate_b32x, result_row + col_index,
1379
+ nk_angulars_from_dot_f32x_ssve_(predicate_b32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
1395
1380
  }
1396
1381
  }
1397
1382
  }
@@ -1420,14 +1405,14 @@ __arm_locally_streaming static void nk_euclideans_packed_u8_sme_finalize_streami
1420
1405
  nk_u32_t query_norm_sq_u32 = nk_dots_reduce_sumsq_u8_ssve_(a_row, depth);
1421
1406
  svfloat32_t query_norm_sq_f32x = svdup_n_f32((nk_f32_t)query_norm_sq_u32);
1422
1407
  for (nk_size_t col_index = 0; col_index < columns; col_index += svcntw()) {
1423
- svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, columns);
1408
+ svbool_t predicate_b32x = svwhilelt_b32_u64(col_index, columns);
1424
1409
  svfloat32_t dots_f32x = svcvt_f32_u32_x(
1425
- predicate_f32x, svld1_u32(predicate_f32x, (nk_u32_t const *)(result_row + col_index)));
1426
- svfloat32_t target_norms_sq_f32x = svcvt_f32_u32_x(predicate_f32x,
1427
- svld1_u32(predicate_f32x, b_norms + col_index));
1410
+ predicate_b32x, svld1_u32(predicate_b32x, (nk_u32_t const *)(result_row + col_index)));
1411
+ svfloat32_t target_norms_sq_f32x = svcvt_f32_u32_x(predicate_b32x,
1412
+ svld1_u32(predicate_b32x, b_norms + col_index));
1428
1413
  svst1_f32(
1429
- predicate_f32x, result_row + col_index,
1430
- nk_euclideans_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
1414
+ predicate_b32x, result_row + col_index,
1415
+ nk_euclideans_from_dot_f32x_ssve_(predicate_b32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
1431
1416
  }
1432
1417
  }
1433
1418
  }
@@ -1444,8 +1429,8 @@ NK_PUBLIC void nk_euclideans_packed_u8_sme( //
1444
1429
  c_stride_elements);
1445
1430
  }
1446
1431
 
1447
- __arm_locally_streaming static void nk_angulars_symmetric_u8_sme_finalize_streaming_( //
1448
- nk_u8_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride_elements, //
1432
+ __arm_locally_streaming static void nk_angulars_symmetric_u8_sme_finalize_streaming_( //
1433
+ nk_u8_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, //
1449
1434
  nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
1450
1435
  // Phase 1: cache row norms on diagonal (store as u32 in f32 slot)
1451
1436
  for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
@@ -1454,8 +1439,8 @@ __arm_locally_streaming static void nk_angulars_symmetric_u8_sme_finalize_stream
1454
1439
  }
1455
1440
  // Phase 2: column-first post-processing
1456
1441
  nk_u32_t norms_cache[256];
1457
- for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
1458
- nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
1442
+ for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
1443
+ nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
1459
1444
  for (nk_size_t col = chunk_start; col < chunk_end; ++col)
1460
1445
  norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_u8_ssve_(vectors + col * stride_elements, depth);
1461
1446
  for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
@@ -1465,13 +1450,13 @@ __arm_locally_streaming static void nk_angulars_symmetric_u8_sme_finalize_stream
1465
1450
  nk_u32_t query_sumsq_u32 = ((nk_u32_t *)result_row)[row_index];
1466
1451
  svfloat32_t query_norm_sq_f32x = svdup_n_f32((nk_f32_t)query_sumsq_u32);
1467
1452
  for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntw()) {
1468
- svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, chunk_end);
1453
+ svbool_t predicate_b32x = svwhilelt_b32_u64(col_index, chunk_end);
1469
1454
  svfloat32_t dots_f32x = svcvt_f32_u32_x(
1470
- predicate_f32x, svld1_u32(predicate_f32x, (nk_u32_t *)(result_row + col_index)));
1455
+ predicate_b32x, svld1_u32(predicate_b32x, (nk_u32_t *)(result_row + col_index)));
1471
1456
  svfloat32_t target_norms_sq_f32x = svcvt_f32_u32_x(
1472
- predicate_f32x, svld1_u32(predicate_f32x, norms_cache + (col_index - chunk_start)));
1473
- svst1_f32(predicate_f32x, result_row + col_index,
1474
- nk_angulars_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x,
1457
+ predicate_b32x, svld1_u32(predicate_b32x, norms_cache + (col_index - chunk_start)));
1458
+ svst1_f32(predicate_b32x, result_row + col_index,
1459
+ nk_angulars_from_dot_f32x_ssve_(predicate_b32x, dots_f32x, query_norm_sq_f32x,
1475
1460
  target_norms_sq_f32x));
1476
1461
  }
1477
1462
  }
@@ -1481,19 +1466,19 @@ __arm_locally_streaming static void nk_angulars_symmetric_u8_sme_finalize_stream
1481
1466
  result[row_index * result_stride_elements + row_index] = 0;
1482
1467
  }
1483
1468
 
1484
- NK_PUBLIC void nk_angulars_symmetric_u8_sme( //
1485
- nk_u8_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
1486
- nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
1487
- nk_size_t const stride_elements = stride / sizeof(nk_u8_t);
1488
- nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
1489
- nk_dots_symmetric_u8_sme_streaming_(vectors, n_vectors, depth, stride_elements, (nk_u32_t *)result,
1469
+ NK_PUBLIC void nk_angulars_symmetric_u8_sme( //
1470
+ nk_u8_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
1471
+ nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
1472
+ nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_u8_t);
1473
+ nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
1474
+ nk_dots_symmetric_u8_sme_streaming_(vectors, vectors_count, depth, stride_elements, (nk_u32_t *)result,
1490
1475
  result_stride_elements, row_start, row_count);
1491
- nk_angulars_symmetric_u8_sme_finalize_streaming_(vectors, n_vectors, depth, stride_elements, result,
1476
+ nk_angulars_symmetric_u8_sme_finalize_streaming_(vectors, vectors_count, depth, stride_elements, result,
1492
1477
  result_stride_elements, row_start, row_count);
1493
1478
  }
1494
1479
 
1495
- __arm_locally_streaming static void nk_euclideans_symmetric_u8_sme_finalize_streaming_( //
1496
- nk_u8_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride_elements, //
1480
+ __arm_locally_streaming static void nk_euclideans_symmetric_u8_sme_finalize_streaming_( //
1481
+ nk_u8_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, //
1497
1482
  nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
1498
1483
  // Phase 1: cache row norms on diagonal (store as u32 in f32 slot)
1499
1484
  for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
@@ -1502,8 +1487,8 @@ __arm_locally_streaming static void nk_euclideans_symmetric_u8_sme_finalize_stre
1502
1487
  }
1503
1488
  // Phase 2: column-first post-processing
1504
1489
  nk_u32_t norms_cache[256];
1505
- for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
1506
- nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
1490
+ for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
1491
+ nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
1507
1492
  for (nk_size_t col = chunk_start; col < chunk_end; ++col)
1508
1493
  norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_u8_ssve_(vectors + col * stride_elements, depth);
1509
1494
  for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
@@ -1513,13 +1498,13 @@ __arm_locally_streaming static void nk_euclideans_symmetric_u8_sme_finalize_stre
1513
1498
  nk_u32_t query_sumsq_u32 = ((nk_u32_t *)result_row)[row_index];
1514
1499
  svfloat32_t query_norm_sq_f32x = svdup_n_f32((nk_f32_t)query_sumsq_u32);
1515
1500
  for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntw()) {
1516
- svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, chunk_end);
1501
+ svbool_t predicate_b32x = svwhilelt_b32_u64(col_index, chunk_end);
1517
1502
  svfloat32_t dots_f32x = svcvt_f32_u32_x(
1518
- predicate_f32x, svld1_u32(predicate_f32x, (nk_u32_t *)(result_row + col_index)));
1503
+ predicate_b32x, svld1_u32(predicate_b32x, (nk_u32_t *)(result_row + col_index)));
1519
1504
  svfloat32_t target_norms_sq_f32x = svcvt_f32_u32_x(
1520
- predicate_f32x, svld1_u32(predicate_f32x, norms_cache + (col_index - chunk_start)));
1521
- svst1_f32(predicate_f32x, result_row + col_index,
1522
- nk_euclideans_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x,
1505
+ predicate_b32x, svld1_u32(predicate_b32x, norms_cache + (col_index - chunk_start)));
1506
+ svst1_f32(predicate_b32x, result_row + col_index,
1507
+ nk_euclideans_from_dot_f32x_ssve_(predicate_b32x, dots_f32x, query_norm_sq_f32x,
1523
1508
  target_norms_sq_f32x));
1524
1509
  }
1525
1510
  }
@@ -1529,20 +1514,20 @@ __arm_locally_streaming static void nk_euclideans_symmetric_u8_sme_finalize_stre
1529
1514
  result[row_index * result_stride_elements + row_index] = 0;
1530
1515
  }
1531
1516
 
1532
- NK_PUBLIC void nk_euclideans_symmetric_u8_sme( //
1533
- nk_u8_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
1534
- nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
1535
- nk_size_t const stride_elements = stride / sizeof(nk_u8_t);
1536
- nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
1537
- nk_dots_symmetric_u8_sme_streaming_(vectors, n_vectors, depth, stride_elements, (nk_u32_t *)result,
1517
+ NK_PUBLIC void nk_euclideans_symmetric_u8_sme( //
1518
+ nk_u8_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
1519
+ nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
1520
+ nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_u8_t);
1521
+ nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
1522
+ nk_dots_symmetric_u8_sme_streaming_(vectors, vectors_count, depth, stride_elements, (nk_u32_t *)result,
1538
1523
  result_stride_elements, row_start, row_count);
1539
- nk_euclideans_symmetric_u8_sme_finalize_streaming_(vectors, n_vectors, depth, stride_elements, result,
1524
+ nk_euclideans_symmetric_u8_sme_finalize_streaming_(vectors, vectors_count, depth, stride_elements, result,
1540
1525
  result_stride_elements, row_start, row_count);
1541
1526
  }
1542
1527
 
1543
- #pragma endregion // Unsigned 8-bit Integers
1528
+ #pragma endregion U8 Integers
1544
1529
 
1545
- #pragma region Nibble Signed Integers
1530
+ #pragma region I4 Integers
1546
1531
 
1547
1532
  __arm_locally_streaming static void nk_angulars_packed_i4_sme_finalize_streaming_( //
1548
1533
  nk_i4x2_t const *a, void const *b_packed, nk_f32_t *c, //
@@ -1556,14 +1541,14 @@ __arm_locally_streaming static void nk_angulars_packed_i4_sme_finalize_streaming
1556
1541
  nk_u32_t query_norm_sq_u32 = nk_dots_reduce_sumsq_i4_ssve_(a_row, depth);
1557
1542
  svfloat32_t query_norm_sq_f32x = svdup_n_f32((nk_f32_t)query_norm_sq_u32);
1558
1543
  for (nk_size_t col_index = 0; col_index < columns; col_index += svcntw()) {
1559
- svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, columns);
1544
+ svbool_t predicate_b32x = svwhilelt_b32_u64(col_index, columns);
1560
1545
  svfloat32_t dots_f32x = svcvt_f32_s32_x(
1561
- predicate_f32x, svld1_s32(predicate_f32x, (nk_i32_t const *)(result_row + col_index)));
1562
- svfloat32_t target_norms_sq_f32x = svcvt_f32_u32_x(predicate_f32x,
1563
- svld1_u32(predicate_f32x, b_norms + col_index));
1546
+ predicate_b32x, svld1_s32(predicate_b32x, (nk_i32_t const *)(result_row + col_index)));
1547
+ svfloat32_t target_norms_sq_f32x = svcvt_f32_u32_x(predicate_b32x,
1548
+ svld1_u32(predicate_b32x, b_norms + col_index));
1564
1549
  svst1_f32(
1565
- predicate_f32x, result_row + col_index,
1566
- nk_angulars_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
1550
+ predicate_b32x, result_row + col_index,
1551
+ nk_angulars_from_dot_f32x_ssve_(predicate_b32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
1567
1552
  }
1568
1553
  }
1569
1554
  }
@@ -1592,14 +1577,14 @@ __arm_locally_streaming static void nk_euclideans_packed_i4_sme_finalize_streami
1592
1577
  nk_u32_t query_norm_sq_u32 = nk_dots_reduce_sumsq_i4_ssve_(a_row, depth);
1593
1578
  svfloat32_t query_norm_sq_f32x = svdup_n_f32((nk_f32_t)query_norm_sq_u32);
1594
1579
  for (nk_size_t col_index = 0; col_index < columns; col_index += svcntw()) {
1595
- svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, columns);
1580
+ svbool_t predicate_b32x = svwhilelt_b32_u64(col_index, columns);
1596
1581
  svfloat32_t dots_f32x = svcvt_f32_s32_x(
1597
- predicate_f32x, svld1_s32(predicate_f32x, (nk_i32_t const *)(result_row + col_index)));
1598
- svfloat32_t target_norms_sq_f32x = svcvt_f32_u32_x(predicate_f32x,
1599
- svld1_u32(predicate_f32x, b_norms + col_index));
1582
+ predicate_b32x, svld1_s32(predicate_b32x, (nk_i32_t const *)(result_row + col_index)));
1583
+ svfloat32_t target_norms_sq_f32x = svcvt_f32_u32_x(predicate_b32x,
1584
+ svld1_u32(predicate_b32x, b_norms + col_index));
1600
1585
  svst1_f32(
1601
- predicate_f32x, result_row + col_index,
1602
- nk_euclideans_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
1586
+ predicate_b32x, result_row + col_index,
1587
+ nk_euclideans_from_dot_f32x_ssve_(predicate_b32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
1603
1588
  }
1604
1589
  }
1605
1590
  }
@@ -1616,8 +1601,8 @@ NK_PUBLIC void nk_euclideans_packed_i4_sme( //
1616
1601
  c_stride_elements);
1617
1602
  }
1618
1603
 
1619
- __arm_locally_streaming static void nk_angulars_symmetric_i4_sme_finalize_streaming_( //
1620
- nk_i4x2_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride_elements, //
1604
+ __arm_locally_streaming static void nk_angulars_symmetric_i4_sme_finalize_streaming_( //
1605
+ nk_i4x2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, //
1621
1606
  nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
1622
1607
  // Phase 1: cache row norms on diagonal (store as u32 in f32 slot)
1623
1608
  for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
@@ -1626,8 +1611,8 @@ __arm_locally_streaming static void nk_angulars_symmetric_i4_sme_finalize_stream
1626
1611
  }
1627
1612
  // Phase 2: column-first post-processing
1628
1613
  nk_u32_t norms_cache[256];
1629
- for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
1630
- nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
1614
+ for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
1615
+ nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
1631
1616
  for (nk_size_t col = chunk_start; col < chunk_end; ++col)
1632
1617
  norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_i4_ssve_(vectors + col * stride_elements, depth);
1633
1618
  for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
@@ -1637,13 +1622,13 @@ __arm_locally_streaming static void nk_angulars_symmetric_i4_sme_finalize_stream
1637
1622
  nk_u32_t query_sumsq_u32 = ((nk_u32_t *)result_row)[row_index];
1638
1623
  svfloat32_t query_norm_sq_f32x = svdup_n_f32((nk_f32_t)query_sumsq_u32);
1639
1624
  for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntw()) {
1640
- svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, chunk_end);
1625
+ svbool_t predicate_b32x = svwhilelt_b32_u64(col_index, chunk_end);
1641
1626
  svfloat32_t dots_f32x = svcvt_f32_s32_x(
1642
- predicate_f32x, svld1_s32(predicate_f32x, (nk_i32_t *)(result_row + col_index)));
1627
+ predicate_b32x, svld1_s32(predicate_b32x, (nk_i32_t *)(result_row + col_index)));
1643
1628
  svfloat32_t target_norms_sq_f32x = svcvt_f32_u32_x(
1644
- predicate_f32x, svld1_u32(predicate_f32x, norms_cache + (col_index - chunk_start)));
1645
- svst1_f32(predicate_f32x, result_row + col_index,
1646
- nk_angulars_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x,
1629
+ predicate_b32x, svld1_u32(predicate_b32x, norms_cache + (col_index - chunk_start)));
1630
+ svst1_f32(predicate_b32x, result_row + col_index,
1631
+ nk_angulars_from_dot_f32x_ssve_(predicate_b32x, dots_f32x, query_norm_sq_f32x,
1647
1632
  target_norms_sq_f32x));
1648
1633
  }
1649
1634
  }
@@ -1653,19 +1638,19 @@ __arm_locally_streaming static void nk_angulars_symmetric_i4_sme_finalize_stream
1653
1638
  result[row_index * result_stride_elements + row_index] = 0;
1654
1639
  }
1655
1640
 
1656
- NK_PUBLIC void nk_angulars_symmetric_i4_sme( //
1657
- nk_i4x2_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
1658
- nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
1659
- nk_size_t const stride_elements = stride / sizeof(nk_i4x2_t);
1660
- nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
1661
- nk_dots_symmetric_i4_sme_streaming_(vectors, n_vectors, depth, stride_elements, (nk_i32_t *)result,
1641
+ NK_PUBLIC void nk_angulars_symmetric_i4_sme( //
1642
+ nk_i4x2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
1643
+ nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
1644
+ nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_i4x2_t);
1645
+ nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
1646
+ nk_dots_symmetric_i4_sme_streaming_(vectors, vectors_count, depth, stride_elements, (nk_i32_t *)result,
1662
1647
  result_stride_elements, row_start, row_count);
1663
- nk_angulars_symmetric_i4_sme_finalize_streaming_(vectors, n_vectors, depth, stride_elements, result,
1648
+ nk_angulars_symmetric_i4_sme_finalize_streaming_(vectors, vectors_count, depth, stride_elements, result,
1664
1649
  result_stride_elements, row_start, row_count);
1665
1650
  }
1666
1651
 
1667
- __arm_locally_streaming static void nk_euclideans_symmetric_i4_sme_finalize_streaming_( //
1668
- nk_i4x2_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride_elements, //
1652
+ __arm_locally_streaming static void nk_euclideans_symmetric_i4_sme_finalize_streaming_( //
1653
+ nk_i4x2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, //
1669
1654
  nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
1670
1655
  // Phase 1: cache row norms on diagonal (store as u32 in f32 slot)
1671
1656
  for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
@@ -1674,8 +1659,8 @@ __arm_locally_streaming static void nk_euclideans_symmetric_i4_sme_finalize_stre
1674
1659
  }
1675
1660
  // Phase 2: column-first post-processing
1676
1661
  nk_u32_t norms_cache[256];
1677
- for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
1678
- nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
1662
+ for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
1663
+ nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
1679
1664
  for (nk_size_t col = chunk_start; col < chunk_end; ++col)
1680
1665
  norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_i4_ssve_(vectors + col * stride_elements, depth);
1681
1666
  for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
@@ -1685,13 +1670,13 @@ __arm_locally_streaming static void nk_euclideans_symmetric_i4_sme_finalize_stre
1685
1670
  nk_u32_t query_sumsq_u32 = ((nk_u32_t *)result_row)[row_index];
1686
1671
  svfloat32_t query_norm_sq_f32x = svdup_n_f32((nk_f32_t)query_sumsq_u32);
1687
1672
  for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntw()) {
1688
- svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, chunk_end);
1673
+ svbool_t predicate_b32x = svwhilelt_b32_u64(col_index, chunk_end);
1689
1674
  svfloat32_t dots_f32x = svcvt_f32_s32_x(
1690
- predicate_f32x, svld1_s32(predicate_f32x, (nk_i32_t *)(result_row + col_index)));
1675
+ predicate_b32x, svld1_s32(predicate_b32x, (nk_i32_t *)(result_row + col_index)));
1691
1676
  svfloat32_t target_norms_sq_f32x = svcvt_f32_u32_x(
1692
- predicate_f32x, svld1_u32(predicate_f32x, norms_cache + (col_index - chunk_start)));
1693
- svst1_f32(predicate_f32x, result_row + col_index,
1694
- nk_euclideans_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x,
1677
+ predicate_b32x, svld1_u32(predicate_b32x, norms_cache + (col_index - chunk_start)));
1678
+ svst1_f32(predicate_b32x, result_row + col_index,
1679
+ nk_euclideans_from_dot_f32x_ssve_(predicate_b32x, dots_f32x, query_norm_sq_f32x,
1695
1680
  target_norms_sq_f32x));
1696
1681
  }
1697
1682
  }
@@ -1701,20 +1686,20 @@ __arm_locally_streaming static void nk_euclideans_symmetric_i4_sme_finalize_stre
1701
1686
  result[row_index * result_stride_elements + row_index] = 0;
1702
1687
  }
1703
1688
 
1704
- NK_PUBLIC void nk_euclideans_symmetric_i4_sme( //
1705
- nk_i4x2_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
1706
- nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
1707
- nk_size_t const stride_elements = stride / sizeof(nk_i4x2_t);
1708
- nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
1709
- nk_dots_symmetric_i4_sme_streaming_(vectors, n_vectors, depth, stride_elements, (nk_i32_t *)result,
1689
+ NK_PUBLIC void nk_euclideans_symmetric_i4_sme( //
1690
+ nk_i4x2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
1691
+ nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
1692
+ nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_i4x2_t);
1693
+ nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
1694
+ nk_dots_symmetric_i4_sme_streaming_(vectors, vectors_count, depth, stride_elements, (nk_i32_t *)result,
1710
1695
  result_stride_elements, row_start, row_count);
1711
- nk_euclideans_symmetric_i4_sme_finalize_streaming_(vectors, n_vectors, depth, stride_elements, result,
1696
+ nk_euclideans_symmetric_i4_sme_finalize_streaming_(vectors, vectors_count, depth, stride_elements, result,
1712
1697
  result_stride_elements, row_start, row_count);
1713
1698
  }
1714
1699
 
1715
- #pragma endregion // Nibble Signed Integers
1700
+ #pragma endregion Signed Integers
1716
1701
 
1717
- #pragma region Nibble Unsigned Integers
1702
+ #pragma region U4 Integers
1718
1703
 
1719
1704
  __arm_locally_streaming static void nk_angulars_packed_u4_sme_finalize_streaming_( //
1720
1705
  nk_u4x2_t const *a, void const *b_packed, nk_f32_t *c, //
@@ -1728,14 +1713,14 @@ __arm_locally_streaming static void nk_angulars_packed_u4_sme_finalize_streaming
1728
1713
  nk_u32_t query_norm_sq_u32 = nk_dots_reduce_sumsq_u4_ssve_(a_row, depth);
1729
1714
  svfloat32_t query_norm_sq_f32x = svdup_n_f32((nk_f32_t)query_norm_sq_u32);
1730
1715
  for (nk_size_t col_index = 0; col_index < columns; col_index += svcntw()) {
1731
- svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, columns);
1716
+ svbool_t predicate_b32x = svwhilelt_b32_u64(col_index, columns);
1732
1717
  svfloat32_t dots_f32x = svcvt_f32_u32_x(
1733
- predicate_f32x, svld1_u32(predicate_f32x, (nk_u32_t const *)(result_row + col_index)));
1734
- svfloat32_t target_norms_sq_f32x = svcvt_f32_u32_x(predicate_f32x,
1735
- svld1_u32(predicate_f32x, b_norms + col_index));
1718
+ predicate_b32x, svld1_u32(predicate_b32x, (nk_u32_t const *)(result_row + col_index)));
1719
+ svfloat32_t target_norms_sq_f32x = svcvt_f32_u32_x(predicate_b32x,
1720
+ svld1_u32(predicate_b32x, b_norms + col_index));
1736
1721
  svst1_f32(
1737
- predicate_f32x, result_row + col_index,
1738
- nk_angulars_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
1722
+ predicate_b32x, result_row + col_index,
1723
+ nk_angulars_from_dot_f32x_ssve_(predicate_b32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
1739
1724
  }
1740
1725
  }
1741
1726
  }
@@ -1764,14 +1749,14 @@ __arm_locally_streaming static void nk_euclideans_packed_u4_sme_finalize_streami
1764
1749
  nk_u32_t query_norm_sq_u32 = nk_dots_reduce_sumsq_u4_ssve_(a_row, depth);
1765
1750
  svfloat32_t query_norm_sq_f32x = svdup_n_f32((nk_f32_t)query_norm_sq_u32);
1766
1751
  for (nk_size_t col_index = 0; col_index < columns; col_index += svcntw()) {
1767
- svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, columns);
1752
+ svbool_t predicate_b32x = svwhilelt_b32_u64(col_index, columns);
1768
1753
  svfloat32_t dots_f32x = svcvt_f32_u32_x(
1769
- predicate_f32x, svld1_u32(predicate_f32x, (nk_u32_t const *)(result_row + col_index)));
1770
- svfloat32_t target_norms_sq_f32x = svcvt_f32_u32_x(predicate_f32x,
1771
- svld1_u32(predicate_f32x, b_norms + col_index));
1754
+ predicate_b32x, svld1_u32(predicate_b32x, (nk_u32_t const *)(result_row + col_index)));
1755
+ svfloat32_t target_norms_sq_f32x = svcvt_f32_u32_x(predicate_b32x,
1756
+ svld1_u32(predicate_b32x, b_norms + col_index));
1772
1757
  svst1_f32(
1773
- predicate_f32x, result_row + col_index,
1774
- nk_euclideans_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
1758
+ predicate_b32x, result_row + col_index,
1759
+ nk_euclideans_from_dot_f32x_ssve_(predicate_b32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
1775
1760
  }
1776
1761
  }
1777
1762
  }
@@ -1788,8 +1773,8 @@ NK_PUBLIC void nk_euclideans_packed_u4_sme( //
1788
1773
  c_stride_elements);
1789
1774
  }
1790
1775
 
1791
- __arm_locally_streaming static void nk_angulars_symmetric_u4_sme_finalize_streaming_( //
1792
- nk_u4x2_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride_elements, //
1776
+ __arm_locally_streaming static void nk_angulars_symmetric_u4_sme_finalize_streaming_( //
1777
+ nk_u4x2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, //
1793
1778
  nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
1794
1779
  // Phase 1: cache row norms on diagonal (store as u32 in f32 slot)
1795
1780
  for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
@@ -1798,8 +1783,8 @@ __arm_locally_streaming static void nk_angulars_symmetric_u4_sme_finalize_stream
1798
1783
  }
1799
1784
  // Phase 2: column-first post-processing
1800
1785
  nk_u32_t norms_cache[256];
1801
- for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
1802
- nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
1786
+ for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
1787
+ nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
1803
1788
  for (nk_size_t col = chunk_start; col < chunk_end; ++col)
1804
1789
  norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_u4_ssve_(vectors + col * stride_elements, depth);
1805
1790
  for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
@@ -1809,13 +1794,13 @@ __arm_locally_streaming static void nk_angulars_symmetric_u4_sme_finalize_stream
1809
1794
  nk_u32_t query_sumsq_u32 = ((nk_u32_t *)result_row)[row_index];
1810
1795
  svfloat32_t query_norm_sq_f32x = svdup_n_f32((nk_f32_t)query_sumsq_u32);
1811
1796
  for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntw()) {
1812
- svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, chunk_end);
1797
+ svbool_t predicate_b32x = svwhilelt_b32_u64(col_index, chunk_end);
1813
1798
  svfloat32_t dots_f32x = svcvt_f32_u32_x(
1814
- predicate_f32x, svld1_u32(predicate_f32x, (nk_u32_t *)(result_row + col_index)));
1799
+ predicate_b32x, svld1_u32(predicate_b32x, (nk_u32_t *)(result_row + col_index)));
1815
1800
  svfloat32_t target_norms_sq_f32x = svcvt_f32_u32_x(
1816
- predicate_f32x, svld1_u32(predicate_f32x, norms_cache + (col_index - chunk_start)));
1817
- svst1_f32(predicate_f32x, result_row + col_index,
1818
- nk_angulars_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x,
1801
+ predicate_b32x, svld1_u32(predicate_b32x, norms_cache + (col_index - chunk_start)));
1802
+ svst1_f32(predicate_b32x, result_row + col_index,
1803
+ nk_angulars_from_dot_f32x_ssve_(predicate_b32x, dots_f32x, query_norm_sq_f32x,
1819
1804
  target_norms_sq_f32x));
1820
1805
  }
1821
1806
  }
@@ -1825,19 +1810,19 @@ __arm_locally_streaming static void nk_angulars_symmetric_u4_sme_finalize_stream
1825
1810
  result[row_index * result_stride_elements + row_index] = 0;
1826
1811
  }
1827
1812
 
1828
- NK_PUBLIC void nk_angulars_symmetric_u4_sme( //
1829
- nk_u4x2_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
1830
- nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
1831
- nk_size_t const stride_elements = stride / sizeof(nk_u4x2_t);
1832
- nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
1833
- nk_dots_symmetric_u4_sme_streaming_(vectors, n_vectors, depth, stride_elements, (nk_u32_t *)result,
1813
+ NK_PUBLIC void nk_angulars_symmetric_u4_sme( //
1814
+ nk_u4x2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
1815
+ nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
1816
+ nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_u4x2_t);
1817
+ nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
1818
+ nk_dots_symmetric_u4_sme_streaming_(vectors, vectors_count, depth, stride_elements, (nk_u32_t *)result,
1834
1819
  result_stride_elements, row_start, row_count);
1835
- nk_angulars_symmetric_u4_sme_finalize_streaming_(vectors, n_vectors, depth, stride_elements, result,
1820
+ nk_angulars_symmetric_u4_sme_finalize_streaming_(vectors, vectors_count, depth, stride_elements, result,
1836
1821
  result_stride_elements, row_start, row_count);
1837
1822
  }
1838
1823
 
1839
- __arm_locally_streaming static void nk_euclideans_symmetric_u4_sme_finalize_streaming_( //
1840
- nk_u4x2_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride_elements, //
1824
+ __arm_locally_streaming static void nk_euclideans_symmetric_u4_sme_finalize_streaming_( //
1825
+ nk_u4x2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, //
1841
1826
  nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
1842
1827
  // Phase 1: cache row norms on diagonal (store as u32 in f32 slot)
1843
1828
  for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
@@ -1846,8 +1831,8 @@ __arm_locally_streaming static void nk_euclideans_symmetric_u4_sme_finalize_stre
1846
1831
  }
1847
1832
  // Phase 2: column-first post-processing
1848
1833
  nk_u32_t norms_cache[256];
1849
- for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
1850
- nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
1834
+ for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
1835
+ nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
1851
1836
  for (nk_size_t col = chunk_start; col < chunk_end; ++col)
1852
1837
  norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_u4_ssve_(vectors + col * stride_elements, depth);
1853
1838
  for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
@@ -1857,13 +1842,13 @@ __arm_locally_streaming static void nk_euclideans_symmetric_u4_sme_finalize_stre
1857
1842
  nk_u32_t query_sumsq_u32 = ((nk_u32_t *)result_row)[row_index];
1858
1843
  svfloat32_t query_norm_sq_f32x = svdup_n_f32((nk_f32_t)query_sumsq_u32);
1859
1844
  for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntw()) {
1860
- svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, chunk_end);
1845
+ svbool_t predicate_b32x = svwhilelt_b32_u64(col_index, chunk_end);
1861
1846
  svfloat32_t dots_f32x = svcvt_f32_u32_x(
1862
- predicate_f32x, svld1_u32(predicate_f32x, (nk_u32_t *)(result_row + col_index)));
1847
+ predicate_b32x, svld1_u32(predicate_b32x, (nk_u32_t *)(result_row + col_index)));
1863
1848
  svfloat32_t target_norms_sq_f32x = svcvt_f32_u32_x(
1864
- predicate_f32x, svld1_u32(predicate_f32x, norms_cache + (col_index - chunk_start)));
1865
- svst1_f32(predicate_f32x, result_row + col_index,
1866
- nk_euclideans_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x,
1849
+ predicate_b32x, svld1_u32(predicate_b32x, norms_cache + (col_index - chunk_start)));
1850
+ svst1_f32(predicate_b32x, result_row + col_index,
1851
+ nk_euclideans_from_dot_f32x_ssve_(predicate_b32x, dots_f32x, query_norm_sq_f32x,
1867
1852
  target_norms_sq_f32x));
1868
1853
  }
1869
1854
  }
@@ -1873,18 +1858,18 @@ __arm_locally_streaming static void nk_euclideans_symmetric_u4_sme_finalize_stre
1873
1858
  result[row_index * result_stride_elements + row_index] = 0;
1874
1859
  }
1875
1860
 
1876
- NK_PUBLIC void nk_euclideans_symmetric_u4_sme( //
1877
- nk_u4x2_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
1878
- nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
1879
- nk_size_t const stride_elements = stride / sizeof(nk_u4x2_t);
1880
- nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
1881
- nk_dots_symmetric_u4_sme_streaming_(vectors, n_vectors, depth, stride_elements, (nk_u32_t *)result,
1861
+ NK_PUBLIC void nk_euclideans_symmetric_u4_sme( //
1862
+ nk_u4x2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
1863
+ nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
1864
+ nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_u4x2_t);
1865
+ nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
1866
+ nk_dots_symmetric_u4_sme_streaming_(vectors, vectors_count, depth, stride_elements, (nk_u32_t *)result,
1882
1867
  result_stride_elements, row_start, row_count);
1883
- nk_euclideans_symmetric_u4_sme_finalize_streaming_(vectors, n_vectors, depth, stride_elements, result,
1868
+ nk_euclideans_symmetric_u4_sme_finalize_streaming_(vectors, vectors_count, depth, stride_elements, result,
1884
1869
  result_stride_elements, row_start, row_count);
1885
1870
  }
1886
1871
 
1887
- #pragma endregion // Nibble Unsigned Integers
1872
+ #pragma endregion Unsigned Integers
1888
1873
 
1889
1874
  #if defined(__clang__)
1890
1875
  #pragma clang attribute pop