numkong 7.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (294) hide show
  1. package/LICENSE +201 -0
  2. package/README.md +495 -0
  3. package/binding.gyp +540 -0
  4. package/c/dispatch.h +512 -0
  5. package/c/dispatch_bf16.c +389 -0
  6. package/c/dispatch_bf16c.c +52 -0
  7. package/c/dispatch_e2m3.c +263 -0
  8. package/c/dispatch_e3m2.c +243 -0
  9. package/c/dispatch_e4m3.c +276 -0
  10. package/c/dispatch_e5m2.c +272 -0
  11. package/c/dispatch_f16.c +376 -0
  12. package/c/dispatch_f16c.c +58 -0
  13. package/c/dispatch_f32.c +378 -0
  14. package/c/dispatch_f32c.c +99 -0
  15. package/c/dispatch_f64.c +296 -0
  16. package/c/dispatch_f64c.c +98 -0
  17. package/c/dispatch_i16.c +96 -0
  18. package/c/dispatch_i32.c +89 -0
  19. package/c/dispatch_i4.c +150 -0
  20. package/c/dispatch_i64.c +86 -0
  21. package/c/dispatch_i8.c +289 -0
  22. package/c/dispatch_other.c +330 -0
  23. package/c/dispatch_u1.c +148 -0
  24. package/c/dispatch_u16.c +124 -0
  25. package/c/dispatch_u32.c +118 -0
  26. package/c/dispatch_u4.c +150 -0
  27. package/c/dispatch_u64.c +102 -0
  28. package/c/dispatch_u8.c +303 -0
  29. package/c/numkong.c +950 -0
  30. package/include/README.md +573 -0
  31. package/include/module.modulemap +129 -0
  32. package/include/numkong/attention/sapphireamx.h +1361 -0
  33. package/include/numkong/attention/sme.h +2066 -0
  34. package/include/numkong/attention.h +49 -0
  35. package/include/numkong/capabilities.h +748 -0
  36. package/include/numkong/cast/README.md +262 -0
  37. package/include/numkong/cast/haswell.h +975 -0
  38. package/include/numkong/cast/icelake.h +470 -0
  39. package/include/numkong/cast/neon.h +1192 -0
  40. package/include/numkong/cast/rvv.h +1021 -0
  41. package/include/numkong/cast/sapphire.h +262 -0
  42. package/include/numkong/cast/serial.h +2262 -0
  43. package/include/numkong/cast/skylake.h +856 -0
  44. package/include/numkong/cast/v128relaxed.h +180 -0
  45. package/include/numkong/cast.h +230 -0
  46. package/include/numkong/curved/README.md +223 -0
  47. package/include/numkong/curved/genoa.h +182 -0
  48. package/include/numkong/curved/haswell.h +276 -0
  49. package/include/numkong/curved/neon.h +205 -0
  50. package/include/numkong/curved/neonbfdot.h +212 -0
  51. package/include/numkong/curved/neonhalf.h +212 -0
  52. package/include/numkong/curved/rvv.h +305 -0
  53. package/include/numkong/curved/serial.h +207 -0
  54. package/include/numkong/curved/skylake.h +457 -0
  55. package/include/numkong/curved/smef64.h +506 -0
  56. package/include/numkong/curved.h +517 -0
  57. package/include/numkong/curved.hpp +144 -0
  58. package/include/numkong/dot/README.md +425 -0
  59. package/include/numkong/dot/alder.h +563 -0
  60. package/include/numkong/dot/genoa.h +315 -0
  61. package/include/numkong/dot/haswell.h +1688 -0
  62. package/include/numkong/dot/icelake.h +883 -0
  63. package/include/numkong/dot/neon.h +818 -0
  64. package/include/numkong/dot/neonbfdot.h +244 -0
  65. package/include/numkong/dot/neonfhm.h +360 -0
  66. package/include/numkong/dot/neonhalf.h +198 -0
  67. package/include/numkong/dot/neonsdot.h +508 -0
  68. package/include/numkong/dot/rvv.h +714 -0
  69. package/include/numkong/dot/rvvbb.h +72 -0
  70. package/include/numkong/dot/rvvbf16.h +123 -0
  71. package/include/numkong/dot/rvvhalf.h +129 -0
  72. package/include/numkong/dot/sapphire.h +141 -0
  73. package/include/numkong/dot/serial.h +838 -0
  74. package/include/numkong/dot/sierra.h +405 -0
  75. package/include/numkong/dot/skylake.h +1084 -0
  76. package/include/numkong/dot/sve.h +379 -0
  77. package/include/numkong/dot/svebfdot.h +74 -0
  78. package/include/numkong/dot/svehalf.h +123 -0
  79. package/include/numkong/dot/v128relaxed.h +1258 -0
  80. package/include/numkong/dot.h +1070 -0
  81. package/include/numkong/dot.hpp +94 -0
  82. package/include/numkong/dots/README.md +496 -0
  83. package/include/numkong/dots/alder.h +114 -0
  84. package/include/numkong/dots/genoa.h +94 -0
  85. package/include/numkong/dots/haswell.h +295 -0
  86. package/include/numkong/dots/icelake.h +171 -0
  87. package/include/numkong/dots/neon.h +120 -0
  88. package/include/numkong/dots/neonbfdot.h +58 -0
  89. package/include/numkong/dots/neonfhm.h +94 -0
  90. package/include/numkong/dots/neonhalf.h +57 -0
  91. package/include/numkong/dots/neonsdot.h +108 -0
  92. package/include/numkong/dots/rvv.h +2486 -0
  93. package/include/numkong/dots/sapphireamx.h +3973 -0
  94. package/include/numkong/dots/serial.h +2844 -0
  95. package/include/numkong/dots/sierra.h +97 -0
  96. package/include/numkong/dots/skylake.h +196 -0
  97. package/include/numkong/dots/sme.h +5372 -0
  98. package/include/numkong/dots/smebi32.h +461 -0
  99. package/include/numkong/dots/smef64.h +1318 -0
  100. package/include/numkong/dots/smehalf.h +47 -0
  101. package/include/numkong/dots/v128relaxed.h +294 -0
  102. package/include/numkong/dots.h +2804 -0
  103. package/include/numkong/dots.hpp +639 -0
  104. package/include/numkong/each/README.md +469 -0
  105. package/include/numkong/each/haswell.h +1658 -0
  106. package/include/numkong/each/icelake.h +272 -0
  107. package/include/numkong/each/neon.h +1104 -0
  108. package/include/numkong/each/neonbfdot.h +212 -0
  109. package/include/numkong/each/neonhalf.h +410 -0
  110. package/include/numkong/each/rvv.h +1121 -0
  111. package/include/numkong/each/sapphire.h +477 -0
  112. package/include/numkong/each/serial.h +260 -0
  113. package/include/numkong/each/skylake.h +1562 -0
  114. package/include/numkong/each.h +2146 -0
  115. package/include/numkong/each.hpp +434 -0
  116. package/include/numkong/geospatial/README.md +147 -0
  117. package/include/numkong/geospatial/haswell.h +593 -0
  118. package/include/numkong/geospatial/neon.h +571 -0
  119. package/include/numkong/geospatial/rvv.h +701 -0
  120. package/include/numkong/geospatial/serial.h +309 -0
  121. package/include/numkong/geospatial/skylake.h +577 -0
  122. package/include/numkong/geospatial/v128relaxed.h +613 -0
  123. package/include/numkong/geospatial.h +453 -0
  124. package/include/numkong/geospatial.hpp +235 -0
  125. package/include/numkong/matrix.hpp +336 -0
  126. package/include/numkong/maxsim/README.md +187 -0
  127. package/include/numkong/maxsim/alder.h +511 -0
  128. package/include/numkong/maxsim/genoa.h +115 -0
  129. package/include/numkong/maxsim/haswell.h +553 -0
  130. package/include/numkong/maxsim/icelake.h +480 -0
  131. package/include/numkong/maxsim/neonsdot.h +394 -0
  132. package/include/numkong/maxsim/sapphireamx.h +877 -0
  133. package/include/numkong/maxsim/serial.h +490 -0
  134. package/include/numkong/maxsim/sme.h +929 -0
  135. package/include/numkong/maxsim/v128relaxed.h +280 -0
  136. package/include/numkong/maxsim.h +571 -0
  137. package/include/numkong/maxsim.hpp +133 -0
  138. package/include/numkong/mesh/README.md +227 -0
  139. package/include/numkong/mesh/haswell.h +2235 -0
  140. package/include/numkong/mesh/neon.h +1329 -0
  141. package/include/numkong/mesh/neonbfdot.h +842 -0
  142. package/include/numkong/mesh/neonhalf.h +616 -0
  143. package/include/numkong/mesh/rvv.h +916 -0
  144. package/include/numkong/mesh/serial.h +742 -0
  145. package/include/numkong/mesh/skylake.h +1135 -0
  146. package/include/numkong/mesh/v128relaxed.h +1052 -0
  147. package/include/numkong/mesh.h +652 -0
  148. package/include/numkong/mesh.hpp +762 -0
  149. package/include/numkong/numkong.h +78 -0
  150. package/include/numkong/numkong.hpp +57 -0
  151. package/include/numkong/probability/README.md +173 -0
  152. package/include/numkong/probability/haswell.h +267 -0
  153. package/include/numkong/probability/neon.h +225 -0
  154. package/include/numkong/probability/rvv.h +409 -0
  155. package/include/numkong/probability/serial.h +169 -0
  156. package/include/numkong/probability/skylake.h +324 -0
  157. package/include/numkong/probability.h +383 -0
  158. package/include/numkong/probability.hpp +120 -0
  159. package/include/numkong/random.h +50 -0
  160. package/include/numkong/random.hpp +285 -0
  161. package/include/numkong/reduce/README.md +547 -0
  162. package/include/numkong/reduce/alder.h +632 -0
  163. package/include/numkong/reduce/genoa.h +201 -0
  164. package/include/numkong/reduce/haswell.h +3783 -0
  165. package/include/numkong/reduce/icelake.h +549 -0
  166. package/include/numkong/reduce/neon.h +3841 -0
  167. package/include/numkong/reduce/neonbfdot.h +353 -0
  168. package/include/numkong/reduce/neonfhm.h +665 -0
  169. package/include/numkong/reduce/neonhalf.h +157 -0
  170. package/include/numkong/reduce/neonsdot.h +357 -0
  171. package/include/numkong/reduce/rvv.h +3407 -0
  172. package/include/numkong/reduce/serial.h +757 -0
  173. package/include/numkong/reduce/sierra.h +338 -0
  174. package/include/numkong/reduce/skylake.h +3792 -0
  175. package/include/numkong/reduce/v128relaxed.h +2302 -0
  176. package/include/numkong/reduce.h +1597 -0
  177. package/include/numkong/reduce.hpp +633 -0
  178. package/include/numkong/scalar/README.md +89 -0
  179. package/include/numkong/scalar/haswell.h +113 -0
  180. package/include/numkong/scalar/neon.h +122 -0
  181. package/include/numkong/scalar/neonhalf.h +70 -0
  182. package/include/numkong/scalar/rvv.h +211 -0
  183. package/include/numkong/scalar/sapphire.h +63 -0
  184. package/include/numkong/scalar/serial.h +332 -0
  185. package/include/numkong/scalar/v128relaxed.h +56 -0
  186. package/include/numkong/scalar.h +683 -0
  187. package/include/numkong/set/README.md +179 -0
  188. package/include/numkong/set/haswell.h +334 -0
  189. package/include/numkong/set/icelake.h +485 -0
  190. package/include/numkong/set/neon.h +364 -0
  191. package/include/numkong/set/rvv.h +226 -0
  192. package/include/numkong/set/rvvbb.h +117 -0
  193. package/include/numkong/set/serial.h +174 -0
  194. package/include/numkong/set/sve.h +185 -0
  195. package/include/numkong/set/v128relaxed.h +240 -0
  196. package/include/numkong/set.h +457 -0
  197. package/include/numkong/set.hpp +114 -0
  198. package/include/numkong/sets/README.md +149 -0
  199. package/include/numkong/sets/haswell.h +63 -0
  200. package/include/numkong/sets/icelake.h +66 -0
  201. package/include/numkong/sets/neon.h +61 -0
  202. package/include/numkong/sets/serial.h +43 -0
  203. package/include/numkong/sets/smebi32.h +1099 -0
  204. package/include/numkong/sets/v128relaxed.h +58 -0
  205. package/include/numkong/sets.h +339 -0
  206. package/include/numkong/sparse/README.md +156 -0
  207. package/include/numkong/sparse/icelake.h +463 -0
  208. package/include/numkong/sparse/neon.h +288 -0
  209. package/include/numkong/sparse/serial.h +117 -0
  210. package/include/numkong/sparse/sve2.h +507 -0
  211. package/include/numkong/sparse/turin.h +322 -0
  212. package/include/numkong/sparse.h +363 -0
  213. package/include/numkong/sparse.hpp +113 -0
  214. package/include/numkong/spatial/README.md +435 -0
  215. package/include/numkong/spatial/alder.h +607 -0
  216. package/include/numkong/spatial/genoa.h +290 -0
  217. package/include/numkong/spatial/haswell.h +960 -0
  218. package/include/numkong/spatial/icelake.h +586 -0
  219. package/include/numkong/spatial/neon.h +773 -0
  220. package/include/numkong/spatial/neonbfdot.h +165 -0
  221. package/include/numkong/spatial/neonhalf.h +118 -0
  222. package/include/numkong/spatial/neonsdot.h +261 -0
  223. package/include/numkong/spatial/rvv.h +984 -0
  224. package/include/numkong/spatial/rvvbf16.h +123 -0
  225. package/include/numkong/spatial/rvvhalf.h +117 -0
  226. package/include/numkong/spatial/sapphire.h +343 -0
  227. package/include/numkong/spatial/serial.h +346 -0
  228. package/include/numkong/spatial/sierra.h +323 -0
  229. package/include/numkong/spatial/skylake.h +606 -0
  230. package/include/numkong/spatial/sve.h +224 -0
  231. package/include/numkong/spatial/svebfdot.h +122 -0
  232. package/include/numkong/spatial/svehalf.h +109 -0
  233. package/include/numkong/spatial/v128relaxed.h +717 -0
  234. package/include/numkong/spatial.h +1425 -0
  235. package/include/numkong/spatial.hpp +183 -0
  236. package/include/numkong/spatials/README.md +580 -0
  237. package/include/numkong/spatials/alder.h +94 -0
  238. package/include/numkong/spatials/genoa.h +94 -0
  239. package/include/numkong/spatials/haswell.h +219 -0
  240. package/include/numkong/spatials/icelake.h +113 -0
  241. package/include/numkong/spatials/neon.h +109 -0
  242. package/include/numkong/spatials/neonbfdot.h +60 -0
  243. package/include/numkong/spatials/neonfhm.h +92 -0
  244. package/include/numkong/spatials/neonhalf.h +58 -0
  245. package/include/numkong/spatials/neonsdot.h +109 -0
  246. package/include/numkong/spatials/rvv.h +1960 -0
  247. package/include/numkong/spatials/sapphireamx.h +1149 -0
  248. package/include/numkong/spatials/serial.h +226 -0
  249. package/include/numkong/spatials/sierra.h +96 -0
  250. package/include/numkong/spatials/skylake.h +184 -0
  251. package/include/numkong/spatials/sme.h +1901 -0
  252. package/include/numkong/spatials/smef64.h +465 -0
  253. package/include/numkong/spatials/v128relaxed.h +240 -0
  254. package/include/numkong/spatials.h +3021 -0
  255. package/include/numkong/spatials.hpp +508 -0
  256. package/include/numkong/tensor.hpp +1592 -0
  257. package/include/numkong/trigonometry/README.md +184 -0
  258. package/include/numkong/trigonometry/haswell.h +652 -0
  259. package/include/numkong/trigonometry/neon.h +639 -0
  260. package/include/numkong/trigonometry/rvv.h +699 -0
  261. package/include/numkong/trigonometry/serial.h +703 -0
  262. package/include/numkong/trigonometry/skylake.h +721 -0
  263. package/include/numkong/trigonometry/v128relaxed.h +666 -0
  264. package/include/numkong/trigonometry.h +467 -0
  265. package/include/numkong/trigonometry.hpp +166 -0
  266. package/include/numkong/types.h +1384 -0
  267. package/include/numkong/types.hpp +5603 -0
  268. package/include/numkong/vector.hpp +698 -0
  269. package/javascript/README.md +246 -0
  270. package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
  271. package/javascript/dist/cjs/numkong-wasm.js +617 -0
  272. package/javascript/dist/cjs/numkong.d.ts +343 -0
  273. package/javascript/dist/cjs/numkong.js +523 -0
  274. package/javascript/dist/cjs/package.json +3 -0
  275. package/javascript/dist/cjs/types.d.ts +284 -0
  276. package/javascript/dist/cjs/types.js +653 -0
  277. package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
  278. package/javascript/dist/esm/numkong-wasm.js +595 -0
  279. package/javascript/dist/esm/numkong.d.ts +343 -0
  280. package/javascript/dist/esm/numkong.js +452 -0
  281. package/javascript/dist/esm/package.json +3 -0
  282. package/javascript/dist/esm/types.d.ts +284 -0
  283. package/javascript/dist/esm/types.js +630 -0
  284. package/javascript/dist-package-cjs.json +3 -0
  285. package/javascript/dist-package-esm.json +3 -0
  286. package/javascript/node-gyp-build.d.ts +1 -0
  287. package/javascript/numkong-wasm.ts +756 -0
  288. package/javascript/numkong.c +689 -0
  289. package/javascript/numkong.ts +575 -0
  290. package/javascript/tsconfig-base.json +39 -0
  291. package/javascript/tsconfig-cjs.json +8 -0
  292. package/javascript/tsconfig-esm.json +8 -0
  293. package/javascript/types.ts +674 -0
  294. package/package.json +87 -0
