numkong 7.0.0 → 7.4.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (315) hide show
  1. package/README.md +197 -124
  2. package/binding.gyp +34 -484
  3. package/c/dispatch_bf16.c +59 -1
  4. package/c/dispatch_e2m3.c +41 -8
  5. package/c/dispatch_e3m2.c +49 -8
  6. package/c/dispatch_e4m3.c +51 -9
  7. package/c/dispatch_e5m2.c +45 -1
  8. package/c/dispatch_f16.c +79 -26
  9. package/c/dispatch_f16c.c +5 -5
  10. package/c/dispatch_f32.c +56 -0
  11. package/c/dispatch_f64.c +52 -0
  12. package/c/dispatch_i4.c +3 -0
  13. package/c/dispatch_i8.c +62 -3
  14. package/c/dispatch_other.c +18 -0
  15. package/c/dispatch_u1.c +54 -9
  16. package/c/dispatch_u4.c +3 -0
  17. package/c/dispatch_u8.c +64 -3
  18. package/c/numkong.c +3 -0
  19. package/include/README.md +79 -9
  20. package/include/numkong/attention/sapphireamx.h +278 -276
  21. package/include/numkong/attention/sme.h +983 -977
  22. package/include/numkong/attention.h +1 -1
  23. package/include/numkong/capabilities.h +289 -94
  24. package/include/numkong/cast/README.md +40 -40
  25. package/include/numkong/cast/diamond.h +64 -0
  26. package/include/numkong/cast/haswell.h +42 -194
  27. package/include/numkong/cast/icelake.h +42 -37
  28. package/include/numkong/cast/loongsonasx.h +252 -0
  29. package/include/numkong/cast/neon.h +216 -249
  30. package/include/numkong/cast/powervsx.h +449 -0
  31. package/include/numkong/cast/rvv.h +223 -274
  32. package/include/numkong/cast/sapphire.h +18 -18
  33. package/include/numkong/cast/serial.h +1018 -944
  34. package/include/numkong/cast/skylake.h +82 -23
  35. package/include/numkong/cast/v128relaxed.h +462 -105
  36. package/include/numkong/cast.h +24 -0
  37. package/include/numkong/cast.hpp +44 -0
  38. package/include/numkong/curved/README.md +17 -17
  39. package/include/numkong/curved/neon.h +131 -7
  40. package/include/numkong/curved/neonbfdot.h +6 -7
  41. package/include/numkong/curved/rvv.h +26 -26
  42. package/include/numkong/curved/smef64.h +186 -182
  43. package/include/numkong/curved.h +14 -18
  44. package/include/numkong/dot/README.md +154 -137
  45. package/include/numkong/dot/alder.h +43 -43
  46. package/include/numkong/dot/diamond.h +158 -0
  47. package/include/numkong/dot/genoa.h +4 -30
  48. package/include/numkong/dot/haswell.h +215 -180
  49. package/include/numkong/dot/icelake.h +190 -76
  50. package/include/numkong/dot/loongsonasx.h +671 -0
  51. package/include/numkong/dot/neon.h +124 -73
  52. package/include/numkong/dot/neonbfdot.h +11 -12
  53. package/include/numkong/dot/neonfhm.h +44 -46
  54. package/include/numkong/dot/neonfp8.h +323 -0
  55. package/include/numkong/dot/neonsdot.h +190 -76
  56. package/include/numkong/dot/powervsx.h +752 -0
  57. package/include/numkong/dot/rvv.h +92 -84
  58. package/include/numkong/dot/rvvbf16.h +12 -12
  59. package/include/numkong/dot/rvvhalf.h +12 -12
  60. package/include/numkong/dot/sapphire.h +4 -4
  61. package/include/numkong/dot/serial.h +66 -30
  62. package/include/numkong/dot/sierra.h +31 -31
  63. package/include/numkong/dot/skylake.h +142 -110
  64. package/include/numkong/dot/sve.h +217 -177
  65. package/include/numkong/dot/svebfdot.h +10 -10
  66. package/include/numkong/dot/svehalf.h +85 -41
  67. package/include/numkong/dot/svesdot.h +89 -0
  68. package/include/numkong/dot/v128relaxed.h +124 -89
  69. package/include/numkong/dot.h +114 -48
  70. package/include/numkong/dots/README.md +203 -203
  71. package/include/numkong/dots/alder.h +12 -9
  72. package/include/numkong/dots/diamond.h +86 -0
  73. package/include/numkong/dots/genoa.h +10 -4
  74. package/include/numkong/dots/haswell.h +63 -48
  75. package/include/numkong/dots/icelake.h +27 -18
  76. package/include/numkong/dots/loongsonasx.h +176 -0
  77. package/include/numkong/dots/neon.h +14 -11
  78. package/include/numkong/dots/neonbfdot.h +4 -3
  79. package/include/numkong/dots/neonfhm.h +11 -9
  80. package/include/numkong/dots/neonfp8.h +99 -0
  81. package/include/numkong/dots/neonsdot.h +48 -12
  82. package/include/numkong/dots/powervsx.h +194 -0
  83. package/include/numkong/dots/rvv.h +451 -344
  84. package/include/numkong/dots/sapphireamx.h +1028 -984
  85. package/include/numkong/dots/serial.h +213 -197
  86. package/include/numkong/dots/sierra.h +10 -7
  87. package/include/numkong/dots/skylake.h +47 -36
  88. package/include/numkong/dots/sme.h +2001 -2364
  89. package/include/numkong/dots/smebi32.h +175 -162
  90. package/include/numkong/dots/smef64.h +328 -323
  91. package/include/numkong/dots/v128relaxed.h +64 -41
  92. package/include/numkong/dots.h +573 -293
  93. package/include/numkong/dots.hpp +45 -43
  94. package/include/numkong/each/README.md +133 -137
  95. package/include/numkong/each/haswell.h +6 -6
  96. package/include/numkong/each/icelake.h +7 -7
  97. package/include/numkong/each/neon.h +76 -42
  98. package/include/numkong/each/neonbfdot.h +11 -12
  99. package/include/numkong/each/neonhalf.h +24 -116
  100. package/include/numkong/each/rvv.h +28 -28
  101. package/include/numkong/each/sapphire.h +27 -161
  102. package/include/numkong/each/serial.h +6 -6
  103. package/include/numkong/each/skylake.h +7 -7
  104. package/include/numkong/each/v128relaxed.h +562 -0
  105. package/include/numkong/each.h +148 -62
  106. package/include/numkong/each.hpp +2 -2
  107. package/include/numkong/geospatial/README.md +18 -18
  108. package/include/numkong/geospatial/haswell.h +365 -325
  109. package/include/numkong/geospatial/neon.h +350 -306
  110. package/include/numkong/geospatial/rvv.h +4 -4
  111. package/include/numkong/geospatial/skylake.h +376 -340
  112. package/include/numkong/geospatial/v128relaxed.h +366 -327
  113. package/include/numkong/geospatial.h +17 -17
  114. package/include/numkong/matrix.hpp +4 -4
  115. package/include/numkong/maxsim/README.md +14 -14
  116. package/include/numkong/maxsim/alder.h +6 -6
  117. package/include/numkong/maxsim/genoa.h +4 -4
  118. package/include/numkong/maxsim/haswell.h +6 -6
  119. package/include/numkong/maxsim/icelake.h +18 -18
  120. package/include/numkong/maxsim/neonsdot.h +21 -21
  121. package/include/numkong/maxsim/sapphireamx.h +14 -14
  122. package/include/numkong/maxsim/serial.h +6 -6
  123. package/include/numkong/maxsim/sme.h +221 -196
  124. package/include/numkong/maxsim/v128relaxed.h +6 -6
  125. package/include/numkong/mesh/README.md +62 -56
  126. package/include/numkong/mesh/haswell.h +339 -464
  127. package/include/numkong/mesh/neon.h +1100 -519
  128. package/include/numkong/mesh/neonbfdot.h +36 -68
  129. package/include/numkong/mesh/rvv.h +530 -435
  130. package/include/numkong/mesh/serial.h +75 -91
  131. package/include/numkong/mesh/skylake.h +1627 -302
  132. package/include/numkong/mesh/v128relaxed.h +443 -330
  133. package/include/numkong/mesh.h +63 -49
  134. package/include/numkong/mesh.hpp +4 -4
  135. package/include/numkong/numkong.h +3 -3
  136. package/include/numkong/numkong.hpp +1 -0
  137. package/include/numkong/probability/README.md +23 -19
  138. package/include/numkong/probability/neon.h +82 -52
  139. package/include/numkong/probability/rvv.h +28 -23
  140. package/include/numkong/probability/serial.h +51 -39
  141. package/include/numkong/probability.h +20 -23
  142. package/include/numkong/random.h +1 -1
  143. package/include/numkong/reduce/README.md +143 -138
  144. package/include/numkong/reduce/alder.h +81 -77
  145. package/include/numkong/reduce/haswell.h +222 -220
  146. package/include/numkong/reduce/neon.h +629 -519
  147. package/include/numkong/reduce/neonbfdot.h +7 -218
  148. package/include/numkong/reduce/neonfhm.h +9 -381
  149. package/include/numkong/reduce/neonsdot.h +9 -9
  150. package/include/numkong/reduce/rvv.h +928 -802
  151. package/include/numkong/reduce/serial.h +23 -27
  152. package/include/numkong/reduce/sierra.h +20 -20
  153. package/include/numkong/reduce/skylake.h +326 -324
  154. package/include/numkong/reduce/v128relaxed.h +52 -52
  155. package/include/numkong/reduce.h +4 -23
  156. package/include/numkong/reduce.hpp +156 -11
  157. package/include/numkong/scalar/README.md +6 -6
  158. package/include/numkong/scalar/haswell.h +26 -17
  159. package/include/numkong/scalar/loongsonasx.h +74 -0
  160. package/include/numkong/scalar/neon.h +9 -9
  161. package/include/numkong/scalar/powervsx.h +96 -0
  162. package/include/numkong/scalar/rvv.h +2 -2
  163. package/include/numkong/scalar/sapphire.h +21 -10
  164. package/include/numkong/scalar/serial.h +21 -21
  165. package/include/numkong/scalar.h +13 -0
  166. package/include/numkong/set/README.md +28 -28
  167. package/include/numkong/set/haswell.h +12 -12
  168. package/include/numkong/set/icelake.h +14 -14
  169. package/include/numkong/set/loongsonasx.h +181 -0
  170. package/include/numkong/set/neon.h +17 -18
  171. package/include/numkong/set/powervsx.h +326 -0
  172. package/include/numkong/set/rvv.h +4 -4
  173. package/include/numkong/set/serial.h +6 -6
  174. package/include/numkong/set/sve.h +60 -59
  175. package/include/numkong/set/v128relaxed.h +6 -6
  176. package/include/numkong/set.h +21 -7
  177. package/include/numkong/sets/README.md +26 -26
  178. package/include/numkong/sets/loongsonasx.h +52 -0
  179. package/include/numkong/sets/powervsx.h +65 -0
  180. package/include/numkong/sets/smebi32.h +395 -364
  181. package/include/numkong/sets.h +83 -40
  182. package/include/numkong/sparse/README.md +4 -4
  183. package/include/numkong/sparse/icelake.h +101 -101
  184. package/include/numkong/sparse/serial.h +1 -1
  185. package/include/numkong/sparse/sve2.h +137 -141
  186. package/include/numkong/sparse/turin.h +12 -12
  187. package/include/numkong/sparse.h +10 -10
  188. package/include/numkong/spatial/README.md +230 -226
  189. package/include/numkong/spatial/alder.h +113 -116
  190. package/include/numkong/spatial/diamond.h +240 -0
  191. package/include/numkong/spatial/genoa.h +0 -68
  192. package/include/numkong/spatial/haswell.h +74 -55
  193. package/include/numkong/spatial/icelake.h +539 -58
  194. package/include/numkong/spatial/loongsonasx.h +483 -0
  195. package/include/numkong/spatial/neon.h +125 -52
  196. package/include/numkong/spatial/neonbfdot.h +8 -9
  197. package/include/numkong/spatial/neonfp8.h +258 -0
  198. package/include/numkong/spatial/neonsdot.h +180 -12
  199. package/include/numkong/spatial/powervsx.h +738 -0
  200. package/include/numkong/spatial/rvv.h +146 -139
  201. package/include/numkong/spatial/rvvbf16.h +17 -12
  202. package/include/numkong/spatial/rvvhalf.h +13 -10
  203. package/include/numkong/spatial/serial.h +13 -12
  204. package/include/numkong/spatial/sierra.h +232 -39
  205. package/include/numkong/spatial/skylake.h +73 -74
  206. package/include/numkong/spatial/sve.h +93 -72
  207. package/include/numkong/spatial/svebfdot.h +29 -29
  208. package/include/numkong/spatial/svehalf.h +52 -26
  209. package/include/numkong/spatial/svesdot.h +142 -0
  210. package/include/numkong/spatial/v128relaxed.h +293 -41
  211. package/include/numkong/spatial.h +338 -82
  212. package/include/numkong/spatials/README.md +194 -194
  213. package/include/numkong/spatials/diamond.h +82 -0
  214. package/include/numkong/spatials/haswell.h +2 -2
  215. package/include/numkong/spatials/loongsonasx.h +153 -0
  216. package/include/numkong/spatials/neonfp8.h +111 -0
  217. package/include/numkong/spatials/neonsdot.h +34 -0
  218. package/include/numkong/spatials/powervsx.h +153 -0
  219. package/include/numkong/spatials/rvv.h +259 -243
  220. package/include/numkong/spatials/sapphireamx.h +173 -173
  221. package/include/numkong/spatials/serial.h +2 -2
  222. package/include/numkong/spatials/skylake.h +2 -2
  223. package/include/numkong/spatials/sme.h +590 -605
  224. package/include/numkong/spatials/smef64.h +139 -130
  225. package/include/numkong/spatials/v128relaxed.h +2 -2
  226. package/include/numkong/spatials.h +820 -500
  227. package/include/numkong/spatials.hpp +49 -48
  228. package/include/numkong/tensor.hpp +406 -17
  229. package/include/numkong/trigonometry/README.md +19 -19
  230. package/include/numkong/trigonometry/haswell.h +402 -401
  231. package/include/numkong/trigonometry/neon.h +386 -387
  232. package/include/numkong/trigonometry/rvv.h +52 -51
  233. package/include/numkong/trigonometry/serial.h +13 -13
  234. package/include/numkong/trigonometry/skylake.h +373 -369
  235. package/include/numkong/trigonometry/v128relaxed.h +375 -374
  236. package/include/numkong/trigonometry.h +13 -13
  237. package/include/numkong/trigonometry.hpp +2 -2
  238. package/include/numkong/types.h +287 -49
  239. package/include/numkong/types.hpp +436 -12
  240. package/include/numkong/vector.hpp +82 -14
  241. package/javascript/dist/cjs/numkong-wasm.js +6 -12
  242. package/javascript/dist/cjs/numkong.d.ts +7 -1
  243. package/javascript/dist/cjs/numkong.js +37 -11
  244. package/javascript/dist/cjs/types.d.ts +9 -0
  245. package/javascript/dist/cjs/types.js +96 -0
  246. package/javascript/dist/esm/numkong-browser.d.ts +14 -0
  247. package/javascript/dist/esm/numkong-browser.js +23 -0
  248. package/javascript/dist/esm/numkong-wasm.js +6 -12
  249. package/javascript/dist/esm/numkong.d.ts +7 -1
  250. package/javascript/dist/esm/numkong.js +37 -11
  251. package/javascript/dist/esm/types.d.ts +9 -0
  252. package/javascript/dist/esm/types.js +96 -0
  253. package/javascript/node-gyp-build.d.ts +4 -1
  254. package/javascript/numkong-browser.ts +40 -0
  255. package/javascript/numkong-wasm.ts +7 -13
  256. package/javascript/numkong.c +5 -26
  257. package/javascript/numkong.ts +36 -11
  258. package/javascript/tsconfig-base.json +1 -0
  259. package/javascript/tsconfig-cjs.json +6 -1
  260. package/javascript/types.ts +110 -0
  261. package/numkong.gypi +101 -0
  262. package/package.json +34 -13
  263. package/probes/arm_neon.c +8 -0
  264. package/probes/arm_neon_bfdot.c +9 -0
  265. package/probes/arm_neon_fhm.c +9 -0
  266. package/probes/arm_neon_half.c +8 -0
  267. package/probes/arm_neon_sdot.c +9 -0
  268. package/probes/arm_neonfp8.c +9 -0
  269. package/probes/arm_sme.c +16 -0
  270. package/probes/arm_sme2.c +16 -0
  271. package/probes/arm_sme2p1.c +16 -0
  272. package/probes/arm_sme_bf16.c +16 -0
  273. package/probes/arm_sme_bi32.c +16 -0
  274. package/probes/arm_sme_f64.c +16 -0
  275. package/probes/arm_sme_fa64.c +14 -0
  276. package/probes/arm_sme_half.c +16 -0
  277. package/probes/arm_sme_lut2.c +15 -0
  278. package/probes/arm_sve.c +18 -0
  279. package/probes/arm_sve2.c +20 -0
  280. package/probes/arm_sve2p1.c +18 -0
  281. package/probes/arm_sve_bfdot.c +20 -0
  282. package/probes/arm_sve_half.c +18 -0
  283. package/probes/arm_sve_sdot.c +21 -0
  284. package/probes/loongarch_lasx.c +12 -0
  285. package/probes/power_vsx.c +12 -0
  286. package/probes/probe.js +127 -0
  287. package/probes/riscv_rvv.c +14 -0
  288. package/probes/riscv_rvv_bb.c +15 -0
  289. package/probes/riscv_rvv_bf16.c +17 -0
  290. package/probes/riscv_rvv_half.c +14 -0
  291. package/probes/wasm_v128relaxed.c +11 -0
  292. package/probes/x86_alder.c +17 -0
  293. package/probes/x86_diamond.c +17 -0
  294. package/probes/x86_genoa.c +17 -0
  295. package/probes/x86_graniteamx.c +19 -0
  296. package/probes/x86_haswell.c +11 -0
  297. package/probes/x86_icelake.c +17 -0
  298. package/probes/x86_sapphire.c +16 -0
  299. package/probes/x86_sapphireamx.c +18 -0
  300. package/probes/x86_sierra.c +17 -0
  301. package/probes/x86_skylake.c +15 -0
  302. package/probes/x86_turin.c +17 -0
  303. package/wasm/numkong-emscripten.js +2 -0
  304. package/wasm/numkong.d.ts +14 -0
  305. package/wasm/numkong.js +1124 -0
  306. package/wasm/numkong.wasm +0 -0
  307. package/include/numkong/curved/neonhalf.h +0 -212
  308. package/include/numkong/dot/neonhalf.h +0 -198
  309. package/include/numkong/dots/neonhalf.h +0 -57
  310. package/include/numkong/mesh/neonhalf.h +0 -616
  311. package/include/numkong/reduce/neonhalf.h +0 -157
  312. package/include/numkong/spatial/neonhalf.h +0 -118
  313. package/include/numkong/spatial/sapphire.h +0 -343
  314. package/include/numkong/spatials/neonhalf.h +0 -58
  315. package/javascript/README.md +0 -246
