numkong 7.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (294) hide show
  1. package/LICENSE +201 -0
  2. package/README.md +495 -0
  3. package/binding.gyp +540 -0
  4. package/c/dispatch.h +512 -0
  5. package/c/dispatch_bf16.c +389 -0
  6. package/c/dispatch_bf16c.c +52 -0
  7. package/c/dispatch_e2m3.c +263 -0
  8. package/c/dispatch_e3m2.c +243 -0
  9. package/c/dispatch_e4m3.c +276 -0
  10. package/c/dispatch_e5m2.c +272 -0
  11. package/c/dispatch_f16.c +376 -0
  12. package/c/dispatch_f16c.c +58 -0
  13. package/c/dispatch_f32.c +378 -0
  14. package/c/dispatch_f32c.c +99 -0
  15. package/c/dispatch_f64.c +296 -0
  16. package/c/dispatch_f64c.c +98 -0
  17. package/c/dispatch_i16.c +96 -0
  18. package/c/dispatch_i32.c +89 -0
  19. package/c/dispatch_i4.c +150 -0
  20. package/c/dispatch_i64.c +86 -0
  21. package/c/dispatch_i8.c +289 -0
  22. package/c/dispatch_other.c +330 -0
  23. package/c/dispatch_u1.c +148 -0
  24. package/c/dispatch_u16.c +124 -0
  25. package/c/dispatch_u32.c +118 -0
  26. package/c/dispatch_u4.c +150 -0
  27. package/c/dispatch_u64.c +102 -0
  28. package/c/dispatch_u8.c +303 -0
  29. package/c/numkong.c +950 -0
  30. package/include/README.md +573 -0
  31. package/include/module.modulemap +129 -0
  32. package/include/numkong/attention/sapphireamx.h +1361 -0
  33. package/include/numkong/attention/sme.h +2066 -0
  34. package/include/numkong/attention.h +49 -0
  35. package/include/numkong/capabilities.h +748 -0
  36. package/include/numkong/cast/README.md +262 -0
  37. package/include/numkong/cast/haswell.h +975 -0
  38. package/include/numkong/cast/icelake.h +470 -0
  39. package/include/numkong/cast/neon.h +1192 -0
  40. package/include/numkong/cast/rvv.h +1021 -0
  41. package/include/numkong/cast/sapphire.h +262 -0
  42. package/include/numkong/cast/serial.h +2262 -0
  43. package/include/numkong/cast/skylake.h +856 -0
  44. package/include/numkong/cast/v128relaxed.h +180 -0
  45. package/include/numkong/cast.h +230 -0
  46. package/include/numkong/curved/README.md +223 -0
  47. package/include/numkong/curved/genoa.h +182 -0
  48. package/include/numkong/curved/haswell.h +276 -0
  49. package/include/numkong/curved/neon.h +205 -0
  50. package/include/numkong/curved/neonbfdot.h +212 -0
  51. package/include/numkong/curved/neonhalf.h +212 -0
  52. package/include/numkong/curved/rvv.h +305 -0
  53. package/include/numkong/curved/serial.h +207 -0
  54. package/include/numkong/curved/skylake.h +457 -0
  55. package/include/numkong/curved/smef64.h +506 -0
  56. package/include/numkong/curved.h +517 -0
  57. package/include/numkong/curved.hpp +144 -0
  58. package/include/numkong/dot/README.md +425 -0
  59. package/include/numkong/dot/alder.h +563 -0
  60. package/include/numkong/dot/genoa.h +315 -0
  61. package/include/numkong/dot/haswell.h +1688 -0
  62. package/include/numkong/dot/icelake.h +883 -0
  63. package/include/numkong/dot/neon.h +818 -0
  64. package/include/numkong/dot/neonbfdot.h +244 -0
  65. package/include/numkong/dot/neonfhm.h +360 -0
  66. package/include/numkong/dot/neonhalf.h +198 -0
  67. package/include/numkong/dot/neonsdot.h +508 -0
  68. package/include/numkong/dot/rvv.h +714 -0
  69. package/include/numkong/dot/rvvbb.h +72 -0
  70. package/include/numkong/dot/rvvbf16.h +123 -0
  71. package/include/numkong/dot/rvvhalf.h +129 -0
  72. package/include/numkong/dot/sapphire.h +141 -0
  73. package/include/numkong/dot/serial.h +838 -0
  74. package/include/numkong/dot/sierra.h +405 -0
  75. package/include/numkong/dot/skylake.h +1084 -0
  76. package/include/numkong/dot/sve.h +379 -0
  77. package/include/numkong/dot/svebfdot.h +74 -0
  78. package/include/numkong/dot/svehalf.h +123 -0
  79. package/include/numkong/dot/v128relaxed.h +1258 -0
  80. package/include/numkong/dot.h +1070 -0
  81. package/include/numkong/dot.hpp +94 -0
  82. package/include/numkong/dots/README.md +496 -0
  83. package/include/numkong/dots/alder.h +114 -0
  84. package/include/numkong/dots/genoa.h +94 -0
  85. package/include/numkong/dots/haswell.h +295 -0
  86. package/include/numkong/dots/icelake.h +171 -0
  87. package/include/numkong/dots/neon.h +120 -0
  88. package/include/numkong/dots/neonbfdot.h +58 -0
  89. package/include/numkong/dots/neonfhm.h +94 -0
  90. package/include/numkong/dots/neonhalf.h +57 -0
  91. package/include/numkong/dots/neonsdot.h +108 -0
  92. package/include/numkong/dots/rvv.h +2486 -0
  93. package/include/numkong/dots/sapphireamx.h +3973 -0
  94. package/include/numkong/dots/serial.h +2844 -0
  95. package/include/numkong/dots/sierra.h +97 -0
  96. package/include/numkong/dots/skylake.h +196 -0
  97. package/include/numkong/dots/sme.h +5372 -0
  98. package/include/numkong/dots/smebi32.h +461 -0
  99. package/include/numkong/dots/smef64.h +1318 -0
  100. package/include/numkong/dots/smehalf.h +47 -0
  101. package/include/numkong/dots/v128relaxed.h +294 -0
  102. package/include/numkong/dots.h +2804 -0
  103. package/include/numkong/dots.hpp +639 -0
  104. package/include/numkong/each/README.md +469 -0
  105. package/include/numkong/each/haswell.h +1658 -0
  106. package/include/numkong/each/icelake.h +272 -0
  107. package/include/numkong/each/neon.h +1104 -0
  108. package/include/numkong/each/neonbfdot.h +212 -0
  109. package/include/numkong/each/neonhalf.h +410 -0
  110. package/include/numkong/each/rvv.h +1121 -0
  111. package/include/numkong/each/sapphire.h +477 -0
  112. package/include/numkong/each/serial.h +260 -0
  113. package/include/numkong/each/skylake.h +1562 -0
  114. package/include/numkong/each.h +2146 -0
  115. package/include/numkong/each.hpp +434 -0
  116. package/include/numkong/geospatial/README.md +147 -0
  117. package/include/numkong/geospatial/haswell.h +593 -0
  118. package/include/numkong/geospatial/neon.h +571 -0
  119. package/include/numkong/geospatial/rvv.h +701 -0
  120. package/include/numkong/geospatial/serial.h +309 -0
  121. package/include/numkong/geospatial/skylake.h +577 -0
  122. package/include/numkong/geospatial/v128relaxed.h +613 -0
  123. package/include/numkong/geospatial.h +453 -0
  124. package/include/numkong/geospatial.hpp +235 -0
  125. package/include/numkong/matrix.hpp +336 -0
  126. package/include/numkong/maxsim/README.md +187 -0
  127. package/include/numkong/maxsim/alder.h +511 -0
  128. package/include/numkong/maxsim/genoa.h +115 -0
  129. package/include/numkong/maxsim/haswell.h +553 -0
  130. package/include/numkong/maxsim/icelake.h +480 -0
  131. package/include/numkong/maxsim/neonsdot.h +394 -0
  132. package/include/numkong/maxsim/sapphireamx.h +877 -0
  133. package/include/numkong/maxsim/serial.h +490 -0
  134. package/include/numkong/maxsim/sme.h +929 -0
  135. package/include/numkong/maxsim/v128relaxed.h +280 -0
  136. package/include/numkong/maxsim.h +571 -0
  137. package/include/numkong/maxsim.hpp +133 -0
  138. package/include/numkong/mesh/README.md +227 -0
  139. package/include/numkong/mesh/haswell.h +2235 -0
  140. package/include/numkong/mesh/neon.h +1329 -0
  141. package/include/numkong/mesh/neonbfdot.h +842 -0
  142. package/include/numkong/mesh/neonhalf.h +616 -0
  143. package/include/numkong/mesh/rvv.h +916 -0
  144. package/include/numkong/mesh/serial.h +742 -0
  145. package/include/numkong/mesh/skylake.h +1135 -0
  146. package/include/numkong/mesh/v128relaxed.h +1052 -0
  147. package/include/numkong/mesh.h +652 -0
  148. package/include/numkong/mesh.hpp +762 -0
  149. package/include/numkong/numkong.h +78 -0
  150. package/include/numkong/numkong.hpp +57 -0
  151. package/include/numkong/probability/README.md +173 -0
  152. package/include/numkong/probability/haswell.h +267 -0
  153. package/include/numkong/probability/neon.h +225 -0
  154. package/include/numkong/probability/rvv.h +409 -0
  155. package/include/numkong/probability/serial.h +169 -0
  156. package/include/numkong/probability/skylake.h +324 -0
  157. package/include/numkong/probability.h +383 -0
  158. package/include/numkong/probability.hpp +120 -0
  159. package/include/numkong/random.h +50 -0
  160. package/include/numkong/random.hpp +285 -0
  161. package/include/numkong/reduce/README.md +547 -0
  162. package/include/numkong/reduce/alder.h +632 -0
  163. package/include/numkong/reduce/genoa.h +201 -0
  164. package/include/numkong/reduce/haswell.h +3783 -0
  165. package/include/numkong/reduce/icelake.h +549 -0
  166. package/include/numkong/reduce/neon.h +3841 -0
  167. package/include/numkong/reduce/neonbfdot.h +353 -0
  168. package/include/numkong/reduce/neonfhm.h +665 -0
  169. package/include/numkong/reduce/neonhalf.h +157 -0
  170. package/include/numkong/reduce/neonsdot.h +357 -0
  171. package/include/numkong/reduce/rvv.h +3407 -0
  172. package/include/numkong/reduce/serial.h +757 -0
  173. package/include/numkong/reduce/sierra.h +338 -0
  174. package/include/numkong/reduce/skylake.h +3792 -0
  175. package/include/numkong/reduce/v128relaxed.h +2302 -0
  176. package/include/numkong/reduce.h +1597 -0
  177. package/include/numkong/reduce.hpp +633 -0
  178. package/include/numkong/scalar/README.md +89 -0
  179. package/include/numkong/scalar/haswell.h +113 -0
  180. package/include/numkong/scalar/neon.h +122 -0
  181. package/include/numkong/scalar/neonhalf.h +70 -0
  182. package/include/numkong/scalar/rvv.h +211 -0
  183. package/include/numkong/scalar/sapphire.h +63 -0
  184. package/include/numkong/scalar/serial.h +332 -0
  185. package/include/numkong/scalar/v128relaxed.h +56 -0
  186. package/include/numkong/scalar.h +683 -0
  187. package/include/numkong/set/README.md +179 -0
  188. package/include/numkong/set/haswell.h +334 -0
  189. package/include/numkong/set/icelake.h +485 -0
  190. package/include/numkong/set/neon.h +364 -0
  191. package/include/numkong/set/rvv.h +226 -0
  192. package/include/numkong/set/rvvbb.h +117 -0
  193. package/include/numkong/set/serial.h +174 -0
  194. package/include/numkong/set/sve.h +185 -0
  195. package/include/numkong/set/v128relaxed.h +240 -0
  196. package/include/numkong/set.h +457 -0
  197. package/include/numkong/set.hpp +114 -0
  198. package/include/numkong/sets/README.md +149 -0
  199. package/include/numkong/sets/haswell.h +63 -0
  200. package/include/numkong/sets/icelake.h +66 -0
  201. package/include/numkong/sets/neon.h +61 -0
  202. package/include/numkong/sets/serial.h +43 -0
  203. package/include/numkong/sets/smebi32.h +1099 -0
  204. package/include/numkong/sets/v128relaxed.h +58 -0
  205. package/include/numkong/sets.h +339 -0
  206. package/include/numkong/sparse/README.md +156 -0
  207. package/include/numkong/sparse/icelake.h +463 -0
  208. package/include/numkong/sparse/neon.h +288 -0
  209. package/include/numkong/sparse/serial.h +117 -0
  210. package/include/numkong/sparse/sve2.h +507 -0
  211. package/include/numkong/sparse/turin.h +322 -0
  212. package/include/numkong/sparse.h +363 -0
  213. package/include/numkong/sparse.hpp +113 -0
  214. package/include/numkong/spatial/README.md +435 -0
  215. package/include/numkong/spatial/alder.h +607 -0
  216. package/include/numkong/spatial/genoa.h +290 -0
  217. package/include/numkong/spatial/haswell.h +960 -0
  218. package/include/numkong/spatial/icelake.h +586 -0
  219. package/include/numkong/spatial/neon.h +773 -0
  220. package/include/numkong/spatial/neonbfdot.h +165 -0
  221. package/include/numkong/spatial/neonhalf.h +118 -0
  222. package/include/numkong/spatial/neonsdot.h +261 -0
  223. package/include/numkong/spatial/rvv.h +984 -0
  224. package/include/numkong/spatial/rvvbf16.h +123 -0
  225. package/include/numkong/spatial/rvvhalf.h +117 -0
  226. package/include/numkong/spatial/sapphire.h +343 -0
  227. package/include/numkong/spatial/serial.h +346 -0
  228. package/include/numkong/spatial/sierra.h +323 -0
  229. package/include/numkong/spatial/skylake.h +606 -0
  230. package/include/numkong/spatial/sve.h +224 -0
  231. package/include/numkong/spatial/svebfdot.h +122 -0
  232. package/include/numkong/spatial/svehalf.h +109 -0
  233. package/include/numkong/spatial/v128relaxed.h +717 -0
  234. package/include/numkong/spatial.h +1425 -0
  235. package/include/numkong/spatial.hpp +183 -0
  236. package/include/numkong/spatials/README.md +580 -0
  237. package/include/numkong/spatials/alder.h +94 -0
  238. package/include/numkong/spatials/genoa.h +94 -0
  239. package/include/numkong/spatials/haswell.h +219 -0
  240. package/include/numkong/spatials/icelake.h +113 -0
  241. package/include/numkong/spatials/neon.h +109 -0
  242. package/include/numkong/spatials/neonbfdot.h +60 -0
  243. package/include/numkong/spatials/neonfhm.h +92 -0
  244. package/include/numkong/spatials/neonhalf.h +58 -0
  245. package/include/numkong/spatials/neonsdot.h +109 -0
  246. package/include/numkong/spatials/rvv.h +1960 -0
  247. package/include/numkong/spatials/sapphireamx.h +1149 -0
  248. package/include/numkong/spatials/serial.h +226 -0
  249. package/include/numkong/spatials/sierra.h +96 -0
  250. package/include/numkong/spatials/skylake.h +184 -0
  251. package/include/numkong/spatials/sme.h +1901 -0
  252. package/include/numkong/spatials/smef64.h +465 -0
  253. package/include/numkong/spatials/v128relaxed.h +240 -0
  254. package/include/numkong/spatials.h +3021 -0
  255. package/include/numkong/spatials.hpp +508 -0
  256. package/include/numkong/tensor.hpp +1592 -0
  257. package/include/numkong/trigonometry/README.md +184 -0
  258. package/include/numkong/trigonometry/haswell.h +652 -0
  259. package/include/numkong/trigonometry/neon.h +639 -0
  260. package/include/numkong/trigonometry/rvv.h +699 -0
  261. package/include/numkong/trigonometry/serial.h +703 -0
  262. package/include/numkong/trigonometry/skylake.h +721 -0
  263. package/include/numkong/trigonometry/v128relaxed.h +666 -0
  264. package/include/numkong/trigonometry.h +467 -0
  265. package/include/numkong/trigonometry.hpp +166 -0
  266. package/include/numkong/types.h +1384 -0
  267. package/include/numkong/types.hpp +5603 -0
  268. package/include/numkong/vector.hpp +698 -0
  269. package/javascript/README.md +246 -0
  270. package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
  271. package/javascript/dist/cjs/numkong-wasm.js +617 -0
  272. package/javascript/dist/cjs/numkong.d.ts +343 -0
  273. package/javascript/dist/cjs/numkong.js +523 -0
  274. package/javascript/dist/cjs/package.json +3 -0
  275. package/javascript/dist/cjs/types.d.ts +284 -0
  276. package/javascript/dist/cjs/types.js +653 -0
  277. package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
  278. package/javascript/dist/esm/numkong-wasm.js +595 -0
  279. package/javascript/dist/esm/numkong.d.ts +343 -0
  280. package/javascript/dist/esm/numkong.js +452 -0
  281. package/javascript/dist/esm/package.json +3 -0
  282. package/javascript/dist/esm/types.d.ts +284 -0
  283. package/javascript/dist/esm/types.js +630 -0
  284. package/javascript/dist-package-cjs.json +3 -0
  285. package/javascript/dist-package-esm.json +3 -0
  286. package/javascript/node-gyp-build.d.ts +1 -0
  287. package/javascript/numkong-wasm.ts +756 -0
  288. package/javascript/numkong.c +689 -0
  289. package/javascript/numkong.ts +575 -0
  290. package/javascript/tsconfig-base.json +39 -0
  291. package/javascript/tsconfig-cjs.json +8 -0
  292. package/javascript/tsconfig-esm.json +8 -0
  293. package/javascript/types.ts +674 -0
  294. package/package.json +87 -0
