numkong 7.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (294) hide show
  1. package/LICENSE +201 -0
  2. package/README.md +495 -0
  3. package/binding.gyp +540 -0
  4. package/c/dispatch.h +512 -0
  5. package/c/dispatch_bf16.c +389 -0
  6. package/c/dispatch_bf16c.c +52 -0
  7. package/c/dispatch_e2m3.c +263 -0
  8. package/c/dispatch_e3m2.c +243 -0
  9. package/c/dispatch_e4m3.c +276 -0
  10. package/c/dispatch_e5m2.c +272 -0
  11. package/c/dispatch_f16.c +376 -0
  12. package/c/dispatch_f16c.c +58 -0
  13. package/c/dispatch_f32.c +378 -0
  14. package/c/dispatch_f32c.c +99 -0
  15. package/c/dispatch_f64.c +296 -0
  16. package/c/dispatch_f64c.c +98 -0
  17. package/c/dispatch_i16.c +96 -0
  18. package/c/dispatch_i32.c +89 -0
  19. package/c/dispatch_i4.c +150 -0
  20. package/c/dispatch_i64.c +86 -0
  21. package/c/dispatch_i8.c +289 -0
  22. package/c/dispatch_other.c +330 -0
  23. package/c/dispatch_u1.c +148 -0
  24. package/c/dispatch_u16.c +124 -0
  25. package/c/dispatch_u32.c +118 -0
  26. package/c/dispatch_u4.c +150 -0
  27. package/c/dispatch_u64.c +102 -0
  28. package/c/dispatch_u8.c +303 -0
  29. package/c/numkong.c +950 -0
  30. package/include/README.md +573 -0
  31. package/include/module.modulemap +129 -0
  32. package/include/numkong/attention/sapphireamx.h +1361 -0
  33. package/include/numkong/attention/sme.h +2066 -0
  34. package/include/numkong/attention.h +49 -0
  35. package/include/numkong/capabilities.h +748 -0
  36. package/include/numkong/cast/README.md +262 -0
  37. package/include/numkong/cast/haswell.h +975 -0
  38. package/include/numkong/cast/icelake.h +470 -0
  39. package/include/numkong/cast/neon.h +1192 -0
  40. package/include/numkong/cast/rvv.h +1021 -0
  41. package/include/numkong/cast/sapphire.h +262 -0
  42. package/include/numkong/cast/serial.h +2262 -0
  43. package/include/numkong/cast/skylake.h +856 -0
  44. package/include/numkong/cast/v128relaxed.h +180 -0
  45. package/include/numkong/cast.h +230 -0
  46. package/include/numkong/curved/README.md +223 -0
  47. package/include/numkong/curved/genoa.h +182 -0
  48. package/include/numkong/curved/haswell.h +276 -0
  49. package/include/numkong/curved/neon.h +205 -0
  50. package/include/numkong/curved/neonbfdot.h +212 -0
  51. package/include/numkong/curved/neonhalf.h +212 -0
  52. package/include/numkong/curved/rvv.h +305 -0
  53. package/include/numkong/curved/serial.h +207 -0
  54. package/include/numkong/curved/skylake.h +457 -0
  55. package/include/numkong/curved/smef64.h +506 -0
  56. package/include/numkong/curved.h +517 -0
  57. package/include/numkong/curved.hpp +144 -0
  58. package/include/numkong/dot/README.md +425 -0
  59. package/include/numkong/dot/alder.h +563 -0
  60. package/include/numkong/dot/genoa.h +315 -0
  61. package/include/numkong/dot/haswell.h +1688 -0
  62. package/include/numkong/dot/icelake.h +883 -0
  63. package/include/numkong/dot/neon.h +818 -0
  64. package/include/numkong/dot/neonbfdot.h +244 -0
  65. package/include/numkong/dot/neonfhm.h +360 -0
  66. package/include/numkong/dot/neonhalf.h +198 -0
  67. package/include/numkong/dot/neonsdot.h +508 -0
  68. package/include/numkong/dot/rvv.h +714 -0
  69. package/include/numkong/dot/rvvbb.h +72 -0
  70. package/include/numkong/dot/rvvbf16.h +123 -0
  71. package/include/numkong/dot/rvvhalf.h +129 -0
  72. package/include/numkong/dot/sapphire.h +141 -0
  73. package/include/numkong/dot/serial.h +838 -0
  74. package/include/numkong/dot/sierra.h +405 -0
  75. package/include/numkong/dot/skylake.h +1084 -0
  76. package/include/numkong/dot/sve.h +379 -0
  77. package/include/numkong/dot/svebfdot.h +74 -0
  78. package/include/numkong/dot/svehalf.h +123 -0
  79. package/include/numkong/dot/v128relaxed.h +1258 -0
  80. package/include/numkong/dot.h +1070 -0
  81. package/include/numkong/dot.hpp +94 -0
  82. package/include/numkong/dots/README.md +496 -0
  83. package/include/numkong/dots/alder.h +114 -0
  84. package/include/numkong/dots/genoa.h +94 -0
  85. package/include/numkong/dots/haswell.h +295 -0
  86. package/include/numkong/dots/icelake.h +171 -0
  87. package/include/numkong/dots/neon.h +120 -0
  88. package/include/numkong/dots/neonbfdot.h +58 -0
  89. package/include/numkong/dots/neonfhm.h +94 -0
  90. package/include/numkong/dots/neonhalf.h +57 -0
  91. package/include/numkong/dots/neonsdot.h +108 -0
  92. package/include/numkong/dots/rvv.h +2486 -0
  93. package/include/numkong/dots/sapphireamx.h +3973 -0
  94. package/include/numkong/dots/serial.h +2844 -0
  95. package/include/numkong/dots/sierra.h +97 -0
  96. package/include/numkong/dots/skylake.h +196 -0
  97. package/include/numkong/dots/sme.h +5372 -0
  98. package/include/numkong/dots/smebi32.h +461 -0
  99. package/include/numkong/dots/smef64.h +1318 -0
  100. package/include/numkong/dots/smehalf.h +47 -0
  101. package/include/numkong/dots/v128relaxed.h +294 -0
  102. package/include/numkong/dots.h +2804 -0
  103. package/include/numkong/dots.hpp +639 -0
  104. package/include/numkong/each/README.md +469 -0
  105. package/include/numkong/each/haswell.h +1658 -0
  106. package/include/numkong/each/icelake.h +272 -0
  107. package/include/numkong/each/neon.h +1104 -0
  108. package/include/numkong/each/neonbfdot.h +212 -0
  109. package/include/numkong/each/neonhalf.h +410 -0
  110. package/include/numkong/each/rvv.h +1121 -0
  111. package/include/numkong/each/sapphire.h +477 -0
  112. package/include/numkong/each/serial.h +260 -0
  113. package/include/numkong/each/skylake.h +1562 -0
  114. package/include/numkong/each.h +2146 -0
  115. package/include/numkong/each.hpp +434 -0
  116. package/include/numkong/geospatial/README.md +147 -0
  117. package/include/numkong/geospatial/haswell.h +593 -0
  118. package/include/numkong/geospatial/neon.h +571 -0
  119. package/include/numkong/geospatial/rvv.h +701 -0
  120. package/include/numkong/geospatial/serial.h +309 -0
  121. package/include/numkong/geospatial/skylake.h +577 -0
  122. package/include/numkong/geospatial/v128relaxed.h +613 -0
  123. package/include/numkong/geospatial.h +453 -0
  124. package/include/numkong/geospatial.hpp +235 -0
  125. package/include/numkong/matrix.hpp +336 -0
  126. package/include/numkong/maxsim/README.md +187 -0
  127. package/include/numkong/maxsim/alder.h +511 -0
  128. package/include/numkong/maxsim/genoa.h +115 -0
  129. package/include/numkong/maxsim/haswell.h +553 -0
  130. package/include/numkong/maxsim/icelake.h +480 -0
  131. package/include/numkong/maxsim/neonsdot.h +394 -0
  132. package/include/numkong/maxsim/sapphireamx.h +877 -0
  133. package/include/numkong/maxsim/serial.h +490 -0
  134. package/include/numkong/maxsim/sme.h +929 -0
  135. package/include/numkong/maxsim/v128relaxed.h +280 -0
  136. package/include/numkong/maxsim.h +571 -0
  137. package/include/numkong/maxsim.hpp +133 -0
  138. package/include/numkong/mesh/README.md +227 -0
  139. package/include/numkong/mesh/haswell.h +2235 -0
  140. package/include/numkong/mesh/neon.h +1329 -0
  141. package/include/numkong/mesh/neonbfdot.h +842 -0
  142. package/include/numkong/mesh/neonhalf.h +616 -0
  143. package/include/numkong/mesh/rvv.h +916 -0
  144. package/include/numkong/mesh/serial.h +742 -0
  145. package/include/numkong/mesh/skylake.h +1135 -0
  146. package/include/numkong/mesh/v128relaxed.h +1052 -0
  147. package/include/numkong/mesh.h +652 -0
  148. package/include/numkong/mesh.hpp +762 -0
  149. package/include/numkong/numkong.h +78 -0
  150. package/include/numkong/numkong.hpp +57 -0
  151. package/include/numkong/probability/README.md +173 -0
  152. package/include/numkong/probability/haswell.h +267 -0
  153. package/include/numkong/probability/neon.h +225 -0
  154. package/include/numkong/probability/rvv.h +409 -0
  155. package/include/numkong/probability/serial.h +169 -0
  156. package/include/numkong/probability/skylake.h +324 -0
  157. package/include/numkong/probability.h +383 -0
  158. package/include/numkong/probability.hpp +120 -0
  159. package/include/numkong/random.h +50 -0
  160. package/include/numkong/random.hpp +285 -0
  161. package/include/numkong/reduce/README.md +547 -0
  162. package/include/numkong/reduce/alder.h +632 -0
  163. package/include/numkong/reduce/genoa.h +201 -0
  164. package/include/numkong/reduce/haswell.h +3783 -0
  165. package/include/numkong/reduce/icelake.h +549 -0
  166. package/include/numkong/reduce/neon.h +3841 -0
  167. package/include/numkong/reduce/neonbfdot.h +353 -0
  168. package/include/numkong/reduce/neonfhm.h +665 -0
  169. package/include/numkong/reduce/neonhalf.h +157 -0
  170. package/include/numkong/reduce/neonsdot.h +357 -0
  171. package/include/numkong/reduce/rvv.h +3407 -0
  172. package/include/numkong/reduce/serial.h +757 -0
  173. package/include/numkong/reduce/sierra.h +338 -0
  174. package/include/numkong/reduce/skylake.h +3792 -0
  175. package/include/numkong/reduce/v128relaxed.h +2302 -0
  176. package/include/numkong/reduce.h +1597 -0
  177. package/include/numkong/reduce.hpp +633 -0
  178. package/include/numkong/scalar/README.md +89 -0
  179. package/include/numkong/scalar/haswell.h +113 -0
  180. package/include/numkong/scalar/neon.h +122 -0
  181. package/include/numkong/scalar/neonhalf.h +70 -0
  182. package/include/numkong/scalar/rvv.h +211 -0
  183. package/include/numkong/scalar/sapphire.h +63 -0
  184. package/include/numkong/scalar/serial.h +332 -0
  185. package/include/numkong/scalar/v128relaxed.h +56 -0
  186. package/include/numkong/scalar.h +683 -0
  187. package/include/numkong/set/README.md +179 -0
  188. package/include/numkong/set/haswell.h +334 -0
  189. package/include/numkong/set/icelake.h +485 -0
  190. package/include/numkong/set/neon.h +364 -0
  191. package/include/numkong/set/rvv.h +226 -0
  192. package/include/numkong/set/rvvbb.h +117 -0
  193. package/include/numkong/set/serial.h +174 -0
  194. package/include/numkong/set/sve.h +185 -0
  195. package/include/numkong/set/v128relaxed.h +240 -0
  196. package/include/numkong/set.h +457 -0
  197. package/include/numkong/set.hpp +114 -0
  198. package/include/numkong/sets/README.md +149 -0
  199. package/include/numkong/sets/haswell.h +63 -0
  200. package/include/numkong/sets/icelake.h +66 -0
  201. package/include/numkong/sets/neon.h +61 -0
  202. package/include/numkong/sets/serial.h +43 -0
  203. package/include/numkong/sets/smebi32.h +1099 -0
  204. package/include/numkong/sets/v128relaxed.h +58 -0
  205. package/include/numkong/sets.h +339 -0
  206. package/include/numkong/sparse/README.md +156 -0
  207. package/include/numkong/sparse/icelake.h +463 -0
  208. package/include/numkong/sparse/neon.h +288 -0
  209. package/include/numkong/sparse/serial.h +117 -0
  210. package/include/numkong/sparse/sve2.h +507 -0
  211. package/include/numkong/sparse/turin.h +322 -0
  212. package/include/numkong/sparse.h +363 -0
  213. package/include/numkong/sparse.hpp +113 -0
  214. package/include/numkong/spatial/README.md +435 -0
  215. package/include/numkong/spatial/alder.h +607 -0
  216. package/include/numkong/spatial/genoa.h +290 -0
  217. package/include/numkong/spatial/haswell.h +960 -0
  218. package/include/numkong/spatial/icelake.h +586 -0
  219. package/include/numkong/spatial/neon.h +773 -0
  220. package/include/numkong/spatial/neonbfdot.h +165 -0
  221. package/include/numkong/spatial/neonhalf.h +118 -0
  222. package/include/numkong/spatial/neonsdot.h +261 -0
  223. package/include/numkong/spatial/rvv.h +984 -0
  224. package/include/numkong/spatial/rvvbf16.h +123 -0
  225. package/include/numkong/spatial/rvvhalf.h +117 -0
  226. package/include/numkong/spatial/sapphire.h +343 -0
  227. package/include/numkong/spatial/serial.h +346 -0
  228. package/include/numkong/spatial/sierra.h +323 -0
  229. package/include/numkong/spatial/skylake.h +606 -0
  230. package/include/numkong/spatial/sve.h +224 -0
  231. package/include/numkong/spatial/svebfdot.h +122 -0
  232. package/include/numkong/spatial/svehalf.h +109 -0
  233. package/include/numkong/spatial/v128relaxed.h +717 -0
  234. package/include/numkong/spatial.h +1425 -0
  235. package/include/numkong/spatial.hpp +183 -0
  236. package/include/numkong/spatials/README.md +580 -0
  237. package/include/numkong/spatials/alder.h +94 -0
  238. package/include/numkong/spatials/genoa.h +94 -0
  239. package/include/numkong/spatials/haswell.h +219 -0
  240. package/include/numkong/spatials/icelake.h +113 -0
  241. package/include/numkong/spatials/neon.h +109 -0
  242. package/include/numkong/spatials/neonbfdot.h +60 -0
  243. package/include/numkong/spatials/neonfhm.h +92 -0
  244. package/include/numkong/spatials/neonhalf.h +58 -0
  245. package/include/numkong/spatials/neonsdot.h +109 -0
  246. package/include/numkong/spatials/rvv.h +1960 -0
  247. package/include/numkong/spatials/sapphireamx.h +1149 -0
  248. package/include/numkong/spatials/serial.h +226 -0
  249. package/include/numkong/spatials/sierra.h +96 -0
  250. package/include/numkong/spatials/skylake.h +184 -0
  251. package/include/numkong/spatials/sme.h +1901 -0
  252. package/include/numkong/spatials/smef64.h +465 -0
  253. package/include/numkong/spatials/v128relaxed.h +240 -0
  254. package/include/numkong/spatials.h +3021 -0
  255. package/include/numkong/spatials.hpp +508 -0
  256. package/include/numkong/tensor.hpp +1592 -0
  257. package/include/numkong/trigonometry/README.md +184 -0
  258. package/include/numkong/trigonometry/haswell.h +652 -0
  259. package/include/numkong/trigonometry/neon.h +639 -0
  260. package/include/numkong/trigonometry/rvv.h +699 -0
  261. package/include/numkong/trigonometry/serial.h +703 -0
  262. package/include/numkong/trigonometry/skylake.h +721 -0
  263. package/include/numkong/trigonometry/v128relaxed.h +666 -0
  264. package/include/numkong/trigonometry.h +467 -0
  265. package/include/numkong/trigonometry.hpp +166 -0
  266. package/include/numkong/types.h +1384 -0
  267. package/include/numkong/types.hpp +5603 -0
  268. package/include/numkong/vector.hpp +698 -0
  269. package/javascript/README.md +246 -0
  270. package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
  271. package/javascript/dist/cjs/numkong-wasm.js +617 -0
  272. package/javascript/dist/cjs/numkong.d.ts +343 -0
  273. package/javascript/dist/cjs/numkong.js +523 -0
  274. package/javascript/dist/cjs/package.json +3 -0
  275. package/javascript/dist/cjs/types.d.ts +284 -0
  276. package/javascript/dist/cjs/types.js +653 -0
  277. package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
  278. package/javascript/dist/esm/numkong-wasm.js +595 -0
  279. package/javascript/dist/esm/numkong.d.ts +343 -0
  280. package/javascript/dist/esm/numkong.js +452 -0
  281. package/javascript/dist/esm/package.json +3 -0
  282. package/javascript/dist/esm/types.d.ts +284 -0
  283. package/javascript/dist/esm/types.js +630 -0
  284. package/javascript/dist-package-cjs.json +3 -0
  285. package/javascript/dist-package-esm.json +3 -0
  286. package/javascript/node-gyp-build.d.ts +1 -0
  287. package/javascript/numkong-wasm.ts +756 -0
  288. package/javascript/numkong.c +689 -0
  289. package/javascript/numkong.ts +575 -0
  290. package/javascript/tsconfig-base.json +39 -0
  291. package/javascript/tsconfig-cjs.json +8 -0
  292. package/javascript/tsconfig-esm.json +8 -0
  293. package/javascript/types.ts +674 -0
  294. package/package.json +87 -0
