numkong 7.0.0 → 7.4.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (315) hide show
  1. package/README.md +239 -122
  2. package/binding.gyp +25 -491
  3. package/c/dispatch_bf16.c +59 -1
  4. package/c/dispatch_e2m3.c +41 -8
  5. package/c/dispatch_e3m2.c +49 -8
  6. package/c/dispatch_e4m3.c +51 -9
  7. package/c/dispatch_e5m2.c +45 -1
  8. package/c/dispatch_f16.c +79 -26
  9. package/c/dispatch_f16c.c +5 -5
  10. package/c/dispatch_f32.c +56 -0
  11. package/c/dispatch_f64.c +52 -0
  12. package/c/dispatch_i4.c +3 -0
  13. package/c/dispatch_i8.c +62 -3
  14. package/c/dispatch_other.c +18 -0
  15. package/c/dispatch_u1.c +54 -9
  16. package/c/dispatch_u4.c +3 -0
  17. package/c/dispatch_u8.c +64 -3
  18. package/c/numkong.c +3 -0
  19. package/include/README.md +79 -9
  20. package/include/numkong/attention/sapphireamx.h +278 -276
  21. package/include/numkong/attention/sme.h +983 -977
  22. package/include/numkong/attention.h +1 -1
  23. package/include/numkong/capabilities.h +289 -94
  24. package/include/numkong/cast/README.md +40 -40
  25. package/include/numkong/cast/diamond.h +64 -0
  26. package/include/numkong/cast/haswell.h +42 -194
  27. package/include/numkong/cast/icelake.h +42 -37
  28. package/include/numkong/cast/loongsonasx.h +252 -0
  29. package/include/numkong/cast/neon.h +216 -249
  30. package/include/numkong/cast/powervsx.h +449 -0
  31. package/include/numkong/cast/rvv.h +223 -274
  32. package/include/numkong/cast/sapphire.h +18 -18
  33. package/include/numkong/cast/serial.h +1018 -944
  34. package/include/numkong/cast/skylake.h +82 -23
  35. package/include/numkong/cast/v128relaxed.h +462 -105
  36. package/include/numkong/cast.h +24 -0
  37. package/include/numkong/cast.hpp +44 -0
  38. package/include/numkong/curved/README.md +17 -17
  39. package/include/numkong/curved/neon.h +131 -7
  40. package/include/numkong/curved/neonbfdot.h +6 -7
  41. package/include/numkong/curved/rvv.h +26 -26
  42. package/include/numkong/curved/smef64.h +186 -182
  43. package/include/numkong/curved.h +14 -18
  44. package/include/numkong/dot/README.md +154 -137
  45. package/include/numkong/dot/alder.h +43 -43
  46. package/include/numkong/dot/diamond.h +158 -0
  47. package/include/numkong/dot/genoa.h +4 -30
  48. package/include/numkong/dot/haswell.h +215 -180
  49. package/include/numkong/dot/icelake.h +190 -76
  50. package/include/numkong/dot/loongsonasx.h +671 -0
  51. package/include/numkong/dot/neon.h +124 -73
  52. package/include/numkong/dot/neonbfdot.h +11 -12
  53. package/include/numkong/dot/neonfhm.h +44 -46
  54. package/include/numkong/dot/neonfp8.h +323 -0
  55. package/include/numkong/dot/neonsdot.h +190 -76
  56. package/include/numkong/dot/powervsx.h +752 -0
  57. package/include/numkong/dot/rvv.h +92 -84
  58. package/include/numkong/dot/rvvbf16.h +12 -12
  59. package/include/numkong/dot/rvvhalf.h +12 -12
  60. package/include/numkong/dot/sapphire.h +4 -4
  61. package/include/numkong/dot/serial.h +66 -30
  62. package/include/numkong/dot/sierra.h +31 -31
  63. package/include/numkong/dot/skylake.h +142 -110
  64. package/include/numkong/dot/sve.h +217 -177
  65. package/include/numkong/dot/svebfdot.h +10 -10
  66. package/include/numkong/dot/svehalf.h +85 -41
  67. package/include/numkong/dot/svesdot.h +89 -0
  68. package/include/numkong/dot/v128relaxed.h +124 -89
  69. package/include/numkong/dot.h +114 -48
  70. package/include/numkong/dots/README.md +203 -203
  71. package/include/numkong/dots/alder.h +12 -9
  72. package/include/numkong/dots/diamond.h +86 -0
  73. package/include/numkong/dots/genoa.h +10 -4
  74. package/include/numkong/dots/haswell.h +63 -48
  75. package/include/numkong/dots/icelake.h +27 -18
  76. package/include/numkong/dots/loongsonasx.h +176 -0
  77. package/include/numkong/dots/neon.h +14 -11
  78. package/include/numkong/dots/neonbfdot.h +4 -3
  79. package/include/numkong/dots/neonfhm.h +11 -9
  80. package/include/numkong/dots/neonfp8.h +99 -0
  81. package/include/numkong/dots/neonsdot.h +48 -12
  82. package/include/numkong/dots/powervsx.h +194 -0
  83. package/include/numkong/dots/rvv.h +451 -344
  84. package/include/numkong/dots/sapphireamx.h +1028 -984
  85. package/include/numkong/dots/serial.h +213 -197
  86. package/include/numkong/dots/sierra.h +10 -7
  87. package/include/numkong/dots/skylake.h +47 -36
  88. package/include/numkong/dots/sme.h +2001 -2364
  89. package/include/numkong/dots/smebi32.h +175 -162
  90. package/include/numkong/dots/smef64.h +328 -323
  91. package/include/numkong/dots/v128relaxed.h +64 -41
  92. package/include/numkong/dots.h +573 -293
  93. package/include/numkong/dots.hpp +45 -43
  94. package/include/numkong/each/README.md +133 -137
  95. package/include/numkong/each/haswell.h +6 -6
  96. package/include/numkong/each/icelake.h +7 -7
  97. package/include/numkong/each/neon.h +76 -42
  98. package/include/numkong/each/neonbfdot.h +11 -12
  99. package/include/numkong/each/neonhalf.h +24 -116
  100. package/include/numkong/each/rvv.h +28 -28
  101. package/include/numkong/each/sapphire.h +27 -161
  102. package/include/numkong/each/serial.h +6 -6
  103. package/include/numkong/each/skylake.h +7 -7
  104. package/include/numkong/each/v128relaxed.h +562 -0
  105. package/include/numkong/each.h +148 -62
  106. package/include/numkong/each.hpp +2 -2
  107. package/include/numkong/geospatial/README.md +18 -18
  108. package/include/numkong/geospatial/haswell.h +365 -325
  109. package/include/numkong/geospatial/neon.h +350 -306
  110. package/include/numkong/geospatial/rvv.h +4 -4
  111. package/include/numkong/geospatial/skylake.h +376 -340
  112. package/include/numkong/geospatial/v128relaxed.h +366 -327
  113. package/include/numkong/geospatial.h +17 -17
  114. package/include/numkong/matrix.hpp +4 -4
  115. package/include/numkong/maxsim/README.md +14 -14
  116. package/include/numkong/maxsim/alder.h +6 -6
  117. package/include/numkong/maxsim/genoa.h +4 -4
  118. package/include/numkong/maxsim/haswell.h +6 -6
  119. package/include/numkong/maxsim/icelake.h +18 -18
  120. package/include/numkong/maxsim/neonsdot.h +21 -21
  121. package/include/numkong/maxsim/sapphireamx.h +14 -14
  122. package/include/numkong/maxsim/serial.h +6 -6
  123. package/include/numkong/maxsim/sme.h +221 -196
  124. package/include/numkong/maxsim/v128relaxed.h +6 -6
  125. package/include/numkong/mesh/README.md +62 -56
  126. package/include/numkong/mesh/haswell.h +339 -464
  127. package/include/numkong/mesh/neon.h +1100 -519
  128. package/include/numkong/mesh/neonbfdot.h +36 -68
  129. package/include/numkong/mesh/rvv.h +530 -435
  130. package/include/numkong/mesh/serial.h +75 -91
  131. package/include/numkong/mesh/skylake.h +1627 -302
  132. package/include/numkong/mesh/v128relaxed.h +443 -330
  133. package/include/numkong/mesh.h +63 -49
  134. package/include/numkong/mesh.hpp +4 -4
  135. package/include/numkong/numkong.h +3 -3
  136. package/include/numkong/numkong.hpp +1 -0
  137. package/include/numkong/probability/README.md +23 -19
  138. package/include/numkong/probability/neon.h +82 -52
  139. package/include/numkong/probability/rvv.h +28 -23
  140. package/include/numkong/probability/serial.h +51 -39
  141. package/include/numkong/probability.h +20 -23
  142. package/include/numkong/random.h +1 -1
  143. package/include/numkong/reduce/README.md +143 -138
  144. package/include/numkong/reduce/alder.h +81 -77
  145. package/include/numkong/reduce/haswell.h +222 -220
  146. package/include/numkong/reduce/neon.h +629 -519
  147. package/include/numkong/reduce/neonbfdot.h +7 -218
  148. package/include/numkong/reduce/neonfhm.h +9 -381
  149. package/include/numkong/reduce/neonsdot.h +9 -9
  150. package/include/numkong/reduce/rvv.h +928 -802
  151. package/include/numkong/reduce/serial.h +23 -27
  152. package/include/numkong/reduce/sierra.h +20 -20
  153. package/include/numkong/reduce/skylake.h +326 -324
  154. package/include/numkong/reduce/v128relaxed.h +52 -52
  155. package/include/numkong/reduce.h +4 -23
  156. package/include/numkong/reduce.hpp +156 -11
  157. package/include/numkong/scalar/README.md +6 -6
  158. package/include/numkong/scalar/haswell.h +26 -17
  159. package/include/numkong/scalar/loongsonasx.h +74 -0
  160. package/include/numkong/scalar/neon.h +9 -9
  161. package/include/numkong/scalar/powervsx.h +96 -0
  162. package/include/numkong/scalar/rvv.h +2 -2
  163. package/include/numkong/scalar/sapphire.h +21 -10
  164. package/include/numkong/scalar/serial.h +21 -21
  165. package/include/numkong/scalar.h +13 -0
  166. package/include/numkong/set/README.md +28 -28
  167. package/include/numkong/set/haswell.h +12 -12
  168. package/include/numkong/set/icelake.h +14 -14
  169. package/include/numkong/set/loongsonasx.h +181 -0
  170. package/include/numkong/set/neon.h +17 -18
  171. package/include/numkong/set/powervsx.h +326 -0
  172. package/include/numkong/set/rvv.h +4 -4
  173. package/include/numkong/set/serial.h +6 -6
  174. package/include/numkong/set/sve.h +60 -59
  175. package/include/numkong/set/v128relaxed.h +6 -6
  176. package/include/numkong/set.h +21 -7
  177. package/include/numkong/sets/README.md +26 -26
  178. package/include/numkong/sets/loongsonasx.h +52 -0
  179. package/include/numkong/sets/powervsx.h +65 -0
  180. package/include/numkong/sets/smebi32.h +395 -364
  181. package/include/numkong/sets.h +83 -40
  182. package/include/numkong/sparse/README.md +4 -4
  183. package/include/numkong/sparse/icelake.h +101 -101
  184. package/include/numkong/sparse/serial.h +1 -1
  185. package/include/numkong/sparse/sve2.h +137 -141
  186. package/include/numkong/sparse/turin.h +12 -12
  187. package/include/numkong/sparse.h +10 -10
  188. package/include/numkong/spatial/README.md +230 -226
  189. package/include/numkong/spatial/alder.h +113 -116
  190. package/include/numkong/spatial/diamond.h +240 -0
  191. package/include/numkong/spatial/genoa.h +0 -68
  192. package/include/numkong/spatial/haswell.h +74 -55
  193. package/include/numkong/spatial/icelake.h +539 -58
  194. package/include/numkong/spatial/loongsonasx.h +483 -0
  195. package/include/numkong/spatial/neon.h +125 -52
  196. package/include/numkong/spatial/neonbfdot.h +8 -9
  197. package/include/numkong/spatial/neonfp8.h +258 -0
  198. package/include/numkong/spatial/neonsdot.h +180 -12
  199. package/include/numkong/spatial/powervsx.h +738 -0
  200. package/include/numkong/spatial/rvv.h +146 -139
  201. package/include/numkong/spatial/rvvbf16.h +17 -12
  202. package/include/numkong/spatial/rvvhalf.h +13 -10
  203. package/include/numkong/spatial/serial.h +13 -12
  204. package/include/numkong/spatial/sierra.h +232 -39
  205. package/include/numkong/spatial/skylake.h +73 -74
  206. package/include/numkong/spatial/sve.h +93 -72
  207. package/include/numkong/spatial/svebfdot.h +29 -29
  208. package/include/numkong/spatial/svehalf.h +52 -26
  209. package/include/numkong/spatial/svesdot.h +142 -0
  210. package/include/numkong/spatial/v128relaxed.h +293 -41
  211. package/include/numkong/spatial.h +338 -82
  212. package/include/numkong/spatials/README.md +194 -194
  213. package/include/numkong/spatials/diamond.h +82 -0
  214. package/include/numkong/spatials/haswell.h +2 -2
  215. package/include/numkong/spatials/loongsonasx.h +153 -0
  216. package/include/numkong/spatials/neonfp8.h +111 -0
  217. package/include/numkong/spatials/neonsdot.h +34 -0
  218. package/include/numkong/spatials/powervsx.h +153 -0
  219. package/include/numkong/spatials/rvv.h +259 -243
  220. package/include/numkong/spatials/sapphireamx.h +173 -173
  221. package/include/numkong/spatials/serial.h +2 -2
  222. package/include/numkong/spatials/skylake.h +2 -2
  223. package/include/numkong/spatials/sme.h +590 -605
  224. package/include/numkong/spatials/smef64.h +139 -130
  225. package/include/numkong/spatials/v128relaxed.h +2 -2
  226. package/include/numkong/spatials.h +820 -500
  227. package/include/numkong/spatials.hpp +49 -48
  228. package/include/numkong/tensor.hpp +406 -17
  229. package/include/numkong/trigonometry/README.md +19 -19
  230. package/include/numkong/trigonometry/haswell.h +402 -401
  231. package/include/numkong/trigonometry/neon.h +386 -387
  232. package/include/numkong/trigonometry/rvv.h +52 -51
  233. package/include/numkong/trigonometry/serial.h +13 -13
  234. package/include/numkong/trigonometry/skylake.h +373 -369
  235. package/include/numkong/trigonometry/v128relaxed.h +375 -374
  236. package/include/numkong/trigonometry.h +13 -13
  237. package/include/numkong/trigonometry.hpp +2 -2
  238. package/include/numkong/types.h +287 -49
  239. package/include/numkong/types.hpp +436 -12
  240. package/include/numkong/vector.hpp +82 -14
  241. package/javascript/dist/cjs/numkong-wasm.js +6 -12
  242. package/javascript/dist/cjs/numkong.d.ts +7 -1
  243. package/javascript/dist/cjs/numkong.js +37 -11
  244. package/javascript/dist/cjs/types.d.ts +9 -0
  245. package/javascript/dist/cjs/types.js +96 -0
  246. package/javascript/dist/esm/numkong-browser.d.ts +14 -0
  247. package/javascript/dist/esm/numkong-browser.js +23 -0
  248. package/javascript/dist/esm/numkong-wasm.js +6 -12
  249. package/javascript/dist/esm/numkong.d.ts +7 -1
  250. package/javascript/dist/esm/numkong.js +37 -11
  251. package/javascript/dist/esm/types.d.ts +9 -0
  252. package/javascript/dist/esm/types.js +96 -0
  253. package/javascript/node-gyp-build.d.ts +4 -1
  254. package/javascript/numkong-browser.ts +40 -0
  255. package/javascript/numkong-wasm.ts +7 -13
  256. package/javascript/numkong.c +5 -26
  257. package/javascript/numkong.ts +36 -11
  258. package/javascript/tsconfig-base.json +1 -0
  259. package/javascript/tsconfig-cjs.json +6 -1
  260. package/javascript/types.ts +110 -0
  261. package/numkong.gypi +101 -0
  262. package/package.json +34 -13
  263. package/probes/arm_neon.c +8 -0
  264. package/probes/arm_neon_bfdot.c +9 -0
  265. package/probes/arm_neon_fhm.c +9 -0
  266. package/probes/arm_neon_half.c +8 -0
  267. package/probes/arm_neon_sdot.c +9 -0
  268. package/probes/arm_neonfp8.c +9 -0
  269. package/probes/arm_sme.c +16 -0
  270. package/probes/arm_sme2.c +16 -0
  271. package/probes/arm_sme2p1.c +16 -0
  272. package/probes/arm_sme_bf16.c +16 -0
  273. package/probes/arm_sme_bi32.c +16 -0
  274. package/probes/arm_sme_f64.c +16 -0
  275. package/probes/arm_sme_fa64.c +14 -0
  276. package/probes/arm_sme_half.c +16 -0
  277. package/probes/arm_sme_lut2.c +15 -0
  278. package/probes/arm_sve.c +18 -0
  279. package/probes/arm_sve2.c +20 -0
  280. package/probes/arm_sve2p1.c +18 -0
  281. package/probes/arm_sve_bfdot.c +20 -0
  282. package/probes/arm_sve_half.c +18 -0
  283. package/probes/arm_sve_sdot.c +21 -0
  284. package/probes/loongarch_lasx.c +12 -0
  285. package/probes/power_vsx.c +12 -0
  286. package/probes/probe.js +127 -0
  287. package/probes/riscv_rvv.c +14 -0
  288. package/probes/riscv_rvv_bb.c +15 -0
  289. package/probes/riscv_rvv_bf16.c +17 -0
  290. package/probes/riscv_rvv_half.c +14 -0
  291. package/probes/wasm_v128relaxed.c +11 -0
  292. package/probes/x86_alder.c +17 -0
  293. package/probes/x86_diamond.c +17 -0
  294. package/probes/x86_genoa.c +17 -0
  295. package/probes/x86_graniteamx.c +19 -0
  296. package/probes/x86_haswell.c +11 -0
  297. package/probes/x86_icelake.c +17 -0
  298. package/probes/x86_sapphire.c +16 -0
  299. package/probes/x86_sapphireamx.c +18 -0
  300. package/probes/x86_sierra.c +17 -0
  301. package/probes/x86_skylake.c +15 -0
  302. package/probes/x86_turin.c +17 -0
  303. package/wasm/numkong-emscripten.js +2 -0
  304. package/wasm/numkong.d.ts +14 -0
  305. package/wasm/numkong.js +1124 -0
  306. package/wasm/numkong.wasm +0 -0
  307. package/include/numkong/curved/neonhalf.h +0 -212
  308. package/include/numkong/dot/neonhalf.h +0 -198
  309. package/include/numkong/dots/neonhalf.h +0 -57
  310. package/include/numkong/mesh/neonhalf.h +0 -616
  311. package/include/numkong/reduce/neonhalf.h +0 -157
  312. package/include/numkong/spatial/neonhalf.h +0 -118
  313. package/include/numkong/spatial/sapphire.h +0 -343
  314. package/include/numkong/spatials/neonhalf.h +0 -58
  315. package/javascript/README.md +0 -246
