numkong 7.0.0 → 7.4.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +239 -122
- package/binding.gyp +25 -491
- package/c/dispatch_bf16.c +59 -1
- package/c/dispatch_e2m3.c +41 -8
- package/c/dispatch_e3m2.c +49 -8
- package/c/dispatch_e4m3.c +51 -9
- package/c/dispatch_e5m2.c +45 -1
- package/c/dispatch_f16.c +79 -26
- package/c/dispatch_f16c.c +5 -5
- package/c/dispatch_f32.c +56 -0
- package/c/dispatch_f64.c +52 -0
- package/c/dispatch_i4.c +3 -0
- package/c/dispatch_i8.c +62 -3
- package/c/dispatch_other.c +18 -0
- package/c/dispatch_u1.c +54 -9
- package/c/dispatch_u4.c +3 -0
- package/c/dispatch_u8.c +64 -3
- package/c/numkong.c +3 -0
- package/include/README.md +79 -9
- package/include/numkong/attention/sapphireamx.h +278 -276
- package/include/numkong/attention/sme.h +983 -977
- package/include/numkong/attention.h +1 -1
- package/include/numkong/capabilities.h +289 -94
- package/include/numkong/cast/README.md +40 -40
- package/include/numkong/cast/diamond.h +64 -0
- package/include/numkong/cast/haswell.h +42 -194
- package/include/numkong/cast/icelake.h +42 -37
- package/include/numkong/cast/loongsonasx.h +252 -0
- package/include/numkong/cast/neon.h +216 -249
- package/include/numkong/cast/powervsx.h +449 -0
- package/include/numkong/cast/rvv.h +223 -274
- package/include/numkong/cast/sapphire.h +18 -18
- package/include/numkong/cast/serial.h +1018 -944
- package/include/numkong/cast/skylake.h +82 -23
- package/include/numkong/cast/v128relaxed.h +462 -105
- package/include/numkong/cast.h +24 -0
- package/include/numkong/cast.hpp +44 -0
- package/include/numkong/curved/README.md +17 -17
- package/include/numkong/curved/neon.h +131 -7
- package/include/numkong/curved/neonbfdot.h +6 -7
- package/include/numkong/curved/rvv.h +26 -26
- package/include/numkong/curved/smef64.h +186 -182
- package/include/numkong/curved.h +14 -18
- package/include/numkong/dot/README.md +154 -137
- package/include/numkong/dot/alder.h +43 -43
- package/include/numkong/dot/diamond.h +158 -0
- package/include/numkong/dot/genoa.h +4 -30
- package/include/numkong/dot/haswell.h +215 -180
- package/include/numkong/dot/icelake.h +190 -76
- package/include/numkong/dot/loongsonasx.h +671 -0
- package/include/numkong/dot/neon.h +124 -73
- package/include/numkong/dot/neonbfdot.h +11 -12
- package/include/numkong/dot/neonfhm.h +44 -46
- package/include/numkong/dot/neonfp8.h +323 -0
- package/include/numkong/dot/neonsdot.h +190 -76
- package/include/numkong/dot/powervsx.h +752 -0
- package/include/numkong/dot/rvv.h +92 -84
- package/include/numkong/dot/rvvbf16.h +12 -12
- package/include/numkong/dot/rvvhalf.h +12 -12
- package/include/numkong/dot/sapphire.h +4 -4
- package/include/numkong/dot/serial.h +66 -30
- package/include/numkong/dot/sierra.h +31 -31
- package/include/numkong/dot/skylake.h +142 -110
- package/include/numkong/dot/sve.h +217 -177
- package/include/numkong/dot/svebfdot.h +10 -10
- package/include/numkong/dot/svehalf.h +85 -41
- package/include/numkong/dot/svesdot.h +89 -0
- package/include/numkong/dot/v128relaxed.h +124 -89
- package/include/numkong/dot.h +114 -48
- package/include/numkong/dots/README.md +203 -203
- package/include/numkong/dots/alder.h +12 -9
- package/include/numkong/dots/diamond.h +86 -0
- package/include/numkong/dots/genoa.h +10 -4
- package/include/numkong/dots/haswell.h +63 -48
- package/include/numkong/dots/icelake.h +27 -18
- package/include/numkong/dots/loongsonasx.h +176 -0
- package/include/numkong/dots/neon.h +14 -11
- package/include/numkong/dots/neonbfdot.h +4 -3
- package/include/numkong/dots/neonfhm.h +11 -9
- package/include/numkong/dots/neonfp8.h +99 -0
- package/include/numkong/dots/neonsdot.h +48 -12
- package/include/numkong/dots/powervsx.h +194 -0
- package/include/numkong/dots/rvv.h +451 -344
- package/include/numkong/dots/sapphireamx.h +1028 -984
- package/include/numkong/dots/serial.h +213 -197
- package/include/numkong/dots/sierra.h +10 -7
- package/include/numkong/dots/skylake.h +47 -36
- package/include/numkong/dots/sme.h +2001 -2364
- package/include/numkong/dots/smebi32.h +175 -162
- package/include/numkong/dots/smef64.h +328 -323
- package/include/numkong/dots/v128relaxed.h +64 -41
- package/include/numkong/dots.h +573 -293
- package/include/numkong/dots.hpp +45 -43
- package/include/numkong/each/README.md +133 -137
- package/include/numkong/each/haswell.h +6 -6
- package/include/numkong/each/icelake.h +7 -7
- package/include/numkong/each/neon.h +76 -42
- package/include/numkong/each/neonbfdot.h +11 -12
- package/include/numkong/each/neonhalf.h +24 -116
- package/include/numkong/each/rvv.h +28 -28
- package/include/numkong/each/sapphire.h +27 -161
- package/include/numkong/each/serial.h +6 -6
- package/include/numkong/each/skylake.h +7 -7
- package/include/numkong/each/v128relaxed.h +562 -0
- package/include/numkong/each.h +148 -62
- package/include/numkong/each.hpp +2 -2
- package/include/numkong/geospatial/README.md +18 -18
- package/include/numkong/geospatial/haswell.h +365 -325
- package/include/numkong/geospatial/neon.h +350 -306
- package/include/numkong/geospatial/rvv.h +4 -4
- package/include/numkong/geospatial/skylake.h +376 -340
- package/include/numkong/geospatial/v128relaxed.h +366 -327
- package/include/numkong/geospatial.h +17 -17
- package/include/numkong/matrix.hpp +4 -4
- package/include/numkong/maxsim/README.md +14 -14
- package/include/numkong/maxsim/alder.h +6 -6
- package/include/numkong/maxsim/genoa.h +4 -4
- package/include/numkong/maxsim/haswell.h +6 -6
- package/include/numkong/maxsim/icelake.h +18 -18
- package/include/numkong/maxsim/neonsdot.h +21 -21
- package/include/numkong/maxsim/sapphireamx.h +14 -14
- package/include/numkong/maxsim/serial.h +6 -6
- package/include/numkong/maxsim/sme.h +221 -196
- package/include/numkong/maxsim/v128relaxed.h +6 -6
- package/include/numkong/mesh/README.md +62 -56
- package/include/numkong/mesh/haswell.h +339 -464
- package/include/numkong/mesh/neon.h +1100 -519
- package/include/numkong/mesh/neonbfdot.h +36 -68
- package/include/numkong/mesh/rvv.h +530 -435
- package/include/numkong/mesh/serial.h +75 -91
- package/include/numkong/mesh/skylake.h +1627 -302
- package/include/numkong/mesh/v128relaxed.h +443 -330
- package/include/numkong/mesh.h +63 -49
- package/include/numkong/mesh.hpp +4 -4
- package/include/numkong/numkong.h +3 -3
- package/include/numkong/numkong.hpp +1 -0
- package/include/numkong/probability/README.md +23 -19
- package/include/numkong/probability/neon.h +82 -52
- package/include/numkong/probability/rvv.h +28 -23
- package/include/numkong/probability/serial.h +51 -39
- package/include/numkong/probability.h +20 -23
- package/include/numkong/random.h +1 -1
- package/include/numkong/reduce/README.md +143 -138
- package/include/numkong/reduce/alder.h +81 -77
- package/include/numkong/reduce/haswell.h +222 -220
- package/include/numkong/reduce/neon.h +629 -519
- package/include/numkong/reduce/neonbfdot.h +7 -218
- package/include/numkong/reduce/neonfhm.h +9 -381
- package/include/numkong/reduce/neonsdot.h +9 -9
- package/include/numkong/reduce/rvv.h +928 -802
- package/include/numkong/reduce/serial.h +23 -27
- package/include/numkong/reduce/sierra.h +20 -20
- package/include/numkong/reduce/skylake.h +326 -324
- package/include/numkong/reduce/v128relaxed.h +52 -52
- package/include/numkong/reduce.h +4 -23
- package/include/numkong/reduce.hpp +156 -11
- package/include/numkong/scalar/README.md +6 -6
- package/include/numkong/scalar/haswell.h +26 -17
- package/include/numkong/scalar/loongsonasx.h +74 -0
- package/include/numkong/scalar/neon.h +9 -9
- package/include/numkong/scalar/powervsx.h +96 -0
- package/include/numkong/scalar/rvv.h +2 -2
- package/include/numkong/scalar/sapphire.h +21 -10
- package/include/numkong/scalar/serial.h +21 -21
- package/include/numkong/scalar.h +13 -0
- package/include/numkong/set/README.md +28 -28
- package/include/numkong/set/haswell.h +12 -12
- package/include/numkong/set/icelake.h +14 -14
- package/include/numkong/set/loongsonasx.h +181 -0
- package/include/numkong/set/neon.h +17 -18
- package/include/numkong/set/powervsx.h +326 -0
- package/include/numkong/set/rvv.h +4 -4
- package/include/numkong/set/serial.h +6 -6
- package/include/numkong/set/sve.h +60 -59
- package/include/numkong/set/v128relaxed.h +6 -6
- package/include/numkong/set.h +21 -7
- package/include/numkong/sets/README.md +26 -26
- package/include/numkong/sets/loongsonasx.h +52 -0
- package/include/numkong/sets/powervsx.h +65 -0
- package/include/numkong/sets/smebi32.h +395 -364
- package/include/numkong/sets.h +83 -40
- package/include/numkong/sparse/README.md +4 -4
- package/include/numkong/sparse/icelake.h +101 -101
- package/include/numkong/sparse/serial.h +1 -1
- package/include/numkong/sparse/sve2.h +137 -141
- package/include/numkong/sparse/turin.h +12 -12
- package/include/numkong/sparse.h +10 -10
- package/include/numkong/spatial/README.md +230 -226
- package/include/numkong/spatial/alder.h +113 -116
- package/include/numkong/spatial/diamond.h +240 -0
- package/include/numkong/spatial/genoa.h +0 -68
- package/include/numkong/spatial/haswell.h +74 -55
- package/include/numkong/spatial/icelake.h +539 -58
- package/include/numkong/spatial/loongsonasx.h +483 -0
- package/include/numkong/spatial/neon.h +125 -52
- package/include/numkong/spatial/neonbfdot.h +8 -9
- package/include/numkong/spatial/neonfp8.h +258 -0
- package/include/numkong/spatial/neonsdot.h +180 -12
- package/include/numkong/spatial/powervsx.h +738 -0
- package/include/numkong/spatial/rvv.h +146 -139
- package/include/numkong/spatial/rvvbf16.h +17 -12
- package/include/numkong/spatial/rvvhalf.h +13 -10
- package/include/numkong/spatial/serial.h +13 -12
- package/include/numkong/spatial/sierra.h +232 -39
- package/include/numkong/spatial/skylake.h +73 -74
- package/include/numkong/spatial/sve.h +93 -72
- package/include/numkong/spatial/svebfdot.h +29 -29
- package/include/numkong/spatial/svehalf.h +52 -26
- package/include/numkong/spatial/svesdot.h +142 -0
- package/include/numkong/spatial/v128relaxed.h +293 -41
- package/include/numkong/spatial.h +338 -82
- package/include/numkong/spatials/README.md +194 -194
- package/include/numkong/spatials/diamond.h +82 -0
- package/include/numkong/spatials/haswell.h +2 -2
- package/include/numkong/spatials/loongsonasx.h +153 -0
- package/include/numkong/spatials/neonfp8.h +111 -0
- package/include/numkong/spatials/neonsdot.h +34 -0
- package/include/numkong/spatials/powervsx.h +153 -0
- package/include/numkong/spatials/rvv.h +259 -243
- package/include/numkong/spatials/sapphireamx.h +173 -173
- package/include/numkong/spatials/serial.h +2 -2
- package/include/numkong/spatials/skylake.h +2 -2
- package/include/numkong/spatials/sme.h +590 -605
- package/include/numkong/spatials/smef64.h +139 -130
- package/include/numkong/spatials/v128relaxed.h +2 -2
- package/include/numkong/spatials.h +820 -500
- package/include/numkong/spatials.hpp +49 -48
- package/include/numkong/tensor.hpp +406 -17
- package/include/numkong/trigonometry/README.md +19 -19
- package/include/numkong/trigonometry/haswell.h +402 -401
- package/include/numkong/trigonometry/neon.h +386 -387
- package/include/numkong/trigonometry/rvv.h +52 -51
- package/include/numkong/trigonometry/serial.h +13 -13
- package/include/numkong/trigonometry/skylake.h +373 -369
- package/include/numkong/trigonometry/v128relaxed.h +375 -374
- package/include/numkong/trigonometry.h +13 -13
- package/include/numkong/trigonometry.hpp +2 -2
- package/include/numkong/types.h +287 -49
- package/include/numkong/types.hpp +436 -12
- package/include/numkong/vector.hpp +82 -14
- package/javascript/dist/cjs/numkong-wasm.js +6 -12
- package/javascript/dist/cjs/numkong.d.ts +7 -1
- package/javascript/dist/cjs/numkong.js +37 -11
- package/javascript/dist/cjs/types.d.ts +9 -0
- package/javascript/dist/cjs/types.js +96 -0
- package/javascript/dist/esm/numkong-browser.d.ts +14 -0
- package/javascript/dist/esm/numkong-browser.js +23 -0
- package/javascript/dist/esm/numkong-wasm.js +6 -12
- package/javascript/dist/esm/numkong.d.ts +7 -1
- package/javascript/dist/esm/numkong.js +37 -11
- package/javascript/dist/esm/types.d.ts +9 -0
- package/javascript/dist/esm/types.js +96 -0
- package/javascript/node-gyp-build.d.ts +4 -1
- package/javascript/numkong-browser.ts +40 -0
- package/javascript/numkong-wasm.ts +7 -13
- package/javascript/numkong.c +5 -26
- package/javascript/numkong.ts +36 -11
- package/javascript/tsconfig-base.json +1 -0
- package/javascript/tsconfig-cjs.json +6 -1
- package/javascript/types.ts +110 -0
- package/numkong.gypi +101 -0
- package/package.json +34 -13
- package/probes/arm_neon.c +8 -0
- package/probes/arm_neon_bfdot.c +9 -0
- package/probes/arm_neon_fhm.c +9 -0
- package/probes/arm_neon_half.c +8 -0
- package/probes/arm_neon_sdot.c +9 -0
- package/probes/arm_neonfp8.c +9 -0
- package/probes/arm_sme.c +16 -0
- package/probes/arm_sme2.c +16 -0
- package/probes/arm_sme2p1.c +16 -0
- package/probes/arm_sme_bf16.c +16 -0
- package/probes/arm_sme_bi32.c +16 -0
- package/probes/arm_sme_f64.c +16 -0
- package/probes/arm_sme_fa64.c +14 -0
- package/probes/arm_sme_half.c +16 -0
- package/probes/arm_sme_lut2.c +15 -0
- package/probes/arm_sve.c +18 -0
- package/probes/arm_sve2.c +20 -0
- package/probes/arm_sve2p1.c +18 -0
- package/probes/arm_sve_bfdot.c +20 -0
- package/probes/arm_sve_half.c +18 -0
- package/probes/arm_sve_sdot.c +21 -0
- package/probes/loongarch_lasx.c +12 -0
- package/probes/power_vsx.c +12 -0
- package/probes/probe.js +127 -0
- package/probes/riscv_rvv.c +14 -0
- package/probes/riscv_rvv_bb.c +15 -0
- package/probes/riscv_rvv_bf16.c +17 -0
- package/probes/riscv_rvv_half.c +14 -0
- package/probes/wasm_v128relaxed.c +11 -0
- package/probes/x86_alder.c +17 -0
- package/probes/x86_diamond.c +17 -0
- package/probes/x86_genoa.c +17 -0
- package/probes/x86_graniteamx.c +19 -0
- package/probes/x86_haswell.c +11 -0
- package/probes/x86_icelake.c +17 -0
- package/probes/x86_sapphire.c +16 -0
- package/probes/x86_sapphireamx.c +18 -0
- package/probes/x86_sierra.c +17 -0
- package/probes/x86_skylake.c +15 -0
- package/probes/x86_turin.c +17 -0
- package/wasm/numkong-emscripten.js +2 -0
- package/wasm/numkong.d.ts +14 -0
- package/wasm/numkong.js +1124 -0
- package/wasm/numkong.wasm +0 -0
- package/include/numkong/curved/neonhalf.h +0 -212
- package/include/numkong/dot/neonhalf.h +0 -198
- package/include/numkong/dots/neonhalf.h +0 -57
- package/include/numkong/mesh/neonhalf.h +0 -616
- package/include/numkong/reduce/neonhalf.h +0 -157
- package/include/numkong/spatial/neonhalf.h +0 -118
- package/include/numkong/spatial/sapphire.h +0 -343
- package/include/numkong/spatials/neonhalf.h +0 -58
- package/javascript/README.md +0 -246
|
@@ -8,10 +8,10 @@
|
|
|
8
8
|
*
|
|
9
9
|
* @section dot_alder_instructions AVX-VNNI Instructions Performance
|
|
10
10
|
*
|
|
11
|
-
* Intrinsic
|
|
12
|
-
* _mm256_dpbusd_epi32
|
|
13
|
-
* _mm256_madd_epi16
|
|
14
|
-
* _mm256_sad_epu8
|
|
11
|
+
* Intrinsic Instruction Alder Lake Raptor Lake
|
|
12
|
+
* _mm256_dpbusd_epi32 VPDPBUSD (YMM, YMM, YMM) 4cy @ p05 4cy @ p05
|
|
13
|
+
* _mm256_madd_epi16 VPMADDWD (YMM, YMM, YMM) 4cy @ p05 4cy @ p05
|
|
14
|
+
* _mm256_sad_epu8 VPSADBW (YMM, YMM, YMM) 3cy @ p5 3cy @ p5
|
|
15
15
|
*
|
|
16
16
|
* Alder Lake and Raptor Lake support AVX-VNNI (256-bit VNNI)
|
|
17
17
|
* for accelerated integer dot products. This is the 256-bit variant of AVX-512 VNNI found on Ice Lake.
|
|
@@ -208,13 +208,13 @@ NK_INTERNAL void nk_dot_i8x32_finalize_alder(
|
|
|
208
208
|
_mm256_extracti128_si256(state_d->biased_product_sum_i32x8, 1));
|
|
209
209
|
|
|
210
210
|
// 4-way transpose reduce
|
|
211
|
-
__m128i
|
|
212
|
-
__m128i
|
|
213
|
-
__m128i
|
|
214
|
-
__m128i
|
|
211
|
+
__m128i t_ab_low = _mm_unpacklo_epi32(sum_a_i32x4, sum_b_i32x4);
|
|
212
|
+
__m128i t_cd_low = _mm_unpacklo_epi32(sum_c_i32x4, sum_d_i32x4);
|
|
213
|
+
__m128i t_ab_high = _mm_unpackhi_epi32(sum_a_i32x4, sum_b_i32x4);
|
|
214
|
+
__m128i t_cd_high = _mm_unpackhi_epi32(sum_c_i32x4, sum_d_i32x4);
|
|
215
215
|
__m128i biased_i32x4 = _mm_add_epi32(
|
|
216
|
-
_mm_add_epi32(_mm_unpacklo_epi64(
|
|
217
|
-
_mm_add_epi32(_mm_unpacklo_epi64(
|
|
216
|
+
_mm_add_epi32(_mm_unpacklo_epi64(t_ab_low, t_cd_low), _mm_unpackhi_epi64(t_ab_low, t_cd_low)),
|
|
217
|
+
_mm_add_epi32(_mm_unpacklo_epi64(t_ab_high, t_cd_high), _mm_unpackhi_epi64(t_ab_high, t_cd_high)));
|
|
218
218
|
|
|
219
219
|
// Apply compensation: result = biased − 128 × Σb
|
|
220
220
|
__m128i correction_i32x4 = _mm_slli_epi32(b_sums.xmm, 7); // × 128
|
|
@@ -328,13 +328,13 @@ NK_INTERNAL void nk_dot_u8x32_finalize_alder(
|
|
|
328
328
|
_mm256_extracti128_si256(state_d->biased_product_sum_i32x8, 1));
|
|
329
329
|
|
|
330
330
|
// 4-way transpose reduce
|
|
331
|
-
__m128i
|
|
332
|
-
__m128i
|
|
333
|
-
__m128i
|
|
334
|
-
__m128i
|
|
331
|
+
__m128i t_ab_low = _mm_unpacklo_epi32(sum_a_i32x4, sum_b_i32x4);
|
|
332
|
+
__m128i t_cd_low = _mm_unpacklo_epi32(sum_c_i32x4, sum_d_i32x4);
|
|
333
|
+
__m128i t_ab_high = _mm_unpackhi_epi32(sum_a_i32x4, sum_b_i32x4);
|
|
334
|
+
__m128i t_cd_high = _mm_unpackhi_epi32(sum_c_i32x4, sum_d_i32x4);
|
|
335
335
|
__m128i biased_i32x4 = _mm_add_epi32(
|
|
336
|
-
_mm_add_epi32(_mm_unpacklo_epi64(
|
|
337
|
-
_mm_add_epi32(_mm_unpacklo_epi64(
|
|
336
|
+
_mm_add_epi32(_mm_unpacklo_epi64(t_ab_low, t_cd_low), _mm_unpackhi_epi64(t_ab_low, t_cd_low)),
|
|
337
|
+
_mm_add_epi32(_mm_unpacklo_epi64(t_ab_high, t_cd_high), _mm_unpackhi_epi64(t_ab_high, t_cd_high)));
|
|
338
338
|
|
|
339
339
|
// Apply compensation: result = biased + 128 × Σb
|
|
340
340
|
__m128i correction_i32x4 = _mm_slli_epi32(b_sums.xmm, 7); // × 128
|
|
@@ -355,20 +355,20 @@ NK_INTERNAL void nk_sum_i8x32_init_alder(nk_sum_i8x32_state_alder_t *state) {
|
|
|
355
355
|
state->biased_sum_u64x4 = _mm256_setzero_si256();
|
|
356
356
|
}
|
|
357
357
|
NK_INTERNAL void nk_sum_i8x32_update_alder(nk_sum_i8x32_state_alder_t *state, nk_b256_vec_t vector) {
|
|
358
|
-
|
|
358
|
+
// Convert signed→unsigned via XOR 0x80, then SAD against zero gives sum of unsigned values
|
|
359
359
|
__m256i vector_unsigned_u8x32 = _mm256_xor_si256(vector.ymm, _mm256_set1_epi8((char)0x80));
|
|
360
360
|
__m256i sad_result_u64x4 = _mm256_sad_epu8(vector_unsigned_u8x32, _mm256_setzero_si256());
|
|
361
361
|
state->biased_sum_u64x4 = _mm256_add_epi64(state->biased_sum_u64x4, sad_result_u64x4);
|
|
362
362
|
}
|
|
363
363
|
NK_INTERNAL nk_i32_t nk_sum_i8x32_finalize_alder(nk_sum_i8x32_state_alder_t const *state, nk_size_t count) {
|
|
364
|
-
|
|
364
|
+
// Horizontal reduce u64x4 → scalar
|
|
365
365
|
__m128i low_u64x2 = _mm256_castsi256_si128(state->biased_sum_u64x4);
|
|
366
366
|
__m128i high_u64x2 = _mm256_extracti128_si256(state->biased_sum_u64x4, 1);
|
|
367
367
|
__m128i paired_u64x2 = _mm_add_epi64(low_u64x2, high_u64x2);
|
|
368
368
|
__m128i shuffled_u64x2 = _mm_shuffle_epi32(paired_u64x2, _MM_SHUFFLE(1, 0, 3, 2));
|
|
369
369
|
__m128i total_u64x2 = _mm_add_epi64(paired_u64x2, shuffled_u64x2);
|
|
370
370
|
nk_u64_t unsigned_sum = (nk_u64_t)_mm_cvtsi128_si64(total_u64x2);
|
|
371
|
-
|
|
371
|
+
// Undo XOR bias: signed_sum = unsigned_sum - 128 * count
|
|
372
372
|
return (nk_i32_t)((nk_i64_t)unsigned_sum - 128 * (nk_i64_t)count);
|
|
373
373
|
}
|
|
374
374
|
|
|
@@ -403,10 +403,10 @@ NK_PUBLIC void nk_dot_e2m3_alder(nk_e2m3_t const *a_scalars, nk_e2m3_t const *b_
|
|
|
403
403
|
// This is the Alder Lake (256-bit AVX-VNNI) variant of the Ice Lake kernel.
|
|
404
404
|
// DPBUSD replaces MADDUBS+MADD (2 instructions → 1), accumulating u8×i8→i32 directly.
|
|
405
405
|
//
|
|
406
|
-
__m256i const
|
|
406
|
+
__m256i const lut_low_u8x32 = _mm256_set_epi8( //
|
|
407
407
|
30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0, //
|
|
408
408
|
30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
|
|
409
|
-
__m256i const
|
|
409
|
+
__m256i const lut_high_u8x32 = _mm256_set_epi8( //
|
|
410
410
|
120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32, //
|
|
411
411
|
120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32);
|
|
412
412
|
__m256i const nibble_mask_u8x32 = _mm256_set1_epi8(0x0F);
|
|
@@ -436,18 +436,18 @@ nk_dot_e2m3_alder_cycle:
|
|
|
436
436
|
__m256i b_magnitude_u8x32 = _mm256_and_si256(b_e2m3_u8x32, magnitude_mask_u8x32);
|
|
437
437
|
__m256i a_shuffle_index_u8x32 = _mm256_and_si256(a_magnitude_u8x32, nibble_mask_u8x32);
|
|
438
438
|
__m256i b_shuffle_index_u8x32 = _mm256_and_si256(b_magnitude_u8x32, nibble_mask_u8x32);
|
|
439
|
-
__m256i
|
|
440
|
-
|
|
441
|
-
__m256i
|
|
442
|
-
|
|
439
|
+
__m256i a_high_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(a_magnitude_u8x32, half_select_u8x32),
|
|
440
|
+
half_select_u8x32);
|
|
441
|
+
__m256i b_high_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(b_magnitude_u8x32, half_select_u8x32),
|
|
442
|
+
half_select_u8x32);
|
|
443
443
|
|
|
444
444
|
// Dual VPSHUFB: lookup in both halves, blend based on bit 4
|
|
445
|
-
__m256i a_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(
|
|
446
|
-
_mm256_shuffle_epi8(
|
|
447
|
-
|
|
448
|
-
__m256i b_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(
|
|
449
|
-
_mm256_shuffle_epi8(
|
|
450
|
-
|
|
445
|
+
__m256i a_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_low_u8x32, a_shuffle_index_u8x32),
|
|
446
|
+
_mm256_shuffle_epi8(lut_high_u8x32, a_shuffle_index_u8x32),
|
|
447
|
+
a_high_select_u8x32);
|
|
448
|
+
__m256i b_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_low_u8x32, b_shuffle_index_u8x32),
|
|
449
|
+
_mm256_shuffle_epi8(lut_high_u8x32, b_shuffle_index_u8x32),
|
|
450
|
+
b_high_select_u8x32);
|
|
451
451
|
|
|
452
452
|
// Combined sign: (a ^ b) & 0x20, negate b where signs differ
|
|
453
453
|
__m256i sign_combined_u8x32 = _mm256_and_si256(_mm256_xor_si256(a_e2m3_u8x32, b_e2m3_u8x32), sign_mask_u8x32);
|
|
@@ -474,10 +474,10 @@ NK_INTERNAL void nk_dot_e2m3x32_update_alder(nk_dot_e2m3x32_state_alder_t *state
|
|
|
474
474
|
nk_size_t depth_offset, nk_size_t active_dimensions) {
|
|
475
475
|
nk_unused_(depth_offset);
|
|
476
476
|
nk_unused_(active_dimensions);
|
|
477
|
-
__m256i const
|
|
477
|
+
__m256i const lut_low_u8x32 = _mm256_set_epi8( //
|
|
478
478
|
30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0, //
|
|
479
479
|
30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
|
|
480
|
-
__m256i const
|
|
480
|
+
__m256i const lut_high_u8x32 = _mm256_set_epi8( //
|
|
481
481
|
120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32, //
|
|
482
482
|
120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32);
|
|
483
483
|
__m256i const nibble_mask_u8x32 = _mm256_set1_epi8(0x0F);
|
|
@@ -493,18 +493,18 @@ NK_INTERNAL void nk_dot_e2m3x32_update_alder(nk_dot_e2m3x32_state_alder_t *state
|
|
|
493
493
|
__m256i b_magnitude_u8x32 = _mm256_and_si256(b_e2m3_u8x32, magnitude_mask_u8x32);
|
|
494
494
|
__m256i a_shuffle_index_u8x32 = _mm256_and_si256(a_magnitude_u8x32, nibble_mask_u8x32);
|
|
495
495
|
__m256i b_shuffle_index_u8x32 = _mm256_and_si256(b_magnitude_u8x32, nibble_mask_u8x32);
|
|
496
|
-
__m256i
|
|
497
|
-
|
|
498
|
-
__m256i
|
|
499
|
-
|
|
496
|
+
__m256i a_high_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(a_magnitude_u8x32, half_select_u8x32),
|
|
497
|
+
half_select_u8x32);
|
|
498
|
+
__m256i b_high_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(b_magnitude_u8x32, half_select_u8x32),
|
|
499
|
+
half_select_u8x32);
|
|
500
500
|
|
|
501
501
|
// Dual VPSHUFB + blend
|
|
502
|
-
__m256i a_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(
|
|
503
|
-
_mm256_shuffle_epi8(
|
|
504
|
-
|
|
505
|
-
__m256i b_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(
|
|
506
|
-
_mm256_shuffle_epi8(
|
|
507
|
-
|
|
502
|
+
__m256i a_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_low_u8x32, a_shuffle_index_u8x32),
|
|
503
|
+
_mm256_shuffle_epi8(lut_high_u8x32, a_shuffle_index_u8x32),
|
|
504
|
+
a_high_select_u8x32);
|
|
505
|
+
__m256i b_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_low_u8x32, b_shuffle_index_u8x32),
|
|
506
|
+
_mm256_shuffle_epi8(lut_high_u8x32, b_shuffle_index_u8x32),
|
|
507
|
+
b_high_select_u8x32);
|
|
508
508
|
|
|
509
509
|
// Combined sign + conditional negate
|
|
510
510
|
__m256i sign_combined_u8x32 = _mm256_and_si256(_mm256_xor_si256(a_e2m3_u8x32, b_e2m3_u8x32), sign_mask_u8x32);
|
|
@@ -0,0 +1,158 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SIMD-accelerated Dot Products for Diamond Rapids.
|
|
3
|
+
* @file include/numkong/dot/diamond.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date March 23, 2026
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/dot.h
|
|
8
|
+
*
|
|
9
|
+
* @section dot_diamond_instructions Key AVX10.2 FP8 + FP16 VNNI Instructions
|
|
10
|
+
*
|
|
11
|
+
* Intrinsic Instruction Diamond Rapids
|
|
12
|
+
* _mm512_cvthf8_ph VCVTHF82PH (ZMM, YMM) ~3cy (estimated)
|
|
13
|
+
* _mm512_cvtbf8_ph VCVTBF82PH (ZMM, YMM) ~3cy (estimated)
|
|
14
|
+
* _mm512_dpph_ps VDPPHPS (ZMM, ZMM, ZMM) ~6cy (estimated)
|
|
15
|
+
*
|
|
16
|
+
* Diamond Rapids (AVX10.2) introduces native FP8→FP16 conversion via VCVTHF82PH (E4M3→FP16)
|
|
17
|
+
* and VCVTBF82PH (E5M2→FP16), replacing the multi-instruction arithmetic conversion used by
|
|
18
|
+
* Genoa's BF16 path. VDPPHPS then computes two FP16 dot products per 32-bit lane, accumulating
|
|
19
|
+
* into FP32 — providing the same 32-element throughput as Genoa's VDPBF16PS but with FP16
|
|
20
|
+
* intermediate precision (10-bit mantissa vs BF16's 7-bit).
|
|
21
|
+
*
|
|
22
|
+
* @section dot_diamond_stateful Stateful Streaming Logic
|
|
23
|
+
*
|
|
24
|
+
* Defines stateful init/update/finalize helpers for tiled GEMM via the dots/ macros:
|
|
25
|
+
* - nk_dot_through_f16_state_diamond_t_ shared by both E4M3 and E5M2 (FP16→VDPPHPS→FP32)
|
|
26
|
+
*/
|
|
27
|
+
#ifndef NK_DOT_DIAMOND_H
|
|
28
|
+
#define NK_DOT_DIAMOND_H
|
|
29
|
+
|
|
30
|
+
#if NK_TARGET_X86_
|
|
31
|
+
#if NK_TARGET_DIAMOND
|
|
32
|
+
|
|
33
|
+
#include "numkong/types.h"
|
|
34
|
+
#include "numkong/cast/diamond.h" // `nk_load_e4m3x32_to_f16x32_diamond_`
|
|
35
|
+
#include "numkong/reduce/skylake.h" // `nk_reduce_add_f32x16_skylake_`
|
|
36
|
+
#include "numkong/dot/skylake.h" // `nk_dot_through_f32_finalize_skylake_`
|
|
37
|
+
|
|
38
|
+
#if defined(__cplusplus)
|
|
39
|
+
extern "C" {
|
|
40
|
+
#endif
|
|
41
|
+
|
|
42
|
+
#if defined(__clang__)
|
|
43
|
+
#pragma clang attribute push( \
|
|
44
|
+
__attribute__((target("avx2,avx512f,avx512vl,avx512bw,avx512dq,avx512fp16,avx10.2-512,f16c,fma,bmi,bmi2"))), \
|
|
45
|
+
apply_to = function)
|
|
46
|
+
#elif defined(__GNUC__)
|
|
47
|
+
#pragma GCC push_options
|
|
48
|
+
#pragma GCC target("avx2", "avx512f", "avx512vl", "avx512bw", "avx512dq", "avx512fp16", "avx10.2-512", "f16c", "fma", \
|
|
49
|
+
"bmi", "bmi2")
|
|
50
|
+
#endif
|
|
51
|
+
|
|
52
|
+
NK_PUBLIC void nk_dot_e4m3_diamond(nk_e4m3_t const *a_scalars, nk_e4m3_t const *b_scalars, nk_size_t count_scalars,
|
|
53
|
+
nk_f32_t *result) {
|
|
54
|
+
__m256i a_e4m3x32, b_e4m3x32;
|
|
55
|
+
__m512 sum_f32x16 = _mm512_setzero_ps();
|
|
56
|
+
|
|
57
|
+
nk_dot_e4m3_diamond_cycle:
|
|
58
|
+
if (count_scalars < 32) {
|
|
59
|
+
__mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, count_scalars);
|
|
60
|
+
a_e4m3x32 = _mm256_maskz_loadu_epi8(mask, a_scalars);
|
|
61
|
+
b_e4m3x32 = _mm256_maskz_loadu_epi8(mask, b_scalars);
|
|
62
|
+
count_scalars = 0;
|
|
63
|
+
}
|
|
64
|
+
else {
|
|
65
|
+
a_e4m3x32 = _mm256_loadu_epi8(a_scalars);
|
|
66
|
+
b_e4m3x32 = _mm256_loadu_epi8(b_scalars);
|
|
67
|
+
a_scalars += 32, b_scalars += 32, count_scalars -= 32;
|
|
68
|
+
}
|
|
69
|
+
__m512h a_f16x32 = _mm512_cvthf8_ph(a_e4m3x32);
|
|
70
|
+
__m512h b_f16x32 = _mm512_cvthf8_ph(b_e4m3x32);
|
|
71
|
+
sum_f32x16 = _mm512_dpph_ps(sum_f32x16, a_f16x32, b_f16x32);
|
|
72
|
+
if (count_scalars) goto nk_dot_e4m3_diamond_cycle;
|
|
73
|
+
|
|
74
|
+
*result = nk_reduce_add_f32x16_skylake_(sum_f32x16);
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
NK_PUBLIC void nk_dot_e5m2_diamond(nk_e5m2_t const *a_scalars, nk_e5m2_t const *b_scalars, nk_size_t count_scalars,
|
|
78
|
+
nk_f32_t *result) {
|
|
79
|
+
__m256i a_e5m2x32, b_e5m2x32;
|
|
80
|
+
__m512 sum_f32x16 = _mm512_setzero_ps();
|
|
81
|
+
|
|
82
|
+
nk_dot_e5m2_diamond_cycle:
|
|
83
|
+
if (count_scalars < 32) {
|
|
84
|
+
__mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, count_scalars);
|
|
85
|
+
a_e5m2x32 = _mm256_maskz_loadu_epi8(mask, a_scalars);
|
|
86
|
+
b_e5m2x32 = _mm256_maskz_loadu_epi8(mask, b_scalars);
|
|
87
|
+
count_scalars = 0;
|
|
88
|
+
}
|
|
89
|
+
else {
|
|
90
|
+
a_e5m2x32 = _mm256_loadu_epi8(a_scalars);
|
|
91
|
+
b_e5m2x32 = _mm256_loadu_epi8(b_scalars);
|
|
92
|
+
a_scalars += 32, b_scalars += 32, count_scalars -= 32;
|
|
93
|
+
}
|
|
94
|
+
__m512h a_f16x32 = _mm512_cvtbf8_ph(a_e5m2x32);
|
|
95
|
+
__m512h b_f16x32 = _mm512_cvtbf8_ph(b_e5m2x32);
|
|
96
|
+
sum_f32x16 = _mm512_dpph_ps(sum_f32x16, a_f16x32, b_f16x32);
|
|
97
|
+
if (count_scalars) goto nk_dot_e5m2_diamond_cycle;
|
|
98
|
+
|
|
99
|
+
*result = nk_reduce_add_f32x16_skylake_(sum_f32x16);
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
NK_PUBLIC void nk_dot_f16_diamond(nk_f16_t const *a_scalars, nk_f16_t const *b_scalars, nk_size_t count_scalars,
|
|
103
|
+
nk_f32_t *result) {
|
|
104
|
+
__m512h a_f16x32, b_f16x32;
|
|
105
|
+
__m512 sum_f32x16 = _mm512_setzero_ps();
|
|
106
|
+
|
|
107
|
+
nk_dot_f16_diamond_cycle:
|
|
108
|
+
if (count_scalars < 32) {
|
|
109
|
+
__mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, count_scalars);
|
|
110
|
+
a_f16x32 = _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(mask, a_scalars));
|
|
111
|
+
b_f16x32 = _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(mask, b_scalars));
|
|
112
|
+
count_scalars = 0;
|
|
113
|
+
}
|
|
114
|
+
else {
|
|
115
|
+
a_f16x32 = _mm512_castsi512_ph(_mm512_loadu_epi16(a_scalars));
|
|
116
|
+
b_f16x32 = _mm512_castsi512_ph(_mm512_loadu_epi16(b_scalars));
|
|
117
|
+
a_scalars += 32, b_scalars += 32, count_scalars -= 32;
|
|
118
|
+
}
|
|
119
|
+
sum_f32x16 = _mm512_dpph_ps(sum_f32x16, a_f16x32, b_f16x32);
|
|
120
|
+
if (count_scalars) goto nk_dot_f16_diamond_cycle;
|
|
121
|
+
|
|
122
|
+
*result = nk_reduce_add_f32x16_skylake_(sum_f32x16);
|
|
123
|
+
}
|
|
124
|
+
|
|
125
|
+
typedef nk_dot_through_f32_state_skylake_t_ nk_dot_through_f16_state_diamond_t_;
|
|
126
|
+
|
|
127
|
+
NK_INTERNAL void nk_dot_through_f16_init_diamond_(nk_dot_through_f16_state_diamond_t_ *state) {
|
|
128
|
+
state->sum_f32x16 = _mm512_setzero();
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
NK_INTERNAL void nk_dot_through_f16_update_diamond_(nk_dot_through_f16_state_diamond_t_ *state, nk_b512_vec_t a,
|
|
132
|
+
nk_b512_vec_t b, nk_size_t depth_offset,
|
|
133
|
+
nk_size_t active_dimensions) {
|
|
134
|
+
nk_unused_(depth_offset);
|
|
135
|
+
nk_unused_(active_dimensions);
|
|
136
|
+
state->sum_f32x16 = _mm512_dpph_ps(state->sum_f32x16, a.zmm_ph, b.zmm_ph);
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
NK_INTERNAL void nk_dot_through_f16_finalize_diamond_( //
|
|
140
|
+
nk_dot_through_f16_state_diamond_t_ const *state_a, nk_dot_through_f16_state_diamond_t_ const *state_b, //
|
|
141
|
+
nk_dot_through_f16_state_diamond_t_ const *state_c, nk_dot_through_f16_state_diamond_t_ const *state_d, //
|
|
142
|
+
nk_size_t total_dimensions, nk_b128_vec_t *result) {
|
|
143
|
+
nk_dot_through_f32_finalize_skylake_(state_a, state_b, state_c, state_d, total_dimensions, result);
|
|
144
|
+
}
|
|
145
|
+
|
|
146
|
+
#if defined(__clang__)
|
|
147
|
+
#pragma clang attribute pop
|
|
148
|
+
#elif defined(__GNUC__)
|
|
149
|
+
#pragma GCC pop_options
|
|
150
|
+
#endif
|
|
151
|
+
|
|
152
|
+
#if defined(__cplusplus)
|
|
153
|
+
} // extern "C"
|
|
154
|
+
#endif
|
|
155
|
+
|
|
156
|
+
#endif // NK_TARGET_DIAMOND
|
|
157
|
+
#endif // NK_TARGET_X86_
|
|
158
|
+
#endif // NK_DOT_DIAMOND_H
|
|
@@ -8,10 +8,10 @@
|
|
|
8
8
|
*
|
|
9
9
|
* @section dot_genoa_instructions Key AVX-512 BF16 Instructions
|
|
10
10
|
*
|
|
11
|
-
* Intrinsic
|
|
12
|
-
* _mm512_dpbf16_ps
|
|
13
|
-
* _mm512_fmadd_ps
|
|
14
|
-
* _mm512_add_ps
|
|
11
|
+
* Intrinsic Instruction Genoa Alder Lake
|
|
12
|
+
* _mm512_dpbf16_ps VDPBF16PS (ZMM, ZMM, ZMM) 6cy @ p01 8cy @ p0+p0+p5+p5
|
|
13
|
+
* _mm512_fmadd_ps VFMADD132PS (ZMM, ZMM, ZMM) 4cy @ p01 4cy @ p0
|
|
14
|
+
* _mm512_add_ps VADDPS (ZMM, ZMM, ZMM) 4cy @ p01 3cy @ p05
|
|
15
15
|
*
|
|
16
16
|
* AMD Genoa introduces native AVX-512 BF16 support with VDPBF16PS, which computes two BF16 dot products
|
|
17
17
|
* per 32-bit lane (32 BF16 multiplies accumulated into 16 FP32 values per instruction). This provides
|
|
@@ -208,32 +208,6 @@ nk_vdot_bf16c_genoa_cycle:
|
|
|
208
208
|
result->imag = nk_reduce_add_f32x16_skylake_(sum_imag_f32x16);
|
|
209
209
|
}
|
|
210
210
|
|
|
211
|
-
NK_PUBLIC void nk_dot_e4m3_genoa(nk_e4m3_t const *a_scalars, nk_e4m3_t const *b_scalars, nk_size_t count_scalars,
|
|
212
|
-
nk_f32_t *result) {
|
|
213
|
-
__m256i a_e4m3x32, b_e4m3x32;
|
|
214
|
-
__m512 sum_f32x16 = _mm512_setzero_ps();
|
|
215
|
-
|
|
216
|
-
nk_dot_e4m3_genoa_cycle:
|
|
217
|
-
if (count_scalars < 32) {
|
|
218
|
-
__mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, count_scalars);
|
|
219
|
-
a_e4m3x32 = _mm256_maskz_loadu_epi8(mask, a_scalars);
|
|
220
|
-
b_e4m3x32 = _mm256_maskz_loadu_epi8(mask, b_scalars);
|
|
221
|
-
count_scalars = 0;
|
|
222
|
-
}
|
|
223
|
-
else {
|
|
224
|
-
a_e4m3x32 = _mm256_loadu_epi8(a_scalars);
|
|
225
|
-
b_e4m3x32 = _mm256_loadu_epi8(b_scalars);
|
|
226
|
-
a_scalars += 32, b_scalars += 32, count_scalars -= 32;
|
|
227
|
-
}
|
|
228
|
-
// Convert E4M3 to BF16 and compute dot product
|
|
229
|
-
__m512i a_bf16x32 = nk_e4m3x32_to_bf16x32_icelake_(a_e4m3x32);
|
|
230
|
-
__m512i b_bf16x32 = nk_e4m3x32_to_bf16x32_icelake_(b_e4m3x32);
|
|
231
|
-
sum_f32x16 = _mm512_dpbf16_ps(sum_f32x16, nk_m512bh_from_m512i_(a_bf16x32), nk_m512bh_from_m512i_(b_bf16x32));
|
|
232
|
-
if (count_scalars) goto nk_dot_e4m3_genoa_cycle;
|
|
233
|
-
|
|
234
|
-
*result = nk_reduce_add_f32x16_skylake_(sum_f32x16);
|
|
235
|
-
}
|
|
236
|
-
|
|
237
211
|
NK_PUBLIC void nk_dot_e5m2_genoa(nk_e5m2_t const *a_scalars, nk_e5m2_t const *b_scalars, nk_size_t count_scalars,
|
|
238
212
|
nk_f32_t *result) {
|
|
239
213
|
__m256i a_e5m2x32, b_e5m2x32;
|