numkong 7.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (294) hide show
  1. package/LICENSE +201 -0
  2. package/README.md +495 -0
  3. package/binding.gyp +540 -0
  4. package/c/dispatch.h +512 -0
  5. package/c/dispatch_bf16.c +389 -0
  6. package/c/dispatch_bf16c.c +52 -0
  7. package/c/dispatch_e2m3.c +263 -0
  8. package/c/dispatch_e3m2.c +243 -0
  9. package/c/dispatch_e4m3.c +276 -0
  10. package/c/dispatch_e5m2.c +272 -0
  11. package/c/dispatch_f16.c +376 -0
  12. package/c/dispatch_f16c.c +58 -0
  13. package/c/dispatch_f32.c +378 -0
  14. package/c/dispatch_f32c.c +99 -0
  15. package/c/dispatch_f64.c +296 -0
  16. package/c/dispatch_f64c.c +98 -0
  17. package/c/dispatch_i16.c +96 -0
  18. package/c/dispatch_i32.c +89 -0
  19. package/c/dispatch_i4.c +150 -0
  20. package/c/dispatch_i64.c +86 -0
  21. package/c/dispatch_i8.c +289 -0
  22. package/c/dispatch_other.c +330 -0
  23. package/c/dispatch_u1.c +148 -0
  24. package/c/dispatch_u16.c +124 -0
  25. package/c/dispatch_u32.c +118 -0
  26. package/c/dispatch_u4.c +150 -0
  27. package/c/dispatch_u64.c +102 -0
  28. package/c/dispatch_u8.c +303 -0
  29. package/c/numkong.c +950 -0
  30. package/include/README.md +573 -0
  31. package/include/module.modulemap +129 -0
  32. package/include/numkong/attention/sapphireamx.h +1361 -0
  33. package/include/numkong/attention/sme.h +2066 -0
  34. package/include/numkong/attention.h +49 -0
  35. package/include/numkong/capabilities.h +748 -0
  36. package/include/numkong/cast/README.md +262 -0
  37. package/include/numkong/cast/haswell.h +975 -0
  38. package/include/numkong/cast/icelake.h +470 -0
  39. package/include/numkong/cast/neon.h +1192 -0
  40. package/include/numkong/cast/rvv.h +1021 -0
  41. package/include/numkong/cast/sapphire.h +262 -0
  42. package/include/numkong/cast/serial.h +2262 -0
  43. package/include/numkong/cast/skylake.h +856 -0
  44. package/include/numkong/cast/v128relaxed.h +180 -0
  45. package/include/numkong/cast.h +230 -0
  46. package/include/numkong/curved/README.md +223 -0
  47. package/include/numkong/curved/genoa.h +182 -0
  48. package/include/numkong/curved/haswell.h +276 -0
  49. package/include/numkong/curved/neon.h +205 -0
  50. package/include/numkong/curved/neonbfdot.h +212 -0
  51. package/include/numkong/curved/neonhalf.h +212 -0
  52. package/include/numkong/curved/rvv.h +305 -0
  53. package/include/numkong/curved/serial.h +207 -0
  54. package/include/numkong/curved/skylake.h +457 -0
  55. package/include/numkong/curved/smef64.h +506 -0
  56. package/include/numkong/curved.h +517 -0
  57. package/include/numkong/curved.hpp +144 -0
  58. package/include/numkong/dot/README.md +425 -0
  59. package/include/numkong/dot/alder.h +563 -0
  60. package/include/numkong/dot/genoa.h +315 -0
  61. package/include/numkong/dot/haswell.h +1688 -0
  62. package/include/numkong/dot/icelake.h +883 -0
  63. package/include/numkong/dot/neon.h +818 -0
  64. package/include/numkong/dot/neonbfdot.h +244 -0
  65. package/include/numkong/dot/neonfhm.h +360 -0
  66. package/include/numkong/dot/neonhalf.h +198 -0
  67. package/include/numkong/dot/neonsdot.h +508 -0
  68. package/include/numkong/dot/rvv.h +714 -0
  69. package/include/numkong/dot/rvvbb.h +72 -0
  70. package/include/numkong/dot/rvvbf16.h +123 -0
  71. package/include/numkong/dot/rvvhalf.h +129 -0
  72. package/include/numkong/dot/sapphire.h +141 -0
  73. package/include/numkong/dot/serial.h +838 -0
  74. package/include/numkong/dot/sierra.h +405 -0
  75. package/include/numkong/dot/skylake.h +1084 -0
  76. package/include/numkong/dot/sve.h +379 -0
  77. package/include/numkong/dot/svebfdot.h +74 -0
  78. package/include/numkong/dot/svehalf.h +123 -0
  79. package/include/numkong/dot/v128relaxed.h +1258 -0
  80. package/include/numkong/dot.h +1070 -0
  81. package/include/numkong/dot.hpp +94 -0
  82. package/include/numkong/dots/README.md +496 -0
  83. package/include/numkong/dots/alder.h +114 -0
  84. package/include/numkong/dots/genoa.h +94 -0
  85. package/include/numkong/dots/haswell.h +295 -0
  86. package/include/numkong/dots/icelake.h +171 -0
  87. package/include/numkong/dots/neon.h +120 -0
  88. package/include/numkong/dots/neonbfdot.h +58 -0
  89. package/include/numkong/dots/neonfhm.h +94 -0
  90. package/include/numkong/dots/neonhalf.h +57 -0
  91. package/include/numkong/dots/neonsdot.h +108 -0
  92. package/include/numkong/dots/rvv.h +2486 -0
  93. package/include/numkong/dots/sapphireamx.h +3973 -0
  94. package/include/numkong/dots/serial.h +2844 -0
  95. package/include/numkong/dots/sierra.h +97 -0
  96. package/include/numkong/dots/skylake.h +196 -0
  97. package/include/numkong/dots/sme.h +5372 -0
  98. package/include/numkong/dots/smebi32.h +461 -0
  99. package/include/numkong/dots/smef64.h +1318 -0
  100. package/include/numkong/dots/smehalf.h +47 -0
  101. package/include/numkong/dots/v128relaxed.h +294 -0
  102. package/include/numkong/dots.h +2804 -0
  103. package/include/numkong/dots.hpp +639 -0
  104. package/include/numkong/each/README.md +469 -0
  105. package/include/numkong/each/haswell.h +1658 -0
  106. package/include/numkong/each/icelake.h +272 -0
  107. package/include/numkong/each/neon.h +1104 -0
  108. package/include/numkong/each/neonbfdot.h +212 -0
  109. package/include/numkong/each/neonhalf.h +410 -0
  110. package/include/numkong/each/rvv.h +1121 -0
  111. package/include/numkong/each/sapphire.h +477 -0
  112. package/include/numkong/each/serial.h +260 -0
  113. package/include/numkong/each/skylake.h +1562 -0
  114. package/include/numkong/each.h +2146 -0
  115. package/include/numkong/each.hpp +434 -0
  116. package/include/numkong/geospatial/README.md +147 -0
  117. package/include/numkong/geospatial/haswell.h +593 -0
  118. package/include/numkong/geospatial/neon.h +571 -0
  119. package/include/numkong/geospatial/rvv.h +701 -0
  120. package/include/numkong/geospatial/serial.h +309 -0
  121. package/include/numkong/geospatial/skylake.h +577 -0
  122. package/include/numkong/geospatial/v128relaxed.h +613 -0
  123. package/include/numkong/geospatial.h +453 -0
  124. package/include/numkong/geospatial.hpp +235 -0
  125. package/include/numkong/matrix.hpp +336 -0
  126. package/include/numkong/maxsim/README.md +187 -0
  127. package/include/numkong/maxsim/alder.h +511 -0
  128. package/include/numkong/maxsim/genoa.h +115 -0
  129. package/include/numkong/maxsim/haswell.h +553 -0
  130. package/include/numkong/maxsim/icelake.h +480 -0
  131. package/include/numkong/maxsim/neonsdot.h +394 -0
  132. package/include/numkong/maxsim/sapphireamx.h +877 -0
  133. package/include/numkong/maxsim/serial.h +490 -0
  134. package/include/numkong/maxsim/sme.h +929 -0
  135. package/include/numkong/maxsim/v128relaxed.h +280 -0
  136. package/include/numkong/maxsim.h +571 -0
  137. package/include/numkong/maxsim.hpp +133 -0
  138. package/include/numkong/mesh/README.md +227 -0
  139. package/include/numkong/mesh/haswell.h +2235 -0
  140. package/include/numkong/mesh/neon.h +1329 -0
  141. package/include/numkong/mesh/neonbfdot.h +842 -0
  142. package/include/numkong/mesh/neonhalf.h +616 -0
  143. package/include/numkong/mesh/rvv.h +916 -0
  144. package/include/numkong/mesh/serial.h +742 -0
  145. package/include/numkong/mesh/skylake.h +1135 -0
  146. package/include/numkong/mesh/v128relaxed.h +1052 -0
  147. package/include/numkong/mesh.h +652 -0
  148. package/include/numkong/mesh.hpp +762 -0
  149. package/include/numkong/numkong.h +78 -0
  150. package/include/numkong/numkong.hpp +57 -0
  151. package/include/numkong/probability/README.md +173 -0
  152. package/include/numkong/probability/haswell.h +267 -0
  153. package/include/numkong/probability/neon.h +225 -0
  154. package/include/numkong/probability/rvv.h +409 -0
  155. package/include/numkong/probability/serial.h +169 -0
  156. package/include/numkong/probability/skylake.h +324 -0
  157. package/include/numkong/probability.h +383 -0
  158. package/include/numkong/probability.hpp +120 -0
  159. package/include/numkong/random.h +50 -0
  160. package/include/numkong/random.hpp +285 -0
  161. package/include/numkong/reduce/README.md +547 -0
  162. package/include/numkong/reduce/alder.h +632 -0
  163. package/include/numkong/reduce/genoa.h +201 -0
  164. package/include/numkong/reduce/haswell.h +3783 -0
  165. package/include/numkong/reduce/icelake.h +549 -0
  166. package/include/numkong/reduce/neon.h +3841 -0
  167. package/include/numkong/reduce/neonbfdot.h +353 -0
  168. package/include/numkong/reduce/neonfhm.h +665 -0
  169. package/include/numkong/reduce/neonhalf.h +157 -0
  170. package/include/numkong/reduce/neonsdot.h +357 -0
  171. package/include/numkong/reduce/rvv.h +3407 -0
  172. package/include/numkong/reduce/serial.h +757 -0
  173. package/include/numkong/reduce/sierra.h +338 -0
  174. package/include/numkong/reduce/skylake.h +3792 -0
  175. package/include/numkong/reduce/v128relaxed.h +2302 -0
  176. package/include/numkong/reduce.h +1597 -0
  177. package/include/numkong/reduce.hpp +633 -0
  178. package/include/numkong/scalar/README.md +89 -0
  179. package/include/numkong/scalar/haswell.h +113 -0
  180. package/include/numkong/scalar/neon.h +122 -0
  181. package/include/numkong/scalar/neonhalf.h +70 -0
  182. package/include/numkong/scalar/rvv.h +211 -0
  183. package/include/numkong/scalar/sapphire.h +63 -0
  184. package/include/numkong/scalar/serial.h +332 -0
  185. package/include/numkong/scalar/v128relaxed.h +56 -0
  186. package/include/numkong/scalar.h +683 -0
  187. package/include/numkong/set/README.md +179 -0
  188. package/include/numkong/set/haswell.h +334 -0
  189. package/include/numkong/set/icelake.h +485 -0
  190. package/include/numkong/set/neon.h +364 -0
  191. package/include/numkong/set/rvv.h +226 -0
  192. package/include/numkong/set/rvvbb.h +117 -0
  193. package/include/numkong/set/serial.h +174 -0
  194. package/include/numkong/set/sve.h +185 -0
  195. package/include/numkong/set/v128relaxed.h +240 -0
  196. package/include/numkong/set.h +457 -0
  197. package/include/numkong/set.hpp +114 -0
  198. package/include/numkong/sets/README.md +149 -0
  199. package/include/numkong/sets/haswell.h +63 -0
  200. package/include/numkong/sets/icelake.h +66 -0
  201. package/include/numkong/sets/neon.h +61 -0
  202. package/include/numkong/sets/serial.h +43 -0
  203. package/include/numkong/sets/smebi32.h +1099 -0
  204. package/include/numkong/sets/v128relaxed.h +58 -0
  205. package/include/numkong/sets.h +339 -0
  206. package/include/numkong/sparse/README.md +156 -0
  207. package/include/numkong/sparse/icelake.h +463 -0
  208. package/include/numkong/sparse/neon.h +288 -0
  209. package/include/numkong/sparse/serial.h +117 -0
  210. package/include/numkong/sparse/sve2.h +507 -0
  211. package/include/numkong/sparse/turin.h +322 -0
  212. package/include/numkong/sparse.h +363 -0
  213. package/include/numkong/sparse.hpp +113 -0
  214. package/include/numkong/spatial/README.md +435 -0
  215. package/include/numkong/spatial/alder.h +607 -0
  216. package/include/numkong/spatial/genoa.h +290 -0
  217. package/include/numkong/spatial/haswell.h +960 -0
  218. package/include/numkong/spatial/icelake.h +586 -0
  219. package/include/numkong/spatial/neon.h +773 -0
  220. package/include/numkong/spatial/neonbfdot.h +165 -0
  221. package/include/numkong/spatial/neonhalf.h +118 -0
  222. package/include/numkong/spatial/neonsdot.h +261 -0
  223. package/include/numkong/spatial/rvv.h +984 -0
  224. package/include/numkong/spatial/rvvbf16.h +123 -0
  225. package/include/numkong/spatial/rvvhalf.h +117 -0
  226. package/include/numkong/spatial/sapphire.h +343 -0
  227. package/include/numkong/spatial/serial.h +346 -0
  228. package/include/numkong/spatial/sierra.h +323 -0
  229. package/include/numkong/spatial/skylake.h +606 -0
  230. package/include/numkong/spatial/sve.h +224 -0
  231. package/include/numkong/spatial/svebfdot.h +122 -0
  232. package/include/numkong/spatial/svehalf.h +109 -0
  233. package/include/numkong/spatial/v128relaxed.h +717 -0
  234. package/include/numkong/spatial.h +1425 -0
  235. package/include/numkong/spatial.hpp +183 -0
  236. package/include/numkong/spatials/README.md +580 -0
  237. package/include/numkong/spatials/alder.h +94 -0
  238. package/include/numkong/spatials/genoa.h +94 -0
  239. package/include/numkong/spatials/haswell.h +219 -0
  240. package/include/numkong/spatials/icelake.h +113 -0
  241. package/include/numkong/spatials/neon.h +109 -0
  242. package/include/numkong/spatials/neonbfdot.h +60 -0
  243. package/include/numkong/spatials/neonfhm.h +92 -0
  244. package/include/numkong/spatials/neonhalf.h +58 -0
  245. package/include/numkong/spatials/neonsdot.h +109 -0
  246. package/include/numkong/spatials/rvv.h +1960 -0
  247. package/include/numkong/spatials/sapphireamx.h +1149 -0
  248. package/include/numkong/spatials/serial.h +226 -0
  249. package/include/numkong/spatials/sierra.h +96 -0
  250. package/include/numkong/spatials/skylake.h +184 -0
  251. package/include/numkong/spatials/sme.h +1901 -0
  252. package/include/numkong/spatials/smef64.h +465 -0
  253. package/include/numkong/spatials/v128relaxed.h +240 -0
  254. package/include/numkong/spatials.h +3021 -0
  255. package/include/numkong/spatials.hpp +508 -0
  256. package/include/numkong/tensor.hpp +1592 -0
  257. package/include/numkong/trigonometry/README.md +184 -0
  258. package/include/numkong/trigonometry/haswell.h +652 -0
  259. package/include/numkong/trigonometry/neon.h +639 -0
  260. package/include/numkong/trigonometry/rvv.h +699 -0
  261. package/include/numkong/trigonometry/serial.h +703 -0
  262. package/include/numkong/trigonometry/skylake.h +721 -0
  263. package/include/numkong/trigonometry/v128relaxed.h +666 -0
  264. package/include/numkong/trigonometry.h +467 -0
  265. package/include/numkong/trigonometry.hpp +166 -0
  266. package/include/numkong/types.h +1384 -0
  267. package/include/numkong/types.hpp +5603 -0
  268. package/include/numkong/vector.hpp +698 -0
  269. package/javascript/README.md +246 -0
  270. package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
  271. package/javascript/dist/cjs/numkong-wasm.js +617 -0
  272. package/javascript/dist/cjs/numkong.d.ts +343 -0
  273. package/javascript/dist/cjs/numkong.js +523 -0
  274. package/javascript/dist/cjs/package.json +3 -0
  275. package/javascript/dist/cjs/types.d.ts +284 -0
  276. package/javascript/dist/cjs/types.js +653 -0
  277. package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
  278. package/javascript/dist/esm/numkong-wasm.js +595 -0
  279. package/javascript/dist/esm/numkong.d.ts +343 -0
  280. package/javascript/dist/esm/numkong.js +452 -0
  281. package/javascript/dist/esm/package.json +3 -0
  282. package/javascript/dist/esm/types.d.ts +284 -0
  283. package/javascript/dist/esm/types.js +630 -0
  284. package/javascript/dist-package-cjs.json +3 -0
  285. package/javascript/dist-package-esm.json +3 -0
  286. package/javascript/node-gyp-build.d.ts +1 -0
  287. package/javascript/numkong-wasm.ts +756 -0
  288. package/javascript/numkong.c +689 -0
  289. package/javascript/numkong.ts +575 -0
  290. package/javascript/tsconfig-base.json +39 -0
  291. package/javascript/tsconfig-cjs.json +8 -0
  292. package/javascript/tsconfig-esm.json +8 -0
  293. package/javascript/types.ts +674 -0
  294. package/package.json +87 -0