@@ -65,7 +65,7 @@ extern "C" {
65
65
  */
66
66
 
67
67
  #if defined(__clang__)
68
- #pragma clang attribute push(__attribute__((target("sme2,sve2"))), apply_to = function)
68
+ #pragma clang attribute push(__attribute__((target("sme2"))), apply_to = function)
69
69
  #elif defined(__GNUC__)
70
70
  #pragma GCC push_options
71
71
  #pragma GCC target("+sme2")
@@ -93,13 +93,12 @@ typedef struct {
93
93
 
94
94
  /** Count total set bits across a byte vector using streaming SVE.
95
95
  * Accumulates per-byte popcounts into u32 lanes via svdot; single horizontal reduction at end. */
96
- NK_PUBLIC nk_u32_t nk_sets_reduce_sumsq_u1_streaming_(nk_u1x8_t const *data,
97
- nk_size_t n_bytes) NK_STREAMING_COMPATIBLE_ {
96
+ NK_PUBLIC nk_u32_t nk_sets_reduce_sumsq_u1_streaming_(nk_u1x8_t const *data, nk_size_t n_bytes) NK_STREAMING_ {
98
97
  svuint32_t acc_u32x = svdup_u32(0);
99
98
  svuint8_t const ones_u8x = svdup_u8(1);
100
99
  for (nk_size_t offset = 0; offset < n_bytes; offset += svcntb()) {
101
- svbool_t predicate_u8x = svwhilelt_b8_u64(offset, n_bytes);
102
- acc_u32x = svdot_u32(acc_u32x, svcnt_u8_z(predicate_u8x, svld1_u8(predicate_u8x, data + offset)), ones_u8x);
100
+ svbool_t predicate_b8x = svwhilelt_b8_u64(offset, n_bytes);
101
+ acc_u32x = svdot_u32(acc_u32x, svcnt_u8_z(predicate_b8x, svld1_u8(predicate_b8x, data + offset)), ones_u8x);
103
102
  }
104
103
  return (nk_u32_t)svaddv_u32(svptrue_b32(), acc_u32x);
105
104
  }
@@ -128,11 +127,13 @@ NK_PUBLIC void nk_dots_pack_u1_smebi32(nk_u1x8_t const *b, nk_size_t row_count,
128
127
  nk_size_t const tile_dim = nk_smebi32_tile_dim_(); // 16 rows per tile
129
128
  nk_size_t const depth_tile_size = nk_smebi32_tile_dim_(); // 16 u32 per depth tile
130
129
  nk_size_t const tile_elements = tile_dim * depth_tile_size;
131
- nk_size_t const depth_in_bytes = nk_size_divide_round_up_(depth_bits, NK_BITS_PER_BYTE);
130
+ nk_size_t const depth_bytes = depth_bits / 8;
132
131
 
133
- nk_size_t const depth_u32_total = nk_size_divide_round_up_(depth_bits, 32);
132
+ // BMOPA processes binary data in 32-bit words: each svbmopa_za32_u32_m step
133
+ // handles one u32 (32 bits) across all row×column pairs simultaneously.
134
+ nk_size_t const depth_words = nk_size_divide_round_up_(depth_bits, 32);
134
135
  nk_size_t const row_tile_count = nk_size_divide_round_up_(row_count, tile_dim);
135
- nk_size_t const depth_tile_count = nk_size_divide_round_up_(depth_u32_total, depth_tile_size);
136
+ nk_size_t const depth_tile_count = nk_size_divide_round_up_(depth_words, depth_tile_size);
136
137
  nk_size_t const total_tiles = row_tile_count * depth_tile_count;
137
138
  nk_size_t const data_size = total_tiles * tile_elements * sizeof(nk_u32_t);
138
139
 
@@ -160,18 +161,24 @@ NK_PUBLIC void nk_dots_pack_u1_smebi32(nk_u1x8_t const *b, nk_size_t row_count,
160
161
  nk_size_t const src_u32_start = depth_tile * depth_tile_size;
161
162
  nk_size_t const rows_to_pack = (src_row_start + tile_dim <= row_count) ? tile_dim
162
163
  : (row_count - src_row_start);
163
- nk_size_t const u32s_to_pack = (src_u32_start + depth_tile_size <= depth_u32_total)
164
+ nk_size_t const u32s_to_pack = (src_u32_start + depth_tile_size <= depth_words)
164
165
  ? depth_tile_size
165
- : (depth_u32_total > src_u32_start ? depth_u32_total - src_u32_start
166
- : 0);
166
+ : (depth_words > src_u32_start ? depth_words - src_u32_start : 0);
167
167
 
168
168
  // Column-major packing: tile_output[col * tile_dim + row]
169
+ // Copy byte-by-byte for the last u32 to avoid garbage bits when depth_bits % 32 != 0
170
+ nk_size_t const tail_bytes = depth_bytes % 4;
171
+ nk_size_t const last_col = u32s_to_pack > 0 ? u32s_to_pack - 1 : 0;
172
+ nk_size_t const is_last_depth_tile = (src_u32_start + u32s_to_pack >= depth_words);
169
173
  for (nk_size_t row = 0; row < rows_to_pack; row++) {
170
174
  nk_u32_t const *src_row = (nk_u32_t const *)((char const *)b +
171
175
  (src_row_start + row) * b_stride_in_bytes);
172
176
  for (nk_size_t col = 0; col < u32s_to_pack; col++) {
173
177
  nk_size_t const dst_idx = col * tile_dim + row; // Column-major!
174
- tile_output[dst_idx] = src_row[src_u32_start + col];
178
+ if (tail_bytes && is_last_depth_tile && col == last_col) {
179
+ nk_copy_bytes_(&tile_output[dst_idx], &src_row[src_u32_start + col], tail_bytes);
180
+ }
181
+ else { tile_output[dst_idx] = src_row[src_u32_start + col]; }
175
182
  }
176
183
  }
177
184
  }
@@ -182,7 +189,7 @@ NK_PUBLIC void nk_dots_pack_u1_smebi32(nk_u1x8_t const *b, nk_size_t row_count,
182
189
  nk_u1x8_t const *src_row = (nk_u1x8_t const *)((char const *)b + row * b_stride_in_bytes);
183
190
  {
184
191
  nk_u64_t nk_local_sum_, nk_local_sumsq_;
185
- nk_reduce_moments_u1(src_row, depth_in_bytes * 8, sizeof(nk_u1x8_t), &nk_local_sum_, &nk_local_sumsq_);
192
+ nk_reduce_moments_u1(src_row, depth_bytes * 8, sizeof(nk_u1x8_t), &nk_local_sum_, &nk_local_sumsq_);
186
193
  norms_ptr[row] = (nk_u32_t)nk_local_sum_;
187
194
  }
188
195
  }
@@ -207,19 +214,24 @@ __arm_locally_streaming __arm_new("za") static void nk_hammings_packed_u1_smebi3
207
214
  nk_size_t const tile_dim = svcntw(); // 16 for 512-bit SVL
208
215
  nk_size_t const depth_tile_size = svcntw(); // 16 u32 per depth tile
209
216
  nk_size_t const tile_elements = tile_dim * depth_tile_size;
210
- nk_size_t const depth_u32_total = nk_size_divide_round_up_(depth_bits, 32);
217
+ // BMOPA processes binary data in 32-bit words: each svbmopa_za32_u32_m step
218
+ // handles one u32 (32 bits) across all row×column pairs simultaneously.
219
+ nk_size_t const depth_words = nk_size_divide_round_up_(depth_bits, 32);
220
+ nk_size_t const depth_bytes = depth_bits / 8;
211
221
 
212
222
  nk_u32_t const *b_tiles = (nk_u32_t const *)((char const *)b_packed + sizeof(nk_sets_smebi32_packed_header_t));
213
223
 
214
- svbool_t const predicate_all_u32x = svptrue_b32();
215
- svuint32_t const depth_u32x = svdup_u32((nk_u32_t)depth_bits);
224
+ svbool_t const predicate_all_b32x = svptrue_b32();
225
+ // Use padded depth (depth_words * 32) for BMOPA: zero-padded bits always match in XNOR,
226
+ // so the effective depth for the matching→hamming conversion is the rounded-up bit count.
227
+ svuint32_t const depth_u32x = svdup_u32((nk_u32_t)(depth_words * 32));
216
228
  nk_size_t const row_tile_count_a = nk_size_divide_round_up_(row_count_a, tile_dim);
217
229
 
218
230
  for (nk_size_t row_tile_a = 0; row_tile_a < row_tile_count_a; row_tile_a++) {
219
231
  nk_size_t const row_start_a = row_tile_a * tile_dim;
220
232
  nk_size_t const rows_a_remaining = (row_start_a + tile_dim <= row_count_a) ? tile_dim
221
233
  : (row_count_a - row_start_a);
222
- svbool_t const row_predicate_u32x = svwhilelt_b32_u64(0u, rows_a_remaining);
234
+ svbool_t const row_predicate_b32x = svwhilelt_b32_u64(0u, rows_a_remaining);
223
235
 
224
236
  // Fast path: 3 B column tiles using ZA1-ZA3 (ZA0.S = staging)
225
237
  nk_size_t row_tile_b = 0;
@@ -228,22 +240,23 @@ __arm_locally_streaming __arm_new("za") static void nk_hammings_packed_u1_smebi3
228
240
 
229
241
  for (nk_size_t d_tile = 0; d_tile < depth_tile_count; d_tile++) {
230
242
  nk_size_t const d_start_u32 = d_tile * depth_tile_size;
231
- nk_size_t const u32s_this_tile = (d_start_u32 + depth_tile_size <= depth_u32_total)
243
+ nk_size_t const u32s_this_tile = (d_start_u32 + depth_tile_size <= depth_words)
232
244
  ? depth_tile_size
233
- : (depth_u32_total > d_start_u32 ? depth_u32_total - d_start_u32
234
- : 0);
245
+ : (depth_words > d_start_u32 ? depth_words - d_start_u32 : 0);
235
246
  if (u32s_this_tile == 0) break;
236
247
 
237
248
  svzero_mask_za(nk_sme_zero_za32_tile_0_);
238
249
 
239
- svbool_t const batch_predicate_u32x = svwhilelt_b32_u64(0u, u32s_this_tile);
250
+ svbool_t const batch_predicate_b32x = svwhilelt_b32_u64(0u, u32s_this_tile);
251
+
252
+ svbool_t const depth_predicate_b8x = svwhilelt_b8_u64(d_start_u32 * 4, depth_bytes);
240
253
 
241
- // Load A rows into ZA0.S horizontally as u32 words
254
+ // Load A rows into ZA0.S, byte-predicated to zero garbage bits
242
255
  for (nk_size_t row_in_tile = 0; row_in_tile < rows_a_remaining; row_in_tile++) {
243
- nk_u32_t const *a_row_u32 = (nk_u32_t const *)((char const *)a +
244
- (row_start_a + row_in_tile) * a_stride_in_bytes) +
245
- d_start_u32;
246
- svld1_hor_za32(0, row_in_tile, batch_predicate_u32x, a_row_u32);
256
+ nk_u8_t const *a_row = (nk_u8_t const *)a + (row_start_a + row_in_tile) * a_stride_in_bytes +
257
+ d_start_u32 * 4;
258
+ svuint8_t row_u8x = svld1_u8(depth_predicate_b8x, a_row);
259
+ svwrite_hor_za32_u32_m(0, row_in_tile, batch_predicate_b32x, svreinterpret_u32_u8(row_u8x));
247
260
  }
248
261
 
249
262
  // B tile pointers for 3 column tiles
@@ -253,14 +266,14 @@ __arm_locally_streaming __arm_new("za") static void nk_hammings_packed_u1_smebi3
253
266
 
254
267
  // Vertical read + BMOPA for each depth step
255
268
  for (nk_size_t step = 0; step < u32s_this_tile; step++) {
256
- svuint32_t a_column_u32x = svread_ver_za32_u32_m(svdup_u32(0), row_predicate_u32x, 0, step);
257
-
258
- svbmopa_za32_u32_m(1, row_predicate_u32x, predicate_all_u32x, a_column_u32x,
259
- svld1_u32(predicate_all_u32x, b_tile0 + step * tile_dim));
260
- svbmopa_za32_u32_m(2, row_predicate_u32x, predicate_all_u32x, a_column_u32x,
261
- svld1_u32(predicate_all_u32x, b_tile1 + step * tile_dim));
262
- svbmopa_za32_u32_m(3, row_predicate_u32x, predicate_all_u32x, a_column_u32x,
263
- svld1_u32(predicate_all_u32x, b_tile2 + step * tile_dim));
269
+ svuint32_t a_column_u32x = svread_ver_za32_u32_m(svdup_u32(0), row_predicate_b32x, 0, step);
270
+
271
+ svbmopa_za32_u32_m(1, row_predicate_b32x, predicate_all_b32x, a_column_u32x,
272
+ svld1_u32(predicate_all_b32x, b_tile0 + step * tile_dim));
273
+ svbmopa_za32_u32_m(2, row_predicate_b32x, predicate_all_b32x, a_column_u32x,
274
+ svld1_u32(predicate_all_b32x, b_tile1 + step * tile_dim));
275
+ svbmopa_za32_u32_m(3, row_predicate_b32x, predicate_all_b32x, a_column_u32x,
276
+ svld1_u32(predicate_all_b32x, b_tile2 + step * tile_dim));
264
277
  }
265
278
  }
266
279
 
@@ -268,16 +281,16 @@ __arm_locally_streaming __arm_new("za") static void nk_hammings_packed_u1_smebi3
268
281
  for (nk_size_t row = 0; row < rows_a_remaining; row++) {
269
282
  nk_u32_t *c_row = (nk_u32_t *)((char *)c + (row_start_a + row) * c_stride_in_bytes);
270
283
 
271
- svuint32_t za1_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_u32x, 1, row);
272
- svuint32_t za2_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_u32x, 2, row);
273
- svuint32_t za3_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_u32x, 3, row);
284
+ svuint32_t za1_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_b32x, 1, row);
285
+ svuint32_t za2_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_b32x, 2, row);
286
+ svuint32_t za3_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_b32x, 3, row);
274
287
 
