numkong 7.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (294) hide show
  1. package/LICENSE +201 -0
  2. package/README.md +495 -0
  3. package/binding.gyp +540 -0
  4. package/c/dispatch.h +512 -0
  5. package/c/dispatch_bf16.c +389 -0
  6. package/c/dispatch_bf16c.c +52 -0
  7. package/c/dispatch_e2m3.c +263 -0
  8. package/c/dispatch_e3m2.c +243 -0
  9. package/c/dispatch_e4m3.c +276 -0
  10. package/c/dispatch_e5m2.c +272 -0
  11. package/c/dispatch_f16.c +376 -0
  12. package/c/dispatch_f16c.c +58 -0
  13. package/c/dispatch_f32.c +378 -0
  14. package/c/dispatch_f32c.c +99 -0
  15. package/c/dispatch_f64.c +296 -0
  16. package/c/dispatch_f64c.c +98 -0
  17. package/c/dispatch_i16.c +96 -0
  18. package/c/dispatch_i32.c +89 -0
  19. package/c/dispatch_i4.c +150 -0
  20. package/c/dispatch_i64.c +86 -0
  21. package/c/dispatch_i8.c +289 -0
  22. package/c/dispatch_other.c +330 -0
  23. package/c/dispatch_u1.c +148 -0
  24. package/c/dispatch_u16.c +124 -0
  25. package/c/dispatch_u32.c +118 -0
  26. package/c/dispatch_u4.c +150 -0
  27. package/c/dispatch_u64.c +102 -0
  28. package/c/dispatch_u8.c +303 -0
  29. package/c/numkong.c +950 -0
  30. package/include/README.md +573 -0
  31. package/include/module.modulemap +129 -0
  32. package/include/numkong/attention/sapphireamx.h +1361 -0
  33. package/include/numkong/attention/sme.h +2066 -0
  34. package/include/numkong/attention.h +49 -0
  35. package/include/numkong/capabilities.h +748 -0
  36. package/include/numkong/cast/README.md +262 -0
  37. package/include/numkong/cast/haswell.h +975 -0
  38. package/include/numkong/cast/icelake.h +470 -0
  39. package/include/numkong/cast/neon.h +1192 -0
  40. package/include/numkong/cast/rvv.h +1021 -0
  41. package/include/numkong/cast/sapphire.h +262 -0
  42. package/include/numkong/cast/serial.h +2262 -0
  43. package/include/numkong/cast/skylake.h +856 -0
  44. package/include/numkong/cast/v128relaxed.h +180 -0
  45. package/include/numkong/cast.h +230 -0
  46. package/include/numkong/curved/README.md +223 -0
  47. package/include/numkong/curved/genoa.h +182 -0
  48. package/include/numkong/curved/haswell.h +276 -0
  49. package/include/numkong/curved/neon.h +205 -0
  50. package/include/numkong/curved/neonbfdot.h +212 -0
  51. package/include/numkong/curved/neonhalf.h +212 -0
  52. package/include/numkong/curved/rvv.h +305 -0
  53. package/include/numkong/curved/serial.h +207 -0
  54. package/include/numkong/curved/skylake.h +457 -0
  55. package/include/numkong/curved/smef64.h +506 -0
  56. package/include/numkong/curved.h +517 -0
  57. package/include/numkong/curved.hpp +144 -0
  58. package/include/numkong/dot/README.md +425 -0
  59. package/include/numkong/dot/alder.h +563 -0
  60. package/include/numkong/dot/genoa.h +315 -0
  61. package/include/numkong/dot/haswell.h +1688 -0
  62. package/include/numkong/dot/icelake.h +883 -0
  63. package/include/numkong/dot/neon.h +818 -0
  64. package/include/numkong/dot/neonbfdot.h +244 -0
  65. package/include/numkong/dot/neonfhm.h +360 -0
  66. package/include/numkong/dot/neonhalf.h +198 -0
  67. package/include/numkong/dot/neonsdot.h +508 -0
  68. package/include/numkong/dot/rvv.h +714 -0
  69. package/include/numkong/dot/rvvbb.h +72 -0
  70. package/include/numkong/dot/rvvbf16.h +123 -0
  71. package/include/numkong/dot/rvvhalf.h +129 -0
  72. package/include/numkong/dot/sapphire.h +141 -0
  73. package/include/numkong/dot/serial.h +838 -0
  74. package/include/numkong/dot/sierra.h +405 -0
  75. package/include/numkong/dot/skylake.h +1084 -0
  76. package/include/numkong/dot/sve.h +379 -0
  77. package/include/numkong/dot/svebfdot.h +74 -0
  78. package/include/numkong/dot/svehalf.h +123 -0
  79. package/include/numkong/dot/v128relaxed.h +1258 -0
  80. package/include/numkong/dot.h +1070 -0
  81. package/include/numkong/dot.hpp +94 -0
  82. package/include/numkong/dots/README.md +496 -0
  83. package/include/numkong/dots/alder.h +114 -0
  84. package/include/numkong/dots/genoa.h +94 -0
  85. package/include/numkong/dots/haswell.h +295 -0
  86. package/include/numkong/dots/icelake.h +171 -0
  87. package/include/numkong/dots/neon.h +120 -0
  88. package/include/numkong/dots/neonbfdot.h +58 -0
  89. package/include/numkong/dots/neonfhm.h +94 -0
  90. package/include/numkong/dots/neonhalf.h +57 -0
  91. package/include/numkong/dots/neonsdot.h +108 -0
  92. package/include/numkong/dots/rvv.h +2486 -0
  93. package/include/numkong/dots/sapphireamx.h +3973 -0
  94. package/include/numkong/dots/serial.h +2844 -0
  95. package/include/numkong/dots/sierra.h +97 -0
  96. package/include/numkong/dots/skylake.h +196 -0
  97. package/include/numkong/dots/sme.h +5372 -0
  98. package/include/numkong/dots/smebi32.h +461 -0
  99. package/include/numkong/dots/smef64.h +1318 -0
  100. package/include/numkong/dots/smehalf.h +47 -0
  101. package/include/numkong/dots/v128relaxed.h +294 -0
  102. package/include/numkong/dots.h +2804 -0
  103. package/include/numkong/dots.hpp +639 -0
  104. package/include/numkong/each/README.md +469 -0
  105. package/include/numkong/each/haswell.h +1658 -0
  106. package/include/numkong/each/icelake.h +272 -0
  107. package/include/numkong/each/neon.h +1104 -0
  108. package/include/numkong/each/neonbfdot.h +212 -0
  109. package/include/numkong/each/neonhalf.h +410 -0
  110. package/include/numkong/each/rvv.h +1121 -0
  111. package/include/numkong/each/sapphire.h +477 -0
  112. package/include/numkong/each/serial.h +260 -0
  113. package/include/numkong/each/skylake.h +1562 -0
  114. package/include/numkong/each.h +2146 -0
  115. package/include/numkong/each.hpp +434 -0
  116. package/include/numkong/geospatial/README.md +147 -0
  117. package/include/numkong/geospatial/haswell.h +593 -0
  118. package/include/numkong/geospatial/neon.h +571 -0
  119. package/include/numkong/geospatial/rvv.h +701 -0
  120. package/include/numkong/geospatial/serial.h +309 -0
  121. package/include/numkong/geospatial/skylake.h +577 -0
  122. package/include/numkong/geospatial/v128relaxed.h +613 -0
  123. package/include/numkong/geospatial.h +453 -0
  124. package/include/numkong/geospatial.hpp +235 -0
  125. package/include/numkong/matrix.hpp +336 -0
  126. package/include/numkong/maxsim/README.md +187 -0
  127. package/include/numkong/maxsim/alder.h +511 -0
  128. package/include/numkong/maxsim/genoa.h +115 -0
  129. package/include/numkong/maxsim/haswell.h +553 -0
  130. package/include/numkong/maxsim/icelake.h +480 -0
  131. package/include/numkong/maxsim/neonsdot.h +394 -0
  132. package/include/numkong/maxsim/sapphireamx.h +877 -0
  133. package/include/numkong/maxsim/serial.h +490 -0
  134. package/include/numkong/maxsim/sme.h +929 -0
  135. package/include/numkong/maxsim/v128relaxed.h +280 -0
  136. package/include/numkong/maxsim.h +571 -0
  137. package/include/numkong/maxsim.hpp +133 -0
  138. package/include/numkong/mesh/README.md +227 -0
  139. package/include/numkong/mesh/haswell.h +2235 -0
  140. package/include/numkong/mesh/neon.h +1329 -0
  141. package/include/numkong/mesh/neonbfdot.h +842 -0
  142. package/include/numkong/mesh/neonhalf.h +616 -0
  143. package/include/numkong/mesh/rvv.h +916 -0
  144. package/include/numkong/mesh/serial.h +742 -0
  145. package/include/numkong/mesh/skylake.h +1135 -0
  146. package/include/numkong/mesh/v128relaxed.h +1052 -0
  147. package/include/numkong/mesh.h +652 -0
  148. package/include/numkong/mesh.hpp +762 -0
  149. package/include/numkong/numkong.h +78 -0
  150. package/include/numkong/numkong.hpp +57 -0
  151. package/include/numkong/probability/README.md +173 -0
  152. package/include/numkong/probability/haswell.h +267 -0
  153. package/include/numkong/probability/neon.h +225 -0
  154. package/include/numkong/probability/rvv.h +409 -0
  155. package/include/numkong/probability/serial.h +169 -0
  156. package/include/numkong/probability/skylake.h +324 -0
  157. package/include/numkong/probability.h +383 -0
  158. package/include/numkong/probability.hpp +120 -0
  159. package/include/numkong/random.h +50 -0
  160. package/include/numkong/random.hpp +285 -0
  161. package/include/numkong/reduce/README.md +547 -0
  162. package/include/numkong/reduce/alder.h +632 -0
  163. package/include/numkong/reduce/genoa.h +201 -0
  164. package/include/numkong/reduce/haswell.h +3783 -0
  165. package/include/numkong/reduce/icelake.h +549 -0
  166. package/include/numkong/reduce/neon.h +3841 -0
  167. package/include/numkong/reduce/neonbfdot.h +353 -0
  168. package/include/numkong/reduce/neonfhm.h +665 -0
  169. package/include/numkong/reduce/neonhalf.h +157 -0
  170. package/include/numkong/reduce/neonsdot.h +357 -0
  171. package/include/numkong/reduce/rvv.h +3407 -0
  172. package/include/numkong/reduce/serial.h +757 -0
  173. package/include/numkong/reduce/sierra.h +338 -0
  174. package/include/numkong/reduce/skylake.h +3792 -0
  175. package/include/numkong/reduce/v128relaxed.h +2302 -0
  176. package/include/numkong/reduce.h +1597 -0
  177. package/include/numkong/reduce.hpp +633 -0
  178. package/include/numkong/scalar/README.md +89 -0
  179. package/include/numkong/scalar/haswell.h +113 -0
  180. package/include/numkong/scalar/neon.h +122 -0
  181. package/include/numkong/scalar/neonhalf.h +70 -0
  182. package/include/numkong/scalar/rvv.h +211 -0
  183. package/include/numkong/scalar/sapphire.h +63 -0
  184. package/include/numkong/scalar/serial.h +332 -0
  185. package/include/numkong/scalar/v128relaxed.h +56 -0
  186. package/include/numkong/scalar.h +683 -0
  187. package/include/numkong/set/README.md +179 -0
  188. package/include/numkong/set/haswell.h +334 -0
  189. package/include/numkong/set/icelake.h +485 -0
  190. package/include/numkong/set/neon.h +364 -0
  191. package/include/numkong/set/rvv.h +226 -0
  192. package/include/numkong/set/rvvbb.h +117 -0
  193. package/include/numkong/set/serial.h +174 -0
  194. package/include/numkong/set/sve.h +185 -0
  195. package/include/numkong/set/v128relaxed.h +240 -0
  196. package/include/numkong/set.h +457 -0
  197. package/include/numkong/set.hpp +114 -0
  198. package/include/numkong/sets/README.md +149 -0
  199. package/include/numkong/sets/haswell.h +63 -0
  200. package/include/numkong/sets/icelake.h +66 -0
  201. package/include/numkong/sets/neon.h +61 -0
  202. package/include/numkong/sets/serial.h +43 -0
  203. package/include/numkong/sets/smebi32.h +1099 -0
  204. package/include/numkong/sets/v128relaxed.h +58 -0
  205. package/include/numkong/sets.h +339 -0
  206. package/include/numkong/sparse/README.md +156 -0
  207. package/include/numkong/sparse/icelake.h +463 -0
  208. package/include/numkong/sparse/neon.h +288 -0
  209. package/include/numkong/sparse/serial.h +117 -0
  210. package/include/numkong/sparse/sve2.h +507 -0
  211. package/include/numkong/sparse/turin.h +322 -0
  212. package/include/numkong/sparse.h +363 -0
  213. package/include/numkong/sparse.hpp +113 -0
  214. package/include/numkong/spatial/README.md +435 -0
  215. package/include/numkong/spatial/alder.h +607 -0
  216. package/include/numkong/spatial/genoa.h +290 -0
  217. package/include/numkong/spatial/haswell.h +960 -0
  218. package/include/numkong/spatial/icelake.h +586 -0
  219. package/include/numkong/spatial/neon.h +773 -0
  220. package/include/numkong/spatial/neonbfdot.h +165 -0
  221. package/include/numkong/spatial/neonhalf.h +118 -0
  222. package/include/numkong/spatial/neonsdot.h +261 -0
  223. package/include/numkong/spatial/rvv.h +984 -0
  224. package/include/numkong/spatial/rvvbf16.h +123 -0
  225. package/include/numkong/spatial/rvvhalf.h +117 -0
  226. package/include/numkong/spatial/sapphire.h +343 -0
  227. package/include/numkong/spatial/serial.h +346 -0
  228. package/include/numkong/spatial/sierra.h +323 -0
  229. package/include/numkong/spatial/skylake.h +606 -0
  230. package/include/numkong/spatial/sve.h +224 -0
  231. package/include/numkong/spatial/svebfdot.h +122 -0
  232. package/include/numkong/spatial/svehalf.h +109 -0
  233. package/include/numkong/spatial/v128relaxed.h +717 -0
  234. package/include/numkong/spatial.h +1425 -0
  235. package/include/numkong/spatial.hpp +183 -0
  236. package/include/numkong/spatials/README.md +580 -0
  237. package/include/numkong/spatials/alder.h +94 -0
  238. package/include/numkong/spatials/genoa.h +94 -0
  239. package/include/numkong/spatials/haswell.h +219 -0
  240. package/include/numkong/spatials/icelake.h +113 -0
  241. package/include/numkong/spatials/neon.h +109 -0
  242. package/include/numkong/spatials/neonbfdot.h +60 -0
  243. package/include/numkong/spatials/neonfhm.h +92 -0
  244. package/include/numkong/spatials/neonhalf.h +58 -0
  245. package/include/numkong/spatials/neonsdot.h +109 -0
  246. package/include/numkong/spatials/rvv.h +1960 -0
  247. package/include/numkong/spatials/sapphireamx.h +1149 -0
  248. package/include/numkong/spatials/serial.h +226 -0
  249. package/include/numkong/spatials/sierra.h +96 -0
  250. package/include/numkong/spatials/skylake.h +184 -0
  251. package/include/numkong/spatials/sme.h +1901 -0
  252. package/include/numkong/spatials/smef64.h +465 -0
  253. package/include/numkong/spatials/v128relaxed.h +240 -0
  254. package/include/numkong/spatials.h +3021 -0
  255. package/include/numkong/spatials.hpp +508 -0
  256. package/include/numkong/tensor.hpp +1592 -0
  257. package/include/numkong/trigonometry/README.md +184 -0
  258. package/include/numkong/trigonometry/haswell.h +652 -0
  259. package/include/numkong/trigonometry/neon.h +639 -0
  260. package/include/numkong/trigonometry/rvv.h +699 -0
  261. package/include/numkong/trigonometry/serial.h +703 -0
  262. package/include/numkong/trigonometry/skylake.h +721 -0
  263. package/include/numkong/trigonometry/v128relaxed.h +666 -0
  264. package/include/numkong/trigonometry.h +467 -0
  265. package/include/numkong/trigonometry.hpp +166 -0
  266. package/include/numkong/types.h +1384 -0
  267. package/include/numkong/types.hpp +5603 -0
  268. package/include/numkong/vector.hpp +698 -0
  269. package/javascript/README.md +246 -0
  270. package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
  271. package/javascript/dist/cjs/numkong-wasm.js +617 -0
  272. package/javascript/dist/cjs/numkong.d.ts +343 -0
  273. package/javascript/dist/cjs/numkong.js +523 -0
  274. package/javascript/dist/cjs/package.json +3 -0
  275. package/javascript/dist/cjs/types.d.ts +284 -0
  276. package/javascript/dist/cjs/types.js +653 -0
  277. package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
  278. package/javascript/dist/esm/numkong-wasm.js +595 -0
  279. package/javascript/dist/esm/numkong.d.ts +343 -0
  280. package/javascript/dist/esm/numkong.js +452 -0
  281. package/javascript/dist/esm/package.json +3 -0
  282. package/javascript/dist/esm/types.d.ts +284 -0
  283. package/javascript/dist/esm/types.js +630 -0
  284. package/javascript/dist-package-cjs.json +3 -0
  285. package/javascript/dist-package-esm.json +3 -0
  286. package/javascript/node-gyp-build.d.ts +1 -0
  287. package/javascript/numkong-wasm.ts +756 -0
  288. package/javascript/numkong.c +689 -0
  289. package/javascript/numkong.ts +575 -0
  290. package/javascript/tsconfig-base.json +39 -0
  291. package/javascript/tsconfig-cjs.json +8 -0
  292. package/javascript/tsconfig-esm.json +8 -0
  293. package/javascript/types.ts +674 -0
  294. package/package.json +87 -0
