numkong 7.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (294) hide show
  1. package/LICENSE +201 -0
  2. package/README.md +495 -0
  3. package/binding.gyp +540 -0
  4. package/c/dispatch.h +512 -0
  5. package/c/dispatch_bf16.c +389 -0
  6. package/c/dispatch_bf16c.c +52 -0
  7. package/c/dispatch_e2m3.c +263 -0
  8. package/c/dispatch_e3m2.c +243 -0
  9. package/c/dispatch_e4m3.c +276 -0
  10. package/c/dispatch_e5m2.c +272 -0
  11. package/c/dispatch_f16.c +376 -0
  12. package/c/dispatch_f16c.c +58 -0
  13. package/c/dispatch_f32.c +378 -0
  14. package/c/dispatch_f32c.c +99 -0
  15. package/c/dispatch_f64.c +296 -0
  16. package/c/dispatch_f64c.c +98 -0
  17. package/c/dispatch_i16.c +96 -0
  18. package/c/dispatch_i32.c +89 -0
  19. package/c/dispatch_i4.c +150 -0
  20. package/c/dispatch_i64.c +86 -0
  21. package/c/dispatch_i8.c +289 -0
  22. package/c/dispatch_other.c +330 -0
  23. package/c/dispatch_u1.c +148 -0
  24. package/c/dispatch_u16.c +124 -0
  25. package/c/dispatch_u32.c +118 -0
  26. package/c/dispatch_u4.c +150 -0
  27. package/c/dispatch_u64.c +102 -0
  28. package/c/dispatch_u8.c +303 -0
  29. package/c/numkong.c +950 -0
  30. package/include/README.md +573 -0
  31. package/include/module.modulemap +129 -0
  32. package/include/numkong/attention/sapphireamx.h +1361 -0
  33. package/include/numkong/attention/sme.h +2066 -0
  34. package/include/numkong/attention.h +49 -0
  35. package/include/numkong/capabilities.h +748 -0
  36. package/include/numkong/cast/README.md +262 -0
  37. package/include/numkong/cast/haswell.h +975 -0
  38. package/include/numkong/cast/icelake.h +470 -0
  39. package/include/numkong/cast/neon.h +1192 -0
  40. package/include/numkong/cast/rvv.h +1021 -0
  41. package/include/numkong/cast/sapphire.h +262 -0
  42. package/include/numkong/cast/serial.h +2262 -0
  43. package/include/numkong/cast/skylake.h +856 -0
  44. package/include/numkong/cast/v128relaxed.h +180 -0
  45. package/include/numkong/cast.h +230 -0
  46. package/include/numkong/curved/README.md +223 -0
  47. package/include/numkong/curved/genoa.h +182 -0
  48. package/include/numkong/curved/haswell.h +276 -0
  49. package/include/numkong/curved/neon.h +205 -0
  50. package/include/numkong/curved/neonbfdot.h +212 -0
  51. package/include/numkong/curved/neonhalf.h +212 -0
  52. package/include/numkong/curved/rvv.h +305 -0
  53. package/include/numkong/curved/serial.h +207 -0
  54. package/include/numkong/curved/skylake.h +457 -0
  55. package/include/numkong/curved/smef64.h +506 -0
  56. package/include/numkong/curved.h +517 -0
  57. package/include/numkong/curved.hpp +144 -0
  58. package/include/numkong/dot/README.md +425 -0
  59. package/include/numkong/dot/alder.h +563 -0
  60. package/include/numkong/dot/genoa.h +315 -0
  61. package/include/numkong/dot/haswell.h +1688 -0
  62. package/include/numkong/dot/icelake.h +883 -0
  63. package/include/numkong/dot/neon.h +818 -0
  64. package/include/numkong/dot/neonbfdot.h +244 -0
  65. package/include/numkong/dot/neonfhm.h +360 -0
  66. package/include/numkong/dot/neonhalf.h +198 -0
  67. package/include/numkong/dot/neonsdot.h +508 -0
  68. package/include/numkong/dot/rvv.h +714 -0
  69. package/include/numkong/dot/rvvbb.h +72 -0
  70. package/include/numkong/dot/rvvbf16.h +123 -0
  71. package/include/numkong/dot/rvvhalf.h +129 -0
  72. package/include/numkong/dot/sapphire.h +141 -0
  73. package/include/numkong/dot/serial.h +838 -0
  74. package/include/numkong/dot/sierra.h +405 -0
  75. package/include/numkong/dot/skylake.h +1084 -0
  76. package/include/numkong/dot/sve.h +379 -0
  77. package/include/numkong/dot/svebfdot.h +74 -0
  78. package/include/numkong/dot/svehalf.h +123 -0
  79. package/include/numkong/dot/v128relaxed.h +1258 -0
  80. package/include/numkong/dot.h +1070 -0
  81. package/include/numkong/dot.hpp +94 -0
  82. package/include/numkong/dots/README.md +496 -0
  83. package/include/numkong/dots/alder.h +114 -0
  84. package/include/numkong/dots/genoa.h +94 -0
  85. package/include/numkong/dots/haswell.h +295 -0
  86. package/include/numkong/dots/icelake.h +171 -0
  87. package/include/numkong/dots/neon.h +120 -0
  88. package/include/numkong/dots/neonbfdot.h +58 -0
  89. package/include/numkong/dots/neonfhm.h +94 -0
  90. package/include/numkong/dots/neonhalf.h +57 -0
  91. package/include/numkong/dots/neonsdot.h +108 -0
  92. package/include/numkong/dots/rvv.h +2486 -0
  93. package/include/numkong/dots/sapphireamx.h +3973 -0
  94. package/include/numkong/dots/serial.h +2844 -0
  95. package/include/numkong/dots/sierra.h +97 -0
  96. package/include/numkong/dots/skylake.h +196 -0
  97. package/include/numkong/dots/sme.h +5372 -0
  98. package/include/numkong/dots/smebi32.h +461 -0
  99. package/include/numkong/dots/smef64.h +1318 -0
  100. package/include/numkong/dots/smehalf.h +47 -0
  101. package/include/numkong/dots/v128relaxed.h +294 -0
  102. package/include/numkong/dots.h +2804 -0
  103. package/include/numkong/dots.hpp +639 -0
  104. package/include/numkong/each/README.md +469 -0
  105. package/include/numkong/each/haswell.h +1658 -0
  106. package/include/numkong/each/icelake.h +272 -0
  107. package/include/numkong/each/neon.h +1104 -0
  108. package/include/numkong/each/neonbfdot.h +212 -0
  109. package/include/numkong/each/neonhalf.h +410 -0
  110. package/include/numkong/each/rvv.h +1121 -0
  111. package/include/numkong/each/sapphire.h +477 -0
  112. package/include/numkong/each/serial.h +260 -0
  113. package/include/numkong/each/skylake.h +1562 -0
  114. package/include/numkong/each.h +2146 -0
  115. package/include/numkong/each.hpp +434 -0
  116. package/include/numkong/geospatial/README.md +147 -0
  117. package/include/numkong/geospatial/haswell.h +593 -0
  118. package/include/numkong/geospatial/neon.h +571 -0
  119. package/include/numkong/geospatial/rvv.h +701 -0
  120. package/include/numkong/geospatial/serial.h +309 -0
  121. package/include/numkong/geospatial/skylake.h +577 -0
  122. package/include/numkong/geospatial/v128relaxed.h +613 -0
  123. package/include/numkong/geospatial.h +453 -0
  124. package/include/numkong/geospatial.hpp +235 -0
  125. package/include/numkong/matrix.hpp +336 -0
  126. package/include/numkong/maxsim/README.md +187 -0
  127. package/include/numkong/maxsim/alder.h +511 -0
  128. package/include/numkong/maxsim/genoa.h +115 -0
  129. package/include/numkong/maxsim/haswell.h +553 -0
  130. package/include/numkong/maxsim/icelake.h +480 -0
  131. package/include/numkong/maxsim/neonsdot.h +394 -0
  132. package/include/numkong/maxsim/sapphireamx.h +877 -0
  133. package/include/numkong/maxsim/serial.h +490 -0
  134. package/include/numkong/maxsim/sme.h +929 -0
  135. package/include/numkong/maxsim/v128relaxed.h +280 -0
  136. package/include/numkong/maxsim.h +571 -0
  137. package/include/numkong/maxsim.hpp +133 -0
  138. package/include/numkong/mesh/README.md +227 -0
  139. package/include/numkong/mesh/haswell.h +2235 -0
  140. package/include/numkong/mesh/neon.h +1329 -0
  141. package/include/numkong/mesh/neonbfdot.h +842 -0
  142. package/include/numkong/mesh/neonhalf.h +616 -0
  143. package/include/numkong/mesh/rvv.h +916 -0
  144. package/include/numkong/mesh/serial.h +742 -0
  145. package/include/numkong/mesh/skylake.h +1135 -0
  146. package/include/numkong/mesh/v128relaxed.h +1052 -0
  147. package/include/numkong/mesh.h +652 -0
  148. package/include/numkong/mesh.hpp +762 -0
  149. package/include/numkong/numkong.h +78 -0
  150. package/include/numkong/numkong.hpp +57 -0
  151. package/include/numkong/probability/README.md +173 -0
  152. package/include/numkong/probability/haswell.h +267 -0
  153. package/include/numkong/probability/neon.h +225 -0
  154. package/include/numkong/probability/rvv.h +409 -0
  155. package/include/numkong/probability/serial.h +169 -0
  156. package/include/numkong/probability/skylake.h +324 -0
  157. package/include/numkong/probability.h +383 -0
  158. package/include/numkong/probability.hpp +120 -0
  159. package/include/numkong/random.h +50 -0
  160. package/include/numkong/random.hpp +285 -0
  161. package/include/numkong/reduce/README.md +547 -0
  162. package/include/numkong/reduce/alder.h +632 -0
  163. package/include/numkong/reduce/genoa.h +201 -0
  164. package/include/numkong/reduce/haswell.h +3783 -0
  165. package/include/numkong/reduce/icelake.h +549 -0
  166. package/include/numkong/reduce/neon.h +3841 -0
  167. package/include/numkong/reduce/neonbfdot.h +353 -0
  168. package/include/numkong/reduce/neonfhm.h +665 -0
  169. package/include/numkong/reduce/neonhalf.h +157 -0
  170. package/include/numkong/reduce/neonsdot.h +357 -0
  171. package/include/numkong/reduce/rvv.h +3407 -0
  172. package/include/numkong/reduce/serial.h +757 -0
  173. package/include/numkong/reduce/sierra.h +338 -0
  174. package/include/numkong/reduce/skylake.h +3792 -0
  175. package/include/numkong/reduce/v128relaxed.h +2302 -0
  176. package/include/numkong/reduce.h +1597 -0
  177. package/include/numkong/reduce.hpp +633 -0
  178. package/include/numkong/scalar/README.md +89 -0
  179. package/include/numkong/scalar/haswell.h +113 -0
  180. package/include/numkong/scalar/neon.h +122 -0
  181. package/include/numkong/scalar/neonhalf.h +70 -0
  182. package/include/numkong/scalar/rvv.h +211 -0
  183. package/include/numkong/scalar/sapphire.h +63 -0
  184. package/include/numkong/scalar/serial.h +332 -0
  185. package/include/numkong/scalar/v128relaxed.h +56 -0
  186. package/include/numkong/scalar.h +683 -0
  187. package/include/numkong/set/README.md +179 -0
  188. package/include/numkong/set/haswell.h +334 -0
  189. package/include/numkong/set/icelake.h +485 -0
  190. package/include/numkong/set/neon.h +364 -0
  191. package/include/numkong/set/rvv.h +226 -0
  192. package/include/numkong/set/rvvbb.h +117 -0
  193. package/include/numkong/set/serial.h +174 -0
  194. package/include/numkong/set/sve.h +185 -0
  195. package/include/numkong/set/v128relaxed.h +240 -0
  196. package/include/numkong/set.h +457 -0
  197. package/include/numkong/set.hpp +114 -0
  198. package/include/numkong/sets/README.md +149 -0
  199. package/include/numkong/sets/haswell.h +63 -0
  200. package/include/numkong/sets/icelake.h +66 -0
  201. package/include/numkong/sets/neon.h +61 -0
  202. package/include/numkong/sets/serial.h +43 -0
  203. package/include/numkong/sets/smebi32.h +1099 -0
  204. package/include/numkong/sets/v128relaxed.h +58 -0
  205. package/include/numkong/sets.h +339 -0
  206. package/include/numkong/sparse/README.md +156 -0
  207. package/include/numkong/sparse/icelake.h +463 -0
  208. package/include/numkong/sparse/neon.h +288 -0
  209. package/include/numkong/sparse/serial.h +117 -0
  210. package/include/numkong/sparse/sve2.h +507 -0
  211. package/include/numkong/sparse/turin.h +322 -0
  212. package/include/numkong/sparse.h +363 -0
  213. package/include/numkong/sparse.hpp +113 -0
  214. package/include/numkong/spatial/README.md +435 -0
  215. package/include/numkong/spatial/alder.h +607 -0
  216. package/include/numkong/spatial/genoa.h +290 -0
  217. package/include/numkong/spatial/haswell.h +960 -0
  218. package/include/numkong/spatial/icelake.h +586 -0
  219. package/include/numkong/spatial/neon.h +773 -0
  220. package/include/numkong/spatial/neonbfdot.h +165 -0
  221. package/include/numkong/spatial/neonhalf.h +118 -0
  222. package/include/numkong/spatial/neonsdot.h +261 -0
  223. package/include/numkong/spatial/rvv.h +984 -0
  224. package/include/numkong/spatial/rvvbf16.h +123 -0
  225. package/include/numkong/spatial/rvvhalf.h +117 -0
  226. package/include/numkong/spatial/sapphire.h +343 -0
  227. package/include/numkong/spatial/serial.h +346 -0
  228. package/include/numkong/spatial/sierra.h +323 -0
  229. package/include/numkong/spatial/skylake.h +606 -0
  230. package/include/numkong/spatial/sve.h +224 -0
  231. package/include/numkong/spatial/svebfdot.h +122 -0
  232. package/include/numkong/spatial/svehalf.h +109 -0
  233. package/include/numkong/spatial/v128relaxed.h +717 -0
  234. package/include/numkong/spatial.h +1425 -0
  235. package/include/numkong/spatial.hpp +183 -0
  236. package/include/numkong/spatials/README.md +580 -0
  237. package/include/numkong/spatials/alder.h +94 -0
  238. package/include/numkong/spatials/genoa.h +94 -0
  239. package/include/numkong/spatials/haswell.h +219 -0
  240. package/include/numkong/spatials/icelake.h +113 -0
  241. package/include/numkong/spatials/neon.h +109 -0
  242. package/include/numkong/spatials/neonbfdot.h +60 -0
  243. package/include/numkong/spatials/neonfhm.h +92 -0
  244. package/include/numkong/spatials/neonhalf.h +58 -0
  245. package/include/numkong/spatials/neonsdot.h +109 -0
  246. package/include/numkong/spatials/rvv.h +1960 -0
  247. package/include/numkong/spatials/sapphireamx.h +1149 -0
  248. package/include/numkong/spatials/serial.h +226 -0
  249. package/include/numkong/spatials/sierra.h +96 -0
  250. package/include/numkong/spatials/skylake.h +184 -0
  251. package/include/numkong/spatials/sme.h +1901 -0
  252. package/include/numkong/spatials/smef64.h +465 -0
  253. package/include/numkong/spatials/v128relaxed.h +240 -0
  254. package/include/numkong/spatials.h +3021 -0
  255. package/include/numkong/spatials.hpp +508 -0
  256. package/include/numkong/tensor.hpp +1592 -0
  257. package/include/numkong/trigonometry/README.md +184 -0
  258. package/include/numkong/trigonometry/haswell.h +652 -0
  259. package/include/numkong/trigonometry/neon.h +639 -0
  260. package/include/numkong/trigonometry/rvv.h +699 -0
  261. package/include/numkong/trigonometry/serial.h +703 -0
  262. package/include/numkong/trigonometry/skylake.h +721 -0
  263. package/include/numkong/trigonometry/v128relaxed.h +666 -0
  264. package/include/numkong/trigonometry.h +467 -0
  265. package/include/numkong/trigonometry.hpp +166 -0
  266. package/include/numkong/types.h +1384 -0
  267. package/include/numkong/types.hpp +5603 -0
  268. package/include/numkong/vector.hpp +698 -0
  269. package/javascript/README.md +246 -0
  270. package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
  271. package/javascript/dist/cjs/numkong-wasm.js +617 -0
  272. package/javascript/dist/cjs/numkong.d.ts +343 -0
  273. package/javascript/dist/cjs/numkong.js +523 -0
  274. package/javascript/dist/cjs/package.json +3 -0
  275. package/javascript/dist/cjs/types.d.ts +284 -0
  276. package/javascript/dist/cjs/types.js +653 -0
  277. package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
  278. package/javascript/dist/esm/numkong-wasm.js +595 -0
  279. package/javascript/dist/esm/numkong.d.ts +343 -0
  280. package/javascript/dist/esm/numkong.js +452 -0
  281. package/javascript/dist/esm/package.json +3 -0
  282. package/javascript/dist/esm/types.d.ts +284 -0
  283. package/javascript/dist/esm/types.js +630 -0
  284. package/javascript/dist-package-cjs.json +3 -0
  285. package/javascript/dist-package-esm.json +3 -0
  286. package/javascript/node-gyp-build.d.ts +1 -0
  287. package/javascript/numkong-wasm.ts +756 -0
  288. package/javascript/numkong.c +689 -0
  289. package/javascript/numkong.ts +575 -0
  290. package/javascript/tsconfig-base.json +39 -0
  291. package/javascript/tsconfig-cjs.json +8 -0
  292. package/javascript/tsconfig-esm.json +8 -0
  293. package/javascript/types.ts +674 -0
  294. package/package.json +87 -0
