numkong 7.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +201 -0
- package/README.md +495 -0
- package/binding.gyp +540 -0
- package/c/dispatch.h +512 -0
- package/c/dispatch_bf16.c +389 -0
- package/c/dispatch_bf16c.c +52 -0
- package/c/dispatch_e2m3.c +263 -0
- package/c/dispatch_e3m2.c +243 -0
- package/c/dispatch_e4m3.c +276 -0
- package/c/dispatch_e5m2.c +272 -0
- package/c/dispatch_f16.c +376 -0
- package/c/dispatch_f16c.c +58 -0
- package/c/dispatch_f32.c +378 -0
- package/c/dispatch_f32c.c +99 -0
- package/c/dispatch_f64.c +296 -0
- package/c/dispatch_f64c.c +98 -0
- package/c/dispatch_i16.c +96 -0
- package/c/dispatch_i32.c +89 -0
- package/c/dispatch_i4.c +150 -0
- package/c/dispatch_i64.c +86 -0
- package/c/dispatch_i8.c +289 -0
- package/c/dispatch_other.c +330 -0
- package/c/dispatch_u1.c +148 -0
- package/c/dispatch_u16.c +124 -0
- package/c/dispatch_u32.c +118 -0
- package/c/dispatch_u4.c +150 -0
- package/c/dispatch_u64.c +102 -0
- package/c/dispatch_u8.c +303 -0
- package/c/numkong.c +950 -0
- package/include/README.md +573 -0
- package/include/module.modulemap +129 -0
- package/include/numkong/attention/sapphireamx.h +1361 -0
- package/include/numkong/attention/sme.h +2066 -0
- package/include/numkong/attention.h +49 -0
- package/include/numkong/capabilities.h +748 -0
- package/include/numkong/cast/README.md +262 -0
- package/include/numkong/cast/haswell.h +975 -0
- package/include/numkong/cast/icelake.h +470 -0
- package/include/numkong/cast/neon.h +1192 -0
- package/include/numkong/cast/rvv.h +1021 -0
- package/include/numkong/cast/sapphire.h +262 -0
- package/include/numkong/cast/serial.h +2262 -0
- package/include/numkong/cast/skylake.h +856 -0
- package/include/numkong/cast/v128relaxed.h +180 -0
- package/include/numkong/cast.h +230 -0
- package/include/numkong/curved/README.md +223 -0
- package/include/numkong/curved/genoa.h +182 -0
- package/include/numkong/curved/haswell.h +276 -0
- package/include/numkong/curved/neon.h +205 -0
- package/include/numkong/curved/neonbfdot.h +212 -0
- package/include/numkong/curved/neonhalf.h +212 -0
- package/include/numkong/curved/rvv.h +305 -0
- package/include/numkong/curved/serial.h +207 -0
- package/include/numkong/curved/skylake.h +457 -0
- package/include/numkong/curved/smef64.h +506 -0
- package/include/numkong/curved.h +517 -0
- package/include/numkong/curved.hpp +144 -0
- package/include/numkong/dot/README.md +425 -0
- package/include/numkong/dot/alder.h +563 -0
- package/include/numkong/dot/genoa.h +315 -0
- package/include/numkong/dot/haswell.h +1688 -0
- package/include/numkong/dot/icelake.h +883 -0
- package/include/numkong/dot/neon.h +818 -0
- package/include/numkong/dot/neonbfdot.h +244 -0
- package/include/numkong/dot/neonfhm.h +360 -0
- package/include/numkong/dot/neonhalf.h +198 -0
- package/include/numkong/dot/neonsdot.h +508 -0
- package/include/numkong/dot/rvv.h +714 -0
- package/include/numkong/dot/rvvbb.h +72 -0
- package/include/numkong/dot/rvvbf16.h +123 -0
- package/include/numkong/dot/rvvhalf.h +129 -0
- package/include/numkong/dot/sapphire.h +141 -0
- package/include/numkong/dot/serial.h +838 -0
- package/include/numkong/dot/sierra.h +405 -0
- package/include/numkong/dot/skylake.h +1084 -0
- package/include/numkong/dot/sve.h +379 -0
- package/include/numkong/dot/svebfdot.h +74 -0
- package/include/numkong/dot/svehalf.h +123 -0
- package/include/numkong/dot/v128relaxed.h +1258 -0
- package/include/numkong/dot.h +1070 -0
- package/include/numkong/dot.hpp +94 -0
- package/include/numkong/dots/README.md +496 -0
- package/include/numkong/dots/alder.h +114 -0
- package/include/numkong/dots/genoa.h +94 -0
- package/include/numkong/dots/haswell.h +295 -0
- package/include/numkong/dots/icelake.h +171 -0
- package/include/numkong/dots/neon.h +120 -0
- package/include/numkong/dots/neonbfdot.h +58 -0
- package/include/numkong/dots/neonfhm.h +94 -0
- package/include/numkong/dots/neonhalf.h +57 -0
- package/include/numkong/dots/neonsdot.h +108 -0
- package/include/numkong/dots/rvv.h +2486 -0
- package/include/numkong/dots/sapphireamx.h +3973 -0
- package/include/numkong/dots/serial.h +2844 -0
- package/include/numkong/dots/sierra.h +97 -0
- package/include/numkong/dots/skylake.h +196 -0
- package/include/numkong/dots/sme.h +5372 -0
- package/include/numkong/dots/smebi32.h +461 -0
- package/include/numkong/dots/smef64.h +1318 -0
- package/include/numkong/dots/smehalf.h +47 -0
- package/include/numkong/dots/v128relaxed.h +294 -0
- package/include/numkong/dots.h +2804 -0
- package/include/numkong/dots.hpp +639 -0
- package/include/numkong/each/README.md +469 -0
- package/include/numkong/each/haswell.h +1658 -0
- package/include/numkong/each/icelake.h +272 -0
- package/include/numkong/each/neon.h +1104 -0
- package/include/numkong/each/neonbfdot.h +212 -0
- package/include/numkong/each/neonhalf.h +410 -0
- package/include/numkong/each/rvv.h +1121 -0
- package/include/numkong/each/sapphire.h +477 -0
- package/include/numkong/each/serial.h +260 -0
- package/include/numkong/each/skylake.h +1562 -0
- package/include/numkong/each.h +2146 -0
- package/include/numkong/each.hpp +434 -0
- package/include/numkong/geospatial/README.md +147 -0
- package/include/numkong/geospatial/haswell.h +593 -0
- package/include/numkong/geospatial/neon.h +571 -0
- package/include/numkong/geospatial/rvv.h +701 -0
- package/include/numkong/geospatial/serial.h +309 -0
- package/include/numkong/geospatial/skylake.h +577 -0
- package/include/numkong/geospatial/v128relaxed.h +613 -0
- package/include/numkong/geospatial.h +453 -0
- package/include/numkong/geospatial.hpp +235 -0
- package/include/numkong/matrix.hpp +336 -0
- package/include/numkong/maxsim/README.md +187 -0
- package/include/numkong/maxsim/alder.h +511 -0
- package/include/numkong/maxsim/genoa.h +115 -0
- package/include/numkong/maxsim/haswell.h +553 -0
- package/include/numkong/maxsim/icelake.h +480 -0
- package/include/numkong/maxsim/neonsdot.h +394 -0
- package/include/numkong/maxsim/sapphireamx.h +877 -0
- package/include/numkong/maxsim/serial.h +490 -0
- package/include/numkong/maxsim/sme.h +929 -0
- package/include/numkong/maxsim/v128relaxed.h +280 -0
- package/include/numkong/maxsim.h +571 -0
- package/include/numkong/maxsim.hpp +133 -0
- package/include/numkong/mesh/README.md +227 -0
- package/include/numkong/mesh/haswell.h +2235 -0
- package/include/numkong/mesh/neon.h +1329 -0
- package/include/numkong/mesh/neonbfdot.h +842 -0
- package/include/numkong/mesh/neonhalf.h +616 -0
- package/include/numkong/mesh/rvv.h +916 -0
- package/include/numkong/mesh/serial.h +742 -0
- package/include/numkong/mesh/skylake.h +1135 -0
- package/include/numkong/mesh/v128relaxed.h +1052 -0
- package/include/numkong/mesh.h +652 -0
- package/include/numkong/mesh.hpp +762 -0
- package/include/numkong/numkong.h +78 -0
- package/include/numkong/numkong.hpp +57 -0
- package/include/numkong/probability/README.md +173 -0
- package/include/numkong/probability/haswell.h +267 -0
- package/include/numkong/probability/neon.h +225 -0
- package/include/numkong/probability/rvv.h +409 -0
- package/include/numkong/probability/serial.h +169 -0
- package/include/numkong/probability/skylake.h +324 -0
- package/include/numkong/probability.h +383 -0
- package/include/numkong/probability.hpp +120 -0
- package/include/numkong/random.h +50 -0
- package/include/numkong/random.hpp +285 -0
- package/include/numkong/reduce/README.md +547 -0
- package/include/numkong/reduce/alder.h +632 -0
- package/include/numkong/reduce/genoa.h +201 -0
- package/include/numkong/reduce/haswell.h +3783 -0
- package/include/numkong/reduce/icelake.h +549 -0
- package/include/numkong/reduce/neon.h +3841 -0
- package/include/numkong/reduce/neonbfdot.h +353 -0
- package/include/numkong/reduce/neonfhm.h +665 -0
- package/include/numkong/reduce/neonhalf.h +157 -0
- package/include/numkong/reduce/neonsdot.h +357 -0
- package/include/numkong/reduce/rvv.h +3407 -0
- package/include/numkong/reduce/serial.h +757 -0
- package/include/numkong/reduce/sierra.h +338 -0
- package/include/numkong/reduce/skylake.h +3792 -0
- package/include/numkong/reduce/v128relaxed.h +2302 -0
- package/include/numkong/reduce.h +1597 -0
- package/include/numkong/reduce.hpp +633 -0
- package/include/numkong/scalar/README.md +89 -0
- package/include/numkong/scalar/haswell.h +113 -0
- package/include/numkong/scalar/neon.h +122 -0
- package/include/numkong/scalar/neonhalf.h +70 -0
- package/include/numkong/scalar/rvv.h +211 -0
- package/include/numkong/scalar/sapphire.h +63 -0
- package/include/numkong/scalar/serial.h +332 -0
- package/include/numkong/scalar/v128relaxed.h +56 -0
- package/include/numkong/scalar.h +683 -0
- package/include/numkong/set/README.md +179 -0
- package/include/numkong/set/haswell.h +334 -0
- package/include/numkong/set/icelake.h +485 -0
- package/include/numkong/set/neon.h +364 -0
- package/include/numkong/set/rvv.h +226 -0
- package/include/numkong/set/rvvbb.h +117 -0
- package/include/numkong/set/serial.h +174 -0
- package/include/numkong/set/sve.h +185 -0
- package/include/numkong/set/v128relaxed.h +240 -0
- package/include/numkong/set.h +457 -0
- package/include/numkong/set.hpp +114 -0
- package/include/numkong/sets/README.md +149 -0
- package/include/numkong/sets/haswell.h +63 -0
- package/include/numkong/sets/icelake.h +66 -0
- package/include/numkong/sets/neon.h +61 -0
- package/include/numkong/sets/serial.h +43 -0
- package/include/numkong/sets/smebi32.h +1099 -0
- package/include/numkong/sets/v128relaxed.h +58 -0
- package/include/numkong/sets.h +339 -0
- package/include/numkong/sparse/README.md +156 -0
- package/include/numkong/sparse/icelake.h +463 -0
- package/include/numkong/sparse/neon.h +288 -0
- package/include/numkong/sparse/serial.h +117 -0
- package/include/numkong/sparse/sve2.h +507 -0
- package/include/numkong/sparse/turin.h +322 -0
- package/include/numkong/sparse.h +363 -0
- package/include/numkong/sparse.hpp +113 -0
- package/include/numkong/spatial/README.md +435 -0
- package/include/numkong/spatial/alder.h +607 -0
- package/include/numkong/spatial/genoa.h +290 -0
- package/include/numkong/spatial/haswell.h +960 -0
- package/include/numkong/spatial/icelake.h +586 -0
- package/include/numkong/spatial/neon.h +773 -0
- package/include/numkong/spatial/neonbfdot.h +165 -0
- package/include/numkong/spatial/neonhalf.h +118 -0
- package/include/numkong/spatial/neonsdot.h +261 -0
- package/include/numkong/spatial/rvv.h +984 -0
- package/include/numkong/spatial/rvvbf16.h +123 -0
- package/include/numkong/spatial/rvvhalf.h +117 -0
- package/include/numkong/spatial/sapphire.h +343 -0
- package/include/numkong/spatial/serial.h +346 -0
- package/include/numkong/spatial/sierra.h +323 -0
- package/include/numkong/spatial/skylake.h +606 -0
- package/include/numkong/spatial/sve.h +224 -0
- package/include/numkong/spatial/svebfdot.h +122 -0
- package/include/numkong/spatial/svehalf.h +109 -0
- package/include/numkong/spatial/v128relaxed.h +717 -0
- package/include/numkong/spatial.h +1425 -0
- package/include/numkong/spatial.hpp +183 -0
- package/include/numkong/spatials/README.md +580 -0
- package/include/numkong/spatials/alder.h +94 -0
- package/include/numkong/spatials/genoa.h +94 -0
- package/include/numkong/spatials/haswell.h +219 -0
- package/include/numkong/spatials/icelake.h +113 -0
- package/include/numkong/spatials/neon.h +109 -0
- package/include/numkong/spatials/neonbfdot.h +60 -0
- package/include/numkong/spatials/neonfhm.h +92 -0
- package/include/numkong/spatials/neonhalf.h +58 -0
- package/include/numkong/spatials/neonsdot.h +109 -0
- package/include/numkong/spatials/rvv.h +1960 -0
- package/include/numkong/spatials/sapphireamx.h +1149 -0
- package/include/numkong/spatials/serial.h +226 -0
- package/include/numkong/spatials/sierra.h +96 -0
- package/include/numkong/spatials/skylake.h +184 -0
- package/include/numkong/spatials/sme.h +1901 -0
- package/include/numkong/spatials/smef64.h +465 -0
- package/include/numkong/spatials/v128relaxed.h +240 -0
- package/include/numkong/spatials.h +3021 -0
- package/include/numkong/spatials.hpp +508 -0
- package/include/numkong/tensor.hpp +1592 -0
- package/include/numkong/trigonometry/README.md +184 -0
- package/include/numkong/trigonometry/haswell.h +652 -0
- package/include/numkong/trigonometry/neon.h +639 -0
- package/include/numkong/trigonometry/rvv.h +699 -0
- package/include/numkong/trigonometry/serial.h +703 -0
- package/include/numkong/trigonometry/skylake.h +721 -0
- package/include/numkong/trigonometry/v128relaxed.h +666 -0
- package/include/numkong/trigonometry.h +467 -0
- package/include/numkong/trigonometry.hpp +166 -0
- package/include/numkong/types.h +1384 -0
- package/include/numkong/types.hpp +5603 -0
- package/include/numkong/vector.hpp +698 -0
- package/javascript/README.md +246 -0
- package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
- package/javascript/dist/cjs/numkong-wasm.js +617 -0
- package/javascript/dist/cjs/numkong.d.ts +343 -0
- package/javascript/dist/cjs/numkong.js +523 -0
- package/javascript/dist/cjs/package.json +3 -0
- package/javascript/dist/cjs/types.d.ts +284 -0
- package/javascript/dist/cjs/types.js +653 -0
- package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
- package/javascript/dist/esm/numkong-wasm.js +595 -0
- package/javascript/dist/esm/numkong.d.ts +343 -0
- package/javascript/dist/esm/numkong.js +452 -0
- package/javascript/dist/esm/package.json +3 -0
- package/javascript/dist/esm/types.d.ts +284 -0
- package/javascript/dist/esm/types.js +630 -0
- package/javascript/dist-package-cjs.json +3 -0
- package/javascript/dist-package-esm.json +3 -0
- package/javascript/node-gyp-build.d.ts +1 -0
- package/javascript/numkong-wasm.ts +756 -0
- package/javascript/numkong.c +689 -0
- package/javascript/numkong.ts +575 -0
- package/javascript/tsconfig-base.json +39 -0
- package/javascript/tsconfig-cjs.json +8 -0
- package/javascript/tsconfig-esm.json +8 -0
- package/javascript/types.ts +674 -0
- package/package.json +87 -0
|
@@ -0,0 +1,3792 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief AVX-512 implementations for the redesigned reduction API (moments + minmax).
|
|
3
|
+
* @file include/numkong/reduce/skylake.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date February 11, 2026
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/reduce.h
|
|
8
|
+
*
|
|
9
|
+
* @section tail_nan_fill Tail Handling via NaN Fill
|
|
10
|
+
*
|
|
11
|
+
* In floating-point minmax contiguous kernels (f32, f64), the tail block fills
|
|
12
|
+
* unloaded lanes with NaN via `_mm512_mask_loadu_ps(nan, mask, ptr)` instead of
|
|
13
|
+
* `_mm512_maskz_loadu_ps(mask, ptr)`. This allows the subsequent `_CMP_LT_OQ` /
|
|
14
|
+
* `_CMP_GT_OQ` comparisons to run without the tail-load mask predicate, because
|
|
15
|
+
* IEEE-754 ordered-quiet comparisons return false for NaN operands.
|
|
16
|
+
*
|
|
17
|
+
* @section reduce_block_caps Block-Cap Overflow Thresholds
|
|
18
|
+
*
|
|
19
|
+
* Dispatch functions use pairwise recursion when `count` exceeds a block cap.
|
|
20
|
+
* The cap is sized so the iteration counter in the contiguous kernel never wraps.
|
|
21
|
+
*
|
|
22
|
+
* Iteration counters start at 0 (initial load) and increment by 1 per SIMD chunk.
|
|
23
|
+
* A u8 counter holds 0..255 → 256 iterations → processes 256 × lanes elements.
|
|
24
|
+
* A u16 counter holds 0..65535 → 65536 iterations → processes 65536 × lanes elements.
|
|
25
|
+
* A u32 counter holds 0..4294967295 → ~4.3 billion iterations.
|
|
26
|
+
*
|
|
27
|
+
* Threshold formula: count > (COUNTER_MAX + 1) × lanes_per_chunk
|
|
28
|
+
* - u8 minmax: (NK_U8_MAX + 1) × lanes (e.g. 256 × 64 = 16384 for i8x64)
|
|
29
|
+
* - u16 minmax: (NK_U16_MAX + 1) × lanes (e.g. 65536 × 32 = 2097152 for i16x32)
|
|
30
|
+
* - u32 minmax: NK_U32_MAX × lanes (no +1: NK_U32_MAX + 1 overflows unsigned)
|
|
31
|
+
*
|
|
32
|
+
* Moments block caps are sized for accumulator overflow, not counter overflow.
|
|
33
|
+
* See individual dispatch functions for type-specific derivations.
|
|
34
|
+
*/
|
|
35
|
+
#ifndef NK_REDUCE_SKYLAKE_H
|
|
36
|
+
#define NK_REDUCE_SKYLAKE_H
|
|
37
|
+
|
|
38
|
+
#if NK_TARGET_X86_
|
|
39
|
+
#if NK_TARGET_SKYLAKE
|
|
40
|
+
|
|
41
|
+
#include "numkong/types.h"
|
|
42
|
+
#include "numkong/cast/skylake.h"
|
|
43
|
+
#include "numkong/reduce/serial.h"
|
|
44
|
+
|
|
45
|
+
#if defined(__cplusplus)
|
|
46
|
+
extern "C" {
|
|
47
|
+
#endif
|
|
48
|
+
|
|
49
|
+
#if defined(__clang__)
|
|
50
|
+
#pragma clang attribute push(__attribute__((target("avx2,avx512f,avx512vl,avx512bw,avx512dq,f16c,fma,bmi,bmi2"))), \
|
|
51
|
+
apply_to = function)
|
|
52
|
+
#elif defined(__GNUC__)
|
|
53
|
+
#pragma GCC push_options
|
|
54
|
+
#pragma GCC target("avx2", "avx512f", "avx512vl", "avx512bw", "avx512dq", "f16c", "fma", "bmi", "bmi2")
|
|
55
|
+
#endif
|
|
56
|
+
|
|
57
|
+
/** @brief Horizontal sum of 16 floats in a ZMM register (native f32 precision). */
|
|
58
|
+
NK_INTERNAL nk_f32_t nk_reduce_add_f32x16_skylake_(__m512 sum_f32x16) {
|
|
59
|
+
__m256 lo_f32x8 = _mm512_castps512_ps256(sum_f32x16);
|
|
60
|
+
__m256 hi_f32x8 = _mm512_extractf32x8_ps(sum_f32x16, 1);
|
|
61
|
+
__m256 sum_f32x8 = _mm256_add_ps(lo_f32x8, hi_f32x8);
|
|
62
|
+
__m128 lo_f32x4 = _mm256_castps256_ps128(sum_f32x8);
|
|
63
|
+
__m128 hi_f32x4 = _mm256_extractf128_ps(sum_f32x8, 1);
|
|
64
|
+
__m128 sum_f32x4 = _mm_add_ps(lo_f32x4, hi_f32x4);
|
|
65
|
+
sum_f32x4 = _mm_hadd_ps(sum_f32x4, sum_f32x4);
|
|
66
|
+
sum_f32x4 = _mm_hadd_ps(sum_f32x4, sum_f32x4);
|
|
67
|
+
return _mm_cvtss_f32(sum_f32x4);
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
/** @brief Horizontal sum of 8 doubles in a ZMM register. */
|
|
71
|
+
NK_INTERNAL nk_f64_t nk_reduce_add_f64x8_skylake_(__m512d sum_f64x8) {
|
|
72
|
+
__m256d lo_f64x4 = _mm512_castpd512_pd256(sum_f64x8);
|
|
73
|
+
__m256d hi_f64x4 = _mm512_extractf64x4_pd(sum_f64x8, 1);
|
|
74
|
+
__m256d sum_f64x4 = _mm256_add_pd(lo_f64x4, hi_f64x4);
|
|
75
|
+
__m128d lo_f64x2 = _mm256_castpd256_pd128(sum_f64x4);
|
|
76
|
+
__m128d hi_f64x2 = _mm256_extractf128_pd(sum_f64x4, 1);
|
|
77
|
+
__m128d sum_f64x2 = _mm_add_pd(lo_f64x2, hi_f64x2);
|
|
78
|
+
sum_f64x2 = _mm_hadd_pd(sum_f64x2, sum_f64x2);
|
|
79
|
+
return _mm_cvtsd_f64(sum_f64x2);
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
/** @brief Horizontal min of 16 floats in a ZMM register. */
|
|
83
|
+
NK_INTERNAL nk_f32_t nk_reduce_min_f32x16_skylake_(__m512 min_f32x16) {
|
|
84
|
+
__m256 lo_f32x8 = _mm512_castps512_ps256(min_f32x16);
|
|
85
|
+
__m256 hi_f32x8 = _mm512_extractf32x8_ps(min_f32x16, 1);
|
|
86
|
+
__m256 min_f32x8 = _mm256_min_ps(lo_f32x8, hi_f32x8);
|
|
87
|
+
__m128 lo_f32x4 = _mm256_castps256_ps128(min_f32x8);
|
|
88
|
+
__m128 hi_f32x4 = _mm256_extractf128_ps(min_f32x8, 1);
|
|
89
|
+
__m128 min_f32x4 = _mm_min_ps(lo_f32x4, hi_f32x4);
|
|
90
|
+
min_f32x4 = _mm_min_ps(min_f32x4, _mm_shuffle_ps(min_f32x4, min_f32x4, _MM_SHUFFLE(2, 3, 0, 1)));
|
|
91
|
+
min_f32x4 = _mm_min_ps(min_f32x4, _mm_shuffle_ps(min_f32x4, min_f32x4, _MM_SHUFFLE(1, 0, 3, 2)));
|
|
92
|
+
return _mm_cvtss_f32(min_f32x4);
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
/** @brief Horizontal max of 16 floats in a ZMM register. */
|
|
96
|
+
NK_INTERNAL nk_f32_t nk_reduce_max_f32x16_skylake_(__m512 max_f32x16) {
|
|
97
|
+
__m256 lo_f32x8 = _mm512_castps512_ps256(max_f32x16);
|
|
98
|
+
__m256 hi_f32x8 = _mm512_extractf32x8_ps(max_f32x16, 1);
|
|
99
|
+
__m256 max_f32x8 = _mm256_max_ps(lo_f32x8, hi_f32x8);
|
|
100
|
+
__m128 lo_f32x4 = _mm256_castps256_ps128(max_f32x8);
|
|
101
|
+
__m128 hi_f32x4 = _mm256_extractf128_ps(max_f32x8, 1);
|
|
102
|
+
__m128 max_f32x4 = _mm_max_ps(lo_f32x4, hi_f32x4);
|
|
103
|
+
max_f32x4 = _mm_max_ps(max_f32x4, _mm_shuffle_ps(max_f32x4, max_f32x4, _MM_SHUFFLE(2, 3, 0, 1)));
|
|
104
|
+
max_f32x4 = _mm_max_ps(max_f32x4, _mm_shuffle_ps(max_f32x4, max_f32x4, _MM_SHUFFLE(1, 0, 3, 2)));
|
|
105
|
+
return _mm_cvtss_f32(max_f32x4);
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
/** @brief Horizontal min of 8 doubles in a ZMM register. */
|
|
109
|
+
NK_INTERNAL nk_f64_t nk_reduce_min_f64x8_skylake_(__m512d min_f64x8) {
|
|
110
|
+
__m256d lo_f64x4 = _mm512_castpd512_pd256(min_f64x8);
|
|
111
|
+
__m256d hi_f64x4 = _mm512_extractf64x4_pd(min_f64x8, 1);
|
|
112
|
+
__m256d min_f64x4 = _mm256_min_pd(lo_f64x4, hi_f64x4);
|
|
113
|
+
__m128d lo_f64x2 = _mm256_castpd256_pd128(min_f64x4);
|
|
114
|
+
__m128d hi_f64x2 = _mm256_extractf128_pd(min_f64x4, 1);
|
|
115
|
+
__m128d min_f64x2 = _mm_min_pd(lo_f64x2, hi_f64x2);
|
|
116
|
+
min_f64x2 = _mm_min_pd(min_f64x2, _mm_shuffle_pd(min_f64x2, min_f64x2, 1));
|
|
117
|
+
return _mm_cvtsd_f64(min_f64x2);
|
|
118
|
+
}
|
|
119
|
+
|
|
120
|
+
/** @brief Horizontal max of 8 doubles in a ZMM register. */
|
|
121
|
+
NK_INTERNAL nk_f64_t nk_reduce_max_f64x8_skylake_(__m512d max_f64x8) {
|
|
122
|
+
__m256d lo_f64x4 = _mm512_castpd512_pd256(max_f64x8);
|
|
123
|
+
__m256d hi_f64x4 = _mm512_extractf64x4_pd(max_f64x8, 1);
|
|
124
|
+
__m256d max_f64x4 = _mm256_max_pd(lo_f64x4, hi_f64x4);
|
|
125
|
+
__m128d lo_f64x2 = _mm256_castpd256_pd128(max_f64x4);
|
|
126
|
+
__m128d hi_f64x2 = _mm256_extractf128_pd(max_f64x4, 1);
|
|
127
|
+
__m128d max_f64x2 = _mm_max_pd(lo_f64x2, hi_f64x2);
|
|
128
|
+
max_f64x2 = _mm_max_pd(max_f64x2, _mm_shuffle_pd(max_f64x2, max_f64x2, 1));
|
|
129
|
+
return _mm_cvtsd_f64(max_f64x2);
|
|
130
|
+
}
|
|
131
|
+
|
|
132
|
+
/** @brief Horizontal sum of 16 i32s in a ZMM register. */
|
|
133
|
+
NK_INTERNAL nk_i32_t nk_reduce_add_i32x16_skylake_(__m512i sum_i32x16) {
|
|
134
|
+
__m256i lo_i32x8 = _mm512_castsi512_si256(sum_i32x16);
|
|
135
|
+
__m256i hi_i32x8 = _mm512_extracti32x8_epi32(sum_i32x16, 1);
|
|
136
|
+
__m256i sum_i32x8 = _mm256_add_epi32(lo_i32x8, hi_i32x8);
|
|
137
|
+
__m128i lo_i32x4 = _mm256_castsi256_si128(sum_i32x8);
|
|
138
|
+
__m128i hi_i32x4 = _mm256_extracti128_si256(sum_i32x8, 1);
|
|
139
|
+
__m128i sum_i32x4 = _mm_add_epi32(lo_i32x4, hi_i32x4);
|
|
140
|
+
sum_i32x4 = _mm_hadd_epi32(sum_i32x4, sum_i32x4);
|
|
141
|
+
sum_i32x4 = _mm_hadd_epi32(sum_i32x4, sum_i32x4);
|
|
142
|
+
return _mm_cvtsi128_si32(sum_i32x4);
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
/** @brief Horizontal sum of 8 i64s in a ZMM register. */
|
|
146
|
+
NK_INTERNAL nk_i64_t nk_reduce_add_i64x8_skylake_(__m512i sum_i64x8) {
|
|
147
|
+
__m256i lo_i64x4 = _mm512_castsi512_si256(sum_i64x8);
|
|
148
|
+
__m256i hi_i64x4 = _mm512_extracti64x4_epi64(sum_i64x8, 1);
|
|
149
|
+
__m256i sum_i64x4 = _mm256_add_epi64(lo_i64x4, hi_i64x4);
|
|
150
|
+
__m128i lo_i64x2 = _mm256_castsi256_si128(sum_i64x4);
|
|
151
|
+
__m128i hi_i64x2 = _mm256_extracti128_si256(sum_i64x4, 1);
|
|
152
|
+
__m128i sum_i64x2 = _mm_add_epi64(lo_i64x2, hi_i64x2);
|
|
153
|
+
sum_i64x2 = _mm_add_epi64(sum_i64x2, _mm_shuffle_epi32(sum_i64x2, _MM_SHUFFLE(1, 0, 3, 2)));
|
|
154
|
+
return _mm_cvtsi128_si64(sum_i64x2);
|
|
155
|
+
}
|
|
156
|
+
|
|
157
|
+
/**
|
|
158
|
+
* @brief Returns AVX-512 mask for strided access of 8-bit elements (64-element register).
|
|
159
|
+
*
|
|
160
|
+
* For column extraction from row-major matrices: stride N means every Nth element.
|
|
161
|
+
* With 64 elements per register, useful for strides 2-16 (yielding 4+ elements per load).
|
|
162
|
+
* Mask bits set to 1 where (position % stride == 0).
|
|
163
|
+
*/
|
|
164
|
+
NK_INTERNAL __mmask64 nk_stride_mask_u1x64_(nk_size_t stride) {
|
|
165
|
+
switch (stride) {
|
|
166
|
+
case 2: return (__mmask64)0x5555555555555555ull; // 32 elems
|
|
167
|
+
case 3: return (__mmask64)0x9249249249249249ull; // 22 elems
|
|
168
|
+
case 4: return (__mmask64)0x1111111111111111ull; // 16 elems
|
|
169
|
+
case 5: return (__mmask64)0x1084210842108421ull; // 13 elems
|
|
170
|
+
case 6: return (__mmask64)0x1041041041041041ull; // 11 elems
|
|
171
|
+
case 7: return (__mmask64)0x0102040810204081ull; // 9 elems
|
|
172
|
+
case 8: return (__mmask64)0x0101010101010101ull; // 8 elems
|
|
173
|
+
case 9: return (__mmask64)0x8040201008040201ull; // 8 elems
|
|
174
|
+
case 10: return (__mmask64)0x1004010040100401ull; // 7 elems
|
|
175
|
+
case 11: return (__mmask64)0x0080100200400801ull; // 6 elems
|
|
176
|
+
case 12: return (__mmask64)0x1001001001001001ull; // 6 elems
|
|
177
|
+
case 13: return (__mmask64)0x0010008004002001ull; // 5 elems
|
|
178
|
+
case 14: return (__mmask64)0x0100040010004001ull; // 5 elems
|
|
179
|
+
case 15: return (__mmask64)0x1000200040008001ull; // 5 elems
|
|
180
|
+
case 16: return (__mmask64)0x0001000100010001ull; // 4 elems
|
|
181
|
+
default: return (__mmask64)0;
|
|
182
|
+
}
|
|
183
|
+
}
|
|
184
|
+
|
|
185
|
+
/**
|
|
186
|
+
* @brief Returns AVX-512 mask for strided access of 32-bit elements (16-element register).
|
|
187
|
+
*
|
|
188
|
+
* For column extraction from row-major matrices: stride N means every Nth element.
|
|
189
|
+
* Example: stride 4 extracts column 0 from a 4-column matrix.
|
|
190
|
+
* Mask bits set to 1 where (position % stride == 0).
|
|
191
|
+
*/
|
|
192
|
+
NK_INTERNAL __mmask16 nk_stride_mask_b32x16_(nk_size_t stride) {
|
|
193
|
+
switch (stride) {
|
|
194
|
+
case 2: return (__mmask16)0x5555; // 8 elems
|
|
195
|
+
case 3: return (__mmask16)0x9249; // 6 elems
|
|
196
|
+
case 4: return (__mmask16)0x1111; // 4 elems
|
|
197
|
+
case 5: return (__mmask16)0x8421; // 4 elems
|
|
198
|
+
case 6: return (__mmask16)0x1041; // 3 elems
|
|
199
|
+
case 7: return (__mmask16)0x4081; // 3 elems
|
|
200
|
+
case 8: return (__mmask16)0x0101; // 2 elems
|
|
201
|
+
default: return (__mmask16)0; // Invalid stride - caller should use gather or serial
|
|
202
|
+
}
|
|
203
|
+
}
|
|
204
|
+
|
|
205
|
+
/**
|
|
206
|
+
* @brief Returns AVX-512 mask for strided access of 16-bit elements (32-element register).
|
|
207
|
+
*
|
|
208
|
+
* For column extraction from row-major matrices: stride N means every Nth element.
|
|
209
|
+
* Example: stride 4 extracts column 0 from a 4-column int16 matrix.
|
|
210
|
+
* Mask bits set to 1 where (position % stride == 0).
|
|
211
|
+
*/
|
|
212
|
+
NK_INTERNAL __mmask32 nk_stride_mask_b16x32_(nk_size_t stride) {
|
|
213
|
+
switch (stride) {
|
|
214
|
+
case 2: return (__mmask32)0x55555555; // 16 elems
|
|
215
|
+
case 3: return (__mmask32)0x49249249; // 11 elems
|
|
216
|
+
case 4: return (__mmask32)0x11111111; // 8 elems
|
|
217
|
+
case 5: return (__mmask32)0x42108421; // 7 elems
|
|
218
|
+
case 6: return (__mmask32)0x41041041; // 6 elems
|
|
219
|
+
case 7: return (__mmask32)0x10204081; // 5 elems
|
|
220
|
+
case 8: return (__mmask32)0x01010101; // 4 elems
|
|
221
|
+
case 9: return (__mmask32)0x08040201; // 4 elems
|
|
222
|
+
case 10: return (__mmask32)0x40100401; // 4 elems
|
|
223
|
+
case 11: return (__mmask32)0x00400801; // 3 elems
|
|
224
|
+
case 12: return (__mmask32)0x01001001; // 3 elems
|
|
225
|
+
case 13: return (__mmask32)0x04002001; // 3 elems
|
|
226
|
+
case 14: return (__mmask32)0x10004001; // 3 elems
|
|
227
|
+
case 15: return (__mmask32)0x40008001; // 3 elems
|
|
228
|
+
case 16: return (__mmask32)0x00010001; // 2 elems
|
|
229
|
+
default: return (__mmask32)0;
|
|
230
|
+
}
|
|
231
|
+
}
|
|
232
|
+
|
|
233
|
+
/**
|
|
234
|
+
* @brief Returns AVX-512 mask for strided access of 64-bit elements (8-element register).
|
|
235
|
+
*
|
|
236
|
+
* For column extraction from row-major matrices: stride N means every Nth element.
|
|
237
|
+
* Example: stride 4 extracts column 0 from a 4-column matrix.
|
|
238
|
+
* Mask bits set to 1 where (position % stride == 0).
|
|
239
|
+
*/
|
|
240
|
+
NK_INTERNAL __mmask8 nk_stride_mask_b64x8_(nk_size_t stride) {
|
|
241
|
+
switch (stride) {
|
|
242
|
+
case 2: return (__mmask8)0x55; // [1,0,1,0,1,0,1,0] → 4 elems
|
|
243
|
+
case 3: return (__mmask8)0x49; // [1,0,0,1,0,0,1,0] → 3 elems
|
|
244
|
+
case 4: return (__mmask8)0x11; // [1,0,0,0,1,0,0,0] → 2 elems
|
|
245
|
+
case 5: return (__mmask8)0x21; // [1,0,0,0,0,1,0,0] → 2 elems
|
|
246
|
+
case 6: return (__mmask8)0x41; // [1,0,0,0,0,0,1,0] → 2 elems
|
|
247
|
+
case 7: return (__mmask8)0x01; // [1,0,0,0,0,0,0,0] → 1 elem
|
|
248
|
+
case 8: return (__mmask8)0x01; // [1,0,0,0,0,0,0,0] → 1 elem
|
|
249
|
+
default: return (__mmask8)0;
|
|
250
|
+
}
|
|
251
|
+
}
|
|
252
|
+
|
|
253
|
+
/**
|
|
254
|
+
* @brief Returns number of logical elements per 16-scalar chunk for given stride.
|
|
255
|
+
*/
|
|
256
|
+
NK_INTERNAL nk_size_t nk_stride_elems_b32x16_(nk_size_t stride) {
|
|
257
|
+
switch (stride) {
|
|
258
|
+
case 2: return 8;
|
|
259
|
+
case 3: return 6;
|
|
260
|
+
case 4: return 4;
|
|
261
|
+
case 5: return 4;
|
|
262
|
+
case 6: return 3;
|
|
263
|
+
case 7: return 3;
|
|
264
|
+
case 8: return 2;
|
|
265
|
+
default: return 0;
|
|
266
|
+
}
|
|
267
|
+
}
|
|
268
|
+
|
|
269
|
+
/**
|
|
270
|
+
* @brief Returns number of logical elements per 8-scalar chunk for given stride.
|
|
271
|
+
*/
|
|
272
|
+
NK_INTERNAL nk_size_t nk_stride_elems_b64x8_(nk_size_t stride) {
|
|
273
|
+
switch (stride) {
|
|
274
|
+
case 2: return 4;
|
|
275
|
+
case 3: return 3;
|
|
276
|
+
case 4: return 2;
|
|
277
|
+
case 5: return 2;
|
|
278
|
+
case 6: return 2;
|
|
279
|
+
case 7: return 1;
|
|
280
|
+
case 8: return 1;
|
|
281
|
+
default: return 0;
|
|
282
|
+
}
|
|
283
|
+
}
|
|
284
|
+
|
|
285
|
+
/** @brief Horizontal min of 64 signed i8s in a ZMM register. */
|
|
286
|
+
NK_INTERNAL nk_i8_t nk_reduce_min_i8x64_skylake_(__m512i min_i8x64) {
|
|
287
|
+
__m256i lo_i8x32 = _mm512_castsi512_si256(min_i8x64);
|
|
288
|
+
__m256i hi_i8x32 = _mm512_extracti64x4_epi64(min_i8x64, 1);
|
|
289
|
+
__m256i min_i8x32 = _mm256_min_epi8(lo_i8x32, hi_i8x32);
|
|
290
|
+
__m128i lo_i8x16 = _mm256_castsi256_si128(min_i8x32);
|
|
291
|
+
__m128i hi_i8x16 = _mm256_extracti128_si256(min_i8x32, 1);
|
|
292
|
+
__m128i min_i8x16 = _mm_min_epi8(lo_i8x16, hi_i8x16);
|
|
293
|
+
min_i8x16 = _mm_min_epi8(min_i8x16, _mm_shuffle_epi32(min_i8x16, _MM_SHUFFLE(2, 3, 0, 1)));
|
|
294
|
+
min_i8x16 = _mm_min_epi8(min_i8x16, _mm_shuffle_epi32(min_i8x16, _MM_SHUFFLE(1, 0, 3, 2)));
|
|
295
|
+
min_i8x16 = _mm_min_epi8(min_i8x16, _mm_srli_si128(min_i8x16, 2));
|
|
296
|
+
min_i8x16 = _mm_min_epi8(min_i8x16, _mm_srli_si128(min_i8x16, 1));
|
|
297
|
+
return (nk_i8_t)_mm_cvtsi128_si32(min_i8x16);
|
|
298
|
+
}
|
|
299
|
+
|
|
300
|
+
/** @brief Horizontal max of 64 signed i8s in a ZMM register. */
|
|
301
|
+
NK_INTERNAL nk_i8_t nk_reduce_max_i8x64_skylake_(__m512i max_i8x64) {
|
|
302
|
+
__m256i lo_i8x32 = _mm512_castsi512_si256(max_i8x64);
|
|
303
|
+
__m256i hi_i8x32 = _mm512_extracti64x4_epi64(max_i8x64, 1);
|
|
304
|
+
__m256i max_i8x32 = _mm256_max_epi8(lo_i8x32, hi_i8x32);
|
|
305
|
+
__m128i lo_i8x16 = _mm256_castsi256_si128(max_i8x32);
|
|
306
|
+
__m128i hi_i8x16 = _mm256_extracti128_si256(max_i8x32, 1);
|
|
307
|
+
__m128i max_i8x16 = _mm_max_epi8(lo_i8x16, hi_i8x16);
|
|
308
|
+
max_i8x16 = _mm_max_epi8(max_i8x16, _mm_shuffle_epi32(max_i8x16, _MM_SHUFFLE(2, 3, 0, 1)));
|
|
309
|
+
max_i8x16 = _mm_max_epi8(max_i8x16, _mm_shuffle_epi32(max_i8x16, _MM_SHUFFLE(1, 0, 3, 2)));
|
|
310
|
+
max_i8x16 = _mm_max_epi8(max_i8x16, _mm_srli_si128(max_i8x16, 2));
|
|
311
|
+
max_i8x16 = _mm_max_epi8(max_i8x16, _mm_srli_si128(max_i8x16, 1));
|
|
312
|
+
return (nk_i8_t)_mm_cvtsi128_si32(max_i8x16);
|
|
313
|
+
}
|
|
314
|
+
|
|
315
|
+
/** @brief Horizontal min of 64 unsigned u8s in a ZMM register. */
|
|
316
|
+
NK_INTERNAL nk_u8_t nk_reduce_min_u8x64_skylake_(__m512i min_u8x64) {
|
|
317
|
+
__m256i lo_u8x32 = _mm512_castsi512_si256(min_u8x64);
|
|
318
|
+
__m256i hi_u8x32 = _mm512_extracti64x4_epi64(min_u8x64, 1);
|
|
319
|
+
__m256i min_u8x32 = _mm256_min_epu8(lo_u8x32, hi_u8x32);
|
|
320
|
+
__m128i lo_u8x16 = _mm256_castsi256_si128(min_u8x32);
|
|
321
|
+
__m128i hi_u8x16 = _mm256_extracti128_si256(min_u8x32, 1);
|
|
322
|
+
__m128i min_u8x16 = _mm_min_epu8(lo_u8x16, hi_u8x16);
|
|
323
|
+
min_u8x16 = _mm_min_epu8(min_u8x16, _mm_shuffle_epi32(min_u8x16, _MM_SHUFFLE(2, 3, 0, 1)));
|
|
324
|
+
min_u8x16 = _mm_min_epu8(min_u8x16, _mm_shuffle_epi32(min_u8x16, _MM_SHUFFLE(1, 0, 3, 2)));
|
|
325
|
+
min_u8x16 = _mm_min_epu8(min_u8x16, _mm_srli_si128(min_u8x16, 2));
|
|
326
|
+
min_u8x16 = _mm_min_epu8(min_u8x16, _mm_srli_si128(min_u8x16, 1));
|
|
327
|
+
return (nk_u8_t)_mm_cvtsi128_si32(min_u8x16);
|
|
328
|
+
}
|
|
329
|
+
|
|
330
|
+
/** @brief Horizontal max of 64 unsigned u8s in a ZMM register. */
|
|
331
|
+
NK_INTERNAL nk_u8_t nk_reduce_max_u8x64_skylake_(__m512i max_u8x64) {
|
|
332
|
+
__m256i lo_u8x32 = _mm512_castsi512_si256(max_u8x64);
|
|
333
|
+
__m256i hi_u8x32 = _mm512_extracti64x4_epi64(max_u8x64, 1);
|
|
334
|
+
__m256i max_u8x32 = _mm256_max_epu8(lo_u8x32, hi_u8x32);
|
|
335
|
+
__m128i lo_u8x16 = _mm256_castsi256_si128(max_u8x32);
|
|
336
|
+
__m128i hi_u8x16 = _mm256_extracti128_si256(max_u8x32, 1);
|
|
337
|
+
__m128i max_u8x16 = _mm_max_epu8(lo_u8x16, hi_u8x16);
|
|
338
|
+
max_u8x16 = _mm_max_epu8(max_u8x16, _mm_shuffle_epi32(max_u8x16, _MM_SHUFFLE(2, 3, 0, 1)));
|
|
339
|
+
max_u8x16 = _mm_max_epu8(max_u8x16, _mm_shuffle_epi32(max_u8x16, _MM_SHUFFLE(1, 0, 3, 2)));
|
|
340
|
+
max_u8x16 = _mm_max_epu8(max_u8x16, _mm_srli_si128(max_u8x16, 2));
|
|
341
|
+
max_u8x16 = _mm_max_epu8(max_u8x16, _mm_srli_si128(max_u8x16, 1));
|
|
342
|
+
return (nk_u8_t)_mm_cvtsi128_si32(max_u8x16);
|
|
343
|
+
}
|
|
344
|
+
|
|
345
|
+
/** @brief Horizontal min of 32 signed i16s in a ZMM register. */
|
|
346
|
+
NK_INTERNAL nk_i16_t nk_reduce_min_i16x32_skylake_(__m512i min_i16x32) {
|
|
347
|
+
__m256i lo_i16x16 = _mm512_castsi512_si256(min_i16x32);
|
|
348
|
+
__m256i hi_i16x16 = _mm512_extracti64x4_epi64(min_i16x32, 1);
|
|
349
|
+
__m256i min_i16x16 = _mm256_min_epi16(lo_i16x16, hi_i16x16);
|
|
350
|
+
__m128i lo_i16x8 = _mm256_castsi256_si128(min_i16x16);
|
|
351
|
+
__m128i hi_i16x8 = _mm256_extracti128_si256(min_i16x16, 1);
|
|
352
|
+
__m128i min_i16x8 = _mm_min_epi16(lo_i16x8, hi_i16x8);
|
|
353
|
+
min_i16x8 = _mm_min_epi16(min_i16x8, _mm_shuffle_epi32(min_i16x8, _MM_SHUFFLE(2, 3, 0, 1)));
|
|
354
|
+
min_i16x8 = _mm_min_epi16(min_i16x8, _mm_shuffle_epi32(min_i16x8, _MM_SHUFFLE(1, 0, 3, 2)));
|
|
355
|
+
min_i16x8 = _mm_min_epi16(min_i16x8, _mm_srli_si128(min_i16x8, 2));
|
|
356
|
+
return (nk_i16_t)_mm_cvtsi128_si32(min_i16x8);
|
|
357
|
+
}
|
|
358
|
+
|
|
359
|
+
/** @brief Horizontal max of 32 signed i16s in a ZMM register. */
|
|
360
|
+
NK_INTERNAL nk_i16_t nk_reduce_max_i16x32_skylake_(__m512i max_i16x32) {
|
|
361
|
+
__m256i lo_i16x16 = _mm512_castsi512_si256(max_i16x32);
|
|
362
|
+
__m256i hi_i16x16 = _mm512_extracti64x4_epi64(max_i16x32, 1);
|
|
363
|
+
__m256i max_i16x16 = _mm256_max_epi16(lo_i16x16, hi_i16x16);
|
|
364
|
+
__m128i lo_i16x8 = _mm256_castsi256_si128(max_i16x16);
|
|
365
|
+
__m128i hi_i16x8 = _mm256_extracti128_si256(max_i16x16, 1);
|
|
366
|
+
__m128i max_i16x8 = _mm_max_epi16(lo_i16x8, hi_i16x8);
|
|
367
|
+
max_i16x8 = _mm_max_epi16(max_i16x8, _mm_shuffle_epi32(max_i16x8, _MM_SHUFFLE(2, 3, 0, 1)));
|
|
368
|
+
max_i16x8 = _mm_max_epi16(max_i16x8, _mm_shuffle_epi32(max_i16x8, _MM_SHUFFLE(1, 0, 3, 2)));
|
|
369
|
+
max_i16x8 = _mm_max_epi16(max_i16x8, _mm_srli_si128(max_i16x8, 2));
|
|
370
|
+
return (nk_i16_t)_mm_cvtsi128_si32(max_i16x8);
|
|
371
|
+
}
|
|
372
|
+
|
|
373
|
+
/** @brief Horizontal min of 32 unsigned u16s in a ZMM register. */
|
|
374
|
+
NK_INTERNAL nk_u16_t nk_reduce_min_u16x32_skylake_(__m512i min_u16x32) {
|
|
375
|
+
__m256i lo_u16x16 = _mm512_castsi512_si256(min_u16x32);
|
|
376
|
+
__m256i hi_u16x16 = _mm512_extracti64x4_epi64(min_u16x32, 1);
|
|
377
|
+
__m256i min_u16x16 = _mm256_min_epu16(lo_u16x16, hi_u16x16);
|
|
378
|
+
__m128i lo_u16x8 = _mm256_castsi256_si128(min_u16x16);
|
|
379
|
+
__m128i hi_u16x8 = _mm256_extracti128_si256(min_u16x16, 1);
|
|
380
|
+
__m128i min_u16x8 = _mm_min_epu16(lo_u16x8, hi_u16x8);
|
|
381
|
+
min_u16x8 = _mm_min_epu16(min_u16x8, _mm_shuffle_epi32(min_u16x8, _MM_SHUFFLE(2, 3, 0, 1)));
|
|
382
|
+
min_u16x8 = _mm_min_epu16(min_u16x8, _mm_shuffle_epi32(min_u16x8, _MM_SHUFFLE(1, 0, 3, 2)));
|
|
383
|
+
min_u16x8 = _mm_min_epu16(min_u16x8, _mm_srli_si128(min_u16x8, 2));
|
|
384
|
+
return (nk_u16_t)_mm_cvtsi128_si32(min_u16x8);
|
|
385
|
+
}
|
|
386
|
+
|
|
387
|
+
/** @brief Horizontal max of 32 unsigned u16s in a ZMM register. */
|
|
388
|
+
NK_INTERNAL nk_u16_t nk_reduce_max_u16x32_skylake_(__m512i max_u16x32) {
|
|
389
|
+
__m256i lo_u16x16 = _mm512_castsi512_si256(max_u16x32);
|
|
390
|
+
__m256i hi_u16x16 = _mm512_extracti64x4_epi64(max_u16x32, 1);
|
|
391
|
+
__m256i max_u16x16 = _mm256_max_epu16(lo_u16x16, hi_u16x16);
|
|
392
|
+
__m128i lo_u16x8 = _mm256_castsi256_si128(max_u16x16);
|
|
393
|
+
__m128i hi_u16x8 = _mm256_extracti128_si256(max_u16x16, 1);
|
|
394
|
+
__m128i max_u16x8 = _mm_max_epu16(lo_u16x8, hi_u16x8);
|
|
395
|
+
max_u16x8 = _mm_max_epu16(max_u16x8, _mm_shuffle_epi32(max_u16x8, _MM_SHUFFLE(2, 3, 0, 1)));
|
|
396
|
+
max_u16x8 = _mm_max_epu16(max_u16x8, _mm_shuffle_epi32(max_u16x8, _MM_SHUFFLE(1, 0, 3, 2)));
|
|
397
|
+
max_u16x8 = _mm_max_epu16(max_u16x8, _mm_srli_si128(max_u16x8, 2));
|
|
398
|
+
return (nk_u16_t)_mm_cvtsi128_si32(max_u16x8);
|
|
399
|
+
}
|
|
400
|
+
|
|
401
|
+
/** @brief Horizontal min of 16 signed i32s in a ZMM register. */
|
|
402
|
+
NK_INTERNAL nk_i32_t nk_reduce_min_i32x16_skylake_(__m512i min_i32x16) {
|
|
403
|
+
__m256i lo_i32x8 = _mm512_castsi512_si256(min_i32x16);
|
|
404
|
+
__m256i hi_i32x8 = _mm512_extracti64x4_epi64(min_i32x16, 1);
|
|
405
|
+
__m256i min_i32x8 = _mm256_min_epi32(lo_i32x8, hi_i32x8);
|
|
406
|
+
__m128i lo_i32x4 = _mm256_castsi256_si128(min_i32x8);
|
|
407
|
+
__m128i hi_i32x4 = _mm256_extracti128_si256(min_i32x8, 1);
|
|
408
|
+
__m128i min_i32x4 = _mm_min_epi32(lo_i32x4, hi_i32x4);
|
|
409
|
+
min_i32x4 = _mm_min_epi32(min_i32x4, _mm_shuffle_epi32(min_i32x4, _MM_SHUFFLE(2, 3, 0, 1)));
|
|
410
|
+
min_i32x4 = _mm_min_epi32(min_i32x4, _mm_shuffle_epi32(min_i32x4, _MM_SHUFFLE(1, 0, 3, 2)));
|
|
411
|
+
return _mm_cvtsi128_si32(min_i32x4);
|
|
412
|
+
}
|
|
413
|
+
|
|
414
|
+
/** @brief Horizontal max of 16 signed i32s in a ZMM register. */
|
|
415
|
+
NK_INTERNAL nk_i32_t nk_reduce_max_i32x16_skylake_(__m512i max_i32x16) {
|
|
416
|
+
__m256i lo_i32x8 = _mm512_castsi512_si256(max_i32x16);
|
|
417
|
+
__m256i hi_i32x8 = _mm512_extracti64x4_epi64(max_i32x16, 1);
|
|
418
|
+
__m256i max_i32x8 = _mm256_max_epi32(lo_i32x8, hi_i32x8);
|
|
419
|
+
__m128i lo_i32x4 = _mm256_castsi256_si128(max_i32x8);
|
|
420
|
+
__m128i hi_i32x4 = _mm256_extracti128_si256(max_i32x8, 1);
|
|
421
|
+
__m128i max_i32x4 = _mm_max_epi32(lo_i32x4, hi_i32x4);
|
|
422
|
+
max_i32x4 = _mm_max_epi32(max_i32x4, _mm_shuffle_epi32(max_i32x4, _MM_SHUFFLE(2, 3, 0, 1)));
|
|
423
|
+
max_i32x4 = _mm_max_epi32(max_i32x4, _mm_shuffle_epi32(max_i32x4, _MM_SHUFFLE(1, 0, 3, 2)));
|
|
424
|
+
return _mm_cvtsi128_si32(max_i32x4);
|
|
425
|
+
}
|
|
426
|
+
|
|
427
|
+
/** @brief Horizontal min of 16 unsigned u32s in a ZMM register. */
|
|
428
|
+
NK_INTERNAL nk_u32_t nk_reduce_min_u32x16_skylake_(__m512i min_u32x16) {
|
|
429
|
+
__m256i lo_u32x8 = _mm512_castsi512_si256(min_u32x16);
|
|
430
|
+
__m256i hi_u32x8 = _mm512_extracti64x4_epi64(min_u32x16, 1);
|
|
431
|
+
__m256i min_u32x8 = _mm256_min_epu32(lo_u32x8, hi_u32x8);
|
|
432
|
+
__m128i lo_u32x4 = _mm256_castsi256_si128(min_u32x8);
|
|
433
|
+
__m128i hi_u32x4 = _mm256_extracti128_si256(min_u32x8, 1);
|
|
434
|
+
__m128i min_u32x4 = _mm_min_epu32(lo_u32x4, hi_u32x4);
|
|
435
|
+
min_u32x4 = _mm_min_epu32(min_u32x4, _mm_shuffle_epi32(min_u32x4, _MM_SHUFFLE(2, 3, 0, 1)));
|
|
436
|
+
min_u32x4 = _mm_min_epu32(min_u32x4, _mm_shuffle_epi32(min_u32x4, _MM_SHUFFLE(1, 0, 3, 2)));
|
|
437
|
+
return (nk_u32_t)_mm_cvtsi128_si32(min_u32x4);
|
|
438
|
+
}
|
|
439
|
+
|
|
440
|
+
/** @brief Horizontal max of 16 unsigned u32s in a ZMM register. */
|
|
441
|
+
NK_INTERNAL nk_u32_t nk_reduce_max_u32x16_skylake_(__m512i max_u32x16) {
|
|
442
|
+
__m256i lo_u32x8 = _mm512_castsi512_si256(max_u32x16);
|
|
443
|
+
__m256i hi_u32x8 = _mm512_extracti64x4_epi64(max_u32x16, 1);
|
|
444
|
+
__m256i max_u32x8 = _mm256_max_epu32(lo_u32x8, hi_u32x8);
|
|
445
|
+
__m128i lo_u32x4 = _mm256_castsi256_si128(max_u32x8);
|
|
446
|
+
__m128i hi_u32x4 = _mm256_extracti128_si256(max_u32x8, 1);
|
|
447
|
+
__m128i max_u32x4 = _mm_max_epu32(lo_u32x4, hi_u32x4);
|
|
448
|
+
max_u32x4 = _mm_max_epu32(max_u32x4, _mm_shuffle_epi32(max_u32x4, _MM_SHUFFLE(2, 3, 0, 1)));
|
|
449
|
+
max_u32x4 = _mm_max_epu32(max_u32x4, _mm_shuffle_epi32(max_u32x4, _MM_SHUFFLE(1, 0, 3, 2)));
|
|
450
|
+
return (nk_u32_t)_mm_cvtsi128_si32(max_u32x4);
|
|
451
|
+
}
|
|
452
|
+
|
|
453
|
+
/** @brief Horizontal min of 8 signed i64s in a ZMM register. */
|
|
454
|
+
NK_INTERNAL nk_i64_t nk_reduce_min_i64x8_skylake_(__m512i min_i64x8) {
|
|
455
|
+
__m256i lo_i64x4 = _mm512_castsi512_si256(min_i64x8);
|
|
456
|
+
__m256i hi_i64x4 = _mm512_extracti64x4_epi64(min_i64x8, 1);
|
|
457
|
+
__m256i min_i64x4 = _mm256_min_epi64(lo_i64x4, hi_i64x4);
|
|
458
|
+
__m128i lo_i64x2 = _mm256_castsi256_si128(min_i64x4);
|
|
459
|
+
__m128i hi_i64x2 = _mm256_extracti128_si256(min_i64x4, 1);
|
|
460
|
+
__m128i min_i64x2 = _mm_min_epi64(lo_i64x2, hi_i64x2);
|
|
461
|
+
__m128i hi_lane_i64 = _mm_unpackhi_epi64(min_i64x2, min_i64x2);
|
|
462
|
+
__m128i final_i64 = _mm_min_epi64(min_i64x2, hi_lane_i64);
|
|
463
|
+
return _mm_cvtsi128_si64(final_i64);
|
|
464
|
+
}
|
|
465
|
+
|
|
466
|
+
/** @brief Horizontal max of 8 signed i64s in a ZMM register. */
|
|
467
|
+
NK_INTERNAL nk_i64_t nk_reduce_max_i64x8_skylake_(__m512i max_i64x8) {
|
|
468
|
+
__m256i lo_i64x4 = _mm512_castsi512_si256(max_i64x8);
|
|
469
|
+
__m256i hi_i64x4 = _mm512_extracti64x4_epi64(max_i64x8, 1);
|
|
470
|
+
__m256i max_i64x4 = _mm256_max_epi64(lo_i64x4, hi_i64x4);
|
|
471
|
+
__m128i lo_i64x2 = _mm256_castsi256_si128(max_i64x4);
|
|
472
|
+
__m128i hi_i64x2 = _mm256_extracti128_si256(max_i64x4, 1);
|
|
473
|
+
__m128i max_i64x2 = _mm_max_epi64(lo_i64x2, hi_i64x2);
|
|
474
|
+
__m128i hi_lane_i64 = _mm_unpackhi_epi64(max_i64x2, max_i64x2);
|
|
475
|
+
__m128i final_i64 = _mm_max_epi64(max_i64x2, hi_lane_i64);
|
|
476
|
+
return _mm_cvtsi128_si64(final_i64);
|
|
477
|
+
}
|
|
478
|
+
|
|
479
|
+
/** @brief Horizontal min of 8 unsigned u64s in a ZMM register. */
|
|
480
|
+
NK_INTERNAL nk_u64_t nk_reduce_min_u64x8_skylake_(__m512i min_u64x8) {
|
|
481
|
+
__m256i lo_u64x4 = _mm512_castsi512_si256(min_u64x8);
|
|
482
|
+
__m256i hi_u64x4 = _mm512_extracti64x4_epi64(min_u64x8, 1);
|
|
483
|
+
__m256i min_u64x4 = _mm256_min_epu64(lo_u64x4, hi_u64x4);
|
|
484
|
+
__m128i lo_u64x2 = _mm256_castsi256_si128(min_u64x4);
|
|
485
|
+
__m128i hi_u64x2 = _mm256_extracti128_si256(min_u64x4, 1);
|
|
486
|
+
__m128i min_u64x2 = _mm_min_epu64(lo_u64x2, hi_u64x2);
|
|
487
|
+
__m128i hi_lane_u64 = _mm_unpackhi_epi64(min_u64x2, min_u64x2);
|
|
488
|
+
__m128i final_u64 = _mm_min_epu64(min_u64x2, hi_lane_u64);
|
|
489
|
+
return (nk_u64_t)_mm_cvtsi128_si64(final_u64);
|
|
490
|
+
}
|
|
491
|
+
|
|
492
|
+
/** @brief Horizontal max of 8 unsigned u64s in a ZMM register. */
|
|
493
|
+
NK_INTERNAL nk_u64_t nk_reduce_max_u64x8_skylake_(__m512i max_u64x8) {
|
|
494
|
+
__m256i lo_u64x4 = _mm512_castsi512_si256(max_u64x8);
|
|
495
|
+
__m256i hi_u64x4 = _mm512_extracti64x4_epi64(max_u64x8, 1);
|
|
496
|
+
__m256i max_u64x4 = _mm256_max_epu64(lo_u64x4, hi_u64x4);
|
|
497
|
+
__m128i lo_u64x2 = _mm256_castsi256_si128(max_u64x4);
|
|
498
|
+
__m128i hi_u64x2 = _mm256_extracti128_si256(max_u64x4, 1);
|
|
499
|
+
__m128i max_u64x2 = _mm_max_epu64(lo_u64x2, hi_u64x2);
|
|
500
|
+
__m128i hi_lane_u64 = _mm_unpackhi_epi64(max_u64x2, max_u64x2);
|
|
501
|
+
__m128i final_u64 = _mm_max_epu64(max_u64x2, hi_lane_u64);
|
|
502
|
+
return (nk_u64_t)_mm_cvtsi128_si64(final_u64);
|
|
503
|
+
}
|
|
504
|
+
|
|
505
|
+
/** @brief Horizontal sum of 8 unsigned u64s in a ZMM register. */
|
|
506
|
+
NK_INTERNAL nk_u64_t nk_reduce_add_u64x8_skylake_(__m512i sum_u64x8) {
|
|
507
|
+
__m256i lo_u64x4 = _mm512_castsi512_si256(sum_u64x8);
|
|
508
|
+
__m256i hi_u64x4 = _mm512_extracti64x4_epi64(sum_u64x8, 1);
|
|
509
|
+
__m256i sum_u64x4 = _mm256_add_epi64(lo_u64x4, hi_u64x4);
|
|
510
|
+
__m128i lo_u64x2 = _mm256_castsi256_si128(sum_u64x4);
|
|
511
|
+
__m128i hi_u64x2 = _mm256_extracti128_si256(sum_u64x4, 1);
|
|
512
|
+
__m128i sum_u64x2 = _mm_add_epi64(lo_u64x2, hi_u64x2);
|
|
513
|
+
__m128i hi_lane_u64 = _mm_unpackhi_epi64(sum_u64x2, sum_u64x2);
|
|
514
|
+
__m128i final_u64 = _mm_add_epi64(sum_u64x2, hi_lane_u64);
|
|
515
|
+
return (nk_u64_t)_mm_cvtsi128_si64(final_u64);
|
|
516
|
+
}
|
|
517
|
+
|
|
518
|
+
NK_INTERNAL __m512i nk_fp8x64_to_u8x64_comparable_skylake_(__m512i raw_i8x64) {
|
|
519
|
+
__mmask64 neg_m64 = _mm512_test_epi8_mask(raw_i8x64, _mm512_set1_epi8((char)0x80));
|
|
520
|
+
__m512i pos_xor_i8x64 = _mm512_set1_epi8((char)0x80);
|
|
521
|
+
__m512i neg_xor_i8x64 = _mm512_set1_epi8((char)0xFF);
|
|
522
|
+
__m512i xor_i8x64 = _mm512_mask_mov_epi8(pos_xor_i8x64, neg_m64, neg_xor_i8x64);
|
|
523
|
+
return _mm512_xor_si512(raw_i8x64, xor_i8x64);
|
|
524
|
+
}
|
|
525
|
+
|
|
526
|
+
NK_INTERNAL __m512i nk_u8x64_comparable_to_fp8x64_skylake_(__m512i cmp_i8x64) {
|
|
527
|
+
__mmask64 was_neg_m64 = _mm512_cmplt_epu8_mask(cmp_i8x64, _mm512_set1_epi8((char)0x80));
|
|
528
|
+
__m512i neg_xor_i8x64 = _mm512_set1_epi8((char)0xFF);
|
|
529
|
+
__m512i pos_xor_i8x64 = _mm512_set1_epi8((char)0x80);
|
|
530
|
+
__m512i xor_i8x64 = _mm512_mask_mov_epi8(pos_xor_i8x64, was_neg_m64, neg_xor_i8x64);
|
|
531
|
+
return _mm512_xor_si512(cmp_i8x64, xor_i8x64);
|
|
532
|
+
}
|
|
533
|
+
|
|
534
|
+
/** @brief Horizontal argmin: returns index of first minimum unsigned byte in ZMM register. */
|
|
535
|
+
NK_INTERNAL nk_size_t nk_argmin_u8x64_skylake_(__m512i data_u8x64) {
|
|
536
|
+
nk_u8_t min_val = nk_reduce_min_u8x64_skylake_(data_u8x64);
|
|
537
|
+
__mmask64 eq_m64 = _mm512_cmpeq_epi8_mask(data_u8x64, _mm512_set1_epi8((char)min_val));
|
|
538
|
+
return (nk_size_t)_tzcnt_u64(eq_m64);
|
|
539
|
+
}
|
|
540
|
+
|
|
541
|
+
/** @brief Horizontal argmax: returns index of first maximum unsigned byte in ZMM register. */
|
|
542
|
+
NK_INTERNAL nk_size_t nk_argmax_u8x64_skylake_(__m512i data_u8x64) {
|
|
543
|
+
nk_u8_t max_val = nk_reduce_max_u8x64_skylake_(data_u8x64);
|
|
544
|
+
__mmask64 eq_m64 = _mm512_cmpeq_epi8_mask(data_u8x64, _mm512_set1_epi8((char)max_val));
|
|
545
|
+
return (nk_size_t)_tzcnt_u64(eq_m64);
|
|
546
|
+
}
|
|
547
|
+
|
|
548
|
+
NK_INTERNAL __m512i nk_fp6x64_to_u8x64_comparable_skylake_(__m512i raw_i8x64) {
|
|
549
|
+
raw_i8x64 = _mm512_and_si512(raw_i8x64, _mm512_set1_epi8(0x3F)); // mask to 6 valid bits
|
|
550
|
+
__mmask64 neg_m64 = _mm512_test_epi8_mask(raw_i8x64, _mm512_set1_epi8(0x20));
|
|
551
|
+
__m512i pos_xor_i8x64 = _mm512_set1_epi8(0x20);
|
|
552
|
+
__m512i neg_xor_i8x64 = _mm512_set1_epi8(0x3F);
|
|
553
|
+
__m512i xor_i8x64 = _mm512_mask_mov_epi8(pos_xor_i8x64, neg_m64, neg_xor_i8x64);
|
|
554
|
+
return _mm512_xor_si512(raw_i8x64, xor_i8x64);
|
|
555
|
+
}
|
|
556
|
+
|
|
557
|
+
NK_INTERNAL __m512i nk_u8x64_comparable_to_fp6x64_skylake_(__m512i cmp_i8x64) {
|
|
558
|
+
__mmask64 was_neg_m64 = _mm512_cmplt_epu8_mask(cmp_i8x64, _mm512_set1_epi8(0x20));
|
|
559
|
+
__m512i neg_xor_i8x64 = _mm512_set1_epi8(0x3F);
|
|
560
|
+
__m512i pos_xor_i8x64 = _mm512_set1_epi8(0x20);
|
|
561
|
+
__m512i xor_i8x64 = _mm512_mask_mov_epi8(pos_xor_i8x64, was_neg_m64, neg_xor_i8x64);
|
|
562
|
+
return _mm512_xor_si512(cmp_i8x64, xor_i8x64);
|
|
563
|
+
}
|
|
564
|
+
|
|
565
|
+
NK_INTERNAL void nk_reduce_moments_f32_skylake_contiguous_( //
|
|
566
|
+
nk_f32_t const *data_ptr, nk_size_t count, //
|
|
567
|
+
nk_f64_t *sum_ptr, nk_f64_t *sumsq_ptr) {
|
|
568
|
+
__m512d sum_low_f64x8 = _mm512_setzero_pd(), sum_high_f64x8 = _mm512_setzero_pd();
|
|
569
|
+
__m512d sumsq_low_f64x8 = _mm512_setzero_pd(), sumsq_high_f64x8 = _mm512_setzero_pd();
|
|
570
|
+
nk_size_t idx = 0;
|
|
571
|
+
for (; idx + 16 <= count; idx += 16) {
|
|
572
|
+
__m512d low_f64x8 = _mm512_cvtps_pd(_mm256_loadu_ps(data_ptr + idx));
|
|
573
|
+
__m512d high_f64x8 = _mm512_cvtps_pd(_mm256_loadu_ps(data_ptr + idx + 8));
|
|
574
|
+
sum_low_f64x8 = _mm512_add_pd(sum_low_f64x8, low_f64x8);
|
|
575
|
+
sum_high_f64x8 = _mm512_add_pd(sum_high_f64x8, high_f64x8);
|
|
576
|
+
sumsq_low_f64x8 = _mm512_fmadd_pd(low_f64x8, low_f64x8, sumsq_low_f64x8);
|
|
577
|
+
sumsq_high_f64x8 = _mm512_fmadd_pd(high_f64x8, high_f64x8, sumsq_high_f64x8);
|
|
578
|
+
}
|
|
579
|
+
__m512d sum_f64x8 = _mm512_add_pd(sum_low_f64x8, sum_high_f64x8);
|
|
580
|
+
__m512d sumsq_f64x8 = _mm512_add_pd(sumsq_low_f64x8, sumsq_high_f64x8);
|
|
581
|
+
nk_size_t remaining = count - idx;
|
|
582
|
+
if (remaining > 0) {
|
|
583
|
+
__mmask16 tail_mask = (__mmask16)_bzhi_u32(0xFFFF, (unsigned int)remaining);
|
|
584
|
+
__m512 tail_f32x16 = _mm512_maskz_loadu_ps(tail_mask, data_ptr + idx);
|
|
585
|
+
__m512d low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(tail_f32x16));
|
|
586
|
+
__m512d high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(tail_f32x16, 1));
|
|
587
|
+
sum_f64x8 = _mm512_add_pd(sum_f64x8, low_f64x8);
|
|
588
|
+
sumsq_f64x8 = _mm512_fmadd_pd(low_f64x8, low_f64x8, sumsq_f64x8);
|
|
589
|
+
if (remaining > 8)
|
|
590
|
+
sum_f64x8 = _mm512_add_pd(sum_f64x8, high_f64x8),
|
|
591
|
+
sumsq_f64x8 = _mm512_fmadd_pd(high_f64x8, high_f64x8, sumsq_f64x8);
|
|
592
|
+
}
|
|
593
|
+
*sum_ptr = nk_reduce_add_f64x8_skylake_(sum_f64x8);
|
|
594
|
+
*sumsq_ptr = nk_reduce_add_f64x8_skylake_(sumsq_f64x8);
|
|
595
|
+
}
|
|
596
|
+
|
|
597
|
+
NK_INTERNAL void nk_reduce_moments_f32_skylake_gather_( //
|
|
598
|
+
nk_f32_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
599
|
+
nk_f64_t *sum_ptr, nk_f64_t *sumsq_ptr) {
|
|
600
|
+
nk_i32_t stride_elements = (nk_i32_t)(stride_bytes / sizeof(nk_f32_t));
|
|
601
|
+
__m512i indices_i32x16 = _mm512_mullo_epi32(_mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15),
|
|
602
|
+
_mm512_set1_epi32(stride_elements));
|
|
603
|
+
__m512d sum_f64x8 = _mm512_setzero_pd();
|
|
604
|
+
__m512d sumsq_f64x8 = _mm512_setzero_pd();
|
|
605
|
+
nk_size_t idx = 0;
|
|
606
|
+
for (; idx + 16 <= count; idx += 16) {
|
|
607
|
+
__m512 gathered_f32x16 = _mm512_i32gather_ps(indices_i32x16, data_ptr + idx * stride_elements,
|
|
608
|
+
sizeof(nk_f32_t));
|
|
609
|
+
__m256 low_f32x8 = _mm512_castps512_ps256(gathered_f32x16);
|
|
610
|
+
__m256 high_f32x8 = _mm512_extractf32x8_ps(gathered_f32x16, 1);
|
|
611
|
+
__m512d low_f64x8 = _mm512_cvtps_pd(low_f32x8);
|
|
612
|
+
__m512d high_f64x8 = _mm512_cvtps_pd(high_f32x8);
|
|
613
|
+
sum_f64x8 = _mm512_add_pd(sum_f64x8, low_f64x8);
|
|
614
|
+
sum_f64x8 = _mm512_add_pd(sum_f64x8, high_f64x8);
|
|
615
|
+
sumsq_f64x8 = _mm512_fmadd_pd(low_f64x8, low_f64x8, sumsq_f64x8);
|
|
616
|
+
sumsq_f64x8 = _mm512_fmadd_pd(high_f64x8, high_f64x8, sumsq_f64x8);
|
|
617
|
+
}
|
|
618
|
+
nk_f64_t sum = nk_reduce_add_f64x8_skylake_(sum_f64x8);
|
|
619
|
+
nk_f64_t sumsq = nk_reduce_add_f64x8_skylake_(sumsq_f64x8);
|
|
620
|
+
unsigned char const *ptr = (unsigned char const *)(data_ptr + idx * stride_elements);
|
|
621
|
+
for (; idx < count; ++idx, ptr += stride_bytes) {
|
|
622
|
+
nk_f64_t val = (nk_f64_t)(*(nk_f32_t const *)ptr);
|
|
623
|
+
sum += val, sumsq += val * val;
|
|
624
|
+
}
|
|
625
|
+
*sum_ptr = sum, *sumsq_ptr = sumsq;
|
|
626
|
+
}
|
|
627
|
+
|
|
628
|
+
NK_INTERNAL void nk_reduce_moments_f32_skylake_strided_( //
|
|
629
|
+
nk_f32_t const *data_ptr, nk_size_t count, nk_size_t stride_elements, //
|
|
630
|
+
nk_f64_t *sum_ptr, nk_f64_t *sumsq_ptr) {
|
|
631
|
+
__mmask16 stride_mask = nk_stride_mask_b32x16_(stride_elements);
|
|
632
|
+
__m512d sum_f64x8 = _mm512_setzero_pd(), sumsq_f64x8 = _mm512_setzero_pd();
|
|
633
|
+
nk_size_t idx = 0, total = count * stride_elements;
|
|
634
|
+
nk_size_t step = (nk_size_t)_mm_popcnt_u64((nk_u64_t)stride_mask) * stride_elements;
|
|
635
|
+
for (; idx + step <= total; idx += step) {
|
|
636
|
+
__m512 data_f32x16 = _mm512_maskz_loadu_ps(stride_mask, data_ptr + idx);
|
|
637
|
+
__m512d low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(data_f32x16));
|
|
638
|
+
__m512d high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(data_f32x16, 1));
|
|
639
|
+
sum_f64x8 = _mm512_add_pd(sum_f64x8, low_f64x8);
|
|
640
|
+
sum_f64x8 = _mm512_add_pd(sum_f64x8, high_f64x8);
|
|
641
|
+
sumsq_f64x8 = _mm512_fmadd_pd(low_f64x8, low_f64x8, sumsq_f64x8);
|
|
642
|
+
sumsq_f64x8 = _mm512_fmadd_pd(high_f64x8, high_f64x8, sumsq_f64x8);
|
|
643
|
+
}
|
|
644
|
+
nk_size_t remaining = total - idx;
|
|
645
|
+
if (remaining > 0) {
|
|
646
|
+
__mmask16 tail_mask = (__mmask16)_bzhi_u32(0xFFFF, (unsigned int)remaining);
|
|
647
|
+
__m512 data_f32x16 = _mm512_maskz_loadu_ps(stride_mask & tail_mask, data_ptr + idx);
|
|
648
|
+
__m512d low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(data_f32x16));
|
|
649
|
+
__m512d high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(data_f32x16, 1));
|
|
650
|
+
sum_f64x8 = _mm512_add_pd(sum_f64x8, low_f64x8);
|
|
651
|
+
sumsq_f64x8 = _mm512_fmadd_pd(low_f64x8, low_f64x8, sumsq_f64x8);
|
|
652
|
+
if (remaining > 8)
|
|
653
|
+
sum_f64x8 = _mm512_add_pd(sum_f64x8, high_f64x8),
|
|
654
|
+
sumsq_f64x8 = _mm512_fmadd_pd(high_f64x8, high_f64x8, sumsq_f64x8);
|
|
655
|
+
}
|
|
656
|
+
*sum_ptr = nk_reduce_add_f64x8_skylake_(sum_f64x8);
|
|
657
|
+
*sumsq_ptr = nk_reduce_add_f64x8_skylake_(sumsq_f64x8);
|
|
658
|
+
}
|
|
659
|
+
|
|
660
|
+
NK_PUBLIC void nk_reduce_moments_f32_skylake( //
|
|
661
|
+
nk_f32_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
662
|
+
nk_f64_t *sum_ptr, nk_f64_t *sumsq_ptr) {
|
|
663
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_f32_t);
|
|
664
|
+
int aligned = (stride_bytes % sizeof(nk_f32_t) == 0);
|
|
665
|
+
if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
|
|
666
|
+
else if (!aligned) nk_reduce_moments_f32_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
667
|
+
else if (count > (nk_size_t)(NK_U16_MAX + 1) * 16) {
|
|
668
|
+
nk_size_t left_count = count / 2;
|
|
669
|
+
nk_f64_t left_sum, left_sumsq, right_sum, right_sumsq;
|
|
670
|
+
nk_reduce_moments_f32_skylake(data_ptr, left_count, stride_bytes, &left_sum, &left_sumsq);
|
|
671
|
+
nk_reduce_moments_f32_skylake(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
|
|
672
|
+
&right_sum, &right_sumsq);
|
|
673
|
+
*sum_ptr = left_sum + right_sum, *sumsq_ptr = left_sumsq + right_sumsq;
|
|
674
|
+
}
|
|
675
|
+
else if (stride_elements == 1) nk_reduce_moments_f32_skylake_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
|
|
676
|
+
else if (stride_elements <= 8)
|
|
677
|
+
nk_reduce_moments_f32_skylake_strided_(data_ptr, count, stride_elements, sum_ptr, sumsq_ptr);
|
|
678
|
+
else nk_reduce_moments_f32_skylake_gather_(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
679
|
+
}
|
|
680
|
+
|
|
681
|
+
NK_INTERNAL void nk_reduce_minmax_f32_skylake_contiguous_( //
|
|
682
|
+
nk_f32_t const *data_ptr, nk_size_t count, //
|
|
683
|
+
nk_f32_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
684
|
+
nk_f32_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
685
|
+
__m512 min_f32x16 = _mm512_set1_ps(NK_F32_MAX);
|
|
686
|
+
__m512 max_f32x16 = _mm512_set1_ps(NK_F32_MIN);
|
|
687
|
+
__m512i min_loop_cycle_u32x16 = _mm512_setzero_si512();
|
|
688
|
+
__m512i max_loop_cycle_u32x16 = _mm512_setzero_si512();
|
|
689
|
+
__m512i current_loop_cycle_u32x16 = _mm512_setzero_si512();
|
|
690
|
+
__m512i one_u32x16 = _mm512_set1_epi32(1);
|
|
691
|
+
|
|
692
|
+
nk_size_t idx = 0;
|
|
693
|
+
for (; idx + 16 <= count; idx += 16) {
|
|
694
|
+
__m512 data_f32x16 = _mm512_loadu_ps(data_ptr + idx);
|
|
695
|
+
__mmask16 min_changed_mask = _mm512_cmp_ps_mask(data_f32x16, min_f32x16, _CMP_LT_OQ);
|
|
696
|
+
__mmask16 max_changed_mask = _mm512_cmp_ps_mask(data_f32x16, max_f32x16, _CMP_GT_OQ);
|
|
697
|
+
min_f32x16 = _mm512_mask_mov_ps(min_f32x16, min_changed_mask, data_f32x16);
|
|
698
|
+
max_f32x16 = _mm512_mask_mov_ps(max_f32x16, max_changed_mask, data_f32x16);
|
|
699
|
+
min_loop_cycle_u32x16 = _mm512_mask_mov_epi32(min_loop_cycle_u32x16, min_changed_mask,
|
|
700
|
+
current_loop_cycle_u32x16);
|
|
701
|
+
max_loop_cycle_u32x16 = _mm512_mask_mov_epi32(max_loop_cycle_u32x16, max_changed_mask,
|
|
702
|
+
current_loop_cycle_u32x16);
|
|
703
|
+
current_loop_cycle_u32x16 = _mm512_add_epi32(current_loop_cycle_u32x16, one_u32x16);
|
|
704
|
+
}
|
|
705
|
+
|
|
706
|
+
nk_size_t remaining = count - idx;
|
|
707
|
+
if (remaining > 0) {
|
|
708
|
+
__mmask16 tail_load = (__mmask16)_bzhi_u32(0xFFFF, (unsigned int)remaining);
|
|
709
|
+
__m512 tail_f32x16 = _mm512_maskz_loadu_ps(tail_load, data_ptr + idx);
|
|
710
|
+
__mmask16 min_changed_mask = _mm512_mask_cmp_ps_mask(tail_load, tail_f32x16, min_f32x16, _CMP_LT_OQ);
|
|
711
|
+
__mmask16 max_changed_mask = _mm512_mask_cmp_ps_mask(tail_load, tail_f32x16, max_f32x16, _CMP_GT_OQ);
|
|
712
|
+
min_f32x16 = _mm512_mask_mov_ps(min_f32x16, min_changed_mask, tail_f32x16);
|
|
713
|
+
max_f32x16 = _mm512_mask_mov_ps(max_f32x16, max_changed_mask, tail_f32x16);
|
|
714
|
+
min_loop_cycle_u32x16 = _mm512_mask_mov_epi32(min_loop_cycle_u32x16, min_changed_mask,
|
|
715
|
+
current_loop_cycle_u32x16);
|
|
716
|
+
max_loop_cycle_u32x16 = _mm512_mask_mov_epi32(max_loop_cycle_u32x16, max_changed_mask,
|
|
717
|
+
current_loop_cycle_u32x16);
|
|
718
|
+
}
|
|
719
|
+
|
|
720
|
+
nk_f32_t min_value = nk_reduce_min_f32x16_skylake_(min_f32x16);
|
|
721
|
+
nk_f32_t max_value = nk_reduce_max_f32x16_skylake_(max_f32x16);
|
|
722
|
+
if (min_value == NK_F32_MAX && max_value == NK_F32_MIN) {
|
|
723
|
+
*min_value_ptr = NK_F32_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_F32_MIN,
|
|
724
|
+
*max_index_ptr = NK_SIZE_MAX;
|
|
725
|
+
return;
|
|
726
|
+
}
|
|
727
|
+
unsigned int min_lane, max_lane;
|
|
728
|
+
{
|
|
729
|
+
__mmask16 value_match_mask = _mm512_cmp_ps_mask(min_f32x16, _mm512_set1_ps(min_value), _CMP_EQ_OQ);
|
|
730
|
+
__m512i masked_cycle_u32x16 = _mm512_mask_blend_epi32(value_match_mask, _mm512_set1_epi32((int)NK_U32_MAX),
|
|
731
|
+
min_loop_cycle_u32x16);
|
|
732
|
+
nk_u32_t earliest_loop_cycle = nk_reduce_min_u32x16_skylake_(masked_cycle_u32x16);
|
|
733
|
+
__mmask16 cycle_match_mask = _mm512_cmpeq_epi32_mask(masked_cycle_u32x16,
|
|
734
|
+
_mm512_set1_epi32((int)earliest_loop_cycle));
|
|
735
|
+
min_lane = _tzcnt_u32(cycle_match_mask);
|
|
736
|
+
}
|
|
737
|
+
{
|
|
738
|
+
__mmask16 value_match_mask = _mm512_cmp_ps_mask(max_f32x16, _mm512_set1_ps(max_value), _CMP_EQ_OQ);
|
|
739
|
+
__m512i masked_cycle_u32x16 = _mm512_mask_blend_epi32(value_match_mask, _mm512_set1_epi32((int)NK_U32_MAX),
|
|
740
|
+
max_loop_cycle_u32x16);
|
|
741
|
+
nk_u32_t earliest_loop_cycle = nk_reduce_min_u32x16_skylake_(masked_cycle_u32x16);
|
|
742
|
+
__mmask16 cycle_match_mask = _mm512_cmpeq_epi32_mask(masked_cycle_u32x16,
|
|
743
|
+
_mm512_set1_epi32((int)earliest_loop_cycle));
|
|
744
|
+
max_lane = _tzcnt_u32(cycle_match_mask);
|
|
745
|
+
}
|
|
746
|
+
nk_b512_vec_t loop_cycle_vec;
|
|
747
|
+
loop_cycle_vec.zmm = min_loop_cycle_u32x16;
|
|
748
|
+
*min_value_ptr = min_value, *min_index_ptr = (nk_size_t)loop_cycle_vec.u32s[min_lane] * 16 + min_lane;
|
|
749
|
+
loop_cycle_vec.zmm = max_loop_cycle_u32x16;
|
|
750
|
+
*max_value_ptr = max_value, *max_index_ptr = (nk_size_t)loop_cycle_vec.u32s[max_lane] * 16 + max_lane;
|
|
751
|
+
}
|
|
752
|
+
|
|
753
|
+
NK_PUBLIC void nk_reduce_minmax_f32_skylake( //
|
|
754
|
+
nk_f32_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
755
|
+
nk_f32_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
756
|
+
nk_f32_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
757
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_f32_t);
|
|
758
|
+
int aligned = (stride_bytes % sizeof(nk_f32_t) == 0);
|
|
759
|
+
if (count == 0)
|
|
760
|
+
*min_value_ptr = NK_F32_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_F32_MIN,
|
|
761
|
+
*max_index_ptr = NK_SIZE_MAX;
|
|
762
|
+
else if (!aligned)
|
|
763
|
+
nk_reduce_minmax_f32_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
764
|
+
max_index_ptr);
|
|
765
|
+
else if (count > (nk_size_t)NK_U32_MAX * 16) {
|
|
766
|
+
nk_size_t left_count = count / 2;
|
|
767
|
+
nk_f32_t left_min, right_min, left_max, right_max;
|
|
768
|
+
nk_size_t left_min_index, right_min_index, left_max_index, right_max_index;
|
|
769
|
+
nk_reduce_minmax_f32_skylake(data_ptr, left_count, stride_bytes, &left_min, &left_min_index, &left_max,
|
|
770
|
+
&left_max_index);
|
|
771
|
+
nk_reduce_minmax_f32_skylake(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
|
|
772
|
+
&right_min, &right_min_index, &right_max, &right_max_index);
|
|
773
|
+
if (right_min < left_min) *min_value_ptr = right_min, *min_index_ptr = left_count + right_min_index;
|
|
774
|
+
else *min_value_ptr = left_min, *min_index_ptr = left_min_index;
|
|
775
|
+
if (right_max > left_max) *max_value_ptr = right_max, *max_index_ptr = left_count + right_max_index;
|
|
776
|
+
else *max_value_ptr = left_max, *max_index_ptr = left_max_index;
|
|
777
|
+
}
|
|
778
|
+
else if (stride_elements == 1)
|
|
779
|
+
nk_reduce_minmax_f32_skylake_contiguous_(data_ptr, count, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
780
|
+
max_index_ptr);
|
|
781
|
+
else
|
|
782
|
+
nk_reduce_minmax_f32_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
783
|
+
max_index_ptr);
|
|
784
|
+
}
|
|
785
|
+
|
|
786
|
+
NK_INTERNAL void nk_reduce_moments_f64_skylake_contiguous_( //
|
|
787
|
+
nk_f64_t const *data_ptr, nk_size_t count, //
|
|
788
|
+
nk_f64_t *sum_ptr, nk_f64_t *sumsq_ptr) {
|
|
789
|
+
__m512d sum_f64x8 = _mm512_setzero_pd();
|
|
790
|
+
__m512d sum_comp_f64x8 = _mm512_setzero_pd();
|
|
791
|
+
__m512d sumsq_f64x8 = _mm512_setzero_pd();
|
|
792
|
+
__m512d sumsq_comp_f64x8 = _mm512_setzero_pd();
|
|
793
|
+
nk_size_t idx = 0;
|
|
794
|
+
for (; idx + 8 <= count; idx += 8) {
|
|
795
|
+
__m512d val_f64x8 = _mm512_loadu_pd(data_ptr + idx);
|
|
796
|
+
// Knuth 2-SUM for sum
|
|
797
|
+
__m512d tentative_f64x8 = _mm512_add_pd(sum_f64x8, val_f64x8);
|
|
798
|
+
__m512d round_f64x8 = _mm512_sub_pd(tentative_f64x8, sum_f64x8);
|
|
799
|
+
__m512d corr_f64x8 = _mm512_add_pd(_mm512_sub_pd(sum_f64x8, _mm512_sub_pd(tentative_f64x8, round_f64x8)),
|
|
800
|
+
_mm512_sub_pd(val_f64x8, round_f64x8));
|
|
801
|
+
sum_comp_f64x8 = _mm512_add_pd(sum_comp_f64x8, corr_f64x8);
|
|
802
|
+
sum_f64x8 = tentative_f64x8;
|
|
803
|
+
// Knuth 2-SUM for sumsq
|
|
804
|
+
__m512d sq_f64x8 = _mm512_mul_pd(val_f64x8, val_f64x8);
|
|
805
|
+
__m512d tentative_sq_f64x8 = _mm512_add_pd(sumsq_f64x8, sq_f64x8);
|
|
806
|
+
__m512d round_sq_f64x8 = _mm512_sub_pd(tentative_sq_f64x8, sumsq_f64x8);
|
|
807
|
+
__m512d corr_sq_f64x8 = _mm512_add_pd(
|
|
808
|
+
_mm512_sub_pd(sumsq_f64x8, _mm512_sub_pd(tentative_sq_f64x8, round_sq_f64x8)),
|
|
809
|
+
_mm512_sub_pd(sq_f64x8, round_sq_f64x8));
|
|
810
|
+
sumsq_comp_f64x8 = _mm512_add_pd(sumsq_comp_f64x8, corr_sq_f64x8);
|
|
811
|
+
sumsq_f64x8 = tentative_sq_f64x8;
|
|
812
|
+
}
|
|
813
|
+
nk_size_t remaining = count - idx;
|
|
814
|
+
if (remaining > 0) {
|
|
815
|
+
__mmask8 tail_mask = (__mmask8)_bzhi_u32(0xFF, (unsigned int)remaining);
|
|
816
|
+
__m512d val_f64x8 = _mm512_maskz_loadu_pd(tail_mask, data_ptr + idx);
|
|
817
|
+
// Knuth 2-SUM for sum
|
|
818
|
+
__m512d tentative_f64x8 = _mm512_add_pd(sum_f64x8, val_f64x8);
|
|
819
|
+
__m512d round_f64x8 = _mm512_sub_pd(tentative_f64x8, sum_f64x8);
|
|
820
|
+
__m512d corr_f64x8 = _mm512_add_pd(_mm512_sub_pd(sum_f64x8, _mm512_sub_pd(tentative_f64x8, round_f64x8)),
|
|
821
|
+
_mm512_sub_pd(val_f64x8, round_f64x8));
|
|
822
|
+
sum_comp_f64x8 = _mm512_add_pd(sum_comp_f64x8, corr_f64x8);
|
|
823
|
+
sum_f64x8 = tentative_f64x8;
|
|
824
|
+
// Knuth 2-SUM for sumsq
|
|
825
|
+
__m512d sq_f64x8 = _mm512_mul_pd(val_f64x8, val_f64x8);
|
|
826
|
+
__m512d tentative_sq_f64x8 = _mm512_add_pd(sumsq_f64x8, sq_f64x8);
|
|
827
|
+
__m512d round_sq_f64x8 = _mm512_sub_pd(tentative_sq_f64x8, sumsq_f64x8);
|
|
828
|
+
__m512d corr_sq_f64x8 = _mm512_add_pd(
|
|
829
|
+
_mm512_sub_pd(sumsq_f64x8, _mm512_sub_pd(tentative_sq_f64x8, round_sq_f64x8)),
|
|
830
|
+
_mm512_sub_pd(sq_f64x8, round_sq_f64x8));
|
|
831
|
+
sumsq_comp_f64x8 = _mm512_add_pd(sumsq_comp_f64x8, corr_sq_f64x8);
|
|
832
|
+
sumsq_f64x8 = tentative_sq_f64x8;
|
|
833
|
+
}
|
|
834
|
+
*sum_ptr = nk_reduce_add_f64x8_skylake_(_mm512_add_pd(sum_f64x8, sum_comp_f64x8));
|
|
835
|
+
*sumsq_ptr = nk_reduce_add_f64x8_skylake_(_mm512_add_pd(sumsq_f64x8, sumsq_comp_f64x8));
|
|
836
|
+
}
|
|
837
|
+
|
|
838
|
+
NK_INTERNAL void nk_reduce_moments_f64_skylake_strided_( //
|
|
839
|
+
nk_f64_t const *data_ptr, nk_size_t count, nk_size_t stride_elements, //
|
|
840
|
+
nk_f64_t *sum_ptr, nk_f64_t *sumsq_ptr) {
|
|
841
|
+
__mmask8 stride_mask = nk_stride_mask_b64x8_(stride_elements);
|
|
842
|
+
__m512d sum_f64x8 = _mm512_setzero_pd();
|
|
843
|
+
__m512d sum_comp_f64x8 = _mm512_setzero_pd();
|
|
844
|
+
__m512d sumsq_f64x8 = _mm512_setzero_pd();
|
|
845
|
+
__m512d sumsq_comp_f64x8 = _mm512_setzero_pd();
|
|
846
|
+
nk_size_t idx = 0, total = count * stride_elements;
|
|
847
|
+
nk_size_t step = (nk_size_t)_mm_popcnt_u64((nk_u64_t)stride_mask) * stride_elements;
|
|
848
|
+
for (; idx + step <= total; idx += step) {
|
|
849
|
+
__m512d val_f64x8 = _mm512_maskz_loadu_pd(stride_mask, data_ptr + idx);
|
|
850
|
+
// Knuth 2-SUM for sum
|
|
851
|
+
__m512d tentative_f64x8 = _mm512_add_pd(sum_f64x8, val_f64x8);
|
|
852
|
+
__m512d round_f64x8 = _mm512_sub_pd(tentative_f64x8, sum_f64x8);
|
|
853
|
+
__m512d corr_f64x8 = _mm512_add_pd(_mm512_sub_pd(sum_f64x8, _mm512_sub_pd(tentative_f64x8, round_f64x8)),
|
|
854
|
+
_mm512_sub_pd(val_f64x8, round_f64x8));
|
|
855
|
+
sum_comp_f64x8 = _mm512_add_pd(sum_comp_f64x8, corr_f64x8);
|
|
856
|
+
sum_f64x8 = tentative_f64x8;
|
|
857
|
+
// Knuth 2-SUM for sumsq
|
|
858
|
+
__m512d sq_f64x8 = _mm512_mul_pd(val_f64x8, val_f64x8);
|
|
859
|
+
__m512d tentative_sq_f64x8 = _mm512_add_pd(sumsq_f64x8, sq_f64x8);
|
|
860
|
+
__m512d round_sq_f64x8 = _mm512_sub_pd(tentative_sq_f64x8, sumsq_f64x8);
|
|
861
|
+
__m512d corr_sq_f64x8 = _mm512_add_pd(
|
|
862
|
+
_mm512_sub_pd(sumsq_f64x8, _mm512_sub_pd(tentative_sq_f64x8, round_sq_f64x8)),
|
|
863
|
+
_mm512_sub_pd(sq_f64x8, round_sq_f64x8));
|
|
864
|
+
sumsq_comp_f64x8 = _mm512_add_pd(sumsq_comp_f64x8, corr_sq_f64x8);
|
|
865
|
+
sumsq_f64x8 = tentative_sq_f64x8;
|
|
866
|
+
}
|
|
867
|
+
nk_size_t remaining = total - idx;
|
|
868
|
+
if (remaining > 0) {
|
|
869
|
+
__mmask8 tail_mask = stride_mask & (__mmask8)_bzhi_u32(0xFF, (unsigned int)remaining);
|
|
870
|
+
__m512d val_f64x8 = _mm512_maskz_loadu_pd(tail_mask, data_ptr + idx);
|
|
871
|
+
// Knuth 2-SUM for sum
|
|
872
|
+
__m512d tentative_f64x8 = _mm512_add_pd(sum_f64x8, val_f64x8);
|
|
873
|
+
__m512d round_f64x8 = _mm512_sub_pd(tentative_f64x8, sum_f64x8);
|
|
874
|
+
__m512d corr_f64x8 = _mm512_add_pd(_mm512_sub_pd(sum_f64x8, _mm512_sub_pd(tentative_f64x8, round_f64x8)),
|
|
875
|
+
_mm512_sub_pd(val_f64x8, round_f64x8));
|
|
876
|
+
sum_comp_f64x8 = _mm512_add_pd(sum_comp_f64x8, corr_f64x8);
|
|
877
|
+
sum_f64x8 = tentative_f64x8;
|
|
878
|
+
// Knuth 2-SUM for sumsq
|
|
879
|
+
__m512d sq_f64x8 = _mm512_mul_pd(val_f64x8, val_f64x8);
|
|
880
|
+
__m512d tentative_sq_f64x8 = _mm512_add_pd(sumsq_f64x8, sq_f64x8);
|
|
881
|
+
__m512d round_sq_f64x8 = _mm512_sub_pd(tentative_sq_f64x8, sumsq_f64x8);
|
|
882
|
+
__m512d corr_sq_f64x8 = _mm512_add_pd(
|
|
883
|
+
_mm512_sub_pd(sumsq_f64x8, _mm512_sub_pd(tentative_sq_f64x8, round_sq_f64x8)),
|
|
884
|
+
_mm512_sub_pd(sq_f64x8, round_sq_f64x8));
|
|
885
|
+
sumsq_comp_f64x8 = _mm512_add_pd(sumsq_comp_f64x8, corr_sq_f64x8);
|
|
886
|
+
sumsq_f64x8 = tentative_sq_f64x8;
|
|
887
|
+
}
|
|
888
|
+
*sum_ptr = nk_reduce_add_f64x8_skylake_(_mm512_add_pd(sum_f64x8, sum_comp_f64x8));
|
|
889
|
+
*sumsq_ptr = nk_reduce_add_f64x8_skylake_(_mm512_add_pd(sumsq_f64x8, sumsq_comp_f64x8));
|
|
890
|
+
}
|
|
891
|
+
|
|
892
|
+
NK_INTERNAL void nk_reduce_moments_f64_skylake_gather_( //
|
|
893
|
+
nk_f64_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
894
|
+
nk_f64_t *sum_ptr, nk_f64_t *sumsq_ptr) {
|
|
895
|
+
nk_i32_t stride_elements = (nk_i32_t)(stride_bytes / sizeof(nk_f64_t));
|
|
896
|
+
__m256i indices_i32x8 = _mm256_mullo_epi32(_mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7),
|
|
897
|
+
_mm256_set1_epi32(stride_elements));
|
|
898
|
+
__m512d sum_f64x8 = _mm512_setzero_pd();
|
|
899
|
+
__m512d sum_comp_f64x8 = _mm512_setzero_pd();
|
|
900
|
+
__m512d sumsq_f64x8 = _mm512_setzero_pd();
|
|
901
|
+
__m512d sumsq_comp_f64x8 = _mm512_setzero_pd();
|
|
902
|
+
nk_size_t idx = 0;
|
|
903
|
+
for (; idx + 8 <= count; idx += 8) {
|
|
904
|
+
__m512d val_f64x8 = _mm512_i32gather_pd(indices_i32x8, data_ptr + idx * stride_elements, sizeof(nk_f64_t));
|
|
905
|
+
// Knuth 2-SUM for sum
|
|
906
|
+
__m512d tentative_f64x8 = _mm512_add_pd(sum_f64x8, val_f64x8);
|
|
907
|
+
__m512d round_f64x8 = _mm512_sub_pd(tentative_f64x8, sum_f64x8);
|
|
908
|
+
__m512d corr_f64x8 = _mm512_add_pd(_mm512_sub_pd(sum_f64x8, _mm512_sub_pd(tentative_f64x8, round_f64x8)),
|
|
909
|
+
_mm512_sub_pd(val_f64x8, round_f64x8));
|
|
910
|
+
sum_comp_f64x8 = _mm512_add_pd(sum_comp_f64x8, corr_f64x8);
|
|
911
|
+
sum_f64x8 = tentative_f64x8;
|
|
912
|
+
// Knuth 2-SUM for sumsq
|
|
913
|
+
__m512d sq_f64x8 = _mm512_mul_pd(val_f64x8, val_f64x8);
|
|
914
|
+
__m512d tentative_sq_f64x8 = _mm512_add_pd(sumsq_f64x8, sq_f64x8);
|
|
915
|
+
__m512d round_sq_f64x8 = _mm512_sub_pd(tentative_sq_f64x8, sumsq_f64x8);
|
|
916
|
+
__m512d corr_sq_f64x8 = _mm512_add_pd(
|
|
917
|
+
_mm512_sub_pd(sumsq_f64x8, _mm512_sub_pd(tentative_sq_f64x8, round_sq_f64x8)),
|
|
918
|
+
_mm512_sub_pd(sq_f64x8, round_sq_f64x8));
|
|
919
|
+
sumsq_comp_f64x8 = _mm512_add_pd(sumsq_comp_f64x8, corr_sq_f64x8);
|
|
920
|
+
sumsq_f64x8 = tentative_sq_f64x8;
|
|
921
|
+
}
|
|
922
|
+
nk_f64_t sum = nk_reduce_add_f64x8_skylake_(_mm512_add_pd(sum_f64x8, sum_comp_f64x8));
|
|
923
|
+
nk_f64_t sumsq = nk_reduce_add_f64x8_skylake_(_mm512_add_pd(sumsq_f64x8, sumsq_comp_f64x8));
|
|
924
|
+
unsigned char const *ptr = (unsigned char const *)(data_ptr + idx * stride_elements);
|
|
925
|
+
for (; idx < count; ++idx, ptr += stride_bytes) {
|
|
926
|
+
nk_f64_t val = *(nk_f64_t const *)ptr;
|
|
927
|
+
sum += val, sumsq += val * val;
|
|
928
|
+
}
|
|
929
|
+
*sum_ptr = sum, *sumsq_ptr = sumsq;
|
|
930
|
+
}
|
|
931
|
+
|
|
932
|
+
NK_PUBLIC void nk_reduce_moments_f64_skylake( //
|
|
933
|
+
nk_f64_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
934
|
+
nk_f64_t *sum_ptr, nk_f64_t *sumsq_ptr) {
|
|
935
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_f64_t);
|
|
936
|
+
int aligned = (stride_bytes % sizeof(nk_f64_t) == 0);
|
|
937
|
+
if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
|
|
938
|
+
else if (!aligned) nk_reduce_moments_f64_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
939
|
+
else if (count > (nk_size_t)(NK_U16_MAX + 1) * 8) {
|
|
940
|
+
nk_size_t left_count = count / 2;
|
|
941
|
+
nk_f64_t left_sum, left_sumsq, right_sum, right_sumsq;
|
|
942
|
+
nk_reduce_moments_f64_skylake(data_ptr, left_count, stride_bytes, &left_sum, &left_sumsq);
|
|
943
|
+
nk_reduce_moments_f64_skylake(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
|
|
944
|
+
&right_sum, &right_sumsq);
|
|
945
|
+
*sum_ptr = left_sum + right_sum, *sumsq_ptr = left_sumsq + right_sumsq;
|
|
946
|
+
}
|
|
947
|
+
else if (stride_elements == 1) nk_reduce_moments_f64_skylake_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
|
|
948
|
+
else if (stride_elements <= 8)
|
|
949
|
+
nk_reduce_moments_f64_skylake_strided_(data_ptr, count, stride_elements, sum_ptr, sumsq_ptr);
|
|
950
|
+
else nk_reduce_moments_f64_skylake_gather_(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
951
|
+
}
|
|
952
|
+
|
|
953
|
+
NK_INTERNAL void nk_reduce_moments_i8_skylake_contiguous_( //
|
|
954
|
+
nk_i8_t const *data_ptr, nk_size_t count, //
|
|
955
|
+
nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
956
|
+
// Sum: VPSADBW with XOR bias (same as nk_reduce_add_i8_skylake_contiguous_).
|
|
957
|
+
// Sumsq: widen i8→i16, VPMADDWD(x,x) → i32 (pairs of squares), accumulate i32.
|
|
958
|
+
// i32 overflow safe: max per lane = (128² + 128²) * 65536 iters ≈ 2.1B = safe limit.
|
|
959
|
+
// The dispatch recurses at (NK_U16_MAX+1)*64 elements → at most 65536 iterations here.
|
|
960
|
+
__m512i bias_i8x64 = _mm512_set1_epi8((char)0x80);
|
|
961
|
+
__m512i zero_i8x64 = _mm512_setzero_si512();
|
|
962
|
+
__m512i sum_u64x8 = _mm512_setzero_si512();
|
|
963
|
+
__m512i sumsq_low_i32x16 = _mm512_setzero_si512();
|
|
964
|
+
__m512i sumsq_high_i32x16 = _mm512_setzero_si512();
|
|
965
|
+
nk_size_t idx = 0;
|
|
966
|
+
for (; idx + 64 <= count; idx += 64) {
|
|
967
|
+
__m512i data_i8x64 = _mm512_loadu_si512(data_ptr + idx);
|
|
968
|
+
__m512i unsigned_i8x64 = _mm512_xor_si512(data_i8x64, bias_i8x64);
|
|
969
|
+
sum_u64x8 = _mm512_add_epi64(sum_u64x8, _mm512_sad_epu8(unsigned_i8x64, zero_i8x64));
|
|
970
|
+
__m512i low_i16x32 = _mm512_cvtepi8_epi16(_mm512_castsi512_si256(data_i8x64));
|
|
971
|
+
__m512i high_i16x32 = _mm512_cvtepi8_epi16(_mm512_extracti64x4_epi64(data_i8x64, 1));
|
|
972
|
+
sumsq_low_i32x16 = _mm512_add_epi32(sumsq_low_i32x16, _mm512_madd_epi16(low_i16x32, low_i16x32));
|
|
973
|
+
sumsq_high_i32x16 = _mm512_add_epi32(sumsq_high_i32x16, _mm512_madd_epi16(high_i16x32, high_i16x32));
|
|
974
|
+
}
|
|
975
|
+
// Flush i32 → i64 once
|
|
976
|
+
sumsq_low_i32x16 = _mm512_add_epi32(sumsq_low_i32x16, sumsq_high_i32x16);
|
|
977
|
+
__m512i sumsq_i64x8 = _mm512_cvtepi32_epi64(_mm512_castsi512_si256(sumsq_low_i32x16));
|
|
978
|
+
sumsq_i64x8 = _mm512_add_epi64(sumsq_i64x8, _mm512_cvtepi32_epi64(_mm512_extracti64x4_epi64(sumsq_low_i32x16, 1)));
|
|
979
|
+
nk_i64_t sum = (nk_i64_t)nk_reduce_add_u64x8_skylake_(sum_u64x8);
|
|
980
|
+
sum -= (nk_i64_t)128 * (nk_i64_t)idx;
|
|
981
|
+
nk_u64_t sumsq = (nk_u64_t)nk_reduce_add_i64x8_skylake_(sumsq_i64x8);
|
|
982
|
+
for (; idx < count; ++idx) {
|
|
983
|
+
nk_i64_t val = (nk_i64_t)data_ptr[idx];
|
|
984
|
+
sum += val, sumsq += (nk_u64_t)(val * val);
|
|
985
|
+
}
|
|
986
|
+
*sum_ptr = sum, *sumsq_ptr = sumsq;
|
|
987
|
+
}
|
|
988
|
+
|
|
989
|
+
NK_INTERNAL void nk_reduce_moments_i8_skylake_strided_( //
|
|
990
|
+
nk_i8_t const *data_ptr, nk_size_t count, nk_size_t stride_elements, //
|
|
991
|
+
nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
992
|
+
__mmask64 stride_mask_m64 = nk_stride_mask_u1x64_(stride_elements);
|
|
993
|
+
nk_size_t elements_per_vector = (nk_size_t)_mm_popcnt_u64((nk_u64_t)stride_mask_m64);
|
|
994
|
+
__m512i masked_bias_i8x64 = _mm512_maskz_mov_epi8(stride_mask_m64, _mm512_set1_epi8((char)0x80));
|
|
995
|
+
__m512i zero_i8x64 = _mm512_setzero_si512();
|
|
996
|
+
__m512i sum_u64x8 = _mm512_setzero_si512();
|
|
997
|
+
__m512i sumsq_low_i32x16 = _mm512_setzero_si512();
|
|
998
|
+
__m512i sumsq_high_i32x16 = _mm512_setzero_si512();
|
|
999
|
+
nk_size_t idx_scalars = 0;
|
|
1000
|
+
nk_size_t total_scalars = count * stride_elements;
|
|
1001
|
+
nk_size_t vector_element_count = 0;
|
|
1002
|
+
nk_size_t step = elements_per_vector * stride_elements;
|
|
1003
|
+
for (; idx_scalars + step <= total_scalars; idx_scalars += step) {
|
|
1004
|
+
__m512i data_i8x64 = _mm512_maskz_loadu_epi8(stride_mask_m64, data_ptr + idx_scalars);
|
|
1005
|
+
__m512i unsigned_i8x64 = _mm512_xor_si512(data_i8x64, masked_bias_i8x64);
|
|
1006
|
+
sum_u64x8 = _mm512_add_epi64(sum_u64x8, _mm512_sad_epu8(unsigned_i8x64, zero_i8x64));
|
|
1007
|
+
__m512i low_i16x32 = _mm512_cvtepi8_epi16(_mm512_castsi512_si256(data_i8x64));
|
|
1008
|
+
__m512i high_i16x32 = _mm512_cvtepi8_epi16(_mm512_extracti64x4_epi64(data_i8x64, 1));
|
|
1009
|
+
sumsq_low_i32x16 = _mm512_add_epi32(sumsq_low_i32x16, _mm512_madd_epi16(low_i16x32, low_i16x32));
|
|
1010
|
+
sumsq_high_i32x16 = _mm512_add_epi32(sumsq_high_i32x16, _mm512_madd_epi16(high_i16x32, high_i16x32));
|
|
1011
|
+
vector_element_count += elements_per_vector;
|
|
1012
|
+
}
|
|
1013
|
+
sumsq_low_i32x16 = _mm512_add_epi32(sumsq_low_i32x16, sumsq_high_i32x16);
|
|
1014
|
+
__m512i sumsq_i64x8 = _mm512_cvtepi32_epi64(_mm512_castsi512_si256(sumsq_low_i32x16));
|
|
1015
|
+
sumsq_i64x8 = _mm512_add_epi64(sumsq_i64x8, _mm512_cvtepi32_epi64(_mm512_extracti64x4_epi64(sumsq_low_i32x16, 1)));
|
|
1016
|
+
nk_i64_t sum = (nk_i64_t)nk_reduce_add_u64x8_skylake_(sum_u64x8);
|
|
1017
|
+
sum -= (nk_i64_t)128 * (nk_i64_t)vector_element_count;
|
|
1018
|
+
nk_u64_t sumsq = (nk_u64_t)nk_reduce_add_i64x8_skylake_(sumsq_i64x8);
|
|
1019
|
+
nk_i8_t const *ptr = data_ptr + idx_scalars;
|
|
1020
|
+
nk_size_t remaining = count - idx_scalars / stride_elements;
|
|
1021
|
+
for (nk_size_t i = 0; i < remaining; ++i, ptr += stride_elements) {
|
|
1022
|
+
nk_i64_t val = (nk_i64_t)*ptr;
|
|
1023
|
+
sum += val, sumsq += (nk_u64_t)(val * val);
|
|
1024
|
+
}
|
|
1025
|
+
*sum_ptr = sum, *sumsq_ptr = sumsq;
|
|
1026
|
+
}
|
|
1027
|
+
|
|
1028
|
+
NK_PUBLIC void nk_reduce_moments_i8_skylake( //
|
|
1029
|
+
nk_i8_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
1030
|
+
nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
1031
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_i8_t);
|
|
1032
|
+
int aligned = (stride_bytes % sizeof(nk_i8_t) == 0);
|
|
1033
|
+
if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
|
|
1034
|
+
else if (!aligned) nk_reduce_moments_i8_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
1035
|
+
else if (count > (nk_size_t)(NK_U16_MAX + 1) * 64) {
|
|
1036
|
+
nk_size_t left_count = count / 2;
|
|
1037
|
+
nk_i64_t left_sum, right_sum;
|
|
1038
|
+
nk_u64_t left_sumsq, right_sumsq;
|
|
1039
|
+
nk_reduce_moments_i8_skylake(data_ptr, left_count, stride_bytes, &left_sum, &left_sumsq);
|
|
1040
|
+
nk_reduce_moments_i8_skylake(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
|
|
1041
|
+
&right_sum, &right_sumsq);
|
|
1042
|
+
*sum_ptr = nk_i64_saturating_add_serial(left_sum, right_sum);
|
|
1043
|
+
*sumsq_ptr = nk_u64_saturating_add_serial(left_sumsq, right_sumsq);
|
|
1044
|
+
}
|
|
1045
|
+
else if (stride_elements == 1) nk_reduce_moments_i8_skylake_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
|
|
1046
|
+
else if (stride_elements <= 16)
|
|
1047
|
+
nk_reduce_moments_i8_skylake_strided_(data_ptr, count, stride_elements, sum_ptr, sumsq_ptr);
|
|
1048
|
+
else nk_reduce_moments_i8_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
1049
|
+
}
|
|
1050
|
+
|
|
1051
|
+
NK_INTERNAL void nk_reduce_minmax_i8_skylake_contiguous_( //
|
|
1052
|
+
nk_i8_t const *data_ptr, nk_size_t count, //
|
|
1053
|
+
nk_i8_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
1054
|
+
nk_i8_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
1055
|
+
__m512i min_i8x64 = _mm512_set1_epi8((char)NK_I8_MAX);
|
|
1056
|
+
__m512i max_i8x64 = _mm512_set1_epi8(NK_I8_MIN);
|
|
1057
|
+
__m512i min_loop_cycle_u8x64 = _mm512_setzero_si512();
|
|
1058
|
+
__m512i max_loop_cycle_u8x64 = _mm512_setzero_si512();
|
|
1059
|
+
__m512i current_loop_cycle_u8x64 = _mm512_setzero_si512();
|
|
1060
|
+
__m512i one_u8x64 = _mm512_set1_epi8(1);
|
|
1061
|
+
|
|
1062
|
+
nk_size_t idx = 0;
|
|
1063
|
+
for (; idx + 64 <= count; idx += 64) {
|
|
1064
|
+
__m512i data_i8x64 = _mm512_loadu_si512(data_ptr + idx);
|
|
1065
|
+
__mmask64 min_changed_mask = _mm512_cmp_epi8_mask(data_i8x64, min_i8x64, _MM_CMPINT_LT);
|
|
1066
|
+
__mmask64 max_changed_mask = _mm512_cmp_epi8_mask(data_i8x64, max_i8x64, _MM_CMPINT_NLE);
|
|
1067
|
+
min_i8x64 = _mm512_mask_mov_epi8(min_i8x64, min_changed_mask, data_i8x64);
|
|
1068
|
+
max_i8x64 = _mm512_mask_mov_epi8(max_i8x64, max_changed_mask, data_i8x64);
|
|
1069
|
+
min_loop_cycle_u8x64 = _mm512_mask_mov_epi8(min_loop_cycle_u8x64, min_changed_mask, current_loop_cycle_u8x64);
|
|
1070
|
+
max_loop_cycle_u8x64 = _mm512_mask_mov_epi8(max_loop_cycle_u8x64, max_changed_mask, current_loop_cycle_u8x64);
|
|
1071
|
+
current_loop_cycle_u8x64 = _mm512_add_epi8(current_loop_cycle_u8x64, one_u8x64);
|
|
1072
|
+
}
|
|
1073
|
+
|
|
1074
|
+
nk_size_t remaining = count - idx;
|
|
1075
|
+
if (remaining > 0) {
|
|
1076
|
+
__mmask64 tail_load = _bzhi_u64(0xFFFFFFFFFFFFFFFFull, (unsigned int)remaining);
|
|
1077
|
+
__m512i tail_i8x64 = _mm512_maskz_loadu_epi8(tail_load, data_ptr + idx);
|
|
1078
|
+
__mmask64 min_changed_mask = _mm512_mask_cmp_epi8_mask(tail_load, tail_i8x64, min_i8x64, _MM_CMPINT_LT);
|
|
1079
|
+
__mmask64 max_changed_mask = _mm512_mask_cmp_epi8_mask(tail_load, tail_i8x64, max_i8x64, _MM_CMPINT_NLE);
|
|
1080
|
+
min_i8x64 = _mm512_mask_mov_epi8(min_i8x64, min_changed_mask, tail_i8x64);
|
|
1081
|
+
max_i8x64 = _mm512_mask_mov_epi8(max_i8x64, max_changed_mask, tail_i8x64);
|
|
1082
|
+
min_loop_cycle_u8x64 = _mm512_mask_mov_epi8(min_loop_cycle_u8x64, min_changed_mask, current_loop_cycle_u8x64);
|
|
1083
|
+
max_loop_cycle_u8x64 = _mm512_mask_mov_epi8(max_loop_cycle_u8x64, max_changed_mask, current_loop_cycle_u8x64);
|
|
1084
|
+
}
|
|
1085
|
+
|
|
1086
|
+
nk_i8_t min_value = nk_reduce_min_i8x64_skylake_(min_i8x64);
|
|
1087
|
+
nk_i8_t max_value = nk_reduce_max_i8x64_skylake_(max_i8x64);
|
|
1088
|
+
unsigned int min_lane, max_lane;
|
|
1089
|
+
{
|
|
1090
|
+
__mmask64 value_match_mask = _mm512_cmpeq_epi8_mask(min_i8x64, _mm512_set1_epi8(min_value));
|
|
1091
|
+
__m512i masked_cycle_u8x64 = _mm512_mask_blend_epi8(value_match_mask, _mm512_set1_epi8((char)NK_U8_MAX),
|
|
1092
|
+
min_loop_cycle_u8x64);
|
|
1093
|
+
nk_u8_t earliest_loop_cycle = nk_reduce_min_u8x64_skylake_(masked_cycle_u8x64);
|
|
1094
|
+
__mmask64 cycle_match_mask = _mm512_cmpeq_epi8_mask(masked_cycle_u8x64,
|
|
1095
|
+
_mm512_set1_epi8((char)earliest_loop_cycle));
|
|
1096
|
+
min_lane = (unsigned int)_tzcnt_u64(cycle_match_mask);
|
|
1097
|
+
}
|
|
1098
|
+
{
|
|
1099
|
+
__mmask64 value_match_mask = _mm512_cmpeq_epi8_mask(max_i8x64, _mm512_set1_epi8(max_value));
|
|
1100
|
+
__m512i masked_cycle_u8x64 = _mm512_mask_blend_epi8(value_match_mask, _mm512_set1_epi8((char)NK_U8_MAX),
|
|
1101
|
+
max_loop_cycle_u8x64);
|
|
1102
|
+
nk_u8_t earliest_loop_cycle = nk_reduce_min_u8x64_skylake_(masked_cycle_u8x64);
|
|
1103
|
+
__mmask64 cycle_match_mask = _mm512_cmpeq_epi8_mask(masked_cycle_u8x64,
|
|
1104
|
+
_mm512_set1_epi8((char)earliest_loop_cycle));
|
|
1105
|
+
max_lane = (unsigned int)_tzcnt_u64(cycle_match_mask);
|
|
1106
|
+
}
|
|
1107
|
+
nk_b512_vec_t loop_cycle_vec;
|
|
1108
|
+
loop_cycle_vec.zmm = min_loop_cycle_u8x64;
|
|
1109
|
+
*min_value_ptr = min_value, *min_index_ptr = (nk_size_t)loop_cycle_vec.u8s[min_lane] * 64 + min_lane;
|
|
1110
|
+
loop_cycle_vec.zmm = max_loop_cycle_u8x64;
|
|
1111
|
+
*max_value_ptr = max_value, *max_index_ptr = (nk_size_t)loop_cycle_vec.u8s[max_lane] * 64 + max_lane;
|
|
1112
|
+
}
|
|
1113
|
+
|
|
1114
|
+
NK_PUBLIC void nk_reduce_minmax_i8_skylake( //
|
|
1115
|
+
nk_i8_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
1116
|
+
nk_i8_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
1117
|
+
nk_i8_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
1118
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_i8_t);
|
|
1119
|
+
int aligned = (stride_bytes % sizeof(nk_i8_t) == 0);
|
|
1120
|
+
if (count == 0)
|
|
1121
|
+
*min_value_ptr = NK_I8_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_I8_MIN,
|
|
1122
|
+
*max_index_ptr = NK_SIZE_MAX;
|
|
1123
|
+
else if (!aligned)
|
|
1124
|
+
nk_reduce_minmax_i8_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
1125
|
+
max_index_ptr);
|
|
1126
|
+
else if (count > (nk_size_t)(NK_U8_MAX + 1) * 64) {
|
|
1127
|
+
nk_size_t left_count = count / 2;
|
|
1128
|
+
nk_i8_t left_min, right_min, left_max, right_max;
|
|
1129
|
+
nk_size_t left_min_index, right_min_index, left_max_index, right_max_index;
|
|
1130
|
+
nk_reduce_minmax_i8_skylake(data_ptr, left_count, stride_bytes, &left_min, &left_min_index, &left_max,
|
|
1131
|
+
&left_max_index);
|
|
1132
|
+
nk_reduce_minmax_i8_skylake(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
|
|
1133
|
+
&right_min, &right_min_index, &right_max, &right_max_index);
|
|
1134
|
+
if (right_min < left_min) *min_value_ptr = right_min, *min_index_ptr = left_count + right_min_index;
|
|
1135
|
+
else *min_value_ptr = left_min, *min_index_ptr = left_min_index;
|
|
1136
|
+
if (right_max > left_max) *max_value_ptr = right_max, *max_index_ptr = left_count + right_max_index;
|
|
1137
|
+
else *max_value_ptr = left_max, *max_index_ptr = left_max_index;
|
|
1138
|
+
}
|
|
1139
|
+
else if (stride_elements == 1)
|
|
1140
|
+
nk_reduce_minmax_i8_skylake_contiguous_(data_ptr, count, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
1141
|
+
max_index_ptr);
|
|
1142
|
+
else
|
|
1143
|
+
nk_reduce_minmax_i8_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
1144
|
+
max_index_ptr);
|
|
1145
|
+
}
|
|
1146
|
+
|
|
1147
|
+
NK_INTERNAL void nk_reduce_moments_u8_skylake_contiguous_( //
|
|
1148
|
+
nk_u8_t const *data_ptr, nk_size_t count, //
|
|
1149
|
+
nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
1150
|
+
// Sum: VPSADBW directly (same as nk_reduce_add_u8_skylake_contiguous_).
|
|
1151
|
+
// Sumsq: widen u8→i16, VPMADDWD(x,x) → i32 (pairs of squares), accumulate i32.
|
|
1152
|
+
// i32 overflow safe: max per lane = (255² + 255²) * 1024 iters ≈ 133M < 2.1B.
|
|
1153
|
+
__m512i zero_u8x64 = _mm512_setzero_si512();
|
|
1154
|
+
__m512i sum_u64x8 = _mm512_setzero_si512();
|
|
1155
|
+
__m512i sumsq_low_i32x16 = _mm512_setzero_si512();
|
|
1156
|
+
__m512i sumsq_high_i32x16 = _mm512_setzero_si512();
|
|
1157
|
+
nk_size_t idx = 0;
|
|
1158
|
+
for (; idx + 64 <= count; idx += 64) {
|
|
1159
|
+
__m512i data_u8x64 = _mm512_loadu_si512(data_ptr + idx);
|
|
1160
|
+
sum_u64x8 = _mm512_add_epi64(sum_u64x8, _mm512_sad_epu8(data_u8x64, zero_u8x64));
|
|
1161
|
+
__m512i low_i16x32 = _mm512_cvtepu8_epi16(_mm512_castsi512_si256(data_u8x64));
|
|
1162
|
+
__m512i high_i16x32 = _mm512_cvtepu8_epi16(_mm512_extracti64x4_epi64(data_u8x64, 1));
|
|
1163
|
+
sumsq_low_i32x16 = _mm512_add_epi32(sumsq_low_i32x16, _mm512_madd_epi16(low_i16x32, low_i16x32));
|
|
1164
|
+
sumsq_high_i32x16 = _mm512_add_epi32(sumsq_high_i32x16, _mm512_madd_epi16(high_i16x32, high_i16x32));
|
|
1165
|
+
}
|
|
1166
|
+
// Flush i32 → u64 once
|
|
1167
|
+
sumsq_low_i32x16 = _mm512_add_epi32(sumsq_low_i32x16, sumsq_high_i32x16);
|
|
1168
|
+
__m512i sumsq_u64x8 = _mm512_cvtepu32_epi64(_mm512_castsi512_si256(sumsq_low_i32x16));
|
|
1169
|
+
sumsq_u64x8 = _mm512_add_epi64(sumsq_u64x8, _mm512_cvtepu32_epi64(_mm512_extracti64x4_epi64(sumsq_low_i32x16, 1)));
|
|
1170
|
+
nk_u64_t sum = nk_reduce_add_u64x8_skylake_(sum_u64x8);
|
|
1171
|
+
nk_u64_t sumsq = nk_reduce_add_u64x8_skylake_(sumsq_u64x8);
|
|
1172
|
+
for (; idx < count; ++idx) {
|
|
1173
|
+
nk_u64_t val = (nk_u64_t)data_ptr[idx];
|
|
1174
|
+
sum += val, sumsq += val * val;
|
|
1175
|
+
}
|
|
1176
|
+
*sum_ptr = sum, *sumsq_ptr = sumsq;
|
|
1177
|
+
}
|
|
1178
|
+
|
|
1179
|
+
NK_INTERNAL void nk_reduce_moments_u8_skylake_strided_( //
|
|
1180
|
+
nk_u8_t const *data_ptr, nk_size_t count, nk_size_t stride_elements, //
|
|
1181
|
+
nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
1182
|
+
__mmask64 stride_mask_m64 = nk_stride_mask_u1x64_(stride_elements);
|
|
1183
|
+
__m512i zero_u8x64 = _mm512_setzero_si512();
|
|
1184
|
+
__m512i sum_u64x8 = _mm512_setzero_si512();
|
|
1185
|
+
__m512i sumsq_low_i32x16 = _mm512_setzero_si512();
|
|
1186
|
+
__m512i sumsq_high_i32x16 = _mm512_setzero_si512();
|
|
1187
|
+
nk_size_t idx_scalars = 0;
|
|
1188
|
+
nk_size_t total_scalars = count * stride_elements;
|
|
1189
|
+
nk_size_t step = (nk_size_t)_mm_popcnt_u64((nk_u64_t)stride_mask_m64) * stride_elements;
|
|
1190
|
+
for (; idx_scalars + step <= total_scalars; idx_scalars += step) {
|
|
1191
|
+
__m512i data_u8x64 = _mm512_maskz_loadu_epi8(stride_mask_m64, data_ptr + idx_scalars);
|
|
1192
|
+
sum_u64x8 = _mm512_add_epi64(sum_u64x8, _mm512_sad_epu8(data_u8x64, zero_u8x64));
|
|
1193
|
+
__m512i low_i16x32 = _mm512_cvtepu8_epi16(_mm512_castsi512_si256(data_u8x64));
|
|
1194
|
+
__m512i high_i16x32 = _mm512_cvtepu8_epi16(_mm512_extracti64x4_epi64(data_u8x64, 1));
|
|
1195
|
+
sumsq_low_i32x16 = _mm512_add_epi32(sumsq_low_i32x16, _mm512_madd_epi16(low_i16x32, low_i16x32));
|
|
1196
|
+
sumsq_high_i32x16 = _mm512_add_epi32(sumsq_high_i32x16, _mm512_madd_epi16(high_i16x32, high_i16x32));
|
|
1197
|
+
}
|
|
1198
|
+
sumsq_low_i32x16 = _mm512_add_epi32(sumsq_low_i32x16, sumsq_high_i32x16);
|
|
1199
|
+
__m512i sumsq_u64x8 = _mm512_cvtepu32_epi64(_mm512_castsi512_si256(sumsq_low_i32x16));
|
|
1200
|
+
sumsq_u64x8 = _mm512_add_epi64(sumsq_u64x8, _mm512_cvtepu32_epi64(_mm512_extracti64x4_epi64(sumsq_low_i32x16, 1)));
|
|
1201
|
+
nk_u64_t sum = nk_reduce_add_u64x8_skylake_(sum_u64x8);
|
|
1202
|
+
nk_u64_t sumsq = nk_reduce_add_u64x8_skylake_(sumsq_u64x8);
|
|
1203
|
+
nk_u8_t const *ptr = data_ptr + idx_scalars;
|
|
1204
|
+
nk_size_t remaining = count - idx_scalars / stride_elements;
|
|
1205
|
+
for (nk_size_t i = 0; i < remaining; ++i, ptr += stride_elements) {
|
|
1206
|
+
nk_u64_t val = (nk_u64_t)*ptr;
|
|
1207
|
+
sum += val, sumsq += val * val;
|
|
1208
|
+
}
|
|
1209
|
+
*sum_ptr = sum, *sumsq_ptr = sumsq;
|
|
1210
|
+
}
|
|
1211
|
+
|
|
1212
|
+
NK_PUBLIC void nk_reduce_moments_u8_skylake( //
|
|
1213
|
+
nk_u8_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
1214
|
+
nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
1215
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_u8_t);
|
|
1216
|
+
int aligned = (stride_bytes % sizeof(nk_u8_t) == 0);
|
|
1217
|
+
if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
|
|
1218
|
+
else if (!aligned) nk_reduce_moments_u8_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
1219
|
+
else if (count > (nk_size_t)(NK_U8_MAX + 1) * 64) {
|
|
1220
|
+
nk_size_t left_count = count / 2;
|
|
1221
|
+
nk_u64_t left_sum, left_sumsq, right_sum, right_sumsq;
|
|
1222
|
+
nk_reduce_moments_u8_skylake(data_ptr, left_count, stride_bytes, &left_sum, &left_sumsq);
|
|
1223
|
+
nk_reduce_moments_u8_skylake(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
|
|
1224
|
+
&right_sum, &right_sumsq);
|
|
1225
|
+
*sum_ptr = nk_u64_saturating_add_serial(left_sum, right_sum);
|
|
1226
|
+
*sumsq_ptr = nk_u64_saturating_add_serial(left_sumsq, right_sumsq);
|
|
1227
|
+
}
|
|
1228
|
+
else if (stride_elements == 1) nk_reduce_moments_u8_skylake_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
|
|
1229
|
+
else if (stride_elements <= 16)
|
|
1230
|
+
nk_reduce_moments_u8_skylake_strided_(data_ptr, count, stride_elements, sum_ptr, sumsq_ptr);
|
|
1231
|
+
else nk_reduce_moments_u8_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
1232
|
+
}
|
|
1233
|
+
|
|
1234
|
+
NK_INTERNAL void nk_reduce_minmax_u8_skylake_contiguous_( //
|
|
1235
|
+
nk_u8_t const *data_ptr, nk_size_t count, //
|
|
1236
|
+
nk_u8_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
1237
|
+
nk_u8_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
1238
|
+
__m512i min_u8x64 = _mm512_set1_epi8((char)NK_U8_MAX);
|
|
1239
|
+
__m512i max_u8x64 = _mm512_setzero_si512();
|
|
1240
|
+
__m512i min_loop_cycle_u8x64 = _mm512_setzero_si512();
|
|
1241
|
+
__m512i max_loop_cycle_u8x64 = _mm512_setzero_si512();
|
|
1242
|
+
__m512i current_loop_cycle_u8x64 = _mm512_setzero_si512();
|
|
1243
|
+
__m512i one_u8x64 = _mm512_set1_epi8(1);
|
|
1244
|
+
|
|
1245
|
+
nk_size_t idx = 0;
|
|
1246
|
+
for (; idx + 64 <= count; idx += 64) {
|
|
1247
|
+
__m512i data_u8x64 = _mm512_loadu_si512(data_ptr + idx);
|
|
1248
|
+
__mmask64 min_changed_mask = _mm512_cmp_epu8_mask(data_u8x64, min_u8x64, _MM_CMPINT_LT);
|
|
1249
|
+
__mmask64 max_changed_mask = _mm512_cmp_epu8_mask(data_u8x64, max_u8x64, _MM_CMPINT_NLE);
|
|
1250
|
+
min_u8x64 = _mm512_mask_mov_epi8(min_u8x64, min_changed_mask, data_u8x64);
|
|
1251
|
+
max_u8x64 = _mm512_mask_mov_epi8(max_u8x64, max_changed_mask, data_u8x64);
|
|
1252
|
+
min_loop_cycle_u8x64 = _mm512_mask_mov_epi8(min_loop_cycle_u8x64, min_changed_mask, current_loop_cycle_u8x64);
|
|
1253
|
+
max_loop_cycle_u8x64 = _mm512_mask_mov_epi8(max_loop_cycle_u8x64, max_changed_mask, current_loop_cycle_u8x64);
|
|
1254
|
+
current_loop_cycle_u8x64 = _mm512_add_epi8(current_loop_cycle_u8x64, one_u8x64);
|
|
1255
|
+
}
|
|
1256
|
+
|
|
1257
|
+
nk_size_t remaining = count - idx;
|
|
1258
|
+
if (remaining > 0) {
|
|
1259
|
+
__mmask64 tail_load = _bzhi_u64(0xFFFFFFFFFFFFFFFFull, (unsigned int)remaining);
|
|
1260
|
+
__m512i tail_u8x64 = _mm512_maskz_loadu_epi8(tail_load, data_ptr + idx);
|
|
1261
|
+
__mmask64 min_changed_mask = _mm512_mask_cmp_epu8_mask(tail_load, tail_u8x64, min_u8x64, _MM_CMPINT_LT);
|
|
1262
|
+
__mmask64 max_changed_mask = _mm512_mask_cmp_epu8_mask(tail_load, tail_u8x64, max_u8x64, _MM_CMPINT_NLE);
|
|
1263
|
+
min_u8x64 = _mm512_mask_mov_epi8(min_u8x64, min_changed_mask, tail_u8x64);
|
|
1264
|
+
max_u8x64 = _mm512_mask_mov_epi8(max_u8x64, max_changed_mask, tail_u8x64);
|
|
1265
|
+
min_loop_cycle_u8x64 = _mm512_mask_mov_epi8(min_loop_cycle_u8x64, min_changed_mask, current_loop_cycle_u8x64);
|
|
1266
|
+
max_loop_cycle_u8x64 = _mm512_mask_mov_epi8(max_loop_cycle_u8x64, max_changed_mask, current_loop_cycle_u8x64);
|
|
1267
|
+
}
|
|
1268
|
+
|
|
1269
|
+
nk_u8_t min_value = nk_reduce_min_u8x64_skylake_(min_u8x64);
|
|
1270
|
+
nk_u8_t max_value = nk_reduce_max_u8x64_skylake_(max_u8x64);
|
|
1271
|
+
unsigned int min_lane, max_lane;
|
|
1272
|
+
{
|
|
1273
|
+
__mmask64 value_match_mask = _mm512_cmpeq_epi8_mask(min_u8x64, _mm512_set1_epi8((char)min_value));
|
|
1274
|
+
__m512i masked_cycle_u8x64 = _mm512_mask_blend_epi8(value_match_mask, _mm512_set1_epi8((char)NK_U8_MAX),
|
|
1275
|
+
min_loop_cycle_u8x64);
|
|
1276
|
+
nk_u8_t earliest_loop_cycle = nk_reduce_min_u8x64_skylake_(masked_cycle_u8x64);
|
|
1277
|
+
__mmask64 cycle_match_mask = _mm512_cmpeq_epi8_mask(masked_cycle_u8x64,
|
|
1278
|
+
_mm512_set1_epi8((char)earliest_loop_cycle));
|
|
1279
|
+
min_lane = (unsigned int)_tzcnt_u64(cycle_match_mask);
|
|
1280
|
+
}
|
|
1281
|
+
{
|
|
1282
|
+
__mmask64 value_match_mask = _mm512_cmpeq_epi8_mask(max_u8x64, _mm512_set1_epi8((char)max_value));
|
|
1283
|
+
__m512i masked_cycle_u8x64 = _mm512_mask_blend_epi8(value_match_mask, _mm512_set1_epi8((char)NK_U8_MAX),
|
|
1284
|
+
max_loop_cycle_u8x64);
|
|
1285
|
+
nk_u8_t earliest_loop_cycle = nk_reduce_min_u8x64_skylake_(masked_cycle_u8x64);
|
|
1286
|
+
__mmask64 cycle_match_mask = _mm512_cmpeq_epi8_mask(masked_cycle_u8x64,
|
|
1287
|
+
_mm512_set1_epi8((char)earliest_loop_cycle));
|
|
1288
|
+
max_lane = (unsigned int)_tzcnt_u64(cycle_match_mask);
|
|
1289
|
+
}
|
|
1290
|
+
nk_b512_vec_t loop_cycle_vec;
|
|
1291
|
+
loop_cycle_vec.zmm = min_loop_cycle_u8x64;
|
|
1292
|
+
*min_value_ptr = min_value, *min_index_ptr = (nk_size_t)loop_cycle_vec.u8s[min_lane] * 64 + min_lane;
|
|
1293
|
+
loop_cycle_vec.zmm = max_loop_cycle_u8x64;
|
|
1294
|
+
*max_value_ptr = max_value, *max_index_ptr = (nk_size_t)loop_cycle_vec.u8s[max_lane] * 64 + max_lane;
|
|
1295
|
+
}
|
|
1296
|
+
|
|
1297
|
+
NK_PUBLIC void nk_reduce_minmax_u8_skylake( //
|
|
1298
|
+
nk_u8_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
1299
|
+
nk_u8_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
1300
|
+
nk_u8_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
1301
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_u8_t);
|
|
1302
|
+
int aligned = (stride_bytes % sizeof(nk_u8_t) == 0);
|
|
1303
|
+
if (count == 0)
|
|
1304
|
+
*min_value_ptr = NK_U8_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = 0, *max_index_ptr = NK_SIZE_MAX;
|
|
1305
|
+
else if (!aligned)
|
|
1306
|
+
nk_reduce_minmax_u8_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
1307
|
+
max_index_ptr);
|
|
1308
|
+
else if (count > (nk_size_t)(NK_U8_MAX + 1) * 64) {
|
|
1309
|
+
nk_size_t left_count = count / 2;
|
|
1310
|
+
nk_u8_t left_min, right_min, left_max, right_max;
|
|
1311
|
+
nk_size_t left_min_index, right_min_index, left_max_index, right_max_index;
|
|
1312
|
+
nk_reduce_minmax_u8_skylake(data_ptr, left_count, stride_bytes, &left_min, &left_min_index, &left_max,
|
|
1313
|
+
&left_max_index);
|
|
1314
|
+
nk_reduce_minmax_u8_skylake(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
|
|
1315
|
+
&right_min, &right_min_index, &right_max, &right_max_index);
|
|
1316
|
+
if (right_min < left_min) *min_value_ptr = right_min, *min_index_ptr = left_count + right_min_index;
|
|
1317
|
+
else *min_value_ptr = left_min, *min_index_ptr = left_min_index;
|
|
1318
|
+
if (right_max > left_max) *max_value_ptr = right_max, *max_index_ptr = left_count + right_max_index;
|
|
1319
|
+
else *max_value_ptr = left_max, *max_index_ptr = left_max_index;
|
|
1320
|
+
}
|
|
1321
|
+
else if (stride_elements == 1)
|
|
1322
|
+
nk_reduce_minmax_u8_skylake_contiguous_(data_ptr, count, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
1323
|
+
max_index_ptr);
|
|
1324
|
+
else
|
|
1325
|
+
nk_reduce_minmax_u8_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
1326
|
+
max_index_ptr);
|
|
1327
|
+
}
|
|
1328
|
+
|
|
1329
|
+
NK_INTERNAL void nk_reduce_moments_i16_skylake_contiguous_( //
|
|
1330
|
+
nk_i16_t const *data_ptr, nk_size_t count, //
|
|
1331
|
+
nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
1332
|
+
// Sum: VPMADDWD(data_ptr, ones) → i32 pairs, accumulate i32, single flush at end.
|
|
1333
|
+
// Within 65536-element block (2048 iters), max i32 = ±65536 * 2048 ≈ ±134M — safe.
|
|
1334
|
+
// Sumsq: VPMADDWD(data_ptr, data_ptr) → i32, each up to ~2.1B — must flush to i64 every iteration.
|
|
1335
|
+
__m512i ones_i16x32 = _mm512_set1_epi16(1);
|
|
1336
|
+
__m512i sum_i32x16 = _mm512_setzero_si512();
|
|
1337
|
+
__m512i sumsq_i64x8 = _mm512_setzero_si512();
|
|
1338
|
+
nk_size_t idx = 0;
|
|
1339
|
+
for (; idx + 32 <= count; idx += 32) {
|
|
1340
|
+
__m512i data_i16x32 = _mm512_loadu_si512(data_ptr + idx);
|
|
1341
|
+
sum_i32x16 = _mm512_add_epi32(sum_i32x16, _mm512_madd_epi16(data_i16x32, ones_i16x32));
|
|
1342
|
+
__m512i sq_i32x16 = _mm512_madd_epi16(data_i16x32, data_i16x32);
|
|
1343
|
+
sumsq_i64x8 = _mm512_add_epi64(sumsq_i64x8, _mm512_cvtepi32_epi64(_mm512_castsi512_si256(sq_i32x16)));
|
|
1344
|
+
sumsq_i64x8 = _mm512_add_epi64(sumsq_i64x8, _mm512_cvtepi32_epi64(_mm512_extracti64x4_epi64(sq_i32x16, 1)));
|
|
1345
|
+
}
|
|
1346
|
+
__m512i sum_i64x8 = _mm512_add_epi64( //
|
|
1347
|
+
_mm512_cvtepi32_epi64(_mm512_castsi512_si256(sum_i32x16)), //
|
|
1348
|
+
_mm512_cvtepi32_epi64(_mm512_extracti64x4_epi64(sum_i32x16, 1))); //
|
|
1349
|
+
nk_i64_t sum = nk_reduce_add_i64x8_skylake_(sum_i64x8);
|
|
1350
|
+
nk_u64_t sumsq = (nk_u64_t)nk_reduce_add_i64x8_skylake_(sumsq_i64x8);
|
|
1351
|
+
for (; idx < count; ++idx) {
|
|
1352
|
+
nk_i64_t val = (nk_i64_t)data_ptr[idx];
|
|
1353
|
+
sum += val, sumsq += (nk_u64_t)(val * val);
|
|
1354
|
+
}
|
|
1355
|
+
*sum_ptr = sum, *sumsq_ptr = sumsq;
|
|
1356
|
+
}
|
|
1357
|
+
|
|
1358
|
+
NK_INTERNAL void nk_reduce_moments_i16_skylake_strided_( //
|
|
1359
|
+
nk_i16_t const *data_ptr, nk_size_t count, nk_size_t stride_elements, //
|
|
1360
|
+
nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
1361
|
+
__mmask32 stride_mask_m32 = nk_stride_mask_b16x32_(stride_elements);
|
|
1362
|
+
__m512i ones_i16x32 = _mm512_set1_epi16(1);
|
|
1363
|
+
__m512i sum_i32x16 = _mm512_setzero_si512();
|
|
1364
|
+
__m512i sumsq_i64x8 = _mm512_setzero_si512();
|
|
1365
|
+
nk_size_t idx_scalars = 0;
|
|
1366
|
+
nk_size_t total_scalars = count * stride_elements;
|
|
1367
|
+
nk_size_t step = (nk_size_t)_mm_popcnt_u64((nk_u64_t)stride_mask_m32) * stride_elements;
|
|
1368
|
+
for (; idx_scalars + step <= total_scalars; idx_scalars += step) {
|
|
1369
|
+
__m512i data_i16x32 = _mm512_maskz_loadu_epi16(stride_mask_m32, data_ptr + idx_scalars);
|
|
1370
|
+
sum_i32x16 = _mm512_add_epi32(sum_i32x16, _mm512_madd_epi16(data_i16x32, ones_i16x32));
|
|
1371
|
+
__m512i sq_i32x16 = _mm512_madd_epi16(data_i16x32, data_i16x32);
|
|
1372
|
+
sumsq_i64x8 = _mm512_add_epi64(sumsq_i64x8, _mm512_cvtepi32_epi64(_mm512_castsi512_si256(sq_i32x16)));
|
|
1373
|
+
sumsq_i64x8 = _mm512_add_epi64(sumsq_i64x8, _mm512_cvtepi32_epi64(_mm512_extracti64x4_epi64(sq_i32x16, 1)));
|
|
1374
|
+
}
|
|
1375
|
+
__m512i sum_i64x8 = _mm512_add_epi64( //
|
|
1376
|
+
_mm512_cvtepi32_epi64(_mm512_castsi512_si256(sum_i32x16)), //
|
|
1377
|
+
_mm512_cvtepi32_epi64(_mm512_extracti64x4_epi64(sum_i32x16, 1))); //
|
|
1378
|
+
nk_i64_t sum = nk_reduce_add_i64x8_skylake_(sum_i64x8);
|
|
1379
|
+
nk_u64_t sumsq = (nk_u64_t)nk_reduce_add_i64x8_skylake_(sumsq_i64x8);
|
|
1380
|
+
nk_i16_t const *ptr = data_ptr + idx_scalars;
|
|
1381
|
+
nk_size_t remaining = count - idx_scalars / stride_elements;
|
|
1382
|
+
for (nk_size_t i = 0; i < remaining; ++i, ptr += stride_elements) {
|
|
1383
|
+
nk_i64_t val = (nk_i64_t)*ptr;
|
|
1384
|
+
sum += val, sumsq += (nk_u64_t)(val * val);
|
|
1385
|
+
}
|
|
1386
|
+
*sum_ptr = sum, *sumsq_ptr = sumsq;
|
|
1387
|
+
}
|
|
1388
|
+
|
|
1389
|
+
NK_PUBLIC void nk_reduce_moments_i16_skylake( //
|
|
1390
|
+
nk_i16_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
1391
|
+
nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
1392
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_i16_t);
|
|
1393
|
+
int aligned = (stride_bytes % sizeof(nk_i16_t) == 0);
|
|
1394
|
+
if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
|
|
1395
|
+
else if (!aligned) nk_reduce_moments_i16_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
1396
|
+
else if (count > (nk_size_t)(NK_I16_MAX + 1) * 32) {
|
|
1397
|
+
nk_size_t left_count = count / 2;
|
|
1398
|
+
nk_i64_t left_sum, right_sum;
|
|
1399
|
+
nk_u64_t left_sumsq, right_sumsq;
|
|
1400
|
+
nk_reduce_moments_i16_skylake(data_ptr, left_count, stride_bytes, &left_sum, &left_sumsq);
|
|
1401
|
+
nk_reduce_moments_i16_skylake(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
|
|
1402
|
+
&right_sum, &right_sumsq);
|
|
1403
|
+
*sum_ptr = nk_i64_saturating_add_serial(left_sum, right_sum);
|
|
1404
|
+
*sumsq_ptr = nk_u64_saturating_add_serial(left_sumsq, right_sumsq);
|
|
1405
|
+
}
|
|
1406
|
+
else if (stride_elements == 1) nk_reduce_moments_i16_skylake_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
|
|
1407
|
+
else if (stride_elements <= 16)
|
|
1408
|
+
nk_reduce_moments_i16_skylake_strided_(data_ptr, count, stride_elements, sum_ptr, sumsq_ptr);
|
|
1409
|
+
else nk_reduce_moments_i16_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
1410
|
+
}
|
|
1411
|
+
|
|
1412
|
+
NK_INTERNAL void nk_reduce_minmax_i16_skylake_contiguous_( //
|
|
1413
|
+
nk_i16_t const *data_ptr, nk_size_t count, //
|
|
1414
|
+
nk_i16_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
1415
|
+
nk_i16_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
1416
|
+
__m512i min_i16x32 = _mm512_set1_epi16((short)NK_I16_MAX);
|
|
1417
|
+
__m512i max_i16x32 = _mm512_set1_epi16(NK_I16_MIN);
|
|
1418
|
+
__m512i min_loop_cycle_u16x32 = _mm512_setzero_si512();
|
|
1419
|
+
__m512i max_loop_cycle_u16x32 = _mm512_setzero_si512();
|
|
1420
|
+
__m512i current_loop_cycle_u16x32 = _mm512_setzero_si512();
|
|
1421
|
+
__m512i one_u16x32 = _mm512_set1_epi16(1);
|
|
1422
|
+
|
|
1423
|
+
nk_size_t idx = 0;
|
|
1424
|
+
for (; idx + 32 <= count; idx += 32) {
|
|
1425
|
+
__m512i data_i16x32 = _mm512_loadu_si512(data_ptr + idx);
|
|
1426
|
+
__mmask32 min_changed_mask = _mm512_cmp_epi16_mask(data_i16x32, min_i16x32, _MM_CMPINT_LT);
|
|
1427
|
+
__mmask32 max_changed_mask = _mm512_cmp_epi16_mask(data_i16x32, max_i16x32, _MM_CMPINT_NLE);
|
|
1428
|
+
min_i16x32 = _mm512_mask_mov_epi16(min_i16x32, min_changed_mask, data_i16x32);
|
|
1429
|
+
max_i16x32 = _mm512_mask_mov_epi16(max_i16x32, max_changed_mask, data_i16x32);
|
|
1430
|
+
min_loop_cycle_u16x32 = _mm512_mask_mov_epi16(min_loop_cycle_u16x32, min_changed_mask,
|
|
1431
|
+
current_loop_cycle_u16x32);
|
|
1432
|
+
max_loop_cycle_u16x32 = _mm512_mask_mov_epi16(max_loop_cycle_u16x32, max_changed_mask,
|
|
1433
|
+
current_loop_cycle_u16x32);
|
|
1434
|
+
current_loop_cycle_u16x32 = _mm512_add_epi16(current_loop_cycle_u16x32, one_u16x32);
|
|
1435
|
+
}
|
|
1436
|
+
|
|
1437
|
+
nk_size_t remaining = count - idx;
|
|
1438
|
+
if (remaining > 0) {
|
|
1439
|
+
__mmask32 tail_load = (__mmask32)_bzhi_u32(0xFFFFFFFF, (unsigned int)remaining);
|
|
1440
|
+
__m512i tail_i16x32 = _mm512_maskz_loadu_epi16(tail_load, data_ptr + idx);
|
|
1441
|
+
__mmask32 min_changed_mask = _mm512_mask_cmp_epi16_mask(tail_load, tail_i16x32, min_i16x32, _MM_CMPINT_LT);
|
|
1442
|
+
__mmask32 max_changed_mask = _mm512_mask_cmp_epi16_mask(tail_load, tail_i16x32, max_i16x32, _MM_CMPINT_NLE);
|
|
1443
|
+
min_i16x32 = _mm512_mask_mov_epi16(min_i16x32, min_changed_mask, tail_i16x32);
|
|
1444
|
+
max_i16x32 = _mm512_mask_mov_epi16(max_i16x32, max_changed_mask, tail_i16x32);
|
|
1445
|
+
min_loop_cycle_u16x32 = _mm512_mask_mov_epi16(min_loop_cycle_u16x32, min_changed_mask,
|
|
1446
|
+
current_loop_cycle_u16x32);
|
|
1447
|
+
max_loop_cycle_u16x32 = _mm512_mask_mov_epi16(max_loop_cycle_u16x32, max_changed_mask,
|
|
1448
|
+
current_loop_cycle_u16x32);
|
|
1449
|
+
}
|
|
1450
|
+
|
|
1451
|
+
nk_i16_t min_value = nk_reduce_min_i16x32_skylake_(min_i16x32);
|
|
1452
|
+
nk_i16_t max_value = nk_reduce_max_i16x32_skylake_(max_i16x32);
|
|
1453
|
+
unsigned int min_lane, max_lane;
|
|
1454
|
+
{
|
|
1455
|
+
__mmask32 value_match_mask = _mm512_cmpeq_epi16_mask(min_i16x32, _mm512_set1_epi16(min_value));
|
|
1456
|
+
__m512i masked_cycle_u16x32 = _mm512_mask_blend_epi16(value_match_mask, _mm512_set1_epi16((short)NK_U16_MAX),
|
|
1457
|
+
min_loop_cycle_u16x32);
|
|
1458
|
+
nk_u16_t earliest_loop_cycle = nk_reduce_min_u16x32_skylake_(masked_cycle_u16x32);
|
|
1459
|
+
__mmask32 cycle_match_mask = _mm512_cmpeq_epi16_mask(masked_cycle_u16x32,
|
|
1460
|
+
_mm512_set1_epi16((short)earliest_loop_cycle));
|
|
1461
|
+
min_lane = _tzcnt_u32(cycle_match_mask);
|
|
1462
|
+
}
|
|
1463
|
+
{
|
|
1464
|
+
__mmask32 value_match_mask = _mm512_cmpeq_epi16_mask(max_i16x32, _mm512_set1_epi16(max_value));
|
|
1465
|
+
__m512i masked_cycle_u16x32 = _mm512_mask_blend_epi16(value_match_mask, _mm512_set1_epi16((short)NK_U16_MAX),
|
|
1466
|
+
max_loop_cycle_u16x32);
|
|
1467
|
+
nk_u16_t earliest_loop_cycle = nk_reduce_min_u16x32_skylake_(masked_cycle_u16x32);
|
|
1468
|
+
__mmask32 cycle_match_mask = _mm512_cmpeq_epi16_mask(masked_cycle_u16x32,
|
|
1469
|
+
_mm512_set1_epi16((short)earliest_loop_cycle));
|
|
1470
|
+
max_lane = _tzcnt_u32(cycle_match_mask);
|
|
1471
|
+
}
|
|
1472
|
+
nk_b512_vec_t loop_cycle_vec;
|
|
1473
|
+
loop_cycle_vec.zmm = min_loop_cycle_u16x32;
|
|
1474
|
+
*min_value_ptr = min_value, *min_index_ptr = (nk_size_t)loop_cycle_vec.u16s[min_lane] * 32 + min_lane;
|
|
1475
|
+
loop_cycle_vec.zmm = max_loop_cycle_u16x32;
|
|
1476
|
+
*max_value_ptr = max_value, *max_index_ptr = (nk_size_t)loop_cycle_vec.u16s[max_lane] * 32 + max_lane;
|
|
1477
|
+
}
|
|
1478
|
+
|
|
1479
|
+
NK_PUBLIC void nk_reduce_minmax_i16_skylake( //
|
|
1480
|
+
nk_i16_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
1481
|
+
nk_i16_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
1482
|
+
nk_i16_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
1483
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_i16_t);
|
|
1484
|
+
int aligned = (stride_bytes % sizeof(nk_i16_t) == 0);
|
|
1485
|
+
if (count == 0)
|
|
1486
|
+
*min_value_ptr = NK_I16_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_I16_MIN,
|
|
1487
|
+
*max_index_ptr = NK_SIZE_MAX;
|
|
1488
|
+
else if (!aligned)
|
|
1489
|
+
nk_reduce_minmax_i16_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
1490
|
+
max_index_ptr);
|
|
1491
|
+
else if (count > (nk_size_t)(NK_U16_MAX + 1) * 32) {
|
|
1492
|
+
nk_size_t left_count = count / 2;
|
|
1493
|
+
nk_i16_t left_min, right_min, left_max, right_max;
|
|
1494
|
+
nk_size_t left_min_index, right_min_index, left_max_index, right_max_index;
|
|
1495
|
+
nk_reduce_minmax_i16_skylake(data_ptr, left_count, stride_bytes, &left_min, &left_min_index, &left_max,
|
|
1496
|
+
&left_max_index);
|
|
1497
|
+
nk_reduce_minmax_i16_skylake(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
|
|
1498
|
+
&right_min, &right_min_index, &right_max, &right_max_index);
|
|
1499
|
+
if (right_min < left_min) *min_value_ptr = right_min, *min_index_ptr = left_count + right_min_index;
|
|
1500
|
+
else *min_value_ptr = left_min, *min_index_ptr = left_min_index;
|
|
1501
|
+
if (right_max > left_max) *max_value_ptr = right_max, *max_index_ptr = left_count + right_max_index;
|
|
1502
|
+
else *max_value_ptr = left_max, *max_index_ptr = left_max_index;
|
|
1503
|
+
}
|
|
1504
|
+
else if (stride_elements == 1)
|
|
1505
|
+
nk_reduce_minmax_i16_skylake_contiguous_(data_ptr, count, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
1506
|
+
max_index_ptr);
|
|
1507
|
+
else
|
|
1508
|
+
nk_reduce_minmax_i16_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
1509
|
+
max_index_ptr);
|
|
1510
|
+
}
|
|
1511
|
+
|
|
1512
|
+
NK_INTERNAL void nk_reduce_moments_u16_skylake_contiguous_( //
|
|
1513
|
+
nk_u16_t const *data_ptr, nk_size_t count, //
|
|
1514
|
+
nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
1515
|
+
// Widen u16→u32, square in u32, widen to u64. Avoids bias trick whose
|
|
1516
|
+
// VPMADDWD pair-of-squares overflows i32 when both lanes map to -32768.
|
|
1517
|
+
__m512i zero = _mm512_setzero_si512();
|
|
1518
|
+
__m512i sum_u32x16 = _mm512_setzero_si512();
|
|
1519
|
+
__m512i sumsq_u64x8 = _mm512_setzero_si512();
|
|
1520
|
+
nk_size_t idx = 0;
|
|
1521
|
+
for (; idx + 16 <= count; idx += 16) {
|
|
1522
|
+
__m512i data_u32x16 = _mm512_cvtepu16_epi32(_mm256_loadu_si256((__m256i const *)(data_ptr + idx)));
|
|
1523
|
+
sum_u32x16 = _mm512_add_epi32(sum_u32x16, data_u32x16);
|
|
1524
|
+
__m512i sq_u32x16 = _mm512_mullo_epi32(data_u32x16, data_u32x16);
|
|
1525
|
+
sumsq_u64x8 = _mm512_add_epi64(sumsq_u64x8, _mm512_unpacklo_epi32(sq_u32x16, zero));
|
|
1526
|
+
sumsq_u64x8 = _mm512_add_epi64(sumsq_u64x8, _mm512_unpackhi_epi32(sq_u32x16, zero));
|
|
1527
|
+
}
|
|
1528
|
+
if (idx < count) {
|
|
1529
|
+
__mmask16 tail_mask = (__mmask16)((1u << (count - idx)) - 1);
|
|
1530
|
+
__m512i data_u32x16 = _mm512_cvtepu16_epi32(_mm256_maskz_loadu_epi16(tail_mask, data_ptr + idx));
|
|
1531
|
+
sum_u32x16 = _mm512_add_epi32(sum_u32x16, data_u32x16);
|
|
1532
|
+
__m512i sq_u32x16 = _mm512_mullo_epi32(data_u32x16, data_u32x16);
|
|
1533
|
+
sumsq_u64x8 = _mm512_add_epi64(sumsq_u64x8, _mm512_unpacklo_epi32(sq_u32x16, zero));
|
|
1534
|
+
sumsq_u64x8 = _mm512_add_epi64(sumsq_u64x8, _mm512_unpackhi_epi32(sq_u32x16, zero));
|
|
1535
|
+
}
|
|
1536
|
+
__m512i sum_u64x8 = _mm512_add_epi64( //
|
|
1537
|
+
_mm512_unpacklo_epi32(sum_u32x16, zero), //
|
|
1538
|
+
_mm512_unpackhi_epi32(sum_u32x16, zero)); //
|
|
1539
|
+
*sum_ptr = (nk_u64_t)nk_reduce_add_i64x8_skylake_(sum_u64x8);
|
|
1540
|
+
*sumsq_ptr = (nk_u64_t)nk_reduce_add_i64x8_skylake_(sumsq_u64x8);
|
|
1541
|
+
}
|
|
1542
|
+
|
|
1543
|
+
NK_INTERNAL void nk_reduce_moments_u16_skylake_strided_( //
|
|
1544
|
+
nk_u16_t const *data_ptr, nk_size_t count, nk_size_t stride_elements, //
|
|
1545
|
+
nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
1546
|
+
__mmask32 stride_mask_m32 = nk_stride_mask_b16x32_(stride_elements);
|
|
1547
|
+
__m512i zero = _mm512_setzero_si512();
|
|
1548
|
+
__m512i sum_u64x8 = _mm512_setzero_si512();
|
|
1549
|
+
__m512i sumsq_u64x8 = _mm512_setzero_si512();
|
|
1550
|
+
nk_size_t idx_scalars = 0;
|
|
1551
|
+
nk_size_t total_scalars = count * stride_elements;
|
|
1552
|
+
nk_size_t step = (nk_size_t)_mm_popcnt_u64((nk_u64_t)stride_mask_m32) * stride_elements;
|
|
1553
|
+
for (; idx_scalars + step <= total_scalars; idx_scalars += step) {
|
|
1554
|
+
__m512i data_u16x32 = _mm512_maskz_loadu_epi16(stride_mask_m32, data_ptr + idx_scalars);
|
|
1555
|
+
__m512i low_u32x16 = _mm512_unpacklo_epi16(data_u16x32, zero);
|
|
1556
|
+
__m512i high_u32x16 = _mm512_unpackhi_epi16(data_u16x32, zero);
|
|
1557
|
+
sum_u64x8 = _mm512_add_epi64(sum_u64x8, _mm512_unpacklo_epi32(low_u32x16, zero));
|
|
1558
|
+
sum_u64x8 = _mm512_add_epi64(sum_u64x8, _mm512_unpackhi_epi32(low_u32x16, zero));
|
|
1559
|
+
sum_u64x8 = _mm512_add_epi64(sum_u64x8, _mm512_unpacklo_epi32(high_u32x16, zero));
|
|
1560
|
+
sum_u64x8 = _mm512_add_epi64(sum_u64x8, _mm512_unpackhi_epi32(high_u32x16, zero));
|
|
1561
|
+
__m512i low_sq = _mm512_mullo_epi32(low_u32x16, low_u32x16);
|
|
1562
|
+
__m512i high_sq = _mm512_mullo_epi32(high_u32x16, high_u32x16);
|
|
1563
|
+
sumsq_u64x8 = _mm512_add_epi64(sumsq_u64x8, _mm512_unpacklo_epi32(low_sq, zero));
|
|
1564
|
+
sumsq_u64x8 = _mm512_add_epi64(sumsq_u64x8, _mm512_unpackhi_epi32(low_sq, zero));
|
|
1565
|
+
sumsq_u64x8 = _mm512_add_epi64(sumsq_u64x8, _mm512_unpacklo_epi32(high_sq, zero));
|
|
1566
|
+
sumsq_u64x8 = _mm512_add_epi64(sumsq_u64x8, _mm512_unpackhi_epi32(high_sq, zero));
|
|
1567
|
+
}
|
|
1568
|
+
nk_u64_t sum = (nk_u64_t)nk_reduce_add_i64x8_skylake_(sum_u64x8);
|
|
1569
|
+
nk_u64_t sumsq = (nk_u64_t)nk_reduce_add_i64x8_skylake_(sumsq_u64x8);
|
|
1570
|
+
nk_u16_t const *ptr = data_ptr + idx_scalars;
|
|
1571
|
+
nk_size_t remaining = count - idx_scalars / stride_elements;
|
|
1572
|
+
for (nk_size_t i = 0; i < remaining; ++i, ptr += stride_elements) {
|
|
1573
|
+
nk_u64_t val = (nk_u64_t)*ptr;
|
|
1574
|
+
sum += val, sumsq += val * val;
|
|
1575
|
+
}
|
|
1576
|
+
*sum_ptr = sum, *sumsq_ptr = sumsq;
|
|
1577
|
+
}
|
|
1578
|
+
|
|
1579
|
+
NK_PUBLIC void nk_reduce_moments_u16_skylake( //
|
|
1580
|
+
nk_u16_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
1581
|
+
nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
1582
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_u16_t);
|
|
1583
|
+
int aligned = (stride_bytes % sizeof(nk_u16_t) == 0);
|
|
1584
|
+
if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
|
|
1585
|
+
else if (!aligned) nk_reduce_moments_u16_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
1586
|
+
else if (count > (nk_size_t)(NK_U16_MAX + 1) * 16) {
|
|
1587
|
+
nk_size_t left_count = count / 2;
|
|
1588
|
+
nk_u64_t left_sum, left_sumsq, right_sum, right_sumsq;
|
|
1589
|
+
nk_reduce_moments_u16_skylake(data_ptr, left_count, stride_bytes, &left_sum, &left_sumsq);
|
|
1590
|
+
nk_reduce_moments_u16_skylake(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
|
|
1591
|
+
&right_sum, &right_sumsq);
|
|
1592
|
+
*sum_ptr = nk_u64_saturating_add_serial(left_sum, right_sum);
|
|
1593
|
+
*sumsq_ptr = nk_u64_saturating_add_serial(left_sumsq, right_sumsq);
|
|
1594
|
+
}
|
|
1595
|
+
else if (stride_elements == 1) nk_reduce_moments_u16_skylake_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
|
|
1596
|
+
else if (stride_elements <= 16)
|
|
1597
|
+
nk_reduce_moments_u16_skylake_strided_(data_ptr, count, stride_elements, sum_ptr, sumsq_ptr);
|
|
1598
|
+
else nk_reduce_moments_u16_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
1599
|
+
}
|
|
1600
|
+
|
|
1601
|
+
NK_INTERNAL void nk_reduce_minmax_u16_skylake_contiguous_( //
|
|
1602
|
+
nk_u16_t const *data_ptr, nk_size_t count, //
|
|
1603
|
+
nk_u16_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
1604
|
+
nk_u16_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
1605
|
+
__m512i min_u16x32 = _mm512_set1_epi16((short)NK_U16_MAX);
|
|
1606
|
+
__m512i max_u16x32 = _mm512_setzero_si512();
|
|
1607
|
+
__m512i min_loop_cycle_u16x32 = _mm512_setzero_si512();
|
|
1608
|
+
__m512i max_loop_cycle_u16x32 = _mm512_setzero_si512();
|
|
1609
|
+
__m512i current_loop_cycle_u16x32 = _mm512_setzero_si512();
|
|
1610
|
+
__m512i one_u16x32 = _mm512_set1_epi16(1);
|
|
1611
|
+
|
|
1612
|
+
nk_size_t idx = 0;
|
|
1613
|
+
for (; idx + 32 <= count; idx += 32) {
|
|
1614
|
+
__m512i data_u16x32 = _mm512_loadu_si512(data_ptr + idx);
|
|
1615
|
+
__mmask32 min_changed_mask = _mm512_cmp_epu16_mask(data_u16x32, min_u16x32, _MM_CMPINT_LT);
|
|
1616
|
+
__mmask32 max_changed_mask = _mm512_cmp_epu16_mask(data_u16x32, max_u16x32, _MM_CMPINT_NLE);
|
|
1617
|
+
min_u16x32 = _mm512_mask_mov_epi16(min_u16x32, min_changed_mask, data_u16x32);
|
|
1618
|
+
max_u16x32 = _mm512_mask_mov_epi16(max_u16x32, max_changed_mask, data_u16x32);
|
|
1619
|
+
min_loop_cycle_u16x32 = _mm512_mask_mov_epi16(min_loop_cycle_u16x32, min_changed_mask,
|
|
1620
|
+
current_loop_cycle_u16x32);
|
|
1621
|
+
max_loop_cycle_u16x32 = _mm512_mask_mov_epi16(max_loop_cycle_u16x32, max_changed_mask,
|
|
1622
|
+
current_loop_cycle_u16x32);
|
|
1623
|
+
current_loop_cycle_u16x32 = _mm512_add_epi16(current_loop_cycle_u16x32, one_u16x32);
|
|
1624
|
+
}
|
|
1625
|
+
|
|
1626
|
+
nk_size_t remaining = count - idx;
|
|
1627
|
+
if (remaining > 0) {
|
|
1628
|
+
__mmask32 tail_load = (__mmask32)_bzhi_u32(0xFFFFFFFF, (unsigned int)remaining);
|
|
1629
|
+
__m512i tail_u16x32 = _mm512_maskz_loadu_epi16(tail_load, data_ptr + idx);
|
|
1630
|
+
__mmask32 min_changed_mask = _mm512_mask_cmp_epu16_mask(tail_load, tail_u16x32, min_u16x32, _MM_CMPINT_LT);
|
|
1631
|
+
__mmask32 max_changed_mask = _mm512_mask_cmp_epu16_mask(tail_load, tail_u16x32, max_u16x32, _MM_CMPINT_NLE);
|
|
1632
|
+
min_u16x32 = _mm512_mask_mov_epi16(min_u16x32, min_changed_mask, tail_u16x32);
|
|
1633
|
+
max_u16x32 = _mm512_mask_mov_epi16(max_u16x32, max_changed_mask, tail_u16x32);
|
|
1634
|
+
min_loop_cycle_u16x32 = _mm512_mask_mov_epi16(min_loop_cycle_u16x32, min_changed_mask,
|
|
1635
|
+
current_loop_cycle_u16x32);
|
|
1636
|
+
max_loop_cycle_u16x32 = _mm512_mask_mov_epi16(max_loop_cycle_u16x32, max_changed_mask,
|
|
1637
|
+
current_loop_cycle_u16x32);
|
|
1638
|
+
}
|
|
1639
|
+
|
|
1640
|
+
nk_u16_t min_value = nk_reduce_min_u16x32_skylake_(min_u16x32);
|
|
1641
|
+
nk_u16_t max_value = nk_reduce_max_u16x32_skylake_(max_u16x32);
|
|
1642
|
+
unsigned int min_lane, max_lane;
|
|
1643
|
+
{
|
|
1644
|
+
__mmask32 value_match_mask = _mm512_cmpeq_epi16_mask(min_u16x32, _mm512_set1_epi16((short)min_value));
|
|
1645
|
+
__m512i masked_cycle_u16x32 = _mm512_mask_blend_epi16(value_match_mask, _mm512_set1_epi16((short)NK_U16_MAX),
|
|
1646
|
+
min_loop_cycle_u16x32);
|
|
1647
|
+
nk_u16_t earliest_loop_cycle = nk_reduce_min_u16x32_skylake_(masked_cycle_u16x32);
|
|
1648
|
+
__mmask32 cycle_match_mask = _mm512_cmpeq_epi16_mask(masked_cycle_u16x32,
|
|
1649
|
+
_mm512_set1_epi16((short)earliest_loop_cycle));
|
|
1650
|
+
min_lane = _tzcnt_u32(cycle_match_mask);
|
|
1651
|
+
}
|
|
1652
|
+
{
|
|
1653
|
+
__mmask32 value_match_mask = _mm512_cmpeq_epi16_mask(max_u16x32, _mm512_set1_epi16((short)max_value));
|
|
1654
|
+
__m512i masked_cycle_u16x32 = _mm512_mask_blend_epi16(value_match_mask, _mm512_set1_epi16((short)NK_U16_MAX),
|
|
1655
|
+
max_loop_cycle_u16x32);
|
|
1656
|
+
nk_u16_t earliest_loop_cycle = nk_reduce_min_u16x32_skylake_(masked_cycle_u16x32);
|
|
1657
|
+
__mmask32 cycle_match_mask = _mm512_cmpeq_epi16_mask(masked_cycle_u16x32,
|
|
1658
|
+
_mm512_set1_epi16((short)earliest_loop_cycle));
|
|
1659
|
+
max_lane = _tzcnt_u32(cycle_match_mask);
|
|
1660
|
+
}
|
|
1661
|
+
nk_b512_vec_t loop_cycle_vec;
|
|
1662
|
+
loop_cycle_vec.zmm = min_loop_cycle_u16x32;
|
|
1663
|
+
*min_value_ptr = min_value, *min_index_ptr = (nk_size_t)loop_cycle_vec.u16s[min_lane] * 32 + min_lane;
|
|
1664
|
+
loop_cycle_vec.zmm = max_loop_cycle_u16x32;
|
|
1665
|
+
*max_value_ptr = max_value, *max_index_ptr = (nk_size_t)loop_cycle_vec.u16s[max_lane] * 32 + max_lane;
|
|
1666
|
+
}
|
|
1667
|
+
|
|
1668
|
+
NK_PUBLIC void nk_reduce_minmax_u16_skylake( //
|
|
1669
|
+
nk_u16_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
1670
|
+
nk_u16_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
1671
|
+
nk_u16_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
1672
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_u16_t);
|
|
1673
|
+
int aligned = (stride_bytes % sizeof(nk_u16_t) == 0);
|
|
1674
|
+
if (count == 0)
|
|
1675
|
+
*min_value_ptr = NK_U16_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = 0, *max_index_ptr = NK_SIZE_MAX;
|
|
1676
|
+
else if (!aligned)
|
|
1677
|
+
nk_reduce_minmax_u16_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
1678
|
+
max_index_ptr);
|
|
1679
|
+
else if (count > (nk_size_t)(NK_U16_MAX + 1) * 32) {
|
|
1680
|
+
nk_size_t left_count = count / 2;
|
|
1681
|
+
nk_u16_t left_min, right_min, left_max, right_max;
|
|
1682
|
+
nk_size_t left_min_index, right_min_index, left_max_index, right_max_index;
|
|
1683
|
+
nk_reduce_minmax_u16_skylake(data_ptr, left_count, stride_bytes, &left_min, &left_min_index, &left_max,
|
|
1684
|
+
&left_max_index);
|
|
1685
|
+
nk_reduce_minmax_u16_skylake(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
|
|
1686
|
+
&right_min, &right_min_index, &right_max, &right_max_index);
|
|
1687
|
+
if (right_min < left_min) *min_value_ptr = right_min, *min_index_ptr = left_count + right_min_index;
|
|
1688
|
+
else *min_value_ptr = left_min, *min_index_ptr = left_min_index;
|
|
1689
|
+
if (right_max > left_max) *max_value_ptr = right_max, *max_index_ptr = left_count + right_max_index;
|
|
1690
|
+
else *max_value_ptr = left_max, *max_index_ptr = left_max_index;
|
|
1691
|
+
}
|
|
1692
|
+
else if (stride_elements == 1)
|
|
1693
|
+
nk_reduce_minmax_u16_skylake_contiguous_(data_ptr, count, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
1694
|
+
max_index_ptr);
|
|
1695
|
+
else
|
|
1696
|
+
nk_reduce_minmax_u16_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
1697
|
+
max_index_ptr);
|
|
1698
|
+
}
|
|
1699
|
+
|
|
1700
|
+
/** @brief Unsigned saturating add of two i64x8 vectors (3 uops). */
|
|
1701
|
+
NK_INTERNAL __m512i nk_u64_sadd_epi64_skylake_(__m512i a, __m512i b) {
|
|
1702
|
+
__m512i result = _mm512_add_epi64(a, b);
|
|
1703
|
+
__mmask8 ovf = _mm512_cmp_epu64_mask(result, a, _MM_CMPINT_LT);
|
|
1704
|
+
return _mm512_mask_mov_epi64(result, ovf, _mm512_set1_epi64((nk_i64_t)-1));
|
|
1705
|
+
}
|
|
1706
|
+
|
|
1707
|
+
/** @brief Saturating i64 square: clamp when |val| > floor(sqrt(INT64_MAX)). */
|
|
1708
|
+
NK_INTERNAL __m512i nk_i64_smul_sq_epi64_skylake_(__m512i val) {
|
|
1709
|
+
__m512i sq = _mm512_mullo_epi64(val, val);
|
|
1710
|
+
__m512i abs_val = _mm512_abs_epi64(val);
|
|
1711
|
+
__mmask8 ovf = _mm512_cmp_epu64_mask(abs_val, _mm512_set1_epi64(3037000499ll), _MM_CMPINT_NLE);
|
|
1712
|
+
return _mm512_mask_mov_epi64(sq, ovf, _mm512_set1_epi64(9223372036854775807ll));
|
|
1713
|
+
}
|
|
1714
|
+
|
|
1715
|
+
/** @brief Saturating u64 square: clamp when val > floor(sqrt(UINT64_MAX)). */
|
|
1716
|
+
NK_INTERNAL __m512i nk_u64_smul_sq_epi64_skylake_(__m512i val) {
|
|
1717
|
+
__m512i sq = _mm512_mullo_epi64(val, val);
|
|
1718
|
+
__mmask8 ovf = _mm512_cmp_epu64_mask(val, _mm512_set1_epi64(4294967295ll), _MM_CMPINT_NLE);
|
|
1719
|
+
return _mm512_mask_mov_epi64(sq, ovf, _mm512_set1_epi64((nk_i64_t)-1));
|
|
1720
|
+
}
|
|
1721
|
+
|
|
1722
|
+
/** @brief Saturating horizontal sum of 8 unsigned u64 lanes.
|
|
1723
|
+
* Tree reduction: unsigned saturating add is order-independent because the
|
|
1724
|
+
* accumulator can only increase — once saturated to UINT64_MAX, it stays there.
|
|
1725
|
+
* Result equals min(true_sum, UINT64_MAX) regardless of reduction order. */
|
|
1726
|
+
NK_INTERNAL nk_u64_t nk_reduce_sadd_u64x8_skylake_(__m512i v) {
|
|
1727
|
+
// 8→4: fold high 256 bits into low 256 bits (VSHUFI64X2 + 3-uop sat-add)
|
|
1728
|
+
v = nk_u64_sadd_epi64_skylake_(v, _mm512_shuffle_i64x2(v, v, _MM_SHUFFLE(1, 0, 3, 2)));
|
|
1729
|
+
// 4→2: fold lanes 2-3 into lanes 0-1
|
|
1730
|
+
v = nk_u64_sadd_epi64_skylake_(v, _mm512_shuffle_i64x2(v, v, _MM_SHUFFLE(2, 3, 0, 1)));
|
|
1731
|
+
// 2→1: fold lane 1 into lane 0 (VALIGNQ + 3-uop sat-add)
|
|
1732
|
+
v = nk_u64_sadd_epi64_skylake_(v, _mm512_alignr_epi64(v, v, 1));
|
|
1733
|
+
return (nk_u64_t)_mm_cvtsi128_si64(_mm512_castsi512_si128(v));
|
|
1734
|
+
}
|
|
1735
|
+
|
|
1736
|
+
NK_INTERNAL void nk_reduce_moments_i32_skylake_contiguous_( //
|
|
1737
|
+
nk_i32_t const *data_ptr, nk_size_t count, //
|
|
1738
|
+
nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
1739
|
+
// Sum: 128-bit accumulation (lower + upper) — no block cap needed.
|
|
1740
|
+
// Sumsq: unsigned wrapping accumulation with carry-based overflow detection.
|
|
1741
|
+
__m512i sum_lower_i64x8 = _mm512_setzero_si512();
|
|
1742
|
+
__m512i sum_upper_i64x8 = _mm512_setzero_si512();
|
|
1743
|
+
__m512i sumsq_u64x8 = _mm512_setzero_si512();
|
|
1744
|
+
__mmask8 sumsq_overflow_mask = 0;
|
|
1745
|
+
__m512i one_i64x8 = _mm512_set1_epi64(1);
|
|
1746
|
+
nk_size_t idx = 0;
|
|
1747
|
+
for (; idx + 16 <= count; idx += 16) {
|
|
1748
|
+
__m512i data_i32x16 = _mm512_loadu_si512(data_ptr + idx);
|
|
1749
|
+
__m256i low_i32x8 = _mm512_castsi512_si256(data_i32x16);
|
|
1750
|
+
__m256i high_i32x8 = _mm512_extracti64x4_epi64(data_i32x16, 1);
|
|
1751
|
+
// 128-bit sum: lower half
|
|
1752
|
+
__m512i widened_low_i64x8 = _mm512_cvtepi32_epi64(low_i32x8);
|
|
1753
|
+
__m512i sum_before_i64x8 = sum_lower_i64x8;
|
|
1754
|
+
sum_lower_i64x8 = _mm512_add_epi64(sum_lower_i64x8, widened_low_i64x8);
|
|
1755
|
+
__mmask8 carry = _mm512_cmp_epu64_mask(sum_lower_i64x8, sum_before_i64x8, _MM_CMPINT_LT);
|
|
1756
|
+
sum_upper_i64x8 = _mm512_add_epi64(sum_upper_i64x8, _mm512_srai_epi64(widened_low_i64x8, 63));
|
|
1757
|
+
sum_upper_i64x8 = _mm512_mask_add_epi64(sum_upper_i64x8, carry, sum_upper_i64x8, one_i64x8);
|
|
1758
|
+
// 128-bit sum: upper half
|
|
1759
|
+
__m512i widened_high_i64x8 = _mm512_cvtepi32_epi64(high_i32x8);
|
|
1760
|
+
sum_before_i64x8 = sum_lower_i64x8;
|
|
1761
|
+
sum_lower_i64x8 = _mm512_add_epi64(sum_lower_i64x8, widened_high_i64x8);
|
|
1762
|
+
carry = _mm512_cmp_epu64_mask(sum_lower_i64x8, sum_before_i64x8, _MM_CMPINT_LT);
|
|
1763
|
+
sum_upper_i64x8 = _mm512_add_epi64(sum_upper_i64x8, _mm512_srai_epi64(widened_high_i64x8, 63));
|
|
1764
|
+
sum_upper_i64x8 = _mm512_mask_add_epi64(sum_upper_i64x8, carry, sum_upper_i64x8, one_i64x8);
|
|
1765
|
+
// Sumsq: unsigned accumulation with carry detection
|
|
1766
|
+
__m512i even_sq_u64x8 = _mm512_mul_epi32(data_i32x16, data_i32x16);
|
|
1767
|
+
__m512i odd_i32x16 = _mm512_srli_epi64(data_i32x16, 32);
|
|
1768
|
+
__m512i odd_sq_u64x8 = _mm512_mul_epi32(odd_i32x16, odd_i32x16);
|
|
1769
|
+
sumsq_u64x8 = _mm512_add_epi64(sumsq_u64x8, even_sq_u64x8);
|
|
1770
|
+
sumsq_overflow_mask |= _mm512_cmp_epu64_mask(sumsq_u64x8, even_sq_u64x8, _MM_CMPINT_LT);
|
|
1771
|
+
sumsq_u64x8 = _mm512_add_epi64(sumsq_u64x8, odd_sq_u64x8);
|
|
1772
|
+
sumsq_overflow_mask |= _mm512_cmp_epu64_mask(sumsq_u64x8, odd_sq_u64x8, _MM_CMPINT_LT);
|
|
1773
|
+
}
|
|
1774
|
+
nk_size_t remaining = count - idx;
|
|
1775
|
+
if (remaining > 0) {
|
|
1776
|
+
__mmask16 tail_mask = (__mmask16)_bzhi_u32(0xFFFF, (unsigned int)remaining);
|
|
1777
|
+
__m512i data_i32x16 = _mm512_maskz_loadu_epi32(tail_mask, data_ptr + idx);
|
|
1778
|
+
__m256i low_i32x8 = _mm512_castsi512_si256(data_i32x16);
|
|
1779
|
+
__m256i high_i32x8 = _mm512_extracti64x4_epi64(data_i32x16, 1);
|
|
1780
|
+
__m512i widened_low_i64x8 = _mm512_cvtepi32_epi64(low_i32x8);
|
|
1781
|
+
__m512i sum_before_i64x8 = sum_lower_i64x8;
|
|
1782
|
+
sum_lower_i64x8 = _mm512_add_epi64(sum_lower_i64x8, widened_low_i64x8);
|
|
1783
|
+
__mmask8 carry = _mm512_cmp_epu64_mask(sum_lower_i64x8, sum_before_i64x8, _MM_CMPINT_LT);
|
|
1784
|
+
sum_upper_i64x8 = _mm512_add_epi64(sum_upper_i64x8, _mm512_srai_epi64(widened_low_i64x8, 63));
|
|
1785
|
+
sum_upper_i64x8 = _mm512_mask_add_epi64(sum_upper_i64x8, carry, sum_upper_i64x8, one_i64x8);
|
|
1786
|
+
if (remaining > 8) {
|
|
1787
|
+
__m512i widened_high_i64x8 = _mm512_cvtepi32_epi64(high_i32x8);
|
|
1788
|
+
sum_before_i64x8 = sum_lower_i64x8;
|
|
1789
|
+
sum_lower_i64x8 = _mm512_add_epi64(sum_lower_i64x8, widened_high_i64x8);
|
|
1790
|
+
carry = _mm512_cmp_epu64_mask(sum_lower_i64x8, sum_before_i64x8, _MM_CMPINT_LT);
|
|
1791
|
+
sum_upper_i64x8 = _mm512_add_epi64(sum_upper_i64x8, _mm512_srai_epi64(widened_high_i64x8, 63));
|
|
1792
|
+
sum_upper_i64x8 = _mm512_mask_add_epi64(sum_upper_i64x8, carry, sum_upper_i64x8, one_i64x8);
|
|
1793
|
+
}
|
|
1794
|
+
__m512i even_sq_u64x8 = _mm512_mul_epi32(data_i32x16, data_i32x16);
|
|
1795
|
+
__m512i odd_i32x16 = _mm512_srli_epi64(data_i32x16, 32);
|
|
1796
|
+
__m512i odd_sq_u64x8 = _mm512_mul_epi32(odd_i32x16, odd_i32x16);
|
|
1797
|
+
sumsq_u64x8 = _mm512_add_epi64(sumsq_u64x8, even_sq_u64x8);
|
|
1798
|
+
sumsq_overflow_mask |= _mm512_cmp_epu64_mask(sumsq_u64x8, even_sq_u64x8, _MM_CMPINT_LT);
|
|
1799
|
+
sumsq_u64x8 = _mm512_add_epi64(sumsq_u64x8, odd_sq_u64x8);
|
|
1800
|
+
sumsq_overflow_mask |= _mm512_cmp_epu64_mask(sumsq_u64x8, odd_sq_u64x8, _MM_CMPINT_LT);
|
|
1801
|
+
}
|
|
1802
|
+
// Sumsq: horizontal unsigned saturating reduction
|
|
1803
|
+
nk_u64_t sumsq;
|
|
1804
|
+
if (sumsq_overflow_mask) sumsq = NK_U64_MAX;
|
|
1805
|
+
else sumsq = nk_reduce_sadd_u64x8_skylake_(sumsq_u64x8);
|
|
1806
|
+
// Sum: horizontal 128-bit tree reduction, same as i64 skylake
|
|
1807
|
+
{ // 8→4
|
|
1808
|
+
__m512i fold_lower_i64x8 = _mm512_shuffle_i64x2(sum_lower_i64x8, sum_lower_i64x8, _MM_SHUFFLE(1, 0, 3, 2));
|
|
1809
|
+
__m512i fold_upper_i64x8 = _mm512_shuffle_i64x2(sum_upper_i64x8, sum_upper_i64x8, _MM_SHUFFLE(1, 0, 3, 2));
|
|
1810
|
+
__m512i before_i64x8 = sum_lower_i64x8;
|
|
1811
|
+
sum_lower_i64x8 = _mm512_add_epi64(sum_lower_i64x8, fold_lower_i64x8);
|
|
1812
|
+
__mmask8 carry = _mm512_cmp_epu64_mask(sum_lower_i64x8, before_i64x8, _MM_CMPINT_LT);
|
|
1813
|
+
sum_upper_i64x8 = _mm512_add_epi64(sum_upper_i64x8, fold_upper_i64x8);
|
|
1814
|
+
sum_upper_i64x8 = _mm512_mask_add_epi64(sum_upper_i64x8, carry, sum_upper_i64x8, one_i64x8);
|
|
1815
|
+
}
|
|
1816
|
+
{ // 4→2
|
|
1817
|
+
__m512i fold_lower_i64x8 = _mm512_shuffle_i64x2(sum_lower_i64x8, sum_lower_i64x8, _MM_SHUFFLE(2, 3, 0, 1));
|
|
1818
|
+
__m512i fold_upper_i64x8 = _mm512_shuffle_i64x2(sum_upper_i64x8, sum_upper_i64x8, _MM_SHUFFLE(2, 3, 0, 1));
|
|
1819
|
+
__m512i before_i64x8 = sum_lower_i64x8;
|
|
1820
|
+
sum_lower_i64x8 = _mm512_add_epi64(sum_lower_i64x8, fold_lower_i64x8);
|
|
1821
|
+
__mmask8 carry = _mm512_cmp_epu64_mask(sum_lower_i64x8, before_i64x8, _MM_CMPINT_LT);
|
|
1822
|
+
sum_upper_i64x8 = _mm512_add_epi64(sum_upper_i64x8, fold_upper_i64x8);
|
|
1823
|
+
sum_upper_i64x8 = _mm512_mask_add_epi64(sum_upper_i64x8, carry, sum_upper_i64x8, one_i64x8);
|
|
1824
|
+
}
|
|
1825
|
+
{ // 2→1
|
|
1826
|
+
__m512i fold_lower_i64x8 = _mm512_alignr_epi64(sum_lower_i64x8, sum_lower_i64x8, 1);
|
|
1827
|
+
__m512i fold_upper_i64x8 = _mm512_alignr_epi64(sum_upper_i64x8, sum_upper_i64x8, 1);
|
|
1828
|
+
__m512i before_i64x8 = sum_lower_i64x8;
|
|
1829
|
+
sum_lower_i64x8 = _mm512_add_epi64(sum_lower_i64x8, fold_lower_i64x8);
|
|
1830
|
+
__mmask8 carry = _mm512_cmp_epu64_mask(sum_lower_i64x8, before_i64x8, _MM_CMPINT_LT);
|
|
1831
|
+
sum_upper_i64x8 = _mm512_add_epi64(sum_upper_i64x8, fold_upper_i64x8);
|
|
1832
|
+
sum_upper_i64x8 = _mm512_mask_add_epi64(sum_upper_i64x8, carry, sum_upper_i64x8, one_i64x8);
|
|
1833
|
+
}
|
|
1834
|
+
nk_i64_t sum_lower = _mm_cvtsi128_si64(_mm512_castsi512_si128(sum_lower_i64x8));
|
|
1835
|
+
nk_i64_t sum_upper = _mm_cvtsi128_si64(_mm512_castsi512_si128(sum_upper_i64x8));
|
|
1836
|
+
if (sum_upper == (sum_lower >> 63)) *sum_ptr = sum_lower;
|
|
1837
|
+
else if (sum_upper >= 0) *sum_ptr = NK_I64_MAX;
|
|
1838
|
+
else *sum_ptr = NK_I64_MIN;
|
|
1839
|
+
*sumsq_ptr = sumsq;
|
|
1840
|
+
}
|
|
1841
|
+
|
|
1842
|
+
NK_PUBLIC void nk_reduce_moments_i32_skylake( //
|
|
1843
|
+
nk_i32_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
1844
|
+
nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
1845
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_i32_t);
|
|
1846
|
+
int aligned = (stride_bytes % sizeof(nk_i32_t) == 0);
|
|
1847
|
+
if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
|
|
1848
|
+
else if (!aligned) nk_reduce_moments_i32_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
1849
|
+
else if (stride_elements == 1) nk_reduce_moments_i32_skylake_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
|
|
1850
|
+
else nk_reduce_moments_i32_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
1851
|
+
}
|
|
1852
|
+
|
|
1853
|
+
NK_INTERNAL void nk_reduce_minmax_i32_skylake_contiguous_( //
|
|
1854
|
+
nk_i32_t const *data_ptr, nk_size_t count, //
|
|
1855
|
+
nk_i32_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
1856
|
+
nk_i32_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
1857
|
+
__m512i min_i32x16 = _mm512_set1_epi32(NK_I32_MAX);
|
|
1858
|
+
__m512i max_i32x16 = _mm512_set1_epi32(NK_I32_MIN);
|
|
1859
|
+
__m512i min_loop_cycle_u32x16 = _mm512_setzero_si512();
|
|
1860
|
+
__m512i max_loop_cycle_u32x16 = _mm512_setzero_si512();
|
|
1861
|
+
__m512i current_loop_cycle_u32x16 = _mm512_setzero_si512();
|
|
1862
|
+
__m512i one_u32x16 = _mm512_set1_epi32(1);
|
|
1863
|
+
|
|
1864
|
+
nk_size_t idx = 0;
|
|
1865
|
+
for (; idx + 16 <= count; idx += 16) {
|
|
1866
|
+
__m512i data_i32x16 = _mm512_loadu_si512(data_ptr + idx);
|
|
1867
|
+
__mmask16 min_changed_mask = _mm512_cmp_epi32_mask(data_i32x16, min_i32x16, _MM_CMPINT_LT);
|
|
1868
|
+
__mmask16 max_changed_mask = _mm512_cmp_epi32_mask(data_i32x16, max_i32x16, _MM_CMPINT_NLE);
|
|
1869
|
+
min_i32x16 = _mm512_mask_mov_epi32(min_i32x16, min_changed_mask, data_i32x16);
|
|
1870
|
+
max_i32x16 = _mm512_mask_mov_epi32(max_i32x16, max_changed_mask, data_i32x16);
|
|
1871
|
+
min_loop_cycle_u32x16 = _mm512_mask_mov_epi32(min_loop_cycle_u32x16, min_changed_mask,
|
|
1872
|
+
current_loop_cycle_u32x16);
|
|
1873
|
+
max_loop_cycle_u32x16 = _mm512_mask_mov_epi32(max_loop_cycle_u32x16, max_changed_mask,
|
|
1874
|
+
current_loop_cycle_u32x16);
|
|
1875
|
+
current_loop_cycle_u32x16 = _mm512_add_epi32(current_loop_cycle_u32x16, one_u32x16);
|
|
1876
|
+
}
|
|
1877
|
+
|
|
1878
|
+
nk_size_t remaining = count - idx;
|
|
1879
|
+
if (remaining > 0) {
|
|
1880
|
+
__mmask16 tail_load = (__mmask16)_bzhi_u32(0xFFFF, (unsigned int)remaining);
|
|
1881
|
+
__m512i tail_i32x16 = _mm512_maskz_loadu_epi32(tail_load, data_ptr + idx);
|
|
1882
|
+
__mmask16 min_changed_mask = _mm512_mask_cmp_epi32_mask(tail_load, tail_i32x16, min_i32x16, _MM_CMPINT_LT);
|
|
1883
|
+
__mmask16 max_changed_mask = _mm512_mask_cmp_epi32_mask(tail_load, tail_i32x16, max_i32x16, _MM_CMPINT_NLE);
|
|
1884
|
+
min_i32x16 = _mm512_mask_mov_epi32(min_i32x16, min_changed_mask, tail_i32x16);
|
|
1885
|
+
max_i32x16 = _mm512_mask_mov_epi32(max_i32x16, max_changed_mask, tail_i32x16);
|
|
1886
|
+
min_loop_cycle_u32x16 = _mm512_mask_mov_epi32(min_loop_cycle_u32x16, min_changed_mask,
|
|
1887
|
+
current_loop_cycle_u32x16);
|
|
1888
|
+
max_loop_cycle_u32x16 = _mm512_mask_mov_epi32(max_loop_cycle_u32x16, max_changed_mask,
|
|
1889
|
+
current_loop_cycle_u32x16);
|
|
1890
|
+
}
|
|
1891
|
+
|
|
1892
|
+
nk_i32_t min_value = nk_reduce_min_i32x16_skylake_(min_i32x16);
|
|
1893
|
+
nk_i32_t max_value = nk_reduce_max_i32x16_skylake_(max_i32x16);
|
|
1894
|
+
unsigned int min_lane, max_lane;
|
|
1895
|
+
{
|
|
1896
|
+
__mmask16 value_match_mask = _mm512_cmpeq_epi32_mask(min_i32x16, _mm512_set1_epi32(min_value));
|
|
1897
|
+
__m512i masked_cycle_u32x16 = _mm512_mask_blend_epi32(value_match_mask, _mm512_set1_epi32((int)NK_U32_MAX),
|
|
1898
|
+
min_loop_cycle_u32x16);
|
|
1899
|
+
nk_u32_t earliest_loop_cycle = nk_reduce_min_u32x16_skylake_(masked_cycle_u32x16);
|
|
1900
|
+
__mmask16 cycle_match_mask = _mm512_cmpeq_epi32_mask(masked_cycle_u32x16,
|
|
1901
|
+
_mm512_set1_epi32((int)earliest_loop_cycle));
|
|
1902
|
+
min_lane = _tzcnt_u32(cycle_match_mask);
|
|
1903
|
+
}
|
|
1904
|
+
{
|
|
1905
|
+
__mmask16 value_match_mask = _mm512_cmpeq_epi32_mask(max_i32x16, _mm512_set1_epi32(max_value));
|
|
1906
|
+
__m512i masked_cycle_u32x16 = _mm512_mask_blend_epi32(value_match_mask, _mm512_set1_epi32((int)NK_U32_MAX),
|
|
1907
|
+
max_loop_cycle_u32x16);
|
|
1908
|
+
nk_u32_t earliest_loop_cycle = nk_reduce_min_u32x16_skylake_(masked_cycle_u32x16);
|
|
1909
|
+
__mmask16 cycle_match_mask = _mm512_cmpeq_epi32_mask(masked_cycle_u32x16,
|
|
1910
|
+
_mm512_set1_epi32((int)earliest_loop_cycle));
|
|
1911
|
+
max_lane = _tzcnt_u32(cycle_match_mask);
|
|
1912
|
+
}
|
|
1913
|
+
nk_b512_vec_t loop_cycle_vec;
|
|
1914
|
+
loop_cycle_vec.zmm = min_loop_cycle_u32x16;
|
|
1915
|
+
*min_value_ptr = min_value, *min_index_ptr = (nk_size_t)loop_cycle_vec.u32s[min_lane] * 16 + min_lane;
|
|
1916
|
+
loop_cycle_vec.zmm = max_loop_cycle_u32x16;
|
|
1917
|
+
*max_value_ptr = max_value, *max_index_ptr = (nk_size_t)loop_cycle_vec.u32s[max_lane] * 16 + max_lane;
|
|
1918
|
+
}
|
|
1919
|
+
|
|
1920
|
+
NK_PUBLIC void nk_reduce_minmax_i32_skylake( //
|
|
1921
|
+
nk_i32_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
1922
|
+
nk_i32_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
1923
|
+
nk_i32_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
1924
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_i32_t);
|
|
1925
|
+
int aligned = (stride_bytes % sizeof(nk_i32_t) == 0);
|
|
1926
|
+
if (count == 0)
|
|
1927
|
+
*min_value_ptr = NK_I32_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_I32_MIN,
|
|
1928
|
+
*max_index_ptr = NK_SIZE_MAX;
|
|
1929
|
+
else if (!aligned)
|
|
1930
|
+
nk_reduce_minmax_i32_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
1931
|
+
max_index_ptr);
|
|
1932
|
+
else if (count > (nk_size_t)NK_U32_MAX * 16) {
|
|
1933
|
+
nk_size_t left_count = count / 2;
|
|
1934
|
+
nk_i32_t left_min, right_min, left_max, right_max;
|
|
1935
|
+
nk_size_t left_min_index, right_min_index, left_max_index, right_max_index;
|
|
1936
|
+
nk_reduce_minmax_i32_skylake(data_ptr, left_count, stride_bytes, &left_min, &left_min_index, &left_max,
|
|
1937
|
+
&left_max_index);
|
|
1938
|
+
nk_reduce_minmax_i32_skylake(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
|
|
1939
|
+
&right_min, &right_min_index, &right_max, &right_max_index);
|
|
1940
|
+
if (right_min < left_min) *min_value_ptr = right_min, *min_index_ptr = left_count + right_min_index;
|
|
1941
|
+
else *min_value_ptr = left_min, *min_index_ptr = left_min_index;
|
|
1942
|
+
if (right_max > left_max) *max_value_ptr = right_max, *max_index_ptr = left_count + right_max_index;
|
|
1943
|
+
else *max_value_ptr = left_max, *max_index_ptr = left_max_index;
|
|
1944
|
+
}
|
|
1945
|
+
else if (stride_elements == 1)
|
|
1946
|
+
nk_reduce_minmax_i32_skylake_contiguous_(data_ptr, count, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
1947
|
+
max_index_ptr);
|
|
1948
|
+
else
|
|
1949
|
+
nk_reduce_minmax_i32_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
1950
|
+
max_index_ptr);
|
|
1951
|
+
}
|
|
1952
|
+
|
|
1953
|
+
NK_INTERNAL void nk_reduce_moments_u32_skylake_contiguous_( //
|
|
1954
|
+
nk_u32_t const *data_ptr, nk_size_t count, //
|
|
1955
|
+
nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
1956
|
+
// Sum: widen u32→u64, accumulate. Sumsq: VPMULUDQ for even/odd lanes (5-cycle, 1 uop each).
|
|
1957
|
+
__m512i sum_u64x8 = _mm512_setzero_si512();
|
|
1958
|
+
__m512i sumsq_u64x8 = _mm512_setzero_si512();
|
|
1959
|
+
__mmask8 sumsq_overflow_mask = 0;
|
|
1960
|
+
nk_size_t idx = 0;
|
|
1961
|
+
for (; idx + 16 <= count; idx += 16) {
|
|
1962
|
+
__m512i data_u32x16 = _mm512_loadu_si512(data_ptr + idx);
|
|
1963
|
+
__m256i low_u32x8 = _mm512_castsi512_si256(data_u32x16);
|
|
1964
|
+
__m256i high_u32x8 = _mm512_extracti64x4_epi64(data_u32x16, 1);
|
|
1965
|
+
sum_u64x8 = _mm512_add_epi64(sum_u64x8, _mm512_cvtepu32_epi64(low_u32x8));
|
|
1966
|
+
sum_u64x8 = _mm512_add_epi64(sum_u64x8, _mm512_cvtepu32_epi64(high_u32x8));
|
|
1967
|
+
__m512i even_sq_u64x8 = _mm512_mul_epu32(data_u32x16, data_u32x16);
|
|
1968
|
+
__m512i odd_u32x16 = _mm512_srli_epi64(data_u32x16, 32);
|
|
1969
|
+
__m512i odd_sq_u64x8 = _mm512_mul_epu32(odd_u32x16, odd_u32x16);
|
|
1970
|
+
sumsq_u64x8 = _mm512_add_epi64(sumsq_u64x8, even_sq_u64x8);
|
|
1971
|
+
sumsq_overflow_mask |= _mm512_cmp_epu64_mask(sumsq_u64x8, even_sq_u64x8, _MM_CMPINT_LT);
|
|
1972
|
+
sumsq_u64x8 = _mm512_add_epi64(sumsq_u64x8, odd_sq_u64x8);
|
|
1973
|
+
sumsq_overflow_mask |= _mm512_cmp_epu64_mask(sumsq_u64x8, odd_sq_u64x8, _MM_CMPINT_LT);
|
|
1974
|
+
}
|
|
1975
|
+
nk_size_t remaining = count - idx;
|
|
1976
|
+
if (remaining > 0) {
|
|
1977
|
+
__mmask16 tail_mask = (__mmask16)_bzhi_u32(0xFFFF, (unsigned int)remaining);
|
|
1978
|
+
__m512i data_u32x16 = _mm512_maskz_loadu_epi32(tail_mask, data_ptr + idx);
|
|
1979
|
+
__m256i low_u32x8 = _mm512_castsi512_si256(data_u32x16);
|
|
1980
|
+
__m256i high_u32x8 = _mm512_extracti64x4_epi64(data_u32x16, 1);
|
|
1981
|
+
sum_u64x8 = _mm512_add_epi64(sum_u64x8, _mm512_cvtepu32_epi64(low_u32x8));
|
|
1982
|
+
if (remaining > 8) sum_u64x8 = _mm512_add_epi64(sum_u64x8, _mm512_cvtepu32_epi64(high_u32x8));
|
|
1983
|
+
__m512i even_sq_u64x8 = _mm512_mul_epu32(data_u32x16, data_u32x16);
|
|
1984
|
+
__m512i odd_u32x16 = _mm512_srli_epi64(data_u32x16, 32);
|
|
1985
|
+
__m512i odd_sq_u64x8 = _mm512_mul_epu32(odd_u32x16, odd_u32x16);
|
|
1986
|
+
sumsq_u64x8 = _mm512_add_epi64(sumsq_u64x8, even_sq_u64x8);
|
|
1987
|
+
sumsq_overflow_mask |= _mm512_cmp_epu64_mask(sumsq_u64x8, even_sq_u64x8, _MM_CMPINT_LT);
|
|
1988
|
+
sumsq_u64x8 = _mm512_add_epi64(sumsq_u64x8, odd_sq_u64x8);
|
|
1989
|
+
sumsq_overflow_mask |= _mm512_cmp_epu64_mask(sumsq_u64x8, odd_sq_u64x8, _MM_CMPINT_LT);
|
|
1990
|
+
}
|
|
1991
|
+
nk_u64_t sum = nk_reduce_add_u64x8_skylake_(sum_u64x8);
|
|
1992
|
+
nk_u64_t sumsq;
|
|
1993
|
+
if (sumsq_overflow_mask) sumsq = NK_U64_MAX;
|
|
1994
|
+
else sumsq = nk_reduce_sadd_u64x8_skylake_(sumsq_u64x8);
|
|
1995
|
+
*sum_ptr = sum, *sumsq_ptr = sumsq;
|
|
1996
|
+
}
|
|
1997
|
+
|
|
1998
|
+
NK_PUBLIC void nk_reduce_moments_u32_skylake( //
|
|
1999
|
+
nk_u32_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
2000
|
+
nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
2001
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_u32_t);
|
|
2002
|
+
int aligned = (stride_bytes % sizeof(nk_u32_t) == 0);
|
|
2003
|
+
if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
|
|
2004
|
+
else if (!aligned) nk_reduce_moments_u32_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
2005
|
+
else if (count > (nk_size_t)(NK_U16_MAX + 1) * 16) {
|
|
2006
|
+
nk_size_t left_count = count / 2;
|
|
2007
|
+
nk_u64_t left_sum, left_sumsq, right_sum, right_sumsq;
|
|
2008
|
+
nk_reduce_moments_u32_skylake(data_ptr, left_count, stride_bytes, &left_sum, &left_sumsq);
|
|
2009
|
+
nk_reduce_moments_u32_skylake(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
|
|
2010
|
+
&right_sum, &right_sumsq);
|
|
2011
|
+
*sum_ptr = nk_u64_saturating_add_serial(left_sum, right_sum);
|
|
2012
|
+
*sumsq_ptr = nk_u64_saturating_add_serial(left_sumsq, right_sumsq);
|
|
2013
|
+
}
|
|
2014
|
+
else if (stride_elements == 1) nk_reduce_moments_u32_skylake_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
|
|
2015
|
+
else nk_reduce_moments_u32_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
2016
|
+
}
|
|
2017
|
+
|
|
2018
|
+
NK_INTERNAL void nk_reduce_minmax_u32_skylake_contiguous_( //
|
|
2019
|
+
nk_u32_t const *data_ptr, nk_size_t count, //
|
|
2020
|
+
nk_u32_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
2021
|
+
nk_u32_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
2022
|
+
__m512i min_u32x16 = _mm512_set1_epi32((nk_i32_t)NK_U32_MAX);
|
|
2023
|
+
__m512i max_u32x16 = _mm512_setzero_si512();
|
|
2024
|
+
__m512i min_loop_cycle_u32x16 = _mm512_setzero_si512();
|
|
2025
|
+
__m512i max_loop_cycle_u32x16 = _mm512_setzero_si512();
|
|
2026
|
+
__m512i current_loop_cycle_u32x16 = _mm512_setzero_si512();
|
|
2027
|
+
__m512i one_u32x16 = _mm512_set1_epi32(1);
|
|
2028
|
+
|
|
2029
|
+
nk_size_t idx = 0;
|
|
2030
|
+
for (; idx + 16 <= count; idx += 16) {
|
|
2031
|
+
__m512i data_u32x16 = _mm512_loadu_si512(data_ptr + idx);
|
|
2032
|
+
__mmask16 min_changed_mask = _mm512_cmp_epu32_mask(data_u32x16, min_u32x16, _MM_CMPINT_LT);
|
|
2033
|
+
__mmask16 max_changed_mask = _mm512_cmp_epu32_mask(data_u32x16, max_u32x16, _MM_CMPINT_NLE);
|
|
2034
|
+
min_u32x16 = _mm512_mask_mov_epi32(min_u32x16, min_changed_mask, data_u32x16);
|
|
2035
|
+
max_u32x16 = _mm512_mask_mov_epi32(max_u32x16, max_changed_mask, data_u32x16);
|
|
2036
|
+
min_loop_cycle_u32x16 = _mm512_mask_mov_epi32(min_loop_cycle_u32x16, min_changed_mask,
|
|
2037
|
+
current_loop_cycle_u32x16);
|
|
2038
|
+
max_loop_cycle_u32x16 = _mm512_mask_mov_epi32(max_loop_cycle_u32x16, max_changed_mask,
|
|
2039
|
+
current_loop_cycle_u32x16);
|
|
2040
|
+
current_loop_cycle_u32x16 = _mm512_add_epi32(current_loop_cycle_u32x16, one_u32x16);
|
|
2041
|
+
}
|
|
2042
|
+
|
|
2043
|
+
nk_size_t remaining = count - idx;
|
|
2044
|
+
if (remaining > 0) {
|
|
2045
|
+
__mmask16 tail_load = (__mmask16)_bzhi_u32(0xFFFF, (unsigned int)remaining);
|
|
2046
|
+
__m512i tail_u32x16 = _mm512_maskz_loadu_epi32(tail_load, data_ptr + idx);
|
|
2047
|
+
__mmask16 min_changed_mask = _mm512_mask_cmp_epu32_mask(tail_load, tail_u32x16, min_u32x16, _MM_CMPINT_LT);
|
|
2048
|
+
__mmask16 max_changed_mask = _mm512_mask_cmp_epu32_mask(tail_load, tail_u32x16, max_u32x16, _MM_CMPINT_NLE);
|
|
2049
|
+
min_u32x16 = _mm512_mask_mov_epi32(min_u32x16, min_changed_mask, tail_u32x16);
|
|
2050
|
+
max_u32x16 = _mm512_mask_mov_epi32(max_u32x16, max_changed_mask, tail_u32x16);
|
|
2051
|
+
min_loop_cycle_u32x16 = _mm512_mask_mov_epi32(min_loop_cycle_u32x16, min_changed_mask,
|
|
2052
|
+
current_loop_cycle_u32x16);
|
|
2053
|
+
max_loop_cycle_u32x16 = _mm512_mask_mov_epi32(max_loop_cycle_u32x16, max_changed_mask,
|
|
2054
|
+
current_loop_cycle_u32x16);
|
|
2055
|
+
}
|
|
2056
|
+
|
|
2057
|
+
nk_u32_t min_value = nk_reduce_min_u32x16_skylake_(min_u32x16);
|
|
2058
|
+
nk_u32_t max_value = nk_reduce_max_u32x16_skylake_(max_u32x16);
|
|
2059
|
+
unsigned int min_lane, max_lane;
|
|
2060
|
+
{
|
|
2061
|
+
__mmask16 value_match_mask = _mm512_cmpeq_epi32_mask(min_u32x16, _mm512_set1_epi32((nk_i32_t)min_value));
|
|
2062
|
+
__m512i masked_cycle_u32x16 = _mm512_mask_blend_epi32(value_match_mask, _mm512_set1_epi32((int)NK_U32_MAX),
|
|
2063
|
+
min_loop_cycle_u32x16);
|
|
2064
|
+
nk_u32_t earliest_loop_cycle = nk_reduce_min_u32x16_skylake_(masked_cycle_u32x16);
|
|
2065
|
+
__mmask16 cycle_match_mask = _mm512_cmpeq_epi32_mask(masked_cycle_u32x16,
|
|
2066
|
+
_mm512_set1_epi32((int)earliest_loop_cycle));
|
|
2067
|
+
min_lane = _tzcnt_u32(cycle_match_mask);
|
|
2068
|
+
}
|
|
2069
|
+
{
|
|
2070
|
+
__mmask16 value_match_mask = _mm512_cmpeq_epi32_mask(max_u32x16, _mm512_set1_epi32((nk_i32_t)max_value));
|
|
2071
|
+
__m512i masked_cycle_u32x16 = _mm512_mask_blend_epi32(value_match_mask, _mm512_set1_epi32((int)NK_U32_MAX),
|
|
2072
|
+
max_loop_cycle_u32x16);
|
|
2073
|
+
nk_u32_t earliest_loop_cycle = nk_reduce_min_u32x16_skylake_(masked_cycle_u32x16);
|
|
2074
|
+
__mmask16 cycle_match_mask = _mm512_cmpeq_epi32_mask(masked_cycle_u32x16,
|
|
2075
|
+
_mm512_set1_epi32((int)earliest_loop_cycle));
|
|
2076
|
+
max_lane = _tzcnt_u32(cycle_match_mask);
|
|
2077
|
+
}
|
|
2078
|
+
nk_b512_vec_t loop_cycle_vec;
|
|
2079
|
+
loop_cycle_vec.zmm = min_loop_cycle_u32x16;
|
|
2080
|
+
*min_value_ptr = min_value, *min_index_ptr = (nk_size_t)loop_cycle_vec.u32s[min_lane] * 16 + min_lane;
|
|
2081
|
+
loop_cycle_vec.zmm = max_loop_cycle_u32x16;
|
|
2082
|
+
*max_value_ptr = max_value, *max_index_ptr = (nk_size_t)loop_cycle_vec.u32s[max_lane] * 16 + max_lane;
|
|
2083
|
+
}
|
|
2084
|
+
|
|
2085
|
+
NK_PUBLIC void nk_reduce_minmax_u32_skylake( //
|
|
2086
|
+
nk_u32_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
2087
|
+
nk_u32_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
2088
|
+
nk_u32_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
2089
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_u32_t);
|
|
2090
|
+
int aligned = (stride_bytes % sizeof(nk_u32_t) == 0);
|
|
2091
|
+
if (count == 0)
|
|
2092
|
+
*min_value_ptr = NK_U32_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = 0, *max_index_ptr = NK_SIZE_MAX;
|
|
2093
|
+
else if (!aligned)
|
|
2094
|
+
nk_reduce_minmax_u32_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
2095
|
+
max_index_ptr);
|
|
2096
|
+
else if (count > (nk_size_t)NK_U32_MAX * 16) {
|
|
2097
|
+
nk_size_t left_count = count / 2;
|
|
2098
|
+
nk_u32_t left_min, right_min, left_max, right_max;
|
|
2099
|
+
nk_size_t left_min_index, right_min_index, left_max_index, right_max_index;
|
|
2100
|
+
nk_reduce_minmax_u32_skylake(data_ptr, left_count, stride_bytes, &left_min, &left_min_index, &left_max,
|
|
2101
|
+
&left_max_index);
|
|
2102
|
+
nk_reduce_minmax_u32_skylake(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
|
|
2103
|
+
&right_min, &right_min_index, &right_max, &right_max_index);
|
|
2104
|
+
if (right_min < left_min) *min_value_ptr = right_min, *min_index_ptr = left_count + right_min_index;
|
|
2105
|
+
else *min_value_ptr = left_min, *min_index_ptr = left_min_index;
|
|
2106
|
+
if (right_max > left_max) *max_value_ptr = right_max, *max_index_ptr = left_count + right_max_index;
|
|
2107
|
+
else *max_value_ptr = left_max, *max_index_ptr = left_max_index;
|
|
2108
|
+
}
|
|
2109
|
+
else if (stride_elements == 1)
|
|
2110
|
+
nk_reduce_minmax_u32_skylake_contiguous_(data_ptr, count, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
2111
|
+
max_index_ptr);
|
|
2112
|
+
else
|
|
2113
|
+
nk_reduce_minmax_u32_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
2114
|
+
max_index_ptr);
|
|
2115
|
+
}
|
|
2116
|
+
|
|
2117
|
+
NK_INTERNAL void nk_reduce_moments_i64_skylake_contiguous_( //
|
|
2118
|
+
nk_i64_t const *data_ptr, nk_size_t count, //
|
|
2119
|
+
nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
2120
|
+
// Sum: double-width 128-bit accumulation per lane.
|
|
2121
|
+
// Sumsq: unsigned wrapping accumulation with carry-based overflow detection.
|
|
2122
|
+
__m512i sum_lower_i64x8 = _mm512_setzero_si512();
|
|
2123
|
+
__m512i sum_upper_i64x8 = _mm512_setzero_si512();
|
|
2124
|
+
__m512i sumsq_u64x8 = _mm512_setzero_si512();
|
|
2125
|
+
__mmask8 sumsq_overflow_mask = 0;
|
|
2126
|
+
__m512i one_i64x8 = _mm512_set1_epi64(1);
|
|
2127
|
+
nk_size_t idx = 0;
|
|
2128
|
+
for (; idx + 8 <= count; idx += 8) {
|
|
2129
|
+
__m512i data_i64x8 = _mm512_loadu_si512(data_ptr + idx);
|
|
2130
|
+
__m512i squared_i64x8 = nk_i64_smul_sq_epi64_skylake_(data_i64x8);
|
|
2131
|
+
sumsq_u64x8 = _mm512_add_epi64(sumsq_u64x8, squared_i64x8);
|
|
2132
|
+
sumsq_overflow_mask |= _mm512_cmp_epu64_mask(sumsq_u64x8, squared_i64x8, _MM_CMPINT_LT);
|
|
2133
|
+
__m512i sum_before_i64x8 = sum_lower_i64x8;
|
|
2134
|
+
sum_lower_i64x8 = _mm512_add_epi64(sum_lower_i64x8, data_i64x8);
|
|
2135
|
+
__mmask8 carry = _mm512_cmp_epu64_mask(sum_lower_i64x8, sum_before_i64x8, _MM_CMPINT_LT);
|
|
2136
|
+
sum_upper_i64x8 = _mm512_add_epi64(sum_upper_i64x8, _mm512_srai_epi64(data_i64x8, 63));
|
|
2137
|
+
sum_upper_i64x8 = _mm512_mask_add_epi64(sum_upper_i64x8, carry, sum_upper_i64x8, one_i64x8);
|
|
2138
|
+
}
|
|
2139
|
+
nk_size_t remaining = count - idx;
|
|
2140
|
+
if (remaining > 0) {
|
|
2141
|
+
__mmask8 tail_mask = (__mmask8)_bzhi_u32(0xFF, (unsigned int)remaining);
|
|
2142
|
+
__m512i data_i64x8 = _mm512_maskz_loadu_epi64(tail_mask, data_ptr + idx);
|
|
2143
|
+
__m512i squared_i64x8 = nk_i64_smul_sq_epi64_skylake_(data_i64x8);
|
|
2144
|
+
sumsq_u64x8 = _mm512_add_epi64(sumsq_u64x8, squared_i64x8);
|
|
2145
|
+
sumsq_overflow_mask |= _mm512_cmp_epu64_mask(sumsq_u64x8, squared_i64x8, _MM_CMPINT_LT);
|
|
2146
|
+
__m512i sum_before_i64x8 = sum_lower_i64x8;
|
|
2147
|
+
sum_lower_i64x8 = _mm512_add_epi64(sum_lower_i64x8, data_i64x8);
|
|
2148
|
+
__mmask8 carry = _mm512_cmp_epu64_mask(sum_lower_i64x8, sum_before_i64x8, _MM_CMPINT_LT);
|
|
2149
|
+
sum_upper_i64x8 = _mm512_add_epi64(sum_upper_i64x8, _mm512_srai_epi64(data_i64x8, 63));
|
|
2150
|
+
sum_upper_i64x8 = _mm512_mask_add_epi64(sum_upper_i64x8, carry, sum_upper_i64x8, one_i64x8);
|
|
2151
|
+
}
|
|
2152
|
+
// Sumsq: horizontal unsigned saturating reduction
|
|
2153
|
+
nk_u64_t sumsq;
|
|
2154
|
+
if (sumsq_overflow_mask) sumsq = NK_U64_MAX;
|
|
2155
|
+
else sumsq = nk_reduce_sadd_u64x8_skylake_(sumsq_u64x8);
|
|
2156
|
+
// Sum: horizontal 128-bit tree reduction (8→4→2→1), then clamp to i64
|
|
2157
|
+
{ // 8→4: fold high 256 bits into low 256 bits
|
|
2158
|
+
__m512i fold_lower_i64x8 = _mm512_shuffle_i64x2(sum_lower_i64x8, sum_lower_i64x8, _MM_SHUFFLE(1, 0, 3, 2));
|
|
2159
|
+
__m512i fold_upper_i64x8 = _mm512_shuffle_i64x2(sum_upper_i64x8, sum_upper_i64x8, _MM_SHUFFLE(1, 0, 3, 2));
|
|
2160
|
+
__m512i before_i64x8 = sum_lower_i64x8;
|
|
2161
|
+
sum_lower_i64x8 = _mm512_add_epi64(sum_lower_i64x8, fold_lower_i64x8);
|
|
2162
|
+
__mmask8 carry = _mm512_cmp_epu64_mask(sum_lower_i64x8, before_i64x8, _MM_CMPINT_LT);
|
|
2163
|
+
sum_upper_i64x8 = _mm512_add_epi64(sum_upper_i64x8, fold_upper_i64x8);
|
|
2164
|
+
sum_upper_i64x8 = _mm512_mask_add_epi64(sum_upper_i64x8, carry, sum_upper_i64x8, one_i64x8);
|
|
2165
|
+
}
|
|
2166
|
+
{ // 4→2: fold lanes 2-3 into lanes 0-1
|
|
2167
|
+
__m512i fold_lower_i64x8 = _mm512_shuffle_i64x2(sum_lower_i64x8, sum_lower_i64x8, _MM_SHUFFLE(2, 3, 0, 1));
|
|
2168
|
+
__m512i fold_upper_i64x8 = _mm512_shuffle_i64x2(sum_upper_i64x8, sum_upper_i64x8, _MM_SHUFFLE(2, 3, 0, 1));
|
|
2169
|
+
__m512i before_i64x8 = sum_lower_i64x8;
|
|
2170
|
+
sum_lower_i64x8 = _mm512_add_epi64(sum_lower_i64x8, fold_lower_i64x8);
|
|
2171
|
+
__mmask8 carry = _mm512_cmp_epu64_mask(sum_lower_i64x8, before_i64x8, _MM_CMPINT_LT);
|
|
2172
|
+
sum_upper_i64x8 = _mm512_add_epi64(sum_upper_i64x8, fold_upper_i64x8);
|
|
2173
|
+
sum_upper_i64x8 = _mm512_mask_add_epi64(sum_upper_i64x8, carry, sum_upper_i64x8, one_i64x8);
|
|
2174
|
+
}
|
|
2175
|
+
{ // 2→1: fold lane 1 into lane 0
|
|
2176
|
+
__m512i fold_lower_i64x8 = _mm512_alignr_epi64(sum_lower_i64x8, sum_lower_i64x8, 1);
|
|
2177
|
+
__m512i fold_upper_i64x8 = _mm512_alignr_epi64(sum_upper_i64x8, sum_upper_i64x8, 1);
|
|
2178
|
+
__m512i before_i64x8 = sum_lower_i64x8;
|
|
2179
|
+
sum_lower_i64x8 = _mm512_add_epi64(sum_lower_i64x8, fold_lower_i64x8);
|
|
2180
|
+
__mmask8 carry = _mm512_cmp_epu64_mask(sum_lower_i64x8, before_i64x8, _MM_CMPINT_LT);
|
|
2181
|
+
sum_upper_i64x8 = _mm512_add_epi64(sum_upper_i64x8, fold_upper_i64x8);
|
|
2182
|
+
sum_upper_i64x8 = _mm512_mask_add_epi64(sum_upper_i64x8, carry, sum_upper_i64x8, one_i64x8);
|
|
2183
|
+
}
|
|
2184
|
+
// Clamp 128-bit result to [INT64_MIN, INT64_MAX]: fits iff upper == sign-extension of lower
|
|
2185
|
+
nk_i64_t sum_lower = _mm_cvtsi128_si64(_mm512_castsi512_si128(sum_lower_i64x8));
|
|
2186
|
+
nk_i64_t sum_upper = _mm_cvtsi128_si64(_mm512_castsi512_si128(sum_upper_i64x8));
|
|
2187
|
+
if (sum_upper == (sum_lower >> 63)) *sum_ptr = sum_lower;
|
|
2188
|
+
else if (sum_upper >= 0) *sum_ptr = NK_I64_MAX;
|
|
2189
|
+
else *sum_ptr = NK_I64_MIN;
|
|
2190
|
+
*sumsq_ptr = sumsq;
|
|
2191
|
+
}
|
|
2192
|
+
|
|
2193
|
+
NK_PUBLIC void nk_reduce_moments_i64_skylake( //
|
|
2194
|
+
nk_i64_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
2195
|
+
nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
2196
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_i64_t);
|
|
2197
|
+
int aligned = (stride_bytes % sizeof(nk_i64_t) == 0);
|
|
2198
|
+
if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
|
|
2199
|
+
else if (!aligned) nk_reduce_moments_i64_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
2200
|
+
else if (stride_elements == 1) nk_reduce_moments_i64_skylake_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
|
|
2201
|
+
else nk_reduce_moments_i64_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
2202
|
+
}
|
|
2203
|
+
|
|
2204
|
+
NK_INTERNAL void nk_reduce_minmax_i64_skylake_contiguous_( //
|
|
2205
|
+
nk_i64_t const *data_ptr, nk_size_t count, //
|
|
2206
|
+
nk_i64_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
2207
|
+
nk_i64_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
2208
|
+
__m512i min_i64x8 = _mm512_set1_epi64(NK_I64_MAX);
|
|
2209
|
+
__m512i max_i64x8 = _mm512_set1_epi64(NK_I64_MIN);
|
|
2210
|
+
__m512i min_loop_cycle_u64x8 = _mm512_setzero_si512();
|
|
2211
|
+
__m512i max_loop_cycle_u64x8 = _mm512_setzero_si512();
|
|
2212
|
+
__m512i current_loop_cycle_u64x8 = _mm512_setzero_si512();
|
|
2213
|
+
__m512i one_u64x8 = _mm512_set1_epi64(1);
|
|
2214
|
+
|
|
2215
|
+
nk_size_t idx = 0;
|
|
2216
|
+
for (; idx + 8 <= count; idx += 8) {
|
|
2217
|
+
__m512i data_i64x8 = _mm512_loadu_si512(data_ptr + idx);
|
|
2218
|
+
__mmask8 min_changed_mask = _mm512_cmp_epi64_mask(data_i64x8, min_i64x8, _MM_CMPINT_LT);
|
|
2219
|
+
__mmask8 max_changed_mask = _mm512_cmp_epi64_mask(data_i64x8, max_i64x8, _MM_CMPINT_NLE);
|
|
2220
|
+
min_i64x8 = _mm512_mask_mov_epi64(min_i64x8, min_changed_mask, data_i64x8);
|
|
2221
|
+
max_i64x8 = _mm512_mask_mov_epi64(max_i64x8, max_changed_mask, data_i64x8);
|
|
2222
|
+
min_loop_cycle_u64x8 = _mm512_mask_mov_epi64(min_loop_cycle_u64x8, min_changed_mask, current_loop_cycle_u64x8);
|
|
2223
|
+
max_loop_cycle_u64x8 = _mm512_mask_mov_epi64(max_loop_cycle_u64x8, max_changed_mask, current_loop_cycle_u64x8);
|
|
2224
|
+
current_loop_cycle_u64x8 = _mm512_add_epi64(current_loop_cycle_u64x8, one_u64x8);
|
|
2225
|
+
}
|
|
2226
|
+
|
|
2227
|
+
nk_size_t remaining = count - idx;
|
|
2228
|
+
if (remaining > 0) {
|
|
2229
|
+
__mmask8 tail_load = (__mmask8)_bzhi_u32(0xFF, (unsigned int)remaining);
|
|
2230
|
+
__m512i tail_i64x8 = _mm512_maskz_loadu_epi64(tail_load, data_ptr + idx);
|
|
2231
|
+
__mmask8 min_changed_mask = _mm512_mask_cmp_epi64_mask(tail_load, tail_i64x8, min_i64x8, _MM_CMPINT_LT);
|
|
2232
|
+
__mmask8 max_changed_mask = _mm512_mask_cmp_epi64_mask(tail_load, tail_i64x8, max_i64x8, _MM_CMPINT_NLE);
|
|
2233
|
+
min_i64x8 = _mm512_mask_mov_epi64(min_i64x8, min_changed_mask, tail_i64x8);
|
|
2234
|
+
max_i64x8 = _mm512_mask_mov_epi64(max_i64x8, max_changed_mask, tail_i64x8);
|
|
2235
|
+
min_loop_cycle_u64x8 = _mm512_mask_mov_epi64(min_loop_cycle_u64x8, min_changed_mask, current_loop_cycle_u64x8);
|
|
2236
|
+
max_loop_cycle_u64x8 = _mm512_mask_mov_epi64(max_loop_cycle_u64x8, max_changed_mask, current_loop_cycle_u64x8);
|
|
2237
|
+
}
|
|
2238
|
+
|
|
2239
|
+
nk_i64_t min_value = nk_reduce_min_i64x8_skylake_(min_i64x8);
|
|
2240
|
+
nk_i64_t max_value = nk_reduce_max_i64x8_skylake_(max_i64x8);
|
|
2241
|
+
unsigned int min_lane, max_lane;
|
|
2242
|
+
{
|
|
2243
|
+
__mmask8 value_match_mask = _mm512_cmpeq_epi64_mask(min_i64x8, _mm512_set1_epi64(min_value));
|
|
2244
|
+
__m512i masked_cycle_u64x8 = _mm512_mask_blend_epi64(value_match_mask, _mm512_set1_epi64((nk_i64_t)NK_U64_MAX),
|
|
2245
|
+
min_loop_cycle_u64x8);
|
|
2246
|
+
nk_u64_t earliest_loop_cycle = nk_reduce_min_u64x8_skylake_(masked_cycle_u64x8);
|
|
2247
|
+
__mmask8 cycle_match_mask = _mm512_cmpeq_epi64_mask(masked_cycle_u64x8,
|
|
2248
|
+
_mm512_set1_epi64((nk_i64_t)earliest_loop_cycle));
|
|
2249
|
+
min_lane = _tzcnt_u32((unsigned int)cycle_match_mask);
|
|
2250
|
+
}
|
|
2251
|
+
{
|
|
2252
|
+
__mmask8 value_match_mask = _mm512_cmpeq_epi64_mask(max_i64x8, _mm512_set1_epi64(max_value));
|
|
2253
|
+
__m512i masked_cycle_u64x8 = _mm512_mask_blend_epi64(value_match_mask, _mm512_set1_epi64((nk_i64_t)NK_U64_MAX),
|
|
2254
|
+
max_loop_cycle_u64x8);
|
|
2255
|
+
nk_u64_t earliest_loop_cycle = nk_reduce_min_u64x8_skylake_(masked_cycle_u64x8);
|
|
2256
|
+
__mmask8 cycle_match_mask = _mm512_cmpeq_epi64_mask(masked_cycle_u64x8,
|
|
2257
|
+
_mm512_set1_epi64((nk_i64_t)earliest_loop_cycle));
|
|
2258
|
+
max_lane = _tzcnt_u32((unsigned int)cycle_match_mask);
|
|
2259
|
+
}
|
|
2260
|
+
nk_b512_vec_t loop_cycle_vec;
|
|
2261
|
+
loop_cycle_vec.zmm = min_loop_cycle_u64x8;
|
|
2262
|
+
*min_value_ptr = min_value, *min_index_ptr = (nk_size_t)loop_cycle_vec.u64s[min_lane] * 8 + min_lane;
|
|
2263
|
+
loop_cycle_vec.zmm = max_loop_cycle_u64x8;
|
|
2264
|
+
*max_value_ptr = max_value, *max_index_ptr = (nk_size_t)loop_cycle_vec.u64s[max_lane] * 8 + max_lane;
|
|
2265
|
+
}
|
|
2266
|
+
|
|
2267
|
+
NK_PUBLIC void nk_reduce_minmax_i64_skylake( //
|
|
2268
|
+
nk_i64_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
2269
|
+
nk_i64_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
2270
|
+
nk_i64_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
2271
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_i64_t);
|
|
2272
|
+
int aligned = (stride_bytes % sizeof(nk_i64_t) == 0);
|
|
2273
|
+
if (count == 0)
|
|
2274
|
+
*min_value_ptr = NK_I64_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_I64_MIN,
|
|
2275
|
+
*max_index_ptr = NK_SIZE_MAX;
|
|
2276
|
+
else if (!aligned)
|
|
2277
|
+
nk_reduce_minmax_i64_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
2278
|
+
max_index_ptr);
|
|
2279
|
+
else if (stride_elements == 1)
|
|
2280
|
+
nk_reduce_minmax_i64_skylake_contiguous_(data_ptr, count, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
2281
|
+
max_index_ptr);
|
|
2282
|
+
else
|
|
2283
|
+
nk_reduce_minmax_i64_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
2284
|
+
max_index_ptr);
|
|
2285
|
+
}
|
|
2286
|
+
|
|
2287
|
+
NK_INTERNAL void nk_reduce_moments_u64_skylake_contiguous_( //
|
|
2288
|
+
nk_u64_t const *data_ptr, nk_size_t count, //
|
|
2289
|
+
nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
2290
|
+
// Unsigned saturating addition is order-independent: sat(sat(a+b)+c) == sat(a+b+c).
|
|
2291
|
+
// Once a lane saturates it stays saturated, so a running overflow mask is sufficient
|
|
2292
|
+
// for any count — no block cap or 128-bit accumulation needed.
|
|
2293
|
+
__m512i sum_u64x8 = _mm512_setzero_si512();
|
|
2294
|
+
__m512i sumsq_u64x8 = _mm512_setzero_si512();
|
|
2295
|
+
__mmask8 sum_overflow_mask = 0, sumsq_overflow_mask = 0;
|
|
2296
|
+
nk_size_t idx = 0;
|
|
2297
|
+
for (; idx + 8 <= count; idx += 8) {
|
|
2298
|
+
__m512i data_u64x8 = _mm512_loadu_si512(data_ptr + idx);
|
|
2299
|
+
sum_u64x8 = _mm512_add_epi64(sum_u64x8, data_u64x8);
|
|
2300
|
+
sum_overflow_mask |= _mm512_cmp_epu64_mask(sum_u64x8, data_u64x8, _MM_CMPINT_LT);
|
|
2301
|
+
__m512i squared_u64x8 = nk_u64_smul_sq_epi64_skylake_(data_u64x8);
|
|
2302
|
+
sumsq_u64x8 = _mm512_add_epi64(sumsq_u64x8, squared_u64x8);
|
|
2303
|
+
sumsq_overflow_mask |= _mm512_cmp_epu64_mask(sumsq_u64x8, squared_u64x8, _MM_CMPINT_LT);
|
|
2304
|
+
}
|
|
2305
|
+
nk_size_t remaining = count - idx;
|
|
2306
|
+
if (remaining > 0) {
|
|
2307
|
+
__mmask8 tail_mask = (__mmask8)_bzhi_u32(0xFF, (unsigned int)remaining);
|
|
2308
|
+
__m512i data_u64x8 = _mm512_maskz_loadu_epi64(tail_mask, data_ptr + idx);
|
|
2309
|
+
sum_u64x8 = _mm512_add_epi64(sum_u64x8, data_u64x8);
|
|
2310
|
+
sum_overflow_mask |= _mm512_cmp_epu64_mask(sum_u64x8, data_u64x8, _MM_CMPINT_LT);
|
|
2311
|
+
__m512i squared_u64x8 = nk_u64_smul_sq_epi64_skylake_(data_u64x8);
|
|
2312
|
+
sumsq_u64x8 = _mm512_add_epi64(sumsq_u64x8, squared_u64x8);
|
|
2313
|
+
sumsq_overflow_mask |= _mm512_cmp_epu64_mask(sumsq_u64x8, squared_u64x8, _MM_CMPINT_LT);
|
|
2314
|
+
}
|
|
2315
|
+
nk_u64_t sum_scalar;
|
|
2316
|
+
if (sum_overflow_mask) sum_scalar = NK_U64_MAX;
|
|
2317
|
+
else sum_scalar = nk_reduce_sadd_u64x8_skylake_(sum_u64x8);
|
|
2318
|
+
nk_u64_t sumsq_scalar;
|
|
2319
|
+
if (sumsq_overflow_mask) sumsq_scalar = NK_U64_MAX;
|
|
2320
|
+
else sumsq_scalar = nk_reduce_sadd_u64x8_skylake_(sumsq_u64x8);
|
|
2321
|
+
*sum_ptr = sum_scalar, *sumsq_ptr = sumsq_scalar;
|
|
2322
|
+
}
|
|
2323
|
+
|
|
2324
|
+
NK_PUBLIC void nk_reduce_moments_u64_skylake( //
|
|
2325
|
+
nk_u64_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
2326
|
+
nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
2327
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_u64_t);
|
|
2328
|
+
int aligned = (stride_bytes % sizeof(nk_u64_t) == 0);
|
|
2329
|
+
if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
|
|
2330
|
+
else if (!aligned) nk_reduce_moments_u64_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
2331
|
+
else if (stride_elements == 1) nk_reduce_moments_u64_skylake_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
|
|
2332
|
+
else nk_reduce_moments_u64_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
2333
|
+
}
|
|
2334
|
+
|
|
2335
|
+
NK_INTERNAL void nk_reduce_minmax_u64_skylake_contiguous_( //
|
|
2336
|
+
nk_u64_t const *data_ptr, nk_size_t count, //
|
|
2337
|
+
nk_u64_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
2338
|
+
nk_u64_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
2339
|
+
__m512i min_u64x8 = _mm512_set1_epi64((nk_i64_t)NK_U64_MAX);
|
|
2340
|
+
__m512i max_u64x8 = _mm512_setzero_si512();
|
|
2341
|
+
__m512i min_loop_cycle_u64x8 = _mm512_setzero_si512();
|
|
2342
|
+
__m512i max_loop_cycle_u64x8 = _mm512_setzero_si512();
|
|
2343
|
+
__m512i current_loop_cycle_u64x8 = _mm512_setzero_si512();
|
|
2344
|
+
__m512i one_u64x8 = _mm512_set1_epi64(1);
|
|
2345
|
+
|
|
2346
|
+
nk_size_t idx = 0;
|
|
2347
|
+
for (; idx + 8 <= count; idx += 8) {
|
|
2348
|
+
__m512i data_u64x8 = _mm512_loadu_si512(data_ptr + idx);
|
|
2349
|
+
__mmask8 min_changed_mask = _mm512_cmp_epu64_mask(data_u64x8, min_u64x8, _MM_CMPINT_LT);
|
|
2350
|
+
__mmask8 max_changed_mask = _mm512_cmp_epu64_mask(data_u64x8, max_u64x8, _MM_CMPINT_NLE);
|
|
2351
|
+
min_u64x8 = _mm512_mask_mov_epi64(min_u64x8, min_changed_mask, data_u64x8);
|
|
2352
|
+
max_u64x8 = _mm512_mask_mov_epi64(max_u64x8, max_changed_mask, data_u64x8);
|
|
2353
|
+
min_loop_cycle_u64x8 = _mm512_mask_mov_epi64(min_loop_cycle_u64x8, min_changed_mask, current_loop_cycle_u64x8);
|
|
2354
|
+
max_loop_cycle_u64x8 = _mm512_mask_mov_epi64(max_loop_cycle_u64x8, max_changed_mask, current_loop_cycle_u64x8);
|
|
2355
|
+
current_loop_cycle_u64x8 = _mm512_add_epi64(current_loop_cycle_u64x8, one_u64x8);
|
|
2356
|
+
}
|
|
2357
|
+
|
|
2358
|
+
nk_size_t remaining = count - idx;
|
|
2359
|
+
if (remaining > 0) {
|
|
2360
|
+
__mmask8 tail_load = (__mmask8)_bzhi_u32(0xFF, (unsigned int)remaining);
|
|
2361
|
+
__m512i tail_u64x8 = _mm512_maskz_loadu_epi64(tail_load, data_ptr + idx);
|
|
2362
|
+
__mmask8 min_changed_mask = _mm512_mask_cmp_epu64_mask(tail_load, tail_u64x8, min_u64x8, _MM_CMPINT_LT);
|
|
2363
|
+
__mmask8 max_changed_mask = _mm512_mask_cmp_epu64_mask(tail_load, tail_u64x8, max_u64x8, _MM_CMPINT_NLE);
|
|
2364
|
+
min_u64x8 = _mm512_mask_mov_epi64(min_u64x8, min_changed_mask, tail_u64x8);
|
|
2365
|
+
max_u64x8 = _mm512_mask_mov_epi64(max_u64x8, max_changed_mask, tail_u64x8);
|
|
2366
|
+
min_loop_cycle_u64x8 = _mm512_mask_mov_epi64(min_loop_cycle_u64x8, min_changed_mask, current_loop_cycle_u64x8);
|
|
2367
|
+
max_loop_cycle_u64x8 = _mm512_mask_mov_epi64(max_loop_cycle_u64x8, max_changed_mask, current_loop_cycle_u64x8);
|
|
2368
|
+
}
|
|
2369
|
+
|
|
2370
|
+
nk_u64_t min_value = nk_reduce_min_u64x8_skylake_(min_u64x8);
|
|
2371
|
+
nk_u64_t max_value = nk_reduce_max_u64x8_skylake_(max_u64x8);
|
|
2372
|
+
unsigned int min_lane, max_lane;
|
|
2373
|
+
{
|
|
2374
|
+
__mmask8 value_match_mask = _mm512_cmpeq_epi64_mask(min_u64x8, _mm512_set1_epi64((nk_i64_t)min_value));
|
|
2375
|
+
__m512i masked_cycle_u64x8 = _mm512_mask_blend_epi64(value_match_mask, _mm512_set1_epi64((nk_i64_t)NK_U64_MAX),
|
|
2376
|
+
min_loop_cycle_u64x8);
|
|
2377
|
+
nk_u64_t earliest_loop_cycle = nk_reduce_min_u64x8_skylake_(masked_cycle_u64x8);
|
|
2378
|
+
__mmask8 cycle_match_mask = _mm512_cmpeq_epi64_mask(masked_cycle_u64x8,
|
|
2379
|
+
_mm512_set1_epi64((nk_i64_t)earliest_loop_cycle));
|
|
2380
|
+
min_lane = _tzcnt_u32((unsigned int)cycle_match_mask);
|
|
2381
|
+
}
|
|
2382
|
+
{
|
|
2383
|
+
__mmask8 value_match_mask = _mm512_cmpeq_epi64_mask(max_u64x8, _mm512_set1_epi64((nk_i64_t)max_value));
|
|
2384
|
+
__m512i masked_cycle_u64x8 = _mm512_mask_blend_epi64(value_match_mask, _mm512_set1_epi64((nk_i64_t)NK_U64_MAX),
|
|
2385
|
+
max_loop_cycle_u64x8);
|
|
2386
|
+
nk_u64_t earliest_loop_cycle = nk_reduce_min_u64x8_skylake_(masked_cycle_u64x8);
|
|
2387
|
+
__mmask8 cycle_match_mask = _mm512_cmpeq_epi64_mask(masked_cycle_u64x8,
|
|
2388
|
+
_mm512_set1_epi64((nk_i64_t)earliest_loop_cycle));
|
|
2389
|
+
max_lane = _tzcnt_u32((unsigned int)cycle_match_mask);
|
|
2390
|
+
}
|
|
2391
|
+
nk_b512_vec_t loop_cycle_vec;
|
|
2392
|
+
loop_cycle_vec.zmm = min_loop_cycle_u64x8;
|
|
2393
|
+
*min_value_ptr = min_value, *min_index_ptr = (nk_size_t)loop_cycle_vec.u64s[min_lane] * 8 + min_lane;
|
|
2394
|
+
loop_cycle_vec.zmm = max_loop_cycle_u64x8;
|
|
2395
|
+
*max_value_ptr = max_value, *max_index_ptr = (nk_size_t)loop_cycle_vec.u64s[max_lane] * 8 + max_lane;
|
|
2396
|
+
}
|
|
2397
|
+
|
|
2398
|
+
NK_PUBLIC void nk_reduce_minmax_u64_skylake( //
|
|
2399
|
+
nk_u64_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
2400
|
+
nk_u64_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
2401
|
+
nk_u64_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
2402
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_u64_t);
|
|
2403
|
+
int aligned = (stride_bytes % sizeof(nk_u64_t) == 0);
|
|
2404
|
+
if (count == 0)
|
|
2405
|
+
*min_value_ptr = NK_U64_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = 0, *max_index_ptr = NK_SIZE_MAX;
|
|
2406
|
+
else if (!aligned)
|
|
2407
|
+
nk_reduce_minmax_u64_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
2408
|
+
max_index_ptr);
|
|
2409
|
+
else if (stride_elements == 1)
|
|
2410
|
+
nk_reduce_minmax_u64_skylake_contiguous_(data_ptr, count, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
2411
|
+
max_index_ptr);
|
|
2412
|
+
else
|
|
2413
|
+
nk_reduce_minmax_u64_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
2414
|
+
max_index_ptr);
|
|
2415
|
+
}
|
|
2416
|
+
|
|
2417
|
+
NK_INTERNAL void nk_reduce_minmax_f64_skylake_contiguous_( //
|
|
2418
|
+
nk_f64_t const *data_ptr, nk_size_t count, //
|
|
2419
|
+
nk_f64_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
2420
|
+
nk_f64_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
2421
|
+
__m512d min_f64x8 = _mm512_set1_pd(NK_F64_MAX);
|
|
2422
|
+
__m512d max_f64x8 = _mm512_set1_pd(NK_F64_MIN);
|
|
2423
|
+
__m512i min_loop_cycle_u64x8 = _mm512_setzero_si512();
|
|
2424
|
+
__m512i max_loop_cycle_u64x8 = _mm512_setzero_si512();
|
|
2425
|
+
__m512i current_loop_cycle_u64x8 = _mm512_setzero_si512();
|
|
2426
|
+
__m512i one_u64x8 = _mm512_set1_epi64(1);
|
|
2427
|
+
|
|
2428
|
+
nk_size_t idx = 0;
|
|
2429
|
+
for (; idx + 8 <= count; idx += 8) {
|
|
2430
|
+
__m512d data_f64x8 = _mm512_loadu_pd(data_ptr + idx);
|
|
2431
|
+
__mmask8 min_changed_mask = _mm512_cmp_pd_mask(data_f64x8, min_f64x8, _CMP_LT_OQ);
|
|
2432
|
+
__mmask8 max_changed_mask = _mm512_cmp_pd_mask(data_f64x8, max_f64x8, _CMP_GT_OQ);
|
|
2433
|
+
min_f64x8 = _mm512_mask_mov_pd(min_f64x8, min_changed_mask, data_f64x8);
|
|
2434
|
+
max_f64x8 = _mm512_mask_mov_pd(max_f64x8, max_changed_mask, data_f64x8);
|
|
2435
|
+
min_loop_cycle_u64x8 = _mm512_mask_mov_epi64(min_loop_cycle_u64x8, min_changed_mask, current_loop_cycle_u64x8);
|
|
2436
|
+
max_loop_cycle_u64x8 = _mm512_mask_mov_epi64(max_loop_cycle_u64x8, max_changed_mask, current_loop_cycle_u64x8);
|
|
2437
|
+
current_loop_cycle_u64x8 = _mm512_add_epi64(current_loop_cycle_u64x8, one_u64x8);
|
|
2438
|
+
}
|
|
2439
|
+
|
|
2440
|
+
nk_size_t remaining = count - idx;
|
|
2441
|
+
if (remaining > 0) {
|
|
2442
|
+
__mmask8 tail_load = (__mmask8)_bzhi_u32(0xFF, (unsigned int)remaining);
|
|
2443
|
+
__m512d tail_f64x8 = _mm512_maskz_loadu_pd(tail_load, data_ptr + idx);
|
|
2444
|
+
__mmask8 min_changed_mask = _mm512_mask_cmp_pd_mask(tail_load, tail_f64x8, min_f64x8, _CMP_LT_OQ);
|
|
2445
|
+
__mmask8 max_changed_mask = _mm512_mask_cmp_pd_mask(tail_load, tail_f64x8, max_f64x8, _CMP_GT_OQ);
|
|
2446
|
+
min_f64x8 = _mm512_mask_mov_pd(min_f64x8, min_changed_mask, tail_f64x8);
|
|
2447
|
+
max_f64x8 = _mm512_mask_mov_pd(max_f64x8, max_changed_mask, tail_f64x8);
|
|
2448
|
+
min_loop_cycle_u64x8 = _mm512_mask_mov_epi64(min_loop_cycle_u64x8, min_changed_mask, current_loop_cycle_u64x8);
|
|
2449
|
+
max_loop_cycle_u64x8 = _mm512_mask_mov_epi64(max_loop_cycle_u64x8, max_changed_mask, current_loop_cycle_u64x8);
|
|
2450
|
+
}
|
|
2451
|
+
|
|
2452
|
+
nk_f64_t min_value = nk_reduce_min_f64x8_skylake_(min_f64x8);
|
|
2453
|
+
nk_f64_t max_value = nk_reduce_max_f64x8_skylake_(max_f64x8);
|
|
2454
|
+
if (min_value == NK_F64_MAX && max_value == NK_F64_MIN) {
|
|
2455
|
+
*min_value_ptr = NK_F64_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_F64_MIN,
|
|
2456
|
+
*max_index_ptr = NK_SIZE_MAX;
|
|
2457
|
+
return;
|
|
2458
|
+
}
|
|
2459
|
+
unsigned int min_lane, max_lane;
|
|
2460
|
+
{
|
|
2461
|
+
__mmask8 value_match_mask = _mm512_cmp_pd_mask(min_f64x8, _mm512_set1_pd(min_value), _CMP_EQ_OQ);
|
|
2462
|
+
__m512i masked_cycle_u64x8 = _mm512_mask_blend_epi64(value_match_mask, _mm512_set1_epi64((nk_i64_t)NK_U64_MAX),
|
|
2463
|
+
min_loop_cycle_u64x8);
|
|
2464
|
+
nk_u64_t earliest_loop_cycle = nk_reduce_min_u64x8_skylake_(masked_cycle_u64x8);
|
|
2465
|
+
__mmask8 cycle_match_mask = _mm512_cmpeq_epi64_mask(masked_cycle_u64x8,
|
|
2466
|
+
_mm512_set1_epi64((nk_i64_t)earliest_loop_cycle));
|
|
2467
|
+
min_lane = _tzcnt_u32((unsigned int)cycle_match_mask);
|
|
2468
|
+
}
|
|
2469
|
+
{
|
|
2470
|
+
__mmask8 value_match_mask = _mm512_cmp_pd_mask(max_f64x8, _mm512_set1_pd(max_value), _CMP_EQ_OQ);
|
|
2471
|
+
__m512i masked_cycle_u64x8 = _mm512_mask_blend_epi64(value_match_mask, _mm512_set1_epi64((nk_i64_t)NK_U64_MAX),
|
|
2472
|
+
max_loop_cycle_u64x8);
|
|
2473
|
+
nk_u64_t earliest_loop_cycle = nk_reduce_min_u64x8_skylake_(masked_cycle_u64x8);
|
|
2474
|
+
__mmask8 cycle_match_mask = _mm512_cmpeq_epi64_mask(masked_cycle_u64x8,
|
|
2475
|
+
_mm512_set1_epi64((nk_i64_t)earliest_loop_cycle));
|
|
2476
|
+
max_lane = _tzcnt_u32((unsigned int)cycle_match_mask);
|
|
2477
|
+
}
|
|
2478
|
+
nk_b512_vec_t loop_cycle_vec;
|
|
2479
|
+
loop_cycle_vec.zmm = min_loop_cycle_u64x8;
|
|
2480
|
+
*min_value_ptr = min_value, *min_index_ptr = (nk_size_t)loop_cycle_vec.u64s[min_lane] * 8 + min_lane;
|
|
2481
|
+
loop_cycle_vec.zmm = max_loop_cycle_u64x8;
|
|
2482
|
+
*max_value_ptr = max_value, *max_index_ptr = (nk_size_t)loop_cycle_vec.u64s[max_lane] * 8 + max_lane;
|
|
2483
|
+
}
|
|
2484
|
+
|
|
2485
|
+
NK_PUBLIC void nk_reduce_minmax_f64_skylake( //
|
|
2486
|
+
nk_f64_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
2487
|
+
nk_f64_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
2488
|
+
nk_f64_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
2489
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_f64_t);
|
|
2490
|
+
int aligned = (stride_bytes % sizeof(nk_f64_t) == 0);
|
|
2491
|
+
if (count == 0)
|
|
2492
|
+
*min_value_ptr = NK_F64_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_F64_MIN,
|
|
2493
|
+
*max_index_ptr = NK_SIZE_MAX;
|
|
2494
|
+
else if (!aligned)
|
|
2495
|
+
nk_reduce_minmax_f64_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
2496
|
+
max_index_ptr);
|
|
2497
|
+
else if (stride_elements == 1)
|
|
2498
|
+
nk_reduce_minmax_f64_skylake_contiguous_(data_ptr, count, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
2499
|
+
max_index_ptr);
|
|
2500
|
+
else
|
|
2501
|
+
nk_reduce_minmax_f64_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
2502
|
+
max_index_ptr);
|
|
2503
|
+
}
|
|
2504
|
+
|
|
2505
|
+
NK_INTERNAL void nk_reduce_moments_e4m3_skylake_contiguous_( //
|
|
2506
|
+
nk_e4m3_t const *data_ptr, nk_size_t count, //
|
|
2507
|
+
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
2508
|
+
__m512 sum_f32x16 = _mm512_setzero_ps();
|
|
2509
|
+
__m512 sumsq_f32x16 = _mm512_setzero_ps();
|
|
2510
|
+
nk_size_t idx = 0;
|
|
2511
|
+
for (; idx + 16 <= count; idx += 16) {
|
|
2512
|
+
__m512 data_f32x16 = nk_e4m3x16_to_f32x16_skylake_(_mm_loadu_si128((__m128i const *)(data_ptr + idx)));
|
|
2513
|
+
sum_f32x16 = _mm512_add_ps(sum_f32x16, data_f32x16);
|
|
2514
|
+
sumsq_f32x16 = _mm512_fmadd_ps(data_f32x16, data_f32x16, sumsq_f32x16);
|
|
2515
|
+
}
|
|
2516
|
+
nk_size_t remaining = count - idx;
|
|
2517
|
+
if (remaining > 0) {
|
|
2518
|
+
nk_b512_vec_t vec;
|
|
2519
|
+
nk_partial_load_e4m3x16_to_f32x16_skylake_(data_ptr + idx, &vec, remaining);
|
|
2520
|
+
sum_f32x16 = _mm512_add_ps(sum_f32x16, vec.zmm_ps);
|
|
2521
|
+
sumsq_f32x16 = _mm512_fmadd_ps(vec.zmm_ps, vec.zmm_ps, sumsq_f32x16);
|
|
2522
|
+
}
|
|
2523
|
+
*sum_ptr = nk_reduce_add_f32x16_skylake_(sum_f32x16);
|
|
2524
|
+
*sumsq_ptr = nk_reduce_add_f32x16_skylake_(sumsq_f32x16);
|
|
2525
|
+
}
|
|
2526
|
+
|
|
2527
|
+
NK_INTERNAL void nk_reduce_moments_e4m3_skylake_strided_( //
|
|
2528
|
+
nk_e4m3_t const *data_ptr, nk_size_t count, nk_size_t stride_elements, //
|
|
2529
|
+
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
2530
|
+
__mmask16 stride_mask_m16 = (__mmask16)nk_stride_mask_u1x64_(stride_elements);
|
|
2531
|
+
__m512 sum_f32x16 = _mm512_setzero_ps();
|
|
2532
|
+
__m512 sumsq_f32x16 = _mm512_setzero_ps();
|
|
2533
|
+
nk_size_t idx_scalars = 0;
|
|
2534
|
+
nk_size_t total_scalars = count * stride_elements;
|
|
2535
|
+
nk_size_t step = (nk_size_t)_mm_popcnt_u64((nk_u64_t)stride_mask_m16) * stride_elements;
|
|
2536
|
+
for (; idx_scalars + step <= total_scalars; idx_scalars += step) {
|
|
2537
|
+
__m128i data_e4m3x16 = _mm_maskz_loadu_epi8(stride_mask_m16, data_ptr + idx_scalars);
|
|
2538
|
+
__m512 data_f32x16 = nk_e4m3x16_to_f32x16_skylake_(data_e4m3x16);
|
|
2539
|
+
sum_f32x16 = _mm512_add_ps(sum_f32x16, data_f32x16);
|
|
2540
|
+
sumsq_f32x16 = _mm512_fmadd_ps(data_f32x16, data_f32x16, sumsq_f32x16);
|
|
2541
|
+
}
|
|
2542
|
+
nk_size_t remaining_bytes = total_scalars - idx_scalars;
|
|
2543
|
+
if (remaining_bytes > 0) {
|
|
2544
|
+
__mmask16 tail_mask = stride_mask_m16 & (__mmask16)_bzhi_u32(0xFFFF, (unsigned int)remaining_bytes);
|
|
2545
|
+
__m128i data_e4m3x16 = _mm_maskz_loadu_epi8(tail_mask, data_ptr + idx_scalars);
|
|
2546
|
+
__m512 data_f32x16 = nk_e4m3x16_to_f32x16_skylake_(data_e4m3x16);
|
|
2547
|
+
sum_f32x16 = _mm512_add_ps(sum_f32x16, data_f32x16);
|
|
2548
|
+
sumsq_f32x16 = _mm512_fmadd_ps(data_f32x16, data_f32x16, sumsq_f32x16);
|
|
2549
|
+
}
|
|
2550
|
+
*sum_ptr = nk_reduce_add_f32x16_skylake_(sum_f32x16);
|
|
2551
|
+
*sumsq_ptr = nk_reduce_add_f32x16_skylake_(sumsq_f32x16);
|
|
2552
|
+
}
|
|
2553
|
+
|
|
2554
|
+
NK_PUBLIC void nk_reduce_moments_e4m3_skylake( //
|
|
2555
|
+
nk_e4m3_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
2556
|
+
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
2557
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_e4m3_t);
|
|
2558
|
+
int aligned = (stride_bytes % sizeof(nk_e4m3_t) == 0);
|
|
2559
|
+
if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
|
|
2560
|
+
else if (!aligned) nk_reduce_moments_e4m3_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
2561
|
+
else if (count > (nk_size_t)(NK_U16_MAX + 1) * 64) {
|
|
2562
|
+
nk_size_t left_count = count / 2;
|
|
2563
|
+
nk_f32_t left_sum, left_sumsq, right_sum, right_sumsq;
|
|
2564
|
+
nk_reduce_moments_e4m3_skylake(data_ptr, left_count, stride_bytes, &left_sum, &left_sumsq);
|
|
2565
|
+
nk_reduce_moments_e4m3_skylake(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
|
|
2566
|
+
&right_sum, &right_sumsq);
|
|
2567
|
+
*sum_ptr = left_sum + right_sum, *sumsq_ptr = left_sumsq + right_sumsq;
|
|
2568
|
+
}
|
|
2569
|
+
else if (stride_elements == 1) nk_reduce_moments_e4m3_skylake_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
|
|
2570
|
+
else if (stride_elements >= 2 && stride_elements <= 16)
|
|
2571
|
+
nk_reduce_moments_e4m3_skylake_strided_(data_ptr, count, stride_elements, sum_ptr, sumsq_ptr);
|
|
2572
|
+
else nk_reduce_moments_e4m3_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
2573
|
+
}
|
|
2574
|
+
|
|
2575
|
+
NK_INTERNAL void nk_reduce_minmax_e4m3_skylake_contiguous_( //
|
|
2576
|
+
nk_e4m3_t const *data_ptr, nk_size_t count, //
|
|
2577
|
+
nk_e4m3_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
2578
|
+
nk_e4m3_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
2579
|
+
// E4M3 NaN: comparable 0x00 (neg NaN) and 0xFF (pos NaN). Replace with neutral values.
|
|
2580
|
+
nk_b512_vec_t min_vec, max_vec;
|
|
2581
|
+
min_vec.zmm = _mm512_set1_epi8((char)0xFF);
|
|
2582
|
+
max_vec.zmm = _mm512_setzero_si512();
|
|
2583
|
+
__m512i min_loop_cycle_u8x64 = _mm512_setzero_si512();
|
|
2584
|
+
__m512i max_loop_cycle_u8x64 = _mm512_setzero_si512();
|
|
2585
|
+
__m512i current_loop_cycle_u8x64 = _mm512_setzero_si512();
|
|
2586
|
+
__m512i one_u8x64 = _mm512_set1_epi8(1);
|
|
2587
|
+
|
|
2588
|
+
nk_size_t idx = 0;
|
|
2589
|
+
for (; idx + 64 <= count; idx += 64) {
|
|
2590
|
+
__m512i data_i8x64 = _mm512_loadu_si512(data_ptr + idx);
|
|
2591
|
+
__m512i data_cmp_u8x64 = nk_fp8x64_to_u8x64_comparable_skylake_(data_i8x64);
|
|
2592
|
+
__mmask64 is_nan_m64 = _mm512_cmpeq_epi8_mask(data_cmp_u8x64, _mm512_setzero_si512()) |
|
|
2593
|
+
_mm512_cmpeq_epi8_mask(data_cmp_u8x64, _mm512_set1_epi8((char)0xFF));
|
|
2594
|
+
__m512i data_min_u8x64 = _mm512_mask_blend_epi8(is_nan_m64, data_cmp_u8x64, _mm512_set1_epi8((char)0xFF));
|
|
2595
|
+
__m512i data_max_u8x64 = _mm512_mask_blend_epi8(is_nan_m64, data_cmp_u8x64, _mm512_setzero_si512());
|
|
2596
|
+
__m512i new_min_u8x64 = _mm512_min_epu8(min_vec.zmm, data_min_u8x64);
|
|
2597
|
+
__mmask64 min_changed_mask = ~_mm512_cmpeq_epi8_mask(new_min_u8x64, min_vec.zmm);
|
|
2598
|
+
min_vec.zmm = new_min_u8x64;
|
|
2599
|
+
min_loop_cycle_u8x64 = _mm512_mask_mov_epi8(min_loop_cycle_u8x64, min_changed_mask, current_loop_cycle_u8x64);
|
|
2600
|
+
__m512i new_max_u8x64 = _mm512_max_epu8(max_vec.zmm, data_max_u8x64);
|
|
2601
|
+
__mmask64 max_changed_mask = ~_mm512_cmpeq_epi8_mask(new_max_u8x64, max_vec.zmm);
|
|
2602
|
+
max_vec.zmm = new_max_u8x64;
|
|
2603
|
+
max_loop_cycle_u8x64 = _mm512_mask_mov_epi8(max_loop_cycle_u8x64, max_changed_mask, current_loop_cycle_u8x64);
|
|
2604
|
+
current_loop_cycle_u8x64 = _mm512_add_epi8(current_loop_cycle_u8x64, one_u8x64);
|
|
2605
|
+
}
|
|
2606
|
+
|
|
2607
|
+
nk_size_t remaining = count - idx;
|
|
2608
|
+
if (remaining > 0) {
|
|
2609
|
+
__mmask64 tail_load = _bzhi_u64(0xFFFFFFFFFFFFFFFFull, (unsigned int)remaining);
|
|
2610
|
+
__m512i data_i8x64 = _mm512_mask_loadu_epi8(_mm512_set1_epi8((char)0xFF), tail_load, data_ptr + idx);
|
|
2611
|
+
__m512i data_cmp_u8x64 = nk_fp8x64_to_u8x64_comparable_skylake_(data_i8x64);
|
|
2612
|
+
__mmask64 is_nan_m64 = _mm512_cmpeq_epi8_mask(data_cmp_u8x64, _mm512_setzero_si512()) |
|
|
2613
|
+
_mm512_cmpeq_epi8_mask(data_cmp_u8x64, _mm512_set1_epi8((char)0xFF));
|
|
2614
|
+
__mmask64 valid_non_nan_m64 = tail_load & ~is_nan_m64;
|
|
2615
|
+
__m512i data_cmp_min = _mm512_mask_blend_epi8(valid_non_nan_m64, _mm512_set1_epi8((char)0xFF), data_cmp_u8x64);
|
|
2616
|
+
__m512i data_cmp_max = _mm512_mask_blend_epi8(valid_non_nan_m64, _mm512_setzero_si512(), data_cmp_u8x64);
|
|
2617
|
+
__m512i new_min_u8x64 = _mm512_min_epu8(min_vec.zmm, data_cmp_min);
|
|
2618
|
+
__mmask64 min_changed_mask = ~_mm512_cmpeq_epi8_mask(new_min_u8x64, min_vec.zmm);
|
|
2619
|
+
min_vec.zmm = new_min_u8x64;
|
|
2620
|
+
min_loop_cycle_u8x64 = _mm512_mask_mov_epi8(min_loop_cycle_u8x64, min_changed_mask, current_loop_cycle_u8x64);
|
|
2621
|
+
__m512i new_max_u8x64 = _mm512_max_epu8(max_vec.zmm, data_cmp_max);
|
|
2622
|
+
__mmask64 max_changed_mask = ~_mm512_cmpeq_epi8_mask(new_max_u8x64, max_vec.zmm);
|
|
2623
|
+
max_vec.zmm = new_max_u8x64;
|
|
2624
|
+
max_loop_cycle_u8x64 = _mm512_mask_mov_epi8(max_loop_cycle_u8x64, max_changed_mask, current_loop_cycle_u8x64);
|
|
2625
|
+
}
|
|
2626
|
+
|
|
2627
|
+
nk_u8_t min_value_comparable = nk_reduce_min_u8x64_skylake_(min_vec.zmm);
|
|
2628
|
+
nk_u8_t max_value_comparable = nk_reduce_max_u8x64_skylake_(max_vec.zmm);
|
|
2629
|
+
|
|
2630
|
+
// All-NaN early return: both sentinels unchanged means no valid data was found
|
|
2631
|
+
if (min_value_comparable == 0xFF && max_value_comparable == 0x00) {
|
|
2632
|
+
*min_value_ptr = NK_E4M3_MAX, *min_index_ptr = NK_SIZE_MAX;
|
|
2633
|
+
*max_value_ptr = NK_E4M3_MIN, *max_index_ptr = NK_SIZE_MAX;
|
|
2634
|
+
return;
|
|
2635
|
+
}
|
|
2636
|
+
|
|
2637
|
+
if (min_value_comparable == 0xFF) { *min_value_ptr = NK_E4M3_MAX, *min_index_ptr = NK_SIZE_MAX; }
|
|
2638
|
+
else {
|
|
2639
|
+
unsigned int min_lane;
|
|
2640
|
+
__mmask64 value_match_mask = _mm512_cmpeq_epi8_mask(min_vec.zmm, _mm512_set1_epi8((char)min_value_comparable));
|
|
2641
|
+
__m512i masked_cycle_u8x64 = _mm512_mask_blend_epi8(value_match_mask, _mm512_set1_epi8((char)NK_U8_MAX),
|
|
2642
|
+
min_loop_cycle_u8x64);
|
|
2643
|
+
nk_u8_t earliest_loop_cycle = nk_reduce_min_u8x64_skylake_(masked_cycle_u8x64);
|
|
2644
|
+
__mmask64 cycle_match_mask = _mm512_cmpeq_epi8_mask(masked_cycle_u8x64,
|
|
2645
|
+
_mm512_set1_epi8((char)earliest_loop_cycle));
|
|
2646
|
+
min_lane = (unsigned int)_tzcnt_u64(cycle_match_mask);
|
|
2647
|
+
nk_b512_vec_t loop_cycle_vec;
|
|
2648
|
+
loop_cycle_vec.zmm = min_loop_cycle_u8x64;
|
|
2649
|
+
*min_index_ptr = (nk_size_t)loop_cycle_vec.u8s[min_lane] * 64 + min_lane;
|
|
2650
|
+
nk_b512_vec_t min_raw_vec;
|
|
2651
|
+
min_raw_vec.zmm = nk_u8x64_comparable_to_fp8x64_skylake_(min_vec.zmm);
|
|
2652
|
+
*min_value_ptr = min_raw_vec.e4m3s[min_lane];
|
|
2653
|
+
}
|
|
2654
|
+
if (max_value_comparable == 0x00) { *max_value_ptr = NK_E4M3_MIN, *max_index_ptr = NK_SIZE_MAX; }
|
|
2655
|
+
else {
|
|
2656
|
+
unsigned int max_lane;
|
|
2657
|
+
__mmask64 value_match_mask = _mm512_cmpeq_epi8_mask(max_vec.zmm, _mm512_set1_epi8((char)max_value_comparable));
|
|
2658
|
+
__m512i masked_cycle_u8x64 = _mm512_mask_blend_epi8(value_match_mask, _mm512_set1_epi8((char)NK_U8_MAX),
|
|
2659
|
+
max_loop_cycle_u8x64);
|
|
2660
|
+
nk_u8_t earliest_loop_cycle = nk_reduce_min_u8x64_skylake_(masked_cycle_u8x64);
|
|
2661
|
+
__mmask64 cycle_match_mask = _mm512_cmpeq_epi8_mask(masked_cycle_u8x64,
|
|
2662
|
+
_mm512_set1_epi8((char)earliest_loop_cycle));
|
|
2663
|
+
max_lane = (unsigned int)_tzcnt_u64(cycle_match_mask);
|
|
2664
|
+
nk_b512_vec_t loop_cycle_vec;
|
|
2665
|
+
loop_cycle_vec.zmm = max_loop_cycle_u8x64;
|
|
2666
|
+
*max_index_ptr = (nk_size_t)loop_cycle_vec.u8s[max_lane] * 64 + max_lane;
|
|
2667
|
+
nk_b512_vec_t max_raw_vec;
|
|
2668
|
+
max_raw_vec.zmm = nk_u8x64_comparable_to_fp8x64_skylake_(max_vec.zmm);
|
|
2669
|
+
*max_value_ptr = max_raw_vec.e4m3s[max_lane];
|
|
2670
|
+
}
|
|
2671
|
+
}
|
|
2672
|
+
|
|
2673
|
+
NK_PUBLIC void nk_reduce_minmax_e4m3_skylake( //
|
|
2674
|
+
nk_e4m3_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
2675
|
+
nk_e4m3_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
2676
|
+
nk_e4m3_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
2677
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_e4m3_t);
|
|
2678
|
+
if (count == 0)
|
|
2679
|
+
*min_value_ptr = NK_E4M3_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_E4M3_MIN,
|
|
2680
|
+
*max_index_ptr = NK_SIZE_MAX;
|
|
2681
|
+
else if (stride_elements == 1 && count > (nk_size_t)(NK_U8_MAX + 1) * 64) {
|
|
2682
|
+
nk_size_t left_count = count / 2;
|
|
2683
|
+
nk_e4m3_t left_min, right_min, left_max, right_max;
|
|
2684
|
+
nk_size_t left_min_index, right_min_index, left_max_index, right_max_index;
|
|
2685
|
+
nk_reduce_minmax_e4m3_skylake(data_ptr, left_count, stride_bytes, &left_min, &left_min_index, &left_max,
|
|
2686
|
+
&left_max_index);
|
|
2687
|
+
nk_reduce_minmax_e4m3_skylake(data_ptr + left_count, count - left_count, stride_bytes, &right_min,
|
|
2688
|
+
&right_min_index, &right_max, &right_max_index);
|
|
2689
|
+
if (left_min_index == NK_SIZE_MAX)
|
|
2690
|
+
*min_value_ptr = right_min,
|
|
2691
|
+
*min_index_ptr = right_min_index == NK_SIZE_MAX ? NK_SIZE_MAX : left_count + right_min_index;
|
|
2692
|
+
else if (right_min_index == NK_SIZE_MAX || nk_e4m3_order_serial(left_min, right_min) <= 0)
|
|
2693
|
+
*min_value_ptr = left_min, *min_index_ptr = left_min_index;
|
|
2694
|
+
else *min_value_ptr = right_min, *min_index_ptr = left_count + right_min_index;
|
|
2695
|
+
if (left_max_index == NK_SIZE_MAX)
|
|
2696
|
+
*max_value_ptr = right_max,
|
|
2697
|
+
*max_index_ptr = right_max_index == NK_SIZE_MAX ? NK_SIZE_MAX : left_count + right_max_index;
|
|
2698
|
+
else if (right_max_index == NK_SIZE_MAX || nk_e4m3_order_serial(left_max, right_max) >= 0)
|
|
2699
|
+
*max_value_ptr = left_max, *max_index_ptr = left_max_index;
|
|
2700
|
+
else *max_value_ptr = right_max, *max_index_ptr = left_count + right_max_index;
|
|
2701
|
+
}
|
|
2702
|
+
else if (stride_elements == 1)
|
|
2703
|
+
nk_reduce_minmax_e4m3_skylake_contiguous_(data_ptr, count, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
2704
|
+
max_index_ptr);
|
|
2705
|
+
else
|
|
2706
|
+
nk_reduce_minmax_e4m3_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
2707
|
+
max_index_ptr);
|
|
2708
|
+
}
|
|
2709
|
+
|
|
2710
|
+
NK_INTERNAL void nk_reduce_moments_e5m2_skylake_contiguous_( //
|
|
2711
|
+
nk_e5m2_t const *data_ptr, nk_size_t count, //
|
|
2712
|
+
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
2713
|
+
__m512 sum_f32x16 = _mm512_setzero_ps();
|
|
2714
|
+
__m512 sumsq_f32x16 = _mm512_setzero_ps();
|
|
2715
|
+
nk_size_t idx = 0;
|
|
2716
|
+
for (; idx + 16 <= count; idx += 16) {
|
|
2717
|
+
__m512 data_f32x16 = nk_e5m2x16_to_f32x16_skylake_(_mm_loadu_si128((__m128i const *)(data_ptr + idx)));
|
|
2718
|
+
sum_f32x16 = _mm512_add_ps(sum_f32x16, data_f32x16);
|
|
2719
|
+
sumsq_f32x16 = _mm512_fmadd_ps(data_f32x16, data_f32x16, sumsq_f32x16);
|
|
2720
|
+
}
|
|
2721
|
+
nk_size_t remaining = count - idx;
|
|
2722
|
+
if (remaining > 0) {
|
|
2723
|
+
nk_b512_vec_t vec;
|
|
2724
|
+
nk_partial_load_e5m2x16_to_f32x16_skylake_(data_ptr + idx, &vec, remaining);
|
|
2725
|
+
sum_f32x16 = _mm512_add_ps(sum_f32x16, vec.zmm_ps);
|
|
2726
|
+
sumsq_f32x16 = _mm512_fmadd_ps(vec.zmm_ps, vec.zmm_ps, sumsq_f32x16);
|
|
2727
|
+
}
|
|
2728
|
+
*sum_ptr = nk_reduce_add_f32x16_skylake_(sum_f32x16);
|
|
2729
|
+
*sumsq_ptr = nk_reduce_add_f32x16_skylake_(sumsq_f32x16);
|
|
2730
|
+
}
|
|
2731
|
+
|
|
2732
|
+
NK_INTERNAL void nk_reduce_moments_e5m2_skylake_strided_( //
|
|
2733
|
+
nk_e5m2_t const *data_ptr, nk_size_t count, nk_size_t stride_elements, //
|
|
2734
|
+
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
2735
|
+
__mmask16 stride_mask_m16 = (__mmask16)nk_stride_mask_u1x64_(stride_elements);
|
|
2736
|
+
__m512 sum_f32x16 = _mm512_setzero_ps();
|
|
2737
|
+
__m512 sumsq_f32x16 = _mm512_setzero_ps();
|
|
2738
|
+
nk_size_t idx_scalars = 0;
|
|
2739
|
+
nk_size_t total_scalars = count * stride_elements;
|
|
2740
|
+
nk_size_t step = (nk_size_t)_mm_popcnt_u64((nk_u64_t)stride_mask_m16) * stride_elements;
|
|
2741
|
+
for (; idx_scalars + step <= total_scalars; idx_scalars += step) {
|
|
2742
|
+
__m128i data_e5m2x16 = _mm_maskz_loadu_epi8(stride_mask_m16, data_ptr + idx_scalars);
|
|
2743
|
+
__m512 data_f32x16 = nk_e5m2x16_to_f32x16_skylake_(data_e5m2x16);
|
|
2744
|
+
sum_f32x16 = _mm512_add_ps(sum_f32x16, data_f32x16);
|
|
2745
|
+
sumsq_f32x16 = _mm512_fmadd_ps(data_f32x16, data_f32x16, sumsq_f32x16);
|
|
2746
|
+
}
|
|
2747
|
+
nk_size_t remaining_bytes = total_scalars - idx_scalars;
|
|
2748
|
+
if (remaining_bytes > 0) {
|
|
2749
|
+
__mmask16 tail_mask = stride_mask_m16 & (__mmask16)_bzhi_u32(0xFFFF, (unsigned int)remaining_bytes);
|
|
2750
|
+
__m128i data_e5m2x16 = _mm_maskz_loadu_epi8(tail_mask, data_ptr + idx_scalars);
|
|
2751
|
+
__m512 data_f32x16 = nk_e5m2x16_to_f32x16_skylake_(data_e5m2x16);
|
|
2752
|
+
sum_f32x16 = _mm512_add_ps(sum_f32x16, data_f32x16);
|
|
2753
|
+
sumsq_f32x16 = _mm512_fmadd_ps(data_f32x16, data_f32x16, sumsq_f32x16);
|
|
2754
|
+
}
|
|
2755
|
+
*sum_ptr = nk_reduce_add_f32x16_skylake_(sum_f32x16);
|
|
2756
|
+
*sumsq_ptr = nk_reduce_add_f32x16_skylake_(sumsq_f32x16);
|
|
2757
|
+
}
|
|
2758
|
+
|
|
2759
|
+
NK_PUBLIC void nk_reduce_moments_e5m2_skylake( //
|
|
2760
|
+
nk_e5m2_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
2761
|
+
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
2762
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_e5m2_t);
|
|
2763
|
+
int aligned = (stride_bytes % sizeof(nk_e5m2_t) == 0);
|
|
2764
|
+
if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
|
|
2765
|
+
else if (!aligned) nk_reduce_moments_e5m2_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
2766
|
+
else if (count > (nk_size_t)(NK_U16_MAX + 1) * 64) {
|
|
2767
|
+
nk_size_t left_count = count / 2;
|
|
2768
|
+
nk_f32_t left_sum, left_sumsq, right_sum, right_sumsq;
|
|
2769
|
+
nk_reduce_moments_e5m2_skylake(data_ptr, left_count, stride_bytes, &left_sum, &left_sumsq);
|
|
2770
|
+
nk_reduce_moments_e5m2_skylake(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
|
|
2771
|
+
&right_sum, &right_sumsq);
|
|
2772
|
+
*sum_ptr = left_sum + right_sum, *sumsq_ptr = left_sumsq + right_sumsq;
|
|
2773
|
+
}
|
|
2774
|
+
else if (stride_elements == 1) nk_reduce_moments_e5m2_skylake_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
|
|
2775
|
+
else if (stride_elements >= 2 && stride_elements <= 16)
|
|
2776
|
+
nk_reduce_moments_e5m2_skylake_strided_(data_ptr, count, stride_elements, sum_ptr, sumsq_ptr);
|
|
2777
|
+
else nk_reduce_moments_e5m2_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
2778
|
+
}
|
|
2779
|
+
|
|
2780
|
+
NK_INTERNAL void nk_reduce_moments_e2m3_skylake_contiguous_( //
|
|
2781
|
+
nk_e2m3_t const *data_ptr, nk_size_t count, //
|
|
2782
|
+
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
2783
|
+
__m512 sum_f32x16 = _mm512_setzero_ps();
|
|
2784
|
+
__m512 sumsq_f32x16 = _mm512_setzero_ps();
|
|
2785
|
+
nk_size_t idx = 0;
|
|
2786
|
+
for (; idx + 16 <= count; idx += 16) {
|
|
2787
|
+
__m512 data_f32x16 = nk_e2m3x16_to_f32x16_skylake_(_mm_loadu_si128((__m128i const *)(data_ptr + idx)));
|
|
2788
|
+
sum_f32x16 = _mm512_add_ps(sum_f32x16, data_f32x16);
|
|
2789
|
+
sumsq_f32x16 = _mm512_fmadd_ps(data_f32x16, data_f32x16, sumsq_f32x16);
|
|
2790
|
+
}
|
|
2791
|
+
nk_size_t remaining = count - idx;
|
|
2792
|
+
if (remaining > 0) {
|
|
2793
|
+
nk_b512_vec_t vec;
|
|
2794
|
+
nk_partial_load_e2m3x16_to_f32x16_skylake_(data_ptr + idx, &vec, remaining);
|
|
2795
|
+
sum_f32x16 = _mm512_add_ps(sum_f32x16, vec.zmm_ps);
|
|
2796
|
+
sumsq_f32x16 = _mm512_fmadd_ps(vec.zmm_ps, vec.zmm_ps, sumsq_f32x16);
|
|
2797
|
+
}
|
|
2798
|
+
*sum_ptr = nk_reduce_add_f32x16_skylake_(sum_f32x16);
|
|
2799
|
+
*sumsq_ptr = nk_reduce_add_f32x16_skylake_(sumsq_f32x16);
|
|
2800
|
+
}
|
|
2801
|
+
|
|
2802
|
+
NK_INTERNAL void nk_reduce_moments_e2m3_skylake_strided_( //
|
|
2803
|
+
nk_e2m3_t const *data_ptr, nk_size_t count, nk_size_t stride_elements, //
|
|
2804
|
+
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
2805
|
+
__mmask16 stride_mask_m16 = (__mmask16)nk_stride_mask_u1x64_(stride_elements);
|
|
2806
|
+
__m512 sum_f32x16 = _mm512_setzero_ps();
|
|
2807
|
+
__m512 sumsq_f32x16 = _mm512_setzero_ps();
|
|
2808
|
+
nk_size_t idx_scalars = 0;
|
|
2809
|
+
nk_size_t total_scalars = count * stride_elements;
|
|
2810
|
+
nk_size_t step = (nk_size_t)_mm_popcnt_u64((nk_u64_t)stride_mask_m16) * stride_elements;
|
|
2811
|
+
for (; idx_scalars + step <= total_scalars; idx_scalars += step) {
|
|
2812
|
+
__m128i data_e2m3x16 = _mm_maskz_loadu_epi8(stride_mask_m16, data_ptr + idx_scalars);
|
|
2813
|
+
__m512 data_f32x16 = nk_e2m3x16_to_f32x16_skylake_(data_e2m3x16);
|
|
2814
|
+
sum_f32x16 = _mm512_add_ps(sum_f32x16, data_f32x16);
|
|
2815
|
+
sumsq_f32x16 = _mm512_fmadd_ps(data_f32x16, data_f32x16, sumsq_f32x16);
|
|
2816
|
+
}
|
|
2817
|
+
nk_size_t remaining_bytes = total_scalars - idx_scalars;
|
|
2818
|
+
if (remaining_bytes > 0) {
|
|
2819
|
+
__mmask16 tail_mask = stride_mask_m16 & (__mmask16)_bzhi_u32(0xFFFF, (unsigned int)remaining_bytes);
|
|
2820
|
+
__m128i data_e2m3x16 = _mm_maskz_loadu_epi8(tail_mask, data_ptr + idx_scalars);
|
|
2821
|
+
__m512 data_f32x16 = nk_e2m3x16_to_f32x16_skylake_(data_e2m3x16);
|
|
2822
|
+
sum_f32x16 = _mm512_add_ps(sum_f32x16, data_f32x16);
|
|
2823
|
+
sumsq_f32x16 = _mm512_fmadd_ps(data_f32x16, data_f32x16, sumsq_f32x16);
|
|
2824
|
+
}
|
|
2825
|
+
*sum_ptr = nk_reduce_add_f32x16_skylake_(sum_f32x16);
|
|
2826
|
+
*sumsq_ptr = nk_reduce_add_f32x16_skylake_(sumsq_f32x16);
|
|
2827
|
+
}
|
|
2828
|
+
|
|
2829
|
+
NK_PUBLIC void nk_reduce_moments_e2m3_skylake( //
|
|
2830
|
+
nk_e2m3_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
2831
|
+
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
2832
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_e2m3_t);
|
|
2833
|
+
int aligned = (stride_bytes % sizeof(nk_e2m3_t) == 0);
|
|
2834
|
+
if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
|
|
2835
|
+
else if (!aligned) nk_reduce_moments_e2m3_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
2836
|
+
else if (count > (nk_size_t)(NK_U16_MAX + 1) * 64) {
|
|
2837
|
+
nk_size_t left_count = count / 2;
|
|
2838
|
+
nk_f32_t left_sum, left_sumsq, right_sum, right_sumsq;
|
|
2839
|
+
nk_reduce_moments_e2m3_skylake(data_ptr, left_count, stride_bytes, &left_sum, &left_sumsq);
|
|
2840
|
+
nk_reduce_moments_e2m3_skylake(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
|
|
2841
|
+
&right_sum, &right_sumsq);
|
|
2842
|
+
*sum_ptr = left_sum + right_sum, *sumsq_ptr = left_sumsq + right_sumsq;
|
|
2843
|
+
}
|
|
2844
|
+
else if (stride_elements == 1) nk_reduce_moments_e2m3_skylake_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
|
|
2845
|
+
else if (stride_elements >= 2 && stride_elements <= 16)
|
|
2846
|
+
nk_reduce_moments_e2m3_skylake_strided_(data_ptr, count, stride_elements, sum_ptr, sumsq_ptr);
|
|
2847
|
+
else nk_reduce_moments_e2m3_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
2848
|
+
}
|
|
2849
|
+
|
|
2850
|
+
NK_INTERNAL void nk_reduce_moments_e3m2_skylake_contiguous_( //
|
|
2851
|
+
nk_e3m2_t const *data_ptr, nk_size_t count, //
|
|
2852
|
+
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
2853
|
+
__m512 sum_f32x16 = _mm512_setzero_ps();
|
|
2854
|
+
__m512 sumsq_f32x16 = _mm512_setzero_ps();
|
|
2855
|
+
nk_size_t idx = 0;
|
|
2856
|
+
for (; idx + 16 <= count; idx += 16) {
|
|
2857
|
+
__m512 data_f32x16 = nk_e3m2x16_to_f32x16_skylake_(_mm_loadu_si128((__m128i const *)(data_ptr + idx)));
|
|
2858
|
+
sum_f32x16 = _mm512_add_ps(sum_f32x16, data_f32x16);
|
|
2859
|
+
sumsq_f32x16 = _mm512_fmadd_ps(data_f32x16, data_f32x16, sumsq_f32x16);
|
|
2860
|
+
}
|
|
2861
|
+
nk_size_t remaining = count - idx;
|
|
2862
|
+
if (remaining > 0) {
|
|
2863
|
+
nk_b512_vec_t vec;
|
|
2864
|
+
nk_partial_load_e3m2x16_to_f32x16_skylake_(data_ptr + idx, &vec, remaining);
|
|
2865
|
+
sum_f32x16 = _mm512_add_ps(sum_f32x16, vec.zmm_ps);
|
|
2866
|
+
sumsq_f32x16 = _mm512_fmadd_ps(vec.zmm_ps, vec.zmm_ps, sumsq_f32x16);
|
|
2867
|
+
}
|
|
2868
|
+
*sum_ptr = nk_reduce_add_f32x16_skylake_(sum_f32x16);
|
|
2869
|
+
*sumsq_ptr = nk_reduce_add_f32x16_skylake_(sumsq_f32x16);
|
|
2870
|
+
}
|
|
2871
|
+
|
|
2872
|
+
NK_INTERNAL void nk_reduce_moments_e3m2_skylake_strided_( //
|
|
2873
|
+
nk_e3m2_t const *data_ptr, nk_size_t count, nk_size_t stride_elements, //
|
|
2874
|
+
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
2875
|
+
__mmask16 stride_mask_m16 = (__mmask16)nk_stride_mask_u1x64_(stride_elements);
|
|
2876
|
+
__m512 sum_f32x16 = _mm512_setzero_ps();
|
|
2877
|
+
__m512 sumsq_f32x16 = _mm512_setzero_ps();
|
|
2878
|
+
nk_size_t idx_scalars = 0;
|
|
2879
|
+
nk_size_t total_scalars = count * stride_elements;
|
|
2880
|
+
nk_size_t step = (nk_size_t)_mm_popcnt_u64((nk_u64_t)stride_mask_m16) * stride_elements;
|
|
2881
|
+
for (; idx_scalars + step <= total_scalars; idx_scalars += step) {
|
|
2882
|
+
__m128i data_e3m2x16 = _mm_maskz_loadu_epi8(stride_mask_m16, data_ptr + idx_scalars);
|
|
2883
|
+
__m512 data_f32x16 = nk_e3m2x16_to_f32x16_skylake_(data_e3m2x16);
|
|
2884
|
+
sum_f32x16 = _mm512_add_ps(sum_f32x16, data_f32x16);
|
|
2885
|
+
sumsq_f32x16 = _mm512_fmadd_ps(data_f32x16, data_f32x16, sumsq_f32x16);
|
|
2886
|
+
}
|
|
2887
|
+
nk_size_t remaining_bytes = total_scalars - idx_scalars;
|
|
2888
|
+
if (remaining_bytes > 0) {
|
|
2889
|
+
__mmask16 tail_mask = stride_mask_m16 & (__mmask16)_bzhi_u32(0xFFFF, (unsigned int)remaining_bytes);
|
|
2890
|
+
__m128i data_e3m2x16 = _mm_maskz_loadu_epi8(tail_mask, data_ptr + idx_scalars);
|
|
2891
|
+
__m512 data_f32x16 = nk_e3m2x16_to_f32x16_skylake_(data_e3m2x16);
|
|
2892
|
+
sum_f32x16 = _mm512_add_ps(sum_f32x16, data_f32x16);
|
|
2893
|
+
sumsq_f32x16 = _mm512_fmadd_ps(data_f32x16, data_f32x16, sumsq_f32x16);
|
|
2894
|
+
}
|
|
2895
|
+
*sum_ptr = nk_reduce_add_f32x16_skylake_(sum_f32x16);
|
|
2896
|
+
*sumsq_ptr = nk_reduce_add_f32x16_skylake_(sumsq_f32x16);
|
|
2897
|
+
}
|
|
2898
|
+
|
|
2899
|
+
NK_PUBLIC void nk_reduce_moments_e3m2_skylake( //
|
|
2900
|
+
nk_e3m2_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
2901
|
+
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
2902
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_e3m2_t);
|
|
2903
|
+
int aligned = (stride_bytes % sizeof(nk_e3m2_t) == 0);
|
|
2904
|
+
if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
|
|
2905
|
+
else if (!aligned) nk_reduce_moments_e3m2_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
2906
|
+
else if (count > (nk_size_t)(NK_U16_MAX + 1) * 64) {
|
|
2907
|
+
nk_size_t left_count = count / 2;
|
|
2908
|
+
nk_f32_t left_sum, left_sumsq, right_sum, right_sumsq;
|
|
2909
|
+
nk_reduce_moments_e3m2_skylake(data_ptr, left_count, stride_bytes, &left_sum, &left_sumsq);
|
|
2910
|
+
nk_reduce_moments_e3m2_skylake(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
|
|
2911
|
+
&right_sum, &right_sumsq);
|
|
2912
|
+
*sum_ptr = left_sum + right_sum, *sumsq_ptr = left_sumsq + right_sumsq;
|
|
2913
|
+
}
|
|
2914
|
+
else if (stride_elements == 1) nk_reduce_moments_e3m2_skylake_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
|
|
2915
|
+
else if (stride_elements >= 2 && stride_elements <= 16)
|
|
2916
|
+
nk_reduce_moments_e3m2_skylake_strided_(data_ptr, count, stride_elements, sum_ptr, sumsq_ptr);
|
|
2917
|
+
else nk_reduce_moments_e3m2_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
2918
|
+
}
|
|
2919
|
+
|
|
2920
|
+
NK_INTERNAL void nk_reduce_minmax_e5m2_skylake_contiguous_( //
|
|
2921
|
+
nk_e5m2_t const *data_ptr, nk_size_t count, //
|
|
2922
|
+
nk_e5m2_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
2923
|
+
nk_e5m2_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
2924
|
+
// E5M2 NaN in comparable form: 0x00-0x02 (neg NaN) and 0xFD-0xFF (pos NaN).
|
|
2925
|
+
nk_b512_vec_t min_vec, max_vec;
|
|
2926
|
+
min_vec.zmm = _mm512_set1_epi8((char)0xFF);
|
|
2927
|
+
max_vec.zmm = _mm512_setzero_si512();
|
|
2928
|
+
__m512i min_loop_cycle_u8x64 = _mm512_setzero_si512();
|
|
2929
|
+
__m512i max_loop_cycle_u8x64 = _mm512_setzero_si512();
|
|
2930
|
+
__m512i current_loop_cycle_u8x64 = _mm512_setzero_si512();
|
|
2931
|
+
__m512i one_u8x64 = _mm512_set1_epi8(1);
|
|
2932
|
+
|
|
2933
|
+
nk_size_t idx = 0;
|
|
2934
|
+
for (; idx + 64 <= count; idx += 64) {
|
|
2935
|
+
__m512i data_i8x64 = _mm512_loadu_si512(data_ptr + idx);
|
|
2936
|
+
__m512i data_cmp_u8x64 = nk_fp8x64_to_u8x64_comparable_skylake_(data_i8x64);
|
|
2937
|
+
__mmask64 is_nan_m64 = _mm512_cmple_epu8_mask(data_cmp_u8x64, _mm512_set1_epi8(0x02)) |
|
|
2938
|
+
_mm512_cmpge_epu8_mask(data_cmp_u8x64, _mm512_set1_epi8((char)0xFD));
|
|
2939
|
+
__m512i data_min_u8x64 = _mm512_mask_blend_epi8(is_nan_m64, data_cmp_u8x64, _mm512_set1_epi8((char)0xFF));
|
|
2940
|
+
__m512i data_max_u8x64 = _mm512_mask_blend_epi8(is_nan_m64, data_cmp_u8x64, _mm512_setzero_si512());
|
|
2941
|
+
__m512i new_min_u8x64 = _mm512_min_epu8(min_vec.zmm, data_min_u8x64);
|
|
2942
|
+
__mmask64 min_changed_mask = ~_mm512_cmpeq_epi8_mask(new_min_u8x64, min_vec.zmm);
|
|
2943
|
+
min_vec.zmm = new_min_u8x64;
|
|
2944
|
+
min_loop_cycle_u8x64 = _mm512_mask_mov_epi8(min_loop_cycle_u8x64, min_changed_mask, current_loop_cycle_u8x64);
|
|
2945
|
+
__m512i new_max_u8x64 = _mm512_max_epu8(max_vec.zmm, data_max_u8x64);
|
|
2946
|
+
__mmask64 max_changed_mask = ~_mm512_cmpeq_epi8_mask(new_max_u8x64, max_vec.zmm);
|
|
2947
|
+
max_vec.zmm = new_max_u8x64;
|
|
2948
|
+
max_loop_cycle_u8x64 = _mm512_mask_mov_epi8(max_loop_cycle_u8x64, max_changed_mask, current_loop_cycle_u8x64);
|
|
2949
|
+
current_loop_cycle_u8x64 = _mm512_add_epi8(current_loop_cycle_u8x64, one_u8x64);
|
|
2950
|
+
}
|
|
2951
|
+
|
|
2952
|
+
nk_size_t remaining = count - idx;
|
|
2953
|
+
if (remaining > 0) {
|
|
2954
|
+
__mmask64 tail_load = _bzhi_u64(0xFFFFFFFFFFFFFFFFull, (unsigned int)remaining);
|
|
2955
|
+
__m512i data_i8x64 = _mm512_mask_loadu_epi8(_mm512_set1_epi8((char)0xFF), tail_load, data_ptr + idx);
|
|
2956
|
+
__m512i data_cmp_u8x64 = nk_fp8x64_to_u8x64_comparable_skylake_(data_i8x64);
|
|
2957
|
+
__mmask64 is_nan_m64 = _mm512_cmple_epu8_mask(data_cmp_u8x64, _mm512_set1_epi8(0x02)) |
|
|
2958
|
+
_mm512_cmpge_epu8_mask(data_cmp_u8x64, _mm512_set1_epi8((char)0xFD));
|
|
2959
|
+
__mmask64 valid_non_nan_m64 = tail_load & ~is_nan_m64;
|
|
2960
|
+
__m512i data_cmp_min = _mm512_mask_blend_epi8(valid_non_nan_m64, _mm512_set1_epi8((char)0xFF), data_cmp_u8x64);
|
|
2961
|
+
__m512i data_cmp_max = _mm512_mask_blend_epi8(valid_non_nan_m64, _mm512_setzero_si512(), data_cmp_u8x64);
|
|
2962
|
+
__m512i new_min_u8x64 = _mm512_min_epu8(min_vec.zmm, data_cmp_min);
|
|
2963
|
+
__mmask64 min_changed_mask = ~_mm512_cmpeq_epi8_mask(new_min_u8x64, min_vec.zmm);
|
|
2964
|
+
min_vec.zmm = new_min_u8x64;
|
|
2965
|
+
min_loop_cycle_u8x64 = _mm512_mask_mov_epi8(min_loop_cycle_u8x64, min_changed_mask, current_loop_cycle_u8x64);
|
|
2966
|
+
__m512i new_max_u8x64 = _mm512_max_epu8(max_vec.zmm, data_cmp_max);
|
|
2967
|
+
__mmask64 max_changed_mask = ~_mm512_cmpeq_epi8_mask(new_max_u8x64, max_vec.zmm);
|
|
2968
|
+
max_vec.zmm = new_max_u8x64;
|
|
2969
|
+
max_loop_cycle_u8x64 = _mm512_mask_mov_epi8(max_loop_cycle_u8x64, max_changed_mask, current_loop_cycle_u8x64);
|
|
2970
|
+
}
|
|
2971
|
+
|
|
2972
|
+
nk_u8_t min_value_comparable = nk_reduce_min_u8x64_skylake_(min_vec.zmm);
|
|
2973
|
+
nk_u8_t max_value_comparable = nk_reduce_max_u8x64_skylake_(max_vec.zmm);
|
|
2974
|
+
|
|
2975
|
+
// All-NaN early return: both sentinels unchanged means no valid data was found
|
|
2976
|
+
if (min_value_comparable == 0xFF && max_value_comparable == 0x00) {
|
|
2977
|
+
*min_value_ptr = NK_E5M2_MAX, *min_index_ptr = NK_SIZE_MAX;
|
|
2978
|
+
*max_value_ptr = NK_E5M2_MIN, *max_index_ptr = NK_SIZE_MAX;
|
|
2979
|
+
return;
|
|
2980
|
+
}
|
|
2981
|
+
|
|
2982
|
+
if (min_value_comparable == 0xFF) { *min_value_ptr = NK_E5M2_MAX, *min_index_ptr = NK_SIZE_MAX; }
|
|
2983
|
+
else {
|
|
2984
|
+
unsigned int min_lane;
|
|
2985
|
+
__mmask64 value_match_mask = _mm512_cmpeq_epi8_mask(min_vec.zmm, _mm512_set1_epi8((char)min_value_comparable));
|
|
2986
|
+
__m512i masked_cycle_u8x64 = _mm512_mask_blend_epi8(value_match_mask, _mm512_set1_epi8((char)NK_U8_MAX),
|
|
2987
|
+
min_loop_cycle_u8x64);
|
|
2988
|
+
nk_u8_t earliest_loop_cycle = nk_reduce_min_u8x64_skylake_(masked_cycle_u8x64);
|
|
2989
|
+
__mmask64 cycle_match_mask = _mm512_cmpeq_epi8_mask(masked_cycle_u8x64,
|
|
2990
|
+
_mm512_set1_epi8((char)earliest_loop_cycle));
|
|
2991
|
+
min_lane = (unsigned int)_tzcnt_u64(cycle_match_mask);
|
|
2992
|
+
nk_b512_vec_t loop_cycle_vec;
|
|
2993
|
+
loop_cycle_vec.zmm = min_loop_cycle_u8x64;
|
|
2994
|
+
*min_index_ptr = (nk_size_t)loop_cycle_vec.u8s[min_lane] * 64 + min_lane;
|
|
2995
|
+
nk_b512_vec_t min_raw_vec;
|
|
2996
|
+
min_raw_vec.zmm = nk_u8x64_comparable_to_fp8x64_skylake_(min_vec.zmm);
|
|
2997
|
+
*min_value_ptr = min_raw_vec.e5m2s[min_lane];
|
|
2998
|
+
}
|
|
2999
|
+
if (max_value_comparable == 0x00) { *max_value_ptr = NK_E5M2_MIN, *max_index_ptr = NK_SIZE_MAX; }
|
|
3000
|
+
else {
|
|
3001
|
+
unsigned int max_lane;
|
|
3002
|
+
__mmask64 value_match_mask = _mm512_cmpeq_epi8_mask(max_vec.zmm, _mm512_set1_epi8((char)max_value_comparable));
|
|
3003
|
+
__m512i masked_cycle_u8x64 = _mm512_mask_blend_epi8(value_match_mask, _mm512_set1_epi8((char)NK_U8_MAX),
|
|
3004
|
+
max_loop_cycle_u8x64);
|
|
3005
|
+
nk_u8_t earliest_loop_cycle = nk_reduce_min_u8x64_skylake_(masked_cycle_u8x64);
|
|
3006
|
+
__mmask64 cycle_match_mask = _mm512_cmpeq_epi8_mask(masked_cycle_u8x64,
|
|
3007
|
+
_mm512_set1_epi8((char)earliest_loop_cycle));
|
|
3008
|
+
max_lane = (unsigned int)_tzcnt_u64(cycle_match_mask);
|
|
3009
|
+
nk_b512_vec_t loop_cycle_vec;
|
|
3010
|
+
loop_cycle_vec.zmm = max_loop_cycle_u8x64;
|
|
3011
|
+
*max_index_ptr = (nk_size_t)loop_cycle_vec.u8s[max_lane] * 64 + max_lane;
|
|
3012
|
+
nk_b512_vec_t max_raw_vec;
|
|
3013
|
+
max_raw_vec.zmm = nk_u8x64_comparable_to_fp8x64_skylake_(max_vec.zmm);
|
|
3014
|
+
*max_value_ptr = max_raw_vec.e5m2s[max_lane];
|
|
3015
|
+
}
|
|
3016
|
+
}
|
|
3017
|
+
|
|
3018
|
+
NK_PUBLIC void nk_reduce_minmax_e5m2_skylake( //
|
|
3019
|
+
nk_e5m2_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
3020
|
+
nk_e5m2_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
3021
|
+
nk_e5m2_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
3022
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_e5m2_t);
|
|
3023
|
+
if (count == 0)
|
|
3024
|
+
*min_value_ptr = NK_E5M2_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_E5M2_MIN,
|
|
3025
|
+
*max_index_ptr = NK_SIZE_MAX;
|
|
3026
|
+
else if (stride_elements == 1 && count > (nk_size_t)(NK_U8_MAX + 1) * 64) {
|
|
3027
|
+
nk_size_t left_count = count / 2;
|
|
3028
|
+
nk_e5m2_t left_min, right_min, left_max, right_max;
|
|
3029
|
+
nk_size_t left_min_index, right_min_index, left_max_index, right_max_index;
|
|
3030
|
+
nk_reduce_minmax_e5m2_skylake(data_ptr, left_count, stride_bytes, &left_min, &left_min_index, &left_max,
|
|
3031
|
+
&left_max_index);
|
|
3032
|
+
nk_reduce_minmax_e5m2_skylake(data_ptr + left_count, count - left_count, stride_bytes, &right_min,
|
|
3033
|
+
&right_min_index, &right_max, &right_max_index);
|
|
3034
|
+
if (left_min_index == NK_SIZE_MAX)
|
|
3035
|
+
*min_value_ptr = right_min,
|
|
3036
|
+
*min_index_ptr = right_min_index == NK_SIZE_MAX ? NK_SIZE_MAX : left_count + right_min_index;
|
|
3037
|
+
else if (right_min_index == NK_SIZE_MAX || nk_e5m2_order_serial(left_min, right_min) <= 0)
|
|
3038
|
+
*min_value_ptr = left_min, *min_index_ptr = left_min_index;
|
|
3039
|
+
else *min_value_ptr = right_min, *min_index_ptr = left_count + right_min_index;
|
|
3040
|
+
if (left_max_index == NK_SIZE_MAX)
|
|
3041
|
+
*max_value_ptr = right_max,
|
|
3042
|
+
*max_index_ptr = right_max_index == NK_SIZE_MAX ? NK_SIZE_MAX : left_count + right_max_index;
|
|
3043
|
+
else if (right_max_index == NK_SIZE_MAX || nk_e5m2_order_serial(left_max, right_max) >= 0)
|
|
3044
|
+
*max_value_ptr = left_max, *max_index_ptr = left_max_index;
|
|
3045
|
+
else *max_value_ptr = right_max, *max_index_ptr = left_count + right_max_index;
|
|
3046
|
+
}
|
|
3047
|
+
else if (stride_elements == 1)
|
|
3048
|
+
nk_reduce_minmax_e5m2_skylake_contiguous_(data_ptr, count, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
3049
|
+
max_index_ptr);
|
|
3050
|
+
else
|
|
3051
|
+
nk_reduce_minmax_e5m2_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
3052
|
+
max_index_ptr);
|
|
3053
|
+
}
|
|
3054
|
+
|
|
3055
|
+
NK_INTERNAL void nk_reduce_minmax_e2m3_skylake_contiguous_( //
|
|
3056
|
+
nk_e2m3_t const *data_ptr, nk_size_t count, //
|
|
3057
|
+
nk_e2m3_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
3058
|
+
nk_e2m3_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
3059
|
+
nk_b512_vec_t min_vec, max_vec;
|
|
3060
|
+
min_vec.zmm = _mm512_set1_epi8((char)0xFF);
|
|
3061
|
+
max_vec.zmm = _mm512_setzero_si512();
|
|
3062
|
+
__m512i min_loop_cycle_u8x64 = _mm512_setzero_si512();
|
|
3063
|
+
__m512i max_loop_cycle_u8x64 = _mm512_setzero_si512();
|
|
3064
|
+
__m512i current_loop_cycle_u8x64 = _mm512_setzero_si512();
|
|
3065
|
+
__m512i one_u8x64 = _mm512_set1_epi8(1);
|
|
3066
|
+
|
|
3067
|
+
nk_size_t idx = 0;
|
|
3068
|
+
for (; idx + 64 <= count; idx += 64) {
|
|
3069
|
+
__m512i data_i8x64 = _mm512_loadu_si512(data_ptr + idx);
|
|
3070
|
+
__m512i data_cmp_u8x64 = nk_fp6x64_to_u8x64_comparable_skylake_(data_i8x64);
|
|
3071
|
+
__mmask64 min_changed_mask = _mm512_cmp_epu8_mask(data_cmp_u8x64, min_vec.zmm, _MM_CMPINT_LT);
|
|
3072
|
+
min_vec.zmm = _mm512_min_epu8(min_vec.zmm, data_cmp_u8x64);
|
|
3073
|
+
min_loop_cycle_u8x64 = _mm512_mask_mov_epi8(min_loop_cycle_u8x64, min_changed_mask, current_loop_cycle_u8x64);
|
|
3074
|
+
__mmask64 max_changed_mask = _mm512_cmp_epu8_mask(data_cmp_u8x64, max_vec.zmm, _MM_CMPINT_NLE);
|
|
3075
|
+
max_vec.zmm = _mm512_max_epu8(max_vec.zmm, data_cmp_u8x64);
|
|
3076
|
+
max_loop_cycle_u8x64 = _mm512_mask_mov_epi8(max_loop_cycle_u8x64, max_changed_mask, current_loop_cycle_u8x64);
|
|
3077
|
+
current_loop_cycle_u8x64 = _mm512_add_epi8(current_loop_cycle_u8x64, one_u8x64);
|
|
3078
|
+
}
|
|
3079
|
+
|
|
3080
|
+
nk_size_t remaining = count - idx;
|
|
3081
|
+
if (remaining > 0) {
|
|
3082
|
+
__mmask64 tail_load = _bzhi_u64(0xFFFFFFFFFFFFFFFFull, (unsigned int)remaining);
|
|
3083
|
+
__m512i data_i8x64 = _mm512_mask_loadu_epi8(_mm512_set1_epi8(0x3F), tail_load, data_ptr + idx);
|
|
3084
|
+
__m512i data_cmp_u8x64 = nk_fp6x64_to_u8x64_comparable_skylake_(data_i8x64);
|
|
3085
|
+
__mmask64 min_changed_mask = _mm512_mask_cmp_epu8_mask(tail_load, data_cmp_u8x64, min_vec.zmm, _MM_CMPINT_LT);
|
|
3086
|
+
min_vec.zmm = _mm512_mask_min_epu8(min_vec.zmm, tail_load, min_vec.zmm, data_cmp_u8x64);
|
|
3087
|
+
min_loop_cycle_u8x64 = _mm512_mask_mov_epi8(min_loop_cycle_u8x64, min_changed_mask, current_loop_cycle_u8x64);
|
|
3088
|
+
__mmask64 max_changed_mask = _mm512_mask_cmp_epu8_mask(tail_load, data_cmp_u8x64, max_vec.zmm, _MM_CMPINT_NLE);
|
|
3089
|
+
max_vec.zmm = _mm512_mask_max_epu8(max_vec.zmm, tail_load, max_vec.zmm, data_cmp_u8x64);
|
|
3090
|
+
max_loop_cycle_u8x64 = _mm512_mask_mov_epi8(max_loop_cycle_u8x64, max_changed_mask, current_loop_cycle_u8x64);
|
|
3091
|
+
}
|
|
3092
|
+
|
|
3093
|
+
nk_u8_t min_value_comparable = nk_reduce_min_u8x64_skylake_(min_vec.zmm);
|
|
3094
|
+
nk_u8_t max_value_comparable = nk_reduce_max_u8x64_skylake_(max_vec.zmm);
|
|
3095
|
+
unsigned int min_lane, max_lane;
|
|
3096
|
+
{
|
|
3097
|
+
__mmask64 value_match_mask = _mm512_cmpeq_epi8_mask(min_vec.zmm, _mm512_set1_epi8((char)min_value_comparable));
|
|
3098
|
+
__m512i masked_cycle_u8x64 = _mm512_mask_blend_epi8(value_match_mask, _mm512_set1_epi8((char)NK_U8_MAX),
|
|
3099
|
+
min_loop_cycle_u8x64);
|
|
3100
|
+
nk_u8_t earliest_loop_cycle = nk_reduce_min_u8x64_skylake_(masked_cycle_u8x64);
|
|
3101
|
+
__mmask64 cycle_match_mask = _mm512_cmpeq_epi8_mask(masked_cycle_u8x64,
|
|
3102
|
+
_mm512_set1_epi8((char)earliest_loop_cycle));
|
|
3103
|
+
min_lane = (unsigned int)_tzcnt_u64(cycle_match_mask);
|
|
3104
|
+
}
|
|
3105
|
+
{
|
|
3106
|
+
__mmask64 value_match_mask = _mm512_cmpeq_epi8_mask(max_vec.zmm, _mm512_set1_epi8((char)max_value_comparable));
|
|
3107
|
+
__m512i masked_cycle_u8x64 = _mm512_mask_blend_epi8(value_match_mask, _mm512_set1_epi8((char)NK_U8_MAX),
|
|
3108
|
+
max_loop_cycle_u8x64);
|
|
3109
|
+
nk_u8_t earliest_loop_cycle = nk_reduce_min_u8x64_skylake_(masked_cycle_u8x64);
|
|
3110
|
+
__mmask64 cycle_match_mask = _mm512_cmpeq_epi8_mask(masked_cycle_u8x64,
|
|
3111
|
+
_mm512_set1_epi8((char)earliest_loop_cycle));
|
|
3112
|
+
max_lane = (unsigned int)_tzcnt_u64(cycle_match_mask);
|
|
3113
|
+
}
|
|
3114
|
+
nk_b512_vec_t loop_cycle_vec;
|
|
3115
|
+
loop_cycle_vec.zmm = min_loop_cycle_u8x64;
|
|
3116
|
+
*min_index_ptr = (nk_size_t)loop_cycle_vec.u8s[min_lane] * 64 + min_lane;
|
|
3117
|
+
loop_cycle_vec.zmm = max_loop_cycle_u8x64;
|
|
3118
|
+
*max_index_ptr = (nk_size_t)loop_cycle_vec.u8s[max_lane] * 64 + max_lane;
|
|
3119
|
+
min_vec.zmm = nk_u8x64_comparable_to_fp6x64_skylake_(min_vec.zmm);
|
|
3120
|
+
max_vec.zmm = nk_u8x64_comparable_to_fp6x64_skylake_(max_vec.zmm);
|
|
3121
|
+
*min_value_ptr = min_vec.e2m3s[min_lane];
|
|
3122
|
+
*max_value_ptr = max_vec.e2m3s[max_lane];
|
|
3123
|
+
}
|
|
3124
|
+
|
|
3125
|
+
NK_PUBLIC void nk_reduce_minmax_e2m3_skylake( //
|
|
3126
|
+
nk_e2m3_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
3127
|
+
nk_e2m3_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
3128
|
+
nk_e2m3_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
3129
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_e2m3_t);
|
|
3130
|
+
if (count == 0)
|
|
3131
|
+
*min_value_ptr = NK_E2M3_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_E2M3_MIN,
|
|
3132
|
+
*max_index_ptr = NK_SIZE_MAX;
|
|
3133
|
+
else if (stride_elements == 1 && count > (nk_size_t)(NK_U8_MAX + 1) * 64) {
|
|
3134
|
+
nk_size_t left_count = count / 2;
|
|
3135
|
+
nk_e2m3_t left_min, right_min, left_max, right_max;
|
|
3136
|
+
nk_size_t left_min_index, right_min_index, left_max_index, right_max_index;
|
|
3137
|
+
nk_reduce_minmax_e2m3_skylake(data_ptr, left_count, stride_bytes, &left_min, &left_min_index, &left_max,
|
|
3138
|
+
&left_max_index);
|
|
3139
|
+
nk_reduce_minmax_e2m3_skylake(data_ptr + left_count, count - left_count, stride_bytes, &right_min,
|
|
3140
|
+
&right_min_index, &right_max, &right_max_index);
|
|
3141
|
+
if (nk_e2m3_order_serial(right_min, left_min) < 0)
|
|
3142
|
+
*min_value_ptr = right_min, *min_index_ptr = left_count + right_min_index;
|
|
3143
|
+
else *min_value_ptr = left_min, *min_index_ptr = left_min_index;
|
|
3144
|
+
if (nk_e2m3_order_serial(right_max, left_max) > 0)
|
|
3145
|
+
*max_value_ptr = right_max, *max_index_ptr = left_count + right_max_index;
|
|
3146
|
+
else *max_value_ptr = left_max, *max_index_ptr = left_max_index;
|
|
3147
|
+
}
|
|
3148
|
+
else if (stride_elements == 1)
|
|
3149
|
+
nk_reduce_minmax_e2m3_skylake_contiguous_(data_ptr, count, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
3150
|
+
max_index_ptr);
|
|
3151
|
+
else
|
|
3152
|
+
nk_reduce_minmax_e2m3_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
3153
|
+
max_index_ptr);
|
|
3154
|
+
}
|
|
3155
|
+
|
|
3156
|
+
NK_INTERNAL void nk_reduce_minmax_e3m2_skylake_contiguous_( //
|
|
3157
|
+
nk_e3m2_t const *data_ptr, nk_size_t count, //
|
|
3158
|
+
nk_e3m2_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
3159
|
+
nk_e3m2_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
3160
|
+
nk_b512_vec_t min_vec, max_vec;
|
|
3161
|
+
min_vec.zmm = _mm512_set1_epi8((char)0xFF);
|
|
3162
|
+
max_vec.zmm = _mm512_setzero_si512();
|
|
3163
|
+
__m512i min_loop_cycle_u8x64 = _mm512_setzero_si512();
|
|
3164
|
+
__m512i max_loop_cycle_u8x64 = _mm512_setzero_si512();
|
|
3165
|
+
__m512i current_loop_cycle_u8x64 = _mm512_setzero_si512();
|
|
3166
|
+
__m512i one_u8x64 = _mm512_set1_epi8(1);
|
|
3167
|
+
|
|
3168
|
+
nk_size_t idx = 0;
|
|
3169
|
+
for (; idx + 64 <= count; idx += 64) {
|
|
3170
|
+
__m512i data_i8x64 = _mm512_loadu_si512(data_ptr + idx);
|
|
3171
|
+
__m512i data_cmp_u8x64 = nk_fp6x64_to_u8x64_comparable_skylake_(data_i8x64);
|
|
3172
|
+
__mmask64 min_changed_mask = _mm512_cmp_epu8_mask(data_cmp_u8x64, min_vec.zmm, _MM_CMPINT_LT);
|
|
3173
|
+
min_vec.zmm = _mm512_min_epu8(min_vec.zmm, data_cmp_u8x64);
|
|
3174
|
+
min_loop_cycle_u8x64 = _mm512_mask_mov_epi8(min_loop_cycle_u8x64, min_changed_mask, current_loop_cycle_u8x64);
|
|
3175
|
+
__mmask64 max_changed_mask = _mm512_cmp_epu8_mask(data_cmp_u8x64, max_vec.zmm, _MM_CMPINT_NLE);
|
|
3176
|
+
max_vec.zmm = _mm512_max_epu8(max_vec.zmm, data_cmp_u8x64);
|
|
3177
|
+
max_loop_cycle_u8x64 = _mm512_mask_mov_epi8(max_loop_cycle_u8x64, max_changed_mask, current_loop_cycle_u8x64);
|
|
3178
|
+
current_loop_cycle_u8x64 = _mm512_add_epi8(current_loop_cycle_u8x64, one_u8x64);
|
|
3179
|
+
}
|
|
3180
|
+
|
|
3181
|
+
nk_size_t remaining = count - idx;
|
|
3182
|
+
if (remaining > 0) {
|
|
3183
|
+
__mmask64 tail_load = _bzhi_u64(0xFFFFFFFFFFFFFFFFull, (unsigned int)remaining);
|
|
3184
|
+
__m512i data_i8x64 = _mm512_mask_loadu_epi8(_mm512_set1_epi8(0x3F), tail_load, data_ptr + idx);
|
|
3185
|
+
__m512i data_cmp_u8x64 = nk_fp6x64_to_u8x64_comparable_skylake_(data_i8x64);
|
|
3186
|
+
__mmask64 min_changed_mask = _mm512_mask_cmp_epu8_mask(tail_load, data_cmp_u8x64, min_vec.zmm, _MM_CMPINT_LT);
|
|
3187
|
+
min_vec.zmm = _mm512_mask_min_epu8(min_vec.zmm, tail_load, min_vec.zmm, data_cmp_u8x64);
|
|
3188
|
+
min_loop_cycle_u8x64 = _mm512_mask_mov_epi8(min_loop_cycle_u8x64, min_changed_mask, current_loop_cycle_u8x64);
|
|
3189
|
+
__mmask64 max_changed_mask = _mm512_mask_cmp_epu8_mask(tail_load, data_cmp_u8x64, max_vec.zmm, _MM_CMPINT_NLE);
|
|
3190
|
+
max_vec.zmm = _mm512_mask_max_epu8(max_vec.zmm, tail_load, max_vec.zmm, data_cmp_u8x64);
|
|
3191
|
+
max_loop_cycle_u8x64 = _mm512_mask_mov_epi8(max_loop_cycle_u8x64, max_changed_mask, current_loop_cycle_u8x64);
|
|
3192
|
+
}
|
|
3193
|
+
|
|
3194
|
+
nk_u8_t min_value_comparable = nk_reduce_min_u8x64_skylake_(min_vec.zmm);
|
|
3195
|
+
nk_u8_t max_value_comparable = nk_reduce_max_u8x64_skylake_(max_vec.zmm);
|
|
3196
|
+
unsigned int min_lane, max_lane;
|
|
3197
|
+
{
|
|
3198
|
+
__mmask64 value_match_mask = _mm512_cmpeq_epi8_mask(min_vec.zmm, _mm512_set1_epi8((char)min_value_comparable));
|
|
3199
|
+
__m512i masked_cycle_u8x64 = _mm512_mask_blend_epi8(value_match_mask, _mm512_set1_epi8((char)NK_U8_MAX),
|
|
3200
|
+
min_loop_cycle_u8x64);
|
|
3201
|
+
nk_u8_t earliest_loop_cycle = nk_reduce_min_u8x64_skylake_(masked_cycle_u8x64);
|
|
3202
|
+
__mmask64 cycle_match_mask = _mm512_cmpeq_epi8_mask(masked_cycle_u8x64,
|
|
3203
|
+
_mm512_set1_epi8((char)earliest_loop_cycle));
|
|
3204
|
+
min_lane = (unsigned int)_tzcnt_u64(cycle_match_mask);
|
|
3205
|
+
}
|
|
3206
|
+
{
|
|
3207
|
+
__mmask64 value_match_mask = _mm512_cmpeq_epi8_mask(max_vec.zmm, _mm512_set1_epi8((char)max_value_comparable));
|
|
3208
|
+
__m512i masked_cycle_u8x64 = _mm512_mask_blend_epi8(value_match_mask, _mm512_set1_epi8((char)NK_U8_MAX),
|
|
3209
|
+
max_loop_cycle_u8x64);
|
|
3210
|
+
nk_u8_t earliest_loop_cycle = nk_reduce_min_u8x64_skylake_(masked_cycle_u8x64);
|
|
3211
|
+
__mmask64 cycle_match_mask = _mm512_cmpeq_epi8_mask(masked_cycle_u8x64,
|
|
3212
|
+
_mm512_set1_epi8((char)earliest_loop_cycle));
|
|
3213
|
+
max_lane = (unsigned int)_tzcnt_u64(cycle_match_mask);
|
|
3214
|
+
}
|
|
3215
|
+
nk_b512_vec_t loop_cycle_vec;
|
|
3216
|
+
loop_cycle_vec.zmm = min_loop_cycle_u8x64;
|
|
3217
|
+
*min_index_ptr = (nk_size_t)loop_cycle_vec.u8s[min_lane] * 64 + min_lane;
|
|
3218
|
+
loop_cycle_vec.zmm = max_loop_cycle_u8x64;
|
|
3219
|
+
*max_index_ptr = (nk_size_t)loop_cycle_vec.u8s[max_lane] * 64 + max_lane;
|
|
3220
|
+
min_vec.zmm = nk_u8x64_comparable_to_fp6x64_skylake_(min_vec.zmm);
|
|
3221
|
+
max_vec.zmm = nk_u8x64_comparable_to_fp6x64_skylake_(max_vec.zmm);
|
|
3222
|
+
*min_value_ptr = min_vec.e3m2s[min_lane];
|
|
3223
|
+
*max_value_ptr = max_vec.e3m2s[max_lane];
|
|
3224
|
+
}
|
|
3225
|
+
|
|
3226
|
+
NK_PUBLIC void nk_reduce_minmax_e3m2_skylake( //
|
|
3227
|
+
nk_e3m2_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
3228
|
+
nk_e3m2_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
3229
|
+
nk_e3m2_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
3230
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_e3m2_t);
|
|
3231
|
+
if (count == 0)
|
|
3232
|
+
*min_value_ptr = NK_E3M2_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_E3M2_MIN,
|
|
3233
|
+
*max_index_ptr = NK_SIZE_MAX;
|
|
3234
|
+
else if (stride_elements == 1 && count > (nk_size_t)(NK_U8_MAX + 1) * 64) {
|
|
3235
|
+
nk_size_t left_count = count / 2;
|
|
3236
|
+
nk_e3m2_t left_min, right_min, left_max, right_max;
|
|
3237
|
+
nk_size_t left_min_index, right_min_index, left_max_index, right_max_index;
|
|
3238
|
+
nk_reduce_minmax_e3m2_skylake(data_ptr, left_count, stride_bytes, &left_min, &left_min_index, &left_max,
|
|
3239
|
+
&left_max_index);
|
|
3240
|
+
nk_reduce_minmax_e3m2_skylake(data_ptr + left_count, count - left_count, stride_bytes, &right_min,
|
|
3241
|
+
&right_min_index, &right_max, &right_max_index);
|
|
3242
|
+
if (nk_e3m2_order_serial(right_min, left_min) < 0)
|
|
3243
|
+
*min_value_ptr = right_min, *min_index_ptr = left_count + right_min_index;
|
|
3244
|
+
else *min_value_ptr = left_min, *min_index_ptr = left_min_index;
|
|
3245
|
+
if (nk_e3m2_order_serial(right_max, left_max) > 0)
|
|
3246
|
+
*max_value_ptr = right_max, *max_index_ptr = left_count + right_max_index;
|
|
3247
|
+
else *max_value_ptr = left_max, *max_index_ptr = left_max_index;
|
|
3248
|
+
}
|
|
3249
|
+
else if (stride_elements == 1)
|
|
3250
|
+
nk_reduce_minmax_e3m2_skylake_contiguous_(data_ptr, count, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
3251
|
+
max_index_ptr);
|
|
3252
|
+
else
|
|
3253
|
+
nk_reduce_minmax_e3m2_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
3254
|
+
max_index_ptr);
|
|
3255
|
+
}
|
|
3256
|
+
|
|
3257
|
+
NK_INTERNAL void nk_reduce_moments_i4_skylake_contiguous_( //
|
|
3258
|
+
nk_i4x2_t const *data_ptr, nk_size_t count, //
|
|
3259
|
+
nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
3260
|
+
// Sum: XOR-bias nibbles to unsigned, vpsadbw, unbias at end.
|
|
3261
|
+
// Sumsq: squares are sign-independent; LUT maps nibble→square (max 225 fits u8), vpsadbw to u64.
|
|
3262
|
+
__m512i mask_0f_i8x64 = _mm512_set1_epi8(0x0F);
|
|
3263
|
+
__m512i eight_i8x64 = _mm512_set1_epi8(8);
|
|
3264
|
+
__m512i zero_i8x64 = _mm512_setzero_si512();
|
|
3265
|
+
// Squares LUT: sq_lut[n] = n² for n in [0,15], all fit in u8 (max 225)
|
|
3266
|
+
__m512i sq_lut_u8x64 = _mm512_set_epi8( //
|
|
3267
|
+
(char)225, (char)196, (char)169, (char)144, 121, 100, 81, 64, 49, 36, 25, 16, 9, 4, 1, 0, //
|
|
3268
|
+
(char)225, (char)196, (char)169, (char)144, 121, 100, 81, 64, 49, 36, 25, 16, 9, 4, 1, 0, //
|
|
3269
|
+
(char)225, (char)196, (char)169, (char)144, 121, 100, 81, 64, 49, 36, 25, 16, 9, 4, 1, 0, //
|
|
3270
|
+
(char)225, (char)196, (char)169, (char)144, 121, 100, 81, 64, 49, 36, 25, 16, 9, 4, 1, 0);
|
|
3271
|
+
__m512i sum_u64x8 = _mm512_setzero_si512();
|
|
3272
|
+
__m512i sumsq_u64x8 = _mm512_setzero_si512();
|
|
3273
|
+
nk_size_t count_bytes = nk_size_divide_round_up_(count, 2);
|
|
3274
|
+
unsigned char const *ptr = (unsigned char const *)data_ptr;
|
|
3275
|
+
while (count_bytes > 0) {
|
|
3276
|
+
__m512i raw_i8x64;
|
|
3277
|
+
if (count_bytes < 64) {
|
|
3278
|
+
__mmask64 tail_mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFFULL, (unsigned int)count_bytes);
|
|
3279
|
+
raw_i8x64 = _mm512_maskz_loadu_epi8(tail_mask, ptr);
|
|
3280
|
+
count_bytes = 0;
|
|
3281
|
+
}
|
|
3282
|
+
else {
|
|
3283
|
+
raw_i8x64 = _mm512_loadu_si512(ptr);
|
|
3284
|
+
ptr += 64, count_bytes -= 64;
|
|
3285
|
+
}
|
|
3286
|
+
// Extract nibbles as unsigned [0,15]
|
|
3287
|
+
__m512i low_u4x64 = _mm512_and_si512(raw_i8x64, mask_0f_i8x64);
|
|
3288
|
+
__m512i high_u4x64 = _mm512_and_si512(_mm512_srli_epi16(raw_i8x64, 4), mask_0f_i8x64);
|
|
3289
|
+
// Sum: XOR-bias nibbles to unsigned [0,15], add lo+hi per byte, vpsadbw
|
|
3290
|
+
__m512i low_biased_u4x64 = _mm512_xor_si512(low_u4x64, eight_i8x64);
|
|
3291
|
+
__m512i high_biased_u4x64 = _mm512_xor_si512(high_u4x64, eight_i8x64);
|
|
3292
|
+
__m512i pair_sum = _mm512_add_epi8(low_biased_u4x64, high_biased_u4x64);
|
|
3293
|
+
sum_u64x8 = _mm512_add_epi64(sum_u64x8, _mm512_sad_epu8(pair_sum, zero_i8x64));
|
|
3294
|
+
// Sumsq: squares are sign-independent, use LUT on unsigned nibbles
|
|
3295
|
+
__m512i low_sq_u8x64 = _mm512_shuffle_epi8(sq_lut_u8x64, low_u4x64);
|
|
3296
|
+
__m512i high_sq_u8x64 = _mm512_shuffle_epi8(sq_lut_u8x64, high_u4x64);
|
|
3297
|
+
sumsq_u64x8 = _mm512_add_epi64(sumsq_u64x8, _mm512_sad_epu8(low_sq_u8x64, zero_i8x64));
|
|
3298
|
+
sumsq_u64x8 = _mm512_add_epi64(sumsq_u64x8, _mm512_sad_epu8(high_sq_u8x64, zero_i8x64));
|
|
3299
|
+
}
|
|
3300
|
+
// The XOR-8 bias adds 8 per nibble to the SAD total. Subtract 8 × total nibbles processed
|
|
3301
|
+
// (including zero-masked register bytes, where 0 XOR 8 = 8, signed = 0).
|
|
3302
|
+
nk_size_t nibbles_processed = nk_size_round_up_to_multiple_(nk_size_divide_round_up_(count, 2), 64) * 2;
|
|
3303
|
+
nk_i64_t sum = (nk_i64_t)nk_reduce_add_u64x8_skylake_(sum_u64x8) - (nk_i64_t)8 * (nk_i64_t)nibbles_processed;
|
|
3304
|
+
// sumsq uses sq_lut[0]=0 for zero-padded nibbles, so no register-padding correction needed.
|
|
3305
|
+
nk_u64_t sumsq = nk_reduce_add_u64x8_skylake_(sumsq_u64x8);
|
|
3306
|
+
*sum_ptr = sum, *sumsq_ptr = sumsq;
|
|
3307
|
+
}
|
|
3308
|
+
|
|
3309
|
+
NK_PUBLIC void nk_reduce_moments_i4_skylake( //
|
|
3310
|
+
nk_i4x2_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
3311
|
+
nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
3312
|
+
count = nk_size_round_up_to_multiple_(count, 2);
|
|
3313
|
+
if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
|
|
3314
|
+
else if (stride_bytes == 1) nk_reduce_moments_i4_skylake_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
|
|
3315
|
+
else nk_reduce_moments_i4_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
3316
|
+
}
|
|
3317
|
+
|
|
3318
|
+
NK_INTERNAL void nk_reduce_moments_u4_skylake_contiguous_( //
|
|
3319
|
+
nk_u4x2_t const *data_ptr, nk_size_t count, //
|
|
3320
|
+
nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
3321
|
+
// Sum: VPSADBW on extracted nibbles. Sumsq: LUT maps nibble→square (max 225 fits u8), vpsadbw to u64.
|
|
3322
|
+
__m512i mask_0f_i8x64 = _mm512_set1_epi8(0x0F);
|
|
3323
|
+
__m512i zero_i8x64 = _mm512_setzero_si512();
|
|
3324
|
+
// Squares LUT: sq_lut[n] = n² for n in [0,15], all fit in u8 (max 225)
|
|
3325
|
+
__m512i sq_lut_u8x64 = _mm512_set_epi8( //
|
|
3326
|
+
(char)225, (char)196, (char)169, (char)144, 121, 100, 81, 64, 49, 36, 25, 16, 9, 4, 1, 0, //
|
|
3327
|
+
(char)225, (char)196, (char)169, (char)144, 121, 100, 81, 64, 49, 36, 25, 16, 9, 4, 1, 0, //
|
|
3328
|
+
(char)225, (char)196, (char)169, (char)144, 121, 100, 81, 64, 49, 36, 25, 16, 9, 4, 1, 0, //
|
|
3329
|
+
(char)225, (char)196, (char)169, (char)144, 121, 100, 81, 64, 49, 36, 25, 16, 9, 4, 1, 0);
|
|
3330
|
+
__m512i sum_u64x8 = _mm512_setzero_si512();
|
|
3331
|
+
__m512i sumsq_u64x8 = _mm512_setzero_si512();
|
|
3332
|
+
nk_size_t count_bytes = count / 2;
|
|
3333
|
+
unsigned char const *ptr = (unsigned char const *)data_ptr;
|
|
3334
|
+
while (count_bytes > 0) {
|
|
3335
|
+
__m512i raw_i8x64;
|
|
3336
|
+
if (count_bytes < 64) {
|
|
3337
|
+
__mmask64 tail_mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFFULL, (unsigned int)count_bytes);
|
|
3338
|
+
raw_i8x64 = _mm512_maskz_loadu_epi8(tail_mask, ptr);
|
|
3339
|
+
count_bytes = 0;
|
|
3340
|
+
}
|
|
3341
|
+
else {
|
|
3342
|
+
raw_i8x64 = _mm512_loadu_si512(ptr);
|
|
3343
|
+
ptr += 64, count_bytes -= 64;
|
|
3344
|
+
}
|
|
3345
|
+
__m512i low_u4x64 = _mm512_and_si512(raw_i8x64, mask_0f_i8x64);
|
|
3346
|
+
__m512i high_u4x64 = _mm512_and_si512(_mm512_srli_epi16(raw_i8x64, 4), mask_0f_i8x64);
|
|
3347
|
+
__m512i pair_sum = _mm512_add_epi8(low_u4x64, high_u4x64);
|
|
3348
|
+
sum_u64x8 = _mm512_add_epi64(sum_u64x8, _mm512_sad_epu8(pair_sum, zero_i8x64));
|
|
3349
|
+
// Sumsq: LUT maps nibble→square, vpsadbw accumulates into u64
|
|
3350
|
+
__m512i low_sq_u8x64 = _mm512_shuffle_epi8(sq_lut_u8x64, low_u4x64);
|
|
3351
|
+
__m512i high_sq_u8x64 = _mm512_shuffle_epi8(sq_lut_u8x64, high_u4x64);
|
|
3352
|
+
sumsq_u64x8 = _mm512_add_epi64(sumsq_u64x8, _mm512_sad_epu8(low_sq_u8x64, zero_i8x64));
|
|
3353
|
+
sumsq_u64x8 = _mm512_add_epi64(sumsq_u64x8, _mm512_sad_epu8(high_sq_u8x64, zero_i8x64));
|
|
3354
|
+
}
|
|
3355
|
+
nk_u64_t sum = _mm512_reduce_add_epi64(sum_u64x8);
|
|
3356
|
+
nk_u64_t sumsq = nk_reduce_add_u64x8_skylake_(sumsq_u64x8);
|
|
3357
|
+
*sum_ptr = sum, *sumsq_ptr = sumsq;
|
|
3358
|
+
}
|
|
3359
|
+
|
|
3360
|
+
NK_PUBLIC void nk_reduce_moments_u4_skylake( //
|
|
3361
|
+
nk_u4x2_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
3362
|
+
nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
3363
|
+
count = nk_size_round_up_to_multiple_(count, 2);
|
|
3364
|
+
if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
|
|
3365
|
+
else if (stride_bytes == 1) nk_reduce_moments_u4_skylake_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
|
|
3366
|
+
else nk_reduce_moments_u4_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
3367
|
+
}
|
|
3368
|
+
|
|
3369
|
+
NK_INTERNAL void nk_reduce_moments_u1_skylake_contiguous_( //
|
|
3370
|
+
nk_u1x8_t const *data_ptr, nk_size_t count, //
|
|
3371
|
+
nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
3372
|
+
// Sum = popcount via 4-bit LUT (same as nk_reduce_add_u1_skylake). Sumsq = sum for bits.
|
|
3373
|
+
__m512i lut_i8x64 = _mm512_set_epi8( //
|
|
3374
|
+
4, 3, 3, 2, 3, 2, 2, 1, 3, 2, 2, 1, 2, 1, 1, 0, //
|
|
3375
|
+
4, 3, 3, 2, 3, 2, 2, 1, 3, 2, 2, 1, 2, 1, 1, 0, //
|
|
3376
|
+
4, 3, 3, 2, 3, 2, 2, 1, 3, 2, 2, 1, 2, 1, 1, 0, //
|
|
3377
|
+
4, 3, 3, 2, 3, 2, 2, 1, 3, 2, 2, 1, 2, 1, 1, 0);
|
|
3378
|
+
__m512i mask_0f_i8x64 = _mm512_set1_epi8(0x0F);
|
|
3379
|
+
__m512i zero_i8x64 = _mm512_setzero_si512();
|
|
3380
|
+
__m512i sum_u64x8 = _mm512_setzero_si512();
|
|
3381
|
+
nk_size_t count_bytes = count / 8;
|
|
3382
|
+
unsigned char const *ptr = (unsigned char const *)data_ptr;
|
|
3383
|
+
while (count_bytes > 0) {
|
|
3384
|
+
__m512i raw_i8x64;
|
|
3385
|
+
if (count_bytes < 64) {
|
|
3386
|
+
__mmask64 tail_mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFFULL, (unsigned int)count_bytes);
|
|
3387
|
+
raw_i8x64 = _mm512_maskz_loadu_epi8(tail_mask, ptr);
|
|
3388
|
+
count_bytes = 0;
|
|
3389
|
+
}
|
|
3390
|
+
else {
|
|
3391
|
+
raw_i8x64 = _mm512_loadu_si512(ptr);
|
|
3392
|
+
ptr += 64, count_bytes -= 64;
|
|
3393
|
+
}
|
|
3394
|
+
__m512i low_nibble_u8x64 = _mm512_and_si512(raw_i8x64, mask_0f_i8x64);
|
|
3395
|
+
__m512i high_nibble_u8x64 = _mm512_and_si512(_mm512_srli_epi16(raw_i8x64, 4), mask_0f_i8x64);
|
|
3396
|
+
__m512i popcnt_u8x64 = _mm512_add_epi8(_mm512_shuffle_epi8(lut_i8x64, low_nibble_u8x64),
|
|
3397
|
+
_mm512_shuffle_epi8(lut_i8x64, high_nibble_u8x64));
|
|
3398
|
+
sum_u64x8 = _mm512_add_epi64(sum_u64x8, _mm512_sad_epu8(popcnt_u8x64, zero_i8x64));
|
|
3399
|
+
}
|
|
3400
|
+
nk_u64_t sum = _mm512_reduce_add_epi64(sum_u64x8);
|
|
3401
|
+
*sum_ptr = sum;
|
|
3402
|
+
*sumsq_ptr = sum;
|
|
3403
|
+
}
|
|
3404
|
+
|
|
3405
|
+
NK_PUBLIC void nk_reduce_moments_u1_skylake( //
|
|
3406
|
+
nk_u1x8_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
3407
|
+
nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
3408
|
+
count = nk_size_round_up_to_multiple_(count, 8);
|
|
3409
|
+
if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
|
|
3410
|
+
else if (stride_bytes == 1) nk_reduce_moments_u1_skylake_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
|
|
3411
|
+
else nk_reduce_moments_u1_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
3412
|
+
}
|
|
3413
|
+
|
|
3414
|
+
NK_INTERNAL void nk_reduce_moments_bf16_skylake_contiguous_( //
|
|
3415
|
+
nk_bf16_t const *data_ptr, nk_size_t count, //
|
|
3416
|
+
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
3417
|
+
__m512 sum_f32x16 = _mm512_setzero_ps();
|
|
3418
|
+
__m512 sumsq_f32x16 = _mm512_setzero_ps();
|
|
3419
|
+
nk_size_t idx = 0;
|
|
3420
|
+
for (; idx + 32 <= count; idx += 32) {
|
|
3421
|
+
__m512 low_f32x16 = _mm512_castsi512_ps(
|
|
3422
|
+
_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((__m256i const *)(data_ptr + idx))), 16));
|
|
3423
|
+
__m512 high_f32x16 = _mm512_castsi512_ps(
|
|
3424
|
+
_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((__m256i const *)(data_ptr + idx + 16))), 16));
|
|
3425
|
+
sum_f32x16 = _mm512_add_ps(sum_f32x16, low_f32x16);
|
|
3426
|
+
sum_f32x16 = _mm512_add_ps(sum_f32x16, high_f32x16);
|
|
3427
|
+
sumsq_f32x16 = _mm512_fmadd_ps(low_f32x16, low_f32x16, sumsq_f32x16);
|
|
3428
|
+
sumsq_f32x16 = _mm512_fmadd_ps(high_f32x16, high_f32x16, sumsq_f32x16);
|
|
3429
|
+
}
|
|
3430
|
+
nk_size_t remaining = count - idx;
|
|
3431
|
+
if (remaining > 0) {
|
|
3432
|
+
__mmask16 low_mask = (__mmask16)_bzhi_u32(0xFFFF, (unsigned int)(remaining > 16 ? 16 : remaining));
|
|
3433
|
+
__m512 low_f32x16 = _mm512_castsi512_ps(
|
|
3434
|
+
_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_maskz_loadu_epi16(low_mask, data_ptr + idx)), 16));
|
|
3435
|
+
sum_f32x16 = _mm512_add_ps(sum_f32x16, low_f32x16);
|
|
3436
|
+
sumsq_f32x16 = _mm512_fmadd_ps(low_f32x16, low_f32x16, sumsq_f32x16);
|
|
3437
|
+
if (remaining > 16) {
|
|
3438
|
+
__mmask16 high_mask = (__mmask16)_bzhi_u32(0xFFFF, (unsigned int)(remaining - 16));
|
|
3439
|
+
__m512 high_f32x16 = _mm512_castsi512_ps(
|
|
3440
|
+
_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_maskz_loadu_epi16(high_mask, data_ptr + idx + 16)), 16));
|
|
3441
|
+
sum_f32x16 = _mm512_add_ps(sum_f32x16, high_f32x16);
|
|
3442
|
+
sumsq_f32x16 = _mm512_fmadd_ps(high_f32x16, high_f32x16, sumsq_f32x16);
|
|
3443
|
+
}
|
|
3444
|
+
}
|
|
3445
|
+
*sum_ptr = nk_reduce_add_f32x16_skylake_(sum_f32x16);
|
|
3446
|
+
*sumsq_ptr = nk_reduce_add_f32x16_skylake_(sumsq_f32x16);
|
|
3447
|
+
}
|
|
3448
|
+
|
|
3449
|
+
NK_PUBLIC void nk_reduce_moments_bf16_skylake( //
|
|
3450
|
+
nk_bf16_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
3451
|
+
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
3452
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_bf16_t);
|
|
3453
|
+
int aligned = (stride_bytes % sizeof(nk_bf16_t) == 0);
|
|
3454
|
+
if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
|
|
3455
|
+
else if (!aligned) nk_reduce_moments_bf16_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
3456
|
+
else if (count > (nk_size_t)(NK_U16_MAX + 1) * 32) {
|
|
3457
|
+
nk_size_t left_count = count / 2;
|
|
3458
|
+
nk_f32_t left_sum, left_sumsq, right_sum, right_sumsq;
|
|
3459
|
+
nk_reduce_moments_bf16_skylake(data_ptr, left_count, stride_bytes, &left_sum, &left_sumsq);
|
|
3460
|
+
nk_reduce_moments_bf16_skylake(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
|
|
3461
|
+
&right_sum, &right_sumsq);
|
|
3462
|
+
*sum_ptr = left_sum + right_sum;
|
|
3463
|
+
*sumsq_ptr = left_sumsq + right_sumsq;
|
|
3464
|
+
}
|
|
3465
|
+
else if (stride_elements == 1) nk_reduce_moments_bf16_skylake_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
|
|
3466
|
+
else nk_reduce_moments_bf16_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
3467
|
+
}
|
|
3468
|
+
|
|
3469
|
+
NK_INTERNAL __m512i nk_bf16x32_to_comparable_i16x32_skylake_(__m512i raw_u16x32) {
|
|
3470
|
+
__m512i sign = _mm512_srai_epi16(raw_u16x32, 15);
|
|
3471
|
+
__m512i flip = _mm512_srli_epi16(sign, 1);
|
|
3472
|
+
return _mm512_xor_si512(raw_u16x32, flip);
|
|
3473
|
+
}
|
|
3474
|
+
|
|
3475
|
+
NK_INTERNAL void nk_reduce_minmax_bf16_skylake_contiguous_( //
|
|
3476
|
+
nk_bf16_t const *data_ptr, nk_size_t count, //
|
|
3477
|
+
nk_bf16_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
3478
|
+
nk_bf16_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
3479
|
+
__m512i abs_mask_u16x32 = _mm512_set1_epi16(0x7FFF);
|
|
3480
|
+
__m512i nan_threshold_u16x32 = _mm512_set1_epi16((short)0x7F80);
|
|
3481
|
+
__m512i min_cmp_i16x32 = _mm512_set1_epi16((short)0x7FFF);
|
|
3482
|
+
__m512i max_cmp_i16x32 = _mm512_set1_epi16((short)0x8000);
|
|
3483
|
+
__m512i min_loop_cycle_u16x32 = _mm512_setzero_si512();
|
|
3484
|
+
__m512i max_loop_cycle_u16x32 = _mm512_setzero_si512();
|
|
3485
|
+
__m512i current_loop_cycle_u16x32 = _mm512_setzero_si512();
|
|
3486
|
+
__m512i one_u16x32 = _mm512_set1_epi16(1);
|
|
3487
|
+
|
|
3488
|
+
nk_size_t idx = 0;
|
|
3489
|
+
for (; idx + 32 <= count; idx += 32) {
|
|
3490
|
+
__m512i raw_u16x32 = _mm512_loadu_si512(data_ptr + idx);
|
|
3491
|
+
__m512i data_cmp_i16x32 = nk_bf16x32_to_comparable_i16x32_skylake_(raw_u16x32);
|
|
3492
|
+
__m512i abs_u16x32 = _mm512_and_si512(raw_u16x32, abs_mask_u16x32);
|
|
3493
|
+
__mmask32 not_nan_m32 = _mm512_cmp_epu16_mask(abs_u16x32, nan_threshold_u16x32, _MM_CMPINT_LE);
|
|
3494
|
+
__mmask32 min_changed_mask = _mm512_mask_cmp_epi16_mask(not_nan_m32, data_cmp_i16x32, min_cmp_i16x32,
|
|
3495
|
+
_MM_CMPINT_LT);
|
|
3496
|
+
__mmask32 max_changed_mask = _mm512_mask_cmp_epi16_mask(not_nan_m32, data_cmp_i16x32, max_cmp_i16x32,
|
|
3497
|
+
_MM_CMPINT_NLE);
|
|
3498
|
+
min_cmp_i16x32 = _mm512_mask_mov_epi16(min_cmp_i16x32, min_changed_mask, data_cmp_i16x32);
|
|
3499
|
+
max_cmp_i16x32 = _mm512_mask_mov_epi16(max_cmp_i16x32, max_changed_mask, data_cmp_i16x32);
|
|
3500
|
+
min_loop_cycle_u16x32 = _mm512_mask_mov_epi16(min_loop_cycle_u16x32, min_changed_mask,
|
|
3501
|
+
current_loop_cycle_u16x32);
|
|
3502
|
+
max_loop_cycle_u16x32 = _mm512_mask_mov_epi16(max_loop_cycle_u16x32, max_changed_mask,
|
|
3503
|
+
current_loop_cycle_u16x32);
|
|
3504
|
+
current_loop_cycle_u16x32 = _mm512_add_epi16(current_loop_cycle_u16x32, one_u16x32);
|
|
3505
|
+
}
|
|
3506
|
+
|
|
3507
|
+
nk_size_t remaining = count - idx;
|
|
3508
|
+
if (remaining > 0) {
|
|
3509
|
+
__mmask32 tail_load_m32 = (__mmask32)_bzhi_u32(0xFFFFFFFF, (unsigned int)remaining);
|
|
3510
|
+
__m512i raw_u16x32 = _mm512_maskz_loadu_epi16(tail_load_m32, data_ptr + idx);
|
|
3511
|
+
__m512i data_cmp_i16x32 = nk_bf16x32_to_comparable_i16x32_skylake_(raw_u16x32);
|
|
3512
|
+
__m512i abs_u16x32 = _mm512_and_si512(raw_u16x32, abs_mask_u16x32);
|
|
3513
|
+
__mmask32 not_nan_m32 = _mm512_cmp_epu16_mask(abs_u16x32, nan_threshold_u16x32, _MM_CMPINT_LE);
|
|
3514
|
+
__mmask32 valid_m32 = tail_load_m32 & not_nan_m32;
|
|
3515
|
+
__mmask32 min_changed_mask = _mm512_mask_cmp_epi16_mask(valid_m32, data_cmp_i16x32, min_cmp_i16x32,
|
|
3516
|
+
_MM_CMPINT_LT);
|
|
3517
|
+
__mmask32 max_changed_mask = _mm512_mask_cmp_epi16_mask(valid_m32, data_cmp_i16x32, max_cmp_i16x32,
|
|
3518
|
+
_MM_CMPINT_NLE);
|
|
3519
|
+
min_cmp_i16x32 = _mm512_mask_mov_epi16(min_cmp_i16x32, min_changed_mask, data_cmp_i16x32);
|
|
3520
|
+
max_cmp_i16x32 = _mm512_mask_mov_epi16(max_cmp_i16x32, max_changed_mask, data_cmp_i16x32);
|
|
3521
|
+
min_loop_cycle_u16x32 = _mm512_mask_mov_epi16(min_loop_cycle_u16x32, min_changed_mask,
|
|
3522
|
+
current_loop_cycle_u16x32);
|
|
3523
|
+
max_loop_cycle_u16x32 = _mm512_mask_mov_epi16(max_loop_cycle_u16x32, max_changed_mask,
|
|
3524
|
+
current_loop_cycle_u16x32);
|
|
3525
|
+
}
|
|
3526
|
+
|
|
3527
|
+
nk_i16_t min_value_comparable = nk_reduce_min_i16x32_skylake_(min_cmp_i16x32);
|
|
3528
|
+
nk_i16_t max_value_comparable = nk_reduce_max_i16x32_skylake_(max_cmp_i16x32);
|
|
3529
|
+
if (min_value_comparable == 0x7FFF && max_value_comparable == (nk_i16_t)0x8000) {
|
|
3530
|
+
*min_value_ptr = NK_BF16_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_BF16_MIN,
|
|
3531
|
+
*max_index_ptr = NK_SIZE_MAX;
|
|
3532
|
+
return;
|
|
3533
|
+
}
|
|
3534
|
+
unsigned int min_lane, max_lane;
|
|
3535
|
+
{
|
|
3536
|
+
__mmask32 value_match_mask = _mm512_cmpeq_epi16_mask(min_cmp_i16x32, _mm512_set1_epi16(min_value_comparable));
|
|
3537
|
+
__m512i masked_cycle_u16x32 = _mm512_mask_blend_epi16(value_match_mask, _mm512_set1_epi16((short)NK_U16_MAX),
|
|
3538
|
+
min_loop_cycle_u16x32);
|
|
3539
|
+
nk_u16_t earliest_loop_cycle = nk_reduce_min_u16x32_skylake_(masked_cycle_u16x32);
|
|
3540
|
+
__mmask32 cycle_match_mask = _mm512_cmpeq_epi16_mask(masked_cycle_u16x32,
|
|
3541
|
+
_mm512_set1_epi16((short)earliest_loop_cycle));
|
|
3542
|
+
min_lane = _tzcnt_u32(cycle_match_mask);
|
|
3543
|
+
}
|
|
3544
|
+
{
|
|
3545
|
+
__mmask32 value_match_mask = _mm512_cmpeq_epi16_mask(max_cmp_i16x32, _mm512_set1_epi16(max_value_comparable));
|
|
3546
|
+
__m512i masked_cycle_u16x32 = _mm512_mask_blend_epi16(value_match_mask, _mm512_set1_epi16((short)NK_U16_MAX),
|
|
3547
|
+
max_loop_cycle_u16x32);
|
|
3548
|
+
nk_u16_t earliest_loop_cycle = nk_reduce_min_u16x32_skylake_(masked_cycle_u16x32);
|
|
3549
|
+
__mmask32 cycle_match_mask = _mm512_cmpeq_epi16_mask(masked_cycle_u16x32,
|
|
3550
|
+
_mm512_set1_epi16((short)earliest_loop_cycle));
|
|
3551
|
+
max_lane = _tzcnt_u32(cycle_match_mask);
|
|
3552
|
+
}
|
|
3553
|
+
nk_b512_vec_t loop_cycle_vec;
|
|
3554
|
+
loop_cycle_vec.zmm = min_loop_cycle_u16x32;
|
|
3555
|
+
*min_index_ptr = (nk_size_t)loop_cycle_vec.u16s[min_lane] * 32 + min_lane;
|
|
3556
|
+
loop_cycle_vec.zmm = max_loop_cycle_u16x32;
|
|
3557
|
+
*max_index_ptr = (nk_size_t)loop_cycle_vec.u16s[max_lane] * 32 + max_lane;
|
|
3558
|
+
nk_i16_t min_sign = min_value_comparable >> 15;
|
|
3559
|
+
*min_value_ptr = (nk_bf16_t)((nk_u16_t)min_value_comparable ^ ((nk_u16_t)min_sign >> 1));
|
|
3560
|
+
nk_i16_t max_sign = max_value_comparable >> 15;
|
|
3561
|
+
*max_value_ptr = (nk_bf16_t)((nk_u16_t)max_value_comparable ^ ((nk_u16_t)max_sign >> 1));
|
|
3562
|
+
}
|
|
3563
|
+
|
|
3564
|
+
NK_PUBLIC void nk_reduce_minmax_bf16_skylake( //
|
|
3565
|
+
nk_bf16_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
3566
|
+
nk_bf16_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
3567
|
+
nk_bf16_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
3568
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_bf16_t);
|
|
3569
|
+
int aligned = (stride_bytes % sizeof(nk_bf16_t) == 0);
|
|
3570
|
+
if (count == 0)
|
|
3571
|
+
*min_value_ptr = NK_BF16_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_BF16_MIN,
|
|
3572
|
+
*max_index_ptr = NK_SIZE_MAX;
|
|
3573
|
+
else if (!aligned)
|
|
3574
|
+
nk_reduce_minmax_bf16_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
3575
|
+
max_index_ptr);
|
|
3576
|
+
else if (stride_elements == 1 && count > (nk_size_t)(NK_U16_MAX + 1) * 32) {
|
|
3577
|
+
nk_size_t left_count = count / 2;
|
|
3578
|
+
nk_bf16_t left_min, right_min, left_max, right_max;
|
|
3579
|
+
nk_size_t left_min_index, right_min_index, left_max_index, right_max_index;
|
|
3580
|
+
nk_reduce_minmax_bf16_skylake(data_ptr, left_count, stride_bytes, &left_min, &left_min_index, &left_max,
|
|
3581
|
+
&left_max_index);
|
|
3582
|
+
nk_reduce_minmax_bf16_skylake(data_ptr + left_count, count - left_count, stride_bytes, &right_min,
|
|
3583
|
+
&right_min_index, &right_max, &right_max_index);
|
|
3584
|
+
if (nk_bf16_order_serial(right_min, left_min) < 0)
|
|
3585
|
+
*min_value_ptr = right_min, *min_index_ptr = left_count + right_min_index;
|
|
3586
|
+
else *min_value_ptr = left_min, *min_index_ptr = left_min_index;
|
|
3587
|
+
if (nk_bf16_order_serial(right_max, left_max) > 0)
|
|
3588
|
+
*max_value_ptr = right_max, *max_index_ptr = left_count + right_max_index;
|
|
3589
|
+
else *max_value_ptr = left_max, *max_index_ptr = left_max_index;
|
|
3590
|
+
}
|
|
3591
|
+
else if (stride_elements == 1)
|
|
3592
|
+
nk_reduce_minmax_bf16_skylake_contiguous_(data_ptr, count, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
3593
|
+
max_index_ptr);
|
|
3594
|
+
else
|
|
3595
|
+
nk_reduce_minmax_bf16_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
3596
|
+
max_index_ptr);
|
|
3597
|
+
}
|
|
3598
|
+
|
|
3599
|
+
NK_INTERNAL void nk_reduce_moments_f16_skylake_contiguous_( //
|
|
3600
|
+
nk_f16_t const *data_ptr, nk_size_t count, //
|
|
3601
|
+
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
3602
|
+
__m512 sum_f32x16 = _mm512_setzero_ps();
|
|
3603
|
+
__m512 sumsq_f32x16 = _mm512_setzero_ps();
|
|
3604
|
+
nk_size_t idx = 0;
|
|
3605
|
+
for (; idx + 32 <= count; idx += 32) {
|
|
3606
|
+
__m512 low_f32x16 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i const *)(data_ptr + idx)));
|
|
3607
|
+
__m512 high_f32x16 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i const *)(data_ptr + idx + 16)));
|
|
3608
|
+
sum_f32x16 = _mm512_add_ps(sum_f32x16, low_f32x16);
|
|
3609
|
+
sum_f32x16 = _mm512_add_ps(sum_f32x16, high_f32x16);
|
|
3610
|
+
sumsq_f32x16 = _mm512_fmadd_ps(low_f32x16, low_f32x16, sumsq_f32x16);
|
|
3611
|
+
sumsq_f32x16 = _mm512_fmadd_ps(high_f32x16, high_f32x16, sumsq_f32x16);
|
|
3612
|
+
}
|
|
3613
|
+
nk_size_t remaining = count - idx;
|
|
3614
|
+
if (remaining > 0) {
|
|
3615
|
+
__mmask16 low_mask_m16 = (__mmask16)_bzhi_u32(0xFFFF, (unsigned int)(remaining > 16 ? 16 : remaining));
|
|
3616
|
+
__m512 low_f32x16 = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(low_mask_m16, data_ptr + idx));
|
|
3617
|
+
sum_f32x16 = _mm512_add_ps(sum_f32x16, low_f32x16);
|
|
3618
|
+
sumsq_f32x16 = _mm512_fmadd_ps(low_f32x16, low_f32x16, sumsq_f32x16);
|
|
3619
|
+
if (remaining > 16) {
|
|
3620
|
+
__mmask16 high_mask_m16 = (__mmask16)_bzhi_u32(0xFFFF, (unsigned int)(remaining - 16));
|
|
3621
|
+
__m512 high_f32x16 = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(high_mask_m16, data_ptr + idx + 16));
|
|
3622
|
+
sum_f32x16 = _mm512_add_ps(sum_f32x16, high_f32x16);
|
|
3623
|
+
sumsq_f32x16 = _mm512_fmadd_ps(high_f32x16, high_f32x16, sumsq_f32x16);
|
|
3624
|
+
}
|
|
3625
|
+
}
|
|
3626
|
+
*sum_ptr = nk_reduce_add_f32x16_skylake_(sum_f32x16);
|
|
3627
|
+
*sumsq_ptr = nk_reduce_add_f32x16_skylake_(sumsq_f32x16);
|
|
3628
|
+
}
|
|
3629
|
+
|
|
3630
|
+
NK_PUBLIC void nk_reduce_moments_f16_skylake( //
|
|
3631
|
+
nk_f16_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
3632
|
+
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
3633
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_f16_t);
|
|
3634
|
+
int aligned = (stride_bytes % sizeof(nk_f16_t) == 0);
|
|
3635
|
+
if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
|
|
3636
|
+
else if (!aligned) nk_reduce_moments_f16_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
3637
|
+
else if (count > (nk_size_t)(NK_U16_MAX + 1) * 32) {
|
|
3638
|
+
nk_size_t left_count = count / 2;
|
|
3639
|
+
nk_f32_t left_sum, left_sumsq, right_sum, right_sumsq;
|
|
3640
|
+
nk_reduce_moments_f16_skylake(data_ptr, left_count, stride_bytes, &left_sum, &left_sumsq);
|
|
3641
|
+
nk_reduce_moments_f16_skylake(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
|
|
3642
|
+
&right_sum, &right_sumsq);
|
|
3643
|
+
*sum_ptr = left_sum + right_sum;
|
|
3644
|
+
*sumsq_ptr = left_sumsq + right_sumsq;
|
|
3645
|
+
}
|
|
3646
|
+
else if (stride_elements == 1) nk_reduce_moments_f16_skylake_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
|
|
3647
|
+
else nk_reduce_moments_f16_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
3648
|
+
}
|
|
3649
|
+
|
|
3650
|
+
NK_INTERNAL __m512i nk_f16x32_to_comparable_i16x32_skylake_(__m512i raw_u16x32) {
|
|
3651
|
+
__m512i sign = _mm512_srai_epi16(raw_u16x32, 15);
|
|
3652
|
+
__m512i flip = _mm512_srli_epi16(sign, 1);
|
|
3653
|
+
return _mm512_xor_si512(raw_u16x32, flip);
|
|
3654
|
+
}
|
|
3655
|
+
|
|
3656
|
+
NK_INTERNAL void nk_reduce_minmax_f16_skylake_contiguous_( //
|
|
3657
|
+
nk_f16_t const *data_ptr, nk_size_t count, //
|
|
3658
|
+
nk_f16_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
3659
|
+
nk_f16_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
3660
|
+
__m512i abs_mask_u16x32 = _mm512_set1_epi16(0x7FFF);
|
|
3661
|
+
__m512i nan_threshold_u16x32 = _mm512_set1_epi16((short)0x7C00);
|
|
3662
|
+
__m512i min_cmp_i16x32 = _mm512_set1_epi16((short)0x7FFF);
|
|
3663
|
+
__m512i max_cmp_i16x32 = _mm512_set1_epi16((short)0x8000);
|
|
3664
|
+
__m512i min_loop_cycle_u16x32 = _mm512_setzero_si512();
|
|
3665
|
+
__m512i max_loop_cycle_u16x32 = _mm512_setzero_si512();
|
|
3666
|
+
__m512i current_loop_cycle_u16x32 = _mm512_setzero_si512();
|
|
3667
|
+
__m512i one_u16x32 = _mm512_set1_epi16(1);
|
|
3668
|
+
|
|
3669
|
+
nk_size_t idx = 0;
|
|
3670
|
+
for (; idx + 32 <= count; idx += 32) {
|
|
3671
|
+
__m512i raw_u16x32 = _mm512_loadu_si512(data_ptr + idx);
|
|
3672
|
+
__m512i data_cmp_i16x32 = nk_f16x32_to_comparable_i16x32_skylake_(raw_u16x32);
|
|
3673
|
+
__m512i abs_u16x32 = _mm512_and_si512(raw_u16x32, abs_mask_u16x32);
|
|
3674
|
+
__mmask32 not_nan_m32 = _mm512_cmp_epu16_mask(abs_u16x32, nan_threshold_u16x32, _MM_CMPINT_LE);
|
|
3675
|
+
__mmask32 min_changed_mask = _mm512_mask_cmp_epi16_mask(not_nan_m32, data_cmp_i16x32, min_cmp_i16x32,
|
|
3676
|
+
_MM_CMPINT_LT);
|
|
3677
|
+
__mmask32 max_changed_mask = _mm512_mask_cmp_epi16_mask(not_nan_m32, data_cmp_i16x32, max_cmp_i16x32,
|
|
3678
|
+
_MM_CMPINT_NLE);
|
|
3679
|
+
min_cmp_i16x32 = _mm512_mask_mov_epi16(min_cmp_i16x32, min_changed_mask, data_cmp_i16x32);
|
|
3680
|
+
max_cmp_i16x32 = _mm512_mask_mov_epi16(max_cmp_i16x32, max_changed_mask, data_cmp_i16x32);
|
|
3681
|
+
min_loop_cycle_u16x32 = _mm512_mask_mov_epi16(min_loop_cycle_u16x32, min_changed_mask,
|
|
3682
|
+
current_loop_cycle_u16x32);
|
|
3683
|
+
max_loop_cycle_u16x32 = _mm512_mask_mov_epi16(max_loop_cycle_u16x32, max_changed_mask,
|
|
3684
|
+
current_loop_cycle_u16x32);
|
|
3685
|
+
current_loop_cycle_u16x32 = _mm512_add_epi16(current_loop_cycle_u16x32, one_u16x32);
|
|
3686
|
+
}
|
|
3687
|
+
|
|
3688
|
+
nk_size_t remaining = count - idx;
|
|
3689
|
+
if (remaining > 0) {
|
|
3690
|
+
__mmask32 tail_load_m32 = (__mmask32)_bzhi_u32(0xFFFFFFFF, (unsigned int)remaining);
|
|
3691
|
+
__m512i raw_u16x32 = _mm512_maskz_loadu_epi16(tail_load_m32, data_ptr + idx);
|
|
3692
|
+
__m512i data_cmp_i16x32 = nk_f16x32_to_comparable_i16x32_skylake_(raw_u16x32);
|
|
3693
|
+
__m512i abs_u16x32 = _mm512_and_si512(raw_u16x32, abs_mask_u16x32);
|
|
3694
|
+
__mmask32 not_nan_m32 = _mm512_cmp_epu16_mask(abs_u16x32, nan_threshold_u16x32, _MM_CMPINT_LE);
|
|
3695
|
+
__mmask32 valid_m32 = tail_load_m32 & not_nan_m32;
|
|
3696
|
+
__mmask32 min_changed_mask = _mm512_mask_cmp_epi16_mask(valid_m32, data_cmp_i16x32, min_cmp_i16x32,
|
|
3697
|
+
_MM_CMPINT_LT);
|
|
3698
|
+
__mmask32 max_changed_mask = _mm512_mask_cmp_epi16_mask(valid_m32, data_cmp_i16x32, max_cmp_i16x32,
|
|
3699
|
+
_MM_CMPINT_NLE);
|
|
3700
|
+
min_cmp_i16x32 = _mm512_mask_mov_epi16(min_cmp_i16x32, min_changed_mask, data_cmp_i16x32);
|
|
3701
|
+
max_cmp_i16x32 = _mm512_mask_mov_epi16(max_cmp_i16x32, max_changed_mask, data_cmp_i16x32);
|
|
3702
|
+
min_loop_cycle_u16x32 = _mm512_mask_mov_epi16(min_loop_cycle_u16x32, min_changed_mask,
|
|
3703
|
+
current_loop_cycle_u16x32);
|
|
3704
|
+
max_loop_cycle_u16x32 = _mm512_mask_mov_epi16(max_loop_cycle_u16x32, max_changed_mask,
|
|
3705
|
+
current_loop_cycle_u16x32);
|
|
3706
|
+
}
|
|
3707
|
+
|
|
3708
|
+
nk_i16_t min_value_comparable = nk_reduce_min_i16x32_skylake_(min_cmp_i16x32);
|
|
3709
|
+
nk_i16_t max_value_comparable = nk_reduce_max_i16x32_skylake_(max_cmp_i16x32);
|
|
3710
|
+
if (min_value_comparable == 0x7FFF && max_value_comparable == (nk_i16_t)0x8000) {
|
|
3711
|
+
*min_value_ptr = NK_F16_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_F16_MIN,
|
|
3712
|
+
*max_index_ptr = NK_SIZE_MAX;
|
|
3713
|
+
return;
|
|
3714
|
+
}
|
|
3715
|
+
unsigned int min_lane, max_lane;
|
|
3716
|
+
{
|
|
3717
|
+
__mmask32 value_match_mask = _mm512_cmpeq_epi16_mask(min_cmp_i16x32, _mm512_set1_epi16(min_value_comparable));
|
|
3718
|
+
__m512i masked_cycle_u16x32 = _mm512_mask_blend_epi16(value_match_mask, _mm512_set1_epi16((short)NK_U16_MAX),
|
|
3719
|
+
min_loop_cycle_u16x32);
|
|
3720
|
+
nk_u16_t earliest_loop_cycle = nk_reduce_min_u16x32_skylake_(masked_cycle_u16x32);
|
|
3721
|
+
__mmask32 cycle_match_mask = _mm512_cmpeq_epi16_mask(masked_cycle_u16x32,
|
|
3722
|
+
_mm512_set1_epi16((short)earliest_loop_cycle));
|
|
3723
|
+
min_lane = _tzcnt_u32(cycle_match_mask);
|
|
3724
|
+
}
|
|
3725
|
+
{
|
|
3726
|
+
__mmask32 value_match_mask = _mm512_cmpeq_epi16_mask(max_cmp_i16x32, _mm512_set1_epi16(max_value_comparable));
|
|
3727
|
+
__m512i masked_cycle_u16x32 = _mm512_mask_blend_epi16(value_match_mask, _mm512_set1_epi16((short)NK_U16_MAX),
|
|
3728
|
+
max_loop_cycle_u16x32);
|
|
3729
|
+
nk_u16_t earliest_loop_cycle = nk_reduce_min_u16x32_skylake_(masked_cycle_u16x32);
|
|
3730
|
+
__mmask32 cycle_match_mask = _mm512_cmpeq_epi16_mask(masked_cycle_u16x32,
|
|
3731
|
+
_mm512_set1_epi16((short)earliest_loop_cycle));
|
|
3732
|
+
max_lane = _tzcnt_u32(cycle_match_mask);
|
|
3733
|
+
}
|
|
3734
|
+
nk_b512_vec_t loop_cycle_vec;
|
|
3735
|
+
loop_cycle_vec.zmm = min_loop_cycle_u16x32;
|
|
3736
|
+
*min_index_ptr = (nk_size_t)loop_cycle_vec.u16s[min_lane] * 32 + min_lane;
|
|
3737
|
+
loop_cycle_vec.zmm = max_loop_cycle_u16x32;
|
|
3738
|
+
*max_index_ptr = (nk_size_t)loop_cycle_vec.u16s[max_lane] * 32 + max_lane;
|
|
3739
|
+
nk_i16_t min_sign = min_value_comparable >> 15;
|
|
3740
|
+
*min_value_ptr = (nk_f16_t)((nk_u16_t)min_value_comparable ^ ((nk_u16_t)min_sign >> 1));
|
|
3741
|
+
nk_i16_t max_sign = max_value_comparable >> 15;
|
|
3742
|
+
*max_value_ptr = (nk_f16_t)((nk_u16_t)max_value_comparable ^ ((nk_u16_t)max_sign >> 1));
|
|
3743
|
+
}
|
|
3744
|
+
|
|
3745
|
+
NK_PUBLIC void nk_reduce_minmax_f16_skylake( //
|
|
3746
|
+
nk_f16_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
3747
|
+
nk_f16_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
3748
|
+
nk_f16_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
3749
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_f16_t);
|
|
3750
|
+
int aligned = (stride_bytes % sizeof(nk_f16_t) == 0);
|
|
3751
|
+
if (count == 0)
|
|
3752
|
+
*min_value_ptr = NK_F16_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_F16_MIN,
|
|
3753
|
+
*max_index_ptr = NK_SIZE_MAX;
|
|
3754
|
+
else if (!aligned)
|
|
3755
|
+
nk_reduce_minmax_f16_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
3756
|
+
max_index_ptr);
|
|
3757
|
+
else if (stride_elements == 1 && count > (nk_size_t)(NK_U16_MAX + 1) * 32) {
|
|
3758
|
+
nk_size_t left_count = count / 2;
|
|
3759
|
+
nk_f16_t left_min, right_min, left_max, right_max;
|
|
3760
|
+
nk_size_t left_min_index, right_min_index, left_max_index, right_max_index;
|
|
3761
|
+
nk_reduce_minmax_f16_skylake(data_ptr, left_count, stride_bytes, &left_min, &left_min_index, &left_max,
|
|
3762
|
+
&left_max_index);
|
|
3763
|
+
nk_reduce_minmax_f16_skylake(data_ptr + left_count, count - left_count, stride_bytes, &right_min,
|
|
3764
|
+
&right_min_index, &right_max, &right_max_index);
|
|
3765
|
+
if (nk_f16_order_serial(right_min, left_min) < 0)
|
|
3766
|
+
*min_value_ptr = right_min, *min_index_ptr = left_count + right_min_index;
|
|
3767
|
+
else *min_value_ptr = left_min, *min_index_ptr = left_min_index;
|
|
3768
|
+
if (nk_f16_order_serial(right_max, left_max) > 0)
|
|
3769
|
+
*max_value_ptr = right_max, *max_index_ptr = left_count + right_max_index;
|
|
3770
|
+
else *max_value_ptr = left_max, *max_index_ptr = left_max_index;
|
|
3771
|
+
}
|
|
3772
|
+
else if (stride_elements == 1)
|
|
3773
|
+
nk_reduce_minmax_f16_skylake_contiguous_(data_ptr, count, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
3774
|
+
max_index_ptr);
|
|
3775
|
+
else
|
|
3776
|
+
nk_reduce_minmax_f16_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
3777
|
+
max_index_ptr);
|
|
3778
|
+
}
|
|
3779
|
+
|
|
3780
|
+
#if defined(__clang__)
|
|
3781
|
+
#pragma clang attribute pop
|
|
3782
|
+
#elif defined(__GNUC__)
|
|
3783
|
+
#pragma GCC pop_options
|
|
3784
|
+
#endif
|
|
3785
|
+
|
|
3786
|
+
#if defined(__cplusplus)
|
|
3787
|
+
} // extern "C"
|
|
3788
|
+
#endif
|
|
3789
|
+
|
|
3790
|
+
#endif // NK_TARGET_SKYLAKE
|
|
3791
|
+
#endif // NK_TARGET_X86_
|
|
3792
|
+
#endif // NK_REDUCE_SKYLAKE_H
|