numkong 7.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (294) hide show
  1. package/LICENSE +201 -0
  2. package/README.md +495 -0
  3. package/binding.gyp +540 -0
  4. package/c/dispatch.h +512 -0
  5. package/c/dispatch_bf16.c +389 -0
  6. package/c/dispatch_bf16c.c +52 -0
  7. package/c/dispatch_e2m3.c +263 -0
  8. package/c/dispatch_e3m2.c +243 -0
  9. package/c/dispatch_e4m3.c +276 -0
  10. package/c/dispatch_e5m2.c +272 -0
  11. package/c/dispatch_f16.c +376 -0
  12. package/c/dispatch_f16c.c +58 -0
  13. package/c/dispatch_f32.c +378 -0
  14. package/c/dispatch_f32c.c +99 -0
  15. package/c/dispatch_f64.c +296 -0
  16. package/c/dispatch_f64c.c +98 -0
  17. package/c/dispatch_i16.c +96 -0
  18. package/c/dispatch_i32.c +89 -0
  19. package/c/dispatch_i4.c +150 -0
  20. package/c/dispatch_i64.c +86 -0
  21. package/c/dispatch_i8.c +289 -0
  22. package/c/dispatch_other.c +330 -0
  23. package/c/dispatch_u1.c +148 -0
  24. package/c/dispatch_u16.c +124 -0
  25. package/c/dispatch_u32.c +118 -0
  26. package/c/dispatch_u4.c +150 -0
  27. package/c/dispatch_u64.c +102 -0
  28. package/c/dispatch_u8.c +303 -0
  29. package/c/numkong.c +950 -0
  30. package/include/README.md +573 -0
  31. package/include/module.modulemap +129 -0
  32. package/include/numkong/attention/sapphireamx.h +1361 -0
  33. package/include/numkong/attention/sme.h +2066 -0
  34. package/include/numkong/attention.h +49 -0
  35. package/include/numkong/capabilities.h +748 -0
  36. package/include/numkong/cast/README.md +262 -0
  37. package/include/numkong/cast/haswell.h +975 -0
  38. package/include/numkong/cast/icelake.h +470 -0
  39. package/include/numkong/cast/neon.h +1192 -0
  40. package/include/numkong/cast/rvv.h +1021 -0
  41. package/include/numkong/cast/sapphire.h +262 -0
  42. package/include/numkong/cast/serial.h +2262 -0
  43. package/include/numkong/cast/skylake.h +856 -0
  44. package/include/numkong/cast/v128relaxed.h +180 -0
  45. package/include/numkong/cast.h +230 -0
  46. package/include/numkong/curved/README.md +223 -0
  47. package/include/numkong/curved/genoa.h +182 -0
  48. package/include/numkong/curved/haswell.h +276 -0
  49. package/include/numkong/curved/neon.h +205 -0
  50. package/include/numkong/curved/neonbfdot.h +212 -0
  51. package/include/numkong/curved/neonhalf.h +212 -0
  52. package/include/numkong/curved/rvv.h +305 -0
  53. package/include/numkong/curved/serial.h +207 -0
  54. package/include/numkong/curved/skylake.h +457 -0
  55. package/include/numkong/curved/smef64.h +506 -0
  56. package/include/numkong/curved.h +517 -0
  57. package/include/numkong/curved.hpp +144 -0
  58. package/include/numkong/dot/README.md +425 -0
  59. package/include/numkong/dot/alder.h +563 -0
  60. package/include/numkong/dot/genoa.h +315 -0
  61. package/include/numkong/dot/haswell.h +1688 -0
  62. package/include/numkong/dot/icelake.h +883 -0
  63. package/include/numkong/dot/neon.h +818 -0
  64. package/include/numkong/dot/neonbfdot.h +244 -0
  65. package/include/numkong/dot/neonfhm.h +360 -0
  66. package/include/numkong/dot/neonhalf.h +198 -0
  67. package/include/numkong/dot/neonsdot.h +508 -0
  68. package/include/numkong/dot/rvv.h +714 -0
  69. package/include/numkong/dot/rvvbb.h +72 -0
  70. package/include/numkong/dot/rvvbf16.h +123 -0
  71. package/include/numkong/dot/rvvhalf.h +129 -0
  72. package/include/numkong/dot/sapphire.h +141 -0
  73. package/include/numkong/dot/serial.h +838 -0
  74. package/include/numkong/dot/sierra.h +405 -0
  75. package/include/numkong/dot/skylake.h +1084 -0
  76. package/include/numkong/dot/sve.h +379 -0
  77. package/include/numkong/dot/svebfdot.h +74 -0
  78. package/include/numkong/dot/svehalf.h +123 -0
  79. package/include/numkong/dot/v128relaxed.h +1258 -0
  80. package/include/numkong/dot.h +1070 -0
  81. package/include/numkong/dot.hpp +94 -0
  82. package/include/numkong/dots/README.md +496 -0
  83. package/include/numkong/dots/alder.h +114 -0
  84. package/include/numkong/dots/genoa.h +94 -0
  85. package/include/numkong/dots/haswell.h +295 -0
  86. package/include/numkong/dots/icelake.h +171 -0
  87. package/include/numkong/dots/neon.h +120 -0
  88. package/include/numkong/dots/neonbfdot.h +58 -0
  89. package/include/numkong/dots/neonfhm.h +94 -0
  90. package/include/numkong/dots/neonhalf.h +57 -0
  91. package/include/numkong/dots/neonsdot.h +108 -0
  92. package/include/numkong/dots/rvv.h +2486 -0
  93. package/include/numkong/dots/sapphireamx.h +3973 -0
  94. package/include/numkong/dots/serial.h +2844 -0
  95. package/include/numkong/dots/sierra.h +97 -0
  96. package/include/numkong/dots/skylake.h +196 -0
  97. package/include/numkong/dots/sme.h +5372 -0
  98. package/include/numkong/dots/smebi32.h +461 -0
  99. package/include/numkong/dots/smef64.h +1318 -0
  100. package/include/numkong/dots/smehalf.h +47 -0
  101. package/include/numkong/dots/v128relaxed.h +294 -0
  102. package/include/numkong/dots.h +2804 -0
  103. package/include/numkong/dots.hpp +639 -0
  104. package/include/numkong/each/README.md +469 -0
  105. package/include/numkong/each/haswell.h +1658 -0
  106. package/include/numkong/each/icelake.h +272 -0
  107. package/include/numkong/each/neon.h +1104 -0
  108. package/include/numkong/each/neonbfdot.h +212 -0
  109. package/include/numkong/each/neonhalf.h +410 -0
  110. package/include/numkong/each/rvv.h +1121 -0
  111. package/include/numkong/each/sapphire.h +477 -0
  112. package/include/numkong/each/serial.h +260 -0
  113. package/include/numkong/each/skylake.h +1562 -0
  114. package/include/numkong/each.h +2146 -0
  115. package/include/numkong/each.hpp +434 -0
  116. package/include/numkong/geospatial/README.md +147 -0
  117. package/include/numkong/geospatial/haswell.h +593 -0
  118. package/include/numkong/geospatial/neon.h +571 -0
  119. package/include/numkong/geospatial/rvv.h +701 -0
  120. package/include/numkong/geospatial/serial.h +309 -0
  121. package/include/numkong/geospatial/skylake.h +577 -0
  122. package/include/numkong/geospatial/v128relaxed.h +613 -0
  123. package/include/numkong/geospatial.h +453 -0
  124. package/include/numkong/geospatial.hpp +235 -0
  125. package/include/numkong/matrix.hpp +336 -0
  126. package/include/numkong/maxsim/README.md +187 -0
  127. package/include/numkong/maxsim/alder.h +511 -0
  128. package/include/numkong/maxsim/genoa.h +115 -0
  129. package/include/numkong/maxsim/haswell.h +553 -0
  130. package/include/numkong/maxsim/icelake.h +480 -0
  131. package/include/numkong/maxsim/neonsdot.h +394 -0
  132. package/include/numkong/maxsim/sapphireamx.h +877 -0
  133. package/include/numkong/maxsim/serial.h +490 -0
  134. package/include/numkong/maxsim/sme.h +929 -0
  135. package/include/numkong/maxsim/v128relaxed.h +280 -0
  136. package/include/numkong/maxsim.h +571 -0
  137. package/include/numkong/maxsim.hpp +133 -0
  138. package/include/numkong/mesh/README.md +227 -0
  139. package/include/numkong/mesh/haswell.h +2235 -0
  140. package/include/numkong/mesh/neon.h +1329 -0
  141. package/include/numkong/mesh/neonbfdot.h +842 -0
  142. package/include/numkong/mesh/neonhalf.h +616 -0
  143. package/include/numkong/mesh/rvv.h +916 -0
  144. package/include/numkong/mesh/serial.h +742 -0
  145. package/include/numkong/mesh/skylake.h +1135 -0
  146. package/include/numkong/mesh/v128relaxed.h +1052 -0
  147. package/include/numkong/mesh.h +652 -0
  148. package/include/numkong/mesh.hpp +762 -0
  149. package/include/numkong/numkong.h +78 -0
  150. package/include/numkong/numkong.hpp +57 -0
  151. package/include/numkong/probability/README.md +173 -0
  152. package/include/numkong/probability/haswell.h +267 -0
  153. package/include/numkong/probability/neon.h +225 -0
  154. package/include/numkong/probability/rvv.h +409 -0
  155. package/include/numkong/probability/serial.h +169 -0
  156. package/include/numkong/probability/skylake.h +324 -0
  157. package/include/numkong/probability.h +383 -0
  158. package/include/numkong/probability.hpp +120 -0
  159. package/include/numkong/random.h +50 -0
  160. package/include/numkong/random.hpp +285 -0
  161. package/include/numkong/reduce/README.md +547 -0
  162. package/include/numkong/reduce/alder.h +632 -0
  163. package/include/numkong/reduce/genoa.h +201 -0
  164. package/include/numkong/reduce/haswell.h +3783 -0
  165. package/include/numkong/reduce/icelake.h +549 -0
  166. package/include/numkong/reduce/neon.h +3841 -0
  167. package/include/numkong/reduce/neonbfdot.h +353 -0
  168. package/include/numkong/reduce/neonfhm.h +665 -0
  169. package/include/numkong/reduce/neonhalf.h +157 -0
  170. package/include/numkong/reduce/neonsdot.h +357 -0
  171. package/include/numkong/reduce/rvv.h +3407 -0
  172. package/include/numkong/reduce/serial.h +757 -0
  173. package/include/numkong/reduce/sierra.h +338 -0
  174. package/include/numkong/reduce/skylake.h +3792 -0
  175. package/include/numkong/reduce/v128relaxed.h +2302 -0
  176. package/include/numkong/reduce.h +1597 -0
  177. package/include/numkong/reduce.hpp +633 -0
  178. package/include/numkong/scalar/README.md +89 -0
  179. package/include/numkong/scalar/haswell.h +113 -0
  180. package/include/numkong/scalar/neon.h +122 -0
  181. package/include/numkong/scalar/neonhalf.h +70 -0
  182. package/include/numkong/scalar/rvv.h +211 -0
  183. package/include/numkong/scalar/sapphire.h +63 -0
  184. package/include/numkong/scalar/serial.h +332 -0
  185. package/include/numkong/scalar/v128relaxed.h +56 -0
  186. package/include/numkong/scalar.h +683 -0
  187. package/include/numkong/set/README.md +179 -0
  188. package/include/numkong/set/haswell.h +334 -0
  189. package/include/numkong/set/icelake.h +485 -0
  190. package/include/numkong/set/neon.h +364 -0
  191. package/include/numkong/set/rvv.h +226 -0
  192. package/include/numkong/set/rvvbb.h +117 -0
  193. package/include/numkong/set/serial.h +174 -0
  194. package/include/numkong/set/sve.h +185 -0
  195. package/include/numkong/set/v128relaxed.h +240 -0
  196. package/include/numkong/set.h +457 -0
  197. package/include/numkong/set.hpp +114 -0
  198. package/include/numkong/sets/README.md +149 -0
  199. package/include/numkong/sets/haswell.h +63 -0
  200. package/include/numkong/sets/icelake.h +66 -0
  201. package/include/numkong/sets/neon.h +61 -0
  202. package/include/numkong/sets/serial.h +43 -0
  203. package/include/numkong/sets/smebi32.h +1099 -0
  204. package/include/numkong/sets/v128relaxed.h +58 -0
  205. package/include/numkong/sets.h +339 -0
  206. package/include/numkong/sparse/README.md +156 -0
  207. package/include/numkong/sparse/icelake.h +463 -0
  208. package/include/numkong/sparse/neon.h +288 -0
  209. package/include/numkong/sparse/serial.h +117 -0
  210. package/include/numkong/sparse/sve2.h +507 -0
  211. package/include/numkong/sparse/turin.h +322 -0
  212. package/include/numkong/sparse.h +363 -0
  213. package/include/numkong/sparse.hpp +113 -0
  214. package/include/numkong/spatial/README.md +435 -0
  215. package/include/numkong/spatial/alder.h +607 -0
  216. package/include/numkong/spatial/genoa.h +290 -0
  217. package/include/numkong/spatial/haswell.h +960 -0
  218. package/include/numkong/spatial/icelake.h +586 -0
  219. package/include/numkong/spatial/neon.h +773 -0
  220. package/include/numkong/spatial/neonbfdot.h +165 -0
  221. package/include/numkong/spatial/neonhalf.h +118 -0
  222. package/include/numkong/spatial/neonsdot.h +261 -0
  223. package/include/numkong/spatial/rvv.h +984 -0
  224. package/include/numkong/spatial/rvvbf16.h +123 -0
  225. package/include/numkong/spatial/rvvhalf.h +117 -0
  226. package/include/numkong/spatial/sapphire.h +343 -0
  227. package/include/numkong/spatial/serial.h +346 -0
  228. package/include/numkong/spatial/sierra.h +323 -0
  229. package/include/numkong/spatial/skylake.h +606 -0
  230. package/include/numkong/spatial/sve.h +224 -0
  231. package/include/numkong/spatial/svebfdot.h +122 -0
  232. package/include/numkong/spatial/svehalf.h +109 -0
  233. package/include/numkong/spatial/v128relaxed.h +717 -0
  234. package/include/numkong/spatial.h +1425 -0
  235. package/include/numkong/spatial.hpp +183 -0
  236. package/include/numkong/spatials/README.md +580 -0
  237. package/include/numkong/spatials/alder.h +94 -0
  238. package/include/numkong/spatials/genoa.h +94 -0
  239. package/include/numkong/spatials/haswell.h +219 -0
  240. package/include/numkong/spatials/icelake.h +113 -0
  241. package/include/numkong/spatials/neon.h +109 -0
  242. package/include/numkong/spatials/neonbfdot.h +60 -0
  243. package/include/numkong/spatials/neonfhm.h +92 -0
  244. package/include/numkong/spatials/neonhalf.h +58 -0
  245. package/include/numkong/spatials/neonsdot.h +109 -0
  246. package/include/numkong/spatials/rvv.h +1960 -0
  247. package/include/numkong/spatials/sapphireamx.h +1149 -0
  248. package/include/numkong/spatials/serial.h +226 -0
  249. package/include/numkong/spatials/sierra.h +96 -0
  250. package/include/numkong/spatials/skylake.h +184 -0
  251. package/include/numkong/spatials/sme.h +1901 -0
  252. package/include/numkong/spatials/smef64.h +465 -0
  253. package/include/numkong/spatials/v128relaxed.h +240 -0
  254. package/include/numkong/spatials.h +3021 -0
  255. package/include/numkong/spatials.hpp +508 -0
  256. package/include/numkong/tensor.hpp +1592 -0
  257. package/include/numkong/trigonometry/README.md +184 -0
  258. package/include/numkong/trigonometry/haswell.h +652 -0
  259. package/include/numkong/trigonometry/neon.h +639 -0
  260. package/include/numkong/trigonometry/rvv.h +699 -0
  261. package/include/numkong/trigonometry/serial.h +703 -0
  262. package/include/numkong/trigonometry/skylake.h +721 -0
  263. package/include/numkong/trigonometry/v128relaxed.h +666 -0
  264. package/include/numkong/trigonometry.h +467 -0
  265. package/include/numkong/trigonometry.hpp +166 -0
  266. package/include/numkong/types.h +1384 -0
  267. package/include/numkong/types.hpp +5603 -0
  268. package/include/numkong/vector.hpp +698 -0
  269. package/javascript/README.md +246 -0
  270. package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
  271. package/javascript/dist/cjs/numkong-wasm.js +617 -0
  272. package/javascript/dist/cjs/numkong.d.ts +343 -0
  273. package/javascript/dist/cjs/numkong.js +523 -0
  274. package/javascript/dist/cjs/package.json +3 -0
  275. package/javascript/dist/cjs/types.d.ts +284 -0
  276. package/javascript/dist/cjs/types.js +653 -0
  277. package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
  278. package/javascript/dist/esm/numkong-wasm.js +595 -0
  279. package/javascript/dist/esm/numkong.d.ts +343 -0
  280. package/javascript/dist/esm/numkong.js +452 -0
  281. package/javascript/dist/esm/package.json +3 -0
  282. package/javascript/dist/esm/types.d.ts +284 -0
  283. package/javascript/dist/esm/types.js +630 -0
  284. package/javascript/dist-package-cjs.json +3 -0
  285. package/javascript/dist-package-esm.json +3 -0
  286. package/javascript/node-gyp-build.d.ts +1 -0
  287. package/javascript/numkong-wasm.ts +756 -0
  288. package/javascript/numkong.c +689 -0
  289. package/javascript/numkong.ts +575 -0
  290. package/javascript/tsconfig-base.json +39 -0
  291. package/javascript/tsconfig-cjs.json +8 -0
  292. package/javascript/tsconfig-esm.json +8 -0
  293. package/javascript/types.ts +674 -0
  294. package/package.json +87 -0