275
- svst1_u32(predicate_all_u32x, c_row + (row_tile_b + 0) * tile_dim,
276
- svsub_u32_x(predicate_all_u32x, depth_u32x, za1_u32x));
277
- svst1_u32(predicate_all_u32x, c_row + (row_tile_b + 1) * tile_dim,
278
- svsub_u32_x(predicate_all_u32x, depth_u32x, za2_u32x));
279
- svst1_u32(predicate_all_u32x, c_row + (row_tile_b + 2) * tile_dim,
280
- svsub_u32_x(predicate_all_u32x, depth_u32x, za3_u32x));
288
+ svst1_u32(predicate_all_b32x, c_row + (row_tile_b + 0) * tile_dim,
289
+ svsub_u32_x(predicate_all_b32x, depth_u32x, za1_u32x));
290
+ svst1_u32(predicate_all_b32x, c_row + (row_tile_b + 1) * tile_dim,
291
+ svsub_u32_x(predicate_all_b32x, depth_u32x, za2_u32x));
292
+ svst1_u32(predicate_all_b32x, c_row + (row_tile_b + 2) * tile_dim,
293
+ svsub_u32_x(predicate_all_b32x, depth_u32x, za3_u32x));
281
294
  }
282
295
  }
283
296
 
@@ -286,46 +299,46 @@ __arm_locally_streaming __arm_new("za") static void nk_hammings_packed_u1_smebi3
286
299
  nk_size_t const row_start_b = row_tile_b * tile_dim;
287
300
  nk_size_t const rows_b_remaining = (row_start_b + tile_dim <= row_count_b) ? tile_dim
288
301
  : (row_count_b - row_start_b);
289
- svbool_t const column_predicate_u32x = svwhilelt_b32_u64(0u, rows_b_remaining);
302
+ svbool_t const column_predicate_b32x = svwhilelt_b32_u64(0u, rows_b_remaining);
290
303
 
291
304
  svzero_mask_za(nk_sme_zero_za32_tile_1_);
292
305
 
293
306
  for (nk_size_t d_tile = 0; d_tile < depth_tile_count; d_tile++) {
294
307
  nk_size_t const d_start_u32 = d_tile * depth_tile_size;
295
- nk_size_t const u32s_this_tile = (d_start_u32 + depth_tile_size <= depth_u32_total)
308
+ nk_size_t const u32s_this_tile = (d_start_u32 + depth_tile_size <= depth_words)
296
309
  ? depth_tile_size
297
- : (depth_u32_total > d_start_u32 ? depth_u32_total - d_start_u32
298
- : 0);
310
+ : (depth_words > d_start_u32 ? depth_words - d_start_u32 : 0);
299
311
  if (u32s_this_tile == 0) break;
300
312
 
301
313
  svzero_mask_za(nk_sme_zero_za32_tile_0_);
302
314
 
303
- svbool_t const batch_predicate_u32x = svwhilelt_b32_u64(0u, u32s_this_tile);
315
+ svbool_t const batch_predicate_b32x = svwhilelt_b32_u64(0u, u32s_this_tile);
304
316
 
305
317
  // Load A rows into ZA0.S horizontally
318
+ svbool_t const depth_predicate_b8x = svwhilelt_b8_u64(d_start_u32 * 4, depth_bytes);
306
319
  for (nk_size_t row_in_tile = 0; row_in_tile < rows_a_remaining; row_in_tile++) {
307
- nk_u32_t const *a_row_u32 = (nk_u32_t const *)((char const *)a +
308
- (row_start_a + row_in_tile) * a_stride_in_bytes) +
309
- d_start_u32;
310
- svld1_hor_za32(0, row_in_tile, batch_predicate_u32x, a_row_u32);
320
+ nk_u8_t const *a_row = (nk_u8_t const *)a + (row_start_a + row_in_tile) * a_stride_in_bytes +
321
+ d_start_u32 * 4;
322
+ svuint8_t row_u8x = svld1_u8(depth_predicate_b8x, a_row);
323
+ svwrite_hor_za32_u32_m(0, row_in_tile, batch_predicate_b32x, svreinterpret_u32_u8(row_u8x));
311
324
  }
312
325
 
313
326
  nk_u32_t const *b_tile = b_tiles + (row_tile_b * depth_tile_count + d_tile) * tile_elements;
314
327
 
315
328
  // Vertical read + BMOPA
316
329
  for (nk_size_t step = 0; step < u32s_this_tile; step++) {
317
- svuint32_t a_column_u32x = svread_ver_za32_u32_m(svdup_u32(0), row_predicate_u32x, 0, step);
318
- svuint32_t b_u32x = svld1_u32(predicate_all_u32x, b_tile + step * tile_dim);
319
- svbmopa_za32_u32_m(1, row_predicate_u32x, column_predicate_u32x, a_column_u32x, b_u32x);
330
+ svuint32_t a_column_u32x = svread_ver_za32_u32_m(svdup_u32(0), row_predicate_b32x, 0, step);
331
+ svuint32_t b_u32x = svld1_u32(predicate_all_b32x, b_tile + step * tile_dim);
332
+ svbmopa_za32_u32_m(1, row_predicate_b32x, column_predicate_b32x, a_column_u32x, b_u32x);
320
333
  }
321
334
  }
322
335
 
323
336
  // Extract from ZA1: Hamming = depth_bits - matching_bits
324
337
  for (nk_size_t row = 0; row < rows_a_remaining; row++) {
325
- svuint32_t za1_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_u32x, 1, row);
326
- svuint32_t hamming_u32x = svsub_u32_x(predicate_all_u32x, depth_u32x, za1_u32x);
338
+ svuint32_t za1_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_b32x, 1, row);
339
+ svuint32_t hamming_u32x = svsub_u32_x(predicate_all_b32x, depth_u32x, za1_u32x);
327
340
  nk_u32_t *c_row = (nk_u32_t *)((char *)c + (row_start_a + row) * c_stride_in_bytes);
328
- svst1_u32(column_predicate_u32x, c_row + row_start_b, hamming_u32x);
341
+ svst1_u32(column_predicate_b32x, c_row + row_start_b, hamming_u32x);
329
342
  }
330
343
  }
331
344
  }