@@ -0,0 +1,1901 @@
1
+ /**
2
+ * @brief Batched Spatial Distances for ARM SME.
3
+ * @file include/numkong/spatials/sme.h
4
+ * @author Ash Vardanian
5
+ * @date February 23, 2026
6
+ *
7
+ * @sa include/numkong/spatials.h
8
+ */
9
+ #ifndef NK_SPATIALS_SME_H
10
+ #define NK_SPATIALS_SME_H
11
+
12
+ #if NK_TARGET_ARM_
13
+ #if NK_TARGET_SME
14
+
15
+ #include "numkong/dots/serial.h"
16
+ #include "numkong/dots/sme.h"
17
+
18
+ #if defined(__cplusplus)
19
+ extern "C" {
20
+ #endif
21
+
22
+ #if defined(__clang__)
23
+ #pragma clang attribute push(__attribute__((target("sme,sve"))), apply_to = function)
24
+ #elif defined(__GNUC__)
25
+ #pragma GCC push_options
26
+ #pragma GCC target("+sme")
27
+ #endif
28
+
29
+ NK_PUBLIC nk_f32_t nk_dots_reduce_sumsq_f16_ssve_(nk_f16_t const *data, nk_size_t count) NK_STREAMING_COMPATIBLE_ {
30
+ svfloat32_t accumulator_f32x = svdup_f32(0.0f);
31
+ nk_size_t const vector_length = svcntw();
32
+ for (nk_size_t i = 0; i < count; i += vector_length) {
33
+ svbool_t predicate_f32x = svwhilelt_b32_u64(i, count);
34
+ svfloat32_t values_f32x = svcvt_f32_f16_x(
35
+ predicate_f32x, svld1_f16(svwhilelt_b16_u64(i, count), (nk_f16_for_arm_simd_t const *)(data + i)));
36
+ accumulator_f32x = svmla_f32_x(predicate_f32x, accumulator_f32x, values_f32x, values_f32x);
37
+ }
38
+ return svaddv_f32(svptrue_b32(), accumulator_f32x);
39
+ }
40
+
41
+ NK_PUBLIC nk_f32_t nk_dots_reduce_sumsq_bf16_ssve_(nk_bf16_t const *data, nk_size_t count) NK_STREAMING_COMPATIBLE_ {
42
+ svfloat32_t accumulator_f32x = svdup_f32(0.0f);
43
+ nk_size_t const vector_length = svcntw();
44
+ for (nk_size_t i = 0; i < count; i += vector_length) {
45
+ svbool_t predicate_f32x = svwhilelt_b32_u64(i, count);
46
+ svuint16_t raw_u16x = svld1_u16(svwhilelt_b16_u64(i, count), (nk_u16_t const *)data + i);
47
+ svfloat32_t values_f32x = svreinterpret_f32_u32(svlsl_n_u32_x(predicate_f32x, svunpklo_u32(raw_u16x), 16));
48
+ accumulator_f32x = svmla_f32_x(predicate_f32x, accumulator_f32x, values_f32x, values_f32x);
49
+ }
50
+ return svaddv_f32(svptrue_b32(), accumulator_f32x);
51
+ }
52
+
53
+ NK_PUBLIC nk_f32_t nk_dots_reduce_sumsq_e4m3_ssve_(nk_e4m3_t const *data, nk_size_t count) NK_STREAMING_ {
54
+ svfloat32_t accumulator_lo_f32x = svdup_f32(0.0f);
55
+ svfloat32_t accumulator_hi_f32x = svdup_f32(0.0f);
56
+ svuint16_t subnorm_lut_u16x = svld1_u16(svwhilelt_b16(0u, 8u), nk_e4m3_subnorm_f16_lut_);
57
+ nk_size_t const vector_length = svcnth();
58
+ nk_size_t const half_vector_length = svcntw();
59
+ for (nk_size_t i = 0; i < count; i += vector_length) {
60
+ nk_size_t const batch_size = (i + vector_length < count) ? vector_length : (count - i);
61
+ svbool_t predicate_i8x = svwhilelt_b8_u64(0u, batch_size);
62
+ svbool_t predicate_f16x = svwhilelt_b16_u64(0u, batch_size);
63
+ svuint8_t raw_u8x = svld1_u8(predicate_i8x, (nk_u8_t const *)data + i);
64
+ svfloat16_t values_f16x = nk_e4m3x_to_f16x_ssve_(predicate_f16x, raw_u8x, subnorm_lut_u16x);
65
+
66
+ svbool_t predicate_lo_f32x = svwhilelt_b32_u64(0u, batch_size);
67
+ svfloat32_t values_lo_f32x = svcvt_f32_f16_x(predicate_lo_f32x, values_f16x);
68
+ accumulator_lo_f32x = svmla_f32_m(predicate_lo_f32x, accumulator_lo_f32x, values_lo_f32x, values_lo_f32x);
69
+
70
+ svbool_t predicate_hi_f32x = svwhilelt_b32_u64(half_vector_length, batch_size);
71
+ svfloat32_t values_hi_f32x = svcvtlt_f32_f16_x(predicate_hi_f32x, values_f16x);
72
+ accumulator_hi_f32x = svmla_f32_m(predicate_hi_f32x, accumulator_hi_f32x, values_hi_f32x, values_hi_f32x);
73
+ }
74
+ return svaddv_f32(svptrue_b32(), accumulator_lo_f32x) + svaddv_f32(svptrue_b32(), accumulator_hi_f32x);
75
+ }
76
+
77
+ NK_PUBLIC nk_f32_t nk_dots_reduce_sumsq_e5m2_ssve_(nk_e5m2_t const *data, nk_size_t count) NK_STREAMING_ {
78
+ svfloat32_t accumulator_lo_f32x = svdup_f32(0.0f);
79
+ svfloat32_t accumulator_hi_f32x = svdup_f32(0.0f);
80
+ nk_size_t const vector_length = svcnth();
81
+ nk_size_t const half_vector_length = svcntw();
82
+ for (nk_size_t i = 0; i < count; i += vector_length) {
83
+ nk_size_t const batch_size = (i + vector_length < count) ? vector_length : (count - i);
84
+ svbool_t predicate_i8x = svwhilelt_b8_u64(0u, batch_size);
85
+ svbool_t predicate_f16x = svwhilelt_b16_u64(0u, batch_size);
86
+ svuint8_t raw_u8x = svld1_u8(predicate_i8x, (nk_u8_t const *)data + i);
87
+ svfloat16_t values_f16x = nk_e5m2x_to_f16x_ssve_(predicate_f16x, raw_u8x);
88
+
89
+ svbool_t predicate_lo_f32x = svwhilelt_b32_u64(0u, batch_size);
90
+ svfloat32_t values_lo_f32x = svcvt_f32_f16_x(predicate_lo_f32x, values_f16x);
91
+ accumulator_lo_f32x = svmla_f32_m(predicate_lo_f32x, accumulator_lo_f32x, values_lo_f32x, values_lo_f32x);
92
+
93
+ svbool_t predicate_hi_f32x = svwhilelt_b32_u64(half_vector_length, batch_size);
94
+ svfloat32_t values_hi_f32x = svcvtlt_f32_f16_x(predicate_hi_f32x, values_f16x);
95
+ accumulator_hi_f32x = svmla_f32_m(predicate_hi_f32x, accumulator_hi_f32x, values_hi_f32x, values_hi_f32x);
96
+ }
97
+ return svaddv_f32(svptrue_b32(), accumulator_lo_f32x) + svaddv_f32(svptrue_b32(), accumulator_hi_f32x);
98
+ }
99
+
100
+ NK_PUBLIC nk_f32_t nk_dots_reduce_sumsq_e2m3_ssve_(nk_e2m3_t const *data, nk_size_t count) NK_STREAMING_COMPATIBLE_ {
101
+ svint64_t accumulator_i64x = svdup_s64(0);
102
+ nk_size_t const vector_length = svcntd();
103
+ for (nk_size_t i = 0; i < count; i += vector_length) {
104
+ svbool_t predicate_i64x = svwhilelt_b64_u64(i, count);
105
+ svuint8_t raw_u8x = svld1_u8(svwhilelt_b8_u64(i, count), (nk_u8_t const *)data + i);
106
+ svint8_t values_i8x = nk_e2m3x_to_i8x_ssve_(svwhilelt_b8_u64(i, count), raw_u8x);
107
+ svint16_t values_i16x = svunpklo_s16(values_i8x);
108
+ svint16_t squares_i16x = svmul_s16_z(svwhilelt_b16_u64(i, count), values_i16x, values_i16x);
109
+ svint64_t squares_i64x = svunpklo_s64(svunpklo_s32(squares_i16x));
110
+ accumulator_i64x = svadd_s64_m(predicate_i64x, accumulator_i64x, squares_i64x);
111
+ }
112
+ return (nk_f32_t)svaddv_s64(svptrue_b64(), accumulator_i64x) / 256.0f;
113
+ }
114
+
115
+ NK_PUBLIC nk_f32_t nk_dots_reduce_sumsq_e3m2_ssve_(nk_e3m2_t const *data, nk_size_t count) NK_STREAMING_ {
116
+ svfloat32_t accumulator_lo_f32x = svdup_f32(0.0f);
117
+ svfloat32_t accumulator_hi_f32x = svdup_f32(0.0f);
118
+ nk_size_t const vector_length = svcnth();
119
+ nk_size_t const half_vector_length = svcntw();
120
+ for (nk_size_t i = 0; i < count; i += vector_length) {
121
+ nk_size_t const batch_size = (i + vector_length < count) ? vector_length : (count - i);
122
+ svbool_t predicate_i8x = svwhilelt_b8_u64(0u, batch_size);
123
+ svbool_t predicate_f16x = svwhilelt_b16_u64(0u, batch_size);
124
+ svuint8_t raw_u8x = svld1_u8(predicate_i8x, (nk_u8_t const *)data + i);
125
+ svfloat16_t values_f16x = nk_e3m2x_to_f16x_ssve_(predicate_f16x, raw_u8x);
126
+
127
+ svbool_t predicate_lo_f32x = svwhilelt_b32_u64(0u, batch_size);
128
+ svfloat32_t values_lo_f32x = svcvt_f32_f16_x(predicate_lo_f32x, values_f16x);
129
+ accumulator_lo_f32x = svmla_f32_m(predicate_lo_f32x, accumulator_lo_f32x, values_lo_f32x, values_lo_f32x);
130
+
131
+ svbool_t predicate_hi_f32x = svwhilelt_b32_u64(half_vector_length, batch_size);
132
+ svfloat32_t values_hi_f32x = svcvtlt_f32_f16_x(predicate_hi_f32x, values_f16x);
133
+ accumulator_hi_f32x = svmla_f32_m(predicate_hi_f32x, accumulator_hi_f32x, values_hi_f32x, values_hi_f32x);
134
+ }
135
+ return svaddv_f32(svptrue_b32(), accumulator_lo_f32x) + svaddv_f32(svptrue_b32(), accumulator_hi_f32x);
136
+ }
137
+
138
+ NK_PUBLIC nk_u32_t nk_dots_reduce_sumsq_i8_ssve_(nk_i8_t const *data, nk_size_t count) NK_STREAMING_COMPATIBLE_ {
139
+ svint64_t accumulator_i64x = svdup_s64(0);
140
+ nk_size_t const vector_length = svcntd();
141
+ for (nk_size_t i = 0; i < count; i += vector_length) {
142
+ svbool_t predicate_i64x = svwhilelt_b64_u64(i, count);
143
+ svint8_t loaded_i8x = svld1_s8(svwhilelt_b8_u64(i, count), data + i);
144
+ svint16_t values_i16x = svunpklo_s16(loaded_i8x);
145
+ svint16_t squares_i16x = svmul_s16_z(svwhilelt_b16_u64(i, count), values_i16x, values_i16x);
146
+ svint64_t squares_i64x = svunpklo_s64(svunpklo_s32(squares_i16x));
147
+ accumulator_i64x = svadd_s64_m(predicate_i64x, accumulator_i64x, squares_i64x);
148
+ }
149
+ return (nk_u32_t)svaddv_s64(svptrue_b64(), accumulator_i64x);
150
+ }
151
+
152
+ NK_PUBLIC nk_u32_t nk_dots_reduce_sumsq_u8_ssve_(nk_u8_t const *data, nk_size_t count) NK_STREAMING_COMPATIBLE_ {
153
+ svuint64_t accumulator_u64x = svdup_u64(0);
154
+ nk_size_t const vector_length = svcntd();
155
+ for (nk_size_t i = 0; i < count; i += vector_length) {
156
+ svbool_t predicate_u64x = svwhilelt_b64_u64(i, count);
157
+ svuint8_t raw_u8x = svld1_u8(svwhilelt_b8_u64(i, count), data + i);
158
+ svuint16_t values_u16x = svunpklo_u16(raw_u8x);
159
+ svuint16_t squares_u16x = svmul_u16_z(svwhilelt_b16_u64(i, count), values_u16x, values_u16x);
160
+ svuint64_t squares_u64x = svunpklo_u64(svunpklo_u32(squares_u16x));
161
+ accumulator_u64x = svadd_u64_m(predicate_u64x, accumulator_u64x, squares_u64x);
162
+ }
163
+ return (nk_u32_t)svaddv_u64(svptrue_b64(), accumulator_u64x);
164
+ }
165
+
166
+ NK_PUBLIC nk_u32_t nk_dots_reduce_sumsq_i4_ssve_(nk_i4x2_t const *data, nk_size_t count) NK_STREAMING_COMPATIBLE_ {
167
+ svint64_t accumulator_i64x = svdup_s64(0);
168
+ nk_u8_t const *bytes = (nk_u8_t const *)data;
169
+ nk_size_t const byte_count = (count + 1) / 2;
170
+ nk_size_t const vector_length = svcntd();
171
+ for (nk_size_t i = 0; i < byte_count; i += vector_length) {
172
+ svbool_t predicate_u8x = svwhilelt_b8_u64(i, byte_count);
173
+ svuint8_t packed_u8x = svld1_u8(predicate_u8x, bytes + i);
174
+ svuint8_t low_u8x = svand_n_u8_x(predicate_u8x, packed_u8x, 0x0F);
175
+ svuint8_t high_u8x = svlsr_n_u8_x(predicate_u8x, packed_u8x, 4);
176
+ // Sign-extend 4-bit to 8-bit: shift left 4, arithmetic shift right 4
177
+ svint8_t low_i8x = svasr_n_s8_x(predicate_u8x, svreinterpret_s8_u8(svlsl_n_u8_x(predicate_u8x, low_u8x, 4)), 4);
178
+ svint8_t high_i8x = svasr_n_s8_x(predicate_u8x, svreinterpret_s8_u8(svlsl_n_u8_x(predicate_u8x, high_u8x, 4)),
179
+ 4);
180
+ // Widen to i16, square, sum per byte
181
+ svbool_t predicate_i16x = svwhilelt_b16_u64(i, byte_count);
182
+ svint16_t low_i16x = svunpklo_s16(low_i8x);
183
+ svint16_t high_i16x = svunpklo_s16(high_i8x);
184
+ svint16_t squares_low_i16x = svmul_s16_z(predicate_i16x, low_i16x, low_i16x);
185
+ svint16_t squares_high_i16x = svmul_s16_z(predicate_i16x, high_i16x, high_i16x);
186
+ svint16_t sum_i16x = svadd_s16_z(predicate_i16x, squares_low_i16x, squares_high_i16x);
187
+ svbool_t predicate_i64x = svwhilelt_b64_u64(i, byte_count);
188
+ svint64_t sum_i64x = svunpklo_s64(svunpklo_s32(sum_i16x));
189
+ accumulator_i64x = svadd_s64_m(predicate_i64x, accumulator_i64x, sum_i64x);
190
+ }
191
+ return (nk_u32_t)svaddv_s64(svptrue_b64(), accumulator_i64x);
192
+ }
193
+
194
+ NK_PUBLIC nk_u32_t nk_dots_reduce_sumsq_u4_ssve_(nk_u4x2_t const *data, nk_size_t count) NK_STREAMING_COMPATIBLE_ {
195
+ svuint64_t accumulator_u64x = svdup_u64(0);
196
+ nk_u8_t const *bytes = (nk_u8_t const *)data;
197
+ nk_size_t const byte_count = (count + 1) / 2;
198
+ nk_size_t const vector_length = svcntd();
199
+ for (nk_size_t i = 0; i < byte_count; i += vector_length) {
200
+ svbool_t predicate_u8x = svwhilelt_b8_u64(i, byte_count);
201
+ svuint8_t packed_u8x = svld1_u8(predicate_u8x, bytes + i);
202
+ svuint8_t low_u8x = svand_n_u8_x(predicate_u8x, packed_u8x, 0x0F);
203
+ svuint8_t high_u8x = svlsr_n_u8_x(predicate_u8x, packed_u8x, 4);
204
+ // Widen to u16, square, sum per byte
205
+ svbool_t predicate_u16x = svwhilelt_b16_u64(i, byte_count);
206
+ svuint16_t low_u16x = svunpklo_u16(low_u8x);
207
+ svuint16_t high_u16x = svunpklo_u16(high_u8x);
208
+ svuint16_t squares_low_u16x = svmul_u16_z(predicate_u16x, low_u16x, low_u16x);
209
+ svuint16_t squares_high_u16x = svmul_u16_z(predicate_u16x, high_u16x, high_u16x);
210
+ svuint16_t sum_u16x = svadd_u16_z(predicate_u16x, squares_low_u16x, squares_high_u16x);
211
+ svbool_t predicate_u64x = svwhilelt_b64_u64(i, byte_count);
212
+ svuint64_t sum_u64x = svunpklo_u64(svunpklo_u32(sum_u16x));
213
+ accumulator_u64x = svadd_u64_m(predicate_u64x, accumulator_u64x, sum_u64x);
214
+ }
215
+ return (nk_u32_t)svaddv_u64(svptrue_b64(), accumulator_u64x);
216
+ }
217
+
218
+ NK_PUBLIC svfloat32_t nk_angulars_from_dot_f32x_ssve_(svbool_t predicate_f32x, svfloat32_t dots_f32x,
219
+ svfloat32_t query_norm_sq_f32x,
220
+ svfloat32_t target_norms_sq_f32x) NK_STREAMING_COMPATIBLE_ {
221
+ svfloat32_t norms_product_f32x = svmul_f32_x(predicate_f32x, query_norm_sq_f32x, target_norms_sq_f32x);
222
+ svfloat32_t rsqrt_f32x = svrsqrte_f32(norms_product_f32x);
223
+ rsqrt_f32x = svmul_f32_x(predicate_f32x, rsqrt_f32x,
224
+ svrsqrts_f32(svmul_f32_x(predicate_f32x, norms_product_f32x, rsqrt_f32x), rsqrt_f32x));
225
+ rsqrt_f32x = svmul_f32_x(predicate_f32x, rsqrt_f32x,
226
+ svrsqrts_f32(svmul_f32_x(predicate_f32x, norms_product_f32x, rsqrt_f32x), rsqrt_f32x));
227
+ svfloat32_t angular_f32x = svsub_f32_x(predicate_f32x, svdup_n_f32(1.0f),
228
+ svmul_f32_x(predicate_f32x, dots_f32x, rsqrt_f32x));
229
+ return svmax_f32_x(predicate_f32x, angular_f32x, svdup_n_f32(0.0f));
230
+ }
231
+
232
+ NK_PUBLIC svfloat32_t nk_euclideans_from_dot_f32x_ssve_(svbool_t predicate_f32x, svfloat32_t dots_f32x,
233
+ svfloat32_t query_norm_sq_f32x,
234
+ svfloat32_t target_norms_sq_f32x) NK_STREAMING_COMPATIBLE_ {
235
+ svfloat32_t sum_sq_f32x = svadd_f32_x(predicate_f32x, query_norm_sq_f32x, target_norms_sq_f32x);
236
+ svfloat32_t dist_sq_f32x = svsub_f32_x(predicate_f32x, sum_sq_f32x,
237
+ svmul_f32_x(predicate_f32x, svdup_n_f32(2.0f), dots_f32x));
238
+ dist_sq_f32x = svmax_f32_x(predicate_f32x, dist_sq_f32x, svdup_n_f32(0.0f));
239
+ return svsqrt_f32_x(predicate_f32x, dist_sq_f32x);
240
+ }
241
+
242
+ #pragma region Half Precision Floats
243
+
244
+ __arm_locally_streaming static void nk_angulars_packed_f16_sme_finalize_streaming_( //
245
+ nk_f16_t const *a, void const *b_packed, nk_f32_t *c, //
246
+ nk_size_t rows, nk_size_t columns, nk_size_t depth, //
247
+ nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
248
+ nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
249
+ nk_f32_t const *b_norms = (nk_f32_t const *)((char const *)b_packed + header->norms_offset);
250
+ for (nk_size_t row_index = 0; row_index < rows; row_index++) {
251
+ nk_f16_t const *a_row = a + row_index * a_stride_elements;
252
+ nk_f32_t *result_row = c + row_index * c_stride_elements;
253
+ nk_f32_t query_norm_sq_f32 = nk_dots_reduce_sumsq_f16_ssve_(a_row, depth);
254
+ svfloat32_t query_norm_sq_f32x = svdup_n_f32(query_norm_sq_f32);
255
+ for (nk_size_t col_index = 0; col_index < columns; col_index += svcntw()) {
256
+ svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, columns);
257
+ svfloat32_t dots_f32x = svld1_f32(predicate_f32x, result_row + col_index);
258
+ svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_f32x, b_norms + col_index);
259
+ svst1_f32(
260
+ predicate_f32x, result_row + col_index,
261
+ nk_angulars_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
262
+ }
263
+ }
264
+ }
265
+
266
+ NK_PUBLIC void nk_angulars_packed_f16_sme( //
267
+ nk_f16_t const *a, void const *b_packed, nk_f32_t *c, //
268
+ nk_size_t rows, nk_size_t columns, nk_size_t depth, //
269
+ nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
270
+ nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_f16_t);
271
+ nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
272
+ nk_dots_packed_f16_sme_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
273
+ nk_angulars_packed_f16_sme_finalize_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements,
274
+ c_stride_elements);
275
+ }
276
+
277
+ __arm_locally_streaming static void nk_euclideans_packed_f16_sme_finalize_streaming_( //
278
+ nk_f16_t const *a, void const *b_packed, nk_f32_t *c, //
279
+ nk_size_t rows, nk_size_t columns, nk_size_t depth, //
280
+ nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
281
+ nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
282
+ nk_f32_t const *b_norms = (nk_f32_t const *)((char const *)b_packed + header->norms_offset);
283
+ for (nk_size_t row_index = 0; row_index < rows; row_index++) {
284
+ nk_f16_t const *a_row = a + row_index * a_stride_elements;
285
+ nk_f32_t *result_row = c + row_index * c_stride_elements;
286
+ nk_f32_t query_norm_sq_f32 = nk_dots_reduce_sumsq_f16_ssve_(a_row, depth);
287
+ svfloat32_t query_norm_sq_f32x = svdup_n_f32(query_norm_sq_f32);
288
+ for (nk_size_t col_index = 0; col_index < columns; col_index += svcntw()) {
289
+ svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, columns);
290
+ svfloat32_t dots_f32x = svld1_f32(predicate_f32x, result_row + col_index);
291
+ svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_f32x, b_norms + col_index);
292
+ svst1_f32(
293
+ predicate_f32x, result_row + col_index,
294
+ nk_euclideans_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
295
+ }
296
+ }
297
+ }
298
+
299
+ NK_PUBLIC void nk_euclideans_packed_f16_sme( //
300
+ nk_f16_t const *a, void const *b_packed, nk_f32_t *c, //
301
+ nk_size_t rows, nk_size_t columns, nk_size_t depth, //
302
+ nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
303
+ nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_f16_t);
304
+ nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
305
+ nk_dots_packed_f16_sme_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
306
+ nk_euclideans_packed_f16_sme_finalize_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements,
307
+ c_stride_elements);
308
+ }
309
+
310
+ __arm_locally_streaming static void nk_angulars_symmetric_f16_sme_finalize_streaming_( //
311
+ nk_f16_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride_elements, //
312
+ nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
313
+ // Phase 1: cache row norms on diagonal
314
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
315
+ nk_f32_t *result_row = result + row_index * result_stride_elements;
316
+ result_row[row_index] = nk_dots_reduce_sumsq_f16_ssve_(vectors + row_index * stride_elements, depth);
317
+ }
318
+ // Phase 2: column-first post-processing
319
+ nk_f32_t norms_cache[256];
320
+ for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
321
+ nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
322
+ for (nk_size_t col = chunk_start; col < chunk_end; ++col)
323
+ norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_f16_ssve_(vectors + col * stride_elements, depth);
324
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
325
+ nk_size_t col_start = row_index + 1 > chunk_start ? row_index + 1 : chunk_start;
326
+ if (col_start >= chunk_end) continue;
327
+ nk_f32_t *result_row = result + row_index * result_stride_elements;
328
+ svfloat32_t query_norm_sq_f32x = svdup_n_f32(result_row[row_index]);
329
+ for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntw()) {
330
+ svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, chunk_end);
331
+ svfloat32_t dots_f32x = svld1_f32(predicate_f32x, result_row + col_index);
332
+ svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_f32x, norms_cache + (col_index - chunk_start));
333
+ svst1_f32(predicate_f32x, result_row + col_index,
334
+ nk_angulars_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x,
335
+ target_norms_sq_f32x));
336
+ }
337
+ }
338
+ }
339
+ // Phase 3: zero diagonals
340
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index)
341
+ result[row_index * result_stride_elements + row_index] = 0;
342
+ }
343
+
344
+ NK_PUBLIC void nk_angulars_symmetric_f16_sme( //
345
+ nk_f16_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
346
+ nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
347
+ nk_size_t const stride_elements = stride / sizeof(nk_f16_t);
348
+ nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
349
+ nk_dots_symmetric_f16_sme_streaming_(vectors, n_vectors, depth, stride_elements, result, result_stride_elements,
350
+ row_start, row_count);
351
+ nk_angulars_symmetric_f16_sme_finalize_streaming_(vectors, n_vectors, depth, stride_elements, result,
352
+ result_stride_elements, row_start, row_count);
353
+ }
354
+
355
+ __arm_locally_streaming static void nk_euclideans_symmetric_f16_sme_finalize_streaming_( //
356
+ nk_f16_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride_elements, //
357
+ nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
358
+ // Phase 1: cache row norms on diagonal
359
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
360
+ nk_f32_t *result_row = result + row_index * result_stride_elements;
361
+ result_row[row_index] = nk_dots_reduce_sumsq_f16_ssve_(vectors + row_index * stride_elements, depth);
362
+ }
363
+ // Phase 2: column-first post-processing
364
+ nk_f32_t norms_cache[256];
365
+ for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
366
+ nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
367
+ for (nk_size_t col = chunk_start; col < chunk_end; ++col)
368
+ norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_f16_ssve_(vectors + col * stride_elements, depth);
369
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
370
+ nk_size_t col_start = row_index + 1 > chunk_start ? row_index + 1 : chunk_start;
371
+ if (col_start >= chunk_end) continue;
372
+ nk_f32_t *result_row = result + row_index * result_stride_elements;
373
+ svfloat32_t query_norm_sq_f32x = svdup_n_f32(result_row[row_index]);
374
+ for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntw()) {
375
+ svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, chunk_end);
376
+ svfloat32_t dots_f32x = svld1_f32(predicate_f32x, result_row + col_index);
377
+ svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_f32x, norms_cache + (col_index - chunk_start));
378
+ svst1_f32(predicate_f32x, result_row + col_index,
379
+ nk_euclideans_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x,
380
+ target_norms_sq_f32x));
381
+ }
382
+ }
383
+ }
384
+ // Phase 3: zero diagonals
385
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index)
386
+ result[row_index * result_stride_elements + row_index] = 0;
387
+ }
388
+
389
+ NK_PUBLIC void nk_euclideans_symmetric_f16_sme( //
390
+ nk_f16_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
391
+ nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
392
+ nk_size_t const stride_elements = stride / sizeof(nk_f16_t);
393
+ nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
394
+ nk_dots_symmetric_f16_sme_streaming_(vectors, n_vectors, depth, stride_elements, result, result_stride_elements,
395
+ row_start, row_count);
396
+ nk_euclideans_symmetric_f16_sme_finalize_streaming_(vectors, n_vectors, depth, stride_elements, result,
397
+ result_stride_elements, row_start, row_count);
398
+ }
399
+
400
+ #pragma endregion // Half Precision Floats
401
+
402
+ #pragma region Brain Float 16
403
+
404
+ __arm_locally_streaming static void nk_angulars_packed_bf16_sme_finalize_streaming_( //
405
+ nk_bf16_t const *a, void const *b_packed, nk_f32_t *c, //
406
+ nk_size_t rows, nk_size_t columns, nk_size_t depth, //
407
+ nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
408
+ nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
409
+ nk_f32_t const *b_norms = (nk_f32_t const *)((char const *)b_packed + header->norms_offset);
410
+ for (nk_size_t row_index = 0; row_index < rows; row_index++) {
411
+ nk_bf16_t const *a_row = a + row_index * a_stride_elements;
412
+ nk_f32_t *result_row = c + row_index * c_stride_elements;
413
+ nk_f32_t query_norm_sq_f32 = nk_dots_reduce_sumsq_bf16_ssve_(a_row, depth);
414
+ svfloat32_t query_norm_sq_f32x = svdup_n_f32(query_norm_sq_f32);
415
+ for (nk_size_t col_index = 0; col_index < columns; col_index += svcntw()) {
416
+ svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, columns);
417
+ svfloat32_t dots_f32x = svld1_f32(predicate_f32x, result_row + col_index);
418
+ svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_f32x, b_norms + col_index);
419
+ svst1_f32(
420
+ predicate_f32x, result_row + col_index,
421
+ nk_angulars_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
422
+ }
423
+ }
424
+ }
425
+
426
+ NK_PUBLIC void nk_angulars_packed_bf16_sme( //
427
+ nk_bf16_t const *a, void const *b_packed, nk_f32_t *c, //
428
+ nk_size_t rows, nk_size_t columns, nk_size_t depth, //
429
+ nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
430
+ nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_bf16_t);
431
+ nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
432
+ nk_dots_packed_bf16_sme_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
433
+ nk_angulars_packed_bf16_sme_finalize_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements,
434
+ c_stride_elements);
435
+ }
436
+
437
+ __arm_locally_streaming static void nk_euclideans_packed_bf16_sme_finalize_streaming_( //
438
+ nk_bf16_t const *a, void const *b_packed, nk_f32_t *c, //
439
+ nk_size_t rows, nk_size_t columns, nk_size_t depth, //
440
+ nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
441
+ nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
442
+ nk_f32_t const *b_norms = (nk_f32_t const *)((char const *)b_packed + header->norms_offset);
443
+ for (nk_size_t row_index = 0; row_index < rows; row_index++) {
444
+ nk_bf16_t const *a_row = a + row_index * a_stride_elements;
445
+ nk_f32_t *result_row = c + row_index * c_stride_elements;
446
+ nk_f32_t query_norm_sq_f32 = nk_dots_reduce_sumsq_bf16_ssve_(a_row, depth);
447
+ svfloat32_t query_norm_sq_f32x = svdup_n_f32(query_norm_sq_f32);
448
+ for (nk_size_t col_index = 0; col_index < columns; col_index += svcntw()) {
449
+ svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, columns);
450
+ svfloat32_t dots_f32x = svld1_f32(predicate_f32x, result_row + col_index);
451
+ svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_f32x, b_norms + col_index);
452
+ svst1_f32(
453
+ predicate_f32x, result_row + col_index,
454
+ nk_euclideans_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
455
+ }
456
+ }
457
+ }
458
+
459
+ NK_PUBLIC void nk_euclideans_packed_bf16_sme( //
460
+ nk_bf16_t const *a, void const *b_packed, nk_f32_t *c, //
461
+ nk_size_t rows, nk_size_t columns, nk_size_t depth, //
462
+ nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
463
+ nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_bf16_t);
464
+ nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
465
+ nk_dots_packed_bf16_sme_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
466
+ nk_euclideans_packed_bf16_sme_finalize_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements,
467
+ c_stride_elements);
468
+ }
469
+
470
+ __arm_locally_streaming static void nk_angulars_symmetric_bf16_sme_finalize_streaming_( //
471
+ nk_bf16_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride_elements, //
472
+ nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
473
+ // Phase 1: cache row norms on diagonal
474
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
475
+ nk_f32_t *result_row = result + row_index * result_stride_elements;
476
+ result_row[row_index] = nk_dots_reduce_sumsq_bf16_ssve_(vectors + row_index * stride_elements, depth);
477
+ }
478
+ // Phase 2: column-first post-processing
479
+ nk_f32_t norms_cache[256];
480
+ for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
481
+ nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
482
+ for (nk_size_t col = chunk_start; col < chunk_end; ++col)
483
+ norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_bf16_ssve_(vectors + col * stride_elements, depth);
484
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
485
+ nk_size_t col_start = row_index + 1 > chunk_start ? row_index + 1 : chunk_start;
486
+ if (col_start >= chunk_end) continue;
487
+ nk_f32_t *result_row = result + row_index * result_stride_elements;
488
+ svfloat32_t query_norm_sq_f32x = svdup_n_f32(result_row[row_index]);
489
+ for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntw()) {
490
+ svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, chunk_end);
491
+ svfloat32_t dots_f32x = svld1_f32(predicate_f32x, result_row + col_index);
492
+ svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_f32x, norms_cache + (col_index - chunk_start));
493
+ svst1_f32(predicate_f32x, result_row + col_index,
494
+ nk_angulars_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x,
495
+ target_norms_sq_f32x));
496
+ }
497
+ }
498
+ }
499
+ // Phase 3: zero diagonals
500
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index)
501
+ result[row_index * result_stride_elements + row_index] = 0;
502
+ }
503
+
504
+ NK_PUBLIC void nk_angulars_symmetric_bf16_sme( //
505
+ nk_bf16_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
506
+ nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
507
+ nk_size_t const stride_elements = stride / sizeof(nk_bf16_t);
508
+ nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
509
+ nk_dots_symmetric_bf16_sme_streaming_(vectors, n_vectors, depth, stride_elements, result, result_stride_elements,
510
+ row_start, row_count);
511
+ nk_angulars_symmetric_bf16_sme_finalize_streaming_(vectors, n_vectors, depth, stride_elements, result,
512
+ result_stride_elements, row_start, row_count);
513
+ }
514
+
515
+ __arm_locally_streaming static void nk_euclideans_symmetric_bf16_sme_finalize_streaming_( //
516
+ nk_bf16_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride_elements, //
517
+ nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
518
+ // Phase 1: cache row norms on diagonal
519
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
520
+ nk_f32_t *result_row = result + row_index * result_stride_elements;
521
+ result_row[row_index] = nk_dots_reduce_sumsq_bf16_ssve_(vectors + row_index * stride_elements, depth);
522
+ }
523
+ // Phase 2: column-first post-processing
524
+ nk_f32_t norms_cache[256];
525
+ for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
526
+ nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
527
+ for (nk_size_t col = chunk_start; col < chunk_end; ++col)
528
+ norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_bf16_ssve_(vectors + col * stride_elements, depth);
529
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
530
+ nk_size_t col_start = row_index + 1 > chunk_start ? row_index + 1 : chunk_start;
531
+ if (col_start >= chunk_end) continue;
532
+ nk_f32_t *result_row = result + row_index * result_stride_elements;
533
+ svfloat32_t query_norm_sq_f32x = svdup_n_f32(result_row[row_index]);
534
+ for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntw()) {
535
+ svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, chunk_end);
536
+ svfloat32_t dots_f32x = svld1_f32(predicate_f32x, result_row + col_index);
537
+ svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_f32x, norms_cache + (col_index - chunk_start));
538
+ svst1_f32(predicate_f32x, result_row + col_index,
539
+ nk_euclideans_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x,
540
+ target_norms_sq_f32x));
541
+ }
542
+ }
543
+ }
544
+ // Phase 3: zero diagonals
545
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index)
546
+ result[row_index * result_stride_elements + row_index] = 0;
547
+ }
548
+
549
+ NK_PUBLIC void nk_euclideans_symmetric_bf16_sme( //
550
+ nk_bf16_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
551
+ nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
552
+ nk_size_t const stride_elements = stride / sizeof(nk_bf16_t);
553
+ nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
554
+ nk_dots_symmetric_bf16_sme_streaming_(vectors, n_vectors, depth, stride_elements, result, result_stride_elements,
555
+ row_start, row_count);
556
+ nk_euclideans_symmetric_bf16_sme_finalize_streaming_(vectors, n_vectors, depth, stride_elements, result,
557
+ result_stride_elements, row_start, row_count);
558
+ }
559
+
560
+ #pragma endregion // Brain Float 16
561
+
562
+ #pragma region Quarter Precision E4M3
563
+
564
+ __arm_locally_streaming static void nk_angulars_packed_e4m3_sme_finalize_streaming_( //
565
+ nk_e4m3_t const *a, void const *b_packed, nk_f32_t *c, //
566
+ nk_size_t rows, nk_size_t columns, nk_size_t depth, //
567
+ nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
568
+ nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
569
+ nk_f32_t const *b_norms = (nk_f32_t const *)((char const *)b_packed + header->norms_offset);
570
+ for (nk_size_t row_index = 0; row_index < rows; row_index++) {
571
+ nk_e4m3_t const *a_row = a + row_index * a_stride_elements;
572
+ nk_f32_t *result_row = c + row_index * c_stride_elements;
573
+ nk_f32_t query_norm_sq_f32 = nk_dots_reduce_sumsq_e4m3_ssve_(a_row, depth);
574
+ svfloat32_t query_norm_sq_f32x = svdup_n_f32(query_norm_sq_f32);
575
+ for (nk_size_t col_index = 0; col_index < columns; col_index += svcntw()) {
576
+ svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, columns);
577
+ svfloat32_t dots_f32x = svld1_f32(predicate_f32x, result_row + col_index);
578
+ svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_f32x, b_norms + col_index);
579
+ svst1_f32(
580
+ predicate_f32x, result_row + col_index,
581
+ nk_angulars_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
582
+ }
583
+ }
584
+ }
585
+
586
+ NK_PUBLIC void nk_angulars_packed_e4m3_sme( //
587
+ nk_e4m3_t const *a, void const *b_packed, nk_f32_t *c, //
588
+ nk_size_t rows, nk_size_t columns, nk_size_t depth, //
589
+ nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
590
+ nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_e4m3_t);
591
+ nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
592
+ nk_dots_packed_e4m3_sme_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
593
+ nk_angulars_packed_e4m3_sme_finalize_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements,
594
+ c_stride_elements);
595
+ }
596
+
597
+ __arm_locally_streaming static void nk_euclideans_packed_e4m3_sme_finalize_streaming_( //
598
+ nk_e4m3_t const *a, void const *b_packed, nk_f32_t *c, //
599
+ nk_size_t rows, nk_size_t columns, nk_size_t depth, //
600
+ nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
601
+ nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
602
+ nk_f32_t const *b_norms = (nk_f32_t const *)((char const *)b_packed + header->norms_offset);
603
+ for (nk_size_t row_index = 0; row_index < rows; row_index++) {
604
+ nk_e4m3_t const *a_row = a + row_index * a_stride_elements;
605
+ nk_f32_t *result_row = c + row_index * c_stride_elements;
606
+ nk_f32_t query_norm_sq_f32 = nk_dots_reduce_sumsq_e4m3_ssve_(a_row, depth);
607
+ svfloat32_t query_norm_sq_f32x = svdup_n_f32(query_norm_sq_f32);
608
+ for (nk_size_t col_index = 0; col_index < columns; col_index += svcntw()) {
609
+ svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, columns);
610
+ svfloat32_t dots_f32x = svld1_f32(predicate_f32x, result_row + col_index);
611
+ svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_f32x, b_norms + col_index);
612
+ svst1_f32(
613
+ predicate_f32x, result_row + col_index,
614
+ nk_euclideans_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
615
+ }
616
+ }
617
+ }
618
+
619
+ NK_PUBLIC void nk_euclideans_packed_e4m3_sme( //
620
+ nk_e4m3_t const *a, void const *b_packed, nk_f32_t *c, //
621
+ nk_size_t rows, nk_size_t columns, nk_size_t depth, //
622
+ nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
623
+ nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_e4m3_t);
624
+ nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
625
+ nk_dots_packed_e4m3_sme_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
626
+ nk_euclideans_packed_e4m3_sme_finalize_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements,
627
+ c_stride_elements);
628
+ }
629
+
630
+ __arm_locally_streaming static void nk_angulars_symmetric_e4m3_sme_finalize_streaming_( //
631
+ nk_e4m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride_elements, //
632
+ nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
633
+ // Phase 1: cache row norms on diagonal
634
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
635
+ nk_f32_t *result_row = result + row_index * result_stride_elements;
636
+ result_row[row_index] = nk_dots_reduce_sumsq_e4m3_ssve_(vectors + row_index * stride_elements, depth);
637
+ }
638
+ // Phase 2: column-first post-processing
639
+ nk_f32_t norms_cache[256];
640
+ for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
641
+ nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
642
+ for (nk_size_t col = chunk_start; col < chunk_end; ++col)
643
+ norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_e4m3_ssve_(vectors + col * stride_elements, depth);
644
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
645
+ nk_size_t col_start = row_index + 1 > chunk_start ? row_index + 1 : chunk_start;
646
+ if (col_start >= chunk_end) continue;
647
+ nk_f32_t *result_row = result + row_index * result_stride_elements;
648
+ svfloat32_t query_norm_sq_f32x = svdup_n_f32(result_row[row_index]);
649
+ for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntw()) {
650
+ svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, chunk_end);
651
+ svfloat32_t dots_f32x = svld1_f32(predicate_f32x, result_row + col_index);
652
+ svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_f32x, norms_cache + (col_index - chunk_start));
653
+ svst1_f32(predicate_f32x, result_row + col_index,
654
+ nk_angulars_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x,
655
+ target_norms_sq_f32x));
656
+ }
657
+ }
658
+ }
659
+ // Phase 3: zero diagonals
660
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index)
661
+ result[row_index * result_stride_elements + row_index] = 0;
662
+ }
663
+
664
+ NK_PUBLIC void nk_angulars_symmetric_e4m3_sme( //
665
+ nk_e4m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
666
+ nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
667
+ nk_size_t const stride_elements = stride / sizeof(nk_e4m3_t);
668
+ nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
669
+ nk_dots_symmetric_e4m3_sme_streaming_(vectors, n_vectors, depth, stride_elements, result, result_stride_elements,
670
+ row_start, row_count);
671
+ nk_angulars_symmetric_e4m3_sme_finalize_streaming_(vectors, n_vectors, depth, stride_elements, result,
672
+ result_stride_elements, row_start, row_count);
673
+ }
674
+
675
+ __arm_locally_streaming static void nk_euclideans_symmetric_e4m3_sme_finalize_streaming_( //
676
+ nk_e4m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride_elements, //
677
+ nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
678
+ // Phase 1: cache row norms on diagonal
679
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
680
+ nk_f32_t *result_row = result + row_index * result_stride_elements;
681
+ result_row[row_index] = nk_dots_reduce_sumsq_e4m3_ssve_(vectors + row_index * stride_elements, depth);
682
+ }
683
+ // Phase 2: column-first post-processing
684
+ nk_f32_t norms_cache[256];
685
+ for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
686
+ nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
687
+ for (nk_size_t col = chunk_start; col < chunk_end; ++col)
688
+ norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_e4m3_ssve_(vectors + col * stride_elements, depth);
689
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
690
+ nk_size_t col_start = row_index + 1 > chunk_start ? row_index + 1 : chunk_start;
691
+ if (col_start >= chunk_end) continue;
692
+ nk_f32_t *result_row = result + row_index * result_stride_elements;
693
+ svfloat32_t query_norm_sq_f32x = svdup_n_f32(result_row[row_index]);
694
+ for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntw()) {
695
+ svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, chunk_end);
696
+ svfloat32_t dots_f32x = svld1_f32(predicate_f32x, result_row + col_index);
697
+ svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_f32x, norms_cache + (col_index - chunk_start));
698
+ svst1_f32(predicate_f32x, result_row + col_index,
699
+ nk_euclideans_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x,
700
+ target_norms_sq_f32x));
701
+ }
702
+ }
703
+ }
704
+ // Phase 3: zero diagonals
705
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index)
706
+ result[row_index * result_stride_elements + row_index] = 0;
707
+ }
708
+
709
+ NK_PUBLIC void nk_euclideans_symmetric_e4m3_sme( //
710
+ nk_e4m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
711
+ nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
712
+ nk_size_t const stride_elements = stride / sizeof(nk_e4m3_t);
713
+ nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
714
+ nk_dots_symmetric_e4m3_sme_streaming_(vectors, n_vectors, depth, stride_elements, result, result_stride_elements,
715
+ row_start, row_count);
716
+ nk_euclideans_symmetric_e4m3_sme_finalize_streaming_(vectors, n_vectors, depth, stride_elements, result,
717
+ result_stride_elements, row_start, row_count);
718
+ }
719
+
720
+ #pragma endregion // Quarter Precision E4M3
721
+
722
+ #pragma region Quarter Precision E5M2
723
+
724
+ __arm_locally_streaming static void nk_angulars_packed_e5m2_sme_finalize_streaming_( //
725
+ nk_e5m2_t const *a, void const *b_packed, nk_f32_t *c, //
726
+ nk_size_t rows, nk_size_t columns, nk_size_t depth, //
727
+ nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
728
+ nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
729
+ nk_f32_t const *b_norms = (nk_f32_t const *)((char const *)b_packed + header->norms_offset);
730
+ for (nk_size_t row_index = 0; row_index < rows; row_index++) {
731
+ nk_e5m2_t const *a_row = a + row_index * a_stride_elements;
732
+ nk_f32_t *result_row = c + row_index * c_stride_elements;
733
+ nk_f32_t query_norm_sq_f32 = nk_dots_reduce_sumsq_e5m2_ssve_(a_row, depth);
734
+ svfloat32_t query_norm_sq_f32x = svdup_n_f32(query_norm_sq_f32);
735
+ for (nk_size_t col_index = 0; col_index < columns; col_index += svcntw()) {
736
+ svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, columns);
737
+ svfloat32_t dots_f32x = svld1_f32(predicate_f32x, result_row + col_index);
738
+ svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_f32x, b_norms + col_index);
739
+ svst1_f32(
740
+ predicate_f32x, result_row + col_index,
741
+ nk_angulars_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
742
+ }
743
+ }
744
+ }
745
+
746
+ NK_PUBLIC void nk_angulars_packed_e5m2_sme( //
747
+ nk_e5m2_t const *a, void const *b_packed, nk_f32_t *c, //
748
+ nk_size_t rows, nk_size_t columns, nk_size_t depth, //
749
+ nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
750
+ nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_e5m2_t);
751
+ nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
752
+ nk_dots_packed_e5m2_sme_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
753
+ nk_angulars_packed_e5m2_sme_finalize_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements,
754
+ c_stride_elements);
755
+ }
756
+
757
+ __arm_locally_streaming static void nk_euclideans_packed_e5m2_sme_finalize_streaming_( //
758
+ nk_e5m2_t const *a, void const *b_packed, nk_f32_t *c, //
759
+ nk_size_t rows, nk_size_t columns, nk_size_t depth, //
760
+ nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
761
+ nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
762
+ nk_f32_t const *b_norms = (nk_f32_t const *)((char const *)b_packed + header->norms_offset);
763
+ for (nk_size_t row_index = 0; row_index < rows; row_index++) {
764
+ nk_e5m2_t const *a_row = a + row_index * a_stride_elements;
765
+ nk_f32_t *result_row = c + row_index * c_stride_elements;
766
+ nk_f32_t query_norm_sq_f32 = nk_dots_reduce_sumsq_e5m2_ssve_(a_row, depth);
767
+ svfloat32_t query_norm_sq_f32x = svdup_n_f32(query_norm_sq_f32);
768
+ for (nk_size_t col_index = 0; col_index < columns; col_index += svcntw()) {
769
+ svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, columns);
770
+ svfloat32_t dots_f32x = svld1_f32(predicate_f32x, result_row + col_index);
771
+ svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_f32x, b_norms + col_index);
772
+ svst1_f32(
773
+ predicate_f32x, result_row + col_index,
774
+ nk_euclideans_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
775
+ }
776
+ }
777
+ }
778
+
779
+ NK_PUBLIC void nk_euclideans_packed_e5m2_sme( //
780
+ nk_e5m2_t const *a, void const *b_packed, nk_f32_t *c, //
781
+ nk_size_t rows, nk_size_t columns, nk_size_t depth, //
782
+ nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
783
+ nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_e5m2_t);
784
+ nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
785
+ nk_dots_packed_e5m2_sme_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
786
+ nk_euclideans_packed_e5m2_sme_finalize_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements,
787
+ c_stride_elements);
788
+ }
789
+
790
+ __arm_locally_streaming static void nk_angulars_symmetric_e5m2_sme_finalize_streaming_( //
791
+ nk_e5m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride_elements, //
792
+ nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
793
+ // Phase 1: cache row norms on diagonal
794
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
795
+ nk_f32_t *result_row = result + row_index * result_stride_elements;
796
+ result_row[row_index] = nk_dots_reduce_sumsq_e5m2_ssve_(vectors + row_index * stride_elements, depth);
797
+ }
798
+ // Phase 2: column-first post-processing
799
+ nk_f32_t norms_cache[256];
800
+ for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
801
+ nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
802
+ for (nk_size_t col = chunk_start; col < chunk_end; ++col)
803
+ norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_e5m2_ssve_(vectors + col * stride_elements, depth);
804
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
805
+ nk_size_t col_start = row_index + 1 > chunk_start ? row_index + 1 : chunk_start;
806
+ if (col_start >= chunk_end) continue;
807
+ nk_f32_t *result_row = result + row_index * result_stride_elements;
808
+ svfloat32_t query_norm_sq_f32x = svdup_n_f32(result_row[row_index]);
809
+ for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntw()) {
810
+ svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, chunk_end);
811
+ svfloat32_t dots_f32x = svld1_f32(predicate_f32x, result_row + col_index);
812
+ svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_f32x, norms_cache + (col_index - chunk_start));
813
+ svst1_f32(predicate_f32x, result_row + col_index,
814
+ nk_angulars_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x,
815
+ target_norms_sq_f32x));
816
+ }
817
+ }
818
+ }
819
+ // Phase 3: zero diagonals
820
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index)
821
+ result[row_index * result_stride_elements + row_index] = 0;
822
+ }
823
+
824
+ NK_PUBLIC void nk_angulars_symmetric_e5m2_sme( //
825
+ nk_e5m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
826
+ nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
827
+ nk_size_t const stride_elements = stride / sizeof(nk_e5m2_t);
828
+ nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
829
+ nk_dots_symmetric_e5m2_sme_streaming_(vectors, n_vectors, depth, stride_elements, result, result_stride_elements,
830
+ row_start, row_count);
831
+ nk_angulars_symmetric_e5m2_sme_finalize_streaming_(vectors, n_vectors, depth, stride_elements, result,
832
+ result_stride_elements, row_start, row_count);
833
+ }
834
+
835
+ __arm_locally_streaming static void nk_euclideans_symmetric_e5m2_sme_finalize_streaming_( //
836
+ nk_e5m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride_elements, //
837
+ nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
838
+ // Phase 1: cache row norms on diagonal
839
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
840
+ nk_f32_t *result_row = result + row_index * result_stride_elements;
841
+ result_row[row_index] = nk_dots_reduce_sumsq_e5m2_ssve_(vectors + row_index * stride_elements, depth);
842
+ }
843
+ // Phase 2: column-first post-processing
844
+ nk_f32_t norms_cache[256];
845
+ for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
846
+ nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
847
+ for (nk_size_t col = chunk_start; col < chunk_end; ++col)
848
+ norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_e5m2_ssve_(vectors + col * stride_elements, depth);
849
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
850
+ nk_size_t col_start = row_index + 1 > chunk_start ? row_index + 1 : chunk_start;
851
+ if (col_start >= chunk_end) continue;
852
+ nk_f32_t *result_row = result + row_index * result_stride_elements;
853
+ svfloat32_t query_norm_sq_f32x = svdup_n_f32(result_row[row_index]);
854
+ for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntw()) {
855
+ svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, chunk_end);
856
+ svfloat32_t dots_f32x = svld1_f32(predicate_f32x, result_row + col_index);
857
+ svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_f32x, norms_cache + (col_index - chunk_start));
858
+ svst1_f32(predicate_f32x, result_row + col_index,
859
+ nk_euclideans_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x,
860
+ target_norms_sq_f32x));
861
+ }
862
+ }
863
+ }
864
+ // Phase 3: zero diagonals
865
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index)
866
+ result[row_index * result_stride_elements + row_index] = 0;
867
+ }
868
+
869
+ NK_PUBLIC void nk_euclideans_symmetric_e5m2_sme( //
870
+ nk_e5m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
871
+ nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
872
+ nk_size_t const stride_elements = stride / sizeof(nk_e5m2_t);
873
+ nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
874
+ nk_dots_symmetric_e5m2_sme_streaming_(vectors, n_vectors, depth, stride_elements, result, result_stride_elements,
875
+ row_start, row_count);
876
+ nk_euclideans_symmetric_e5m2_sme_finalize_streaming_(vectors, n_vectors, depth, stride_elements, result,
877
+ result_stride_elements, row_start, row_count);
878
+ }
879
+
880
+ #pragma endregion // Quarter Precision E5M2
881
+
882
+ #pragma region Micro Precision E2M3
883
+
884
+ __arm_locally_streaming static void nk_angulars_packed_e2m3_sme_finalize_streaming_( //
885
+ nk_e2m3_t const *a, void const *b_packed, nk_f32_t *c, //
886
+ nk_size_t rows, nk_size_t columns, nk_size_t depth, //
887
+ nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
888
+ nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
889
+ nk_f32_t const *b_norms = (nk_f32_t const *)((char const *)b_packed + header->norms_offset);
890
+ for (nk_size_t row_index = 0; row_index < rows; row_index++) {
891
+ nk_e2m3_t const *a_row = a + row_index * a_stride_elements;
892
+ nk_f32_t *result_row = c + row_index * c_stride_elements;
893
+ nk_f32_t query_norm_sq_f32 = nk_dots_reduce_sumsq_e2m3_ssve_(a_row, depth);
894
+ svfloat32_t query_norm_sq_f32x = svdup_n_f32(query_norm_sq_f32);
895
+ for (nk_size_t col_index = 0; col_index < columns; col_index += svcntw()) {
896
+ svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, columns);
897
+ svfloat32_t dots_f32x = svld1_f32(predicate_f32x, result_row + col_index);
898
+ svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_f32x, b_norms + col_index);
899
+ svst1_f32(
900
+ predicate_f32x, result_row + col_index,
901
+ nk_angulars_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
902
+ }
903
+ }
904
+ }
905
+
906
+ NK_PUBLIC void nk_angulars_packed_e2m3_sme( //
907
+ nk_e2m3_t const *a, void const *b_packed, nk_f32_t *c, //
908
+ nk_size_t rows, nk_size_t columns, nk_size_t depth, //
909
+ nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
910
+ nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_e2m3_t);
911
+ nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
912
+ nk_dots_packed_e2m3_sme_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
913
+ nk_angulars_packed_e2m3_sme_finalize_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements,
914
+ c_stride_elements);
915
+ }
916
+
917
+ __arm_locally_streaming static void nk_euclideans_packed_e2m3_sme_finalize_streaming_( //
918
+ nk_e2m3_t const *a, void const *b_packed, nk_f32_t *c, //
919
+ nk_size_t rows, nk_size_t columns, nk_size_t depth, //
920
+ nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
921
+ nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
922
+ nk_f32_t const *b_norms = (nk_f32_t const *)((char const *)b_packed + header->norms_offset);
923
+ for (nk_size_t row_index = 0; row_index < rows; row_index++) {
924
+ nk_e2m3_t const *a_row = a + row_index * a_stride_elements;
925
+ nk_f32_t *result_row = c + row_index * c_stride_elements;
926
+ nk_f32_t query_norm_sq_f32 = nk_dots_reduce_sumsq_e2m3_ssve_(a_row, depth);
927
+ svfloat32_t query_norm_sq_f32x = svdup_n_f32(query_norm_sq_f32);
928
+ for (nk_size_t col_index = 0; col_index < columns; col_index += svcntw()) {
929
+ svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, columns);
930
+ svfloat32_t dots_f32x = svld1_f32(predicate_f32x, result_row + col_index);
931
+ svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_f32x, b_norms + col_index);
932
+ svst1_f32(
933
+ predicate_f32x, result_row + col_index,
934
+ nk_euclideans_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
935
+ }
936
+ }
937
+ }
938
+
939
+ NK_PUBLIC void nk_euclideans_packed_e2m3_sme( //
940
+ nk_e2m3_t const *a, void const *b_packed, nk_f32_t *c, //
941
+ nk_size_t rows, nk_size_t columns, nk_size_t depth, //
942
+ nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
943
+ nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_e2m3_t);
944
+ nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
945
+ nk_dots_packed_e2m3_sme_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
946
+ nk_euclideans_packed_e2m3_sme_finalize_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements,
947
+ c_stride_elements);
948
+ }
949
+
950
+ __arm_locally_streaming static void nk_angulars_symmetric_e2m3_sme_finalize_streaming_( //
951
+ nk_e2m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride_elements, //
952
+ nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
953
+ // Phase 1: cache row norms on diagonal
954
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
955
+ nk_f32_t *result_row = result + row_index * result_stride_elements;
956
+ result_row[row_index] = nk_dots_reduce_sumsq_e2m3_ssve_(vectors + row_index * stride_elements, depth);
957
+ }
958
+ // Phase 2: column-first post-processing
959
+ nk_f32_t norms_cache[256];
960
+ for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
961
+ nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
962
+ for (nk_size_t col = chunk_start; col < chunk_end; ++col)
963
+ norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_e2m3_ssve_(vectors + col * stride_elements, depth);
964
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
965
+ nk_size_t col_start = row_index + 1 > chunk_start ? row_index + 1 : chunk_start;
966
+ if (col_start >= chunk_end) continue;
967
+ nk_f32_t *result_row = result + row_index * result_stride_elements;
968
+ svfloat32_t query_norm_sq_f32x = svdup_n_f32(result_row[row_index]);
969
+ for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntw()) {
970
+ svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, chunk_end);
971
+ svfloat32_t dots_f32x = svld1_f32(predicate_f32x, result_row + col_index);
972
+ svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_f32x, norms_cache + (col_index - chunk_start));
973
+ svst1_f32(predicate_f32x, result_row + col_index,
974
+ nk_angulars_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x,
975
+ target_norms_sq_f32x));
976
+ }
977
+ }
978
+ }
979
+ // Phase 3: zero diagonals
980
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index)
981
+ result[row_index * result_stride_elements + row_index] = 0;
982
+ }
983
+
984
+ NK_PUBLIC void nk_angulars_symmetric_e2m3_sme( //
985
+ nk_e2m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
986
+ nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
987
+ nk_size_t const stride_elements = stride / sizeof(nk_e2m3_t);
988
+ nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
989
+ nk_dots_symmetric_e2m3_sme_streaming_(vectors, n_vectors, depth, stride_elements, result, result_stride_elements,
990
+ row_start, row_count);
991
+ nk_angulars_symmetric_e2m3_sme_finalize_streaming_(vectors, n_vectors, depth, stride_elements, result,
992
+ result_stride_elements, row_start, row_count);
993
+ }
994
+
995
+ __arm_locally_streaming static void nk_euclideans_symmetric_e2m3_sme_finalize_streaming_( //
996
+ nk_e2m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride_elements, //
997
+ nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
998
+ // Phase 1: cache row norms on diagonal
999
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
1000
+ nk_f32_t *result_row = result + row_index * result_stride_elements;
1001
+ result_row[row_index] = nk_dots_reduce_sumsq_e2m3_ssve_(vectors + row_index * stride_elements, depth);
1002
+ }
1003
+ // Phase 2: column-first post-processing
1004
+ nk_f32_t norms_cache[256];
1005
+ for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
1006
+ nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
1007
+ for (nk_size_t col = chunk_start; col < chunk_end; ++col)
1008
+ norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_e2m3_ssve_(vectors + col * stride_elements, depth);
1009
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
1010
+ nk_size_t col_start = row_index + 1 > chunk_start ? row_index + 1 : chunk_start;
1011
+ if (col_start >= chunk_end) continue;
1012
+ nk_f32_t *result_row = result + row_index * result_stride_elements;
1013
+ svfloat32_t query_norm_sq_f32x = svdup_n_f32(result_row[row_index]);
1014
+ for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntw()) {
1015
+ svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, chunk_end);
1016
+ svfloat32_t dots_f32x = svld1_f32(predicate_f32x, result_row + col_index);
1017
+ svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_f32x, norms_cache + (col_index - chunk_start));
1018
+ svst1_f32(predicate_f32x, result_row + col_index,
1019
+ nk_euclideans_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x,
1020
+ target_norms_sq_f32x));
1021
+ }
1022
+ }
1023
+ }
1024
+ // Phase 3: zero diagonals
1025
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index)
1026
+ result[row_index * result_stride_elements + row_index] = 0;
1027
+ }
1028
+
1029
+ NK_PUBLIC void nk_euclideans_symmetric_e2m3_sme( //
1030
+ nk_e2m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
1031
+ nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
1032
+ nk_size_t const stride_elements = stride / sizeof(nk_e2m3_t);
1033
+ nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
1034
+ nk_dots_symmetric_e2m3_sme_streaming_(vectors, n_vectors, depth, stride_elements, result, result_stride_elements,
1035
+ row_start, row_count);
1036
+ nk_euclideans_symmetric_e2m3_sme_finalize_streaming_(vectors, n_vectors, depth, stride_elements, result,
1037
+ result_stride_elements, row_start, row_count);
1038
+ }
1039
+
1040
+ #pragma endregion // Micro Precision E2M3
1041
+
1042
+ #pragma region Micro Precision E3M2
1043
+
1044
+ __arm_locally_streaming static void nk_angulars_packed_e3m2_sme_finalize_streaming_( //
1045
+ nk_e3m2_t const *a, void const *b_packed, nk_f32_t *c, //
1046
+ nk_size_t rows, nk_size_t columns, nk_size_t depth, //
1047
+ nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
1048
+ nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
1049
+ nk_f32_t const *b_norms = (nk_f32_t const *)((char const *)b_packed + header->norms_offset);
1050
+ for (nk_size_t row_index = 0; row_index < rows; row_index++) {
1051
+ nk_e3m2_t const *a_row = a + row_index * a_stride_elements;
1052
+ nk_f32_t *result_row = c + row_index * c_stride_elements;
1053
+ nk_f32_t query_norm_sq_f32 = nk_dots_reduce_sumsq_e3m2_ssve_(a_row, depth);
1054
+ svfloat32_t query_norm_sq_f32x = svdup_n_f32(query_norm_sq_f32);
1055
+ for (nk_size_t col_index = 0; col_index < columns; col_index += svcntw()) {
1056
+ svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, columns);
1057
+ svfloat32_t dots_f32x = svld1_f32(predicate_f32x, result_row + col_index);
1058
+ svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_f32x, b_norms + col_index);
1059
+ svst1_f32(
1060
+ predicate_f32x, result_row + col_index,
1061
+ nk_angulars_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
1062
+ }
1063
+ }
1064
+ }
1065
+
1066
+ NK_PUBLIC void nk_angulars_packed_e3m2_sme( //
1067
+ nk_e3m2_t const *a, void const *b_packed, nk_f32_t *c, //
1068
+ nk_size_t rows, nk_size_t columns, nk_size_t depth, //
1069
+ nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
1070
+ nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_e3m2_t);
1071
+ nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
1072
+ nk_dots_packed_e3m2_sme_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
1073
+ nk_angulars_packed_e3m2_sme_finalize_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements,
1074
+ c_stride_elements);
1075
+ }
1076
+
1077
+ __arm_locally_streaming static void nk_euclideans_packed_e3m2_sme_finalize_streaming_( //
1078
+ nk_e3m2_t const *a, void const *b_packed, nk_f32_t *c, //
1079
+ nk_size_t rows, nk_size_t columns, nk_size_t depth, //
1080
+ nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
1081
+ nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
1082
+ nk_f32_t const *b_norms = (nk_f32_t const *)((char const *)b_packed + header->norms_offset);
1083
+ for (nk_size_t row_index = 0; row_index < rows; row_index++) {
1084
+ nk_e3m2_t const *a_row = a + row_index * a_stride_elements;
1085
+ nk_f32_t *result_row = c + row_index * c_stride_elements;
1086
+ nk_f32_t query_norm_sq_f32 = nk_dots_reduce_sumsq_e3m2_ssve_(a_row, depth);
1087
+ svfloat32_t query_norm_sq_f32x = svdup_n_f32(query_norm_sq_f32);
1088
+ for (nk_size_t col_index = 0; col_index < columns; col_index += svcntw()) {
1089
+ svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, columns);
1090
+ svfloat32_t dots_f32x = svld1_f32(predicate_f32x, result_row + col_index);
1091
+ svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_f32x, b_norms + col_index);
1092
+ svst1_f32(
1093
+ predicate_f32x, result_row + col_index,
1094
+ nk_euclideans_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
1095
+ }
1096
+ }
1097
+ }
1098
+
1099
+ NK_PUBLIC void nk_euclideans_packed_e3m2_sme( //
1100
+ nk_e3m2_t const *a, void const *b_packed, nk_f32_t *c, //
1101
+ nk_size_t rows, nk_size_t columns, nk_size_t depth, //
1102
+ nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
1103
+ nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_e3m2_t);
1104
+ nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
1105
+ nk_dots_packed_e3m2_sme_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
1106
+ nk_euclideans_packed_e3m2_sme_finalize_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements,
1107
+ c_stride_elements);
1108
+ }
1109
+
1110
+ __arm_locally_streaming static void nk_angulars_symmetric_e3m2_sme_finalize_streaming_( //
1111
+ nk_e3m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride_elements, //
1112
+ nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
1113
+ // Phase 1: cache row norms on diagonal
1114
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
1115
+ nk_f32_t *result_row = result + row_index * result_stride_elements;
1116
+ result_row[row_index] = nk_dots_reduce_sumsq_e3m2_ssve_(vectors + row_index * stride_elements, depth);
1117
+ }
1118
+ // Phase 2: column-first post-processing
1119
+ nk_f32_t norms_cache[256];
1120
+ for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
1121
+ nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
1122
+ for (nk_size_t col = chunk_start; col < chunk_end; ++col)
1123
+ norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_e3m2_ssve_(vectors + col * stride_elements, depth);
1124
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
1125
+ nk_size_t col_start = row_index + 1 > chunk_start ? row_index + 1 : chunk_start;
1126
+ if (col_start >= chunk_end) continue;
1127
+ nk_f32_t *result_row = result + row_index * result_stride_elements;
1128
+ svfloat32_t query_norm_sq_f32x = svdup_n_f32(result_row[row_index]);
1129
+ for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntw()) {
1130
+ svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, chunk_end);
1131
+ svfloat32_t dots_f32x = svld1_f32(predicate_f32x, result_row + col_index);
1132
+ svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_f32x, norms_cache + (col_index - chunk_start));
1133
+ svst1_f32(predicate_f32x, result_row + col_index,
1134
+ nk_angulars_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x,
1135
+ target_norms_sq_f32x));
1136
+ }
1137
+ }
1138
+ }
1139
+ // Phase 3: zero diagonals
1140
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index)
1141
+ result[row_index * result_stride_elements + row_index] = 0;
1142
+ }
1143
+
1144
+ NK_PUBLIC void nk_angulars_symmetric_e3m2_sme( //
1145
+ nk_e3m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
1146
+ nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
1147
+ nk_size_t const stride_elements = stride / sizeof(nk_e3m2_t);
1148
+ nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
1149
+ nk_dots_symmetric_e3m2_sme_streaming_(vectors, n_vectors, depth, stride_elements, result, result_stride_elements,
1150
+ row_start, row_count);
1151
+ nk_angulars_symmetric_e3m2_sme_finalize_streaming_(vectors, n_vectors, depth, stride_elements, result,
1152
+ result_stride_elements, row_start, row_count);
1153
+ }
1154
+
1155
+ __arm_locally_streaming static void nk_euclideans_symmetric_e3m2_sme_finalize_streaming_( //
1156
+ nk_e3m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride_elements, //
1157
+ nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
1158
+ // Phase 1: cache row norms on diagonal
1159
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
1160
+ nk_f32_t *result_row = result + row_index * result_stride_elements;
1161
+ result_row[row_index] = nk_dots_reduce_sumsq_e3m2_ssve_(vectors + row_index * stride_elements, depth);
1162
+ }
1163
+ // Phase 2: column-first post-processing
1164
+ nk_f32_t norms_cache[256];
1165
+ for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
1166
+ nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
1167
+ for (nk_size_t col = chunk_start; col < chunk_end; ++col)
1168
+ norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_e3m2_ssve_(vectors + col * stride_elements, depth);
1169
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
1170
+ nk_size_t col_start = row_index + 1 > chunk_start ? row_index + 1 : chunk_start;
1171
+ if (col_start >= chunk_end) continue;
1172
+ nk_f32_t *result_row = result + row_index * result_stride_elements;
1173
+ svfloat32_t query_norm_sq_f32x = svdup_n_f32(result_row[row_index]);
1174
+ for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntw()) {
1175
+ svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, chunk_end);
1176
+ svfloat32_t dots_f32x = svld1_f32(predicate_f32x, result_row + col_index);
1177
+ svfloat32_t target_norms_sq_f32x = svld1_f32(predicate_f32x, norms_cache + (col_index - chunk_start));
1178
+ svst1_f32(predicate_f32x, result_row + col_index,
1179
+ nk_euclideans_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x,
1180
+ target_norms_sq_f32x));
1181
+ }
1182
+ }
1183
+ }
1184
+ // Phase 3: zero diagonals
1185
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index)
1186
+ result[row_index * result_stride_elements + row_index] = 0;
1187
+ }
1188
+
1189
+ NK_PUBLIC void nk_euclideans_symmetric_e3m2_sme( //
1190
+ nk_e3m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
1191
+ nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
1192
+ nk_size_t const stride_elements = stride / sizeof(nk_e3m2_t);
1193
+ nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
1194
+ nk_dots_symmetric_e3m2_sme_streaming_(vectors, n_vectors, depth, stride_elements, result, result_stride_elements,
1195
+ row_start, row_count);
1196
+ nk_euclideans_symmetric_e3m2_sme_finalize_streaming_(vectors, n_vectors, depth, stride_elements, result,
1197
+ result_stride_elements, row_start, row_count);
1198
+ }
1199
+
1200
+ #pragma endregion // Micro Precision E3M2
1201
+ #pragma region Signed 8-bit Integers
1202
+
1203
+ __arm_locally_streaming static void nk_angulars_packed_i8_sme_finalize_streaming_( //
1204
+ nk_i8_t const *a, void const *b_packed, nk_f32_t *c, //
1205
+ nk_size_t rows, nk_size_t columns, nk_size_t depth, //
1206
+ nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
1207
+ nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
1208
+ nk_u32_t const *b_norms = (nk_u32_t const *)((char const *)b_packed + header->norms_offset);
1209
+ for (nk_size_t row_index = 0; row_index < rows; row_index++) {
1210
+ nk_i8_t const *a_row = a + row_index * a_stride_elements;
1211
+ nk_f32_t *result_row = c + row_index * c_stride_elements;
1212
+ nk_u32_t query_norm_sq_u32 = nk_dots_reduce_sumsq_i8_ssve_(a_row, depth);
1213
+ svfloat32_t query_norm_sq_f32x = svdup_n_f32((nk_f32_t)query_norm_sq_u32);
1214
+ for (nk_size_t col_index = 0; col_index < columns; col_index += svcntw()) {
1215
+ svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, columns);
1216
+ svfloat32_t dots_f32x = svcvt_f32_s32_x(
1217
+ predicate_f32x, svld1_s32(predicate_f32x, (nk_i32_t const *)(result_row + col_index)));
1218
+ svfloat32_t target_norms_sq_f32x = svcvt_f32_u32_x(predicate_f32x,
1219
+ svld1_u32(predicate_f32x, b_norms + col_index));
1220
+ svst1_f32(
1221
+ predicate_f32x, result_row + col_index,
1222
+ nk_angulars_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
1223
+ }
1224
+ }
1225
+ }
1226
+
1227
+ NK_PUBLIC void nk_angulars_packed_i8_sme( //
1228
+ nk_i8_t const *a, void const *b_packed, nk_f32_t *c, //
1229
+ nk_size_t rows, nk_size_t columns, nk_size_t depth, //
1230
+ nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
1231
+ nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_i8_t);
1232
+ nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
1233
+ nk_dots_packed_i8_sme_streaming_(a, b_packed, (nk_i32_t *)c, rows, columns, depth, a_stride_elements,
1234
+ c_stride_elements);
1235
+ nk_angulars_packed_i8_sme_finalize_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements,
1236
+ c_stride_elements);
1237
+ }
1238
+
1239
+ __arm_locally_streaming static void nk_euclideans_packed_i8_sme_finalize_streaming_( //
1240
+ nk_i8_t const *a, void const *b_packed, nk_f32_t *c, //
1241
+ nk_size_t rows, nk_size_t columns, nk_size_t depth, //
1242
+ nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
1243
+ nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
1244
+ nk_u32_t const *b_norms = (nk_u32_t const *)((char const *)b_packed + header->norms_offset);
1245
+ for (nk_size_t row_index = 0; row_index < rows; row_index++) {
1246
+ nk_i8_t const *a_row = a + row_index * a_stride_elements;
1247
+ nk_f32_t *result_row = c + row_index * c_stride_elements;
1248
+ nk_u32_t query_norm_sq_u32 = nk_dots_reduce_sumsq_i8_ssve_(a_row, depth);
1249
+ svfloat32_t query_norm_sq_f32x = svdup_n_f32((nk_f32_t)query_norm_sq_u32);
1250
+ for (nk_size_t col_index = 0; col_index < columns; col_index += svcntw()) {
1251
+ svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, columns);
1252
+ svfloat32_t dots_f32x = svcvt_f32_s32_x(
1253
+ predicate_f32x, svld1_s32(predicate_f32x, (nk_i32_t const *)(result_row + col_index)));
1254
+ svfloat32_t target_norms_sq_f32x = svcvt_f32_u32_x(predicate_f32x,
1255
+ svld1_u32(predicate_f32x, b_norms + col_index));
1256
+ svst1_f32(
1257
+ predicate_f32x, result_row + col_index,
1258
+ nk_euclideans_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
1259
+ }
1260
+ }
1261
+ }
1262
+
1263
+ NK_PUBLIC void nk_euclideans_packed_i8_sme( //
1264
+ nk_i8_t const *a, void const *b_packed, nk_f32_t *c, //
1265
+ nk_size_t rows, nk_size_t columns, nk_size_t depth, //
1266
+ nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
1267
+ nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_i8_t);
1268
+ nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
1269
+ nk_dots_packed_i8_sme_streaming_(a, b_packed, (nk_i32_t *)c, rows, columns, depth, a_stride_elements,
1270
+ c_stride_elements);
1271
+ nk_euclideans_packed_i8_sme_finalize_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements,
1272
+ c_stride_elements);
1273
+ }
1274
+
1275
+ __arm_locally_streaming static void nk_angulars_symmetric_i8_sme_finalize_streaming_( //
1276
+ nk_i8_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride_elements, //
1277
+ nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
1278
+ // Phase 1: cache row norms on diagonal (store as u32 in f32 slot)
1279
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
1280
+ nk_u32_t row_sumsq_u32 = nk_dots_reduce_sumsq_i8_ssve_(vectors + row_index * stride_elements, depth);
1281
+ ((nk_u32_t *)(result + row_index * result_stride_elements))[row_index] = row_sumsq_u32;
1282
+ }
1283
+ // Phase 2: column-first post-processing
1284
+ nk_u32_t norms_cache[256];
1285
+ for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
1286
+ nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
1287
+ for (nk_size_t col = chunk_start; col < chunk_end; ++col)
1288
+ norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_i8_ssve_(vectors + col * stride_elements, depth);
1289
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
1290
+ nk_size_t col_start = row_index + 1 > chunk_start ? row_index + 1 : chunk_start;
1291
+ if (col_start >= chunk_end) continue;
1292
+ nk_f32_t *result_row = result + row_index * result_stride_elements;
1293
+ nk_u32_t query_sumsq_u32 = ((nk_u32_t *)result_row)[row_index];
1294
+ svfloat32_t query_norm_sq_f32x = svdup_n_f32((nk_f32_t)query_sumsq_u32);
1295
+ for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntw()) {
1296
+ svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, chunk_end);
1297
+ svfloat32_t dots_f32x = svcvt_f32_s32_x(
1298
+ predicate_f32x, svld1_s32(predicate_f32x, (nk_i32_t *)(result_row + col_index)));
1299
+ svfloat32_t target_norms_sq_f32x = svcvt_f32_u32_x(
1300
+ predicate_f32x, svld1_u32(predicate_f32x, norms_cache + (col_index - chunk_start)));
1301
+ svst1_f32(predicate_f32x, result_row + col_index,
1302
+ nk_angulars_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x,
1303
+ target_norms_sq_f32x));
1304
+ }
1305
+ }
1306
+ }
1307
+ // Phase 3: zero diagonals
1308
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index)
1309
+ result[row_index * result_stride_elements + row_index] = 0;
1310
+ }
1311
+
1312
+ NK_PUBLIC void nk_angulars_symmetric_i8_sme( //
1313
+ nk_i8_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
1314
+ nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
1315
+ nk_size_t const stride_elements = stride / sizeof(nk_i8_t);
1316
+ nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
1317
+ nk_dots_symmetric_i8_sme_streaming_(vectors, n_vectors, depth, stride_elements, (nk_i32_t *)result,
1318
+ result_stride_elements, row_start, row_count);
1319
+ nk_angulars_symmetric_i8_sme_finalize_streaming_(vectors, n_vectors, depth, stride_elements, result,
1320
+ result_stride_elements, row_start, row_count);
1321
+ }
1322
+
1323
+ __arm_locally_streaming static void nk_euclideans_symmetric_i8_sme_finalize_streaming_( //
1324
+ nk_i8_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride_elements, //
1325
+ nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
1326
+ // Phase 1: cache row norms on diagonal (store as u32 in f32 slot)
1327
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
1328
+ nk_u32_t row_sumsq_u32 = nk_dots_reduce_sumsq_i8_ssve_(vectors + row_index * stride_elements, depth);
1329
+ ((nk_u32_t *)(result + row_index * result_stride_elements))[row_index] = row_sumsq_u32;
1330
+ }
1331
+ // Phase 2: column-first post-processing
1332
+ nk_u32_t norms_cache[256];
1333
+ for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
1334
+ nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
1335
+ for (nk_size_t col = chunk_start; col < chunk_end; ++col)
1336
+ norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_i8_ssve_(vectors + col * stride_elements, depth);
1337
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
1338
+ nk_size_t col_start = row_index + 1 > chunk_start ? row_index + 1 : chunk_start;
1339
+ if (col_start >= chunk_end) continue;
1340
+ nk_f32_t *result_row = result + row_index * result_stride_elements;
1341
+ nk_u32_t query_sumsq_u32 = ((nk_u32_t *)result_row)[row_index];
1342
+ svfloat32_t query_norm_sq_f32x = svdup_n_f32((nk_f32_t)query_sumsq_u32);
1343
+ for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntw()) {
1344
+ svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, chunk_end);
1345
+ svfloat32_t dots_f32x = svcvt_f32_s32_x(
1346
+ predicate_f32x, svld1_s32(predicate_f32x, (nk_i32_t *)(result_row + col_index)));
1347
+ svfloat32_t target_norms_sq_f32x = svcvt_f32_u32_x(
1348
+ predicate_f32x, svld1_u32(predicate_f32x, norms_cache + (col_index - chunk_start)));
1349
+ svst1_f32(predicate_f32x, result_row + col_index,
1350
+ nk_euclideans_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x,
1351
+ target_norms_sq_f32x));
1352
+ }
1353
+ }
1354
+ }
1355
+ // Phase 3: zero diagonals
1356
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index)
1357
+ result[row_index * result_stride_elements + row_index] = 0;
1358
+ }
1359
+
1360
+ NK_PUBLIC void nk_euclideans_symmetric_i8_sme( //
1361
+ nk_i8_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
1362
+ nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
1363
+ nk_size_t const stride_elements = stride / sizeof(nk_i8_t);
1364
+ nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
1365
+ nk_dots_symmetric_i8_sme_streaming_(vectors, n_vectors, depth, stride_elements, (nk_i32_t *)result,
1366
+ result_stride_elements, row_start, row_count);
1367
+ nk_euclideans_symmetric_i8_sme_finalize_streaming_(vectors, n_vectors, depth, stride_elements, result,
1368
+ result_stride_elements, row_start, row_count);
1369
+ }
1370
+
1371
+ #pragma endregion // Signed 8-bit Integers
1372
+
1373
+ #pragma region Unsigned 8-bit Integers
1374
+
1375
+ __arm_locally_streaming static void nk_angulars_packed_u8_sme_finalize_streaming_( //
1376
+ nk_u8_t const *a, void const *b_packed, nk_f32_t *c, //
1377
+ nk_size_t rows, nk_size_t columns, nk_size_t depth, //
1378
+ nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
1379
+ nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
1380
+ nk_u32_t const *b_norms = (nk_u32_t const *)((char const *)b_packed + header->norms_offset);
1381
+ for (nk_size_t row_index = 0; row_index < rows; row_index++) {
1382
+ nk_u8_t const *a_row = a + row_index * a_stride_elements;
1383
+ nk_f32_t *result_row = c + row_index * c_stride_elements;
1384
+ nk_u32_t query_norm_sq_u32 = nk_dots_reduce_sumsq_u8_ssve_(a_row, depth);
1385
+ svfloat32_t query_norm_sq_f32x = svdup_n_f32((nk_f32_t)query_norm_sq_u32);
1386
+ for (nk_size_t col_index = 0; col_index < columns; col_index += svcntw()) {
1387
+ svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, columns);
1388
+ svfloat32_t dots_f32x = svcvt_f32_u32_x(
1389
+ predicate_f32x, svld1_u32(predicate_f32x, (nk_u32_t const *)(result_row + col_index)));
1390
+ svfloat32_t target_norms_sq_f32x = svcvt_f32_u32_x(predicate_f32x,
1391
+ svld1_u32(predicate_f32x, b_norms + col_index));
1392
+ svst1_f32(
1393
+ predicate_f32x, result_row + col_index,
1394
+ nk_angulars_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
1395
+ }
1396
+ }
1397
+ }
1398
+
1399
+ NK_PUBLIC void nk_angulars_packed_u8_sme( //
1400
+ nk_u8_t const *a, void const *b_packed, nk_f32_t *c, //
1401
+ nk_size_t rows, nk_size_t columns, nk_size_t depth, //
1402
+ nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
1403
+ nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_u8_t);
1404
+ nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
1405
+ nk_dots_packed_u8_sme_streaming_(a, b_packed, (nk_u32_t *)c, rows, columns, depth, a_stride_elements,
1406
+ c_stride_elements);
1407
+ nk_angulars_packed_u8_sme_finalize_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements,
1408
+ c_stride_elements);
1409
+ }
1410
+
1411
+ __arm_locally_streaming static void nk_euclideans_packed_u8_sme_finalize_streaming_( //
1412
+ nk_u8_t const *a, void const *b_packed, nk_f32_t *c, //
1413
+ nk_size_t rows, nk_size_t columns, nk_size_t depth, //
1414
+ nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
1415
+ nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
1416
+ nk_u32_t const *b_norms = (nk_u32_t const *)((char const *)b_packed + header->norms_offset);
1417
+ for (nk_size_t row_index = 0; row_index < rows; row_index++) {
1418
+ nk_u8_t const *a_row = a + row_index * a_stride_elements;
1419
+ nk_f32_t *result_row = c + row_index * c_stride_elements;
1420
+ nk_u32_t query_norm_sq_u32 = nk_dots_reduce_sumsq_u8_ssve_(a_row, depth);
1421
+ svfloat32_t query_norm_sq_f32x = svdup_n_f32((nk_f32_t)query_norm_sq_u32);
1422
+ for (nk_size_t col_index = 0; col_index < columns; col_index += svcntw()) {
1423
+ svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, columns);
1424
+ svfloat32_t dots_f32x = svcvt_f32_u32_x(
1425
+ predicate_f32x, svld1_u32(predicate_f32x, (nk_u32_t const *)(result_row + col_index)));
1426
+ svfloat32_t target_norms_sq_f32x = svcvt_f32_u32_x(predicate_f32x,
1427
+ svld1_u32(predicate_f32x, b_norms + col_index));
1428
+ svst1_f32(
1429
+ predicate_f32x, result_row + col_index,
1430
+ nk_euclideans_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
1431
+ }
1432
+ }
1433
+ }
1434
+
1435
+ NK_PUBLIC void nk_euclideans_packed_u8_sme( //
1436
+ nk_u8_t const *a, void const *b_packed, nk_f32_t *c, //
1437
+ nk_size_t rows, nk_size_t columns, nk_size_t depth, //
1438
+ nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
1439
+ nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_u8_t);
1440
+ nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
1441
+ nk_dots_packed_u8_sme_streaming_(a, b_packed, (nk_u32_t *)c, rows, columns, depth, a_stride_elements,
1442
+ c_stride_elements);
1443
+ nk_euclideans_packed_u8_sme_finalize_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements,
1444
+ c_stride_elements);
1445
+ }
1446
+
1447
+ __arm_locally_streaming static void nk_angulars_symmetric_u8_sme_finalize_streaming_( //
1448
+ nk_u8_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride_elements, //
1449
+ nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
1450
+ // Phase 1: cache row norms on diagonal (store as u32 in f32 slot)
1451
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
1452
+ nk_u32_t row_sumsq_u32 = nk_dots_reduce_sumsq_u8_ssve_(vectors + row_index * stride_elements, depth);
1453
+ ((nk_u32_t *)(result + row_index * result_stride_elements))[row_index] = row_sumsq_u32;
1454
+ }
1455
+ // Phase 2: column-first post-processing
1456
+ nk_u32_t norms_cache[256];
1457
+ for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
1458
+ nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
1459
+ for (nk_size_t col = chunk_start; col < chunk_end; ++col)
1460
+ norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_u8_ssve_(vectors + col * stride_elements, depth);
1461
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
1462
+ nk_size_t col_start = row_index + 1 > chunk_start ? row_index + 1 : chunk_start;
1463
+ if (col_start >= chunk_end) continue;
1464
+ nk_f32_t *result_row = result + row_index * result_stride_elements;
1465
+ nk_u32_t query_sumsq_u32 = ((nk_u32_t *)result_row)[row_index];
1466
+ svfloat32_t query_norm_sq_f32x = svdup_n_f32((nk_f32_t)query_sumsq_u32);
1467
+ for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntw()) {
1468
+ svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, chunk_end);
1469
+ svfloat32_t dots_f32x = svcvt_f32_u32_x(
1470
+ predicate_f32x, svld1_u32(predicate_f32x, (nk_u32_t *)(result_row + col_index)));
1471
+ svfloat32_t target_norms_sq_f32x = svcvt_f32_u32_x(
1472
+ predicate_f32x, svld1_u32(predicate_f32x, norms_cache + (col_index - chunk_start)));
1473
+ svst1_f32(predicate_f32x, result_row + col_index,
1474
+ nk_angulars_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x,
1475
+ target_norms_sq_f32x));
1476
+ }
1477
+ }
1478
+ }
1479
+ // Phase 3: zero diagonals
1480
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index)
1481
+ result[row_index * result_stride_elements + row_index] = 0;
1482
+ }
1483
+
1484
+ NK_PUBLIC void nk_angulars_symmetric_u8_sme( //
1485
+ nk_u8_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
1486
+ nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
1487
+ nk_size_t const stride_elements = stride / sizeof(nk_u8_t);
1488
+ nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
1489
+ nk_dots_symmetric_u8_sme_streaming_(vectors, n_vectors, depth, stride_elements, (nk_u32_t *)result,
1490
+ result_stride_elements, row_start, row_count);
1491
+ nk_angulars_symmetric_u8_sme_finalize_streaming_(vectors, n_vectors, depth, stride_elements, result,
1492
+ result_stride_elements, row_start, row_count);
1493
+ }
1494
+
1495
+ __arm_locally_streaming static void nk_euclideans_symmetric_u8_sme_finalize_streaming_( //
1496
+ nk_u8_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride_elements, //
1497
+ nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
1498
+ // Phase 1: cache row norms on diagonal (store as u32 in f32 slot)
1499
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
1500
+ nk_u32_t row_sumsq_u32 = nk_dots_reduce_sumsq_u8_ssve_(vectors + row_index * stride_elements, depth);
1501
+ ((nk_u32_t *)(result + row_index * result_stride_elements))[row_index] = row_sumsq_u32;
1502
+ }
1503
+ // Phase 2: column-first post-processing
1504
+ nk_u32_t norms_cache[256];
1505
+ for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
1506
+ nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
1507
+ for (nk_size_t col = chunk_start; col < chunk_end; ++col)
1508
+ norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_u8_ssve_(vectors + col * stride_elements, depth);
1509
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
1510
+ nk_size_t col_start = row_index + 1 > chunk_start ? row_index + 1 : chunk_start;
1511
+ if (col_start >= chunk_end) continue;
1512
+ nk_f32_t *result_row = result + row_index * result_stride_elements;
1513
+ nk_u32_t query_sumsq_u32 = ((nk_u32_t *)result_row)[row_index];
1514
+ svfloat32_t query_norm_sq_f32x = svdup_n_f32((nk_f32_t)query_sumsq_u32);
1515
+ for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntw()) {
1516
+ svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, chunk_end);
1517
+ svfloat32_t dots_f32x = svcvt_f32_u32_x(
1518
+ predicate_f32x, svld1_u32(predicate_f32x, (nk_u32_t *)(result_row + col_index)));
1519
+ svfloat32_t target_norms_sq_f32x = svcvt_f32_u32_x(
1520
+ predicate_f32x, svld1_u32(predicate_f32x, norms_cache + (col_index - chunk_start)));
1521
+ svst1_f32(predicate_f32x, result_row + col_index,
1522
+ nk_euclideans_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x,
1523
+ target_norms_sq_f32x));
1524
+ }
1525
+ }
1526
+ }
1527
+ // Phase 3: zero diagonals
1528
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index)
1529
+ result[row_index * result_stride_elements + row_index] = 0;
1530
+ }
1531
+
1532
+ NK_PUBLIC void nk_euclideans_symmetric_u8_sme( //
1533
+ nk_u8_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
1534
+ nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
1535
+ nk_size_t const stride_elements = stride / sizeof(nk_u8_t);
1536
+ nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
1537
+ nk_dots_symmetric_u8_sme_streaming_(vectors, n_vectors, depth, stride_elements, (nk_u32_t *)result,
1538
+ result_stride_elements, row_start, row_count);
1539
+ nk_euclideans_symmetric_u8_sme_finalize_streaming_(vectors, n_vectors, depth, stride_elements, result,
1540
+ result_stride_elements, row_start, row_count);
1541
+ }
1542
+
1543
+ #pragma endregion // Unsigned 8-bit Integers
1544
+
1545
+ #pragma region Nibble Signed Integers
1546
+
1547
+ __arm_locally_streaming static void nk_angulars_packed_i4_sme_finalize_streaming_( //
1548
+ nk_i4x2_t const *a, void const *b_packed, nk_f32_t *c, //
1549
+ nk_size_t rows, nk_size_t columns, nk_size_t depth, //
1550
+ nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
1551
+ nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
1552
+ nk_u32_t const *b_norms = (nk_u32_t const *)((char const *)b_packed + header->norms_offset);
1553
+ for (nk_size_t row_index = 0; row_index < rows; row_index++) {
1554
+ nk_i4x2_t const *a_row = a + row_index * a_stride_elements;
1555
+ nk_f32_t *result_row = c + row_index * c_stride_elements;
1556
+ nk_u32_t query_norm_sq_u32 = nk_dots_reduce_sumsq_i4_ssve_(a_row, depth);
1557
+ svfloat32_t query_norm_sq_f32x = svdup_n_f32((nk_f32_t)query_norm_sq_u32);
1558
+ for (nk_size_t col_index = 0; col_index < columns; col_index += svcntw()) {
1559
+ svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, columns);
1560
+ svfloat32_t dots_f32x = svcvt_f32_s32_x(
1561
+ predicate_f32x, svld1_s32(predicate_f32x, (nk_i32_t const *)(result_row + col_index)));
1562
+ svfloat32_t target_norms_sq_f32x = svcvt_f32_u32_x(predicate_f32x,
1563
+ svld1_u32(predicate_f32x, b_norms + col_index));
1564
+ svst1_f32(
1565
+ predicate_f32x, result_row + col_index,
1566
+ nk_angulars_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
1567
+ }
1568
+ }
1569
+ }
1570
+
1571
+ NK_PUBLIC void nk_angulars_packed_i4_sme( //
1572
+ nk_i4x2_t const *a, void const *b_packed, nk_f32_t *c, //
1573
+ nk_size_t rows, nk_size_t columns, nk_size_t depth, //
1574
+ nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
1575
+ nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_i4x2_t);
1576
+ nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
1577
+ nk_dots_packed_i4_sme_streaming_(a, b_packed, (nk_i32_t *)c, rows, columns, depth, a_stride_elements,
1578
+ c_stride_elements);
1579
+ nk_angulars_packed_i4_sme_finalize_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements,
1580
+ c_stride_elements);
1581
+ }
1582
+
1583
+ __arm_locally_streaming static void nk_euclideans_packed_i4_sme_finalize_streaming_( //
1584
+ nk_i4x2_t const *a, void const *b_packed, nk_f32_t *c, //
1585
+ nk_size_t rows, nk_size_t columns, nk_size_t depth, //
1586
+ nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
1587
+ nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
1588
+ nk_u32_t const *b_norms = (nk_u32_t const *)((char const *)b_packed + header->norms_offset);
1589
+ for (nk_size_t row_index = 0; row_index < rows; row_index++) {
1590
+ nk_i4x2_t const *a_row = a + row_index * a_stride_elements;
1591
+ nk_f32_t *result_row = c + row_index * c_stride_elements;
1592
+ nk_u32_t query_norm_sq_u32 = nk_dots_reduce_sumsq_i4_ssve_(a_row, depth);
1593
+ svfloat32_t query_norm_sq_f32x = svdup_n_f32((nk_f32_t)query_norm_sq_u32);
1594
+ for (nk_size_t col_index = 0; col_index < columns; col_index += svcntw()) {
1595
+ svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, columns);
1596
+ svfloat32_t dots_f32x = svcvt_f32_s32_x(
1597
+ predicate_f32x, svld1_s32(predicate_f32x, (nk_i32_t const *)(result_row + col_index)));
1598
+ svfloat32_t target_norms_sq_f32x = svcvt_f32_u32_x(predicate_f32x,
1599
+ svld1_u32(predicate_f32x, b_norms + col_index));
1600
+ svst1_f32(
1601
+ predicate_f32x, result_row + col_index,
1602
+ nk_euclideans_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
1603
+ }
1604
+ }
1605
+ }
1606
+
1607
+ NK_PUBLIC void nk_euclideans_packed_i4_sme( //
1608
+ nk_i4x2_t const *a, void const *b_packed, nk_f32_t *c, //
1609
+ nk_size_t rows, nk_size_t columns, nk_size_t depth, //
1610
+ nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
1611
+ nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_i4x2_t);
1612
+ nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
1613
+ nk_dots_packed_i4_sme_streaming_(a, b_packed, (nk_i32_t *)c, rows, columns, depth, a_stride_elements,
1614
+ c_stride_elements);
1615
+ nk_euclideans_packed_i4_sme_finalize_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements,
1616
+ c_stride_elements);
1617
+ }
1618
+
1619
+ __arm_locally_streaming static void nk_angulars_symmetric_i4_sme_finalize_streaming_( //
1620
+ nk_i4x2_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride_elements, //
1621
+ nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
1622
+ // Phase 1: cache row norms on diagonal (store as u32 in f32 slot)
1623
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
1624
+ nk_u32_t row_sumsq_u32 = nk_dots_reduce_sumsq_i4_ssve_(vectors + row_index * stride_elements, depth);
1625
+ ((nk_u32_t *)(result + row_index * result_stride_elements))[row_index] = row_sumsq_u32;
1626
+ }
1627
+ // Phase 2: column-first post-processing
1628
+ nk_u32_t norms_cache[256];
1629
+ for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
1630
+ nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
1631
+ for (nk_size_t col = chunk_start; col < chunk_end; ++col)
1632
+ norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_i4_ssve_(vectors + col * stride_elements, depth);
1633
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
1634
+ nk_size_t col_start = row_index + 1 > chunk_start ? row_index + 1 : chunk_start;
1635
+ if (col_start >= chunk_end) continue;
1636
+ nk_f32_t *result_row = result + row_index * result_stride_elements;
1637
+ nk_u32_t query_sumsq_u32 = ((nk_u32_t *)result_row)[row_index];
1638
+ svfloat32_t query_norm_sq_f32x = svdup_n_f32((nk_f32_t)query_sumsq_u32);
1639
+ for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntw()) {
1640
+ svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, chunk_end);
1641
+ svfloat32_t dots_f32x = svcvt_f32_s32_x(
1642
+ predicate_f32x, svld1_s32(predicate_f32x, (nk_i32_t *)(result_row + col_index)));
1643
+ svfloat32_t target_norms_sq_f32x = svcvt_f32_u32_x(
1644
+ predicate_f32x, svld1_u32(predicate_f32x, norms_cache + (col_index - chunk_start)));
1645
+ svst1_f32(predicate_f32x, result_row + col_index,
1646
+ nk_angulars_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x,
1647
+ target_norms_sq_f32x));
1648
+ }
1649
+ }
1650
+ }
1651
+ // Phase 3: zero diagonals
1652
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index)
1653
+ result[row_index * result_stride_elements + row_index] = 0;
1654
+ }
1655
+
1656
+ NK_PUBLIC void nk_angulars_symmetric_i4_sme( //
1657
+ nk_i4x2_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
1658
+ nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
1659
+ nk_size_t const stride_elements = stride / sizeof(nk_i4x2_t);
1660
+ nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
1661
+ nk_dots_symmetric_i4_sme_streaming_(vectors, n_vectors, depth, stride_elements, (nk_i32_t *)result,
1662
+ result_stride_elements, row_start, row_count);
1663
+ nk_angulars_symmetric_i4_sme_finalize_streaming_(vectors, n_vectors, depth, stride_elements, result,
1664
+ result_stride_elements, row_start, row_count);
1665
+ }
1666
+
1667
+ __arm_locally_streaming static void nk_euclideans_symmetric_i4_sme_finalize_streaming_( //
1668
+ nk_i4x2_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride_elements, //
1669
+ nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
1670
+ // Phase 1: cache row norms on diagonal (store as u32 in f32 slot)
1671
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
1672
+ nk_u32_t row_sumsq_u32 = nk_dots_reduce_sumsq_i4_ssve_(vectors + row_index * stride_elements, depth);
1673
+ ((nk_u32_t *)(result + row_index * result_stride_elements))[row_index] = row_sumsq_u32;
1674
+ }
1675
+ // Phase 2: column-first post-processing
1676
+ nk_u32_t norms_cache[256];
1677
+ for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
1678
+ nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
1679
+ for (nk_size_t col = chunk_start; col < chunk_end; ++col)
1680
+ norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_i4_ssve_(vectors + col * stride_elements, depth);
1681
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
1682
+ nk_size_t col_start = row_index + 1 > chunk_start ? row_index + 1 : chunk_start;
1683
+ if (col_start >= chunk_end) continue;
1684
+ nk_f32_t *result_row = result + row_index * result_stride_elements;
1685
+ nk_u32_t query_sumsq_u32 = ((nk_u32_t *)result_row)[row_index];
1686
+ svfloat32_t query_norm_sq_f32x = svdup_n_f32((nk_f32_t)query_sumsq_u32);
1687
+ for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntw()) {
1688
+ svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, chunk_end);
1689
+ svfloat32_t dots_f32x = svcvt_f32_s32_x(
1690
+ predicate_f32x, svld1_s32(predicate_f32x, (nk_i32_t *)(result_row + col_index)));
1691
+ svfloat32_t target_norms_sq_f32x = svcvt_f32_u32_x(
1692
+ predicate_f32x, svld1_u32(predicate_f32x, norms_cache + (col_index - chunk_start)));
1693
+ svst1_f32(predicate_f32x, result_row + col_index,
1694
+ nk_euclideans_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x,
1695
+ target_norms_sq_f32x));
1696
+ }
1697
+ }
1698
+ }
1699
+ // Phase 3: zero diagonals
1700
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index)
1701
+ result[row_index * result_stride_elements + row_index] = 0;
1702
+ }
1703
+
1704
+ NK_PUBLIC void nk_euclideans_symmetric_i4_sme( //
1705
+ nk_i4x2_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
1706
+ nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
1707
+ nk_size_t const stride_elements = stride / sizeof(nk_i4x2_t);
1708
+ nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
1709
+ nk_dots_symmetric_i4_sme_streaming_(vectors, n_vectors, depth, stride_elements, (nk_i32_t *)result,
1710
+ result_stride_elements, row_start, row_count);
1711
+ nk_euclideans_symmetric_i4_sme_finalize_streaming_(vectors, n_vectors, depth, stride_elements, result,
1712
+ result_stride_elements, row_start, row_count);
1713
+ }
1714
+
1715
+ #pragma endregion // Nibble Signed Integers
1716
+
1717
+ #pragma region Nibble Unsigned Integers
1718
+
1719
+ __arm_locally_streaming static void nk_angulars_packed_u4_sme_finalize_streaming_( //
1720
+ nk_u4x2_t const *a, void const *b_packed, nk_f32_t *c, //
1721
+ nk_size_t rows, nk_size_t columns, nk_size_t depth, //
1722
+ nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
1723
+ nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
1724
+ nk_u32_t const *b_norms = (nk_u32_t const *)((char const *)b_packed + header->norms_offset);
1725
+ for (nk_size_t row_index = 0; row_index < rows; row_index++) {
1726
+ nk_u4x2_t const *a_row = a + row_index * a_stride_elements;
1727
+ nk_f32_t *result_row = c + row_index * c_stride_elements;
1728
+ nk_u32_t query_norm_sq_u32 = nk_dots_reduce_sumsq_u4_ssve_(a_row, depth);
1729
+ svfloat32_t query_norm_sq_f32x = svdup_n_f32((nk_f32_t)query_norm_sq_u32);
1730
+ for (nk_size_t col_index = 0; col_index < columns; col_index += svcntw()) {
1731
+ svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, columns);
1732
+ svfloat32_t dots_f32x = svcvt_f32_u32_x(
1733
+ predicate_f32x, svld1_u32(predicate_f32x, (nk_u32_t const *)(result_row + col_index)));
1734
+ svfloat32_t target_norms_sq_f32x = svcvt_f32_u32_x(predicate_f32x,
1735
+ svld1_u32(predicate_f32x, b_norms + col_index));
1736
+ svst1_f32(
1737
+ predicate_f32x, result_row + col_index,
1738
+ nk_angulars_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
1739
+ }
1740
+ }
1741
+ }
1742
+
1743
+ NK_PUBLIC void nk_angulars_packed_u4_sme( //
1744
+ nk_u4x2_t const *a, void const *b_packed, nk_f32_t *c, //
1745
+ nk_size_t rows, nk_size_t columns, nk_size_t depth, //
1746
+ nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
1747
+ nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_u4x2_t);
1748
+ nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
1749
+ nk_dots_packed_u4_sme_streaming_(a, b_packed, (nk_u32_t *)c, rows, columns, depth, a_stride_elements,
1750
+ c_stride_elements);
1751
+ nk_angulars_packed_u4_sme_finalize_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements,
1752
+ c_stride_elements);
1753
+ }
1754
+
1755
+ __arm_locally_streaming static void nk_euclideans_packed_u4_sme_finalize_streaming_( //
1756
+ nk_u4x2_t const *a, void const *b_packed, nk_f32_t *c, //
1757
+ nk_size_t rows, nk_size_t columns, nk_size_t depth, //
1758
+ nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
1759
+ nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
1760
+ nk_u32_t const *b_norms = (nk_u32_t const *)((char const *)b_packed + header->norms_offset);
1761
+ for (nk_size_t row_index = 0; row_index < rows; row_index++) {
1762
+ nk_u4x2_t const *a_row = a + row_index * a_stride_elements;
1763
+ nk_f32_t *result_row = c + row_index * c_stride_elements;
1764
+ nk_u32_t query_norm_sq_u32 = nk_dots_reduce_sumsq_u4_ssve_(a_row, depth);
1765
+ svfloat32_t query_norm_sq_f32x = svdup_n_f32((nk_f32_t)query_norm_sq_u32);
1766
+ for (nk_size_t col_index = 0; col_index < columns; col_index += svcntw()) {
1767
+ svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, columns);
1768
+ svfloat32_t dots_f32x = svcvt_f32_u32_x(
1769
+ predicate_f32x, svld1_u32(predicate_f32x, (nk_u32_t const *)(result_row + col_index)));
1770
+ svfloat32_t target_norms_sq_f32x = svcvt_f32_u32_x(predicate_f32x,
1771
+ svld1_u32(predicate_f32x, b_norms + col_index));
1772
+ svst1_f32(
1773
+ predicate_f32x, result_row + col_index,
1774
+ nk_euclideans_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x, target_norms_sq_f32x));
1775
+ }
1776
+ }
1777
+ }
1778
+
1779
+ NK_PUBLIC void nk_euclideans_packed_u4_sme( //
1780
+ nk_u4x2_t const *a, void const *b_packed, nk_f32_t *c, //
1781
+ nk_size_t rows, nk_size_t columns, nk_size_t depth, //
1782
+ nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
1783
+ nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_u4x2_t);
1784
+ nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
1785
+ nk_dots_packed_u4_sme_streaming_(a, b_packed, (nk_u32_t *)c, rows, columns, depth, a_stride_elements,
1786
+ c_stride_elements);
1787
+ nk_euclideans_packed_u4_sme_finalize_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements,
1788
+ c_stride_elements);
1789
+ }
1790
+
1791
+ __arm_locally_streaming static void nk_angulars_symmetric_u4_sme_finalize_streaming_( //
1792
+ nk_u4x2_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride_elements, //
1793
+ nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
1794
+ // Phase 1: cache row norms on diagonal (store as u32 in f32 slot)
1795
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
1796
+ nk_u32_t row_sumsq_u32 = nk_dots_reduce_sumsq_u4_ssve_(vectors + row_index * stride_elements, depth);
1797
+ ((nk_u32_t *)(result + row_index * result_stride_elements))[row_index] = row_sumsq_u32;
1798
+ }
1799
+ // Phase 2: column-first post-processing
1800
+ nk_u32_t norms_cache[256];
1801
+ for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
1802
+ nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
1803
+ for (nk_size_t col = chunk_start; col < chunk_end; ++col)
1804
+ norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_u4_ssve_(vectors + col * stride_elements, depth);
1805
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
1806
+ nk_size_t col_start = row_index + 1 > chunk_start ? row_index + 1 : chunk_start;
1807
+ if (col_start >= chunk_end) continue;
1808
+ nk_f32_t *result_row = result + row_index * result_stride_elements;
1809
+ nk_u32_t query_sumsq_u32 = ((nk_u32_t *)result_row)[row_index];
1810
+ svfloat32_t query_norm_sq_f32x = svdup_n_f32((nk_f32_t)query_sumsq_u32);
1811
+ for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntw()) {
1812
+ svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, chunk_end);
1813
+ svfloat32_t dots_f32x = svcvt_f32_u32_x(
1814
+ predicate_f32x, svld1_u32(predicate_f32x, (nk_u32_t *)(result_row + col_index)));
1815
+ svfloat32_t target_norms_sq_f32x = svcvt_f32_u32_x(
1816
+ predicate_f32x, svld1_u32(predicate_f32x, norms_cache + (col_index - chunk_start)));
1817
+ svst1_f32(predicate_f32x, result_row + col_index,
1818
+ nk_angulars_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x,
1819
+ target_norms_sq_f32x));
1820
+ }
1821
+ }
1822
+ }
1823
+ // Phase 3: zero diagonals
1824
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index)
1825
+ result[row_index * result_stride_elements + row_index] = 0;
1826
+ }
1827
+
1828
+ NK_PUBLIC void nk_angulars_symmetric_u4_sme( //
1829
+ nk_u4x2_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
1830
+ nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
1831
+ nk_size_t const stride_elements = stride / sizeof(nk_u4x2_t);
1832
+ nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
1833
+ nk_dots_symmetric_u4_sme_streaming_(vectors, n_vectors, depth, stride_elements, (nk_u32_t *)result,
1834
+ result_stride_elements, row_start, row_count);
1835
+ nk_angulars_symmetric_u4_sme_finalize_streaming_(vectors, n_vectors, depth, stride_elements, result,
1836
+ result_stride_elements, row_start, row_count);
1837
+ }
1838
+
1839
+ __arm_locally_streaming static void nk_euclideans_symmetric_u4_sme_finalize_streaming_( //
1840
+ nk_u4x2_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride_elements, //
1841
+ nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
1842
+ // Phase 1: cache row norms on diagonal (store as u32 in f32 slot)
1843
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
1844
+ nk_u32_t row_sumsq_u32 = nk_dots_reduce_sumsq_u4_ssve_(vectors + row_index * stride_elements, depth);
1845
+ ((nk_u32_t *)(result + row_index * result_stride_elements))[row_index] = row_sumsq_u32;
1846
+ }
1847
+ // Phase 2: column-first post-processing
1848
+ nk_u32_t norms_cache[256];
1849
+ for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
1850
+ nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
1851
+ for (nk_size_t col = chunk_start; col < chunk_end; ++col)
1852
+ norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_u4_ssve_(vectors + col * stride_elements, depth);
1853
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
1854
+ nk_size_t col_start = row_index + 1 > chunk_start ? row_index + 1 : chunk_start;
1855
+ if (col_start >= chunk_end) continue;
1856
+ nk_f32_t *result_row = result + row_index * result_stride_elements;
1857
+ nk_u32_t query_sumsq_u32 = ((nk_u32_t *)result_row)[row_index];
1858
+ svfloat32_t query_norm_sq_f32x = svdup_n_f32((nk_f32_t)query_sumsq_u32);
1859
+ for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntw()) {
1860
+ svbool_t predicate_f32x = svwhilelt_b32_u64(col_index, chunk_end);
1861
+ svfloat32_t dots_f32x = svcvt_f32_u32_x(
1862
+ predicate_f32x, svld1_u32(predicate_f32x, (nk_u32_t *)(result_row + col_index)));
1863
+ svfloat32_t target_norms_sq_f32x = svcvt_f32_u32_x(
1864
+ predicate_f32x, svld1_u32(predicate_f32x, norms_cache + (col_index - chunk_start)));
1865
+ svst1_f32(predicate_f32x, result_row + col_index,
1866
+ nk_euclideans_from_dot_f32x_ssve_(predicate_f32x, dots_f32x, query_norm_sq_f32x,
1867
+ target_norms_sq_f32x));
1868
+ }
1869
+ }
1870
+ }
1871
+ // Phase 3: zero diagonals
1872
+ for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index)
1873
+ result[row_index * result_stride_elements + row_index] = 0;
1874
+ }
1875
+
1876
+ NK_PUBLIC void nk_euclideans_symmetric_u4_sme( //
1877
+ nk_u4x2_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
1878
+ nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
1879
+ nk_size_t const stride_elements = stride / sizeof(nk_u4x2_t);
1880
+ nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
1881
+ nk_dots_symmetric_u4_sme_streaming_(vectors, n_vectors, depth, stride_elements, (nk_u32_t *)result,
1882
+ result_stride_elements, row_start, row_count);
1883
+ nk_euclideans_symmetric_u4_sme_finalize_streaming_(vectors, n_vectors, depth, stride_elements, result,
1884
+ result_stride_elements, row_start, row_count);
1885
+ }
1886
+
1887
+ #pragma endregion // Nibble Unsigned Integers
1888
+
1889
+ #if defined(__clang__)
1890
+ #pragma clang attribute pop
1891
+ #elif defined(__GNUC__)
1892
+ #pragma GCC pop_options
1893
+ #endif
1894
+
1895
+ #if defined(__cplusplus)
1896
+ } // extern "C"
1897
+ #endif
1898
+
1899
+ #endif // NK_TARGET_SME
1900
+ #endif // NK_TARGET_ARM_
1901
+ #endif // NK_SPATIALS_SME_H