@@ -45,7 +45,7 @@ extern "C" {
45
45
  #endif
46
46
 
47
47
  #if defined(__clang__)
48
- #pragma clang attribute push(__attribute__((target("sme,sve,sme-f64f64"))), apply_to = function)
48
+ #pragma clang attribute push(__attribute__((target("sme,sme-f64f64"))), apply_to = function)
49
49
  #elif defined(__GNUC__)
50
50
  #pragma GCC push_options
51
51
  #pragma GCC target("+sme+sme-f64f64")
@@ -72,11 +72,11 @@ extern "C" {
72
72
  * for higher-than-f32 accumulation precision; replacing it with f32 FMOPA would be
73
73
  * counterproductive. Apple M4 has `hw.optional.arm.SME_F32F32: 1` but we don't use it here.
74
74
  */
75
- #pragma region Single Precision Floats
75
+ #pragma region F32 Floats
76
76
 
77
77
  NK_PUBLIC nk_size_t nk_dots_packed_size_f32_smef64(nk_size_t columns, nk_size_t depth) {
78
- nk_size_t const tile_dimension = svcntsd(); // rows per `ZA64` tile (8 for SVL=512)
79
- nk_size_t const depth_tile_size = svcntsw(); // `f32` depth elements per tile (16 for SVL=512)
78
+ nk_size_t const tile_dimension = nk_sme_cntd_(); // rows per `ZA64` tile (8 for SVL=512)
79
+ nk_size_t const depth_tile_size = nk_sme_cntw_(); // `f32` depth elements per tile (16 for SVL=512)
80
80
 
81
81
  nk_size_t const column_tile_count = nk_size_divide_round_up_(columns, tile_dimension);
82
82
  nk_size_t const depth_tile_count = nk_size_divide_round_up_(depth, depth_tile_size);
@@ -88,13 +88,13 @@ NK_PUBLIC nk_size_t nk_dots_packed_size_f32_smef64(nk_size_t columns, nk_size_t
88
88
  return size;
89
89
  }
90
90
 
91
- NK_PUBLIC void nk_dots_pack_f32_smef64(nk_f32_t const *b, nk_size_t columns, nk_size_t depth, nk_size_t b_stride,
92
- void *b_packed) {
91
+ NK_PUBLIC void nk_dots_pack_f32_smef64(nk_f32_t const *b, nk_size_t columns, nk_size_t depth,
92
+ nk_size_t b_stride_in_bytes, void *b_packed) {
93
93
 
94
- nk_size_t const tile_dimension = svcntsd(); // rows per `ZA64` tile (8 for SVL=512)
95
- nk_size_t const depth_tile_size = svcntsw(); // `f32` depth elements per tile (16 for SVL=512)
94
+ nk_size_t const tile_dimension = nk_sme_cntd_(); // rows per `ZA64` tile (8 for SVL=512)
95
+ nk_size_t const depth_tile_size = nk_sme_cntw_(); // `f32` depth elements per tile (16 for SVL=512)
96
96
  nk_size_t const tile_elements = tile_dimension * depth_tile_size; // 128
97
- nk_size_t const b_stride_elements = b_stride / sizeof(nk_f32_t);
97
+ nk_size_t const b_stride_elements = b_stride_in_bytes / sizeof(nk_f32_t);
98
98
 
99
99
  nk_size_t const column_tile_count = nk_size_divide_round_up_(columns, tile_dimension);
100
100
  nk_size_t const depth_tile_count = nk_size_divide_round_up_(depth, depth_tile_size);
@@ -106,7 +106,7 @@ NK_PUBLIC void nk_dots_pack_f32_smef64(nk_f32_t const *b, nk_size_t columns, nk_
106
106
  header->depth_tile_count = (nk_u32_t)depth_tile_count;
107
107
  header->columns = (nk_u32_t)columns;
108
108
  header->depth = (nk_u32_t)depth;
109
- header->svl_bytes = (nk_u32_t)svcntsb(); // streaming vector length in bytes
109
+ header->svl_bytes = (nk_u32_t)nk_sme_cntb_(); // streaming vector length in bytes
110
110
 
111
111
  nk_f32_t *tiles = (nk_f32_t *)((char *)b_packed + sizeof(nk_dots_sme_packed_header_t));
112
112
 
@@ -148,7 +148,7 @@ NK_PUBLIC void nk_dots_pack_f32_smef64(nk_f32_t const *b, nk_size_t columns, nk_
148
148
  header->norms_offset = (nk_u32_t)(sizeof(nk_dots_sme_packed_header_t) + data_size);
149
149
  nk_f64_t *norms_ptr = (nk_f64_t *)((char *)b_packed + header->norms_offset);
150
150
  for (nk_size_t col = 0; col < columns; col++) {
151
- nk_f32_t const *col_data = (nk_f32_t const *)((char const *)b + col * b_stride);
151
+ nk_f32_t const *col_data = (nk_f32_t const *)((char const *)b + col * b_stride_in_bytes);
152
152
  norms_ptr[col] = nk_dots_reduce_sumsq_f32_(col_data, depth);
153
153
  }
154
154
  }
@@ -168,14 +168,14 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_packed_f32_smef64_st
168
168
 
169
169
  nk_f32_t const *b_tiles = (nk_f32_t const *)((char const *)b_packed + sizeof(nk_dots_sme_packed_header_t));
170
170
 
171
- svbool_t const predicate_all_f64x = svptrue_b64();
171
+ svbool_t const predicate_all_b64x = svptrue_b64();
172
172
 
173
173
  // ZA0.D = staging, ZA1-7.D = accumulation (7-tile fast path)
174
174
  for (nk_size_t row_tile_index = 0; row_tile_index < nk_size_divide_round_up_(rows, tile_dimension);
175
175
  row_tile_index++) {
176
176
  nk_size_t const row_start = row_tile_index * tile_dimension;
177
177
  nk_size_t const rows_remaining = (row_start + tile_dimension <= rows) ? tile_dimension : (rows - row_start);
178
- svbool_t const row_predicate_f64x = svwhilelt_b64_u64(0u, rows_remaining);
178
+ svbool_t const row_predicate_b64x = svwhilelt_b64_u64(0u, rows_remaining);
179
179
 
180
180
  nk_size_t column_tile_index = 0;
181
181
 
@@ -200,18 +200,17 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_packed_f32_smef64_st
200
200
  svzero_mask_za(nk_sme_zero_za64_tile_0_);
201
201
 
202
202
  // Load A rows into ZA0.D: extending load f32→u64 + convert to f64
203
- svbool_t const batch_predicate_f64x = svwhilelt_b64_u64(0u, (uint64_t)batch_size);
204
- svbool_t const a_depth_predicate_f64x = svwhilelt_b64_u64(depth_offset + depth_batch_start,
205
- (uint64_t)depth);
203
+ svbool_t const batch_predicate_b64x = svwhilelt_b64_u64(0u, batch_size);
204
+ svbool_t const a_depth_predicate_b64x = svwhilelt_b64_u64(depth_offset + depth_batch_start, depth);
206
205
  for (nk_size_t row_in_tile = 0; row_in_tile < rows_remaining; row_in_tile++) {
207
206
  nk_size_t const a_row = row_start + row_in_tile;
208
207
  // Extending load: svld1uw_u64 loads f32 bits into lower 32 of each u64 lane
209
208
  svfloat64_t a_row_widened_f64x = svcvt_f64_f32_x(
210
- batch_predicate_f64x,
209
+ batch_predicate_b64x,
211
210
  svreinterpret_f32_u64(svld1uw_u64(
212
- a_depth_predicate_f64x,
211
+ a_depth_predicate_b64x,
213
212
  (nk_u32_t const *)&a[a_row * a_stride_elements + depth_offset + depth_batch_start])));
214
- svwrite_hor_za64_f64_m(0, row_in_tile, batch_predicate_f64x, a_row_widened_f64x);
213
+ svwrite_hor_za64_f64_m(0, row_in_tile, batch_predicate_b64x, a_row_widened_f64x);
215
214
  }
216
215
 
217
216
  // Vertical read + MOPA for each depth step in batch
@@ -219,110 +218,110 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_packed_f32_smef64_st
219
218
  nk_size_t const k_abs = depth_offset + depth_batch_start + step;
220
219
  if (k_abs >= depth) break;
221
220
 
222
- svfloat64_t a_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), row_predicate_f64x, 0, step);
221
+ svfloat64_t a_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), row_predicate_b64x, 0, step);
223
222
 
224
223
  nk_size_t const b_k = depth_batch_start + step;
225
224
 
226
225
  // Extending load f32→u64 + convert to f64: svld1uw_u64 replaces svld1_f32 + svunpklo_u64
227
226
  svfloat64_t b_column_tile_1_f64x = svcvt_f64_f32_x(
228
- predicate_all_f64x,
227
+ predicate_all_b64x,
229
228
  svreinterpret_f32_u64(svld1uw_u64(
230
- predicate_all_f64x,
229
+ predicate_all_b64x,
231
230
  (nk_u32_t const *)(b_tiles +
232
231
  ((column_tile_index + 0) * depth_tile_count + depth_tile_idx) *
233
232
  tile_elements +
234
233
  b_k * tile_dimension))));
235
234
  svfloat64_t b_column_tile_2_f64x = svcvt_f64_f32_x(
236
- predicate_all_f64x,
235
+ predicate_all_b64x,
237
236
  svreinterpret_f32_u64(svld1uw_u64(
238
- predicate_all_f64x,
237
+ predicate_all_b64x,
239
238
  (nk_u32_t const *)(b_tiles +
240
239
  ((column_tile_index + 1) * depth_tile_count + depth_tile_idx) *
241
240
  tile_elements +
242
241
  b_k * tile_dimension))));
243
242
  svfloat64_t b_column_tile_3_f64x = svcvt_f64_f32_x(
244
- predicate_all_f64x,
243
+ predicate_all_b64x,
245
244
  svreinterpret_f32_u64(svld1uw_u64(
246
- predicate_all_f64x,
245
+ predicate_all_b64x,
247
246
  (nk_u32_t const *)(b_tiles +
248
247
  ((column_tile_index + 2) * depth_tile_count + depth_tile_idx) *
249
248
  tile_elements +
250
249
  b_k * tile_dimension))));
251
250
  svfloat64_t b_column_tile_4_f64x = svcvt_f64_f32_x(
252
- predicate_all_f64x,
251
+ predicate_all_b64x,
253
252
  svreinterpret_f32_u64(svld1uw_u64(
254
- predicate_all_f64x,
253
+ predicate_all_b64x,
255
254
  (nk_u32_t const *)(b_tiles +
256
255
  ((column_tile_index + 3) * depth_tile_count + depth_tile_idx) *
257
256
  tile_elements +
258
257
  b_k * tile_dimension))));
259
258
  svfloat64_t b_column_tile_5_f64x = svcvt_f64_f32_x(
260
- predicate_all_f64x,
259
+ predicate_all_b64x,
261
260
  svreinterpret_f32_u64(svld1uw_u64(
262
- predicate_all_f64x,
261
+ predicate_all_b64x,
263
262
  (nk_u32_t const *)(b_tiles +
264
263
  ((column_tile_index + 4) * depth_tile_count + depth_tile_idx) *
265
264
  tile_elements +
266
265
  b_k * tile_dimension))));
267
266
  svfloat64_t b_column_tile_6_f64x = svcvt_f64_f32_x(
268
- predicate_all_f64x,
267
+ predicate_all_b64x,
269
268
  svreinterpret_f32_u64(svld1uw_u64(
270
- predicate_all_f64x,
269
+ predicate_all_b64x,
271
270
  (nk_u32_t const *)(b_tiles +
272
271
  ((column_tile_index + 5) * depth_tile_count + depth_tile_idx) *
273
272
  tile_elements +
274
273
  b_k * tile_dimension))));
275
274
  svfloat64_t b_column_tile_7_f64x = svcvt_f64_f32_x(
276
- predicate_all_f64x,
275
+ predicate_all_b64x,
277
276
  svreinterpret_f32_u64(svld1uw_u64(
278
- predicate_all_f64x,
277
+ predicate_all_b64x,
279
278
  (nk_u32_t const *)(b_tiles +
280
279
  ((column_tile_index + 6) * depth_tile_count + depth_tile_idx) *
281
280
  tile_elements +
282
281
  b_k * tile_dimension))));
283
282
 
284
- svmopa_za64_f64_m(1, row_predicate_f64x, predicate_all_f64x, a_f64x, b_column_tile_1_f64x);
285
- svmopa_za64_f64_m(2, row_predicate_f64x, predicate_all_f64x, a_f64x, b_column_tile_2_f64x);
286
- svmopa_za64_f64_m(3, row_predicate_f64x, predicate_all_f64x, a_f64x, b_column_tile_3_f64x);
287
- svmopa_za64_f64_m(4, row_predicate_f64x, predicate_all_f64x, a_f64x, b_column_tile_4_f64x);
288
- svmopa_za64_f64_m(5, row_predicate_f64x, predicate_all_f64x, a_f64x, b_column_tile_5_f64x);
289
- svmopa_za64_f64_m(6, row_predicate_f64x, predicate_all_f64x, a_f64x, b_column_tile_6_f64x);
290
- svmopa_za64_f64_m(7, row_predicate_f64x, predicate_all_f64x, a_f64x, b_column_tile_7_f64x);
283
+ svmopa_za64_f64_m(1, row_predicate_b64x, predicate_all_b64x, a_f64x, b_column_tile_1_f64x);
284
+ svmopa_za64_f64_m(2, row_predicate_b64x, predicate_all_b64x, a_f64x, b_column_tile_2_f64x);
285
+ svmopa_za64_f64_m(3, row_predicate_b64x, predicate_all_b64x, a_f64x, b_column_tile_3_f64x);
286
+ svmopa_za64_f64_m(4, row_predicate_b64x, predicate_all_b64x, a_f64x, b_column_tile_4_f64x);
287
+ svmopa_za64_f64_m(5, row_predicate_b64x, predicate_all_b64x, a_f64x, b_column_tile_5_f64x);
288
+ svmopa_za64_f64_m(6, row_predicate_b64x, predicate_all_b64x, a_f64x, b_column_tile_6_f64x);
289
+ svmopa_za64_f64_m(7, row_predicate_b64x, predicate_all_b64x, a_f64x, b_column_tile_7_f64x);
291
290
  }
292
291
  }
293
292
  }