@@ -345,30 +358,37 @@ NK_PUBLIC void nk_hammings_packed_u1_smebi32(nk_u1x8_t const *a, void const *b_p
345
358
  * Mirrors the unpacked kernel nk_hammings_packed_u1_smebi32_streaming_ pattern.
346
359
  */
347
360
  __arm_locally_streaming __arm_new("za") static void nk_hammings_symmetric_u1_smebi32_streaming_(
348
- nk_u1x8_t const *vectors, nk_size_t n_vectors, nk_size_t depth_bits, nk_size_t stride, nk_u32_t *result,
349
- nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
361
+ nk_u1x8_t const *vectors, nk_size_t vectors_count, nk_size_t depth_bits, nk_size_t stride_in_bytes,
362
+ nk_u32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
350
363
 
351
364
  nk_size_t const tile_dim = svcntw(); // 16 for 512-bit SVL
352
365
  nk_size_t const depth_tile_size = svcntw(); // 16 u32 per depth tile
353
- nk_size_t const depth_u32_total = nk_size_divide_round_up_(depth_bits, 32);
354
- nk_size_t const depth_tile_count = nk_size_divide_round_up_(depth_u32_total, depth_tile_size);
366
+ // BMOPA processes binary data in 32-bit words: each svbmopa_za32_u32_m step
367
+ // handles one u32 (32 bits) across all row×column pairs simultaneously.
368
+ nk_size_t const depth_words = nk_size_divide_round_up_(depth_bits, 32);
369
+ nk_size_t const depth_bytes = depth_bits / 8;
370
+ nk_size_t const depth_tile_count = nk_size_divide_round_up_(depth_words, depth_tile_size);
355
371
 
356
- svbool_t const predicate_all_u32x = svptrue_b32();
357
- svuint32_t const depth_u32x = svdup_u32((nk_u32_t)depth_bits);
372
+ svbool_t const predicate_all_b32x = svptrue_b32();
373
+ // Use padded depth (depth_words * 32) for BMOPA: zero-padded bits always match in XNOR,
374
+ // so the effective depth for the matching→hamming conversion is the rounded-up bit count.
375
+ svuint32_t const depth_u32x = svdup_u32((nk_u32_t)(depth_words * 32));
358
376
 
359
377
  NK_ALIGN64 nk_u32_t a_buffer[16][16]; // Stack buffer for A column save
360
378
 
361
379
  nk_size_t const row_end = row_start + row_count;
362
- nk_size_t const column_tile_count = nk_size_divide_round_up_(n_vectors, tile_dim);
380
+ nk_size_t const column_tile_count = nk_size_divide_round_up_(vectors_count, tile_dim);
363
381
 
364
- for (nk_size_t row_tile_start = row_start; row_tile_start < row_end && row_tile_start < n_vectors;
382
+ for (nk_size_t row_tile_start = row_start; row_tile_start < row_end && row_tile_start < vectors_count;
365
383
  row_tile_start += tile_dim) {
366
384
  nk_size_t const rows_remaining = (row_tile_start + tile_dim <= row_end) ? tile_dim : (row_end - row_tile_start);
367
- nk_size_t const rows_clamped = (row_tile_start + rows_remaining <= n_vectors) ? rows_remaining
368
- : (n_vectors - row_tile_start);
369
- svbool_t const row_predicate_u32x = svwhilelt_b32_u64(0u, rows_clamped);
385
+ nk_size_t const rows_clamped = (row_tile_start + rows_remaining <= vectors_count)
386
+ ? rows_remaining
387
+ : (vectors_count - row_tile_start);
388
+ svbool_t const row_predicate_b32x = svwhilelt_b32_u64(0u, rows_clamped);
370
389
 
371
- nk_size_t column_tile_index = 0;
390
+ // Upper triangle: start from this row tile's column
391
+ nk_size_t column_tile_index = row_tile_start / tile_dim;
372
392
 
373
393
  // Fast path: 3 column tiles using ZA1-ZA3 (ZA0 = staging)
374
394
  for (; column_tile_index + 3 <= column_tile_count; column_tile_index += 3) {
@@ -376,162 +396,164 @@ __arm_locally_streaming __arm_new("za") static void nk_hammings_symmetric_u1_sme
376
396
 
377
397
  for (nk_size_t d_tile = 0; d_tile < depth_tile_count; d_tile++) {
378
398
  nk_size_t const d_start_u32 = d_tile * depth_tile_size;
379
- nk_size_t const u32s_this_tile = (d_start_u32 + depth_tile_size <= depth_u32_total)
399
+ nk_size_t const u32s_this_tile = (d_start_u32 + depth_tile_size <= depth_words)
380
400
  ? depth_tile_size
381
- : (depth_u32_total > d_start_u32 ? depth_u32_total - d_start_u32
382
- : 0);
401
+ : (depth_words > d_start_u32 ? depth_words - d_start_u32 : 0);
383
402
  if (u32s_this_tile == 0) break;
384
403
 
385
404
  // Load A rows into ZA0 horizontally
386
405
  svzero_mask_za(nk_sme_zero_za32_tile_0_);
387
- svbool_t const batch_predicate_u32x = svwhilelt_b32_u64(0u, u32s_this_tile);
406
+ svbool_t const batch_predicate_b32x = svwhilelt_b32_u64(0u, u32s_this_tile);
388
407
 
408
+ svbool_t const depth_predicate_b8x = svwhilelt_b8_u64(d_start_u32 * 4, depth_bytes);
389
409
  for (nk_size_t row_in_tile = 0; row_in_tile < rows_clamped; row_in_tile++) {
390
- nk_u32_t const *a_row_u32 = (nk_u32_t const *)((char const *)vectors +
391
- (row_tile_start + row_in_tile) * stride) +
392
- d_start_u32;
393
- svld1_hor_za32(0, row_in_tile, batch_predicate_u32x, a_row_u32);
410
+ nk_u8_t const *a_row = (nk_u8_t const *)vectors + (row_tile_start + row_in_tile) * stride_in_bytes +
411
+ d_start_u32 * 4;
412
+ svuint8_t row_u8x = svld1_u8(depth_predicate_b8x, a_row);
413
+ svwrite_hor_za32_u32_m(0, row_in_tile, batch_predicate_b32x, svreinterpret_u32_u8(row_u8x));
394
414
  }
395
415
 
396
416
  // Save A columns from ZA0 to stack buffer
397
417
  for (nk_size_t s = 0; s < u32s_this_tile; s++)
398
- svst1_u32(predicate_all_u32x, a_buffer[s],
399
- svread_ver_za32_u32_m(svdup_u32(0), row_predicate_u32x, 0, s));
418
+ svst1_u32(predicate_all_b32x, a_buffer[s],
419
+ svread_ver_za32_u32_m(svdup_u32(0), row_predicate_b32x, 0, s));
400
420
 
401
421
  // B column tile 0
402
422
  svzero_mask_za(nk_sme_zero_za32_tile_0_);
403
423
  for (nk_size_t col = 0; col < tile_dim; col++) {
404
424
  nk_size_t const col_abs = (column_tile_index + 0) * tile_dim + col;
405
- if (col_abs < n_vectors) {
406
- nk_u32_t const *b_row = (nk_u32_t const *)((char const *)vectors + col_abs * stride) +
407
- d_start_u32;
408
- svld1_hor_za32(0, col, batch_predicate_u32x, b_row);
425
+ if (col_abs < vectors_count) {
426
+ nk_u8_t const *b_row = (nk_u8_t const *)vectors + col_abs * stride_in_bytes + d_start_u32 * 4;
427
+ svuint8_t col_u8x = svld1_u8(depth_predicate_b8x, b_row);
428
+ svwrite_hor_za32_u32_m(0, col, batch_predicate_b32x, svreinterpret_u32_u8(col_u8x));
409
429
  }
410
430
  }
411
431
  for (nk_size_t step = 0; step < u32s_this_tile; step++) {
412
- svuint32_t a_u32x = svld1_u32(predicate_all_u32x, a_buffer[step]);
413
- svuint32_t b_u32x = svread_ver_za32_u32_m(svdup_u32(0), predicate_all_u32x, 0, step);
414
- svbmopa_za32_u32_m(1, row_predicate_u32x, predicate_all_u32x, a_u32x, b_u32x);
432
+ svuint32_t a_u32x = svld1_u32(predicate_all_b32x, a_buffer[step]);
433
+ svuint32_t b_u32x = svread_ver_za32_u32_m(svdup_u32(0), predicate_all_b32x, 0, step);
434
+ svbmopa_za32_u32_m(1, row_predicate_b32x, predicate_all_b32x, a_u32x, b_u32x);
415
435
  }
416
436
 
417
437
  // B column tile 1
418
438
  svzero_mask_za(nk_sme_zero_za32_tile_0_);
419
439
  for (nk_size_t col = 0; col < tile_dim; col++) {
420
440
  nk_size_t const col_abs = (column_tile_index + 1) * tile_dim + col;
421
- if (col_abs < n_vectors) {
422
- nk_u32_t const *b_row = (nk_u32_t const *)((char const *)vectors + col_abs * stride) +
423
- d_start_u32;
424
- svld1_hor_za32(0, col, batch_predicate_u32x, b_row);
441
+ if (col_abs < vectors_count) {
442
+ nk_u8_t const *b_row = (nk_u8_t const *)vectors + col_abs * stride_in_bytes + d_start_u32 * 4;
443
+ svuint8_t col_u8x = svld1_u8(depth_predicate_b8x, b_row);
444
+ svwrite_hor_za32_u32_m(0, col, batch_predicate_b32x, svreinterpret_u32_u8(col_u8x));
425
445
  }
426
446
  }
427
447
  for (nk_size_t step = 0; step < u32s_this_tile; step++) {
428
- svuint32_t a_u32x = svld1_u32(predicate_all_u32x, a_buffer[step]);
429
- svuint32_t b_u32x = svread_ver_za32_u32_m(svdup_u32(0), predicate_all_u32x, 0, step);
430
- svbmopa_za32_u32_m(2, row_predicate_u32x, predicate_all_u32x, a_u32x, b_u32x);
448
+ svuint32_t a_u32x = svld1_u32(predicate_all_b32x, a_buffer[step]);
449
+ svuint32_t b_u32x = svread_ver_za32_u32_m(svdup_u32(0), predicate_all_b32x, 0, step);
450
+ svbmopa_za32_u32_m(2, row_predicate_b32x, predicate_all_b32x, a_u32x, b_u32x);
431
451
  }
432
452
 
433
453
  // B column tile 2
434
454
  svzero_mask_za(nk_sme_zero_za32_tile_0_);
435
455
  for (nk_size_t col = 0; col < tile_dim; col++) {
436
456
  nk_size_t const col_abs = (column_tile_index + 2) * tile_dim + col;
437
- if (col_abs < n_vectors) {
438
- nk_u32_t const *b_row = (nk_u32_t const *)((char const *)vectors + col_abs * stride) +
439
- d_start_u32;
440
- svld1_hor_za32(0, col, batch_predicate_u32x, b_row);
457
+ if (col_abs < vectors_count) {
458
+ nk_u8_t const *b_row = (nk_u8_t const *)vectors + col_abs * stride_in_bytes + d_start_u32 * 4;
459
+ svuint8_t col_u8x = svld1_u8(depth_predicate_b8x, b_row);
460
+ svwrite_hor_za32_u32_m(0, col, batch_predicate_b32x, svreinterpret_u32_u8(col_u8x));
441
461
  }
442
462
  }
443
463
  for (nk_size_t step = 0; step < u32s_this_tile; step++) {
444
- svuint32_t a_u32x = svld1_u32(predicate_all_u32x, a_buffer[step]);
445
- svuint32_t b_u32x = svread_ver_za32_u32_m(svdup_u32(0), predicate_all_u32x, 0, step);
446
- svbmopa_za32_u32_m(3, row_predicate_u32x, predicate_all_u32x, a_u32x, b_u32x);
464
+ svuint32_t a_u32x = svld1_u32(predicate_all_b32x, a_buffer[step]);
465
+ svuint32_t b_u32x = svread_ver_za32_u32_m(svdup_u32(0), predicate_all_b32x, 0, step);
466
+ svbmopa_za32_u32_m(3, row_predicate_b32x, predicate_all_b32x, a_u32x, b_u32x);
447
467
  }
448
468
  }
449
469
 
450
470
  // Extract ZA1-3: hamming = depth_bits - ZA[i][j]
451
471
  for (nk_size_t row = 0; row < rows_clamped; row++) {
452
- nk_u32_t *c_row = (nk_u32_t *)((char *)result + (row_tile_start + row) * result_stride);
453
-
454
- svuint32_t za1_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_u32x, 1, row);
455
- svuint32_t za2_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_u32x, 2, row);
456
- svuint32_t za3_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_u32x, 3, row);
457
-
458
- svst1_u32(predicate_all_u32x, c_row + (column_tile_index + 0) * tile_dim,
459
- svsub_u32_x(predicate_all_u32x, depth_u32x, za1_u32x));
460
- svst1_u32(predicate_all_u32x, c_row + (column_tile_index + 1) * tile_dim,
461
- svsub_u32_x(predicate_all_u32x, depth_u32x, za2_u32x));
462
- svst1_u32(predicate_all_u32x, c_row + (column_tile_index + 2) * tile_dim,
463
- svsub_u32_x(predicate_all_u32x, depth_u32x, za3_u32x));
472
+ nk_u32_t *c_row = (nk_u32_t *)((char *)result + (row_tile_start + row) * result_stride_in_bytes);
473
+
474
+ svuint32_t za1_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_b32x, 1, row);
475
+ svuint32_t za2_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_b32x, 2, row);
476
+ svuint32_t za3_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_b32x, 3, row);
477
+
478
+ svst1_u32(predicate_all_b32x, c_row + (column_tile_index + 0) * tile_dim,
479
+ svsub_u32_x(predicate_all_b32x, depth_u32x, za1_u32x));
480
+ svst1_u32(predicate_all_b32x, c_row + (column_tile_index + 1) * tile_dim,
481
+ svsub_u32_x(predicate_all_b32x, depth_u32x, za2_u32x));
482
+ svst1_u32(predicate_all_b32x, c_row + (column_tile_index + 2) * tile_dim,
483
+ svsub_u32_x(predicate_all_b32x, depth_u32x, za3_u32x));
464
484
  }
465
485
  }
466
486
 
467
487
  // Remainder: 1 column tile at a time using ZA1
468
488
  for (; column_tile_index < column_tile_count; column_tile_index++) {
469
489
  nk_size_t const col_tile_start = column_tile_index * tile_dim;
470
- nk_size_t const cols_remaining = (col_tile_start + tile_dim <= n_vectors) ? tile_dim
471
- : (n_vectors - col_tile_start);
472
- svbool_t const column_predicate_u32x = svwhilelt_b32_u64(0u, cols_remaining);
490
+ nk_size_t const cols_remaining = (col_tile_start + tile_dim <= vectors_count)
491
+ ? tile_dim
492
+ : (vectors_count - col_tile_start);
493
+ svbool_t const column_predicate_b32x = svwhilelt_b32_u64(0u, cols_remaining);
473
494
 
474
495
  svzero_mask_za(nk_sme_zero_za32_tile_1_);
475
496
 
476
497
  for (nk_size_t d_tile = 0; d_tile < depth_tile_count; d_tile++) {
477
498
  nk_size_t const d_start_u32 = d_tile * depth_tile_size;
478
- nk_size_t const u32s_this_tile = (d_start_u32 + depth_tile_size <= depth_u32_total)
499
+ nk_size_t const u32s_this_tile = (d_start_u32 + depth_tile_size <= depth_words)
479
500
  ? depth_tile_size
480
- : (depth_u32_total > d_start_u32 ? depth_u32_total - d_start_u32
481
- : 0);
501
+ : (depth_words > d_start_u32 ? depth_words - d_start_u32 : 0);
482
502
  if (u32s_this_tile == 0) break;
483
503
 
484
504
  svzero_mask_za(nk_sme_zero_za32_tile_0_);
485
- svbool_t const batch_predicate_u32x = svwhilelt_b32_u64(0u, u32s_this_tile);
505
+ svbool_t const batch_predicate_b32x = svwhilelt_b32_u64(0u, u32s_this_tile);
486
506
 
487
507
  // Load A rows into ZA0 horizontally
508
+ svbool_t const depth_predicate_b8x = svwhilelt_b8_u64(d_start_u32 * 4, depth_bytes);
488
509
  for (nk_size_t row_in_tile = 0; row_in_tile < rows_clamped; row_in_tile++) {
489
- nk_u32_t const *a_row_u32 = (nk_u32_t const *)((char const *)vectors +
490
- (row_tile_start + row_in_tile) * stride) +
491
- d_start_u32;
492
- svld1_hor_za32(0, row_in_tile, batch_predicate_u32x, a_row_u32);
510
+ nk_u8_t const *a_row = (nk_u8_t const *)vectors + (row_tile_start + row_in_tile) * stride_in_bytes +
511
+ d_start_u32 * 4;
512
+ svuint8_t row_u8x = svld1_u8(depth_predicate_b8x, a_row);
513
+ svwrite_hor_za32_u32_m(0, row_in_tile, batch_predicate_b32x, svreinterpret_u32_u8(row_u8x));
493
514
  }
494
515
 
495
516
  // Save A columns from ZA0 to stack buffer
496
517
  for (nk_size_t s = 0; s < u32s_this_tile; s++)
497
- svst1_u32(predicate_all_u32x, a_buffer[s],
498
- svread_ver_za32_u32_m(svdup_u32(0), row_predicate_u32x, 0, s));
518
+ svst1_u32(predicate_all_b32x, a_buffer[s],
519
+ svread_ver_za32_u32_m(svdup_u32(0), row_predicate_b32x, 0, s));
499
520
 
500
521
  // Load B column tile into ZA0
501
522
  svzero_mask_za(nk_sme_zero_za32_tile_0_);
502
523
  for (nk_size_t col = 0; col < tile_dim; col++) {
503
524
  nk_size_t const col_abs = col_tile_start + col;
504
- if (col_abs < n_vectors) {
505
- nk_u32_t const *b_row = (nk_u32_t const *)((char const *)vectors + col_abs * stride) +
506
- d_start_u32;
507
- svld1_hor_za32(0, col, batch_predicate_u32x, b_row);
525
+ if (col_abs < vectors_count) {
526
+ nk_u8_t const *b_row = (nk_u8_t const *)vectors + col_abs * stride_in_bytes + d_start_u32 * 4;
527
+ svuint8_t col_u8x = svld1_u8(depth_predicate_b8x, b_row);
528
+ svwrite_hor_za32_u32_m(0, col, batch_predicate_b32x, svreinterpret_u32_u8(col_u8x));
508
529
  }
509
530
  }
510
531
  for (nk_size_t step = 0; step < u32s_this_tile; step++) {
511
- svuint32_t a_u32x = svld1_u32(predicate_all_u32x, a_buffer[step]);
512
- svuint32_t b_u32x = svread_ver_za32_u32_m(svdup_u32(0), column_predicate_u32x, 0, step);
513
- svbmopa_za32_u32_m(1, row_predicate_u32x, column_predicate_u32x, a_u32x, b_u32x);
532
+ svuint32_t a_u32x = svld1_u32(predicate_all_b32x, a_buffer[step]);
533
+ svuint32_t b_u32x = svread_ver_za32_u32_m(svdup_u32(0), column_predicate_b32x, 0, step);
534
+ svbmopa_za32_u32_m(1, row_predicate_b32x, column_predicate_b32x, a_u32x, b_u32x);
514
535
  }
515
536
  }
516
537
 
517
538
  for (nk_size_t row = 0; row < rows_clamped; row++) {
518
- svuint32_t za1_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_u32x, 1, row);
519
- svuint32_t hamming_u32x = svsub_u32_x(predicate_all_u32x, depth_u32x, za1_u32x);
520
- nk_u32_t *c_row = (nk_u32_t *)((char *)result + (row_tile_start + row) * result_stride);
521
- svst1_u32(column_predicate_u32x, c_row + col_tile_start, hamming_u32x);
539
+ svuint32_t za1_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_b32x, 1, row);
540
+ svuint32_t hamming_u32x = svsub_u32_x(predicate_all_b32x, depth_u32x, za1_u32x);
541
+ nk_u32_t *c_row = (nk_u32_t *)((char *)result + (row_tile_start + row) * result_stride_in_bytes);
542
+ svst1_u32(column_predicate_b32x, c_row + col_tile_start, hamming_u32x);
522
543
  }
523
544
  }
524
545
  }
525
546
  }
526
547
 