@@ -0,0 +1,212 @@
1
+ /**
2
+ * @brief SIMD-accelerated Curved Space Similarity for NEON BF16.
3
+ * @file include/numkong/curved/neonbfdot.h
4
+ * @author Ash Vardanian
5
+ * @date January 14, 2026
6
+ *
7
+ * @sa include/numkong/curved.h
8
+ *
9
+ * Implements bf16 bilinear forms and Mahalanobis distance using ARM NEON with BF16 extensions.
10
+ *
11
+ * @section curved_neonbfdot_instructions ARM NEON BF16 Instructions (ARMv8.6-BF16)
12
+ *
13
+ * Intrinsic Instruction Latency Throughput
14
+ * A76 M4+/V1+/Oryon
15
+ * vbfdotq_f32 BFDOT (V.4S, V.8H, V.8H) 3cy 2/cy 4/cy
16
+ * vcvt_f32_bf16 BFCVTN (V.4H, V.4S) 3cy 2/cy 4/cy
17
+ * vld1q_bf16 LD1 (V.8H) 4cy 2/cy 3/cy
18
+ * vaddvq_f32 FADDP+FADDP (V.4S) 4cy 1/cy 2/cy
19
+ * vfmaq_f32 FMLA (V.4S, V.4S, V.4S) 4cy 2/cy 4/cy
20
+ *
21
+ * For bilinear forms, BFDOT enables efficient inner-product computation by processing 8 bf16
22
+ * pairs into 4 f32 results per instruction. For Mahalanobis distance, bf16 inputs are converted
23
+ * to f32 for subtraction, then accumulated using FMA for numerical stability.
24
+ */
25
+ #ifndef NK_CURVED_NEONBFDOT_H
26
+ #define NK_CURVED_NEONBFDOT_H
27
+
28
+ #if NK_TARGET_ARM_
29
+ #if NK_TARGET_NEONBFDOT
30
+
31
+ #include "numkong/types.h" // `nk_bf16_t`
32
+ #include "numkong/spatial/neon.h" // `nk_f32_sqrt_neon`
33
+ #include "numkong/cast/serial.h" // `nk_bf16_to_f32_serial`
34
+
35
+ #if defined(__cplusplus)
36
+ extern "C" {
37
+ #endif
38
+
39
+ #if defined(__clang__)
40
+ #pragma clang attribute push(__attribute__((target("arch=armv8.6-a+simd+bf16"))), apply_to = function)
41
+ #elif defined(__GNUC__)
42
+ #pragma GCC push_options
43
+ #pragma GCC target("arch=armv8.6-a+simd+bf16")
44
+ #endif
45
+
46
+ NK_PUBLIC void nk_bilinear_bf16_neonbfdot(nk_bf16_t const *a, nk_bf16_t const *b, nk_bf16_t const *c, nk_size_t n,
47
+ nk_f32_t *result) {
48
+ float32x4_t outer_sum_f32x4 = vdupq_n_f32(0);
49
+
50
+ for (nk_size_t i = 0; i != n; ++i) {
51
+ // Load a[i] and broadcast to f32
52
+ nk_f32_t a_i_f32;
53
+ nk_bf16_to_f32_serial(a + i, &a_i_f32);
54
+ float32x4_t a_i_f32x4 = vdupq_n_f32(a_i_f32);
55
+
56
+ // Inner sum
57
+ float32x4_t inner_sum_f32x4 = vdupq_n_f32(0);
58
+ nk_size_t j = 0;
59
+
60
+ // Process 8 elements at a time using BFDOT
61
+ for (; j + 8 <= n; j += 8) {
62
+ bfloat16x8_t b_bf16x8 = vld1q_bf16((nk_bf16_for_arm_simd_t const *)(b + j));
63
+ bfloat16x8_t c_bf16x8 = vld1q_bf16((nk_bf16_for_arm_simd_t const *)(c + i * n + j));
64
+ inner_sum_f32x4 = vbfdotq_f32(inner_sum_f32x4, c_bf16x8, b_bf16x8);
65
+ }
66
+
67
+ // Handle tail elements (less than 8)
68
+ if (j < n) {
69
+ nk_b128_vec_t b_vec, c_vec;
70
+ nk_partial_load_b16x8_serial_(b + j, &b_vec, n - j);
71
+ nk_partial_load_b16x8_serial_(c + i * n + j, &c_vec, n - j);
72
+ bfloat16x8_t b_bf16x8 = vreinterpretq_bf16_u16(b_vec.u16x8);
73
+ bfloat16x8_t c_bf16x8 = vreinterpretq_bf16_u16(c_vec.u16x8);
74
+ inner_sum_f32x4 = vbfdotq_f32(inner_sum_f32x4, c_bf16x8, b_bf16x8);
75
+ }
76
+
77
+ // Accumulate: outer_sum += a[i] * inner_sum
78
+ outer_sum_f32x4 = vfmaq_f32(outer_sum_f32x4, a_i_f32x4, inner_sum_f32x4);
79
+ }
80
+
81
+ *result = vaddvq_f32(outer_sum_f32x4);
82
+ }
83
+
84
+ NK_PUBLIC void nk_mahalanobis_bf16_neonbfdot(nk_bf16_t const *a, nk_bf16_t const *b, nk_bf16_t const *c, nk_size_t n,
85
+ nk_f32_t *result) {
86
+ nk_f32_t outer_sum = 0;
87
+
88
+ for (nk_size_t i = 0; i != n; ++i) {
89
+ // Compute diff_i = a[i] - b[i] in f32
90
+ nk_f32_t a_i_f32, b_i_f32;
91
+ nk_bf16_to_f32_serial(a + i, &a_i_f32);
92
+ nk_bf16_to_f32_serial(b + i, &b_i_f32);
93
+ nk_f32_t diff_i = a_i_f32 - b_i_f32;
94
+
95
+ // Inner sum
96
+ float32x4_t inner_sum_f32x4 = vdupq_n_f32(0);
97
+ nk_size_t j = 0;
98
+
99
+ // Process 4 elements at a time (convert bf16->f32, subtract, then FMA)
100
+ for (; j + 4 <= n; j += 4) {
101
+ bfloat16x4_t a_j_bf16x4 = vld1_bf16((nk_bf16_for_arm_simd_t const *)(a + j));
102
+ bfloat16x4_t b_j_bf16x4 = vld1_bf16((nk_bf16_for_arm_simd_t const *)(b + j));
103
+ bfloat16x4_t c_bf16x4 = vld1_bf16((nk_bf16_for_arm_simd_t const *)(c + i * n + j));
104
+
105
+ float32x4_t a_j_f32x4 = vcvt_f32_bf16(a_j_bf16x4);
106
+ float32x4_t b_j_f32x4 = vcvt_f32_bf16(b_j_bf16x4);
107
+ float32x4_t c_f32x4 = vcvt_f32_bf16(c_bf16x4);
108
+
109
+ float32x4_t diff_j_f32x4 = vsubq_f32(a_j_f32x4, b_j_f32x4);
110
+ inner_sum_f32x4 = vfmaq_f32(inner_sum_f32x4, c_f32x4, diff_j_f32x4);
111
+ }
112
+
113
+ // Handle tail elements
114
+ nk_f32_t inner_sum_tail = 0;
115
+ for (; j < n; ++j) {
116
+ nk_f32_t a_j_f32, b_j_f32, c_f32;
117
+ nk_bf16_to_f32_serial(a + j, &a_j_f32);
118
+ nk_bf16_to_f32_serial(b + j, &b_j_f32);
119
+ nk_bf16_to_f32_serial(c + i * n + j, &c_f32);
120
+ inner_sum_tail += c_f32 * (a_j_f32 - b_j_f32);
121
+ }
122
+
123
+ // Reduce inner sum and add tail
124
+ nk_f32_t inner_sum = vaddvq_f32(inner_sum_f32x4) + inner_sum_tail;
125
+
126
+ // Accumulate: outer_sum += diff_i * inner_sum
127
+ outer_sum += diff_i * inner_sum;
128
+ }
129
+
130
+ nk_f32_t quadratic = outer_sum;
131
+ *result = nk_f32_sqrt_neon(quadratic > 0 ? quadratic : 0);
132
+ }
133
+
134
+ NK_PUBLIC void nk_bilinear_bf16c_neonbfdot(nk_bf16c_t const *a_pairs, nk_bf16c_t const *b_pairs,
135
+ nk_bf16c_t const *c_pairs, nk_size_t n, nk_f32c_t *result) {
136
+ // ARMv8.3-A FCMLA was benchmarked for this complex multiply pattern.
137
+ // The deinterleave+4FMA approach is 2.3x faster on Apple M4 — see `dot/neon.h` comment.
138
+ nk_f32_t outer_sum_real = 0;
139
+ nk_f32_t outer_sum_imag = 0;
140
+
141
+ for (nk_size_t i = 0; i != n; ++i) {
142
+ // Load a[i] as complex (real, imag) and convert to f32
143
+ nk_f32_t a_real, a_imag;
144
+ nk_bf16_to_f32_serial(&a_pairs[i].real, &a_real);
145
+ nk_bf16_to_f32_serial(&a_pairs[i].imag, &a_imag);
146
+
147
+ // Inner sums for real and imaginary parts of c[i,j] * b[j]
148
+ float32x4_t inner_sum_real_f32x4 = vdupq_n_f32(0);
149
+ float32x4_t inner_sum_imag_f32x4 = vdupq_n_f32(0);
150
+ nk_size_t j = 0;
151
+
152
+ // Process 4 complex pairs at a time
153
+ for (; j + 4 <= n; j += 4) {
154
+ // Deinterleave load: separate real and imaginary parts
155
+ // MSVC doesn't support vld2_bf16, so load as s16 and reinterpret
156
+ int16x4x2_t b_i16x4x2 = vld2_s16((short const *)(b_pairs + j));
157
+ int16x4x2_t c_i16x4x2 = vld2_s16((short const *)(c_pairs + i * n + j));
158
+
159
+ float32x4_t b_real_f32x4 = vcvt_f32_bf16(vreinterpret_bf16_s16(b_i16x4x2.val[0]));
160
+ float32x4_t b_imag_f32x4 = vcvt_f32_bf16(vreinterpret_bf16_s16(b_i16x4x2.val[1]));
161
+ float32x4_t c_real_f32x4 = vcvt_f32_bf16(vreinterpret_bf16_s16(c_i16x4x2.val[0]));
162
+ float32x4_t c_imag_f32x4 = vcvt_f32_bf16(vreinterpret_bf16_s16(c_i16x4x2.val[1]));
163
+
164
+ // Complex multiply: c * b = (c_real*b_real - c_imag*b_imag) + (c_real*b_imag + c_imag*b_real)*i
165
+ // Real part: c_real*b_real - c_imag*b_imag
166
+ inner_sum_real_f32x4 = vfmaq_f32(inner_sum_real_f32x4, c_real_f32x4, b_real_f32x4);
167
+ inner_sum_real_f32x4 = vfmsq_f32(inner_sum_real_f32x4, c_imag_f32x4, b_imag_f32x4);
168
+ // Imaginary part: c_real*b_imag + c_imag*b_real
169
+ inner_sum_imag_f32x4 = vfmaq_f32(inner_sum_imag_f32x4, c_real_f32x4, b_imag_f32x4);
170
+ inner_sum_imag_f32x4 = vfmaq_f32(inner_sum_imag_f32x4, c_imag_f32x4, b_real_f32x4);
171
+ }
172
+
173
+ // Handle tail elements
174
+ nk_f32_t inner_sum_real_tail = 0, inner_sum_imag_tail = 0;
175
+ for (; j < n; ++j) {
176
+ nk_f32_t b_real, b_imag, c_real, c_imag;
177
+ nk_bf16_to_f32_serial(&b_pairs[j].real, &b_real);
178
+ nk_bf16_to_f32_serial(&b_pairs[j].imag, &b_imag);
179
+ nk_bf16_to_f32_serial(&c_pairs[i * n + j].real, &c_real);
180
+ nk_bf16_to_f32_serial(&c_pairs[i * n + j].imag, &c_imag);
181
+ // Complex multiply: c * b
182
+ inner_sum_real_tail += c_real * b_real - c_imag * b_imag;
183
+ inner_sum_imag_tail += c_real * b_imag + c_imag * b_real;
184
+ }
185
+
186
+ // Reduce inner sums
187
+ nk_f32_t inner_sum_real = vaddvq_f32(inner_sum_real_f32x4) + inner_sum_real_tail;
188
+ nk_f32_t inner_sum_imag = vaddvq_f32(inner_sum_imag_f32x4) + inner_sum_imag_tail;
189
+
190
+ // Complex multiply: a * inner_sum = (a_real*inner_real - a_imag*inner_imag) + (a_real*inner_imag +
191
+ // a_imag*inner_real)*i
192
+ outer_sum_real += a_real * inner_sum_real - a_imag * inner_sum_imag;
193
+ outer_sum_imag += a_real * inner_sum_imag + a_imag * inner_sum_real;
194
+ }
195
+
196
+ result->real = outer_sum_real;
197
+ result->imag = outer_sum_imag;
198
+ }
199
+
200
+ #if defined(__clang__)
201
+ #pragma clang attribute pop
202
+ #elif defined(__GNUC__)
203
+ #pragma GCC pop_options
204
+ #endif
205
+
206
+ #if defined(__cplusplus)
207
+ } // extern "C"
208
+ #endif
209
+
210
+ #endif // NK_TARGET_NEONBFDOT
211
+ #endif // NK_TARGET_ARM_
212
+ #endif // NK_CURVED_NEONBFDOT_H
@@ -0,0 +1,212 @@
1
+ /**
2
+ * @brief SIMD-accelerated Curved Space Similarity for NEON FP16.
3
+ * @file include/numkong/curved/neonhalf.h
4
+ * @author Ash Vardanian
5
+ * @date January 14, 2026
6
+ *
7
+ * @sa include/numkong/curved.h
8
+ *
9
+ * Implements f16 bilinear forms and Mahalanobis distance using ARM NEON with FP16 extensions.
10
+ *
11
+ * @section curved_neonhalf_instructions ARM NEON FP16 Instructions (ARMv8.2-FP16)
12
+ *
13
+ * Intrinsic Instruction Latency Throughput
14
+ * A76 M4+/V1+/Oryon
15
+ * vfmaq_f32 FMLA (V.4S, V.4S, V.4S) 4cy 2/cy 4/cy
16
+ * vcvt_f32_f16 FCVTL (V.4S, V.4H) 3cy 2/cy 4/cy
17
+ * vld1_f16 LD1 (V.4H) 4cy 2/cy 3/cy
18
+ * vsubq_f32 FSUB (V.4S, V.4S, V.4S) 2cy 2/cy 4/cy
19
+ * vaddvq_f32 FADDP+FADDP (V.4S) 4cy 1/cy 2/cy
20
+ *
21
+ * Bilinear forms involve nested summation O(n^2) operations. For numerical stability,
22
+ * f16 inputs are widened to f32 for accumulation. The matrix C is accessed row-by-row
23
+ * to maintain cache locality.
24
+ *
25
+ * Mathematical definitions:
26
+ * - Bilinear: result = ∑ᵢ ∑ⱼ aᵢ × cᵢⱼ × bⱼ
27
+ * - Mahalanobis: result = √((a - b)ᵀ × C × (a - b))
28
+ */
29
+ #ifndef NK_CURVED_NEONHALF_H
30
+ #define NK_CURVED_NEONHALF_H
31
+
32
+ #if NK_TARGET_ARM_
33
+ #if NK_TARGET_NEONHALF
34
+
35
+ #include "numkong/types.h"
36
+ #include "numkong/spatial/neon.h" // `nk_f32_sqrt_neon`
37
+ #include "numkong/cast/serial.h" // `nk_f16_to_f32_serial`
38
+
39
+ #if defined(__cplusplus)
40
+ extern "C" {
41
+ #endif
42
+
43
+ #if defined(__clang__)
44
+ #pragma clang attribute push(__attribute__((target("arch=armv8.2-a+simd+fp16"))), apply_to = function)
45
+ #elif defined(__GNUC__)
46
+ #pragma GCC push_options
47
+ #pragma GCC target("arch=armv8.2-a+simd+fp16")
48
+ #endif
49
+
50
+ NK_PUBLIC void nk_bilinear_f16_neonhalf(nk_f16_t const *a, nk_f16_t const *b, nk_f16_t const *c, nk_size_t n,
51
+ nk_f32_t *result) {
52
+ nk_f32_t outer_sum = 0;
53
+
54
+ // Process rows of the matrix
55
+ for (nk_size_t row = 0; row != n; ++row) {
56
+ nk_f16_t const *c_row = c + row * n;
57
+
58
+ // Load a[row] as f32
59
+ nk_f32_t a_row;
60
+ nk_f16_to_f32_serial(a + row, &a_row);
61
+
62
+ // Compute inner sum
63
+ float32x4_t inner_sum_f32x4 = vdupq_n_f32(0);
64
+ nk_size_t column = 0;
65
+
66
+ // Process 4 elements at a time
67
+ for (; column + 4 <= n; column += 4) {
68
+ float32x4_t b_f32x4 = vcvt_f32_f16(vld1_f16((nk_f16_for_arm_simd_t const *)(b + column)));
69
+ float32x4_t c_f32x4 = vcvt_f32_f16(vld1_f16((nk_f16_for_arm_simd_t const *)(c_row + column)));
70
+ inner_sum_f32x4 = vfmaq_f32(inner_sum_f32x4, c_f32x4, b_f32x4);
71
+ }
72
+
73
+ // Reduce SIMD accumulator
74
+ nk_f32_t inner_sum = vaddvq_f32(inner_sum_f32x4);
75
+
76
+ // Handle tail elements with scalar code
77
+ for (; column < n; ++column) {
78
+ nk_f32_t b_val, c_val;
79
+ nk_f16_to_f32_serial(b + column, &b_val);
80
+ nk_f16_to_f32_serial(c_row + column, &c_val);
81
+ inner_sum += c_val * b_val;
82
+ }
83
+
84
+ // Multiply by a[row] and accumulate
85
+ outer_sum += a_row * inner_sum;
86
+ }
87
+
88
+ *result = outer_sum;
89
+ }
90
+
91
+ NK_PUBLIC void nk_mahalanobis_f16_neonhalf(nk_f16_t const *a, nk_f16_t const *b, nk_f16_t const *c, nk_size_t n,
92
+ nk_f32_t *result) {
93
+ nk_f32_t outer_sum = 0;
94
+
95
+ // Process rows of the matrix
96
+ for (nk_size_t row = 0; row != n; ++row) {
97
+ nk_f16_t const *c_row = c + row * n;
98
+
99
+ // Compute diff_row = a[row] - b[row] in f32
100
+ nk_f32_t a_row, b_row;
101
+ nk_f16_to_f32_serial(a + row, &a_row);
102
+ nk_f16_to_f32_serial(b + row, &b_row);
103
+ nk_f32_t diff_row = a_row - b_row;
104
+
105
+ // Compute inner sum
106
+ float32x4_t inner_sum_f32x4 = vdupq_n_f32(0);
107
+ nk_size_t column = 0;
108
+
109
+ // Process 4 elements at a time
110
+ for (; column + 4 <= n; column += 4) {
111
+ float32x4_t a_f32x4 = vcvt_f32_f16(vld1_f16((nk_f16_for_arm_simd_t const *)(a + column)));
112
+ float32x4_t b_f32x4 = vcvt_f32_f16(vld1_f16((nk_f16_for_arm_simd_t const *)(b + column)));
113
+ float32x4_t c_f32x4 = vcvt_f32_f16(vld1_f16((nk_f16_for_arm_simd_t const *)(c_row + column)));
114
+ float32x4_t diff_column_f32x4 = vsubq_f32(a_f32x4, b_f32x4);
115
+ inner_sum_f32x4 = vfmaq_f32(inner_sum_f32x4, c_f32x4, diff_column_f32x4);
116
+ }
117
+
118
+ // Reduce SIMD accumulator
119
+ nk_f32_t inner_sum = vaddvq_f32(inner_sum_f32x4);
120
+
121
+ // Handle tail elements with scalar code
122
+ for (; column < n; ++column) {
123
+ nk_f32_t a_val, b_val, c_val;
124
+ nk_f16_to_f32_serial(a + column, &a_val);
125
+ nk_f16_to_f32_serial(b + column, &b_val);
126
+ nk_f16_to_f32_serial(c_row + column, &c_val);
127
+ inner_sum += c_val * (a_val - b_val);
128
+ }
129
+
130
+ // Multiply by diff_row and accumulate
131
+ outer_sum += diff_row * inner_sum;
132
+ }
133
+
134
+ nk_f32_t quadratic = outer_sum;
135
+ *result = nk_f32_sqrt_neon(quadratic > 0 ? quadratic : 0);
136
+ }
137
+
138
+ NK_PUBLIC void nk_bilinear_f16c_neonhalf(nk_f16c_t const *a_pairs, nk_f16c_t const *b_pairs, nk_f16c_t const *c_pairs,
139
+ nk_size_t n, nk_f32c_t *results) {
140
+ nk_f32_t outer_sum_real = 0;
141
+ nk_f32_t outer_sum_imag = 0;
142
+
143
+ // Process rows of the matrix
144
+ for (nk_size_t row = 0; row != n; ++row) {
145
+ nk_f16c_t const *c_row = c_pairs + row * n;
146
+
147
+ // Load a[row] complex value
148
+ nk_f32_t a_real, a_imag;
149
+ nk_f16_to_f32_serial(&(a_pairs + row)->real, &a_real);
150
+ nk_f16_to_f32_serial(&(a_pairs + row)->imag, &a_imag);
151
+
152
+ // Compute inner sum
153
+ float32x4_t inner_sum_real_f32x4 = vdupq_n_f32(0);
154
+ float32x4_t inner_sum_imag_f32x4 = vdupq_n_f32(0);
155
+ nk_size_t column = 0;
156
+
157
+ // Process 4 complex pairs at a time using deinterleaved loads
158
+ for (; column + 4 <= n; column += 4) {
159
+ // Deinterleave real/imaginary using vld2_s16 pattern from dot/neonhalf.h
160
+ int16x4x2_t b_i16x4x2 = vld2_s16((short const *)(b_pairs + column));
161
+ int16x4x2_t c_i16x4x2 = vld2_s16((short const *)(c_row + column));
162
+ float32x4_t b_real_f32x4 = vcvt_f32_f16(vreinterpret_f16_s16(b_i16x4x2.val[0]));
163
+ float32x4_t b_imag_f32x4 = vcvt_f32_f16(vreinterpret_f16_s16(b_i16x4x2.val[1]));
164
+ float32x4_t c_real_f32x4 = vcvt_f32_f16(vreinterpret_f16_s16(c_i16x4x2.val[0]));
165
+ float32x4_t c_imag_f32x4 = vcvt_f32_f16(vreinterpret_f16_s16(c_i16x4x2.val[1]));
166
+
167
+ // Complex multiply
168
+ inner_sum_real_f32x4 = vfmaq_f32(inner_sum_real_f32x4, c_real_f32x4, b_real_f32x4);
169
+ inner_sum_real_f32x4 = vfmsq_f32(inner_sum_real_f32x4, c_imag_f32x4, b_imag_f32x4);
170
+ inner_sum_imag_f32x4 = vfmaq_f32(inner_sum_imag_f32x4, c_real_f32x4, b_imag_f32x4);
171
+ inner_sum_imag_f32x4 = vfmaq_f32(inner_sum_imag_f32x4, c_imag_f32x4, b_real_f32x4);
172
+ }
173
+
174
+ // Reduce SIMD accumulators
175
+ nk_f32_t inner_sum_real = vaddvq_f32(inner_sum_real_f32x4);
176
+ nk_f32_t inner_sum_imag = vaddvq_f32(inner_sum_imag_f32x4);
177
+
178
+ // Handle tail elements with scalar code
179
+ for (; column < n; ++column) {
180
+ nk_f32_t b_real, b_imag, c_real, c_imag;
181
+ nk_f16_to_f32_serial(&(b_pairs + column)->real, &b_real);
182
+ nk_f16_to_f32_serial(&(b_pairs + column)->imag, &b_imag);
183
+ nk_f16_to_f32_serial(&(c_row + column)->real, &c_real);
184
+ nk_f16_to_f32_serial(&(c_row + column)->imag, &c_imag);
185
+
186
+ // Complex multiply
187
+ inner_sum_real += c_real * b_real - c_imag * b_imag;
188
+ inner_sum_imag += c_real * b_imag + c_imag * b_real;
189
+ }
190
+
191
+ // Complex multiply
192
+ outer_sum_real += a_real * inner_sum_real - a_imag * inner_sum_imag;
193
+ outer_sum_imag += a_real * inner_sum_imag + a_imag * inner_sum_real;
194
+ }
195
+
196
+ results->real = outer_sum_real;
197
+ results->imag = outer_sum_imag;
198
+ }
199
+
200
+ #if defined(__clang__)
201
+ #pragma clang attribute pop
202
+ #elif defined(__GNUC__)
203
+ #pragma GCC pop_options
204
+ #endif
205
+
206
+ #if defined(__cplusplus)
207
+ } // extern "C"
208
+ #endif
209
+
210
+ #endif // NK_TARGET_NEONHALF
211
+ #endif // NK_TARGET_ARM_
212
+ #endif // NK_CURVED_NEONHALF_H