numkong 7.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (294) hide show
  1. package/LICENSE +201 -0
  2. package/README.md +495 -0
  3. package/binding.gyp +540 -0
  4. package/c/dispatch.h +512 -0
  5. package/c/dispatch_bf16.c +389 -0
  6. package/c/dispatch_bf16c.c +52 -0
  7. package/c/dispatch_e2m3.c +263 -0
  8. package/c/dispatch_e3m2.c +243 -0
  9. package/c/dispatch_e4m3.c +276 -0
  10. package/c/dispatch_e5m2.c +272 -0
  11. package/c/dispatch_f16.c +376 -0
  12. package/c/dispatch_f16c.c +58 -0
  13. package/c/dispatch_f32.c +378 -0
  14. package/c/dispatch_f32c.c +99 -0
  15. package/c/dispatch_f64.c +296 -0
  16. package/c/dispatch_f64c.c +98 -0
  17. package/c/dispatch_i16.c +96 -0
  18. package/c/dispatch_i32.c +89 -0
  19. package/c/dispatch_i4.c +150 -0
  20. package/c/dispatch_i64.c +86 -0
  21. package/c/dispatch_i8.c +289 -0
  22. package/c/dispatch_other.c +330 -0
  23. package/c/dispatch_u1.c +148 -0
  24. package/c/dispatch_u16.c +124 -0
  25. package/c/dispatch_u32.c +118 -0
  26. package/c/dispatch_u4.c +150 -0
  27. package/c/dispatch_u64.c +102 -0
  28. package/c/dispatch_u8.c +303 -0
  29. package/c/numkong.c +950 -0
  30. package/include/README.md +573 -0
  31. package/include/module.modulemap +129 -0
  32. package/include/numkong/attention/sapphireamx.h +1361 -0
  33. package/include/numkong/attention/sme.h +2066 -0
  34. package/include/numkong/attention.h +49 -0
  35. package/include/numkong/capabilities.h +748 -0
  36. package/include/numkong/cast/README.md +262 -0
  37. package/include/numkong/cast/haswell.h +975 -0
  38. package/include/numkong/cast/icelake.h +470 -0
  39. package/include/numkong/cast/neon.h +1192 -0
  40. package/include/numkong/cast/rvv.h +1021 -0
  41. package/include/numkong/cast/sapphire.h +262 -0
  42. package/include/numkong/cast/serial.h +2262 -0
  43. package/include/numkong/cast/skylake.h +856 -0
  44. package/include/numkong/cast/v128relaxed.h +180 -0
  45. package/include/numkong/cast.h +230 -0
  46. package/include/numkong/curved/README.md +223 -0
  47. package/include/numkong/curved/genoa.h +182 -0
  48. package/include/numkong/curved/haswell.h +276 -0
  49. package/include/numkong/curved/neon.h +205 -0
  50. package/include/numkong/curved/neonbfdot.h +212 -0
  51. package/include/numkong/curved/neonhalf.h +212 -0
  52. package/include/numkong/curved/rvv.h +305 -0
  53. package/include/numkong/curved/serial.h +207 -0
  54. package/include/numkong/curved/skylake.h +457 -0
  55. package/include/numkong/curved/smef64.h +506 -0
  56. package/include/numkong/curved.h +517 -0
  57. package/include/numkong/curved.hpp +144 -0
  58. package/include/numkong/dot/README.md +425 -0
  59. package/include/numkong/dot/alder.h +563 -0
  60. package/include/numkong/dot/genoa.h +315 -0
  61. package/include/numkong/dot/haswell.h +1688 -0
  62. package/include/numkong/dot/icelake.h +883 -0
  63. package/include/numkong/dot/neon.h +818 -0
  64. package/include/numkong/dot/neonbfdot.h +244 -0
  65. package/include/numkong/dot/neonfhm.h +360 -0
  66. package/include/numkong/dot/neonhalf.h +198 -0
  67. package/include/numkong/dot/neonsdot.h +508 -0
  68. package/include/numkong/dot/rvv.h +714 -0
  69. package/include/numkong/dot/rvvbb.h +72 -0
  70. package/include/numkong/dot/rvvbf16.h +123 -0
  71. package/include/numkong/dot/rvvhalf.h +129 -0
  72. package/include/numkong/dot/sapphire.h +141 -0
  73. package/include/numkong/dot/serial.h +838 -0
  74. package/include/numkong/dot/sierra.h +405 -0
  75. package/include/numkong/dot/skylake.h +1084 -0
  76. package/include/numkong/dot/sve.h +379 -0
  77. package/include/numkong/dot/svebfdot.h +74 -0
  78. package/include/numkong/dot/svehalf.h +123 -0
  79. package/include/numkong/dot/v128relaxed.h +1258 -0
  80. package/include/numkong/dot.h +1070 -0
  81. package/include/numkong/dot.hpp +94 -0
  82. package/include/numkong/dots/README.md +496 -0
  83. package/include/numkong/dots/alder.h +114 -0
  84. package/include/numkong/dots/genoa.h +94 -0
  85. package/include/numkong/dots/haswell.h +295 -0
  86. package/include/numkong/dots/icelake.h +171 -0
  87. package/include/numkong/dots/neon.h +120 -0
  88. package/include/numkong/dots/neonbfdot.h +58 -0
  89. package/include/numkong/dots/neonfhm.h +94 -0
  90. package/include/numkong/dots/neonhalf.h +57 -0
  91. package/include/numkong/dots/neonsdot.h +108 -0
  92. package/include/numkong/dots/rvv.h +2486 -0
  93. package/include/numkong/dots/sapphireamx.h +3973 -0
  94. package/include/numkong/dots/serial.h +2844 -0
  95. package/include/numkong/dots/sierra.h +97 -0
  96. package/include/numkong/dots/skylake.h +196 -0
  97. package/include/numkong/dots/sme.h +5372 -0
  98. package/include/numkong/dots/smebi32.h +461 -0
  99. package/include/numkong/dots/smef64.h +1318 -0
  100. package/include/numkong/dots/smehalf.h +47 -0
  101. package/include/numkong/dots/v128relaxed.h +294 -0
  102. package/include/numkong/dots.h +2804 -0
  103. package/include/numkong/dots.hpp +639 -0
  104. package/include/numkong/each/README.md +469 -0
  105. package/include/numkong/each/haswell.h +1658 -0
  106. package/include/numkong/each/icelake.h +272 -0
  107. package/include/numkong/each/neon.h +1104 -0
  108. package/include/numkong/each/neonbfdot.h +212 -0
  109. package/include/numkong/each/neonhalf.h +410 -0
  110. package/include/numkong/each/rvv.h +1121 -0
  111. package/include/numkong/each/sapphire.h +477 -0
  112. package/include/numkong/each/serial.h +260 -0
  113. package/include/numkong/each/skylake.h +1562 -0
  114. package/include/numkong/each.h +2146 -0
  115. package/include/numkong/each.hpp +434 -0
  116. package/include/numkong/geospatial/README.md +147 -0
  117. package/include/numkong/geospatial/haswell.h +593 -0
  118. package/include/numkong/geospatial/neon.h +571 -0
  119. package/include/numkong/geospatial/rvv.h +701 -0
  120. package/include/numkong/geospatial/serial.h +309 -0
  121. package/include/numkong/geospatial/skylake.h +577 -0
  122. package/include/numkong/geospatial/v128relaxed.h +613 -0
  123. package/include/numkong/geospatial.h +453 -0
  124. package/include/numkong/geospatial.hpp +235 -0
  125. package/include/numkong/matrix.hpp +336 -0
  126. package/include/numkong/maxsim/README.md +187 -0
  127. package/include/numkong/maxsim/alder.h +511 -0
  128. package/include/numkong/maxsim/genoa.h +115 -0
  129. package/include/numkong/maxsim/haswell.h +553 -0
  130. package/include/numkong/maxsim/icelake.h +480 -0
  131. package/include/numkong/maxsim/neonsdot.h +394 -0
  132. package/include/numkong/maxsim/sapphireamx.h +877 -0
  133. package/include/numkong/maxsim/serial.h +490 -0
  134. package/include/numkong/maxsim/sme.h +929 -0
  135. package/include/numkong/maxsim/v128relaxed.h +280 -0
  136. package/include/numkong/maxsim.h +571 -0
  137. package/include/numkong/maxsim.hpp +133 -0
  138. package/include/numkong/mesh/README.md +227 -0
  139. package/include/numkong/mesh/haswell.h +2235 -0
  140. package/include/numkong/mesh/neon.h +1329 -0
  141. package/include/numkong/mesh/neonbfdot.h +842 -0
  142. package/include/numkong/mesh/neonhalf.h +616 -0
  143. package/include/numkong/mesh/rvv.h +916 -0
  144. package/include/numkong/mesh/serial.h +742 -0
  145. package/include/numkong/mesh/skylake.h +1135 -0
  146. package/include/numkong/mesh/v128relaxed.h +1052 -0
  147. package/include/numkong/mesh.h +652 -0
  148. package/include/numkong/mesh.hpp +762 -0
  149. package/include/numkong/numkong.h +78 -0
  150. package/include/numkong/numkong.hpp +57 -0
  151. package/include/numkong/probability/README.md +173 -0
  152. package/include/numkong/probability/haswell.h +267 -0
  153. package/include/numkong/probability/neon.h +225 -0
  154. package/include/numkong/probability/rvv.h +409 -0
  155. package/include/numkong/probability/serial.h +169 -0
  156. package/include/numkong/probability/skylake.h +324 -0
  157. package/include/numkong/probability.h +383 -0
  158. package/include/numkong/probability.hpp +120 -0
  159. package/include/numkong/random.h +50 -0
  160. package/include/numkong/random.hpp +285 -0
  161. package/include/numkong/reduce/README.md +547 -0
  162. package/include/numkong/reduce/alder.h +632 -0
  163. package/include/numkong/reduce/genoa.h +201 -0
  164. package/include/numkong/reduce/haswell.h +3783 -0
  165. package/include/numkong/reduce/icelake.h +549 -0
  166. package/include/numkong/reduce/neon.h +3841 -0
  167. package/include/numkong/reduce/neonbfdot.h +353 -0
  168. package/include/numkong/reduce/neonfhm.h +665 -0
  169. package/include/numkong/reduce/neonhalf.h +157 -0
  170. package/include/numkong/reduce/neonsdot.h +357 -0
  171. package/include/numkong/reduce/rvv.h +3407 -0
  172. package/include/numkong/reduce/serial.h +757 -0
  173. package/include/numkong/reduce/sierra.h +338 -0
  174. package/include/numkong/reduce/skylake.h +3792 -0
  175. package/include/numkong/reduce/v128relaxed.h +2302 -0
  176. package/include/numkong/reduce.h +1597 -0
  177. package/include/numkong/reduce.hpp +633 -0
  178. package/include/numkong/scalar/README.md +89 -0
  179. package/include/numkong/scalar/haswell.h +113 -0
  180. package/include/numkong/scalar/neon.h +122 -0
  181. package/include/numkong/scalar/neonhalf.h +70 -0
  182. package/include/numkong/scalar/rvv.h +211 -0
  183. package/include/numkong/scalar/sapphire.h +63 -0
  184. package/include/numkong/scalar/serial.h +332 -0
  185. package/include/numkong/scalar/v128relaxed.h +56 -0
  186. package/include/numkong/scalar.h +683 -0
  187. package/include/numkong/set/README.md +179 -0
  188. package/include/numkong/set/haswell.h +334 -0
  189. package/include/numkong/set/icelake.h +485 -0
  190. package/include/numkong/set/neon.h +364 -0
  191. package/include/numkong/set/rvv.h +226 -0
  192. package/include/numkong/set/rvvbb.h +117 -0
  193. package/include/numkong/set/serial.h +174 -0
  194. package/include/numkong/set/sve.h +185 -0
  195. package/include/numkong/set/v128relaxed.h +240 -0
  196. package/include/numkong/set.h +457 -0
  197. package/include/numkong/set.hpp +114 -0
  198. package/include/numkong/sets/README.md +149 -0
  199. package/include/numkong/sets/haswell.h +63 -0
  200. package/include/numkong/sets/icelake.h +66 -0
  201. package/include/numkong/sets/neon.h +61 -0
  202. package/include/numkong/sets/serial.h +43 -0
  203. package/include/numkong/sets/smebi32.h +1099 -0
  204. package/include/numkong/sets/v128relaxed.h +58 -0
  205. package/include/numkong/sets.h +339 -0
  206. package/include/numkong/sparse/README.md +156 -0
  207. package/include/numkong/sparse/icelake.h +463 -0
  208. package/include/numkong/sparse/neon.h +288 -0
  209. package/include/numkong/sparse/serial.h +117 -0
  210. package/include/numkong/sparse/sve2.h +507 -0
  211. package/include/numkong/sparse/turin.h +322 -0
  212. package/include/numkong/sparse.h +363 -0
  213. package/include/numkong/sparse.hpp +113 -0
  214. package/include/numkong/spatial/README.md +435 -0
  215. package/include/numkong/spatial/alder.h +607 -0
  216. package/include/numkong/spatial/genoa.h +290 -0
  217. package/include/numkong/spatial/haswell.h +960 -0
  218. package/include/numkong/spatial/icelake.h +586 -0
  219. package/include/numkong/spatial/neon.h +773 -0
  220. package/include/numkong/spatial/neonbfdot.h +165 -0
  221. package/include/numkong/spatial/neonhalf.h +118 -0
  222. package/include/numkong/spatial/neonsdot.h +261 -0
  223. package/include/numkong/spatial/rvv.h +984 -0
  224. package/include/numkong/spatial/rvvbf16.h +123 -0
  225. package/include/numkong/spatial/rvvhalf.h +117 -0
  226. package/include/numkong/spatial/sapphire.h +343 -0
  227. package/include/numkong/spatial/serial.h +346 -0
  228. package/include/numkong/spatial/sierra.h +323 -0
  229. package/include/numkong/spatial/skylake.h +606 -0
  230. package/include/numkong/spatial/sve.h +224 -0
  231. package/include/numkong/spatial/svebfdot.h +122 -0
  232. package/include/numkong/spatial/svehalf.h +109 -0
  233. package/include/numkong/spatial/v128relaxed.h +717 -0
  234. package/include/numkong/spatial.h +1425 -0
  235. package/include/numkong/spatial.hpp +183 -0
  236. package/include/numkong/spatials/README.md +580 -0
  237. package/include/numkong/spatials/alder.h +94 -0
  238. package/include/numkong/spatials/genoa.h +94 -0
  239. package/include/numkong/spatials/haswell.h +219 -0
  240. package/include/numkong/spatials/icelake.h +113 -0
  241. package/include/numkong/spatials/neon.h +109 -0
  242. package/include/numkong/spatials/neonbfdot.h +60 -0
  243. package/include/numkong/spatials/neonfhm.h +92 -0
  244. package/include/numkong/spatials/neonhalf.h +58 -0
  245. package/include/numkong/spatials/neonsdot.h +109 -0
  246. package/include/numkong/spatials/rvv.h +1960 -0
  247. package/include/numkong/spatials/sapphireamx.h +1149 -0
  248. package/include/numkong/spatials/serial.h +226 -0
  249. package/include/numkong/spatials/sierra.h +96 -0
  250. package/include/numkong/spatials/skylake.h +184 -0
  251. package/include/numkong/spatials/sme.h +1901 -0
  252. package/include/numkong/spatials/smef64.h +465 -0
  253. package/include/numkong/spatials/v128relaxed.h +240 -0
  254. package/include/numkong/spatials.h +3021 -0
  255. package/include/numkong/spatials.hpp +508 -0
  256. package/include/numkong/tensor.hpp +1592 -0
  257. package/include/numkong/trigonometry/README.md +184 -0
  258. package/include/numkong/trigonometry/haswell.h +652 -0
  259. package/include/numkong/trigonometry/neon.h +639 -0
  260. package/include/numkong/trigonometry/rvv.h +699 -0
  261. package/include/numkong/trigonometry/serial.h +703 -0
  262. package/include/numkong/trigonometry/skylake.h +721 -0
  263. package/include/numkong/trigonometry/v128relaxed.h +666 -0
  264. package/include/numkong/trigonometry.h +467 -0
  265. package/include/numkong/trigonometry.hpp +166 -0
  266. package/include/numkong/types.h +1384 -0
  267. package/include/numkong/types.hpp +5603 -0
  268. package/include/numkong/vector.hpp +698 -0
  269. package/javascript/README.md +246 -0
  270. package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
  271. package/javascript/dist/cjs/numkong-wasm.js +617 -0
  272. package/javascript/dist/cjs/numkong.d.ts +343 -0
  273. package/javascript/dist/cjs/numkong.js +523 -0
  274. package/javascript/dist/cjs/package.json +3 -0
  275. package/javascript/dist/cjs/types.d.ts +284 -0
  276. package/javascript/dist/cjs/types.js +653 -0
  277. package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
  278. package/javascript/dist/esm/numkong-wasm.js +595 -0
  279. package/javascript/dist/esm/numkong.d.ts +343 -0
  280. package/javascript/dist/esm/numkong.js +452 -0
  281. package/javascript/dist/esm/package.json +3 -0
  282. package/javascript/dist/esm/types.d.ts +284 -0
  283. package/javascript/dist/esm/types.js +630 -0
  284. package/javascript/dist-package-cjs.json +3 -0
  285. package/javascript/dist-package-esm.json +3 -0
  286. package/javascript/node-gyp-build.d.ts +1 -0
  287. package/javascript/numkong-wasm.ts +756 -0
  288. package/javascript/numkong.c +689 -0
  289. package/javascript/numkong.ts +575 -0
  290. package/javascript/tsconfig-base.json +39 -0
  291. package/javascript/tsconfig-cjs.json +8 -0
  292. package/javascript/tsconfig-esm.json +8 -0
  293. package/javascript/types.ts +674 -0
  294. package/package.json +87 -0
