numkong 7.0.0 → 7.4.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (315) hide show
  1. package/README.md +239 -122
  2. package/binding.gyp +25 -491
  3. package/c/dispatch_bf16.c +59 -1
  4. package/c/dispatch_e2m3.c +41 -8
  5. package/c/dispatch_e3m2.c +49 -8
  6. package/c/dispatch_e4m3.c +51 -9
  7. package/c/dispatch_e5m2.c +45 -1
  8. package/c/dispatch_f16.c +79 -26
  9. package/c/dispatch_f16c.c +5 -5
  10. package/c/dispatch_f32.c +56 -0
  11. package/c/dispatch_f64.c +52 -0
  12. package/c/dispatch_i4.c +3 -0
  13. package/c/dispatch_i8.c +62 -3
  14. package/c/dispatch_other.c +18 -0
  15. package/c/dispatch_u1.c +54 -9
  16. package/c/dispatch_u4.c +3 -0
  17. package/c/dispatch_u8.c +64 -3
  18. package/c/numkong.c +3 -0
  19. package/include/README.md +79 -9
  20. package/include/numkong/attention/sapphireamx.h +278 -276
  21. package/include/numkong/attention/sme.h +983 -977
  22. package/include/numkong/attention.h +1 -1
  23. package/include/numkong/capabilities.h +289 -94
  24. package/include/numkong/cast/README.md +40 -40
  25. package/include/numkong/cast/diamond.h +64 -0
  26. package/include/numkong/cast/haswell.h +42 -194
  27. package/include/numkong/cast/icelake.h +42 -37
  28. package/include/numkong/cast/loongsonasx.h +252 -0
  29. package/include/numkong/cast/neon.h +216 -249
  30. package/include/numkong/cast/powervsx.h +449 -0
  31. package/include/numkong/cast/rvv.h +223 -274
  32. package/include/numkong/cast/sapphire.h +18 -18
  33. package/include/numkong/cast/serial.h +1018 -944
  34. package/include/numkong/cast/skylake.h +82 -23
  35. package/include/numkong/cast/v128relaxed.h +462 -105
  36. package/include/numkong/cast.h +24 -0
  37. package/include/numkong/cast.hpp +44 -0
  38. package/include/numkong/curved/README.md +17 -17
  39. package/include/numkong/curved/neon.h +131 -7
  40. package/include/numkong/curved/neonbfdot.h +6 -7
  41. package/include/numkong/curved/rvv.h +26 -26
  42. package/include/numkong/curved/smef64.h +186 -182
  43. package/include/numkong/curved.h +14 -18
  44. package/include/numkong/dot/README.md +154 -137
  45. package/include/numkong/dot/alder.h +43 -43
  46. package/include/numkong/dot/diamond.h +158 -0
  47. package/include/numkong/dot/genoa.h +4 -30
  48. package/include/numkong/dot/haswell.h +215 -180
  49. package/include/numkong/dot/icelake.h +190 -76
  50. package/include/numkong/dot/loongsonasx.h +671 -0
  51. package/include/numkong/dot/neon.h +124 -73
  52. package/include/numkong/dot/neonbfdot.h +11 -12
  53. package/include/numkong/dot/neonfhm.h +44 -46
  54. package/include/numkong/dot/neonfp8.h +323 -0
  55. package/include/numkong/dot/neonsdot.h +190 -76
  56. package/include/numkong/dot/powervsx.h +752 -0
  57. package/include/numkong/dot/rvv.h +92 -84
  58. package/include/numkong/dot/rvvbf16.h +12 -12
  59. package/include/numkong/dot/rvvhalf.h +12 -12
  60. package/include/numkong/dot/sapphire.h +4 -4
  61. package/include/numkong/dot/serial.h +66 -30
  62. package/include/numkong/dot/sierra.h +31 -31
  63. package/include/numkong/dot/skylake.h +142 -110
  64. package/include/numkong/dot/sve.h +217 -177
  65. package/include/numkong/dot/svebfdot.h +10 -10
  66. package/include/numkong/dot/svehalf.h +85 -41
  67. package/include/numkong/dot/svesdot.h +89 -0
  68. package/include/numkong/dot/v128relaxed.h +124 -89
  69. package/include/numkong/dot.h +114 -48
  70. package/include/numkong/dots/README.md +203 -203
  71. package/include/numkong/dots/alder.h +12 -9
  72. package/include/numkong/dots/diamond.h +86 -0
  73. package/include/numkong/dots/genoa.h +10 -4
  74. package/include/numkong/dots/haswell.h +63 -48
  75. package/include/numkong/dots/icelake.h +27 -18
  76. package/include/numkong/dots/loongsonasx.h +176 -0
  77. package/include/numkong/dots/neon.h +14 -11
  78. package/include/numkong/dots/neonbfdot.h +4 -3
  79. package/include/numkong/dots/neonfhm.h +11 -9
  80. package/include/numkong/dots/neonfp8.h +99 -0
  81. package/include/numkong/dots/neonsdot.h +48 -12
  82. package/include/numkong/dots/powervsx.h +194 -0
  83. package/include/numkong/dots/rvv.h +451 -344
  84. package/include/numkong/dots/sapphireamx.h +1028 -984
  85. package/include/numkong/dots/serial.h +213 -197
  86. package/include/numkong/dots/sierra.h +10 -7
  87. package/include/numkong/dots/skylake.h +47 -36
  88. package/include/numkong/dots/sme.h +2001 -2364
  89. package/include/numkong/dots/smebi32.h +175 -162
  90. package/include/numkong/dots/smef64.h +328 -323
  91. package/include/numkong/dots/v128relaxed.h +64 -41
  92. package/include/numkong/dots.h +573 -293
  93. package/include/numkong/dots.hpp +45 -43
  94. package/include/numkong/each/README.md +133 -137
  95. package/include/numkong/each/haswell.h +6 -6
  96. package/include/numkong/each/icelake.h +7 -7
  97. package/include/numkong/each/neon.h +76 -42
  98. package/include/numkong/each/neonbfdot.h +11 -12
  99. package/include/numkong/each/neonhalf.h +24 -116
  100. package/include/numkong/each/rvv.h +28 -28
  101. package/include/numkong/each/sapphire.h +27 -161
  102. package/include/numkong/each/serial.h +6 -6
  103. package/include/numkong/each/skylake.h +7 -7
  104. package/include/numkong/each/v128relaxed.h +562 -0
  105. package/include/numkong/each.h +148 -62
  106. package/include/numkong/each.hpp +2 -2
  107. package/include/numkong/geospatial/README.md +18 -18
  108. package/include/numkong/geospatial/haswell.h +365 -325
  109. package/include/numkong/geospatial/neon.h +350 -306
  110. package/include/numkong/geospatial/rvv.h +4 -4
  111. package/include/numkong/geospatial/skylake.h +376 -340
  112. package/include/numkong/geospatial/v128relaxed.h +366 -327
  113. package/include/numkong/geospatial.h +17 -17
  114. package/include/numkong/matrix.hpp +4 -4
  115. package/include/numkong/maxsim/README.md +14 -14
  116. package/include/numkong/maxsim/alder.h +6 -6
  117. package/include/numkong/maxsim/genoa.h +4 -4
  118. package/include/numkong/maxsim/haswell.h +6 -6
  119. package/include/numkong/maxsim/icelake.h +18 -18
  120. package/include/numkong/maxsim/neonsdot.h +21 -21
  121. package/include/numkong/maxsim/sapphireamx.h +14 -14
  122. package/include/numkong/maxsim/serial.h +6 -6
  123. package/include/numkong/maxsim/sme.h +221 -196
  124. package/include/numkong/maxsim/v128relaxed.h +6 -6
  125. package/include/numkong/mesh/README.md +62 -56
  126. package/include/numkong/mesh/haswell.h +339 -464
  127. package/include/numkong/mesh/neon.h +1100 -519
  128. package/include/numkong/mesh/neonbfdot.h +36 -68
  129. package/include/numkong/mesh/rvv.h +530 -435
  130. package/include/numkong/mesh/serial.h +75 -91
  131. package/include/numkong/mesh/skylake.h +1627 -302
  132. package/include/numkong/mesh/v128relaxed.h +443 -330
  133. package/include/numkong/mesh.h +63 -49
  134. package/include/numkong/mesh.hpp +4 -4
  135. package/include/numkong/numkong.h +3 -3
  136. package/include/numkong/numkong.hpp +1 -0
  137. package/include/numkong/probability/README.md +23 -19
  138. package/include/numkong/probability/neon.h +82 -52
  139. package/include/numkong/probability/rvv.h +28 -23
  140. package/include/numkong/probability/serial.h +51 -39
  141. package/include/numkong/probability.h +20 -23
  142. package/include/numkong/random.h +1 -1
  143. package/include/numkong/reduce/README.md +143 -138
  144. package/include/numkong/reduce/alder.h +81 -77
  145. package/include/numkong/reduce/haswell.h +222 -220
  146. package/include/numkong/reduce/neon.h +629 -519
  147. package/include/numkong/reduce/neonbfdot.h +7 -218
  148. package/include/numkong/reduce/neonfhm.h +9 -381
  149. package/include/numkong/reduce/neonsdot.h +9 -9
  150. package/include/numkong/reduce/rvv.h +928 -802
  151. package/include/numkong/reduce/serial.h +23 -27
  152. package/include/numkong/reduce/sierra.h +20 -20
  153. package/include/numkong/reduce/skylake.h +326 -324
  154. package/include/numkong/reduce/v128relaxed.h +52 -52
  155. package/include/numkong/reduce.h +4 -23
  156. package/include/numkong/reduce.hpp +156 -11
  157. package/include/numkong/scalar/README.md +6 -6
  158. package/include/numkong/scalar/haswell.h +26 -17
  159. package/include/numkong/scalar/loongsonasx.h +74 -0
  160. package/include/numkong/scalar/neon.h +9 -9
  161. package/include/numkong/scalar/powervsx.h +96 -0
  162. package/include/numkong/scalar/rvv.h +2 -2
  163. package/include/numkong/scalar/sapphire.h +21 -10
  164. package/include/numkong/scalar/serial.h +21 -21
  165. package/include/numkong/scalar.h +13 -0
  166. package/include/numkong/set/README.md +28 -28
  167. package/include/numkong/set/haswell.h +12 -12
  168. package/include/numkong/set/icelake.h +14 -14
  169. package/include/numkong/set/loongsonasx.h +181 -0
  170. package/include/numkong/set/neon.h +17 -18
  171. package/include/numkong/set/powervsx.h +326 -0
  172. package/include/numkong/set/rvv.h +4 -4
  173. package/include/numkong/set/serial.h +6 -6
  174. package/include/numkong/set/sve.h +60 -59
  175. package/include/numkong/set/v128relaxed.h +6 -6
  176. package/include/numkong/set.h +21 -7
  177. package/include/numkong/sets/README.md +26 -26
  178. package/include/numkong/sets/loongsonasx.h +52 -0
  179. package/include/numkong/sets/powervsx.h +65 -0
  180. package/include/numkong/sets/smebi32.h +395 -364
  181. package/include/numkong/sets.h +83 -40
  182. package/include/numkong/sparse/README.md +4 -4
  183. package/include/numkong/sparse/icelake.h +101 -101
  184. package/include/numkong/sparse/serial.h +1 -1
  185. package/include/numkong/sparse/sve2.h +137 -141
  186. package/include/numkong/sparse/turin.h +12 -12
  187. package/include/numkong/sparse.h +10 -10
  188. package/include/numkong/spatial/README.md +230 -226
  189. package/include/numkong/spatial/alder.h +113 -116
  190. package/include/numkong/spatial/diamond.h +240 -0
  191. package/include/numkong/spatial/genoa.h +0 -68
  192. package/include/numkong/spatial/haswell.h +74 -55
  193. package/include/numkong/spatial/icelake.h +539 -58
  194. package/include/numkong/spatial/loongsonasx.h +483 -0
  195. package/include/numkong/spatial/neon.h +125 -52
  196. package/include/numkong/spatial/neonbfdot.h +8 -9
  197. package/include/numkong/spatial/neonfp8.h +258 -0
  198. package/include/numkong/spatial/neonsdot.h +180 -12
  199. package/include/numkong/spatial/powervsx.h +738 -0
  200. package/include/numkong/spatial/rvv.h +146 -139
  201. package/include/numkong/spatial/rvvbf16.h +17 -12
  202. package/include/numkong/spatial/rvvhalf.h +13 -10
  203. package/include/numkong/spatial/serial.h +13 -12
  204. package/include/numkong/spatial/sierra.h +232 -39
  205. package/include/numkong/spatial/skylake.h +73 -74
  206. package/include/numkong/spatial/sve.h +93 -72
  207. package/include/numkong/spatial/svebfdot.h +29 -29
  208. package/include/numkong/spatial/svehalf.h +52 -26
  209. package/include/numkong/spatial/svesdot.h +142 -0
  210. package/include/numkong/spatial/v128relaxed.h +293 -41
  211. package/include/numkong/spatial.h +338 -82
  212. package/include/numkong/spatials/README.md +194 -194
  213. package/include/numkong/spatials/diamond.h +82 -0
  214. package/include/numkong/spatials/haswell.h +2 -2
  215. package/include/numkong/spatials/loongsonasx.h +153 -0
  216. package/include/numkong/spatials/neonfp8.h +111 -0
  217. package/include/numkong/spatials/neonsdot.h +34 -0
  218. package/include/numkong/spatials/powervsx.h +153 -0
  219. package/include/numkong/spatials/rvv.h +259 -243
  220. package/include/numkong/spatials/sapphireamx.h +173 -173
  221. package/include/numkong/spatials/serial.h +2 -2
  222. package/include/numkong/spatials/skylake.h +2 -2
  223. package/include/numkong/spatials/sme.h +590 -605
  224. package/include/numkong/spatials/smef64.h +139 -130
  225. package/include/numkong/spatials/v128relaxed.h +2 -2
  226. package/include/numkong/spatials.h +820 -500
  227. package/include/numkong/spatials.hpp +49 -48
  228. package/include/numkong/tensor.hpp +406 -17
  229. package/include/numkong/trigonometry/README.md +19 -19
  230. package/include/numkong/trigonometry/haswell.h +402 -401
  231. package/include/numkong/trigonometry/neon.h +386 -387
  232. package/include/numkong/trigonometry/rvv.h +52 -51
  233. package/include/numkong/trigonometry/serial.h +13 -13
  234. package/include/numkong/trigonometry/skylake.h +373 -369
  235. package/include/numkong/trigonometry/v128relaxed.h +375 -374
  236. package/include/numkong/trigonometry.h +13 -13
  237. package/include/numkong/trigonometry.hpp +2 -2
  238. package/include/numkong/types.h +287 -49
  239. package/include/numkong/types.hpp +436 -12
  240. package/include/numkong/vector.hpp +82 -14
  241. package/javascript/dist/cjs/numkong-wasm.js +6 -12
  242. package/javascript/dist/cjs/numkong.d.ts +7 -1
  243. package/javascript/dist/cjs/numkong.js +37 -11
  244. package/javascript/dist/cjs/types.d.ts +9 -0
  245. package/javascript/dist/cjs/types.js +96 -0
  246. package/javascript/dist/esm/numkong-browser.d.ts +14 -0
  247. package/javascript/dist/esm/numkong-browser.js +23 -0
  248. package/javascript/dist/esm/numkong-wasm.js +6 -12
  249. package/javascript/dist/esm/numkong.d.ts +7 -1
  250. package/javascript/dist/esm/numkong.js +37 -11
  251. package/javascript/dist/esm/types.d.ts +9 -0
  252. package/javascript/dist/esm/types.js +96 -0
  253. package/javascript/node-gyp-build.d.ts +4 -1
  254. package/javascript/numkong-browser.ts +40 -0
  255. package/javascript/numkong-wasm.ts +7 -13
  256. package/javascript/numkong.c +5 -26
  257. package/javascript/numkong.ts +36 -11
  258. package/javascript/tsconfig-base.json +1 -0
  259. package/javascript/tsconfig-cjs.json +6 -1
  260. package/javascript/types.ts +110 -0
  261. package/numkong.gypi +101 -0
  262. package/package.json +34 -13
  263. package/probes/arm_neon.c +8 -0
  264. package/probes/arm_neon_bfdot.c +9 -0
  265. package/probes/arm_neon_fhm.c +9 -0
  266. package/probes/arm_neon_half.c +8 -0
  267. package/probes/arm_neon_sdot.c +9 -0
  268. package/probes/arm_neonfp8.c +9 -0
  269. package/probes/arm_sme.c +16 -0
  270. package/probes/arm_sme2.c +16 -0
  271. package/probes/arm_sme2p1.c +16 -0
  272. package/probes/arm_sme_bf16.c +16 -0
  273. package/probes/arm_sme_bi32.c +16 -0
  274. package/probes/arm_sme_f64.c +16 -0
  275. package/probes/arm_sme_fa64.c +14 -0
  276. package/probes/arm_sme_half.c +16 -0
  277. package/probes/arm_sme_lut2.c +15 -0
  278. package/probes/arm_sve.c +18 -0
  279. package/probes/arm_sve2.c +20 -0
  280. package/probes/arm_sve2p1.c +18 -0
  281. package/probes/arm_sve_bfdot.c +20 -0
  282. package/probes/arm_sve_half.c +18 -0
  283. package/probes/arm_sve_sdot.c +21 -0
  284. package/probes/loongarch_lasx.c +12 -0
  285. package/probes/power_vsx.c +12 -0
  286. package/probes/probe.js +127 -0
  287. package/probes/riscv_rvv.c +14 -0
  288. package/probes/riscv_rvv_bb.c +15 -0
  289. package/probes/riscv_rvv_bf16.c +17 -0
  290. package/probes/riscv_rvv_half.c +14 -0
  291. package/probes/wasm_v128relaxed.c +11 -0
  292. package/probes/x86_alder.c +17 -0
  293. package/probes/x86_diamond.c +17 -0
  294. package/probes/x86_genoa.c +17 -0
  295. package/probes/x86_graniteamx.c +19 -0
  296. package/probes/x86_haswell.c +11 -0
  297. package/probes/x86_icelake.c +17 -0
  298. package/probes/x86_sapphire.c +16 -0
  299. package/probes/x86_sapphireamx.c +18 -0
  300. package/probes/x86_sierra.c +17 -0
  301. package/probes/x86_skylake.c +15 -0
  302. package/probes/x86_turin.c +17 -0
  303. package/wasm/numkong-emscripten.js +2 -0
  304. package/wasm/numkong.d.ts +14 -0
  305. package/wasm/numkong.js +1124 -0
  306. package/wasm/numkong.wasm +0 -0
  307. package/include/numkong/curved/neonhalf.h +0 -212
  308. package/include/numkong/dot/neonhalf.h +0 -198
  309. package/include/numkong/dots/neonhalf.h +0 -57
  310. package/include/numkong/mesh/neonhalf.h +0 -616
  311. package/include/numkong/reduce/neonhalf.h +0 -157
  312. package/include/numkong/spatial/neonhalf.h +0 -118
  313. package/include/numkong/spatial/sapphire.h +0 -343
  314. package/include/numkong/spatials/neonhalf.h +0 -58
  315. package/javascript/README.md +0 -246