527
- NK_PUBLIC void nk_hammings_symmetric_u1_smebi32(nk_u1x8_t const *vectors, nk_size_t n_vectors, nk_size_t depth_bits,
528
- nk_size_t stride, nk_u32_t *result, nk_size_t result_stride,
529
- nk_size_t row_start, nk_size_t row_count) {
530
- nk_hammings_symmetric_u1_smebi32_streaming_(vectors, n_vectors, depth_bits, stride, result, result_stride,
531
- row_start, row_count);
548
+ NK_PUBLIC void nk_hammings_symmetric_u1_smebi32(nk_u1x8_t const *vectors, nk_size_t vectors_count, nk_size_t depth_bits,
549
+ nk_size_t stride_in_bytes, nk_u32_t *result,
550
+ nk_size_t result_stride_in_bytes, nk_size_t row_start,
551
+ nk_size_t row_count) {
552
+ nk_hammings_symmetric_u1_smebi32_streaming_(vectors, vectors_count, depth_bits, stride_in_bytes, result,
553
+ result_stride_in_bytes, row_start, row_count);
532
554
  }
533
555
 
534
- #pragma endregion // Hamming Distance
556
+ #pragma endregion Hamming Distance
535
557
 
536
558
  /*
537
559
  * Jaccard distance via BMOPA matching counts + algebraic normalization.
@@ -570,31 +592,33 @@ __arm_locally_streaming __arm_new("za") static void nk_jaccards_packed_u1_smebi3
570
592
  nk_size_t const tile_dim = svcntw(); // 16 for 512-bit SVL
571
593
  nk_size_t const depth_tile_size = svcntw(); // 16 u32 per depth tile
572
594
  nk_size_t const tile_elements = tile_dim * depth_tile_size;
573
- nk_size_t const depth_u32_total = nk_size_divide_round_up_(depth_bits, 32);
595
+ // BMOPA processes binary data in 32-bit words: each svbmopa_za32_u32_m step
596
+ // handles one u32 (32 bits) across all row×column pairs simultaneously.
597
+ nk_size_t const depth_words = nk_size_divide_round_up_(depth_bits, 32);
598
+ nk_size_t const depth_bytes = depth_bits / 8;
574
599
 
575
600
  nk_u32_t const *b_tiles = (nk_u32_t const *)((char const *)b_packed + sizeof(nk_sets_smebi32_packed_header_t));
576
601
  nk_u32_t const *b_norms = header->norms_offset ? (nk_u32_t const *)((char const *)b_packed + header->norms_offset)
577
602
  : (nk_u32_t const *)0;
578
603
 
579
- svbool_t const predicate_all_f32x = svptrue_b32();
580
- svfloat32_t const depth_f32x = svdup_f32((nk_f32_t)depth_bits);
604
+ svbool_t const predicate_all_b32x = svptrue_b32();
605
+ svfloat32_t const depth_f32x = svdup_f32((nk_f32_t)(depth_words * 32));
581
606
  svfloat32_t const half_f32x = svdup_f32(0.5f);
582
607
  svfloat32_t const one_f32x = svdup_f32(1.0f);
583
608
  svfloat32_t const zero_f32x = svdup_f32(0.0f);
584
- nk_size_t const depth_in_bytes = nk_size_divide_round_up_(depth_bits, 8);
585
609
  nk_size_t const row_tile_count_a = nk_size_divide_round_up_(row_count_a, tile_dim);
586
610
 
587
611
  for (nk_size_t row_tile_a = 0; row_tile_a < row_tile_count_a; row_tile_a++) {
588
612
  nk_size_t const row_start_a = row_tile_a * tile_dim;
589
613
  nk_size_t const rows_a_remaining = (row_start_a + tile_dim <= row_count_a) ? tile_dim
590
614
  : (row_count_a - row_start_a);
591
- svbool_t const row_predicate_f32x = svwhilelt_b32_u64(0u, rows_a_remaining);
615
+ svbool_t const row_predicate_b32x = svwhilelt_b32_u64(0u, rows_a_remaining);
592
616
 
593
617
  // Compute A tile norms using streaming SVE popcount
594
618
  NK_ALIGN64 nk_f32_t a_tile_norms[16];
595
619
  for (nk_size_t r = 0; r < rows_a_remaining; r++) {
596
620
  nk_u1x8_t const *a_row = (nk_u1x8_t const *)((char const *)a + (row_start_a + r) * a_stride_in_bytes);
597
- a_tile_norms[r] = (nk_f32_t)nk_sets_reduce_sumsq_u1_streaming_(a_row, depth_in_bytes);
621
+ a_tile_norms[r] = (nk_f32_t)nk_sets_reduce_sumsq_u1_streaming_(a_row, depth_bytes);
598
622
  }
599
623
 
600
624
  // Fast path: 3 B column tiles using ZA1-ZA3 (ZA0.S = staging)
@@ -604,22 +628,23 @@ __arm_locally_streaming __arm_new("za") static void nk_jaccards_packed_u1_smebi3
604
628
 
605
629
  for (nk_size_t d_tile = 0; d_tile < depth_tile_count; d_tile++) {
606
630
  nk_size_t const d_start_u32 = d_tile * depth_tile_size;
607
- nk_size_t const u32s_this_tile = (d_start_u32 + depth_tile_size <= depth_u32_total)
631
+ nk_size_t const u32s_this_tile = (d_start_u32 + depth_tile_size <= depth_words)
608
632
  ? depth_tile_size
609
- : (depth_u32_total > d_start_u32 ? depth_u32_total - d_start_u32
610
- : 0);
633
+ : (depth_words > d_start_u32 ? depth_words - d_start_u32 : 0);
611
634
  if (u32s_this_tile == 0) break;
612
635
 
613
636
  svzero_mask_za(nk_sme_zero_za32_tile_0_);
614
637
 
615
- svbool_t const batch_predicate_f32x = svwhilelt_b32_u64(0u, u32s_this_tile);
638
+ svbool_t const batch_predicate_b32x = svwhilelt_b32_u64(0u, u32s_this_tile);
639
+
640
+ svbool_t const depth_predicate_b8x = svwhilelt_b8_u64(d_start_u32 * 4, depth_bytes);
616
641
 
617
- // Load A rows into ZA0.S horizontally as u32 words
642
+ // Load A rows into ZA0.S, byte-predicated to zero garbage bits
618
643
  for (nk_size_t row_in_tile = 0; row_in_tile < rows_a_remaining; row_in_tile++) {
619
- nk_u32_t const *a_row_u32 = (nk_u32_t const *)((char const *)a +
620
- (row_start_a + row_in_tile) * a_stride_in_bytes) +
621
- d_start_u32;
622
- svld1_hor_za32(0, row_in_tile, batch_predicate_f32x, a_row_u32);
644
+ nk_u8_t const *a_row = (nk_u8_t const *)a + (row_start_a + row_in_tile) * a_stride_in_bytes +
645
+ d_start_u32 * 4;
646
+ svuint8_t row_u8x = svld1_u8(depth_predicate_b8x, a_row);
647
+ svwrite_hor_za32_u32_m(0, row_in_tile, batch_predicate_b32x, svreinterpret_u32_u8(row_u8x));
623
648
  }
624
649
 
625
650
  // B tile pointers for 3 column tiles
@@ -629,25 +654,25 @@ __arm_locally_streaming __arm_new("za") static void nk_jaccards_packed_u1_smebi3
629
654
 
630
655
  // Vertical read + BMOPA for each depth step
631
656
  for (nk_size_t step = 0; step < u32s_this_tile; step++) {
632
- svuint32_t a_column_u32x = svread_ver_za32_u32_m(svdup_u32(0), row_predicate_f32x, 0, step);
633
-
634
- svbmopa_za32_u32_m(1, row_predicate_f32x, predicate_all_f32x, a_column_u32x,
635
- svld1_u32(predicate_all_f32x, b_tile0 + step * tile_dim));
636
- svbmopa_za32_u32_m(2, row_predicate_f32x, predicate_all_f32x, a_column_u32x,
637
- svld1_u32(predicate_all_f32x, b_tile1 + step * tile_dim));
638
- svbmopa_za32_u32_m(3, row_predicate_f32x, predicate_all_f32x, a_column_u32x,
639
- svld1_u32(predicate_all_f32x, b_tile2 + step * tile_dim));
657
+ svuint32_t a_column_u32x = svread_ver_za32_u32_m(svdup_u32(0), row_predicate_b32x, 0, step);
658
+
659
+ svbmopa_za32_u32_m(1, row_predicate_b32x, predicate_all_b32x, a_column_u32x,
660
+ svld1_u32(predicate_all_b32x, b_tile0 + step * tile_dim));
661
+ svbmopa_za32_u32_m(2, row_predicate_b32x, predicate_all_b32x, a_column_u32x,
662
+ svld1_u32(predicate_all_b32x, b_tile1 + step * tile_dim));
663
+ svbmopa_za32_u32_m(3, row_predicate_b32x, predicate_all_b32x, a_column_u32x,
664
+ svld1_u32(predicate_all_b32x, b_tile2 + step * tile_dim));
640
665
  }
641
666
  }
642
667
 
643
668
  // Extract from ZA1-3: Jaccard normalization via streaming SVE
644
669
  // Hoist B norms outside row loop (same for all A rows in this tile-pair)
645
670
  svfloat32_t b_norms_0_f32x = svcvt_f32_u32_x(
646
- predicate_all_f32x, svld1_u32(predicate_all_f32x, b_norms + (row_tile_b + 0) * tile_dim));
671
+ predicate_all_b32x, svld1_u32(predicate_all_b32x, b_norms + (row_tile_b + 0) * tile_dim));
647
672
  svfloat32_t b_norms_1_f32x = svcvt_f32_u32_x(
648
- predicate_all_f32x, svld1_u32(predicate_all_f32x, b_norms + (row_tile_b + 1) * tile_dim));
673
+ predicate_all_b32x, svld1_u32(predicate_all_b32x, b_norms + (row_tile_b + 1) * tile_dim));
649
674
  svfloat32_t b_norms_2_f32x = svcvt_f32_u32_x(
650
- predicate_all_f32x, svld1_u32(predicate_all_f32x, b_norms + (row_tile_b + 2) * tile_dim));
675
+ predicate_all_b32x, svld1_u32(predicate_all_b32x, b_norms + (row_tile_b + 2) * tile_dim));
651
676
 
652
677
  for (nk_size_t row = 0; row < rows_a_remaining; row++) {
653
678
  nk_f32_t *c_row = (nk_f32_t *)((char *)c + (row_start_a + row) * c_stride_in_bytes);
@@ -655,54 +680,54 @@ __arm_locally_streaming __arm_new("za") static void nk_jaccards_packed_u1_smebi3
655
680
 
656
681
  // ZA1
657
682
  {
658
- svuint32_t za1_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_f32x, 1, row);
659
- svfloat32_t matching_f32x = svcvt_f32_u32_x(predicate_all_f32x, za1_u32x);
660
- svfloat32_t sum_norms_f32x = svadd_f32_x(predicate_all_f32x, norm_a_f32x, b_norms_0_f32x);
683
+ svuint32_t za1_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_b32x, 1, row);
684
+ svfloat32_t matching_f32x = svcvt_f32_u32_x(predicate_all_b32x, za1_u32x);
685
+ svfloat32_t sum_norms_f32x = svadd_f32_x(predicate_all_b32x, norm_a_f32x, b_norms_0_f32x);
661
686
  svfloat32_t intersection_f32x = svmul_f32_x(
662
- predicate_all_f32x,
663
- svadd_f32_x(predicate_all_f32x, svsub_f32_x(predicate_all_f32x, sum_norms_f32x, depth_f32x),
687
+ predicate_all_b32x,
688
+ svadd_f32_x(predicate_all_b32x, svsub_f32_x(predicate_all_b32x, sum_norms_f32x, depth_f32x),
664
689
  matching_f32x),
665
690
  half_f32x);
666
- svfloat32_t union_val_f32x = svsub_f32_x(predicate_all_f32x, sum_norms_f32x, intersection_f32x);
667
- svbool_t nonzero_f32x = svcmpne_f32(predicate_all_f32x, union_val_f32x, zero_f32x);
668
- svfloat32_t ratio_f32x = svdiv_f32_x(predicate_all_f32x, intersection_f32x, union_val_f32x);
691
+ svfloat32_t union_val_f32x = svsub_f32_x(predicate_all_b32x, sum_norms_f32x, intersection_f32x);
692
+ svbool_t nonzero_b32x = svcmpne_f32(predicate_all_b32x, union_val_f32x, zero_f32x);
693
+ svfloat32_t ratio_f32x = svdiv_f32_x(predicate_all_b32x, intersection_f32x, union_val_f32x);
669
694
  svfloat32_t jaccard_f32x = svsel_f32(
670
- nonzero_f32x, svsub_f32_x(predicate_all_f32x, one_f32x, ratio_f32x), one_f32x);
671
- svst1_f32(predicate_all_f32x, c_row + (row_tile_b + 0) * tile_dim, jaccard_f32x);
695
+ nonzero_b32x, svsub_f32_x(predicate_all_b32x, one_f32x, ratio_f32x), one_f32x);
696
+ svst1_f32(predicate_all_b32x, c_row + (row_tile_b + 0) * tile_dim, jaccard_f32x);
672
697
  }
673
698
  // ZA2
674
699
  {
675
- svuint32_t za2_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_f32x, 2, row);
676
- svfloat32_t matching_f32x = svcvt_f32_u32_x(predicate_all_f32x, za2_u32x);
677
- svfloat32_t sum_norms_f32x = svadd_f32_x(predicate_all_f32x, norm_a_f32x, b_norms_1_f32x);
700
+ svuint32_t za2_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_b32x, 2, row);
701
+ svfloat32_t matching_f32x = svcvt_f32_u32_x(predicate_all_b32x, za2_u32x);
702
+ svfloat32_t sum_norms_f32x = svadd_f32_x(predicate_all_b32x, norm_a_f32x, b_norms_1_f32x);
678
703
  svfloat32_t intersection_f32x = svmul_f32_x(
679
- predicate_all_f32x,
680
- svadd_f32_x(predicate_all_f32x, svsub_f32_x(predicate_all_f32x, sum_norms_f32x, depth_f32x),
704
+ predicate_all_b32x,
705
+ svadd_f32_x(predicate_all_b32x, svsub_f32_x(predicate_all_b32x, sum_norms_f32x, depth_f32x),
681
706
  matching_f32x),
682
707
  half_f32x);
683
- svfloat32_t union_val_f32x = svsub_f32_x(predicate_all_f32x, sum_norms_f32x, intersection_f32x);
684
- svbool_t nonzero_f32x = svcmpne_f32(predicate_all_f32x, union_val_f32x, zero_f32x);
685
- svfloat32_t ratio_f32x = svdiv_f32_x(predicate_all_f32x, intersection_f32x, union_val_f32x);
708
+ svfloat32_t union_val_f32x = svsub_f32_x(predicate_all_b32x, sum_norms_f32x, intersection_f32x);
709
+ svbool_t nonzero_b32x = svcmpne_f32(predicate_all_b32x, union_val_f32x, zero_f32x);
710
+ svfloat32_t ratio_f32x = svdiv_f32_x(predicate_all_b32x, intersection_f32x, union_val_f32x);
686
711
  svfloat32_t jaccard_f32x = svsel_f32(
687
- nonzero_f32x, svsub_f32_x(predicate_all_f32x, one_f32x, ratio_f32x), one_f32x);
688
- svst1_f32(predicate_all_f32x, c_row + (row_tile_b + 1) * tile_dim, jaccard_f32x);
712
+ nonzero_b32x, svsub_f32_x(predicate_all_b32x, one_f32x, ratio_f32x), one_f32x);
713
+ svst1_f32(predicate_all_b32x, c_row + (row_tile_b + 1) * tile_dim, jaccard_f32x);
689
714
  }
690
715
  // ZA3
691
716
  {
692
- svuint32_t za3_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_f32x, 3, row);
693
- svfloat32_t matching_f32x = svcvt_f32_u32_x(predicate_all_f32x, za3_u32x);
694
- svfloat32_t sum_norms_f32x = svadd_f32_x(predicate_all_f32x, norm_a_f32x, b_norms_2_f32x);
717
+ svuint32_t za3_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_b32x, 3, row);
718
+ svfloat32_t matching_f32x = svcvt_f32_u32_x(predicate_all_b32x, za3_u32x);
719
+ svfloat32_t sum_norms_f32x = svadd_f32_x(predicate_all_b32x, norm_a_f32x, b_norms_2_f32x);
695
720
  svfloat32_t intersection_f32x = svmul_f32_x(
696
- predicate_all_f32x,
697
- svadd_f32_x(predicate_all_f32x, svsub_f32_x(predicate_all_f32x, sum_norms_f32x, depth_f32x),
721
+ predicate_all_b32x,
722
+ svadd_f32_x(predicate_all_b32x, svsub_f32_x(predicate_all_b32x, sum_norms_f32x, depth_f32x),
698
723
  matching_f32x),
699
724
  half_f32x);
700
- svfloat32_t union_val_f32x = svsub_f32_x(predicate_all_f32x, sum_norms_f32x, intersection_f32x);
701
- svbool_t nonzero_f32x = svcmpne_f32(predicate_all_f32x, union_val_f32x, zero_f32x);
702
- svfloat32_t ratio_f32x = svdiv_f32_x(predicate_all_f32x, intersection_f32x, union_val_f32x);
725
+ svfloat32_t union_val_f32x = svsub_f32_x(predicate_all_b32x, sum_norms_f32x, intersection_f32x);
726
+ svbool_t nonzero_b32x = svcmpne_f32(predicate_all_b32x, union_val_f32x, zero_f32x);
727
+ svfloat32_t ratio_f32x = svdiv_f32_x(predicate_all_b32x, intersection_f32x, union_val_f32x);
703
728
  svfloat32_t jaccard_f32x = svsel_f32(
704
- nonzero_f32x, svsub_f32_x(predicate_all_f32x, one_f32x, ratio_f32x), one_f32x);
705
- svst1_f32(predicate_all_f32x, c_row + (row_tile_b + 2) * tile_dim, jaccard_f32x);
729
+ nonzero_b32x, svsub_f32_x(predicate_all_b32x, one_f32x, ratio_f32x), one_f32x);
730
+ svst1_f32(predicate_all_b32x, c_row + (row_tile_b + 2) * tile_dim, jaccard_f32x);
706
731
  }
707
732
  }
708
733
  }
@@ -712,60 +737,60 @@ __arm_locally_streaming __arm_new("za") static void nk_jaccards_packed_u1_smebi3
712
737
  nk_size_t const row_start_b = row_tile_b * tile_dim;
713
738
  nk_size_t const rows_b_remaining = (row_start_b + tile_dim <= row_count_b) ? tile_dim
714
739
  : (row_count_b - row_start_b);
715
- svbool_t const column_predicate_f32x = svwhilelt_b32_u64(0u, rows_b_remaining);
740
+ svbool_t const column_predicate_b32x = svwhilelt_b32_u64(0u, rows_b_remaining);
716
741
 
717
742
  svzero_mask_za(nk_sme_zero_za32_tile_1_);
718
743
 
719
744
  for (nk_size_t d_tile = 0; d_tile < depth_tile_count; d_tile++) {
720
745
  nk_size_t const d_start_u32 = d_tile * depth_tile_size;
721
- nk_size_t const u32s_this_tile = (d_start_u32 + depth_tile_size <= depth_u32_total)
746
+ nk_size_t const u32s_this_tile = (d_start_u32 + depth_tile_size <= depth_words)
722
747
  ? depth_tile_size
723
- : (depth_u32_total > d_start_u32 ? depth_u32_total - d_start_u32
724
- : 0);
748
+ : (depth_words > d_start_u32 ? depth_words - d_start_u32 : 0);
725
749
  if (u32s_this_tile == 0) break;
726
750
 
727
751
  svzero_mask_za(nk_sme_zero_za32_tile_0_);
728
752
 
729
- svbool_t const batch_predicate_f32x = svwhilelt_b32_u64(0u, u32s_this_tile);
753
+ svbool_t const batch_predicate_b32x = svwhilelt_b32_u64(0u, u32s_this_tile);
730
754
 
731
755
  // Load A rows into ZA0.S horizontally
756
+ svbool_t const depth_predicate_b8x = svwhilelt_b8_u64(d_start_u32 * 4, depth_bytes);
732
757
  for (nk_size_t row_in_tile = 0; row_in_tile < rows_a_remaining; row_in_tile++) {
733
- nk_u32_t const *a_row_u32 = (nk_u32_t const *)((char const *)a +
734
- (row_start_a + row_in_tile) * a_stride_in_bytes) +
735
- d_start_u32;
736
- svld1_hor_za32(0, row_in_tile, batch_predicate_f32x, a_row_u32);
758
+ nk_u8_t const *a_row = (nk_u8_t const *)a + (row_start_a + row_in_tile) * a_stride_in_bytes +
759
+ d_start_u32 * 4;
760
+ svuint8_t row_u8x = svld1_u8(depth_predicate_b8x, a_row);
761
+ svwrite_hor_za32_u32_m(0, row_in_tile, batch_predicate_b32x, svreinterpret_u32_u8(row_u8x));
737
762
  }
738
763
 
739
764
  nk_u32_t const *b_tile = b_tiles + (row_tile_b * depth_tile_count + d_tile) * tile_elements;
740
765
 
741
766
  // Vertical read + BMOPA
742
767
  for (nk_size_t step = 0; step < u32s_this_tile; step++) {
743
- svuint32_t a_column_u32x = svread_ver_za32_u32_m(svdup_u32(0), row_predicate_f32x, 0, step);
744
- svuint32_t b_u32x = svld1_u32(predicate_all_f32x, b_tile + step * tile_dim);
745
- svbmopa_za32_u32_m(1, row_predicate_f32x, column_predicate_f32x, a_column_u32x, b_u32x);
768
+ svuint32_t a_column_u32x = svread_ver_za32_u32_m(svdup_u32(0), row_predicate_b32x, 0, step);
769
+ svuint32_t b_u32x = svld1_u32(predicate_all_b32x, b_tile + step * tile_dim);
770
+ svbmopa_za32_u32_m(1, row_predicate_b32x, column_predicate_b32x, a_column_u32x, b_u32x);
746
771
  }
747
772
  }
748
773
 
749
774
  // Extract from ZA1: Jaccard normalization
750
- svfloat32_t b_norms_f32x = svcvt_f32_u32_x(predicate_all_f32x,
751
- svld1_u32(predicate_all_f32x, b_norms + row_start_b));
775
+ svfloat32_t b_norms_f32x = svcvt_f32_u32_x(predicate_all_b32x,
776
+ svld1_u32(predicate_all_b32x, b_norms + row_start_b));
752
777
  for (nk_size_t row = 0; row < rows_a_remaining; row++) {
753
- svuint32_t za1_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_f32x, 1, row);
754
- svfloat32_t matching_f32x = svcvt_f32_u32_x(predicate_all_f32x, za1_u32x);
778
+ svuint32_t za1_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_b32x, 1, row);
779
+ svfloat32_t matching_f32x = svcvt_f32_u32_x(predicate_all_b32x, za1_u32x);
755
780
  svfloat32_t norm_a_f32x = svdup_f32(a_tile_norms[row]);
756
- svfloat32_t sum_norms_f32x = svadd_f32_x(predicate_all_f32x, norm_a_f32x, b_norms_f32x);
781
+ svfloat32_t sum_norms_f32x = svadd_f32_x(predicate_all_b32x, norm_a_f32x, b_norms_f32x);
757
782
  svfloat32_t intersection_f32x = svmul_f32_x(
758
- predicate_all_f32x,
759
- svadd_f32_x(predicate_all_f32x, svsub_f32_x(predicate_all_f32x, sum_norms_f32x, depth_f32x),
783
+ predicate_all_b32x,
784
+ svadd_f32_x(predicate_all_b32x, svsub_f32_x(predicate_all_b32x, sum_norms_f32x, depth_f32x),
760
785
  matching_f32x),
761
786
  half_f32x);
762
- svfloat32_t union_val_f32x = svsub_f32_x(predicate_all_f32x, sum_norms_f32x, intersection_f32x);
763
- svbool_t nonzero_f32x = svcmpne_f32(predicate_all_f32x, union_val_f32x, zero_f32x);
764
- svfloat32_t ratio_f32x = svdiv_f32_x(predicate_all_f32x, intersection_f32x, union_val_f32x);
765
- svfloat32_t jaccard_f32x = svsel_f32(nonzero_f32x,
766
- svsub_f32_x(predicate_all_f32x, one_f32x, ratio_f32x), one_f32x);
787
+ svfloat32_t union_val_f32x = svsub_f32_x(predicate_all_b32x, sum_norms_f32x, intersection_f32x);
788
+ svbool_t nonzero_b32x = svcmpne_f32(predicate_all_b32x, union_val_f32x, zero_f32x);
789
+ svfloat32_t ratio_f32x = svdiv_f32_x(predicate_all_b32x, intersection_f32x, union_val_f32x);
790
+ svfloat32_t jaccard_f32x = svsel_f32(nonzero_b32x,
791
+ svsub_f32_x(predicate_all_b32x, one_f32x, ratio_f32x), one_f32x);
767
792
  nk_f32_t *c_row = (nk_f32_t *)((char *)c + (row_start_a + row) * c_stride_in_bytes);
768
- svst1_f32(column_predicate_f32x, c_row + row_start_b, jaccard_f32x);
793
+ svst1_f32(column_predicate_b32x, c_row + row_start_b, jaccard_f32x);
769
794
  }
770
795
  }
771
796
  }
@@ -784,17 +809,19 @@ NK_PUBLIC void nk_jaccards_packed_u1_smebi32(nk_u1x8_t const *a, void const *b_p
784
809
  * Norms computed on-the-fly using streaming SVE popcount.
785
810
  */
786
811
  __arm_locally_streaming __arm_new("za") static void nk_jaccards_symmetric_u1_smebi32_streaming_(
787
- nk_u1x8_t const *vectors, nk_size_t n_vectors, nk_size_t depth_bits, nk_size_t stride, nk_f32_t *result,
788
- nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
812
+ nk_u1x8_t const *vectors, nk_size_t vectors_count, nk_size_t depth_bits, nk_size_t stride_in_bytes,
813
+ nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
789
814
 
790
815
  nk_size_t const tile_dim = svcntw(); // 16 for 512-bit SVL
791
816
  nk_size_t const depth_tile_size = svcntw(); // 16 u32 per depth tile
792
- nk_size_t const depth_u32_total = nk_size_divide_round_up_(depth_bits, 32);
793
- nk_size_t const depth_tile_count = nk_size_divide_round_up_(depth_u32_total, depth_tile_size);
794
- nk_size_t const depth_in_bytes = nk_size_divide_round_up_(depth_bits, NK_BITS_PER_BYTE);
795
-
796
- svbool_t const predicate_all_f32x = svptrue_b32();
797
- svfloat32_t const depth_f32x = svdup_f32((nk_f32_t)depth_bits);
817
+ // BMOPA processes binary data in 32-bit words: each svbmopa_za32_u32_m step
818
+ // handles one u32 (32 bits) across all row×column pairs simultaneously.
819
+ nk_size_t const depth_words = nk_size_divide_round_up_(depth_bits, 32);
820
+ nk_size_t const depth_tile_count = nk_size_divide_round_up_(depth_words, depth_tile_size);
821
+ nk_size_t const depth_bytes = depth_bits / 8;
822
+
823
+ svbool_t const predicate_all_b32x = svptrue_b32();
824
+ svfloat32_t const depth_f32x = svdup_f32((nk_f32_t)(depth_words * 32));
798
825
  svfloat32_t const half_f32x = svdup_f32(0.5f);
799
826
  svfloat32_t const one_f32x = svdup_f32(1.0f);
800
827
  svfloat32_t const zero_f32x = svdup_f32(0.0f);
@@ -802,20 +829,22 @@ __arm_locally_streaming __arm_new("za") static void nk_jaccards_symmetric_u1_sme
802
829
  NK_ALIGN64 nk_u32_t a_buffer[16][16]; // Stack buffer for A column save
803
830
 
804
831
  nk_size_t const row_end = row_start + row_count;
805
- nk_size_t const column_tile_count = nk_size_divide_round_up_(n_vectors, tile_dim);
832
+ nk_size_t const column_tile_count = nk_size_divide_round_up_(vectors_count, tile_dim);
806
833
 
807
- for (nk_size_t row_tile_start = row_start; row_tile_start < row_end && row_tile_start < n_vectors;
834
+ for (nk_size_t row_tile_start = row_start; row_tile_start < row_end && row_tile_start < vectors_count;
808
835
  row_tile_start += tile_dim) {
809
836
  nk_size_t const rows_remaining = (row_tile_start + tile_dim <= row_end) ? tile_dim : (row_end - row_tile_start);
810
- nk_size_t const rows_clamped = (row_tile_start + rows_remaining <= n_vectors) ? rows_remaining
811
- : (n_vectors - row_tile_start);
812
- svbool_t const row_predicate_f32x = svwhilelt_b32_u64(0u, rows_clamped);
837
+ nk_size_t const rows_clamped = (row_tile_start + rows_remaining <= vectors_count)
838
+ ? rows_remaining
839
+ : (vectors_count - row_tile_start);
840
+ svbool_t const row_predicate_b32x = svwhilelt_b32_u64(0u, rows_clamped);
813
841
 
814
842
  // Compute A tile norms
815
843
  NK_ALIGN64 nk_f32_t a_tile_norms[16];
816
844
  for (nk_size_t r = 0; r < rows_clamped; r++) {
817
- nk_u1x8_t const *a_row = (nk_u1x8_t const *)((char const *)vectors + (row_tile_start + r) * stride);
818
- a_tile_norms[r] = (nk_f32_t)nk_sets_reduce_sumsq_u1_streaming_(a_row, depth_in_bytes);
845
+ nk_u1x8_t const *a_row = (nk_u1x8_t const *)((char const *)vectors +
846
+ (row_tile_start + r) * stride_in_bytes);
847
+ a_tile_norms[r] = (nk_f32_t)nk_sets_reduce_sumsq_u1_streaming_(a_row, depth_bytes);
819
848
  }
820
849
  for (nk_size_t r = rows_clamped; r < tile_dim; r++) a_tile_norms[r] = 0.0f;
821
850
 
@@ -828,74 +857,74 @@ __arm_locally_streaming __arm_new("za") static void nk_jaccards_symmetric_u1_sme
828
857
 
829
858
  for (nk_size_t d_tile = 0; d_tile < depth_tile_count; d_tile++) {
830
859
  nk_size_t const d_start_u32 = d_tile * depth_tile_size;
831
- nk_size_t const u32s_this_tile = (d_start_u32 + depth_tile_size <= depth_u32_total)
860
+ nk_size_t const u32s_this_tile = (d_start_u32 + depth_tile_size <= depth_words)
832
861
  ? depth_tile_size
833
- : (depth_u32_total > d_start_u32 ? depth_u32_total - d_start_u32
834
- : 0);
862
+ : (depth_words > d_start_u32 ? depth_words - d_start_u32 : 0);
835
863
  if (u32s_this_tile == 0) break;
836
864
 
837
865
  // Load A rows into ZA0 horizontally
838
866
  svzero_mask_za(nk_sme_zero_za32_tile_0_);
839
- svbool_t const batch_predicate_f32x = svwhilelt_b32_u64(0u, u32s_this_tile);
867
+ svbool_t const batch_predicate_b32x = svwhilelt_b32_u64(0u, u32s_this_tile);
840
868
 
869
+ svbool_t const depth_predicate_b8x = svwhilelt_b8_u64(d_start_u32 * 4, depth_bytes);
841
870
  for (nk_size_t row_in_tile = 0; row_in_tile < rows_clamped; row_in_tile++) {
842
- nk_u32_t const *a_row_u32 = (nk_u32_t const *)((char const *)vectors +
843
- (row_tile_start + row_in_tile) * stride) +
844
- d_start_u32;
845
- svld1_hor_za32(0, row_in_tile, batch_predicate_f32x, a_row_u32);
871
+ nk_u8_t const *a_row = (nk_u8_t const *)vectors + (row_tile_start + row_in_tile) * stride_in_bytes +
872
+ d_start_u32 * 4;
873
+ svuint8_t row_u8x = svld1_u8(depth_predicate_b8x, a_row);
874
+ svwrite_hor_za32_u32_m(0, row_in_tile, batch_predicate_b32x, svreinterpret_u32_u8(row_u8x));
846
875
  }
847
876
 
848
877
  // Save A columns from ZA0 to stack buffer
849
878
  for (nk_size_t s = 0; s < u32s_this_tile; s++)
850
- svst1_u32(predicate_all_f32x, a_buffer[s],
851
- svread_ver_za32_u32_m(svdup_u32(0), row_predicate_f32x, 0, s));
879
+ svst1_u32(predicate_all_b32x, a_buffer[s],
880
+ svread_ver_za32_u32_m(svdup_u32(0), row_predicate_b32x, 0, s));
852
881
 
853
882
  // B column tile 0
854
883
  svzero_mask_za(nk_sme_zero_za32_tile_0_);
855
884
  for (nk_size_t col = 0; col < tile_dim; col++) {
856
885
  nk_size_t const col_abs = (column_tile_index + 0) * tile_dim + col;
857
- if (col_abs < n_vectors) {
858
- nk_u32_t const *b_row = (nk_u32_t const *)((char const *)vectors + col_abs * stride) +
859
- d_start_u32;
860
- svld1_hor_za32(0, col, batch_predicate_f32x, b_row);
886
+ if (col_abs < vectors_count) {
887
+ nk_u8_t const *b_row = (nk_u8_t const *)vectors + col_abs * stride_in_bytes + d_start_u32 * 4;
888
+ svuint8_t col_u8x = svld1_u8(depth_predicate_b8x, b_row);
889
+ svwrite_hor_za32_u32_m(0, col, batch_predicate_b32x, svreinterpret_u32_u8(col_u8x));
861
890
  }
862
891
  }
863
892
  for (nk_size_t step = 0; step < u32s_this_tile; step++) {
864
- svuint32_t a_u32x = svld1_u32(predicate_all_f32x, a_buffer[step]);
865
- svuint32_t b_u32x = svread_ver_za32_u32_m(svdup_u32(0), predicate_all_f32x, 0, step);
866
- svbmopa_za32_u32_m(1, row_predicate_f32x, predicate_all_f32x, a_u32x, b_u32x);
893
+ svuint32_t a_u32x = svld1_u32(predicate_all_b32x, a_buffer[step]);
894
+ svuint32_t b_u32x = svread_ver_za32_u32_m(svdup_u32(0), predicate_all_b32x, 0, step);
895
+ svbmopa_za32_u32_m(1, row_predicate_b32x, predicate_all_b32x, a_u32x, b_u32x);
867
896
  }
868
897
 
869
898
  // B column tile 1
870
899
  svzero_mask_za(nk_sme_zero_za32_tile_0_);
871
900
  for (nk_size_t col = 0; col < tile_dim; col++) {
872
901
  nk_size_t const col_abs = (column_tile_index + 1) * tile_dim + col;
873
- if (col_abs < n_vectors) {
874
- nk_u32_t const *b_row = (nk_u32_t const *)((char const *)vectors + col_abs * stride) +
875
- d_start_u32;
876
- svld1_hor_za32(0, col, batch_predicate_f32x, b_row);
902
+ if (col_abs < vectors_count) {
903
+ nk_u8_t const *b_row = (nk_u8_t const *)vectors + col_abs * stride_in_bytes + d_start_u32 * 4;
904
+ svuint8_t col_u8x = svld1_u8(depth_predicate_b8x, b_row);
905
+ svwrite_hor_za32_u32_m(0, col, batch_predicate_b32x, svreinterpret_u32_u8(col_u8x));
877
906
  }
878
907
  }
879
908
  for (nk_size_t step = 0; step < u32s_this_tile; step++) {
880
- svuint32_t a_u32x = svld1_u32(predicate_all_f32x, a_buffer[step]);
881
- svuint32_t b_u32x = svread_ver_za32_u32_m(svdup_u32(0), predicate_all_f32x, 0, step);
882
- svbmopa_za32_u32_m(2, row_predicate_f32x, predicate_all_f32x, a_u32x, b_u32x);
909
+ svuint32_t a_u32x = svld1_u32(predicate_all_b32x, a_buffer[step]);
910
+ svuint32_t b_u32x = svread_ver_za32_u32_m(svdup_u32(0), predicate_all_b32x, 0, step);
911
+ svbmopa_za32_u32_m(2, row_predicate_b32x, predicate_all_b32x, a_u32x, b_u32x);
883
912
  }
884
913
 
885
914
  // B column tile 2
886
915
  svzero_mask_za(nk_sme_zero_za32_tile_0_);
887
916
  for (nk_size_t col = 0; col < tile_dim; col++) {
888
917
  nk_size_t const col_abs = (column_tile_index + 2) * tile_dim + col;
889
- if (col_abs < n_vectors) {
890
- nk_u32_t const *b_row = (nk_u32_t const *)((char const *)vectors + col_abs * stride) +
891
- d_start_u32;
892
- svld1_hor_za32(0, col, batch_predicate_f32x, b_row);
918
+ if (col_abs < vectors_count) {
919
+ nk_u8_t const *b_row = (nk_u8_t const *)vectors + col_abs * stride_in_bytes + d_start_u32 * 4;
920
+ svuint8_t col_u8x = svld1_u8(depth_predicate_b8x, b_row);
921
+ svwrite_hor_za32_u32_m(0, col, batch_predicate_b32x, svreinterpret_u32_u8(col_u8x));
893
922
  }
894
923
  }
895
924
  for (nk_size_t step = 0; step < u32s_this_tile; step++) {
896
- svuint32_t a_u32x = svld1_u32(predicate_all_f32x, a_buffer[step]);
897
- svuint32_t b_u32x = svread_ver_za32_u32_m(svdup_u32(0), predicate_all_f32x, 0, step);
898
- svbmopa_za32_u32_m(3, row_predicate_f32x, predicate_all_f32x, a_u32x, b_u32x);
925
+ svuint32_t a_u32x = svld1_u32(predicate_all_b32x, a_buffer[step]);
926
+ svuint32_t b_u32x = svread_ver_za32_u32_m(svdup_u32(0), predicate_all_b32x, 0, step);
927
+ svbmopa_za32_u32_m(3, row_predicate_b32x, predicate_all_b32x, a_u32x, b_u32x);
899
928
  }
900
929
  }
901
930
 
@@ -907,85 +936,85 @@ __arm_locally_streaming __arm_new("za") static void nk_jaccards_symmetric_u1_sme
907
936
  nk_size_t const col_abs_0 = (column_tile_index + 0) * tile_dim + col;
908
937
  nk_size_t const col_abs_1 = (column_tile_index + 1) * tile_dim + col;
909
938
  nk_size_t const col_abs_2 = (column_tile_index + 2) * tile_dim + col;
910
- b_tile_norms_0[col] = (col_abs_0 < n_vectors)
911
- ? nk_sets_reduce_sumsq_u1_streaming_(
912
- (nk_u1x8_t const *)((char const *)vectors + col_abs_0 * stride),
913
- depth_in_bytes)
914
- : 0;
915
- b_tile_norms_1[col] = (col_abs_1 < n_vectors)
916
- ? nk_sets_reduce_sumsq_u1_streaming_(
917
- (nk_u1x8_t const *)((char const *)vectors + col_abs_1 * stride),
918
- depth_in_bytes)
919
- : 0;
920
- b_tile_norms_2[col] = (col_abs_2 < n_vectors)
921
- ? nk_sets_reduce_sumsq_u1_streaming_(
922
- (nk_u1x8_t const *)((char const *)vectors + col_abs_2 * stride),
923
- depth_in_bytes)
924
- : 0;
939
+ b_tile_norms_0[col] =
940
+ (col_abs_0 < vectors_count)
941
+ ? nk_sets_reduce_sumsq_u1_streaming_(
942
+ (nk_u1x8_t const *)((char const *)vectors + col_abs_0 * stride_in_bytes), depth_bytes)
943
+ : 0;
944
+ b_tile_norms_1[col] =
945
+ (col_abs_1 < vectors_count)
946
+ ? nk_sets_reduce_sumsq_u1_streaming_(
947
+ (nk_u1x8_t const *)((char const *)vectors + col_abs_1 * stride_in_bytes), depth_bytes)
948
+ : 0;
949
+ b_tile_norms_2[col] =
950
+ (col_abs_2 < vectors_count)
951
+ ? nk_sets_reduce_sumsq_u1_streaming_(
952
+ (nk_u1x8_t const *)((char const *)vectors + col_abs_2 * stride_in_bytes), depth_bytes)
953
+ : 0;
925
954
  }
926
955
 
927
956
  // Extract ZA1-3: Jaccard normalization
928
- svfloat32_t b_norms_0_f32x = svcvt_f32_u32_x(predicate_all_f32x,
929
- svld1_u32(predicate_all_f32x, b_tile_norms_0));
930
- svfloat32_t b_norms_1_f32x = svcvt_f32_u32_x(predicate_all_f32x,
931
- svld1_u32(predicate_all_f32x, b_tile_norms_1));
932
- svfloat32_t b_norms_2_f32x = svcvt_f32_u32_x(predicate_all_f32x,
933
- svld1_u32(predicate_all_f32x, b_tile_norms_2));
957
+ svfloat32_t b_norms_0_f32x = svcvt_f32_u32_x(predicate_all_b32x,
958
+ svld1_u32(predicate_all_b32x, b_tile_norms_0));
959
+ svfloat32_t b_norms_1_f32x = svcvt_f32_u32_x(predicate_all_b32x,
960
+ svld1_u32(predicate_all_b32x, b_tile_norms_1));
961
+ svfloat32_t b_norms_2_f32x = svcvt_f32_u32_x(predicate_all_b32x,
962
+ svld1_u32(predicate_all_b32x, b_tile_norms_2));
934
963
 
935
964
  for (nk_size_t row = 0; row < rows_clamped; row++) {
936
- nk_f32_t *c_row = (nk_f32_t *)((char *)result + (row_tile_start + row) * result_stride);
965
+ nk_f32_t *c_row = (nk_f32_t *)((char *)result + (row_tile_start + row) * result_stride_in_bytes);
937
966
  svfloat32_t norm_a_f32x = svdup_f32(a_tile_norms[row]);
938
967
 
939
968
  // ZA1
940
969
  {
941
- svuint32_t za1_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_f32x, 1, row);
942
- svfloat32_t matching_f32x = svcvt_f32_u32_x(predicate_all_f32x, za1_u32x);
943
- svfloat32_t sum_norms_f32x = svadd_f32_x(predicate_all_f32x, norm_a_f32x, b_norms_0_f32x);
970
+ svuint32_t za1_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_b32x, 1, row);
971
+ svfloat32_t matching_f32x = svcvt_f32_u32_x(predicate_all_b32x, za1_u32x);
972
+ svfloat32_t sum_norms_f32x = svadd_f32_x(predicate_all_b32x, norm_a_f32x, b_norms_0_f32x);
944
973
  svfloat32_t intersection_f32x = svmul_f32_x(
945
- predicate_all_f32x,
946
- svadd_f32_x(predicate_all_f32x, svsub_f32_x(predicate_all_f32x, sum_norms_f32x, depth_f32x),
974
+ predicate_all_b32x,
975
+ svadd_f32_x(predicate_all_b32x, svsub_f32_x(predicate_all_b32x, sum_norms_f32x, depth_f32x),
947
976
  matching_f32x),
948
977
  half_f32x);
949
- svfloat32_t union_val_f32x = svsub_f32_x(predicate_all_f32x, sum_norms_f32x, intersection_f32x);
950
- svbool_t nonzero_f32x = svcmpne_f32(predicate_all_f32x, union_val_f32x, zero_f32x);
951
- svfloat32_t ratio_f32x = svdiv_f32_x(predicate_all_f32x, intersection_f32x, union_val_f32x);
978
+ svfloat32_t union_val_f32x = svsub_f32_x(predicate_all_b32x, sum_norms_f32x, intersection_f32x);
979
+ svbool_t nonzero_b32x = svcmpne_f32(predicate_all_b32x, union_val_f32x, zero_f32x);
980
+ svfloat32_t ratio_f32x = svdiv_f32_x(predicate_all_b32x, intersection_f32x, union_val_f32x);
952
981
  svfloat32_t jaccard_f32x = svsel_f32(
953
- nonzero_f32x, svsub_f32_x(predicate_all_f32x, one_f32x, ratio_f32x), one_f32x);
954
- svst1_f32(predicate_all_f32x, c_row + (column_tile_index + 0) * tile_dim, jaccard_f32x);
982
+ nonzero_b32x, svsub_f32_x(predicate_all_b32x, one_f32x, ratio_f32x), one_f32x);
983
+ svst1_f32(predicate_all_b32x, c_row + (column_tile_index + 0) * tile_dim, jaccard_f32x);
955
984
  }
956
985
  // ZA2
957
986
  {
958
- svuint32_t za2_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_f32x, 2, row);
959
- svfloat32_t matching_f32x = svcvt_f32_u32_x(predicate_all_f32x, za2_u32x);
960
- svfloat32_t sum_norms_f32x = svadd_f32_x(predicate_all_f32x, norm_a_f32x, b_norms_1_f32x);
987
+ svuint32_t za2_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_b32x, 2, row);
988
+ svfloat32_t matching_f32x = svcvt_f32_u32_x(predicate_all_b32x, za2_u32x);
989
+ svfloat32_t sum_norms_f32x = svadd_f32_x(predicate_all_b32x, norm_a_f32x, b_norms_1_f32x);
961
990
  svfloat32_t intersection_f32x = svmul_f32_x(
962
- predicate_all_f32x,
963
- svadd_f32_x(predicate_all_f32x, svsub_f32_x(predicate_all_f32x, sum_norms_f32x, depth_f32x),
991
+ predicate_all_b32x,
992
+ svadd_f32_x(predicate_all_b32x, svsub_f32_x(predicate_all_b32x, sum_norms_f32x, depth_f32x),
964
993
  matching_f32x),
965
994
  half_f32x);
966
- svfloat32_t union_val_f32x = svsub_f32_x(predicate_all_f32x, sum_norms_f32x, intersection_f32x);
967
- svbool_t nonzero_f32x = svcmpne_f32(predicate_all_f32x, union_val_f32x, zero_f32x);
968
- svfloat32_t ratio_f32x = svdiv_f32_x(predicate_all_f32x, intersection_f32x, union_val_f32x);
995
+ svfloat32_t union_val_f32x = svsub_f32_x(predicate_all_b32x, sum_norms_f32x, intersection_f32x);
996
+ svbool_t nonzero_b32x = svcmpne_f32(predicate_all_b32x, union_val_f32x, zero_f32x);
997
+ svfloat32_t ratio_f32x = svdiv_f32_x(predicate_all_b32x, intersection_f32x, union_val_f32x);
969
998
  svfloat32_t jaccard_f32x = svsel_f32(
970
- nonzero_f32x, svsub_f32_x(predicate_all_f32x, one_f32x, ratio_f32x), one_f32x);
971
- svst1_f32(predicate_all_f32x, c_row + (column_tile_index + 1) * tile_dim, jaccard_f32x);
999
+ nonzero_b32x, svsub_f32_x(predicate_all_b32x, one_f32x, ratio_f32x), one_f32x);
1000
+ svst1_f32(predicate_all_b32x, c_row + (column_tile_index + 1) * tile_dim, jaccard_f32x);
972
1001
  }
973
1002
  // ZA3
974
1003
  {
975
- svuint32_t za3_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_f32x, 3, row);
976
- svfloat32_t matching_f32x = svcvt_f32_u32_x(predicate_all_f32x, za3_u32x);
977
- svfloat32_t sum_norms_f32x = svadd_f32_x(predicate_all_f32x, norm_a_f32x, b_norms_2_f32x);
1004
+ svuint32_t za3_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_b32x, 3, row);
1005
+ svfloat32_t matching_f32x = svcvt_f32_u32_x(predicate_all_b32x, za3_u32x);
1006
+ svfloat32_t sum_norms_f32x = svadd_f32_x(predicate_all_b32x, norm_a_f32x, b_norms_2_f32x);
978
1007
  svfloat32_t intersection_f32x = svmul_f32_x(
979
- predicate_all_f32x,
980
- svadd_f32_x(predicate_all_f32x, svsub_f32_x(predicate_all_f32x, sum_norms_f32x, depth_f32x),
1008
+ predicate_all_b32x,
1009
+ svadd_f32_x(predicate_all_b32x, svsub_f32_x(predicate_all_b32x, sum_norms_f32x, depth_f32x),
981
1010
  matching_f32x),
982
1011
  half_f32x);
983
- svfloat32_t union_val_f32x = svsub_f32_x(predicate_all_f32x, sum_norms_f32x, intersection_f32x);
984
- svbool_t nonzero_f32x = svcmpne_f32(predicate_all_f32x, union_val_f32x, zero_f32x);
985
- svfloat32_t ratio_f32x = svdiv_f32_x(predicate_all_f32x, intersection_f32x, union_val_f32x);
1012
+ svfloat32_t union_val_f32x = svsub_f32_x(predicate_all_b32x, sum_norms_f32x, intersection_f32x);
1013
+ svbool_t nonzero_b32x = svcmpne_f32(predicate_all_b32x, union_val_f32x, zero_f32x);
1014
+ svfloat32_t ratio_f32x = svdiv_f32_x(predicate_all_b32x, intersection_f32x, union_val_f32x);
986
1015
  svfloat32_t jaccard_f32x = svsel_f32(
987
- nonzero_f32x, svsub_f32_x(predicate_all_f32x, one_f32x, ratio_f32x), one_f32x);
988
- svst1_f32(predicate_all_f32x, c_row + (column_tile_index + 2) * tile_dim, jaccard_f32x);
1016
+ nonzero_b32x, svsub_f32_x(predicate_all_b32x, one_f32x, ratio_f32x), one_f32x);
1017
+ svst1_f32(predicate_all_b32x, c_row + (column_tile_index + 2) * tile_dim, jaccard_f32x);
989
1018
  }
990
1019
  }
