numkong 7.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (294) hide show
  1. package/LICENSE +201 -0
  2. package/README.md +495 -0
  3. package/binding.gyp +540 -0
  4. package/c/dispatch.h +512 -0
  5. package/c/dispatch_bf16.c +389 -0
  6. package/c/dispatch_bf16c.c +52 -0
  7. package/c/dispatch_e2m3.c +263 -0
  8. package/c/dispatch_e3m2.c +243 -0
  9. package/c/dispatch_e4m3.c +276 -0
  10. package/c/dispatch_e5m2.c +272 -0
  11. package/c/dispatch_f16.c +376 -0
  12. package/c/dispatch_f16c.c +58 -0
  13. package/c/dispatch_f32.c +378 -0
  14. package/c/dispatch_f32c.c +99 -0
  15. package/c/dispatch_f64.c +296 -0
  16. package/c/dispatch_f64c.c +98 -0
  17. package/c/dispatch_i16.c +96 -0
  18. package/c/dispatch_i32.c +89 -0
  19. package/c/dispatch_i4.c +150 -0
  20. package/c/dispatch_i64.c +86 -0
  21. package/c/dispatch_i8.c +289 -0
  22. package/c/dispatch_other.c +330 -0
  23. package/c/dispatch_u1.c +148 -0
  24. package/c/dispatch_u16.c +124 -0
  25. package/c/dispatch_u32.c +118 -0
  26. package/c/dispatch_u4.c +150 -0
  27. package/c/dispatch_u64.c +102 -0
  28. package/c/dispatch_u8.c +303 -0
  29. package/c/numkong.c +950 -0
  30. package/include/README.md +573 -0
  31. package/include/module.modulemap +129 -0
  32. package/include/numkong/attention/sapphireamx.h +1361 -0
  33. package/include/numkong/attention/sme.h +2066 -0
  34. package/include/numkong/attention.h +49 -0
  35. package/include/numkong/capabilities.h +748 -0
  36. package/include/numkong/cast/README.md +262 -0
  37. package/include/numkong/cast/haswell.h +975 -0
  38. package/include/numkong/cast/icelake.h +470 -0
  39. package/include/numkong/cast/neon.h +1192 -0
  40. package/include/numkong/cast/rvv.h +1021 -0
  41. package/include/numkong/cast/sapphire.h +262 -0
  42. package/include/numkong/cast/serial.h +2262 -0
  43. package/include/numkong/cast/skylake.h +856 -0
  44. package/include/numkong/cast/v128relaxed.h +180 -0
  45. package/include/numkong/cast.h +230 -0
  46. package/include/numkong/curved/README.md +223 -0
  47. package/include/numkong/curved/genoa.h +182 -0
  48. package/include/numkong/curved/haswell.h +276 -0
  49. package/include/numkong/curved/neon.h +205 -0
  50. package/include/numkong/curved/neonbfdot.h +212 -0
  51. package/include/numkong/curved/neonhalf.h +212 -0
  52. package/include/numkong/curved/rvv.h +305 -0
  53. package/include/numkong/curved/serial.h +207 -0
  54. package/include/numkong/curved/skylake.h +457 -0
  55. package/include/numkong/curved/smef64.h +506 -0
  56. package/include/numkong/curved.h +517 -0
  57. package/include/numkong/curved.hpp +144 -0
  58. package/include/numkong/dot/README.md +425 -0
  59. package/include/numkong/dot/alder.h +563 -0
  60. package/include/numkong/dot/genoa.h +315 -0
  61. package/include/numkong/dot/haswell.h +1688 -0
  62. package/include/numkong/dot/icelake.h +883 -0
  63. package/include/numkong/dot/neon.h +818 -0
  64. package/include/numkong/dot/neonbfdot.h +244 -0
  65. package/include/numkong/dot/neonfhm.h +360 -0
  66. package/include/numkong/dot/neonhalf.h +198 -0
  67. package/include/numkong/dot/neonsdot.h +508 -0
  68. package/include/numkong/dot/rvv.h +714 -0
  69. package/include/numkong/dot/rvvbb.h +72 -0
  70. package/include/numkong/dot/rvvbf16.h +123 -0
  71. package/include/numkong/dot/rvvhalf.h +129 -0
  72. package/include/numkong/dot/sapphire.h +141 -0
  73. package/include/numkong/dot/serial.h +838 -0
  74. package/include/numkong/dot/sierra.h +405 -0
  75. package/include/numkong/dot/skylake.h +1084 -0
  76. package/include/numkong/dot/sve.h +379 -0
  77. package/include/numkong/dot/svebfdot.h +74 -0
  78. package/include/numkong/dot/svehalf.h +123 -0
  79. package/include/numkong/dot/v128relaxed.h +1258 -0
  80. package/include/numkong/dot.h +1070 -0
  81. package/include/numkong/dot.hpp +94 -0
  82. package/include/numkong/dots/README.md +496 -0
  83. package/include/numkong/dots/alder.h +114 -0
  84. package/include/numkong/dots/genoa.h +94 -0
  85. package/include/numkong/dots/haswell.h +295 -0
  86. package/include/numkong/dots/icelake.h +171 -0
  87. package/include/numkong/dots/neon.h +120 -0
  88. package/include/numkong/dots/neonbfdot.h +58 -0
  89. package/include/numkong/dots/neonfhm.h +94 -0
  90. package/include/numkong/dots/neonhalf.h +57 -0
  91. package/include/numkong/dots/neonsdot.h +108 -0
  92. package/include/numkong/dots/rvv.h +2486 -0
  93. package/include/numkong/dots/sapphireamx.h +3973 -0
  94. package/include/numkong/dots/serial.h +2844 -0
  95. package/include/numkong/dots/sierra.h +97 -0
  96. package/include/numkong/dots/skylake.h +196 -0
  97. package/include/numkong/dots/sme.h +5372 -0
  98. package/include/numkong/dots/smebi32.h +461 -0
  99. package/include/numkong/dots/smef64.h +1318 -0
  100. package/include/numkong/dots/smehalf.h +47 -0
  101. package/include/numkong/dots/v128relaxed.h +294 -0
  102. package/include/numkong/dots.h +2804 -0
  103. package/include/numkong/dots.hpp +639 -0
  104. package/include/numkong/each/README.md +469 -0
  105. package/include/numkong/each/haswell.h +1658 -0
  106. package/include/numkong/each/icelake.h +272 -0
  107. package/include/numkong/each/neon.h +1104 -0
  108. package/include/numkong/each/neonbfdot.h +212 -0
  109. package/include/numkong/each/neonhalf.h +410 -0
  110. package/include/numkong/each/rvv.h +1121 -0
  111. package/include/numkong/each/sapphire.h +477 -0
  112. package/include/numkong/each/serial.h +260 -0
  113. package/include/numkong/each/skylake.h +1562 -0
  114. package/include/numkong/each.h +2146 -0
  115. package/include/numkong/each.hpp +434 -0
  116. package/include/numkong/geospatial/README.md +147 -0
  117. package/include/numkong/geospatial/haswell.h +593 -0
  118. package/include/numkong/geospatial/neon.h +571 -0
  119. package/include/numkong/geospatial/rvv.h +701 -0
  120. package/include/numkong/geospatial/serial.h +309 -0
  121. package/include/numkong/geospatial/skylake.h +577 -0
  122. package/include/numkong/geospatial/v128relaxed.h +613 -0
  123. package/include/numkong/geospatial.h +453 -0
  124. package/include/numkong/geospatial.hpp +235 -0
  125. package/include/numkong/matrix.hpp +336 -0
  126. package/include/numkong/maxsim/README.md +187 -0
  127. package/include/numkong/maxsim/alder.h +511 -0
  128. package/include/numkong/maxsim/genoa.h +115 -0
  129. package/include/numkong/maxsim/haswell.h +553 -0
  130. package/include/numkong/maxsim/icelake.h +480 -0
  131. package/include/numkong/maxsim/neonsdot.h +394 -0
  132. package/include/numkong/maxsim/sapphireamx.h +877 -0
  133. package/include/numkong/maxsim/serial.h +490 -0
  134. package/include/numkong/maxsim/sme.h +929 -0
  135. package/include/numkong/maxsim/v128relaxed.h +280 -0
  136. package/include/numkong/maxsim.h +571 -0
  137. package/include/numkong/maxsim.hpp +133 -0
  138. package/include/numkong/mesh/README.md +227 -0
  139. package/include/numkong/mesh/haswell.h +2235 -0
  140. package/include/numkong/mesh/neon.h +1329 -0
  141. package/include/numkong/mesh/neonbfdot.h +842 -0
  142. package/include/numkong/mesh/neonhalf.h +616 -0
  143. package/include/numkong/mesh/rvv.h +916 -0
  144. package/include/numkong/mesh/serial.h +742 -0
  145. package/include/numkong/mesh/skylake.h +1135 -0
  146. package/include/numkong/mesh/v128relaxed.h +1052 -0
  147. package/include/numkong/mesh.h +652 -0
  148. package/include/numkong/mesh.hpp +762 -0
  149. package/include/numkong/numkong.h +78 -0
  150. package/include/numkong/numkong.hpp +57 -0
  151. package/include/numkong/probability/README.md +173 -0
  152. package/include/numkong/probability/haswell.h +267 -0
  153. package/include/numkong/probability/neon.h +225 -0
  154. package/include/numkong/probability/rvv.h +409 -0
  155. package/include/numkong/probability/serial.h +169 -0
  156. package/include/numkong/probability/skylake.h +324 -0
  157. package/include/numkong/probability.h +383 -0
  158. package/include/numkong/probability.hpp +120 -0
  159. package/include/numkong/random.h +50 -0
  160. package/include/numkong/random.hpp +285 -0
  161. package/include/numkong/reduce/README.md +547 -0
  162. package/include/numkong/reduce/alder.h +632 -0
  163. package/include/numkong/reduce/genoa.h +201 -0
  164. package/include/numkong/reduce/haswell.h +3783 -0
  165. package/include/numkong/reduce/icelake.h +549 -0
  166. package/include/numkong/reduce/neon.h +3841 -0
  167. package/include/numkong/reduce/neonbfdot.h +353 -0
  168. package/include/numkong/reduce/neonfhm.h +665 -0
  169. package/include/numkong/reduce/neonhalf.h +157 -0
  170. package/include/numkong/reduce/neonsdot.h +357 -0
  171. package/include/numkong/reduce/rvv.h +3407 -0
  172. package/include/numkong/reduce/serial.h +757 -0
  173. package/include/numkong/reduce/sierra.h +338 -0
  174. package/include/numkong/reduce/skylake.h +3792 -0
  175. package/include/numkong/reduce/v128relaxed.h +2302 -0
  176. package/include/numkong/reduce.h +1597 -0
  177. package/include/numkong/reduce.hpp +633 -0
  178. package/include/numkong/scalar/README.md +89 -0
  179. package/include/numkong/scalar/haswell.h +113 -0
  180. package/include/numkong/scalar/neon.h +122 -0
  181. package/include/numkong/scalar/neonhalf.h +70 -0
  182. package/include/numkong/scalar/rvv.h +211 -0
  183. package/include/numkong/scalar/sapphire.h +63 -0
  184. package/include/numkong/scalar/serial.h +332 -0
  185. package/include/numkong/scalar/v128relaxed.h +56 -0
  186. package/include/numkong/scalar.h +683 -0
  187. package/include/numkong/set/README.md +179 -0
  188. package/include/numkong/set/haswell.h +334 -0
  189. package/include/numkong/set/icelake.h +485 -0
  190. package/include/numkong/set/neon.h +364 -0
  191. package/include/numkong/set/rvv.h +226 -0
  192. package/include/numkong/set/rvvbb.h +117 -0
  193. package/include/numkong/set/serial.h +174 -0
  194. package/include/numkong/set/sve.h +185 -0
  195. package/include/numkong/set/v128relaxed.h +240 -0
  196. package/include/numkong/set.h +457 -0
  197. package/include/numkong/set.hpp +114 -0
  198. package/include/numkong/sets/README.md +149 -0
  199. package/include/numkong/sets/haswell.h +63 -0
  200. package/include/numkong/sets/icelake.h +66 -0
  201. package/include/numkong/sets/neon.h +61 -0
  202. package/include/numkong/sets/serial.h +43 -0
  203. package/include/numkong/sets/smebi32.h +1099 -0
  204. package/include/numkong/sets/v128relaxed.h +58 -0
  205. package/include/numkong/sets.h +339 -0
  206. package/include/numkong/sparse/README.md +156 -0
  207. package/include/numkong/sparse/icelake.h +463 -0
  208. package/include/numkong/sparse/neon.h +288 -0
  209. package/include/numkong/sparse/serial.h +117 -0
  210. package/include/numkong/sparse/sve2.h +507 -0
  211. package/include/numkong/sparse/turin.h +322 -0
  212. package/include/numkong/sparse.h +363 -0
  213. package/include/numkong/sparse.hpp +113 -0
  214. package/include/numkong/spatial/README.md +435 -0
  215. package/include/numkong/spatial/alder.h +607 -0
  216. package/include/numkong/spatial/genoa.h +290 -0
  217. package/include/numkong/spatial/haswell.h +960 -0
  218. package/include/numkong/spatial/icelake.h +586 -0
  219. package/include/numkong/spatial/neon.h +773 -0
  220. package/include/numkong/spatial/neonbfdot.h +165 -0
  221. package/include/numkong/spatial/neonhalf.h +118 -0
  222. package/include/numkong/spatial/neonsdot.h +261 -0
  223. package/include/numkong/spatial/rvv.h +984 -0
  224. package/include/numkong/spatial/rvvbf16.h +123 -0
  225. package/include/numkong/spatial/rvvhalf.h +117 -0
  226. package/include/numkong/spatial/sapphire.h +343 -0
  227. package/include/numkong/spatial/serial.h +346 -0
  228. package/include/numkong/spatial/sierra.h +323 -0
  229. package/include/numkong/spatial/skylake.h +606 -0
  230. package/include/numkong/spatial/sve.h +224 -0
  231. package/include/numkong/spatial/svebfdot.h +122 -0
  232. package/include/numkong/spatial/svehalf.h +109 -0
  233. package/include/numkong/spatial/v128relaxed.h +717 -0
  234. package/include/numkong/spatial.h +1425 -0
  235. package/include/numkong/spatial.hpp +183 -0
  236. package/include/numkong/spatials/README.md +580 -0
  237. package/include/numkong/spatials/alder.h +94 -0
  238. package/include/numkong/spatials/genoa.h +94 -0
  239. package/include/numkong/spatials/haswell.h +219 -0
  240. package/include/numkong/spatials/icelake.h +113 -0
  241. package/include/numkong/spatials/neon.h +109 -0
  242. package/include/numkong/spatials/neonbfdot.h +60 -0
  243. package/include/numkong/spatials/neonfhm.h +92 -0
  244. package/include/numkong/spatials/neonhalf.h +58 -0
  245. package/include/numkong/spatials/neonsdot.h +109 -0
  246. package/include/numkong/spatials/rvv.h +1960 -0
  247. package/include/numkong/spatials/sapphireamx.h +1149 -0
  248. package/include/numkong/spatials/serial.h +226 -0
  249. package/include/numkong/spatials/sierra.h +96 -0
  250. package/include/numkong/spatials/skylake.h +184 -0
  251. package/include/numkong/spatials/sme.h +1901 -0
  252. package/include/numkong/spatials/smef64.h +465 -0
  253. package/include/numkong/spatials/v128relaxed.h +240 -0
  254. package/include/numkong/spatials.h +3021 -0
  255. package/include/numkong/spatials.hpp +508 -0
  256. package/include/numkong/tensor.hpp +1592 -0
  257. package/include/numkong/trigonometry/README.md +184 -0
  258. package/include/numkong/trigonometry/haswell.h +652 -0
  259. package/include/numkong/trigonometry/neon.h +639 -0
  260. package/include/numkong/trigonometry/rvv.h +699 -0
  261. package/include/numkong/trigonometry/serial.h +703 -0
  262. package/include/numkong/trigonometry/skylake.h +721 -0
  263. package/include/numkong/trigonometry/v128relaxed.h +666 -0
  264. package/include/numkong/trigonometry.h +467 -0
  265. package/include/numkong/trigonometry.hpp +166 -0
  266. package/include/numkong/types.h +1384 -0
  267. package/include/numkong/types.hpp +5603 -0
  268. package/include/numkong/vector.hpp +698 -0
  269. package/javascript/README.md +246 -0
  270. package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
  271. package/javascript/dist/cjs/numkong-wasm.js +617 -0
  272. package/javascript/dist/cjs/numkong.d.ts +343 -0
  273. package/javascript/dist/cjs/numkong.js +523 -0
  274. package/javascript/dist/cjs/package.json +3 -0
  275. package/javascript/dist/cjs/types.d.ts +284 -0
  276. package/javascript/dist/cjs/types.js +653 -0
  277. package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
  278. package/javascript/dist/esm/numkong-wasm.js +595 -0
  279. package/javascript/dist/esm/numkong.d.ts +343 -0
  280. package/javascript/dist/esm/numkong.js +452 -0
  281. package/javascript/dist/esm/package.json +3 -0
  282. package/javascript/dist/esm/types.d.ts +284 -0
  283. package/javascript/dist/esm/types.js +630 -0
  284. package/javascript/dist-package-cjs.json +3 -0
  285. package/javascript/dist-package-esm.json +3 -0
  286. package/javascript/node-gyp-build.d.ts +1 -0
  287. package/javascript/numkong-wasm.ts +756 -0
  288. package/javascript/numkong.c +689 -0
  289. package/javascript/numkong.ts +575 -0
  290. package/javascript/tsconfig-base.json +39 -0
  291. package/javascript/tsconfig-cjs.json +8 -0
  292. package/javascript/tsconfig-esm.json +8 -0
  293. package/javascript/types.ts +674 -0
  294. package/package.json +87 -0