@@ -0,0 +1,157 @@
1
+ /**
2
+ * @brief NEON FP16 implementations for the redesigned reduction API (moments + minmax).
3
+ * @file include/numkong/reduce/neonhalf.h
4
+ * @author Ash Vardanian
5
+ * @date February 13, 2026
6
+ *
7
+ * @sa include/numkong/reduce.h
8
+ *
9
+ * @section reduce_neonhalf_new_design Design Notes
10
+ *
11
+ * Moments (sum + sum-of-squares) accumulate in f32 via vcvt_f32_f16 widening, giving
12
+ * full f32 precision. The contiguous path processes 8 f16 elements per iteration, widening
13
+ * to two f32x4 halves and using vfmaq_f32 for fused multiply-accumulate of squares.
14
+ *
15
+ * Minmax tracks min/max values as native f16x8 with u16x8 iteration counters (same width
16
+ * as f16). The u16 counters wrap at 65536, so the dispatcher splits arrays larger than
17
+ * 65536 * 8 = 524288 elements via recursive halving.
18
+ */
19
+ #ifndef NK_REDUCE_NEONHALF_H
20
+ #define NK_REDUCE_NEONHALF_H
21
+
22
+ #if NK_TARGET_ARM_
23
+ #if NK_TARGET_NEONHALF
24
+
25
+ #include "numkong/types.h"
26
+ #include "numkong/cast/neon.h"
27
+ #include "numkong/cast/serial.h"
28
+ #include "numkong/reduce/serial.h"
29
+
30
+ #if defined(__cplusplus)
31
+ extern "C" {
32
+ #endif
33
+
34
+ #if defined(__clang__)
35
+ #pragma clang attribute push(__attribute__((target("arch=armv8.2-a+simd+fp16"))), apply_to = function)
36
+ #elif defined(__GNUC__)
37
+ #pragma GCC push_options
38
+ #pragma GCC target("arch=armv8.2-a+simd+fp16")
39
+ #endif
40
+
41
+ NK_INTERNAL void nk_reduce_moments_f16_neonhalf_contiguous_( //
42
+ nk_f16_t const *data_ptr, nk_size_t count, //
43
+ nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
44
+ float32x4_t sum_f32x4 = vdupq_n_f32(0);
45
+ float32x4_t sumsq_f32x4 = vdupq_n_f32(0);
46
+ nk_size_t idx = 0;
47
+
48
+ for (; idx + 8 <= count; idx += 8) {
49
+ float16x8_t data_f16x8 = vld1q_f16((nk_f16_for_arm_simd_t const *)(data_ptr + idx));
50
+ float32x4_t low_f32x4 = vcvt_f32_f16(vget_low_f16(data_f16x8));
51
+ float32x4_t high_f32x4 = vcvt_f32_f16(vget_high_f16(data_f16x8));
52
+ sum_f32x4 = vaddq_f32(sum_f32x4, low_f32x4);
53
+ sum_f32x4 = vaddq_f32(sum_f32x4, high_f32x4);
54
+ sumsq_f32x4 = vfmaq_f32(sumsq_f32x4, low_f32x4, low_f32x4);
55
+ sumsq_f32x4 = vfmaq_f32(sumsq_f32x4, high_f32x4, high_f32x4);
56
+ }
57
+
58
+ // Scalar tail
59
+ nk_f32_t sum = vaddvq_f32(sum_f32x4);
60
+ nk_f32_t sumsq = vaddvq_f32(sumsq_f32x4);
61
+ for (; idx < count; ++idx) {
62
+ nk_f32_t value_f32;
63
+ nk_f16_to_f32_serial(data_ptr + idx, &value_f32);
64
+ sum += value_f32, sumsq += value_f32 * value_f32;
65
+ }
66
+ *sum_ptr = sum, *sumsq_ptr = sumsq;
67
+ }
68
+
69
+ NK_INTERNAL void nk_reduce_moments_f16_neonhalf_strided_( //
70
+ nk_f16_t const *data_ptr, nk_size_t count, nk_size_t stride_elements, //
71
+ nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
72
+ float32x4_t sum_f32x4 = vdupq_n_f32(0);
73
+ float32x4_t sumsq_f32x4 = vdupq_n_f32(0);
74
+ nk_size_t idx = 0;
75
+
76
+ if (stride_elements == 2) {
77
+ for (; idx + 8 <= count; idx += 8) {
78
+ uint16x8x2_t loaded_u16x8x2 = vld2q_u16((uint16_t const *)(data_ptr + idx * 2));
79
+ float16x8_t data_f16x8 = vreinterpretq_f16_u16(loaded_u16x8x2.val[0]);
80
+ float32x4_t low_f32x4 = vcvt_f32_f16(vget_low_f16(data_f16x8));
81
+ float32x4_t high_f32x4 = vcvt_f32_f16(vget_high_f16(data_f16x8));
82
+ sum_f32x4 = vaddq_f32(sum_f32x4, low_f32x4);
83
+ sum_f32x4 = vaddq_f32(sum_f32x4, high_f32x4);
84
+ sumsq_f32x4 = vfmaq_f32(sumsq_f32x4, low_f32x4, low_f32x4);
85
+ sumsq_f32x4 = vfmaq_f32(sumsq_f32x4, high_f32x4, high_f32x4);
86
+ }
87
+ }
88
+ else if (stride_elements == 3) {
89
+ for (; idx + 8 <= count; idx += 8) {
90
+ uint16x8x3_t loaded_u16x8x3 = vld3q_u16((uint16_t const *)(data_ptr + idx * 3));
91
+ float16x8_t data_f16x8 = vreinterpretq_f16_u16(loaded_u16x8x3.val[0]);
92
+ float32x4_t low_f32x4 = vcvt_f32_f16(vget_low_f16(data_f16x8));
93
+ float32x4_t high_f32x4 = vcvt_f32_f16(vget_high_f16(data_f16x8));
94
+ sum_f32x4 = vaddq_f32(sum_f32x4, low_f32x4);
95
+ sum_f32x4 = vaddq_f32(sum_f32x4, high_f32x4);
96
+ sumsq_f32x4 = vfmaq_f32(sumsq_f32x4, low_f32x4, low_f32x4);
97
+ sumsq_f32x4 = vfmaq_f32(sumsq_f32x4, high_f32x4, high_f32x4);
98
+ }
99
+ }
100
+ else if (stride_elements == 4) {
101
+ for (; idx + 8 <= count; idx += 8) {
102
+ uint16x8x4_t loaded_u16x8x4 = vld4q_u16((uint16_t const *)(data_ptr + idx * 4));
103
+ float16x8_t data_f16x8 = vreinterpretq_f16_u16(loaded_u16x8x4.val[0]);
104
+ float32x4_t low_f32x4 = vcvt_f32_f16(vget_low_f16(data_f16x8));
105
+ float32x4_t high_f32x4 = vcvt_f32_f16(vget_high_f16(data_f16x8));
106
+ sum_f32x4 = vaddq_f32(sum_f32x4, low_f32x4);
107
+ sum_f32x4 = vaddq_f32(sum_f32x4, high_f32x4);
108
+ sumsq_f32x4 = vfmaq_f32(sumsq_f32x4, low_f32x4, low_f32x4);
109
+ sumsq_f32x4 = vfmaq_f32(sumsq_f32x4, high_f32x4, high_f32x4);
110
+ }
111
+ }
112
+
113
+ // Scalar tail for remaining elements
114
+ nk_f32_t sum = vaddvq_f32(sum_f32x4);
115
+ nk_f32_t sumsq = vaddvq_f32(sumsq_f32x4);
116
+ for (; idx < count; ++idx) {
117
+ nk_f32_t value_f32;
118
+ nk_f16_to_f32_serial((nk_f16_t const *)(data_ptr + idx * stride_elements), &value_f32);
119
+ sum += value_f32, sumsq += value_f32 * value_f32;
120
+ }
121
+ *sum_ptr = sum, *sumsq_ptr = sumsq;
122
+ }
123
+
124
+ NK_PUBLIC void nk_reduce_moments_f16_neonhalf( //
125
+ nk_f16_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
126
+ nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
127
+ nk_size_t stride_elements = stride_bytes / sizeof(nk_f16_t);
128
+ int aligned = (stride_bytes % sizeof(nk_f16_t) == 0);
129
+ if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
130
+ else if (!aligned) nk_reduce_moments_f16_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
131
+ else if (count > (nk_size_t)(NK_U16_MAX + 1) * 8) {
132
+ nk_size_t left_count = count / 2;
133
+ nk_f32_t left_sum_value, left_sumsq_value, right_sum_value, right_sumsq_value;
134
+ nk_reduce_moments_f16_neonhalf(data_ptr, left_count, stride_bytes, &left_sum_value, &left_sumsq_value);
135
+ nk_reduce_moments_f16_neonhalf(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
136
+ &right_sum_value, &right_sumsq_value);
137
+ *sum_ptr = left_sum_value + right_sum_value, *sumsq_ptr = left_sumsq_value + right_sumsq_value;
138
+ }
139
+ else if (stride_elements == 1) nk_reduce_moments_f16_neonhalf_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
140
+ else if (stride_elements <= 4)
141
+ nk_reduce_moments_f16_neonhalf_strided_(data_ptr, count, stride_elements, sum_ptr, sumsq_ptr);
142
+ else nk_reduce_moments_f16_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
143
+ }
144
+
145
+ #if defined(__clang__)
146
+ #pragma clang attribute pop
147
+ #elif defined(__GNUC__)
148
+ #pragma GCC pop_options
149
+ #endif
150
+
151
+ #if defined(__cplusplus)
152
+ } // extern "C"
153
+ #endif
154
+
155
+ #endif // NK_TARGET_NEONHALF
156
+ #endif // NK_TARGET_ARM_
157
+ #endif // NK_REDUCE_NEONHALF_H
@@ -0,0 +1,357 @@
1
+ /**
2
+ * @brief ARMv8.4-DotProd implementations for the redesigned reduction API (moments).
3
+ * @file include/numkong/reduce/neonsdot.h
4
+ * @author Ash Vardanian
5
+ * @date February 13, 2026
6
+ *
7
+ * @sa include/numkong/reduce.h
8
+ */
9
+ #ifndef NK_REDUCE_NEONSDOT_H
10
+ #define NK_REDUCE_NEONSDOT_H
11
+
12
+ #if NK_TARGET_ARM_
13
+ #if NK_TARGET_NEONSDOT
14
+
15
+ #include "numkong/types.h"
16
+ #include "numkong/cast/serial.h"
17
+ #include "numkong/reduce/serial.h"
18
+
19
+ #if defined(__cplusplus)
20
+ extern "C" {
21
+ #endif
22
+
23
+ #if defined(__clang__)
24
+ #pragma clang attribute push(__attribute__((target("arch=armv8.2-a+dotprod"))), apply_to = function)
25
+ #elif defined(__GNUC__)
26
+ #pragma GCC push_options
27
+ #pragma GCC target("arch=armv8.2-a+dotprod")
28
+ #endif
29
+
30
+ NK_INTERNAL void nk_reduce_moments_i8_neonsdot_contiguous_( //
31
+ nk_i8_t const *data_ptr, nk_size_t count, //
32
+ nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
33
+ int8x16_t ones_i8x16 = vdupq_n_s8(1);
34
+ int32x4_t sum_i32x4 = vdupq_n_s32(0);
35
+ int32x4_t sumsq_i32x4 = vdupq_n_s32(0);
36
+ nk_size_t idx = 0;
37
+ for (; idx + 16 <= count; idx += 16) {
38
+ int8x16_t data_i8x16 = vld1q_s8(data_ptr + idx);
39
+ sum_i32x4 = vdotq_s32(sum_i32x4, data_i8x16, ones_i8x16);
40
+ sumsq_i32x4 = vdotq_s32(sumsq_i32x4, data_i8x16, data_i8x16);
41
+ }
42
+ // Widen i32 -> i64 and horizontal reduce
43
+ int64x2_t sum_i64x2 = vpaddlq_s32(sum_i32x4);
44
+ nk_i64_t sum = vgetq_lane_s64(sum_i64x2, 0) + vgetq_lane_s64(sum_i64x2, 1);
45
+ uint64x2_t sumsq_u64x2 = vpaddlq_u32(vreinterpretq_u32_s32(sumsq_i32x4));
46
+ nk_u64_t sumsq = vgetq_lane_u64(sumsq_u64x2, 0) + vgetq_lane_u64(sumsq_u64x2, 1);
47
+ for (; idx < count; ++idx) {
48
+ nk_i64_t value = (nk_i64_t)data_ptr[idx];
49
+ sum += value, sumsq += (nk_u64_t)(value * value);
50
+ }
51
+ *sum_ptr = sum;
52
+ *sumsq_ptr = sumsq;
53
+ }
54
+
55
+ NK_INTERNAL void nk_reduce_moments_i8_neonsdot_strided_( //
56
+ nk_i8_t const *data_ptr, nk_size_t count, nk_size_t stride_elements, //
57
+ nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
58
+ int8x16_t ones_i8x16 = vdupq_n_s8(1);
59
+ int32x4_t sum_i32x4 = vdupq_n_s32(0);
60
+ int32x4_t sumsq_i32x4 = vdupq_n_s32(0);
61
+ nk_size_t idx = 0;
62
+ if (stride_elements == 2) {
63
+ for (; idx + 16 <= count; idx += 16) {
64
+ int8x16x2_t loaded = vld2q_s8(data_ptr + idx * 2);
65
+ int8x16_t data_i8x16 = loaded.val[0];
66
+ sum_i32x4 = vdotq_s32(sum_i32x4, data_i8x16, ones_i8x16);
67
+ sumsq_i32x4 = vdotq_s32(sumsq_i32x4, data_i8x16, data_i8x16);
68
+ }
69
+ }
70
+ else if (stride_elements == 3) {
71
+ for (; idx + 16 <= count; idx += 16) {
72
+ int8x16x3_t loaded = vld3q_s8(data_ptr + idx * 3);
73
+ int8x16_t data_i8x16 = loaded.val[0];
74
+ sum_i32x4 = vdotq_s32(sum_i32x4, data_i8x16, ones_i8x16);
75
+ sumsq_i32x4 = vdotq_s32(sumsq_i32x4, data_i8x16, data_i8x16);
76
+ }
77
+ }
78
+ else if (stride_elements == 4) {
79
+ for (; idx + 16 <= count; idx += 16) {
80
+ int8x16x4_t loaded = vld4q_s8(data_ptr + idx * 4);
81
+ int8x16_t data_i8x16 = loaded.val[0];
82
+ sum_i32x4 = vdotq_s32(sum_i32x4, data_i8x16, ones_i8x16);
83
+ sumsq_i32x4 = vdotq_s32(sumsq_i32x4, data_i8x16, data_i8x16);
84
+ }
85
+ }
86
+ // Widen i32 -> i64 and horizontal reduce
87
+ int64x2_t sum_i64x2 = vpaddlq_s32(sum_i32x4);
88
+ nk_i64_t sum = vgetq_lane_s64(sum_i64x2, 0) + vgetq_lane_s64(sum_i64x2, 1);
89
+ uint64x2_t sumsq_u64x2 = vpaddlq_u32(vreinterpretq_u32_s32(sumsq_i32x4));
90
+ nk_u64_t sumsq = vgetq_lane_u64(sumsq_u64x2, 0) + vgetq_lane_u64(sumsq_u64x2, 1);
91
+ for (; idx < count; ++idx) {
92
+ nk_i64_t value = (nk_i64_t)data_ptr[idx * stride_elements];
93
+ sum += value, sumsq += (nk_u64_t)(value * value);
94
+ }
95
+ *sum_ptr = sum;
96
+ *sumsq_ptr = sumsq;
97
+ }
98
+
99
+ NK_PUBLIC void nk_reduce_moments_i8_neonsdot( //
100
+ nk_i8_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
101
+ nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
102
+ nk_size_t stride_elements = stride_bytes / sizeof(nk_i8_t);
103
+ int aligned = (stride_bytes % sizeof(nk_i8_t) == 0);
104
+ if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
105
+ else if (!aligned) nk_reduce_moments_i8_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
106
+ else if (count > (nk_size_t)32768 * 16) {
107
+ nk_size_t left_count = count / 2;
108
+ nk_i64_t left_sum_value, right_sum_value;
109
+ nk_u64_t left_sumsq_value, right_sumsq_value;
110
+ nk_reduce_moments_i8_neonsdot(data_ptr, left_count, stride_bytes, &left_sum_value, &left_sumsq_value);
111
+ nk_reduce_moments_i8_neonsdot(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
112
+ &right_sum_value, &right_sumsq_value);
113
+ *sum_ptr = nk_i64_saturating_add_serial(left_sum_value, right_sum_value);
114
+ *sumsq_ptr = nk_u64_saturating_add_serial(left_sumsq_value, right_sumsq_value);
115
+ }
116
+ else if (stride_elements == 1) nk_reduce_moments_i8_neonsdot_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
117
+ else if (stride_elements <= 4)
118
+ nk_reduce_moments_i8_neonsdot_strided_(data_ptr, count, stride_elements, sum_ptr, sumsq_ptr);
119
+ else nk_reduce_moments_i8_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
120
+ }
121
+
122
+ NK_INTERNAL void nk_reduce_moments_u8_neonsdot_contiguous_( //
123
+ nk_u8_t const *data_ptr, nk_size_t count, //
124
+ nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
125
+ uint8x16_t ones_u8x16 = vdupq_n_u8(1);
126
+ uint32x4_t sum_u32x4 = vdupq_n_u32(0);
127
+ uint32x4_t sumsq_u32x4 = vdupq_n_u32(0);
128
+ nk_size_t idx = 0;
129
+ for (; idx + 16 <= count; idx += 16) {
130
+ uint8x16_t data_u8x16 = vld1q_u8(data_ptr + idx);
131
+ sum_u32x4 = vdotq_u32(sum_u32x4, data_u8x16, ones_u8x16);
132
+ sumsq_u32x4 = vdotq_u32(sumsq_u32x4, data_u8x16, data_u8x16);
133
+ }
134
+ uint64x2_t sum_u64x2 = vpaddlq_u32(sum_u32x4);
135
+ nk_u64_t sum = vgetq_lane_u64(sum_u64x2, 0) + vgetq_lane_u64(sum_u64x2, 1);
136
+ uint64x2_t sumsq_u64x2 = vpaddlq_u32(sumsq_u32x4);
137
+ nk_u64_t sumsq = vgetq_lane_u64(sumsq_u64x2, 0) + vgetq_lane_u64(sumsq_u64x2, 1);
138
+ for (; idx < count; ++idx) {
139
+ nk_u64_t value = (nk_u64_t)data_ptr[idx];
140
+ sum += value, sumsq += value * value;
141
+ }
142
+ *sum_ptr = sum;
143
+ *sumsq_ptr = sumsq;
144
+ }
145
+
146
+ NK_INTERNAL void nk_reduce_moments_u8_neonsdot_strided_( //
147
+ nk_u8_t const *data_ptr, nk_size_t count, nk_size_t stride_elements, //
148
+ nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
149
+ uint8x16_t ones_u8x16 = vdupq_n_u8(1);
150
+ uint32x4_t sum_u32x4 = vdupq_n_u32(0);
151
+ uint32x4_t sumsq_u32x4 = vdupq_n_u32(0);
152
+ nk_size_t idx = 0;
153
+ if (stride_elements == 2) {
154
+ for (; idx + 16 <= count; idx += 16) {
155
+ uint8x16x2_t loaded = vld2q_u8(data_ptr + idx * 2);
156
+ uint8x16_t data_u8x16 = loaded.val[0];
157
+ sum_u32x4 = vdotq_u32(sum_u32x4, data_u8x16, ones_u8x16);
158
+ sumsq_u32x4 = vdotq_u32(sumsq_u32x4, data_u8x16, data_u8x16);
159
+ }
160
+ }
161
+ else if (stride_elements == 3) {
162
+ for (; idx + 16 <= count; idx += 16) {
163
+ uint8x16x3_t loaded = vld3q_u8(data_ptr + idx * 3);
164
+ uint8x16_t data_u8x16 = loaded.val[0];
165
+ sum_u32x4 = vdotq_u32(sum_u32x4, data_u8x16, ones_u8x16);
166
+ sumsq_u32x4 = vdotq_u32(sumsq_u32x4, data_u8x16, data_u8x16);
167
+ }
168
+ }
169
+ else if (stride_elements == 4) {
170
+ for (; idx + 16 <= count; idx += 16) {
171
+ uint8x16x4_t loaded = vld4q_u8(data_ptr + idx * 4);
172
+ uint8x16_t data_u8x16 = loaded.val[0];
173
+ sum_u32x4 = vdotq_u32(sum_u32x4, data_u8x16, ones_u8x16);
174
+ sumsq_u32x4 = vdotq_u32(sumsq_u32x4, data_u8x16, data_u8x16);
175
+ }
176
+ }
177
+ uint64x2_t sum_u64x2 = vpaddlq_u32(sum_u32x4);
178
+ nk_u64_t sum = vgetq_lane_u64(sum_u64x2, 0) + vgetq_lane_u64(sum_u64x2, 1);
179
+ uint64x2_t sumsq_u64x2 = vpaddlq_u32(sumsq_u32x4);
180
+ nk_u64_t sumsq = vgetq_lane_u64(sumsq_u64x2, 0) + vgetq_lane_u64(sumsq_u64x2, 1);
181
+ for (; idx < count; ++idx) {
182
+ nk_u64_t value = (nk_u64_t)data_ptr[idx * stride_elements];
183
+ sum += value, sumsq += value * value;
184
+ }
185
+ *sum_ptr = sum;
186
+ *sumsq_ptr = sumsq;
187
+ }
188
+
189
+ NK_PUBLIC void nk_reduce_moments_u8_neonsdot( //
190
+ nk_u8_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
191
+ nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
192
+ nk_size_t stride_elements = stride_bytes / sizeof(nk_u8_t);
193
+ int aligned = (stride_bytes % sizeof(nk_u8_t) == 0);
194
+ if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
195
+ else if (!aligned) nk_reduce_moments_u8_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
196
+ else if (count > (nk_size_t)16384 * 16) {
197
+ nk_size_t left_count = count / 2;
198
+ nk_u64_t left_sum_value, left_sumsq_value, right_sum_value, right_sumsq_value;
199
+ nk_reduce_moments_u8_neonsdot(data_ptr, left_count, stride_bytes, &left_sum_value, &left_sumsq_value);
200
+ nk_reduce_moments_u8_neonsdot(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
201
+ &right_sum_value, &right_sumsq_value);
202
+ *sum_ptr = nk_u64_saturating_add_serial(left_sum_value, right_sum_value);
203
+ *sumsq_ptr = nk_u64_saturating_add_serial(left_sumsq_value, right_sumsq_value);
204
+ }
205
+ else if (stride_elements == 1) nk_reduce_moments_u8_neonsdot_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
206
+ else if (stride_elements <= 4)
207
+ nk_reduce_moments_u8_neonsdot_strided_(data_ptr, count, stride_elements, sum_ptr, sumsq_ptr);
208
+ else nk_reduce_moments_u8_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
209
+ }
210
+
211
+ NK_INTERNAL void nk_reduce_moments_e2m3_neonsdot_contiguous_( //
212
+ nk_e2m3_t const *data_ptr, nk_size_t count, //
213
+ nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
214
+ uint8x16x2_t lut_e2m3_x16;
215
+ // table[0]: values for magnitudes 0..15
216
+ // 0x0E0C0A0806040200 → bytes [0..7] = 0,2,4,6,8,10,12,14
217
+ // 0x1E1C1A1816141210 → bytes [8..15] = 16,18,20,22,24,26,28,30
218
+ lut_e2m3_x16.val[0] = vcombine_u8(vreinterpret_u8_u64(vcreate_u64(0x0E0C0A0806040200ULL)),
219
+ vreinterpret_u8_u64(vcreate_u64(0x1E1C1A1816141210ULL)));
220
+ // table[1]: values for magnitudes 16..31
221
+ // 0x3C3834302C282420 → bytes [0..7] = 32,36,40,44,48,52,56,60
222
+ // 0x7870686058504840 → bytes [8..15] = 64,72,80,88,96,104,112,120
223
+ lut_e2m3_x16.val[1] = vcombine_u8(vreinterpret_u8_u64(vcreate_u64(0x3C3834302C282420ULL)),
224
+ vreinterpret_u8_u64(vcreate_u64(0x7870686058504840ULL)));
225
+ int8x16_t ones_i8x16 = vdupq_n_s8(1);
226
+ int32x4_t sum_i32x4 = vdupq_n_s32(0);
227
+ int32x4_t sumsq_i32x4 = vdupq_n_s32(0);
228
+ nk_size_t idx = 0;
229
+ for (; idx + 16 <= count; idx += 16) {
230
+ uint8x16_t raw_u8x16 = vld1q_u8((nk_u8_t const *)(data_ptr + idx));
231
+ uint8x16_t magnitude_u8x16 = vandq_u8(raw_u8x16, vdupq_n_u8(0x1F));
232
+ uint8x16_t unsigned_u8x16 = vqtbl2q_u8(lut_e2m3_x16, magnitude_u8x16);
233
+ uint8x16_t is_negative_u8x16 = vtstq_u8(raw_u8x16, vdupq_n_u8(0x20));
234
+ int8x16_t positive_i8x16 = vreinterpretq_s8_u8(unsigned_u8x16);
235
+ int8x16_t negative_i8x16 = vnegq_s8(positive_i8x16);
236
+ int8x16_t scaled_i8x16 = vbslq_s8(is_negative_u8x16, negative_i8x16, positive_i8x16);
237
+ sum_i32x4 = vdotq_s32(sum_i32x4, scaled_i8x16, ones_i8x16);
238
+ sumsq_i32x4 = vdotq_s32(sumsq_i32x4, scaled_i8x16, scaled_i8x16);
239
+ }
240
+ int64x2_t sum_i64x2 = vpaddlq_s32(sum_i32x4);
241
+ nk_i64_t sum = vgetq_lane_s64(sum_i64x2, 0) + vgetq_lane_s64(sum_i64x2, 1);
242
+ uint64x2_t sumsq_u64x2 = vpaddlq_u32(vreinterpretq_u32_s32(sumsq_i32x4));
243
+ nk_u64_t sumsq = vgetq_lane_u64(sumsq_u64x2, 0) + vgetq_lane_u64(sumsq_u64x2, 1);
244
+ for (; idx < count; ++idx) {
245
+ nk_f32_t value;
246
+ nk_e2m3_to_f32_serial(&data_ptr[idx], &value);
247
+ sum += (nk_i64_t)(value * 16.0f), sumsq += (nk_u64_t)(nk_i64_t)(value * value * 256.0f);
248
+ }
249
+ *sum_ptr = (nk_f32_t)sum / 16.0f, *sumsq_ptr = (nk_f32_t)sumsq / 256.0f;
250
+ }
251
+
252
+ NK_INTERNAL void nk_reduce_moments_e2m3_neonsdot_strided_( //
253
+ nk_e2m3_t const *data_ptr, nk_size_t count, nk_size_t stride_elements, //
254
+ nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
255
+ uint8x16x2_t lut_e2m3_x16;
256
+ // table[0]: values for magnitudes 0..15
257
+ // 0x0E0C0A0806040200 → bytes [0..7] = 0,2,4,6,8,10,12,14
258
+ // 0x1E1C1A1816141210 → bytes [8..15] = 16,18,20,22,24,26,28,30
259
+ lut_e2m3_x16.val[0] = vcombine_u8(vreinterpret_u8_u64(vcreate_u64(0x0E0C0A0806040200ULL)),
260
+ vreinterpret_u8_u64(vcreate_u64(0x1E1C1A1816141210ULL)));
261
+ // table[1]: values for magnitudes 16..31
262
+ // 0x3C3834302C282420 → bytes [0..7] = 32,36,40,44,48,52,56,60
263
+ // 0x7870686058504840 → bytes [8..15] = 64,72,80,88,96,104,112,120
264
+ lut_e2m3_x16.val[1] = vcombine_u8(vreinterpret_u8_u64(vcreate_u64(0x3C3834302C282420ULL)),
265
+ vreinterpret_u8_u64(vcreate_u64(0x7870686058504840ULL)));
266
+ int8x16_t ones_i8x16 = vdupq_n_s8(1);
267
+ int32x4_t sum_i32x4 = vdupq_n_s32(0);
268
+ int32x4_t sumsq_i32x4 = vdupq_n_s32(0);
269
+ nk_size_t idx = 0;
270
+ if (stride_elements == 2) {
271
+ for (; idx + 16 <= count; idx += 16) {
272
+ uint8x16x2_t loaded_u8x16x2 = vld2q_u8((nk_u8_t const *)(data_ptr + idx * 2));
273
+ uint8x16_t raw_u8x16 = loaded_u8x16x2.val[0];
274
+ uint8x16_t magnitude_u8x16 = vandq_u8(raw_u8x16, vdupq_n_u8(0x1F));
275
+ uint8x16_t unsigned_u8x16 = vqtbl2q_u8(lut_e2m3_x16, magnitude_u8x16);
276
+ uint8x16_t is_negative_u8x16 = vtstq_u8(raw_u8x16, vdupq_n_u8(0x20));
277
+ int8x16_t positive_i8x16 = vreinterpretq_s8_u8(unsigned_u8x16);
278
+ int8x16_t negative_i8x16 = vnegq_s8(positive_i8x16);
279
+ int8x16_t scaled_i8x16 = vbslq_s8(is_negative_u8x16, negative_i8x16, positive_i8x16);
280
+ sum_i32x4 = vdotq_s32(sum_i32x4, scaled_i8x16, ones_i8x16);
281
+ sumsq_i32x4 = vdotq_s32(sumsq_i32x4, scaled_i8x16, scaled_i8x16);
282
+ }
283
+ }
284
+ else if (stride_elements == 3) {
285
+ for (; idx + 16 <= count; idx += 16) {
286
+ uint8x16x3_t loaded_u8x16x3 = vld3q_u8((nk_u8_t const *)(data_ptr + idx * 3));
287
+ uint8x16_t raw_u8x16 = loaded_u8x16x3.val[0];
288
+ uint8x16_t magnitude_u8x16 = vandq_u8(raw_u8x16, vdupq_n_u8(0x1F));
289
+ uint8x16_t unsigned_u8x16 = vqtbl2q_u8(lut_e2m3_x16, magnitude_u8x16);
290
+ uint8x16_t is_negative_u8x16 = vtstq_u8(raw_u8x16, vdupq_n_u8(0x20));
291
+ int8x16_t positive_i8x16 = vreinterpretq_s8_u8(unsigned_u8x16);
292
+ int8x16_t negative_i8x16 = vnegq_s8(positive_i8x16);
293
+ int8x16_t scaled_i8x16 = vbslq_s8(is_negative_u8x16, negative_i8x16, positive_i8x16);
294
+ sum_i32x4 = vdotq_s32(sum_i32x4, scaled_i8x16, ones_i8x16);
295
+ sumsq_i32x4 = vdotq_s32(sumsq_i32x4, scaled_i8x16, scaled_i8x16);
296
+ }
297
+ }
298
+ else if (stride_elements == 4) {
299
+ for (; idx + 16 <= count; idx += 16) {
300
+ uint8x16x4_t loaded_u8x16x4 = vld4q_u8((nk_u8_t const *)(data_ptr + idx * 4));
301
+ uint8x16_t raw_u8x16 = loaded_u8x16x4.val[0];
302
+ uint8x16_t magnitude_u8x16 = vandq_u8(raw_u8x16, vdupq_n_u8(0x1F));
303
+ uint8x16_t unsigned_u8x16 = vqtbl2q_u8(lut_e2m3_x16, magnitude_u8x16);
304
+ uint8x16_t is_negative_u8x16 = vtstq_u8(raw_u8x16, vdupq_n_u8(0x20));
305
+ int8x16_t positive_i8x16 = vreinterpretq_s8_u8(unsigned_u8x16);
306
+ int8x16_t negative_i8x16 = vnegq_s8(positive_i8x16);
307
+ int8x16_t scaled_i8x16 = vbslq_s8(is_negative_u8x16, negative_i8x16, positive_i8x16);
308
+ sum_i32x4 = vdotq_s32(sum_i32x4, scaled_i8x16, ones_i8x16);
309
+ sumsq_i32x4 = vdotq_s32(sumsq_i32x4, scaled_i8x16, scaled_i8x16);
310
+ }
311
+ }
312
+ int64x2_t sum_i64x2 = vpaddlq_s32(sum_i32x4);
313
+ nk_i64_t sum = vgetq_lane_s64(sum_i64x2, 0) + vgetq_lane_s64(sum_i64x2, 1);
314
+ uint64x2_t sumsq_u64x2 = vpaddlq_u32(vreinterpretq_u32_s32(sumsq_i32x4));
315
+ nk_u64_t sumsq = vgetq_lane_u64(sumsq_u64x2, 0) + vgetq_lane_u64(sumsq_u64x2, 1);
316
+ for (; idx < count; ++idx) {
317
+ nk_f32_t value;
318
+ nk_e2m3_to_f32_serial(data_ptr + idx * stride_elements, &value);
319
+ sum += (nk_i64_t)(value * 16.0f), sumsq += (nk_u64_t)(nk_i64_t)(value * value * 256.0f);
320
+ }
321
+ *sum_ptr = (nk_f32_t)sum / 16.0f, *sumsq_ptr = (nk_f32_t)sumsq / 256.0f;
322
+ }
323
+
324
+ NK_PUBLIC void nk_reduce_moments_e2m3_neonsdot( //
325
+ nk_e2m3_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
326
+ nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
327
+ nk_size_t stride_elements = stride_bytes / sizeof(nk_e2m3_t);
328
+ int aligned = (stride_bytes % sizeof(nk_e2m3_t) == 0);
329
+ if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
330
+ else if (!aligned) nk_reduce_moments_e2m3_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
331
+ else if (count > (nk_size_t)(NK_I16_MAX + 1) * 16) {
332
+ nk_size_t left_count = count / 2;
333
+ nk_f32_t left_sum_value, left_sumsq_value, right_sum_value, right_sumsq_value;
334
+ nk_reduce_moments_e2m3_neonsdot(data_ptr, left_count, stride_bytes, &left_sum_value, &left_sumsq_value);
335
+ nk_reduce_moments_e2m3_neonsdot(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
336
+ &right_sum_value, &right_sumsq_value);
337
+ *sum_ptr = left_sum_value + right_sum_value, *sumsq_ptr = left_sumsq_value + right_sumsq_value;
338
+ }
339
+ else if (stride_elements == 1) nk_reduce_moments_e2m3_neonsdot_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
340
+ else if (stride_elements <= 4)
341
+ nk_reduce_moments_e2m3_neonsdot_strided_(data_ptr, count, stride_elements, sum_ptr, sumsq_ptr);
342
+ else nk_reduce_moments_e2m3_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
343
+ }
344
+
345
+ #if defined(__clang__)
346
+ #pragma clang attribute pop
347
+ #elif defined(__GNUC__)
348
+ #pragma GCC pop_options
349
+ #endif
350
+
351
+ #if defined(__cplusplus)
352
+ } // extern "C"
353
+ #endif
354
+
355
+ #endif // NK_TARGET_NEONSDOT
356
+ #endif // NK_TARGET_ARM_
357
+ #endif // NK_REDUCE_NEONSDOT_H