numkong 7.0.0 → 7.4.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (315) hide show
  1. package/README.md +197 -124
  2. package/binding.gyp +34 -484
  3. package/c/dispatch_bf16.c +59 -1
  4. package/c/dispatch_e2m3.c +41 -8
  5. package/c/dispatch_e3m2.c +49 -8
  6. package/c/dispatch_e4m3.c +51 -9
  7. package/c/dispatch_e5m2.c +45 -1
  8. package/c/dispatch_f16.c +79 -26
  9. package/c/dispatch_f16c.c +5 -5
  10. package/c/dispatch_f32.c +56 -0
  11. package/c/dispatch_f64.c +52 -0
  12. package/c/dispatch_i4.c +3 -0
  13. package/c/dispatch_i8.c +62 -3
  14. package/c/dispatch_other.c +18 -0
  15. package/c/dispatch_u1.c +54 -9
  16. package/c/dispatch_u4.c +3 -0
  17. package/c/dispatch_u8.c +64 -3
  18. package/c/numkong.c +3 -0
  19. package/include/README.md +79 -9
  20. package/include/numkong/attention/sapphireamx.h +278 -276
  21. package/include/numkong/attention/sme.h +983 -977
  22. package/include/numkong/attention.h +1 -1
  23. package/include/numkong/capabilities.h +289 -94
  24. package/include/numkong/cast/README.md +40 -40
  25. package/include/numkong/cast/diamond.h +64 -0
  26. package/include/numkong/cast/haswell.h +42 -194
  27. package/include/numkong/cast/icelake.h +42 -37
  28. package/include/numkong/cast/loongsonasx.h +252 -0
  29. package/include/numkong/cast/neon.h +216 -249
  30. package/include/numkong/cast/powervsx.h +449 -0
  31. package/include/numkong/cast/rvv.h +223 -274
  32. package/include/numkong/cast/sapphire.h +18 -18
  33. package/include/numkong/cast/serial.h +1018 -944
  34. package/include/numkong/cast/skylake.h +82 -23
  35. package/include/numkong/cast/v128relaxed.h +462 -105
  36. package/include/numkong/cast.h +24 -0
  37. package/include/numkong/cast.hpp +44 -0
  38. package/include/numkong/curved/README.md +17 -17
  39. package/include/numkong/curved/neon.h +131 -7
  40. package/include/numkong/curved/neonbfdot.h +6 -7
  41. package/include/numkong/curved/rvv.h +26 -26
  42. package/include/numkong/curved/smef64.h +186 -182
  43. package/include/numkong/curved.h +14 -18
  44. package/include/numkong/dot/README.md +154 -137
  45. package/include/numkong/dot/alder.h +43 -43
  46. package/include/numkong/dot/diamond.h +158 -0
  47. package/include/numkong/dot/genoa.h +4 -30
  48. package/include/numkong/dot/haswell.h +215 -180
  49. package/include/numkong/dot/icelake.h +190 -76
  50. package/include/numkong/dot/loongsonasx.h +671 -0
  51. package/include/numkong/dot/neon.h +124 -73
  52. package/include/numkong/dot/neonbfdot.h +11 -12
  53. package/include/numkong/dot/neonfhm.h +44 -46
  54. package/include/numkong/dot/neonfp8.h +323 -0
  55. package/include/numkong/dot/neonsdot.h +190 -76
  56. package/include/numkong/dot/powervsx.h +752 -0
  57. package/include/numkong/dot/rvv.h +92 -84
  58. package/include/numkong/dot/rvvbf16.h +12 -12
  59. package/include/numkong/dot/rvvhalf.h +12 -12
  60. package/include/numkong/dot/sapphire.h +4 -4
  61. package/include/numkong/dot/serial.h +66 -30
  62. package/include/numkong/dot/sierra.h +31 -31
  63. package/include/numkong/dot/skylake.h +142 -110
  64. package/include/numkong/dot/sve.h +217 -177
  65. package/include/numkong/dot/svebfdot.h +10 -10
  66. package/include/numkong/dot/svehalf.h +85 -41
  67. package/include/numkong/dot/svesdot.h +89 -0
  68. package/include/numkong/dot/v128relaxed.h +124 -89
  69. package/include/numkong/dot.h +114 -48
  70. package/include/numkong/dots/README.md +203 -203
  71. package/include/numkong/dots/alder.h +12 -9
  72. package/include/numkong/dots/diamond.h +86 -0
  73. package/include/numkong/dots/genoa.h +10 -4
  74. package/include/numkong/dots/haswell.h +63 -48
  75. package/include/numkong/dots/icelake.h +27 -18
  76. package/include/numkong/dots/loongsonasx.h +176 -0
  77. package/include/numkong/dots/neon.h +14 -11
  78. package/include/numkong/dots/neonbfdot.h +4 -3
  79. package/include/numkong/dots/neonfhm.h +11 -9
  80. package/include/numkong/dots/neonfp8.h +99 -0
  81. package/include/numkong/dots/neonsdot.h +48 -12
  82. package/include/numkong/dots/powervsx.h +194 -0
  83. package/include/numkong/dots/rvv.h +451 -344
  84. package/include/numkong/dots/sapphireamx.h +1028 -984
  85. package/include/numkong/dots/serial.h +213 -197
  86. package/include/numkong/dots/sierra.h +10 -7
  87. package/include/numkong/dots/skylake.h +47 -36
  88. package/include/numkong/dots/sme.h +2001 -2364
  89. package/include/numkong/dots/smebi32.h +175 -162
  90. package/include/numkong/dots/smef64.h +328 -323
  91. package/include/numkong/dots/v128relaxed.h +64 -41
  92. package/include/numkong/dots.h +573 -293
  93. package/include/numkong/dots.hpp +45 -43
  94. package/include/numkong/each/README.md +133 -137
  95. package/include/numkong/each/haswell.h +6 -6
  96. package/include/numkong/each/icelake.h +7 -7
  97. package/include/numkong/each/neon.h +76 -42
  98. package/include/numkong/each/neonbfdot.h +11 -12
  99. package/include/numkong/each/neonhalf.h +24 -116
  100. package/include/numkong/each/rvv.h +28 -28
  101. package/include/numkong/each/sapphire.h +27 -161
  102. package/include/numkong/each/serial.h +6 -6
  103. package/include/numkong/each/skylake.h +7 -7
  104. package/include/numkong/each/v128relaxed.h +562 -0
  105. package/include/numkong/each.h +148 -62
  106. package/include/numkong/each.hpp +2 -2
  107. package/include/numkong/geospatial/README.md +18 -18
  108. package/include/numkong/geospatial/haswell.h +365 -325
  109. package/include/numkong/geospatial/neon.h +350 -306
  110. package/include/numkong/geospatial/rvv.h +4 -4
  111. package/include/numkong/geospatial/skylake.h +376 -340
  112. package/include/numkong/geospatial/v128relaxed.h +366 -327
  113. package/include/numkong/geospatial.h +17 -17
  114. package/include/numkong/matrix.hpp +4 -4
  115. package/include/numkong/maxsim/README.md +14 -14
  116. package/include/numkong/maxsim/alder.h +6 -6
  117. package/include/numkong/maxsim/genoa.h +4 -4
  118. package/include/numkong/maxsim/haswell.h +6 -6
  119. package/include/numkong/maxsim/icelake.h +18 -18
  120. package/include/numkong/maxsim/neonsdot.h +21 -21
  121. package/include/numkong/maxsim/sapphireamx.h +14 -14
  122. package/include/numkong/maxsim/serial.h +6 -6
  123. package/include/numkong/maxsim/sme.h +221 -196
  124. package/include/numkong/maxsim/v128relaxed.h +6 -6
  125. package/include/numkong/mesh/README.md +62 -56
  126. package/include/numkong/mesh/haswell.h +339 -464
  127. package/include/numkong/mesh/neon.h +1100 -519
  128. package/include/numkong/mesh/neonbfdot.h +36 -68
  129. package/include/numkong/mesh/rvv.h +530 -435
  130. package/include/numkong/mesh/serial.h +75 -91
  131. package/include/numkong/mesh/skylake.h +1627 -302
  132. package/include/numkong/mesh/v128relaxed.h +443 -330
  133. package/include/numkong/mesh.h +63 -49
  134. package/include/numkong/mesh.hpp +4 -4
  135. package/include/numkong/numkong.h +3 -3
  136. package/include/numkong/numkong.hpp +1 -0
  137. package/include/numkong/probability/README.md +23 -19
  138. package/include/numkong/probability/neon.h +82 -52
  139. package/include/numkong/probability/rvv.h +28 -23
  140. package/include/numkong/probability/serial.h +51 -39
  141. package/include/numkong/probability.h +20 -23
  142. package/include/numkong/random.h +1 -1
  143. package/include/numkong/reduce/README.md +143 -138
  144. package/include/numkong/reduce/alder.h +81 -77
  145. package/include/numkong/reduce/haswell.h +222 -220
  146. package/include/numkong/reduce/neon.h +629 -519
  147. package/include/numkong/reduce/neonbfdot.h +7 -218
  148. package/include/numkong/reduce/neonfhm.h +9 -381
  149. package/include/numkong/reduce/neonsdot.h +9 -9
  150. package/include/numkong/reduce/rvv.h +928 -802
  151. package/include/numkong/reduce/serial.h +23 -27
  152. package/include/numkong/reduce/sierra.h +20 -20
  153. package/include/numkong/reduce/skylake.h +326 -324
  154. package/include/numkong/reduce/v128relaxed.h +52 -52
  155. package/include/numkong/reduce.h +4 -23
  156. package/include/numkong/reduce.hpp +156 -11
  157. package/include/numkong/scalar/README.md +6 -6
  158. package/include/numkong/scalar/haswell.h +26 -17
  159. package/include/numkong/scalar/loongsonasx.h +74 -0
  160. package/include/numkong/scalar/neon.h +9 -9
  161. package/include/numkong/scalar/powervsx.h +96 -0
  162. package/include/numkong/scalar/rvv.h +2 -2
  163. package/include/numkong/scalar/sapphire.h +21 -10
  164. package/include/numkong/scalar/serial.h +21 -21
  165. package/include/numkong/scalar.h +13 -0
  166. package/include/numkong/set/README.md +28 -28
  167. package/include/numkong/set/haswell.h +12 -12
  168. package/include/numkong/set/icelake.h +14 -14
  169. package/include/numkong/set/loongsonasx.h +181 -0
  170. package/include/numkong/set/neon.h +17 -18
  171. package/include/numkong/set/powervsx.h +326 -0
  172. package/include/numkong/set/rvv.h +4 -4
  173. package/include/numkong/set/serial.h +6 -6
  174. package/include/numkong/set/sve.h +60 -59
  175. package/include/numkong/set/v128relaxed.h +6 -6
  176. package/include/numkong/set.h +21 -7
  177. package/include/numkong/sets/README.md +26 -26
  178. package/include/numkong/sets/loongsonasx.h +52 -0
  179. package/include/numkong/sets/powervsx.h +65 -0
  180. package/include/numkong/sets/smebi32.h +395 -364
  181. package/include/numkong/sets.h +83 -40
  182. package/include/numkong/sparse/README.md +4 -4
  183. package/include/numkong/sparse/icelake.h +101 -101
  184. package/include/numkong/sparse/serial.h +1 -1
  185. package/include/numkong/sparse/sve2.h +137 -141
  186. package/include/numkong/sparse/turin.h +12 -12
  187. package/include/numkong/sparse.h +10 -10
  188. package/include/numkong/spatial/README.md +230 -226
  189. package/include/numkong/spatial/alder.h +113 -116
  190. package/include/numkong/spatial/diamond.h +240 -0
  191. package/include/numkong/spatial/genoa.h +0 -68
  192. package/include/numkong/spatial/haswell.h +74 -55
  193. package/include/numkong/spatial/icelake.h +539 -58
  194. package/include/numkong/spatial/loongsonasx.h +483 -0
  195. package/include/numkong/spatial/neon.h +125 -52
  196. package/include/numkong/spatial/neonbfdot.h +8 -9
  197. package/include/numkong/spatial/neonfp8.h +258 -0
  198. package/include/numkong/spatial/neonsdot.h +180 -12
  199. package/include/numkong/spatial/powervsx.h +738 -0
  200. package/include/numkong/spatial/rvv.h +146 -139
  201. package/include/numkong/spatial/rvvbf16.h +17 -12
  202. package/include/numkong/spatial/rvvhalf.h +13 -10
  203. package/include/numkong/spatial/serial.h +13 -12
  204. package/include/numkong/spatial/sierra.h +232 -39
  205. package/include/numkong/spatial/skylake.h +73 -74
  206. package/include/numkong/spatial/sve.h +93 -72
  207. package/include/numkong/spatial/svebfdot.h +29 -29
  208. package/include/numkong/spatial/svehalf.h +52 -26
  209. package/include/numkong/spatial/svesdot.h +142 -0
  210. package/include/numkong/spatial/v128relaxed.h +293 -41
  211. package/include/numkong/spatial.h +338 -82
  212. package/include/numkong/spatials/README.md +194 -194
  213. package/include/numkong/spatials/diamond.h +82 -0
  214. package/include/numkong/spatials/haswell.h +2 -2
  215. package/include/numkong/spatials/loongsonasx.h +153 -0
  216. package/include/numkong/spatials/neonfp8.h +111 -0
  217. package/include/numkong/spatials/neonsdot.h +34 -0
  218. package/include/numkong/spatials/powervsx.h +153 -0
  219. package/include/numkong/spatials/rvv.h +259 -243
  220. package/include/numkong/spatials/sapphireamx.h +173 -173
  221. package/include/numkong/spatials/serial.h +2 -2
  222. package/include/numkong/spatials/skylake.h +2 -2
  223. package/include/numkong/spatials/sme.h +590 -605
  224. package/include/numkong/spatials/smef64.h +139 -130
  225. package/include/numkong/spatials/v128relaxed.h +2 -2
  226. package/include/numkong/spatials.h +820 -500
  227. package/include/numkong/spatials.hpp +49 -48
  228. package/include/numkong/tensor.hpp +406 -17
  229. package/include/numkong/trigonometry/README.md +19 -19
  230. package/include/numkong/trigonometry/haswell.h +402 -401
  231. package/include/numkong/trigonometry/neon.h +386 -387
  232. package/include/numkong/trigonometry/rvv.h +52 -51
  233. package/include/numkong/trigonometry/serial.h +13 -13
  234. package/include/numkong/trigonometry/skylake.h +373 -369
  235. package/include/numkong/trigonometry/v128relaxed.h +375 -374
  236. package/include/numkong/trigonometry.h +13 -13
  237. package/include/numkong/trigonometry.hpp +2 -2
  238. package/include/numkong/types.h +287 -49
  239. package/include/numkong/types.hpp +436 -12
  240. package/include/numkong/vector.hpp +82 -14
  241. package/javascript/dist/cjs/numkong-wasm.js +6 -12
  242. package/javascript/dist/cjs/numkong.d.ts +7 -1
  243. package/javascript/dist/cjs/numkong.js +37 -11
  244. package/javascript/dist/cjs/types.d.ts +9 -0
  245. package/javascript/dist/cjs/types.js +96 -0
  246. package/javascript/dist/esm/numkong-browser.d.ts +14 -0
  247. package/javascript/dist/esm/numkong-browser.js +23 -0
  248. package/javascript/dist/esm/numkong-wasm.js +6 -12
  249. package/javascript/dist/esm/numkong.d.ts +7 -1
  250. package/javascript/dist/esm/numkong.js +37 -11
  251. package/javascript/dist/esm/types.d.ts +9 -0
  252. package/javascript/dist/esm/types.js +96 -0
  253. package/javascript/node-gyp-build.d.ts +4 -1
  254. package/javascript/numkong-browser.ts +40 -0
  255. package/javascript/numkong-wasm.ts +7 -13
  256. package/javascript/numkong.c +5 -26
  257. package/javascript/numkong.ts +36 -11
  258. package/javascript/tsconfig-base.json +1 -0
  259. package/javascript/tsconfig-cjs.json +6 -1
  260. package/javascript/types.ts +110 -0
  261. package/numkong.gypi +101 -0
  262. package/package.json +34 -13
  263. package/probes/arm_neon.c +8 -0
  264. package/probes/arm_neon_bfdot.c +9 -0
  265. package/probes/arm_neon_fhm.c +9 -0
  266. package/probes/arm_neon_half.c +8 -0
  267. package/probes/arm_neon_sdot.c +9 -0
  268. package/probes/arm_neonfp8.c +9 -0
  269. package/probes/arm_sme.c +16 -0
  270. package/probes/arm_sme2.c +16 -0
  271. package/probes/arm_sme2p1.c +16 -0
  272. package/probes/arm_sme_bf16.c +16 -0
  273. package/probes/arm_sme_bi32.c +16 -0
  274. package/probes/arm_sme_f64.c +16 -0
  275. package/probes/arm_sme_fa64.c +14 -0
  276. package/probes/arm_sme_half.c +16 -0
  277. package/probes/arm_sme_lut2.c +15 -0
  278. package/probes/arm_sve.c +18 -0
  279. package/probes/arm_sve2.c +20 -0
  280. package/probes/arm_sve2p1.c +18 -0
  281. package/probes/arm_sve_bfdot.c +20 -0
  282. package/probes/arm_sve_half.c +18 -0
  283. package/probes/arm_sve_sdot.c +21 -0
  284. package/probes/loongarch_lasx.c +12 -0
  285. package/probes/power_vsx.c +12 -0
  286. package/probes/probe.js +127 -0
  287. package/probes/riscv_rvv.c +14 -0
  288. package/probes/riscv_rvv_bb.c +15 -0
  289. package/probes/riscv_rvv_bf16.c +17 -0
  290. package/probes/riscv_rvv_half.c +14 -0
  291. package/probes/wasm_v128relaxed.c +11 -0
  292. package/probes/x86_alder.c +17 -0
  293. package/probes/x86_diamond.c +17 -0
  294. package/probes/x86_genoa.c +17 -0
  295. package/probes/x86_graniteamx.c +19 -0
  296. package/probes/x86_haswell.c +11 -0
  297. package/probes/x86_icelake.c +17 -0
  298. package/probes/x86_sapphire.c +16 -0
  299. package/probes/x86_sapphireamx.c +18 -0
  300. package/probes/x86_sierra.c +17 -0
  301. package/probes/x86_skylake.c +15 -0
  302. package/probes/x86_turin.c +17 -0
  303. package/wasm/numkong-emscripten.js +2 -0
  304. package/wasm/numkong.d.ts +14 -0
  305. package/wasm/numkong.js +1124 -0
  306. package/wasm/numkong.wasm +0 -0
  307. package/include/numkong/curved/neonhalf.h +0 -212
  308. package/include/numkong/dot/neonhalf.h +0 -198
  309. package/include/numkong/dots/neonhalf.h +0 -57
  310. package/include/numkong/mesh/neonhalf.h +0 -616
  311. package/include/numkong/reduce/neonhalf.h +0 -157
  312. package/include/numkong/spatial/neonhalf.h +0 -118
  313. package/include/numkong/spatial/sapphire.h +0 -343
  314. package/include/numkong/spatials/neonhalf.h +0 -58
  315. package/javascript/README.md +0 -246