@@ -0,0 +1,244 @@
1
+ /**
2
+ * @brief SIMD-accelerated Dot Products for NEON BF16.
3
+ * @file include/numkong/dot/neonbfdot.h
4
+ * @author Ash Vardanian
5
+ * @date December 27, 2025
6
+ *
7
+ * @sa include/numkong/dot.h
8
+ *
9
+ * @section dot_neonbfdot_instructions ARM NEON BF16 Instructions (ARMv8.6-BF16)
10
+ *
11
+ * Intrinsic Instruction Latency Throughput
12
+ * A76 M4+/V1+/Oryon
13
+ * vbfdotq_f32 BFDOT (V.4S, V.8H, V.8H) 3cy 2/cy 4/cy
14
+ * vcvt_f32_bf16 BFCVTN (V.4H, V.4S) 3cy 2/cy 4/cy
15
+ * vld1q_bf16 LD1 (V.8H) 4cy 2/cy 3/cy
16
+ * vaddvq_f32 FADDP+FADDP (V.4S) 4cy 1/cy 2/cy
17
+ * vfmaq_f32 FMLA (V.4S, V.4S, V.4S) 4cy 2/cy 4/cy
18
+ * vfmsq_f32 FMLS (V.4S, V.4S, V.4S) 4cy 2/cy 4/cy
19
+ *
20
+ * The ARMv8.6-BF16 extension provides the BFDOT instruction for accelerated BF16 dot products,
21
+ * targeting machine learning inference workloads. BF16 trades mantissa precision (7 bits vs 10 in
22
+ * FP16) for a larger exponent range matching FP32, eliminating overflow concerns during training.
23
+ *
24
+ * BFDOT computes two BF16 dot products per lane, accumulating directly into FP32 without explicit
25
+ * conversion. This provides higher throughput than FP16 convert-then-FMA sequences for ML inference
26
+ * where the reduced precision is acceptable.
27
+ *
28
+ * @section dot_neonbfdot_stateful Stateful Streaming Logic
29
+ *
30
+ * To build memory-optimal tiled algorithms, this file defines following structures and force-inlined
31
+ * `NK_INTERNAL` functions:
32
+ *
33
+ * - nk_dot_bf16x8 state with native BFDOT bf16 dot-products.
34
+ *
35
+ * @code{c}
36
+ * nk_dot_bf16x8_state_neonbfdot_t state_first, state_second, state_third, state_fourth;
37
+ * bfloat16x8_t query_bf16x8, target_first_bf16x8, target_second_bf16x8, target_third_bf16x8, target_fourth_bf16x8;
38
+ * nk_dot_bf16x8_init_neonbfdot(&state_first);
39
+ * nk_dot_bf16x8_init_neonbfdot(&state_second);
40
+ * nk_dot_bf16x8_init_neonbfdot(&state_third);
41
+ * nk_dot_bf16x8_init_neonbfdot(&state_fourth);
42
+ * for (nk_size_t idx = 0; idx + 8 <= depth; idx += 8) {
43
+ * query_bf16x8 = vld1q_bf16(query_ptr + idx);
44
+ * target_first_bf16x8 = vld1q_bf16(target_first_ptr + idx);
45
+ * target_second_bf16x8 = vld1q_bf16(target_second_ptr + idx);
46
+ * target_third_bf16x8 = vld1q_bf16(target_third_ptr + idx);
47
+ * target_fourth_bf16x8 = vld1q_bf16(target_fourth_ptr + idx);
48
+ * nk_dot_bf16x8_update_neonbfdot(&state_first, query_bf16x8, target_first_bf16x8, idx, 8);
49
+ * nk_dot_bf16x8_update_neonbfdot(&state_second, query_bf16x8, target_second_bf16x8, idx, 8);
50
+ * nk_dot_bf16x8_update_neonbfdot(&state_third, query_bf16x8, target_third_bf16x8, idx, 8);
51
+ * nk_dot_bf16x8_update_neonbfdot(&state_fourth, query_bf16x8, target_fourth_bf16x8, idx, 8);
52
+ * }
53
+ * float32x4_t results_f32x4;
54
+ * nk_dot_bf16x8_finalize_neonbfdot(&state_first, &state_second, &state_third, &state_fourth, depth, &results_f32x4);
55
+ * @endcode
56
+ */
57
+ #ifndef NK_DOT_NEONBFDOT_H
58
+ #define NK_DOT_NEONBFDOT_H
59
+
60
+ #if NK_TARGET_ARM_
61
+ #if NK_TARGET_NEONBFDOT
62
+
63
+ #include "numkong/types.h"
64
+ #include "numkong/cast/serial.h" // `nk_partial_load_b8x8_serial_`
65
+ #include "numkong/cast/neon.h" // `nk_e4m3x8_to_bf16x8_neon_`
66
+
67
+ #if defined(__cplusplus)
68
+ extern "C" {
69
+ #endif
70
+
71
+ #if defined(__clang__)
72
+ #pragma clang attribute push(__attribute__((target("arch=armv8.6-a+simd+bf16"))), apply_to = function)
73
+ #elif defined(__GNUC__)
74
+ #pragma GCC push_options
75
+ #pragma GCC target("arch=armv8.6-a+simd+bf16")
76
+ #endif
77
+
78
+ NK_PUBLIC void nk_dot_bf16_neonbfdot(nk_bf16_t const *a_scalars, nk_bf16_t const *b_scalars, nk_size_t count_scalars,
79
+ nk_f32_t *result) {
80
+ bfloat16x8_t a_bf16x8, b_bf16x8;
81
+ float32x4_t sum_f32x4 = vdupq_n_f32(0);
82
+ nk_dot_bf16_neonbfdot_cycle:
83
+ if (count_scalars < 8) {
84
+ nk_b128_vec_t a_vec, b_vec;
85
+ nk_partial_load_b16x8_serial_(a_scalars, &a_vec, count_scalars);
86
+ nk_partial_load_b16x8_serial_(b_scalars, &b_vec, count_scalars);
87
+ a_bf16x8 = vreinterpretq_bf16_u16(a_vec.u16x8);
88
+ b_bf16x8 = vreinterpretq_bf16_u16(b_vec.u16x8);
89
+ count_scalars = 0;
90
+ }
91
+ else {
92
+ a_bf16x8 = vld1q_bf16((nk_bf16_for_arm_simd_t const *)a_scalars);
93
+ b_bf16x8 = vld1q_bf16((nk_bf16_for_arm_simd_t const *)b_scalars);
94
+ a_scalars += 8, b_scalars += 8, count_scalars -= 8;
95
+ }
96
+ sum_f32x4 = vbfdotq_f32(sum_f32x4, a_bf16x8, b_bf16x8);
97
+ if (count_scalars) goto nk_dot_bf16_neonbfdot_cycle;
98
+ *result = vaddvq_f32(sum_f32x4);
99
+ }
100
+
101
+ NK_PUBLIC void nk_dot_bf16c_neonbfdot(nk_bf16c_t const *a_pairs, nk_bf16c_t const *b_pairs, nk_size_t count_pairs,
102
+ nk_f32c_t *result) {
103
+ float32x4_t sum_real_f32x4 = vdupq_n_f32(0);
104
+ float32x4_t sum_imag_f32x4 = vdupq_n_f32(0);
105
+ while (count_pairs >= 4) {
106
+ // Unpack the input arrays into real and imaginary parts.
107
+ // MSVC sadly doesn't recognize the `vld2_bf16`, so we load the data as signed
108
+ // integers of the same size and reinterpret with `vreinterpret_bf16_s16` afterwards.
109
+ int16x4x2_t a_i16x4x2 = vld2_s16((short const *)a_pairs);
110
+ int16x4x2_t b_i16x4x2 = vld2_s16((short const *)b_pairs);
111
+ float32x4_t a_real_f32x4 = vcvt_f32_bf16(vreinterpret_bf16_s16(a_i16x4x2.val[0]));
112
+ float32x4_t a_imag_f32x4 = vcvt_f32_bf16(vreinterpret_bf16_s16(a_i16x4x2.val[1]));
113
+ float32x4_t b_real_f32x4 = vcvt_f32_bf16(vreinterpret_bf16_s16(b_i16x4x2.val[0]));
114
+ float32x4_t b_imag_f32x4 = vcvt_f32_bf16(vreinterpret_bf16_s16(b_i16x4x2.val[1]));
115
+ sum_real_f32x4 = vfmaq_f32(sum_real_f32x4, a_real_f32x4, b_real_f32x4);
116
+ sum_real_f32x4 = vfmsq_f32(sum_real_f32x4, a_imag_f32x4, b_imag_f32x4);
117
+ sum_imag_f32x4 = vfmaq_f32(sum_imag_f32x4, a_real_f32x4, b_imag_f32x4);
118
+ sum_imag_f32x4 = vfmaq_f32(sum_imag_f32x4, a_imag_f32x4, b_real_f32x4);
119
+ count_pairs -= 4, a_pairs += 4, b_pairs += 4;
120
+ }
121
+ // Reduce horizontal sums and aggregate with the tail:
122
+ nk_f32c_t tail_result;
123
+ nk_dot_bf16c_serial(a_pairs, b_pairs, count_pairs, &tail_result);
124
+ result->real = tail_result.real + vaddvq_f32(sum_real_f32x4);
125
+ result->imag = tail_result.imag + vaddvq_f32(sum_imag_f32x4);
126
+ }
127
+
128
+ NK_PUBLIC void nk_vdot_bf16c_neonbfdot(nk_bf16c_t const *a_pairs, nk_bf16c_t const *b_pairs, nk_size_t count_pairs,
129
+ nk_f32c_t *result) {
130
+ float32x4_t sum_real_f32x4 = vdupq_n_f32(0);
131
+ float32x4_t sum_imag_f32x4 = vdupq_n_f32(0);
132
+ while (count_pairs >= 4) {
133
+ // Unpack the input arrays into real and imaginary parts.
134
+ // MSVC sadly doesn't recognize the `vld2_bf16`, so we load the data as signed
135
+ // integers of the same size and reinterpret with `vreinterpret_bf16_s16` afterwards.
136
+ int16x4x2_t a_i16x4x2 = vld2_s16((short const *)a_pairs);
137
+ int16x4x2_t b_i16x4x2 = vld2_s16((short const *)b_pairs);
138
+ float32x4_t a_real_f32x4 = vcvt_f32_bf16(vreinterpret_bf16_s16(a_i16x4x2.val[0]));
139
+ float32x4_t a_imag_f32x4 = vcvt_f32_bf16(vreinterpret_bf16_s16(a_i16x4x2.val[1]));
140
+ float32x4_t b_real_f32x4 = vcvt_f32_bf16(vreinterpret_bf16_s16(b_i16x4x2.val[0]));
141
+ float32x4_t b_imag_f32x4 = vcvt_f32_bf16(vreinterpret_bf16_s16(b_i16x4x2.val[1]));
142
+ sum_real_f32x4 = vfmaq_f32(sum_real_f32x4, a_real_f32x4, b_real_f32x4);
143
+ sum_real_f32x4 = vfmaq_f32(sum_real_f32x4, a_imag_f32x4, b_imag_f32x4);
144
+ sum_imag_f32x4 = vfmaq_f32(sum_imag_f32x4, a_real_f32x4, b_imag_f32x4);
145
+ sum_imag_f32x4 = vfmsq_f32(sum_imag_f32x4, a_imag_f32x4, b_real_f32x4);
146
+ count_pairs -= 4, a_pairs += 4, b_pairs += 4;
147
+ }
148
+ // Reduce horizontal sums and aggregate with the tail:
149
+ nk_f32c_t tail_result;
150
+ nk_vdot_bf16c_serial(a_pairs, b_pairs, count_pairs, &tail_result);
151
+ result->real = tail_result.real + vaddvq_f32(sum_real_f32x4);
152
+ result->imag = tail_result.imag + vaddvq_f32(sum_imag_f32x4);
153
+ }
154
+
155
+ NK_PUBLIC void nk_dot_e4m3_neonbfdot(nk_e4m3_t const *a_scalars, nk_e4m3_t const *b_scalars, nk_size_t count_scalars,
156
+ nk_f32_t *result) {
157
+ bfloat16x8_t a_bf16x8, b_bf16x8;
158
+ float32x4_t sum_f32x4 = vdupq_n_f32(0);
159
+ nk_dot_e4m3_neonbfdot_cycle:
160
+ if (count_scalars < 8) {
161
+ nk_b64_vec_t a_vec, b_vec;
162
+ nk_partial_load_b8x8_serial_(a_scalars, &a_vec, count_scalars);
163
+ nk_partial_load_b8x8_serial_(b_scalars, &b_vec, count_scalars);
164
+ a_bf16x8 = vreinterpretq_bf16_u16(nk_e4m3x8_to_bf16x8_neon_(a_vec.u8x8));
165
+ b_bf16x8 = vreinterpretq_bf16_u16(nk_e4m3x8_to_bf16x8_neon_(b_vec.u8x8));
166
+ count_scalars = 0;
167
+ }
168
+ else {
169
+ a_bf16x8 = vreinterpretq_bf16_u16(nk_e4m3x8_to_bf16x8_neon_(vld1_u8(a_scalars)));
170
+ b_bf16x8 = vreinterpretq_bf16_u16(nk_e4m3x8_to_bf16x8_neon_(vld1_u8(b_scalars)));
171
+ a_scalars += 8, b_scalars += 8, count_scalars -= 8;
172
+ }
173
+ sum_f32x4 = vbfdotq_f32(sum_f32x4, a_bf16x8, b_bf16x8);
174
+ if (count_scalars) goto nk_dot_e4m3_neonbfdot_cycle;
175
+ *result = vaddvq_f32(sum_f32x4);
176
+ }
177
+
178
+ NK_PUBLIC void nk_dot_e5m2_neonbfdot(nk_e5m2_t const *a_scalars, nk_e5m2_t const *b_scalars, nk_size_t count_scalars,
179
+ nk_f32_t *result) {
180
+ bfloat16x8_t a_bf16x8, b_bf16x8;
181
+ float32x4_t sum_f32x4 = vdupq_n_f32(0);
182
+ nk_dot_e5m2_neonbfdot_cycle:
183
+ if (count_scalars < 8) {
184
+ nk_b64_vec_t a_vec, b_vec;
185
+ nk_partial_load_b8x8_serial_(a_scalars, &a_vec, count_scalars);
186
+ nk_partial_load_b8x8_serial_(b_scalars, &b_vec, count_scalars);
187
+ a_bf16x8 = vreinterpretq_bf16_u16(nk_e5m2x8_to_bf16x8_neon_(a_vec.u8x8));
188
+ b_bf16x8 = vreinterpretq_bf16_u16(nk_e5m2x8_to_bf16x8_neon_(b_vec.u8x8));
189
+ count_scalars = 0;
190
+ }
191
+ else {
192
+ a_bf16x8 = vreinterpretq_bf16_u16(nk_e5m2x8_to_bf16x8_neon_(vld1_u8(a_scalars)));
193
+ b_bf16x8 = vreinterpretq_bf16_u16(nk_e5m2x8_to_bf16x8_neon_(vld1_u8(b_scalars)));
194
+ a_scalars += 8, b_scalars += 8, count_scalars -= 8;
195
+ }
196
+ sum_f32x4 = vbfdotq_f32(sum_f32x4, a_bf16x8, b_bf16x8);
197
+ if (count_scalars) goto nk_dot_e5m2_neonbfdot_cycle;
198
+ *result = vaddvq_f32(sum_f32x4);
199
+ }
200
+
201
+ /**
202
+ * @brief Running state for 128-bit dot accumulation over bf16 scalars on NEON.
203
+ */
204
+ typedef struct nk_dot_bf16x8_state_neonbfdot_t {
205
+ float32x4_t sum_f32x4;
206
+ } nk_dot_bf16x8_state_neonbfdot_t;
207
+
208
+ NK_INTERNAL void nk_dot_bf16x8_init_neonbfdot(nk_dot_bf16x8_state_neonbfdot_t *state) {
209
+ state->sum_f32x4 = vdupq_n_f32(0);
210
+ }
211
+
212
+ NK_INTERNAL void nk_dot_bf16x8_update_neonbfdot(nk_dot_bf16x8_state_neonbfdot_t *state, nk_b128_vec_t a,
213
+ nk_b128_vec_t b, nk_size_t depth_offset, nk_size_t active_dimensions) {
214
+ nk_unused_(depth_offset);
215
+ nk_unused_(active_dimensions);
216
+ bfloat16x8_t a_bf16x8 = vreinterpretq_bf16_u16(a.u16x8);
217
+ bfloat16x8_t b_bf16x8 = vreinterpretq_bf16_u16(b.u16x8);
218
+ state->sum_f32x4 = vbfdotq_f32(state->sum_f32x4, a_bf16x8, b_bf16x8);
219
+ }
220
+
221
+ NK_INTERNAL void nk_dot_bf16x8_finalize_neonbfdot( //
222
+ nk_dot_bf16x8_state_neonbfdot_t const *state_a, nk_dot_bf16x8_state_neonbfdot_t const *state_b, //
223
+ nk_dot_bf16x8_state_neonbfdot_t const *state_c, nk_dot_bf16x8_state_neonbfdot_t const *state_d, //
224
+ nk_size_t total_dimensions, nk_b128_vec_t *result) {
225
+ nk_unused_(total_dimensions);
226
+ result->f32s[0] = vaddvq_f32(state_a->sum_f32x4);
227
+ result->f32s[1] = vaddvq_f32(state_b->sum_f32x4);
228
+ result->f32s[2] = vaddvq_f32(state_c->sum_f32x4);
229
+ result->f32s[3] = vaddvq_f32(state_d->sum_f32x4);
230
+ }
231
+
232
+ #if defined(__clang__)
233
+ #pragma clang attribute pop
234
+ #elif defined(__GNUC__)
235
+ #pragma GCC pop_options
236
+ #endif
237
+
238
+ #if defined(__cplusplus)
239
+ } // extern "C"
240
+ #endif
241
+
242
+ #endif // NK_TARGET_NEONBFDOT
243
+ #endif // NK_TARGET_ARM_
244
+ #endif // NK_DOT_NEONBFDOT_H
@@ -0,0 +1,360 @@
1
+ /**
2
+ * @brief SIMD-accelerated Dot Products for NEON FHM.
3
+ * @file include/numkong/dot/neonfhm.h
4
+ * @author Ash Vardanian
5
+ * @date December 28, 2025
6
+ *
7
+ * @sa include/numkong/dot.h
8
+ *
9
+ * @section dot_neonfhm_instructions ARM NEON FP16 Matrix Instructions (ARMv8.4-FHM)
10
+ *
11
+ * Intrinsic Instruction Latency Throughput
12
+ * A76 M4+/V1+/Oryon
13
+ * vfmlalq_low_f16 FMLAL (V.4S, V.8H, V.8H) 4cy 2/cy 4/cy
14
+ * vfmlalq_high_f16 FMLAL2 (V.4S, V.8H, V.8H) 4cy 2/cy 4/cy
15
+ * vfmlslq_low_f16 FMLSL (V.4S, V.8H, V.8H) 4cy 2/cy 4/cy
16
+ * vfmlslq_high_f16 FMLSL2 (V.4S, V.8H, V.8H) 4cy 2/cy 4/cy
17
+ * vld1q_f16 LD1 (V.8H) 4cy 2/cy 3/cy
18
+ * vaddvq_f32 FADDP+FADDP (V.4S) 4cy 1/cy 2/cy
19
+ *
20
+ * The ARMv8.4-FHM extension (FEAT_FHM) provides FMLAL/FMLSL instructions that fuse FP16 to FP32
21
+ * widening with multiply-accumulate in a single operation. FMLAL executes as a single fused op
22
+ * (4cy latency, 2/cy throughput on A76, 4/cy on M4+/V1+/Oryon) rather than separate FCVTL + FMLA.
23
+ *
24
+ * FMLAL preserves FP32 accumulator precision while accepting FP16 inputs, ideal for mixed-precision
25
+ * workloads. The _low variants process elements 0-3, _high variants process elements 4-7, enabling
26
+ * processing of 8 FP16 elements per iteration with full precision accumulation.
27
+ *
28
+ * @section dot_neonfhm_stateful Stateful Streaming Logic
29
+ *
30
+ * To build memory-optimal tiled algorithms, this file defines following structures and force-inlined
31
+ * `NK_INTERNAL` functions:
32
+ *
33
+ * - nk_dot_f16x8 state with native FMLAL f16 dot-products.
34
+ *
35
+ * @code{c}
36
+ * nk_dot_f16x8_state_neonfhm_t state_first, state_second, state_third, state_fourth;
37
+ * float16x8_t query_f16x8, target_first_f16x8, target_second_f16x8, target_third_f16x8, target_fourth_f16x8;
38
+ * nk_dot_f16x8_init_neonfhm(&state_first);
39
+ * nk_dot_f16x8_init_neonfhm(&state_second);
40
+ * nk_dot_f16x8_init_neonfhm(&state_third);
41
+ * nk_dot_f16x8_init_neonfhm(&state_fourth);
42
+ * for (nk_size_t idx = 0; idx + 8 <= depth; idx += 8) {
43
+ * query_f16x8 = vld1q_f16(query_ptr + idx);
44
+ * target_first_f16x8 = vld1q_f16(target_first_ptr + idx);
45
+ * target_second_f16x8 = vld1q_f16(target_second_ptr + idx);
46
+ * target_third_f16x8 = vld1q_f16(target_third_ptr + idx);
47
+ * target_fourth_f16x8 = vld1q_f16(target_fourth_ptr + idx);
48
+ * nk_dot_f16x8_update_neonfhm(&state_first, query_f16x8, target_first_f16x8, idx, 8);
49
+ * nk_dot_f16x8_update_neonfhm(&state_second, query_f16x8, target_second_f16x8, idx, 8);
50
+ * nk_dot_f16x8_update_neonfhm(&state_third, query_f16x8, target_third_f16x8, idx, 8);
51
+ * nk_dot_f16x8_update_neonfhm(&state_fourth, query_f16x8, target_fourth_f16x8, idx, 8);
52
+ * }
53
+ * float32x4_t results_f32x4;
54
+ * nk_dot_f16x8_finalize_neonfhm(&state_first, &state_second, &state_third, &state_fourth, depth, &results_f32x4);
55
+ * @endcode
56
+ *
57
+ */
58
+ #ifndef NK_DOT_NEONFHM_H
59
+ #define NK_DOT_NEONFHM_H
60
+
61
+ #if NK_TARGET_ARM_
62
+ #if NK_TARGET_NEONFHM
63
+
64
+ #include "numkong/types.h"
65
+ #include "numkong/cast/serial.h" // `nk_partial_load_b8x8_serial_`
66
+ #include "numkong/cast/neon.h" // `nk_e4m3x8_to_f16x8_neon_`
67
+
68
+ #if defined(__cplusplus)
69
+ extern "C" {
70
+ #endif
71
+
72
+ #if defined(__clang__)
73
+ #pragma clang attribute push(__attribute__((target("arch=armv8.2-a+simd+fp16+fp16fml"))), apply_to = function)
74
+ #elif defined(__GNUC__)
75
+ #pragma GCC push_options
76
+ #pragma GCC target("arch=armv8.2-a+simd+fp16+fp16fml")
77
+ #endif
78
+
79
+ NK_PUBLIC void nk_dot_f16_neonfhm(nk_f16_t const *a_scalars, nk_f16_t const *b_scalars, nk_size_t count_scalars,
80
+ nk_f32_t *result) {
81
+ float16x8_t a_f16x8, b_f16x8;
82
+ float32x4_t sum_f32x4 = vdupq_n_f32(0);
83
+ nk_dot_f16_neonfhm_cycle:
84
+ if (count_scalars < 8) {
85
+ nk_b128_vec_t a_vec, b_vec;
86
+ nk_partial_load_b16x8_serial_(a_scalars, &a_vec, count_scalars);
87
+ nk_partial_load_b16x8_serial_(b_scalars, &b_vec, count_scalars);
88
+ a_f16x8 = vreinterpretq_f16_u16(a_vec.u16x8);
89
+ b_f16x8 = vreinterpretq_f16_u16(b_vec.u16x8);
90
+ count_scalars = 0;
91
+ }
92
+ else {
93
+ a_f16x8 = vld1q_f16((nk_f16_for_arm_simd_t const *)(a_scalars));
94
+ b_f16x8 = vld1q_f16((nk_f16_for_arm_simd_t const *)(b_scalars));
95
+ a_scalars += 8, b_scalars += 8, count_scalars -= 8;
96
+ }
97
+ // FMLAL: widening multiply-accumulate fp16 → f32
98
+ // low: processes elements 0-3, high: processes elements 4-7
99
+ sum_f32x4 = vfmlalq_low_f16(sum_f32x4, a_f16x8, b_f16x8);
100
+ sum_f32x4 = vfmlalq_high_f16(sum_f32x4, a_f16x8, b_f16x8);
101
+ if (count_scalars) goto nk_dot_f16_neonfhm_cycle;
102
+ *result = vaddvq_f32(sum_f32x4);
103
+ }
104
+
105
+ typedef struct nk_dot_f16x8_state_neonfhm_t {
106
+ float32x4_t sum_f32x4;
107
+ } nk_dot_f16x8_state_neonfhm_t;
108
+
109
+ NK_INTERNAL void nk_dot_f16x8_init_neonfhm(nk_dot_f16x8_state_neonfhm_t *state) { state->sum_f32x4 = vdupq_n_f32(0); }
110
+
111
+ NK_INTERNAL void nk_dot_f16x8_update_neonfhm(nk_dot_f16x8_state_neonfhm_t *state, nk_b128_vec_t a, nk_b128_vec_t b,
112
+ nk_size_t depth_offset, nk_size_t active_dimensions) {
113
+ nk_unused_(depth_offset);
114
+ nk_unused_(active_dimensions);
115
+ float16x8_t a_f16x8 = vreinterpretq_f16_u16(a.u16x8);
116
+ float16x8_t b_f16x8 = vreinterpretq_f16_u16(b.u16x8);
117
+ // FMLAL: widening multiply-accumulate fp16 → f32
118
+ state->sum_f32x4 = vfmlalq_low_f16(state->sum_f32x4, a_f16x8, b_f16x8);
119
+ state->sum_f32x4 = vfmlalq_high_f16(state->sum_f32x4, a_f16x8, b_f16x8);
120
+ }
121
+
122
+ NK_INTERNAL void nk_dot_f16x8_finalize_neonfhm( //
123
+ nk_dot_f16x8_state_neonfhm_t const *state_a, nk_dot_f16x8_state_neonfhm_t const *state_b, //
124
+ nk_dot_f16x8_state_neonfhm_t const *state_c, nk_dot_f16x8_state_neonfhm_t const *state_d, //
125
+ nk_size_t total_dimensions, nk_b128_vec_t *result) {
126
+ nk_unused_(total_dimensions);
127
+ result->f32s[0] = vaddvq_f32(state_a->sum_f32x4);
128
+ result->f32s[1] = vaddvq_f32(state_b->sum_f32x4);
129
+ result->f32s[2] = vaddvq_f32(state_c->sum_f32x4);
130
+ result->f32s[3] = vaddvq_f32(state_d->sum_f32x4);
131
+ }
132
+
133
+ NK_PUBLIC void nk_dot_f16c_neonfhm(nk_f16c_t const *a_pairs, nk_f16c_t const *b_pairs, nk_size_t count_pairs,
134
+ nk_f32c_t *result) {
135
+ // Accumulate into 4 float32x2_t vectors (low/high for real/imag)
136
+ float32x2_t sum_real_low_f32x2 = vdup_n_f32(0);
137
+ float32x2_t sum_real_high_f32x2 = vdup_n_f32(0);
138
+ float32x2_t sum_imag_low_f32x2 = vdup_n_f32(0);
139
+ float32x2_t sum_imag_high_f32x2 = vdup_n_f32(0);
140
+
141
+ while (count_pairs >= 4) {
142
+ // Load and deinterleave: vld2 loads 4 complex pairs as 2 x float16x4_t
143
+ int16x4x2_t a_i16x4x2 = vld2_s16((short const *)a_pairs);
144
+ int16x4x2_t b_i16x4x2 = vld2_s16((short const *)b_pairs);
145
+
146
+ float16x4_t a_real_f16x4 = vreinterpret_f16_s16(a_i16x4x2.val[0]);
147
+ float16x4_t a_imag_f16x4 = vreinterpret_f16_s16(a_i16x4x2.val[1]);
148
+ float16x4_t b_real_f16x4 = vreinterpret_f16_s16(b_i16x4x2.val[0]);
149
+ float16x4_t b_imag_f16x4 = vreinterpret_f16_s16(b_i16x4x2.val[1]);
150
+
151
+ // Real: aᵣ × bᵣ - aᵢ × bᵢ (FMLAL then FMLSL)
152
+ sum_real_low_f32x2 = vfmlal_low_f16(sum_real_low_f32x2, a_real_f16x4, b_real_f16x4);
153
+ sum_real_low_f32x2 = vfmlsl_low_f16(sum_real_low_f32x2, a_imag_f16x4, b_imag_f16x4);
154
+ sum_real_high_f32x2 = vfmlal_high_f16(sum_real_high_f32x2, a_real_f16x4, b_real_f16x4);
155
+ sum_real_high_f32x2 = vfmlsl_high_f16(sum_real_high_f32x2, a_imag_f16x4, b_imag_f16x4);
156
+
157
+ // Imag: aᵣ × bᵢ + aᵢ × bᵣ (FMLAL for both)
158
+ sum_imag_low_f32x2 = vfmlal_low_f16(sum_imag_low_f32x2, a_real_f16x4, b_imag_f16x4);
159
+ sum_imag_low_f32x2 = vfmlal_low_f16(sum_imag_low_f32x2, a_imag_f16x4, b_real_f16x4);
160
+ sum_imag_high_f32x2 = vfmlal_high_f16(sum_imag_high_f32x2, a_real_f16x4, b_imag_f16x4);
161
+ sum_imag_high_f32x2 = vfmlal_high_f16(sum_imag_high_f32x2, a_imag_f16x4, b_real_f16x4);
162
+
163
+ count_pairs -= 4, a_pairs += 4, b_pairs += 4;
164
+ }
165
+
166
+ // Combine and reduce
167
+ float32x4_t sum_real_f32x4 = vcombine_f32(sum_real_low_f32x2, sum_real_high_f32x2);
168
+ float32x4_t sum_imag_f32x4 = vcombine_f32(sum_imag_low_f32x2, sum_imag_high_f32x2);
169
+
170
+ // Handle tail with serial fallback
171
+ nk_f32c_t tail_result;
172
+ nk_dot_f16c_serial(a_pairs, b_pairs, count_pairs, &tail_result);
173
+ result->real = vaddvq_f32(sum_real_f32x4) + tail_result.real;
174
+ result->imag = vaddvq_f32(sum_imag_f32x4) + tail_result.imag;
175
+ }
176
+
177
+ NK_PUBLIC void nk_vdot_f16c_neonfhm(nk_f16c_t const *a_pairs, nk_f16c_t const *b_pairs, nk_size_t count_pairs,
178
+ nk_f32c_t *result) {
179
+ // Accumulate into 4 float32x2_t vectors (low/high for real/imag)
180
+ float32x2_t sum_real_low_f32x2 = vdup_n_f32(0);
181
+ float32x2_t sum_real_high_f32x2 = vdup_n_f32(0);
182
+ float32x2_t sum_imag_low_f32x2 = vdup_n_f32(0);
183
+ float32x2_t sum_imag_high_f32x2 = vdup_n_f32(0);
184
+
185
+ while (count_pairs >= 4) {
186
+ // Load and deinterleave: vld2 loads 4 complex pairs as 2 x float16x4_t
187
+ int16x4x2_t a_i16x4x2 = vld2_s16((short const *)a_pairs);
188
+ int16x4x2_t b_i16x4x2 = vld2_s16((short const *)b_pairs);
189
+
190
+ float16x4_t a_real_f16x4 = vreinterpret_f16_s16(a_i16x4x2.val[0]);
191
+ float16x4_t a_imag_f16x4 = vreinterpret_f16_s16(a_i16x4x2.val[1]);
192
+ float16x4_t b_real_f16x4 = vreinterpret_f16_s16(b_i16x4x2.val[0]);
193
+ float16x4_t b_imag_f16x4 = vreinterpret_f16_s16(b_i16x4x2.val[1]);
194
+
195
+ // Real: aᵣ × bᵣ + aᵢ × bᵢ (FMLAL for both)
196
+ sum_real_low_f32x2 = vfmlal_low_f16(sum_real_low_f32x2, a_real_f16x4, b_real_f16x4);
197
+ sum_real_low_f32x2 = vfmlal_low_f16(sum_real_low_f32x2, a_imag_f16x4, b_imag_f16x4);
198
+ sum_real_high_f32x2 = vfmlal_high_f16(sum_real_high_f32x2, a_real_f16x4, b_real_f16x4);
199
+ sum_real_high_f32x2 = vfmlal_high_f16(sum_real_high_f32x2, a_imag_f16x4, b_imag_f16x4);
200
+
201
+ // Imag: aᵣ × bᵢ - aᵢ × bᵣ (FMLAL then FMLSL)
202
+ sum_imag_low_f32x2 = vfmlal_low_f16(sum_imag_low_f32x2, a_real_f16x4, b_imag_f16x4);
203
+ sum_imag_low_f32x2 = vfmlsl_low_f16(sum_imag_low_f32x2, a_imag_f16x4, b_real_f16x4);
204
+ sum_imag_high_f32x2 = vfmlal_high_f16(sum_imag_high_f32x2, a_real_f16x4, b_imag_f16x4);
205
+ sum_imag_high_f32x2 = vfmlsl_high_f16(sum_imag_high_f32x2, a_imag_f16x4, b_real_f16x4);
206
+
207
+ count_pairs -= 4, a_pairs += 4, b_pairs += 4;
208
+ }
209
+
210
+ // Combine and reduce
211
+ float32x4_t sum_real_f32x4 = vcombine_f32(sum_real_low_f32x2, sum_real_high_f32x2);
212
+ float32x4_t sum_imag_f32x4 = vcombine_f32(sum_imag_low_f32x2, sum_imag_high_f32x2);
213
+
214
+ // Handle tail with serial fallback
215
+ nk_f32c_t tail_result;
216
+ nk_vdot_f16c_serial(a_pairs, b_pairs, count_pairs, &tail_result);
217
+ result->real = vaddvq_f32(sum_real_f32x4) + tail_result.real;
218
+ result->imag = vaddvq_f32(sum_imag_f32x4) + tail_result.imag;
219
+ }
220
+
221
+ NK_PUBLIC void nk_dot_e4m3_neonfhm(nk_e4m3_t const *a_scalars, nk_e4m3_t const *b_scalars, nk_size_t count_scalars,
222
+ nk_f32_t *result) {
223
+ float16x8_t a_low, a_high, b_low, b_high;
224
+ float32x4_t sum_f32x4 = vdupq_n_f32(0);
225
+ nk_dot_e4m3_neonfhm_cycle:
226
+ if (count_scalars < 16) {
227
+ nk_b128_vec_t a_vec, b_vec;
228
+ nk_partial_load_b8x16_serial_(a_scalars, &a_vec, count_scalars);
229
+ nk_partial_load_b8x16_serial_(b_scalars, &b_vec, count_scalars);
230
+ nk_e4m3x16_to_f16x8x2_neon_(a_vec.u8x16, &a_low, &a_high);
231
+ nk_e4m3x16_to_f16x8x2_neon_(b_vec.u8x16, &b_low, &b_high);
232
+ count_scalars = 0;
233
+ }
234
+ else {
235
+ nk_e4m3x16_to_f16x8x2_neon_(vld1q_u8(a_scalars), &a_low, &a_high);
236
+ nk_e4m3x16_to_f16x8x2_neon_(vld1q_u8(b_scalars), &b_low, &b_high);
237
+ a_scalars += 16, b_scalars += 16, count_scalars -= 16;
238
+ }
239
+ sum_f32x4 = vfmlalq_low_f16(sum_f32x4, a_low, b_low);
240
+ sum_f32x4 = vfmlalq_high_f16(sum_f32x4, a_low, b_low);
241
+ sum_f32x4 = vfmlalq_low_f16(sum_f32x4, a_high, b_high);
242
+ sum_f32x4 = vfmlalq_high_f16(sum_f32x4, a_high, b_high);
243
+ if (count_scalars) goto nk_dot_e4m3_neonfhm_cycle;
244
+ *result = vaddvq_f32(sum_f32x4);
245
+ }
246
+
247
+ NK_PUBLIC void nk_dot_e5m2_neonfhm(nk_e5m2_t const *a_scalars, nk_e5m2_t const *b_scalars, nk_size_t count_scalars,
248
+ nk_f32_t *result) {
249
+ float16x8_t a_low, a_high, b_low, b_high;
250
+ float32x4_t sum_f32x4 = vdupq_n_f32(0);
251
+ nk_dot_e5m2_neonfhm_cycle:
252
+ if (count_scalars < 16) {
253
+ nk_b128_vec_t a_vec, b_vec;
254
+ nk_partial_load_b8x16_serial_(a_scalars, &a_vec, count_scalars);
255
+ nk_partial_load_b8x16_serial_(b_scalars, &b_vec, count_scalars);
256
+ a_low = vreinterpretq_f16_u16(vshll_n_u8(vget_low_u8(a_vec.u8x16), 8));
257
+ a_high = vreinterpretq_f16_u16(vshll_n_u8(vget_high_u8(a_vec.u8x16), 8));
258
+ b_low = vreinterpretq_f16_u16(vshll_n_u8(vget_low_u8(b_vec.u8x16), 8));
259
+ b_high = vreinterpretq_f16_u16(vshll_n_u8(vget_high_u8(b_vec.u8x16), 8));
260
+ count_scalars = 0;
261
+ }
262
+ else {
263
+ uint8x16_t a_u8x16 = vld1q_u8(a_scalars);
264
+ uint8x16_t b_u8x16 = vld1q_u8(b_scalars);
265
+ a_low = vreinterpretq_f16_u16(vshll_n_u8(vget_low_u8(a_u8x16), 8));
266
+ a_high = vreinterpretq_f16_u16(vshll_n_u8(vget_high_u8(a_u8x16), 8));
267
+ b_low = vreinterpretq_f16_u16(vshll_n_u8(vget_low_u8(b_u8x16), 8));
268
+ b_high = vreinterpretq_f16_u16(vshll_n_u8(vget_high_u8(b_u8x16), 8));
269
+ a_scalars += 16, b_scalars += 16, count_scalars -= 16;
270
+ }
271
+ sum_f32x4 = vfmlalq_low_f16(sum_f32x4, a_low, b_low);
272
+ sum_f32x4 = vfmlalq_high_f16(sum_f32x4, a_low, b_low);
273
+ sum_f32x4 = vfmlalq_low_f16(sum_f32x4, a_high, b_high);
274
+ sum_f32x4 = vfmlalq_high_f16(sum_f32x4, a_high, b_high);
275
+ if (count_scalars) goto nk_dot_e5m2_neonfhm_cycle;
276
+ *result = vaddvq_f32(sum_f32x4);
277
+ }
278
+
279
+ typedef struct nk_dot_e4m3x16_state_neonfhm_t {
280
+ float32x4_t sum_f32x4;
281
+ } nk_dot_e4m3x16_state_neonfhm_t;
282
+
283
+ NK_INTERNAL void nk_dot_e4m3x16_init_neonfhm(nk_dot_e4m3x16_state_neonfhm_t *state) {
284
+ state->sum_f32x4 = vdupq_n_f32(0);
285
+ }
286
+
287
+ NK_INTERNAL void nk_dot_e4m3x16_update_neonfhm(nk_dot_e4m3x16_state_neonfhm_t *state, nk_b128_vec_t a, nk_b128_vec_t b,
288
+ nk_size_t depth_offset, nk_size_t active_dimensions) {
289
+ nk_unused_(depth_offset);
290
+ nk_unused_(active_dimensions);
291
+ // Convert e4m3 → f16 using 16-element LUT path (4× VQTBL4)
292
+ float16x8_t a_low_f16x8, a_high_f16x8, b_low_f16x8, b_high_f16x8;
293
+ nk_e4m3x16_to_f16x8x2_neon_(a.u8x16, &a_low_f16x8, &a_high_f16x8);
294
+ nk_e4m3x16_to_f16x8x2_neon_(b.u8x16, &b_low_f16x8, &b_high_f16x8);
295
+ // FMLAL: widening multiply-accumulate fp16 → f32
296
+ state->sum_f32x4 = vfmlalq_low_f16(state->sum_f32x4, a_low_f16x8, b_low_f16x8);
297
+ state->sum_f32x4 = vfmlalq_high_f16(state->sum_f32x4, a_low_f16x8, b_low_f16x8);
298
+ state->sum_f32x4 = vfmlalq_low_f16(state->sum_f32x4, a_high_f16x8, b_high_f16x8);
299
+ state->sum_f32x4 = vfmlalq_high_f16(state->sum_f32x4, a_high_f16x8, b_high_f16x8);
300
+ }
301
+
302
+ NK_INTERNAL void nk_dot_e4m3x16_finalize_neonfhm( //
303
+ nk_dot_e4m3x16_state_neonfhm_t const *state_a, nk_dot_e4m3x16_state_neonfhm_t const *state_b, //
304
+ nk_dot_e4m3x16_state_neonfhm_t const *state_c, nk_dot_e4m3x16_state_neonfhm_t const *state_d, //
305
+ nk_size_t total_dimensions, nk_b128_vec_t *result) {
306
+ nk_unused_(total_dimensions);
307
+ result->f32s[0] = vaddvq_f32(state_a->sum_f32x4);
308
+ result->f32s[1] = vaddvq_f32(state_b->sum_f32x4);
309
+ result->f32s[2] = vaddvq_f32(state_c->sum_f32x4);
310
+ result->f32s[3] = vaddvq_f32(state_d->sum_f32x4);
311
+ }
312
+
313
+ typedef struct nk_dot_e5m2x16_state_neonfhm_t {
314
+ float32x4_t sum_f32x4;
315
+ } nk_dot_e5m2x16_state_neonfhm_t;
316
+
317
+ NK_INTERNAL void nk_dot_e5m2x16_init_neonfhm(nk_dot_e5m2x16_state_neonfhm_t *state) {
318
+ state->sum_f32x4 = vdupq_n_f32(0);
319
+ }
320
+
321
+ NK_INTERNAL void nk_dot_e5m2x16_update_neonfhm(nk_dot_e5m2x16_state_neonfhm_t *state, nk_b128_vec_t a, nk_b128_vec_t b,
322
+ nk_size_t depth_offset, nk_size_t active_dimensions) {
323
+ nk_unused_(depth_offset);
324
+ nk_unused_(active_dimensions);
325
+ // Convert e5m2 → f16 via SHLL: widen u8→u16 and shift left 8 in one instruction
326
+ float16x8_t a_low_f16x8 = vreinterpretq_f16_u16(vshll_n_u8(vget_low_u8(a.u8x16), 8));
327
+ float16x8_t a_high_f16x8 = vreinterpretq_f16_u16(vshll_n_u8(vget_high_u8(a.u8x16), 8));
328
+ float16x8_t b_low_f16x8 = vreinterpretq_f16_u16(vshll_n_u8(vget_low_u8(b.u8x16), 8));
329
+ float16x8_t b_high_f16x8 = vreinterpretq_f16_u16(vshll_n_u8(vget_high_u8(b.u8x16), 8));
330
+ // FMLAL: widening multiply-accumulate fp16 → f32
331
+ state->sum_f32x4 = vfmlalq_low_f16(state->sum_f32x4, a_low_f16x8, b_low_f16x8);
332
+ state->sum_f32x4 = vfmlalq_high_f16(state->sum_f32x4, a_low_f16x8, b_low_f16x8);
333
+ state->sum_f32x4 = vfmlalq_low_f16(state->sum_f32x4, a_high_f16x8, b_high_f16x8);
334
+ state->sum_f32x4 = vfmlalq_high_f16(state->sum_f32x4, a_high_f16x8, b_high_f16x8);
335
+ }
336
+
337
+ NK_INTERNAL void nk_dot_e5m2x16_finalize_neonfhm( //
338
+ nk_dot_e5m2x16_state_neonfhm_t const *state_a, nk_dot_e5m2x16_state_neonfhm_t const *state_b, //
339
+ nk_dot_e5m2x16_state_neonfhm_t const *state_c, nk_dot_e5m2x16_state_neonfhm_t const *state_d, //
340
+ nk_size_t total_dimensions, nk_b128_vec_t *result) {
341
+ nk_unused_(total_dimensions);
342
+ result->f32s[0] = vaddvq_f32(state_a->sum_f32x4);
343
+ result->f32s[1] = vaddvq_f32(state_b->sum_f32x4);
344
+ result->f32s[2] = vaddvq_f32(state_c->sum_f32x4);
345
+ result->f32s[3] = vaddvq_f32(state_d->sum_f32x4);
346
+ }
347
+
348
+ #if defined(__clang__)
349
+ #pragma clang attribute pop
350
+ #elif defined(__GNUC__)
351
+ #pragma GCC pop_options
352
+ #endif
353
+
354
+ #if defined(__cplusplus)
355
+ } // extern "C"
356
+ #endif
357
+
358
+ #endif // NK_TARGET_NEONFHM
359
+ #endif // NK_TARGET_ARM_
360
+ #endif // NK_DOT_NEONFHM_H