294
293
 
295
294
  // Extract from ZA1-7 and store native f64 outputs.
296
- svbool_t const predicate_tile_f64x = svwhilelt_b64_u64(0u, tile_dimension);
295
+ svbool_t const predicate_tile_b64x = svwhilelt_b64_u64(0u, tile_dimension);
297
296
  // The 7th tile (index 6) may be partial when it's the last column tile
298
297
  nk_size_t const last_fast_col_start = (column_tile_index + 6) * tile_dimension;
299
298
  nk_size_t const last_fast_cols = (last_fast_col_start + tile_dimension <= columns)
300
299
  ? tile_dimension
301
300
  : (columns - last_fast_col_start);
302
- svbool_t const last_tile_pred_f64x = svwhilelt_b64_u64(0u, last_fast_cols);
301
+ svbool_t const last_tile_pred_b64x = svwhilelt_b64_u64(0u, last_fast_cols);
303
302
  for (nk_size_t row_idx = 0; row_idx < rows_remaining; row_idx++) {
304
303
  nk_f64_t *c_row = c + (row_start + row_idx) * c_stride_elements;
305
304
 
306
- svfloat64_t za_row_f64x = svread_hor_za64_f64_m(svdup_f64(0), predicate_all_f64x, 1, row_idx);
307
- svst1_f64(predicate_tile_f64x, c_row + (column_tile_index + 0) * tile_dimension, za_row_f64x);
305
+ svfloat64_t za_row_f64x = svread_hor_za64_f64_m(svdup_f64(0), predicate_all_b64x, 1, row_idx);
306
+ svst1_f64(predicate_tile_b64x, c_row + (column_tile_index + 0) * tile_dimension, za_row_f64x);
308
307
 
309
- za_row_f64x = svread_hor_za64_f64_m(svdup_f64(0), predicate_all_f64x, 2, row_idx);
310
- svst1_f64(predicate_tile_f64x, c_row + (column_tile_index + 1) * tile_dimension, za_row_f64x);
308
+ za_row_f64x = svread_hor_za64_f64_m(svdup_f64(0), predicate_all_b64x, 2, row_idx);
309
+ svst1_f64(predicate_tile_b64x, c_row + (column_tile_index + 1) * tile_dimension, za_row_f64x);
311
310
 
312
- za_row_f64x = svread_hor_za64_f64_m(svdup_f64(0), predicate_all_f64x, 3, row_idx);
313
- svst1_f64(predicate_tile_f64x, c_row + (column_tile_index + 2) * tile_dimension, za_row_f64x);
311
+ za_row_f64x = svread_hor_za64_f64_m(svdup_f64(0), predicate_all_b64x, 3, row_idx);
312
+ svst1_f64(predicate_tile_b64x, c_row + (column_tile_index + 2) * tile_dimension, za_row_f64x);
314
313
 
315
- za_row_f64x = svread_hor_za64_f64_m(svdup_f64(0), predicate_all_f64x, 4, row_idx);
316
- svst1_f64(predicate_tile_f64x, c_row + (column_tile_index + 3) * tile_dimension, za_row_f64x);
314
+ za_row_f64x = svread_hor_za64_f64_m(svdup_f64(0), predicate_all_b64x, 4, row_idx);
315
+ svst1_f64(predicate_tile_b64x, c_row + (column_tile_index + 3) * tile_dimension, za_row_f64x);
317
316
 
318
- za_row_f64x = svread_hor_za64_f64_m(svdup_f64(0), predicate_all_f64x, 5, row_idx);
319
- svst1_f64(predicate_tile_f64x, c_row + (column_tile_index + 4) * tile_dimension, za_row_f64x);
317
+ za_row_f64x = svread_hor_za64_f64_m(svdup_f64(0), predicate_all_b64x, 5, row_idx);
318
+ svst1_f64(predicate_tile_b64x, c_row + (column_tile_index + 4) * tile_dimension, za_row_f64x);
320
319
 
321
- za_row_f64x = svread_hor_za64_f64_m(svdup_f64(0), predicate_all_f64x, 6, row_idx);
322
- svst1_f64(predicate_tile_f64x, c_row + (column_tile_index + 5) * tile_dimension, za_row_f64x);
320
+ za_row_f64x = svread_hor_za64_f64_m(svdup_f64(0), predicate_all_b64x, 6, row_idx);
321
+ svst1_f64(predicate_tile_b64x, c_row + (column_tile_index + 5) * tile_dimension, za_row_f64x);
323
322
 
324
- za_row_f64x = svread_hor_za64_f64_m(svdup_f64(0), predicate_all_f64x, 7, row_idx);
325
- svst1_f64(last_tile_pred_f64x, c_row + (column_tile_index + 6) * tile_dimension, za_row_f64x);
323
+ za_row_f64x = svread_hor_za64_f64_m(svdup_f64(0), predicate_all_b64x, 7, row_idx);
324
+ svst1_f64(last_tile_pred_b64x, c_row + (column_tile_index + 6) * tile_dimension, za_row_f64x);
326
325
  }
327
326
  }
328
327
 
@@ -331,7 +330,7 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_packed_f32_smef64_st
331
330
  nk_size_t const column_start = column_tile_index * tile_dimension;
332
331
  nk_size_t const columns_remaining = (column_start + tile_dimension <= columns) ? tile_dimension
333
332
  : (columns - column_start);
334
- svbool_t const column_predicate_f64x = svwhilelt_b64_u64(0u, columns_remaining);
333
+ svbool_t const column_predicate_b64x = svwhilelt_b64_u64(0u, columns_remaining);
335
334
 
336
335
  svzero_mask_za(nk_sme_zero_za64_tile_1_);
337
336
 
@@ -349,54 +348,54 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_packed_f32_smef64_st
349
348
 
350
349
  svzero_mask_za(nk_sme_zero_za64_tile_0_);
351
350
 
352
- svbool_t const batch_predicate_f64x = svwhilelt_b64_u64(0u, (uint64_t)batch_size);
353
- svbool_t const a_depth_pred_f64x = svwhilelt_b64_u64(depth_offset + depth_batch_start,
354
- (uint64_t)depth);
351
+ svbool_t const batch_predicate_b64x = svwhilelt_b64_u64(0u, batch_size);
352
+ svbool_t const a_depth_pred_b64x = svwhilelt_b64_u64(depth_offset + depth_batch_start, depth);
355
353
  for (nk_size_t row_in_tile = 0; row_in_tile < rows_remaining; row_in_tile++) {
356
354
  nk_size_t const a_row = row_start + row_in_tile;
357
355
  svfloat64_t a_row_widened_f64x = svcvt_f64_f32_x(
358
- batch_predicate_f64x,
356
+ batch_predicate_b64x,
359
357
  svreinterpret_f32_u64(svld1uw_u64(
360
- a_depth_pred_f64x,
358
+ a_depth_pred_b64x,
361
359
  (nk_u32_t const *)&a[a_row * a_stride_elements + depth_offset + depth_batch_start])));
362
- svwrite_hor_za64_f64_m(0, row_in_tile, batch_predicate_f64x, a_row_widened_f64x);
360
+ svwrite_hor_za64_f64_m(0, row_in_tile, batch_predicate_b64x, a_row_widened_f64x);
363
361
  }
364
362
 
365
363
  for (nk_size_t step = 0; step < batch_size; step++) {
366
364
  nk_size_t const k_abs = depth_offset + depth_batch_start + step;
367
365
  if (k_abs >= depth) break;
368
366
 
369
- svfloat64_t a_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), row_predicate_f64x, 0, step);
367
+ svfloat64_t a_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), row_predicate_b64x, 0, step);
370
368
 
371
369
  nk_size_t const b_k = depth_batch_start + step;
372
370
  nk_f32_t const *b_tile = b_tiles + (column_tile_index * depth_tile_count + depth_tile_idx) *
373
371
  tile_elements;
374
372
  // Extending load f32→u64 + convert to f64
375
373
  svfloat64_t b_f64x = svcvt_f64_f32_x(
376
- predicate_all_f64x,
374
+ predicate_all_b64x,
377
375
  svreinterpret_f32_u64(
378
- svld1uw_u64(predicate_all_f64x, (nk_u32_t const *)(b_tile + b_k * tile_dimension))));
376
+ svld1uw_u64(predicate_all_b64x, (nk_u32_t const *)(b_tile + b_k * tile_dimension))));
379
377
 
380
- svmopa_za64_f64_m(1, row_predicate_f64x, column_predicate_f64x, a_f64x, b_f64x);
378
+ svmopa_za64_f64_m(1, row_predicate_b64x, column_predicate_b64x, a_f64x, b_f64x);
381
379
  }
382
380
  }
383
381
  }
384
382
 
385
383
  // Store native f64 outputs for the tail column tile.
386
384
  for (nk_size_t row_idx = 0; row_idx < rows_remaining; row_idx++) {
387
- svfloat64_t za_row_f64x = svread_hor_za64_f64_m(svdup_f64(0), predicate_all_f64x, 1, row_idx);
385
+ svfloat64_t za_row_f64x = svread_hor_za64_f64_m(svdup_f64(0), predicate_all_b64x, 1, row_idx);
388
386
  nk_f64_t *c_row = c + (row_start + row_idx) * c_stride_elements + column_start;
389
- svst1_f64(column_predicate_f64x, c_row, za_row_f64x);
387
+ svst1_f64(column_predicate_b64x, c_row, za_row_f64x);
390
388
  }
391
389
  }
392
390
  }
393
391
  }
394
392
 
395
393
  NK_PUBLIC void nk_dots_packed_f32_smef64(nk_f32_t const *a, void const *b_packed, nk_f64_t *c, nk_size_t rows,
396
- nk_size_t columns, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride) {
394
+ nk_size_t columns, nk_size_t depth, nk_size_t a_stride_in_bytes,
395
+ nk_size_t c_stride_in_bytes) {
397
396
 
398
- nk_size_t const a_stride_elements = a_stride / sizeof(nk_f32_t);
399
- nk_size_t const c_stride_elements = c_stride / sizeof(nk_f64_t);
397
+ nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_f32_t);
398
+ nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f64_t);
400
399
 
401
400
  nk_dots_packed_f32_smef64_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
402
401
  }
