numkong 7.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (294) hide show
  1. package/LICENSE +201 -0
  2. package/README.md +495 -0
  3. package/binding.gyp +540 -0
  4. package/c/dispatch.h +512 -0
  5. package/c/dispatch_bf16.c +389 -0
  6. package/c/dispatch_bf16c.c +52 -0
  7. package/c/dispatch_e2m3.c +263 -0
  8. package/c/dispatch_e3m2.c +243 -0
  9. package/c/dispatch_e4m3.c +276 -0
  10. package/c/dispatch_e5m2.c +272 -0
  11. package/c/dispatch_f16.c +376 -0
  12. package/c/dispatch_f16c.c +58 -0
  13. package/c/dispatch_f32.c +378 -0
  14. package/c/dispatch_f32c.c +99 -0
  15. package/c/dispatch_f64.c +296 -0
  16. package/c/dispatch_f64c.c +98 -0
  17. package/c/dispatch_i16.c +96 -0
  18. package/c/dispatch_i32.c +89 -0
  19. package/c/dispatch_i4.c +150 -0
  20. package/c/dispatch_i64.c +86 -0
  21. package/c/dispatch_i8.c +289 -0
  22. package/c/dispatch_other.c +330 -0
  23. package/c/dispatch_u1.c +148 -0
  24. package/c/dispatch_u16.c +124 -0
  25. package/c/dispatch_u32.c +118 -0
  26. package/c/dispatch_u4.c +150 -0
  27. package/c/dispatch_u64.c +102 -0
  28. package/c/dispatch_u8.c +303 -0
  29. package/c/numkong.c +950 -0
  30. package/include/README.md +573 -0
  31. package/include/module.modulemap +129 -0
  32. package/include/numkong/attention/sapphireamx.h +1361 -0
  33. package/include/numkong/attention/sme.h +2066 -0
  34. package/include/numkong/attention.h +49 -0
  35. package/include/numkong/capabilities.h +748 -0
  36. package/include/numkong/cast/README.md +262 -0
  37. package/include/numkong/cast/haswell.h +975 -0
  38. package/include/numkong/cast/icelake.h +470 -0
  39. package/include/numkong/cast/neon.h +1192 -0
  40. package/include/numkong/cast/rvv.h +1021 -0
  41. package/include/numkong/cast/sapphire.h +262 -0
  42. package/include/numkong/cast/serial.h +2262 -0
  43. package/include/numkong/cast/skylake.h +856 -0
  44. package/include/numkong/cast/v128relaxed.h +180 -0
  45. package/include/numkong/cast.h +230 -0
  46. package/include/numkong/curved/README.md +223 -0
  47. package/include/numkong/curved/genoa.h +182 -0
  48. package/include/numkong/curved/haswell.h +276 -0
  49. package/include/numkong/curved/neon.h +205 -0
  50. package/include/numkong/curved/neonbfdot.h +212 -0
  51. package/include/numkong/curved/neonhalf.h +212 -0
  52. package/include/numkong/curved/rvv.h +305 -0
  53. package/include/numkong/curved/serial.h +207 -0
  54. package/include/numkong/curved/skylake.h +457 -0
  55. package/include/numkong/curved/smef64.h +506 -0
  56. package/include/numkong/curved.h +517 -0
  57. package/include/numkong/curved.hpp +144 -0
  58. package/include/numkong/dot/README.md +425 -0
  59. package/include/numkong/dot/alder.h +563 -0
  60. package/include/numkong/dot/genoa.h +315 -0
  61. package/include/numkong/dot/haswell.h +1688 -0
  62. package/include/numkong/dot/icelake.h +883 -0
  63. package/include/numkong/dot/neon.h +818 -0
  64. package/include/numkong/dot/neonbfdot.h +244 -0
  65. package/include/numkong/dot/neonfhm.h +360 -0
  66. package/include/numkong/dot/neonhalf.h +198 -0
  67. package/include/numkong/dot/neonsdot.h +508 -0
  68. package/include/numkong/dot/rvv.h +714 -0
  69. package/include/numkong/dot/rvvbb.h +72 -0
  70. package/include/numkong/dot/rvvbf16.h +123 -0
  71. package/include/numkong/dot/rvvhalf.h +129 -0
  72. package/include/numkong/dot/sapphire.h +141 -0
  73. package/include/numkong/dot/serial.h +838 -0
  74. package/include/numkong/dot/sierra.h +405 -0
  75. package/include/numkong/dot/skylake.h +1084 -0
  76. package/include/numkong/dot/sve.h +379 -0
  77. package/include/numkong/dot/svebfdot.h +74 -0
  78. package/include/numkong/dot/svehalf.h +123 -0
  79. package/include/numkong/dot/v128relaxed.h +1258 -0
  80. package/include/numkong/dot.h +1070 -0
  81. package/include/numkong/dot.hpp +94 -0
  82. package/include/numkong/dots/README.md +496 -0
  83. package/include/numkong/dots/alder.h +114 -0
  84. package/include/numkong/dots/genoa.h +94 -0
  85. package/include/numkong/dots/haswell.h +295 -0
  86. package/include/numkong/dots/icelake.h +171 -0
  87. package/include/numkong/dots/neon.h +120 -0
  88. package/include/numkong/dots/neonbfdot.h +58 -0
  89. package/include/numkong/dots/neonfhm.h +94 -0
  90. package/include/numkong/dots/neonhalf.h +57 -0
  91. package/include/numkong/dots/neonsdot.h +108 -0
  92. package/include/numkong/dots/rvv.h +2486 -0
  93. package/include/numkong/dots/sapphireamx.h +3973 -0
  94. package/include/numkong/dots/serial.h +2844 -0
  95. package/include/numkong/dots/sierra.h +97 -0
  96. package/include/numkong/dots/skylake.h +196 -0
  97. package/include/numkong/dots/sme.h +5372 -0
  98. package/include/numkong/dots/smebi32.h +461 -0
  99. package/include/numkong/dots/smef64.h +1318 -0
  100. package/include/numkong/dots/smehalf.h +47 -0
  101. package/include/numkong/dots/v128relaxed.h +294 -0
  102. package/include/numkong/dots.h +2804 -0
  103. package/include/numkong/dots.hpp +639 -0
  104. package/include/numkong/each/README.md +469 -0
  105. package/include/numkong/each/haswell.h +1658 -0
  106. package/include/numkong/each/icelake.h +272 -0
  107. package/include/numkong/each/neon.h +1104 -0
  108. package/include/numkong/each/neonbfdot.h +212 -0
  109. package/include/numkong/each/neonhalf.h +410 -0
  110. package/include/numkong/each/rvv.h +1121 -0
  111. package/include/numkong/each/sapphire.h +477 -0
  112. package/include/numkong/each/serial.h +260 -0
  113. package/include/numkong/each/skylake.h +1562 -0
  114. package/include/numkong/each.h +2146 -0
  115. package/include/numkong/each.hpp +434 -0
  116. package/include/numkong/geospatial/README.md +147 -0
  117. package/include/numkong/geospatial/haswell.h +593 -0
  118. package/include/numkong/geospatial/neon.h +571 -0
  119. package/include/numkong/geospatial/rvv.h +701 -0
  120. package/include/numkong/geospatial/serial.h +309 -0
  121. package/include/numkong/geospatial/skylake.h +577 -0
  122. package/include/numkong/geospatial/v128relaxed.h +613 -0
  123. package/include/numkong/geospatial.h +453 -0
  124. package/include/numkong/geospatial.hpp +235 -0
  125. package/include/numkong/matrix.hpp +336 -0
  126. package/include/numkong/maxsim/README.md +187 -0
  127. package/include/numkong/maxsim/alder.h +511 -0
  128. package/include/numkong/maxsim/genoa.h +115 -0
  129. package/include/numkong/maxsim/haswell.h +553 -0
  130. package/include/numkong/maxsim/icelake.h +480 -0
  131. package/include/numkong/maxsim/neonsdot.h +394 -0
  132. package/include/numkong/maxsim/sapphireamx.h +877 -0
  133. package/include/numkong/maxsim/serial.h +490 -0
  134. package/include/numkong/maxsim/sme.h +929 -0
  135. package/include/numkong/maxsim/v128relaxed.h +280 -0
  136. package/include/numkong/maxsim.h +571 -0
  137. package/include/numkong/maxsim.hpp +133 -0
  138. package/include/numkong/mesh/README.md +227 -0
  139. package/include/numkong/mesh/haswell.h +2235 -0
  140. package/include/numkong/mesh/neon.h +1329 -0
  141. package/include/numkong/mesh/neonbfdot.h +842 -0
  142. package/include/numkong/mesh/neonhalf.h +616 -0
  143. package/include/numkong/mesh/rvv.h +916 -0
  144. package/include/numkong/mesh/serial.h +742 -0
  145. package/include/numkong/mesh/skylake.h +1135 -0
  146. package/include/numkong/mesh/v128relaxed.h +1052 -0
  147. package/include/numkong/mesh.h +652 -0
  148. package/include/numkong/mesh.hpp +762 -0
  149. package/include/numkong/numkong.h +78 -0
  150. package/include/numkong/numkong.hpp +57 -0
  151. package/include/numkong/probability/README.md +173 -0
  152. package/include/numkong/probability/haswell.h +267 -0
  153. package/include/numkong/probability/neon.h +225 -0
  154. package/include/numkong/probability/rvv.h +409 -0
  155. package/include/numkong/probability/serial.h +169 -0
  156. package/include/numkong/probability/skylake.h +324 -0
  157. package/include/numkong/probability.h +383 -0
  158. package/include/numkong/probability.hpp +120 -0
  159. package/include/numkong/random.h +50 -0
  160. package/include/numkong/random.hpp +285 -0
  161. package/include/numkong/reduce/README.md +547 -0
  162. package/include/numkong/reduce/alder.h +632 -0
  163. package/include/numkong/reduce/genoa.h +201 -0
  164. package/include/numkong/reduce/haswell.h +3783 -0
  165. package/include/numkong/reduce/icelake.h +549 -0
  166. package/include/numkong/reduce/neon.h +3841 -0
  167. package/include/numkong/reduce/neonbfdot.h +353 -0
  168. package/include/numkong/reduce/neonfhm.h +665 -0
  169. package/include/numkong/reduce/neonhalf.h +157 -0
  170. package/include/numkong/reduce/neonsdot.h +357 -0
  171. package/include/numkong/reduce/rvv.h +3407 -0
  172. package/include/numkong/reduce/serial.h +757 -0
  173. package/include/numkong/reduce/sierra.h +338 -0
  174. package/include/numkong/reduce/skylake.h +3792 -0
  175. package/include/numkong/reduce/v128relaxed.h +2302 -0
  176. package/include/numkong/reduce.h +1597 -0
  177. package/include/numkong/reduce.hpp +633 -0
  178. package/include/numkong/scalar/README.md +89 -0
  179. package/include/numkong/scalar/haswell.h +113 -0
  180. package/include/numkong/scalar/neon.h +122 -0
  181. package/include/numkong/scalar/neonhalf.h +70 -0
  182. package/include/numkong/scalar/rvv.h +211 -0
  183. package/include/numkong/scalar/sapphire.h +63 -0
  184. package/include/numkong/scalar/serial.h +332 -0
  185. package/include/numkong/scalar/v128relaxed.h +56 -0
  186. package/include/numkong/scalar.h +683 -0
  187. package/include/numkong/set/README.md +179 -0
  188. package/include/numkong/set/haswell.h +334 -0
  189. package/include/numkong/set/icelake.h +485 -0
  190. package/include/numkong/set/neon.h +364 -0
  191. package/include/numkong/set/rvv.h +226 -0
  192. package/include/numkong/set/rvvbb.h +117 -0
  193. package/include/numkong/set/serial.h +174 -0
  194. package/include/numkong/set/sve.h +185 -0
  195. package/include/numkong/set/v128relaxed.h +240 -0
  196. package/include/numkong/set.h +457 -0
  197. package/include/numkong/set.hpp +114 -0
  198. package/include/numkong/sets/README.md +149 -0
  199. package/include/numkong/sets/haswell.h +63 -0
  200. package/include/numkong/sets/icelake.h +66 -0
  201. package/include/numkong/sets/neon.h +61 -0
  202. package/include/numkong/sets/serial.h +43 -0
  203. package/include/numkong/sets/smebi32.h +1099 -0
  204. package/include/numkong/sets/v128relaxed.h +58 -0
  205. package/include/numkong/sets.h +339 -0
  206. package/include/numkong/sparse/README.md +156 -0
  207. package/include/numkong/sparse/icelake.h +463 -0
  208. package/include/numkong/sparse/neon.h +288 -0
  209. package/include/numkong/sparse/serial.h +117 -0
  210. package/include/numkong/sparse/sve2.h +507 -0
  211. package/include/numkong/sparse/turin.h +322 -0
  212. package/include/numkong/sparse.h +363 -0
  213. package/include/numkong/sparse.hpp +113 -0
  214. package/include/numkong/spatial/README.md +435 -0
  215. package/include/numkong/spatial/alder.h +607 -0
  216. package/include/numkong/spatial/genoa.h +290 -0
  217. package/include/numkong/spatial/haswell.h +960 -0
  218. package/include/numkong/spatial/icelake.h +586 -0
  219. package/include/numkong/spatial/neon.h +773 -0
  220. package/include/numkong/spatial/neonbfdot.h +165 -0
  221. package/include/numkong/spatial/neonhalf.h +118 -0
  222. package/include/numkong/spatial/neonsdot.h +261 -0
  223. package/include/numkong/spatial/rvv.h +984 -0
  224. package/include/numkong/spatial/rvvbf16.h +123 -0
  225. package/include/numkong/spatial/rvvhalf.h +117 -0
  226. package/include/numkong/spatial/sapphire.h +343 -0
  227. package/include/numkong/spatial/serial.h +346 -0
  228. package/include/numkong/spatial/sierra.h +323 -0
  229. package/include/numkong/spatial/skylake.h +606 -0
  230. package/include/numkong/spatial/sve.h +224 -0
  231. package/include/numkong/spatial/svebfdot.h +122 -0
  232. package/include/numkong/spatial/svehalf.h +109 -0
  233. package/include/numkong/spatial/v128relaxed.h +717 -0
  234. package/include/numkong/spatial.h +1425 -0
  235. package/include/numkong/spatial.hpp +183 -0
  236. package/include/numkong/spatials/README.md +580 -0
  237. package/include/numkong/spatials/alder.h +94 -0
  238. package/include/numkong/spatials/genoa.h +94 -0
  239. package/include/numkong/spatials/haswell.h +219 -0
  240. package/include/numkong/spatials/icelake.h +113 -0
  241. package/include/numkong/spatials/neon.h +109 -0
  242. package/include/numkong/spatials/neonbfdot.h +60 -0
  243. package/include/numkong/spatials/neonfhm.h +92 -0
  244. package/include/numkong/spatials/neonhalf.h +58 -0
  245. package/include/numkong/spatials/neonsdot.h +109 -0
  246. package/include/numkong/spatials/rvv.h +1960 -0
  247. package/include/numkong/spatials/sapphireamx.h +1149 -0
  248. package/include/numkong/spatials/serial.h +226 -0
  249. package/include/numkong/spatials/sierra.h +96 -0
  250. package/include/numkong/spatials/skylake.h +184 -0
  251. package/include/numkong/spatials/sme.h +1901 -0
  252. package/include/numkong/spatials/smef64.h +465 -0
  253. package/include/numkong/spatials/v128relaxed.h +240 -0
  254. package/include/numkong/spatials.h +3021 -0
  255. package/include/numkong/spatials.hpp +508 -0
  256. package/include/numkong/tensor.hpp +1592 -0
  257. package/include/numkong/trigonometry/README.md +184 -0
  258. package/include/numkong/trigonometry/haswell.h +652 -0
  259. package/include/numkong/trigonometry/neon.h +639 -0
  260. package/include/numkong/trigonometry/rvv.h +699 -0
  261. package/include/numkong/trigonometry/serial.h +703 -0
  262. package/include/numkong/trigonometry/skylake.h +721 -0
  263. package/include/numkong/trigonometry/v128relaxed.h +666 -0
  264. package/include/numkong/trigonometry.h +467 -0
  265. package/include/numkong/trigonometry.hpp +166 -0
  266. package/include/numkong/types.h +1384 -0
  267. package/include/numkong/types.hpp +5603 -0
  268. package/include/numkong/vector.hpp +698 -0
  269. package/javascript/README.md +246 -0
  270. package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
  271. package/javascript/dist/cjs/numkong-wasm.js +617 -0
  272. package/javascript/dist/cjs/numkong.d.ts +343 -0
  273. package/javascript/dist/cjs/numkong.js +523 -0
  274. package/javascript/dist/cjs/package.json +3 -0
  275. package/javascript/dist/cjs/types.d.ts +284 -0
  276. package/javascript/dist/cjs/types.js +653 -0
  277. package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
  278. package/javascript/dist/esm/numkong-wasm.js +595 -0
  279. package/javascript/dist/esm/numkong.d.ts +343 -0
  280. package/javascript/dist/esm/numkong.js +452 -0
  281. package/javascript/dist/esm/package.json +3 -0
  282. package/javascript/dist/esm/types.d.ts +284 -0
  283. package/javascript/dist/esm/types.js +630 -0
  284. package/javascript/dist-package-cjs.json +3 -0
  285. package/javascript/dist-package-esm.json +3 -0
  286. package/javascript/node-gyp-build.d.ts +1 -0
  287. package/javascript/numkong-wasm.ts +756 -0
  288. package/javascript/numkong.c +689 -0
  289. package/javascript/numkong.ts +575 -0
  290. package/javascript/tsconfig-base.json +39 -0
  291. package/javascript/tsconfig-cjs.json +8 -0
  292. package/javascript/tsconfig-esm.json +8 -0
  293. package/javascript/types.ts +674 -0
  294. package/package.json +87 -0