@@ -61,7 +61,7 @@ extern "C" {
61
61
  #endif
62
62
 
63
63
  #if defined(__clang__)
64
- #pragma clang attribute push(__attribute__((target("sme,sve,sme-f64f64"))), apply_to = function)
64
+ #pragma clang attribute push(__attribute__((target("sme,sme-f64f64"))), apply_to = function)
65
65
  #elif defined(__GNUC__)
66
66
  #pragma GCC push_options
67
67
  #pragma GCC target("+sme+sme-f64f64")
@@ -71,122 +71,123 @@ extern "C" {
71
71
  * @brief SVE Dot2 accumulator: sum += a × b with error compensation.
72
72
  * Uses TwoProd (svneg+svnmls) and TwoSum error-free transformations.
73
73
  */
74
- NK_PUBLIC void nk_dot2_f64_sve_accumulate_(svbool_t predicate_f64x, svfloat64_t *sum, svfloat64_t *comp,
75
- svfloat64_t a_f64x, svfloat64_t b_f64x) NK_STREAMING_COMPATIBLE_ {
76
- svfloat64_t product_f64x = svmul_f64_x(predicate_f64x, a_f64x, b_f64x);
77
- svfloat64_t product_error_f64x = svneg_f64_x(predicate_f64x,
78
- svnmls_f64_x(predicate_f64x, product_f64x, a_f64x, b_f64x));
79
- svfloat64_t running_sum_f64x = svadd_f64_x(predicate_f64x, *sum, product_f64x);
80
- svfloat64_t recovered_addend_f64x = svsub_f64_x(predicate_f64x, running_sum_f64x, *sum);
74
+ NK_PUBLIC void nk_dot2_f64_sve_accumulate_(svbool_t predicate_b64x, svfloat64_t *sum, svfloat64_t *comp,
75
+ svfloat64_t a_f64x, svfloat64_t b_f64x) NK_STREAMING_ {
76
+ svfloat64_t product_f64x = svmul_f64_x(predicate_b64x, a_f64x, b_f64x);
77
+ svfloat64_t product_error_f64x = svneg_f64_x(predicate_b64x,
78
+ svnmls_f64_x(predicate_b64x, product_f64x, a_f64x, b_f64x));
79
+ svfloat64_t running_sum_f64x = svadd_f64_m(predicate_b64x, *sum, product_f64x);
80
+ svfloat64_t recovered_addend_f64x = svsub_f64_x(predicate_b64x, running_sum_f64x, *sum);
81
81
  svfloat64_t sum_error_f64x = svadd_f64_x(
82
- predicate_f64x,
83
- svsub_f64_x(predicate_f64x, *sum, svsub_f64_x(predicate_f64x, running_sum_f64x, recovered_addend_f64x)),
84
- svsub_f64_x(predicate_f64x, product_f64x, recovered_addend_f64x));
82
+ predicate_b64x,
83
+ svsub_f64_x(predicate_b64x, *sum, svsub_f64_x(predicate_b64x, running_sum_f64x, recovered_addend_f64x)),
84
+ svsub_f64_x(predicate_b64x, product_f64x, recovered_addend_f64x));
85
85
  *sum = running_sum_f64x;
86
- *comp = svadd_f64_x(predicate_f64x, *comp, svadd_f64_x(predicate_f64x, sum_error_f64x, product_error_f64x));
86
+ *comp = svadd_f64_m(predicate_b64x, *comp, svadd_f64_x(predicate_b64x, sum_error_f64x, product_error_f64x));
87
87
  }
88
88
 
89
89
  /**
90
90
  * @brief f32 bilinear: GEMV via FMOPA (widening f32→f64, exact accumulation).
91
91
  * ZA0.D = C staging, ZA1.D = GEMV accumulator.
92
92
  */
93
- __arm_locally_streaming __arm_new("za") static void nk_bilinear_f32_smef64_streaming_(nk_f32_t const *a,
94
- nk_f32_t const *b,
95
- nk_f32_t const *c, nk_size_t n,
96
- nk_f64_t *result) {
97
- svbool_t predicate_body_f64x = svptrue_b64();
93
+ __arm_locally_streaming __arm_new("za") static void nk_bilinear_f32_smef64_streaming_(
94
+ nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t dimensions, nk_f64_t *result) {
95
+ svbool_t predicate_body_b64x = svptrue_b64();
98
96
  nk_size_t tile_dimension = svcntd();
99
97
  nk_f64_t outer_sum_f64 = 0.0;
100
98
 
101
- for (nk_size_t row = 0; row < n; row += tile_dimension) {
102
- nk_size_t rows_remaining = (row + tile_dimension <= n) ? tile_dimension : (n - row);
103
- svbool_t row_predicate_f64x = svwhilelt_b64_u64(0u, rows_remaining);
99
+ for (nk_size_t row = 0; row < dimensions; row += tile_dimension) {
100
+ nk_size_t rows_remaining = (row + tile_dimension <= dimensions) ? tile_dimension : (dimensions - row);
101
+ svbool_t row_predicate_b64x = svwhilelt_b64_u64(0u, rows_remaining);
104
102
 
105
103
  svzero_mask_za(nk_sme_zero_za64_tile_1_);
106
104
 
107
- for (nk_size_t j = 0; j < n; j += tile_dimension) {
108
- nk_size_t batch_size = (j + tile_dimension <= n) ? tile_dimension : (n - j);
109
- svbool_t batch_predicate_f64x = svwhilelt_b64_u64(0u, batch_size);
105
+ for (nk_size_t j = 0; j < dimensions; j += tile_dimension) {
106
+ nk_size_t batch_size = (j + tile_dimension <= dimensions) ? tile_dimension : (dimensions - j);
107
+ svbool_t batch_predicate_b64x = svwhilelt_b64_u64(0u, batch_size);
110
108
 
111
109
  svzero_mask_za(nk_sme_zero_za64_tile_0_);
112
110
  for (nk_size_t r = 0; r < rows_remaining; r++) {
113
111
  svfloat64_t c_row_f64x = svcvt_f64_f32_x(
114
- batch_predicate_f64x, svreinterpret_f32_u64(svld1uw_u64(
115
- batch_predicate_f64x, (nk_u32_t const *)(c + (row + r) * n + j))));
116
- svwrite_hor_za64_f64_m(0, r, batch_predicate_f64x, c_row_f64x);
112
+ batch_predicate_b64x,
113
+ svreinterpret_f32_u64(
114
+ svld1uw_u64(batch_predicate_b64x, (nk_u32_t const *)(c + (row + r) * dimensions + j))));
115
+ svwrite_hor_za64_f64_m(0, r, batch_predicate_b64x, c_row_f64x);
117
116
  }
118
117
 
119
118
  for (nk_size_t k = 0; k < batch_size; k++) {
120
- svfloat64_t c_col_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), row_predicate_f64x, 0, k);
121
- svmopa_za64_f64_m(1, row_predicate_f64x, row_predicate_f64x, c_col_f64x, svdup_f64((nk_f64_t)b[j + k]));
119
+ svfloat64_t c_col_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), row_predicate_b64x, 0, k);
120
+ svmopa_za64_f64_m(1, row_predicate_b64x, row_predicate_b64x, c_col_f64x, svdup_f64((nk_f64_t)b[j + k]));
122
121
  }
123
122
  }
124
123
 
125
- svfloat64_t v_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), row_predicate_f64x, 1, 0);
124
+ svfloat64_t v_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), row_predicate_b64x, 1, 0);
126
125
  svfloat64_t a_f64x = svcvt_f64_f32_x(
127
- row_predicate_f64x, svreinterpret_f32_u64(svld1uw_u64(row_predicate_f64x, (nk_u32_t const *)(a + row))));
128
- outer_sum_f64 += svaddv_f64(predicate_body_f64x, svmul_f64_x(row_predicate_f64x, a_f64x, v_f64x));
126
+ row_predicate_b64x, svreinterpret_f32_u64(svld1uw_u64(row_predicate_b64x, (nk_u32_t const *)(a + row))));
127
+ outer_sum_f64 += svaddv_f64(predicate_body_b64x, svmul_f64_x(row_predicate_b64x, a_f64x, v_f64x));
129
128
  }
