numkong 7.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (294) hide show
  1. package/LICENSE +201 -0
  2. package/README.md +495 -0
  3. package/binding.gyp +540 -0
  4. package/c/dispatch.h +512 -0
  5. package/c/dispatch_bf16.c +389 -0
  6. package/c/dispatch_bf16c.c +52 -0
  7. package/c/dispatch_e2m3.c +263 -0
  8. package/c/dispatch_e3m2.c +243 -0
  9. package/c/dispatch_e4m3.c +276 -0
  10. package/c/dispatch_e5m2.c +272 -0
  11. package/c/dispatch_f16.c +376 -0
  12. package/c/dispatch_f16c.c +58 -0
  13. package/c/dispatch_f32.c +378 -0
  14. package/c/dispatch_f32c.c +99 -0
  15. package/c/dispatch_f64.c +296 -0
  16. package/c/dispatch_f64c.c +98 -0
  17. package/c/dispatch_i16.c +96 -0
  18. package/c/dispatch_i32.c +89 -0
  19. package/c/dispatch_i4.c +150 -0
  20. package/c/dispatch_i64.c +86 -0
  21. package/c/dispatch_i8.c +289 -0
  22. package/c/dispatch_other.c +330 -0
  23. package/c/dispatch_u1.c +148 -0
  24. package/c/dispatch_u16.c +124 -0
  25. package/c/dispatch_u32.c +118 -0
  26. package/c/dispatch_u4.c +150 -0
  27. package/c/dispatch_u64.c +102 -0
  28. package/c/dispatch_u8.c +303 -0
  29. package/c/numkong.c +950 -0
  30. package/include/README.md +573 -0
  31. package/include/module.modulemap +129 -0
  32. package/include/numkong/attention/sapphireamx.h +1361 -0
  33. package/include/numkong/attention/sme.h +2066 -0
  34. package/include/numkong/attention.h +49 -0
  35. package/include/numkong/capabilities.h +748 -0
  36. package/include/numkong/cast/README.md +262 -0
  37. package/include/numkong/cast/haswell.h +975 -0
  38. package/include/numkong/cast/icelake.h +470 -0
  39. package/include/numkong/cast/neon.h +1192 -0
  40. package/include/numkong/cast/rvv.h +1021 -0
  41. package/include/numkong/cast/sapphire.h +262 -0
  42. package/include/numkong/cast/serial.h +2262 -0
  43. package/include/numkong/cast/skylake.h +856 -0
  44. package/include/numkong/cast/v128relaxed.h +180 -0
  45. package/include/numkong/cast.h +230 -0
  46. package/include/numkong/curved/README.md +223 -0
  47. package/include/numkong/curved/genoa.h +182 -0
  48. package/include/numkong/curved/haswell.h +276 -0
  49. package/include/numkong/curved/neon.h +205 -0
  50. package/include/numkong/curved/neonbfdot.h +212 -0
  51. package/include/numkong/curved/neonhalf.h +212 -0
  52. package/include/numkong/curved/rvv.h +305 -0
  53. package/include/numkong/curved/serial.h +207 -0
  54. package/include/numkong/curved/skylake.h +457 -0
  55. package/include/numkong/curved/smef64.h +506 -0
  56. package/include/numkong/curved.h +517 -0
  57. package/include/numkong/curved.hpp +144 -0
  58. package/include/numkong/dot/README.md +425 -0
  59. package/include/numkong/dot/alder.h +563 -0
  60. package/include/numkong/dot/genoa.h +315 -0
  61. package/include/numkong/dot/haswell.h +1688 -0
  62. package/include/numkong/dot/icelake.h +883 -0
  63. package/include/numkong/dot/neon.h +818 -0
  64. package/include/numkong/dot/neonbfdot.h +244 -0
  65. package/include/numkong/dot/neonfhm.h +360 -0
  66. package/include/numkong/dot/neonhalf.h +198 -0
  67. package/include/numkong/dot/neonsdot.h +508 -0
  68. package/include/numkong/dot/rvv.h +714 -0
  69. package/include/numkong/dot/rvvbb.h +72 -0
  70. package/include/numkong/dot/rvvbf16.h +123 -0
  71. package/include/numkong/dot/rvvhalf.h +129 -0
  72. package/include/numkong/dot/sapphire.h +141 -0
  73. package/include/numkong/dot/serial.h +838 -0
  74. package/include/numkong/dot/sierra.h +405 -0
  75. package/include/numkong/dot/skylake.h +1084 -0
  76. package/include/numkong/dot/sve.h +379 -0
  77. package/include/numkong/dot/svebfdot.h +74 -0
  78. package/include/numkong/dot/svehalf.h +123 -0
  79. package/include/numkong/dot/v128relaxed.h +1258 -0
  80. package/include/numkong/dot.h +1070 -0
  81. package/include/numkong/dot.hpp +94 -0
  82. package/include/numkong/dots/README.md +496 -0
  83. package/include/numkong/dots/alder.h +114 -0
  84. package/include/numkong/dots/genoa.h +94 -0
  85. package/include/numkong/dots/haswell.h +295 -0
  86. package/include/numkong/dots/icelake.h +171 -0
  87. package/include/numkong/dots/neon.h +120 -0
  88. package/include/numkong/dots/neonbfdot.h +58 -0
  89. package/include/numkong/dots/neonfhm.h +94 -0
  90. package/include/numkong/dots/neonhalf.h +57 -0
  91. package/include/numkong/dots/neonsdot.h +108 -0
  92. package/include/numkong/dots/rvv.h +2486 -0
  93. package/include/numkong/dots/sapphireamx.h +3973 -0
  94. package/include/numkong/dots/serial.h +2844 -0
  95. package/include/numkong/dots/sierra.h +97 -0
  96. package/include/numkong/dots/skylake.h +196 -0
  97. package/include/numkong/dots/sme.h +5372 -0
  98. package/include/numkong/dots/smebi32.h +461 -0
  99. package/include/numkong/dots/smef64.h +1318 -0
  100. package/include/numkong/dots/smehalf.h +47 -0
  101. package/include/numkong/dots/v128relaxed.h +294 -0
  102. package/include/numkong/dots.h +2804 -0
  103. package/include/numkong/dots.hpp +639 -0
  104. package/include/numkong/each/README.md +469 -0
  105. package/include/numkong/each/haswell.h +1658 -0
  106. package/include/numkong/each/icelake.h +272 -0
  107. package/include/numkong/each/neon.h +1104 -0
  108. package/include/numkong/each/neonbfdot.h +212 -0
  109. package/include/numkong/each/neonhalf.h +410 -0
  110. package/include/numkong/each/rvv.h +1121 -0
  111. package/include/numkong/each/sapphire.h +477 -0
  112. package/include/numkong/each/serial.h +260 -0
  113. package/include/numkong/each/skylake.h +1562 -0
  114. package/include/numkong/each.h +2146 -0
  115. package/include/numkong/each.hpp +434 -0
  116. package/include/numkong/geospatial/README.md +147 -0
  117. package/include/numkong/geospatial/haswell.h +593 -0
  118. package/include/numkong/geospatial/neon.h +571 -0
  119. package/include/numkong/geospatial/rvv.h +701 -0
  120. package/include/numkong/geospatial/serial.h +309 -0
  121. package/include/numkong/geospatial/skylake.h +577 -0
  122. package/include/numkong/geospatial/v128relaxed.h +613 -0
  123. package/include/numkong/geospatial.h +453 -0
  124. package/include/numkong/geospatial.hpp +235 -0
  125. package/include/numkong/matrix.hpp +336 -0
  126. package/include/numkong/maxsim/README.md +187 -0
  127. package/include/numkong/maxsim/alder.h +511 -0
  128. package/include/numkong/maxsim/genoa.h +115 -0
  129. package/include/numkong/maxsim/haswell.h +553 -0
  130. package/include/numkong/maxsim/icelake.h +480 -0
  131. package/include/numkong/maxsim/neonsdot.h +394 -0
  132. package/include/numkong/maxsim/sapphireamx.h +877 -0
  133. package/include/numkong/maxsim/serial.h +490 -0
  134. package/include/numkong/maxsim/sme.h +929 -0
  135. package/include/numkong/maxsim/v128relaxed.h +280 -0
  136. package/include/numkong/maxsim.h +571 -0
  137. package/include/numkong/maxsim.hpp +133 -0
  138. package/include/numkong/mesh/README.md +227 -0
  139. package/include/numkong/mesh/haswell.h +2235 -0
  140. package/include/numkong/mesh/neon.h +1329 -0
  141. package/include/numkong/mesh/neonbfdot.h +842 -0
  142. package/include/numkong/mesh/neonhalf.h +616 -0
  143. package/include/numkong/mesh/rvv.h +916 -0
  144. package/include/numkong/mesh/serial.h +742 -0
  145. package/include/numkong/mesh/skylake.h +1135 -0
  146. package/include/numkong/mesh/v128relaxed.h +1052 -0
  147. package/include/numkong/mesh.h +652 -0
  148. package/include/numkong/mesh.hpp +762 -0
  149. package/include/numkong/numkong.h +78 -0
  150. package/include/numkong/numkong.hpp +57 -0
  151. package/include/numkong/probability/README.md +173 -0
  152. package/include/numkong/probability/haswell.h +267 -0
  153. package/include/numkong/probability/neon.h +225 -0
  154. package/include/numkong/probability/rvv.h +409 -0
  155. package/include/numkong/probability/serial.h +169 -0
  156. package/include/numkong/probability/skylake.h +324 -0
  157. package/include/numkong/probability.h +383 -0
  158. package/include/numkong/probability.hpp +120 -0
  159. package/include/numkong/random.h +50 -0
  160. package/include/numkong/random.hpp +285 -0
  161. package/include/numkong/reduce/README.md +547 -0
  162. package/include/numkong/reduce/alder.h +632 -0
  163. package/include/numkong/reduce/genoa.h +201 -0
  164. package/include/numkong/reduce/haswell.h +3783 -0
  165. package/include/numkong/reduce/icelake.h +549 -0
  166. package/include/numkong/reduce/neon.h +3841 -0
  167. package/include/numkong/reduce/neonbfdot.h +353 -0
  168. package/include/numkong/reduce/neonfhm.h +665 -0
  169. package/include/numkong/reduce/neonhalf.h +157 -0
  170. package/include/numkong/reduce/neonsdot.h +357 -0
  171. package/include/numkong/reduce/rvv.h +3407 -0
  172. package/include/numkong/reduce/serial.h +757 -0
  173. package/include/numkong/reduce/sierra.h +338 -0
  174. package/include/numkong/reduce/skylake.h +3792 -0
  175. package/include/numkong/reduce/v128relaxed.h +2302 -0
  176. package/include/numkong/reduce.h +1597 -0
  177. package/include/numkong/reduce.hpp +633 -0
  178. package/include/numkong/scalar/README.md +89 -0
  179. package/include/numkong/scalar/haswell.h +113 -0
  180. package/include/numkong/scalar/neon.h +122 -0
  181. package/include/numkong/scalar/neonhalf.h +70 -0
  182. package/include/numkong/scalar/rvv.h +211 -0
  183. package/include/numkong/scalar/sapphire.h +63 -0
  184. package/include/numkong/scalar/serial.h +332 -0
  185. package/include/numkong/scalar/v128relaxed.h +56 -0
  186. package/include/numkong/scalar.h +683 -0
  187. package/include/numkong/set/README.md +179 -0
  188. package/include/numkong/set/haswell.h +334 -0
  189. package/include/numkong/set/icelake.h +485 -0
  190. package/include/numkong/set/neon.h +364 -0
  191. package/include/numkong/set/rvv.h +226 -0
  192. package/include/numkong/set/rvvbb.h +117 -0
  193. package/include/numkong/set/serial.h +174 -0
  194. package/include/numkong/set/sve.h +185 -0
  195. package/include/numkong/set/v128relaxed.h +240 -0
  196. package/include/numkong/set.h +457 -0
  197. package/include/numkong/set.hpp +114 -0
  198. package/include/numkong/sets/README.md +149 -0
  199. package/include/numkong/sets/haswell.h +63 -0
  200. package/include/numkong/sets/icelake.h +66 -0
  201. package/include/numkong/sets/neon.h +61 -0
  202. package/include/numkong/sets/serial.h +43 -0
  203. package/include/numkong/sets/smebi32.h +1099 -0
  204. package/include/numkong/sets/v128relaxed.h +58 -0
  205. package/include/numkong/sets.h +339 -0
  206. package/include/numkong/sparse/README.md +156 -0
  207. package/include/numkong/sparse/icelake.h +463 -0
  208. package/include/numkong/sparse/neon.h +288 -0
  209. package/include/numkong/sparse/serial.h +117 -0
  210. package/include/numkong/sparse/sve2.h +507 -0
  211. package/include/numkong/sparse/turin.h +322 -0
  212. package/include/numkong/sparse.h +363 -0
  213. package/include/numkong/sparse.hpp +113 -0
  214. package/include/numkong/spatial/README.md +435 -0
  215. package/include/numkong/spatial/alder.h +607 -0
  216. package/include/numkong/spatial/genoa.h +290 -0
  217. package/include/numkong/spatial/haswell.h +960 -0
  218. package/include/numkong/spatial/icelake.h +586 -0
  219. package/include/numkong/spatial/neon.h +773 -0
  220. package/include/numkong/spatial/neonbfdot.h +165 -0
  221. package/include/numkong/spatial/neonhalf.h +118 -0
  222. package/include/numkong/spatial/neonsdot.h +261 -0
  223. package/include/numkong/spatial/rvv.h +984 -0
  224. package/include/numkong/spatial/rvvbf16.h +123 -0
  225. package/include/numkong/spatial/rvvhalf.h +117 -0
  226. package/include/numkong/spatial/sapphire.h +343 -0
  227. package/include/numkong/spatial/serial.h +346 -0
  228. package/include/numkong/spatial/sierra.h +323 -0
  229. package/include/numkong/spatial/skylake.h +606 -0
  230. package/include/numkong/spatial/sve.h +224 -0
  231. package/include/numkong/spatial/svebfdot.h +122 -0
  232. package/include/numkong/spatial/svehalf.h +109 -0
  233. package/include/numkong/spatial/v128relaxed.h +717 -0
  234. package/include/numkong/spatial.h +1425 -0
  235. package/include/numkong/spatial.hpp +183 -0
  236. package/include/numkong/spatials/README.md +580 -0
  237. package/include/numkong/spatials/alder.h +94 -0
  238. package/include/numkong/spatials/genoa.h +94 -0
  239. package/include/numkong/spatials/haswell.h +219 -0
  240. package/include/numkong/spatials/icelake.h +113 -0
  241. package/include/numkong/spatials/neon.h +109 -0
  242. package/include/numkong/spatials/neonbfdot.h +60 -0
  243. package/include/numkong/spatials/neonfhm.h +92 -0
  244. package/include/numkong/spatials/neonhalf.h +58 -0
  245. package/include/numkong/spatials/neonsdot.h +109 -0
  246. package/include/numkong/spatials/rvv.h +1960 -0
  247. package/include/numkong/spatials/sapphireamx.h +1149 -0
  248. package/include/numkong/spatials/serial.h +226 -0
  249. package/include/numkong/spatials/sierra.h +96 -0
  250. package/include/numkong/spatials/skylake.h +184 -0
  251. package/include/numkong/spatials/sme.h +1901 -0
  252. package/include/numkong/spatials/smef64.h +465 -0
  253. package/include/numkong/spatials/v128relaxed.h +240 -0
  254. package/include/numkong/spatials.h +3021 -0
  255. package/include/numkong/spatials.hpp +508 -0
  256. package/include/numkong/tensor.hpp +1592 -0
  257. package/include/numkong/trigonometry/README.md +184 -0
  258. package/include/numkong/trigonometry/haswell.h +652 -0
  259. package/include/numkong/trigonometry/neon.h +639 -0
  260. package/include/numkong/trigonometry/rvv.h +699 -0
  261. package/include/numkong/trigonometry/serial.h +703 -0
  262. package/include/numkong/trigonometry/skylake.h +721 -0
  263. package/include/numkong/trigonometry/v128relaxed.h +666 -0
  264. package/include/numkong/trigonometry.h +467 -0
  265. package/include/numkong/trigonometry.hpp +166 -0
  266. package/include/numkong/types.h +1384 -0
  267. package/include/numkong/types.hpp +5603 -0
  268. package/include/numkong/vector.hpp +698 -0
  269. package/javascript/README.md +246 -0
  270. package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
  271. package/javascript/dist/cjs/numkong-wasm.js +617 -0
  272. package/javascript/dist/cjs/numkong.d.ts +343 -0
  273. package/javascript/dist/cjs/numkong.js +523 -0
  274. package/javascript/dist/cjs/package.json +3 -0
  275. package/javascript/dist/cjs/types.d.ts +284 -0
  276. package/javascript/dist/cjs/types.js +653 -0
  277. package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
  278. package/javascript/dist/esm/numkong-wasm.js +595 -0
  279. package/javascript/dist/esm/numkong.d.ts +343 -0
  280. package/javascript/dist/esm/numkong.js +452 -0
  281. package/javascript/dist/esm/package.json +3 -0
  282. package/javascript/dist/esm/types.d.ts +284 -0
  283. package/javascript/dist/esm/types.js +630 -0
  284. package/javascript/dist-package-cjs.json +3 -0
  285. package/javascript/dist-package-esm.json +3 -0
  286. package/javascript/node-gyp-build.d.ts +1 -0
  287. package/javascript/numkong-wasm.ts +756 -0
  288. package/javascript/numkong.c +689 -0
  289. package/javascript/numkong.ts +575 -0
  290. package/javascript/tsconfig-base.json +39 -0
  291. package/javascript/tsconfig-cjs.json +8 -0
  292. package/javascript/tsconfig-esm.json +8 -0
  293. package/javascript/types.ts +674 -0
  294. package/package.json +87 -0
