numkong 7.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (294) hide show
  1. package/LICENSE +201 -0
  2. package/README.md +495 -0
  3. package/binding.gyp +540 -0
  4. package/c/dispatch.h +512 -0
  5. package/c/dispatch_bf16.c +389 -0
  6. package/c/dispatch_bf16c.c +52 -0
  7. package/c/dispatch_e2m3.c +263 -0
  8. package/c/dispatch_e3m2.c +243 -0
  9. package/c/dispatch_e4m3.c +276 -0
  10. package/c/dispatch_e5m2.c +272 -0
  11. package/c/dispatch_f16.c +376 -0
  12. package/c/dispatch_f16c.c +58 -0
  13. package/c/dispatch_f32.c +378 -0
  14. package/c/dispatch_f32c.c +99 -0
  15. package/c/dispatch_f64.c +296 -0
  16. package/c/dispatch_f64c.c +98 -0
  17. package/c/dispatch_i16.c +96 -0
  18. package/c/dispatch_i32.c +89 -0
  19. package/c/dispatch_i4.c +150 -0
  20. package/c/dispatch_i64.c +86 -0
  21. package/c/dispatch_i8.c +289 -0
  22. package/c/dispatch_other.c +330 -0
  23. package/c/dispatch_u1.c +148 -0
  24. package/c/dispatch_u16.c +124 -0
  25. package/c/dispatch_u32.c +118 -0
  26. package/c/dispatch_u4.c +150 -0
  27. package/c/dispatch_u64.c +102 -0
  28. package/c/dispatch_u8.c +303 -0
  29. package/c/numkong.c +950 -0
  30. package/include/README.md +573 -0
  31. package/include/module.modulemap +129 -0
  32. package/include/numkong/attention/sapphireamx.h +1361 -0
  33. package/include/numkong/attention/sme.h +2066 -0
  34. package/include/numkong/attention.h +49 -0
  35. package/include/numkong/capabilities.h +748 -0
  36. package/include/numkong/cast/README.md +262 -0
  37. package/include/numkong/cast/haswell.h +975 -0
  38. package/include/numkong/cast/icelake.h +470 -0
  39. package/include/numkong/cast/neon.h +1192 -0
  40. package/include/numkong/cast/rvv.h +1021 -0
  41. package/include/numkong/cast/sapphire.h +262 -0
  42. package/include/numkong/cast/serial.h +2262 -0
  43. package/include/numkong/cast/skylake.h +856 -0
  44. package/include/numkong/cast/v128relaxed.h +180 -0
  45. package/include/numkong/cast.h +230 -0
  46. package/include/numkong/curved/README.md +223 -0
  47. package/include/numkong/curved/genoa.h +182 -0
  48. package/include/numkong/curved/haswell.h +276 -0
  49. package/include/numkong/curved/neon.h +205 -0
  50. package/include/numkong/curved/neonbfdot.h +212 -0
  51. package/include/numkong/curved/neonhalf.h +212 -0
  52. package/include/numkong/curved/rvv.h +305 -0
  53. package/include/numkong/curved/serial.h +207 -0
  54. package/include/numkong/curved/skylake.h +457 -0
  55. package/include/numkong/curved/smef64.h +506 -0
  56. package/include/numkong/curved.h +517 -0
  57. package/include/numkong/curved.hpp +144 -0
  58. package/include/numkong/dot/README.md +425 -0
  59. package/include/numkong/dot/alder.h +563 -0
  60. package/include/numkong/dot/genoa.h +315 -0
  61. package/include/numkong/dot/haswell.h +1688 -0
  62. package/include/numkong/dot/icelake.h +883 -0
  63. package/include/numkong/dot/neon.h +818 -0
  64. package/include/numkong/dot/neonbfdot.h +244 -0
  65. package/include/numkong/dot/neonfhm.h +360 -0
  66. package/include/numkong/dot/neonhalf.h +198 -0
  67. package/include/numkong/dot/neonsdot.h +508 -0
  68. package/include/numkong/dot/rvv.h +714 -0
  69. package/include/numkong/dot/rvvbb.h +72 -0
  70. package/include/numkong/dot/rvvbf16.h +123 -0
  71. package/include/numkong/dot/rvvhalf.h +129 -0
  72. package/include/numkong/dot/sapphire.h +141 -0
  73. package/include/numkong/dot/serial.h +838 -0
  74. package/include/numkong/dot/sierra.h +405 -0
  75. package/include/numkong/dot/skylake.h +1084 -0
  76. package/include/numkong/dot/sve.h +379 -0
  77. package/include/numkong/dot/svebfdot.h +74 -0
  78. package/include/numkong/dot/svehalf.h +123 -0
  79. package/include/numkong/dot/v128relaxed.h +1258 -0
  80. package/include/numkong/dot.h +1070 -0
  81. package/include/numkong/dot.hpp +94 -0
  82. package/include/numkong/dots/README.md +496 -0
  83. package/include/numkong/dots/alder.h +114 -0
  84. package/include/numkong/dots/genoa.h +94 -0
  85. package/include/numkong/dots/haswell.h +295 -0
  86. package/include/numkong/dots/icelake.h +171 -0
  87. package/include/numkong/dots/neon.h +120 -0
  88. package/include/numkong/dots/neonbfdot.h +58 -0
  89. package/include/numkong/dots/neonfhm.h +94 -0
  90. package/include/numkong/dots/neonhalf.h +57 -0
  91. package/include/numkong/dots/neonsdot.h +108 -0
  92. package/include/numkong/dots/rvv.h +2486 -0
  93. package/include/numkong/dots/sapphireamx.h +3973 -0
  94. package/include/numkong/dots/serial.h +2844 -0
  95. package/include/numkong/dots/sierra.h +97 -0
  96. package/include/numkong/dots/skylake.h +196 -0
  97. package/include/numkong/dots/sme.h +5372 -0
  98. package/include/numkong/dots/smebi32.h +461 -0
  99. package/include/numkong/dots/smef64.h +1318 -0
  100. package/include/numkong/dots/smehalf.h +47 -0
  101. package/include/numkong/dots/v128relaxed.h +294 -0
  102. package/include/numkong/dots.h +2804 -0
  103. package/include/numkong/dots.hpp +639 -0
  104. package/include/numkong/each/README.md +469 -0
  105. package/include/numkong/each/haswell.h +1658 -0
  106. package/include/numkong/each/icelake.h +272 -0
  107. package/include/numkong/each/neon.h +1104 -0
  108. package/include/numkong/each/neonbfdot.h +212 -0
  109. package/include/numkong/each/neonhalf.h +410 -0
  110. package/include/numkong/each/rvv.h +1121 -0
  111. package/include/numkong/each/sapphire.h +477 -0
  112. package/include/numkong/each/serial.h +260 -0
  113. package/include/numkong/each/skylake.h +1562 -0
  114. package/include/numkong/each.h +2146 -0
  115. package/include/numkong/each.hpp +434 -0
  116. package/include/numkong/geospatial/README.md +147 -0
  117. package/include/numkong/geospatial/haswell.h +593 -0
  118. package/include/numkong/geospatial/neon.h +571 -0
  119. package/include/numkong/geospatial/rvv.h +701 -0
  120. package/include/numkong/geospatial/serial.h +309 -0
  121. package/include/numkong/geospatial/skylake.h +577 -0
  122. package/include/numkong/geospatial/v128relaxed.h +613 -0
  123. package/include/numkong/geospatial.h +453 -0
  124. package/include/numkong/geospatial.hpp +235 -0
  125. package/include/numkong/matrix.hpp +336 -0
  126. package/include/numkong/maxsim/README.md +187 -0
  127. package/include/numkong/maxsim/alder.h +511 -0
  128. package/include/numkong/maxsim/genoa.h +115 -0
  129. package/include/numkong/maxsim/haswell.h +553 -0
  130. package/include/numkong/maxsim/icelake.h +480 -0
  131. package/include/numkong/maxsim/neonsdot.h +394 -0
  132. package/include/numkong/maxsim/sapphireamx.h +877 -0
  133. package/include/numkong/maxsim/serial.h +490 -0
  134. package/include/numkong/maxsim/sme.h +929 -0
  135. package/include/numkong/maxsim/v128relaxed.h +280 -0
  136. package/include/numkong/maxsim.h +571 -0
  137. package/include/numkong/maxsim.hpp +133 -0
  138. package/include/numkong/mesh/README.md +227 -0
  139. package/include/numkong/mesh/haswell.h +2235 -0
  140. package/include/numkong/mesh/neon.h +1329 -0
  141. package/include/numkong/mesh/neonbfdot.h +842 -0
  142. package/include/numkong/mesh/neonhalf.h +616 -0
  143. package/include/numkong/mesh/rvv.h +916 -0
  144. package/include/numkong/mesh/serial.h +742 -0
  145. package/include/numkong/mesh/skylake.h +1135 -0
  146. package/include/numkong/mesh/v128relaxed.h +1052 -0
  147. package/include/numkong/mesh.h +652 -0
  148. package/include/numkong/mesh.hpp +762 -0
  149. package/include/numkong/numkong.h +78 -0
  150. package/include/numkong/numkong.hpp +57 -0
  151. package/include/numkong/probability/README.md +173 -0
  152. package/include/numkong/probability/haswell.h +267 -0
  153. package/include/numkong/probability/neon.h +225 -0
  154. package/include/numkong/probability/rvv.h +409 -0
  155. package/include/numkong/probability/serial.h +169 -0
  156. package/include/numkong/probability/skylake.h +324 -0
  157. package/include/numkong/probability.h +383 -0
  158. package/include/numkong/probability.hpp +120 -0
  159. package/include/numkong/random.h +50 -0
  160. package/include/numkong/random.hpp +285 -0
  161. package/include/numkong/reduce/README.md +547 -0
  162. package/include/numkong/reduce/alder.h +632 -0
  163. package/include/numkong/reduce/genoa.h +201 -0
  164. package/include/numkong/reduce/haswell.h +3783 -0
  165. package/include/numkong/reduce/icelake.h +549 -0
  166. package/include/numkong/reduce/neon.h +3841 -0
  167. package/include/numkong/reduce/neonbfdot.h +353 -0
  168. package/include/numkong/reduce/neonfhm.h +665 -0
  169. package/include/numkong/reduce/neonhalf.h +157 -0
  170. package/include/numkong/reduce/neonsdot.h +357 -0
  171. package/include/numkong/reduce/rvv.h +3407 -0
  172. package/include/numkong/reduce/serial.h +757 -0
  173. package/include/numkong/reduce/sierra.h +338 -0
  174. package/include/numkong/reduce/skylake.h +3792 -0
  175. package/include/numkong/reduce/v128relaxed.h +2302 -0
  176. package/include/numkong/reduce.h +1597 -0
  177. package/include/numkong/reduce.hpp +633 -0
  178. package/include/numkong/scalar/README.md +89 -0
  179. package/include/numkong/scalar/haswell.h +113 -0
  180. package/include/numkong/scalar/neon.h +122 -0
  181. package/include/numkong/scalar/neonhalf.h +70 -0
  182. package/include/numkong/scalar/rvv.h +211 -0
  183. package/include/numkong/scalar/sapphire.h +63 -0
  184. package/include/numkong/scalar/serial.h +332 -0
  185. package/include/numkong/scalar/v128relaxed.h +56 -0
  186. package/include/numkong/scalar.h +683 -0
  187. package/include/numkong/set/README.md +179 -0
  188. package/include/numkong/set/haswell.h +334 -0
  189. package/include/numkong/set/icelake.h +485 -0
  190. package/include/numkong/set/neon.h +364 -0
  191. package/include/numkong/set/rvv.h +226 -0
  192. package/include/numkong/set/rvvbb.h +117 -0
  193. package/include/numkong/set/serial.h +174 -0
  194. package/include/numkong/set/sve.h +185 -0
  195. package/include/numkong/set/v128relaxed.h +240 -0
  196. package/include/numkong/set.h +457 -0
  197. package/include/numkong/set.hpp +114 -0
  198. package/include/numkong/sets/README.md +149 -0
  199. package/include/numkong/sets/haswell.h +63 -0
  200. package/include/numkong/sets/icelake.h +66 -0
  201. package/include/numkong/sets/neon.h +61 -0
  202. package/include/numkong/sets/serial.h +43 -0
  203. package/include/numkong/sets/smebi32.h +1099 -0
  204. package/include/numkong/sets/v128relaxed.h +58 -0
  205. package/include/numkong/sets.h +339 -0
  206. package/include/numkong/sparse/README.md +156 -0
  207. package/include/numkong/sparse/icelake.h +463 -0
  208. package/include/numkong/sparse/neon.h +288 -0
  209. package/include/numkong/sparse/serial.h +117 -0
  210. package/include/numkong/sparse/sve2.h +507 -0
  211. package/include/numkong/sparse/turin.h +322 -0
  212. package/include/numkong/sparse.h +363 -0
  213. package/include/numkong/sparse.hpp +113 -0
  214. package/include/numkong/spatial/README.md +435 -0
  215. package/include/numkong/spatial/alder.h +607 -0
  216. package/include/numkong/spatial/genoa.h +290 -0
  217. package/include/numkong/spatial/haswell.h +960 -0
  218. package/include/numkong/spatial/icelake.h +586 -0
  219. package/include/numkong/spatial/neon.h +773 -0
  220. package/include/numkong/spatial/neonbfdot.h +165 -0
  221. package/include/numkong/spatial/neonhalf.h +118 -0
  222. package/include/numkong/spatial/neonsdot.h +261 -0
  223. package/include/numkong/spatial/rvv.h +984 -0
  224. package/include/numkong/spatial/rvvbf16.h +123 -0
  225. package/include/numkong/spatial/rvvhalf.h +117 -0
  226. package/include/numkong/spatial/sapphire.h +343 -0
  227. package/include/numkong/spatial/serial.h +346 -0
  228. package/include/numkong/spatial/sierra.h +323 -0
  229. package/include/numkong/spatial/skylake.h +606 -0
  230. package/include/numkong/spatial/sve.h +224 -0
  231. package/include/numkong/spatial/svebfdot.h +122 -0
  232. package/include/numkong/spatial/svehalf.h +109 -0
  233. package/include/numkong/spatial/v128relaxed.h +717 -0
  234. package/include/numkong/spatial.h +1425 -0
  235. package/include/numkong/spatial.hpp +183 -0
  236. package/include/numkong/spatials/README.md +580 -0
  237. package/include/numkong/spatials/alder.h +94 -0
  238. package/include/numkong/spatials/genoa.h +94 -0
  239. package/include/numkong/spatials/haswell.h +219 -0
  240. package/include/numkong/spatials/icelake.h +113 -0
  241. package/include/numkong/spatials/neon.h +109 -0
  242. package/include/numkong/spatials/neonbfdot.h +60 -0
  243. package/include/numkong/spatials/neonfhm.h +92 -0
  244. package/include/numkong/spatials/neonhalf.h +58 -0
  245. package/include/numkong/spatials/neonsdot.h +109 -0
  246. package/include/numkong/spatials/rvv.h +1960 -0
  247. package/include/numkong/spatials/sapphireamx.h +1149 -0
  248. package/include/numkong/spatials/serial.h +226 -0
  249. package/include/numkong/spatials/sierra.h +96 -0
  250. package/include/numkong/spatials/skylake.h +184 -0
  251. package/include/numkong/spatials/sme.h +1901 -0
  252. package/include/numkong/spatials/smef64.h +465 -0
  253. package/include/numkong/spatials/v128relaxed.h +240 -0
  254. package/include/numkong/spatials.h +3021 -0
  255. package/include/numkong/spatials.hpp +508 -0
  256. package/include/numkong/tensor.hpp +1592 -0
  257. package/include/numkong/trigonometry/README.md +184 -0
  258. package/include/numkong/trigonometry/haswell.h +652 -0
  259. package/include/numkong/trigonometry/neon.h +639 -0
  260. package/include/numkong/trigonometry/rvv.h +699 -0
  261. package/include/numkong/trigonometry/serial.h +703 -0
  262. package/include/numkong/trigonometry/skylake.h +721 -0
  263. package/include/numkong/trigonometry/v128relaxed.h +666 -0
  264. package/include/numkong/trigonometry.h +467 -0
  265. package/include/numkong/trigonometry.hpp +166 -0
  266. package/include/numkong/types.h +1384 -0
  267. package/include/numkong/types.hpp +5603 -0
  268. package/include/numkong/vector.hpp +698 -0
  269. package/javascript/README.md +246 -0
  270. package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
  271. package/javascript/dist/cjs/numkong-wasm.js +617 -0
  272. package/javascript/dist/cjs/numkong.d.ts +343 -0
  273. package/javascript/dist/cjs/numkong.js +523 -0
  274. package/javascript/dist/cjs/package.json +3 -0
  275. package/javascript/dist/cjs/types.d.ts +284 -0
  276. package/javascript/dist/cjs/types.js +653 -0
  277. package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
  278. package/javascript/dist/esm/numkong-wasm.js +595 -0
  279. package/javascript/dist/esm/numkong.d.ts +343 -0
  280. package/javascript/dist/esm/numkong.js +452 -0
  281. package/javascript/dist/esm/package.json +3 -0
  282. package/javascript/dist/esm/types.d.ts +284 -0
  283. package/javascript/dist/esm/types.js +630 -0
  284. package/javascript/dist-package-cjs.json +3 -0
  285. package/javascript/dist-package-esm.json +3 -0
  286. package/javascript/node-gyp-build.d.ts +1 -0
  287. package/javascript/numkong-wasm.ts +756 -0
  288. package/javascript/numkong.c +689 -0
  289. package/javascript/numkong.ts +575 -0
  290. package/javascript/tsconfig-base.json +39 -0
  291. package/javascript/tsconfig-cjs.json +8 -0
  292. package/javascript/tsconfig-esm.json +8 -0
  293. package/javascript/types.ts +674 -0
  294. package/package.json +87 -0