@@ -0,0 +1,182 @@
1
+ /**
2
+ * @brief SIMD-accelerated Curved Space Similarity for Genoa.
3
+ * @file include/numkong/curved/genoa.h
4
+ * @author Ash Vardanian
5
+ * @date January 14, 2026
6
+ *
7
+ * @sa include/numkong/curved.h
8
+ *
9
+ * Implements bf16 bilinear forms using AVX-512 with BF16 extensions.
10
+ */
11
+ #ifndef NK_CURVED_GENOA_H
12
+ #define NK_CURVED_GENOA_H
13
+
14
+ #if NK_TARGET_X86_
15
+ #if NK_TARGET_GENOA
16
+
17
+ #include "numkong/types.h"
18
+ #include "numkong/spatial/genoa.h" // `nk_substract_bf16x32_genoa_`
19
+ #include "numkong/reduce/skylake.h" // `nk_reduce_add_f32x16_skylake_`
20
+
21
+ #if defined(__cplusplus)
22
+ extern "C" {
23
+ #endif
24
+
25
+ #if defined(__clang__)
26
+ #pragma clang attribute push( \
27
+ __attribute__((target("avx2,avx512f,avx512vl,avx512bw,avx512dq,avx512bf16,f16c,fma,bmi,bmi2"))), \
28
+ apply_to = function)
29
+ #elif defined(__GNUC__)
30
+ #pragma GCC push_options
31
+ #pragma GCC target("avx2", "avx512f", "avx512vl", "avx512bw", "avx512dq", "avx512bf16", "f16c", "fma", "bmi", "bmi2")
32
+ #endif
33
+
34
+ NK_PUBLIC void nk_bilinear_bf16_genoa(nk_bf16_t const *a, nk_bf16_t const *b, nk_bf16_t const *c, nk_size_t n,
35
+ nk_f32_t *result) {
36
+ nk_size_t const tail_length = n % 32;
37
+ nk_size_t const tail_start = n - tail_length;
38
+ __mmask32 const tail_mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, tail_length);
39
+ __m512 sum_f32x16 = _mm512_setzero_ps();
40
+
41
+ for (nk_size_t i = 0; i != n; ++i) {
42
+ nk_f32_t a_f32;
43
+ nk_bf16_to_f32_serial(a + i, &a_f32);
44
+ __m512 a_f32x16 = _mm512_set1_ps(a_f32);
45
+ __m512 cb_j_f32x16 = _mm512_setzero_ps();
46
+ __m512i b_bf16x32, c_bf16x32;
47
+ nk_size_t j = 0;
48
+
49
+ nk_bilinear_bf16_genoa_cycle:
50
+ if (j + 32 <= n) {
51
+ b_bf16x32 = _mm512_loadu_epi16(b + j);
52
+ c_bf16x32 = _mm512_loadu_epi16(c + i * n + j);
53
+ }
54
+ else {
55
+ b_bf16x32 = _mm512_maskz_loadu_epi16(tail_mask, b + tail_start);
56
+ c_bf16x32 = _mm512_maskz_loadu_epi16(tail_mask, c + i * n + tail_start);
57
+ }
58
+ cb_j_f32x16 = _mm512_dpbf16_ps(cb_j_f32x16, nk_m512bh_from_m512i_(b_bf16x32), nk_m512bh_from_m512i_(c_bf16x32));
59
+ j += 32;
60
+ if (j < n) goto nk_bilinear_bf16_genoa_cycle;
61
+ sum_f32x16 = _mm512_fmadd_ps(a_f32x16, cb_j_f32x16, sum_f32x16);
62
+ }
63
+
64
+ *result = _mm512_reduce_add_ps(sum_f32x16);
65
+ }
66
+
67
+ NK_PUBLIC void nk_mahalanobis_bf16_genoa(nk_bf16_t const *a, nk_bf16_t const *b, nk_bf16_t const *c, nk_size_t n,
68
+ nk_f32_t *result) {
69
+ nk_size_t const tail_length = n % 32;
70
+ nk_size_t const tail_start = n - tail_length;
71
+ __mmask32 const tail_mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, tail_length);
72
+ __m512 sum_f32x16 = _mm512_setzero_ps();
73
+
74
+ for (nk_size_t i = 0; i != n; ++i) {
75
+ nk_f32_t a_i, b_i;
76
+ nk_bf16_to_f32_serial(a + i, &a_i);
77
+ nk_bf16_to_f32_serial(b + i, &b_i);
78
+ __m512 diff_i_f32x16 = _mm512_set1_ps(a_i - b_i);
79
+ __m512 cdiff_j_f32x16 = _mm512_setzero_ps();
80
+ __m512i a_j_bf16x32, b_j_bf16x32, diff_j_bf16x32, c_bf16x32;
81
+ nk_size_t j = 0;
82
+
83
+ // The nested loop is cleaner to implement with a `goto` in this case:
84
+ nk_mahalanobis_bf16_genoa_cycle:
85
+ if (j + 32 <= n) {
86
+ a_j_bf16x32 = _mm512_loadu_epi16(a + j);
87
+ b_j_bf16x32 = _mm512_loadu_epi16(b + j);
88
+ c_bf16x32 = _mm512_loadu_epi16(c + i * n + j);
89
+ }
90
+ else {
91
+ a_j_bf16x32 = _mm512_maskz_loadu_epi16(tail_mask, a + tail_start);
92
+ b_j_bf16x32 = _mm512_maskz_loadu_epi16(tail_mask, b + tail_start);
93
+ c_bf16x32 = _mm512_maskz_loadu_epi16(tail_mask, c + i * n + tail_start);
94
+ }
95
+ diff_j_bf16x32 = nk_substract_bf16x32_genoa_(a_j_bf16x32, b_j_bf16x32);
96
+ cdiff_j_f32x16 = _mm512_dpbf16_ps(cdiff_j_f32x16, nk_m512bh_from_m512i_(diff_j_bf16x32),
97
+ nk_m512bh_from_m512i_(c_bf16x32));
98
+ j += 32;
99
+ if (j < n) goto nk_mahalanobis_bf16_genoa_cycle;
100
+ sum_f32x16 = _mm512_fmadd_ps(diff_i_f32x16, cdiff_j_f32x16, sum_f32x16);
101
+ }
102
+
103
+ nk_f32_t quadratic = _mm512_reduce_add_ps(sum_f32x16);
104
+ *result = nk_f32_sqrt_haswell(quadratic > 0 ? quadratic : 0);
105
+ }
106
+
107
+ NK_PUBLIC void nk_bilinear_bf16c_genoa(nk_bf16c_t const *a, nk_bf16c_t const *b, nk_bf16c_t const *c, nk_size_t n,
108
+ nk_f32c_t *results) {
109
+
110
+ // We take into account, that FMS is the same as FMA with a negative multiplier.
111
+ // To multiply a floating-point value by -1, we can use the `XOR` instruction to flip the sign bit.
112
+ // This way we can avoid the shuffling and the need for separate real and imaginary parts.
113
+ // For the imaginary part of the product, we would need to swap the real and imaginary parts of
114
+ // one of the vectors.
115
+ __m512i const sign_flip_i32x16 = _mm512_set1_epi32(0x80000000);
116
+ __m512i const swap_adjacent_i8x64 = _mm512_set_epi8( //
117
+ 61, 60, 63, 62, 57, 56, 59, 58, 53, 52, 55, 54, 49, 48, 51, 50, // 4th 128-bit lane
118
+ 45, 44, 47, 46, 41, 40, 43, 42, 37, 36, 39, 38, 33, 32, 35, 34, // 3rd 128-bit lane
119
+ 29, 28, 31, 30, 25, 24, 27, 26, 21, 20, 23, 22, 17, 16, 19, 18, // 2nd 128-bit lane
120
+ 13, 12, 15, 14, 9, 8, 11, 10, 5, 4, 7, 6, 1, 0, 3, 2 // 1st 128-bit lane
121
+ );
122
+
123
+ // Default case for arbitrary size `n`
124
+ nk_size_t const tail_length = n % 16;
125
+ nk_size_t const tail_start = n - tail_length;
126
+ __mmask32 const tail_mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, tail_length * 2);
127
+ nk_f32_t sum_real = 0;
128
+ nk_f32_t sum_imag = 0;
129
+
130
+ for (nk_size_t i = 0; i != n; ++i) {
131
+ nk_f32_t a_i_real, a_i_imag;
132
+ nk_bf16_to_f32_serial(&a[i].real, &a_i_real);
133
+ nk_bf16_to_f32_serial(&a[i].imag, &a_i_imag);
134
+ __m512 cb_j_real_f32x16 = _mm512_setzero_ps();
135
+ __m512 cb_j_imag_f32x16 = _mm512_setzero_ps();
136
+ __m512i b_bf16x32, c_bf16x32;
137
+ nk_size_t j = 0;
138
+
139
+ nk_bilinear_bf16c_skylake_cycle:
140
+ if (j + 16 <= n) {
141
+ b_bf16x32 = _mm512_loadu_epi16((nk_i16_t const *)(b + j));
142
+ c_bf16x32 = _mm512_loadu_epi16((nk_i16_t const *)(c + i * n + j));
143
+ }
144
+ else {
145
+ b_bf16x32 = _mm512_maskz_loadu_epi16(tail_mask, (nk_i16_t const *)(b + tail_start));
146
+ c_bf16x32 = _mm512_maskz_loadu_epi16(tail_mask, (nk_i16_t const *)(c + i * n + tail_start));
147
+ }
148
+ cb_j_real_f32x16 = _mm512_dpbf16_ps( //
149
+ cb_j_real_f32x16, //
150
+ nk_m512bh_from_m512i_(_mm512_xor_si512(c_bf16x32, sign_flip_i32x16)), //
151
+ nk_m512bh_from_m512i_(b_bf16x32));
152
+ cb_j_imag_f32x16 = _mm512_dpbf16_ps( //
153
+ cb_j_imag_f32x16, //
154
+ nk_m512bh_from_m512i_(_mm512_shuffle_epi8(c_bf16x32, swap_adjacent_i8x64)), //
155
+ nk_m512bh_from_m512i_(b_bf16x32));
156
+ j += 16;
157
+ if (j < n) goto nk_bilinear_bf16c_skylake_cycle;
158
+ // Horizontal sums are the expensive part of the computation:
159
+ nk_f32_t const cb_j_real = nk_reduce_add_f32x16_skylake_(cb_j_real_f32x16);
160
+ nk_f32_t const cb_j_imag = nk_reduce_add_f32x16_skylake_(cb_j_imag_f32x16);
161
+ sum_real += a_i_real * cb_j_real - a_i_imag * cb_j_imag;
162
+ sum_imag += a_i_real * cb_j_imag + a_i_imag * cb_j_real;
163
+ }
164
+
165
+ // Reduce horizontal sums:
166
+ results->real = sum_real;
167
+ results->imag = sum_imag;
168
+ }
169
+
170
+ #if defined(__clang__)
171
+ #pragma clang attribute pop
172
+ #elif defined(__GNUC__)
173
+ #pragma GCC pop_options
174
+ #endif
175
+
176
+ #if defined(__cplusplus)
177
+ } // extern "C"
178
+ #endif
179
+
180
+ #endif // NK_TARGET_GENOA
181
+ #endif // NK_TARGET_X86_
182
+ #endif // NK_CURVED_GENOA_H
@@ -0,0 +1,276 @@
1
+ /**
2
+ * @brief SIMD-accelerated Curved Space Similarity for Haswell.
3
+ * @file include/numkong/curved/haswell.h
4
+ * @author Ash Vardanian
5
+ * @date January 14, 2026
6
+ *
7
+ * @sa include/numkong/curved.h
8
+ *
9
+ * Implements f16 and bf16 bilinear forms using AVX2 with F16C conversion.
10
+ */
11
+ #ifndef NK_CURVED_HASWELL_H
12
+ #define NK_CURVED_HASWELL_H
13
+
14
+ #if NK_TARGET_X86_
15
+ #if NK_TARGET_HASWELL
16
+
17
+ #include "numkong/types.h"
18
+ #include "numkong/reduce/haswell.h" // `nk_reduce_add_f32x8_haswell_`
19
+ #include "numkong/spatial/haswell.h" // `nk_f32_sqrt_haswell`
20
+
21
+ #if defined(__cplusplus)
22
+ extern "C" {
23
+ #endif
24
+
25
+ #if defined(__clang__)
26
+ #pragma clang attribute push(__attribute__((target("avx2,f16c,fma,bmi,bmi2"))), apply_to = function)
27
+ #elif defined(__GNUC__)
28
+ #pragma GCC push_options
29
+ #pragma GCC target("avx2", "f16c", "fma", "bmi", "bmi2")
30
+ #endif
31
+
32
+ NK_PUBLIC void nk_bilinear_f32_haswell(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t n,
33
+ nk_f64_t *result) {
34
+ nk_size_t const tail_length = n % 4;
35
+ nk_size_t const tail_start = n - tail_length;
36
+ __m256d sum_f64x4 = _mm256_setzero_pd();
37
+
38
+ for (nk_size_t i = 0; i != n; ++i) {
39
+ __m256d a_f64x4 = _mm256_set1_pd((nk_f64_t)a[i]);
40
+ __m256d cb_j_f64x4 = _mm256_setzero_pd();
41
+ for (nk_size_t j = 0; j + 4 <= n; j += 4) {
42
+ __m256d b_f64x4 = _mm256_cvtps_pd(_mm_loadu_ps(b + j));
43
+ __m256d c_f64x4 = _mm256_cvtps_pd(_mm_loadu_ps(c + i * n + j));
44
+ cb_j_f64x4 = _mm256_fmadd_pd(b_f64x4, c_f64x4, cb_j_f64x4);
45
+ }
46
+ sum_f64x4 = _mm256_fmadd_pd(a_f64x4, cb_j_f64x4, sum_f64x4);
47
+ }
48
+
49
+ nk_f64_t sum = nk_reduce_add_f64x4_haswell_(sum_f64x4);
50
+ if (tail_length) {
51
+ nk_b128_vec_t b_tail_vec;
52
+ nk_partial_load_b32x4_haswell_(b + tail_start, &b_tail_vec, tail_length);
53
+ __m256d b_tail_f64x4 = _mm256_cvtps_pd(b_tail_vec.xmm_ps);
54
+ for (nk_size_t i = 0; i != n; ++i) {
55
+ nk_f64_t a_i = (nk_f64_t)a[i];
56
+ nk_b128_vec_t c_tail_vec;
57
+ nk_partial_load_b32x4_haswell_(c + i * n + tail_start, &c_tail_vec, tail_length);
58
+ __m256d c_tail_f64x4 = _mm256_cvtps_pd(c_tail_vec.xmm_ps);
59
+ sum += a_i * nk_reduce_add_f64x4_haswell_(_mm256_mul_pd(b_tail_f64x4, c_tail_f64x4));
60
+ }
61
+ }
62
+
63
+ *result = sum;
64
+ }
65
+
66
+ NK_PUBLIC void nk_mahalanobis_f32_haswell(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t n,
67
+ nk_f64_t *result) {
68
+ nk_size_t const tail_length = n % 4;
69
+ nk_size_t const tail_start = n - tail_length;
70
+ __m256d sum_f64x4 = _mm256_setzero_pd();
71
+
72
+ for (nk_size_t i = 0; i != n; ++i) {
73
+ __m256d diff_i_f64x4 = _mm256_set1_pd((nk_f64_t)a[i] - (nk_f64_t)b[i]);
74
+ __m256d cdiff_j_f64x4 = _mm256_setzero_pd();
75
+ for (nk_size_t j = 0; j + 4 <= n; j += 4) {
76
+ __m256d diff_j_f64x4 = _mm256_sub_pd( //
77
+ _mm256_cvtps_pd(_mm_loadu_ps(a + j)), _mm256_cvtps_pd(_mm_loadu_ps(b + j)));
78
+ __m256d c_f64x4 = _mm256_cvtps_pd(_mm_loadu_ps(c + i * n + j));
79
+ cdiff_j_f64x4 = _mm256_fmadd_pd(diff_j_f64x4, c_f64x4, cdiff_j_f64x4);
80
+ }
81
+ sum_f64x4 = _mm256_fmadd_pd(diff_i_f64x4, cdiff_j_f64x4, sum_f64x4);
82
+ }
83
+
84
+ nk_f64_t sum = nk_reduce_add_f64x4_haswell_(sum_f64x4);
85
+ if (tail_length) {
86
+ nk_b128_vec_t a_tail_vec, b_tail_vec;
87
+ nk_partial_load_b32x4_haswell_(a + tail_start, &a_tail_vec, tail_length);
88
+ nk_partial_load_b32x4_haswell_(b + tail_start, &b_tail_vec, tail_length);
89
+ __m256d diff_tail_f64x4 = _mm256_sub_pd(_mm256_cvtps_pd(a_tail_vec.xmm_ps), _mm256_cvtps_pd(b_tail_vec.xmm_ps));
90
+ for (nk_size_t i = 0; i != n; ++i) {
91
+ nk_f64_t diff_i = (nk_f64_t)a[i] - (nk_f64_t)b[i];
92
+ nk_b128_vec_t c_tail_vec;
93
+ nk_partial_load_b32x4_haswell_(c + i * n + tail_start, &c_tail_vec, tail_length);
94
+ __m256d c_tail_f64x4 = _mm256_cvtps_pd(c_tail_vec.xmm_ps);
95
+ sum += diff_i * nk_reduce_add_f64x4_haswell_(_mm256_mul_pd(diff_tail_f64x4, c_tail_f64x4));
96
+ }
97
+ }
98
+
99
+ *result = nk_f64_sqrt_haswell(sum > 0 ? sum : 0);
100
+ }
101
+
102
+ NK_PUBLIC void nk_bilinear_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_f16_t const *c, nk_size_t n,
103
+ nk_f32_t *result) {
104
+ __m256 sum_f32x8 = _mm256_setzero_ps();
105
+ for (nk_size_t i = 0; i != n; ++i) {
106
+ __m256 a_f32x8 = _mm256_cvtph_ps(_mm_set1_epi16(*(short const *)(a + i)));
107
+ __m256 cb_j_f32x8 = _mm256_setzero_ps();
108
+ for (nk_size_t j = 0; j + 8 <= n; j += 8) {
109
+ __m256 b_f32x8 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i const *)(b + j)));
110
+ __m256 c_f32x8 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i const *)(c + i * n + j)));
111
+ cb_j_f32x8 = _mm256_fmadd_ps(b_f32x8, c_f32x8, cb_j_f32x8);
112
+ }
113
+ sum_f32x8 = _mm256_fmadd_ps(a_f32x8, cb_j_f32x8, sum_f32x8);
114
+ }
115
+
116
+ // Handle the tail of every row
117
+ nk_f32_t sum = nk_reduce_add_f32x8_haswell_(sum_f32x8);
118
+ nk_size_t const tail_length = n % 8;
119
+ nk_size_t const tail_start = n - tail_length;
120
+ if (tail_length) {
121
+ for (nk_size_t i = 0; i != n; ++i) {
122
+ nk_f32_t a_i = _mm256_cvtss_f32(_mm256_cvtph_ps(_mm_set1_epi16(*(short const *)(a + i))));
123
+ nk_b256_vec_t b_vec;
124
+ nk_partial_load_f16x8_to_f32x8_haswell_(b + tail_start, &b_vec, tail_length);
125
+ __m256 b_f32x8 = b_vec.ymm_ps;
126
+ nk_b256_vec_t c_vec;
127
+ nk_partial_load_f16x8_to_f32x8_haswell_(c + i * n + tail_start, &c_vec, tail_length);
128
+ __m256 c_f32x8 = c_vec.ymm_ps;
129
+ nk_f32_t cb_j = nk_reduce_add_f32x8_haswell_(_mm256_mul_ps(b_f32x8, c_f32x8));
130
+ sum += a_i * cb_j;
131
+ }
132
+ }
133
+
134
+ *result = sum;
135
+ }
136
+
137
+ NK_PUBLIC void nk_mahalanobis_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_f16_t const *c, nk_size_t n,
138
+ nk_f32_t *result) {
139
+ __m256 sum_f32x8 = _mm256_setzero_ps();
140
+ for (nk_size_t i = 0; i != n; ++i) {
141
+ __m256 diff_i_f32x8 = _mm256_sub_ps( //
142
+ _mm256_cvtph_ps(_mm_set1_epi16(*(short const *)(a + i))), //
143
+ _mm256_cvtph_ps(_mm_set1_epi16(*(short const *)(b + i))));
144
+ __m256 cdiff_j_f32x8 = _mm256_setzero_ps();
145
+ for (nk_size_t j = 0; j + 8 <= n; j += 8) {
146
+ __m256 diff_j_f32x8 = _mm256_sub_ps( //
147
+ _mm256_cvtph_ps(_mm_loadu_si128((__m128i const *)(a + j))),
148
+ _mm256_cvtph_ps(_mm_loadu_si128((__m128i const *)(b + j))));
149
+ __m256 c_f32x8 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i const *)(c + i * n + j)));
150
+ cdiff_j_f32x8 = _mm256_fmadd_ps(diff_j_f32x8, c_f32x8, cdiff_j_f32x8);
151
+ }
152
+ sum_f32x8 = _mm256_fmadd_ps(diff_i_f32x8, cdiff_j_f32x8, sum_f32x8);
153
+ }
154
+
155
+ // Handle the tail of every row
156
+ nk_f32_t sum = nk_reduce_add_f32x8_haswell_(sum_f32x8);
157
+ nk_size_t const tail_length = n % 8;
158
+ nk_size_t const tail_start = n - tail_length;
159
+ if (tail_length) {
160
+ for (nk_size_t i = 0; i != n; ++i) {
161
+ nk_f32_t diff_i = _mm256_cvtss_f32(_mm256_sub_ps( //
162
+ _mm256_cvtph_ps(_mm_set1_epi16(*(short const *)(a + i))), //
163
+ _mm256_cvtph_ps(_mm_set1_epi16(*(short const *)(b + i)))));
164
+ nk_b256_vec_t a_tail_vec, b_tail_vec;
165
+ nk_partial_load_f16x8_to_f32x8_haswell_(a + tail_start, &a_tail_vec, tail_length);
166
+ nk_partial_load_f16x8_to_f32x8_haswell_(b + tail_start, &b_tail_vec, tail_length);
167
+ __m256 diff_j_f32x8 = _mm256_sub_ps(a_tail_vec.ymm_ps, b_tail_vec.ymm_ps);
168
+ nk_b256_vec_t c_vec;
169
+ nk_partial_load_f16x8_to_f32x8_haswell_(c + i * n + tail_start, &c_vec, tail_length);
170
+ __m256 c_f32x8 = c_vec.ymm_ps;
171
+ nk_f32_t cdiff_j = nk_reduce_add_f32x8_haswell_(_mm256_mul_ps(diff_j_f32x8, c_f32x8));
172
+ sum += diff_i * cdiff_j;
173
+ }
174
+ }
175
+
176
+ *result = nk_f32_sqrt_haswell(sum > 0 ? sum : 0);
177
+ }
178
+
179
+ NK_PUBLIC void nk_bilinear_bf16_haswell(nk_bf16_t const *a, nk_bf16_t const *b, nk_bf16_t const *c, nk_size_t n,
180
+ nk_f32_t *result) {
181
+ __m256 sum_f32x8 = _mm256_setzero_ps();
182
+ for (nk_size_t i = 0; i != n; ++i) {
183
+ // The `nk_bf16_to_f32_serial` is cheaper than `nk_bf16x8_to_f32x8_haswell_`
184
+ nk_f32_t a_f32;
185
+ nk_bf16_to_f32_serial(a + i, &a_f32);
186
+ __m256 a_f32x8 = _mm256_set1_ps(a_f32);
187
+ __m256 cb_j_f32x8 = _mm256_setzero_ps();
188
+ for (nk_size_t j = 0; j + 8 <= n; j += 8) {
189
+ __m256 b_f32x8 = nk_bf16x8_to_f32x8_haswell_(_mm_loadu_si128((__m128i const *)(b + j)));
190
+ __m256 c_f32x8 = nk_bf16x8_to_f32x8_haswell_(_mm_loadu_si128((__m128i const *)(c + i * n + j)));
191
+ cb_j_f32x8 = _mm256_fmadd_ps(b_f32x8, c_f32x8, cb_j_f32x8);
192
+ }
193
+ sum_f32x8 = _mm256_fmadd_ps(a_f32x8, cb_j_f32x8, sum_f32x8);
194
+ }
195
+
196
+ // Handle the tail of every row
197
+ nk_f32_t sum = nk_reduce_add_f32x8_haswell_(sum_f32x8);
198
+ nk_size_t const tail_length = n % 8;
199
+ nk_size_t const tail_start = n - tail_length;
200
+ if (tail_length) {
201
+ for (nk_size_t i = 0; i != n; ++i) {
202
+ nk_f32_t a_i;
203
+ nk_bf16_to_f32_serial(a + i, &a_i);
204
+ nk_b256_vec_t b_vec;
205
+ nk_partial_load_bf16x8_to_f32x8_haswell_(b + tail_start, &b_vec, tail_length);
206
+ __m256 b_f32x8 = b_vec.ymm_ps;
207
+ nk_b256_vec_t c_vec;
208
+ nk_partial_load_bf16x8_to_f32x8_haswell_(c + i * n + tail_start, &c_vec, tail_length);
209
+ __m256 c_f32x8 = c_vec.ymm_ps;
210
+ nk_f32_t cb_j = nk_reduce_add_f32x8_haswell_(_mm256_mul_ps(b_f32x8, c_f32x8));
211
+ sum += a_i * cb_j;
212
+ }
213
+ }
214
+
215
+ *result = sum;
216
+ }
217
+
218
+ NK_PUBLIC void nk_mahalanobis_bf16_haswell(nk_bf16_t const *a, nk_bf16_t const *b, nk_bf16_t const *c, nk_size_t n,
219
+ nk_f32_t *result) {
220
+ __m256 sum_f32x8 = _mm256_setzero_ps();
221
+ for (nk_size_t i = 0; i != n; ++i) {
222
+ nk_f32_t a_i, b_i;
223
+ nk_bf16_to_f32_serial(a + i, &a_i);
224
+ nk_bf16_to_f32_serial(b + i, &b_i);
225
+ __m256 diff_i_f32x8 = _mm256_sub_ps( //
226
+ _mm256_set1_ps(a_i), //
227
+ _mm256_set1_ps(b_i));
228
+ __m256 cdiff_j_f32x8 = _mm256_setzero_ps();
229
+ for (nk_size_t j = 0; j + 8 <= n; j += 8) {
230
+ __m256 diff_j_f32x8 = _mm256_sub_ps( //
231
+ nk_bf16x8_to_f32x8_haswell_(_mm_loadu_si128((__m128i const *)(a + j))), //
232
+ nk_bf16x8_to_f32x8_haswell_(_mm_loadu_si128((__m128i const *)(b + j))));
233
+ __m256 c_f32x8 = nk_bf16x8_to_f32x8_haswell_(_mm_loadu_si128((__m128i const *)(c + i * n + j)));
234
+ cdiff_j_f32x8 = _mm256_fmadd_ps(diff_j_f32x8, c_f32x8, cdiff_j_f32x8);
235
+ }
236
+ sum_f32x8 = _mm256_fmadd_ps(diff_i_f32x8, cdiff_j_f32x8, sum_f32x8);
237
+ }
238
+
239
+ // Handle the tail of every row
240
+ nk_f32_t sum = nk_reduce_add_f32x8_haswell_(sum_f32x8);
241
+ nk_size_t const tail_length = n % 8;
242
+ nk_size_t const tail_start = n - tail_length;
243
+ if (tail_length) {
244
+ for (nk_size_t i = 0; i != n; ++i) {
245
+ nk_f32_t a_i, b_i;
246
+ nk_bf16_to_f32_serial(a + i, &a_i);
247
+ nk_bf16_to_f32_serial(b + i, &b_i);
248
+ nk_f32_t diff_i = a_i - b_i;
249
+ nk_b256_vec_t a_tail_vec, b_tail_vec;
250
+ nk_partial_load_bf16x8_to_f32x8_haswell_(a + tail_start, &a_tail_vec, tail_length);
251
+ nk_partial_load_bf16x8_to_f32x8_haswell_(b + tail_start, &b_tail_vec, tail_length);
252
+ __m256 diff_j_f32x8 = _mm256_sub_ps(a_tail_vec.ymm_ps, b_tail_vec.ymm_ps);
253
+ nk_b256_vec_t c_vec;
254
+ nk_partial_load_bf16x8_to_f32x8_haswell_(c + i * n + tail_start, &c_vec, tail_length);
255
+ __m256 c_f32x8 = c_vec.ymm_ps;
256
+ nk_f32_t cdiff_j = nk_reduce_add_f32x8_haswell_(_mm256_mul_ps(diff_j_f32x8, c_f32x8));
257
+ sum += diff_i * cdiff_j;
258
+ }
259
+ }
260
+
261
+ *result = nk_f32_sqrt_haswell(sum > 0 ? sum : 0);
262
+ }
263
+
264
+ #if defined(__clang__)
265
+ #pragma clang attribute pop
266
+ #elif defined(__GNUC__)
267
+ #pragma GCC pop_options
268
+ #endif
269
+
270
+ #if defined(__cplusplus)
271
+ } // extern "C"
272
+ #endif
273
+
274
+ #endif // NK_TARGET_HASWELL
275
+ #endif // NK_TARGET_X86_
276
+ #endif // NK_CURVED_HASWELL_H
@@ -0,0 +1,205 @@
1
+ /**
2
+ * @brief SIMD-accelerated Curved Space Similarity for NEON.
3
+ * @file include/numkong/curved/neon.h
4
+ * @author Ash Vardanian
5
+ * @date January 14, 2026
6
+ *
7
+ * @sa include/numkong/curved.h
8
+ *
9
+ * Implements f32 bilinear forms and Mahalanobis distance using ARM NEON SIMD.
10
+ * Accumulates f32 inputs in f64 precision to avoid catastrophic cancellation.
11
+ *
12
+ * @section neon_curved_instructions Key NEON Instructions
13
+ *
14
+ * Intrinsic Instruction Latency Throughput
15
+ * A76 M4+/V1+/Oryon
16
+ * vfmaq_f64 FMLA (V.2D, V.2D, V.2D) 4cy 2/cy 4/cy
17
+ * vcvt_f64_f32 FCVTL (V.2D, V.2S) 3cy 2/cy 2/cy
18
+ * vaddvq_f64 FADDP (V.2D to scalar) 3cy 1/cy 1/cy
19
+ * vld1_f32 LD1 ({Vt.2S}, [Xn]) 4cy 2/cy 2/cy
20
+ * vld2_f32 LD2 ({Vt.2S, Vt2.2S}, [Xn]) 4cy 1/cy 1/cy
21
+ *
22
+ * For f32 bilinear and Mahalanobis, we upcast to f64 for accumulation to preserve
23
+ * precision and avoid catastrophic cancellation in large-magnitude sums.
24
+ */
25
+ #ifndef NK_CURVED_NEON_H
26
+ #define NK_CURVED_NEON_H
27
+
28
+ #if NK_TARGET_ARM_
29
+ #if NK_TARGET_NEON
30
+
31
+ #include "numkong/types.h"
32
+ #include "numkong/spatial/neon.h" // nk_f64_sqrt_neon
33
+
34
+ #if defined(__cplusplus)
35
+ extern "C" {
36
+ #endif
37
+
38
+ #if defined(__clang__)
39
+ #pragma clang attribute push(__attribute__((target("arch=armv8-a+simd"))), apply_to = function)
40
+ #elif defined(__GNUC__)
41
+ #pragma GCC push_options
42
+ #pragma GCC target("arch=armv8-a+simd")
43
+ #endif
44
+
45
+ NK_PUBLIC void nk_bilinear_f32_neon(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t n,
46
+ nk_f64_t *result) {
47
+ nk_f64_t outer_sum_f64 = 0;
48
+
49
+ for (nk_size_t i = 0; i != n; ++i) {
50
+ // Convert a[i] to f64 for precision
51
+ nk_f64_t a_i_f64 = (nk_f64_t)a[i];
52
+
53
+ // Inner loop: accumulate Σⱼ cᵢⱼ × bⱼ in f64
54
+ float64x2_t inner_sum_f64x2 = vdupq_n_f64(0);
55
+ nk_size_t j = 0;
56
+
57
+ // Vectorized inner loop: process 2 elements at a time
58
+ for (; j + 2 <= n; j += 2) {
59
+ // Load b[j:j+2] as f32, upcast to f64
60
+ float32x2_t b_f32x2 = vld1_f32(b + j);
61
+ float64x2_t b_f64x2 = vcvt_f64_f32(b_f32x2);
62
+
63
+ // Load c[i*n+j : i*n+j+2] as f32, upcast to f64
64
+ float32x2_t c_f32x2 = vld1_f32(c + i * n + j);
65
+ float64x2_t c_f64x2 = vcvt_f64_f32(c_f32x2);
66
+
67
+ // FMA: inner_sum += c × b
68
+ inner_sum_f64x2 = vfmaq_f64(inner_sum_f64x2, c_f64x2, b_f64x2);
69
+ }
70
+
71
+ // Reduce the f64x2 accumulator to scalar
72
+ nk_f64_t inner_sum_f64 = vaddvq_f64(inner_sum_f64x2);
73
+
74
+ // Handle tail elements
75
+ for (; j < n; ++j) { inner_sum_f64 += (nk_f64_t)c[i * n + j] * (nk_f64_t)b[j]; }
76
+
77
+ // Outer accumulation: outer_sum += aᵢ × inner_sum
78
+ outer_sum_f64 += a_i_f64 * inner_sum_f64;
79
+ }
80
+
81
+ *result = outer_sum_f64;
82
+ }
83
+
84
+ NK_PUBLIC void nk_mahalanobis_f32_neon(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t n,
85
+ nk_f64_t *result) {
86
+ nk_f64_t outer_sum_f64 = 0;
87
+
88
+ for (nk_size_t i = 0; i != n; ++i) {
89
+ // Compute difference (aᵢ - bᵢ) in f64 for precision
90
+ nk_f64_t diff_i_f64 = (nk_f64_t)a[i] - (nk_f64_t)b[i];
91
+
92
+ // Inner loop: accumulate Σⱼ cᵢⱼ × (aⱼ - bⱼ) in f64
93
+ float64x2_t inner_sum_f64x2 = vdupq_n_f64(0);
94
+ nk_size_t j = 0;
95
+
96
+ // Vectorized inner loop: process 2 elements at a time
97
+ for (; j + 2 <= n; j += 2) {
98
+ // Load a[j:j+2] and b[j:j+2] as f32
99
+ float32x2_t a_f32x2 = vld1_f32(a + j);
100
+ float32x2_t b_f32x2 = vld1_f32(b + j);
101
+
102
+ // Compute difference in f32, then upcast to f64
103
+ float32x2_t diff_f32x2 = vsub_f32(a_f32x2, b_f32x2);
104
+ float64x2_t diff_f64x2 = vcvt_f64_f32(diff_f32x2);
105
+
106
+ // Load c[i*n+j : i*n+j+2] as f32, upcast to f64
107
+ float32x2_t c_f32x2 = vld1_f32(c + i * n + j);
108
+ float64x2_t c_f64x2 = vcvt_f64_f32(c_f32x2);
109
+
110
+ // FMA: inner_sum += c × diff
111
+ inner_sum_f64x2 = vfmaq_f64(inner_sum_f64x2, c_f64x2, diff_f64x2);
112
+ }
113
+
114
+ // Reduce the f64x2 accumulator to scalar
115
+ nk_f64_t inner_sum_f64 = vaddvq_f64(inner_sum_f64x2);
116
+
117
+ // Handle tail elements
118
+ for (; j < n; ++j) {
119
+ nk_f64_t diff_j_f64 = (nk_f64_t)a[j] - (nk_f64_t)b[j];
120
+ inner_sum_f64 += (nk_f64_t)c[i * n + j] * diff_j_f64;
121
+ }
122
+
123
+ // Outer accumulation: outer_sum += diff_i × inner_sum
124
+ outer_sum_f64 += diff_i_f64 * inner_sum_f64;
125
+ }
126
+
127
+ // Take square root of the result (clamp to 0 for numerical stability)
128
+ *result = nk_f64_sqrt_neon(outer_sum_f64 > 0 ? outer_sum_f64 : 0);
129
+ }
130
+
131
+ NK_PUBLIC void nk_bilinear_f32c_neon(nk_f32c_t const *a_pairs, nk_f32c_t const *b_pairs, nk_f32c_t const *c_pairs,
132
+ nk_size_t n, nk_f64c_t *results) {
133
+ // ARMv8.3-A FCMLA (`vcmlaq_f32`) was benchmarked for this complex inner loop.
134
+ // The deinterleave+4FMA pattern is 2.3x faster on Apple M4 — see `dot/neon.h` comment.
135
+ nk_f64_t outer_sum_real_f64 = 0;
136
+ nk_f64_t outer_sum_imag_f64 = 0;
137
+
138
+ for (nk_size_t i = 0; i != n; ++i) {
139
+ // Convert a[i] to f64 for precision
140
+ nk_f64_t a_real_f64 = (nk_f64_t)a_pairs[i].real;
141
+ nk_f64_t a_imag_f64 = (nk_f64_t)a_pairs[i].imag;
142
+
143
+ // Inner loop: accumulate Σⱼ cᵢⱼ × bⱼ in f64
144
+ float64x2_t inner_sum_real_f64x2 = vdupq_n_f64(0);
145
+ float64x2_t inner_sum_imag_f64x2 = vdupq_n_f64(0);
146
+ nk_size_t j = 0;
147
+
148
+ // Vectorized inner loop: process 2 complex elements at a time
149
+ for (; j + 2 <= n; j += 2) {
150
+ // Load b[j:j+2] as interleaved complex pairs (real, imag, real, imag)
151
+ float32x2x2_t b_f32x2x2 = vld2_f32((nk_f32_t const *)(b_pairs + j));
152
+ float64x2_t b_real_f64x2 = vcvt_f64_f32(b_f32x2x2.val[0]);
153
+ float64x2_t b_imag_f64x2 = vcvt_f64_f32(b_f32x2x2.val[1]);
154
+
155
+ // Load c[i*n+j : i*n+j+2] as interleaved complex pairs
156
+ float32x2x2_t c_f32x2x2 = vld2_f32((nk_f32_t const *)(c_pairs + i * n + j));
157
+ float64x2_t c_real_f64x2 = vcvt_f64_f32(c_f32x2x2.val[0]);
158
+ float64x2_t c_imag_f64x2 = vcvt_f64_f32(c_f32x2x2.val[1]);
159
+
160
+ // Complex multiply
161
+ inner_sum_real_f64x2 = vfmaq_f64(inner_sum_real_f64x2, c_real_f64x2, b_real_f64x2);
162
+ inner_sum_real_f64x2 = vfmsq_f64(inner_sum_real_f64x2, c_imag_f64x2, b_imag_f64x2);
163
+
164
+ // Imaginary part: c_real×b_imag + c_imag×b_real
165
+ inner_sum_imag_f64x2 = vfmaq_f64(inner_sum_imag_f64x2, c_real_f64x2, b_imag_f64x2);
166
+ inner_sum_imag_f64x2 = vfmaq_f64(inner_sum_imag_f64x2, c_imag_f64x2, b_real_f64x2);
167
+ }
168
+
169
+ // Reduce the f64x2 accumulators to scalars
170
+ nk_f64_t inner_sum_real_f64 = vaddvq_f64(inner_sum_real_f64x2);
171
+ nk_f64_t inner_sum_imag_f64 = vaddvq_f64(inner_sum_imag_f64x2);
172
+
173
+ // Handle tail elements
174
+ for (; j < n; ++j) {
175
+ nk_f64_t b_real = (nk_f64_t)b_pairs[j].real;
176
+ nk_f64_t b_imag = (nk_f64_t)b_pairs[j].imag;
177
+ nk_f64_t c_real = (nk_f64_t)c_pairs[i * n + j].real;
178
+ nk_f64_t c_imag = (nk_f64_t)c_pairs[i * n + j].imag;
179
+ // Complex multiply: c × b
180
+ inner_sum_real_f64 += c_real * b_real - c_imag * b_imag;
181
+ inner_sum_imag_f64 += c_real * b_imag + c_imag * b_real;
182
+ }
183
+
184
+ // Outer accumulation
185
+ outer_sum_real_f64 += a_real_f64 * inner_sum_real_f64 - a_imag_f64 * inner_sum_imag_f64;
186
+ outer_sum_imag_f64 += a_real_f64 * inner_sum_imag_f64 + a_imag_f64 * inner_sum_real_f64;
187
+ }
188
+
189
+ results->real = outer_sum_real_f64;
190
+ results->imag = outer_sum_imag_f64;
191
+ }
192
+
193
+ #if defined(__clang__)
194
+ #pragma clang attribute pop
195
+ #elif defined(__GNUC__)
196
+ #pragma GCC pop_options
197
+ #endif
198
+
199
+ #if defined(__cplusplus)
200
+ } // extern "C"
201
+ #endif
202
+
203
+ #endif // NK_TARGET_NEON
204
+ #endif // NK_TARGET_ARM_
205
+ #endif // NK_CURVED_NEON_H