numkong 7.0.0 → 7.4.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (315) hide show
  1. package/README.md +197 -124
  2. package/binding.gyp +34 -484
  3. package/c/dispatch_bf16.c +59 -1
  4. package/c/dispatch_e2m3.c +41 -8
  5. package/c/dispatch_e3m2.c +49 -8
  6. package/c/dispatch_e4m3.c +51 -9
  7. package/c/dispatch_e5m2.c +45 -1
  8. package/c/dispatch_f16.c +79 -26
  9. package/c/dispatch_f16c.c +5 -5
  10. package/c/dispatch_f32.c +56 -0
  11. package/c/dispatch_f64.c +52 -0
  12. package/c/dispatch_i4.c +3 -0
  13. package/c/dispatch_i8.c +62 -3
  14. package/c/dispatch_other.c +18 -0
  15. package/c/dispatch_u1.c +54 -9
  16. package/c/dispatch_u4.c +3 -0
  17. package/c/dispatch_u8.c +64 -3
  18. package/c/numkong.c +3 -0
  19. package/include/README.md +79 -9
  20. package/include/numkong/attention/sapphireamx.h +278 -276
  21. package/include/numkong/attention/sme.h +983 -977
  22. package/include/numkong/attention.h +1 -1
  23. package/include/numkong/capabilities.h +289 -94
  24. package/include/numkong/cast/README.md +40 -40
  25. package/include/numkong/cast/diamond.h +64 -0
  26. package/include/numkong/cast/haswell.h +42 -194
  27. package/include/numkong/cast/icelake.h +42 -37
  28. package/include/numkong/cast/loongsonasx.h +252 -0
  29. package/include/numkong/cast/neon.h +216 -249
  30. package/include/numkong/cast/powervsx.h +449 -0
  31. package/include/numkong/cast/rvv.h +223 -274
  32. package/include/numkong/cast/sapphire.h +18 -18
  33. package/include/numkong/cast/serial.h +1018 -944
  34. package/include/numkong/cast/skylake.h +82 -23
  35. package/include/numkong/cast/v128relaxed.h +462 -105
  36. package/include/numkong/cast.h +24 -0
  37. package/include/numkong/cast.hpp +44 -0
  38. package/include/numkong/curved/README.md +17 -17
  39. package/include/numkong/curved/neon.h +131 -7
  40. package/include/numkong/curved/neonbfdot.h +6 -7
  41. package/include/numkong/curved/rvv.h +26 -26
  42. package/include/numkong/curved/smef64.h +186 -182
  43. package/include/numkong/curved.h +14 -18
  44. package/include/numkong/dot/README.md +154 -137
  45. package/include/numkong/dot/alder.h +43 -43
  46. package/include/numkong/dot/diamond.h +158 -0
  47. package/include/numkong/dot/genoa.h +4 -30
  48. package/include/numkong/dot/haswell.h +215 -180
  49. package/include/numkong/dot/icelake.h +190 -76
  50. package/include/numkong/dot/loongsonasx.h +671 -0
  51. package/include/numkong/dot/neon.h +124 -73
  52. package/include/numkong/dot/neonbfdot.h +11 -12
  53. package/include/numkong/dot/neonfhm.h +44 -46
  54. package/include/numkong/dot/neonfp8.h +323 -0
  55. package/include/numkong/dot/neonsdot.h +190 -76
  56. package/include/numkong/dot/powervsx.h +752 -0
  57. package/include/numkong/dot/rvv.h +92 -84
  58. package/include/numkong/dot/rvvbf16.h +12 -12
  59. package/include/numkong/dot/rvvhalf.h +12 -12
  60. package/include/numkong/dot/sapphire.h +4 -4
  61. package/include/numkong/dot/serial.h +66 -30
  62. package/include/numkong/dot/sierra.h +31 -31
  63. package/include/numkong/dot/skylake.h +142 -110
  64. package/include/numkong/dot/sve.h +217 -177
  65. package/include/numkong/dot/svebfdot.h +10 -10
  66. package/include/numkong/dot/svehalf.h +85 -41
  67. package/include/numkong/dot/svesdot.h +89 -0
  68. package/include/numkong/dot/v128relaxed.h +124 -89
  69. package/include/numkong/dot.h +114 -48
  70. package/include/numkong/dots/README.md +203 -203
  71. package/include/numkong/dots/alder.h +12 -9
  72. package/include/numkong/dots/diamond.h +86 -0
  73. package/include/numkong/dots/genoa.h +10 -4
  74. package/include/numkong/dots/haswell.h +63 -48
  75. package/include/numkong/dots/icelake.h +27 -18
  76. package/include/numkong/dots/loongsonasx.h +176 -0
  77. package/include/numkong/dots/neon.h +14 -11
  78. package/include/numkong/dots/neonbfdot.h +4 -3
  79. package/include/numkong/dots/neonfhm.h +11 -9
  80. package/include/numkong/dots/neonfp8.h +99 -0
  81. package/include/numkong/dots/neonsdot.h +48 -12
  82. package/include/numkong/dots/powervsx.h +194 -0
  83. package/include/numkong/dots/rvv.h +451 -344
  84. package/include/numkong/dots/sapphireamx.h +1028 -984
  85. package/include/numkong/dots/serial.h +213 -197
  86. package/include/numkong/dots/sierra.h +10 -7
  87. package/include/numkong/dots/skylake.h +47 -36
  88. package/include/numkong/dots/sme.h +2001 -2364
  89. package/include/numkong/dots/smebi32.h +175 -162
  90. package/include/numkong/dots/smef64.h +328 -323
  91. package/include/numkong/dots/v128relaxed.h +64 -41
  92. package/include/numkong/dots.h +573 -293
  93. package/include/numkong/dots.hpp +45 -43
  94. package/include/numkong/each/README.md +133 -137
  95. package/include/numkong/each/haswell.h +6 -6
  96. package/include/numkong/each/icelake.h +7 -7
  97. package/include/numkong/each/neon.h +76 -42
  98. package/include/numkong/each/neonbfdot.h +11 -12
  99. package/include/numkong/each/neonhalf.h +24 -116
  100. package/include/numkong/each/rvv.h +28 -28
  101. package/include/numkong/each/sapphire.h +27 -161
  102. package/include/numkong/each/serial.h +6 -6
  103. package/include/numkong/each/skylake.h +7 -7
  104. package/include/numkong/each/v128relaxed.h +562 -0
  105. package/include/numkong/each.h +148 -62
  106. package/include/numkong/each.hpp +2 -2
  107. package/include/numkong/geospatial/README.md +18 -18
  108. package/include/numkong/geospatial/haswell.h +365 -325
  109. package/include/numkong/geospatial/neon.h +350 -306
  110. package/include/numkong/geospatial/rvv.h +4 -4
  111. package/include/numkong/geospatial/skylake.h +376 -340
  112. package/include/numkong/geospatial/v128relaxed.h +366 -327
  113. package/include/numkong/geospatial.h +17 -17
  114. package/include/numkong/matrix.hpp +4 -4
  115. package/include/numkong/maxsim/README.md +14 -14
  116. package/include/numkong/maxsim/alder.h +6 -6
  117. package/include/numkong/maxsim/genoa.h +4 -4
  118. package/include/numkong/maxsim/haswell.h +6 -6
  119. package/include/numkong/maxsim/icelake.h +18 -18
  120. package/include/numkong/maxsim/neonsdot.h +21 -21
  121. package/include/numkong/maxsim/sapphireamx.h +14 -14
  122. package/include/numkong/maxsim/serial.h +6 -6
  123. package/include/numkong/maxsim/sme.h +221 -196
  124. package/include/numkong/maxsim/v128relaxed.h +6 -6
  125. package/include/numkong/mesh/README.md +62 -56
  126. package/include/numkong/mesh/haswell.h +339 -464
  127. package/include/numkong/mesh/neon.h +1100 -519
  128. package/include/numkong/mesh/neonbfdot.h +36 -68
  129. package/include/numkong/mesh/rvv.h +530 -435
  130. package/include/numkong/mesh/serial.h +75 -91
  131. package/include/numkong/mesh/skylake.h +1627 -302
  132. package/include/numkong/mesh/v128relaxed.h +443 -330
  133. package/include/numkong/mesh.h +63 -49
  134. package/include/numkong/mesh.hpp +4 -4
  135. package/include/numkong/numkong.h +3 -3
  136. package/include/numkong/numkong.hpp +1 -0
  137. package/include/numkong/probability/README.md +23 -19
  138. package/include/numkong/probability/neon.h +82 -52
  139. package/include/numkong/probability/rvv.h +28 -23
  140. package/include/numkong/probability/serial.h +51 -39
  141. package/include/numkong/probability.h +20 -23
  142. package/include/numkong/random.h +1 -1
  143. package/include/numkong/reduce/README.md +143 -138
  144. package/include/numkong/reduce/alder.h +81 -77
  145. package/include/numkong/reduce/haswell.h +222 -220
  146. package/include/numkong/reduce/neon.h +629 -519
  147. package/include/numkong/reduce/neonbfdot.h +7 -218
  148. package/include/numkong/reduce/neonfhm.h +9 -381
  149. package/include/numkong/reduce/neonsdot.h +9 -9
  150. package/include/numkong/reduce/rvv.h +928 -802
  151. package/include/numkong/reduce/serial.h +23 -27
  152. package/include/numkong/reduce/sierra.h +20 -20
  153. package/include/numkong/reduce/skylake.h +326 -324
  154. package/include/numkong/reduce/v128relaxed.h +52 -52
  155. package/include/numkong/reduce.h +4 -23
  156. package/include/numkong/reduce.hpp +156 -11
  157. package/include/numkong/scalar/README.md +6 -6
  158. package/include/numkong/scalar/haswell.h +26 -17
  159. package/include/numkong/scalar/loongsonasx.h +74 -0
  160. package/include/numkong/scalar/neon.h +9 -9
  161. package/include/numkong/scalar/powervsx.h +96 -0
  162. package/include/numkong/scalar/rvv.h +2 -2
  163. package/include/numkong/scalar/sapphire.h +21 -10
  164. package/include/numkong/scalar/serial.h +21 -21
  165. package/include/numkong/scalar.h +13 -0
  166. package/include/numkong/set/README.md +28 -28
  167. package/include/numkong/set/haswell.h +12 -12
  168. package/include/numkong/set/icelake.h +14 -14
  169. package/include/numkong/set/loongsonasx.h +181 -0
  170. package/include/numkong/set/neon.h +17 -18
  171. package/include/numkong/set/powervsx.h +326 -0
  172. package/include/numkong/set/rvv.h +4 -4
  173. package/include/numkong/set/serial.h +6 -6
  174. package/include/numkong/set/sve.h +60 -59
  175. package/include/numkong/set/v128relaxed.h +6 -6
  176. package/include/numkong/set.h +21 -7
  177. package/include/numkong/sets/README.md +26 -26
  178. package/include/numkong/sets/loongsonasx.h +52 -0
  179. package/include/numkong/sets/powervsx.h +65 -0
  180. package/include/numkong/sets/smebi32.h +395 -364
  181. package/include/numkong/sets.h +83 -40
  182. package/include/numkong/sparse/README.md +4 -4
  183. package/include/numkong/sparse/icelake.h +101 -101
  184. package/include/numkong/sparse/serial.h +1 -1
  185. package/include/numkong/sparse/sve2.h +137 -141
  186. package/include/numkong/sparse/turin.h +12 -12
  187. package/include/numkong/sparse.h +10 -10
  188. package/include/numkong/spatial/README.md +230 -226
  189. package/include/numkong/spatial/alder.h +113 -116
  190. package/include/numkong/spatial/diamond.h +240 -0
  191. package/include/numkong/spatial/genoa.h +0 -68
  192. package/include/numkong/spatial/haswell.h +74 -55
  193. package/include/numkong/spatial/icelake.h +539 -58
  194. package/include/numkong/spatial/loongsonasx.h +483 -0
  195. package/include/numkong/spatial/neon.h +125 -52
  196. package/include/numkong/spatial/neonbfdot.h +8 -9
  197. package/include/numkong/spatial/neonfp8.h +258 -0
  198. package/include/numkong/spatial/neonsdot.h +180 -12
  199. package/include/numkong/spatial/powervsx.h +738 -0
  200. package/include/numkong/spatial/rvv.h +146 -139
  201. package/include/numkong/spatial/rvvbf16.h +17 -12
  202. package/include/numkong/spatial/rvvhalf.h +13 -10
  203. package/include/numkong/spatial/serial.h +13 -12
  204. package/include/numkong/spatial/sierra.h +232 -39
  205. package/include/numkong/spatial/skylake.h +73 -74
  206. package/include/numkong/spatial/sve.h +93 -72
  207. package/include/numkong/spatial/svebfdot.h +29 -29
  208. package/include/numkong/spatial/svehalf.h +52 -26
  209. package/include/numkong/spatial/svesdot.h +142 -0
  210. package/include/numkong/spatial/v128relaxed.h +293 -41
  211. package/include/numkong/spatial.h +338 -82
  212. package/include/numkong/spatials/README.md +194 -194
  213. package/include/numkong/spatials/diamond.h +82 -0
  214. package/include/numkong/spatials/haswell.h +2 -2
  215. package/include/numkong/spatials/loongsonasx.h +153 -0
  216. package/include/numkong/spatials/neonfp8.h +111 -0
  217. package/include/numkong/spatials/neonsdot.h +34 -0
  218. package/include/numkong/spatials/powervsx.h +153 -0
  219. package/include/numkong/spatials/rvv.h +259 -243
  220. package/include/numkong/spatials/sapphireamx.h +173 -173
  221. package/include/numkong/spatials/serial.h +2 -2
  222. package/include/numkong/spatials/skylake.h +2 -2
  223. package/include/numkong/spatials/sme.h +590 -605
  224. package/include/numkong/spatials/smef64.h +139 -130
  225. package/include/numkong/spatials/v128relaxed.h +2 -2
  226. package/include/numkong/spatials.h +820 -500
  227. package/include/numkong/spatials.hpp +49 -48
  228. package/include/numkong/tensor.hpp +406 -17
  229. package/include/numkong/trigonometry/README.md +19 -19
  230. package/include/numkong/trigonometry/haswell.h +402 -401
  231. package/include/numkong/trigonometry/neon.h +386 -387
  232. package/include/numkong/trigonometry/rvv.h +52 -51
  233. package/include/numkong/trigonometry/serial.h +13 -13
  234. package/include/numkong/trigonometry/skylake.h +373 -369
  235. package/include/numkong/trigonometry/v128relaxed.h +375 -374
  236. package/include/numkong/trigonometry.h +13 -13
  237. package/include/numkong/trigonometry.hpp +2 -2
  238. package/include/numkong/types.h +287 -49
  239. package/include/numkong/types.hpp +436 -12
  240. package/include/numkong/vector.hpp +82 -14
  241. package/javascript/dist/cjs/numkong-wasm.js +6 -12
  242. package/javascript/dist/cjs/numkong.d.ts +7 -1
  243. package/javascript/dist/cjs/numkong.js +37 -11
  244. package/javascript/dist/cjs/types.d.ts +9 -0
  245. package/javascript/dist/cjs/types.js +96 -0
  246. package/javascript/dist/esm/numkong-browser.d.ts +14 -0
  247. package/javascript/dist/esm/numkong-browser.js +23 -0
  248. package/javascript/dist/esm/numkong-wasm.js +6 -12
  249. package/javascript/dist/esm/numkong.d.ts +7 -1
  250. package/javascript/dist/esm/numkong.js +37 -11
  251. package/javascript/dist/esm/types.d.ts +9 -0
  252. package/javascript/dist/esm/types.js +96 -0
  253. package/javascript/node-gyp-build.d.ts +4 -1
  254. package/javascript/numkong-browser.ts +40 -0
  255. package/javascript/numkong-wasm.ts +7 -13
  256. package/javascript/numkong.c +5 -26
  257. package/javascript/numkong.ts +36 -11
  258. package/javascript/tsconfig-base.json +1 -0
  259. package/javascript/tsconfig-cjs.json +6 -1
  260. package/javascript/types.ts +110 -0
  261. package/numkong.gypi +101 -0
  262. package/package.json +34 -13
  263. package/probes/arm_neon.c +8 -0
  264. package/probes/arm_neon_bfdot.c +9 -0
  265. package/probes/arm_neon_fhm.c +9 -0
  266. package/probes/arm_neon_half.c +8 -0
  267. package/probes/arm_neon_sdot.c +9 -0
  268. package/probes/arm_neonfp8.c +9 -0
  269. package/probes/arm_sme.c +16 -0
  270. package/probes/arm_sme2.c +16 -0
  271. package/probes/arm_sme2p1.c +16 -0
  272. package/probes/arm_sme_bf16.c +16 -0
  273. package/probes/arm_sme_bi32.c +16 -0
  274. package/probes/arm_sme_f64.c +16 -0
  275. package/probes/arm_sme_fa64.c +14 -0
  276. package/probes/arm_sme_half.c +16 -0
  277. package/probes/arm_sme_lut2.c +15 -0
  278. package/probes/arm_sve.c +18 -0
  279. package/probes/arm_sve2.c +20 -0
  280. package/probes/arm_sve2p1.c +18 -0
  281. package/probes/arm_sve_bfdot.c +20 -0
  282. package/probes/arm_sve_half.c +18 -0
  283. package/probes/arm_sve_sdot.c +21 -0
  284. package/probes/loongarch_lasx.c +12 -0
  285. package/probes/power_vsx.c +12 -0
  286. package/probes/probe.js +127 -0
  287. package/probes/riscv_rvv.c +14 -0
  288. package/probes/riscv_rvv_bb.c +15 -0
  289. package/probes/riscv_rvv_bf16.c +17 -0
  290. package/probes/riscv_rvv_half.c +14 -0
  291. package/probes/wasm_v128relaxed.c +11 -0
  292. package/probes/x86_alder.c +17 -0
  293. package/probes/x86_diamond.c +17 -0
  294. package/probes/x86_genoa.c +17 -0
  295. package/probes/x86_graniteamx.c +19 -0
  296. package/probes/x86_haswell.c +11 -0
  297. package/probes/x86_icelake.c +17 -0
  298. package/probes/x86_sapphire.c +16 -0
  299. package/probes/x86_sapphireamx.c +18 -0
  300. package/probes/x86_sierra.c +17 -0
  301. package/probes/x86_skylake.c +15 -0
  302. package/probes/x86_turin.c +17 -0
  303. package/wasm/numkong-emscripten.js +2 -0
  304. package/wasm/numkong.d.ts +14 -0
  305. package/wasm/numkong.js +1124 -0
  306. package/wasm/numkong.wasm +0 -0
  307. package/include/numkong/curved/neonhalf.h +0 -212
  308. package/include/numkong/dot/neonhalf.h +0 -198
  309. package/include/numkong/dots/neonhalf.h +0 -57
  310. package/include/numkong/mesh/neonhalf.h +0 -616
  311. package/include/numkong/reduce/neonhalf.h +0 -157
  312. package/include/numkong/spatial/neonhalf.h +0 -118
  313. package/include/numkong/spatial/sapphire.h +0 -343
  314. package/include/numkong/spatials/neonhalf.h +0 -58
  315. package/javascript/README.md +0 -246