@@ -408,30 +407,32 @@ NK_PUBLIC void nk_dots_packed_f32_smef64(nk_f32_t const *a, void const *b_packed
408
407
  * per column tile. Eliminates all scalar B-packing loops.
409
408
  */
410
409
  __arm_locally_streaming __arm_new("za") static void nk_dots_symmetric_f32_smef64_streaming_(
411
- nk_f32_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride_elements, nk_f64_t *result,
410
+ nk_f32_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, nk_f64_t *result,
412
411
  nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
413
412
 
414
413
  nk_size_t const tile_dimension = svcntd(); // 8 for SVL=512
415
414
  nk_size_t const depth_tile_size = svcntw(); // 16 for SVL=512
416
415
  nk_size_t const depth_steps_per_batch = tile_dimension; // 8
417
416
 
418
- svbool_t const predicate_all_f64x = svptrue_b64();
417
+ svbool_t const predicate_all_b64x = svptrue_b64();
419
418
 
420
419
  NK_ALIGN64 nk_f64_t a_buffer[8][8];
421
420
 
422
421
  nk_size_t const row_end = row_start + row_count;
423
- nk_size_t const column_tile_count = nk_size_divide_round_up_(n_vectors, tile_dimension);
422
+ nk_size_t const column_tile_count = nk_size_divide_round_up_(vectors_count, tile_dimension);
424
423
  nk_size_t const depth_tile_count = nk_size_divide_round_up_(depth, depth_tile_size);
425
424
 
426
- for (nk_size_t row_tile_start = row_start; row_tile_start < row_end && row_tile_start < n_vectors;
425
+ for (nk_size_t row_tile_start = row_start; row_tile_start < row_end && row_tile_start < vectors_count;
427
426
  row_tile_start += tile_dimension) {
428
427
  nk_size_t const rows_clamped = (row_tile_start + tile_dimension <= row_end) ? tile_dimension
429
428
  : (row_end - row_tile_start);
430
- nk_size_t const rows_actual = (row_tile_start + rows_clamped <= n_vectors) ? rows_clamped
431
- : (n_vectors - row_tile_start);
432
- svbool_t const row_predicate_f64x = svwhilelt_b64_u64(0u, rows_actual);
429
+ nk_size_t const rows_actual = (row_tile_start + rows_clamped <= vectors_count)
430
+ ? rows_clamped
431
+ : (vectors_count - row_tile_start);
432
+ svbool_t const row_predicate_b64x = svwhilelt_b64_u64(0u, rows_actual);
433
433
 
434
- nk_size_t column_tile_index = 0;
434
+ // Upper triangle: start from this row tile's column
435
+ nk_size_t column_tile_index = row_tile_start / tile_dimension;
435
436
 
436
437
  // Fast path: 7 column tiles at a time
437
438
  for (; column_tile_index + 7 <= column_tile_count; column_tile_index += 7) {
@@ -451,209 +452,208 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_symmetric_f32_smef64
451
452
  if (depth_offset + depth_batch_start >= depth) break;
452
453
 
453
454
  // ZA transpose for A rows: extending load f32→f64, MOVA directly into ZA0
454
- svbool_t const batch_predicate_f64x = svwhilelt_b64_u64(0u, (uint64_t)batch_size);
455
- svbool_t const a_depth_predicate_f64x = svwhilelt_b64_u64(depth_offset + depth_batch_start,
456
- (uint64_t)depth);
455
+ svbool_t const batch_predicate_b64x = svwhilelt_b64_u64(0u, batch_size);
456
+ svbool_t const a_depth_predicate_b64x = svwhilelt_b64_u64(depth_offset + depth_batch_start, depth);
457
457
  svzero_mask_za(nk_sme_zero_za64_tile_0_);
458
458
  for (nk_size_t row_in_tile = 0; row_in_tile < rows_actual; row_in_tile++) {
459
459
  nk_size_t const row_abs = row_tile_start + row_in_tile;
460
460
  svfloat64_t a_row_widened_f64x = svcvt_f64_f32_x(
461
- batch_predicate_f64x,
461
+ batch_predicate_b64x,
462
462
  svreinterpret_f32_u64(svld1uw_u64(
463
- a_depth_predicate_f64x, (nk_u32_t const *)&vectors[row_abs * stride_elements +
463
+ a_depth_predicate_b64x, (nk_u32_t const *)&vectors[row_abs * stride_elements +
464
464
  depth_offset + depth_batch_start])));
465
- svwrite_hor_za64_f64_m(0, row_in_tile, batch_predicate_f64x, a_row_widened_f64x);
465
+ svwrite_hor_za64_f64_m(0, row_in_tile, batch_predicate_b64x, a_row_widened_f64x);
466
466
  }
467
467
 
468
468
  // Save A columns from ZA0 to stack buffer
469
469
  for (nk_size_t s = 0; s < batch_size; s++)
470
- svst1_f64(predicate_all_f64x, a_buffer[s],
471
- svread_ver_za64_f64_m(svdup_f64(0), row_predicate_f64x, 0, s));
470
+ svst1_f64(predicate_all_b64x, a_buffer[s],
471
+ svread_ver_za64_f64_m(svdup_f64(0), row_predicate_b64x, 0, s));
472
472
 
473
473
  // Column tile 0 → ZA1 via MOVA
474
474
  svzero_mask_za(nk_sme_zero_za64_tile_0_);
475
475
  for (nk_size_t column = 0; column < tile_dimension; column++) {
476
476
  nk_size_t const column_abs = (column_tile_index + 0) * tile_dimension + column;
477
- if (column_abs < n_vectors) {
477
+ if (column_abs < vectors_count) {
478
478
  svfloat64_t widened_f64x = svcvt_f64_f32_x(
479
- batch_predicate_f64x,
479
+ batch_predicate_b64x,
480
480
  svreinterpret_f32_u64(svld1uw_u64(
481
- a_depth_predicate_f64x,
481
+ a_depth_predicate_b64x,
482
482
  (nk_u32_t const
483
483
  *)&vectors[column_abs * stride_elements + depth_offset + depth_batch_start])));
484
- svwrite_hor_za64_f64_m(0, column, batch_predicate_f64x, widened_f64x);
484
+ svwrite_hor_za64_f64_m(0, column, batch_predicate_b64x, widened_f64x);
485
485
  }
486
486
  }
487
487
  for (nk_size_t step = 0; step < batch_size; step++) {
488
- svfloat64_t a_f64x = svld1_f64(predicate_all_f64x, a_buffer[step]);
489
- svfloat64_t b_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), predicate_all_f64x, 0, step);
490
- svmopa_za64_f64_m(1, row_predicate_f64x, predicate_all_f64x, a_f64x, b_f64x);
488
+ svfloat64_t a_f64x = svld1_f64(predicate_all_b64x, a_buffer[step]);
489
+ svfloat64_t b_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), predicate_all_b64x, 0, step);
490
+ svmopa_za64_f64_m(1, row_predicate_b64x, predicate_all_b64x, a_f64x, b_f64x);
491
491
  }
492
492
 
493
493
  // Column tile 1 → ZA2 via MOVA
494
494
  svzero_mask_za(nk_sme_zero_za64_tile_0_);
495
495
  for (nk_size_t column = 0; column < tile_dimension; column++) {
496
496
  nk_size_t const column_abs = (column_tile_index + 1) * tile_dimension + column;
497
- if (column_abs < n_vectors) {
497
+ if (column_abs < vectors_count) {
498
498
  svfloat64_t widened_f64x = svcvt_f64_f32_x(
499
- batch_predicate_f64x,
499
+ batch_predicate_b64x,
500
500
  svreinterpret_f32_u64(svld1uw_u64(
501
- a_depth_predicate_f64x,
501
+ a_depth_predicate_b64x,
502
502
  (nk_u32_t const
503
503
  *)&vectors[column_abs * stride_elements + depth_offset + depth_batch_start])));
504
- svwrite_hor_za64_f64_m(0, column, batch_predicate_f64x, widened_f64x);
504
+ svwrite_hor_za64_f64_m(0, column, batch_predicate_b64x, widened_f64x);
505
505
  }
506
506
  }
507
507
  for (nk_size_t step = 0; step < batch_size; step++) {
508
- svfloat64_t a_f64x = svld1_f64(predicate_all_f64x, a_buffer[step]);
509
- svfloat64_t b_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), predicate_all_f64x, 0, step);
510
- svmopa_za64_f64_m(2, row_predicate_f64x, predicate_all_f64x, a_f64x, b_f64x);
508
+ svfloat64_t a_f64x = svld1_f64(predicate_all_b64x, a_buffer[step]);
509
+ svfloat64_t b_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), predicate_all_b64x, 0, step);
510
+ svmopa_za64_f64_m(2, row_predicate_b64x, predicate_all_b64x, a_f64x, b_f64x);
511
511
  }
512
512
 
513
513
  // Column tile 2 → ZA3 via MOVA
514
514
  svzero_mask_za(nk_sme_zero_za64_tile_0_);
515
515
  for (nk_size_t column = 0; column < tile_dimension; column++) {
516
516
  nk_size_t const column_abs = (column_tile_index + 2) * tile_dimension + column;
517
- if (column_abs < n_vectors) {
517
+ if (column_abs < vectors_count) {
518
518
  svfloat64_t widened_f64x = svcvt_f64_f32_x(
519
- batch_predicate_f64x,
519
+ batch_predicate_b64x,
520
520
  svreinterpret_f32_u64(svld1uw_u64(
521
- a_depth_predicate_f64x,
521
+ a_depth_predicate_b64x,
522
522
  (nk_u32_t const
523
523
  *)&vectors[column_abs * stride_elements + depth_offset + depth_batch_start])));
524
- svwrite_hor_za64_f64_m(0, column, batch_predicate_f64x, widened_f64x);
524
+ svwrite_hor_za64_f64_m(0, column, batch_predicate_b64x, widened_f64x);
525
525
  }
526
526
  }
527
527
  for (nk_size_t step = 0; step < batch_size; step++) {
528
- svfloat64_t a_f64x = svld1_f64(predicate_all_f64x, a_buffer[step]);
529
- svfloat64_t b_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), predicate_all_f64x, 0, step);
530
- svmopa_za64_f64_m(3, row_predicate_f64x, predicate_all_f64x, a_f64x, b_f64x);
528
+ svfloat64_t a_f64x = svld1_f64(predicate_all_b64x, a_buffer[step]);
529
+ svfloat64_t b_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), predicate_all_b64x, 0, step);
530
+ svmopa_za64_f64_m(3, row_predicate_b64x, predicate_all_b64x, a_f64x, b_f64x);
531
531
  }
532
532
 
533
533
  // Column tile 3 → ZA4 via MOVA
534
534
  svzero_mask_za(nk_sme_zero_za64_tile_0_);
535
535
  for (nk_size_t column = 0; column < tile_dimension; column++) {
536
536
  nk_size_t const column_abs = (column_tile_index + 3) * tile_dimension + column;
537
- if (column_abs < n_vectors) {
537
+ if (column_abs < vectors_count) {
538
538
  svfloat64_t widened_f64x = svcvt_f64_f32_x(
539
- batch_predicate_f64x,
539
+ batch_predicate_b64x,
540
540
  svreinterpret_f32_u64(svld1uw_u64(
541
- a_depth_predicate_f64x,
541
+ a_depth_predicate_b64x,
542
542
  (nk_u32_t const
543
543
  *)&vectors[column_abs * stride_elements + depth_offset + depth_batch_start])));
544
- svwrite_hor_za64_f64_m(0, column, batch_predicate_f64x, widened_f64x);
544
+ svwrite_hor_za64_f64_m(0, column, batch_predicate_b64x, widened_f64x);
545
545
  }
546
546
  }
547
547
  for (nk_size_t step = 0; step < batch_size; step++) {
548
- svfloat64_t a_f64x = svld1_f64(predicate_all_f64x, a_buffer[step]);
549
- svfloat64_t b_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), predicate_all_f64x, 0, step);
550
- svmopa_za64_f64_m(4, row_predicate_f64x, predicate_all_f64x, a_f64x, b_f64x);
548
+ svfloat64_t a_f64x = svld1_f64(predicate_all_b64x, a_buffer[step]);
549
+ svfloat64_t b_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), predicate_all_b64x, 0, step);
550
+ svmopa_za64_f64_m(4, row_predicate_b64x, predicate_all_b64x, a_f64x, b_f64x);
551
551
  }
552
552
 
553
553
  // Column tile 4 → ZA5 via MOVA
554
554
  svzero_mask_za(nk_sme_zero_za64_tile_0_);
555
555
  for (nk_size_t column = 0; column < tile_dimension; column++) {
556
556
  nk_size_t const column_abs = (column_tile_index + 4) * tile_dimension + column;
557
- if (column_abs < n_vectors) {
557
+ if (column_abs < vectors_count) {
558
558
  svfloat64_t widened_f64x = svcvt_f64_f32_x(
559
- batch_predicate_f64x,
559
+ batch_predicate_b64x,
560
560
  svreinterpret_f32_u64(svld1uw_u64(
561
- a_depth_predicate_f64x,
561
+ a_depth_predicate_b64x,
562
562
  (nk_u32_t const
563
563
  *)&vectors[column_abs * stride_elements + depth_offset + depth_batch_start])));
564
- svwrite_hor_za64_f64_m(0, column, batch_predicate_f64x, widened_f64x);
564
+ svwrite_hor_za64_f64_m(0, column, batch_predicate_b64x, widened_f64x);
565
565
  }
566
566
  }
567
567
  for (nk_size_t step = 0; step < batch_size; step++) {
568
- svfloat64_t a_f64x = svld1_f64(predicate_all_f64x, a_buffer[step]);
569
- svfloat64_t b_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), predicate_all_f64x, 0, step);
570
- svmopa_za64_f64_m(5, row_predicate_f64x, predicate_all_f64x, a_f64x, b_f64x);
568
+ svfloat64_t a_f64x = svld1_f64(predicate_all_b64x, a_buffer[step]);
569
+ svfloat64_t b_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), predicate_all_b64x, 0, step);
570
+ svmopa_za64_f64_m(5, row_predicate_b64x, predicate_all_b64x, a_f64x, b_f64x);
571
571
  }
572
572
 
573
573
  // Column tile 5 → ZA6 via MOVA
574
574
  svzero_mask_za(nk_sme_zero_za64_tile_0_);
575
575
  for (nk_size_t column = 0; column < tile_dimension; column++) {
576
576
  nk_size_t const column_abs = (column_tile_index + 5) * tile_dimension + column;
577
- if (column_abs < n_vectors) {
577
+ if (column_abs < vectors_count) {
578
578
  svfloat64_t widened_f64x = svcvt_f64_f32_x(
579
- batch_predicate_f64x,
579
+ batch_predicate_b64x,
580
580
  svreinterpret_f32_u64(svld1uw_u64(
581
- a_depth_predicate_f64x,
581
+ a_depth_predicate_b64x,
582
582
  (nk_u32_t const
583
583
  *)&vectors[column_abs * stride_elements + depth_offset + depth_batch_start])));
584
- svwrite_hor_za64_f64_m(0, column, batch_predicate_f64x, widened_f64x);
584
+ svwrite_hor_za64_f64_m(0, column, batch_predicate_b64x, widened_f64x);
585
585
  }
586
586
  }
587
587
  for (nk_size_t step = 0; step < batch_size; step++) {
588
- svfloat64_t a_f64x = svld1_f64(predicate_all_f64x, a_buffer[step]);
589
- svfloat64_t b_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), predicate_all_f64x, 0, step);
590
- svmopa_za64_f64_m(6, row_predicate_f64x, predicate_all_f64x, a_f64x, b_f64x);
588
+ svfloat64_t a_f64x = svld1_f64(predicate_all_b64x, a_buffer[step]);
589
+ svfloat64_t b_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), predicate_all_b64x, 0, step);
590
+ svmopa_za64_f64_m(6, row_predicate_b64x, predicate_all_b64x, a_f64x, b_f64x);
591
591
  }
592
592
 
593
593
  // Column tile 6 → ZA7 via MOVA
594
594
  svzero_mask_za(nk_sme_zero_za64_tile_0_);
595
595
  for (nk_size_t column = 0; column < tile_dimension; column++) {
596
596
  nk_size_t const column_abs = (column_tile_index + 6) * tile_dimension + column;
597
- if (column_abs < n_vectors) {
597
+ if (column_abs < vectors_count) {
598
598
  svfloat64_t widened_f64x = svcvt_f64_f32_x(
599
- batch_predicate_f64x,
599
+ batch_predicate_b64x,
600
600
  svreinterpret_f32_u64(svld1uw_u64(
601
- a_depth_predicate_f64x,
601
+ a_depth_predicate_b64x,
602
602
  (nk_u32_t const
603
603
  *)&vectors[column_abs * stride_elements + depth_offset + depth_batch_start])));
604
- svwrite_hor_za64_f64_m(0, column, batch_predicate_f64x, widened_f64x);
604
+ svwrite_hor_za64_f64_m(0, column, batch_predicate_b64x, widened_f64x);
605
605
  }
606
606
  }
607
607
  for (nk_size_t step = 0; step < batch_size; step++) {
608
- svfloat64_t a_f64x = svld1_f64(predicate_all_f64x, a_buffer[step]);
609
- svfloat64_t b_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), predicate_all_f64x, 0, step);
610
- svmopa_za64_f64_m(7, row_predicate_f64x, predicate_all_f64x, a_f64x, b_f64x);
608
+ svfloat64_t a_f64x = svld1_f64(predicate_all_b64x, a_buffer[step]);
609
+ svfloat64_t b_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), predicate_all_b64x, 0, step);
610
+ svmopa_za64_f64_m(7, row_predicate_b64x, predicate_all_b64x, a_f64x, b_f64x);
611
611
  }