130
129
 
131
130
  *result = outer_sum_f64;
132
131
  }
133
132
 
134
- NK_PUBLIC void nk_bilinear_f32_smef64(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t n,
133
+ NK_PUBLIC void nk_bilinear_f32_smef64(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t dimensions,
135
134
  nk_f64_t *result) {
136
- nk_bilinear_f32_smef64_streaming_(a, b, c, n, result);
135
+ nk_bilinear_f32_smef64_streaming_(a, b, c, dimensions, result);
137
136
  }
138
137
 
139
138
  /**
140
139
  * @brief f32 Mahalanobis: GEMV v = C×d via FMOPA, where d = a − b (exact in f64).
141
140
  * ZA0.D = C staging, ZA1.D = GEMV accumulator.
142
141
  */
143
- __arm_locally_streaming __arm_new("za") static inline nk_f64_t
144
- nk_mahalanobis_f32_smef64_streaming_(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t n) {
142
+ __arm_locally_streaming __arm_new("za") static nk_f64_t
143
+ nk_mahalanobis_f32_smef64_streaming_(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c,
144
+ nk_size_t dimensions) {
145
145
 
146
- svbool_t predicate_body_f64x = svptrue_b64();
146
+ svbool_t predicate_body_b64x = svptrue_b64();
147
147
  nk_size_t tile_dimension = svcntd();
148
148
  nk_f64_t outer_sum_f64 = 0.0;
149
149
 
150
- for (nk_size_t row = 0; row < n; row += tile_dimension) {
151
- nk_size_t rows_remaining = (row + tile_dimension <= n) ? tile_dimension : (n - row);
152
- svbool_t row_predicate_f64x = svwhilelt_b64_u64(0u, rows_remaining);
150
+ for (nk_size_t row = 0; row < dimensions; row += tile_dimension) {
151
+ nk_size_t rows_remaining = (row + tile_dimension <= dimensions) ? tile_dimension : (dimensions - row);
152
+ svbool_t row_predicate_b64x = svwhilelt_b64_u64(0u, rows_remaining);
153
153
 
154
154
  svzero_mask_za(nk_sme_zero_za64_tile_1_);
155
155
 
156
- for (nk_size_t j = 0; j < n; j += tile_dimension) {
157
- nk_size_t batch_size = (j + tile_dimension <= n) ? tile_dimension : (n - j);
158
- svbool_t batch_predicate_f64x = svwhilelt_b64_u64(0u, batch_size);
156
+ for (nk_size_t j = 0; j < dimensions; j += tile_dimension) {
157
+ nk_size_t batch_size = (j + tile_dimension <= dimensions) ? tile_dimension : (dimensions - j);
158
+ svbool_t batch_predicate_b64x = svwhilelt_b64_u64(0u, batch_size);
159
159
 
160
160
  svzero_mask_za(nk_sme_zero_za64_tile_0_);
161
161
  for (nk_size_t r = 0; r < rows_remaining; r++) {
162
162
  svfloat64_t c_row_f64x = svcvt_f64_f32_x(
163
- batch_predicate_f64x, svreinterpret_f32_u64(svld1uw_u64(
164
- batch_predicate_f64x, (nk_u32_t const *)(c + (row + r) * n + j))));
165
- svwrite_hor_za64_f64_m(0, r, batch_predicate_f64x, c_row_f64x);
163
+ batch_predicate_b64x,
164
+ svreinterpret_f32_u64(
165
+ svld1uw_u64(batch_predicate_b64x, (nk_u32_t const *)(c + (row + r) * dimensions + j))));
166
+ svwrite_hor_za64_f64_m(0, r, batch_predicate_b64x, c_row_f64x);
166
167
  }
167
168
 
168
169
  for (nk_size_t k = 0; k < batch_size; k++) {
169
- svfloat64_t c_col_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), row_predicate_f64x, 0, k);
170
+ svfloat64_t c_col_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), row_predicate_b64x, 0, k);
170
171
  nk_f64_t d_k = (nk_f64_t)a[j + k] - (nk_f64_t)b[j + k];
171
- svmopa_za64_f64_m(1, row_predicate_f64x, row_predicate_f64x, c_col_f64x, svdup_f64(d_k));
172
+ svmopa_za64_f64_m(1, row_predicate_b64x, row_predicate_b64x, c_col_f64x, svdup_f64(d_k));
172
173
  }
173
174
  }
174
175
 
175
- svfloat64_t v_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), row_predicate_f64x, 1, 0);
176
+ svfloat64_t v_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), row_predicate_b64x, 1, 0);
176
177
  svfloat64_t a_f64x = svcvt_f64_f32_x(
177
- row_predicate_f64x, svreinterpret_f32_u64(svld1uw_u64(row_predicate_f64x, (nk_u32_t const *)(a + row))));
178
+ row_predicate_b64x, svreinterpret_f32_u64(svld1uw_u64(row_predicate_b64x, (nk_u32_t const *)(a + row))));
178
179
  svfloat64_t b_f64x = svcvt_f64_f32_x(
179
- row_predicate_f64x, svreinterpret_f32_u64(svld1uw_u64(row_predicate_f64x, (nk_u32_t const *)(b + row))));
180
- svfloat64_t d_f64x = svsub_f64_x(row_predicate_f64x, a_f64x, b_f64x);
181
- outer_sum_f64 += svaddv_f64(predicate_body_f64x, svmul_f64_x(row_predicate_f64x, d_f64x, v_f64x));
180
+ row_predicate_b64x, svreinterpret_f32_u64(svld1uw_u64(row_predicate_b64x, (nk_u32_t const *)(b + row))));
181
+ svfloat64_t d_f64x = svsub_f64_x(row_predicate_b64x, a_f64x, b_f64x);
182
+ outer_sum_f64 += svaddv_f64(predicate_body_b64x, svmul_f64_x(row_predicate_b64x, d_f64x, v_f64x));
182
183
  }
