numkong 7.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (294) hide show
  1. package/LICENSE +201 -0
  2. package/README.md +495 -0
  3. package/binding.gyp +540 -0
  4. package/c/dispatch.h +512 -0
  5. package/c/dispatch_bf16.c +389 -0
  6. package/c/dispatch_bf16c.c +52 -0
  7. package/c/dispatch_e2m3.c +263 -0
  8. package/c/dispatch_e3m2.c +243 -0
  9. package/c/dispatch_e4m3.c +276 -0
  10. package/c/dispatch_e5m2.c +272 -0
  11. package/c/dispatch_f16.c +376 -0
  12. package/c/dispatch_f16c.c +58 -0
  13. package/c/dispatch_f32.c +378 -0
  14. package/c/dispatch_f32c.c +99 -0
  15. package/c/dispatch_f64.c +296 -0
  16. package/c/dispatch_f64c.c +98 -0
  17. package/c/dispatch_i16.c +96 -0
  18. package/c/dispatch_i32.c +89 -0
  19. package/c/dispatch_i4.c +150 -0
  20. package/c/dispatch_i64.c +86 -0
  21. package/c/dispatch_i8.c +289 -0
  22. package/c/dispatch_other.c +330 -0
  23. package/c/dispatch_u1.c +148 -0
  24. package/c/dispatch_u16.c +124 -0
  25. package/c/dispatch_u32.c +118 -0
  26. package/c/dispatch_u4.c +150 -0
  27. package/c/dispatch_u64.c +102 -0
  28. package/c/dispatch_u8.c +303 -0
  29. package/c/numkong.c +950 -0
  30. package/include/README.md +573 -0
  31. package/include/module.modulemap +129 -0
  32. package/include/numkong/attention/sapphireamx.h +1361 -0
  33. package/include/numkong/attention/sme.h +2066 -0
  34. package/include/numkong/attention.h +49 -0
  35. package/include/numkong/capabilities.h +748 -0
  36. package/include/numkong/cast/README.md +262 -0
  37. package/include/numkong/cast/haswell.h +975 -0
  38. package/include/numkong/cast/icelake.h +470 -0
  39. package/include/numkong/cast/neon.h +1192 -0
  40. package/include/numkong/cast/rvv.h +1021 -0
  41. package/include/numkong/cast/sapphire.h +262 -0
  42. package/include/numkong/cast/serial.h +2262 -0
  43. package/include/numkong/cast/skylake.h +856 -0
  44. package/include/numkong/cast/v128relaxed.h +180 -0
  45. package/include/numkong/cast.h +230 -0
  46. package/include/numkong/curved/README.md +223 -0
  47. package/include/numkong/curved/genoa.h +182 -0
  48. package/include/numkong/curved/haswell.h +276 -0
  49. package/include/numkong/curved/neon.h +205 -0
  50. package/include/numkong/curved/neonbfdot.h +212 -0
  51. package/include/numkong/curved/neonhalf.h +212 -0
  52. package/include/numkong/curved/rvv.h +305 -0
  53. package/include/numkong/curved/serial.h +207 -0
  54. package/include/numkong/curved/skylake.h +457 -0
  55. package/include/numkong/curved/smef64.h +506 -0
  56. package/include/numkong/curved.h +517 -0
  57. package/include/numkong/curved.hpp +144 -0
  58. package/include/numkong/dot/README.md +425 -0
  59. package/include/numkong/dot/alder.h +563 -0
  60. package/include/numkong/dot/genoa.h +315 -0
  61. package/include/numkong/dot/haswell.h +1688 -0
  62. package/include/numkong/dot/icelake.h +883 -0
  63. package/include/numkong/dot/neon.h +818 -0
  64. package/include/numkong/dot/neonbfdot.h +244 -0
  65. package/include/numkong/dot/neonfhm.h +360 -0
  66. package/include/numkong/dot/neonhalf.h +198 -0
  67. package/include/numkong/dot/neonsdot.h +508 -0
  68. package/include/numkong/dot/rvv.h +714 -0
  69. package/include/numkong/dot/rvvbb.h +72 -0
  70. package/include/numkong/dot/rvvbf16.h +123 -0
  71. package/include/numkong/dot/rvvhalf.h +129 -0
  72. package/include/numkong/dot/sapphire.h +141 -0
  73. package/include/numkong/dot/serial.h +838 -0
  74. package/include/numkong/dot/sierra.h +405 -0
  75. package/include/numkong/dot/skylake.h +1084 -0
  76. package/include/numkong/dot/sve.h +379 -0
  77. package/include/numkong/dot/svebfdot.h +74 -0
  78. package/include/numkong/dot/svehalf.h +123 -0
  79. package/include/numkong/dot/v128relaxed.h +1258 -0
  80. package/include/numkong/dot.h +1070 -0
  81. package/include/numkong/dot.hpp +94 -0
  82. package/include/numkong/dots/README.md +496 -0
  83. package/include/numkong/dots/alder.h +114 -0
  84. package/include/numkong/dots/genoa.h +94 -0
  85. package/include/numkong/dots/haswell.h +295 -0
  86. package/include/numkong/dots/icelake.h +171 -0
  87. package/include/numkong/dots/neon.h +120 -0
  88. package/include/numkong/dots/neonbfdot.h +58 -0
  89. package/include/numkong/dots/neonfhm.h +94 -0
  90. package/include/numkong/dots/neonhalf.h +57 -0
  91. package/include/numkong/dots/neonsdot.h +108 -0
  92. package/include/numkong/dots/rvv.h +2486 -0
  93. package/include/numkong/dots/sapphireamx.h +3973 -0
  94. package/include/numkong/dots/serial.h +2844 -0
  95. package/include/numkong/dots/sierra.h +97 -0
  96. package/include/numkong/dots/skylake.h +196 -0
  97. package/include/numkong/dots/sme.h +5372 -0
  98. package/include/numkong/dots/smebi32.h +461 -0
  99. package/include/numkong/dots/smef64.h +1318 -0
  100. package/include/numkong/dots/smehalf.h +47 -0
  101. package/include/numkong/dots/v128relaxed.h +294 -0
  102. package/include/numkong/dots.h +2804 -0
  103. package/include/numkong/dots.hpp +639 -0
  104. package/include/numkong/each/README.md +469 -0
  105. package/include/numkong/each/haswell.h +1658 -0
  106. package/include/numkong/each/icelake.h +272 -0
  107. package/include/numkong/each/neon.h +1104 -0
  108. package/include/numkong/each/neonbfdot.h +212 -0
  109. package/include/numkong/each/neonhalf.h +410 -0
  110. package/include/numkong/each/rvv.h +1121 -0
  111. package/include/numkong/each/sapphire.h +477 -0
  112. package/include/numkong/each/serial.h +260 -0
  113. package/include/numkong/each/skylake.h +1562 -0
  114. package/include/numkong/each.h +2146 -0
  115. package/include/numkong/each.hpp +434 -0
  116. package/include/numkong/geospatial/README.md +147 -0
  117. package/include/numkong/geospatial/haswell.h +593 -0
  118. package/include/numkong/geospatial/neon.h +571 -0
  119. package/include/numkong/geospatial/rvv.h +701 -0
  120. package/include/numkong/geospatial/serial.h +309 -0
  121. package/include/numkong/geospatial/skylake.h +577 -0
  122. package/include/numkong/geospatial/v128relaxed.h +613 -0
  123. package/include/numkong/geospatial.h +453 -0
  124. package/include/numkong/geospatial.hpp +235 -0
  125. package/include/numkong/matrix.hpp +336 -0
  126. package/include/numkong/maxsim/README.md +187 -0
  127. package/include/numkong/maxsim/alder.h +511 -0
  128. package/include/numkong/maxsim/genoa.h +115 -0
  129. package/include/numkong/maxsim/haswell.h +553 -0
  130. package/include/numkong/maxsim/icelake.h +480 -0
  131. package/include/numkong/maxsim/neonsdot.h +394 -0
  132. package/include/numkong/maxsim/sapphireamx.h +877 -0
  133. package/include/numkong/maxsim/serial.h +490 -0
  134. package/include/numkong/maxsim/sme.h +929 -0
  135. package/include/numkong/maxsim/v128relaxed.h +280 -0
  136. package/include/numkong/maxsim.h +571 -0
  137. package/include/numkong/maxsim.hpp +133 -0
  138. package/include/numkong/mesh/README.md +227 -0
  139. package/include/numkong/mesh/haswell.h +2235 -0
  140. package/include/numkong/mesh/neon.h +1329 -0
  141. package/include/numkong/mesh/neonbfdot.h +842 -0
  142. package/include/numkong/mesh/neonhalf.h +616 -0
  143. package/include/numkong/mesh/rvv.h +916 -0
  144. package/include/numkong/mesh/serial.h +742 -0
  145. package/include/numkong/mesh/skylake.h +1135 -0
  146. package/include/numkong/mesh/v128relaxed.h +1052 -0
  147. package/include/numkong/mesh.h +652 -0
  148. package/include/numkong/mesh.hpp +762 -0
  149. package/include/numkong/numkong.h +78 -0
  150. package/include/numkong/numkong.hpp +57 -0
  151. package/include/numkong/probability/README.md +173 -0
  152. package/include/numkong/probability/haswell.h +267 -0
  153. package/include/numkong/probability/neon.h +225 -0
  154. package/include/numkong/probability/rvv.h +409 -0
  155. package/include/numkong/probability/serial.h +169 -0
  156. package/include/numkong/probability/skylake.h +324 -0
  157. package/include/numkong/probability.h +383 -0
  158. package/include/numkong/probability.hpp +120 -0
  159. package/include/numkong/random.h +50 -0
  160. package/include/numkong/random.hpp +285 -0
  161. package/include/numkong/reduce/README.md +547 -0
  162. package/include/numkong/reduce/alder.h +632 -0
  163. package/include/numkong/reduce/genoa.h +201 -0
  164. package/include/numkong/reduce/haswell.h +3783 -0
  165. package/include/numkong/reduce/icelake.h +549 -0
  166. package/include/numkong/reduce/neon.h +3841 -0
  167. package/include/numkong/reduce/neonbfdot.h +353 -0
  168. package/include/numkong/reduce/neonfhm.h +665 -0
  169. package/include/numkong/reduce/neonhalf.h +157 -0
  170. package/include/numkong/reduce/neonsdot.h +357 -0
  171. package/include/numkong/reduce/rvv.h +3407 -0
  172. package/include/numkong/reduce/serial.h +757 -0
  173. package/include/numkong/reduce/sierra.h +338 -0
  174. package/include/numkong/reduce/skylake.h +3792 -0
  175. package/include/numkong/reduce/v128relaxed.h +2302 -0
  176. package/include/numkong/reduce.h +1597 -0
  177. package/include/numkong/reduce.hpp +633 -0
  178. package/include/numkong/scalar/README.md +89 -0
  179. package/include/numkong/scalar/haswell.h +113 -0
  180. package/include/numkong/scalar/neon.h +122 -0
  181. package/include/numkong/scalar/neonhalf.h +70 -0
  182. package/include/numkong/scalar/rvv.h +211 -0
  183. package/include/numkong/scalar/sapphire.h +63 -0
  184. package/include/numkong/scalar/serial.h +332 -0
  185. package/include/numkong/scalar/v128relaxed.h +56 -0
  186. package/include/numkong/scalar.h +683 -0
  187. package/include/numkong/set/README.md +179 -0
  188. package/include/numkong/set/haswell.h +334 -0
  189. package/include/numkong/set/icelake.h +485 -0
  190. package/include/numkong/set/neon.h +364 -0
  191. package/include/numkong/set/rvv.h +226 -0
  192. package/include/numkong/set/rvvbb.h +117 -0
  193. package/include/numkong/set/serial.h +174 -0
  194. package/include/numkong/set/sve.h +185 -0
  195. package/include/numkong/set/v128relaxed.h +240 -0
  196. package/include/numkong/set.h +457 -0
  197. package/include/numkong/set.hpp +114 -0
  198. package/include/numkong/sets/README.md +149 -0
  199. package/include/numkong/sets/haswell.h +63 -0
  200. package/include/numkong/sets/icelake.h +66 -0
  201. package/include/numkong/sets/neon.h +61 -0
  202. package/include/numkong/sets/serial.h +43 -0
  203. package/include/numkong/sets/smebi32.h +1099 -0
  204. package/include/numkong/sets/v128relaxed.h +58 -0
  205. package/include/numkong/sets.h +339 -0
  206. package/include/numkong/sparse/README.md +156 -0
  207. package/include/numkong/sparse/icelake.h +463 -0
  208. package/include/numkong/sparse/neon.h +288 -0
  209. package/include/numkong/sparse/serial.h +117 -0
  210. package/include/numkong/sparse/sve2.h +507 -0
  211. package/include/numkong/sparse/turin.h +322 -0
  212. package/include/numkong/sparse.h +363 -0
  213. package/include/numkong/sparse.hpp +113 -0
  214. package/include/numkong/spatial/README.md +435 -0
  215. package/include/numkong/spatial/alder.h +607 -0
  216. package/include/numkong/spatial/genoa.h +290 -0
  217. package/include/numkong/spatial/haswell.h +960 -0
  218. package/include/numkong/spatial/icelake.h +586 -0
  219. package/include/numkong/spatial/neon.h +773 -0
  220. package/include/numkong/spatial/neonbfdot.h +165 -0
  221. package/include/numkong/spatial/neonhalf.h +118 -0
  222. package/include/numkong/spatial/neonsdot.h +261 -0
  223. package/include/numkong/spatial/rvv.h +984 -0
  224. package/include/numkong/spatial/rvvbf16.h +123 -0
  225. package/include/numkong/spatial/rvvhalf.h +117 -0
  226. package/include/numkong/spatial/sapphire.h +343 -0
  227. package/include/numkong/spatial/serial.h +346 -0
  228. package/include/numkong/spatial/sierra.h +323 -0
  229. package/include/numkong/spatial/skylake.h +606 -0
  230. package/include/numkong/spatial/sve.h +224 -0
  231. package/include/numkong/spatial/svebfdot.h +122 -0
  232. package/include/numkong/spatial/svehalf.h +109 -0
  233. package/include/numkong/spatial/v128relaxed.h +717 -0
  234. package/include/numkong/spatial.h +1425 -0
  235. package/include/numkong/spatial.hpp +183 -0
  236. package/include/numkong/spatials/README.md +580 -0
  237. package/include/numkong/spatials/alder.h +94 -0
  238. package/include/numkong/spatials/genoa.h +94 -0
  239. package/include/numkong/spatials/haswell.h +219 -0
  240. package/include/numkong/spatials/icelake.h +113 -0
  241. package/include/numkong/spatials/neon.h +109 -0
  242. package/include/numkong/spatials/neonbfdot.h +60 -0
  243. package/include/numkong/spatials/neonfhm.h +92 -0
  244. package/include/numkong/spatials/neonhalf.h +58 -0
  245. package/include/numkong/spatials/neonsdot.h +109 -0
  246. package/include/numkong/spatials/rvv.h +1960 -0
  247. package/include/numkong/spatials/sapphireamx.h +1149 -0
  248. package/include/numkong/spatials/serial.h +226 -0
  249. package/include/numkong/spatials/sierra.h +96 -0
  250. package/include/numkong/spatials/skylake.h +184 -0
  251. package/include/numkong/spatials/sme.h +1901 -0
  252. package/include/numkong/spatials/smef64.h +465 -0
  253. package/include/numkong/spatials/v128relaxed.h +240 -0
  254. package/include/numkong/spatials.h +3021 -0
  255. package/include/numkong/spatials.hpp +508 -0
  256. package/include/numkong/tensor.hpp +1592 -0
  257. package/include/numkong/trigonometry/README.md +184 -0
  258. package/include/numkong/trigonometry/haswell.h +652 -0
  259. package/include/numkong/trigonometry/neon.h +639 -0
  260. package/include/numkong/trigonometry/rvv.h +699 -0
  261. package/include/numkong/trigonometry/serial.h +703 -0
  262. package/include/numkong/trigonometry/skylake.h +721 -0
  263. package/include/numkong/trigonometry/v128relaxed.h +666 -0
  264. package/include/numkong/trigonometry.h +467 -0
  265. package/include/numkong/trigonometry.hpp +166 -0
  266. package/include/numkong/types.h +1384 -0
  267. package/include/numkong/types.hpp +5603 -0
  268. package/include/numkong/vector.hpp +698 -0
  269. package/javascript/README.md +246 -0
  270. package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
  271. package/javascript/dist/cjs/numkong-wasm.js +617 -0
  272. package/javascript/dist/cjs/numkong.d.ts +343 -0
  273. package/javascript/dist/cjs/numkong.js +523 -0
  274. package/javascript/dist/cjs/package.json +3 -0
  275. package/javascript/dist/cjs/types.d.ts +284 -0
  276. package/javascript/dist/cjs/types.js +653 -0
  277. package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
  278. package/javascript/dist/esm/numkong-wasm.js +595 -0
  279. package/javascript/dist/esm/numkong.d.ts +343 -0
  280. package/javascript/dist/esm/numkong.js +452 -0
  281. package/javascript/dist/esm/package.json +3 -0
  282. package/javascript/dist/esm/types.d.ts +284 -0
  283. package/javascript/dist/esm/types.js +630 -0
  284. package/javascript/dist-package-cjs.json +3 -0
  285. package/javascript/dist-package-esm.json +3 -0
  286. package/javascript/node-gyp-build.d.ts +1 -0
  287. package/javascript/numkong-wasm.ts +756 -0
  288. package/javascript/numkong.c +689 -0
  289. package/javascript/numkong.ts +575 -0
  290. package/javascript/tsconfig-base.json +39 -0
  291. package/javascript/tsconfig-cjs.json +8 -0
  292. package/javascript/tsconfig-esm.json +8 -0
  293. package/javascript/types.ts +674 -0
  294. package/package.json +87 -0
