numkong 7.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +201 -0
- package/README.md +495 -0
- package/binding.gyp +540 -0
- package/c/dispatch.h +512 -0
- package/c/dispatch_bf16.c +389 -0
- package/c/dispatch_bf16c.c +52 -0
- package/c/dispatch_e2m3.c +263 -0
- package/c/dispatch_e3m2.c +243 -0
- package/c/dispatch_e4m3.c +276 -0
- package/c/dispatch_e5m2.c +272 -0
- package/c/dispatch_f16.c +376 -0
- package/c/dispatch_f16c.c +58 -0
- package/c/dispatch_f32.c +378 -0
- package/c/dispatch_f32c.c +99 -0
- package/c/dispatch_f64.c +296 -0
- package/c/dispatch_f64c.c +98 -0
- package/c/dispatch_i16.c +96 -0
- package/c/dispatch_i32.c +89 -0
- package/c/dispatch_i4.c +150 -0
- package/c/dispatch_i64.c +86 -0
- package/c/dispatch_i8.c +289 -0
- package/c/dispatch_other.c +330 -0
- package/c/dispatch_u1.c +148 -0
- package/c/dispatch_u16.c +124 -0
- package/c/dispatch_u32.c +118 -0
- package/c/dispatch_u4.c +150 -0
- package/c/dispatch_u64.c +102 -0
- package/c/dispatch_u8.c +303 -0
- package/c/numkong.c +950 -0
- package/include/README.md +573 -0
- package/include/module.modulemap +129 -0
- package/include/numkong/attention/sapphireamx.h +1361 -0
- package/include/numkong/attention/sme.h +2066 -0
- package/include/numkong/attention.h +49 -0
- package/include/numkong/capabilities.h +748 -0
- package/include/numkong/cast/README.md +262 -0
- package/include/numkong/cast/haswell.h +975 -0
- package/include/numkong/cast/icelake.h +470 -0
- package/include/numkong/cast/neon.h +1192 -0
- package/include/numkong/cast/rvv.h +1021 -0
- package/include/numkong/cast/sapphire.h +262 -0
- package/include/numkong/cast/serial.h +2262 -0
- package/include/numkong/cast/skylake.h +856 -0
- package/include/numkong/cast/v128relaxed.h +180 -0
- package/include/numkong/cast.h +230 -0
- package/include/numkong/curved/README.md +223 -0
- package/include/numkong/curved/genoa.h +182 -0
- package/include/numkong/curved/haswell.h +276 -0
- package/include/numkong/curved/neon.h +205 -0
- package/include/numkong/curved/neonbfdot.h +212 -0
- package/include/numkong/curved/neonhalf.h +212 -0
- package/include/numkong/curved/rvv.h +305 -0
- package/include/numkong/curved/serial.h +207 -0
- package/include/numkong/curved/skylake.h +457 -0
- package/include/numkong/curved/smef64.h +506 -0
- package/include/numkong/curved.h +517 -0
- package/include/numkong/curved.hpp +144 -0
- package/include/numkong/dot/README.md +425 -0
- package/include/numkong/dot/alder.h +563 -0
- package/include/numkong/dot/genoa.h +315 -0
- package/include/numkong/dot/haswell.h +1688 -0
- package/include/numkong/dot/icelake.h +883 -0
- package/include/numkong/dot/neon.h +818 -0
- package/include/numkong/dot/neonbfdot.h +244 -0
- package/include/numkong/dot/neonfhm.h +360 -0
- package/include/numkong/dot/neonhalf.h +198 -0
- package/include/numkong/dot/neonsdot.h +508 -0
- package/include/numkong/dot/rvv.h +714 -0
- package/include/numkong/dot/rvvbb.h +72 -0
- package/include/numkong/dot/rvvbf16.h +123 -0
- package/include/numkong/dot/rvvhalf.h +129 -0
- package/include/numkong/dot/sapphire.h +141 -0
- package/include/numkong/dot/serial.h +838 -0
- package/include/numkong/dot/sierra.h +405 -0
- package/include/numkong/dot/skylake.h +1084 -0
- package/include/numkong/dot/sve.h +379 -0
- package/include/numkong/dot/svebfdot.h +74 -0
- package/include/numkong/dot/svehalf.h +123 -0
- package/include/numkong/dot/v128relaxed.h +1258 -0
- package/include/numkong/dot.h +1070 -0
- package/include/numkong/dot.hpp +94 -0
- package/include/numkong/dots/README.md +496 -0
- package/include/numkong/dots/alder.h +114 -0
- package/include/numkong/dots/genoa.h +94 -0
- package/include/numkong/dots/haswell.h +295 -0
- package/include/numkong/dots/icelake.h +171 -0
- package/include/numkong/dots/neon.h +120 -0
- package/include/numkong/dots/neonbfdot.h +58 -0
- package/include/numkong/dots/neonfhm.h +94 -0
- package/include/numkong/dots/neonhalf.h +57 -0
- package/include/numkong/dots/neonsdot.h +108 -0
- package/include/numkong/dots/rvv.h +2486 -0
- package/include/numkong/dots/sapphireamx.h +3973 -0
- package/include/numkong/dots/serial.h +2844 -0
- package/include/numkong/dots/sierra.h +97 -0
- package/include/numkong/dots/skylake.h +196 -0
- package/include/numkong/dots/sme.h +5372 -0
- package/include/numkong/dots/smebi32.h +461 -0
- package/include/numkong/dots/smef64.h +1318 -0
- package/include/numkong/dots/smehalf.h +47 -0
- package/include/numkong/dots/v128relaxed.h +294 -0
- package/include/numkong/dots.h +2804 -0
- package/include/numkong/dots.hpp +639 -0
- package/include/numkong/each/README.md +469 -0
- package/include/numkong/each/haswell.h +1658 -0
- package/include/numkong/each/icelake.h +272 -0
- package/include/numkong/each/neon.h +1104 -0
- package/include/numkong/each/neonbfdot.h +212 -0
- package/include/numkong/each/neonhalf.h +410 -0
- package/include/numkong/each/rvv.h +1121 -0
- package/include/numkong/each/sapphire.h +477 -0
- package/include/numkong/each/serial.h +260 -0
- package/include/numkong/each/skylake.h +1562 -0
- package/include/numkong/each.h +2146 -0
- package/include/numkong/each.hpp +434 -0
- package/include/numkong/geospatial/README.md +147 -0
- package/include/numkong/geospatial/haswell.h +593 -0
- package/include/numkong/geospatial/neon.h +571 -0
- package/include/numkong/geospatial/rvv.h +701 -0
- package/include/numkong/geospatial/serial.h +309 -0
- package/include/numkong/geospatial/skylake.h +577 -0
- package/include/numkong/geospatial/v128relaxed.h +613 -0
- package/include/numkong/geospatial.h +453 -0
- package/include/numkong/geospatial.hpp +235 -0
- package/include/numkong/matrix.hpp +336 -0
- package/include/numkong/maxsim/README.md +187 -0
- package/include/numkong/maxsim/alder.h +511 -0
- package/include/numkong/maxsim/genoa.h +115 -0
- package/include/numkong/maxsim/haswell.h +553 -0
- package/include/numkong/maxsim/icelake.h +480 -0
- package/include/numkong/maxsim/neonsdot.h +394 -0
- package/include/numkong/maxsim/sapphireamx.h +877 -0
- package/include/numkong/maxsim/serial.h +490 -0
- package/include/numkong/maxsim/sme.h +929 -0
- package/include/numkong/maxsim/v128relaxed.h +280 -0
- package/include/numkong/maxsim.h +571 -0
- package/include/numkong/maxsim.hpp +133 -0
- package/include/numkong/mesh/README.md +227 -0
- package/include/numkong/mesh/haswell.h +2235 -0
- package/include/numkong/mesh/neon.h +1329 -0
- package/include/numkong/mesh/neonbfdot.h +842 -0
- package/include/numkong/mesh/neonhalf.h +616 -0
- package/include/numkong/mesh/rvv.h +916 -0
- package/include/numkong/mesh/serial.h +742 -0
- package/include/numkong/mesh/skylake.h +1135 -0
- package/include/numkong/mesh/v128relaxed.h +1052 -0
- package/include/numkong/mesh.h +652 -0
- package/include/numkong/mesh.hpp +762 -0
- package/include/numkong/numkong.h +78 -0
- package/include/numkong/numkong.hpp +57 -0
- package/include/numkong/probability/README.md +173 -0
- package/include/numkong/probability/haswell.h +267 -0
- package/include/numkong/probability/neon.h +225 -0
- package/include/numkong/probability/rvv.h +409 -0
- package/include/numkong/probability/serial.h +169 -0
- package/include/numkong/probability/skylake.h +324 -0
- package/include/numkong/probability.h +383 -0
- package/include/numkong/probability.hpp +120 -0
- package/include/numkong/random.h +50 -0
- package/include/numkong/random.hpp +285 -0
- package/include/numkong/reduce/README.md +547 -0
- package/include/numkong/reduce/alder.h +632 -0
- package/include/numkong/reduce/genoa.h +201 -0
- package/include/numkong/reduce/haswell.h +3783 -0
- package/include/numkong/reduce/icelake.h +549 -0
- package/include/numkong/reduce/neon.h +3841 -0
- package/include/numkong/reduce/neonbfdot.h +353 -0
- package/include/numkong/reduce/neonfhm.h +665 -0
- package/include/numkong/reduce/neonhalf.h +157 -0
- package/include/numkong/reduce/neonsdot.h +357 -0
- package/include/numkong/reduce/rvv.h +3407 -0
- package/include/numkong/reduce/serial.h +757 -0
- package/include/numkong/reduce/sierra.h +338 -0
- package/include/numkong/reduce/skylake.h +3792 -0
- package/include/numkong/reduce/v128relaxed.h +2302 -0
- package/include/numkong/reduce.h +1597 -0
- package/include/numkong/reduce.hpp +633 -0
- package/include/numkong/scalar/README.md +89 -0
- package/include/numkong/scalar/haswell.h +113 -0
- package/include/numkong/scalar/neon.h +122 -0
- package/include/numkong/scalar/neonhalf.h +70 -0
- package/include/numkong/scalar/rvv.h +211 -0
- package/include/numkong/scalar/sapphire.h +63 -0
- package/include/numkong/scalar/serial.h +332 -0
- package/include/numkong/scalar/v128relaxed.h +56 -0
- package/include/numkong/scalar.h +683 -0
- package/include/numkong/set/README.md +179 -0
- package/include/numkong/set/haswell.h +334 -0
- package/include/numkong/set/icelake.h +485 -0
- package/include/numkong/set/neon.h +364 -0
- package/include/numkong/set/rvv.h +226 -0
- package/include/numkong/set/rvvbb.h +117 -0
- package/include/numkong/set/serial.h +174 -0
- package/include/numkong/set/sve.h +185 -0
- package/include/numkong/set/v128relaxed.h +240 -0
- package/include/numkong/set.h +457 -0
- package/include/numkong/set.hpp +114 -0
- package/include/numkong/sets/README.md +149 -0
- package/include/numkong/sets/haswell.h +63 -0
- package/include/numkong/sets/icelake.h +66 -0
- package/include/numkong/sets/neon.h +61 -0
- package/include/numkong/sets/serial.h +43 -0
- package/include/numkong/sets/smebi32.h +1099 -0
- package/include/numkong/sets/v128relaxed.h +58 -0
- package/include/numkong/sets.h +339 -0
- package/include/numkong/sparse/README.md +156 -0
- package/include/numkong/sparse/icelake.h +463 -0
- package/include/numkong/sparse/neon.h +288 -0
- package/include/numkong/sparse/serial.h +117 -0
- package/include/numkong/sparse/sve2.h +507 -0
- package/include/numkong/sparse/turin.h +322 -0
- package/include/numkong/sparse.h +363 -0
- package/include/numkong/sparse.hpp +113 -0
- package/include/numkong/spatial/README.md +435 -0
- package/include/numkong/spatial/alder.h +607 -0
- package/include/numkong/spatial/genoa.h +290 -0
- package/include/numkong/spatial/haswell.h +960 -0
- package/include/numkong/spatial/icelake.h +586 -0
- package/include/numkong/spatial/neon.h +773 -0
- package/include/numkong/spatial/neonbfdot.h +165 -0
- package/include/numkong/spatial/neonhalf.h +118 -0
- package/include/numkong/spatial/neonsdot.h +261 -0
- package/include/numkong/spatial/rvv.h +984 -0
- package/include/numkong/spatial/rvvbf16.h +123 -0
- package/include/numkong/spatial/rvvhalf.h +117 -0
- package/include/numkong/spatial/sapphire.h +343 -0
- package/include/numkong/spatial/serial.h +346 -0
- package/include/numkong/spatial/sierra.h +323 -0
- package/include/numkong/spatial/skylake.h +606 -0
- package/include/numkong/spatial/sve.h +224 -0
- package/include/numkong/spatial/svebfdot.h +122 -0
- package/include/numkong/spatial/svehalf.h +109 -0
- package/include/numkong/spatial/v128relaxed.h +717 -0
- package/include/numkong/spatial.h +1425 -0
- package/include/numkong/spatial.hpp +183 -0
- package/include/numkong/spatials/README.md +580 -0
- package/include/numkong/spatials/alder.h +94 -0
- package/include/numkong/spatials/genoa.h +94 -0
- package/include/numkong/spatials/haswell.h +219 -0
- package/include/numkong/spatials/icelake.h +113 -0
- package/include/numkong/spatials/neon.h +109 -0
- package/include/numkong/spatials/neonbfdot.h +60 -0
- package/include/numkong/spatials/neonfhm.h +92 -0
- package/include/numkong/spatials/neonhalf.h +58 -0
- package/include/numkong/spatials/neonsdot.h +109 -0
- package/include/numkong/spatials/rvv.h +1960 -0
- package/include/numkong/spatials/sapphireamx.h +1149 -0
- package/include/numkong/spatials/serial.h +226 -0
- package/include/numkong/spatials/sierra.h +96 -0
- package/include/numkong/spatials/skylake.h +184 -0
- package/include/numkong/spatials/sme.h +1901 -0
- package/include/numkong/spatials/smef64.h +465 -0
- package/include/numkong/spatials/v128relaxed.h +240 -0
- package/include/numkong/spatials.h +3021 -0
- package/include/numkong/spatials.hpp +508 -0
- package/include/numkong/tensor.hpp +1592 -0
- package/include/numkong/trigonometry/README.md +184 -0
- package/include/numkong/trigonometry/haswell.h +652 -0
- package/include/numkong/trigonometry/neon.h +639 -0
- package/include/numkong/trigonometry/rvv.h +699 -0
- package/include/numkong/trigonometry/serial.h +703 -0
- package/include/numkong/trigonometry/skylake.h +721 -0
- package/include/numkong/trigonometry/v128relaxed.h +666 -0
- package/include/numkong/trigonometry.h +467 -0
- package/include/numkong/trigonometry.hpp +166 -0
- package/include/numkong/types.h +1384 -0
- package/include/numkong/types.hpp +5603 -0
- package/include/numkong/vector.hpp +698 -0
- package/javascript/README.md +246 -0
- package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
- package/javascript/dist/cjs/numkong-wasm.js +617 -0
- package/javascript/dist/cjs/numkong.d.ts +343 -0
- package/javascript/dist/cjs/numkong.js +523 -0
- package/javascript/dist/cjs/package.json +3 -0
- package/javascript/dist/cjs/types.d.ts +284 -0
- package/javascript/dist/cjs/types.js +653 -0
- package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
- package/javascript/dist/esm/numkong-wasm.js +595 -0
- package/javascript/dist/esm/numkong.d.ts +343 -0
- package/javascript/dist/esm/numkong.js +452 -0
- package/javascript/dist/esm/package.json +3 -0
- package/javascript/dist/esm/types.d.ts +284 -0
- package/javascript/dist/esm/types.js +630 -0
- package/javascript/dist-package-cjs.json +3 -0
- package/javascript/dist-package-esm.json +3 -0
- package/javascript/node-gyp-build.d.ts +1 -0
- package/javascript/numkong-wasm.ts +756 -0
- package/javascript/numkong.c +689 -0
- package/javascript/numkong.ts +575 -0
- package/javascript/tsconfig-base.json +39 -0
- package/javascript/tsconfig-cjs.json +8 -0
- package/javascript/tsconfig-esm.json +8 -0
- package/javascript/types.ts +674 -0
- package/package.json +87 -0
|
@@ -0,0 +1,3783 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief AVX2 implementations for the redesigned reduction API (moments + minmax).
|
|
3
|
+
* @file include/numkong/reduce/haswell.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date February 12, 2026
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/reduce.h
|
|
8
|
+
*
|
|
9
|
+
* @section reduce_block_caps Block-Cap Overflow Thresholds
|
|
10
|
+
*
|
|
11
|
+
* Dispatch functions use pairwise recursion when `count` exceeds a block cap.
|
|
12
|
+
* The cap is sized so the iteration counter in the contiguous kernel never wraps.
|
|
13
|
+
*
|
|
14
|
+
* Iteration counters start at 0 (initial load) and increment by 1 per SIMD chunk.
|
|
15
|
+
* A u8 counter holds 0..255 → 256 iterations → processes 256 × lanes elements.
|
|
16
|
+
* A u16 counter holds 0..65535 → 65536 iterations → processes 65536 × lanes elements.
|
|
17
|
+
* A u32 counter holds 0..4294967295 → ~4.3 billion iterations.
|
|
18
|
+
*
|
|
19
|
+
* Threshold formula: count > (COUNTER_MAX + 1) × lanes_per_chunk
|
|
20
|
+
* - u8 minmax: (NK_U8_MAX + 1) × lanes (e.g. 256 × 32 = 8192 for i8x32)
|
|
21
|
+
* - u16 minmax: (NK_U16_MAX + 1) × lanes (e.g. 65536 × 16 = 1048576 for i16x16)
|
|
22
|
+
* - u32 minmax: NK_U32_MAX × lanes (no +1: NK_U32_MAX + 1 overflows unsigned)
|
|
23
|
+
*
|
|
24
|
+
* Moments block caps are sized for accumulator overflow, not counter overflow.
|
|
25
|
+
* See individual dispatch functions for type-specific derivations.
|
|
26
|
+
*/
|
|
27
|
+
#ifndef NK_REDUCE_HASWELL_H
|
|
28
|
+
#define NK_REDUCE_HASWELL_H
|
|
29
|
+
|
|
30
|
+
#if NK_TARGET_X86_
|
|
31
|
+
#if NK_TARGET_HASWELL
|
|
32
|
+
|
|
33
|
+
#include "numkong/types.h"
|
|
34
|
+
#include "numkong/cast/haswell.h"
|
|
35
|
+
#include "numkong/reduce/serial.h"
|
|
36
|
+
|
|
37
|
+
#if defined(__cplusplus)
|
|
38
|
+
extern "C" {
|
|
39
|
+
#endif
|
|
40
|
+
|
|
41
|
+
#if defined(__clang__)
|
|
42
|
+
#pragma clang attribute push(__attribute__((target("avx2,f16c,fma,bmi,bmi2"))), apply_to = function)
|
|
43
|
+
#elif defined(__GNUC__)
|
|
44
|
+
#pragma GCC push_options
|
|
45
|
+
#pragma GCC target("avx2", "f16c", "fma", "bmi", "bmi2")
|
|
46
|
+
#endif
|
|
47
|
+
|
|
48
|
+
/** @brief Horizontal sum of 4 doubles in a YMM register. */
|
|
49
|
+
NK_INTERNAL nk_f64_t nk_reduce_add_f64x4_haswell_(__m256d sum_f64x4) {
|
|
50
|
+
__m128d lo_f64x2 = _mm256_castpd256_pd128(sum_f64x4);
|
|
51
|
+
__m128d hi_f64x2 = _mm256_extractf128_pd(sum_f64x4, 1);
|
|
52
|
+
__m128d sum_f64x2 = _mm_add_pd(lo_f64x2, hi_f64x2);
|
|
53
|
+
sum_f64x2 = _mm_hadd_pd(sum_f64x2, sum_f64x2);
|
|
54
|
+
return _mm_cvtsd_f64(sum_f64x2);
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
/** @brief Horizontal sum of 8 floats in a YMM register (native f32 precision). */
|
|
58
|
+
NK_INTERNAL nk_f32_t nk_reduce_add_f32x8_haswell_(__m256 sum_f32x8) {
|
|
59
|
+
__m128 lo_f32x4 = _mm256_castps256_ps128(sum_f32x8);
|
|
60
|
+
__m128 hi_f32x4 = _mm256_extractf128_ps(sum_f32x8, 1);
|
|
61
|
+
__m128 sum_f32x4 = _mm_add_ps(lo_f32x4, hi_f32x4);
|
|
62
|
+
sum_f32x4 = _mm_hadd_ps(sum_f32x4, sum_f32x4);
|
|
63
|
+
sum_f32x4 = _mm_hadd_ps(sum_f32x4, sum_f32x4);
|
|
64
|
+
return _mm_cvtss_f32(sum_f32x4);
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
/** @brief Horizontal sum of 8 i32s in a YMM register. */
|
|
68
|
+
NK_INTERNAL nk_i32_t nk_reduce_add_i32x8_haswell_(__m256i sum_i32x8) {
|
|
69
|
+
__m128i lo_i32x4 = _mm256_castsi256_si128(sum_i32x8);
|
|
70
|
+
__m128i hi_i32x4 = _mm256_extracti128_si256(sum_i32x8, 1);
|
|
71
|
+
__m128i sum_i32x4 = _mm_add_epi32(lo_i32x4, hi_i32x4);
|
|
72
|
+
sum_i32x4 = _mm_hadd_epi32(sum_i32x4, sum_i32x4);
|
|
73
|
+
sum_i32x4 = _mm_hadd_epi32(sum_i32x4, sum_i32x4);
|
|
74
|
+
return _mm_cvtsi128_si32(sum_i32x4);
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
/** @brief Horizontal sum of 4 i64s in a YMM register. */
|
|
78
|
+
NK_INTERNAL nk_i64_t nk_reduce_add_i64x4_haswell_(__m256i sum_i64x4) {
|
|
79
|
+
__m128i lo_i64x2 = _mm256_castsi256_si128(sum_i64x4);
|
|
80
|
+
__m128i hi_i64x2 = _mm256_extracti128_si256(sum_i64x4, 1);
|
|
81
|
+
__m128i sum_i64x2 = _mm_add_epi64(lo_i64x2, hi_i64x2);
|
|
82
|
+
__m128i hi_lane_i64 = _mm_unpackhi_epi64(sum_i64x2, sum_i64x2);
|
|
83
|
+
__m128i final_i64 = _mm_add_epi64(sum_i64x2, hi_lane_i64);
|
|
84
|
+
return _mm_cvtsi128_si64(final_i64);
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
/** @brief Horizontal min of 8 signed i8s in a YMM register. */
|
|
88
|
+
NK_INTERNAL nk_i8_t nk_reduce_min_i8x32_haswell_(__m256i min_i8x32) {
|
|
89
|
+
__m128i lo_i8x16 = _mm256_castsi256_si128(min_i8x32);
|
|
90
|
+
__m128i hi_i8x16 = _mm256_extracti128_si256(min_i8x32, 1);
|
|
91
|
+
__m128i min_i8x16 = _mm_min_epi8(lo_i8x16, hi_i8x16);
|
|
92
|
+
min_i8x16 = _mm_min_epi8(min_i8x16, _mm_shuffle_epi32(min_i8x16, _MM_SHUFFLE(2, 3, 0, 1)));
|
|
93
|
+
min_i8x16 = _mm_min_epi8(min_i8x16, _mm_shuffle_epi32(min_i8x16, _MM_SHUFFLE(1, 0, 3, 2)));
|
|
94
|
+
min_i8x16 = _mm_min_epi8(min_i8x16, _mm_srli_si128(min_i8x16, 2));
|
|
95
|
+
min_i8x16 = _mm_min_epi8(min_i8x16, _mm_srli_si128(min_i8x16, 1));
|
|
96
|
+
return (nk_i8_t)_mm_cvtsi128_si32(min_i8x16);
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
/** @brief Horizontal max of 8 signed i8s in a YMM register. */
|
|
100
|
+
NK_INTERNAL nk_i8_t nk_reduce_max_i8x32_haswell_(__m256i max_i8x32) {
|
|
101
|
+
__m128i lo_i8x16 = _mm256_castsi256_si128(max_i8x32);
|
|
102
|
+
__m128i hi_i8x16 = _mm256_extracti128_si256(max_i8x32, 1);
|
|
103
|
+
__m128i max_i8x16 = _mm_max_epi8(lo_i8x16, hi_i8x16);
|
|
104
|
+
max_i8x16 = _mm_max_epi8(max_i8x16, _mm_shuffle_epi32(max_i8x16, _MM_SHUFFLE(2, 3, 0, 1)));
|
|
105
|
+
max_i8x16 = _mm_max_epi8(max_i8x16, _mm_shuffle_epi32(max_i8x16, _MM_SHUFFLE(1, 0, 3, 2)));
|
|
106
|
+
max_i8x16 = _mm_max_epi8(max_i8x16, _mm_srli_si128(max_i8x16, 2));
|
|
107
|
+
max_i8x16 = _mm_max_epi8(max_i8x16, _mm_srli_si128(max_i8x16, 1));
|
|
108
|
+
return (nk_i8_t)_mm_cvtsi128_si32(max_i8x16);
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
/** @brief Horizontal min of 8 unsigned u8s in a YMM register. */
|
|
112
|
+
NK_INTERNAL nk_u8_t nk_reduce_min_u8x32_haswell_(__m256i min_u8x32) {
|
|
113
|
+
__m128i lo_u8x16 = _mm256_castsi256_si128(min_u8x32);
|
|
114
|
+
__m128i hi_u8x16 = _mm256_extracti128_si256(min_u8x32, 1);
|
|
115
|
+
__m128i min_u8x16 = _mm_min_epu8(lo_u8x16, hi_u8x16);
|
|
116
|
+
min_u8x16 = _mm_min_epu8(min_u8x16, _mm_shuffle_epi32(min_u8x16, _MM_SHUFFLE(2, 3, 0, 1)));
|
|
117
|
+
min_u8x16 = _mm_min_epu8(min_u8x16, _mm_shuffle_epi32(min_u8x16, _MM_SHUFFLE(1, 0, 3, 2)));
|
|
118
|
+
min_u8x16 = _mm_min_epu8(min_u8x16, _mm_srli_si128(min_u8x16, 2));
|
|
119
|
+
min_u8x16 = _mm_min_epu8(min_u8x16, _mm_srli_si128(min_u8x16, 1));
|
|
120
|
+
return (nk_u8_t)_mm_cvtsi128_si32(min_u8x16);
|
|
121
|
+
}
|
|
122
|
+
|
|
123
|
+
/** @brief Horizontal max of 8 unsigned u8s in a YMM register. */
|
|
124
|
+
NK_INTERNAL nk_u8_t nk_reduce_max_u8x32_haswell_(__m256i max_u8x32) {
|
|
125
|
+
__m128i lo_u8x16 = _mm256_castsi256_si128(max_u8x32);
|
|
126
|
+
__m128i hi_u8x16 = _mm256_extracti128_si256(max_u8x32, 1);
|
|
127
|
+
__m128i max_u8x16 = _mm_max_epu8(lo_u8x16, hi_u8x16);
|
|
128
|
+
max_u8x16 = _mm_max_epu8(max_u8x16, _mm_shuffle_epi32(max_u8x16, _MM_SHUFFLE(2, 3, 0, 1)));
|
|
129
|
+
max_u8x16 = _mm_max_epu8(max_u8x16, _mm_shuffle_epi32(max_u8x16, _MM_SHUFFLE(1, 0, 3, 2)));
|
|
130
|
+
max_u8x16 = _mm_max_epu8(max_u8x16, _mm_srli_si128(max_u8x16, 2));
|
|
131
|
+
max_u8x16 = _mm_max_epu8(max_u8x16, _mm_srli_si128(max_u8x16, 1));
|
|
132
|
+
return (nk_u8_t)_mm_cvtsi128_si32(max_u8x16);
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
/** @brief Horizontal min of 16 signed i16s in a YMM register. */
|
|
136
|
+
NK_INTERNAL nk_i16_t nk_reduce_min_i16x16_haswell_(__m256i min_i16x16) {
|
|
137
|
+
__m128i lo_i16x8 = _mm256_castsi256_si128(min_i16x16);
|
|
138
|
+
__m128i hi_i16x8 = _mm256_extracti128_si256(min_i16x16, 1);
|
|
139
|
+
__m128i min_i16x8 = _mm_min_epi16(lo_i16x8, hi_i16x8);
|
|
140
|
+
min_i16x8 = _mm_min_epi16(min_i16x8, _mm_shuffle_epi32(min_i16x8, _MM_SHUFFLE(2, 3, 0, 1)));
|
|
141
|
+
min_i16x8 = _mm_min_epi16(min_i16x8, _mm_shuffle_epi32(min_i16x8, _MM_SHUFFLE(1, 0, 3, 2)));
|
|
142
|
+
min_i16x8 = _mm_min_epi16(min_i16x8, _mm_srli_si128(min_i16x8, 2));
|
|
143
|
+
return (nk_i16_t)_mm_cvtsi128_si32(min_i16x8);
|
|
144
|
+
}
|
|
145
|
+
|
|
146
|
+
/** @brief Horizontal max of 16 signed i16s in a YMM register. */
|
|
147
|
+
NK_INTERNAL nk_i16_t nk_reduce_max_i16x16_haswell_(__m256i max_i16x16) {
|
|
148
|
+
__m128i lo_i16x8 = _mm256_castsi256_si128(max_i16x16);
|
|
149
|
+
__m128i hi_i16x8 = _mm256_extracti128_si256(max_i16x16, 1);
|
|
150
|
+
__m128i max_i16x8 = _mm_max_epi16(lo_i16x8, hi_i16x8);
|
|
151
|
+
max_i16x8 = _mm_max_epi16(max_i16x8, _mm_shuffle_epi32(max_i16x8, _MM_SHUFFLE(2, 3, 0, 1)));
|
|
152
|
+
max_i16x8 = _mm_max_epi16(max_i16x8, _mm_shuffle_epi32(max_i16x8, _MM_SHUFFLE(1, 0, 3, 2)));
|
|
153
|
+
max_i16x8 = _mm_max_epi16(max_i16x8, _mm_srli_si128(max_i16x8, 2));
|
|
154
|
+
return (nk_i16_t)_mm_cvtsi128_si32(max_i16x8);
|
|
155
|
+
}
|
|
156
|
+
|
|
157
|
+
/** @brief Horizontal min of 16 unsigned u16s in a YMM register. */
|
|
158
|
+
NK_INTERNAL nk_u16_t nk_reduce_min_u16x16_haswell_(__m256i min_u16x16) {
|
|
159
|
+
__m128i lo_u16x8 = _mm256_castsi256_si128(min_u16x16);
|
|
160
|
+
__m128i hi_u16x8 = _mm256_extracti128_si256(min_u16x16, 1);
|
|
161
|
+
__m128i min_u16x8 = _mm_min_epu16(lo_u16x8, hi_u16x8);
|
|
162
|
+
min_u16x8 = _mm_min_epu16(min_u16x8, _mm_shuffle_epi32(min_u16x8, _MM_SHUFFLE(2, 3, 0, 1)));
|
|
163
|
+
min_u16x8 = _mm_min_epu16(min_u16x8, _mm_shuffle_epi32(min_u16x8, _MM_SHUFFLE(1, 0, 3, 2)));
|
|
164
|
+
min_u16x8 = _mm_min_epu16(min_u16x8, _mm_srli_si128(min_u16x8, 2));
|
|
165
|
+
return (nk_u16_t)_mm_cvtsi128_si32(min_u16x8);
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
/** @brief Horizontal max of 16 unsigned u16s in a YMM register. */
|
|
169
|
+
NK_INTERNAL nk_u16_t nk_reduce_max_u16x16_haswell_(__m256i max_u16x16) {
|
|
170
|
+
__m128i lo_u16x8 = _mm256_castsi256_si128(max_u16x16);
|
|
171
|
+
__m128i hi_u16x8 = _mm256_extracti128_si256(max_u16x16, 1);
|
|
172
|
+
__m128i max_u16x8 = _mm_max_epu16(lo_u16x8, hi_u16x8);
|
|
173
|
+
max_u16x8 = _mm_max_epu16(max_u16x8, _mm_shuffle_epi32(max_u16x8, _MM_SHUFFLE(2, 3, 0, 1)));
|
|
174
|
+
max_u16x8 = _mm_max_epu16(max_u16x8, _mm_shuffle_epi32(max_u16x8, _MM_SHUFFLE(1, 0, 3, 2)));
|
|
175
|
+
max_u16x8 = _mm_max_epu16(max_u16x8, _mm_srli_si128(max_u16x8, 2));
|
|
176
|
+
return (nk_u16_t)_mm_cvtsi128_si32(max_u16x8);
|
|
177
|
+
}
|
|
178
|
+
|
|
179
|
+
/** @brief Horizontal min of 8 signed i32s in a YMM register. */
|
|
180
|
+
NK_INTERNAL nk_i32_t nk_reduce_min_i32x8_haswell_(__m256i min_i32x8) {
|
|
181
|
+
__m128i lo_i32x4 = _mm256_castsi256_si128(min_i32x8);
|
|
182
|
+
__m128i hi_i32x4 = _mm256_extracti128_si256(min_i32x8, 1);
|
|
183
|
+
__m128i min_i32x4 = _mm_min_epi32(lo_i32x4, hi_i32x4);
|
|
184
|
+
min_i32x4 = _mm_min_epi32(min_i32x4, _mm_shuffle_epi32(min_i32x4, _MM_SHUFFLE(2, 3, 0, 1)));
|
|
185
|
+
min_i32x4 = _mm_min_epi32(min_i32x4, _mm_shuffle_epi32(min_i32x4, _MM_SHUFFLE(1, 0, 3, 2)));
|
|
186
|
+
return _mm_cvtsi128_si32(min_i32x4);
|
|
187
|
+
}
|
|
188
|
+
|
|
189
|
+
/** @brief Horizontal max of 8 signed i32s in a YMM register. */
|
|
190
|
+
NK_INTERNAL nk_i32_t nk_reduce_max_i32x8_haswell_(__m256i max_i32x8) {
|
|
191
|
+
__m128i lo_i32x4 = _mm256_castsi256_si128(max_i32x8);
|
|
192
|
+
__m128i hi_i32x4 = _mm256_extracti128_si256(max_i32x8, 1);
|
|
193
|
+
__m128i max_i32x4 = _mm_max_epi32(lo_i32x4, hi_i32x4);
|
|
194
|
+
max_i32x4 = _mm_max_epi32(max_i32x4, _mm_shuffle_epi32(max_i32x4, _MM_SHUFFLE(2, 3, 0, 1)));
|
|
195
|
+
max_i32x4 = _mm_max_epi32(max_i32x4, _mm_shuffle_epi32(max_i32x4, _MM_SHUFFLE(1, 0, 3, 2)));
|
|
196
|
+
return _mm_cvtsi128_si32(max_i32x4);
|
|
197
|
+
}
|
|
198
|
+
|
|
199
|
+
/** @brief Horizontal min of 8 unsigned u32s in a YMM register. */
|
|
200
|
+
NK_INTERNAL nk_u32_t nk_reduce_min_u32x8_haswell_(__m256i min_u32x8) {
|
|
201
|
+
__m128i lo_u32x4 = _mm256_castsi256_si128(min_u32x8);
|
|
202
|
+
__m128i hi_u32x4 = _mm256_extracti128_si256(min_u32x8, 1);
|
|
203
|
+
__m128i min_u32x4 = _mm_min_epu32(lo_u32x4, hi_u32x4);
|
|
204
|
+
min_u32x4 = _mm_min_epu32(min_u32x4, _mm_shuffle_epi32(min_u32x4, _MM_SHUFFLE(2, 3, 0, 1)));
|
|
205
|
+
min_u32x4 = _mm_min_epu32(min_u32x4, _mm_shuffle_epi32(min_u32x4, _MM_SHUFFLE(1, 0, 3, 2)));
|
|
206
|
+
return (nk_u32_t)_mm_cvtsi128_si32(min_u32x4);
|
|
207
|
+
}
|
|
208
|
+
|
|
209
|
+
/** @brief Horizontal max of 8 unsigned u32s in a YMM register. */
|
|
210
|
+
NK_INTERNAL nk_u32_t nk_reduce_max_u32x8_haswell_(__m256i max_u32x8) {
|
|
211
|
+
__m128i lo_u32x4 = _mm256_castsi256_si128(max_u32x8);
|
|
212
|
+
__m128i hi_u32x4 = _mm256_extracti128_si256(max_u32x8, 1);
|
|
213
|
+
__m128i max_u32x4 = _mm_max_epu32(lo_u32x4, hi_u32x4);
|
|
214
|
+
max_u32x4 = _mm_max_epu32(max_u32x4, _mm_shuffle_epi32(max_u32x4, _MM_SHUFFLE(2, 3, 0, 1)));
|
|
215
|
+
max_u32x4 = _mm_max_epu32(max_u32x4, _mm_shuffle_epi32(max_u32x4, _MM_SHUFFLE(1, 0, 3, 2)));
|
|
216
|
+
return (nk_u32_t)_mm_cvtsi128_si32(max_u32x4);
|
|
217
|
+
}
|
|
218
|
+
|
|
219
|
+
/** @brief Horizontal min of 4 signed i64s in a YMM register using comparison+blend. */
|
|
220
|
+
NK_INTERNAL nk_i64_t nk_reduce_min_i64x4_haswell_(__m256i min_i64x4) {
|
|
221
|
+
__m128i lo_i64x2 = _mm256_castsi256_si128(min_i64x4);
|
|
222
|
+
__m128i hi_i64x2 = _mm256_extracti128_si256(min_i64x4, 1);
|
|
223
|
+
__m128i cmp_i64x2 = _mm_cmpgt_epi64(lo_i64x2, hi_i64x2);
|
|
224
|
+
__m128i min_i64x2 = _mm_blendv_epi8(lo_i64x2, hi_i64x2, cmp_i64x2);
|
|
225
|
+
__m128i hi_lane_i64 = _mm_unpackhi_epi64(min_i64x2, min_i64x2);
|
|
226
|
+
__m128i cmp_final = _mm_cmpgt_epi64(min_i64x2, hi_lane_i64);
|
|
227
|
+
__m128i result_i64 = _mm_blendv_epi8(min_i64x2, hi_lane_i64, cmp_final);
|
|
228
|
+
return _mm_cvtsi128_si64(result_i64);
|
|
229
|
+
}
|
|
230
|
+
|
|
231
|
+
/** @brief Horizontal max of 4 signed i64s in a YMM register using comparison+blend. */
|
|
232
|
+
NK_INTERNAL nk_i64_t nk_reduce_max_i64x4_haswell_(__m256i max_i64x4) {
|
|
233
|
+
__m128i lo_i64x2 = _mm256_castsi256_si128(max_i64x4);
|
|
234
|
+
__m128i hi_i64x2 = _mm256_extracti128_si256(max_i64x4, 1);
|
|
235
|
+
__m128i cmp_i64x2 = _mm_cmpgt_epi64(lo_i64x2, hi_i64x2);
|
|
236
|
+
__m128i max_i64x2 = _mm_blendv_epi8(hi_i64x2, lo_i64x2, cmp_i64x2);
|
|
237
|
+
__m128i hi_lane_i64 = _mm_unpackhi_epi64(max_i64x2, max_i64x2);
|
|
238
|
+
__m128i cmp_final = _mm_cmpgt_epi64(max_i64x2, hi_lane_i64);
|
|
239
|
+
__m128i result_i64 = _mm_blendv_epi8(hi_lane_i64, max_i64x2, cmp_final);
|
|
240
|
+
return _mm_cvtsi128_si64(result_i64);
|
|
241
|
+
}
|
|
242
|
+
|
|
243
|
+
/** @brief Horizontal min of 4 unsigned u64s in a YMM register using XOR trick for unsigned comparison. */
|
|
244
|
+
NK_INTERNAL nk_u64_t nk_reduce_min_u64x4_haswell_(__m256i min_u64x4) {
|
|
245
|
+
__m128i sign_bit_i64 = _mm_set1_epi64x((nk_i64_t)0x8000000000000000ull);
|
|
246
|
+
__m128i lo_u64x2 = _mm256_castsi256_si128(min_u64x4);
|
|
247
|
+
__m128i hi_u64x2 = _mm256_extracti128_si256(min_u64x4, 1);
|
|
248
|
+
__m128i cmp_i64x2 = _mm_cmpgt_epi64(_mm_xor_si128(lo_u64x2, sign_bit_i64), _mm_xor_si128(hi_u64x2, sign_bit_i64));
|
|
249
|
+
__m128i min_u64x2 = _mm_blendv_epi8(lo_u64x2, hi_u64x2, cmp_i64x2);
|
|
250
|
+
__m128i hi_lane_u64 = _mm_unpackhi_epi64(min_u64x2, min_u64x2);
|
|
251
|
+
__m128i cmp_final = _mm_cmpgt_epi64(_mm_xor_si128(min_u64x2, sign_bit_i64),
|
|
252
|
+
_mm_xor_si128(hi_lane_u64, sign_bit_i64));
|
|
253
|
+
__m128i result_u64 = _mm_blendv_epi8(min_u64x2, hi_lane_u64, cmp_final);
|
|
254
|
+
return (nk_u64_t)_mm_cvtsi128_si64(result_u64);
|
|
255
|
+
}
|
|
256
|
+
|
|
257
|
+
/** @brief Horizontal max of 4 unsigned u64s in a YMM register using XOR trick for unsigned comparison. */
|
|
258
|
+
NK_INTERNAL nk_u64_t nk_reduce_max_u64x4_haswell_(__m256i max_u64x4) {
|
|
259
|
+
__m128i sign_bit_i64 = _mm_set1_epi64x((nk_i64_t)0x8000000000000000ull);
|
|
260
|
+
__m128i lo_u64x2 = _mm256_castsi256_si128(max_u64x4);
|
|
261
|
+
__m128i hi_u64x2 = _mm256_extracti128_si256(max_u64x4, 1);
|
|
262
|
+
__m128i cmp_i64x2 = _mm_cmpgt_epi64(_mm_xor_si128(lo_u64x2, sign_bit_i64), _mm_xor_si128(hi_u64x2, sign_bit_i64));
|
|
263
|
+
__m128i max_u64x2 = _mm_blendv_epi8(hi_u64x2, lo_u64x2, cmp_i64x2);
|
|
264
|
+
__m128i hi_lane_u64 = _mm_unpackhi_epi64(max_u64x2, max_u64x2);
|
|
265
|
+
__m128i cmp_final = _mm_cmpgt_epi64(_mm_xor_si128(max_u64x2, sign_bit_i64),
|
|
266
|
+
_mm_xor_si128(hi_lane_u64, sign_bit_i64));
|
|
267
|
+
__m128i result_u64 = _mm_blendv_epi8(hi_lane_u64, max_u64x2, cmp_final);
|
|
268
|
+
return (nk_u64_t)_mm_cvtsi128_si64(result_u64);
|
|
269
|
+
}
|
|
270
|
+
|
|
271
|
+
/** @brief Horizontal min of 8 floats in a YMM register. */
|
|
272
|
+
NK_INTERNAL nk_f32_t nk_reduce_min_f32x8_haswell_(__m256 min_f32x8) {
|
|
273
|
+
__m128 lo_f32x4 = _mm256_castps256_ps128(min_f32x8);
|
|
274
|
+
__m128 hi_f32x4 = _mm256_extractf128_ps(min_f32x8, 1);
|
|
275
|
+
__m128 min_f32x4 = _mm_min_ps(lo_f32x4, hi_f32x4);
|
|
276
|
+
min_f32x4 = _mm_min_ps(min_f32x4, _mm_shuffle_ps(min_f32x4, min_f32x4, _MM_SHUFFLE(2, 3, 0, 1)));
|
|
277
|
+
min_f32x4 = _mm_min_ps(min_f32x4, _mm_shuffle_ps(min_f32x4, min_f32x4, _MM_SHUFFLE(1, 0, 3, 2)));
|
|
278
|
+
return _mm_cvtss_f32(min_f32x4);
|
|
279
|
+
}
|
|
280
|
+
|
|
281
|
+
/** @brief Horizontal max of 8 floats in a YMM register. */
|
|
282
|
+
NK_INTERNAL nk_f32_t nk_reduce_max_f32x8_haswell_(__m256 max_f32x8) {
|
|
283
|
+
__m128 lo_f32x4 = _mm256_castps256_ps128(max_f32x8);
|
|
284
|
+
__m128 hi_f32x4 = _mm256_extractf128_ps(max_f32x8, 1);
|
|
285
|
+
__m128 max_f32x4 = _mm_max_ps(lo_f32x4, hi_f32x4);
|
|
286
|
+
max_f32x4 = _mm_max_ps(max_f32x4, _mm_shuffle_ps(max_f32x4, max_f32x4, _MM_SHUFFLE(2, 3, 0, 1)));
|
|
287
|
+
max_f32x4 = _mm_max_ps(max_f32x4, _mm_shuffle_ps(max_f32x4, max_f32x4, _MM_SHUFFLE(1, 0, 3, 2)));
|
|
288
|
+
return _mm_cvtss_f32(max_f32x4);
|
|
289
|
+
}
|
|
290
|
+
|
|
291
|
+
/** @brief Horizontal min of 4 doubles in a YMM register. */
|
|
292
|
+
NK_INTERNAL nk_f64_t nk_reduce_min_f64x4_haswell_(__m256d min_f64x4) {
|
|
293
|
+
__m128d lo_f64x2 = _mm256_castpd256_pd128(min_f64x4);
|
|
294
|
+
__m128d hi_f64x2 = _mm256_extractf128_pd(min_f64x4, 1);
|
|
295
|
+
__m128d min_f64x2 = _mm_min_pd(lo_f64x2, hi_f64x2);
|
|
296
|
+
min_f64x2 = _mm_min_pd(min_f64x2, _mm_shuffle_pd(min_f64x2, min_f64x2, 1));
|
|
297
|
+
return _mm_cvtsd_f64(min_f64x2);
|
|
298
|
+
}
|
|
299
|
+
|
|
300
|
+
/** @brief Horizontal max of 4 doubles in a YMM register. */
|
|
301
|
+
NK_INTERNAL nk_f64_t nk_reduce_max_f64x4_haswell_(__m256d max_f64x4) {
|
|
302
|
+
__m128d lo_f64x2 = _mm256_castpd256_pd128(max_f64x4);
|
|
303
|
+
__m128d hi_f64x2 = _mm256_extractf128_pd(max_f64x4, 1);
|
|
304
|
+
__m128d max_f64x2 = _mm_max_pd(lo_f64x2, hi_f64x2);
|
|
305
|
+
max_f64x2 = _mm_max_pd(max_f64x2, _mm_shuffle_pd(max_f64x2, max_f64x2, 1));
|
|
306
|
+
return _mm_cvtsd_f64(max_f64x2);
|
|
307
|
+
}
|
|
308
|
+
|
|
309
|
+
NK_INTERNAL __m256i nk_fp8x32_to_u8x32_comparable_haswell_(__m256i raw_i8x32) {
|
|
310
|
+
// In AVX2, use signed comparison: 0 > x means x < 0 (negative)
|
|
311
|
+
__m256i neg_i8x32 = _mm256_cmpgt_epi8(_mm256_setzero_si256(), raw_i8x32);
|
|
312
|
+
__m256i pos_xor_i8x32 = _mm256_set1_epi8((char)0x80);
|
|
313
|
+
__m256i neg_xor_i8x32 = _mm256_set1_epi8((char)0xFF);
|
|
314
|
+
__m256i xor_i8x32 = _mm256_blendv_epi8(pos_xor_i8x32, neg_xor_i8x32, neg_i8x32);
|
|
315
|
+
return _mm256_xor_si256(raw_i8x32, xor_i8x32);
|
|
316
|
+
}
|
|
317
|
+
|
|
318
|
+
NK_INTERNAL __m256i nk_u8x32_comparable_to_fp8x32_haswell_(__m256i cmp_i8x32) {
|
|
319
|
+
// Values < 0x80 were negative FP8 (sign bit clear in comparable form), values >= 0x80 were positive
|
|
320
|
+
__m256i sign_bit_i8x32 = _mm256_set1_epi8((char)0x80);
|
|
321
|
+
__m256i was_neg_i8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(cmp_i8x32, sign_bit_i8x32), _mm256_setzero_si256());
|
|
322
|
+
__m256i neg_xor_i8x32 = _mm256_set1_epi8((char)0xFF);
|
|
323
|
+
__m256i pos_xor_i8x32 = _mm256_set1_epi8((char)0x80);
|
|
324
|
+
__m256i xor_i8x32 = _mm256_blendv_epi8(pos_xor_i8x32, neg_xor_i8x32, was_neg_i8x32);
|
|
325
|
+
return _mm256_xor_si256(cmp_i8x32, xor_i8x32);
|
|
326
|
+
}
|
|
327
|
+
|
|
328
|
+
/** @brief Horizontal argmin: returns index of first minimum unsigned byte in YMM register. */
|
|
329
|
+
NK_INTERNAL nk_size_t nk_argmin_u8x32_haswell_(__m256i data_u8x32) {
|
|
330
|
+
nk_u8_t min_val = nk_reduce_min_u8x32_haswell_(data_u8x32);
|
|
331
|
+
__m256i eq_i8x32 = _mm256_cmpeq_epi8(data_u8x32, _mm256_set1_epi8((char)min_val));
|
|
332
|
+
int eq_bits = _mm256_movemask_epi8(eq_i8x32);
|
|
333
|
+
return (nk_size_t)_tzcnt_u32((unsigned int)eq_bits);
|
|
334
|
+
}
|
|
335
|
+
|
|
336
|
+
/** @brief Horizontal argmax: returns index of first maximum unsigned byte in YMM register. */
|
|
337
|
+
NK_INTERNAL nk_size_t nk_argmax_u8x32_haswell_(__m256i data_u8x32) {
|
|
338
|
+
nk_u8_t max_val = nk_reduce_max_u8x32_haswell_(data_u8x32);
|
|
339
|
+
__m256i eq_i8x32 = _mm256_cmpeq_epi8(data_u8x32, _mm256_set1_epi8((char)max_val));
|
|
340
|
+
int eq_bits = _mm256_movemask_epi8(eq_i8x32);
|
|
341
|
+
return (nk_size_t)_tzcnt_u32((unsigned int)eq_bits);
|
|
342
|
+
}
|
|
343
|
+
|
|
344
|
+
NK_INTERNAL __m256i nk_bf16x16_to_comparable_i16x16_haswell_(__m256i raw_u16x16) {
|
|
345
|
+
__m256i sign_i16x16 = _mm256_srai_epi16(raw_u16x16, 15);
|
|
346
|
+
__m256i flip_i16x16 = _mm256_srli_epi16(sign_i16x16, 1);
|
|
347
|
+
return _mm256_xor_si256(raw_u16x16, flip_i16x16);
|
|
348
|
+
}
|
|
349
|
+
|
|
350
|
+
NK_INTERNAL __m256i nk_f16x16_to_comparable_i16x16_haswell_(__m256i raw_u16x16) {
|
|
351
|
+
__m256i sign_i16x16 = _mm256_srai_epi16(raw_u16x16, 15);
|
|
352
|
+
__m256i flip_i16x16 = _mm256_srli_epi16(sign_i16x16, 1);
|
|
353
|
+
return _mm256_xor_si256(raw_u16x16, flip_i16x16);
|
|
354
|
+
}
|
|
355
|
+
|
|
356
|
+
NK_INTERNAL __m256i nk_u64_sadd_epi64_haswell_(__m256i a_u64x4, __m256i b_u64x4) {
|
|
357
|
+
__m256i result_u64x4 = _mm256_add_epi64(a_u64x4, b_u64x4);
|
|
358
|
+
// Unsigned overflow: result < a. AVX2 only has signed cmpgt, so flip sign bits.
|
|
359
|
+
__m256i sign_bit_i64x4 = _mm256_set1_epi64x((nk_i64_t)0x8000000000000000ULL);
|
|
360
|
+
__m256i a_biased_i64x4 = _mm256_xor_si256(a_u64x4, sign_bit_i64x4);
|
|
361
|
+
__m256i result_biased_i64x4 = _mm256_xor_si256(result_u64x4, sign_bit_i64x4);
|
|
362
|
+
__m256i overflow_u64x4 = _mm256_cmpgt_epi64(a_biased_i64x4, result_biased_i64x4);
|
|
363
|
+
return _mm256_or_si256(result_u64x4, overflow_u64x4); // overflow lanes -> all-ones = U64_MAX
|
|
364
|
+
}
|
|
365
|
+
|
|
366
|
+
NK_INTERNAL __m256i nk_i64_smul_sq_epi64_haswell_(__m256i val_i64x4) {
|
|
367
|
+
// abs(val) — AVX2 lacks _mm256_abs_epi64, emulate:
|
|
368
|
+
__m256i sign_i64x4 = _mm256_cmpgt_epi64(_mm256_setzero_si256(), val_i64x4);
|
|
369
|
+
__m256i abs_val_u64x4 = _mm256_sub_epi64(_mm256_xor_si256(val_i64x4, sign_i64x4), sign_i64x4);
|
|
370
|
+
// Extract low 32 bits and square: _mm256_mul_epu32 multiplies even 32-bit lanes -> 64-bit
|
|
371
|
+
__m256i low_halves_u32x4 = _mm256_and_si256(abs_val_u64x4, _mm256_set1_epi64x(0xFFFFFFFF));
|
|
372
|
+
__m256i low_squared_u64x4 = _mm256_mul_epu32(low_halves_u32x4, low_halves_u32x4);
|
|
373
|
+
// Check if high 32 bits are zero (value fits in 32 bits)
|
|
374
|
+
__m256i high_bits_u64x4 = _mm256_srli_epi64(abs_val_u64x4, 32);
|
|
375
|
+
__m256i is_small_u64x4 = _mm256_cmpeq_epi64(high_bits_u64x4, _mm256_setzero_si256());
|
|
376
|
+
// Saturate: I64_MAX when overflow (since result is always positive)
|
|
377
|
+
__m256i saturated_u64x4 = _mm256_set1_epi64x(NK_I64_MAX);
|
|
378
|
+
return _mm256_blendv_epi8(saturated_u64x4, low_squared_u64x4, is_small_u64x4);
|
|
379
|
+
}
|
|
380
|
+
|
|
381
|
+
NK_INTERNAL __m256i nk_u64_smul_sq_epi64_haswell_(__m256i val_u64x4) {
|
|
382
|
+
__m256i low_halves_u32x4 = _mm256_and_si256(val_u64x4, _mm256_set1_epi64x(0xFFFFFFFF));
|
|
383
|
+
__m256i low_squared_u64x4 = _mm256_mul_epu32(low_halves_u32x4, low_halves_u32x4);
|
|
384
|
+
__m256i high_bits_u64x4 = _mm256_srli_epi64(val_u64x4, 32);
|
|
385
|
+
__m256i is_small_u64x4 = _mm256_cmpeq_epi64(high_bits_u64x4, _mm256_setzero_si256());
|
|
386
|
+
__m256i saturated_u64x4 = _mm256_set1_epi64x((nk_i64_t)-1);
|
|
387
|
+
return _mm256_blendv_epi8(saturated_u64x4, low_squared_u64x4, is_small_u64x4);
|
|
388
|
+
}
|
|
389
|
+
|
|
390
|
+
NK_INTERNAL nk_u64_t nk_reduce_sadd_u64x4_haswell_(__m256i v_u64x4) {
|
|
391
|
+
// 4->2: fold high 128 into low 128
|
|
392
|
+
__m128i high_u64x2 = _mm256_extracti128_si256(v_u64x4, 1);
|
|
393
|
+
__m128i low_u64x2 = _mm256_castsi256_si128(v_u64x4);
|
|
394
|
+
__m128i sum_u64x2 = _mm_add_epi64(low_u64x2, high_u64x2);
|
|
395
|
+
__m128i sign_bit_i64x2 = _mm_set1_epi64x((nk_i64_t)0x8000000000000000ULL);
|
|
396
|
+
__m128i low_biased_i64x2 = _mm_xor_si128(low_u64x2, sign_bit_i64x2);
|
|
397
|
+
__m128i sum_biased_i64x2 = _mm_xor_si128(sum_u64x2, sign_bit_i64x2);
|
|
398
|
+
__m128i overflow_u64x2 = _mm_cmpgt_epi64(low_biased_i64x2, sum_biased_i64x2);
|
|
399
|
+
sum_u64x2 = _mm_or_si128(sum_u64x2, overflow_u64x2);
|
|
400
|
+
// 2->1: fold lane 1 into lane 0
|
|
401
|
+
__m128i swapped_u64x2 = _mm_unpackhi_epi64(sum_u64x2, sum_u64x2);
|
|
402
|
+
__m128i final_u64x2 = _mm_add_epi64(sum_u64x2, swapped_u64x2);
|
|
403
|
+
__m128i sum2_biased_i64x2 = _mm_xor_si128(sum_u64x2, sign_bit_i64x2);
|
|
404
|
+
__m128i final_biased_i64x2 = _mm_xor_si128(final_u64x2, sign_bit_i64x2);
|
|
405
|
+
__m128i overflow2_u64x2 = _mm_cmpgt_epi64(sum2_biased_i64x2, final_biased_i64x2);
|
|
406
|
+
final_u64x2 = _mm_or_si128(final_u64x2, overflow2_u64x2);
|
|
407
|
+
return (nk_u64_t)_mm_cvtsi128_si64(final_u64x2);
|
|
408
|
+
}
|
|
409
|
+
|
|
410
|
+
NK_INTERNAL __m256i nk_stride_blend_u1x32_(nk_size_t stride) {
|
|
411
|
+
switch (stride) {
|
|
412
|
+
case 2:
|
|
413
|
+
return _mm256_setr_epi8(-1, 0, -1, 0, -1, 0, -1, 0, -1, 0, -1, 0, -1, 0, -1, 0, -1, 0, -1, 0, -1, 0, -1, 0, -1,
|
|
414
|
+
0, -1, 0, -1, 0, -1, 0);
|
|
415
|
+
case 3:
|
|
416
|
+
return _mm256_setr_epi8(-1, 0, 0, -1, 0, 0, -1, 0, 0, -1, 0, 0, -1, 0, 0, -1, 0, 0, -1, 0, 0, -1, 0, 0, -1, 0,
|
|
417
|
+
0, -1, 0, 0, -1, 0);
|
|
418
|
+
case 4:
|
|
419
|
+
return _mm256_setr_epi8(-1, 0, 0, 0, -1, 0, 0, 0, -1, 0, 0, 0, -1, 0, 0, 0, -1, 0, 0, 0, -1, 0, 0, 0, -1, 0, 0,
|
|
420
|
+
0, -1, 0, 0, 0);
|
|
421
|
+
case 5:
|
|
422
|
+
return _mm256_setr_epi8(-1, 0, 0, 0, 0, -1, 0, 0, 0, 0, -1, 0, 0, 0, 0, -1, 0, 0, 0, 0, -1, 0, 0, 0, 0, -1, 0,
|
|
423
|
+
0, 0, 0, -1, 0);
|
|
424
|
+
case 6:
|
|
425
|
+
return _mm256_setr_epi8(-1, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, -1, 0, 0, 0,
|
|
426
|
+
0, 0, -1, 0);
|
|
427
|
+
case 7:
|
|
428
|
+
return _mm256_setr_epi8(-1, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0,
|
|
429
|
+
-1, 0, 0, 0);
|
|
430
|
+
case 8:
|
|
431
|
+
return _mm256_setr_epi8(-1, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0,
|
|
432
|
+
0, 0, 0, 0);
|
|
433
|
+
default: return _mm256_setzero_si256();
|
|
434
|
+
}
|
|
435
|
+
}
|
|
436
|
+
|
|
437
|
+
NK_INTERNAL __m256i nk_stride_blend_b16x16_(nk_size_t stride) {
|
|
438
|
+
switch (stride) {
|
|
439
|
+
case 2: return _mm256_setr_epi16(-1, 0, -1, 0, -1, 0, -1, 0, -1, 0, -1, 0, -1, 0, -1, 0);
|
|
440
|
+
case 3: return _mm256_setr_epi16(-1, 0, 0, -1, 0, 0, -1, 0, 0, -1, 0, 0, -1, 0, 0, -1);
|
|
441
|
+
case 4: return _mm256_setr_epi16(-1, 0, 0, 0, -1, 0, 0, 0, -1, 0, 0, 0, -1, 0, 0, 0);
|
|
442
|
+
case 5: return _mm256_setr_epi16(-1, 0, 0, 0, 0, -1, 0, 0, 0, 0, -1, 0, 0, 0, 0, -1);
|
|
443
|
+
case 6: return _mm256_setr_epi16(-1, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, -1, 0, 0, 0);
|
|
444
|
+
case 7: return _mm256_setr_epi16(-1, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, -1, 0);
|
|
445
|
+
case 8: return _mm256_setr_epi16(-1, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0);
|
|
446
|
+
default: return _mm256_setzero_si256();
|
|
447
|
+
}
|
|
448
|
+
}
|
|
449
|
+
|
|
450
|
+
NK_INTERNAL __m256i nk_stride_blend_b32x8_(nk_size_t stride) {
|
|
451
|
+
switch (stride) {
|
|
452
|
+
case 2: return _mm256_setr_epi32(-1, 0, -1, 0, -1, 0, -1, 0); // 4 elems
|
|
453
|
+
case 3: return _mm256_setr_epi32(-1, 0, 0, -1, 0, 0, -1, 0); // 3 elems
|
|
454
|
+
case 4: return _mm256_setr_epi32(-1, 0, 0, 0, -1, 0, 0, 0); // 2 elems
|
|
455
|
+
case 5: return _mm256_setr_epi32(-1, 0, 0, 0, 0, -1, 0, 0); // 2 elems
|
|
456
|
+
case 6: return _mm256_setr_epi32(-1, 0, 0, 0, 0, 0, -1, 0); // 2 elems
|
|
457
|
+
case 7: return _mm256_setr_epi32(-1, 0, 0, 0, 0, 0, 0, -1); // 2 elems
|
|
458
|
+
case 8: return _mm256_setr_epi32(-1, 0, 0, 0, 0, 0, 0, 0); // 1 elem
|
|
459
|
+
default: return _mm256_setzero_si256();
|
|
460
|
+
}
|
|
461
|
+
}
|
|
462
|
+
|
|
463
|
+
NK_INTERNAL __m256i nk_stride_blend_b64x4_(nk_size_t stride) {
|
|
464
|
+
switch (stride) {
|
|
465
|
+
case 2: return _mm256_setr_epi64x(-1, 0, -1, 0); // 2 elems
|
|
466
|
+
case 3: return _mm256_setr_epi64x(-1, 0, 0, -1); // 2 elems (wraps)
|
|
467
|
+
case 4: return _mm256_setr_epi64x(-1, 0, 0, 0); // 1 elem
|
|
468
|
+
default: return _mm256_setr_epi64x(-1, 0, 0, 0); // 1 elem for stride 5+
|
|
469
|
+
}
|
|
470
|
+
}
|
|
471
|
+
|
|
472
|
+
NK_INTERNAL nk_size_t nk_stride_elems_b32x8_(nk_size_t stride) {
|
|
473
|
+
switch (stride) {
|
|
474
|
+
case 2: return 4;
|
|
475
|
+
case 3: return 3;
|
|
476
|
+
case 4: return 2;
|
|
477
|
+
case 5: return 2;
|
|
478
|
+
case 6: return 2;
|
|
479
|
+
case 7: return 2;
|
|
480
|
+
case 8: return 1;
|
|
481
|
+
default: return 0;
|
|
482
|
+
}
|
|
483
|
+
}
|
|
484
|
+
|
|
485
|
+
NK_INTERNAL nk_size_t nk_stride_elems_b64x4_(nk_size_t stride) {
|
|
486
|
+
switch (stride) {
|
|
487
|
+
case 2: return 2;
|
|
488
|
+
case 3: return 2;
|
|
489
|
+
case 4: return 1;
|
|
490
|
+
default: return 1;
|
|
491
|
+
}
|
|
492
|
+
}
|
|
493
|
+
|
|
494
|
+
NK_INTERNAL void nk_reduce_moments_f32_haswell_contiguous_( //
|
|
495
|
+
nk_f32_t const *data_ptr, nk_size_t count, //
|
|
496
|
+
nk_f64_t *sum_ptr, nk_f64_t *sumsq_ptr) {
|
|
497
|
+
|
|
498
|
+
__m256d sum_low_f64x4 = _mm256_setzero_pd(), sum_high_f64x4 = _mm256_setzero_pd();
|
|
499
|
+
__m256d sumsq_low_f64x4 = _mm256_setzero_pd(), sumsq_high_f64x4 = _mm256_setzero_pd();
|
|
500
|
+
nk_size_t idx = 0;
|
|
501
|
+
for (; idx + 8 <= count; idx += 8) {
|
|
502
|
+
__m256d low_f64x4 = _mm256_cvtps_pd(_mm_loadu_ps(data_ptr + idx));
|
|
503
|
+
__m256d high_f64x4 = _mm256_cvtps_pd(_mm_loadu_ps(data_ptr + idx + 4));
|
|
504
|
+
sum_low_f64x4 = _mm256_add_pd(sum_low_f64x4, low_f64x4);
|
|
505
|
+
sum_high_f64x4 = _mm256_add_pd(sum_high_f64x4, high_f64x4);
|
|
506
|
+
sumsq_low_f64x4 = _mm256_fmadd_pd(low_f64x4, low_f64x4, sumsq_low_f64x4);
|
|
507
|
+
sumsq_high_f64x4 = _mm256_fmadd_pd(high_f64x4, high_f64x4, sumsq_high_f64x4);
|
|
508
|
+
}
|
|
509
|
+
__m256d sum_f64x4 = _mm256_add_pd(sum_low_f64x4, sum_high_f64x4);
|
|
510
|
+
__m256d sumsq_f64x4 = _mm256_add_pd(sumsq_low_f64x4, sumsq_high_f64x4);
|
|
511
|
+
nk_f64_t sum = nk_reduce_add_f64x4_haswell_(sum_f64x4);
|
|
512
|
+
nk_f64_t sumsq = nk_reduce_add_f64x4_haswell_(sumsq_f64x4);
|
|
513
|
+
for (; idx < count; ++idx) {
|
|
514
|
+
nk_f64_t val = (nk_f64_t)data_ptr[idx];
|
|
515
|
+
sum += val, sumsq += val * val;
|
|
516
|
+
}
|
|
517
|
+
*sum_ptr = sum, *sumsq_ptr = sumsq;
|
|
518
|
+
}
|
|
519
|
+
|
|
520
|
+
NK_INTERNAL void nk_reduce_moments_f32_haswell_strided_( //
|
|
521
|
+
nk_f32_t const *data_ptr, nk_size_t count, nk_size_t stride_elements, //
|
|
522
|
+
nk_f64_t *sum_ptr, nk_f64_t *sumsq_ptr) {
|
|
523
|
+
|
|
524
|
+
__m256i blend_mask_i32x8 = nk_stride_blend_b32x8_(stride_elements);
|
|
525
|
+
__m128 blend_low_f32x4 = _mm_castsi128_ps(_mm256_castsi256_si128(blend_mask_i32x8));
|
|
526
|
+
__m128 blend_high_f32x4 = _mm_castsi128_ps(_mm256_extracti128_si256(blend_mask_i32x8, 1));
|
|
527
|
+
__m128 zero_f32x4 = _mm_setzero_ps();
|
|
528
|
+
__m256d sum_low_f64x4 = _mm256_setzero_pd(), sum_high_f64x4 = _mm256_setzero_pd();
|
|
529
|
+
__m256d sumsq_low_f64x4 = _mm256_setzero_pd(), sumsq_high_f64x4 = _mm256_setzero_pd();
|
|
530
|
+
nk_size_t idx = 0, total = count * stride_elements;
|
|
531
|
+
nk_size_t step = nk_size_round_up_to_multiple_(8, stride_elements);
|
|
532
|
+
for (; idx + step <= total; idx += step) {
|
|
533
|
+
__m128 low_f32x4 = _mm_blendv_ps(zero_f32x4, _mm_loadu_ps(data_ptr + idx), blend_low_f32x4);
|
|
534
|
+
__m128 high_f32x4 = _mm_blendv_ps(zero_f32x4, _mm_loadu_ps(data_ptr + idx + 4), blend_high_f32x4);
|
|
535
|
+
__m256d low_f64x4 = _mm256_cvtps_pd(low_f32x4);
|
|
536
|
+
__m256d high_f64x4 = _mm256_cvtps_pd(high_f32x4);
|
|
537
|
+
sum_low_f64x4 = _mm256_add_pd(sum_low_f64x4, low_f64x4);
|
|
538
|
+
sum_high_f64x4 = _mm256_add_pd(sum_high_f64x4, high_f64x4);
|
|
539
|
+
sumsq_low_f64x4 = _mm256_fmadd_pd(low_f64x4, low_f64x4, sumsq_low_f64x4);
|
|
540
|
+
sumsq_high_f64x4 = _mm256_fmadd_pd(high_f64x4, high_f64x4, sumsq_high_f64x4);
|
|
541
|
+
}
|
|
542
|
+
__m256d sum_f64x4 = _mm256_add_pd(sum_low_f64x4, sum_high_f64x4);
|
|
543
|
+
__m256d sumsq_f64x4 = _mm256_add_pd(sumsq_low_f64x4, sumsq_high_f64x4);
|
|
544
|
+
nk_f64_t sum = nk_reduce_add_f64x4_haswell_(sum_f64x4);
|
|
545
|
+
nk_f64_t sumsq = nk_reduce_add_f64x4_haswell_(sumsq_f64x4);
|
|
546
|
+
nk_f32_t const *ptr = data_ptr + idx;
|
|
547
|
+
nk_size_t remaining = count - idx / stride_elements;
|
|
548
|
+
for (nk_size_t i = 0; i < remaining; ++i, ptr += stride_elements) {
|
|
549
|
+
nk_f64_t val = (nk_f64_t)(*ptr);
|
|
550
|
+
sum += val, sumsq += val * val;
|
|
551
|
+
}
|
|
552
|
+
*sum_ptr = sum, *sumsq_ptr = sumsq;
|
|
553
|
+
}
|
|
554
|
+
|
|
555
|
+
NK_INTERNAL void nk_reduce_moments_f32_haswell_gather_( //
|
|
556
|
+
nk_f32_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
557
|
+
nk_f64_t *sum_ptr, nk_f64_t *sumsq_ptr) {
|
|
558
|
+
|
|
559
|
+
nk_i32_t stride_elements = (nk_i32_t)(stride_bytes / sizeof(nk_f32_t));
|
|
560
|
+
__m256i indices_i32x8 = _mm256_mullo_epi32(_mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7),
|
|
561
|
+
_mm256_set1_epi32(stride_elements));
|
|
562
|
+
__m256d sum_f64x4 = _mm256_setzero_pd();
|
|
563
|
+
__m256d sumsq_f64x4 = _mm256_setzero_pd();
|
|
564
|
+
nk_size_t idx = 0;
|
|
565
|
+
for (; idx + 8 <= count; idx += 8) {
|
|
566
|
+
__m256 gathered_f32x8 = _mm256_i32gather_ps(data_ptr + idx * stride_elements, indices_i32x8, sizeof(nk_f32_t));
|
|
567
|
+
__m256d low_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(gathered_f32x8));
|
|
568
|
+
__m256d high_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(gathered_f32x8, 1));
|
|
569
|
+
sum_f64x4 = _mm256_add_pd(sum_f64x4, low_f64x4);
|
|
570
|
+
sum_f64x4 = _mm256_add_pd(sum_f64x4, high_f64x4);
|
|
571
|
+
sumsq_f64x4 = _mm256_fmadd_pd(low_f64x4, low_f64x4, sumsq_f64x4);
|
|
572
|
+
sumsq_f64x4 = _mm256_fmadd_pd(high_f64x4, high_f64x4, sumsq_f64x4);
|
|
573
|
+
}
|
|
574
|
+
nk_f64_t sum = nk_reduce_add_f64x4_haswell_(sum_f64x4);
|
|
575
|
+
nk_f64_t sumsq = nk_reduce_add_f64x4_haswell_(sumsq_f64x4);
|
|
576
|
+
unsigned char const *ptr = (unsigned char const *)(data_ptr + idx * stride_elements);
|
|
577
|
+
for (; idx < count; ++idx, ptr += stride_bytes) {
|
|
578
|
+
nk_f64_t val = (nk_f64_t)(*(nk_f32_t const *)ptr);
|
|
579
|
+
sum += val, sumsq += val * val;
|
|
580
|
+
}
|
|
581
|
+
*sum_ptr = sum, *sumsq_ptr = sumsq;
|
|
582
|
+
}
|
|
583
|
+
|
|
584
|
+
NK_PUBLIC void nk_reduce_moments_f32_haswell( //
|
|
585
|
+
nk_f32_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
586
|
+
nk_f64_t *sum_ptr, nk_f64_t *sumsq_ptr) {
|
|
587
|
+
|
|
588
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_f32_t);
|
|
589
|
+
int aligned = (stride_bytes % sizeof(nk_f32_t) == 0);
|
|
590
|
+
if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
|
|
591
|
+
else if (!aligned) nk_reduce_moments_f32_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
592
|
+
else if (count > (nk_size_t)(NK_U16_MAX + 1) * 8) {
|
|
593
|
+
nk_size_t left_count = count / 2;
|
|
594
|
+
nk_f64_t left_sum, left_sumsq, right_sum, right_sumsq;
|
|
595
|
+
nk_reduce_moments_f32_haswell(data_ptr, left_count, stride_bytes, &left_sum, &left_sumsq);
|
|
596
|
+
nk_reduce_moments_f32_haswell(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
|
|
597
|
+
&right_sum, &right_sumsq);
|
|
598
|
+
*sum_ptr = left_sum + right_sum, *sumsq_ptr = left_sumsq + right_sumsq;
|
|
599
|
+
}
|
|
600
|
+
else if (stride_elements == 1) nk_reduce_moments_f32_haswell_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
|
|
601
|
+
else if (stride_elements <= 8)
|
|
602
|
+
nk_reduce_moments_f32_haswell_strided_(data_ptr, count, stride_elements, sum_ptr, sumsq_ptr);
|
|
603
|
+
else nk_reduce_moments_f32_haswell_gather_(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
604
|
+
}
|
|
605
|
+
|
|
606
|
+
NK_INTERNAL void nk_reduce_minmax_f32_haswell_contiguous_( //
|
|
607
|
+
nk_f32_t const *data_ptr, nk_size_t count, //
|
|
608
|
+
nk_f32_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
609
|
+
nk_f32_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
610
|
+
|
|
611
|
+
__m256 min_f32x8 = _mm256_set1_ps(NK_F32_MAX);
|
|
612
|
+
__m256 max_f32x8 = _mm256_set1_ps(NK_F32_MIN);
|
|
613
|
+
__m256i min_loop_cycle_u32x8 = _mm256_setzero_si256();
|
|
614
|
+
__m256i max_loop_cycle_u32x8 = _mm256_setzero_si256();
|
|
615
|
+
__m256i current_loop_cycle_u32x8 = _mm256_setzero_si256();
|
|
616
|
+
__m256i one_u32x8 = _mm256_set1_epi32(1);
|
|
617
|
+
|
|
618
|
+
nk_size_t idx = 0;
|
|
619
|
+
for (; idx + 8 <= count; idx += 8) {
|
|
620
|
+
__m256 data_f32x8 = _mm256_loadu_ps(data_ptr + idx);
|
|
621
|
+
__m256 less_b32x8 = _mm256_cmp_ps(data_f32x8, min_f32x8, _CMP_LT_OQ);
|
|
622
|
+
__m256 greater_b32x8 = _mm256_cmp_ps(data_f32x8, max_f32x8, _CMP_GT_OQ);
|
|
623
|
+
min_f32x8 = _mm256_blendv_ps(min_f32x8, data_f32x8, less_b32x8);
|
|
624
|
+
max_f32x8 = _mm256_blendv_ps(max_f32x8, data_f32x8, greater_b32x8);
|
|
625
|
+
min_loop_cycle_u32x8 = _mm256_blendv_epi8(min_loop_cycle_u32x8, current_loop_cycle_u32x8,
|
|
626
|
+
_mm256_castps_si256(less_b32x8));
|
|
627
|
+
max_loop_cycle_u32x8 = _mm256_blendv_epi8(max_loop_cycle_u32x8, current_loop_cycle_u32x8,
|
|
628
|
+
_mm256_castps_si256(greater_b32x8));
|
|
629
|
+
current_loop_cycle_u32x8 = _mm256_add_epi32(current_loop_cycle_u32x8, one_u32x8);
|
|
630
|
+
}
|
|
631
|
+
|
|
632
|
+
// Reduce SIMD lanes and scan for indices
|
|
633
|
+
nk_f32_t min_value = NK_F32_MAX, max_value = NK_F32_MIN;
|
|
634
|
+
nk_size_t min_idx = NK_SIZE_MAX, max_idx = NK_SIZE_MAX;
|
|
635
|
+
|
|
636
|
+
// Locate the minimum index
|
|
637
|
+
if (idx > 0) min_value = nk_reduce_min_f32x8_haswell_(min_f32x8);
|
|
638
|
+
if (min_value < NK_F32_MAX) {
|
|
639
|
+
__m256 value_match_b32x8 = _mm256_cmp_ps(min_f32x8, _mm256_set1_ps(min_value), _CMP_EQ_OQ);
|
|
640
|
+
__m256i masked_cycle_u32x8 = _mm256_blendv_epi8(_mm256_set1_epi32((int)NK_U32_MAX), min_loop_cycle_u32x8,
|
|
641
|
+
_mm256_castps_si256(value_match_b32x8));
|
|
642
|
+
nk_u32_t earliest_loop_cycle = nk_reduce_min_u32x8_haswell_(masked_cycle_u32x8);
|
|
643
|
+
__m256i cycle_match_b32x8 = _mm256_cmpeq_epi32(masked_cycle_u32x8, _mm256_set1_epi32((int)earliest_loop_cycle));
|
|
644
|
+
unsigned int min_lane = _tzcnt_u32((unsigned int)_mm256_movemask_ps(_mm256_castsi256_ps(cycle_match_b32x8)));
|
|
645
|
+
nk_b256_vec_t loop_cycle_vec;
|
|
646
|
+
loop_cycle_vec.ymm = min_loop_cycle_u32x8;
|
|
647
|
+
min_idx = (nk_size_t)loop_cycle_vec.u32s[min_lane] * 8 + min_lane;
|
|
648
|
+
}
|
|
649
|
+
// Locate the maximum index
|
|
650
|
+
if (idx > 0) max_value = nk_reduce_max_f32x8_haswell_(max_f32x8);
|
|
651
|
+
if (max_value > NK_F32_MIN) {
|
|
652
|
+
__m256 value_match_b32x8 = _mm256_cmp_ps(max_f32x8, _mm256_set1_ps(max_value), _CMP_EQ_OQ);
|
|
653
|
+
__m256i masked_cycle_u32x8 = _mm256_blendv_epi8(_mm256_set1_epi32((int)NK_U32_MAX), max_loop_cycle_u32x8,
|
|
654
|
+
_mm256_castps_si256(value_match_b32x8));
|
|
655
|
+
nk_u32_t earliest_loop_cycle = nk_reduce_min_u32x8_haswell_(masked_cycle_u32x8);
|
|
656
|
+
__m256i cycle_match_b32x8 = _mm256_cmpeq_epi32(masked_cycle_u32x8, _mm256_set1_epi32((int)earliest_loop_cycle));
|
|
657
|
+
unsigned int max_lane = _tzcnt_u32((unsigned int)_mm256_movemask_ps(_mm256_castsi256_ps(cycle_match_b32x8)));
|
|
658
|
+
nk_b256_vec_t loop_cycle_vec;
|
|
659
|
+
loop_cycle_vec.ymm = max_loop_cycle_u32x8;
|
|
660
|
+
max_idx = (nk_size_t)loop_cycle_vec.u32s[max_lane] * 8 + max_lane;
|
|
661
|
+
}
|
|
662
|
+
|
|
663
|
+
// Scalar tail
|
|
664
|
+
for (; idx < count; ++idx) {
|
|
665
|
+
nk_f32_t val = data_ptr[idx];
|
|
666
|
+
if (val < min_value) min_value = val, min_idx = idx;
|
|
667
|
+
if (val > max_value) max_value = val, max_idx = idx;
|
|
668
|
+
}
|
|
669
|
+
*min_value_ptr = min_value, *min_index_ptr = min_idx;
|
|
670
|
+
*max_value_ptr = max_value, *max_index_ptr = max_idx;
|
|
671
|
+
}
|
|
672
|
+
|
|
673
|
+
NK_PUBLIC void nk_reduce_minmax_f32_haswell( //
|
|
674
|
+
nk_f32_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
675
|
+
nk_f32_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
676
|
+
nk_f32_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
677
|
+
|
|
678
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_f32_t);
|
|
679
|
+
int aligned = (stride_bytes % sizeof(nk_f32_t) == 0);
|
|
680
|
+
if (count == 0)
|
|
681
|
+
*min_value_ptr = NK_F32_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_F32_MIN,
|
|
682
|
+
*max_index_ptr = NK_SIZE_MAX;
|
|
683
|
+
else if (!aligned)
|
|
684
|
+
nk_reduce_minmax_f32_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
685
|
+
max_index_ptr);
|
|
686
|
+
else if (count > (nk_size_t)NK_U32_MAX * 8) {
|
|
687
|
+
nk_size_t left_count = count / 2;
|
|
688
|
+
nk_f32_t left_min, right_min, left_max, right_max;
|
|
689
|
+
nk_size_t left_min_index, right_min_index, left_max_index, right_max_index;
|
|
690
|
+
nk_reduce_minmax_f32_haswell(data_ptr, left_count, stride_bytes, &left_min, &left_min_index, &left_max,
|
|
691
|
+
&left_max_index);
|
|
692
|
+
nk_reduce_minmax_f32_haswell(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
|
|
693
|
+
&right_min, &right_min_index, &right_max, &right_max_index);
|
|
694
|
+
if (right_min < left_min) *min_value_ptr = right_min, *min_index_ptr = left_count + right_min_index;
|
|
695
|
+
else *min_value_ptr = left_min, *min_index_ptr = left_min_index;
|
|
696
|
+
if (right_max > left_max) *max_value_ptr = right_max, *max_index_ptr = left_count + right_max_index;
|
|
697
|
+
else *max_value_ptr = left_max, *max_index_ptr = left_max_index;
|
|
698
|
+
}
|
|
699
|
+
else if (stride_elements == 1)
|
|
700
|
+
nk_reduce_minmax_f32_haswell_contiguous_(data_ptr, count, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
701
|
+
max_index_ptr);
|
|
702
|
+
else
|
|
703
|
+
nk_reduce_minmax_f32_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
704
|
+
max_index_ptr);
|
|
705
|
+
}
|
|
706
|
+
|
|
707
|
+
NK_INTERNAL void nk_reduce_moments_f64_haswell_contiguous_( //
|
|
708
|
+
nk_f64_t const *data_ptr, nk_size_t count, //
|
|
709
|
+
nk_f64_t *sum_ptr, nk_f64_t *sumsq_ptr) {
|
|
710
|
+
|
|
711
|
+
__m256d sum_f64x4 = _mm256_setzero_pd();
|
|
712
|
+
__m256d sum_comp_f64x4 = _mm256_setzero_pd();
|
|
713
|
+
__m256d sumsq_f64x4 = _mm256_setzero_pd();
|
|
714
|
+
__m256d sumsq_comp_f64x4 = _mm256_setzero_pd();
|
|
715
|
+
nk_size_t idx = 0;
|
|
716
|
+
for (; idx + 4 <= count; idx += 4) {
|
|
717
|
+
__m256d val_f64x4 = _mm256_loadu_pd(data_ptr + idx);
|
|
718
|
+
__m256d tentative_f64x4 = _mm256_add_pd(sum_f64x4, val_f64x4);
|
|
719
|
+
__m256d round_f64x4 = _mm256_sub_pd(tentative_f64x4, sum_f64x4);
|
|
720
|
+
__m256d corr_f64x4 = _mm256_add_pd(_mm256_sub_pd(sum_f64x4, _mm256_sub_pd(tentative_f64x4, round_f64x4)),
|
|
721
|
+
_mm256_sub_pd(val_f64x4, round_f64x4));
|
|
722
|
+
sum_comp_f64x4 = _mm256_add_pd(sum_comp_f64x4, corr_f64x4);
|
|
723
|
+
sum_f64x4 = tentative_f64x4;
|
|
724
|
+
__m256d sq_f64x4 = _mm256_mul_pd(val_f64x4, val_f64x4);
|
|
725
|
+
__m256d tentative_sq_f64x4 = _mm256_add_pd(sumsq_f64x4, sq_f64x4);
|
|
726
|
+
__m256d round_sq_f64x4 = _mm256_sub_pd(tentative_sq_f64x4, sumsq_f64x4);
|
|
727
|
+
__m256d corr_sq_f64x4 = _mm256_add_pd(
|
|
728
|
+
_mm256_sub_pd(sumsq_f64x4, _mm256_sub_pd(tentative_sq_f64x4, round_sq_f64x4)),
|
|
729
|
+
_mm256_sub_pd(sq_f64x4, round_sq_f64x4));
|
|
730
|
+
sumsq_comp_f64x4 = _mm256_add_pd(sumsq_comp_f64x4, corr_sq_f64x4);
|
|
731
|
+
sumsq_f64x4 = tentative_sq_f64x4;
|
|
732
|
+
}
|
|
733
|
+
nk_size_t remaining = count - idx;
|
|
734
|
+
if (remaining > 0) {
|
|
735
|
+
nk_b256_vec_t tail_vec;
|
|
736
|
+
nk_partial_load_b64x4_serial_(data_ptr + idx, &tail_vec, remaining);
|
|
737
|
+
__m256d val_f64x4 = tail_vec.ymm_pd;
|
|
738
|
+
__m256d tentative_f64x4 = _mm256_add_pd(sum_f64x4, val_f64x4);
|
|
739
|
+
__m256d round_f64x4 = _mm256_sub_pd(tentative_f64x4, sum_f64x4);
|
|
740
|
+
__m256d corr_f64x4 = _mm256_add_pd(_mm256_sub_pd(sum_f64x4, _mm256_sub_pd(tentative_f64x4, round_f64x4)),
|
|
741
|
+
_mm256_sub_pd(val_f64x4, round_f64x4));
|
|
742
|
+
sum_comp_f64x4 = _mm256_add_pd(sum_comp_f64x4, corr_f64x4);
|
|
743
|
+
sum_f64x4 = tentative_f64x4;
|
|
744
|
+
__m256d sq_f64x4 = _mm256_mul_pd(val_f64x4, val_f64x4);
|
|
745
|
+
__m256d tentative_sq_f64x4 = _mm256_add_pd(sumsq_f64x4, sq_f64x4);
|
|
746
|
+
__m256d round_sq_f64x4 = _mm256_sub_pd(tentative_sq_f64x4, sumsq_f64x4);
|
|
747
|
+
__m256d corr_sq_f64x4 = _mm256_add_pd(
|
|
748
|
+
_mm256_sub_pd(sumsq_f64x4, _mm256_sub_pd(tentative_sq_f64x4, round_sq_f64x4)),
|
|
749
|
+
_mm256_sub_pd(sq_f64x4, round_sq_f64x4));
|
|
750
|
+
sumsq_comp_f64x4 = _mm256_add_pd(sumsq_comp_f64x4, corr_sq_f64x4);
|
|
751
|
+
sumsq_f64x4 = tentative_sq_f64x4;
|
|
752
|
+
}
|
|
753
|
+
*sum_ptr = nk_reduce_add_f64x4_haswell_(_mm256_add_pd(sum_f64x4, sum_comp_f64x4)),
|
|
754
|
+
*sumsq_ptr = nk_reduce_add_f64x4_haswell_(_mm256_add_pd(sumsq_f64x4, sumsq_comp_f64x4));
|
|
755
|
+
}
|
|
756
|
+
|
|
757
|
+
NK_INTERNAL void nk_reduce_moments_f64_haswell_strided_( //
|
|
758
|
+
nk_f64_t const *data_ptr, nk_size_t count, nk_size_t stride_elements, //
|
|
759
|
+
nk_f64_t *sum_ptr, nk_f64_t *sumsq_ptr) {
|
|
760
|
+
|
|
761
|
+
__m256i blend_mask_i64x4 = nk_stride_blend_b64x4_(stride_elements);
|
|
762
|
+
__m256d blend_f64x4 = _mm256_castsi256_pd(blend_mask_i64x4);
|
|
763
|
+
__m256d zero_f64x4 = _mm256_setzero_pd();
|
|
764
|
+
__m256d sum_f64x4 = _mm256_setzero_pd();
|
|
765
|
+
__m256d sum_comp_f64x4 = _mm256_setzero_pd();
|
|
766
|
+
__m256d sumsq_f64x4 = _mm256_setzero_pd();
|
|
767
|
+
__m256d sumsq_comp_f64x4 = _mm256_setzero_pd();
|
|
768
|
+
nk_size_t idx = 0, total = count * stride_elements;
|
|
769
|
+
nk_size_t step = nk_size_round_up_to_multiple_(4, stride_elements);
|
|
770
|
+
for (; idx + step <= total; idx += step) {
|
|
771
|
+
__m256d val_f64x4 = _mm256_blendv_pd(zero_f64x4, _mm256_loadu_pd(data_ptr + idx), blend_f64x4);
|
|
772
|
+
__m256d tentative_f64x4 = _mm256_add_pd(sum_f64x4, val_f64x4);
|
|
773
|
+
__m256d round_f64x4 = _mm256_sub_pd(tentative_f64x4, sum_f64x4);
|
|
774
|
+
__m256d corr_f64x4 = _mm256_add_pd(_mm256_sub_pd(sum_f64x4, _mm256_sub_pd(tentative_f64x4, round_f64x4)),
|
|
775
|
+
_mm256_sub_pd(val_f64x4, round_f64x4));
|
|
776
|
+
sum_comp_f64x4 = _mm256_add_pd(sum_comp_f64x4, corr_f64x4);
|
|
777
|
+
sum_f64x4 = tentative_f64x4;
|
|
778
|
+
__m256d sq_f64x4 = _mm256_mul_pd(val_f64x4, val_f64x4);
|
|
779
|
+
__m256d tentative_sq_f64x4 = _mm256_add_pd(sumsq_f64x4, sq_f64x4);
|
|
780
|
+
__m256d round_sq_f64x4 = _mm256_sub_pd(tentative_sq_f64x4, sumsq_f64x4);
|
|
781
|
+
__m256d corr_sq_f64x4 = _mm256_add_pd(
|
|
782
|
+
_mm256_sub_pd(sumsq_f64x4, _mm256_sub_pd(tentative_sq_f64x4, round_sq_f64x4)),
|
|
783
|
+
_mm256_sub_pd(sq_f64x4, round_sq_f64x4));
|
|
784
|
+
sumsq_comp_f64x4 = _mm256_add_pd(sumsq_comp_f64x4, corr_sq_f64x4);
|
|
785
|
+
sumsq_f64x4 = tentative_sq_f64x4;
|
|
786
|
+
}
|
|
787
|
+
nk_f64_t sum = nk_reduce_add_f64x4_haswell_(_mm256_add_pd(sum_f64x4, sum_comp_f64x4));
|
|
788
|
+
nk_f64_t sumsq = nk_reduce_add_f64x4_haswell_(_mm256_add_pd(sumsq_f64x4, sumsq_comp_f64x4));
|
|
789
|
+
nk_f64_t const *ptr = data_ptr + idx;
|
|
790
|
+
nk_size_t remaining_elements = count - idx / stride_elements;
|
|
791
|
+
for (nk_size_t i = 0; i < remaining_elements; ++i, ptr += stride_elements) {
|
|
792
|
+
nk_f64_t val = *ptr;
|
|
793
|
+
sum += val, sumsq += val * val;
|
|
794
|
+
}
|
|
795
|
+
*sum_ptr = sum, *sumsq_ptr = sumsq;
|
|
796
|
+
}
|
|
797
|
+
|
|
798
|
+
NK_PUBLIC void nk_reduce_moments_f64_haswell( //
|
|
799
|
+
nk_f64_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
800
|
+
nk_f64_t *sum_ptr, nk_f64_t *sumsq_ptr) {
|
|
801
|
+
|
|
802
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_f64_t);
|
|
803
|
+
int aligned = (stride_bytes % sizeof(nk_f64_t) == 0);
|
|
804
|
+
if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
|
|
805
|
+
else if (!aligned) nk_reduce_moments_f64_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
806
|
+
else if (count > (nk_size_t)(NK_U16_MAX + 1) * 4) {
|
|
807
|
+
nk_size_t left_count = count / 2;
|
|
808
|
+
nk_f64_t left_sum, left_sumsq, right_sum, right_sumsq;
|
|
809
|
+
nk_reduce_moments_f64_haswell(data_ptr, left_count, stride_bytes, &left_sum, &left_sumsq);
|
|
810
|
+
nk_reduce_moments_f64_haswell(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
|
|
811
|
+
&right_sum, &right_sumsq);
|
|
812
|
+
*sum_ptr = left_sum + right_sum, *sumsq_ptr = left_sumsq + right_sumsq;
|
|
813
|
+
}
|
|
814
|
+
else if (stride_elements == 1) nk_reduce_moments_f64_haswell_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
|
|
815
|
+
else if (stride_elements <= 4)
|
|
816
|
+
nk_reduce_moments_f64_haswell_strided_(data_ptr, count, stride_elements, sum_ptr, sumsq_ptr);
|
|
817
|
+
else nk_reduce_moments_f64_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
818
|
+
}
|
|
819
|
+
|
|
820
|
+
NK_INTERNAL void nk_reduce_minmax_f64_haswell_contiguous_( //
|
|
821
|
+
nk_f64_t const *data_ptr, nk_size_t count, //
|
|
822
|
+
nk_f64_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
823
|
+
nk_f64_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
824
|
+
|
|
825
|
+
__m256d min_f64x4 = _mm256_set1_pd(NK_F64_MAX);
|
|
826
|
+
__m256d max_f64x4 = _mm256_set1_pd(NK_F64_MIN);
|
|
827
|
+
__m256i min_loop_cycle_u64x4 = _mm256_setzero_si256();
|
|
828
|
+
__m256i max_loop_cycle_u64x4 = _mm256_setzero_si256();
|
|
829
|
+
__m256i current_loop_cycle_u64x4 = _mm256_setzero_si256();
|
|
830
|
+
__m256i one_u64x4 = _mm256_set1_epi64x(1);
|
|
831
|
+
|
|
832
|
+
nk_size_t idx = 0;
|
|
833
|
+
for (; idx + 4 <= count; idx += 4) {
|
|
834
|
+
__m256d data_f64x4 = _mm256_loadu_pd(data_ptr + idx);
|
|
835
|
+
__m256d less_b64x4 = _mm256_cmp_pd(data_f64x4, min_f64x4, _CMP_LT_OQ);
|
|
836
|
+
__m256d greater_b64x4 = _mm256_cmp_pd(data_f64x4, max_f64x4, _CMP_GT_OQ);
|
|
837
|
+
min_f64x4 = _mm256_blendv_pd(min_f64x4, data_f64x4, less_b64x4);
|
|
838
|
+
max_f64x4 = _mm256_blendv_pd(max_f64x4, data_f64x4, greater_b64x4);
|
|
839
|
+
min_loop_cycle_u64x4 = _mm256_blendv_epi8(min_loop_cycle_u64x4, current_loop_cycle_u64x4,
|
|
840
|
+
_mm256_castpd_si256(less_b64x4));
|
|
841
|
+
max_loop_cycle_u64x4 = _mm256_blendv_epi8(max_loop_cycle_u64x4, current_loop_cycle_u64x4,
|
|
842
|
+
_mm256_castpd_si256(greater_b64x4));
|
|
843
|
+
current_loop_cycle_u64x4 = _mm256_add_epi64(current_loop_cycle_u64x4, one_u64x4);
|
|
844
|
+
}
|
|
845
|
+
|
|
846
|
+
nk_size_t remaining = count - idx;
|
|
847
|
+
if (remaining > 0) {
|
|
848
|
+
nk_b256_vec_t tail_vec;
|
|
849
|
+
nk_partial_load_b64x4_serial_(data_ptr + idx, &tail_vec, remaining);
|
|
850
|
+
__m256i lane_indices_i64x4 = _mm256_setr_epi64x(0, 1, 2, 3);
|
|
851
|
+
__m256i threshold_i64x4 = _mm256_set1_epi64x((nk_i64_t)remaining);
|
|
852
|
+
__m256i valid_i64x4 = _mm256_cmpgt_epi64(threshold_i64x4, lane_indices_i64x4);
|
|
853
|
+
__m256d valid_b64x4 = _mm256_castsi256_pd(valid_i64x4);
|
|
854
|
+
__m256d nan_f64x4 = _mm256_castsi256_pd(_mm256_set1_epi64x((long long)0x7FF8000000000000LL));
|
|
855
|
+
__m256d data_f64x4 = _mm256_blendv_pd(nan_f64x4, tail_vec.ymm_pd, valid_b64x4);
|
|
856
|
+
__m256d less_b64x4 = _mm256_cmp_pd(data_f64x4, min_f64x4, _CMP_LT_OQ);
|
|
857
|
+
__m256d greater_b64x4 = _mm256_cmp_pd(data_f64x4, max_f64x4, _CMP_GT_OQ);
|
|
858
|
+
min_f64x4 = _mm256_blendv_pd(min_f64x4, data_f64x4, less_b64x4);
|
|
859
|
+
max_f64x4 = _mm256_blendv_pd(max_f64x4, data_f64x4, greater_b64x4);
|
|
860
|
+
min_loop_cycle_u64x4 = _mm256_blendv_epi8(min_loop_cycle_u64x4, current_loop_cycle_u64x4,
|
|
861
|
+
_mm256_castpd_si256(less_b64x4));
|
|
862
|
+
max_loop_cycle_u64x4 = _mm256_blendv_epi8(max_loop_cycle_u64x4, current_loop_cycle_u64x4,
|
|
863
|
+
_mm256_castpd_si256(greater_b64x4));
|
|
864
|
+
}
|
|
865
|
+
|
|
866
|
+
nk_f64_t min_value = nk_reduce_min_f64x4_haswell_(min_f64x4);
|
|
867
|
+
nk_f64_t max_value = nk_reduce_max_f64x4_haswell_(max_f64x4);
|
|
868
|
+
if (min_value == NK_F64_MAX && max_value == NK_F64_MIN) {
|
|
869
|
+
*min_value_ptr = NK_F64_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_F64_MIN,
|
|
870
|
+
*max_index_ptr = NK_SIZE_MAX;
|
|
871
|
+
return;
|
|
872
|
+
}
|
|
873
|
+
{
|
|
874
|
+
__m256d value_match_b64x4 = _mm256_cmp_pd(min_f64x4, _mm256_set1_pd(min_value), _CMP_EQ_OQ);
|
|
875
|
+
__m256i masked_cycle_u64x4 = _mm256_blendv_epi8(_mm256_set1_epi64x((nk_i64_t)NK_U64_MAX), min_loop_cycle_u64x4,
|
|
876
|
+
_mm256_castpd_si256(value_match_b64x4));
|
|
877
|
+
nk_u64_t earliest_loop_cycle = nk_reduce_min_u64x4_haswell_(masked_cycle_u64x4);
|
|
878
|
+
__m256i cycle_match_b64x4 = _mm256_cmpeq_epi64(masked_cycle_u64x4,
|
|
879
|
+
_mm256_set1_epi64x((nk_i64_t)earliest_loop_cycle));
|
|
880
|
+
unsigned int min_lane = _tzcnt_u32((unsigned int)_mm256_movemask_pd(_mm256_castsi256_pd(cycle_match_b64x4)));
|
|
881
|
+
nk_b256_vec_t loop_cycle_vec;
|
|
882
|
+
loop_cycle_vec.ymm = min_loop_cycle_u64x4;
|
|
883
|
+
*min_value_ptr = min_value;
|
|
884
|
+
*min_index_ptr = (nk_size_t)loop_cycle_vec.u64s[min_lane] * 4 + min_lane;
|
|
885
|
+
}
|
|
886
|
+
{
|
|
887
|
+
__m256d value_match_b64x4 = _mm256_cmp_pd(max_f64x4, _mm256_set1_pd(max_value), _CMP_EQ_OQ);
|
|
888
|
+
__m256i masked_cycle_u64x4 = _mm256_blendv_epi8(_mm256_set1_epi64x((nk_i64_t)NK_U64_MAX), max_loop_cycle_u64x4,
|
|
889
|
+
_mm256_castpd_si256(value_match_b64x4));
|
|
890
|
+
nk_u64_t earliest_loop_cycle = nk_reduce_min_u64x4_haswell_(masked_cycle_u64x4);
|
|
891
|
+
__m256i cycle_match_b64x4 = _mm256_cmpeq_epi64(masked_cycle_u64x4,
|
|
892
|
+
_mm256_set1_epi64x((nk_i64_t)earliest_loop_cycle));
|
|
893
|
+
unsigned int max_lane = _tzcnt_u32((unsigned int)_mm256_movemask_pd(_mm256_castsi256_pd(cycle_match_b64x4)));
|
|
894
|
+
nk_b256_vec_t loop_cycle_vec;
|
|
895
|
+
loop_cycle_vec.ymm = max_loop_cycle_u64x4;
|
|
896
|
+
*max_value_ptr = max_value;
|
|
897
|
+
*max_index_ptr = (nk_size_t)loop_cycle_vec.u64s[max_lane] * 4 + max_lane;
|
|
898
|
+
}
|
|
899
|
+
}
|
|
900
|
+
|
|
901
|
+
NK_PUBLIC void nk_reduce_minmax_f64_haswell( //
|
|
902
|
+
nk_f64_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
903
|
+
nk_f64_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
904
|
+
nk_f64_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
905
|
+
|
|
906
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_f64_t);
|
|
907
|
+
int aligned = (stride_bytes % sizeof(nk_f64_t) == 0);
|
|
908
|
+
if (count == 0)
|
|
909
|
+
*min_value_ptr = NK_F64_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_F64_MIN,
|
|
910
|
+
*max_index_ptr = NK_SIZE_MAX;
|
|
911
|
+
else if (!aligned)
|
|
912
|
+
nk_reduce_minmax_f64_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
913
|
+
max_index_ptr);
|
|
914
|
+
else if (stride_elements == 1)
|
|
915
|
+
nk_reduce_minmax_f64_haswell_contiguous_(data_ptr, count, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
916
|
+
max_index_ptr);
|
|
917
|
+
else
|
|
918
|
+
nk_reduce_minmax_f64_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
919
|
+
max_index_ptr);
|
|
920
|
+
}
|
|
921
|
+
|
|
922
|
+
NK_INTERNAL void nk_reduce_moments_i8_haswell_contiguous_( //
|
|
923
|
+
nk_i8_t const *data_ptr, nk_size_t count, //
|
|
924
|
+
nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
925
|
+
|
|
926
|
+
__m256i bias_i8x32 = _mm256_set1_epi8((char)0x80);
|
|
927
|
+
__m256i zero_i8x32 = _mm256_setzero_si256();
|
|
928
|
+
__m256i sum_u64x4 = _mm256_setzero_si256();
|
|
929
|
+
__m256i sumsq_low_i32x8 = _mm256_setzero_si256();
|
|
930
|
+
__m256i sumsq_high_i32x8 = _mm256_setzero_si256();
|
|
931
|
+
nk_size_t idx = 0;
|
|
932
|
+
for (; idx + 32 <= count; idx += 32) {
|
|
933
|
+
__m256i data_i8x32 = _mm256_loadu_si256((__m256i const *)(data_ptr + idx));
|
|
934
|
+
__m256i unsigned_u8x32 = _mm256_xor_si256(data_i8x32, bias_i8x32);
|
|
935
|
+
sum_u64x4 = _mm256_add_epi64(sum_u64x4, _mm256_sad_epu8(unsigned_u8x32, zero_i8x32));
|
|
936
|
+
__m256i low_i16x16 = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(data_i8x32));
|
|
937
|
+
__m256i high_i16x16 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(data_i8x32, 1));
|
|
938
|
+
sumsq_low_i32x8 = _mm256_add_epi32(sumsq_low_i32x8, _mm256_madd_epi16(low_i16x16, low_i16x16));
|
|
939
|
+
sumsq_high_i32x8 = _mm256_add_epi32(sumsq_high_i32x8, _mm256_madd_epi16(high_i16x16, high_i16x16));
|
|
940
|
+
}
|
|
941
|
+
nk_size_t remaining = count - idx;
|
|
942
|
+
if (remaining > 0) {
|
|
943
|
+
nk_b256_vec_t tail_vec;
|
|
944
|
+
nk_partial_load_b8x32_serial_(data_ptr + idx, &tail_vec, remaining);
|
|
945
|
+
__m256i data_i8x32 = tail_vec.ymm;
|
|
946
|
+
// Build masked bias: only bias valid lanes so zero-padded lanes stay zero
|
|
947
|
+
__m256i lane_indices_u8x32 = _mm256_setr_epi8(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
|
|
948
|
+
19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31);
|
|
949
|
+
__m256i valid_b8x32 = _mm256_cmpgt_epi8(_mm256_set1_epi8((char)remaining), lane_indices_u8x32);
|
|
950
|
+
__m256i masked_bias_i8x32 = _mm256_and_si256(bias_i8x32, valid_b8x32);
|
|
951
|
+
__m256i unsigned_u8x32 = _mm256_xor_si256(data_i8x32, masked_bias_i8x32);
|
|
952
|
+
sum_u64x4 = _mm256_add_epi64(sum_u64x4, _mm256_sad_epu8(unsigned_u8x32, zero_i8x32));
|
|
953
|
+
__m256i low_i16x16 = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(data_i8x32));
|
|
954
|
+
__m256i high_i16x16 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(data_i8x32, 1));
|
|
955
|
+
sumsq_low_i32x8 = _mm256_add_epi32(sumsq_low_i32x8, _mm256_madd_epi16(low_i16x16, low_i16x16));
|
|
956
|
+
sumsq_high_i32x8 = _mm256_add_epi32(sumsq_high_i32x8, _mm256_madd_epi16(high_i16x16, high_i16x16));
|
|
957
|
+
}
|
|
958
|
+
sumsq_low_i32x8 = _mm256_add_epi32(sumsq_low_i32x8, sumsq_high_i32x8);
|
|
959
|
+
__m256i sumsq_i64x4 = _mm256_cvtepi32_epi64(_mm256_castsi256_si128(sumsq_low_i32x8));
|
|
960
|
+
sumsq_i64x4 = _mm256_add_epi64(sumsq_i64x4, _mm256_cvtepi32_epi64(_mm256_extracti128_si256(sumsq_low_i32x8, 1)));
|
|
961
|
+
nk_i64_t sum = (nk_i64_t)(nk_u64_t)nk_reduce_add_i64x4_haswell_(sum_u64x4);
|
|
962
|
+
sum -= (nk_i64_t)128 * (nk_i64_t)count;
|
|
963
|
+
nk_u64_t sumsq = (nk_u64_t)nk_reduce_add_i64x4_haswell_(sumsq_i64x4);
|
|
964
|
+
*sum_ptr = sum, *sumsq_ptr = sumsq;
|
|
965
|
+
}
|
|
966
|
+
|
|
967
|
+
NK_INTERNAL void nk_reduce_moments_i8_haswell_strided_( //
|
|
968
|
+
nk_i8_t const *data_ptr, nk_size_t count, nk_size_t stride_elements, //
|
|
969
|
+
nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
970
|
+
|
|
971
|
+
__m256i stride_mask_i8x32 = nk_stride_blend_u1x32_(stride_elements);
|
|
972
|
+
nk_size_t elements_per_vector = nk_size_divide_round_up_(32, stride_elements);
|
|
973
|
+
__m256i masked_bias_i8x32 = _mm256_and_si256(_mm256_set1_epi8((char)0x80), stride_mask_i8x32);
|
|
974
|
+
__m256i zero_i8x32 = _mm256_setzero_si256();
|
|
975
|
+
__m256i sum_u64x4 = _mm256_setzero_si256();
|
|
976
|
+
__m256i sumsq_low_i32x8 = _mm256_setzero_si256();
|
|
977
|
+
__m256i sumsq_high_i32x8 = _mm256_setzero_si256();
|
|
978
|
+
nk_size_t idx_scalars = 0;
|
|
979
|
+
nk_size_t total_scalars = count * stride_elements;
|
|
980
|
+
nk_size_t vector_element_count = 0;
|
|
981
|
+
nk_size_t step = elements_per_vector * stride_elements;
|
|
982
|
+
for (; idx_scalars + step <= total_scalars; idx_scalars += step) {
|
|
983
|
+
__m256i data_i8x32 = _mm256_loadu_si256((__m256i const *)(data_ptr + idx_scalars));
|
|
984
|
+
data_i8x32 = _mm256_and_si256(data_i8x32, stride_mask_i8x32);
|
|
985
|
+
__m256i unsigned_u8x32 = _mm256_xor_si256(data_i8x32, masked_bias_i8x32);
|
|
986
|
+
sum_u64x4 = _mm256_add_epi64(sum_u64x4, _mm256_sad_epu8(unsigned_u8x32, zero_i8x32));
|
|
987
|
+
__m256i low_i16x16 = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(data_i8x32));
|
|
988
|
+
__m256i high_i16x16 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(data_i8x32, 1));
|
|
989
|
+
sumsq_low_i32x8 = _mm256_add_epi32(sumsq_low_i32x8, _mm256_madd_epi16(low_i16x16, low_i16x16));
|
|
990
|
+
sumsq_high_i32x8 = _mm256_add_epi32(sumsq_high_i32x8, _mm256_madd_epi16(high_i16x16, high_i16x16));
|
|
991
|
+
vector_element_count += elements_per_vector;
|
|
992
|
+
}
|
|
993
|
+
sumsq_low_i32x8 = _mm256_add_epi32(sumsq_low_i32x8, sumsq_high_i32x8);
|
|
994
|
+
__m256i sumsq_i64x4 = _mm256_cvtepi32_epi64(_mm256_castsi256_si128(sumsq_low_i32x8));
|
|
995
|
+
sumsq_i64x4 = _mm256_add_epi64(sumsq_i64x4, _mm256_cvtepi32_epi64(_mm256_extracti128_si256(sumsq_low_i32x8, 1)));
|
|
996
|
+
nk_i64_t sum = (nk_i64_t)(nk_u64_t)nk_reduce_add_i64x4_haswell_(sum_u64x4);
|
|
997
|
+
sum -= (nk_i64_t)128 * (nk_i64_t)vector_element_count;
|
|
998
|
+
nk_u64_t sumsq = (nk_u64_t)nk_reduce_add_i64x4_haswell_(sumsq_i64x4);
|
|
999
|
+
nk_i8_t const *ptr = data_ptr + idx_scalars;
|
|
1000
|
+
nk_size_t remaining = count - idx_scalars / stride_elements;
|
|
1001
|
+
for (nk_size_t i = 0; i < remaining; ++i, ptr += stride_elements) {
|
|
1002
|
+
nk_i64_t val = (nk_i64_t)*ptr;
|
|
1003
|
+
sum += val, sumsq += (nk_u64_t)(val * val);
|
|
1004
|
+
}
|
|
1005
|
+
*sum_ptr = sum, *sumsq_ptr = sumsq;
|
|
1006
|
+
}
|
|
1007
|
+
|
|
1008
|
+
NK_PUBLIC void nk_reduce_moments_i8_haswell( //
|
|
1009
|
+
nk_i8_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
1010
|
+
nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
1011
|
+
|
|
1012
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_i8_t);
|
|
1013
|
+
int aligned = (stride_bytes % sizeof(nk_i8_t) == 0);
|
|
1014
|
+
if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
|
|
1015
|
+
else if (!aligned) nk_reduce_moments_i8_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
1016
|
+
else if (count > (nk_size_t)(NK_U16_MAX + 1) * 32) {
|
|
1017
|
+
nk_size_t left_count = count / 2;
|
|
1018
|
+
nk_i64_t left_sum, right_sum;
|
|
1019
|
+
nk_u64_t left_sumsq, right_sumsq;
|
|
1020
|
+
nk_reduce_moments_i8_haswell(data_ptr, left_count, stride_bytes, &left_sum, &left_sumsq);
|
|
1021
|
+
nk_reduce_moments_i8_haswell(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
|
|
1022
|
+
&right_sum, &right_sumsq);
|
|
1023
|
+
*sum_ptr = nk_i64_saturating_add_serial(left_sum, right_sum);
|
|
1024
|
+
*sumsq_ptr = nk_u64_saturating_add_serial(left_sumsq, right_sumsq);
|
|
1025
|
+
}
|
|
1026
|
+
else if (stride_elements == 1) nk_reduce_moments_i8_haswell_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
|
|
1027
|
+
else if (stride_elements <= 8)
|
|
1028
|
+
nk_reduce_moments_i8_haswell_strided_(data_ptr, count, stride_elements, sum_ptr, sumsq_ptr);
|
|
1029
|
+
else nk_reduce_moments_i8_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
1030
|
+
}
|
|
1031
|
+
|
|
1032
|
+
NK_INTERNAL void nk_reduce_minmax_i8_haswell_contiguous_( //
|
|
1033
|
+
nk_i8_t const *data_ptr, nk_size_t count, //
|
|
1034
|
+
nk_i8_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
1035
|
+
nk_i8_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
1036
|
+
|
|
1037
|
+
__m256i min_i8x32 = _mm256_set1_epi8((char)NK_I8_MAX);
|
|
1038
|
+
__m256i max_i8x32 = _mm256_set1_epi8(NK_I8_MIN);
|
|
1039
|
+
__m256i min_loop_cycle_u8x32 = _mm256_setzero_si256();
|
|
1040
|
+
__m256i max_loop_cycle_u8x32 = _mm256_setzero_si256();
|
|
1041
|
+
__m256i current_loop_cycle_u8x32 = _mm256_setzero_si256();
|
|
1042
|
+
__m256i one_u8x32 = _mm256_set1_epi8(1);
|
|
1043
|
+
|
|
1044
|
+
nk_size_t idx = 0;
|
|
1045
|
+
for (; idx + 32 <= count; idx += 32) {
|
|
1046
|
+
__m256i data_i8x32 = _mm256_loadu_si256((__m256i const *)(data_ptr + idx));
|
|
1047
|
+
__m256i less_b8x32 = _mm256_cmpgt_epi8(min_i8x32, data_i8x32);
|
|
1048
|
+
__m256i greater_b8x32 = _mm256_cmpgt_epi8(data_i8x32, max_i8x32);
|
|
1049
|
+
min_i8x32 = _mm256_blendv_epi8(min_i8x32, data_i8x32, less_b8x32);
|
|
1050
|
+
max_i8x32 = _mm256_blendv_epi8(max_i8x32, data_i8x32, greater_b8x32);
|
|
1051
|
+
min_loop_cycle_u8x32 = _mm256_blendv_epi8(min_loop_cycle_u8x32, current_loop_cycle_u8x32, less_b8x32);
|
|
1052
|
+
max_loop_cycle_u8x32 = _mm256_blendv_epi8(max_loop_cycle_u8x32, current_loop_cycle_u8x32, greater_b8x32);
|
|
1053
|
+
current_loop_cycle_u8x32 = _mm256_add_epi8(current_loop_cycle_u8x32, one_u8x32);
|
|
1054
|
+
}
|
|
1055
|
+
|
|
1056
|
+
nk_size_t remaining = count - idx;
|
|
1057
|
+
if (remaining > 0) {
|
|
1058
|
+
nk_b256_vec_t tail_vec;
|
|
1059
|
+
nk_partial_load_b8x32_serial_(data_ptr + idx, &tail_vec, remaining);
|
|
1060
|
+
__m256i data_i8x32 = tail_vec.ymm;
|
|
1061
|
+
// Build lane mask and fill invalid lanes with identity values
|
|
1062
|
+
__m256i lane_indices_u8x32 = _mm256_setr_epi8(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
|
|
1063
|
+
19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31);
|
|
1064
|
+
__m256i valid_b8x32 = _mm256_cmpgt_epi8(_mm256_set1_epi8((char)remaining), lane_indices_u8x32);
|
|
1065
|
+
data_i8x32 = _mm256_blendv_epi8(_mm256_set1_epi8(NK_I8_MAX), data_i8x32, valid_b8x32);
|
|
1066
|
+
__m256i data_max_i8x32 = _mm256_blendv_epi8(_mm256_set1_epi8(NK_I8_MIN), tail_vec.ymm, valid_b8x32);
|
|
1067
|
+
__m256i less_b8x32 = _mm256_cmpgt_epi8(min_i8x32, data_i8x32);
|
|
1068
|
+
__m256i greater_b8x32 = _mm256_cmpgt_epi8(data_max_i8x32, max_i8x32);
|
|
1069
|
+
min_i8x32 = _mm256_blendv_epi8(min_i8x32, data_i8x32, less_b8x32);
|
|
1070
|
+
max_i8x32 = _mm256_blendv_epi8(max_i8x32, data_max_i8x32, greater_b8x32);
|
|
1071
|
+
min_loop_cycle_u8x32 = _mm256_blendv_epi8(min_loop_cycle_u8x32, current_loop_cycle_u8x32, less_b8x32);
|
|
1072
|
+
max_loop_cycle_u8x32 = _mm256_blendv_epi8(max_loop_cycle_u8x32, current_loop_cycle_u8x32, greater_b8x32);
|
|
1073
|
+
}
|
|
1074
|
+
|
|
1075
|
+
nk_i8_t min_value = nk_reduce_min_i8x32_haswell_(min_i8x32);
|
|
1076
|
+
nk_i8_t max_value = nk_reduce_max_i8x32_haswell_(max_i8x32);
|
|
1077
|
+
unsigned int min_lane, max_lane;
|
|
1078
|
+
{
|
|
1079
|
+
__m256i value_match_b8x32 = _mm256_cmpeq_epi8(min_i8x32, _mm256_set1_epi8(min_value));
|
|
1080
|
+
__m256i masked_cycle_u8x32 = _mm256_blendv_epi8(_mm256_set1_epi8((char)NK_U8_MAX), min_loop_cycle_u8x32,
|
|
1081
|
+
value_match_b8x32);
|
|
1082
|
+
nk_u8_t earliest_loop_cycle = nk_reduce_min_u8x32_haswell_(masked_cycle_u8x32);
|
|
1083
|
+
__m256i cycle_match_b8x32 = _mm256_cmpeq_epi8(masked_cycle_u8x32, _mm256_set1_epi8((char)earliest_loop_cycle));
|
|
1084
|
+
min_lane = _tzcnt_u32((unsigned int)_mm256_movemask_epi8(cycle_match_b8x32));
|
|
1085
|
+
}
|
|
1086
|
+
{
|
|
1087
|
+
__m256i value_match_b8x32 = _mm256_cmpeq_epi8(max_i8x32, _mm256_set1_epi8(max_value));
|
|
1088
|
+
__m256i masked_cycle_u8x32 = _mm256_blendv_epi8(_mm256_set1_epi8((char)NK_U8_MAX), max_loop_cycle_u8x32,
|
|
1089
|
+
value_match_b8x32);
|
|
1090
|
+
nk_u8_t earliest_loop_cycle = nk_reduce_min_u8x32_haswell_(masked_cycle_u8x32);
|
|
1091
|
+
__m256i cycle_match_b8x32 = _mm256_cmpeq_epi8(masked_cycle_u8x32, _mm256_set1_epi8((char)earliest_loop_cycle));
|
|
1092
|
+
max_lane = _tzcnt_u32((unsigned int)_mm256_movemask_epi8(cycle_match_b8x32));
|
|
1093
|
+
}
|
|
1094
|
+
nk_b256_vec_t loop_cycle_vec;
|
|
1095
|
+
loop_cycle_vec.ymm = min_loop_cycle_u8x32;
|
|
1096
|
+
*min_value_ptr = min_value, *min_index_ptr = (nk_size_t)loop_cycle_vec.u8s[min_lane] * 32 + min_lane;
|
|
1097
|
+
loop_cycle_vec.ymm = max_loop_cycle_u8x32;
|
|
1098
|
+
*max_value_ptr = max_value, *max_index_ptr = (nk_size_t)loop_cycle_vec.u8s[max_lane] * 32 + max_lane;
|
|
1099
|
+
}
|
|
1100
|
+
|
|
1101
|
+
NK_PUBLIC void nk_reduce_minmax_i8_haswell( //
|
|
1102
|
+
nk_i8_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
1103
|
+
nk_i8_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
1104
|
+
nk_i8_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
1105
|
+
|
|
1106
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_i8_t);
|
|
1107
|
+
int aligned = (stride_bytes % sizeof(nk_i8_t) == 0);
|
|
1108
|
+
if (count == 0)
|
|
1109
|
+
*min_value_ptr = NK_I8_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_I8_MIN,
|
|
1110
|
+
*max_index_ptr = NK_SIZE_MAX;
|
|
1111
|
+
else if (!aligned)
|
|
1112
|
+
nk_reduce_minmax_i8_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
1113
|
+
max_index_ptr);
|
|
1114
|
+
else if (count > (nk_size_t)(NK_U8_MAX + 1) * 32) {
|
|
1115
|
+
nk_size_t left_count = count / 2;
|
|
1116
|
+
nk_i8_t left_min, right_min, left_max, right_max;
|
|
1117
|
+
nk_size_t left_min_index, right_min_index, left_max_index, right_max_index;
|
|
1118
|
+
nk_reduce_minmax_i8_haswell(data_ptr, left_count, stride_bytes, &left_min, &left_min_index, &left_max,
|
|
1119
|
+
&left_max_index);
|
|
1120
|
+
nk_reduce_minmax_i8_haswell(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
|
|
1121
|
+
&right_min, &right_min_index, &right_max, &right_max_index);
|
|
1122
|
+
if (right_min < left_min) *min_value_ptr = right_min, *min_index_ptr = left_count + right_min_index;
|
|
1123
|
+
else *min_value_ptr = left_min, *min_index_ptr = left_min_index;
|
|
1124
|
+
if (right_max > left_max) *max_value_ptr = right_max, *max_index_ptr = left_count + right_max_index;
|
|
1125
|
+
else *max_value_ptr = left_max, *max_index_ptr = left_max_index;
|
|
1126
|
+
}
|
|
1127
|
+
else if (stride_elements == 1)
|
|
1128
|
+
nk_reduce_minmax_i8_haswell_contiguous_(data_ptr, count, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
1129
|
+
max_index_ptr);
|
|
1130
|
+
else
|
|
1131
|
+
nk_reduce_minmax_i8_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
1132
|
+
max_index_ptr);
|
|
1133
|
+
}
|
|
1134
|
+
|
|
1135
|
+
NK_INTERNAL void nk_reduce_moments_u8_haswell_contiguous_( //
|
|
1136
|
+
nk_u8_t const *data_ptr, nk_size_t count, //
|
|
1137
|
+
nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
1138
|
+
|
|
1139
|
+
__m256i zero_u8x32 = _mm256_setzero_si256();
|
|
1140
|
+
__m256i sum_u64x4 = _mm256_setzero_si256();
|
|
1141
|
+
__m256i sumsq_low_i32x8 = _mm256_setzero_si256();
|
|
1142
|
+
__m256i sumsq_high_i32x8 = _mm256_setzero_si256();
|
|
1143
|
+
nk_size_t idx = 0;
|
|
1144
|
+
for (; idx + 32 <= count; idx += 32) {
|
|
1145
|
+
__m256i data_u8x32 = _mm256_loadu_si256((__m256i const *)(data_ptr + idx));
|
|
1146
|
+
sum_u64x4 = _mm256_add_epi64(sum_u64x4, _mm256_sad_epu8(data_u8x32, zero_u8x32));
|
|
1147
|
+
__m256i low_i16x16 = _mm256_cvtepu8_epi16(_mm256_castsi256_si128(data_u8x32));
|
|
1148
|
+
__m256i high_i16x16 = _mm256_cvtepu8_epi16(_mm256_extracti128_si256(data_u8x32, 1));
|
|
1149
|
+
sumsq_low_i32x8 = _mm256_add_epi32(sumsq_low_i32x8, _mm256_madd_epi16(low_i16x16, low_i16x16));
|
|
1150
|
+
sumsq_high_i32x8 = _mm256_add_epi32(sumsq_high_i32x8, _mm256_madd_epi16(high_i16x16, high_i16x16));
|
|
1151
|
+
}
|
|
1152
|
+
nk_size_t remaining = count - idx;
|
|
1153
|
+
if (remaining > 0) {
|
|
1154
|
+
nk_b256_vec_t tail_vec;
|
|
1155
|
+
nk_partial_load_b8x32_serial_(data_ptr + idx, &tail_vec, remaining);
|
|
1156
|
+
__m256i data_u8x32 = tail_vec.ymm;
|
|
1157
|
+
sum_u64x4 = _mm256_add_epi64(sum_u64x4, _mm256_sad_epu8(data_u8x32, zero_u8x32));
|
|
1158
|
+
__m256i low_i16x16 = _mm256_cvtepu8_epi16(_mm256_castsi256_si128(data_u8x32));
|
|
1159
|
+
__m256i high_i16x16 = _mm256_cvtepu8_epi16(_mm256_extracti128_si256(data_u8x32, 1));
|
|
1160
|
+
sumsq_low_i32x8 = _mm256_add_epi32(sumsq_low_i32x8, _mm256_madd_epi16(low_i16x16, low_i16x16));
|
|
1161
|
+
sumsq_high_i32x8 = _mm256_add_epi32(sumsq_high_i32x8, _mm256_madd_epi16(high_i16x16, high_i16x16));
|
|
1162
|
+
}
|
|
1163
|
+
sumsq_low_i32x8 = _mm256_add_epi32(sumsq_low_i32x8, sumsq_high_i32x8);
|
|
1164
|
+
__m256i sumsq_u64x4 = _mm256_cvtepu32_epi64(_mm256_castsi256_si128(sumsq_low_i32x8));
|
|
1165
|
+
sumsq_u64x4 = _mm256_add_epi64(sumsq_u64x4, _mm256_cvtepu32_epi64(_mm256_extracti128_si256(sumsq_low_i32x8, 1)));
|
|
1166
|
+
*sum_ptr = (nk_u64_t)nk_reduce_add_i64x4_haswell_(sum_u64x4),
|
|
1167
|
+
*sumsq_ptr = (nk_u64_t)nk_reduce_add_i64x4_haswell_(sumsq_u64x4);
|
|
1168
|
+
}
|
|
1169
|
+
|
|
1170
|
+
NK_INTERNAL void nk_reduce_moments_u8_haswell_strided_( //
|
|
1171
|
+
nk_u8_t const *data_ptr, nk_size_t count, nk_size_t stride_elements, //
|
|
1172
|
+
nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
1173
|
+
|
|
1174
|
+
__m256i stride_mask_u8x32 = nk_stride_blend_u1x32_(stride_elements);
|
|
1175
|
+
__m256i zero_u8x32 = _mm256_setzero_si256();
|
|
1176
|
+
__m256i sum_u64x4 = _mm256_setzero_si256();
|
|
1177
|
+
__m256i sumsq_low_i32x8 = _mm256_setzero_si256();
|
|
1178
|
+
__m256i sumsq_high_i32x8 = _mm256_setzero_si256();
|
|
1179
|
+
nk_size_t idx_scalars = 0;
|
|
1180
|
+
nk_size_t total_scalars = count * stride_elements;
|
|
1181
|
+
nk_size_t step = nk_size_round_up_to_multiple_(32, stride_elements);
|
|
1182
|
+
for (; idx_scalars + step <= total_scalars; idx_scalars += step) {
|
|
1183
|
+
__m256i data_u8x32 = _mm256_loadu_si256((__m256i const *)(data_ptr + idx_scalars));
|
|
1184
|
+
data_u8x32 = _mm256_and_si256(data_u8x32, stride_mask_u8x32);
|
|
1185
|
+
sum_u64x4 = _mm256_add_epi64(sum_u64x4, _mm256_sad_epu8(data_u8x32, zero_u8x32));
|
|
1186
|
+
__m256i low_i16x16 = _mm256_cvtepu8_epi16(_mm256_castsi256_si128(data_u8x32));
|
|
1187
|
+
__m256i high_i16x16 = _mm256_cvtepu8_epi16(_mm256_extracti128_si256(data_u8x32, 1));
|
|
1188
|
+
sumsq_low_i32x8 = _mm256_add_epi32(sumsq_low_i32x8, _mm256_madd_epi16(low_i16x16, low_i16x16));
|
|
1189
|
+
sumsq_high_i32x8 = _mm256_add_epi32(sumsq_high_i32x8, _mm256_madd_epi16(high_i16x16, high_i16x16));
|
|
1190
|
+
}
|
|
1191
|
+
sumsq_low_i32x8 = _mm256_add_epi32(sumsq_low_i32x8, sumsq_high_i32x8);
|
|
1192
|
+
__m256i sumsq_u64x4 = _mm256_cvtepu32_epi64(_mm256_castsi256_si128(sumsq_low_i32x8));
|
|
1193
|
+
sumsq_u64x4 = _mm256_add_epi64(sumsq_u64x4, _mm256_cvtepu32_epi64(_mm256_extracti128_si256(sumsq_low_i32x8, 1)));
|
|
1194
|
+
nk_u64_t sum = (nk_u64_t)nk_reduce_add_i64x4_haswell_(sum_u64x4);
|
|
1195
|
+
nk_u64_t sumsq = (nk_u64_t)nk_reduce_add_i64x4_haswell_(sumsq_u64x4);
|
|
1196
|
+
nk_u8_t const *ptr = data_ptr + idx_scalars;
|
|
1197
|
+
nk_size_t remaining_elements = count - idx_scalars / stride_elements;
|
|
1198
|
+
for (nk_size_t i = 0; i < remaining_elements; ++i, ptr += stride_elements) {
|
|
1199
|
+
nk_u64_t val = (nk_u64_t)*ptr;
|
|
1200
|
+
sum += val, sumsq += val * val;
|
|
1201
|
+
}
|
|
1202
|
+
*sum_ptr = sum, *sumsq_ptr = sumsq;
|
|
1203
|
+
}
|
|
1204
|
+
|
|
1205
|
+
NK_PUBLIC void nk_reduce_moments_u8_haswell( //
|
|
1206
|
+
nk_u8_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
1207
|
+
nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
1208
|
+
|
|
1209
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_u8_t);
|
|
1210
|
+
int aligned = (stride_bytes % sizeof(nk_u8_t) == 0);
|
|
1211
|
+
if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
|
|
1212
|
+
else if (!aligned) nk_reduce_moments_u8_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
1213
|
+
else if (count > (nk_size_t)(NK_U16_MAX + 1) * 32) {
|
|
1214
|
+
nk_size_t left_count = count / 2;
|
|
1215
|
+
nk_u64_t left_sum, left_sumsq, right_sum, right_sumsq;
|
|
1216
|
+
nk_reduce_moments_u8_haswell(data_ptr, left_count, stride_bytes, &left_sum, &left_sumsq);
|
|
1217
|
+
nk_reduce_moments_u8_haswell(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
|
|
1218
|
+
&right_sum, &right_sumsq);
|
|
1219
|
+
*sum_ptr = nk_u64_saturating_add_serial(left_sum, right_sum);
|
|
1220
|
+
*sumsq_ptr = nk_u64_saturating_add_serial(left_sumsq, right_sumsq);
|
|
1221
|
+
}
|
|
1222
|
+
else if (stride_elements == 1) nk_reduce_moments_u8_haswell_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
|
|
1223
|
+
else if (stride_elements <= 8)
|
|
1224
|
+
nk_reduce_moments_u8_haswell_strided_(data_ptr, count, stride_elements, sum_ptr, sumsq_ptr);
|
|
1225
|
+
else nk_reduce_moments_u8_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
1226
|
+
}
|
|
1227
|
+
|
|
1228
|
+
NK_INTERNAL void nk_reduce_minmax_u8_haswell_contiguous_( //
|
|
1229
|
+
nk_u8_t const *data_ptr, nk_size_t count, //
|
|
1230
|
+
nk_u8_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
1231
|
+
nk_u8_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
1232
|
+
|
|
1233
|
+
// XOR-bias to signed domain for _mm256_cmpgt_epi8
|
|
1234
|
+
__m256i bias_u8x32 = _mm256_set1_epi8((char)0x80);
|
|
1235
|
+
__m256i min_biased_i8x32 = _mm256_set1_epi8((char)NK_I8_MAX);
|
|
1236
|
+
__m256i max_biased_i8x32 = _mm256_set1_epi8(NK_I8_MIN);
|
|
1237
|
+
__m256i min_loop_cycle_u8x32 = _mm256_setzero_si256();
|
|
1238
|
+
__m256i max_loop_cycle_u8x32 = _mm256_setzero_si256();
|
|
1239
|
+
__m256i current_loop_cycle_u8x32 = _mm256_setzero_si256();
|
|
1240
|
+
__m256i one_u8x32 = _mm256_set1_epi8(1);
|
|
1241
|
+
|
|
1242
|
+
nk_size_t idx = 0;
|
|
1243
|
+
for (; idx + 32 <= count; idx += 32) {
|
|
1244
|
+
__m256i data_biased_i8x32 = _mm256_xor_si256(_mm256_loadu_si256((__m256i const *)(data_ptr + idx)), bias_u8x32);
|
|
1245
|
+
__m256i less_b8x32 = _mm256_cmpgt_epi8(min_biased_i8x32, data_biased_i8x32);
|
|
1246
|
+
__m256i greater_b8x32 = _mm256_cmpgt_epi8(data_biased_i8x32, max_biased_i8x32);
|
|
1247
|
+
min_biased_i8x32 = _mm256_blendv_epi8(min_biased_i8x32, data_biased_i8x32, less_b8x32);
|
|
1248
|
+
max_biased_i8x32 = _mm256_blendv_epi8(max_biased_i8x32, data_biased_i8x32, greater_b8x32);
|
|
1249
|
+
min_loop_cycle_u8x32 = _mm256_blendv_epi8(min_loop_cycle_u8x32, current_loop_cycle_u8x32, less_b8x32);
|
|
1250
|
+
max_loop_cycle_u8x32 = _mm256_blendv_epi8(max_loop_cycle_u8x32, current_loop_cycle_u8x32, greater_b8x32);
|
|
1251
|
+
current_loop_cycle_u8x32 = _mm256_add_epi8(current_loop_cycle_u8x32, one_u8x32);
|
|
1252
|
+
}
|
|
1253
|
+
|
|
1254
|
+
nk_size_t remaining = count - idx;
|
|
1255
|
+
if (remaining > 0) {
|
|
1256
|
+
nk_b256_vec_t tail_vec;
|
|
1257
|
+
nk_partial_load_b8x32_serial_(data_ptr + idx, &tail_vec, remaining);
|
|
1258
|
+
__m256i data_biased_i8x32 = _mm256_xor_si256(tail_vec.ymm, bias_u8x32);
|
|
1259
|
+
__m256i lane_indices_u8x32 = _mm256_setr_epi8(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
|
|
1260
|
+
19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31);
|
|
1261
|
+
__m256i valid_b8x32 = _mm256_cmpgt_epi8(_mm256_set1_epi8((char)remaining), lane_indices_u8x32);
|
|
1262
|
+
// Biased identity: NK_U8_MAX ^ 0x80 = 0x7F for min, NK_U8_MIN ^ 0x80 = 0x80 for max
|
|
1263
|
+
__m256i data_min_i8x32 = _mm256_blendv_epi8(_mm256_set1_epi8(0x7F), data_biased_i8x32, valid_b8x32);
|
|
1264
|
+
__m256i data_max_i8x32 = _mm256_blendv_epi8(_mm256_set1_epi8((char)0x80), data_biased_i8x32, valid_b8x32);
|
|
1265
|
+
__m256i less_b8x32 = _mm256_cmpgt_epi8(min_biased_i8x32, data_min_i8x32);
|
|
1266
|
+
__m256i greater_b8x32 = _mm256_cmpgt_epi8(data_max_i8x32, max_biased_i8x32);
|
|
1267
|
+
min_biased_i8x32 = _mm256_blendv_epi8(min_biased_i8x32, data_min_i8x32, less_b8x32);
|
|
1268
|
+
max_biased_i8x32 = _mm256_blendv_epi8(max_biased_i8x32, data_max_i8x32, greater_b8x32);
|
|
1269
|
+
min_loop_cycle_u8x32 = _mm256_blendv_epi8(min_loop_cycle_u8x32, current_loop_cycle_u8x32, less_b8x32);
|
|
1270
|
+
max_loop_cycle_u8x32 = _mm256_blendv_epi8(max_loop_cycle_u8x32, current_loop_cycle_u8x32, greater_b8x32);
|
|
1271
|
+
}
|
|
1272
|
+
|
|
1273
|
+
// Un-bias to get real u8 values
|
|
1274
|
+
__m256i min_u8x32 = _mm256_xor_si256(min_biased_i8x32, bias_u8x32);
|
|
1275
|
+
__m256i max_u8x32 = _mm256_xor_si256(max_biased_i8x32, bias_u8x32);
|
|
1276
|
+
nk_u8_t min_value = nk_reduce_min_u8x32_haswell_(min_u8x32);
|
|
1277
|
+
nk_u8_t max_value = nk_reduce_max_u8x32_haswell_(max_u8x32);
|
|
1278
|
+
unsigned int min_lane, max_lane;
|
|
1279
|
+
{
|
|
1280
|
+
__m256i value_match_b8x32 = _mm256_cmpeq_epi8(min_u8x32, _mm256_set1_epi8((char)min_value));
|
|
1281
|
+
__m256i masked_cycle_u8x32 = _mm256_blendv_epi8(_mm256_set1_epi8((char)NK_U8_MAX), min_loop_cycle_u8x32,
|
|
1282
|
+
value_match_b8x32);
|
|
1283
|
+
nk_u8_t earliest_loop_cycle = nk_reduce_min_u8x32_haswell_(masked_cycle_u8x32);
|
|
1284
|
+
__m256i cycle_match_b8x32 = _mm256_cmpeq_epi8(masked_cycle_u8x32, _mm256_set1_epi8((char)earliest_loop_cycle));
|
|
1285
|
+
min_lane = _tzcnt_u32((unsigned int)_mm256_movemask_epi8(cycle_match_b8x32));
|
|
1286
|
+
}
|
|
1287
|
+
{
|
|
1288
|
+
__m256i value_match_b8x32 = _mm256_cmpeq_epi8(max_u8x32, _mm256_set1_epi8((char)max_value));
|
|
1289
|
+
__m256i masked_cycle_u8x32 = _mm256_blendv_epi8(_mm256_set1_epi8((char)NK_U8_MAX), max_loop_cycle_u8x32,
|
|
1290
|
+
value_match_b8x32);
|
|
1291
|
+
nk_u8_t earliest_loop_cycle = nk_reduce_min_u8x32_haswell_(masked_cycle_u8x32);
|
|
1292
|
+
__m256i cycle_match_b8x32 = _mm256_cmpeq_epi8(masked_cycle_u8x32, _mm256_set1_epi8((char)earliest_loop_cycle));
|
|
1293
|
+
max_lane = _tzcnt_u32((unsigned int)_mm256_movemask_epi8(cycle_match_b8x32));
|
|
1294
|
+
}
|
|
1295
|
+
nk_b256_vec_t loop_cycle_vec;
|
|
1296
|
+
loop_cycle_vec.ymm = min_loop_cycle_u8x32;
|
|
1297
|
+
*min_value_ptr = min_value, *min_index_ptr = (nk_size_t)loop_cycle_vec.u8s[min_lane] * 32 + min_lane;
|
|
1298
|
+
loop_cycle_vec.ymm = max_loop_cycle_u8x32;
|
|
1299
|
+
*max_value_ptr = max_value, *max_index_ptr = (nk_size_t)loop_cycle_vec.u8s[max_lane] * 32 + max_lane;
|
|
1300
|
+
}
|
|
1301
|
+
|
|
1302
|
+
NK_PUBLIC void nk_reduce_minmax_u8_haswell( //
|
|
1303
|
+
nk_u8_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
1304
|
+
nk_u8_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
1305
|
+
nk_u8_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
1306
|
+
|
|
1307
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_u8_t);
|
|
1308
|
+
int aligned = (stride_bytes % sizeof(nk_u8_t) == 0);
|
|
1309
|
+
if (count == 0)
|
|
1310
|
+
*min_value_ptr = NK_U8_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = 0, *max_index_ptr = NK_SIZE_MAX;
|
|
1311
|
+
else if (!aligned)
|
|
1312
|
+
nk_reduce_minmax_u8_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
1313
|
+
max_index_ptr);
|
|
1314
|
+
else if (count > (nk_size_t)(NK_U8_MAX + 1) * 32) {
|
|
1315
|
+
nk_size_t left_count = count / 2;
|
|
1316
|
+
nk_u8_t left_min, right_min, left_max, right_max;
|
|
1317
|
+
nk_size_t left_min_index, right_min_index, left_max_index, right_max_index;
|
|
1318
|
+
nk_reduce_minmax_u8_haswell(data_ptr, left_count, stride_bytes, &left_min, &left_min_index, &left_max,
|
|
1319
|
+
&left_max_index);
|
|
1320
|
+
nk_reduce_minmax_u8_haswell(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
|
|
1321
|
+
&right_min, &right_min_index, &right_max, &right_max_index);
|
|
1322
|
+
if (right_min < left_min) *min_value_ptr = right_min, *min_index_ptr = left_count + right_min_index;
|
|
1323
|
+
else *min_value_ptr = left_min, *min_index_ptr = left_min_index;
|
|
1324
|
+
if (right_max > left_max) *max_value_ptr = right_max, *max_index_ptr = left_count + right_max_index;
|
|
1325
|
+
else *max_value_ptr = left_max, *max_index_ptr = left_max_index;
|
|
1326
|
+
}
|
|
1327
|
+
else if (stride_elements == 1)
|
|
1328
|
+
nk_reduce_minmax_u8_haswell_contiguous_(data_ptr, count, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
1329
|
+
max_index_ptr);
|
|
1330
|
+
else
|
|
1331
|
+
nk_reduce_minmax_u8_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
1332
|
+
max_index_ptr);
|
|
1333
|
+
}
|
|
1334
|
+
|
|
1335
|
+
NK_INTERNAL void nk_reduce_moments_i16_haswell_contiguous_( //
|
|
1336
|
+
nk_i16_t const *data_ptr, nk_size_t count, //
|
|
1337
|
+
nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
1338
|
+
|
|
1339
|
+
__m256i ones_i16x16 = _mm256_set1_epi16(1);
|
|
1340
|
+
__m256i sum_i32x8 = _mm256_setzero_si256();
|
|
1341
|
+
__m256i sumsq_i64x4 = _mm256_setzero_si256();
|
|
1342
|
+
nk_size_t idx = 0;
|
|
1343
|
+
for (; idx + 16 <= count; idx += 16) {
|
|
1344
|
+
__m256i data_i16x16 = _mm256_loadu_si256((__m256i const *)(data_ptr + idx));
|
|
1345
|
+
sum_i32x8 = _mm256_add_epi32(sum_i32x8, _mm256_madd_epi16(data_i16x16, ones_i16x16));
|
|
1346
|
+
__m256i sq_i32x8 = _mm256_madd_epi16(data_i16x16, data_i16x16);
|
|
1347
|
+
sumsq_i64x4 = _mm256_add_epi64(sumsq_i64x4, _mm256_cvtepi32_epi64(_mm256_castsi256_si128(sq_i32x8)));
|
|
1348
|
+
sumsq_i64x4 = _mm256_add_epi64(sumsq_i64x4, _mm256_cvtepi32_epi64(_mm256_extracti128_si256(sq_i32x8, 1)));
|
|
1349
|
+
}
|
|
1350
|
+
nk_size_t remaining = count - idx;
|
|
1351
|
+
if (remaining > 0) {
|
|
1352
|
+
nk_b256_vec_t tail_vec;
|
|
1353
|
+
nk_partial_load_b16x16_serial_(data_ptr + idx, &tail_vec, remaining);
|
|
1354
|
+
__m256i data_i16x16 = tail_vec.ymm;
|
|
1355
|
+
sum_i32x8 = _mm256_add_epi32(sum_i32x8, _mm256_madd_epi16(data_i16x16, ones_i16x16));
|
|
1356
|
+
__m256i sq_i32x8 = _mm256_madd_epi16(data_i16x16, data_i16x16);
|
|
1357
|
+
sumsq_i64x4 = _mm256_add_epi64(sumsq_i64x4, _mm256_cvtepi32_epi64(_mm256_castsi256_si128(sq_i32x8)));
|
|
1358
|
+
sumsq_i64x4 = _mm256_add_epi64(sumsq_i64x4, _mm256_cvtepi32_epi64(_mm256_extracti128_si256(sq_i32x8, 1)));
|
|
1359
|
+
}
|
|
1360
|
+
__m256i sum_i64x4 = _mm256_add_epi64( //
|
|
1361
|
+
_mm256_cvtepi32_epi64(_mm256_castsi256_si128(sum_i32x8)), //
|
|
1362
|
+
_mm256_cvtepi32_epi64(_mm256_extracti128_si256(sum_i32x8, 1))); //
|
|
1363
|
+
*sum_ptr = nk_reduce_add_i64x4_haswell_(sum_i64x4),
|
|
1364
|
+
*sumsq_ptr = (nk_u64_t)nk_reduce_add_i64x4_haswell_(sumsq_i64x4);
|
|
1365
|
+
}
|
|
1366
|
+
|
|
1367
|
+
NK_INTERNAL void nk_reduce_moments_i16_haswell_strided_( //
|
|
1368
|
+
nk_i16_t const *data_ptr, nk_size_t count, nk_size_t stride_elements, //
|
|
1369
|
+
nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
1370
|
+
|
|
1371
|
+
__m256i stride_mask_i16x16 = nk_stride_blend_b16x16_(stride_elements);
|
|
1372
|
+
__m256i ones_i16x16 = _mm256_set1_epi16(1);
|
|
1373
|
+
__m256i sum_i32x8 = _mm256_setzero_si256();
|
|
1374
|
+
__m256i sumsq_i64x4 = _mm256_setzero_si256();
|
|
1375
|
+
nk_size_t idx_scalars = 0;
|
|
1376
|
+
nk_size_t total_scalars = count * stride_elements;
|
|
1377
|
+
nk_size_t step = nk_size_round_up_to_multiple_(16, stride_elements);
|
|
1378
|
+
for (; idx_scalars + step <= total_scalars; idx_scalars += step) {
|
|
1379
|
+
__m256i data_i16x16 = _mm256_loadu_si256((__m256i const *)(data_ptr + idx_scalars));
|
|
1380
|
+
data_i16x16 = _mm256_and_si256(data_i16x16, stride_mask_i16x16);
|
|
1381
|
+
sum_i32x8 = _mm256_add_epi32(sum_i32x8, _mm256_madd_epi16(data_i16x16, ones_i16x16));
|
|
1382
|
+
__m256i sq_i32x8 = _mm256_madd_epi16(data_i16x16, data_i16x16);
|
|
1383
|
+
sumsq_i64x4 = _mm256_add_epi64(sumsq_i64x4, _mm256_cvtepi32_epi64(_mm256_castsi256_si128(sq_i32x8)));
|
|
1384
|
+
sumsq_i64x4 = _mm256_add_epi64(sumsq_i64x4, _mm256_cvtepi32_epi64(_mm256_extracti128_si256(sq_i32x8, 1)));
|
|
1385
|
+
}
|
|
1386
|
+
__m256i sum_i64x4 = _mm256_add_epi64( //
|
|
1387
|
+
_mm256_cvtepi32_epi64(_mm256_castsi256_si128(sum_i32x8)), //
|
|
1388
|
+
_mm256_cvtepi32_epi64(_mm256_extracti128_si256(sum_i32x8, 1))); //
|
|
1389
|
+
nk_i64_t sum = nk_reduce_add_i64x4_haswell_(sum_i64x4);
|
|
1390
|
+
nk_u64_t sumsq = (nk_u64_t)nk_reduce_add_i64x4_haswell_(sumsq_i64x4);
|
|
1391
|
+
nk_i16_t const *ptr = data_ptr + idx_scalars;
|
|
1392
|
+
nk_size_t remaining = count - idx_scalars / stride_elements;
|
|
1393
|
+
for (nk_size_t i = 0; i < remaining; ++i, ptr += stride_elements) {
|
|
1394
|
+
nk_i64_t val = (nk_i64_t)*ptr;
|
|
1395
|
+
sum += val, sumsq += (nk_u64_t)(val * val);
|
|
1396
|
+
}
|
|
1397
|
+
*sum_ptr = sum, *sumsq_ptr = sumsq;
|
|
1398
|
+
}
|
|
1399
|
+
|
|
1400
|
+
NK_PUBLIC void nk_reduce_moments_i16_haswell( //
|
|
1401
|
+
nk_i16_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
1402
|
+
nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
1403
|
+
|
|
1404
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_i16_t);
|
|
1405
|
+
int aligned = (stride_bytes % sizeof(nk_i16_t) == 0);
|
|
1406
|
+
if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
|
|
1407
|
+
else if (!aligned) nk_reduce_moments_i16_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
1408
|
+
else if (count > (nk_size_t)(NK_U16_MAX + 1) * 16) {
|
|
1409
|
+
nk_size_t left_count = count / 2;
|
|
1410
|
+
nk_i64_t left_sum, right_sum;
|
|
1411
|
+
nk_u64_t left_sumsq, right_sumsq;
|
|
1412
|
+
nk_reduce_moments_i16_haswell(data_ptr, left_count, stride_bytes, &left_sum, &left_sumsq);
|
|
1413
|
+
nk_reduce_moments_i16_haswell(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
|
|
1414
|
+
&right_sum, &right_sumsq);
|
|
1415
|
+
*sum_ptr = nk_i64_saturating_add_serial(left_sum, right_sum);
|
|
1416
|
+
*sumsq_ptr = nk_u64_saturating_add_serial(left_sumsq, right_sumsq);
|
|
1417
|
+
}
|
|
1418
|
+
else if (stride_elements == 1) nk_reduce_moments_i16_haswell_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
|
|
1419
|
+
else if (stride_elements <= 8)
|
|
1420
|
+
nk_reduce_moments_i16_haswell_strided_(data_ptr, count, stride_elements, sum_ptr, sumsq_ptr);
|
|
1421
|
+
else nk_reduce_moments_i16_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
1422
|
+
}
|
|
1423
|
+
|
|
1424
|
+
NK_INTERNAL void nk_reduce_minmax_i16_haswell_contiguous_( //
|
|
1425
|
+
nk_i16_t const *data_ptr, nk_size_t count, //
|
|
1426
|
+
nk_i16_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
1427
|
+
nk_i16_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
1428
|
+
|
|
1429
|
+
__m256i min_i16x16 = _mm256_set1_epi16((short)NK_I16_MAX);
|
|
1430
|
+
__m256i max_i16x16 = _mm256_set1_epi16(NK_I16_MIN);
|
|
1431
|
+
__m256i min_loop_cycle_u16x16 = _mm256_setzero_si256();
|
|
1432
|
+
__m256i max_loop_cycle_u16x16 = _mm256_setzero_si256();
|
|
1433
|
+
__m256i current_loop_cycle_u16x16 = _mm256_setzero_si256();
|
|
1434
|
+
__m256i one_u16x16 = _mm256_set1_epi16(1);
|
|
1435
|
+
|
|
1436
|
+
nk_size_t idx = 0;
|
|
1437
|
+
for (; idx + 16 <= count; idx += 16) {
|
|
1438
|
+
__m256i data_i16x16 = _mm256_loadu_si256((__m256i const *)(data_ptr + idx));
|
|
1439
|
+
__m256i less_b16x16 = _mm256_cmpgt_epi16(min_i16x16, data_i16x16);
|
|
1440
|
+
__m256i greater_b16x16 = _mm256_cmpgt_epi16(data_i16x16, max_i16x16);
|
|
1441
|
+
min_i16x16 = _mm256_blendv_epi8(min_i16x16, data_i16x16, less_b16x16);
|
|
1442
|
+
max_i16x16 = _mm256_blendv_epi8(max_i16x16, data_i16x16, greater_b16x16);
|
|
1443
|
+
min_loop_cycle_u16x16 = _mm256_blendv_epi8(min_loop_cycle_u16x16, current_loop_cycle_u16x16, less_b16x16);
|
|
1444
|
+
max_loop_cycle_u16x16 = _mm256_blendv_epi8(max_loop_cycle_u16x16, current_loop_cycle_u16x16, greater_b16x16);
|
|
1445
|
+
current_loop_cycle_u16x16 = _mm256_add_epi16(current_loop_cycle_u16x16, one_u16x16);
|
|
1446
|
+
}
|
|
1447
|
+
|
|
1448
|
+
nk_size_t remaining = count - idx;
|
|
1449
|
+
if (remaining > 0) {
|
|
1450
|
+
nk_b256_vec_t tail_vec;
|
|
1451
|
+
nk_partial_load_b16x16_serial_(data_ptr + idx, &tail_vec, remaining);
|
|
1452
|
+
// Build 16-bit lane mask
|
|
1453
|
+
__m256i lane_indices_u16x16 = _mm256_setr_epi16(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
|
|
1454
|
+
__m256i valid_b16x16 = _mm256_cmpgt_epi16(_mm256_set1_epi16((short)remaining), lane_indices_u16x16);
|
|
1455
|
+
__m256i data_min_i16x16 = _mm256_blendv_epi8(_mm256_set1_epi16(NK_I16_MAX), tail_vec.ymm, valid_b16x16);
|
|
1456
|
+
__m256i data_max_i16x16 = _mm256_blendv_epi8(_mm256_set1_epi16(NK_I16_MIN), tail_vec.ymm, valid_b16x16);
|
|
1457
|
+
__m256i less_b16x16 = _mm256_cmpgt_epi16(min_i16x16, data_min_i16x16);
|
|
1458
|
+
__m256i greater_b16x16 = _mm256_cmpgt_epi16(data_max_i16x16, max_i16x16);
|
|
1459
|
+
min_i16x16 = _mm256_blendv_epi8(min_i16x16, data_min_i16x16, less_b16x16);
|
|
1460
|
+
max_i16x16 = _mm256_blendv_epi8(max_i16x16, data_max_i16x16, greater_b16x16);
|
|
1461
|
+
min_loop_cycle_u16x16 = _mm256_blendv_epi8(min_loop_cycle_u16x16, current_loop_cycle_u16x16, less_b16x16);
|
|
1462
|
+
max_loop_cycle_u16x16 = _mm256_blendv_epi8(max_loop_cycle_u16x16, current_loop_cycle_u16x16, greater_b16x16);
|
|
1463
|
+
}
|
|
1464
|
+
|
|
1465
|
+
nk_i16_t min_value = nk_reduce_min_i16x16_haswell_(min_i16x16);
|
|
1466
|
+
nk_i16_t max_value = nk_reduce_max_i16x16_haswell_(max_i16x16);
|
|
1467
|
+
unsigned int min_lane, max_lane;
|
|
1468
|
+
{
|
|
1469
|
+
__m256i value_match_b16x16 = _mm256_cmpeq_epi16(min_i16x16, _mm256_set1_epi16(min_value));
|
|
1470
|
+
__m256i masked_cycle_u16x16 = _mm256_blendv_epi8(_mm256_set1_epi16((short)NK_U16_MAX), min_loop_cycle_u16x16,
|
|
1471
|
+
value_match_b16x16);
|
|
1472
|
+
nk_u16_t earliest_loop_cycle = nk_reduce_min_u16x16_haswell_(masked_cycle_u16x16);
|
|
1473
|
+
__m256i cycle_match_b16x16 = _mm256_cmpeq_epi16(masked_cycle_u16x16,
|
|
1474
|
+
_mm256_set1_epi16((short)earliest_loop_cycle));
|
|
1475
|
+
min_lane = _tzcnt_u32((unsigned int)_mm256_movemask_epi8(cycle_match_b16x16)) / 2;
|
|
1476
|
+
}
|
|
1477
|
+
{
|
|
1478
|
+
__m256i value_match_b16x16 = _mm256_cmpeq_epi16(max_i16x16, _mm256_set1_epi16(max_value));
|
|
1479
|
+
__m256i masked_cycle_u16x16 = _mm256_blendv_epi8(_mm256_set1_epi16((short)NK_U16_MAX), max_loop_cycle_u16x16,
|
|
1480
|
+
value_match_b16x16);
|
|
1481
|
+
nk_u16_t earliest_loop_cycle = nk_reduce_min_u16x16_haswell_(masked_cycle_u16x16);
|
|
1482
|
+
__m256i cycle_match_b16x16 = _mm256_cmpeq_epi16(masked_cycle_u16x16,
|
|
1483
|
+
_mm256_set1_epi16((short)earliest_loop_cycle));
|
|
1484
|
+
max_lane = _tzcnt_u32((unsigned int)_mm256_movemask_epi8(cycle_match_b16x16)) / 2;
|
|
1485
|
+
}
|
|
1486
|
+
nk_b256_vec_t loop_cycle_vec;
|
|
1487
|
+
loop_cycle_vec.ymm = min_loop_cycle_u16x16;
|
|
1488
|
+
*min_value_ptr = min_value, *min_index_ptr = (nk_size_t)loop_cycle_vec.u16s[min_lane] * 16 + min_lane;
|
|
1489
|
+
loop_cycle_vec.ymm = max_loop_cycle_u16x16;
|
|
1490
|
+
*max_value_ptr = max_value, *max_index_ptr = (nk_size_t)loop_cycle_vec.u16s[max_lane] * 16 + max_lane;
|
|
1491
|
+
}
|
|
1492
|
+
|
|
1493
|
+
NK_PUBLIC void nk_reduce_minmax_i16_haswell( //
|
|
1494
|
+
nk_i16_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
1495
|
+
nk_i16_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
1496
|
+
nk_i16_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
1497
|
+
|
|
1498
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_i16_t);
|
|
1499
|
+
int aligned = (stride_bytes % sizeof(nk_i16_t) == 0);
|
|
1500
|
+
if (count == 0)
|
|
1501
|
+
*min_value_ptr = NK_I16_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_I16_MIN,
|
|
1502
|
+
*max_index_ptr = NK_SIZE_MAX;
|
|
1503
|
+
else if (!aligned)
|
|
1504
|
+
nk_reduce_minmax_i16_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
1505
|
+
max_index_ptr);
|
|
1506
|
+
else if (count > (nk_size_t)(NK_U16_MAX + 1) * 16) {
|
|
1507
|
+
nk_size_t left_count = count / 2;
|
|
1508
|
+
nk_i16_t left_min, right_min, left_max, right_max;
|
|
1509
|
+
nk_size_t left_min_index, right_min_index, left_max_index, right_max_index;
|
|
1510
|
+
nk_reduce_minmax_i16_haswell(data_ptr, left_count, stride_bytes, &left_min, &left_min_index, &left_max,
|
|
1511
|
+
&left_max_index);
|
|
1512
|
+
nk_reduce_minmax_i16_haswell(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
|
|
1513
|
+
&right_min, &right_min_index, &right_max, &right_max_index);
|
|
1514
|
+
if (right_min < left_min) *min_value_ptr = right_min, *min_index_ptr = left_count + right_min_index;
|
|
1515
|
+
else *min_value_ptr = left_min, *min_index_ptr = left_min_index;
|
|
1516
|
+
if (right_max > left_max) *max_value_ptr = right_max, *max_index_ptr = left_count + right_max_index;
|
|
1517
|
+
else *max_value_ptr = left_max, *max_index_ptr = left_max_index;
|
|
1518
|
+
}
|
|
1519
|
+
else if (stride_elements == 1)
|
|
1520
|
+
nk_reduce_minmax_i16_haswell_contiguous_(data_ptr, count, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
1521
|
+
max_index_ptr);
|
|
1522
|
+
else
|
|
1523
|
+
nk_reduce_minmax_i16_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
1524
|
+
max_index_ptr);
|
|
1525
|
+
}
|
|
1526
|
+
|
|
1527
|
+
NK_INTERNAL void nk_reduce_moments_u16_haswell_contiguous_( //
|
|
1528
|
+
nk_u16_t const *data_ptr, nk_size_t count, //
|
|
1529
|
+
nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
1530
|
+
|
|
1531
|
+
// Widen u16→u32, square in u32, widen to u64.
|
|
1532
|
+
__m256i sum_u32x8 = _mm256_setzero_si256();
|
|
1533
|
+
__m256i sumsq_u64x4 = _mm256_setzero_si256();
|
|
1534
|
+
nk_size_t idx = 0;
|
|
1535
|
+
for (; idx + 8 <= count; idx += 8) {
|
|
1536
|
+
__m256i data_u32x8 = _mm256_cvtepu16_epi32(_mm_loadu_si128((__m128i const *)(data_ptr + idx)));
|
|
1537
|
+
sum_u32x8 = _mm256_add_epi32(sum_u32x8, data_u32x8);
|
|
1538
|
+
__m256i sq_u32x8 = _mm256_mullo_epi32(data_u32x8, data_u32x8);
|
|
1539
|
+
sumsq_u64x4 = _mm256_add_epi64(sumsq_u64x4, _mm256_cvtepu32_epi64(_mm256_castsi256_si128(sq_u32x8)));
|
|
1540
|
+
sumsq_u64x4 = _mm256_add_epi64(sumsq_u64x4, _mm256_cvtepu32_epi64(_mm256_extracti128_si256(sq_u32x8, 1)));
|
|
1541
|
+
}
|
|
1542
|
+
nk_size_t remaining = count - idx;
|
|
1543
|
+
if (remaining > 0) {
|
|
1544
|
+
nk_b256_vec_t tail_vec;
|
|
1545
|
+
nk_partial_load_b16x16_serial_(data_ptr + idx, &tail_vec, remaining);
|
|
1546
|
+
__m256i data_u32x8 = _mm256_cvtepu16_epi32(_mm256_castsi256_si128(tail_vec.ymm));
|
|
1547
|
+
sum_u32x8 = _mm256_add_epi32(sum_u32x8, data_u32x8);
|
|
1548
|
+
__m256i sq_u32x8 = _mm256_mullo_epi32(data_u32x8, data_u32x8);
|
|
1549
|
+
sumsq_u64x4 = _mm256_add_epi64(sumsq_u64x4, _mm256_cvtepu32_epi64(_mm256_castsi256_si128(sq_u32x8)));
|
|
1550
|
+
sumsq_u64x4 = _mm256_add_epi64(sumsq_u64x4, _mm256_cvtepu32_epi64(_mm256_extracti128_si256(sq_u32x8, 1)));
|
|
1551
|
+
}
|
|
1552
|
+
__m256i sum_u64x4 = _mm256_add_epi64( //
|
|
1553
|
+
_mm256_cvtepu32_epi64(_mm256_castsi256_si128(sum_u32x8)), //
|
|
1554
|
+
_mm256_cvtepu32_epi64(_mm256_extracti128_si256(sum_u32x8, 1))); //
|
|
1555
|
+
*sum_ptr = (nk_u64_t)nk_reduce_add_i64x4_haswell_(sum_u64x4),
|
|
1556
|
+
*sumsq_ptr = (nk_u64_t)nk_reduce_add_i64x4_haswell_(sumsq_u64x4);
|
|
1557
|
+
}
|
|
1558
|
+
|
|
1559
|
+
NK_INTERNAL void nk_reduce_moments_u16_haswell_strided_( //
|
|
1560
|
+
nk_u16_t const *data_ptr, nk_size_t count, nk_size_t stride_elements, //
|
|
1561
|
+
nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
1562
|
+
|
|
1563
|
+
__m256i stride_mask_i16x16 = nk_stride_blend_b16x16_(stride_elements);
|
|
1564
|
+
__m256i sum_u32x8 = _mm256_setzero_si256();
|
|
1565
|
+
__m256i sumsq_u64x4 = _mm256_setzero_si256();
|
|
1566
|
+
nk_size_t idx_scalars = 0;
|
|
1567
|
+
nk_size_t total_scalars = count * stride_elements;
|
|
1568
|
+
nk_size_t step = nk_size_round_up_to_multiple_(16, stride_elements);
|
|
1569
|
+
for (; idx_scalars + step <= total_scalars; idx_scalars += step) {
|
|
1570
|
+
__m256i data_u16x16 = _mm256_loadu_si256((__m256i const *)(data_ptr + idx_scalars));
|
|
1571
|
+
data_u16x16 = _mm256_and_si256(data_u16x16, stride_mask_i16x16);
|
|
1572
|
+
__m256i lo_u32x8 = _mm256_cvtepu16_epi32(_mm256_castsi256_si128(data_u16x16));
|
|
1573
|
+
__m256i hi_u32x8 = _mm256_cvtepu16_epi32(_mm256_extracti128_si256(data_u16x16, 1));
|
|
1574
|
+
sum_u32x8 = _mm256_add_epi32(sum_u32x8, lo_u32x8);
|
|
1575
|
+
sum_u32x8 = _mm256_add_epi32(sum_u32x8, hi_u32x8);
|
|
1576
|
+
__m256i lo_sq_u32x8 = _mm256_mullo_epi32(lo_u32x8, lo_u32x8);
|
|
1577
|
+
__m256i hi_sq_u32x8 = _mm256_mullo_epi32(hi_u32x8, hi_u32x8);
|
|
1578
|
+
sumsq_u64x4 = _mm256_add_epi64(sumsq_u64x4, _mm256_cvtepu32_epi64(_mm256_castsi256_si128(lo_sq_u32x8)));
|
|
1579
|
+
sumsq_u64x4 = _mm256_add_epi64(sumsq_u64x4, _mm256_cvtepu32_epi64(_mm256_extracti128_si256(lo_sq_u32x8, 1)));
|
|
1580
|
+
sumsq_u64x4 = _mm256_add_epi64(sumsq_u64x4, _mm256_cvtepu32_epi64(_mm256_castsi256_si128(hi_sq_u32x8)));
|
|
1581
|
+
sumsq_u64x4 = _mm256_add_epi64(sumsq_u64x4, _mm256_cvtepu32_epi64(_mm256_extracti128_si256(hi_sq_u32x8, 1)));
|
|
1582
|
+
}
|
|
1583
|
+
__m256i sum_u64x4 = _mm256_add_epi64( //
|
|
1584
|
+
_mm256_cvtepu32_epi64(_mm256_castsi256_si128(sum_u32x8)), //
|
|
1585
|
+
_mm256_cvtepu32_epi64(_mm256_extracti128_si256(sum_u32x8, 1))); //
|
|
1586
|
+
nk_u64_t sum = (nk_u64_t)nk_reduce_add_i64x4_haswell_(sum_u64x4);
|
|
1587
|
+
nk_u64_t sumsq = (nk_u64_t)nk_reduce_add_i64x4_haswell_(sumsq_u64x4);
|
|
1588
|
+
nk_u16_t const *ptr = data_ptr + idx_scalars;
|
|
1589
|
+
nk_size_t remaining = count - idx_scalars / stride_elements;
|
|
1590
|
+
for (nk_size_t i = 0; i < remaining; ++i, ptr += stride_elements) {
|
|
1591
|
+
nk_u64_t val = (nk_u64_t)*ptr;
|
|
1592
|
+
sum += val, sumsq += val * val;
|
|
1593
|
+
}
|
|
1594
|
+
*sum_ptr = sum, *sumsq_ptr = sumsq;
|
|
1595
|
+
}
|
|
1596
|
+
|
|
1597
|
+
NK_PUBLIC void nk_reduce_moments_u16_haswell( //
|
|
1598
|
+
nk_u16_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
1599
|
+
nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
1600
|
+
|
|
1601
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_u16_t);
|
|
1602
|
+
int aligned = (stride_bytes % sizeof(nk_u16_t) == 0);
|
|
1603
|
+
if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
|
|
1604
|
+
else if (!aligned) nk_reduce_moments_u16_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
1605
|
+
else if (count > (nk_size_t)(NK_U16_MAX + 1) * 8) {
|
|
1606
|
+
nk_size_t left_count = count / 2;
|
|
1607
|
+
nk_u64_t left_sum, left_sumsq, right_sum, right_sumsq;
|
|
1608
|
+
nk_reduce_moments_u16_haswell(data_ptr, left_count, stride_bytes, &left_sum, &left_sumsq);
|
|
1609
|
+
nk_reduce_moments_u16_haswell(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
|
|
1610
|
+
&right_sum, &right_sumsq);
|
|
1611
|
+
*sum_ptr = nk_u64_saturating_add_serial(left_sum, right_sum);
|
|
1612
|
+
*sumsq_ptr = nk_u64_saturating_add_serial(left_sumsq, right_sumsq);
|
|
1613
|
+
}
|
|
1614
|
+
else if (stride_elements == 1) nk_reduce_moments_u16_haswell_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
|
|
1615
|
+
else if (stride_elements <= 8)
|
|
1616
|
+
nk_reduce_moments_u16_haswell_strided_(data_ptr, count, stride_elements, sum_ptr, sumsq_ptr);
|
|
1617
|
+
else nk_reduce_moments_u16_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
1618
|
+
}
|
|
1619
|
+
|
|
1620
|
+
NK_INTERNAL void nk_reduce_minmax_u16_haswell_contiguous_( //
|
|
1621
|
+
nk_u16_t const *data_ptr, nk_size_t count, //
|
|
1622
|
+
nk_u16_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
1623
|
+
nk_u16_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
1624
|
+
|
|
1625
|
+
// XOR-bias to signed domain for _mm256_cmpgt_epi16
|
|
1626
|
+
__m256i bias_u16x16 = _mm256_set1_epi16((short)0x8000);
|
|
1627
|
+
__m256i min_biased_i16x16 = _mm256_set1_epi16((short)NK_I16_MAX);
|
|
1628
|
+
__m256i max_biased_i16x16 = _mm256_set1_epi16(NK_I16_MIN);
|
|
1629
|
+
__m256i min_loop_cycle_u16x16 = _mm256_setzero_si256();
|
|
1630
|
+
__m256i max_loop_cycle_u16x16 = _mm256_setzero_si256();
|
|
1631
|
+
__m256i current_loop_cycle_u16x16 = _mm256_setzero_si256();
|
|
1632
|
+
__m256i one_u16x16 = _mm256_set1_epi16(1);
|
|
1633
|
+
|
|
1634
|
+
nk_size_t idx = 0;
|
|
1635
|
+
for (; idx + 16 <= count; idx += 16) {
|
|
1636
|
+
__m256i data_biased_i16x16 = _mm256_xor_si256(_mm256_loadu_si256((__m256i const *)(data_ptr + idx)),
|
|
1637
|
+
bias_u16x16);
|
|
1638
|
+
__m256i less_b16x16 = _mm256_cmpgt_epi16(min_biased_i16x16, data_biased_i16x16);
|
|
1639
|
+
__m256i greater_b16x16 = _mm256_cmpgt_epi16(data_biased_i16x16, max_biased_i16x16);
|
|
1640
|
+
min_biased_i16x16 = _mm256_blendv_epi8(min_biased_i16x16, data_biased_i16x16, less_b16x16);
|
|
1641
|
+
max_biased_i16x16 = _mm256_blendv_epi8(max_biased_i16x16, data_biased_i16x16, greater_b16x16);
|
|
1642
|
+
min_loop_cycle_u16x16 = _mm256_blendv_epi8(min_loop_cycle_u16x16, current_loop_cycle_u16x16, less_b16x16);
|
|
1643
|
+
max_loop_cycle_u16x16 = _mm256_blendv_epi8(max_loop_cycle_u16x16, current_loop_cycle_u16x16, greater_b16x16);
|
|
1644
|
+
current_loop_cycle_u16x16 = _mm256_add_epi16(current_loop_cycle_u16x16, one_u16x16);
|
|
1645
|
+
}
|
|
1646
|
+
|
|
1647
|
+
nk_size_t remaining = count - idx;
|
|
1648
|
+
if (remaining > 0) {
|
|
1649
|
+
nk_b256_vec_t tail_vec;
|
|
1650
|
+
nk_partial_load_b16x16_serial_(data_ptr + idx, &tail_vec, remaining);
|
|
1651
|
+
__m256i data_biased_i16x16 = _mm256_xor_si256(tail_vec.ymm, bias_u16x16);
|
|
1652
|
+
__m256i lane_indices_u16x16 = _mm256_setr_epi16(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
|
|
1653
|
+
__m256i valid_b16x16 = _mm256_cmpgt_epi16(_mm256_set1_epi16((short)remaining), lane_indices_u16x16);
|
|
1654
|
+
// Biased identity: NK_U16_MAX ^ 0x8000 = 0x7FFF for min, 0 ^ 0x8000 = 0x8000 for max
|
|
1655
|
+
__m256i data_min_i16x16 = _mm256_blendv_epi8(_mm256_set1_epi16(0x7FFF), data_biased_i16x16, valid_b16x16);
|
|
1656
|
+
__m256i data_max_i16x16 = _mm256_blendv_epi8(_mm256_set1_epi16((short)0x8000), data_biased_i16x16,
|
|
1657
|
+
valid_b16x16);
|
|
1658
|
+
__m256i less_b16x16 = _mm256_cmpgt_epi16(min_biased_i16x16, data_min_i16x16);
|
|
1659
|
+
__m256i greater_b16x16 = _mm256_cmpgt_epi16(data_max_i16x16, max_biased_i16x16);
|
|
1660
|
+
min_biased_i16x16 = _mm256_blendv_epi8(min_biased_i16x16, data_min_i16x16, less_b16x16);
|
|
1661
|
+
max_biased_i16x16 = _mm256_blendv_epi8(max_biased_i16x16, data_max_i16x16, greater_b16x16);
|
|
1662
|
+
min_loop_cycle_u16x16 = _mm256_blendv_epi8(min_loop_cycle_u16x16, current_loop_cycle_u16x16, less_b16x16);
|
|
1663
|
+
max_loop_cycle_u16x16 = _mm256_blendv_epi8(max_loop_cycle_u16x16, current_loop_cycle_u16x16, greater_b16x16);
|
|
1664
|
+
}
|
|
1665
|
+
|
|
1666
|
+
__m256i min_u16x16 = _mm256_xor_si256(min_biased_i16x16, bias_u16x16);
|
|
1667
|
+
__m256i max_u16x16 = _mm256_xor_si256(max_biased_i16x16, bias_u16x16);
|
|
1668
|
+
nk_u16_t min_value = nk_reduce_min_u16x16_haswell_(min_u16x16);
|
|
1669
|
+
nk_u16_t max_value = nk_reduce_max_u16x16_haswell_(max_u16x16);
|
|
1670
|
+
unsigned int min_lane, max_lane;
|
|
1671
|
+
{
|
|
1672
|
+
__m256i value_match_b16x16 = _mm256_cmpeq_epi16(min_u16x16, _mm256_set1_epi16((short)min_value));
|
|
1673
|
+
__m256i masked_cycle_u16x16 = _mm256_blendv_epi8(_mm256_set1_epi16((short)NK_U16_MAX), min_loop_cycle_u16x16,
|
|
1674
|
+
value_match_b16x16);
|
|
1675
|
+
nk_u16_t earliest_loop_cycle = nk_reduce_min_u16x16_haswell_(masked_cycle_u16x16);
|
|
1676
|
+
__m256i cycle_match_b16x16 = _mm256_cmpeq_epi16(masked_cycle_u16x16,
|
|
1677
|
+
_mm256_set1_epi16((short)earliest_loop_cycle));
|
|
1678
|
+
min_lane = _tzcnt_u32((unsigned int)_mm256_movemask_epi8(cycle_match_b16x16)) / 2;
|
|
1679
|
+
}
|
|
1680
|
+
{
|
|
1681
|
+
__m256i value_match_b16x16 = _mm256_cmpeq_epi16(max_u16x16, _mm256_set1_epi16((short)max_value));
|
|
1682
|
+
__m256i masked_cycle_u16x16 = _mm256_blendv_epi8(_mm256_set1_epi16((short)NK_U16_MAX), max_loop_cycle_u16x16,
|
|
1683
|
+
value_match_b16x16);
|
|
1684
|
+
nk_u16_t earliest_loop_cycle = nk_reduce_min_u16x16_haswell_(masked_cycle_u16x16);
|
|
1685
|
+
__m256i cycle_match_b16x16 = _mm256_cmpeq_epi16(masked_cycle_u16x16,
|
|
1686
|
+
_mm256_set1_epi16((short)earliest_loop_cycle));
|
|
1687
|
+
max_lane = _tzcnt_u32((unsigned int)_mm256_movemask_epi8(cycle_match_b16x16)) / 2;
|
|
1688
|
+
}
|
|
1689
|
+
nk_b256_vec_t loop_cycle_vec;
|
|
1690
|
+
loop_cycle_vec.ymm = min_loop_cycle_u16x16;
|
|
1691
|
+
*min_value_ptr = min_value, *min_index_ptr = (nk_size_t)loop_cycle_vec.u16s[min_lane] * 16 + min_lane;
|
|
1692
|
+
loop_cycle_vec.ymm = max_loop_cycle_u16x16;
|
|
1693
|
+
*max_value_ptr = max_value, *max_index_ptr = (nk_size_t)loop_cycle_vec.u16s[max_lane] * 16 + max_lane;
|
|
1694
|
+
}
|
|
1695
|
+
|
|
1696
|
+
NK_PUBLIC void nk_reduce_minmax_u16_haswell( //
|
|
1697
|
+
nk_u16_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
1698
|
+
nk_u16_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
1699
|
+
nk_u16_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
1700
|
+
|
|
1701
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_u16_t);
|
|
1702
|
+
int aligned = (stride_bytes % sizeof(nk_u16_t) == 0);
|
|
1703
|
+
if (count == 0)
|
|
1704
|
+
*min_value_ptr = NK_U16_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = 0, *max_index_ptr = NK_SIZE_MAX;
|
|
1705
|
+
else if (!aligned)
|
|
1706
|
+
nk_reduce_minmax_u16_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
1707
|
+
max_index_ptr);
|
|
1708
|
+
else if (count > (nk_size_t)(NK_U16_MAX + 1) * 16) {
|
|
1709
|
+
nk_size_t left_count = count / 2;
|
|
1710
|
+
nk_u16_t left_min, right_min, left_max, right_max;
|
|
1711
|
+
nk_size_t left_min_index, right_min_index, left_max_index, right_max_index;
|
|
1712
|
+
nk_reduce_minmax_u16_haswell(data_ptr, left_count, stride_bytes, &left_min, &left_min_index, &left_max,
|
|
1713
|
+
&left_max_index);
|
|
1714
|
+
nk_reduce_minmax_u16_haswell(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
|
|
1715
|
+
&right_min, &right_min_index, &right_max, &right_max_index);
|
|
1716
|
+
if (right_min < left_min) *min_value_ptr = right_min, *min_index_ptr = left_count + right_min_index;
|
|
1717
|
+
else *min_value_ptr = left_min, *min_index_ptr = left_min_index;
|
|
1718
|
+
if (right_max > left_max) *max_value_ptr = right_max, *max_index_ptr = left_count + right_max_index;
|
|
1719
|
+
else *max_value_ptr = left_max, *max_index_ptr = left_max_index;
|
|
1720
|
+
}
|
|
1721
|
+
else if (stride_elements == 1)
|
|
1722
|
+
nk_reduce_minmax_u16_haswell_contiguous_(data_ptr, count, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
1723
|
+
max_index_ptr);
|
|
1724
|
+
else
|
|
1725
|
+
nk_reduce_minmax_u16_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
1726
|
+
max_index_ptr);
|
|
1727
|
+
}
|
|
1728
|
+
|
|
1729
|
+
NK_INTERNAL void nk_reduce_moments_i32_haswell_contiguous_( //
|
|
1730
|
+
nk_i32_t const *data_ptr, nk_size_t count, //
|
|
1731
|
+
nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
1732
|
+
|
|
1733
|
+
__m256i sum_lower_i64x4 = _mm256_setzero_si256();
|
|
1734
|
+
__m256i sum_upper_i64x4 = _mm256_setzero_si256();
|
|
1735
|
+
__m256i sumsq_u64x4 = _mm256_setzero_si256();
|
|
1736
|
+
int sumsq_overflow_mask = 0;
|
|
1737
|
+
__m256i sign_bit_i64x4 = _mm256_set1_epi64x((nk_i64_t)0x8000000000000000ULL);
|
|
1738
|
+
nk_size_t idx = 0;
|
|
1739
|
+
for (; idx + 8 <= count; idx += 8) {
|
|
1740
|
+
__m256i data_i32x8 = _mm256_loadu_si256((__m256i const *)(data_ptr + idx));
|
|
1741
|
+
// 128-bit sum: lo half
|
|
1742
|
+
__m256i widened_lo_i64x4 = _mm256_cvtepi32_epi64(_mm256_castsi256_si128(data_i32x8));
|
|
1743
|
+
__m256i sum_before_i64x4 = sum_lower_i64x4;
|
|
1744
|
+
sum_lower_i64x4 = _mm256_add_epi64(sum_lower_i64x4, widened_lo_i64x4);
|
|
1745
|
+
__m256i result_biased_i64x4 = _mm256_xor_si256(sum_lower_i64x4, sign_bit_i64x4);
|
|
1746
|
+
__m256i before_biased_i64x4 = _mm256_xor_si256(sum_before_i64x4, sign_bit_i64x4);
|
|
1747
|
+
__m256i carry_mask_i64x4 = _mm256_cmpgt_epi64(before_biased_i64x4, result_biased_i64x4);
|
|
1748
|
+
sum_upper_i64x4 = _mm256_sub_epi64(sum_upper_i64x4, carry_mask_i64x4);
|
|
1749
|
+
__m256i sign_ext_i64x4 = _mm256_cmpgt_epi64(_mm256_setzero_si256(), widened_lo_i64x4);
|
|
1750
|
+
sum_upper_i64x4 = _mm256_add_epi64(sum_upper_i64x4, sign_ext_i64x4);
|
|
1751
|
+
// 128-bit sum: hi half
|
|
1752
|
+
__m256i widened_hi_i64x4 = _mm256_cvtepi32_epi64(_mm256_extracti128_si256(data_i32x8, 1));
|
|
1753
|
+
sum_before_i64x4 = sum_lower_i64x4;
|
|
1754
|
+
sum_lower_i64x4 = _mm256_add_epi64(sum_lower_i64x4, widened_hi_i64x4);
|
|
1755
|
+
result_biased_i64x4 = _mm256_xor_si256(sum_lower_i64x4, sign_bit_i64x4);
|
|
1756
|
+
before_biased_i64x4 = _mm256_xor_si256(sum_before_i64x4, sign_bit_i64x4);
|
|
1757
|
+
carry_mask_i64x4 = _mm256_cmpgt_epi64(before_biased_i64x4, result_biased_i64x4);
|
|
1758
|
+
sum_upper_i64x4 = _mm256_sub_epi64(sum_upper_i64x4, carry_mask_i64x4);
|
|
1759
|
+
sign_ext_i64x4 = _mm256_cmpgt_epi64(_mm256_setzero_si256(), widened_hi_i64x4);
|
|
1760
|
+
sum_upper_i64x4 = _mm256_add_epi64(sum_upper_i64x4, sign_ext_i64x4);
|
|
1761
|
+
// Sumsq: running mask + wrapping add with unsigned carry detection
|
|
1762
|
+
__m256i even_sq_u64x4 = _mm256_mul_epi32(data_i32x8, data_i32x8);
|
|
1763
|
+
__m256i odd_i32x8 = _mm256_srli_epi64(data_i32x8, 32);
|
|
1764
|
+
__m256i odd_sq_u64x4 = _mm256_mul_epi32(odd_i32x8, odd_i32x8);
|
|
1765
|
+
__m256i sumsq_before_u64x4 = sumsq_u64x4;
|
|
1766
|
+
sumsq_u64x4 = _mm256_add_epi64(sumsq_u64x4, even_sq_u64x4);
|
|
1767
|
+
__m256i sq_result_biased_u64x4 = _mm256_xor_si256(sumsq_u64x4, sign_bit_i64x4);
|
|
1768
|
+
__m256i sq_before_biased_u64x4 = _mm256_xor_si256(sumsq_before_u64x4, sign_bit_i64x4);
|
|
1769
|
+
sumsq_overflow_mask |= _mm256_movemask_pd(
|
|
1770
|
+
_mm256_castsi256_pd(_mm256_cmpgt_epi64(sq_before_biased_u64x4, sq_result_biased_u64x4)));
|
|
1771
|
+
sumsq_before_u64x4 = sumsq_u64x4;
|
|
1772
|
+
sumsq_u64x4 = _mm256_add_epi64(sumsq_u64x4, odd_sq_u64x4);
|
|
1773
|
+
sq_result_biased_u64x4 = _mm256_xor_si256(sumsq_u64x4, sign_bit_i64x4);
|
|
1774
|
+
sq_before_biased_u64x4 = _mm256_xor_si256(sumsq_before_u64x4, sign_bit_i64x4);
|
|
1775
|
+
sumsq_overflow_mask |= _mm256_movemask_pd(
|
|
1776
|
+
_mm256_castsi256_pd(_mm256_cmpgt_epi64(sq_before_biased_u64x4, sq_result_biased_u64x4)));
|
|
1777
|
+
}
|
|
1778
|
+
nk_size_t remaining = count - idx;
|
|
1779
|
+
if (remaining > 0) {
|
|
1780
|
+
nk_b256_vec_t tail_vec;
|
|
1781
|
+
nk_partial_load_b32x8_serial_(data_ptr + idx, &tail_vec, remaining);
|
|
1782
|
+
__m256i data_i32x8 = tail_vec.ymm;
|
|
1783
|
+
__m256i widened_lo_i64x4 = _mm256_cvtepi32_epi64(_mm256_castsi256_si128(data_i32x8));
|
|
1784
|
+
__m256i sum_before_i64x4 = sum_lower_i64x4;
|
|
1785
|
+
sum_lower_i64x4 = _mm256_add_epi64(sum_lower_i64x4, widened_lo_i64x4);
|
|
1786
|
+
__m256i result_biased_i64x4 = _mm256_xor_si256(sum_lower_i64x4, sign_bit_i64x4);
|
|
1787
|
+
__m256i before_biased_i64x4 = _mm256_xor_si256(sum_before_i64x4, sign_bit_i64x4);
|
|
1788
|
+
__m256i carry_mask_i64x4 = _mm256_cmpgt_epi64(before_biased_i64x4, result_biased_i64x4);
|
|
1789
|
+
sum_upper_i64x4 = _mm256_sub_epi64(sum_upper_i64x4, carry_mask_i64x4);
|
|
1790
|
+
__m256i sign_ext_i64x4 = _mm256_cmpgt_epi64(_mm256_setzero_si256(), widened_lo_i64x4);
|
|
1791
|
+
sum_upper_i64x4 = _mm256_add_epi64(sum_upper_i64x4, sign_ext_i64x4);
|
|
1792
|
+
__m256i widened_hi_i64x4 = _mm256_cvtepi32_epi64(_mm256_extracti128_si256(data_i32x8, 1));
|
|
1793
|
+
sum_before_i64x4 = sum_lower_i64x4;
|
|
1794
|
+
sum_lower_i64x4 = _mm256_add_epi64(sum_lower_i64x4, widened_hi_i64x4);
|
|
1795
|
+
result_biased_i64x4 = _mm256_xor_si256(sum_lower_i64x4, sign_bit_i64x4);
|
|
1796
|
+
before_biased_i64x4 = _mm256_xor_si256(sum_before_i64x4, sign_bit_i64x4);
|
|
1797
|
+
carry_mask_i64x4 = _mm256_cmpgt_epi64(before_biased_i64x4, result_biased_i64x4);
|
|
1798
|
+
sum_upper_i64x4 = _mm256_sub_epi64(sum_upper_i64x4, carry_mask_i64x4);
|
|
1799
|
+
sign_ext_i64x4 = _mm256_cmpgt_epi64(_mm256_setzero_si256(), widened_hi_i64x4);
|
|
1800
|
+
sum_upper_i64x4 = _mm256_add_epi64(sum_upper_i64x4, sign_ext_i64x4);
|
|
1801
|
+
__m256i even_sq_u64x4 = _mm256_mul_epi32(data_i32x8, data_i32x8);
|
|
1802
|
+
__m256i odd_i32x8 = _mm256_srli_epi64(data_i32x8, 32);
|
|
1803
|
+
__m256i odd_sq_u64x4 = _mm256_mul_epi32(odd_i32x8, odd_i32x8);
|
|
1804
|
+
__m256i sumsq_before_u64x4 = sumsq_u64x4;
|
|
1805
|
+
sumsq_u64x4 = _mm256_add_epi64(sumsq_u64x4, even_sq_u64x4);
|
|
1806
|
+
__m256i sq_result_biased_u64x4 = _mm256_xor_si256(sumsq_u64x4, sign_bit_i64x4);
|
|
1807
|
+
__m256i sq_before_biased_u64x4 = _mm256_xor_si256(sumsq_before_u64x4, sign_bit_i64x4);
|
|
1808
|
+
sumsq_overflow_mask |= _mm256_movemask_pd(
|
|
1809
|
+
_mm256_castsi256_pd(_mm256_cmpgt_epi64(sq_before_biased_u64x4, sq_result_biased_u64x4)));
|
|
1810
|
+
sumsq_before_u64x4 = sumsq_u64x4;
|
|
1811
|
+
sumsq_u64x4 = _mm256_add_epi64(sumsq_u64x4, odd_sq_u64x4);
|
|
1812
|
+
sq_result_biased_u64x4 = _mm256_xor_si256(sumsq_u64x4, sign_bit_i64x4);
|
|
1813
|
+
sq_before_biased_u64x4 = _mm256_xor_si256(sumsq_before_u64x4, sign_bit_i64x4);
|
|
1814
|
+
sumsq_overflow_mask |= _mm256_movemask_pd(
|
|
1815
|
+
_mm256_castsi256_pd(_mm256_cmpgt_epi64(sq_before_biased_u64x4, sq_result_biased_u64x4)));
|
|
1816
|
+
}
|
|
1817
|
+
// Sumsq: horizontal unsigned saturating reduction
|
|
1818
|
+
nk_u64_t sumsq;
|
|
1819
|
+
if (sumsq_overflow_mask) sumsq = NK_U64_MAX;
|
|
1820
|
+
else sumsq = nk_reduce_sadd_u64x4_haswell_(sumsq_u64x4);
|
|
1821
|
+
// Sum: horizontal 128-bit reduction (4 lanes → scalar)
|
|
1822
|
+
nk_b256_vec_t lower_vec, upper_vec;
|
|
1823
|
+
lower_vec.ymm = sum_lower_i64x4;
|
|
1824
|
+
upper_vec.ymm = sum_upper_i64x4;
|
|
1825
|
+
nk_u64_t sum_lower = 0;
|
|
1826
|
+
nk_i64_t sum_upper = 0;
|
|
1827
|
+
for (int i = 0; i < 4; i++) {
|
|
1828
|
+
nk_u64_t sum_before = sum_lower;
|
|
1829
|
+
sum_lower += lower_vec.u64s[i];
|
|
1830
|
+
if (sum_lower < sum_before) sum_upper++;
|
|
1831
|
+
sum_upper += upper_vec.i64s[i];
|
|
1832
|
+
}
|
|
1833
|
+
*sumsq_ptr = sumsq;
|
|
1834
|
+
nk_i64_t sum_lower_signed = (nk_i64_t)sum_lower;
|
|
1835
|
+
if (sum_upper == (sum_lower_signed >> 63)) *sum_ptr = sum_lower_signed;
|
|
1836
|
+
else if (sum_upper >= 0) *sum_ptr = NK_I64_MAX;
|
|
1837
|
+
else *sum_ptr = NK_I64_MIN;
|
|
1838
|
+
}
|
|
1839
|
+
|
|
1840
|
+
NK_PUBLIC void nk_reduce_moments_i32_haswell( //
|
|
1841
|
+
nk_i32_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
1842
|
+
nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
1843
|
+
|
|
1844
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_i32_t);
|
|
1845
|
+
int aligned = (stride_bytes % sizeof(nk_i32_t) == 0);
|
|
1846
|
+
if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
|
|
1847
|
+
else if (!aligned) nk_reduce_moments_i32_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
1848
|
+
else if (stride_elements == 1) nk_reduce_moments_i32_haswell_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
|
|
1849
|
+
else nk_reduce_moments_i32_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
1850
|
+
}
|
|
1851
|
+
|
|
1852
|
+
NK_INTERNAL void nk_reduce_minmax_i32_haswell_contiguous_( //
|
|
1853
|
+
nk_i32_t const *data_ptr, nk_size_t count, //
|
|
1854
|
+
nk_i32_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
1855
|
+
nk_i32_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
1856
|
+
|
|
1857
|
+
__m256i min_i32x8 = _mm256_set1_epi32(NK_I32_MAX);
|
|
1858
|
+
__m256i max_i32x8 = _mm256_set1_epi32(NK_I32_MIN);
|
|
1859
|
+
__m256i min_loop_cycle_u32x8 = _mm256_setzero_si256();
|
|
1860
|
+
__m256i max_loop_cycle_u32x8 = _mm256_setzero_si256();
|
|
1861
|
+
__m256i current_loop_cycle_u32x8 = _mm256_setzero_si256();
|
|
1862
|
+
__m256i one_u32x8 = _mm256_set1_epi32(1);
|
|
1863
|
+
|
|
1864
|
+
nk_size_t idx = 0;
|
|
1865
|
+
for (; idx + 8 <= count; idx += 8) {
|
|
1866
|
+
__m256i data_i32x8 = _mm256_loadu_si256((__m256i const *)(data_ptr + idx));
|
|
1867
|
+
__m256i less_b32x8 = _mm256_cmpgt_epi32(min_i32x8, data_i32x8);
|
|
1868
|
+
__m256i greater_b32x8 = _mm256_cmpgt_epi32(data_i32x8, max_i32x8);
|
|
1869
|
+
min_i32x8 = _mm256_blendv_epi8(min_i32x8, data_i32x8, less_b32x8);
|
|
1870
|
+
max_i32x8 = _mm256_blendv_epi8(max_i32x8, data_i32x8, greater_b32x8);
|
|
1871
|
+
min_loop_cycle_u32x8 = _mm256_blendv_epi8(min_loop_cycle_u32x8, current_loop_cycle_u32x8, less_b32x8);
|
|
1872
|
+
max_loop_cycle_u32x8 = _mm256_blendv_epi8(max_loop_cycle_u32x8, current_loop_cycle_u32x8, greater_b32x8);
|
|
1873
|
+
current_loop_cycle_u32x8 = _mm256_add_epi32(current_loop_cycle_u32x8, one_u32x8);
|
|
1874
|
+
}
|
|
1875
|
+
|
|
1876
|
+
nk_size_t remaining = count - idx;
|
|
1877
|
+
if (remaining > 0) {
|
|
1878
|
+
nk_b256_vec_t tail_vec;
|
|
1879
|
+
nk_partial_load_b32x8_serial_(data_ptr + idx, &tail_vec, remaining);
|
|
1880
|
+
__m256i lane_indices_u32x8 = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
|
|
1881
|
+
__m256i valid_b32x8 = _mm256_cmpgt_epi32(_mm256_set1_epi32((int)remaining), lane_indices_u32x8);
|
|
1882
|
+
__m256i data_min_i32x8 = _mm256_blendv_epi8(_mm256_set1_epi32(NK_I32_MAX), tail_vec.ymm, valid_b32x8);
|
|
1883
|
+
__m256i data_max_i32x8 = _mm256_blendv_epi8(_mm256_set1_epi32(NK_I32_MIN), tail_vec.ymm, valid_b32x8);
|
|
1884
|
+
__m256i less_b32x8 = _mm256_cmpgt_epi32(min_i32x8, data_min_i32x8);
|
|
1885
|
+
__m256i greater_b32x8 = _mm256_cmpgt_epi32(data_max_i32x8, max_i32x8);
|
|
1886
|
+
min_i32x8 = _mm256_blendv_epi8(min_i32x8, data_min_i32x8, less_b32x8);
|
|
1887
|
+
max_i32x8 = _mm256_blendv_epi8(max_i32x8, data_max_i32x8, greater_b32x8);
|
|
1888
|
+
min_loop_cycle_u32x8 = _mm256_blendv_epi8(min_loop_cycle_u32x8, current_loop_cycle_u32x8, less_b32x8);
|
|
1889
|
+
max_loop_cycle_u32x8 = _mm256_blendv_epi8(max_loop_cycle_u32x8, current_loop_cycle_u32x8, greater_b32x8);
|
|
1890
|
+
}
|
|
1891
|
+
|
|
1892
|
+
nk_i32_t min_value = nk_reduce_min_i32x8_haswell_(min_i32x8);
|
|
1893
|
+
nk_i32_t max_value = nk_reduce_max_i32x8_haswell_(max_i32x8);
|
|
1894
|
+
unsigned int min_lane, max_lane;
|
|
1895
|
+
{
|
|
1896
|
+
__m256i value_match_b32x8 = _mm256_cmpeq_epi32(min_i32x8, _mm256_set1_epi32(min_value));
|
|
1897
|
+
__m256i masked_cycle_u32x8 = _mm256_blendv_epi8(_mm256_set1_epi32((int)NK_U32_MAX), min_loop_cycle_u32x8,
|
|
1898
|
+
value_match_b32x8);
|
|
1899
|
+
nk_u32_t earliest_loop_cycle = nk_reduce_min_u32x8_haswell_(masked_cycle_u32x8);
|
|
1900
|
+
__m256i cycle_match_b32x8 = _mm256_cmpeq_epi32(masked_cycle_u32x8, _mm256_set1_epi32((int)earliest_loop_cycle));
|
|
1901
|
+
min_lane = _tzcnt_u32((unsigned int)_mm256_movemask_ps(_mm256_castsi256_ps(cycle_match_b32x8)));
|
|
1902
|
+
}
|
|
1903
|
+
{
|
|
1904
|
+
__m256i value_match_b32x8 = _mm256_cmpeq_epi32(max_i32x8, _mm256_set1_epi32(max_value));
|
|
1905
|
+
__m256i masked_cycle_u32x8 = _mm256_blendv_epi8(_mm256_set1_epi32((int)NK_U32_MAX), max_loop_cycle_u32x8,
|
|
1906
|
+
value_match_b32x8);
|
|
1907
|
+
nk_u32_t earliest_loop_cycle = nk_reduce_min_u32x8_haswell_(masked_cycle_u32x8);
|
|
1908
|
+
__m256i cycle_match_b32x8 = _mm256_cmpeq_epi32(masked_cycle_u32x8, _mm256_set1_epi32((int)earliest_loop_cycle));
|
|
1909
|
+
max_lane = _tzcnt_u32((unsigned int)_mm256_movemask_ps(_mm256_castsi256_ps(cycle_match_b32x8)));
|
|
1910
|
+
}
|
|
1911
|
+
nk_b256_vec_t loop_cycle_vec;
|
|
1912
|
+
loop_cycle_vec.ymm = min_loop_cycle_u32x8;
|
|
1913
|
+
*min_value_ptr = min_value, *min_index_ptr = (nk_size_t)loop_cycle_vec.u32s[min_lane] * 8 + min_lane;
|
|
1914
|
+
loop_cycle_vec.ymm = max_loop_cycle_u32x8;
|
|
1915
|
+
*max_value_ptr = max_value, *max_index_ptr = (nk_size_t)loop_cycle_vec.u32s[max_lane] * 8 + max_lane;
|
|
1916
|
+
}
|
|
1917
|
+
|
|
1918
|
+
NK_PUBLIC void nk_reduce_minmax_i32_haswell( //
|
|
1919
|
+
nk_i32_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
1920
|
+
nk_i32_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
1921
|
+
nk_i32_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
1922
|
+
|
|
1923
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_i32_t);
|
|
1924
|
+
int aligned = (stride_bytes % sizeof(nk_i32_t) == 0);
|
|
1925
|
+
if (count == 0)
|
|
1926
|
+
*min_value_ptr = NK_I32_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_I32_MIN,
|
|
1927
|
+
*max_index_ptr = NK_SIZE_MAX;
|
|
1928
|
+
else if (!aligned)
|
|
1929
|
+
nk_reduce_minmax_i32_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
1930
|
+
max_index_ptr);
|
|
1931
|
+
else if (count > (nk_size_t)NK_U32_MAX * 8) {
|
|
1932
|
+
nk_size_t left_count = count / 2;
|
|
1933
|
+
nk_i32_t left_min, right_min, left_max, right_max;
|
|
1934
|
+
nk_size_t left_min_index, right_min_index, left_max_index, right_max_index;
|
|
1935
|
+
nk_reduce_minmax_i32_haswell(data_ptr, left_count, stride_bytes, &left_min, &left_min_index, &left_max,
|
|
1936
|
+
&left_max_index);
|
|
1937
|
+
nk_reduce_minmax_i32_haswell(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
|
|
1938
|
+
&right_min, &right_min_index, &right_max, &right_max_index);
|
|
1939
|
+
if (right_min < left_min) *min_value_ptr = right_min, *min_index_ptr = left_count + right_min_index;
|
|
1940
|
+
else *min_value_ptr = left_min, *min_index_ptr = left_min_index;
|
|
1941
|
+
if (right_max > left_max) *max_value_ptr = right_max, *max_index_ptr = left_count + right_max_index;
|
|
1942
|
+
else *max_value_ptr = left_max, *max_index_ptr = left_max_index;
|
|
1943
|
+
}
|
|
1944
|
+
else if (stride_elements == 1)
|
|
1945
|
+
nk_reduce_minmax_i32_haswell_contiguous_(data_ptr, count, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
1946
|
+
max_index_ptr);
|
|
1947
|
+
else
|
|
1948
|
+
nk_reduce_minmax_i32_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
1949
|
+
max_index_ptr);
|
|
1950
|
+
}
|
|
1951
|
+
|
|
1952
|
+
NK_INTERNAL void nk_reduce_moments_u32_haswell_contiguous_( //
|
|
1953
|
+
nk_u32_t const *data_ptr, nk_size_t count, //
|
|
1954
|
+
nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
1955
|
+
|
|
1956
|
+
__m256i sum_u64x4 = _mm256_setzero_si256();
|
|
1957
|
+
__m256i sumsq_u64x4 = _mm256_setzero_si256();
|
|
1958
|
+
nk_size_t idx = 0;
|
|
1959
|
+
for (; idx + 8 <= count; idx += 8) {
|
|
1960
|
+
__m256i data_u32x8 = _mm256_loadu_si256((__m256i const *)(data_ptr + idx));
|
|
1961
|
+
sum_u64x4 = _mm256_add_epi64(sum_u64x4, _mm256_cvtepu32_epi64(_mm256_castsi256_si128(data_u32x8)));
|
|
1962
|
+
sum_u64x4 = _mm256_add_epi64(sum_u64x4, _mm256_cvtepu32_epi64(_mm256_extracti128_si256(data_u32x8, 1)));
|
|
1963
|
+
__m256i even_sq_u64x4 = _mm256_mul_epu32(data_u32x8, data_u32x8);
|
|
1964
|
+
__m256i odd_u32x8 = _mm256_srli_epi64(data_u32x8, 32);
|
|
1965
|
+
__m256i odd_sq_u64x4 = _mm256_mul_epu32(odd_u32x8, odd_u32x8);
|
|
1966
|
+
sumsq_u64x4 = nk_u64_sadd_epi64_haswell_(sumsq_u64x4, even_sq_u64x4);
|
|
1967
|
+
sumsq_u64x4 = nk_u64_sadd_epi64_haswell_(sumsq_u64x4, odd_sq_u64x4);
|
|
1968
|
+
}
|
|
1969
|
+
nk_size_t remaining = count - idx;
|
|
1970
|
+
if (remaining > 0) {
|
|
1971
|
+
nk_b256_vec_t tail_vec;
|
|
1972
|
+
nk_partial_load_b32x8_serial_(data_ptr + idx, &tail_vec, remaining);
|
|
1973
|
+
__m256i data_u32x8 = tail_vec.ymm;
|
|
1974
|
+
sum_u64x4 = _mm256_add_epi64(sum_u64x4, _mm256_cvtepu32_epi64(_mm256_castsi256_si128(data_u32x8)));
|
|
1975
|
+
sum_u64x4 = _mm256_add_epi64(sum_u64x4, _mm256_cvtepu32_epi64(_mm256_extracti128_si256(data_u32x8, 1)));
|
|
1976
|
+
__m256i even_sq_u64x4 = _mm256_mul_epu32(data_u32x8, data_u32x8);
|
|
1977
|
+
__m256i odd_u32x8 = _mm256_srli_epi64(data_u32x8, 32);
|
|
1978
|
+
__m256i odd_sq_u64x4 = _mm256_mul_epu32(odd_u32x8, odd_u32x8);
|
|
1979
|
+
sumsq_u64x4 = nk_u64_sadd_epi64_haswell_(sumsq_u64x4, even_sq_u64x4);
|
|
1980
|
+
sumsq_u64x4 = nk_u64_sadd_epi64_haswell_(sumsq_u64x4, odd_sq_u64x4);
|
|
1981
|
+
}
|
|
1982
|
+
*sum_ptr = (nk_u64_t)nk_reduce_add_i64x4_haswell_(sum_u64x4),
|
|
1983
|
+
*sumsq_ptr = nk_reduce_sadd_u64x4_haswell_(sumsq_u64x4);
|
|
1984
|
+
}
|
|
1985
|
+
|
|
1986
|
+
NK_PUBLIC void nk_reduce_moments_u32_haswell( //
|
|
1987
|
+
nk_u32_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
1988
|
+
nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
1989
|
+
|
|
1990
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_u32_t);
|
|
1991
|
+
int aligned = (stride_bytes % sizeof(nk_u32_t) == 0);
|
|
1992
|
+
if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
|
|
1993
|
+
else if (!aligned) nk_reduce_moments_u32_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
1994
|
+
else if (count > (nk_size_t)(NK_U16_MAX + 1) * 8) {
|
|
1995
|
+
nk_size_t left_count = count / 2;
|
|
1996
|
+
nk_u64_t left_sum, left_sumsq, right_sum, right_sumsq;
|
|
1997
|
+
nk_reduce_moments_u32_haswell(data_ptr, left_count, stride_bytes, &left_sum, &left_sumsq);
|
|
1998
|
+
nk_reduce_moments_u32_haswell(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
|
|
1999
|
+
&right_sum, &right_sumsq);
|
|
2000
|
+
*sum_ptr = nk_u64_saturating_add_serial(left_sum, right_sum);
|
|
2001
|
+
*sumsq_ptr = nk_u64_saturating_add_serial(left_sumsq, right_sumsq);
|
|
2002
|
+
}
|
|
2003
|
+
else if (stride_elements == 1) nk_reduce_moments_u32_haswell_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
|
|
2004
|
+
else nk_reduce_moments_u32_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
2005
|
+
}
|
|
2006
|
+
|
|
2007
|
+
NK_INTERNAL void nk_reduce_minmax_u32_haswell_contiguous_( //
|
|
2008
|
+
nk_u32_t const *data_ptr, nk_size_t count, //
|
|
2009
|
+
nk_u32_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
2010
|
+
nk_u32_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
2011
|
+
|
|
2012
|
+
// XOR-bias to signed domain for _mm256_cmpgt_epi32
|
|
2013
|
+
__m256i bias_u32x8 = _mm256_set1_epi32((nk_i32_t)0x80000000);
|
|
2014
|
+
__m256i min_biased_i32x8 = _mm256_set1_epi32(NK_I32_MAX);
|
|
2015
|
+
__m256i max_biased_i32x8 = _mm256_set1_epi32(NK_I32_MIN);
|
|
2016
|
+
__m256i min_loop_cycle_u32x8 = _mm256_setzero_si256();
|
|
2017
|
+
__m256i max_loop_cycle_u32x8 = _mm256_setzero_si256();
|
|
2018
|
+
__m256i current_loop_cycle_u32x8 = _mm256_setzero_si256();
|
|
2019
|
+
__m256i one_u32x8 = _mm256_set1_epi32(1);
|
|
2020
|
+
|
|
2021
|
+
nk_size_t idx = 0;
|
|
2022
|
+
for (; idx + 8 <= count; idx += 8) {
|
|
2023
|
+
__m256i data_biased_i32x8 = _mm256_xor_si256(_mm256_loadu_si256((__m256i const *)(data_ptr + idx)), bias_u32x8);
|
|
2024
|
+
__m256i less_b32x8 = _mm256_cmpgt_epi32(min_biased_i32x8, data_biased_i32x8);
|
|
2025
|
+
__m256i greater_b32x8 = _mm256_cmpgt_epi32(data_biased_i32x8, max_biased_i32x8);
|
|
2026
|
+
min_biased_i32x8 = _mm256_blendv_epi8(min_biased_i32x8, data_biased_i32x8, less_b32x8);
|
|
2027
|
+
max_biased_i32x8 = _mm256_blendv_epi8(max_biased_i32x8, data_biased_i32x8, greater_b32x8);
|
|
2028
|
+
min_loop_cycle_u32x8 = _mm256_blendv_epi8(min_loop_cycle_u32x8, current_loop_cycle_u32x8, less_b32x8);
|
|
2029
|
+
max_loop_cycle_u32x8 = _mm256_blendv_epi8(max_loop_cycle_u32x8, current_loop_cycle_u32x8, greater_b32x8);
|
|
2030
|
+
current_loop_cycle_u32x8 = _mm256_add_epi32(current_loop_cycle_u32x8, one_u32x8);
|
|
2031
|
+
}
|
|
2032
|
+
|
|
2033
|
+
nk_size_t remaining = count - idx;
|
|
2034
|
+
if (remaining > 0) {
|
|
2035
|
+
nk_b256_vec_t tail_vec;
|
|
2036
|
+
nk_partial_load_b32x8_serial_(data_ptr + idx, &tail_vec, remaining);
|
|
2037
|
+
__m256i data_biased_i32x8 = _mm256_xor_si256(tail_vec.ymm, bias_u32x8);
|
|
2038
|
+
__m256i lane_indices_u32x8 = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
|
|
2039
|
+
__m256i valid_b32x8 = _mm256_cmpgt_epi32(_mm256_set1_epi32((int)remaining), lane_indices_u32x8);
|
|
2040
|
+
// Biased identity: NK_U32_MAX ^ 0x80000000 = 0x7FFFFFFF for min, 0 ^ 0x80000000 = 0x80000000 for max
|
|
2041
|
+
__m256i data_min_i32x8 = _mm256_blendv_epi8(_mm256_set1_epi32(0x7FFFFFFF), data_biased_i32x8, valid_b32x8);
|
|
2042
|
+
__m256i data_max_i32x8 = _mm256_blendv_epi8(_mm256_set1_epi32((nk_i32_t)0x80000000), data_biased_i32x8,
|
|
2043
|
+
valid_b32x8);
|
|
2044
|
+
__m256i less_b32x8 = _mm256_cmpgt_epi32(min_biased_i32x8, data_min_i32x8);
|
|
2045
|
+
__m256i greater_b32x8 = _mm256_cmpgt_epi32(data_max_i32x8, max_biased_i32x8);
|
|
2046
|
+
min_biased_i32x8 = _mm256_blendv_epi8(min_biased_i32x8, data_min_i32x8, less_b32x8);
|
|
2047
|
+
max_biased_i32x8 = _mm256_blendv_epi8(max_biased_i32x8, data_max_i32x8, greater_b32x8);
|
|
2048
|
+
min_loop_cycle_u32x8 = _mm256_blendv_epi8(min_loop_cycle_u32x8, current_loop_cycle_u32x8, less_b32x8);
|
|
2049
|
+
max_loop_cycle_u32x8 = _mm256_blendv_epi8(max_loop_cycle_u32x8, current_loop_cycle_u32x8, greater_b32x8);
|
|
2050
|
+
}
|
|
2051
|
+
|
|
2052
|
+
__m256i min_u32x8 = _mm256_xor_si256(min_biased_i32x8, bias_u32x8);
|
|
2053
|
+
__m256i max_u32x8 = _mm256_xor_si256(max_biased_i32x8, bias_u32x8);
|
|
2054
|
+
nk_u32_t min_value = nk_reduce_min_u32x8_haswell_(min_u32x8);
|
|
2055
|
+
nk_u32_t max_value = nk_reduce_max_u32x8_haswell_(max_u32x8);
|
|
2056
|
+
unsigned int min_lane, max_lane;
|
|
2057
|
+
{
|
|
2058
|
+
__m256i value_match_b32x8 = _mm256_cmpeq_epi32(min_u32x8, _mm256_set1_epi32((nk_i32_t)min_value));
|
|
2059
|
+
__m256i masked_cycle_u32x8 = _mm256_blendv_epi8(_mm256_set1_epi32((int)NK_U32_MAX), min_loop_cycle_u32x8,
|
|
2060
|
+
value_match_b32x8);
|
|
2061
|
+
nk_u32_t earliest_loop_cycle = nk_reduce_min_u32x8_haswell_(masked_cycle_u32x8);
|
|
2062
|
+
__m256i cycle_match_b32x8 = _mm256_cmpeq_epi32(masked_cycle_u32x8, _mm256_set1_epi32((int)earliest_loop_cycle));
|
|
2063
|
+
min_lane = _tzcnt_u32((unsigned int)_mm256_movemask_ps(_mm256_castsi256_ps(cycle_match_b32x8)));
|
|
2064
|
+
}
|
|
2065
|
+
{
|
|
2066
|
+
__m256i value_match_b32x8 = _mm256_cmpeq_epi32(max_u32x8, _mm256_set1_epi32((nk_i32_t)max_value));
|
|
2067
|
+
__m256i masked_cycle_u32x8 = _mm256_blendv_epi8(_mm256_set1_epi32((int)NK_U32_MAX), max_loop_cycle_u32x8,
|
|
2068
|
+
value_match_b32x8);
|
|
2069
|
+
nk_u32_t earliest_loop_cycle = nk_reduce_min_u32x8_haswell_(masked_cycle_u32x8);
|
|
2070
|
+
__m256i cycle_match_b32x8 = _mm256_cmpeq_epi32(masked_cycle_u32x8, _mm256_set1_epi32((int)earliest_loop_cycle));
|
|
2071
|
+
max_lane = _tzcnt_u32((unsigned int)_mm256_movemask_ps(_mm256_castsi256_ps(cycle_match_b32x8)));
|
|
2072
|
+
}
|
|
2073
|
+
nk_b256_vec_t loop_cycle_vec;
|
|
2074
|
+
loop_cycle_vec.ymm = min_loop_cycle_u32x8;
|
|
2075
|
+
*min_value_ptr = min_value, *min_index_ptr = (nk_size_t)loop_cycle_vec.u32s[min_lane] * 8 + min_lane;
|
|
2076
|
+
loop_cycle_vec.ymm = max_loop_cycle_u32x8;
|
|
2077
|
+
*max_value_ptr = max_value, *max_index_ptr = (nk_size_t)loop_cycle_vec.u32s[max_lane] * 8 + max_lane;
|
|
2078
|
+
}
|
|
2079
|
+
|
|
2080
|
+
NK_PUBLIC void nk_reduce_minmax_u32_haswell( //
|
|
2081
|
+
nk_u32_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
2082
|
+
nk_u32_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
2083
|
+
nk_u32_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
2084
|
+
|
|
2085
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_u32_t);
|
|
2086
|
+
int aligned = (stride_bytes % sizeof(nk_u32_t) == 0);
|
|
2087
|
+
if (count == 0)
|
|
2088
|
+
*min_value_ptr = NK_U32_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = 0, *max_index_ptr = NK_SIZE_MAX;
|
|
2089
|
+
else if (!aligned)
|
|
2090
|
+
nk_reduce_minmax_u32_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
2091
|
+
max_index_ptr);
|
|
2092
|
+
else if (count > (nk_size_t)NK_U32_MAX * 8) {
|
|
2093
|
+
nk_size_t left_count = count / 2;
|
|
2094
|
+
nk_u32_t left_min, right_min, left_max, right_max;
|
|
2095
|
+
nk_size_t left_min_index, right_min_index, left_max_index, right_max_index;
|
|
2096
|
+
nk_reduce_minmax_u32_haswell(data_ptr, left_count, stride_bytes, &left_min, &left_min_index, &left_max,
|
|
2097
|
+
&left_max_index);
|
|
2098
|
+
nk_reduce_minmax_u32_haswell(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
|
|
2099
|
+
&right_min, &right_min_index, &right_max, &right_max_index);
|
|
2100
|
+
if (right_min < left_min) *min_value_ptr = right_min, *min_index_ptr = left_count + right_min_index;
|
|
2101
|
+
else *min_value_ptr = left_min, *min_index_ptr = left_min_index;
|
|
2102
|
+
if (right_max > left_max) *max_value_ptr = right_max, *max_index_ptr = left_count + right_max_index;
|
|
2103
|
+
else *max_value_ptr = left_max, *max_index_ptr = left_max_index;
|
|
2104
|
+
}
|
|
2105
|
+
else if (stride_elements == 1)
|
|
2106
|
+
nk_reduce_minmax_u32_haswell_contiguous_(data_ptr, count, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
2107
|
+
max_index_ptr);
|
|
2108
|
+
else
|
|
2109
|
+
nk_reduce_minmax_u32_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
2110
|
+
max_index_ptr);
|
|
2111
|
+
}
|
|
2112
|
+
|
|
2113
|
+
NK_INTERNAL void nk_reduce_moments_i64_haswell_contiguous_( //
|
|
2114
|
+
nk_i64_t const *data_ptr, nk_size_t count, //
|
|
2115
|
+
nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
2116
|
+
|
|
2117
|
+
__m256i sum_lower_u64x4 = _mm256_setzero_si256();
|
|
2118
|
+
__m256i sum_upper_i64x4 = _mm256_setzero_si256();
|
|
2119
|
+
__m256i sumsq_u64x4 = _mm256_setzero_si256();
|
|
2120
|
+
int sumsq_overflow_mask = 0;
|
|
2121
|
+
__m256i sign_bit_i64x4 = _mm256_set1_epi64x((nk_i64_t)0x8000000000000000ULL);
|
|
2122
|
+
nk_size_t idx = 0;
|
|
2123
|
+
for (; idx + 4 <= count; idx += 4) {
|
|
2124
|
+
__m256i data_i64x4 = _mm256_loadu_si256((__m256i const *)(data_ptr + idx));
|
|
2125
|
+
__m256i squared_u64x4 = nk_i64_smul_sq_epi64_haswell_(data_i64x4);
|
|
2126
|
+
__m256i sumsq_before_u64x4 = sumsq_u64x4;
|
|
2127
|
+
sumsq_u64x4 = _mm256_add_epi64(sumsq_u64x4, squared_u64x4);
|
|
2128
|
+
__m256i sq_result_biased_u64x4 = _mm256_xor_si256(sumsq_u64x4, sign_bit_i64x4);
|
|
2129
|
+
__m256i sq_before_biased_u64x4 = _mm256_xor_si256(sumsq_before_u64x4, sign_bit_i64x4);
|
|
2130
|
+
sumsq_overflow_mask |= _mm256_movemask_pd(
|
|
2131
|
+
_mm256_castsi256_pd(_mm256_cmpgt_epi64(sq_before_biased_u64x4, sq_result_biased_u64x4)));
|
|
2132
|
+
// Vectorized 128-bit carry-propagating sum
|
|
2133
|
+
__m256i sum_before_u64x4 = sum_lower_u64x4;
|
|
2134
|
+
sum_lower_u64x4 = _mm256_add_epi64(sum_lower_u64x4, data_i64x4);
|
|
2135
|
+
__m256i before_biased_u64x4 = _mm256_xor_si256(sum_before_u64x4, sign_bit_i64x4);
|
|
2136
|
+
__m256i result_biased_u64x4 = _mm256_xor_si256(sum_lower_u64x4, sign_bit_i64x4);
|
|
2137
|
+
__m256i carry_u64x4 = _mm256_cmpgt_epi64(before_biased_u64x4, result_biased_u64x4);
|
|
2138
|
+
sum_upper_i64x4 = _mm256_sub_epi64(sum_upper_i64x4, carry_u64x4);
|
|
2139
|
+
__m256i sign_ext_i64x4 = _mm256_cmpgt_epi64(_mm256_setzero_si256(), data_i64x4);
|
|
2140
|
+
sum_upper_i64x4 = _mm256_add_epi64(sum_upper_i64x4, sign_ext_i64x4);
|
|
2141
|
+
}
|
|
2142
|
+
// Horizontal reduction of 4 lanes to scalar (sum_lower, sum_upper)
|
|
2143
|
+
nk_b256_vec_t lower_vec, upper_vec;
|
|
2144
|
+
lower_vec.ymm = sum_lower_u64x4;
|
|
2145
|
+
upper_vec.ymm = sum_upper_i64x4;
|
|
2146
|
+
nk_u64_t sum_lower = 0;
|
|
2147
|
+
nk_i64_t sum_upper = 0;
|
|
2148
|
+
for (int i = 0; i < 4; i++) {
|
|
2149
|
+
nk_u64_t before = sum_lower;
|
|
2150
|
+
sum_lower += lower_vec.u64s[i];
|
|
2151
|
+
if (sum_lower < before) sum_upper++;
|
|
2152
|
+
sum_upper += upper_vec.i64s[i];
|
|
2153
|
+
}
|
|
2154
|
+
nk_u64_t sumsq;
|
|
2155
|
+
if (sumsq_overflow_mask) sumsq = NK_U64_MAX;
|
|
2156
|
+
else sumsq = nk_reduce_sadd_u64x4_haswell_(sumsq_u64x4);
|
|
2157
|
+
for (; idx < count; ++idx) {
|
|
2158
|
+
nk_i64_t val = data_ptr[idx];
|
|
2159
|
+
nk_i64_t product = nk_i64_saturating_mul_serial(val, val);
|
|
2160
|
+
nk_u64_t unsigned_product = (nk_u64_t)product;
|
|
2161
|
+
sumsq = nk_u64_saturating_add_serial(sumsq, unsigned_product);
|
|
2162
|
+
nk_u64_t before = sum_lower;
|
|
2163
|
+
sum_lower += (nk_u64_t)val;
|
|
2164
|
+
if (sum_lower < before) sum_upper++;
|
|
2165
|
+
sum_upper += (val >> 63);
|
|
2166
|
+
}
|
|
2167
|
+
*sumsq_ptr = sumsq;
|
|
2168
|
+
nk_i64_t sum_lower_signed = (nk_i64_t)sum_lower;
|
|
2169
|
+
if (sum_upper == (sum_lower_signed >> 63)) *sum_ptr = sum_lower_signed;
|
|
2170
|
+
else if (sum_upper >= 0) *sum_ptr = NK_I64_MAX;
|
|
2171
|
+
else *sum_ptr = NK_I64_MIN;
|
|
2172
|
+
}
|
|
2173
|
+
|
|
2174
|
+
NK_PUBLIC void nk_reduce_moments_i64_haswell( //
|
|
2175
|
+
nk_i64_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
2176
|
+
nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
2177
|
+
|
|
2178
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_i64_t);
|
|
2179
|
+
int aligned = (stride_bytes % sizeof(nk_i64_t) == 0);
|
|
2180
|
+
if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
|
|
2181
|
+
else if (!aligned) nk_reduce_moments_i64_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
2182
|
+
else if (stride_elements == 1) nk_reduce_moments_i64_haswell_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
|
|
2183
|
+
else nk_reduce_moments_i64_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
2184
|
+
}
|
|
2185
|
+
|
|
2186
|
+
NK_INTERNAL void nk_reduce_minmax_i64_haswell_contiguous_( //
|
|
2187
|
+
nk_i64_t const *data_ptr, nk_size_t count, //
|
|
2188
|
+
nk_i64_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
2189
|
+
nk_i64_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
2190
|
+
|
|
2191
|
+
__m256i min_i64x4 = _mm256_set1_epi64x(NK_I64_MAX);
|
|
2192
|
+
__m256i max_i64x4 = _mm256_set1_epi64x(NK_I64_MIN);
|
|
2193
|
+
__m256i min_loop_cycle_u64x4 = _mm256_setzero_si256();
|
|
2194
|
+
__m256i max_loop_cycle_u64x4 = _mm256_setzero_si256();
|
|
2195
|
+
__m256i current_loop_cycle_u64x4 = _mm256_setzero_si256();
|
|
2196
|
+
__m256i one_u64x4 = _mm256_set1_epi64x(1);
|
|
2197
|
+
|
|
2198
|
+
nk_size_t idx = 0;
|
|
2199
|
+
for (; idx + 4 <= count; idx += 4) {
|
|
2200
|
+
__m256i data_i64x4 = _mm256_loadu_si256((__m256i const *)(data_ptr + idx));
|
|
2201
|
+
__m256i less_b64x4 = _mm256_cmpgt_epi64(min_i64x4, data_i64x4);
|
|
2202
|
+
__m256i greater_b64x4 = _mm256_cmpgt_epi64(data_i64x4, max_i64x4);
|
|
2203
|
+
min_i64x4 = _mm256_blendv_epi8(min_i64x4, data_i64x4, less_b64x4);
|
|
2204
|
+
max_i64x4 = _mm256_blendv_epi8(max_i64x4, data_i64x4, greater_b64x4);
|
|
2205
|
+
min_loop_cycle_u64x4 = _mm256_blendv_epi8(min_loop_cycle_u64x4, current_loop_cycle_u64x4, less_b64x4);
|
|
2206
|
+
max_loop_cycle_u64x4 = _mm256_blendv_epi8(max_loop_cycle_u64x4, current_loop_cycle_u64x4, greater_b64x4);
|
|
2207
|
+
current_loop_cycle_u64x4 = _mm256_add_epi64(current_loop_cycle_u64x4, one_u64x4);
|
|
2208
|
+
}
|
|
2209
|
+
|
|
2210
|
+
nk_size_t remaining = count - idx;
|
|
2211
|
+
if (remaining > 0) {
|
|
2212
|
+
nk_b256_vec_t tail_vec;
|
|
2213
|
+
nk_partial_load_b64x4_serial_(data_ptr + idx, &tail_vec, remaining);
|
|
2214
|
+
__m256i lane_indices_u64x4 = _mm256_setr_epi64x(0, 1, 2, 3);
|
|
2215
|
+
__m256i valid_b64x4 = _mm256_cmpgt_epi64(_mm256_set1_epi64x((long long)remaining), lane_indices_u64x4);
|
|
2216
|
+
__m256i data_min_i64x4 = _mm256_blendv_epi8(_mm256_set1_epi64x(NK_I64_MAX), tail_vec.ymm, valid_b64x4);
|
|
2217
|
+
__m256i data_max_i64x4 = _mm256_blendv_epi8(_mm256_set1_epi64x(NK_I64_MIN), tail_vec.ymm, valid_b64x4);
|
|
2218
|
+
__m256i less_b64x4 = _mm256_cmpgt_epi64(min_i64x4, data_min_i64x4);
|
|
2219
|
+
__m256i greater_b64x4 = _mm256_cmpgt_epi64(data_max_i64x4, max_i64x4);
|
|
2220
|
+
min_i64x4 = _mm256_blendv_epi8(min_i64x4, data_min_i64x4, less_b64x4);
|
|
2221
|
+
max_i64x4 = _mm256_blendv_epi8(max_i64x4, data_max_i64x4, greater_b64x4);
|
|
2222
|
+
min_loop_cycle_u64x4 = _mm256_blendv_epi8(min_loop_cycle_u64x4, current_loop_cycle_u64x4, less_b64x4);
|
|
2223
|
+
max_loop_cycle_u64x4 = _mm256_blendv_epi8(max_loop_cycle_u64x4, current_loop_cycle_u64x4, greater_b64x4);
|
|
2224
|
+
}
|
|
2225
|
+
|
|
2226
|
+
nk_i64_t min_value = nk_reduce_min_i64x4_haswell_(min_i64x4);
|
|
2227
|
+
nk_i64_t max_value = nk_reduce_max_i64x4_haswell_(max_i64x4);
|
|
2228
|
+
unsigned int min_lane, max_lane;
|
|
2229
|
+
{
|
|
2230
|
+
__m256i value_match_b64x4 = _mm256_cmpeq_epi64(min_i64x4, _mm256_set1_epi64x(min_value));
|
|
2231
|
+
__m256i masked_cycle_u64x4 = _mm256_blendv_epi8(_mm256_set1_epi64x((nk_i64_t)NK_U64_MAX), min_loop_cycle_u64x4,
|
|
2232
|
+
value_match_b64x4);
|
|
2233
|
+
nk_u64_t earliest_loop_cycle = nk_reduce_min_u64x4_haswell_(masked_cycle_u64x4);
|
|
2234
|
+
__m256i cycle_match_b64x4 = _mm256_cmpeq_epi64(masked_cycle_u64x4,
|
|
2235
|
+
_mm256_set1_epi64x((nk_i64_t)earliest_loop_cycle));
|
|
2236
|
+
min_lane = _tzcnt_u32((unsigned int)_mm256_movemask_pd(_mm256_castsi256_pd(cycle_match_b64x4)));
|
|
2237
|
+
}
|
|
2238
|
+
{
|
|
2239
|
+
__m256i value_match_b64x4 = _mm256_cmpeq_epi64(max_i64x4, _mm256_set1_epi64x(max_value));
|
|
2240
|
+
__m256i masked_cycle_u64x4 = _mm256_blendv_epi8(_mm256_set1_epi64x((nk_i64_t)NK_U64_MAX), max_loop_cycle_u64x4,
|
|
2241
|
+
value_match_b64x4);
|
|
2242
|
+
nk_u64_t earliest_loop_cycle = nk_reduce_min_u64x4_haswell_(masked_cycle_u64x4);
|
|
2243
|
+
__m256i cycle_match_b64x4 = _mm256_cmpeq_epi64(masked_cycle_u64x4,
|
|
2244
|
+
_mm256_set1_epi64x((nk_i64_t)earliest_loop_cycle));
|
|
2245
|
+
max_lane = _tzcnt_u32((unsigned int)_mm256_movemask_pd(_mm256_castsi256_pd(cycle_match_b64x4)));
|
|
2246
|
+
}
|
|
2247
|
+
nk_b256_vec_t loop_cycle_vec;
|
|
2248
|
+
loop_cycle_vec.ymm = min_loop_cycle_u64x4;
|
|
2249
|
+
*min_value_ptr = min_value, *min_index_ptr = (nk_size_t)loop_cycle_vec.u64s[min_lane] * 4 + min_lane;
|
|
2250
|
+
loop_cycle_vec.ymm = max_loop_cycle_u64x4;
|
|
2251
|
+
*max_value_ptr = max_value, *max_index_ptr = (nk_size_t)loop_cycle_vec.u64s[max_lane] * 4 + max_lane;
|
|
2252
|
+
}
|
|
2253
|
+
|
|
2254
|
+
NK_PUBLIC void nk_reduce_minmax_i64_haswell( //
|
|
2255
|
+
nk_i64_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
2256
|
+
nk_i64_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
2257
|
+
nk_i64_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
2258
|
+
|
|
2259
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_i64_t);
|
|
2260
|
+
int aligned = (stride_bytes % sizeof(nk_i64_t) == 0);
|
|
2261
|
+
if (count == 0)
|
|
2262
|
+
*min_value_ptr = NK_I64_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_I64_MIN,
|
|
2263
|
+
*max_index_ptr = NK_SIZE_MAX;
|
|
2264
|
+
else if (!aligned)
|
|
2265
|
+
nk_reduce_minmax_i64_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
2266
|
+
max_index_ptr);
|
|
2267
|
+
else if (stride_elements == 1)
|
|
2268
|
+
nk_reduce_minmax_i64_haswell_contiguous_(data_ptr, count, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
2269
|
+
max_index_ptr);
|
|
2270
|
+
else
|
|
2271
|
+
nk_reduce_minmax_i64_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
2272
|
+
max_index_ptr);
|
|
2273
|
+
}
|
|
2274
|
+
|
|
2275
|
+
NK_INTERNAL void nk_reduce_moments_u64_haswell_contiguous_( //
|
|
2276
|
+
nk_u64_t const *data_ptr, nk_size_t count, //
|
|
2277
|
+
nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
2278
|
+
|
|
2279
|
+
__m256i sum_u64x4 = _mm256_setzero_si256();
|
|
2280
|
+
__m256i sumsq_u64x4 = _mm256_setzero_si256();
|
|
2281
|
+
nk_size_t idx = 0;
|
|
2282
|
+
for (; idx + 4 <= count; idx += 4) {
|
|
2283
|
+
__m256i data_u64x4 = _mm256_loadu_si256((__m256i const *)(data_ptr + idx));
|
|
2284
|
+
sumsq_u64x4 = nk_u64_sadd_epi64_haswell_(sumsq_u64x4, nk_u64_smul_sq_epi64_haswell_(data_u64x4));
|
|
2285
|
+
sum_u64x4 = nk_u64_sadd_epi64_haswell_(sum_u64x4, data_u64x4);
|
|
2286
|
+
}
|
|
2287
|
+
nk_u64_t sum = nk_reduce_sadd_u64x4_haswell_(sum_u64x4);
|
|
2288
|
+
nk_u64_t sumsq = nk_reduce_sadd_u64x4_haswell_(sumsq_u64x4);
|
|
2289
|
+
for (; idx < count; ++idx) {
|
|
2290
|
+
nk_u64_t val = data_ptr[idx];
|
|
2291
|
+
nk_u64_t product = nk_u64_saturating_mul_serial(val, val);
|
|
2292
|
+
sum = nk_u64_saturating_add_serial(sum, val);
|
|
2293
|
+
sumsq = nk_u64_saturating_add_serial(sumsq, product);
|
|
2294
|
+
}
|
|
2295
|
+
*sum_ptr = sum, *sumsq_ptr = sumsq;
|
|
2296
|
+
}
|
|
2297
|
+
|
|
2298
|
+
NK_PUBLIC void nk_reduce_moments_u64_haswell( //
|
|
2299
|
+
nk_u64_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
2300
|
+
nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
2301
|
+
|
|
2302
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_u64_t);
|
|
2303
|
+
int aligned = (stride_bytes % sizeof(nk_u64_t) == 0);
|
|
2304
|
+
if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
|
|
2305
|
+
else if (!aligned) nk_reduce_moments_u64_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
2306
|
+
else if (stride_elements == 1) nk_reduce_moments_u64_haswell_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
|
|
2307
|
+
else nk_reduce_moments_u64_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
2308
|
+
}
|
|
2309
|
+
|
|
2310
|
+
NK_INTERNAL void nk_reduce_minmax_u64_haswell_contiguous_( //
|
|
2311
|
+
nk_u64_t const *data_ptr, nk_size_t count, //
|
|
2312
|
+
nk_u64_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
2313
|
+
nk_u64_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
2314
|
+
|
|
2315
|
+
// XOR-bias to signed domain for _mm256_cmpgt_epi64
|
|
2316
|
+
__m256i bias_u64x4 = _mm256_set1_epi64x((nk_i64_t)0x8000000000000000ull);
|
|
2317
|
+
__m256i min_biased_i64x4 = _mm256_set1_epi64x(NK_I64_MAX);
|
|
2318
|
+
__m256i max_biased_i64x4 = _mm256_set1_epi64x(NK_I64_MIN);
|
|
2319
|
+
__m256i min_loop_cycle_u64x4 = _mm256_setzero_si256();
|
|
2320
|
+
__m256i max_loop_cycle_u64x4 = _mm256_setzero_si256();
|
|
2321
|
+
__m256i current_loop_cycle_u64x4 = _mm256_setzero_si256();
|
|
2322
|
+
__m256i one_u64x4 = _mm256_set1_epi64x(1);
|
|
2323
|
+
|
|
2324
|
+
nk_size_t idx = 0;
|
|
2325
|
+
for (; idx + 4 <= count; idx += 4) {
|
|
2326
|
+
__m256i data_biased_i64x4 = _mm256_xor_si256(_mm256_loadu_si256((__m256i const *)(data_ptr + idx)), bias_u64x4);
|
|
2327
|
+
__m256i less_b64x4 = _mm256_cmpgt_epi64(min_biased_i64x4, data_biased_i64x4);
|
|
2328
|
+
__m256i greater_b64x4 = _mm256_cmpgt_epi64(data_biased_i64x4, max_biased_i64x4);
|
|
2329
|
+
min_biased_i64x4 = _mm256_blendv_epi8(min_biased_i64x4, data_biased_i64x4, less_b64x4);
|
|
2330
|
+
max_biased_i64x4 = _mm256_blendv_epi8(max_biased_i64x4, data_biased_i64x4, greater_b64x4);
|
|
2331
|
+
min_loop_cycle_u64x4 = _mm256_blendv_epi8(min_loop_cycle_u64x4, current_loop_cycle_u64x4, less_b64x4);
|
|
2332
|
+
max_loop_cycle_u64x4 = _mm256_blendv_epi8(max_loop_cycle_u64x4, current_loop_cycle_u64x4, greater_b64x4);
|
|
2333
|
+
current_loop_cycle_u64x4 = _mm256_add_epi64(current_loop_cycle_u64x4, one_u64x4);
|
|
2334
|
+
}
|
|
2335
|
+
|
|
2336
|
+
nk_size_t remaining = count - idx;
|
|
2337
|
+
if (remaining > 0) {
|
|
2338
|
+
nk_b256_vec_t tail_vec;
|
|
2339
|
+
nk_partial_load_b64x4_serial_(data_ptr + idx, &tail_vec, remaining);
|
|
2340
|
+
__m256i data_biased_i64x4 = _mm256_xor_si256(tail_vec.ymm, bias_u64x4);
|
|
2341
|
+
__m256i lane_indices_u64x4 = _mm256_setr_epi64x(0, 1, 2, 3);
|
|
2342
|
+
__m256i valid_b64x4 = _mm256_cmpgt_epi64(_mm256_set1_epi64x((long long)remaining), lane_indices_u64x4);
|
|
2343
|
+
// Biased identity: NK_U64_MAX ^ bias = 0x7FFF... for min, 0 ^ bias = 0x8000... for max
|
|
2344
|
+
__m256i data_min_i64x4 = _mm256_blendv_epi8(_mm256_set1_epi64x(NK_I64_MAX), data_biased_i64x4, valid_b64x4);
|
|
2345
|
+
__m256i data_max_i64x4 = _mm256_blendv_epi8(_mm256_set1_epi64x(NK_I64_MIN), data_biased_i64x4, valid_b64x4);
|
|
2346
|
+
__m256i less_b64x4 = _mm256_cmpgt_epi64(min_biased_i64x4, data_min_i64x4);
|
|
2347
|
+
__m256i greater_b64x4 = _mm256_cmpgt_epi64(data_max_i64x4, max_biased_i64x4);
|
|
2348
|
+
min_biased_i64x4 = _mm256_blendv_epi8(min_biased_i64x4, data_min_i64x4, less_b64x4);
|
|
2349
|
+
max_biased_i64x4 = _mm256_blendv_epi8(max_biased_i64x4, data_max_i64x4, greater_b64x4);
|
|
2350
|
+
min_loop_cycle_u64x4 = _mm256_blendv_epi8(min_loop_cycle_u64x4, current_loop_cycle_u64x4, less_b64x4);
|
|
2351
|
+
max_loop_cycle_u64x4 = _mm256_blendv_epi8(max_loop_cycle_u64x4, current_loop_cycle_u64x4, greater_b64x4);
|
|
2352
|
+
}
|
|
2353
|
+
|
|
2354
|
+
__m256i min_u64x4 = _mm256_xor_si256(min_biased_i64x4, bias_u64x4);
|
|
2355
|
+
__m256i max_u64x4 = _mm256_xor_si256(max_biased_i64x4, bias_u64x4);
|
|
2356
|
+
nk_u64_t min_value = nk_reduce_min_u64x4_haswell_(min_u64x4);
|
|
2357
|
+
nk_u64_t max_value = nk_reduce_max_u64x4_haswell_(max_u64x4);
|
|
2358
|
+
unsigned int min_lane, max_lane;
|
|
2359
|
+
{
|
|
2360
|
+
__m256i value_match_b64x4 = _mm256_cmpeq_epi64(min_u64x4, _mm256_set1_epi64x((nk_i64_t)min_value));
|
|
2361
|
+
__m256i masked_cycle_u64x4 = _mm256_blendv_epi8(_mm256_set1_epi64x((nk_i64_t)NK_U64_MAX), min_loop_cycle_u64x4,
|
|
2362
|
+
value_match_b64x4);
|
|
2363
|
+
nk_u64_t earliest_loop_cycle = nk_reduce_min_u64x4_haswell_(masked_cycle_u64x4);
|
|
2364
|
+
__m256i cycle_match_b64x4 = _mm256_cmpeq_epi64(masked_cycle_u64x4,
|
|
2365
|
+
_mm256_set1_epi64x((nk_i64_t)earliest_loop_cycle));
|
|
2366
|
+
min_lane = _tzcnt_u32((unsigned int)_mm256_movemask_pd(_mm256_castsi256_pd(cycle_match_b64x4)));
|
|
2367
|
+
}
|
|
2368
|
+
{
|
|
2369
|
+
__m256i value_match_b64x4 = _mm256_cmpeq_epi64(max_u64x4, _mm256_set1_epi64x((nk_i64_t)max_value));
|
|
2370
|
+
__m256i masked_cycle_u64x4 = _mm256_blendv_epi8(_mm256_set1_epi64x((nk_i64_t)NK_U64_MAX), max_loop_cycle_u64x4,
|
|
2371
|
+
value_match_b64x4);
|
|
2372
|
+
nk_u64_t earliest_loop_cycle = nk_reduce_min_u64x4_haswell_(masked_cycle_u64x4);
|
|
2373
|
+
__m256i cycle_match_b64x4 = _mm256_cmpeq_epi64(masked_cycle_u64x4,
|
|
2374
|
+
_mm256_set1_epi64x((nk_i64_t)earliest_loop_cycle));
|
|
2375
|
+
max_lane = _tzcnt_u32((unsigned int)_mm256_movemask_pd(_mm256_castsi256_pd(cycle_match_b64x4)));
|
|
2376
|
+
}
|
|
2377
|
+
nk_b256_vec_t loop_cycle_vec;
|
|
2378
|
+
loop_cycle_vec.ymm = min_loop_cycle_u64x4;
|
|
2379
|
+
*min_value_ptr = min_value, *min_index_ptr = (nk_size_t)loop_cycle_vec.u64s[min_lane] * 4 + min_lane;
|
|
2380
|
+
loop_cycle_vec.ymm = max_loop_cycle_u64x4;
|
|
2381
|
+
*max_value_ptr = max_value, *max_index_ptr = (nk_size_t)loop_cycle_vec.u64s[max_lane] * 4 + max_lane;
|
|
2382
|
+
}
|
|
2383
|
+
|
|
2384
|
+
NK_PUBLIC void nk_reduce_minmax_u64_haswell( //
|
|
2385
|
+
nk_u64_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
2386
|
+
nk_u64_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
2387
|
+
nk_u64_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
2388
|
+
|
|
2389
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_u64_t);
|
|
2390
|
+
int aligned = (stride_bytes % sizeof(nk_u64_t) == 0);
|
|
2391
|
+
if (count == 0)
|
|
2392
|
+
*min_value_ptr = NK_U64_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = 0, *max_index_ptr = NK_SIZE_MAX;
|
|
2393
|
+
else if (!aligned)
|
|
2394
|
+
nk_reduce_minmax_u64_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
2395
|
+
max_index_ptr);
|
|
2396
|
+
else if (stride_elements == 1)
|
|
2397
|
+
nk_reduce_minmax_u64_haswell_contiguous_(data_ptr, count, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
2398
|
+
max_index_ptr);
|
|
2399
|
+
else
|
|
2400
|
+
nk_reduce_minmax_u64_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
2401
|
+
max_index_ptr);
|
|
2402
|
+
}
|
|
2403
|
+
|
|
2404
|
+
NK_INTERNAL void nk_reduce_moments_e4m3_haswell_contiguous_( //
|
|
2405
|
+
nk_e4m3_t const *data_ptr, nk_size_t count, //
|
|
2406
|
+
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
2407
|
+
|
|
2408
|
+
__m256 sum_f32x8 = _mm256_setzero_ps();
|
|
2409
|
+
__m256 sumsq_f32x8 = _mm256_setzero_ps();
|
|
2410
|
+
nk_size_t idx = 0;
|
|
2411
|
+
for (; idx + 8 <= count; idx += 8) {
|
|
2412
|
+
__m256 data_f32x8 = nk_e4m3x8_to_f32x8_haswell_(_mm_loadl_epi64((__m128i const *)(data_ptr + idx)));
|
|
2413
|
+
sum_f32x8 = _mm256_add_ps(sum_f32x8, data_f32x8);
|
|
2414
|
+
sumsq_f32x8 = _mm256_fmadd_ps(data_f32x8, data_f32x8, sumsq_f32x8);
|
|
2415
|
+
}
|
|
2416
|
+
nk_size_t remaining = count - idx;
|
|
2417
|
+
if (remaining > 0) {
|
|
2418
|
+
nk_b256_vec_t vec;
|
|
2419
|
+
nk_partial_load_e4m3x8_to_f32x8_haswell_(data_ptr + idx, &vec, remaining);
|
|
2420
|
+
sum_f32x8 = _mm256_add_ps(sum_f32x8, vec.ymm_ps);
|
|
2421
|
+
sumsq_f32x8 = _mm256_fmadd_ps(vec.ymm_ps, vec.ymm_ps, sumsq_f32x8);
|
|
2422
|
+
}
|
|
2423
|
+
*sum_ptr = nk_reduce_add_f32x8_haswell_(sum_f32x8), *sumsq_ptr = nk_reduce_add_f32x8_haswell_(sumsq_f32x8);
|
|
2424
|
+
}
|
|
2425
|
+
|
|
2426
|
+
NK_INTERNAL void nk_reduce_moments_e4m3_haswell_strided_( //
|
|
2427
|
+
nk_e4m3_t const *data_ptr, nk_size_t count, nk_size_t stride_elements, //
|
|
2428
|
+
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
2429
|
+
|
|
2430
|
+
__m256 sum_f32x8 = _mm256_setzero_ps();
|
|
2431
|
+
__m256 sumsq_f32x8 = _mm256_setzero_ps();
|
|
2432
|
+
nk_size_t idx = 0;
|
|
2433
|
+
nk_size_t elems_per_chunk = 8;
|
|
2434
|
+
for (; idx + elems_per_chunk <= count; idx += elems_per_chunk) {
|
|
2435
|
+
nk_b64_vec_t buf;
|
|
2436
|
+
buf.u64 = 0;
|
|
2437
|
+
nk_e4m3_t const *ptr = data_ptr + idx * stride_elements;
|
|
2438
|
+
for (nk_size_t i = 0; i < elems_per_chunk; ++i) buf.u8s[i] = ptr[i * stride_elements];
|
|
2439
|
+
__m256 data_f32x8 = nk_e4m3x8_to_f32x8_haswell_(_mm_cvtsi64_si128(buf.u64));
|
|
2440
|
+
sum_f32x8 = _mm256_add_ps(sum_f32x8, data_f32x8);
|
|
2441
|
+
sumsq_f32x8 = _mm256_fmadd_ps(data_f32x8, data_f32x8, sumsq_f32x8);
|
|
2442
|
+
}
|
|
2443
|
+
nk_size_t remaining = count - idx;
|
|
2444
|
+
if (remaining > 0) {
|
|
2445
|
+
nk_b64_vec_t buf;
|
|
2446
|
+
buf.u64 = 0;
|
|
2447
|
+
nk_e4m3_t const *ptr = data_ptr + idx * stride_elements;
|
|
2448
|
+
for (nk_size_t i = 0; i < remaining; ++i) buf.u8s[i] = ptr[i * stride_elements];
|
|
2449
|
+
nk_b256_vec_t vec;
|
|
2450
|
+
nk_partial_load_e4m3x8_to_f32x8_haswell_((nk_e4m3_t *)&buf, &vec, remaining);
|
|
2451
|
+
sum_f32x8 = _mm256_add_ps(sum_f32x8, vec.ymm_ps);
|
|
2452
|
+
sumsq_f32x8 = _mm256_fmadd_ps(vec.ymm_ps, vec.ymm_ps, sumsq_f32x8);
|
|
2453
|
+
}
|
|
2454
|
+
*sum_ptr = nk_reduce_add_f32x8_haswell_(sum_f32x8), *sumsq_ptr = nk_reduce_add_f32x8_haswell_(sumsq_f32x8);
|
|
2455
|
+
}
|
|
2456
|
+
|
|
2457
|
+
NK_PUBLIC void nk_reduce_moments_e4m3_haswell( //
|
|
2458
|
+
nk_e4m3_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
2459
|
+
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
2460
|
+
|
|
2461
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_e4m3_t);
|
|
2462
|
+
int aligned = (stride_bytes % sizeof(nk_e4m3_t) == 0);
|
|
2463
|
+
if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
|
|
2464
|
+
else if (!aligned) nk_reduce_moments_e4m3_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
2465
|
+
else if (count > (nk_size_t)(NK_U16_MAX + 1) * 32) {
|
|
2466
|
+
nk_size_t left_count = count / 2;
|
|
2467
|
+
nk_f32_t left_sum, left_sumsq, right_sum, right_sumsq;
|
|
2468
|
+
nk_reduce_moments_e4m3_haswell(data_ptr, left_count, stride_bytes, &left_sum, &left_sumsq);
|
|
2469
|
+
nk_reduce_moments_e4m3_haswell(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
|
|
2470
|
+
&right_sum, &right_sumsq);
|
|
2471
|
+
*sum_ptr = left_sum + right_sum, *sumsq_ptr = left_sumsq + right_sumsq;
|
|
2472
|
+
}
|
|
2473
|
+
else if (stride_elements == 1) nk_reduce_moments_e4m3_haswell_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
|
|
2474
|
+
else if (stride_elements >= 2 && stride_elements <= 8)
|
|
2475
|
+
nk_reduce_moments_e4m3_haswell_strided_(data_ptr, count, stride_elements, sum_ptr, sumsq_ptr);
|
|
2476
|
+
else nk_reduce_moments_e4m3_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
2477
|
+
}
|
|
2478
|
+
|
|
2479
|
+
NK_INTERNAL void nk_reduce_minmax_e4m3_haswell_contiguous_( //
|
|
2480
|
+
nk_e4m3_t const *data_ptr, nk_size_t count, //
|
|
2481
|
+
nk_e4m3_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
2482
|
+
nk_e4m3_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
2483
|
+
|
|
2484
|
+
// E4M3 NaN: comparable 0x00 (neg NaN 0xFF) and 0xFF (pos NaN 0x7F).
|
|
2485
|
+
// Replace NaN lanes with neutral values: 0xFF for min, 0x00 for max.
|
|
2486
|
+
nk_b256_vec_t min_vec, max_vec;
|
|
2487
|
+
min_vec.ymm = _mm256_set1_epi8((char)0xFF);
|
|
2488
|
+
max_vec.ymm = _mm256_setzero_si256();
|
|
2489
|
+
__m256i min_loop_cycle_u8x32 = _mm256_setzero_si256();
|
|
2490
|
+
__m256i max_loop_cycle_u8x32 = _mm256_setzero_si256();
|
|
2491
|
+
__m256i current_loop_cycle_u8x32 = _mm256_setzero_si256();
|
|
2492
|
+
__m256i one_u8x32 = _mm256_set1_epi8(1);
|
|
2493
|
+
__m256i all_ones_u8x32 = _mm256_set1_epi8((char)0xFF);
|
|
2494
|
+
|
|
2495
|
+
nk_size_t idx = 0;
|
|
2496
|
+
for (; idx + 32 <= count; idx += 32) {
|
|
2497
|
+
__m256i data_i8x32 = _mm256_loadu_si256((__m256i const *)(data_ptr + idx));
|
|
2498
|
+
__m256i data_cmp_u8x32 = nk_fp8x32_to_u8x32_comparable_haswell_(data_i8x32);
|
|
2499
|
+
// Detect NaN: comparable == 0x00 (neg NaN) or comparable == 0xFF (pos NaN)
|
|
2500
|
+
__m256i is_nan_low_u8x32 = _mm256_cmpeq_epi8(data_cmp_u8x32, _mm256_setzero_si256());
|
|
2501
|
+
__m256i is_nan_high_u8x32 = _mm256_cmpeq_epi8(data_cmp_u8x32, all_ones_u8x32);
|
|
2502
|
+
__m256i is_nan_u8x32 = _mm256_or_si256(is_nan_low_u8x32, is_nan_high_u8x32);
|
|
2503
|
+
// Replace NaN with neutral values
|
|
2504
|
+
__m256i data_min_u8x32 = _mm256_blendv_epi8(data_cmp_u8x32, all_ones_u8x32, is_nan_u8x32);
|
|
2505
|
+
__m256i data_max_u8x32 = _mm256_blendv_epi8(data_cmp_u8x32, _mm256_setzero_si256(), is_nan_u8x32);
|
|
2506
|
+
__m256i new_min_u8x32 = _mm256_min_epu8(min_vec.ymm, data_min_u8x32);
|
|
2507
|
+
__m256i min_changed_i8x32 = _mm256_xor_si256(_mm256_cmpeq_epi8(new_min_u8x32, min_vec.ymm), all_ones_u8x32);
|
|
2508
|
+
min_vec.ymm = new_min_u8x32;
|
|
2509
|
+
min_loop_cycle_u8x32 = _mm256_blendv_epi8(min_loop_cycle_u8x32, current_loop_cycle_u8x32, min_changed_i8x32);
|
|
2510
|
+
__m256i new_max_u8x32 = _mm256_max_epu8(max_vec.ymm, data_max_u8x32);
|
|
2511
|
+
__m256i max_changed_i8x32 = _mm256_xor_si256(_mm256_cmpeq_epi8(new_max_u8x32, max_vec.ymm), all_ones_u8x32);
|
|
2512
|
+
max_vec.ymm = new_max_u8x32;
|
|
2513
|
+
max_loop_cycle_u8x32 = _mm256_blendv_epi8(max_loop_cycle_u8x32, current_loop_cycle_u8x32, max_changed_i8x32);
|
|
2514
|
+
current_loop_cycle_u8x32 = _mm256_add_epi8(current_loop_cycle_u8x32, one_u8x32);
|
|
2515
|
+
}
|
|
2516
|
+
|
|
2517
|
+
nk_size_t remaining = count - idx;
|
|
2518
|
+
if (remaining > 0) {
|
|
2519
|
+
nk_b256_vec_t tail_vec;
|
|
2520
|
+
tail_vec.ymm = _mm256_set1_epi8((char)0xFF);
|
|
2521
|
+
nk_partial_load_b8x32_serial_(data_ptr + idx, &tail_vec, remaining);
|
|
2522
|
+
for (nk_size_t i = remaining; i < 32; ++i) tail_vec.u8s[i] = 0xFF;
|
|
2523
|
+
__m256i data_cmp_u8x32 = nk_fp8x32_to_u8x32_comparable_haswell_(tail_vec.ymm);
|
|
2524
|
+
// Detect NaN in valid lanes, then combine with invalid-lane neutralization
|
|
2525
|
+
__m256i is_nan_low_u8x32 = _mm256_cmpeq_epi8(data_cmp_u8x32, _mm256_setzero_si256());
|
|
2526
|
+
__m256i is_nan_high_u8x32 = _mm256_cmpeq_epi8(data_cmp_u8x32, all_ones_u8x32);
|
|
2527
|
+
__m256i is_nan_u8x32 = _mm256_or_si256(is_nan_low_u8x32, is_nan_high_u8x32);
|
|
2528
|
+
__m256i data_min_u8x32 = _mm256_blendv_epi8(data_cmp_u8x32, all_ones_u8x32, is_nan_u8x32);
|
|
2529
|
+
__m256i data_max_u8x32 = _mm256_blendv_epi8(data_cmp_u8x32, _mm256_setzero_si256(), is_nan_u8x32);
|
|
2530
|
+
// Fill invalid lanes with neutral values
|
|
2531
|
+
nk_b256_vec_t min_cmp_vec, max_cmp_vec;
|
|
2532
|
+
min_cmp_vec.ymm = data_min_u8x32;
|
|
2533
|
+
max_cmp_vec.ymm = data_max_u8x32;
|
|
2534
|
+
for (nk_size_t i = remaining; i < 32; ++i) min_cmp_vec.u8s[i] = 0xFF, max_cmp_vec.u8s[i] = 0x00;
|
|
2535
|
+
__m256i new_min_u8x32 = _mm256_min_epu8(min_vec.ymm, min_cmp_vec.ymm);
|
|
2536
|
+
__m256i min_changed_i8x32 = _mm256_xor_si256(_mm256_cmpeq_epi8(new_min_u8x32, min_vec.ymm), all_ones_u8x32);
|
|
2537
|
+
min_vec.ymm = new_min_u8x32;
|
|
2538
|
+
min_loop_cycle_u8x32 = _mm256_blendv_epi8(min_loop_cycle_u8x32, current_loop_cycle_u8x32, min_changed_i8x32);
|
|
2539
|
+
__m256i new_max_u8x32 = _mm256_max_epu8(max_vec.ymm, max_cmp_vec.ymm);
|
|
2540
|
+
__m256i max_changed_i8x32 = _mm256_xor_si256(_mm256_cmpeq_epi8(new_max_u8x32, max_vec.ymm), all_ones_u8x32);
|
|
2541
|
+
max_vec.ymm = new_max_u8x32;
|
|
2542
|
+
max_loop_cycle_u8x32 = _mm256_blendv_epi8(max_loop_cycle_u8x32, current_loop_cycle_u8x32, max_changed_i8x32);
|
|
2543
|
+
}
|
|
2544
|
+
|
|
2545
|
+
nk_u8_t min_value_comparable = nk_reduce_min_u8x32_haswell_(min_vec.ymm);
|
|
2546
|
+
nk_u8_t max_value_comparable = nk_reduce_max_u8x32_haswell_(max_vec.ymm);
|
|
2547
|
+
|
|
2548
|
+
// All-NaN early return: both sentinels unchanged means no valid data was found
|
|
2549
|
+
if (min_value_comparable == 0xFF && max_value_comparable == 0x00) {
|
|
2550
|
+
*min_value_ptr = NK_E4M3_MAX, *min_index_ptr = NK_SIZE_MAX;
|
|
2551
|
+
*max_value_ptr = NK_E4M3_MIN, *max_index_ptr = NK_SIZE_MAX;
|
|
2552
|
+
return;
|
|
2553
|
+
}
|
|
2554
|
+
|
|
2555
|
+
if (min_value_comparable == 0xFF) { *min_value_ptr = NK_E4M3_MAX, *min_index_ptr = NK_SIZE_MAX; }
|
|
2556
|
+
else {
|
|
2557
|
+
unsigned int min_lane;
|
|
2558
|
+
__m256i value_match_b8x32 = _mm256_cmpeq_epi8(min_vec.ymm, _mm256_set1_epi8((char)min_value_comparable));
|
|
2559
|
+
__m256i masked_cycle_u8x32 = _mm256_blendv_epi8(_mm256_set1_epi8((char)NK_U8_MAX), min_loop_cycle_u8x32,
|
|
2560
|
+
value_match_b8x32);
|
|
2561
|
+
nk_u8_t earliest_loop_cycle = nk_reduce_min_u8x32_haswell_(masked_cycle_u8x32);
|
|
2562
|
+
__m256i cycle_match_b8x32 = _mm256_cmpeq_epi8(masked_cycle_u8x32, _mm256_set1_epi8((char)earliest_loop_cycle));
|
|
2563
|
+
min_lane = _tzcnt_u32((unsigned int)_mm256_movemask_epi8(cycle_match_b8x32));
|
|
2564
|
+
nk_b256_vec_t loop_cycle_vec;
|
|
2565
|
+
loop_cycle_vec.ymm = min_loop_cycle_u8x32;
|
|
2566
|
+
*min_index_ptr = (nk_size_t)loop_cycle_vec.u8s[min_lane] * 32 + min_lane;
|
|
2567
|
+
nk_b256_vec_t min_raw_vec;
|
|
2568
|
+
min_raw_vec.ymm = nk_u8x32_comparable_to_fp8x32_haswell_(min_vec.ymm);
|
|
2569
|
+
*min_value_ptr = min_raw_vec.e4m3s[min_lane];
|
|
2570
|
+
}
|
|
2571
|
+
if (max_value_comparable == 0x00) { *max_value_ptr = NK_E4M3_MIN, *max_index_ptr = NK_SIZE_MAX; }
|
|
2572
|
+
else {
|
|
2573
|
+
unsigned int max_lane;
|
|
2574
|
+
__m256i value_match_b8x32 = _mm256_cmpeq_epi8(max_vec.ymm, _mm256_set1_epi8((char)max_value_comparable));
|
|
2575
|
+
__m256i masked_cycle_u8x32 = _mm256_blendv_epi8(_mm256_set1_epi8((char)NK_U8_MAX), max_loop_cycle_u8x32,
|
|
2576
|
+
value_match_b8x32);
|
|
2577
|
+
nk_u8_t earliest_loop_cycle = nk_reduce_min_u8x32_haswell_(masked_cycle_u8x32);
|
|
2578
|
+
__m256i cycle_match_b8x32 = _mm256_cmpeq_epi8(masked_cycle_u8x32, _mm256_set1_epi8((char)earliest_loop_cycle));
|
|
2579
|
+
max_lane = _tzcnt_u32((unsigned int)_mm256_movemask_epi8(cycle_match_b8x32));
|
|
2580
|
+
nk_b256_vec_t loop_cycle_vec;
|
|
2581
|
+
loop_cycle_vec.ymm = max_loop_cycle_u8x32;
|
|
2582
|
+
*max_index_ptr = (nk_size_t)loop_cycle_vec.u8s[max_lane] * 32 + max_lane;
|
|
2583
|
+
nk_b256_vec_t max_raw_vec;
|
|
2584
|
+
max_raw_vec.ymm = nk_u8x32_comparable_to_fp8x32_haswell_(max_vec.ymm);
|
|
2585
|
+
*max_value_ptr = max_raw_vec.e4m3s[max_lane];
|
|
2586
|
+
}
|
|
2587
|
+
}
|
|
2588
|
+
|
|
2589
|
+
NK_PUBLIC void nk_reduce_minmax_e4m3_haswell( //
|
|
2590
|
+
nk_e4m3_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
2591
|
+
nk_e4m3_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
2592
|
+
nk_e4m3_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
2593
|
+
|
|
2594
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_e4m3_t);
|
|
2595
|
+
if (count == 0)
|
|
2596
|
+
*min_value_ptr = NK_E4M3_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_E4M3_MIN,
|
|
2597
|
+
*max_index_ptr = NK_SIZE_MAX;
|
|
2598
|
+
else if (stride_elements == 1 && count > (nk_size_t)(NK_U8_MAX + 1) * 32) {
|
|
2599
|
+
nk_size_t left_count = count / 2;
|
|
2600
|
+
nk_e4m3_t left_min, right_min, left_max, right_max;
|
|
2601
|
+
nk_size_t left_min_index, right_min_index, left_max_index, right_max_index;
|
|
2602
|
+
nk_reduce_minmax_e4m3_haswell(data_ptr, left_count, stride_bytes, &left_min, &left_min_index, &left_max,
|
|
2603
|
+
&left_max_index);
|
|
2604
|
+
nk_reduce_minmax_e4m3_haswell(data_ptr + left_count, count - left_count, stride_bytes, &right_min,
|
|
2605
|
+
&right_min_index, &right_max, &right_max_index);
|
|
2606
|
+
// Prefer the side that found valid data (NK_SIZE_MAX means all-NaN)
|
|
2607
|
+
if (left_min_index == NK_SIZE_MAX)
|
|
2608
|
+
*min_value_ptr = right_min,
|
|
2609
|
+
*min_index_ptr = right_min_index == NK_SIZE_MAX ? NK_SIZE_MAX : left_count + right_min_index;
|
|
2610
|
+
else if (right_min_index == NK_SIZE_MAX || nk_e4m3_order_serial(left_min, right_min) <= 0)
|
|
2611
|
+
*min_value_ptr = left_min, *min_index_ptr = left_min_index;
|
|
2612
|
+
else *min_value_ptr = right_min, *min_index_ptr = left_count + right_min_index;
|
|
2613
|
+
if (left_max_index == NK_SIZE_MAX)
|
|
2614
|
+
*max_value_ptr = right_max,
|
|
2615
|
+
*max_index_ptr = right_max_index == NK_SIZE_MAX ? NK_SIZE_MAX : left_count + right_max_index;
|
|
2616
|
+
else if (right_max_index == NK_SIZE_MAX || nk_e4m3_order_serial(left_max, right_max) >= 0)
|
|
2617
|
+
*max_value_ptr = left_max, *max_index_ptr = left_max_index;
|
|
2618
|
+
else *max_value_ptr = right_max, *max_index_ptr = left_count + right_max_index;
|
|
2619
|
+
}
|
|
2620
|
+
else if (stride_elements == 1)
|
|
2621
|
+
nk_reduce_minmax_e4m3_haswell_contiguous_(data_ptr, count, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
2622
|
+
max_index_ptr);
|
|
2623
|
+
else
|
|
2624
|
+
nk_reduce_minmax_e4m3_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
2625
|
+
max_index_ptr);
|
|
2626
|
+
}
|
|
2627
|
+
|
|
2628
|
+
NK_INTERNAL void nk_reduce_moments_e5m2_haswell_contiguous_( //
|
|
2629
|
+
nk_e5m2_t const *data_ptr, nk_size_t count, //
|
|
2630
|
+
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
2631
|
+
|
|
2632
|
+
__m256 sum_f32x8 = _mm256_setzero_ps();
|
|
2633
|
+
__m256 sumsq_f32x8 = _mm256_setzero_ps();
|
|
2634
|
+
nk_size_t idx = 0;
|
|
2635
|
+
for (; idx + 8 <= count; idx += 8) {
|
|
2636
|
+
__m256 data_f32x8 = nk_e5m2x8_to_f32x8_haswell_(_mm_loadl_epi64((__m128i const *)(data_ptr + idx)));
|
|
2637
|
+
sum_f32x8 = _mm256_add_ps(sum_f32x8, data_f32x8);
|
|
2638
|
+
sumsq_f32x8 = _mm256_fmadd_ps(data_f32x8, data_f32x8, sumsq_f32x8);
|
|
2639
|
+
}
|
|
2640
|
+
nk_size_t remaining = count - idx;
|
|
2641
|
+
if (remaining > 0) {
|
|
2642
|
+
nk_b256_vec_t vec;
|
|
2643
|
+
nk_partial_load_e5m2x8_to_f32x8_haswell_(data_ptr + idx, &vec, remaining);
|
|
2644
|
+
sum_f32x8 = _mm256_add_ps(sum_f32x8, vec.ymm_ps);
|
|
2645
|
+
sumsq_f32x8 = _mm256_fmadd_ps(vec.ymm_ps, vec.ymm_ps, sumsq_f32x8);
|
|
2646
|
+
}
|
|
2647
|
+
*sum_ptr = nk_reduce_add_f32x8_haswell_(sum_f32x8), *sumsq_ptr = nk_reduce_add_f32x8_haswell_(sumsq_f32x8);
|
|
2648
|
+
}
|
|
2649
|
+
|
|
2650
|
+
NK_INTERNAL void nk_reduce_moments_e5m2_haswell_strided_( //
|
|
2651
|
+
nk_e5m2_t const *data_ptr, nk_size_t count, nk_size_t stride_elements, //
|
|
2652
|
+
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
2653
|
+
|
|
2654
|
+
__m256 sum_f32x8 = _mm256_setzero_ps();
|
|
2655
|
+
__m256 sumsq_f32x8 = _mm256_setzero_ps();
|
|
2656
|
+
nk_size_t idx = 0;
|
|
2657
|
+
nk_size_t elems_per_chunk = 8;
|
|
2658
|
+
for (; idx + elems_per_chunk <= count; idx += elems_per_chunk) {
|
|
2659
|
+
nk_b64_vec_t buf;
|
|
2660
|
+
buf.u64 = 0;
|
|
2661
|
+
nk_e5m2_t const *ptr = data_ptr + idx * stride_elements;
|
|
2662
|
+
for (nk_size_t i = 0; i < elems_per_chunk; ++i) buf.u8s[i] = ptr[i * stride_elements];
|
|
2663
|
+
__m256 data_f32x8 = nk_e5m2x8_to_f32x8_haswell_(_mm_cvtsi64_si128(buf.u64));
|
|
2664
|
+
sum_f32x8 = _mm256_add_ps(sum_f32x8, data_f32x8);
|
|
2665
|
+
sumsq_f32x8 = _mm256_fmadd_ps(data_f32x8, data_f32x8, sumsq_f32x8);
|
|
2666
|
+
}
|
|
2667
|
+
nk_size_t remaining = count - idx;
|
|
2668
|
+
if (remaining > 0) {
|
|
2669
|
+
nk_b64_vec_t buf;
|
|
2670
|
+
buf.u64 = 0;
|
|
2671
|
+
nk_e5m2_t const *ptr = data_ptr + idx * stride_elements;
|
|
2672
|
+
for (nk_size_t i = 0; i < remaining; ++i) buf.u8s[i] = ptr[i * stride_elements];
|
|
2673
|
+
nk_b256_vec_t vec;
|
|
2674
|
+
nk_partial_load_e5m2x8_to_f32x8_haswell_((nk_e5m2_t *)&buf, &vec, remaining);
|
|
2675
|
+
sum_f32x8 = _mm256_add_ps(sum_f32x8, vec.ymm_ps);
|
|
2676
|
+
sumsq_f32x8 = _mm256_fmadd_ps(vec.ymm_ps, vec.ymm_ps, sumsq_f32x8);
|
|
2677
|
+
}
|
|
2678
|
+
*sum_ptr = nk_reduce_add_f32x8_haswell_(sum_f32x8), *sumsq_ptr = nk_reduce_add_f32x8_haswell_(sumsq_f32x8);
|
|
2679
|
+
}
|
|
2680
|
+
|
|
2681
|
+
NK_PUBLIC void nk_reduce_moments_e5m2_haswell( //
|
|
2682
|
+
nk_e5m2_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
2683
|
+
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
2684
|
+
|
|
2685
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_e5m2_t);
|
|
2686
|
+
int aligned = (stride_bytes % sizeof(nk_e5m2_t) == 0);
|
|
2687
|
+
if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
|
|
2688
|
+
else if (!aligned) nk_reduce_moments_e5m2_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
2689
|
+
else if (count > (nk_size_t)(NK_U16_MAX + 1) * 32) {
|
|
2690
|
+
nk_size_t left_count = count / 2;
|
|
2691
|
+
nk_f32_t left_sum, left_sumsq, right_sum, right_sumsq;
|
|
2692
|
+
nk_reduce_moments_e5m2_haswell(data_ptr, left_count, stride_bytes, &left_sum, &left_sumsq);
|
|
2693
|
+
nk_reduce_moments_e5m2_haswell(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
|
|
2694
|
+
&right_sum, &right_sumsq);
|
|
2695
|
+
*sum_ptr = left_sum + right_sum, *sumsq_ptr = left_sumsq + right_sumsq;
|
|
2696
|
+
}
|
|
2697
|
+
else if (stride_elements == 1) nk_reduce_moments_e5m2_haswell_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
|
|
2698
|
+
else if (stride_elements >= 2 && stride_elements <= 8)
|
|
2699
|
+
nk_reduce_moments_e5m2_haswell_strided_(data_ptr, count, stride_elements, sum_ptr, sumsq_ptr);
|
|
2700
|
+
else nk_reduce_moments_e5m2_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
2701
|
+
}
|
|
2702
|
+
|
|
2703
|
+
NK_INTERNAL void nk_reduce_minmax_e5m2_haswell_contiguous_( //
|
|
2704
|
+
nk_e5m2_t const *data_ptr, nk_size_t count, //
|
|
2705
|
+
nk_e5m2_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
2706
|
+
nk_e5m2_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
2707
|
+
|
|
2708
|
+
// E5M2 NaN in comparable form: 0x00-0x02 (neg NaN) and 0xFD-0xFF (pos NaN).
|
|
2709
|
+
// Infinity (comparable 0x03/0xFC) is NOT NaN and IS included.
|
|
2710
|
+
nk_b256_vec_t min_vec, max_vec;
|
|
2711
|
+
min_vec.ymm = _mm256_set1_epi8((char)0xFF);
|
|
2712
|
+
max_vec.ymm = _mm256_setzero_si256();
|
|
2713
|
+
__m256i min_loop_cycle_u8x32 = _mm256_setzero_si256();
|
|
2714
|
+
__m256i max_loop_cycle_u8x32 = _mm256_setzero_si256();
|
|
2715
|
+
__m256i current_loop_cycle_u8x32 = _mm256_setzero_si256();
|
|
2716
|
+
__m256i one_u8x32 = _mm256_set1_epi8(1);
|
|
2717
|
+
__m256i all_ones_u8x32 = _mm256_set1_epi8((char)0xFF);
|
|
2718
|
+
__m256i low_bound_u8x32 = _mm256_set1_epi8(0x02);
|
|
2719
|
+
__m256i high_bound_u8x32 = _mm256_set1_epi8((char)0xFD);
|
|
2720
|
+
|
|
2721
|
+
nk_size_t idx = 0;
|
|
2722
|
+
for (; idx + 32 <= count; idx += 32) {
|
|
2723
|
+
__m256i data_i8x32 = _mm256_loadu_si256((__m256i const *)(data_ptr + idx));
|
|
2724
|
+
__m256i data_cmp_u8x32 = nk_fp8x32_to_u8x32_comparable_haswell_(data_i8x32);
|
|
2725
|
+
// Detect NaN: comparable <= 0x02 or comparable >= 0xFD
|
|
2726
|
+
__m256i is_nan_low_u8x32 = _mm256_cmpeq_epi8(_mm256_min_epu8(data_cmp_u8x32, low_bound_u8x32), data_cmp_u8x32);
|
|
2727
|
+
__m256i is_nan_high_u8x32 = _mm256_cmpeq_epi8(_mm256_max_epu8(data_cmp_u8x32, high_bound_u8x32),
|
|
2728
|
+
data_cmp_u8x32);
|
|
2729
|
+
__m256i is_nan_u8x32 = _mm256_or_si256(is_nan_low_u8x32, is_nan_high_u8x32);
|
|
2730
|
+
__m256i data_min_u8x32 = _mm256_blendv_epi8(data_cmp_u8x32, all_ones_u8x32, is_nan_u8x32);
|
|
2731
|
+
__m256i data_max_u8x32 = _mm256_blendv_epi8(data_cmp_u8x32, _mm256_setzero_si256(), is_nan_u8x32);
|
|
2732
|
+
__m256i new_min_u8x32 = _mm256_min_epu8(min_vec.ymm, data_min_u8x32);
|
|
2733
|
+
__m256i min_changed_i8x32 = _mm256_xor_si256(_mm256_cmpeq_epi8(new_min_u8x32, min_vec.ymm), all_ones_u8x32);
|
|
2734
|
+
min_vec.ymm = new_min_u8x32;
|
|
2735
|
+
min_loop_cycle_u8x32 = _mm256_blendv_epi8(min_loop_cycle_u8x32, current_loop_cycle_u8x32, min_changed_i8x32);
|
|
2736
|
+
__m256i new_max_u8x32 = _mm256_max_epu8(max_vec.ymm, data_max_u8x32);
|
|
2737
|
+
__m256i max_changed_i8x32 = _mm256_xor_si256(_mm256_cmpeq_epi8(new_max_u8x32, max_vec.ymm), all_ones_u8x32);
|
|
2738
|
+
max_vec.ymm = new_max_u8x32;
|
|
2739
|
+
max_loop_cycle_u8x32 = _mm256_blendv_epi8(max_loop_cycle_u8x32, current_loop_cycle_u8x32, max_changed_i8x32);
|
|
2740
|
+
current_loop_cycle_u8x32 = _mm256_add_epi8(current_loop_cycle_u8x32, one_u8x32);
|
|
2741
|
+
}
|
|
2742
|
+
|
|
2743
|
+
nk_size_t remaining = count - idx;
|
|
2744
|
+
if (remaining > 0) {
|
|
2745
|
+
nk_b256_vec_t tail_vec;
|
|
2746
|
+
tail_vec.ymm = _mm256_set1_epi8((char)0xFF);
|
|
2747
|
+
nk_partial_load_b8x32_serial_(data_ptr + idx, &tail_vec, remaining);
|
|
2748
|
+
for (nk_size_t i = remaining; i < 32; ++i) tail_vec.u8s[i] = 0xFF;
|
|
2749
|
+
__m256i data_cmp_u8x32 = nk_fp8x32_to_u8x32_comparable_haswell_(tail_vec.ymm);
|
|
2750
|
+
__m256i is_nan_low_u8x32 = _mm256_cmpeq_epi8(_mm256_min_epu8(data_cmp_u8x32, low_bound_u8x32), data_cmp_u8x32);
|
|
2751
|
+
__m256i is_nan_high_u8x32 = _mm256_cmpeq_epi8(_mm256_max_epu8(data_cmp_u8x32, high_bound_u8x32),
|
|
2752
|
+
data_cmp_u8x32);
|
|
2753
|
+
__m256i is_nan_u8x32 = _mm256_or_si256(is_nan_low_u8x32, is_nan_high_u8x32);
|
|
2754
|
+
__m256i data_min_u8x32 = _mm256_blendv_epi8(data_cmp_u8x32, all_ones_u8x32, is_nan_u8x32);
|
|
2755
|
+
__m256i data_max_u8x32 = _mm256_blendv_epi8(data_cmp_u8x32, _mm256_setzero_si256(), is_nan_u8x32);
|
|
2756
|
+
nk_b256_vec_t min_cmp_vec, max_cmp_vec;
|
|
2757
|
+
min_cmp_vec.ymm = data_min_u8x32;
|
|
2758
|
+
max_cmp_vec.ymm = data_max_u8x32;
|
|
2759
|
+
for (nk_size_t i = remaining; i < 32; ++i) min_cmp_vec.u8s[i] = 0xFF, max_cmp_vec.u8s[i] = 0x00;
|
|
2760
|
+
__m256i new_min_u8x32 = _mm256_min_epu8(min_vec.ymm, min_cmp_vec.ymm);
|
|
2761
|
+
__m256i min_changed_i8x32 = _mm256_xor_si256(_mm256_cmpeq_epi8(new_min_u8x32, min_vec.ymm), all_ones_u8x32);
|
|
2762
|
+
min_vec.ymm = new_min_u8x32;
|
|
2763
|
+
min_loop_cycle_u8x32 = _mm256_blendv_epi8(min_loop_cycle_u8x32, current_loop_cycle_u8x32, min_changed_i8x32);
|
|
2764
|
+
__m256i new_max_u8x32 = _mm256_max_epu8(max_vec.ymm, max_cmp_vec.ymm);
|
|
2765
|
+
__m256i max_changed_i8x32 = _mm256_xor_si256(_mm256_cmpeq_epi8(new_max_u8x32, max_vec.ymm), all_ones_u8x32);
|
|
2766
|
+
max_vec.ymm = new_max_u8x32;
|
|
2767
|
+
max_loop_cycle_u8x32 = _mm256_blendv_epi8(max_loop_cycle_u8x32, current_loop_cycle_u8x32, max_changed_i8x32);
|
|
2768
|
+
}
|
|
2769
|
+
|
|
2770
|
+
nk_u8_t min_value_comparable = nk_reduce_min_u8x32_haswell_(min_vec.ymm);
|
|
2771
|
+
nk_u8_t max_value_comparable = nk_reduce_max_u8x32_haswell_(max_vec.ymm);
|
|
2772
|
+
|
|
2773
|
+
// All-NaN early return: both sentinels unchanged means no valid data was found
|
|
2774
|
+
if (min_value_comparable == 0xFF && max_value_comparable == 0x00) {
|
|
2775
|
+
*min_value_ptr = NK_E5M2_MAX, *min_index_ptr = NK_SIZE_MAX;
|
|
2776
|
+
*max_value_ptr = NK_E5M2_MIN, *max_index_ptr = NK_SIZE_MAX;
|
|
2777
|
+
return;
|
|
2778
|
+
}
|
|
2779
|
+
|
|
2780
|
+
if (min_value_comparable == 0xFF) { *min_value_ptr = NK_E5M2_MAX, *min_index_ptr = NK_SIZE_MAX; }
|
|
2781
|
+
else {
|
|
2782
|
+
unsigned int min_lane;
|
|
2783
|
+
__m256i value_match_b8x32 = _mm256_cmpeq_epi8(min_vec.ymm, _mm256_set1_epi8((char)min_value_comparable));
|
|
2784
|
+
__m256i masked_cycle_u8x32 = _mm256_blendv_epi8(_mm256_set1_epi8((char)NK_U8_MAX), min_loop_cycle_u8x32,
|
|
2785
|
+
value_match_b8x32);
|
|
2786
|
+
nk_u8_t earliest_loop_cycle = nk_reduce_min_u8x32_haswell_(masked_cycle_u8x32);
|
|
2787
|
+
__m256i cycle_match_b8x32 = _mm256_cmpeq_epi8(masked_cycle_u8x32, _mm256_set1_epi8((char)earliest_loop_cycle));
|
|
2788
|
+
min_lane = _tzcnt_u32((unsigned int)_mm256_movemask_epi8(cycle_match_b8x32));
|
|
2789
|
+
nk_b256_vec_t loop_cycle_vec;
|
|
2790
|
+
loop_cycle_vec.ymm = min_loop_cycle_u8x32;
|
|
2791
|
+
*min_index_ptr = (nk_size_t)loop_cycle_vec.u8s[min_lane] * 32 + min_lane;
|
|
2792
|
+
nk_b256_vec_t min_raw_vec;
|
|
2793
|
+
min_raw_vec.ymm = nk_u8x32_comparable_to_fp8x32_haswell_(min_vec.ymm);
|
|
2794
|
+
*min_value_ptr = min_raw_vec.e5m2s[min_lane];
|
|
2795
|
+
}
|
|
2796
|
+
if (max_value_comparable == 0x00) { *max_value_ptr = NK_E5M2_MIN, *max_index_ptr = NK_SIZE_MAX; }
|
|
2797
|
+
else {
|
|
2798
|
+
unsigned int max_lane;
|
|
2799
|
+
__m256i value_match_b8x32 = _mm256_cmpeq_epi8(max_vec.ymm, _mm256_set1_epi8((char)max_value_comparable));
|
|
2800
|
+
__m256i masked_cycle_u8x32 = _mm256_blendv_epi8(_mm256_set1_epi8((char)NK_U8_MAX), max_loop_cycle_u8x32,
|
|
2801
|
+
value_match_b8x32);
|
|
2802
|
+
nk_u8_t earliest_loop_cycle = nk_reduce_min_u8x32_haswell_(masked_cycle_u8x32);
|
|
2803
|
+
__m256i cycle_match_b8x32 = _mm256_cmpeq_epi8(masked_cycle_u8x32, _mm256_set1_epi8((char)earliest_loop_cycle));
|
|
2804
|
+
max_lane = _tzcnt_u32((unsigned int)_mm256_movemask_epi8(cycle_match_b8x32));
|
|
2805
|
+
nk_b256_vec_t loop_cycle_vec;
|
|
2806
|
+
loop_cycle_vec.ymm = max_loop_cycle_u8x32;
|
|
2807
|
+
*max_index_ptr = (nk_size_t)loop_cycle_vec.u8s[max_lane] * 32 + max_lane;
|
|
2808
|
+
nk_b256_vec_t max_raw_vec;
|
|
2809
|
+
max_raw_vec.ymm = nk_u8x32_comparable_to_fp8x32_haswell_(max_vec.ymm);
|
|
2810
|
+
*max_value_ptr = max_raw_vec.e5m2s[max_lane];
|
|
2811
|
+
}
|
|
2812
|
+
}
|
|
2813
|
+
|
|
2814
|
+
NK_PUBLIC void nk_reduce_minmax_e5m2_haswell( //
|
|
2815
|
+
nk_e5m2_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
2816
|
+
nk_e5m2_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
2817
|
+
nk_e5m2_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
2818
|
+
|
|
2819
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_e5m2_t);
|
|
2820
|
+
if (count == 0)
|
|
2821
|
+
*min_value_ptr = NK_E5M2_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_E5M2_MIN,
|
|
2822
|
+
*max_index_ptr = NK_SIZE_MAX;
|
|
2823
|
+
else if (stride_elements == 1 && count > (nk_size_t)(NK_U8_MAX + 1) * 32) {
|
|
2824
|
+
nk_size_t left_count = count / 2;
|
|
2825
|
+
nk_e5m2_t left_min, right_min, left_max, right_max;
|
|
2826
|
+
nk_size_t left_min_index, right_min_index, left_max_index, right_max_index;
|
|
2827
|
+
nk_reduce_minmax_e5m2_haswell(data_ptr, left_count, stride_bytes, &left_min, &left_min_index, &left_max,
|
|
2828
|
+
&left_max_index);
|
|
2829
|
+
nk_reduce_minmax_e5m2_haswell(data_ptr + left_count, count - left_count, stride_bytes, &right_min,
|
|
2830
|
+
&right_min_index, &right_max, &right_max_index);
|
|
2831
|
+
if (left_min_index == NK_SIZE_MAX)
|
|
2832
|
+
*min_value_ptr = right_min,
|
|
2833
|
+
*min_index_ptr = right_min_index == NK_SIZE_MAX ? NK_SIZE_MAX : left_count + right_min_index;
|
|
2834
|
+
else if (right_min_index == NK_SIZE_MAX || nk_e5m2_order_serial(left_min, right_min) <= 0)
|
|
2835
|
+
*min_value_ptr = left_min, *min_index_ptr = left_min_index;
|
|
2836
|
+
else *min_value_ptr = right_min, *min_index_ptr = left_count + right_min_index;
|
|
2837
|
+
if (left_max_index == NK_SIZE_MAX)
|
|
2838
|
+
*max_value_ptr = right_max,
|
|
2839
|
+
*max_index_ptr = right_max_index == NK_SIZE_MAX ? NK_SIZE_MAX : left_count + right_max_index;
|
|
2840
|
+
else if (right_max_index == NK_SIZE_MAX || nk_e5m2_order_serial(left_max, right_max) >= 0)
|
|
2841
|
+
*max_value_ptr = left_max, *max_index_ptr = left_max_index;
|
|
2842
|
+
else *max_value_ptr = right_max, *max_index_ptr = left_count + right_max_index;
|
|
2843
|
+
}
|
|
2844
|
+
else if (stride_elements == 1)
|
|
2845
|
+
nk_reduce_minmax_e5m2_haswell_contiguous_(data_ptr, count, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
2846
|
+
max_index_ptr);
|
|
2847
|
+
else
|
|
2848
|
+
nk_reduce_minmax_e5m2_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
2849
|
+
max_index_ptr);
|
|
2850
|
+
}
|
|
2851
|
+
|
|
2852
|
+
NK_INTERNAL void nk_reduce_moments_e2m3_haswell_contiguous_( //
|
|
2853
|
+
nk_e2m3_t const *data_ptr, nk_size_t count, //
|
|
2854
|
+
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
2855
|
+
|
|
2856
|
+
__m256 sum_f32x8 = _mm256_setzero_ps();
|
|
2857
|
+
__m256 sumsq_f32x8 = _mm256_setzero_ps();
|
|
2858
|
+
nk_size_t idx = 0;
|
|
2859
|
+
for (; idx + 8 <= count; idx += 8) {
|
|
2860
|
+
__m256 data_f32x8 = nk_e2m3x8_to_f32x8_haswell_(_mm_loadl_epi64((__m128i const *)(data_ptr + idx)));
|
|
2861
|
+
sum_f32x8 = _mm256_add_ps(sum_f32x8, data_f32x8);
|
|
2862
|
+
sumsq_f32x8 = _mm256_fmadd_ps(data_f32x8, data_f32x8, sumsq_f32x8);
|
|
2863
|
+
}
|
|
2864
|
+
nk_size_t remaining = count - idx;
|
|
2865
|
+
if (remaining > 0) {
|
|
2866
|
+
nk_b256_vec_t vec;
|
|
2867
|
+
nk_partial_load_e2m3x8_to_f32x8_haswell_(data_ptr + idx, &vec, remaining);
|
|
2868
|
+
sum_f32x8 = _mm256_add_ps(sum_f32x8, vec.ymm_ps);
|
|
2869
|
+
sumsq_f32x8 = _mm256_fmadd_ps(vec.ymm_ps, vec.ymm_ps, sumsq_f32x8);
|
|
2870
|
+
}
|
|
2871
|
+
*sum_ptr = nk_reduce_add_f32x8_haswell_(sum_f32x8), *sumsq_ptr = nk_reduce_add_f32x8_haswell_(sumsq_f32x8);
|
|
2872
|
+
}
|
|
2873
|
+
|
|
2874
|
+
NK_INTERNAL void nk_reduce_moments_e2m3_haswell_strided_( //
|
|
2875
|
+
nk_e2m3_t const *data_ptr, nk_size_t count, nk_size_t stride_elements, //
|
|
2876
|
+
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
2877
|
+
|
|
2878
|
+
__m256 sum_f32x8 = _mm256_setzero_ps();
|
|
2879
|
+
__m256 sumsq_f32x8 = _mm256_setzero_ps();
|
|
2880
|
+
nk_size_t idx = 0;
|
|
2881
|
+
for (; idx + 8 <= count; idx += 8) {
|
|
2882
|
+
nk_b64_vec_t buf;
|
|
2883
|
+
buf.u64 = 0;
|
|
2884
|
+
nk_e2m3_t const *ptr = data_ptr + idx * stride_elements;
|
|
2885
|
+
for (nk_size_t i = 0; i < 8; ++i) buf.u8s[i] = ptr[i * stride_elements];
|
|
2886
|
+
__m256 data_f32x8 = nk_e2m3x8_to_f32x8_haswell_(_mm_cvtsi64_si128(buf.u64));
|
|
2887
|
+
sum_f32x8 = _mm256_add_ps(sum_f32x8, data_f32x8);
|
|
2888
|
+
sumsq_f32x8 = _mm256_fmadd_ps(data_f32x8, data_f32x8, sumsq_f32x8);
|
|
2889
|
+
}
|
|
2890
|
+
nk_size_t remaining = count - idx;
|
|
2891
|
+
if (remaining > 0) {
|
|
2892
|
+
nk_b64_vec_t buf;
|
|
2893
|
+
buf.u64 = 0;
|
|
2894
|
+
nk_e2m3_t const *ptr = data_ptr + idx * stride_elements;
|
|
2895
|
+
for (nk_size_t i = 0; i < remaining; ++i) buf.u8s[i] = ptr[i * stride_elements];
|
|
2896
|
+
nk_b256_vec_t vec;
|
|
2897
|
+
nk_partial_load_e2m3x8_to_f32x8_haswell_((nk_e2m3_t *)&buf, &vec, remaining);
|
|
2898
|
+
sum_f32x8 = _mm256_add_ps(sum_f32x8, vec.ymm_ps);
|
|
2899
|
+
sumsq_f32x8 = _mm256_fmadd_ps(vec.ymm_ps, vec.ymm_ps, sumsq_f32x8);
|
|
2900
|
+
}
|
|
2901
|
+
*sum_ptr = nk_reduce_add_f32x8_haswell_(sum_f32x8), *sumsq_ptr = nk_reduce_add_f32x8_haswell_(sumsq_f32x8);
|
|
2902
|
+
}
|
|
2903
|
+
|
|
2904
|
+
NK_PUBLIC void nk_reduce_moments_e2m3_haswell( //
|
|
2905
|
+
nk_e2m3_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
2906
|
+
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
2907
|
+
|
|
2908
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_e2m3_t);
|
|
2909
|
+
int aligned = (stride_bytes % sizeof(nk_e2m3_t) == 0);
|
|
2910
|
+
if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
|
|
2911
|
+
else if (!aligned) nk_reduce_moments_e2m3_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
2912
|
+
else if (count > (nk_size_t)(NK_U16_MAX + 1) * 32) {
|
|
2913
|
+
nk_size_t left_count = count / 2;
|
|
2914
|
+
nk_f32_t left_sum, left_sumsq, right_sum, right_sumsq;
|
|
2915
|
+
nk_reduce_moments_e2m3_haswell(data_ptr, left_count, stride_bytes, &left_sum, &left_sumsq);
|
|
2916
|
+
nk_reduce_moments_e2m3_haswell(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
|
|
2917
|
+
&right_sum, &right_sumsq);
|
|
2918
|
+
*sum_ptr = left_sum + right_sum, *sumsq_ptr = left_sumsq + right_sumsq;
|
|
2919
|
+
}
|
|
2920
|
+
else if (stride_elements == 1) nk_reduce_moments_e2m3_haswell_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
|
|
2921
|
+
else if (stride_elements >= 2 && stride_elements <= 8)
|
|
2922
|
+
nk_reduce_moments_e2m3_haswell_strided_(data_ptr, count, stride_elements, sum_ptr, sumsq_ptr);
|
|
2923
|
+
else nk_reduce_moments_e2m3_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
2924
|
+
}
|
|
2925
|
+
|
|
2926
|
+
NK_INTERNAL __m256i nk_fp6x32_to_u8x32_comparable_haswell_(__m256i raw_i8x32) {
|
|
2927
|
+
raw_i8x32 = _mm256_and_si256(raw_i8x32, _mm256_set1_epi8(0x3F)); // mask to 6 valid bits
|
|
2928
|
+
__m256i sign_mask = _mm256_set1_epi8(0x20);
|
|
2929
|
+
__m256i neg_i8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(raw_i8x32, sign_mask), sign_mask);
|
|
2930
|
+
__m256i pos_xor_i8x32 = sign_mask; // flip sign bit only
|
|
2931
|
+
__m256i neg_xor_i8x32 = _mm256_set1_epi8(0x3F); // flip all 6 bits
|
|
2932
|
+
__m256i xor_i8x32 = _mm256_blendv_epi8(pos_xor_i8x32, neg_xor_i8x32, neg_i8x32);
|
|
2933
|
+
return _mm256_xor_si256(raw_i8x32, xor_i8x32);
|
|
2934
|
+
}
|
|
2935
|
+
|
|
2936
|
+
NK_INTERNAL __m256i nk_u8x32_comparable_to_fp6x32_haswell_(__m256i cmp_i8x32) {
|
|
2937
|
+
__m256i sign_mask_i8x32 = _mm256_set1_epi8(0x20);
|
|
2938
|
+
__m256i was_neg_i8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(cmp_i8x32, sign_mask_i8x32), _mm256_setzero_si256());
|
|
2939
|
+
__m256i neg_xor_i8x32 = _mm256_set1_epi8(0x3F);
|
|
2940
|
+
__m256i pos_xor_i8x32 = sign_mask_i8x32;
|
|
2941
|
+
__m256i xor_i8x32 = _mm256_blendv_epi8(pos_xor_i8x32, neg_xor_i8x32, was_neg_i8x32);
|
|
2942
|
+
return _mm256_xor_si256(cmp_i8x32, xor_i8x32);
|
|
2943
|
+
}
|
|
2944
|
+
|
|
2945
|
+
NK_INTERNAL void nk_reduce_minmax_e2m3_haswell_contiguous_( //
|
|
2946
|
+
nk_e2m3_t const *data_ptr, nk_size_t count, //
|
|
2947
|
+
nk_e2m3_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
2948
|
+
nk_e2m3_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
2949
|
+
|
|
2950
|
+
// FP6 has no NaN — use simple unsigned min/max on comparable form
|
|
2951
|
+
nk_b256_vec_t min_vec, max_vec;
|
|
2952
|
+
min_vec.ymm = _mm256_set1_epi8((char)0xFF);
|
|
2953
|
+
max_vec.ymm = _mm256_setzero_si256();
|
|
2954
|
+
__m256i min_loop_cycle_u8x32 = _mm256_setzero_si256();
|
|
2955
|
+
__m256i max_loop_cycle_u8x32 = _mm256_setzero_si256();
|
|
2956
|
+
__m256i current_loop_cycle_u8x32 = _mm256_setzero_si256();
|
|
2957
|
+
__m256i one_u8x32 = _mm256_set1_epi8(1);
|
|
2958
|
+
|
|
2959
|
+
nk_size_t idx = 0;
|
|
2960
|
+
for (; idx + 32 <= count; idx += 32) {
|
|
2961
|
+
__m256i data_i8x32 = _mm256_loadu_si256((__m256i const *)(data_ptr + idx));
|
|
2962
|
+
__m256i data_cmp_u8x32 = nk_fp6x32_to_u8x32_comparable_haswell_(data_i8x32);
|
|
2963
|
+
__m256i new_min = _mm256_min_epu8(min_vec.ymm, data_cmp_u8x32);
|
|
2964
|
+
__m256i min_changed_i8x32 = _mm256_xor_si256(_mm256_cmpeq_epi8(new_min, min_vec.ymm),
|
|
2965
|
+
_mm256_set1_epi8((char)0xFF));
|
|
2966
|
+
min_vec.ymm = new_min;
|
|
2967
|
+
min_loop_cycle_u8x32 = _mm256_blendv_epi8(min_loop_cycle_u8x32, current_loop_cycle_u8x32, min_changed_i8x32);
|
|
2968
|
+
__m256i new_max = _mm256_max_epu8(max_vec.ymm, data_cmp_u8x32);
|
|
2969
|
+
__m256i max_changed_i8x32 = _mm256_xor_si256(_mm256_cmpeq_epi8(new_max, max_vec.ymm),
|
|
2970
|
+
_mm256_set1_epi8((char)0xFF));
|
|
2971
|
+
max_vec.ymm = new_max;
|
|
2972
|
+
max_loop_cycle_u8x32 = _mm256_blendv_epi8(max_loop_cycle_u8x32, current_loop_cycle_u8x32, max_changed_i8x32);
|
|
2973
|
+
current_loop_cycle_u8x32 = _mm256_add_epi8(current_loop_cycle_u8x32, one_u8x32);
|
|
2974
|
+
}
|
|
2975
|
+
|
|
2976
|
+
nk_size_t remaining = count - idx;
|
|
2977
|
+
if (remaining > 0) {
|
|
2978
|
+
nk_b256_vec_t tail_vec;
|
|
2979
|
+
nk_partial_load_b8x32_serial_(data_ptr + idx, &tail_vec, remaining);
|
|
2980
|
+
__m256i data_cmp_u8x32 = nk_fp6x32_to_u8x32_comparable_haswell_(tail_vec.ymm);
|
|
2981
|
+
// Fill invalid lanes with identity: 0x3F for min (max comparable), 0x00 for max (min comparable)
|
|
2982
|
+
__m256i lane_indices_u8x32 = _mm256_setr_epi8(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
|
|
2983
|
+
19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31);
|
|
2984
|
+
__m256i valid_b8x32 = _mm256_cmpgt_epi8(_mm256_set1_epi8((char)remaining), lane_indices_u8x32);
|
|
2985
|
+
__m256i data_min_u8x32 = _mm256_blendv_epi8(_mm256_set1_epi8(0x3F), data_cmp_u8x32, valid_b8x32);
|
|
2986
|
+
__m256i data_max_u8x32 = _mm256_blendv_epi8(_mm256_setzero_si256(), data_cmp_u8x32, valid_b8x32);
|
|
2987
|
+
__m256i new_min = _mm256_min_epu8(min_vec.ymm, data_min_u8x32);
|
|
2988
|
+
__m256i min_changed_i8x32 = _mm256_xor_si256(_mm256_cmpeq_epi8(new_min, min_vec.ymm),
|
|
2989
|
+
_mm256_set1_epi8((char)0xFF));
|
|
2990
|
+
min_vec.ymm = new_min;
|
|
2991
|
+
min_loop_cycle_u8x32 = _mm256_blendv_epi8(min_loop_cycle_u8x32, current_loop_cycle_u8x32, min_changed_i8x32);
|
|
2992
|
+
__m256i new_max = _mm256_max_epu8(max_vec.ymm, data_max_u8x32);
|
|
2993
|
+
__m256i max_changed_i8x32 = _mm256_xor_si256(_mm256_cmpeq_epi8(new_max, max_vec.ymm),
|
|
2994
|
+
_mm256_set1_epi8((char)0xFF));
|
|
2995
|
+
max_vec.ymm = new_max;
|
|
2996
|
+
max_loop_cycle_u8x32 = _mm256_blendv_epi8(max_loop_cycle_u8x32, current_loop_cycle_u8x32, max_changed_i8x32);
|
|
2997
|
+
}
|
|
2998
|
+
|
|
2999
|
+
nk_u8_t min_value_comparable = nk_reduce_min_u8x32_haswell_(min_vec.ymm);
|
|
3000
|
+
nk_u8_t max_value_comparable = nk_reduce_max_u8x32_haswell_(max_vec.ymm);
|
|
3001
|
+
unsigned int min_lane, max_lane;
|
|
3002
|
+
{
|
|
3003
|
+
__m256i value_match_b8x32 = _mm256_cmpeq_epi8(min_vec.ymm, _mm256_set1_epi8((char)min_value_comparable));
|
|
3004
|
+
__m256i masked_cycle_u8x32 = _mm256_blendv_epi8(_mm256_set1_epi8((char)NK_U8_MAX), min_loop_cycle_u8x32,
|
|
3005
|
+
value_match_b8x32);
|
|
3006
|
+
nk_u8_t earliest_loop_cycle = nk_reduce_min_u8x32_haswell_(masked_cycle_u8x32);
|
|
3007
|
+
__m256i cycle_match_b8x32 = _mm256_cmpeq_epi8(masked_cycle_u8x32, _mm256_set1_epi8((char)earliest_loop_cycle));
|
|
3008
|
+
min_lane = _tzcnt_u32((unsigned int)_mm256_movemask_epi8(cycle_match_b8x32));
|
|
3009
|
+
}
|
|
3010
|
+
{
|
|
3011
|
+
__m256i value_match_b8x32 = _mm256_cmpeq_epi8(max_vec.ymm, _mm256_set1_epi8((char)max_value_comparable));
|
|
3012
|
+
__m256i masked_cycle_u8x32 = _mm256_blendv_epi8(_mm256_set1_epi8((char)NK_U8_MAX), max_loop_cycle_u8x32,
|
|
3013
|
+
value_match_b8x32);
|
|
3014
|
+
nk_u8_t earliest_loop_cycle = nk_reduce_min_u8x32_haswell_(masked_cycle_u8x32);
|
|
3015
|
+
__m256i cycle_match_b8x32 = _mm256_cmpeq_epi8(masked_cycle_u8x32, _mm256_set1_epi8((char)earliest_loop_cycle));
|
|
3016
|
+
max_lane = _tzcnt_u32((unsigned int)_mm256_movemask_epi8(cycle_match_b8x32));
|
|
3017
|
+
}
|
|
3018
|
+
nk_b256_vec_t loop_cycle_vec;
|
|
3019
|
+
loop_cycle_vec.ymm = min_loop_cycle_u8x32;
|
|
3020
|
+
*min_index_ptr = (nk_size_t)loop_cycle_vec.u8s[min_lane] * 32 + min_lane;
|
|
3021
|
+
loop_cycle_vec.ymm = max_loop_cycle_u8x32;
|
|
3022
|
+
*max_index_ptr = (nk_size_t)loop_cycle_vec.u8s[max_lane] * 32 + max_lane;
|
|
3023
|
+
min_vec.ymm = nk_u8x32_comparable_to_fp6x32_haswell_(min_vec.ymm);
|
|
3024
|
+
max_vec.ymm = nk_u8x32_comparable_to_fp6x32_haswell_(max_vec.ymm);
|
|
3025
|
+
*min_value_ptr = min_vec.e2m3s[min_lane];
|
|
3026
|
+
*max_value_ptr = max_vec.e2m3s[max_lane];
|
|
3027
|
+
}
|
|
3028
|
+
|
|
3029
|
+
NK_PUBLIC void nk_reduce_minmax_e2m3_haswell( //
|
|
3030
|
+
nk_e2m3_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
3031
|
+
nk_e2m3_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
3032
|
+
nk_e2m3_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
3033
|
+
|
|
3034
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_e2m3_t);
|
|
3035
|
+
if (count == 0)
|
|
3036
|
+
*min_value_ptr = NK_E2M3_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_E2M3_MIN,
|
|
3037
|
+
*max_index_ptr = NK_SIZE_MAX;
|
|
3038
|
+
else if (stride_elements == 1 && count > (nk_size_t)(NK_U8_MAX + 1) * 32) {
|
|
3039
|
+
nk_size_t left_count = count / 2;
|
|
3040
|
+
nk_e2m3_t left_min, right_min, left_max, right_max;
|
|
3041
|
+
nk_size_t left_min_index, right_min_index, left_max_index, right_max_index;
|
|
3042
|
+
nk_reduce_minmax_e2m3_haswell(data_ptr, left_count, stride_bytes, &left_min, &left_min_index, &left_max,
|
|
3043
|
+
&left_max_index);
|
|
3044
|
+
nk_reduce_minmax_e2m3_haswell(data_ptr + left_count, count - left_count, stride_bytes, &right_min,
|
|
3045
|
+
&right_min_index, &right_max, &right_max_index);
|
|
3046
|
+
if (nk_e2m3_order_serial(right_min, left_min) < 0)
|
|
3047
|
+
*min_value_ptr = right_min, *min_index_ptr = left_count + right_min_index;
|
|
3048
|
+
else *min_value_ptr = left_min, *min_index_ptr = left_min_index;
|
|
3049
|
+
if (nk_e2m3_order_serial(right_max, left_max) > 0)
|
|
3050
|
+
*max_value_ptr = right_max, *max_index_ptr = left_count + right_max_index;
|
|
3051
|
+
else *max_value_ptr = left_max, *max_index_ptr = left_max_index;
|
|
3052
|
+
}
|
|
3053
|
+
else if (stride_elements == 1)
|
|
3054
|
+
nk_reduce_minmax_e2m3_haswell_contiguous_(data_ptr, count, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
3055
|
+
max_index_ptr);
|
|
3056
|
+
else
|
|
3057
|
+
nk_reduce_minmax_e2m3_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
3058
|
+
max_index_ptr);
|
|
3059
|
+
}
|
|
3060
|
+
|
|
3061
|
+
NK_INTERNAL void nk_reduce_moments_e3m2_haswell_contiguous_( //
|
|
3062
|
+
nk_e3m2_t const *data_ptr, nk_size_t count, //
|
|
3063
|
+
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
3064
|
+
|
|
3065
|
+
__m256 sum_f32x8 = _mm256_setzero_ps();
|
|
3066
|
+
__m256 sumsq_f32x8 = _mm256_setzero_ps();
|
|
3067
|
+
nk_size_t idx = 0;
|
|
3068
|
+
for (; idx + 8 <= count; idx += 8) {
|
|
3069
|
+
__m256 data_f32x8 = nk_e3m2x8_to_f32x8_haswell_(_mm_loadl_epi64((__m128i const *)(data_ptr + idx)));
|
|
3070
|
+
sum_f32x8 = _mm256_add_ps(sum_f32x8, data_f32x8);
|
|
3071
|
+
sumsq_f32x8 = _mm256_fmadd_ps(data_f32x8, data_f32x8, sumsq_f32x8);
|
|
3072
|
+
}
|
|
3073
|
+
nk_size_t remaining = count - idx;
|
|
3074
|
+
if (remaining > 0) {
|
|
3075
|
+
nk_b256_vec_t vec;
|
|
3076
|
+
nk_partial_load_e3m2x8_to_f32x8_haswell_(data_ptr + idx, &vec, remaining);
|
|
3077
|
+
sum_f32x8 = _mm256_add_ps(sum_f32x8, vec.ymm_ps);
|
|
3078
|
+
sumsq_f32x8 = _mm256_fmadd_ps(vec.ymm_ps, vec.ymm_ps, sumsq_f32x8);
|
|
3079
|
+
}
|
|
3080
|
+
*sum_ptr = nk_reduce_add_f32x8_haswell_(sum_f32x8), *sumsq_ptr = nk_reduce_add_f32x8_haswell_(sumsq_f32x8);
|
|
3081
|
+
}
|
|
3082
|
+
|
|
3083
|
+
NK_INTERNAL void nk_reduce_moments_e3m2_haswell_strided_( //
|
|
3084
|
+
nk_e3m2_t const *data_ptr, nk_size_t count, nk_size_t stride_elements, //
|
|
3085
|
+
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
3086
|
+
|
|
3087
|
+
__m256 sum_f32x8 = _mm256_setzero_ps();
|
|
3088
|
+
__m256 sumsq_f32x8 = _mm256_setzero_ps();
|
|
3089
|
+
nk_size_t idx = 0;
|
|
3090
|
+
for (; idx + 8 <= count; idx += 8) {
|
|
3091
|
+
nk_b64_vec_t buf;
|
|
3092
|
+
buf.u64 = 0;
|
|
3093
|
+
nk_e3m2_t const *ptr = data_ptr + idx * stride_elements;
|
|
3094
|
+
for (nk_size_t i = 0; i < 8; ++i) buf.u8s[i] = ptr[i * stride_elements];
|
|
3095
|
+
__m256 data_f32x8 = nk_e3m2x8_to_f32x8_haswell_(_mm_cvtsi64_si128(buf.u64));
|
|
3096
|
+
sum_f32x8 = _mm256_add_ps(sum_f32x8, data_f32x8);
|
|
3097
|
+
sumsq_f32x8 = _mm256_fmadd_ps(data_f32x8, data_f32x8, sumsq_f32x8);
|
|
3098
|
+
}
|
|
3099
|
+
nk_size_t remaining = count - idx;
|
|
3100
|
+
if (remaining > 0) {
|
|
3101
|
+
nk_b64_vec_t buf;
|
|
3102
|
+
buf.u64 = 0;
|
|
3103
|
+
nk_e3m2_t const *ptr = data_ptr + idx * stride_elements;
|
|
3104
|
+
for (nk_size_t i = 0; i < remaining; ++i) buf.u8s[i] = ptr[i * stride_elements];
|
|
3105
|
+
nk_b256_vec_t vec;
|
|
3106
|
+
nk_partial_load_e3m2x8_to_f32x8_haswell_((nk_e3m2_t *)&buf, &vec, remaining);
|
|
3107
|
+
sum_f32x8 = _mm256_add_ps(sum_f32x8, vec.ymm_ps);
|
|
3108
|
+
sumsq_f32x8 = _mm256_fmadd_ps(vec.ymm_ps, vec.ymm_ps, sumsq_f32x8);
|
|
3109
|
+
}
|
|
3110
|
+
*sum_ptr = nk_reduce_add_f32x8_haswell_(sum_f32x8), *sumsq_ptr = nk_reduce_add_f32x8_haswell_(sumsq_f32x8);
|
|
3111
|
+
}
|
|
3112
|
+
|
|
3113
|
+
NK_PUBLIC void nk_reduce_moments_e3m2_haswell( //
|
|
3114
|
+
nk_e3m2_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
3115
|
+
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
3116
|
+
|
|
3117
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_e3m2_t);
|
|
3118
|
+
int aligned = (stride_bytes % sizeof(nk_e3m2_t) == 0);
|
|
3119
|
+
if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
|
|
3120
|
+
else if (!aligned) nk_reduce_moments_e3m2_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
3121
|
+
else if (count > (nk_size_t)(NK_U16_MAX + 1) * 32) {
|
|
3122
|
+
nk_size_t left_count = count / 2;
|
|
3123
|
+
nk_f32_t left_sum, left_sumsq, right_sum, right_sumsq;
|
|
3124
|
+
nk_reduce_moments_e3m2_haswell(data_ptr, left_count, stride_bytes, &left_sum, &left_sumsq);
|
|
3125
|
+
nk_reduce_moments_e3m2_haswell(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
|
|
3126
|
+
&right_sum, &right_sumsq);
|
|
3127
|
+
*sum_ptr = left_sum + right_sum, *sumsq_ptr = left_sumsq + right_sumsq;
|
|
3128
|
+
}
|
|
3129
|
+
else if (stride_elements == 1) nk_reduce_moments_e3m2_haswell_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
|
|
3130
|
+
else if (stride_elements >= 2 && stride_elements <= 8)
|
|
3131
|
+
nk_reduce_moments_e3m2_haswell_strided_(data_ptr, count, stride_elements, sum_ptr, sumsq_ptr);
|
|
3132
|
+
else nk_reduce_moments_e3m2_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
3133
|
+
}
|
|
3134
|
+
|
|
3135
|
+
NK_INTERNAL void nk_reduce_minmax_e3m2_haswell_contiguous_( //
|
|
3136
|
+
nk_e3m2_t const *data_ptr, nk_size_t count, //
|
|
3137
|
+
nk_e3m2_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
3138
|
+
nk_e3m2_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
3139
|
+
|
|
3140
|
+
nk_b256_vec_t min_vec, max_vec;
|
|
3141
|
+
min_vec.ymm = _mm256_set1_epi8((char)0xFF);
|
|
3142
|
+
max_vec.ymm = _mm256_setzero_si256();
|
|
3143
|
+
__m256i min_loop_cycle_u8x32 = _mm256_setzero_si256();
|
|
3144
|
+
__m256i max_loop_cycle_u8x32 = _mm256_setzero_si256();
|
|
3145
|
+
__m256i current_loop_cycle_u8x32 = _mm256_setzero_si256();
|
|
3146
|
+
__m256i one_u8x32 = _mm256_set1_epi8(1);
|
|
3147
|
+
|
|
3148
|
+
nk_size_t idx = 0;
|
|
3149
|
+
for (; idx + 32 <= count; idx += 32) {
|
|
3150
|
+
__m256i data_i8x32 = _mm256_loadu_si256((__m256i const *)(data_ptr + idx));
|
|
3151
|
+
__m256i data_cmp_u8x32 = nk_fp6x32_to_u8x32_comparable_haswell_(data_i8x32);
|
|
3152
|
+
__m256i new_min = _mm256_min_epu8(min_vec.ymm, data_cmp_u8x32);
|
|
3153
|
+
__m256i min_changed_i8x32 = _mm256_xor_si256(_mm256_cmpeq_epi8(new_min, min_vec.ymm),
|
|
3154
|
+
_mm256_set1_epi8((char)0xFF));
|
|
3155
|
+
min_vec.ymm = new_min;
|
|
3156
|
+
min_loop_cycle_u8x32 = _mm256_blendv_epi8(min_loop_cycle_u8x32, current_loop_cycle_u8x32, min_changed_i8x32);
|
|
3157
|
+
__m256i new_max = _mm256_max_epu8(max_vec.ymm, data_cmp_u8x32);
|
|
3158
|
+
__m256i max_changed_i8x32 = _mm256_xor_si256(_mm256_cmpeq_epi8(new_max, max_vec.ymm),
|
|
3159
|
+
_mm256_set1_epi8((char)0xFF));
|
|
3160
|
+
max_vec.ymm = new_max;
|
|
3161
|
+
max_loop_cycle_u8x32 = _mm256_blendv_epi8(max_loop_cycle_u8x32, current_loop_cycle_u8x32, max_changed_i8x32);
|
|
3162
|
+
current_loop_cycle_u8x32 = _mm256_add_epi8(current_loop_cycle_u8x32, one_u8x32);
|
|
3163
|
+
}
|
|
3164
|
+
|
|
3165
|
+
nk_size_t remaining = count - idx;
|
|
3166
|
+
if (remaining > 0) {
|
|
3167
|
+
nk_b256_vec_t tail_vec;
|
|
3168
|
+
nk_partial_load_b8x32_serial_(data_ptr + idx, &tail_vec, remaining);
|
|
3169
|
+
__m256i data_cmp_u8x32 = nk_fp6x32_to_u8x32_comparable_haswell_(tail_vec.ymm);
|
|
3170
|
+
__m256i lane_indices_u8x32 = _mm256_setr_epi8(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
|
|
3171
|
+
19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31);
|
|
3172
|
+
__m256i valid_b8x32 = _mm256_cmpgt_epi8(_mm256_set1_epi8((char)remaining), lane_indices_u8x32);
|
|
3173
|
+
__m256i data_min_u8x32 = _mm256_blendv_epi8(_mm256_set1_epi8(0x3F), data_cmp_u8x32, valid_b8x32);
|
|
3174
|
+
__m256i data_max_u8x32 = _mm256_blendv_epi8(_mm256_setzero_si256(), data_cmp_u8x32, valid_b8x32);
|
|
3175
|
+
__m256i new_min = _mm256_min_epu8(min_vec.ymm, data_min_u8x32);
|
|
3176
|
+
__m256i min_changed_i8x32 = _mm256_xor_si256(_mm256_cmpeq_epi8(new_min, min_vec.ymm),
|
|
3177
|
+
_mm256_set1_epi8((char)0xFF));
|
|
3178
|
+
min_vec.ymm = new_min;
|
|
3179
|
+
min_loop_cycle_u8x32 = _mm256_blendv_epi8(min_loop_cycle_u8x32, current_loop_cycle_u8x32, min_changed_i8x32);
|
|
3180
|
+
__m256i new_max = _mm256_max_epu8(max_vec.ymm, data_max_u8x32);
|
|
3181
|
+
__m256i max_changed_i8x32 = _mm256_xor_si256(_mm256_cmpeq_epi8(new_max, max_vec.ymm),
|
|
3182
|
+
_mm256_set1_epi8((char)0xFF));
|
|
3183
|
+
max_vec.ymm = new_max;
|
|
3184
|
+
max_loop_cycle_u8x32 = _mm256_blendv_epi8(max_loop_cycle_u8x32, current_loop_cycle_u8x32, max_changed_i8x32);
|
|
3185
|
+
}
|
|
3186
|
+
|
|
3187
|
+
nk_u8_t min_value_comparable = nk_reduce_min_u8x32_haswell_(min_vec.ymm);
|
|
3188
|
+
nk_u8_t max_value_comparable = nk_reduce_max_u8x32_haswell_(max_vec.ymm);
|
|
3189
|
+
unsigned int min_lane, max_lane;
|
|
3190
|
+
{
|
|
3191
|
+
__m256i value_match_b8x32 = _mm256_cmpeq_epi8(min_vec.ymm, _mm256_set1_epi8((char)min_value_comparable));
|
|
3192
|
+
__m256i masked_cycle_u8x32 = _mm256_blendv_epi8(_mm256_set1_epi8((char)NK_U8_MAX), min_loop_cycle_u8x32,
|
|
3193
|
+
value_match_b8x32);
|
|
3194
|
+
nk_u8_t earliest_loop_cycle = nk_reduce_min_u8x32_haswell_(masked_cycle_u8x32);
|
|
3195
|
+
__m256i cycle_match_b8x32 = _mm256_cmpeq_epi8(masked_cycle_u8x32, _mm256_set1_epi8((char)earliest_loop_cycle));
|
|
3196
|
+
min_lane = _tzcnt_u32((unsigned int)_mm256_movemask_epi8(cycle_match_b8x32));
|
|
3197
|
+
}
|
|
3198
|
+
{
|
|
3199
|
+
__m256i value_match_b8x32 = _mm256_cmpeq_epi8(max_vec.ymm, _mm256_set1_epi8((char)max_value_comparable));
|
|
3200
|
+
__m256i masked_cycle_u8x32 = _mm256_blendv_epi8(_mm256_set1_epi8((char)NK_U8_MAX), max_loop_cycle_u8x32,
|
|
3201
|
+
value_match_b8x32);
|
|
3202
|
+
nk_u8_t earliest_loop_cycle = nk_reduce_min_u8x32_haswell_(masked_cycle_u8x32);
|
|
3203
|
+
__m256i cycle_match_b8x32 = _mm256_cmpeq_epi8(masked_cycle_u8x32, _mm256_set1_epi8((char)earliest_loop_cycle));
|
|
3204
|
+
max_lane = _tzcnt_u32((unsigned int)_mm256_movemask_epi8(cycle_match_b8x32));
|
|
3205
|
+
}
|
|
3206
|
+
nk_b256_vec_t loop_cycle_vec;
|
|
3207
|
+
loop_cycle_vec.ymm = min_loop_cycle_u8x32;
|
|
3208
|
+
*min_index_ptr = (nk_size_t)loop_cycle_vec.u8s[min_lane] * 32 + min_lane;
|
|
3209
|
+
loop_cycle_vec.ymm = max_loop_cycle_u8x32;
|
|
3210
|
+
*max_index_ptr = (nk_size_t)loop_cycle_vec.u8s[max_lane] * 32 + max_lane;
|
|
3211
|
+
min_vec.ymm = nk_u8x32_comparable_to_fp6x32_haswell_(min_vec.ymm);
|
|
3212
|
+
max_vec.ymm = nk_u8x32_comparable_to_fp6x32_haswell_(max_vec.ymm);
|
|
3213
|
+
*min_value_ptr = min_vec.e3m2s[min_lane];
|
|
3214
|
+
*max_value_ptr = max_vec.e3m2s[max_lane];
|
|
3215
|
+
}
|
|
3216
|
+
|
|
3217
|
+
NK_PUBLIC void nk_reduce_minmax_e3m2_haswell( //
|
|
3218
|
+
nk_e3m2_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
3219
|
+
nk_e3m2_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
3220
|
+
nk_e3m2_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
3221
|
+
|
|
3222
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_e3m2_t);
|
|
3223
|
+
if (count == 0)
|
|
3224
|
+
*min_value_ptr = NK_E3M2_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_E3M2_MIN,
|
|
3225
|
+
*max_index_ptr = NK_SIZE_MAX;
|
|
3226
|
+
else if (stride_elements == 1 && count > (nk_size_t)(NK_U8_MAX + 1) * 32) {
|
|
3227
|
+
nk_size_t left_count = count / 2;
|
|
3228
|
+
nk_e3m2_t left_min, right_min, left_max, right_max;
|
|
3229
|
+
nk_size_t left_min_index, right_min_index, left_max_index, right_max_index;
|
|
3230
|
+
nk_reduce_minmax_e3m2_haswell(data_ptr, left_count, stride_bytes, &left_min, &left_min_index, &left_max,
|
|
3231
|
+
&left_max_index);
|
|
3232
|
+
nk_reduce_minmax_e3m2_haswell(data_ptr + left_count, count - left_count, stride_bytes, &right_min,
|
|
3233
|
+
&right_min_index, &right_max, &right_max_index);
|
|
3234
|
+
if (nk_e3m2_order_serial(right_min, left_min) < 0)
|
|
3235
|
+
*min_value_ptr = right_min, *min_index_ptr = left_count + right_min_index;
|
|
3236
|
+
else *min_value_ptr = left_min, *min_index_ptr = left_min_index;
|
|
3237
|
+
if (nk_e3m2_order_serial(right_max, left_max) > 0)
|
|
3238
|
+
*max_value_ptr = right_max, *max_index_ptr = left_count + right_max_index;
|
|
3239
|
+
else *max_value_ptr = left_max, *max_index_ptr = left_max_index;
|
|
3240
|
+
}
|
|
3241
|
+
else if (stride_elements == 1)
|
|
3242
|
+
nk_reduce_minmax_e3m2_haswell_contiguous_(data_ptr, count, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
3243
|
+
max_index_ptr);
|
|
3244
|
+
else
|
|
3245
|
+
nk_reduce_minmax_e3m2_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
3246
|
+
max_index_ptr);
|
|
3247
|
+
}
|
|
3248
|
+
|
|
3249
|
+
NK_INTERNAL void nk_reduce_moments_bf16_haswell_contiguous_( //
|
|
3250
|
+
nk_bf16_t const *data_ptr, nk_size_t count, //
|
|
3251
|
+
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
3252
|
+
|
|
3253
|
+
__m256 sum_f32x8 = _mm256_setzero_ps();
|
|
3254
|
+
__m256 sumsq_f32x8 = _mm256_setzero_ps();
|
|
3255
|
+
nk_size_t idx = 0;
|
|
3256
|
+
for (; idx + 16 <= count; idx += 16) {
|
|
3257
|
+
__m256 low_f32x8 = _mm256_castsi256_ps(
|
|
3258
|
+
_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((__m128i const *)(data_ptr + idx))), 16));
|
|
3259
|
+
__m256 high_f32x8 = _mm256_castsi256_ps(
|
|
3260
|
+
_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((__m128i const *)(data_ptr + idx + 8))), 16));
|
|
3261
|
+
sum_f32x8 = _mm256_add_ps(sum_f32x8, low_f32x8);
|
|
3262
|
+
sum_f32x8 = _mm256_add_ps(sum_f32x8, high_f32x8);
|
|
3263
|
+
sumsq_f32x8 = _mm256_fmadd_ps(low_f32x8, low_f32x8, sumsq_f32x8);
|
|
3264
|
+
sumsq_f32x8 = _mm256_fmadd_ps(high_f32x8, high_f32x8, sumsq_f32x8);
|
|
3265
|
+
}
|
|
3266
|
+
nk_size_t remaining = count - idx;
|
|
3267
|
+
if (remaining > 0) {
|
|
3268
|
+
nk_b256_vec_t partial_vec;
|
|
3269
|
+
nk_partial_load_b16x16_serial_(data_ptr + idx, &partial_vec, remaining);
|
|
3270
|
+
nk_size_t first_half = remaining > 8 ? 8 : remaining;
|
|
3271
|
+
nk_unused_(first_half);
|
|
3272
|
+
__m256 low_f32x8 = _mm256_castsi256_ps(
|
|
3273
|
+
_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm256_castsi256_si128(partial_vec.ymm)), 16));
|
|
3274
|
+
sum_f32x8 = _mm256_add_ps(sum_f32x8, low_f32x8);
|
|
3275
|
+
sumsq_f32x8 = _mm256_fmadd_ps(low_f32x8, low_f32x8, sumsq_f32x8);
|
|
3276
|
+
if (remaining > 8) {
|
|
3277
|
+
__m256 high_f32x8 = _mm256_castsi256_ps(
|
|
3278
|
+
_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm256_extracti128_si256(partial_vec.ymm, 1)), 16));
|
|
3279
|
+
sum_f32x8 = _mm256_add_ps(sum_f32x8, high_f32x8);
|
|
3280
|
+
sumsq_f32x8 = _mm256_fmadd_ps(high_f32x8, high_f32x8, sumsq_f32x8);
|
|
3281
|
+
}
|
|
3282
|
+
}
|
|
3283
|
+
*sum_ptr = nk_reduce_add_f32x8_haswell_(sum_f32x8), *sumsq_ptr = nk_reduce_add_f32x8_haswell_(sumsq_f32x8);
|
|
3284
|
+
}
|
|
3285
|
+
|
|
3286
|
+
NK_PUBLIC void nk_reduce_moments_bf16_haswell( //
|
|
3287
|
+
nk_bf16_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
3288
|
+
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
3289
|
+
|
|
3290
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_bf16_t);
|
|
3291
|
+
int aligned = (stride_bytes % sizeof(nk_bf16_t) == 0);
|
|
3292
|
+
if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
|
|
3293
|
+
else if (!aligned) nk_reduce_moments_bf16_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
3294
|
+
else if (count > (nk_size_t)(NK_U16_MAX + 1) * 16) {
|
|
3295
|
+
nk_size_t left_count = count / 2;
|
|
3296
|
+
nk_f32_t left_sum, left_sumsq, right_sum, right_sumsq;
|
|
3297
|
+
nk_reduce_moments_bf16_haswell(data_ptr, left_count, stride_bytes, &left_sum, &left_sumsq);
|
|
3298
|
+
nk_reduce_moments_bf16_haswell(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
|
|
3299
|
+
&right_sum, &right_sumsq);
|
|
3300
|
+
*sum_ptr = left_sum + right_sum;
|
|
3301
|
+
*sumsq_ptr = left_sumsq + right_sumsq;
|
|
3302
|
+
}
|
|
3303
|
+
else if (stride_elements == 1) nk_reduce_moments_bf16_haswell_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
|
|
3304
|
+
else nk_reduce_moments_bf16_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
3305
|
+
}
|
|
3306
|
+
|
|
3307
|
+
NK_INTERNAL void nk_reduce_minmax_bf16_haswell_contiguous_( //
|
|
3308
|
+
nk_bf16_t const *data_ptr, nk_size_t count, //
|
|
3309
|
+
nk_bf16_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
3310
|
+
nk_bf16_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
3311
|
+
|
|
3312
|
+
__m256i abs_mask_u16x16 = _mm256_set1_epi16(0x7FFF);
|
|
3313
|
+
__m256i nan_threshold_u16x16 = _mm256_set1_epi16((short)0x7F80);
|
|
3314
|
+
__m256i all_ones_i16x16 = _mm256_set1_epi8((char)0xFF);
|
|
3315
|
+
__m256i min_cmp_i16x16 = _mm256_set1_epi16((short)0x7FFF);
|
|
3316
|
+
__m256i max_cmp_i16x16 = _mm256_set1_epi16((short)0x8000);
|
|
3317
|
+
__m256i min_loop_cycle_u16x16 = _mm256_setzero_si256();
|
|
3318
|
+
__m256i max_loop_cycle_u16x16 = _mm256_setzero_si256();
|
|
3319
|
+
__m256i current_loop_cycle_u16x16 = _mm256_setzero_si256();
|
|
3320
|
+
__m256i one_u16x16 = _mm256_set1_epi16(1);
|
|
3321
|
+
|
|
3322
|
+
nk_size_t idx = 0;
|
|
3323
|
+
for (; idx + 16 <= count; idx += 16) {
|
|
3324
|
+
__m256i raw_u16x16 = _mm256_loadu_si256((__m256i const *)(data_ptr + idx));
|
|
3325
|
+
__m256i data_cmp_i16x16 = nk_bf16x16_to_comparable_i16x16_haswell_(raw_u16x16);
|
|
3326
|
+
__m256i abs_u16x16 = _mm256_and_si256(raw_u16x16, abs_mask_u16x16);
|
|
3327
|
+
__m256i nan_detect_i16x16 = _mm256_cmpgt_epi16(abs_u16x16, nan_threshold_u16x16);
|
|
3328
|
+
__m256i not_nan_i16x16 = _mm256_xor_si256(nan_detect_i16x16, all_ones_i16x16);
|
|
3329
|
+
__m256i less_i16x16 = _mm256_cmpgt_epi16(min_cmp_i16x16, data_cmp_i16x16);
|
|
3330
|
+
__m256i min_changed_i16x16 = _mm256_and_si256(less_i16x16, not_nan_i16x16);
|
|
3331
|
+
min_cmp_i16x16 = _mm256_blendv_epi8(min_cmp_i16x16, data_cmp_i16x16, min_changed_i16x16);
|
|
3332
|
+
min_loop_cycle_u16x16 = _mm256_blendv_epi8(min_loop_cycle_u16x16, current_loop_cycle_u16x16,
|
|
3333
|
+
min_changed_i16x16);
|
|
3334
|
+
__m256i greater_i16x16 = _mm256_cmpgt_epi16(data_cmp_i16x16, max_cmp_i16x16);
|
|
3335
|
+
__m256i max_changed_i16x16 = _mm256_and_si256(greater_i16x16, not_nan_i16x16);
|
|
3336
|
+
max_cmp_i16x16 = _mm256_blendv_epi8(max_cmp_i16x16, data_cmp_i16x16, max_changed_i16x16);
|
|
3337
|
+
max_loop_cycle_u16x16 = _mm256_blendv_epi8(max_loop_cycle_u16x16, current_loop_cycle_u16x16,
|
|
3338
|
+
max_changed_i16x16);
|
|
3339
|
+
current_loop_cycle_u16x16 = _mm256_add_epi16(current_loop_cycle_u16x16, one_u16x16);
|
|
3340
|
+
}
|
|
3341
|
+
|
|
3342
|
+
nk_size_t remaining = count - idx;
|
|
3343
|
+
if (remaining > 0) {
|
|
3344
|
+
nk_b256_vec_t partial_vec;
|
|
3345
|
+
nk_partial_load_b16x16_serial_(data_ptr + idx, &partial_vec, remaining);
|
|
3346
|
+
__m256i raw_u16x16 = partial_vec.ymm;
|
|
3347
|
+
__m256i data_cmp_i16x16 = nk_bf16x16_to_comparable_i16x16_haswell_(raw_u16x16);
|
|
3348
|
+
__m256i abs_u16x16 = _mm256_and_si256(raw_u16x16, abs_mask_u16x16);
|
|
3349
|
+
__m256i nan_detect_i16x16 = _mm256_cmpgt_epi16(abs_u16x16, nan_threshold_u16x16);
|
|
3350
|
+
__m256i lane_indices_u16x16 = _mm256_setr_epi16(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
|
|
3351
|
+
__m256i valid_i16x16 = _mm256_cmpgt_epi16(_mm256_set1_epi16((short)remaining), lane_indices_u16x16);
|
|
3352
|
+
__m256i not_nan_valid_i16x16 = _mm256_andnot_si256(nan_detect_i16x16, valid_i16x16);
|
|
3353
|
+
__m256i less_i16x16 = _mm256_cmpgt_epi16(min_cmp_i16x16, data_cmp_i16x16);
|
|
3354
|
+
__m256i min_changed_i16x16 = _mm256_and_si256(less_i16x16, not_nan_valid_i16x16);
|
|
3355
|
+
min_cmp_i16x16 = _mm256_blendv_epi8(min_cmp_i16x16, data_cmp_i16x16, min_changed_i16x16);
|
|
3356
|
+
min_loop_cycle_u16x16 = _mm256_blendv_epi8(min_loop_cycle_u16x16, current_loop_cycle_u16x16,
|
|
3357
|
+
min_changed_i16x16);
|
|
3358
|
+
__m256i greater_i16x16 = _mm256_cmpgt_epi16(data_cmp_i16x16, max_cmp_i16x16);
|
|
3359
|
+
__m256i max_changed_i16x16 = _mm256_and_si256(greater_i16x16, not_nan_valid_i16x16);
|
|
3360
|
+
max_cmp_i16x16 = _mm256_blendv_epi8(max_cmp_i16x16, data_cmp_i16x16, max_changed_i16x16);
|
|
3361
|
+
max_loop_cycle_u16x16 = _mm256_blendv_epi8(max_loop_cycle_u16x16, current_loop_cycle_u16x16,
|
|
3362
|
+
max_changed_i16x16);
|
|
3363
|
+
}
|
|
3364
|
+
|
|
3365
|
+
nk_i16_t min_value_comparable = nk_reduce_min_i16x16_haswell_(min_cmp_i16x16);
|
|
3366
|
+
nk_i16_t max_value_comparable = nk_reduce_max_i16x16_haswell_(max_cmp_i16x16);
|
|
3367
|
+
if (min_value_comparable == 0x7FFF && max_value_comparable == (nk_i16_t)0x8000) {
|
|
3368
|
+
*min_value_ptr = NK_BF16_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_BF16_MIN,
|
|
3369
|
+
*max_index_ptr = NK_SIZE_MAX;
|
|
3370
|
+
return;
|
|
3371
|
+
}
|
|
3372
|
+
unsigned int min_lane, max_lane;
|
|
3373
|
+
{
|
|
3374
|
+
__m256i value_match_b16x16 = _mm256_cmpeq_epi16(min_cmp_i16x16, _mm256_set1_epi16(min_value_comparable));
|
|
3375
|
+
__m256i masked_cycle_u16x16 = _mm256_blendv_epi8(_mm256_set1_epi16((short)NK_U16_MAX), min_loop_cycle_u16x16,
|
|
3376
|
+
value_match_b16x16);
|
|
3377
|
+
nk_u16_t earliest_loop_cycle = nk_reduce_min_u16x16_haswell_(masked_cycle_u16x16);
|
|
3378
|
+
__m256i cycle_match_b16x16 = _mm256_cmpeq_epi16(masked_cycle_u16x16,
|
|
3379
|
+
_mm256_set1_epi16((short)earliest_loop_cycle));
|
|
3380
|
+
min_lane = _tzcnt_u32((unsigned int)_mm256_movemask_epi8(cycle_match_b16x16)) / 2;
|
|
3381
|
+
}
|
|
3382
|
+
{
|
|
3383
|
+
__m256i value_match_b16x16 = _mm256_cmpeq_epi16(max_cmp_i16x16, _mm256_set1_epi16(max_value_comparable));
|
|
3384
|
+
__m256i masked_cycle_u16x16 = _mm256_blendv_epi8(_mm256_set1_epi16((short)NK_U16_MAX), max_loop_cycle_u16x16,
|
|
3385
|
+
value_match_b16x16);
|
|
3386
|
+
nk_u16_t earliest_loop_cycle = nk_reduce_min_u16x16_haswell_(masked_cycle_u16x16);
|
|
3387
|
+
__m256i cycle_match_b16x16 = _mm256_cmpeq_epi16(masked_cycle_u16x16,
|
|
3388
|
+
_mm256_set1_epi16((short)earliest_loop_cycle));
|
|
3389
|
+
max_lane = _tzcnt_u32((unsigned int)_mm256_movemask_epi8(cycle_match_b16x16)) / 2;
|
|
3390
|
+
}
|
|
3391
|
+
nk_b256_vec_t loop_cycle_vec;
|
|
3392
|
+
loop_cycle_vec.ymm = min_loop_cycle_u16x16;
|
|
3393
|
+
*min_index_ptr = (nk_size_t)loop_cycle_vec.u16s[min_lane] * 16 + min_lane;
|
|
3394
|
+
loop_cycle_vec.ymm = max_loop_cycle_u16x16;
|
|
3395
|
+
*max_index_ptr = (nk_size_t)loop_cycle_vec.u16s[max_lane] * 16 + max_lane;
|
|
3396
|
+
nk_i16_t min_sign = min_value_comparable >> 15;
|
|
3397
|
+
*min_value_ptr = (nk_bf16_t)((nk_u16_t)min_value_comparable ^ ((nk_u16_t)min_sign >> 1));
|
|
3398
|
+
nk_i16_t max_sign = max_value_comparable >> 15;
|
|
3399
|
+
*max_value_ptr = (nk_bf16_t)((nk_u16_t)max_value_comparable ^ ((nk_u16_t)max_sign >> 1));
|
|
3400
|
+
}
|
|
3401
|
+
|
|
3402
|
+
NK_PUBLIC void nk_reduce_minmax_bf16_haswell( //
|
|
3403
|
+
nk_bf16_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
3404
|
+
nk_bf16_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
3405
|
+
nk_bf16_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
3406
|
+
|
|
3407
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_bf16_t);
|
|
3408
|
+
int aligned = (stride_bytes % sizeof(nk_bf16_t) == 0);
|
|
3409
|
+
if (count == 0)
|
|
3410
|
+
*min_value_ptr = NK_BF16_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_BF16_MIN,
|
|
3411
|
+
*max_index_ptr = NK_SIZE_MAX;
|
|
3412
|
+
else if (!aligned)
|
|
3413
|
+
nk_reduce_minmax_bf16_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
3414
|
+
max_index_ptr);
|
|
3415
|
+
else if (stride_elements == 1 && count > (nk_size_t)(NK_U16_MAX + 1) * 16) {
|
|
3416
|
+
nk_size_t left_count = count / 2;
|
|
3417
|
+
nk_bf16_t left_min, right_min, left_max, right_max;
|
|
3418
|
+
nk_size_t left_min_index, right_min_index, left_max_index, right_max_index;
|
|
3419
|
+
nk_reduce_minmax_bf16_haswell(data_ptr, left_count, stride_bytes, &left_min, &left_min_index, &left_max,
|
|
3420
|
+
&left_max_index);
|
|
3421
|
+
nk_reduce_minmax_bf16_haswell(data_ptr + left_count, count - left_count, stride_bytes, &right_min,
|
|
3422
|
+
&right_min_index, &right_max, &right_max_index);
|
|
3423
|
+
if (nk_bf16_order_serial(right_min, left_min) < 0)
|
|
3424
|
+
*min_value_ptr = right_min, *min_index_ptr = left_count + right_min_index;
|
|
3425
|
+
else *min_value_ptr = left_min, *min_index_ptr = left_min_index;
|
|
3426
|
+
if (nk_bf16_order_serial(right_max, left_max) > 0)
|
|
3427
|
+
*max_value_ptr = right_max, *max_index_ptr = left_count + right_max_index;
|
|
3428
|
+
else *max_value_ptr = left_max, *max_index_ptr = left_max_index;
|
|
3429
|
+
}
|
|
3430
|
+
else if (stride_elements == 1)
|
|
3431
|
+
nk_reduce_minmax_bf16_haswell_contiguous_(data_ptr, count, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
3432
|
+
max_index_ptr);
|
|
3433
|
+
else
|
|
3434
|
+
nk_reduce_minmax_bf16_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
3435
|
+
max_index_ptr);
|
|
3436
|
+
}
|
|
3437
|
+
|
|
3438
|
+
NK_INTERNAL void nk_reduce_moments_f16_haswell_contiguous_( //
|
|
3439
|
+
nk_f16_t const *data_ptr, nk_size_t count, //
|
|
3440
|
+
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
3441
|
+
|
|
3442
|
+
__m256 sum_f32x8 = _mm256_setzero_ps();
|
|
3443
|
+
__m256 sumsq_f32x8 = _mm256_setzero_ps();
|
|
3444
|
+
nk_size_t idx = 0;
|
|
3445
|
+
for (; idx + 16 <= count; idx += 16) {
|
|
3446
|
+
__m256 low_f32x8 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i const *)(data_ptr + idx)));
|
|
3447
|
+
__m256 high_f32x8 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i const *)(data_ptr + idx + 8)));
|
|
3448
|
+
sum_f32x8 = _mm256_add_ps(sum_f32x8, low_f32x8);
|
|
3449
|
+
sum_f32x8 = _mm256_add_ps(sum_f32x8, high_f32x8);
|
|
3450
|
+
sumsq_f32x8 = _mm256_fmadd_ps(low_f32x8, low_f32x8, sumsq_f32x8);
|
|
3451
|
+
sumsq_f32x8 = _mm256_fmadd_ps(high_f32x8, high_f32x8, sumsq_f32x8);
|
|
3452
|
+
}
|
|
3453
|
+
nk_size_t remaining = count - idx;
|
|
3454
|
+
if (remaining > 0) {
|
|
3455
|
+
nk_b256_vec_t partial_vec;
|
|
3456
|
+
nk_partial_load_b16x16_serial_(data_ptr + idx, &partial_vec, remaining);
|
|
3457
|
+
__m256 low_f32x8 = _mm256_cvtph_ps(_mm256_castsi256_si128(partial_vec.ymm));
|
|
3458
|
+
sum_f32x8 = _mm256_add_ps(sum_f32x8, low_f32x8);
|
|
3459
|
+
sumsq_f32x8 = _mm256_fmadd_ps(low_f32x8, low_f32x8, sumsq_f32x8);
|
|
3460
|
+
if (remaining > 8) {
|
|
3461
|
+
__m256 high_f32x8 = _mm256_cvtph_ps(_mm256_extracti128_si256(partial_vec.ymm, 1));
|
|
3462
|
+
sum_f32x8 = _mm256_add_ps(sum_f32x8, high_f32x8);
|
|
3463
|
+
sumsq_f32x8 = _mm256_fmadd_ps(high_f32x8, high_f32x8, sumsq_f32x8);
|
|
3464
|
+
}
|
|
3465
|
+
}
|
|
3466
|
+
*sum_ptr = nk_reduce_add_f32x8_haswell_(sum_f32x8), *sumsq_ptr = nk_reduce_add_f32x8_haswell_(sumsq_f32x8);
|
|
3467
|
+
}
|
|
3468
|
+
|
|
3469
|
+
NK_PUBLIC void nk_reduce_moments_f16_haswell( //
|
|
3470
|
+
nk_f16_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
3471
|
+
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
3472
|
+
|
|
3473
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_f16_t);
|
|
3474
|
+
int aligned = (stride_bytes % sizeof(nk_f16_t) == 0);
|
|
3475
|
+
if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
|
|
3476
|
+
else if (!aligned) nk_reduce_moments_f16_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
3477
|
+
else if (count > (nk_size_t)(NK_U16_MAX + 1) * 16) {
|
|
3478
|
+
nk_size_t left_count = count / 2;
|
|
3479
|
+
nk_f32_t left_sum, left_sumsq, right_sum, right_sumsq;
|
|
3480
|
+
nk_reduce_moments_f16_haswell(data_ptr, left_count, stride_bytes, &left_sum, &left_sumsq);
|
|
3481
|
+
nk_reduce_moments_f16_haswell(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
|
|
3482
|
+
&right_sum, &right_sumsq);
|
|
3483
|
+
*sum_ptr = left_sum + right_sum;
|
|
3484
|
+
*sumsq_ptr = left_sumsq + right_sumsq;
|
|
3485
|
+
}
|
|
3486
|
+
else if (stride_elements == 1) nk_reduce_moments_f16_haswell_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
|
|
3487
|
+
else nk_reduce_moments_f16_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
3488
|
+
}
|
|
3489
|
+
|
|
3490
|
+
NK_INTERNAL void nk_reduce_minmax_f16_haswell_contiguous_( //
|
|
3491
|
+
nk_f16_t const *data_ptr, nk_size_t count, //
|
|
3492
|
+
nk_f16_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
3493
|
+
nk_f16_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
3494
|
+
|
|
3495
|
+
__m256i abs_mask_u16x16 = _mm256_set1_epi16(0x7FFF);
|
|
3496
|
+
__m256i nan_threshold_u16x16 = _mm256_set1_epi16((short)0x7C00);
|
|
3497
|
+
__m256i all_ones_i16x16 = _mm256_set1_epi8((char)0xFF);
|
|
3498
|
+
__m256i min_cmp_i16x16 = _mm256_set1_epi16((short)0x7FFF);
|
|
3499
|
+
__m256i max_cmp_i16x16 = _mm256_set1_epi16((short)0x8000);
|
|
3500
|
+
__m256i min_loop_cycle_u16x16 = _mm256_setzero_si256();
|
|
3501
|
+
__m256i max_loop_cycle_u16x16 = _mm256_setzero_si256();
|
|
3502
|
+
__m256i current_loop_cycle_u16x16 = _mm256_setzero_si256();
|
|
3503
|
+
__m256i one_u16x16 = _mm256_set1_epi16(1);
|
|
3504
|
+
|
|
3505
|
+
nk_size_t idx = 0;
|
|
3506
|
+
for (; idx + 16 <= count; idx += 16) {
|
|
3507
|
+
__m256i raw_u16x16 = _mm256_loadu_si256((__m256i const *)(data_ptr + idx));
|
|
3508
|
+
__m256i data_cmp_i16x16 = nk_f16x16_to_comparable_i16x16_haswell_(raw_u16x16);
|
|
3509
|
+
__m256i abs_u16x16 = _mm256_and_si256(raw_u16x16, abs_mask_u16x16);
|
|
3510
|
+
__m256i nan_detect_i16x16 = _mm256_cmpgt_epi16(abs_u16x16, nan_threshold_u16x16);
|
|
3511
|
+
__m256i not_nan_i16x16 = _mm256_xor_si256(nan_detect_i16x16, all_ones_i16x16);
|
|
3512
|
+
__m256i less_i16x16 = _mm256_cmpgt_epi16(min_cmp_i16x16, data_cmp_i16x16);
|
|
3513
|
+
__m256i min_changed_i16x16 = _mm256_and_si256(less_i16x16, not_nan_i16x16);
|
|
3514
|
+
min_cmp_i16x16 = _mm256_blendv_epi8(min_cmp_i16x16, data_cmp_i16x16, min_changed_i16x16);
|
|
3515
|
+
min_loop_cycle_u16x16 = _mm256_blendv_epi8(min_loop_cycle_u16x16, current_loop_cycle_u16x16,
|
|
3516
|
+
min_changed_i16x16);
|
|
3517
|
+
__m256i greater_i16x16 = _mm256_cmpgt_epi16(data_cmp_i16x16, max_cmp_i16x16);
|
|
3518
|
+
__m256i max_changed_i16x16 = _mm256_and_si256(greater_i16x16, not_nan_i16x16);
|
|
3519
|
+
max_cmp_i16x16 = _mm256_blendv_epi8(max_cmp_i16x16, data_cmp_i16x16, max_changed_i16x16);
|
|
3520
|
+
max_loop_cycle_u16x16 = _mm256_blendv_epi8(max_loop_cycle_u16x16, current_loop_cycle_u16x16,
|
|
3521
|
+
max_changed_i16x16);
|
|
3522
|
+
current_loop_cycle_u16x16 = _mm256_add_epi16(current_loop_cycle_u16x16, one_u16x16);
|
|
3523
|
+
}
|
|
3524
|
+
|
|
3525
|
+
nk_size_t remaining = count - idx;
|
|
3526
|
+
if (remaining > 0) {
|
|
3527
|
+
nk_b256_vec_t partial_vec;
|
|
3528
|
+
nk_partial_load_b16x16_serial_(data_ptr + idx, &partial_vec, remaining);
|
|
3529
|
+
__m256i raw_u16x16 = partial_vec.ymm;
|
|
3530
|
+
__m256i data_cmp_i16x16 = nk_f16x16_to_comparable_i16x16_haswell_(raw_u16x16);
|
|
3531
|
+
__m256i abs_u16x16 = _mm256_and_si256(raw_u16x16, abs_mask_u16x16);
|
|
3532
|
+
__m256i nan_detect_i16x16 = _mm256_cmpgt_epi16(abs_u16x16, nan_threshold_u16x16);
|
|
3533
|
+
__m256i lane_indices_u16x16 = _mm256_setr_epi16(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
|
|
3534
|
+
__m256i valid_i16x16 = _mm256_cmpgt_epi16(_mm256_set1_epi16((short)remaining), lane_indices_u16x16);
|
|
3535
|
+
__m256i not_nan_valid_i16x16 = _mm256_andnot_si256(nan_detect_i16x16, valid_i16x16);
|
|
3536
|
+
__m256i less_i16x16 = _mm256_cmpgt_epi16(min_cmp_i16x16, data_cmp_i16x16);
|
|
3537
|
+
__m256i min_changed_i16x16 = _mm256_and_si256(less_i16x16, not_nan_valid_i16x16);
|
|
3538
|
+
min_cmp_i16x16 = _mm256_blendv_epi8(min_cmp_i16x16, data_cmp_i16x16, min_changed_i16x16);
|
|
3539
|
+
min_loop_cycle_u16x16 = _mm256_blendv_epi8(min_loop_cycle_u16x16, current_loop_cycle_u16x16,
|
|
3540
|
+
min_changed_i16x16);
|
|
3541
|
+
__m256i greater_i16x16 = _mm256_cmpgt_epi16(data_cmp_i16x16, max_cmp_i16x16);
|
|
3542
|
+
__m256i max_changed_i16x16 = _mm256_and_si256(greater_i16x16, not_nan_valid_i16x16);
|
|
3543
|
+
max_cmp_i16x16 = _mm256_blendv_epi8(max_cmp_i16x16, data_cmp_i16x16, max_changed_i16x16);
|
|
3544
|
+
max_loop_cycle_u16x16 = _mm256_blendv_epi8(max_loop_cycle_u16x16, current_loop_cycle_u16x16,
|
|
3545
|
+
max_changed_i16x16);
|
|
3546
|
+
}
|
|
3547
|
+
|
|
3548
|
+
nk_i16_t min_value_comparable = nk_reduce_min_i16x16_haswell_(min_cmp_i16x16);
|
|
3549
|
+
nk_i16_t max_value_comparable = nk_reduce_max_i16x16_haswell_(max_cmp_i16x16);
|
|
3550
|
+
if (min_value_comparable == 0x7FFF && max_value_comparable == (nk_i16_t)0x8000) {
|
|
3551
|
+
*min_value_ptr = NK_F16_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_F16_MIN,
|
|
3552
|
+
*max_index_ptr = NK_SIZE_MAX;
|
|
3553
|
+
return;
|
|
3554
|
+
}
|
|
3555
|
+
unsigned int min_lane, max_lane;
|
|
3556
|
+
{
|
|
3557
|
+
__m256i value_match_b16x16 = _mm256_cmpeq_epi16(min_cmp_i16x16, _mm256_set1_epi16(min_value_comparable));
|
|
3558
|
+
__m256i masked_cycle_u16x16 = _mm256_blendv_epi8(_mm256_set1_epi16((short)NK_U16_MAX), min_loop_cycle_u16x16,
|
|
3559
|
+
value_match_b16x16);
|
|
3560
|
+
nk_u16_t earliest_loop_cycle = nk_reduce_min_u16x16_haswell_(masked_cycle_u16x16);
|
|
3561
|
+
__m256i cycle_match_b16x16 = _mm256_cmpeq_epi16(masked_cycle_u16x16,
|
|
3562
|
+
_mm256_set1_epi16((short)earliest_loop_cycle));
|
|
3563
|
+
min_lane = _tzcnt_u32((unsigned int)_mm256_movemask_epi8(cycle_match_b16x16)) / 2;
|
|
3564
|
+
}
|
|
3565
|
+
{
|
|
3566
|
+
__m256i value_match_b16x16 = _mm256_cmpeq_epi16(max_cmp_i16x16, _mm256_set1_epi16(max_value_comparable));
|
|
3567
|
+
__m256i masked_cycle_u16x16 = _mm256_blendv_epi8(_mm256_set1_epi16((short)NK_U16_MAX), max_loop_cycle_u16x16,
|
|
3568
|
+
value_match_b16x16);
|
|
3569
|
+
nk_u16_t earliest_loop_cycle = nk_reduce_min_u16x16_haswell_(masked_cycle_u16x16);
|
|
3570
|
+
__m256i cycle_match_b16x16 = _mm256_cmpeq_epi16(masked_cycle_u16x16,
|
|
3571
|
+
_mm256_set1_epi16((short)earliest_loop_cycle));
|
|
3572
|
+
max_lane = _tzcnt_u32((unsigned int)_mm256_movemask_epi8(cycle_match_b16x16)) / 2;
|
|
3573
|
+
}
|
|
3574
|
+
nk_b256_vec_t loop_cycle_vec;
|
|
3575
|
+
loop_cycle_vec.ymm = min_loop_cycle_u16x16;
|
|
3576
|
+
*min_index_ptr = (nk_size_t)loop_cycle_vec.u16s[min_lane] * 16 + min_lane;
|
|
3577
|
+
loop_cycle_vec.ymm = max_loop_cycle_u16x16;
|
|
3578
|
+
*max_index_ptr = (nk_size_t)loop_cycle_vec.u16s[max_lane] * 16 + max_lane;
|
|
3579
|
+
nk_i16_t min_sign = min_value_comparable >> 15;
|
|
3580
|
+
*min_value_ptr = (nk_f16_t)((nk_u16_t)min_value_comparable ^ ((nk_u16_t)min_sign >> 1));
|
|
3581
|
+
nk_i16_t max_sign = max_value_comparable >> 15;
|
|
3582
|
+
*max_value_ptr = (nk_f16_t)((nk_u16_t)max_value_comparable ^ ((nk_u16_t)max_sign >> 1));
|
|
3583
|
+
}
|
|
3584
|
+
|
|
3585
|
+
NK_PUBLIC void nk_reduce_minmax_f16_haswell( //
|
|
3586
|
+
nk_f16_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
3587
|
+
nk_f16_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
3588
|
+
nk_f16_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
3589
|
+
|
|
3590
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_f16_t);
|
|
3591
|
+
int aligned = (stride_bytes % sizeof(nk_f16_t) == 0);
|
|
3592
|
+
if (count == 0)
|
|
3593
|
+
*min_value_ptr = NK_F16_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_F16_MIN,
|
|
3594
|
+
*max_index_ptr = NK_SIZE_MAX;
|
|
3595
|
+
else if (!aligned)
|
|
3596
|
+
nk_reduce_minmax_f16_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
3597
|
+
max_index_ptr);
|
|
3598
|
+
else if (stride_elements == 1 && count > (nk_size_t)(NK_U16_MAX + 1) * 16) {
|
|
3599
|
+
nk_size_t left_count = count / 2;
|
|
3600
|
+
nk_f16_t left_min, right_min, left_max, right_max;
|
|
3601
|
+
nk_size_t left_min_index, right_min_index, left_max_index, right_max_index;
|
|
3602
|
+
nk_reduce_minmax_f16_haswell(data_ptr, left_count, stride_bytes, &left_min, &left_min_index, &left_max,
|
|
3603
|
+
&left_max_index);
|
|
3604
|
+
nk_reduce_minmax_f16_haswell(data_ptr + left_count, count - left_count, stride_bytes, &right_min,
|
|
3605
|
+
&right_min_index, &right_max, &right_max_index);
|
|
3606
|
+
if (nk_f16_order_serial(right_min, left_min) < 0)
|
|
3607
|
+
*min_value_ptr = right_min, *min_index_ptr = left_count + right_min_index;
|
|
3608
|
+
else *min_value_ptr = left_min, *min_index_ptr = left_min_index;
|
|
3609
|
+
if (nk_f16_order_serial(right_max, left_max) > 0)
|
|
3610
|
+
*max_value_ptr = right_max, *max_index_ptr = left_count + right_max_index;
|
|
3611
|
+
else *max_value_ptr = left_max, *max_index_ptr = left_max_index;
|
|
3612
|
+
}
|
|
3613
|
+
else if (stride_elements == 1)
|
|
3614
|
+
nk_reduce_minmax_f16_haswell_contiguous_(data_ptr, count, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
3615
|
+
max_index_ptr);
|
|
3616
|
+
else
|
|
3617
|
+
nk_reduce_minmax_f16_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
3618
|
+
max_index_ptr);
|
|
3619
|
+
}
|
|
3620
|
+
|
|
3621
|
+
NK_INTERNAL void nk_reduce_moments_i4_haswell_contiguous_( //
|
|
3622
|
+
nk_i4x2_t const *data_ptr, nk_size_t count, //
|
|
3623
|
+
nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
3624
|
+
|
|
3625
|
+
__m256i mask_0f_i8x32 = _mm256_set1_epi8(0x0F);
|
|
3626
|
+
__m256i eight_i8x32 = _mm256_set1_epi8(8);
|
|
3627
|
+
__m256i zero_i8x32 = _mm256_setzero_si256();
|
|
3628
|
+
__m256i sq_lut_u8x32 = _mm256_setr_epi8( //
|
|
3629
|
+
0, 1, 4, 9, 16, 25, 36, 49, 64, 81, 100, 121, (char)144, (char)169, //
|
|
3630
|
+
(char)196, (char)225, //
|
|
3631
|
+
0, 1, 4, 9, 16, 25, 36, 49, 64, 81, 100, 121, (char)144, (char)169, //
|
|
3632
|
+
(char)196, (char)225);
|
|
3633
|
+
__m256i sum_u64x4 = _mm256_setzero_si256();
|
|
3634
|
+
__m256i sumsq_u64x4 = _mm256_setzero_si256();
|
|
3635
|
+
nk_size_t count_bytes = nk_size_divide_round_up_(count, 2);
|
|
3636
|
+
unsigned char const *ptr = (unsigned char const *)data_ptr;
|
|
3637
|
+
while (count_bytes > 0) {
|
|
3638
|
+
nk_b256_vec_t raw_vec;
|
|
3639
|
+
if (count_bytes < 32) {
|
|
3640
|
+
nk_partial_load_b8x32_serial_(ptr, &raw_vec, count_bytes);
|
|
3641
|
+
count_bytes = 0;
|
|
3642
|
+
}
|
|
3643
|
+
else {
|
|
3644
|
+
raw_vec.ymm = _mm256_loadu_si256((__m256i const *)ptr);
|
|
3645
|
+
ptr += 32, count_bytes -= 32;
|
|
3646
|
+
}
|
|
3647
|
+
__m256i raw_i8x32 = raw_vec.ymm;
|
|
3648
|
+
__m256i low_u4x32 = _mm256_and_si256(raw_i8x32, mask_0f_i8x32);
|
|
3649
|
+
__m256i high_u4x32 = _mm256_and_si256(_mm256_srli_epi16(raw_i8x32, 4), mask_0f_i8x32);
|
|
3650
|
+
__m256i low_biased_u4x32 = _mm256_xor_si256(low_u4x32, eight_i8x32);
|
|
3651
|
+
__m256i high_biased_u4x32 = _mm256_xor_si256(high_u4x32, eight_i8x32);
|
|
3652
|
+
__m256i pair_sum = _mm256_add_epi8(low_biased_u4x32, high_biased_u4x32);
|
|
3653
|
+
sum_u64x4 = _mm256_add_epi64(sum_u64x4, _mm256_sad_epu8(pair_sum, zero_i8x32));
|
|
3654
|
+
__m256i low_sq_u8x32 = _mm256_shuffle_epi8(sq_lut_u8x32, low_u4x32);
|
|
3655
|
+
__m256i high_sq_u8x32 = _mm256_shuffle_epi8(sq_lut_u8x32, high_u4x32);
|
|
3656
|
+
sumsq_u64x4 = _mm256_add_epi64(sumsq_u64x4, _mm256_sad_epu8(low_sq_u8x32, zero_i8x32));
|
|
3657
|
+
sumsq_u64x4 = _mm256_add_epi64(sumsq_u64x4, _mm256_sad_epu8(high_sq_u8x32, zero_i8x32));
|
|
3658
|
+
}
|
|
3659
|
+
// The XOR-8 bias adds 8 per nibble to the SAD total. Subtract 8 × total nibbles processed
|
|
3660
|
+
// (including zero-padded register bytes, where 0 XOR 8 = 8, signed = 0).
|
|
3661
|
+
nk_size_t nibbles_processed = nk_size_round_up_to_multiple_(nk_size_divide_round_up_(count, 2), 32) * 2;
|
|
3662
|
+
nk_i64_t sum = (nk_i64_t)(nk_u64_t)nk_reduce_add_i64x4_haswell_(sum_u64x4) -
|
|
3663
|
+
(nk_i64_t)8 * (nk_i64_t)nibbles_processed;
|
|
3664
|
+
// sumsq uses sq_lut[0]=0 for zero-padded nibbles, so no register-padding correction needed.
|
|
3665
|
+
nk_u64_t sumsq = (nk_u64_t)nk_reduce_add_i64x4_haswell_(sumsq_u64x4);
|
|
3666
|
+
*sum_ptr = sum, *sumsq_ptr = sumsq;
|
|
3667
|
+
}
|
|
3668
|
+
|
|
3669
|
+
NK_PUBLIC void nk_reduce_moments_i4_haswell( //
|
|
3670
|
+
nk_i4x2_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
3671
|
+
nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
3672
|
+
|
|
3673
|
+
count = nk_size_round_up_to_multiple_(count, 2);
|
|
3674
|
+
if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
|
|
3675
|
+
else if (stride_bytes == 1) nk_reduce_moments_i4_haswell_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
|
|
3676
|
+
else nk_reduce_moments_i4_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
3677
|
+
}
|
|
3678
|
+
|
|
3679
|
+
NK_INTERNAL void nk_reduce_moments_u4_haswell_contiguous_( //
|
|
3680
|
+
nk_u4x2_t const *data_ptr, nk_size_t count, //
|
|
3681
|
+
nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
3682
|
+
|
|
3683
|
+
__m256i mask_0f_i8x32 = _mm256_set1_epi8(0x0F);
|
|
3684
|
+
__m256i zero_i8x32 = _mm256_setzero_si256();
|
|
3685
|
+
__m256i sq_lut_u8x32 = _mm256_setr_epi8( //
|
|
3686
|
+
0, 1, 4, 9, 16, 25, 36, 49, 64, 81, 100, 121, (char)144, (char)169, //
|
|
3687
|
+
(char)196, (char)225, //
|
|
3688
|
+
0, 1, 4, 9, 16, 25, 36, 49, 64, 81, 100, 121, (char)144, (char)169, //
|
|
3689
|
+
(char)196, (char)225);
|
|
3690
|
+
__m256i sum_u64x4 = _mm256_setzero_si256();
|
|
3691
|
+
__m256i sumsq_u64x4 = _mm256_setzero_si256();
|
|
3692
|
+
nk_size_t count_bytes = count / 2;
|
|
3693
|
+
unsigned char const *ptr = (unsigned char const *)data_ptr;
|
|
3694
|
+
while (count_bytes > 0) {
|
|
3695
|
+
nk_b256_vec_t raw_vec;
|
|
3696
|
+
if (count_bytes < 32) {
|
|
3697
|
+
nk_partial_load_b8x32_serial_(ptr, &raw_vec, count_bytes);
|
|
3698
|
+
count_bytes = 0;
|
|
3699
|
+
}
|
|
3700
|
+
else {
|
|
3701
|
+
raw_vec.ymm = _mm256_loadu_si256((__m256i const *)ptr);
|
|
3702
|
+
ptr += 32, count_bytes -= 32;
|
|
3703
|
+
}
|
|
3704
|
+
__m256i raw_i8x32 = raw_vec.ymm;
|
|
3705
|
+
__m256i low_u4x32 = _mm256_and_si256(raw_i8x32, mask_0f_i8x32);
|
|
3706
|
+
__m256i high_u4x32 = _mm256_and_si256(_mm256_srli_epi16(raw_i8x32, 4), mask_0f_i8x32);
|
|
3707
|
+
__m256i pair_sum = _mm256_add_epi8(low_u4x32, high_u4x32);
|
|
3708
|
+
sum_u64x4 = _mm256_add_epi64(sum_u64x4, _mm256_sad_epu8(pair_sum, zero_i8x32));
|
|
3709
|
+
__m256i low_sq_u8x32 = _mm256_shuffle_epi8(sq_lut_u8x32, low_u4x32);
|
|
3710
|
+
__m256i high_sq_u8x32 = _mm256_shuffle_epi8(sq_lut_u8x32, high_u4x32);
|
|
3711
|
+
sumsq_u64x4 = _mm256_add_epi64(sumsq_u64x4, _mm256_sad_epu8(low_sq_u8x32, zero_i8x32));
|
|
3712
|
+
sumsq_u64x4 = _mm256_add_epi64(sumsq_u64x4, _mm256_sad_epu8(high_sq_u8x32, zero_i8x32));
|
|
3713
|
+
}
|
|
3714
|
+
nk_u64_t sum = (nk_u64_t)nk_reduce_add_i64x4_haswell_(sum_u64x4);
|
|
3715
|
+
nk_u64_t sumsq = (nk_u64_t)nk_reduce_add_i64x4_haswell_(sumsq_u64x4);
|
|
3716
|
+
*sum_ptr = sum, *sumsq_ptr = sumsq;
|
|
3717
|
+
}
|
|
3718
|
+
|
|
3719
|
+
NK_PUBLIC void nk_reduce_moments_u4_haswell( //
|
|
3720
|
+
nk_u4x2_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
3721
|
+
nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
3722
|
+
|
|
3723
|
+
count = nk_size_round_up_to_multiple_(count, 2);
|
|
3724
|
+
if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
|
|
3725
|
+
else if (stride_bytes == 1) nk_reduce_moments_u4_haswell_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
|
|
3726
|
+
else nk_reduce_moments_u4_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
3727
|
+
}
|
|
3728
|
+
|
|
3729
|
+
NK_INTERNAL void nk_reduce_moments_u1_haswell_contiguous_( //
|
|
3730
|
+
nk_u1x8_t const *data_ptr, nk_size_t count, //
|
|
3731
|
+
nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
3732
|
+
|
|
3733
|
+
__m256i lut_i8x32 = _mm256_setr_epi8( //
|
|
3734
|
+
0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4);
|
|
3735
|
+
__m256i mask_0f_i8x32 = _mm256_set1_epi8(0x0F);
|
|
3736
|
+
__m256i zero_i8x32 = _mm256_setzero_si256();
|
|
3737
|
+
__m256i sum_u64x4 = _mm256_setzero_si256();
|
|
3738
|
+
nk_size_t count_bytes = count / 8;
|
|
3739
|
+
unsigned char const *ptr = (unsigned char const *)data_ptr;
|
|
3740
|
+
while (count_bytes > 0) {
|
|
3741
|
+
nk_b256_vec_t raw_vec;
|
|
3742
|
+
if (count_bytes < 32) {
|
|
3743
|
+
nk_partial_load_b8x32_serial_(ptr, &raw_vec, count_bytes);
|
|
3744
|
+
count_bytes = 0;
|
|
3745
|
+
}
|
|
3746
|
+
else {
|
|
3747
|
+
raw_vec.ymm = _mm256_loadu_si256((__m256i const *)ptr);
|
|
3748
|
+
ptr += 32, count_bytes -= 32;
|
|
3749
|
+
}
|
|
3750
|
+
__m256i raw_i8x32 = raw_vec.ymm;
|
|
3751
|
+
__m256i low_nibble_u8x32 = _mm256_and_si256(raw_i8x32, mask_0f_i8x32);
|
|
3752
|
+
__m256i high_nibble_u8x32 = _mm256_and_si256(_mm256_srli_epi16(raw_i8x32, 4), mask_0f_i8x32);
|
|
3753
|
+
__m256i popcnt_u8x32 = _mm256_add_epi8(_mm256_shuffle_epi8(lut_i8x32, low_nibble_u8x32),
|
|
3754
|
+
_mm256_shuffle_epi8(lut_i8x32, high_nibble_u8x32));
|
|
3755
|
+
sum_u64x4 = _mm256_add_epi64(sum_u64x4, _mm256_sad_epu8(popcnt_u8x32, zero_i8x32));
|
|
3756
|
+
}
|
|
3757
|
+
nk_u64_t sum = (nk_u64_t)nk_reduce_add_i64x4_haswell_(sum_u64x4);
|
|
3758
|
+
*sum_ptr = sum, *sumsq_ptr = sum;
|
|
3759
|
+
}
|
|
3760
|
+
|
|
3761
|
+
NK_PUBLIC void nk_reduce_moments_u1_haswell( //
|
|
3762
|
+
nk_u1x8_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
3763
|
+
nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
3764
|
+
|
|
3765
|
+
count = nk_size_round_up_to_multiple_(count, 8);
|
|
3766
|
+
if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
|
|
3767
|
+
else if (stride_bytes == 1) nk_reduce_moments_u1_haswell_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
|
|
3768
|
+
else nk_reduce_moments_u1_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
3769
|
+
}
|
|
3770
|
+
|
|
3771
|
+
#if defined(__clang__)
|
|
3772
|
+
#pragma clang attribute pop
|
|
3773
|
+
#elif defined(__GNUC__)
|
|
3774
|
+
#pragma GCC pop_options
|
|
3775
|
+
#endif
|
|
3776
|
+
|
|
3777
|
+
#if defined(__cplusplus)
|
|
3778
|
+
} // extern "C"
|
|
3779
|
+
#endif
|
|
3780
|
+
|
|
3781
|
+
#endif // NK_TARGET_HASWELL
|
|
3782
|
+
#endif // NK_TARGET_X86_
|
|
3783
|
+
#endif // NK_REDUCE_HASWELL_H
|