numkong 7.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (294) hide show
  1. package/LICENSE +201 -0
  2. package/README.md +495 -0
  3. package/binding.gyp +540 -0
  4. package/c/dispatch.h +512 -0
  5. package/c/dispatch_bf16.c +389 -0
  6. package/c/dispatch_bf16c.c +52 -0
  7. package/c/dispatch_e2m3.c +263 -0
  8. package/c/dispatch_e3m2.c +243 -0
  9. package/c/dispatch_e4m3.c +276 -0
  10. package/c/dispatch_e5m2.c +272 -0
  11. package/c/dispatch_f16.c +376 -0
  12. package/c/dispatch_f16c.c +58 -0
  13. package/c/dispatch_f32.c +378 -0
  14. package/c/dispatch_f32c.c +99 -0
  15. package/c/dispatch_f64.c +296 -0
  16. package/c/dispatch_f64c.c +98 -0
  17. package/c/dispatch_i16.c +96 -0
  18. package/c/dispatch_i32.c +89 -0
  19. package/c/dispatch_i4.c +150 -0
  20. package/c/dispatch_i64.c +86 -0
  21. package/c/dispatch_i8.c +289 -0
  22. package/c/dispatch_other.c +330 -0
  23. package/c/dispatch_u1.c +148 -0
  24. package/c/dispatch_u16.c +124 -0
  25. package/c/dispatch_u32.c +118 -0
  26. package/c/dispatch_u4.c +150 -0
  27. package/c/dispatch_u64.c +102 -0
  28. package/c/dispatch_u8.c +303 -0
  29. package/c/numkong.c +950 -0
  30. package/include/README.md +573 -0
  31. package/include/module.modulemap +129 -0
  32. package/include/numkong/attention/sapphireamx.h +1361 -0
  33. package/include/numkong/attention/sme.h +2066 -0
  34. package/include/numkong/attention.h +49 -0
  35. package/include/numkong/capabilities.h +748 -0
  36. package/include/numkong/cast/README.md +262 -0
  37. package/include/numkong/cast/haswell.h +975 -0
  38. package/include/numkong/cast/icelake.h +470 -0
  39. package/include/numkong/cast/neon.h +1192 -0
  40. package/include/numkong/cast/rvv.h +1021 -0
  41. package/include/numkong/cast/sapphire.h +262 -0
  42. package/include/numkong/cast/serial.h +2262 -0
  43. package/include/numkong/cast/skylake.h +856 -0
  44. package/include/numkong/cast/v128relaxed.h +180 -0
  45. package/include/numkong/cast.h +230 -0
  46. package/include/numkong/curved/README.md +223 -0
  47. package/include/numkong/curved/genoa.h +182 -0
  48. package/include/numkong/curved/haswell.h +276 -0
  49. package/include/numkong/curved/neon.h +205 -0
  50. package/include/numkong/curved/neonbfdot.h +212 -0
  51. package/include/numkong/curved/neonhalf.h +212 -0
  52. package/include/numkong/curved/rvv.h +305 -0
  53. package/include/numkong/curved/serial.h +207 -0
  54. package/include/numkong/curved/skylake.h +457 -0
  55. package/include/numkong/curved/smef64.h +506 -0
  56. package/include/numkong/curved.h +517 -0
  57. package/include/numkong/curved.hpp +144 -0
  58. package/include/numkong/dot/README.md +425 -0
  59. package/include/numkong/dot/alder.h +563 -0
  60. package/include/numkong/dot/genoa.h +315 -0
  61. package/include/numkong/dot/haswell.h +1688 -0
  62. package/include/numkong/dot/icelake.h +883 -0
  63. package/include/numkong/dot/neon.h +818 -0
  64. package/include/numkong/dot/neonbfdot.h +244 -0
  65. package/include/numkong/dot/neonfhm.h +360 -0
  66. package/include/numkong/dot/neonhalf.h +198 -0
  67. package/include/numkong/dot/neonsdot.h +508 -0
  68. package/include/numkong/dot/rvv.h +714 -0
  69. package/include/numkong/dot/rvvbb.h +72 -0
  70. package/include/numkong/dot/rvvbf16.h +123 -0
  71. package/include/numkong/dot/rvvhalf.h +129 -0
  72. package/include/numkong/dot/sapphire.h +141 -0
  73. package/include/numkong/dot/serial.h +838 -0
  74. package/include/numkong/dot/sierra.h +405 -0
  75. package/include/numkong/dot/skylake.h +1084 -0
  76. package/include/numkong/dot/sve.h +379 -0
  77. package/include/numkong/dot/svebfdot.h +74 -0
  78. package/include/numkong/dot/svehalf.h +123 -0
  79. package/include/numkong/dot/v128relaxed.h +1258 -0
  80. package/include/numkong/dot.h +1070 -0
  81. package/include/numkong/dot.hpp +94 -0
  82. package/include/numkong/dots/README.md +496 -0
  83. package/include/numkong/dots/alder.h +114 -0
  84. package/include/numkong/dots/genoa.h +94 -0
  85. package/include/numkong/dots/haswell.h +295 -0
  86. package/include/numkong/dots/icelake.h +171 -0
  87. package/include/numkong/dots/neon.h +120 -0
  88. package/include/numkong/dots/neonbfdot.h +58 -0
  89. package/include/numkong/dots/neonfhm.h +94 -0
  90. package/include/numkong/dots/neonhalf.h +57 -0
  91. package/include/numkong/dots/neonsdot.h +108 -0
  92. package/include/numkong/dots/rvv.h +2486 -0
  93. package/include/numkong/dots/sapphireamx.h +3973 -0
  94. package/include/numkong/dots/serial.h +2844 -0
  95. package/include/numkong/dots/sierra.h +97 -0
  96. package/include/numkong/dots/skylake.h +196 -0
  97. package/include/numkong/dots/sme.h +5372 -0
  98. package/include/numkong/dots/smebi32.h +461 -0
  99. package/include/numkong/dots/smef64.h +1318 -0
  100. package/include/numkong/dots/smehalf.h +47 -0
  101. package/include/numkong/dots/v128relaxed.h +294 -0
  102. package/include/numkong/dots.h +2804 -0
  103. package/include/numkong/dots.hpp +639 -0
  104. package/include/numkong/each/README.md +469 -0
  105. package/include/numkong/each/haswell.h +1658 -0
  106. package/include/numkong/each/icelake.h +272 -0
  107. package/include/numkong/each/neon.h +1104 -0
  108. package/include/numkong/each/neonbfdot.h +212 -0
  109. package/include/numkong/each/neonhalf.h +410 -0
  110. package/include/numkong/each/rvv.h +1121 -0
  111. package/include/numkong/each/sapphire.h +477 -0
  112. package/include/numkong/each/serial.h +260 -0
  113. package/include/numkong/each/skylake.h +1562 -0
  114. package/include/numkong/each.h +2146 -0
  115. package/include/numkong/each.hpp +434 -0
  116. package/include/numkong/geospatial/README.md +147 -0
  117. package/include/numkong/geospatial/haswell.h +593 -0
  118. package/include/numkong/geospatial/neon.h +571 -0
  119. package/include/numkong/geospatial/rvv.h +701 -0
  120. package/include/numkong/geospatial/serial.h +309 -0
  121. package/include/numkong/geospatial/skylake.h +577 -0
  122. package/include/numkong/geospatial/v128relaxed.h +613 -0
  123. package/include/numkong/geospatial.h +453 -0
  124. package/include/numkong/geospatial.hpp +235 -0
  125. package/include/numkong/matrix.hpp +336 -0
  126. package/include/numkong/maxsim/README.md +187 -0
  127. package/include/numkong/maxsim/alder.h +511 -0
  128. package/include/numkong/maxsim/genoa.h +115 -0
  129. package/include/numkong/maxsim/haswell.h +553 -0
  130. package/include/numkong/maxsim/icelake.h +480 -0
  131. package/include/numkong/maxsim/neonsdot.h +394 -0
  132. package/include/numkong/maxsim/sapphireamx.h +877 -0
  133. package/include/numkong/maxsim/serial.h +490 -0
  134. package/include/numkong/maxsim/sme.h +929 -0
  135. package/include/numkong/maxsim/v128relaxed.h +280 -0
  136. package/include/numkong/maxsim.h +571 -0
  137. package/include/numkong/maxsim.hpp +133 -0
  138. package/include/numkong/mesh/README.md +227 -0
  139. package/include/numkong/mesh/haswell.h +2235 -0
  140. package/include/numkong/mesh/neon.h +1329 -0
  141. package/include/numkong/mesh/neonbfdot.h +842 -0
  142. package/include/numkong/mesh/neonhalf.h +616 -0
  143. package/include/numkong/mesh/rvv.h +916 -0
  144. package/include/numkong/mesh/serial.h +742 -0
  145. package/include/numkong/mesh/skylake.h +1135 -0
  146. package/include/numkong/mesh/v128relaxed.h +1052 -0
  147. package/include/numkong/mesh.h +652 -0
  148. package/include/numkong/mesh.hpp +762 -0
  149. package/include/numkong/numkong.h +78 -0
  150. package/include/numkong/numkong.hpp +57 -0
  151. package/include/numkong/probability/README.md +173 -0
  152. package/include/numkong/probability/haswell.h +267 -0
  153. package/include/numkong/probability/neon.h +225 -0
  154. package/include/numkong/probability/rvv.h +409 -0
  155. package/include/numkong/probability/serial.h +169 -0
  156. package/include/numkong/probability/skylake.h +324 -0
  157. package/include/numkong/probability.h +383 -0
  158. package/include/numkong/probability.hpp +120 -0
  159. package/include/numkong/random.h +50 -0
  160. package/include/numkong/random.hpp +285 -0
  161. package/include/numkong/reduce/README.md +547 -0
  162. package/include/numkong/reduce/alder.h +632 -0
  163. package/include/numkong/reduce/genoa.h +201 -0
  164. package/include/numkong/reduce/haswell.h +3783 -0
  165. package/include/numkong/reduce/icelake.h +549 -0
  166. package/include/numkong/reduce/neon.h +3841 -0
  167. package/include/numkong/reduce/neonbfdot.h +353 -0
  168. package/include/numkong/reduce/neonfhm.h +665 -0
  169. package/include/numkong/reduce/neonhalf.h +157 -0
  170. package/include/numkong/reduce/neonsdot.h +357 -0
  171. package/include/numkong/reduce/rvv.h +3407 -0
  172. package/include/numkong/reduce/serial.h +757 -0
  173. package/include/numkong/reduce/sierra.h +338 -0
  174. package/include/numkong/reduce/skylake.h +3792 -0
  175. package/include/numkong/reduce/v128relaxed.h +2302 -0
  176. package/include/numkong/reduce.h +1597 -0
  177. package/include/numkong/reduce.hpp +633 -0
  178. package/include/numkong/scalar/README.md +89 -0
  179. package/include/numkong/scalar/haswell.h +113 -0
  180. package/include/numkong/scalar/neon.h +122 -0
  181. package/include/numkong/scalar/neonhalf.h +70 -0
  182. package/include/numkong/scalar/rvv.h +211 -0
  183. package/include/numkong/scalar/sapphire.h +63 -0
  184. package/include/numkong/scalar/serial.h +332 -0
  185. package/include/numkong/scalar/v128relaxed.h +56 -0
  186. package/include/numkong/scalar.h +683 -0
  187. package/include/numkong/set/README.md +179 -0
  188. package/include/numkong/set/haswell.h +334 -0
  189. package/include/numkong/set/icelake.h +485 -0
  190. package/include/numkong/set/neon.h +364 -0
  191. package/include/numkong/set/rvv.h +226 -0
  192. package/include/numkong/set/rvvbb.h +117 -0
  193. package/include/numkong/set/serial.h +174 -0
  194. package/include/numkong/set/sve.h +185 -0
  195. package/include/numkong/set/v128relaxed.h +240 -0
  196. package/include/numkong/set.h +457 -0
  197. package/include/numkong/set.hpp +114 -0
  198. package/include/numkong/sets/README.md +149 -0
  199. package/include/numkong/sets/haswell.h +63 -0
  200. package/include/numkong/sets/icelake.h +66 -0
  201. package/include/numkong/sets/neon.h +61 -0
  202. package/include/numkong/sets/serial.h +43 -0
  203. package/include/numkong/sets/smebi32.h +1099 -0
  204. package/include/numkong/sets/v128relaxed.h +58 -0
  205. package/include/numkong/sets.h +339 -0
  206. package/include/numkong/sparse/README.md +156 -0
  207. package/include/numkong/sparse/icelake.h +463 -0
  208. package/include/numkong/sparse/neon.h +288 -0
  209. package/include/numkong/sparse/serial.h +117 -0
  210. package/include/numkong/sparse/sve2.h +507 -0
  211. package/include/numkong/sparse/turin.h +322 -0
  212. package/include/numkong/sparse.h +363 -0
  213. package/include/numkong/sparse.hpp +113 -0
  214. package/include/numkong/spatial/README.md +435 -0
  215. package/include/numkong/spatial/alder.h +607 -0
  216. package/include/numkong/spatial/genoa.h +290 -0
  217. package/include/numkong/spatial/haswell.h +960 -0
  218. package/include/numkong/spatial/icelake.h +586 -0
  219. package/include/numkong/spatial/neon.h +773 -0
  220. package/include/numkong/spatial/neonbfdot.h +165 -0
  221. package/include/numkong/spatial/neonhalf.h +118 -0
  222. package/include/numkong/spatial/neonsdot.h +261 -0
  223. package/include/numkong/spatial/rvv.h +984 -0
  224. package/include/numkong/spatial/rvvbf16.h +123 -0
  225. package/include/numkong/spatial/rvvhalf.h +117 -0
  226. package/include/numkong/spatial/sapphire.h +343 -0
  227. package/include/numkong/spatial/serial.h +346 -0
  228. package/include/numkong/spatial/sierra.h +323 -0
  229. package/include/numkong/spatial/skylake.h +606 -0
  230. package/include/numkong/spatial/sve.h +224 -0
  231. package/include/numkong/spatial/svebfdot.h +122 -0
  232. package/include/numkong/spatial/svehalf.h +109 -0
  233. package/include/numkong/spatial/v128relaxed.h +717 -0
  234. package/include/numkong/spatial.h +1425 -0
  235. package/include/numkong/spatial.hpp +183 -0
  236. package/include/numkong/spatials/README.md +580 -0
  237. package/include/numkong/spatials/alder.h +94 -0
  238. package/include/numkong/spatials/genoa.h +94 -0
  239. package/include/numkong/spatials/haswell.h +219 -0
  240. package/include/numkong/spatials/icelake.h +113 -0
  241. package/include/numkong/spatials/neon.h +109 -0
  242. package/include/numkong/spatials/neonbfdot.h +60 -0
  243. package/include/numkong/spatials/neonfhm.h +92 -0
  244. package/include/numkong/spatials/neonhalf.h +58 -0
  245. package/include/numkong/spatials/neonsdot.h +109 -0
  246. package/include/numkong/spatials/rvv.h +1960 -0
  247. package/include/numkong/spatials/sapphireamx.h +1149 -0
  248. package/include/numkong/spatials/serial.h +226 -0
  249. package/include/numkong/spatials/sierra.h +96 -0
  250. package/include/numkong/spatials/skylake.h +184 -0
  251. package/include/numkong/spatials/sme.h +1901 -0
  252. package/include/numkong/spatials/smef64.h +465 -0
  253. package/include/numkong/spatials/v128relaxed.h +240 -0
  254. package/include/numkong/spatials.h +3021 -0
  255. package/include/numkong/spatials.hpp +508 -0
  256. package/include/numkong/tensor.hpp +1592 -0
  257. package/include/numkong/trigonometry/README.md +184 -0
  258. package/include/numkong/trigonometry/haswell.h +652 -0
  259. package/include/numkong/trigonometry/neon.h +639 -0
  260. package/include/numkong/trigonometry/rvv.h +699 -0
  261. package/include/numkong/trigonometry/serial.h +703 -0
  262. package/include/numkong/trigonometry/skylake.h +721 -0
  263. package/include/numkong/trigonometry/v128relaxed.h +666 -0
  264. package/include/numkong/trigonometry.h +467 -0
  265. package/include/numkong/trigonometry.hpp +166 -0
  266. package/include/numkong/types.h +1384 -0
  267. package/include/numkong/types.hpp +5603 -0
  268. package/include/numkong/vector.hpp +698 -0
  269. package/javascript/README.md +246 -0
  270. package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
  271. package/javascript/dist/cjs/numkong-wasm.js +617 -0
  272. package/javascript/dist/cjs/numkong.d.ts +343 -0
  273. package/javascript/dist/cjs/numkong.js +523 -0
  274. package/javascript/dist/cjs/package.json +3 -0
  275. package/javascript/dist/cjs/types.d.ts +284 -0
  276. package/javascript/dist/cjs/types.js +653 -0
  277. package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
  278. package/javascript/dist/esm/numkong-wasm.js +595 -0
  279. package/javascript/dist/esm/numkong.d.ts +343 -0
  280. package/javascript/dist/esm/numkong.js +452 -0
  281. package/javascript/dist/esm/package.json +3 -0
  282. package/javascript/dist/esm/types.d.ts +284 -0
  283. package/javascript/dist/esm/types.js +630 -0
  284. package/javascript/dist-package-cjs.json +3 -0
  285. package/javascript/dist-package-esm.json +3 -0
  286. package/javascript/node-gyp-build.d.ts +1 -0
  287. package/javascript/numkong-wasm.ts +756 -0
  288. package/javascript/numkong.c +689 -0
  289. package/javascript/numkong.ts +575 -0
  290. package/javascript/tsconfig-base.json +39 -0
  291. package/javascript/tsconfig-cjs.json +8 -0
  292. package/javascript/tsconfig-esm.json +8 -0
  293. package/javascript/types.ts +674 -0
  294. package/package.json +87 -0
