numkong 7.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (294) hide show
  1. package/LICENSE +201 -0
  2. package/README.md +495 -0
  3. package/binding.gyp +540 -0
  4. package/c/dispatch.h +512 -0
  5. package/c/dispatch_bf16.c +389 -0
  6. package/c/dispatch_bf16c.c +52 -0
  7. package/c/dispatch_e2m3.c +263 -0
  8. package/c/dispatch_e3m2.c +243 -0
  9. package/c/dispatch_e4m3.c +276 -0
  10. package/c/dispatch_e5m2.c +272 -0
  11. package/c/dispatch_f16.c +376 -0
  12. package/c/dispatch_f16c.c +58 -0
  13. package/c/dispatch_f32.c +378 -0
  14. package/c/dispatch_f32c.c +99 -0
  15. package/c/dispatch_f64.c +296 -0
  16. package/c/dispatch_f64c.c +98 -0
  17. package/c/dispatch_i16.c +96 -0
  18. package/c/dispatch_i32.c +89 -0
  19. package/c/dispatch_i4.c +150 -0
  20. package/c/dispatch_i64.c +86 -0
  21. package/c/dispatch_i8.c +289 -0
  22. package/c/dispatch_other.c +330 -0
  23. package/c/dispatch_u1.c +148 -0
  24. package/c/dispatch_u16.c +124 -0
  25. package/c/dispatch_u32.c +118 -0
  26. package/c/dispatch_u4.c +150 -0
  27. package/c/dispatch_u64.c +102 -0
  28. package/c/dispatch_u8.c +303 -0
  29. package/c/numkong.c +950 -0
  30. package/include/README.md +573 -0
  31. package/include/module.modulemap +129 -0
  32. package/include/numkong/attention/sapphireamx.h +1361 -0
  33. package/include/numkong/attention/sme.h +2066 -0
  34. package/include/numkong/attention.h +49 -0
  35. package/include/numkong/capabilities.h +748 -0
  36. package/include/numkong/cast/README.md +262 -0
  37. package/include/numkong/cast/haswell.h +975 -0
  38. package/include/numkong/cast/icelake.h +470 -0
  39. package/include/numkong/cast/neon.h +1192 -0
  40. package/include/numkong/cast/rvv.h +1021 -0
  41. package/include/numkong/cast/sapphire.h +262 -0
  42. package/include/numkong/cast/serial.h +2262 -0
  43. package/include/numkong/cast/skylake.h +856 -0
  44. package/include/numkong/cast/v128relaxed.h +180 -0
  45. package/include/numkong/cast.h +230 -0
  46. package/include/numkong/curved/README.md +223 -0
  47. package/include/numkong/curved/genoa.h +182 -0
  48. package/include/numkong/curved/haswell.h +276 -0
  49. package/include/numkong/curved/neon.h +205 -0
  50. package/include/numkong/curved/neonbfdot.h +212 -0
  51. package/include/numkong/curved/neonhalf.h +212 -0
  52. package/include/numkong/curved/rvv.h +305 -0
  53. package/include/numkong/curved/serial.h +207 -0
  54. package/include/numkong/curved/skylake.h +457 -0
  55. package/include/numkong/curved/smef64.h +506 -0
  56. package/include/numkong/curved.h +517 -0
  57. package/include/numkong/curved.hpp +144 -0
  58. package/include/numkong/dot/README.md +425 -0
  59. package/include/numkong/dot/alder.h +563 -0
  60. package/include/numkong/dot/genoa.h +315 -0
  61. package/include/numkong/dot/haswell.h +1688 -0
  62. package/include/numkong/dot/icelake.h +883 -0
  63. package/include/numkong/dot/neon.h +818 -0
  64. package/include/numkong/dot/neonbfdot.h +244 -0
  65. package/include/numkong/dot/neonfhm.h +360 -0
  66. package/include/numkong/dot/neonhalf.h +198 -0
  67. package/include/numkong/dot/neonsdot.h +508 -0
  68. package/include/numkong/dot/rvv.h +714 -0
  69. package/include/numkong/dot/rvvbb.h +72 -0
  70. package/include/numkong/dot/rvvbf16.h +123 -0
  71. package/include/numkong/dot/rvvhalf.h +129 -0
  72. package/include/numkong/dot/sapphire.h +141 -0
  73. package/include/numkong/dot/serial.h +838 -0
  74. package/include/numkong/dot/sierra.h +405 -0
  75. package/include/numkong/dot/skylake.h +1084 -0
  76. package/include/numkong/dot/sve.h +379 -0
  77. package/include/numkong/dot/svebfdot.h +74 -0
  78. package/include/numkong/dot/svehalf.h +123 -0
  79. package/include/numkong/dot/v128relaxed.h +1258 -0
  80. package/include/numkong/dot.h +1070 -0
  81. package/include/numkong/dot.hpp +94 -0
  82. package/include/numkong/dots/README.md +496 -0
  83. package/include/numkong/dots/alder.h +114 -0
  84. package/include/numkong/dots/genoa.h +94 -0
  85. package/include/numkong/dots/haswell.h +295 -0
  86. package/include/numkong/dots/icelake.h +171 -0
  87. package/include/numkong/dots/neon.h +120 -0
  88. package/include/numkong/dots/neonbfdot.h +58 -0
  89. package/include/numkong/dots/neonfhm.h +94 -0
  90. package/include/numkong/dots/neonhalf.h +57 -0
  91. package/include/numkong/dots/neonsdot.h +108 -0
  92. package/include/numkong/dots/rvv.h +2486 -0
  93. package/include/numkong/dots/sapphireamx.h +3973 -0
  94. package/include/numkong/dots/serial.h +2844 -0
  95. package/include/numkong/dots/sierra.h +97 -0
  96. package/include/numkong/dots/skylake.h +196 -0
  97. package/include/numkong/dots/sme.h +5372 -0
  98. package/include/numkong/dots/smebi32.h +461 -0
  99. package/include/numkong/dots/smef64.h +1318 -0
  100. package/include/numkong/dots/smehalf.h +47 -0
  101. package/include/numkong/dots/v128relaxed.h +294 -0
  102. package/include/numkong/dots.h +2804 -0
  103. package/include/numkong/dots.hpp +639 -0
  104. package/include/numkong/each/README.md +469 -0
  105. package/include/numkong/each/haswell.h +1658 -0
  106. package/include/numkong/each/icelake.h +272 -0
  107. package/include/numkong/each/neon.h +1104 -0
  108. package/include/numkong/each/neonbfdot.h +212 -0
  109. package/include/numkong/each/neonhalf.h +410 -0
  110. package/include/numkong/each/rvv.h +1121 -0
  111. package/include/numkong/each/sapphire.h +477 -0
  112. package/include/numkong/each/serial.h +260 -0
  113. package/include/numkong/each/skylake.h +1562 -0
  114. package/include/numkong/each.h +2146 -0
  115. package/include/numkong/each.hpp +434 -0
  116. package/include/numkong/geospatial/README.md +147 -0
  117. package/include/numkong/geospatial/haswell.h +593 -0
  118. package/include/numkong/geospatial/neon.h +571 -0
  119. package/include/numkong/geospatial/rvv.h +701 -0
  120. package/include/numkong/geospatial/serial.h +309 -0
  121. package/include/numkong/geospatial/skylake.h +577 -0
  122. package/include/numkong/geospatial/v128relaxed.h +613 -0
  123. package/include/numkong/geospatial.h +453 -0
  124. package/include/numkong/geospatial.hpp +235 -0
  125. package/include/numkong/matrix.hpp +336 -0
  126. package/include/numkong/maxsim/README.md +187 -0
  127. package/include/numkong/maxsim/alder.h +511 -0
  128. package/include/numkong/maxsim/genoa.h +115 -0
  129. package/include/numkong/maxsim/haswell.h +553 -0
  130. package/include/numkong/maxsim/icelake.h +480 -0
  131. package/include/numkong/maxsim/neonsdot.h +394 -0
  132. package/include/numkong/maxsim/sapphireamx.h +877 -0
  133. package/include/numkong/maxsim/serial.h +490 -0
  134. package/include/numkong/maxsim/sme.h +929 -0
  135. package/include/numkong/maxsim/v128relaxed.h +280 -0
  136. package/include/numkong/maxsim.h +571 -0
  137. package/include/numkong/maxsim.hpp +133 -0
  138. package/include/numkong/mesh/README.md +227 -0
  139. package/include/numkong/mesh/haswell.h +2235 -0
  140. package/include/numkong/mesh/neon.h +1329 -0
  141. package/include/numkong/mesh/neonbfdot.h +842 -0
  142. package/include/numkong/mesh/neonhalf.h +616 -0
  143. package/include/numkong/mesh/rvv.h +916 -0
  144. package/include/numkong/mesh/serial.h +742 -0
  145. package/include/numkong/mesh/skylake.h +1135 -0
  146. package/include/numkong/mesh/v128relaxed.h +1052 -0
  147. package/include/numkong/mesh.h +652 -0
  148. package/include/numkong/mesh.hpp +762 -0
  149. package/include/numkong/numkong.h +78 -0
  150. package/include/numkong/numkong.hpp +57 -0
  151. package/include/numkong/probability/README.md +173 -0
  152. package/include/numkong/probability/haswell.h +267 -0
  153. package/include/numkong/probability/neon.h +225 -0
  154. package/include/numkong/probability/rvv.h +409 -0
  155. package/include/numkong/probability/serial.h +169 -0
  156. package/include/numkong/probability/skylake.h +324 -0
  157. package/include/numkong/probability.h +383 -0
  158. package/include/numkong/probability.hpp +120 -0
  159. package/include/numkong/random.h +50 -0
  160. package/include/numkong/random.hpp +285 -0
  161. package/include/numkong/reduce/README.md +547 -0
  162. package/include/numkong/reduce/alder.h +632 -0
  163. package/include/numkong/reduce/genoa.h +201 -0
  164. package/include/numkong/reduce/haswell.h +3783 -0
  165. package/include/numkong/reduce/icelake.h +549 -0
  166. package/include/numkong/reduce/neon.h +3841 -0
  167. package/include/numkong/reduce/neonbfdot.h +353 -0
  168. package/include/numkong/reduce/neonfhm.h +665 -0
  169. package/include/numkong/reduce/neonhalf.h +157 -0
  170. package/include/numkong/reduce/neonsdot.h +357 -0
  171. package/include/numkong/reduce/rvv.h +3407 -0
  172. package/include/numkong/reduce/serial.h +757 -0
  173. package/include/numkong/reduce/sierra.h +338 -0
  174. package/include/numkong/reduce/skylake.h +3792 -0
  175. package/include/numkong/reduce/v128relaxed.h +2302 -0
  176. package/include/numkong/reduce.h +1597 -0
  177. package/include/numkong/reduce.hpp +633 -0
  178. package/include/numkong/scalar/README.md +89 -0
  179. package/include/numkong/scalar/haswell.h +113 -0
  180. package/include/numkong/scalar/neon.h +122 -0
  181. package/include/numkong/scalar/neonhalf.h +70 -0
  182. package/include/numkong/scalar/rvv.h +211 -0
  183. package/include/numkong/scalar/sapphire.h +63 -0
  184. package/include/numkong/scalar/serial.h +332 -0
  185. package/include/numkong/scalar/v128relaxed.h +56 -0
  186. package/include/numkong/scalar.h +683 -0
  187. package/include/numkong/set/README.md +179 -0
  188. package/include/numkong/set/haswell.h +334 -0
  189. package/include/numkong/set/icelake.h +485 -0
  190. package/include/numkong/set/neon.h +364 -0
  191. package/include/numkong/set/rvv.h +226 -0
  192. package/include/numkong/set/rvvbb.h +117 -0
  193. package/include/numkong/set/serial.h +174 -0
  194. package/include/numkong/set/sve.h +185 -0
  195. package/include/numkong/set/v128relaxed.h +240 -0
  196. package/include/numkong/set.h +457 -0
  197. package/include/numkong/set.hpp +114 -0
  198. package/include/numkong/sets/README.md +149 -0
  199. package/include/numkong/sets/haswell.h +63 -0
  200. package/include/numkong/sets/icelake.h +66 -0
  201. package/include/numkong/sets/neon.h +61 -0
  202. package/include/numkong/sets/serial.h +43 -0
  203. package/include/numkong/sets/smebi32.h +1099 -0
  204. package/include/numkong/sets/v128relaxed.h +58 -0
  205. package/include/numkong/sets.h +339 -0
  206. package/include/numkong/sparse/README.md +156 -0
  207. package/include/numkong/sparse/icelake.h +463 -0
  208. package/include/numkong/sparse/neon.h +288 -0
  209. package/include/numkong/sparse/serial.h +117 -0
  210. package/include/numkong/sparse/sve2.h +507 -0
  211. package/include/numkong/sparse/turin.h +322 -0
  212. package/include/numkong/sparse.h +363 -0
  213. package/include/numkong/sparse.hpp +113 -0
  214. package/include/numkong/spatial/README.md +435 -0
  215. package/include/numkong/spatial/alder.h +607 -0
  216. package/include/numkong/spatial/genoa.h +290 -0
  217. package/include/numkong/spatial/haswell.h +960 -0
  218. package/include/numkong/spatial/icelake.h +586 -0
  219. package/include/numkong/spatial/neon.h +773 -0
  220. package/include/numkong/spatial/neonbfdot.h +165 -0
  221. package/include/numkong/spatial/neonhalf.h +118 -0
  222. package/include/numkong/spatial/neonsdot.h +261 -0
  223. package/include/numkong/spatial/rvv.h +984 -0
  224. package/include/numkong/spatial/rvvbf16.h +123 -0
  225. package/include/numkong/spatial/rvvhalf.h +117 -0
  226. package/include/numkong/spatial/sapphire.h +343 -0
  227. package/include/numkong/spatial/serial.h +346 -0
  228. package/include/numkong/spatial/sierra.h +323 -0
  229. package/include/numkong/spatial/skylake.h +606 -0
  230. package/include/numkong/spatial/sve.h +224 -0
  231. package/include/numkong/spatial/svebfdot.h +122 -0
  232. package/include/numkong/spatial/svehalf.h +109 -0
  233. package/include/numkong/spatial/v128relaxed.h +717 -0
  234. package/include/numkong/spatial.h +1425 -0
  235. package/include/numkong/spatial.hpp +183 -0
  236. package/include/numkong/spatials/README.md +580 -0
  237. package/include/numkong/spatials/alder.h +94 -0
  238. package/include/numkong/spatials/genoa.h +94 -0
  239. package/include/numkong/spatials/haswell.h +219 -0
  240. package/include/numkong/spatials/icelake.h +113 -0
  241. package/include/numkong/spatials/neon.h +109 -0
  242. package/include/numkong/spatials/neonbfdot.h +60 -0
  243. package/include/numkong/spatials/neonfhm.h +92 -0
  244. package/include/numkong/spatials/neonhalf.h +58 -0
  245. package/include/numkong/spatials/neonsdot.h +109 -0
  246. package/include/numkong/spatials/rvv.h +1960 -0
  247. package/include/numkong/spatials/sapphireamx.h +1149 -0
  248. package/include/numkong/spatials/serial.h +226 -0
  249. package/include/numkong/spatials/sierra.h +96 -0
  250. package/include/numkong/spatials/skylake.h +184 -0
  251. package/include/numkong/spatials/sme.h +1901 -0
  252. package/include/numkong/spatials/smef64.h +465 -0
  253. package/include/numkong/spatials/v128relaxed.h +240 -0
  254. package/include/numkong/spatials.h +3021 -0
  255. package/include/numkong/spatials.hpp +508 -0
  256. package/include/numkong/tensor.hpp +1592 -0
  257. package/include/numkong/trigonometry/README.md +184 -0
  258. package/include/numkong/trigonometry/haswell.h +652 -0
  259. package/include/numkong/trigonometry/neon.h +639 -0
  260. package/include/numkong/trigonometry/rvv.h +699 -0
  261. package/include/numkong/trigonometry/serial.h +703 -0
  262. package/include/numkong/trigonometry/skylake.h +721 -0
  263. package/include/numkong/trigonometry/v128relaxed.h +666 -0
  264. package/include/numkong/trigonometry.h +467 -0
  265. package/include/numkong/trigonometry.hpp +166 -0
  266. package/include/numkong/types.h +1384 -0
  267. package/include/numkong/types.hpp +5603 -0
  268. package/include/numkong/vector.hpp +698 -0
  269. package/javascript/README.md +246 -0
  270. package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
  271. package/javascript/dist/cjs/numkong-wasm.js +617 -0
  272. package/javascript/dist/cjs/numkong.d.ts +343 -0
  273. package/javascript/dist/cjs/numkong.js +523 -0
  274. package/javascript/dist/cjs/package.json +3 -0
  275. package/javascript/dist/cjs/types.d.ts +284 -0
  276. package/javascript/dist/cjs/types.js +653 -0
  277. package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
  278. package/javascript/dist/esm/numkong-wasm.js +595 -0
  279. package/javascript/dist/esm/numkong.d.ts +343 -0
  280. package/javascript/dist/esm/numkong.js +452 -0
  281. package/javascript/dist/esm/package.json +3 -0
  282. package/javascript/dist/esm/types.d.ts +284 -0
  283. package/javascript/dist/esm/types.js +630 -0
  284. package/javascript/dist-package-cjs.json +3 -0
  285. package/javascript/dist-package-esm.json +3 -0
  286. package/javascript/node-gyp-build.d.ts +1 -0
  287. package/javascript/numkong-wasm.ts +756 -0
  288. package/javascript/numkong.c +689 -0
  289. package/javascript/numkong.ts +575 -0
  290. package/javascript/tsconfig-base.json +39 -0
  291. package/javascript/tsconfig-cjs.json +8 -0
  292. package/javascript/tsconfig-esm.json +8 -0
  293. package/javascript/types.ts +674 -0
  294. package/package.json +87 -0