991
1020
  }
@@ -993,50 +1022,51 @@ __arm_locally_streaming __arm_new("za") static void nk_jaccards_symmetric_u1_sme
993
1022
  // Remainder: 1 column tile at a time using ZA1
994
1023
  for (; column_tile_index < column_tile_count; column_tile_index++) {
995
1024
  nk_size_t const col_tile_start = column_tile_index * tile_dim;
996
- nk_size_t const cols_remaining = (col_tile_start + tile_dim <= n_vectors) ? tile_dim
997
- : (n_vectors - col_tile_start);
998
- svbool_t const column_predicate_f32x = svwhilelt_b32_u64(0u, cols_remaining);
1025
+ nk_size_t const cols_remaining = (col_tile_start + tile_dim <= vectors_count)
1026
+ ? tile_dim
1027
+ : (vectors_count - col_tile_start);
1028
+ svbool_t const column_predicate_b32x = svwhilelt_b32_u64(0u, cols_remaining);
999
1029
 
1000
1030
  svzero_mask_za(nk_sme_zero_za32_tile_1_);
1001
1031
 
1002
1032
  for (nk_size_t d_tile = 0; d_tile < depth_tile_count; d_tile++) {
1003
1033
  nk_size_t const d_start_u32 = d_tile * depth_tile_size;
1004
- nk_size_t const u32s_this_tile = (d_start_u32 + depth_tile_size <= depth_u32_total)
1034
+ nk_size_t const u32s_this_tile = (d_start_u32 + depth_tile_size <= depth_words)
1005
1035
  ? depth_tile_size
1006
- : (depth_u32_total > d_start_u32 ? depth_u32_total - d_start_u32
1007
- : 0);
1036
+ : (depth_words > d_start_u32 ? depth_words - d_start_u32 : 0);
1008
1037
  if (u32s_this_tile == 0) break;
1009
1038
 
1010
1039
  svzero_mask_za(nk_sme_zero_za32_tile_0_);
1011
- svbool_t const batch_predicate_f32x = svwhilelt_b32_u64(0u, u32s_this_tile);
1040
+ svbool_t const batch_predicate_b32x = svwhilelt_b32_u64(0u, u32s_this_tile);
1012
1041
 
1013
1042
  // Load A rows into ZA0 horizontally
1043
+ svbool_t const depth_predicate_b8x = svwhilelt_b8_u64(d_start_u32 * 4, depth_bytes);
1014
1044
  for (nk_size_t row_in_tile = 0; row_in_tile < rows_clamped; row_in_tile++) {
1015
- nk_u32_t const *a_row_u32 = (nk_u32_t const *)((char const *)vectors +
1016
- (row_tile_start + row_in_tile) * stride) +
1017
- d_start_u32;
1018
- svld1_hor_za32(0, row_in_tile, batch_predicate_f32x, a_row_u32);
1045
+ nk_u8_t const *a_row = (nk_u8_t const *)vectors + (row_tile_start + row_in_tile) * stride_in_bytes +
1046
+ d_start_u32 * 4;
1047
+ svuint8_t row_u8x = svld1_u8(depth_predicate_b8x, a_row);
1048
+ svwrite_hor_za32_u32_m(0, row_in_tile, batch_predicate_b32x, svreinterpret_u32_u8(row_u8x));
1019
1049
  }
1020
1050
 
1021
1051
  // Save A columns from ZA0 to stack buffer
1022
1052
  for (nk_size_t s = 0; s < u32s_this_tile; s++)
1023
- svst1_u32(predicate_all_f32x, a_buffer[s],
1024
- svread_ver_za32_u32_m(svdup_u32(0), row_predicate_f32x, 0, s));
1053
+ svst1_u32(predicate_all_b32x, a_buffer[s],
1054
+ svread_ver_za32_u32_m(svdup_u32(0), row_predicate_b32x, 0, s));
1025
1055
 
1026
1056
  // Load B column tile into ZA0
1027
1057
  svzero_mask_za(nk_sme_zero_za32_tile_0_);
1028
1058
  for (nk_size_t col = 0; col < tile_dim; col++) {
1029
1059
  nk_size_t const col_abs = col_tile_start + col;
1030
- if (col_abs < n_vectors) {
1031
- nk_u32_t const *b_row = (nk_u32_t const *)((char const *)vectors + col_abs * stride) +
1032
- d_start_u32;
1033
- svld1_hor_za32(0, col, batch_predicate_f32x, b_row);
1060
+ if (col_abs < vectors_count) {
1061
+ nk_u8_t const *b_row = (nk_u8_t const *)vectors + col_abs * stride_in_bytes + d_start_u32 * 4;
1062
+ svuint8_t col_u8x = svld1_u8(depth_predicate_b8x, b_row);
1063
+ svwrite_hor_za32_u32_m(0, col, batch_predicate_b32x, svreinterpret_u32_u8(col_u8x));
1034
1064
  }
1035
1065
  }
1036
1066
  for (nk_size_t step = 0; step < u32s_this_tile; step++) {
1037
- svuint32_t a_u32x = svld1_u32(predicate_all_f32x, a_buffer[step]);
1038
- svuint32_t b_u32x = svread_ver_za32_u32_m(svdup_u32(0), column_predicate_f32x, 0, step);
1039
- svbmopa_za32_u32_m(1, row_predicate_f32x, column_predicate_f32x, a_u32x, b_u32x);
1067
+ svuint32_t a_u32x = svld1_u32(predicate_all_b32x, a_buffer[step]);
1068
+ svuint32_t b_u32x = svread_ver_za32_u32_m(svdup_u32(0), column_predicate_b32x, 0, step);
1069
+ svbmopa_za32_u32_m(1, row_predicate_b32x, column_predicate_b32x, a_u32x, b_u32x);
1040
1070
  }
1041
1071
  }
1042
1072
 
@@ -1044,44 +1074,45 @@ __arm_locally_streaming __arm_new("za") static void nk_jaccards_symmetric_u1_sme
1044
1074
  NK_ALIGN64 nk_u32_t b_tile_norms[16];
1045
1075
  for (nk_size_t col = 0; col < tile_dim; col++) {
1046
1076
  nk_size_t const col_abs = col_tile_start + col;
1047
- b_tile_norms[col] = (col_abs < n_vectors)
1077
+ b_tile_norms[col] = (col_abs < vectors_count)
1048
1078
  ? nk_sets_reduce_sumsq_u1_streaming_(
1049
- (nk_u1x8_t const *)((char const *)vectors + col_abs * stride),
1050
- depth_in_bytes)
1079
+ (nk_u1x8_t const *)((char const *)vectors + col_abs * stride_in_bytes),
1080
+ depth_bytes)
1051
1081
  : 0;
1052
1082
  }
1053
1083
 
1054
- svfloat32_t b_norms_f32x = svcvt_f32_u32_x(predicate_all_f32x, svld1_u32(predicate_all_f32x, b_tile_norms));
1084
+ svfloat32_t b_norms_f32x = svcvt_f32_u32_x(predicate_all_b32x, svld1_u32(predicate_all_b32x, b_tile_norms));
1055
1085
  for (nk_size_t row = 0; row < rows_clamped; row++) {
1056
- svuint32_t za1_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_f32x, 1, row);
1057
- svfloat32_t matching_f32x = svcvt_f32_u32_x(predicate_all_f32x, za1_u32x);
1086
+ svuint32_t za1_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_b32x, 1, row);
1087
+ svfloat32_t matching_f32x = svcvt_f32_u32_x(predicate_all_b32x, za1_u32x);
1058
1088
  svfloat32_t norm_a_f32x = svdup_f32(a_tile_norms[row]);
