numkong 7.0.0 → 7.4.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +197 -124
- package/binding.gyp +34 -484
- package/c/dispatch_bf16.c +59 -1
- package/c/dispatch_e2m3.c +41 -8
- package/c/dispatch_e3m2.c +49 -8
- package/c/dispatch_e4m3.c +51 -9
- package/c/dispatch_e5m2.c +45 -1
- package/c/dispatch_f16.c +79 -26
- package/c/dispatch_f16c.c +5 -5
- package/c/dispatch_f32.c +56 -0
- package/c/dispatch_f64.c +52 -0
- package/c/dispatch_i4.c +3 -0
- package/c/dispatch_i8.c +62 -3
- package/c/dispatch_other.c +18 -0
- package/c/dispatch_u1.c +54 -9
- package/c/dispatch_u4.c +3 -0
- package/c/dispatch_u8.c +64 -3
- package/c/numkong.c +3 -0
- package/include/README.md +79 -9
- package/include/numkong/attention/sapphireamx.h +278 -276
- package/include/numkong/attention/sme.h +983 -977
- package/include/numkong/attention.h +1 -1
- package/include/numkong/capabilities.h +289 -94
- package/include/numkong/cast/README.md +40 -40
- package/include/numkong/cast/diamond.h +64 -0
- package/include/numkong/cast/haswell.h +42 -194
- package/include/numkong/cast/icelake.h +42 -37
- package/include/numkong/cast/loongsonasx.h +252 -0
- package/include/numkong/cast/neon.h +216 -249
- package/include/numkong/cast/powervsx.h +449 -0
- package/include/numkong/cast/rvv.h +223 -274
- package/include/numkong/cast/sapphire.h +18 -18
- package/include/numkong/cast/serial.h +1018 -944
- package/include/numkong/cast/skylake.h +82 -23
- package/include/numkong/cast/v128relaxed.h +462 -105
- package/include/numkong/cast.h +24 -0
- package/include/numkong/cast.hpp +44 -0
- package/include/numkong/curved/README.md +17 -17
- package/include/numkong/curved/neon.h +131 -7
- package/include/numkong/curved/neonbfdot.h +6 -7
- package/include/numkong/curved/rvv.h +26 -26
- package/include/numkong/curved/smef64.h +186 -182
- package/include/numkong/curved.h +14 -18
- package/include/numkong/dot/README.md +154 -137
- package/include/numkong/dot/alder.h +43 -43
- package/include/numkong/dot/diamond.h +158 -0
- package/include/numkong/dot/genoa.h +4 -30
- package/include/numkong/dot/haswell.h +215 -180
- package/include/numkong/dot/icelake.h +190 -76
- package/include/numkong/dot/loongsonasx.h +671 -0
- package/include/numkong/dot/neon.h +124 -73
- package/include/numkong/dot/neonbfdot.h +11 -12
- package/include/numkong/dot/neonfhm.h +44 -46
- package/include/numkong/dot/neonfp8.h +323 -0
- package/include/numkong/dot/neonsdot.h +190 -76
- package/include/numkong/dot/powervsx.h +752 -0
- package/include/numkong/dot/rvv.h +92 -84
- package/include/numkong/dot/rvvbf16.h +12 -12
- package/include/numkong/dot/rvvhalf.h +12 -12
- package/include/numkong/dot/sapphire.h +4 -4
- package/include/numkong/dot/serial.h +66 -30
- package/include/numkong/dot/sierra.h +31 -31
- package/include/numkong/dot/skylake.h +142 -110
- package/include/numkong/dot/sve.h +217 -177
- package/include/numkong/dot/svebfdot.h +10 -10
- package/include/numkong/dot/svehalf.h +85 -41
- package/include/numkong/dot/svesdot.h +89 -0
- package/include/numkong/dot/v128relaxed.h +124 -89
- package/include/numkong/dot.h +114 -48
- package/include/numkong/dots/README.md +203 -203
- package/include/numkong/dots/alder.h +12 -9
- package/include/numkong/dots/diamond.h +86 -0
- package/include/numkong/dots/genoa.h +10 -4
- package/include/numkong/dots/haswell.h +63 -48
- package/include/numkong/dots/icelake.h +27 -18
- package/include/numkong/dots/loongsonasx.h +176 -0
- package/include/numkong/dots/neon.h +14 -11
- package/include/numkong/dots/neonbfdot.h +4 -3
- package/include/numkong/dots/neonfhm.h +11 -9
- package/include/numkong/dots/neonfp8.h +99 -0
- package/include/numkong/dots/neonsdot.h +48 -12
- package/include/numkong/dots/powervsx.h +194 -0
- package/include/numkong/dots/rvv.h +451 -344
- package/include/numkong/dots/sapphireamx.h +1028 -984
- package/include/numkong/dots/serial.h +213 -197
- package/include/numkong/dots/sierra.h +10 -7
- package/include/numkong/dots/skylake.h +47 -36
- package/include/numkong/dots/sme.h +2001 -2364
- package/include/numkong/dots/smebi32.h +175 -162
- package/include/numkong/dots/smef64.h +328 -323
- package/include/numkong/dots/v128relaxed.h +64 -41
- package/include/numkong/dots.h +573 -293
- package/include/numkong/dots.hpp +45 -43
- package/include/numkong/each/README.md +133 -137
- package/include/numkong/each/haswell.h +6 -6
- package/include/numkong/each/icelake.h +7 -7
- package/include/numkong/each/neon.h +76 -42
- package/include/numkong/each/neonbfdot.h +11 -12
- package/include/numkong/each/neonhalf.h +24 -116
- package/include/numkong/each/rvv.h +28 -28
- package/include/numkong/each/sapphire.h +27 -161
- package/include/numkong/each/serial.h +6 -6
- package/include/numkong/each/skylake.h +7 -7
- package/include/numkong/each/v128relaxed.h +562 -0
- package/include/numkong/each.h +148 -62
- package/include/numkong/each.hpp +2 -2
- package/include/numkong/geospatial/README.md +18 -18
- package/include/numkong/geospatial/haswell.h +365 -325
- package/include/numkong/geospatial/neon.h +350 -306
- package/include/numkong/geospatial/rvv.h +4 -4
- package/include/numkong/geospatial/skylake.h +376 -340
- package/include/numkong/geospatial/v128relaxed.h +366 -327
- package/include/numkong/geospatial.h +17 -17
- package/include/numkong/matrix.hpp +4 -4
- package/include/numkong/maxsim/README.md +14 -14
- package/include/numkong/maxsim/alder.h +6 -6
- package/include/numkong/maxsim/genoa.h +4 -4
- package/include/numkong/maxsim/haswell.h +6 -6
- package/include/numkong/maxsim/icelake.h +18 -18
- package/include/numkong/maxsim/neonsdot.h +21 -21
- package/include/numkong/maxsim/sapphireamx.h +14 -14
- package/include/numkong/maxsim/serial.h +6 -6
- package/include/numkong/maxsim/sme.h +221 -196
- package/include/numkong/maxsim/v128relaxed.h +6 -6
- package/include/numkong/mesh/README.md +62 -56
- package/include/numkong/mesh/haswell.h +339 -464
- package/include/numkong/mesh/neon.h +1100 -519
- package/include/numkong/mesh/neonbfdot.h +36 -68
- package/include/numkong/mesh/rvv.h +530 -435
- package/include/numkong/mesh/serial.h +75 -91
- package/include/numkong/mesh/skylake.h +1627 -302
- package/include/numkong/mesh/v128relaxed.h +443 -330
- package/include/numkong/mesh.h +63 -49
- package/include/numkong/mesh.hpp +4 -4
- package/include/numkong/numkong.h +3 -3
- package/include/numkong/numkong.hpp +1 -0
- package/include/numkong/probability/README.md +23 -19
- package/include/numkong/probability/neon.h +82 -52
- package/include/numkong/probability/rvv.h +28 -23
- package/include/numkong/probability/serial.h +51 -39
- package/include/numkong/probability.h +20 -23
- package/include/numkong/random.h +1 -1
- package/include/numkong/reduce/README.md +143 -138
- package/include/numkong/reduce/alder.h +81 -77
- package/include/numkong/reduce/haswell.h +222 -220
- package/include/numkong/reduce/neon.h +629 -519
- package/include/numkong/reduce/neonbfdot.h +7 -218
- package/include/numkong/reduce/neonfhm.h +9 -381
- package/include/numkong/reduce/neonsdot.h +9 -9
- package/include/numkong/reduce/rvv.h +928 -802
- package/include/numkong/reduce/serial.h +23 -27
- package/include/numkong/reduce/sierra.h +20 -20
- package/include/numkong/reduce/skylake.h +326 -324
- package/include/numkong/reduce/v128relaxed.h +52 -52
- package/include/numkong/reduce.h +4 -23
- package/include/numkong/reduce.hpp +156 -11
- package/include/numkong/scalar/README.md +6 -6
- package/include/numkong/scalar/haswell.h +26 -17
- package/include/numkong/scalar/loongsonasx.h +74 -0
- package/include/numkong/scalar/neon.h +9 -9
- package/include/numkong/scalar/powervsx.h +96 -0
- package/include/numkong/scalar/rvv.h +2 -2
- package/include/numkong/scalar/sapphire.h +21 -10
- package/include/numkong/scalar/serial.h +21 -21
- package/include/numkong/scalar.h +13 -0
- package/include/numkong/set/README.md +28 -28
- package/include/numkong/set/haswell.h +12 -12
- package/include/numkong/set/icelake.h +14 -14
- package/include/numkong/set/loongsonasx.h +181 -0
- package/include/numkong/set/neon.h +17 -18
- package/include/numkong/set/powervsx.h +326 -0
- package/include/numkong/set/rvv.h +4 -4
- package/include/numkong/set/serial.h +6 -6
- package/include/numkong/set/sve.h +60 -59
- package/include/numkong/set/v128relaxed.h +6 -6
- package/include/numkong/set.h +21 -7
- package/include/numkong/sets/README.md +26 -26
- package/include/numkong/sets/loongsonasx.h +52 -0
- package/include/numkong/sets/powervsx.h +65 -0
- package/include/numkong/sets/smebi32.h +395 -364
- package/include/numkong/sets.h +83 -40
- package/include/numkong/sparse/README.md +4 -4
- package/include/numkong/sparse/icelake.h +101 -101
- package/include/numkong/sparse/serial.h +1 -1
- package/include/numkong/sparse/sve2.h +137 -141
- package/include/numkong/sparse/turin.h +12 -12
- package/include/numkong/sparse.h +10 -10
- package/include/numkong/spatial/README.md +230 -226
- package/include/numkong/spatial/alder.h +113 -116
- package/include/numkong/spatial/diamond.h +240 -0
- package/include/numkong/spatial/genoa.h +0 -68
- package/include/numkong/spatial/haswell.h +74 -55
- package/include/numkong/spatial/icelake.h +539 -58
- package/include/numkong/spatial/loongsonasx.h +483 -0
- package/include/numkong/spatial/neon.h +125 -52
- package/include/numkong/spatial/neonbfdot.h +8 -9
- package/include/numkong/spatial/neonfp8.h +258 -0
- package/include/numkong/spatial/neonsdot.h +180 -12
- package/include/numkong/spatial/powervsx.h +738 -0
- package/include/numkong/spatial/rvv.h +146 -139
- package/include/numkong/spatial/rvvbf16.h +17 -12
- package/include/numkong/spatial/rvvhalf.h +13 -10
- package/include/numkong/spatial/serial.h +13 -12
- package/include/numkong/spatial/sierra.h +232 -39
- package/include/numkong/spatial/skylake.h +73 -74
- package/include/numkong/spatial/sve.h +93 -72
- package/include/numkong/spatial/svebfdot.h +29 -29
- package/include/numkong/spatial/svehalf.h +52 -26
- package/include/numkong/spatial/svesdot.h +142 -0
- package/include/numkong/spatial/v128relaxed.h +293 -41
- package/include/numkong/spatial.h +338 -82
- package/include/numkong/spatials/README.md +194 -194
- package/include/numkong/spatials/diamond.h +82 -0
- package/include/numkong/spatials/haswell.h +2 -2
- package/include/numkong/spatials/loongsonasx.h +153 -0
- package/include/numkong/spatials/neonfp8.h +111 -0
- package/include/numkong/spatials/neonsdot.h +34 -0
- package/include/numkong/spatials/powervsx.h +153 -0
- package/include/numkong/spatials/rvv.h +259 -243
- package/include/numkong/spatials/sapphireamx.h +173 -173
- package/include/numkong/spatials/serial.h +2 -2
- package/include/numkong/spatials/skylake.h +2 -2
- package/include/numkong/spatials/sme.h +590 -605
- package/include/numkong/spatials/smef64.h +139 -130
- package/include/numkong/spatials/v128relaxed.h +2 -2
- package/include/numkong/spatials.h +820 -500
- package/include/numkong/spatials.hpp +49 -48
- package/include/numkong/tensor.hpp +406 -17
- package/include/numkong/trigonometry/README.md +19 -19
- package/include/numkong/trigonometry/haswell.h +402 -401
- package/include/numkong/trigonometry/neon.h +386 -387
- package/include/numkong/trigonometry/rvv.h +52 -51
- package/include/numkong/trigonometry/serial.h +13 -13
- package/include/numkong/trigonometry/skylake.h +373 -369
- package/include/numkong/trigonometry/v128relaxed.h +375 -374
- package/include/numkong/trigonometry.h +13 -13
- package/include/numkong/trigonometry.hpp +2 -2
- package/include/numkong/types.h +287 -49
- package/include/numkong/types.hpp +436 -12
- package/include/numkong/vector.hpp +82 -14
- package/javascript/dist/cjs/numkong-wasm.js +6 -12
- package/javascript/dist/cjs/numkong.d.ts +7 -1
- package/javascript/dist/cjs/numkong.js +37 -11
- package/javascript/dist/cjs/types.d.ts +9 -0
- package/javascript/dist/cjs/types.js +96 -0
- package/javascript/dist/esm/numkong-browser.d.ts +14 -0
- package/javascript/dist/esm/numkong-browser.js +23 -0
- package/javascript/dist/esm/numkong-wasm.js +6 -12
- package/javascript/dist/esm/numkong.d.ts +7 -1
- package/javascript/dist/esm/numkong.js +37 -11
- package/javascript/dist/esm/types.d.ts +9 -0
- package/javascript/dist/esm/types.js +96 -0
- package/javascript/node-gyp-build.d.ts +4 -1
- package/javascript/numkong-browser.ts +40 -0
- package/javascript/numkong-wasm.ts +7 -13
- package/javascript/numkong.c +5 -26
- package/javascript/numkong.ts +36 -11
- package/javascript/tsconfig-base.json +1 -0
- package/javascript/tsconfig-cjs.json +6 -1
- package/javascript/types.ts +110 -0
- package/numkong.gypi +101 -0
- package/package.json +34 -13
- package/probes/arm_neon.c +8 -0
- package/probes/arm_neon_bfdot.c +9 -0
- package/probes/arm_neon_fhm.c +9 -0
- package/probes/arm_neon_half.c +8 -0
- package/probes/arm_neon_sdot.c +9 -0
- package/probes/arm_neonfp8.c +9 -0
- package/probes/arm_sme.c +16 -0
- package/probes/arm_sme2.c +16 -0
- package/probes/arm_sme2p1.c +16 -0
- package/probes/arm_sme_bf16.c +16 -0
- package/probes/arm_sme_bi32.c +16 -0
- package/probes/arm_sme_f64.c +16 -0
- package/probes/arm_sme_fa64.c +14 -0
- package/probes/arm_sme_half.c +16 -0
- package/probes/arm_sme_lut2.c +15 -0
- package/probes/arm_sve.c +18 -0
- package/probes/arm_sve2.c +20 -0
- package/probes/arm_sve2p1.c +18 -0
- package/probes/arm_sve_bfdot.c +20 -0
- package/probes/arm_sve_half.c +18 -0
- package/probes/arm_sve_sdot.c +21 -0
- package/probes/loongarch_lasx.c +12 -0
- package/probes/power_vsx.c +12 -0
- package/probes/probe.js +127 -0
- package/probes/riscv_rvv.c +14 -0
- package/probes/riscv_rvv_bb.c +15 -0
- package/probes/riscv_rvv_bf16.c +17 -0
- package/probes/riscv_rvv_half.c +14 -0
- package/probes/wasm_v128relaxed.c +11 -0
- package/probes/x86_alder.c +17 -0
- package/probes/x86_diamond.c +17 -0
- package/probes/x86_genoa.c +17 -0
- package/probes/x86_graniteamx.c +19 -0
- package/probes/x86_haswell.c +11 -0
- package/probes/x86_icelake.c +17 -0
- package/probes/x86_sapphire.c +16 -0
- package/probes/x86_sapphireamx.c +18 -0
- package/probes/x86_sierra.c +17 -0
- package/probes/x86_skylake.c +15 -0
- package/probes/x86_turin.c +17 -0
- package/wasm/numkong-emscripten.js +2 -0
- package/wasm/numkong.d.ts +14 -0
- package/wasm/numkong.js +1124 -0
- package/wasm/numkong.wasm +0 -0
- package/include/numkong/curved/neonhalf.h +0 -212
- package/include/numkong/dot/neonhalf.h +0 -198
- package/include/numkong/dots/neonhalf.h +0 -57
- package/include/numkong/mesh/neonhalf.h +0 -616
- package/include/numkong/reduce/neonhalf.h +0 -157
- package/include/numkong/spatial/neonhalf.h +0 -118
- package/include/numkong/spatial/sapphire.h +0 -343
- package/include/numkong/spatials/neonhalf.h +0 -58
- package/javascript/README.md +0 -246
|
@@ -8,18 +8,18 @@
|
|
|
8
8
|
*
|
|
9
9
|
* @section sapphire_elementwise_instructions Relevant Instructions
|
|
10
10
|
*
|
|
11
|
-
* Intrinsic
|
|
12
|
-
* _mm512_add_ph
|
|
13
|
-
* _mm512_mul_ph
|
|
14
|
-
* _mm512_fmadd_ph
|
|
15
|
-
* _mm512_cvtepi16_ph
|
|
16
|
-
* _mm512_cvtph_epi16
|
|
17
|
-
* _mm512_cvtepi8_epi16
|
|
18
|
-
* _mm512_cvtsepi16_epi8
|
|
19
|
-
* _mm512_packus_epi16
|
|
20
|
-
* _mm256_add_ph
|
|
21
|
-
* _mm512_maskz_loadu_epi16
|
|
22
|
-
* _mm512_mask_storeu_epi16
|
|
11
|
+
* Intrinsic Instruction Sapphire Genoa
|
|
12
|
+
* _mm512_add_ph VADDPH (ZMM, ZMM, ZMM) 4cy @ p05 3cy @ p01
|
|
13
|
+
* _mm512_mul_ph VMULPH (ZMM, ZMM, ZMM) 4cy @ p05 3cy @ p01
|
|
14
|
+
* _mm512_fmadd_ph VFMADD (ZMM, ZMM, ZMM) 4cy @ p05 4cy @ p01
|
|
15
|
+
* _mm512_cvtepi16_ph VCVTW2PH (ZMM, ZMM) 4cy @ p05 4cy @ p01
|
|
16
|
+
* _mm512_cvtph_epi16 VCVTPH2W (ZMM, ZMM) 4cy @ p05 4cy @ p01
|
|
17
|
+
* _mm512_cvtepi8_epi16 VPMOVSXBW (ZMM, YMM) 3cy @ p5 3cy @ p12
|
|
18
|
+
* _mm512_cvtsepi16_epi8 VPMOVSWB (YMM, ZMM) 4cy @ p5 4cy @ p12
|
|
19
|
+
* _mm512_packus_epi16 VPACKUSWB (ZMM, ZMM, ZMM) 1cy @ p5 1cy @ p12
|
|
20
|
+
* _mm256_add_ph VADDPH (YMM, YMM, YMM) 4cy @ p05 3cy @ p01
|
|
21
|
+
* _mm512_maskz_loadu_epi16 VMOVDQU16 (ZMM {K}, M512) 7cy @ p23 7cy @ p23
|
|
22
|
+
* _mm512_mask_storeu_epi16 VMOVDQU16 (M512 {K}, ZMM) 4cy @ p4 4cy @ p4
|
|
23
23
|
*/
|
|
24
24
|
#ifndef NK_EACH_SAPPHIRE_H
|
|
25
25
|
#define NK_EACH_SAPPHIRE_H
|
|
@@ -54,8 +54,8 @@ nk_each_sum_f16_sapphire_cycle:
|
|
|
54
54
|
n = 0;
|
|
55
55
|
}
|
|
56
56
|
else {
|
|
57
|
-
a_f16_vec =
|
|
58
|
-
b_f16_vec =
|
|
57
|
+
a_f16_vec = _mm512_castsi512_ph(_mm512_loadu_epi16(a));
|
|
58
|
+
b_f16_vec = _mm512_castsi512_ph(_mm512_loadu_epi16(b));
|
|
59
59
|
a += 32, b += 32, n -= 32;
|
|
60
60
|
}
|
|
61
61
|
sum_f16_vec = _mm512_add_ph(a_f16_vec, b_f16_vec);
|
|
@@ -287,146 +287,11 @@ nk_each_blend_i8_sapphire_cycle:
|
|
|
287
287
|
if (n) goto nk_each_blend_i8_sapphire_cycle;
|
|
288
288
|
}
|
|
289
289
|
|
|
290
|
-
NK_PUBLIC void nk_each_fma_i8_sapphire( //
|
|
291
|
-
nk_i8_t const *a, nk_i8_t const *b, nk_i8_t const *c, nk_size_t n, //
|
|
292
|
-
nk_f32_t const *alpha, nk_f32_t const *beta, nk_i8_t *result) {
|
|
293
|
-
|
|
294
|
-
short alpha_short, beta_short;
|
|
295
|
-
nk_f32_to_f16_sapphire(alpha, (nk_f16_t *)&alpha_short);
|
|
296
|
-
nk_f32_to_f16_sapphire(beta, (nk_f16_t *)&beta_short);
|
|
297
|
-
__mmask64 mask = 0xFFFFFFFFFFFFFFFF;
|
|
298
|
-
__m512h alpha_f16x32 = _mm512_castsi512_ph(_mm512_set1_epi16(alpha_short));
|
|
299
|
-
__m512h beta_f16x32 = _mm512_castsi512_ph(_mm512_set1_epi16(beta_short));
|
|
300
|
-
__m256i a_low_i8x32, a_high_i8x32, b_low_i8x32, b_high_i8x32, c_low_i8x32, c_high_i8x32;
|
|
301
|
-
__m512i result_i8x64;
|
|
302
|
-
__m512h a_low_f16x32, a_high_f16x32, b_low_f16x32, b_high_f16x32;
|
|
303
|
-
__m512h c_low_f16x32, c_high_f16x32, ab_low_f16x32, ab_high_f16x32;
|
|
304
|
-
__m512h ab_scaled_low_f16x32, ab_scaled_high_f16x32, result_low_f16x32, result_high_f16x32;
|
|
305
|
-
__m512i result_low_i16x32, result_high_i16x32;
|
|
306
|
-
__m512h min_f16x32 = _mm512_cvtepi16_ph(_mm512_set1_epi16(-128));
|
|
307
|
-
__m512h max_f16x32 = _mm512_cvtepi16_ph(_mm512_set1_epi16(127));
|
|
308
|
-
|
|
309
|
-
nk_each_fma_i8_sapphire_cycle:
|
|
310
|
-
if (n < 64) {
|
|
311
|
-
// Tail: use masked 512-bit loads and extract (runs once)
|
|
312
|
-
mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFFull, n);
|
|
313
|
-
__m512i a_i8x64 = _mm512_maskz_loadu_epi8(mask, a);
|
|
314
|
-
__m512i b_i8x64 = _mm512_maskz_loadu_epi8(mask, b);
|
|
315
|
-
__m512i c_i8x64 = _mm512_maskz_loadu_epi8(mask, c);
|
|
316
|
-
a_low_i8x32 = _mm512_castsi512_si256(a_i8x64);
|
|
317
|
-
a_high_i8x32 = _mm512_extracti64x4_epi64(a_i8x64, 1);
|
|
318
|
-
b_low_i8x32 = _mm512_castsi512_si256(b_i8x64);
|
|
319
|
-
b_high_i8x32 = _mm512_extracti64x4_epi64(b_i8x64, 1);
|
|
320
|
-
c_low_i8x32 = _mm512_castsi512_si256(c_i8x64);
|
|
321
|
-
c_high_i8x32 = _mm512_extracti64x4_epi64(c_i8x64, 1);
|
|
322
|
-
n = 0;
|
|
323
|
-
}
|
|
324
|
-
else {
|
|
325
|
-
// Hot path: 2×256-bit loads per vector to avoid VEXTRACTI64X4 (Port 5)
|
|
326
|
-
a_low_i8x32 = _mm256_loadu_epi8(a);
|
|
327
|
-
a_high_i8x32 = _mm256_loadu_epi8(a + 32);
|
|
328
|
-
b_low_i8x32 = _mm256_loadu_epi8(b);
|
|
329
|
-
b_high_i8x32 = _mm256_loadu_epi8(b + 32);
|
|
330
|
-
c_low_i8x32 = _mm256_loadu_epi8(c);
|
|
331
|
-
c_high_i8x32 = _mm256_loadu_epi8(c + 32);
|
|
332
|
-
a += 64, b += 64, c += 64, n -= 64;
|
|
333
|
-
}
|
|
334
|
-
// Upcast from 256-bit halves:
|
|
335
|
-
a_low_f16x32 = _mm512_cvtepi16_ph(_mm512_cvtepi8_epi16(a_low_i8x32));
|
|
336
|
-
a_high_f16x32 = _mm512_cvtepi16_ph(_mm512_cvtepi8_epi16(a_high_i8x32));
|
|
337
|
-
b_low_f16x32 = _mm512_cvtepi16_ph(_mm512_cvtepi8_epi16(b_low_i8x32));
|
|
338
|
-
b_high_f16x32 = _mm512_cvtepi16_ph(_mm512_cvtepi8_epi16(b_high_i8x32));
|
|
339
|
-
c_low_f16x32 = _mm512_cvtepi16_ph(_mm512_cvtepi8_epi16(c_low_i8x32));
|
|
340
|
-
c_high_f16x32 = _mm512_cvtepi16_ph(_mm512_cvtepi8_epi16(c_high_i8x32));
|
|
341
|
-
// Multiply:
|
|
342
|
-
ab_low_f16x32 = _mm512_mul_ph(a_low_f16x32, b_low_f16x32);
|
|
343
|
-
ab_high_f16x32 = _mm512_mul_ph(a_high_f16x32, b_high_f16x32);
|
|
344
|
-
// Scale:
|
|
345
|
-
ab_scaled_low_f16x32 = _mm512_mul_ph(ab_low_f16x32, alpha_f16x32);
|
|
346
|
-
ab_scaled_high_f16x32 = _mm512_mul_ph(ab_high_f16x32, alpha_f16x32);
|
|
347
|
-
// Add:
|
|
348
|
-
result_low_f16x32 = _mm512_fmadd_ph(c_low_f16x32, beta_f16x32, ab_scaled_low_f16x32);
|
|
349
|
-
result_high_f16x32 = _mm512_fmadd_ph(c_high_f16x32, beta_f16x32, ab_scaled_high_f16x32);
|
|
350
|
-
// Clip the 16-bit result to 8-bit:
|
|
351
|
-
result_low_f16x32 = _mm512_max_ph(_mm512_min_ph(result_low_f16x32, max_f16x32), min_f16x32);
|
|
352
|
-
result_high_f16x32 = _mm512_max_ph(_mm512_min_ph(result_high_f16x32, max_f16x32), min_f16x32);
|
|
353
|
-
// Downcast:
|
|
354
|
-
result_low_i16x32 = _mm512_cvtph_epi16(result_low_f16x32);
|
|
355
|
-
result_high_i16x32 = _mm512_cvtph_epi16(result_high_f16x32);
|
|
356
|
-
// Merge back:
|
|
357
|
-
result_i8x64 = _mm512_inserti64x4(_mm512_castsi256_si512(_mm512_cvtsepi16_epi8(result_low_i16x32)),
|
|
358
|
-
_mm512_cvtsepi16_epi8(result_high_i16x32), 1);
|
|
359
|
-
_mm512_mask_storeu_epi8(result, mask, result_i8x64);
|
|
360
|
-
result += 64;
|
|
361
|
-
if (n) goto nk_each_fma_i8_sapphire_cycle;
|
|
362
|
-
}
|
|
363
|
-
|
|
364
|
-
NK_PUBLIC void nk_each_fma_u8_sapphire( //
|
|
365
|
-
nk_u8_t const *a, nk_u8_t const *b, nk_u8_t const *c, nk_size_t n, //
|
|
366
|
-
nk_f32_t const *alpha, nk_f32_t const *beta, nk_u8_t *result) {
|
|
367
|
-
|
|
368
|
-
short alpha_short, beta_short;
|
|
369
|
-
nk_f32_to_f16_sapphire(alpha, (nk_f16_t *)&alpha_short);
|
|
370
|
-
nk_f32_to_f16_sapphire(beta, (nk_f16_t *)&beta_short);
|
|
371
|
-
__mmask64 mask = 0xFFFFFFFFFFFFFFFF;
|
|
372
|
-
__m512h alpha_f16x32 = _mm512_castsi512_ph(_mm512_set1_epi16(alpha_short));
|
|
373
|
-
__m512h beta_f16x32 = _mm512_castsi512_ph(_mm512_set1_epi16(beta_short));
|
|
374
|
-
__m512i a_u8x64, b_u8x64, c_u8x64, result_u8x64;
|
|
375
|
-
__m512h a_low_f16x32, a_high_f16x32, b_low_f16x32, b_high_f16x32;
|
|
376
|
-
__m512h c_low_f16x32, c_high_f16x32, ab_low_f16x32, ab_high_f16x32;
|
|
377
|
-
__m512h ab_scaled_low_f16x32, ab_scaled_high_f16x32, result_low_f16x32, result_high_f16x32;
|
|
378
|
-
__m512i result_low_i16x32, result_high_i16x32;
|
|
379
|
-
__m512h min_f16x32 = _mm512_cvtepi16_ph(_mm512_set1_epi16(0));
|
|
380
|
-
__m512h max_f16x32 = _mm512_cvtepi16_ph(_mm512_set1_epi16(255));
|
|
381
|
-
|
|
382
|
-
nk_each_fma_u8_sapphire_cycle:
|
|
383
|
-
if (n < 64) {
|
|
384
|
-
mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFFull, n);
|
|
385
|
-
a_u8x64 = _mm512_maskz_loadu_epi8(mask, a);
|
|
386
|
-
b_u8x64 = _mm512_maskz_loadu_epi8(mask, b);
|
|
387
|
-
c_u8x64 = _mm512_maskz_loadu_epi8(mask, c);
|
|
388
|
-
n = 0;
|
|
389
|
-
}
|
|
390
|
-
else {
|
|
391
|
-
a_u8x64 = _mm512_loadu_epi8(a);
|
|
392
|
-
b_u8x64 = _mm512_loadu_epi8(b);
|
|
393
|
-
c_u8x64 = _mm512_loadu_epi8(c);
|
|
394
|
-
a += 64, b += 64, c += 64, n -= 64;
|
|
395
|
-
}
|
|
396
|
-
// Upcast:
|
|
397
|
-
a_low_f16x32 = _mm512_cvtepi16_ph(_mm512_unpacklo_epi8(a_u8x64, _mm512_setzero_si512()));
|
|
398
|
-
a_high_f16x32 = _mm512_cvtepi16_ph(_mm512_unpackhi_epi8(a_u8x64, _mm512_setzero_si512()));
|
|
399
|
-
b_low_f16x32 = _mm512_cvtepi16_ph(_mm512_unpacklo_epi8(b_u8x64, _mm512_setzero_si512()));
|
|
400
|
-
b_high_f16x32 = _mm512_cvtepi16_ph(_mm512_unpackhi_epi8(b_u8x64, _mm512_setzero_si512()));
|
|
401
|
-
c_low_f16x32 = _mm512_cvtepi16_ph(_mm512_unpacklo_epi8(c_u8x64, _mm512_setzero_si512()));
|
|
402
|
-
c_high_f16x32 = _mm512_cvtepi16_ph(_mm512_unpackhi_epi8(c_u8x64, _mm512_setzero_si512()));
|
|
403
|
-
// Multiply:
|
|
404
|
-
ab_low_f16x32 = _mm512_mul_ph(a_low_f16x32, b_low_f16x32);
|
|
405
|
-
ab_high_f16x32 = _mm512_mul_ph(a_high_f16x32, b_high_f16x32);
|
|
406
|
-
// Scale:
|
|
407
|
-
ab_scaled_low_f16x32 = _mm512_mul_ph(ab_low_f16x32, alpha_f16x32);
|
|
408
|
-
ab_scaled_high_f16x32 = _mm512_mul_ph(ab_high_f16x32, alpha_f16x32);
|
|
409
|
-
// Add:
|
|
410
|
-
result_low_f16x32 = _mm512_fmadd_ph(c_low_f16x32, beta_f16x32, ab_scaled_low_f16x32);
|
|
411
|
-
result_high_f16x32 = _mm512_fmadd_ph(c_high_f16x32, beta_f16x32, ab_scaled_high_f16x32);
|
|
412
|
-
// Clip the 16-bit result to 8-bit:
|
|
413
|
-
result_low_f16x32 = _mm512_max_ph(_mm512_min_ph(result_low_f16x32, max_f16x32), min_f16x32);
|
|
414
|
-
result_high_f16x32 = _mm512_max_ph(_mm512_min_ph(result_high_f16x32, max_f16x32), min_f16x32);
|
|
415
|
-
// Downcast:
|
|
416
|
-
result_low_i16x32 = _mm512_cvtph_epi16(result_low_f16x32);
|
|
417
|
-
result_high_i16x32 = _mm512_cvtph_epi16(result_high_f16x32);
|
|
418
|
-
// Merge back:
|
|
419
|
-
result_u8x64 = _mm512_packus_epi16(result_low_i16x32, result_high_i16x32);
|
|
420
|
-
_mm512_mask_storeu_epi8(result, mask, result_u8x64);
|
|
421
|
-
result += 64;
|
|
422
|
-
if (n) goto nk_each_fma_u8_sapphire_cycle;
|
|
423
|
-
}
|
|
424
|
-
|
|
425
290
|
NK_PUBLIC void nk_each_sum_e4m3_sapphire(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_e4m3_t *result) {
|
|
426
291
|
__m256i a_e4m3x32, b_e4m3x32;
|
|
427
|
-
__m256h
|
|
428
|
-
__m256h
|
|
429
|
-
__m128i
|
|
292
|
+
__m256h a_low_f16x16, a_high_f16x16, b_low_f16x16, b_high_f16x16;
|
|
293
|
+
__m256h sum_low_f16x16, sum_high_f16x16;
|
|
294
|
+
__m128i result_low_e4m3x16, result_high_e4m3x16;
|
|
430
295
|
__mmask32 mask = 0xFFFFFFFF;
|
|
431
296
|
nk_each_sum_e4m3_sapphire_cycle:
|
|
432
297
|
if (n < 32) {
|
|
@@ -442,21 +307,22 @@ nk_each_sum_e4m3_sapphire_cycle:
|
|
|
442
307
|
}
|
|
443
308
|
|
|
444
309
|
// Convert e4m3x16 → f16x16 (two halves)
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
310
|
+
a_low_f16x16 = nk_e4m3x16_to_f16x16_sapphire_(_mm256_castsi256_si128(a_e4m3x32));
|
|
311
|
+
a_high_f16x16 = nk_e4m3x16_to_f16x16_sapphire_(_mm256_extracti128_si256(a_e4m3x32, 1));
|
|
312
|
+
b_low_f16x16 = nk_e4m3x16_to_f16x16_sapphire_(_mm256_castsi256_si128(b_e4m3x32));
|
|
313
|
+
b_high_f16x16 = nk_e4m3x16_to_f16x16_sapphire_(_mm256_extracti128_si256(b_e4m3x32, 1));
|
|
449
314
|
|
|
450
315
|
// Add in F16 - e4m3 sum is safe (max 896 < 65504)
|
|
451
|
-
|
|
452
|
-
|
|
316
|
+
sum_low_f16x16 = _mm256_add_ph(a_low_f16x16, b_low_f16x16);
|
|
317
|
+
sum_high_f16x16 = _mm256_add_ph(a_high_f16x16, b_high_f16x16);
|
|
453
318
|
|
|
454
319
|
// Convert f16x16 → e4m3x16
|
|
455
|
-
|
|
456
|
-
|
|
320
|
+
result_low_e4m3x16 = nk_f16x16_to_e4m3x16_sapphire_(sum_low_f16x16);
|
|
321
|
+
result_high_e4m3x16 = nk_f16x16_to_e4m3x16_sapphire_(sum_high_f16x16);
|
|
457
322
|
|
|
458
323
|
// Pack and store
|
|
459
|
-
__m256i result_e4m3x32 = _mm256_inserti128_si256(_mm256_castsi128_si256(
|
|
324
|
+
__m256i result_e4m3x32 = _mm256_inserti128_si256(_mm256_castsi128_si256(result_low_e4m3x16), result_high_e4m3x16,
|
|
325
|
+
1);
|
|
460
326
|
_mm256_mask_storeu_epi8(result, mask, result_e4m3x32);
|
|
461
327
|
result += 32;
|
|
462
328
|
if (n) goto nk_each_sum_e4m3_sapphire_cycle;
|
|
@@ -107,8 +107,8 @@ nk_define_each_scale_(i16, f32, nk_assign_from_to_, nk_f32_to_i16_serial) /
|
|
|
107
107
|
nk_define_each_scale_(u16, f32, nk_assign_from_to_, nk_f32_to_u16_serial) // nk_each_scale_u16_serial
|
|
108
108
|
nk_define_each_scale_(i32, f64, nk_assign_from_to_, nk_f64_to_i32_serial) // nk_each_scale_i32_serial
|
|
109
109
|
nk_define_each_scale_(u32, f64, nk_assign_from_to_, nk_f64_to_u32_serial) // nk_each_scale_u32_serial
|
|
110
|
-
nk_define_each_scale_(i64, f64,
|
|
111
|
-
nk_define_each_scale_(u64, f64,
|
|
110
|
+
nk_define_each_scale_(i64, f64, nk_f64_from_i64_, nk_f64_to_i64_serial) // nk_each_scale_i64_serial
|
|
111
|
+
nk_define_each_scale_(u64, f64, nk_f64_from_u64_, nk_f64_to_u64_serial) // nk_each_scale_u64_serial
|
|
112
112
|
|
|
113
113
|
nk_define_each_blend_(f64, f64, nk_assign_from_to_, nk_assign_from_to_) // nk_each_blend_f64_serial
|
|
114
114
|
nk_define_each_blend_(f32, f32, nk_assign_from_to_, nk_assign_from_to_) // nk_each_blend_f32_serial
|
|
@@ -124,8 +124,8 @@ nk_define_each_blend_(i16, f32, nk_assign_from_to_, nk_f32_to_i16_serial) /
|
|
|
124
124
|
nk_define_each_blend_(u16, f32, nk_assign_from_to_, nk_f32_to_u16_serial) // nk_each_blend_u16_serial
|
|
125
125
|
nk_define_each_blend_(i32, f64, nk_assign_from_to_, nk_f64_to_i32_serial) // nk_each_blend_i32_serial
|
|
126
126
|
nk_define_each_blend_(u32, f64, nk_assign_from_to_, nk_f64_to_u32_serial) // nk_each_blend_u32_serial
|
|
127
|
-
nk_define_each_blend_(i64, f64,
|
|
128
|
-
nk_define_each_blend_(u64, f64,
|
|
127
|
+
nk_define_each_blend_(i64, f64, nk_f64_from_i64_, nk_f64_to_i64_serial) // nk_each_blend_i64_serial
|
|
128
|
+
nk_define_each_blend_(u64, f64, nk_f64_from_u64_, nk_f64_to_u64_serial) // nk_each_blend_u64_serial
|
|
129
129
|
|
|
130
130
|
nk_define_each_fma_(f64, f64, nk_assign_from_to_, nk_assign_from_to_) // nk_each_fma_f64_serial
|
|
131
131
|
nk_define_each_fma_(f32, f32, nk_assign_from_to_, nk_assign_from_to_) // nk_each_fma_f32_serial
|
|
@@ -141,8 +141,8 @@ nk_define_each_fma_(i16, f32, nk_assign_from_to_, nk_f32_to_i16_serial) //
|
|
|
141
141
|
nk_define_each_fma_(u16, f32, nk_assign_from_to_, nk_f32_to_u16_serial) // nk_each_fma_u16_serial
|
|
142
142
|
nk_define_each_fma_(i32, f64, nk_assign_from_to_, nk_f64_to_i32_serial) // nk_each_fma_i32_serial
|
|
143
143
|
nk_define_each_fma_(u32, f64, nk_assign_from_to_, nk_f64_to_u32_serial) // nk_each_fma_u32_serial
|
|
144
|
-
nk_define_each_fma_(i64, f64,
|
|
145
|
-
nk_define_each_fma_(u64, f64,
|
|
144
|
+
nk_define_each_fma_(i64, f64, nk_f64_from_i64_, nk_f64_to_i64_serial) // nk_each_fma_i64_serial
|
|
145
|
+
nk_define_each_fma_(u64, f64, nk_f64_from_u64_, nk_f64_to_u64_serial) // nk_each_fma_u64_serial
|
|
146
146
|
|
|
147
147
|
#undef nk_define_each_scale_
|
|
148
148
|
#undef nk_define_each_sum_
|
|
@@ -8,13 +8,13 @@
|
|
|
8
8
|
*
|
|
9
9
|
* @section skylake_elementwise_instructions Relevant Instructions
|
|
10
10
|
*
|
|
11
|
-
* Intrinsic
|
|
12
|
-
* _mm512_add_ps
|
|
13
|
-
* _mm512_fmadd_ps
|
|
14
|
-
* _mm512_mul_ps
|
|
15
|
-
* _mm512_cvtph_ps
|
|
16
|
-
* _mm512_maskz_loadu_ps
|
|
17
|
-
* _mm512_mask_storeu_ps
|
|
11
|
+
* Intrinsic Instruction SKL ICL Genoa
|
|
12
|
+
* _mm512_add_ps VADDPS (ZMM, ZMM, ZMM) 4cy @ p05 4cy @ p0 3cy @ p01
|
|
13
|
+
* _mm512_fmadd_ps VFMADD132PS (ZMM, ZMM, ZMM) 4cy @ p05 4cy @ p0 4cy @ p01
|
|
14
|
+
* _mm512_mul_ps VMULPS (ZMM, ZMM, ZMM) 4cy @ p05 4cy @ p0 3cy @ p01
|
|
15
|
+
* _mm512_cvtph_ps VCVTPH2PS (ZMM, YMM) 5cy @ p05 7cy @ p0 5cy @ p01
|
|
16
|
+
* _mm512_maskz_loadu_ps VMOVUPS (ZMM {K}, M512) 7cy @ p23 7cy @ p23 7cy @ p23
|
|
17
|
+
* _mm512_mask_storeu_ps VMOVUPS (M512 {K}, ZMM) 4cy @ p4 4cy @ p4 4cy @ p4
|
|
18
18
|
*
|
|
19
19
|
* Skylake-X server chips have dual 512-bit FMA units enabling 0.5cy throughput for arithmetic operations.
|
|
20
20
|
* AVX-512 masked loads and stores eliminate branch misprediction penalties for partial vector processing.
|