numkong 7.0.0 → 7.4.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +239 -122
- package/binding.gyp +25 -491
- package/c/dispatch_bf16.c +59 -1
- package/c/dispatch_e2m3.c +41 -8
- package/c/dispatch_e3m2.c +49 -8
- package/c/dispatch_e4m3.c +51 -9
- package/c/dispatch_e5m2.c +45 -1
- package/c/dispatch_f16.c +79 -26
- package/c/dispatch_f16c.c +5 -5
- package/c/dispatch_f32.c +56 -0
- package/c/dispatch_f64.c +52 -0
- package/c/dispatch_i4.c +3 -0
- package/c/dispatch_i8.c +62 -3
- package/c/dispatch_other.c +18 -0
- package/c/dispatch_u1.c +54 -9
- package/c/dispatch_u4.c +3 -0
- package/c/dispatch_u8.c +64 -3
- package/c/numkong.c +3 -0
- package/include/README.md +79 -9
- package/include/numkong/attention/sapphireamx.h +278 -276
- package/include/numkong/attention/sme.h +983 -977
- package/include/numkong/attention.h +1 -1
- package/include/numkong/capabilities.h +289 -94
- package/include/numkong/cast/README.md +40 -40
- package/include/numkong/cast/diamond.h +64 -0
- package/include/numkong/cast/haswell.h +42 -194
- package/include/numkong/cast/icelake.h +42 -37
- package/include/numkong/cast/loongsonasx.h +252 -0
- package/include/numkong/cast/neon.h +216 -249
- package/include/numkong/cast/powervsx.h +449 -0
- package/include/numkong/cast/rvv.h +223 -274
- package/include/numkong/cast/sapphire.h +18 -18
- package/include/numkong/cast/serial.h +1018 -944
- package/include/numkong/cast/skylake.h +82 -23
- package/include/numkong/cast/v128relaxed.h +462 -105
- package/include/numkong/cast.h +24 -0
- package/include/numkong/cast.hpp +44 -0
- package/include/numkong/curved/README.md +17 -17
- package/include/numkong/curved/neon.h +131 -7
- package/include/numkong/curved/neonbfdot.h +6 -7
- package/include/numkong/curved/rvv.h +26 -26
- package/include/numkong/curved/smef64.h +186 -182
- package/include/numkong/curved.h +14 -18
- package/include/numkong/dot/README.md +154 -137
- package/include/numkong/dot/alder.h +43 -43
- package/include/numkong/dot/diamond.h +158 -0
- package/include/numkong/dot/genoa.h +4 -30
- package/include/numkong/dot/haswell.h +215 -180
- package/include/numkong/dot/icelake.h +190 -76
- package/include/numkong/dot/loongsonasx.h +671 -0
- package/include/numkong/dot/neon.h +124 -73
- package/include/numkong/dot/neonbfdot.h +11 -12
- package/include/numkong/dot/neonfhm.h +44 -46
- package/include/numkong/dot/neonfp8.h +323 -0
- package/include/numkong/dot/neonsdot.h +190 -76
- package/include/numkong/dot/powervsx.h +752 -0
- package/include/numkong/dot/rvv.h +92 -84
- package/include/numkong/dot/rvvbf16.h +12 -12
- package/include/numkong/dot/rvvhalf.h +12 -12
- package/include/numkong/dot/sapphire.h +4 -4
- package/include/numkong/dot/serial.h +66 -30
- package/include/numkong/dot/sierra.h +31 -31
- package/include/numkong/dot/skylake.h +142 -110
- package/include/numkong/dot/sve.h +217 -177
- package/include/numkong/dot/svebfdot.h +10 -10
- package/include/numkong/dot/svehalf.h +85 -41
- package/include/numkong/dot/svesdot.h +89 -0
- package/include/numkong/dot/v128relaxed.h +124 -89
- package/include/numkong/dot.h +114 -48
- package/include/numkong/dots/README.md +203 -203
- package/include/numkong/dots/alder.h +12 -9
- package/include/numkong/dots/diamond.h +86 -0
- package/include/numkong/dots/genoa.h +10 -4
- package/include/numkong/dots/haswell.h +63 -48
- package/include/numkong/dots/icelake.h +27 -18
- package/include/numkong/dots/loongsonasx.h +176 -0
- package/include/numkong/dots/neon.h +14 -11
- package/include/numkong/dots/neonbfdot.h +4 -3
- package/include/numkong/dots/neonfhm.h +11 -9
- package/include/numkong/dots/neonfp8.h +99 -0
- package/include/numkong/dots/neonsdot.h +48 -12
- package/include/numkong/dots/powervsx.h +194 -0
- package/include/numkong/dots/rvv.h +451 -344
- package/include/numkong/dots/sapphireamx.h +1028 -984
- package/include/numkong/dots/serial.h +213 -197
- package/include/numkong/dots/sierra.h +10 -7
- package/include/numkong/dots/skylake.h +47 -36
- package/include/numkong/dots/sme.h +2001 -2364
- package/include/numkong/dots/smebi32.h +175 -162
- package/include/numkong/dots/smef64.h +328 -323
- package/include/numkong/dots/v128relaxed.h +64 -41
- package/include/numkong/dots.h +573 -293
- package/include/numkong/dots.hpp +45 -43
- package/include/numkong/each/README.md +133 -137
- package/include/numkong/each/haswell.h +6 -6
- package/include/numkong/each/icelake.h +7 -7
- package/include/numkong/each/neon.h +76 -42
- package/include/numkong/each/neonbfdot.h +11 -12
- package/include/numkong/each/neonhalf.h +24 -116
- package/include/numkong/each/rvv.h +28 -28
- package/include/numkong/each/sapphire.h +27 -161
- package/include/numkong/each/serial.h +6 -6
- package/include/numkong/each/skylake.h +7 -7
- package/include/numkong/each/v128relaxed.h +562 -0
- package/include/numkong/each.h +148 -62
- package/include/numkong/each.hpp +2 -2
- package/include/numkong/geospatial/README.md +18 -18
- package/include/numkong/geospatial/haswell.h +365 -325
- package/include/numkong/geospatial/neon.h +350 -306
- package/include/numkong/geospatial/rvv.h +4 -4
- package/include/numkong/geospatial/skylake.h +376 -340
- package/include/numkong/geospatial/v128relaxed.h +366 -327
- package/include/numkong/geospatial.h +17 -17
- package/include/numkong/matrix.hpp +4 -4
- package/include/numkong/maxsim/README.md +14 -14
- package/include/numkong/maxsim/alder.h +6 -6
- package/include/numkong/maxsim/genoa.h +4 -4
- package/include/numkong/maxsim/haswell.h +6 -6
- package/include/numkong/maxsim/icelake.h +18 -18
- package/include/numkong/maxsim/neonsdot.h +21 -21
- package/include/numkong/maxsim/sapphireamx.h +14 -14
- package/include/numkong/maxsim/serial.h +6 -6
- package/include/numkong/maxsim/sme.h +221 -196
- package/include/numkong/maxsim/v128relaxed.h +6 -6
- package/include/numkong/mesh/README.md +62 -56
- package/include/numkong/mesh/haswell.h +339 -464
- package/include/numkong/mesh/neon.h +1100 -519
- package/include/numkong/mesh/neonbfdot.h +36 -68
- package/include/numkong/mesh/rvv.h +530 -435
- package/include/numkong/mesh/serial.h +75 -91
- package/include/numkong/mesh/skylake.h +1627 -302
- package/include/numkong/mesh/v128relaxed.h +443 -330
- package/include/numkong/mesh.h +63 -49
- package/include/numkong/mesh.hpp +4 -4
- package/include/numkong/numkong.h +3 -3
- package/include/numkong/numkong.hpp +1 -0
- package/include/numkong/probability/README.md +23 -19
- package/include/numkong/probability/neon.h +82 -52
- package/include/numkong/probability/rvv.h +28 -23
- package/include/numkong/probability/serial.h +51 -39
- package/include/numkong/probability.h +20 -23
- package/include/numkong/random.h +1 -1
- package/include/numkong/reduce/README.md +143 -138
- package/include/numkong/reduce/alder.h +81 -77
- package/include/numkong/reduce/haswell.h +222 -220
- package/include/numkong/reduce/neon.h +629 -519
- package/include/numkong/reduce/neonbfdot.h +7 -218
- package/include/numkong/reduce/neonfhm.h +9 -381
- package/include/numkong/reduce/neonsdot.h +9 -9
- package/include/numkong/reduce/rvv.h +928 -802
- package/include/numkong/reduce/serial.h +23 -27
- package/include/numkong/reduce/sierra.h +20 -20
- package/include/numkong/reduce/skylake.h +326 -324
- package/include/numkong/reduce/v128relaxed.h +52 -52
- package/include/numkong/reduce.h +4 -23
- package/include/numkong/reduce.hpp +156 -11
- package/include/numkong/scalar/README.md +6 -6
- package/include/numkong/scalar/haswell.h +26 -17
- package/include/numkong/scalar/loongsonasx.h +74 -0
- package/include/numkong/scalar/neon.h +9 -9
- package/include/numkong/scalar/powervsx.h +96 -0
- package/include/numkong/scalar/rvv.h +2 -2
- package/include/numkong/scalar/sapphire.h +21 -10
- package/include/numkong/scalar/serial.h +21 -21
- package/include/numkong/scalar.h +13 -0
- package/include/numkong/set/README.md +28 -28
- package/include/numkong/set/haswell.h +12 -12
- package/include/numkong/set/icelake.h +14 -14
- package/include/numkong/set/loongsonasx.h +181 -0
- package/include/numkong/set/neon.h +17 -18
- package/include/numkong/set/powervsx.h +326 -0
- package/include/numkong/set/rvv.h +4 -4
- package/include/numkong/set/serial.h +6 -6
- package/include/numkong/set/sve.h +60 -59
- package/include/numkong/set/v128relaxed.h +6 -6
- package/include/numkong/set.h +21 -7
- package/include/numkong/sets/README.md +26 -26
- package/include/numkong/sets/loongsonasx.h +52 -0
- package/include/numkong/sets/powervsx.h +65 -0
- package/include/numkong/sets/smebi32.h +395 -364
- package/include/numkong/sets.h +83 -40
- package/include/numkong/sparse/README.md +4 -4
- package/include/numkong/sparse/icelake.h +101 -101
- package/include/numkong/sparse/serial.h +1 -1
- package/include/numkong/sparse/sve2.h +137 -141
- package/include/numkong/sparse/turin.h +12 -12
- package/include/numkong/sparse.h +10 -10
- package/include/numkong/spatial/README.md +230 -226
- package/include/numkong/spatial/alder.h +113 -116
- package/include/numkong/spatial/diamond.h +240 -0
- package/include/numkong/spatial/genoa.h +0 -68
- package/include/numkong/spatial/haswell.h +74 -55
- package/include/numkong/spatial/icelake.h +539 -58
- package/include/numkong/spatial/loongsonasx.h +483 -0
- package/include/numkong/spatial/neon.h +125 -52
- package/include/numkong/spatial/neonbfdot.h +8 -9
- package/include/numkong/spatial/neonfp8.h +258 -0
- package/include/numkong/spatial/neonsdot.h +180 -12
- package/include/numkong/spatial/powervsx.h +738 -0
- package/include/numkong/spatial/rvv.h +146 -139
- package/include/numkong/spatial/rvvbf16.h +17 -12
- package/include/numkong/spatial/rvvhalf.h +13 -10
- package/include/numkong/spatial/serial.h +13 -12
- package/include/numkong/spatial/sierra.h +232 -39
- package/include/numkong/spatial/skylake.h +73 -74
- package/include/numkong/spatial/sve.h +93 -72
- package/include/numkong/spatial/svebfdot.h +29 -29
- package/include/numkong/spatial/svehalf.h +52 -26
- package/include/numkong/spatial/svesdot.h +142 -0
- package/include/numkong/spatial/v128relaxed.h +293 -41
- package/include/numkong/spatial.h +338 -82
- package/include/numkong/spatials/README.md +194 -194
- package/include/numkong/spatials/diamond.h +82 -0
- package/include/numkong/spatials/haswell.h +2 -2
- package/include/numkong/spatials/loongsonasx.h +153 -0
- package/include/numkong/spatials/neonfp8.h +111 -0
- package/include/numkong/spatials/neonsdot.h +34 -0
- package/include/numkong/spatials/powervsx.h +153 -0
- package/include/numkong/spatials/rvv.h +259 -243
- package/include/numkong/spatials/sapphireamx.h +173 -173
- package/include/numkong/spatials/serial.h +2 -2
- package/include/numkong/spatials/skylake.h +2 -2
- package/include/numkong/spatials/sme.h +590 -605
- package/include/numkong/spatials/smef64.h +139 -130
- package/include/numkong/spatials/v128relaxed.h +2 -2
- package/include/numkong/spatials.h +820 -500
- package/include/numkong/spatials.hpp +49 -48
- package/include/numkong/tensor.hpp +406 -17
- package/include/numkong/trigonometry/README.md +19 -19
- package/include/numkong/trigonometry/haswell.h +402 -401
- package/include/numkong/trigonometry/neon.h +386 -387
- package/include/numkong/trigonometry/rvv.h +52 -51
- package/include/numkong/trigonometry/serial.h +13 -13
- package/include/numkong/trigonometry/skylake.h +373 -369
- package/include/numkong/trigonometry/v128relaxed.h +375 -374
- package/include/numkong/trigonometry.h +13 -13
- package/include/numkong/trigonometry.hpp +2 -2
- package/include/numkong/types.h +287 -49
- package/include/numkong/types.hpp +436 -12
- package/include/numkong/vector.hpp +82 -14
- package/javascript/dist/cjs/numkong-wasm.js +6 -12
- package/javascript/dist/cjs/numkong.d.ts +7 -1
- package/javascript/dist/cjs/numkong.js +37 -11
- package/javascript/dist/cjs/types.d.ts +9 -0
- package/javascript/dist/cjs/types.js +96 -0
- package/javascript/dist/esm/numkong-browser.d.ts +14 -0
- package/javascript/dist/esm/numkong-browser.js +23 -0
- package/javascript/dist/esm/numkong-wasm.js +6 -12
- package/javascript/dist/esm/numkong.d.ts +7 -1
- package/javascript/dist/esm/numkong.js +37 -11
- package/javascript/dist/esm/types.d.ts +9 -0
- package/javascript/dist/esm/types.js +96 -0
- package/javascript/node-gyp-build.d.ts +4 -1
- package/javascript/numkong-browser.ts +40 -0
- package/javascript/numkong-wasm.ts +7 -13
- package/javascript/numkong.c +5 -26
- package/javascript/numkong.ts +36 -11
- package/javascript/tsconfig-base.json +1 -0
- package/javascript/tsconfig-cjs.json +6 -1
- package/javascript/types.ts +110 -0
- package/numkong.gypi +101 -0
- package/package.json +34 -13
- package/probes/arm_neon.c +8 -0
- package/probes/arm_neon_bfdot.c +9 -0
- package/probes/arm_neon_fhm.c +9 -0
- package/probes/arm_neon_half.c +8 -0
- package/probes/arm_neon_sdot.c +9 -0
- package/probes/arm_neonfp8.c +9 -0
- package/probes/arm_sme.c +16 -0
- package/probes/arm_sme2.c +16 -0
- package/probes/arm_sme2p1.c +16 -0
- package/probes/arm_sme_bf16.c +16 -0
- package/probes/arm_sme_bi32.c +16 -0
- package/probes/arm_sme_f64.c +16 -0
- package/probes/arm_sme_fa64.c +14 -0
- package/probes/arm_sme_half.c +16 -0
- package/probes/arm_sme_lut2.c +15 -0
- package/probes/arm_sve.c +18 -0
- package/probes/arm_sve2.c +20 -0
- package/probes/arm_sve2p1.c +18 -0
- package/probes/arm_sve_bfdot.c +20 -0
- package/probes/arm_sve_half.c +18 -0
- package/probes/arm_sve_sdot.c +21 -0
- package/probes/loongarch_lasx.c +12 -0
- package/probes/power_vsx.c +12 -0
- package/probes/probe.js +127 -0
- package/probes/riscv_rvv.c +14 -0
- package/probes/riscv_rvv_bb.c +15 -0
- package/probes/riscv_rvv_bf16.c +17 -0
- package/probes/riscv_rvv_half.c +14 -0
- package/probes/wasm_v128relaxed.c +11 -0
- package/probes/x86_alder.c +17 -0
- package/probes/x86_diamond.c +17 -0
- package/probes/x86_genoa.c +17 -0
- package/probes/x86_graniteamx.c +19 -0
- package/probes/x86_haswell.c +11 -0
- package/probes/x86_icelake.c +17 -0
- package/probes/x86_sapphire.c +16 -0
- package/probes/x86_sapphireamx.c +18 -0
- package/probes/x86_sierra.c +17 -0
- package/probes/x86_skylake.c +15 -0
- package/probes/x86_turin.c +17 -0
- package/wasm/numkong-emscripten.js +2 -0
- package/wasm/numkong.d.ts +14 -0
- package/wasm/numkong.js +1124 -0
- package/wasm/numkong.wasm +0 -0
- package/include/numkong/curved/neonhalf.h +0 -212
- package/include/numkong/dot/neonhalf.h +0 -198
- package/include/numkong/dots/neonhalf.h +0 -57
- package/include/numkong/mesh/neonhalf.h +0 -616
- package/include/numkong/reduce/neonhalf.h +0 -157
- package/include/numkong/spatial/neonhalf.h +0 -118
- package/include/numkong/spatial/sapphire.h +0 -343
- package/include/numkong/spatials/neonhalf.h +0 -58
- package/javascript/README.md +0 -246
|
@@ -8,14 +8,13 @@
|
|
|
8
8
|
*
|
|
9
9
|
* @section dot_icelake_instructions VNNI Instructions Performance
|
|
10
10
|
*
|
|
11
|
-
* Intrinsic
|
|
12
|
-
* _mm512_dpwssd_epi32
|
|
13
|
-
* _mm512_dpbusd_epi32
|
|
14
|
-
* _mm512_madd_epi16
|
|
11
|
+
* Intrinsic Instruction Icelake Genoa
|
|
12
|
+
* _mm512_dpwssd_epi32 VPDPWSSD (ZMM, ZMM, ZMM) 5cy @ p0 4cy @ p01
|
|
13
|
+
* _mm512_dpbusd_epi32 VPDPBUSD (ZMM, ZMM, ZMM) 5cy @ p0 4cy @ p01
|
|
14
|
+
* _mm512_madd_epi16 VPMADDWD (ZMM, ZMM, ZMM) 5cy @ p05 3cy @ p01
|
|
15
15
|
*
|
|
16
16
|
* Ice Lake introduces AVX-512 VNNI for accelerated integer dot products. VNNI instructions bottleneck
|
|
17
17
|
* on port 0, limiting throughput to 1/cy. AMD Genoa dual-issues on ports 0-1, achieving 0.5/cy throughput.
|
|
18
|
-
* We use VPDPWSSD for signed i8 inputs after widening to i16, since VPDPBUSD is asymmetric (unsigned x signed).
|
|
19
18
|
*
|
|
20
19
|
* @section dot_icelake_stateful Stateful Streaming Logic
|
|
21
20
|
*
|
|
@@ -80,6 +79,7 @@
|
|
|
80
79
|
#if NK_TARGET_ICELAKE
|
|
81
80
|
|
|
82
81
|
#include "numkong/types.h"
|
|
82
|
+
#include "numkong/reduce/skylake.h" // `nk_reduce_add_f32x16_skylake_`
|
|
83
83
|
|
|
84
84
|
#if defined(__cplusplus)
|
|
85
85
|
extern "C" {
|
|
@@ -268,13 +268,13 @@ NK_INTERNAL void nk_dot_i8x64_finalize_icelake(
|
|
|
268
268
|
__m128i sum_d_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(sum_d_i32x8), _mm256_extracti128_si256(sum_d_i32x8, 1));
|
|
269
269
|
|
|
270
270
|
// 4-way transpose reduce
|
|
271
|
-
__m128i
|
|
272
|
-
__m128i
|
|
273
|
-
__m128i
|
|
274
|
-
__m128i
|
|
271
|
+
__m128i t_ab_low = _mm_unpacklo_epi32(sum_a_i32x4, sum_b_i32x4);
|
|
272
|
+
__m128i t_cd_low = _mm_unpacklo_epi32(sum_c_i32x4, sum_d_i32x4);
|
|
273
|
+
__m128i t_ab_high = _mm_unpackhi_epi32(sum_a_i32x4, sum_b_i32x4);
|
|
274
|
+
__m128i t_cd_high = _mm_unpackhi_epi32(sum_c_i32x4, sum_d_i32x4);
|
|
275
275
|
__m128i biased_i32x4 = _mm_add_epi32(
|
|
276
|
-
_mm_add_epi32(_mm_unpacklo_epi64(
|
|
277
|
-
_mm_add_epi32(_mm_unpacklo_epi64(
|
|
276
|
+
_mm_add_epi32(_mm_unpacklo_epi64(t_ab_low, t_cd_low), _mm_unpackhi_epi64(t_ab_low, t_cd_low)),
|
|
277
|
+
_mm_add_epi32(_mm_unpacklo_epi64(t_ab_high, t_cd_high), _mm_unpackhi_epi64(t_ab_high, t_cd_high)));
|
|
278
278
|
|
|
279
279
|
// Apply compensation: result = biased − 128 × Σb
|
|
280
280
|
__m128i correction_i32x4 = _mm_slli_epi32(b_sums.xmm, 7); // × 128
|
|
@@ -335,13 +335,13 @@ NK_INTERNAL void nk_dot_u8x64_finalize_icelake(
|
|
|
335
335
|
__m128i sum_d_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(sum_d_i32x8), _mm256_extracti128_si256(sum_d_i32x8, 1));
|
|
336
336
|
|
|
337
337
|
// 4-way transpose reduce
|
|
338
|
-
__m128i
|
|
339
|
-
__m128i
|
|
340
|
-
__m128i
|
|
341
|
-
__m128i
|
|
338
|
+
__m128i t_ab_low = _mm_unpacklo_epi32(sum_a_i32x4, sum_b_i32x4);
|
|
339
|
+
__m128i t_cd_low = _mm_unpacklo_epi32(sum_c_i32x4, sum_d_i32x4);
|
|
340
|
+
__m128i t_ab_high = _mm_unpackhi_epi32(sum_a_i32x4, sum_b_i32x4);
|
|
341
|
+
__m128i t_cd_high = _mm_unpackhi_epi32(sum_c_i32x4, sum_d_i32x4);
|
|
342
342
|
__m128i biased_i32x4 = _mm_add_epi32(
|
|
343
|
-
_mm_add_epi32(_mm_unpacklo_epi64(
|
|
344
|
-
_mm_add_epi32(_mm_unpacklo_epi64(
|
|
343
|
+
_mm_add_epi32(_mm_unpacklo_epi64(t_ab_low, t_cd_low), _mm_unpackhi_epi64(t_ab_low, t_cd_low)),
|
|
344
|
+
_mm_add_epi32(_mm_unpacklo_epi64(t_ab_high, t_cd_high), _mm_unpackhi_epi64(t_ab_high, t_cd_high)));
|
|
345
345
|
|
|
346
346
|
// Apply compensation: result = biased + 128 × Σb
|
|
347
347
|
__m128i correction_i32x4 = _mm_slli_epi32(b_sums.xmm, 7); // × 128
|
|
@@ -402,18 +402,18 @@ NK_INTERNAL void nk_sum_i4x128_update_icelake(nk_sum_i4x128_state_icelake_t *sta
|
|
|
402
402
|
__m512i const nibble_mask_u8x64 = _mm512_set1_epi8(0x0F);
|
|
403
403
|
__m512i const xor_mask_u8x64 = _mm512_set1_epi8(0x08);
|
|
404
404
|
__m512i const zeros_u8x64 = _mm512_setzero_si512();
|
|
405
|
-
|
|
405
|
+
// Extract low and high nibbles, XOR with 8 to get unsigned representation
|
|
406
406
|
__m512i low_u8x64 = _mm512_and_si512(v.zmm, nibble_mask_u8x64);
|
|
407
407
|
__m512i high_u8x64 = _mm512_and_si512(_mm512_srli_epi16(v.zmm, 4), nibble_mask_u8x64);
|
|
408
408
|
__m512i low_biased_u8x64 = _mm512_xor_si512(low_u8x64, xor_mask_u8x64);
|
|
409
409
|
__m512i high_biased_u8x64 = _mm512_xor_si512(high_u8x64, xor_mask_u8x64);
|
|
410
|
-
|
|
410
|
+
// SAD against zero gives sum of unsigned values, accumulate in u64 lanes
|
|
411
411
|
state->biased_sum_u64x8 = _mm512_add_epi64(state->biased_sum_u64x8, _mm512_sad_epu8(low_biased_u8x64, zeros_u8x64));
|
|
412
412
|
state->biased_sum_u64x8 = _mm512_add_epi64(state->biased_sum_u64x8,
|
|
413
413
|
_mm512_sad_epu8(high_biased_u8x64, zeros_u8x64));
|
|
414
414
|
}
|
|
415
415
|
NK_INTERNAL nk_i32_t nk_sum_i4x128_finalize_icelake(nk_sum_i4x128_state_icelake_t const *state, nk_size_t count) {
|
|
416
|
-
|
|
416
|
+
// Reduce u64x8 → scalar, then undo XOR bias: signed_sum = unsigned_sum - 8 * count
|
|
417
417
|
nk_i64_t unsigned_sum = _mm512_reduce_add_epi64(state->biased_sum_u64x8);
|
|
418
418
|
return (nk_i32_t)(unsigned_sum - 8 * (nk_i64_t)count);
|
|
419
419
|
}
|
|
@@ -454,26 +454,26 @@ nk_dot_i4_icelake_cycle:
|
|
|
454
454
|
}
|
|
455
455
|
|
|
456
456
|
// Extract low and high nibbles
|
|
457
|
-
__m512i
|
|
458
|
-
__m512i
|
|
459
|
-
__m512i
|
|
460
|
-
__m512i
|
|
457
|
+
__m512i a_low_u8x64 = _mm512_and_si512(a_i4x128, nibble_mask_u8x64);
|
|
458
|
+
__m512i a_high_u8x64 = _mm512_and_si512(_mm512_srli_epi16(a_i4x128, 4), nibble_mask_u8x64);
|
|
459
|
+
__m512i b_low_u8x64 = _mm512_and_si512(b_i4x128, nibble_mask_u8x64);
|
|
460
|
+
__m512i b_high_u8x64 = _mm512_and_si512(_mm512_srli_epi16(b_i4x128, 4), nibble_mask_u8x64);
|
|
461
461
|
|
|
462
462
|
// XOR with 8 to get cx, dx values for the algebraic transformation
|
|
463
|
-
__m512i
|
|
464
|
-
__m512i
|
|
465
|
-
__m512i
|
|
466
|
-
__m512i
|
|
463
|
+
__m512i c_low_u8x64 = _mm512_xor_si512(a_low_u8x64, xor_mask_u8x64);
|
|
464
|
+
__m512i c_high_u8x64 = _mm512_xor_si512(a_high_u8x64, xor_mask_u8x64);
|
|
465
|
+
__m512i d_low_u8x64 = _mm512_xor_si512(b_low_u8x64, xor_mask_u8x64);
|
|
466
|
+
__m512i d_high_u8x64 = _mm512_xor_si512(b_high_u8x64, xor_mask_u8x64);
|
|
467
467
|
|
|
468
468
|
// Compute dot products of cx*dx for low and high nibbles
|
|
469
|
-
sum_cd_i32x16 = _mm512_dpbusd_epi32(sum_cd_i32x16,
|
|
470
|
-
sum_cd_i32x16 = _mm512_dpbusd_epi32(sum_cd_i32x16,
|
|
469
|
+
sum_cd_i32x16 = _mm512_dpbusd_epi32(sum_cd_i32x16, c_low_u8x64, d_low_u8x64);
|
|
470
|
+
sum_cd_i32x16 = _mm512_dpbusd_epi32(sum_cd_i32x16, c_high_u8x64, d_high_u8x64);
|
|
471
471
|
|
|
472
472
|
// Accumulate sums of cx and dx using SAD against zeros
|
|
473
|
-
sum_cx_i64x8 = _mm512_add_epi64(sum_cx_i64x8, _mm512_sad_epu8(
|
|
474
|
-
sum_cx_i64x8 = _mm512_add_epi64(sum_cx_i64x8, _mm512_sad_epu8(
|
|
475
|
-
sum_dx_i64x8 = _mm512_add_epi64(sum_dx_i64x8, _mm512_sad_epu8(
|
|
476
|
-
sum_dx_i64x8 = _mm512_add_epi64(sum_dx_i64x8, _mm512_sad_epu8(
|
|
473
|
+
sum_cx_i64x8 = _mm512_add_epi64(sum_cx_i64x8, _mm512_sad_epu8(c_low_u8x64, zeros_u8x64));
|
|
474
|
+
sum_cx_i64x8 = _mm512_add_epi64(sum_cx_i64x8, _mm512_sad_epu8(c_high_u8x64, zeros_u8x64));
|
|
475
|
+
sum_dx_i64x8 = _mm512_add_epi64(sum_dx_i64x8, _mm512_sad_epu8(d_low_u8x64, zeros_u8x64));
|
|
476
|
+
sum_dx_i64x8 = _mm512_add_epi64(sum_dx_i64x8, _mm512_sad_epu8(d_high_u8x64, zeros_u8x64));
|
|
477
477
|
if (n_bytes) goto nk_dot_i4_icelake_cycle;
|
|
478
478
|
|
|
479
479
|
// Reduce partial sums and apply algebraic correction
|
|
@@ -509,15 +509,15 @@ nk_dot_u4_icelake_cycle:
|
|
|
509
509
|
}
|
|
510
510
|
|
|
511
511
|
// Extract low and high nibbles
|
|
512
|
-
__m512i
|
|
513
|
-
__m512i
|
|
514
|
-
__m512i
|
|
515
|
-
__m512i
|
|
512
|
+
__m512i a_low_u8x64 = _mm512_and_si512(a_u4x128, nibble_mask_u8x64);
|
|
513
|
+
__m512i a_high_u8x64 = _mm512_and_si512(_mm512_srli_epi16(a_u4x128, 4), nibble_mask_u8x64);
|
|
514
|
+
__m512i b_low_u8x64 = _mm512_and_si512(b_u4x128, nibble_mask_u8x64);
|
|
515
|
+
__m512i b_high_u8x64 = _mm512_and_si512(_mm512_srli_epi16(b_u4x128, 4), nibble_mask_u8x64);
|
|
516
516
|
|
|
517
517
|
// DPBUSD works directly for u4 since values are ∈ [0,15]
|
|
518
518
|
// and the signed interpretation of [0,15] is the same as unsigned
|
|
519
|
-
sum_i32x16 = _mm512_dpbusd_epi32(sum_i32x16,
|
|
520
|
-
sum_i32x16 = _mm512_dpbusd_epi32(sum_i32x16,
|
|
519
|
+
sum_i32x16 = _mm512_dpbusd_epi32(sum_i32x16, a_low_u8x64, b_low_u8x64);
|
|
520
|
+
sum_i32x16 = _mm512_dpbusd_epi32(sum_i32x16, a_high_u8x64, b_high_u8x64);
|
|
521
521
|
if (n_bytes) goto nk_dot_u4_icelake_cycle;
|
|
522
522
|
|
|
523
523
|
*result = (nk_u32_t)_mm512_reduce_add_epi32(sum_i32x16);
|
|
@@ -545,22 +545,22 @@ NK_INTERNAL void nk_dot_i4x128_update_icelake(nk_dot_i4x128_state_icelake_t *sta
|
|
|
545
545
|
__m512i b_i4x128 = b.zmm;
|
|
546
546
|
|
|
547
547
|
// Extract low and high nibbles (all 128 nibbles from 64 bytes)
|
|
548
|
-
__m512i
|
|
549
|
-
__m512i
|
|
550
|
-
__m512i
|
|
551
|
-
__m512i
|
|
548
|
+
__m512i a_low_u8x64 = _mm512_and_si512(a_i4x128, nibble_mask_u8x64);
|
|
549
|
+
__m512i a_high_u8x64 = _mm512_and_si512(_mm512_srli_epi16(a_i4x128, 4), nibble_mask_u8x64);
|
|
550
|
+
__m512i b_low_u8x64 = _mm512_and_si512(b_i4x128, nibble_mask_u8x64);
|
|
551
|
+
__m512i b_high_u8x64 = _mm512_and_si512(_mm512_srli_epi16(b_i4x128, 4), nibble_mask_u8x64);
|
|
552
552
|
|
|
553
553
|
// Apply bias transformation: XOR with 8
|
|
554
|
-
__m512i
|
|
555
|
-
__m512i
|
|
556
|
-
__m512i
|
|
557
|
-
__m512i
|
|
554
|
+
__m512i a_biased_low_u8x64 = _mm512_xor_si512(a_low_u8x64, bias_xor_mask_u8x64);
|
|
555
|
+
__m512i a_biased_high_u8x64 = _mm512_xor_si512(a_high_u8x64, bias_xor_mask_u8x64);
|
|
556
|
+
__m512i b_biased_low_u8x64 = _mm512_xor_si512(b_low_u8x64, bias_xor_mask_u8x64);
|
|
557
|
+
__m512i b_biased_high_u8x64 = _mm512_xor_si512(b_high_u8x64, bias_xor_mask_u8x64);
|
|
558
558
|
|
|
559
559
|
// Compute dot products of a_biased×b_biased — no SAD correction accumulators
|
|
560
|
-
state->biased_product_sum_i32x16 = _mm512_dpbusd_epi32(state->biased_product_sum_i32x16,
|
|
561
|
-
|
|
562
|
-
state->biased_product_sum_i32x16 = _mm512_dpbusd_epi32(state->biased_product_sum_i32x16,
|
|
563
|
-
|
|
560
|
+
state->biased_product_sum_i32x16 = _mm512_dpbusd_epi32(state->biased_product_sum_i32x16, a_biased_low_u8x64,
|
|
561
|
+
b_biased_low_u8x64);
|
|
562
|
+
state->biased_product_sum_i32x16 = _mm512_dpbusd_epi32(state->biased_product_sum_i32x16, a_biased_high_u8x64,
|
|
563
|
+
b_biased_high_u8x64);
|
|
564
564
|
}
|
|
565
565
|
|
|
566
566
|
NK_INTERNAL void nk_dot_i4x128_finalize_icelake( //
|
|
@@ -596,13 +596,13 @@ NK_INTERNAL void nk_dot_i4x128_finalize_icelake(
|
|
|
596
596
|
_mm256_extracti128_si256(product_d_i32x8, 1));
|
|
597
597
|
|
|
598
598
|
// 4-way transpose reduce
|
|
599
|
-
__m128i
|
|
600
|
-
__m128i
|
|
601
|
-
__m128i
|
|
602
|
-
__m128i
|
|
599
|
+
__m128i t_ab_low = _mm_unpacklo_epi32(product_a_i32x4, product_b_i32x4);
|
|
600
|
+
__m128i t_cd_low = _mm_unpacklo_epi32(product_c_i32x4, product_d_i32x4);
|
|
601
|
+
__m128i t_ab_high = _mm_unpackhi_epi32(product_a_i32x4, product_b_i32x4);
|
|
602
|
+
__m128i t_cd_high = _mm_unpackhi_epi32(product_c_i32x4, product_d_i32x4);
|
|
603
603
|
__m128i biased_i32x4 = _mm_add_epi32(
|
|
604
|
-
_mm_add_epi32(_mm_unpacklo_epi64(
|
|
605
|
-
_mm_add_epi32(_mm_unpacklo_epi64(
|
|
604
|
+
_mm_add_epi32(_mm_unpacklo_epi64(t_ab_low, t_cd_low), _mm_unpackhi_epi64(t_ab_low, t_cd_low)),
|
|
605
|
+
_mm_add_epi32(_mm_unpacklo_epi64(t_ab_high, t_cd_high), _mm_unpackhi_epi64(t_ab_high, t_cd_high)));
|
|
606
606
|
|
|
607
607
|
// Apply compensation: result = biased − 8×(Σa + Σb) − 64×depth_padded
|
|
608
608
|
__m128i a_sum_broadcast_i32x4 = _mm_set1_epi32(a_sum);
|
|
@@ -633,14 +633,14 @@ NK_INTERNAL void nk_dot_u4x128_update_icelake(nk_dot_u4x128_state_icelake_t *sta
|
|
|
633
633
|
__m512i b_u4x128 = b.zmm;
|
|
634
634
|
|
|
635
635
|
// Extract low and high nibbles (all 128 nibbles from 64 bytes)
|
|
636
|
-
__m512i
|
|
637
|
-
__m512i
|
|
638
|
-
__m512i
|
|
639
|
-
__m512i
|
|
636
|
+
__m512i a_low_u8x64 = _mm512_and_si512(a_u4x128, nibble_mask_u8x64);
|
|
637
|
+
__m512i a_high_u8x64 = _mm512_and_si512(_mm512_srli_epi16(a_u4x128, 4), nibble_mask_u8x64);
|
|
638
|
+
__m512i b_low_u8x64 = _mm512_and_si512(b_u4x128, nibble_mask_u8x64);
|
|
639
|
+
__m512i b_high_u8x64 = _mm512_and_si512(_mm512_srli_epi16(b_u4x128, 4), nibble_mask_u8x64);
|
|
640
640
|
|
|
641
641
|
// DPBUSD works directly for u4 since values are ∈ [0,15]
|
|
642
|
-
state->sum_i32x16 = _mm512_dpbusd_epi32(state->sum_i32x16,
|
|
643
|
-
state->sum_i32x16 = _mm512_dpbusd_epi32(state->sum_i32x16,
|
|
642
|
+
state->sum_i32x16 = _mm512_dpbusd_epi32(state->sum_i32x16, a_low_u8x64, b_low_u8x64);
|
|
643
|
+
state->sum_i32x16 = _mm512_dpbusd_epi32(state->sum_i32x16, a_high_u8x64, b_high_u8x64);
|
|
644
644
|
}
|
|
645
645
|
|
|
646
646
|
NK_INTERNAL void nk_dot_u4x128_finalize_icelake( //
|
|
@@ -667,16 +667,17 @@ NK_INTERNAL void nk_dot_u4x128_finalize_icelake(
|
|
|
667
667
|
__m128i sum_d_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(sum_d_i32x8), _mm256_extracti128_si256(sum_d_i32x8, 1));
|
|
668
668
|
|
|
669
669
|
// 4-way transpose to get [a,b,c,d] in lanes
|
|
670
|
-
__m128i
|
|
671
|
-
__m128i
|
|
672
|
-
__m128i
|
|
673
|
-
__m128i
|
|
674
|
-
__m128i
|
|
675
|
-
__m128i
|
|
676
|
-
__m128i
|
|
677
|
-
__m128i
|
|
678
|
-
|
|
679
|
-
__m128i final_i32x4 = _mm_add_epi32(_mm_add_epi32(
|
|
670
|
+
__m128i transpose_ab_low_i32x4 = _mm_unpacklo_epi32(sum_a_i32x4, sum_b_i32x4);
|
|
671
|
+
__m128i transpose_cd_low_i32x4 = _mm_unpacklo_epi32(sum_c_i32x4, sum_d_i32x4);
|
|
672
|
+
__m128i transpose_ab_high_i32x4 = _mm_unpackhi_epi32(sum_a_i32x4, sum_b_i32x4);
|
|
673
|
+
__m128i transpose_cd_high_i32x4 = _mm_unpackhi_epi32(sum_c_i32x4, sum_d_i32x4);
|
|
674
|
+
__m128i sum_lane0_i32x4 = _mm_unpacklo_epi64(transpose_ab_low_i32x4, transpose_cd_low_i32x4);
|
|
675
|
+
__m128i sum_lane1_i32x4 = _mm_unpackhi_epi64(transpose_ab_low_i32x4, transpose_cd_low_i32x4);
|
|
676
|
+
__m128i sum_lane2_i32x4 = _mm_unpacklo_epi64(transpose_ab_high_i32x4, transpose_cd_high_i32x4);
|
|
677
|
+
__m128i sum_lane3_i32x4 = _mm_unpackhi_epi64(transpose_ab_high_i32x4, transpose_cd_high_i32x4);
|
|
678
|
+
|
|
679
|
+
__m128i final_i32x4 = _mm_add_epi32(_mm_add_epi32(sum_lane0_i32x4, sum_lane1_i32x4),
|
|
680
|
+
_mm_add_epi32(sum_lane2_i32x4, sum_lane3_i32x4));
|
|
680
681
|
result->xmm = final_i32x4;
|
|
681
682
|
}
|
|
682
683
|
|
|
@@ -801,7 +802,120 @@ nk_dot_e3m2_icelake_cycle:
|
|
|
801
802
|
*result = (nk_f32_t)_mm512_reduce_add_epi32(sum_i32x16) / 256.0f;
|
|
802
803
|
}
|
|
803
804
|
|
|
804
|
-
#pragma region
|
|
805
|
+
#pragma region F16 and BF16 Floats
|
|
806
|
+
|
|
807
|
+
NK_PUBLIC void nk_dot_e4m3_icelake(nk_e4m3_t const *a_scalars, nk_e4m3_t const *b_scalars, nk_size_t count_scalars,
|
|
808
|
+
nk_f32_t *result) {
|
|
809
|
+
// E4M3 dot product via octave decomposition + VPDPBUSD integer MAC.
|
|
810
|
+
// Splits 4-bit exponent into 2 octave bits + 2 remainder bits, maps low 5 bits via VPERMB
|
|
811
|
+
// to u8 integers [0, 120], then 16 VPDPBUSD cross-products across 4×4 octave pairs.
|
|
812
|
+
|
|
813
|
+
__m512i const lut_normal_u8x64 = _mm512_set_epi8( //
|
|
814
|
+
120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32, //
|
|
815
|
+
30, 28, 26, 24, 22, 20, 18, 16, 15, 14, 13, 12, 11, 10, 9, 8, //
|
|
816
|
+
120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32, //
|
|
817
|
+
30, 28, 26, 24, 22, 20, 18, 16, 15, 14, 13, 12, 11, 10, 9, 8); //
|
|
818
|
+
__m512i const lut_subnorm_u8x64 = _mm512_set_epi8( //
|
|
819
|
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, //
|
|
820
|
+
0, 0, 0, 0, 0, 0, 0, 0, 14, 12, 10, 8, 6, 4, 2, 0, //
|
|
821
|
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, //
|
|
822
|
+
0, 0, 0, 0, 0, 0, 0, 0, 14, 12, 10, 8, 6, 4, 2, 0); //
|
|
823
|
+
__m512i const magnitude_mask_u8x64 = _mm512_set1_epi8(0x7F);
|
|
824
|
+
__m512i const subnorm_threshold_u8x64 = _mm512_set1_epi8(0x08);
|
|
825
|
+
__m512i const oct_threshold_20_u8x64 = _mm512_set1_epi8(0x20);
|
|
826
|
+
__m512i const oct_threshold_40_u8x64 = _mm512_set1_epi8(0x40);
|
|
827
|
+
__m512i const oct_threshold_60_u8x64 = _mm512_set1_epi8(0x60);
|
|
828
|
+
|
|
829
|
+
__m512i dot0_i32x16 = _mm512_setzero_si512();
|
|
830
|
+
__m512i dot1_i32x16 = _mm512_setzero_si512();
|
|
831
|
+
__m512i dot2_i32x16 = _mm512_setzero_si512();
|
|
832
|
+
__m512i dot3_i32x16 = _mm512_setzero_si512();
|
|
833
|
+
__m512i dot4_i32x16 = _mm512_setzero_si512();
|
|
834
|
+
__m512i dot5_i32x16 = _mm512_setzero_si512();
|
|
835
|
+
__m512i dot6_i32x16 = _mm512_setzero_si512();
|
|
836
|
+
__m512i a_e4m3_u8x64, b_e4m3_u8x64;
|
|
837
|
+
|
|
838
|
+
nk_dot_e4m3_icelake_cycle:
|
|
839
|
+
if (count_scalars < 64) {
|
|
840
|
+
__mmask64 mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFF, count_scalars);
|
|
841
|
+
a_e4m3_u8x64 = _mm512_maskz_loadu_epi8(mask, a_scalars);
|
|
842
|
+
b_e4m3_u8x64 = _mm512_maskz_loadu_epi8(mask, b_scalars);
|
|
843
|
+
count_scalars = 0;
|
|
844
|
+
}
|
|
845
|
+
else {
|
|
846
|
+
a_e4m3_u8x64 = _mm512_loadu_si512(a_scalars);
|
|
847
|
+
b_e4m3_u8x64 = _mm512_loadu_si512(b_scalars);
|
|
848
|
+
a_scalars += 64, b_scalars += 64, count_scalars -= 64;
|
|
849
|
+
}
|
|
850
|
+
|
|
851
|
+
__m512i a_magnitude_u8x64 = _mm512_and_si512(a_e4m3_u8x64, magnitude_mask_u8x64);
|
|
852
|
+
__m512i b_magnitude_u8x64 = _mm512_and_si512(b_e4m3_u8x64, magnitude_mask_u8x64);
|
|
853
|
+
__m512i a_base_u8x64 = _mm512_permutexvar_epi8(a_magnitude_u8x64, lut_normal_u8x64);
|
|
854
|
+
__m512i b_base_u8x64 = _mm512_permutexvar_epi8(b_magnitude_u8x64, lut_normal_u8x64);
|
|
855
|
+
|
|
856
|
+
// Subnormal fixup via VPERMB (avoids VPADDB on Zen 4 ports 8+9 / SPR port 0)
|
|
857
|
+
a_base_u8x64 = _mm512_mask_permutexvar_epi8(a_base_u8x64,
|
|
858
|
+
_mm512_cmplt_epu8_mask(a_magnitude_u8x64, subnorm_threshold_u8x64),
|
|
859
|
+
a_magnitude_u8x64, lut_subnorm_u8x64);
|
|
860
|
+
b_base_u8x64 = _mm512_mask_permutexvar_epi8(b_base_u8x64,
|
|
861
|
+
_mm512_cmplt_epu8_mask(b_magnitude_u8x64, subnorm_threshold_u8x64),
|
|
862
|
+
b_magnitude_u8x64, lut_subnorm_u8x64);
|
|
863
|
+
|
|
864
|
+
// Sign via ternary logic: (a ^ b) & ~0x7F in one VPTERNLOGD (imm 0x14)
|
|
865
|
+
__m512i sign_diff_u8x64 = _mm512_ternarylogic_epi64(a_e4m3_u8x64, b_e4m3_u8x64, magnitude_mask_u8x64, 0x14);
|
|
866
|
+
__m512i b_signed_i8x64 = _mm512_mask_sub_epi8(b_base_u8x64, _mm512_test_epi8_mask(sign_diff_u8x64, sign_diff_u8x64),
|
|
867
|
+
_mm512_setzero_si512(), b_base_u8x64);
|
|
868
|
+
|
|
869
|
+
// Octave masks via cascaded range compares on magnitude
|
|
870
|
+
__mmask64 ka_lt20 = _mm512_cmplt_epu8_mask(a_magnitude_u8x64, oct_threshold_20_u8x64);
|
|
871
|
+
__mmask64 ka_lt40 = _mm512_cmplt_epu8_mask(a_magnitude_u8x64, oct_threshold_40_u8x64);
|
|
872
|
+
__mmask64 ka_lt60 = _mm512_cmplt_epu8_mask(a_magnitude_u8x64, oct_threshold_60_u8x64);
|
|
873
|
+
__mmask64 kb_lt20 = _mm512_cmplt_epu8_mask(b_magnitude_u8x64, oct_threshold_20_u8x64);
|
|
874
|
+
__mmask64 kb_lt40 = _mm512_cmplt_epu8_mask(b_magnitude_u8x64, oct_threshold_40_u8x64);
|
|
875
|
+
__mmask64 kb_lt60 = _mm512_cmplt_epu8_mask(b_magnitude_u8x64, oct_threshold_60_u8x64);
|
|
876
|
+
|
|
877
|
+
__m512i a0_u8x64 = _mm512_maskz_mov_epi8(ka_lt20, a_base_u8x64);
|
|
878
|
+
__m512i a1_u8x64 = _mm512_maskz_mov_epi8(ka_lt40 & ~ka_lt20, a_base_u8x64);
|
|
879
|
+
__m512i a2_u8x64 = _mm512_maskz_mov_epi8(ka_lt60 & ~ka_lt40, a_base_u8x64);
|
|
880
|
+
__m512i a3_u8x64 = _mm512_maskz_mov_epi8(~ka_lt60, a_base_u8x64);
|
|
881
|
+
|
|
882
|
+
__m512i b0_i8x64 = _mm512_maskz_mov_epi8(kb_lt20, b_signed_i8x64);
|
|
883
|
+
__m512i b1_i8x64 = _mm512_maskz_mov_epi8(kb_lt40 & ~kb_lt20, b_signed_i8x64);
|
|
884
|
+
__m512i b2_i8x64 = _mm512_maskz_mov_epi8(kb_lt60 & ~kb_lt40, b_signed_i8x64);
|
|
885
|
+
__m512i b3_i8x64 = _mm512_maskz_mov_epi8(~kb_lt60, b_signed_i8x64);
|
|
886
|
+
|
|
887
|
+
// 16 VPDPBUSD into 7 accumulators grouped by octave sum k = oa + ob
|
|
888
|
+
dot0_i32x16 = _mm512_dpbusd_epi32(dot0_i32x16, a0_u8x64, b0_i8x64);
|
|
889
|
+
dot1_i32x16 = _mm512_dpbusd_epi32(_mm512_dpbusd_epi32(dot1_i32x16, a0_u8x64, b1_i8x64), a1_u8x64, b0_i8x64);
|
|
890
|
+
dot2_i32x16 = _mm512_dpbusd_epi32(
|
|
891
|
+
_mm512_dpbusd_epi32(_mm512_dpbusd_epi32(dot2_i32x16, a0_u8x64, b2_i8x64), a1_u8x64, b1_i8x64), a2_u8x64,
|
|
892
|
+
b0_i8x64);
|
|
893
|
+
dot3_i32x16 = _mm512_dpbusd_epi32(
|
|
894
|
+
_mm512_dpbusd_epi32(
|
|
895
|
+
_mm512_dpbusd_epi32(_mm512_dpbusd_epi32(dot3_i32x16, a0_u8x64, b3_i8x64), a1_u8x64, b2_i8x64), a2_u8x64,
|
|
896
|
+
b1_i8x64),
|
|
897
|
+
a3_u8x64, b0_i8x64);
|
|
898
|
+
dot4_i32x16 = _mm512_dpbusd_epi32(
|
|
899
|
+
_mm512_dpbusd_epi32(_mm512_dpbusd_epi32(dot4_i32x16, a1_u8x64, b3_i8x64), a2_u8x64, b2_i8x64), a3_u8x64,
|
|
900
|
+
b1_i8x64);
|
|
901
|
+
dot5_i32x16 = _mm512_dpbusd_epi32(_mm512_dpbusd_epi32(dot5_i32x16, a2_u8x64, b3_i8x64), a3_u8x64, b2_i8x64);
|
|
902
|
+
dot6_i32x16 = _mm512_dpbusd_epi32(dot6_i32x16, a3_u8x64, b3_i8x64);
|
|
903
|
+
|
|
904
|
+
if (count_scalars) goto nk_dot_e4m3_icelake_cycle;
|
|
905
|
+
|
|
906
|
+
__m512 sum_f32x16 = _mm512_mul_ps(_mm512_cvtepi32_ps(dot0_i32x16), _mm512_set1_ps(9.5367431640625e-07f));
|
|
907
|
+
sum_f32x16 = _mm512_fmadd_ps(_mm512_cvtepi32_ps(dot1_i32x16), _mm512_set1_ps(1.52587890625e-05f), sum_f32x16);
|
|
908
|
+
sum_f32x16 = _mm512_fmadd_ps(_mm512_cvtepi32_ps(dot2_i32x16), _mm512_set1_ps(2.44140625e-04f), sum_f32x16);
|
|
909
|
+
sum_f32x16 = _mm512_fmadd_ps(_mm512_cvtepi32_ps(dot3_i32x16), _mm512_set1_ps(3.90625e-03f), sum_f32x16);
|
|
910
|
+
sum_f32x16 = _mm512_fmadd_ps(_mm512_cvtepi32_ps(dot4_i32x16), _mm512_set1_ps(6.25e-02f), sum_f32x16);
|
|
911
|
+
sum_f32x16 = _mm512_fmadd_ps(_mm512_cvtepi32_ps(dot5_i32x16), _mm512_set1_ps(1.0f), sum_f32x16);
|
|
912
|
+
sum_f32x16 = _mm512_fmadd_ps(_mm512_cvtepi32_ps(dot6_i32x16), _mm512_set1_ps(16.0f), sum_f32x16);
|
|
913
|
+
*result = nk_reduce_add_f32x16_skylake_(sum_f32x16);
|
|
914
|
+
}
|
|
915
|
+
|
|
916
|
+
#pragma endregion F16 and BF16 Floats
|
|
917
|
+
|
|
918
|
+
#pragma region Binary
|
|
805
919
|
|
|
806
920
|
NK_PUBLIC void nk_dot_u1_icelake(nk_u1x8_t const *a, nk_u1x8_t const *b, nk_size_t n_bits, nk_u32_t *result) {
|
|
807
921
|
nk_size_t n_bytes = nk_size_divide_round_up_(n_bits, NK_BITS_PER_BYTE);
|
|
@@ -866,7 +980,7 @@ NK_INTERNAL void nk_dot_u1x512_finalize_icelake( //
|
|
|
866
980
|
result->xmm = _mm_hadd_epi32(ab_i32x4, cd_i32x4);
|
|
867
981
|
}
|
|
868
982
|
|
|
869
|
-
#pragma endregion
|
|
983
|
+
#pragma endregion Binary
|
|
870
984
|
|
|
871
985
|
#if defined(__clang__)
|
|
872
986
|
#pragma clang attribute pop
|