@@ -25,7 +25,7 @@ extern "C" {
25
25
  #endif
26
26
 
27
27
  #if defined(__clang__)
28
- #pragma clang attribute push(__attribute__((target("sme2,sve2"))), apply_to = function)
28
+ #pragma clang attribute push(__attribute__((target("sme2"))), apply_to = function)
29
29
  #elif defined(__GNUC__)
30
30
  #pragma GCC push_options
31
31
  #pragma GCC target("+sme2")
@@ -50,28 +50,32 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_packed_u1_smebi32_st
50
50
  nk_size_t const tile_dim = svcntw(); // 16 for 512-bit SVL
51
51
  nk_size_t const depth_tile_size = svcntw(); // 16 u32 per depth tile
52
52
  nk_size_t const tile_elements = tile_dim * depth_tile_size;
53
- nk_size_t const depth_u32_total = nk_size_divide_round_up_(depth_bits, 32);
53
+ // BMOPA processes binary data in 32-bit words: each svbmopa_za32_u32_m step
54
+ // handles one u32 (32 bits) across all row×column pairs simultaneously.
55
+ nk_size_t const depth_words = nk_size_divide_round_up_(depth_bits, 32);
56
+ nk_size_t const depth_bytes = depth_bits / 8;
54
57
 
55
58
  nk_u32_t const *b_tiles = (nk_u32_t const *)((char const *)b_packed + sizeof(nk_sets_smebi32_packed_header_t));
56
59
  nk_u32_t const *b_norms = header->norms_offset ? (nk_u32_t const *)((char const *)b_packed + header->norms_offset)
57
60
  : (nk_u32_t const *)0;
58
61
 
59
- svbool_t const predicate_all_u32x = svptrue_b32();
60
- svuint32_t const depth_u32x = svdup_u32((nk_u32_t)depth_bits);
61
- nk_size_t const depth_in_bytes = nk_size_divide_round_up_(depth_bits, 8);
62
+ svbool_t const predicate_all_b32x = svptrue_b32();
63
+ // Use padded depth (depth_words * 32) for BMOPA: zero-padded bits always match in XNOR,
64
+ // so the effective depth for the matching→intersection conversion is the rounded-up bit count.
65
+ svuint32_t const depth_u32x = svdup_u32((nk_u32_t)(depth_words * 32));
62
66
  nk_size_t const row_tile_count_a = nk_size_divide_round_up_(row_count_a, tile_dim);
63
67
 
64
68
  for (nk_size_t row_tile_a = 0; row_tile_a < row_tile_count_a; row_tile_a++) {
65
69
  nk_size_t const row_start_a = row_tile_a * tile_dim;
66
70
  nk_size_t const rows_a_remaining = (row_start_a + tile_dim <= row_count_a) ? tile_dim
67
71
  : (row_count_a - row_start_a);
68
- svbool_t const row_predicate_u32x = svwhilelt_b32_u64(0u, rows_a_remaining);
72
+ svbool_t const row_predicate_b32x = svwhilelt_b32_u64(0u, rows_a_remaining);
69
73
 
70
74
  // Compute A row popcounts for this tile
71
75
  nk_u32_t a_popcounts[16];
72
76
  for (nk_size_t r = 0; r < rows_a_remaining; r++) {
73
77
  nk_u1x8_t const *a_row = (nk_u1x8_t const *)((char const *)a + (row_start_a + r) * a_stride_in_bytes);
74
- a_popcounts[r] = nk_sets_reduce_sumsq_u1_streaming_(a_row, depth_in_bytes);
78
+ a_popcounts[r] = nk_sets_reduce_sumsq_u1_streaming_(a_row, depth_bytes);
75
79
  }
76
80
 
77
81
  // Fast path: 3 B column tiles using ZA1-ZA3 (ZA0.S = staging)
@@ -81,21 +85,21 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_packed_u1_smebi32_st
81
85
 
82
86
  for (nk_size_t d_tile = 0; d_tile < depth_tile_count; d_tile++) {
83
87
  nk_size_t const d_start_u32 = d_tile * depth_tile_size;
84
- nk_size_t const u32s_this_tile = (d_start_u32 + depth_tile_size <= depth_u32_total)
88
+ nk_size_t const u32s_this_tile = (d_start_u32 + depth_tile_size <= depth_words)
85
89
  ? depth_tile_size
86
- : (depth_u32_total > d_start_u32 ? depth_u32_total - d_start_u32
87
- : 0);
90
+ : (depth_words > d_start_u32 ? depth_words - d_start_u32 : 0);
88
91
  if (u32s_this_tile == 0) break;
89
92
 
90
93
  svzero_mask_za(nk_sme_zero_za32_tile_0_);
91
94
 
92
- svbool_t const batch_predicate_u32x = svwhilelt_b32_u64(0u, u32s_this_tile);
95
+ svbool_t const batch_predicate_b32x = svwhilelt_b32_u64(0u, u32s_this_tile);
93
96
 
97
+ svbool_t const depth_predicate_b8x = svwhilelt_b8_u64(d_start_u32 * 4, depth_bytes);
94
98
  for (nk_size_t row_in_tile = 0; row_in_tile < rows_a_remaining; row_in_tile++) {
95
- nk_u32_t const *a_row_u32 = (nk_u32_t const *)((char const *)a +
96
- (row_start_a + row_in_tile) * a_stride_in_bytes) +
97
- d_start_u32;
98
- svld1_hor_za32(0, row_in_tile, batch_predicate_u32x, a_row_u32);
99
+ nk_u8_t const *a_row = (nk_u8_t const *)a + (row_start_a + row_in_tile) * a_stride_in_bytes +
100
+ d_start_u32 * 4;
101
+ svuint8_t row_u8x = svld1_u8(depth_predicate_b8x, a_row);
102
+ svwrite_hor_za32_u32_m(0, row_in_tile, batch_predicate_b32x, svreinterpret_u32_u8(row_u8x));
99
103
  }
100
104
 
101
105
  nk_u32_t const *b_tile0 = b_tiles + ((row_tile_b + 0) * depth_tile_count + d_tile) * tile_elements;
@@ -103,47 +107,47 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_packed_u1_smebi32_st
103
107
  nk_u32_t const *b_tile2 = b_tiles + ((row_tile_b + 2) * depth_tile_count + d_tile) * tile_elements;
104
108
 
105
109
  for (nk_size_t step = 0; step < u32s_this_tile; step++) {
106
- svuint32_t a_column_u32x = svread_ver_za32_u32_m(svdup_u32(0), row_predicate_u32x, 0, step);
107
-
108
- svbmopa_za32_u32_m(1, row_predicate_u32x, predicate_all_u32x, a_column_u32x,
109
- svld1_u32(predicate_all_u32x, b_tile0 + step * tile_dim));
110
- svbmopa_za32_u32_m(2, row_predicate_u32x, predicate_all_u32x, a_column_u32x,
111
- svld1_u32(predicate_all_u32x, b_tile1 + step * tile_dim));
112
- svbmopa_za32_u32_m(3, row_predicate_u32x, predicate_all_u32x, a_column_u32x,
113
- svld1_u32(predicate_all_u32x, b_tile2 + step * tile_dim));
110
+ svuint32_t a_column_u32x = svread_ver_za32_u32_m(svdup_u32(0), row_predicate_b32x, 0, step);
111
+
112
+ svbmopa_za32_u32_m(1, row_predicate_b32x, predicate_all_b32x, a_column_u32x,
113
+ svld1_u32(predicate_all_b32x, b_tile0 + step * tile_dim));
114
+ svbmopa_za32_u32_m(2, row_predicate_b32x, predicate_all_b32x, a_column_u32x,
115
+ svld1_u32(predicate_all_b32x, b_tile1 + step * tile_dim));
116
+ svbmopa_za32_u32_m(3, row_predicate_b32x, predicate_all_b32x, a_column_u32x,
117
+ svld1_u32(predicate_all_b32x, b_tile2 + step * tile_dim));
114
118
  }
115
119
  }
116
120
 
117
121
  // Extract: dot = (pop_a + pop_b - depth + matching) / 2
118
122
  // matching = ZA[i][j]
119
- svuint32_t b_pop0_u32x = svld1_u32(predicate_all_u32x, b_norms + (row_tile_b + 0) * tile_dim);
120
- svuint32_t b_pop1_u32x = svld1_u32(predicate_all_u32x, b_norms + (row_tile_b + 1) * tile_dim);
121
- svuint32_t b_pop2_u32x = svld1_u32(predicate_all_u32x, b_norms + (row_tile_b + 2) * tile_dim);
123
+ svuint32_t b_pop0_u32x = svld1_u32(predicate_all_b32x, b_norms + (row_tile_b + 0) * tile_dim);
124
+ svuint32_t b_pop1_u32x = svld1_u32(predicate_all_b32x, b_norms + (row_tile_b + 1) * tile_dim);
125
+ svuint32_t b_pop2_u32x = svld1_u32(predicate_all_b32x, b_norms + (row_tile_b + 2) * tile_dim);
122
126
 
123
127
  for (nk_size_t row = 0; row < rows_a_remaining; row++) {
124
128
  nk_u32_t *c_row = (nk_u32_t *)((char *)c + (row_start_a + row) * c_stride_in_bytes);
125
129
  svuint32_t pop_a_u32x = svdup_u32(a_popcounts[row]);
126
130
 
127
- svuint32_t za1_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_u32x, 1, row);
128
- svuint32_t sum_pops0_u32x = svadd_u32_x(predicate_all_u32x, pop_a_u32x, b_pop0_u32x);
131
+ svuint32_t za1_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_b32x, 1, row);
132
+ svuint32_t sum_pops0_u32x = svadd_u32_x(predicate_all_b32x, pop_a_u32x, b_pop0_u32x);
129
133
  svuint32_t numerator0_u32x = svadd_u32_x(
130
- predicate_all_u32x, svsub_u32_x(predicate_all_u32x, sum_pops0_u32x, depth_u32x), za1_u32x);
131
- svst1_u32(predicate_all_u32x, c_row + (row_tile_b + 0) * tile_dim,
132
- svlsr_n_u32_x(predicate_all_u32x, numerator0_u32x, 1));
134
+ predicate_all_b32x, svsub_u32_x(predicate_all_b32x, sum_pops0_u32x, depth_u32x), za1_u32x);
135
+ svst1_u32(predicate_all_b32x, c_row + (row_tile_b + 0) * tile_dim,
136
+ svlsr_n_u32_x(predicate_all_b32x, numerator0_u32x, 1));
133
137
 