@@ -0,0 +1,517 @@
1
+ /**
2
+ * @brief SIMD-accelerated Similarity Measures for Curved Spaces.
3
+ * @file include/numkong/curved.h
4
+ * @author Ash Vardanian
5
+ * @date August 27, 2024
6
+ *
7
+ * Contains following similarity measures:
8
+ *
9
+ * - Mahalanobis distance: √((a-b)ᵀ × C × (a-b))
10
+ * - Bilinear form: aᵀ × C × b
11
+ * - Bilinear form over complex numbers
12
+ *
13
+ * For dtypes:
14
+ *
15
+ * - 64-bit floating point numbers → 64-bit floats
16
+ * - 32-bit floating point numbers → 64-bit floats
17
+ * - 16-bit floating point numbers → 32-bit floats
18
+ * - 16-bit brain-floating point numbers → 32-bit floats
19
+ *
20
+ * For hardware architectures:
21
+ *
22
+ * - Arm: NEON, NEON+F16, NEON+BF16, SME+F64
23
+ * - x86: Haswell, Skylake, Genoa
24
+ * - RISC-V: RVV
25
+ *
26
+ * @section numerical_stability Numerical Stability
27
+ *
28
+ * To minimize catastrophic cancellation in large-magnitude sums:
29
+ * - f32 kernels widen public outputs to f64/f64c and accumulate in f64 precision where possible
30
+ * - f64 kernels use Dot2 algorithm (Ogita-Rump-Oishi 2005) in SIMD paths
31
+ * - Serial kernels use Neumaier compensated summation for all types
32
+ *
33
+ * @section usage Usage and Benefits
34
+ *
35
+ * These kernels target BLAS level 2 patterns where vectors are combined with a metric
36
+ * tensor or covariance matrix. Using raw bilinear and Mahalanobis forms avoids constructing
37
+ * intermediates and keeps memory traffic low, which is often faster than a full GEMM path
38
+ * for small and medium sizes. Complex bilinear forms return a complex scalar as two reals,
39
+ * serving complex-valued signals without extra packing or unpacking.
40
+ *
41
+ * @section references References
42
+ *
43
+ * - x86 intrinsics: https://www.intel.com/content/www/us/en/docs/intrinsics-guide/
44
+ * - Arm intrinsics: https://developer.arm.com/architectures/instruction-sets/intrinsics/
45
+ * - Neumaier, A. (1974). "Rundungsfehleranalyse einiger Verfahren zur Summation endlicher Summen"
46
+ * - Ogita, T., Rump, S.M., Oishi, S. (2005). "Accurate Sum and Dot Product"
47
+ *
48
+ */
49
+ #ifndef NK_CURVED_H
50
+ #define NK_CURVED_H
51
+
52
+ #include "numkong/types.h"
53
+
54
+ #if defined(__cplusplus)
55
+ extern "C" {
56
+ #endif
57
+
58
+ /**
59
+ * @brief Bilinear form between vectors a and b under metric tensor C.
60
+ *
61
+ * Computes aᵀ × C × b = Σᵢ Σⱼ aᵢ × cᵢⱼ × bⱼ
62
+ *
63
+ * @param[in] a The first vector.
64
+ * @param[in] b The second vector.
65
+ * @param[in] c The metric tensor or covariance matrix, stored row-major as an n×n matrix.
66
+ * @param[in] n The number of dimensions in the vectors.
67
+ * @param[out] result The output bilinear form value.
68
+ *
69
+ * @note The output value can be negative.
70
+ */
71
+ NK_DYNAMIC void nk_bilinear_f64(nk_f64_t const *a, nk_f64_t const *b, nk_f64_t const *c, nk_size_t n, nk_f64_t *result);
72
+ /** @copydoc nk_bilinear_f64 */
73
+ NK_DYNAMIC void nk_bilinear_f32(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t n, nk_f64_t *result);
74
+ /** @copydoc nk_bilinear_f64 */
75
+ NK_DYNAMIC void nk_bilinear_f16(nk_f16_t const *a, nk_f16_t const *b, nk_f16_t const *c, nk_size_t n, nk_f32_t *result);
76
+ /** @copydoc nk_bilinear_f64 */
77
+ NK_DYNAMIC void nk_bilinear_bf16(nk_bf16_t const *a, nk_bf16_t const *b, nk_bf16_t const *c, nk_size_t n,
78
+ nk_f32_t *result);
79
+
80
+ /**
81
+ * @brief Mahalanobis distance between vectors a and b under metric tensor C.
82
+ *
83
+ * Computes √((a-b)ᵀ × C × (a-b)) = √(Σᵢ Σⱼ (aᵢ-bᵢ) × cᵢⱼ × (aⱼ-bⱼ))
84
+ *
85
+ * @param[in] a The first vector.
86
+ * @param[in] b The second vector.
87
+ * @param[in] c The Positive Semi-Definite (PSD) matrix, stored row-major as an n×n matrix.
88
+ * @param[in] n The number of dimensions in the vectors.
89
+ * @param[out] result The output distance value.
90
+ *
91
+ * @note The output value is non-negative when C is PSD.
92
+ * @note The output value is zero if and only if the two vectors are identical.
93
+ * @note The matrix C must be positive semi-definite. If C is not PSD, the quadratic form
94
+ * (a-b)ᵀ C (a-b) may be negative, and the square root will produce NaN.
95
+ */
96
+ NK_DYNAMIC void nk_mahalanobis_f64(nk_f64_t const *a, nk_f64_t const *b, nk_f64_t const *c, nk_size_t n,
97
+ nk_f64_t *result);
98
+ /** @copydoc nk_mahalanobis_f64 */
99
+ NK_DYNAMIC void nk_mahalanobis_f32(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t n,
100
+ nk_f64_t *result);
101
+ /** @copydoc nk_mahalanobis_f64 */
102
+ NK_DYNAMIC void nk_mahalanobis_f16(nk_f16_t const *a, nk_f16_t const *b, nk_f16_t const *c, nk_size_t n,
103
+ nk_f32_t *result);
104
+ /** @copydoc nk_mahalanobis_f64 */
105
+ NK_DYNAMIC void nk_mahalanobis_bf16(nk_bf16_t const *a, nk_bf16_t const *b, nk_bf16_t const *c, nk_size_t n,
106
+ nk_f32_t *result);
107
+
108
+ /**
109
+ * @brief Complex bilinear form between vectors a and b under metric tensor C.
110
+ *
111
+ * @param[in] a The first complex vector.
112
+ * @param[in] b The second complex vector.
113
+ * @param[in] c The complex metric tensor, stored row-major as an n×n matrix.
114
+ * @param[in] n The number of dimensions in the vectors.
115
+ * @param[out] results The output complex value with real and imaginary parts.
116
+ */
117
+ NK_DYNAMIC void nk_bilinear_f64c(nk_f64c_t const *a, nk_f64c_t const *b, nk_f64c_t const *c, nk_size_t n,
118
+ nk_f64c_t *results);
119
+ /** @copydoc nk_bilinear_f64c */
120
+ NK_DYNAMIC void nk_bilinear_f32c(nk_f32c_t const *a, nk_f32c_t const *b, nk_f32c_t const *c, nk_size_t n,
121
+ nk_f64c_t *results);
122
+ /** @copydoc nk_bilinear_f64c */
123
+ NK_DYNAMIC void nk_bilinear_f16c(nk_f16c_t const *a, nk_f16c_t const *b, nk_f16c_t const *c, nk_size_t n,
124
+ nk_f32c_t *results);
125
+ /** @copydoc nk_bilinear_f64c */
126
+ NK_DYNAMIC void nk_bilinear_bf16c(nk_bf16c_t const *a, nk_bf16c_t const *b, nk_bf16c_t const *c, nk_size_t n,
127
+ nk_f32c_t *results);
128
+
129
+ /** @copydoc nk_bilinear_f64 */
130
+ NK_PUBLIC void nk_bilinear_f64_serial(nk_f64_t const *a, nk_f64_t const *b, nk_f64_t const *c, nk_size_t n,
131
+ nk_f64_t *result);
132
+ /** @copydoc nk_bilinear_f64c */
133
+ NK_PUBLIC void nk_bilinear_f64c_serial(nk_f64c_t const *a, nk_f64c_t const *b, nk_f64c_t const *c, nk_size_t n,
134
+ nk_f64c_t *results);
135
+ /** @copydoc nk_mahalanobis_f64 */
136
+ NK_PUBLIC void nk_mahalanobis_f64_serial(nk_f64_t const *a, nk_f64_t const *b, nk_f64_t const *c, nk_size_t n,
137
+ nk_f64_t *result);
138
+ /** @copydoc nk_bilinear_f32 */
139
+ NK_PUBLIC void nk_bilinear_f32_serial(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t n,
140
+ nk_f64_t *result);
141
+ /** @copydoc nk_bilinear_f32c */
142
+ NK_PUBLIC void nk_bilinear_f32c_serial(nk_f32c_t const *a, nk_f32c_t const *b, nk_f32c_t const *c, nk_size_t n,
143
+ nk_f64c_t *results);
144
+ /** @copydoc nk_mahalanobis_f32 */
145
+ NK_PUBLIC void nk_mahalanobis_f32_serial(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t n,
146
+ nk_f64_t *result);
147
+ /** @copydoc nk_bilinear_f16 */
148
+ NK_PUBLIC void nk_bilinear_f16_serial(nk_f16_t const *a, nk_f16_t const *b, nk_f16_t const *c, nk_size_t n,
149
+ nk_f32_t *result);
150
+ /** @copydoc nk_bilinear_f16c */
151
+ NK_PUBLIC void nk_bilinear_f16c_serial(nk_f16c_t const *a, nk_f16c_t const *b, nk_f16c_t const *c, nk_size_t n,
152
+ nk_f32c_t *results);
153
+ /** @copydoc nk_mahalanobis_f16 */
154
+ NK_PUBLIC void nk_mahalanobis_f16_serial(nk_f16_t const *a, nk_f16_t const *b, nk_f16_t const *c, nk_size_t n,
155
+ nk_f32_t *result);
156
+ /** @copydoc nk_bilinear_bf16 */
157
+ NK_PUBLIC void nk_bilinear_bf16_serial(nk_bf16_t const *a, nk_bf16_t const *b, nk_bf16_t const *c, nk_size_t n,
158
+ nk_f32_t *result);
159
+ /** @copydoc nk_bilinear_bf16c */
160
+ NK_PUBLIC void nk_bilinear_bf16c_serial(nk_bf16c_t const *a, nk_bf16c_t const *b, nk_bf16c_t const *c, nk_size_t n,
161
+ nk_f32c_t *results);
162
+ /** @copydoc nk_mahalanobis_bf16 */
163
+ NK_PUBLIC void nk_mahalanobis_bf16_serial(nk_bf16_t const *a, nk_bf16_t const *b, nk_bf16_t const *c, nk_size_t n,
164
+ nk_f32_t *result);
165
+
166
+ #if NK_TARGET_NEON
167
+ /** @copydoc nk_bilinear_f32 */
168
+ NK_PUBLIC void nk_bilinear_f32_neon(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t n,
169
+ nk_f64_t *result);
170
+ /** @copydoc nk_bilinear_f32c */
171
+ NK_PUBLIC void nk_bilinear_f32c_neon(nk_f32c_t const *a, nk_f32c_t const *b, nk_f32c_t const *c, nk_size_t n,
172
+ nk_f64c_t *results);
173
+ /** @copydoc nk_mahalanobis_f32 */
174
+ NK_PUBLIC void nk_mahalanobis_f32_neon(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t n,
175
+ nk_f64_t *result);
176
+ #endif // NK_TARGET_NEON
177
+
178
+ #if NK_TARGET_NEONHALF
179
+ /** @copydoc nk_bilinear_f16 */
180
+ NK_PUBLIC void nk_bilinear_f16_neonhalf(nk_f16_t const *a, nk_f16_t const *b, nk_f16_t const *c, nk_size_t n,
181
+ nk_f32_t *result);
182
+ /** @copydoc nk_bilinear_f16c */
183
+ NK_PUBLIC void nk_bilinear_f16c_neonhalf(nk_f16c_t const *a, nk_f16c_t const *b, nk_f16c_t const *c, nk_size_t n,
184
+ nk_f32c_t *results);
185
+ /** @copydoc nk_mahalanobis_f16 */
186
+ NK_PUBLIC void nk_mahalanobis_f16_neonhalf(nk_f16_t const *a, nk_f16_t const *b, nk_f16_t const *c, nk_size_t n,
187
+ nk_f32_t *result);
188
+ #endif // NK_TARGET_NEONHALF
189
+
190
+ #if NK_TARGET_NEONBFDOT
191
+ /** @copydoc nk_bilinear_bf16 */
192
+ NK_PUBLIC void nk_bilinear_bf16_neonbfdot(nk_bf16_t const *a, nk_bf16_t const *b, nk_bf16_t const *c, nk_size_t n,
193
+ nk_f32_t *result);
194
+ /** @copydoc nk_bilinear_bf16c */
195
+ NK_PUBLIC void nk_bilinear_bf16c_neonbfdot(nk_bf16c_t const *a, nk_bf16c_t const *b, nk_bf16c_t const *c, nk_size_t n,
196
+ nk_f32c_t *results);
197
+ /** @copydoc nk_mahalanobis_bf16 */
198
+ NK_PUBLIC void nk_mahalanobis_bf16_neonbfdot(nk_bf16_t const *a, nk_bf16_t const *b, nk_bf16_t const *c, nk_size_t n,
199
+ nk_f32_t *result);
200
+ #endif // NK_TARGET_NEONBFDOT
201
+
202
+ #if NK_TARGET_SMEF64
203
+ /** @copydoc nk_bilinear_f32 */
204
+ NK_PUBLIC void nk_bilinear_f32_smef64(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t n,
205
+ nk_f64_t *result);
206
+ /** @copydoc nk_bilinear_f32c */
207
+ NK_PUBLIC void nk_bilinear_f32c_smef64(nk_f32c_t const *a, nk_f32c_t const *b, nk_f32c_t const *c, nk_size_t n,
208
+ nk_f64c_t *result);
209
+ /** @copydoc nk_mahalanobis_f32 */
210
+ NK_PUBLIC void nk_mahalanobis_f32_smef64(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t n,
211
+ nk_f64_t *result);
212
+ /** @copydoc nk_bilinear_f64 */
213
+ NK_PUBLIC void nk_bilinear_f64_smef64(nk_f64_t const *a, nk_f64_t const *b, nk_f64_t const *c, nk_size_t n,
214
+ nk_f64_t *result);
215
+ /** @copydoc nk_bilinear_f64c */
216
+ NK_PUBLIC void nk_bilinear_f64c_smef64(nk_f64c_t const *a, nk_f64c_t const *b, nk_f64c_t const *c, nk_size_t n,
217
+ nk_f64c_t *result);
218
+ /** @copydoc nk_mahalanobis_f64 */
219
+ NK_PUBLIC void nk_mahalanobis_f64_smef64(nk_f64_t const *a, nk_f64_t const *b, nk_f64_t const *c, nk_size_t n,
220
+ nk_f64_t *result);
221
+ #endif // NK_TARGET_SMEF64
222
+
223
+ #if NK_TARGET_HASWELL
224
+ /** @copydoc nk_bilinear_f32 */
225
+ NK_PUBLIC void nk_bilinear_f32_haswell(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t n,
226
+ nk_f64_t *result);
227
+ /** @copydoc nk_mahalanobis_f32 */
228
+ NK_PUBLIC void nk_mahalanobis_f32_haswell(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t n,
229
+ nk_f64_t *result);
230
+ /** @copydoc nk_bilinear_f16 */
231
+ NK_PUBLIC void nk_bilinear_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_f16_t const *c, nk_size_t n,
232
+ nk_f32_t *result);
233
+ /** @copydoc nk_mahalanobis_f16 */
234
+ NK_PUBLIC void nk_mahalanobis_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_f16_t const *c, nk_size_t n,
235
+ nk_f32_t *result);
236
+ /** @copydoc nk_bilinear_bf16 */
237
+ NK_PUBLIC void nk_bilinear_bf16_haswell(nk_bf16_t const *a, nk_bf16_t const *b, nk_bf16_t const *c, nk_size_t n,
238
+ nk_f32_t *result);
239
+ /** @copydoc nk_mahalanobis_bf16 */
240
+ NK_PUBLIC void nk_mahalanobis_bf16_haswell(nk_bf16_t const *a, nk_bf16_t const *b, nk_bf16_t const *c, nk_size_t n,
241
+ nk_f32_t *result);
242
+ #endif // NK_TARGET_HASWELL
243
+
244
+ #if NK_TARGET_SKYLAKE
245
+ /** @copydoc nk_bilinear_f64 */
246
+ NK_PUBLIC void nk_bilinear_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_f64_t const *c, nk_size_t n,
247
+ nk_f64_t *result);
248
+ /** @copydoc nk_bilinear_f64c */
249
+ NK_PUBLIC void nk_bilinear_f64c_skylake(nk_f64c_t const *a, nk_f64c_t const *b, nk_f64c_t const *c, nk_size_t n,
250
+ nk_f64c_t *results);
251
+ /** @copydoc nk_mahalanobis_f64 */
252
+ NK_PUBLIC void nk_mahalanobis_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_f64_t const *c, nk_size_t n,
253
+ nk_f64_t *result);
254
+ /** @copydoc nk_bilinear_f32 */
255
+ NK_PUBLIC void nk_bilinear_f32_skylake(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t n,
256
+ nk_f64_t *result);
257
+ /** @copydoc nk_bilinear_f32c */
258
+ NK_PUBLIC void nk_bilinear_f32c_skylake(nk_f32c_t const *a, nk_f32c_t const *b, nk_f32c_t const *c, nk_size_t n,
259
+ nk_f64c_t *results);
260
+ /** @copydoc nk_mahalanobis_f32 */
261
+ NK_PUBLIC void nk_mahalanobis_f32_skylake(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t n,
262
+ nk_f64_t *result);
263
+ #endif // NK_TARGET_SKYLAKE
264
+
265
+ #if NK_TARGET_GENOA
266
+ /** @copydoc nk_bilinear_bf16 */
267
+ NK_PUBLIC void nk_bilinear_bf16_genoa(nk_bf16_t const *a, nk_bf16_t const *b, nk_bf16_t const *c, nk_size_t n,
268
+ nk_f32_t *result);
269
+ /** @copydoc nk_bilinear_bf16c */
270
+ NK_PUBLIC void nk_bilinear_bf16c_genoa(nk_bf16c_t const *a, nk_bf16c_t const *b, nk_bf16c_t const *c, nk_size_t n,
271
+ nk_f32c_t *results);
272
+ /** @copydoc nk_mahalanobis_bf16 */
273
+ NK_PUBLIC void nk_mahalanobis_bf16_genoa(nk_bf16_t const *a, nk_bf16_t const *b, nk_bf16_t const *c, nk_size_t n,
274
+ nk_f32_t *result);
275
+ #endif // NK_TARGET_GENOA
276
+
277
+ #if NK_TARGET_RVV
278
+ /** @copydoc nk_bilinear_f64 */
279
+ NK_PUBLIC void nk_bilinear_f64_rvv(nk_f64_t const *a, nk_f64_t const *b, nk_f64_t const *c, nk_size_t n,
280
+ nk_f64_t *result);
281
+ /** @copydoc nk_mahalanobis_f64 */
282
+ NK_PUBLIC void nk_mahalanobis_f64_rvv(nk_f64_t const *a, nk_f64_t const *b, nk_f64_t const *c, nk_size_t n,
283
+ nk_f64_t *result);
284
+ /** @copydoc nk_bilinear_f32 */
285
+ NK_PUBLIC void nk_bilinear_f32_rvv(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t n,
286
+ nk_f64_t *result);
287
+ /** @copydoc nk_mahalanobis_f32 */
288
+ NK_PUBLIC void nk_mahalanobis_f32_rvv(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t n,
289
+ nk_f64_t *result);
290
+ /** @copydoc nk_bilinear_f16 */
291
+ NK_PUBLIC void nk_bilinear_f16_rvv(nk_f16_t const *a, nk_f16_t const *b, nk_f16_t const *c, nk_size_t n,
292
+ nk_f32_t *result);
293
+ /** @copydoc nk_mahalanobis_f16 */
294
+ NK_PUBLIC void nk_mahalanobis_f16_rvv(nk_f16_t const *a, nk_f16_t const *b, nk_f16_t const *c, nk_size_t n,
295
+ nk_f32_t *result);
296
+ /** @copydoc nk_bilinear_bf16 */
297
+ NK_PUBLIC void nk_bilinear_bf16_rvv(nk_bf16_t const *a, nk_bf16_t const *b, nk_bf16_t const *c, nk_size_t n,
298
+ nk_f32_t *result);
299
+ /** @copydoc nk_mahalanobis_bf16 */
300
+ NK_PUBLIC void nk_mahalanobis_bf16_rvv(nk_bf16_t const *a, nk_bf16_t const *b, nk_bf16_t const *c, nk_size_t n,
301
+ nk_f32_t *result);
302
+ #endif // NK_TARGET_RVV
303
+
304
+ /**
305
+ * @brief Returns the output dtype for bilinear forms.
306
+ */
307
+ NK_INTERNAL nk_dtype_t nk_bilinear_output_dtype(nk_dtype_t dtype) {
308
+ switch (dtype) {
309
+ case nk_f64_k: return nk_f64_k;
310
+ case nk_f32_k: return nk_f64_k;
311
+ case nk_f16_k: return nk_f32_k;
312
+ case nk_bf16_k: return nk_f32_k;
313
+ case nk_f64c_k: return nk_f64c_k;
314
+ case nk_f32c_k: return nk_f64c_k;
315
+ case nk_f16c_k: return nk_f32c_k;
316
+ case nk_bf16c_k: return nk_f32c_k;
317
+ default: return nk_dtype_unknown_k;
318
+ }
319
+ }
320
+
321
+ /**
322
+ * @brief Returns the output dtype for Mahalanobis metrics.
323
+ */
324
+ NK_INTERNAL nk_dtype_t nk_mahalanobis_output_dtype(nk_dtype_t dtype) {
325
+ switch (dtype) {
326
+ case nk_f64_k: return nk_f64_k;
327
+ case nk_f32_k: return nk_f64_k;
328
+ case nk_f16_k: return nk_f32_k;
329
+ case nk_bf16_k: return nk_f32_k;
330
+ default: return nk_dtype_unknown_k;
331
+ }
332
+ }
333
+
334
+ #if defined(__cplusplus)
335
+ } // extern "C"
336
+ #endif
337
+
338
+ #include "numkong/curved/serial.h"
339
+ #include "numkong/curved/neon.h"
340
+ #include "numkong/curved/neonhalf.h"
341
+ #include "numkong/curved/neonbfdot.h"
342
+ #include "numkong/curved/smef64.h"
343
+ #include "numkong/curved/haswell.h"
344
+ #include "numkong/curved/skylake.h"
345
+ #include "numkong/curved/genoa.h"
346
+ #include "numkong/curved/rvv.h"
347
+
348
+ #if defined(__cplusplus)
349
+ extern "C" {
350
+ #endif
351
+
352
+ #if !NK_DYNAMIC_DISPATCH
353
+
354
+ NK_PUBLIC void nk_bilinear_f64(nk_f64_t const *a, nk_f64_t const *b, nk_f64_t const *c, nk_size_t n, nk_f64_t *result) {
355
+ #if NK_TARGET_SKYLAKE
356
+ nk_bilinear_f64_skylake(a, b, c, n, result);
357
+ #elif NK_TARGET_SMEF64
358
+ nk_bilinear_f64_smef64(a, b, c, n, result);
359
+ #elif NK_TARGET_RVV
360
+ nk_bilinear_f64_rvv(a, b, c, n, result);
361
+ #else
362
+ nk_bilinear_f64_serial(a, b, c, n, result);
363
+ #endif
364
+ }
365
+
366
+ NK_PUBLIC void nk_bilinear_f32(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t n, nk_f64_t *result) {
367
+ #if NK_TARGET_SKYLAKE
368
+ nk_bilinear_f32_skylake(a, b, c, n, result);
369
+ #elif NK_TARGET_SMEF64
370
+ nk_bilinear_f32_smef64(a, b, c, n, result);
371
+ #elif NK_TARGET_HASWELL
372
+ nk_bilinear_f32_haswell(a, b, c, n, result);
373
+ #elif NK_TARGET_NEON
374
+ nk_bilinear_f32_neon(a, b, c, n, result);
375
+ #elif NK_TARGET_RVV
376
+ nk_bilinear_f32_rvv(a, b, c, n, result);
377
+ #else
378
+ nk_bilinear_f32_serial(a, b, c, n, result);
379
+ #endif
380
+ }
381
+
382
+ NK_PUBLIC void nk_bilinear_f16(nk_f16_t const *a, nk_f16_t const *b, nk_f16_t const *c, nk_size_t n, nk_f32_t *result) {
383
+ #if NK_TARGET_HASWELL
384
+ nk_bilinear_f16_haswell(a, b, c, n, result);
385
+ #elif NK_TARGET_NEONHALF
386
+ nk_bilinear_f16_neonhalf(a, b, c, n, result);
387
+ #elif NK_TARGET_RVV
388
+ nk_bilinear_f16_rvv(a, b, c, n, result);
389
+ #else
390
+ nk_bilinear_f16_serial(a, b, c, n, result);
391
+ #endif
392
+ }
393
+
394
+ NK_PUBLIC void nk_bilinear_bf16(nk_bf16_t const *a, nk_bf16_t const *b, nk_bf16_t const *c, nk_size_t n,
395
+ nk_f32_t *result) {
396
+ #if NK_TARGET_GENOA
397
+ nk_bilinear_bf16_genoa(a, b, c, n, result);
398
+ #elif NK_TARGET_HASWELL
399
+ nk_bilinear_bf16_haswell(a, b, c, n, result);
400
+ #elif NK_TARGET_NEONBFDOT
401
+ nk_bilinear_bf16_neonbfdot(a, b, c, n, result);
402
+ #elif NK_TARGET_RVV
403
+ nk_bilinear_bf16_rvv(a, b, c, n, result);
404
+ #else
405
+ nk_bilinear_bf16_serial(a, b, c, n, result);
406
+ #endif
407
+ }
408
+
409
+ NK_PUBLIC void nk_bilinear_f64c(nk_f64c_t const *a, nk_f64c_t const *b, nk_f64c_t const *c, nk_size_t n,
410
+ nk_f64c_t *results) {
411
+ #if NK_TARGET_SKYLAKE
412
+ nk_bilinear_f64c_skylake(a, b, c, n, results);
413
+ #elif NK_TARGET_SMEF64
414
+ nk_bilinear_f64c_smef64(a, b, c, n, results);
415
+ #else
416
+ nk_bilinear_f64c_serial(a, b, c, n, results);
417
+ #endif
418
+ }
419
+
420
+ NK_PUBLIC void nk_bilinear_f32c(nk_f32c_t const *a, nk_f32c_t const *b, nk_f32c_t const *c, nk_size_t n,
421
+ nk_f64c_t *results) {
422
+ #if NK_TARGET_SKYLAKE
423
+ nk_bilinear_f32c_skylake(a, b, c, n, results);
424
+ #elif NK_TARGET_SMEF64
425
+ nk_bilinear_f32c_smef64(a, b, c, n, results);
426
+ #elif NK_TARGET_NEON
427
+ nk_bilinear_f32c_neon(a, b, c, n, results);
428
+ #else
429
+ nk_bilinear_f32c_serial(a, b, c, n, results);
430
+ #endif
431
+ }
432
+
433
+ NK_PUBLIC void nk_bilinear_f16c(nk_f16c_t const *a, nk_f16c_t const *b, nk_f16c_t const *c, nk_size_t n,
434
+ nk_f32c_t *results) {
435
+ #if NK_TARGET_NEONHALF
436
+ nk_bilinear_f16c_neonhalf(a, b, c, n, results);
437
+ #else
438
+ nk_bilinear_f16c_serial(a, b, c, n, results);
439
+ #endif
440
+ }
441
+
442
+ NK_PUBLIC void nk_bilinear_bf16c(nk_bf16c_t const *a, nk_bf16c_t const *b, nk_bf16c_t const *c, nk_size_t n,
443
+ nk_f32c_t *results) {
444
+ #if NK_TARGET_GENOA
445
+ nk_bilinear_bf16c_genoa(a, b, c, n, results);
446
+ #elif NK_TARGET_NEONBFDOT
447
+ nk_bilinear_bf16c_neonbfdot(a, b, c, n, results);
448
+ #else
449
+ nk_bilinear_bf16c_serial(a, b, c, n, results);
450
+ #endif
451
+ }
452
+
453
+ NK_PUBLIC void nk_mahalanobis_f64(nk_f64_t const *a, nk_f64_t const *b, nk_f64_t const *c, nk_size_t n,
454
+ nk_f64_t *result) {
455
+ #if NK_TARGET_SKYLAKE
456
+ nk_mahalanobis_f64_skylake(a, b, c, n, result);
457
+ #elif NK_TARGET_SMEF64
458
+ nk_mahalanobis_f64_smef64(a, b, c, n, result);
459
+ #elif NK_TARGET_RVV
460
+ nk_mahalanobis_f64_rvv(a, b, c, n, result);
461
+ #else
462
+ nk_mahalanobis_f64_serial(a, b, c, n, result);
463
+ #endif
464
+ }
465
+
466
+ NK_PUBLIC void nk_mahalanobis_f32(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t n,
467
+ nk_f64_t *result) {
468
+ #if NK_TARGET_SKYLAKE
469
+ nk_mahalanobis_f32_skylake(a, b, c, n, result);
470
+ #elif NK_TARGET_SMEF64
471
+ nk_mahalanobis_f32_smef64(a, b, c, n, result);
472
+ #elif NK_TARGET_HASWELL
473
+ nk_mahalanobis_f32_haswell(a, b, c, n, result);
474
+ #elif NK_TARGET_NEON
475
+ nk_mahalanobis_f32_neon(a, b, c, n, result);
476
+ #elif NK_TARGET_RVV
477
+ nk_mahalanobis_f32_rvv(a, b, c, n, result);
478
+ #else
479
+ nk_mahalanobis_f32_serial(a, b, c, n, result);
480
+ #endif
481
+ }
482
+
483
+ NK_PUBLIC void nk_mahalanobis_f16(nk_f16_t const *a, nk_f16_t const *b, nk_f16_t const *c, nk_size_t n,
484
+ nk_f32_t *result) {
485
+ #if NK_TARGET_HASWELL
486
+ nk_mahalanobis_f16_haswell(a, b, c, n, result);
487
+ #elif NK_TARGET_NEONHALF
488
+ nk_mahalanobis_f16_neonhalf(a, b, c, n, result);
489
+ #elif NK_TARGET_RVV
490
+ nk_mahalanobis_f16_rvv(a, b, c, n, result);
491
+ #else
492
+ nk_mahalanobis_f16_serial(a, b, c, n, result);
493
+ #endif
494
+ }
495
+
496
+ NK_PUBLIC void nk_mahalanobis_bf16(nk_bf16_t const *a, nk_bf16_t const *b, nk_bf16_t const *c, nk_size_t n,
497
+ nk_f32_t *result) {
498
+ #if NK_TARGET_GENOA
499
+ nk_mahalanobis_bf16_genoa(a, b, c, n, result);
500
+ #elif NK_TARGET_HASWELL
501
+ nk_mahalanobis_bf16_haswell(a, b, c, n, result);
502
+ #elif NK_TARGET_NEONBFDOT
503
+ nk_mahalanobis_bf16_neonbfdot(a, b, c, n, result);
504
+ #elif NK_TARGET_RVV
505
+ nk_mahalanobis_bf16_rvv(a, b, c, n, result);
506
+ #else
507
+ nk_mahalanobis_bf16_serial(a, b, c, n, result);
508
+ #endif
509
+ }
510
+
511
+ #endif // !NK_DYNAMIC_DISPATCH
512
+
513
+ #if defined(__cplusplus)
514
+ } // extern "C"
515
+ #endif
516
+
517
+ #endif // NK_CURVED_H
@@ -0,0 +1,144 @@
1
+ /**
2
+ * @brief Curved-space kernels: bilinear, mahalanobis.
3
+ * @file include/numkong/curved.hpp
4
+ * @author Ash Vardanian
5
+ * @date February 5, 2026
6
+ */
7
+ #ifndef NK_CURVED_HPP
8
+ #define NK_CURVED_HPP
9
+
10
+ #include <cstdint> // `std::uint32_t`
11
+ #include <type_traits> // `std::is_same_v`
12
+
13
+ #include "numkong/curved.h"
14
+
15
+ #include "numkong/types.hpp"
16
+
17
+ namespace ashvardanian::numkong {
18
+
19
+ /**
20
+ * @brief Bilinear form: aᵀ × C × b where C is a d×d matrix (row-major)
21
+ * @param[in] a,b Input vectors of length d
22
+ * @param[in] c Matrix of size dxd (row-major)
23
+ * @param[in] d Number of dimensions
24
+ * @param[out] r Pointer to output value
25
+ *
26
+ * @tparam in_type_ Input vector element type (real or complex)
27
+ * @tparam result_type_ Accumulator type, defaults to `in_type_::curved_result_t`
28
+ * @tparam allow_simd_ Enable SIMD kernel dispatch when `prefer_simd_k`
29
+ *
30
+ * @note For weighted inner products, Mahalanobis distance, etc.
31
+ */
32
+ template <numeric_dtype in_type_, numeric_dtype result_type_ = typename in_type_::curved_result_t,
33
+ allow_simd_t allow_simd_ = prefer_simd_k>
34
+ void bilinear(in_type_ const *a, in_type_ const *b, in_type_ const *c, std::size_t d, result_type_ *r) noexcept {
35
+ constexpr bool simd = allow_simd_ == prefer_simd_k &&
36
+ std::is_same_v<result_type_, typename in_type_::curved_result_t>;
37
+
38
+ // Real types
39
+ if constexpr (std::is_same_v<in_type_, f64_t> && simd) nk_bilinear_f64(&a->raw_, &b->raw_, &c->raw_, d, &r->raw_);
40
+ else if constexpr (std::is_same_v<in_type_, f32_t> && simd)
41
+ nk_bilinear_f32(&a->raw_, &b->raw_, &c->raw_, d, &r->raw_);
42
+ else if constexpr (std::is_same_v<in_type_, f16_t> && simd)
43
+ nk_bilinear_f16(&a->raw_, &b->raw_, &c->raw_, d, &r->raw_);
44
+ else if constexpr (std::is_same_v<in_type_, bf16_t> && simd)
45
+ nk_bilinear_bf16(&a->raw_, &b->raw_, &c->raw_, d, &r->raw_);
46
+ // Complex types
47
+ else if constexpr (std::is_same_v<in_type_, f64c_t> && simd)
48
+ nk_bilinear_f64c(&a->raw_, &b->raw_, &c->raw_, d, &r->raw_);
49
+ else if constexpr (std::is_same_v<in_type_, f32c_t> && simd)
50
+ nk_bilinear_f32c(&a->raw_, &b->raw_, &c->raw_, d, &r->raw_);
51
+ else if constexpr (std::is_same_v<in_type_, f16c_t> && simd)
52
+ nk_bilinear_f16c(&a->raw_, &b->raw_, &c->raw_, d, &r->raw_);
53
+ else if constexpr (std::is_same_v<in_type_, bf16c_t> && simd)
54
+ nk_bilinear_bf16c(&a->raw_, &b->raw_, &c->raw_, d, &r->raw_);
55
+ // Scalar fallback
56
+ else {
57
+ result_type_ sum {};
58
+ for (std::size_t i = 0; i < d; i++) {
59
+ for (std::size_t j = 0; j < d; j++) {
60
+ sum = sum + result_type_(a[i]) * result_type_(c[i * d + j]) * result_type_(b[j]);
61
+ }
62
+ }
63
+ *r = sum;
64
+ }
65
+ }
66
+
67
+ /**
68
+ * @brief Mahalanobis distance: √((a−b)ᵀ × C × (a−b)) where C is a d×d matrix (row-major)
69
+ * @param[in] a,b Input vectors of length d
70
+ * @param[in] c Covariance matrix of size dxd (row-major)
71
+ * @param[in] d Number of dimensions
72
+ * @param[out] r Pointer to output distance value
73
+ *
74
+ * @tparam in_type_ Input vector element type
75
+ * @tparam result_type_ Accumulator type, defaults to `in_type_::curved_result_t`
76
+ * @tparam allow_simd_ Enable SIMD kernel dispatch when `prefer_simd_k`
77
+ */
78
+ template <numeric_dtype in_type_, numeric_dtype result_type_ = typename in_type_::curved_result_t,
79
+ allow_simd_t allow_simd_ = prefer_simd_k>
80
+ void mahalanobis(in_type_ const *a, in_type_ const *b, in_type_ const *c, std::size_t d, result_type_ *r) noexcept {
81
+ constexpr bool simd = allow_simd_ == prefer_simd_k &&
82
+ std::is_same_v<result_type_, typename in_type_::curved_result_t>;
83
+
84
+ if constexpr (std::is_same_v<in_type_, f64_t> && simd)
85
+ nk_mahalanobis_f64(&a->raw_, &b->raw_, &c->raw_, d, &r->raw_);
86
+ else if constexpr (std::is_same_v<in_type_, f32_t> && simd)
87
+ nk_mahalanobis_f32(&a->raw_, &b->raw_, &c->raw_, d, &r->raw_);
88
+ else if constexpr (std::is_same_v<in_type_, f16_t> && simd)
89
+ nk_mahalanobis_f16(&a->raw_, &b->raw_, &c->raw_, d, &r->raw_);
90
+ else if constexpr (std::is_same_v<in_type_, bf16_t> && simd)
91
+ nk_mahalanobis_bf16(&a->raw_, &b->raw_, &c->raw_, d, &r->raw_);
92
+ // Scalar fallback
93
+ else {
94
+ result_type_ sum {};
95
+ for (std::size_t i = 0; i < d; i++) {
96
+ result_type_ di = result_type_(a[i]) - result_type_(b[i]);
97
+ for (std::size_t j = 0; j < d; j++) {
98
+ result_type_ dj = result_type_(a[j]) - result_type_(b[j]);
99
+ sum = sum + di * result_type_(c[i * d + j]) * dj;
100
+ }
101
+ }
102
+ *r = sum.sqrt();
103
+ }
104
+ }
105
+
106
+ } // namespace ashvardanian::numkong
107
+
108
+ #include "numkong/tensor.hpp"
109
+
110
+ namespace ashvardanian::numkong {
111
+
112
+ template <numeric_dtype in_type_, numeric_dtype result_type_ = typename in_type_::curved_result_t,
113
+ allow_simd_t allow_simd_ = prefer_simd_k, std::size_t max_rank_a_, std::size_t max_rank_b_,
114
+ std::size_t max_rank_c_>
115
+ void bilinear(tensor_view<in_type_, max_rank_a_> a, tensor_view<in_type_, max_rank_b_> b,
116
+ tensor_view<in_type_, max_rank_c_> c, std::size_t d, result_type_ *r) noexcept {
117
+ bilinear<in_type_, result_type_, allow_simd_>(a.data(), b.data(), c.data(), d, r);
118
+ }
119
+
120
+ template <numeric_dtype in_type_, numeric_dtype result_type_ = typename in_type_::curved_result_t,
121
+ allow_simd_t allow_simd_ = prefer_simd_k>
122
+ void bilinear(vector_view<in_type_> a, vector_view<in_type_> b, vector_view<in_type_> c, std::size_t d,
123
+ result_type_ *r) noexcept {
124
+ bilinear<in_type_, result_type_, allow_simd_>(a.data(), b.data(), c.data(), d, r);
125
+ }
126
+
127
+ template <numeric_dtype in_type_, numeric_dtype result_type_ = typename in_type_::curved_result_t,
128
+ allow_simd_t allow_simd_ = prefer_simd_k, std::size_t max_rank_a_, std::size_t max_rank_b_,
129
+ std::size_t max_rank_c_>
130
+ void mahalanobis(tensor_view<in_type_, max_rank_a_> a, tensor_view<in_type_, max_rank_b_> b,
131
+ tensor_view<in_type_, max_rank_c_> c, std::size_t d, result_type_ *r) noexcept {
132
+ mahalanobis<in_type_, result_type_, allow_simd_>(a.data(), b.data(), c.data(), d, r);
133
+ }
134
+
135
+ template <numeric_dtype in_type_, numeric_dtype result_type_ = typename in_type_::curved_result_t,
136
+ allow_simd_t allow_simd_ = prefer_simd_k>
137
+ void mahalanobis(vector_view<in_type_> a, vector_view<in_type_> b, vector_view<in_type_> c, std::size_t d,
138
+ result_type_ *r) noexcept {
139
+ mahalanobis<in_type_, result_type_, allow_simd_>(a.data(), b.data(), c.data(), d, r);
140
+ }
141
+
142
+ } // namespace ashvardanian::numkong
143
+
144
+ #endif // NK_CURVED_HPP