612
612
  }
613
613
  }
614
614
 
615
615
  // Extract results and store native f64 outputs.
616
- svbool_t const predicate_tile_f64x = svwhilelt_b64_u64(0u, tile_dimension);
616
+ svbool_t const predicate_tile_b64x = svwhilelt_b64_u64(0u, tile_dimension);
617
617
  // The 7th tile (index 6) may be partial when it's the last column tile
618
618
  nk_size_t const last_fast_col_start = (column_tile_index + 6) * tile_dimension;
619
- nk_size_t const last_fast_cols = (last_fast_col_start + tile_dimension <= n_vectors)
619
+ nk_size_t const last_fast_cols = (last_fast_col_start + tile_dimension <= vectors_count)
620
620
  ? tile_dimension
621
- : (n_vectors - last_fast_col_start);
622
- svbool_t const last_tile_pred_f64x = svwhilelt_b64_u64(0u, last_fast_cols);
621
+ : (vectors_count - last_fast_col_start);
622
+ svbool_t const last_tile_pred_b64x = svwhilelt_b64_u64(0u, last_fast_cols);
623
623
  for (nk_size_t row = 0; row < rows_actual; row++) {
624
624
  nk_size_t const row_abs = row_tile_start + row;
625
625
  nk_f64_t *result_row = result + row_abs * result_stride_elements;
626
626
 
627
- svfloat64_t za_row_f64x = svread_hor_za64_f64_m(svdup_f64(0), predicate_all_f64x, 1, row);
628
- svst1_f64(predicate_tile_f64x, result_row + (column_tile_index + 0) * tile_dimension, za_row_f64x);
627
+ svfloat64_t za_row_f64x = svread_hor_za64_f64_m(svdup_f64(0), predicate_all_b64x, 1, row);
628
+ svst1_f64(predicate_tile_b64x, result_row + (column_tile_index + 0) * tile_dimension, za_row_f64x);
629
629
 
630
- za_row_f64x = svread_hor_za64_f64_m(svdup_f64(0), predicate_all_f64x, 2, row);
631
- svst1_f64(predicate_tile_f64x, result_row + (column_tile_index + 1) * tile_dimension, za_row_f64x);
630
+ za_row_f64x = svread_hor_za64_f64_m(svdup_f64(0), predicate_all_b64x, 2, row);
631
+ svst1_f64(predicate_tile_b64x, result_row + (column_tile_index + 1) * tile_dimension, za_row_f64x);
632
632
 
633
- za_row_f64x = svread_hor_za64_f64_m(svdup_f64(0), predicate_all_f64x, 3, row);
634
- svst1_f64(predicate_tile_f64x, result_row + (column_tile_index + 2) * tile_dimension, za_row_f64x);
633
+ za_row_f64x = svread_hor_za64_f64_m(svdup_f64(0), predicate_all_b64x, 3, row);
634
+ svst1_f64(predicate_tile_b64x, result_row + (column_tile_index + 2) * tile_dimension, za_row_f64x);
635
635
 
636
- za_row_f64x = svread_hor_za64_f64_m(svdup_f64(0), predicate_all_f64x, 4, row);
637
- svst1_f64(predicate_tile_f64x, result_row + (column_tile_index + 3) * tile_dimension, za_row_f64x);
636
+ za_row_f64x = svread_hor_za64_f64_m(svdup_f64(0), predicate_all_b64x, 4, row);
637
+ svst1_f64(predicate_tile_b64x, result_row + (column_tile_index + 3) * tile_dimension, za_row_f64x);
638
638
 
639
- za_row_f64x = svread_hor_za64_f64_m(svdup_f64(0), predicate_all_f64x, 5, row);
640
- svst1_f64(predicate_tile_f64x, result_row + (column_tile_index + 4) * tile_dimension, za_row_f64x);
639
+ za_row_f64x = svread_hor_za64_f64_m(svdup_f64(0), predicate_all_b64x, 5, row);
640
+ svst1_f64(predicate_tile_b64x, result_row + (column_tile_index + 4) * tile_dimension, za_row_f64x);
641
641
 
642
- za_row_f64x = svread_hor_za64_f64_m(svdup_f64(0), predicate_all_f64x, 6, row);
643
- svst1_f64(predicate_tile_f64x, result_row + (column_tile_index + 5) * tile_dimension, za_row_f64x);
642
+ za_row_f64x = svread_hor_za64_f64_m(svdup_f64(0), predicate_all_b64x, 6, row);
643
+ svst1_f64(predicate_tile_b64x, result_row + (column_tile_index + 5) * tile_dimension, za_row_f64x);
644
644
 
645
- za_row_f64x = svread_hor_za64_f64_m(svdup_f64(0), predicate_all_f64x, 7, row);
646
- svst1_f64(last_tile_pred_f64x, result_row + (column_tile_index + 6) * tile_dimension, za_row_f64x);
645
+ za_row_f64x = svread_hor_za64_f64_m(svdup_f64(0), predicate_all_b64x, 7, row);
646
+ svst1_f64(last_tile_pred_b64x, result_row + (column_tile_index + 6) * tile_dimension, za_row_f64x);
647
647
  }
648
648
  }
649
649
 
650
650
  // Remainder: 1 column tile at a time
651
651
  for (; column_tile_index < column_tile_count; column_tile_index++) {
652
652
  nk_size_t const column_tile_start = column_tile_index * tile_dimension;
653
- nk_size_t const columns_remaining = (column_tile_start + tile_dimension <= n_vectors)
653
+ nk_size_t const columns_remaining = (column_tile_start + tile_dimension <= vectors_count)
654
654
  ? tile_dimension
655
- : (n_vectors - column_tile_start);
656
- svbool_t const column_predicate_f64x = svwhilelt_b64_u64(0u, columns_remaining);
655
+ : (vectors_count - column_tile_start);
656
+ svbool_t const column_predicate_b64x = svwhilelt_b64_u64(0u, columns_remaining);
657
657
 
658
658
  svzero_mask_za(nk_sme_zero_za64_tile_1_);
659
659
 
@@ -669,44 +669,43 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_symmetric_f32_smef64
669
669
 
670
670
  if (depth_offset + depth_batch_start >= depth) break;
671
671
 
672
- svbool_t const batch_predicate_f64x = svwhilelt_b64_u64(0u, (uint64_t)batch_size);
673
- svbool_t const a_depth_pred_f64x = svwhilelt_b64_u64(depth_offset + depth_batch_start,
674
- (uint64_t)depth);
672
+ svbool_t const batch_predicate_b64x = svwhilelt_b64_u64(0u, batch_size);
673
+ svbool_t const a_depth_pred_b64x = svwhilelt_b64_u64(depth_offset + depth_batch_start, depth);
675
674
  svzero_mask_za(nk_sme_zero_za64_tile_0_);
676
675
  for (nk_size_t row_in_tile = 0; row_in_tile < rows_actual; row_in_tile++) {
677
676
  nk_size_t const row_abs = row_tile_start + row_in_tile;
678
677
  svfloat64_t a_row_widened_f64x = svcvt_f64_f32_x(
679
- batch_predicate_f64x,
678
+ batch_predicate_b64x,
680
679
  svreinterpret_f32_u64(svld1uw_u64(
681
- a_depth_pred_f64x, (nk_u32_t const *)&vectors[row_abs * stride_elements + depth_offset +
680
+ a_depth_pred_b64x, (nk_u32_t const *)&vectors[row_abs * stride_elements + depth_offset +
682
681
  depth_batch_start])));
683
- svwrite_hor_za64_f64_m(0, row_in_tile, batch_predicate_f64x, a_row_widened_f64x);
682
+ svwrite_hor_za64_f64_m(0, row_in_tile, batch_predicate_b64x, a_row_widened_f64x);
684
683
  }
685
684
 
686
685
  // Save A columns from ZA0 to stack buffer
687
686
  for (nk_size_t s = 0; s < batch_size; s++)
688
- svst1_f64(predicate_all_f64x, a_buffer[s],
689
- svread_ver_za64_f64_m(svdup_f64(0), row_predicate_f64x, 0, s));
687
+ svst1_f64(predicate_all_b64x, a_buffer[s],
688
+ svread_ver_za64_f64_m(svdup_f64(0), row_predicate_b64x, 0, s));
690
689
 
691
690
  // Load B column tile into ZA0 via MOVA, vertical read + FMOPA into ZA1
692
691
  svzero_mask_za(nk_sme_zero_za64_tile_0_);
693
692
  for (nk_size_t column = 0; column < tile_dimension; column++) {
694
693
  nk_size_t const column_abs = column_tile_start + column;
695
- if (column_abs < n_vectors) {
694
+ if (column_abs < vectors_count) {
696
695
  svfloat64_t widened_f64x = svcvt_f64_f32_x(
697
- batch_predicate_f64x,
696
+ batch_predicate_b64x,
698
697
  svreinterpret_f32_u64(svld1uw_u64(
699
- a_depth_pred_f64x, (nk_u32_t const *)&vectors[column_abs * stride_elements +
698
+ a_depth_pred_b64x, (nk_u32_t const *)&vectors[column_abs * stride_elements +
700
699
  depth_offset + depth_batch_start])));
701
- svwrite_hor_za64_f64_m(0, column, batch_predicate_f64x, widened_f64x);
700
+ svwrite_hor_za64_f64_m(0, column, batch_predicate_b64x, widened_f64x);
702
701
  }
703
702
  }
704
703
  for (nk_size_t step = 0; step < batch_size; step++) {
705
704
  nk_size_t const k_abs = depth_offset + depth_batch_start + step;
706
705
  if (k_abs >= depth) break;
707
- svfloat64_t a_f64x = svld1_f64(predicate_all_f64x, a_buffer[step]);
708
- svfloat64_t b_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), column_predicate_f64x, 0, step);
709
- svmopa_za64_f64_m(1, row_predicate_f64x, column_predicate_f64x, a_f64x, b_f64x);
706
+ svfloat64_t a_f64x = svld1_f64(predicate_all_b64x, a_buffer[step]);
707
+ svfloat64_t b_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), column_predicate_b64x, 0, step);
708
+ svmopa_za64_f64_m(1, row_predicate_b64x, column_predicate_b64x, a_f64x, b_f64x);
710
709
  }
711
710
  }
712
711
  }
@@ -714,25 +713,26 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_symmetric_f32_smef64
714
713
  // Store native f64 outputs for the tail column tile.
715
714
  for (nk_size_t row = 0; row < rows_actual; row++) {
716
715
  nk_size_t const row_abs = row_tile_start + row;
717
- svfloat64_t za_row_f64x = svread_hor_za64_f64_m(svdup_f64(0), predicate_all_f64x, 1, row);
718
- svst1_f64(column_predicate_f64x, result + row_abs * result_stride_elements + column_tile_start,
716
+ svfloat64_t za_row_f64x = svread_hor_za64_f64_m(svdup_f64(0), predicate_all_b64x, 1, row);
717
+ svst1_f64(column_predicate_b64x, result + row_abs * result_stride_elements + column_tile_start,
719
718
  za_row_f64x);
720
719
  }
721
720
  }
722
721
  }
723
722
  }
724
723
 
725
- NK_PUBLIC void nk_dots_symmetric_f32_smef64(nk_f32_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
726
- nk_size_t stride, nk_f64_t *result, nk_size_t result_stride,
727
- nk_size_t row_start, nk_size_t row_count) {
724
+ NK_PUBLIC void nk_dots_symmetric_f32_smef64(nk_f32_t const *vectors, nk_size_t vectors_count, nk_size_t depth,
725
+ nk_size_t stride_in_bytes, nk_f64_t *result,
726
+ nk_size_t result_stride_in_bytes, nk_size_t row_start,
727
+ nk_size_t row_count) {
728
728
 
729
- nk_size_t const stride_elements = stride / sizeof(nk_f32_t);
730
- nk_size_t const result_stride_elements = result_stride / sizeof(nk_f64_t);
731
- nk_dots_symmetric_f32_smef64_streaming_(vectors, n_vectors, depth, stride_elements, result, result_stride_elements,
732
- row_start, row_count);
729
+ nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_f32_t);
730
+ nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f64_t);
731
+ nk_dots_symmetric_f32_smef64_streaming_(vectors, vectors_count, depth, stride_elements, result,
732
+ result_stride_elements, row_start, row_count);
733
733
  }
734
734
 
735
- #pragma endregion // Single Precision Floats
735
+ #pragma endregion F32 Floats
736
736
 
737
737
  /*
738
738
  * f64 GEMM via 3-way Ozaki splitting using FMOPA with ZA64 tiles.
@@ -768,7 +768,7 @@ NK_PUBLIC void nk_dots_symmetric_f32_smef64(nk_f32_t const *vectors, nk_size_t n
768
768
  * - f64 input vectors: 8 elements (SVL/64)
769
769
  * - FMOPA predicates: b64 (native f64 granularity)
770
770
  */
771
- #pragma region Double Precision Floats
771
+ #pragma region F64 Floats
772
772
 
773
773
  /* Mantissa bit masks for 3-way Ozaki splitting of f64 values.
774
774
  *
@@ -783,17 +783,17 @@ NK_PUBLIC void nk_dots_symmetric_f32_smef64(nk_f32_t const *vectors, nk_size_t n
783
783
  *
784
784
  * All slices fit in f32 (24-bit significand). Products: max 19+19 = 38 ≤ 53, exact in f64.
785
785
  */