@@ -0,0 +1,3021 @@
1
+ /**
2
+ * @brief SIMD-accelerated Batched Spatial Distances (Angular & Euclidean).
3
+ * @file include/numkong/spatials.h
4
+ * @author Ash Vardanian
5
+ * @date February 22, 2026
6
+ *
7
+ * This module provides efficient batched computation of angular and euclidean distances
8
+ * via a two-pass approach: compute dot products first, then post-process with spatial
9
+ * distance formulas using pre-computed norms stored in the packed buffer.
10
+ *
11
+ * For dtypes:
12
+ *
13
+ * - f64: 64-bit IEEE floating point numbers → 64-bit floats
14
+ * - f32: 32-bit IEEE floating point numbers → 64-bit floats
15
+ * - f16: 16-bit IEEE floating point numbers → 32-bit floats
16
+ * - bf16: 16-bit brain floating point numbers → 32-bit floats
17
+ * - e4m3: 8-bit e4m3 floating point numbers → 32-bit floats
18
+ * - e5m2: 8-bit e5m2 floating point numbers → 32-bit floats
19
+ * - e2m3: 8-bit e2m3 floating point numbers (MX) → 32-bit floats
20
+ * - e3m2: 8-bit e3m2 floating point numbers (MX) → 32-bit floats
21
+ *
22
+ * For hardware architectures:
23
+ *
24
+ * - Arm: NEON, NEON+HALF, NEON+FHM, NEON+BF16, NEON+SDOT, SME, SME+F64
25
+ * - x86: Haswell, Skylake, Ice Lake, Genoa, Sapphire Rapids (AMX), Sierra Forest
26
+ * - RISC-V: RVV
27
+ *
28
+ * @section numerical_stability Numerical Stability
29
+ *
30
+ * Inherits dot-product precision from nk_dots_packed_* and keeps packed payloads narrow. `f32` batched spatial
31
+ * kernels now normalize from widened `f64` dots and norms and store `f64` results directly.
32
+ *
33
+ * @section approach Two-Pass Approach
34
+ *
35
+ * 1. Pack B matrix using nk_dots_pack_* (norms are stored in the packed buffer footer)
36
+ * 2. Compute nk_angulars_packed_* or nk_euclideans_packed_*:
37
+ * a. Internally calls nk_dots_packed_* to fill result buffer with dot products
38
+ * b. Post-processes each result cell using angular/euclidean formula with pre-computed norms
39
+ *
40
+ * @section math Mathematical Foundation
41
+ *
42
+ * Angular distance: 1 - dot(a,b) / sqrt(sumsq(a) * sumsq(b))
43
+ * Euclidean distance: sqrt(max(0, sumsq(a) + sumsq(b) - 2*dot(a,b)))
44
+ *
45
+ * @section packing Packing
46
+ *
47
+ * Uses the SAME pack functions as dot products (nk_dots_packed_size_*, nk_dots_pack_*).
48
+ * The packed buffer includes norms appended after the data.
49
+ */
50
+
51
+ #ifndef NK_SPATIALS_H
52
+ #define NK_SPATIALS_H
53
+
54
+ #include "numkong/dots.h"
55
+ #include "numkong/types.h"
56
+
57
+ #if defined(__cplusplus)
58
+ extern "C" {
59
+ #endif
60
+
61
+ /**
62
+ * @brief Computes batched angular distances using a packed second matrix.
63
+ * @param[in] a Input A matrix in row-major order.
64
+ * @param[in] b_packed Packed B matrix (produced by nk_dots_pack_*), with norms in footer.
65
+ * @param[out] result Output matrix (rows x cols) of angular distances.
66
+ * @param[in] rows Number of rows in A.
67
+ * @param[in] cols Number of columns in B (packed).
68
+ * @param[in] depth Shared inner dimension (vector length).
69
+ * @param[in] a_stride_in_bytes Row stride in bytes for A.
70
+ * @param[in] r_stride_in_bytes Row stride in bytes for the result matrix.
71
+ */
72
+ NK_DYNAMIC void nk_angulars_packed_f32(nk_f32_t const *a, void const *b_packed, nk_f64_t *result, nk_size_t rows,
73
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
74
+ nk_size_t r_stride_in_bytes);
75
+
76
+ /**
77
+ * @brief Computes symmetric angular distance matrix (Gram-style) for a set of vectors.
78
+ * @param[in] vectors Input matrix of row vectors in row-major order.
79
+ * @param[in] n_vectors Number of vectors (rows) in the input matrix.
80
+ * @param[in] depth Dimension of each vector (columns).
81
+ * @param[in] stride Row stride in bytes for the input matrix.
82
+ * @param[out] result Output symmetric matrix (n_vectors x n_vectors).
83
+ * @param[in] result_stride Row stride in bytes for the result matrix.
84
+ * @param[in] row_start Starting row offset of results to compute (for parallelism).
85
+ * @param[in] row_count Number of rows of results to compute (for parallelism).
86
+ */
87
+ NK_DYNAMIC void nk_angulars_symmetric_f32(nk_f32_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
88
+ nk_size_t stride, nk_f64_t *result, nk_size_t result_stride,
89
+ nk_size_t row_start, nk_size_t row_count);
90
+
91
+ /**
92
+ * @brief Computes batched euclidean distances using a packed second matrix.
93
+ * @param[in] a Input A matrix in row-major order.
94
+ * @param[in] b_packed Packed B matrix (produced by nk_dots_pack_*), with norms in footer.
95
+ * @param[out] result Output matrix (rows x cols) of euclidean distances.
96
+ * @param[in] rows Number of rows in A.
97
+ * @param[in] cols Number of columns in B (packed).
98
+ * @param[in] depth Shared inner dimension (vector length).
99
+ * @param[in] a_stride_in_bytes Row stride in bytes for A.
100
+ * @param[in] r_stride_in_bytes Row stride in bytes for the result matrix.
101
+ */
102
+ NK_DYNAMIC void nk_euclideans_packed_f32(nk_f32_t const *a, void const *b_packed, nk_f64_t *result, nk_size_t rows,
103
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
104
+ nk_size_t r_stride_in_bytes);
105
+
106
+ /**
107
+ * @brief Computes symmetric euclidean distance matrix (Gram-style) for a set of vectors.
108
+ * @param[in] vectors Input matrix of row vectors in row-major order.
109
+ * @param[in] n_vectors Number of vectors (rows) in the input matrix.
110
+ * @param[in] depth Dimension of each vector (columns).
111
+ * @param[in] stride Row stride in bytes for the input matrix.
112
+ * @param[out] result Output symmetric matrix (n_vectors x n_vectors).
113
+ * @param[in] result_stride Row stride in bytes for the result matrix.
114
+ * @param[in] row_start Starting row offset of results to compute (for parallelism).
115
+ * @param[in] row_count Number of rows of results to compute (for parallelism).
116
+ */
117
+ NK_DYNAMIC void nk_euclideans_symmetric_f32(nk_f32_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
118
+ nk_size_t stride, nk_f64_t *result, nk_size_t result_stride,
119
+ nk_size_t row_start, nk_size_t row_count);
120
+
121
+ /** @copydoc nk_angulars_packed_f32 */
122
+ NK_DYNAMIC void nk_angulars_packed_f64(nk_f64_t const *a, void const *b_packed, nk_f64_t *result, nk_size_t rows,
123
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
124
+ nk_size_t r_stride_in_bytes);
125
+ /** @copydoc nk_angulars_symmetric_f32 */
126
+ NK_DYNAMIC void nk_angulars_symmetric_f64(nk_f64_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
127
+ nk_size_t stride, nk_f64_t *result, nk_size_t result_stride,
128
+ nk_size_t row_start, nk_size_t row_count);
129
+ /** @copydoc nk_euclideans_packed_f32 */
130
+ NK_DYNAMIC void nk_euclideans_packed_f64(nk_f64_t const *a, void const *b_packed, nk_f64_t *result, nk_size_t rows,
131
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
132
+ nk_size_t r_stride_in_bytes);
133
+ /** @copydoc nk_euclideans_symmetric_f32 */
134
+ NK_DYNAMIC void nk_euclideans_symmetric_f64(nk_f64_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
135
+ nk_size_t stride, nk_f64_t *result, nk_size_t result_stride,
136
+ nk_size_t row_start, nk_size_t row_count);
137
+
138
+ /** @copydoc nk_angulars_packed_f32 */
139
+ NK_DYNAMIC void nk_angulars_packed_f16(nk_f16_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
140
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
141
+ nk_size_t r_stride_in_bytes);
142
+ /** @copydoc nk_angulars_symmetric_f32 */
143
+ NK_DYNAMIC void nk_angulars_symmetric_f16(nk_f16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
144
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
145
+ nk_size_t row_start, nk_size_t row_count);
146
+ /** @copydoc nk_euclideans_packed_f32 */
147
+ NK_DYNAMIC void nk_euclideans_packed_f16(nk_f16_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
148
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
149
+ nk_size_t r_stride_in_bytes);
150
+ /** @copydoc nk_euclideans_symmetric_f32 */
151
+ NK_DYNAMIC void nk_euclideans_symmetric_f16(nk_f16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
152
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
153
+ nk_size_t row_start, nk_size_t row_count);
154
+
155
+ /** @copydoc nk_angulars_packed_f32 */
156
+ NK_DYNAMIC void nk_angulars_packed_bf16(nk_bf16_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
157
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
158
+ nk_size_t r_stride_in_bytes);
159
+ /** @copydoc nk_angulars_symmetric_f32 */
160
+ NK_DYNAMIC void nk_angulars_symmetric_bf16(nk_bf16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
161
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
162
+ nk_size_t row_start, nk_size_t row_count);
163
+ /** @copydoc nk_euclideans_packed_f32 */
164
+ NK_DYNAMIC void nk_euclideans_packed_bf16(nk_bf16_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
165
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
166
+ nk_size_t r_stride_in_bytes);
167
+ /** @copydoc nk_euclideans_symmetric_f32 */
168
+ NK_DYNAMIC void nk_euclideans_symmetric_bf16(nk_bf16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
169
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
170
+ nk_size_t row_start, nk_size_t row_count);
171
+
172
+ /** @copydoc nk_angulars_packed_f32 */
173
+ NK_DYNAMIC void nk_angulars_packed_e4m3(nk_e4m3_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
174
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
175
+ nk_size_t r_stride_in_bytes);
176
+ /** @copydoc nk_angulars_symmetric_f32 */
177
+ NK_DYNAMIC void nk_angulars_symmetric_e4m3(nk_e4m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
178
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
179
+ nk_size_t row_start, nk_size_t row_count);
180
+ /** @copydoc nk_euclideans_packed_f32 */
181
+ NK_DYNAMIC void nk_euclideans_packed_e4m3(nk_e4m3_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
182
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
183
+ nk_size_t r_stride_in_bytes);
184
+ /** @copydoc nk_euclideans_symmetric_f32 */
185
+ NK_DYNAMIC void nk_euclideans_symmetric_e4m3(nk_e4m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
186
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
187
+ nk_size_t row_start, nk_size_t row_count);
188
+
189
+ /** @copydoc nk_angulars_packed_f32 */
190
+ NK_DYNAMIC void nk_angulars_packed_e5m2(nk_e5m2_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
191
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
192
+ nk_size_t r_stride_in_bytes);
193
+ /** @copydoc nk_angulars_symmetric_f32 */
194
+ NK_DYNAMIC void nk_angulars_symmetric_e5m2(nk_e5m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
195
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
196
+ nk_size_t row_start, nk_size_t row_count);
197
+ /** @copydoc nk_euclideans_packed_f32 */
198
+ NK_DYNAMIC void nk_euclideans_packed_e5m2(nk_e5m2_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
199
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
200
+ nk_size_t r_stride_in_bytes);
201
+ /** @copydoc nk_euclideans_symmetric_f32 */
202
+ NK_DYNAMIC void nk_euclideans_symmetric_e5m2(nk_e5m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
203
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
204
+ nk_size_t row_start, nk_size_t row_count);
205
+
206
+ /** @copydoc nk_angulars_packed_f32 */
207
+ NK_DYNAMIC void nk_angulars_packed_e2m3(nk_e2m3_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
208
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
209
+ nk_size_t r_stride_in_bytes);
210
+ /** @copydoc nk_angulars_symmetric_f32 */
211
+ NK_DYNAMIC void nk_angulars_symmetric_e2m3(nk_e2m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
212
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
213
+ nk_size_t row_start, nk_size_t row_count);
214
+ /** @copydoc nk_euclideans_packed_f32 */
215
+ NK_DYNAMIC void nk_euclideans_packed_e2m3(nk_e2m3_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
216
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
217
+ nk_size_t r_stride_in_bytes);
218
+ /** @copydoc nk_euclideans_symmetric_f32 */
219
+ NK_DYNAMIC void nk_euclideans_symmetric_e2m3(nk_e2m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
220
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
221
+ nk_size_t row_start, nk_size_t row_count);
222
+
223
+ /** @copydoc nk_angulars_packed_f32 */
224
+ NK_DYNAMIC void nk_angulars_packed_e3m2(nk_e3m2_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
225
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
226
+ nk_size_t r_stride_in_bytes);
227
+ /** @copydoc nk_angulars_symmetric_f32 */
228
+ NK_DYNAMIC void nk_angulars_symmetric_e3m2(nk_e3m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
229
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
230
+ nk_size_t row_start, nk_size_t row_count);
231
+ /** @copydoc nk_euclideans_packed_f32 */
232
+ NK_DYNAMIC void nk_euclideans_packed_e3m2(nk_e3m2_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
233
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
234
+ nk_size_t r_stride_in_bytes);
235
+ /** @copydoc nk_euclideans_symmetric_f32 */
236
+ NK_DYNAMIC void nk_euclideans_symmetric_e3m2(nk_e3m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
237
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
238
+ nk_size_t row_start, nk_size_t row_count);
239
+
240
+ /** @copydoc nk_angulars_packed_f32 */
241
+ NK_DYNAMIC void nk_angulars_packed_i8(nk_i8_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
242
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
243
+ nk_size_t r_stride_in_bytes);
244
+ /** @copydoc nk_angulars_symmetric_f32 */
245
+ NK_DYNAMIC void nk_angulars_symmetric_i8(nk_i8_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride,
246
+ nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start,
247
+ nk_size_t row_count);
248
+ /** @copydoc nk_euclideans_packed_f32 */
249
+ NK_DYNAMIC void nk_euclideans_packed_i8(nk_i8_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
250
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
251
+ nk_size_t r_stride_in_bytes);
252
+ /** @copydoc nk_euclideans_symmetric_f32 */
253
+ NK_DYNAMIC void nk_euclideans_symmetric_i8(nk_i8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
254
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
255
+ nk_size_t row_start, nk_size_t row_count);
256
+
257
+ /** @copydoc nk_angulars_packed_f32 */
258
+ NK_DYNAMIC void nk_angulars_packed_u8(nk_u8_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
259
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
260
+ nk_size_t r_stride_in_bytes);
261
+ /** @copydoc nk_angulars_symmetric_f32 */
262
+ NK_DYNAMIC void nk_angulars_symmetric_u8(nk_u8_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride,
263
+ nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start,
264
+ nk_size_t row_count);
265
+ /** @copydoc nk_euclideans_packed_f32 */
266
+ NK_DYNAMIC void nk_euclideans_packed_u8(nk_u8_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
267
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
268
+ nk_size_t r_stride_in_bytes);
269
+ /** @copydoc nk_euclideans_symmetric_f32 */
270
+ NK_DYNAMIC void nk_euclideans_symmetric_u8(nk_u8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
271
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
272
+ nk_size_t row_start, nk_size_t row_count);
273
+
274
+ /** @copydoc nk_angulars_packed_f32 */
275
+ NK_DYNAMIC void nk_angulars_packed_i4(nk_i4x2_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
276
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
277
+ nk_size_t r_stride_in_bytes);
278
+ /** @copydoc nk_angulars_symmetric_f32 */
279
+ NK_DYNAMIC void nk_angulars_symmetric_i4(nk_i4x2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
280
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
281
+ nk_size_t row_start, nk_size_t row_count);
282
+ /** @copydoc nk_euclideans_packed_f32 */
283
+ NK_DYNAMIC void nk_euclideans_packed_i4(nk_i4x2_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
284
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
285
+ nk_size_t r_stride_in_bytes);
286
+ /** @copydoc nk_euclideans_symmetric_f32 */
287
+ NK_DYNAMIC void nk_euclideans_symmetric_i4(nk_i4x2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
288
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
289
+ nk_size_t row_start, nk_size_t row_count);
290
+
291
+ /** @copydoc nk_angulars_packed_f32 */
292
+ NK_DYNAMIC void nk_angulars_packed_u4(nk_u4x2_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
293
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
294
+ nk_size_t r_stride_in_bytes);
295
+ /** @copydoc nk_angulars_symmetric_f32 */
296
+ NK_DYNAMIC void nk_angulars_symmetric_u4(nk_u4x2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
297
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
298
+ nk_size_t row_start, nk_size_t row_count);
299
+ /** @copydoc nk_euclideans_packed_f32 */
300
+ NK_DYNAMIC void nk_euclideans_packed_u4(nk_u4x2_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
301
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
302
+ nk_size_t r_stride_in_bytes);
303
+ /** @copydoc nk_euclideans_symmetric_f32 */
304
+ NK_DYNAMIC void nk_euclideans_symmetric_u4(nk_u4x2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
305
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
306
+ nk_size_t row_start, nk_size_t row_count);
307
+
308
+ /** @copydoc nk_angulars_packed_f32 */
309
+ NK_PUBLIC void nk_angulars_packed_f32_serial(nk_f32_t const *a, void const *b_packed, nk_f64_t *result, nk_size_t rows,
310
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
311
+ nk_size_t r_stride_in_bytes);
312
+ /** @copydoc nk_angulars_symmetric_f32 */
313
+ NK_PUBLIC void nk_angulars_symmetric_f32_serial(nk_f32_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
314
+ nk_size_t stride, nk_f64_t *result, nk_size_t result_stride,
315
+ nk_size_t row_start, nk_size_t row_count);
316
+ /** @copydoc nk_euclideans_packed_f32 */
317
+ NK_PUBLIC void nk_euclideans_packed_f32_serial(nk_f32_t const *a, void const *b_packed, nk_f64_t *result,
318
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
319
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
320
+ /** @copydoc nk_euclideans_symmetric_f32 */
321
+ NK_PUBLIC void nk_euclideans_symmetric_f32_serial(nk_f32_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
322
+ nk_size_t stride, nk_f64_t *result, nk_size_t result_stride,
323
+ nk_size_t row_start, nk_size_t row_count);
324
+
325
+ /** @copydoc nk_angulars_packed_f64 */
326
+ NK_PUBLIC void nk_angulars_packed_f64_serial(nk_f64_t const *a, void const *b_packed, nk_f64_t *result, nk_size_t rows,
327
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
328
+ nk_size_t r_stride_in_bytes);
329
+ /** @copydoc nk_angulars_symmetric_f64 */
330
+ NK_PUBLIC void nk_angulars_symmetric_f64_serial(nk_f64_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
331
+ nk_size_t stride, nk_f64_t *result, nk_size_t result_stride,
332
+ nk_size_t row_start, nk_size_t row_count);
333
+ /** @copydoc nk_euclideans_packed_f64 */
334
+ NK_PUBLIC void nk_euclideans_packed_f64_serial(nk_f64_t const *a, void const *b_packed, nk_f64_t *result,
335
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
336
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
337
+ /** @copydoc nk_euclideans_symmetric_f64 */
338
+ NK_PUBLIC void nk_euclideans_symmetric_f64_serial(nk_f64_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
339
+ nk_size_t stride, nk_f64_t *result, nk_size_t result_stride,
340
+ nk_size_t row_start, nk_size_t row_count);
341
+
342
+ /** @copydoc nk_angulars_packed_f16 */
343
+ NK_PUBLIC void nk_angulars_packed_f16_serial(nk_f16_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
344
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
345
+ nk_size_t r_stride_in_bytes);
346
+ /** @copydoc nk_angulars_symmetric_f16 */
347
+ NK_PUBLIC void nk_angulars_symmetric_f16_serial(nk_f16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
348
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
349
+ nk_size_t row_start, nk_size_t row_count);
350
+ /** @copydoc nk_euclideans_packed_f16 */
351
+ NK_PUBLIC void nk_euclideans_packed_f16_serial(nk_f16_t const *a, void const *b_packed, nk_f32_t *result,
352
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
353
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
354
+ /** @copydoc nk_euclideans_symmetric_f16 */
355
+ NK_PUBLIC void nk_euclideans_symmetric_f16_serial(nk_f16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
356
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
357
+ nk_size_t row_start, nk_size_t row_count);
358
+
359
+ /** @copydoc nk_angulars_packed_bf16 */
360
+ NK_PUBLIC void nk_angulars_packed_bf16_serial(nk_bf16_t const *a, void const *b_packed, nk_f32_t *result,
361
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
362
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
363
+ /** @copydoc nk_angulars_symmetric_bf16 */
364
+ NK_PUBLIC void nk_angulars_symmetric_bf16_serial(nk_bf16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
365
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
366
+ nk_size_t row_start, nk_size_t row_count);
367
+ /** @copydoc nk_euclideans_packed_bf16 */
368
+ NK_PUBLIC void nk_euclideans_packed_bf16_serial(nk_bf16_t const *a, void const *b_packed, nk_f32_t *result,
369
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
370
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
371
+ /** @copydoc nk_euclideans_symmetric_bf16 */
372
+ NK_PUBLIC void nk_euclideans_symmetric_bf16_serial(nk_bf16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
373
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
374
+ nk_size_t row_start, nk_size_t row_count);
375
+
376
+ /** @copydoc nk_angulars_packed_e4m3 */
377
+ NK_PUBLIC void nk_angulars_packed_e4m3_serial(nk_e4m3_t const *a, void const *b_packed, nk_f32_t *result,
378
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
379
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
380
+ /** @copydoc nk_angulars_symmetric_e4m3 */
381
+ NK_PUBLIC void nk_angulars_symmetric_e4m3_serial(nk_e4m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
382
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
383
+ nk_size_t row_start, nk_size_t row_count);
384
+ /** @copydoc nk_euclideans_packed_e4m3 */
385
+ NK_PUBLIC void nk_euclideans_packed_e4m3_serial(nk_e4m3_t const *a, void const *b_packed, nk_f32_t *result,
386
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
387
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
388
+ /** @copydoc nk_euclideans_symmetric_e4m3 */
389
+ NK_PUBLIC void nk_euclideans_symmetric_e4m3_serial(nk_e4m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
390
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
391
+ nk_size_t row_start, nk_size_t row_count);
392
+
393
+ /** @copydoc nk_angulars_packed_e5m2 */
394
+ NK_PUBLIC void nk_angulars_packed_e5m2_serial(nk_e5m2_t const *a, void const *b_packed, nk_f32_t *result,
395
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
396
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
397
+ /** @copydoc nk_angulars_symmetric_e5m2 */
398
+ NK_PUBLIC void nk_angulars_symmetric_e5m2_serial(nk_e5m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
399
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
400
+ nk_size_t row_start, nk_size_t row_count);
401
+ /** @copydoc nk_euclideans_packed_e5m2 */
402
+ NK_PUBLIC void nk_euclideans_packed_e5m2_serial(nk_e5m2_t const *a, void const *b_packed, nk_f32_t *result,
403
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
404
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
405
+ /** @copydoc nk_euclideans_symmetric_e5m2 */
406
+ NK_PUBLIC void nk_euclideans_symmetric_e5m2_serial(nk_e5m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
407
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
408
+ nk_size_t row_start, nk_size_t row_count);
409
+
410
+ /** @copydoc nk_angulars_packed_e2m3 */
411
+ NK_PUBLIC void nk_angulars_packed_e2m3_serial(nk_e2m3_t const *a, void const *b_packed, nk_f32_t *result,
412
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
413
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
414
+ /** @copydoc nk_angulars_symmetric_e2m3 */
415
+ NK_PUBLIC void nk_angulars_symmetric_e2m3_serial(nk_e2m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
416
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
417
+ nk_size_t row_start, nk_size_t row_count);
418
+ /** @copydoc nk_euclideans_packed_e2m3 */
419
+ NK_PUBLIC void nk_euclideans_packed_e2m3_serial(nk_e2m3_t const *a, void const *b_packed, nk_f32_t *result,
420
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
421
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
422
+ /** @copydoc nk_euclideans_symmetric_e2m3 */
423
+ NK_PUBLIC void nk_euclideans_symmetric_e2m3_serial(nk_e2m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
424
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
425
+ nk_size_t row_start, nk_size_t row_count);
426
+
427
+ /** @copydoc nk_angulars_packed_e3m2 */
428
+ NK_PUBLIC void nk_angulars_packed_e3m2_serial(nk_e3m2_t const *a, void const *b_packed, nk_f32_t *result,
429
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
430
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
431
+ /** @copydoc nk_angulars_symmetric_e3m2 */
432
+ NK_PUBLIC void nk_angulars_symmetric_e3m2_serial(nk_e3m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
433
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
434
+ nk_size_t row_start, nk_size_t row_count);
435
+ /** @copydoc nk_euclideans_packed_e3m2 */
436
+ NK_PUBLIC void nk_euclideans_packed_e3m2_serial(nk_e3m2_t const *a, void const *b_packed, nk_f32_t *result,
437
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
438
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
439
+ /** @copydoc nk_euclideans_symmetric_e3m2 */
440
+ NK_PUBLIC void nk_euclideans_symmetric_e3m2_serial(nk_e3m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
441
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
442
+ nk_size_t row_start, nk_size_t row_count);
443
+
444
+ /** @copydoc nk_angulars_packed_i8 */
445
+ NK_PUBLIC void nk_angulars_packed_i8_serial(nk_i8_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
446
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
447
+ nk_size_t r_stride_in_bytes);
448
+ /** @copydoc nk_angulars_symmetric_i8 */
449
+ NK_PUBLIC void nk_angulars_symmetric_i8_serial(nk_i8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
450
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
451
+ nk_size_t row_start, nk_size_t row_count);
452
+ /** @copydoc nk_euclideans_packed_i8 */
453
+ NK_PUBLIC void nk_euclideans_packed_i8_serial(nk_i8_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
454
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
455
+ nk_size_t r_stride_in_bytes);
456
+ /** @copydoc nk_euclideans_symmetric_i8 */
457
+ NK_PUBLIC void nk_euclideans_symmetric_i8_serial(nk_i8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
458
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
459
+ nk_size_t row_start, nk_size_t row_count);
460
+
461
+ /** @copydoc nk_angulars_packed_u8 */
462
+ NK_PUBLIC void nk_angulars_packed_u8_serial(nk_u8_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
463
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
464
+ nk_size_t r_stride_in_bytes);
465
+ /** @copydoc nk_angulars_symmetric_u8 */
466
+ NK_PUBLIC void nk_angulars_symmetric_u8_serial(nk_u8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
467
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
468
+ nk_size_t row_start, nk_size_t row_count);
469
+ /** @copydoc nk_euclideans_packed_u8 */
470
+ NK_PUBLIC void nk_euclideans_packed_u8_serial(nk_u8_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
471
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
472
+ nk_size_t r_stride_in_bytes);
473
+ /** @copydoc nk_euclideans_symmetric_u8 */
474
+ NK_PUBLIC void nk_euclideans_symmetric_u8_serial(nk_u8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
475
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
476
+ nk_size_t row_start, nk_size_t row_count);
477
+
478
+ /** @copydoc nk_angulars_packed_i4 */
479
+ NK_PUBLIC void nk_angulars_packed_i4_serial(nk_i4x2_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
480
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
481
+ nk_size_t r_stride_in_bytes);
482
+ /** @copydoc nk_angulars_symmetric_i4 */
483
+ NK_PUBLIC void nk_angulars_symmetric_i4_serial(nk_i4x2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
484
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
485
+ nk_size_t row_start, nk_size_t row_count);
486
+ /** @copydoc nk_euclideans_packed_i4 */
487
+ NK_PUBLIC void nk_euclideans_packed_i4_serial(nk_i4x2_t const *a, void const *b_packed, nk_f32_t *result,
488
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
489
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
490
+ /** @copydoc nk_euclideans_symmetric_i4 */
491
+ NK_PUBLIC void nk_euclideans_symmetric_i4_serial(nk_i4x2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
492
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
493
+ nk_size_t row_start, nk_size_t row_count);
494
+
495
+ /** @copydoc nk_angulars_packed_u4 */
496
+ NK_PUBLIC void nk_angulars_packed_u4_serial(nk_u4x2_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
497
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
498
+ nk_size_t r_stride_in_bytes);
499
+ /** @copydoc nk_angulars_symmetric_u4 */
500
+ NK_PUBLIC void nk_angulars_symmetric_u4_serial(nk_u4x2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
501
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
502
+ nk_size_t row_start, nk_size_t row_count);
503
+ /** @copydoc nk_euclideans_packed_u4 */
504
+ NK_PUBLIC void nk_euclideans_packed_u4_serial(nk_u4x2_t const *a, void const *b_packed, nk_f32_t *result,
505
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
506
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
507
+ /** @copydoc nk_euclideans_symmetric_u4 */
508
+ NK_PUBLIC void nk_euclideans_symmetric_u4_serial(nk_u4x2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
509
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
510
+ nk_size_t row_start, nk_size_t row_count);
511
+
512
+ /* Genoa backends using AVX-512 with BF16 extensions.
513
+ * These use VDPBF16PS for BF16 dot products.
514
+ * Packing interleaves elements for SIMD broadcast patterns.
515
+ */
516
+ #if NK_TARGET_GENOA
517
+ /** @copydoc nk_angulars_packed_bf16 */
518
+ NK_PUBLIC void nk_angulars_packed_bf16_genoa(nk_bf16_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
519
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
520
+ nk_size_t r_stride_in_bytes);
521
+ /** @copydoc nk_angulars_symmetric_bf16 */
522
+ NK_PUBLIC void nk_angulars_symmetric_bf16_genoa(nk_bf16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
523
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
524
+ nk_size_t row_start, nk_size_t row_count);
525
+ /** @copydoc nk_euclideans_packed_bf16 */
526
+ NK_PUBLIC void nk_euclideans_packed_bf16_genoa(nk_bf16_t const *a, void const *b_packed, nk_f32_t *result,
527
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
528
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
529
+ /** @copydoc nk_euclideans_symmetric_bf16 */
530
+ NK_PUBLIC void nk_euclideans_symmetric_bf16_genoa(nk_bf16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
531
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
532
+ nk_size_t row_start, nk_size_t row_count);
533
+
534
+ /** @copydoc nk_angulars_packed_e4m3 */
535
+ NK_PUBLIC void nk_angulars_packed_e4m3_genoa(nk_e4m3_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
536
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
537
+ nk_size_t r_stride_in_bytes);
538
+ /** @copydoc nk_angulars_symmetric_e4m3 */
539
+ NK_PUBLIC void nk_angulars_symmetric_e4m3_genoa(nk_e4m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
540
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
541
+ nk_size_t row_start, nk_size_t row_count);
542
+ /** @copydoc nk_euclideans_packed_e4m3 */
543
+ NK_PUBLIC void nk_euclideans_packed_e4m3_genoa(nk_e4m3_t const *a, void const *b_packed, nk_f32_t *result,
544
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
545
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
546
+ /** @copydoc nk_euclideans_symmetric_e4m3 */
547
+ NK_PUBLIC void nk_euclideans_symmetric_e4m3_genoa(nk_e4m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
548
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
549
+ nk_size_t row_start, nk_size_t row_count);
550
+
551
+ /** @copydoc nk_angulars_packed_e5m2 */
552
+ NK_PUBLIC void nk_angulars_packed_e5m2_genoa(nk_e5m2_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
553
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
554
+ nk_size_t r_stride_in_bytes);
555
+ /** @copydoc nk_angulars_symmetric_e5m2 */
556
+ NK_PUBLIC void nk_angulars_symmetric_e5m2_genoa(nk_e5m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
557
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
558
+ nk_size_t row_start, nk_size_t row_count);
559
+ /** @copydoc nk_euclideans_packed_e5m2 */
560
+ NK_PUBLIC void nk_euclideans_packed_e5m2_genoa(nk_e5m2_t const *a, void const *b_packed, nk_f32_t *result,
561
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
562
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
563
+ /** @copydoc nk_euclideans_symmetric_e5m2 */
564
+ NK_PUBLIC void nk_euclideans_symmetric_e5m2_genoa(nk_e5m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
565
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
566
+ nk_size_t row_start, nk_size_t row_count);
567
+
568
+ #endif // NK_TARGET_GENOA
569
+
570
+ /* Sapphire Rapids backends using Intel AMX (Advanced Matrix Extensions).
571
+ * AMX provides 8 tile registers (TMM0-TMM7), each holding up to 1KB of data.
572
+ * Tiles are configured as 16 rows x 64 bytes, enabling (16 x 32) BF16 or (16 x 64) INT8 tiles.
573
+ * Packing arranges data into AMX-native tile layout with pair interleaving for TDPBF16PS.
574
+ */
575
+ #if NK_TARGET_SAPPHIREAMX
576
+ /** @copydoc nk_angulars_packed_bf16 */
577
+ NK_PUBLIC void nk_angulars_packed_bf16_sapphireamx(nk_bf16_t const *a, void const *b_packed, nk_f32_t *result,
578
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
579
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
580
+ /** @copydoc nk_angulars_symmetric_bf16 */
581
+ NK_PUBLIC void nk_angulars_symmetric_bf16_sapphireamx(nk_bf16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
582
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
583
+ nk_size_t row_start, nk_size_t row_count);
584
+ /** @copydoc nk_euclideans_packed_bf16 */
585
+ NK_PUBLIC void nk_euclideans_packed_bf16_sapphireamx(nk_bf16_t const *a, void const *b_packed, nk_f32_t *result,
586
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
587
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
588
+ /** @copydoc nk_euclideans_symmetric_bf16 */
589
+ NK_PUBLIC void nk_euclideans_symmetric_bf16_sapphireamx(nk_bf16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
590
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
591
+ nk_size_t row_start, nk_size_t row_count);
592
+
593
+ /** @copydoc nk_angulars_packed_e4m3 */
594
+ NK_PUBLIC void nk_angulars_packed_e4m3_sapphireamx(nk_e4m3_t const *a, void const *b_packed, nk_f32_t *result,
595
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
596
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
597
+ /** @copydoc nk_angulars_symmetric_e4m3 */
598
+ NK_PUBLIC void nk_angulars_symmetric_e4m3_sapphireamx(nk_e4m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
599
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
600
+ nk_size_t row_start, nk_size_t row_count);
601
+ /** @copydoc nk_euclideans_packed_e4m3 */
602
+ NK_PUBLIC void nk_euclideans_packed_e4m3_sapphireamx(nk_e4m3_t const *a, void const *b_packed, nk_f32_t *result,
603
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
604
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
605
+ /** @copydoc nk_euclideans_symmetric_e4m3 */
606
+ NK_PUBLIC void nk_euclideans_symmetric_e4m3_sapphireamx(nk_e4m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
607
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
608
+ nk_size_t row_start, nk_size_t row_count);
609
+
610
+ /** @copydoc nk_angulars_packed_e5m2 */
611
+ NK_PUBLIC void nk_angulars_packed_e5m2_sapphireamx(nk_e5m2_t const *a, void const *b_packed, nk_f32_t *result,
612
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
613
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
614
+ /** @copydoc nk_angulars_symmetric_e5m2 */
615
+ NK_PUBLIC void nk_angulars_symmetric_e5m2_sapphireamx(nk_e5m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
616
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
617
+ nk_size_t row_start, nk_size_t row_count);
618
+ /** @copydoc nk_euclideans_packed_e5m2 */
619
+ NK_PUBLIC void nk_euclideans_packed_e5m2_sapphireamx(nk_e5m2_t const *a, void const *b_packed, nk_f32_t *result,
620
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
621
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
622
+ /** @copydoc nk_euclideans_symmetric_e5m2 */
623
+ NK_PUBLIC void nk_euclideans_symmetric_e5m2_sapphireamx(nk_e5m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
624
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
625
+ nk_size_t row_start, nk_size_t row_count);
626
+
627
+ /** @copydoc nk_angulars_packed_e2m3 */
628
+ NK_PUBLIC void nk_angulars_packed_e2m3_sapphireamx(nk_e2m3_t const *a, void const *b_packed, nk_f32_t *result,
629
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
630
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
631
+ /** @copydoc nk_angulars_symmetric_e2m3 */
632
+ NK_PUBLIC void nk_angulars_symmetric_e2m3_sapphireamx(nk_e2m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
633
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
634
+ nk_size_t row_start, nk_size_t row_count);
635
+ /** @copydoc nk_euclideans_packed_e2m3 */
636
+ NK_PUBLIC void nk_euclideans_packed_e2m3_sapphireamx(nk_e2m3_t const *a, void const *b_packed, nk_f32_t *result,
637
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
638
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
639
+ /** @copydoc nk_euclideans_symmetric_e2m3 */
640
+ NK_PUBLIC void nk_euclideans_symmetric_e2m3_sapphireamx(nk_e2m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
641
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
642
+ nk_size_t row_start, nk_size_t row_count);
643
+
644
+ /** @copydoc nk_angulars_packed_e3m2 */
645
+ NK_PUBLIC void nk_angulars_packed_e3m2_sapphireamx(nk_e3m2_t const *a, void const *b_packed, nk_f32_t *result,
646
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
647
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
648
+ /** @copydoc nk_angulars_symmetric_e3m2 */
649
+ NK_PUBLIC void nk_angulars_symmetric_e3m2_sapphireamx(nk_e3m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
650
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
651
+ nk_size_t row_start, nk_size_t row_count);
652
+ /** @copydoc nk_euclideans_packed_e3m2 */
653
+ NK_PUBLIC void nk_euclideans_packed_e3m2_sapphireamx(nk_e3m2_t const *a, void const *b_packed, nk_f32_t *result,
654
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
655
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
656
+ /** @copydoc nk_euclideans_symmetric_e3m2 */
657
+ NK_PUBLIC void nk_euclideans_symmetric_e3m2_sapphireamx(nk_e3m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
658
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
659
+ nk_size_t row_start, nk_size_t row_count);
660
+
661
+ /** @copydoc nk_angulars_packed_i8 */
662
+ NK_PUBLIC void nk_angulars_packed_i8_sapphireamx(nk_i8_t const *a, void const *b_packed, nk_f32_t *result,
663
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
664
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
665
+ /** @copydoc nk_angulars_symmetric_i8 */
666
+ NK_PUBLIC void nk_angulars_symmetric_i8_sapphireamx(nk_i8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
667
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
668
+ nk_size_t row_start, nk_size_t row_count);
669
+ /** @copydoc nk_euclideans_packed_i8 */
670
+ NK_PUBLIC void nk_euclideans_packed_i8_sapphireamx(nk_i8_t const *a, void const *b_packed, nk_f32_t *result,
671
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
672
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
673
+ /** @copydoc nk_euclideans_symmetric_i8 */
674
+ NK_PUBLIC void nk_euclideans_symmetric_i8_sapphireamx(nk_i8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
675
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
676
+ nk_size_t row_start, nk_size_t row_count);
677
+
678
+ /** @copydoc nk_angulars_packed_u8 */
679
+ NK_PUBLIC void nk_angulars_packed_u8_sapphireamx(nk_u8_t const *a, void const *b_packed, nk_f32_t *result,
680
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
681
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
682
+ /** @copydoc nk_angulars_symmetric_u8 */
683
+ NK_PUBLIC void nk_angulars_symmetric_u8_sapphireamx(nk_u8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
684
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
685
+ nk_size_t row_start, nk_size_t row_count);
686
+ /** @copydoc nk_euclideans_packed_u8 */
687
+ NK_PUBLIC void nk_euclideans_packed_u8_sapphireamx(nk_u8_t const *a, void const *b_packed, nk_f32_t *result,
688
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
689
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
690
+ /** @copydoc nk_euclideans_symmetric_u8 */
691
+ NK_PUBLIC void nk_euclideans_symmetric_u8_sapphireamx(nk_u8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
692
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
693
+ nk_size_t row_start, nk_size_t row_count);
694
+ #endif // NK_TARGET_SAPPHIREAMX
695
+
696
+ /* ARM SME backends using Scalable Matrix Extension.
697
+ * SME provides ZA tile registers for outer product operations.
698
+ * F16/BF16/I8/U8/E4M3 use ZA32 tiles, F32/F64 use ZA64 tiles (FEAT_SME_F64F64).
699
+ */
700
+ #if NK_TARGET_SME
701
+ /** @copydoc nk_angulars_packed_f16 */
702
+ NK_PUBLIC void nk_angulars_packed_f16_sme(nk_f16_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
703
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
704
+ nk_size_t r_stride_in_bytes);
705
+ /** @copydoc nk_angulars_symmetric_f16 */
706
+ NK_PUBLIC void nk_angulars_symmetric_f16_sme(nk_f16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
707
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
708
+ nk_size_t row_start, nk_size_t row_count);
709
+ /** @copydoc nk_euclideans_packed_f16 */
710
+ NK_PUBLIC void nk_euclideans_packed_f16_sme(nk_f16_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
711
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
712
+ nk_size_t r_stride_in_bytes);
713
+ /** @copydoc nk_euclideans_symmetric_f16 */
714
+ NK_PUBLIC void nk_euclideans_symmetric_f16_sme(nk_f16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
715
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
716
+ nk_size_t row_start, nk_size_t row_count);
717
+
718
+ /** @copydoc nk_angulars_packed_bf16 */
719
+ NK_PUBLIC void nk_angulars_packed_bf16_sme(nk_bf16_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
720
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
721
+ nk_size_t r_stride_in_bytes);
722
+ /** @copydoc nk_angulars_symmetric_bf16 */
723
+ NK_PUBLIC void nk_angulars_symmetric_bf16_sme(nk_bf16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
724
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
725
+ nk_size_t row_start, nk_size_t row_count);
726
+ /** @copydoc nk_euclideans_packed_bf16 */
727
+ NK_PUBLIC void nk_euclideans_packed_bf16_sme(nk_bf16_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
728
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
729
+ nk_size_t r_stride_in_bytes);
730
+ /** @copydoc nk_euclideans_symmetric_bf16 */
731
+ NK_PUBLIC void nk_euclideans_symmetric_bf16_sme(nk_bf16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
732
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
733
+ nk_size_t row_start, nk_size_t row_count);
734
+
735
+ /** @copydoc nk_angulars_packed_e4m3 */
736
+ NK_PUBLIC void nk_angulars_packed_e4m3_sme(nk_e4m3_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
737
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
738
+ nk_size_t r_stride_in_bytes);
739
+ /** @copydoc nk_angulars_symmetric_e4m3 */
740
+ NK_PUBLIC void nk_angulars_symmetric_e4m3_sme(nk_e4m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
741
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
742
+ nk_size_t row_start, nk_size_t row_count);
743
+ /** @copydoc nk_euclideans_packed_e4m3 */
744
+ NK_PUBLIC void nk_euclideans_packed_e4m3_sme(nk_e4m3_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
745
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
746
+ nk_size_t r_stride_in_bytes);
747
+ /** @copydoc nk_euclideans_symmetric_e4m3 */
748
+ NK_PUBLIC void nk_euclideans_symmetric_e4m3_sme(nk_e4m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
749
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
750
+ nk_size_t row_start, nk_size_t row_count);
751
+
752
+ /** @copydoc nk_angulars_packed_e5m2 */
753
+ NK_PUBLIC void nk_angulars_packed_e5m2_sme(nk_e5m2_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
754
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
755
+ nk_size_t r_stride_in_bytes);
756
+ /** @copydoc nk_angulars_symmetric_e5m2 */
757
+ NK_PUBLIC void nk_angulars_symmetric_e5m2_sme(nk_e5m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
758
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
759
+ nk_size_t row_start, nk_size_t row_count);
760
+ /** @copydoc nk_euclideans_packed_e5m2 */
761
+ NK_PUBLIC void nk_euclideans_packed_e5m2_sme(nk_e5m2_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
762
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
763
+ nk_size_t r_stride_in_bytes);
764
+ /** @copydoc nk_euclideans_symmetric_e5m2 */
765
+ NK_PUBLIC void nk_euclideans_symmetric_e5m2_sme(nk_e5m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
766
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
767
+ nk_size_t row_start, nk_size_t row_count);
768
+
769
+ /** @copydoc nk_angulars_packed_e2m3 */
770
+ NK_PUBLIC void nk_angulars_packed_e2m3_sme(nk_e2m3_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
771
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
772
+ nk_size_t r_stride_in_bytes);
773
+ /** @copydoc nk_angulars_symmetric_e2m3 */
774
+ NK_PUBLIC void nk_angulars_symmetric_e2m3_sme(nk_e2m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
775
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
776
+ nk_size_t row_start, nk_size_t row_count);
777
+ /** @copydoc nk_euclideans_packed_e2m3 */
778
+ NK_PUBLIC void nk_euclideans_packed_e2m3_sme(nk_e2m3_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
779
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
780
+ nk_size_t r_stride_in_bytes);
781
+ /** @copydoc nk_euclideans_symmetric_e2m3 */
782
+ NK_PUBLIC void nk_euclideans_symmetric_e2m3_sme(nk_e2m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
783
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
784
+ nk_size_t row_start, nk_size_t row_count);
785
+
786
+ /** @copydoc nk_angulars_packed_e3m2 */
787
+ NK_PUBLIC void nk_angulars_packed_e3m2_sme(nk_e3m2_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
788
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
789
+ nk_size_t r_stride_in_bytes);
790
+ /** @copydoc nk_angulars_symmetric_e3m2 */
791
+ NK_PUBLIC void nk_angulars_symmetric_e3m2_sme(nk_e3m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
792
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
793
+ nk_size_t row_start, nk_size_t row_count);
794
+ /** @copydoc nk_euclideans_packed_e3m2 */
795
+ NK_PUBLIC void nk_euclideans_packed_e3m2_sme(nk_e3m2_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
796
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
797
+ nk_size_t r_stride_in_bytes);
798
+ /** @copydoc nk_euclideans_symmetric_e3m2 */
799
+ NK_PUBLIC void nk_euclideans_symmetric_e3m2_sme(nk_e3m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
800
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
801
+ nk_size_t row_start, nk_size_t row_count);
802
+
803
+ /** @copydoc nk_angulars_packed_i8 */
804
+ NK_PUBLIC void nk_angulars_packed_i8_sme(nk_i8_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
805
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
806
+ nk_size_t r_stride_in_bytes);
807
+ /** @copydoc nk_angulars_symmetric_i8 */
808
+ NK_PUBLIC void nk_angulars_symmetric_i8_sme(nk_i8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
809
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
810
+ nk_size_t row_start, nk_size_t row_count);
811
+ /** @copydoc nk_euclideans_packed_i8 */
812
+ NK_PUBLIC void nk_euclideans_packed_i8_sme(nk_i8_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
813
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
814
+ nk_size_t r_stride_in_bytes);
815
+ /** @copydoc nk_euclideans_symmetric_i8 */
816
+ NK_PUBLIC void nk_euclideans_symmetric_i8_sme(nk_i8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
817
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
818
+ nk_size_t row_start, nk_size_t row_count);
819
+
820
+ /** @copydoc nk_angulars_packed_u8 */
821
+ NK_PUBLIC void nk_angulars_packed_u8_sme(nk_u8_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
822
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
823
+ nk_size_t r_stride_in_bytes);
824
+ /** @copydoc nk_angulars_symmetric_u8 */
825
+ NK_PUBLIC void nk_angulars_symmetric_u8_sme(nk_u8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
826
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
827
+ nk_size_t row_start, nk_size_t row_count);
828
+ /** @copydoc nk_euclideans_packed_u8 */
829
+ NK_PUBLIC void nk_euclideans_packed_u8_sme(nk_u8_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
830
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
831
+ nk_size_t r_stride_in_bytes);
832
+ /** @copydoc nk_euclideans_symmetric_u8 */
833
+ NK_PUBLIC void nk_euclideans_symmetric_u8_sme(nk_u8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
834
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
835
+ nk_size_t row_start, nk_size_t row_count);
836
+
837
+ /** @copydoc nk_angulars_packed_i4 */
838
+ NK_PUBLIC void nk_angulars_packed_i4_sme(nk_i4x2_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
839
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
840
+ nk_size_t r_stride_in_bytes);
841
+ /** @copydoc nk_angulars_symmetric_i4 */
842
+ NK_PUBLIC void nk_angulars_symmetric_i4_sme(nk_i4x2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
843
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
844
+ nk_size_t row_start, nk_size_t row_count);
845
+ /** @copydoc nk_euclideans_packed_i4 */
846
+ NK_PUBLIC void nk_euclideans_packed_i4_sme(nk_i4x2_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
847
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
848
+ nk_size_t r_stride_in_bytes);
849
+ /** @copydoc nk_euclideans_symmetric_i4 */
850
+ NK_PUBLIC void nk_euclideans_symmetric_i4_sme(nk_i4x2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
851
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
852
+ nk_size_t row_start, nk_size_t row_count);
853
+
854
+ /** @copydoc nk_angulars_packed_u4 */
855
+ NK_PUBLIC void nk_angulars_packed_u4_sme(nk_u4x2_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
856
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
857
+ nk_size_t r_stride_in_bytes);
858
+ /** @copydoc nk_angulars_symmetric_u4 */
859
+ NK_PUBLIC void nk_angulars_symmetric_u4_sme(nk_u4x2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
860
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
861
+ nk_size_t row_start, nk_size_t row_count);
862
+ /** @copydoc nk_euclideans_packed_u4 */
863
+ NK_PUBLIC void nk_euclideans_packed_u4_sme(nk_u4x2_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
864
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
865
+ nk_size_t r_stride_in_bytes);
866
+ /** @copydoc nk_euclideans_symmetric_u4 */
867
+ NK_PUBLIC void nk_euclideans_symmetric_u4_sme(nk_u4x2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
868
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
869
+ nk_size_t row_start, nk_size_t row_count);
870
+ #endif // NK_TARGET_SME
871
+
872
+ /* ARM SME with FEAT_SME_F64F64 (F32/F64 with F64 accumulators).
873
+ * Requires Apple M4 or equivalent with F64 outer product support.
874
+ */
875
+ #if NK_TARGET_SMEF64
876
+ /** @copydoc nk_angulars_packed_f32 */
877
+ NK_PUBLIC void nk_angulars_packed_f32_smef64(nk_f32_t const *a, void const *b_packed, nk_f64_t *result, nk_size_t rows,
878
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
879
+ nk_size_t r_stride_in_bytes);
880
+ /** @copydoc nk_angulars_symmetric_f32 */
881
+ NK_PUBLIC void nk_angulars_symmetric_f32_smef64(nk_f32_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
882
+ nk_size_t stride, nk_f64_t *result, nk_size_t result_stride,
883
+ nk_size_t row_start, nk_size_t row_count);
884
+ /** @copydoc nk_euclideans_packed_f32 */
885
+ NK_PUBLIC void nk_euclideans_packed_f32_smef64(nk_f32_t const *a, void const *b_packed, nk_f64_t *result,
886
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
887
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
888
+ /** @copydoc nk_euclideans_symmetric_f32 */
889
+ NK_PUBLIC void nk_euclideans_symmetric_f32_smef64(nk_f32_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
890
+ nk_size_t stride, nk_f64_t *result, nk_size_t result_stride,
891
+ nk_size_t row_start, nk_size_t row_count);
892
+
893
+ /** @copydoc nk_angulars_packed_f64 */
894
+ NK_PUBLIC void nk_angulars_packed_f64_smef64(nk_f64_t const *a, void const *b_packed, nk_f64_t *result, nk_size_t rows,
895
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
896
+ nk_size_t r_stride_in_bytes);
897
+ /** @copydoc nk_angulars_symmetric_f64 */
898
+ NK_PUBLIC void nk_angulars_symmetric_f64_smef64(nk_f64_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
899
+ nk_size_t stride, nk_f64_t *result, nk_size_t result_stride,
900
+ nk_size_t row_start, nk_size_t row_count);
901
+ /** @copydoc nk_euclideans_packed_f64 */
902
+ NK_PUBLIC void nk_euclideans_packed_f64_smef64(nk_f64_t const *a, void const *b_packed, nk_f64_t *result,
903
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
904
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
905
+ /** @copydoc nk_euclideans_symmetric_f64 */
906
+ NK_PUBLIC void nk_euclideans_symmetric_f64_smef64(nk_f64_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
907
+ nk_size_t stride, nk_f64_t *result, nk_size_t result_stride,
908
+ nk_size_t row_start, nk_size_t row_count);
909
+ #endif // NK_TARGET_SMEF64
910
+
911
+ /* Haswell backends using AVX2 (Intel Core 4th gen).
912
+ * Supports F32/F64 via FMA, F16/BF16/FP8 via software emulation, I8/U8 via VPMADDUBSW+VPADDD.
913
+ */
914
+ #if NK_TARGET_HASWELL
915
+ /** @copydoc nk_angulars_packed_f32 */
916
+ NK_PUBLIC void nk_angulars_packed_f32_haswell(nk_f32_t const *a, void const *b_packed, nk_f64_t *result, nk_size_t rows,
917
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
918
+ nk_size_t r_stride_in_bytes);
919
+ /** @copydoc nk_angulars_symmetric_f32 */
920
+ NK_PUBLIC void nk_angulars_symmetric_f32_haswell(nk_f32_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
921
+ nk_size_t stride, nk_f64_t *result, nk_size_t result_stride,
922
+ nk_size_t row_start, nk_size_t row_count);
923
+ /** @copydoc nk_euclideans_packed_f32 */
924
+ NK_PUBLIC void nk_euclideans_packed_f32_haswell(nk_f32_t const *a, void const *b_packed, nk_f64_t *result,
925
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
926
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
927
+ /** @copydoc nk_euclideans_symmetric_f32 */
928
+ NK_PUBLIC void nk_euclideans_symmetric_f32_haswell(nk_f32_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
929
+ nk_size_t stride, nk_f64_t *result, nk_size_t result_stride,
930
+ nk_size_t row_start, nk_size_t row_count);
931
+
932
+ /** @copydoc nk_angulars_packed_f64 */
933
+ NK_PUBLIC void nk_angulars_packed_f64_haswell(nk_f64_t const *a, void const *b_packed, nk_f64_t *result, nk_size_t rows,
934
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
935
+ nk_size_t r_stride_in_bytes);
936
+ /** @copydoc nk_angulars_symmetric_f64 */
937
+ NK_PUBLIC void nk_angulars_symmetric_f64_haswell(nk_f64_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
938
+ nk_size_t stride, nk_f64_t *result, nk_size_t result_stride,
939
+ nk_size_t row_start, nk_size_t row_count);
940
+ /** @copydoc nk_euclideans_packed_f64 */
941
+ NK_PUBLIC void nk_euclideans_packed_f64_haswell(nk_f64_t const *a, void const *b_packed, nk_f64_t *result,
942
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
943
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
944
+ /** @copydoc nk_euclideans_symmetric_f64 */
945
+ NK_PUBLIC void nk_euclideans_symmetric_f64_haswell(nk_f64_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
946
+ nk_size_t stride, nk_f64_t *result, nk_size_t result_stride,
947
+ nk_size_t row_start, nk_size_t row_count);
948
+
949
+ /** @copydoc nk_angulars_packed_f16 */
950
+ NK_PUBLIC void nk_angulars_packed_f16_haswell(nk_f16_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
951
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
952
+ nk_size_t r_stride_in_bytes);
953
+ /** @copydoc nk_angulars_symmetric_f16 */
954
+ NK_PUBLIC void nk_angulars_symmetric_f16_haswell(nk_f16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
955
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
956
+ nk_size_t row_start, nk_size_t row_count);
957
+ /** @copydoc nk_euclideans_packed_f16 */
958
+ NK_PUBLIC void nk_euclideans_packed_f16_haswell(nk_f16_t const *a, void const *b_packed, nk_f32_t *result,
959
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
960
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
961
+ /** @copydoc nk_euclideans_symmetric_f16 */
962
+ NK_PUBLIC void nk_euclideans_symmetric_f16_haswell(nk_f16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
963
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
964
+ nk_size_t row_start, nk_size_t row_count);
965
+
966
+ /** @copydoc nk_angulars_packed_bf16 */
967
+ NK_PUBLIC void nk_angulars_packed_bf16_haswell(nk_bf16_t const *a, void const *b_packed, nk_f32_t *result,
968
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
969
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
970
+ /** @copydoc nk_angulars_symmetric_bf16 */
971
+ NK_PUBLIC void nk_angulars_symmetric_bf16_haswell(nk_bf16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
972
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
973
+ nk_size_t row_start, nk_size_t row_count);
974
+ /** @copydoc nk_euclideans_packed_bf16 */
975
+ NK_PUBLIC void nk_euclideans_packed_bf16_haswell(nk_bf16_t const *a, void const *b_packed, nk_f32_t *result,
976
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
977
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
978
+ /** @copydoc nk_euclideans_symmetric_bf16 */
979
+ NK_PUBLIC void nk_euclideans_symmetric_bf16_haswell(nk_bf16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
980
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
981
+ nk_size_t row_start, nk_size_t row_count);
982
+
983
+ /** @copydoc nk_angulars_packed_e4m3 */
984
+ NK_PUBLIC void nk_angulars_packed_e4m3_haswell(nk_e4m3_t const *a, void const *b_packed, nk_f32_t *result,
985
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
986
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
987
+ /** @copydoc nk_angulars_symmetric_e4m3 */
988
+ NK_PUBLIC void nk_angulars_symmetric_e4m3_haswell(nk_e4m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
989
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
990
+ nk_size_t row_start, nk_size_t row_count);
991
+ /** @copydoc nk_euclideans_packed_e4m3 */
992
+ NK_PUBLIC void nk_euclideans_packed_e4m3_haswell(nk_e4m3_t const *a, void const *b_packed, nk_f32_t *result,
993
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
994
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
995
+ /** @copydoc nk_euclideans_symmetric_e4m3 */
996
+ NK_PUBLIC void nk_euclideans_symmetric_e4m3_haswell(nk_e4m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
997
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
998
+ nk_size_t row_start, nk_size_t row_count);
999
+
1000
+ /** @copydoc nk_angulars_packed_e5m2 */
1001
+ NK_PUBLIC void nk_angulars_packed_e5m2_haswell(nk_e5m2_t const *a, void const *b_packed, nk_f32_t *result,
1002
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
1003
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
1004
+ /** @copydoc nk_angulars_symmetric_e5m2 */
1005
+ NK_PUBLIC void nk_angulars_symmetric_e5m2_haswell(nk_e5m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1006
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1007
+ nk_size_t row_start, nk_size_t row_count);
1008
+ /** @copydoc nk_euclideans_packed_e5m2 */
1009
+ NK_PUBLIC void nk_euclideans_packed_e5m2_haswell(nk_e5m2_t const *a, void const *b_packed, nk_f32_t *result,
1010
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
1011
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
1012
+ /** @copydoc nk_euclideans_symmetric_e5m2 */
1013
+ NK_PUBLIC void nk_euclideans_symmetric_e5m2_haswell(nk_e5m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1014
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1015
+ nk_size_t row_start, nk_size_t row_count);
1016
+
1017
+ /** @copydoc nk_angulars_packed_e2m3 */
1018
+ NK_PUBLIC void nk_angulars_packed_e2m3_haswell(nk_e2m3_t const *a, void const *b_packed, nk_f32_t *result,
1019
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
1020
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
1021
+ /** @copydoc nk_angulars_symmetric_e2m3 */
1022
+ NK_PUBLIC void nk_angulars_symmetric_e2m3_haswell(nk_e2m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1023
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1024
+ nk_size_t row_start, nk_size_t row_count);
1025
+ /** @copydoc nk_euclideans_packed_e2m3 */
1026
+ NK_PUBLIC void nk_euclideans_packed_e2m3_haswell(nk_e2m3_t const *a, void const *b_packed, nk_f32_t *result,
1027
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
1028
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
1029
+ /** @copydoc nk_euclideans_symmetric_e2m3 */
1030
+ NK_PUBLIC void nk_euclideans_symmetric_e2m3_haswell(nk_e2m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1031
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1032
+ nk_size_t row_start, nk_size_t row_count);
1033
+
1034
+ /** @copydoc nk_angulars_packed_e3m2 */
1035
+ NK_PUBLIC void nk_angulars_packed_e3m2_haswell(nk_e3m2_t const *a, void const *b_packed, nk_f32_t *result,
1036
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
1037
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
1038
+ /** @copydoc nk_angulars_symmetric_e3m2 */
1039
+ NK_PUBLIC void nk_angulars_symmetric_e3m2_haswell(nk_e3m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1040
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1041
+ nk_size_t row_start, nk_size_t row_count);
1042
+ /** @copydoc nk_euclideans_packed_e3m2 */
1043
+ NK_PUBLIC void nk_euclideans_packed_e3m2_haswell(nk_e3m2_t const *a, void const *b_packed, nk_f32_t *result,
1044
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
1045
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
1046
+ /** @copydoc nk_euclideans_symmetric_e3m2 */
1047
+ NK_PUBLIC void nk_euclideans_symmetric_e3m2_haswell(nk_e3m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1048
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1049
+ nk_size_t row_start, nk_size_t row_count);
1050
+ /** @copydoc nk_angulars_packed_i8 */
1051
+ NK_PUBLIC void nk_angulars_packed_i8_haswell(nk_i8_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
1052
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
1053
+ nk_size_t r_stride_in_bytes);
1054
+ /** @copydoc nk_angulars_symmetric_i8 */
1055
+ NK_PUBLIC void nk_angulars_symmetric_i8_haswell(nk_i8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1056
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1057
+ nk_size_t row_start, nk_size_t row_count);
1058
+ /** @copydoc nk_euclideans_packed_i8 */
1059
+ NK_PUBLIC void nk_euclideans_packed_i8_haswell(nk_i8_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
1060
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
1061
+ nk_size_t r_stride_in_bytes);
1062
+ /** @copydoc nk_euclideans_symmetric_i8 */
1063
+ NK_PUBLIC void nk_euclideans_symmetric_i8_haswell(nk_i8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1064
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1065
+ nk_size_t row_start, nk_size_t row_count);
1066
+ /** @copydoc nk_angulars_packed_u8 */
1067
+ NK_PUBLIC void nk_angulars_packed_u8_haswell(nk_u8_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
1068
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
1069
+ nk_size_t r_stride_in_bytes);
1070
+ /** @copydoc nk_angulars_symmetric_u8 */
1071
+ NK_PUBLIC void nk_angulars_symmetric_u8_haswell(nk_u8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1072
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1073
+ nk_size_t row_start, nk_size_t row_count);
1074
+ /** @copydoc nk_euclideans_packed_u8 */
1075
+ NK_PUBLIC void nk_euclideans_packed_u8_haswell(nk_u8_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
1076
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
1077
+ nk_size_t r_stride_in_bytes);
1078
+ /** @copydoc nk_euclideans_symmetric_u8 */
1079
+ NK_PUBLIC void nk_euclideans_symmetric_u8_haswell(nk_u8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1080
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1081
+ nk_size_t row_start, nk_size_t row_count);
1082
+ #endif // NK_TARGET_HASWELL
1083
+
1084
+ /* Skylake backends using AVX-512 (Intel Core 6th gen+).
1085
+ * Provides 512-bit vectors (16x f32, 8x f64), supporting F32/F64/F16/BF16/FP8 with FMA.
1086
+ */
1087
+ #if NK_TARGET_SKYLAKE
1088
+ /** @copydoc nk_angulars_packed_f32 */
1089
+ NK_PUBLIC void nk_angulars_packed_f32_skylake(nk_f32_t const *a, void const *b_packed, nk_f64_t *result, nk_size_t rows,
1090
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
1091
+ nk_size_t r_stride_in_bytes);
1092
+ /** @copydoc nk_angulars_symmetric_f32 */
1093
+ NK_PUBLIC void nk_angulars_symmetric_f32_skylake(nk_f32_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1094
+ nk_size_t stride, nk_f64_t *result, nk_size_t result_stride,
1095
+ nk_size_t row_start, nk_size_t row_count);
1096
+ /** @copydoc nk_euclideans_packed_f32 */
1097
+ NK_PUBLIC void nk_euclideans_packed_f32_skylake(nk_f32_t const *a, void const *b_packed, nk_f64_t *result,
1098
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
1099
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
1100
+ /** @copydoc nk_euclideans_symmetric_f32 */
1101
+ NK_PUBLIC void nk_euclideans_symmetric_f32_skylake(nk_f32_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1102
+ nk_size_t stride, nk_f64_t *result, nk_size_t result_stride,
1103
+ nk_size_t row_start, nk_size_t row_count);
1104
+
1105
+ /** @copydoc nk_angulars_packed_f64 */
1106
+ NK_PUBLIC void nk_angulars_packed_f64_skylake(nk_f64_t const *a, void const *b_packed, nk_f64_t *result, nk_size_t rows,
1107
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
1108
+ nk_size_t r_stride_in_bytes);
1109
+ /** @copydoc nk_angulars_symmetric_f64 */
1110
+ NK_PUBLIC void nk_angulars_symmetric_f64_skylake(nk_f64_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1111
+ nk_size_t stride, nk_f64_t *result, nk_size_t result_stride,
1112
+ nk_size_t row_start, nk_size_t row_count);
1113
+ /** @copydoc nk_euclideans_packed_f64 */
1114
+ NK_PUBLIC void nk_euclideans_packed_f64_skylake(nk_f64_t const *a, void const *b_packed, nk_f64_t *result,
1115
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
1116
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
1117
+ /** @copydoc nk_euclideans_symmetric_f64 */
1118
+ NK_PUBLIC void nk_euclideans_symmetric_f64_skylake(nk_f64_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1119
+ nk_size_t stride, nk_f64_t *result, nk_size_t result_stride,
1120
+ nk_size_t row_start, nk_size_t row_count);
1121
+
1122
+ /** @copydoc nk_angulars_packed_f16 */
1123
+ NK_PUBLIC void nk_angulars_packed_f16_skylake(nk_f16_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
1124
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
1125
+ nk_size_t r_stride_in_bytes);
1126
+ /** @copydoc nk_angulars_symmetric_f16 */
1127
+ NK_PUBLIC void nk_angulars_symmetric_f16_skylake(nk_f16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1128
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1129
+ nk_size_t row_start, nk_size_t row_count);
1130
+ /** @copydoc nk_euclideans_packed_f16 */
1131
+ NK_PUBLIC void nk_euclideans_packed_f16_skylake(nk_f16_t const *a, void const *b_packed, nk_f32_t *result,
1132
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
1133
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
1134
+ /** @copydoc nk_euclideans_symmetric_f16 */
1135
+ NK_PUBLIC void nk_euclideans_symmetric_f16_skylake(nk_f16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1136
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1137
+ nk_size_t row_start, nk_size_t row_count);
1138
+
1139
+ /** @copydoc nk_angulars_packed_bf16 */
1140
+ NK_PUBLIC void nk_angulars_packed_bf16_skylake(nk_bf16_t const *a, void const *b_packed, nk_f32_t *result,
1141
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
1142
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
1143
+ /** @copydoc nk_angulars_symmetric_bf16 */
1144
+ NK_PUBLIC void nk_angulars_symmetric_bf16_skylake(nk_bf16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1145
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1146
+ nk_size_t row_start, nk_size_t row_count);
1147
+ /** @copydoc nk_euclideans_packed_bf16 */
1148
+ NK_PUBLIC void nk_euclideans_packed_bf16_skylake(nk_bf16_t const *a, void const *b_packed, nk_f32_t *result,
1149
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
1150
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
1151
+ /** @copydoc nk_euclideans_symmetric_bf16 */
1152
+ NK_PUBLIC void nk_euclideans_symmetric_bf16_skylake(nk_bf16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1153
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1154
+ nk_size_t row_start, nk_size_t row_count);
1155
+
1156
+ /** @copydoc nk_angulars_packed_e4m3 */
1157
+ NK_PUBLIC void nk_angulars_packed_e4m3_skylake(nk_e4m3_t const *a, void const *b_packed, nk_f32_t *result,
1158
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
1159
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
1160
+ /** @copydoc nk_angulars_symmetric_e4m3 */
1161
+ NK_PUBLIC void nk_angulars_symmetric_e4m3_skylake(nk_e4m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1162
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1163
+ nk_size_t row_start, nk_size_t row_count);
1164
+ /** @copydoc nk_euclideans_packed_e4m3 */
1165
+ NK_PUBLIC void nk_euclideans_packed_e4m3_skylake(nk_e4m3_t const *a, void const *b_packed, nk_f32_t *result,
1166
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
1167
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
1168
+ /** @copydoc nk_euclideans_symmetric_e4m3 */
1169
+ NK_PUBLIC void nk_euclideans_symmetric_e4m3_skylake(nk_e4m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1170
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1171
+ nk_size_t row_start, nk_size_t row_count);
1172
+
1173
+ /** @copydoc nk_angulars_packed_e5m2 */
1174
+ NK_PUBLIC void nk_angulars_packed_e5m2_skylake(nk_e5m2_t const *a, void const *b_packed, nk_f32_t *result,
1175
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
1176
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
1177
+ /** @copydoc nk_angulars_symmetric_e5m2 */
1178
+ NK_PUBLIC void nk_angulars_symmetric_e5m2_skylake(nk_e5m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1179
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1180
+ nk_size_t row_start, nk_size_t row_count);
1181
+ /** @copydoc nk_euclideans_packed_e5m2 */
1182
+ NK_PUBLIC void nk_euclideans_packed_e5m2_skylake(nk_e5m2_t const *a, void const *b_packed, nk_f32_t *result,
1183
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
1184
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
1185
+ /** @copydoc nk_euclideans_symmetric_e5m2 */
1186
+ NK_PUBLIC void nk_euclideans_symmetric_e5m2_skylake(nk_e5m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1187
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1188
+ nk_size_t row_start, nk_size_t row_count);
1189
+
1190
+ /** @copydoc nk_angulars_packed_e2m3 */
1191
+ NK_PUBLIC void nk_angulars_packed_e2m3_skylake(nk_e2m3_t const *a, void const *b_packed, nk_f32_t *result,
1192
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
1193
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
1194
+ /** @copydoc nk_angulars_symmetric_e2m3 */
1195
+ NK_PUBLIC void nk_angulars_symmetric_e2m3_skylake(nk_e2m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1196
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1197
+ nk_size_t row_start, nk_size_t row_count);
1198
+ /** @copydoc nk_euclideans_packed_e2m3 */
1199
+ NK_PUBLIC void nk_euclideans_packed_e2m3_skylake(nk_e2m3_t const *a, void const *b_packed, nk_f32_t *result,
1200
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
1201
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
1202
+ /** @copydoc nk_euclideans_symmetric_e2m3 */
1203
+ NK_PUBLIC void nk_euclideans_symmetric_e2m3_skylake(nk_e2m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1204
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1205
+ nk_size_t row_start, nk_size_t row_count);
1206
+
1207
+ /** @copydoc nk_angulars_packed_e3m2 */
1208
+ NK_PUBLIC void nk_angulars_packed_e3m2_skylake(nk_e3m2_t const *a, void const *b_packed, nk_f32_t *result,
1209
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
1210
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
1211
+ /** @copydoc nk_angulars_symmetric_e3m2 */
1212
+ NK_PUBLIC void nk_angulars_symmetric_e3m2_skylake(nk_e3m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1213
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1214
+ nk_size_t row_start, nk_size_t row_count);
1215
+ /** @copydoc nk_euclideans_packed_e3m2 */
1216
+ NK_PUBLIC void nk_euclideans_packed_e3m2_skylake(nk_e3m2_t const *a, void const *b_packed, nk_f32_t *result,
1217
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
1218
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
1219
+ /** @copydoc nk_euclideans_symmetric_e3m2 */
1220
+ NK_PUBLIC void nk_euclideans_symmetric_e3m2_skylake(nk_e3m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1221
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1222
+ nk_size_t row_start, nk_size_t row_count);
1223
+ #endif // NK_TARGET_SKYLAKE
1224
+
1225
+ /* Ice Lake backends using AVX-512 with VNNI (Vector Neural Network Instructions).
1226
+ * Adds VPDPBUSD for I8/U8, VPDPWSSD for I4/U4 with efficient dot products.
1227
+ */
1228
+ #if NK_TARGET_ICELAKE
1229
+ /** @copydoc nk_angulars_packed_i8 */
1230
+ NK_PUBLIC void nk_angulars_packed_i8_icelake(nk_i8_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
1231
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
1232
+ nk_size_t r_stride_in_bytes);
1233
+ /** @copydoc nk_angulars_symmetric_i8 */
1234
+ NK_PUBLIC void nk_angulars_symmetric_i8_icelake(nk_i8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1235
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1236
+ nk_size_t row_start, nk_size_t row_count);
1237
+ /** @copydoc nk_euclideans_packed_i8 */
1238
+ NK_PUBLIC void nk_euclideans_packed_i8_icelake(nk_i8_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
1239
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
1240
+ nk_size_t r_stride_in_bytes);
1241
+ /** @copydoc nk_euclideans_symmetric_i8 */
1242
+ NK_PUBLIC void nk_euclideans_symmetric_i8_icelake(nk_i8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1243
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1244
+ nk_size_t row_start, nk_size_t row_count);
1245
+
1246
+ /** @copydoc nk_angulars_packed_u8 */
1247
+ NK_PUBLIC void nk_angulars_packed_u8_icelake(nk_u8_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
1248
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
1249
+ nk_size_t r_stride_in_bytes);
1250
+ /** @copydoc nk_angulars_symmetric_u8 */
1251
+ NK_PUBLIC void nk_angulars_symmetric_u8_icelake(nk_u8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1252
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1253
+ nk_size_t row_start, nk_size_t row_count);
1254
+ /** @copydoc nk_euclideans_packed_u8 */
1255
+ NK_PUBLIC void nk_euclideans_packed_u8_icelake(nk_u8_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
1256
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
1257
+ nk_size_t r_stride_in_bytes);
1258
+ /** @copydoc nk_euclideans_symmetric_u8 */
1259
+ NK_PUBLIC void nk_euclideans_symmetric_u8_icelake(nk_u8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1260
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1261
+ nk_size_t row_start, nk_size_t row_count);
1262
+
1263
+ /** @copydoc nk_angulars_packed_i4 */
1264
+ NK_PUBLIC void nk_angulars_packed_i4_icelake(nk_i4x2_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
1265
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
1266
+ nk_size_t r_stride_in_bytes);
1267
+ /** @copydoc nk_angulars_symmetric_i4 */
1268
+ NK_PUBLIC void nk_angulars_symmetric_i4_icelake(nk_i4x2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1269
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1270
+ nk_size_t row_start, nk_size_t row_count);
1271
+ /** @copydoc nk_euclideans_packed_i4 */
1272
+ NK_PUBLIC void nk_euclideans_packed_i4_icelake(nk_i4x2_t const *a, void const *b_packed, nk_f32_t *result,
1273
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
1274
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
1275
+ /** @copydoc nk_euclideans_symmetric_i4 */
1276
+ NK_PUBLIC void nk_euclideans_symmetric_i4_icelake(nk_i4x2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1277
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1278
+ nk_size_t row_start, nk_size_t row_count);
1279
+
1280
+ /** @copydoc nk_angulars_packed_u4 */
1281
+ NK_PUBLIC void nk_angulars_packed_u4_icelake(nk_u4x2_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
1282
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
1283
+ nk_size_t r_stride_in_bytes);
1284
+ /** @copydoc nk_angulars_symmetric_u4 */
1285
+ NK_PUBLIC void nk_angulars_symmetric_u4_icelake(nk_u4x2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1286
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1287
+ nk_size_t row_start, nk_size_t row_count);
1288
+ /** @copydoc nk_euclideans_packed_u4 */
1289
+ NK_PUBLIC void nk_euclideans_packed_u4_icelake(nk_u4x2_t const *a, void const *b_packed, nk_f32_t *result,
1290
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
1291
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
1292
+ /** @copydoc nk_euclideans_symmetric_u4 */
1293
+ NK_PUBLIC void nk_euclideans_symmetric_u4_icelake(nk_u4x2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1294
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1295
+ nk_size_t row_start, nk_size_t row_count);
1296
+ #endif // NK_TARGET_ICELAKE
1297
+
1298
+ #if NK_TARGET_ALDER
1299
+ /** @copydoc nk_angulars_packed_i8 */
1300
+ NK_PUBLIC void nk_angulars_packed_i8_alder(nk_i8_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
1301
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
1302
+ nk_size_t r_stride_in_bytes);
1303
+ /** @copydoc nk_angulars_symmetric_i8 */
1304
+ NK_PUBLIC void nk_angulars_symmetric_i8_alder(nk_i8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1305
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1306
+ nk_size_t row_start, nk_size_t row_count);
1307
+ /** @copydoc nk_euclideans_packed_i8 */
1308
+ NK_PUBLIC void nk_euclideans_packed_i8_alder(nk_i8_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
1309
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
1310
+ nk_size_t r_stride_in_bytes);
1311
+ /** @copydoc nk_euclideans_symmetric_i8 */
1312
+ NK_PUBLIC void nk_euclideans_symmetric_i8_alder(nk_i8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1313
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1314
+ nk_size_t row_start, nk_size_t row_count);
1315
+ /** @copydoc nk_angulars_packed_u8 */
1316
+ NK_PUBLIC void nk_angulars_packed_u8_alder(nk_u8_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
1317
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
1318
+ nk_size_t r_stride_in_bytes);
1319
+ /** @copydoc nk_angulars_symmetric_u8 */
1320
+ NK_PUBLIC void nk_angulars_symmetric_u8_alder(nk_u8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1321
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1322
+ nk_size_t row_start, nk_size_t row_count);
1323
+ /** @copydoc nk_euclideans_packed_u8 */
1324
+ NK_PUBLIC void nk_euclideans_packed_u8_alder(nk_u8_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
1325
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
1326
+ nk_size_t r_stride_in_bytes);
1327
+ /** @copydoc nk_euclideans_symmetric_u8 */
1328
+ NK_PUBLIC void nk_euclideans_symmetric_u8_alder(nk_u8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1329
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1330
+ nk_size_t row_start, nk_size_t row_count);
1331
+ /** @copydoc nk_angulars_packed_e2m3 */
1332
+ NK_PUBLIC void nk_angulars_packed_e2m3_alder(nk_e2m3_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
1333
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
1334
+ nk_size_t r_stride_in_bytes);
1335
+ /** @copydoc nk_angulars_symmetric_e2m3 */
1336
+ NK_PUBLIC void nk_angulars_symmetric_e2m3_alder(nk_e2m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1337
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1338
+ nk_size_t row_start, nk_size_t row_count);
1339
+ /** @copydoc nk_euclideans_packed_e2m3 */
1340
+ NK_PUBLIC void nk_euclideans_packed_e2m3_alder(nk_e2m3_t const *a, void const *b_packed, nk_f32_t *result,
1341
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
1342
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
1343
+ /** @copydoc nk_euclideans_symmetric_e2m3 */
1344
+ NK_PUBLIC void nk_euclideans_symmetric_e2m3_alder(nk_e2m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1345
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1346
+ nk_size_t row_start, nk_size_t row_count);
1347
+ #endif // NK_TARGET_ALDER
1348
+
1349
+ /* Sierra backends using AVX10.2 with VMPSADBW.
1350
+ * Optimized for I8/U8 via VMPSADBW (vector multiply-sum of absolute differences).
1351
+ */
1352
+ #if NK_TARGET_SIERRA
1353
+ /** @copydoc nk_angulars_packed_i8 */
1354
+ NK_PUBLIC void nk_angulars_packed_i8_sierra(nk_i8_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
1355
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
1356
+ nk_size_t r_stride_in_bytes);
1357
+ /** @copydoc nk_angulars_symmetric_i8 */
1358
+ NK_PUBLIC void nk_angulars_symmetric_i8_sierra(nk_i8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1359
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1360
+ nk_size_t row_start, nk_size_t row_count);
1361
+ /** @copydoc nk_euclideans_packed_i8 */
1362
+ NK_PUBLIC void nk_euclideans_packed_i8_sierra(nk_i8_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
1363
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
1364
+ nk_size_t r_stride_in_bytes);
1365
+ /** @copydoc nk_euclideans_symmetric_i8 */
1366
+ NK_PUBLIC void nk_euclideans_symmetric_i8_sierra(nk_i8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1367
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1368
+ nk_size_t row_start, nk_size_t row_count);
1369
+ /** @copydoc nk_angulars_packed_u8 */
1370
+ NK_PUBLIC void nk_angulars_packed_u8_sierra(nk_u8_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
1371
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
1372
+ nk_size_t r_stride_in_bytes);
1373
+ /** @copydoc nk_angulars_symmetric_u8 */
1374
+ NK_PUBLIC void nk_angulars_symmetric_u8_sierra(nk_u8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1375
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1376
+ nk_size_t row_start, nk_size_t row_count);
1377
+ /** @copydoc nk_euclideans_packed_u8 */
1378
+ NK_PUBLIC void nk_euclideans_packed_u8_sierra(nk_u8_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
1379
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
1380
+ nk_size_t r_stride_in_bytes);
1381
+ /** @copydoc nk_euclideans_symmetric_u8 */
1382
+ NK_PUBLIC void nk_euclideans_symmetric_u8_sierra(nk_u8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1383
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1384
+ nk_size_t row_start, nk_size_t row_count);
1385
+ /** @copydoc nk_angulars_packed_e2m3 */
1386
+ NK_PUBLIC void nk_angulars_packed_e2m3_sierra(nk_e2m3_t const *a, void const *b_packed, nk_f32_t *result,
1387
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
1388
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
1389
+ /** @copydoc nk_angulars_symmetric_e2m3 */
1390
+ NK_PUBLIC void nk_angulars_symmetric_e2m3_sierra(nk_e2m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1391
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1392
+ nk_size_t row_start, nk_size_t row_count);
1393
+ /** @copydoc nk_euclideans_packed_e2m3 */
1394
+ NK_PUBLIC void nk_euclideans_packed_e2m3_sierra(nk_e2m3_t const *a, void const *b_packed, nk_f32_t *result,
1395
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
1396
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
1397
+ /** @copydoc nk_euclideans_symmetric_e2m3 */
1398
+ NK_PUBLIC void nk_euclideans_symmetric_e2m3_sierra(nk_e2m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1399
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1400
+ nk_size_t row_start, nk_size_t row_count);
1401
+ #endif // NK_TARGET_SIERRA
1402
+
1403
+ /* WASM Relaxed SIMD backends for angular/euclidean distances.
1404
+ * Covers I8/U8/E2M3/BF16/F32/F64 spatial distance operations.
1405
+ */
1406
+ #if NK_TARGET_V128RELAXED
1407
+ /** @copydoc nk_angulars_packed_i8 */
1408
+ NK_PUBLIC void nk_angulars_packed_i8_v128relaxed(nk_i8_t const *a, void const *b_packed, nk_f32_t *result,
1409
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
1410
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
1411
+ /** @copydoc nk_angulars_symmetric_i8 */
1412
+ NK_PUBLIC void nk_angulars_symmetric_i8_v128relaxed(nk_i8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1413
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1414
+ nk_size_t row_start, nk_size_t row_count);
1415
+ /** @copydoc nk_euclideans_packed_i8 */
1416
+ NK_PUBLIC void nk_euclideans_packed_i8_v128relaxed(nk_i8_t const *a, void const *b_packed, nk_f32_t *result,
1417
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
1418
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
1419
+ /** @copydoc nk_euclideans_symmetric_i8 */
1420
+ NK_PUBLIC void nk_euclideans_symmetric_i8_v128relaxed(nk_i8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1421
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1422
+ nk_size_t row_start, nk_size_t row_count);
1423
+ /** @copydoc nk_angulars_packed_u8 */
1424
+ NK_PUBLIC void nk_angulars_packed_u8_v128relaxed(nk_u8_t const *a, void const *b_packed, nk_f32_t *result,
1425
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
1426
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
1427
+ /** @copydoc nk_angulars_symmetric_u8 */
1428
+ NK_PUBLIC void nk_angulars_symmetric_u8_v128relaxed(nk_u8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1429
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1430
+ nk_size_t row_start, nk_size_t row_count);
1431
+ /** @copydoc nk_euclideans_packed_u8 */
1432
+ NK_PUBLIC void nk_euclideans_packed_u8_v128relaxed(nk_u8_t const *a, void const *b_packed, nk_f32_t *result,
1433
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
1434
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
1435
+ /** @copydoc nk_euclideans_symmetric_u8 */
1436
+ NK_PUBLIC void nk_euclideans_symmetric_u8_v128relaxed(nk_u8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1437
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1438
+ nk_size_t row_start, nk_size_t row_count);
1439
+ /** @copydoc nk_angulars_packed_e2m3 */
1440
+ NK_PUBLIC void nk_angulars_packed_e2m3_v128relaxed(nk_e2m3_t const *a, void const *b_packed, nk_f32_t *result,
1441
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
1442
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
1443
+ /** @copydoc nk_angulars_symmetric_e2m3 */
1444
+ NK_PUBLIC void nk_angulars_symmetric_e2m3_v128relaxed(nk_e2m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1445
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1446
+ nk_size_t row_start, nk_size_t row_count);
1447
+ /** @copydoc nk_euclideans_packed_e2m3 */
1448
+ NK_PUBLIC void nk_euclideans_packed_e2m3_v128relaxed(nk_e2m3_t const *a, void const *b_packed, nk_f32_t *result,
1449
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
1450
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
1451
+ /** @copydoc nk_euclideans_symmetric_e2m3 */
1452
+ NK_PUBLIC void nk_euclideans_symmetric_e2m3_v128relaxed(nk_e2m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1453
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1454
+ nk_size_t row_start, nk_size_t row_count);
1455
+ /** @copydoc nk_angulars_packed_e4m3 */
1456
+ NK_PUBLIC void nk_angulars_packed_e4m3_v128relaxed(nk_e4m3_t const *a, void const *b_packed, nk_f32_t *result,
1457
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
1458
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
1459
+ /** @copydoc nk_angulars_symmetric_e4m3 */
1460
+ NK_PUBLIC void nk_angulars_symmetric_e4m3_v128relaxed(nk_e4m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1461
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1462
+ nk_size_t row_start, nk_size_t row_count);
1463
+ /** @copydoc nk_euclideans_packed_e4m3 */
1464
+ NK_PUBLIC void nk_euclideans_packed_e4m3_v128relaxed(nk_e4m3_t const *a, void const *b_packed, nk_f32_t *result,
1465
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
1466
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
1467
+ /** @copydoc nk_euclideans_symmetric_e4m3 */
1468
+ NK_PUBLIC void nk_euclideans_symmetric_e4m3_v128relaxed(nk_e4m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1469
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1470
+ nk_size_t row_start, nk_size_t row_count);
1471
+ /** @copydoc nk_angulars_packed_e5m2 */
1472
+ NK_PUBLIC void nk_angulars_packed_e5m2_v128relaxed(nk_e5m2_t const *a, void const *b_packed, nk_f32_t *result,
1473
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
1474
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
1475
+ /** @copydoc nk_angulars_symmetric_e5m2 */
1476
+ NK_PUBLIC void nk_angulars_symmetric_e5m2_v128relaxed(nk_e5m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1477
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1478
+ nk_size_t row_start, nk_size_t row_count);
1479
+ /** @copydoc nk_euclideans_packed_e5m2 */
1480
+ NK_PUBLIC void nk_euclideans_packed_e5m2_v128relaxed(nk_e5m2_t const *a, void const *b_packed, nk_f32_t *result,
1481
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
1482
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
1483
+ /** @copydoc nk_euclideans_symmetric_e5m2 */
1484
+ NK_PUBLIC void nk_euclideans_symmetric_e5m2_v128relaxed(nk_e5m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1485
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1486
+ nk_size_t row_start, nk_size_t row_count);
1487
+ /** @copydoc nk_angulars_packed_bf16 */
1488
+ NK_PUBLIC void nk_angulars_packed_bf16_v128relaxed(nk_bf16_t const *a, void const *b_packed, nk_f32_t *result,
1489
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
1490
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
1491
+ /** @copydoc nk_angulars_symmetric_bf16 */
1492
+ NK_PUBLIC void nk_angulars_symmetric_bf16_v128relaxed(nk_bf16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1493
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1494
+ nk_size_t row_start, nk_size_t row_count);
1495
+ /** @copydoc nk_euclideans_packed_bf16 */
1496
+ NK_PUBLIC void nk_euclideans_packed_bf16_v128relaxed(nk_bf16_t const *a, void const *b_packed, nk_f32_t *result,
1497
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
1498
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
1499
+ /** @copydoc nk_euclideans_symmetric_bf16 */
1500
+ NK_PUBLIC void nk_euclideans_symmetric_bf16_v128relaxed(nk_bf16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1501
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1502
+ nk_size_t row_start, nk_size_t row_count);
1503
+ /** @copydoc nk_angulars_packed_f32 */
1504
+ NK_PUBLIC void nk_angulars_packed_f32_v128relaxed(nk_f32_t const *a, void const *b_packed, nk_f64_t *result,
1505
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
1506
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
1507
+ /** @copydoc nk_angulars_symmetric_f32 */
1508
+ NK_PUBLIC void nk_angulars_symmetric_f32_v128relaxed(nk_f32_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1509
+ nk_size_t stride, nk_f64_t *result, nk_size_t result_stride,
1510
+ nk_size_t row_start, nk_size_t row_count);
1511
+ /** @copydoc nk_euclideans_packed_f32 */
1512
+ NK_PUBLIC void nk_euclideans_packed_f32_v128relaxed(nk_f32_t const *a, void const *b_packed, nk_f64_t *result,
1513
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
1514
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
1515
+ /** @copydoc nk_euclideans_symmetric_f32 */
1516
+ NK_PUBLIC void nk_euclideans_symmetric_f32_v128relaxed(nk_f32_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1517
+ nk_size_t stride, nk_f64_t *result, nk_size_t result_stride,
1518
+ nk_size_t row_start, nk_size_t row_count);
1519
+ /** @copydoc nk_angulars_packed_f64 */
1520
+ NK_PUBLIC void nk_angulars_packed_f64_v128relaxed(nk_f64_t const *a, void const *b_packed, nk_f64_t *result,
1521
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
1522
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
1523
+ /** @copydoc nk_angulars_symmetric_f64 */
1524
+ NK_PUBLIC void nk_angulars_symmetric_f64_v128relaxed(nk_f64_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1525
+ nk_size_t stride, nk_f64_t *result, nk_size_t result_stride,
1526
+ nk_size_t row_start, nk_size_t row_count);
1527
+ /** @copydoc nk_euclideans_packed_f64 */
1528
+ NK_PUBLIC void nk_euclideans_packed_f64_v128relaxed(nk_f64_t const *a, void const *b_packed, nk_f64_t *result,
1529
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
1530
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
1531
+ /** @copydoc nk_euclideans_symmetric_f64 */
1532
+ NK_PUBLIC void nk_euclideans_symmetric_f64_v128relaxed(nk_f64_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1533
+ nk_size_t stride, nk_f64_t *result, nk_size_t result_stride,
1534
+ nk_size_t row_start, nk_size_t row_count);
1535
+ #endif // NK_TARGET_V128RELAXED
1536
+
1537
+ /* ARM NEON backends (base NEON with F32/F64 support).
1538
+ * Uses FMLA for F32 dots, FMLA (scalar) for F64.
1539
+ */
1540
+ #if NK_TARGET_NEON
1541
+ /** @copydoc nk_angulars_packed_f32 */
1542
+ NK_PUBLIC void nk_angulars_packed_f32_neon(nk_f32_t const *a, void const *b_packed, nk_f64_t *result, nk_size_t rows,
1543
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
1544
+ nk_size_t r_stride_in_bytes);
1545
+ /** @copydoc nk_angulars_symmetric_f32 */
1546
+ NK_PUBLIC void nk_angulars_symmetric_f32_neon(nk_f32_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1547
+ nk_size_t stride, nk_f64_t *result, nk_size_t result_stride,
1548
+ nk_size_t row_start, nk_size_t row_count);
1549
+ /** @copydoc nk_euclideans_packed_f32 */
1550
+ NK_PUBLIC void nk_euclideans_packed_f32_neon(nk_f32_t const *a, void const *b_packed, nk_f64_t *result, nk_size_t rows,
1551
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
1552
+ nk_size_t r_stride_in_bytes);
1553
+ /** @copydoc nk_euclideans_symmetric_f32 */
1554
+ NK_PUBLIC void nk_euclideans_symmetric_f32_neon(nk_f32_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1555
+ nk_size_t stride, nk_f64_t *result, nk_size_t result_stride,
1556
+ nk_size_t row_start, nk_size_t row_count);
1557
+
1558
+ /** @copydoc nk_angulars_packed_f64 */
1559
+ NK_PUBLIC void nk_angulars_packed_f64_neon(nk_f64_t const *a, void const *b_packed, nk_f64_t *result, nk_size_t rows,
1560
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
1561
+ nk_size_t r_stride_in_bytes);
1562
+ /** @copydoc nk_angulars_symmetric_f64 */
1563
+ NK_PUBLIC void nk_angulars_symmetric_f64_neon(nk_f64_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1564
+ nk_size_t stride, nk_f64_t *result, nk_size_t result_stride,
1565
+ nk_size_t row_start, nk_size_t row_count);
1566
+ /** @copydoc nk_euclideans_packed_f64 */
1567
+ NK_PUBLIC void nk_euclideans_packed_f64_neon(nk_f64_t const *a, void const *b_packed, nk_f64_t *result, nk_size_t rows,
1568
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
1569
+ nk_size_t r_stride_in_bytes);
1570
+ /** @copydoc nk_euclideans_symmetric_f64 */
1571
+ NK_PUBLIC void nk_euclideans_symmetric_f64_neon(nk_f64_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1572
+ nk_size_t stride, nk_f64_t *result, nk_size_t result_stride,
1573
+ nk_size_t row_start, nk_size_t row_count);
1574
+ /** @copydoc nk_angulars_packed_bf16 */
1575
+ NK_PUBLIC void nk_angulars_packed_bf16_neon(nk_bf16_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
1576
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
1577
+ nk_size_t r_stride_in_bytes);
1578
+ /** @copydoc nk_angulars_symmetric_bf16 */
1579
+ NK_PUBLIC void nk_angulars_symmetric_bf16_neon(nk_bf16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1580
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1581
+ nk_size_t row_start, nk_size_t row_count);
1582
+ /** @copydoc nk_euclideans_packed_bf16 */
1583
+ NK_PUBLIC void nk_euclideans_packed_bf16_neon(nk_bf16_t const *a, void const *b_packed, nk_f32_t *result,
1584
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
1585
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
1586
+ /** @copydoc nk_euclideans_symmetric_bf16 */
1587
+ NK_PUBLIC void nk_euclideans_symmetric_bf16_neon(nk_bf16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1588
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1589
+ nk_size_t row_start, nk_size_t row_count);
1590
+ /** @copydoc nk_angulars_packed_f16 */
1591
+ NK_PUBLIC void nk_angulars_packed_f16_neon(nk_f16_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
1592
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
1593
+ nk_size_t r_stride_in_bytes);
1594
+ /** @copydoc nk_angulars_symmetric_f16 */
1595
+ NK_PUBLIC void nk_angulars_symmetric_f16_neon(nk_f16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1596
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1597
+ nk_size_t row_start, nk_size_t row_count);
1598
+ /** @copydoc nk_euclideans_packed_f16 */
1599
+ NK_PUBLIC void nk_euclideans_packed_f16_neon(nk_f16_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
1600
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
1601
+ nk_size_t r_stride_in_bytes);
1602
+ /** @copydoc nk_euclideans_symmetric_f16 */
1603
+ NK_PUBLIC void nk_euclideans_symmetric_f16_neon(nk_f16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1604
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1605
+ nk_size_t row_start, nk_size_t row_count);
1606
+ #endif // NK_TARGET_NEON
1607
+
1608
+ /* ARM NEON with F16 arithmetic (ARMv8.2-A FP16).
1609
+ * Provides native F16 FMLA for half-precision dot products.
1610
+ */
1611
+ #if NK_TARGET_NEONHALF
1612
+ /** @copydoc nk_angulars_packed_f16 */
1613
+ NK_PUBLIC void nk_angulars_packed_f16_neonhalf(nk_f16_t const *a, void const *b_packed, nk_f32_t *result,
1614
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
1615
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
1616
+ /** @copydoc nk_angulars_symmetric_f16 */
1617
+ NK_PUBLIC void nk_angulars_symmetric_f16_neonhalf(nk_f16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1618
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1619
+ nk_size_t row_start, nk_size_t row_count);
1620
+ /** @copydoc nk_euclideans_packed_f16 */
1621
+ NK_PUBLIC void nk_euclideans_packed_f16_neonhalf(nk_f16_t const *a, void const *b_packed, nk_f32_t *result,
1622
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
1623
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
1624
+ /** @copydoc nk_euclideans_symmetric_f16 */
1625
+ NK_PUBLIC void nk_euclideans_symmetric_f16_neonhalf(nk_f16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1626
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1627
+ nk_size_t row_start, nk_size_t row_count);
1628
+ #endif // NK_TARGET_NEONHALF
1629
+
1630
+ /* ARM NEON with BF16 dot product (ARMv8.6-A BF16).
1631
+ * Uses BFDOT/BFMMLA for efficient BF16 matrix operations.
1632
+ */
1633
+ #if NK_TARGET_NEONBFDOT
1634
+ /** @copydoc nk_angulars_packed_bf16 */
1635
+ NK_PUBLIC void nk_angulars_packed_bf16_neonbfdot(nk_bf16_t const *a, void const *b_packed, nk_f32_t *result,
1636
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
1637
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
1638
+ /** @copydoc nk_angulars_symmetric_bf16 */
1639
+ NK_PUBLIC void nk_angulars_symmetric_bf16_neonbfdot(nk_bf16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1640
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1641
+ nk_size_t row_start, nk_size_t row_count);
1642
+ /** @copydoc nk_euclideans_packed_bf16 */
1643
+ NK_PUBLIC void nk_euclideans_packed_bf16_neonbfdot(nk_bf16_t const *a, void const *b_packed, nk_f32_t *result,
1644
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
1645
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
1646
+ /** @copydoc nk_euclideans_symmetric_bf16 */
1647
+ NK_PUBLIC void nk_euclideans_symmetric_bf16_neonbfdot(nk_bf16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1648
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1649
+ nk_size_t row_start, nk_size_t row_count);
1650
+ #endif // NK_TARGET_NEONBFDOT
1651
+
1652
+ /* ARM NEON with signed/unsigned dot product (ARMv8.2-A DotProd).
1653
+ * Provides SDOT/UDOT for I8/U8 vector dot products.
1654
+ */
1655
+ #if NK_TARGET_NEONSDOT
1656
+ /** @copydoc nk_angulars_packed_i8 */
1657
+ NK_PUBLIC void nk_angulars_packed_i8_neonsdot(nk_i8_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
1658
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
1659
+ nk_size_t r_stride_in_bytes);
1660
+ /** @copydoc nk_angulars_symmetric_i8 */
1661
+ NK_PUBLIC void nk_angulars_symmetric_i8_neonsdot(nk_i8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1662
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1663
+ nk_size_t row_start, nk_size_t row_count);
1664
+ /** @copydoc nk_euclideans_packed_i8 */
1665
+ NK_PUBLIC void nk_euclideans_packed_i8_neonsdot(nk_i8_t const *a, void const *b_packed, nk_f32_t *result,
1666
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
1667
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
1668
+ /** @copydoc nk_euclideans_symmetric_i8 */
1669
+ NK_PUBLIC void nk_euclideans_symmetric_i8_neonsdot(nk_i8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1670
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1671
+ nk_size_t row_start, nk_size_t row_count);
1672
+
1673
+ /** @copydoc nk_angulars_packed_u8 */
1674
+ NK_PUBLIC void nk_angulars_packed_u8_neonsdot(nk_u8_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
1675
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
1676
+ nk_size_t r_stride_in_bytes);
1677
+ /** @copydoc nk_angulars_symmetric_u8 */
1678
+ NK_PUBLIC void nk_angulars_symmetric_u8_neonsdot(nk_u8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1679
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1680
+ nk_size_t row_start, nk_size_t row_count);
1681
+ /** @copydoc nk_euclideans_packed_u8 */
1682
+ NK_PUBLIC void nk_euclideans_packed_u8_neonsdot(nk_u8_t const *a, void const *b_packed, nk_f32_t *result,
1683
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
1684
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
1685
+ /** @copydoc nk_euclideans_symmetric_u8 */
1686
+ NK_PUBLIC void nk_euclideans_symmetric_u8_neonsdot(nk_u8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1687
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1688
+ nk_size_t row_start, nk_size_t row_count);
1689
+
1690
+ /** @copydoc nk_angulars_packed_i4 */
1691
+ NK_PUBLIC void nk_angulars_packed_i4_neonsdot(nk_i4x2_t const *a, void const *b_packed, nk_f32_t *result,
1692
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
1693
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
1694
+ /** @copydoc nk_angulars_symmetric_i4 */
1695
+ NK_PUBLIC void nk_angulars_symmetric_i4_neonsdot(nk_i4x2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1696
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1697
+ nk_size_t row_start, nk_size_t row_count);
1698
+ /** @copydoc nk_euclideans_packed_i4 */
1699
+ NK_PUBLIC void nk_euclideans_packed_i4_neonsdot(nk_i4x2_t const *a, void const *b_packed, nk_f32_t *result,
1700
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
1701
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
1702
+ /** @copydoc nk_euclideans_symmetric_i4 */
1703
+ NK_PUBLIC void nk_euclideans_symmetric_i4_neonsdot(nk_i4x2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1704
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1705
+ nk_size_t row_start, nk_size_t row_count);
1706
+
1707
+ /** @copydoc nk_angulars_packed_u4 */
1708
+ NK_PUBLIC void nk_angulars_packed_u4_neonsdot(nk_u4x2_t const *a, void const *b_packed, nk_f32_t *result,
1709
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
1710
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
1711
+ /** @copydoc nk_angulars_symmetric_u4 */
1712
+ NK_PUBLIC void nk_angulars_symmetric_u4_neonsdot(nk_u4x2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1713
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1714
+ nk_size_t row_start, nk_size_t row_count);
1715
+ /** @copydoc nk_euclideans_packed_u4 */
1716
+ NK_PUBLIC void nk_euclideans_packed_u4_neonsdot(nk_u4x2_t const *a, void const *b_packed, nk_f32_t *result,
1717
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
1718
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
1719
+ /** @copydoc nk_euclideans_symmetric_u4 */
1720
+ NK_PUBLIC void nk_euclideans_symmetric_u4_neonsdot(nk_u4x2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1721
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1722
+ nk_size_t row_start, nk_size_t row_count);
1723
+ #endif // NK_TARGET_NEONSDOT
1724
+
1725
+ /* ARM NEON with FP16 FML (fused multiply-long, ARMv8.2-A FP16FML).
1726
+ * Uses FMLAL/FMLSL for F16 and custom FP8 (E2M3/E3M2) operations.
1727
+ */
1728
+ #if NK_TARGET_NEONFHM
1729
+ /** @copydoc nk_angulars_packed_f16 */
1730
+ NK_PUBLIC void nk_angulars_packed_f16_neonfhm(nk_f16_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
1731
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
1732
+ nk_size_t r_stride_in_bytes);
1733
+ /** @copydoc nk_angulars_symmetric_f16 */
1734
+ NK_PUBLIC void nk_angulars_symmetric_f16_neonfhm(nk_f16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1735
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1736
+ nk_size_t row_start, nk_size_t row_count);
1737
+ /** @copydoc nk_euclideans_packed_f16 */
1738
+ NK_PUBLIC void nk_euclideans_packed_f16_neonfhm(nk_f16_t const *a, void const *b_packed, nk_f32_t *result,
1739
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
1740
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
1741
+ /** @copydoc nk_euclideans_symmetric_f16 */
1742
+ NK_PUBLIC void nk_euclideans_symmetric_f16_neonfhm(nk_f16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1743
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1744
+ nk_size_t row_start, nk_size_t row_count);
1745
+
1746
+ /** @copydoc nk_angulars_packed_e4m3 */
1747
+ NK_PUBLIC void nk_angulars_packed_e4m3_neonfhm(nk_e4m3_t const *a, void const *b_packed, nk_f32_t *result,
1748
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
1749
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
1750
+ /** @copydoc nk_angulars_symmetric_e4m3 */
1751
+ NK_PUBLIC void nk_angulars_symmetric_e4m3_neonfhm(nk_e4m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1752
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1753
+ nk_size_t row_start, nk_size_t row_count);
1754
+ /** @copydoc nk_euclideans_packed_e4m3 */
1755
+ NK_PUBLIC void nk_euclideans_packed_e4m3_neonfhm(nk_e4m3_t const *a, void const *b_packed, nk_f32_t *result,
1756
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
1757
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
1758
+ /** @copydoc nk_euclideans_symmetric_e4m3 */
1759
+ NK_PUBLIC void nk_euclideans_symmetric_e4m3_neonfhm(nk_e4m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1760
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1761
+ nk_size_t row_start, nk_size_t row_count);
1762
+
1763
+ /** @copydoc nk_angulars_packed_e5m2 */
1764
+ NK_PUBLIC void nk_angulars_packed_e5m2_neonfhm(nk_e5m2_t const *a, void const *b_packed, nk_f32_t *result,
1765
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
1766
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
1767
+ /** @copydoc nk_angulars_symmetric_e5m2 */
1768
+ NK_PUBLIC void nk_angulars_symmetric_e5m2_neonfhm(nk_e5m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1769
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1770
+ nk_size_t row_start, nk_size_t row_count);
1771
+ /** @copydoc nk_euclideans_packed_e5m2 */
1772
+ NK_PUBLIC void nk_euclideans_packed_e5m2_neonfhm(nk_e5m2_t const *a, void const *b_packed, nk_f32_t *result,
1773
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
1774
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
1775
+ /** @copydoc nk_euclideans_symmetric_e5m2 */
1776
+ NK_PUBLIC void nk_euclideans_symmetric_e5m2_neonfhm(nk_e5m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1777
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1778
+ nk_size_t row_start, nk_size_t row_count);
1779
+
1780
+ #endif // NK_TARGET_NEONFHM
1781
+
1782
+ #if NK_TARGET_RVV
1783
+ /** @copydoc nk_angulars_packed_f32 */
1784
+ NK_PUBLIC void nk_angulars_packed_f32_rvv(nk_f32_t const *a, void const *b_packed, nk_f64_t *result, nk_size_t rows,
1785
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
1786
+ nk_size_t r_stride_in_bytes);
1787
+ /** @copydoc nk_angulars_symmetric_f32 */
1788
+ NK_PUBLIC void nk_angulars_symmetric_f32_rvv(nk_f32_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1789
+ nk_size_t stride, nk_f64_t *result, nk_size_t result_stride,
1790
+ nk_size_t row_start, nk_size_t row_count);
1791
+ /** @copydoc nk_euclideans_packed_f32 */
1792
+ NK_PUBLIC void nk_euclideans_packed_f32_rvv(nk_f32_t const *a, void const *b_packed, nk_f64_t *result, nk_size_t rows,
1793
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
1794
+ nk_size_t r_stride_in_bytes);
1795
+ /** @copydoc nk_euclideans_symmetric_f32 */
1796
+ NK_PUBLIC void nk_euclideans_symmetric_f32_rvv(nk_f32_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1797
+ nk_size_t stride, nk_f64_t *result, nk_size_t result_stride,
1798
+ nk_size_t row_start, nk_size_t row_count);
1799
+
1800
+ /** @copydoc nk_angulars_packed_f64 */
1801
+ NK_PUBLIC void nk_angulars_packed_f64_rvv(nk_f64_t const *a, void const *b_packed, nk_f64_t *result, nk_size_t rows,
1802
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
1803
+ nk_size_t r_stride_in_bytes);
1804
+ /** @copydoc nk_angulars_symmetric_f64 */
1805
+ NK_PUBLIC void nk_angulars_symmetric_f64_rvv(nk_f64_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1806
+ nk_size_t stride, nk_f64_t *result, nk_size_t result_stride,
1807
+ nk_size_t row_start, nk_size_t row_count);
1808
+ /** @copydoc nk_euclideans_packed_f64 */
1809
+ NK_PUBLIC void nk_euclideans_packed_f64_rvv(nk_f64_t const *a, void const *b_packed, nk_f64_t *result, nk_size_t rows,
1810
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
1811
+ nk_size_t r_stride_in_bytes);
1812
+ /** @copydoc nk_euclideans_symmetric_f64 */
1813
+ NK_PUBLIC void nk_euclideans_symmetric_f64_rvv(nk_f64_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1814
+ nk_size_t stride, nk_f64_t *result, nk_size_t result_stride,
1815
+ nk_size_t row_start, nk_size_t row_count);
1816
+
1817
+ /** @copydoc nk_angulars_packed_f16 */
1818
+ NK_PUBLIC void nk_angulars_packed_f16_rvv(nk_f16_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
1819
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
1820
+ nk_size_t r_stride_in_bytes);
1821
+ /** @copydoc nk_angulars_symmetric_f16 */
1822
+ NK_PUBLIC void nk_angulars_symmetric_f16_rvv(nk_f16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1823
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1824
+ nk_size_t row_start, nk_size_t row_count);
1825
+ /** @copydoc nk_euclideans_packed_f16 */
1826
+ NK_PUBLIC void nk_euclideans_packed_f16_rvv(nk_f16_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
1827
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
1828
+ nk_size_t r_stride_in_bytes);
1829
+ /** @copydoc nk_euclideans_symmetric_f16 */
1830
+ NK_PUBLIC void nk_euclideans_symmetric_f16_rvv(nk_f16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1831
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1832
+ nk_size_t row_start, nk_size_t row_count);
1833
+
1834
+ /** @copydoc nk_angulars_packed_bf16 */
1835
+ NK_PUBLIC void nk_angulars_packed_bf16_rvv(nk_bf16_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
1836
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
1837
+ nk_size_t r_stride_in_bytes);
1838
+ /** @copydoc nk_angulars_symmetric_bf16 */
1839
+ NK_PUBLIC void nk_angulars_symmetric_bf16_rvv(nk_bf16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1840
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1841
+ nk_size_t row_start, nk_size_t row_count);
1842
+ /** @copydoc nk_euclideans_packed_bf16 */
1843
+ NK_PUBLIC void nk_euclideans_packed_bf16_rvv(nk_bf16_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
1844
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
1845
+ nk_size_t r_stride_in_bytes);
1846
+ /** @copydoc nk_euclideans_symmetric_bf16 */
1847
+ NK_PUBLIC void nk_euclideans_symmetric_bf16_rvv(nk_bf16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1848
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1849
+ nk_size_t row_start, nk_size_t row_count);
1850
+
1851
+ /** @copydoc nk_angulars_packed_e4m3 */
1852
+ NK_PUBLIC void nk_angulars_packed_e4m3_rvv(nk_e4m3_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
1853
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
1854
+ nk_size_t r_stride_in_bytes);
1855
+ /** @copydoc nk_angulars_symmetric_e4m3 */
1856
+ NK_PUBLIC void nk_angulars_symmetric_e4m3_rvv(nk_e4m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1857
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1858
+ nk_size_t row_start, nk_size_t row_count);
1859
+ /** @copydoc nk_euclideans_packed_e4m3 */
1860
+ NK_PUBLIC void nk_euclideans_packed_e4m3_rvv(nk_e4m3_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
1861
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
1862
+ nk_size_t r_stride_in_bytes);
1863
+ /** @copydoc nk_euclideans_symmetric_e4m3 */
1864
+ NK_PUBLIC void nk_euclideans_symmetric_e4m3_rvv(nk_e4m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1865
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1866
+ nk_size_t row_start, nk_size_t row_count);
1867
+
1868
+ /** @copydoc nk_angulars_packed_e5m2 */
1869
+ NK_PUBLIC void nk_angulars_packed_e5m2_rvv(nk_e5m2_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
1870
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
1871
+ nk_size_t r_stride_in_bytes);
1872
+ /** @copydoc nk_angulars_symmetric_e5m2 */
1873
+ NK_PUBLIC void nk_angulars_symmetric_e5m2_rvv(nk_e5m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1874
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1875
+ nk_size_t row_start, nk_size_t row_count);
1876
+ /** @copydoc nk_euclideans_packed_e5m2 */
1877
+ NK_PUBLIC void nk_euclideans_packed_e5m2_rvv(nk_e5m2_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
1878
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
1879
+ nk_size_t r_stride_in_bytes);
1880
+ /** @copydoc nk_euclideans_symmetric_e5m2 */
1881
+ NK_PUBLIC void nk_euclideans_symmetric_e5m2_rvv(nk_e5m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1882
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1883
+ nk_size_t row_start, nk_size_t row_count);
1884
+
1885
+ /** @copydoc nk_angulars_packed_e2m3 */
1886
+ NK_PUBLIC void nk_angulars_packed_e2m3_rvv(nk_e2m3_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
1887
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
1888
+ nk_size_t r_stride_in_bytes);
1889
+ /** @copydoc nk_angulars_symmetric_e2m3 */
1890
+ NK_PUBLIC void nk_angulars_symmetric_e2m3_rvv(nk_e2m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1891
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1892
+ nk_size_t row_start, nk_size_t row_count);
1893
+ /** @copydoc nk_euclideans_packed_e2m3 */
1894
+ NK_PUBLIC void nk_euclideans_packed_e2m3_rvv(nk_e2m3_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
1895
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
1896
+ nk_size_t r_stride_in_bytes);
1897
+ /** @copydoc nk_euclideans_symmetric_e2m3 */
1898
+ NK_PUBLIC void nk_euclideans_symmetric_e2m3_rvv(nk_e2m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1899
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1900
+ nk_size_t row_start, nk_size_t row_count);
1901
+
1902
+ /** @copydoc nk_angulars_packed_e3m2 */
1903
+ NK_PUBLIC void nk_angulars_packed_e3m2_rvv(nk_e3m2_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
1904
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
1905
+ nk_size_t r_stride_in_bytes);
1906
+ /** @copydoc nk_angulars_symmetric_e3m2 */
1907
+ NK_PUBLIC void nk_angulars_symmetric_e3m2_rvv(nk_e3m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1908
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1909
+ nk_size_t row_start, nk_size_t row_count);
1910
+ /** @copydoc nk_euclideans_packed_e3m2 */
1911
+ NK_PUBLIC void nk_euclideans_packed_e3m2_rvv(nk_e3m2_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
1912
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
1913
+ nk_size_t r_stride_in_bytes);
1914
+ /** @copydoc nk_euclideans_symmetric_e3m2 */
1915
+ NK_PUBLIC void nk_euclideans_symmetric_e3m2_rvv(nk_e3m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1916
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1917
+ nk_size_t row_start, nk_size_t row_count);
1918
+
1919
+ /** @copydoc nk_angulars_packed_i8 */
1920
+ NK_PUBLIC void nk_angulars_packed_i8_rvv(nk_i8_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
1921
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
1922
+ nk_size_t r_stride_in_bytes);
1923
+ /** @copydoc nk_angulars_symmetric_i8 */
1924
+ NK_PUBLIC void nk_angulars_symmetric_i8_rvv(nk_i8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1925
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1926
+ nk_size_t row_start, nk_size_t row_count);
1927
+ /** @copydoc nk_euclideans_packed_i8 */
1928
+ NK_PUBLIC void nk_euclideans_packed_i8_rvv(nk_i8_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
1929
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
1930
+ nk_size_t r_stride_in_bytes);
1931
+ /** @copydoc nk_euclideans_symmetric_i8 */
1932
+ NK_PUBLIC void nk_euclideans_symmetric_i8_rvv(nk_i8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1933
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1934
+ nk_size_t row_start, nk_size_t row_count);
1935
+
1936
+ /** @copydoc nk_angulars_packed_u8 */
1937
+ NK_PUBLIC void nk_angulars_packed_u8_rvv(nk_u8_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
1938
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
1939
+ nk_size_t r_stride_in_bytes);
1940
+ /** @copydoc nk_angulars_symmetric_u8 */
1941
+ NK_PUBLIC void nk_angulars_symmetric_u8_rvv(nk_u8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1942
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1943
+ nk_size_t row_start, nk_size_t row_count);
1944
+ /** @copydoc nk_euclideans_packed_u8 */
1945
+ NK_PUBLIC void nk_euclideans_packed_u8_rvv(nk_u8_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
1946
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
1947
+ nk_size_t r_stride_in_bytes);
1948
+ /** @copydoc nk_euclideans_symmetric_u8 */
1949
+ NK_PUBLIC void nk_euclideans_symmetric_u8_rvv(nk_u8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
1950
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
1951
+ nk_size_t row_start, nk_size_t row_count);
1952
+ #endif // NK_TARGET_RVV
1953
+
1954
+ #if defined(__cplusplus)
1955
+ } // extern "C"
1956
+ #endif
1957
+
1958
+ #include "numkong/spatials/serial.h"
1959
+ #include "numkong/spatials/neon.h"
1960
+ #include "numkong/spatials/neonhalf.h"
1961
+ #include "numkong/spatials/neonfhm.h"
1962
+ #include "numkong/spatials/neonbfdot.h"
1963
+ #include "numkong/spatials/neonsdot.h"
1964
+ #include "numkong/spatials/haswell.h"
1965
+ #include "numkong/spatials/skylake.h"
1966
+ #include "numkong/spatials/genoa.h"
1967
+ #include "numkong/spatials/icelake.h"
1968
+ #include "numkong/spatials/alder.h"
1969
+ #include "numkong/spatials/sierra.h"
1970
+ #include "numkong/spatials/sapphireamx.h"
1971
+ #include "numkong/spatials/rvv.h"
1972
+ #include "numkong/spatials/v128relaxed.h"
1973
+ #include "numkong/spatials/sme.h"
1974
+ #include "numkong/spatials/smef64.h"
1975
+
1976
+ #if defined(__cplusplus)
1977
+ extern "C" {
1978
+ #endif
1979
+
1980
+ #if !NK_DYNAMIC_DISPATCH
1981
+
1982
+ NK_PUBLIC void nk_angulars_packed_f64(nk_f64_t const *a, void const *b_packed, nk_f64_t *result, nk_size_t rows,
1983
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
1984
+ nk_size_t r_stride_in_bytes) {
1985
+ #if NK_TARGET_SMEF64
1986
+ nk_angulars_packed_f64_smef64(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
1987
+ #elif NK_TARGET_NEON
1988
+ nk_angulars_packed_f64_neon(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
1989
+ #elif NK_TARGET_SKYLAKE
1990
+ nk_angulars_packed_f64_skylake(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
1991
+ #elif NK_TARGET_HASWELL
1992
+ nk_angulars_packed_f64_haswell(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
1993
+ #elif NK_TARGET_RVV
1994
+ nk_angulars_packed_f64_rvv(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
1995
+ #elif NK_TARGET_V128RELAXED
1996
+ nk_angulars_packed_f64_v128relaxed(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
1997
+ #else
1998
+ nk_angulars_packed_f64_serial(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
1999
+ #endif
2000
+ }
2001
+ NK_PUBLIC void nk_angulars_symmetric_f64(nk_f64_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
2002
+ nk_size_t stride, nk_f64_t *result, nk_size_t result_stride,
2003
+ nk_size_t row_start, nk_size_t row_count) {
2004
+ #if NK_TARGET_SMEF64
2005
+ nk_angulars_symmetric_f64_smef64(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2006
+ #elif NK_TARGET_NEON
2007
+ nk_angulars_symmetric_f64_neon(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2008
+ #elif NK_TARGET_SKYLAKE
2009
+ nk_angulars_symmetric_f64_skylake(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2010
+ #elif NK_TARGET_HASWELL
2011
+ nk_angulars_symmetric_f64_haswell(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2012
+ #elif NK_TARGET_RVV
2013
+ nk_angulars_symmetric_f64_rvv(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2014
+ #elif NK_TARGET_V128RELAXED
2015
+ nk_angulars_symmetric_f64_v128relaxed(vectors, n_vectors, depth, stride, result, result_stride, row_start,
2016
+ row_count);
2017
+ #else
2018
+ nk_angulars_symmetric_f64_serial(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2019
+ #endif
2020
+ }
2021
+ NK_PUBLIC void nk_euclideans_packed_f64(nk_f64_t const *a, void const *b_packed, nk_f64_t *result, nk_size_t rows,
2022
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
2023
+ nk_size_t r_stride_in_bytes) {
2024
+ #if NK_TARGET_SMEF64
2025
+ nk_euclideans_packed_f64_smef64(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2026
+ #elif NK_TARGET_NEON
2027
+ nk_euclideans_packed_f64_neon(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2028
+ #elif NK_TARGET_SKYLAKE
2029
+ nk_euclideans_packed_f64_skylake(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2030
+ #elif NK_TARGET_HASWELL
2031
+ nk_euclideans_packed_f64_haswell(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2032
+ #elif NK_TARGET_RVV
2033
+ nk_euclideans_packed_f64_rvv(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2034
+ #elif NK_TARGET_V128RELAXED
2035
+ nk_euclideans_packed_f64_v128relaxed(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2036
+ #else
2037
+ nk_euclideans_packed_f64_serial(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2038
+ #endif
2039
+ }
2040
+ NK_PUBLIC void nk_euclideans_symmetric_f64(nk_f64_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
2041
+ nk_size_t stride, nk_f64_t *result, nk_size_t result_stride,
2042
+ nk_size_t row_start, nk_size_t row_count) {
2043
+ #if NK_TARGET_SMEF64
2044
+ nk_euclideans_symmetric_f64_smef64(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2045
+ #elif NK_TARGET_NEON
2046
+ nk_euclideans_symmetric_f64_neon(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2047
+ #elif NK_TARGET_SKYLAKE
2048
+ nk_euclideans_symmetric_f64_skylake(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2049
+ #elif NK_TARGET_HASWELL
2050
+ nk_euclideans_symmetric_f64_haswell(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2051
+ #elif NK_TARGET_RVV
2052
+ nk_euclideans_symmetric_f64_rvv(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2053
+ #elif NK_TARGET_V128RELAXED
2054
+ nk_euclideans_symmetric_f64_v128relaxed(vectors, n_vectors, depth, stride, result, result_stride, row_start,
2055
+ row_count);
2056
+ #else
2057
+ nk_euclideans_symmetric_f64_serial(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2058
+ #endif
2059
+ }
2060
+
2061
+ NK_PUBLIC void nk_angulars_packed_f32(nk_f32_t const *a, void const *b_packed, nk_f64_t *result, nk_size_t rows,
2062
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
2063
+ nk_size_t r_stride_in_bytes) {
2064
+ #if NK_TARGET_SMEF64
2065
+ nk_angulars_packed_f32_smef64(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2066
+ #elif NK_TARGET_NEON
2067
+ nk_angulars_packed_f32_neon(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2068
+ #elif NK_TARGET_SKYLAKE
2069
+ nk_angulars_packed_f32_skylake(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2070
+ #elif NK_TARGET_HASWELL
2071
+ nk_angulars_packed_f32_haswell(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2072
+ #elif NK_TARGET_RVV
2073
+ nk_angulars_packed_f32_rvv(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2074
+ #elif NK_TARGET_V128RELAXED
2075
+ nk_angulars_packed_f32_v128relaxed(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2076
+ #else
2077
+ nk_angulars_packed_f32_serial(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2078
+ #endif
2079
+ }
2080
+ NK_PUBLIC void nk_angulars_symmetric_f32(nk_f32_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
2081
+ nk_size_t stride, nk_f64_t *result, nk_size_t result_stride,
2082
+ nk_size_t row_start, nk_size_t row_count) {
2083
+ #if NK_TARGET_SMEF64
2084
+ nk_angulars_symmetric_f32_smef64(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2085
+ #elif NK_TARGET_NEON
2086
+ nk_angulars_symmetric_f32_neon(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2087
+ #elif NK_TARGET_SKYLAKE
2088
+ nk_angulars_symmetric_f32_skylake(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2089
+ #elif NK_TARGET_HASWELL
2090
+ nk_angulars_symmetric_f32_haswell(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2091
+ #elif NK_TARGET_RVV
2092
+ nk_angulars_symmetric_f32_rvv(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2093
+ #elif NK_TARGET_V128RELAXED
2094
+ nk_angulars_symmetric_f32_v128relaxed(vectors, n_vectors, depth, stride, result, result_stride, row_start,
2095
+ row_count);
2096
+ #else
2097
+ nk_angulars_symmetric_f32_serial(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2098
+ #endif
2099
+ }
2100
+ NK_PUBLIC void nk_euclideans_packed_f32(nk_f32_t const *a, void const *b_packed, nk_f64_t *result, nk_size_t rows,
2101
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
2102
+ nk_size_t r_stride_in_bytes) {
2103
+ #if NK_TARGET_SMEF64
2104
+ nk_euclideans_packed_f32_smef64(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2105
+ #elif NK_TARGET_NEON
2106
+ nk_euclideans_packed_f32_neon(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2107
+ #elif NK_TARGET_SKYLAKE
2108
+ nk_euclideans_packed_f32_skylake(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2109
+ #elif NK_TARGET_HASWELL
2110
+ nk_euclideans_packed_f32_haswell(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2111
+ #elif NK_TARGET_RVV
2112
+ nk_euclideans_packed_f32_rvv(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2113
+ #elif NK_TARGET_V128RELAXED
2114
+ nk_euclideans_packed_f32_v128relaxed(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2115
+ #else
2116
+ nk_euclideans_packed_f32_serial(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2117
+ #endif
2118
+ }
2119
+ NK_PUBLIC void nk_euclideans_symmetric_f32(nk_f32_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
2120
+ nk_size_t stride, nk_f64_t *result, nk_size_t result_stride,
2121
+ nk_size_t row_start, nk_size_t row_count) {
2122
+ #if NK_TARGET_SMEF64
2123
+ nk_euclideans_symmetric_f32_smef64(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2124
+ #elif NK_TARGET_NEON
2125
+ nk_euclideans_symmetric_f32_neon(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2126
+ #elif NK_TARGET_SKYLAKE
2127
+ nk_euclideans_symmetric_f32_skylake(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2128
+ #elif NK_TARGET_HASWELL
2129
+ nk_euclideans_symmetric_f32_haswell(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2130
+ #elif NK_TARGET_RVV
2131
+ nk_euclideans_symmetric_f32_rvv(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2132
+ #elif NK_TARGET_V128RELAXED
2133
+ nk_euclideans_symmetric_f32_v128relaxed(vectors, n_vectors, depth, stride, result, result_stride, row_start,
2134
+ row_count);
2135
+ #else
2136
+ nk_euclideans_symmetric_f32_serial(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2137
+ #endif
2138
+ }
2139
+
2140
+ NK_PUBLIC void nk_angulars_packed_f16(nk_f16_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
2141
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
2142
+ nk_size_t r_stride_in_bytes) {
2143
+ #if NK_TARGET_SME
2144
+ nk_angulars_packed_f16_sme(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2145
+ #elif NK_TARGET_NEONFHM
2146
+ nk_angulars_packed_f16_neonfhm(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2147
+ #elif NK_TARGET_NEONHALF
2148
+ nk_angulars_packed_f16_neonhalf(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2149
+ #elif NK_TARGET_NEON
2150
+ nk_angulars_packed_f16_neon(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2151
+ #elif NK_TARGET_SKYLAKE
2152
+ nk_angulars_packed_f16_skylake(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2153
+ #elif NK_TARGET_HASWELL
2154
+ nk_angulars_packed_f16_haswell(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2155
+ #elif NK_TARGET_RVV
2156
+ nk_angulars_packed_f16_rvv(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2157
+ #else
2158
+ nk_angulars_packed_f16_serial(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2159
+ #endif
2160
+ }
2161
+ NK_PUBLIC void nk_angulars_symmetric_f16(nk_f16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
2162
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
2163
+ nk_size_t row_start, nk_size_t row_count) {
2164
+ #if NK_TARGET_SME
2165
+ nk_angulars_symmetric_f16_sme(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2166
+ #elif NK_TARGET_NEONFHM
2167
+ nk_angulars_symmetric_f16_neonfhm(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2168
+ #elif NK_TARGET_NEONHALF
2169
+ nk_angulars_symmetric_f16_neonhalf(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2170
+ #elif NK_TARGET_NEON
2171
+ nk_angulars_symmetric_f16_neon(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2172
+ #elif NK_TARGET_SKYLAKE
2173
+ nk_angulars_symmetric_f16_skylake(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2174
+ #elif NK_TARGET_HASWELL
2175
+ nk_angulars_symmetric_f16_haswell(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2176
+ #elif NK_TARGET_RVV
2177
+ nk_angulars_symmetric_f16_rvv(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2178
+ #else
2179
+ nk_angulars_symmetric_f16_serial(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2180
+ #endif
2181
+ }
2182
+ NK_PUBLIC void nk_euclideans_packed_f16(nk_f16_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
2183
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
2184
+ nk_size_t r_stride_in_bytes) {
2185
+ #if NK_TARGET_SME
2186
+ nk_euclideans_packed_f16_sme(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2187
+ #elif NK_TARGET_NEONFHM
2188
+ nk_euclideans_packed_f16_neonfhm(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2189
+ #elif NK_TARGET_NEONHALF
2190
+ nk_euclideans_packed_f16_neonhalf(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2191
+ #elif NK_TARGET_NEON
2192
+ nk_euclideans_packed_f16_neon(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2193
+ #elif NK_TARGET_SKYLAKE
2194
+ nk_euclideans_packed_f16_skylake(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2195
+ #elif NK_TARGET_HASWELL
2196
+ nk_euclideans_packed_f16_haswell(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2197
+ #elif NK_TARGET_RVV
2198
+ nk_euclideans_packed_f16_rvv(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2199
+ #else
2200
+ nk_euclideans_packed_f16_serial(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2201
+ #endif
2202
+ }
2203
+ NK_PUBLIC void nk_euclideans_symmetric_f16(nk_f16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
2204
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
2205
+ nk_size_t row_start, nk_size_t row_count) {
2206
+ #if NK_TARGET_SME
2207
+ nk_euclideans_symmetric_f16_sme(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2208
+ #elif NK_TARGET_NEONFHM
2209
+ nk_euclideans_symmetric_f16_neonfhm(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2210
+ #elif NK_TARGET_NEONHALF
2211
+ nk_euclideans_symmetric_f16_neonhalf(vectors, n_vectors, depth, stride, result, result_stride, row_start,
2212
+ row_count);
2213
+ #elif NK_TARGET_NEON
2214
+ nk_euclideans_symmetric_f16_neon(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2215
+ #elif NK_TARGET_SKYLAKE
2216
+ nk_euclideans_symmetric_f16_skylake(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2217
+ #elif NK_TARGET_HASWELL
2218
+ nk_euclideans_symmetric_f16_haswell(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2219
+ #elif NK_TARGET_RVV
2220
+ nk_euclideans_symmetric_f16_rvv(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2221
+ #else
2222
+ nk_euclideans_symmetric_f16_serial(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2223
+ #endif
2224
+ }
2225
+
2226
+ NK_PUBLIC void nk_angulars_packed_bf16(nk_bf16_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
2227
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
2228
+ nk_size_t r_stride_in_bytes) {
2229
+ #if NK_TARGET_SME
2230
+ nk_angulars_packed_bf16_sme(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2231
+ #elif NK_TARGET_NEONBFDOT
2232
+ nk_angulars_packed_bf16_neonbfdot(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2233
+ #elif NK_TARGET_SAPPHIREAMX
2234
+ nk_angulars_packed_bf16_sapphireamx(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2235
+ #elif NK_TARGET_GENOA
2236
+ nk_angulars_packed_bf16_genoa(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2237
+ #elif NK_TARGET_SKYLAKE
2238
+ nk_angulars_packed_bf16_skylake(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2239
+ #elif NK_TARGET_HASWELL
2240
+ nk_angulars_packed_bf16_haswell(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2241
+ #elif NK_TARGET_RVV
2242
+ nk_angulars_packed_bf16_rvv(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2243
+ #elif NK_TARGET_V128RELAXED
2244
+ nk_angulars_packed_bf16_v128relaxed(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2245
+ #else
2246
+ nk_angulars_packed_bf16_serial(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2247
+ #endif
2248
+ }
2249
+ NK_PUBLIC void nk_angulars_symmetric_bf16(nk_bf16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
2250
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
2251
+ nk_size_t row_start, nk_size_t row_count) {
2252
+ #if NK_TARGET_SME
2253
+ nk_angulars_symmetric_bf16_sme(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2254
+ #elif NK_TARGET_NEONBFDOT
2255
+ nk_angulars_symmetric_bf16_neonbfdot(vectors, n_vectors, depth, stride, result, result_stride, row_start,
2256
+ row_count);
2257
+ #elif NK_TARGET_SAPPHIREAMX
2258
+ nk_angulars_symmetric_bf16_sapphireamx(vectors, n_vectors, depth, stride, result, result_stride, row_start,
2259
+ row_count);
2260
+ #elif NK_TARGET_GENOA
2261
+ nk_angulars_symmetric_bf16_genoa(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2262
+ #elif NK_TARGET_SKYLAKE
2263
+ nk_angulars_symmetric_bf16_skylake(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2264
+ #elif NK_TARGET_HASWELL
2265
+ nk_angulars_symmetric_bf16_haswell(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2266
+ #elif NK_TARGET_RVV
2267
+ nk_angulars_symmetric_bf16_rvv(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2268
+ #elif NK_TARGET_V128RELAXED
2269
+ nk_angulars_symmetric_bf16_v128relaxed(vectors, n_vectors, depth, stride, result, result_stride, row_start,
2270
+ row_count);
2271
+ #else
2272
+ nk_angulars_symmetric_bf16_serial(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2273
+ #endif
2274
+ }
2275
+ NK_PUBLIC void nk_euclideans_packed_bf16(nk_bf16_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
2276
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
2277
+ nk_size_t r_stride_in_bytes) {
2278
+ #if NK_TARGET_SME
2279
+ nk_euclideans_packed_bf16_sme(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2280
+ #elif NK_TARGET_NEONBFDOT
2281
+ nk_euclideans_packed_bf16_neonbfdot(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2282
+ #elif NK_TARGET_SAPPHIREAMX
2283
+ nk_euclideans_packed_bf16_sapphireamx(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2284
+ #elif NK_TARGET_GENOA
2285
+ nk_euclideans_packed_bf16_genoa(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2286
+ #elif NK_TARGET_SKYLAKE
2287
+ nk_euclideans_packed_bf16_skylake(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2288
+ #elif NK_TARGET_HASWELL
2289
+ nk_euclideans_packed_bf16_haswell(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2290
+ #elif NK_TARGET_RVV
2291
+ nk_euclideans_packed_bf16_rvv(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2292
+ #elif NK_TARGET_V128RELAXED
2293
+ nk_euclideans_packed_bf16_v128relaxed(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2294
+ #else
2295
+ nk_euclideans_packed_bf16_serial(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2296
+ #endif
2297
+ }
2298
+ NK_PUBLIC void nk_euclideans_symmetric_bf16(nk_bf16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
2299
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
2300
+ nk_size_t row_start, nk_size_t row_count) {
2301
+ #if NK_TARGET_SME
2302
+ nk_euclideans_symmetric_bf16_sme(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2303
+ #elif NK_TARGET_NEONBFDOT
2304
+ nk_euclideans_symmetric_bf16_neonbfdot(vectors, n_vectors, depth, stride, result, result_stride, row_start,
2305
+ row_count);
2306
+ #elif NK_TARGET_SAPPHIREAMX
2307
+ nk_euclideans_symmetric_bf16_sapphireamx(vectors, n_vectors, depth, stride, result, result_stride, row_start,
2308
+ row_count);
2309
+ #elif NK_TARGET_GENOA
2310
+ nk_euclideans_symmetric_bf16_genoa(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2311
+ #elif NK_TARGET_SKYLAKE
2312
+ nk_euclideans_symmetric_bf16_skylake(vectors, n_vectors, depth, stride, result, result_stride, row_start,
2313
+ row_count);
2314
+ #elif NK_TARGET_HASWELL
2315
+ nk_euclideans_symmetric_bf16_haswell(vectors, n_vectors, depth, stride, result, result_stride, row_start,
2316
+ row_count);
2317
+ #elif NK_TARGET_RVV
2318
+ nk_euclideans_symmetric_bf16_rvv(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2319
+ #elif NK_TARGET_V128RELAXED
2320
+ nk_euclideans_symmetric_bf16_v128relaxed(vectors, n_vectors, depth, stride, result, result_stride, row_start,
2321
+ row_count);
2322
+ #else
2323
+ nk_euclideans_symmetric_bf16_serial(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2324
+ #endif
2325
+ }
2326
+
2327
+ NK_PUBLIC void nk_angulars_packed_e4m3(nk_e4m3_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
2328
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
2329
+ nk_size_t r_stride_in_bytes) {
2330
+ #if NK_TARGET_SME
2331
+ nk_angulars_packed_e4m3_sme(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2332
+ #elif NK_TARGET_NEONFHM
2333
+ nk_angulars_packed_e4m3_neonfhm(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2334
+ #elif NK_TARGET_SAPPHIREAMX
2335
+ nk_angulars_packed_e4m3_sapphireamx(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2336
+ #elif NK_TARGET_GENOA
2337
+ nk_angulars_packed_e4m3_genoa(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2338
+ #elif NK_TARGET_SKYLAKE
2339
+ nk_angulars_packed_e4m3_skylake(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2340
+ #elif NK_TARGET_HASWELL
2341
+ nk_angulars_packed_e4m3_haswell(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2342
+ #elif NK_TARGET_RVV
2343
+ nk_angulars_packed_e4m3_rvv(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2344
+ #elif NK_TARGET_V128RELAXED
2345
+ nk_angulars_packed_e4m3_v128relaxed(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2346
+ #else
2347
+ nk_angulars_packed_e4m3_serial(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2348
+ #endif
2349
+ }
2350
+ NK_PUBLIC void nk_angulars_symmetric_e4m3(nk_e4m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
2351
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
2352
+ nk_size_t row_start, nk_size_t row_count) {
2353
+ #if NK_TARGET_SME
2354
+ nk_angulars_symmetric_e4m3_sme(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2355
+ #elif NK_TARGET_NEONFHM
2356
+ nk_angulars_symmetric_e4m3_neonfhm(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2357
+ #elif NK_TARGET_SAPPHIREAMX
2358
+ nk_angulars_symmetric_e4m3_sapphireamx(vectors, n_vectors, depth, stride, result, result_stride, row_start,
2359
+ row_count);
2360
+ #elif NK_TARGET_GENOA
2361
+ nk_angulars_symmetric_e4m3_genoa(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2362
+ #elif NK_TARGET_SKYLAKE
2363
+ nk_angulars_symmetric_e4m3_skylake(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2364
+ #elif NK_TARGET_HASWELL
2365
+ nk_angulars_symmetric_e4m3_haswell(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2366
+ #elif NK_TARGET_RVV
2367
+ nk_angulars_symmetric_e4m3_rvv(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2368
+ #elif NK_TARGET_V128RELAXED
2369
+ nk_angulars_symmetric_e4m3_v128relaxed(vectors, n_vectors, depth, stride, result, result_stride, row_start,
2370
+ row_count);
2371
+ #else
2372
+ nk_angulars_symmetric_e4m3_serial(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2373
+ #endif
2374
+ }
2375
+ NK_PUBLIC void nk_euclideans_packed_e4m3(nk_e4m3_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
2376
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
2377
+ nk_size_t r_stride_in_bytes) {
2378
+ #if NK_TARGET_SME
2379
+ nk_euclideans_packed_e4m3_sme(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2380
+ #elif NK_TARGET_NEONFHM
2381
+ nk_euclideans_packed_e4m3_neonfhm(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2382
+ #elif NK_TARGET_SAPPHIREAMX
2383
+ nk_euclideans_packed_e4m3_sapphireamx(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2384
+ #elif NK_TARGET_GENOA
2385
+ nk_euclideans_packed_e4m3_genoa(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2386
+ #elif NK_TARGET_SKYLAKE
2387
+ nk_euclideans_packed_e4m3_skylake(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2388
+ #elif NK_TARGET_HASWELL
2389
+ nk_euclideans_packed_e4m3_haswell(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2390
+ #elif NK_TARGET_RVV
2391
+ nk_euclideans_packed_e4m3_rvv(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2392
+ #elif NK_TARGET_V128RELAXED
2393
+ nk_euclideans_packed_e4m3_v128relaxed(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2394
+ #else
2395
+ nk_euclideans_packed_e4m3_serial(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2396
+ #endif
2397
+ }
2398
+ NK_PUBLIC void nk_euclideans_symmetric_e4m3(nk_e4m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
2399
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
2400
+ nk_size_t row_start, nk_size_t row_count) {
2401
+ #if NK_TARGET_SME
2402
+ nk_euclideans_symmetric_e4m3_sme(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2403
+ #elif NK_TARGET_NEONFHM
2404
+ nk_euclideans_symmetric_e4m3_neonfhm(vectors, n_vectors, depth, stride, result, result_stride, row_start,
2405
+ row_count);
2406
+ #elif NK_TARGET_SAPPHIREAMX
2407
+ nk_euclideans_symmetric_e4m3_sapphireamx(vectors, n_vectors, depth, stride, result, result_stride, row_start,
2408
+ row_count);
2409
+ #elif NK_TARGET_GENOA
2410
+ nk_euclideans_symmetric_e4m3_genoa(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2411
+ #elif NK_TARGET_SKYLAKE
2412
+ nk_euclideans_symmetric_e4m3_skylake(vectors, n_vectors, depth, stride, result, result_stride, row_start,
2413
+ row_count);
2414
+ #elif NK_TARGET_HASWELL
2415
+ nk_euclideans_symmetric_e4m3_haswell(vectors, n_vectors, depth, stride, result, result_stride, row_start,
2416
+ row_count);
2417
+ #elif NK_TARGET_RVV
2418
+ nk_euclideans_symmetric_e4m3_rvv(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2419
+ #elif NK_TARGET_V128RELAXED
2420
+ nk_euclideans_symmetric_e4m3_v128relaxed(vectors, n_vectors, depth, stride, result, result_stride, row_start,
2421
+ row_count);
2422
+ #else
2423
+ nk_euclideans_symmetric_e4m3_serial(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2424
+ #endif
2425
+ }
2426
+
2427
+ NK_PUBLIC void nk_angulars_packed_e5m2(nk_e5m2_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
2428
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
2429
+ nk_size_t r_stride_in_bytes) {
2430
+ #if NK_TARGET_SME
2431
+ nk_angulars_packed_e5m2_sme(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2432
+ #elif NK_TARGET_NEONFHM
2433
+ nk_angulars_packed_e5m2_neonfhm(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2434
+ #elif NK_TARGET_SAPPHIREAMX
2435
+ nk_angulars_packed_e5m2_sapphireamx(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2436
+ #elif NK_TARGET_GENOA
2437
+ nk_angulars_packed_e5m2_genoa(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2438
+ #elif NK_TARGET_SKYLAKE
2439
+ nk_angulars_packed_e5m2_skylake(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2440
+ #elif NK_TARGET_HASWELL
2441
+ nk_angulars_packed_e5m2_haswell(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2442
+ #elif NK_TARGET_RVV
2443
+ nk_angulars_packed_e5m2_rvv(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2444
+ #elif NK_TARGET_V128RELAXED
2445
+ nk_angulars_packed_e5m2_v128relaxed(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2446
+ #else
2447
+ nk_angulars_packed_e5m2_serial(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2448
+ #endif
2449
+ }
2450
+ NK_PUBLIC void nk_angulars_symmetric_e5m2(nk_e5m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
2451
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
2452
+ nk_size_t row_start, nk_size_t row_count) {
2453
+ #if NK_TARGET_SME
2454
+ nk_angulars_symmetric_e5m2_sme(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2455
+ #elif NK_TARGET_NEONFHM
2456
+ nk_angulars_symmetric_e5m2_neonfhm(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2457
+ #elif NK_TARGET_SAPPHIREAMX
2458
+ nk_angulars_symmetric_e5m2_sapphireamx(vectors, n_vectors, depth, stride, result, result_stride, row_start,
2459
+ row_count);
2460
+ #elif NK_TARGET_GENOA
2461
+ nk_angulars_symmetric_e5m2_genoa(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2462
+ #elif NK_TARGET_SKYLAKE
2463
+ nk_angulars_symmetric_e5m2_skylake(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2464
+ #elif NK_TARGET_HASWELL
2465
+ nk_angulars_symmetric_e5m2_haswell(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2466
+ #elif NK_TARGET_RVV
2467
+ nk_angulars_symmetric_e5m2_rvv(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2468
+ #elif NK_TARGET_V128RELAXED
2469
+ nk_angulars_symmetric_e5m2_v128relaxed(vectors, n_vectors, depth, stride, result, result_stride, row_start,
2470
+ row_count);
2471
+ #else
2472
+ nk_angulars_symmetric_e5m2_serial(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2473
+ #endif
2474
+ }
2475
+ NK_PUBLIC void nk_euclideans_packed_e5m2(nk_e5m2_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
2476
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
2477
+ nk_size_t r_stride_in_bytes) {
2478
+ #if NK_TARGET_SME
2479
+ nk_euclideans_packed_e5m2_sme(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2480
+ #elif NK_TARGET_NEONFHM
2481
+ nk_euclideans_packed_e5m2_neonfhm(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2482
+ #elif NK_TARGET_SAPPHIREAMX
2483
+ nk_euclideans_packed_e5m2_sapphireamx(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2484
+ #elif NK_TARGET_GENOA
2485
+ nk_euclideans_packed_e5m2_genoa(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2486
+ #elif NK_TARGET_SKYLAKE
2487
+ nk_euclideans_packed_e5m2_skylake(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2488
+ #elif NK_TARGET_HASWELL
2489
+ nk_euclideans_packed_e5m2_haswell(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2490
+ #elif NK_TARGET_RVV
2491
+ nk_euclideans_packed_e5m2_rvv(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2492
+ #elif NK_TARGET_V128RELAXED
2493
+ nk_euclideans_packed_e5m2_v128relaxed(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2494
+ #else
2495
+ nk_euclideans_packed_e5m2_serial(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2496
+ #endif
2497
+ }
2498
+ NK_PUBLIC void nk_euclideans_symmetric_e5m2(nk_e5m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
2499
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
2500
+ nk_size_t row_start, nk_size_t row_count) {
2501
+ #if NK_TARGET_SME
2502
+ nk_euclideans_symmetric_e5m2_sme(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2503
+ #elif NK_TARGET_NEONFHM
2504
+ nk_euclideans_symmetric_e5m2_neonfhm(vectors, n_vectors, depth, stride, result, result_stride, row_start,
2505
+ row_count);
2506
+ #elif NK_TARGET_SAPPHIREAMX
2507
+ nk_euclideans_symmetric_e5m2_sapphireamx(vectors, n_vectors, depth, stride, result, result_stride, row_start,
2508
+ row_count);
2509
+ #elif NK_TARGET_GENOA
2510
+ nk_euclideans_symmetric_e5m2_genoa(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2511
+ #elif NK_TARGET_SKYLAKE
2512
+ nk_euclideans_symmetric_e5m2_skylake(vectors, n_vectors, depth, stride, result, result_stride, row_start,
2513
+ row_count);
2514
+ #elif NK_TARGET_HASWELL
2515
+ nk_euclideans_symmetric_e5m2_haswell(vectors, n_vectors, depth, stride, result, result_stride, row_start,
2516
+ row_count);
2517
+ #elif NK_TARGET_RVV
2518
+ nk_euclideans_symmetric_e5m2_rvv(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2519
+ #elif NK_TARGET_V128RELAXED
2520
+ nk_euclideans_symmetric_e5m2_v128relaxed(vectors, n_vectors, depth, stride, result, result_stride, row_start,
2521
+ row_count);
2522
+ #else
2523
+ nk_euclideans_symmetric_e5m2_serial(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2524
+ #endif
2525
+ }
2526
+
2527
+ NK_PUBLIC void nk_angulars_packed_e2m3(nk_e2m3_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
2528
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
2529
+ nk_size_t r_stride_in_bytes) {
2530
+ #if NK_TARGET_SME
2531
+ nk_angulars_packed_e2m3_sme(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2532
+ #elif NK_TARGET_SAPPHIREAMX
2533
+ nk_angulars_packed_e2m3_sapphireamx(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2534
+ #elif NK_TARGET_SKYLAKE
2535
+ nk_angulars_packed_e2m3_skylake(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2536
+ #elif NK_TARGET_SIERRA
2537
+ nk_angulars_packed_e2m3_sierra(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2538
+ #elif NK_TARGET_ALDER
2539
+ nk_angulars_packed_e2m3_alder(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2540
+ #elif NK_TARGET_HASWELL
2541
+ nk_angulars_packed_e2m3_haswell(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2542
+ #elif NK_TARGET_RVV
2543
+ nk_angulars_packed_e2m3_rvv(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2544
+ #elif NK_TARGET_V128RELAXED
2545
+ nk_angulars_packed_e2m3_v128relaxed(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2546
+ #else
2547
+ nk_angulars_packed_e2m3_serial(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2548
+ #endif
2549
+ }
2550
+ NK_PUBLIC void nk_angulars_symmetric_e2m3(nk_e2m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
2551
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
2552
+ nk_size_t row_start, nk_size_t row_count) {
2553
+ #if NK_TARGET_SME
2554
+ nk_angulars_symmetric_e2m3_sme(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2555
+ #elif NK_TARGET_SAPPHIREAMX
2556
+ nk_angulars_symmetric_e2m3_sapphireamx(vectors, n_vectors, depth, stride, result, result_stride, row_start,
2557
+ row_count);
2558
+ #elif NK_TARGET_SKYLAKE
2559
+ nk_angulars_symmetric_e2m3_skylake(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2560
+ #elif NK_TARGET_SIERRA
2561
+ nk_angulars_symmetric_e2m3_sierra(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2562
+ #elif NK_TARGET_ALDER
2563
+ nk_angulars_symmetric_e2m3_alder(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2564
+ #elif NK_TARGET_HASWELL
2565
+ nk_angulars_symmetric_e2m3_haswell(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2566
+ #elif NK_TARGET_RVV
2567
+ nk_angulars_symmetric_e2m3_rvv(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2568
+ #elif NK_TARGET_V128RELAXED
2569
+ nk_angulars_symmetric_e2m3_v128relaxed(vectors, n_vectors, depth, stride, result, result_stride, row_start,
2570
+ row_count);
2571
+ #else
2572
+ nk_angulars_symmetric_e2m3_serial(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2573
+ #endif
2574
+ }
2575
+ NK_PUBLIC void nk_euclideans_packed_e2m3(nk_e2m3_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
2576
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
2577
+ nk_size_t r_stride_in_bytes) {
2578
+ #if NK_TARGET_SME
2579
+ nk_euclideans_packed_e2m3_sme(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2580
+ #elif NK_TARGET_SAPPHIREAMX
2581
+ nk_euclideans_packed_e2m3_sapphireamx(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2582
+ #elif NK_TARGET_SKYLAKE
2583
+ nk_euclideans_packed_e2m3_skylake(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2584
+ #elif NK_TARGET_SIERRA
2585
+ nk_euclideans_packed_e2m3_sierra(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2586
+ #elif NK_TARGET_ALDER
2587
+ nk_euclideans_packed_e2m3_alder(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2588
+ #elif NK_TARGET_HASWELL
2589
+ nk_euclideans_packed_e2m3_haswell(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2590
+ #elif NK_TARGET_RVV
2591
+ nk_euclideans_packed_e2m3_rvv(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2592
+ #elif NK_TARGET_V128RELAXED
2593
+ nk_euclideans_packed_e2m3_v128relaxed(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2594
+ #else
2595
+ nk_euclideans_packed_e2m3_serial(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2596
+ #endif
2597
+ }
2598
+ NK_PUBLIC void nk_euclideans_symmetric_e2m3(nk_e2m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
2599
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
2600
+ nk_size_t row_start, nk_size_t row_count) {
2601
+ #if NK_TARGET_SME
2602
+ nk_euclideans_symmetric_e2m3_sme(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2603
+ #elif NK_TARGET_SAPPHIREAMX
2604
+ nk_euclideans_symmetric_e2m3_sapphireamx(vectors, n_vectors, depth, stride, result, result_stride, row_start,
2605
+ row_count);
2606
+ #elif NK_TARGET_SKYLAKE
2607
+ nk_euclideans_symmetric_e2m3_skylake(vectors, n_vectors, depth, stride, result, result_stride, row_start,
2608
+ row_count);
2609
+ #elif NK_TARGET_SIERRA
2610
+ nk_euclideans_symmetric_e2m3_sierra(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2611
+ #elif NK_TARGET_ALDER
2612
+ nk_euclideans_symmetric_e2m3_alder(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2613
+ #elif NK_TARGET_HASWELL
2614
+ nk_euclideans_symmetric_e2m3_haswell(vectors, n_vectors, depth, stride, result, result_stride, row_start,
2615
+ row_count);
2616
+ #elif NK_TARGET_RVV
2617
+ nk_euclideans_symmetric_e2m3_rvv(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2618
+ #elif NK_TARGET_V128RELAXED
2619
+ nk_euclideans_symmetric_e2m3_v128relaxed(vectors, n_vectors, depth, stride, result, result_stride, row_start,
2620
+ row_count);
2621
+ #else
2622
+ nk_euclideans_symmetric_e2m3_serial(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2623
+ #endif
2624
+ }
2625
+
2626
+ NK_PUBLIC void nk_angulars_packed_e3m2(nk_e3m2_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
2627
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
2628
+ nk_size_t r_stride_in_bytes) {
2629
+ #if NK_TARGET_SME
2630
+ nk_angulars_packed_e3m2_sme(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2631
+ #elif NK_TARGET_SAPPHIREAMX
2632
+ nk_angulars_packed_e3m2_sapphireamx(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2633
+ #elif NK_TARGET_SKYLAKE
2634
+ nk_angulars_packed_e3m2_skylake(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2635
+ #elif NK_TARGET_HASWELL
2636
+ nk_angulars_packed_e3m2_haswell(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2637
+ #elif NK_TARGET_RVV
2638
+ nk_angulars_packed_e3m2_rvv(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2639
+ #else
2640
+ nk_angulars_packed_e3m2_serial(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2641
+ #endif
2642
+ }
2643
+ NK_PUBLIC void nk_angulars_symmetric_e3m2(nk_e3m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
2644
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
2645
+ nk_size_t row_start, nk_size_t row_count) {
2646
+ #if NK_TARGET_SME
2647
+ nk_angulars_symmetric_e3m2_sme(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2648
+ #elif NK_TARGET_SAPPHIREAMX
2649
+ nk_angulars_symmetric_e3m2_sapphireamx(vectors, n_vectors, depth, stride, result, result_stride, row_start,
2650
+ row_count);
2651
+ #elif NK_TARGET_SKYLAKE
2652
+ nk_angulars_symmetric_e3m2_skylake(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2653
+ #elif NK_TARGET_HASWELL
2654
+ nk_angulars_symmetric_e3m2_haswell(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2655
+ #elif NK_TARGET_RVV
2656
+ nk_angulars_symmetric_e3m2_rvv(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2657
+ #else
2658
+ nk_angulars_symmetric_e3m2_serial(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2659
+ #endif
2660
+ }
2661
+ NK_PUBLIC void nk_euclideans_packed_e3m2(nk_e3m2_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
2662
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
2663
+ nk_size_t r_stride_in_bytes) {
2664
+ #if NK_TARGET_SME
2665
+ nk_euclideans_packed_e3m2_sme(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2666
+ #elif NK_TARGET_SAPPHIREAMX
2667
+ nk_euclideans_packed_e3m2_sapphireamx(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2668
+ #elif NK_TARGET_SKYLAKE
2669
+ nk_euclideans_packed_e3m2_skylake(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2670
+ #elif NK_TARGET_HASWELL
2671
+ nk_euclideans_packed_e3m2_haswell(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2672
+ #elif NK_TARGET_RVV
2673
+ nk_euclideans_packed_e3m2_rvv(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2674
+ #else
2675
+ nk_euclideans_packed_e3m2_serial(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2676
+ #endif
2677
+ }
2678
+ NK_PUBLIC void nk_euclideans_symmetric_e3m2(nk_e3m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
2679
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
2680
+ nk_size_t row_start, nk_size_t row_count) {
2681
+ #if NK_TARGET_SME
2682
+ nk_euclideans_symmetric_e3m2_sme(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2683
+ #elif NK_TARGET_SAPPHIREAMX
2684
+ nk_euclideans_symmetric_e3m2_sapphireamx(vectors, n_vectors, depth, stride, result, result_stride, row_start,
2685
+ row_count);
2686
+ #elif NK_TARGET_SKYLAKE
2687
+ nk_euclideans_symmetric_e3m2_skylake(vectors, n_vectors, depth, stride, result, result_stride, row_start,
2688
+ row_count);
2689
+ #elif NK_TARGET_HASWELL
2690
+ nk_euclideans_symmetric_e3m2_haswell(vectors, n_vectors, depth, stride, result, result_stride, row_start,
2691
+ row_count);
2692
+ #elif NK_TARGET_RVV
2693
+ nk_euclideans_symmetric_e3m2_rvv(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2694
+ #else
2695
+ nk_euclideans_symmetric_e3m2_serial(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2696
+ #endif
2697
+ }
2698
+
2699
+ NK_PUBLIC void nk_angulars_packed_i8(nk_i8_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
2700
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
2701
+ nk_size_t r_stride_in_bytes) {
2702
+ #if NK_TARGET_SME
2703
+ nk_angulars_packed_i8_sme(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2704
+ #elif NK_TARGET_NEONSDOT
2705
+ nk_angulars_packed_i8_neonsdot(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2706
+ #elif NK_TARGET_SAPPHIREAMX
2707
+ nk_angulars_packed_i8_sapphireamx(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2708
+ #elif NK_TARGET_ICELAKE
2709
+ nk_angulars_packed_i8_icelake(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2710
+ #elif NK_TARGET_SIERRA
2711
+ nk_angulars_packed_i8_sierra(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2712
+ #elif NK_TARGET_ALDER
2713
+ nk_angulars_packed_i8_alder(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2714
+ #elif NK_TARGET_HASWELL
2715
+ nk_angulars_packed_i8_haswell(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2716
+ #elif NK_TARGET_RVV
2717
+ nk_angulars_packed_i8_rvv(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2718
+ #elif NK_TARGET_V128RELAXED
2719
+ nk_angulars_packed_i8_v128relaxed(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2720
+ #else
2721
+ nk_angulars_packed_i8_serial(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2722
+ #endif
2723
+ }
2724
+ NK_PUBLIC void nk_angulars_symmetric_i8(nk_i8_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride,
2725
+ nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start,
2726
+ nk_size_t row_count) {
2727
+ #if NK_TARGET_SME
2728
+ nk_angulars_symmetric_i8_sme(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2729
+ #elif NK_TARGET_NEONSDOT
2730
+ nk_angulars_symmetric_i8_neonsdot(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2731
+ #elif NK_TARGET_SAPPHIREAMX
2732
+ nk_angulars_symmetric_i8_sapphireamx(vectors, n_vectors, depth, stride, result, result_stride, row_start,
2733
+ row_count);
2734
+ #elif NK_TARGET_ICELAKE
2735
+ nk_angulars_symmetric_i8_icelake(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2736
+ #elif NK_TARGET_SIERRA
2737
+ nk_angulars_symmetric_i8_sierra(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2738
+ #elif NK_TARGET_ALDER
2739
+ nk_angulars_symmetric_i8_alder(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2740
+ #elif NK_TARGET_HASWELL
2741
+ nk_angulars_symmetric_i8_haswell(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2742
+ #elif NK_TARGET_RVV
2743
+ nk_angulars_symmetric_i8_rvv(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2744
+ #elif NK_TARGET_V128RELAXED
2745
+ nk_angulars_symmetric_i8_v128relaxed(vectors, n_vectors, depth, stride, result, result_stride, row_start,
2746
+ row_count);
2747
+ #else
2748
+ nk_angulars_symmetric_i8_serial(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2749
+ #endif
2750
+ }
2751
+ NK_PUBLIC void nk_euclideans_packed_i8(nk_i8_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
2752
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
2753
+ nk_size_t r_stride_in_bytes) {
2754
+ #if NK_TARGET_SME
2755
+ nk_euclideans_packed_i8_sme(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2756
+ #elif NK_TARGET_NEONSDOT
2757
+ nk_euclideans_packed_i8_neonsdot(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2758
+ #elif NK_TARGET_SAPPHIREAMX
2759
+ nk_euclideans_packed_i8_sapphireamx(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2760
+ #elif NK_TARGET_ICELAKE
2761
+ nk_euclideans_packed_i8_icelake(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2762
+ #elif NK_TARGET_SIERRA
2763
+ nk_euclideans_packed_i8_sierra(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2764
+ #elif NK_TARGET_ALDER
2765
+ nk_euclideans_packed_i8_alder(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2766
+ #elif NK_TARGET_HASWELL
2767
+ nk_euclideans_packed_i8_haswell(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2768
+ #elif NK_TARGET_RVV
2769
+ nk_euclideans_packed_i8_rvv(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2770
+ #elif NK_TARGET_V128RELAXED
2771
+ nk_euclideans_packed_i8_v128relaxed(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2772
+ #else
2773
+ nk_euclideans_packed_i8_serial(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2774
+ #endif
2775
+ }
2776
+ NK_PUBLIC void nk_euclideans_symmetric_i8(nk_i8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
2777
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
2778
+ nk_size_t row_start, nk_size_t row_count) {
2779
+ #if NK_TARGET_SME
2780
+ nk_euclideans_symmetric_i8_sme(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2781
+ #elif NK_TARGET_NEONSDOT
2782
+ nk_euclideans_symmetric_i8_neonsdot(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2783
+ #elif NK_TARGET_SAPPHIREAMX
2784
+ nk_euclideans_symmetric_i8_sapphireamx(vectors, n_vectors, depth, stride, result, result_stride, row_start,
2785
+ row_count);
2786
+ #elif NK_TARGET_ICELAKE
2787
+ nk_euclideans_symmetric_i8_icelake(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2788
+ #elif NK_TARGET_SIERRA
2789
+ nk_euclideans_symmetric_i8_sierra(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2790
+ #elif NK_TARGET_ALDER
2791
+ nk_euclideans_symmetric_i8_alder(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2792
+ #elif NK_TARGET_HASWELL
2793
+ nk_euclideans_symmetric_i8_haswell(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2794
+ #elif NK_TARGET_RVV
2795
+ nk_euclideans_symmetric_i8_rvv(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2796
+ #elif NK_TARGET_V128RELAXED
2797
+ nk_euclideans_symmetric_i8_v128relaxed(vectors, n_vectors, depth, stride, result, result_stride, row_start,
2798
+ row_count);
2799
+ #else
2800
+ nk_euclideans_symmetric_i8_serial(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2801
+ #endif
2802
+ }
2803
+
2804
+ NK_PUBLIC void nk_angulars_packed_u8(nk_u8_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
2805
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
2806
+ nk_size_t r_stride_in_bytes) {
2807
+ #if NK_TARGET_SME
2808
+ nk_angulars_packed_u8_sme(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2809
+ #elif NK_TARGET_NEONSDOT
2810
+ nk_angulars_packed_u8_neonsdot(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2811
+ #elif NK_TARGET_SAPPHIREAMX
2812
+ nk_angulars_packed_u8_sapphireamx(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2813
+ #elif NK_TARGET_ICELAKE
2814
+ nk_angulars_packed_u8_icelake(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2815
+ #elif NK_TARGET_SIERRA
2816
+ nk_angulars_packed_u8_sierra(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2817
+ #elif NK_TARGET_ALDER
2818
+ nk_angulars_packed_u8_alder(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2819
+ #elif NK_TARGET_HASWELL
2820
+ nk_angulars_packed_u8_haswell(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2821
+ #elif NK_TARGET_RVV
2822
+ nk_angulars_packed_u8_rvv(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2823
+ #elif NK_TARGET_V128RELAXED
2824
+ nk_angulars_packed_u8_v128relaxed(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2825
+ #else
2826
+ nk_angulars_packed_u8_serial(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2827
+ #endif
2828
+ }
2829
+ NK_PUBLIC void nk_angulars_symmetric_u8(nk_u8_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride,
2830
+ nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start,
2831
+ nk_size_t row_count) {
2832
+ #if NK_TARGET_SME
2833
+ nk_angulars_symmetric_u8_sme(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2834
+ #elif NK_TARGET_NEONSDOT
2835
+ nk_angulars_symmetric_u8_neonsdot(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2836
+ #elif NK_TARGET_SAPPHIREAMX
2837
+ nk_angulars_symmetric_u8_sapphireamx(vectors, n_vectors, depth, stride, result, result_stride, row_start,
2838
+ row_count);
2839
+ #elif NK_TARGET_ICELAKE
2840
+ nk_angulars_symmetric_u8_icelake(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2841
+ #elif NK_TARGET_SIERRA
2842
+ nk_angulars_symmetric_u8_sierra(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2843
+ #elif NK_TARGET_ALDER
2844
+ nk_angulars_symmetric_u8_alder(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2845
+ #elif NK_TARGET_HASWELL
2846
+ nk_angulars_symmetric_u8_haswell(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2847
+ #elif NK_TARGET_RVV
2848
+ nk_angulars_symmetric_u8_rvv(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2849
+ #elif NK_TARGET_V128RELAXED
2850
+ nk_angulars_symmetric_u8_v128relaxed(vectors, n_vectors, depth, stride, result, result_stride, row_start,
2851
+ row_count);
2852
+ #else
2853
+ nk_angulars_symmetric_u8_serial(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2854
+ #endif
2855
+ }
2856
+ NK_PUBLIC void nk_euclideans_packed_u8(nk_u8_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
2857
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
2858
+ nk_size_t r_stride_in_bytes) {
2859
+ #if NK_TARGET_SME
2860
+ nk_euclideans_packed_u8_sme(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2861
+ #elif NK_TARGET_NEONSDOT
2862
+ nk_euclideans_packed_u8_neonsdot(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2863
+ #elif NK_TARGET_SAPPHIREAMX
2864
+ nk_euclideans_packed_u8_sapphireamx(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2865
+ #elif NK_TARGET_ICELAKE
2866
+ nk_euclideans_packed_u8_icelake(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2867
+ #elif NK_TARGET_SIERRA
2868
+ nk_euclideans_packed_u8_sierra(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2869
+ #elif NK_TARGET_ALDER
2870
+ nk_euclideans_packed_u8_alder(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2871
+ #elif NK_TARGET_HASWELL
2872
+ nk_euclideans_packed_u8_haswell(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2873
+ #elif NK_TARGET_RVV
2874
+ nk_euclideans_packed_u8_rvv(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2875
+ #elif NK_TARGET_V128RELAXED
2876
+ nk_euclideans_packed_u8_v128relaxed(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2877
+ #else
2878
+ nk_euclideans_packed_u8_serial(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2879
+ #endif
2880
+ }
2881
+ NK_PUBLIC void nk_euclideans_symmetric_u8(nk_u8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
2882
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
2883
+ nk_size_t row_start, nk_size_t row_count) {
2884
+ #if NK_TARGET_SME
2885
+ nk_euclideans_symmetric_u8_sme(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2886
+ #elif NK_TARGET_NEONSDOT
2887
+ nk_euclideans_symmetric_u8_neonsdot(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2888
+ #elif NK_TARGET_SAPPHIREAMX
2889
+ nk_euclideans_symmetric_u8_sapphireamx(vectors, n_vectors, depth, stride, result, result_stride, row_start,
2890
+ row_count);
2891
+ #elif NK_TARGET_ICELAKE
2892
+ nk_euclideans_symmetric_u8_icelake(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2893
+ #elif NK_TARGET_SIERRA
2894
+ nk_euclideans_symmetric_u8_sierra(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2895
+ #elif NK_TARGET_ALDER
2896
+ nk_euclideans_symmetric_u8_alder(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2897
+ #elif NK_TARGET_HASWELL
2898
+ nk_euclideans_symmetric_u8_haswell(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2899
+ #elif NK_TARGET_RVV
2900
+ nk_euclideans_symmetric_u8_rvv(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2901
+ #elif NK_TARGET_V128RELAXED
2902
+ nk_euclideans_symmetric_u8_v128relaxed(vectors, n_vectors, depth, stride, result, result_stride, row_start,
2903
+ row_count);
2904
+ #else
2905
+ nk_euclideans_symmetric_u8_serial(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2906
+ #endif
2907
+ }
2908
+
2909
+ NK_PUBLIC void nk_angulars_packed_i4(nk_i4x2_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
2910
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
2911
+ nk_size_t r_stride_in_bytes) {
2912
+ #if NK_TARGET_SME
2913
+ nk_angulars_packed_i4_sme(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2914
+ #elif NK_TARGET_NEONSDOT
2915
+ nk_angulars_packed_i4_neonsdot(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2916
+ #elif NK_TARGET_ICELAKE
2917
+ nk_angulars_packed_i4_icelake(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2918
+ #else
2919
+ nk_angulars_packed_i4_serial(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2920
+ #endif
2921
+ }
2922
+ NK_PUBLIC void nk_angulars_symmetric_i4(nk_i4x2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
2923
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
2924
+ nk_size_t row_start, nk_size_t row_count) {
2925
+ #if NK_TARGET_SME
2926
+ nk_angulars_symmetric_i4_sme(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2927
+ #elif NK_TARGET_NEONSDOT
2928
+ nk_angulars_symmetric_i4_neonsdot(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2929
+ #elif NK_TARGET_ICELAKE
2930
+ nk_angulars_symmetric_i4_icelake(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2931
+ #else
2932
+ nk_angulars_symmetric_i4_serial(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2933
+ #endif
2934
+ }
2935
+ NK_PUBLIC void nk_euclideans_packed_i4(nk_i4x2_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
2936
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
2937
+ nk_size_t r_stride_in_bytes) {
2938
+ #if NK_TARGET_SME
2939
+ nk_euclideans_packed_i4_sme(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2940
+ #elif NK_TARGET_NEONSDOT
2941
+ nk_euclideans_packed_i4_neonsdot(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2942
+ #elif NK_TARGET_ICELAKE
2943
+ nk_euclideans_packed_i4_icelake(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2944
+ #else
2945
+ nk_euclideans_packed_i4_serial(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2946
+ #endif
2947
+ }
2948
+ NK_PUBLIC void nk_euclideans_symmetric_i4(nk_i4x2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
2949
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
2950
+ nk_size_t row_start, nk_size_t row_count) {
2951
+ #if NK_TARGET_SME
2952
+ nk_euclideans_symmetric_i4_sme(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2953
+ #elif NK_TARGET_NEONSDOT
2954
+ nk_euclideans_symmetric_i4_neonsdot(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2955
+ #elif NK_TARGET_ICELAKE
2956
+ nk_euclideans_symmetric_i4_icelake(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2957
+ #else
2958
+ nk_euclideans_symmetric_i4_serial(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2959
+ #endif
2960
+ }
2961
+
2962
+ NK_PUBLIC void nk_angulars_packed_u4(nk_u4x2_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
2963
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
2964
+ nk_size_t r_stride_in_bytes) {
2965
+ #if NK_TARGET_SME
2966
+ nk_angulars_packed_u4_sme(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2967
+ #elif NK_TARGET_NEONSDOT
2968
+ nk_angulars_packed_u4_neonsdot(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2969
+ #elif NK_TARGET_ICELAKE
2970
+ nk_angulars_packed_u4_icelake(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2971
+ #else
2972
+ nk_angulars_packed_u4_serial(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2973
+ #endif
2974
+ }
2975
+ NK_PUBLIC void nk_angulars_symmetric_u4(nk_u4x2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
2976
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
2977
+ nk_size_t row_start, nk_size_t row_count) {
2978
+ #if NK_TARGET_SME
2979
+ nk_angulars_symmetric_u4_sme(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2980
+ #elif NK_TARGET_NEONSDOT
2981
+ nk_angulars_symmetric_u4_neonsdot(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2982
+ #elif NK_TARGET_ICELAKE
2983
+ nk_angulars_symmetric_u4_icelake(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2984
+ #else
2985
+ nk_angulars_symmetric_u4_serial(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
2986
+ #endif
2987
+ }
2988
+ NK_PUBLIC void nk_euclideans_packed_u4(nk_u4x2_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
2989
+ nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
2990
+ nk_size_t r_stride_in_bytes) {
2991
+ #if NK_TARGET_SME
2992
+ nk_euclideans_packed_u4_sme(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2993
+ #elif NK_TARGET_NEONSDOT
2994
+ nk_euclideans_packed_u4_neonsdot(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2995
+ #elif NK_TARGET_ICELAKE
2996
+ nk_euclideans_packed_u4_icelake(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2997
+ #else
2998
+ nk_euclideans_packed_u4_serial(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2999
+ #endif
3000
+ }
3001
+ NK_PUBLIC void nk_euclideans_symmetric_u4(nk_u4x2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
3002
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
3003
+ nk_size_t row_start, nk_size_t row_count) {
3004
+ #if NK_TARGET_SME
3005
+ nk_euclideans_symmetric_u4_sme(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
3006
+ #elif NK_TARGET_NEONSDOT
3007
+ nk_euclideans_symmetric_u4_neonsdot(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
3008
+ #elif NK_TARGET_ICELAKE
3009
+ nk_euclideans_symmetric_u4_icelake(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
3010
+ #else
3011
+ nk_euclideans_symmetric_u4_serial(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
3012
+ #endif
3013
+ }
3014
+
3015
+ #endif // !NK_DYNAMIC_DISPATCH
3016
+
3017
+ #if defined(__cplusplus)
3018
+ } // extern "C"
3019
+ #endif
3020
+
3021
+ #endif // NK_SPATIALS_H