134
- svuint32_t za2_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_u32x, 2, row);
135
- svuint32_t sum_pops1_u32x = svadd_u32_x(predicate_all_u32x, pop_a_u32x, b_pop1_u32x);
138
+ svuint32_t za2_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_b32x, 2, row);
139
+ svuint32_t sum_pops1_u32x = svadd_u32_x(predicate_all_b32x, pop_a_u32x, b_pop1_u32x);
136
140
  svuint32_t numerator1_u32x = svadd_u32_x(
137
- predicate_all_u32x, svsub_u32_x(predicate_all_u32x, sum_pops1_u32x, depth_u32x), za2_u32x);
138
- svst1_u32(predicate_all_u32x, c_row + (row_tile_b + 1) * tile_dim,
139
- svlsr_n_u32_x(predicate_all_u32x, numerator1_u32x, 1));
141
+ predicate_all_b32x, svsub_u32_x(predicate_all_b32x, sum_pops1_u32x, depth_u32x), za2_u32x);
142
+ svst1_u32(predicate_all_b32x, c_row + (row_tile_b + 1) * tile_dim,
143
+ svlsr_n_u32_x(predicate_all_b32x, numerator1_u32x, 1));
140
144
 
141
- svuint32_t za3_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_u32x, 3, row);
142
- svuint32_t sum_pops2_u32x = svadd_u32_x(predicate_all_u32x, pop_a_u32x, b_pop2_u32x);
145
+ svuint32_t za3_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_b32x, 3, row);
146
+ svuint32_t sum_pops2_u32x = svadd_u32_x(predicate_all_b32x, pop_a_u32x, b_pop2_u32x);
143
147
  svuint32_t numerator2_u32x = svadd_u32_x(
144
- predicate_all_u32x, svsub_u32_x(predicate_all_u32x, sum_pops2_u32x, depth_u32x), za3_u32x);
145
- svst1_u32(predicate_all_u32x, c_row + (row_tile_b + 2) * tile_dim,
146
- svlsr_n_u32_x(predicate_all_u32x, numerator2_u32x, 1));
148
+ predicate_all_b32x, svsub_u32_x(predicate_all_b32x, sum_pops2_u32x, depth_u32x), za3_u32x);
149
+ svst1_u32(predicate_all_b32x, c_row + (row_tile_b + 2) * tile_dim,
150
+ svlsr_n_u32_x(predicate_all_b32x, numerator2_u32x, 1));
147
151
  }