786
- NK_PUBLIC nk_u64_t nk_f64_smef64_ozaki_mask_19_bits_(void) NK_STREAMING_COMPATIBLE_ {
786
+ NK_PUBLIC nk_u64_t nk_f64_smef64_ozaki_mask_19_bits_(void) NK_STREAMING_ {
787
787
  return 0xFFFFFFFC00000000ULL; // keep top 19 sig bits
788
788
  }
789
- NK_PUBLIC nk_u64_t nk_f64_smef64_ozaki_mask_17_bits_(void) NK_STREAMING_COMPATIBLE_ {
789
+ NK_PUBLIC nk_u64_t nk_f64_smef64_ozaki_mask_17_bits_(void) NK_STREAMING_ {
790
790
  return 0xFFFFFFF000000000ULL; // keep top 17 sig bits
791
791
  }
792
792
 
793
793
  /* Split a scalar f64 into 3 non-overlapping Ozaki slices (19+17+17 mantissa bits).
794
794
  * Each slice fits in f32. Outputs stored via pointers. */
795
795
  NK_PUBLIC void nk_f64_smef64_ozaki_split_f64_(nk_f64_t val, nk_f64_t *slice_0, nk_f64_t *slice_1,
796
- nk_f64_t *slice_2) NK_STREAMING_COMPATIBLE_ {
796
+ nk_f64_t *slice_2) NK_STREAMING_ {
797
797
  nk_fui64_t pun;
798
798
  pun.f = val;
799
799
  pun.u &= nk_f64_smef64_ozaki_mask_19_bits_();
@@ -806,36 +806,39 @@ NK_PUBLIC void nk_f64_smef64_ozaki_split_f64_(nk_f64_t val, nk_f64_t *slice_0, n
806
806
  }
807
807
 
808
808
  __arm_locally_streaming __arm_new("za") static void nk_dots_symmetric_f64_smef64_streaming_(
809
- nk_f64_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride_elements, nk_f64_t *result,
809
+ nk_f64_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, nk_f64_t *result,
810
810
  nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
811
811
 
812
812
  nk_size_t const tile_dimension = svcntd();
813
813
  nk_size_t const depth_steps_per_batch = tile_dimension;
814
814
 
815
- svbool_t const predicate_all_f64x = svptrue_b64();
815
+ svbool_t const predicate_all_b64x = svptrue_b64();
816
816
  svuint64_t const ozaki_mask_19_u64x = svdup_u64(nk_f64_smef64_ozaki_mask_19_bits_());
817
817
  svuint64_t const ozaki_mask_17_u64x = svdup_u64(nk_f64_smef64_ozaki_mask_17_bits_());
818
818
 
819
819
  NK_ALIGN64 nk_f64_t a_buffer[8][8]; // save A columns before reusing ZA0 for B
820
820
 
821
821
  nk_size_t const row_end = row_start + row_count;
822
- nk_size_t const column_tile_count = nk_size_divide_round_up_(n_vectors, tile_dimension);
822
+ nk_size_t const column_tile_count = nk_size_divide_round_up_(vectors_count, tile_dimension);
823
823
 
824
824
  // ZA0.D = staging (A then B), ZA1-3.D = merged Ozaki accumulators (i+j=0,1,2)
825
- for (nk_size_t row_tile_start = row_start; row_tile_start < row_end && row_tile_start < n_vectors;
825
+ for (nk_size_t row_tile_start = row_start; row_tile_start < row_end && row_tile_start < vectors_count;
826
826
  row_tile_start += tile_dimension) {
827
827
  nk_size_t const rows_remaining = (row_tile_start + tile_dimension <= row_end) ? tile_dimension
828
828
  : (row_end - row_tile_start);
829
- nk_size_t const rows_clamped = (row_tile_start + rows_remaining <= n_vectors) ? rows_remaining
830
- : (n_vectors - row_tile_start);
831
- svbool_t const row_predicate_f64x = svwhilelt_b64_u64(0u, rows_clamped);
832
-
833
- for (nk_size_t column_tile_index = 0; column_tile_index < column_tile_count; column_tile_index++) {
829
+ nk_size_t const rows_clamped = (row_tile_start + rows_remaining <= vectors_count)
830
+ ? rows_remaining
831
+ : (vectors_count - row_tile_start);
832
+ svbool_t const row_predicate_b64x = svwhilelt_b64_u64(0u, rows_clamped);
833
+
834
+ // Upper triangle: start from this row tile's column
835
+ for (nk_size_t column_tile_index = row_tile_start / tile_dimension; column_tile_index < column_tile_count;
836
+ column_tile_index++) {
834
837
  nk_size_t const column_tile_start = column_tile_index * tile_dimension;
835
- nk_size_t const columns_remaining = (column_tile_start + tile_dimension <= n_vectors)
838
+ nk_size_t const columns_remaining = (column_tile_start + tile_dimension <= vectors_count)
836
839
  ? tile_dimension
837
- : (n_vectors - column_tile_start);
838
- svbool_t const column_predicate_f64x = svwhilelt_b64_u64(0u, columns_remaining);
840
+ : (vectors_count - column_tile_start);
841
+ svbool_t const column_predicate_b64x = svwhilelt_b64_u64(0u, columns_remaining);
839
842
 
840
843
  // Zero ZA1-3 (3 merged Ozaki accumulators)
841
844
  svzero_mask_za(nk_sme_zero_za64_tiles_1_3_);
@@ -846,67 +849,67 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_symmetric_f64_smef64
846
849
  ? depth_batch_start + depth_steps_per_batch
847
850
  : depth;
848
851
  nk_size_t const batch_size = depth_batch_end - depth_batch_start;
849
- svbool_t const batch_predicate_f64x = svwhilelt_b64_u64(0u, batch_size);
852
+ svbool_t const batch_predicate_b64x = svwhilelt_b64_u64(0u, batch_size);
850
853
 
851
854
  // Load A rows into ZA0
852
855
  svzero_mask_za(nk_sme_zero_za64_tile_0_);
853
856
  for (nk_size_t row_in_tile = 0; row_in_tile < rows_clamped; row_in_tile++) {
854
857
  nk_size_t const row_abs = row_tile_start + row_in_tile;
855
- svld1_hor_za64(0, row_in_tile, batch_predicate_f64x,
858
+ svld1_hor_za64(0, row_in_tile, batch_predicate_b64x,
856
859
  vectors + row_abs * stride_elements + depth_batch_start);
857
860
  }
858
861
 
859
862
  // Save A columns to buffer before reusing ZA0 for B
860
863
  for (nk_size_t s = 0; s < batch_size; s++)
861
- svst1_f64(predicate_all_f64x, a_buffer[s],
862
- svread_ver_za64_f64_m(svdup_f64(0.0), row_predicate_f64x, 0, s));
864
+ svst1_f64(predicate_all_b64x, a_buffer[s],
865
+ svread_ver_za64_f64_m(svdup_f64(0.0), row_predicate_b64x, 0, s));
863
866
 
864
867
  // Load B columns into ZA0 (reuse)
865
868
  svzero_mask_za(nk_sme_zero_za64_tile_0_);
866
869
  for (nk_size_t column = 0; column < tile_dimension; column++) {
867
870
  nk_size_t const column_abs = column_tile_start + column;
868
- if (column_abs < n_vectors)
869
- svld1_hor_za64(0, column, batch_predicate_f64x,
871
+ if (column_abs < vectors_count)
872
+ svld1_hor_za64(0, column, batch_predicate_b64x,
870
873
  vectors + column_abs * stride_elements + depth_batch_start);
871
874
  }
872
875
 
873
876
  // Split both A and B into 3 Ozaki slices, 6 FMOPAs per step
874
877
  for (nk_size_t step = 0; step < batch_size; step++) {
875
- svfloat64_t a_f64x = svld1_f64(predicate_all_f64x, a_buffer[step]);
878
+ svfloat64_t a_f64x = svld1_f64(predicate_all_b64x, a_buffer[step]);
876
879
  svuint64_t a_bits_u64x = svreinterpret_u64_f64(a_f64x);
877
880
  svfloat64_t a_slice_0_f64x = svreinterpret_f64_u64(
878
- svand_u64_x(predicate_all_f64x, a_bits_u64x, ozaki_mask_19_u64x));
879
- svfloat64_t residual_a_f64x = svsub_f64_x(predicate_all_f64x, a_f64x, a_slice_0_f64x);
881
+ svand_u64_x(predicate_all_b64x, a_bits_u64x, ozaki_mask_19_u64x));
882
+ svfloat64_t residual_a_f64x = svsub_f64_x(predicate_all_b64x, a_f64x, a_slice_0_f64x);
880
883
  svuint64_t residual_a_bits_u64x = svreinterpret_u64_f64(residual_a_f64x);
881
884
  svfloat64_t a_slice_1_f64x = svreinterpret_f64_u64(
882
- svand_u64_x(predicate_all_f64x, residual_a_bits_u64x, ozaki_mask_17_u64x));
883
- svfloat64_t a_slice_2_f64x = svsub_f64_x(predicate_all_f64x, residual_a_f64x, a_slice_1_f64x);
885
+ svand_u64_x(predicate_all_b64x, residual_a_bits_u64x, ozaki_mask_17_u64x));
886
+ svfloat64_t a_slice_2_f64x = svsub_f64_x(predicate_all_b64x, residual_a_f64x, a_slice_1_f64x);
884
887
 
885
- svfloat64_t b_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), column_predicate_f64x, 0, step);
888
+ svfloat64_t b_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), column_predicate_b64x, 0, step);
886
889
  svuint64_t b_bits_u64x = svreinterpret_u64_f64(b_f64x);
887
890
  svfloat64_t b_slice_0_f64x = svreinterpret_f64_u64(
888
- svand_u64_x(predicate_all_f64x, b_bits_u64x, ozaki_mask_19_u64x));
889
- svfloat64_t residual_b_f64x = svsub_f64_x(predicate_all_f64x, b_f64x, b_slice_0_f64x);
891
+ svand_u64_x(predicate_all_b64x, b_bits_u64x, ozaki_mask_19_u64x));
892
+ svfloat64_t residual_b_f64x = svsub_f64_x(predicate_all_b64x, b_f64x, b_slice_0_f64x);
890
893
  svuint64_t residual_b_bits_u64x = svreinterpret_u64_f64(residual_b_f64x);
891
894
  svfloat64_t b_slice_1_f64x = svreinterpret_f64_u64(
892
- svand_u64_x(predicate_all_f64x, residual_b_bits_u64x, ozaki_mask_17_u64x));
893
- svfloat64_t b_slice_2_f64x = svsub_f64_x(predicate_all_f64x, residual_b_f64x, b_slice_1_f64x);
895
+ svand_u64_x(predicate_all_b64x, residual_b_bits_u64x, ozaki_mask_17_u64x));
896
+ svfloat64_t b_slice_2_f64x = svsub_f64_x(predicate_all_b64x, residual_b_f64x, b_slice_1_f64x);
894
897
 
895
898
  // 6 FMOPAs reordered to minimize WAW pipeline stalls on 3 tiles.
896
899
  // Same-tile accumulation order preserved (bit-identical output).
897
900
  // Tile schedule: ZA3(0), ZA2(1), ZA1(2), ZA3(4), ZA2(5), ZA3(8).
898
901
  // 9 cycles vs 15 original (3 unavoidable bubbles with only 3 tiles).
899
- svmopa_za64_f64_m(3, row_predicate_f64x, column_predicate_f64x, a_slice_0_f64x,
902
+ svmopa_za64_f64_m(3, row_predicate_b64x, column_predicate_b64x, a_slice_0_f64x,
900
903
  b_slice_2_f64x); // ZA3: i+j=2 (1/3)
901
- svmopa_za64_f64_m(2, row_predicate_f64x, column_predicate_f64x, a_slice_0_f64x,
904
+ svmopa_za64_f64_m(2, row_predicate_b64x, column_predicate_b64x, a_slice_0_f64x,
902
905
  b_slice_1_f64x); // ZA2: i+j=1 (1/2)
903
- svmopa_za64_f64_m(1, row_predicate_f64x, column_predicate_f64x, a_slice_0_f64x,
906
+ svmopa_za64_f64_m(1, row_predicate_b64x, column_predicate_b64x, a_slice_0_f64x,
904
907
  b_slice_0_f64x); // ZA1: i+j=0
905
- svmopa_za64_f64_m(3, row_predicate_f64x, column_predicate_f64x, a_slice_1_f64x,
908
+ svmopa_za64_f64_m(3, row_predicate_b64x, column_predicate_b64x, a_slice_1_f64x,
906
909
  b_slice_1_f64x); // ZA3: i+j=2 (2/3)
907
- svmopa_za64_f64_m(2, row_predicate_f64x, column_predicate_f64x, a_slice_1_f64x,
910
+ svmopa_za64_f64_m(2, row_predicate_b64x, column_predicate_b64x, a_slice_1_f64x,
908
911
  b_slice_0_f64x); // ZA2: i+j=1 (2/2)
909
- svmopa_za64_f64_m(3, row_predicate_f64x, column_predicate_f64x, a_slice_2_f64x,
912
+ svmopa_za64_f64_m(3, row_predicate_b64x, column_predicate_b64x, a_slice_2_f64x,
910
913
  b_slice_0_f64x); // ZA3: i+j=2 (3/3)
911
914
  }
912
915
  }
@@ -914,31 +917,32 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_symmetric_f64_smef64
914
917
  // Sum ZA3 + ZA2 + ZA1 (smallest to largest)
915
918
  for (nk_size_t row = 0; row < rows_clamped; row++) {
916
919
  nk_size_t const row_abs = row_tile_start + row;
917
- svfloat64_t result_f64x = svread_hor_za64_f64_m(svdup_f64(0.0), predicate_all_f64x, 3, row);
918
- result_f64x = svadd_f64_x(predicate_all_f64x, result_f64x,
919
- svread_hor_za64_f64_m(svdup_f64(0.0), predicate_all_f64x, 2, row));
920
- result_f64x = svadd_f64_x(predicate_all_f64x, result_f64x,
921
- svread_hor_za64_f64_m(svdup_f64(0.0), predicate_all_f64x, 1, row));
922
- svst1_f64(column_predicate_f64x, result + row_abs * result_stride_elements + column_tile_start,
920
+ svfloat64_t result_f64x = svread_hor_za64_f64_m(svdup_f64(0.0), predicate_all_b64x, 3, row);
921
+ result_f64x = svadd_f64_x(predicate_all_b64x, result_f64x,
922
+ svread_hor_za64_f64_m(svdup_f64(0.0), predicate_all_b64x, 2, row));
923
+ result_f64x = svadd_f64_x(predicate_all_b64x, result_f64x,
924
+ svread_hor_za64_f64_m(svdup_f64(0.0), predicate_all_b64x, 1, row));
925
+ svst1_f64(column_predicate_b64x, result + row_abs * result_stride_elements + column_tile_start,
923
926
  result_f64x);
924
927
  }
925
928
  }
926
929
  }
927
930
  }
928
931
 
929
- NK_PUBLIC void nk_dots_symmetric_f64_smef64(nk_f64_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
930
- nk_size_t stride, nk_f64_t *result, nk_size_t result_stride,
931
- nk_size_t row_start, nk_size_t row_count) {
932
+ NK_PUBLIC void nk_dots_symmetric_f64_smef64(nk_f64_t const *vectors, nk_size_t vectors_count, nk_size_t depth,
933
+ nk_size_t stride_in_bytes, nk_f64_t *result,
934
+ nk_size_t result_stride_in_bytes, nk_size_t row_start,
935
+ nk_size_t row_count) {
932
936
 
933
- nk_size_t const stride_elements = stride / sizeof(nk_f64_t);
934
- nk_size_t const result_stride_elements = result_stride / sizeof(nk_f64_t);
935
- nk_dots_symmetric_f64_smef64_streaming_(vectors, n_vectors, depth, stride_elements, result, result_stride_elements,
936
- row_start, row_count);
937
+ nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_f64_t);
938
+ nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f64_t);
939
+ nk_dots_symmetric_f64_smef64_streaming_(vectors, vectors_count, depth, stride_elements, result,
940
+ result_stride_elements, row_start, row_count);
937
941
  }
938
942
 
939
943
  NK_PUBLIC nk_size_t nk_dots_packed_size_f64_smef64(nk_size_t columns, nk_size_t depth) {
940
- nk_size_t const tile_dimension = svcntsd();
941
- nk_size_t const depth_tile_size = svcntsw();
944
+ nk_size_t const tile_dimension = nk_sme_cntd_();
945
+ nk_size_t const depth_tile_size = nk_sme_cntw_();
942
946
  nk_size_t const column_tile_count = nk_size_divide_round_up_(columns, tile_dimension);
943
947
  nk_size_t const depth_tile_count = nk_size_divide_round_up_(depth, depth_tile_size);
944
948
  // Single header + interleaved 3-slice data (3× tile_dimension elements per depth step)
@@ -948,13 +952,13 @@ NK_PUBLIC nk_size_t nk_dots_packed_size_f64_smef64(nk_size_t columns, nk_size_t
948
952
  return size;
949
953
  }
950
954
 
951
- NK_PUBLIC void nk_dots_pack_f64_smef64(nk_f64_t const *b, nk_size_t columns, nk_size_t depth, nk_size_t b_stride,
952
- void *b_packed) {
955
+ NK_PUBLIC void nk_dots_pack_f64_smef64(nk_f64_t const *b, nk_size_t columns, nk_size_t depth,
956
+ nk_size_t b_stride_in_bytes, void *b_packed) {
953
957
 
954
- nk_size_t const b_stride_elements = b_stride / sizeof(nk_f64_t);
958
+ nk_size_t const b_stride_elements = b_stride_in_bytes / sizeof(nk_f64_t);
955
959
 
956
- nk_size_t const tile_dimension = svcntsd();
957
- nk_size_t const depth_tile_size = svcntsw();
960
+ nk_size_t const tile_dimension = nk_sme_cntd_();
961
+ nk_size_t const depth_tile_size = nk_sme_cntw_();
958
962
  nk_size_t const interleaved_stride = 3 * tile_dimension;
959
963
  nk_size_t const interleaved_tile_elements = depth_tile_size * interleaved_stride;
960
964
 
@@ -968,7 +972,7 @@ NK_PUBLIC void nk_dots_pack_f64_smef64(nk_f64_t const *b, nk_size_t columns, nk_
968
972
  header->depth_tile_count = (nk_u32_t)depth_tile_count;
969
973
  header->columns = (nk_u32_t)columns;
970
974
  header->depth = (nk_u32_t)depth;
971
- header->svl_bytes = (nk_u32_t)svcntsb();
975
+ header->svl_bytes = (nk_u32_t)nk_sme_cntb_();
972
976
 
973
977
  nk_f32_t *tiles = (nk_f32_t *)((char *)b_packed + sizeof(nk_dots_sme_packed_header_t));
974
978
 
@@ -1009,7 +1013,7 @@ NK_PUBLIC void nk_dots_pack_f64_smef64(nk_f64_t const *b, nk_size_t columns, nk_
1009
1013
  header->norms_offset = (nk_u32_t)(sizeof(nk_dots_sme_packed_header_t) + data_size);
1010
1014
  nk_f64_t *norms_ptr = (nk_f64_t *)((char *)b_packed + header->norms_offset);
1011
1015
  for (nk_size_t col = 0; col < columns; col++) {
1012
- nk_f64_t const *col_data = (nk_f64_t const *)((char const *)b + col * b_stride);
1016
+ nk_f64_t const *col_data = (nk_f64_t const *)((char const *)b + col * b_stride_in_bytes);
1013
1017
  norms_ptr[col] = nk_dots_reduce_sumsq_f64_(col_data, depth);
1014
1018
  }
1015
1019
  }
@@ -1032,7 +1036,7 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_packed_f64_smef64_st
1032
1036
  // B tile data pointer (f32, interleaved slices)
1033
1037
  nk_f32_t const *b_tiles = (nk_f32_t const *)((char const *)b_packed + sizeof(nk_dots_sme_packed_header_t));
1034
1038
 
1035
- svbool_t const predicate_all_f64x = svptrue_b64();
1039
+ svbool_t const predicate_all_b64x = svptrue_b64();
1036
1040
 
1037
1041
  // Mantissa masks for in-register Ozaki splitting (19+17+17 bits)
1038
1042
  svuint64_t const ozaki_mask_19_u64x = svdup_u64(nk_f64_smef64_ozaki_mask_19_bits_());
@@ -1045,7 +1049,7 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_packed_f64_smef64_st
1045
1049
  row_tile_index++) {
1046
1050
  nk_size_t const row_start = row_tile_index * tile_dimension;
1047
1051
  nk_size_t const rows_remaining = (row_start + tile_dimension <= rows) ? tile_dimension : (rows - row_start);
1048
- svbool_t const row_predicate_f64x = svwhilelt_b64_u64(0u, rows_remaining);
1052
+ svbool_t const row_predicate_b64x = svwhilelt_b64_u64(0u, rows_remaining);
1049
1053
 
1050
1054
  nk_size_t column_tile_index = 0;
1051
1055
 
@@ -1059,8 +1063,8 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_packed_f64_smef64_st
1059
1063
  nk_size_t const columns_remaining_1 = (column_start_1 + tile_dimension <= columns)
1060
1064
  ? tile_dimension
1061
1065
  : (columns - column_start_1);
1062
- svbool_t const column_predicate_0_f64x = svwhilelt_b64_u64(0u, columns_remaining_0);
1063
- svbool_t const column_predicate_1_f64x = svwhilelt_b64_u64(0u, columns_remaining_1);
1066
+ svbool_t const column_predicate_0_b64x = svwhilelt_b64_u64(0u, columns_remaining_0);
1067
+ svbool_t const column_predicate_1_b64x = svwhilelt_b64_u64(0u, columns_remaining_1);
1064
1068
 
1065
1069
  // Zero ZA1-6 (3 accumulators × 2 column tiles)
1066
1070
  svzero_mask_za(nk_sme_zero_za64_tiles_1_6_);
@@ -1081,9 +1085,9 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_packed_f64_smef64_st
1081
1085
  svzero_mask_za(nk_sme_zero_za64_tile_0_);
1082
1086
  for (nk_size_t row_in_tile = 0; row_in_tile < rows_remaining; row_in_tile++) {
1083
1087
  nk_size_t const a_row = row_start + row_in_tile;
1084
- svbool_t const a_depth_predicate_f64x = svwhilelt_b64_u64(depth_offset + depth_batch_start,
1085
- (uint64_t)depth);
1086
- svld1_hor_za64(0, row_in_tile, a_depth_predicate_f64x,
1088
+ svbool_t const a_depth_predicate_b64x = svwhilelt_b64_u64(depth_offset + depth_batch_start,
1089
+ depth);
1090
+ svld1_hor_za64(0, row_in_tile, a_depth_predicate_b64x,
1087
1091
  &a[a_row * a_stride_elements + depth_offset + depth_batch_start]);
1088
1092
  }
1089
1093
 
@@ -1100,71 +1104,71 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_packed_f64_smef64_st
1100
1104
  if (k_abs >= depth) break;
1101
1105
 
1102
1106
  // Read A column from ZA0 and split into 3 Ozaki slices
1103
- svfloat64_t a_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), row_predicate_f64x, 0, step);
1107
+ svfloat64_t a_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), row_predicate_b64x, 0, step);
1104
1108
  svuint64_t a_bits_u64x = svreinterpret_u64_f64(a_f64x);
1105
1109
  svfloat64_t a_slice_0_f64x = svreinterpret_f64_u64(
1106
- svand_u64_x(predicate_all_f64x, a_bits_u64x, ozaki_mask_19_u64x));
1107
- svfloat64_t residual_a_f64x = svsub_f64_x(predicate_all_f64x, a_f64x, a_slice_0_f64x);
1110
+ svand_u64_x(predicate_all_b64x, a_bits_u64x, ozaki_mask_19_u64x));
1111
+ svfloat64_t residual_a_f64x = svsub_f64_x(predicate_all_b64x, a_f64x, a_slice_0_f64x);
1108
1112
  svuint64_t residual_a_bits_u64x = svreinterpret_u64_f64(residual_a_f64x);
1109
1113
  svfloat64_t a_slice_1_f64x = svreinterpret_f64_u64(
1110
- svand_u64_x(predicate_all_f64x, residual_a_bits_u64x, ozaki_mask_17_u64x));
1111
- svfloat64_t a_slice_2_f64x = svsub_f64_x(predicate_all_f64x, residual_a_f64x, a_slice_1_f64x);
1114
+ svand_u64_x(predicate_all_b64x, residual_a_bits_u64x, ozaki_mask_17_u64x));
1115
+ svfloat64_t a_slice_2_f64x = svsub_f64_x(predicate_all_b64x, residual_a_f64x, a_slice_1_f64x);
1112
1116
 
1113
1117
  // Load all 6 B slices upfront (3 per column tile) for pipeline interleaving
1114
1118
  nk_size_t const b_tile_offset_0 = b_batch_offset_0 + step * interleaved_stride;
1115
1119
  nk_size_t const b_tile_offset_1 = b_batch_offset_1 + step * interleaved_stride;
1116
1120
  svfloat64_t b_column_0_slice_0_f64x = svcvt_f64_f32_x(
1117
- predicate_all_f64x,
1121
+ predicate_all_b64x,
1118
1122
  svreinterpret_f32_u64(
1119
- svld1uw_u64(predicate_all_f64x, (nk_u32_t const *)(b_tiles + b_tile_offset_0))));
1123
+ svld1uw_u64(predicate_all_b64x, (nk_u32_t const *)(b_tiles + b_tile_offset_0))));
1120
1124
  svfloat64_t b_column_0_slice_1_f64x = svcvt_f64_f32_x(
1121
- predicate_all_f64x,
1125
+ predicate_all_b64x,
1122
1126
  svreinterpret_f32_u64(svld1uw_u64(
1123
- predicate_all_f64x, (nk_u32_t const *)(b_tiles + b_tile_offset_0 + tile_dimension))));
1127
+ predicate_all_b64x, (nk_u32_t const *)(b_tiles + b_tile_offset_0 + tile_dimension))));
1124
1128
  svfloat64_t b_column_0_slice_2_f64x = svcvt_f64_f32_x(
1125
- predicate_all_f64x, svreinterpret_f32_u64(svld1uw_u64(
1126
- predicate_all_f64x, (nk_u32_t const *)(b_tiles + b_tile_offset_0 +
1129
+ predicate_all_b64x, svreinterpret_f32_u64(svld1uw_u64(
1130
+ predicate_all_b64x, (nk_u32_t const *)(b_tiles + b_tile_offset_0 +
1127
1131
  2 * tile_dimension))));
1128
1132
  svfloat64_t b_column_1_slice_0_f64x = svcvt_f64_f32_x(
1129
- predicate_all_f64x,
1133
+ predicate_all_b64x,
1130
1134
  svreinterpret_f32_u64(
1131
- svld1uw_u64(predicate_all_f64x, (nk_u32_t const *)(b_tiles + b_tile_offset_1))));
1135
+ svld1uw_u64(predicate_all_b64x, (nk_u32_t const *)(b_tiles + b_tile_offset_1))));
1132
1136
  svfloat64_t b_column_1_slice_1_f64x = svcvt_f64_f32_x(
1133
- predicate_all_f64x,
1137
+ predicate_all_b64x,
1134
1138
  svreinterpret_f32_u64(svld1uw_u64(
1135
- predicate_all_f64x, (nk_u32_t const *)(b_tiles + b_tile_offset_1 + tile_dimension))));
1139
+ predicate_all_b64x, (nk_u32_t const *)(b_tiles + b_tile_offset_1 + tile_dimension))));
1136
1140
  svfloat64_t b_column_1_slice_2_f64x = svcvt_f64_f32_x(
1137
- predicate_all_f64x, svreinterpret_f32_u64(svld1uw_u64(
1138
- predicate_all_f64x, (nk_u32_t const *)(b_tiles + b_tile_offset_1 +
1141
+ predicate_all_b64x, svreinterpret_f32_u64(svld1uw_u64(
1142
+ predicate_all_b64x, (nk_u32_t const *)(b_tiles + b_tile_offset_1 +
1139
1143
  2 * tile_dimension))));
1140
1144
 
1141
1145
  // 12 FMOPAs interleaved across 6 tiles to eliminate WAW pipeline stalls.
1142
1146
  // Same-tile accumulation order preserved (bit-identical output).
1143
1147
  // Tile gaps: ZA3 at 0,6,10 (6,4); ZA6 at 1,7,11 (6,4); ZA2 at 4,8 (4);
1144
1148
  // ZA5 at 5,9 (4); ZA1 at 2; ZA4 at 3. All gaps >= 4-cycle latency.
1145
- svmopa_za64_f64_m(3, row_predicate_f64x, column_predicate_0_f64x, a_slice_0_f64x,
1149
+ svmopa_za64_f64_m(3, row_predicate_b64x, column_predicate_0_b64x, a_slice_0_f64x,
1146
1150
  b_column_0_slice_2_f64x); // ZA3: i+j=2 (1/3)
1147
- svmopa_za64_f64_m(6, row_predicate_f64x, column_predicate_1_f64x, a_slice_0_f64x,
1151
+ svmopa_za64_f64_m(6, row_predicate_b64x, column_predicate_1_b64x, a_slice_0_f64x,
1148
1152
  b_column_1_slice_2_f64x); // ZA6: i+j=2 (1/3)
1149
- svmopa_za64_f64_m(1, row_predicate_f64x, column_predicate_0_f64x, a_slice_0_f64x,
1153
+ svmopa_za64_f64_m(1, row_predicate_b64x, column_predicate_0_b64x, a_slice_0_f64x,
1150
1154
  b_column_0_slice_0_f64x); // ZA1: i+j=0
1151
- svmopa_za64_f64_m(4, row_predicate_f64x, column_predicate_1_f64x, a_slice_0_f64x,
1155
+ svmopa_za64_f64_m(4, row_predicate_b64x, column_predicate_1_b64x, a_slice_0_f64x,
1152
1156
  b_column_1_slice_0_f64x); // ZA4: i+j=0
1153
- svmopa_za64_f64_m(2, row_predicate_f64x, column_predicate_0_f64x, a_slice_0_f64x,
1157
+ svmopa_za64_f64_m(2, row_predicate_b64x, column_predicate_0_b64x, a_slice_0_f64x,
1154
1158
  b_column_0_slice_1_f64x); // ZA2: i+j=1 (1/2)
1155
- svmopa_za64_f64_m(5, row_predicate_f64x, column_predicate_1_f64x, a_slice_0_f64x,
1159
+ svmopa_za64_f64_m(5, row_predicate_b64x, column_predicate_1_b64x, a_slice_0_f64x,
1156
1160
  b_column_1_slice_1_f64x); // ZA5: i+j=1 (1/2)
1157
- svmopa_za64_f64_m(3, row_predicate_f64x, column_predicate_0_f64x, a_slice_1_f64x,
1161
+ svmopa_za64_f64_m(3, row_predicate_b64x, column_predicate_0_b64x, a_slice_1_f64x,
1158
1162
  b_column_0_slice_1_f64x); // ZA3: i+j=2 (2/3)
1159
- svmopa_za64_f64_m(6, row_predicate_f64x, column_predicate_1_f64x, a_slice_1_f64x,
1163
+ svmopa_za64_f64_m(6, row_predicate_b64x, column_predicate_1_b64x, a_slice_1_f64x,
1160
1164
  b_column_1_slice_1_f64x); // ZA6: i+j=2 (2/3)
1161
- svmopa_za64_f64_m(2, row_predicate_f64x, column_predicate_0_f64x, a_slice_1_f64x,
1165
+ svmopa_za64_f64_m(2, row_predicate_b64x, column_predicate_0_b64x, a_slice_1_f64x,
1162
1166
  b_column_0_slice_0_f64x); // ZA2: i+j=1 (2/2)
1163
- svmopa_za64_f64_m(5, row_predicate_f64x, column_predicate_1_f64x, a_slice_1_f64x,
1167
+ svmopa_za64_f64_m(5, row_predicate_b64x, column_predicate_1_b64x, a_slice_1_f64x,
1164
1168
  b_column_1_slice_0_f64x); // ZA5: i+j=1 (2/2)
1165
- svmopa_za64_f64_m(3, row_predicate_f64x, column_predicate_0_f64x, a_slice_2_f64x,
1169
+ svmopa_za64_f64_m(3, row_predicate_b64x, column_predicate_0_b64x, a_slice_2_f64x,
1166
1170
  b_column_0_slice_0_f64x); // ZA3: i+j=2 (3/3)
1167
- svmopa_za64_f64_m(6, row_predicate_f64x, column_predicate_1_f64x, a_slice_2_f64x,
1171
+ svmopa_za64_f64_m(6, row_predicate_b64x, column_predicate_1_b64x, a_slice_2_f64x,
1168
1172
  b_column_1_slice_0_f64x); // ZA6: i+j=2 (3/3)
1169
1173
  }
1170
1174
  }
@@ -1173,23 +1177,23 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_packed_f64_smef64_st
1173
1177
  // Simple summation for col tile 0: ZA3 + ZA2 + ZA1 (smallest to largest)
1174
1178
  for (nk_size_t row = 0; row < rows_remaining; row++) {
1175
1179
  nk_f64_t *c_row = c + (row_start + row) * c_stride_elements + column_start_0;
1176
- svfloat64_t result_f64x = svread_hor_za64_f64_m(svdup_f64(0.0), predicate_all_f64x, 3, row);
1177
- result_f64x = svadd_f64_x(predicate_all_f64x, result_f64x,
1178
- svread_hor_za64_f64_m(svdup_f64(0.0), predicate_all_f64x, 2, row));
1179
- result_f64x = svadd_f64_x(predicate_all_f64x, result_f64x,
1180
- svread_hor_za64_f64_m(svdup_f64(0.0), predicate_all_f64x, 1, row));
1181
- svst1_f64(column_predicate_0_f64x, c_row, result_f64x);
1180
+ svfloat64_t result_f64x = svread_hor_za64_f64_m(svdup_f64(0.0), predicate_all_b64x, 3, row);
1181
+ result_f64x = svadd_f64_x(predicate_all_b64x, result_f64x,
1182
+ svread_hor_za64_f64_m(svdup_f64(0.0), predicate_all_b64x, 2, row));
1183
+ result_f64x = svadd_f64_x(predicate_all_b64x, result_f64x,
1184
+ svread_hor_za64_f64_m(svdup_f64(0.0), predicate_all_b64x, 1, row));
1185
+ svst1_f64(column_predicate_0_b64x, c_row, result_f64x);
1182
1186
  }
1183
1187
 
1184
1188
  // Simple summation for col tile 1: ZA6 + ZA5 + ZA4 (smallest to largest)
1185
1189
  for (nk_size_t row = 0; row < rows_remaining; row++) {
1186
1190
  nk_f64_t *c_row = c + (row_start + row) * c_stride_elements + column_start_1;
1187
- svfloat64_t result_f64x = svread_hor_za64_f64_m(svdup_f64(0.0), predicate_all_f64x, 6, row);
1188
- result_f64x = svadd_f64_x(predicate_all_f64x, result_f64x,
1189
- svread_hor_za64_f64_m(svdup_f64(0.0), predicate_all_f64x, 5, row));
1190
- result_f64x = svadd_f64_x(predicate_all_f64x, result_f64x,
1191
- svread_hor_za64_f64_m(svdup_f64(0.0), predicate_all_f64x, 4, row));
1192
- svst1_f64(column_predicate_1_f64x, c_row, result_f64x);
1191
+ svfloat64_t result_f64x = svread_hor_za64_f64_m(svdup_f64(0.0), predicate_all_b64x, 6, row);
1192
+ result_f64x = svadd_f64_x(predicate_all_b64x, result_f64x,
1193
+ svread_hor_za64_f64_m(svdup_f64(0.0), predicate_all_b64x, 5, row));
1194
+ result_f64x = svadd_f64_x(predicate_all_b64x, result_f64x,
1195
+ svread_hor_za64_f64_m(svdup_f64(0.0), predicate_all_b64x, 4, row));
1196
+ svst1_f64(column_predicate_1_b64x, c_row, result_f64x);
1193
1197
  }
1194
1198
  }
1195
1199
 
@@ -1198,7 +1202,7 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_packed_f64_smef64_st
1198
1202
  nk_size_t const column_start = column_tile_index * tile_dimension;
1199
1203
  nk_size_t const columns_remaining = (column_start + tile_dimension <= columns) ? tile_dimension
1200
1204
  : (columns - column_start);
1201
- svbool_t const column_predicate_f64x = svwhilelt_b64_u64(0u, columns_remaining);
1205
+ svbool_t const column_predicate_b64x = svwhilelt_b64_u64(0u, columns_remaining);
1202
1206
 
1203
1207
  // Zero ZA1-3 (3 merged accumulators)
1204
1208
  svzero_mask_za(nk_sme_zero_za64_tiles_1_3_);
@@ -1219,9 +1223,9 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_packed_f64_smef64_st
1219
1223
  svzero_mask_za(nk_sme_zero_za64_tile_0_);
1220
1224
  for (nk_size_t row_in_tile = 0; row_in_tile < rows_remaining; row_in_tile++) {
1221
1225
  nk_size_t const a_row = row_start + row_in_tile;
1222
- svbool_t const a_depth_predicate_f64x = svwhilelt_b64_u64(depth_offset + depth_batch_start,
1223
- (uint64_t)depth);
1224
- svld1_hor_za64(0, row_in_tile, a_depth_predicate_f64x,
1226
+ svbool_t const a_depth_predicate_b64x = svwhilelt_b64_u64(depth_offset + depth_batch_start,
1227
+ depth);
1228
+ svld1_hor_za64(0, row_in_tile, a_depth_predicate_b64x,
1225
1229
  &a[a_row * a_stride_elements + depth_offset + depth_batch_start]);
1226
1230
  }
1227
1231
 
@@ -1234,45 +1238,45 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_packed_f64_smef64_st
1234
1238
  if (k_abs >= depth) break;
1235
1239
 
1236
1240
  // Read A column from ZA0 and split into 3 Ozaki slices
1237
- svfloat64_t a_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), row_predicate_f64x, 0, step);
1241
+ svfloat64_t a_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), row_predicate_b64x, 0, step);
1238
1242
  svuint64_t a_bits_u64x = svreinterpret_u64_f64(a_f64x);
1239
1243
  svfloat64_t a_slice_0_f64x = svreinterpret_f64_u64(
1240
- svand_u64_x(predicate_all_f64x, a_bits_u64x, ozaki_mask_19_u64x));
1241
- svfloat64_t residual_a_f64x = svsub_f64_x(predicate_all_f64x, a_f64x, a_slice_0_f64x);
1244
+ svand_u64_x(predicate_all_b64x, a_bits_u64x, ozaki_mask_19_u64x));
1245
+ svfloat64_t residual_a_f64x = svsub_f64_x(predicate_all_b64x, a_f64x, a_slice_0_f64x);
1242
1246
  svuint64_t residual_a_bits_u64x = svreinterpret_u64_f64(residual_a_f64x);
1243
1247
  svfloat64_t a_slice_1_f64x = svreinterpret_f64_u64(
1244
- svand_u64_x(predicate_all_f64x, residual_a_bits_u64x, ozaki_mask_17_u64x));
1245
- svfloat64_t a_slice_2_f64x = svsub_f64_x(predicate_all_f64x, residual_a_f64x, a_slice_1_f64x);
1248
+ svand_u64_x(predicate_all_b64x, residual_a_bits_u64x, ozaki_mask_17_u64x));
1249
+ svfloat64_t a_slice_2_f64x = svsub_f64_x(predicate_all_b64x, residual_a_f64x, a_slice_1_f64x);
1246
1250
 
1247
1251
  // Load 3 B slices (contiguous in interleaved layout)
1248
1252
  nk_size_t const b_tile_offset = b_batch_offset + step * interleaved_stride;
1249
1253
  svfloat64_t b_slice_0_f64x = svcvt_f64_f32_x(
1250
- predicate_all_f64x, svreinterpret_f32_u64(svld1uw_u64(
1251
- predicate_all_f64x, (nk_u32_t const *)(b_tiles + b_tile_offset))));
1254
+ predicate_all_b64x, svreinterpret_f32_u64(svld1uw_u64(
1255
+ predicate_all_b64x, (nk_u32_t const *)(b_tiles + b_tile_offset))));
1252
1256
  svfloat64_t b_slice_1_f64x = svcvt_f64_f32_x(
1253
- predicate_all_f64x,
1257
+ predicate_all_b64x,
1254
1258
  svreinterpret_f32_u64(svld1uw_u64(
1255
- predicate_all_f64x, (nk_u32_t const *)(b_tiles + b_tile_offset + tile_dimension))));
1259
+ predicate_all_b64x, (nk_u32_t const *)(b_tiles + b_tile_offset + tile_dimension))));
1256
1260
  svfloat64_t b_slice_2_f64x = svcvt_f64_f32_x(
1257
- predicate_all_f64x,
1261
+ predicate_all_b64x,
1258
1262
  svreinterpret_f32_u64(svld1uw_u64(
1259
- predicate_all_f64x, (nk_u32_t const *)(b_tiles + b_tile_offset + 2 * tile_dimension))));
1263
+ predicate_all_b64x, (nk_u32_t const *)(b_tiles + b_tile_offset + 2 * tile_dimension))));
1260
1264
 