@@ -0,0 +1,379 @@
1
+ /**
2
+ * @brief SIMD-accelerated Dot Products for SVE.
3
+ * @file include/numkong/dot/sve.h
4
+ * @author Ash Vardanian
5
+ * @date December 27, 2025
6
+ *
7
+ * @sa include/numkong/dot.h
8
+ *
9
+ * @section dot_sve_instructions ARM SVE Instructions
10
+ *
11
+ * Intrinsic Instruction Latency Throughput
12
+ * svld1_f32 LD1W (Z.S, P/Z, [Xn]) 4-6cy 2/cy
13
+ * svld2_f32 LD2W (Z.S, P/Z, [Xn]) 6-8cy 1/cy
14
+ * svmla_f32_x FMLA (Z.S, P/M, Z.S, Z.S) 4cy 2/cy
15
+ * svmls_f32_x FMLS (Z.S, P/M, Z.S, Z.S) 4cy 2/cy
16
+ * svaddv_f32 FADDV (S, P, Z.S) 6cy 1/cy
17
+ * svdup_f32 DUP (Z.S, #imm) 1cy 2/cy
18
+ * svwhilelt_b32 WHILELT (P.S, Xn, Xm) 2cy 1/cy
19
+ * svptrue_b32 PTRUE (P.S, pattern) 1cy 2/cy
20
+ * svcntw CNTW (Xd) 1cy 2/cy
21
+ * svcntd CNTD (Xd) 1cy 2/cy
22
+ * svld1_f64 LD1D (Z.D, P/Z, [Xn]) 4-6cy 2/cy
23
+ * svld2_f64 LD2D (Z.D, P/Z, [Xn]) 6-8cy 1/cy
24
+ * svmla_f64_x FMLA (Z.D, P/M, Z.D, Z.D) 4cy 2/cy
25
+ * svmls_f64_x FMLS (Z.D, P/M, Z.D, Z.D) 4cy 2/cy
26
+ * svaddv_f64 FADDV (D, P, Z.D) 6cy 1/cy
27
+ *
28
+ * SVE vector widths vary across implementations: Graviton3 uses 256-bit, while Graviton4/5
29
+ * and Apple M4+ use 128-bit. Code using svcntb() adapts automatically, but wider vectors
30
+ * process more elements per iteration with identical latencies.
31
+ *
32
+ * The FADDV horizontal reduction has higher latency (6cy) compared to vertical operations,
33
+ * making it beneficial to accumulate in vector registers and reduce only at the end.
34
+ */
35
+ #ifndef NK_DOT_SVE_H
36
+ #define NK_DOT_SVE_H
37
+
38
+ #if NK_TARGET_ARM_
39
+ #if NK_TARGET_SVE
40
+
41
+ #include "numkong/types.h" // `nk_f32_t`
42
+ #include "numkong/dot/serial.h" // `nk_u1x8_popcount_`
43
+
44
+ #if defined(__cplusplus)
45
+ extern "C" {
46
+ #endif
47
+
48
+ #if defined(__clang__)
49
+ #pragma clang attribute push(__attribute__((target("arch=armv8.2-a+sve"))), apply_to = function)
50
+ #elif defined(__GNUC__)
51
+ #pragma GCC push_options
52
+ #pragma GCC target("arch=armv8.2-a+sve")
53
+ #endif
54
+
55
+ /** @brief Compensated horizontal sum of SVE f64 lanes via TwoSum tree reduction.
56
+ *
57
+ * Uses svtbl to extract the upper half at each tree level. Out-of-range indices
58
+ * return 0 (SVE spec), which is harmless since only the lower half is meaningful
59
+ * after each halving stage.
60
+ */
61
+ NK_INTERNAL nk_f64_t nk_dot_stable_sum_f64_sve_(svbool_t predicate, svfloat64_t sum, svfloat64_t compensation) {
62
+ // Stage 0: TwoSum merge of sum + compensation (parallel across all active lanes)
63
+ svfloat64_t tentative_sum_f64x = svadd_f64_x(predicate, sum, compensation);
64
+ svfloat64_t virtual_addend_f64x = svsub_f64_x(predicate, tentative_sum_f64x, sum);
65
+ svfloat64_t accumulated_error_f64x = svadd_f64_x(
66
+ predicate, svsub_f64_x(predicate, sum, svsub_f64_x(predicate, tentative_sum_f64x, virtual_addend_f64x)),
67
+ svsub_f64_x(predicate, compensation, virtual_addend_f64x));
68
+
69
+ // Tree reduction: TwoSum halving at each level, log2(VL) iterations
70
+ for (unsigned int half = (unsigned int)svcntd() / 2; half > 0; half >>= 1) {
71
+ svuint64_t upper_indices_u64x = svadd_n_u64_x(predicate, svindex_u64(0, 1), half);
72
+ svfloat64_t upper_sum_f64x = svtbl_f64(tentative_sum_f64x, upper_indices_u64x);
73
+ svfloat64_t upper_error_f64x = svtbl_f64(accumulated_error_f64x, upper_indices_u64x);
74
+ // TwoSum: lower_half + upper_half
75
+ svfloat64_t halved_tentative_sum_f64x = svadd_f64_x(predicate, tentative_sum_f64x, upper_sum_f64x);
76
+ svfloat64_t halved_virtual_addend_f64x = svsub_f64_x(predicate, halved_tentative_sum_f64x, tentative_sum_f64x);
77
+ svfloat64_t rounding_error_f64x = svadd_f64_x(
78
+ predicate,
79
+ svsub_f64_x(predicate, tentative_sum_f64x,
80
+ svsub_f64_x(predicate, halved_tentative_sum_f64x, halved_virtual_addend_f64x)),
81
+ svsub_f64_x(predicate, upper_sum_f64x, halved_virtual_addend_f64x));
82
+ tentative_sum_f64x = halved_tentative_sum_f64x;
83
+ accumulated_error_f64x = svadd_f64_x(
84
+ predicate, svadd_f64_x(predicate, accumulated_error_f64x, upper_error_f64x), rounding_error_f64x);
85
+ }
86
+ // Result is in lane 0
87
+ svbool_t predicate_first_f64x = svwhilelt_b64_u64(0u, 1);
88
+ return svlastb_f64(predicate_first_f64x, tentative_sum_f64x) +
89
+ svlastb_f64(predicate_first_f64x, accumulated_error_f64x);
90
+ }
91
+
92
+ NK_PUBLIC void nk_dot_f32_sve(nk_f32_t const *a_scalars, nk_f32_t const *b_scalars, nk_size_t count_scalars,
93
+ nk_f64_t *result) {
94
+ nk_size_t idx_scalars = 0;
95
+ nk_size_t const vector_length = svcntd();
96
+ svfloat64_t ab_f64x = svdup_f64(0.);
97
+ for (; idx_scalars < count_scalars; idx_scalars += vector_length) {
98
+ svbool_t predicate_f64x = svwhilelt_b64_u64(idx_scalars, count_scalars);
99
+ svfloat64_t a_f64x = svcvt_f64_f32_x(
100
+ predicate_f64x, svld1_f32(svwhilelt_b32_u64(idx_scalars, count_scalars), a_scalars + idx_scalars));
101
+ svfloat64_t b_f64x = svcvt_f64_f32_x(
102
+ predicate_f64x, svld1_f32(svwhilelt_b32_u64(idx_scalars, count_scalars), b_scalars + idx_scalars));
103
+ ab_f64x = svmla_f64_x(predicate_f64x, ab_f64x, a_f64x, b_f64x);
104
+ }
105
+ *result = svaddv_f64(svptrue_b64(), ab_f64x);
106
+ }
107
+
108
+ NK_PUBLIC void nk_dot_f32c_sve(nk_f32c_t const *a_pairs, nk_f32c_t const *b_pairs, nk_size_t count_pairs,
109
+ nk_f64c_t *results) {
110
+ nk_size_t idx_pairs = 0;
111
+ nk_size_t const vector_length = svcntd();
112
+ svfloat64_t ab_real_f64x = svdup_f64(0.);
113
+ svfloat64_t ab_imag_f64x = svdup_f64(0.);
114
+ for (; idx_pairs < count_pairs; idx_pairs += vector_length) {
115
+ svbool_t predicate_f64x = svwhilelt_b64_u64(idx_pairs, count_pairs);
116
+ svbool_t predicate_f32x = svwhilelt_b32_u64(idx_pairs, count_pairs);
117
+ svfloat32x2_t a_f32x2 = svld2_f32(predicate_f32x, (nk_f32_t const *)(a_pairs + idx_pairs));
118
+ svfloat32x2_t b_f32x2 = svld2_f32(predicate_f32x, (nk_f32_t const *)(b_pairs + idx_pairs));
119
+ svfloat64_t a_real_f64x = svcvt_f64_f32_x(predicate_f64x, svget2_f32(a_f32x2, 0));
120
+ svfloat64_t a_imag_f64x = svcvt_f64_f32_x(predicate_f64x, svget2_f32(a_f32x2, 1));
121
+ svfloat64_t b_real_f64x = svcvt_f64_f32_x(predicate_f64x, svget2_f32(b_f32x2, 0));
122
+ svfloat64_t b_imag_f64x = svcvt_f64_f32_x(predicate_f64x, svget2_f32(b_f32x2, 1));
123
+ ab_real_f64x = svmla_f64_x(predicate_f64x, ab_real_f64x, a_real_f64x, b_real_f64x);
124
+ ab_real_f64x = svmls_f64_x(predicate_f64x, ab_real_f64x, a_imag_f64x, b_imag_f64x);
125
+ ab_imag_f64x = svmla_f64_x(predicate_f64x, ab_imag_f64x, a_real_f64x, b_imag_f64x);
126
+ ab_imag_f64x = svmla_f64_x(predicate_f64x, ab_imag_f64x, a_imag_f64x, b_real_f64x);
127
+ }
128
+ results->real = svaddv_f64(svptrue_b64(), ab_real_f64x);
129
+ results->imag = svaddv_f64(svptrue_b64(), ab_imag_f64x);
130
+ }
131
+
132
+ NK_PUBLIC void nk_vdot_f32c_sve(nk_f32c_t const *a_pairs, nk_f32c_t const *b_pairs, nk_size_t count_pairs,
133
+ nk_f64c_t *results) {
134
+ nk_size_t idx_pairs = 0;
135
+ nk_size_t const vector_length = svcntd();
136
+ svfloat64_t ab_real_f64x = svdup_f64(0.);
137
+ svfloat64_t ab_imag_f64x = svdup_f64(0.);
138
+ for (; idx_pairs < count_pairs; idx_pairs += vector_length) {
139
+ svbool_t predicate_f64x = svwhilelt_b64_u64(idx_pairs, count_pairs);
140
+ svbool_t predicate_f32x = svwhilelt_b32_u64(idx_pairs, count_pairs);
141
+ svfloat32x2_t a_f32x2 = svld2_f32(predicate_f32x, (nk_f32_t const *)(a_pairs + idx_pairs));
142
+ svfloat32x2_t b_f32x2 = svld2_f32(predicate_f32x, (nk_f32_t const *)(b_pairs + idx_pairs));
143
+ svfloat64_t a_real_f64x = svcvt_f64_f32_x(predicate_f64x, svget2_f32(a_f32x2, 0));
144
+ svfloat64_t a_imag_f64x = svcvt_f64_f32_x(predicate_f64x, svget2_f32(a_f32x2, 1));
145
+ svfloat64_t b_real_f64x = svcvt_f64_f32_x(predicate_f64x, svget2_f32(b_f32x2, 0));
146
+ svfloat64_t b_imag_f64x = svcvt_f64_f32_x(predicate_f64x, svget2_f32(b_f32x2, 1));
147
+ ab_real_f64x = svmla_f64_x(predicate_f64x, ab_real_f64x, a_real_f64x, b_real_f64x);
148
+ ab_real_f64x = svmla_f64_x(predicate_f64x, ab_real_f64x, a_imag_f64x, b_imag_f64x);
149
+ ab_imag_f64x = svmla_f64_x(predicate_f64x, ab_imag_f64x, a_real_f64x, b_imag_f64x);
150
+ ab_imag_f64x = svmls_f64_x(predicate_f64x, ab_imag_f64x, a_imag_f64x, b_real_f64x);
151
+ }
152
+ results->real = svaddv_f64(svptrue_b64(), ab_real_f64x);
153
+ results->imag = svaddv_f64(svptrue_b64(), ab_imag_f64x);
154
+ }
155
+
156
+ NK_PUBLIC void nk_dot_f64_sve(nk_f64_t const *a_scalars, nk_f64_t const *b_scalars, nk_size_t count_scalars,
157
+ nk_f64_t *result) {
158
+ // Dot2 (Ogita-Rump-Oishi) compensated accumulation via TwoProd + TwoSum
159
+ nk_size_t idx_scalars = 0;
160
+ svfloat64_t sum_f64x = svdup_f64(0.);
161
+ svfloat64_t compensation_f64x = svdup_f64(0.);
162
+ do {
163
+ svbool_t predicate_f64x = svwhilelt_b64_u64(idx_scalars, count_scalars);
164
+ svfloat64_t a_f64x = svld1_f64(predicate_f64x, a_scalars + idx_scalars);
165
+ svfloat64_t b_f64x = svld1_f64(predicate_f64x, b_scalars + idx_scalars);
166
+ // TwoProd: product = a*b, error = -(product - a*b) negated
167
+ svfloat64_t product_f64x = svmul_f64_x(predicate_f64x, a_f64x, b_f64x);
168
+ svfloat64_t product_error_f64x = svneg_f64_x(predicate_f64x,
169
+ svnmls_f64_x(predicate_f64x, product_f64x, a_f64x, b_f64x));
170
+ // TwoSum: tentative_sum = sum + product
171
+ svfloat64_t tentative_sum_f64x = svadd_f64_x(predicate_f64x, sum_f64x, product_f64x);
172
+ svfloat64_t virtual_addend_f64x = svsub_f64_x(predicate_f64x, tentative_sum_f64x, sum_f64x);
173
+ svfloat64_t sum_error_f64x = svadd_f64_x(
174
+ predicate_f64x,
175
+ svsub_f64_x(predicate_f64x, sum_f64x, svsub_f64_x(predicate_f64x, tentative_sum_f64x, virtual_addend_f64x)),
176
+ svsub_f64_x(predicate_f64x, product_f64x, virtual_addend_f64x));
177
+ sum_f64x = tentative_sum_f64x;
178
+ compensation_f64x = svadd_f64_x(predicate_f64x, compensation_f64x,
179
+ svadd_f64_x(predicate_f64x, sum_error_f64x, product_error_f64x));
180
+ idx_scalars += svcntd();
181
+ } while (idx_scalars < count_scalars);
182
+ *result = nk_dot_stable_sum_f64_sve_(svptrue_b64(), sum_f64x, compensation_f64x);
183
+ }
184
+
185
+ NK_PUBLIC void nk_dot_f64c_sve(nk_f64c_t const *a_pairs, nk_f64c_t const *b_pairs, nk_size_t count_pairs,
186
+ nk_f64c_t *results) {
187
+ // Dot2 compensated accumulation for complex dot product: (a_real + i*a_imag)(b_real + i*b_imag)
188
+ // real = a_real*b_real - a_imag*b_imag, imag = a_real*b_imag + a_imag*b_real
189
+ nk_size_t idx_pairs = 0;
190
+ svfloat64_t sum_real_f64x = svdup_f64(0.);
191
+ svfloat64_t comp_real_f64x = svdup_f64(0.);
192
+ svfloat64_t sum_imag_f64x = svdup_f64(0.);
193
+ svfloat64_t comp_imag_f64x = svdup_f64(0.);
194
+ do {
195
+ svbool_t predicate_f64x = svwhilelt_b64_u64(idx_pairs, count_pairs);
196
+ svfloat64x2_t a_f64x2 = svld2_f64(predicate_f64x, (nk_f64_t const *)(a_pairs + idx_pairs));
197
+ svfloat64x2_t b_f64x2 = svld2_f64(predicate_f64x, (nk_f64_t const *)(b_pairs + idx_pairs));
198
+ svfloat64_t a_real_f64x = svget2_f64(a_f64x2, 0);
199
+ svfloat64_t a_imag_f64x = svget2_f64(a_f64x2, 1);
200
+ svfloat64_t b_real_f64x = svget2_f64(b_f64x2, 0);
201
+ svfloat64_t b_imag_f64x = svget2_f64(b_f64x2, 1);
202
+
203
+ // TwoProd + TwoSum for real part: sum_real += a_real*b_real
204
+ {
205
+ svfloat64_t product_f64x = svmul_f64_x(predicate_f64x, a_real_f64x, b_real_f64x);
206
+ svfloat64_t product_error_f64x = svneg_f64_x(
207
+ predicate_f64x, svnmls_f64_x(predicate_f64x, product_f64x, a_real_f64x, b_real_f64x));
208
+ svfloat64_t tentative_sum_f64x = svadd_f64_x(predicate_f64x, sum_real_f64x, product_f64x);
209
+ svfloat64_t virtual_addend_f64x = svsub_f64_x(predicate_f64x, tentative_sum_f64x, sum_real_f64x);
210
+ svfloat64_t sum_error_f64x = svadd_f64_x(
211
+ predicate_f64x,
212
+ svsub_f64_x(predicate_f64x, sum_real_f64x,
213
+ svsub_f64_x(predicate_f64x, tentative_sum_f64x, virtual_addend_f64x)),
214
+ svsub_f64_x(predicate_f64x, product_f64x, virtual_addend_f64x));
215
+ sum_real_f64x = tentative_sum_f64x;
216
+ comp_real_f64x = svadd_f64_x(predicate_f64x, comp_real_f64x,
217
+ svadd_f64_x(predicate_f64x, sum_error_f64x, product_error_f64x));
218
+ }
219
+ // TwoProd + TwoSum for real part: sum_real -= a_imag*b_imag
220
+ {
221
+ svfloat64_t product_f64x = svmul_f64_x(predicate_f64x, a_imag_f64x, b_imag_f64x);
222
+ svfloat64_t product_error_f64x = svneg_f64_x(
223
+ predicate_f64x, svnmls_f64_x(predicate_f64x, product_f64x, a_imag_f64x, b_imag_f64x));
224
+ svfloat64_t neg_product_f64x = svneg_f64_x(predicate_f64x, product_f64x);
225
+ svfloat64_t neg_product_error_f64x = svneg_f64_x(predicate_f64x, product_error_f64x);
226
+ svfloat64_t tentative_sum_f64x = svadd_f64_x(predicate_f64x, sum_real_f64x, neg_product_f64x);
227
+ svfloat64_t virtual_addend_f64x = svsub_f64_x(predicate_f64x, tentative_sum_f64x, sum_real_f64x);
228
+ svfloat64_t sum_error_f64x = svadd_f64_x(
229
+ predicate_f64x,
230
+ svsub_f64_x(predicate_f64x, sum_real_f64x,
231
+ svsub_f64_x(predicate_f64x, tentative_sum_f64x, virtual_addend_f64x)),
232
+ svsub_f64_x(predicate_f64x, neg_product_f64x, virtual_addend_f64x));
233
+ sum_real_f64x = tentative_sum_f64x;
234
+ comp_real_f64x = svadd_f64_x(predicate_f64x, comp_real_f64x,
235
+ svadd_f64_x(predicate_f64x, sum_error_f64x, neg_product_error_f64x));
236
+ }
237
+ // TwoProd + TwoSum for imaginary part: sum_imag += a_real*b_imag
238
+ {
239
+ svfloat64_t product_f64x = svmul_f64_x(predicate_f64x, a_real_f64x, b_imag_f64x);
240
+ svfloat64_t product_error_f64x = svneg_f64_x(
241
+ predicate_f64x, svnmls_f64_x(predicate_f64x, product_f64x, a_real_f64x, b_imag_f64x));
242
+ svfloat64_t tentative_sum_f64x = svadd_f64_x(predicate_f64x, sum_imag_f64x, product_f64x);
243
+ svfloat64_t virtual_addend_f64x = svsub_f64_x(predicate_f64x, tentative_sum_f64x, sum_imag_f64x);
244
+ svfloat64_t sum_error_f64x = svadd_f64_x(
245
+ predicate_f64x,
246
+ svsub_f64_x(predicate_f64x, sum_imag_f64x,
247
+ svsub_f64_x(predicate_f64x, tentative_sum_f64x, virtual_addend_f64x)),
248
+ svsub_f64_x(predicate_f64x, product_f64x, virtual_addend_f64x));
249
+ sum_imag_f64x = tentative_sum_f64x;
250
+ comp_imag_f64x = svadd_f64_x(predicate_f64x, comp_imag_f64x,
251
+ svadd_f64_x(predicate_f64x, sum_error_f64x, product_error_f64x));
252
+ }
253
+ // TwoProd + TwoSum for imaginary part: sum_imag += a_imag*b_real
254
+ {
255
+ svfloat64_t product_f64x = svmul_f64_x(predicate_f64x, a_imag_f64x, b_real_f64x);
256
+ svfloat64_t product_error_f64x = svneg_f64_x(
257
+ predicate_f64x, svnmls_f64_x(predicate_f64x, product_f64x, a_imag_f64x, b_real_f64x));
258
+ svfloat64_t tentative_sum_f64x = svadd_f64_x(predicate_f64x, sum_imag_f64x, product_f64x);
259
+ svfloat64_t virtual_addend_f64x = svsub_f64_x(predicate_f64x, tentative_sum_f64x, sum_imag_f64x);
260
+ svfloat64_t sum_error_f64x = svadd_f64_x(
261
+ predicate_f64x,
262
+ svsub_f64_x(predicate_f64x, sum_imag_f64x,
263
+ svsub_f64_x(predicate_f64x, tentative_sum_f64x, virtual_addend_f64x)),
264
+ svsub_f64_x(predicate_f64x, product_f64x, virtual_addend_f64x));
265
+ sum_imag_f64x = tentative_sum_f64x;
266
+ comp_imag_f64x = svadd_f64_x(predicate_f64x, comp_imag_f64x,
267
+ svadd_f64_x(predicate_f64x, sum_error_f64x, product_error_f64x));
268
+ }
269
+ idx_pairs += svcntd();
270
+ } while (idx_pairs < count_pairs);
271
+ svbool_t predicate_all_f64x = svptrue_b64();
272
+ results->real = nk_dot_stable_sum_f64_sve_(predicate_all_f64x, sum_real_f64x, comp_real_f64x);
273
+ results->imag = nk_dot_stable_sum_f64_sve_(predicate_all_f64x, sum_imag_f64x, comp_imag_f64x);
274
+ }
275
+
276
+ NK_PUBLIC void nk_vdot_f64c_sve(nk_f64c_t const *a_pairs, nk_f64c_t const *b_pairs, nk_size_t count_pairs,
277
+ nk_f64c_t *results) {
278
+ // Dot2 compensated conjugate dot product: conj(a) · b = (a_real - i*a_imag)(b_real + i*b_imag)
279
+ // real = a_real*b_real + a_imag*b_imag, imag = a_real*b_imag - a_imag*b_real
280
+ nk_size_t idx_pairs = 0;
281
+ svfloat64_t sum_real_f64x = svdup_f64(0.);
282
+ svfloat64_t comp_real_f64x = svdup_f64(0.);
283
+ svfloat64_t sum_imag_f64x = svdup_f64(0.);
284
+ svfloat64_t comp_imag_f64x = svdup_f64(0.);
285
+ do {
286
+ svbool_t predicate_f64x = svwhilelt_b64_u64(idx_pairs, count_pairs);
287
+ svfloat64x2_t a_f64x2 = svld2_f64(predicate_f64x, (nk_f64_t const *)(a_pairs + idx_pairs));
288
+ svfloat64x2_t b_f64x2 = svld2_f64(predicate_f64x, (nk_f64_t const *)(b_pairs + idx_pairs));
289
+ svfloat64_t a_real_f64x = svget2_f64(a_f64x2, 0);
290
+ svfloat64_t a_imag_f64x = svget2_f64(a_f64x2, 1);
291
+ svfloat64_t b_real_f64x = svget2_f64(b_f64x2, 0);
292
+ svfloat64_t b_imag_f64x = svget2_f64(b_f64x2, 1);
293
+
294
+ // TwoProd + TwoSum for real part: sum_real += a_real*b_real
295
+ {
296
+ svfloat64_t product_f64x = svmul_f64_x(predicate_f64x, a_real_f64x, b_real_f64x);
297
+ svfloat64_t product_error_f64x = svneg_f64_x(
298
+ predicate_f64x, svnmls_f64_x(predicate_f64x, product_f64x, a_real_f64x, b_real_f64x));
299
+ svfloat64_t tentative_sum_f64x = svadd_f64_x(predicate_f64x, sum_real_f64x, product_f64x);
300
+ svfloat64_t virtual_addend_f64x = svsub_f64_x(predicate_f64x, tentative_sum_f64x, sum_real_f64x);
301
+ svfloat64_t sum_error_f64x = svadd_f64_x(
302
+ predicate_f64x,
303
+ svsub_f64_x(predicate_f64x, sum_real_f64x,
304
+ svsub_f64_x(predicate_f64x, tentative_sum_f64x, virtual_addend_f64x)),
305
+ svsub_f64_x(predicate_f64x, product_f64x, virtual_addend_f64x));
306
+ sum_real_f64x = tentative_sum_f64x;
307
+ comp_real_f64x = svadd_f64_x(predicate_f64x, comp_real_f64x,
308
+ svadd_f64_x(predicate_f64x, sum_error_f64x, product_error_f64x));
309
+ }
310
+ // TwoProd + TwoSum for real part: sum_real += a_imag*b_imag (conjugate: + not -)
311
+ {
312
+ svfloat64_t product_f64x = svmul_f64_x(predicate_f64x, a_imag_f64x, b_imag_f64x);
313
+ svfloat64_t product_error_f64x = svneg_f64_x(
314
+ predicate_f64x, svnmls_f64_x(predicate_f64x, product_f64x, a_imag_f64x, b_imag_f64x));
315
+ svfloat64_t tentative_sum_f64x = svadd_f64_x(predicate_f64x, sum_real_f64x, product_f64x);
316
+ svfloat64_t virtual_addend_f64x = svsub_f64_x(predicate_f64x, tentative_sum_f64x, sum_real_f64x);
317
+ svfloat64_t sum_error_f64x = svadd_f64_x(
318
+ predicate_f64x,
319
+ svsub_f64_x(predicate_f64x, sum_real_f64x,
320
+ svsub_f64_x(predicate_f64x, tentative_sum_f64x, virtual_addend_f64x)),
321
+ svsub_f64_x(predicate_f64x, product_f64x, virtual_addend_f64x));
322
+ sum_real_f64x = tentative_sum_f64x;
323
+ comp_real_f64x = svadd_f64_x(predicate_f64x, comp_real_f64x,
324
+ svadd_f64_x(predicate_f64x, sum_error_f64x, product_error_f64x));
325
+ }
326
+ // TwoProd + TwoSum for imaginary part: sum_imag += a_real*b_imag
327
+ {
328
+ svfloat64_t product_f64x = svmul_f64_x(predicate_f64x, a_real_f64x, b_imag_f64x);
329
+ svfloat64_t product_error_f64x = svneg_f64_x(
330
+ predicate_f64x, svnmls_f64_x(predicate_f64x, product_f64x, a_real_f64x, b_imag_f64x));
331
+ svfloat64_t tentative_sum_f64x = svadd_f64_x(predicate_f64x, sum_imag_f64x, product_f64x);
332
+ svfloat64_t virtual_addend_f64x = svsub_f64_x(predicate_f64x, tentative_sum_f64x, sum_imag_f64x);
333
+ svfloat64_t sum_error_f64x = svadd_f64_x(
334
+ predicate_f64x,
335
+ svsub_f64_x(predicate_f64x, sum_imag_f64x,
336
+ svsub_f64_x(predicate_f64x, tentative_sum_f64x, virtual_addend_f64x)),
337
+ svsub_f64_x(predicate_f64x, product_f64x, virtual_addend_f64x));
338
+ sum_imag_f64x = tentative_sum_f64x;
339
+ comp_imag_f64x = svadd_f64_x(predicate_f64x, comp_imag_f64x,
340
+ svadd_f64_x(predicate_f64x, sum_error_f64x, product_error_f64x));
341
+ }
342
+ // TwoProd + TwoSum for imaginary part: sum_imag -= a_imag*b_real (conjugate: - not +)
343
+ {
344
+ svfloat64_t product_f64x = svmul_f64_x(predicate_f64x, a_imag_f64x, b_real_f64x);
345
+ svfloat64_t product_error_f64x = svneg_f64_x(
346
+ predicate_f64x, svnmls_f64_x(predicate_f64x, product_f64x, a_imag_f64x, b_real_f64x));
347
+ svfloat64_t neg_product_f64x = svneg_f64_x(predicate_f64x, product_f64x);
348
+ svfloat64_t neg_product_error_f64x = svneg_f64_x(predicate_f64x, product_error_f64x);
349
+ svfloat64_t tentative_sum_f64x = svadd_f64_x(predicate_f64x, sum_imag_f64x, neg_product_f64x);
350
+ svfloat64_t virtual_addend_f64x = svsub_f64_x(predicate_f64x, tentative_sum_f64x, sum_imag_f64x);
351
+ svfloat64_t sum_error_f64x = svadd_f64_x(
352
+ predicate_f64x,
353
+ svsub_f64_x(predicate_f64x, sum_imag_f64x,
354
+ svsub_f64_x(predicate_f64x, tentative_sum_f64x, virtual_addend_f64x)),
355
+ svsub_f64_x(predicate_f64x, neg_product_f64x, virtual_addend_f64x));
356
+ sum_imag_f64x = tentative_sum_f64x;
357
+ comp_imag_f64x = svadd_f64_x(predicate_f64x, comp_imag_f64x,
358
+ svadd_f64_x(predicate_f64x, sum_error_f64x, neg_product_error_f64x));
359
+ }
360
+ idx_pairs += svcntd();
361
+ } while (idx_pairs < count_pairs);
362
+ svbool_t predicate_all_f64x = svptrue_b64();
363
+ results->real = nk_dot_stable_sum_f64_sve_(predicate_all_f64x, sum_real_f64x, comp_real_f64x);
364
+ results->imag = nk_dot_stable_sum_f64_sve_(predicate_all_f64x, sum_imag_f64x, comp_imag_f64x);
365
+ }
366
+
367
+ #if defined(__clang__)
368
+ #pragma clang attribute pop
369
+ #elif defined(__GNUC__)
370
+ #pragma GCC pop_options
371
+ #endif
372
+
373
+ #if defined(__cplusplus)
374
+ } // extern "C"
375
+ #endif
376
+
377
+ #endif // NK_TARGET_SVE
378
+ #endif // NK_TARGET_ARM_
379
+ #endif // NK_DOT_SVE_H
@@ -0,0 +1,74 @@
1
+ /**
2
+ * @brief SIMD-accelerated Dot Products for SVE BF16.
3
+ * @file include/numkong/dot/svebfdot.h
4
+ * @author Ash Vardanian
5
+ * @date March 16, 2026
6
+ *
7
+ * @sa include/numkong/dot.h
8
+ *
9
+ * @section dot_svebfdot_instructions ARM SVE+BF16 Instructions
10
+ *
11
+ * Intrinsic Instruction Latency Throughput
12
+ * svld1_bf16 LD1H (Z.H, P/Z, [Xn]) 4-6cy 2/cy
13
+ * svbfdot_f32 BFDOT (Z.S, Z.H, Z.H) 4cy 2/cy
14
+ * svaddv_f32 FADDV (S, P, Z.S) 6cy 1/cy
15
+ * svdup_f32 DUP (Z.S, #imm) 1cy 2/cy
16
+ * svwhilelt_b16 WHILELT (P.H, Xn, Xm) 2cy 1/cy
17
+ * svcnth CNTH (Xd) 1cy 2/cy
18
+ *
19
+ * SVE vector widths vary across implementations: Graviton3 uses 256-bit, while Graviton4/5
20
+ * and Apple M4+ use 128-bit. Code using svcnth() adapts automatically, but wider vectors
21
+ * process more elements per iteration with identical latencies.
22
+ *
23
+ * The BFDOT instruction fuses two BF16 multiplications with FP32 accumulation per lane,
24
+ * providing 4x the throughput of convert-then-FMA sequences. Each BFDOT processes
25
+ * pairs of BF16 values, accumulating directly into FP32 without explicit conversion.
26
+ */
27
+ #ifndef NK_DOT_SVEBFDOT_H
28
+ #define NK_DOT_SVEBFDOT_H
29
+
30
+ #if NK_TARGET_ARM_
31
+ #if NK_TARGET_SVEBFDOT
32
+
33
+ #include "numkong/types.h"
34
+
35
+ #if defined(__cplusplus)
36
+ extern "C" {
37
+ #endif
38
+
39
+ #if defined(__clang__)
40
+ #pragma clang attribute push(__attribute__((target("arch=armv8.2-a+sve+bf16"))), apply_to = function)
41
+ #elif defined(__GNUC__)
42
+ #pragma GCC push_options
43
+ #pragma GCC target("arch=armv8.2-a+sve+bf16")
44
+ #endif
45
+
46
+ NK_PUBLIC void nk_dot_bf16_svebfdot(nk_bf16_t const *a_scalars, nk_bf16_t const *b_scalars, nk_size_t count_scalars,
47
+ nk_f32_t *result) {
48
+ nk_size_t idx_scalars = 0;
49
+ svfloat32_t sum_f32x = svdup_f32(0);
50
+ nk_bf16_for_arm_simd_t const *a = (nk_bf16_for_arm_simd_t const *)(a_scalars);
51
+ nk_bf16_for_arm_simd_t const *b = (nk_bf16_for_arm_simd_t const *)(b_scalars);
52
+ do {
53
+ svbool_t predicate_bf16x = svwhilelt_b16_u64(idx_scalars, count_scalars);
54
+ svbfloat16_t a_bf16x = svld1_bf16(predicate_bf16x, a + idx_scalars);
55
+ svbfloat16_t b_bf16x = svld1_bf16(predicate_bf16x, b + idx_scalars);
56
+ sum_f32x = svbfdot_f32(sum_f32x, a_bf16x, b_bf16x);
57
+ idx_scalars += svcnth();
58
+ } while (idx_scalars < count_scalars);
59
+ *result = svaddv_f32(svptrue_b32(), sum_f32x);
60
+ }
61
+
62
+ #if defined(__clang__)
63
+ #pragma clang attribute pop
64
+ #elif defined(__GNUC__)
65
+ #pragma GCC pop_options
66
+ #endif
67
+
68
+ #if defined(__cplusplus)
69
+ } // extern "C"
70
+ #endif
71
+
72
+ #endif // NK_TARGET_SVEBFDOT
73
+ #endif // NK_TARGET_ARM_
74
+ #endif // NK_DOT_SVEBFDOT_H
@@ -0,0 +1,123 @@
1
+ /**
2
+ * @brief SIMD-accelerated Dot Products for SVE FP16.
3
+ * @file include/numkong/dot/sve.h
4
+ * @author Ash Vardanian
5
+ * @date December 27, 2025
6
+ *
7
+ * @sa include/numkong/dot.h
8
+ *
9
+ * @section dot_svehalf_instructions ARM SVE+FP16 Instructions
10
+ *
11
+ * Intrinsic Instruction Latency Throughput
12
+ * svld1_f16 LD1H (Z.H, P/Z, [Xn]) 4-6cy 2/cy
13
+ * svld2_f16 LD2H (Z.H, P/Z, [Xn]) 6-8cy 1/cy
14
+ * svmla_f16_x FMLA (Z.H, P/M, Z.H, Z.H) 4cy 2/cy
15
+ * svmls_f16_x FMLS (Z.H, P/M, Z.H, Z.H) 4cy 2/cy
16
+ * svaddv_f16 FADDV (H, P, Z.H) 6cy 1/cy
17
+ * svdup_f16 DUP (Z.H, #imm) 1cy 2/cy
18
+ * svwhilelt_b16 WHILELT (P.H, Xn, Xm) 2cy 1/cy
19
+ * svptrue_b16 PTRUE (P.H, pattern) 1cy 2/cy
20
+ * svcnth CNTH (Xd) 1cy 2/cy
21
+ *
22
+ * SVE vector widths vary across implementations: Graviton3 uses 256-bit, while Graviton4/5
23
+ * and Apple M4+ use 128-bit. Code using svcntb() adapts automatically, but wider vectors
24
+ * process more elements per iteration with identical latencies.
25
+ *
26
+ * FP16 operations double the element count per vector compared to FP32, providing higher
27
+ * throughput at the cost of reduced precision. The FADDV reduction remains the bottleneck.
28
+ */
29
+ #ifndef NK_DOT_SVEHALF_H
30
+ #define NK_DOT_SVEHALF_H
31
+
32
+ #if NK_TARGET_ARM_
33
+ #if NK_TARGET_SVEHALF
34
+
35
+ #include "numkong/types.h" // `nk_f16_t`
36
+ #include "numkong/dot/serial.h" // `nk_u1x8_popcount_`
37
+
38
+ #if defined(__cplusplus)
39
+ extern "C" {
40
+ #endif
41
+
42
+ #if defined(__clang__)
43
+ #pragma clang attribute push(__attribute__((target("arch=armv8.2-a+sve+fp16"))), apply_to = function)
44
+ #elif defined(__GNUC__)
45
+ #pragma GCC push_options
46
+ #pragma GCC target("arch=armv8.2-a+sve+fp16")
47
+ #endif
48
+
49
+ NK_PUBLIC void nk_dot_f16_svehalf(nk_f16_t const *a_scalars, nk_f16_t const *b_scalars, nk_size_t count_scalars,
50
+ nk_f32_t *result) {
51
+ nk_size_t idx_scalars = 0;
52
+ svfloat32_t ab_f32x = svdup_f32(0);
53
+ do {
54
+ svbool_t predicate_f32x = svwhilelt_b32_u64(idx_scalars, count_scalars);
55
+ svfloat16_t a_f16x = svld1_f16(predicate_f32x, (nk_f16_for_arm_simd_t const *)(a_scalars) + idx_scalars);
56
+ svfloat16_t b_f16x = svld1_f16(predicate_f32x, (nk_f16_for_arm_simd_t const *)(b_scalars) + idx_scalars);
57
+ svfloat32_t a_f32x = svcvt_f32_f16_x(predicate_f32x, a_f16x);
58
+ svfloat32_t b_f32x = svcvt_f32_f16_x(predicate_f32x, b_f16x);
59
+ ab_f32x = svmla_f32_x(predicate_f32x, ab_f32x, a_f32x, b_f32x);
60
+ idx_scalars += svcntw();
61
+ } while (idx_scalars < count_scalars);
62
+ *result = svaddv_f32(svptrue_b32(), ab_f32x);
63
+ }
64
+
65
+ NK_PUBLIC void nk_dot_f16c_svehalf(nk_f16c_t const *a_pairs, nk_f16c_t const *b_pairs, nk_size_t count_pairs,
66
+ nk_f32c_t *results) {
67
+ nk_size_t idx_scalars = 0;
68
+ svfloat32_t ab_real_f32x = svdup_f32(0);
69
+ svfloat32_t ab_imag_f32x = svdup_f32(0);
70
+ do {
71
+ svbool_t predicate_f32x = svwhilelt_b32_u64(idx_scalars, count_pairs);
72
+ svfloat16x2_t a_f16x2 = svld2_f16(predicate_f32x, (nk_f16_for_arm_simd_t const *)(a_pairs) + idx_scalars * 2);
73
+ svfloat16x2_t b_f16x2 = svld2_f16(predicate_f32x, (nk_f16_for_arm_simd_t const *)(b_pairs) + idx_scalars * 2);
74
+ svfloat32_t a_real_f32x = svcvt_f32_f16_x(predicate_f32x, svget2_f16(a_f16x2, 0));
75
+ svfloat32_t a_imag_f32x = svcvt_f32_f16_x(predicate_f32x, svget2_f16(a_f16x2, 1));
76
+ svfloat32_t b_real_f32x = svcvt_f32_f16_x(predicate_f32x, svget2_f16(b_f16x2, 0));
77
+ svfloat32_t b_imag_f32x = svcvt_f32_f16_x(predicate_f32x, svget2_f16(b_f16x2, 1));
78
+ ab_real_f32x = svmla_f32_x(predicate_f32x, ab_real_f32x, a_real_f32x, b_real_f32x);
79
+ ab_real_f32x = svmls_f32_x(predicate_f32x, ab_real_f32x, a_imag_f32x, b_imag_f32x);
80
+ ab_imag_f32x = svmla_f32_x(predicate_f32x, ab_imag_f32x, a_real_f32x, b_imag_f32x);
81
+ ab_imag_f32x = svmla_f32_x(predicate_f32x, ab_imag_f32x, a_imag_f32x, b_real_f32x);
82
+ idx_scalars += svcntw();
83
+ } while (idx_scalars < count_pairs);
84
+ results->real = svaddv_f32(svptrue_b32(), ab_real_f32x);
85
+ results->imag = svaddv_f32(svptrue_b32(), ab_imag_f32x);
86
+ }
87
+
88
+ NK_PUBLIC void nk_vdot_f16c_svehalf(nk_f16c_t const *a_pairs, nk_f16c_t const *b_pairs, nk_size_t count_pairs,
89
+ nk_f32c_t *results) {
90
+ nk_size_t idx_scalars = 0;
91
+ svfloat32_t ab_real_f32x = svdup_f32(0);
92
+ svfloat32_t ab_imag_f32x = svdup_f32(0);
93
+ do {
94
+ svbool_t predicate_f32x = svwhilelt_b32_u64(idx_scalars, count_pairs);
95
+ svfloat16x2_t a_f16x2 = svld2_f16(predicate_f32x, (nk_f16_for_arm_simd_t const *)(a_pairs) + idx_scalars * 2);
96
+ svfloat16x2_t b_f16x2 = svld2_f16(predicate_f32x, (nk_f16_for_arm_simd_t const *)(b_pairs) + idx_scalars * 2);
97
+ svfloat32_t a_real_f32x = svcvt_f32_f16_x(predicate_f32x, svget2_f16(a_f16x2, 0));
98
+ svfloat32_t a_imag_f32x = svcvt_f32_f16_x(predicate_f32x, svget2_f16(a_f16x2, 1));
99
+ svfloat32_t b_real_f32x = svcvt_f32_f16_x(predicate_f32x, svget2_f16(b_f16x2, 0));
100
+ svfloat32_t b_imag_f32x = svcvt_f32_f16_x(predicate_f32x, svget2_f16(b_f16x2, 1));
101
+ ab_real_f32x = svmla_f32_x(predicate_f32x, ab_real_f32x, a_real_f32x, b_real_f32x);
102
+ ab_real_f32x = svmla_f32_x(predicate_f32x, ab_real_f32x, a_imag_f32x, b_imag_f32x);
103
+ ab_imag_f32x = svmla_f32_x(predicate_f32x, ab_imag_f32x, a_real_f32x, b_imag_f32x);
104
+ ab_imag_f32x = svmls_f32_x(predicate_f32x, ab_imag_f32x, a_imag_f32x, b_real_f32x);
105
+ idx_scalars += svcntw();
106
+ } while (idx_scalars < count_pairs);
107
+ results->real = svaddv_f32(svptrue_b32(), ab_real_f32x);
108
+ results->imag = svaddv_f32(svptrue_b32(), ab_imag_f32x);
109
+ }
110
+
111
+ #if defined(__clang__)
112
+ #pragma clang attribute pop
113
+ #elif defined(__GNUC__)
114
+ #pragma GCC pop_options
115
+ #endif
116
+
117
+ #if defined(__cplusplus)
118
+ } // extern "C"
119
+ #endif
120
+
121
+ #endif // NK_TARGET_SVEHALF
122
+ #endif // NK_TARGET_ARM_
123
+ #endif // NK_DOT_SVEHALF_H