183
184
 
184
185
  return outer_sum_f64;
185
186
  }
186
187
 
187
- NK_PUBLIC void nk_mahalanobis_f32_smef64(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t n,
188
+ NK_PUBLIC void nk_mahalanobis_f32_smef64(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t dimensions,
188
189
  nk_f64_t *result) {
189
- nk_f64_t quadratic = nk_mahalanobis_f32_smef64_streaming_(a, b, c, n);
190
+ nk_f64_t quadratic = nk_mahalanobis_f32_smef64_streaming_(a, b, c, dimensions);
190
191
  *result = nk_f64_sqrt_neon(quadratic > 0 ? quadratic : 0);
191
192
  }
192
193
 
@@ -195,84 +196,84 @@ NK_PUBLIC void nk_mahalanobis_f32_smef64(nk_f32_t const *a, nk_f32_t const *b, n
195
196
  * 4-row fast path shares b_f64x loads; 1-row tail for remainder.
196
197
  */
197
198
  __arm_locally_streaming static void nk_bilinear_f64_smef64_streaming_(nk_f64_t const *a, nk_f64_t const *b,
198
- nk_f64_t const *c, nk_size_t n,
199
+ nk_f64_t const *c, nk_size_t dimensions,
199
200
  nk_f64_t *result) {
200
- svbool_t predicate_all_f64x = svptrue_b64();
201
+ svbool_t predicate_all_b64x = svptrue_b64();
201
202
  nk_f64_t outer_sum = 0.0, outer_comp = 0.0;
202
203
  nk_size_t row = 0;
203
204
 
204
205
  // 4-row fast path: share b_f64x load across 4 rows
205
- for (; row + 4 <= n; row += 4) {
206
+ for (; row + 4 <= dimensions; row += 4) {
206
207
  nk_f64_t a0 = a[row + 0], a1 = a[row + 1], a2 = a[row + 2], a3 = a[row + 3];
207
208
  svfloat64_t sum_0_f64x = svdup_f64(0), compensation_0_f64x = svdup_f64(0);
208
209
  svfloat64_t sum_1_f64x = svdup_f64(0), compensation_1_f64x = svdup_f64(0);
209
210
  svfloat64_t sum_2_f64x = svdup_f64(0), compensation_2_f64x = svdup_f64(0);
210
211
  svfloat64_t sum_3_f64x = svdup_f64(0), compensation_3_f64x = svdup_f64(0);
211
212
  nk_size_t j = 0;
212
- svbool_t predicate_f64x = svwhilelt_b64(j, n);
213
-
214
- while (svptest_first(predicate_all_f64x, predicate_f64x)) {
215
- svfloat64_t b_f64x = svld1_f64(predicate_f64x, b + j);
216
- nk_dot2_f64_sve_accumulate_(predicate_f64x, &sum_0_f64x, &compensation_0_f64x,
217
- svld1_f64(predicate_f64x, c + (row + 0) * n + j), b_f64x);
218
- nk_dot2_f64_sve_accumulate_(predicate_f64x, &sum_1_f64x, &compensation_1_f64x,
219
- svld1_f64(predicate_f64x, c + (row + 1) * n + j), b_f64x);
220
- nk_dot2_f64_sve_accumulate_(predicate_f64x, &sum_2_f64x, &compensation_2_f64x,
221
- svld1_f64(predicate_f64x, c + (row + 2) * n + j), b_f64x);
222
- nk_dot2_f64_sve_accumulate_(predicate_f64x, &sum_3_f64x, &compensation_3_f64x,
223
- svld1_f64(predicate_f64x, c + (row + 3) * n + j), b_f64x);
213
+ svbool_t predicate_b64x = svwhilelt_b64(j, dimensions);
214
+
215
+ while (svptest_first(predicate_all_b64x, predicate_b64x)) {
216
+ svfloat64_t b_f64x = svld1_f64(predicate_b64x, b + j);
217
+ nk_dot2_f64_sve_accumulate_(predicate_b64x, &sum_0_f64x, &compensation_0_f64x,
218
+ svld1_f64(predicate_b64x, c + (row + 0) * dimensions + j), b_f64x);
219
+ nk_dot2_f64_sve_accumulate_(predicate_b64x, &sum_1_f64x, &compensation_1_f64x,
220
+ svld1_f64(predicate_b64x, c + (row + 1) * dimensions + j), b_f64x);
221
+ nk_dot2_f64_sve_accumulate_(predicate_b64x, &sum_2_f64x, &compensation_2_f64x,
222
+ svld1_f64(predicate_b64x, c + (row + 2) * dimensions + j), b_f64x);
223
+ nk_dot2_f64_sve_accumulate_(predicate_b64x, &sum_3_f64x, &compensation_3_f64x,
224
+ svld1_f64(predicate_b64x, c + (row + 3) * dimensions + j), b_f64x);
224
225
  j += svcntd();
225
- predicate_f64x = svwhilelt_b64(j, n);
226
+ predicate_b64x = svwhilelt_b64(j, dimensions);
226
227
  }
227
228
 
228
- nk_f64_t cb[4] = {
229
- svaddv_f64(predicate_all_f64x, sum_0_f64x) + svaddv_f64(predicate_all_f64x, compensation_0_f64x),
230
- svaddv_f64(predicate_all_f64x, sum_1_f64x) + svaddv_f64(predicate_all_f64x, compensation_1_f64x),
231
- svaddv_f64(predicate_all_f64x, sum_2_f64x) + svaddv_f64(predicate_all_f64x, compensation_2_f64x),
232
- svaddv_f64(predicate_all_f64x, sum_3_f64x) + svaddv_f64(predicate_all_f64x, compensation_3_f64x),
233
- };
234
- nk_f64_t av[4] = {a0, a1, a2, a3};
235
- for (int r = 0; r < 4; ++r) nk_f64_dot2_(&outer_sum, &outer_comp, av[r], cb[r]);
229
+ nk_f64_dot2_(&outer_sum, &outer_comp, a0,
230
+ svaddv_f64(predicate_all_b64x, sum_0_f64x) + svaddv_f64(predicate_all_b64x, compensation_0_f64x));
231
+ nk_f64_dot2_(&outer_sum, &outer_comp, a1,
232
+ svaddv_f64(predicate_all_b64x, sum_1_f64x) + svaddv_f64(predicate_all_b64x, compensation_1_f64x));
233
+ nk_f64_dot2_(&outer_sum, &outer_comp, a2,
234
+ svaddv_f64(predicate_all_b64x, sum_2_f64x) + svaddv_f64(predicate_all_b64x, compensation_2_f64x));
235
+ nk_f64_dot2_(&outer_sum, &outer_comp, a3,
236
+ svaddv_f64(predicate_all_b64x, sum_3_f64x) + svaddv_f64(predicate_all_b64x, compensation_3_f64x));
236
237
  }
237
238
 
238
239
  // 1-row tail
239
- for (; row < n; ++row) {
240
+ for (; row < dimensions; ++row) {
240
241
  svfloat64_t sum_f64x = svdup_f64(0.0), compensation_f64x = svdup_f64(0.0);
241
242
  nk_size_t j = 0;
242
- svbool_t predicate_f64x = svwhilelt_b64(j, n);
243
+ svbool_t predicate_b64x = svwhilelt_b64(j, dimensions);
243
244
 
244
- while (svptest_first(predicate_all_f64x, predicate_f64x)) {
245
- nk_dot2_f64_sve_accumulate_(predicate_f64x, &sum_f64x, &compensation_f64x,
246
- svld1_f64(predicate_f64x, c + row * n + j), svld1_f64(predicate_f64x, b + j));
245
+ while (svptest_first(predicate_all_b64x, predicate_b64x)) {
246
+ nk_dot2_f64_sve_accumulate_(predicate_b64x, &sum_f64x, &compensation_f64x,
247
+ svld1_f64(predicate_b64x, c + row * dimensions + j),
248
+ svld1_f64(predicate_b64x, b + j));
247
249
  j += svcntd();
248
- predicate_f64x = svwhilelt_b64(j, n);
250
+ predicate_b64x = svwhilelt_b64(j, dimensions);
249
251
  }
250
252
 
251
- nk_f64_t cb_j = svaddv_f64(predicate_all_f64x, sum_f64x) + svaddv_f64(predicate_all_f64x, compensation_f64x);
253
+ nk_f64_t cb_j = svaddv_f64(predicate_all_b64x, sum_f64x) + svaddv_f64(predicate_all_b64x, compensation_f64x);
252
254
  nk_f64_dot2_(&outer_sum, &outer_comp, a[row], cb_j);
253
255
  }
254
256
 
255
257
  *result = outer_sum + outer_comp;
256
258
  }
257
259
 
258
- NK_PUBLIC void nk_bilinear_f64_smef64(nk_f64_t const *a, nk_f64_t const *b, nk_f64_t const *c, nk_size_t n,
260
+ NK_PUBLIC void nk_bilinear_f64_smef64(nk_f64_t const *a, nk_f64_t const *b, nk_f64_t const *c, nk_size_t dimensions,
259
261
  nk_f64_t *result) {
260
- nk_bilinear_f64_smef64_streaming_(a, b, c, n, result);
262
+ nk_bilinear_f64_smef64_streaming_(a, b, c, dimensions, result);
261
263
  }
262
264
 
263
265
  /**
264
266
  * @brief f64 Mahalanobis: row-by-row streaming SVE with Dot2 compensation.
265
267
  * 4-row fast path shares (a−b) column vector; 1-row tail for remainder.
266
268
  */
267
- __arm_locally_streaming static inline nk_f64_t nk_mahalanobis_f64_smef64_streaming_(nk_f64_t const *a,
268
- nk_f64_t const *b,
269
- nk_f64_t const *c, nk_size_t n) {
270
- svbool_t predicate_all_f64x = svptrue_b64();
269
+ __arm_locally_streaming static nk_f64_t nk_mahalanobis_f64_smef64_streaming_(nk_f64_t const *a, nk_f64_t const *b,
270
+ nk_f64_t const *c, nk_size_t dimensions) {
271
+ svbool_t predicate_all_b64x = svptrue_b64();
271
272
  nk_f64_t outer_sum = 0.0, outer_comp = 0.0;
272
273
  nk_size_t row = 0;
273
274
 
274
275
  // 4-row fast path: share (a−b) column vector across 4 rows
275
- for (; row + 4 <= n; row += 4) {
276
+ for (; row + 4 <= dimensions; row += 4) {
276
277
  nk_f64_t d0 = a[row + 0] - b[row + 0], d1 = a[row + 1] - b[row + 1];
277
278
  nk_f64_t d2 = a[row + 2] - b[row + 2], d3 = a[row + 3] - b[row + 3];
278
279
  svfloat64_t sum_0_f64x = svdup_f64(0), compensation_0_f64x = svdup_f64(0);
@@ -280,59 +281,59 @@ __arm_locally_streaming static inline nk_f64_t nk_mahalanobis_f64_smef64_streami
280
281
  svfloat64_t sum_2_f64x = svdup_f64(0), compensation_2_f64x = svdup_f64(0);
281
282
  svfloat64_t sum_3_f64x = svdup_f64(0), compensation_3_f64x = svdup_f64(0);
282
283
  nk_size_t j = 0;
283
- svbool_t predicate_f64x = svwhilelt_b64(j, n);
284
-
285
- while (svptest_first(predicate_all_f64x, predicate_f64x)) {
286
- svfloat64_t diff_col_f64x = svsub_f64_x(predicate_f64x, svld1_f64(predicate_f64x, a + j),
287
- svld1_f64(predicate_f64x, b + j));
288
- nk_dot2_f64_sve_accumulate_(predicate_f64x, &sum_0_f64x, &compensation_0_f64x,
289
- svld1_f64(predicate_f64x, c + (row + 0) * n + j), diff_col_f64x);
290
- nk_dot2_f64_sve_accumulate_(predicate_f64x, &sum_1_f64x, &compensation_1_f64x,
291
- svld1_f64(predicate_f64x, c + (row + 1) * n + j), diff_col_f64x);
292
- nk_dot2_f64_sve_accumulate_(predicate_f64x, &sum_2_f64x, &compensation_2_f64x,
293
- svld1_f64(predicate_f64x, c + (row + 2) * n + j), diff_col_f64x);
294
- nk_dot2_f64_sve_accumulate_(predicate_f64x, &sum_3_f64x, &compensation_3_f64x,
295
- svld1_f64(predicate_f64x, c + (row + 3) * n + j), diff_col_f64x);
284
+ svbool_t predicate_b64x = svwhilelt_b64(j, dimensions);
285
+
286
+ while (svptest_first(predicate_all_b64x, predicate_b64x)) {
287
+ svfloat64_t diff_col_f64x = svsub_f64_x(predicate_b64x, svld1_f64(predicate_b64x, a + j),
288
+ svld1_f64(predicate_b64x, b + j));
289
+ nk_dot2_f64_sve_accumulate_(predicate_b64x, &sum_0_f64x, &compensation_0_f64x,
290
+ svld1_f64(predicate_b64x, c + (row + 0) * dimensions + j), diff_col_f64x);
291
+ nk_dot2_f64_sve_accumulate_(predicate_b64x, &sum_1_f64x, &compensation_1_f64x,
292
+ svld1_f64(predicate_b64x, c + (row + 1) * dimensions + j), diff_col_f64x);
293
+ nk_dot2_f64_sve_accumulate_(predicate_b64x, &sum_2_f64x, &compensation_2_f64x,
294
+ svld1_f64(predicate_b64x, c + (row + 2) * dimensions + j), diff_col_f64x);
295
+ nk_dot2_f64_sve_accumulate_(predicate_b64x, &sum_3_f64x, &compensation_3_f64x,
296
+ svld1_f64(predicate_b64x, c + (row + 3) * dimensions + j), diff_col_f64x);
296
297
  j += svcntd();
297
- predicate_f64x = svwhilelt_b64(j, n);
298
+ predicate_b64x = svwhilelt_b64(j, dimensions);
298
299
  }
299
300
 
300
- nk_f64_t cb[4] = {
301
- svaddv_f64(predicate_all_f64x, sum_0_f64x) + svaddv_f64(predicate_all_f64x, compensation_0_f64x),
302
- svaddv_f64(predicate_all_f64x, sum_1_f64x) + svaddv_f64(predicate_all_f64x, compensation_1_f64x),
303
- svaddv_f64(predicate_all_f64x, sum_2_f64x) + svaddv_f64(predicate_all_f64x, compensation_2_f64x),
304
- svaddv_f64(predicate_all_f64x, sum_3_f64x) + svaddv_f64(predicate_all_f64x, compensation_3_f64x),
305
- };
306
- nk_f64_t dv[4] = {d0, d1, d2, d3};
307
- for (int r = 0; r < 4; ++r) nk_f64_dot2_(&outer_sum, &outer_comp, dv[r], cb[r]);
301
+ nk_f64_dot2_(&outer_sum, &outer_comp, d0,
302
+ svaddv_f64(predicate_all_b64x, sum_0_f64x) + svaddv_f64(predicate_all_b64x, compensation_0_f64x));
303
+ nk_f64_dot2_(&outer_sum, &outer_comp, d1,
304
+ svaddv_f64(predicate_all_b64x, sum_1_f64x) + svaddv_f64(predicate_all_b64x, compensation_1_f64x));
305
+ nk_f64_dot2_(&outer_sum, &outer_comp, d2,
306
+ svaddv_f64(predicate_all_b64x, sum_2_f64x) + svaddv_f64(predicate_all_b64x, compensation_2_f64x));
307
+ nk_f64_dot2_(&outer_sum, &outer_comp, d3,
308
+ svaddv_f64(predicate_all_b64x, sum_3_f64x) + svaddv_f64(predicate_all_b64x, compensation_3_f64x));
308
309
  }
309
310
 
310
311
  // 1-row tail
311
- for (; row < n; ++row) {
312
+ for (; row < dimensions; ++row) {
312
313
  nk_f64_t diff_row = a[row] - b[row];
313
314
  svfloat64_t sum_f64x = svdup_f64(0.0), compensation_f64x = svdup_f64(0.0);
314
315
  nk_size_t j = 0;
315
- svbool_t predicate_f64x = svwhilelt_b64(j, n);
316
+ svbool_t predicate_b64x = svwhilelt_b64(j, dimensions);
316
317
 
317
- while (svptest_first(predicate_all_f64x, predicate_f64x)) {
318
- svfloat64_t diff_col_f64x = svsub_f64_x(predicate_f64x, svld1_f64(predicate_f64x, a + j),
319
- svld1_f64(predicate_f64x, b + j));
320
- nk_dot2_f64_sve_accumulate_(predicate_f64x, &sum_f64x, &compensation_f64x,
321
- svld1_f64(predicate_f64x, c + row * n + j), diff_col_f64x);
318
+ while (svptest_first(predicate_all_b64x, predicate_b64x)) {
319
+ svfloat64_t diff_col_f64x = svsub_f64_x(predicate_b64x, svld1_f64(predicate_b64x, a + j),
320
+ svld1_f64(predicate_b64x, b + j));
321
+ nk_dot2_f64_sve_accumulate_(predicate_b64x, &sum_f64x, &compensation_f64x,
322
+ svld1_f64(predicate_b64x, c + row * dimensions + j), diff_col_f64x);
322
323
  j += svcntd();
323
- predicate_f64x = svwhilelt_b64(j, n);
324
+ predicate_b64x = svwhilelt_b64(j, dimensions);
324
325
  }
325
326
 
326
- nk_f64_t cb_j = svaddv_f64(predicate_all_f64x, sum_f64x) + svaddv_f64(predicate_all_f64x, compensation_f64x);
327
+ nk_f64_t cb_j = svaddv_f64(predicate_all_b64x, sum_f64x) + svaddv_f64(predicate_all_b64x, compensation_f64x);
327
328
  nk_f64_dot2_(&outer_sum, &outer_comp, diff_row, cb_j);
328
329
  }
329
330
 
330
331
  return outer_sum + outer_comp;
331
332
  }
332
333
 
333
- NK_PUBLIC void nk_mahalanobis_f64_smef64(nk_f64_t const *a, nk_f64_t const *b, nk_f64_t const *c, nk_size_t n,
334
+ NK_PUBLIC void nk_mahalanobis_f64_smef64(nk_f64_t const *a, nk_f64_t const *b, nk_f64_t const *c, nk_size_t dimensions,
334
335
  nk_f64_t *result) {
335
- nk_f64_t quadratic = nk_mahalanobis_f64_smef64_streaming_(a, b, c, n);
336
+ nk_f64_t quadratic = nk_mahalanobis_f64_smef64_streaming_(a, b, c, dimensions);
336
337
  *result = nk_f64_sqrt_neon(quadratic > 0 ? quadratic : 0);
337
338
  }
338
339
 
@@ -340,75 +341,78 @@ NK_PUBLIC void nk_mahalanobis_f64_smef64(nk_f64_t const *a, nk_f64_t const *b, n
340
341
  * @brief f32c bilinear: complex GEMV via FMOPA (widening f32→f64).
341
342
  * ZA0.D = C staging, ZA1.D = v_real accumulator, ZA2.D = v_imag accumulator.
342
343
  */
343
- __arm_locally_streaming __arm_new("za") static void nk_bilinear_f32c_smef64_streaming_(
344
- nk_f32c_t const *a_pairs, nk_f32c_t const *b_pairs, nk_f32c_t const *c_pairs, nk_size_t n, nk_f64c_t *results) {
345
- svbool_t predicate_body_f64x = svptrue_b64();
344
+ __arm_locally_streaming __arm_new("za") static void nk_bilinear_f32c_smef64_streaming_(nk_f32c_t const *a_pairs,
345
+ nk_f32c_t const *b_pairs,
346
+ nk_f32c_t const *c_pairs,
347
+ nk_size_t dimensions,
348
+ nk_f64c_t *results) {
349
+ svbool_t predicate_body_b64x = svptrue_b64();
346
350
  nk_size_t tile_dimension = svcntd();
347
351
  nk_f64_t outer_sum_real_f64 = 0.0, outer_sum_imag_f64 = 0.0;
348
352
 
349
- for (nk_size_t row = 0; row < n; row += tile_dimension) {
350
- nk_size_t rows_remaining = (row + tile_dimension <= n) ? tile_dimension : (n - row);
351
- svbool_t row_predicate_f64x = svwhilelt_b64_u64(0u, rows_remaining);
353
+ for (nk_size_t row = 0; row < dimensions; row += tile_dimension) {
354
+ nk_size_t rows_remaining = (row + tile_dimension <= dimensions) ? tile_dimension : (dimensions - row);
355
+ svbool_t row_predicate_b64x = svwhilelt_b64_u64(0u, rows_remaining);
352
356
 
353
357
  svzero_mask_za(nk_sme_zero_za64_tile_1_);
354
358
  svzero_mask_za(nk_sme_zero_za64_tile_2_);
355
359
 
356
- for (nk_size_t j = 0; j < n; j += tile_dimension) {
357
- nk_size_t batch_size = (j + tile_dimension <= n) ? tile_dimension : (n - j);
358
- svbool_t batch_predicate_f64x = svwhilelt_b64_u64(0u, batch_size);
359
- svbool_t batch_predicate_f32x = svwhilelt_b32_u64(0u, batch_size + batch_size);
360
+ for (nk_size_t j = 0; j < dimensions; j += tile_dimension) {
361
+ nk_size_t batch_size = (j + tile_dimension <= dimensions) ? tile_dimension : (dimensions - j);
362
+ svbool_t batch_predicate_b64x = svwhilelt_b64_u64(0u, batch_size);
363
+ svbool_t batch_predicate_b32x = svwhilelt_b32_u64(0u, batch_size + batch_size);
360
364
 
361
365
  // Pass 1: Stage C_real into ZA0
362
366
  svzero_mask_za(nk_sme_zero_za64_tile_0_);
363
367
  for (nk_size_t r = 0; r < rows_remaining; r++) {
364
- svfloat32_t c_f32x = svld1_f32(batch_predicate_f32x,
365
- (nk_f32_t const *)c_pairs + ((row + r) * n + j) * 2);
366
- svfloat64_t c_real_f64x = svcvt_f64_f32_x(batch_predicate_f64x, svtrn1_f32(c_f32x, c_f32x));
367
- svwrite_hor_za64_f64_m(0, r, batch_predicate_f64x, c_real_f64x);
368
+ svfloat32_t c_f32x = svld1_f32(batch_predicate_b32x,
369
+ (nk_f32_t const *)c_pairs + ((row + r) * dimensions + j) * 2);
370
+ svfloat64_t c_real_f64x = svcvt_f64_f32_x(batch_predicate_b64x, svtrn1_f32(c_f32x, c_f32x));
371
+ svwrite_hor_za64_f64_m(0, r, batch_predicate_b64x, c_real_f64x);
368
372
  }
369
373
 
370
374
  for (nk_size_t k = 0; k < batch_size; k++) {
371
- svfloat64_t c_re_col_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), row_predicate_f64x, 0, k);
372
- svmopa_za64_f64_m(1, row_predicate_f64x, row_predicate_f64x, c_re_col_f64x,
375
+ svfloat64_t c_re_col_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), row_predicate_b64x, 0, k);
376
+ svmopa_za64_f64_m(1, row_predicate_b64x, row_predicate_b64x, c_re_col_f64x,
373
377
  svdup_f64((nk_f64_t)b_pairs[j + k].real)); // v_real += c_real × b_real
374
- svmopa_za64_f64_m(2, row_predicate_f64x, row_predicate_f64x, c_re_col_f64x,
378
+ svmopa_za64_f64_m(2, row_predicate_b64x, row_predicate_b64x, c_re_col_f64x,
375
379
  svdup_f64((nk_f64_t)b_pairs[j + k].imag)); // v_imag += c_real × b_imag
376
380
  }
377
381
 
378
382
  // Pass 2: Stage C_imag into ZA0
379
383
  svzero_mask_za(nk_sme_zero_za64_tile_0_);
380
384
  for (nk_size_t r = 0; r < rows_remaining; r++) {
381
- svfloat32_t c_f32x = svld1_f32(batch_predicate_f32x,
382
- (nk_f32_t const *)c_pairs + ((row + r) * n + j) * 2);
383
- svfloat64_t c_imag_f64x = svcvt_f64_f32_x(batch_predicate_f64x, svtrn2_f32(c_f32x, c_f32x));
384
- svwrite_hor_za64_f64_m(0, r, batch_predicate_f64x, c_imag_f64x);
385
+ svfloat32_t c_f32x = svld1_f32(batch_predicate_b32x,
386
+ (nk_f32_t const *)c_pairs + ((row + r) * dimensions + j) * 2);
387
+ svfloat64_t c_imag_f64x = svcvt_f64_f32_x(batch_predicate_b64x, svtrn2_f32(c_f32x, c_f32x));
388
+ svwrite_hor_za64_f64_m(0, r, batch_predicate_b64x, c_imag_f64x);
385
389
  }
386
390
 
387
391
  for (nk_size_t k = 0; k < batch_size; k++) {
388
- svfloat64_t c_im_col_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), row_predicate_f64x, 0, k);
389
- svmopa_za64_f64_m(2, row_predicate_f64x, row_predicate_f64x, c_im_col_f64x,
392
+ svfloat64_t c_im_col_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), row_predicate_b64x, 0, k);
393
+ svmopa_za64_f64_m(2, row_predicate_b64x, row_predicate_b64x, c_im_col_f64x,
390
394
  svdup_f64((nk_f64_t)b_pairs[j + k].real)); // v_imag += c_imag × b_real
391
- svmops_za64_f64_m(1, row_predicate_f64x, row_predicate_f64x, c_im_col_f64x,
395
+ svmops_za64_f64_m(1, row_predicate_b64x, row_predicate_b64x, c_im_col_f64x,
392
396
  svdup_f64((nk_f64_t)b_pairs[j + k].imag)); // v_real -= c_imag × b_imag
393
397
  }
394
398
  }
395
399
 
396
- svfloat64_t v_re_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), row_predicate_f64x, 1, 0);
397
- svfloat64_t v_im_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), row_predicate_f64x, 2, 0);
400
+ svfloat64_t v_re_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), row_predicate_b64x, 1, 0);
401
+ svfloat64_t v_im_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), row_predicate_b64x, 2, 0);
398
402
 
399
403
  // Deinterleave a[row:row+tile]
400
- svbool_t row_predicate_f32x = svwhilelt_b32_u64(0u, rows_remaining + rows_remaining);
401
- svfloat32_t a_f32x = svld1_f32(row_predicate_f32x, (nk_f32_t const *)a_pairs + row * 2);
402
- svfloat64_t a_re_f64x = svcvt_f64_f32_x(row_predicate_f64x, svtrn1_f32(a_f32x, a_f32x));
403
- svfloat64_t a_im_f64x = svcvt_f64_f32_x(row_predicate_f64x, svtrn2_f32(a_f32x, a_f32x));
404
+ svbool_t row_predicate_b32x = svwhilelt_b32_u64(0u, rows_remaining + rows_remaining);
405
+ svfloat32_t a_f32x = svld1_f32(row_predicate_b32x, (nk_f32_t const *)a_pairs + row * 2);
406
+ svfloat64_t a_re_f64x = svcvt_f64_f32_x(row_predicate_b64x, svtrn1_f32(a_f32x, a_f32x));
407
+ svfloat64_t a_im_f64x = svcvt_f64_f32_x(row_predicate_b64x, svtrn2_f32(a_f32x, a_f32x));
404
408
 
405
409
  // Complex dot: a × v
406
410
  outer_sum_real_f64 += svaddv_f64(
407
- predicate_body_f64x, svsub_f64_x(row_predicate_f64x, svmul_f64_x(row_predicate_f64x, a_re_f64x, v_re_f64x),
408
- svmul_f64_x(row_predicate_f64x, a_im_f64x, v_im_f64x)));
411
+ predicate_body_b64x, svsub_f64_x(row_predicate_b64x, svmul_f64_x(row_predicate_b64x, a_re_f64x, v_re_f64x),
412
+ svmul_f64_x(row_predicate_b64x, a_im_f64x, v_im_f64x)));
409
413
  outer_sum_imag_f64 += svaddv_f64(
410
- predicate_body_f64x, svadd_f64_x(row_predicate_f64x, svmul_f64_x(row_predicate_f64x, a_re_f64x, v_im_f64x),
411
- svmul_f64_x(row_predicate_f64x, a_im_f64x, v_re_f64x)));
414
+ predicate_body_b64x, svadd_f64_x(row_predicate_b64x, svmul_f64_x(row_predicate_b64x, a_re_f64x, v_im_f64x),
415
+ svmul_f64_x(row_predicate_b64x, a_im_f64x, v_re_f64x)));
412
416
  }
413
417
 
414
418
  results->real = outer_sum_real_f64;
@@ -416,8 +420,8 @@ __arm_locally_streaming __arm_new("za") static void nk_bilinear_f32c_smef64_stre
416
420
  }
417
421
 
418
422
  NK_PUBLIC void nk_bilinear_f32c_smef64(nk_f32c_t const *a_pairs, nk_f32c_t const *b_pairs, nk_f32c_t const *c_pairs,
419
- nk_size_t n, nk_f64c_t *results) {
420
- nk_bilinear_f32c_smef64_streaming_(a_pairs, b_pairs, c_pairs, n, results);
423
+ nk_size_t dimensions, nk_f64c_t *results) {
424
+ nk_bilinear_f32c_smef64_streaming_(a_pairs, b_pairs, c_pairs, dimensions, results);
421
425
  }
422
426
 
423
427
  /**
@@ -426,20 +430,20 @@ NK_PUBLIC void nk_bilinear_f32c_smef64(nk_f32c_t const *a_pairs, nk_f32c_t const
426
430
  */
427
431
  __arm_locally_streaming static void nk_bilinear_f64c_smef64_streaming_(nk_f64c_t const *a_pairs,
428
432
  nk_f64c_t const *b_pairs,
429
- nk_f64c_t const *c_pairs, nk_size_t n,
433
+ nk_f64c_t const *c_pairs, nk_size_t dimensions,
430
434
  nk_f64c_t *results) {
431
- svbool_t predicate_all_f64x = svptrue_b64();
435
+ svbool_t predicate_all_b64x = svptrue_b64();
432
436
  nk_f64_t outer_sum_real = 0.0, outer_comp_real = 0.0;
433
437
  nk_f64_t outer_sum_imag = 0.0, outer_comp_imag = 0.0;
434
- nk_size_t const n2 = n * 2; // total f64 elements in interleaved layout
438
+ nk_size_t const n2 = dimensions * 2; // total f64 elements in interleaved layout
435
439
 
436
440
  // swap_idx_u64x = [1,0,3,2,5,4,...] — swap adjacent f64 lanes
437
- svuint64_t swap_idx_u64x = sveor_u64_x(predicate_all_f64x, svindex_u64(0, 1), svdup_u64(1));
441
+ svuint64_t swap_idx_u64x = sveor_u64_x(predicate_all_b64x, svindex_u64(0, 1), svdup_u64(1));
438
442
  // sign_mask_u64x = [0, 0x8000..., 0, 0x8000..., ...] — sign bit in odd positions
439
443
  svuint64_t sign_mask_u64x = svlsl_u64_x(
440
- predicate_all_f64x, svand_u64_x(predicate_all_f64x, svindex_u64(0, 1), svdup_u64(1)), svdup_u64(63));
444
+ predicate_all_b64x, svand_u64_x(predicate_all_b64x, svindex_u64(0, 1), svdup_u64(1)), svdup_u64(63));
441
445
 
442
- for (nk_size_t row = 0; row < n; ++row) {
446
+ for (nk_size_t row = 0; row < dimensions; ++row) {
443
447
  nk_f64_t a_real = a_pairs[row].real;
444
448
  nk_f64_t a_imag = a_pairs[row].imag;
445
449
 
@@ -447,33 +451,33 @@ __arm_locally_streaming static void nk_bilinear_f64c_smef64_streaming_(nk_f64c_t
447
451
  svfloat64_t sum_real_f64x = svdup_f64(0), comp_real_f64x = svdup_f64(0);
448
452
  svfloat64_t sum_imag_f64x = svdup_f64(0), comp_imag_f64x = svdup_f64(0);
449
453
  nk_size_t j = 0;
450
- svbool_t predicate_f64x = svwhilelt_b64(j, n2);
454
+ svbool_t predicate_b64x = svwhilelt_b64(j, n2);
451
455
 
452
- while (svptest_first(predicate_all_f64x, predicate_f64x)) {
456
+ while (svptest_first(predicate_all_b64x, predicate_b64x)) {
453
457
  // Load interleaved [re₀, im₀, re₁, im₁, ...] — no deinterleave needed
454
- svfloat64_t b_f64x = svld1_f64(predicate_f64x, (nk_f64_t const *)b_pairs + j);
455
- svfloat64_t c_f64x = svld1_f64(predicate_f64x, (nk_f64_t const *)c_pairs + row * n2 + j);
458
+ svfloat64_t b_f64x = svld1_f64(predicate_b64x, (nk_f64_t const *)b_pairs + j);
459
+ svfloat64_t c_f64x = svld1_f64(predicate_b64x, (nk_f64_t const *)c_pairs + row * n2 + j);
456
460
  svfloat64_t c_swapped_f64x = svtbl_f64(c_f64x, swap_idx_u64x);
457
461
 
458
462
  // 2 Dot2 accumulators instead of 4:
459
463
  // sum_real_f64x accumulates [c_real×b_real, c_imag×b_imag, ...] (sign-flip deferred)
460
464
  // sum_imag_f64x accumulates [c_imag×b_real, c_real×b_imag, ...] (all positive)
461
- nk_dot2_f64_sve_accumulate_(predicate_f64x, &sum_real_f64x, &comp_real_f64x, c_f64x, b_f64x);
462
- nk_dot2_f64_sve_accumulate_(predicate_f64x, &sum_imag_f64x, &comp_imag_f64x, c_swapped_f64x, b_f64x);
465
+ nk_dot2_f64_sve_accumulate_(predicate_b64x, &sum_real_f64x, &comp_real_f64x, c_f64x, b_f64x);
466
+ nk_dot2_f64_sve_accumulate_(predicate_b64x, &sum_imag_f64x, &comp_imag_f64x, c_swapped_f64x, b_f64x);
463
467
 
464
468
  j += svcntd();
465
- predicate_f64x = svwhilelt_b64(j, n2);
469
+ predicate_b64x = svwhilelt_b64(j, n2);
466
470
  }
467
471
 
468
472
  // Flip sign of odd positions in sum_real_f64x: [c_real×b_real, -(c_imag×b_imag), ...]
469
473
  sum_real_f64x = svreinterpret_f64_u64(
470
- sveor_u64_x(predicate_all_f64x, svreinterpret_u64_f64(sum_real_f64x), sign_mask_u64x));
474
+ sveor_u64_x(predicate_all_b64x, svreinterpret_u64_f64(sum_real_f64x), sign_mask_u64x));
471
475
  comp_real_f64x = svreinterpret_f64_u64(
472
- sveor_u64_x(predicate_all_f64x, svreinterpret_u64_f64(comp_real_f64x), sign_mask_u64x));
473
- nk_f64_t inner_real = svaddv_f64(predicate_all_f64x,
474
- svadd_f64_x(predicate_all_f64x, sum_real_f64x, comp_real_f64x));
475
- nk_f64_t inner_imag = svaddv_f64(predicate_all_f64x,
476
- svadd_f64_x(predicate_all_f64x, sum_imag_f64x, comp_imag_f64x));
476
+ sveor_u64_x(predicate_all_b64x, svreinterpret_u64_f64(comp_real_f64x), sign_mask_u64x));
477
+ nk_f64_t inner_real = svaddv_f64(predicate_all_b64x,
478
+ svadd_f64_x(predicate_all_b64x, sum_real_f64x, comp_real_f64x));
479
+ nk_f64_t inner_imag = svaddv_f64(predicate_all_b64x,
480
+ svadd_f64_x(predicate_all_b64x, sum_imag_f64x, comp_imag_f64x));
477
481
 
478
482
  // Outer Dot2 complex multiply: a × inner
479
483
  nk_f64_dot2_(&outer_sum_real, &outer_comp_real, a_real, inner_real);
@@ -487,8 +491,8 @@ __arm_locally_streaming static void nk_bilinear_f64c_smef64_streaming_(nk_f64c_t
487
491
  }
488
492
 
489
493
  NK_PUBLIC void nk_bilinear_f64c_smef64(nk_f64c_t const *a_pairs, nk_f64c_t const *b_pairs, nk_f64c_t const *c_pairs,
490
- nk_size_t n, nk_f64c_t *results) {
491
- nk_bilinear_f64c_smef64_streaming_(a_pairs, b_pairs, c_pairs, n, results);
494
+ nk_size_t dimensions, nk_f64c_t *results) {
495
+ nk_bilinear_f64c_smef64_streaming_(a_pairs, b_pairs, c_pairs, dimensions, results);
492
496
  }
493
497
 
494
498
  #if defined(__clang__)