numkong 7.0.0 → 7.4.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (315) hide show
  1. package/README.md +239 -122
  2. package/binding.gyp +25 -491
  3. package/c/dispatch_bf16.c +59 -1
  4. package/c/dispatch_e2m3.c +41 -8
  5. package/c/dispatch_e3m2.c +49 -8
  6. package/c/dispatch_e4m3.c +51 -9
  7. package/c/dispatch_e5m2.c +45 -1
  8. package/c/dispatch_f16.c +79 -26
  9. package/c/dispatch_f16c.c +5 -5
  10. package/c/dispatch_f32.c +56 -0
  11. package/c/dispatch_f64.c +52 -0
  12. package/c/dispatch_i4.c +3 -0
  13. package/c/dispatch_i8.c +62 -3
  14. package/c/dispatch_other.c +18 -0
  15. package/c/dispatch_u1.c +54 -9
  16. package/c/dispatch_u4.c +3 -0
  17. package/c/dispatch_u8.c +64 -3
  18. package/c/numkong.c +3 -0
  19. package/include/README.md +79 -9
  20. package/include/numkong/attention/sapphireamx.h +278 -276
  21. package/include/numkong/attention/sme.h +983 -977
  22. package/include/numkong/attention.h +1 -1
  23. package/include/numkong/capabilities.h +289 -94
  24. package/include/numkong/cast/README.md +40 -40
  25. package/include/numkong/cast/diamond.h +64 -0
  26. package/include/numkong/cast/haswell.h +42 -194
  27. package/include/numkong/cast/icelake.h +42 -37
  28. package/include/numkong/cast/loongsonasx.h +252 -0
  29. package/include/numkong/cast/neon.h +216 -249
  30. package/include/numkong/cast/powervsx.h +449 -0
  31. package/include/numkong/cast/rvv.h +223 -274
  32. package/include/numkong/cast/sapphire.h +18 -18
  33. package/include/numkong/cast/serial.h +1018 -944
  34. package/include/numkong/cast/skylake.h +82 -23
  35. package/include/numkong/cast/v128relaxed.h +462 -105
  36. package/include/numkong/cast.h +24 -0
  37. package/include/numkong/cast.hpp +44 -0
  38. package/include/numkong/curved/README.md +17 -17
  39. package/include/numkong/curved/neon.h +131 -7
  40. package/include/numkong/curved/neonbfdot.h +6 -7
  41. package/include/numkong/curved/rvv.h +26 -26
  42. package/include/numkong/curved/smef64.h +186 -182
  43. package/include/numkong/curved.h +14 -18
  44. package/include/numkong/dot/README.md +154 -137
  45. package/include/numkong/dot/alder.h +43 -43
  46. package/include/numkong/dot/diamond.h +158 -0
  47. package/include/numkong/dot/genoa.h +4 -30
  48. package/include/numkong/dot/haswell.h +215 -180
  49. package/include/numkong/dot/icelake.h +190 -76
  50. package/include/numkong/dot/loongsonasx.h +671 -0
  51. package/include/numkong/dot/neon.h +124 -73
  52. package/include/numkong/dot/neonbfdot.h +11 -12
  53. package/include/numkong/dot/neonfhm.h +44 -46
  54. package/include/numkong/dot/neonfp8.h +323 -0
  55. package/include/numkong/dot/neonsdot.h +190 -76
  56. package/include/numkong/dot/powervsx.h +752 -0
  57. package/include/numkong/dot/rvv.h +92 -84
  58. package/include/numkong/dot/rvvbf16.h +12 -12
  59. package/include/numkong/dot/rvvhalf.h +12 -12
  60. package/include/numkong/dot/sapphire.h +4 -4
  61. package/include/numkong/dot/serial.h +66 -30
  62. package/include/numkong/dot/sierra.h +31 -31
  63. package/include/numkong/dot/skylake.h +142 -110
  64. package/include/numkong/dot/sve.h +217 -177
  65. package/include/numkong/dot/svebfdot.h +10 -10
  66. package/include/numkong/dot/svehalf.h +85 -41
  67. package/include/numkong/dot/svesdot.h +89 -0
  68. package/include/numkong/dot/v128relaxed.h +124 -89
  69. package/include/numkong/dot.h +114 -48
  70. package/include/numkong/dots/README.md +203 -203
  71. package/include/numkong/dots/alder.h +12 -9
  72. package/include/numkong/dots/diamond.h +86 -0
  73. package/include/numkong/dots/genoa.h +10 -4
  74. package/include/numkong/dots/haswell.h +63 -48
  75. package/include/numkong/dots/icelake.h +27 -18
  76. package/include/numkong/dots/loongsonasx.h +176 -0
  77. package/include/numkong/dots/neon.h +14 -11
  78. package/include/numkong/dots/neonbfdot.h +4 -3
  79. package/include/numkong/dots/neonfhm.h +11 -9
  80. package/include/numkong/dots/neonfp8.h +99 -0
  81. package/include/numkong/dots/neonsdot.h +48 -12
  82. package/include/numkong/dots/powervsx.h +194 -0
  83. package/include/numkong/dots/rvv.h +451 -344
  84. package/include/numkong/dots/sapphireamx.h +1028 -984
  85. package/include/numkong/dots/serial.h +213 -197
  86. package/include/numkong/dots/sierra.h +10 -7
  87. package/include/numkong/dots/skylake.h +47 -36
  88. package/include/numkong/dots/sme.h +2001 -2364
  89. package/include/numkong/dots/smebi32.h +175 -162
  90. package/include/numkong/dots/smef64.h +328 -323
  91. package/include/numkong/dots/v128relaxed.h +64 -41
  92. package/include/numkong/dots.h +573 -293
  93. package/include/numkong/dots.hpp +45 -43
  94. package/include/numkong/each/README.md +133 -137
  95. package/include/numkong/each/haswell.h +6 -6
  96. package/include/numkong/each/icelake.h +7 -7
  97. package/include/numkong/each/neon.h +76 -42
  98. package/include/numkong/each/neonbfdot.h +11 -12
  99. package/include/numkong/each/neonhalf.h +24 -116
  100. package/include/numkong/each/rvv.h +28 -28
  101. package/include/numkong/each/sapphire.h +27 -161
  102. package/include/numkong/each/serial.h +6 -6
  103. package/include/numkong/each/skylake.h +7 -7
  104. package/include/numkong/each/v128relaxed.h +562 -0
  105. package/include/numkong/each.h +148 -62
  106. package/include/numkong/each.hpp +2 -2
  107. package/include/numkong/geospatial/README.md +18 -18
  108. package/include/numkong/geospatial/haswell.h +365 -325
  109. package/include/numkong/geospatial/neon.h +350 -306
  110. package/include/numkong/geospatial/rvv.h +4 -4
  111. package/include/numkong/geospatial/skylake.h +376 -340
  112. package/include/numkong/geospatial/v128relaxed.h +366 -327
  113. package/include/numkong/geospatial.h +17 -17
  114. package/include/numkong/matrix.hpp +4 -4
  115. package/include/numkong/maxsim/README.md +14 -14
  116. package/include/numkong/maxsim/alder.h +6 -6
  117. package/include/numkong/maxsim/genoa.h +4 -4
  118. package/include/numkong/maxsim/haswell.h +6 -6
  119. package/include/numkong/maxsim/icelake.h +18 -18
  120. package/include/numkong/maxsim/neonsdot.h +21 -21
  121. package/include/numkong/maxsim/sapphireamx.h +14 -14
  122. package/include/numkong/maxsim/serial.h +6 -6
  123. package/include/numkong/maxsim/sme.h +221 -196
  124. package/include/numkong/maxsim/v128relaxed.h +6 -6
  125. package/include/numkong/mesh/README.md +62 -56
  126. package/include/numkong/mesh/haswell.h +339 -464
  127. package/include/numkong/mesh/neon.h +1100 -519
  128. package/include/numkong/mesh/neonbfdot.h +36 -68
  129. package/include/numkong/mesh/rvv.h +530 -435
  130. package/include/numkong/mesh/serial.h +75 -91
  131. package/include/numkong/mesh/skylake.h +1627 -302
  132. package/include/numkong/mesh/v128relaxed.h +443 -330
  133. package/include/numkong/mesh.h +63 -49
  134. package/include/numkong/mesh.hpp +4 -4
  135. package/include/numkong/numkong.h +3 -3
  136. package/include/numkong/numkong.hpp +1 -0
  137. package/include/numkong/probability/README.md +23 -19
  138. package/include/numkong/probability/neon.h +82 -52
  139. package/include/numkong/probability/rvv.h +28 -23
  140. package/include/numkong/probability/serial.h +51 -39
  141. package/include/numkong/probability.h +20 -23
  142. package/include/numkong/random.h +1 -1
  143. package/include/numkong/reduce/README.md +143 -138
  144. package/include/numkong/reduce/alder.h +81 -77
  145. package/include/numkong/reduce/haswell.h +222 -220
  146. package/include/numkong/reduce/neon.h +629 -519
  147. package/include/numkong/reduce/neonbfdot.h +7 -218
  148. package/include/numkong/reduce/neonfhm.h +9 -381
  149. package/include/numkong/reduce/neonsdot.h +9 -9
  150. package/include/numkong/reduce/rvv.h +928 -802
  151. package/include/numkong/reduce/serial.h +23 -27
  152. package/include/numkong/reduce/sierra.h +20 -20
  153. package/include/numkong/reduce/skylake.h +326 -324
  154. package/include/numkong/reduce/v128relaxed.h +52 -52
  155. package/include/numkong/reduce.h +4 -23
  156. package/include/numkong/reduce.hpp +156 -11
  157. package/include/numkong/scalar/README.md +6 -6
  158. package/include/numkong/scalar/haswell.h +26 -17
  159. package/include/numkong/scalar/loongsonasx.h +74 -0
  160. package/include/numkong/scalar/neon.h +9 -9
  161. package/include/numkong/scalar/powervsx.h +96 -0
  162. package/include/numkong/scalar/rvv.h +2 -2
  163. package/include/numkong/scalar/sapphire.h +21 -10
  164. package/include/numkong/scalar/serial.h +21 -21
  165. package/include/numkong/scalar.h +13 -0
  166. package/include/numkong/set/README.md +28 -28
  167. package/include/numkong/set/haswell.h +12 -12
  168. package/include/numkong/set/icelake.h +14 -14
  169. package/include/numkong/set/loongsonasx.h +181 -0
  170. package/include/numkong/set/neon.h +17 -18
  171. package/include/numkong/set/powervsx.h +326 -0
  172. package/include/numkong/set/rvv.h +4 -4
  173. package/include/numkong/set/serial.h +6 -6
  174. package/include/numkong/set/sve.h +60 -59
  175. package/include/numkong/set/v128relaxed.h +6 -6
  176. package/include/numkong/set.h +21 -7
  177. package/include/numkong/sets/README.md +26 -26
  178. package/include/numkong/sets/loongsonasx.h +52 -0
  179. package/include/numkong/sets/powervsx.h +65 -0
  180. package/include/numkong/sets/smebi32.h +395 -364
  181. package/include/numkong/sets.h +83 -40
  182. package/include/numkong/sparse/README.md +4 -4
  183. package/include/numkong/sparse/icelake.h +101 -101
  184. package/include/numkong/sparse/serial.h +1 -1
  185. package/include/numkong/sparse/sve2.h +137 -141
  186. package/include/numkong/sparse/turin.h +12 -12
  187. package/include/numkong/sparse.h +10 -10
  188. package/include/numkong/spatial/README.md +230 -226
  189. package/include/numkong/spatial/alder.h +113 -116
  190. package/include/numkong/spatial/diamond.h +240 -0
  191. package/include/numkong/spatial/genoa.h +0 -68
  192. package/include/numkong/spatial/haswell.h +74 -55
  193. package/include/numkong/spatial/icelake.h +539 -58
  194. package/include/numkong/spatial/loongsonasx.h +483 -0
  195. package/include/numkong/spatial/neon.h +125 -52
  196. package/include/numkong/spatial/neonbfdot.h +8 -9
  197. package/include/numkong/spatial/neonfp8.h +258 -0
  198. package/include/numkong/spatial/neonsdot.h +180 -12
  199. package/include/numkong/spatial/powervsx.h +738 -0
  200. package/include/numkong/spatial/rvv.h +146 -139
  201. package/include/numkong/spatial/rvvbf16.h +17 -12
  202. package/include/numkong/spatial/rvvhalf.h +13 -10
  203. package/include/numkong/spatial/serial.h +13 -12
  204. package/include/numkong/spatial/sierra.h +232 -39
  205. package/include/numkong/spatial/skylake.h +73 -74
  206. package/include/numkong/spatial/sve.h +93 -72
  207. package/include/numkong/spatial/svebfdot.h +29 -29
  208. package/include/numkong/spatial/svehalf.h +52 -26
  209. package/include/numkong/spatial/svesdot.h +142 -0
  210. package/include/numkong/spatial/v128relaxed.h +293 -41
  211. package/include/numkong/spatial.h +338 -82
  212. package/include/numkong/spatials/README.md +194 -194
  213. package/include/numkong/spatials/diamond.h +82 -0
  214. package/include/numkong/spatials/haswell.h +2 -2
  215. package/include/numkong/spatials/loongsonasx.h +153 -0
  216. package/include/numkong/spatials/neonfp8.h +111 -0
  217. package/include/numkong/spatials/neonsdot.h +34 -0
  218. package/include/numkong/spatials/powervsx.h +153 -0
  219. package/include/numkong/spatials/rvv.h +259 -243
  220. package/include/numkong/spatials/sapphireamx.h +173 -173
  221. package/include/numkong/spatials/serial.h +2 -2
  222. package/include/numkong/spatials/skylake.h +2 -2
  223. package/include/numkong/spatials/sme.h +590 -605
  224. package/include/numkong/spatials/smef64.h +139 -130
  225. package/include/numkong/spatials/v128relaxed.h +2 -2
  226. package/include/numkong/spatials.h +820 -500
  227. package/include/numkong/spatials.hpp +49 -48
  228. package/include/numkong/tensor.hpp +406 -17
  229. package/include/numkong/trigonometry/README.md +19 -19
  230. package/include/numkong/trigonometry/haswell.h +402 -401
  231. package/include/numkong/trigonometry/neon.h +386 -387
  232. package/include/numkong/trigonometry/rvv.h +52 -51
  233. package/include/numkong/trigonometry/serial.h +13 -13
  234. package/include/numkong/trigonometry/skylake.h +373 -369
  235. package/include/numkong/trigonometry/v128relaxed.h +375 -374
  236. package/include/numkong/trigonometry.h +13 -13
  237. package/include/numkong/trigonometry.hpp +2 -2
  238. package/include/numkong/types.h +287 -49
  239. package/include/numkong/types.hpp +436 -12
  240. package/include/numkong/vector.hpp +82 -14
  241. package/javascript/dist/cjs/numkong-wasm.js +6 -12
  242. package/javascript/dist/cjs/numkong.d.ts +7 -1
  243. package/javascript/dist/cjs/numkong.js +37 -11
  244. package/javascript/dist/cjs/types.d.ts +9 -0
  245. package/javascript/dist/cjs/types.js +96 -0
  246. package/javascript/dist/esm/numkong-browser.d.ts +14 -0
  247. package/javascript/dist/esm/numkong-browser.js +23 -0
  248. package/javascript/dist/esm/numkong-wasm.js +6 -12
  249. package/javascript/dist/esm/numkong.d.ts +7 -1
  250. package/javascript/dist/esm/numkong.js +37 -11
  251. package/javascript/dist/esm/types.d.ts +9 -0
  252. package/javascript/dist/esm/types.js +96 -0
  253. package/javascript/node-gyp-build.d.ts +4 -1
  254. package/javascript/numkong-browser.ts +40 -0
  255. package/javascript/numkong-wasm.ts +7 -13
  256. package/javascript/numkong.c +5 -26
  257. package/javascript/numkong.ts +36 -11
  258. package/javascript/tsconfig-base.json +1 -0
  259. package/javascript/tsconfig-cjs.json +6 -1
  260. package/javascript/types.ts +110 -0
  261. package/numkong.gypi +101 -0
  262. package/package.json +34 -13
  263. package/probes/arm_neon.c +8 -0
  264. package/probes/arm_neon_bfdot.c +9 -0
  265. package/probes/arm_neon_fhm.c +9 -0
  266. package/probes/arm_neon_half.c +8 -0
  267. package/probes/arm_neon_sdot.c +9 -0
  268. package/probes/arm_neonfp8.c +9 -0
  269. package/probes/arm_sme.c +16 -0
  270. package/probes/arm_sme2.c +16 -0
  271. package/probes/arm_sme2p1.c +16 -0
  272. package/probes/arm_sme_bf16.c +16 -0
  273. package/probes/arm_sme_bi32.c +16 -0
  274. package/probes/arm_sme_f64.c +16 -0
  275. package/probes/arm_sme_fa64.c +14 -0
  276. package/probes/arm_sme_half.c +16 -0
  277. package/probes/arm_sme_lut2.c +15 -0
  278. package/probes/arm_sve.c +18 -0
  279. package/probes/arm_sve2.c +20 -0
  280. package/probes/arm_sve2p1.c +18 -0
  281. package/probes/arm_sve_bfdot.c +20 -0
  282. package/probes/arm_sve_half.c +18 -0
  283. package/probes/arm_sve_sdot.c +21 -0
  284. package/probes/loongarch_lasx.c +12 -0
  285. package/probes/power_vsx.c +12 -0
  286. package/probes/probe.js +127 -0
  287. package/probes/riscv_rvv.c +14 -0
  288. package/probes/riscv_rvv_bb.c +15 -0
  289. package/probes/riscv_rvv_bf16.c +17 -0
  290. package/probes/riscv_rvv_half.c +14 -0
  291. package/probes/wasm_v128relaxed.c +11 -0
  292. package/probes/x86_alder.c +17 -0
  293. package/probes/x86_diamond.c +17 -0
  294. package/probes/x86_genoa.c +17 -0
  295. package/probes/x86_graniteamx.c +19 -0
  296. package/probes/x86_haswell.c +11 -0
  297. package/probes/x86_icelake.c +17 -0
  298. package/probes/x86_sapphire.c +16 -0
  299. package/probes/x86_sapphireamx.c +18 -0
  300. package/probes/x86_sierra.c +17 -0
  301. package/probes/x86_skylake.c +15 -0
  302. package/probes/x86_turin.c +17 -0
  303. package/wasm/numkong-emscripten.js +2 -0
  304. package/wasm/numkong.d.ts +14 -0
  305. package/wasm/numkong.js +1124 -0
  306. package/wasm/numkong.wasm +0 -0
  307. package/include/numkong/curved/neonhalf.h +0 -212
  308. package/include/numkong/dot/neonhalf.h +0 -198
  309. package/include/numkong/dots/neonhalf.h +0 -57
  310. package/include/numkong/mesh/neonhalf.h +0 -616
  311. package/include/numkong/reduce/neonhalf.h +0 -157
  312. package/include/numkong/spatial/neonhalf.h +0 -118
  313. package/include/numkong/spatial/sapphire.h +0 -343
  314. package/include/numkong/spatials/neonhalf.h +0 -58
  315. package/javascript/README.md +0 -246