148
152
  }
149
153
 
@@ -152,49 +156,49 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_packed_u1_smebi32_st
152
156
  nk_size_t const row_start_b = row_tile_b * tile_dim;
153
157
  nk_size_t const rows_b_remaining = (row_start_b + tile_dim <= row_count_b) ? tile_dim
154
158
  : (row_count_b - row_start_b);
155
- svbool_t const column_predicate_u32x = svwhilelt_b32_u64(0u, rows_b_remaining);
159
+ svbool_t const column_predicate_b32x = svwhilelt_b32_u64(0u, rows_b_remaining);
156
160
 
157
161
  svzero_mask_za(nk_sme_zero_za32_tile_1_);
158
162
 
159
163
  for (nk_size_t d_tile = 0; d_tile < depth_tile_count; d_tile++) {
160
164
  nk_size_t const d_start_u32 = d_tile * depth_tile_size;
161
- nk_size_t const u32s_this_tile = (d_start_u32 + depth_tile_size <= depth_u32_total)
165
+ nk_size_t const u32s_this_tile = (d_start_u32 + depth_tile_size <= depth_words)
162
166
  ? depth_tile_size
163
- : (depth_u32_total > d_start_u32 ? depth_u32_total - d_start_u32
164
- : 0);
167
+ : (depth_words > d_start_u32 ? depth_words - d_start_u32 : 0);
165
168
  if (u32s_this_tile == 0) break;
166
169
 
167
170
  svzero_mask_za(nk_sme_zero_za32_tile_0_);
168
171
 
169
- svbool_t const batch_predicate_u32x = svwhilelt_b32_u64(0u, u32s_this_tile);
172
+ svbool_t const batch_predicate_b32x = svwhilelt_b32_u64(0u, u32s_this_tile);
170
173
 
174
+ svbool_t const depth_predicate_b8x = svwhilelt_b8_u64(d_start_u32 * 4, depth_bytes);
171
175
  for (nk_size_t row_in_tile = 0; row_in_tile < rows_a_remaining; row_in_tile++) {
172
- nk_u32_t const *a_row_u32 = (nk_u32_t const *)((char const *)a +
173
- (row_start_a + row_in_tile) * a_stride_in_bytes) +
174
- d_start_u32;
175
- svld1_hor_za32(0, row_in_tile, batch_predicate_u32x, a_row_u32);
176
+ nk_u8_t const *a_row = (nk_u8_t const *)a + (row_start_a + row_in_tile) * a_stride_in_bytes +
177
+ d_start_u32 * 4;
178
+ svuint8_t row_u8x = svld1_u8(depth_predicate_b8x, a_row);
179
+ svwrite_hor_za32_u32_m(0, row_in_tile, batch_predicate_b32x, svreinterpret_u32_u8(row_u8x));
176
180
  }
177
181
 
178
182
  nk_u32_t const *b_tile = b_tiles + (row_tile_b * depth_tile_count + d_tile) * tile_elements;
179
183
 
180
184
  for (nk_size_t step = 0; step < u32s_this_tile; step++) {
181
- svuint32_t a_column_u32x = svread_ver_za32_u32_m(svdup_u32(0), row_predicate_u32x, 0, step);
182
- svuint32_t b_u32x = svld1_u32(predicate_all_u32x, b_tile + step * tile_dim);
183
- svbmopa_za32_u32_m(1, row_predicate_u32x, column_predicate_u32x, a_column_u32x, b_u32x);
185
+ svuint32_t a_column_u32x = svread_ver_za32_u32_m(svdup_u32(0), row_predicate_b32x, 0, step);
186
+ svuint32_t b_u32x = svld1_u32(predicate_all_b32x, b_tile + step * tile_dim);
187
+ svbmopa_za32_u32_m(1, row_predicate_b32x, column_predicate_b32x, a_column_u32x, b_u32x);
184
188
  }
185
189
  }
