numkong 7.0.0 → 7.4.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +197 -124
- package/binding.gyp +34 -484
- package/c/dispatch_bf16.c +59 -1
- package/c/dispatch_e2m3.c +41 -8
- package/c/dispatch_e3m2.c +49 -8
- package/c/dispatch_e4m3.c +51 -9
- package/c/dispatch_e5m2.c +45 -1
- package/c/dispatch_f16.c +79 -26
- package/c/dispatch_f16c.c +5 -5
- package/c/dispatch_f32.c +56 -0
- package/c/dispatch_f64.c +52 -0
- package/c/dispatch_i4.c +3 -0
- package/c/dispatch_i8.c +62 -3
- package/c/dispatch_other.c +18 -0
- package/c/dispatch_u1.c +54 -9
- package/c/dispatch_u4.c +3 -0
- package/c/dispatch_u8.c +64 -3
- package/c/numkong.c +3 -0
- package/include/README.md +79 -9
- package/include/numkong/attention/sapphireamx.h +278 -276
- package/include/numkong/attention/sme.h +983 -977
- package/include/numkong/attention.h +1 -1
- package/include/numkong/capabilities.h +289 -94
- package/include/numkong/cast/README.md +40 -40
- package/include/numkong/cast/diamond.h +64 -0
- package/include/numkong/cast/haswell.h +42 -194
- package/include/numkong/cast/icelake.h +42 -37
- package/include/numkong/cast/loongsonasx.h +252 -0
- package/include/numkong/cast/neon.h +216 -249
- package/include/numkong/cast/powervsx.h +449 -0
- package/include/numkong/cast/rvv.h +223 -274
- package/include/numkong/cast/sapphire.h +18 -18
- package/include/numkong/cast/serial.h +1018 -944
- package/include/numkong/cast/skylake.h +82 -23
- package/include/numkong/cast/v128relaxed.h +462 -105
- package/include/numkong/cast.h +24 -0
- package/include/numkong/cast.hpp +44 -0
- package/include/numkong/curved/README.md +17 -17
- package/include/numkong/curved/neon.h +131 -7
- package/include/numkong/curved/neonbfdot.h +6 -7
- package/include/numkong/curved/rvv.h +26 -26
- package/include/numkong/curved/smef64.h +186 -182
- package/include/numkong/curved.h +14 -18
- package/include/numkong/dot/README.md +154 -137
- package/include/numkong/dot/alder.h +43 -43
- package/include/numkong/dot/diamond.h +158 -0
- package/include/numkong/dot/genoa.h +4 -30
- package/include/numkong/dot/haswell.h +215 -180
- package/include/numkong/dot/icelake.h +190 -76
- package/include/numkong/dot/loongsonasx.h +671 -0
- package/include/numkong/dot/neon.h +124 -73
- package/include/numkong/dot/neonbfdot.h +11 -12
- package/include/numkong/dot/neonfhm.h +44 -46
- package/include/numkong/dot/neonfp8.h +323 -0
- package/include/numkong/dot/neonsdot.h +190 -76
- package/include/numkong/dot/powervsx.h +752 -0
- package/include/numkong/dot/rvv.h +92 -84
- package/include/numkong/dot/rvvbf16.h +12 -12
- package/include/numkong/dot/rvvhalf.h +12 -12
- package/include/numkong/dot/sapphire.h +4 -4
- package/include/numkong/dot/serial.h +66 -30
- package/include/numkong/dot/sierra.h +31 -31
- package/include/numkong/dot/skylake.h +142 -110
- package/include/numkong/dot/sve.h +217 -177
- package/include/numkong/dot/svebfdot.h +10 -10
- package/include/numkong/dot/svehalf.h +85 -41
- package/include/numkong/dot/svesdot.h +89 -0
- package/include/numkong/dot/v128relaxed.h +124 -89
- package/include/numkong/dot.h +114 -48
- package/include/numkong/dots/README.md +203 -203
- package/include/numkong/dots/alder.h +12 -9
- package/include/numkong/dots/diamond.h +86 -0
- package/include/numkong/dots/genoa.h +10 -4
- package/include/numkong/dots/haswell.h +63 -48
- package/include/numkong/dots/icelake.h +27 -18
- package/include/numkong/dots/loongsonasx.h +176 -0
- package/include/numkong/dots/neon.h +14 -11
- package/include/numkong/dots/neonbfdot.h +4 -3
- package/include/numkong/dots/neonfhm.h +11 -9
- package/include/numkong/dots/neonfp8.h +99 -0
- package/include/numkong/dots/neonsdot.h +48 -12
- package/include/numkong/dots/powervsx.h +194 -0
- package/include/numkong/dots/rvv.h +451 -344
- package/include/numkong/dots/sapphireamx.h +1028 -984
- package/include/numkong/dots/serial.h +213 -197
- package/include/numkong/dots/sierra.h +10 -7
- package/include/numkong/dots/skylake.h +47 -36
- package/include/numkong/dots/sme.h +2001 -2364
- package/include/numkong/dots/smebi32.h +175 -162
- package/include/numkong/dots/smef64.h +328 -323
- package/include/numkong/dots/v128relaxed.h +64 -41
- package/include/numkong/dots.h +573 -293
- package/include/numkong/dots.hpp +45 -43
- package/include/numkong/each/README.md +133 -137
- package/include/numkong/each/haswell.h +6 -6
- package/include/numkong/each/icelake.h +7 -7
- package/include/numkong/each/neon.h +76 -42
- package/include/numkong/each/neonbfdot.h +11 -12
- package/include/numkong/each/neonhalf.h +24 -116
- package/include/numkong/each/rvv.h +28 -28
- package/include/numkong/each/sapphire.h +27 -161
- package/include/numkong/each/serial.h +6 -6
- package/include/numkong/each/skylake.h +7 -7
- package/include/numkong/each/v128relaxed.h +562 -0
- package/include/numkong/each.h +148 -62
- package/include/numkong/each.hpp +2 -2
- package/include/numkong/geospatial/README.md +18 -18
- package/include/numkong/geospatial/haswell.h +365 -325
- package/include/numkong/geospatial/neon.h +350 -306
- package/include/numkong/geospatial/rvv.h +4 -4
- package/include/numkong/geospatial/skylake.h +376 -340
- package/include/numkong/geospatial/v128relaxed.h +366 -327
- package/include/numkong/geospatial.h +17 -17
- package/include/numkong/matrix.hpp +4 -4
- package/include/numkong/maxsim/README.md +14 -14
- package/include/numkong/maxsim/alder.h +6 -6
- package/include/numkong/maxsim/genoa.h +4 -4
- package/include/numkong/maxsim/haswell.h +6 -6
- package/include/numkong/maxsim/icelake.h +18 -18
- package/include/numkong/maxsim/neonsdot.h +21 -21
- package/include/numkong/maxsim/sapphireamx.h +14 -14
- package/include/numkong/maxsim/serial.h +6 -6
- package/include/numkong/maxsim/sme.h +221 -196
- package/include/numkong/maxsim/v128relaxed.h +6 -6
- package/include/numkong/mesh/README.md +62 -56
- package/include/numkong/mesh/haswell.h +339 -464
- package/include/numkong/mesh/neon.h +1100 -519
- package/include/numkong/mesh/neonbfdot.h +36 -68
- package/include/numkong/mesh/rvv.h +530 -435
- package/include/numkong/mesh/serial.h +75 -91
- package/include/numkong/mesh/skylake.h +1627 -302
- package/include/numkong/mesh/v128relaxed.h +443 -330
- package/include/numkong/mesh.h +63 -49
- package/include/numkong/mesh.hpp +4 -4
- package/include/numkong/numkong.h +3 -3
- package/include/numkong/numkong.hpp +1 -0
- package/include/numkong/probability/README.md +23 -19
- package/include/numkong/probability/neon.h +82 -52
- package/include/numkong/probability/rvv.h +28 -23
- package/include/numkong/probability/serial.h +51 -39
- package/include/numkong/probability.h +20 -23
- package/include/numkong/random.h +1 -1
- package/include/numkong/reduce/README.md +143 -138
- package/include/numkong/reduce/alder.h +81 -77
- package/include/numkong/reduce/haswell.h +222 -220
- package/include/numkong/reduce/neon.h +629 -519
- package/include/numkong/reduce/neonbfdot.h +7 -218
- package/include/numkong/reduce/neonfhm.h +9 -381
- package/include/numkong/reduce/neonsdot.h +9 -9
- package/include/numkong/reduce/rvv.h +928 -802
- package/include/numkong/reduce/serial.h +23 -27
- package/include/numkong/reduce/sierra.h +20 -20
- package/include/numkong/reduce/skylake.h +326 -324
- package/include/numkong/reduce/v128relaxed.h +52 -52
- package/include/numkong/reduce.h +4 -23
- package/include/numkong/reduce.hpp +156 -11
- package/include/numkong/scalar/README.md +6 -6
- package/include/numkong/scalar/haswell.h +26 -17
- package/include/numkong/scalar/loongsonasx.h +74 -0
- package/include/numkong/scalar/neon.h +9 -9
- package/include/numkong/scalar/powervsx.h +96 -0
- package/include/numkong/scalar/rvv.h +2 -2
- package/include/numkong/scalar/sapphire.h +21 -10
- package/include/numkong/scalar/serial.h +21 -21
- package/include/numkong/scalar.h +13 -0
- package/include/numkong/set/README.md +28 -28
- package/include/numkong/set/haswell.h +12 -12
- package/include/numkong/set/icelake.h +14 -14
- package/include/numkong/set/loongsonasx.h +181 -0
- package/include/numkong/set/neon.h +17 -18
- package/include/numkong/set/powervsx.h +326 -0
- package/include/numkong/set/rvv.h +4 -4
- package/include/numkong/set/serial.h +6 -6
- package/include/numkong/set/sve.h +60 -59
- package/include/numkong/set/v128relaxed.h +6 -6
- package/include/numkong/set.h +21 -7
- package/include/numkong/sets/README.md +26 -26
- package/include/numkong/sets/loongsonasx.h +52 -0
- package/include/numkong/sets/powervsx.h +65 -0
- package/include/numkong/sets/smebi32.h +395 -364
- package/include/numkong/sets.h +83 -40
- package/include/numkong/sparse/README.md +4 -4
- package/include/numkong/sparse/icelake.h +101 -101
- package/include/numkong/sparse/serial.h +1 -1
- package/include/numkong/sparse/sve2.h +137 -141
- package/include/numkong/sparse/turin.h +12 -12
- package/include/numkong/sparse.h +10 -10
- package/include/numkong/spatial/README.md +230 -226
- package/include/numkong/spatial/alder.h +113 -116
- package/include/numkong/spatial/diamond.h +240 -0
- package/include/numkong/spatial/genoa.h +0 -68
- package/include/numkong/spatial/haswell.h +74 -55
- package/include/numkong/spatial/icelake.h +539 -58
- package/include/numkong/spatial/loongsonasx.h +483 -0
- package/include/numkong/spatial/neon.h +125 -52
- package/include/numkong/spatial/neonbfdot.h +8 -9
- package/include/numkong/spatial/neonfp8.h +258 -0
- package/include/numkong/spatial/neonsdot.h +180 -12
- package/include/numkong/spatial/powervsx.h +738 -0
- package/include/numkong/spatial/rvv.h +146 -139
- package/include/numkong/spatial/rvvbf16.h +17 -12
- package/include/numkong/spatial/rvvhalf.h +13 -10
- package/include/numkong/spatial/serial.h +13 -12
- package/include/numkong/spatial/sierra.h +232 -39
- package/include/numkong/spatial/skylake.h +73 -74
- package/include/numkong/spatial/sve.h +93 -72
- package/include/numkong/spatial/svebfdot.h +29 -29
- package/include/numkong/spatial/svehalf.h +52 -26
- package/include/numkong/spatial/svesdot.h +142 -0
- package/include/numkong/spatial/v128relaxed.h +293 -41
- package/include/numkong/spatial.h +338 -82
- package/include/numkong/spatials/README.md +194 -194
- package/include/numkong/spatials/diamond.h +82 -0
- package/include/numkong/spatials/haswell.h +2 -2
- package/include/numkong/spatials/loongsonasx.h +153 -0
- package/include/numkong/spatials/neonfp8.h +111 -0
- package/include/numkong/spatials/neonsdot.h +34 -0
- package/include/numkong/spatials/powervsx.h +153 -0
- package/include/numkong/spatials/rvv.h +259 -243
- package/include/numkong/spatials/sapphireamx.h +173 -173
- package/include/numkong/spatials/serial.h +2 -2
- package/include/numkong/spatials/skylake.h +2 -2
- package/include/numkong/spatials/sme.h +590 -605
- package/include/numkong/spatials/smef64.h +139 -130
- package/include/numkong/spatials/v128relaxed.h +2 -2
- package/include/numkong/spatials.h +820 -500
- package/include/numkong/spatials.hpp +49 -48
- package/include/numkong/tensor.hpp +406 -17
- package/include/numkong/trigonometry/README.md +19 -19
- package/include/numkong/trigonometry/haswell.h +402 -401
- package/include/numkong/trigonometry/neon.h +386 -387
- package/include/numkong/trigonometry/rvv.h +52 -51
- package/include/numkong/trigonometry/serial.h +13 -13
- package/include/numkong/trigonometry/skylake.h +373 -369
- package/include/numkong/trigonometry/v128relaxed.h +375 -374
- package/include/numkong/trigonometry.h +13 -13
- package/include/numkong/trigonometry.hpp +2 -2
- package/include/numkong/types.h +287 -49
- package/include/numkong/types.hpp +436 -12
- package/include/numkong/vector.hpp +82 -14
- package/javascript/dist/cjs/numkong-wasm.js +6 -12
- package/javascript/dist/cjs/numkong.d.ts +7 -1
- package/javascript/dist/cjs/numkong.js +37 -11
- package/javascript/dist/cjs/types.d.ts +9 -0
- package/javascript/dist/cjs/types.js +96 -0
- package/javascript/dist/esm/numkong-browser.d.ts +14 -0
- package/javascript/dist/esm/numkong-browser.js +23 -0
- package/javascript/dist/esm/numkong-wasm.js +6 -12
- package/javascript/dist/esm/numkong.d.ts +7 -1
- package/javascript/dist/esm/numkong.js +37 -11
- package/javascript/dist/esm/types.d.ts +9 -0
- package/javascript/dist/esm/types.js +96 -0
- package/javascript/node-gyp-build.d.ts +4 -1
- package/javascript/numkong-browser.ts +40 -0
- package/javascript/numkong-wasm.ts +7 -13
- package/javascript/numkong.c +5 -26
- package/javascript/numkong.ts +36 -11
- package/javascript/tsconfig-base.json +1 -0
- package/javascript/tsconfig-cjs.json +6 -1
- package/javascript/types.ts +110 -0
- package/numkong.gypi +101 -0
- package/package.json +34 -13
- package/probes/arm_neon.c +8 -0
- package/probes/arm_neon_bfdot.c +9 -0
- package/probes/arm_neon_fhm.c +9 -0
- package/probes/arm_neon_half.c +8 -0
- package/probes/arm_neon_sdot.c +9 -0
- package/probes/arm_neonfp8.c +9 -0
- package/probes/arm_sme.c +16 -0
- package/probes/arm_sme2.c +16 -0
- package/probes/arm_sme2p1.c +16 -0
- package/probes/arm_sme_bf16.c +16 -0
- package/probes/arm_sme_bi32.c +16 -0
- package/probes/arm_sme_f64.c +16 -0
- package/probes/arm_sme_fa64.c +14 -0
- package/probes/arm_sme_half.c +16 -0
- package/probes/arm_sme_lut2.c +15 -0
- package/probes/arm_sve.c +18 -0
- package/probes/arm_sve2.c +20 -0
- package/probes/arm_sve2p1.c +18 -0
- package/probes/arm_sve_bfdot.c +20 -0
- package/probes/arm_sve_half.c +18 -0
- package/probes/arm_sve_sdot.c +21 -0
- package/probes/loongarch_lasx.c +12 -0
- package/probes/power_vsx.c +12 -0
- package/probes/probe.js +127 -0
- package/probes/riscv_rvv.c +14 -0
- package/probes/riscv_rvv_bb.c +15 -0
- package/probes/riscv_rvv_bf16.c +17 -0
- package/probes/riscv_rvv_half.c +14 -0
- package/probes/wasm_v128relaxed.c +11 -0
- package/probes/x86_alder.c +17 -0
- package/probes/x86_diamond.c +17 -0
- package/probes/x86_genoa.c +17 -0
- package/probes/x86_graniteamx.c +19 -0
- package/probes/x86_haswell.c +11 -0
- package/probes/x86_icelake.c +17 -0
- package/probes/x86_sapphire.c +16 -0
- package/probes/x86_sapphireamx.c +18 -0
- package/probes/x86_sierra.c +17 -0
- package/probes/x86_skylake.c +15 -0
- package/probes/x86_turin.c +17 -0
- package/wasm/numkong-emscripten.js +2 -0
- package/wasm/numkong.d.ts +14 -0
- package/wasm/numkong.js +1124 -0
- package/wasm/numkong.wasm +0 -0
- package/include/numkong/curved/neonhalf.h +0 -212
- package/include/numkong/dot/neonhalf.h +0 -198
- package/include/numkong/dots/neonhalf.h +0 -57
- package/include/numkong/mesh/neonhalf.h +0 -616
- package/include/numkong/reduce/neonhalf.h +0 -157
- package/include/numkong/spatial/neonhalf.h +0 -118
- package/include/numkong/spatial/sapphire.h +0 -343
- package/include/numkong/spatials/neonhalf.h +0 -58
- package/javascript/README.md +0 -246
|
@@ -8,11 +8,11 @@
|
|
|
8
8
|
*
|
|
9
9
|
* @section dot_skylake_instructions Key AVX-512 Instructions
|
|
10
10
|
*
|
|
11
|
-
* Intrinsic
|
|
12
|
-
* _mm512_madd_epi16
|
|
13
|
-
* _mm512_add_epi32
|
|
14
|
-
* _mm512_fmadd_ps
|
|
15
|
-
* _mm512_cvtepi8_epi16
|
|
11
|
+
* Intrinsic Instruction Skylake-X Genoa
|
|
12
|
+
* _mm512_madd_epi16 VPMADDWD (ZMM, ZMM, ZMM) 5cy @ p05 3cy @ p01
|
|
13
|
+
* _mm512_add_epi32 VPADDD (ZMM, ZMM, ZMM) 1cy @ p05 1cy @ p0123
|
|
14
|
+
* _mm512_fmadd_ps VFMADD132PS (ZMM, ZMM, ZMM) 4cy @ p05 4cy @ p01
|
|
15
|
+
* _mm512_cvtepi8_epi16 VPMOVSXBW (ZMM, YMM) 3cy @ p5 4cy @ p12
|
|
16
16
|
*
|
|
17
17
|
* Skylake-X server chips feature dual 512-bit FMA units on ports 0 and 5, enabling 0.5cy throughput for
|
|
18
18
|
* VFMADD and arithmetic operations. Client Skylake variants have only one FMA unit with 1cy throughput.
|
|
@@ -123,7 +123,7 @@ NK_INTERNAL nk_f64_t nk_dot_stable_sum_f64x8_skylake_(__m512d sum_f64x8, __m512d
|
|
|
123
123
|
return nk_dot_stable_sum_f64x4_haswell_(tentative_sum_f64x4, accumulated_error_f64x4);
|
|
124
124
|
}
|
|
125
125
|
|
|
126
|
-
#pragma region
|
|
126
|
+
#pragma region F32 and F64 Floats
|
|
127
127
|
|
|
128
128
|
/**
|
|
129
129
|
* @brief Internal helper state for dot-products of low-precision types, where 32-bit accumulation is enough.
|
|
@@ -479,7 +479,8 @@ nk_vdot_f64c_skylake_cycle:
|
|
|
479
479
|
result->imag = nk_dot_stable_sum_f64x8_skylake_(sum_imag_f64x8, compensation_imag_f64x8);
|
|
480
480
|
}
|
|
481
481
|
|
|
482
|
-
#pragma
|
|
482
|
+
#pragma endregion F32 and F64 Floats
|
|
483
|
+
#pragma region F16 and BF16 Floats
|
|
483
484
|
|
|
484
485
|
NK_PUBLIC void nk_dot_f16_skylake(nk_f16_t const *a_scalars, nk_f16_t const *b_scalars, nk_size_t count_scalars,
|
|
485
486
|
nk_f32_t *result) {
|
|
@@ -508,24 +509,28 @@ nk_dot_f16_skylake_cycle:
|
|
|
508
509
|
|
|
509
510
|
NK_PUBLIC void nk_dot_bf16_skylake(nk_bf16_t const *a_scalars, nk_bf16_t const *b_scalars, nk_size_t count_scalars,
|
|
510
511
|
nk_f32_t *result) {
|
|
511
|
-
|
|
512
|
+
__m512i a_bf16_i16x32, b_bf16_i16x32;
|
|
512
513
|
__m512 sum_f32x16 = _mm512_setzero_ps();
|
|
514
|
+
__m512i mask_high_u32x16 = _mm512_set1_epi32((int)0xFFFF0000);
|
|
513
515
|
|
|
514
516
|
nk_dot_bf16_skylake_cycle:
|
|
515
|
-
if (count_scalars <
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
517
|
+
if (count_scalars < 32) {
|
|
518
|
+
__mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, count_scalars);
|
|
519
|
+
a_bf16_i16x32 = _mm512_maskz_loadu_epi16(mask, a_scalars);
|
|
520
|
+
b_bf16_i16x32 = _mm512_maskz_loadu_epi16(mask, b_scalars);
|
|
519
521
|
count_scalars = 0;
|
|
520
522
|
}
|
|
521
523
|
else {
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
a_scalars +=
|
|
524
|
+
a_bf16_i16x32 = _mm512_loadu_si512(a_scalars);
|
|
525
|
+
b_bf16_i16x32 = _mm512_loadu_si512(b_scalars);
|
|
526
|
+
a_scalars += 32, b_scalars += 32, count_scalars -= 32;
|
|
525
527
|
}
|
|
526
|
-
__m512
|
|
527
|
-
__m512
|
|
528
|
-
sum_f32x16 = _mm512_fmadd_ps(
|
|
528
|
+
__m512 a_even_f32x16 = _mm512_castsi512_ps(_mm512_slli_epi32(a_bf16_i16x32, 16));
|
|
529
|
+
__m512 b_even_f32x16 = _mm512_castsi512_ps(_mm512_slli_epi32(b_bf16_i16x32, 16));
|
|
530
|
+
sum_f32x16 = _mm512_fmadd_ps(a_even_f32x16, b_even_f32x16, sum_f32x16);
|
|
531
|
+
__m512 a_odd_f32x16 = _mm512_castsi512_ps(_mm512_and_si512(a_bf16_i16x32, mask_high_u32x16));
|
|
532
|
+
__m512 b_odd_f32x16 = _mm512_castsi512_ps(_mm512_and_si512(b_bf16_i16x32, mask_high_u32x16));
|
|
533
|
+
sum_f32x16 = _mm512_fmadd_ps(a_odd_f32x16, b_odd_f32x16, sum_f32x16);
|
|
529
534
|
if (count_scalars) goto nk_dot_bf16_skylake_cycle;
|
|
530
535
|
|
|
531
536
|
*result = nk_reduce_add_f32x16_skylake_(sum_f32x16);
|
|
@@ -533,23 +538,23 @@ nk_dot_bf16_skylake_cycle:
|
|
|
533
538
|
|
|
534
539
|
NK_PUBLIC void nk_dot_e4m3_skylake(nk_e4m3_t const *a_scalars, nk_e4m3_t const *b_scalars, nk_size_t count_scalars,
|
|
535
540
|
nk_f32_t *result) {
|
|
536
|
-
__m128i
|
|
541
|
+
__m128i a_e4m3_u8x16, b_e4m3_u8x16;
|
|
537
542
|
__m512 sum_f32x16 = _mm512_setzero_ps();
|
|
538
543
|
|
|
539
544
|
nk_dot_e4m3_skylake_cycle:
|
|
540
545
|
if (count_scalars < 16) {
|
|
541
546
|
__mmask16 mask = (__mmask16)_bzhi_u32(0xFFFF, count_scalars);
|
|
542
|
-
|
|
543
|
-
|
|
547
|
+
a_e4m3_u8x16 = _mm_maskz_loadu_epi8(mask, a_scalars);
|
|
548
|
+
b_e4m3_u8x16 = _mm_maskz_loadu_epi8(mask, b_scalars);
|
|
544
549
|
count_scalars = 0;
|
|
545
550
|
}
|
|
546
551
|
else {
|
|
547
|
-
|
|
548
|
-
|
|
552
|
+
a_e4m3_u8x16 = _mm_loadu_si128((__m128i const *)a_scalars);
|
|
553
|
+
b_e4m3_u8x16 = _mm_loadu_si128((__m128i const *)b_scalars);
|
|
549
554
|
a_scalars += 16, b_scalars += 16, count_scalars -= 16;
|
|
550
555
|
}
|
|
551
|
-
__m512 a_f32x16 = nk_e4m3x16_to_f32x16_skylake_(
|
|
552
|
-
__m512 b_f32x16 = nk_e4m3x16_to_f32x16_skylake_(
|
|
556
|
+
__m512 a_f32x16 = nk_e4m3x16_to_f32x16_skylake_(a_e4m3_u8x16);
|
|
557
|
+
__m512 b_f32x16 = nk_e4m3x16_to_f32x16_skylake_(b_e4m3_u8x16);
|
|
553
558
|
sum_f32x16 = _mm512_fmadd_ps(a_f32x16, b_f32x16, sum_f32x16);
|
|
554
559
|
if (count_scalars) goto nk_dot_e4m3_skylake_cycle;
|
|
555
560
|
|
|
@@ -558,23 +563,23 @@ nk_dot_e4m3_skylake_cycle:
|
|
|
558
563
|
|
|
559
564
|
NK_PUBLIC void nk_dot_e5m2_skylake(nk_e5m2_t const *a_scalars, nk_e5m2_t const *b_scalars, nk_size_t count_scalars,
|
|
560
565
|
nk_f32_t *result) {
|
|
561
|
-
__m128i
|
|
566
|
+
__m128i a_e5m2_u8x16, b_e5m2_u8x16;
|
|
562
567
|
__m512 sum_f32x16 = _mm512_setzero_ps();
|
|
563
568
|
|
|
564
569
|
nk_dot_e5m2_skylake_cycle:
|
|
565
570
|
if (count_scalars < 16) {
|
|
566
571
|
__mmask16 mask = (__mmask16)_bzhi_u32(0xFFFF, count_scalars);
|
|
567
|
-
|
|
568
|
-
|
|
572
|
+
a_e5m2_u8x16 = _mm_maskz_loadu_epi8(mask, a_scalars);
|
|
573
|
+
b_e5m2_u8x16 = _mm_maskz_loadu_epi8(mask, b_scalars);
|
|
569
574
|
count_scalars = 0;
|
|
570
575
|
}
|
|
571
576
|
else {
|
|
572
|
-
|
|
573
|
-
|
|
577
|
+
a_e5m2_u8x16 = _mm_loadu_si128((__m128i const *)a_scalars);
|
|
578
|
+
b_e5m2_u8x16 = _mm_loadu_si128((__m128i const *)b_scalars);
|
|
574
579
|
a_scalars += 16, b_scalars += 16, count_scalars -= 16;
|
|
575
580
|
}
|
|
576
|
-
__m512 a_f32x16 = nk_e5m2x16_to_f32x16_skylake_(
|
|
577
|
-
__m512 b_f32x16 = nk_e5m2x16_to_f32x16_skylake_(
|
|
581
|
+
__m512 a_f32x16 = nk_e5m2x16_to_f32x16_skylake_(a_e5m2_u8x16);
|
|
582
|
+
__m512 b_f32x16 = nk_e5m2x16_to_f32x16_skylake_(b_e5m2_u8x16);
|
|
578
583
|
sum_f32x16 = _mm512_fmadd_ps(a_f32x16, b_f32x16, sum_f32x16);
|
|
579
584
|
if (count_scalars) goto nk_dot_e5m2_skylake_cycle;
|
|
580
585
|
|
|
@@ -587,12 +592,12 @@ NK_PUBLIC void nk_dot_e2m3_skylake(nk_e2m3_t const *a_scalars, nk_e2m3_t const *
|
|
|
587
592
|
// 64 elements per iteration using AVX-512BW. Result = i32_dot / 256.0f (exact).
|
|
588
593
|
//
|
|
589
594
|
// LUTs replicated 4× for 512-bit VPSHUFB (operates per 128-bit lane):
|
|
590
|
-
__m512i const
|
|
595
|
+
__m512i const lut_low_u8x64 = _mm512_set_epi8( //
|
|
591
596
|
30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0, //
|
|
592
597
|
30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0, //
|
|
593
598
|
30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0, //
|
|
594
599
|
30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
|
|
595
|
-
__m512i const
|
|
600
|
+
__m512i const lut_high_u8x64 = _mm512_set_epi8( //
|
|
596
601
|
120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32, //
|
|
597
602
|
120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32, //
|
|
598
603
|
120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32, //
|
|
@@ -625,16 +630,16 @@ nk_dot_e2m3_skylake_cycle:
|
|
|
625
630
|
__m512i b_shuffle_index_u8x64 = _mm512_and_si512(b_magnitude_u8x64, nibble_mask_u8x64);
|
|
626
631
|
|
|
627
632
|
// Bit-4 select via kmask (cleaner than Haswell's vector compare)
|
|
628
|
-
__mmask64
|
|
629
|
-
__mmask64
|
|
633
|
+
__mmask64 a_high_select = _mm512_test_epi8_mask(a_magnitude_u8x64, half_select_u8x64);
|
|
634
|
+
__mmask64 b_high_select = _mm512_test_epi8_mask(b_magnitude_u8x64, half_select_u8x64);
|
|
630
635
|
|
|
631
636
|
// Dual VPSHUFB + mask-blend for 32-entry LUT
|
|
632
|
-
__m512i a_unsigned_u8x64 = _mm512_mask_blend_epi8(
|
|
633
|
-
_mm512_shuffle_epi8(
|
|
634
|
-
_mm512_shuffle_epi8(
|
|
635
|
-
__m512i b_unsigned_u8x64 = _mm512_mask_blend_epi8(
|
|
636
|
-
_mm512_shuffle_epi8(
|
|
637
|
-
_mm512_shuffle_epi8(
|
|
637
|
+
__m512i a_unsigned_u8x64 = _mm512_mask_blend_epi8(a_high_select,
|
|
638
|
+
_mm512_shuffle_epi8(lut_low_u8x64, a_shuffle_index_u8x64),
|
|
639
|
+
_mm512_shuffle_epi8(lut_high_u8x64, a_shuffle_index_u8x64));
|
|
640
|
+
__m512i b_unsigned_u8x64 = _mm512_mask_blend_epi8(b_high_select,
|
|
641
|
+
_mm512_shuffle_epi8(lut_low_u8x64, b_shuffle_index_u8x64),
|
|
642
|
+
_mm512_shuffle_epi8(lut_high_u8x64, b_shuffle_index_u8x64));
|
|
638
643
|
|
|
639
644
|
// Combined sign: (a ^ b) & 0x20, negate b where signs differ using kmask
|
|
640
645
|
__m512i sign_combined_u8x64 = _mm512_and_si512(_mm512_xor_si512(a_e2m3_u8x64, b_e2m3_u8x64), sign_mask_u8x64);
|
|
@@ -657,12 +662,12 @@ NK_PUBLIC void nk_dot_e3m2_skylake(nk_e3m2_t const *a_scalars, nk_e3m2_t const *
|
|
|
657
662
|
// 64 elements per iteration using AVX-512BW. Magnitudes reach 448, requiring i16.
|
|
658
663
|
// Result = i32_dot / 256.0f (exact, no rounding error).
|
|
659
664
|
//
|
|
660
|
-
__m512i const
|
|
665
|
+
__m512i const lut_low_byte_first_u8x64 = _mm512_set_epi8( //
|
|
661
666
|
28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2, 1, 0, //
|
|
662
667
|
28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2, 1, 0, //
|
|
663
668
|
28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2, 1, 0, //
|
|
664
669
|
28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2, 1, 0);
|
|
665
|
-
__m512i const
|
|
670
|
+
__m512i const lut_low_byte_second_u8x64 = _mm512_set_epi8( //
|
|
666
671
|
(char)192, (char)128, 64, 0, (char)224, (char)192, (char)160, (char)128, 112, 96, 80, 64, 56, 48, 40, 32, //
|
|
667
672
|
(char)192, (char)128, 64, 0, (char)224, (char)192, (char)160, (char)128, 112, 96, 80, 64, 56, 48, 40, 32, //
|
|
668
673
|
(char)192, (char)128, 64, 0, (char)224, (char)192, (char)160, (char)128, 112, 96, 80, 64, 56, 48, 40, 32, //
|
|
@@ -695,51 +700,53 @@ nk_dot_e3m2_skylake_cycle:
|
|
|
695
700
|
__m512i b_shuffle_index_u8x64 = _mm512_and_si512(b_magnitude_u8x64, nibble_mask_u8x64);
|
|
696
701
|
|
|
697
702
|
// Bit-4 select via kmask
|
|
698
|
-
__mmask64
|
|
699
|
-
__mmask64
|
|
703
|
+
__mmask64 a_high_select = _mm512_test_epi8_mask(a_magnitude_u8x64, half_select_u8x64);
|
|
704
|
+
__mmask64 b_high_select = _mm512_test_epi8_mask(b_magnitude_u8x64, half_select_u8x64);
|
|
700
705
|
|
|
701
706
|
// Dual VPSHUFB + mask-blend for low bytes
|
|
702
|
-
__m512i
|
|
703
|
-
|
|
704
|
-
|
|
705
|
-
__m512i
|
|
706
|
-
|
|
707
|
-
|
|
707
|
+
__m512i a_low_byte_u8x64 = _mm512_mask_blend_epi8(
|
|
708
|
+
a_high_select, _mm512_shuffle_epi8(lut_low_byte_first_u8x64, a_shuffle_index_u8x64),
|
|
709
|
+
_mm512_shuffle_epi8(lut_low_byte_second_u8x64, a_shuffle_index_u8x64));
|
|
710
|
+
__m512i b_low_byte_u8x64 = _mm512_mask_blend_epi8(
|
|
711
|
+
b_high_select, _mm512_shuffle_epi8(lut_low_byte_first_u8x64, b_shuffle_index_u8x64),
|
|
712
|
+
_mm512_shuffle_epi8(lut_low_byte_second_u8x64, b_shuffle_index_u8x64));
|
|
708
713
|
|
|
709
714
|
// High byte: 1 iff magnitude >= 28 (unsigned compare via _mm512_cmpge_epu8_mask)
|
|
710
|
-
__mmask64
|
|
711
|
-
__mmask64
|
|
712
|
-
__m512i
|
|
713
|
-
__m512i
|
|
715
|
+
__mmask64 a_high_mask = _mm512_cmpge_epu8_mask(a_magnitude_u8x64, _mm512_set1_epi8(28));
|
|
716
|
+
__mmask64 b_high_mask = _mm512_cmpge_epu8_mask(b_magnitude_u8x64, _mm512_set1_epi8(28));
|
|
717
|
+
__m512i a_high_byte_u8x64 = _mm512_maskz_mov_epi8(a_high_mask, ones_u8x64);
|
|
718
|
+
__m512i b_high_byte_u8x64 = _mm512_maskz_mov_epi8(b_high_mask, ones_u8x64);
|
|
714
719
|
|
|
715
720
|
// Interleave low and high bytes into i16
|
|
716
|
-
__m512i
|
|
717
|
-
__m512i
|
|
718
|
-
__m512i
|
|
719
|
-
__m512i
|
|
721
|
+
__m512i a_low_i16x32 = _mm512_unpacklo_epi8(a_low_byte_u8x64, a_high_byte_u8x64);
|
|
722
|
+
__m512i a_high_i16x32 = _mm512_unpackhi_epi8(a_low_byte_u8x64, a_high_byte_u8x64);
|
|
723
|
+
__m512i b_low_i16x32 = _mm512_unpacklo_epi8(b_low_byte_u8x64, b_high_byte_u8x64);
|
|
724
|
+
__m512i b_high_i16x32 = _mm512_unpackhi_epi8(b_low_byte_u8x64, b_high_byte_u8x64);
|
|
720
725
|
|
|
721
726
|
// Combined sign: (a ^ b) & 0x20, need to apply at i16 level
|
|
722
727
|
// Compute sign mask at u8 level, widen to match unpacklo/unpackhi ordering via PEXT
|
|
723
728
|
__m512i sign_combined_u8x64 = _mm512_and_si512(_mm512_xor_si512(a_e3m2_u8x64, b_e3m2_u8x64), sign_mask_u8x64);
|
|
724
729
|
__mmask64 negate_u8_mask = _mm512_test_epi8_mask(sign_combined_u8x64, sign_combined_u8x64);
|
|
725
730
|
// Extract bits matching unpacklo element ordering (bytes 0-7,16-23,32-39,48-55 per 64-byte vector)
|
|
726
|
-
__mmask32
|
|
727
|
-
__mmask32
|
|
731
|
+
__mmask32 negate_low_i16 = (__mmask32)_pext_u64(negate_u8_mask, 0x00FF00FF00FF00FFULL);
|
|
732
|
+
__mmask32 negate_high_i16 = (__mmask32)_pext_u64(negate_u8_mask, 0xFF00FF00FF00FF00ULL);
|
|
728
733
|
// Negate b at i16 level using mask_sub
|
|
729
|
-
__m512i
|
|
730
|
-
|
|
734
|
+
__m512i b_signed_low_i16x32 = _mm512_mask_sub_epi16(b_low_i16x32, negate_low_i16, _mm512_setzero_si512(),
|
|
735
|
+
b_low_i16x32);
|
|
736
|
+
__m512i b_signed_high_i16x32 = _mm512_mask_sub_epi16(b_high_i16x32, negate_high_i16, _mm512_setzero_si512(),
|
|
737
|
+
b_high_i16x32);
|
|
731
738
|
|
|
732
739
|
// VPMADDWD: a_i16 × b_signed_i16 → i32 accumulator
|
|
733
|
-
sum_i32x16 = _mm512_add_epi32(sum_i32x16, _mm512_madd_epi16(
|
|
734
|
-
sum_i32x16 = _mm512_add_epi32(sum_i32x16, _mm512_madd_epi16(
|
|
740
|
+
sum_i32x16 = _mm512_add_epi32(sum_i32x16, _mm512_madd_epi16(a_low_i16x32, b_signed_low_i16x32));
|
|
741
|
+
sum_i32x16 = _mm512_add_epi32(sum_i32x16, _mm512_madd_epi16(a_high_i16x32, b_signed_high_i16x32));
|
|
735
742
|
|
|
736
743
|
if (count_scalars) goto nk_dot_e3m2_skylake_cycle;
|
|
737
744
|
*result = (nk_f32_t)_mm512_reduce_add_epi32(sum_i32x16) / 256.0f;
|
|
738
745
|
}
|
|
739
746
|
|
|
740
|
-
#pragma endregion
|
|
747
|
+
#pragma endregion F16 and BF16 Floats
|
|
741
748
|
|
|
742
|
-
#pragma region
|
|
749
|
+
#pragma region I8 and U8 Integers
|
|
743
750
|
|
|
744
751
|
NK_PUBLIC void nk_dot_i8_skylake(nk_i8_t const *a_scalars, nk_i8_t const *b_scalars, nk_size_t count_scalars,
|
|
745
752
|
nk_i32_t *result) {
|
|
@@ -869,10 +876,34 @@ NK_INTERNAL void nk_dot_f32x8_finalize_skylake(
|
|
|
869
876
|
result->ymm_pd = _mm256_set_m128d(sum_cd_f64x2, sum_ab_f64x2);
|
|
870
877
|
}
|
|
871
878
|
|
|
872
|
-
#pragma endregion - Traditional Floats
|
|
873
|
-
|
|
874
879
|
typedef nk_dot_through_f32_state_skylake_t_ nk_dot_bf16x16_state_skylake_t;
|
|
875
880
|
|
|
881
|
+
typedef nk_dot_through_f32_state_skylake_t_ nk_dot_bf16x32_state_skylake_t;
|
|
882
|
+
|
|
883
|
+
NK_INTERNAL void nk_dot_bf16x32_init_skylake(nk_dot_bf16x32_state_skylake_t *state) {
|
|
884
|
+
nk_dot_through_f32_init_skylake_(state);
|
|
885
|
+
}
|
|
886
|
+
|
|
887
|
+
NK_INTERNAL void nk_dot_bf16x32_update_skylake(nk_dot_bf16x32_state_skylake_t *state, nk_b512_vec_t a, nk_b512_vec_t b,
|
|
888
|
+
nk_size_t depth_offset, nk_size_t active_dimensions) {
|
|
889
|
+
nk_unused_(depth_offset);
|
|
890
|
+
nk_unused_(active_dimensions);
|
|
891
|
+
__m512i mask_high_u32x16 = _mm512_set1_epi32((int)0xFFFF0000);
|
|
892
|
+
__m512 a_even_f32x16 = _mm512_castsi512_ps(_mm512_slli_epi32(a.zmm, 16));
|
|
893
|
+
__m512 b_even_f32x16 = _mm512_castsi512_ps(_mm512_slli_epi32(b.zmm, 16));
|
|
894
|
+
state->sum_f32x16 = _mm512_fmadd_ps(a_even_f32x16, b_even_f32x16, state->sum_f32x16);
|
|
895
|
+
__m512 a_odd_f32x16 = _mm512_castsi512_ps(_mm512_and_si512(a.zmm, mask_high_u32x16));
|
|
896
|
+
__m512 b_odd_f32x16 = _mm512_castsi512_ps(_mm512_and_si512(b.zmm, mask_high_u32x16));
|
|
897
|
+
state->sum_f32x16 = _mm512_fmadd_ps(a_odd_f32x16, b_odd_f32x16, state->sum_f32x16);
|
|
898
|
+
}
|
|
899
|
+
|
|
900
|
+
NK_INTERNAL void nk_dot_bf16x32_finalize_skylake( //
|
|
901
|
+
nk_dot_bf16x32_state_skylake_t const *state_a, nk_dot_bf16x32_state_skylake_t const *state_b, //
|
|
902
|
+
nk_dot_bf16x32_state_skylake_t const *state_c, nk_dot_bf16x32_state_skylake_t const *state_d, //
|
|
903
|
+
nk_size_t total_dimensions, nk_b128_vec_t *result) {
|
|
904
|
+
nk_dot_through_f32_finalize_skylake_(state_a, state_b, state_c, state_d, total_dimensions, result);
|
|
905
|
+
}
|
|
906
|
+
|
|
876
907
|
typedef nk_dot_through_f32_state_skylake_t_ nk_dot_f16x16_state_skylake_t;
|
|
877
908
|
|
|
878
909
|
typedef struct nk_dot_e2m3x64_state_skylake_t {
|
|
@@ -887,14 +918,14 @@ NK_INTERNAL void nk_dot_e2m3x64_update_skylake(nk_dot_e2m3x64_state_skylake_t *s
|
|
|
887
918
|
nk_size_t depth_offset, nk_size_t active_dimensions) {
|
|
888
919
|
nk_unused_(depth_offset);
|
|
889
920
|
nk_unused_(active_dimensions);
|
|
890
|
-
__m512i const
|
|
891
|
-
|
|
892
|
-
|
|
893
|
-
|
|
894
|
-
__m512i const
|
|
895
|
-
|
|
896
|
-
|
|
897
|
-
|
|
921
|
+
__m512i const lut_low_u8x64 = _mm512_set_epi8(30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0, 30, 28, 26,
|
|
922
|
+
24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0, 30, 28, 26, 24, 22, 20,
|
|
923
|
+
18, 16, 14, 12, 10, 8, 6, 4, 2, 0, 30, 28, 26, 24, 22, 20, 18, 16, 14,
|
|
924
|
+
12, 10, 8, 6, 4, 2, 0);
|
|
925
|
+
__m512i const lut_high_u8x64 = _mm512_set_epi8(120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32,
|
|
926
|
+
120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32,
|
|
927
|
+
120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32,
|
|
928
|
+
120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32);
|
|
898
929
|
__m512i const nibble_mask_u8x64 = _mm512_set1_epi8(0x0F);
|
|
899
930
|
__m512i const magnitude_mask_u8x64 = _mm512_set1_epi8(0x1F);
|
|
900
931
|
__m512i const half_select_u8x64 = _mm512_set1_epi8(0x10);
|
|
@@ -909,19 +940,19 @@ NK_INTERNAL void nk_dot_e2m3x64_update_skylake(nk_dot_e2m3x64_state_skylake_t *s
|
|
|
909
940
|
__m512i a_shuffle_idx = _mm512_and_si512(a_magnitude, nibble_mask_u8x64);
|
|
910
941
|
__m512i b_shuffle_idx = _mm512_and_si512(b_magnitude, nibble_mask_u8x64);
|
|
911
942
|
|
|
912
|
-
__mmask64
|
|
913
|
-
__mmask64
|
|
943
|
+
__mmask64 a_high = _mm512_test_epi8_mask(a_magnitude, half_select_u8x64);
|
|
944
|
+
__mmask64 b_high = _mm512_test_epi8_mask(b_magnitude, half_select_u8x64);
|
|
914
945
|
|
|
915
|
-
__m512i a_unsigned = _mm512_mask_blend_epi8(
|
|
916
|
-
_mm512_shuffle_epi8(
|
|
917
|
-
__m512i b_unsigned = _mm512_mask_blend_epi8(
|
|
918
|
-
_mm512_shuffle_epi8(
|
|
946
|
+
__m512i a_unsigned = _mm512_mask_blend_epi8(a_high, _mm512_shuffle_epi8(lut_low_u8x64, a_shuffle_idx),
|
|
947
|
+
_mm512_shuffle_epi8(lut_high_u8x64, a_shuffle_idx));
|
|
948
|
+
__m512i b_unsigned = _mm512_mask_blend_epi8(b_high, _mm512_shuffle_epi8(lut_low_u8x64, b_shuffle_idx),
|
|
949
|
+
_mm512_shuffle_epi8(lut_high_u8x64, b_shuffle_idx));
|
|
919
950
|
|
|
920
951
|
__m512i sign_combined = _mm512_and_si512(_mm512_xor_si512(a_u8x64, b_u8x64), sign_mask_u8x64);
|
|
921
952
|
__mmask64 negate_mask = _mm512_test_epi8_mask(sign_combined, sign_combined);
|
|
922
|
-
__m512i
|
|
953
|
+
__m512i b_signed_i8x64 = _mm512_mask_sub_epi8(b_unsigned, negate_mask, _mm512_setzero_si512(), b_unsigned);
|
|
923
954
|
|
|
924
|
-
__m512i products_i16x32 = _mm512_maddubs_epi16(a_unsigned,
|
|
955
|
+
__m512i products_i16x32 = _mm512_maddubs_epi16(a_unsigned, b_signed_i8x64);
|
|
925
956
|
state->sum_i32x16 = _mm512_add_epi32(state->sum_i32x16, _mm512_madd_epi16(products_i16x32, ones_i16x32));
|
|
926
957
|
}
|
|
927
958
|
|
|
@@ -976,10 +1007,10 @@ NK_INTERNAL void nk_dot_e3m2x64_update_skylake(nk_dot_e3m2x64_state_skylake_t *s
|
|
|
976
1007
|
nk_size_t depth_offset, nk_size_t active_dimensions) {
|
|
977
1008
|
nk_unused_(depth_offset);
|
|
978
1009
|
nk_unused_(active_dimensions);
|
|
979
|
-
__m512i const
|
|
1010
|
+
__m512i const lut_low_byte_first_u8x64 = _mm512_set_epi8( //
|
|
980
1011
|
28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2, 1, 0, 28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2, 1, 0, //
|
|
981
1012
|
28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2, 1, 0, 28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2, 1, 0);
|
|
982
|
-
__m512i const
|
|
1013
|
+
__m512i const lut_low_byte_second_u8x64 = _mm512_set_epi8( //
|
|
983
1014
|
(char)192, (char)128, 64, 0, (char)224, (char)192, (char)160, (char)128, 112, 96, 80, 64, 56, 48, 40, 32, //
|
|
984
1015
|
(char)192, (char)128, 64, 0, (char)224, (char)192, (char)160, (char)128, 112, 96, 80, 64, 56, 48, 40, 32, //
|
|
985
1016
|
(char)192, (char)128, 64, 0, (char)224, (char)192, (char)160, (char)128, 112, 96, 80, 64, 56, 48, 40, 32, //
|
|
@@ -998,34 +1029,35 @@ NK_INTERNAL void nk_dot_e3m2x64_update_skylake(nk_dot_e3m2x64_state_skylake_t *s
|
|
|
998
1029
|
__m512i a_shuffle_idx = _mm512_and_si512(a_magnitude, nibble_mask_u8x64);
|
|
999
1030
|
__m512i b_shuffle_idx = _mm512_and_si512(b_magnitude, nibble_mask_u8x64);
|
|
1000
1031
|
|
|
1001
|
-
__mmask64
|
|
1002
|
-
__mmask64
|
|
1032
|
+
__mmask64 a_high = _mm512_test_epi8_mask(a_magnitude, half_select_u8x64);
|
|
1033
|
+
__mmask64 b_high = _mm512_test_epi8_mask(b_magnitude, half_select_u8x64);
|
|
1003
1034
|
|
|
1004
|
-
__m512i
|
|
1005
|
-
_mm512_shuffle_epi8(
|
|
1006
|
-
__m512i
|
|
1007
|
-
_mm512_shuffle_epi8(
|
|
1035
|
+
__m512i a_low_byte = _mm512_mask_blend_epi8(a_high, _mm512_shuffle_epi8(lut_low_byte_first_u8x64, a_shuffle_idx),
|
|
1036
|
+
_mm512_shuffle_epi8(lut_low_byte_second_u8x64, a_shuffle_idx));
|
|
1037
|
+
__m512i b_low_byte = _mm512_mask_blend_epi8(b_high, _mm512_shuffle_epi8(lut_low_byte_first_u8x64, b_shuffle_idx),
|
|
1038
|
+
_mm512_shuffle_epi8(lut_low_byte_second_u8x64, b_shuffle_idx));
|
|
1008
1039
|
|
|
1009
|
-
__mmask64
|
|
1010
|
-
__mmask64
|
|
1011
|
-
__m512i
|
|
1012
|
-
__m512i
|
|
1040
|
+
__mmask64 a_high_mask = _mm512_cmpge_epu8_mask(a_magnitude, _mm512_set1_epi8(28));
|
|
1041
|
+
__mmask64 b_high_mask = _mm512_cmpge_epu8_mask(b_magnitude, _mm512_set1_epi8(28));
|
|
1042
|
+
__m512i a_high_byte = _mm512_maskz_mov_epi8(a_high_mask, ones_u8x64);
|
|
1043
|
+
__m512i b_high_byte = _mm512_maskz_mov_epi8(b_high_mask, ones_u8x64);
|
|
1013
1044
|
|
|
1014
|
-
__m512i
|
|
1015
|
-
__m512i
|
|
1016
|
-
__m512i
|
|
1017
|
-
__m512i
|
|
1045
|
+
__m512i a_low_i16x32 = _mm512_unpacklo_epi8(a_low_byte, a_high_byte);
|
|
1046
|
+
__m512i a_high_i16x32 = _mm512_unpackhi_epi8(a_low_byte, a_high_byte);
|
|
1047
|
+
__m512i b_low_i16x32 = _mm512_unpacklo_epi8(b_low_byte, b_high_byte);
|
|
1048
|
+
__m512i b_high_i16x32 = _mm512_unpackhi_epi8(b_low_byte, b_high_byte);
|
|
1018
1049
|
|
|
1019
1050
|
// Combined sign: negate b at i16 level via PEXT + mask_sub
|
|
1020
1051
|
__m512i sign_combined = _mm512_and_si512(_mm512_xor_si512(a_u8x64, b_u8x64), sign_mask_u8x64);
|
|
1021
1052
|
__mmask64 negate_u8 = _mm512_test_epi8_mask(sign_combined, sign_combined);
|
|
1022
|
-
__mmask32
|
|
1023
|
-
__mmask32
|
|
1024
|
-
__m512i
|
|
1025
|
-
__m512i
|
|
1026
|
-
|
|
1027
|
-
|
|
1028
|
-
state->
|
|
1053
|
+
__mmask32 negate_low = (__mmask32)_pext_u64(negate_u8, 0x00FF00FF00FF00FFULL);
|
|
1054
|
+
__mmask32 negate_high = (__mmask32)_pext_u64(negate_u8, 0xFF00FF00FF00FF00ULL);
|
|
1055
|
+
__m512i b_signed_low_i16x32 = _mm512_mask_sub_epi16(b_low_i16x32, negate_low, _mm512_setzero_si512(), b_low_i16x32);
|
|
1056
|
+
__m512i b_signed_high_i16x32 = _mm512_mask_sub_epi16(b_high_i16x32, negate_high, _mm512_setzero_si512(),
|
|
1057
|
+
b_high_i16x32);
|
|
1058
|
+
|
|
1059
|
+
state->sum_a_i32x16 = _mm512_add_epi32(state->sum_a_i32x16, _mm512_madd_epi16(a_low_i16x32, b_signed_low_i16x32));
|
|
1060
|
+
state->sum_b_i32x16 = _mm512_add_epi32(state->sum_b_i32x16, _mm512_madd_epi16(a_high_i16x32, b_signed_high_i16x32));
|
|
1029
1061
|
}
|
|
1030
1062
|
|
|
1031
1063
|
NK_INTERNAL void nk_dot_e3m2x64_finalize_skylake( //
|
|
@@ -1067,7 +1099,7 @@ NK_INTERNAL void nk_dot_e3m2x64_finalize_skylake(
|
|
|
1067
1099
|
results->xmm = _mm_castps_si128(sum_f32x4);
|
|
1068
1100
|
}
|
|
1069
1101
|
|
|
1070
|
-
#pragma endregion
|
|
1102
|
+
#pragma endregion I8 and U8 Integers
|
|
1071
1103
|
|
|
1072
1104
|
#if defined(__clang__)
|
|
1073
1105
|
#pragma clang attribute pop
|