@@ -0,0 +1,305 @@
1
+ /**
2
+ * @brief SIMD-accelerated Curved Space Distances for RISC-V.
3
+ * @file include/numkong/curved/rvv.h
4
+ * @author Ash Vardanian
5
+ * @date February 6, 2026
6
+ *
7
+ * @sa include/numkong/curved.h
8
+ *
9
+ * Implements bilinear forms and Mahalanobis distance using RVV 1.0:
10
+ * - f32 inputs use f32 SIMD accumulation with vfredusum ordered reduction
11
+ * - f64 inputs use f64 SIMD accumulation with vfredusum ordered reduction
12
+ * - f16/bf16 inputs are converted to f32 via cast helpers, then accumulated in f32
13
+ * - Complex bilinear forms delegate to serial implementations
14
+ */
15
+ #ifndef NK_CURVED_RVV_H
16
+ #define NK_CURVED_RVV_H
17
+
18
+ #if NK_TARGET_RISCV_
19
+ #if NK_TARGET_RVV
20
+
21
+ #include "numkong/types.h"
22
+ #include "numkong/curved/serial.h"
23
+ #include "numkong/cast/rvv.h"
24
+ #include "numkong/spatial/rvv.h" // `nk_f64_sqrt_rvv`
25
+
26
+ #if defined(__clang__)
27
+ #pragma clang attribute push(__attribute__((target("arch=+v"))), apply_to = function)
28
+ #elif defined(__GNUC__)
29
+ #pragma GCC push_options
30
+ #pragma GCC target("arch=+v")
31
+ #endif
32
+
33
+ #if defined(__cplusplus)
34
+ extern "C" {
35
+ #endif
36
+
37
+ NK_PUBLIC void nk_bilinear_f32_rvv(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t n,
38
+ nk_f64_t *result) {
39
+ nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
40
+ nk_f64_t outer_sum = 0;
41
+ for (nk_size_t i = 0; i < n; ++i) {
42
+ vfloat64m4_t inner_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
43
+ nk_f32_t const *c_row = c + i * n;
44
+ nk_size_t remaining = n;
45
+ for (nk_size_t vector_length; remaining > 0; remaining -= vector_length, c_row += vector_length) {
46
+ vector_length = __riscv_vsetvl_e32m2(remaining);
47
+ vfloat32m2_t c_f32m2 = __riscv_vle32_v_f32m2(c_row, vector_length);
48
+ vfloat32m2_t b_f32m2 = __riscv_vle32_v_f32m2(b + (n - remaining), vector_length);
49
+ inner_f64m4 = __riscv_vfwmacc_vv_f64m4_tu(inner_f64m4, c_f32m2, b_f32m2, vector_length);
50
+ }
51
+ vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
52
+ nk_f64_t inner_val = __riscv_vfmv_f_s_f64m1_f64(
53
+ __riscv_vfredusum_vs_f64m4_f64m1(inner_f64m4, zero_f64m1, vlmax));
54
+ outer_sum += (nk_f64_t)a[i] * inner_val;
55
+ }
56
+ *result = outer_sum;
57
+ }
58
+
59
+ NK_PUBLIC void nk_bilinear_f64_rvv(nk_f64_t const *a, nk_f64_t const *b, nk_f64_t const *c, nk_size_t n,
60
+ nk_f64_t *result) {
61
+ nk_size_t vlmax = __riscv_vsetvlmax_e64m4();
62
+ vfloat64m1_t sum_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
63
+ nk_f64_t outer_compensation = 0;
64
+ for (nk_size_t i = 0; i < n; ++i) {
65
+ vfloat64m4_t inner_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
66
+ vfloat64m4_t compensation_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
67
+ nk_f64_t const *c_row = c + i * n;
68
+ nk_size_t remaining = n;
69
+ for (nk_size_t vector_length; remaining > 0; remaining -= vector_length, c_row += vector_length) {
70
+ vector_length = __riscv_vsetvl_e64m4(remaining);
71
+ vfloat64m4_t vc_f64m4 = __riscv_vle64_v_f64m4(c_row, vector_length);
72
+ vfloat64m4_t vb_f64m4 = __riscv_vle64_v_f64m4(b + (n - remaining), vector_length);
73
+ vfloat64m4_t product_f64m4 = __riscv_vfmul_vv_f64m4(vc_f64m4, vb_f64m4, vector_length);
74
+ vfloat64m4_t corrected_term_f64m4 = __riscv_vfsub_vv_f64m4(product_f64m4, compensation_f64m4,
75
+ vector_length);
76
+ vfloat64m4_t running_sum_f64m4 = __riscv_vfadd_vv_f64m4_tu(inner_f64m4, inner_f64m4, corrected_term_f64m4,
77
+ vector_length);
78
+ compensation_f64m4 = __riscv_vfsub_vv_f64m4_tu(
79
+ compensation_f64m4, __riscv_vfsub_vv_f64m4(running_sum_f64m4, inner_f64m4, vector_length),
80
+ corrected_term_f64m4, vector_length);
81
+ inner_f64m4 = running_sum_f64m4;
82
+ }
83
+ vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
84
+ nk_f64_t inner_val = __riscv_vfmv_f_s_f64m1_f64(
85
+ __riscv_vfredusum_vs_f64m4_f64m1(inner_f64m4, zero_f64m1, vlmax));
86
+ nk_f64_t product_outer = a[i] * inner_val;
87
+ nk_f64_t old_sum = __riscv_vfmv_f_s_f64m1_f64(sum_f64m1);
88
+ nk_f64_t new_sum = old_sum + product_outer;
89
+ if (nk_f64_abs_(old_sum) >= nk_f64_abs_(product_outer))
90
+ outer_compensation += (old_sum - new_sum) + product_outer;
91
+ else outer_compensation += (product_outer - new_sum) + old_sum;
92
+ sum_f64m1 = __riscv_vfmv_v_f_f64m1(new_sum, 1);
93
+ }
94
+ *result = __riscv_vfmv_f_s_f64m1_f64(sum_f64m1) + outer_compensation;
95
+ }
96
+
97
+ NK_PUBLIC void nk_bilinear_f16_rvv(nk_f16_t const *a, nk_f16_t const *b, nk_f16_t const *c, nk_size_t n,
98
+ nk_f32_t *result) {
99
+ nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
100
+ vfloat32m1_t sum_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
101
+ for (nk_size_t i = 0; i < n; ++i) {
102
+ // Convert a[i] from f16 to f32
103
+ nk_f32_t a_i;
104
+ nk_f16_to_f32_serial(a + i, &a_i);
105
+
106
+ vfloat32m2_t inner_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
107
+ nk_f16_t const *c_row = c + i * n;
108
+ nk_size_t remaining = n;
109
+ for (nk_size_t vector_length; remaining > 0; remaining -= vector_length, c_row += vector_length) {
110
+ vector_length = __riscv_vsetvl_e16m1(remaining);
111
+ // Load f16 as u16 bits and convert to f32
112
+ vuint16m1_t vc_u16m1 = __riscv_vle16_v_u16m1((nk_u16_t const *)c_row, vector_length);
113
+ vuint16m1_t vb_u16m1 = __riscv_vle16_v_u16m1((nk_u16_t const *)(b + (n - remaining)), vector_length);
114
+ vfloat32m2_t vc_f32m2 = nk_f16m1_to_f32m2_rvv_(vc_u16m1, vector_length);
115
+ vfloat32m2_t vb_f32m2 = nk_f16m1_to_f32m2_rvv_(vb_u16m1, vector_length);
116
+ inner_f32m2 = __riscv_vfmacc_vv_f32m2_tu(inner_f32m2, vc_f32m2, vb_f32m2, vector_length);
117
+ }
118
+ vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
119
+ nk_f32_t inner_val = __riscv_vfmv_f_s_f32m1_f32(
120
+ __riscv_vfredusum_vs_f32m2_f32m1(inner_f32m2, zero_f32m1, vlmax));
121
+ sum_f32m1 = __riscv_vfmv_v_f_f32m1(__riscv_vfmv_f_s_f32m1_f32(sum_f32m1) + a_i * inner_val, 1);
122
+ }
123
+ *result = __riscv_vfmv_f_s_f32m1_f32(sum_f32m1);
124
+ }
125
+
126
+ NK_PUBLIC void nk_bilinear_bf16_rvv(nk_bf16_t const *a, nk_bf16_t const *b, nk_bf16_t const *c, nk_size_t n,
127
+ nk_f32_t *result) {
128
+ nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
129
+ vfloat32m1_t sum_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
130
+ for (nk_size_t i = 0; i < n; ++i) {
131
+ // Convert a[i] from bf16 to f32
132
+ nk_f32_t a_i;
133
+ nk_bf16_to_f32_serial(a + i, &a_i);
134
+
135
+ vfloat32m2_t inner_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
136
+ nk_bf16_t const *c_row = c + i * n;
137
+ nk_size_t remaining = n;
138
+ for (nk_size_t vector_length; remaining > 0; remaining -= vector_length, c_row += vector_length) {
139
+ vector_length = __riscv_vsetvl_e16m1(remaining);
140
+ // Load bf16 as u16 bits and convert to f32
141
+ vuint16m1_t vc_u16m1 = __riscv_vle16_v_u16m1((nk_u16_t const *)c_row, vector_length);
142
+ vuint16m1_t vb_u16m1 = __riscv_vle16_v_u16m1((nk_u16_t const *)(b + (n - remaining)), vector_length);
143
+ vfloat32m2_t vc_f32m2 = nk_bf16m1_to_f32m2_rvv_(vc_u16m1, vector_length);
144
+ vfloat32m2_t vb_f32m2 = nk_bf16m1_to_f32m2_rvv_(vb_u16m1, vector_length);
145
+ inner_f32m2 = __riscv_vfmacc_vv_f32m2_tu(inner_f32m2, vc_f32m2, vb_f32m2, vector_length);
146
+ }
147
+ vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
148
+ nk_f32_t inner_val = __riscv_vfmv_f_s_f32m1_f32(
149
+ __riscv_vfredusum_vs_f32m2_f32m1(inner_f32m2, zero_f32m1, vlmax));
150
+ sum_f32m1 = __riscv_vfmv_v_f_f32m1(__riscv_vfmv_f_s_f32m1_f32(sum_f32m1) + a_i * inner_val, 1);
151
+ }
152
+ *result = __riscv_vfmv_f_s_f32m1_f32(sum_f32m1);
153
+ }
154
+
155
+ NK_PUBLIC void nk_mahalanobis_f32_rvv(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t n,
156
+ nk_f64_t *result) {
157
+ nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
158
+ nk_f64_t outer_sum = 0;
159
+ for (nk_size_t i = 0; i < n; ++i) {
160
+ nk_f64_t diff_i = (nk_f64_t)a[i] - (nk_f64_t)b[i];
161
+ vfloat64m4_t inner_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
162
+ nk_f32_t const *c_row = c + i * n;
163
+ nk_size_t remaining = n;
164
+ for (nk_size_t vector_length; remaining > 0; remaining -= vector_length, c_row += vector_length) {
165
+ vector_length = __riscv_vsetvl_e32m2(remaining);
166
+ nk_size_t j = n - remaining;
167
+ vfloat32m2_t c_f32m2 = __riscv_vle32_v_f32m2(c_row, vector_length);
168
+ vfloat32m2_t a_f32m2 = __riscv_vle32_v_f32m2(a + j, vector_length);
169
+ vfloat32m2_t b_f32m2 = __riscv_vle32_v_f32m2(b + j, vector_length);
170
+ vfloat64m4_t diff_f64m4 = __riscv_vfwsub_vv_f64m4(a_f32m2, b_f32m2, vector_length);
171
+ vfloat64m4_t c_f64m4 = __riscv_vfwcvt_f_f_v_f64m4(c_f32m2, vector_length);
172
+ inner_f64m4 = __riscv_vfmacc_vv_f64m4_tu(inner_f64m4, c_f64m4, diff_f64m4, vector_length);
173
+ }
174
+ vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
175
+ nk_f64_t inner_val = __riscv_vfmv_f_s_f64m1_f64(
176
+ __riscv_vfredusum_vs_f64m4_f64m1(inner_f64m4, zero_f64m1, vlmax));
177
+ outer_sum += diff_i * inner_val;
178
+ }
179
+ *result = nk_f64_sqrt_rvv(outer_sum > 0 ? outer_sum : 0);
180
+ }
181
+
182
+ NK_PUBLIC void nk_mahalanobis_f64_rvv(nk_f64_t const *a, nk_f64_t const *b, nk_f64_t const *c, nk_size_t n,
183
+ nk_f64_t *result) {
184
+ nk_size_t vlmax = __riscv_vsetvlmax_e64m4();
185
+ vfloat64m1_t sum_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
186
+ nk_f64_t outer_compensation = 0;
187
+ for (nk_size_t i = 0; i < n; ++i) {
188
+ nk_f64_t diff_i = a[i] - b[i];
189
+ vfloat64m4_t inner_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
190
+ vfloat64m4_t compensation_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
191
+ nk_f64_t const *c_row = c + i * n;
192
+ nk_size_t remaining = n;
193
+ for (nk_size_t vector_length; remaining > 0; remaining -= vector_length, c_row += vector_length) {
194
+ vector_length = __riscv_vsetvl_e64m4(remaining);
195
+ nk_size_t j = n - remaining;
196
+ vfloat64m4_t vc_f64m4 = __riscv_vle64_v_f64m4(c_row, vector_length);
197
+ vfloat64m4_t va_f64m4 = __riscv_vle64_v_f64m4(a + j, vector_length);
198
+ vfloat64m4_t vb_f64m4 = __riscv_vle64_v_f64m4(b + j, vector_length);
199
+ vfloat64m4_t diff_j_f64m4 = __riscv_vfsub_vv_f64m4(va_f64m4, vb_f64m4, vector_length);
200
+ vfloat64m4_t product_f64m4 = __riscv_vfmul_vv_f64m4(vc_f64m4, diff_j_f64m4, vector_length);
201
+ vfloat64m4_t corrected_term_f64m4 = __riscv_vfsub_vv_f64m4(product_f64m4, compensation_f64m4,
202
+ vector_length);
203
+ vfloat64m4_t running_sum_f64m4 = __riscv_vfadd_vv_f64m4_tu(inner_f64m4, inner_f64m4, corrected_term_f64m4,
204
+ vector_length);
205
+ compensation_f64m4 = __riscv_vfsub_vv_f64m4_tu(
206
+ compensation_f64m4, __riscv_vfsub_vv_f64m4(running_sum_f64m4, inner_f64m4, vector_length),
207
+ corrected_term_f64m4, vector_length);
208
+ inner_f64m4 = running_sum_f64m4;
209
+ }
210
+ vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
211
+ nk_f64_t inner_val = __riscv_vfmv_f_s_f64m1_f64(
212
+ __riscv_vfredusum_vs_f64m4_f64m1(inner_f64m4, zero_f64m1, vlmax));
213
+ nk_f64_t product_outer = diff_i * inner_val;
214
+ nk_f64_t old_sum = __riscv_vfmv_f_s_f64m1_f64(sum_f64m1);
215
+ nk_f64_t new_sum = old_sum + product_outer;
216
+ if (nk_f64_abs_(old_sum) >= nk_f64_abs_(product_outer))
217
+ outer_compensation += (old_sum - new_sum) + product_outer;
218
+ else outer_compensation += (product_outer - new_sum) + old_sum;
219
+ sum_f64m1 = __riscv_vfmv_v_f_f64m1(new_sum, 1);
220
+ }
221
+ nk_f64_t quadratic = __riscv_vfmv_f_s_f64m1_f64(sum_f64m1) + outer_compensation;
222
+ *result = nk_f64_sqrt_rvv(quadratic > 0 ? quadratic : 0);
223
+ }
224
+
225
+ NK_PUBLIC void nk_mahalanobis_f16_rvv(nk_f16_t const *a, nk_f16_t const *b, nk_f16_t const *c, nk_size_t n,
226
+ nk_f32_t *result) {
227
+ nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
228
+ vfloat32m1_t sum_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
229
+ for (nk_size_t i = 0; i < n; ++i) {
230
+ nk_f32_t a_i, b_i;
231
+ nk_f16_to_f32_serial(a + i, &a_i);
232
+ nk_f16_to_f32_serial(b + i, &b_i);
233
+ nk_f32_t diff_i = a_i - b_i;
234
+
235
+ vfloat32m2_t inner_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
236
+ nk_f16_t const *c_row = c + i * n;
237
+ nk_size_t remaining = n;
238
+ for (nk_size_t vector_length; remaining > 0; remaining -= vector_length, c_row += vector_length) {
239
+ vector_length = __riscv_vsetvl_e16m1(remaining);
240
+ nk_size_t j = n - remaining;
241
+ vuint16m1_t vc_u16m1 = __riscv_vle16_v_u16m1((nk_u16_t const *)c_row, vector_length);
242
+ vuint16m1_t va_u16m1 = __riscv_vle16_v_u16m1((nk_u16_t const *)(a + j), vector_length);
243
+ vuint16m1_t vb_u16m1 = __riscv_vle16_v_u16m1((nk_u16_t const *)(b + j), vector_length);
244
+ vfloat32m2_t vc_f32m2 = nk_f16m1_to_f32m2_rvv_(vc_u16m1, vector_length);
245
+ vfloat32m2_t va_f32m2 = nk_f16m1_to_f32m2_rvv_(va_u16m1, vector_length);
246
+ vfloat32m2_t vb_f32m2 = nk_f16m1_to_f32m2_rvv_(vb_u16m1, vector_length);
247
+ vfloat32m2_t diff_j_f32m2 = __riscv_vfsub_vv_f32m2(va_f32m2, vb_f32m2, vector_length);
248
+ inner_f32m2 = __riscv_vfmacc_vv_f32m2_tu(inner_f32m2, vc_f32m2, diff_j_f32m2, vector_length);
249
+ }
250
+ vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
251
+ nk_f32_t inner_val = __riscv_vfmv_f_s_f32m1_f32(
252
+ __riscv_vfredusum_vs_f32m2_f32m1(inner_f32m2, zero_f32m1, vlmax));
253
+ sum_f32m1 = __riscv_vfmv_v_f_f32m1(__riscv_vfmv_f_s_f32m1_f32(sum_f32m1) + diff_i * inner_val, 1);
254
+ }
255
+ nk_f32_t quadratic_f16 = __riscv_vfmv_f_s_f32m1_f32(sum_f32m1);
256
+ *result = nk_f32_sqrt_rvv(quadratic_f16 > 0 ? quadratic_f16 : 0);
257
+ }
258
+
259
+ NK_PUBLIC void nk_mahalanobis_bf16_rvv(nk_bf16_t const *a, nk_bf16_t const *b, nk_bf16_t const *c, nk_size_t n,
260
+ nk_f32_t *result) {
261
+ nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
262
+ vfloat32m1_t sum_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
263
+ for (nk_size_t i = 0; i < n; ++i) {
264
+ nk_f32_t a_i, b_i;
265
+ nk_bf16_to_f32_serial(a + i, &a_i);
266
+ nk_bf16_to_f32_serial(b + i, &b_i);
267
+ nk_f32_t diff_i = a_i - b_i;
268
+
269
+ vfloat32m2_t inner_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
270
+ nk_bf16_t const *c_row = c + i * n;
271
+ nk_size_t remaining = n;
272
+ for (nk_size_t vector_length; remaining > 0; remaining -= vector_length, c_row += vector_length) {
273
+ vector_length = __riscv_vsetvl_e16m1(remaining);
274
+ nk_size_t j = n - remaining;
275
+ vuint16m1_t vc_u16m1 = __riscv_vle16_v_u16m1((nk_u16_t const *)c_row, vector_length);
276
+ vuint16m1_t va_u16m1 = __riscv_vle16_v_u16m1((nk_u16_t const *)(a + j), vector_length);
277
+ vuint16m1_t vb_u16m1 = __riscv_vle16_v_u16m1((nk_u16_t const *)(b + j), vector_length);
278
+ vfloat32m2_t vc_f32m2 = nk_bf16m1_to_f32m2_rvv_(vc_u16m1, vector_length);
279
+ vfloat32m2_t va_f32m2 = nk_bf16m1_to_f32m2_rvv_(va_u16m1, vector_length);
280
+ vfloat32m2_t vb_f32m2 = nk_bf16m1_to_f32m2_rvv_(vb_u16m1, vector_length);
281
+ vfloat32m2_t diff_j_f32m2 = __riscv_vfsub_vv_f32m2(va_f32m2, vb_f32m2, vector_length);
282
+ inner_f32m2 = __riscv_vfmacc_vv_f32m2_tu(inner_f32m2, vc_f32m2, diff_j_f32m2, vector_length);
283
+ }
284
+ vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
285
+ nk_f32_t inner_val = __riscv_vfmv_f_s_f32m1_f32(
286
+ __riscv_vfredusum_vs_f32m2_f32m1(inner_f32m2, zero_f32m1, vlmax));
287
+ sum_f32m1 = __riscv_vfmv_v_f_f32m1(__riscv_vfmv_f_s_f32m1_f32(sum_f32m1) + diff_i * inner_val, 1);
288
+ }
289
+ nk_f32_t quadratic_bf16 = __riscv_vfmv_f_s_f32m1_f32(sum_f32m1);
290
+ *result = nk_f32_sqrt_rvv(quadratic_bf16 > 0 ? quadratic_bf16 : 0);
291
+ }
292
+
293
+ #if defined(__cplusplus)
294
+ } // extern "C"
295
+ #endif
296
+
297
+ #if defined(__clang__)
298
+ #pragma clang attribute pop
299
+ #elif defined(__GNUC__)
300
+ #pragma GCC pop_options
301
+ #endif
302
+
303
+ #endif // NK_TARGET_RVV
304
+ #endif // NK_TARGET_RISCV_
305
+ #endif // NK_CURVED_RVV_H
@@ -0,0 +1,207 @@
1
+ /**
2
+ * @brief SWAR-accelerated Curved Space Similarity for SIMD-free CPUs.
3
+ * @file include/numkong/curved/serial.h
4
+ * @author Ash Vardanian
5
+ * @date January 14, 2026
6
+ *
7
+ * @sa include/numkong/curved.h
8
+ *
9
+ * Implements bilinear forms and Mahalanobis distance with precision-appropriate strategies:
10
+ * - f64 inputs use Dot2 (Ogita-Rump-Oishi 2005) for error-free transformations
11
+ * - f32/f16/bf16 inputs upcast to wider accumulators (f64/f32), providing sufficient
12
+ * precision headroom without compensation overhead
13
+ *
14
+ * Bilinear form: aᵀ × C × b = Σᵢ aᵢ × (Σⱼ cᵢⱼ × bⱼ)
15
+ *
16
+ * The nested loop structure has two accumulation levels:
17
+ * - Inner: Σⱼ cᵢⱼ × bⱼ (O(n) terms per row)
18
+ * - Outer: Σᵢ aᵢ × inner_result (O(n) terms total)
19
+ *
20
+ * For f64→f64 (no upcast headroom): Dot2 uses TwoProd and TwoSum error-free
21
+ * transformations at both levels, capturing rounding errors in compensation terms.
22
+ *
23
+ * For upcasted types (f32→f64, f16→f32, bf16→f32): the wider accumulator provides
24
+ * enough extra mantissa bits that simple accumulation suffices.
25
+ *
26
+ * @see Ogita, T., Rump, S.M., Oishi, S. (2005). "Accurate Sum and Dot Product"
27
+ */
28
+ #ifndef NK_CURVED_SERIAL_H
29
+ #define NK_CURVED_SERIAL_H
30
+
31
+ #include "numkong/types.h"
32
+ #include "numkong/spatial/serial.h" // `nk_f64_sqrt_serial`
33
+
34
+ #if defined(__cplusplus)
35
+ extern "C" {
36
+ #endif
37
+
38
+ /**
39
+ * @brief Macro for bilinear form aᵀ × C × b with simple accumulation.
40
+ *
41
+ * Suitable for upcasted types where the wider accumulator provides sufficient
42
+ * precision headroom (f32→f64, f16→f32, bf16→f32).
43
+ */
44
+ #define nk_define_bilinear_(input_type, accumulator_type, output_type, load_and_convert) \
45
+ NK_PUBLIC void nk_bilinear_##input_type##_serial(nk_##input_type##_t const *a, nk_##input_type##_t const *b, \
46
+ nk_##input_type##_t const *c, nk_size_t n, \
47
+ nk_##output_type##_t *result) { \
48
+ nk_##accumulator_type##_t outer_sum = 0; \
49
+ nk_##accumulator_type##_t vector_a_value, vector_b_value, tensor_value; \
50
+ for (nk_size_t row = 0; row != n; ++row) { \
51
+ nk_##accumulator_type##_t inner_sum = 0; \
52
+ load_and_convert(a + row, &vector_a_value); \
53
+ for (nk_size_t column = 0; column != n; ++column) { \
54
+ load_and_convert(b + column, &vector_b_value); \
55
+ load_and_convert(c + row * n + column, &tensor_value); \
56
+ inner_sum += tensor_value * vector_b_value; \
57
+ } \
58
+ outer_sum += vector_a_value * inner_sum; \
59
+ } \
60
+ *result = (nk_##output_type##_t)(outer_sum); \
61
+ }
62
+
63
+ /**
64
+ * @brief Macro for complex bilinear form aᵀ × C × b with simple accumulation.
65
+ *
66
+ * Suitable for upcasted complex types where the wider accumulator provides
67
+ * sufficient precision headroom.
68
+ */
69
+ #define nk_define_bilinear_complex_(input_type, accumulator_type, output_type, load_and_convert) \
70
+ NK_PUBLIC void nk_bilinear_##input_type##_serial( \
71
+ nk_##input_type##_t const *a_pairs, nk_##input_type##_t const *b_pairs, nk_##input_type##_t const *c_pairs, \
72
+ nk_size_t n, nk_##output_type##c_t *results) { \
73
+ nk_##accumulator_type##_t outer_sum_real = 0, outer_sum_imag = 0; \
74
+ nk_##accumulator_type##_t a_real, a_imag, b_real, b_imag, c_real, c_imag; \
75
+ for (nk_size_t row = 0; row != n; ++row) { \
76
+ nk_##accumulator_type##_t inner_sum_real = 0, inner_sum_imag = 0; \
77
+ load_and_convert(&(a_pairs + row)->real, &a_real); \
78
+ load_and_convert(&(a_pairs + row)->imag, &a_imag); \
79
+ for (nk_size_t column = 0; column != n; ++column) { \
80
+ load_and_convert(&(b_pairs + column)->real, &b_real); \
81
+ load_and_convert(&(b_pairs + column)->imag, &b_imag); \
82
+ load_and_convert(&(c_pairs + row * n + column)->real, &c_real); \
83
+ load_and_convert(&(c_pairs + row * n + column)->imag, &c_imag); \
84
+ inner_sum_real += c_real * b_real - c_imag * b_imag; \
85
+ inner_sum_imag += c_real * b_imag + c_imag * b_real; \
86
+ } \
87
+ /* Complex multiply: a_i * inner_result */ \
88
+ outer_sum_real += a_real * inner_sum_real - a_imag * inner_sum_imag; \
89
+ outer_sum_imag += a_real * inner_sum_imag + a_imag * inner_sum_real; \
90
+ } \
91
+ results->real = outer_sum_real; \
92
+ results->imag = outer_sum_imag; \
93
+ }
94
+
95
+ /**
96
+ * @brief Macro for Mahalanobis distance √((a−b)ᵀ × C × (a−b)) with simple accumulation.
97
+ *
98
+ * Suitable for upcasted types where the wider accumulator provides sufficient
99
+ * precision headroom. Differences are computed in the accumulator precision.
100
+ */
101
+ #define nk_define_mahalanobis_(input_type, accumulator_type, output_type, load_and_convert) \
102
+ NK_PUBLIC void nk_mahalanobis_##input_type##_serial(nk_##input_type##_t const *a, nk_##input_type##_t const *b, \
103
+ nk_##input_type##_t const *c, nk_size_t n, \
104
+ nk_##output_type##_t *result) { \
105
+ nk_##accumulator_type##_t outer_sum = 0; \
106
+ nk_##accumulator_type##_t a_row_value, b_row_value, a_column_value, b_column_value, tensor_value; \
107
+ for (nk_size_t row = 0; row != n; ++row) { \
108
+ nk_##accumulator_type##_t inner_sum = 0; \
109
+ load_and_convert(a + row, &a_row_value); \
110
+ load_and_convert(b + row, &b_row_value); \
111
+ nk_##accumulator_type##_t difference_row = a_row_value - b_row_value; \
112
+ for (nk_size_t column = 0; column != n; ++column) { \
113
+ load_and_convert(a + column, &a_column_value); \
114
+ load_and_convert(b + column, &b_column_value); \
115
+ load_and_convert(c + row * n + column, &tensor_value); \
116
+ nk_##accumulator_type##_t difference_column = a_column_value - b_column_value; \
117
+ inner_sum += tensor_value * difference_column; \
118
+ } \
119
+ outer_sum += difference_row * inner_sum; \
120
+ } \
121
+ nk_##accumulator_type##_t quadratic = outer_sum; \
122
+ *result = nk_##accumulator_type##_sqrt_serial(quadratic > 0 ? quadratic : 0); \
123
+ }
124
+
125
+ // f32 → f64 accumulator → f64 output
126
+ nk_define_bilinear_(f32, f64, f64, nk_assign_from_to_) // nk_bilinear_f32_serial
127
+ nk_define_bilinear_complex_(f32c, f64, f64, nk_assign_from_to_) // nk_bilinear_f32c_serial
128
+ nk_define_mahalanobis_(f32, f64, f64, nk_assign_from_to_) // nk_mahalanobis_f32_serial
129
+
130
+ // f16 → f32 accumulator → f32 output: f32 provides ample headroom for f16 (~3 vs ~7 decimal digits)
131
+ nk_define_bilinear_(f16, f32, f32, nk_f16_to_f32_serial) // nk_bilinear_f16_serial
132
+ nk_define_bilinear_complex_(f16c, f32, f32, nk_f16_to_f32_serial) // nk_bilinear_f16c_serial
133
+ nk_define_mahalanobis_(f16, f32, f32, nk_f16_to_f32_serial) // nk_mahalanobis_f16_serial
134
+
135
+ // bf16 → f32 accumulator → f32 output: f32 provides ample headroom for bf16 (~2 vs ~7 decimal digits)
136
+ nk_define_bilinear_(bf16, f32, f32, nk_bf16_to_f32_serial) // nk_bilinear_bf16_serial
137
+ nk_define_bilinear_complex_(bf16c, f32, f32, nk_bf16_to_f32_serial) // nk_bilinear_bf16c_serial
138
+ nk_define_mahalanobis_(bf16, f32, f32, nk_bf16_to_f32_serial) // nk_mahalanobis_bf16_serial
139
+
140
+ #undef nk_define_bilinear_
141
+ #undef nk_define_bilinear_complex_
142
+ #undef nk_define_mahalanobis_
143
+
144
+ NK_PUBLIC void nk_bilinear_f64_serial(nk_f64_t const *a, nk_f64_t const *b, nk_f64_t const *c, nk_size_t n,
145
+ nk_f64_t *result) {
146
+ nk_f64_t outer_sum = 0, outer_comp = 0;
147
+ for (nk_size_t row = 0; row != n; ++row) {
148
+ nk_f64_t inner_sum = 0, inner_comp = 0;
149
+ for (nk_size_t col = 0; col != n; ++col) nk_f64_dot2_(&inner_sum, &inner_comp, c[row * n + col], b[col]);
150
+ nk_f64_t cb_j = inner_sum + inner_comp;
151
+ nk_f64_dot2_(&outer_sum, &outer_comp, a[row], cb_j);
152
+ }
153
+ *result = outer_sum + outer_comp;
154
+ }
155
+
156
+ NK_PUBLIC void nk_bilinear_f64c_serial(nk_f64c_t const *a_pairs, nk_f64c_t const *b_pairs, nk_f64c_t const *c_pairs,
157
+ nk_size_t n, nk_f64c_t *results) {
158
+ nk_f64_t outer_sum_real = 0, outer_comp_real = 0;
159
+ nk_f64_t outer_sum_imag = 0, outer_comp_imag = 0;
160
+ for (nk_size_t row = 0; row != n; ++row) {
161
+ nk_f64_t a_real = a_pairs[row].real;
162
+ nk_f64_t a_imag = a_pairs[row].imag;
163
+ // 4 Dot2 accumulators for inner cross-terms
164
+ nk_f64_t sum_rr = 0, comp_rr = 0;
165
+ nk_f64_t sum_ii = 0, comp_ii = 0;
166
+ nk_f64_t sum_ri = 0, comp_ri = 0;
167
+ nk_f64_t sum_ir = 0, comp_ir = 0;
168
+ for (nk_size_t col = 0; col != n; ++col) {
169
+ nk_f64_t b_real = b_pairs[col].real, b_imag = b_pairs[col].imag;
170
+ nk_f64_t c_real = c_pairs[row * n + col].real, c_imag = c_pairs[row * n + col].imag;
171
+ nk_f64_dot2_(&sum_rr, &comp_rr, c_real, b_real);
172
+ nk_f64_dot2_(&sum_ii, &comp_ii, c_imag, b_imag);
173
+ nk_f64_dot2_(&sum_ri, &comp_ri, c_real, b_imag);
174
+ nk_f64_dot2_(&sum_ir, &comp_ir, c_imag, b_real);
175
+ }
176
+ nk_f64_t inner_real = (sum_rr + comp_rr) - (sum_ii + comp_ii);
177
+ nk_f64_t inner_imag = (sum_ri + comp_ri) + (sum_ir + comp_ir);
178
+ // Outer Dot2 complex multiply: a × inner
179
+ nk_f64_dot2_(&outer_sum_real, &outer_comp_real, a_real, inner_real);
180
+ nk_f64_dot2_(&outer_sum_real, &outer_comp_real, -a_imag, inner_imag);
181
+ nk_f64_dot2_(&outer_sum_imag, &outer_comp_imag, a_real, inner_imag);
182
+ nk_f64_dot2_(&outer_sum_imag, &outer_comp_imag, a_imag, inner_real);
183
+ }
184
+ results->real = outer_sum_real + outer_comp_real;
185
+ results->imag = outer_sum_imag + outer_comp_imag;
186
+ }
187
+
188
+ NK_PUBLIC void nk_mahalanobis_f64_serial(nk_f64_t const *a, nk_f64_t const *b, nk_f64_t const *c, nk_size_t n,
189
+ nk_f64_t *result) {
190
+ nk_f64_t outer_sum = 0, outer_comp = 0;
191
+ for (nk_size_t row = 0; row != n; ++row) {
192
+ nk_f64_t diff_row = a[row] - b[row];
193
+ nk_f64_t inner_sum = 0, inner_comp = 0;
194
+ for (nk_size_t col = 0; col != n; ++col)
195
+ nk_f64_dot2_(&inner_sum, &inner_comp, c[row * n + col], a[col] - b[col]);
196
+ nk_f64_t cb_j = inner_sum + inner_comp;
197
+ nk_f64_dot2_(&outer_sum, &outer_comp, diff_row, cb_j);
198
+ }
199
+ nk_f64_t quadratic = outer_sum + outer_comp;
200
+ *result = nk_f64_sqrt_serial(quadratic > 0 ? quadratic : 0);
201
+ }
202
+
203
+ #if defined(__cplusplus)
204
+ } // extern "C"
205
+ #endif
206
+
207
+ #endif // NK_CURVED_SERIAL_H