186
190
 
187
191
  // Extract: dot = (pop_a + pop_b - depth + matching) / 2
188
- svuint32_t b_pop_u32x = svld1_u32(predicate_all_u32x, b_norms + row_start_b);
192
+ svuint32_t b_pop_u32x = svld1_u32(predicate_all_b32x, b_norms + row_start_b);
189
193
  for (nk_size_t row = 0; row < rows_a_remaining; row++) {
190
- svuint32_t za1_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_u32x, 1, row);
194
+ svuint32_t za1_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_b32x, 1, row);
191
195
  svuint32_t pop_a_u32x = svdup_u32(a_popcounts[row]);
192
- svuint32_t sum_pops_u32x = svadd_u32_x(predicate_all_u32x, pop_a_u32x, b_pop_u32x);
196
+ svuint32_t sum_pops_u32x = svadd_u32_x(predicate_all_b32x, pop_a_u32x, b_pop_u32x);
193
197
  svuint32_t numerator_u32x = svadd_u32_x(
194
- predicate_all_u32x, svsub_u32_x(predicate_all_u32x, sum_pops_u32x, depth_u32x), za1_u32x);
198
+ predicate_all_b32x, svsub_u32_x(predicate_all_b32x, sum_pops_u32x, depth_u32x), za1_u32x);
195
199
  nk_u32_t *c_row = (nk_u32_t *)((char *)c + (row_start_a + row) * c_stride_in_bytes);
196
- svst1_u32(column_predicate_u32x, c_row + row_start_b,
197
- svlsr_n_u32_x(predicate_all_u32x, numerator_u32x, 1));
200
+ svst1_u32(column_predicate_b32x, c_row + row_start_b,
201
+ svlsr_n_u32_x(predicate_all_b32x, numerator_u32x, 1));
198
202
  }
199
203
  }
200
204
  }