1059
- svfloat32_t sum_norms_f32x = svadd_f32_x(predicate_all_f32x, norm_a_f32x, b_norms_f32x);
1089
+ svfloat32_t sum_norms_f32x = svadd_f32_x(predicate_all_b32x, norm_a_f32x, b_norms_f32x);
1060
1090
  svfloat32_t intersection_f32x = svmul_f32_x(
1061
- predicate_all_f32x,
1062
- svadd_f32_x(predicate_all_f32x, svsub_f32_x(predicate_all_f32x, sum_norms_f32x, depth_f32x),
1091
+ predicate_all_b32x,
1092
+ svadd_f32_x(predicate_all_b32x, svsub_f32_x(predicate_all_b32x, sum_norms_f32x, depth_f32x),
1063
1093
  matching_f32x),
1064
1094
  half_f32x);
1065
- svfloat32_t union_val_f32x = svsub_f32_x(predicate_all_f32x, sum_norms_f32x, intersection_f32x);
1066
- svbool_t nonzero_f32x = svcmpne_f32(predicate_all_f32x, union_val_f32x, zero_f32x);
1067
- svfloat32_t ratio_f32x = svdiv_f32_x(predicate_all_f32x, intersection_f32x, union_val_f32x);
1068
- svfloat32_t jaccard_f32x = svsel_f32(nonzero_f32x,
1069
- svsub_f32_x(predicate_all_f32x, one_f32x, ratio_f32x), one_f32x);
1070
- nk_f32_t *c_row = (nk_f32_t *)((char *)result + (row_tile_start + row) * result_stride);
1071
- svst1_f32(column_predicate_f32x, c_row + col_tile_start, jaccard_f32x);
1095
+ svfloat32_t union_val_f32x = svsub_f32_x(predicate_all_b32x, sum_norms_f32x, intersection_f32x);
1096
+ svbool_t nonzero_b32x = svcmpne_f32(predicate_all_b32x, union_val_f32x, zero_f32x);
1097
+ svfloat32_t ratio_f32x = svdiv_f32_x(predicate_all_b32x, intersection_f32x, union_val_f32x);
1098
+ svfloat32_t jaccard_f32x = svsel_f32(nonzero_b32x,
1099
+ svsub_f32_x(predicate_all_b32x, one_f32x, ratio_f32x), one_f32x);
1100
+ nk_f32_t *c_row = (nk_f32_t *)((char *)result + (row_tile_start + row) * result_stride_in_bytes);
1101
+ svst1_f32(column_predicate_b32x, c_row + col_tile_start, jaccard_f32x);
1072
1102
  }