@@ -20,69 +20,70 @@ extern "C" {
20
20
  #endif
21
21
 
22
22
  #if defined(__clang__)
23
- #pragma clang attribute push(__attribute__((target("sme,sve,sme-f64f64"))), apply_to = function)
23
+ #pragma clang attribute push(__attribute__((target("sme,sme-f64f64"))), apply_to = function)
24
24
  #elif defined(__GNUC__)
25
25
  #pragma GCC push_options
26
26
  #pragma GCC target("+sme+sme-f64f64")
27
27
  #endif
28
28
 
29
29
  NK_PUBLIC nk_f64_t nk_dots_reduce_sumsq_f32_ssve_(nk_f32_t const *data, nk_size_t count) NK_STREAMING_ {
30
- svfloat64_t accumulator_lo_f64x = svdup_f64(0.0);
31
- svfloat64_t accumulator_hi_f64x = svdup_f64(0.0);
30
+ svfloat64_t accumulator_even_f64x = svdup_f64(0.0);
31
+ svfloat64_t accumulator_odd_f64x = svdup_f64(0.0);
32
32
  nk_size_t const vector_length = svcntw();
33
33
  nk_size_t const half_vector_length = svcntd();
34
34
  for (nk_size_t i = 0; i < count; i += vector_length) {
35
- svbool_t predicate_f32x = svwhilelt_b32_u64(i, count);
36
- svfloat32_t values_f32x = svld1_f32(predicate_f32x, data + i);
35
+ svbool_t predicate_b32x = svwhilelt_b32_u64(i, count);
36
+ svfloat32_t values_f32x = svld1_f32(predicate_b32x, data + i);
37
37
 
38
- svbool_t predicate_lo_f64x = svwhilelt_b64_u64(i, count);
39
- svfloat64_t values_lo_f64x = svcvt_f64_f32_x(predicate_lo_f64x, values_f32x);
40
- accumulator_lo_f64x = svmla_f64_x(predicate_lo_f64x, accumulator_lo_f64x, values_lo_f64x, values_lo_f64x);
38
+ svbool_t predicate_even_b64x = svwhilelt_b64_u64(i, count);
39
+ svfloat64_t values_even_f64x = svcvt_f64_f32_x(predicate_even_b64x, values_f32x);
40
+ accumulator_even_f64x = svmla_f64_m(predicate_even_b64x, accumulator_even_f64x, values_even_f64x,
41
+ values_even_f64x);
41
42
 
42
- svbool_t predicate_hi_f64x = svwhilelt_b64_u64(i + half_vector_length, count);
43
- svfloat64_t values_hi_f64x = svcvtlt_f64_f32_x(predicate_hi_f64x, values_f32x);
44
- accumulator_hi_f64x = svmla_f64_x(predicate_hi_f64x, accumulator_hi_f64x, values_hi_f64x, values_hi_f64x);
43
+ svbool_t predicate_odd_b64x = svwhilelt_b64_u64(i + half_vector_length, count);
44
+ svfloat64_t values_odd_f64x = svcvtlt_f64_f32_x(predicate_odd_b64x, values_f32x);
45
+ accumulator_odd_f64x = svmla_f64_m(predicate_odd_b64x, accumulator_odd_f64x, values_odd_f64x, values_odd_f64x);
45
46
  }
46
- return svaddv_f64(svptrue_b64(), accumulator_lo_f64x) + svaddv_f64(svptrue_b64(), accumulator_hi_f64x);
47
+ return svaddv_f64(svptrue_b64(), accumulator_even_f64x) + svaddv_f64(svptrue_b64(), accumulator_odd_f64x);
47
48
  }
48
49
 
49
- NK_PUBLIC nk_f64_t nk_dots_reduce_sumsq_f64_ssve_(nk_f64_t const *data, nk_size_t count) NK_STREAMING_COMPATIBLE_ {
50
+ NK_PUBLIC nk_f64_t nk_dots_reduce_sumsq_f64_ssve_(nk_f64_t const *data, nk_size_t count) NK_STREAMING_ {
50
51
  svfloat64_t accumulator_f64x = svdup_f64(0.0);
51
52
  nk_size_t const vector_length = svcntd();
52
53
  for (nk_size_t i = 0; i < count; i += vector_length) {
53
- svbool_t predicate_f64x = svwhilelt_b64_u64(i, count);
54
- svfloat64_t values_f64x = svld1_f64(predicate_f64x, data + i);
55
- accumulator_f64x = svmla_f64_x(predicate_f64x, accumulator_f64x, values_f64x, values_f64x);
54
+ svbool_t predicate_b64x = svwhilelt_b64_u64(i, count);
55
+ svfloat64_t values_f64x = svld1_f64(predicate_b64x, data + i);
56
+ accumulator_f64x = svmla_f64_m(predicate_b64x, accumulator_f64x, values_f64x, values_f64x);
56
57
  }
57
58
  return svaddv_f64(svptrue_b64(), accumulator_f64x);
58
59
  }
59
60
 
60
- NK_PUBLIC svfloat64_t nk_angulars_from_dot_f64x_ssvef64_(svbool_t predicate_f64x, svfloat64_t dots_f64x,
61
+ NK_PUBLIC svfloat64_t nk_angulars_from_dot_f64x_ssvef64_(svbool_t predicate_b64x, svfloat64_t dots_f64x,
61
62
  svfloat64_t query_norm_sq_f64x,
62
- svfloat64_t target_norms_sq_f64x) NK_STREAMING_COMPATIBLE_ {
63
- svfloat64_t norms_product_f64x = svmul_f64_x(predicate_f64x, query_norm_sq_f64x, target_norms_sq_f64x);
64
- svbool_t positive_norms_f64x = svcmpgt_n_f64(predicate_f64x, norms_product_f64x, 0.0);
65
- svfloat64_t denom_f64x = svsqrt_f64_x(positive_norms_f64x, norms_product_f64x);
66
- svfloat64_t safe_denom_f64x = svsel_f64(positive_norms_f64x, denom_f64x, svdup_n_f64(1.0));
67
- svfloat64_t normalized_f64x = svdiv_f64_x(predicate_f64x, dots_f64x, safe_denom_f64x);
68
- svfloat64_t angular_f64x = svsub_f64_x(predicate_f64x, svdup_n_f64(1.0), normalized_f64x);
63
+ svfloat64_t target_norms_sq_f64x) NK_STREAMING_ {
64
+ svfloat64_t norms_product_f64x = svmul_f64_x(predicate_b64x, query_norm_sq_f64x, target_norms_sq_f64x);
65
+ svbool_t positive_norms_b64x = svcmpgt_n_f64(predicate_b64x, norms_product_f64x, 0.0);
66
+ svfloat64_t denom_f64x = svsqrt_f64_x(positive_norms_b64x, norms_product_f64x);
67
+ svfloat64_t safe_denom_f64x = svsel_f64(positive_norms_b64x, denom_f64x, svdup_n_f64(1.0));
68
+ svfloat64_t normalized_f64x = svdiv_f64_x(predicate_b64x, dots_f64x, safe_denom_f64x);
69
+ svfloat64_t angular_f64x = svsub_f64_x(predicate_b64x, svdup_n_f64(1.0), normalized_f64x);
69
70
  angular_f64x = svsel_f64(
70
- positive_norms_f64x, angular_f64x,
71
- svsel_f64(svcmpeq_n_f64(predicate_f64x, dots_f64x, 0.0), svdup_n_f64(0.0), svdup_n_f64(1.0)));
72
- return svmax_f64_x(predicate_f64x, angular_f64x, svdup_n_f64(0.0));
71
+ positive_norms_b64x, angular_f64x,
72
+ svsel_f64(svcmpeq_n_f64(predicate_b64x, dots_f64x, 0.0), svdup_n_f64(0.0), svdup_n_f64(1.0)));
73
+ return svmax_f64_x(predicate_b64x, angular_f64x, svdup_n_f64(0.0));
73
74
  }
74
75
 
75
- NK_PUBLIC svfloat64_t nk_euclideans_from_dot_f64x_ssvef64_(svbool_t predicate_f64x, svfloat64_t dots_f64x,
76
+ NK_PUBLIC svfloat64_t nk_euclideans_from_dot_f64x_ssvef64_(svbool_t predicate_b64x, svfloat64_t dots_f64x,
76
77
  svfloat64_t query_norm_sq_f64x,
77
- svfloat64_t target_norms_sq_f64x) NK_STREAMING_COMPATIBLE_ {
78
- svfloat64_t sum_sq_f64x = svadd_f64_x(predicate_f64x, query_norm_sq_f64x, target_norms_sq_f64x);
79
- svfloat64_t dist_sq_f64x = svsub_f64_x(predicate_f64x, sum_sq_f64x,
80
- svmul_f64_x(predicate_f64x, svdup_n_f64(2.0), dots_f64x));
81
- dist_sq_f64x = svmax_f64_x(predicate_f64x, dist_sq_f64x, svdup_n_f64(0.0));
82
- return svsqrt_f64_x(predicate_f64x, dist_sq_f64x);
78
+ svfloat64_t target_norms_sq_f64x) NK_STREAMING_ {
79
+ svfloat64_t sum_sq_f64x = svadd_f64_x(predicate_b64x, query_norm_sq_f64x, target_norms_sq_f64x);
80
+ svfloat64_t dist_sq_f64x = svsub_f64_x(predicate_b64x, sum_sq_f64x,
81
+ svmul_f64_x(predicate_b64x, svdup_n_f64(2.0), dots_f64x));
82
+ dist_sq_f64x = svmax_f64_x(predicate_b64x, dist_sq_f64x, svdup_n_f64(0.0));
83
+ return svsqrt_f64_x(predicate_b64x, dist_sq_f64x);
83
84
  }
84
85
 
85
- #pragma region Single Precision Packed Angular
86
+ #pragma region F32 Packed Angular
86
87
 
87
88
  __arm_locally_streaming static void nk_angulars_packed_f32_smef64_finalize_streaming_( //
88
89
  nk_f32_t const *a, void const *b_packed, nk_f64_t *c, //
@@ -99,11 +100,11 @@ __arm_locally_streaming static void nk_angulars_packed_f32_smef64_finalize_strea
99
100
  svfloat64_t query_norm_sq_f64x = svdup_n_f64(query_norm_sq_f64);
100
101
 
101
102
  for (nk_size_t col_index = 0; col_index < columns; col_index += svcntd()) {
102
- svbool_t predicate_f64x = svwhilelt_b64_u64(col_index, columns);
103
- svfloat64_t dots_f64x = svld1_f64(predicate_f64x, c_row + col_index);
104
- svfloat64_t target_norms_sq_f64x = svld1_f64(predicate_f64x, b_norms + col_index);
105
- svst1_f64(predicate_f64x, c_row + col_index,
106
- nk_angulars_from_dot_f64x_ssvef64_(predicate_f64x, dots_f64x, query_norm_sq_f64x,
103
+ svbool_t predicate_b64x = svwhilelt_b64_u64(col_index, columns);
104
+ svfloat64_t dots_f64x = svld1_f64(predicate_b64x, c_row + col_index);
105
+ svfloat64_t target_norms_sq_f64x = svld1_f64(predicate_b64x, b_norms + col_index);
106
+ svst1_f64(predicate_b64x, c_row + col_index,
107
+ nk_angulars_from_dot_f64x_ssvef64_(predicate_b64x, dots_f64x, query_norm_sq_f64x,
107
108
  target_norms_sq_f64x));
108
109
  }
109
110
  }
@@ -122,7 +123,8 @@ NK_PUBLIC void nk_angulars_packed_f32_smef64( //
122
123
  c_stride_elements);
123
124
  }
124
125
 
125
- #pragma region Single Precision Packed Euclidean
126
+ #pragma endregion F32 Packed Angular
127
+ #pragma region F32 Packed Euclidean
126
128
 
127
129
  __arm_locally_streaming static void nk_euclideans_packed_f32_smef64_finalize_streaming_( //
128
130
  nk_f32_t const *a, void const *b_packed, nk_f64_t *c, //
@@ -139,11 +141,11 @@ __arm_locally_streaming static void nk_euclideans_packed_f32_smef64_finalize_str
139
141
  svfloat64_t query_norm_sq_f64x = svdup_n_f64(query_norm_sq_f64);
140
142
 
141
143
  for (nk_size_t col_index = 0; col_index < columns; col_index += svcntd()) {
142
- svbool_t predicate_f64x = svwhilelt_b64_u64(col_index, columns);
143
- svfloat64_t dots_f64x = svld1_f64(predicate_f64x, c_row + col_index);
144
- svfloat64_t target_norms_sq_f64x = svld1_f64(predicate_f64x, b_norms + col_index);
145
- svst1_f64(predicate_f64x, c_row + col_index,
146
- nk_euclideans_from_dot_f64x_ssvef64_(predicate_f64x, dots_f64x, query_norm_sq_f64x,
144
+ svbool_t predicate_b64x = svwhilelt_b64_u64(col_index, columns);
145
+ svfloat64_t dots_f64x = svld1_f64(predicate_b64x, c_row + col_index);
146
+ svfloat64_t target_norms_sq_f64x = svld1_f64(predicate_b64x, b_norms + col_index);
147
+ svst1_f64(predicate_b64x, c_row + col_index,
148
+ nk_euclideans_from_dot_f64x_ssvef64_(predicate_b64x, dots_f64x, query_norm_sq_f64x,
147
149
  target_norms_sq_f64x));
148
150
  }
149
151
  }
@@ -162,10 +164,11 @@ NK_PUBLIC void nk_euclideans_packed_f32_smef64( //
162
164
  c_stride_elements);
163
165
  }
164
166
 
165
- #pragma region Single Precision Symmetric Angular
167
+ #pragma endregion F32 Packed Euclidean
168
+ #pragma region F32 Symmetric Angular
166
169
 
167
- __arm_locally_streaming static void nk_angulars_symmetric_f32_smef64_finalize_streaming_( //
168
- nk_f32_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride_elements, //
170
+ __arm_locally_streaming static void nk_angulars_symmetric_f32_smef64_finalize_streaming_( //
171
+ nk_f32_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, //
169
172
  nk_f64_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
170
173
  // Phase 1: cache row norms on diagonal
171
174
  for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
@@ -175,8 +178,8 @@ __arm_locally_streaming static void nk_angulars_symmetric_f32_smef64_finalize_st
175
178
  }
176
179
  // Phase 2: column-chunked post-processing
177
180
  nk_f64_t column_norms[256];
178
- for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
179
- nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
181
+ for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
182
+ nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
180
183
  for (nk_size_t col = chunk_start; col < chunk_end; ++col) {
181
184
  nk_f32_t const *col_vector = vectors + col * stride_elements;
182
185
  column_norms[col - chunk_start] = nk_dots_reduce_sumsq_f32_ssve_(col_vector, depth);
@@ -187,11 +190,11 @@ __arm_locally_streaming static void nk_angulars_symmetric_f32_smef64_finalize_st
187
190
  nk_f64_t *result_row = result + row_index * result_stride_elements;
188
191
  svfloat64_t query_norm_sq_f64x = svdup_n_f64(result_row[row_index]);
189
192
  for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntd()) {
190
- svbool_t predicate_f64x = svwhilelt_b64_u64(col_index, chunk_end);
191
- svfloat64_t dots_f64x = svld1_f64(predicate_f64x, result_row + col_index);
192
- svfloat64_t target_norms_sq_f64x = svld1_f64(predicate_f64x, column_norms + (col_index - chunk_start));
193
- svst1_f64(predicate_f64x, result_row + col_index,
194
- nk_angulars_from_dot_f64x_ssvef64_(predicate_f64x, dots_f64x, query_norm_sq_f64x,
193
+ svbool_t predicate_b64x = svwhilelt_b64_u64(col_index, chunk_end);
194
+ svfloat64_t dots_f64x = svld1_f64(predicate_b64x, result_row + col_index);
195
+ svfloat64_t target_norms_sq_f64x = svld1_f64(predicate_b64x, column_norms + (col_index - chunk_start));
196
+ svst1_f64(predicate_b64x, result_row + col_index,
197
+ nk_angulars_from_dot_f64x_ssvef64_(predicate_b64x, dots_f64x, query_norm_sq_f64x,
195
198
  target_norms_sq_f64x));
196
199
  }
197
200
  }
@@ -201,23 +204,24 @@ __arm_locally_streaming static void nk_angulars_symmetric_f32_smef64_finalize_st
201
204
  result[row_index * result_stride_elements + row_index] = 0;
202
205
  }
203
206
 
204
- NK_PUBLIC void nk_angulars_symmetric_f32_smef64( //
205
- nk_f32_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
206
- nk_f64_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
207
+ NK_PUBLIC void nk_angulars_symmetric_f32_smef64( //
208
+ nk_f32_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
209
+ nk_f64_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
207
210
 
208
- nk_size_t const stride_elements = stride / sizeof(nk_f32_t);
209
- nk_size_t const result_stride_elements = result_stride / sizeof(nk_f64_t);
211
+ nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_f32_t);
212
+ nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f64_t);
210
213
 
211
- nk_dots_symmetric_f32_smef64_streaming_(vectors, n_vectors, depth, stride_elements, result, result_stride_elements,
212
- row_start, row_count);
213
- nk_angulars_symmetric_f32_smef64_finalize_streaming_(vectors, n_vectors, depth, stride_elements, result,
214
+ nk_dots_symmetric_f32_smef64_streaming_(vectors, vectors_count, depth, stride_elements, result,
215
+ result_stride_elements, row_start, row_count);
216
+ nk_angulars_symmetric_f32_smef64_finalize_streaming_(vectors, vectors_count, depth, stride_elements, result,
214
217
  result_stride_elements, row_start, row_count);
215
218
  }
216
219
 
217
- #pragma region Single Precision Symmetric Euclidean
220
+ #pragma endregion F32 Symmetric Angular
221
+ #pragma region F32 Symmetric Euclidean
218
222
 
219
- __arm_locally_streaming static void nk_euclideans_symmetric_f32_smef64_finalize_streaming_( //
220
- nk_f32_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride_elements, //
223
+ __arm_locally_streaming static void nk_euclideans_symmetric_f32_smef64_finalize_streaming_( //
224
+ nk_f32_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, //
221
225
  nk_f64_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
222
226
  // Phase 1: cache row norms on diagonal
223
227
  for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
@@ -227,8 +231,8 @@ __arm_locally_streaming static void nk_euclideans_symmetric_f32_smef64_finalize_
227
231
  }
228
232
  // Phase 2: column-chunked post-processing
229
233
  nk_f64_t column_norms[256];
230
- for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
231
- nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
234
+ for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
235
+ nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
232
236
  for (nk_size_t col = chunk_start; col < chunk_end; ++col) {
233
237
  nk_f32_t const *col_vector = vectors + col * stride_elements;
234
238
  column_norms[col - chunk_start] = nk_dots_reduce_sumsq_f32_ssve_(col_vector, depth);
@@ -239,11 +243,11 @@ __arm_locally_streaming static void nk_euclideans_symmetric_f32_smef64_finalize_
239
243
  nk_f64_t *result_row = result + row_index * result_stride_elements;
240
244
  svfloat64_t query_norm_sq_f64x = svdup_n_f64(result_row[row_index]);
241
245
  for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntd()) {
242
- svbool_t predicate_f64x = svwhilelt_b64_u64(col_index, chunk_end);
243
- svfloat64_t dots_f64x = svld1_f64(predicate_f64x, result_row + col_index);
244
- svfloat64_t target_norms_sq_f64x = svld1_f64(predicate_f64x, column_norms + (col_index - chunk_start));
245
- svst1_f64(predicate_f64x, result_row + col_index,
246
- nk_euclideans_from_dot_f64x_ssvef64_(predicate_f64x, dots_f64x, query_norm_sq_f64x,
246
+ svbool_t predicate_b64x = svwhilelt_b64_u64(col_index, chunk_end);
247
+ svfloat64_t dots_f64x = svld1_f64(predicate_b64x, result_row + col_index);
248
+ svfloat64_t target_norms_sq_f64x = svld1_f64(predicate_b64x, column_norms + (col_index - chunk_start));
249
+ svst1_f64(predicate_b64x, result_row + col_index,
250
+ nk_euclideans_from_dot_f64x_ssvef64_(predicate_b64x, dots_f64x, query_norm_sq_f64x,
247
251
  target_norms_sq_f64x));
248
252
  }
249
253
  }
@@ -253,20 +257,21 @@ __arm_locally_streaming static void nk_euclideans_symmetric_f32_smef64_finalize_
253
257
  result[row_index * result_stride_elements + row_index] = 0;
254
258
  }
255
259
 
256
- NK_PUBLIC void nk_euclideans_symmetric_f32_smef64( //
257
- nk_f32_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
258
- nk_f64_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
260
+ NK_PUBLIC void nk_euclideans_symmetric_f32_smef64( //
261
+ nk_f32_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
262
+ nk_f64_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
259
263
 
260
- nk_size_t const stride_elements = stride / sizeof(nk_f32_t);
261
- nk_size_t const result_stride_elements = result_stride / sizeof(nk_f64_t);
264
+ nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_f32_t);
265
+ nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f64_t);
262
266
 
263
- nk_dots_symmetric_f32_smef64_streaming_(vectors, n_vectors, depth, stride_elements, result, result_stride_elements,
264
- row_start, row_count);
265
- nk_euclideans_symmetric_f32_smef64_finalize_streaming_(vectors, n_vectors, depth, stride_elements, result,
267
+ nk_dots_symmetric_f32_smef64_streaming_(vectors, vectors_count, depth, stride_elements, result,
268
+ result_stride_elements, row_start, row_count);
269
+ nk_euclideans_symmetric_f32_smef64_finalize_streaming_(vectors, vectors_count, depth, stride_elements, result,
266
270
  result_stride_elements, row_start, row_count);
267
271
  }
268
272
 
269
- #pragma region Double Precision Packed Angular
273
+ #pragma endregion F32 Symmetric Euclidean
274
+ #pragma region F64 Packed Angular
270
275
 
271
276
  __arm_locally_streaming static void nk_angulars_packed_f64_smef64_finalize_streaming_( //
272
277
  nk_f64_t const *a, void const *b_packed, nk_f64_t *c, //
@@ -283,11 +288,11 @@ __arm_locally_streaming static void nk_angulars_packed_f64_smef64_finalize_strea
283
288
  svfloat64_t query_norm_sq_f64x = svdup_n_f64(query_norm_sq_f64);
284
289
 
285
290
  for (nk_size_t col_index = 0; col_index < columns; col_index += svcntd()) {
286
- svbool_t predicate_f64x = svwhilelt_b64_u64(col_index, columns);
287
- svfloat64_t dots_f64x = svld1_f64(predicate_f64x, c_row + col_index);
288
- svfloat64_t target_norms_sq_f64x = svld1_f64(predicate_f64x, b_norms + col_index);
289
- svst1_f64(predicate_f64x, c_row + col_index,
290
- nk_angulars_from_dot_f64x_ssvef64_(predicate_f64x, dots_f64x, query_norm_sq_f64x,
291
+ svbool_t predicate_b64x = svwhilelt_b64_u64(col_index, columns);
292
+ svfloat64_t dots_f64x = svld1_f64(predicate_b64x, c_row + col_index);
293
+ svfloat64_t target_norms_sq_f64x = svld1_f64(predicate_b64x, b_norms + col_index);
294
+ svst1_f64(predicate_b64x, c_row + col_index,
295
+ nk_angulars_from_dot_f64x_ssvef64_(predicate_b64x, dots_f64x, query_norm_sq_f64x,
291
296
  target_norms_sq_f64x));
292
297
  }
293
298
  }
@@ -306,7 +311,8 @@ NK_PUBLIC void nk_angulars_packed_f64_smef64( //
306
311
  c_stride_elements);
307
312
  }
308
313
 
309
- #pragma region Double Precision Packed Euclidean
314
+ #pragma endregion F64 Packed Angular
315
+ #pragma region F64 Packed Euclidean
310
316
 
311
317
  __arm_locally_streaming static void nk_euclideans_packed_f64_smef64_finalize_streaming_( //
312
318
  nk_f64_t const *a, void const *b_packed, nk_f64_t *c, //
@@ -323,11 +329,11 @@ __arm_locally_streaming static void nk_euclideans_packed_f64_smef64_finalize_str
323
329
  svfloat64_t query_norm_sq_f64x = svdup_n_f64(query_norm_sq_f64);
324
330
 
325
331
  for (nk_size_t col_index = 0; col_index < columns; col_index += svcntd()) {
326
- svbool_t predicate_f64x = svwhilelt_b64_u64(col_index, columns);
327
- svfloat64_t dots_f64x = svld1_f64(predicate_f64x, c_row + col_index);
328
- svfloat64_t target_norms_sq_f64x = svld1_f64(predicate_f64x, b_norms + col_index);
329
- svst1_f64(predicate_f64x, c_row + col_index,
330
- nk_euclideans_from_dot_f64x_ssvef64_(predicate_f64x, dots_f64x, query_norm_sq_f64x,
332
+ svbool_t predicate_b64x = svwhilelt_b64_u64(col_index, columns);
333
+ svfloat64_t dots_f64x = svld1_f64(predicate_b64x, c_row + col_index);
334
+ svfloat64_t target_norms_sq_f64x = svld1_f64(predicate_b64x, b_norms + col_index);
335
+ svst1_f64(predicate_b64x, c_row + col_index,
336
+ nk_euclideans_from_dot_f64x_ssvef64_(predicate_b64x, dots_f64x, query_norm_sq_f64x,
331
337
  target_norms_sq_f64x));
332
338
  }
333
339
  }
@@ -346,10 +352,11 @@ NK_PUBLIC void nk_euclideans_packed_f64_smef64( //
346
352
  c_stride_elements);
347
353
  }
348
354
 
349
- #pragma region Double Precision Symmetric Angular
355
+ #pragma endregion F64 Packed Euclidean
356
+ #pragma region F64 Symmetric Angular
350
357
 
351
- __arm_locally_streaming static void nk_angulars_symmetric_f64_smef64_finalize_streaming_( //
352
- nk_f64_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride_elements, //
358
+ __arm_locally_streaming static void nk_angulars_symmetric_f64_smef64_finalize_streaming_( //
359
+ nk_f64_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, //
353
360
  nk_f64_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
354
361
  // Phase 1: cache row norms on diagonal
355
362
  for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
@@ -359,8 +366,8 @@ __arm_locally_streaming static void nk_angulars_symmetric_f64_smef64_finalize_st
359
366
  }
360
367
  // Phase 2: column-chunked post-processing
361
368
  nk_f64_t column_norms[256];
362
- for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
363
- nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
369
+ for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
370
+ nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
364
371
  for (nk_size_t col = chunk_start; col < chunk_end; ++col) {
365
372
  nk_f64_t const *col_vector = vectors + col * stride_elements;
366
373
  column_norms[col - chunk_start] = nk_dots_reduce_sumsq_f64_ssve_(col_vector, depth);
@@ -371,11 +378,11 @@ __arm_locally_streaming static void nk_angulars_symmetric_f64_smef64_finalize_st
371
378
  nk_f64_t *result_row = result + row_index * result_stride_elements;
372
379
  svfloat64_t query_norm_sq_f64x = svdup_n_f64(result_row[row_index]);
373
380
  for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntd()) {
374
- svbool_t predicate_f64x = svwhilelt_b64_u64(col_index, chunk_end);
375
- svfloat64_t dots_f64x = svld1_f64(predicate_f64x, result_row + col_index);
376
- svfloat64_t target_norms_sq_f64x = svld1_f64(predicate_f64x, column_norms + (col_index - chunk_start));
377
- svst1_f64(predicate_f64x, result_row + col_index,
378
- nk_angulars_from_dot_f64x_ssvef64_(predicate_f64x, dots_f64x, query_norm_sq_f64x,
381
+ svbool_t predicate_b64x = svwhilelt_b64_u64(col_index, chunk_end);
382
+ svfloat64_t dots_f64x = svld1_f64(predicate_b64x, result_row + col_index);
383
+ svfloat64_t target_norms_sq_f64x = svld1_f64(predicate_b64x, column_norms + (col_index - chunk_start));
384
+ svst1_f64(predicate_b64x, result_row + col_index,
385
+ nk_angulars_from_dot_f64x_ssvef64_(predicate_b64x, dots_f64x, query_norm_sq_f64x,
379
386
  target_norms_sq_f64x));
380
387
  }
381
388
  }
@@ -385,23 +392,24 @@ __arm_locally_streaming static void nk_angulars_symmetric_f64_smef64_finalize_st
385
392
  result[row_index * result_stride_elements + row_index] = 0;
386
393
  }
387
394
 
388
- NK_PUBLIC void nk_angulars_symmetric_f64_smef64( //
389
- nk_f64_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
390
- nk_f64_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
395
+ NK_PUBLIC void nk_angulars_symmetric_f64_smef64( //
396
+ nk_f64_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
397
+ nk_f64_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
391
398
 
392
- nk_size_t const stride_elements = stride / sizeof(nk_f64_t);
393
- nk_size_t const result_stride_elements = result_stride / sizeof(nk_f64_t);
399
+ nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_f64_t);
400
+ nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f64_t);
394
401
 
395
- nk_dots_symmetric_f64_smef64_streaming_(vectors, n_vectors, depth, stride_elements, result, result_stride_elements,
396
- row_start, row_count);
397
- nk_angulars_symmetric_f64_smef64_finalize_streaming_(vectors, n_vectors, depth, stride_elements, result,
402
+ nk_dots_symmetric_f64_smef64_streaming_(vectors, vectors_count, depth, stride_elements, result,
403
+ result_stride_elements, row_start, row_count);
404
+ nk_angulars_symmetric_f64_smef64_finalize_streaming_(vectors, vectors_count, depth, stride_elements, result,
398
405
  result_stride_elements, row_start, row_count);
399
406
  }
400
407
 
401
- #pragma region Double Precision Symmetric Euclidean
408
+ #pragma endregion F64 Symmetric Angular
409
+ #pragma region F64 Symmetric Euclidean
402
410
 
403
- __arm_locally_streaming static void nk_euclideans_symmetric_f64_smef64_finalize_streaming_( //
404
- nk_f64_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride_elements, //
411
+ __arm_locally_streaming static void nk_euclideans_symmetric_f64_smef64_finalize_streaming_( //
412
+ nk_f64_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, //
405
413
  nk_f64_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
406
414
  // Phase 1: cache row norms on diagonal
407
415
  for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
@@ -411,8 +419,8 @@ __arm_locally_streaming static void nk_euclideans_symmetric_f64_smef64_finalize_
411
419
  }
412
420
  // Phase 2: column-chunked post-processing
413
421
  nk_f64_t column_norms[256];
414
- for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
415
- nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
422
+ for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
423
+ nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
416
424
  for (nk_size_t col = chunk_start; col < chunk_end; ++col) {
417
425
  nk_f64_t const *col_vector = vectors + col * stride_elements;
418
426
  column_norms[col - chunk_start] = nk_dots_reduce_sumsq_f64_ssve_(col_vector, depth);
@@ -423,11 +431,11 @@ __arm_locally_streaming static void nk_euclideans_symmetric_f64_smef64_finalize_
423
431
  nk_f64_t *result_row = result + row_index * result_stride_elements;
424
432
  svfloat64_t query_norm_sq_f64x = svdup_n_f64(result_row[row_index]);
425
433
  for (nk_size_t col_index = col_start; col_index < chunk_end; col_index += svcntd()) {
426
- svbool_t predicate_f64x = svwhilelt_b64_u64(col_index, chunk_end);
427
- svfloat64_t dots_f64x = svld1_f64(predicate_f64x, result_row + col_index);
428
- svfloat64_t target_norms_sq_f64x = svld1_f64(predicate_f64x, column_norms + (col_index - chunk_start));
429
- svst1_f64(predicate_f64x, result_row + col_index,
430
- nk_euclideans_from_dot_f64x_ssvef64_(predicate_f64x, dots_f64x, query_norm_sq_f64x,
434
+ svbool_t predicate_b64x = svwhilelt_b64_u64(col_index, chunk_end);
435
+ svfloat64_t dots_f64x = svld1_f64(predicate_b64x, result_row + col_index);
436
+ svfloat64_t target_norms_sq_f64x = svld1_f64(predicate_b64x, column_norms + (col_index - chunk_start));
437
+ svst1_f64(predicate_b64x, result_row + col_index,
438
+ nk_euclideans_from_dot_f64x_ssvef64_(predicate_b64x, dots_f64x, query_norm_sq_f64x,
431
439
  target_norms_sq_f64x));
432
440
  }
433
441
  }
@@ -437,19 +445,20 @@ __arm_locally_streaming static void nk_euclideans_symmetric_f64_smef64_finalize_
437
445
  result[row_index * result_stride_elements + row_index] = 0;
438
446
  }
439
447
 
440
- NK_PUBLIC void nk_euclideans_symmetric_f64_smef64( //
441
- nk_f64_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
442
- nk_f64_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
448
+ NK_PUBLIC void nk_euclideans_symmetric_f64_smef64( //
449
+ nk_f64_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
450
+ nk_f64_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
443
451
 
444
- nk_size_t const stride_elements = stride / sizeof(nk_f64_t);
445
- nk_size_t const result_stride_elements = result_stride / sizeof(nk_f64_t);
452
+ nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_f64_t);
453
+ nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f64_t);
446
454
 
447
- nk_dots_symmetric_f64_smef64_streaming_(vectors, n_vectors, depth, stride_elements, result, result_stride_elements,
448
- row_start, row_count);
449
- nk_euclideans_symmetric_f64_smef64_finalize_streaming_(vectors, n_vectors, depth, stride_elements, result,
455
+ nk_dots_symmetric_f64_smef64_streaming_(vectors, vectors_count, depth, stride_elements, result,
456
+ result_stride_elements, row_start, row_count);
457
+ nk_euclideans_symmetric_f64_smef64_finalize_streaming_(vectors, vectors_count, depth, stride_elements, result,
450
458
  result_stride_elements, row_start, row_count);
451
459
  }
452
460
 
461
+ #pragma endregion F64 Symmetric Euclidean
453
462
  #if defined(__clang__)
454
463
  #pragma clang attribute pop
455
464
  #elif defined(__GNUC__)
@@ -144,7 +144,7 @@ nk_define_cross_normalized_symmetric_(euclidean, e5m2, v128relaxed, e5m2, f32, /
144
144
  nk_load_b128_v128relaxed_, nk_partial_load_b32x4_serial_,
145
145
  nk_store_b128_v128relaxed_, nk_partial_store_b32x4_serial_, 1)
146
146
 
147
- nk_define_cross_normalized_packed_(angular, bf16, v128relaxed, bf16, f32, f32, /*norm_value_type=*/f32, f32,
147
+ nk_define_cross_normalized_packed_(angular, bf16, v128relaxed, bf16, bf16, f32, /*norm_value_type=*/f32, f32,
148
148
  nk_b128_vec_t, nk_dots_packed_bf16_v128relaxed,
149
149
  nk_angular_through_f32_from_dot_v128relaxed_, nk_dots_reduce_sumsq_bf16_,
150
150
  nk_load_b128_v128relaxed_, nk_partial_load_b32x4_serial_, nk_store_b128_v128relaxed_,
@@ -154,7 +154,7 @@ nk_define_cross_normalized_symmetric_(angular, bf16, v128relaxed, bf16, f32, /*n
154
154
  nk_angular_through_f32_from_dot_v128relaxed_, nk_dots_reduce_sumsq_bf16_,
155
155
  nk_load_b128_v128relaxed_, nk_partial_load_b32x4_serial_,
156
156
  nk_store_b128_v128relaxed_, nk_partial_store_b32x4_serial_, 1)
157
- nk_define_cross_normalized_packed_(euclidean, bf16, v128relaxed, bf16, f32, f32, /*norm_value_type=*/f32, f32,
157
+ nk_define_cross_normalized_packed_(euclidean, bf16, v128relaxed, bf16, bf16, f32, /*norm_value_type=*/f32, f32,
158
158
  nk_b128_vec_t, nk_dots_packed_bf16_v128relaxed,
159
159
  nk_euclidean_through_f32_from_dot_v128relaxed_, nk_dots_reduce_sumsq_bf16_,
160
160
  nk_load_b128_v128relaxed_, nk_partial_load_b32x4_serial_, nk_store_b128_v128relaxed_,