@@ -0,0 +1,1099 @@
1
+ /**
2
+ * @brief SIMD-accelerated Batched Set Distances for SME.
3
+ * @file include/numkong/sets/smebi32.h
4
+ * @author Ash Vardanian
5
+ * @date February 6, 2026
6
+ * @sa include/numkong/sets.h
7
+ *
8
+ * Uses ARM Scalable Matrix Extension (SME) for efficient binary set operations.
9
+ * Leverages streaming mode's wider vectors (512-bit on Apple M4) for fast
10
+ * XOR+POPCNT operations on binary vectors.
11
+ *
12
+ * @section smebi32_math Mathematical Foundation
13
+ *
14
+ * Hamming distance: popcount(a XOR b) = number of differing bits
15
+ *
16
+ * Jaccard distance using intersection:
17
+ * intersection = popcount(a AND b)
18
+ * union = popcount(a) + popcount(b) - intersection
19
+ * jaccard = 1 - intersection / union
20
+ *
21
+ * @section smebi32_tiles SME Dimensions (512-bit SVL)
22
+ *
23
+ * - svcntw(): 16 (number of 32-bit elements per vector)
24
+ * - svcntb(): 64 (number of bytes per SVE vector)
25
+ * - Tile blocking: 16x16 output tiles for cache efficiency
26
+ * - Depth processing: 64 bytes (512 bits) per iteration
27
+ *
28
+ * @section smebi32_perf Performance Characteristics (Apple M4)
29
+ *
30
+ * - SVL: 512 bits (64 bytes)
31
+ * - Streaming mode provides dedicated register file
32
+ * - Streaming mode overhead: ~50-100 cycles for SMSTART/SMSTOP
33
+ */
34
+
35
+ #ifndef NK_SETS_SMEBI32_H
36
+ #define NK_SETS_SMEBI32_H
37
+
38
+ #if NK_TARGET_ARM_
39
+ #if NK_TARGET_SMEBI32
40
+
41
+ #include "numkong/types.h"
42
+ #include "numkong/set/serial.h"
43
+ #include "numkong/sets/serial.h"
44
+ #include "numkong/dots/sme.h" // `nk_sme_zero_za32_*` constants
45
+ #include "numkong/reduce.h" // `nk_reduce_moments_u1`
46
+
47
+ #if defined(__cplusplus)
48
+ extern "C" {
49
+ #endif
50
+
51
+ /*
52
+ * Binary set operations using SME BMOPA instruction.
53
+ *
54
+ * BMOPA computes: ZA[i,j] += popcount(~(Zn[i] ^ Zm[j])) = popcount(XNOR)
55
+ * This counts matching bits. Hamming = depth_bits - matching.
56
+ *
57
+ * Tile layout (SVL=512, Apple M4):
58
+ * - ZA32 output tile: 16 × 16 u32 elements (1 KB)
59
+ * - Input vectors: 16 u32 elements (SVL/32)
60
+ * - Each BMOPA processes 32 bits (one u32) across 16×16 pairs
61
+ * - BMOPA predicates: b32 (u32 input granularity)
62
+ * - Packed kernel: 4-tile path (ZA0-ZA3) for 4 B-column tiles simultaneously
63
+ * - Unpacked kernel: ZA transpose (ZA0.S=staging, ZA1-3.S=accumulation, 3-tile fast path)
64
+ * - Packed format: column-major u32 within each tile
65
+ */
66
+
67
+ #if defined(__clang__)
68
+ #pragma clang attribute push(__attribute__((target("sme2,sve2"))), apply_to = function)
69
+ #elif defined(__GNUC__)
70
+ #pragma GCC push_options
71
+ #pragma GCC target("+sme2")
72
+ #endif
73
+
74
+ /* Read SVL in bytes from non-streaming context using RDSVL instruction. */
75
+ NK_INTERNAL nk_size_t nk_smebi32_svl_bytes_(void) {
76
+ nk_size_t svl_bytes;
77
+ __asm__ volatile("rdsvl %0, #1" : "=r"(svl_bytes));
78
+ return svl_bytes;
79
+ }
80
+
81
+ /* Get ZA32 tile dimension (number of f32/u32 elements per row). */
82
+ NK_INTERNAL nk_size_t nk_smebi32_tile_dim_(void) { return nk_smebi32_svl_bytes_() / sizeof(nk_u32_t); }
83
+
84
+ typedef struct {
85
+ nk_u32_t row_tile_count; // ceiling(rows / tile_dim)
86
+ nk_u32_t depth_tile_count; // ceiling(depth_bits / depth_tile_bits)
87
+ nk_u32_t rows; // actual row count
88
+ nk_u32_t depth_bits; // actual depth in bits
89
+ nk_u32_t svl_bytes; // SVL at pack time for validation
90
+ nk_u32_t norms_offset; // byte offset to norms (0 if none)
91
+ nk_u32_t reserved[10]; // padding to 64 bytes
92
+ } nk_sets_smebi32_packed_header_t;
93
+
94
+ /** Count total set bits across a byte vector using streaming SVE.
95
+ * Accumulates per-byte popcounts into u32 lanes via svdot; single horizontal reduction at end. */
96
+ NK_PUBLIC nk_u32_t nk_sets_reduce_sumsq_u1_streaming_(nk_u1x8_t const *data,
97
+ nk_size_t n_bytes) NK_STREAMING_COMPATIBLE_ {
98
+ svuint32_t acc_u32x = svdup_u32(0);
99
+ svuint8_t const ones_u8x = svdup_u8(1);
100
+ for (nk_size_t offset = 0; offset < n_bytes; offset += svcntb()) {
101
+ svbool_t predicate_u8x = svwhilelt_b8_u64(offset, n_bytes);
102
+ acc_u32x = svdot_u32(acc_u32x, svcnt_u8_z(predicate_u8x, svld1_u8(predicate_u8x, data + offset)), ones_u8x);
103
+ }
104
+ return (nk_u32_t)svaddv_u32(svptrue_b32(), acc_u32x);
105
+ }
106
+
107
+ #pragma region Hamming Distance
108
+
109
+ NK_PUBLIC nk_size_t nk_dots_packed_size_u1_smebi32(nk_size_t row_count, nk_size_t depth_bits) {
110
+ nk_size_t const tile_dim = nk_smebi32_tile_dim_(); // 16 rows per tile
111
+ nk_size_t const depth_tile_size = nk_smebi32_tile_dim_(); // 16 u32 per depth tile = 512 bits
112
+
113
+ nk_size_t const depth_u32 = nk_size_divide_round_up_(depth_bits, 32);
114
+ nk_size_t const row_tile_count = nk_size_divide_round_up_(row_count, tile_dim);
115
+ nk_size_t const depth_tile_count = nk_size_divide_round_up_(depth_u32, depth_tile_size);
116
+
117
+ nk_size_t const tile_elements = tile_dim * depth_tile_size; // 256 u32 per tile
118
+ nk_size_t size = sizeof(nk_sets_smebi32_packed_header_t);
119
+ size += row_tile_count * depth_tile_count * tile_elements * sizeof(nk_u32_t);
120
+ size += row_count * sizeof(nk_u32_t); // per-row population counts
121
+
122
+ return size;
123
+ }
124
+
125
+ NK_PUBLIC void nk_dots_pack_u1_smebi32(nk_u1x8_t const *b, nk_size_t row_count, nk_size_t depth_bits,
126
+ nk_size_t b_stride_in_bytes, void *b_packed) {
127
+ nk_size_t const svl_bytes = nk_smebi32_svl_bytes_();
128
+ nk_size_t const tile_dim = nk_smebi32_tile_dim_(); // 16 rows per tile
129
+ nk_size_t const depth_tile_size = nk_smebi32_tile_dim_(); // 16 u32 per depth tile
130
+ nk_size_t const tile_elements = tile_dim * depth_tile_size;
131
+ nk_size_t const depth_in_bytes = nk_size_divide_round_up_(depth_bits, NK_BITS_PER_BYTE);
132
+
133
+ nk_size_t const depth_u32_total = nk_size_divide_round_up_(depth_bits, 32);
134
+ nk_size_t const row_tile_count = nk_size_divide_round_up_(row_count, tile_dim);
135
+ nk_size_t const depth_tile_count = nk_size_divide_round_up_(depth_u32_total, depth_tile_size);
136
+ nk_size_t const total_tiles = row_tile_count * depth_tile_count;
137
+ nk_size_t const data_size = total_tiles * tile_elements * sizeof(nk_u32_t);
138
+
139
+ nk_sets_smebi32_packed_header_t *header = (nk_sets_smebi32_packed_header_t *)b_packed;
140
+ header->row_tile_count = (nk_u32_t)row_tile_count;
141
+ header->depth_tile_count = (nk_u32_t)depth_tile_count;
142
+ header->rows = (nk_u32_t)row_count;
143
+ header->depth_bits = (nk_u32_t)depth_bits;
144
+ header->svl_bytes = (nk_u32_t)svl_bytes;
145
+ header->norms_offset = (nk_u32_t)(sizeof(nk_sets_smebi32_packed_header_t) + data_size);
146
+
147
+ nk_u32_t *tiles_ptr = (nk_u32_t *)((char *)b_packed + sizeof(nk_sets_smebi32_packed_header_t));
148
+ nk_u32_t *norms_ptr = (nk_u32_t *)((char *)b_packed + header->norms_offset);
149
+
150
+ // Zero-initialize all tiles (partial tiles stay zero-padded for predicated loads)
151
+ for (nk_size_t i = 0; i < total_tiles * tile_elements; i++) tiles_ptr[i] = 0;
152
+
153
+ // Pack tiles: column-major u32 within each tile for efficient SVE loads
154
+ for (nk_size_t row_tile = 0; row_tile < row_tile_count; row_tile++) {
155
+ for (nk_size_t depth_tile = 0; depth_tile < depth_tile_count; depth_tile++) {
156
+ nk_size_t const tile_index = row_tile * depth_tile_count + depth_tile;
157
+ nk_u32_t *tile_output = tiles_ptr + tile_index * tile_elements;
158
+
159
+ nk_size_t const src_row_start = row_tile * tile_dim;
160
+ nk_size_t const src_u32_start = depth_tile * depth_tile_size;
161
+ nk_size_t const rows_to_pack = (src_row_start + tile_dim <= row_count) ? tile_dim
162
+ : (row_count - src_row_start);
163
+ nk_size_t const u32s_to_pack = (src_u32_start + depth_tile_size <= depth_u32_total)
164
+ ? depth_tile_size
165
+ : (depth_u32_total > src_u32_start ? depth_u32_total - src_u32_start
166
+ : 0);
167
+
168
+ // Column-major packing: tile_output[col * tile_dim + row]
169
+ for (nk_size_t row = 0; row < rows_to_pack; row++) {
170
+ nk_u32_t const *src_row = (nk_u32_t const *)((char const *)b +
171
+ (src_row_start + row) * b_stride_in_bytes);
172
+ for (nk_size_t col = 0; col < u32s_to_pack; col++) {
173
+ nk_size_t const dst_idx = col * tile_dim + row; // Column-major!
174
+ tile_output[dst_idx] = src_row[src_u32_start + col];
175
+ }
176
+ }
177
+ }
178
+ }
179
+
180
+ // Compute per-row population counts
181
+ for (nk_size_t row = 0; row < row_count; row++) {
182
+ nk_u1x8_t const *src_row = (nk_u1x8_t const *)((char const *)b + row * b_stride_in_bytes);
183
+ {
184
+ nk_u64_t nk_local_sum_, nk_local_sumsq_;
185
+ nk_reduce_moments_u1(src_row, depth_in_bytes * 8, sizeof(nk_u1x8_t), &nk_local_sum_, &nk_local_sumsq_);
186
+ norms_ptr[row] = (nk_u32_t)nk_local_sum_;
187
+ }
188
+ }
189
+ }
190
+
191
+ /**
192
+ * SME Hamming kernel using ZA transpose for unpacked A.
193
+ * ZA0.S = staging (A rows loaded horizontally, read vertically for BMOPA).
194
+ * ZA1-3.S = BMOPA accumulation (3 B column tiles in fast path).
195
+ *
196
+ * Each ZA0.S batch covers 16 depth u32 steps (one full depth tile).
197
+ * BMOPA expansion=1 for u32: each u32 contributes 32 bits via XNOR+POPCNT.
198
+ */
199
+ __arm_locally_streaming __arm_new("za") static void nk_hammings_packed_u1_smebi32_streaming_(
200
+ nk_u1x8_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t row_count_a, nk_size_t row_count_b,
201
+ nk_size_t depth_bits, nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
202
+
203
+ nk_sets_smebi32_packed_header_t const *header = (nk_sets_smebi32_packed_header_t const *)b_packed;
204
+ nk_size_t const row_tile_count_b = header->row_tile_count;
205
+ nk_size_t const depth_tile_count = header->depth_tile_count;
206
+
207
+ nk_size_t const tile_dim = svcntw(); // 16 for 512-bit SVL
208
+ nk_size_t const depth_tile_size = svcntw(); // 16 u32 per depth tile
209
+ nk_size_t const tile_elements = tile_dim * depth_tile_size;
210
+ nk_size_t const depth_u32_total = nk_size_divide_round_up_(depth_bits, 32);
211
+
212
+ nk_u32_t const *b_tiles = (nk_u32_t const *)((char const *)b_packed + sizeof(nk_sets_smebi32_packed_header_t));
213
+
214
+ svbool_t const predicate_all_u32x = svptrue_b32();
215
+ svuint32_t const depth_u32x = svdup_u32((nk_u32_t)depth_bits);
216
+ nk_size_t const row_tile_count_a = nk_size_divide_round_up_(row_count_a, tile_dim);
217
+
218
+ for (nk_size_t row_tile_a = 0; row_tile_a < row_tile_count_a; row_tile_a++) {
219
+ nk_size_t const row_start_a = row_tile_a * tile_dim;
220
+ nk_size_t const rows_a_remaining = (row_start_a + tile_dim <= row_count_a) ? tile_dim
221
+ : (row_count_a - row_start_a);
222
+ svbool_t const row_predicate_u32x = svwhilelt_b32_u64(0u, rows_a_remaining);
223
+
224
+ // Fast path: 3 B column tiles using ZA1-ZA3 (ZA0.S = staging)
225
+ nk_size_t row_tile_b = 0;
226
+ for (; row_tile_b + 3 <= row_tile_count_b; row_tile_b += 3) {
227
+ svzero_mask_za(nk_sme_zero_za32_tiles_123_);
228
+
229
+ for (nk_size_t d_tile = 0; d_tile < depth_tile_count; d_tile++) {
230
+ nk_size_t const d_start_u32 = d_tile * depth_tile_size;
231
+ nk_size_t const u32s_this_tile = (d_start_u32 + depth_tile_size <= depth_u32_total)
232
+ ? depth_tile_size
233
+ : (depth_u32_total > d_start_u32 ? depth_u32_total - d_start_u32
234
+ : 0);
235
+ if (u32s_this_tile == 0) break;
236
+
237
+ svzero_mask_za(nk_sme_zero_za32_tile_0_);
238
+
239
+ svbool_t const batch_predicate_u32x = svwhilelt_b32_u64(0u, u32s_this_tile);
240
+
241
+ // Load A rows into ZA0.S horizontally as u32 words
242
+ for (nk_size_t row_in_tile = 0; row_in_tile < rows_a_remaining; row_in_tile++) {
243
+ nk_u32_t const *a_row_u32 = (nk_u32_t const *)((char const *)a +
244
+ (row_start_a + row_in_tile) * a_stride_in_bytes) +
245
+ d_start_u32;
246
+ svld1_hor_za32(0, row_in_tile, batch_predicate_u32x, a_row_u32);
247
+ }
248
+
249
+ // B tile pointers for 3 column tiles
250
+ nk_u32_t const *b_tile0 = b_tiles + ((row_tile_b + 0) * depth_tile_count + d_tile) * tile_elements;
251
+ nk_u32_t const *b_tile1 = b_tiles + ((row_tile_b + 1) * depth_tile_count + d_tile) * tile_elements;
252
+ nk_u32_t const *b_tile2 = b_tiles + ((row_tile_b + 2) * depth_tile_count + d_tile) * tile_elements;
253
+
254
+ // Vertical read + BMOPA for each depth step
255
+ for (nk_size_t step = 0; step < u32s_this_tile; step++) {
256
+ svuint32_t a_column_u32x = svread_ver_za32_u32_m(svdup_u32(0), row_predicate_u32x, 0, step);
257
+
258
+ svbmopa_za32_u32_m(1, row_predicate_u32x, predicate_all_u32x, a_column_u32x,
259
+ svld1_u32(predicate_all_u32x, b_tile0 + step * tile_dim));
260
+ svbmopa_za32_u32_m(2, row_predicate_u32x, predicate_all_u32x, a_column_u32x,
261
+ svld1_u32(predicate_all_u32x, b_tile1 + step * tile_dim));
262
+ svbmopa_za32_u32_m(3, row_predicate_u32x, predicate_all_u32x, a_column_u32x,
263
+ svld1_u32(predicate_all_u32x, b_tile2 + step * tile_dim));
264
+ }
265
+ }
266
+
267
+ // Extract from ZA1-3: Hamming = depth_bits - matching_bits
268
+ for (nk_size_t row = 0; row < rows_a_remaining; row++) {
269
+ nk_u32_t *c_row = (nk_u32_t *)((char *)c + (row_start_a + row) * c_stride_in_bytes);
270
+
271
+ svuint32_t za1_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_u32x, 1, row);
272
+ svuint32_t za2_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_u32x, 2, row);
273
+ svuint32_t za3_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_u32x, 3, row);
274
+
275
+ svst1_u32(predicate_all_u32x, c_row + (row_tile_b + 0) * tile_dim,
276
+ svsub_u32_x(predicate_all_u32x, depth_u32x, za1_u32x));
277
+ svst1_u32(predicate_all_u32x, c_row + (row_tile_b + 1) * tile_dim,
278
+ svsub_u32_x(predicate_all_u32x, depth_u32x, za2_u32x));
279
+ svst1_u32(predicate_all_u32x, c_row + (row_tile_b + 2) * tile_dim,
280
+ svsub_u32_x(predicate_all_u32x, depth_u32x, za3_u32x));
281
+ }
282
+ }
283
+
284
+ // Remainder: 1 B column tile at a time using ZA1
285
+ for (; row_tile_b < row_tile_count_b; row_tile_b++) {
286
+ nk_size_t const row_start_b = row_tile_b * tile_dim;
287
+ nk_size_t const rows_b_remaining = (row_start_b + tile_dim <= row_count_b) ? tile_dim
288
+ : (row_count_b - row_start_b);
289
+ svbool_t const column_predicate_u32x = svwhilelt_b32_u64(0u, rows_b_remaining);
290
+
291
+ svzero_mask_za(nk_sme_zero_za32_tile_1_);
292
+
293
+ for (nk_size_t d_tile = 0; d_tile < depth_tile_count; d_tile++) {
294
+ nk_size_t const d_start_u32 = d_tile * depth_tile_size;
295
+ nk_size_t const u32s_this_tile = (d_start_u32 + depth_tile_size <= depth_u32_total)
296
+ ? depth_tile_size
297
+ : (depth_u32_total > d_start_u32 ? depth_u32_total - d_start_u32
298
+ : 0);
299
+ if (u32s_this_tile == 0) break;
300
+
301
+ svzero_mask_za(nk_sme_zero_za32_tile_0_);
302
+
303
+ svbool_t const batch_predicate_u32x = svwhilelt_b32_u64(0u, u32s_this_tile);
304
+
305
+ // Load A rows into ZA0.S horizontally
306
+ for (nk_size_t row_in_tile = 0; row_in_tile < rows_a_remaining; row_in_tile++) {
307
+ nk_u32_t const *a_row_u32 = (nk_u32_t const *)((char const *)a +
308
+ (row_start_a + row_in_tile) * a_stride_in_bytes) +
309
+ d_start_u32;
310
+ svld1_hor_za32(0, row_in_tile, batch_predicate_u32x, a_row_u32);
311
+ }
312
+
313
+ nk_u32_t const *b_tile = b_tiles + (row_tile_b * depth_tile_count + d_tile) * tile_elements;
314
+
315
+ // Vertical read + BMOPA
316
+ for (nk_size_t step = 0; step < u32s_this_tile; step++) {
317
+ svuint32_t a_column_u32x = svread_ver_za32_u32_m(svdup_u32(0), row_predicate_u32x, 0, step);
318
+ svuint32_t b_u32x = svld1_u32(predicate_all_u32x, b_tile + step * tile_dim);
319
+ svbmopa_za32_u32_m(1, row_predicate_u32x, column_predicate_u32x, a_column_u32x, b_u32x);
320
+ }
321
+ }
322
+
323
+ // Extract from ZA1: Hamming = depth_bits - matching_bits
324
+ for (nk_size_t row = 0; row < rows_a_remaining; row++) {
325
+ svuint32_t za1_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_u32x, 1, row);
326
+ svuint32_t hamming_u32x = svsub_u32_x(predicate_all_u32x, depth_u32x, za1_u32x);
327
+ nk_u32_t *c_row = (nk_u32_t *)((char *)c + (row_start_a + row) * c_stride_in_bytes);
328
+ svst1_u32(column_predicate_u32x, c_row + row_start_b, hamming_u32x);
329
+ }
330
+ }
331
+ }
332
+ }
333
+
334
+ NK_PUBLIC void nk_hammings_packed_u1_smebi32(nk_u1x8_t const *a, void const *b_packed, nk_u32_t *c,
335
+ nk_size_t row_count_a, nk_size_t row_count_b, nk_size_t depth_bits,
336
+ nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
337
+ nk_hammings_packed_u1_smebi32_streaming_(a, b_packed, c, row_count_a, row_count_b, depth_bits, a_stride_in_bytes,
338
+ c_stride_in_bytes);
339
+ }
340
+
341
+ /**
342
+ * Symmetric Hamming using ZA0 time-sharing + 3-tile fast path.
343
+ * ZA0.S = staging (A rows loaded horizontally, read vertically for BMOPA).
344
+ * ZA1-3.S = BMOPA accumulators (3 B column tiles in fast path).
345
+ * Mirrors the unpacked kernel nk_hammings_packed_u1_smebi32_streaming_ pattern.
346
+ */
347
+ __arm_locally_streaming __arm_new("za") static void nk_hammings_symmetric_u1_smebi32_streaming_(
348
+ nk_u1x8_t const *vectors, nk_size_t n_vectors, nk_size_t depth_bits, nk_size_t stride, nk_u32_t *result,
349
+ nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
350
+
351
+ nk_size_t const tile_dim = svcntw(); // 16 for 512-bit SVL
352
+ nk_size_t const depth_tile_size = svcntw(); // 16 u32 per depth tile
353
+ nk_size_t const depth_u32_total = nk_size_divide_round_up_(depth_bits, 32);
354
+ nk_size_t const depth_tile_count = nk_size_divide_round_up_(depth_u32_total, depth_tile_size);
355
+
356
+ svbool_t const predicate_all_u32x = svptrue_b32();
357
+ svuint32_t const depth_u32x = svdup_u32((nk_u32_t)depth_bits);
358
+
359
+ NK_ALIGN64 nk_u32_t a_buffer[16][16]; // Stack buffer for A column save
360
+
361
+ nk_size_t const row_end = row_start + row_count;
362
+ nk_size_t const column_tile_count = nk_size_divide_round_up_(n_vectors, tile_dim);
363
+
364
+ for (nk_size_t row_tile_start = row_start; row_tile_start < row_end && row_tile_start < n_vectors;
365
+ row_tile_start += tile_dim) {
366
+ nk_size_t const rows_remaining = (row_tile_start + tile_dim <= row_end) ? tile_dim : (row_end - row_tile_start);
367
+ nk_size_t const rows_clamped = (row_tile_start + rows_remaining <= n_vectors) ? rows_remaining
368
+ : (n_vectors - row_tile_start);
369
+ svbool_t const row_predicate_u32x = svwhilelt_b32_u64(0u, rows_clamped);
370
+
371
+ nk_size_t column_tile_index = 0;
372
+
373
+ // Fast path: 3 column tiles using ZA1-ZA3 (ZA0 = staging)
374
+ for (; column_tile_index + 3 <= column_tile_count; column_tile_index += 3) {
375
+ svzero_mask_za(nk_sme_zero_za32_tiles_123_);
376
+
377
+ for (nk_size_t d_tile = 0; d_tile < depth_tile_count; d_tile++) {
378
+ nk_size_t const d_start_u32 = d_tile * depth_tile_size;
379
+ nk_size_t const u32s_this_tile = (d_start_u32 + depth_tile_size <= depth_u32_total)
380
+ ? depth_tile_size
381
+ : (depth_u32_total > d_start_u32 ? depth_u32_total - d_start_u32
382
+ : 0);
383
+ if (u32s_this_tile == 0) break;
384
+
385
+ // Load A rows into ZA0 horizontally
386
+ svzero_mask_za(nk_sme_zero_za32_tile_0_);
387
+ svbool_t const batch_predicate_u32x = svwhilelt_b32_u64(0u, u32s_this_tile);
388
+
389
+ for (nk_size_t row_in_tile = 0; row_in_tile < rows_clamped; row_in_tile++) {
390
+ nk_u32_t const *a_row_u32 = (nk_u32_t const *)((char const *)vectors +
391
+ (row_tile_start + row_in_tile) * stride) +
392
+ d_start_u32;
393
+ svld1_hor_za32(0, row_in_tile, batch_predicate_u32x, a_row_u32);
394
+ }
395
+
396
+ // Save A columns from ZA0 to stack buffer
397
+ for (nk_size_t s = 0; s < u32s_this_tile; s++)
398
+ svst1_u32(predicate_all_u32x, a_buffer[s],
399
+ svread_ver_za32_u32_m(svdup_u32(0), row_predicate_u32x, 0, s));
400
+
401
+ // B column tile 0
402
+ svzero_mask_za(nk_sme_zero_za32_tile_0_);
403
+ for (nk_size_t col = 0; col < tile_dim; col++) {
404
+ nk_size_t const col_abs = (column_tile_index + 0) * tile_dim + col;
405
+ if (col_abs < n_vectors) {
406
+ nk_u32_t const *b_row = (nk_u32_t const *)((char const *)vectors + col_abs * stride) +
407
+ d_start_u32;
408
+ svld1_hor_za32(0, col, batch_predicate_u32x, b_row);
409
+ }
410
+ }
411
+ for (nk_size_t step = 0; step < u32s_this_tile; step++) {
412
+ svuint32_t a_u32x = svld1_u32(predicate_all_u32x, a_buffer[step]);
413
+ svuint32_t b_u32x = svread_ver_za32_u32_m(svdup_u32(0), predicate_all_u32x, 0, step);
414
+ svbmopa_za32_u32_m(1, row_predicate_u32x, predicate_all_u32x, a_u32x, b_u32x);
415
+ }
416
+
417
+ // B column tile 1
418
+ svzero_mask_za(nk_sme_zero_za32_tile_0_);
419
+ for (nk_size_t col = 0; col < tile_dim; col++) {
420
+ nk_size_t const col_abs = (column_tile_index + 1) * tile_dim + col;
421
+ if (col_abs < n_vectors) {
422
+ nk_u32_t const *b_row = (nk_u32_t const *)((char const *)vectors + col_abs * stride) +
423
+ d_start_u32;
424
+ svld1_hor_za32(0, col, batch_predicate_u32x, b_row);
425
+ }
426
+ }
427
+ for (nk_size_t step = 0; step < u32s_this_tile; step++) {
428
+ svuint32_t a_u32x = svld1_u32(predicate_all_u32x, a_buffer[step]);
429
+ svuint32_t b_u32x = svread_ver_za32_u32_m(svdup_u32(0), predicate_all_u32x, 0, step);
430
+ svbmopa_za32_u32_m(2, row_predicate_u32x, predicate_all_u32x, a_u32x, b_u32x);
431
+ }
432
+
433
+ // B column tile 2
434
+ svzero_mask_za(nk_sme_zero_za32_tile_0_);
435
+ for (nk_size_t col = 0; col < tile_dim; col++) {
436
+ nk_size_t const col_abs = (column_tile_index + 2) * tile_dim + col;
437
+ if (col_abs < n_vectors) {
438
+ nk_u32_t const *b_row = (nk_u32_t const *)((char const *)vectors + col_abs * stride) +
439
+ d_start_u32;
440
+ svld1_hor_za32(0, col, batch_predicate_u32x, b_row);
441
+ }
442
+ }
443
+ for (nk_size_t step = 0; step < u32s_this_tile; step++) {
444
+ svuint32_t a_u32x = svld1_u32(predicate_all_u32x, a_buffer[step]);
445
+ svuint32_t b_u32x = svread_ver_za32_u32_m(svdup_u32(0), predicate_all_u32x, 0, step);
446
+ svbmopa_za32_u32_m(3, row_predicate_u32x, predicate_all_u32x, a_u32x, b_u32x);
447
+ }
448
+ }
449
+
450
+ // Extract ZA1-3: hamming = depth_bits - ZA[i][j]
451
+ for (nk_size_t row = 0; row < rows_clamped; row++) {
452
+ nk_u32_t *c_row = (nk_u32_t *)((char *)result + (row_tile_start + row) * result_stride);
453
+
454
+ svuint32_t za1_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_u32x, 1, row);
455
+ svuint32_t za2_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_u32x, 2, row);
456
+ svuint32_t za3_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_u32x, 3, row);
457
+
458
+ svst1_u32(predicate_all_u32x, c_row + (column_tile_index + 0) * tile_dim,
459
+ svsub_u32_x(predicate_all_u32x, depth_u32x, za1_u32x));
460
+ svst1_u32(predicate_all_u32x, c_row + (column_tile_index + 1) * tile_dim,
461
+ svsub_u32_x(predicate_all_u32x, depth_u32x, za2_u32x));
462
+ svst1_u32(predicate_all_u32x, c_row + (column_tile_index + 2) * tile_dim,
463
+ svsub_u32_x(predicate_all_u32x, depth_u32x, za3_u32x));
464
+ }
465
+ }
466
+
467
+ // Remainder: 1 column tile at a time using ZA1
468
+ for (; column_tile_index < column_tile_count; column_tile_index++) {
469
+ nk_size_t const col_tile_start = column_tile_index * tile_dim;
470
+ nk_size_t const cols_remaining = (col_tile_start + tile_dim <= n_vectors) ? tile_dim
471
+ : (n_vectors - col_tile_start);
472
+ svbool_t const column_predicate_u32x = svwhilelt_b32_u64(0u, cols_remaining);
473
+
474
+ svzero_mask_za(nk_sme_zero_za32_tile_1_);
475
+
476
+ for (nk_size_t d_tile = 0; d_tile < depth_tile_count; d_tile++) {
477
+ nk_size_t const d_start_u32 = d_tile * depth_tile_size;
478
+ nk_size_t const u32s_this_tile = (d_start_u32 + depth_tile_size <= depth_u32_total)
479
+ ? depth_tile_size
480
+ : (depth_u32_total > d_start_u32 ? depth_u32_total - d_start_u32
481
+ : 0);
482
+ if (u32s_this_tile == 0) break;
483
+
484
+ svzero_mask_za(nk_sme_zero_za32_tile_0_);
485
+ svbool_t const batch_predicate_u32x = svwhilelt_b32_u64(0u, u32s_this_tile);
486
+
487
+ // Load A rows into ZA0 horizontally
488
+ for (nk_size_t row_in_tile = 0; row_in_tile < rows_clamped; row_in_tile++) {
489
+ nk_u32_t const *a_row_u32 = (nk_u32_t const *)((char const *)vectors +
490
+ (row_tile_start + row_in_tile) * stride) +
491
+ d_start_u32;
492
+ svld1_hor_za32(0, row_in_tile, batch_predicate_u32x, a_row_u32);
493
+ }
494
+
495
+ // Save A columns from ZA0 to stack buffer
496
+ for (nk_size_t s = 0; s < u32s_this_tile; s++)
497
+ svst1_u32(predicate_all_u32x, a_buffer[s],
498
+ svread_ver_za32_u32_m(svdup_u32(0), row_predicate_u32x, 0, s));
499
+
500
+ // Load B column tile into ZA0
501
+ svzero_mask_za(nk_sme_zero_za32_tile_0_);
502
+ for (nk_size_t col = 0; col < tile_dim; col++) {
503
+ nk_size_t const col_abs = col_tile_start + col;
504
+ if (col_abs < n_vectors) {
505
+ nk_u32_t const *b_row = (nk_u32_t const *)((char const *)vectors + col_abs * stride) +
506
+ d_start_u32;
507
+ svld1_hor_za32(0, col, batch_predicate_u32x, b_row);
508
+ }
509
+ }
510
+ for (nk_size_t step = 0; step < u32s_this_tile; step++) {
511
+ svuint32_t a_u32x = svld1_u32(predicate_all_u32x, a_buffer[step]);
512
+ svuint32_t b_u32x = svread_ver_za32_u32_m(svdup_u32(0), column_predicate_u32x, 0, step);
513
+ svbmopa_za32_u32_m(1, row_predicate_u32x, column_predicate_u32x, a_u32x, b_u32x);
514
+ }
515
+ }
516
+
517
+ for (nk_size_t row = 0; row < rows_clamped; row++) {
518
+ svuint32_t za1_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_u32x, 1, row);
519
+ svuint32_t hamming_u32x = svsub_u32_x(predicate_all_u32x, depth_u32x, za1_u32x);
520
+ nk_u32_t *c_row = (nk_u32_t *)((char *)result + (row_tile_start + row) * result_stride);
521
+ svst1_u32(column_predicate_u32x, c_row + col_tile_start, hamming_u32x);
522
+ }
523
+ }
524
+ }
525
+ }
526
+
527
+ NK_PUBLIC void nk_hammings_symmetric_u1_smebi32(nk_u1x8_t const *vectors, nk_size_t n_vectors, nk_size_t depth_bits,
528
+ nk_size_t stride, nk_u32_t *result, nk_size_t result_stride,
529
+ nk_size_t row_start, nk_size_t row_count) {
530
+ nk_hammings_symmetric_u1_smebi32_streaming_(vectors, n_vectors, depth_bits, stride, result, result_stride,
531
+ row_start, row_count);
532
+ }
533
+
534
+ #pragma endregion // Hamming Distance
535
+
536
+ /*
537
+ * Jaccard distance via BMOPA matching counts + algebraic normalization.
538
+ *
539
+ * BMOPA gives: matching = popcount(XNOR(a,b))
540
+ * Then:
541
+ * hamming = depth_bits - matching
542
+ * intersection = (norm_a + norm_b - hamming) / 2 = (norm_a + norm_b - depth_bits + matching) / 2
543
+ * union = (norm_a + norm_b + hamming) / 2 = sum_norms - intersection
544
+ * jaccard = 1 - intersection / union (1.0 when union == 0)
545
+ *
546
+ * Inner BMOPA loop is identical to Hamming; only the extraction phase differs.
547
+ * Packed format shares the Hamming tile layout for B operand, plus per-row norms.
548
+ */
549
+
550
+ #pragma region Jaccard Distance
551
+
552
+ /**
553
+ * SME Jaccard kernel using BMOPA for matching-bit counts.
554
+ * Mirrors nk_hammings_packed_u1_smebi32_streaming_ exactly in structure,
555
+ * but derives intersection/union algebraically from the matching counts:
556
+ * matching = popcount(XNOR(a,b)) (from BMOPA)
557
+ * hamming = depth_bits - matching
558
+ * intersection = (norm_a + norm_b - hamming) / 2
559
+ * union = (norm_a + norm_b + hamming) / 2
560
+ * jaccard = 1 - intersection / union (1.0 when union == 0)
561
+ */
562
+ __arm_locally_streaming __arm_new("za") static void nk_jaccards_packed_u1_smebi32_streaming_(
563
+ nk_u1x8_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t row_count_a, nk_size_t row_count_b,
564
+ nk_size_t depth_bits, nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
565
+
566
+ nk_sets_smebi32_packed_header_t const *header = (nk_sets_smebi32_packed_header_t const *)b_packed;
567
+ nk_size_t const row_tile_count_b = header->row_tile_count;
568
+ nk_size_t const depth_tile_count = header->depth_tile_count;
569
+
570
+ nk_size_t const tile_dim = svcntw(); // 16 for 512-bit SVL
571
+ nk_size_t const depth_tile_size = svcntw(); // 16 u32 per depth tile
572
+ nk_size_t const tile_elements = tile_dim * depth_tile_size;
573
+ nk_size_t const depth_u32_total = nk_size_divide_round_up_(depth_bits, 32);
574
+
575
+ nk_u32_t const *b_tiles = (nk_u32_t const *)((char const *)b_packed + sizeof(nk_sets_smebi32_packed_header_t));
576
+ nk_u32_t const *b_norms = header->norms_offset ? (nk_u32_t const *)((char const *)b_packed + header->norms_offset)
577
+ : (nk_u32_t const *)0;
578
+
579
+ svbool_t const predicate_all_f32x = svptrue_b32();
580
+ svfloat32_t const depth_f32x = svdup_f32((nk_f32_t)depth_bits);
581
+ svfloat32_t const half_f32x = svdup_f32(0.5f);
582
+ svfloat32_t const one_f32x = svdup_f32(1.0f);
583
+ svfloat32_t const zero_f32x = svdup_f32(0.0f);
584
+ nk_size_t const depth_in_bytes = nk_size_divide_round_up_(depth_bits, 8);
585
+ nk_size_t const row_tile_count_a = nk_size_divide_round_up_(row_count_a, tile_dim);
586
+
587
+ for (nk_size_t row_tile_a = 0; row_tile_a < row_tile_count_a; row_tile_a++) {
588
+ nk_size_t const row_start_a = row_tile_a * tile_dim;
589
+ nk_size_t const rows_a_remaining = (row_start_a + tile_dim <= row_count_a) ? tile_dim
590
+ : (row_count_a - row_start_a);
591
+ svbool_t const row_predicate_f32x = svwhilelt_b32_u64(0u, rows_a_remaining);
592
+
593
+ // Compute A tile norms using streaming SVE popcount
594
+ NK_ALIGN64 nk_f32_t a_tile_norms[16];
595
+ for (nk_size_t r = 0; r < rows_a_remaining; r++) {
596
+ nk_u1x8_t const *a_row = (nk_u1x8_t const *)((char const *)a + (row_start_a + r) * a_stride_in_bytes);
597
+ a_tile_norms[r] = (nk_f32_t)nk_sets_reduce_sumsq_u1_streaming_(a_row, depth_in_bytes);
598
+ }
599
+
600
+ // Fast path: 3 B column tiles using ZA1-ZA3 (ZA0.S = staging)
601
+ nk_size_t row_tile_b = 0;
602
+ for (; row_tile_b + 3 <= row_tile_count_b; row_tile_b += 3) {
603
+ svzero_mask_za(nk_sme_zero_za32_tiles_123_);
604
+
605
+ for (nk_size_t d_tile = 0; d_tile < depth_tile_count; d_tile++) {
606
+ nk_size_t const d_start_u32 = d_tile * depth_tile_size;
607
+ nk_size_t const u32s_this_tile = (d_start_u32 + depth_tile_size <= depth_u32_total)
608
+ ? depth_tile_size
609
+ : (depth_u32_total > d_start_u32 ? depth_u32_total - d_start_u32
610
+ : 0);
611
+ if (u32s_this_tile == 0) break;
612
+
613
+ svzero_mask_za(nk_sme_zero_za32_tile_0_);
614
+
615
+ svbool_t const batch_predicate_f32x = svwhilelt_b32_u64(0u, u32s_this_tile);
616
+
617
+ // Load A rows into ZA0.S horizontally as u32 words
618
+ for (nk_size_t row_in_tile = 0; row_in_tile < rows_a_remaining; row_in_tile++) {
619
+ nk_u32_t const *a_row_u32 = (nk_u32_t const *)((char const *)a +
620
+ (row_start_a + row_in_tile) * a_stride_in_bytes) +
621
+ d_start_u32;
622
+ svld1_hor_za32(0, row_in_tile, batch_predicate_f32x, a_row_u32);
623
+ }
624
+
625
+ // B tile pointers for 3 column tiles
626
+ nk_u32_t const *b_tile0 = b_tiles + ((row_tile_b + 0) * depth_tile_count + d_tile) * tile_elements;
627
+ nk_u32_t const *b_tile1 = b_tiles + ((row_tile_b + 1) * depth_tile_count + d_tile) * tile_elements;
628
+ nk_u32_t const *b_tile2 = b_tiles + ((row_tile_b + 2) * depth_tile_count + d_tile) * tile_elements;
629
+
630
+ // Vertical read + BMOPA for each depth step
631
+ for (nk_size_t step = 0; step < u32s_this_tile; step++) {
632
+ svuint32_t a_column_u32x = svread_ver_za32_u32_m(svdup_u32(0), row_predicate_f32x, 0, step);
633
+
634
+ svbmopa_za32_u32_m(1, row_predicate_f32x, predicate_all_f32x, a_column_u32x,
635
+ svld1_u32(predicate_all_f32x, b_tile0 + step * tile_dim));
636
+ svbmopa_za32_u32_m(2, row_predicate_f32x, predicate_all_f32x, a_column_u32x,
637
+ svld1_u32(predicate_all_f32x, b_tile1 + step * tile_dim));
638
+ svbmopa_za32_u32_m(3, row_predicate_f32x, predicate_all_f32x, a_column_u32x,
639
+ svld1_u32(predicate_all_f32x, b_tile2 + step * tile_dim));
640
+ }
641
+ }
642
+
643
+ // Extract from ZA1-3: Jaccard normalization via streaming SVE
644
+ // Hoist B norms outside row loop (same for all A rows in this tile-pair)
645
+ svfloat32_t b_norms_0_f32x = svcvt_f32_u32_x(
646
+ predicate_all_f32x, svld1_u32(predicate_all_f32x, b_norms + (row_tile_b + 0) * tile_dim));
647
+ svfloat32_t b_norms_1_f32x = svcvt_f32_u32_x(
648
+ predicate_all_f32x, svld1_u32(predicate_all_f32x, b_norms + (row_tile_b + 1) * tile_dim));
649
+ svfloat32_t b_norms_2_f32x = svcvt_f32_u32_x(
650
+ predicate_all_f32x, svld1_u32(predicate_all_f32x, b_norms + (row_tile_b + 2) * tile_dim));
651
+
652
+ for (nk_size_t row = 0; row < rows_a_remaining; row++) {
653
+ nk_f32_t *c_row = (nk_f32_t *)((char *)c + (row_start_a + row) * c_stride_in_bytes);
654
+ svfloat32_t norm_a_f32x = svdup_f32(a_tile_norms[row]);
655
+
656
+ // ZA1
657
+ {
658
+ svuint32_t za1_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_f32x, 1, row);
659
+ svfloat32_t matching_f32x = svcvt_f32_u32_x(predicate_all_f32x, za1_u32x);
660
+ svfloat32_t sum_norms_f32x = svadd_f32_x(predicate_all_f32x, norm_a_f32x, b_norms_0_f32x);
661
+ svfloat32_t intersection_f32x = svmul_f32_x(
662
+ predicate_all_f32x,
663
+ svadd_f32_x(predicate_all_f32x, svsub_f32_x(predicate_all_f32x, sum_norms_f32x, depth_f32x),
664
+ matching_f32x),
665
+ half_f32x);
666
+ svfloat32_t union_val_f32x = svsub_f32_x(predicate_all_f32x, sum_norms_f32x, intersection_f32x);
667
+ svbool_t nonzero_f32x = svcmpne_f32(predicate_all_f32x, union_val_f32x, zero_f32x);
668
+ svfloat32_t ratio_f32x = svdiv_f32_x(predicate_all_f32x, intersection_f32x, union_val_f32x);
669
+ svfloat32_t jaccard_f32x = svsel_f32(
670
+ nonzero_f32x, svsub_f32_x(predicate_all_f32x, one_f32x, ratio_f32x), one_f32x);
671
+ svst1_f32(predicate_all_f32x, c_row + (row_tile_b + 0) * tile_dim, jaccard_f32x);
672
+ }
673
+ // ZA2
674
+ {
675
+ svuint32_t za2_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_f32x, 2, row);
676
+ svfloat32_t matching_f32x = svcvt_f32_u32_x(predicate_all_f32x, za2_u32x);
677
+ svfloat32_t sum_norms_f32x = svadd_f32_x(predicate_all_f32x, norm_a_f32x, b_norms_1_f32x);
678
+ svfloat32_t intersection_f32x = svmul_f32_x(
679
+ predicate_all_f32x,
680
+ svadd_f32_x(predicate_all_f32x, svsub_f32_x(predicate_all_f32x, sum_norms_f32x, depth_f32x),
681
+ matching_f32x),
682
+ half_f32x);
683
+ svfloat32_t union_val_f32x = svsub_f32_x(predicate_all_f32x, sum_norms_f32x, intersection_f32x);
684
+ svbool_t nonzero_f32x = svcmpne_f32(predicate_all_f32x, union_val_f32x, zero_f32x);
685
+ svfloat32_t ratio_f32x = svdiv_f32_x(predicate_all_f32x, intersection_f32x, union_val_f32x);
686
+ svfloat32_t jaccard_f32x = svsel_f32(
687
+ nonzero_f32x, svsub_f32_x(predicate_all_f32x, one_f32x, ratio_f32x), one_f32x);
688
+ svst1_f32(predicate_all_f32x, c_row + (row_tile_b + 1) * tile_dim, jaccard_f32x);
689
+ }
690
+ // ZA3
691
+ {
692
+ svuint32_t za3_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_f32x, 3, row);
693
+ svfloat32_t matching_f32x = svcvt_f32_u32_x(predicate_all_f32x, za3_u32x);
694
+ svfloat32_t sum_norms_f32x = svadd_f32_x(predicate_all_f32x, norm_a_f32x, b_norms_2_f32x);
695
+ svfloat32_t intersection_f32x = svmul_f32_x(
696
+ predicate_all_f32x,
697
+ svadd_f32_x(predicate_all_f32x, svsub_f32_x(predicate_all_f32x, sum_norms_f32x, depth_f32x),
698
+ matching_f32x),
699
+ half_f32x);
700
+ svfloat32_t union_val_f32x = svsub_f32_x(predicate_all_f32x, sum_norms_f32x, intersection_f32x);
701
+ svbool_t nonzero_f32x = svcmpne_f32(predicate_all_f32x, union_val_f32x, zero_f32x);
702
+ svfloat32_t ratio_f32x = svdiv_f32_x(predicate_all_f32x, intersection_f32x, union_val_f32x);
703
+ svfloat32_t jaccard_f32x = svsel_f32(
704
+ nonzero_f32x, svsub_f32_x(predicate_all_f32x, one_f32x, ratio_f32x), one_f32x);
705
+ svst1_f32(predicate_all_f32x, c_row + (row_tile_b + 2) * tile_dim, jaccard_f32x);
706
+ }
707
+ }
708
+ }
709
+
710
+ // Remainder: 1 B column tile at a time using ZA1
711
+ for (; row_tile_b < row_tile_count_b; row_tile_b++) {
712
+ nk_size_t const row_start_b = row_tile_b * tile_dim;
713
+ nk_size_t const rows_b_remaining = (row_start_b + tile_dim <= row_count_b) ? tile_dim
714
+ : (row_count_b - row_start_b);
715
+ svbool_t const column_predicate_f32x = svwhilelt_b32_u64(0u, rows_b_remaining);
716
+
717
+ svzero_mask_za(nk_sme_zero_za32_tile_1_);
718
+
719
+ for (nk_size_t d_tile = 0; d_tile < depth_tile_count; d_tile++) {
720
+ nk_size_t const d_start_u32 = d_tile * depth_tile_size;
721
+ nk_size_t const u32s_this_tile = (d_start_u32 + depth_tile_size <= depth_u32_total)
722
+ ? depth_tile_size
723
+ : (depth_u32_total > d_start_u32 ? depth_u32_total - d_start_u32
724
+ : 0);
725
+ if (u32s_this_tile == 0) break;
726
+
727
+ svzero_mask_za(nk_sme_zero_za32_tile_0_);
728
+
729
+ svbool_t const batch_predicate_f32x = svwhilelt_b32_u64(0u, u32s_this_tile);
730
+
731
+ // Load A rows into ZA0.S horizontally
732
+ for (nk_size_t row_in_tile = 0; row_in_tile < rows_a_remaining; row_in_tile++) {
733
+ nk_u32_t const *a_row_u32 = (nk_u32_t const *)((char const *)a +
734
+ (row_start_a + row_in_tile) * a_stride_in_bytes) +
735
+ d_start_u32;
736
+ svld1_hor_za32(0, row_in_tile, batch_predicate_f32x, a_row_u32);
737
+ }
738
+
739
+ nk_u32_t const *b_tile = b_tiles + (row_tile_b * depth_tile_count + d_tile) * tile_elements;
740
+
741
+ // Vertical read + BMOPA
742
+ for (nk_size_t step = 0; step < u32s_this_tile; step++) {
743
+ svuint32_t a_column_u32x = svread_ver_za32_u32_m(svdup_u32(0), row_predicate_f32x, 0, step);
744
+ svuint32_t b_u32x = svld1_u32(predicate_all_f32x, b_tile + step * tile_dim);
745
+ svbmopa_za32_u32_m(1, row_predicate_f32x, column_predicate_f32x, a_column_u32x, b_u32x);
746
+ }
747
+ }
748
+
749
+ // Extract from ZA1: Jaccard normalization
750
+ svfloat32_t b_norms_f32x = svcvt_f32_u32_x(predicate_all_f32x,
751
+ svld1_u32(predicate_all_f32x, b_norms + row_start_b));
752
+ for (nk_size_t row = 0; row < rows_a_remaining; row++) {
753
+ svuint32_t za1_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_f32x, 1, row);
754
+ svfloat32_t matching_f32x = svcvt_f32_u32_x(predicate_all_f32x, za1_u32x);
755
+ svfloat32_t norm_a_f32x = svdup_f32(a_tile_norms[row]);
756
+ svfloat32_t sum_norms_f32x = svadd_f32_x(predicate_all_f32x, norm_a_f32x, b_norms_f32x);
757
+ svfloat32_t intersection_f32x = svmul_f32_x(
758
+ predicate_all_f32x,
759
+ svadd_f32_x(predicate_all_f32x, svsub_f32_x(predicate_all_f32x, sum_norms_f32x, depth_f32x),
760
+ matching_f32x),
761
+ half_f32x);
762
+ svfloat32_t union_val_f32x = svsub_f32_x(predicate_all_f32x, sum_norms_f32x, intersection_f32x);
763
+ svbool_t nonzero_f32x = svcmpne_f32(predicate_all_f32x, union_val_f32x, zero_f32x);
764
+ svfloat32_t ratio_f32x = svdiv_f32_x(predicate_all_f32x, intersection_f32x, union_val_f32x);
765
+ svfloat32_t jaccard_f32x = svsel_f32(nonzero_f32x,
766
+ svsub_f32_x(predicate_all_f32x, one_f32x, ratio_f32x), one_f32x);
767
+ nk_f32_t *c_row = (nk_f32_t *)((char *)c + (row_start_a + row) * c_stride_in_bytes);
768
+ svst1_f32(column_predicate_f32x, c_row + row_start_b, jaccard_f32x);
769
+ }
770
+ }
771
+ }
772
+ }
773
+
774
+ NK_PUBLIC void nk_jaccards_packed_u1_smebi32(nk_u1x8_t const *a, void const *b_packed, nk_f32_t *c,
775
+ nk_size_t row_count_a, nk_size_t row_count_b, nk_size_t depth_bits,
776
+ nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
777
+ nk_jaccards_packed_u1_smebi32_streaming_(a, b_packed, c, row_count_a, row_count_b, depth_bits, a_stride_in_bytes,
778
+ c_stride_in_bytes);
779
+ }
780
+
781
+ /**
782
+ * Symmetric Jaccard kernel using ZA0 time-sharing + 3-tile fast path.
783
+ * Fills upper triangle only (column_tile >= row_tile); caller sees result[i][j] for j >= i.
784
+ * Norms computed on-the-fly using streaming SVE popcount.
785
+ */
786
+ __arm_locally_streaming __arm_new("za") static void nk_jaccards_symmetric_u1_smebi32_streaming_(
787
+ nk_u1x8_t const *vectors, nk_size_t n_vectors, nk_size_t depth_bits, nk_size_t stride, nk_f32_t *result,
788
+ nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
789
+
790
+ nk_size_t const tile_dim = svcntw(); // 16 for 512-bit SVL
791
+ nk_size_t const depth_tile_size = svcntw(); // 16 u32 per depth tile
792
+ nk_size_t const depth_u32_total = nk_size_divide_round_up_(depth_bits, 32);
793
+ nk_size_t const depth_tile_count = nk_size_divide_round_up_(depth_u32_total, depth_tile_size);
794
+ nk_size_t const depth_in_bytes = nk_size_divide_round_up_(depth_bits, NK_BITS_PER_BYTE);
795
+
796
+ svbool_t const predicate_all_f32x = svptrue_b32();
797
+ svfloat32_t const depth_f32x = svdup_f32((nk_f32_t)depth_bits);
798
+ svfloat32_t const half_f32x = svdup_f32(0.5f);
799
+ svfloat32_t const one_f32x = svdup_f32(1.0f);
800
+ svfloat32_t const zero_f32x = svdup_f32(0.0f);
801
+
802
+ NK_ALIGN64 nk_u32_t a_buffer[16][16]; // Stack buffer for A column save
803
+
804
+ nk_size_t const row_end = row_start + row_count;
805
+ nk_size_t const column_tile_count = nk_size_divide_round_up_(n_vectors, tile_dim);
806
+
807
+ for (nk_size_t row_tile_start = row_start; row_tile_start < row_end && row_tile_start < n_vectors;
808
+ row_tile_start += tile_dim) {
809
+ nk_size_t const rows_remaining = (row_tile_start + tile_dim <= row_end) ? tile_dim : (row_end - row_tile_start);
810
+ nk_size_t const rows_clamped = (row_tile_start + rows_remaining <= n_vectors) ? rows_remaining
811
+ : (n_vectors - row_tile_start);
812
+ svbool_t const row_predicate_f32x = svwhilelt_b32_u64(0u, rows_clamped);
813
+
814
+ // Compute A tile norms
815
+ NK_ALIGN64 nk_f32_t a_tile_norms[16];
816
+ for (nk_size_t r = 0; r < rows_clamped; r++) {
817
+ nk_u1x8_t const *a_row = (nk_u1x8_t const *)((char const *)vectors + (row_tile_start + r) * stride);
818
+ a_tile_norms[r] = (nk_f32_t)nk_sets_reduce_sumsq_u1_streaming_(a_row, depth_in_bytes);
819
+ }
820
+ for (nk_size_t r = rows_clamped; r < tile_dim; r++) a_tile_norms[r] = 0.0f;
821
+
822
+ // Upper triangle: start from this row tile's column
823
+ nk_size_t column_tile_index = row_tile_start / tile_dim;
824
+
825
+ // Fast path: 3 column tiles using ZA1-ZA3 (ZA0 = staging)
826
+ for (; column_tile_index + 3 <= column_tile_count; column_tile_index += 3) {
827
+ svzero_mask_za(nk_sme_zero_za32_tiles_123_);
828
+
829
+ for (nk_size_t d_tile = 0; d_tile < depth_tile_count; d_tile++) {
830
+ nk_size_t const d_start_u32 = d_tile * depth_tile_size;
831
+ nk_size_t const u32s_this_tile = (d_start_u32 + depth_tile_size <= depth_u32_total)
832
+ ? depth_tile_size
833
+ : (depth_u32_total > d_start_u32 ? depth_u32_total - d_start_u32
834
+ : 0);
835
+ if (u32s_this_tile == 0) break;
836
+
837
+ // Load A rows into ZA0 horizontally
838
+ svzero_mask_za(nk_sme_zero_za32_tile_0_);
839
+ svbool_t const batch_predicate_f32x = svwhilelt_b32_u64(0u, u32s_this_tile);
840
+
841
+ for (nk_size_t row_in_tile = 0; row_in_tile < rows_clamped; row_in_tile++) {
842
+ nk_u32_t const *a_row_u32 = (nk_u32_t const *)((char const *)vectors +
843
+ (row_tile_start + row_in_tile) * stride) +
844
+ d_start_u32;
845
+ svld1_hor_za32(0, row_in_tile, batch_predicate_f32x, a_row_u32);
846
+ }
847
+
848
+ // Save A columns from ZA0 to stack buffer
849
+ for (nk_size_t s = 0; s < u32s_this_tile; s++)
850
+ svst1_u32(predicate_all_f32x, a_buffer[s],
851
+ svread_ver_za32_u32_m(svdup_u32(0), row_predicate_f32x, 0, s));
852
+
853
+ // B column tile 0
854
+ svzero_mask_za(nk_sme_zero_za32_tile_0_);
855
+ for (nk_size_t col = 0; col < tile_dim; col++) {
856
+ nk_size_t const col_abs = (column_tile_index + 0) * tile_dim + col;
857
+ if (col_abs < n_vectors) {
858
+ nk_u32_t const *b_row = (nk_u32_t const *)((char const *)vectors + col_abs * stride) +
859
+ d_start_u32;
860
+ svld1_hor_za32(0, col, batch_predicate_f32x, b_row);
861
+ }
862
+ }
863
+ for (nk_size_t step = 0; step < u32s_this_tile; step++) {
864
+ svuint32_t a_u32x = svld1_u32(predicate_all_f32x, a_buffer[step]);
865
+ svuint32_t b_u32x = svread_ver_za32_u32_m(svdup_u32(0), predicate_all_f32x, 0, step);
866
+ svbmopa_za32_u32_m(1, row_predicate_f32x, predicate_all_f32x, a_u32x, b_u32x);
867
+ }
868
+
869
+ // B column tile 1
870
+ svzero_mask_za(nk_sme_zero_za32_tile_0_);
871
+ for (nk_size_t col = 0; col < tile_dim; col++) {
872
+ nk_size_t const col_abs = (column_tile_index + 1) * tile_dim + col;
873
+ if (col_abs < n_vectors) {
874
+ nk_u32_t const *b_row = (nk_u32_t const *)((char const *)vectors + col_abs * stride) +
875
+ d_start_u32;
876
+ svld1_hor_za32(0, col, batch_predicate_f32x, b_row);
877
+ }
878
+ }
879
+ for (nk_size_t step = 0; step < u32s_this_tile; step++) {
880
+ svuint32_t a_u32x = svld1_u32(predicate_all_f32x, a_buffer[step]);
881
+ svuint32_t b_u32x = svread_ver_za32_u32_m(svdup_u32(0), predicate_all_f32x, 0, step);
882
+ svbmopa_za32_u32_m(2, row_predicate_f32x, predicate_all_f32x, a_u32x, b_u32x);
883
+ }
884
+
885
+ // B column tile 2
886
+ svzero_mask_za(nk_sme_zero_za32_tile_0_);
887
+ for (nk_size_t col = 0; col < tile_dim; col++) {
888
+ nk_size_t const col_abs = (column_tile_index + 2) * tile_dim + col;
889
+ if (col_abs < n_vectors) {
890
+ nk_u32_t const *b_row = (nk_u32_t const *)((char const *)vectors + col_abs * stride) +
891
+ d_start_u32;
892
+ svld1_hor_za32(0, col, batch_predicate_f32x, b_row);
893
+ }
894
+ }
895
+ for (nk_size_t step = 0; step < u32s_this_tile; step++) {
896
+ svuint32_t a_u32x = svld1_u32(predicate_all_f32x, a_buffer[step]);
897
+ svuint32_t b_u32x = svread_ver_za32_u32_m(svdup_u32(0), predicate_all_f32x, 0, step);
898
+ svbmopa_za32_u32_m(3, row_predicate_f32x, predicate_all_f32x, a_u32x, b_u32x);
899
+ }
900
+ }
901
+
902
+ // Compute B tile norms for 3 column tiles
903
+ NK_ALIGN64 nk_u32_t b_tile_norms_0[16];
904
+ NK_ALIGN64 nk_u32_t b_tile_norms_1[16];
905
+ NK_ALIGN64 nk_u32_t b_tile_norms_2[16];
906
+ for (nk_size_t col = 0; col < tile_dim; col++) {
907
+ nk_size_t const col_abs_0 = (column_tile_index + 0) * tile_dim + col;
908
+ nk_size_t const col_abs_1 = (column_tile_index + 1) * tile_dim + col;
909
+ nk_size_t const col_abs_2 = (column_tile_index + 2) * tile_dim + col;
910
+ b_tile_norms_0[col] = (col_abs_0 < n_vectors)
911
+ ? nk_sets_reduce_sumsq_u1_streaming_(
912
+ (nk_u1x8_t const *)((char const *)vectors + col_abs_0 * stride),
913
+ depth_in_bytes)
914
+ : 0;
915
+ b_tile_norms_1[col] = (col_abs_1 < n_vectors)
916
+ ? nk_sets_reduce_sumsq_u1_streaming_(
917
+ (nk_u1x8_t const *)((char const *)vectors + col_abs_1 * stride),
918
+ depth_in_bytes)
919
+ : 0;
920
+ b_tile_norms_2[col] = (col_abs_2 < n_vectors)
921
+ ? nk_sets_reduce_sumsq_u1_streaming_(
922
+ (nk_u1x8_t const *)((char const *)vectors + col_abs_2 * stride),
923
+ depth_in_bytes)
924
+ : 0;
925
+ }
926
+
927
+ // Extract ZA1-3: Jaccard normalization
928
+ svfloat32_t b_norms_0_f32x = svcvt_f32_u32_x(predicate_all_f32x,
929
+ svld1_u32(predicate_all_f32x, b_tile_norms_0));
930
+ svfloat32_t b_norms_1_f32x = svcvt_f32_u32_x(predicate_all_f32x,
931
+ svld1_u32(predicate_all_f32x, b_tile_norms_1));
932
+ svfloat32_t b_norms_2_f32x = svcvt_f32_u32_x(predicate_all_f32x,
933
+ svld1_u32(predicate_all_f32x, b_tile_norms_2));
934
+
935
+ for (nk_size_t row = 0; row < rows_clamped; row++) {
936
+ nk_f32_t *c_row = (nk_f32_t *)((char *)result + (row_tile_start + row) * result_stride);
937
+ svfloat32_t norm_a_f32x = svdup_f32(a_tile_norms[row]);
938
+
939
+ // ZA1
940
+ {
941
+ svuint32_t za1_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_f32x, 1, row);
942
+ svfloat32_t matching_f32x = svcvt_f32_u32_x(predicate_all_f32x, za1_u32x);
943
+ svfloat32_t sum_norms_f32x = svadd_f32_x(predicate_all_f32x, norm_a_f32x, b_norms_0_f32x);
944
+ svfloat32_t intersection_f32x = svmul_f32_x(
945
+ predicate_all_f32x,
946
+ svadd_f32_x(predicate_all_f32x, svsub_f32_x(predicate_all_f32x, sum_norms_f32x, depth_f32x),
947
+ matching_f32x),
948
+ half_f32x);
949
+ svfloat32_t union_val_f32x = svsub_f32_x(predicate_all_f32x, sum_norms_f32x, intersection_f32x);
950
+ svbool_t nonzero_f32x = svcmpne_f32(predicate_all_f32x, union_val_f32x, zero_f32x);
951
+ svfloat32_t ratio_f32x = svdiv_f32_x(predicate_all_f32x, intersection_f32x, union_val_f32x);
952
+ svfloat32_t jaccard_f32x = svsel_f32(
953
+ nonzero_f32x, svsub_f32_x(predicate_all_f32x, one_f32x, ratio_f32x), one_f32x);
954
+ svst1_f32(predicate_all_f32x, c_row + (column_tile_index + 0) * tile_dim, jaccard_f32x);
955
+ }
956
+ // ZA2
957
+ {
958
+ svuint32_t za2_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_f32x, 2, row);
959
+ svfloat32_t matching_f32x = svcvt_f32_u32_x(predicate_all_f32x, za2_u32x);
960
+ svfloat32_t sum_norms_f32x = svadd_f32_x(predicate_all_f32x, norm_a_f32x, b_norms_1_f32x);
961
+ svfloat32_t intersection_f32x = svmul_f32_x(
962
+ predicate_all_f32x,
963
+ svadd_f32_x(predicate_all_f32x, svsub_f32_x(predicate_all_f32x, sum_norms_f32x, depth_f32x),
964
+ matching_f32x),
965
+ half_f32x);
966
+ svfloat32_t union_val_f32x = svsub_f32_x(predicate_all_f32x, sum_norms_f32x, intersection_f32x);
967
+ svbool_t nonzero_f32x = svcmpne_f32(predicate_all_f32x, union_val_f32x, zero_f32x);
968
+ svfloat32_t ratio_f32x = svdiv_f32_x(predicate_all_f32x, intersection_f32x, union_val_f32x);
969
+ svfloat32_t jaccard_f32x = svsel_f32(
970
+ nonzero_f32x, svsub_f32_x(predicate_all_f32x, one_f32x, ratio_f32x), one_f32x);
971
+ svst1_f32(predicate_all_f32x, c_row + (column_tile_index + 1) * tile_dim, jaccard_f32x);
972
+ }
973
+ // ZA3
974
+ {
975
+ svuint32_t za3_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_f32x, 3, row);
976
+ svfloat32_t matching_f32x = svcvt_f32_u32_x(predicate_all_f32x, za3_u32x);
977
+ svfloat32_t sum_norms_f32x = svadd_f32_x(predicate_all_f32x, norm_a_f32x, b_norms_2_f32x);
978
+ svfloat32_t intersection_f32x = svmul_f32_x(
979
+ predicate_all_f32x,
980
+ svadd_f32_x(predicate_all_f32x, svsub_f32_x(predicate_all_f32x, sum_norms_f32x, depth_f32x),
981
+ matching_f32x),
982
+ half_f32x);
983
+ svfloat32_t union_val_f32x = svsub_f32_x(predicate_all_f32x, sum_norms_f32x, intersection_f32x);
984
+ svbool_t nonzero_f32x = svcmpne_f32(predicate_all_f32x, union_val_f32x, zero_f32x);
985
+ svfloat32_t ratio_f32x = svdiv_f32_x(predicate_all_f32x, intersection_f32x, union_val_f32x);
986
+ svfloat32_t jaccard_f32x = svsel_f32(
987
+ nonzero_f32x, svsub_f32_x(predicate_all_f32x, one_f32x, ratio_f32x), one_f32x);
988
+ svst1_f32(predicate_all_f32x, c_row + (column_tile_index + 2) * tile_dim, jaccard_f32x);
989
+ }
990
+ }
991
+ }
992
+
993
+ // Remainder: 1 column tile at a time using ZA1
994
+ for (; column_tile_index < column_tile_count; column_tile_index++) {
995
+ nk_size_t const col_tile_start = column_tile_index * tile_dim;
996
+ nk_size_t const cols_remaining = (col_tile_start + tile_dim <= n_vectors) ? tile_dim
997
+ : (n_vectors - col_tile_start);
998
+ svbool_t const column_predicate_f32x = svwhilelt_b32_u64(0u, cols_remaining);
999
+
1000
+ svzero_mask_za(nk_sme_zero_za32_tile_1_);
1001
+
1002
+ for (nk_size_t d_tile = 0; d_tile < depth_tile_count; d_tile++) {
1003
+ nk_size_t const d_start_u32 = d_tile * depth_tile_size;
1004
+ nk_size_t const u32s_this_tile = (d_start_u32 + depth_tile_size <= depth_u32_total)
1005
+ ? depth_tile_size
1006
+ : (depth_u32_total > d_start_u32 ? depth_u32_total - d_start_u32
1007
+ : 0);
1008
+ if (u32s_this_tile == 0) break;
1009
+
1010
+ svzero_mask_za(nk_sme_zero_za32_tile_0_);
1011
+ svbool_t const batch_predicate_f32x = svwhilelt_b32_u64(0u, u32s_this_tile);
1012
+
1013
+ // Load A rows into ZA0 horizontally
1014
+ for (nk_size_t row_in_tile = 0; row_in_tile < rows_clamped; row_in_tile++) {
1015
+ nk_u32_t const *a_row_u32 = (nk_u32_t const *)((char const *)vectors +
1016
+ (row_tile_start + row_in_tile) * stride) +
1017
+ d_start_u32;
1018
+ svld1_hor_za32(0, row_in_tile, batch_predicate_f32x, a_row_u32);
1019
+ }
1020
+
1021
+ // Save A columns from ZA0 to stack buffer
1022
+ for (nk_size_t s = 0; s < u32s_this_tile; s++)
1023
+ svst1_u32(predicate_all_f32x, a_buffer[s],
1024
+ svread_ver_za32_u32_m(svdup_u32(0), row_predicate_f32x, 0, s));
1025
+
1026
+ // Load B column tile into ZA0
1027
+ svzero_mask_za(nk_sme_zero_za32_tile_0_);
1028
+ for (nk_size_t col = 0; col < tile_dim; col++) {
1029
+ nk_size_t const col_abs = col_tile_start + col;
1030
+ if (col_abs < n_vectors) {
1031
+ nk_u32_t const *b_row = (nk_u32_t const *)((char const *)vectors + col_abs * stride) +
1032
+ d_start_u32;
1033
+ svld1_hor_za32(0, col, batch_predicate_f32x, b_row);
1034
+ }
1035
+ }
1036
+ for (nk_size_t step = 0; step < u32s_this_tile; step++) {
1037
+ svuint32_t a_u32x = svld1_u32(predicate_all_f32x, a_buffer[step]);
1038
+ svuint32_t b_u32x = svread_ver_za32_u32_m(svdup_u32(0), column_predicate_f32x, 0, step);
1039
+ svbmopa_za32_u32_m(1, row_predicate_f32x, column_predicate_f32x, a_u32x, b_u32x);
1040
+ }
1041
+ }
1042
+
1043
+ // Compute B tile norms for remainder tile
1044
+ NK_ALIGN64 nk_u32_t b_tile_norms[16];
1045
+ for (nk_size_t col = 0; col < tile_dim; col++) {
1046
+ nk_size_t const col_abs = col_tile_start + col;
1047
+ b_tile_norms[col] = (col_abs < n_vectors)
1048
+ ? nk_sets_reduce_sumsq_u1_streaming_(
1049
+ (nk_u1x8_t const *)((char const *)vectors + col_abs * stride),
1050
+ depth_in_bytes)
1051
+ : 0;
1052
+ }
1053
+
1054
+ svfloat32_t b_norms_f32x = svcvt_f32_u32_x(predicate_all_f32x, svld1_u32(predicate_all_f32x, b_tile_norms));
1055
+ for (nk_size_t row = 0; row < rows_clamped; row++) {
1056
+ svuint32_t za1_u32x = svread_hor_za32_u32_m(svdup_u32(0), predicate_all_f32x, 1, row);
1057
+ svfloat32_t matching_f32x = svcvt_f32_u32_x(predicate_all_f32x, za1_u32x);
1058
+ svfloat32_t norm_a_f32x = svdup_f32(a_tile_norms[row]);
1059
+ svfloat32_t sum_norms_f32x = svadd_f32_x(predicate_all_f32x, norm_a_f32x, b_norms_f32x);
1060
+ svfloat32_t intersection_f32x = svmul_f32_x(
1061
+ predicate_all_f32x,
1062
+ svadd_f32_x(predicate_all_f32x, svsub_f32_x(predicate_all_f32x, sum_norms_f32x, depth_f32x),
1063
+ matching_f32x),
1064
+ half_f32x);
1065
+ svfloat32_t union_val_f32x = svsub_f32_x(predicate_all_f32x, sum_norms_f32x, intersection_f32x);
1066
+ svbool_t nonzero_f32x = svcmpne_f32(predicate_all_f32x, union_val_f32x, zero_f32x);
1067
+ svfloat32_t ratio_f32x = svdiv_f32_x(predicate_all_f32x, intersection_f32x, union_val_f32x);
1068
+ svfloat32_t jaccard_f32x = svsel_f32(nonzero_f32x,
1069
+ svsub_f32_x(predicate_all_f32x, one_f32x, ratio_f32x), one_f32x);
1070
+ nk_f32_t *c_row = (nk_f32_t *)((char *)result + (row_tile_start + row) * result_stride);
1071
+ svst1_f32(column_predicate_f32x, c_row + col_tile_start, jaccard_f32x);
1072
+ }
1073
+ }
1074
+ }
1075
+ }
1076
+
1077
+ NK_PUBLIC void nk_jaccards_symmetric_u1_smebi32(nk_u1x8_t const *vectors, nk_size_t n_vectors, nk_size_t depth_bits,
1078
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1079
+ nk_size_t row_start, nk_size_t row_count) {
1080
+ nk_jaccards_symmetric_u1_smebi32_streaming_(vectors, n_vectors, depth_bits, stride, result, result_stride,
1081
+ row_start, row_count);
1082
+ }
1083
+
1084
+ #pragma endregion // Jaccard Distance
1085
+
1086
+ #if defined(__clang__)
1087
+ #pragma clang attribute pop
1088
+ #elif defined(__GNUC__)
1089
+ #pragma GCC pop_options
1090
+ #endif
1091
+
1092
+ #if defined(__cplusplus)
1093
+ } // extern "C"
1094
+ #endif
1095
+
1096
+ #endif // NK_TARGET_SMEBI32
1097
+ #endif // NK_TARGET_ARM_
1098
+
1099
+ #endif // NK_SETS_SMEBI32_H