numkong 7.0.0 → 7.4.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (315) hide show
  1. package/README.md +239 -122
  2. package/binding.gyp +25 -491
  3. package/c/dispatch_bf16.c +59 -1
  4. package/c/dispatch_e2m3.c +41 -8
  5. package/c/dispatch_e3m2.c +49 -8
  6. package/c/dispatch_e4m3.c +51 -9
  7. package/c/dispatch_e5m2.c +45 -1
  8. package/c/dispatch_f16.c +79 -26
  9. package/c/dispatch_f16c.c +5 -5
  10. package/c/dispatch_f32.c +56 -0
  11. package/c/dispatch_f64.c +52 -0
  12. package/c/dispatch_i4.c +3 -0
  13. package/c/dispatch_i8.c +62 -3
  14. package/c/dispatch_other.c +18 -0
  15. package/c/dispatch_u1.c +54 -9
  16. package/c/dispatch_u4.c +3 -0
  17. package/c/dispatch_u8.c +64 -3
  18. package/c/numkong.c +3 -0
  19. package/include/README.md +79 -9
  20. package/include/numkong/attention/sapphireamx.h +278 -276
  21. package/include/numkong/attention/sme.h +983 -977
  22. package/include/numkong/attention.h +1 -1
  23. package/include/numkong/capabilities.h +289 -94
  24. package/include/numkong/cast/README.md +40 -40
  25. package/include/numkong/cast/diamond.h +64 -0
  26. package/include/numkong/cast/haswell.h +42 -194
  27. package/include/numkong/cast/icelake.h +42 -37
  28. package/include/numkong/cast/loongsonasx.h +252 -0
  29. package/include/numkong/cast/neon.h +216 -249
  30. package/include/numkong/cast/powervsx.h +449 -0
  31. package/include/numkong/cast/rvv.h +223 -274
  32. package/include/numkong/cast/sapphire.h +18 -18
  33. package/include/numkong/cast/serial.h +1018 -944
  34. package/include/numkong/cast/skylake.h +82 -23
  35. package/include/numkong/cast/v128relaxed.h +462 -105
  36. package/include/numkong/cast.h +24 -0
  37. package/include/numkong/cast.hpp +44 -0
  38. package/include/numkong/curved/README.md +17 -17
  39. package/include/numkong/curved/neon.h +131 -7
  40. package/include/numkong/curved/neonbfdot.h +6 -7
  41. package/include/numkong/curved/rvv.h +26 -26
  42. package/include/numkong/curved/smef64.h +186 -182
  43. package/include/numkong/curved.h +14 -18
  44. package/include/numkong/dot/README.md +154 -137
  45. package/include/numkong/dot/alder.h +43 -43
  46. package/include/numkong/dot/diamond.h +158 -0
  47. package/include/numkong/dot/genoa.h +4 -30
  48. package/include/numkong/dot/haswell.h +215 -180
  49. package/include/numkong/dot/icelake.h +190 -76
  50. package/include/numkong/dot/loongsonasx.h +671 -0
  51. package/include/numkong/dot/neon.h +124 -73
  52. package/include/numkong/dot/neonbfdot.h +11 -12
  53. package/include/numkong/dot/neonfhm.h +44 -46
  54. package/include/numkong/dot/neonfp8.h +323 -0
  55. package/include/numkong/dot/neonsdot.h +190 -76
  56. package/include/numkong/dot/powervsx.h +752 -0
  57. package/include/numkong/dot/rvv.h +92 -84
  58. package/include/numkong/dot/rvvbf16.h +12 -12
  59. package/include/numkong/dot/rvvhalf.h +12 -12
  60. package/include/numkong/dot/sapphire.h +4 -4
  61. package/include/numkong/dot/serial.h +66 -30
  62. package/include/numkong/dot/sierra.h +31 -31
  63. package/include/numkong/dot/skylake.h +142 -110
  64. package/include/numkong/dot/sve.h +217 -177
  65. package/include/numkong/dot/svebfdot.h +10 -10
  66. package/include/numkong/dot/svehalf.h +85 -41
  67. package/include/numkong/dot/svesdot.h +89 -0
  68. package/include/numkong/dot/v128relaxed.h +124 -89
  69. package/include/numkong/dot.h +114 -48
  70. package/include/numkong/dots/README.md +203 -203
  71. package/include/numkong/dots/alder.h +12 -9
  72. package/include/numkong/dots/diamond.h +86 -0
  73. package/include/numkong/dots/genoa.h +10 -4
  74. package/include/numkong/dots/haswell.h +63 -48
  75. package/include/numkong/dots/icelake.h +27 -18
  76. package/include/numkong/dots/loongsonasx.h +176 -0
  77. package/include/numkong/dots/neon.h +14 -11
  78. package/include/numkong/dots/neonbfdot.h +4 -3
  79. package/include/numkong/dots/neonfhm.h +11 -9
  80. package/include/numkong/dots/neonfp8.h +99 -0
  81. package/include/numkong/dots/neonsdot.h +48 -12
  82. package/include/numkong/dots/powervsx.h +194 -0
  83. package/include/numkong/dots/rvv.h +451 -344
  84. package/include/numkong/dots/sapphireamx.h +1028 -984
  85. package/include/numkong/dots/serial.h +213 -197
  86. package/include/numkong/dots/sierra.h +10 -7
  87. package/include/numkong/dots/skylake.h +47 -36
  88. package/include/numkong/dots/sme.h +2001 -2364
  89. package/include/numkong/dots/smebi32.h +175 -162
  90. package/include/numkong/dots/smef64.h +328 -323
  91. package/include/numkong/dots/v128relaxed.h +64 -41
  92. package/include/numkong/dots.h +573 -293
  93. package/include/numkong/dots.hpp +45 -43
  94. package/include/numkong/each/README.md +133 -137
  95. package/include/numkong/each/haswell.h +6 -6
  96. package/include/numkong/each/icelake.h +7 -7
  97. package/include/numkong/each/neon.h +76 -42
  98. package/include/numkong/each/neonbfdot.h +11 -12
  99. package/include/numkong/each/neonhalf.h +24 -116
  100. package/include/numkong/each/rvv.h +28 -28
  101. package/include/numkong/each/sapphire.h +27 -161
  102. package/include/numkong/each/serial.h +6 -6
  103. package/include/numkong/each/skylake.h +7 -7
  104. package/include/numkong/each/v128relaxed.h +562 -0
  105. package/include/numkong/each.h +148 -62
  106. package/include/numkong/each.hpp +2 -2
  107. package/include/numkong/geospatial/README.md +18 -18
  108. package/include/numkong/geospatial/haswell.h +365 -325
  109. package/include/numkong/geospatial/neon.h +350 -306
  110. package/include/numkong/geospatial/rvv.h +4 -4
  111. package/include/numkong/geospatial/skylake.h +376 -340
  112. package/include/numkong/geospatial/v128relaxed.h +366 -327
  113. package/include/numkong/geospatial.h +17 -17
  114. package/include/numkong/matrix.hpp +4 -4
  115. package/include/numkong/maxsim/README.md +14 -14
  116. package/include/numkong/maxsim/alder.h +6 -6
  117. package/include/numkong/maxsim/genoa.h +4 -4
  118. package/include/numkong/maxsim/haswell.h +6 -6
  119. package/include/numkong/maxsim/icelake.h +18 -18
  120. package/include/numkong/maxsim/neonsdot.h +21 -21
  121. package/include/numkong/maxsim/sapphireamx.h +14 -14
  122. package/include/numkong/maxsim/serial.h +6 -6
  123. package/include/numkong/maxsim/sme.h +221 -196
  124. package/include/numkong/maxsim/v128relaxed.h +6 -6
  125. package/include/numkong/mesh/README.md +62 -56
  126. package/include/numkong/mesh/haswell.h +339 -464
  127. package/include/numkong/mesh/neon.h +1100 -519
  128. package/include/numkong/mesh/neonbfdot.h +36 -68
  129. package/include/numkong/mesh/rvv.h +530 -435
  130. package/include/numkong/mesh/serial.h +75 -91
  131. package/include/numkong/mesh/skylake.h +1627 -302
  132. package/include/numkong/mesh/v128relaxed.h +443 -330
  133. package/include/numkong/mesh.h +63 -49
  134. package/include/numkong/mesh.hpp +4 -4
  135. package/include/numkong/numkong.h +3 -3
  136. package/include/numkong/numkong.hpp +1 -0
  137. package/include/numkong/probability/README.md +23 -19
  138. package/include/numkong/probability/neon.h +82 -52
  139. package/include/numkong/probability/rvv.h +28 -23
  140. package/include/numkong/probability/serial.h +51 -39
  141. package/include/numkong/probability.h +20 -23
  142. package/include/numkong/random.h +1 -1
  143. package/include/numkong/reduce/README.md +143 -138
  144. package/include/numkong/reduce/alder.h +81 -77
  145. package/include/numkong/reduce/haswell.h +222 -220
  146. package/include/numkong/reduce/neon.h +629 -519
  147. package/include/numkong/reduce/neonbfdot.h +7 -218
  148. package/include/numkong/reduce/neonfhm.h +9 -381
  149. package/include/numkong/reduce/neonsdot.h +9 -9
  150. package/include/numkong/reduce/rvv.h +928 -802
  151. package/include/numkong/reduce/serial.h +23 -27
  152. package/include/numkong/reduce/sierra.h +20 -20
  153. package/include/numkong/reduce/skylake.h +326 -324
  154. package/include/numkong/reduce/v128relaxed.h +52 -52
  155. package/include/numkong/reduce.h +4 -23
  156. package/include/numkong/reduce.hpp +156 -11
  157. package/include/numkong/scalar/README.md +6 -6
  158. package/include/numkong/scalar/haswell.h +26 -17
  159. package/include/numkong/scalar/loongsonasx.h +74 -0
  160. package/include/numkong/scalar/neon.h +9 -9
  161. package/include/numkong/scalar/powervsx.h +96 -0
  162. package/include/numkong/scalar/rvv.h +2 -2
  163. package/include/numkong/scalar/sapphire.h +21 -10
  164. package/include/numkong/scalar/serial.h +21 -21
  165. package/include/numkong/scalar.h +13 -0
  166. package/include/numkong/set/README.md +28 -28
  167. package/include/numkong/set/haswell.h +12 -12
  168. package/include/numkong/set/icelake.h +14 -14
  169. package/include/numkong/set/loongsonasx.h +181 -0
  170. package/include/numkong/set/neon.h +17 -18
  171. package/include/numkong/set/powervsx.h +326 -0
  172. package/include/numkong/set/rvv.h +4 -4
  173. package/include/numkong/set/serial.h +6 -6
  174. package/include/numkong/set/sve.h +60 -59
  175. package/include/numkong/set/v128relaxed.h +6 -6
  176. package/include/numkong/set.h +21 -7
  177. package/include/numkong/sets/README.md +26 -26
  178. package/include/numkong/sets/loongsonasx.h +52 -0
  179. package/include/numkong/sets/powervsx.h +65 -0
  180. package/include/numkong/sets/smebi32.h +395 -364
  181. package/include/numkong/sets.h +83 -40
  182. package/include/numkong/sparse/README.md +4 -4
  183. package/include/numkong/sparse/icelake.h +101 -101
  184. package/include/numkong/sparse/serial.h +1 -1
  185. package/include/numkong/sparse/sve2.h +137 -141
  186. package/include/numkong/sparse/turin.h +12 -12
  187. package/include/numkong/sparse.h +10 -10
  188. package/include/numkong/spatial/README.md +230 -226
  189. package/include/numkong/spatial/alder.h +113 -116
  190. package/include/numkong/spatial/diamond.h +240 -0
  191. package/include/numkong/spatial/genoa.h +0 -68
  192. package/include/numkong/spatial/haswell.h +74 -55
  193. package/include/numkong/spatial/icelake.h +539 -58
  194. package/include/numkong/spatial/loongsonasx.h +483 -0
  195. package/include/numkong/spatial/neon.h +125 -52
  196. package/include/numkong/spatial/neonbfdot.h +8 -9
  197. package/include/numkong/spatial/neonfp8.h +258 -0
  198. package/include/numkong/spatial/neonsdot.h +180 -12
  199. package/include/numkong/spatial/powervsx.h +738 -0
  200. package/include/numkong/spatial/rvv.h +146 -139
  201. package/include/numkong/spatial/rvvbf16.h +17 -12
  202. package/include/numkong/spatial/rvvhalf.h +13 -10
  203. package/include/numkong/spatial/serial.h +13 -12
  204. package/include/numkong/spatial/sierra.h +232 -39
  205. package/include/numkong/spatial/skylake.h +73 -74
  206. package/include/numkong/spatial/sve.h +93 -72
  207. package/include/numkong/spatial/svebfdot.h +29 -29
  208. package/include/numkong/spatial/svehalf.h +52 -26
  209. package/include/numkong/spatial/svesdot.h +142 -0
  210. package/include/numkong/spatial/v128relaxed.h +293 -41
  211. package/include/numkong/spatial.h +338 -82
  212. package/include/numkong/spatials/README.md +194 -194
  213. package/include/numkong/spatials/diamond.h +82 -0
  214. package/include/numkong/spatials/haswell.h +2 -2
  215. package/include/numkong/spatials/loongsonasx.h +153 -0
  216. package/include/numkong/spatials/neonfp8.h +111 -0
  217. package/include/numkong/spatials/neonsdot.h +34 -0
  218. package/include/numkong/spatials/powervsx.h +153 -0
  219. package/include/numkong/spatials/rvv.h +259 -243
  220. package/include/numkong/spatials/sapphireamx.h +173 -173
  221. package/include/numkong/spatials/serial.h +2 -2
  222. package/include/numkong/spatials/skylake.h +2 -2
  223. package/include/numkong/spatials/sme.h +590 -605
  224. package/include/numkong/spatials/smef64.h +139 -130
  225. package/include/numkong/spatials/v128relaxed.h +2 -2
  226. package/include/numkong/spatials.h +820 -500
  227. package/include/numkong/spatials.hpp +49 -48
  228. package/include/numkong/tensor.hpp +406 -17
  229. package/include/numkong/trigonometry/README.md +19 -19
  230. package/include/numkong/trigonometry/haswell.h +402 -401
  231. package/include/numkong/trigonometry/neon.h +386 -387
  232. package/include/numkong/trigonometry/rvv.h +52 -51
  233. package/include/numkong/trigonometry/serial.h +13 -13
  234. package/include/numkong/trigonometry/skylake.h +373 -369
  235. package/include/numkong/trigonometry/v128relaxed.h +375 -374
  236. package/include/numkong/trigonometry.h +13 -13
  237. package/include/numkong/trigonometry.hpp +2 -2
  238. package/include/numkong/types.h +287 -49
  239. package/include/numkong/types.hpp +436 -12
  240. package/include/numkong/vector.hpp +82 -14
  241. package/javascript/dist/cjs/numkong-wasm.js +6 -12
  242. package/javascript/dist/cjs/numkong.d.ts +7 -1
  243. package/javascript/dist/cjs/numkong.js +37 -11
  244. package/javascript/dist/cjs/types.d.ts +9 -0
  245. package/javascript/dist/cjs/types.js +96 -0
  246. package/javascript/dist/esm/numkong-browser.d.ts +14 -0
  247. package/javascript/dist/esm/numkong-browser.js +23 -0
  248. package/javascript/dist/esm/numkong-wasm.js +6 -12
  249. package/javascript/dist/esm/numkong.d.ts +7 -1
  250. package/javascript/dist/esm/numkong.js +37 -11
  251. package/javascript/dist/esm/types.d.ts +9 -0
  252. package/javascript/dist/esm/types.js +96 -0
  253. package/javascript/node-gyp-build.d.ts +4 -1
  254. package/javascript/numkong-browser.ts +40 -0
  255. package/javascript/numkong-wasm.ts +7 -13
  256. package/javascript/numkong.c +5 -26
  257. package/javascript/numkong.ts +36 -11
  258. package/javascript/tsconfig-base.json +1 -0
  259. package/javascript/tsconfig-cjs.json +6 -1
  260. package/javascript/types.ts +110 -0
  261. package/numkong.gypi +101 -0
  262. package/package.json +34 -13
  263. package/probes/arm_neon.c +8 -0
  264. package/probes/arm_neon_bfdot.c +9 -0
  265. package/probes/arm_neon_fhm.c +9 -0
  266. package/probes/arm_neon_half.c +8 -0
  267. package/probes/arm_neon_sdot.c +9 -0
  268. package/probes/arm_neonfp8.c +9 -0
  269. package/probes/arm_sme.c +16 -0
  270. package/probes/arm_sme2.c +16 -0
  271. package/probes/arm_sme2p1.c +16 -0
  272. package/probes/arm_sme_bf16.c +16 -0
  273. package/probes/arm_sme_bi32.c +16 -0
  274. package/probes/arm_sme_f64.c +16 -0
  275. package/probes/arm_sme_fa64.c +14 -0
  276. package/probes/arm_sme_half.c +16 -0
  277. package/probes/arm_sme_lut2.c +15 -0
  278. package/probes/arm_sve.c +18 -0
  279. package/probes/arm_sve2.c +20 -0
  280. package/probes/arm_sve2p1.c +18 -0
  281. package/probes/arm_sve_bfdot.c +20 -0
  282. package/probes/arm_sve_half.c +18 -0
  283. package/probes/arm_sve_sdot.c +21 -0
  284. package/probes/loongarch_lasx.c +12 -0
  285. package/probes/power_vsx.c +12 -0
  286. package/probes/probe.js +127 -0
  287. package/probes/riscv_rvv.c +14 -0
  288. package/probes/riscv_rvv_bb.c +15 -0
  289. package/probes/riscv_rvv_bf16.c +17 -0
  290. package/probes/riscv_rvv_half.c +14 -0
  291. package/probes/wasm_v128relaxed.c +11 -0
  292. package/probes/x86_alder.c +17 -0
  293. package/probes/x86_diamond.c +17 -0
  294. package/probes/x86_genoa.c +17 -0
  295. package/probes/x86_graniteamx.c +19 -0
  296. package/probes/x86_haswell.c +11 -0
  297. package/probes/x86_icelake.c +17 -0
  298. package/probes/x86_sapphire.c +16 -0
  299. package/probes/x86_sapphireamx.c +18 -0
  300. package/probes/x86_sierra.c +17 -0
  301. package/probes/x86_skylake.c +15 -0
  302. package/probes/x86_turin.c +17 -0
  303. package/wasm/numkong-emscripten.js +2 -0
  304. package/wasm/numkong.d.ts +14 -0
  305. package/wasm/numkong.js +1124 -0
  306. package/wasm/numkong.wasm +0 -0
  307. package/include/numkong/curved/neonhalf.h +0 -212
  308. package/include/numkong/dot/neonhalf.h +0 -198
  309. package/include/numkong/dots/neonhalf.h +0 -57
  310. package/include/numkong/mesh/neonhalf.h +0 -616
  311. package/include/numkong/reduce/neonhalf.h +0 -157
  312. package/include/numkong/spatial/neonhalf.h +0 -118
  313. package/include/numkong/spatial/sapphire.h +0 -343
  314. package/include/numkong/spatials/neonhalf.h +0 -58
  315. package/javascript/README.md +0 -246