@@ -0,0 +1,929 @@
1
+ /**
2
+ * @brief SIMD-accelerated MaxSim (ColBERT late-interaction) for SME.
3
+ * @file include/numkong/maxsim/sme.h
4
+ * @author Ash Vardanian
5
+ * @date February 10, 2026
6
+ *
7
+ * Computes MaxSim(Q, D) = Σᵢ maxⱼ dot(qᵢ, dⱼ) using ARM SME outer products.
8
+ *
9
+ * Both Q and D are pre-packed with `nk_dots_pack_bf16_sme` from `dots/sme.h`.
10
+ * This frees all 4 ZA tiles for accumulation (vs 3 with A-side staging).
11
+ *
12
+ * Key optimization: vertical column reads for max reduction.
13
+ * Traditional extraction reads tile rows then calls `svmaxv` (horizontal max, ~8cy).
14
+ * Our approach reads tile columns with `svread_ver_za32_f32_m`:
15
+ *
16
+ * - Each column read gives dot products of all query tokens vs one doc token.
17
+ * - Element-wise `svmax` (~1cy) updates a running max vector across doc tokens.
18
+ * - Only `svaddv` at the very end: ⌈n_q/16⌉ = 2 horizontal reductions total.
19
+ *
20
+ * This is ~100x fewer horizontal reductions for typical ColBERT dimensions.
21
+ *
22
+ * ZA tile layout after BFMOPA accumulation (16x16 f32):
23
+ *
24
+ * - Row i, Column j = dot(q_{tile_row_start + i}, d_{tile_col_start + j})
25
+ * - Vertical column read of column j → similarities of all 16 q tokens to doc token j
26
+ * - Element-wise max across columns → per-query-token max over doc tokens in this tile group
27
+ *
28
+ * Benchmark results (Apple M4, SVL=512):
29
+ *
30
+ * Dimensions dots_packed GEMM maxsim fused GEMM speedup End-to-end speedup
31
+ * 32×128×128 (ColBERT) 840 GFLOPS 1516 GFLOPS 1.81× 5.10×
32
+ * 32×256×128 1037 GFLOPS 1591 GFLOPS 1.53× 5.17×
33
+ * 64×512×128 1016 GFLOPS 1651 GFLOPS 1.62× 5.42×
34
+ * 32×128×256 859 GFLOPS 1725 GFLOPS 2.01× 4.06×
35
+ * 32×1024×768 (BERT) 1124 GFLOPS 1932 GFLOPS 1.72× 2.61×
36
+ *
37
+ * Speedup sources:
38
+ *
39
+ * 1. Pre-packing both sides → 4 ZA tiles for accumulation (vs 3 with A-staging): +33% MOPA throughput
40
+ * 2. No output matrix materialization → eliminates M×N f32 memory round-trip
41
+ * 3. Vertical column reads → ~128 element-wise svmax (1cy) vs ~256 svmaxv horizontal reductions (8cy)
42
+ */
43
+ #ifndef NK_MAXSIM_SME_H
44
+ #define NK_MAXSIM_SME_H
45
+
46
+ #if NK_TARGET_ARM_
47
+ #if NK_TARGET_SME
48
+
49
+ #include "numkong/dots/sme.h" // nk_dots_sme_packed_header_t, nk_dots_pack_{f16,bf16}_sme, nk_dots_packed_size_{f16,bf16}_sme
50
+
51
+ #if defined(__cplusplus)
52
+ extern "C" {
53
+ #endif
54
+
55
+ #if defined(__clang__)
56
+ #pragma clang attribute push(__attribute__((target("sme,sve"))), apply_to = function)
57
+ #elif defined(__GNUC__)
58
+ #pragma GCC push_options
59
+ #pragma GCC target("+sme")
60
+ #endif
61
+
62
+ /**
63
+ * Packed header for MaxSim SME kernels. Used by f32 (i8 screening + f32 refinement)
64
+ * and bf16/f16 (BFMOPA/FMOPA + angular normalization) kernels.
65
+ *
66
+ * For f32: stores i8 tile-interleaved data, f32 squared norms, AND f32 originals.
67
+ * For bf16/f16: stores tile-interleaved data and f32 inverse norms (1/||v||).
68
+ * originals_offset and original_stride are 0 (unused).
69
+ */
70
+ typedef struct {
71
+ nk_u32_t column_tile_count; // ceil(n / tile_dimension)
72
+ nk_u32_t depth_tile_count; // ceil(depth / expansion)
73
+ nk_u32_t columns; // actual vector count (for predicates)
74
+ nk_u32_t depth; // actual depth
75
+ nk_u32_t svl_bytes; // SVL in bytes at pack time (validation)
76
+ nk_u32_t norms_offset; // byte offset -> per-vector norms (squared for f32, inverse for bf16/f16)
77
+ nk_u32_t originals_offset; // byte offset -> f32 original vectors (0 for bf16/f16)
78
+ nk_u32_t original_stride; // row stride in bytes for originals (64B-aligned, 0 for bf16/f16)
79
+ nk_u32_t reserved[8]; // padding to 64 bytes
80
+ } nk_maxsim_sme_packed_header_t;
81
+
82
+ NK_STATIC_ASSERT(sizeof(nk_maxsim_sme_packed_header_t) == 64, nk_maxsim_sme_packed_header_must_be_64_bytes);
83
+
84
+ /**
85
+ * MaxSim f16 kernel: both Q and D pre-packed, vertical column read extraction.
86
+ *
87
+ * 4-tile fast path: processes 4 doc column tiles simultaneously using ZA0-ZA3.
88
+ * Inner loop per depth_step: 1 Q load + 4 D loads + 4 FMOPA = 9 ops.
89
+ * Extraction per 4-tile group: 4×16 = 64 vertical reads + 64 svmax = ~128 cycles.
90
+ *
91
+ * 1-tile remainder: uses ZA0 only, with predicated loads for partial tiles.
92
+ */
93
+ __arm_locally_streaming __arm_new("za") static void nk_maxsim_packed_f16_streaming_( //
94
+ void const *query_packed, void const *document_packed, //
95
+ nk_size_t query_count, nk_size_t document_count, //
96
+ nk_size_t depth, nk_f32_t *result) {
97
+
98
+ nk_maxsim_sme_packed_header_t const *query_header = (nk_maxsim_sme_packed_header_t const *)query_packed;
99
+ nk_maxsim_sme_packed_header_t const *document_header = (nk_maxsim_sme_packed_header_t const *)document_packed;
100
+ nk_size_t const depth_step_count = query_header->depth_tile_count;
101
+ nk_size_t const query_row_tiles = query_header->column_tile_count;
102
+ nk_size_t const document_col_tiles = document_header->column_tile_count;
103
+
104
+ nk_size_t const tile_dimension = svcntw(); // 16: ZA32 tile dimension
105
+ nk_size_t const vector_elements = svcnth(); // 32: f16 elements per SVE vector
106
+
107
+ nk_f16_t const *query_vecs = (nk_f16_t const *)((char const *)query_packed + sizeof(nk_maxsim_sme_packed_header_t));
108
+ nk_f16_t const *document_vecs = (nk_f16_t const *)((char const *)document_packed +
109
+ sizeof(nk_maxsim_sme_packed_header_t));
110
+
111
+ nk_f32_t const *query_inverse_norms = (nk_f32_t const *)((char const *)query_packed + query_header->norms_offset);
112
+ nk_f32_t const *document_inverse_norms = (nk_f32_t const *)((char const *)document_packed +
113
+ document_header->norms_offset);
114
+
115
+ svbool_t const predicate_all_f16x = svptrue_b16();
116
+ svbool_t const predicate_all_f32x = svptrue_b32();
117
+
118
+ nk_f32_t total_angular_distance = 0.0f;
119
+
120
+ for (nk_size_t row_tile_index = 0; row_tile_index < query_row_tiles; row_tile_index++) {
121
+ nk_size_t const row_start = row_tile_index * tile_dimension;
122
+ nk_size_t const rows_remaining = (row_start + tile_dimension <= query_count) ? tile_dimension
123
+ : (query_count - row_start);
124
+ svbool_t const row_predicate_f16x = (rows_remaining == tile_dimension)
125
+ ? svptrue_b16()
126
+ : svwhilelt_b16_u64(0u, rows_remaining * 2);
127
+ svbool_t const row_predicate_f32x = (rows_remaining == tile_dimension) ? svptrue_b32()
128
+ : svwhilelt_b32_u64(0u, rows_remaining);
129
+
130
+ // Running max + argmax vectors for angular distance finalization
131
+ svfloat32_t running_maximum_f32x = svdup_f32(NK_F32_MIN);
132
+ svuint32_t running_argmax_u32x = svdup_u32(0);
133
+
134
+ nk_size_t column_tile_index = 0;
135
+
136
+ // Fast path: 4 doc column tiles at a time using ZA0-ZA3
137
+ for (; column_tile_index + 4 <= document_col_tiles; column_tile_index += 4) {
138
+ svzero_za(); // Zero all 4 tiles
139
+
140
+ // Accumulate: for each depth step, load Q vector and 4 D vectors, issue 4 FMOPAs
141
+ for (nk_size_t depth_step = 0; depth_step < depth_step_count; depth_step++) {
142
+ svfloat16_t query_packed_f16x = svld1_f16(
143
+ row_predicate_f16x,
144
+ (float16_t const *)(query_vecs +
145
+ (row_tile_index * depth_step_count + depth_step) * vector_elements));
146
+ svfloat16_t document_packed_0_f16x = svld1_f16(
147
+ predicate_all_f16x,
148
+ (float16_t const *)(document_vecs +
149
+ ((column_tile_index + 0) * depth_step_count + depth_step) * vector_elements));
150
+ svfloat16_t document_packed_1_f16x = svld1_f16(
151
+ predicate_all_f16x,
152
+ (float16_t const *)(document_vecs +
153
+ ((column_tile_index + 1) * depth_step_count + depth_step) * vector_elements));
154
+ svfloat16_t document_packed_2_f16x = svld1_f16(
155
+ predicate_all_f16x,
156
+ (float16_t const *)(document_vecs +
157
+ ((column_tile_index + 2) * depth_step_count + depth_step) * vector_elements));
158
+ svfloat16_t document_packed_3_f16x = svld1_f16(
159
+ predicate_all_f16x,
160
+ (float16_t const *)(document_vecs +
161
+ ((column_tile_index + 3) * depth_step_count + depth_step) * vector_elements));
162
+ svmopa_za32_f16_m(0, row_predicate_f16x, predicate_all_f16x, query_packed_f16x, document_packed_0_f16x);
163
+ svmopa_za32_f16_m(1, row_predicate_f16x, predicate_all_f16x, query_packed_f16x, document_packed_1_f16x);
164
+ svmopa_za32_f16_m(2, row_predicate_f16x, predicate_all_f16x, query_packed_f16x, document_packed_2_f16x);
165
+ svmopa_za32_f16_m(3, row_predicate_f16x, predicate_all_f16x, query_packed_f16x, document_packed_3_f16x);
166
+ }
167
+
168
+ // Vertical column extraction + argmax update (manually unrolled over 4 tiles)
169
+ for (nk_size_t column_within_tile = 0; column_within_tile < tile_dimension; column_within_tile++) {
170
+ // Tile 0
171
+ {
172
+ nk_u32_t document_index = (nk_u32_t)((column_tile_index + 0) * tile_dimension + column_within_tile);
173
+ svfloat32_t column_dots_f32x = svread_ver_za32_f32_m(svdup_f32(NK_F32_MIN), predicate_all_f32x, 0,
174
+ column_within_tile);
175
+ svbool_t is_better_bx = svcmpgt_f32(predicate_all_f32x, column_dots_f32x, running_maximum_f32x);
176
+ running_maximum_f32x = svsel_f32(is_better_bx, column_dots_f32x, running_maximum_f32x);
177
+ running_argmax_u32x = svsel_u32(is_better_bx, svdup_u32(document_index), running_argmax_u32x);
178
+ }
179
+ // Tile 1
180
+ {
181
+ nk_u32_t document_index = (nk_u32_t)((column_tile_index + 1) * tile_dimension + column_within_tile);
182
+ svfloat32_t column_dots_f32x = svread_ver_za32_f32_m(svdup_f32(NK_F32_MIN), predicate_all_f32x, 1,
183
+ column_within_tile);
184
+ svbool_t is_better_bx = svcmpgt_f32(predicate_all_f32x, column_dots_f32x, running_maximum_f32x);
185
+ running_maximum_f32x = svsel_f32(is_better_bx, column_dots_f32x, running_maximum_f32x);
186
+ running_argmax_u32x = svsel_u32(is_better_bx, svdup_u32(document_index), running_argmax_u32x);
187
+ }
188
+ // Tile 2
189
+ {
190
+ nk_u32_t document_index = (nk_u32_t)((column_tile_index + 2) * tile_dimension + column_within_tile);
191
+ svfloat32_t column_dots_f32x = svread_ver_za32_f32_m(svdup_f32(NK_F32_MIN), predicate_all_f32x, 2,
192
+ column_within_tile);
193
+ svbool_t is_better_bx = svcmpgt_f32(predicate_all_f32x, column_dots_f32x, running_maximum_f32x);
194
+ running_maximum_f32x = svsel_f32(is_better_bx, column_dots_f32x, running_maximum_f32x);
195
+ running_argmax_u32x = svsel_u32(is_better_bx, svdup_u32(document_index), running_argmax_u32x);
196
+ }
197
+ // Tile 3
198
+ {
199
+ nk_u32_t document_index = (nk_u32_t)((column_tile_index + 3) * tile_dimension + column_within_tile);
200
+ svfloat32_t column_dots_f32x = svread_ver_za32_f32_m(svdup_f32(NK_F32_MIN), predicate_all_f32x, 3,
201
+ column_within_tile);
202
+ svbool_t is_better_bx = svcmpgt_f32(predicate_all_f32x, column_dots_f32x, running_maximum_f32x);
203
+ running_maximum_f32x = svsel_f32(is_better_bx, column_dots_f32x, running_maximum_f32x);
204
+ running_argmax_u32x = svsel_u32(is_better_bx, svdup_u32(document_index), running_argmax_u32x);
205
+ }
206
+ }
207
+ }
208
+
209
+ // Remainder: 1 doc column tile at a time using ZA0 only
210
+ for (; column_tile_index < document_col_tiles; column_tile_index++) {
211
+ nk_size_t const col_start = column_tile_index * tile_dimension;
212
+ nk_size_t const cols_remaining = (col_start + tile_dimension <= document_count)
213
+ ? tile_dimension
214
+ : (document_count - col_start);
215
+ svbool_t const column_predicate_f16x = (cols_remaining == tile_dimension)
216
+ ? svptrue_b16()
217
+ : svwhilelt_b16_u64(0u, cols_remaining * 2);
218
+
219
+ svzero_mask_za(nk_sme_zero_za32_tile_0_); // Zero ZA0 only
220
+
221
+ for (nk_size_t depth_step = 0; depth_step < depth_step_count; depth_step++) {
222
+ svfloat16_t query_packed_f16x = svld1_f16(
223
+ row_predicate_f16x,
224
+ (float16_t const *)(query_vecs +
225
+ (row_tile_index * depth_step_count + depth_step) * vector_elements));
226
+ svfloat16_t document_packed_f16x = svld1_f16(
227
+ column_predicate_f16x,
228
+ (float16_t const *)(document_vecs +
229
+ (column_tile_index * depth_step_count + depth_step) * vector_elements));
230
+ svmopa_za32_f16_m(0, row_predicate_f16x, column_predicate_f16x, query_packed_f16x,
231
+ document_packed_f16x);
232
+ }
233
+
234
+ // Vertical column extraction from ZA0 + argmax update
235
+ for (nk_size_t column_within_tile = 0; column_within_tile < cols_remaining; column_within_tile++) {
236
+ nk_u32_t document_index = (nk_u32_t)(col_start + column_within_tile);
237
+ svfloat32_t column_dots_f32x = svread_ver_za32_f32_m(svdup_f32(NK_F32_MIN), predicate_all_f32x, 0,
238
+ column_within_tile);
239
+ svbool_t is_better_bx = svcmpgt_f32(predicate_all_f32x, column_dots_f32x, running_maximum_f32x);
240
+ running_maximum_f32x = svsel_f32(is_better_bx, column_dots_f32x, running_maximum_f32x);
241
+ running_argmax_u32x = svsel_u32(is_better_bx, svdup_u32(document_index), running_argmax_u32x);
242
+ }
243
+ }
244
+
245
+ // Angular distance finalization — SVE-width vector ops
246
+ // Gather document inverse norms via argmax indices (no SVE gather in streaming mode)
247
+ nk_u32_t best_document_indices[64];
248
+ nk_f32_t document_inverse_norms_gathered[64];
249
+ svst1_u32(row_predicate_f32x, best_document_indices, running_argmax_u32x);
250
+ for (nk_size_t row_in_tile = 0; row_in_tile < rows_remaining; row_in_tile++)
251
+ document_inverse_norms_gathered[row_in_tile] = document_inverse_norms[best_document_indices[row_in_tile]];
252
+
253
+ // SVE-width: cosine = dot * inv_norm_q * inv_norm_d, angular = max(1 - cosine, 0)
254
+ svfloat32_t query_inverse_norms_f32x = svld1_f32(row_predicate_f32x, query_inverse_norms + row_start);
255
+ svfloat32_t document_inverse_norms_f32x = svld1_f32(row_predicate_f32x, document_inverse_norms_gathered);
256
+ svfloat32_t cosine_f32x = svmul_f32_x(
257
+ row_predicate_f32x, svmul_f32_x(row_predicate_f32x, running_maximum_f32x, query_inverse_norms_f32x),
258
+ document_inverse_norms_f32x);
259
+ svfloat32_t angular_distance_f32x = svmax_f32_x(
260
+ row_predicate_f32x, svsub_f32_x(row_predicate_f32x, svdup_f32(1.0f), cosine_f32x), svdup_f32(0.0f));
261
+ total_angular_distance += svaddv_f32(row_predicate_f32x, angular_distance_f32x);
262
+ }
263
+
264
+ *result = total_angular_distance;
265
+ }
266
+
267
+ NK_PUBLIC void nk_maxsim_packed_f16_sme( //
268
+ void const *query_packed, void const *document_packed, //
269
+ nk_size_t query_count, nk_size_t document_count, nk_size_t depth, //
270
+ nk_f32_t *result) { //
271
+
272
+ nk_maxsim_packed_f16_streaming_(query_packed, document_packed, query_count, document_count, depth, result);
273
+ }
274
+
275
+ /**
276
+ * MaxSim bf16 kernel: both Q and D pre-packed, vertical column read extraction.
277
+ *
278
+ * 4-tile fast path: processes 4 doc column tiles simultaneously using ZA0-ZA3.
279
+ * Inner loop per depth_step: 1 Q load + 4 D loads + 4 BFMOPA = 9 ops.
280
+ * Extraction per 4-tile group: 4×16 = 64 vertical reads + 64 svmax = ~128 cycles.
281
+ *
282
+ * 1-tile remainder: uses ZA0 only, with predicated loads for partial tiles.
283
+ */
284
+ __arm_locally_streaming __arm_new("za") static void nk_maxsim_packed_bf16_streaming_( //
285
+ void const *query_packed, void const *document_packed, //
286
+ nk_size_t query_count, nk_size_t document_count, //
287
+ nk_size_t depth, nk_f32_t *result) {
288
+
289
+ nk_maxsim_sme_packed_header_t const *query_header = (nk_maxsim_sme_packed_header_t const *)query_packed;
290
+ nk_maxsim_sme_packed_header_t const *document_header = (nk_maxsim_sme_packed_header_t const *)document_packed;
291
+ nk_size_t const depth_step_count = query_header->depth_tile_count;
292
+ nk_size_t const query_row_tiles = query_header->column_tile_count;
293
+ nk_size_t const document_col_tiles = document_header->column_tile_count;
294
+
295
+ nk_size_t const tile_dimension = svcntw(); // 16: ZA32 tile dimension
296
+ nk_size_t const vector_elements = svcnth(); // 32: bf16 elements per SVE vector
297
+
298
+ nk_bf16_t const *query_vecs = (nk_bf16_t const *)((char const *)query_packed +
299
+ sizeof(nk_maxsim_sme_packed_header_t));
300
+ nk_bf16_t const *document_vecs = (nk_bf16_t const *)((char const *)document_packed +
301
+ sizeof(nk_maxsim_sme_packed_header_t));
302
+
303
+ nk_f32_t const *query_inverse_norms = (nk_f32_t const *)((char const *)query_packed + query_header->norms_offset);
304
+ nk_f32_t const *document_inverse_norms = (nk_f32_t const *)((char const *)document_packed +
305
+ document_header->norms_offset);
306
+
307
+ svbool_t const predicate_all_f16x = svptrue_b16();
308
+ svbool_t const predicate_all_f32x = svptrue_b32();
309
+
310
+ nk_f32_t total_angular_distance = 0.0f;
311
+
312
+ for (nk_size_t row_tile_index = 0; row_tile_index < query_row_tiles; row_tile_index++) {
313
+ nk_size_t const row_start = row_tile_index * tile_dimension;
314
+ nk_size_t const rows_remaining = (row_start + tile_dimension <= query_count) ? tile_dimension
315
+ : (query_count - row_start);
316
+ svbool_t const row_predicate_f16x = (rows_remaining == tile_dimension)
317
+ ? svptrue_b16()
318
+ : svwhilelt_b16_u64(0u, rows_remaining * 2);
319
+ svbool_t const row_predicate_f32x = (rows_remaining == tile_dimension) ? svptrue_b32()
320
+ : svwhilelt_b32_u64(0u, rows_remaining);
321
+
322
+ // Running max + argmax vectors for angular distance finalization
323
+ svfloat32_t running_maximum_f32x = svdup_f32(NK_F32_MIN);
324
+ svuint32_t running_argmax_u32x = svdup_u32(0);
325
+
326
+ nk_size_t column_tile_index = 0;
327
+
328
+ // Fast path: 4 doc column tiles at a time using ZA0-ZA3
329
+ for (; column_tile_index + 4 <= document_col_tiles; column_tile_index += 4) {
330
+ svzero_za(); // Zero all 4 tiles
331
+
332
+ // Accumulate: for each depth step, load Q vector and 4 D vectors, issue 4 BFMOPAs
333
+ for (nk_size_t depth_step = 0; depth_step < depth_step_count; depth_step++) {
334
+ svbfloat16_t query_packed_bf16x = svld1_bf16(
335
+ row_predicate_f16x,
336
+ (bfloat16_t const *)(query_vecs +
337
+ (row_tile_index * depth_step_count + depth_step) * vector_elements));
338
+ svbfloat16_t document_packed_0_bf16x = svld1_bf16(
339
+ predicate_all_f16x,
340
+ (bfloat16_t const *)(document_vecs +
341
+ ((column_tile_index + 0) * depth_step_count + depth_step) * vector_elements));
342
+ svbfloat16_t document_packed_1_bf16x = svld1_bf16(
343
+ predicate_all_f16x,
344
+ (bfloat16_t const *)(document_vecs +
345
+ ((column_tile_index + 1) * depth_step_count + depth_step) * vector_elements));
346
+ svbfloat16_t document_packed_2_bf16x = svld1_bf16(
347
+ predicate_all_f16x,
348
+ (bfloat16_t const *)(document_vecs +
349
+ ((column_tile_index + 2) * depth_step_count + depth_step) * vector_elements));
350
+ svbfloat16_t document_packed_3_bf16x = svld1_bf16(
351
+ predicate_all_f16x,
352
+ (bfloat16_t const *)(document_vecs +
353
+ ((column_tile_index + 3) * depth_step_count + depth_step) * vector_elements));
354
+ svmopa_za32_bf16_m(0, row_predicate_f16x, predicate_all_f16x, query_packed_bf16x,
355
+ document_packed_0_bf16x);
356
+ svmopa_za32_bf16_m(1, row_predicate_f16x, predicate_all_f16x, query_packed_bf16x,
357
+ document_packed_1_bf16x);
358
+ svmopa_za32_bf16_m(2, row_predicate_f16x, predicate_all_f16x, query_packed_bf16x,
359
+ document_packed_2_bf16x);
360
+ svmopa_za32_bf16_m(3, row_predicate_f16x, predicate_all_f16x, query_packed_bf16x,
361
+ document_packed_3_bf16x);
362
+ }
363
+
364
+ // Vertical column extraction + argmax update (manually unrolled over 4 tiles)
365
+ for (nk_size_t column_within_tile = 0; column_within_tile < tile_dimension; column_within_tile++) {
366
+ // Tile 0
367
+ {
368
+ nk_u32_t document_index = (nk_u32_t)((column_tile_index + 0) * tile_dimension + column_within_tile);
369
+ svfloat32_t column_dots_f32x = svread_ver_za32_f32_m(svdup_f32(NK_F32_MIN), predicate_all_f32x, 0,
370
+ column_within_tile);
371
+ svbool_t is_better_bx = svcmpgt_f32(predicate_all_f32x, column_dots_f32x, running_maximum_f32x);
372
+ running_maximum_f32x = svsel_f32(is_better_bx, column_dots_f32x, running_maximum_f32x);
373
+ running_argmax_u32x = svsel_u32(is_better_bx, svdup_u32(document_index), running_argmax_u32x);
374
+ }
375
+ // Tile 1
376
+ {
377
+ nk_u32_t document_index = (nk_u32_t)((column_tile_index + 1) * tile_dimension + column_within_tile);
378
+ svfloat32_t column_dots_f32x = svread_ver_za32_f32_m(svdup_f32(NK_F32_MIN), predicate_all_f32x, 1,
379
+ column_within_tile);
380
+ svbool_t is_better_bx = svcmpgt_f32(predicate_all_f32x, column_dots_f32x, running_maximum_f32x);
381
+ running_maximum_f32x = svsel_f32(is_better_bx, column_dots_f32x, running_maximum_f32x);
382
+ running_argmax_u32x = svsel_u32(is_better_bx, svdup_u32(document_index), running_argmax_u32x);
383
+ }
384
+ // Tile 2
385
+ {
386
+ nk_u32_t document_index = (nk_u32_t)((column_tile_index + 2) * tile_dimension + column_within_tile);
387
+ svfloat32_t column_dots_f32x = svread_ver_za32_f32_m(svdup_f32(NK_F32_MIN), predicate_all_f32x, 2,
388
+ column_within_tile);
389
+ svbool_t is_better_bx = svcmpgt_f32(predicate_all_f32x, column_dots_f32x, running_maximum_f32x);
390
+ running_maximum_f32x = svsel_f32(is_better_bx, column_dots_f32x, running_maximum_f32x);
391
+ running_argmax_u32x = svsel_u32(is_better_bx, svdup_u32(document_index), running_argmax_u32x);
392
+ }
393
+ // Tile 3
394
+ {
395
+ nk_u32_t document_index = (nk_u32_t)((column_tile_index + 3) * tile_dimension + column_within_tile);
396
+ svfloat32_t column_dots_f32x = svread_ver_za32_f32_m(svdup_f32(NK_F32_MIN), predicate_all_f32x, 3,
397
+ column_within_tile);
398
+ svbool_t is_better_bx = svcmpgt_f32(predicate_all_f32x, column_dots_f32x, running_maximum_f32x);
399
+ running_maximum_f32x = svsel_f32(is_better_bx, column_dots_f32x, running_maximum_f32x);
400
+ running_argmax_u32x = svsel_u32(is_better_bx, svdup_u32(document_index), running_argmax_u32x);
401
+ }
402
+ }
403
+ }
404
+
405
+ // Remainder: 1 doc column tile at a time using ZA0 only
406
+ for (; column_tile_index < document_col_tiles; column_tile_index++) {
407
+ nk_size_t const col_start = column_tile_index * tile_dimension;
408
+ nk_size_t const cols_remaining = (col_start + tile_dimension <= document_count)
409
+ ? tile_dimension
410
+ : (document_count - col_start);
411
+ svbool_t const column_predicate_f16x = (cols_remaining == tile_dimension)
412
+ ? svptrue_b16()
413
+ : svwhilelt_b16_u64(0u, cols_remaining * 2);
414
+
415
+ svzero_mask_za(nk_sme_zero_za32_tile_0_); // Zero ZA0 only
416
+
417
+ for (nk_size_t depth_step = 0; depth_step < depth_step_count; depth_step++) {
418
+ svbfloat16_t query_packed_bf16x = svld1_bf16(
419
+ row_predicate_f16x,
420
+ (bfloat16_t const *)(query_vecs +
421
+ (row_tile_index * depth_step_count + depth_step) * vector_elements));
422
+ svbfloat16_t document_packed_bf16x = svld1_bf16(
423
+ column_predicate_f16x,
424
+ (bfloat16_t const *)(document_vecs +
425
+ (column_tile_index * depth_step_count + depth_step) * vector_elements));
426
+ svmopa_za32_bf16_m(0, row_predicate_f16x, column_predicate_f16x, query_packed_bf16x,
427
+ document_packed_bf16x);
428
+ }
429
+
430
+ // Vertical column extraction from ZA0 + argmax update
431
+ for (nk_size_t column_within_tile = 0; column_within_tile < cols_remaining; column_within_tile++) {
432
+ nk_u32_t document_index = (nk_u32_t)(col_start + column_within_tile);
433
+ svfloat32_t column_dots_f32x = svread_ver_za32_f32_m(svdup_f32(NK_F32_MIN), predicate_all_f32x, 0,
434
+ column_within_tile);
435
+ svbool_t is_better_bx = svcmpgt_f32(predicate_all_f32x, column_dots_f32x, running_maximum_f32x);
436
+ running_maximum_f32x = svsel_f32(is_better_bx, column_dots_f32x, running_maximum_f32x);
437
+ running_argmax_u32x = svsel_u32(is_better_bx, svdup_u32(document_index), running_argmax_u32x);
438
+ }
439
+ }
440
+
441
+ // Angular distance finalization — SVE-width vector ops
442
+ // Gather document inverse norms via argmax indices (no SVE gather in streaming mode)
443
+ nk_u32_t best_document_indices[64];
444
+ nk_f32_t document_inverse_norms_gathered[64];
445
+ svst1_u32(row_predicate_f32x, best_document_indices, running_argmax_u32x);
446
+ for (nk_size_t row_in_tile = 0; row_in_tile < rows_remaining; row_in_tile++)
447
+ document_inverse_norms_gathered[row_in_tile] = document_inverse_norms[best_document_indices[row_in_tile]];
448
+
449
+ // SVE-width: cosine = dot * inv_norm_q * inv_norm_d, angular = max(1 - cosine, 0)
450
+ svfloat32_t query_inverse_norms_f32x = svld1_f32(row_predicate_f32x, query_inverse_norms + row_start);
451
+ svfloat32_t document_inverse_norms_f32x = svld1_f32(row_predicate_f32x, document_inverse_norms_gathered);
452
+ svfloat32_t cosine_f32x = svmul_f32_x(
453
+ row_predicate_f32x, svmul_f32_x(row_predicate_f32x, running_maximum_f32x, query_inverse_norms_f32x),
454
+ document_inverse_norms_f32x);
455
+ svfloat32_t angular_distance_f32x = svmax_f32_x(
456
+ row_predicate_f32x, svsub_f32_x(row_predicate_f32x, svdup_f32(1.0f), cosine_f32x), svdup_f32(0.0f));
457
+ total_angular_distance += svaddv_f32(row_predicate_f32x, angular_distance_f32x);
458
+ }
459
+
460
+ *result = total_angular_distance;
461
+ }
462
+
463
+ NK_PUBLIC void nk_maxsim_packed_bf16_sme( //
464
+ void const *query_packed, void const *document_packed, //
465
+ nk_size_t query_count, nk_size_t document_count, nk_size_t depth, //
466
+ nk_f32_t *result) { //
467
+
468
+ nk_maxsim_packed_bf16_streaming_(query_packed, document_packed, query_count, document_count, depth, result);
469
+ }
470
+
471
+ NK_PUBLIC nk_size_t nk_maxsim_packed_size_bf16_sme(nk_size_t n, nk_size_t k) { //
472
+ return nk_dots_packed_size_bf16_sme(n, k);
473
+ }
474
+
475
+ NK_PUBLIC nk_size_t nk_maxsim_packed_size_f16_sme(nk_size_t n, nk_size_t k) { //
476
+ return nk_dots_packed_size_f16_sme(n, k);
477
+ }
478
+
479
+ NK_PUBLIC void nk_maxsim_pack_bf16_sme( //
480
+ nk_bf16_t const *vectors, nk_size_t n, nk_size_t k, nk_size_t stride, void *packed) { //
481
+
482
+ // Delegate tile interleaving and squared norms computation to dots pack.
483
+ // Both headers are 64 bytes with identical layout for the first 6 fields.
484
+ nk_dots_pack_bf16_sme(vectors, n, k, stride, packed);
485
+
486
+ // Set maxsim-specific header fields (overlaps dots reserved area)
487
+ nk_maxsim_sme_packed_header_t *header = (nk_maxsim_sme_packed_header_t *)packed;
488
+ header->originals_offset = 0; // not used for bf16
489
+ header->original_stride = 0; // not used for bf16
490
+ for (nk_size_t i = 0; i < 8; i++) header->reserved[i] = 0;
491
+
492
+ // Convert squared norms → inverse norms in-place
493
+ nk_f32_t *norms = (nk_f32_t *)((char *)packed + header->norms_offset);
494
+ for (nk_size_t i = 0; i < n; i++) {
495
+ nk_f32_t norm_sq = norms[i];
496
+ norms[i] = (norm_sq > 0.0f) ? (nk_f32_t)nk_f64_rsqrt_neon((nk_f64_t)norm_sq) : 0.0f;
497
+ }
498
+ }
499
+
500
+ NK_PUBLIC void nk_maxsim_pack_f16_sme( //
501
+ nk_f16_t const *vectors, nk_size_t n, nk_size_t k, nk_size_t stride, void *packed) { //
502
+
503
+ // Delegate tile interleaving and squared norms computation to dots pack.
504
+ // Both headers are 64 bytes with identical layout for the first 6 fields.
505
+ nk_dots_pack_f16_sme(vectors, n, k, stride, packed);
506
+
507
+ // Set maxsim-specific header fields (overlaps dots reserved area)
508
+ nk_maxsim_sme_packed_header_t *header = (nk_maxsim_sme_packed_header_t *)packed;
509
+ header->originals_offset = 0; // not used for f16
510
+ header->original_stride = 0; // not used for f16
511
+ for (nk_size_t i = 0; i < 8; i++) header->reserved[i] = 0;
512
+
513
+ // Convert squared norms → inverse norms in-place
514
+ nk_f32_t *norms = (nk_f32_t *)((char *)packed + header->norms_offset);
515
+ for (nk_size_t i = 0; i < n; i++) {
516
+ nk_f32_t norm_sq = norms[i];
517
+ norms[i] = (norm_sq > 0.0f) ? (nk_f32_t)nk_f64_rsqrt_neon((nk_f64_t)norm_sq) : 0.0f;
518
+ }
519
+ }
520
+
521
+ /**
522
+ * MaxSim f32 kernel: i8 SMOPA screening + f32/f64 refinement + angular distance.
523
+ *
524
+ * Screening: i8 SMOPA has expansion=4, processing 4x more depth per instruction than f32 FMOPA.
525
+ * With 4 ZA tiles the fast path processes 64 document columns per iteration.
526
+ *
527
+ * Refinement: tile-wide interleaved f64 dot products for the winning (query, document) pairs.
528
+ * Angular distance: 1 - dot / sqrt(||q||^2 * ||d||^2), accumulated with f64.
529
+ */
530
+ NK_PUBLIC nk_size_t nk_maxsim_packed_size_f32_sme(nk_size_t n, nk_size_t k) { //
531
+ nk_size_t const expansion = 4; // i8->i32 SMOPA
532
+ nk_size_t const tile_dimension = svcntsw(); // 16 for SVL=512
533
+ nk_size_t const vector_elements = svcntsb(); // 64 for SVL=512
534
+ nk_size_t const column_tile_count = nk_size_divide_round_up_(n, tile_dimension);
535
+ nk_size_t const depth_step_count = nk_size_divide_round_up_(k, expansion);
536
+ nk_size_t const original_stride = nk_size_round_up_to_multiple_(k * sizeof(nk_f32_t), 64);
537
+
538
+ nk_size_t size = sizeof(nk_maxsim_sme_packed_header_t); // 64 B header
539
+ size += column_tile_count * depth_step_count * vector_elements; // i8 tiles
540
+ size += n * sizeof(nk_f32_t); // f32 squared norms
541
+ size += n * original_stride; // f32 originals
542
+ return size;
543
+ }
544
+
545
+ NK_PUBLIC void nk_maxsim_pack_f32_sme( //
546
+ nk_f32_t const *vectors, nk_size_t n, nk_size_t k, nk_size_t stride, void *packed) { //
547
+
548
+ nk_size_t const expansion = 4; // i8->i32 SMOPA
549
+ nk_size_t const tile_dimension = svcntsw(); // 16 for SVL=512
550
+ nk_size_t const vector_elements = svcntsb(); // 64 for SVL=512
551
+ nk_size_t const stride_elements = stride / sizeof(nk_f32_t);
552
+
553
+ nk_size_t const column_tile_count = nk_size_divide_round_up_(n, tile_dimension);
554
+ nk_size_t const depth_step_count = nk_size_divide_round_up_(k, expansion);
555
+ nk_size_t const total_vectors = column_tile_count * depth_step_count;
556
+ nk_size_t const original_stride = nk_size_round_up_to_multiple_(k * sizeof(nk_f32_t), 64);
557
+
558
+ // Set up header
559
+ nk_maxsim_sme_packed_header_t *header = (nk_maxsim_sme_packed_header_t *)packed;
560
+ header->column_tile_count = (nk_u32_t)column_tile_count;
561
+ header->depth_tile_count = (nk_u32_t)depth_step_count;
562
+ header->columns = (nk_u32_t)n;
563
+ header->depth = (nk_u32_t)k;
564
+ header->svl_bytes = (nk_u32_t)(svcntsw() * sizeof(nk_f32_t));
565
+
566
+ nk_size_t const tiles_size = total_vectors * vector_elements;
567
+ nk_size_t const norms_offset = sizeof(nk_maxsim_sme_packed_header_t) + tiles_size;
568
+ nk_size_t const originals_offset = norms_offset + n * sizeof(nk_f32_t);
569
+
570
+ header->norms_offset = (nk_u32_t)norms_offset;
571
+ header->originals_offset = (nk_u32_t)originals_offset;
572
+ header->original_stride = (nk_u32_t)original_stride;
573
+ for (nk_size_t i = 0; i < 8; i++) header->reserved[i] = 0;
574
+
575
+ nk_i8_t *tiles = (nk_i8_t *)((char *)packed + sizeof(nk_maxsim_sme_packed_header_t));
576
+ nk_f32_t *norms = (nk_f32_t *)((char *)packed + norms_offset);
577
+ char *originals = (char *)packed + originals_offset;
578
+
579
+ // Zero-initialize tile data (partial vectors stay zero-padded)
580
+ for (nk_size_t i = 0; i < tiles_size; i++) tiles[i] = 0;
581
+
582
+ // For each vector: quantize metadata, quantize+interleave into tiles, copy originals
583
+ for (nk_size_t vector_index = 0; vector_index < n; vector_index++) {
584
+ nk_f32_t const *source = (nk_f32_t const *)((char const *)vectors + vector_index * stride);
585
+
586
+ // Pass 1: Compute absmax and norm_sq simultaneously
587
+ nk_f32_t absmax = 0.0f;
588
+ nk_f32_t norm_sq = 0.0f;
589
+ for (nk_size_t dim = 0; dim < k; dim++) {
590
+ nk_f32_t val = source[dim];
591
+ nk_f32_t abs_val = nk_f32_abs_(val);
592
+ if (abs_val > absmax) absmax = abs_val;
593
+ norm_sq += val * val;
594
+ }
595
+ norms[vector_index] = norm_sq;
596
+
597
+ nk_f32_t scale = absmax / 127.0f;
598
+ if (scale == 0.0f) scale = 1.0f;
599
+
600
+ // Pass 2: Quantize and scatter into tile-interleaved positions
601
+ nk_size_t const column_tile = vector_index / tile_dimension;
602
+ nk_size_t const column_in_tile = vector_index % tile_dimension;
603
+
604
+ for (nk_size_t dim = 0; dim < k; dim++) {
605
+ nk_size_t const depth_step = dim / expansion;
606
+ nk_size_t const sub_element = dim % expansion;
607
+ nk_size_t const vec_index = column_tile * depth_step_count + depth_step;
608
+ nk_size_t const offset = vec_index * vector_elements + expansion * column_in_tile + sub_element;
609
+
610
+ nk_f32_t scaled = source[dim] / scale;
611
+ nk_i32_t quantized;
612
+ if (scaled >= 0.0f) quantized = (nk_i32_t)(scaled + 0.5f);
613
+ else quantized = (nk_i32_t)(scaled - 0.5f);
614
+ if (quantized > 127) quantized = 127;
615
+ if (quantized < -127) quantized = -127;
616
+
617
+ tiles[offset] = (nk_i8_t)quantized;
618
+ }
619
+
620
+ // Pass 3: Copy originals (64B-aligned stride, zero-pad tail)
621
+ char *dest_original = originals + vector_index * original_stride;
622
+ nk_copy_bytes_(dest_original, source, k * sizeof(nk_f32_t));
623
+ for (nk_size_t byte = k * sizeof(nk_f32_t); byte < original_stride; byte++) dest_original[byte] = 0;
624
+ }
625
+ }
626
+
627
+ /**
628
+ * Streaming-compatible f32 dot product with f64 accumulation.
629
+ * Follows the svcntd()-stride + svcvt_f64_f32_x pattern from nk_dots_reduce_sumsq_f32_ssve_.
630
+ */
631
+ NK_PUBLIC nk_f64_t nk_maxsim_reduce_dot_f32_ssve_( //
632
+ nk_f32_t const *a, nk_f32_t const *b, nk_size_t count) NK_STREAMING_COMPATIBLE_ { //
633
+ svfloat64_t accumulator_f64x = svdup_f64(0.0);
634
+ for (nk_size_t i = 0; i < count; i += svcntd()) {
635
+ svbool_t predicate_f64x = svwhilelt_b64_u64(i, count);
636
+ svfloat64_t a_f64x = svcvt_f64_f32_x(predicate_f64x, svld1_f32(svwhilelt_b32_u64(i, count), a + i));
637
+ svfloat64_t b_f64x = svcvt_f64_f32_x(predicate_f64x, svld1_f32(svwhilelt_b32_u64(i, count), b + i));
638
+ accumulator_f64x = svmla_f64_x(predicate_f64x, accumulator_f64x, a_f64x, b_f64x);
639
+ }
640
+ return svaddv_f64(svptrue_b64(), accumulator_f64x);
641
+ }
642
+
643
+ /**
644
+ * MaxSim f32 kernel: i8 SMOPA screening + f32/f64 refinement + angular distance.
645
+ *
646
+ * Screening: i8 SMOPA has expansion=4, processing 4x more depth per instruction than f32 FMOPA.
647
+ * With 4 ZA tiles the fast path processes 64 document columns per iteration.
648
+ *
649
+ * Refinement: tile-wide interleaved f64 dot products for the winning (query, document) pairs.
650
+ * Angular distance: 1 - dot / sqrt(||q||^2 * ||d||^2), accumulated with f64.
651
+ */
652
+ __arm_locally_streaming __arm_new("za") static void nk_maxsim_packed_f32_streaming_( //
653
+ void const *query_packed, void const *document_packed, //
654
+ nk_size_t query_count, nk_size_t document_count, nk_size_t depth, //
655
+ nk_f64_t *result) {
656
+
657
+ nk_maxsim_sme_packed_header_t const *query_header = (nk_maxsim_sme_packed_header_t const *)query_packed;
658
+ nk_maxsim_sme_packed_header_t const *document_header = (nk_maxsim_sme_packed_header_t const *)document_packed;
659
+
660
+ nk_size_t const depth_step_count = query_header->depth_tile_count;
661
+ nk_size_t const query_row_tiles = query_header->column_tile_count;
662
+ nk_size_t const document_col_tiles = document_header->column_tile_count;
663
+
664
+ nk_size_t const tile_dimension = svcntw(); // 16: ZA32 tile dimension
665
+ nk_size_t const vector_elements = svcntb(); // 64: i8 elements per SVE vector
666
+
667
+ // Tile data pointers (i8)
668
+ nk_i8_t const *query_tiles = (nk_i8_t const *)((char const *)query_packed + sizeof(nk_maxsim_sme_packed_header_t));
669
+ nk_i8_t const *document_tiles = (nk_i8_t const *)((char const *)document_packed +
670
+ sizeof(nk_maxsim_sme_packed_header_t));
671
+
672
+ // Norms and originals pointers
673
+ nk_f32_t const *query_norms = (nk_f32_t const *)((char const *)query_packed + query_header->norms_offset);
674
+ nk_f32_t const *document_norms = (nk_f32_t const *)((char const *)document_packed + document_header->norms_offset);
675
+ nk_f32_t const *query_originals = (nk_f32_t const *)((char const *)query_packed + query_header->originals_offset);
676
+ nk_f32_t const *document_originals = (nk_f32_t const *)((char const *)document_packed +
677
+ document_header->originals_offset);
678
+ nk_size_t const query_original_stride_elements = query_header->original_stride / sizeof(nk_f32_t);
679
+ nk_size_t const document_original_stride_elements = document_header->original_stride / sizeof(nk_f32_t);
680
+
681
+ nk_size_t const expansion = 4; // i8->i32 SMOPA
682
+
683
+ svbool_t const predicate_all_i8x = svptrue_b8();
684
+ svbool_t const predicate_all_f32x = svptrue_b32();
685
+
686
+ nk_f64_t total_angular_distance_f64 = 0.0;
687
+
688
+ for (nk_size_t row_tile_index = 0; row_tile_index < query_row_tiles; row_tile_index++) {
689
+ nk_size_t const row_start = row_tile_index * tile_dimension;
690
+ nk_size_t const rows_remaining = (row_start + tile_dimension <= query_count) ? tile_dimension
691
+ : (query_count - row_start);
692
+ svbool_t const row_predicate_i8x = (rows_remaining == tile_dimension)
693
+ ? svptrue_b8()
694
+ : svwhilelt_b8_u64(0u, rows_remaining * expansion);
695
+ svbool_t const row_predicate_f32x = (rows_remaining == tile_dimension) ? svptrue_b32()
696
+ : svwhilelt_b32_u64(0u, rows_remaining);
697
+
698
+ svint32_t running_max_i32x = svdup_s32(NK_I32_MIN);
699
+ svuint32_t running_argmax_u32x = svdup_u32(0);
700
+
701
+ nk_size_t column_tile_index = 0;
702
+
703
+ // 4-tile fast path: ZA0-ZA3 process 4 document column tiles simultaneously
704
+ for (; column_tile_index + 4 <= document_col_tiles; column_tile_index += 4) {
705
+ svzero_za();
706
+
707
+ for (nk_size_t depth_step = 0; depth_step < depth_step_count; depth_step++) {
708
+ svint8_t query_packed_i8x = svld1_s8(
709
+ row_predicate_i8x,
710
+ (int8_t const *)(query_tiles + (row_tile_index * depth_step_count + depth_step) * vector_elements));
711
+ svint8_t document_packed_0_i8x = svld1_s8(
712
+ predicate_all_i8x,
713
+ (int8_t const *)(document_tiles +
714
+ ((column_tile_index + 0) * depth_step_count + depth_step) * vector_elements));
715
+ svint8_t document_packed_1_i8x = svld1_s8(
716
+ predicate_all_i8x,
717
+ (int8_t const *)(document_tiles +
718
+ ((column_tile_index + 1) * depth_step_count + depth_step) * vector_elements));
719
+ svint8_t document_packed_2_i8x = svld1_s8(
720
+ predicate_all_i8x,
721
+ (int8_t const *)(document_tiles +
722
+ ((column_tile_index + 2) * depth_step_count + depth_step) * vector_elements));
723
+ svint8_t document_packed_3_i8x = svld1_s8(
724
+ predicate_all_i8x,
725
+ (int8_t const *)(document_tiles +
726
+ ((column_tile_index + 3) * depth_step_count + depth_step) * vector_elements));
727
+ svmopa_za32_s8_m(0, row_predicate_i8x, predicate_all_i8x, query_packed_i8x, document_packed_0_i8x);
728
+ svmopa_za32_s8_m(1, row_predicate_i8x, predicate_all_i8x, query_packed_i8x, document_packed_1_i8x);
729
+ svmopa_za32_s8_m(2, row_predicate_i8x, predicate_all_i8x, query_packed_i8x, document_packed_2_i8x);
730
+ svmopa_za32_s8_m(3, row_predicate_i8x, predicate_all_i8x, query_packed_i8x, document_packed_3_i8x);
731
+ }
732
+
733
+ // Vertical column extraction + argmax update (manually unrolled over 4 tiles)
734
+ for (nk_size_t column_within_tile = 0; column_within_tile < tile_dimension; column_within_tile++) {
735
+ // Tile 0
736
+ {
737
+ nk_u32_t document_index = (nk_u32_t)((column_tile_index + 0) * tile_dimension + column_within_tile);
738
+ svint32_t column_dots_i32x = svread_ver_za32_s32_m(svdup_s32(NK_I32_MIN), predicate_all_f32x, 0,
739
+ column_within_tile);
740
+ svbool_t is_better_bx = svcmpgt_s32(predicate_all_f32x, column_dots_i32x, running_max_i32x);
741
+ running_max_i32x = svsel_s32(is_better_bx, column_dots_i32x, running_max_i32x);
742
+ running_argmax_u32x = svsel_u32(is_better_bx, svdup_u32(document_index), running_argmax_u32x);
743
+ }
744
+ // Tile 1
745
+ {
746
+ nk_u32_t document_index = (nk_u32_t)((column_tile_index + 1) * tile_dimension + column_within_tile);
747
+ svint32_t column_dots_i32x = svread_ver_za32_s32_m(svdup_s32(NK_I32_MIN), predicate_all_f32x, 1,
748
+ column_within_tile);
749
+ svbool_t is_better_bx = svcmpgt_s32(predicate_all_f32x, column_dots_i32x, running_max_i32x);
750
+ running_max_i32x = svsel_s32(is_better_bx, column_dots_i32x, running_max_i32x);
751
+ running_argmax_u32x = svsel_u32(is_better_bx, svdup_u32(document_index), running_argmax_u32x);
752
+ }
753
+ // Tile 2
754
+ {
755
+ nk_u32_t document_index = (nk_u32_t)((column_tile_index + 2) * tile_dimension + column_within_tile);
756
+ svint32_t column_dots_i32x = svread_ver_za32_s32_m(svdup_s32(NK_I32_MIN), predicate_all_f32x, 2,
757
+ column_within_tile);
758
+ svbool_t is_better_bx = svcmpgt_s32(predicate_all_f32x, column_dots_i32x, running_max_i32x);
759
+ running_max_i32x = svsel_s32(is_better_bx, column_dots_i32x, running_max_i32x);
760
+ running_argmax_u32x = svsel_u32(is_better_bx, svdup_u32(document_index), running_argmax_u32x);
761
+ }
762
+ // Tile 3
763
+ {
764
+ nk_u32_t document_index = (nk_u32_t)((column_tile_index + 3) * tile_dimension + column_within_tile);
765
+ svint32_t column_dots_i32x = svread_ver_za32_s32_m(svdup_s32(NK_I32_MIN), predicate_all_f32x, 3,
766
+ column_within_tile);
767
+ svbool_t is_better_bx = svcmpgt_s32(predicate_all_f32x, column_dots_i32x, running_max_i32x);
768
+ running_max_i32x = svsel_s32(is_better_bx, column_dots_i32x, running_max_i32x);
769
+ running_argmax_u32x = svsel_u32(is_better_bx, svdup_u32(document_index), running_argmax_u32x);
770
+ }
771
+ }
772
+ }
773
+
774
+ // 1-tile remainder: ZA0 only
775
+ for (; column_tile_index < document_col_tiles; column_tile_index++) {
776
+ nk_size_t const col_start = column_tile_index * tile_dimension;
777
+ nk_size_t const cols_remaining = (col_start + tile_dimension <= document_count)
778
+ ? tile_dimension
779
+ : (document_count - col_start);
780
+ svbool_t const column_predicate_i8x = (cols_remaining == tile_dimension)
781
+ ? svptrue_b8()
782
+ : svwhilelt_b8_u64(0u, cols_remaining * expansion);
783
+
784
+ svzero_mask_za(nk_sme_zero_za32_tile_0_);
785
+
786
+ for (nk_size_t depth_step = 0; depth_step < depth_step_count; depth_step++) {
787
+ svint8_t query_packed_i8x = svld1_s8(
788
+ row_predicate_i8x,
789
+ (int8_t const *)(query_tiles + (row_tile_index * depth_step_count + depth_step) * vector_elements));
790
+ svint8_t document_packed_i8x = svld1_s8(
791
+ column_predicate_i8x,
792
+ (int8_t const *)(document_tiles +
793
+ (column_tile_index * depth_step_count + depth_step) * vector_elements));
794
+ svmopa_za32_s8_m(0, row_predicate_i8x, column_predicate_i8x, query_packed_i8x, document_packed_i8x);
795
+ }
796
+
797
+ for (nk_size_t column_within_tile = 0; column_within_tile < cols_remaining; column_within_tile++) {
798
+ nk_u32_t document_index = (nk_u32_t)(col_start + column_within_tile);
799
+ svint32_t column_dots_i32x = svread_ver_za32_s32_m(svdup_s32(NK_I32_MIN), predicate_all_f32x, 0,
800
+ column_within_tile);
801
+ svbool_t is_better_bx = svcmpgt_s32(predicate_all_f32x, column_dots_i32x, running_max_i32x);
802
+ running_max_i32x = svsel_s32(is_better_bx, column_dots_i32x, running_max_i32x);
803
+ running_argmax_u32x = svsel_u32(is_better_bx, svdup_u32(document_index), running_argmax_u32x);
804
+ }
805
+ }
806
+
807
+ // Refinement: tile-wide interleaved f64 dot products
808
+ nk_u32_t best_document_indices[64]; // max tile_dimension across all SVL values
809
+ svst1_u32(row_predicate_f32x, best_document_indices, running_argmax_u32x);
810
+
811
+ // Pointer setup: one (query, document) pair per row in the tile
812
+ nk_f32_t const *query_original_ptrs[64];
813
+ nk_f32_t const *document_original_ptrs[64];
814
+ for (nk_size_t row_in_tile = 0; row_in_tile < rows_remaining; row_in_tile++) {
815
+ nk_size_t query_index = row_start + row_in_tile;
816
+ nk_u32_t best_document_index = best_document_indices[row_in_tile];
817
+ query_original_ptrs[row_in_tile] = query_originals + query_index * query_original_stride_elements;
818
+ document_original_ptrs[row_in_tile] = document_originals +
819
+ best_document_index * document_original_stride_elements;
820
+ }
821
+
822
+ // Interleaved f64 dot products in batches of 4 (hides MLA 4-cycle latency)
823
+ nk_size_t row_batch_start = 0;
824
+
825
+ // Fast path: 4-wide batches
826
+ for (; row_batch_start + 4 <= rows_remaining; row_batch_start += 4) {
827
+ svfloat64_t accumulator_0_f64x = svdup_f64(0.0);
828
+ svfloat64_t accumulator_1_f64x = svdup_f64(0.0);
829
+ svfloat64_t accumulator_2_f64x = svdup_f64(0.0);
830
+ svfloat64_t accumulator_3_f64x = svdup_f64(0.0);
831
+
832
+ for (nk_size_t depth_index = 0; depth_index < depth; depth_index += svcntd()) {
833
+ svbool_t predicate_depth_f64x = svwhilelt_b64_u64(depth_index, depth);
834
+ svbool_t predicate_depth_f32x = svwhilelt_b32_u64(depth_index, depth);
835
+
836
+ svfloat64_t query_values_0_f64x = svcvt_f64_f32_x(
837
+ predicate_depth_f64x,
838
+ svld1_f32(predicate_depth_f32x, query_original_ptrs[row_batch_start + 0] + depth_index));
839
+ svfloat64_t document_values_0_f64x = svcvt_f64_f32_x(
840
+ predicate_depth_f64x,
841
+ svld1_f32(predicate_depth_f32x, document_original_ptrs[row_batch_start + 0] + depth_index));
842
+ accumulator_0_f64x = svmla_f64_x(predicate_depth_f64x, accumulator_0_f64x, query_values_0_f64x,
843
+ document_values_0_f64x);
844
+
845
+ svfloat64_t query_values_1_f64x = svcvt_f64_f32_x(
846
+ predicate_depth_f64x,
847
+ svld1_f32(predicate_depth_f32x, query_original_ptrs[row_batch_start + 1] + depth_index));
848
+ svfloat64_t document_values_1_f64x = svcvt_f64_f32_x(
849
+ predicate_depth_f64x,
850
+ svld1_f32(predicate_depth_f32x, document_original_ptrs[row_batch_start + 1] + depth_index));
851
+ accumulator_1_f64x = svmla_f64_x(predicate_depth_f64x, accumulator_1_f64x, query_values_1_f64x,
852
+ document_values_1_f64x);
853
+
854
+ svfloat64_t query_values_2_f64x = svcvt_f64_f32_x(
855
+ predicate_depth_f64x,
856
+ svld1_f32(predicate_depth_f32x, query_original_ptrs[row_batch_start + 2] + depth_index));
857
+ svfloat64_t document_values_2_f64x = svcvt_f64_f32_x(
858
+ predicate_depth_f64x,
859
+ svld1_f32(predicate_depth_f32x, document_original_ptrs[row_batch_start + 2] + depth_index));
860
+ accumulator_2_f64x = svmla_f64_x(predicate_depth_f64x, accumulator_2_f64x, query_values_2_f64x,
861
+ document_values_2_f64x);
862
+
863
+ svfloat64_t query_values_3_f64x = svcvt_f64_f32_x(
864
+ predicate_depth_f64x,
865
+ svld1_f32(predicate_depth_f32x, query_original_ptrs[row_batch_start + 3] + depth_index));
866
+ svfloat64_t document_values_3_f64x = svcvt_f64_f32_x(
867
+ predicate_depth_f64x,
868
+ svld1_f32(predicate_depth_f32x, document_original_ptrs[row_batch_start + 3] + depth_index));
869
+ accumulator_3_f64x = svmla_f64_x(predicate_depth_f64x, accumulator_3_f64x, query_values_3_f64x,
870
+ document_values_3_f64x);
871
+ }
872
+
873
+ // Reduce accumulators and compute angular distance per row
874
+ svfloat64_t *batch_accumulators[] = {&accumulator_0_f64x, &accumulator_1_f64x, &accumulator_2_f64x,
875
+ &accumulator_3_f64x};
876
+ for (nk_size_t batch_index = 0; batch_index < 4; batch_index++) {
877
+ nk_size_t query_index = row_start + row_batch_start + batch_index;
878
+ nk_u32_t best_document_index = best_document_indices[row_batch_start + batch_index];
879
+ nk_f64_t dot_product_f64 = svaddv_f64(svptrue_b64(), *batch_accumulators[batch_index]);
880
+ nk_f64_t norm_product_f64 = (nk_f64_t)query_norms[query_index] *
881
+ (nk_f64_t)document_norms[best_document_index];
882
+ nk_f64_t cosine_f64 = (norm_product_f64 > 0.0) ? dot_product_f64 * nk_f64_rsqrt_serial(norm_product_f64)
883
+ : 0.0;
884
+ nk_f64_t angular_distance_f64 = 1.0 - cosine_f64;
885
+ if (angular_distance_f64 < 0.0) angular_distance_f64 = 0.0;
886
+ total_angular_distance_f64 += angular_distance_f64;
887
+ }
888
+ }
889
+
890
+ // Remainder: 1 row at a time
891
+ for (; row_batch_start < rows_remaining; row_batch_start++) {
892
+ nk_size_t query_index = row_start + row_batch_start;
893
+ nk_u32_t best_document_index = best_document_indices[row_batch_start];
894
+ nk_f64_t dot_product_f64 = nk_maxsim_reduce_dot_f32_ssve_(query_original_ptrs[row_batch_start],
895
+ document_original_ptrs[row_batch_start], depth);
896
+ nk_f64_t norm_product_f64 = (nk_f64_t)query_norms[query_index] *
897
+ (nk_f64_t)document_norms[best_document_index];
898
+ nk_f64_t cosine_f64 = (norm_product_f64 > 0.0) ? dot_product_f64 * nk_f64_rsqrt_serial(norm_product_f64)
899
+ : 0.0;
900
+ nk_f64_t angular_distance_f64 = 1.0 - cosine_f64;
901
+ if (angular_distance_f64 < 0.0) angular_distance_f64 = 0.0;
902
+ total_angular_distance_f64 += angular_distance_f64;
903
+ }
904
+ }
905
+
906
+ *result = total_angular_distance_f64;
907
+ }
908
+
909
+ NK_PUBLIC void nk_maxsim_packed_f32_sme( //
910
+ void const *query_packed, void const *document_packed, //
911
+ nk_size_t query_count, nk_size_t document_count, nk_size_t depth, //
912
+ nk_f64_t *result) { //
913
+
914
+ nk_maxsim_packed_f32_streaming_(query_packed, document_packed, query_count, document_count, depth, result);
915
+ }
916
+
917
+ #if defined(__clang__)
918
+ #pragma clang attribute pop
919
+ #elif defined(__GNUC__)
920
+ #pragma GCC pop_options
921
+ #endif
922
+
923
+ #if defined(__cplusplus)
924
+ } // extern "C"
925
+ #endif
926
+
927
+ #endif // NK_TARGET_SME
928
+ #endif // NK_TARGET_ARM_
929
+ #endif // NK_MAXSIM_SME_H