@@ -0,0 +1,1318 @@
1
+ /**
2
+ * @brief SIMD-accelerated Batched Dot Products for SME F64.
3
+ * @file include/numkong/dots/smef64.h
4
+ * @author Ash Vardanian
5
+ * @date January 2, 2026
6
+ *
7
+ * @sa include/numkong/dots.h
8
+ *
9
+ * Uses ARM SME with `FEAT_SME_F64F64` for high-precision GEMM.
10
+ * Requires Apple M4 or equivalent with `f64` outer product support.
11
+ *
12
+ * Provides `f32` and `f64` GEMM using `ZA64` tiles:
13
+ * - `f32` inputs with `f64` accumulation: higher precision than `ZA32`
14
+ * - Native `f64` GEMM via 3-way Ozaki splitting (19+17+17 mantissa bits)
15
+ *
16
+ * Ozaki splitting for `f64`:
17
+ * Each `f64` value is decomposed into 3 non-overlapping mantissa-masked slices
18
+ * that each fit in `f32` (max 19 significant bits < 24). All cross-products with
19
+ * index sum i+j <= 2 are accumulated via 6 FMOPAs into 3 merged accumulators.
20
+ * Products are exact in `f64` (max 19+19 = 38 < 53 mantissa bits).
21
+ * B is pre-split at pack time into interleaved `f32` slices; A is split in-register.
22
+ * A 2-column-tile fast path halves A memory traffic.
23
+ *
24
+ * Tile dimensions for SVL=512 (Apple M4):
25
+ * - `ZA64` tile: 8 × 8 `f64` elements (512B)
26
+ * - `f64` vectors: 8 elements per SVE vector
27
+ * - `f32` vectors: 16 elements per SVE vector, converted to `f64`
28
+ *
29
+ * Key instructions:
30
+ * - `svmopa_za64_f64_m` / `FMOPA`: `f64` outer product, 16cy amortized
31
+ * - `svcvt_f64_f32_x` / `FCVT`: `f32` → `f64` conversion
32
+ * - `svwrite_hor_za64_f64_m` / `MOVA`: direct Z → ZA tile write (no bounce buffer)
33
+ */
34
+ #ifndef NK_DOTS_SMEF64_H
35
+ #define NK_DOTS_SMEF64_H
36
+
37
+ #if NK_TARGET_ARM_
38
+ #if NK_TARGET_SME
39
+
40
+ #include "numkong/types.h"
41
+ #include "numkong/dots/sme.h" // `nk_dots_sme_packed_header_t`
42
+
43
+ #if defined(__cplusplus)
44
+ extern "C" {
45
+ #endif
46
+
47
+ #if defined(__clang__)
48
+ #pragma clang attribute push(__attribute__((target("sme,sve,sme-f64f64"))), apply_to = function)
49
+ #elif defined(__GNUC__)
50
+ #pragma GCC push_options
51
+ #pragma GCC target("+sme+sme-f64f64")
52
+ #endif
53
+
54
+ /*
55
+ * f32 → f64 GEMM using FMOPA with ZA64 tiles (FEAT_SME_F64F64).
56
+ *
57
+ * Tile layout (SVL=512, Apple M4):
58
+ * - ZA64 output tile: 8 × 8 f64 elements (512 B)
59
+ * - f32 input vectors: 16 elements (SVL/32), converted to f64 in chunks of 8
60
+ * - Depth sub-loop: processes 8 f32 values per iteration (→ 8 f64)
61
+ * - FMOPA predicates: b64 (f64 output granularity)
62
+ * - f32 load predicates: b32 (f32 input granularity)
63
+ * - 4-tile path: ZA0-ZA3 process 4 column tiles simultaneously
64
+ * - Output: native f64 results written directly from ZA64 tiles
65
+ *
66
+ * Non-widening alternative (FEAT_SME_F32F32, `svmopa_za32_f32_m`): ZA32 tiles are 16×16
67
+ * (4× area vs ZA64 8×8) with no f32↔f64 conversion, offering ~3-4× raw throughput. However,
68
+ * ZA32 and ZA64 tiles alias physically (ZA0.S overlaps ZA0.D+ZA1.D), so a periodic flush to
69
+ * f64 stack accumulators would be needed for precision above f32 — erasing most speedup.
70
+ * Pure f32 accumulation (no flush) provides only f32 precision, which is already served by
71
+ * the f16 → f32 GEMM path for reduced-precision workloads. This f64 path exists specifically
72
+ * for higher-than-f32 accumulation precision; replacing it with f32 FMOPA would be
73
+ * counterproductive. Apple M4 has `hw.optional.arm.SME_F32F32: 1` but we don't use it here.
74
+ */
75
+ #pragma region Single Precision Floats
76
+
77
+ NK_PUBLIC nk_size_t nk_dots_packed_size_f32_smef64(nk_size_t columns, nk_size_t depth) {
78
+ nk_size_t const tile_dimension = svcntsd(); // rows per `ZA64` tile (8 for SVL=512)
79
+ nk_size_t const depth_tile_size = svcntsw(); // `f32` depth elements per tile (16 for SVL=512)
80
+
81
+ nk_size_t const column_tile_count = nk_size_divide_round_up_(columns, tile_dimension);
82
+ nk_size_t const depth_tile_count = nk_size_divide_round_up_(depth, depth_tile_size);
83
+
84
+ nk_size_t size = sizeof(nk_dots_sme_packed_header_t);
85
+ size += column_tile_count * depth_tile_count * tile_dimension * depth_tile_size * sizeof(nk_f32_t);
86
+ size += columns * sizeof(nk_f64_t); // per-column squared norms
87
+
88
+ return size;
89
+ }
90
+
91
+ NK_PUBLIC void nk_dots_pack_f32_smef64(nk_f32_t const *b, nk_size_t columns, nk_size_t depth, nk_size_t b_stride,
92
+ void *b_packed) {
93
+
94
+ nk_size_t const tile_dimension = svcntsd(); // rows per `ZA64` tile (8 for SVL=512)
95
+ nk_size_t const depth_tile_size = svcntsw(); // `f32` depth elements per tile (16 for SVL=512)
96
+ nk_size_t const tile_elements = tile_dimension * depth_tile_size; // 128
97
+ nk_size_t const b_stride_elements = b_stride / sizeof(nk_f32_t);
98
+
99
+ nk_size_t const column_tile_count = nk_size_divide_round_up_(columns, tile_dimension);
100
+ nk_size_t const depth_tile_count = nk_size_divide_round_up_(depth, depth_tile_size);
101
+ nk_size_t const total_tiles = column_tile_count * depth_tile_count;
102
+
103
+ // Store actual dimensions and tile counts in header
104
+ nk_dots_sme_packed_header_t *header = (nk_dots_sme_packed_header_t *)b_packed;
105
+ header->column_tile_count = (nk_u32_t)column_tile_count;
106
+ header->depth_tile_count = (nk_u32_t)depth_tile_count;
107
+ header->columns = (nk_u32_t)columns;
108
+ header->depth = (nk_u32_t)depth;
109
+ header->svl_bytes = (nk_u32_t)svcntsb(); // streaming vector length in bytes
110
+
111
+ nk_f32_t *tiles = (nk_f32_t *)((char *)b_packed + sizeof(nk_dots_sme_packed_header_t));
112
+
113
+ // Zero-initialize all tiles (handles partial tile padding)
114
+ nk_size_t const total_elements = total_tiles * tile_elements;
115
+ for (nk_size_t i = 0; i < total_elements; i++) tiles[i] = 0.0f;
116
+
117
+ // Pack data into tiles with depth-major layout within each tile:
118
+ // dst_idx = depth_idx * tile_dimension + column_idx
119
+ // This allows loading one B vector per depth step: svld1(b_tile + k * tile_dimension)
120
+ for (nk_size_t column_tile_idx = 0; column_tile_idx < column_tile_count; column_tile_idx++) {
121
+ for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tile_count; depth_tile_idx++) {
122
+ nk_size_t const tile_index = column_tile_idx * depth_tile_count + depth_tile_idx;
123
+ nk_f32_t *tile_output = tiles + tile_index * tile_elements;
124
+
125
+ nk_size_t const src_row_start = column_tile_idx * tile_dimension;
126
+ nk_size_t const src_column_start = depth_tile_idx * depth_tile_size;
127
+
128
+ // Handle partial tiles at edges
129
+ nk_size_t const rows_to_pack = (src_row_start + tile_dimension <= columns) ? tile_dimension
130
+ : (columns - src_row_start);
131
+ nk_size_t const columns_to_pack = (src_column_start + depth_tile_size <= depth)
132
+ ? depth_tile_size
133
+ : (depth - src_column_start);
134
+
135
+ for (nk_size_t column_idx = 0; column_idx < rows_to_pack; column_idx++) {
136
+ for (nk_size_t depth_idx = 0; depth_idx < columns_to_pack; depth_idx++) {
137
+ nk_size_t const src_idx = (src_row_start + column_idx) * b_stride_elements + src_column_start +
138
+ depth_idx;
139
+ nk_size_t const dst_idx = depth_idx * tile_dimension + column_idx;
140
+ tile_output[dst_idx] = b[src_idx];
141
+ }
142
+ }
143
+ }
144
+ }
145
+
146
+ // Compute per-column squared norms and store after packed data
147
+ nk_size_t const data_size = total_tiles * tile_elements * sizeof(nk_f32_t);
148
+ header->norms_offset = (nk_u32_t)(sizeof(nk_dots_sme_packed_header_t) + data_size);
149
+ nk_f64_t *norms_ptr = (nk_f64_t *)((char *)b_packed + header->norms_offset);
150
+ for (nk_size_t col = 0; col < columns; col++) {
151
+ nk_f32_t const *col_data = (nk_f32_t const *)((char const *)b + col * b_stride);
152
+ norms_ptr[col] = nk_dots_reduce_sumsq_f32_(col_data, depth);
153
+ }
154
+ }
155
+
156
+ __arm_locally_streaming __arm_new("za") static void nk_dots_packed_f32_smef64_streaming_(
157
+ nk_f32_t const *a, void const *b_packed, nk_f64_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
158
+ nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
159
+
160
+ nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
161
+ nk_size_t const column_tile_count = header->column_tile_count;
162
+ nk_size_t const depth_tile_count = header->depth_tile_count;
163
+
164
+ nk_size_t const tile_dimension = svcntd(); // 8 for 512-bit SVL
165
+ nk_size_t const depth_tile_size = svcntw(); // 16 for 512-bit SVL
166
+ nk_size_t const tile_elements = tile_dimension * depth_tile_size;
167
+ nk_size_t const depth_steps_per_batch = tile_dimension; // 8 depth steps per ZA0.D load
168
+
169
+ nk_f32_t const *b_tiles = (nk_f32_t const *)((char const *)b_packed + sizeof(nk_dots_sme_packed_header_t));
170
+
171
+ svbool_t const predicate_all_f64x = svptrue_b64();
172
+
173
+ // ZA0.D = staging, ZA1-7.D = accumulation (7-tile fast path)
174
+ for (nk_size_t row_tile_index = 0; row_tile_index < nk_size_divide_round_up_(rows, tile_dimension);
175
+ row_tile_index++) {
176
+ nk_size_t const row_start = row_tile_index * tile_dimension;
177
+ nk_size_t const rows_remaining = (row_start + tile_dimension <= rows) ? tile_dimension : (rows - row_start);
178
+ svbool_t const row_predicate_f64x = svwhilelt_b64_u64(0u, rows_remaining);
179
+
180
+ nk_size_t column_tile_index = 0;
181
+
182
+ // Fast path: 7 column tiles using ZA1-ZA7 (ZA0.D = staging)
183
+ for (; column_tile_index + 7 <= column_tile_count; column_tile_index += 7) {
184
+ svzero_mask_za(nk_sme_zero_za64_tiles_1_7_);
185
+
186
+ for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tile_count; depth_tile_idx++) {
187
+ nk_size_t const depth_offset = depth_tile_idx * depth_tile_size;
188
+
189
+ // Process depth_tile_size elements in batches of tile_dimension (8)
190
+ for (nk_size_t depth_batch_start = 0; depth_batch_start < depth_tile_size;
191
+ depth_batch_start += depth_steps_per_batch) {
192
+ nk_size_t const depth_batch_end = (depth_batch_start + depth_steps_per_batch < depth_tile_size)
193
+ ? depth_batch_start + depth_steps_per_batch
194
+ : depth_tile_size;
195
+ nk_size_t const batch_size = depth_batch_end - depth_batch_start;
196
+
197
+ // Check if any elements in this batch are valid
198
+ if (depth_offset + depth_batch_start >= depth) break;
199
+
200
+ svzero_mask_za(nk_sme_zero_za64_tile_0_);
201
+
202
+ // Load A rows into ZA0.D: extending load f32→u64 + convert to f64
203
+ svbool_t const batch_predicate_f64x = svwhilelt_b64_u64(0u, (uint64_t)batch_size);
204
+ svbool_t const a_depth_predicate_f64x = svwhilelt_b64_u64(depth_offset + depth_batch_start,
205
+ (uint64_t)depth);
206
+ for (nk_size_t row_in_tile = 0; row_in_tile < rows_remaining; row_in_tile++) {
207
+ nk_size_t const a_row = row_start + row_in_tile;
208
+ // Extending load: svld1uw_u64 loads f32 bits into lower 32 of each u64 lane
209
+ svfloat64_t a_row_widened_f64x = svcvt_f64_f32_x(
210
+ batch_predicate_f64x,
211
+ svreinterpret_f32_u64(svld1uw_u64(
212
+ a_depth_predicate_f64x,
213
+ (nk_u32_t const *)&a[a_row * a_stride_elements + depth_offset + depth_batch_start])));
214
+ svwrite_hor_za64_f64_m(0, row_in_tile, batch_predicate_f64x, a_row_widened_f64x);
215
+ }
216
+
217
+ // Vertical read + MOPA for each depth step in batch
218
+ for (nk_size_t step = 0; step < batch_size; step++) {
219
+ nk_size_t const k_abs = depth_offset + depth_batch_start + step;
220
+ if (k_abs >= depth) break;
221
+
222
+ svfloat64_t a_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), row_predicate_f64x, 0, step);
223
+
224
+ nk_size_t const b_k = depth_batch_start + step;
225
+
226
+ // Extending load f32→u64 + convert to f64: svld1uw_u64 replaces svld1_f32 + svunpklo_u64
227
+ svfloat64_t b_column_tile_1_f64x = svcvt_f64_f32_x(
228
+ predicate_all_f64x,
229
+ svreinterpret_f32_u64(svld1uw_u64(
230
+ predicate_all_f64x,
231
+ (nk_u32_t const *)(b_tiles +
232
+ ((column_tile_index + 0) * depth_tile_count + depth_tile_idx) *
233
+ tile_elements +
234
+ b_k * tile_dimension))));
235
+ svfloat64_t b_column_tile_2_f64x = svcvt_f64_f32_x(
236
+ predicate_all_f64x,
237
+ svreinterpret_f32_u64(svld1uw_u64(
238
+ predicate_all_f64x,
239
+ (nk_u32_t const *)(b_tiles +
240
+ ((column_tile_index + 1) * depth_tile_count + depth_tile_idx) *
241
+ tile_elements +
242
+ b_k * tile_dimension))));
243
+ svfloat64_t b_column_tile_3_f64x = svcvt_f64_f32_x(
244
+ predicate_all_f64x,
245
+ svreinterpret_f32_u64(svld1uw_u64(
246
+ predicate_all_f64x,
247
+ (nk_u32_t const *)(b_tiles +
248
+ ((column_tile_index + 2) * depth_tile_count + depth_tile_idx) *
249
+ tile_elements +
250
+ b_k * tile_dimension))));
251
+ svfloat64_t b_column_tile_4_f64x = svcvt_f64_f32_x(
252
+ predicate_all_f64x,
253
+ svreinterpret_f32_u64(svld1uw_u64(
254
+ predicate_all_f64x,
255
+ (nk_u32_t const *)(b_tiles +
256
+ ((column_tile_index + 3) * depth_tile_count + depth_tile_idx) *
257
+ tile_elements +
258
+ b_k * tile_dimension))));
259
+ svfloat64_t b_column_tile_5_f64x = svcvt_f64_f32_x(
260
+ predicate_all_f64x,
261
+ svreinterpret_f32_u64(svld1uw_u64(
262
+ predicate_all_f64x,
263
+ (nk_u32_t const *)(b_tiles +
264
+ ((column_tile_index + 4) * depth_tile_count + depth_tile_idx) *
265
+ tile_elements +
266
+ b_k * tile_dimension))));
267
+ svfloat64_t b_column_tile_6_f64x = svcvt_f64_f32_x(
268
+ predicate_all_f64x,
269
+ svreinterpret_f32_u64(svld1uw_u64(
270
+ predicate_all_f64x,
271
+ (nk_u32_t const *)(b_tiles +
272
+ ((column_tile_index + 5) * depth_tile_count + depth_tile_idx) *
273
+ tile_elements +
274
+ b_k * tile_dimension))));
275
+ svfloat64_t b_column_tile_7_f64x = svcvt_f64_f32_x(
276
+ predicate_all_f64x,
277
+ svreinterpret_f32_u64(svld1uw_u64(
278
+ predicate_all_f64x,
279
+ (nk_u32_t const *)(b_tiles +
280
+ ((column_tile_index + 6) * depth_tile_count + depth_tile_idx) *
281
+ tile_elements +
282
+ b_k * tile_dimension))));
283
+
284
+ svmopa_za64_f64_m(1, row_predicate_f64x, predicate_all_f64x, a_f64x, b_column_tile_1_f64x);
285
+ svmopa_za64_f64_m(2, row_predicate_f64x, predicate_all_f64x, a_f64x, b_column_tile_2_f64x);
286
+ svmopa_za64_f64_m(3, row_predicate_f64x, predicate_all_f64x, a_f64x, b_column_tile_3_f64x);
287
+ svmopa_za64_f64_m(4, row_predicate_f64x, predicate_all_f64x, a_f64x, b_column_tile_4_f64x);
288
+ svmopa_za64_f64_m(5, row_predicate_f64x, predicate_all_f64x, a_f64x, b_column_tile_5_f64x);
289
+ svmopa_za64_f64_m(6, row_predicate_f64x, predicate_all_f64x, a_f64x, b_column_tile_6_f64x);
290
+ svmopa_za64_f64_m(7, row_predicate_f64x, predicate_all_f64x, a_f64x, b_column_tile_7_f64x);
291
+ }
292
+ }
293
+ }
294
+
295
+ // Extract from ZA1-7 and store native f64 outputs.
296
+ svbool_t const predicate_tile_f64x = svwhilelt_b64_u64(0u, tile_dimension);
297
+ // The 7th tile (index 6) may be partial when it's the last column tile
298
+ nk_size_t const last_fast_col_start = (column_tile_index + 6) * tile_dimension;
299
+ nk_size_t const last_fast_cols = (last_fast_col_start + tile_dimension <= columns)
300
+ ? tile_dimension
301
+ : (columns - last_fast_col_start);
302
+ svbool_t const last_tile_pred_f64x = svwhilelt_b64_u64(0u, last_fast_cols);
303
+ for (nk_size_t row_idx = 0; row_idx < rows_remaining; row_idx++) {
304
+ nk_f64_t *c_row = c + (row_start + row_idx) * c_stride_elements;
305
+
306
+ svfloat64_t za_row_f64x = svread_hor_za64_f64_m(svdup_f64(0), predicate_all_f64x, 1, row_idx);
307
+ svst1_f64(predicate_tile_f64x, c_row + (column_tile_index + 0) * tile_dimension, za_row_f64x);
308
+
309
+ za_row_f64x = svread_hor_za64_f64_m(svdup_f64(0), predicate_all_f64x, 2, row_idx);
310
+ svst1_f64(predicate_tile_f64x, c_row + (column_tile_index + 1) * tile_dimension, za_row_f64x);
311
+
312
+ za_row_f64x = svread_hor_za64_f64_m(svdup_f64(0), predicate_all_f64x, 3, row_idx);
313
+ svst1_f64(predicate_tile_f64x, c_row + (column_tile_index + 2) * tile_dimension, za_row_f64x);
314
+
315
+ za_row_f64x = svread_hor_za64_f64_m(svdup_f64(0), predicate_all_f64x, 4, row_idx);
316
+ svst1_f64(predicate_tile_f64x, c_row + (column_tile_index + 3) * tile_dimension, za_row_f64x);
317
+
318
+ za_row_f64x = svread_hor_za64_f64_m(svdup_f64(0), predicate_all_f64x, 5, row_idx);
319
+ svst1_f64(predicate_tile_f64x, c_row + (column_tile_index + 4) * tile_dimension, za_row_f64x);
320
+
321
+ za_row_f64x = svread_hor_za64_f64_m(svdup_f64(0), predicate_all_f64x, 6, row_idx);
322
+ svst1_f64(predicate_tile_f64x, c_row + (column_tile_index + 5) * tile_dimension, za_row_f64x);
323
+
324
+ za_row_f64x = svread_hor_za64_f64_m(svdup_f64(0), predicate_all_f64x, 7, row_idx);
325
+ svst1_f64(last_tile_pred_f64x, c_row + (column_tile_index + 6) * tile_dimension, za_row_f64x);
326
+ }
327
+ }
328
+
329
+ // Remainder: 1 column tile at a time using ZA1
330
+ for (; column_tile_index < column_tile_count; column_tile_index++) {
331
+ nk_size_t const column_start = column_tile_index * tile_dimension;
332
+ nk_size_t const columns_remaining = (column_start + tile_dimension <= columns) ? tile_dimension
333
+ : (columns - column_start);
334
+ svbool_t const column_predicate_f64x = svwhilelt_b64_u64(0u, columns_remaining);
335
+
336
+ svzero_mask_za(nk_sme_zero_za64_tile_1_);
337
+
338
+ for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tile_count; depth_tile_idx++) {
339
+ nk_size_t const depth_offset = depth_tile_idx * depth_tile_size;
340
+
341
+ for (nk_size_t depth_batch_start = 0; depth_batch_start < depth_tile_size;
342
+ depth_batch_start += depth_steps_per_batch) {
343
+ nk_size_t const depth_batch_end = (depth_batch_start + depth_steps_per_batch < depth_tile_size)
344
+ ? depth_batch_start + depth_steps_per_batch
345
+ : depth_tile_size;
346
+ nk_size_t const batch_size = depth_batch_end - depth_batch_start;
347
+
348
+ if (depth_offset + depth_batch_start >= depth) break;
349
+
350
+ svzero_mask_za(nk_sme_zero_za64_tile_0_);
351
+
352
+ svbool_t const batch_predicate_f64x = svwhilelt_b64_u64(0u, (uint64_t)batch_size);
353
+ svbool_t const a_depth_pred_f64x = svwhilelt_b64_u64(depth_offset + depth_batch_start,
354
+ (uint64_t)depth);
355
+ for (nk_size_t row_in_tile = 0; row_in_tile < rows_remaining; row_in_tile++) {
356
+ nk_size_t const a_row = row_start + row_in_tile;
357
+ svfloat64_t a_row_widened_f64x = svcvt_f64_f32_x(
358
+ batch_predicate_f64x,
359
+ svreinterpret_f32_u64(svld1uw_u64(
360
+ a_depth_pred_f64x,
361
+ (nk_u32_t const *)&a[a_row * a_stride_elements + depth_offset + depth_batch_start])));
362
+ svwrite_hor_za64_f64_m(0, row_in_tile, batch_predicate_f64x, a_row_widened_f64x);
363
+ }
364
+
365
+ for (nk_size_t step = 0; step < batch_size; step++) {
366
+ nk_size_t const k_abs = depth_offset + depth_batch_start + step;
367
+ if (k_abs >= depth) break;
368
+
369
+ svfloat64_t a_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), row_predicate_f64x, 0, step);
370
+
371
+ nk_size_t const b_k = depth_batch_start + step;
372
+ nk_f32_t const *b_tile = b_tiles + (column_tile_index * depth_tile_count + depth_tile_idx) *
373
+ tile_elements;
374
+ // Extending load f32→u64 + convert to f64
375
+ svfloat64_t b_f64x = svcvt_f64_f32_x(
376
+ predicate_all_f64x,
377
+ svreinterpret_f32_u64(
378
+ svld1uw_u64(predicate_all_f64x, (nk_u32_t const *)(b_tile + b_k * tile_dimension))));
379
+
380
+ svmopa_za64_f64_m(1, row_predicate_f64x, column_predicate_f64x, a_f64x, b_f64x);
381
+ }
382
+ }
383
+ }
384
+
385
+ // Store native f64 outputs for the tail column tile.
386
+ for (nk_size_t row_idx = 0; row_idx < rows_remaining; row_idx++) {
387
+ svfloat64_t za_row_f64x = svread_hor_za64_f64_m(svdup_f64(0), predicate_all_f64x, 1, row_idx);
388
+ nk_f64_t *c_row = c + (row_start + row_idx) * c_stride_elements + column_start;
389
+ svst1_f64(column_predicate_f64x, c_row, za_row_f64x);
390
+ }
391
+ }
392
+ }
393
+ }
394
+
395
+ NK_PUBLIC void nk_dots_packed_f32_smef64(nk_f32_t const *a, void const *b_packed, nk_f64_t *c, nk_size_t rows,
396
+ nk_size_t columns, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride) {
397
+
398
+ nk_size_t const a_stride_elements = a_stride / sizeof(nk_f32_t);
399
+ nk_size_t const c_stride_elements = c_stride / sizeof(nk_f64_t);
400
+
401
+ nk_dots_packed_f32_smef64_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
402
+ }
403
+
404
+ /**
405
+ * `f32` × `f32` → `f32` symmetric kernel using MOPA self-GEMM with f64 accumulation.
406
+ * Time-shares ZA0 for both A and B transposition: loads A horizontally,
407
+ * pre-reads A columns into Z registers, then reloads ZA0 with widened B data
408
+ * per column tile. Eliminates all scalar B-packing loops.
409
+ */
410
+ __arm_locally_streaming __arm_new("za") static void nk_dots_symmetric_f32_smef64_streaming_(
411
+ nk_f32_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride_elements, nk_f64_t *result,
412
+ nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
413
+
414
+ nk_size_t const tile_dimension = svcntd(); // 8 for SVL=512
415
+ nk_size_t const depth_tile_size = svcntw(); // 16 for SVL=512
416
+ nk_size_t const depth_steps_per_batch = tile_dimension; // 8
417
+
418
+ svbool_t const predicate_all_f64x = svptrue_b64();
419
+
420
+ NK_ALIGN64 nk_f64_t a_buffer[8][8];
421
+
422
+ nk_size_t const row_end = row_start + row_count;
423
+ nk_size_t const column_tile_count = nk_size_divide_round_up_(n_vectors, tile_dimension);
424
+ nk_size_t const depth_tile_count = nk_size_divide_round_up_(depth, depth_tile_size);
425
+
426
+ for (nk_size_t row_tile_start = row_start; row_tile_start < row_end && row_tile_start < n_vectors;
427
+ row_tile_start += tile_dimension) {
428
+ nk_size_t const rows_clamped = (row_tile_start + tile_dimension <= row_end) ? tile_dimension
429
+ : (row_end - row_tile_start);
430
+ nk_size_t const rows_actual = (row_tile_start + rows_clamped <= n_vectors) ? rows_clamped
431
+ : (n_vectors - row_tile_start);
432
+ svbool_t const row_predicate_f64x = svwhilelt_b64_u64(0u, rows_actual);
433
+
434
+ nk_size_t column_tile_index = 0;
435
+
436
+ // Fast path: 7 column tiles at a time
437
+ for (; column_tile_index + 7 <= column_tile_count; column_tile_index += 7) {
438
+ svzero_mask_za(nk_sme_zero_za64_tiles_1_7_);
439
+
440
+ for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tile_count; depth_tile_idx++) {
441
+ nk_size_t const depth_offset = depth_tile_idx * depth_tile_size;
442
+
443
+ // Process depth_tile_size in batches of depth_steps_per_batch (8)
444
+ for (nk_size_t depth_batch_start = 0; depth_batch_start < depth_tile_size;
445
+ depth_batch_start += depth_steps_per_batch) {
446
+ nk_size_t const depth_batch_end = (depth_batch_start + depth_steps_per_batch < depth_tile_size)
447
+ ? depth_batch_start + depth_steps_per_batch
448
+ : depth_tile_size;
449
+ nk_size_t const batch_size = depth_batch_end - depth_batch_start;
450
+
451
+ if (depth_offset + depth_batch_start >= depth) break;
452
+
453
+ // ZA transpose for A rows: extending load f32→f64, MOVA directly into ZA0
454
+ svbool_t const batch_predicate_f64x = svwhilelt_b64_u64(0u, (uint64_t)batch_size);
455
+ svbool_t const a_depth_predicate_f64x = svwhilelt_b64_u64(depth_offset + depth_batch_start,
456
+ (uint64_t)depth);
457
+ svzero_mask_za(nk_sme_zero_za64_tile_0_);
458
+ for (nk_size_t row_in_tile = 0; row_in_tile < rows_actual; row_in_tile++) {
459
+ nk_size_t const row_abs = row_tile_start + row_in_tile;
460
+ svfloat64_t a_row_widened_f64x = svcvt_f64_f32_x(
461
+ batch_predicate_f64x,
462
+ svreinterpret_f32_u64(svld1uw_u64(
463
+ a_depth_predicate_f64x, (nk_u32_t const *)&vectors[row_abs * stride_elements +
464
+ depth_offset + depth_batch_start])));
465
+ svwrite_hor_za64_f64_m(0, row_in_tile, batch_predicate_f64x, a_row_widened_f64x);
466
+ }
467
+
468
+ // Save A columns from ZA0 to stack buffer
469
+ for (nk_size_t s = 0; s < batch_size; s++)
470
+ svst1_f64(predicate_all_f64x, a_buffer[s],
471
+ svread_ver_za64_f64_m(svdup_f64(0), row_predicate_f64x, 0, s));
472
+
473
+ // Column tile 0 → ZA1 via MOVA
474
+ svzero_mask_za(nk_sme_zero_za64_tile_0_);
475
+ for (nk_size_t column = 0; column < tile_dimension; column++) {
476
+ nk_size_t const column_abs = (column_tile_index + 0) * tile_dimension + column;
477
+ if (column_abs < n_vectors) {
478
+ svfloat64_t widened_f64x = svcvt_f64_f32_x(
479
+ batch_predicate_f64x,
480
+ svreinterpret_f32_u64(svld1uw_u64(
481
+ a_depth_predicate_f64x,
482
+ (nk_u32_t const
483
+ *)&vectors[column_abs * stride_elements + depth_offset + depth_batch_start])));
484
+ svwrite_hor_za64_f64_m(0, column, batch_predicate_f64x, widened_f64x);
485
+ }
486
+ }
487
+ for (nk_size_t step = 0; step < batch_size; step++) {
488
+ svfloat64_t a_f64x = svld1_f64(predicate_all_f64x, a_buffer[step]);
489
+ svfloat64_t b_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), predicate_all_f64x, 0, step);
490
+ svmopa_za64_f64_m(1, row_predicate_f64x, predicate_all_f64x, a_f64x, b_f64x);
491
+ }
492
+
493
+ // Column tile 1 → ZA2 via MOVA
494
+ svzero_mask_za(nk_sme_zero_za64_tile_0_);
495
+ for (nk_size_t column = 0; column < tile_dimension; column++) {
496
+ nk_size_t const column_abs = (column_tile_index + 1) * tile_dimension + column;
497
+ if (column_abs < n_vectors) {
498
+ svfloat64_t widened_f64x = svcvt_f64_f32_x(
499
+ batch_predicate_f64x,
500
+ svreinterpret_f32_u64(svld1uw_u64(
501
+ a_depth_predicate_f64x,
502
+ (nk_u32_t const
503
+ *)&vectors[column_abs * stride_elements + depth_offset + depth_batch_start])));
504
+ svwrite_hor_za64_f64_m(0, column, batch_predicate_f64x, widened_f64x);
505
+ }
506
+ }
507
+ for (nk_size_t step = 0; step < batch_size; step++) {
508
+ svfloat64_t a_f64x = svld1_f64(predicate_all_f64x, a_buffer[step]);
509
+ svfloat64_t b_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), predicate_all_f64x, 0, step);
510
+ svmopa_za64_f64_m(2, row_predicate_f64x, predicate_all_f64x, a_f64x, b_f64x);
511
+ }
512
+
513
+ // Column tile 2 → ZA3 via MOVA
514
+ svzero_mask_za(nk_sme_zero_za64_tile_0_);
515
+ for (nk_size_t column = 0; column < tile_dimension; column++) {
516
+ nk_size_t const column_abs = (column_tile_index + 2) * tile_dimension + column;
517
+ if (column_abs < n_vectors) {
518
+ svfloat64_t widened_f64x = svcvt_f64_f32_x(
519
+ batch_predicate_f64x,
520
+ svreinterpret_f32_u64(svld1uw_u64(
521
+ a_depth_predicate_f64x,
522
+ (nk_u32_t const
523
+ *)&vectors[column_abs * stride_elements + depth_offset + depth_batch_start])));
524
+ svwrite_hor_za64_f64_m(0, column, batch_predicate_f64x, widened_f64x);
525
+ }
526
+ }
527
+ for (nk_size_t step = 0; step < batch_size; step++) {
528
+ svfloat64_t a_f64x = svld1_f64(predicate_all_f64x, a_buffer[step]);
529
+ svfloat64_t b_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), predicate_all_f64x, 0, step);
530
+ svmopa_za64_f64_m(3, row_predicate_f64x, predicate_all_f64x, a_f64x, b_f64x);
531
+ }
532
+
533
+ // Column tile 3 → ZA4 via MOVA
534
+ svzero_mask_za(nk_sme_zero_za64_tile_0_);
535
+ for (nk_size_t column = 0; column < tile_dimension; column++) {
536
+ nk_size_t const column_abs = (column_tile_index + 3) * tile_dimension + column;
537
+ if (column_abs < n_vectors) {
538
+ svfloat64_t widened_f64x = svcvt_f64_f32_x(
539
+ batch_predicate_f64x,
540
+ svreinterpret_f32_u64(svld1uw_u64(
541
+ a_depth_predicate_f64x,
542
+ (nk_u32_t const
543
+ *)&vectors[column_abs * stride_elements + depth_offset + depth_batch_start])));
544
+ svwrite_hor_za64_f64_m(0, column, batch_predicate_f64x, widened_f64x);
545
+ }
546
+ }
547
+ for (nk_size_t step = 0; step < batch_size; step++) {
548
+ svfloat64_t a_f64x = svld1_f64(predicate_all_f64x, a_buffer[step]);
549
+ svfloat64_t b_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), predicate_all_f64x, 0, step);
550
+ svmopa_za64_f64_m(4, row_predicate_f64x, predicate_all_f64x, a_f64x, b_f64x);
551
+ }
552
+
553
+ // Column tile 4 → ZA5 via MOVA
554
+ svzero_mask_za(nk_sme_zero_za64_tile_0_);
555
+ for (nk_size_t column = 0; column < tile_dimension; column++) {
556
+ nk_size_t const column_abs = (column_tile_index + 4) * tile_dimension + column;
557
+ if (column_abs < n_vectors) {
558
+ svfloat64_t widened_f64x = svcvt_f64_f32_x(
559
+ batch_predicate_f64x,
560
+ svreinterpret_f32_u64(svld1uw_u64(
561
+ a_depth_predicate_f64x,
562
+ (nk_u32_t const
563
+ *)&vectors[column_abs * stride_elements + depth_offset + depth_batch_start])));
564
+ svwrite_hor_za64_f64_m(0, column, batch_predicate_f64x, widened_f64x);
565
+ }
566
+ }
567
+ for (nk_size_t step = 0; step < batch_size; step++) {
568
+ svfloat64_t a_f64x = svld1_f64(predicate_all_f64x, a_buffer[step]);
569
+ svfloat64_t b_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), predicate_all_f64x, 0, step);
570
+ svmopa_za64_f64_m(5, row_predicate_f64x, predicate_all_f64x, a_f64x, b_f64x);
571
+ }
572
+
573
+ // Column tile 5 → ZA6 via MOVA
574
+ svzero_mask_za(nk_sme_zero_za64_tile_0_);
575
+ for (nk_size_t column = 0; column < tile_dimension; column++) {
576
+ nk_size_t const column_abs = (column_tile_index + 5) * tile_dimension + column;
577
+ if (column_abs < n_vectors) {
578
+ svfloat64_t widened_f64x = svcvt_f64_f32_x(
579
+ batch_predicate_f64x,
580
+ svreinterpret_f32_u64(svld1uw_u64(
581
+ a_depth_predicate_f64x,
582
+ (nk_u32_t const
583
+ *)&vectors[column_abs * stride_elements + depth_offset + depth_batch_start])));
584
+ svwrite_hor_za64_f64_m(0, column, batch_predicate_f64x, widened_f64x);
585
+ }
586
+ }
587
+ for (nk_size_t step = 0; step < batch_size; step++) {
588
+ svfloat64_t a_f64x = svld1_f64(predicate_all_f64x, a_buffer[step]);
589
+ svfloat64_t b_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), predicate_all_f64x, 0, step);
590
+ svmopa_za64_f64_m(6, row_predicate_f64x, predicate_all_f64x, a_f64x, b_f64x);
591
+ }
592
+
593
+ // Column tile 6 → ZA7 via MOVA
594
+ svzero_mask_za(nk_sme_zero_za64_tile_0_);
595
+ for (nk_size_t column = 0; column < tile_dimension; column++) {
596
+ nk_size_t const column_abs = (column_tile_index + 6) * tile_dimension + column;
597
+ if (column_abs < n_vectors) {
598
+ svfloat64_t widened_f64x = svcvt_f64_f32_x(
599
+ batch_predicate_f64x,
600
+ svreinterpret_f32_u64(svld1uw_u64(
601
+ a_depth_predicate_f64x,
602
+ (nk_u32_t const
603
+ *)&vectors[column_abs * stride_elements + depth_offset + depth_batch_start])));
604
+ svwrite_hor_za64_f64_m(0, column, batch_predicate_f64x, widened_f64x);
605
+ }
606
+ }
607
+ for (nk_size_t step = 0; step < batch_size; step++) {
608
+ svfloat64_t a_f64x = svld1_f64(predicate_all_f64x, a_buffer[step]);
609
+ svfloat64_t b_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), predicate_all_f64x, 0, step);
610
+ svmopa_za64_f64_m(7, row_predicate_f64x, predicate_all_f64x, a_f64x, b_f64x);
611
+ }
612
+ }
613
+ }
614
+
615
+ // Extract results and store native f64 outputs.
616
+ svbool_t const predicate_tile_f64x = svwhilelt_b64_u64(0u, tile_dimension);
617
+ // The 7th tile (index 6) may be partial when it's the last column tile
618
+ nk_size_t const last_fast_col_start = (column_tile_index + 6) * tile_dimension;
619
+ nk_size_t const last_fast_cols = (last_fast_col_start + tile_dimension <= n_vectors)
620
+ ? tile_dimension
621
+ : (n_vectors - last_fast_col_start);
622
+ svbool_t const last_tile_pred_f64x = svwhilelt_b64_u64(0u, last_fast_cols);
623
+ for (nk_size_t row = 0; row < rows_actual; row++) {
624
+ nk_size_t const row_abs = row_tile_start + row;
625
+ nk_f64_t *result_row = result + row_abs * result_stride_elements;
626
+
627
+ svfloat64_t za_row_f64x = svread_hor_za64_f64_m(svdup_f64(0), predicate_all_f64x, 1, row);
628
+ svst1_f64(predicate_tile_f64x, result_row + (column_tile_index + 0) * tile_dimension, za_row_f64x);
629
+
630
+ za_row_f64x = svread_hor_za64_f64_m(svdup_f64(0), predicate_all_f64x, 2, row);
631
+ svst1_f64(predicate_tile_f64x, result_row + (column_tile_index + 1) * tile_dimension, za_row_f64x);
632
+
633
+ za_row_f64x = svread_hor_za64_f64_m(svdup_f64(0), predicate_all_f64x, 3, row);
634
+ svst1_f64(predicate_tile_f64x, result_row + (column_tile_index + 2) * tile_dimension, za_row_f64x);
635
+
636
+ za_row_f64x = svread_hor_za64_f64_m(svdup_f64(0), predicate_all_f64x, 4, row);
637
+ svst1_f64(predicate_tile_f64x, result_row + (column_tile_index + 3) * tile_dimension, za_row_f64x);
638
+
639
+ za_row_f64x = svread_hor_za64_f64_m(svdup_f64(0), predicate_all_f64x, 5, row);
640
+ svst1_f64(predicate_tile_f64x, result_row + (column_tile_index + 4) * tile_dimension, za_row_f64x);
641
+
642
+ za_row_f64x = svread_hor_za64_f64_m(svdup_f64(0), predicate_all_f64x, 6, row);
643
+ svst1_f64(predicate_tile_f64x, result_row + (column_tile_index + 5) * tile_dimension, za_row_f64x);
644
+
645
+ za_row_f64x = svread_hor_za64_f64_m(svdup_f64(0), predicate_all_f64x, 7, row);
646
+ svst1_f64(last_tile_pred_f64x, result_row + (column_tile_index + 6) * tile_dimension, za_row_f64x);
647
+ }
648
+ }
649
+
650
+ // Remainder: 1 column tile at a time
651
+ for (; column_tile_index < column_tile_count; column_tile_index++) {
652
+ nk_size_t const column_tile_start = column_tile_index * tile_dimension;
653
+ nk_size_t const columns_remaining = (column_tile_start + tile_dimension <= n_vectors)
654
+ ? tile_dimension
655
+ : (n_vectors - column_tile_start);
656
+ svbool_t const column_predicate_f64x = svwhilelt_b64_u64(0u, columns_remaining);
657
+
658
+ svzero_mask_za(nk_sme_zero_za64_tile_1_);
659
+
660
+ for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tile_count; depth_tile_idx++) {
661
+ nk_size_t const depth_offset = depth_tile_idx * depth_tile_size;
662
+
663
+ for (nk_size_t depth_batch_start = 0; depth_batch_start < depth_tile_size;
664
+ depth_batch_start += depth_steps_per_batch) {
665
+ nk_size_t const depth_batch_end = (depth_batch_start + depth_steps_per_batch < depth_tile_size)
666
+ ? depth_batch_start + depth_steps_per_batch
667
+ : depth_tile_size;
668
+ nk_size_t const batch_size = depth_batch_end - depth_batch_start;
669
+
670
+ if (depth_offset + depth_batch_start >= depth) break;
671
+
672
+ svbool_t const batch_predicate_f64x = svwhilelt_b64_u64(0u, (uint64_t)batch_size);
673
+ svbool_t const a_depth_pred_f64x = svwhilelt_b64_u64(depth_offset + depth_batch_start,
674
+ (uint64_t)depth);
675
+ svzero_mask_za(nk_sme_zero_za64_tile_0_);
676
+ for (nk_size_t row_in_tile = 0; row_in_tile < rows_actual; row_in_tile++) {
677
+ nk_size_t const row_abs = row_tile_start + row_in_tile;
678
+ svfloat64_t a_row_widened_f64x = svcvt_f64_f32_x(
679
+ batch_predicate_f64x,
680
+ svreinterpret_f32_u64(svld1uw_u64(
681
+ a_depth_pred_f64x, (nk_u32_t const *)&vectors[row_abs * stride_elements + depth_offset +
682
+ depth_batch_start])));
683
+ svwrite_hor_za64_f64_m(0, row_in_tile, batch_predicate_f64x, a_row_widened_f64x);
684
+ }
685
+
686
+ // Save A columns from ZA0 to stack buffer
687
+ for (nk_size_t s = 0; s < batch_size; s++)
688
+ svst1_f64(predicate_all_f64x, a_buffer[s],
689
+ svread_ver_za64_f64_m(svdup_f64(0), row_predicate_f64x, 0, s));
690
+
691
+ // Load B column tile into ZA0 via MOVA, vertical read + FMOPA into ZA1
692
+ svzero_mask_za(nk_sme_zero_za64_tile_0_);
693
+ for (nk_size_t column = 0; column < tile_dimension; column++) {
694
+ nk_size_t const column_abs = column_tile_start + column;
695
+ if (column_abs < n_vectors) {
696
+ svfloat64_t widened_f64x = svcvt_f64_f32_x(
697
+ batch_predicate_f64x,
698
+ svreinterpret_f32_u64(svld1uw_u64(
699
+ a_depth_pred_f64x, (nk_u32_t const *)&vectors[column_abs * stride_elements +
700
+ depth_offset + depth_batch_start])));
701
+ svwrite_hor_za64_f64_m(0, column, batch_predicate_f64x, widened_f64x);
702
+ }
703
+ }
704
+ for (nk_size_t step = 0; step < batch_size; step++) {
705
+ nk_size_t const k_abs = depth_offset + depth_batch_start + step;
706
+ if (k_abs >= depth) break;
707
+ svfloat64_t a_f64x = svld1_f64(predicate_all_f64x, a_buffer[step]);
708
+ svfloat64_t b_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), column_predicate_f64x, 0, step);
709
+ svmopa_za64_f64_m(1, row_predicate_f64x, column_predicate_f64x, a_f64x, b_f64x);
710
+ }
711
+ }
712
+ }
713
+
714
+ // Store native f64 outputs for the tail column tile.
715
+ for (nk_size_t row = 0; row < rows_actual; row++) {
716
+ nk_size_t const row_abs = row_tile_start + row;
717
+ svfloat64_t za_row_f64x = svread_hor_za64_f64_m(svdup_f64(0), predicate_all_f64x, 1, row);
718
+ svst1_f64(column_predicate_f64x, result + row_abs * result_stride_elements + column_tile_start,
719
+ za_row_f64x);
720
+ }
721
+ }
722
+ }
723
+ }
724
+
725
+ NK_PUBLIC void nk_dots_symmetric_f32_smef64(nk_f32_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
726
+ nk_size_t stride, nk_f64_t *result, nk_size_t result_stride,
727
+ nk_size_t row_start, nk_size_t row_count) {
728
+
729
+ nk_size_t const stride_elements = stride / sizeof(nk_f32_t);
730
+ nk_size_t const result_stride_elements = result_stride / sizeof(nk_f64_t);
731
+ nk_dots_symmetric_f32_smef64_streaming_(vectors, n_vectors, depth, stride_elements, result, result_stride_elements,
732
+ row_start, row_count);
733
+ }
734
+
735
+ #pragma endregion // Single Precision Floats
736
+
737
+ /*
738
+ * f64 GEMM via 3-way Ozaki splitting using FMOPA with ZA64 tiles.
739
+ * Uses ZA transpose for A-vector construction (expansion=1, no interleaving needed).
740
+ *
741
+ * Each f64 is split into 3 non-overlapping mantissa-masked slices (19+17+17 bits).
742
+ * All slices fit in f32 (max 19 significant bits < 24-bit significand).
743
+ * Cross-products with index sum i+j ≤ 2 are accumulated via 6 FMOPAs into 3 merged
744
+ * accumulators. Products are exact in f64 (max 19+19 = 38 < 53 mantissa bits).
745
+ * Accumulation is exact for K ≤ 16,384 (ceil(log2(K)) + 38 + 1 ≤ 53).
746
+ *
747
+ * Packed GEMM tile allocation (2-column fast path):
748
+ * - ZA0.D: A-staging (horizontal load, vertical read, shared by both col tiles)
749
+ * - ZA1.D: col0 acc0 — a₀×b₀ (i+j=0, dominant)
750
+ * - ZA2.D: col0 acc1 — a₀×b₁ + a₁×b₀ (i+j=1)
751
+ * - ZA3.D: col0 acc2 — a₀×b₂ + a₁×b₁ + a₂×b₀ (i+j=2, smallest)
752
+ * - ZA4.D: col1 acc0 — a₀×b₀ (i+j=0)
753
+ * - ZA5.D: col1 acc1 — a₀×b₁ + a₁×b₀ (i+j=1)
754
+ * - ZA6.D: col1 acc2 — a₀×b₂ + a₁×b₁ + a₂×b₀ (i+j=2)
755
+ * - ZA7.D: unused
756
+ *
757
+ * 1-column remainder uses ZA1-3 only.
758
+ * B is pre-split at pack time into interleaved f32 slices.
759
+ * A is split in-register per depth step via SVE integer AND.
760
+ *
761
+ * Symmetric GEMM tile allocation:
762
+ * - ZA0.D: staging (A rows, then B columns via horizontal load)
763
+ * - ZA1-3.D: merged Ozaki accumulators (i+j=0,1,2)
764
+ * Both A and B are split in-register per depth step.
765
+ *
766
+ * Tile dimensions for SVL=512 (Apple M4):
767
+ * - ZA64 tile: 8 × 8 f64 elements (512B)
768
+ * - f64 input vectors: 8 elements (SVL/64)
769
+ * - FMOPA predicates: b64 (native f64 granularity)
770
+ */
771
+ #pragma region Double Precision Floats
772
+
773
+ /* Mantissa bit masks for 3-way Ozaki splitting of f64 values.
774
+ *
775
+ * f64 layout: [63]=sign, [62:52]=exponent (11 bits), [51:0]=mantissa (52 bits).
776
+ * Significand = implicit 1 + mantissa = 53 significant bits.
777
+ *
778
+ * Slice 0 (19 significant bits): keep sign + exponent + top 18 mantissa bits.
779
+ * Zeroes mantissa bits [33:0] (34 bits). Mask = 0xFFFFFFFC00000000.
780
+ * Slice 1 (17 significant bits): keep sign + exponent + top 16 mantissa bits of residual.
781
+ * Zeroes mantissa bits [35:0] (36 bits). Mask = 0xFFFFFFF000000000.
782
+ * Slice 2 = residual of residual (at most 17 significant bits, fits f32).
783
+ *
784
+ * All slices fit in f32 (24-bit significand). Products: max 19+19 = 38 ≤ 53, exact in f64.
785
+ */
786
+ NK_PUBLIC nk_u64_t nk_f64_smef64_ozaki_mask_19_bits_(void) NK_STREAMING_COMPATIBLE_ {
787
+ return 0xFFFFFFFC00000000ULL; // keep top 19 sig bits
788
+ }
789
+ NK_PUBLIC nk_u64_t nk_f64_smef64_ozaki_mask_17_bits_(void) NK_STREAMING_COMPATIBLE_ {
790
+ return 0xFFFFFFF000000000ULL; // keep top 17 sig bits
791
+ }
792
+
793
+ /* Split a scalar f64 into 3 non-overlapping Ozaki slices (19+17+17 mantissa bits).
794
+ * Each slice fits in f32. Outputs stored via pointers. */
795
+ NK_PUBLIC void nk_f64_smef64_ozaki_split_f64_(nk_f64_t val, nk_f64_t *slice_0, nk_f64_t *slice_1,
796
+ nk_f64_t *slice_2) NK_STREAMING_COMPATIBLE_ {
797
+ nk_fui64_t pun;
798
+ pun.f = val;
799
+ pun.u &= nk_f64_smef64_ozaki_mask_19_bits_();
800
+ *slice_0 = pun.f;
801
+ nk_f64_t residual = val - *slice_0;
802
+ pun.f = residual;
803
+ pun.u &= nk_f64_smef64_ozaki_mask_17_bits_();
804
+ *slice_1 = pun.f;
805
+ *slice_2 = residual - *slice_1;
806
+ }
807
+
808
+ __arm_locally_streaming __arm_new("za") static void nk_dots_symmetric_f64_smef64_streaming_(
809
+ nk_f64_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride_elements, nk_f64_t *result,
810
+ nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
811
+
812
+ nk_size_t const tile_dimension = svcntd();
813
+ nk_size_t const depth_steps_per_batch = tile_dimension;
814
+
815
+ svbool_t const predicate_all_f64x = svptrue_b64();
816
+ svuint64_t const ozaki_mask_19_u64x = svdup_u64(nk_f64_smef64_ozaki_mask_19_bits_());
817
+ svuint64_t const ozaki_mask_17_u64x = svdup_u64(nk_f64_smef64_ozaki_mask_17_bits_());
818
+
819
+ NK_ALIGN64 nk_f64_t a_buffer[8][8]; // save A columns before reusing ZA0 for B
820
+
821
+ nk_size_t const row_end = row_start + row_count;
822
+ nk_size_t const column_tile_count = nk_size_divide_round_up_(n_vectors, tile_dimension);
823
+
824
+ // ZA0.D = staging (A then B), ZA1-3.D = merged Ozaki accumulators (i+j=0,1,2)
825
+ for (nk_size_t row_tile_start = row_start; row_tile_start < row_end && row_tile_start < n_vectors;
826
+ row_tile_start += tile_dimension) {
827
+ nk_size_t const rows_remaining = (row_tile_start + tile_dimension <= row_end) ? tile_dimension
828
+ : (row_end - row_tile_start);
829
+ nk_size_t const rows_clamped = (row_tile_start + rows_remaining <= n_vectors) ? rows_remaining
830
+ : (n_vectors - row_tile_start);
831
+ svbool_t const row_predicate_f64x = svwhilelt_b64_u64(0u, rows_clamped);
832
+
833
+ for (nk_size_t column_tile_index = 0; column_tile_index < column_tile_count; column_tile_index++) {
834
+ nk_size_t const column_tile_start = column_tile_index * tile_dimension;
835
+ nk_size_t const columns_remaining = (column_tile_start + tile_dimension <= n_vectors)
836
+ ? tile_dimension
837
+ : (n_vectors - column_tile_start);
838
+ svbool_t const column_predicate_f64x = svwhilelt_b64_u64(0u, columns_remaining);
839
+
840
+ // Zero ZA1-3 (3 merged Ozaki accumulators)
841
+ svzero_mask_za(nk_sme_zero_za64_tiles_1_3_);
842
+
843
+ for (nk_size_t depth_batch_start = 0; depth_batch_start < depth;
844
+ depth_batch_start += depth_steps_per_batch) {
845
+ nk_size_t const depth_batch_end = (depth_batch_start + depth_steps_per_batch < depth)
846
+ ? depth_batch_start + depth_steps_per_batch
847
+ : depth;
848
+ nk_size_t const batch_size = depth_batch_end - depth_batch_start;
849
+ svbool_t const batch_predicate_f64x = svwhilelt_b64_u64(0u, batch_size);
850
+
851
+ // Load A rows into ZA0
852
+ svzero_mask_za(nk_sme_zero_za64_tile_0_);
853
+ for (nk_size_t row_in_tile = 0; row_in_tile < rows_clamped; row_in_tile++) {
854
+ nk_size_t const row_abs = row_tile_start + row_in_tile;
855
+ svld1_hor_za64(0, row_in_tile, batch_predicate_f64x,
856
+ vectors + row_abs * stride_elements + depth_batch_start);
857
+ }
858
+
859
+ // Save A columns to buffer before reusing ZA0 for B
860
+ for (nk_size_t s = 0; s < batch_size; s++)
861
+ svst1_f64(predicate_all_f64x, a_buffer[s],
862
+ svread_ver_za64_f64_m(svdup_f64(0.0), row_predicate_f64x, 0, s));
863
+
864
+ // Load B columns into ZA0 (reuse)
865
+ svzero_mask_za(nk_sme_zero_za64_tile_0_);
866
+ for (nk_size_t column = 0; column < tile_dimension; column++) {
867
+ nk_size_t const column_abs = column_tile_start + column;
868
+ if (column_abs < n_vectors)
869
+ svld1_hor_za64(0, column, batch_predicate_f64x,
870
+ vectors + column_abs * stride_elements + depth_batch_start);
871
+ }
872
+
873
+ // Split both A and B into 3 Ozaki slices, 6 FMOPAs per step
874
+ for (nk_size_t step = 0; step < batch_size; step++) {
875
+ svfloat64_t a_f64x = svld1_f64(predicate_all_f64x, a_buffer[step]);
876
+ svuint64_t a_bits_u64x = svreinterpret_u64_f64(a_f64x);
877
+ svfloat64_t a_slice_0_f64x = svreinterpret_f64_u64(
878
+ svand_u64_x(predicate_all_f64x, a_bits_u64x, ozaki_mask_19_u64x));
879
+ svfloat64_t residual_a_f64x = svsub_f64_x(predicate_all_f64x, a_f64x, a_slice_0_f64x);
880
+ svuint64_t residual_a_bits_u64x = svreinterpret_u64_f64(residual_a_f64x);
881
+ svfloat64_t a_slice_1_f64x = svreinterpret_f64_u64(
882
+ svand_u64_x(predicate_all_f64x, residual_a_bits_u64x, ozaki_mask_17_u64x));
883
+ svfloat64_t a_slice_2_f64x = svsub_f64_x(predicate_all_f64x, residual_a_f64x, a_slice_1_f64x);
884
+
885
+ svfloat64_t b_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), column_predicate_f64x, 0, step);
886
+ svuint64_t b_bits_u64x = svreinterpret_u64_f64(b_f64x);
887
+ svfloat64_t b_slice_0_f64x = svreinterpret_f64_u64(
888
+ svand_u64_x(predicate_all_f64x, b_bits_u64x, ozaki_mask_19_u64x));
889
+ svfloat64_t residual_b_f64x = svsub_f64_x(predicate_all_f64x, b_f64x, b_slice_0_f64x);
890
+ svuint64_t residual_b_bits_u64x = svreinterpret_u64_f64(residual_b_f64x);
891
+ svfloat64_t b_slice_1_f64x = svreinterpret_f64_u64(
892
+ svand_u64_x(predicate_all_f64x, residual_b_bits_u64x, ozaki_mask_17_u64x));
893
+ svfloat64_t b_slice_2_f64x = svsub_f64_x(predicate_all_f64x, residual_b_f64x, b_slice_1_f64x);
894
+
895
+ // 6 FMOPAs reordered to minimize WAW pipeline stalls on 3 tiles.
896
+ // Same-tile accumulation order preserved (bit-identical output).
897
+ // Tile schedule: ZA3(0), ZA2(1), ZA1(2), ZA3(4), ZA2(5), ZA3(8).
898
+ // 9 cycles vs 15 original (3 unavoidable bubbles with only 3 tiles).
899
+ svmopa_za64_f64_m(3, row_predicate_f64x, column_predicate_f64x, a_slice_0_f64x,
900
+ b_slice_2_f64x); // ZA3: i+j=2 (1/3)
901
+ svmopa_za64_f64_m(2, row_predicate_f64x, column_predicate_f64x, a_slice_0_f64x,
902
+ b_slice_1_f64x); // ZA2: i+j=1 (1/2)
903
+ svmopa_za64_f64_m(1, row_predicate_f64x, column_predicate_f64x, a_slice_0_f64x,
904
+ b_slice_0_f64x); // ZA1: i+j=0
905
+ svmopa_za64_f64_m(3, row_predicate_f64x, column_predicate_f64x, a_slice_1_f64x,
906
+ b_slice_1_f64x); // ZA3: i+j=2 (2/3)
907
+ svmopa_za64_f64_m(2, row_predicate_f64x, column_predicate_f64x, a_slice_1_f64x,
908
+ b_slice_0_f64x); // ZA2: i+j=1 (2/2)
909
+ svmopa_za64_f64_m(3, row_predicate_f64x, column_predicate_f64x, a_slice_2_f64x,
910
+ b_slice_0_f64x); // ZA3: i+j=2 (3/3)
911
+ }
912
+ }
913
+
914
+ // Sum ZA3 + ZA2 + ZA1 (smallest to largest)
915
+ for (nk_size_t row = 0; row < rows_clamped; row++) {
916
+ nk_size_t const row_abs = row_tile_start + row;
917
+ svfloat64_t result_f64x = svread_hor_za64_f64_m(svdup_f64(0.0), predicate_all_f64x, 3, row);
918
+ result_f64x = svadd_f64_x(predicate_all_f64x, result_f64x,
919
+ svread_hor_za64_f64_m(svdup_f64(0.0), predicate_all_f64x, 2, row));
920
+ result_f64x = svadd_f64_x(predicate_all_f64x, result_f64x,
921
+ svread_hor_za64_f64_m(svdup_f64(0.0), predicate_all_f64x, 1, row));
922
+ svst1_f64(column_predicate_f64x, result + row_abs * result_stride_elements + column_tile_start,
923
+ result_f64x);
924
+ }
925
+ }
926
+ }
927
+ }
928
+
929
+ NK_PUBLIC void nk_dots_symmetric_f64_smef64(nk_f64_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
930
+ nk_size_t stride, nk_f64_t *result, nk_size_t result_stride,
931
+ nk_size_t row_start, nk_size_t row_count) {
932
+
933
+ nk_size_t const stride_elements = stride / sizeof(nk_f64_t);
934
+ nk_size_t const result_stride_elements = result_stride / sizeof(nk_f64_t);
935
+ nk_dots_symmetric_f64_smef64_streaming_(vectors, n_vectors, depth, stride_elements, result, result_stride_elements,
936
+ row_start, row_count);
937
+ }
938
+
939
+ NK_PUBLIC nk_size_t nk_dots_packed_size_f64_smef64(nk_size_t columns, nk_size_t depth) {
940
+ nk_size_t const tile_dimension = svcntsd();
941
+ nk_size_t const depth_tile_size = svcntsw();
942
+ nk_size_t const column_tile_count = nk_size_divide_round_up_(columns, tile_dimension);
943
+ nk_size_t const depth_tile_count = nk_size_divide_round_up_(depth, depth_tile_size);
944
+ // Single header + interleaved 3-slice data (3× tile_dimension elements per depth step)
945
+ nk_size_t size = sizeof(nk_dots_sme_packed_header_t);
946
+ size += column_tile_count * depth_tile_count * 3 * tile_dimension * depth_tile_size * sizeof(nk_f32_t);
947
+ size += columns * sizeof(nk_f64_t); // per-column squared norms
948
+ return size;
949
+ }
950
+
951
+ NK_PUBLIC void nk_dots_pack_f64_smef64(nk_f64_t const *b, nk_size_t columns, nk_size_t depth, nk_size_t b_stride,
952
+ void *b_packed) {
953
+
954
+ nk_size_t const b_stride_elements = b_stride / sizeof(nk_f64_t);
955
+
956
+ nk_size_t const tile_dimension = svcntsd();
957
+ nk_size_t const depth_tile_size = svcntsw();
958
+ nk_size_t const interleaved_stride = 3 * tile_dimension;
959
+ nk_size_t const interleaved_tile_elements = depth_tile_size * interleaved_stride;
960
+
961
+ nk_size_t const column_tile_count = nk_size_divide_round_up_(columns, tile_dimension);
962
+ nk_size_t const depth_tile_count = nk_size_divide_round_up_(depth, depth_tile_size);
963
+ nk_size_t const total_tiles = column_tile_count * depth_tile_count;
964
+
965
+ // Write single header
966
+ nk_dots_sme_packed_header_t *header = (nk_dots_sme_packed_header_t *)b_packed;
967
+ header->column_tile_count = (nk_u32_t)column_tile_count;
968
+ header->depth_tile_count = (nk_u32_t)depth_tile_count;
969
+ header->columns = (nk_u32_t)columns;
970
+ header->depth = (nk_u32_t)depth;
971
+ header->svl_bytes = (nk_u32_t)svcntsb();
972
+
973
+ nk_f32_t *tiles = (nk_f32_t *)((char *)b_packed + sizeof(nk_dots_sme_packed_header_t));
974
+
975
+ // Zero-initialize all tiles (handles partial tile padding)
976
+ nk_size_t const total_elements = total_tiles * interleaved_tile_elements;
977
+ for (nk_size_t i = 0; i < total_elements; i++) tiles[i] = 0.0f;
978
+
979
+ // Inline tiling + 3-way mantissa-mask split with interleaved slice layout.
980
+ // Per depth step depth_idx, 3 slices are stored contiguously:
981
+ // tiles[tile_output + depth_idx * interleaved_stride + slice * tile_dimension + column_idx]
982
+ for (nk_size_t column_tile_idx = 0; column_tile_idx < column_tile_count; column_tile_idx++) {
983
+ for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tile_count; depth_tile_idx++) {
984
+ nk_f32_t *tile_output = tiles +
985
+ (column_tile_idx * depth_tile_count + depth_tile_idx) * interleaved_tile_elements;
986
+
987
+ nk_size_t const column_start = column_tile_idx * tile_dimension;
988
+ nk_size_t const k_start = depth_tile_idx * depth_tile_size;
989
+ nk_size_t const columns_to_pack = (column_start + tile_dimension <= columns) ? tile_dimension
990
+ : (columns - column_start);
991
+ nk_size_t const depth_to_pack = (k_start + depth_tile_size <= depth) ? depth_tile_size : (depth - k_start);
992
+
993
+ for (nk_size_t column_idx = 0; column_idx < columns_to_pack; column_idx++) {
994
+ for (nk_size_t depth_idx = 0; depth_idx < depth_to_pack; depth_idx++) {
995
+ nk_f64_t val = b[(column_start + column_idx) * b_stride_elements + k_start + depth_idx];
996
+ nk_f64_t slice_0, slice_1, slice_2;
997
+ nk_f64_smef64_ozaki_split_f64_(val, &slice_0, &slice_1, &slice_2);
998
+
999
+ tile_output[depth_idx * interleaved_stride + 0 * tile_dimension + column_idx] = (nk_f32_t)slice_0;
1000
+ tile_output[depth_idx * interleaved_stride + 1 * tile_dimension + column_idx] = (nk_f32_t)slice_1;
1001
+ tile_output[depth_idx * interleaved_stride + 2 * tile_dimension + column_idx] = (nk_f32_t)slice_2;
1002
+ }
1003
+ }
1004
+ }
1005
+ }
1006
+
1007
+ // Compute per-column squared norms and store after packed data
1008
+ nk_size_t const data_size = total_tiles * interleaved_tile_elements * sizeof(nk_f32_t);
1009
+ header->norms_offset = (nk_u32_t)(sizeof(nk_dots_sme_packed_header_t) + data_size);
1010
+ nk_f64_t *norms_ptr = (nk_f64_t *)((char *)b_packed + header->norms_offset);
1011
+ for (nk_size_t col = 0; col < columns; col++) {
1012
+ nk_f64_t const *col_data = (nk_f64_t const *)((char const *)b + col * b_stride);
1013
+ norms_ptr[col] = nk_dots_reduce_sumsq_f64_(col_data, depth);
1014
+ }
1015
+ }
1016
+
1017
+ __arm_locally_streaming __arm_new("za") static void nk_dots_packed_f64_smef64_streaming_(
1018
+ nk_f64_t const *a, void const *b_packed, nk_f64_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
1019
+ nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
1020
+
1021
+ // Read header
1022
+ nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
1023
+ nk_size_t const column_tile_count = header->column_tile_count;
1024
+ nk_size_t const depth_tile_count = header->depth_tile_count;
1025
+
1026
+ nk_size_t const tile_dimension = svcntd(); // 8 for 512-bit SVL
1027
+ nk_size_t const depth_tile_size = svcntw(); // 16 (f32 packing granularity)
1028
+ nk_size_t const interleaved_stride = 3 * tile_dimension; // 24
1029
+ nk_size_t const interleaved_tile_elements = depth_tile_size * interleaved_stride; // 384
1030
+ nk_size_t const depth_steps_per_batch = tile_dimension; // 8 f64 steps per ZA0 load
1031
+
1032
+ // B tile data pointer (f32, interleaved slices)
1033
+ nk_f32_t const *b_tiles = (nk_f32_t const *)((char const *)b_packed + sizeof(nk_dots_sme_packed_header_t));
1034
+
1035
+ svbool_t const predicate_all_f64x = svptrue_b64();
1036
+
1037
+ // Mantissa masks for in-register Ozaki splitting (19+17+17 bits)
1038
+ svuint64_t const ozaki_mask_19_u64x = svdup_u64(nk_f64_smef64_ozaki_mask_19_bits_());
1039
+ svuint64_t const ozaki_mask_17_u64x = svdup_u64(nk_f64_smef64_ozaki_mask_17_bits_());
1040
+
1041
+ // ZA0.D = A staging
1042
+ // ZA1-3.D = col0 merged accumulators (i+j=0,1,2)
1043
+ // ZA4-6.D = col1 merged accumulators (i+j=0,1,2) [2-col path only]
1044
+ for (nk_size_t row_tile_index = 0; row_tile_index < nk_size_divide_round_up_(rows, tile_dimension);
1045
+ row_tile_index++) {
1046
+ nk_size_t const row_start = row_tile_index * tile_dimension;
1047
+ nk_size_t const rows_remaining = (row_start + tile_dimension <= rows) ? tile_dimension : (rows - row_start);
1048
+ svbool_t const row_predicate_f64x = svwhilelt_b64_u64(0u, rows_remaining);
1049
+
1050
+ nk_size_t column_tile_index = 0;
1051
+
1052
+ // 2-column fast path: process 2 column tiles per iteration, A staged once
1053
+ for (; column_tile_index + 2 <= column_tile_count; column_tile_index += 2) {
1054
+ nk_size_t const column_start_0 = column_tile_index * tile_dimension;
1055
+ nk_size_t const column_start_1 = (column_tile_index + 1) * tile_dimension;
1056
+ nk_size_t const columns_remaining_0 = (column_start_0 + tile_dimension <= columns)
1057
+ ? tile_dimension
1058
+ : (columns - column_start_0);
1059
+ nk_size_t const columns_remaining_1 = (column_start_1 + tile_dimension <= columns)
1060
+ ? tile_dimension
1061
+ : (columns - column_start_1);
1062
+ svbool_t const column_predicate_0_f64x = svwhilelt_b64_u64(0u, columns_remaining_0);
1063
+ svbool_t const column_predicate_1_f64x = svwhilelt_b64_u64(0u, columns_remaining_1);
1064
+
1065
+ // Zero ZA1-6 (3 accumulators × 2 column tiles)
1066
+ svzero_mask_za(nk_sme_zero_za64_tiles_1_6_);
1067
+
1068
+ for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tile_count; depth_tile_idx++) {
1069
+ nk_size_t const depth_offset = depth_tile_idx * depth_tile_size;
1070
+
1071
+ for (nk_size_t depth_batch_start = 0; depth_batch_start < depth_tile_size;
1072
+ depth_batch_start += depth_steps_per_batch) {
1073
+ nk_size_t const depth_batch_end = (depth_batch_start + depth_steps_per_batch < depth_tile_size)
1074
+ ? depth_batch_start + depth_steps_per_batch
1075
+ : depth_tile_size;
1076
+ nk_size_t const batch_size = depth_batch_end - depth_batch_start;
1077
+
1078
+ if (depth_offset + depth_batch_start >= depth) break;
1079
+
1080
+ // Load A rows into ZA0.D (shared for both column tiles)
1081
+ svzero_mask_za(nk_sme_zero_za64_tile_0_);
1082
+ for (nk_size_t row_in_tile = 0; row_in_tile < rows_remaining; row_in_tile++) {
1083
+ nk_size_t const a_row = row_start + row_in_tile;
1084
+ svbool_t const a_depth_predicate_f64x = svwhilelt_b64_u64(depth_offset + depth_batch_start,
1085
+ (uint64_t)depth);
1086
+ svld1_hor_za64(0, row_in_tile, a_depth_predicate_f64x,
1087
+ &a[a_row * a_stride_elements + depth_offset + depth_batch_start]);
1088
+ }
1089
+
1090
+ // B base offsets for both column tiles
1091
+ nk_size_t const b_batch_offset_0 = (column_tile_index * depth_tile_count + depth_tile_idx) *
1092
+ interleaved_tile_elements +
1093
+ depth_batch_start * interleaved_stride;
1094
+ nk_size_t const b_batch_offset_1 = ((column_tile_index + 1) * depth_tile_count + depth_tile_idx) *
1095
+ interleaved_tile_elements +
1096
+ depth_batch_start * interleaved_stride;
1097
+
1098
+ for (nk_size_t step = 0; step < batch_size; step++) {
1099
+ nk_size_t const k_abs = depth_offset + depth_batch_start + step;
1100
+ if (k_abs >= depth) break;
1101
+
1102
+ // Read A column from ZA0 and split into 3 Ozaki slices
1103
+ svfloat64_t a_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), row_predicate_f64x, 0, step);
1104
+ svuint64_t a_bits_u64x = svreinterpret_u64_f64(a_f64x);
1105
+ svfloat64_t a_slice_0_f64x = svreinterpret_f64_u64(
1106
+ svand_u64_x(predicate_all_f64x, a_bits_u64x, ozaki_mask_19_u64x));
1107
+ svfloat64_t residual_a_f64x = svsub_f64_x(predicate_all_f64x, a_f64x, a_slice_0_f64x);
1108
+ svuint64_t residual_a_bits_u64x = svreinterpret_u64_f64(residual_a_f64x);
1109
+ svfloat64_t a_slice_1_f64x = svreinterpret_f64_u64(
1110
+ svand_u64_x(predicate_all_f64x, residual_a_bits_u64x, ozaki_mask_17_u64x));
1111
+ svfloat64_t a_slice_2_f64x = svsub_f64_x(predicate_all_f64x, residual_a_f64x, a_slice_1_f64x);
1112
+
1113
+ // Load all 6 B slices upfront (3 per column tile) for pipeline interleaving
1114
+ nk_size_t const b_tile_offset_0 = b_batch_offset_0 + step * interleaved_stride;
1115
+ nk_size_t const b_tile_offset_1 = b_batch_offset_1 + step * interleaved_stride;
1116
+ svfloat64_t b_column_0_slice_0_f64x = svcvt_f64_f32_x(
1117
+ predicate_all_f64x,
1118
+ svreinterpret_f32_u64(
1119
+ svld1uw_u64(predicate_all_f64x, (nk_u32_t const *)(b_tiles + b_tile_offset_0))));
1120
+ svfloat64_t b_column_0_slice_1_f64x = svcvt_f64_f32_x(
1121
+ predicate_all_f64x,
1122
+ svreinterpret_f32_u64(svld1uw_u64(
1123
+ predicate_all_f64x, (nk_u32_t const *)(b_tiles + b_tile_offset_0 + tile_dimension))));
1124
+ svfloat64_t b_column_0_slice_2_f64x = svcvt_f64_f32_x(
1125
+ predicate_all_f64x, svreinterpret_f32_u64(svld1uw_u64(
1126
+ predicate_all_f64x, (nk_u32_t const *)(b_tiles + b_tile_offset_0 +
1127
+ 2 * tile_dimension))));
1128
+ svfloat64_t b_column_1_slice_0_f64x = svcvt_f64_f32_x(
1129
+ predicate_all_f64x,
1130
+ svreinterpret_f32_u64(
1131
+ svld1uw_u64(predicate_all_f64x, (nk_u32_t const *)(b_tiles + b_tile_offset_1))));
1132
+ svfloat64_t b_column_1_slice_1_f64x = svcvt_f64_f32_x(
1133
+ predicate_all_f64x,
1134
+ svreinterpret_f32_u64(svld1uw_u64(
1135
+ predicate_all_f64x, (nk_u32_t const *)(b_tiles + b_tile_offset_1 + tile_dimension))));
1136
+ svfloat64_t b_column_1_slice_2_f64x = svcvt_f64_f32_x(
1137
+ predicate_all_f64x, svreinterpret_f32_u64(svld1uw_u64(
1138
+ predicate_all_f64x, (nk_u32_t const *)(b_tiles + b_tile_offset_1 +
1139
+ 2 * tile_dimension))));
1140
+
1141
+ // 12 FMOPAs interleaved across 6 tiles to eliminate WAW pipeline stalls.
1142
+ // Same-tile accumulation order preserved (bit-identical output).
1143
+ // Tile gaps: ZA3 at 0,6,10 (6,4); ZA6 at 1,7,11 (6,4); ZA2 at 4,8 (4);
1144
+ // ZA5 at 5,9 (4); ZA1 at 2; ZA4 at 3. All gaps >= 4-cycle latency.
1145
+ svmopa_za64_f64_m(3, row_predicate_f64x, column_predicate_0_f64x, a_slice_0_f64x,
1146
+ b_column_0_slice_2_f64x); // ZA3: i+j=2 (1/3)
1147
+ svmopa_za64_f64_m(6, row_predicate_f64x, column_predicate_1_f64x, a_slice_0_f64x,
1148
+ b_column_1_slice_2_f64x); // ZA6: i+j=2 (1/3)
1149
+ svmopa_za64_f64_m(1, row_predicate_f64x, column_predicate_0_f64x, a_slice_0_f64x,
1150
+ b_column_0_slice_0_f64x); // ZA1: i+j=0
1151
+ svmopa_za64_f64_m(4, row_predicate_f64x, column_predicate_1_f64x, a_slice_0_f64x,
1152
+ b_column_1_slice_0_f64x); // ZA4: i+j=0
1153
+ svmopa_za64_f64_m(2, row_predicate_f64x, column_predicate_0_f64x, a_slice_0_f64x,
1154
+ b_column_0_slice_1_f64x); // ZA2: i+j=1 (1/2)
1155
+ svmopa_za64_f64_m(5, row_predicate_f64x, column_predicate_1_f64x, a_slice_0_f64x,
1156
+ b_column_1_slice_1_f64x); // ZA5: i+j=1 (1/2)
1157
+ svmopa_za64_f64_m(3, row_predicate_f64x, column_predicate_0_f64x, a_slice_1_f64x,
1158
+ b_column_0_slice_1_f64x); // ZA3: i+j=2 (2/3)
1159
+ svmopa_za64_f64_m(6, row_predicate_f64x, column_predicate_1_f64x, a_slice_1_f64x,
1160
+ b_column_1_slice_1_f64x); // ZA6: i+j=2 (2/3)
1161
+ svmopa_za64_f64_m(2, row_predicate_f64x, column_predicate_0_f64x, a_slice_1_f64x,
1162
+ b_column_0_slice_0_f64x); // ZA2: i+j=1 (2/2)
1163
+ svmopa_za64_f64_m(5, row_predicate_f64x, column_predicate_1_f64x, a_slice_1_f64x,
1164
+ b_column_1_slice_0_f64x); // ZA5: i+j=1 (2/2)
1165
+ svmopa_za64_f64_m(3, row_predicate_f64x, column_predicate_0_f64x, a_slice_2_f64x,
1166
+ b_column_0_slice_0_f64x); // ZA3: i+j=2 (3/3)
1167
+ svmopa_za64_f64_m(6, row_predicate_f64x, column_predicate_1_f64x, a_slice_2_f64x,
1168
+ b_column_1_slice_0_f64x); // ZA6: i+j=2 (3/3)
1169
+ }
1170
+ }
1171
+ }
1172
+
1173
+ // Simple summation for col tile 0: ZA3 + ZA2 + ZA1 (smallest to largest)
1174
+ for (nk_size_t row = 0; row < rows_remaining; row++) {
1175
+ nk_f64_t *c_row = c + (row_start + row) * c_stride_elements + column_start_0;
1176
+ svfloat64_t result_f64x = svread_hor_za64_f64_m(svdup_f64(0.0), predicate_all_f64x, 3, row);
1177
+ result_f64x = svadd_f64_x(predicate_all_f64x, result_f64x,
1178
+ svread_hor_za64_f64_m(svdup_f64(0.0), predicate_all_f64x, 2, row));
1179
+ result_f64x = svadd_f64_x(predicate_all_f64x, result_f64x,
1180
+ svread_hor_za64_f64_m(svdup_f64(0.0), predicate_all_f64x, 1, row));
1181
+ svst1_f64(column_predicate_0_f64x, c_row, result_f64x);
1182
+ }
1183
+
1184
+ // Simple summation for col tile 1: ZA6 + ZA5 + ZA4 (smallest to largest)
1185
+ for (nk_size_t row = 0; row < rows_remaining; row++) {
1186
+ nk_f64_t *c_row = c + (row_start + row) * c_stride_elements + column_start_1;
1187
+ svfloat64_t result_f64x = svread_hor_za64_f64_m(svdup_f64(0.0), predicate_all_f64x, 6, row);
1188
+ result_f64x = svadd_f64_x(predicate_all_f64x, result_f64x,
1189
+ svread_hor_za64_f64_m(svdup_f64(0.0), predicate_all_f64x, 5, row));
1190
+ result_f64x = svadd_f64_x(predicate_all_f64x, result_f64x,
1191
+ svread_hor_za64_f64_m(svdup_f64(0.0), predicate_all_f64x, 4, row));
1192
+ svst1_f64(column_predicate_1_f64x, c_row, result_f64x);
1193
+ }
1194
+ }
1195
+
1196
+ // 1-column remainder (when column_tile_count is odd)
1197
+ for (; column_tile_index < column_tile_count; column_tile_index++) {
1198
+ nk_size_t const column_start = column_tile_index * tile_dimension;
1199
+ nk_size_t const columns_remaining = (column_start + tile_dimension <= columns) ? tile_dimension
1200
+ : (columns - column_start);
1201
+ svbool_t const column_predicate_f64x = svwhilelt_b64_u64(0u, columns_remaining);
1202
+
1203
+ // Zero ZA1-3 (3 merged accumulators)
1204
+ svzero_mask_za(nk_sme_zero_za64_tiles_1_3_);
1205
+
1206
+ for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tile_count; depth_tile_idx++) {
1207
+ nk_size_t const depth_offset = depth_tile_idx * depth_tile_size;
1208
+
1209
+ for (nk_size_t depth_batch_start = 0; depth_batch_start < depth_tile_size;
1210
+ depth_batch_start += depth_steps_per_batch) {
1211
+ nk_size_t const depth_batch_end = (depth_batch_start + depth_steps_per_batch < depth_tile_size)
1212
+ ? depth_batch_start + depth_steps_per_batch
1213
+ : depth_tile_size;
1214
+ nk_size_t const batch_size = depth_batch_end - depth_batch_start;
1215
+
1216
+ if (depth_offset + depth_batch_start >= depth) break;
1217
+
1218
+ // Load A rows into ZA0.D
1219
+ svzero_mask_za(nk_sme_zero_za64_tile_0_);
1220
+ for (nk_size_t row_in_tile = 0; row_in_tile < rows_remaining; row_in_tile++) {
1221
+ nk_size_t const a_row = row_start + row_in_tile;
1222
+ svbool_t const a_depth_predicate_f64x = svwhilelt_b64_u64(depth_offset + depth_batch_start,
1223
+ (uint64_t)depth);
1224
+ svld1_hor_za64(0, row_in_tile, a_depth_predicate_f64x,
1225
+ &a[a_row * a_stride_elements + depth_offset + depth_batch_start]);
1226
+ }
1227
+
1228
+ nk_size_t const b_batch_offset = (column_tile_index * depth_tile_count + depth_tile_idx) *
1229
+ interleaved_tile_elements +
1230
+ depth_batch_start * interleaved_stride;
1231
+
1232
+ for (nk_size_t step = 0; step < batch_size; step++) {
1233
+ nk_size_t const k_abs = depth_offset + depth_batch_start + step;
1234
+ if (k_abs >= depth) break;
1235
+
1236
+ // Read A column from ZA0 and split into 3 Ozaki slices
1237
+ svfloat64_t a_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), row_predicate_f64x, 0, step);
1238
+ svuint64_t a_bits_u64x = svreinterpret_u64_f64(a_f64x);
1239
+ svfloat64_t a_slice_0_f64x = svreinterpret_f64_u64(
1240
+ svand_u64_x(predicate_all_f64x, a_bits_u64x, ozaki_mask_19_u64x));
1241
+ svfloat64_t residual_a_f64x = svsub_f64_x(predicate_all_f64x, a_f64x, a_slice_0_f64x);
1242
+ svuint64_t residual_a_bits_u64x = svreinterpret_u64_f64(residual_a_f64x);
1243
+ svfloat64_t a_slice_1_f64x = svreinterpret_f64_u64(
1244
+ svand_u64_x(predicate_all_f64x, residual_a_bits_u64x, ozaki_mask_17_u64x));
1245
+ svfloat64_t a_slice_2_f64x = svsub_f64_x(predicate_all_f64x, residual_a_f64x, a_slice_1_f64x);
1246
+
1247
+ // Load 3 B slices (contiguous in interleaved layout)
1248
+ nk_size_t const b_tile_offset = b_batch_offset + step * interleaved_stride;
1249
+ svfloat64_t b_slice_0_f64x = svcvt_f64_f32_x(
1250
+ predicate_all_f64x, svreinterpret_f32_u64(svld1uw_u64(
1251
+ predicate_all_f64x, (nk_u32_t const *)(b_tiles + b_tile_offset))));
1252
+ svfloat64_t b_slice_1_f64x = svcvt_f64_f32_x(
1253
+ predicate_all_f64x,
1254
+ svreinterpret_f32_u64(svld1uw_u64(
1255
+ predicate_all_f64x, (nk_u32_t const *)(b_tiles + b_tile_offset + tile_dimension))));
1256
+ svfloat64_t b_slice_2_f64x = svcvt_f64_f32_x(
1257
+ predicate_all_f64x,
1258
+ svreinterpret_f32_u64(svld1uw_u64(
1259
+ predicate_all_f64x, (nk_u32_t const *)(b_tiles + b_tile_offset + 2 * tile_dimension))));
1260
+
1261
+ // 6 FMOPAs reordered to minimize WAW pipeline stalls on 3 tiles.
1262
+ // Same-tile accumulation order preserved (bit-identical output).
1263
+ // Tile schedule: ZA3(0), ZA2(1), ZA1(2), ZA3(4), ZA2(5), ZA3(8).
1264
+ // 9 cycles vs 15 original (3 unavoidable bubbles with only 3 tiles).
1265
+ svmopa_za64_f64_m(3, row_predicate_f64x, column_predicate_f64x, a_slice_0_f64x,
1266
+ b_slice_2_f64x); // ZA3: i+j=2 (1/3)
1267
+ svmopa_za64_f64_m(2, row_predicate_f64x, column_predicate_f64x, a_slice_0_f64x,
1268
+ b_slice_1_f64x); // ZA2: i+j=1 (1/2)
1269
+ svmopa_za64_f64_m(1, row_predicate_f64x, column_predicate_f64x, a_slice_0_f64x,
1270
+ b_slice_0_f64x); // ZA1: i+j=0
1271
+ svmopa_za64_f64_m(3, row_predicate_f64x, column_predicate_f64x, a_slice_1_f64x,
1272
+ b_slice_1_f64x); // ZA3: i+j=2 (2/3)
1273
+ svmopa_za64_f64_m(2, row_predicate_f64x, column_predicate_f64x, a_slice_1_f64x,
1274
+ b_slice_0_f64x); // ZA2: i+j=1 (2/2)
1275
+ svmopa_za64_f64_m(3, row_predicate_f64x, column_predicate_f64x, a_slice_2_f64x,
1276
+ b_slice_0_f64x); // ZA3: i+j=2 (3/3)
1277
+ }
1278
+ }
1279
+ }
1280
+
1281
+ // Simple summation: ZA3 + ZA2 + ZA1 (smallest to largest)
1282
+ for (nk_size_t row = 0; row < rows_remaining; row++) {
1283
+ nk_f64_t *c_row = c + (row_start + row) * c_stride_elements + column_start;
1284
+ svfloat64_t result_f64x = svread_hor_za64_f64_m(svdup_f64(0.0), predicate_all_f64x, 3, row);
1285
+ result_f64x = svadd_f64_x(predicate_all_f64x, result_f64x,
1286
+ svread_hor_za64_f64_m(svdup_f64(0.0), predicate_all_f64x, 2, row));
1287
+ result_f64x = svadd_f64_x(predicate_all_f64x, result_f64x,
1288
+ svread_hor_za64_f64_m(svdup_f64(0.0), predicate_all_f64x, 1, row));
1289
+ svst1_f64(column_predicate_f64x, c_row, result_f64x);
1290
+ }
1291
+ }
1292
+ }
1293
+ }
1294
+
1295
+ NK_PUBLIC void nk_dots_packed_f64_smef64(nk_f64_t const *a, void const *b_packed, nk_f64_t *c, nk_size_t rows,
1296
+ nk_size_t columns, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride) {
1297
+
1298
+ nk_size_t const a_stride_elements = a_stride / sizeof(nk_f64_t);
1299
+ nk_size_t const c_stride_elements = c_stride / sizeof(nk_f64_t);
1300
+
1301
+ nk_dots_packed_f64_smef64_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
1302
+ }
1303
+
1304
+ #pragma endregion // Double Precision Floats
1305
+
1306
+ #if defined(__clang__)
1307
+ #pragma clang attribute pop
1308
+ #elif defined(__GNUC__)
1309
+ #pragma GCC pop_options
1310
+ #endif
1311
+
1312
+ #if defined(__cplusplus)
1313
+ } // extern "C"
1314
+ #endif
1315
+
1316
+ #endif // NK_TARGET_SME
1317
+ #endif // NK_TARGET_ARM_
1318
+ #endif // NK_DOTS_SMEF64_H