Binary file
@@ -1,212 +0,0 @@
1
- /**
2
- * @brief SIMD-accelerated Curved Space Similarity for NEON FP16.
3
- * @file include/numkong/curved/neonhalf.h
4
- * @author Ash Vardanian
5
- * @date January 14, 2026
6
- *
7
- * @sa include/numkong/curved.h
8
- *
9
- * Implements f16 bilinear forms and Mahalanobis distance using ARM NEON with FP16 extensions.
10
- *
11
- * @section curved_neonhalf_instructions ARM NEON FP16 Instructions (ARMv8.2-FP16)
12
- *
13
- * Intrinsic Instruction Latency Throughput
14
- * A76 M4+/V1+/Oryon
15
- * vfmaq_f32 FMLA (V.4S, V.4S, V.4S) 4cy 2/cy 4/cy
16
- * vcvt_f32_f16 FCVTL (V.4S, V.4H) 3cy 2/cy 4/cy
17
- * vld1_f16 LD1 (V.4H) 4cy 2/cy 3/cy
18
- * vsubq_f32 FSUB (V.4S, V.4S, V.4S) 2cy 2/cy 4/cy
19
- * vaddvq_f32 FADDP+FADDP (V.4S) 4cy 1/cy 2/cy
20
- *
21
- * Bilinear forms involve nested summation O(n^2) operations. For numerical stability,
22
- * f16 inputs are widened to f32 for accumulation. The matrix C is accessed row-by-row
23
- * to maintain cache locality.
24
- *
25
- * Mathematical definitions:
26
- * - Bilinear: result = ∑ᵢ ∑ⱼ aᵢ × cᵢⱼ × bⱼ
27
- * - Mahalanobis: result = √((a - b)ᵀ × C × (a - b))
28
- */
29
- #ifndef NK_CURVED_NEONHALF_H
30
- #define NK_CURVED_NEONHALF_H
31
-
32
- #if NK_TARGET_ARM_
33
- #if NK_TARGET_NEONHALF
34
-
35
- #include "numkong/types.h"
36
- #include "numkong/spatial/neon.h" // `nk_f32_sqrt_neon`
37
- #include "numkong/cast/serial.h" // `nk_f16_to_f32_serial`
38
-
39
- #if defined(__cplusplus)
40
- extern "C" {
41
- #endif
42
-
43
- #if defined(__clang__)
44
- #pragma clang attribute push(__attribute__((target("arch=armv8.2-a+simd+fp16"))), apply_to = function)
45
- #elif defined(__GNUC__)
46
- #pragma GCC push_options
47
- #pragma GCC target("arch=armv8.2-a+simd+fp16")
48
- #endif
49
-
50
- NK_PUBLIC void nk_bilinear_f16_neonhalf(nk_f16_t const *a, nk_f16_t const *b, nk_f16_t const *c, nk_size_t n,
51
- nk_f32_t *result) {
52
- nk_f32_t outer_sum = 0;
53
-
54
- // Process rows of the matrix
55
- for (nk_size_t row = 0; row != n; ++row) {
56
- nk_f16_t const *c_row = c + row * n;
57
-
58
- // Load a[row] as f32
59
- nk_f32_t a_row;
60
- nk_f16_to_f32_serial(a + row, &a_row);
61
-
62
- // Compute inner sum
63
- float32x4_t inner_sum_f32x4 = vdupq_n_f32(0);
64
- nk_size_t column = 0;
65
-
66
- // Process 4 elements at a time
67
- for (; column + 4 <= n; column += 4) {
68
- float32x4_t b_f32x4 = vcvt_f32_f16(vld1_f16((nk_f16_for_arm_simd_t const *)(b + column)));
69
- float32x4_t c_f32x4 = vcvt_f32_f16(vld1_f16((nk_f16_for_arm_simd_t const *)(c_row + column)));
70
- inner_sum_f32x4 = vfmaq_f32(inner_sum_f32x4, c_f32x4, b_f32x4);
71
- }
72
-
73
- // Reduce SIMD accumulator
74
- nk_f32_t inner_sum = vaddvq_f32(inner_sum_f32x4);
75
-
76
- // Handle tail elements with scalar code
77
- for (; column < n; ++column) {
78
- nk_f32_t b_val, c_val;
79
- nk_f16_to_f32_serial(b + column, &b_val);
80
- nk_f16_to_f32_serial(c_row + column, &c_val);
81
- inner_sum += c_val * b_val;
82
- }
83
-
84
- // Multiply by a[row] and accumulate
85
- outer_sum += a_row * inner_sum;
86
- }
87
-
88
- *result = outer_sum;
89
- }
90
-
91
- NK_PUBLIC void nk_mahalanobis_f16_neonhalf(nk_f16_t const *a, nk_f16_t const *b, nk_f16_t const *c, nk_size_t n,
92
- nk_f32_t *result) {
93
- nk_f32_t outer_sum = 0;
94
-
95
- // Process rows of the matrix
96
- for (nk_size_t row = 0; row != n; ++row) {
97
- nk_f16_t const *c_row = c + row * n;
98
-
99
- // Compute diff_row = a[row] - b[row] in f32
100
- nk_f32_t a_row, b_row;
101
- nk_f16_to_f32_serial(a + row, &a_row);
102
- nk_f16_to_f32_serial(b + row, &b_row);
103
- nk_f32_t diff_row = a_row - b_row;
104
-
105
- // Compute inner sum
106
- float32x4_t inner_sum_f32x4 = vdupq_n_f32(0);
107
- nk_size_t column = 0;
108
-
109
- // Process 4 elements at a time
110
- for (; column + 4 <= n; column += 4) {
111
- float32x4_t a_f32x4 = vcvt_f32_f16(vld1_f16((nk_f16_for_arm_simd_t const *)(a + column)));
112
- float32x4_t b_f32x4 = vcvt_f32_f16(vld1_f16((nk_f16_for_arm_simd_t const *)(b + column)));
113
- float32x4_t c_f32x4 = vcvt_f32_f16(vld1_f16((nk_f16_for_arm_simd_t const *)(c_row + column)));
114
- float32x4_t diff_column_f32x4 = vsubq_f32(a_f32x4, b_f32x4);
115
- inner_sum_f32x4 = vfmaq_f32(inner_sum_f32x4, c_f32x4, diff_column_f32x4);
116
- }
117
-
118
- // Reduce SIMD accumulator
119
- nk_f32_t inner_sum = vaddvq_f32(inner_sum_f32x4);
120
-
121
- // Handle tail elements with scalar code
122
- for (; column < n; ++column) {
123
- nk_f32_t a_val, b_val, c_val;
124
- nk_f16_to_f32_serial(a + column, &a_val);
125
- nk_f16_to_f32_serial(b + column, &b_val);
126
- nk_f16_to_f32_serial(c_row + column, &c_val);
127
- inner_sum += c_val * (a_val - b_val);
128
- }
129
-
130
- // Multiply by diff_row and accumulate
131
- outer_sum += diff_row * inner_sum;
132
- }
133
-
134
- nk_f32_t quadratic = outer_sum;
135
- *result = nk_f32_sqrt_neon(quadratic > 0 ? quadratic : 0);
136
- }
137
-
138
- NK_PUBLIC void nk_bilinear_f16c_neonhalf(nk_f16c_t const *a_pairs, nk_f16c_t const *b_pairs, nk_f16c_t const *c_pairs,
139
- nk_size_t n, nk_f32c_t *results) {
140
- nk_f32_t outer_sum_real = 0;
141
- nk_f32_t outer_sum_imag = 0;
142
-
143
- // Process rows of the matrix
144
- for (nk_size_t row = 0; row != n; ++row) {
145
- nk_f16c_t const *c_row = c_pairs + row * n;
146
-
147
- // Load a[row] complex value
148
- nk_f32_t a_real, a_imag;
149
- nk_f16_to_f32_serial(&(a_pairs + row)->real, &a_real);
150
- nk_f16_to_f32_serial(&(a_pairs + row)->imag, &a_imag);
151
-
152
- // Compute inner sum
153
- float32x4_t inner_sum_real_f32x4 = vdupq_n_f32(0);
154
- float32x4_t inner_sum_imag_f32x4 = vdupq_n_f32(0);
155
- nk_size_t column = 0;
156
-
157
- // Process 4 complex pairs at a time using deinterleaved loads
158
- for (; column + 4 <= n; column += 4) {
159
- // Deinterleave real/imaginary using vld2_s16 pattern from dot/neonhalf.h
160
- int16x4x2_t b_i16x4x2 = vld2_s16((short const *)(b_pairs + column));
161
- int16x4x2_t c_i16x4x2 = vld2_s16((short const *)(c_row + column));
162
- float32x4_t b_real_f32x4 = vcvt_f32_f16(vreinterpret_f16_s16(b_i16x4x2.val[0]));
163
- float32x4_t b_imag_f32x4 = vcvt_f32_f16(vreinterpret_f16_s16(b_i16x4x2.val[1]));
164
- float32x4_t c_real_f32x4 = vcvt_f32_f16(vreinterpret_f16_s16(c_i16x4x2.val[0]));
165
- float32x4_t c_imag_f32x4 = vcvt_f32_f16(vreinterpret_f16_s16(c_i16x4x2.val[1]));
166
-
167
- // Complex multiply
168
- inner_sum_real_f32x4 = vfmaq_f32(inner_sum_real_f32x4, c_real_f32x4, b_real_f32x4);
169
- inner_sum_real_f32x4 = vfmsq_f32(inner_sum_real_f32x4, c_imag_f32x4, b_imag_f32x4);
170
- inner_sum_imag_f32x4 = vfmaq_f32(inner_sum_imag_f32x4, c_real_f32x4, b_imag_f32x4);
171
- inner_sum_imag_f32x4 = vfmaq_f32(inner_sum_imag_f32x4, c_imag_f32x4, b_real_f32x4);
172
- }
173
-
174
- // Reduce SIMD accumulators
175
- nk_f32_t inner_sum_real = vaddvq_f32(inner_sum_real_f32x4);
176
- nk_f32_t inner_sum_imag = vaddvq_f32(inner_sum_imag_f32x4);
177
-
178
- // Handle tail elements with scalar code
179
- for (; column < n; ++column) {
180
- nk_f32_t b_real, b_imag, c_real, c_imag;
181
- nk_f16_to_f32_serial(&(b_pairs + column)->real, &b_real);
182
- nk_f16_to_f32_serial(&(b_pairs + column)->imag, &b_imag);
183
- nk_f16_to_f32_serial(&(c_row + column)->real, &c_real);
184
- nk_f16_to_f32_serial(&(c_row + column)->imag, &c_imag);
185
-
186
- // Complex multiply
187
- inner_sum_real += c_real * b_real - c_imag * b_imag;
188
- inner_sum_imag += c_real * b_imag + c_imag * b_real;
189
- }
190
-
191
- // Complex multiply
192
- outer_sum_real += a_real * inner_sum_real - a_imag * inner_sum_imag;
193
- outer_sum_imag += a_real * inner_sum_imag + a_imag * inner_sum_real;
194
- }
195
-
196
- results->real = outer_sum_real;
197
- results->imag = outer_sum_imag;
198
- }
199
-
200
- #if defined(__clang__)
201
- #pragma clang attribute pop
202
- #elif defined(__GNUC__)
203
- #pragma GCC pop_options
204
- #endif
205
-
206
- #if defined(__cplusplus)
207
- } // extern "C"
208
- #endif
209
-
210
- #endif // NK_TARGET_NEONHALF
211
- #endif // NK_TARGET_ARM_
212
- #endif // NK_CURVED_NEONHALF_H
@@ -1,198 +0,0 @@
1
- /**
2
- * @brief SIMD-accelerated Dot Products for NEON FP16.
3
- * @file include/numkong/dot/neonhalf.h
4
- * @author Ash Vardanian
5
- * @date December 27, 2025
6
- *
7
- * @sa include/numkong/dot.h
8
- *
9
- * @section dot_neonhalf_instructions ARM NEON FP16 Instructions (ARMv8.2-FP16)
10
- *
11
- * Intrinsic Instruction Latency Throughput
12
- * A76 M4+/V1+/Oryon
13
- * vfmaq_f16 FMLA (V.8H, V.8H, V.8H) 4cy 2/cy 4/cy
14
- * vcvt_f32_f16 FCVTL (V.4S, V.4H) 3cy 2/cy 4/cy
15
- * vld1q_f16 LD1 (V.8H) 4cy 2/cy 3/cy
16
- * vaddvq_f32 FADDP+FADDP (V.4S) 4cy 1/cy 2/cy
17
- * vfmsq_f16 FMLS (V.8H, V.8H, V.8H) 4cy 2/cy 4/cy
18
- *
19
- * The ARMv8.2-FP16 extension enables native half-precision arithmetic, doubling the element count
20
- * per vector register (8x F16 vs 4x F32). This doubles theoretical throughput for bandwidth-bound
21
- * workloads while halving memory footprint.
22
- *
23
- * For dot products, inputs are widened from F16 to F32 for accumulation to preserve numerical
24
- * precision. The FCVTL instruction handles this widening, allowing the FMA operations
25
- * to maintain full F32 precision in the accumulator.
26
- *
27
- * @section dot_neonhalf_stateful Stateful Streaming Logic
28
- *
29
- * To build memory-optimal tiled algorithms, this file defines following structures and force-inlined
30
- * `NK_INTERNAL` functions:
31
- *
32
- * - nk_dot_f16x4 state with f16 inputs widened to f32 for accumulation.
33
- *
34
- * @code{c}
35
- * nk_dot_f16x4_state_neonhalf_t state_first, state_second, state_third, state_fourth;
36
- * float16x4_t query_f16x4, target_first_f16x4, target_second_f16x4, target_third_f16x4, target_fourth_f16x4;
37
- * nk_dot_f16x4_init_neonhalf(&state_first);
38
- * nk_dot_f16x4_init_neonhalf(&state_second);
39
- * nk_dot_f16x4_init_neonhalf(&state_third);
40
- * nk_dot_f16x4_init_neonhalf(&state_fourth);
41
- * for (nk_size_t idx = 0; idx + 4 <= depth; idx += 4) {
42
- * query_f16x4 = vld1_f16(query_ptr + idx);
43
- * target_first_f16x4 = vld1_f16(target_first_ptr + idx);
44
- * target_second_f16x4 = vld1_f16(target_second_ptr + idx);
45
- * target_third_f16x4 = vld1_f16(target_third_ptr + idx);
46
- * target_fourth_f16x4 = vld1_f16(target_fourth_ptr + idx);
47
- * nk_dot_f16x4_update_neonhalf(&state_first, query_f16x4, target_first_f16x4, idx, 4);
48
- * nk_dot_f16x4_update_neonhalf(&state_second, query_f16x4, target_second_f16x4, idx, 4);
49
- * nk_dot_f16x4_update_neonhalf(&state_third, query_f16x4, target_third_f16x4, idx, 4);
50
- * nk_dot_f16x4_update_neonhalf(&state_fourth, query_f16x4, target_fourth_f16x4, idx, 4);
51
- * }
52
- * float32x4_t results_f32x4;
53
- * nk_dot_f16x4_finalize_neonhalf(&state_first, &state_second, &state_third, &state_fourth, depth, &results_f32x4);
54
- * @endcode
55
- */
56
- #ifndef NK_DOT_NEONHALF_H
57
- #define NK_DOT_NEONHALF_H
58
-
59
- #if NK_TARGET_ARM_
60
- #if NK_TARGET_NEONHALF
61
-
62
- #include "numkong/types.h"
63
- #include "numkong/cast/serial.h" // `nk_partial_load_b16x4_serial_`
64
-
65
- #if defined(__cplusplus)
66
- extern "C" {
67
- #endif
68
-
69
- #if defined(__clang__)
70
- #pragma clang attribute push(__attribute__((target("arch=armv8.2-a+simd+fp16"))), apply_to = function)
71
- #elif defined(__GNUC__)
72
- #pragma GCC push_options
73
- #pragma GCC target("arch=armv8.2-a+simd+fp16")
74
- #endif
75
-
76
- NK_PUBLIC void nk_dot_f16_neonhalf(nk_f16_t const *a_scalars, nk_f16_t const *b_scalars, nk_size_t count_scalars,
77
- nk_f32_t *result) {
78
- float32x4_t a_f32x4, b_f32x4;
79
- float32x4_t sum_f32x4 = vdupq_n_f32(0);
80
- nk_dot_f16_neonhalf_cycle:
81
- if (count_scalars < 4) {
82
- nk_b64_vec_t a_vec, b_vec;
83
- nk_partial_load_b16x4_serial_(a_scalars, &a_vec, count_scalars);
84
- nk_partial_load_b16x4_serial_(b_scalars, &b_vec, count_scalars);
85
- a_f32x4 = vcvt_f32_f16(vreinterpret_f16_u16(a_vec.u16x4));
86
- b_f32x4 = vcvt_f32_f16(vreinterpret_f16_u16(b_vec.u16x4));
87
- count_scalars = 0;
88
- }
89
- else {
90
- a_f32x4 = vcvt_f32_f16(vld1_f16((nk_f16_for_arm_simd_t const *)a_scalars));
91
- b_f32x4 = vcvt_f32_f16(vld1_f16((nk_f16_for_arm_simd_t const *)b_scalars));
92
- a_scalars += 4, b_scalars += 4, count_scalars -= 4;
93
- }
94
- sum_f32x4 = vfmaq_f32(sum_f32x4, a_f32x4, b_f32x4);
95
- if (count_scalars) goto nk_dot_f16_neonhalf_cycle;
96
- *result = vaddvq_f32(sum_f32x4);
97
- }
98
-
99
- NK_PUBLIC void nk_dot_f16c_neonhalf(nk_f16c_t const *a_pairs, nk_f16c_t const *b_pairs, nk_size_t count_pairs,
100
- nk_f32c_t *result) {
101
- float32x4_t sum_real_f32x4 = vdupq_n_f32(0);
102
- float32x4_t sum_imag_f32x4 = vdupq_n_f32(0);
103
- while (count_pairs >= 4) {
104
- // Unpack the input arrays into real and imaginary parts.
105
- // MSVC sadly doesn't recognize the `vld2_f16`, so we load the data as signed
106
- // integers of the same size and reinterpret with `vreinterpret_f16_s16` afterwards.
107
- int16x4x2_t a_i16x4x2 = vld2_s16((short *)a_pairs);
108
- int16x4x2_t b_i16x4x2 = vld2_s16((short *)b_pairs);
109
- float32x4_t a_real_f32x4 = vcvt_f32_f16(vreinterpret_f16_s16(a_i16x4x2.val[0]));
110
- float32x4_t a_imag_f32x4 = vcvt_f32_f16(vreinterpret_f16_s16(a_i16x4x2.val[1]));
111
- float32x4_t b_real_f32x4 = vcvt_f32_f16(vreinterpret_f16_s16(b_i16x4x2.val[0]));
112
- float32x4_t b_imag_f32x4 = vcvt_f32_f16(vreinterpret_f16_s16(b_i16x4x2.val[1]));
113
- sum_real_f32x4 = vfmaq_f32(sum_real_f32x4, a_real_f32x4, b_real_f32x4);
114
- sum_real_f32x4 = vfmsq_f32(sum_real_f32x4, a_imag_f32x4, b_imag_f32x4);
115
- sum_imag_f32x4 = vfmaq_f32(sum_imag_f32x4, a_real_f32x4, b_imag_f32x4);
116
- sum_imag_f32x4 = vfmaq_f32(sum_imag_f32x4, a_imag_f32x4, b_real_f32x4);
117
- count_pairs -= 4, a_pairs += 4, b_pairs += 4;
118
- }
119
- // Reduce horizontal sums and aggregate with the tail:
120
- nk_f32c_t tail_result;
121
- nk_dot_f16c_serial(a_pairs, b_pairs, count_pairs, &tail_result);
122
- result->real = tail_result.real + vaddvq_f32(sum_real_f32x4);
123
- result->imag = tail_result.imag + vaddvq_f32(sum_imag_f32x4);
124
- }
125
-
126
- NK_PUBLIC void nk_vdot_f16c_neonhalf(nk_f16c_t const *a_pairs, nk_f16c_t const *b_pairs, nk_size_t count_pairs,
127
- nk_f32c_t *result) {
128
- float32x4_t sum_real_f32x4 = vdupq_n_f32(0);
129
- float32x4_t sum_imag_f32x4 = vdupq_n_f32(0);
130
- while (count_pairs >= 4) {
131
- // Unpack the input arrays into real and imaginary parts.
132
- // MSVC sadly doesn't recognize the `vld2_f16`, so we load the data as signed
133
- // integers of the same size and reinterpret with `vreinterpret_f16_s16` afterwards.
134
- int16x4x2_t a_i16x4x2 = vld2_s16((short *)a_pairs);
135
- int16x4x2_t b_i16x4x2 = vld2_s16((short *)b_pairs);
136
- float32x4_t a_real_f32x4 = vcvt_f32_f16(vreinterpret_f16_s16(a_i16x4x2.val[0]));
137
- float32x4_t a_imag_f32x4 = vcvt_f32_f16(vreinterpret_f16_s16(a_i16x4x2.val[1]));
138
- float32x4_t b_real_f32x4 = vcvt_f32_f16(vreinterpret_f16_s16(b_i16x4x2.val[0]));
139
- float32x4_t b_imag_f32x4 = vcvt_f32_f16(vreinterpret_f16_s16(b_i16x4x2.val[1]));
140
- sum_real_f32x4 = vfmaq_f32(sum_real_f32x4, a_real_f32x4, b_real_f32x4);
141
- sum_real_f32x4 = vfmaq_f32(sum_real_f32x4, a_imag_f32x4, b_imag_f32x4);
142
- sum_imag_f32x4 = vfmaq_f32(sum_imag_f32x4, a_real_f32x4, b_imag_f32x4);
143
- sum_imag_f32x4 = vfmsq_f32(sum_imag_f32x4, a_imag_f32x4, b_real_f32x4);
144
- count_pairs -= 4, a_pairs += 4, b_pairs += 4;
145
- }
146
- // Reduce horizontal sums and aggregate with the tail:
147
- nk_f32c_t tail_result;
148
- nk_vdot_f16c_serial(a_pairs, b_pairs, count_pairs, &tail_result);
149
- result->real = tail_result.real + vaddvq_f32(sum_real_f32x4);
150
- result->imag = tail_result.imag + vaddvq_f32(sum_imag_f32x4);
151
- }
152
-
153
- /**
154
- * @brief Running state for 64-bit dot accumulation over f16 scalars on NEON with FP16 extension.
155
- *
156
- * Processes 4 f16 values at a time (64 bits), converting directly to f32 without
157
- * the overhead of vget_low/vget_high operations on 128-bit vectors.
158
- */
159
- typedef struct nk_dot_f16x4_state_neonhalf_t {
160
- float32x4_t sum_f32x4;
161
- } nk_dot_f16x4_state_neonhalf_t;
162
-
163
- NK_INTERNAL void nk_dot_f16x4_init_neonhalf(nk_dot_f16x4_state_neonhalf_t *state) { state->sum_f32x4 = vdupq_n_f32(0); }
164
-
165
- NK_INTERNAL void nk_dot_f16x4_update_neonhalf(nk_dot_f16x4_state_neonhalf_t *state, nk_b64_vec_t a, nk_b64_vec_t b,
166
- nk_size_t depth_offset, nk_size_t active_dimensions) {
167
- nk_unused_(depth_offset);
168
- nk_unused_(active_dimensions);
169
- // 4 f16s = 64 bits, direct conversion without low/high split
170
- float16x4_t a_f16x4 = vreinterpret_f16_u16(a.u16x4);
171
- float16x4_t b_f16x4 = vreinterpret_f16_u16(b.u16x4);
172
- state->sum_f32x4 = vfmaq_f32(state->sum_f32x4, vcvt_f32_f16(a_f16x4), vcvt_f32_f16(b_f16x4));
173
- }
174
-
175
- NK_INTERNAL void nk_dot_f16x4_finalize_neonhalf( //
176
- nk_dot_f16x4_state_neonhalf_t const *state_a, nk_dot_f16x4_state_neonhalf_t const *state_b, //
177
- nk_dot_f16x4_state_neonhalf_t const *state_c, nk_dot_f16x4_state_neonhalf_t const *state_d, //
178
- nk_size_t total_dimensions, nk_b128_vec_t *result) {
179
- nk_unused_(total_dimensions);
180
- result->f32s[0] = vaddvq_f32(state_a->sum_f32x4);
181
- result->f32s[1] = vaddvq_f32(state_b->sum_f32x4);
182
- result->f32s[2] = vaddvq_f32(state_c->sum_f32x4);
183
- result->f32s[3] = vaddvq_f32(state_d->sum_f32x4);
184
- }
185
-
186
- #if defined(__clang__)
187
- #pragma clang attribute pop
188
- #elif defined(__GNUC__)
189
- #pragma GCC pop_options
190
- #endif
191
-
192
- #if defined(__cplusplus)
193
- } // extern "C"
194
- #endif
195
-
196
- #endif // NK_TARGET_NEONHALF
197
- #endif // NK_TARGET_ARM_
198
- #endif // NK_DOT_NEONHALF_H
@@ -1,57 +0,0 @@
1
- /**
2
- * @brief SIMD-accelerated Batched Dot Products for NEON FP16.
3
- * @file include/numkong/dots/neonhalf.h
4
- * @author Ash Vardanian
5
- * @date December 27, 2025
6
- *
7
- * @sa include/numkong/dots.h
8
- */
9
- #ifndef NK_DOTS_NEONHALF_H
10
- #define NK_DOTS_NEONHALF_H
11
-
12
- #if NK_TARGET_ARM_
13
- #if NK_TARGET_NEONHALF
14
-
15
- #include "numkong/dot/neonhalf.h"
16
-
17
- #if defined(__cplusplus)
18
- extern "C" {
19
- #endif
20
-
21
- #if defined(__clang__)
22
- #pragma clang attribute push(__attribute__((target("arch=armv8.2-a+simd+fp16"))), apply_to = function)
23
- #elif defined(__GNUC__)
24
- #pragma GCC push_options
25
- #pragma GCC target("arch=armv8.2-a+simd+fp16")
26
- #endif
27
-
28
- /* F16 GEMM: depth_simd_dimensions=4 (4 f16s = 8 bytes = 64-bit input for direct f32 conversion) */
29
- nk_define_cross_pack_size_(dots, f16, neonhalf, f16, f16, /*norm_value_type=*/f32, /*depth_simd_dimensions=*/4,
30
- /*dimensions_per_value=*/1)
31
- nk_define_cross_pack_(dots, f16, neonhalf, f16, f16, nk_assign_from_to_, /*norm_value_type=*/f32,
32
- nk_dots_reduce_sumsq_f16_, /*depth_simd_dimensions=*/4,
33
- /*dimensions_per_value=*/1)
34
- nk_define_cross_symmetric_(dots, f16, neonhalf, f16, f32, nk_b64_vec_t, nk_dot_f16x4_state_neonhalf_t, nk_b128_vec_t,
35
- nk_dot_f16x4_init_neonhalf, nk_load_b64_neon_, nk_partial_load_b16x4_serial_,
36
- nk_dot_f16x4_update_neonhalf, nk_dot_f16x4_finalize_neonhalf, nk_store_b128_neon_,
37
- nk_partial_store_b32x4_serial_,
38
- /*depth_simd_dimensions=*/4, /*dimensions_per_value=*/1)
39
- nk_define_cross_packed_(dots, f16, neonhalf, f16, f16, f32, nk_b64_vec_t, nk_dot_f16x4_state_neonhalf_t, nk_b128_vec_t,
40
- nk_dot_f16x4_init_neonhalf, nk_load_b64_neon_, nk_partial_load_b16x4_serial_, nk_load_b64_neon_,
41
- nk_partial_load_b16x4_serial_, nk_dot_f16x4_update_neonhalf, nk_dot_f16x4_finalize_neonhalf,
42
- nk_store_b128_neon_, nk_partial_store_b32x4_serial_,
43
- /*depth_simd_dimensions=*/4, /*dimensions_per_value=*/1)
44
-
45
- #if defined(__clang__)
46
- #pragma clang attribute pop
47
- #elif defined(__GNUC__)
48
- #pragma GCC pop_options
49
- #endif
50
-
51
- #if defined(__cplusplus)
52
- } // extern "C"
53
- #endif
54
-
55
- #endif // NK_TARGET_NEONHALF
56
- #endif // NK_TARGET_ARM_
57
- #endif // NK_DOTS_NEONHALF_H