numkong 7.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +201 -0
- package/README.md +495 -0
- package/binding.gyp +540 -0
- package/c/dispatch.h +512 -0
- package/c/dispatch_bf16.c +389 -0
- package/c/dispatch_bf16c.c +52 -0
- package/c/dispatch_e2m3.c +263 -0
- package/c/dispatch_e3m2.c +243 -0
- package/c/dispatch_e4m3.c +276 -0
- package/c/dispatch_e5m2.c +272 -0
- package/c/dispatch_f16.c +376 -0
- package/c/dispatch_f16c.c +58 -0
- package/c/dispatch_f32.c +378 -0
- package/c/dispatch_f32c.c +99 -0
- package/c/dispatch_f64.c +296 -0
- package/c/dispatch_f64c.c +98 -0
- package/c/dispatch_i16.c +96 -0
- package/c/dispatch_i32.c +89 -0
- package/c/dispatch_i4.c +150 -0
- package/c/dispatch_i64.c +86 -0
- package/c/dispatch_i8.c +289 -0
- package/c/dispatch_other.c +330 -0
- package/c/dispatch_u1.c +148 -0
- package/c/dispatch_u16.c +124 -0
- package/c/dispatch_u32.c +118 -0
- package/c/dispatch_u4.c +150 -0
- package/c/dispatch_u64.c +102 -0
- package/c/dispatch_u8.c +303 -0
- package/c/numkong.c +950 -0
- package/include/README.md +573 -0
- package/include/module.modulemap +129 -0
- package/include/numkong/attention/sapphireamx.h +1361 -0
- package/include/numkong/attention/sme.h +2066 -0
- package/include/numkong/attention.h +49 -0
- package/include/numkong/capabilities.h +748 -0
- package/include/numkong/cast/README.md +262 -0
- package/include/numkong/cast/haswell.h +975 -0
- package/include/numkong/cast/icelake.h +470 -0
- package/include/numkong/cast/neon.h +1192 -0
- package/include/numkong/cast/rvv.h +1021 -0
- package/include/numkong/cast/sapphire.h +262 -0
- package/include/numkong/cast/serial.h +2262 -0
- package/include/numkong/cast/skylake.h +856 -0
- package/include/numkong/cast/v128relaxed.h +180 -0
- package/include/numkong/cast.h +230 -0
- package/include/numkong/curved/README.md +223 -0
- package/include/numkong/curved/genoa.h +182 -0
- package/include/numkong/curved/haswell.h +276 -0
- package/include/numkong/curved/neon.h +205 -0
- package/include/numkong/curved/neonbfdot.h +212 -0
- package/include/numkong/curved/neonhalf.h +212 -0
- package/include/numkong/curved/rvv.h +305 -0
- package/include/numkong/curved/serial.h +207 -0
- package/include/numkong/curved/skylake.h +457 -0
- package/include/numkong/curved/smef64.h +506 -0
- package/include/numkong/curved.h +517 -0
- package/include/numkong/curved.hpp +144 -0
- package/include/numkong/dot/README.md +425 -0
- package/include/numkong/dot/alder.h +563 -0
- package/include/numkong/dot/genoa.h +315 -0
- package/include/numkong/dot/haswell.h +1688 -0
- package/include/numkong/dot/icelake.h +883 -0
- package/include/numkong/dot/neon.h +818 -0
- package/include/numkong/dot/neonbfdot.h +244 -0
- package/include/numkong/dot/neonfhm.h +360 -0
- package/include/numkong/dot/neonhalf.h +198 -0
- package/include/numkong/dot/neonsdot.h +508 -0
- package/include/numkong/dot/rvv.h +714 -0
- package/include/numkong/dot/rvvbb.h +72 -0
- package/include/numkong/dot/rvvbf16.h +123 -0
- package/include/numkong/dot/rvvhalf.h +129 -0
- package/include/numkong/dot/sapphire.h +141 -0
- package/include/numkong/dot/serial.h +838 -0
- package/include/numkong/dot/sierra.h +405 -0
- package/include/numkong/dot/skylake.h +1084 -0
- package/include/numkong/dot/sve.h +379 -0
- package/include/numkong/dot/svebfdot.h +74 -0
- package/include/numkong/dot/svehalf.h +123 -0
- package/include/numkong/dot/v128relaxed.h +1258 -0
- package/include/numkong/dot.h +1070 -0
- package/include/numkong/dot.hpp +94 -0
- package/include/numkong/dots/README.md +496 -0
- package/include/numkong/dots/alder.h +114 -0
- package/include/numkong/dots/genoa.h +94 -0
- package/include/numkong/dots/haswell.h +295 -0
- package/include/numkong/dots/icelake.h +171 -0
- package/include/numkong/dots/neon.h +120 -0
- package/include/numkong/dots/neonbfdot.h +58 -0
- package/include/numkong/dots/neonfhm.h +94 -0
- package/include/numkong/dots/neonhalf.h +57 -0
- package/include/numkong/dots/neonsdot.h +108 -0
- package/include/numkong/dots/rvv.h +2486 -0
- package/include/numkong/dots/sapphireamx.h +3973 -0
- package/include/numkong/dots/serial.h +2844 -0
- package/include/numkong/dots/sierra.h +97 -0
- package/include/numkong/dots/skylake.h +196 -0
- package/include/numkong/dots/sme.h +5372 -0
- package/include/numkong/dots/smebi32.h +461 -0
- package/include/numkong/dots/smef64.h +1318 -0
- package/include/numkong/dots/smehalf.h +47 -0
- package/include/numkong/dots/v128relaxed.h +294 -0
- package/include/numkong/dots.h +2804 -0
- package/include/numkong/dots.hpp +639 -0
- package/include/numkong/each/README.md +469 -0
- package/include/numkong/each/haswell.h +1658 -0
- package/include/numkong/each/icelake.h +272 -0
- package/include/numkong/each/neon.h +1104 -0
- package/include/numkong/each/neonbfdot.h +212 -0
- package/include/numkong/each/neonhalf.h +410 -0
- package/include/numkong/each/rvv.h +1121 -0
- package/include/numkong/each/sapphire.h +477 -0
- package/include/numkong/each/serial.h +260 -0
- package/include/numkong/each/skylake.h +1562 -0
- package/include/numkong/each.h +2146 -0
- package/include/numkong/each.hpp +434 -0
- package/include/numkong/geospatial/README.md +147 -0
- package/include/numkong/geospatial/haswell.h +593 -0
- package/include/numkong/geospatial/neon.h +571 -0
- package/include/numkong/geospatial/rvv.h +701 -0
- package/include/numkong/geospatial/serial.h +309 -0
- package/include/numkong/geospatial/skylake.h +577 -0
- package/include/numkong/geospatial/v128relaxed.h +613 -0
- package/include/numkong/geospatial.h +453 -0
- package/include/numkong/geospatial.hpp +235 -0
- package/include/numkong/matrix.hpp +336 -0
- package/include/numkong/maxsim/README.md +187 -0
- package/include/numkong/maxsim/alder.h +511 -0
- package/include/numkong/maxsim/genoa.h +115 -0
- package/include/numkong/maxsim/haswell.h +553 -0
- package/include/numkong/maxsim/icelake.h +480 -0
- package/include/numkong/maxsim/neonsdot.h +394 -0
- package/include/numkong/maxsim/sapphireamx.h +877 -0
- package/include/numkong/maxsim/serial.h +490 -0
- package/include/numkong/maxsim/sme.h +929 -0
- package/include/numkong/maxsim/v128relaxed.h +280 -0
- package/include/numkong/maxsim.h +571 -0
- package/include/numkong/maxsim.hpp +133 -0
- package/include/numkong/mesh/README.md +227 -0
- package/include/numkong/mesh/haswell.h +2235 -0
- package/include/numkong/mesh/neon.h +1329 -0
- package/include/numkong/mesh/neonbfdot.h +842 -0
- package/include/numkong/mesh/neonhalf.h +616 -0
- package/include/numkong/mesh/rvv.h +916 -0
- package/include/numkong/mesh/serial.h +742 -0
- package/include/numkong/mesh/skylake.h +1135 -0
- package/include/numkong/mesh/v128relaxed.h +1052 -0
- package/include/numkong/mesh.h +652 -0
- package/include/numkong/mesh.hpp +762 -0
- package/include/numkong/numkong.h +78 -0
- package/include/numkong/numkong.hpp +57 -0
- package/include/numkong/probability/README.md +173 -0
- package/include/numkong/probability/haswell.h +267 -0
- package/include/numkong/probability/neon.h +225 -0
- package/include/numkong/probability/rvv.h +409 -0
- package/include/numkong/probability/serial.h +169 -0
- package/include/numkong/probability/skylake.h +324 -0
- package/include/numkong/probability.h +383 -0
- package/include/numkong/probability.hpp +120 -0
- package/include/numkong/random.h +50 -0
- package/include/numkong/random.hpp +285 -0
- package/include/numkong/reduce/README.md +547 -0
- package/include/numkong/reduce/alder.h +632 -0
- package/include/numkong/reduce/genoa.h +201 -0
- package/include/numkong/reduce/haswell.h +3783 -0
- package/include/numkong/reduce/icelake.h +549 -0
- package/include/numkong/reduce/neon.h +3841 -0
- package/include/numkong/reduce/neonbfdot.h +353 -0
- package/include/numkong/reduce/neonfhm.h +665 -0
- package/include/numkong/reduce/neonhalf.h +157 -0
- package/include/numkong/reduce/neonsdot.h +357 -0
- package/include/numkong/reduce/rvv.h +3407 -0
- package/include/numkong/reduce/serial.h +757 -0
- package/include/numkong/reduce/sierra.h +338 -0
- package/include/numkong/reduce/skylake.h +3792 -0
- package/include/numkong/reduce/v128relaxed.h +2302 -0
- package/include/numkong/reduce.h +1597 -0
- package/include/numkong/reduce.hpp +633 -0
- package/include/numkong/scalar/README.md +89 -0
- package/include/numkong/scalar/haswell.h +113 -0
- package/include/numkong/scalar/neon.h +122 -0
- package/include/numkong/scalar/neonhalf.h +70 -0
- package/include/numkong/scalar/rvv.h +211 -0
- package/include/numkong/scalar/sapphire.h +63 -0
- package/include/numkong/scalar/serial.h +332 -0
- package/include/numkong/scalar/v128relaxed.h +56 -0
- package/include/numkong/scalar.h +683 -0
- package/include/numkong/set/README.md +179 -0
- package/include/numkong/set/haswell.h +334 -0
- package/include/numkong/set/icelake.h +485 -0
- package/include/numkong/set/neon.h +364 -0
- package/include/numkong/set/rvv.h +226 -0
- package/include/numkong/set/rvvbb.h +117 -0
- package/include/numkong/set/serial.h +174 -0
- package/include/numkong/set/sve.h +185 -0
- package/include/numkong/set/v128relaxed.h +240 -0
- package/include/numkong/set.h +457 -0
- package/include/numkong/set.hpp +114 -0
- package/include/numkong/sets/README.md +149 -0
- package/include/numkong/sets/haswell.h +63 -0
- package/include/numkong/sets/icelake.h +66 -0
- package/include/numkong/sets/neon.h +61 -0
- package/include/numkong/sets/serial.h +43 -0
- package/include/numkong/sets/smebi32.h +1099 -0
- package/include/numkong/sets/v128relaxed.h +58 -0
- package/include/numkong/sets.h +339 -0
- package/include/numkong/sparse/README.md +156 -0
- package/include/numkong/sparse/icelake.h +463 -0
- package/include/numkong/sparse/neon.h +288 -0
- package/include/numkong/sparse/serial.h +117 -0
- package/include/numkong/sparse/sve2.h +507 -0
- package/include/numkong/sparse/turin.h +322 -0
- package/include/numkong/sparse.h +363 -0
- package/include/numkong/sparse.hpp +113 -0
- package/include/numkong/spatial/README.md +435 -0
- package/include/numkong/spatial/alder.h +607 -0
- package/include/numkong/spatial/genoa.h +290 -0
- package/include/numkong/spatial/haswell.h +960 -0
- package/include/numkong/spatial/icelake.h +586 -0
- package/include/numkong/spatial/neon.h +773 -0
- package/include/numkong/spatial/neonbfdot.h +165 -0
- package/include/numkong/spatial/neonhalf.h +118 -0
- package/include/numkong/spatial/neonsdot.h +261 -0
- package/include/numkong/spatial/rvv.h +984 -0
- package/include/numkong/spatial/rvvbf16.h +123 -0
- package/include/numkong/spatial/rvvhalf.h +117 -0
- package/include/numkong/spatial/sapphire.h +343 -0
- package/include/numkong/spatial/serial.h +346 -0
- package/include/numkong/spatial/sierra.h +323 -0
- package/include/numkong/spatial/skylake.h +606 -0
- package/include/numkong/spatial/sve.h +224 -0
- package/include/numkong/spatial/svebfdot.h +122 -0
- package/include/numkong/spatial/svehalf.h +109 -0
- package/include/numkong/spatial/v128relaxed.h +717 -0
- package/include/numkong/spatial.h +1425 -0
- package/include/numkong/spatial.hpp +183 -0
- package/include/numkong/spatials/README.md +580 -0
- package/include/numkong/spatials/alder.h +94 -0
- package/include/numkong/spatials/genoa.h +94 -0
- package/include/numkong/spatials/haswell.h +219 -0
- package/include/numkong/spatials/icelake.h +113 -0
- package/include/numkong/spatials/neon.h +109 -0
- package/include/numkong/spatials/neonbfdot.h +60 -0
- package/include/numkong/spatials/neonfhm.h +92 -0
- package/include/numkong/spatials/neonhalf.h +58 -0
- package/include/numkong/spatials/neonsdot.h +109 -0
- package/include/numkong/spatials/rvv.h +1960 -0
- package/include/numkong/spatials/sapphireamx.h +1149 -0
- package/include/numkong/spatials/serial.h +226 -0
- package/include/numkong/spatials/sierra.h +96 -0
- package/include/numkong/spatials/skylake.h +184 -0
- package/include/numkong/spatials/sme.h +1901 -0
- package/include/numkong/spatials/smef64.h +465 -0
- package/include/numkong/spatials/v128relaxed.h +240 -0
- package/include/numkong/spatials.h +3021 -0
- package/include/numkong/spatials.hpp +508 -0
- package/include/numkong/tensor.hpp +1592 -0
- package/include/numkong/trigonometry/README.md +184 -0
- package/include/numkong/trigonometry/haswell.h +652 -0
- package/include/numkong/trigonometry/neon.h +639 -0
- package/include/numkong/trigonometry/rvv.h +699 -0
- package/include/numkong/trigonometry/serial.h +703 -0
- package/include/numkong/trigonometry/skylake.h +721 -0
- package/include/numkong/trigonometry/v128relaxed.h +666 -0
- package/include/numkong/trigonometry.h +467 -0
- package/include/numkong/trigonometry.hpp +166 -0
- package/include/numkong/types.h +1384 -0
- package/include/numkong/types.hpp +5603 -0
- package/include/numkong/vector.hpp +698 -0
- package/javascript/README.md +246 -0
- package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
- package/javascript/dist/cjs/numkong-wasm.js +617 -0
- package/javascript/dist/cjs/numkong.d.ts +343 -0
- package/javascript/dist/cjs/numkong.js +523 -0
- package/javascript/dist/cjs/package.json +3 -0
- package/javascript/dist/cjs/types.d.ts +284 -0
- package/javascript/dist/cjs/types.js +653 -0
- package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
- package/javascript/dist/esm/numkong-wasm.js +595 -0
- package/javascript/dist/esm/numkong.d.ts +343 -0
- package/javascript/dist/esm/numkong.js +452 -0
- package/javascript/dist/esm/package.json +3 -0
- package/javascript/dist/esm/types.d.ts +284 -0
- package/javascript/dist/esm/types.js +630 -0
- package/javascript/dist-package-cjs.json +3 -0
- package/javascript/dist-package-esm.json +3 -0
- package/javascript/node-gyp-build.d.ts +1 -0
- package/javascript/numkong-wasm.ts +756 -0
- package/javascript/numkong.c +689 -0
- package/javascript/numkong.ts +575 -0
- package/javascript/tsconfig-base.json +39 -0
- package/javascript/tsconfig-cjs.json +8 -0
- package/javascript/tsconfig-esm.json +8 -0
- package/javascript/types.ts +674 -0
- package/package.json +87 -0
|
@@ -0,0 +1,549 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief AVX-512 VNNI implementations for the redesigned reduction API (moments).
|
|
3
|
+
* @file include/numkong/reduce/icelake.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date February 12, 2026
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/reduce.h
|
|
8
|
+
*
|
|
9
|
+
* @section vnni_advantage VNNI Advantage
|
|
10
|
+
*
|
|
11
|
+
* `_mm512_dpwssd_epi32(acc, a, b)` (VPDPWSSD) fuses `acc + _mm512_madd_epi16(a, b)`
|
|
12
|
+
* into one instruction (5cy @ p0 on Ice Lake, 4cy @ p01 on Genoa), saving one
|
|
13
|
+
* `_mm512_add_epi32` per call vs the Skylake `madd + add` pair.
|
|
14
|
+
*/
|
|
15
|
+
#ifndef NK_REDUCE_ICELAKE_H
|
|
16
|
+
#define NK_REDUCE_ICELAKE_H
|
|
17
|
+
|
|
18
|
+
#if NK_TARGET_X86_
|
|
19
|
+
#if NK_TARGET_ICELAKE
|
|
20
|
+
|
|
21
|
+
#include "numkong/reduce/serial.h"
|
|
22
|
+
|
|
23
|
+
#if defined(__cplusplus)
|
|
24
|
+
extern "C" {
|
|
25
|
+
#endif
|
|
26
|
+
|
|
27
|
+
#if defined(__clang__)
|
|
28
|
+
#pragma clang attribute push( \
|
|
29
|
+
__attribute__((target("avx2,avx512f,avx512vl,avx512bw,avx512dq,avx512vnni,avx512vbmi,f16c,fma,bmi,bmi2"))), \
|
|
30
|
+
apply_to = function)
|
|
31
|
+
#elif defined(__GNUC__)
|
|
32
|
+
#pragma GCC push_options
|
|
33
|
+
#pragma GCC target("avx2", "avx512f", "avx512vl", "avx512bw", "avx512dq", "avx512vnni", "avx512vbmi", "f16c", "fma", \
|
|
34
|
+
"bmi", "bmi2")
|
|
35
|
+
#endif
|
|
36
|
+
|
|
37
|
+
NK_INTERNAL void nk_reduce_moments_i8_icelake_contiguous_( //
|
|
38
|
+
nk_i8_t const *data_ptr, nk_size_t count, //
|
|
39
|
+
nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
40
|
+
__m512i bias_i8x64 = _mm512_set1_epi8((char)0x80);
|
|
41
|
+
__m512i zero_i8x64 = _mm512_setzero_si512();
|
|
42
|
+
__m512i sum_u64x8 = _mm512_setzero_si512();
|
|
43
|
+
__m512i sumsq_low_i32x16 = _mm512_setzero_si512();
|
|
44
|
+
__m512i sumsq_high_i32x16 = _mm512_setzero_si512();
|
|
45
|
+
nk_size_t idx = 0;
|
|
46
|
+
for (; idx + 64 <= count; idx += 64) {
|
|
47
|
+
__m512i data_i8x64 = _mm512_loadu_si512(data_ptr + idx);
|
|
48
|
+
__m512i unsigned_i8x64 = _mm512_xor_si512(data_i8x64, bias_i8x64);
|
|
49
|
+
sum_u64x8 = _mm512_add_epi64(sum_u64x8, _mm512_sad_epu8(unsigned_i8x64, zero_i8x64));
|
|
50
|
+
__m512i low_i16x32 = _mm512_cvtepi8_epi16(_mm512_castsi512_si256(data_i8x64));
|
|
51
|
+
__m512i high_i16x32 = _mm512_cvtepi8_epi16(_mm512_extracti64x4_epi64(data_i8x64, 1));
|
|
52
|
+
sumsq_low_i32x16 = _mm512_dpwssd_epi32(sumsq_low_i32x16, low_i16x32, low_i16x32);
|
|
53
|
+
sumsq_high_i32x16 = _mm512_dpwssd_epi32(sumsq_high_i32x16, high_i16x32, high_i16x32);
|
|
54
|
+
}
|
|
55
|
+
nk_size_t remaining = count - idx;
|
|
56
|
+
if (remaining > 0) {
|
|
57
|
+
__mmask64 tail_mask = _bzhi_u64(0xFFFFFFFFFFFFFFFFull, (unsigned int)remaining);
|
|
58
|
+
__m512i data_i8x64 = _mm512_maskz_loadu_epi8(tail_mask, data_ptr + idx);
|
|
59
|
+
__m512i unsigned_i8x64 = _mm512_xor_si512(data_i8x64, _mm512_maskz_mov_epi8(tail_mask, bias_i8x64));
|
|
60
|
+
sum_u64x8 = _mm512_add_epi64(sum_u64x8, _mm512_sad_epu8(unsigned_i8x64, zero_i8x64));
|
|
61
|
+
__m512i low_i16x32 = _mm512_cvtepi8_epi16(_mm512_castsi512_si256(data_i8x64));
|
|
62
|
+
__m512i high_i16x32 = _mm512_cvtepi8_epi16(_mm512_extracti64x4_epi64(data_i8x64, 1));
|
|
63
|
+
sumsq_low_i32x16 = _mm512_dpwssd_epi32(sumsq_low_i32x16, low_i16x32, low_i16x32);
|
|
64
|
+
sumsq_high_i32x16 = _mm512_dpwssd_epi32(sumsq_high_i32x16, high_i16x32, high_i16x32);
|
|
65
|
+
}
|
|
66
|
+
sumsq_low_i32x16 = _mm512_add_epi32(sumsq_low_i32x16, sumsq_high_i32x16);
|
|
67
|
+
__m512i sumsq_i64x8 = _mm512_cvtepi32_epi64(_mm512_castsi512_si256(sumsq_low_i32x16));
|
|
68
|
+
sumsq_i64x8 = _mm512_add_epi64(sumsq_i64x8, _mm512_cvtepi32_epi64(_mm512_extracti64x4_epi64(sumsq_low_i32x16, 1)));
|
|
69
|
+
*sum_ptr = (nk_i64_t)nk_reduce_add_u64x8_skylake_(sum_u64x8) - (nk_i64_t)128 * (nk_i64_t)count;
|
|
70
|
+
*sumsq_ptr = (nk_u64_t)nk_reduce_add_i64x8_skylake_(sumsq_i64x8);
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
NK_INTERNAL void nk_reduce_moments_i8_icelake_strided_( //
|
|
74
|
+
nk_i8_t const *data_ptr, nk_size_t count, nk_size_t stride_elements, //
|
|
75
|
+
nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
76
|
+
__mmask64 stride_mask_m64 = nk_stride_mask_u1x64_(stride_elements);
|
|
77
|
+
__m512i masked_bias_i8x64 = _mm512_maskz_mov_epi8(stride_mask_m64, _mm512_set1_epi8((char)0x80));
|
|
78
|
+
__m512i zero_i8x64 = _mm512_setzero_si512();
|
|
79
|
+
__m512i sum_u64x8 = _mm512_setzero_si512();
|
|
80
|
+
__m512i sumsq_low_i32x16 = _mm512_setzero_si512();
|
|
81
|
+
__m512i sumsq_high_i32x16 = _mm512_setzero_si512();
|
|
82
|
+
nk_size_t idx_scalars = 0;
|
|
83
|
+
nk_size_t total_scalars = count * stride_elements;
|
|
84
|
+
nk_size_t step = (nk_size_t)_mm_popcnt_u64((nk_u64_t)stride_mask_m64) * stride_elements;
|
|
85
|
+
for (; idx_scalars + step <= total_scalars; idx_scalars += step) {
|
|
86
|
+
__m512i data_i8x64 = _mm512_maskz_loadu_epi8(stride_mask_m64, data_ptr + idx_scalars);
|
|
87
|
+
__m512i unsigned_i8x64 = _mm512_xor_si512(data_i8x64, masked_bias_i8x64);
|
|
88
|
+
sum_u64x8 = _mm512_add_epi64(sum_u64x8, _mm512_sad_epu8(unsigned_i8x64, zero_i8x64));
|
|
89
|
+
__m512i low_i16x32 = _mm512_cvtepi8_epi16(_mm512_castsi512_si256(data_i8x64));
|
|
90
|
+
__m512i high_i16x32 = _mm512_cvtepi8_epi16(_mm512_extracti64x4_epi64(data_i8x64, 1));
|
|
91
|
+
sumsq_low_i32x16 = _mm512_dpwssd_epi32(sumsq_low_i32x16, low_i16x32, low_i16x32);
|
|
92
|
+
sumsq_high_i32x16 = _mm512_dpwssd_epi32(sumsq_high_i32x16, high_i16x32, high_i16x32);
|
|
93
|
+
}
|
|
94
|
+
nk_size_t remaining_scalars = total_scalars - idx_scalars;
|
|
95
|
+
if (remaining_scalars > 0) {
|
|
96
|
+
__mmask64 tail_mask = stride_mask_m64 & _bzhi_u64(0xFFFFFFFFFFFFFFFFull, (unsigned int)remaining_scalars);
|
|
97
|
+
__m512i data_i8x64 = _mm512_maskz_loadu_epi8(tail_mask, data_ptr + idx_scalars);
|
|
98
|
+
__m512i unsigned_i8x64 = _mm512_xor_si512(data_i8x64,
|
|
99
|
+
_mm512_maskz_mov_epi8(tail_mask, _mm512_set1_epi8((char)0x80)));
|
|
100
|
+
sum_u64x8 = _mm512_add_epi64(sum_u64x8, _mm512_sad_epu8(unsigned_i8x64, zero_i8x64));
|
|
101
|
+
__m512i low_i16x32 = _mm512_cvtepi8_epi16(_mm512_castsi512_si256(data_i8x64));
|
|
102
|
+
__m512i high_i16x32 = _mm512_cvtepi8_epi16(_mm512_extracti64x4_epi64(data_i8x64, 1));
|
|
103
|
+
sumsq_low_i32x16 = _mm512_dpwssd_epi32(sumsq_low_i32x16, low_i16x32, low_i16x32);
|
|
104
|
+
sumsq_high_i32x16 = _mm512_dpwssd_epi32(sumsq_high_i32x16, high_i16x32, high_i16x32);
|
|
105
|
+
}
|
|
106
|
+
sumsq_low_i32x16 = _mm512_add_epi32(sumsq_low_i32x16, sumsq_high_i32x16);
|
|
107
|
+
__m512i sumsq_i64x8 = _mm512_cvtepi32_epi64(_mm512_castsi512_si256(sumsq_low_i32x16));
|
|
108
|
+
sumsq_i64x8 = _mm512_add_epi64(sumsq_i64x8, _mm512_cvtepi32_epi64(_mm512_extracti64x4_epi64(sumsq_low_i32x16, 1)));
|
|
109
|
+
*sum_ptr = (nk_i64_t)nk_reduce_add_u64x8_skylake_(sum_u64x8) - (nk_i64_t)128 * (nk_i64_t)count;
|
|
110
|
+
*sumsq_ptr = (nk_u64_t)nk_reduce_add_i64x8_skylake_(sumsq_i64x8);
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
NK_PUBLIC void nk_reduce_moments_i8_icelake( //
|
|
114
|
+
nk_i8_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
115
|
+
nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
116
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_i8_t);
|
|
117
|
+
int aligned = (stride_bytes % sizeof(nk_i8_t) == 0);
|
|
118
|
+
if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
|
|
119
|
+
else if (!aligned) nk_reduce_moments_i8_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
120
|
+
else if (count > (nk_size_t)(NK_U16_MAX + 1) * 64) {
|
|
121
|
+
nk_size_t left_count = count / 2;
|
|
122
|
+
nk_i64_t left_sum, right_sum;
|
|
123
|
+
nk_u64_t left_sumsq, right_sumsq;
|
|
124
|
+
nk_reduce_moments_i8_icelake(data_ptr, left_count, stride_bytes, &left_sum, &left_sumsq);
|
|
125
|
+
nk_reduce_moments_i8_icelake(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
|
|
126
|
+
&right_sum, &right_sumsq);
|
|
127
|
+
*sum_ptr = nk_i64_saturating_add_serial(left_sum, right_sum);
|
|
128
|
+
*sumsq_ptr = nk_u64_saturating_add_serial(left_sumsq, right_sumsq);
|
|
129
|
+
}
|
|
130
|
+
else if (stride_elements == 1) nk_reduce_moments_i8_icelake_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
|
|
131
|
+
else if (stride_elements <= 16)
|
|
132
|
+
nk_reduce_moments_i8_icelake_strided_(data_ptr, count, stride_elements, sum_ptr, sumsq_ptr);
|
|
133
|
+
else nk_reduce_moments_i8_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
NK_INTERNAL void nk_reduce_moments_u8_icelake_contiguous_( //
|
|
137
|
+
nk_u8_t const *data_ptr, nk_size_t count, //
|
|
138
|
+
nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
139
|
+
__m512i zero_u8x64 = _mm512_setzero_si512();
|
|
140
|
+
__m512i sum_u64x8 = _mm512_setzero_si512();
|
|
141
|
+
__m512i sumsq_low_i32x16 = _mm512_setzero_si512();
|
|
142
|
+
__m512i sumsq_high_i32x16 = _mm512_setzero_si512();
|
|
143
|
+
nk_size_t idx = 0;
|
|
144
|
+
for (; idx + 64 <= count; idx += 64) {
|
|
145
|
+
__m512i data_u8x64 = _mm512_loadu_si512(data_ptr + idx);
|
|
146
|
+
sum_u64x8 = _mm512_add_epi64(sum_u64x8, _mm512_sad_epu8(data_u8x64, zero_u8x64));
|
|
147
|
+
__m512i low_i16x32 = _mm512_cvtepu8_epi16(_mm512_castsi512_si256(data_u8x64));
|
|
148
|
+
__m512i high_i16x32 = _mm512_cvtepu8_epi16(_mm512_extracti64x4_epi64(data_u8x64, 1));
|
|
149
|
+
sumsq_low_i32x16 = _mm512_dpwssd_epi32(sumsq_low_i32x16, low_i16x32, low_i16x32);
|
|
150
|
+
sumsq_high_i32x16 = _mm512_dpwssd_epi32(sumsq_high_i32x16, high_i16x32, high_i16x32);
|
|
151
|
+
}
|
|
152
|
+
nk_size_t remaining = count - idx;
|
|
153
|
+
if (remaining > 0) {
|
|
154
|
+
__mmask64 tail_mask = _bzhi_u64(0xFFFFFFFFFFFFFFFFull, (unsigned int)remaining);
|
|
155
|
+
__m512i data_u8x64 = _mm512_maskz_loadu_epi8(tail_mask, data_ptr + idx);
|
|
156
|
+
sum_u64x8 = _mm512_add_epi64(sum_u64x8, _mm512_sad_epu8(data_u8x64, zero_u8x64));
|
|
157
|
+
__m512i low_i16x32 = _mm512_cvtepu8_epi16(_mm512_castsi512_si256(data_u8x64));
|
|
158
|
+
__m512i high_i16x32 = _mm512_cvtepu8_epi16(_mm512_extracti64x4_epi64(data_u8x64, 1));
|
|
159
|
+
sumsq_low_i32x16 = _mm512_dpwssd_epi32(sumsq_low_i32x16, low_i16x32, low_i16x32);
|
|
160
|
+
sumsq_high_i32x16 = _mm512_dpwssd_epi32(sumsq_high_i32x16, high_i16x32, high_i16x32);
|
|
161
|
+
}
|
|
162
|
+
sumsq_low_i32x16 = _mm512_add_epi32(sumsq_low_i32x16, sumsq_high_i32x16);
|
|
163
|
+
__m512i sumsq_u64x8 = _mm512_cvtepu32_epi64(_mm512_castsi512_si256(sumsq_low_i32x16));
|
|
164
|
+
sumsq_u64x8 = _mm512_add_epi64(sumsq_u64x8, _mm512_cvtepu32_epi64(_mm512_extracti64x4_epi64(sumsq_low_i32x16, 1)));
|
|
165
|
+
*sum_ptr = nk_reduce_add_u64x8_skylake_(sum_u64x8);
|
|
166
|
+
*sumsq_ptr = nk_reduce_add_u64x8_skylake_(sumsq_u64x8);
|
|
167
|
+
}
|
|
168
|
+
|
|
169
|
+
NK_INTERNAL void nk_reduce_moments_u8_icelake_strided_( //
|
|
170
|
+
nk_u8_t const *data_ptr, nk_size_t count, nk_size_t stride_elements, //
|
|
171
|
+
nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
172
|
+
__mmask64 stride_mask_m64 = nk_stride_mask_u1x64_(stride_elements);
|
|
173
|
+
__m512i zero_u8x64 = _mm512_setzero_si512();
|
|
174
|
+
__m512i sum_u64x8 = _mm512_setzero_si512();
|
|
175
|
+
__m512i sumsq_low_i32x16 = _mm512_setzero_si512();
|
|
176
|
+
__m512i sumsq_high_i32x16 = _mm512_setzero_si512();
|
|
177
|
+
nk_size_t idx_scalars = 0;
|
|
178
|
+
nk_size_t total_scalars = count * stride_elements;
|
|
179
|
+
nk_size_t step = (nk_size_t)_mm_popcnt_u64((nk_u64_t)stride_mask_m64) * stride_elements;
|
|
180
|
+
for (; idx_scalars + step <= total_scalars; idx_scalars += step) {
|
|
181
|
+
__m512i data_u8x64 = _mm512_maskz_loadu_epi8(stride_mask_m64, data_ptr + idx_scalars);
|
|
182
|
+
sum_u64x8 = _mm512_add_epi64(sum_u64x8, _mm512_sad_epu8(data_u8x64, zero_u8x64));
|
|
183
|
+
__m512i low_i16x32 = _mm512_cvtepu8_epi16(_mm512_castsi512_si256(data_u8x64));
|
|
184
|
+
__m512i high_i16x32 = _mm512_cvtepu8_epi16(_mm512_extracti64x4_epi64(data_u8x64, 1));
|
|
185
|
+
sumsq_low_i32x16 = _mm512_dpwssd_epi32(sumsq_low_i32x16, low_i16x32, low_i16x32);
|
|
186
|
+
sumsq_high_i32x16 = _mm512_dpwssd_epi32(sumsq_high_i32x16, high_i16x32, high_i16x32);
|
|
187
|
+
}
|
|
188
|
+
nk_size_t remaining_scalars = total_scalars - idx_scalars;
|
|
189
|
+
if (remaining_scalars > 0) {
|
|
190
|
+
__mmask64 tail_mask = stride_mask_m64 & _bzhi_u64(0xFFFFFFFFFFFFFFFFull, (unsigned int)remaining_scalars);
|
|
191
|
+
__m512i data_u8x64 = _mm512_maskz_loadu_epi8(tail_mask, data_ptr + idx_scalars);
|
|
192
|
+
sum_u64x8 = _mm512_add_epi64(sum_u64x8, _mm512_sad_epu8(data_u8x64, zero_u8x64));
|
|
193
|
+
__m512i low_i16x32 = _mm512_cvtepu8_epi16(_mm512_castsi512_si256(data_u8x64));
|
|
194
|
+
__m512i high_i16x32 = _mm512_cvtepu8_epi16(_mm512_extracti64x4_epi64(data_u8x64, 1));
|
|
195
|
+
sumsq_low_i32x16 = _mm512_dpwssd_epi32(sumsq_low_i32x16, low_i16x32, low_i16x32);
|
|
196
|
+
sumsq_high_i32x16 = _mm512_dpwssd_epi32(sumsq_high_i32x16, high_i16x32, high_i16x32);
|
|
197
|
+
}
|
|
198
|
+
sumsq_low_i32x16 = _mm512_add_epi32(sumsq_low_i32x16, sumsq_high_i32x16);
|
|
199
|
+
__m512i sumsq_u64x8 = _mm512_cvtepu32_epi64(_mm512_castsi512_si256(sumsq_low_i32x16));
|
|
200
|
+
sumsq_u64x8 = _mm512_add_epi64(sumsq_u64x8, _mm512_cvtepu32_epi64(_mm512_extracti64x4_epi64(sumsq_low_i32x16, 1)));
|
|
201
|
+
*sum_ptr = nk_reduce_add_u64x8_skylake_(sum_u64x8);
|
|
202
|
+
*sumsq_ptr = nk_reduce_add_u64x8_skylake_(sumsq_u64x8);
|
|
203
|
+
}
|
|
204
|
+
|
|
205
|
+
NK_PUBLIC void nk_reduce_moments_u8_icelake( //
|
|
206
|
+
nk_u8_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
207
|
+
nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
208
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_u8_t);
|
|
209
|
+
int aligned = (stride_bytes % sizeof(nk_u8_t) == 0);
|
|
210
|
+
if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
|
|
211
|
+
else if (!aligned) nk_reduce_moments_u8_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
212
|
+
else if (count > (nk_size_t)(NK_U8_MAX + 1) * 64) {
|
|
213
|
+
nk_size_t left_count = count / 2;
|
|
214
|
+
nk_u64_t left_sum, left_sumsq, right_sum, right_sumsq;
|
|
215
|
+
nk_reduce_moments_u8_icelake(data_ptr, left_count, stride_bytes, &left_sum, &left_sumsq);
|
|
216
|
+
nk_reduce_moments_u8_icelake(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
|
|
217
|
+
&right_sum, &right_sumsq);
|
|
218
|
+
*sum_ptr = nk_u64_saturating_add_serial(left_sum, right_sum);
|
|
219
|
+
*sumsq_ptr = nk_u64_saturating_add_serial(left_sumsq, right_sumsq);
|
|
220
|
+
}
|
|
221
|
+
else if (stride_elements == 1) nk_reduce_moments_u8_icelake_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
|
|
222
|
+
else if (stride_elements <= 16)
|
|
223
|
+
nk_reduce_moments_u8_icelake_strided_(data_ptr, count, stride_elements, sum_ptr, sumsq_ptr);
|
|
224
|
+
else nk_reduce_moments_u8_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
225
|
+
}
|
|
226
|
+
|
|
227
|
+
NK_INTERNAL void nk_reduce_moments_i16_icelake_contiguous_( //
|
|
228
|
+
nk_i16_t const *data_ptr, nk_size_t count, //
|
|
229
|
+
nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
230
|
+
// Sum: VPDPWSSD(acc, data, ones) accumulates in i32 — safe for (NK_I16_MAX+1)*32 elements.
|
|
231
|
+
// Sumsq: VPDPWSSD(zero, data, data) → fresh i32, widen to i64 each iteration.
|
|
232
|
+
__m512i ones_i16x32 = _mm512_set1_epi16(1);
|
|
233
|
+
__m512i sum_i32x16 = _mm512_setzero_si512();
|
|
234
|
+
__m512i sumsq_i64x8 = _mm512_setzero_si512();
|
|
235
|
+
nk_size_t idx = 0;
|
|
236
|
+
for (; idx + 32 <= count; idx += 32) {
|
|
237
|
+
__m512i data_i16x32 = _mm512_loadu_si512(data_ptr + idx);
|
|
238
|
+
sum_i32x16 = _mm512_dpwssd_epi32(sum_i32x16, data_i16x32, ones_i16x32);
|
|
239
|
+
__m512i sq_i32x16 = _mm512_dpwssd_epi32(_mm512_setzero_si512(), data_i16x32, data_i16x32);
|
|
240
|
+
sumsq_i64x8 = _mm512_add_epi64(sumsq_i64x8, _mm512_cvtepi32_epi64(_mm512_castsi512_si256(sq_i32x16)));
|
|
241
|
+
sumsq_i64x8 = _mm512_add_epi64(sumsq_i64x8, _mm512_cvtepi32_epi64(_mm512_extracti64x4_epi64(sq_i32x16, 1)));
|
|
242
|
+
}
|
|
243
|
+
nk_size_t remaining = count - idx;
|
|
244
|
+
if (remaining > 0) {
|
|
245
|
+
__mmask32 tail_mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, (unsigned int)remaining);
|
|
246
|
+
__m512i data_i16x32 = _mm512_maskz_loadu_epi16(tail_mask, data_ptr + idx);
|
|
247
|
+
sum_i32x16 = _mm512_dpwssd_epi32(sum_i32x16, data_i16x32, ones_i16x32);
|
|
248
|
+
__m512i sq_i32x16 = _mm512_dpwssd_epi32(_mm512_setzero_si512(), data_i16x32, data_i16x32);
|
|
249
|
+
sumsq_i64x8 = _mm512_add_epi64(sumsq_i64x8, _mm512_cvtepi32_epi64(_mm512_castsi512_si256(sq_i32x16)));
|
|
250
|
+
sumsq_i64x8 = _mm512_add_epi64(sumsq_i64x8, _mm512_cvtepi32_epi64(_mm512_extracti64x4_epi64(sq_i32x16, 1)));
|
|
251
|
+
}
|
|
252
|
+
__m512i sum_i64x8 = _mm512_add_epi64( //
|
|
253
|
+
_mm512_cvtepi32_epi64(_mm512_castsi512_si256(sum_i32x16)), //
|
|
254
|
+
_mm512_cvtepi32_epi64(_mm512_extracti64x4_epi64(sum_i32x16, 1))); //
|
|
255
|
+
*sum_ptr = nk_reduce_add_i64x8_skylake_(sum_i64x8);
|
|
256
|
+
*sumsq_ptr = (nk_u64_t)nk_reduce_add_i64x8_skylake_(sumsq_i64x8);
|
|
257
|
+
}
|
|
258
|
+
|
|
259
|
+
NK_INTERNAL void nk_reduce_moments_i16_icelake_strided_( //
|
|
260
|
+
nk_i16_t const *data_ptr, nk_size_t count, nk_size_t stride_elements, //
|
|
261
|
+
nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
262
|
+
__mmask32 stride_mask_m32 = nk_stride_mask_b16x32_(stride_elements);
|
|
263
|
+
__m512i ones_i16x32 = _mm512_set1_epi16(1);
|
|
264
|
+
__m512i sum_i32x16 = _mm512_setzero_si512();
|
|
265
|
+
__m512i sumsq_i64x8 = _mm512_setzero_si512();
|
|
266
|
+
nk_size_t idx_scalars = 0;
|
|
267
|
+
nk_size_t total_scalars = count * stride_elements;
|
|
268
|
+
nk_size_t step = (nk_size_t)_mm_popcnt_u64((nk_u64_t)stride_mask_m32) * stride_elements;
|
|
269
|
+
for (; idx_scalars + step <= total_scalars; idx_scalars += step) {
|
|
270
|
+
__m512i data_i16x32 = _mm512_maskz_loadu_epi16(stride_mask_m32, data_ptr + idx_scalars);
|
|
271
|
+
sum_i32x16 = _mm512_dpwssd_epi32(sum_i32x16, data_i16x32, ones_i16x32);
|
|
272
|
+
__m512i sq_i32x16 = _mm512_dpwssd_epi32(_mm512_setzero_si512(), data_i16x32, data_i16x32);
|
|
273
|
+
sumsq_i64x8 = _mm512_add_epi64(sumsq_i64x8, _mm512_cvtepi32_epi64(_mm512_castsi512_si256(sq_i32x16)));
|
|
274
|
+
sumsq_i64x8 = _mm512_add_epi64(sumsq_i64x8, _mm512_cvtepi32_epi64(_mm512_extracti64x4_epi64(sq_i32x16, 1)));
|
|
275
|
+
}
|
|
276
|
+
nk_size_t remaining_scalars = total_scalars - idx_scalars;
|
|
277
|
+
if (remaining_scalars > 0) {
|
|
278
|
+
__mmask32 tail_mask = stride_mask_m32 & (__mmask32)_bzhi_u32(0xFFFFFFFF, (unsigned int)(remaining_scalars));
|
|
279
|
+
__m512i data_i16x32 = _mm512_maskz_loadu_epi16(tail_mask, data_ptr + idx_scalars);
|
|
280
|
+
sum_i32x16 = _mm512_dpwssd_epi32(sum_i32x16, data_i16x32, ones_i16x32);
|
|
281
|
+
__m512i sq_i32x16 = _mm512_dpwssd_epi32(_mm512_setzero_si512(), data_i16x32, data_i16x32);
|
|
282
|
+
sumsq_i64x8 = _mm512_add_epi64(sumsq_i64x8, _mm512_cvtepi32_epi64(_mm512_castsi512_si256(sq_i32x16)));
|
|
283
|
+
sumsq_i64x8 = _mm512_add_epi64(sumsq_i64x8, _mm512_cvtepi32_epi64(_mm512_extracti64x4_epi64(sq_i32x16, 1)));
|
|
284
|
+
}
|
|
285
|
+
__m512i sum_i64x8 = _mm512_add_epi64( //
|
|
286
|
+
_mm512_cvtepi32_epi64(_mm512_castsi512_si256(sum_i32x16)), //
|
|
287
|
+
_mm512_cvtepi32_epi64(_mm512_extracti64x4_epi64(sum_i32x16, 1))); //
|
|
288
|
+
*sum_ptr = nk_reduce_add_i64x8_skylake_(sum_i64x8);
|
|
289
|
+
*sumsq_ptr = (nk_u64_t)nk_reduce_add_i64x8_skylake_(sumsq_i64x8);
|
|
290
|
+
}
|
|
291
|
+
|
|
292
|
+
NK_PUBLIC void nk_reduce_moments_i16_icelake( //
|
|
293
|
+
nk_i16_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
294
|
+
nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
295
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_i16_t);
|
|
296
|
+
int aligned = (stride_bytes % sizeof(nk_i16_t) == 0);
|
|
297
|
+
if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
|
|
298
|
+
else if (!aligned) nk_reduce_moments_i16_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
299
|
+
else if (count > (nk_size_t)(NK_I16_MAX + 1) * 32) {
|
|
300
|
+
nk_size_t left_count = count / 2;
|
|
301
|
+
nk_i64_t left_sum, right_sum;
|
|
302
|
+
nk_u64_t left_sumsq, right_sumsq;
|
|
303
|
+
nk_reduce_moments_i16_icelake(data_ptr, left_count, stride_bytes, &left_sum, &left_sumsq);
|
|
304
|
+
nk_reduce_moments_i16_icelake(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
|
|
305
|
+
&right_sum, &right_sumsq);
|
|
306
|
+
*sum_ptr = nk_i64_saturating_add_serial(left_sum, right_sum);
|
|
307
|
+
*sumsq_ptr = nk_u64_saturating_add_serial(left_sumsq, right_sumsq);
|
|
308
|
+
}
|
|
309
|
+
else if (stride_elements == 1) nk_reduce_moments_i16_icelake_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
|
|
310
|
+
else if (stride_elements <= 16)
|
|
311
|
+
nk_reduce_moments_i16_icelake_strided_(data_ptr, count, stride_elements, sum_ptr, sumsq_ptr);
|
|
312
|
+
else nk_reduce_moments_i16_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
313
|
+
}
|
|
314
|
+
|
|
315
|
+
NK_INTERNAL void nk_reduce_moments_e2m3_icelake_contiguous_( //
|
|
316
|
+
nk_e2m3_t const *data_ptr, nk_size_t count, //
|
|
317
|
+
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
318
|
+
// 64-byte LUT: maps 5-bit unsigned magnitude -> value*16 as u8 (0..120)
|
|
319
|
+
// Entries 0-31 replicated in upper 32 bytes (VPERMB indexes mod 64)
|
|
320
|
+
__m512i const lut_magnitude_u8x64 = _mm512_set_epi8(120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36,
|
|
321
|
+
32, 30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0,
|
|
322
|
+
120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36,
|
|
323
|
+
32, 30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
|
|
324
|
+
__m512i const magnitude_mask_u8x64 = _mm512_set1_epi8(0x1F);
|
|
325
|
+
__m512i const sign_mask_u8x64 = _mm512_set1_epi8(0x20);
|
|
326
|
+
__m512i const ones_u8x64 = _mm512_set1_epi8(1);
|
|
327
|
+
__m512i sum_i32x16 = _mm512_setzero_si512();
|
|
328
|
+
__m512i sumsq_i32x16 = _mm512_setzero_si512();
|
|
329
|
+
nk_size_t idx = 0;
|
|
330
|
+
for (; idx + 64 <= count; idx += 64) {
|
|
331
|
+
__m512i data_u8x64 = _mm512_loadu_si512(data_ptr + idx);
|
|
332
|
+
// Extract 5-bit magnitude, LUT lookup
|
|
333
|
+
__m512i magnitude_u8x64 = _mm512_and_si512(data_u8x64, magnitude_mask_u8x64);
|
|
334
|
+
__m512i unsigned_mag_u8x64 = _mm512_permutexvar_epi8(magnitude_u8x64, lut_magnitude_u8x64);
|
|
335
|
+
// Apply sign for sum: negate where bit 5 is set
|
|
336
|
+
__mmask64 sign_mask = _mm512_test_epi8_mask(data_u8x64, sign_mask_u8x64);
|
|
337
|
+
__m512i signed_mag_i8x64 = _mm512_mask_sub_epi8(unsigned_mag_u8x64, sign_mask, _mm512_setzero_si512(),
|
|
338
|
+
unsigned_mag_u8x64);
|
|
339
|
+
// Sum: VPDPBUSD(acc, ones_u8, signed_i8) = acc + sum(1 * signed_val) per 4-byte group
|
|
340
|
+
sum_i32x16 = _mm512_dpbusd_epi32(sum_i32x16, ones_u8x64, signed_mag_i8x64);
|
|
341
|
+
// Sumsq: VPDPBUSD(acc, unsigned_mag, unsigned_mag) = acc + sum(mag^2) per 4-byte group
|
|
342
|
+
// magnitude is 0-120, fits in both u8 and i8 interpretations
|
|
343
|
+
sumsq_i32x16 = _mm512_dpbusd_epi32(sumsq_i32x16, unsigned_mag_u8x64, unsigned_mag_u8x64);
|
|
344
|
+
}
|
|
345
|
+
nk_size_t remaining = count - idx;
|
|
346
|
+
if (remaining > 0) {
|
|
347
|
+
__mmask64 tail_mask = _bzhi_u64(0xFFFFFFFFFFFFFFFFull, (unsigned int)remaining);
|
|
348
|
+
__m512i data_u8x64 = _mm512_maskz_loadu_epi8(tail_mask, data_ptr + idx);
|
|
349
|
+
__m512i magnitude_u8x64 = _mm512_and_si512(data_u8x64, magnitude_mask_u8x64);
|
|
350
|
+
__m512i unsigned_mag_u8x64 = _mm512_permutexvar_epi8(magnitude_u8x64, lut_magnitude_u8x64);
|
|
351
|
+
__mmask64 sign_mask = _mm512_test_epi8_mask(data_u8x64, sign_mask_u8x64);
|
|
352
|
+
__m512i signed_mag_i8x64 = _mm512_mask_sub_epi8(unsigned_mag_u8x64, sign_mask, _mm512_setzero_si512(),
|
|
353
|
+
unsigned_mag_u8x64);
|
|
354
|
+
sum_i32x16 = _mm512_dpbusd_epi32(sum_i32x16, ones_u8x64, signed_mag_i8x64);
|
|
355
|
+
sumsq_i32x16 = _mm512_dpbusd_epi32(sumsq_i32x16, unsigned_mag_u8x64, unsigned_mag_u8x64);
|
|
356
|
+
}
|
|
357
|
+
*sum_ptr = (nk_f32_t)_mm512_reduce_add_epi32(sum_i32x16) / 16.0f;
|
|
358
|
+
*sumsq_ptr = (nk_f32_t)_mm512_reduce_add_epi32(sumsq_i32x16) / 256.0f;
|
|
359
|
+
}
|
|
360
|
+
|
|
361
|
+
NK_INTERNAL void nk_reduce_moments_e2m3_icelake_strided_( //
|
|
362
|
+
nk_e2m3_t const *data_ptr, nk_size_t count, nk_size_t stride_elements, //
|
|
363
|
+
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
364
|
+
__mmask64 stride_mask_m64 = nk_stride_mask_u1x64_(stride_elements);
|
|
365
|
+
__m512i const lut_magnitude_u8x64 = _mm512_set_epi8(120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36,
|
|
366
|
+
32, 30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0,
|
|
367
|
+
120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36,
|
|
368
|
+
32, 30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
|
|
369
|
+
__m512i const magnitude_mask_u8x64 = _mm512_set1_epi8(0x1F);
|
|
370
|
+
__m512i const sign_mask_u8x64 = _mm512_set1_epi8(0x20);
|
|
371
|
+
__m512i const ones_u8x64 = _mm512_set1_epi8(1);
|
|
372
|
+
__m512i sum_i32x16 = _mm512_setzero_si512();
|
|
373
|
+
__m512i sumsq_i32x16 = _mm512_setzero_si512();
|
|
374
|
+
nk_size_t idx_scalars = 0;
|
|
375
|
+
nk_size_t total_scalars = count * stride_elements;
|
|
376
|
+
nk_size_t step = (nk_size_t)_mm_popcnt_u64((nk_u64_t)stride_mask_m64) * stride_elements;
|
|
377
|
+
for (; idx_scalars + step <= total_scalars; idx_scalars += step) {
|
|
378
|
+
__m512i data_u8x64 = _mm512_maskz_loadu_epi8(stride_mask_m64, data_ptr + idx_scalars);
|
|
379
|
+
__m512i magnitude_u8x64 = _mm512_and_si512(data_u8x64, magnitude_mask_u8x64);
|
|
380
|
+
__m512i unsigned_mag_u8x64 = _mm512_permutexvar_epi8(magnitude_u8x64, lut_magnitude_u8x64);
|
|
381
|
+
__mmask64 sign_mask = _mm512_test_epi8_mask(data_u8x64, sign_mask_u8x64);
|
|
382
|
+
__m512i signed_mag_i8x64 = _mm512_mask_sub_epi8(unsigned_mag_u8x64, sign_mask, _mm512_setzero_si512(),
|
|
383
|
+
unsigned_mag_u8x64);
|
|
384
|
+
sum_i32x16 = _mm512_dpbusd_epi32(sum_i32x16, ones_u8x64, signed_mag_i8x64);
|
|
385
|
+
sumsq_i32x16 = _mm512_dpbusd_epi32(sumsq_i32x16, unsigned_mag_u8x64, unsigned_mag_u8x64);
|
|
386
|
+
}
|
|
387
|
+
nk_size_t remaining_scalars = total_scalars - idx_scalars;
|
|
388
|
+
if (remaining_scalars > 0) {
|
|
389
|
+
__mmask64 tail_mask = stride_mask_m64 & _bzhi_u64(0xFFFFFFFFFFFFFFFFull, (unsigned int)remaining_scalars);
|
|
390
|
+
__m512i data_u8x64 = _mm512_maskz_loadu_epi8(tail_mask, data_ptr + idx_scalars);
|
|
391
|
+
__m512i magnitude_u8x64 = _mm512_and_si512(data_u8x64, magnitude_mask_u8x64);
|
|
392
|
+
__m512i unsigned_mag_u8x64 = _mm512_permutexvar_epi8(magnitude_u8x64, lut_magnitude_u8x64);
|
|
393
|
+
__mmask64 sign_mask = _mm512_test_epi8_mask(data_u8x64, sign_mask_u8x64);
|
|
394
|
+
__m512i signed_mag_i8x64 = _mm512_mask_sub_epi8(unsigned_mag_u8x64, sign_mask, _mm512_setzero_si512(),
|
|
395
|
+
unsigned_mag_u8x64);
|
|
396
|
+
sum_i32x16 = _mm512_dpbusd_epi32(sum_i32x16, ones_u8x64, signed_mag_i8x64);
|
|
397
|
+
sumsq_i32x16 = _mm512_dpbusd_epi32(sumsq_i32x16, unsigned_mag_u8x64, unsigned_mag_u8x64);
|
|
398
|
+
}
|
|
399
|
+
*sum_ptr = (nk_f32_t)_mm512_reduce_add_epi32(sum_i32x16) / 16.0f;
|
|
400
|
+
*sumsq_ptr = (nk_f32_t)_mm512_reduce_add_epi32(sumsq_i32x16) / 256.0f;
|
|
401
|
+
}
|
|
402
|
+
|
|
403
|
+
NK_PUBLIC void nk_reduce_moments_e2m3_icelake( //
|
|
404
|
+
nk_e2m3_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
405
|
+
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
406
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_e2m3_t);
|
|
407
|
+
int aligned = (stride_bytes % sizeof(nk_e2m3_t) == 0);
|
|
408
|
+
if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
|
|
409
|
+
else if (!aligned) nk_reduce_moments_e2m3_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
410
|
+
else if (count > (nk_size_t)(NK_I16_MAX + 1) * 64) {
|
|
411
|
+
nk_size_t left_count = count / 2;
|
|
412
|
+
nk_f32_t left_sum, left_sumsq, right_sum, right_sumsq;
|
|
413
|
+
nk_reduce_moments_e2m3_icelake(data_ptr, left_count, stride_bytes, &left_sum, &left_sumsq);
|
|
414
|
+
nk_reduce_moments_e2m3_icelake(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
|
|
415
|
+
&right_sum, &right_sumsq);
|
|
416
|
+
*sum_ptr = left_sum + right_sum;
|
|
417
|
+
*sumsq_ptr = left_sumsq + right_sumsq;
|
|
418
|
+
}
|
|
419
|
+
else if (stride_elements == 1) nk_reduce_moments_e2m3_icelake_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
|
|
420
|
+
else if (stride_elements <= 16)
|
|
421
|
+
nk_reduce_moments_e2m3_icelake_strided_(data_ptr, count, stride_elements, sum_ptr, sumsq_ptr);
|
|
422
|
+
else nk_reduce_moments_e2m3_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
423
|
+
}
|
|
424
|
+
|
|
425
|
+
NK_INTERNAL void nk_reduce_moments_e3m2_icelake_contiguous_( //
|
|
426
|
+
nk_e3m2_t const *data_ptr, nk_size_t count, //
|
|
427
|
+
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
428
|
+
// 32-entry i16 LUT: maps 5-bit unsigned magnitude -> value*16 as i16 (0..448)
|
|
429
|
+
__m512i const lut_magnitude_i16x32 = _mm512_set_epi16(448, 384, 320, 256, 224, 192, 160, 128, 112, 96, 80, 64, 56,
|
|
430
|
+
48, 40, 32, 28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2,
|
|
431
|
+
1, 0);
|
|
432
|
+
__m512i const magnitude_mask_i16x32 = _mm512_set1_epi16(0x1F);
|
|
433
|
+
__m512i const sign_mask_i16x32 = _mm512_set1_epi16(0x20);
|
|
434
|
+
__m512i const ones_i16x32 = _mm512_set1_epi16(1);
|
|
435
|
+
__m512i sum_i32x16 = _mm512_setzero_si512();
|
|
436
|
+
__m512i sumsq_i32x16 = _mm512_setzero_si512();
|
|
437
|
+
nk_size_t idx = 0;
|
|
438
|
+
for (; idx + 32 <= count; idx += 32) {
|
|
439
|
+
// Load 32 bytes, widen u8->u16
|
|
440
|
+
__m256i data_u8x32 = _mm256_loadu_si256((__m256i const *)(data_ptr + idx));
|
|
441
|
+
__m512i data_u16x32 = _mm512_cvtepu8_epi16(data_u8x32);
|
|
442
|
+
// Extract 5-bit magnitude, VPERMW LUT lookup
|
|
443
|
+
__m512i magnitude_u16x32 = _mm512_and_si512(data_u16x32, magnitude_mask_i16x32);
|
|
444
|
+
__m512i unsigned_mag_i16x32 = _mm512_permutexvar_epi16(magnitude_u16x32, lut_magnitude_i16x32);
|
|
445
|
+
// Apply sign for sum: negate where bit 5 is set
|
|
446
|
+
__mmask32 sign_mask = _mm512_test_epi16_mask(data_u16x32, sign_mask_i16x32);
|
|
447
|
+
__m512i signed_mag_i16x32 = _mm512_mask_sub_epi16(unsigned_mag_i16x32, sign_mask, _mm512_setzero_si512(),
|
|
448
|
+
unsigned_mag_i16x32);
|
|
449
|
+
// Sum: VPMADDWD(signed_i16, ones) = sum of pairs -> i32
|
|
450
|
+
sum_i32x16 = _mm512_add_epi32(sum_i32x16, _mm512_madd_epi16(signed_mag_i16x32, ones_i16x32));
|
|
451
|
+
// Sumsq: VPMADDWD(unsigned_mag, unsigned_mag) = sum of pairs of squares -> i32
|
|
452
|
+
// max per i32: 2 * 448^2 = 401408, fits in i32
|
|
453
|
+
sumsq_i32x16 = _mm512_add_epi32(sumsq_i32x16, _mm512_madd_epi16(unsigned_mag_i16x32, unsigned_mag_i16x32));
|
|
454
|
+
}
|
|
455
|
+
nk_size_t remaining = count - idx;
|
|
456
|
+
if (remaining > 0) {
|
|
457
|
+
__mmask32 tail_mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, (unsigned int)remaining);
|
|
458
|
+
__m256i data_u8x32 = _mm256_maskz_loadu_epi8(tail_mask, data_ptr + idx);
|
|
459
|
+
__m512i data_u16x32 = _mm512_cvtepu8_epi16(data_u8x32);
|
|
460
|
+
__m512i magnitude_u16x32 = _mm512_and_si512(data_u16x32, magnitude_mask_i16x32);
|
|
461
|
+
__m512i unsigned_mag_i16x32 = _mm512_permutexvar_epi16(magnitude_u16x32, lut_magnitude_i16x32);
|
|
462
|
+
__mmask32 sign_mask = _mm512_test_epi16_mask(data_u16x32, sign_mask_i16x32);
|
|
463
|
+
__m512i signed_mag_i16x32 = _mm512_mask_sub_epi16(unsigned_mag_i16x32, sign_mask, _mm512_setzero_si512(),
|
|
464
|
+
unsigned_mag_i16x32);
|
|
465
|
+
sum_i32x16 = _mm512_add_epi32(sum_i32x16, _mm512_madd_epi16(signed_mag_i16x32, ones_i16x32));
|
|
466
|
+
sumsq_i32x16 = _mm512_add_epi32(sumsq_i32x16, _mm512_madd_epi16(unsigned_mag_i16x32, unsigned_mag_i16x32));
|
|
467
|
+
}
|
|
468
|
+
*sum_ptr = (nk_f32_t)_mm512_reduce_add_epi32(sum_i32x16) / 16.0f;
|
|
469
|
+
*sumsq_ptr = (nk_f32_t)_mm512_reduce_add_epi32(sumsq_i32x16) / 256.0f;
|
|
470
|
+
}
|
|
471
|
+
|
|
472
|
+
NK_INTERNAL void nk_reduce_moments_e3m2_icelake_strided_( //
|
|
473
|
+
nk_e3m2_t const *data_ptr, nk_size_t count, nk_size_t stride_elements, //
|
|
474
|
+
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
475
|
+
__mmask32 stride_mask_m32 = (__mmask32)nk_stride_mask_u1x64_(stride_elements);
|
|
476
|
+
__m512i const lut_magnitude_i16x32 = _mm512_set_epi16(448, 384, 320, 256, 224, 192, 160, 128, 112, 96, 80, 64, 56,
|
|
477
|
+
48, 40, 32, 28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2,
|
|
478
|
+
1, 0);
|
|
479
|
+
__m512i const magnitude_mask_i16x32 = _mm512_set1_epi16(0x1F);
|
|
480
|
+
__m512i const sign_mask_i16x32 = _mm512_set1_epi16(0x20);
|
|
481
|
+
__m512i const ones_i16x32 = _mm512_set1_epi16(1);
|
|
482
|
+
__m512i sum_i32x16 = _mm512_setzero_si512();
|
|
483
|
+
__m512i sumsq_i32x16 = _mm512_setzero_si512();
|
|
484
|
+
nk_size_t idx_scalars = 0;
|
|
485
|
+
nk_size_t total_scalars = count * stride_elements;
|
|
486
|
+
nk_size_t step = (nk_size_t)_mm_popcnt_u64((nk_u64_t)stride_mask_m32) * stride_elements;
|
|
487
|
+
for (; idx_scalars + step <= total_scalars; idx_scalars += step) {
|
|
488
|
+
__m256i data_u8x32 = _mm256_maskz_loadu_epi8(stride_mask_m32, data_ptr + idx_scalars);
|
|
489
|
+
__m512i data_u16x32 = _mm512_cvtepu8_epi16(data_u8x32);
|
|
490
|
+
__m512i magnitude_u16x32 = _mm512_and_si512(data_u16x32, magnitude_mask_i16x32);
|
|
491
|
+
__m512i unsigned_mag_i16x32 = _mm512_permutexvar_epi16(magnitude_u16x32, lut_magnitude_i16x32);
|
|
492
|
+
__mmask32 sign_mask = _mm512_test_epi16_mask(data_u16x32, sign_mask_i16x32);
|
|
493
|
+
__m512i signed_mag_i16x32 = _mm512_mask_sub_epi16(unsigned_mag_i16x32, sign_mask, _mm512_setzero_si512(),
|
|
494
|
+
unsigned_mag_i16x32);
|
|
495
|
+
sum_i32x16 = _mm512_add_epi32(sum_i32x16, _mm512_madd_epi16(signed_mag_i16x32, ones_i16x32));
|
|
496
|
+
sumsq_i32x16 = _mm512_add_epi32(sumsq_i32x16, _mm512_madd_epi16(unsigned_mag_i16x32, unsigned_mag_i16x32));
|
|
497
|
+
}
|
|
498
|
+
nk_size_t remaining_scalars = total_scalars - idx_scalars;
|
|
499
|
+
if (remaining_scalars > 0) {
|
|
500
|
+
__mmask32 tail_mask = stride_mask_m32 & (__mmask32)_bzhi_u32(0xFFFFFFFF, (unsigned int)remaining_scalars);
|
|
501
|
+
__m256i data_u8x32 = _mm256_maskz_loadu_epi8(tail_mask, data_ptr + idx_scalars);
|
|
502
|
+
__m512i data_u16x32 = _mm512_cvtepu8_epi16(data_u8x32);
|
|
503
|
+
__m512i magnitude_u16x32 = _mm512_and_si512(data_u16x32, magnitude_mask_i16x32);
|
|
504
|
+
__m512i unsigned_mag_i16x32 = _mm512_permutexvar_epi16(magnitude_u16x32, lut_magnitude_i16x32);
|
|
505
|
+
__mmask32 sign_mask = _mm512_test_epi16_mask(data_u16x32, sign_mask_i16x32);
|
|
506
|
+
__m512i signed_mag_i16x32 = _mm512_mask_sub_epi16(unsigned_mag_i16x32, sign_mask, _mm512_setzero_si512(),
|
|
507
|
+
unsigned_mag_i16x32);
|
|
508
|
+
sum_i32x16 = _mm512_add_epi32(sum_i32x16, _mm512_madd_epi16(signed_mag_i16x32, ones_i16x32));
|
|
509
|
+
sumsq_i32x16 = _mm512_add_epi32(sumsq_i32x16, _mm512_madd_epi16(unsigned_mag_i16x32, unsigned_mag_i16x32));
|
|
510
|
+
}
|
|
511
|
+
*sum_ptr = (nk_f32_t)_mm512_reduce_add_epi32(sum_i32x16) / 16.0f;
|
|
512
|
+
*sumsq_ptr = (nk_f32_t)_mm512_reduce_add_epi32(sumsq_i32x16) / 256.0f;
|
|
513
|
+
}
|
|
514
|
+
|
|
515
|
+
NK_PUBLIC void nk_reduce_moments_e3m2_icelake( //
|
|
516
|
+
nk_e3m2_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
517
|
+
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
518
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_e3m2_t);
|
|
519
|
+
int aligned = (stride_bytes % sizeof(nk_e3m2_t) == 0);
|
|
520
|
+
if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
|
|
521
|
+
else if (!aligned) nk_reduce_moments_e3m2_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
522
|
+
else if (count > (nk_size_t)2048 * 64) {
|
|
523
|
+
nk_size_t left_count = count / 2;
|
|
524
|
+
nk_f32_t left_sum, left_sumsq, right_sum, right_sumsq;
|
|
525
|
+
nk_reduce_moments_e3m2_icelake(data_ptr, left_count, stride_bytes, &left_sum, &left_sumsq);
|
|
526
|
+
nk_reduce_moments_e3m2_icelake(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
|
|
527
|
+
&right_sum, &right_sumsq);
|
|
528
|
+
*sum_ptr = left_sum + right_sum;
|
|
529
|
+
*sumsq_ptr = left_sumsq + right_sumsq;
|
|
530
|
+
}
|
|
531
|
+
else if (stride_elements == 1) nk_reduce_moments_e3m2_icelake_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
|
|
532
|
+
else if (stride_elements <= 16)
|
|
533
|
+
nk_reduce_moments_e3m2_icelake_strided_(data_ptr, count, stride_elements, sum_ptr, sumsq_ptr);
|
|
534
|
+
else nk_reduce_moments_e3m2_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
535
|
+
}
|
|
536
|
+
|
|
537
|
+
#if defined(__clang__)
|
|
538
|
+
#pragma clang attribute pop
|
|
539
|
+
#elif defined(__GNUC__)
|
|
540
|
+
#pragma GCC pop_options
|
|
541
|
+
#endif
|
|
542
|
+
|
|
543
|
+
#if defined(__cplusplus)
|
|
544
|
+
} // extern "C"
|
|
545
|
+
#endif
|
|
546
|
+
|
|
547
|
+
#endif // NK_TARGET_ICELAKE
|
|
548
|
+
#endif // NK_TARGET_X86_
|
|
549
|
+
#endif // NK_REDUCE_ICELAKE_H
|