@@ -53,7 +53,7 @@ extern "C" {
53
53
  #endif
54
54
 
55
55
  #if defined(__clang__)
56
- #pragma clang attribute push(__attribute__((target("sme,sve"))), apply_to = function)
56
+ #pragma clang attribute push(__attribute__((target("sme"))), apply_to = function)
57
57
  #elif defined(__GNUC__)
58
58
  #pragma GCC push_options
59
59
  #pragma GCC target("+sme")
@@ -112,8 +112,8 @@ __arm_locally_streaming __arm_new("za") static void nk_maxsim_packed_f16_streami
112
112
  nk_f32_t const *document_inverse_norms = (nk_f32_t const *)((char const *)document_packed +
113
113
  document_header->norms_offset);
114
114
 
115
- svbool_t const predicate_all_f16x = svptrue_b16();
116
- svbool_t const predicate_all_f32x = svptrue_b32();
115
+ svbool_t const predicate_all_b16x = svptrue_b16();
116
+ svbool_t const predicate_all_b32x = svptrue_b32();
117
117
 
118
118
  nk_f32_t total_angular_distance = 0.0f;
119
119
 
@@ -121,10 +121,10 @@ __arm_locally_streaming __arm_new("za") static void nk_maxsim_packed_f16_streami
121
121
  nk_size_t const row_start = row_tile_index * tile_dimension;
122
122
  nk_size_t const rows_remaining = (row_start + tile_dimension <= query_count) ? tile_dimension
123
123
  : (query_count - row_start);
124
- svbool_t const row_predicate_f16x = (rows_remaining == tile_dimension)
124
+ svbool_t const row_predicate_b16x = (rows_remaining == tile_dimension)
125
125
  ? svptrue_b16()
126
126
  : svwhilelt_b16_u64(0u, rows_remaining * 2);
127
- svbool_t const row_predicate_f32x = (rows_remaining == tile_dimension) ? svptrue_b32()
127
+ svbool_t const row_predicate_b32x = (rows_remaining == tile_dimension) ? svptrue_b32()
128
128
  : svwhilelt_b32_u64(0u, rows_remaining);
129
129
 
130
130
  // Running max + argmax vectors for angular distance finalization
@@ -140,29 +140,29 @@ __arm_locally_streaming __arm_new("za") static void nk_maxsim_packed_f16_streami
140
140
  // Accumulate: for each depth step, load Q vector and 4 D vectors, issue 4 FMOPAs
141
141
  for (nk_size_t depth_step = 0; depth_step < depth_step_count; depth_step++) {
142
142
  svfloat16_t query_packed_f16x = svld1_f16(
143
- row_predicate_f16x,
143
+ row_predicate_b16x,
144
144
  (float16_t const *)(query_vecs +
145
145
  (row_tile_index * depth_step_count + depth_step) * vector_elements));
146
146
  svfloat16_t document_packed_0_f16x = svld1_f16(
147
- predicate_all_f16x,
147
+ predicate_all_b16x,
148
148
  (float16_t const *)(document_vecs +
149
149
  ((column_tile_index + 0) * depth_step_count + depth_step) * vector_elements));
150
150
  svfloat16_t document_packed_1_f16x = svld1_f16(
151
- predicate_all_f16x,
151
+ predicate_all_b16x,
152
152
  (float16_t const *)(document_vecs +
153
153
  ((column_tile_index + 1) * depth_step_count + depth_step) * vector_elements));
154
154
  svfloat16_t document_packed_2_f16x = svld1_f16(
155
- predicate_all_f16x,
155
+ predicate_all_b16x,
156
156
  (float16_t const *)(document_vecs +
157
157
  ((column_tile_index + 2) * depth_step_count + depth_step) * vector_elements));
158
158
  svfloat16_t document_packed_3_f16x = svld1_f16(
159
- predicate_all_f16x,
159
+ predicate_all_b16x,
160
160
  (float16_t const *)(document_vecs +
161
161
  ((column_tile_index + 3) * depth_step_count + depth_step) * vector_elements));
162
- svmopa_za32_f16_m(0, row_predicate_f16x, predicate_all_f16x, query_packed_f16x, document_packed_0_f16x);
163
- svmopa_za32_f16_m(1, row_predicate_f16x, predicate_all_f16x, query_packed_f16x, document_packed_1_f16x);
164
- svmopa_za32_f16_m(2, row_predicate_f16x, predicate_all_f16x, query_packed_f16x, document_packed_2_f16x);
165
- svmopa_za32_f16_m(3, row_predicate_f16x, predicate_all_f16x, query_packed_f16x, document_packed_3_f16x);
162
+ svmopa_za32_f16_m(0, row_predicate_b16x, predicate_all_b16x, query_packed_f16x, document_packed_0_f16x);
163
+ svmopa_za32_f16_m(1, row_predicate_b16x, predicate_all_b16x, query_packed_f16x, document_packed_1_f16x);
164
+ svmopa_za32_f16_m(2, row_predicate_b16x, predicate_all_b16x, query_packed_f16x, document_packed_2_f16x);
165
+ svmopa_za32_f16_m(3, row_predicate_b16x, predicate_all_b16x, query_packed_f16x, document_packed_3_f16x);
166
166
  }
167
167
 
168
168
  // Vertical column extraction + argmax update (manually unrolled over 4 tiles)
@@ -170,36 +170,36 @@ __arm_locally_streaming __arm_new("za") static void nk_maxsim_packed_f16_streami
170
170
  // Tile 0
171
171
  {
172
172
  nk_u32_t document_index = (nk_u32_t)((column_tile_index + 0) * tile_dimension + column_within_tile);
173
- svfloat32_t column_dots_f32x = svread_ver_za32_f32_m(svdup_f32(NK_F32_MIN), predicate_all_f32x, 0,
173
+ svfloat32_t column_dots_f32x = svread_ver_za32_f32_m(svdup_f32(NK_F32_MIN), predicate_all_b32x, 0,
174
174
  column_within_tile);
175
- svbool_t is_better_bx = svcmpgt_f32(predicate_all_f32x, column_dots_f32x, running_maximum_f32x);
175
+ svbool_t is_better_bx = svcmpgt_f32(predicate_all_b32x, column_dots_f32x, running_maximum_f32x);
176
176
  running_maximum_f32x = svsel_f32(is_better_bx, column_dots_f32x, running_maximum_f32x);
177
177
  running_argmax_u32x = svsel_u32(is_better_bx, svdup_u32(document_index), running_argmax_u32x);
178
178
  }
179
179
  // Tile 1
180
180
  {
181
181
  nk_u32_t document_index = (nk_u32_t)((column_tile_index + 1) * tile_dimension + column_within_tile);
182
- svfloat32_t column_dots_f32x = svread_ver_za32_f32_m(svdup_f32(NK_F32_MIN), predicate_all_f32x, 1,
182
+ svfloat32_t column_dots_f32x = svread_ver_za32_f32_m(svdup_f32(NK_F32_MIN), predicate_all_b32x, 1,
183
183
  column_within_tile);
184
- svbool_t is_better_bx = svcmpgt_f32(predicate_all_f32x, column_dots_f32x, running_maximum_f32x);
184
+ svbool_t is_better_bx = svcmpgt_f32(predicate_all_b32x, column_dots_f32x, running_maximum_f32x);
185
185
  running_maximum_f32x = svsel_f32(is_better_bx, column_dots_f32x, running_maximum_f32x);
186
186
  running_argmax_u32x = svsel_u32(is_better_bx, svdup_u32(document_index), running_argmax_u32x);
187
187
  }
188
188
  // Tile 2
189
189
  {
190
190
  nk_u32_t document_index = (nk_u32_t)((column_tile_index + 2) * tile_dimension + column_within_tile);
191
- svfloat32_t column_dots_f32x = svread_ver_za32_f32_m(svdup_f32(NK_F32_MIN), predicate_all_f32x, 2,
191
+ svfloat32_t column_dots_f32x = svread_ver_za32_f32_m(svdup_f32(NK_F32_MIN), predicate_all_b32x, 2,
192
192
  column_within_tile);
193
- svbool_t is_better_bx = svcmpgt_f32(predicate_all_f32x, column_dots_f32x, running_maximum_f32x);
193
+ svbool_t is_better_bx = svcmpgt_f32(predicate_all_b32x, column_dots_f32x, running_maximum_f32x);
194
194
  running_maximum_f32x = svsel_f32(is_better_bx, column_dots_f32x, running_maximum_f32x);
195
195
  running_argmax_u32x = svsel_u32(is_better_bx, svdup_u32(document_index), running_argmax_u32x);
196
196
  }
197
197
  // Tile 3
198
198
  {
199
199
  nk_u32_t document_index = (nk_u32_t)((column_tile_index + 3) * tile_dimension + column_within_tile);
200
- svfloat32_t column_dots_f32x = svread_ver_za32_f32_m(svdup_f32(NK_F32_MIN), predicate_all_f32x, 3,
200
+ svfloat32_t column_dots_f32x = svread_ver_za32_f32_m(svdup_f32(NK_F32_MIN), predicate_all_b32x, 3,
201
201
  column_within_tile);
202
- svbool_t is_better_bx = svcmpgt_f32(predicate_all_f32x, column_dots_f32x, running_maximum_f32x);
202
+ svbool_t is_better_bx = svcmpgt_f32(predicate_all_b32x, column_dots_f32x, running_maximum_f32x);
203
203
  running_maximum_f32x = svsel_f32(is_better_bx, column_dots_f32x, running_maximum_f32x);
204
204
  running_argmax_u32x = svsel_u32(is_better_bx, svdup_u32(document_index), running_argmax_u32x);
205
205
  }
@@ -212,7 +212,7 @@ __arm_locally_streaming __arm_new("za") static void nk_maxsim_packed_f16_streami
212
212
  nk_size_t const cols_remaining = (col_start + tile_dimension <= document_count)
213
213
  ? tile_dimension
214
214
  : (document_count - col_start);
215
- svbool_t const column_predicate_f16x = (cols_remaining == tile_dimension)
215
+ svbool_t const column_predicate_b16x = (cols_remaining == tile_dimension)
216
216
  ? svptrue_b16()
217
217
  : svwhilelt_b16_u64(0u, cols_remaining * 2);
218
218
 
@@ -220,23 +220,23 @@ __arm_locally_streaming __arm_new("za") static void nk_maxsim_packed_f16_streami
220
220
 
221
221
  for (nk_size_t depth_step = 0; depth_step < depth_step_count; depth_step++) {
222
222
  svfloat16_t query_packed_f16x = svld1_f16(
223
- row_predicate_f16x,
223
+ row_predicate_b16x,
224
224
  (float16_t const *)(query_vecs +
225
225
  (row_tile_index * depth_step_count + depth_step) * vector_elements));
226
226
  svfloat16_t document_packed_f16x = svld1_f16(
227
- column_predicate_f16x,
227
+ column_predicate_b16x,
228
228
  (float16_t const *)(document_vecs +
229
229
  (column_tile_index * depth_step_count + depth_step) * vector_elements));
230
- svmopa_za32_f16_m(0, row_predicate_f16x, column_predicate_f16x, query_packed_f16x,
230
+ svmopa_za32_f16_m(0, row_predicate_b16x, column_predicate_b16x, query_packed_f16x,
231
231
  document_packed_f16x);
232
232
  }
233
233
 
234
234
  // Vertical column extraction from ZA0 + argmax update
235
235
  for (nk_size_t column_within_tile = 0; column_within_tile < cols_remaining; column_within_tile++) {
236
236
  nk_u32_t document_index = (nk_u32_t)(col_start + column_within_tile);
237
- svfloat32_t column_dots_f32x = svread_ver_za32_f32_m(svdup_f32(NK_F32_MIN), predicate_all_f32x, 0,
237
+ svfloat32_t column_dots_f32x = svread_ver_za32_f32_m(svdup_f32(NK_F32_MIN), predicate_all_b32x, 0,
238
238
  column_within_tile);
239
- svbool_t is_better_bx = svcmpgt_f32(predicate_all_f32x, column_dots_f32x, running_maximum_f32x);
239
+ svbool_t is_better_bx = svcmpgt_f32(predicate_all_b32x, column_dots_f32x, running_maximum_f32x);
240
240
  running_maximum_f32x = svsel_f32(is_better_bx, column_dots_f32x, running_maximum_f32x);
241
241
  running_argmax_u32x = svsel_u32(is_better_bx, svdup_u32(document_index), running_argmax_u32x);
242
242
  }
@@ -246,19 +246,19 @@ __arm_locally_streaming __arm_new("za") static void nk_maxsim_packed_f16_streami
246
246
  // Gather document inverse norms via argmax indices (no SVE gather in streaming mode)
247
247
  nk_u32_t best_document_indices[64];
248
248
  nk_f32_t document_inverse_norms_gathered[64];
249
- svst1_u32(row_predicate_f32x, best_document_indices, running_argmax_u32x);
249
+ svst1_u32(row_predicate_b32x, best_document_indices, running_argmax_u32x);
250
250
  for (nk_size_t row_in_tile = 0; row_in_tile < rows_remaining; row_in_tile++)
251
251
  document_inverse_norms_gathered[row_in_tile] = document_inverse_norms[best_document_indices[row_in_tile]];
252
252
 
253
253
  // SVE-width: cosine = dot * inv_norm_q * inv_norm_d, angular = max(1 - cosine, 0)
254
- svfloat32_t query_inverse_norms_f32x = svld1_f32(row_predicate_f32x, query_inverse_norms + row_start);
255
- svfloat32_t document_inverse_norms_f32x = svld1_f32(row_predicate_f32x, document_inverse_norms_gathered);
254
+ svfloat32_t query_inverse_norms_f32x = svld1_f32(row_predicate_b32x, query_inverse_norms + row_start);
255
+ svfloat32_t document_inverse_norms_f32x = svld1_f32(row_predicate_b32x, document_inverse_norms_gathered);
256
256
  svfloat32_t cosine_f32x = svmul_f32_x(
257
- row_predicate_f32x, svmul_f32_x(row_predicate_f32x, running_maximum_f32x, query_inverse_norms_f32x),
257
+ row_predicate_b32x, svmul_f32_x(row_predicate_b32x, running_maximum_f32x, query_inverse_norms_f32x),
258
258
  document_inverse_norms_f32x);
259
259
  svfloat32_t angular_distance_f32x = svmax_f32_x(
260
- row_predicate_f32x, svsub_f32_x(row_predicate_f32x, svdup_f32(1.0f), cosine_f32x), svdup_f32(0.0f));
261
- total_angular_distance += svaddv_f32(row_predicate_f32x, angular_distance_f32x);
260
+ row_predicate_b32x, svsub_f32_x(row_predicate_b32x, svdup_f32(1.0f), cosine_f32x), svdup_f32(0.0f));
261
+ total_angular_distance += svaddv_f32(row_predicate_b32x, angular_distance_f32x);
262
262
  }
263
263
 
264
264
  *result = total_angular_distance;
@@ -304,8 +304,8 @@ __arm_locally_streaming __arm_new("za") static void nk_maxsim_packed_bf16_stream
304
304
  nk_f32_t const *document_inverse_norms = (nk_f32_t const *)((char const *)document_packed +
305
305
  document_header->norms_offset);
306
306
 
307
- svbool_t const predicate_all_f16x = svptrue_b16();
308
- svbool_t const predicate_all_f32x = svptrue_b32();
307
+ svbool_t const predicate_all_b16x = svptrue_b16();
308
+ svbool_t const predicate_all_b32x = svptrue_b32();
309
309
 
310
310
  nk_f32_t total_angular_distance = 0.0f;
311
311
 
@@ -313,10 +313,10 @@ __arm_locally_streaming __arm_new("za") static void nk_maxsim_packed_bf16_stream
313
313
  nk_size_t const row_start = row_tile_index * tile_dimension;
314
314
  nk_size_t const rows_remaining = (row_start + tile_dimension <= query_count) ? tile_dimension
315
315
  : (query_count - row_start);
316
- svbool_t const row_predicate_f16x = (rows_remaining == tile_dimension)
316
+ svbool_t const row_predicate_b16x = (rows_remaining == tile_dimension)
317
317
  ? svptrue_b16()
318
318
  : svwhilelt_b16_u64(0u, rows_remaining * 2);
319
- svbool_t const row_predicate_f32x = (rows_remaining == tile_dimension) ? svptrue_b32()
319
+ svbool_t const row_predicate_b32x = (rows_remaining == tile_dimension) ? svptrue_b32()
320
320
  : svwhilelt_b32_u64(0u, rows_remaining);
321
321
 
322
322
  // Running max + argmax vectors for angular distance finalization
@@ -332,32 +332,32 @@ __arm_locally_streaming __arm_new("za") static void nk_maxsim_packed_bf16_stream
332
332
  // Accumulate: for each depth step, load Q vector and 4 D vectors, issue 4 BFMOPAs
333
333
  for (nk_size_t depth_step = 0; depth_step < depth_step_count; depth_step++) {
334
334
  svbfloat16_t query_packed_bf16x = svld1_bf16(
335
- row_predicate_f16x,
335
+ row_predicate_b16x,
336
336
  (bfloat16_t const *)(query_vecs +
337
337
  (row_tile_index * depth_step_count + depth_step) * vector_elements));
338
338
  svbfloat16_t document_packed_0_bf16x = svld1_bf16(
339
- predicate_all_f16x,
339
+ predicate_all_b16x,
340
340
  (bfloat16_t const *)(document_vecs +
341
341
  ((column_tile_index + 0) * depth_step_count + depth_step) * vector_elements));
342
342
  svbfloat16_t document_packed_1_bf16x = svld1_bf16(
343
- predicate_all_f16x,
343
+ predicate_all_b16x,
344
344
  (bfloat16_t const *)(document_vecs +
345
345
  ((column_tile_index + 1) * depth_step_count + depth_step) * vector_elements));
346
346
  svbfloat16_t document_packed_2_bf16x = svld1_bf16(
347
- predicate_all_f16x,
347
+ predicate_all_b16x,
348
348
  (bfloat16_t const *)(document_vecs +
349
349
  ((column_tile_index + 2) * depth_step_count + depth_step) * vector_elements));
350
350
  svbfloat16_t document_packed_3_bf16x = svld1_bf16(
351
- predicate_all_f16x,
351
+ predicate_all_b16x,
352
352
  (bfloat16_t const *)(document_vecs +
353
353
  ((column_tile_index + 3) * depth_step_count + depth_step) * vector_elements));
354
- svmopa_za32_bf16_m(0, row_predicate_f16x, predicate_all_f16x, query_packed_bf16x,
354
+ svmopa_za32_bf16_m(0, row_predicate_b16x, predicate_all_b16x, query_packed_bf16x,
355
355
  document_packed_0_bf16x);
356
- svmopa_za32_bf16_m(1, row_predicate_f16x, predicate_all_f16x, query_packed_bf16x,
356
+ svmopa_za32_bf16_m(1, row_predicate_b16x, predicate_all_b16x, query_packed_bf16x,
357
357
  document_packed_1_bf16x);
358
- svmopa_za32_bf16_m(2, row_predicate_f16x, predicate_all_f16x, query_packed_bf16x,
358
+ svmopa_za32_bf16_m(2, row_predicate_b16x, predicate_all_b16x, query_packed_bf16x,
359
359
  document_packed_2_bf16x);
360
- svmopa_za32_bf16_m(3, row_predicate_f16x, predicate_all_f16x, query_packed_bf16x,
360
+ svmopa_za32_bf16_m(3, row_predicate_b16x, predicate_all_b16x, query_packed_bf16x,
361
361
  document_packed_3_bf16x);
362
362
  }
363
363
 
@@ -366,36 +366,36 @@ __arm_locally_streaming __arm_new("za") static void nk_maxsim_packed_bf16_stream
366
366
  // Tile 0
367
367
  {
368
368
  nk_u32_t document_index = (nk_u32_t)((column_tile_index + 0) * tile_dimension + column_within_tile);
369
- svfloat32_t column_dots_f32x = svread_ver_za32_f32_m(svdup_f32(NK_F32_MIN), predicate_all_f32x, 0,
369
+ svfloat32_t column_dots_f32x = svread_ver_za32_f32_m(svdup_f32(NK_F32_MIN), predicate_all_b32x, 0,
370
370
  column_within_tile);
371
- svbool_t is_better_bx = svcmpgt_f32(predicate_all_f32x, column_dots_f32x, running_maximum_f32x);
371
+ svbool_t is_better_bx = svcmpgt_f32(predicate_all_b32x, column_dots_f32x, running_maximum_f32x);
372
372
  running_maximum_f32x = svsel_f32(is_better_bx, column_dots_f32x, running_maximum_f32x);
373
373
  running_argmax_u32x = svsel_u32(is_better_bx, svdup_u32(document_index), running_argmax_u32x);
374
374
  }
375
375
  // Tile 1
376
376
  {
377
377
  nk_u32_t document_index = (nk_u32_t)((column_tile_index + 1) * tile_dimension + column_within_tile);
378
- svfloat32_t column_dots_f32x = svread_ver_za32_f32_m(svdup_f32(NK_F32_MIN), predicate_all_f32x, 1,
378
+ svfloat32_t column_dots_f32x = svread_ver_za32_f32_m(svdup_f32(NK_F32_MIN), predicate_all_b32x, 1,
379
379
  column_within_tile);
380
- svbool_t is_better_bx = svcmpgt_f32(predicate_all_f32x, column_dots_f32x, running_maximum_f32x);
380
+ svbool_t is_better_bx = svcmpgt_f32(predicate_all_b32x, column_dots_f32x, running_maximum_f32x);
381
381
  running_maximum_f32x = svsel_f32(is_better_bx, column_dots_f32x, running_maximum_f32x);
382
382
  running_argmax_u32x = svsel_u32(is_better_bx, svdup_u32(document_index), running_argmax_u32x);
383
383
  }
384
384
  // Tile 2
385
385
  {
386
386
  nk_u32_t document_index = (nk_u32_t)((column_tile_index + 2) * tile_dimension + column_within_tile);
387
- svfloat32_t column_dots_f32x = svread_ver_za32_f32_m(svdup_f32(NK_F32_MIN), predicate_all_f32x, 2,
387
+ svfloat32_t column_dots_f32x = svread_ver_za32_f32_m(svdup_f32(NK_F32_MIN), predicate_all_b32x, 2,
388
388
  column_within_tile);
389
- svbool_t is_better_bx = svcmpgt_f32(predicate_all_f32x, column_dots_f32x, running_maximum_f32x);
389
+ svbool_t is_better_bx = svcmpgt_f32(predicate_all_b32x, column_dots_f32x, running_maximum_f32x);
390
390
  running_maximum_f32x = svsel_f32(is_better_bx, column_dots_f32x, running_maximum_f32x);
391
391
  running_argmax_u32x = svsel_u32(is_better_bx, svdup_u32(document_index), running_argmax_u32x);
392
392
  }
393
393
  // Tile 3
394
394
  {
395
395
  nk_u32_t document_index = (nk_u32_t)((column_tile_index + 3) * tile_dimension + column_within_tile);
396
- svfloat32_t column_dots_f32x = svread_ver_za32_f32_m(svdup_f32(NK_F32_MIN), predicate_all_f32x, 3,
396
+ svfloat32_t column_dots_f32x = svread_ver_za32_f32_m(svdup_f32(NK_F32_MIN), predicate_all_b32x, 3,
397
397
  column_within_tile);
398
- svbool_t is_better_bx = svcmpgt_f32(predicate_all_f32x, column_dots_f32x, running_maximum_f32x);
398
+ svbool_t is_better_bx = svcmpgt_f32(predicate_all_b32x, column_dots_f32x, running_maximum_f32x);
399
399
  running_maximum_f32x = svsel_f32(is_better_bx, column_dots_f32x, running_maximum_f32x);
400
400
  running_argmax_u32x = svsel_u32(is_better_bx, svdup_u32(document_index), running_argmax_u32x);
401
401
  }
@@ -408,7 +408,7 @@ __arm_locally_streaming __arm_new("za") static void nk_maxsim_packed_bf16_stream
408
408
  nk_size_t const cols_remaining = (col_start + tile_dimension <= document_count)
409
409
  ? tile_dimension
410
410
  : (document_count - col_start);
411
- svbool_t const column_predicate_f16x = (cols_remaining == tile_dimension)
411
+ svbool_t const column_predicate_b16x = (cols_remaining == tile_dimension)
412
412
  ? svptrue_b16()
413
413
  : svwhilelt_b16_u64(0u, cols_remaining * 2);
414
414
 
@@ -416,23 +416,23 @@ __arm_locally_streaming __arm_new("za") static void nk_maxsim_packed_bf16_stream
416
416
 
417
417
  for (nk_size_t depth_step = 0; depth_step < depth_step_count; depth_step++) {
418
418
  svbfloat16_t query_packed_bf16x = svld1_bf16(
419
- row_predicate_f16x,
419
+ row_predicate_b16x,
420
420
  (bfloat16_t const *)(query_vecs +
421
421
  (row_tile_index * depth_step_count + depth_step) * vector_elements));
422
422
  svbfloat16_t document_packed_bf16x = svld1_bf16(
423
- column_predicate_f16x,
423
+ column_predicate_b16x,
424
424
  (bfloat16_t const *)(document_vecs +
425
425
  (column_tile_index * depth_step_count + depth_step) * vector_elements));
426
- svmopa_za32_bf16_m(0, row_predicate_f16x, column_predicate_f16x, query_packed_bf16x,
426
+ svmopa_za32_bf16_m(0, row_predicate_b16x, column_predicate_b16x, query_packed_bf16x,
427
427
  document_packed_bf16x);
428
428
  }
429
429
 
430
430
  // Vertical column extraction from ZA0 + argmax update
431
431
  for (nk_size_t column_within_tile = 0; column_within_tile < cols_remaining; column_within_tile++) {
432
432
  nk_u32_t document_index = (nk_u32_t)(col_start + column_within_tile);
433
- svfloat32_t column_dots_f32x = svread_ver_za32_f32_m(svdup_f32(NK_F32_MIN), predicate_all_f32x, 0,
433
+ svfloat32_t column_dots_f32x = svread_ver_za32_f32_m(svdup_f32(NK_F32_MIN), predicate_all_b32x, 0,
434
434
  column_within_tile);
435
- svbool_t is_better_bx = svcmpgt_f32(predicate_all_f32x, column_dots_f32x, running_maximum_f32x);
435
+ svbool_t is_better_bx = svcmpgt_f32(predicate_all_b32x, column_dots_f32x, running_maximum_f32x);
436
436
  running_maximum_f32x = svsel_f32(is_better_bx, column_dots_f32x, running_maximum_f32x);
437
437
  running_argmax_u32x = svsel_u32(is_better_bx, svdup_u32(document_index), running_argmax_u32x);
438
438
  }
@@ -442,19 +442,19 @@ __arm_locally_streaming __arm_new("za") static void nk_maxsim_packed_bf16_stream
442
442
  // Gather document inverse norms via argmax indices (no SVE gather in streaming mode)
443
443
  nk_u32_t best_document_indices[64];
444
444
  nk_f32_t document_inverse_norms_gathered[64];
445
- svst1_u32(row_predicate_f32x, best_document_indices, running_argmax_u32x);
445
+ svst1_u32(row_predicate_b32x, best_document_indices, running_argmax_u32x);
446
446
  for (nk_size_t row_in_tile = 0; row_in_tile < rows_remaining; row_in_tile++)
447
447
  document_inverse_norms_gathered[row_in_tile] = document_inverse_norms[best_document_indices[row_in_tile]];
448
448
 
449
449
  // SVE-width: cosine = dot * inv_norm_q * inv_norm_d, angular = max(1 - cosine, 0)
450
- svfloat32_t query_inverse_norms_f32x = svld1_f32(row_predicate_f32x, query_inverse_norms + row_start);
451
- svfloat32_t document_inverse_norms_f32x = svld1_f32(row_predicate_f32x, document_inverse_norms_gathered);
450
+ svfloat32_t query_inverse_norms_f32x = svld1_f32(row_predicate_b32x, query_inverse_norms + row_start);
451
+ svfloat32_t document_inverse_norms_f32x = svld1_f32(row_predicate_b32x, document_inverse_norms_gathered);
452
452
  svfloat32_t cosine_f32x = svmul_f32_x(
453
- row_predicate_f32x, svmul_f32_x(row_predicate_f32x, running_maximum_f32x, query_inverse_norms_f32x),
453
+ row_predicate_b32x, svmul_f32_x(row_predicate_b32x, running_maximum_f32x, query_inverse_norms_f32x),
454
454
  document_inverse_norms_f32x);
455
455
  svfloat32_t angular_distance_f32x = svmax_f32_x(
456
- row_predicate_f32x, svsub_f32_x(row_predicate_f32x, svdup_f32(1.0f), cosine_f32x), svdup_f32(0.0f));
457
- total_angular_distance += svaddv_f32(row_predicate_f32x, angular_distance_f32x);
456
+ row_predicate_b32x, svsub_f32_x(row_predicate_b32x, svdup_f32(1.0f), cosine_f32x), svdup_f32(0.0f));
457
+ total_angular_distance += svaddv_f32(row_predicate_b32x, angular_distance_f32x);
458
458
  }
459
459
 
460
460
  *result = total_angular_distance;
@@ -468,20 +468,20 @@ NK_PUBLIC void nk_maxsim_packed_bf16_sme( //
468
468
  nk_maxsim_packed_bf16_streaming_(query_packed, document_packed, query_count, document_count, depth, result);
469
469
  }
470
470
 
471
- NK_PUBLIC nk_size_t nk_maxsim_packed_size_bf16_sme(nk_size_t n, nk_size_t k) { //
472
- return nk_dots_packed_size_bf16_sme(n, k);
471
+ NK_PUBLIC nk_size_t nk_maxsim_packed_size_bf16_sme(nk_size_t columns, nk_size_t depth) { //
472
+ return nk_dots_packed_size_bf16_sme(columns, depth);
473
473
  }
474
474
 
475
- NK_PUBLIC nk_size_t nk_maxsim_packed_size_f16_sme(nk_size_t n, nk_size_t k) { //
476
- return nk_dots_packed_size_f16_sme(n, k);
475
+ NK_PUBLIC nk_size_t nk_maxsim_packed_size_f16_sme(nk_size_t columns, nk_size_t depth) { //
476
+ return nk_dots_packed_size_f16_sme(columns, depth);
477
477
  }
478
478
 
479
- NK_PUBLIC void nk_maxsim_pack_bf16_sme( //
480
- nk_bf16_t const *vectors, nk_size_t n, nk_size_t k, nk_size_t stride, void *packed) { //
479
+ NK_PUBLIC void nk_maxsim_pack_bf16_sme( //
480
+ nk_bf16_t const *vectors, nk_size_t columns, nk_size_t depth, nk_size_t stride_in_bytes, void *packed) { //
481
481
 
482
482
  // Delegate tile interleaving and squared norms computation to dots pack.
483
483
  // Both headers are 64 bytes with identical layout for the first 6 fields.
484
- nk_dots_pack_bf16_sme(vectors, n, k, stride, packed);
484
+ nk_dots_pack_bf16_sme(vectors, columns, depth, stride_in_bytes, packed);
485
485
 
486
486
  // Set maxsim-specific header fields (overlaps dots reserved area)
487
487
  nk_maxsim_sme_packed_header_t *header = (nk_maxsim_sme_packed_header_t *)packed;
@@ -491,18 +491,18 @@ NK_PUBLIC void nk_maxsim_pack_bf16_sme(
491
491
 
492
492
  // Convert squared norms → inverse norms in-place
493
493
  nk_f32_t *norms = (nk_f32_t *)((char *)packed + header->norms_offset);
494
- for (nk_size_t i = 0; i < n; i++) {
494
+ for (nk_size_t i = 0; i < columns; i++) {
495
495
  nk_f32_t norm_sq = norms[i];
496
496
  norms[i] = (norm_sq > 0.0f) ? (nk_f32_t)nk_f64_rsqrt_neon((nk_f64_t)norm_sq) : 0.0f;
497
497
  }
498
498
  }
499
499
 
500
- NK_PUBLIC void nk_maxsim_pack_f16_sme( //
501
- nk_f16_t const *vectors, nk_size_t n, nk_size_t k, nk_size_t stride, void *packed) { //
500
+ NK_PUBLIC void nk_maxsim_pack_f16_sme( //
501
+ nk_f16_t const *vectors, nk_size_t columns, nk_size_t depth, nk_size_t stride_in_bytes, void *packed) { //
502
502
 
503
503
  // Delegate tile interleaving and squared norms computation to dots pack.
504
504
  // Both headers are 64 bytes with identical layout for the first 6 fields.
505
- nk_dots_pack_f16_sme(vectors, n, k, stride, packed);
505
+ nk_dots_pack_f16_sme(vectors, columns, depth, stride_in_bytes, packed);
506
506
 
507
507
  // Set maxsim-specific header fields (overlaps dots reserved area)
508
508
  nk_maxsim_sme_packed_header_t *header = (nk_maxsim_sme_packed_header_t *)packed;
@@ -512,7 +512,7 @@ NK_PUBLIC void nk_maxsim_pack_f16_sme(
512
512
 
513
513
  // Convert squared norms → inverse norms in-place
514
514
  nk_f32_t *norms = (nk_f32_t *)((char *)packed + header->norms_offset);
515
- for (nk_size_t i = 0; i < n; i++) {
515
+ for (nk_size_t i = 0; i < columns; i++) {
516
516
  nk_f32_t norm_sq = norms[i];
517
517
  norms[i] = (norm_sq > 0.0f) ? (nk_f32_t)nk_f64_rsqrt_neon((nk_f64_t)norm_sq) : 0.0f;
518
518
  }
@@ -527,45 +527,45 @@ NK_PUBLIC void nk_maxsim_pack_f16_sme(
527
527
  * Refinement: tile-wide interleaved f64 dot products for the winning (query, document) pairs.
528
528
  * Angular distance: 1 - dot / sqrt(||q||^2 * ||d||^2), accumulated with f64.
529
529
  */
530
- NK_PUBLIC nk_size_t nk_maxsim_packed_size_f32_sme(nk_size_t n, nk_size_t k) { //
531
- nk_size_t const expansion = 4; // i8->i32 SMOPA
532
- nk_size_t const tile_dimension = svcntsw(); // 16 for SVL=512
533
- nk_size_t const vector_elements = svcntsb(); // 64 for SVL=512
534
- nk_size_t const column_tile_count = nk_size_divide_round_up_(n, tile_dimension);
535
- nk_size_t const depth_step_count = nk_size_divide_round_up_(k, expansion);
536
- nk_size_t const original_stride = nk_size_round_up_to_multiple_(k * sizeof(nk_f32_t), 64);
530
+ NK_PUBLIC nk_size_t nk_maxsim_packed_size_f32_sme(nk_size_t columns, nk_size_t depth) { //
531
+ nk_size_t const expansion = 4; // i8->i32 SMOPA
532
+ nk_size_t const tile_dimension = nk_sme_cntw_(); // 16 for SVL=512
533
+ nk_size_t const vector_elements = nk_sme_cntb_(); // 64 for SVL=512
534
+ nk_size_t const column_tile_count = nk_size_divide_round_up_(columns, tile_dimension);
535
+ nk_size_t const depth_step_count = nk_size_divide_round_up_(depth, expansion);
536
+ nk_size_t const original_stride = nk_size_round_up_to_multiple_(depth * sizeof(nk_f32_t), 64);
537
537
 
538
538
  nk_size_t size = sizeof(nk_maxsim_sme_packed_header_t); // 64 B header
539
539
  size += column_tile_count * depth_step_count * vector_elements; // i8 tiles
540
- size += n * sizeof(nk_f32_t); // f32 squared norms
541
- size += n * original_stride; // f32 originals
540
+ size += columns * sizeof(nk_f32_t); // f32 squared norms
541
+ size += columns * original_stride; // f32 originals
542
542
  return size;
543
543
  }
544
544
 
545
- NK_PUBLIC void nk_maxsim_pack_f32_sme( //
546
- nk_f32_t const *vectors, nk_size_t n, nk_size_t k, nk_size_t stride, void *packed) { //
545
+ NK_PUBLIC void nk_maxsim_pack_f32_sme( //
546
+ nk_f32_t const *vectors, nk_size_t columns, nk_size_t depth, nk_size_t stride_in_bytes, void *packed) { //
547
547
 
548
- nk_size_t const expansion = 4; // i8->i32 SMOPA
549
- nk_size_t const tile_dimension = svcntsw(); // 16 for SVL=512
550
- nk_size_t const vector_elements = svcntsb(); // 64 for SVL=512
551
- nk_size_t const stride_elements = stride / sizeof(nk_f32_t);
548
+ nk_size_t const expansion = 4; // i8->i32 SMOPA
549
+ nk_size_t const tile_dimension = nk_sme_cntw_(); // 16 for SVL=512
550
+ nk_size_t const vector_elements = nk_sme_cntb_(); // 64 for SVL=512
551
+ nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_f32_t);
552
552
 
553
- nk_size_t const column_tile_count = nk_size_divide_round_up_(n, tile_dimension);
554
- nk_size_t const depth_step_count = nk_size_divide_round_up_(k, expansion);
553
+ nk_size_t const column_tile_count = nk_size_divide_round_up_(columns, tile_dimension);
554
+ nk_size_t const depth_step_count = nk_size_divide_round_up_(depth, expansion);
555
555
  nk_size_t const total_vectors = column_tile_count * depth_step_count;
556
- nk_size_t const original_stride = nk_size_round_up_to_multiple_(k * sizeof(nk_f32_t), 64);
556
+ nk_size_t const original_stride = nk_size_round_up_to_multiple_(depth * sizeof(nk_f32_t), 64);
557
557
 
558
558
  // Set up header
559
559
  nk_maxsim_sme_packed_header_t *header = (nk_maxsim_sme_packed_header_t *)packed;
560
560
  header->column_tile_count = (nk_u32_t)column_tile_count;
561
561
  header->depth_tile_count = (nk_u32_t)depth_step_count;
562
- header->columns = (nk_u32_t)n;
563
- header->depth = (nk_u32_t)k;
564
- header->svl_bytes = (nk_u32_t)(svcntsw() * sizeof(nk_f32_t));
562
+ header->columns = (nk_u32_t)columns;
563
+ header->depth = (nk_u32_t)depth;
564
+ header->svl_bytes = (nk_u32_t)(tile_dimension * sizeof(nk_f32_t));
565
565
 
566
566
  nk_size_t const tiles_size = total_vectors * vector_elements;
567
567
  nk_size_t const norms_offset = sizeof(nk_maxsim_sme_packed_header_t) + tiles_size;
568
- nk_size_t const originals_offset = norms_offset + n * sizeof(nk_f32_t);
568
+ nk_size_t const originals_offset = norms_offset + columns * sizeof(nk_f32_t);
569
569
 
570
570
  header->norms_offset = (nk_u32_t)norms_offset;
571
571
  header->originals_offset = (nk_u32_t)originals_offset;
@@ -580,13 +580,13 @@ NK_PUBLIC void nk_maxsim_pack_f32_sme(
580
580
  for (nk_size_t i = 0; i < tiles_size; i++) tiles[i] = 0;
581
581
 
582
582
  // For each vector: quantize metadata, quantize+interleave into tiles, copy originals
583
- for (nk_size_t vector_index = 0; vector_index < n; vector_index++) {
584
- nk_f32_t const *source = (nk_f32_t const *)((char const *)vectors + vector_index * stride);
583
+ for (nk_size_t vector_index = 0; vector_index < columns; vector_index++) {
584
+ nk_f32_t const *source = (nk_f32_t const *)((char const *)vectors + vector_index * stride_in_bytes);
585
585
 
586
586
  // Pass 1: Compute absmax and norm_sq simultaneously
587
587
  nk_f32_t absmax = 0.0f;
588
588
  nk_f32_t norm_sq = 0.0f;
589
- for (nk_size_t dim = 0; dim < k; dim++) {
589
+ for (nk_size_t dim = 0; dim < depth; dim++) {
590
590
  nk_f32_t val = source[dim];
591
591
  nk_f32_t abs_val = nk_f32_abs_(val);
592
592
  if (abs_val > absmax) absmax = abs_val;
@@ -601,7 +601,7 @@ NK_PUBLIC void nk_maxsim_pack_f32_sme(
601
601
  nk_size_t const column_tile = vector_index / tile_dimension;
602
602
  nk_size_t const column_in_tile = vector_index % tile_dimension;
603
603
 
604
- for (nk_size_t dim = 0; dim < k; dim++) {
604
+ for (nk_size_t dim = 0; dim < depth; dim++) {
605
605
  nk_size_t const depth_step = dim / expansion;
606
606
  nk_size_t const sub_element = dim % expansion;
607
607
  nk_size_t const vec_index = column_tile * depth_step_count + depth_step;
@@ -619,8 +619,8 @@ NK_PUBLIC void nk_maxsim_pack_f32_sme(
619
619
 
620
620
  // Pass 3: Copy originals (64B-aligned stride, zero-pad tail)
621
621
  char *dest_original = originals + vector_index * original_stride;
622
- nk_copy_bytes_(dest_original, source, k * sizeof(nk_f32_t));
623
- for (nk_size_t byte = k * sizeof(nk_f32_t); byte < original_stride; byte++) dest_original[byte] = 0;
622
+ nk_copy_bytes_(dest_original, source, depth * sizeof(nk_f32_t));
623
+ for (nk_size_t byte = depth * sizeof(nk_f32_t); byte < original_stride; byte++) dest_original[byte] = 0;
624
624
  }
625
625
  }
626
626
 
@@ -628,16 +628,28 @@ NK_PUBLIC void nk_maxsim_pack_f32_sme(
628
628
  * Streaming-compatible f32 dot product with f64 accumulation.
629
629
  * Follows the svcntd()-stride + svcvt_f64_f32_x pattern from nk_dots_reduce_sumsq_f32_ssve_.
630
630
  */
631
- NK_PUBLIC nk_f64_t nk_maxsim_reduce_dot_f32_ssve_( //
632
- nk_f32_t const *a, nk_f32_t const *b, nk_size_t count) NK_STREAMING_COMPATIBLE_ { //
633
- svfloat64_t accumulator_f64x = svdup_f64(0.0);
634
- for (nk_size_t i = 0; i < count; i += svcntd()) {
635
- svbool_t predicate_f64x = svwhilelt_b64_u64(i, count);
636
- svfloat64_t a_f64x = svcvt_f64_f32_x(predicate_f64x, svld1_f32(svwhilelt_b32_u64(i, count), a + i));
637
- svfloat64_t b_f64x = svcvt_f64_f32_x(predicate_f64x, svld1_f32(svwhilelt_b32_u64(i, count), b + i));
638
- accumulator_f64x = svmla_f64_x(predicate_f64x, accumulator_f64x, a_f64x, b_f64x);
631
+ NK_PUBLIC nk_f64_t nk_maxsim_reduce_dot_f32_ssve_( //
632
+ nk_f32_t const *a, nk_f32_t const *b, nk_size_t count) NK_STREAMING_ { //
633
+ svfloat64_t accumulator_even_f64x = svdup_f64(0.0);
634
+ svfloat64_t accumulator_odd_f64x = svdup_f64(0.0);
635
+ nk_size_t const vector_length = svcntw();
636
+ nk_size_t const half_vector_length = svcntd();
637
+ for (nk_size_t i = 0; i < count; i += vector_length) {
638
+ svbool_t predicate_b32x = svwhilelt_b32_u64(i, count);
639
+ svfloat32_t a_f32x = svld1_f32(predicate_b32x, a + i);
640
+ svfloat32_t b_f32x = svld1_f32(predicate_b32x, b + i);
641
+
642
+ svbool_t predicate_even_b64x = svwhilelt_b64_u64(i, count);
643
+ svfloat64_t a_even_f64x = svcvt_f64_f32_x(predicate_even_b64x, a_f32x);
644
+ svfloat64_t b_even_f64x = svcvt_f64_f32_x(predicate_even_b64x, b_f32x);
645
+ accumulator_even_f64x = svmla_f64_m(predicate_even_b64x, accumulator_even_f64x, a_even_f64x, b_even_f64x);
646
+
647
+ svbool_t predicate_odd_b64x = svwhilelt_b64_u64(i + half_vector_length, count);
648
+ svfloat64_t a_odd_f64x = svcvtlt_f64_f32_x(predicate_odd_b64x, a_f32x);
649
+ svfloat64_t b_odd_f64x = svcvtlt_f64_f32_x(predicate_odd_b64x, b_f32x);
650
+ accumulator_odd_f64x = svmla_f64_m(predicate_odd_b64x, accumulator_odd_f64x, a_odd_f64x, b_odd_f64x);
639
651
  }
640
- return svaddv_f64(svptrue_b64(), accumulator_f64x);
652
+ return svaddv_f64(svptrue_b64(), accumulator_even_f64x) + svaddv_f64(svptrue_b64(), accumulator_odd_f64x);
641
653
  }
642
654
 
643
655
  /**
@@ -680,8 +692,8 @@ __arm_locally_streaming __arm_new("za") static void nk_maxsim_packed_f32_streami
680
692
 
681
693
  nk_size_t const expansion = 4; // i8->i32 SMOPA
682
694
 
683
- svbool_t const predicate_all_i8x = svptrue_b8();
684
- svbool_t const predicate_all_f32x = svptrue_b32();
695
+ svbool_t const predicate_all_b8x = svptrue_b8();
696
+ svbool_t const predicate_all_b32x = svptrue_b32();
685
697
 
686
698
  nk_f64_t total_angular_distance_f64 = 0.0;
687
699
 
@@ -689,10 +701,10 @@ __arm_locally_streaming __arm_new("za") static void nk_maxsim_packed_f32_streami
689
701
  nk_size_t const row_start = row_tile_index * tile_dimension;
690
702
  nk_size_t const rows_remaining = (row_start + tile_dimension <= query_count) ? tile_dimension
691
703
  : (query_count - row_start);
692
- svbool_t const row_predicate_i8x = (rows_remaining == tile_dimension)
704
+ svbool_t const row_predicate_b8x = (rows_remaining == tile_dimension)
693
705
  ? svptrue_b8()
694
706
  : svwhilelt_b8_u64(0u, rows_remaining * expansion);
695
- svbool_t const row_predicate_f32x = (rows_remaining == tile_dimension) ? svptrue_b32()
707
+ svbool_t const row_predicate_b32x = (rows_remaining == tile_dimension) ? svptrue_b32()
696
708
  : svwhilelt_b32_u64(0u, rows_remaining);
697
709
 
698
710
  svint32_t running_max_i32x = svdup_s32(NK_I32_MIN);
@@ -706,28 +718,29 @@ __arm_locally_streaming __arm_new("za") static void nk_maxsim_packed_f32_streami
706
718
 
707
719
  for (nk_size_t depth_step = 0; depth_step < depth_step_count; depth_step++) {
708
720
  svint8_t query_packed_i8x = svld1_s8(
709
- row_predicate_i8x,
710
- (int8_t const *)(query_tiles + (row_tile_index * depth_step_count + depth_step) * vector_elements));
721
+ row_predicate_b8x,
722
+ (nk_i8_t const *)(query_tiles +
723
+ (row_tile_index * depth_step_count + depth_step) * vector_elements));
711
724
  svint8_t document_packed_0_i8x = svld1_s8(
712
- predicate_all_i8x,
713
- (int8_t const *)(document_tiles +
714
- ((column_tile_index + 0) * depth_step_count + depth_step) * vector_elements));
725
+ predicate_all_b8x,
726
+ (nk_i8_t const *)(document_tiles +
727
+ ((column_tile_index + 0) * depth_step_count + depth_step) * vector_elements));
715
728
  svint8_t document_packed_1_i8x = svld1_s8(
716
- predicate_all_i8x,
717
- (int8_t const *)(document_tiles +
718
- ((column_tile_index + 1) * depth_step_count + depth_step) * vector_elements));
729
+ predicate_all_b8x,
730
+ (nk_i8_t const *)(document_tiles +
731
+ ((column_tile_index + 1) * depth_step_count + depth_step) * vector_elements));
719
732
  svint8_t document_packed_2_i8x = svld1_s8(
720
- predicate_all_i8x,
721
- (int8_t const *)(document_tiles +
722
- ((column_tile_index + 2) * depth_step_count + depth_step) * vector_elements));
733
+ predicate_all_b8x,
734
+ (nk_i8_t const *)(document_tiles +
735
+ ((column_tile_index + 2) * depth_step_count + depth_step) * vector_elements));
723
736
  svint8_t document_packed_3_i8x = svld1_s8(
724
- predicate_all_i8x,
725
- (int8_t const *)(document_tiles +
726
- ((column_tile_index + 3) * depth_step_count + depth_step) * vector_elements));
727
- svmopa_za32_s8_m(0, row_predicate_i8x, predicate_all_i8x, query_packed_i8x, document_packed_0_i8x);
728
- svmopa_za32_s8_m(1, row_predicate_i8x, predicate_all_i8x, query_packed_i8x, document_packed_1_i8x);
729
- svmopa_za32_s8_m(2, row_predicate_i8x, predicate_all_i8x, query_packed_i8x, document_packed_2_i8x);
730
- svmopa_za32_s8_m(3, row_predicate_i8x, predicate_all_i8x, query_packed_i8x, document_packed_3_i8x);
737
+ predicate_all_b8x,
738
+ (nk_i8_t const *)(document_tiles +
739
+ ((column_tile_index + 3) * depth_step_count + depth_step) * vector_elements));
740
+ svmopa_za32_s8_m(0, row_predicate_b8x, predicate_all_b8x, query_packed_i8x, document_packed_0_i8x);
741
+ svmopa_za32_s8_m(1, row_predicate_b8x, predicate_all_b8x, query_packed_i8x, document_packed_1_i8x);
742
+ svmopa_za32_s8_m(2, row_predicate_b8x, predicate_all_b8x, query_packed_i8x, document_packed_2_i8x);
743
+ svmopa_za32_s8_m(3, row_predicate_b8x, predicate_all_b8x, query_packed_i8x, document_packed_3_i8x);
731
744
  }
732
745
 
733
746
  // Vertical column extraction + argmax update (manually unrolled over 4 tiles)
@@ -735,36 +748,36 @@ __arm_locally_streaming __arm_new("za") static void nk_maxsim_packed_f32_streami
735
748
  // Tile 0
736
749
  {
737
750
  nk_u32_t document_index = (nk_u32_t)((column_tile_index + 0) * tile_dimension + column_within_tile);
738
- svint32_t column_dots_i32x = svread_ver_za32_s32_m(svdup_s32(NK_I32_MIN), predicate_all_f32x, 0,
751
+ svint32_t column_dots_i32x = svread_ver_za32_s32_m(svdup_s32(NK_I32_MIN), predicate_all_b32x, 0,
739
752
  column_within_tile);
740
- svbool_t is_better_bx = svcmpgt_s32(predicate_all_f32x, column_dots_i32x, running_max_i32x);
753
+ svbool_t is_better_bx = svcmpgt_s32(predicate_all_b32x, column_dots_i32x, running_max_i32x);
741
754
  running_max_i32x = svsel_s32(is_better_bx, column_dots_i32x, running_max_i32x);
742
755
  running_argmax_u32x = svsel_u32(is_better_bx, svdup_u32(document_index), running_argmax_u32x);
743
756
  }
744
757
  // Tile 1
745
758
  {
746
759
  nk_u32_t document_index = (nk_u32_t)((column_tile_index + 1) * tile_dimension + column_within_tile);
747
- svint32_t column_dots_i32x = svread_ver_za32_s32_m(svdup_s32(NK_I32_MIN), predicate_all_f32x, 1,
760
+ svint32_t column_dots_i32x = svread_ver_za32_s32_m(svdup_s32(NK_I32_MIN), predicate_all_b32x, 1,
748
761
  column_within_tile);
749
- svbool_t is_better_bx = svcmpgt_s32(predicate_all_f32x, column_dots_i32x, running_max_i32x);
762
+ svbool_t is_better_bx = svcmpgt_s32(predicate_all_b32x, column_dots_i32x, running_max_i32x);
750
763
  running_max_i32x = svsel_s32(is_better_bx, column_dots_i32x, running_max_i32x);
751
764
  running_argmax_u32x = svsel_u32(is_better_bx, svdup_u32(document_index), running_argmax_u32x);
752
765
  }
753
766
  // Tile 2
754
767
  {
755
768
  nk_u32_t document_index = (nk_u32_t)((column_tile_index + 2) * tile_dimension + column_within_tile);
756
- svint32_t column_dots_i32x = svread_ver_za32_s32_m(svdup_s32(NK_I32_MIN), predicate_all_f32x, 2,
769
+ svint32_t column_dots_i32x = svread_ver_za32_s32_m(svdup_s32(NK_I32_MIN), predicate_all_b32x, 2,
757
770
  column_within_tile);
758
- svbool_t is_better_bx = svcmpgt_s32(predicate_all_f32x, column_dots_i32x, running_max_i32x);
771
+ svbool_t is_better_bx = svcmpgt_s32(predicate_all_b32x, column_dots_i32x, running_max_i32x);
759
772
  running_max_i32x = svsel_s32(is_better_bx, column_dots_i32x, running_max_i32x);
760
773
  running_argmax_u32x = svsel_u32(is_better_bx, svdup_u32(document_index), running_argmax_u32x);
761
774
  }
762
775
  // Tile 3
763
776
  {
764
777
  nk_u32_t document_index = (nk_u32_t)((column_tile_index + 3) * tile_dimension + column_within_tile);
765
- svint32_t column_dots_i32x = svread_ver_za32_s32_m(svdup_s32(NK_I32_MIN), predicate_all_f32x, 3,
778
+ svint32_t column_dots_i32x = svread_ver_za32_s32_m(svdup_s32(NK_I32_MIN), predicate_all_b32x, 3,
766
779
  column_within_tile);
767
- svbool_t is_better_bx = svcmpgt_s32(predicate_all_f32x, column_dots_i32x, running_max_i32x);
780
+ svbool_t is_better_bx = svcmpgt_s32(predicate_all_b32x, column_dots_i32x, running_max_i32x);
768
781
  running_max_i32x = svsel_s32(is_better_bx, column_dots_i32x, running_max_i32x);
769
782
  running_argmax_u32x = svsel_u32(is_better_bx, svdup_u32(document_index), running_argmax_u32x);
770
783
  }
@@ -777,7 +790,7 @@ __arm_locally_streaming __arm_new("za") static void nk_maxsim_packed_f32_streami
777
790
  nk_size_t const cols_remaining = (col_start + tile_dimension <= document_count)
778
791
  ? tile_dimension
779
792
  : (document_count - col_start);
780
- svbool_t const column_predicate_i8x = (cols_remaining == tile_dimension)
793
+ svbool_t const column_predicate_b8x = (cols_remaining == tile_dimension)
781
794
  ? svptrue_b8()
782
795
  : svwhilelt_b8_u64(0u, cols_remaining * expansion);
783
796
 
@@ -785,20 +798,21 @@ __arm_locally_streaming __arm_new("za") static void nk_maxsim_packed_f32_streami
785
798
 
786
799
  for (nk_size_t depth_step = 0; depth_step < depth_step_count; depth_step++) {
787
800
  svint8_t query_packed_i8x = svld1_s8(
788
- row_predicate_i8x,
789
- (int8_t const *)(query_tiles + (row_tile_index * depth_step_count + depth_step) * vector_elements));
801
+ row_predicate_b8x,
802
+ (nk_i8_t const *)(query_tiles +
803
+ (row_tile_index * depth_step_count + depth_step) * vector_elements));
790
804
  svint8_t document_packed_i8x = svld1_s8(
791
- column_predicate_i8x,
792
- (int8_t const *)(document_tiles +
793
- (column_tile_index * depth_step_count + depth_step) * vector_elements));
794
- svmopa_za32_s8_m(0, row_predicate_i8x, column_predicate_i8x, query_packed_i8x, document_packed_i8x);
805
+ column_predicate_b8x,
806
+ (nk_i8_t const *)(document_tiles +
807
+ (column_tile_index * depth_step_count + depth_step) * vector_elements));
808
+ svmopa_za32_s8_m(0, row_predicate_b8x, column_predicate_b8x, query_packed_i8x, document_packed_i8x);
795
809
  }
796
810
 
797
811
  for (nk_size_t column_within_tile = 0; column_within_tile < cols_remaining; column_within_tile++) {
798
812
  nk_u32_t document_index = (nk_u32_t)(col_start + column_within_tile);
799
- svint32_t column_dots_i32x = svread_ver_za32_s32_m(svdup_s32(NK_I32_MIN), predicate_all_f32x, 0,
813
+ svint32_t column_dots_i32x = svread_ver_za32_s32_m(svdup_s32(NK_I32_MIN), predicate_all_b32x, 0,
800
814
  column_within_tile);
801
- svbool_t is_better_bx = svcmpgt_s32(predicate_all_f32x, column_dots_i32x, running_max_i32x);
815
+ svbool_t is_better_bx = svcmpgt_s32(predicate_all_b32x, column_dots_i32x, running_max_i32x);
802
816
  running_max_i32x = svsel_s32(is_better_bx, column_dots_i32x, running_max_i32x);
803
817
  running_argmax_u32x = svsel_u32(is_better_bx, svdup_u32(document_index), running_argmax_u32x);
804
818
  }
@@ -806,7 +820,7 @@ __arm_locally_streaming __arm_new("za") static void nk_maxsim_packed_f32_streami
806
820
 
807
821
  // Refinement: tile-wide interleaved f64 dot products
808
822
  nk_u32_t best_document_indices[64]; // max tile_dimension across all SVL values
809
- svst1_u32(row_predicate_f32x, best_document_indices, running_argmax_u32x);
823
+ svst1_u32(row_predicate_b32x, best_document_indices, running_argmax_u32x);
810
824
 
811
825
  // Pointer setup: one (query, document) pair per row in the tile
812
826
  nk_f32_t const *query_original_ptrs[64];
@@ -828,46 +842,57 @@ __arm_locally_streaming __arm_new("za") static void nk_maxsim_packed_f32_streami
828
842
  svfloat64_t accumulator_1_f64x = svdup_f64(0.0);
829
843
  svfloat64_t accumulator_2_f64x = svdup_f64(0.0);
830
844
  svfloat64_t accumulator_3_f64x = svdup_f64(0.0);
831
-
832
- for (nk_size_t depth_index = 0; depth_index < depth; depth_index += svcntd()) {
833
- svbool_t predicate_depth_f64x = svwhilelt_b64_u64(depth_index, depth);
834
- svbool_t predicate_depth_f32x = svwhilelt_b32_u64(depth_index, depth);
835
-
836
- svfloat64_t query_values_0_f64x = svcvt_f64_f32_x(
837
- predicate_depth_f64x,
838
- svld1_f32(predicate_depth_f32x, query_original_ptrs[row_batch_start + 0] + depth_index));
839
- svfloat64_t document_values_0_f64x = svcvt_f64_f32_x(
840
- predicate_depth_f64x,
841
- svld1_f32(predicate_depth_f32x, document_original_ptrs[row_batch_start + 0] + depth_index));
842
- accumulator_0_f64x = svmla_f64_x(predicate_depth_f64x, accumulator_0_f64x, query_values_0_f64x,
843
- document_values_0_f64x);
844
-
845
- svfloat64_t query_values_1_f64x = svcvt_f64_f32_x(
846
- predicate_depth_f64x,
847
- svld1_f32(predicate_depth_f32x, query_original_ptrs[row_batch_start + 1] + depth_index));
848
- svfloat64_t document_values_1_f64x = svcvt_f64_f32_x(
849
- predicate_depth_f64x,
850
- svld1_f32(predicate_depth_f32x, document_original_ptrs[row_batch_start + 1] + depth_index));
851
- accumulator_1_f64x = svmla_f64_x(predicate_depth_f64x, accumulator_1_f64x, query_values_1_f64x,
852
- document_values_1_f64x);
853
-
854
- svfloat64_t query_values_2_f64x = svcvt_f64_f32_x(
855
- predicate_depth_f64x,
856
- svld1_f32(predicate_depth_f32x, query_original_ptrs[row_batch_start + 2] + depth_index));
857
- svfloat64_t document_values_2_f64x = svcvt_f64_f32_x(
858
- predicate_depth_f64x,
859
- svld1_f32(predicate_depth_f32x, document_original_ptrs[row_batch_start + 2] + depth_index));
860
- accumulator_2_f64x = svmla_f64_x(predicate_depth_f64x, accumulator_2_f64x, query_values_2_f64x,
861
- document_values_2_f64x);
862
-
863
- svfloat64_t query_values_3_f64x = svcvt_f64_f32_x(
864
- predicate_depth_f64x,
865
- svld1_f32(predicate_depth_f32x, query_original_ptrs[row_batch_start + 3] + depth_index));
866
- svfloat64_t document_values_3_f64x = svcvt_f64_f32_x(
867
- predicate_depth_f64x,
868
- svld1_f32(predicate_depth_f32x, document_original_ptrs[row_batch_start + 3] + depth_index));
869
- accumulator_3_f64x = svmla_f64_x(predicate_depth_f64x, accumulator_3_f64x, query_values_3_f64x,
870
- document_values_3_f64x);
845
+ nk_size_t const depth_vector_length = svcntw();
846
+ nk_size_t const depth_half_length = svcntd();
847
+
848
+ for (nk_size_t depth_index = 0; depth_index < depth; depth_index += depth_vector_length) {
849
+ svbool_t predicate_depth_b32x = svwhilelt_b32_u64(depth_index, depth);
850
+ svbool_t predicate_even_b64x = svwhilelt_b64_u64(depth_index, depth);
851
+ svbool_t predicate_odd_b64x = svwhilelt_b64_u64(depth_index + depth_half_length, depth);
852
+
853
+ svfloat32_t query_values_0_f32x = svld1_f32(predicate_depth_b32x,
854
+ query_original_ptrs[row_batch_start + 0] + depth_index);
855
+ svfloat32_t document_values_0_f32x = svld1_f32(
856
+ predicate_depth_b32x, document_original_ptrs[row_batch_start + 0] + depth_index);
857
+ accumulator_0_f64x = svmla_f64_m(predicate_even_b64x, accumulator_0_f64x,
858
+ svcvt_f64_f32_x(predicate_even_b64x, query_values_0_f32x),
859
+ svcvt_f64_f32_x(predicate_even_b64x, document_values_0_f32x));
860
+ accumulator_0_f64x = svmla_f64_m(predicate_odd_b64x, accumulator_0_f64x,
861
+ svcvtlt_f64_f32_x(predicate_odd_b64x, query_values_0_f32x),
862
+ svcvtlt_f64_f32_x(predicate_odd_b64x, document_values_0_f32x));
863
+
864
+ svfloat32_t query_values_1_f32x = svld1_f32(predicate_depth_b32x,
865
+ query_original_ptrs[row_batch_start + 1] + depth_index);
866
+ svfloat32_t document_values_1_f32x = svld1_f32(
867
+ predicate_depth_b32x, document_original_ptrs[row_batch_start + 1] + depth_index);
868
+ accumulator_1_f64x = svmla_f64_m(predicate_even_b64x, accumulator_1_f64x,
869
+ svcvt_f64_f32_x(predicate_even_b64x, query_values_1_f32x),
870
+ svcvt_f64_f32_x(predicate_even_b64x, document_values_1_f32x));
871
+ accumulator_1_f64x = svmla_f64_m(predicate_odd_b64x, accumulator_1_f64x,
872
+ svcvtlt_f64_f32_x(predicate_odd_b64x, query_values_1_f32x),
873
+ svcvtlt_f64_f32_x(predicate_odd_b64x, document_values_1_f32x));
874
+
875
+ svfloat32_t query_values_2_f32x = svld1_f32(predicate_depth_b32x,
876
+ query_original_ptrs[row_batch_start + 2] + depth_index);
877
+ svfloat32_t document_values_2_f32x = svld1_f32(
878
+ predicate_depth_b32x, document_original_ptrs[row_batch_start + 2] + depth_index);
879
+ accumulator_2_f64x = svmla_f64_m(predicate_even_b64x, accumulator_2_f64x,
880
+ svcvt_f64_f32_x(predicate_even_b64x, query_values_2_f32x),
881
+ svcvt_f64_f32_x(predicate_even_b64x, document_values_2_f32x));
882
+ accumulator_2_f64x = svmla_f64_m(predicate_odd_b64x, accumulator_2_f64x,
883
+ svcvtlt_f64_f32_x(predicate_odd_b64x, query_values_2_f32x),
884
+ svcvtlt_f64_f32_x(predicate_odd_b64x, document_values_2_f32x));
885
+
886
+ svfloat32_t query_values_3_f32x = svld1_f32(predicate_depth_b32x,
887
+ query_original_ptrs[row_batch_start + 3] + depth_index);
888
+ svfloat32_t document_values_3_f32x = svld1_f32(
889
+ predicate_depth_b32x, document_original_ptrs[row_batch_start + 3] + depth_index);
890
+ accumulator_3_f64x = svmla_f64_m(predicate_even_b64x, accumulator_3_f64x,
891
+ svcvt_f64_f32_x(predicate_even_b64x, query_values_3_f32x),
892
+ svcvt_f64_f32_x(predicate_even_b64x, document_values_3_f32x));
893
+ accumulator_3_f64x = svmla_f64_m(predicate_odd_b64x, accumulator_3_f64x,
894
+ svcvtlt_f64_f32_x(predicate_odd_b64x, query_values_3_f32x),
895
+ svcvtlt_f64_f32_x(predicate_odd_b64x, document_values_3_f32x));
871
896
  }
872
897
 
873
898
  // Reduce accumulators and compute angular distance per row