@@ -0,0 +1,212 @@
1
+ /**
2
+ * @brief SIMD-accelerated Elementwise Arithmetic for NEON BF16.
3
+ * @file include/numkong/each/neonbfdot.h
4
+ * @author Ash Vardanian
5
+ * @date December 27, 2025
6
+ *
7
+ * @sa include/numkong/each.h
8
+ *
9
+ * @section elementwise_neonbfdot_instructions ARM NEON BF16 Instructions (ARMv8.6-BF16)
10
+ *
11
+ * Intrinsic Instruction Latency Throughput
12
+ * A76 M4+/V1+/Oryon
13
+ * vld1_bf16 LD1 (V.4H) 4cy 2/cy 3/cy
14
+ * vst1_bf16 ST1 (V.4H) 2cy 2/cy 3/cy
15
+ * vcvt_f32_bf16 BFCVTN (V.4H, V.4S) 3cy 2/cy 4/cy
16
+ * vcvt_bf16_f32 BFCVT (V.4H, V.4S) 3cy 2/cy 4/cy
17
+ * vaddq_f32 FADD (V.4S, V.4S, V.4S) 2cy 2/cy 4/cy
18
+ * vmulq_f32 FMUL (V.4S, V.4S, V.4S) 3cy 2/cy 4/cy
19
+ * vmulq_n_f32 FMUL (V.4S, V.4S, scalar) 3cy 2/cy 4/cy
20
+ * vfmaq_f32 FMLA (V.4S, V.4S, V.4S) 4cy 2/cy 4/cy
21
+ * vfmaq_n_f32 FMLA (V.4S, V.4S, scalar) 4cy 2/cy 4/cy
22
+ * vdupq_n_f32 DUP (V.4S, scalar) 2cy 2/cy 4/cy
23
+ *
24
+ * The ARMv8.6-BF16 extension provides element-wise operations on BF16 data by converting to F32
25
+ * for arithmetic, then back to BF16 for storage. This preserves the dynamic range benefits of BF16
26
+ * (matching F32 exponent) while using F32 precision for intermediate calculations.
27
+ *
28
+ * Operations process 4 BF16 elements at a time, widening to F32 for computation. While this gives
29
+ * lower throughput than native F16 operations, it prevents overflow issues common with FP16's
30
+ * limited exponent range in ML training workloads.
31
+ */
32
+ #ifndef NK_EACH_NEONBFDOT_H
33
+ #define NK_EACH_NEONBFDOT_H
34
+
35
+ #if NK_TARGET_ARM_
36
+ #if NK_TARGET_NEONBFDOT
37
+
38
+ #include "numkong/types.h"
39
+ #include "numkong/cast/serial.h"
40
+
41
+ #if defined(__cplusplus)
42
+ extern "C" {
43
+ #endif
44
+
45
+ #if defined(__clang__)
46
+ #pragma clang attribute push(__attribute__((target("arch=armv8.6-a+simd+bf16"))), apply_to = function)
47
+ #elif defined(__GNUC__)
48
+ #pragma GCC push_options
49
+ #pragma GCC target("arch=armv8.6-a+simd+bf16")
50
+ #endif
51
+
52
+ NK_PUBLIC void nk_each_sum_bf16_neonbfdot(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_bf16_t *result) {
53
+ nk_size_t i = 0;
54
+ for (; i + 4 <= n; i += 4) {
55
+ bfloat16x4_t a_bf16x4 = vld1_bf16((bfloat16_t const *)a + i);
56
+ bfloat16x4_t b_bf16x4 = vld1_bf16((bfloat16_t const *)b + i);
57
+ float32x4_t a_f32x4 = vcvt_f32_bf16(a_bf16x4);
58
+ float32x4_t b_f32x4 = vcvt_f32_bf16(b_bf16x4);
59
+ float32x4_t result_f32x4 = vaddq_f32(a_f32x4, b_f32x4);
60
+ bfloat16x4_t result_bf16x4 = vcvt_bf16_f32(result_f32x4);
61
+ vst1_bf16((bfloat16_t *)result + i, result_bf16x4);
62
+ }
63
+ if (i < n) {
64
+ nk_b64_vec_t a_tail, b_tail;
65
+ nk_partial_load_b16x4_serial_(a + i, &a_tail, n - i);
66
+ nk_partial_load_b16x4_serial_(b + i, &b_tail, n - i);
67
+ bfloat16x4_t a_bf16x4 = vreinterpret_bf16_u16(a_tail.u16x4);
68
+ bfloat16x4_t b_bf16x4 = vreinterpret_bf16_u16(b_tail.u16x4);
69
+ float32x4_t a_f32x4 = vcvt_f32_bf16(a_bf16x4);
70
+ float32x4_t b_f32x4 = vcvt_f32_bf16(b_bf16x4);
71
+ float32x4_t result_f32x4 = vaddq_f32(a_f32x4, b_f32x4);
72
+ bfloat16x4_t result_bf16x4 = vcvt_bf16_f32(result_f32x4);
73
+ nk_b64_vec_t result_vec;
74
+ result_vec.u16x4 = vreinterpret_u16_bf16(result_bf16x4);
75
+ nk_partial_store_b16x4_serial_(result + i, &result_vec, n - i);
76
+ }
77
+ }
78
+
79
+ NK_PUBLIC void nk_each_scale_bf16_neonbfdot(nk_bf16_t const *a, nk_size_t n, nk_f32_t const *alpha,
80
+ nk_f32_t const *beta, nk_bf16_t *result) {
81
+ nk_f32_t alpha_val = *alpha;
82
+ nk_f32_t beta_val = *beta;
83
+ float32x4_t alpha_f32x4 = vdupq_n_f32(alpha_val);
84
+ float32x4_t beta_f32x4 = vdupq_n_f32(beta_val);
85
+ nk_size_t i = 0;
86
+ for (; i + 4 <= n; i += 4) {
87
+ bfloat16x4_t a_bf16x4 = vld1_bf16((bfloat16_t const *)a + i);
88
+ float32x4_t a_f32x4 = vcvt_f32_bf16(a_bf16x4);
89
+ float32x4_t result_f32x4 = vfmaq_f32(beta_f32x4, a_f32x4, alpha_f32x4);
90
+ bfloat16x4_t result_bf16x4 = vcvt_bf16_f32(result_f32x4);
91
+ vst1_bf16((bfloat16_t *)result + i, result_bf16x4);
92
+ }
93
+ if (i < n) {
94
+ nk_b64_vec_t a_tail;
95
+ nk_partial_load_b16x4_serial_(a + i, &a_tail, n - i);
96
+ bfloat16x4_t a_bf16x4 = vreinterpret_bf16_u16(a_tail.u16x4);
97
+ float32x4_t a_f32x4 = vcvt_f32_bf16(a_bf16x4);
98
+ float32x4_t result_f32x4 = vfmaq_f32(beta_f32x4, a_f32x4, alpha_f32x4);
99
+ bfloat16x4_t result_bf16x4 = vcvt_bf16_f32(result_f32x4);
100
+ nk_b64_vec_t result_vec;
101
+ result_vec.u16x4 = vreinterpret_u16_bf16(result_bf16x4);
102
+ nk_partial_store_b16x4_serial_(result + i, &result_vec, n - i);
103
+ }
104
+ }
105
+
106
+ NK_PUBLIC void nk_each_blend_bf16_neonbfdot( //
107
+ nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, //
108
+ nk_f32_t const *alpha, nk_f32_t const *beta, nk_bf16_t *result) {
109
+
110
+ nk_f32_t alpha_val = *alpha;
111
+ nk_f32_t beta_val = *beta;
112
+
113
+ // There are several special cases we may want to implement:
114
+ // 1. Simple addition, when both weights are equal to 1.0.
115
+ if (alpha_val == 1 && beta_val == 1) {
116
+ // In this case we can avoid expensive multiplications.
117
+ nk_each_sum_bf16_neonbfdot(a, b, n, result);
118
+ return;
119
+ }
120
+ // 2. Just scaling, when one of the weights is equal to zero.
121
+ else if (alpha_val == 0 || beta_val == 0) {
122
+ // In this case we can avoid half of the load instructions.
123
+ nk_f32_t zero = 0;
124
+ if (beta_val == 0) { nk_each_scale_bf16_neonbfdot(a, n, alpha, &zero, result); }
125
+ else { nk_each_scale_bf16_neonbfdot(b, n, beta, &zero, result); }
126
+ return;
127
+ }
128
+
129
+ // The general case.
130
+ nk_size_t i = 0;
131
+ for (; i + 4 <= n; i += 4) {
132
+ bfloat16x4_t a_bf16x4 = vld1_bf16((bfloat16_t const *)a + i);
133
+ bfloat16x4_t b_bf16x4 = vld1_bf16((bfloat16_t const *)b + i);
134
+ float32x4_t a_f32x4 = vcvt_f32_bf16(a_bf16x4);
135
+ float32x4_t b_f32x4 = vcvt_f32_bf16(b_bf16x4);
136
+ float32x4_t a_scaled_f32x4 = vmulq_n_f32(a_f32x4, alpha_val);
137
+ float32x4_t b_scaled_f32x4 = vmulq_n_f32(b_f32x4, beta_val);
138
+ float32x4_t result_f32x4 = vaddq_f32(a_scaled_f32x4, b_scaled_f32x4);
139
+ bfloat16x4_t result_bf16x4 = vcvt_bf16_f32(result_f32x4);
140
+ vst1_bf16((bfloat16_t *)result + i, result_bf16x4);
141
+ }
142
+ if (i < n) {
143
+ nk_b64_vec_t a_tail, b_tail;
144
+ nk_partial_load_b16x4_serial_(a + i, &a_tail, n - i);
145
+ nk_partial_load_b16x4_serial_(b + i, &b_tail, n - i);
146
+ bfloat16x4_t a_bf16x4 = vreinterpret_bf16_u16(a_tail.u16x4);
147
+ bfloat16x4_t b_bf16x4 = vreinterpret_bf16_u16(b_tail.u16x4);
148
+ float32x4_t a_f32x4 = vcvt_f32_bf16(a_bf16x4);
149
+ float32x4_t b_f32x4 = vcvt_f32_bf16(b_bf16x4);
150
+ float32x4_t a_scaled_f32x4 = vmulq_n_f32(a_f32x4, alpha_val);
151
+ float32x4_t b_scaled_f32x4 = vmulq_n_f32(b_f32x4, beta_val);
152
+ float32x4_t result_f32x4 = vaddq_f32(a_scaled_f32x4, b_scaled_f32x4);
153
+ bfloat16x4_t result_bf16x4 = vcvt_bf16_f32(result_f32x4);
154
+ nk_b64_vec_t result_vec;
155
+ result_vec.u16x4 = vreinterpret_u16_bf16(result_bf16x4);
156
+ nk_partial_store_b16x4_serial_(result + i, &result_vec, n - i);
157
+ }
158
+ }
159
+
160
+ NK_PUBLIC void nk_each_fma_bf16_neonbfdot( //
161
+ nk_bf16_t const *a, nk_bf16_t const *b, nk_bf16_t const *c, //
162
+ nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta, nk_bf16_t *result) {
163
+ nk_f32_t alpha_val = *alpha;
164
+ nk_f32_t beta_val = *beta;
165
+ nk_size_t i = 0;
166
+ for (; i + 4 <= n; i += 4) {
167
+ bfloat16x4_t a_bf16x4 = vld1_bf16((bfloat16_t const *)a + i);
168
+ bfloat16x4_t b_bf16x4 = vld1_bf16((bfloat16_t const *)b + i);
169
+ bfloat16x4_t c_bf16x4 = vld1_bf16((bfloat16_t const *)c + i);
170
+ float32x4_t a_f32x4 = vcvt_f32_bf16(a_bf16x4);
171
+ float32x4_t b_f32x4 = vcvt_f32_bf16(b_bf16x4);
172
+ float32x4_t c_f32x4 = vcvt_f32_bf16(c_bf16x4);
173
+ float32x4_t ab_f32x4 = vmulq_f32(a_f32x4, b_f32x4);
174
+ float32x4_t ab_scaled_f32x4 = vmulq_n_f32(ab_f32x4, alpha_val);
175
+ float32x4_t result_f32x4 = vfmaq_n_f32(ab_scaled_f32x4, c_f32x4, beta_val);
176
+ bfloat16x4_t result_bf16x4 = vcvt_bf16_f32(result_f32x4);
177
+ vst1_bf16((bfloat16_t *)result + i, result_bf16x4);
178
+ }
179
+ if (i < n) {
180
+ nk_b64_vec_t a_tail, b_tail, c_tail;
181
+ nk_partial_load_b16x4_serial_(a + i, &a_tail, n - i);
182
+ nk_partial_load_b16x4_serial_(b + i, &b_tail, n - i);
183
+ nk_partial_load_b16x4_serial_(c + i, &c_tail, n - i);
184
+ bfloat16x4_t a_bf16x4 = vreinterpret_bf16_u16(a_tail.u16x4);
185
+ bfloat16x4_t b_bf16x4 = vreinterpret_bf16_u16(b_tail.u16x4);
186
+ bfloat16x4_t c_bf16x4 = vreinterpret_bf16_u16(c_tail.u16x4);
187
+ float32x4_t a_f32x4 = vcvt_f32_bf16(a_bf16x4);
188
+ float32x4_t b_f32x4 = vcvt_f32_bf16(b_bf16x4);
189
+ float32x4_t c_f32x4 = vcvt_f32_bf16(c_bf16x4);
190
+ float32x4_t ab_f32x4 = vmulq_f32(a_f32x4, b_f32x4);
191
+ float32x4_t ab_scaled_f32x4 = vmulq_n_f32(ab_f32x4, alpha_val);
192
+ float32x4_t result_f32x4 = vfmaq_n_f32(ab_scaled_f32x4, c_f32x4, beta_val);
193
+ bfloat16x4_t result_bf16x4 = vcvt_bf16_f32(result_f32x4);
194
+ nk_b64_vec_t result_vec;
195
+ result_vec.u16x4 = vreinterpret_u16_bf16(result_bf16x4);
196
+ nk_partial_store_b16x4_serial_(result + i, &result_vec, n - i);
197
+ }
198
+ }
199
+
200
+ #if defined(__clang__)
201
+ #pragma clang attribute pop
202
+ #elif defined(__GNUC__)
203
+ #pragma GCC pop_options
204
+ #endif
205
+
206
+ #if defined(__cplusplus)
207
+ } // extern "C"
208
+ #endif
209
+
210
+ #endif // NK_TARGET_NEONBFDOT
211
+ #endif // NK_TARGET_ARM_
212
+ #endif // NK_EACH_NEONBFDOT_H
@@ -0,0 +1,410 @@
1
+ /**
2
+ * @brief SIMD-accelerated Elementwise Arithmetic for NEON FP16.
3
+ * @file include/numkong/each/neonhalf.h
4
+ * @author Ash Vardanian
5
+ * @date December 27, 2025
6
+ *
7
+ * @sa include/numkong/each.h
8
+ *
9
+ * @section elementwise_neonhalf_instructions ARM NEON FP16 Instructions (ARMv8.2-FP16)
10
+ *
11
+ * Intrinsic Instruction Latency Throughput
12
+ * A76 M4+/V1+/Oryon
13
+ * vld1q_f16 LD1 (V.8H) 4cy 2/cy 3/cy
14
+ * vst1q_f16 ST1 (V.8H) 2cy 2/cy 3/cy
15
+ * vaddq_f16 FADD (V.8H, V.8H, V.8H) 2cy 2/cy 4/cy
16
+ * vmulq_f16 FMUL (V.8H, V.8H, V.8H) 3cy 2/cy 4/cy
17
+ * vmulq_n_f16 FMUL (V.8H, V.8H, scalar) 3cy 2/cy 4/cy
18
+ * vfmaq_f16 FMLA (V.8H, V.8H, V.8H) 4cy 2/cy 4/cy
19
+ * vfmaq_n_f16 FMLA (V.8H, V.8H, scalar) 4cy 2/cy 4/cy
20
+ * vdupq_n_f16 DUP (V.8H, scalar) 2cy 2/cy 4/cy
21
+ * vld1_u8 LD1 (V.8B) 4cy 2/cy 3/cy
22
+ * vld1_s8 LD1 (V.8B) 4cy 2/cy 3/cy
23
+ * vmovl_u8 UXTL (V.8H, V.8B) 2cy 2/cy 4/cy
24
+ * vmovl_s8 SXTL (V.8H, V.8B) 2cy 2/cy 4/cy
25
+ * vcvtq_f16_u16 UCVTF (V.8H, V.8H) 3cy 2/cy 4/cy
26
+ * vcvtq_f16_s16 SCVTF (V.8H, V.8H) 3cy 2/cy 4/cy
27
+ * vcvtnq_u16_f16 FCVTNU (V.8H, V.8H) 3cy 2/cy 4/cy
28
+ * vcvtnq_s16_f16 FCVTNS (V.8H, V.8H) 3cy 2/cy 4/cy
29
+ * vqmovn_u16 UQXTN (V.8B, V.8H) 3cy 2/cy 4/cy
30
+ * vqmovn_s16 SQXTN (V.8B, V.8H) 3cy 2/cy 4/cy
31
+ * vqaddq_u8 UQADD (V.16B, V.16B, V.16B) 2cy 2/cy 4/cy
32
+ * vqaddq_s8 SQADD (V.16B, V.16B, V.16B) 2cy 2/cy 4/cy
33
+ *
34
+ * The ARMv8.2-FP16 extension enables native half-precision element-wise operations, processing 8
35
+ * F16 elements per instruction. Operations like sum, scale, blend, and fma work directly in F16,
36
+ * avoiding conversion overhead while halving memory bandwidth vs F32.
37
+ *
38
+ * For int8 element-wise operations, values are widened to F16 for arithmetic via UCVTF/SCVTF,
39
+ * then narrowed back with saturating conversion (FCVTA + UQXTN/SQXTN) to handle overflow gracefully.
40
+ */
41
+ #ifndef NK_EACH_NEONHALF_H
42
+ #define NK_EACH_NEONHALF_H
43
+
44
+ #if NK_TARGET_ARM_
45
+ #if NK_TARGET_NEONHALF
46
+
47
+ #include "numkong/types.h"
48
+ #include "numkong/cast/serial.h" // `nk_f32_to_i8_serial`
49
+
50
+ #if defined(__cplusplus)
51
+ extern "C" {
52
+ #endif
53
+
54
+ #if defined(__clang__)
55
+ #pragma clang attribute push(__attribute__((target("arch=armv8.2-a+simd+fp16"))), apply_to = function)
56
+ #elif defined(__GNUC__)
57
+ #pragma GCC push_options
58
+ #pragma GCC target("arch=armv8.2-a+simd+fp16")
59
+ #endif
60
+
61
+ NK_PUBLIC void nk_each_sum_f16_neonhalf(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f16_t *result) {
62
+ // The main loop:
63
+ nk_size_t i = 0;
64
+ for (; i + 8 <= n; i += 8) {
65
+ float16x8_t a_vec = vld1q_f16((float16_t const *)a + i);
66
+ float16x8_t b_vec = vld1q_f16((float16_t const *)b + i);
67
+ float16x8_t sum_vec = vaddq_f16(a_vec, b_vec);
68
+ vst1q_f16((float16_t *)result + i, sum_vec);
69
+ }
70
+
71
+ // The tail:
72
+ for (; i < n; ++i) ((float16_t *)result)[i] = ((float16_t const *)a)[i] + ((float16_t const *)b)[i];
73
+ }
74
+
75
+ NK_PUBLIC void nk_each_scale_f16_neonhalf(nk_f16_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
76
+ nk_f16_t *result) {
77
+ nk_f32_t alpha_val = *alpha;
78
+ nk_f32_t beta_val = *beta;
79
+ float16_t alpha_f16 = (float16_t)alpha_val;
80
+ float16_t beta_f16 = (float16_t)beta_val;
81
+ float16x8_t alpha_f16x8 = vdupq_n_f16(alpha_f16);
82
+ float16x8_t beta_f16x8 = vdupq_n_f16(beta_f16);
83
+
84
+ // The main loop:
85
+ nk_size_t i = 0;
86
+ for (; i + 8 <= n; i += 8) {
87
+ float16x8_t a_f16x8 = vld1q_f16((float16_t const *)a + i);
88
+ float16x8_t result_f16x8 = vfmaq_f16(beta_f16x8, a_f16x8, alpha_f16x8);
89
+ vst1q_f16((float16_t *)result + i, result_f16x8);
90
+ }
91
+
92
+ // The tail:
93
+ for (; i < n; ++i) ((float16_t *)result)[i] = alpha_f16 * ((float16_t const *)a)[i] + beta_f16;
94
+ }
95
+
96
+ NK_PUBLIC void nk_each_blend_f16_neonhalf( //
97
+ nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, //
98
+ nk_f32_t const *alpha, nk_f32_t const *beta, nk_f16_t *result) {
99
+
100
+ nk_f32_t alpha_val = *alpha;
101
+ nk_f32_t beta_val = *beta;
102
+
103
+ // There are several special cases we may want to implement:
104
+ // 1. Simple addition, when both weights are equal to 1.0.
105
+ if (alpha_val == 1 && beta_val == 1) {
106
+ // In this case we can avoid expensive multiplications.
107
+ nk_each_sum_f16_neonhalf(a, b, n, result);
108
+ return;
109
+ }
110
+ // 2. Just scaling, when one of the weights is equal to zero.
111
+ else if (alpha_val == 0 || beta_val == 0) {
112
+ // In this case we can avoid half of the load instructions.
113
+ nk_f32_t zero = 0;
114
+ if (beta_val == 0) { nk_each_scale_f16_neonhalf(a, n, alpha, &zero, result); }
115
+ else { nk_each_scale_f16_neonhalf(b, n, beta, &zero, result); }
116
+ return;
117
+ }
118
+
119
+ // The general case.
120
+ float16_t alpha_f16 = (float16_t)alpha_val;
121
+ float16_t beta_f16 = (float16_t)beta_val;
122
+
123
+ // The main loop:
124
+ nk_size_t i = 0;
125
+ for (; i + 8 <= n; i += 8) {
126
+ float16x8_t a_f16x8 = vld1q_f16((float16_t const *)a + i);
127
+ float16x8_t b_f16x8 = vld1q_f16((float16_t const *)b + i);
128
+ float16x8_t a_scaled_f16x8 = vmulq_n_f16(a_f16x8, alpha_f16);
129
+ float16x8_t result_f16x8 = vfmaq_n_f16(a_scaled_f16x8, b_f16x8, beta_f16);
130
+ vst1q_f16((float16_t *)result + i, result_f16x8);
131
+ }
132
+
133
+ // The tail:
134
+ for (; i < n; ++i)
135
+ ((float16_t *)result)[i] = alpha_f16 * ((float16_t const *)a)[i] + beta_f16 * ((float16_t const *)b)[i];
136
+ }
137
+
138
+ NK_PUBLIC void nk_each_fma_f16_neonhalf( //
139
+ nk_f16_t const *a, nk_f16_t const *b, nk_f16_t const *c, //
140
+ nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta, nk_f16_t *result) {
141
+ nk_f32_t alpha_val = *alpha;
142
+ nk_f32_t beta_val = *beta;
143
+ float16_t alpha_f16 = (float16_t)alpha_val;
144
+ float16_t beta_f16 = (float16_t)beta_val;
145
+
146
+ // The main loop:
147
+ nk_size_t i = 0;
148
+ for (; i + 8 <= n; i += 8) {
149
+ float16x8_t a_f16x8 = vld1q_f16((float16_t const *)a + i);
150
+ float16x8_t b_f16x8 = vld1q_f16((float16_t const *)b + i);
151
+ float16x8_t c_f16x8 = vld1q_f16((float16_t const *)c + i);
152
+ float16x8_t ab_f16x8 = vmulq_f16(a_f16x8, b_f16x8);
153
+ float16x8_t ab_scaled_f16x8 = vmulq_n_f16(ab_f16x8, alpha_f16);
154
+ float16x8_t result_f16x8 = vfmaq_n_f16(ab_scaled_f16x8, c_f16x8, beta_f16);
155
+ vst1q_f16((float16_t *)result + i, result_f16x8);
156
+ }
157
+
158
+ // The tail:
159
+ for (; i < n; ++i)
160
+ ((float16_t *)result)[i] = alpha_f16 * ((float16_t const *)a)[i] * ((float16_t const *)b)[i] +
161
+ beta_f16 * ((float16_t const *)c)[i];
162
+ }
163
+
164
+ NK_PUBLIC void nk_each_sum_u8_neonhalf(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_u8_t *result) {
165
+ // The main loop:
166
+ nk_size_t i = 0;
167
+ for (; i + 16 <= n; i += 16) {
168
+ uint8x16_t a_vec = vld1q_u8(a + i);
169
+ uint8x16_t b_vec = vld1q_u8(b + i);
170
+ uint8x16_t sum_vec = vqaddq_u8(a_vec, b_vec);
171
+ vst1q_u8(result + i, sum_vec);
172
+ }
173
+
174
+ // The tail:
175
+ for (; i < n; ++i) {
176
+ nk_f32_t sum = (nk_f32_t)a[i] + b[i];
177
+ nk_f32_to_u8_serial(&sum, result + i);
178
+ }
179
+ }
180
+
181
+ NK_PUBLIC void nk_each_scale_u8_neonhalf(nk_u8_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
182
+ nk_u8_t *result) {
183
+ float16_t alpha_f16 = (float16_t)*alpha;
184
+ float16_t beta_f16 = (float16_t)*beta;
185
+ float16x8_t alpha_f16x8 = vdupq_n_f16(alpha_f16);
186
+ float16x8_t beta_f16x8 = vdupq_n_f16(beta_f16);
187
+
188
+ // The main loop:
189
+ nk_size_t i = 0;
190
+ for (; i + 8 <= n; i += 8) {
191
+ uint8x8_t a_u8x8 = vld1_u8(a + i);
192
+ float16x8_t a_f16x8 = vcvtq_f16_u16(vmovl_u8(a_u8x8));
193
+ float16x8_t result_f16x8 = vfmaq_f16(beta_f16x8, a_f16x8, alpha_f16x8);
194
+ uint8x8_t result_u8x8 = vqmovn_u16(vcvtnq_u16_f16(result_f16x8));
195
+ vst1_u8(result + i, result_u8x8);
196
+ }
197
+
198
+ // The tail:
199
+ for (; i < n; ++i) {
200
+ nk_f32_t sum = alpha_f16 * a[i] + beta_f16;
201
+ nk_f32_to_u8_serial(&sum, result + i);
202
+ }
203
+ }
204
+
205
+ NK_PUBLIC void nk_each_blend_u8_neonhalf( //
206
+ nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, //
207
+ nk_f32_t const *alpha, nk_f32_t const *beta, nk_u8_t *result) {
208
+
209
+ nk_f32_t alpha_val = *alpha;
210
+ nk_f32_t beta_val = *beta;
211
+
212
+ // There are several special cases we may want to implement:
213
+ // 1. Simple addition, when both weights are equal to 1.0.
214
+ if (alpha_val == 1 && beta_val == 1) {
215
+ // In this case we can avoid expensive multiplications.
216
+ nk_each_sum_u8_neonhalf(a, b, n, result);
217
+ return;
218
+ }
219
+ // 2. Just scaling, when one of the weights is equal to zero.
220
+ else if (alpha_val == 0 || beta_val == 0) {
221
+ // In this case we can avoid half of the load instructions.
222
+ nk_f32_t zero = 0;
223
+ if (beta_val == 0) { nk_each_scale_u8_neonhalf(a, n, alpha, &zero, result); }
224
+ else { nk_each_scale_u8_neonhalf(b, n, beta, &zero, result); }
225
+ return;
226
+ }
227
+
228
+ // The general case.
229
+ float16_t alpha_f16 = (float16_t)alpha_val;
230
+ float16_t beta_f16 = (float16_t)beta_val;
231
+
232
+ // The main loop:
233
+ nk_size_t i = 0;
234
+ for (; i + 8 <= n; i += 8) {
235
+ uint8x8_t a_u8x8 = vld1_u8(a + i);
236
+ uint8x8_t b_u8x8 = vld1_u8(b + i);
237
+ float16x8_t a_f16x8 = vcvtq_f16_u16(vmovl_u8(a_u8x8));
238
+ float16x8_t b_f16x8 = vcvtq_f16_u16(vmovl_u8(b_u8x8));
239
+ float16x8_t a_scaled_f16x8 = vmulq_n_f16(a_f16x8, alpha_f16);
240
+ float16x8_t result_f16x8 = vfmaq_n_f16(a_scaled_f16x8, b_f16x8, beta_f16);
241
+ uint8x8_t result_u8x8 = vqmovn_u16(vcvtnq_u16_f16(result_f16x8));
242
+ vst1_u8(result + i, result_u8x8);
243
+ }
244
+
245
+ // The tail:
246
+ for (; i < n; ++i) {
247
+ nk_f32_t sum = alpha_f16 * a[i] + beta_f16 * b[i];
248
+ nk_f32_to_u8_serial(&sum, result + i);
249
+ }
250
+ }
251
+
252
+ NK_PUBLIC void nk_each_fma_u8_neonhalf( //
253
+ nk_u8_t const *a, nk_u8_t const *b, nk_u8_t const *c, //
254
+ nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta, nk_u8_t *result) {
255
+ float16_t alpha_f16 = (float16_t)*alpha;
256
+ float16_t beta_f16 = (float16_t)*beta;
257
+
258
+ // The main loop:
259
+ nk_size_t i = 0;
260
+ for (; i + 8 <= n; i += 8) {
261
+ uint8x8_t a_u8x8 = vld1_u8(a + i);
262
+ uint8x8_t b_u8x8 = vld1_u8(b + i);
263
+ uint8x8_t c_u8x8 = vld1_u8(c + i);
264
+ float16x8_t a_f16x8 = vcvtq_f16_u16(vmovl_u8(a_u8x8));
265
+ float16x8_t b_f16x8 = vcvtq_f16_u16(vmovl_u8(b_u8x8));
266
+ float16x8_t c_f16x8 = vcvtq_f16_u16(vmovl_u8(c_u8x8));
267
+ float16x8_t ab_f16x8 = vmulq_f16(a_f16x8, b_f16x8);
268
+ float16x8_t ab_scaled_f16x8 = vmulq_n_f16(ab_f16x8, alpha_f16);
269
+ float16x8_t result_f16x8 = vfmaq_n_f16(ab_scaled_f16x8, c_f16x8, beta_f16);
270
+ uint8x8_t result_u8x8 = vqmovn_u16(vcvtnq_u16_f16(result_f16x8));
271
+ vst1_u8(result + i, result_u8x8);
272
+ }
273
+
274
+ // The tail:
275
+ for (; i < n; ++i) {
276
+ nk_f32_t sum = alpha_f16 * a[i] * b[i] + beta_f16 * c[i];
277
+ nk_f32_to_u8_serial(&sum, result + i);
278
+ }
279
+ }
280
+
281
+ NK_PUBLIC void nk_each_sum_i8_neonhalf(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_i8_t *result) {
282
+ // The main loop:
283
+ nk_size_t i = 0;
284
+ for (; i + 16 <= n; i += 16) {
285
+ int8x16_t a_vec = vld1q_s8(a + i);
286
+ int8x16_t b_vec = vld1q_s8(b + i);
287
+ int8x16_t sum_vec = vqaddq_s8(a_vec, b_vec);
288
+ vst1q_s8(result + i, sum_vec);
289
+ }
290
+
291
+ // The tail:
292
+ for (; i < n; ++i) {
293
+ nk_f32_t sum = (nk_f32_t)a[i] + b[i];
294
+ nk_f32_to_i8_serial(&sum, result + i);
295
+ }
296
+ }
297
+
298
+ NK_PUBLIC void nk_each_scale_i8_neonhalf(nk_i8_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
299
+ nk_i8_t *result) {
300
+ float16_t alpha_f16 = (float16_t)*alpha;
301
+ float16_t beta_f16 = (float16_t)*beta;
302
+ float16x8_t alpha_f16x8 = vdupq_n_f16(alpha_f16);
303
+ float16x8_t beta_f16x8 = vdupq_n_f16(beta_f16);
304
+
305
+ // The main loop:
306
+ nk_size_t i = 0;
307
+ for (; i + 8 <= n; i += 8) {
308
+ int8x8_t a_i8x8 = vld1_s8(a + i);
309
+ float16x8_t a_f16x8 = vcvtq_f16_s16(vmovl_s8(a_i8x8));
310
+ float16x8_t result_f16x8 = vfmaq_f16(beta_f16x8, a_f16x8, alpha_f16x8);
311
+ int8x8_t result_i8x8 = vqmovn_s16(vcvtnq_s16_f16(result_f16x8));
312
+ vst1_s8(result + i, result_i8x8);
313
+ }
314
+
315
+ // The tail:
316
+ for (; i < n; ++i) {
317
+ nk_f32_t sum = alpha_f16 * a[i] + beta_f16;
318
+ nk_f32_to_i8_serial(&sum, result + i);
319
+ }
320
+ }
321
+
322
+ NK_PUBLIC void nk_each_blend_i8_neonhalf( //
323
+ nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, //
324
+ nk_f32_t const *alpha, nk_f32_t const *beta, nk_i8_t *result) {
325
+
326
+ nk_f32_t alpha_val = *alpha;
327
+ nk_f32_t beta_val = *beta;
328
+
329
+ // There are several special cases we may want to implement:
330
+ // 1. Simple addition, when both weights are equal to 1.0.
331
+ if (alpha_val == 1 && beta_val == 1) {
332
+ // In this case we can avoid expensive multiplications.
333
+ nk_each_sum_i8_neonhalf(a, b, n, result);
334
+ return;
335
+ }
336
+ // 2. Just scaling, when one of the weights is equal to zero.
337
+ else if (alpha_val == 0 || beta_val == 0) {
338
+ // In this case we can avoid half of the load instructions.
339
+ nk_f32_t zero = 0;
340
+ if (beta_val == 0) { nk_each_scale_i8_neonhalf(a, n, alpha, &zero, result); }
341
+ else { nk_each_scale_i8_neonhalf(b, n, beta, &zero, result); }
342
+ return;
343
+ }
344
+
345
+ // The general case.
346
+ float16_t alpha_f16 = (float16_t)alpha_val;
347
+ float16_t beta_f16 = (float16_t)beta_val;
348
+
349
+ // The main loop:
350
+ nk_size_t i = 0;
351
+ for (; i + 8 <= n; i += 8) {
352
+ int8x8_t a_i8x8 = vld1_s8(a + i);
353
+ int8x8_t b_i8x8 = vld1_s8(b + i);
354
+ float16x8_t a_f16x8 = vcvtq_f16_s16(vmovl_s8(a_i8x8));
355
+ float16x8_t b_f16x8 = vcvtq_f16_s16(vmovl_s8(b_i8x8));
356
+ float16x8_t a_scaled_f16x8 = vmulq_n_f16(a_f16x8, alpha_f16);
357
+ float16x8_t result_f16x8 = vfmaq_n_f16(a_scaled_f16x8, b_f16x8, beta_f16);
358
+ int8x8_t result_i8x8 = vqmovn_s16(vcvtnq_s16_f16(result_f16x8));
359
+ vst1_s8(result + i, result_i8x8);
360
+ }
361
+
362
+ // The tail:
363
+ for (; i < n; ++i) {
364
+ nk_f32_t sum = alpha_f16 * a[i] + beta_f16 * b[i];
365
+ nk_f32_to_i8_serial(&sum, result + i);
366
+ }
367
+ }
368
+
369
+ NK_PUBLIC void nk_each_fma_i8_neonhalf( //
370
+ nk_i8_t const *a, nk_i8_t const *b, nk_i8_t const *c, //
371
+ nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta, nk_i8_t *result) {
372
+ float16_t alpha_f16 = (float16_t)*alpha;
373
+ float16_t beta_f16 = (float16_t)*beta;
374
+
375
+ // The main loop:
376
+ nk_size_t i = 0;
377
+ for (; i + 8 <= n; i += 8) {
378
+ int8x8_t a_i8x8 = vld1_s8(a + i);
379
+ int8x8_t b_i8x8 = vld1_s8(b + i);
380
+ int8x8_t c_i8x8 = vld1_s8(c + i);
381
+ float16x8_t a_f16x8 = vcvtq_f16_s16(vmovl_s8(a_i8x8));
382
+ float16x8_t b_f16x8 = vcvtq_f16_s16(vmovl_s8(b_i8x8));
383
+ float16x8_t c_f16x8 = vcvtq_f16_s16(vmovl_s8(c_i8x8));
384
+ float16x8_t ab_f16x8 = vmulq_f16(a_f16x8, b_f16x8);
385
+ float16x8_t ab_scaled_f16x8 = vmulq_n_f16(ab_f16x8, alpha_f16);
386
+ float16x8_t result_f16x8 = vfmaq_n_f16(ab_scaled_f16x8, c_f16x8, beta_f16);
387
+ int8x8_t result_i8x8 = vqmovn_s16(vcvtnq_s16_f16(result_f16x8));
388
+ vst1_s8(result + i, result_i8x8);
389
+ }
390
+
391
+ // The tail:
392
+ for (; i < n; ++i) {
393
+ nk_f32_t sum = alpha_f16 * a[i] * b[i] + beta_f16 * c[i];
394
+ nk_f32_to_i8_serial(&sum, result + i);
395
+ }
396
+ }
397
+
398
+ #if defined(__clang__)
399
+ #pragma clang attribute pop
400
+ #elif defined(__GNUC__)
401
+ #pragma GCC pop_options
402
+ #endif
403
+
404
+ #if defined(__cplusplus)
405
+ } // extern "C"
406
+ #endif
407
+
408
+ #endif // NK_TARGET_NEONHALF
409
+ #endif // NK_TARGET_ARM_
410
+ #endif // NK_EACH_NEONHALF_H