numkong 7.0.0 → 7.4.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +197 -124
- package/binding.gyp +34 -484
- package/c/dispatch_bf16.c +59 -1
- package/c/dispatch_e2m3.c +41 -8
- package/c/dispatch_e3m2.c +49 -8
- package/c/dispatch_e4m3.c +51 -9
- package/c/dispatch_e5m2.c +45 -1
- package/c/dispatch_f16.c +79 -26
- package/c/dispatch_f16c.c +5 -5
- package/c/dispatch_f32.c +56 -0
- package/c/dispatch_f64.c +52 -0
- package/c/dispatch_i4.c +3 -0
- package/c/dispatch_i8.c +62 -3
- package/c/dispatch_other.c +18 -0
- package/c/dispatch_u1.c +54 -9
- package/c/dispatch_u4.c +3 -0
- package/c/dispatch_u8.c +64 -3
- package/c/numkong.c +3 -0
- package/include/README.md +79 -9
- package/include/numkong/attention/sapphireamx.h +278 -276
- package/include/numkong/attention/sme.h +983 -977
- package/include/numkong/attention.h +1 -1
- package/include/numkong/capabilities.h +289 -94
- package/include/numkong/cast/README.md +40 -40
- package/include/numkong/cast/diamond.h +64 -0
- package/include/numkong/cast/haswell.h +42 -194
- package/include/numkong/cast/icelake.h +42 -37
- package/include/numkong/cast/loongsonasx.h +252 -0
- package/include/numkong/cast/neon.h +216 -249
- package/include/numkong/cast/powervsx.h +449 -0
- package/include/numkong/cast/rvv.h +223 -274
- package/include/numkong/cast/sapphire.h +18 -18
- package/include/numkong/cast/serial.h +1018 -944
- package/include/numkong/cast/skylake.h +82 -23
- package/include/numkong/cast/v128relaxed.h +462 -105
- package/include/numkong/cast.h +24 -0
- package/include/numkong/cast.hpp +44 -0
- package/include/numkong/curved/README.md +17 -17
- package/include/numkong/curved/neon.h +131 -7
- package/include/numkong/curved/neonbfdot.h +6 -7
- package/include/numkong/curved/rvv.h +26 -26
- package/include/numkong/curved/smef64.h +186 -182
- package/include/numkong/curved.h +14 -18
- package/include/numkong/dot/README.md +154 -137
- package/include/numkong/dot/alder.h +43 -43
- package/include/numkong/dot/diamond.h +158 -0
- package/include/numkong/dot/genoa.h +4 -30
- package/include/numkong/dot/haswell.h +215 -180
- package/include/numkong/dot/icelake.h +190 -76
- package/include/numkong/dot/loongsonasx.h +671 -0
- package/include/numkong/dot/neon.h +124 -73
- package/include/numkong/dot/neonbfdot.h +11 -12
- package/include/numkong/dot/neonfhm.h +44 -46
- package/include/numkong/dot/neonfp8.h +323 -0
- package/include/numkong/dot/neonsdot.h +190 -76
- package/include/numkong/dot/powervsx.h +752 -0
- package/include/numkong/dot/rvv.h +92 -84
- package/include/numkong/dot/rvvbf16.h +12 -12
- package/include/numkong/dot/rvvhalf.h +12 -12
- package/include/numkong/dot/sapphire.h +4 -4
- package/include/numkong/dot/serial.h +66 -30
- package/include/numkong/dot/sierra.h +31 -31
- package/include/numkong/dot/skylake.h +142 -110
- package/include/numkong/dot/sve.h +217 -177
- package/include/numkong/dot/svebfdot.h +10 -10
- package/include/numkong/dot/svehalf.h +85 -41
- package/include/numkong/dot/svesdot.h +89 -0
- package/include/numkong/dot/v128relaxed.h +124 -89
- package/include/numkong/dot.h +114 -48
- package/include/numkong/dots/README.md +203 -203
- package/include/numkong/dots/alder.h +12 -9
- package/include/numkong/dots/diamond.h +86 -0
- package/include/numkong/dots/genoa.h +10 -4
- package/include/numkong/dots/haswell.h +63 -48
- package/include/numkong/dots/icelake.h +27 -18
- package/include/numkong/dots/loongsonasx.h +176 -0
- package/include/numkong/dots/neon.h +14 -11
- package/include/numkong/dots/neonbfdot.h +4 -3
- package/include/numkong/dots/neonfhm.h +11 -9
- package/include/numkong/dots/neonfp8.h +99 -0
- package/include/numkong/dots/neonsdot.h +48 -12
- package/include/numkong/dots/powervsx.h +194 -0
- package/include/numkong/dots/rvv.h +451 -344
- package/include/numkong/dots/sapphireamx.h +1028 -984
- package/include/numkong/dots/serial.h +213 -197
- package/include/numkong/dots/sierra.h +10 -7
- package/include/numkong/dots/skylake.h +47 -36
- package/include/numkong/dots/sme.h +2001 -2364
- package/include/numkong/dots/smebi32.h +175 -162
- package/include/numkong/dots/smef64.h +328 -323
- package/include/numkong/dots/v128relaxed.h +64 -41
- package/include/numkong/dots.h +573 -293
- package/include/numkong/dots.hpp +45 -43
- package/include/numkong/each/README.md +133 -137
- package/include/numkong/each/haswell.h +6 -6
- package/include/numkong/each/icelake.h +7 -7
- package/include/numkong/each/neon.h +76 -42
- package/include/numkong/each/neonbfdot.h +11 -12
- package/include/numkong/each/neonhalf.h +24 -116
- package/include/numkong/each/rvv.h +28 -28
- package/include/numkong/each/sapphire.h +27 -161
- package/include/numkong/each/serial.h +6 -6
- package/include/numkong/each/skylake.h +7 -7
- package/include/numkong/each/v128relaxed.h +562 -0
- package/include/numkong/each.h +148 -62
- package/include/numkong/each.hpp +2 -2
- package/include/numkong/geospatial/README.md +18 -18
- package/include/numkong/geospatial/haswell.h +365 -325
- package/include/numkong/geospatial/neon.h +350 -306
- package/include/numkong/geospatial/rvv.h +4 -4
- package/include/numkong/geospatial/skylake.h +376 -340
- package/include/numkong/geospatial/v128relaxed.h +366 -327
- package/include/numkong/geospatial.h +17 -17
- package/include/numkong/matrix.hpp +4 -4
- package/include/numkong/maxsim/README.md +14 -14
- package/include/numkong/maxsim/alder.h +6 -6
- package/include/numkong/maxsim/genoa.h +4 -4
- package/include/numkong/maxsim/haswell.h +6 -6
- package/include/numkong/maxsim/icelake.h +18 -18
- package/include/numkong/maxsim/neonsdot.h +21 -21
- package/include/numkong/maxsim/sapphireamx.h +14 -14
- package/include/numkong/maxsim/serial.h +6 -6
- package/include/numkong/maxsim/sme.h +221 -196
- package/include/numkong/maxsim/v128relaxed.h +6 -6
- package/include/numkong/mesh/README.md +62 -56
- package/include/numkong/mesh/haswell.h +339 -464
- package/include/numkong/mesh/neon.h +1100 -519
- package/include/numkong/mesh/neonbfdot.h +36 -68
- package/include/numkong/mesh/rvv.h +530 -435
- package/include/numkong/mesh/serial.h +75 -91
- package/include/numkong/mesh/skylake.h +1627 -302
- package/include/numkong/mesh/v128relaxed.h +443 -330
- package/include/numkong/mesh.h +63 -49
- package/include/numkong/mesh.hpp +4 -4
- package/include/numkong/numkong.h +3 -3
- package/include/numkong/numkong.hpp +1 -0
- package/include/numkong/probability/README.md +23 -19
- package/include/numkong/probability/neon.h +82 -52
- package/include/numkong/probability/rvv.h +28 -23
- package/include/numkong/probability/serial.h +51 -39
- package/include/numkong/probability.h +20 -23
- package/include/numkong/random.h +1 -1
- package/include/numkong/reduce/README.md +143 -138
- package/include/numkong/reduce/alder.h +81 -77
- package/include/numkong/reduce/haswell.h +222 -220
- package/include/numkong/reduce/neon.h +629 -519
- package/include/numkong/reduce/neonbfdot.h +7 -218
- package/include/numkong/reduce/neonfhm.h +9 -381
- package/include/numkong/reduce/neonsdot.h +9 -9
- package/include/numkong/reduce/rvv.h +928 -802
- package/include/numkong/reduce/serial.h +23 -27
- package/include/numkong/reduce/sierra.h +20 -20
- package/include/numkong/reduce/skylake.h +326 -324
- package/include/numkong/reduce/v128relaxed.h +52 -52
- package/include/numkong/reduce.h +4 -23
- package/include/numkong/reduce.hpp +156 -11
- package/include/numkong/scalar/README.md +6 -6
- package/include/numkong/scalar/haswell.h +26 -17
- package/include/numkong/scalar/loongsonasx.h +74 -0
- package/include/numkong/scalar/neon.h +9 -9
- package/include/numkong/scalar/powervsx.h +96 -0
- package/include/numkong/scalar/rvv.h +2 -2
- package/include/numkong/scalar/sapphire.h +21 -10
- package/include/numkong/scalar/serial.h +21 -21
- package/include/numkong/scalar.h +13 -0
- package/include/numkong/set/README.md +28 -28
- package/include/numkong/set/haswell.h +12 -12
- package/include/numkong/set/icelake.h +14 -14
- package/include/numkong/set/loongsonasx.h +181 -0
- package/include/numkong/set/neon.h +17 -18
- package/include/numkong/set/powervsx.h +326 -0
- package/include/numkong/set/rvv.h +4 -4
- package/include/numkong/set/serial.h +6 -6
- package/include/numkong/set/sve.h +60 -59
- package/include/numkong/set/v128relaxed.h +6 -6
- package/include/numkong/set.h +21 -7
- package/include/numkong/sets/README.md +26 -26
- package/include/numkong/sets/loongsonasx.h +52 -0
- package/include/numkong/sets/powervsx.h +65 -0
- package/include/numkong/sets/smebi32.h +395 -364
- package/include/numkong/sets.h +83 -40
- package/include/numkong/sparse/README.md +4 -4
- package/include/numkong/sparse/icelake.h +101 -101
- package/include/numkong/sparse/serial.h +1 -1
- package/include/numkong/sparse/sve2.h +137 -141
- package/include/numkong/sparse/turin.h +12 -12
- package/include/numkong/sparse.h +10 -10
- package/include/numkong/spatial/README.md +230 -226
- package/include/numkong/spatial/alder.h +113 -116
- package/include/numkong/spatial/diamond.h +240 -0
- package/include/numkong/spatial/genoa.h +0 -68
- package/include/numkong/spatial/haswell.h +74 -55
- package/include/numkong/spatial/icelake.h +539 -58
- package/include/numkong/spatial/loongsonasx.h +483 -0
- package/include/numkong/spatial/neon.h +125 -52
- package/include/numkong/spatial/neonbfdot.h +8 -9
- package/include/numkong/spatial/neonfp8.h +258 -0
- package/include/numkong/spatial/neonsdot.h +180 -12
- package/include/numkong/spatial/powervsx.h +738 -0
- package/include/numkong/spatial/rvv.h +146 -139
- package/include/numkong/spatial/rvvbf16.h +17 -12
- package/include/numkong/spatial/rvvhalf.h +13 -10
- package/include/numkong/spatial/serial.h +13 -12
- package/include/numkong/spatial/sierra.h +232 -39
- package/include/numkong/spatial/skylake.h +73 -74
- package/include/numkong/spatial/sve.h +93 -72
- package/include/numkong/spatial/svebfdot.h +29 -29
- package/include/numkong/spatial/svehalf.h +52 -26
- package/include/numkong/spatial/svesdot.h +142 -0
- package/include/numkong/spatial/v128relaxed.h +293 -41
- package/include/numkong/spatial.h +338 -82
- package/include/numkong/spatials/README.md +194 -194
- package/include/numkong/spatials/diamond.h +82 -0
- package/include/numkong/spatials/haswell.h +2 -2
- package/include/numkong/spatials/loongsonasx.h +153 -0
- package/include/numkong/spatials/neonfp8.h +111 -0
- package/include/numkong/spatials/neonsdot.h +34 -0
- package/include/numkong/spatials/powervsx.h +153 -0
- package/include/numkong/spatials/rvv.h +259 -243
- package/include/numkong/spatials/sapphireamx.h +173 -173
- package/include/numkong/spatials/serial.h +2 -2
- package/include/numkong/spatials/skylake.h +2 -2
- package/include/numkong/spatials/sme.h +590 -605
- package/include/numkong/spatials/smef64.h +139 -130
- package/include/numkong/spatials/v128relaxed.h +2 -2
- package/include/numkong/spatials.h +820 -500
- package/include/numkong/spatials.hpp +49 -48
- package/include/numkong/tensor.hpp +406 -17
- package/include/numkong/trigonometry/README.md +19 -19
- package/include/numkong/trigonometry/haswell.h +402 -401
- package/include/numkong/trigonometry/neon.h +386 -387
- package/include/numkong/trigonometry/rvv.h +52 -51
- package/include/numkong/trigonometry/serial.h +13 -13
- package/include/numkong/trigonometry/skylake.h +373 -369
- package/include/numkong/trigonometry/v128relaxed.h +375 -374
- package/include/numkong/trigonometry.h +13 -13
- package/include/numkong/trigonometry.hpp +2 -2
- package/include/numkong/types.h +287 -49
- package/include/numkong/types.hpp +436 -12
- package/include/numkong/vector.hpp +82 -14
- package/javascript/dist/cjs/numkong-wasm.js +6 -12
- package/javascript/dist/cjs/numkong.d.ts +7 -1
- package/javascript/dist/cjs/numkong.js +37 -11
- package/javascript/dist/cjs/types.d.ts +9 -0
- package/javascript/dist/cjs/types.js +96 -0
- package/javascript/dist/esm/numkong-browser.d.ts +14 -0
- package/javascript/dist/esm/numkong-browser.js +23 -0
- package/javascript/dist/esm/numkong-wasm.js +6 -12
- package/javascript/dist/esm/numkong.d.ts +7 -1
- package/javascript/dist/esm/numkong.js +37 -11
- package/javascript/dist/esm/types.d.ts +9 -0
- package/javascript/dist/esm/types.js +96 -0
- package/javascript/node-gyp-build.d.ts +4 -1
- package/javascript/numkong-browser.ts +40 -0
- package/javascript/numkong-wasm.ts +7 -13
- package/javascript/numkong.c +5 -26
- package/javascript/numkong.ts +36 -11
- package/javascript/tsconfig-base.json +1 -0
- package/javascript/tsconfig-cjs.json +6 -1
- package/javascript/types.ts +110 -0
- package/numkong.gypi +101 -0
- package/package.json +34 -13
- package/probes/arm_neon.c +8 -0
- package/probes/arm_neon_bfdot.c +9 -0
- package/probes/arm_neon_fhm.c +9 -0
- package/probes/arm_neon_half.c +8 -0
- package/probes/arm_neon_sdot.c +9 -0
- package/probes/arm_neonfp8.c +9 -0
- package/probes/arm_sme.c +16 -0
- package/probes/arm_sme2.c +16 -0
- package/probes/arm_sme2p1.c +16 -0
- package/probes/arm_sme_bf16.c +16 -0
- package/probes/arm_sme_bi32.c +16 -0
- package/probes/arm_sme_f64.c +16 -0
- package/probes/arm_sme_fa64.c +14 -0
- package/probes/arm_sme_half.c +16 -0
- package/probes/arm_sme_lut2.c +15 -0
- package/probes/arm_sve.c +18 -0
- package/probes/arm_sve2.c +20 -0
- package/probes/arm_sve2p1.c +18 -0
- package/probes/arm_sve_bfdot.c +20 -0
- package/probes/arm_sve_half.c +18 -0
- package/probes/arm_sve_sdot.c +21 -0
- package/probes/loongarch_lasx.c +12 -0
- package/probes/power_vsx.c +12 -0
- package/probes/probe.js +127 -0
- package/probes/riscv_rvv.c +14 -0
- package/probes/riscv_rvv_bb.c +15 -0
- package/probes/riscv_rvv_bf16.c +17 -0
- package/probes/riscv_rvv_half.c +14 -0
- package/probes/wasm_v128relaxed.c +11 -0
- package/probes/x86_alder.c +17 -0
- package/probes/x86_diamond.c +17 -0
- package/probes/x86_genoa.c +17 -0
- package/probes/x86_graniteamx.c +19 -0
- package/probes/x86_haswell.c +11 -0
- package/probes/x86_icelake.c +17 -0
- package/probes/x86_sapphire.c +16 -0
- package/probes/x86_sapphireamx.c +18 -0
- package/probes/x86_sierra.c +17 -0
- package/probes/x86_skylake.c +15 -0
- package/probes/x86_turin.c +17 -0
- package/wasm/numkong-emscripten.js +2 -0
- package/wasm/numkong.d.ts +14 -0
- package/wasm/numkong.js +1124 -0
- package/wasm/numkong.wasm +0 -0
- package/include/numkong/curved/neonhalf.h +0 -212
- package/include/numkong/dot/neonhalf.h +0 -198
- package/include/numkong/dots/neonhalf.h +0 -57
- package/include/numkong/mesh/neonhalf.h +0 -616
- package/include/numkong/reduce/neonhalf.h +0 -157
- package/include/numkong/spatial/neonhalf.h +0 -118
- package/include/numkong/spatial/sapphire.h +0 -343
- package/include/numkong/spatials/neonhalf.h +0 -58
- package/javascript/README.md +0 -246
|
@@ -56,12 +56,12 @@ extern "C" {
|
|
|
56
56
|
|
|
57
57
|
/** @brief Horizontal sum of 16 floats in a ZMM register (native f32 precision). */
|
|
58
58
|
NK_INTERNAL nk_f32_t nk_reduce_add_f32x16_skylake_(__m512 sum_f32x16) {
|
|
59
|
-
__m256
|
|
60
|
-
__m256
|
|
61
|
-
__m256 sum_f32x8 = _mm256_add_ps(
|
|
62
|
-
__m128
|
|
63
|
-
__m128
|
|
64
|
-
__m128 sum_f32x4 = _mm_add_ps(
|
|
59
|
+
__m256 low_f32x8 = _mm512_castps512_ps256(sum_f32x16);
|
|
60
|
+
__m256 high_f32x8 = _mm512_extractf32x8_ps(sum_f32x16, 1);
|
|
61
|
+
__m256 sum_f32x8 = _mm256_add_ps(low_f32x8, high_f32x8);
|
|
62
|
+
__m128 low_f32x4 = _mm256_castps256_ps128(sum_f32x8);
|
|
63
|
+
__m128 high_f32x4 = _mm256_extractf128_ps(sum_f32x8, 1);
|
|
64
|
+
__m128 sum_f32x4 = _mm_add_ps(low_f32x4, high_f32x4);
|
|
65
65
|
sum_f32x4 = _mm_hadd_ps(sum_f32x4, sum_f32x4);
|
|
66
66
|
sum_f32x4 = _mm_hadd_ps(sum_f32x4, sum_f32x4);
|
|
67
67
|
return _mm_cvtss_f32(sum_f32x4);
|
|
@@ -69,24 +69,24 @@ NK_INTERNAL nk_f32_t nk_reduce_add_f32x16_skylake_(__m512 sum_f32x16) {
|
|
|
69
69
|
|
|
70
70
|
/** @brief Horizontal sum of 8 doubles in a ZMM register. */
|
|
71
71
|
NK_INTERNAL nk_f64_t nk_reduce_add_f64x8_skylake_(__m512d sum_f64x8) {
|
|
72
|
-
__m256d
|
|
73
|
-
__m256d
|
|
74
|
-
__m256d sum_f64x4 = _mm256_add_pd(
|
|
75
|
-
__m128d
|
|
76
|
-
__m128d
|
|
77
|
-
__m128d sum_f64x2 = _mm_add_pd(
|
|
72
|
+
__m256d low_f64x4 = _mm512_castpd512_pd256(sum_f64x8);
|
|
73
|
+
__m256d high_f64x4 = _mm512_extractf64x4_pd(sum_f64x8, 1);
|
|
74
|
+
__m256d sum_f64x4 = _mm256_add_pd(low_f64x4, high_f64x4);
|
|
75
|
+
__m128d low_f64x2 = _mm256_castpd256_pd128(sum_f64x4);
|
|
76
|
+
__m128d high_f64x2 = _mm256_extractf128_pd(sum_f64x4, 1);
|
|
77
|
+
__m128d sum_f64x2 = _mm_add_pd(low_f64x2, high_f64x2);
|
|
78
78
|
sum_f64x2 = _mm_hadd_pd(sum_f64x2, sum_f64x2);
|
|
79
79
|
return _mm_cvtsd_f64(sum_f64x2);
|
|
80
80
|
}
|
|
81
81
|
|
|
82
82
|
/** @brief Horizontal min of 16 floats in a ZMM register. */
|
|
83
83
|
NK_INTERNAL nk_f32_t nk_reduce_min_f32x16_skylake_(__m512 min_f32x16) {
|
|
84
|
-
__m256
|
|
85
|
-
__m256
|
|
86
|
-
__m256 min_f32x8 = _mm256_min_ps(
|
|
87
|
-
__m128
|
|
88
|
-
__m128
|
|
89
|
-
__m128 min_f32x4 = _mm_min_ps(
|
|
84
|
+
__m256 low_f32x8 = _mm512_castps512_ps256(min_f32x16);
|
|
85
|
+
__m256 high_f32x8 = _mm512_extractf32x8_ps(min_f32x16, 1);
|
|
86
|
+
__m256 min_f32x8 = _mm256_min_ps(low_f32x8, high_f32x8);
|
|
87
|
+
__m128 low_f32x4 = _mm256_castps256_ps128(min_f32x8);
|
|
88
|
+
__m128 high_f32x4 = _mm256_extractf128_ps(min_f32x8, 1);
|
|
89
|
+
__m128 min_f32x4 = _mm_min_ps(low_f32x4, high_f32x4);
|
|
90
90
|
min_f32x4 = _mm_min_ps(min_f32x4, _mm_shuffle_ps(min_f32x4, min_f32x4, _MM_SHUFFLE(2, 3, 0, 1)));
|
|
91
91
|
min_f32x4 = _mm_min_ps(min_f32x4, _mm_shuffle_ps(min_f32x4, min_f32x4, _MM_SHUFFLE(1, 0, 3, 2)));
|
|
92
92
|
return _mm_cvtss_f32(min_f32x4);
|
|
@@ -94,12 +94,12 @@ NK_INTERNAL nk_f32_t nk_reduce_min_f32x16_skylake_(__m512 min_f32x16) {
|
|
|
94
94
|
|
|
95
95
|
/** @brief Horizontal max of 16 floats in a ZMM register. */
|
|
96
96
|
NK_INTERNAL nk_f32_t nk_reduce_max_f32x16_skylake_(__m512 max_f32x16) {
|
|
97
|
-
__m256
|
|
98
|
-
__m256
|
|
99
|
-
__m256 max_f32x8 = _mm256_max_ps(
|
|
100
|
-
__m128
|
|
101
|
-
__m128
|
|
102
|
-
__m128 max_f32x4 = _mm_max_ps(
|
|
97
|
+
__m256 low_f32x8 = _mm512_castps512_ps256(max_f32x16);
|
|
98
|
+
__m256 high_f32x8 = _mm512_extractf32x8_ps(max_f32x16, 1);
|
|
99
|
+
__m256 max_f32x8 = _mm256_max_ps(low_f32x8, high_f32x8);
|
|
100
|
+
__m128 low_f32x4 = _mm256_castps256_ps128(max_f32x8);
|
|
101
|
+
__m128 high_f32x4 = _mm256_extractf128_ps(max_f32x8, 1);
|
|
102
|
+
__m128 max_f32x4 = _mm_max_ps(low_f32x4, high_f32x4);
|
|
103
103
|
max_f32x4 = _mm_max_ps(max_f32x4, _mm_shuffle_ps(max_f32x4, max_f32x4, _MM_SHUFFLE(2, 3, 0, 1)));
|
|
104
104
|
max_f32x4 = _mm_max_ps(max_f32x4, _mm_shuffle_ps(max_f32x4, max_f32x4, _MM_SHUFFLE(1, 0, 3, 2)));
|
|
105
105
|
return _mm_cvtss_f32(max_f32x4);
|
|
@@ -107,36 +107,36 @@ NK_INTERNAL nk_f32_t nk_reduce_max_f32x16_skylake_(__m512 max_f32x16) {
|
|
|
107
107
|
|
|
108
108
|
/** @brief Horizontal min of 8 doubles in a ZMM register. */
|
|
109
109
|
NK_INTERNAL nk_f64_t nk_reduce_min_f64x8_skylake_(__m512d min_f64x8) {
|
|
110
|
-
__m256d
|
|
111
|
-
__m256d
|
|
112
|
-
__m256d min_f64x4 = _mm256_min_pd(
|
|
113
|
-
__m128d
|
|
114
|
-
__m128d
|
|
115
|
-
__m128d min_f64x2 = _mm_min_pd(
|
|
110
|
+
__m256d low_f64x4 = _mm512_castpd512_pd256(min_f64x8);
|
|
111
|
+
__m256d high_f64x4 = _mm512_extractf64x4_pd(min_f64x8, 1);
|
|
112
|
+
__m256d min_f64x4 = _mm256_min_pd(low_f64x4, high_f64x4);
|
|
113
|
+
__m128d low_f64x2 = _mm256_castpd256_pd128(min_f64x4);
|
|
114
|
+
__m128d high_f64x2 = _mm256_extractf128_pd(min_f64x4, 1);
|
|
115
|
+
__m128d min_f64x2 = _mm_min_pd(low_f64x2, high_f64x2);
|
|
116
116
|
min_f64x2 = _mm_min_pd(min_f64x2, _mm_shuffle_pd(min_f64x2, min_f64x2, 1));
|
|
117
117
|
return _mm_cvtsd_f64(min_f64x2);
|
|
118
118
|
}
|
|
119
119
|
|
|
120
120
|
/** @brief Horizontal max of 8 doubles in a ZMM register. */
|
|
121
121
|
NK_INTERNAL nk_f64_t nk_reduce_max_f64x8_skylake_(__m512d max_f64x8) {
|
|
122
|
-
__m256d
|
|
123
|
-
__m256d
|
|
124
|
-
__m256d max_f64x4 = _mm256_max_pd(
|
|
125
|
-
__m128d
|
|
126
|
-
__m128d
|
|
127
|
-
__m128d max_f64x2 = _mm_max_pd(
|
|
122
|
+
__m256d low_f64x4 = _mm512_castpd512_pd256(max_f64x8);
|
|
123
|
+
__m256d high_f64x4 = _mm512_extractf64x4_pd(max_f64x8, 1);
|
|
124
|
+
__m256d max_f64x4 = _mm256_max_pd(low_f64x4, high_f64x4);
|
|
125
|
+
__m128d low_f64x2 = _mm256_castpd256_pd128(max_f64x4);
|
|
126
|
+
__m128d high_f64x2 = _mm256_extractf128_pd(max_f64x4, 1);
|
|
127
|
+
__m128d max_f64x2 = _mm_max_pd(low_f64x2, high_f64x2);
|
|
128
128
|
max_f64x2 = _mm_max_pd(max_f64x2, _mm_shuffle_pd(max_f64x2, max_f64x2, 1));
|
|
129
129
|
return _mm_cvtsd_f64(max_f64x2);
|
|
130
130
|
}
|
|
131
131
|
|
|
132
132
|
/** @brief Horizontal sum of 16 i32s in a ZMM register. */
|
|
133
133
|
NK_INTERNAL nk_i32_t nk_reduce_add_i32x16_skylake_(__m512i sum_i32x16) {
|
|
134
|
-
__m256i
|
|
135
|
-
__m256i
|
|
136
|
-
__m256i sum_i32x8 = _mm256_add_epi32(
|
|
137
|
-
__m128i
|
|
138
|
-
__m128i
|
|
139
|
-
__m128i sum_i32x4 = _mm_add_epi32(
|
|
134
|
+
__m256i low_i32x8 = _mm512_castsi512_si256(sum_i32x16);
|
|
135
|
+
__m256i high_i32x8 = _mm512_extracti32x8_epi32(sum_i32x16, 1);
|
|
136
|
+
__m256i sum_i32x8 = _mm256_add_epi32(low_i32x8, high_i32x8);
|
|
137
|
+
__m128i low_i32x4 = _mm256_castsi256_si128(sum_i32x8);
|
|
138
|
+
__m128i high_i32x4 = _mm256_extracti128_si256(sum_i32x8, 1);
|
|
139
|
+
__m128i sum_i32x4 = _mm_add_epi32(low_i32x4, high_i32x4);
|
|
140
140
|
sum_i32x4 = _mm_hadd_epi32(sum_i32x4, sum_i32x4);
|
|
141
141
|
sum_i32x4 = _mm_hadd_epi32(sum_i32x4, sum_i32x4);
|
|
142
142
|
return _mm_cvtsi128_si32(sum_i32x4);
|
|
@@ -144,12 +144,12 @@ NK_INTERNAL nk_i32_t nk_reduce_add_i32x16_skylake_(__m512i sum_i32x16) {
|
|
|
144
144
|
|
|
145
145
|
/** @brief Horizontal sum of 8 i64s in a ZMM register. */
|
|
146
146
|
NK_INTERNAL nk_i64_t nk_reduce_add_i64x8_skylake_(__m512i sum_i64x8) {
|
|
147
|
-
__m256i
|
|
148
|
-
__m256i
|
|
149
|
-
__m256i sum_i64x4 = _mm256_add_epi64(
|
|
150
|
-
__m128i
|
|
151
|
-
__m128i
|
|
152
|
-
__m128i sum_i64x2 = _mm_add_epi64(
|
|
147
|
+
__m256i low_i64x4 = _mm512_castsi512_si256(sum_i64x8);
|
|
148
|
+
__m256i high_i64x4 = _mm512_extracti64x4_epi64(sum_i64x8, 1);
|
|
149
|
+
__m256i sum_i64x4 = _mm256_add_epi64(low_i64x4, high_i64x4);
|
|
150
|
+
__m128i low_i64x2 = _mm256_castsi256_si128(sum_i64x4);
|
|
151
|
+
__m128i high_i64x2 = _mm256_extracti128_si256(sum_i64x4, 1);
|
|
152
|
+
__m128i sum_i64x2 = _mm_add_epi64(low_i64x2, high_i64x2);
|
|
153
153
|
sum_i64x2 = _mm_add_epi64(sum_i64x2, _mm_shuffle_epi32(sum_i64x2, _MM_SHUFFLE(1, 0, 3, 2)));
|
|
154
154
|
return _mm_cvtsi128_si64(sum_i64x2);
|
|
155
155
|
}
|
|
@@ -284,12 +284,12 @@ NK_INTERNAL nk_size_t nk_stride_elems_b64x8_(nk_size_t stride) {
|
|
|
284
284
|
|
|
285
285
|
/** @brief Horizontal min of 64 signed i8s in a ZMM register. */
|
|
286
286
|
NK_INTERNAL nk_i8_t nk_reduce_min_i8x64_skylake_(__m512i min_i8x64) {
|
|
287
|
-
__m256i
|
|
288
|
-
__m256i
|
|
289
|
-
__m256i min_i8x32 = _mm256_min_epi8(
|
|
290
|
-
__m128i
|
|
291
|
-
__m128i
|
|
292
|
-
__m128i min_i8x16 = _mm_min_epi8(
|
|
287
|
+
__m256i low_i8x32 = _mm512_castsi512_si256(min_i8x64);
|
|
288
|
+
__m256i high_i8x32 = _mm512_extracti64x4_epi64(min_i8x64, 1);
|
|
289
|
+
__m256i min_i8x32 = _mm256_min_epi8(low_i8x32, high_i8x32);
|
|
290
|
+
__m128i low_i8x16 = _mm256_castsi256_si128(min_i8x32);
|
|
291
|
+
__m128i high_i8x16 = _mm256_extracti128_si256(min_i8x32, 1);
|
|
292
|
+
__m128i min_i8x16 = _mm_min_epi8(low_i8x16, high_i8x16);
|
|
293
293
|
min_i8x16 = _mm_min_epi8(min_i8x16, _mm_shuffle_epi32(min_i8x16, _MM_SHUFFLE(2, 3, 0, 1)));
|
|
294
294
|
min_i8x16 = _mm_min_epi8(min_i8x16, _mm_shuffle_epi32(min_i8x16, _MM_SHUFFLE(1, 0, 3, 2)));
|
|
295
295
|
min_i8x16 = _mm_min_epi8(min_i8x16, _mm_srli_si128(min_i8x16, 2));
|
|
@@ -299,12 +299,12 @@ NK_INTERNAL nk_i8_t nk_reduce_min_i8x64_skylake_(__m512i min_i8x64) {
|
|
|
299
299
|
|
|
300
300
|
/** @brief Horizontal max of 64 signed i8s in a ZMM register. */
|
|
301
301
|
NK_INTERNAL nk_i8_t nk_reduce_max_i8x64_skylake_(__m512i max_i8x64) {
|
|
302
|
-
__m256i
|
|
303
|
-
__m256i
|
|
304
|
-
__m256i max_i8x32 = _mm256_max_epi8(
|
|
305
|
-
__m128i
|
|
306
|
-
__m128i
|
|
307
|
-
__m128i max_i8x16 = _mm_max_epi8(
|
|
302
|
+
__m256i low_i8x32 = _mm512_castsi512_si256(max_i8x64);
|
|
303
|
+
__m256i high_i8x32 = _mm512_extracti64x4_epi64(max_i8x64, 1);
|
|
304
|
+
__m256i max_i8x32 = _mm256_max_epi8(low_i8x32, high_i8x32);
|
|
305
|
+
__m128i low_i8x16 = _mm256_castsi256_si128(max_i8x32);
|
|
306
|
+
__m128i high_i8x16 = _mm256_extracti128_si256(max_i8x32, 1);
|
|
307
|
+
__m128i max_i8x16 = _mm_max_epi8(low_i8x16, high_i8x16);
|
|
308
308
|
max_i8x16 = _mm_max_epi8(max_i8x16, _mm_shuffle_epi32(max_i8x16, _MM_SHUFFLE(2, 3, 0, 1)));
|
|
309
309
|
max_i8x16 = _mm_max_epi8(max_i8x16, _mm_shuffle_epi32(max_i8x16, _MM_SHUFFLE(1, 0, 3, 2)));
|
|
310
310
|
max_i8x16 = _mm_max_epi8(max_i8x16, _mm_srli_si128(max_i8x16, 2));
|
|
@@ -314,12 +314,12 @@ NK_INTERNAL nk_i8_t nk_reduce_max_i8x64_skylake_(__m512i max_i8x64) {
|
|
|
314
314
|
|
|
315
315
|
/** @brief Horizontal min of 64 unsigned u8s in a ZMM register. */
|
|
316
316
|
NK_INTERNAL nk_u8_t nk_reduce_min_u8x64_skylake_(__m512i min_u8x64) {
|
|
317
|
-
__m256i
|
|
318
|
-
__m256i
|
|
319
|
-
__m256i min_u8x32 = _mm256_min_epu8(
|
|
320
|
-
__m128i
|
|
321
|
-
__m128i
|
|
322
|
-
__m128i min_u8x16 = _mm_min_epu8(
|
|
317
|
+
__m256i low_u8x32 = _mm512_castsi512_si256(min_u8x64);
|
|
318
|
+
__m256i high_u8x32 = _mm512_extracti64x4_epi64(min_u8x64, 1);
|
|
319
|
+
__m256i min_u8x32 = _mm256_min_epu8(low_u8x32, high_u8x32);
|
|
320
|
+
__m128i low_u8x16 = _mm256_castsi256_si128(min_u8x32);
|
|
321
|
+
__m128i high_u8x16 = _mm256_extracti128_si256(min_u8x32, 1);
|
|
322
|
+
__m128i min_u8x16 = _mm_min_epu8(low_u8x16, high_u8x16);
|
|
323
323
|
min_u8x16 = _mm_min_epu8(min_u8x16, _mm_shuffle_epi32(min_u8x16, _MM_SHUFFLE(2, 3, 0, 1)));
|
|
324
324
|
min_u8x16 = _mm_min_epu8(min_u8x16, _mm_shuffle_epi32(min_u8x16, _MM_SHUFFLE(1, 0, 3, 2)));
|
|
325
325
|
min_u8x16 = _mm_min_epu8(min_u8x16, _mm_srli_si128(min_u8x16, 2));
|
|
@@ -329,12 +329,12 @@ NK_INTERNAL nk_u8_t nk_reduce_min_u8x64_skylake_(__m512i min_u8x64) {
|
|
|
329
329
|
|
|
330
330
|
/** @brief Horizontal max of 64 unsigned u8s in a ZMM register. */
|
|
331
331
|
NK_INTERNAL nk_u8_t nk_reduce_max_u8x64_skylake_(__m512i max_u8x64) {
|
|
332
|
-
__m256i
|
|
333
|
-
__m256i
|
|
334
|
-
__m256i max_u8x32 = _mm256_max_epu8(
|
|
335
|
-
__m128i
|
|
336
|
-
__m128i
|
|
337
|
-
__m128i max_u8x16 = _mm_max_epu8(
|
|
332
|
+
__m256i low_u8x32 = _mm512_castsi512_si256(max_u8x64);
|
|
333
|
+
__m256i high_u8x32 = _mm512_extracti64x4_epi64(max_u8x64, 1);
|
|
334
|
+
__m256i max_u8x32 = _mm256_max_epu8(low_u8x32, high_u8x32);
|
|
335
|
+
__m128i low_u8x16 = _mm256_castsi256_si128(max_u8x32);
|
|
336
|
+
__m128i high_u8x16 = _mm256_extracti128_si256(max_u8x32, 1);
|
|
337
|
+
__m128i max_u8x16 = _mm_max_epu8(low_u8x16, high_u8x16);
|
|
338
338
|
max_u8x16 = _mm_max_epu8(max_u8x16, _mm_shuffle_epi32(max_u8x16, _MM_SHUFFLE(2, 3, 0, 1)));
|
|
339
339
|
max_u8x16 = _mm_max_epu8(max_u8x16, _mm_shuffle_epi32(max_u8x16, _MM_SHUFFLE(1, 0, 3, 2)));
|
|
340
340
|
max_u8x16 = _mm_max_epu8(max_u8x16, _mm_srli_si128(max_u8x16, 2));
|
|
@@ -344,12 +344,12 @@ NK_INTERNAL nk_u8_t nk_reduce_max_u8x64_skylake_(__m512i max_u8x64) {
|
|
|
344
344
|
|
|
345
345
|
/** @brief Horizontal min of 32 signed i16s in a ZMM register. */
|
|
346
346
|
NK_INTERNAL nk_i16_t nk_reduce_min_i16x32_skylake_(__m512i min_i16x32) {
|
|
347
|
-
__m256i
|
|
348
|
-
__m256i
|
|
349
|
-
__m256i min_i16x16 = _mm256_min_epi16(
|
|
350
|
-
__m128i
|
|
351
|
-
__m128i
|
|
352
|
-
__m128i min_i16x8 = _mm_min_epi16(
|
|
347
|
+
__m256i low_i16x16 = _mm512_castsi512_si256(min_i16x32);
|
|
348
|
+
__m256i high_i16x16 = _mm512_extracti64x4_epi64(min_i16x32, 1);
|
|
349
|
+
__m256i min_i16x16 = _mm256_min_epi16(low_i16x16, high_i16x16);
|
|
350
|
+
__m128i low_i16x8 = _mm256_castsi256_si128(min_i16x16);
|
|
351
|
+
__m128i high_i16x8 = _mm256_extracti128_si256(min_i16x16, 1);
|
|
352
|
+
__m128i min_i16x8 = _mm_min_epi16(low_i16x8, high_i16x8);
|
|
353
353
|
min_i16x8 = _mm_min_epi16(min_i16x8, _mm_shuffle_epi32(min_i16x8, _MM_SHUFFLE(2, 3, 0, 1)));
|
|
354
354
|
min_i16x8 = _mm_min_epi16(min_i16x8, _mm_shuffle_epi32(min_i16x8, _MM_SHUFFLE(1, 0, 3, 2)));
|
|
355
355
|
min_i16x8 = _mm_min_epi16(min_i16x8, _mm_srli_si128(min_i16x8, 2));
|
|
@@ -358,12 +358,12 @@ NK_INTERNAL nk_i16_t nk_reduce_min_i16x32_skylake_(__m512i min_i16x32) {
|
|
|
358
358
|
|
|
359
359
|
/** @brief Horizontal max of 32 signed i16s in a ZMM register. */
|
|
360
360
|
NK_INTERNAL nk_i16_t nk_reduce_max_i16x32_skylake_(__m512i max_i16x32) {
|
|
361
|
-
__m256i
|
|
362
|
-
__m256i
|
|
363
|
-
__m256i max_i16x16 = _mm256_max_epi16(
|
|
364
|
-
__m128i
|
|
365
|
-
__m128i
|
|
366
|
-
__m128i max_i16x8 = _mm_max_epi16(
|
|
361
|
+
__m256i low_i16x16 = _mm512_castsi512_si256(max_i16x32);
|
|
362
|
+
__m256i high_i16x16 = _mm512_extracti64x4_epi64(max_i16x32, 1);
|
|
363
|
+
__m256i max_i16x16 = _mm256_max_epi16(low_i16x16, high_i16x16);
|
|
364
|
+
__m128i low_i16x8 = _mm256_castsi256_si128(max_i16x16);
|
|
365
|
+
__m128i high_i16x8 = _mm256_extracti128_si256(max_i16x16, 1);
|
|
366
|
+
__m128i max_i16x8 = _mm_max_epi16(low_i16x8, high_i16x8);
|
|
367
367
|
max_i16x8 = _mm_max_epi16(max_i16x8, _mm_shuffle_epi32(max_i16x8, _MM_SHUFFLE(2, 3, 0, 1)));
|
|
368
368
|
max_i16x8 = _mm_max_epi16(max_i16x8, _mm_shuffle_epi32(max_i16x8, _MM_SHUFFLE(1, 0, 3, 2)));
|
|
369
369
|
max_i16x8 = _mm_max_epi16(max_i16x8, _mm_srli_si128(max_i16x8, 2));
|
|
@@ -372,12 +372,12 @@ NK_INTERNAL nk_i16_t nk_reduce_max_i16x32_skylake_(__m512i max_i16x32) {
|
|
|
372
372
|
|
|
373
373
|
/** @brief Horizontal min of 32 unsigned u16s in a ZMM register. */
|
|
374
374
|
NK_INTERNAL nk_u16_t nk_reduce_min_u16x32_skylake_(__m512i min_u16x32) {
|
|
375
|
-
__m256i
|
|
376
|
-
__m256i
|
|
377
|
-
__m256i min_u16x16 = _mm256_min_epu16(
|
|
378
|
-
__m128i
|
|
379
|
-
__m128i
|
|
380
|
-
__m128i min_u16x8 = _mm_min_epu16(
|
|
375
|
+
__m256i low_u16x16 = _mm512_castsi512_si256(min_u16x32);
|
|
376
|
+
__m256i high_u16x16 = _mm512_extracti64x4_epi64(min_u16x32, 1);
|
|
377
|
+
__m256i min_u16x16 = _mm256_min_epu16(low_u16x16, high_u16x16);
|
|
378
|
+
__m128i low_u16x8 = _mm256_castsi256_si128(min_u16x16);
|
|
379
|
+
__m128i high_u16x8 = _mm256_extracti128_si256(min_u16x16, 1);
|
|
380
|
+
__m128i min_u16x8 = _mm_min_epu16(low_u16x8, high_u16x8);
|
|
381
381
|
min_u16x8 = _mm_min_epu16(min_u16x8, _mm_shuffle_epi32(min_u16x8, _MM_SHUFFLE(2, 3, 0, 1)));
|
|
382
382
|
min_u16x8 = _mm_min_epu16(min_u16x8, _mm_shuffle_epi32(min_u16x8, _MM_SHUFFLE(1, 0, 3, 2)));
|
|
383
383
|
min_u16x8 = _mm_min_epu16(min_u16x8, _mm_srli_si128(min_u16x8, 2));
|
|
@@ -386,12 +386,12 @@ NK_INTERNAL nk_u16_t nk_reduce_min_u16x32_skylake_(__m512i min_u16x32) {
|
|
|
386
386
|
|
|
387
387
|
/** @brief Horizontal max of 32 unsigned u16s in a ZMM register. */
|
|
388
388
|
NK_INTERNAL nk_u16_t nk_reduce_max_u16x32_skylake_(__m512i max_u16x32) {
|
|
389
|
-
__m256i
|
|
390
|
-
__m256i
|
|
391
|
-
__m256i max_u16x16 = _mm256_max_epu16(
|
|
392
|
-
__m128i
|
|
393
|
-
__m128i
|
|
394
|
-
__m128i max_u16x8 = _mm_max_epu16(
|
|
389
|
+
__m256i low_u16x16 = _mm512_castsi512_si256(max_u16x32);
|
|
390
|
+
__m256i high_u16x16 = _mm512_extracti64x4_epi64(max_u16x32, 1);
|
|
391
|
+
__m256i max_u16x16 = _mm256_max_epu16(low_u16x16, high_u16x16);
|
|
392
|
+
__m128i low_u16x8 = _mm256_castsi256_si128(max_u16x16);
|
|
393
|
+
__m128i high_u16x8 = _mm256_extracti128_si256(max_u16x16, 1);
|
|
394
|
+
__m128i max_u16x8 = _mm_max_epu16(low_u16x8, high_u16x8);
|
|
395
395
|
max_u16x8 = _mm_max_epu16(max_u16x8, _mm_shuffle_epi32(max_u16x8, _MM_SHUFFLE(2, 3, 0, 1)));
|
|
396
396
|
max_u16x8 = _mm_max_epu16(max_u16x8, _mm_shuffle_epi32(max_u16x8, _MM_SHUFFLE(1, 0, 3, 2)));
|
|
397
397
|
max_u16x8 = _mm_max_epu16(max_u16x8, _mm_srli_si128(max_u16x8, 2));
|
|
@@ -400,12 +400,12 @@ NK_INTERNAL nk_u16_t nk_reduce_max_u16x32_skylake_(__m512i max_u16x32) {
|
|
|
400
400
|
|
|
401
401
|
/** @brief Horizontal min of 16 signed i32s in a ZMM register. */
|
|
402
402
|
NK_INTERNAL nk_i32_t nk_reduce_min_i32x16_skylake_(__m512i min_i32x16) {
|
|
403
|
-
__m256i
|
|
404
|
-
__m256i
|
|
405
|
-
__m256i min_i32x8 = _mm256_min_epi32(
|
|
406
|
-
__m128i
|
|
407
|
-
__m128i
|
|
408
|
-
__m128i min_i32x4 = _mm_min_epi32(
|
|
403
|
+
__m256i low_i32x8 = _mm512_castsi512_si256(min_i32x16);
|
|
404
|
+
__m256i high_i32x8 = _mm512_extracti64x4_epi64(min_i32x16, 1);
|
|
405
|
+
__m256i min_i32x8 = _mm256_min_epi32(low_i32x8, high_i32x8);
|
|
406
|
+
__m128i low_i32x4 = _mm256_castsi256_si128(min_i32x8);
|
|
407
|
+
__m128i high_i32x4 = _mm256_extracti128_si256(min_i32x8, 1);
|
|
408
|
+
__m128i min_i32x4 = _mm_min_epi32(low_i32x4, high_i32x4);
|
|
409
409
|
min_i32x4 = _mm_min_epi32(min_i32x4, _mm_shuffle_epi32(min_i32x4, _MM_SHUFFLE(2, 3, 0, 1)));
|
|
410
410
|
min_i32x4 = _mm_min_epi32(min_i32x4, _mm_shuffle_epi32(min_i32x4, _MM_SHUFFLE(1, 0, 3, 2)));
|
|
411
411
|
return _mm_cvtsi128_si32(min_i32x4);
|
|
@@ -413,12 +413,12 @@ NK_INTERNAL nk_i32_t nk_reduce_min_i32x16_skylake_(__m512i min_i32x16) {
|
|
|
413
413
|
|
|
414
414
|
/** @brief Horizontal max of 16 signed i32s in a ZMM register. */
|
|
415
415
|
NK_INTERNAL nk_i32_t nk_reduce_max_i32x16_skylake_(__m512i max_i32x16) {
|
|
416
|
-
__m256i
|
|
417
|
-
__m256i
|
|
418
|
-
__m256i max_i32x8 = _mm256_max_epi32(
|
|
419
|
-
__m128i
|
|
420
|
-
__m128i
|
|
421
|
-
__m128i max_i32x4 = _mm_max_epi32(
|
|
416
|
+
__m256i low_i32x8 = _mm512_castsi512_si256(max_i32x16);
|
|
417
|
+
__m256i high_i32x8 = _mm512_extracti64x4_epi64(max_i32x16, 1);
|
|
418
|
+
__m256i max_i32x8 = _mm256_max_epi32(low_i32x8, high_i32x8);
|
|
419
|
+
__m128i low_i32x4 = _mm256_castsi256_si128(max_i32x8);
|
|
420
|
+
__m128i high_i32x4 = _mm256_extracti128_si256(max_i32x8, 1);
|
|
421
|
+
__m128i max_i32x4 = _mm_max_epi32(low_i32x4, high_i32x4);
|
|
422
422
|
max_i32x4 = _mm_max_epi32(max_i32x4, _mm_shuffle_epi32(max_i32x4, _MM_SHUFFLE(2, 3, 0, 1)));
|
|
423
423
|
max_i32x4 = _mm_max_epi32(max_i32x4, _mm_shuffle_epi32(max_i32x4, _MM_SHUFFLE(1, 0, 3, 2)));
|
|
424
424
|
return _mm_cvtsi128_si32(max_i32x4);
|
|
@@ -426,12 +426,12 @@ NK_INTERNAL nk_i32_t nk_reduce_max_i32x16_skylake_(__m512i max_i32x16) {
|
|
|
426
426
|
|
|
427
427
|
/** @brief Horizontal min of 16 unsigned u32s in a ZMM register. */
|
|
428
428
|
NK_INTERNAL nk_u32_t nk_reduce_min_u32x16_skylake_(__m512i min_u32x16) {
|
|
429
|
-
__m256i
|
|
430
|
-
__m256i
|
|
431
|
-
__m256i min_u32x8 = _mm256_min_epu32(
|
|
432
|
-
__m128i
|
|
433
|
-
__m128i
|
|
434
|
-
__m128i min_u32x4 = _mm_min_epu32(
|
|
429
|
+
__m256i low_u32x8 = _mm512_castsi512_si256(min_u32x16);
|
|
430
|
+
__m256i high_u32x8 = _mm512_extracti64x4_epi64(min_u32x16, 1);
|
|
431
|
+
__m256i min_u32x8 = _mm256_min_epu32(low_u32x8, high_u32x8);
|
|
432
|
+
__m128i low_u32x4 = _mm256_castsi256_si128(min_u32x8);
|
|
433
|
+
__m128i high_u32x4 = _mm256_extracti128_si256(min_u32x8, 1);
|
|
434
|
+
__m128i min_u32x4 = _mm_min_epu32(low_u32x4, high_u32x4);
|
|
435
435
|
min_u32x4 = _mm_min_epu32(min_u32x4, _mm_shuffle_epi32(min_u32x4, _MM_SHUFFLE(2, 3, 0, 1)));
|
|
436
436
|
min_u32x4 = _mm_min_epu32(min_u32x4, _mm_shuffle_epi32(min_u32x4, _MM_SHUFFLE(1, 0, 3, 2)));
|
|
437
437
|
return (nk_u32_t)_mm_cvtsi128_si32(min_u32x4);
|
|
@@ -439,12 +439,12 @@ NK_INTERNAL nk_u32_t nk_reduce_min_u32x16_skylake_(__m512i min_u32x16) {
|
|
|
439
439
|
|
|
440
440
|
/** @brief Horizontal max of 16 unsigned u32s in a ZMM register. */
|
|
441
441
|
NK_INTERNAL nk_u32_t nk_reduce_max_u32x16_skylake_(__m512i max_u32x16) {
|
|
442
|
-
__m256i
|
|
443
|
-
__m256i
|
|
444
|
-
__m256i max_u32x8 = _mm256_max_epu32(
|
|
445
|
-
__m128i
|
|
446
|
-
__m128i
|
|
447
|
-
__m128i max_u32x4 = _mm_max_epu32(
|
|
442
|
+
__m256i low_u32x8 = _mm512_castsi512_si256(max_u32x16);
|
|
443
|
+
__m256i high_u32x8 = _mm512_extracti64x4_epi64(max_u32x16, 1);
|
|
444
|
+
__m256i max_u32x8 = _mm256_max_epu32(low_u32x8, high_u32x8);
|
|
445
|
+
__m128i low_u32x4 = _mm256_castsi256_si128(max_u32x8);
|
|
446
|
+
__m128i high_u32x4 = _mm256_extracti128_si256(max_u32x8, 1);
|
|
447
|
+
__m128i max_u32x4 = _mm_max_epu32(low_u32x4, high_u32x4);
|
|
448
448
|
max_u32x4 = _mm_max_epu32(max_u32x4, _mm_shuffle_epi32(max_u32x4, _MM_SHUFFLE(2, 3, 0, 1)));
|
|
449
449
|
max_u32x4 = _mm_max_epu32(max_u32x4, _mm_shuffle_epi32(max_u32x4, _MM_SHUFFLE(1, 0, 3, 2)));
|
|
450
450
|
return (nk_u32_t)_mm_cvtsi128_si32(max_u32x4);
|
|
@@ -452,67 +452,67 @@ NK_INTERNAL nk_u32_t nk_reduce_max_u32x16_skylake_(__m512i max_u32x16) {
|
|
|
452
452
|
|
|
453
453
|
/** @brief Horizontal min of 8 signed i64s in a ZMM register. */
|
|
454
454
|
NK_INTERNAL nk_i64_t nk_reduce_min_i64x8_skylake_(__m512i min_i64x8) {
|
|
455
|
-
__m256i
|
|
456
|
-
__m256i
|
|
457
|
-
__m256i min_i64x4 = _mm256_min_epi64(
|
|
458
|
-
__m128i
|
|
459
|
-
__m128i
|
|
460
|
-
__m128i min_i64x2 = _mm_min_epi64(
|
|
461
|
-
__m128i
|
|
462
|
-
__m128i
|
|
463
|
-
return _mm_cvtsi128_si64(
|
|
455
|
+
__m256i low_i64x4 = _mm512_castsi512_si256(min_i64x8);
|
|
456
|
+
__m256i high_i64x4 = _mm512_extracti64x4_epi64(min_i64x8, 1);
|
|
457
|
+
__m256i min_i64x4 = _mm256_min_epi64(low_i64x4, high_i64x4);
|
|
458
|
+
__m128i low_i64x2 = _mm256_castsi256_si128(min_i64x4);
|
|
459
|
+
__m128i high_i64x2 = _mm256_extracti128_si256(min_i64x4, 1);
|
|
460
|
+
__m128i min_i64x2 = _mm_min_epi64(low_i64x2, high_i64x2);
|
|
461
|
+
__m128i high_lane_i64x2 = _mm_unpackhi_epi64(min_i64x2, min_i64x2);
|
|
462
|
+
__m128i final_i64x2 = _mm_min_epi64(min_i64x2, high_lane_i64x2);
|
|
463
|
+
return _mm_cvtsi128_si64(final_i64x2);
|
|
464
464
|
}
|
|
465
465
|
|
|
466
466
|
/** @brief Horizontal max of 8 signed i64s in a ZMM register. */
|
|
467
467
|
NK_INTERNAL nk_i64_t nk_reduce_max_i64x8_skylake_(__m512i max_i64x8) {
|
|
468
|
-
__m256i
|
|
469
|
-
__m256i
|
|
470
|
-
__m256i max_i64x4 = _mm256_max_epi64(
|
|
471
|
-
__m128i
|
|
472
|
-
__m128i
|
|
473
|
-
__m128i max_i64x2 = _mm_max_epi64(
|
|
474
|
-
__m128i
|
|
475
|
-
__m128i
|
|
476
|
-
return _mm_cvtsi128_si64(
|
|
468
|
+
__m256i low_i64x4 = _mm512_castsi512_si256(max_i64x8);
|
|
469
|
+
__m256i high_i64x4 = _mm512_extracti64x4_epi64(max_i64x8, 1);
|
|
470
|
+
__m256i max_i64x4 = _mm256_max_epi64(low_i64x4, high_i64x4);
|
|
471
|
+
__m128i low_i64x2 = _mm256_castsi256_si128(max_i64x4);
|
|
472
|
+
__m128i high_i64x2 = _mm256_extracti128_si256(max_i64x4, 1);
|
|
473
|
+
__m128i max_i64x2 = _mm_max_epi64(low_i64x2, high_i64x2);
|
|
474
|
+
__m128i high_lane_i64x2 = _mm_unpackhi_epi64(max_i64x2, max_i64x2);
|
|
475
|
+
__m128i final_i64x2 = _mm_max_epi64(max_i64x2, high_lane_i64x2);
|
|
476
|
+
return _mm_cvtsi128_si64(final_i64x2);
|
|
477
477
|
}
|
|
478
478
|
|
|
479
479
|
/** @brief Horizontal min of 8 unsigned u64s in a ZMM register. */
|
|
480
480
|
NK_INTERNAL nk_u64_t nk_reduce_min_u64x8_skylake_(__m512i min_u64x8) {
|
|
481
|
-
__m256i
|
|
482
|
-
__m256i
|
|
483
|
-
__m256i min_u64x4 = _mm256_min_epu64(
|
|
484
|
-
__m128i
|
|
485
|
-
__m128i
|
|
486
|
-
__m128i min_u64x2 = _mm_min_epu64(
|
|
487
|
-
__m128i
|
|
488
|
-
__m128i
|
|
489
|
-
return (nk_u64_t)_mm_cvtsi128_si64(
|
|
481
|
+
__m256i low_u64x4 = _mm512_castsi512_si256(min_u64x8);
|
|
482
|
+
__m256i high_u64x4 = _mm512_extracti64x4_epi64(min_u64x8, 1);
|
|
483
|
+
__m256i min_u64x4 = _mm256_min_epu64(low_u64x4, high_u64x4);
|
|
484
|
+
__m128i low_u64x2 = _mm256_castsi256_si128(min_u64x4);
|
|
485
|
+
__m128i high_u64x2 = _mm256_extracti128_si256(min_u64x4, 1);
|
|
486
|
+
__m128i min_u64x2 = _mm_min_epu64(low_u64x2, high_u64x2);
|
|
487
|
+
__m128i high_lane_u64x2 = _mm_unpackhi_epi64(min_u64x2, min_u64x2);
|
|
488
|
+
__m128i final_u64x2 = _mm_min_epu64(min_u64x2, high_lane_u64x2);
|
|
489
|
+
return (nk_u64_t)_mm_cvtsi128_si64(final_u64x2);
|
|
490
490
|
}
|
|
491
491
|
|
|
492
492
|
/** @brief Horizontal max of 8 unsigned u64s in a ZMM register. */
|
|
493
493
|
NK_INTERNAL nk_u64_t nk_reduce_max_u64x8_skylake_(__m512i max_u64x8) {
|
|
494
|
-
__m256i
|
|
495
|
-
__m256i
|
|
496
|
-
__m256i max_u64x4 = _mm256_max_epu64(
|
|
497
|
-
__m128i
|
|
498
|
-
__m128i
|
|
499
|
-
__m128i max_u64x2 = _mm_max_epu64(
|
|
500
|
-
__m128i
|
|
501
|
-
__m128i
|
|
502
|
-
return (nk_u64_t)_mm_cvtsi128_si64(
|
|
494
|
+
__m256i low_u64x4 = _mm512_castsi512_si256(max_u64x8);
|
|
495
|
+
__m256i high_u64x4 = _mm512_extracti64x4_epi64(max_u64x8, 1);
|
|
496
|
+
__m256i max_u64x4 = _mm256_max_epu64(low_u64x4, high_u64x4);
|
|
497
|
+
__m128i low_u64x2 = _mm256_castsi256_si128(max_u64x4);
|
|
498
|
+
__m128i high_u64x2 = _mm256_extracti128_si256(max_u64x4, 1);
|
|
499
|
+
__m128i max_u64x2 = _mm_max_epu64(low_u64x2, high_u64x2);
|
|
500
|
+
__m128i high_lane_u64x2 = _mm_unpackhi_epi64(max_u64x2, max_u64x2);
|
|
501
|
+
__m128i final_u64x2 = _mm_max_epu64(max_u64x2, high_lane_u64x2);
|
|
502
|
+
return (nk_u64_t)_mm_cvtsi128_si64(final_u64x2);
|
|
503
503
|
}
|
|
504
504
|
|
|
505
505
|
/** @brief Horizontal sum of 8 unsigned u64s in a ZMM register. */
|
|
506
506
|
NK_INTERNAL nk_u64_t nk_reduce_add_u64x8_skylake_(__m512i sum_u64x8) {
|
|
507
|
-
__m256i
|
|
508
|
-
__m256i
|
|
509
|
-
__m256i sum_u64x4 = _mm256_add_epi64(
|
|
510
|
-
__m128i
|
|
511
|
-
__m128i
|
|
512
|
-
__m128i sum_u64x2 = _mm_add_epi64(
|
|
513
|
-
__m128i
|
|
514
|
-
__m128i
|
|
515
|
-
return (nk_u64_t)_mm_cvtsi128_si64(
|
|
507
|
+
__m256i low_u64x4 = _mm512_castsi512_si256(sum_u64x8);
|
|
508
|
+
__m256i high_u64x4 = _mm512_extracti64x4_epi64(sum_u64x8, 1);
|
|
509
|
+
__m256i sum_u64x4 = _mm256_add_epi64(low_u64x4, high_u64x4);
|
|
510
|
+
__m128i low_u64x2 = _mm256_castsi256_si128(sum_u64x4);
|
|
511
|
+
__m128i high_u64x2 = _mm256_extracti128_si256(sum_u64x4, 1);
|
|
512
|
+
__m128i sum_u64x2 = _mm_add_epi64(low_u64x2, high_u64x2);
|
|
513
|
+
__m128i high_lane_u64x2 = _mm_unpackhi_epi64(sum_u64x2, sum_u64x2);
|
|
514
|
+
__m128i final_u64x2 = _mm_add_epi64(sum_u64x2, high_lane_u64x2);
|
|
515
|
+
return (nk_u64_t)_mm_cvtsi128_si64(final_u64x2);
|
|
516
516
|
}
|
|
517
517
|
|
|
518
518
|
NK_INTERNAL __m512i nk_fp8x64_to_u8x64_comparable_skylake_(__m512i raw_i8x64) {
|
|
@@ -1514,7 +1514,7 @@ NK_INTERNAL void nk_reduce_moments_u16_skylake_contiguous_( //
|
|
|
1514
1514
|
nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
1515
1515
|
// Widen u16→u32, square in u32, widen to u64. Avoids bias trick whose
|
|
1516
1516
|
// VPMADDWD pair-of-squares overflows i32 when both lanes map to -32768.
|
|
1517
|
-
__m512i
|
|
1517
|
+
__m512i zero_u32x16 = _mm512_setzero_si512();
|
|
1518
1518
|
__m512i sum_u32x16 = _mm512_setzero_si512();
|
|
1519
1519
|
__m512i sumsq_u64x8 = _mm512_setzero_si512();
|
|
1520
1520
|
nk_size_t idx = 0;
|
|
@@ -1522,20 +1522,20 @@ NK_INTERNAL void nk_reduce_moments_u16_skylake_contiguous_( //
|
|
|
1522
1522
|
__m512i data_u32x16 = _mm512_cvtepu16_epi32(_mm256_loadu_si256((__m256i const *)(data_ptr + idx)));
|
|
1523
1523
|
sum_u32x16 = _mm512_add_epi32(sum_u32x16, data_u32x16);
|
|
1524
1524
|
__m512i sq_u32x16 = _mm512_mullo_epi32(data_u32x16, data_u32x16);
|
|
1525
|
-
sumsq_u64x8 = _mm512_add_epi64(sumsq_u64x8, _mm512_unpacklo_epi32(sq_u32x16,
|
|
1526
|
-
sumsq_u64x8 = _mm512_add_epi64(sumsq_u64x8, _mm512_unpackhi_epi32(sq_u32x16,
|
|
1525
|
+
sumsq_u64x8 = _mm512_add_epi64(sumsq_u64x8, _mm512_unpacklo_epi32(sq_u32x16, zero_u32x16));
|
|
1526
|
+
sumsq_u64x8 = _mm512_add_epi64(sumsq_u64x8, _mm512_unpackhi_epi32(sq_u32x16, zero_u32x16));
|
|
1527
1527
|
}
|
|
1528
1528
|
if (idx < count) {
|
|
1529
1529
|
__mmask16 tail_mask = (__mmask16)((1u << (count - idx)) - 1);
|
|
1530
1530
|
__m512i data_u32x16 = _mm512_cvtepu16_epi32(_mm256_maskz_loadu_epi16(tail_mask, data_ptr + idx));
|
|
1531
1531
|
sum_u32x16 = _mm512_add_epi32(sum_u32x16, data_u32x16);
|
|
1532
1532
|
__m512i sq_u32x16 = _mm512_mullo_epi32(data_u32x16, data_u32x16);
|
|
1533
|
-
sumsq_u64x8 = _mm512_add_epi64(sumsq_u64x8, _mm512_unpacklo_epi32(sq_u32x16,
|
|
1534
|
-
sumsq_u64x8 = _mm512_add_epi64(sumsq_u64x8, _mm512_unpackhi_epi32(sq_u32x16,
|
|
1533
|
+
sumsq_u64x8 = _mm512_add_epi64(sumsq_u64x8, _mm512_unpacklo_epi32(sq_u32x16, zero_u32x16));
|
|
1534
|
+
sumsq_u64x8 = _mm512_add_epi64(sumsq_u64x8, _mm512_unpackhi_epi32(sq_u32x16, zero_u32x16));
|
|
1535
1535
|
}
|
|
1536
|
-
__m512i sum_u64x8 = _mm512_add_epi64(
|
|
1537
|
-
_mm512_unpacklo_epi32(sum_u32x16,
|
|
1538
|
-
_mm512_unpackhi_epi32(sum_u32x16,
|
|
1536
|
+
__m512i sum_u64x8 = _mm512_add_epi64( //
|
|
1537
|
+
_mm512_unpacklo_epi32(sum_u32x16, zero_u32x16), //
|
|
1538
|
+
_mm512_unpackhi_epi32(sum_u32x16, zero_u32x16)); //
|
|
1539
1539
|
*sum_ptr = (nk_u64_t)nk_reduce_add_i64x8_skylake_(sum_u64x8);
|
|
1540
1540
|
*sumsq_ptr = (nk_u64_t)nk_reduce_add_i64x8_skylake_(sumsq_u64x8);
|
|
1541
1541
|
}
|
|
@@ -1544,7 +1544,7 @@ NK_INTERNAL void nk_reduce_moments_u16_skylake_strided_( //
|
|
|
1544
1544
|
nk_u16_t const *data_ptr, nk_size_t count, nk_size_t stride_elements, //
|
|
1545
1545
|
nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
1546
1546
|
__mmask32 stride_mask_m32 = nk_stride_mask_b16x32_(stride_elements);
|
|
1547
|
-
__m512i
|
|
1547
|
+
__m512i zero_u32x16 = _mm512_setzero_si512();
|
|
1548
1548
|
__m512i sum_u64x8 = _mm512_setzero_si512();
|
|
1549
1549
|
__m512i sumsq_u64x8 = _mm512_setzero_si512();
|
|
1550
1550
|
nk_size_t idx_scalars = 0;
|
|
@@ -1552,18 +1552,18 @@ NK_INTERNAL void nk_reduce_moments_u16_skylake_strided_( //
|
|
|
1552
1552
|
nk_size_t step = (nk_size_t)_mm_popcnt_u64((nk_u64_t)stride_mask_m32) * stride_elements;
|
|
1553
1553
|
for (; idx_scalars + step <= total_scalars; idx_scalars += step) {
|
|
1554
1554
|
__m512i data_u16x32 = _mm512_maskz_loadu_epi16(stride_mask_m32, data_ptr + idx_scalars);
|
|
1555
|
-
__m512i low_u32x16 = _mm512_unpacklo_epi16(data_u16x32,
|
|
1556
|
-
__m512i high_u32x16 = _mm512_unpackhi_epi16(data_u16x32,
|
|
1557
|
-
sum_u64x8 = _mm512_add_epi64(sum_u64x8, _mm512_unpacklo_epi32(low_u32x16,
|
|
1558
|
-
sum_u64x8 = _mm512_add_epi64(sum_u64x8, _mm512_unpackhi_epi32(low_u32x16,
|
|
1559
|
-
sum_u64x8 = _mm512_add_epi64(sum_u64x8, _mm512_unpacklo_epi32(high_u32x16,
|
|
1560
|
-
sum_u64x8 = _mm512_add_epi64(sum_u64x8, _mm512_unpackhi_epi32(high_u32x16,
|
|
1561
|
-
__m512i
|
|
1562
|
-
__m512i
|
|
1563
|
-
sumsq_u64x8 = _mm512_add_epi64(sumsq_u64x8, _mm512_unpacklo_epi32(
|
|
1564
|
-
sumsq_u64x8 = _mm512_add_epi64(sumsq_u64x8, _mm512_unpackhi_epi32(
|
|
1565
|
-
sumsq_u64x8 = _mm512_add_epi64(sumsq_u64x8, _mm512_unpacklo_epi32(
|
|
1566
|
-
sumsq_u64x8 = _mm512_add_epi64(sumsq_u64x8, _mm512_unpackhi_epi32(
|
|
1555
|
+
__m512i low_u32x16 = _mm512_unpacklo_epi16(data_u16x32, zero_u32x16);
|
|
1556
|
+
__m512i high_u32x16 = _mm512_unpackhi_epi16(data_u16x32, zero_u32x16);
|
|
1557
|
+
sum_u64x8 = _mm512_add_epi64(sum_u64x8, _mm512_unpacklo_epi32(low_u32x16, zero_u32x16));
|
|
1558
|
+
sum_u64x8 = _mm512_add_epi64(sum_u64x8, _mm512_unpackhi_epi32(low_u32x16, zero_u32x16));
|
|
1559
|
+
sum_u64x8 = _mm512_add_epi64(sum_u64x8, _mm512_unpacklo_epi32(high_u32x16, zero_u32x16));
|
|
1560
|
+
sum_u64x8 = _mm512_add_epi64(sum_u64x8, _mm512_unpackhi_epi32(high_u32x16, zero_u32x16));
|
|
1561
|
+
__m512i low_sq_u32x16 = _mm512_mullo_epi32(low_u32x16, low_u32x16);
|
|
1562
|
+
__m512i high_sq_u32x16 = _mm512_mullo_epi32(high_u32x16, high_u32x16);
|
|
1563
|
+
sumsq_u64x8 = _mm512_add_epi64(sumsq_u64x8, _mm512_unpacklo_epi32(low_sq_u32x16, zero_u32x16));
|
|
1564
|
+
sumsq_u64x8 = _mm512_add_epi64(sumsq_u64x8, _mm512_unpackhi_epi32(low_sq_u32x16, zero_u32x16));
|
|
1565
|
+
sumsq_u64x8 = _mm512_add_epi64(sumsq_u64x8, _mm512_unpacklo_epi32(high_sq_u32x16, zero_u32x16));
|
|
1566
|
+
sumsq_u64x8 = _mm512_add_epi64(sumsq_u64x8, _mm512_unpackhi_epi32(high_sq_u32x16, zero_u32x16));
|
|
1567
1567
|
}
|
|
1568
1568
|
nk_u64_t sum = (nk_u64_t)nk_reduce_add_i64x8_skylake_(sum_u64x8);
|
|
1569
1569
|
nk_u64_t sumsq = (nk_u64_t)nk_reduce_add_i64x8_skylake_(sumsq_u64x8);
|
|
@@ -1699,24 +1699,24 @@ NK_PUBLIC void nk_reduce_minmax_u16_skylake( //
|
|
|
1699
1699
|
|
|
1700
1700
|
/** @brief Unsigned saturating add of two i64x8 vectors (3 uops). */
|
|
1701
1701
|
NK_INTERNAL __m512i nk_u64_sadd_epi64_skylake_(__m512i a, __m512i b) {
|
|
1702
|
-
__m512i
|
|
1703
|
-
__mmask8 ovf = _mm512_cmp_epu64_mask(
|
|
1704
|
-
return _mm512_mask_mov_epi64(
|
|
1702
|
+
__m512i result_u64x8 = _mm512_add_epi64(a, b);
|
|
1703
|
+
__mmask8 ovf = _mm512_cmp_epu64_mask(result_u64x8, a, _MM_CMPINT_LT);
|
|
1704
|
+
return _mm512_mask_mov_epi64(result_u64x8, ovf, _mm512_set1_epi64((nk_i64_t)-1));
|
|
1705
1705
|
}
|
|
1706
1706
|
|
|
1707
1707
|
/** @brief Saturating i64 square: clamp when |val| > floor(sqrt(INT64_MAX)). */
|
|
1708
1708
|
NK_INTERNAL __m512i nk_i64_smul_sq_epi64_skylake_(__m512i val) {
|
|
1709
|
-
__m512i
|
|
1710
|
-
__m512i
|
|
1711
|
-
__mmask8 ovf = _mm512_cmp_epu64_mask(
|
|
1712
|
-
return _mm512_mask_mov_epi64(
|
|
1709
|
+
__m512i sq_i64x8 = _mm512_mullo_epi64(val, val);
|
|
1710
|
+
__m512i abs_val_u64x8 = _mm512_abs_epi64(val);
|
|
1711
|
+
__mmask8 ovf = _mm512_cmp_epu64_mask(abs_val_u64x8, _mm512_set1_epi64(3037000499ll), _MM_CMPINT_NLE);
|
|
1712
|
+
return _mm512_mask_mov_epi64(sq_i64x8, ovf, _mm512_set1_epi64(9223372036854775807ll));
|
|
1713
1713
|
}
|
|
1714
1714
|
|
|
1715
1715
|
/** @brief Saturating u64 square: clamp when val > floor(sqrt(UINT64_MAX)). */
|
|
1716
1716
|
NK_INTERNAL __m512i nk_u64_smul_sq_epi64_skylake_(__m512i val) {
|
|
1717
|
-
__m512i
|
|
1717
|
+
__m512i sq_u64x8 = _mm512_mullo_epi64(val, val);
|
|
1718
1718
|
__mmask8 ovf = _mm512_cmp_epu64_mask(val, _mm512_set1_epi64(4294967295ll), _MM_CMPINT_NLE);
|
|
1719
|
-
return _mm512_mask_mov_epi64(
|
|
1719
|
+
return _mm512_mask_mov_epi64(sq_u64x8, ovf, _mm512_set1_epi64((nk_i64_t)-1));
|
|
1720
1720
|
}
|
|
1721
1721
|
|
|
1722
1722
|
/** @brief Saturating horizontal sum of 8 unsigned u64 lanes.
|
|
@@ -1738,8 +1738,8 @@ NK_INTERNAL void nk_reduce_moments_i32_skylake_contiguous_( //
|
|
|
1738
1738
|
nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
1739
1739
|
// Sum: 128-bit accumulation (lower + upper) — no block cap needed.
|
|
1740
1740
|
// Sumsq: unsigned wrapping accumulation with carry-based overflow detection.
|
|
1741
|
-
__m512i
|
|
1742
|
-
__m512i
|
|
1741
|
+
__m512i sum_low_i64x8 = _mm512_setzero_si512();
|
|
1742
|
+
__m512i sum_high_i64x8 = _mm512_setzero_si512();
|
|
1743
1743
|
__m512i sumsq_u64x8 = _mm512_setzero_si512();
|
|
1744
1744
|
__mmask8 sumsq_overflow_mask = 0;
|
|
1745
1745
|
__m512i one_i64x8 = _mm512_set1_epi64(1);
|
|
@@ -1750,18 +1750,18 @@ NK_INTERNAL void nk_reduce_moments_i32_skylake_contiguous_( //
|
|
|
1750
1750
|
__m256i high_i32x8 = _mm512_extracti64x4_epi64(data_i32x16, 1);
|
|
1751
1751
|
// 128-bit sum: lower half
|
|
1752
1752
|
__m512i widened_low_i64x8 = _mm512_cvtepi32_epi64(low_i32x8);
|
|
1753
|
-
__m512i sum_before_i64x8 =
|
|
1754
|
-
|
|
1755
|
-
__mmask8 carry = _mm512_cmp_epu64_mask(
|
|
1756
|
-
|
|
1757
|
-
|
|
1753
|
+
__m512i sum_before_i64x8 = sum_low_i64x8;
|
|
1754
|
+
sum_low_i64x8 = _mm512_add_epi64(sum_low_i64x8, widened_low_i64x8);
|
|
1755
|
+
__mmask8 carry = _mm512_cmp_epu64_mask(sum_low_i64x8, sum_before_i64x8, _MM_CMPINT_LT);
|
|
1756
|
+
sum_high_i64x8 = _mm512_add_epi64(sum_high_i64x8, _mm512_srai_epi64(widened_low_i64x8, 63));
|
|
1757
|
+
sum_high_i64x8 = _mm512_mask_add_epi64(sum_high_i64x8, carry, sum_high_i64x8, one_i64x8);
|
|
1758
1758
|
// 128-bit sum: upper half
|
|
1759
1759
|
__m512i widened_high_i64x8 = _mm512_cvtepi32_epi64(high_i32x8);
|
|
1760
|
-
sum_before_i64x8 =
|
|
1761
|
-
|
|
1762
|
-
carry = _mm512_cmp_epu64_mask(
|
|
1763
|
-
|
|
1764
|
-
|
|
1760
|
+
sum_before_i64x8 = sum_low_i64x8;
|
|
1761
|
+
sum_low_i64x8 = _mm512_add_epi64(sum_low_i64x8, widened_high_i64x8);
|
|
1762
|
+
carry = _mm512_cmp_epu64_mask(sum_low_i64x8, sum_before_i64x8, _MM_CMPINT_LT);
|
|
1763
|
+
sum_high_i64x8 = _mm512_add_epi64(sum_high_i64x8, _mm512_srai_epi64(widened_high_i64x8, 63));
|
|
1764
|
+
sum_high_i64x8 = _mm512_mask_add_epi64(sum_high_i64x8, carry, sum_high_i64x8, one_i64x8);
|
|
1765
1765
|
// Sumsq: unsigned accumulation with carry detection
|
|
1766
1766
|
__m512i even_sq_u64x8 = _mm512_mul_epi32(data_i32x16, data_i32x16);
|
|
1767
1767
|
__m512i odd_i32x16 = _mm512_srli_epi64(data_i32x16, 32);
|
|
@@ -1778,18 +1778,18 @@ NK_INTERNAL void nk_reduce_moments_i32_skylake_contiguous_( //
|
|
|
1778
1778
|
__m256i low_i32x8 = _mm512_castsi512_si256(data_i32x16);
|
|
1779
1779
|
__m256i high_i32x8 = _mm512_extracti64x4_epi64(data_i32x16, 1);
|
|
1780
1780
|
__m512i widened_low_i64x8 = _mm512_cvtepi32_epi64(low_i32x8);
|
|
1781
|
-
__m512i sum_before_i64x8 =
|
|
1782
|
-
|
|
1783
|
-
__mmask8 carry = _mm512_cmp_epu64_mask(
|
|
1784
|
-
|
|
1785
|
-
|
|
1781
|
+
__m512i sum_before_i64x8 = sum_low_i64x8;
|
|
1782
|
+
sum_low_i64x8 = _mm512_add_epi64(sum_low_i64x8, widened_low_i64x8);
|
|
1783
|
+
__mmask8 carry = _mm512_cmp_epu64_mask(sum_low_i64x8, sum_before_i64x8, _MM_CMPINT_LT);
|
|
1784
|
+
sum_high_i64x8 = _mm512_add_epi64(sum_high_i64x8, _mm512_srai_epi64(widened_low_i64x8, 63));
|
|
1785
|
+
sum_high_i64x8 = _mm512_mask_add_epi64(sum_high_i64x8, carry, sum_high_i64x8, one_i64x8);
|
|
1786
1786
|
if (remaining > 8) {
|
|
1787
1787
|
__m512i widened_high_i64x8 = _mm512_cvtepi32_epi64(high_i32x8);
|
|
1788
|
-
sum_before_i64x8 =
|
|
1789
|
-
|
|
1790
|
-
carry = _mm512_cmp_epu64_mask(
|
|
1791
|
-
|
|
1792
|
-
|
|
1788
|
+
sum_before_i64x8 = sum_low_i64x8;
|
|
1789
|
+
sum_low_i64x8 = _mm512_add_epi64(sum_low_i64x8, widened_high_i64x8);
|
|
1790
|
+
carry = _mm512_cmp_epu64_mask(sum_low_i64x8, sum_before_i64x8, _MM_CMPINT_LT);
|
|
1791
|
+
sum_high_i64x8 = _mm512_add_epi64(sum_high_i64x8, _mm512_srai_epi64(widened_high_i64x8, 63));
|
|
1792
|
+
sum_high_i64x8 = _mm512_mask_add_epi64(sum_high_i64x8, carry, sum_high_i64x8, one_i64x8);
|
|
1793
1793
|
}
|
|
1794
1794
|
__m512i even_sq_u64x8 = _mm512_mul_epi32(data_i32x16, data_i32x16);
|
|
1795
1795
|
__m512i odd_i32x16 = _mm512_srli_epi64(data_i32x16, 32);
|
|
@@ -1805,36 +1805,36 @@ NK_INTERNAL void nk_reduce_moments_i32_skylake_contiguous_( //
|
|
|
1805
1805
|
else sumsq = nk_reduce_sadd_u64x8_skylake_(sumsq_u64x8);
|
|
1806
1806
|
// Sum: horizontal 128-bit tree reduction, same as i64 skylake
|
|
1807
1807
|
{ // 8→4
|
|
1808
|
-
__m512i
|
|
1809
|
-
__m512i
|
|
1810
|
-
__m512i before_i64x8 =
|
|
1811
|
-
|
|
1812
|
-
__mmask8 carry = _mm512_cmp_epu64_mask(
|
|
1813
|
-
|
|
1814
|
-
|
|
1808
|
+
__m512i fold_low_i64x8 = _mm512_shuffle_i64x2(sum_low_i64x8, sum_low_i64x8, _MM_SHUFFLE(1, 0, 3, 2));
|
|
1809
|
+
__m512i fold_high_i64x8 = _mm512_shuffle_i64x2(sum_high_i64x8, sum_high_i64x8, _MM_SHUFFLE(1, 0, 3, 2));
|
|
1810
|
+
__m512i before_i64x8 = sum_low_i64x8;
|
|
1811
|
+
sum_low_i64x8 = _mm512_add_epi64(sum_low_i64x8, fold_low_i64x8);
|
|
1812
|
+
__mmask8 carry = _mm512_cmp_epu64_mask(sum_low_i64x8, before_i64x8, _MM_CMPINT_LT);
|
|
1813
|
+
sum_high_i64x8 = _mm512_add_epi64(sum_high_i64x8, fold_high_i64x8);
|
|
1814
|
+
sum_high_i64x8 = _mm512_mask_add_epi64(sum_high_i64x8, carry, sum_high_i64x8, one_i64x8);
|
|
1815
1815
|
}
|
|
1816
1816
|
{ // 4→2
|
|
1817
|
-
__m512i
|
|
1818
|
-
__m512i
|
|
1819
|
-
__m512i before_i64x8 =
|
|
1820
|
-
|
|
1821
|
-
__mmask8 carry = _mm512_cmp_epu64_mask(
|
|
1822
|
-
|
|
1823
|
-
|
|
1817
|
+
__m512i fold_low_i64x8 = _mm512_shuffle_i64x2(sum_low_i64x8, sum_low_i64x8, _MM_SHUFFLE(2, 3, 0, 1));
|
|
1818
|
+
__m512i fold_high_i64x8 = _mm512_shuffle_i64x2(sum_high_i64x8, sum_high_i64x8, _MM_SHUFFLE(2, 3, 0, 1));
|
|
1819
|
+
__m512i before_i64x8 = sum_low_i64x8;
|
|
1820
|
+
sum_low_i64x8 = _mm512_add_epi64(sum_low_i64x8, fold_low_i64x8);
|
|
1821
|
+
__mmask8 carry = _mm512_cmp_epu64_mask(sum_low_i64x8, before_i64x8, _MM_CMPINT_LT);
|
|
1822
|
+
sum_high_i64x8 = _mm512_add_epi64(sum_high_i64x8, fold_high_i64x8);
|
|
1823
|
+
sum_high_i64x8 = _mm512_mask_add_epi64(sum_high_i64x8, carry, sum_high_i64x8, one_i64x8);
|
|
1824
1824
|
}
|
|
1825
1825
|
{ // 2→1
|
|
1826
|
-
__m512i
|
|
1827
|
-
__m512i
|
|
1828
|
-
__m512i before_i64x8 =
|
|
1829
|
-
|
|
1830
|
-
__mmask8 carry = _mm512_cmp_epu64_mask(
|
|
1831
|
-
|
|
1832
|
-
|
|
1833
|
-
}
|
|
1834
|
-
nk_i64_t
|
|
1835
|
-
nk_i64_t
|
|
1836
|
-
if (
|
|
1837
|
-
else if (
|
|
1826
|
+
__m512i fold_low_i64x8 = _mm512_alignr_epi64(sum_low_i64x8, sum_low_i64x8, 1);
|
|
1827
|
+
__m512i fold_high_i64x8 = _mm512_alignr_epi64(sum_high_i64x8, sum_high_i64x8, 1);
|
|
1828
|
+
__m512i before_i64x8 = sum_low_i64x8;
|
|
1829
|
+
sum_low_i64x8 = _mm512_add_epi64(sum_low_i64x8, fold_low_i64x8);
|
|
1830
|
+
__mmask8 carry = _mm512_cmp_epu64_mask(sum_low_i64x8, before_i64x8, _MM_CMPINT_LT);
|
|
1831
|
+
sum_high_i64x8 = _mm512_add_epi64(sum_high_i64x8, fold_high_i64x8);
|
|
1832
|
+
sum_high_i64x8 = _mm512_mask_add_epi64(sum_high_i64x8, carry, sum_high_i64x8, one_i64x8);
|
|
1833
|
+
}
|
|
1834
|
+
nk_i64_t sum_low = _mm_cvtsi128_si64(_mm512_castsi512_si128(sum_low_i64x8));
|
|
1835
|
+
nk_i64_t sum_high = _mm_cvtsi128_si64(_mm512_castsi512_si128(sum_high_i64x8));
|
|
1836
|
+
if (sum_high == (sum_low >> 63)) *sum_ptr = sum_low;
|
|
1837
|
+
else if (sum_high >= 0) *sum_ptr = NK_I64_MAX;
|
|
1838
1838
|
else *sum_ptr = NK_I64_MIN;
|
|
1839
1839
|
*sumsq_ptr = sumsq;
|
|
1840
1840
|
}
|
|
@@ -2119,8 +2119,8 @@ NK_INTERNAL void nk_reduce_moments_i64_skylake_contiguous_( //
|
|
|
2119
2119
|
nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
2120
2120
|
// Sum: double-width 128-bit accumulation per lane.
|
|
2121
2121
|
// Sumsq: unsigned wrapping accumulation with carry-based overflow detection.
|
|
2122
|
-
__m512i
|
|
2123
|
-
__m512i
|
|
2122
|
+
__m512i sum_low_i64x8 = _mm512_setzero_si512();
|
|
2123
|
+
__m512i sum_high_i64x8 = _mm512_setzero_si512();
|
|
2124
2124
|
__m512i sumsq_u64x8 = _mm512_setzero_si512();
|
|
2125
2125
|
__mmask8 sumsq_overflow_mask = 0;
|
|
2126
2126
|
__m512i one_i64x8 = _mm512_set1_epi64(1);
|
|
@@ -2130,11 +2130,11 @@ NK_INTERNAL void nk_reduce_moments_i64_skylake_contiguous_( //
|
|
|
2130
2130
|
__m512i squared_i64x8 = nk_i64_smul_sq_epi64_skylake_(data_i64x8);
|
|
2131
2131
|
sumsq_u64x8 = _mm512_add_epi64(sumsq_u64x8, squared_i64x8);
|
|
2132
2132
|
sumsq_overflow_mask |= _mm512_cmp_epu64_mask(sumsq_u64x8, squared_i64x8, _MM_CMPINT_LT);
|
|
2133
|
-
__m512i sum_before_i64x8 =
|
|
2134
|
-
|
|
2135
|
-
__mmask8 carry = _mm512_cmp_epu64_mask(
|
|
2136
|
-
|
|
2137
|
-
|
|
2133
|
+
__m512i sum_before_i64x8 = sum_low_i64x8;
|
|
2134
|
+
sum_low_i64x8 = _mm512_add_epi64(sum_low_i64x8, data_i64x8);
|
|
2135
|
+
__mmask8 carry = _mm512_cmp_epu64_mask(sum_low_i64x8, sum_before_i64x8, _MM_CMPINT_LT);
|
|
2136
|
+
sum_high_i64x8 = _mm512_add_epi64(sum_high_i64x8, _mm512_srai_epi64(data_i64x8, 63));
|
|
2137
|
+
sum_high_i64x8 = _mm512_mask_add_epi64(sum_high_i64x8, carry, sum_high_i64x8, one_i64x8);
|
|
2138
2138
|
}
|
|
2139
2139
|
nk_size_t remaining = count - idx;
|
|
2140
2140
|
if (remaining > 0) {
|
|
@@ -2143,11 +2143,11 @@ NK_INTERNAL void nk_reduce_moments_i64_skylake_contiguous_( //
|
|
|
2143
2143
|
__m512i squared_i64x8 = nk_i64_smul_sq_epi64_skylake_(data_i64x8);
|
|
2144
2144
|
sumsq_u64x8 = _mm512_add_epi64(sumsq_u64x8, squared_i64x8);
|
|
2145
2145
|
sumsq_overflow_mask |= _mm512_cmp_epu64_mask(sumsq_u64x8, squared_i64x8, _MM_CMPINT_LT);
|
|
2146
|
-
__m512i sum_before_i64x8 =
|
|
2147
|
-
|
|
2148
|
-
__mmask8 carry = _mm512_cmp_epu64_mask(
|
|
2149
|
-
|
|
2150
|
-
|
|
2146
|
+
__m512i sum_before_i64x8 = sum_low_i64x8;
|
|
2147
|
+
sum_low_i64x8 = _mm512_add_epi64(sum_low_i64x8, data_i64x8);
|
|
2148
|
+
__mmask8 carry = _mm512_cmp_epu64_mask(sum_low_i64x8, sum_before_i64x8, _MM_CMPINT_LT);
|
|
2149
|
+
sum_high_i64x8 = _mm512_add_epi64(sum_high_i64x8, _mm512_srai_epi64(data_i64x8, 63));
|
|
2150
|
+
sum_high_i64x8 = _mm512_mask_add_epi64(sum_high_i64x8, carry, sum_high_i64x8, one_i64x8);
|
|
2151
2151
|
}
|
|
2152
2152
|
// Sumsq: horizontal unsigned saturating reduction
|
|
2153
2153
|
nk_u64_t sumsq;
|
|
@@ -2155,37 +2155,37 @@ NK_INTERNAL void nk_reduce_moments_i64_skylake_contiguous_( //
|
|
|
2155
2155
|
else sumsq = nk_reduce_sadd_u64x8_skylake_(sumsq_u64x8);
|
|
2156
2156
|
// Sum: horizontal 128-bit tree reduction (8→4→2→1), then clamp to i64
|
|
2157
2157
|
{ // 8→4: fold high 256 bits into low 256 bits
|
|
2158
|
-
__m512i
|
|
2159
|
-
__m512i
|
|
2160
|
-
__m512i before_i64x8 =
|
|
2161
|
-
|
|
2162
|
-
__mmask8 carry = _mm512_cmp_epu64_mask(
|
|
2163
|
-
|
|
2164
|
-
|
|
2158
|
+
__m512i fold_low_i64x8 = _mm512_shuffle_i64x2(sum_low_i64x8, sum_low_i64x8, _MM_SHUFFLE(1, 0, 3, 2));
|
|
2159
|
+
__m512i fold_high_i64x8 = _mm512_shuffle_i64x2(sum_high_i64x8, sum_high_i64x8, _MM_SHUFFLE(1, 0, 3, 2));
|
|
2160
|
+
__m512i before_i64x8 = sum_low_i64x8;
|
|
2161
|
+
sum_low_i64x8 = _mm512_add_epi64(sum_low_i64x8, fold_low_i64x8);
|
|
2162
|
+
__mmask8 carry = _mm512_cmp_epu64_mask(sum_low_i64x8, before_i64x8, _MM_CMPINT_LT);
|
|
2163
|
+
sum_high_i64x8 = _mm512_add_epi64(sum_high_i64x8, fold_high_i64x8);
|
|
2164
|
+
sum_high_i64x8 = _mm512_mask_add_epi64(sum_high_i64x8, carry, sum_high_i64x8, one_i64x8);
|
|
2165
2165
|
}
|
|
2166
2166
|
{ // 4→2: fold lanes 2-3 into lanes 0-1
|
|
2167
|
-
__m512i
|
|
2168
|
-
__m512i
|
|
2169
|
-
__m512i before_i64x8 =
|
|
2170
|
-
|
|
2171
|
-
__mmask8 carry = _mm512_cmp_epu64_mask(
|
|
2172
|
-
|
|
2173
|
-
|
|
2167
|
+
__m512i fold_low_i64x8 = _mm512_shuffle_i64x2(sum_low_i64x8, sum_low_i64x8, _MM_SHUFFLE(2, 3, 0, 1));
|
|
2168
|
+
__m512i fold_high_i64x8 = _mm512_shuffle_i64x2(sum_high_i64x8, sum_high_i64x8, _MM_SHUFFLE(2, 3, 0, 1));
|
|
2169
|
+
__m512i before_i64x8 = sum_low_i64x8;
|
|
2170
|
+
sum_low_i64x8 = _mm512_add_epi64(sum_low_i64x8, fold_low_i64x8);
|
|
2171
|
+
__mmask8 carry = _mm512_cmp_epu64_mask(sum_low_i64x8, before_i64x8, _MM_CMPINT_LT);
|
|
2172
|
+
sum_high_i64x8 = _mm512_add_epi64(sum_high_i64x8, fold_high_i64x8);
|
|
2173
|
+
sum_high_i64x8 = _mm512_mask_add_epi64(sum_high_i64x8, carry, sum_high_i64x8, one_i64x8);
|
|
2174
2174
|
}
|
|
2175
2175
|
{ // 2→1: fold lane 1 into lane 0
|
|
2176
|
-
__m512i
|
|
2177
|
-
__m512i
|
|
2178
|
-
__m512i before_i64x8 =
|
|
2179
|
-
|
|
2180
|
-
__mmask8 carry = _mm512_cmp_epu64_mask(
|
|
2181
|
-
|
|
2182
|
-
|
|
2176
|
+
__m512i fold_low_i64x8 = _mm512_alignr_epi64(sum_low_i64x8, sum_low_i64x8, 1);
|
|
2177
|
+
__m512i fold_high_i64x8 = _mm512_alignr_epi64(sum_high_i64x8, sum_high_i64x8, 1);
|
|
2178
|
+
__m512i before_i64x8 = sum_low_i64x8;
|
|
2179
|
+
sum_low_i64x8 = _mm512_add_epi64(sum_low_i64x8, fold_low_i64x8);
|
|
2180
|
+
__mmask8 carry = _mm512_cmp_epu64_mask(sum_low_i64x8, before_i64x8, _MM_CMPINT_LT);
|
|
2181
|
+
sum_high_i64x8 = _mm512_add_epi64(sum_high_i64x8, fold_high_i64x8);
|
|
2182
|
+
sum_high_i64x8 = _mm512_mask_add_epi64(sum_high_i64x8, carry, sum_high_i64x8, one_i64x8);
|
|
2183
2183
|
}
|
|
2184
2184
|
// Clamp 128-bit result to [INT64_MIN, INT64_MAX]: fits iff upper == sign-extension of lower
|
|
2185
|
-
nk_i64_t
|
|
2186
|
-
nk_i64_t
|
|
2187
|
-
if (
|
|
2188
|
-
else if (
|
|
2185
|
+
nk_i64_t sum_low = _mm_cvtsi128_si64(_mm512_castsi512_si128(sum_low_i64x8));
|
|
2186
|
+
nk_i64_t sum_high = _mm_cvtsi128_si64(_mm512_castsi512_si128(sum_high_i64x8));
|
|
2187
|
+
if (sum_high == (sum_low >> 63)) *sum_ptr = sum_low;
|
|
2188
|
+
else if (sum_high >= 0) *sum_ptr = NK_I64_MAX;
|
|
2189
2189
|
else *sum_ptr = NK_I64_MIN;
|
|
2190
2190
|
*sumsq_ptr = sumsq;
|
|
2191
2191
|
}
|
|
@@ -2534,16 +2534,16 @@ NK_INTERNAL void nk_reduce_moments_e4m3_skylake_strided_( //
|
|
|
2534
2534
|
nk_size_t total_scalars = count * stride_elements;
|
|
2535
2535
|
nk_size_t step = (nk_size_t)_mm_popcnt_u64((nk_u64_t)stride_mask_m16) * stride_elements;
|
|
2536
2536
|
for (; idx_scalars + step <= total_scalars; idx_scalars += step) {
|
|
2537
|
-
__m128i
|
|
2538
|
-
__m512 data_f32x16 = nk_e4m3x16_to_f32x16_skylake_(
|
|
2537
|
+
__m128i data_e4m3_u8x16 = _mm_maskz_loadu_epi8(stride_mask_m16, data_ptr + idx_scalars);
|
|
2538
|
+
__m512 data_f32x16 = nk_e4m3x16_to_f32x16_skylake_(data_e4m3_u8x16);
|
|
2539
2539
|
sum_f32x16 = _mm512_add_ps(sum_f32x16, data_f32x16);
|
|
2540
2540
|
sumsq_f32x16 = _mm512_fmadd_ps(data_f32x16, data_f32x16, sumsq_f32x16);
|
|
2541
2541
|
}
|
|
2542
2542
|
nk_size_t remaining_bytes = total_scalars - idx_scalars;
|
|
2543
2543
|
if (remaining_bytes > 0) {
|
|
2544
2544
|
__mmask16 tail_mask = stride_mask_m16 & (__mmask16)_bzhi_u32(0xFFFF, (unsigned int)remaining_bytes);
|
|
2545
|
-
__m128i
|
|
2546
|
-
__m512 data_f32x16 = nk_e4m3x16_to_f32x16_skylake_(
|
|
2545
|
+
__m128i data_e4m3_u8x16 = _mm_maskz_loadu_epi8(tail_mask, data_ptr + idx_scalars);
|
|
2546
|
+
__m512 data_f32x16 = nk_e4m3x16_to_f32x16_skylake_(data_e4m3_u8x16);
|
|
2547
2547
|
sum_f32x16 = _mm512_add_ps(sum_f32x16, data_f32x16);
|
|
2548
2548
|
sumsq_f32x16 = _mm512_fmadd_ps(data_f32x16, data_f32x16, sumsq_f32x16);
|
|
2549
2549
|
}
|
|
@@ -2612,13 +2612,14 @@ NK_INTERNAL void nk_reduce_minmax_e4m3_skylake_contiguous_( //
|
|
|
2612
2612
|
__mmask64 is_nan_m64 = _mm512_cmpeq_epi8_mask(data_cmp_u8x64, _mm512_setzero_si512()) |
|
|
2613
2613
|
_mm512_cmpeq_epi8_mask(data_cmp_u8x64, _mm512_set1_epi8((char)0xFF));
|
|
2614
2614
|
__mmask64 valid_non_nan_m64 = tail_load & ~is_nan_m64;
|
|
2615
|
-
__m512i
|
|
2616
|
-
|
|
2617
|
-
__m512i
|
|
2615
|
+
__m512i data_cmp_min_u8x64 = _mm512_mask_blend_epi8(valid_non_nan_m64, _mm512_set1_epi8((char)0xFF),
|
|
2616
|
+
data_cmp_u8x64);
|
|
2617
|
+
__m512i data_cmp_max_u8x64 = _mm512_mask_blend_epi8(valid_non_nan_m64, _mm512_setzero_si512(), data_cmp_u8x64);
|
|
2618
|
+
__m512i new_min_u8x64 = _mm512_min_epu8(min_vec.zmm, data_cmp_min_u8x64);
|
|
2618
2619
|
__mmask64 min_changed_mask = ~_mm512_cmpeq_epi8_mask(new_min_u8x64, min_vec.zmm);
|
|
2619
2620
|
min_vec.zmm = new_min_u8x64;
|
|
2620
2621
|
min_loop_cycle_u8x64 = _mm512_mask_mov_epi8(min_loop_cycle_u8x64, min_changed_mask, current_loop_cycle_u8x64);
|
|
2621
|
-
__m512i new_max_u8x64 = _mm512_max_epu8(max_vec.zmm,
|
|
2622
|
+
__m512i new_max_u8x64 = _mm512_max_epu8(max_vec.zmm, data_cmp_max_u8x64);
|
|
2622
2623
|
__mmask64 max_changed_mask = ~_mm512_cmpeq_epi8_mask(new_max_u8x64, max_vec.zmm);
|
|
2623
2624
|
max_vec.zmm = new_max_u8x64;
|
|
2624
2625
|
max_loop_cycle_u8x64 = _mm512_mask_mov_epi8(max_loop_cycle_u8x64, max_changed_mask, current_loop_cycle_u8x64);
|
|
@@ -2739,16 +2740,16 @@ NK_INTERNAL void nk_reduce_moments_e5m2_skylake_strided_( //
|
|
|
2739
2740
|
nk_size_t total_scalars = count * stride_elements;
|
|
2740
2741
|
nk_size_t step = (nk_size_t)_mm_popcnt_u64((nk_u64_t)stride_mask_m16) * stride_elements;
|
|
2741
2742
|
for (; idx_scalars + step <= total_scalars; idx_scalars += step) {
|
|
2742
|
-
__m128i
|
|
2743
|
-
__m512 data_f32x16 = nk_e5m2x16_to_f32x16_skylake_(
|
|
2743
|
+
__m128i data_e5m2_u8x16 = _mm_maskz_loadu_epi8(stride_mask_m16, data_ptr + idx_scalars);
|
|
2744
|
+
__m512 data_f32x16 = nk_e5m2x16_to_f32x16_skylake_(data_e5m2_u8x16);
|
|
2744
2745
|
sum_f32x16 = _mm512_add_ps(sum_f32x16, data_f32x16);
|
|
2745
2746
|
sumsq_f32x16 = _mm512_fmadd_ps(data_f32x16, data_f32x16, sumsq_f32x16);
|
|
2746
2747
|
}
|
|
2747
2748
|
nk_size_t remaining_bytes = total_scalars - idx_scalars;
|
|
2748
2749
|
if (remaining_bytes > 0) {
|
|
2749
2750
|
__mmask16 tail_mask = stride_mask_m16 & (__mmask16)_bzhi_u32(0xFFFF, (unsigned int)remaining_bytes);
|
|
2750
|
-
__m128i
|
|
2751
|
-
__m512 data_f32x16 = nk_e5m2x16_to_f32x16_skylake_(
|
|
2751
|
+
__m128i data_e5m2_u8x16 = _mm_maskz_loadu_epi8(tail_mask, data_ptr + idx_scalars);
|
|
2752
|
+
__m512 data_f32x16 = nk_e5m2x16_to_f32x16_skylake_(data_e5m2_u8x16);
|
|
2752
2753
|
sum_f32x16 = _mm512_add_ps(sum_f32x16, data_f32x16);
|
|
2753
2754
|
sumsq_f32x16 = _mm512_fmadd_ps(data_f32x16, data_f32x16, sumsq_f32x16);
|
|
2754
2755
|
}
|
|
@@ -2809,16 +2810,16 @@ NK_INTERNAL void nk_reduce_moments_e2m3_skylake_strided_( //
|
|
|
2809
2810
|
nk_size_t total_scalars = count * stride_elements;
|
|
2810
2811
|
nk_size_t step = (nk_size_t)_mm_popcnt_u64((nk_u64_t)stride_mask_m16) * stride_elements;
|
|
2811
2812
|
for (; idx_scalars + step <= total_scalars; idx_scalars += step) {
|
|
2812
|
-
__m128i
|
|
2813
|
-
__m512 data_f32x16 = nk_e2m3x16_to_f32x16_skylake_(
|
|
2813
|
+
__m128i data_e2m3_u8x16 = _mm_maskz_loadu_epi8(stride_mask_m16, data_ptr + idx_scalars);
|
|
2814
|
+
__m512 data_f32x16 = nk_e2m3x16_to_f32x16_skylake_(data_e2m3_u8x16);
|
|
2814
2815
|
sum_f32x16 = _mm512_add_ps(sum_f32x16, data_f32x16);
|
|
2815
2816
|
sumsq_f32x16 = _mm512_fmadd_ps(data_f32x16, data_f32x16, sumsq_f32x16);
|
|
2816
2817
|
}
|
|
2817
2818
|
nk_size_t remaining_bytes = total_scalars - idx_scalars;
|
|
2818
2819
|
if (remaining_bytes > 0) {
|
|
2819
2820
|
__mmask16 tail_mask = stride_mask_m16 & (__mmask16)_bzhi_u32(0xFFFF, (unsigned int)remaining_bytes);
|
|
2820
|
-
__m128i
|
|
2821
|
-
__m512 data_f32x16 = nk_e2m3x16_to_f32x16_skylake_(
|
|
2821
|
+
__m128i data_e2m3_u8x16 = _mm_maskz_loadu_epi8(tail_mask, data_ptr + idx_scalars);
|
|
2822
|
+
__m512 data_f32x16 = nk_e2m3x16_to_f32x16_skylake_(data_e2m3_u8x16);
|
|
2822
2823
|
sum_f32x16 = _mm512_add_ps(sum_f32x16, data_f32x16);
|
|
2823
2824
|
sumsq_f32x16 = _mm512_fmadd_ps(data_f32x16, data_f32x16, sumsq_f32x16);
|
|
2824
2825
|
}
|
|
@@ -2879,16 +2880,16 @@ NK_INTERNAL void nk_reduce_moments_e3m2_skylake_strided_( //
|
|
|
2879
2880
|
nk_size_t total_scalars = count * stride_elements;
|
|
2880
2881
|
nk_size_t step = (nk_size_t)_mm_popcnt_u64((nk_u64_t)stride_mask_m16) * stride_elements;
|
|
2881
2882
|
for (; idx_scalars + step <= total_scalars; idx_scalars += step) {
|
|
2882
|
-
__m128i
|
|
2883
|
-
__m512 data_f32x16 = nk_e3m2x16_to_f32x16_skylake_(
|
|
2883
|
+
__m128i data_e3m2_u8x16 = _mm_maskz_loadu_epi8(stride_mask_m16, data_ptr + idx_scalars);
|
|
2884
|
+
__m512 data_f32x16 = nk_e3m2x16_to_f32x16_skylake_(data_e3m2_u8x16);
|
|
2884
2885
|
sum_f32x16 = _mm512_add_ps(sum_f32x16, data_f32x16);
|
|
2885
2886
|
sumsq_f32x16 = _mm512_fmadd_ps(data_f32x16, data_f32x16, sumsq_f32x16);
|
|
2886
2887
|
}
|
|
2887
2888
|
nk_size_t remaining_bytes = total_scalars - idx_scalars;
|
|
2888
2889
|
if (remaining_bytes > 0) {
|
|
2889
2890
|
__mmask16 tail_mask = stride_mask_m16 & (__mmask16)_bzhi_u32(0xFFFF, (unsigned int)remaining_bytes);
|
|
2890
|
-
__m128i
|
|
2891
|
-
__m512 data_f32x16 = nk_e3m2x16_to_f32x16_skylake_(
|
|
2891
|
+
__m128i data_e3m2_u8x16 = _mm_maskz_loadu_epi8(tail_mask, data_ptr + idx_scalars);
|
|
2892
|
+
__m512 data_f32x16 = nk_e3m2x16_to_f32x16_skylake_(data_e3m2_u8x16);
|
|
2892
2893
|
sum_f32x16 = _mm512_add_ps(sum_f32x16, data_f32x16);
|
|
2893
2894
|
sumsq_f32x16 = _mm512_fmadd_ps(data_f32x16, data_f32x16, sumsq_f32x16);
|
|
2894
2895
|
}
|
|
@@ -2957,13 +2958,14 @@ NK_INTERNAL void nk_reduce_minmax_e5m2_skylake_contiguous_( //
|
|
|
2957
2958
|
__mmask64 is_nan_m64 = _mm512_cmple_epu8_mask(data_cmp_u8x64, _mm512_set1_epi8(0x02)) |
|
|
2958
2959
|
_mm512_cmpge_epu8_mask(data_cmp_u8x64, _mm512_set1_epi8((char)0xFD));
|
|
2959
2960
|
__mmask64 valid_non_nan_m64 = tail_load & ~is_nan_m64;
|
|
2960
|
-
__m512i
|
|
2961
|
-
|
|
2962
|
-
__m512i
|
|
2961
|
+
__m512i data_cmp_min_u8x64 = _mm512_mask_blend_epi8(valid_non_nan_m64, _mm512_set1_epi8((char)0xFF),
|
|
2962
|
+
data_cmp_u8x64);
|
|
2963
|
+
__m512i data_cmp_max_u8x64 = _mm512_mask_blend_epi8(valid_non_nan_m64, _mm512_setzero_si512(), data_cmp_u8x64);
|
|
2964
|
+
__m512i new_min_u8x64 = _mm512_min_epu8(min_vec.zmm, data_cmp_min_u8x64);
|
|
2963
2965
|
__mmask64 min_changed_mask = ~_mm512_cmpeq_epi8_mask(new_min_u8x64, min_vec.zmm);
|
|
2964
2966
|
min_vec.zmm = new_min_u8x64;
|
|
2965
2967
|
min_loop_cycle_u8x64 = _mm512_mask_mov_epi8(min_loop_cycle_u8x64, min_changed_mask, current_loop_cycle_u8x64);
|
|
2966
|
-
__m512i new_max_u8x64 = _mm512_max_epu8(max_vec.zmm,
|
|
2968
|
+
__m512i new_max_u8x64 = _mm512_max_epu8(max_vec.zmm, data_cmp_max_u8x64);
|
|
2967
2969
|
__mmask64 max_changed_mask = ~_mm512_cmpeq_epi8_mask(new_max_u8x64, max_vec.zmm);
|
|
2968
2970
|
max_vec.zmm = new_max_u8x64;
|
|
2969
2971
|
max_loop_cycle_u8x64 = _mm512_mask_mov_epi8(max_loop_cycle_u8x64, max_changed_mask, current_loop_cycle_u8x64);
|
|
@@ -3284,16 +3286,16 @@ NK_INTERNAL void nk_reduce_moments_i4_skylake_contiguous_( //
|
|
|
3284
3286
|
ptr += 64, count_bytes -= 64;
|
|
3285
3287
|
}
|
|
3286
3288
|
// Extract nibbles as unsigned [0,15]
|
|
3287
|
-
__m512i
|
|
3288
|
-
__m512i
|
|
3289
|
+
__m512i low_u4_u8x64 = _mm512_and_si512(raw_i8x64, mask_0f_i8x64);
|
|
3290
|
+
__m512i high_u4_u8x64 = _mm512_and_si512(_mm512_srli_epi16(raw_i8x64, 4), mask_0f_i8x64);
|
|
3289
3291
|
// Sum: XOR-bias nibbles to unsigned [0,15], add lo+hi per byte, vpsadbw
|
|
3290
|
-
__m512i low_biased_u4x64 = _mm512_xor_si512(
|
|
3291
|
-
__m512i high_biased_u4x64 = _mm512_xor_si512(
|
|
3292
|
-
__m512i
|
|
3293
|
-
sum_u64x8 = _mm512_add_epi64(sum_u64x8, _mm512_sad_epu8(
|
|
3292
|
+
__m512i low_biased_u4x64 = _mm512_xor_si512(low_u4_u8x64, eight_i8x64);
|
|
3293
|
+
__m512i high_biased_u4x64 = _mm512_xor_si512(high_u4_u8x64, eight_i8x64);
|
|
3294
|
+
__m512i pair_sum_u8x64 = _mm512_add_epi8(low_biased_u4x64, high_biased_u4x64);
|
|
3295
|
+
sum_u64x8 = _mm512_add_epi64(sum_u64x8, _mm512_sad_epu8(pair_sum_u8x64, zero_i8x64));
|
|
3294
3296
|
// Sumsq: squares are sign-independent, use LUT on unsigned nibbles
|
|
3295
|
-
__m512i low_sq_u8x64 = _mm512_shuffle_epi8(sq_lut_u8x64,
|
|
3296
|
-
__m512i high_sq_u8x64 = _mm512_shuffle_epi8(sq_lut_u8x64,
|
|
3297
|
+
__m512i low_sq_u8x64 = _mm512_shuffle_epi8(sq_lut_u8x64, low_u4_u8x64);
|
|
3298
|
+
__m512i high_sq_u8x64 = _mm512_shuffle_epi8(sq_lut_u8x64, high_u4_u8x64);
|
|
3297
3299
|
sumsq_u64x8 = _mm512_add_epi64(sumsq_u64x8, _mm512_sad_epu8(low_sq_u8x64, zero_i8x64));
|
|
3298
3300
|
sumsq_u64x8 = _mm512_add_epi64(sumsq_u64x8, _mm512_sad_epu8(high_sq_u8x64, zero_i8x64));
|
|
3299
3301
|
}
|
|
@@ -3342,13 +3344,13 @@ NK_INTERNAL void nk_reduce_moments_u4_skylake_contiguous_( //
|
|
|
3342
3344
|
raw_i8x64 = _mm512_loadu_si512(ptr);
|
|
3343
3345
|
ptr += 64, count_bytes -= 64;
|
|
3344
3346
|
}
|
|
3345
|
-
__m512i
|
|
3346
|
-
__m512i
|
|
3347
|
-
__m512i
|
|
3348
|
-
sum_u64x8 = _mm512_add_epi64(sum_u64x8, _mm512_sad_epu8(
|
|
3347
|
+
__m512i low_u4_u8x64 = _mm512_and_si512(raw_i8x64, mask_0f_i8x64);
|
|
3348
|
+
__m512i high_u4_u8x64 = _mm512_and_si512(_mm512_srli_epi16(raw_i8x64, 4), mask_0f_i8x64);
|
|
3349
|
+
__m512i pair_sum_u8x64 = _mm512_add_epi8(low_u4_u8x64, high_u4_u8x64);
|
|
3350
|
+
sum_u64x8 = _mm512_add_epi64(sum_u64x8, _mm512_sad_epu8(pair_sum_u8x64, zero_i8x64));
|
|
3349
3351
|
// Sumsq: LUT maps nibble→square, vpsadbw accumulates into u64
|
|
3350
|
-
__m512i low_sq_u8x64 = _mm512_shuffle_epi8(sq_lut_u8x64,
|
|
3351
|
-
__m512i high_sq_u8x64 = _mm512_shuffle_epi8(sq_lut_u8x64,
|
|
3352
|
+
__m512i low_sq_u8x64 = _mm512_shuffle_epi8(sq_lut_u8x64, low_u4_u8x64);
|
|
3353
|
+
__m512i high_sq_u8x64 = _mm512_shuffle_epi8(sq_lut_u8x64, high_u4_u8x64);
|
|
3352
3354
|
sumsq_u64x8 = _mm512_add_epi64(sumsq_u64x8, _mm512_sad_epu8(low_sq_u8x64, zero_i8x64));
|
|
3353
3355
|
sumsq_u64x8 = _mm512_add_epi64(sumsq_u64x8, _mm512_sad_epu8(high_sq_u8x64, zero_i8x64));
|
|
3354
3356
|
}
|
|
@@ -3467,9 +3469,9 @@ NK_PUBLIC void nk_reduce_moments_bf16_skylake( //
|
|
|
3467
3469
|
}
|
|
3468
3470
|
|
|
3469
3471
|
NK_INTERNAL __m512i nk_bf16x32_to_comparable_i16x32_skylake_(__m512i raw_u16x32) {
|
|
3470
|
-
__m512i
|
|
3471
|
-
__m512i
|
|
3472
|
-
return _mm512_xor_si512(raw_u16x32,
|
|
3472
|
+
__m512i sign_i16x32 = _mm512_srai_epi16(raw_u16x32, 15);
|
|
3473
|
+
__m512i flip_i16x32 = _mm512_srli_epi16(sign_i16x32, 1);
|
|
3474
|
+
return _mm512_xor_si512(raw_u16x32, flip_i16x32);
|
|
3473
3475
|
}
|
|
3474
3476
|
|
|
3475
3477
|
NK_INTERNAL void nk_reduce_minmax_bf16_skylake_contiguous_( //
|
|
@@ -3648,9 +3650,9 @@ NK_PUBLIC void nk_reduce_moments_f16_skylake( //
|
|
|
3648
3650
|
}
|
|
3649
3651
|
|
|
3650
3652
|
NK_INTERNAL __m512i nk_f16x32_to_comparable_i16x32_skylake_(__m512i raw_u16x32) {
|
|
3651
|
-
__m512i
|
|
3652
|
-
__m512i
|
|
3653
|
-
return _mm512_xor_si512(raw_u16x32,
|
|
3653
|
+
__m512i sign_i16x32 = _mm512_srai_epi16(raw_u16x32, 15);
|
|
3654
|
+
__m512i flip_i16x32 = _mm512_srli_epi16(sign_i16x32, 1);
|
|
3655
|
+
return _mm512_xor_si512(raw_u16x32, flip_i16x32);
|
|
3654
3656
|
}
|
|
3655
3657
|
|
|
3656
3658
|
NK_INTERNAL void nk_reduce_minmax_f16_skylake_contiguous_( //
|