@@ -212,39 +216,46 @@ NK_PUBLIC void nk_dots_packed_u1_smebi32(nk_u1x8_t const *a, void const *b_packe
212
216
  * Same ZA transpose pattern as hammings_symmetric, but with dot extraction.
213
217
  */
214
218
  __arm_locally_streaming __arm_new("za") static void nk_dots_symmetric_u1_smebi32_streaming_(
215
- nk_u1x8_t const *vectors, nk_size_t n_vectors, nk_size_t depth_bits, nk_size_t stride, nk_u32_t *result,
216
- nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
219
+ nk_u1x8_t const *vectors, nk_size_t vectors_count, nk_size_t depth_bits, nk_size_t stride_in_bytes,
220
+ nk_u32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
217
221
 
218
222
  nk_size_t const tile_dim = svcntw(); // 16 for 512-bit SVL
219
223
  nk_size_t const depth_tile_size = svcntw(); // 16 u32 per depth tile
220
- nk_size_t const depth_u32_total = nk_size_divide_round_up_(depth_bits, 32);
221
- nk_size_t const depth_tile_count = nk_size_divide_round_up_(depth_u32_total, depth_tile_size);
222
- nk_size_t const depth_in_bytes = nk_size_divide_round_up_(depth_bits, 8);
224
+ // BMOPA processes binary data in 32-bit words: each svbmopa_za32_u32_m step
225
+ // handles one u32 (32 bits) across all row×column pairs simultaneously.
226
+ nk_size_t const depth_words = nk_size_divide_round_up_(depth_bits, 32);
227
+ nk_size_t const depth_bytes = depth_bits / 8;
228
+ nk_size_t const depth_tile_count = nk_size_divide_round_up_(depth_words, depth_tile_size);
223
229
 
224
- svbool_t const predicate_all_u32x = svptrue_b32();
225
- svuint32_t const depth_u32x = svdup_u32((nk_u32_t)depth_bits);
230
+ svbool_t const predicate_all_b32x = svptrue_b32();
231
+ // Use padded depth (depth_words * 32) for BMOPA: zero-padded bits always match in XNOR,
232
+ // so the effective depth for the matching→intersection conversion is the rounded-up bit count.
233
+ svuint32_t const depth_u32x = svdup_u32((nk_u32_t)(depth_words * 32));
226
234
 
227
235
  NK_ALIGN64 nk_u32_t a_buffer[16][16]; // Stack buffer for A column save
228
236
 
229
237
  nk_size_t const row_end = row_start + row_count;
230
- nk_size_t const column_tile_count = nk_size_divide_round_up_(n_vectors, tile_dim);
238
+ nk_size_t const column_tile_count = nk_size_divide_round_up_(vectors_count, tile_dim);
231
239
 
232
- for (nk_size_t row_tile_start = row_start; row_tile_start < row_end && row_tile_start < n_vectors;
240
+ for (nk_size_t row_tile_start = row_start; row_tile_start < row_end && row_tile_start < vectors_count;
233
241
  row_tile_start += tile_dim) {
234
242
  nk_size_t const rows_remaining = (row_tile_start + tile_dim <= row_end) ? tile_dim : (row_end - row_tile_start);
235
- nk_size_t const rows_clamped = (row_tile_start + rows_remaining <= n_vectors) ? rows_remaining
236
- : (n_vectors - row_tile_start);
237
- svbool_t const row_predicate_u32x = svwhilelt_b32_u64(0u, rows_clamped);
243
+ nk_size_t const rows_clamped = (row_tile_start + rows_remaining <= vectors_count)
244
+ ? rows_remaining
245
+ : (vectors_count - row_tile_start);
246
+ svbool_t const row_predicate_b32x = svwhilelt_b32_u64(0u, rows_clamped);
238
247
 
239
248
  // Compute A tile popcounts
240
249
  NK_ALIGN64 nk_u32_t a_tile_pops[16];
241
250
  for (nk_size_t r = 0; r < rows_clamped; r++) {
242
- nk_u1x8_t const *a_row = (nk_u1x8_t const *)((char const *)vectors + (row_tile_start + r) * stride);
243
- a_tile_pops[r] = nk_sets_reduce_sumsq_u1_streaming_(a_row, depth_in_bytes);
251
+ nk_u1x8_t const *a_row = (nk_u1x8_t const *)((char const *)vectors +
252
+ (row_tile_start + r) * stride_in_bytes);
253
+ a_tile_pops[r] = nk_sets_reduce_sumsq_u1_streaming_(a_row, depth_bytes);
244
254
  }
245
255
  for (nk_size_t r = rows_clamped; r < tile_dim; r++) a_tile_pops[r] = 0;
246
256
 
247
- nk_size_t column_tile_index = 0;
257
+ // Upper triangle: start from this row tile's column
258
+ nk_size_t column_tile_index = row_tile_start / tile_dim;
248
259
 
249
260
  // Fast path: 3 column tiles using ZA1-ZA3 (ZA0 = staging)
250
261
  for (; column_tile_index + 3 <= column_tile_count; column_tile_index += 3) {
@@ -252,73 +263,73 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_symmetric_u1_smebi32
252
263
 
253
264
  for (nk_size_t d_tile = 0; d_tile < depth_tile_count; d_tile++) {
254
265
  nk_size_t const d_start_u32 = d_tile * depth_tile_size;
255
- nk_size_t const u32s_this_tile = (d_start_u32 + depth_tile_size <= depth_u32_total)
266
+ nk_size_t const u32s_this_tile = (d_start_u32 + depth_tile_size <= depth_words)
256
267
  ? depth_tile_size
257
- : (depth_u32_total > d_start_u32 ? depth_u32_total - d_start_u32
258
- : 0);
268
+ : (depth_words > d_start_u32 ? depth_words - d_start_u32 : 0);
259
269
  if (u32s_this_tile == 0) break;
260
270
 
261
271
  svzero_mask_za(nk_sme_zero_za32_tile_0_);
262
- svbool_t const batch_predicate_u32x = svwhilelt_b32_u64(0u, u32s_this_tile);
272
+ svbool_t const batch_predicate_b32x = svwhilelt_b32_u64(0u, u32s_this_tile);
263
273
 
274
+ svbool_t const depth_predicate_b8x = svwhilelt_b8_u64(d_start_u32 * 4, depth_bytes);
264
275
  for (nk_size_t row_in_tile = 0; row_in_tile < rows_clamped; row_in_tile++) {
265
- nk_u32_t const *a_row_u32 = (nk_u32_t const *)((char const *)vectors +
266
- (row_tile_start + row_in_tile) * stride) +
267
- d_start_u32;
268
- svld1_hor_za32(0, row_in_tile, batch_predicate_u32x, a_row_u32);
276
+ nk_u8_t const *a_row = (nk_u8_t const *)vectors + (row_tile_start + row_in_tile) * stride_in_bytes +
277
+ d_start_u32 * 4;
278
+ svuint8_t row_u8x = svld1_u8(depth_predicate_b8x, a_row);
279
+ svwrite_hor_za32_u32_m(0, row_in_tile, batch_predicate_b32x, svreinterpret_u32_u8(row_u8x));
269
280
  }
270
281
 
271
282
  // Save A columns
272
283
  for (nk_size_t s = 0; s < u32s_this_tile; s++)
273
- svst1_u32(predicate_all_u32x, a_buffer[s],
274
- svread_ver_za32_u32_m(svdup_u32(0), row_predicate_u32x, 0, s));
284
+ svst1_u32(predicate_all_b32x, a_buffer[s],
285
+ svread_ver_za32_u32_m(svdup_u32(0), row_predicate_b32x, 0, s));
275
286
 
276
287
  // B column tile 0
277
288
  svzero_mask_za(nk_sme_zero_za32_tile_0_);
278
289
  for (nk_size_t col = 0; col < tile_dim; col++) {
279
290
  nk_size_t const col_abs = (column_tile_index + 0) * tile_dim + col;
280
- if (col_abs < n_vectors) {
281
- nk_u32_t const *b_row = (nk_u32_t const *)((char const *)vectors + col_abs * stride) +
282
- d_start_u32;
283
- svld1_hor_za32(0, col, batch_predicate_u32x, b_row);
291
+ if (col_abs < vectors_count) {
292
+ nk_u8_t const *b_row = (nk_u8_t const *)vectors + col_abs * stride_in_bytes + d_start_u32 * 4;
293
+ svuint8_t col_u8x = svld1_u8(depth_predicate_b8x, b_row);
294
+ svwrite_hor_za32_u32_m(0, col, batch_predicate_b32x, svreinterpret_u32_u8(col_u8x));
284
295
  }
285
296
  }
286
297
  for (nk_size_t step = 0; step < u32s_this_tile; step++) {
287
- svuint32_t a_u32x = svld1_u32(predicate_all_u32x, a_buffer[step]);
288
- svuint32_t b_u32x = svread_ver_za32_u32_m(svdup_u32(0), predicate_all_u32x, 0, step);
289
- svbmopa_za32_u32_m(1, row_predicate_u32x, predicate_all_u32x, a_u32x, b_u32x);
298
+ svuint32_t a_u32x = svld1_u32(predicate_all_b32x, a_buffer[step]);
299
+ svuint32_t b_u32x = svread_ver_za32_u32_m(svdup_u32(0), predicate_all_b32x, 0, step);
300
+ svbmopa_za32_u32_m(1, row_predicate_b32x, predicate_all_b32x, a_u32x, b_u32x);
290
301
  }
291
302
 
292
303
  // B column tile 1
293
304
  svzero_mask_za(nk_sme_zero_za32_tile_0_);
294
305
  for (nk_size_t col = 0; col < tile_dim; col++) {
295
306
  nk_size_t const col_abs = (column_tile_index + 1) * tile_dim + col;
296
- if (col_abs < n_vectors) {
297
- nk_u32_t const *b_row = (nk_u32_t const *)((char const *)vectors + col_abs * stride) +
298
- d_start_u32;
299
- svld1_hor_za32(0, col, batch_predicate_u32x, b_row);
307
+ if (col_abs < vectors_count) {
308
+ nk_u8_t const *b_row = (nk_u8_t const *)vectors + col_abs * stride_in_bytes + d_start_u32 * 4;
309
+ svuint8_t col_u8x = svld1_u8(depth_predicate_b8x, b_row);
310
+ svwrite_hor_za32_u32_m(0, col, batch_predicate_b32x, svreinterpret_u32_u8(col_u8x));
300
311
  }
301
312
  }
302
313
  for (nk_size_t step = 0; step < u32s_this_tile; step++) {
303
- svuint32_t a_u32x = svld1_u32(predicate_all_u32x, a_buffer[step]);
304
- svuint32_t b_u32x = svread_ver_za32_u32_m(svdup_u32(0), predicate_all_u32x, 0, step);
305
- svbmopa_za32_u32_m(2, row_predicate_u32x, predicate_all_u32x, a_u32x, b_u32x);
314
+ svuint32_t a_u32x = svld1_u32(predicate_all_b32x, a_buffer[step]);
315
+ svuint32_t b_u32x = svread_ver_za32_u32_m(svdup_u32(0), predicate_all_b32x, 0, step);
316
+ svbmopa_za32_u32_m(2, row_predicate_b32x, predicate_all_b32x, a_u32x, b_u32x);
306
317
  }
307
318
 
308
319
  // B column tile 2
309
320
  svzero_mask_za(nk_sme_zero_za32_tile_0_);
310
321
  for (nk_size_t col = 0; col < tile_dim; col++) {
311
322
  nk_size_t const col_abs = (column_tile_index + 2) * tile_dim + col;
312
- if (col_abs < n_vectors) {
313
- nk_u32_t const *b_row = (nk_u32_t const *)((char const *)vectors + col_abs * stride) +
314
- d_start_u32;
315
- svld1_hor_za32(0, col, batch_predicate_u32x, b_row);
323
+ if (col_abs < vectors_count) {
324
+ nk_u8_t const *b_row = (nk_u8_t const *)vectors + col_abs * stride_in_bytes + d_start_u32 * 4;
325
+ svuint8_t col_u8x = svld1_u8(depth_predicate_b8x, b_row);
326
+ svwrite_hor_za32_u32_m(0, col, batch_predicate_b32x, svreinterpret_u32_u8(col_u8x));
316
327
  }
317
328
  }
318
329
  for (nk_size_t step = 0; step < u32s_this_tile; step++) {
319
- svuint32_t a_u32x = svld1_u32(predicate_all_u32x, a_buffer[step]);
320
- svuint32_t b_u32x = svread_ver_za32_u32_m(svdup_u32(0), predicate_all_u32x, 0, step);
321
- svbmopa_za32_u32_m(3, row_predicate_u32x, predicate_all_u32x, a_u32x, b_u32x);
330
+ svuint32_t a_u32x = svld1_u32(predicate_all_b32x, a_buffer[step]);
331
+ svuint32_t b_u32x = svread_ver_za32_u32_m(svdup_u32(0), predicate_all_b32x, 0, step);
332
+ svbmopa_za32_u32_m(3, row_predicate_b32x, predicate_all_b32x, a_u32x, b_u32x);
322
333
  }
323
334
  }
324
335
 
@@ -328,88 +339,89 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_symmetric_u1_smebi32
328
339
  for (nk_size_t t = 0; t < 3; t++) {
329
340
  for (nk_size_t col = 0; col < tile_dim; col++) {
330
341
  nk_size_t const col_abs = (column_tile_index + t) * tile_dim + col;
331
- if (col_abs < n_vectors) {
332
- nk_u1x8_t const *b_row = (nk_u1x8_t const *)((char const *)vectors + col_abs * stride);
333
- b_pops[t][col] = nk_sets_reduce_sumsq_u1_streaming_(b_row, depth_in_bytes);
342
+ if (col_abs < vectors_count) {
343
+ nk_u1x8_t const *b_row = (nk_u1x8_t const *)((char const *)vectors + col_abs * stride_in_bytes);
344
+ b_pops[t][col] = nk_sets_reduce_sumsq_u1_streaming_(b_row, depth_bytes);
334
345
  }
335
346
  else { b_pops[t][col] = 0; }
336
347
  }
337
348
  }
338
349
 
339
350
  for (nk_size_t row = 0; row < rows_clamped; row++) {
340
- nk_u32_t *result_row = (nk_u32_t *)((char *)result + (row_tile_start + row) * result_stride);
351
+ nk_u32_t *result_row = (nk_u32_t *)((char *)result + (row_tile_start + row) * result_stride_in_bytes);
341
352
  svuint32_t pop_a_u32x = svdup_u32(a_tile_pops[row]);
342
353
 
343
- svuint32_t za1_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_u32x, 1, row);
344
- svuint32_t b_popcount_0_u32x = svld1_u32(predicate_all_u32x, b_pops[0]);
345
- svuint32_t sum_pops0_u32x = svadd_u32_x(predicate_all_u32x, pop_a_u32x, b_popcount_0_u32x);
354
+ svuint32_t za1_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_b32x, 1, row);
355
+ svuint32_t b_popcount_0_u32x = svld1_u32(predicate_all_b32x, b_pops[0]);
356
+ svuint32_t sum_pops0_u32x = svadd_u32_x(predicate_all_b32x, pop_a_u32x, b_popcount_0_u32x);
346
357
  svuint32_t numerator0_u32x = svadd_u32_x(
347
- predicate_all_u32x, svsub_u32_x(predicate_all_u32x, sum_pops0_u32x, depth_u32x), za1_u32x);
348
- svst1_u32(predicate_all_u32x, result_row + (column_tile_index + 0) * tile_dim,
349
- svlsr_n_u32_x(predicate_all_u32x, numerator0_u32x, 1));
358
+ predicate_all_b32x, svsub_u32_x(predicate_all_b32x, sum_pops0_u32x, depth_u32x), za1_u32x);
359
+ svst1_u32(predicate_all_b32x, result_row + (column_tile_index + 0) * tile_dim,
360
+ svlsr_n_u32_x(predicate_all_b32x, numerator0_u32x, 1));
350
361
 
351
- svuint32_t za2_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_u32x, 2, row);
352
- svuint32_t b_popcount_1_u32x = svld1_u32(predicate_all_u32x, b_pops[1]);
353
- svuint32_t sum_pops1_u32x = svadd_u32_x(predicate_all_u32x, pop_a_u32x, b_popcount_1_u32x);
362
+ svuint32_t za2_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_b32x, 2, row);
363
+ svuint32_t b_popcount_1_u32x = svld1_u32(predicate_all_b32x, b_pops[1]);
364
+ svuint32_t sum_pops1_u32x = svadd_u32_x(predicate_all_b32x, pop_a_u32x, b_popcount_1_u32x);
354
365
  svuint32_t numerator1_u32x = svadd_u32_x(
355
- predicate_all_u32x, svsub_u32_x(predicate_all_u32x, sum_pops1_u32x, depth_u32x), za2_u32x);
356
- svst1_u32(predicate_all_u32x, result_row + (column_tile_index + 1) * tile_dim,
357
- svlsr_n_u32_x(predicate_all_u32x, numerator1_u32x, 1));
366
+ predicate_all_b32x, svsub_u32_x(predicate_all_b32x, sum_pops1_u32x, depth_u32x), za2_u32x);
367
+ svst1_u32(predicate_all_b32x, result_row + (column_tile_index + 1) * tile_dim,
368
+ svlsr_n_u32_x(predicate_all_b32x, numerator1_u32x, 1));
358
369
 
359
- svuint32_t za3_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_u32x, 3, row);
360
- svuint32_t b_popcount_2_u32x = svld1_u32(predicate_all_u32x, b_pops[2]);
361
- svuint32_t sum_pops2_u32x = svadd_u32_x(predicate_all_u32x, pop_a_u32x, b_popcount_2_u32x);
370
+ svuint32_t za3_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_b32x, 3, row);
371
+ svuint32_t b_popcount_2_u32x = svld1_u32(predicate_all_b32x, b_pops[2]);
372
+ svuint32_t sum_pops2_u32x = svadd_u32_x(predicate_all_b32x, pop_a_u32x, b_popcount_2_u32x);
362
373
  svuint32_t numerator2_u32x = svadd_u32_x(
363
- predicate_all_u32x, svsub_u32_x(predicate_all_u32x, sum_pops2_u32x, depth_u32x), za3_u32x);
364
- svst1_u32(predicate_all_u32x, result_row + (column_tile_index + 2) * tile_dim,
365
- svlsr_n_u32_x(predicate_all_u32x, numerator2_u32x, 1));
374
+ predicate_all_b32x, svsub_u32_x(predicate_all_b32x, sum_pops2_u32x, depth_u32x), za3_u32x);
375
+ svst1_u32(predicate_all_b32x, result_row + (column_tile_index + 2) * tile_dim,
376
+ svlsr_n_u32_x(predicate_all_b32x, numerator2_u32x, 1));
366
377
  }