1073
1103
  }
1074
1104
  }
1075
1105
  }
1076
1106
 
1077
- NK_PUBLIC void nk_jaccards_symmetric_u1_smebi32(nk_u1x8_t const *vectors, nk_size_t n_vectors, nk_size_t depth_bits,
1078
- nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1079
- nk_size_t row_start, nk_size_t row_count) {
1080
- nk_jaccards_symmetric_u1_smebi32_streaming_(vectors, n_vectors, depth_bits, stride, result, result_stride,
1081
- row_start, row_count);
1107
+ NK_PUBLIC void nk_jaccards_symmetric_u1_smebi32(nk_u1x8_t const *vectors, nk_size_t vectors_count, nk_size_t depth_bits,
1108
+ nk_size_t stride_in_bytes, nk_f32_t *result,
1109
+ nk_size_t result_stride_in_bytes, nk_size_t row_start,
1110
+ nk_size_t row_count) {
1111
+ nk_jaccards_symmetric_u1_smebi32_streaming_(vectors, vectors_count, depth_bits, stride_in_bytes, result,
1112
+ result_stride_in_bytes, row_start, row_count);
1082
1113
  }
1083
1114
 
1084
- #pragma endregion // Jaccard Distance
1115
+ #pragma endregion Jaccard Distance
1085
1116
 
1086
1117
  #if defined(__clang__)
1087
1118
  #pragma clang attribute pop