@@ -0,0 +1,549 @@
1
+ /**
2
+ * @brief AVX-512 VNNI implementations for the redesigned reduction API (moments).
3
+ * @file include/numkong/reduce/icelake.h
4
+ * @author Ash Vardanian
5
+ * @date February 12, 2026
6
+ *
7
+ * @sa include/numkong/reduce.h
8
+ *
9
+ * @section vnni_advantage VNNI Advantage
10
+ *
11
+ * `_mm512_dpwssd_epi32(acc, a, b)` (VPDPWSSD) fuses `acc + _mm512_madd_epi16(a, b)`
12
+ * into one instruction (5cy @ p0 on Ice Lake, 4cy @ p01 on Genoa), saving one
13
+ * `_mm512_add_epi32` per call vs the Skylake `madd + add` pair.
14
+ */
15
+ #ifndef NK_REDUCE_ICELAKE_H
16
+ #define NK_REDUCE_ICELAKE_H
17
+
18
+ #if NK_TARGET_X86_
19
+ #if NK_TARGET_ICELAKE
20
+
21
+ #include "numkong/reduce/serial.h"
22
+
23
+ #if defined(__cplusplus)
24
+ extern "C" {
25
+ #endif
26
+
27
+ #if defined(__clang__)
28
+ #pragma clang attribute push( \
29
+ __attribute__((target("avx2,avx512f,avx512vl,avx512bw,avx512dq,avx512vnni,avx512vbmi,f16c,fma,bmi,bmi2"))), \
30
+ apply_to = function)
31
+ #elif defined(__GNUC__)
32
+ #pragma GCC push_options
33
+ #pragma GCC target("avx2", "avx512f", "avx512vl", "avx512bw", "avx512dq", "avx512vnni", "avx512vbmi", "f16c", "fma", \
34
+ "bmi", "bmi2")
35
+ #endif
36
+
37
+ NK_INTERNAL void nk_reduce_moments_i8_icelake_contiguous_( //
38
+ nk_i8_t const *data_ptr, nk_size_t count, //
39
+ nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
40
+ __m512i bias_i8x64 = _mm512_set1_epi8((char)0x80);
41
+ __m512i zero_i8x64 = _mm512_setzero_si512();
42
+ __m512i sum_u64x8 = _mm512_setzero_si512();
43
+ __m512i sumsq_low_i32x16 = _mm512_setzero_si512();
44
+ __m512i sumsq_high_i32x16 = _mm512_setzero_si512();
45
+ nk_size_t idx = 0;
46
+ for (; idx + 64 <= count; idx += 64) {
47
+ __m512i data_i8x64 = _mm512_loadu_si512(data_ptr + idx);
48
+ __m512i unsigned_i8x64 = _mm512_xor_si512(data_i8x64, bias_i8x64);
49
+ sum_u64x8 = _mm512_add_epi64(sum_u64x8, _mm512_sad_epu8(unsigned_i8x64, zero_i8x64));
50
+ __m512i low_i16x32 = _mm512_cvtepi8_epi16(_mm512_castsi512_si256(data_i8x64));
51
+ __m512i high_i16x32 = _mm512_cvtepi8_epi16(_mm512_extracti64x4_epi64(data_i8x64, 1));
52
+ sumsq_low_i32x16 = _mm512_dpwssd_epi32(sumsq_low_i32x16, low_i16x32, low_i16x32);
53
+ sumsq_high_i32x16 = _mm512_dpwssd_epi32(sumsq_high_i32x16, high_i16x32, high_i16x32);
54
+ }
55
+ nk_size_t remaining = count - idx;
56
+ if (remaining > 0) {
57
+ __mmask64 tail_mask = _bzhi_u64(0xFFFFFFFFFFFFFFFFull, (unsigned int)remaining);
58
+ __m512i data_i8x64 = _mm512_maskz_loadu_epi8(tail_mask, data_ptr + idx);
59
+ __m512i unsigned_i8x64 = _mm512_xor_si512(data_i8x64, _mm512_maskz_mov_epi8(tail_mask, bias_i8x64));
60
+ sum_u64x8 = _mm512_add_epi64(sum_u64x8, _mm512_sad_epu8(unsigned_i8x64, zero_i8x64));
61
+ __m512i low_i16x32 = _mm512_cvtepi8_epi16(_mm512_castsi512_si256(data_i8x64));
62
+ __m512i high_i16x32 = _mm512_cvtepi8_epi16(_mm512_extracti64x4_epi64(data_i8x64, 1));
63
+ sumsq_low_i32x16 = _mm512_dpwssd_epi32(sumsq_low_i32x16, low_i16x32, low_i16x32);
64
+ sumsq_high_i32x16 = _mm512_dpwssd_epi32(sumsq_high_i32x16, high_i16x32, high_i16x32);
65
+ }
66
+ sumsq_low_i32x16 = _mm512_add_epi32(sumsq_low_i32x16, sumsq_high_i32x16);
67
+ __m512i sumsq_i64x8 = _mm512_cvtepi32_epi64(_mm512_castsi512_si256(sumsq_low_i32x16));
68
+ sumsq_i64x8 = _mm512_add_epi64(sumsq_i64x8, _mm512_cvtepi32_epi64(_mm512_extracti64x4_epi64(sumsq_low_i32x16, 1)));
69
+ *sum_ptr = (nk_i64_t)nk_reduce_add_u64x8_skylake_(sum_u64x8) - (nk_i64_t)128 * (nk_i64_t)count;
70
+ *sumsq_ptr = (nk_u64_t)nk_reduce_add_i64x8_skylake_(sumsq_i64x8);
71
+ }
72
+
73
+ NK_INTERNAL void nk_reduce_moments_i8_icelake_strided_( //
74
+ nk_i8_t const *data_ptr, nk_size_t count, nk_size_t stride_elements, //
75
+ nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
76
+ __mmask64 stride_mask_m64 = nk_stride_mask_u1x64_(stride_elements);
77
+ __m512i masked_bias_i8x64 = _mm512_maskz_mov_epi8(stride_mask_m64, _mm512_set1_epi8((char)0x80));
78
+ __m512i zero_i8x64 = _mm512_setzero_si512();
79
+ __m512i sum_u64x8 = _mm512_setzero_si512();
80
+ __m512i sumsq_low_i32x16 = _mm512_setzero_si512();
81
+ __m512i sumsq_high_i32x16 = _mm512_setzero_si512();
82
+ nk_size_t idx_scalars = 0;
83
+ nk_size_t total_scalars = count * stride_elements;
84
+ nk_size_t step = (nk_size_t)_mm_popcnt_u64((nk_u64_t)stride_mask_m64) * stride_elements;
85
+ for (; idx_scalars + step <= total_scalars; idx_scalars += step) {
86
+ __m512i data_i8x64 = _mm512_maskz_loadu_epi8(stride_mask_m64, data_ptr + idx_scalars);
87
+ __m512i unsigned_i8x64 = _mm512_xor_si512(data_i8x64, masked_bias_i8x64);
88
+ sum_u64x8 = _mm512_add_epi64(sum_u64x8, _mm512_sad_epu8(unsigned_i8x64, zero_i8x64));
89
+ __m512i low_i16x32 = _mm512_cvtepi8_epi16(_mm512_castsi512_si256(data_i8x64));
90
+ __m512i high_i16x32 = _mm512_cvtepi8_epi16(_mm512_extracti64x4_epi64(data_i8x64, 1));
91
+ sumsq_low_i32x16 = _mm512_dpwssd_epi32(sumsq_low_i32x16, low_i16x32, low_i16x32);
92
+ sumsq_high_i32x16 = _mm512_dpwssd_epi32(sumsq_high_i32x16, high_i16x32, high_i16x32);
93
+ }
94
+ nk_size_t remaining_scalars = total_scalars - idx_scalars;
95
+ if (remaining_scalars > 0) {
96
+ __mmask64 tail_mask = stride_mask_m64 & _bzhi_u64(0xFFFFFFFFFFFFFFFFull, (unsigned int)remaining_scalars);
97
+ __m512i data_i8x64 = _mm512_maskz_loadu_epi8(tail_mask, data_ptr + idx_scalars);
98
+ __m512i unsigned_i8x64 = _mm512_xor_si512(data_i8x64,
99
+ _mm512_maskz_mov_epi8(tail_mask, _mm512_set1_epi8((char)0x80)));
100
+ sum_u64x8 = _mm512_add_epi64(sum_u64x8, _mm512_sad_epu8(unsigned_i8x64, zero_i8x64));
101
+ __m512i low_i16x32 = _mm512_cvtepi8_epi16(_mm512_castsi512_si256(data_i8x64));
102
+ __m512i high_i16x32 = _mm512_cvtepi8_epi16(_mm512_extracti64x4_epi64(data_i8x64, 1));
103
+ sumsq_low_i32x16 = _mm512_dpwssd_epi32(sumsq_low_i32x16, low_i16x32, low_i16x32);
104
+ sumsq_high_i32x16 = _mm512_dpwssd_epi32(sumsq_high_i32x16, high_i16x32, high_i16x32);
105
+ }
106
+ sumsq_low_i32x16 = _mm512_add_epi32(sumsq_low_i32x16, sumsq_high_i32x16);
107
+ __m512i sumsq_i64x8 = _mm512_cvtepi32_epi64(_mm512_castsi512_si256(sumsq_low_i32x16));
108
+ sumsq_i64x8 = _mm512_add_epi64(sumsq_i64x8, _mm512_cvtepi32_epi64(_mm512_extracti64x4_epi64(sumsq_low_i32x16, 1)));
109
+ *sum_ptr = (nk_i64_t)nk_reduce_add_u64x8_skylake_(sum_u64x8) - (nk_i64_t)128 * (nk_i64_t)count;
110
+ *sumsq_ptr = (nk_u64_t)nk_reduce_add_i64x8_skylake_(sumsq_i64x8);
111
+ }
112
+
113
+ NK_PUBLIC void nk_reduce_moments_i8_icelake( //
114
+ nk_i8_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
115
+ nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
116
+ nk_size_t stride_elements = stride_bytes / sizeof(nk_i8_t);
117
+ int aligned = (stride_bytes % sizeof(nk_i8_t) == 0);
118
+ if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
119
+ else if (!aligned) nk_reduce_moments_i8_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
120
+ else if (count > (nk_size_t)(NK_U16_MAX + 1) * 64) {
121
+ nk_size_t left_count = count / 2;
122
+ nk_i64_t left_sum, right_sum;
123
+ nk_u64_t left_sumsq, right_sumsq;
124
+ nk_reduce_moments_i8_icelake(data_ptr, left_count, stride_bytes, &left_sum, &left_sumsq);
125
+ nk_reduce_moments_i8_icelake(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
126
+ &right_sum, &right_sumsq);
127
+ *sum_ptr = nk_i64_saturating_add_serial(left_sum, right_sum);
128
+ *sumsq_ptr = nk_u64_saturating_add_serial(left_sumsq, right_sumsq);
129
+ }
130
+ else if (stride_elements == 1) nk_reduce_moments_i8_icelake_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
131
+ else if (stride_elements <= 16)
132
+ nk_reduce_moments_i8_icelake_strided_(data_ptr, count, stride_elements, sum_ptr, sumsq_ptr);
133
+ else nk_reduce_moments_i8_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
134
+ }
135
+
136
+ NK_INTERNAL void nk_reduce_moments_u8_icelake_contiguous_( //
137
+ nk_u8_t const *data_ptr, nk_size_t count, //
138
+ nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
139
+ __m512i zero_u8x64 = _mm512_setzero_si512();
140
+ __m512i sum_u64x8 = _mm512_setzero_si512();
141
+ __m512i sumsq_low_i32x16 = _mm512_setzero_si512();
142
+ __m512i sumsq_high_i32x16 = _mm512_setzero_si512();
143
+ nk_size_t idx = 0;
144
+ for (; idx + 64 <= count; idx += 64) {
145
+ __m512i data_u8x64 = _mm512_loadu_si512(data_ptr + idx);
146
+ sum_u64x8 = _mm512_add_epi64(sum_u64x8, _mm512_sad_epu8(data_u8x64, zero_u8x64));
147
+ __m512i low_i16x32 = _mm512_cvtepu8_epi16(_mm512_castsi512_si256(data_u8x64));
148
+ __m512i high_i16x32 = _mm512_cvtepu8_epi16(_mm512_extracti64x4_epi64(data_u8x64, 1));
149
+ sumsq_low_i32x16 = _mm512_dpwssd_epi32(sumsq_low_i32x16, low_i16x32, low_i16x32);
150
+ sumsq_high_i32x16 = _mm512_dpwssd_epi32(sumsq_high_i32x16, high_i16x32, high_i16x32);
151
+ }
152
+ nk_size_t remaining = count - idx;
153
+ if (remaining > 0) {
154
+ __mmask64 tail_mask = _bzhi_u64(0xFFFFFFFFFFFFFFFFull, (unsigned int)remaining);
155
+ __m512i data_u8x64 = _mm512_maskz_loadu_epi8(tail_mask, data_ptr + idx);
156
+ sum_u64x8 = _mm512_add_epi64(sum_u64x8, _mm512_sad_epu8(data_u8x64, zero_u8x64));
157
+ __m512i low_i16x32 = _mm512_cvtepu8_epi16(_mm512_castsi512_si256(data_u8x64));
158
+ __m512i high_i16x32 = _mm512_cvtepu8_epi16(_mm512_extracti64x4_epi64(data_u8x64, 1));
159
+ sumsq_low_i32x16 = _mm512_dpwssd_epi32(sumsq_low_i32x16, low_i16x32, low_i16x32);
160
+ sumsq_high_i32x16 = _mm512_dpwssd_epi32(sumsq_high_i32x16, high_i16x32, high_i16x32);
161
+ }
162
+ sumsq_low_i32x16 = _mm512_add_epi32(sumsq_low_i32x16, sumsq_high_i32x16);
163
+ __m512i sumsq_u64x8 = _mm512_cvtepu32_epi64(_mm512_castsi512_si256(sumsq_low_i32x16));
164
+ sumsq_u64x8 = _mm512_add_epi64(sumsq_u64x8, _mm512_cvtepu32_epi64(_mm512_extracti64x4_epi64(sumsq_low_i32x16, 1)));
165
+ *sum_ptr = nk_reduce_add_u64x8_skylake_(sum_u64x8);
166
+ *sumsq_ptr = nk_reduce_add_u64x8_skylake_(sumsq_u64x8);
167
+ }
168
+
169
+ NK_INTERNAL void nk_reduce_moments_u8_icelake_strided_( //
170
+ nk_u8_t const *data_ptr, nk_size_t count, nk_size_t stride_elements, //
171
+ nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
172
+ __mmask64 stride_mask_m64 = nk_stride_mask_u1x64_(stride_elements);
173
+ __m512i zero_u8x64 = _mm512_setzero_si512();
174
+ __m512i sum_u64x8 = _mm512_setzero_si512();
175
+ __m512i sumsq_low_i32x16 = _mm512_setzero_si512();
176
+ __m512i sumsq_high_i32x16 = _mm512_setzero_si512();
177
+ nk_size_t idx_scalars = 0;
178
+ nk_size_t total_scalars = count * stride_elements;
179
+ nk_size_t step = (nk_size_t)_mm_popcnt_u64((nk_u64_t)stride_mask_m64) * stride_elements;
180
+ for (; idx_scalars + step <= total_scalars; idx_scalars += step) {
181
+ __m512i data_u8x64 = _mm512_maskz_loadu_epi8(stride_mask_m64, data_ptr + idx_scalars);
182
+ sum_u64x8 = _mm512_add_epi64(sum_u64x8, _mm512_sad_epu8(data_u8x64, zero_u8x64));
183
+ __m512i low_i16x32 = _mm512_cvtepu8_epi16(_mm512_castsi512_si256(data_u8x64));
184
+ __m512i high_i16x32 = _mm512_cvtepu8_epi16(_mm512_extracti64x4_epi64(data_u8x64, 1));
185
+ sumsq_low_i32x16 = _mm512_dpwssd_epi32(sumsq_low_i32x16, low_i16x32, low_i16x32);
186
+ sumsq_high_i32x16 = _mm512_dpwssd_epi32(sumsq_high_i32x16, high_i16x32, high_i16x32);
187
+ }
188
+ nk_size_t remaining_scalars = total_scalars - idx_scalars;
189
+ if (remaining_scalars > 0) {
190
+ __mmask64 tail_mask = stride_mask_m64 & _bzhi_u64(0xFFFFFFFFFFFFFFFFull, (unsigned int)remaining_scalars);
191
+ __m512i data_u8x64 = _mm512_maskz_loadu_epi8(tail_mask, data_ptr + idx_scalars);
192
+ sum_u64x8 = _mm512_add_epi64(sum_u64x8, _mm512_sad_epu8(data_u8x64, zero_u8x64));
193
+ __m512i low_i16x32 = _mm512_cvtepu8_epi16(_mm512_castsi512_si256(data_u8x64));
194
+ __m512i high_i16x32 = _mm512_cvtepu8_epi16(_mm512_extracti64x4_epi64(data_u8x64, 1));
195
+ sumsq_low_i32x16 = _mm512_dpwssd_epi32(sumsq_low_i32x16, low_i16x32, low_i16x32);
196
+ sumsq_high_i32x16 = _mm512_dpwssd_epi32(sumsq_high_i32x16, high_i16x32, high_i16x32);
197
+ }
198
+ sumsq_low_i32x16 = _mm512_add_epi32(sumsq_low_i32x16, sumsq_high_i32x16);
199
+ __m512i sumsq_u64x8 = _mm512_cvtepu32_epi64(_mm512_castsi512_si256(sumsq_low_i32x16));
200
+ sumsq_u64x8 = _mm512_add_epi64(sumsq_u64x8, _mm512_cvtepu32_epi64(_mm512_extracti64x4_epi64(sumsq_low_i32x16, 1)));
201
+ *sum_ptr = nk_reduce_add_u64x8_skylake_(sum_u64x8);
202
+ *sumsq_ptr = nk_reduce_add_u64x8_skylake_(sumsq_u64x8);
203
+ }
204
+
205
+ NK_PUBLIC void nk_reduce_moments_u8_icelake( //
206
+ nk_u8_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
207
+ nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
208
+ nk_size_t stride_elements = stride_bytes / sizeof(nk_u8_t);
209
+ int aligned = (stride_bytes % sizeof(nk_u8_t) == 0);
210
+ if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
211
+ else if (!aligned) nk_reduce_moments_u8_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
212
+ else if (count > (nk_size_t)(NK_U8_MAX + 1) * 64) {
213
+ nk_size_t left_count = count / 2;
214
+ nk_u64_t left_sum, left_sumsq, right_sum, right_sumsq;
215
+ nk_reduce_moments_u8_icelake(data_ptr, left_count, stride_bytes, &left_sum, &left_sumsq);
216
+ nk_reduce_moments_u8_icelake(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
217
+ &right_sum, &right_sumsq);
218
+ *sum_ptr = nk_u64_saturating_add_serial(left_sum, right_sum);
219
+ *sumsq_ptr = nk_u64_saturating_add_serial(left_sumsq, right_sumsq);
220
+ }
221
+ else if (stride_elements == 1) nk_reduce_moments_u8_icelake_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
222
+ else if (stride_elements <= 16)
223
+ nk_reduce_moments_u8_icelake_strided_(data_ptr, count, stride_elements, sum_ptr, sumsq_ptr);
224
+ else nk_reduce_moments_u8_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
225
+ }
226
+
227
+ NK_INTERNAL void nk_reduce_moments_i16_icelake_contiguous_( //
228
+ nk_i16_t const *data_ptr, nk_size_t count, //
229
+ nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
230
+ // Sum: VPDPWSSD(acc, data, ones) accumulates in i32 — safe for (NK_I16_MAX+1)*32 elements.
231
+ // Sumsq: VPDPWSSD(zero, data, data) → fresh i32, widen to i64 each iteration.
232
+ __m512i ones_i16x32 = _mm512_set1_epi16(1);
233
+ __m512i sum_i32x16 = _mm512_setzero_si512();
234
+ __m512i sumsq_i64x8 = _mm512_setzero_si512();
235
+ nk_size_t idx = 0;
236
+ for (; idx + 32 <= count; idx += 32) {
237
+ __m512i data_i16x32 = _mm512_loadu_si512(data_ptr + idx);
238
+ sum_i32x16 = _mm512_dpwssd_epi32(sum_i32x16, data_i16x32, ones_i16x32);
239
+ __m512i sq_i32x16 = _mm512_dpwssd_epi32(_mm512_setzero_si512(), data_i16x32, data_i16x32);
240
+ sumsq_i64x8 = _mm512_add_epi64(sumsq_i64x8, _mm512_cvtepi32_epi64(_mm512_castsi512_si256(sq_i32x16)));
241
+ sumsq_i64x8 = _mm512_add_epi64(sumsq_i64x8, _mm512_cvtepi32_epi64(_mm512_extracti64x4_epi64(sq_i32x16, 1)));
242
+ }
243
+ nk_size_t remaining = count - idx;
244
+ if (remaining > 0) {
245
+ __mmask32 tail_mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, (unsigned int)remaining);
246
+ __m512i data_i16x32 = _mm512_maskz_loadu_epi16(tail_mask, data_ptr + idx);
247
+ sum_i32x16 = _mm512_dpwssd_epi32(sum_i32x16, data_i16x32, ones_i16x32);
248
+ __m512i sq_i32x16 = _mm512_dpwssd_epi32(_mm512_setzero_si512(), data_i16x32, data_i16x32);
249
+ sumsq_i64x8 = _mm512_add_epi64(sumsq_i64x8, _mm512_cvtepi32_epi64(_mm512_castsi512_si256(sq_i32x16)));
250
+ sumsq_i64x8 = _mm512_add_epi64(sumsq_i64x8, _mm512_cvtepi32_epi64(_mm512_extracti64x4_epi64(sq_i32x16, 1)));
251
+ }
252
+ __m512i sum_i64x8 = _mm512_add_epi64( //
253
+ _mm512_cvtepi32_epi64(_mm512_castsi512_si256(sum_i32x16)), //
254
+ _mm512_cvtepi32_epi64(_mm512_extracti64x4_epi64(sum_i32x16, 1))); //
255
+ *sum_ptr = nk_reduce_add_i64x8_skylake_(sum_i64x8);
256
+ *sumsq_ptr = (nk_u64_t)nk_reduce_add_i64x8_skylake_(sumsq_i64x8);
257
+ }
258
+
259
+ NK_INTERNAL void nk_reduce_moments_i16_icelake_strided_( //
260
+ nk_i16_t const *data_ptr, nk_size_t count, nk_size_t stride_elements, //
261
+ nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
262
+ __mmask32 stride_mask_m32 = nk_stride_mask_b16x32_(stride_elements);
263
+ __m512i ones_i16x32 = _mm512_set1_epi16(1);
264
+ __m512i sum_i32x16 = _mm512_setzero_si512();
265
+ __m512i sumsq_i64x8 = _mm512_setzero_si512();
266
+ nk_size_t idx_scalars = 0;
267
+ nk_size_t total_scalars = count * stride_elements;
268
+ nk_size_t step = (nk_size_t)_mm_popcnt_u64((nk_u64_t)stride_mask_m32) * stride_elements;
269
+ for (; idx_scalars + step <= total_scalars; idx_scalars += step) {
270
+ __m512i data_i16x32 = _mm512_maskz_loadu_epi16(stride_mask_m32, data_ptr + idx_scalars);
271
+ sum_i32x16 = _mm512_dpwssd_epi32(sum_i32x16, data_i16x32, ones_i16x32);
272
+ __m512i sq_i32x16 = _mm512_dpwssd_epi32(_mm512_setzero_si512(), data_i16x32, data_i16x32);
273
+ sumsq_i64x8 = _mm512_add_epi64(sumsq_i64x8, _mm512_cvtepi32_epi64(_mm512_castsi512_si256(sq_i32x16)));
274
+ sumsq_i64x8 = _mm512_add_epi64(sumsq_i64x8, _mm512_cvtepi32_epi64(_mm512_extracti64x4_epi64(sq_i32x16, 1)));
275
+ }
276
+ nk_size_t remaining_scalars = total_scalars - idx_scalars;
277
+ if (remaining_scalars > 0) {
278
+ __mmask32 tail_mask = stride_mask_m32 & (__mmask32)_bzhi_u32(0xFFFFFFFF, (unsigned int)(remaining_scalars));
279
+ __m512i data_i16x32 = _mm512_maskz_loadu_epi16(tail_mask, data_ptr + idx_scalars);
280
+ sum_i32x16 = _mm512_dpwssd_epi32(sum_i32x16, data_i16x32, ones_i16x32);
281
+ __m512i sq_i32x16 = _mm512_dpwssd_epi32(_mm512_setzero_si512(), data_i16x32, data_i16x32);
282
+ sumsq_i64x8 = _mm512_add_epi64(sumsq_i64x8, _mm512_cvtepi32_epi64(_mm512_castsi512_si256(sq_i32x16)));
283
+ sumsq_i64x8 = _mm512_add_epi64(sumsq_i64x8, _mm512_cvtepi32_epi64(_mm512_extracti64x4_epi64(sq_i32x16, 1)));
284
+ }
285
+ __m512i sum_i64x8 = _mm512_add_epi64( //
286
+ _mm512_cvtepi32_epi64(_mm512_castsi512_si256(sum_i32x16)), //
287
+ _mm512_cvtepi32_epi64(_mm512_extracti64x4_epi64(sum_i32x16, 1))); //
288
+ *sum_ptr = nk_reduce_add_i64x8_skylake_(sum_i64x8);
289
+ *sumsq_ptr = (nk_u64_t)nk_reduce_add_i64x8_skylake_(sumsq_i64x8);
290
+ }
291
+
292
+ NK_PUBLIC void nk_reduce_moments_i16_icelake( //
293
+ nk_i16_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
294
+ nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
295
+ nk_size_t stride_elements = stride_bytes / sizeof(nk_i16_t);
296
+ int aligned = (stride_bytes % sizeof(nk_i16_t) == 0);
297
+ if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
298
+ else if (!aligned) nk_reduce_moments_i16_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
299
+ else if (count > (nk_size_t)(NK_I16_MAX + 1) * 32) {
300
+ nk_size_t left_count = count / 2;
301
+ nk_i64_t left_sum, right_sum;
302
+ nk_u64_t left_sumsq, right_sumsq;
303
+ nk_reduce_moments_i16_icelake(data_ptr, left_count, stride_bytes, &left_sum, &left_sumsq);
304
+ nk_reduce_moments_i16_icelake(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
305
+ &right_sum, &right_sumsq);
306
+ *sum_ptr = nk_i64_saturating_add_serial(left_sum, right_sum);
307
+ *sumsq_ptr = nk_u64_saturating_add_serial(left_sumsq, right_sumsq);
308
+ }
309
+ else if (stride_elements == 1) nk_reduce_moments_i16_icelake_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
310
+ else if (stride_elements <= 16)
311
+ nk_reduce_moments_i16_icelake_strided_(data_ptr, count, stride_elements, sum_ptr, sumsq_ptr);
312
+ else nk_reduce_moments_i16_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
313
+ }
314
+
315
+ NK_INTERNAL void nk_reduce_moments_e2m3_icelake_contiguous_( //
316
+ nk_e2m3_t const *data_ptr, nk_size_t count, //
317
+ nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
318
+ // 64-byte LUT: maps 5-bit unsigned magnitude -> value*16 as u8 (0..120)
319
+ // Entries 0-31 replicated in upper 32 bytes (VPERMB indexes mod 64)
320
+ __m512i const lut_magnitude_u8x64 = _mm512_set_epi8(120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36,
321
+ 32, 30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0,
322
+ 120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36,
323
+ 32, 30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
324
+ __m512i const magnitude_mask_u8x64 = _mm512_set1_epi8(0x1F);
325
+ __m512i const sign_mask_u8x64 = _mm512_set1_epi8(0x20);
326
+ __m512i const ones_u8x64 = _mm512_set1_epi8(1);
327
+ __m512i sum_i32x16 = _mm512_setzero_si512();
328
+ __m512i sumsq_i32x16 = _mm512_setzero_si512();
329
+ nk_size_t idx = 0;
330
+ for (; idx + 64 <= count; idx += 64) {
331
+ __m512i data_u8x64 = _mm512_loadu_si512(data_ptr + idx);
332
+ // Extract 5-bit magnitude, LUT lookup
333
+ __m512i magnitude_u8x64 = _mm512_and_si512(data_u8x64, magnitude_mask_u8x64);
334
+ __m512i unsigned_mag_u8x64 = _mm512_permutexvar_epi8(magnitude_u8x64, lut_magnitude_u8x64);
335
+ // Apply sign for sum: negate where bit 5 is set
336
+ __mmask64 sign_mask = _mm512_test_epi8_mask(data_u8x64, sign_mask_u8x64);
337
+ __m512i signed_mag_i8x64 = _mm512_mask_sub_epi8(unsigned_mag_u8x64, sign_mask, _mm512_setzero_si512(),
338
+ unsigned_mag_u8x64);
339
+ // Sum: VPDPBUSD(acc, ones_u8, signed_i8) = acc + sum(1 * signed_val) per 4-byte group
340
+ sum_i32x16 = _mm512_dpbusd_epi32(sum_i32x16, ones_u8x64, signed_mag_i8x64);
341
+ // Sumsq: VPDPBUSD(acc, unsigned_mag, unsigned_mag) = acc + sum(mag^2) per 4-byte group
342
+ // magnitude is 0-120, fits in both u8 and i8 interpretations
343
+ sumsq_i32x16 = _mm512_dpbusd_epi32(sumsq_i32x16, unsigned_mag_u8x64, unsigned_mag_u8x64);
344
+ }
345
+ nk_size_t remaining = count - idx;
346
+ if (remaining > 0) {
347
+ __mmask64 tail_mask = _bzhi_u64(0xFFFFFFFFFFFFFFFFull, (unsigned int)remaining);
348
+ __m512i data_u8x64 = _mm512_maskz_loadu_epi8(tail_mask, data_ptr + idx);
349
+ __m512i magnitude_u8x64 = _mm512_and_si512(data_u8x64, magnitude_mask_u8x64);
350
+ __m512i unsigned_mag_u8x64 = _mm512_permutexvar_epi8(magnitude_u8x64, lut_magnitude_u8x64);
351
+ __mmask64 sign_mask = _mm512_test_epi8_mask(data_u8x64, sign_mask_u8x64);
352
+ __m512i signed_mag_i8x64 = _mm512_mask_sub_epi8(unsigned_mag_u8x64, sign_mask, _mm512_setzero_si512(),
353
+ unsigned_mag_u8x64);
354
+ sum_i32x16 = _mm512_dpbusd_epi32(sum_i32x16, ones_u8x64, signed_mag_i8x64);
355
+ sumsq_i32x16 = _mm512_dpbusd_epi32(sumsq_i32x16, unsigned_mag_u8x64, unsigned_mag_u8x64);
356
+ }
357
+ *sum_ptr = (nk_f32_t)_mm512_reduce_add_epi32(sum_i32x16) / 16.0f;
358
+ *sumsq_ptr = (nk_f32_t)_mm512_reduce_add_epi32(sumsq_i32x16) / 256.0f;
359
+ }
360
+
361
+ NK_INTERNAL void nk_reduce_moments_e2m3_icelake_strided_( //
362
+ nk_e2m3_t const *data_ptr, nk_size_t count, nk_size_t stride_elements, //
363
+ nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
364
+ __mmask64 stride_mask_m64 = nk_stride_mask_u1x64_(stride_elements);
365
+ __m512i const lut_magnitude_u8x64 = _mm512_set_epi8(120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36,
366
+ 32, 30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0,
367
+ 120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36,
368
+ 32, 30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
369
+ __m512i const magnitude_mask_u8x64 = _mm512_set1_epi8(0x1F);
370
+ __m512i const sign_mask_u8x64 = _mm512_set1_epi8(0x20);
371
+ __m512i const ones_u8x64 = _mm512_set1_epi8(1);
372
+ __m512i sum_i32x16 = _mm512_setzero_si512();
373
+ __m512i sumsq_i32x16 = _mm512_setzero_si512();
374
+ nk_size_t idx_scalars = 0;
375
+ nk_size_t total_scalars = count * stride_elements;
376
+ nk_size_t step = (nk_size_t)_mm_popcnt_u64((nk_u64_t)stride_mask_m64) * stride_elements;
377
+ for (; idx_scalars + step <= total_scalars; idx_scalars += step) {
378
+ __m512i data_u8x64 = _mm512_maskz_loadu_epi8(stride_mask_m64, data_ptr + idx_scalars);
379
+ __m512i magnitude_u8x64 = _mm512_and_si512(data_u8x64, magnitude_mask_u8x64);
380
+ __m512i unsigned_mag_u8x64 = _mm512_permutexvar_epi8(magnitude_u8x64, lut_magnitude_u8x64);
381
+ __mmask64 sign_mask = _mm512_test_epi8_mask(data_u8x64, sign_mask_u8x64);
382
+ __m512i signed_mag_i8x64 = _mm512_mask_sub_epi8(unsigned_mag_u8x64, sign_mask, _mm512_setzero_si512(),
383
+ unsigned_mag_u8x64);
384
+ sum_i32x16 = _mm512_dpbusd_epi32(sum_i32x16, ones_u8x64, signed_mag_i8x64);
385
+ sumsq_i32x16 = _mm512_dpbusd_epi32(sumsq_i32x16, unsigned_mag_u8x64, unsigned_mag_u8x64);
386
+ }
387
+ nk_size_t remaining_scalars = total_scalars - idx_scalars;
388
+ if (remaining_scalars > 0) {
389
+ __mmask64 tail_mask = stride_mask_m64 & _bzhi_u64(0xFFFFFFFFFFFFFFFFull, (unsigned int)remaining_scalars);
390
+ __m512i data_u8x64 = _mm512_maskz_loadu_epi8(tail_mask, data_ptr + idx_scalars);
391
+ __m512i magnitude_u8x64 = _mm512_and_si512(data_u8x64, magnitude_mask_u8x64);
392
+ __m512i unsigned_mag_u8x64 = _mm512_permutexvar_epi8(magnitude_u8x64, lut_magnitude_u8x64);
393
+ __mmask64 sign_mask = _mm512_test_epi8_mask(data_u8x64, sign_mask_u8x64);
394
+ __m512i signed_mag_i8x64 = _mm512_mask_sub_epi8(unsigned_mag_u8x64, sign_mask, _mm512_setzero_si512(),
395
+ unsigned_mag_u8x64);
396
+ sum_i32x16 = _mm512_dpbusd_epi32(sum_i32x16, ones_u8x64, signed_mag_i8x64);
397
+ sumsq_i32x16 = _mm512_dpbusd_epi32(sumsq_i32x16, unsigned_mag_u8x64, unsigned_mag_u8x64);
398
+ }
399
+ *sum_ptr = (nk_f32_t)_mm512_reduce_add_epi32(sum_i32x16) / 16.0f;
400
+ *sumsq_ptr = (nk_f32_t)_mm512_reduce_add_epi32(sumsq_i32x16) / 256.0f;
401
+ }
402
+
403
+ NK_PUBLIC void nk_reduce_moments_e2m3_icelake( //
404
+ nk_e2m3_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
405
+ nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
406
+ nk_size_t stride_elements = stride_bytes / sizeof(nk_e2m3_t);
407
+ int aligned = (stride_bytes % sizeof(nk_e2m3_t) == 0);
408
+ if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
409
+ else if (!aligned) nk_reduce_moments_e2m3_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
410
+ else if (count > (nk_size_t)(NK_I16_MAX + 1) * 64) {
411
+ nk_size_t left_count = count / 2;
412
+ nk_f32_t left_sum, left_sumsq, right_sum, right_sumsq;
413
+ nk_reduce_moments_e2m3_icelake(data_ptr, left_count, stride_bytes, &left_sum, &left_sumsq);
414
+ nk_reduce_moments_e2m3_icelake(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
415
+ &right_sum, &right_sumsq);
416
+ *sum_ptr = left_sum + right_sum;
417
+ *sumsq_ptr = left_sumsq + right_sumsq;
418
+ }
419
+ else if (stride_elements == 1) nk_reduce_moments_e2m3_icelake_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
420
+ else if (stride_elements <= 16)
421
+ nk_reduce_moments_e2m3_icelake_strided_(data_ptr, count, stride_elements, sum_ptr, sumsq_ptr);
422
+ else nk_reduce_moments_e2m3_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
423
+ }
424
+
425
+ NK_INTERNAL void nk_reduce_moments_e3m2_icelake_contiguous_( //
426
+ nk_e3m2_t const *data_ptr, nk_size_t count, //
427
+ nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
428
+ // 32-entry i16 LUT: maps 5-bit unsigned magnitude -> value*16 as i16 (0..448)
429
+ __m512i const lut_magnitude_i16x32 = _mm512_set_epi16(448, 384, 320, 256, 224, 192, 160, 128, 112, 96, 80, 64, 56,
430
+ 48, 40, 32, 28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2,
431
+ 1, 0);
432
+ __m512i const magnitude_mask_i16x32 = _mm512_set1_epi16(0x1F);
433
+ __m512i const sign_mask_i16x32 = _mm512_set1_epi16(0x20);
434
+ __m512i const ones_i16x32 = _mm512_set1_epi16(1);
435
+ __m512i sum_i32x16 = _mm512_setzero_si512();
436
+ __m512i sumsq_i32x16 = _mm512_setzero_si512();
437
+ nk_size_t idx = 0;
438
+ for (; idx + 32 <= count; idx += 32) {
439
+ // Load 32 bytes, widen u8->u16
440
+ __m256i data_u8x32 = _mm256_loadu_si256((__m256i const *)(data_ptr + idx));
441
+ __m512i data_u16x32 = _mm512_cvtepu8_epi16(data_u8x32);
442
+ // Extract 5-bit magnitude, VPERMW LUT lookup
443
+ __m512i magnitude_u16x32 = _mm512_and_si512(data_u16x32, magnitude_mask_i16x32);
444
+ __m512i unsigned_mag_i16x32 = _mm512_permutexvar_epi16(magnitude_u16x32, lut_magnitude_i16x32);
445
+ // Apply sign for sum: negate where bit 5 is set
446
+ __mmask32 sign_mask = _mm512_test_epi16_mask(data_u16x32, sign_mask_i16x32);
447
+ __m512i signed_mag_i16x32 = _mm512_mask_sub_epi16(unsigned_mag_i16x32, sign_mask, _mm512_setzero_si512(),
448
+ unsigned_mag_i16x32);
449
+ // Sum: VPMADDWD(signed_i16, ones) = sum of pairs -> i32
450
+ sum_i32x16 = _mm512_add_epi32(sum_i32x16, _mm512_madd_epi16(signed_mag_i16x32, ones_i16x32));
451
+ // Sumsq: VPMADDWD(unsigned_mag, unsigned_mag) = sum of pairs of squares -> i32
452
+ // max per i32: 2 * 448^2 = 401408, fits in i32
453
+ sumsq_i32x16 = _mm512_add_epi32(sumsq_i32x16, _mm512_madd_epi16(unsigned_mag_i16x32, unsigned_mag_i16x32));
454
+ }
455
+ nk_size_t remaining = count - idx;
456
+ if (remaining > 0) {
457
+ __mmask32 tail_mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, (unsigned int)remaining);
458
+ __m256i data_u8x32 = _mm256_maskz_loadu_epi8(tail_mask, data_ptr + idx);
459
+ __m512i data_u16x32 = _mm512_cvtepu8_epi16(data_u8x32);
460
+ __m512i magnitude_u16x32 = _mm512_and_si512(data_u16x32, magnitude_mask_i16x32);
461
+ __m512i unsigned_mag_i16x32 = _mm512_permutexvar_epi16(magnitude_u16x32, lut_magnitude_i16x32);
462
+ __mmask32 sign_mask = _mm512_test_epi16_mask(data_u16x32, sign_mask_i16x32);
463
+ __m512i signed_mag_i16x32 = _mm512_mask_sub_epi16(unsigned_mag_i16x32, sign_mask, _mm512_setzero_si512(),
464
+ unsigned_mag_i16x32);
465
+ sum_i32x16 = _mm512_add_epi32(sum_i32x16, _mm512_madd_epi16(signed_mag_i16x32, ones_i16x32));
466
+ sumsq_i32x16 = _mm512_add_epi32(sumsq_i32x16, _mm512_madd_epi16(unsigned_mag_i16x32, unsigned_mag_i16x32));
467
+ }
468
+ *sum_ptr = (nk_f32_t)_mm512_reduce_add_epi32(sum_i32x16) / 16.0f;
469
+ *sumsq_ptr = (nk_f32_t)_mm512_reduce_add_epi32(sumsq_i32x16) / 256.0f;
470
+ }
471
+
472
+ NK_INTERNAL void nk_reduce_moments_e3m2_icelake_strided_( //
473
+ nk_e3m2_t const *data_ptr, nk_size_t count, nk_size_t stride_elements, //
474
+ nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
475
+ __mmask32 stride_mask_m32 = (__mmask32)nk_stride_mask_u1x64_(stride_elements);
476
+ __m512i const lut_magnitude_i16x32 = _mm512_set_epi16(448, 384, 320, 256, 224, 192, 160, 128, 112, 96, 80, 64, 56,
477
+ 48, 40, 32, 28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2,
478
+ 1, 0);
479
+ __m512i const magnitude_mask_i16x32 = _mm512_set1_epi16(0x1F);
480
+ __m512i const sign_mask_i16x32 = _mm512_set1_epi16(0x20);
481
+ __m512i const ones_i16x32 = _mm512_set1_epi16(1);
482
+ __m512i sum_i32x16 = _mm512_setzero_si512();
483
+ __m512i sumsq_i32x16 = _mm512_setzero_si512();
484
+ nk_size_t idx_scalars = 0;
485
+ nk_size_t total_scalars = count * stride_elements;
486
+ nk_size_t step = (nk_size_t)_mm_popcnt_u64((nk_u64_t)stride_mask_m32) * stride_elements;
487
+ for (; idx_scalars + step <= total_scalars; idx_scalars += step) {
488
+ __m256i data_u8x32 = _mm256_maskz_loadu_epi8(stride_mask_m32, data_ptr + idx_scalars);
489
+ __m512i data_u16x32 = _mm512_cvtepu8_epi16(data_u8x32);
490
+ __m512i magnitude_u16x32 = _mm512_and_si512(data_u16x32, magnitude_mask_i16x32);
491
+ __m512i unsigned_mag_i16x32 = _mm512_permutexvar_epi16(magnitude_u16x32, lut_magnitude_i16x32);
492
+ __mmask32 sign_mask = _mm512_test_epi16_mask(data_u16x32, sign_mask_i16x32);
493
+ __m512i signed_mag_i16x32 = _mm512_mask_sub_epi16(unsigned_mag_i16x32, sign_mask, _mm512_setzero_si512(),
494
+ unsigned_mag_i16x32);
495
+ sum_i32x16 = _mm512_add_epi32(sum_i32x16, _mm512_madd_epi16(signed_mag_i16x32, ones_i16x32));
496
+ sumsq_i32x16 = _mm512_add_epi32(sumsq_i32x16, _mm512_madd_epi16(unsigned_mag_i16x32, unsigned_mag_i16x32));
497
+ }
498
+ nk_size_t remaining_scalars = total_scalars - idx_scalars;
499
+ if (remaining_scalars > 0) {
500
+ __mmask32 tail_mask = stride_mask_m32 & (__mmask32)_bzhi_u32(0xFFFFFFFF, (unsigned int)remaining_scalars);
501
+ __m256i data_u8x32 = _mm256_maskz_loadu_epi8(tail_mask, data_ptr + idx_scalars);
502
+ __m512i data_u16x32 = _mm512_cvtepu8_epi16(data_u8x32);
503
+ __m512i magnitude_u16x32 = _mm512_and_si512(data_u16x32, magnitude_mask_i16x32);
504
+ __m512i unsigned_mag_i16x32 = _mm512_permutexvar_epi16(magnitude_u16x32, lut_magnitude_i16x32);
505
+ __mmask32 sign_mask = _mm512_test_epi16_mask(data_u16x32, sign_mask_i16x32);
506
+ __m512i signed_mag_i16x32 = _mm512_mask_sub_epi16(unsigned_mag_i16x32, sign_mask, _mm512_setzero_si512(),
507
+ unsigned_mag_i16x32);
508
+ sum_i32x16 = _mm512_add_epi32(sum_i32x16, _mm512_madd_epi16(signed_mag_i16x32, ones_i16x32));
509
+ sumsq_i32x16 = _mm512_add_epi32(sumsq_i32x16, _mm512_madd_epi16(unsigned_mag_i16x32, unsigned_mag_i16x32));
510
+ }
511
+ *sum_ptr = (nk_f32_t)_mm512_reduce_add_epi32(sum_i32x16) / 16.0f;
512
+ *sumsq_ptr = (nk_f32_t)_mm512_reduce_add_epi32(sumsq_i32x16) / 256.0f;
513
+ }
514
+
515
+ NK_PUBLIC void nk_reduce_moments_e3m2_icelake( //
516
+ nk_e3m2_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
517
+ nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
518
+ nk_size_t stride_elements = stride_bytes / sizeof(nk_e3m2_t);
519
+ int aligned = (stride_bytes % sizeof(nk_e3m2_t) == 0);
520
+ if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
521
+ else if (!aligned) nk_reduce_moments_e3m2_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
522
+ else if (count > (nk_size_t)2048 * 64) {
523
+ nk_size_t left_count = count / 2;
524
+ nk_f32_t left_sum, left_sumsq, right_sum, right_sumsq;
525
+ nk_reduce_moments_e3m2_icelake(data_ptr, left_count, stride_bytes, &left_sum, &left_sumsq);
526
+ nk_reduce_moments_e3m2_icelake(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
527
+ &right_sum, &right_sumsq);
528
+ *sum_ptr = left_sum + right_sum;
529
+ *sumsq_ptr = left_sumsq + right_sumsq;
530
+ }
531
+ else if (stride_elements == 1) nk_reduce_moments_e3m2_icelake_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
532
+ else if (stride_elements <= 16)
533
+ nk_reduce_moments_e3m2_icelake_strided_(data_ptr, count, stride_elements, sum_ptr, sumsq_ptr);
534
+ else nk_reduce_moments_e3m2_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
535
+ }
536
+
537
+ #if defined(__clang__)
538
+ #pragma clang attribute pop
539
+ #elif defined(__GNUC__)
540
+ #pragma GCC pop_options
541
+ #endif
542
+
543
+ #if defined(__cplusplus)
544
+ } // extern "C"
545
+ #endif
546
+
547
+ #endif // NK_TARGET_ICELAKE
548
+ #endif // NK_TARGET_X86_
549
+ #endif // NK_REDUCE_ICELAKE_H