1261
1265
  // 6 FMOPAs reordered to minimize WAW pipeline stalls on 3 tiles.
1262
1266
  // Same-tile accumulation order preserved (bit-identical output).
1263
1267
  // Tile schedule: ZA3(0), ZA2(1), ZA1(2), ZA3(4), ZA2(5), ZA3(8).
1264
1268
  // 9 cycles vs 15 original (3 unavoidable bubbles with only 3 tiles).
1265
- svmopa_za64_f64_m(3, row_predicate_f64x, column_predicate_f64x, a_slice_0_f64x,
1269
+ svmopa_za64_f64_m(3, row_predicate_b64x, column_predicate_b64x, a_slice_0_f64x,
1266
1270
  b_slice_2_f64x); // ZA3: i+j=2 (1/3)
1267
- svmopa_za64_f64_m(2, row_predicate_f64x, column_predicate_f64x, a_slice_0_f64x,
1271
+ svmopa_za64_f64_m(2, row_predicate_b64x, column_predicate_b64x, a_slice_0_f64x,
1268
1272
  b_slice_1_f64x); // ZA2: i+j=1 (1/2)
1269
- svmopa_za64_f64_m(1, row_predicate_f64x, column_predicate_f64x, a_slice_0_f64x,
1273
+ svmopa_za64_f64_m(1, row_predicate_b64x, column_predicate_b64x, a_slice_0_f64x,
1270
1274
  b_slice_0_f64x); // ZA1: i+j=0
1271
- svmopa_za64_f64_m(3, row_predicate_f64x, column_predicate_f64x, a_slice_1_f64x,
1275
+ svmopa_za64_f64_m(3, row_predicate_b64x, column_predicate_b64x, a_slice_1_f64x,
1272
1276
  b_slice_1_f64x); // ZA3: i+j=2 (2/3)
1273
- svmopa_za64_f64_m(2, row_predicate_f64x, column_predicate_f64x, a_slice_1_f64x,
1277
+ svmopa_za64_f64_m(2, row_predicate_b64x, column_predicate_b64x, a_slice_1_f64x,
1274
1278
  b_slice_0_f64x); // ZA2: i+j=1 (2/2)
1275
- svmopa_za64_f64_m(3, row_predicate_f64x, column_predicate_f64x, a_slice_2_f64x,
1279
+ svmopa_za64_f64_m(3, row_predicate_b64x, column_predicate_b64x, a_slice_2_f64x,
1276
1280
  b_slice_0_f64x); // ZA3: i+j=2 (3/3)
1277
1281
  }
1278
1282
  }
@@ -1281,27 +1285,28 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_packed_f64_smef64_st
1281
1285
  // Simple summation: ZA3 + ZA2 + ZA1 (smallest to largest)
1282
1286
  for (nk_size_t row = 0; row < rows_remaining; row++) {
1283
1287
  nk_f64_t *c_row = c + (row_start + row) * c_stride_elements + column_start;
1284
- svfloat64_t result_f64x = svread_hor_za64_f64_m(svdup_f64(0.0), predicate_all_f64x, 3, row);
1285
- result_f64x = svadd_f64_x(predicate_all_f64x, result_f64x,
1286
- svread_hor_za64_f64_m(svdup_f64(0.0), predicate_all_f64x, 2, row));
1287
- result_f64x = svadd_f64_x(predicate_all_f64x, result_f64x,
1288
- svread_hor_za64_f64_m(svdup_f64(0.0), predicate_all_f64x, 1, row));
1289
- svst1_f64(column_predicate_f64x, c_row, result_f64x);
1288
+ svfloat64_t result_f64x = svread_hor_za64_f64_m(svdup_f64(0.0), predicate_all_b64x, 3, row);
1289
+ result_f64x = svadd_f64_x(predicate_all_b64x, result_f64x,
1290
+ svread_hor_za64_f64_m(svdup_f64(0.0), predicate_all_b64x, 2, row));
1291
+ result_f64x = svadd_f64_x(predicate_all_b64x, result_f64x,
1292
+ svread_hor_za64_f64_m(svdup_f64(0.0), predicate_all_b64x, 1, row));
1293
+ svst1_f64(column_predicate_b64x, c_row, result_f64x);
1290
1294
  }
1291
1295
  }
1292
1296
  }
1293
1297
  }
1294
1298
 
1295
1299
  NK_PUBLIC void nk_dots_packed_f64_smef64(nk_f64_t const *a, void const *b_packed, nk_f64_t *c, nk_size_t rows,
1296
- nk_size_t columns, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride) {
1300
+ nk_size_t columns, nk_size_t depth, nk_size_t a_stride_in_bytes,
1301
+ nk_size_t c_stride_in_bytes) {
1297
1302
 
1298
- nk_size_t const a_stride_elements = a_stride / sizeof(nk_f64_t);
1299
- nk_size_t const c_stride_elements = c_stride / sizeof(nk_f64_t);
1303
+ nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_f64_t);
1304
+ nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f64_t);
1300
1305
 
1301
1306
  nk_dots_packed_f64_smef64_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
1302
1307
  }
1303
1308
 
1304
- #pragma endregion // Double Precision Floats
1309
+ #pragma endregion F64 Floats
1305
1310
 
1306
1311
  #if defined(__clang__)
1307
1312
  #pragma clang attribute pop