@@ -0,0 +1,338 @@
1
+ /**
2
+ * @brief Sierra Forest (AVX-VNNI-INT8) implementations for the redesigned reduction API (moments).
3
+ * @file include/numkong/reduce/sierra.h
4
+ * @author Ash Vardanian
5
+ * @date February 13, 2026
6
+ *
7
+ * @sa include/numkong/reduce.h
8
+ *
9
+ * Uses AVX-VNNI-INT8 (256-bit) for efficient widening dot-products on i8, u8, and e2m3:
10
+ * - `_mm256_dpbssd_epi32`: i8 x i8 -> i32 signed dot product (AVXVNNIINT8)
11
+ * - `_mm256_dpbuud_epi32`: u8 x u8 -> u32 unsigned dot product (AVXVNNIINT8)
12
+ */
13
+ #ifndef NK_REDUCE_SIERRA_H
14
+ #define NK_REDUCE_SIERRA_H
15
+
16
+ #if NK_TARGET_X86_
17
+ #if NK_TARGET_SIERRA
18
+
19
+ #include "numkong/types.h"
20
+ #include "numkong/reduce/serial.h"
21
+ #include "numkong/reduce/haswell.h" // `nk_reduce_add_i32x8_haswell_`
22
+
23
+ #if defined(__cplusplus)
24
+ extern "C" {
25
+ #endif
26
+
27
+ #if defined(__clang__)
28
+ #pragma clang attribute push(__attribute__((target("avx2,f16c,fma,bmi,bmi2,avxvnni,avxvnniint8"))), apply_to = function)
29
+ #elif defined(__GNUC__)
30
+ #pragma GCC push_options
31
+ #pragma GCC target("avx2", "f16c", "fma", "bmi", "bmi2", "avxvnni", "avxvnniint8")
32
+ #endif
33
+
34
+ NK_INTERNAL void nk_reduce_moments_i8_sierra_contiguous_( //
35
+ nk_i8_t const *data, nk_size_t count, //
36
+ nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
37
+ __m256i ones_i8x32 = _mm256_set1_epi8(1);
38
+ __m256i sum_i32x8 = _mm256_setzero_si256();
39
+ __m256i sumsq_i32x8 = _mm256_setzero_si256();
40
+ nk_size_t idx = 0;
41
+ for (; idx + 32 <= count; idx += 32) {
42
+ __m256i data_i8x32 = _mm256_loadu_si256((__m256i const *)(data + idx));
43
+ sum_i32x8 = _mm256_dpbssd_epi32(sum_i32x8, data_i8x32, ones_i8x32);
44
+ sumsq_i32x8 = _mm256_dpbssd_epi32(sumsq_i32x8, data_i8x32, data_i8x32);
45
+ }
46
+ nk_i64_t sum = (nk_i64_t)nk_reduce_add_i32x8_haswell_(sum_i32x8);
47
+ nk_u64_t sumsq = (nk_u64_t)(nk_u32_t)nk_reduce_add_i32x8_haswell_(sumsq_i32x8);
48
+ nk_size_t remaining = count - idx;
49
+ if (remaining > 0) {
50
+ nk_b256_vec_t tail_vec;
51
+ nk_partial_load_b8x32_serial_(data + idx, &tail_vec, remaining);
52
+ __m256i data_i8x32 = tail_vec.ymm;
53
+ __m256i tail_sum_i32x8 = _mm256_dpbssd_epi32(_mm256_setzero_si256(), data_i8x32, ones_i8x32);
54
+ __m256i tail_sumsq_i32x8 = _mm256_dpbssd_epi32(_mm256_setzero_si256(), data_i8x32, data_i8x32);
55
+ sum += (nk_i64_t)nk_reduce_add_i32x8_haswell_(tail_sum_i32x8);
56
+ sumsq += (nk_u64_t)(nk_u32_t)nk_reduce_add_i32x8_haswell_(tail_sumsq_i32x8);
57
+ }
58
+ *sum_ptr = sum, *sumsq_ptr = sumsq;
59
+ }
60
+
61
+ NK_INTERNAL void nk_reduce_moments_i8_sierra_strided_( //
62
+ nk_i8_t const *data, nk_size_t count, nk_size_t stride_elements, //
63
+ nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
64
+ __m256i stride_mask_i8x32 = nk_stride_blend_u1x32_(stride_elements);
65
+ __m256i ones_i8x32 = _mm256_set1_epi8(1);
66
+ __m256i sum_i32x8 = _mm256_setzero_si256();
67
+ __m256i sumsq_i32x8 = _mm256_setzero_si256();
68
+ nk_size_t idx_scalars = 0;
69
+ nk_size_t total_scalars = count * stride_elements;
70
+ nk_size_t step = nk_size_round_up_to_multiple_(32, stride_elements);
71
+ for (; idx_scalars + step <= total_scalars; idx_scalars += step) {
72
+ __m256i data_i8x32 = _mm256_loadu_si256((__m256i const *)(data + idx_scalars));
73
+ data_i8x32 = _mm256_and_si256(data_i8x32, stride_mask_i8x32);
74
+ sum_i32x8 = _mm256_dpbssd_epi32(sum_i32x8, data_i8x32, ones_i8x32);
75
+ sumsq_i32x8 = _mm256_dpbssd_epi32(sumsq_i32x8, data_i8x32, data_i8x32);
76
+ }
77
+ nk_i64_t sum = (nk_i64_t)nk_reduce_add_i32x8_haswell_(sum_i32x8);
78
+ nk_u64_t sumsq = (nk_u64_t)(nk_u32_t)nk_reduce_add_i32x8_haswell_(sumsq_i32x8);
79
+ nk_i8_t const *ptr = data + idx_scalars;
80
+ nk_size_t remaining = count - idx_scalars / stride_elements;
81
+ for (nk_size_t i = 0; i < remaining; ++i, ptr += stride_elements) {
82
+ nk_i64_t val = (nk_i64_t)*ptr;
83
+ sum += val, sumsq += (nk_u64_t)(val * val);
84
+ }
85
+ *sum_ptr = sum, *sumsq_ptr = sumsq;
86
+ }
87
+
88
+ NK_PUBLIC void nk_reduce_moments_i8_sierra( //
89
+ nk_i8_t const *data, nk_size_t count, nk_size_t stride_bytes, //
90
+ nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
91
+ nk_size_t stride_elements = stride_bytes / sizeof(nk_i8_t);
92
+ int aligned = (stride_bytes % sizeof(nk_i8_t) == 0);
93
+ if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
94
+ else if (!aligned) nk_reduce_moments_i8_serial(data, count, stride_bytes, sum_ptr, sumsq_ptr);
95
+ else if (count > (nk_size_t)32768 * 32) {
96
+ nk_size_t left_count = count / 2;
97
+ nk_i64_t left_sum, right_sum;
98
+ nk_u64_t left_sumsq, right_sumsq;
99
+ nk_reduce_moments_i8_sierra(data, left_count, stride_bytes, &left_sum, &left_sumsq);
100
+ nk_reduce_moments_i8_sierra(data + left_count * stride_elements, count - left_count, stride_bytes, &right_sum,
101
+ &right_sumsq);
102
+ *sum_ptr = nk_i64_saturating_add_serial(left_sum, right_sum);
103
+ *sumsq_ptr = nk_u64_saturating_add_serial(left_sumsq, right_sumsq);
104
+ }
105
+ else if (stride_elements == 1) nk_reduce_moments_i8_sierra_contiguous_(data, count, sum_ptr, sumsq_ptr);
106
+ else if (stride_elements <= 8)
107
+ nk_reduce_moments_i8_sierra_strided_(data, count, stride_elements, sum_ptr, sumsq_ptr);
108
+ else nk_reduce_moments_i8_serial(data, count, stride_bytes, sum_ptr, sumsq_ptr);
109
+ }
110
+
111
+ /**
112
+ * @section u8 moments via VPDPBUUD (unsigned u8 x u8 -> u32)
113
+ *
114
+ * Sierra's `_mm256_dpbuud_epi32` provides native u8×u8→u32 dot product, replacing
115
+ * Haswell's 8-instruction SAD+widen+MADD sequence with 3 instructions per 32 elements.
116
+ * - sum: dot(data, ones) via DPBUUD — each group of 4 bytes sums into a u32 lane
117
+ * - sumsq: dot(data, data) via DPBUUD — native u8×u8 squaring and accumulation
118
+ */
119
+ NK_INTERNAL void nk_reduce_moments_u8_sierra_contiguous_( //
120
+ nk_u8_t const *data, nk_size_t count, //
121
+ nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
122
+ __m256i ones_u8x32 = _mm256_set1_epi8(1);
123
+ __m256i sum_i32x8 = _mm256_setzero_si256();
124
+ __m256i sumsq_i32x8 = _mm256_setzero_si256();
125
+ nk_size_t idx = 0;
126
+ for (; idx + 32 <= count; idx += 32) {
127
+ __m256i data_u8x32 = _mm256_loadu_si256((__m256i const *)(data + idx));
128
+ sum_i32x8 = _mm256_dpbuud_epi32(sum_i32x8, data_u8x32, ones_u8x32);
129
+ sumsq_i32x8 = _mm256_dpbuud_epi32(sumsq_i32x8, data_u8x32, data_u8x32);
130
+ }
131
+ nk_u64_t sum = (nk_u64_t)(nk_u32_t)nk_reduce_add_i32x8_haswell_(sum_i32x8);
132
+ nk_u64_t sumsq = (nk_u64_t)(nk_u32_t)nk_reduce_add_i32x8_haswell_(sumsq_i32x8);
133
+ nk_size_t remaining = count - idx;
134
+ if (remaining > 0) {
135
+ nk_b256_vec_t tail_vec;
136
+ nk_partial_load_b8x32_serial_(data + idx, &tail_vec, remaining);
137
+ __m256i data_u8x32 = tail_vec.ymm;
138
+ __m256i tail_sum_i32x8 = _mm256_dpbuud_epi32(_mm256_setzero_si256(), data_u8x32, ones_u8x32);
139
+ __m256i tail_sumsq_i32x8 = _mm256_dpbuud_epi32(_mm256_setzero_si256(), data_u8x32, data_u8x32);
140
+ sum += (nk_u64_t)(nk_u32_t)nk_reduce_add_i32x8_haswell_(tail_sum_i32x8);
141
+ sumsq += (nk_u64_t)(nk_u32_t)nk_reduce_add_i32x8_haswell_(tail_sumsq_i32x8);
142
+ }
143
+ *sum_ptr = sum, *sumsq_ptr = sumsq;
144
+ }
145
+
146
+ NK_INTERNAL void nk_reduce_moments_u8_sierra_strided_( //
147
+ nk_u8_t const *data, nk_size_t count, nk_size_t stride_elements, //
148
+ nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
149
+ __m256i stride_mask_u8x32 = nk_stride_blend_u1x32_(stride_elements);
150
+ __m256i ones_u8x32 = _mm256_set1_epi8(1);
151
+ __m256i sum_i32x8 = _mm256_setzero_si256();
152
+ __m256i sumsq_i32x8 = _mm256_setzero_si256();
153
+ nk_size_t idx_scalars = 0;
154
+ nk_size_t total_scalars = count * stride_elements;
155
+ nk_size_t step = nk_size_round_up_to_multiple_(32, stride_elements);
156
+ for (; idx_scalars + step <= total_scalars; idx_scalars += step) {
157
+ __m256i data_u8x32 = _mm256_loadu_si256((__m256i const *)(data + idx_scalars));
158
+ data_u8x32 = _mm256_and_si256(data_u8x32, stride_mask_u8x32);
159
+ sum_i32x8 = _mm256_dpbuud_epi32(sum_i32x8, data_u8x32, ones_u8x32);
160
+ sumsq_i32x8 = _mm256_dpbuud_epi32(sumsq_i32x8, data_u8x32, data_u8x32);
161
+ }
162
+ nk_u64_t sum = (nk_u64_t)(nk_u32_t)nk_reduce_add_i32x8_haswell_(sum_i32x8);
163
+ nk_u64_t sumsq = (nk_u64_t)(nk_u32_t)nk_reduce_add_i32x8_haswell_(sumsq_i32x8);
164
+ nk_u8_t const *ptr = data + idx_scalars;
165
+ nk_size_t remaining = count - idx_scalars / stride_elements;
166
+ for (nk_size_t i = 0; i < remaining; ++i, ptr += stride_elements) {
167
+ nk_u64_t val = (nk_u64_t)*ptr;
168
+ sum += val, sumsq += val * val;
169
+ }
170
+ *sum_ptr = sum, *sumsq_ptr = sumsq;
171
+ }
172
+
173
+ NK_PUBLIC void nk_reduce_moments_u8_sierra( //
174
+ nk_u8_t const *data, nk_size_t count, nk_size_t stride_bytes, //
175
+ nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
176
+ nk_size_t stride_elements = stride_bytes / sizeof(nk_u8_t);
177
+ int aligned = (stride_bytes % sizeof(nk_u8_t) == 0);
178
+ if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
179
+ else if (!aligned) nk_reduce_moments_u8_serial(data, count, stride_bytes, sum_ptr, sumsq_ptr);
180
+ else if (count > (nk_size_t)16384 * 32) {
181
+ nk_size_t left_count = count / 2;
182
+ nk_u64_t left_sum, left_sumsq, right_sum, right_sumsq;
183
+ nk_reduce_moments_u8_sierra(data, left_count, stride_bytes, &left_sum, &left_sumsq);
184
+ nk_reduce_moments_u8_sierra(data + left_count * stride_elements, count - left_count, stride_bytes, &right_sum,
185
+ &right_sumsq);
186
+ *sum_ptr = nk_u64_saturating_add_serial(left_sum, right_sum);
187
+ *sumsq_ptr = nk_u64_saturating_add_serial(left_sumsq, right_sumsq);
188
+ }
189
+ else if (stride_elements == 1) nk_reduce_moments_u8_sierra_contiguous_(data, count, sum_ptr, sumsq_ptr);
190
+ else if (stride_elements <= 8)
191
+ nk_reduce_moments_u8_sierra_strided_(data, count, stride_elements, sum_ptr, sumsq_ptr);
192
+ else nk_reduce_moments_u8_serial(data, count, stride_bytes, sum_ptr, sumsq_ptr);
193
+ }
194
+
195
+ /**
196
+ * @section e2m3 moments via integer VNNI (dpbssd)
197
+ *
198
+ * Every e2m3 value × 16 is an exact integer in [-120, +120] (i8 range).
199
+ * We use a dual-VPSHUFB LUT to map 5-bit magnitude → unsigned i8, apply the sign,
200
+ * then accumulate with `_mm256_dpbssd_epi32` (signed i8 × signed i8 → i32).
201
+ * Final: sum = i32_sum / 16, sumsq = i32_sumsq / 256.
202
+ */
203
+ NK_INTERNAL void nk_reduce_moments_e2m3_sierra_contiguous_( //
204
+ nk_e2m3_t const *data, nk_size_t count, //
205
+ nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
206
+ __m256i const lut_lower_u8x32 = _mm256_set_epi8(30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0, //
207
+ 30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
208
+ __m256i const lut_upper_u8x32 = _mm256_set_epi8(120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32,
209
+ 120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32);
210
+ __m256i const nibble_mask_u8x32 = _mm256_set1_epi8(0x0F);
211
+ __m256i const magnitude_mask_u8x32 = _mm256_set1_epi8(0x1F);
212
+ __m256i const half_select_u8x32 = _mm256_set1_epi8(0x10);
213
+ __m256i const sign_mask_u8x32 = _mm256_set1_epi8(0x20);
214
+ __m256i const ones_i8x32 = _mm256_set1_epi8(1);
215
+ __m256i sum_i32x8 = _mm256_setzero_si256();
216
+ __m256i sumsq_i32x8 = _mm256_setzero_si256();
217
+ nk_size_t idx = 0;
218
+ for (; idx + 32 <= count; idx += 32) {
219
+ __m256i data_u8x32 = _mm256_loadu_si256((__m256i const *)(data + idx));
220
+ __m256i magnitude_u8x32 = _mm256_and_si256(data_u8x32, magnitude_mask_u8x32);
221
+ __m256i shuffle_idx_u8x32 = _mm256_and_si256(magnitude_u8x32, nibble_mask_u8x32);
222
+ __m256i upper_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(magnitude_u8x32, half_select_u8x32),
223
+ half_select_u8x32);
224
+ __m256i unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_lower_u8x32, shuffle_idx_u8x32),
225
+ _mm256_shuffle_epi8(lut_upper_u8x32, shuffle_idx_u8x32),
226
+ upper_select_u8x32);
227
+ __m256i negate_mask_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(data_u8x32, sign_mask_u8x32), sign_mask_u8x32);
228
+ __m256i signed_i8x32 = _mm256_blendv_epi8(
229
+ unsigned_u8x32, _mm256_sub_epi8(_mm256_setzero_si256(), unsigned_u8x32), negate_mask_u8x32);
230
+ sum_i32x8 = _mm256_dpbssd_epi32(sum_i32x8, signed_i8x32, ones_i8x32);
231
+ sumsq_i32x8 = _mm256_dpbssd_epi32(sumsq_i32x8, signed_i8x32, signed_i8x32);
232
+ }
233
+ nk_i32_t sum = nk_reduce_add_i32x8_haswell_(sum_i32x8);
234
+ nk_i32_t sumsq = nk_reduce_add_i32x8_haswell_(sumsq_i32x8);
235
+ nk_size_t remaining = count - idx;
236
+ if (remaining > 0) {
237
+ nk_b256_vec_t tail_vec;
238
+ nk_partial_load_b8x32_serial_(data + idx, &tail_vec, remaining);
239
+ __m256i data_u8x32 = tail_vec.ymm;
240
+ __m256i magnitude_u8x32 = _mm256_and_si256(data_u8x32, magnitude_mask_u8x32);
241
+ __m256i shuffle_idx_u8x32 = _mm256_and_si256(magnitude_u8x32, nibble_mask_u8x32);
242
+ __m256i upper_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(magnitude_u8x32, half_select_u8x32),
243
+ half_select_u8x32);
244
+ __m256i unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_lower_u8x32, shuffle_idx_u8x32),
245
+ _mm256_shuffle_epi8(lut_upper_u8x32, shuffle_idx_u8x32),
246
+ upper_select_u8x32);
247
+ __m256i negate_mask_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(data_u8x32, sign_mask_u8x32), sign_mask_u8x32);
248
+ __m256i signed_i8x32 = _mm256_blendv_epi8(
249
+ unsigned_u8x32, _mm256_sub_epi8(_mm256_setzero_si256(), unsigned_u8x32), negate_mask_u8x32);
250
+ sum += nk_reduce_add_i32x8_haswell_(_mm256_dpbssd_epi32(_mm256_setzero_si256(), signed_i8x32, ones_i8x32));
251
+ sumsq += nk_reduce_add_i32x8_haswell_(_mm256_dpbssd_epi32(_mm256_setzero_si256(), signed_i8x32, signed_i8x32));
252
+ }
253
+ *sum_ptr = (nk_f32_t)sum / 16.0f;
254
+ *sumsq_ptr = (nk_f32_t)sumsq / 256.0f;
255
+ }
256
+
257
+ NK_INTERNAL void nk_reduce_moments_e2m3_sierra_strided_( //
258
+ nk_e2m3_t const *data, nk_size_t count, nk_size_t stride_elements, //
259
+ nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
260
+ __m256i stride_mask_u8x32 = nk_stride_blend_u1x32_(stride_elements);
261
+ __m256i const lut_lower_u8x32 = _mm256_set_epi8(30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0, //
262
+ 30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
263
+ __m256i const lut_upper_u8x32 = _mm256_set_epi8(120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32,
264
+ 120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32);
265
+ __m256i const nibble_mask_u8x32 = _mm256_set1_epi8(0x0F);
266
+ __m256i const magnitude_mask_u8x32 = _mm256_set1_epi8(0x1F);
267
+ __m256i const half_select_u8x32 = _mm256_set1_epi8(0x10);
268
+ __m256i const sign_mask_u8x32 = _mm256_set1_epi8(0x20);
269
+ __m256i const ones_i8x32 = _mm256_set1_epi8(1);
270
+ __m256i sum_i32x8 = _mm256_setzero_si256();
271
+ __m256i sumsq_i32x8 = _mm256_setzero_si256();
272
+ nk_size_t idx_scalars = 0;
273
+ nk_size_t total_scalars = count * stride_elements;
274
+ nk_size_t step = nk_size_round_up_to_multiple_(32, stride_elements);
275
+ for (; idx_scalars + step <= total_scalars; idx_scalars += step) {
276
+ __m256i data_u8x32 = _mm256_loadu_si256((__m256i const *)(data + idx_scalars));
277
+ data_u8x32 = _mm256_and_si256(data_u8x32, stride_mask_u8x32);
278
+ __m256i magnitude_u8x32 = _mm256_and_si256(data_u8x32, magnitude_mask_u8x32);
279
+ __m256i shuffle_idx_u8x32 = _mm256_and_si256(magnitude_u8x32, nibble_mask_u8x32);
280
+ __m256i upper_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(magnitude_u8x32, half_select_u8x32),
281
+ half_select_u8x32);
282
+ __m256i unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_lower_u8x32, shuffle_idx_u8x32),
283
+ _mm256_shuffle_epi8(lut_upper_u8x32, shuffle_idx_u8x32),
284
+ upper_select_u8x32);
285
+ __m256i negate_mask_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(data_u8x32, sign_mask_u8x32), sign_mask_u8x32);
286
+ __m256i signed_i8x32 = _mm256_blendv_epi8(
287
+ unsigned_u8x32, _mm256_sub_epi8(_mm256_setzero_si256(), unsigned_u8x32), negate_mask_u8x32);
288
+ sum_i32x8 = _mm256_dpbssd_epi32(sum_i32x8, signed_i8x32, ones_i8x32);
289
+ sumsq_i32x8 = _mm256_dpbssd_epi32(sumsq_i32x8, signed_i8x32, signed_i8x32);
290
+ }
291
+ nk_i32_t sum = nk_reduce_add_i32x8_haswell_(sum_i32x8);
292
+ nk_i32_t sumsq = nk_reduce_add_i32x8_haswell_(sumsq_i32x8);
293
+ nk_e2m3_t const *ptr = data + idx_scalars;
294
+ nk_size_t remaining = count - idx_scalars / stride_elements;
295
+ for (nk_size_t i = 0; i < remaining; ++i, ptr += stride_elements) {
296
+ nk_f32_t val;
297
+ nk_e2m3_to_f32_serial(ptr, &val);
298
+ nk_i32_t ival = (nk_i32_t)(val * 16.0f);
299
+ sum += ival;
300
+ sumsq += ival * ival;
301
+ }
302
+ *sum_ptr = (nk_f32_t)sum / 16.0f;
303
+ *sumsq_ptr = (nk_f32_t)sumsq / 256.0f;
304
+ }
305
+
306
+ NK_PUBLIC void nk_reduce_moments_e2m3_sierra( //
307
+ nk_e2m3_t const *data, nk_size_t count, nk_size_t stride_bytes, //
308
+ nk_f32_t *sum, nk_f32_t *sumsq) {
309
+ nk_size_t stride_elements = stride_bytes / sizeof(nk_e2m3_t);
310
+ int aligned = (stride_bytes % sizeof(nk_e2m3_t) == 0);
311
+ if (count == 0) *sum = 0, *sumsq = 0;
312
+ else if (!aligned) nk_reduce_moments_e2m3_serial(data, count, stride_bytes, sum, sumsq);
313
+ else if (count > (nk_size_t)(NK_I16_MAX + 1) * 32) {
314
+ nk_size_t left_count = count / 2;
315
+ nk_f32_t left_sum, left_sumsq, right_sum, right_sumsq;
316
+ nk_reduce_moments_e2m3_sierra(data, left_count, stride_bytes, &left_sum, &left_sumsq);
317
+ nk_reduce_moments_e2m3_sierra(data + left_count * stride_elements, count - left_count, stride_bytes, &right_sum,
318
+ &right_sumsq);
319
+ *sum = left_sum + right_sum, *sumsq = left_sumsq + right_sumsq;
320
+ }
321
+ else if (stride_elements == 1) nk_reduce_moments_e2m3_sierra_contiguous_(data, count, sum, sumsq);
322
+ else if (stride_elements <= 8) nk_reduce_moments_e2m3_sierra_strided_(data, count, stride_elements, sum, sumsq);
323
+ else nk_reduce_moments_e2m3_serial(data, count, stride_bytes, sum, sumsq);
324
+ }
325
+
326
+ #if defined(__clang__)
327
+ #pragma clang attribute pop
328
+ #elif defined(__GNUC__)
329
+ #pragma GCC pop_options
330
+ #endif
331
+
332
+ #if defined(__cplusplus)
333
+ } // extern "C"
334
+ #endif
335
+
336
+ #endif // NK_TARGET_SIERRA
337
+ #endif // NK_TARGET_X86_
338
+ #endif // NK_REDUCE_SIERRA_H