numkong 7.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +201 -0
- package/README.md +495 -0
- package/binding.gyp +540 -0
- package/c/dispatch.h +512 -0
- package/c/dispatch_bf16.c +389 -0
- package/c/dispatch_bf16c.c +52 -0
- package/c/dispatch_e2m3.c +263 -0
- package/c/dispatch_e3m2.c +243 -0
- package/c/dispatch_e4m3.c +276 -0
- package/c/dispatch_e5m2.c +272 -0
- package/c/dispatch_f16.c +376 -0
- package/c/dispatch_f16c.c +58 -0
- package/c/dispatch_f32.c +378 -0
- package/c/dispatch_f32c.c +99 -0
- package/c/dispatch_f64.c +296 -0
- package/c/dispatch_f64c.c +98 -0
- package/c/dispatch_i16.c +96 -0
- package/c/dispatch_i32.c +89 -0
- package/c/dispatch_i4.c +150 -0
- package/c/dispatch_i64.c +86 -0
- package/c/dispatch_i8.c +289 -0
- package/c/dispatch_other.c +330 -0
- package/c/dispatch_u1.c +148 -0
- package/c/dispatch_u16.c +124 -0
- package/c/dispatch_u32.c +118 -0
- package/c/dispatch_u4.c +150 -0
- package/c/dispatch_u64.c +102 -0
- package/c/dispatch_u8.c +303 -0
- package/c/numkong.c +950 -0
- package/include/README.md +573 -0
- package/include/module.modulemap +129 -0
- package/include/numkong/attention/sapphireamx.h +1361 -0
- package/include/numkong/attention/sme.h +2066 -0
- package/include/numkong/attention.h +49 -0
- package/include/numkong/capabilities.h +748 -0
- package/include/numkong/cast/README.md +262 -0
- package/include/numkong/cast/haswell.h +975 -0
- package/include/numkong/cast/icelake.h +470 -0
- package/include/numkong/cast/neon.h +1192 -0
- package/include/numkong/cast/rvv.h +1021 -0
- package/include/numkong/cast/sapphire.h +262 -0
- package/include/numkong/cast/serial.h +2262 -0
- package/include/numkong/cast/skylake.h +856 -0
- package/include/numkong/cast/v128relaxed.h +180 -0
- package/include/numkong/cast.h +230 -0
- package/include/numkong/curved/README.md +223 -0
- package/include/numkong/curved/genoa.h +182 -0
- package/include/numkong/curved/haswell.h +276 -0
- package/include/numkong/curved/neon.h +205 -0
- package/include/numkong/curved/neonbfdot.h +212 -0
- package/include/numkong/curved/neonhalf.h +212 -0
- package/include/numkong/curved/rvv.h +305 -0
- package/include/numkong/curved/serial.h +207 -0
- package/include/numkong/curved/skylake.h +457 -0
- package/include/numkong/curved/smef64.h +506 -0
- package/include/numkong/curved.h +517 -0
- package/include/numkong/curved.hpp +144 -0
- package/include/numkong/dot/README.md +425 -0
- package/include/numkong/dot/alder.h +563 -0
- package/include/numkong/dot/genoa.h +315 -0
- package/include/numkong/dot/haswell.h +1688 -0
- package/include/numkong/dot/icelake.h +883 -0
- package/include/numkong/dot/neon.h +818 -0
- package/include/numkong/dot/neonbfdot.h +244 -0
- package/include/numkong/dot/neonfhm.h +360 -0
- package/include/numkong/dot/neonhalf.h +198 -0
- package/include/numkong/dot/neonsdot.h +508 -0
- package/include/numkong/dot/rvv.h +714 -0
- package/include/numkong/dot/rvvbb.h +72 -0
- package/include/numkong/dot/rvvbf16.h +123 -0
- package/include/numkong/dot/rvvhalf.h +129 -0
- package/include/numkong/dot/sapphire.h +141 -0
- package/include/numkong/dot/serial.h +838 -0
- package/include/numkong/dot/sierra.h +405 -0
- package/include/numkong/dot/skylake.h +1084 -0
- package/include/numkong/dot/sve.h +379 -0
- package/include/numkong/dot/svebfdot.h +74 -0
- package/include/numkong/dot/svehalf.h +123 -0
- package/include/numkong/dot/v128relaxed.h +1258 -0
- package/include/numkong/dot.h +1070 -0
- package/include/numkong/dot.hpp +94 -0
- package/include/numkong/dots/README.md +496 -0
- package/include/numkong/dots/alder.h +114 -0
- package/include/numkong/dots/genoa.h +94 -0
- package/include/numkong/dots/haswell.h +295 -0
- package/include/numkong/dots/icelake.h +171 -0
- package/include/numkong/dots/neon.h +120 -0
- package/include/numkong/dots/neonbfdot.h +58 -0
- package/include/numkong/dots/neonfhm.h +94 -0
- package/include/numkong/dots/neonhalf.h +57 -0
- package/include/numkong/dots/neonsdot.h +108 -0
- package/include/numkong/dots/rvv.h +2486 -0
- package/include/numkong/dots/sapphireamx.h +3973 -0
- package/include/numkong/dots/serial.h +2844 -0
- package/include/numkong/dots/sierra.h +97 -0
- package/include/numkong/dots/skylake.h +196 -0
- package/include/numkong/dots/sme.h +5372 -0
- package/include/numkong/dots/smebi32.h +461 -0
- package/include/numkong/dots/smef64.h +1318 -0
- package/include/numkong/dots/smehalf.h +47 -0
- package/include/numkong/dots/v128relaxed.h +294 -0
- package/include/numkong/dots.h +2804 -0
- package/include/numkong/dots.hpp +639 -0
- package/include/numkong/each/README.md +469 -0
- package/include/numkong/each/haswell.h +1658 -0
- package/include/numkong/each/icelake.h +272 -0
- package/include/numkong/each/neon.h +1104 -0
- package/include/numkong/each/neonbfdot.h +212 -0
- package/include/numkong/each/neonhalf.h +410 -0
- package/include/numkong/each/rvv.h +1121 -0
- package/include/numkong/each/sapphire.h +477 -0
- package/include/numkong/each/serial.h +260 -0
- package/include/numkong/each/skylake.h +1562 -0
- package/include/numkong/each.h +2146 -0
- package/include/numkong/each.hpp +434 -0
- package/include/numkong/geospatial/README.md +147 -0
- package/include/numkong/geospatial/haswell.h +593 -0
- package/include/numkong/geospatial/neon.h +571 -0
- package/include/numkong/geospatial/rvv.h +701 -0
- package/include/numkong/geospatial/serial.h +309 -0
- package/include/numkong/geospatial/skylake.h +577 -0
- package/include/numkong/geospatial/v128relaxed.h +613 -0
- package/include/numkong/geospatial.h +453 -0
- package/include/numkong/geospatial.hpp +235 -0
- package/include/numkong/matrix.hpp +336 -0
- package/include/numkong/maxsim/README.md +187 -0
- package/include/numkong/maxsim/alder.h +511 -0
- package/include/numkong/maxsim/genoa.h +115 -0
- package/include/numkong/maxsim/haswell.h +553 -0
- package/include/numkong/maxsim/icelake.h +480 -0
- package/include/numkong/maxsim/neonsdot.h +394 -0
- package/include/numkong/maxsim/sapphireamx.h +877 -0
- package/include/numkong/maxsim/serial.h +490 -0
- package/include/numkong/maxsim/sme.h +929 -0
- package/include/numkong/maxsim/v128relaxed.h +280 -0
- package/include/numkong/maxsim.h +571 -0
- package/include/numkong/maxsim.hpp +133 -0
- package/include/numkong/mesh/README.md +227 -0
- package/include/numkong/mesh/haswell.h +2235 -0
- package/include/numkong/mesh/neon.h +1329 -0
- package/include/numkong/mesh/neonbfdot.h +842 -0
- package/include/numkong/mesh/neonhalf.h +616 -0
- package/include/numkong/mesh/rvv.h +916 -0
- package/include/numkong/mesh/serial.h +742 -0
- package/include/numkong/mesh/skylake.h +1135 -0
- package/include/numkong/mesh/v128relaxed.h +1052 -0
- package/include/numkong/mesh.h +652 -0
- package/include/numkong/mesh.hpp +762 -0
- package/include/numkong/numkong.h +78 -0
- package/include/numkong/numkong.hpp +57 -0
- package/include/numkong/probability/README.md +173 -0
- package/include/numkong/probability/haswell.h +267 -0
- package/include/numkong/probability/neon.h +225 -0
- package/include/numkong/probability/rvv.h +409 -0
- package/include/numkong/probability/serial.h +169 -0
- package/include/numkong/probability/skylake.h +324 -0
- package/include/numkong/probability.h +383 -0
- package/include/numkong/probability.hpp +120 -0
- package/include/numkong/random.h +50 -0
- package/include/numkong/random.hpp +285 -0
- package/include/numkong/reduce/README.md +547 -0
- package/include/numkong/reduce/alder.h +632 -0
- package/include/numkong/reduce/genoa.h +201 -0
- package/include/numkong/reduce/haswell.h +3783 -0
- package/include/numkong/reduce/icelake.h +549 -0
- package/include/numkong/reduce/neon.h +3841 -0
- package/include/numkong/reduce/neonbfdot.h +353 -0
- package/include/numkong/reduce/neonfhm.h +665 -0
- package/include/numkong/reduce/neonhalf.h +157 -0
- package/include/numkong/reduce/neonsdot.h +357 -0
- package/include/numkong/reduce/rvv.h +3407 -0
- package/include/numkong/reduce/serial.h +757 -0
- package/include/numkong/reduce/sierra.h +338 -0
- package/include/numkong/reduce/skylake.h +3792 -0
- package/include/numkong/reduce/v128relaxed.h +2302 -0
- package/include/numkong/reduce.h +1597 -0
- package/include/numkong/reduce.hpp +633 -0
- package/include/numkong/scalar/README.md +89 -0
- package/include/numkong/scalar/haswell.h +113 -0
- package/include/numkong/scalar/neon.h +122 -0
- package/include/numkong/scalar/neonhalf.h +70 -0
- package/include/numkong/scalar/rvv.h +211 -0
- package/include/numkong/scalar/sapphire.h +63 -0
- package/include/numkong/scalar/serial.h +332 -0
- package/include/numkong/scalar/v128relaxed.h +56 -0
- package/include/numkong/scalar.h +683 -0
- package/include/numkong/set/README.md +179 -0
- package/include/numkong/set/haswell.h +334 -0
- package/include/numkong/set/icelake.h +485 -0
- package/include/numkong/set/neon.h +364 -0
- package/include/numkong/set/rvv.h +226 -0
- package/include/numkong/set/rvvbb.h +117 -0
- package/include/numkong/set/serial.h +174 -0
- package/include/numkong/set/sve.h +185 -0
- package/include/numkong/set/v128relaxed.h +240 -0
- package/include/numkong/set.h +457 -0
- package/include/numkong/set.hpp +114 -0
- package/include/numkong/sets/README.md +149 -0
- package/include/numkong/sets/haswell.h +63 -0
- package/include/numkong/sets/icelake.h +66 -0
- package/include/numkong/sets/neon.h +61 -0
- package/include/numkong/sets/serial.h +43 -0
- package/include/numkong/sets/smebi32.h +1099 -0
- package/include/numkong/sets/v128relaxed.h +58 -0
- package/include/numkong/sets.h +339 -0
- package/include/numkong/sparse/README.md +156 -0
- package/include/numkong/sparse/icelake.h +463 -0
- package/include/numkong/sparse/neon.h +288 -0
- package/include/numkong/sparse/serial.h +117 -0
- package/include/numkong/sparse/sve2.h +507 -0
- package/include/numkong/sparse/turin.h +322 -0
- package/include/numkong/sparse.h +363 -0
- package/include/numkong/sparse.hpp +113 -0
- package/include/numkong/spatial/README.md +435 -0
- package/include/numkong/spatial/alder.h +607 -0
- package/include/numkong/spatial/genoa.h +290 -0
- package/include/numkong/spatial/haswell.h +960 -0
- package/include/numkong/spatial/icelake.h +586 -0
- package/include/numkong/spatial/neon.h +773 -0
- package/include/numkong/spatial/neonbfdot.h +165 -0
- package/include/numkong/spatial/neonhalf.h +118 -0
- package/include/numkong/spatial/neonsdot.h +261 -0
- package/include/numkong/spatial/rvv.h +984 -0
- package/include/numkong/spatial/rvvbf16.h +123 -0
- package/include/numkong/spatial/rvvhalf.h +117 -0
- package/include/numkong/spatial/sapphire.h +343 -0
- package/include/numkong/spatial/serial.h +346 -0
- package/include/numkong/spatial/sierra.h +323 -0
- package/include/numkong/spatial/skylake.h +606 -0
- package/include/numkong/spatial/sve.h +224 -0
- package/include/numkong/spatial/svebfdot.h +122 -0
- package/include/numkong/spatial/svehalf.h +109 -0
- package/include/numkong/spatial/v128relaxed.h +717 -0
- package/include/numkong/spatial.h +1425 -0
- package/include/numkong/spatial.hpp +183 -0
- package/include/numkong/spatials/README.md +580 -0
- package/include/numkong/spatials/alder.h +94 -0
- package/include/numkong/spatials/genoa.h +94 -0
- package/include/numkong/spatials/haswell.h +219 -0
- package/include/numkong/spatials/icelake.h +113 -0
- package/include/numkong/spatials/neon.h +109 -0
- package/include/numkong/spatials/neonbfdot.h +60 -0
- package/include/numkong/spatials/neonfhm.h +92 -0
- package/include/numkong/spatials/neonhalf.h +58 -0
- package/include/numkong/spatials/neonsdot.h +109 -0
- package/include/numkong/spatials/rvv.h +1960 -0
- package/include/numkong/spatials/sapphireamx.h +1149 -0
- package/include/numkong/spatials/serial.h +226 -0
- package/include/numkong/spatials/sierra.h +96 -0
- package/include/numkong/spatials/skylake.h +184 -0
- package/include/numkong/spatials/sme.h +1901 -0
- package/include/numkong/spatials/smef64.h +465 -0
- package/include/numkong/spatials/v128relaxed.h +240 -0
- package/include/numkong/spatials.h +3021 -0
- package/include/numkong/spatials.hpp +508 -0
- package/include/numkong/tensor.hpp +1592 -0
- package/include/numkong/trigonometry/README.md +184 -0
- package/include/numkong/trigonometry/haswell.h +652 -0
- package/include/numkong/trigonometry/neon.h +639 -0
- package/include/numkong/trigonometry/rvv.h +699 -0
- package/include/numkong/trigonometry/serial.h +703 -0
- package/include/numkong/trigonometry/skylake.h +721 -0
- package/include/numkong/trigonometry/v128relaxed.h +666 -0
- package/include/numkong/trigonometry.h +467 -0
- package/include/numkong/trigonometry.hpp +166 -0
- package/include/numkong/types.h +1384 -0
- package/include/numkong/types.hpp +5603 -0
- package/include/numkong/vector.hpp +698 -0
- package/javascript/README.md +246 -0
- package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
- package/javascript/dist/cjs/numkong-wasm.js +617 -0
- package/javascript/dist/cjs/numkong.d.ts +343 -0
- package/javascript/dist/cjs/numkong.js +523 -0
- package/javascript/dist/cjs/package.json +3 -0
- package/javascript/dist/cjs/types.d.ts +284 -0
- package/javascript/dist/cjs/types.js +653 -0
- package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
- package/javascript/dist/esm/numkong-wasm.js +595 -0
- package/javascript/dist/esm/numkong.d.ts +343 -0
- package/javascript/dist/esm/numkong.js +452 -0
- package/javascript/dist/esm/package.json +3 -0
- package/javascript/dist/esm/types.d.ts +284 -0
- package/javascript/dist/esm/types.js +630 -0
- package/javascript/dist-package-cjs.json +3 -0
- package/javascript/dist-package-esm.json +3 -0
- package/javascript/node-gyp-build.d.ts +1 -0
- package/javascript/numkong-wasm.ts +756 -0
- package/javascript/numkong.c +689 -0
- package/javascript/numkong.ts +575 -0
- package/javascript/tsconfig-base.json +39 -0
- package/javascript/tsconfig-cjs.json +8 -0
- package/javascript/tsconfig-esm.json +8 -0
- package/javascript/types.ts +674 -0
- package/package.json +87 -0
|
@@ -0,0 +1,632 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief Alder Lake (AVX2+VNNI) implementations for the redesigned reduction API (moments).
|
|
3
|
+
* @file include/numkong/reduce/alder.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date March 4, 2026
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/reduce.h
|
|
8
|
+
*
|
|
9
|
+
* Uses AVX-VNNI (256-bit) for efficient widening dot-products:
|
|
10
|
+
* - `_mm256_dpwssd_epi32`: i16 x i16 -> i32 accumulation (AVXVNNI, used for i16 and e3m2)
|
|
11
|
+
* - `_mm256_sad_epu8` + `_mm256_madd_epi16`: pure AVX2 SAD/MADD (used for u8)
|
|
12
|
+
* - `_mm256_cvtepu16_epi32` + `_mm256_mullo_epi32`: pure AVX2 (used for u16)
|
|
13
|
+
*/
|
|
14
|
+
#ifndef NK_REDUCE_ALDER_H
|
|
15
|
+
#define NK_REDUCE_ALDER_H
|
|
16
|
+
|
|
17
|
+
#if NK_TARGET_X86_
|
|
18
|
+
#if NK_TARGET_ALDER
|
|
19
|
+
|
|
20
|
+
#include "numkong/types.h"
|
|
21
|
+
#include "numkong/dot/alder.h" // VEX compat macros (_mm256_dpwssd_avx_epi32)
|
|
22
|
+
#include "numkong/reduce/serial.h"
|
|
23
|
+
#include "numkong/reduce/haswell.h" // `nk_reduce_add_i32x8_haswell_`
|
|
24
|
+
|
|
25
|
+
#if defined(__cplusplus)
|
|
26
|
+
extern "C" {
|
|
27
|
+
#endif
|
|
28
|
+
|
|
29
|
+
#if defined(__clang__)
|
|
30
|
+
#pragma clang attribute push(__attribute__((target("avx2,f16c,fma,bmi,bmi2,avxvnni"))), apply_to = function)
|
|
31
|
+
#elif defined(__GNUC__)
|
|
32
|
+
#pragma GCC push_options
|
|
33
|
+
#pragma GCC target("avx2", "f16c", "fma", "bmi", "bmi2", "avxvnni")
|
|
34
|
+
#endif
|
|
35
|
+
|
|
36
|
+
NK_INTERNAL void nk_reduce_moments_u8_alder_contiguous_( //
|
|
37
|
+
nk_u8_t const *data, nk_size_t count, //
|
|
38
|
+
nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
39
|
+
__m256i zero_u8x32 = _mm256_setzero_si256();
|
|
40
|
+
__m256i sum_u64x4 = _mm256_setzero_si256();
|
|
41
|
+
__m256i sumsq_low_i32x8 = _mm256_setzero_si256();
|
|
42
|
+
__m256i sumsq_high_i32x8 = _mm256_setzero_si256();
|
|
43
|
+
nk_size_t idx = 0;
|
|
44
|
+
for (; idx + 32 <= count; idx += 32) {
|
|
45
|
+
__m256i data_u8x32 = _mm256_loadu_si256((__m256i const *)(data + idx));
|
|
46
|
+
sum_u64x4 = _mm256_add_epi64(sum_u64x4, _mm256_sad_epu8(data_u8x32, zero_u8x32));
|
|
47
|
+
__m256i low_u16x16 = _mm256_cvtepu8_epi16(_mm256_castsi256_si128(data_u8x32));
|
|
48
|
+
__m256i high_u16x16 = _mm256_cvtepu8_epi16(_mm256_extracti128_si256(data_u8x32, 1));
|
|
49
|
+
sumsq_low_i32x8 = _mm256_add_epi32(sumsq_low_i32x8, _mm256_madd_epi16(low_u16x16, low_u16x16));
|
|
50
|
+
sumsq_high_i32x8 = _mm256_add_epi32(sumsq_high_i32x8, _mm256_madd_epi16(high_u16x16, high_u16x16));
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
// Handle tail with partial load
|
|
54
|
+
nk_size_t remaining = count - idx;
|
|
55
|
+
if (remaining > 0) {
|
|
56
|
+
nk_b256_vec_t tail_vec;
|
|
57
|
+
nk_partial_load_b8x32_serial_(data + idx, &tail_vec, remaining);
|
|
58
|
+
__m256i data_u8x32 = tail_vec.ymm;
|
|
59
|
+
sum_u64x4 = _mm256_add_epi64(sum_u64x4, _mm256_sad_epu8(data_u8x32, zero_u8x32));
|
|
60
|
+
__m256i low_u16x16 = _mm256_cvtepu8_epi16(_mm256_castsi256_si128(data_u8x32));
|
|
61
|
+
__m256i high_u16x16 = _mm256_cvtepu8_epi16(_mm256_extracti128_si256(data_u8x32, 1));
|
|
62
|
+
sumsq_low_i32x8 = _mm256_add_epi32(sumsq_low_i32x8, _mm256_madd_epi16(low_u16x16, low_u16x16));
|
|
63
|
+
sumsq_high_i32x8 = _mm256_add_epi32(sumsq_high_i32x8, _mm256_madd_epi16(high_u16x16, high_u16x16));
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
sumsq_low_i32x8 = _mm256_add_epi32(sumsq_low_i32x8, sumsq_high_i32x8);
|
|
67
|
+
__m256i sumsq_u64x4 = _mm256_cvtepu32_epi64(_mm256_castsi256_si128(sumsq_low_i32x8));
|
|
68
|
+
sumsq_u64x4 = _mm256_add_epi64(sumsq_u64x4, _mm256_cvtepu32_epi64(_mm256_extracti128_si256(sumsq_low_i32x8, 1)));
|
|
69
|
+
*sum_ptr = (nk_u64_t)nk_reduce_add_i64x4_haswell_(sum_u64x4);
|
|
70
|
+
*sumsq_ptr = (nk_u64_t)nk_reduce_add_i64x4_haswell_(sumsq_u64x4);
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
NK_INTERNAL void nk_reduce_moments_u8_alder_strided_( //
|
|
74
|
+
nk_u8_t const *data, nk_size_t count, nk_size_t stride_elements, //
|
|
75
|
+
nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
76
|
+
__m256i stride_mask_u8x32 = nk_stride_blend_u1x32_(stride_elements);
|
|
77
|
+
__m256i zero_u8x32 = _mm256_setzero_si256();
|
|
78
|
+
__m256i sum_u64x4 = _mm256_setzero_si256();
|
|
79
|
+
__m256i sumsq_low_i32x8 = _mm256_setzero_si256();
|
|
80
|
+
__m256i sumsq_high_i32x8 = _mm256_setzero_si256();
|
|
81
|
+
nk_size_t idx_scalars = 0;
|
|
82
|
+
nk_size_t total_scalars = count * stride_elements;
|
|
83
|
+
nk_size_t step = nk_size_round_up_to_multiple_(32, stride_elements);
|
|
84
|
+
for (; idx_scalars + step <= total_scalars; idx_scalars += step) {
|
|
85
|
+
__m256i data_u8x32 = _mm256_loadu_si256((__m256i const *)(data + idx_scalars));
|
|
86
|
+
data_u8x32 = _mm256_and_si256(data_u8x32, stride_mask_u8x32);
|
|
87
|
+
sum_u64x4 = _mm256_add_epi64(sum_u64x4, _mm256_sad_epu8(data_u8x32, zero_u8x32));
|
|
88
|
+
__m256i low_u16x16 = _mm256_cvtepu8_epi16(_mm256_castsi256_si128(data_u8x32));
|
|
89
|
+
__m256i high_u16x16 = _mm256_cvtepu8_epi16(_mm256_extracti128_si256(data_u8x32, 1));
|
|
90
|
+
sumsq_low_i32x8 = _mm256_add_epi32(sumsq_low_i32x8, _mm256_madd_epi16(low_u16x16, low_u16x16));
|
|
91
|
+
sumsq_high_i32x8 = _mm256_add_epi32(sumsq_high_i32x8, _mm256_madd_epi16(high_u16x16, high_u16x16));
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
sumsq_low_i32x8 = _mm256_add_epi32(sumsq_low_i32x8, sumsq_high_i32x8);
|
|
95
|
+
__m256i sumsq_u64x4 = _mm256_cvtepu32_epi64(_mm256_castsi256_si128(sumsq_low_i32x8));
|
|
96
|
+
sumsq_u64x4 = _mm256_add_epi64(sumsq_u64x4, _mm256_cvtepu32_epi64(_mm256_extracti128_si256(sumsq_low_i32x8, 1)));
|
|
97
|
+
nk_u64_t sum = (nk_u64_t)nk_reduce_add_i64x4_haswell_(sum_u64x4);
|
|
98
|
+
nk_u64_t sumsq = (nk_u64_t)nk_reduce_add_i64x4_haswell_(sumsq_u64x4);
|
|
99
|
+
|
|
100
|
+
// Scalar tail for remaining elements
|
|
101
|
+
nk_u8_t const *ptr = data + idx_scalars;
|
|
102
|
+
nk_size_t remaining_elements = count - idx_scalars / stride_elements;
|
|
103
|
+
for (nk_size_t i = 0; i < remaining_elements; ++i, ptr += stride_elements) {
|
|
104
|
+
nk_u64_t val = (nk_u64_t)*ptr;
|
|
105
|
+
sum += val, sumsq += val * val;
|
|
106
|
+
}
|
|
107
|
+
*sum_ptr = sum, *sumsq_ptr = sumsq;
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
NK_PUBLIC void nk_reduce_moments_u8_alder( //
|
|
111
|
+
nk_u8_t const *data, nk_size_t count, nk_size_t stride_bytes, //
|
|
112
|
+
nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
113
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_u8_t);
|
|
114
|
+
int aligned = (stride_bytes % sizeof(nk_u8_t) == 0);
|
|
115
|
+
if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
|
|
116
|
+
else if (!aligned) nk_reduce_moments_u8_serial(data, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
117
|
+
else if (count > (nk_size_t)16384 * 32) {
|
|
118
|
+
nk_size_t left_count = count / 2;
|
|
119
|
+
nk_u64_t left_sum, left_sumsq, right_sum, right_sumsq;
|
|
120
|
+
nk_reduce_moments_u8_alder(data, left_count, stride_bytes, &left_sum, &left_sumsq);
|
|
121
|
+
nk_reduce_moments_u8_alder(data + left_count * stride_elements, count - left_count, stride_bytes, &right_sum,
|
|
122
|
+
&right_sumsq);
|
|
123
|
+
*sum_ptr = nk_u64_saturating_add_serial(left_sum, right_sum);
|
|
124
|
+
*sumsq_ptr = nk_u64_saturating_add_serial(left_sumsq, right_sumsq);
|
|
125
|
+
}
|
|
126
|
+
else if (stride_elements == 1) nk_reduce_moments_u8_alder_contiguous_(data, count, sum_ptr, sumsq_ptr);
|
|
127
|
+
else if (stride_elements <= 8)
|
|
128
|
+
nk_reduce_moments_u8_alder_strided_(data, count, stride_elements, sum_ptr, sumsq_ptr);
|
|
129
|
+
else nk_reduce_moments_u8_serial(data, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
130
|
+
}
|
|
131
|
+
|
|
132
|
+
NK_INTERNAL void nk_reduce_moments_i16_alder_contiguous_( //
|
|
133
|
+
nk_i16_t const *data, nk_size_t count, //
|
|
134
|
+
nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
135
|
+
__m256i ones_i16x16 = _mm256_set1_epi16(1);
|
|
136
|
+
__m256i sum_i64x4 = _mm256_setzero_si256();
|
|
137
|
+
__m256i sumsq_i64x4 = _mm256_setzero_si256();
|
|
138
|
+
nk_size_t idx = 0;
|
|
139
|
+
for (; idx + 16 <= count; idx += 16) {
|
|
140
|
+
__m256i data_i16x16 = _mm256_loadu_si256((__m256i const *)(data + idx));
|
|
141
|
+
__m256i sum_i32x8 = _mm256_dpwssd_avx_epi32(_mm256_setzero_si256(), data_i16x16, ones_i16x16);
|
|
142
|
+
sum_i64x4 = _mm256_add_epi64(sum_i64x4, _mm256_cvtepi32_epi64(_mm256_castsi256_si128(sum_i32x8)));
|
|
143
|
+
sum_i64x4 = _mm256_add_epi64(sum_i64x4, _mm256_cvtepi32_epi64(_mm256_extracti128_si256(sum_i32x8, 1)));
|
|
144
|
+
__m256i sq_i32x8 = _mm256_dpwssd_avx_epi32(_mm256_setzero_si256(), data_i16x16, data_i16x16);
|
|
145
|
+
sumsq_i64x4 = _mm256_add_epi64(sumsq_i64x4, _mm256_cvtepi32_epi64(_mm256_castsi256_si128(sq_i32x8)));
|
|
146
|
+
sumsq_i64x4 = _mm256_add_epi64(sumsq_i64x4, _mm256_cvtepi32_epi64(_mm256_extracti128_si256(sq_i32x8, 1)));
|
|
147
|
+
}
|
|
148
|
+
nk_size_t remaining = count - idx;
|
|
149
|
+
if (remaining > 0) {
|
|
150
|
+
nk_b256_vec_t tail_vec;
|
|
151
|
+
nk_partial_load_b16x16_serial_(data + idx, &tail_vec, remaining);
|
|
152
|
+
__m256i data_i16x16 = tail_vec.ymm;
|
|
153
|
+
__m256i sum_i32x8 = _mm256_dpwssd_avx_epi32(_mm256_setzero_si256(), data_i16x16, ones_i16x16);
|
|
154
|
+
sum_i64x4 = _mm256_add_epi64(sum_i64x4, _mm256_cvtepi32_epi64(_mm256_castsi256_si128(sum_i32x8)));
|
|
155
|
+
sum_i64x4 = _mm256_add_epi64(sum_i64x4, _mm256_cvtepi32_epi64(_mm256_extracti128_si256(sum_i32x8, 1)));
|
|
156
|
+
__m256i sq_i32x8 = _mm256_dpwssd_avx_epi32(_mm256_setzero_si256(), data_i16x16, data_i16x16);
|
|
157
|
+
sumsq_i64x4 = _mm256_add_epi64(sumsq_i64x4, _mm256_cvtepi32_epi64(_mm256_castsi256_si128(sq_i32x8)));
|
|
158
|
+
sumsq_i64x4 = _mm256_add_epi64(sumsq_i64x4, _mm256_cvtepi32_epi64(_mm256_extracti128_si256(sq_i32x8, 1)));
|
|
159
|
+
}
|
|
160
|
+
*sum_ptr = (nk_i64_t)nk_reduce_add_i64x4_haswell_(sum_i64x4);
|
|
161
|
+
*sumsq_ptr = (nk_u64_t)nk_reduce_add_i64x4_haswell_(sumsq_i64x4);
|
|
162
|
+
}
|
|
163
|
+
|
|
164
|
+
NK_INTERNAL void nk_reduce_moments_i16_alder_strided_( //
|
|
165
|
+
nk_i16_t const *data, nk_size_t count, nk_size_t stride_elements, //
|
|
166
|
+
nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
167
|
+
__m256i stride_mask_i16x16 = nk_stride_blend_b16x16_(stride_elements);
|
|
168
|
+
__m256i ones_i16x16 = _mm256_set1_epi16(1);
|
|
169
|
+
__m256i sum_i64x4 = _mm256_setzero_si256();
|
|
170
|
+
__m256i sumsq_i64x4 = _mm256_setzero_si256();
|
|
171
|
+
nk_size_t idx_scalars = 0;
|
|
172
|
+
nk_size_t total_scalars = count * stride_elements;
|
|
173
|
+
nk_size_t step = nk_size_round_up_to_multiple_(16, stride_elements);
|
|
174
|
+
for (; idx_scalars + step <= total_scalars; idx_scalars += step) {
|
|
175
|
+
__m256i data_i16x16 = _mm256_loadu_si256((__m256i const *)(data + idx_scalars));
|
|
176
|
+
data_i16x16 = _mm256_and_si256(data_i16x16, stride_mask_i16x16);
|
|
177
|
+
__m256i sum_i32x8 = _mm256_dpwssd_avx_epi32(_mm256_setzero_si256(), data_i16x16, ones_i16x16);
|
|
178
|
+
sum_i64x4 = _mm256_add_epi64(sum_i64x4, _mm256_cvtepi32_epi64(_mm256_castsi256_si128(sum_i32x8)));
|
|
179
|
+
sum_i64x4 = _mm256_add_epi64(sum_i64x4, _mm256_cvtepi32_epi64(_mm256_extracti128_si256(sum_i32x8, 1)));
|
|
180
|
+
__m256i sq_i32x8 = _mm256_dpwssd_avx_epi32(_mm256_setzero_si256(), data_i16x16, data_i16x16);
|
|
181
|
+
sumsq_i64x4 = _mm256_add_epi64(sumsq_i64x4, _mm256_cvtepi32_epi64(_mm256_castsi256_si128(sq_i32x8)));
|
|
182
|
+
sumsq_i64x4 = _mm256_add_epi64(sumsq_i64x4, _mm256_cvtepi32_epi64(_mm256_extracti128_si256(sq_i32x8, 1)));
|
|
183
|
+
}
|
|
184
|
+
nk_i64_t sum = (nk_i64_t)nk_reduce_add_i64x4_haswell_(sum_i64x4);
|
|
185
|
+
nk_u64_t sumsq = (nk_u64_t)nk_reduce_add_i64x4_haswell_(sumsq_i64x4);
|
|
186
|
+
nk_i16_t const *ptr = data + idx_scalars;
|
|
187
|
+
nk_size_t remaining = count - idx_scalars / stride_elements;
|
|
188
|
+
for (nk_size_t i = 0; i < remaining; ++i, ptr += stride_elements) {
|
|
189
|
+
nk_i64_t val = (nk_i64_t)*ptr;
|
|
190
|
+
sum += val, sumsq += (nk_u64_t)(val * val);
|
|
191
|
+
}
|
|
192
|
+
*sum_ptr = sum, *sumsq_ptr = sumsq;
|
|
193
|
+
}
|
|
194
|
+
|
|
195
|
+
NK_PUBLIC void nk_reduce_moments_i16_alder( //
|
|
196
|
+
nk_i16_t const *data, nk_size_t count, nk_size_t stride_bytes, //
|
|
197
|
+
nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
198
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_i16_t);
|
|
199
|
+
int aligned = (stride_bytes % sizeof(nk_i16_t) == 0);
|
|
200
|
+
if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
|
|
201
|
+
else if (!aligned) nk_reduce_moments_i16_serial(data, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
202
|
+
else if (count > (nk_size_t)(NK_U16_MAX + 1) * 16) {
|
|
203
|
+
nk_size_t left_count = count / 2;
|
|
204
|
+
nk_i64_t left_sum, right_sum;
|
|
205
|
+
nk_u64_t left_sumsq, right_sumsq;
|
|
206
|
+
nk_reduce_moments_i16_alder(data, left_count, stride_bytes, &left_sum, &left_sumsq);
|
|
207
|
+
nk_reduce_moments_i16_alder(data + left_count * stride_elements, count - left_count, stride_bytes, &right_sum,
|
|
208
|
+
&right_sumsq);
|
|
209
|
+
*sum_ptr = nk_i64_saturating_add_serial(left_sum, right_sum);
|
|
210
|
+
*sumsq_ptr = nk_u64_saturating_add_serial(left_sumsq, right_sumsq);
|
|
211
|
+
}
|
|
212
|
+
else if (stride_elements == 1) nk_reduce_moments_i16_alder_contiguous_(data, count, sum_ptr, sumsq_ptr);
|
|
213
|
+
else if (stride_elements <= 8)
|
|
214
|
+
nk_reduce_moments_i16_alder_strided_(data, count, stride_elements, sum_ptr, sumsq_ptr);
|
|
215
|
+
else nk_reduce_moments_i16_serial(data, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
216
|
+
}
|
|
217
|
+
|
|
218
|
+
NK_INTERNAL void nk_reduce_moments_u16_alder_contiguous_( //
|
|
219
|
+
nk_u16_t const *data, nk_size_t count, //
|
|
220
|
+
nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
221
|
+
__m256i sum_u32x8 = _mm256_setzero_si256();
|
|
222
|
+
__m256i sumsq_u64x4 = _mm256_setzero_si256();
|
|
223
|
+
nk_size_t idx = 0;
|
|
224
|
+
for (; idx + 8 <= count; idx += 8) {
|
|
225
|
+
__m256i data_u32x8 = _mm256_cvtepu16_epi32(_mm_loadu_si128((__m128i const *)(data + idx)));
|
|
226
|
+
sum_u32x8 = _mm256_add_epi32(sum_u32x8, data_u32x8);
|
|
227
|
+
__m256i sq_u32x8 = _mm256_mullo_epi32(data_u32x8, data_u32x8);
|
|
228
|
+
sumsq_u64x4 = _mm256_add_epi64(sumsq_u64x4, _mm256_cvtepu32_epi64(_mm256_castsi256_si128(sq_u32x8)));
|
|
229
|
+
sumsq_u64x4 = _mm256_add_epi64(sumsq_u64x4, _mm256_cvtepu32_epi64(_mm256_extracti128_si256(sq_u32x8, 1)));
|
|
230
|
+
}
|
|
231
|
+
|
|
232
|
+
// Handle tail with partial load
|
|
233
|
+
nk_size_t remaining = count - idx;
|
|
234
|
+
if (remaining > 0) {
|
|
235
|
+
nk_b256_vec_t tail_vec;
|
|
236
|
+
nk_partial_load_b16x16_serial_(data + idx, &tail_vec, remaining);
|
|
237
|
+
__m256i data_u32x8 = _mm256_cvtepu16_epi32(_mm256_castsi256_si128(tail_vec.ymm));
|
|
238
|
+
sum_u32x8 = _mm256_add_epi32(sum_u32x8, data_u32x8);
|
|
239
|
+
__m256i sq_u32x8 = _mm256_mullo_epi32(data_u32x8, data_u32x8);
|
|
240
|
+
sumsq_u64x4 = _mm256_add_epi64(sumsq_u64x4, _mm256_cvtepu32_epi64(_mm256_castsi256_si128(sq_u32x8)));
|
|
241
|
+
sumsq_u64x4 = _mm256_add_epi64(sumsq_u64x4, _mm256_cvtepu32_epi64(_mm256_extracti128_si256(sq_u32x8, 1)));
|
|
242
|
+
}
|
|
243
|
+
|
|
244
|
+
__m256i sum_u64x4 = _mm256_add_epi64( //
|
|
245
|
+
_mm256_cvtepu32_epi64(_mm256_castsi256_si128(sum_u32x8)), //
|
|
246
|
+
_mm256_cvtepu32_epi64(_mm256_extracti128_si256(sum_u32x8, 1))); //
|
|
247
|
+
*sum_ptr = (nk_u64_t)nk_reduce_add_i64x4_haswell_(sum_u64x4);
|
|
248
|
+
*sumsq_ptr = (nk_u64_t)nk_reduce_add_i64x4_haswell_(sumsq_u64x4);
|
|
249
|
+
}
|
|
250
|
+
|
|
251
|
+
NK_INTERNAL void nk_reduce_moments_u16_alder_strided_( //
|
|
252
|
+
nk_u16_t const *data, nk_size_t count, nk_size_t stride_elements, //
|
|
253
|
+
nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
254
|
+
__m256i stride_mask_i16x16 = nk_stride_blend_b16x16_(stride_elements);
|
|
255
|
+
__m256i sum_u32x8 = _mm256_setzero_si256();
|
|
256
|
+
__m256i sumsq_u64x4 = _mm256_setzero_si256();
|
|
257
|
+
nk_size_t idx_scalars = 0;
|
|
258
|
+
nk_size_t total_scalars = count * stride_elements;
|
|
259
|
+
nk_size_t step = nk_size_round_up_to_multiple_(16, stride_elements);
|
|
260
|
+
for (; idx_scalars + step <= total_scalars; idx_scalars += step) {
|
|
261
|
+
__m256i data_u16x16 = _mm256_loadu_si256((__m256i const *)(data + idx_scalars));
|
|
262
|
+
data_u16x16 = _mm256_and_si256(data_u16x16, stride_mask_i16x16);
|
|
263
|
+
__m256i low_u32x8 = _mm256_cvtepu16_epi32(_mm256_castsi256_si128(data_u16x16));
|
|
264
|
+
__m256i high_u32x8 = _mm256_cvtepu16_epi32(_mm256_extracti128_si256(data_u16x16, 1));
|
|
265
|
+
sum_u32x8 = _mm256_add_epi32(sum_u32x8, low_u32x8);
|
|
266
|
+
sum_u32x8 = _mm256_add_epi32(sum_u32x8, high_u32x8);
|
|
267
|
+
__m256i low_sq_u32x8 = _mm256_mullo_epi32(low_u32x8, low_u32x8);
|
|
268
|
+
__m256i high_sq_u32x8 = _mm256_mullo_epi32(high_u32x8, high_u32x8);
|
|
269
|
+
sumsq_u64x4 = _mm256_add_epi64(sumsq_u64x4, _mm256_cvtepu32_epi64(_mm256_castsi256_si128(low_sq_u32x8)));
|
|
270
|
+
sumsq_u64x4 = _mm256_add_epi64(sumsq_u64x4, _mm256_cvtepu32_epi64(_mm256_extracti128_si256(low_sq_u32x8, 1)));
|
|
271
|
+
sumsq_u64x4 = _mm256_add_epi64(sumsq_u64x4, _mm256_cvtepu32_epi64(_mm256_castsi256_si128(high_sq_u32x8)));
|
|
272
|
+
sumsq_u64x4 = _mm256_add_epi64(sumsq_u64x4, _mm256_cvtepu32_epi64(_mm256_extracti128_si256(high_sq_u32x8, 1)));
|
|
273
|
+
}
|
|
274
|
+
|
|
275
|
+
__m256i sum_u64x4 = _mm256_add_epi64( //
|
|
276
|
+
_mm256_cvtepu32_epi64(_mm256_castsi256_si128(sum_u32x8)), //
|
|
277
|
+
_mm256_cvtepu32_epi64(_mm256_extracti128_si256(sum_u32x8, 1))); //
|
|
278
|
+
nk_u64_t sum = (nk_u64_t)nk_reduce_add_i64x4_haswell_(sum_u64x4);
|
|
279
|
+
nk_u64_t sumsq = (nk_u64_t)nk_reduce_add_i64x4_haswell_(sumsq_u64x4);
|
|
280
|
+
|
|
281
|
+
// Scalar tail for remaining elements
|
|
282
|
+
nk_u16_t const *ptr = data + idx_scalars;
|
|
283
|
+
nk_size_t remaining = count - idx_scalars / stride_elements;
|
|
284
|
+
for (nk_size_t i = 0; i < remaining; ++i, ptr += stride_elements) {
|
|
285
|
+
nk_u64_t val = (nk_u64_t)*ptr;
|
|
286
|
+
sum += val, sumsq += val * val;
|
|
287
|
+
}
|
|
288
|
+
*sum_ptr = sum, *sumsq_ptr = sumsq;
|
|
289
|
+
}
|
|
290
|
+
|
|
291
|
+
NK_PUBLIC void nk_reduce_moments_u16_alder( //
|
|
292
|
+
nk_u16_t const *data, nk_size_t count, nk_size_t stride_bytes, //
|
|
293
|
+
nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
294
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_u16_t);
|
|
295
|
+
int aligned = (stride_bytes % sizeof(nk_u16_t) == 0);
|
|
296
|
+
if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
|
|
297
|
+
else if (!aligned) nk_reduce_moments_u16_serial(data, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
298
|
+
else if (count > (nk_size_t)(NK_U16_MAX + 1) * 8) {
|
|
299
|
+
nk_size_t left_count = count / 2;
|
|
300
|
+
nk_u64_t left_sum, left_sumsq, right_sum, right_sumsq;
|
|
301
|
+
nk_reduce_moments_u16_alder(data, left_count, stride_bytes, &left_sum, &left_sumsq);
|
|
302
|
+
nk_reduce_moments_u16_alder(data + left_count * stride_elements, count - left_count, stride_bytes, &right_sum,
|
|
303
|
+
&right_sumsq);
|
|
304
|
+
*sum_ptr = nk_u64_saturating_add_serial(left_sum, right_sum);
|
|
305
|
+
*sumsq_ptr = nk_u64_saturating_add_serial(left_sumsq, right_sumsq);
|
|
306
|
+
}
|
|
307
|
+
else if (stride_elements == 1) nk_reduce_moments_u16_alder_contiguous_(data, count, sum_ptr, sumsq_ptr);
|
|
308
|
+
else if (stride_elements <= 8)
|
|
309
|
+
nk_reduce_moments_u16_alder_strided_(data, count, stride_elements, sum_ptr, sumsq_ptr);
|
|
310
|
+
else nk_reduce_moments_u16_serial(data, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
311
|
+
}
|
|
312
|
+
|
|
313
|
+
/**
|
|
314
|
+
* @section e3m2 moments via integer VNNI (dpwssd)
|
|
315
|
+
*
|
|
316
|
+
* Every e3m2 value x 16 is an exact integer in [-448, +448] (i16 range).
|
|
317
|
+
* We use dual-VPSHUFB for the low byte + threshold compare for the high byte,
|
|
318
|
+
* then UNPACKLO/HI to form unsigned i16, apply sign via `_mm256_sign_epi16`,
|
|
319
|
+
* and accumulate with `_mm256_dpwssd_epi32` (signed i16 x signed i16 -> i32).
|
|
320
|
+
* Final: sum = i32_sum / 16, sumsq = i32_sumsq / 256.
|
|
321
|
+
*/
|
|
322
|
+
NK_INTERNAL void nk_reduce_moments_e3m2_alder_contiguous_( //
|
|
323
|
+
nk_e3m2_t const *data, nk_size_t count, //
|
|
324
|
+
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
325
|
+
__m256i const lut_lo_lower_u8x32 = _mm256_set_epi8( //
|
|
326
|
+
28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2, 1, 0, //
|
|
327
|
+
28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2, 1, 0); //
|
|
328
|
+
__m256i const lut_lo_upper_u8x32 = _mm256_set_epi8( //
|
|
329
|
+
(char)192, (char)128, 64, 0, (char)224, (char)192, (char)160, (char)128, 112, 96, 80, 64, 56, 48, 40, 32, //
|
|
330
|
+
(char)192, (char)128, 64, 0, (char)224, (char)192, (char)160, (char)128, 112, 96, 80, 64, 56, 48, 40, 32); //
|
|
331
|
+
__m256i const nibble_mask_u8x32 = _mm256_set1_epi8(0x0F);
|
|
332
|
+
__m256i const magnitude_mask_u8x32 = _mm256_set1_epi8(0x1F);
|
|
333
|
+
__m256i const half_select_u8x32 = _mm256_set1_epi8(0x10);
|
|
334
|
+
__m256i const sign_mask_u8x32 = _mm256_set1_epi8(0x20);
|
|
335
|
+
__m256i const high_threshold_u8x32 = _mm256_set1_epi8(27);
|
|
336
|
+
__m256i const ones_u8x32 = _mm256_set1_epi8(1);
|
|
337
|
+
__m256i const ones_i16x16 = _mm256_set1_epi16(1);
|
|
338
|
+
__m256i sum_i32x8 = _mm256_setzero_si256();
|
|
339
|
+
__m256i sumsq_i32x8 = _mm256_setzero_si256();
|
|
340
|
+
nk_size_t idx = 0;
|
|
341
|
+
for (; idx + 32 <= count; idx += 32) {
|
|
342
|
+
__m256i data_u8x32 = _mm256_loadu_si256((__m256i const *)(data + idx));
|
|
343
|
+
__m256i magnitude_u8x32 = _mm256_and_si256(data_u8x32, magnitude_mask_u8x32);
|
|
344
|
+
__m256i shuffle_idx_u8x32 = _mm256_and_si256(magnitude_u8x32, nibble_mask_u8x32);
|
|
345
|
+
__m256i upper_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(magnitude_u8x32, half_select_u8x32),
|
|
346
|
+
half_select_u8x32);
|
|
347
|
+
__m256i lo_bytes_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_lo_lower_u8x32, shuffle_idx_u8x32),
|
|
348
|
+
_mm256_shuffle_epi8(lut_lo_upper_u8x32, shuffle_idx_u8x32),
|
|
349
|
+
upper_select_u8x32);
|
|
350
|
+
__m256i hi_bytes_u8x32 = _mm256_and_si256(_mm256_cmpgt_epi8(magnitude_u8x32, high_threshold_u8x32), ones_u8x32);
|
|
351
|
+
__m256i unsigned_lo_i16x16 = _mm256_unpacklo_epi8(lo_bytes_u8x32, hi_bytes_u8x32);
|
|
352
|
+
__m256i unsigned_hi_i16x16 = _mm256_unpackhi_epi8(lo_bytes_u8x32, hi_bytes_u8x32);
|
|
353
|
+
// Sign handling: extract sign bit, widen to i16, create +1/-1, apply via VPSIGNW
|
|
354
|
+
__m256i negate_mask_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(data_u8x32, sign_mask_u8x32), sign_mask_u8x32);
|
|
355
|
+
__m256i negate_lo_i16x16 = _mm256_unpacklo_epi8(negate_mask_u8x32, negate_mask_u8x32);
|
|
356
|
+
__m256i negate_hi_i16x16 = _mm256_unpackhi_epi8(negate_mask_u8x32, negate_mask_u8x32);
|
|
357
|
+
__m256i signed_lo_i16x16 = _mm256_sign_epi16(unsigned_lo_i16x16,
|
|
358
|
+
_mm256_or_si256(negate_lo_i16x16, ones_i16x16));
|
|
359
|
+
__m256i signed_hi_i16x16 = _mm256_sign_epi16(unsigned_hi_i16x16,
|
|
360
|
+
_mm256_or_si256(negate_hi_i16x16, ones_i16x16));
|
|
361
|
+
// VNNI accumulation: dpwssd (signed i16 x signed i16 -> i32)
|
|
362
|
+
sum_i32x8 = _mm256_dpwssd_avx_epi32(sum_i32x8, signed_lo_i16x16, ones_i16x16);
|
|
363
|
+
sum_i32x8 = _mm256_dpwssd_avx_epi32(sum_i32x8, signed_hi_i16x16, ones_i16x16);
|
|
364
|
+
sumsq_i32x8 = _mm256_dpwssd_avx_epi32(sumsq_i32x8, signed_lo_i16x16, signed_lo_i16x16);
|
|
365
|
+
sumsq_i32x8 = _mm256_dpwssd_avx_epi32(sumsq_i32x8, signed_hi_i16x16, signed_hi_i16x16);
|
|
366
|
+
}
|
|
367
|
+
nk_i32_t sum = nk_reduce_add_i32x8_haswell_(sum_i32x8);
|
|
368
|
+
nk_i32_t sumsq = nk_reduce_add_i32x8_haswell_(sumsq_i32x8);
|
|
369
|
+
nk_size_t remaining = count - idx;
|
|
370
|
+
if (remaining > 0) {
|
|
371
|
+
nk_b256_vec_t tail_vec;
|
|
372
|
+
nk_partial_load_b8x32_serial_(data + idx, &tail_vec, remaining);
|
|
373
|
+
__m256i data_u8x32 = tail_vec.ymm;
|
|
374
|
+
__m256i magnitude_u8x32 = _mm256_and_si256(data_u8x32, magnitude_mask_u8x32);
|
|
375
|
+
__m256i shuffle_idx_u8x32 = _mm256_and_si256(magnitude_u8x32, nibble_mask_u8x32);
|
|
376
|
+
__m256i upper_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(magnitude_u8x32, half_select_u8x32),
|
|
377
|
+
half_select_u8x32);
|
|
378
|
+
__m256i lo_bytes_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_lo_lower_u8x32, shuffle_idx_u8x32),
|
|
379
|
+
_mm256_shuffle_epi8(lut_lo_upper_u8x32, shuffle_idx_u8x32),
|
|
380
|
+
upper_select_u8x32);
|
|
381
|
+
__m256i hi_bytes_u8x32 = _mm256_and_si256(_mm256_cmpgt_epi8(magnitude_u8x32, high_threshold_u8x32), ones_u8x32);
|
|
382
|
+
__m256i unsigned_lo_i16x16 = _mm256_unpacklo_epi8(lo_bytes_u8x32, hi_bytes_u8x32);
|
|
383
|
+
__m256i unsigned_hi_i16x16 = _mm256_unpackhi_epi8(lo_bytes_u8x32, hi_bytes_u8x32);
|
|
384
|
+
__m256i negate_mask_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(data_u8x32, sign_mask_u8x32), sign_mask_u8x32);
|
|
385
|
+
__m256i negate_lo_i16x16 = _mm256_unpacklo_epi8(negate_mask_u8x32, negate_mask_u8x32);
|
|
386
|
+
__m256i negate_hi_i16x16 = _mm256_unpackhi_epi8(negate_mask_u8x32, negate_mask_u8x32);
|
|
387
|
+
__m256i signed_lo_i16x16 = _mm256_sign_epi16(unsigned_lo_i16x16,
|
|
388
|
+
_mm256_or_si256(negate_lo_i16x16, ones_i16x16));
|
|
389
|
+
__m256i signed_hi_i16x16 = _mm256_sign_epi16(unsigned_hi_i16x16,
|
|
390
|
+
_mm256_or_si256(negate_hi_i16x16, ones_i16x16));
|
|
391
|
+
__m256i tail_sum = _mm256_dpwssd_avx_epi32(_mm256_setzero_si256(), signed_lo_i16x16, ones_i16x16);
|
|
392
|
+
tail_sum = _mm256_dpwssd_avx_epi32(tail_sum, signed_hi_i16x16, ones_i16x16);
|
|
393
|
+
__m256i tail_sumsq = _mm256_dpwssd_avx_epi32(_mm256_setzero_si256(), signed_lo_i16x16, signed_lo_i16x16);
|
|
394
|
+
tail_sumsq = _mm256_dpwssd_avx_epi32(tail_sumsq, signed_hi_i16x16, signed_hi_i16x16);
|
|
395
|
+
sum += nk_reduce_add_i32x8_haswell_(tail_sum);
|
|
396
|
+
sumsq += nk_reduce_add_i32x8_haswell_(tail_sumsq);
|
|
397
|
+
}
|
|
398
|
+
*sum_ptr = (nk_f32_t)sum / 16.0f;
|
|
399
|
+
*sumsq_ptr = (nk_f32_t)sumsq / 256.0f;
|
|
400
|
+
}
|
|
401
|
+
|
|
402
|
+
NK_INTERNAL void nk_reduce_moments_e3m2_alder_strided_( //
|
|
403
|
+
nk_e3m2_t const *data, nk_size_t count, nk_size_t stride_elements, //
|
|
404
|
+
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
405
|
+
__m256i stride_mask_u8x32 = nk_stride_blend_u1x32_(stride_elements);
|
|
406
|
+
__m256i const lut_lo_lower_u8x32 = _mm256_set_epi8( //
|
|
407
|
+
28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2, 1, 0, //
|
|
408
|
+
28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2, 1, 0); //
|
|
409
|
+
__m256i const lut_lo_upper_u8x32 = _mm256_set_epi8( //
|
|
410
|
+
(char)192, (char)128, 64, 0, (char)224, (char)192, (char)160, (char)128, 112, 96, 80, 64, 56, 48, 40, 32, //
|
|
411
|
+
(char)192, (char)128, 64, 0, (char)224, (char)192, (char)160, (char)128, 112, 96, 80, 64, 56, 48, 40, 32); //
|
|
412
|
+
__m256i const nibble_mask_u8x32 = _mm256_set1_epi8(0x0F);
|
|
413
|
+
__m256i const magnitude_mask_u8x32 = _mm256_set1_epi8(0x1F);
|
|
414
|
+
__m256i const half_select_u8x32 = _mm256_set1_epi8(0x10);
|
|
415
|
+
__m256i const sign_mask_u8x32 = _mm256_set1_epi8(0x20);
|
|
416
|
+
__m256i const high_threshold_u8x32 = _mm256_set1_epi8(27);
|
|
417
|
+
__m256i const ones_u8x32 = _mm256_set1_epi8(1);
|
|
418
|
+
__m256i const ones_i16x16 = _mm256_set1_epi16(1);
|
|
419
|
+
__m256i sum_i32x8 = _mm256_setzero_si256();
|
|
420
|
+
__m256i sumsq_i32x8 = _mm256_setzero_si256();
|
|
421
|
+
nk_size_t idx_scalars = 0;
|
|
422
|
+
nk_size_t total_scalars = count * stride_elements;
|
|
423
|
+
nk_size_t step = nk_size_round_up_to_multiple_(32, stride_elements);
|
|
424
|
+
for (; idx_scalars + step <= total_scalars; idx_scalars += step) {
|
|
425
|
+
__m256i data_u8x32 = _mm256_loadu_si256((__m256i const *)(data + idx_scalars));
|
|
426
|
+
data_u8x32 = _mm256_and_si256(data_u8x32, stride_mask_u8x32);
|
|
427
|
+
__m256i magnitude_u8x32 = _mm256_and_si256(data_u8x32, magnitude_mask_u8x32);
|
|
428
|
+
__m256i shuffle_idx_u8x32 = _mm256_and_si256(magnitude_u8x32, nibble_mask_u8x32);
|
|
429
|
+
__m256i upper_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(magnitude_u8x32, half_select_u8x32),
|
|
430
|
+
half_select_u8x32);
|
|
431
|
+
__m256i lo_bytes_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_lo_lower_u8x32, shuffle_idx_u8x32),
|
|
432
|
+
_mm256_shuffle_epi8(lut_lo_upper_u8x32, shuffle_idx_u8x32),
|
|
433
|
+
upper_select_u8x32);
|
|
434
|
+
__m256i hi_bytes_u8x32 = _mm256_and_si256(_mm256_cmpgt_epi8(magnitude_u8x32, high_threshold_u8x32), ones_u8x32);
|
|
435
|
+
__m256i unsigned_lo_i16x16 = _mm256_unpacklo_epi8(lo_bytes_u8x32, hi_bytes_u8x32);
|
|
436
|
+
__m256i unsigned_hi_i16x16 = _mm256_unpackhi_epi8(lo_bytes_u8x32, hi_bytes_u8x32);
|
|
437
|
+
__m256i negate_mask_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(data_u8x32, sign_mask_u8x32), sign_mask_u8x32);
|
|
438
|
+
__m256i negate_lo_i16x16 = _mm256_unpacklo_epi8(negate_mask_u8x32, negate_mask_u8x32);
|
|
439
|
+
__m256i negate_hi_i16x16 = _mm256_unpackhi_epi8(negate_mask_u8x32, negate_mask_u8x32);
|
|
440
|
+
__m256i signed_lo_i16x16 = _mm256_sign_epi16(unsigned_lo_i16x16,
|
|
441
|
+
_mm256_or_si256(negate_lo_i16x16, ones_i16x16));
|
|
442
|
+
__m256i signed_hi_i16x16 = _mm256_sign_epi16(unsigned_hi_i16x16,
|
|
443
|
+
_mm256_or_si256(negate_hi_i16x16, ones_i16x16));
|
|
444
|
+
sum_i32x8 = _mm256_dpwssd_avx_epi32(sum_i32x8, signed_lo_i16x16, ones_i16x16);
|
|
445
|
+
sum_i32x8 = _mm256_dpwssd_avx_epi32(sum_i32x8, signed_hi_i16x16, ones_i16x16);
|
|
446
|
+
sumsq_i32x8 = _mm256_dpwssd_avx_epi32(sumsq_i32x8, signed_lo_i16x16, signed_lo_i16x16);
|
|
447
|
+
sumsq_i32x8 = _mm256_dpwssd_avx_epi32(sumsq_i32x8, signed_hi_i16x16, signed_hi_i16x16);
|
|
448
|
+
}
|
|
449
|
+
nk_i32_t sum = nk_reduce_add_i32x8_haswell_(sum_i32x8);
|
|
450
|
+
nk_i32_t sumsq = nk_reduce_add_i32x8_haswell_(sumsq_i32x8);
|
|
451
|
+
nk_e3m2_t const *ptr = data + idx_scalars;
|
|
452
|
+
nk_size_t remaining = count - idx_scalars / stride_elements;
|
|
453
|
+
for (nk_size_t i = 0; i < remaining; ++i, ptr += stride_elements) {
|
|
454
|
+
nk_f32_t val;
|
|
455
|
+
nk_e3m2_to_f32_serial(ptr, &val);
|
|
456
|
+
nk_i32_t ival = (nk_i32_t)(val * 16.0f);
|
|
457
|
+
sum += ival;
|
|
458
|
+
sumsq += ival * ival;
|
|
459
|
+
}
|
|
460
|
+
*sum_ptr = (nk_f32_t)sum / 16.0f;
|
|
461
|
+
*sumsq_ptr = (nk_f32_t)sumsq / 256.0f;
|
|
462
|
+
}
|
|
463
|
+
|
|
464
|
+
NK_PUBLIC void nk_reduce_moments_e3m2_alder( //
|
|
465
|
+
nk_e3m2_t const *data, nk_size_t count, nk_size_t stride_bytes, //
|
|
466
|
+
nk_f32_t *sum, nk_f32_t *sumsq) {
|
|
467
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_e3m2_t);
|
|
468
|
+
int aligned = (stride_bytes % sizeof(nk_e3m2_t) == 0);
|
|
469
|
+
if (count == 0) *sum = 0, *sumsq = 0;
|
|
470
|
+
else if (!aligned) nk_reduce_moments_e3m2_serial(data, count, stride_bytes, sum, sumsq);
|
|
471
|
+
else if (count > (nk_size_t)2048 * 32) {
|
|
472
|
+
nk_size_t left_count = count / 2;
|
|
473
|
+
nk_f32_t left_sum, left_sumsq, right_sum, right_sumsq;
|
|
474
|
+
nk_reduce_moments_e3m2_alder(data, left_count, stride_bytes, &left_sum, &left_sumsq);
|
|
475
|
+
nk_reduce_moments_e3m2_alder(data + left_count * stride_elements, count - left_count, stride_bytes, &right_sum,
|
|
476
|
+
&right_sumsq);
|
|
477
|
+
*sum = left_sum + right_sum, *sumsq = left_sumsq + right_sumsq;
|
|
478
|
+
}
|
|
479
|
+
else if (stride_elements == 1) nk_reduce_moments_e3m2_alder_contiguous_(data, count, sum, sumsq);
|
|
480
|
+
else if (stride_elements <= 8) nk_reduce_moments_e3m2_alder_strided_(data, count, stride_elements, sum, sumsq);
|
|
481
|
+
else nk_reduce_moments_e3m2_serial(data, count, stride_bytes, sum, sumsq);
|
|
482
|
+
}
|
|
483
|
+
|
|
484
|
+
/**
|
|
485
|
+
* @section e2m3 moments via integer VNNI (dpbusd)
|
|
486
|
+
*
|
|
487
|
+
* Every e2m3 value × 16 is an exact integer in [-120, +120] (i8 range).
|
|
488
|
+
* We use a dual-VPSHUFB LUT to map 5-bit magnitude → unsigned i8, create a
|
|
489
|
+
* sign vector (+1/-1), then accumulate with `_mm256_dpbusd_avx_epi32` (u8 × i8 → i32).
|
|
490
|
+
* For sumsq, magnitude ≤ 120 < 128 so it's safe as both u8 and i8.
|
|
491
|
+
* Final: sum = i32_sum / 16, sumsq = i32_sumsq / 256.
|
|
492
|
+
*/
|
|
493
|
+
NK_INTERNAL void nk_reduce_moments_e2m3_alder_contiguous_( //
|
|
494
|
+
nk_e2m3_t const *data, nk_size_t count, //
|
|
495
|
+
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
496
|
+
__m256i const lut_lower_u8x32 = _mm256_set_epi8(30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0, //
|
|
497
|
+
30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
|
|
498
|
+
__m256i const lut_upper_u8x32 = _mm256_set_epi8(120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32,
|
|
499
|
+
120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32);
|
|
500
|
+
__m256i const nibble_mask_u8x32 = _mm256_set1_epi8(0x0F);
|
|
501
|
+
__m256i const magnitude_mask_u8x32 = _mm256_set1_epi8(0x1F);
|
|
502
|
+
__m256i const half_select_u8x32 = _mm256_set1_epi8(0x10);
|
|
503
|
+
__m256i const sign_mask_u8x32 = _mm256_set1_epi8(0x20);
|
|
504
|
+
__m256i const ones_i8x32 = _mm256_set1_epi8(1);
|
|
505
|
+
__m256i const neg_ones_i8x32 = _mm256_set1_epi8(-1);
|
|
506
|
+
__m256i sum_i32x8 = _mm256_setzero_si256();
|
|
507
|
+
__m256i sumsq_i32x8 = _mm256_setzero_si256();
|
|
508
|
+
nk_size_t idx = 0;
|
|
509
|
+
for (; idx + 32 <= count; idx += 32) {
|
|
510
|
+
__m256i data_u8x32 = _mm256_loadu_si256((__m256i const *)(data + idx));
|
|
511
|
+
__m256i magnitude_u8x32 = _mm256_and_si256(data_u8x32, magnitude_mask_u8x32);
|
|
512
|
+
__m256i shuffle_idx_u8x32 = _mm256_and_si256(magnitude_u8x32, nibble_mask_u8x32);
|
|
513
|
+
__m256i upper_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(magnitude_u8x32, half_select_u8x32),
|
|
514
|
+
half_select_u8x32);
|
|
515
|
+
__m256i unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_lower_u8x32, shuffle_idx_u8x32),
|
|
516
|
+
_mm256_shuffle_epi8(lut_upper_u8x32, shuffle_idx_u8x32),
|
|
517
|
+
upper_select_u8x32);
|
|
518
|
+
// Sign vector: +1 for positive, -1 for negative (i8)
|
|
519
|
+
__m256i negate_mask_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(data_u8x32, sign_mask_u8x32), sign_mask_u8x32);
|
|
520
|
+
__m256i sign_i8x32 = _mm256_blendv_epi8(ones_i8x32, neg_ones_i8x32, negate_mask_u8x32);
|
|
521
|
+
// DPBUSD: unsigned_magnitude (u8) × sign (+1/-1 i8) → signed sum
|
|
522
|
+
sum_i32x8 = _mm256_dpbusd_avx_epi32(sum_i32x8, unsigned_u8x32, sign_i8x32);
|
|
523
|
+
// DPBUSD: unsigned_magnitude (u8) × unsigned_magnitude (as i8, safe since ≤120) → sumsq
|
|
524
|
+
sumsq_i32x8 = _mm256_dpbusd_avx_epi32(sumsq_i32x8, unsigned_u8x32, unsigned_u8x32);
|
|
525
|
+
}
|
|
526
|
+
nk_i32_t sum = nk_reduce_add_i32x8_haswell_(sum_i32x8);
|
|
527
|
+
nk_i32_t sumsq = nk_reduce_add_i32x8_haswell_(sumsq_i32x8);
|
|
528
|
+
nk_size_t remaining = count - idx;
|
|
529
|
+
if (remaining > 0) {
|
|
530
|
+
nk_b256_vec_t tail_vec;
|
|
531
|
+
nk_partial_load_b8x32_serial_(data + idx, &tail_vec, remaining);
|
|
532
|
+
__m256i data_u8x32 = tail_vec.ymm;
|
|
533
|
+
__m256i magnitude_u8x32 = _mm256_and_si256(data_u8x32, magnitude_mask_u8x32);
|
|
534
|
+
__m256i shuffle_idx_u8x32 = _mm256_and_si256(magnitude_u8x32, nibble_mask_u8x32);
|
|
535
|
+
__m256i upper_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(magnitude_u8x32, half_select_u8x32),
|
|
536
|
+
half_select_u8x32);
|
|
537
|
+
__m256i unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_lower_u8x32, shuffle_idx_u8x32),
|
|
538
|
+
_mm256_shuffle_epi8(lut_upper_u8x32, shuffle_idx_u8x32),
|
|
539
|
+
upper_select_u8x32);
|
|
540
|
+
__m256i negate_mask_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(data_u8x32, sign_mask_u8x32), sign_mask_u8x32);
|
|
541
|
+
__m256i sign_i8x32 = _mm256_blendv_epi8(ones_i8x32, neg_ones_i8x32, negate_mask_u8x32);
|
|
542
|
+
sum += nk_reduce_add_i32x8_haswell_(
|
|
543
|
+
_mm256_dpbusd_avx_epi32(_mm256_setzero_si256(), unsigned_u8x32, sign_i8x32));
|
|
544
|
+
sumsq += nk_reduce_add_i32x8_haswell_(
|
|
545
|
+
_mm256_dpbusd_avx_epi32(_mm256_setzero_si256(), unsigned_u8x32, unsigned_u8x32));
|
|
546
|
+
}
|
|
547
|
+
*sum_ptr = (nk_f32_t)sum / 16.0f;
|
|
548
|
+
*sumsq_ptr = (nk_f32_t)sumsq / 256.0f;
|
|
549
|
+
}
|
|
550
|
+
|
|
551
|
+
NK_INTERNAL void nk_reduce_moments_e2m3_alder_strided_( //
|
|
552
|
+
nk_e2m3_t const *data, nk_size_t count, nk_size_t stride_elements, //
|
|
553
|
+
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
554
|
+
__m256i stride_mask_u8x32 = nk_stride_blend_u1x32_(stride_elements);
|
|
555
|
+
__m256i const lut_lower_u8x32 = _mm256_set_epi8(30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0, //
|
|
556
|
+
30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
|
|
557
|
+
__m256i const lut_upper_u8x32 = _mm256_set_epi8(120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32,
|
|
558
|
+
120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32);
|
|
559
|
+
__m256i const nibble_mask_u8x32 = _mm256_set1_epi8(0x0F);
|
|
560
|
+
__m256i const magnitude_mask_u8x32 = _mm256_set1_epi8(0x1F);
|
|
561
|
+
__m256i const half_select_u8x32 = _mm256_set1_epi8(0x10);
|
|
562
|
+
__m256i const sign_mask_u8x32 = _mm256_set1_epi8(0x20);
|
|
563
|
+
__m256i const ones_i8x32 = _mm256_set1_epi8(1);
|
|
564
|
+
__m256i const neg_ones_i8x32 = _mm256_set1_epi8(-1);
|
|
565
|
+
__m256i sum_i32x8 = _mm256_setzero_si256();
|
|
566
|
+
__m256i sumsq_i32x8 = _mm256_setzero_si256();
|
|
567
|
+
nk_size_t idx_scalars = 0;
|
|
568
|
+
nk_size_t total_scalars = count * stride_elements;
|
|
569
|
+
nk_size_t step = nk_size_round_up_to_multiple_(32, stride_elements);
|
|
570
|
+
for (; idx_scalars + step <= total_scalars; idx_scalars += step) {
|
|
571
|
+
__m256i data_u8x32 = _mm256_loadu_si256((__m256i const *)(data + idx_scalars));
|
|
572
|
+
data_u8x32 = _mm256_and_si256(data_u8x32, stride_mask_u8x32);
|
|
573
|
+
__m256i magnitude_u8x32 = _mm256_and_si256(data_u8x32, magnitude_mask_u8x32);
|
|
574
|
+
__m256i shuffle_idx_u8x32 = _mm256_and_si256(magnitude_u8x32, nibble_mask_u8x32);
|
|
575
|
+
__m256i upper_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(magnitude_u8x32, half_select_u8x32),
|
|
576
|
+
half_select_u8x32);
|
|
577
|
+
__m256i unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_lower_u8x32, shuffle_idx_u8x32),
|
|
578
|
+
_mm256_shuffle_epi8(lut_upper_u8x32, shuffle_idx_u8x32),
|
|
579
|
+
upper_select_u8x32);
|
|
580
|
+
__m256i negate_mask_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(data_u8x32, sign_mask_u8x32), sign_mask_u8x32);
|
|
581
|
+
__m256i sign_i8x32 = _mm256_blendv_epi8(ones_i8x32, neg_ones_i8x32, negate_mask_u8x32);
|
|
582
|
+
sum_i32x8 = _mm256_dpbusd_avx_epi32(sum_i32x8, unsigned_u8x32, sign_i8x32);
|
|
583
|
+
sumsq_i32x8 = _mm256_dpbusd_avx_epi32(sumsq_i32x8, unsigned_u8x32, unsigned_u8x32);
|
|
584
|
+
}
|
|
585
|
+
nk_i32_t sum = nk_reduce_add_i32x8_haswell_(sum_i32x8);
|
|
586
|
+
nk_i32_t sumsq = nk_reduce_add_i32x8_haswell_(sumsq_i32x8);
|
|
587
|
+
nk_e2m3_t const *ptr = data + idx_scalars;
|
|
588
|
+
nk_size_t remaining = count - idx_scalars / stride_elements;
|
|
589
|
+
for (nk_size_t i = 0; i < remaining; ++i, ptr += stride_elements) {
|
|
590
|
+
nk_f32_t val;
|
|
591
|
+
nk_e2m3_to_f32_serial(ptr, &val);
|
|
592
|
+
nk_i32_t ival = (nk_i32_t)(val * 16.0f);
|
|
593
|
+
sum += ival;
|
|
594
|
+
sumsq += ival * ival;
|
|
595
|
+
}
|
|
596
|
+
*sum_ptr = (nk_f32_t)sum / 16.0f;
|
|
597
|
+
*sumsq_ptr = (nk_f32_t)sumsq / 256.0f;
|
|
598
|
+
}
|
|
599
|
+
|
|
600
|
+
NK_PUBLIC void nk_reduce_moments_e2m3_alder( //
|
|
601
|
+
nk_e2m3_t const *data, nk_size_t count, nk_size_t stride_bytes, //
|
|
602
|
+
nk_f32_t *sum, nk_f32_t *sumsq) {
|
|
603
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_e2m3_t);
|
|
604
|
+
int aligned = (stride_bytes % sizeof(nk_e2m3_t) == 0);
|
|
605
|
+
if (count == 0) *sum = 0, *sumsq = 0;
|
|
606
|
+
else if (!aligned) nk_reduce_moments_e2m3_serial(data, count, stride_bytes, sum, sumsq);
|
|
607
|
+
else if (count > (nk_size_t)(NK_I16_MAX + 1) * 32) {
|
|
608
|
+
nk_size_t left_count = count / 2;
|
|
609
|
+
nk_f32_t left_sum, left_sumsq, right_sum, right_sumsq;
|
|
610
|
+
nk_reduce_moments_e2m3_alder(data, left_count, stride_bytes, &left_sum, &left_sumsq);
|
|
611
|
+
nk_reduce_moments_e2m3_alder(data + left_count * stride_elements, count - left_count, stride_bytes, &right_sum,
|
|
612
|
+
&right_sumsq);
|
|
613
|
+
*sum = left_sum + right_sum, *sumsq = left_sumsq + right_sumsq;
|
|
614
|
+
}
|
|
615
|
+
else if (stride_elements == 1) nk_reduce_moments_e2m3_alder_contiguous_(data, count, sum, sumsq);
|
|
616
|
+
else if (stride_elements <= 8) nk_reduce_moments_e2m3_alder_strided_(data, count, stride_elements, sum, sumsq);
|
|
617
|
+
else nk_reduce_moments_e2m3_serial(data, count, stride_bytes, sum, sumsq);
|
|
618
|
+
}
|
|
619
|
+
|
|
620
|
+
#if defined(__clang__)
|
|
621
|
+
#pragma clang attribute pop
|
|
622
|
+
#elif defined(__GNUC__)
|
|
623
|
+
#pragma GCC pop_options
|
|
624
|
+
#endif
|
|
625
|
+
|
|
626
|
+
#if defined(__cplusplus)
|
|
627
|
+
} // extern "C"
|
|
628
|
+
#endif
|
|
629
|
+
|
|
630
|
+
#endif // NK_TARGET_ALDER
|
|
631
|
+
#endif // NK_TARGET_X86_
|
|
632
|
+
#endif // NK_REDUCE_ALDER_H
|