367
378
  }
368
379
 
369
380
  // Remainder: 1 column tile at a time using ZA1
370
381
  for (; column_tile_index < column_tile_count; column_tile_index++) {
371
382
  nk_size_t const col_tile_start = column_tile_index * tile_dim;
372
- nk_size_t const cols_remaining = (col_tile_start + tile_dim <= n_vectors) ? tile_dim
373
- : (n_vectors - col_tile_start);
374
- svbool_t const column_predicate_u32x = svwhilelt_b32_u64(0u, cols_remaining);
383
+ nk_size_t const cols_remaining = (col_tile_start + tile_dim <= vectors_count)
384
+ ? tile_dim
385
+ : (vectors_count - col_tile_start);
386
+ svbool_t const column_predicate_b32x = svwhilelt_b32_u64(0u, cols_remaining);
375
387
 
376
388
  svzero_mask_za(nk_sme_zero_za32_tile_1_);
377
389
 
378
390
  for (nk_size_t d_tile = 0; d_tile < depth_tile_count; d_tile++) {
379
391
  nk_size_t const d_start_u32 = d_tile * depth_tile_size;
380
- nk_size_t const u32s_this_tile = (d_start_u32 + depth_tile_size <= depth_u32_total)
392
+ nk_size_t const u32s_this_tile = (d_start_u32 + depth_tile_size <= depth_words)
381
393
  ? depth_tile_size
382
- : (depth_u32_total > d_start_u32 ? depth_u32_total - d_start_u32
383
- : 0);
394
+ : (depth_words > d_start_u32 ? depth_words - d_start_u32 : 0);
384
395
  if (u32s_this_tile == 0) break;
385
396
 
386
397
  svzero_mask_za(nk_sme_zero_za32_tile_0_);
387
- svbool_t const batch_predicate_u32x = svwhilelt_b32_u64(0u, u32s_this_tile);
398
+ svbool_t const batch_predicate_b32x = svwhilelt_b32_u64(0u, u32s_this_tile);
388
399
 
400
+ svbool_t const depth_predicate_b8x = svwhilelt_b8_u64(d_start_u32 * 4, depth_bytes);
389
401
  for (nk_size_t row_in_tile = 0; row_in_tile < rows_clamped; row_in_tile++) {
390
- nk_u32_t const *a_row_u32 = (nk_u32_t const *)((char const *)vectors +
391
- (row_tile_start + row_in_tile) * stride) +
392
- d_start_u32;
393
- svld1_hor_za32(0, row_in_tile, batch_predicate_u32x, a_row_u32);
402
+ nk_u8_t const *a_row = (nk_u8_t const *)vectors + (row_tile_start + row_in_tile) * stride_in_bytes +
403
+ d_start_u32 * 4;
404
+ svuint8_t row_u8x = svld1_u8(depth_predicate_b8x, a_row);
405
+ svwrite_hor_za32_u32_m(0, row_in_tile, batch_predicate_b32x, svreinterpret_u32_u8(row_u8x));
394
406
  }
395
407
 
396
408
  for (nk_size_t s = 0; s < u32s_this_tile; s++)
397
- svst1_u32(predicate_all_u32x, a_buffer[s],
398
- svread_ver_za32_u32_m(svdup_u32(0), row_predicate_u32x, 0, s));
409
+ svst1_u32(predicate_all_b32x, a_buffer[s],
410
+ svread_ver_za32_u32_m(svdup_u32(0), row_predicate_b32x, 0, s));
399
411
 
400
412
  svzero_mask_za(nk_sme_zero_za32_tile_0_);
401
413
  for (nk_size_t col = 0; col < tile_dim; col++) {
402
414
  nk_size_t const col_abs = col_tile_start + col;
403
- if (col_abs < n_vectors) {
404
- nk_u32_t const *b_row = (nk_u32_t const *)((char const *)vectors + col_abs * stride) +
405
- d_start_u32;
406
- svld1_hor_za32(0, col, batch_predicate_u32x, b_row);
415
+ if (col_abs < vectors_count) {
416
+ nk_u8_t const *b_row = (nk_u8_t const *)vectors + col_abs * stride_in_bytes + d_start_u32 * 4;
417
+ svuint8_t col_u8x = svld1_u8(depth_predicate_b8x, b_row);
418
+ svwrite_hor_za32_u32_m(0, col, batch_predicate_b32x, svreinterpret_u32_u8(col_u8x));
407
419
  }
408
420
  }
409
421
  for (nk_size_t step = 0; step < u32s_this_tile; step++) {
410
- svuint32_t a_u32x = svld1_u32(predicate_all_u32x, a_buffer[step]);
411
- svuint32_t b_u32x = svread_ver_za32_u32_m(svdup_u32(0), column_predicate_u32x, 0, step);
412
- svbmopa_za32_u32_m(1, row_predicate_u32x, column_predicate_u32x, a_u32x, b_u32x);
422
+ svuint32_t a_u32x = svld1_u32(predicate_all_b32x, a_buffer[step]);
423
+ svuint32_t b_u32x = svread_ver_za32_u32_m(svdup_u32(0), column_predicate_b32x, 0, step);
424
+ svbmopa_za32_u32_m(1, row_predicate_b32x, column_predicate_b32x, a_u32x, b_u32x);
413
425
  }
414
426
  }
415
427
 
@@ -417,33 +429,34 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_symmetric_u1_smebi32
417
429
  NK_ALIGN64 nk_u32_t b_pops_r[16];
418
430
  for (nk_size_t col = 0; col < tile_dim; col++) {
419
431
  nk_size_t const col_abs = col_tile_start + col;
420
- if (col_abs < n_vectors) {
421
- nk_u1x8_t const *b_row = (nk_u1x8_t const *)((char const *)vectors + col_abs * stride);
422
- b_pops_r[col] = nk_sets_reduce_sumsq_u1_streaming_(b_row, depth_in_bytes);
432
+ if (col_abs < vectors_count) {
433
+ nk_u1x8_t const *b_row = (nk_u1x8_t const *)((char const *)vectors + col_abs * stride_in_bytes);
434
+ b_pops_r[col] = nk_sets_reduce_sumsq_u1_streaming_(b_row, depth_bytes);
423
435
  }
424
436
  else { b_pops_r[col] = 0; }
425
437
  }
426
438
 
427
439
  for (nk_size_t row = 0; row < rows_clamped; row++) {
428
- svuint32_t za1_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_u32x, 1, row);
440
+ svuint32_t za1_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_b32x, 1, row);
429
441
  svuint32_t pop_a_u32x = svdup_u32(a_tile_pops[row]);
430
- svuint32_t b_popcount_u32x = svld1_u32(predicate_all_u32x, b_pops_r);
431
- svuint32_t sum_pops_u32x = svadd_u32_x(predicate_all_u32x, pop_a_u32x, b_popcount_u32x);
442
+ svuint32_t b_popcount_u32x = svld1_u32(predicate_all_b32x, b_pops_r);
443
+ svuint32_t sum_pops_u32x = svadd_u32_x(predicate_all_b32x, pop_a_u32x, b_popcount_u32x);
432
444
  svuint32_t numerator_u32x = svadd_u32_x(
433
- predicate_all_u32x, svsub_u32_x(predicate_all_u32x, sum_pops_u32x, depth_u32x), za1_u32x);
434
- nk_u32_t *result_row = (nk_u32_t *)((char *)result + (row_tile_start + row) * result_stride);
435
- svst1_u32(column_predicate_u32x, result_row + col_tile_start,
436
- svlsr_n_u32_x(predicate_all_u32x, numerator_u32x, 1));
445
+ predicate_all_b32x, svsub_u32_x(predicate_all_b32x, sum_pops_u32x, depth_u32x), za1_u32x);
446
+ nk_u32_t *result_row = (nk_u32_t *)((char *)result + (row_tile_start + row) * result_stride_in_bytes);
447
+ svst1_u32(column_predicate_b32x, result_row + col_tile_start,
448
+ svlsr_n_u32_x(predicate_all_b32x, numerator_u32x, 1));
437
449
  }
438
450
  }
439
451
  }
440
452
  }
441
453
 
442
- NK_PUBLIC void nk_dots_symmetric_u1_smebi32(nk_u1x8_t const *vectors, nk_size_t n_vectors, nk_size_t depth_bits,
443
- nk_size_t stride, nk_u32_t *result, nk_size_t result_stride,
444
- nk_size_t row_start, nk_size_t row_count) {
445
- nk_dots_symmetric_u1_smebi32_streaming_(vectors, n_vectors, depth_bits, stride, result, result_stride, row_start,
446
- row_count);
454
+ NK_PUBLIC void nk_dots_symmetric_u1_smebi32(nk_u1x8_t const *vectors, nk_size_t vectors_count, nk_size_t depth_bits,
455
+ nk_size_t stride_in_bytes, nk_u32_t *result,
456
+ nk_size_t result_stride_in_bytes, nk_size_t row_start,
457
+ nk_size_t row_count) {
458
+ nk_dots_symmetric_u1_smebi32_streaming_(vectors, vectors_count, depth_bits, stride_in_bytes, result,
459
+ result_stride_in_bytes, row_start, row_count);
447
460
  }
448
461
 
449
462
  #if defined(__clang__)