numkong 7.0.0 → 7.4.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +197 -124
- package/binding.gyp +34 -484
- package/c/dispatch_bf16.c +59 -1
- package/c/dispatch_e2m3.c +41 -8
- package/c/dispatch_e3m2.c +49 -8
- package/c/dispatch_e4m3.c +51 -9
- package/c/dispatch_e5m2.c +45 -1
- package/c/dispatch_f16.c +79 -26
- package/c/dispatch_f16c.c +5 -5
- package/c/dispatch_f32.c +56 -0
- package/c/dispatch_f64.c +52 -0
- package/c/dispatch_i4.c +3 -0
- package/c/dispatch_i8.c +62 -3
- package/c/dispatch_other.c +18 -0
- package/c/dispatch_u1.c +54 -9
- package/c/dispatch_u4.c +3 -0
- package/c/dispatch_u8.c +64 -3
- package/c/numkong.c +3 -0
- package/include/README.md +79 -9
- package/include/numkong/attention/sapphireamx.h +278 -276
- package/include/numkong/attention/sme.h +983 -977
- package/include/numkong/attention.h +1 -1
- package/include/numkong/capabilities.h +289 -94
- package/include/numkong/cast/README.md +40 -40
- package/include/numkong/cast/diamond.h +64 -0
- package/include/numkong/cast/haswell.h +42 -194
- package/include/numkong/cast/icelake.h +42 -37
- package/include/numkong/cast/loongsonasx.h +252 -0
- package/include/numkong/cast/neon.h +216 -249
- package/include/numkong/cast/powervsx.h +449 -0
- package/include/numkong/cast/rvv.h +223 -274
- package/include/numkong/cast/sapphire.h +18 -18
- package/include/numkong/cast/serial.h +1018 -944
- package/include/numkong/cast/skylake.h +82 -23
- package/include/numkong/cast/v128relaxed.h +462 -105
- package/include/numkong/cast.h +24 -0
- package/include/numkong/cast.hpp +44 -0
- package/include/numkong/curved/README.md +17 -17
- package/include/numkong/curved/neon.h +131 -7
- package/include/numkong/curved/neonbfdot.h +6 -7
- package/include/numkong/curved/rvv.h +26 -26
- package/include/numkong/curved/smef64.h +186 -182
- package/include/numkong/curved.h +14 -18
- package/include/numkong/dot/README.md +154 -137
- package/include/numkong/dot/alder.h +43 -43
- package/include/numkong/dot/diamond.h +158 -0
- package/include/numkong/dot/genoa.h +4 -30
- package/include/numkong/dot/haswell.h +215 -180
- package/include/numkong/dot/icelake.h +190 -76
- package/include/numkong/dot/loongsonasx.h +671 -0
- package/include/numkong/dot/neon.h +124 -73
- package/include/numkong/dot/neonbfdot.h +11 -12
- package/include/numkong/dot/neonfhm.h +44 -46
- package/include/numkong/dot/neonfp8.h +323 -0
- package/include/numkong/dot/neonsdot.h +190 -76
- package/include/numkong/dot/powervsx.h +752 -0
- package/include/numkong/dot/rvv.h +92 -84
- package/include/numkong/dot/rvvbf16.h +12 -12
- package/include/numkong/dot/rvvhalf.h +12 -12
- package/include/numkong/dot/sapphire.h +4 -4
- package/include/numkong/dot/serial.h +66 -30
- package/include/numkong/dot/sierra.h +31 -31
- package/include/numkong/dot/skylake.h +142 -110
- package/include/numkong/dot/sve.h +217 -177
- package/include/numkong/dot/svebfdot.h +10 -10
- package/include/numkong/dot/svehalf.h +85 -41
- package/include/numkong/dot/svesdot.h +89 -0
- package/include/numkong/dot/v128relaxed.h +124 -89
- package/include/numkong/dot.h +114 -48
- package/include/numkong/dots/README.md +203 -203
- package/include/numkong/dots/alder.h +12 -9
- package/include/numkong/dots/diamond.h +86 -0
- package/include/numkong/dots/genoa.h +10 -4
- package/include/numkong/dots/haswell.h +63 -48
- package/include/numkong/dots/icelake.h +27 -18
- package/include/numkong/dots/loongsonasx.h +176 -0
- package/include/numkong/dots/neon.h +14 -11
- package/include/numkong/dots/neonbfdot.h +4 -3
- package/include/numkong/dots/neonfhm.h +11 -9
- package/include/numkong/dots/neonfp8.h +99 -0
- package/include/numkong/dots/neonsdot.h +48 -12
- package/include/numkong/dots/powervsx.h +194 -0
- package/include/numkong/dots/rvv.h +451 -344
- package/include/numkong/dots/sapphireamx.h +1028 -984
- package/include/numkong/dots/serial.h +213 -197
- package/include/numkong/dots/sierra.h +10 -7
- package/include/numkong/dots/skylake.h +47 -36
- package/include/numkong/dots/sme.h +2001 -2364
- package/include/numkong/dots/smebi32.h +175 -162
- package/include/numkong/dots/smef64.h +328 -323
- package/include/numkong/dots/v128relaxed.h +64 -41
- package/include/numkong/dots.h +573 -293
- package/include/numkong/dots.hpp +45 -43
- package/include/numkong/each/README.md +133 -137
- package/include/numkong/each/haswell.h +6 -6
- package/include/numkong/each/icelake.h +7 -7
- package/include/numkong/each/neon.h +76 -42
- package/include/numkong/each/neonbfdot.h +11 -12
- package/include/numkong/each/neonhalf.h +24 -116
- package/include/numkong/each/rvv.h +28 -28
- package/include/numkong/each/sapphire.h +27 -161
- package/include/numkong/each/serial.h +6 -6
- package/include/numkong/each/skylake.h +7 -7
- package/include/numkong/each/v128relaxed.h +562 -0
- package/include/numkong/each.h +148 -62
- package/include/numkong/each.hpp +2 -2
- package/include/numkong/geospatial/README.md +18 -18
- package/include/numkong/geospatial/haswell.h +365 -325
- package/include/numkong/geospatial/neon.h +350 -306
- package/include/numkong/geospatial/rvv.h +4 -4
- package/include/numkong/geospatial/skylake.h +376 -340
- package/include/numkong/geospatial/v128relaxed.h +366 -327
- package/include/numkong/geospatial.h +17 -17
- package/include/numkong/matrix.hpp +4 -4
- package/include/numkong/maxsim/README.md +14 -14
- package/include/numkong/maxsim/alder.h +6 -6
- package/include/numkong/maxsim/genoa.h +4 -4
- package/include/numkong/maxsim/haswell.h +6 -6
- package/include/numkong/maxsim/icelake.h +18 -18
- package/include/numkong/maxsim/neonsdot.h +21 -21
- package/include/numkong/maxsim/sapphireamx.h +14 -14
- package/include/numkong/maxsim/serial.h +6 -6
- package/include/numkong/maxsim/sme.h +221 -196
- package/include/numkong/maxsim/v128relaxed.h +6 -6
- package/include/numkong/mesh/README.md +62 -56
- package/include/numkong/mesh/haswell.h +339 -464
- package/include/numkong/mesh/neon.h +1100 -519
- package/include/numkong/mesh/neonbfdot.h +36 -68
- package/include/numkong/mesh/rvv.h +530 -435
- package/include/numkong/mesh/serial.h +75 -91
- package/include/numkong/mesh/skylake.h +1627 -302
- package/include/numkong/mesh/v128relaxed.h +443 -330
- package/include/numkong/mesh.h +63 -49
- package/include/numkong/mesh.hpp +4 -4
- package/include/numkong/numkong.h +3 -3
- package/include/numkong/numkong.hpp +1 -0
- package/include/numkong/probability/README.md +23 -19
- package/include/numkong/probability/neon.h +82 -52
- package/include/numkong/probability/rvv.h +28 -23
- package/include/numkong/probability/serial.h +51 -39
- package/include/numkong/probability.h +20 -23
- package/include/numkong/random.h +1 -1
- package/include/numkong/reduce/README.md +143 -138
- package/include/numkong/reduce/alder.h +81 -77
- package/include/numkong/reduce/haswell.h +222 -220
- package/include/numkong/reduce/neon.h +629 -519
- package/include/numkong/reduce/neonbfdot.h +7 -218
- package/include/numkong/reduce/neonfhm.h +9 -381
- package/include/numkong/reduce/neonsdot.h +9 -9
- package/include/numkong/reduce/rvv.h +928 -802
- package/include/numkong/reduce/serial.h +23 -27
- package/include/numkong/reduce/sierra.h +20 -20
- package/include/numkong/reduce/skylake.h +326 -324
- package/include/numkong/reduce/v128relaxed.h +52 -52
- package/include/numkong/reduce.h +4 -23
- package/include/numkong/reduce.hpp +156 -11
- package/include/numkong/scalar/README.md +6 -6
- package/include/numkong/scalar/haswell.h +26 -17
- package/include/numkong/scalar/loongsonasx.h +74 -0
- package/include/numkong/scalar/neon.h +9 -9
- package/include/numkong/scalar/powervsx.h +96 -0
- package/include/numkong/scalar/rvv.h +2 -2
- package/include/numkong/scalar/sapphire.h +21 -10
- package/include/numkong/scalar/serial.h +21 -21
- package/include/numkong/scalar.h +13 -0
- package/include/numkong/set/README.md +28 -28
- package/include/numkong/set/haswell.h +12 -12
- package/include/numkong/set/icelake.h +14 -14
- package/include/numkong/set/loongsonasx.h +181 -0
- package/include/numkong/set/neon.h +17 -18
- package/include/numkong/set/powervsx.h +326 -0
- package/include/numkong/set/rvv.h +4 -4
- package/include/numkong/set/serial.h +6 -6
- package/include/numkong/set/sve.h +60 -59
- package/include/numkong/set/v128relaxed.h +6 -6
- package/include/numkong/set.h +21 -7
- package/include/numkong/sets/README.md +26 -26
- package/include/numkong/sets/loongsonasx.h +52 -0
- package/include/numkong/sets/powervsx.h +65 -0
- package/include/numkong/sets/smebi32.h +395 -364
- package/include/numkong/sets.h +83 -40
- package/include/numkong/sparse/README.md +4 -4
- package/include/numkong/sparse/icelake.h +101 -101
- package/include/numkong/sparse/serial.h +1 -1
- package/include/numkong/sparse/sve2.h +137 -141
- package/include/numkong/sparse/turin.h +12 -12
- package/include/numkong/sparse.h +10 -10
- package/include/numkong/spatial/README.md +230 -226
- package/include/numkong/spatial/alder.h +113 -116
- package/include/numkong/spatial/diamond.h +240 -0
- package/include/numkong/spatial/genoa.h +0 -68
- package/include/numkong/spatial/haswell.h +74 -55
- package/include/numkong/spatial/icelake.h +539 -58
- package/include/numkong/spatial/loongsonasx.h +483 -0
- package/include/numkong/spatial/neon.h +125 -52
- package/include/numkong/spatial/neonbfdot.h +8 -9
- package/include/numkong/spatial/neonfp8.h +258 -0
- package/include/numkong/spatial/neonsdot.h +180 -12
- package/include/numkong/spatial/powervsx.h +738 -0
- package/include/numkong/spatial/rvv.h +146 -139
- package/include/numkong/spatial/rvvbf16.h +17 -12
- package/include/numkong/spatial/rvvhalf.h +13 -10
- package/include/numkong/spatial/serial.h +13 -12
- package/include/numkong/spatial/sierra.h +232 -39
- package/include/numkong/spatial/skylake.h +73 -74
- package/include/numkong/spatial/sve.h +93 -72
- package/include/numkong/spatial/svebfdot.h +29 -29
- package/include/numkong/spatial/svehalf.h +52 -26
- package/include/numkong/spatial/svesdot.h +142 -0
- package/include/numkong/spatial/v128relaxed.h +293 -41
- package/include/numkong/spatial.h +338 -82
- package/include/numkong/spatials/README.md +194 -194
- package/include/numkong/spatials/diamond.h +82 -0
- package/include/numkong/spatials/haswell.h +2 -2
- package/include/numkong/spatials/loongsonasx.h +153 -0
- package/include/numkong/spatials/neonfp8.h +111 -0
- package/include/numkong/spatials/neonsdot.h +34 -0
- package/include/numkong/spatials/powervsx.h +153 -0
- package/include/numkong/spatials/rvv.h +259 -243
- package/include/numkong/spatials/sapphireamx.h +173 -173
- package/include/numkong/spatials/serial.h +2 -2
- package/include/numkong/spatials/skylake.h +2 -2
- package/include/numkong/spatials/sme.h +590 -605
- package/include/numkong/spatials/smef64.h +139 -130
- package/include/numkong/spatials/v128relaxed.h +2 -2
- package/include/numkong/spatials.h +820 -500
- package/include/numkong/spatials.hpp +49 -48
- package/include/numkong/tensor.hpp +406 -17
- package/include/numkong/trigonometry/README.md +19 -19
- package/include/numkong/trigonometry/haswell.h +402 -401
- package/include/numkong/trigonometry/neon.h +386 -387
- package/include/numkong/trigonometry/rvv.h +52 -51
- package/include/numkong/trigonometry/serial.h +13 -13
- package/include/numkong/trigonometry/skylake.h +373 -369
- package/include/numkong/trigonometry/v128relaxed.h +375 -374
- package/include/numkong/trigonometry.h +13 -13
- package/include/numkong/trigonometry.hpp +2 -2
- package/include/numkong/types.h +287 -49
- package/include/numkong/types.hpp +436 -12
- package/include/numkong/vector.hpp +82 -14
- package/javascript/dist/cjs/numkong-wasm.js +6 -12
- package/javascript/dist/cjs/numkong.d.ts +7 -1
- package/javascript/dist/cjs/numkong.js +37 -11
- package/javascript/dist/cjs/types.d.ts +9 -0
- package/javascript/dist/cjs/types.js +96 -0
- package/javascript/dist/esm/numkong-browser.d.ts +14 -0
- package/javascript/dist/esm/numkong-browser.js +23 -0
- package/javascript/dist/esm/numkong-wasm.js +6 -12
- package/javascript/dist/esm/numkong.d.ts +7 -1
- package/javascript/dist/esm/numkong.js +37 -11
- package/javascript/dist/esm/types.d.ts +9 -0
- package/javascript/dist/esm/types.js +96 -0
- package/javascript/node-gyp-build.d.ts +4 -1
- package/javascript/numkong-browser.ts +40 -0
- package/javascript/numkong-wasm.ts +7 -13
- package/javascript/numkong.c +5 -26
- package/javascript/numkong.ts +36 -11
- package/javascript/tsconfig-base.json +1 -0
- package/javascript/tsconfig-cjs.json +6 -1
- package/javascript/types.ts +110 -0
- package/numkong.gypi +101 -0
- package/package.json +34 -13
- package/probes/arm_neon.c +8 -0
- package/probes/arm_neon_bfdot.c +9 -0
- package/probes/arm_neon_fhm.c +9 -0
- package/probes/arm_neon_half.c +8 -0
- package/probes/arm_neon_sdot.c +9 -0
- package/probes/arm_neonfp8.c +9 -0
- package/probes/arm_sme.c +16 -0
- package/probes/arm_sme2.c +16 -0
- package/probes/arm_sme2p1.c +16 -0
- package/probes/arm_sme_bf16.c +16 -0
- package/probes/arm_sme_bi32.c +16 -0
- package/probes/arm_sme_f64.c +16 -0
- package/probes/arm_sme_fa64.c +14 -0
- package/probes/arm_sme_half.c +16 -0
- package/probes/arm_sme_lut2.c +15 -0
- package/probes/arm_sve.c +18 -0
- package/probes/arm_sve2.c +20 -0
- package/probes/arm_sve2p1.c +18 -0
- package/probes/arm_sve_bfdot.c +20 -0
- package/probes/arm_sve_half.c +18 -0
- package/probes/arm_sve_sdot.c +21 -0
- package/probes/loongarch_lasx.c +12 -0
- package/probes/power_vsx.c +12 -0
- package/probes/probe.js +127 -0
- package/probes/riscv_rvv.c +14 -0
- package/probes/riscv_rvv_bb.c +15 -0
- package/probes/riscv_rvv_bf16.c +17 -0
- package/probes/riscv_rvv_half.c +14 -0
- package/probes/wasm_v128relaxed.c +11 -0
- package/probes/x86_alder.c +17 -0
- package/probes/x86_diamond.c +17 -0
- package/probes/x86_genoa.c +17 -0
- package/probes/x86_graniteamx.c +19 -0
- package/probes/x86_haswell.c +11 -0
- package/probes/x86_icelake.c +17 -0
- package/probes/x86_sapphire.c +16 -0
- package/probes/x86_sapphireamx.c +18 -0
- package/probes/x86_sierra.c +17 -0
- package/probes/x86_skylake.c +15 -0
- package/probes/x86_turin.c +17 -0
- package/wasm/numkong-emscripten.js +2 -0
- package/wasm/numkong.d.ts +14 -0
- package/wasm/numkong.js +1124 -0
- package/wasm/numkong.wasm +0 -0
- package/include/numkong/curved/neonhalf.h +0 -212
- package/include/numkong/dot/neonhalf.h +0 -198
- package/include/numkong/dots/neonhalf.h +0 -57
- package/include/numkong/mesh/neonhalf.h +0 -616
- package/include/numkong/reduce/neonhalf.h +0 -157
- package/include/numkong/spatial/neonhalf.h +0 -118
- package/include/numkong/spatial/sapphire.h +0 -343
- package/include/numkong/spatials/neonhalf.h +0 -58
- package/javascript/README.md +0 -246
|
@@ -7,7 +7,7 @@
|
|
|
7
7
|
* @sa include/numkong/reduce.h
|
|
8
8
|
*
|
|
9
9
|
* Uses AVX-VNNI (256-bit) for efficient widening dot-products:
|
|
10
|
-
* - `_mm256_dpwssd_epi32`: i16
|
|
10
|
+
* - `_mm256_dpwssd_epi32`: i16 × i16 → i32 accumulation (AVXVNNI, used for i16 and e3m2)
|
|
11
11
|
* - `_mm256_sad_epu8` + `_mm256_madd_epi16`: pure AVX2 SAD/MADD (used for u8)
|
|
12
12
|
* - `_mm256_cvtepu16_epi32` + `_mm256_mullo_epi32`: pure AVX2 (used for u16)
|
|
13
13
|
*/
|
|
@@ -81,7 +81,7 @@ NK_INTERNAL void nk_reduce_moments_u8_alder_strided_( //
|
|
|
81
81
|
nk_size_t idx_scalars = 0;
|
|
82
82
|
nk_size_t total_scalars = count * stride_elements;
|
|
83
83
|
nk_size_t step = nk_size_round_up_to_multiple_(32, stride_elements);
|
|
84
|
-
for (; idx_scalars +
|
|
84
|
+
for (; idx_scalars + stride_elements + 31 <= total_scalars; idx_scalars += step) {
|
|
85
85
|
__m256i data_u8x32 = _mm256_loadu_si256((__m256i const *)(data + idx_scalars));
|
|
86
86
|
data_u8x32 = _mm256_and_si256(data_u8x32, stride_mask_u8x32);
|
|
87
87
|
sum_u64x4 = _mm256_add_epi64(sum_u64x4, _mm256_sad_epu8(data_u8x32, zero_u8x32));
|
|
@@ -171,7 +171,7 @@ NK_INTERNAL void nk_reduce_moments_i16_alder_strided_( //
|
|
|
171
171
|
nk_size_t idx_scalars = 0;
|
|
172
172
|
nk_size_t total_scalars = count * stride_elements;
|
|
173
173
|
nk_size_t step = nk_size_round_up_to_multiple_(16, stride_elements);
|
|
174
|
-
for (; idx_scalars +
|
|
174
|
+
for (; idx_scalars + stride_elements + 15 <= total_scalars; idx_scalars += step) {
|
|
175
175
|
__m256i data_i16x16 = _mm256_loadu_si256((__m256i const *)(data + idx_scalars));
|
|
176
176
|
data_i16x16 = _mm256_and_si256(data_i16x16, stride_mask_i16x16);
|
|
177
177
|
__m256i sum_i32x8 = _mm256_dpwssd_avx_epi32(_mm256_setzero_si256(), data_i16x16, ones_i16x16);
|
|
@@ -257,7 +257,7 @@ NK_INTERNAL void nk_reduce_moments_u16_alder_strided_( //
|
|
|
257
257
|
nk_size_t idx_scalars = 0;
|
|
258
258
|
nk_size_t total_scalars = count * stride_elements;
|
|
259
259
|
nk_size_t step = nk_size_round_up_to_multiple_(16, stride_elements);
|
|
260
|
-
for (; idx_scalars +
|
|
260
|
+
for (; idx_scalars + stride_elements + 15 <= total_scalars; idx_scalars += step) {
|
|
261
261
|
__m256i data_u16x16 = _mm256_loadu_si256((__m256i const *)(data + idx_scalars));
|
|
262
262
|
data_u16x16 = _mm256_and_si256(data_u16x16, stride_mask_i16x16);
|
|
263
263
|
__m256i low_u32x8 = _mm256_cvtepu16_epi32(_mm256_castsi256_si128(data_u16x16));
|
|
@@ -313,19 +313,19 @@ NK_PUBLIC void nk_reduce_moments_u16_alder( //
|
|
|
313
313
|
/**
|
|
314
314
|
* @section e3m2 moments via integer VNNI (dpwssd)
|
|
315
315
|
*
|
|
316
|
-
* Every e3m2 value
|
|
316
|
+
* Every e3m2 value × 16 is an exact integer in [-448, +448] (i16 range).
|
|
317
317
|
* We use dual-VPSHUFB for the low byte + threshold compare for the high byte,
|
|
318
318
|
* then UNPACKLO/HI to form unsigned i16, apply sign via `_mm256_sign_epi16`,
|
|
319
|
-
* and accumulate with `_mm256_dpwssd_epi32` (signed i16
|
|
319
|
+
* and accumulate with `_mm256_dpwssd_epi32` (signed i16 × signed i16 → i32).
|
|
320
320
|
* Final: sum = i32_sum / 16, sumsq = i32_sumsq / 256.
|
|
321
321
|
*/
|
|
322
322
|
NK_INTERNAL void nk_reduce_moments_e3m2_alder_contiguous_( //
|
|
323
323
|
nk_e3m2_t const *data, nk_size_t count, //
|
|
324
324
|
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
325
|
-
__m256i const
|
|
325
|
+
__m256i const lut_low_byte_first_u8x32 = _mm256_set_epi8( //
|
|
326
326
|
28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2, 1, 0, //
|
|
327
327
|
28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2, 1, 0); //
|
|
328
|
-
__m256i const
|
|
328
|
+
__m256i const lut_low_byte_second_u8x32 = _mm256_set_epi8( //
|
|
329
329
|
(char)192, (char)128, 64, 0, (char)224, (char)192, (char)160, (char)128, 112, 96, 80, 64, 56, 48, 40, 32, //
|
|
330
330
|
(char)192, (char)128, 64, 0, (char)224, (char)192, (char)160, (char)128, 112, 96, 80, 64, 56, 48, 40, 32); //
|
|
331
331
|
__m256i const nibble_mask_u8x32 = _mm256_set1_epi8(0x0F);
|
|
@@ -344,25 +344,26 @@ NK_INTERNAL void nk_reduce_moments_e3m2_alder_contiguous_( //
|
|
|
344
344
|
__m256i shuffle_idx_u8x32 = _mm256_and_si256(magnitude_u8x32, nibble_mask_u8x32);
|
|
345
345
|
__m256i upper_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(magnitude_u8x32, half_select_u8x32),
|
|
346
346
|
half_select_u8x32);
|
|
347
|
-
__m256i
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
__m256i
|
|
351
|
-
|
|
352
|
-
__m256i
|
|
347
|
+
__m256i low_bytes_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_low_byte_first_u8x32, shuffle_idx_u8x32),
|
|
348
|
+
_mm256_shuffle_epi8(lut_low_byte_second_u8x32, shuffle_idx_u8x32),
|
|
349
|
+
upper_select_u8x32);
|
|
350
|
+
__m256i high_bytes_u8x32 = _mm256_and_si256(_mm256_cmpgt_epi8(magnitude_u8x32, high_threshold_u8x32),
|
|
351
|
+
ones_u8x32);
|
|
352
|
+
__m256i unsigned_low_i16x16 = _mm256_unpacklo_epi8(low_bytes_u8x32, high_bytes_u8x32);
|
|
353
|
+
__m256i unsigned_high_i16x16 = _mm256_unpackhi_epi8(low_bytes_u8x32, high_bytes_u8x32);
|
|
353
354
|
// Sign handling: extract sign bit, widen to i16, create +1/-1, apply via VPSIGNW
|
|
354
355
|
__m256i negate_mask_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(data_u8x32, sign_mask_u8x32), sign_mask_u8x32);
|
|
355
|
-
__m256i
|
|
356
|
-
__m256i
|
|
357
|
-
__m256i
|
|
358
|
-
|
|
359
|
-
__m256i
|
|
360
|
-
|
|
361
|
-
// VNNI accumulation: dpwssd (signed i16
|
|
362
|
-
sum_i32x8 = _mm256_dpwssd_avx_epi32(sum_i32x8,
|
|
363
|
-
sum_i32x8 = _mm256_dpwssd_avx_epi32(sum_i32x8,
|
|
364
|
-
sumsq_i32x8 = _mm256_dpwssd_avx_epi32(sumsq_i32x8,
|
|
365
|
-
sumsq_i32x8 = _mm256_dpwssd_avx_epi32(sumsq_i32x8,
|
|
356
|
+
__m256i negate_low_i16x16 = _mm256_unpacklo_epi8(negate_mask_u8x32, negate_mask_u8x32);
|
|
357
|
+
__m256i negate_high_i16x16 = _mm256_unpackhi_epi8(negate_mask_u8x32, negate_mask_u8x32);
|
|
358
|
+
__m256i signed_low_i16x16 = _mm256_sign_epi16(unsigned_low_i16x16,
|
|
359
|
+
_mm256_or_si256(negate_low_i16x16, ones_i16x16));
|
|
360
|
+
__m256i signed_high_i16x16 = _mm256_sign_epi16(unsigned_high_i16x16,
|
|
361
|
+
_mm256_or_si256(negate_high_i16x16, ones_i16x16));
|
|
362
|
+
// VNNI accumulation: dpwssd (signed i16 × signed i16 → i32)
|
|
363
|
+
sum_i32x8 = _mm256_dpwssd_avx_epi32(sum_i32x8, signed_low_i16x16, ones_i16x16);
|
|
364
|
+
sum_i32x8 = _mm256_dpwssd_avx_epi32(sum_i32x8, signed_high_i16x16, ones_i16x16);
|
|
365
|
+
sumsq_i32x8 = _mm256_dpwssd_avx_epi32(sumsq_i32x8, signed_low_i16x16, signed_low_i16x16);
|
|
366
|
+
sumsq_i32x8 = _mm256_dpwssd_avx_epi32(sumsq_i32x8, signed_high_i16x16, signed_high_i16x16);
|
|
366
367
|
}
|
|
367
368
|
nk_i32_t sum = nk_reduce_add_i32x8_haswell_(sum_i32x8);
|
|
368
369
|
nk_i32_t sumsq = nk_reduce_add_i32x8_haswell_(sumsq_i32x8);
|
|
@@ -375,25 +376,27 @@ NK_INTERNAL void nk_reduce_moments_e3m2_alder_contiguous_( //
|
|
|
375
376
|
__m256i shuffle_idx_u8x32 = _mm256_and_si256(magnitude_u8x32, nibble_mask_u8x32);
|
|
376
377
|
__m256i upper_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(magnitude_u8x32, half_select_u8x32),
|
|
377
378
|
half_select_u8x32);
|
|
378
|
-
__m256i
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
__m256i
|
|
382
|
-
|
|
383
|
-
__m256i
|
|
379
|
+
__m256i low_bytes_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_low_byte_first_u8x32, shuffle_idx_u8x32),
|
|
380
|
+
_mm256_shuffle_epi8(lut_low_byte_second_u8x32, shuffle_idx_u8x32),
|
|
381
|
+
upper_select_u8x32);
|
|
382
|
+
__m256i high_bytes_u8x32 = _mm256_and_si256(_mm256_cmpgt_epi8(magnitude_u8x32, high_threshold_u8x32),
|
|
383
|
+
ones_u8x32);
|
|
384
|
+
__m256i unsigned_low_i16x16 = _mm256_unpacklo_epi8(low_bytes_u8x32, high_bytes_u8x32);
|
|
385
|
+
__m256i unsigned_high_i16x16 = _mm256_unpackhi_epi8(low_bytes_u8x32, high_bytes_u8x32);
|
|
384
386
|
__m256i negate_mask_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(data_u8x32, sign_mask_u8x32), sign_mask_u8x32);
|
|
385
|
-
__m256i
|
|
386
|
-
__m256i
|
|
387
|
-
__m256i
|
|
388
|
-
|
|
389
|
-
__m256i
|
|
390
|
-
|
|
391
|
-
__m256i
|
|
392
|
-
|
|
393
|
-
__m256i
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
387
|
+
__m256i negate_low_i16x16 = _mm256_unpacklo_epi8(negate_mask_u8x32, negate_mask_u8x32);
|
|
388
|
+
__m256i negate_high_i16x16 = _mm256_unpackhi_epi8(negate_mask_u8x32, negate_mask_u8x32);
|
|
389
|
+
__m256i signed_low_i16x16 = _mm256_sign_epi16(unsigned_low_i16x16,
|
|
390
|
+
_mm256_or_si256(negate_low_i16x16, ones_i16x16));
|
|
391
|
+
__m256i signed_high_i16x16 = _mm256_sign_epi16(unsigned_high_i16x16,
|
|
392
|
+
_mm256_or_si256(negate_high_i16x16, ones_i16x16));
|
|
393
|
+
__m256i tail_sum_i32x8 = _mm256_dpwssd_avx_epi32(_mm256_setzero_si256(), signed_low_i16x16, ones_i16x16);
|
|
394
|
+
tail_sum_i32x8 = _mm256_dpwssd_avx_epi32(tail_sum_i32x8, signed_high_i16x16, ones_i16x16);
|
|
395
|
+
__m256i tail_sumsq_i32x8 = _mm256_dpwssd_avx_epi32(_mm256_setzero_si256(), signed_low_i16x16,
|
|
396
|
+
signed_low_i16x16);
|
|
397
|
+
tail_sumsq_i32x8 = _mm256_dpwssd_avx_epi32(tail_sumsq_i32x8, signed_high_i16x16, signed_high_i16x16);
|
|
398
|
+
sum += nk_reduce_add_i32x8_haswell_(tail_sum_i32x8);
|
|
399
|
+
sumsq += nk_reduce_add_i32x8_haswell_(tail_sumsq_i32x8);
|
|
397
400
|
}
|
|
398
401
|
*sum_ptr = (nk_f32_t)sum / 16.0f;
|
|
399
402
|
*sumsq_ptr = (nk_f32_t)sumsq / 256.0f;
|
|
@@ -403,10 +406,10 @@ NK_INTERNAL void nk_reduce_moments_e3m2_alder_strided_( //
|
|
|
403
406
|
nk_e3m2_t const *data, nk_size_t count, nk_size_t stride_elements, //
|
|
404
407
|
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
405
408
|
__m256i stride_mask_u8x32 = nk_stride_blend_u1x32_(stride_elements);
|
|
406
|
-
__m256i const
|
|
409
|
+
__m256i const lut_low_byte_first_u8x32 = _mm256_set_epi8( //
|
|
407
410
|
28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2, 1, 0, //
|
|
408
411
|
28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2, 1, 0); //
|
|
409
|
-
__m256i const
|
|
412
|
+
__m256i const lut_low_byte_second_u8x32 = _mm256_set_epi8( //
|
|
410
413
|
(char)192, (char)128, 64, 0, (char)224, (char)192, (char)160, (char)128, 112, 96, 80, 64, 56, 48, 40, 32, //
|
|
411
414
|
(char)192, (char)128, 64, 0, (char)224, (char)192, (char)160, (char)128, 112, 96, 80, 64, 56, 48, 40, 32); //
|
|
412
415
|
__m256i const nibble_mask_u8x32 = _mm256_set1_epi8(0x0F);
|
|
@@ -421,30 +424,31 @@ NK_INTERNAL void nk_reduce_moments_e3m2_alder_strided_( //
|
|
|
421
424
|
nk_size_t idx_scalars = 0;
|
|
422
425
|
nk_size_t total_scalars = count * stride_elements;
|
|
423
426
|
nk_size_t step = nk_size_round_up_to_multiple_(32, stride_elements);
|
|
424
|
-
for (; idx_scalars +
|
|
427
|
+
for (; idx_scalars + stride_elements + 31 <= total_scalars; idx_scalars += step) {
|
|
425
428
|
__m256i data_u8x32 = _mm256_loadu_si256((__m256i const *)(data + idx_scalars));
|
|
426
429
|
data_u8x32 = _mm256_and_si256(data_u8x32, stride_mask_u8x32);
|
|
427
430
|
__m256i magnitude_u8x32 = _mm256_and_si256(data_u8x32, magnitude_mask_u8x32);
|
|
428
431
|
__m256i shuffle_idx_u8x32 = _mm256_and_si256(magnitude_u8x32, nibble_mask_u8x32);
|
|
429
432
|
__m256i upper_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(magnitude_u8x32, half_select_u8x32),
|
|
430
433
|
half_select_u8x32);
|
|
431
|
-
__m256i
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
__m256i
|
|
435
|
-
|
|
436
|
-
__m256i
|
|
434
|
+
__m256i low_bytes_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_low_byte_first_u8x32, shuffle_idx_u8x32),
|
|
435
|
+
_mm256_shuffle_epi8(lut_low_byte_second_u8x32, shuffle_idx_u8x32),
|
|
436
|
+
upper_select_u8x32);
|
|
437
|
+
__m256i high_bytes_u8x32 = _mm256_and_si256(_mm256_cmpgt_epi8(magnitude_u8x32, high_threshold_u8x32),
|
|
438
|
+
ones_u8x32);
|
|
439
|
+
__m256i unsigned_low_i16x16 = _mm256_unpacklo_epi8(low_bytes_u8x32, high_bytes_u8x32);
|
|
440
|
+
__m256i unsigned_high_i16x16 = _mm256_unpackhi_epi8(low_bytes_u8x32, high_bytes_u8x32);
|
|
437
441
|
__m256i negate_mask_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(data_u8x32, sign_mask_u8x32), sign_mask_u8x32);
|
|
438
|
-
__m256i
|
|
439
|
-
__m256i
|
|
440
|
-
__m256i
|
|
441
|
-
|
|
442
|
-
__m256i
|
|
443
|
-
|
|
444
|
-
sum_i32x8 = _mm256_dpwssd_avx_epi32(sum_i32x8,
|
|
445
|
-
sum_i32x8 = _mm256_dpwssd_avx_epi32(sum_i32x8,
|
|
446
|
-
sumsq_i32x8 = _mm256_dpwssd_avx_epi32(sumsq_i32x8,
|
|
447
|
-
sumsq_i32x8 = _mm256_dpwssd_avx_epi32(sumsq_i32x8,
|
|
442
|
+
__m256i negate_low_i16x16 = _mm256_unpacklo_epi8(negate_mask_u8x32, negate_mask_u8x32);
|
|
443
|
+
__m256i negate_high_i16x16 = _mm256_unpackhi_epi8(negate_mask_u8x32, negate_mask_u8x32);
|
|
444
|
+
__m256i signed_low_i16x16 = _mm256_sign_epi16(unsigned_low_i16x16,
|
|
445
|
+
_mm256_or_si256(negate_low_i16x16, ones_i16x16));
|
|
446
|
+
__m256i signed_high_i16x16 = _mm256_sign_epi16(unsigned_high_i16x16,
|
|
447
|
+
_mm256_or_si256(negate_high_i16x16, ones_i16x16));
|
|
448
|
+
sum_i32x8 = _mm256_dpwssd_avx_epi32(sum_i32x8, signed_low_i16x16, ones_i16x16);
|
|
449
|
+
sum_i32x8 = _mm256_dpwssd_avx_epi32(sum_i32x8, signed_high_i16x16, ones_i16x16);
|
|
450
|
+
sumsq_i32x8 = _mm256_dpwssd_avx_epi32(sumsq_i32x8, signed_low_i16x16, signed_low_i16x16);
|
|
451
|
+
sumsq_i32x8 = _mm256_dpwssd_avx_epi32(sumsq_i32x8, signed_high_i16x16, signed_high_i16x16);
|
|
448
452
|
}
|
|
449
453
|
nk_i32_t sum = nk_reduce_add_i32x8_haswell_(sum_i32x8);
|
|
450
454
|
nk_i32_t sumsq = nk_reduce_add_i32x8_haswell_(sumsq_i32x8);
|
|
@@ -493,10 +497,10 @@ NK_PUBLIC void nk_reduce_moments_e3m2_alder( //
|
|
|
493
497
|
NK_INTERNAL void nk_reduce_moments_e2m3_alder_contiguous_( //
|
|
494
498
|
nk_e2m3_t const *data, nk_size_t count, //
|
|
495
499
|
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
496
|
-
__m256i const
|
|
497
|
-
|
|
498
|
-
__m256i const
|
|
499
|
-
|
|
500
|
+
__m256i const lut_low_u8x32 = _mm256_set_epi8(30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0, //
|
|
501
|
+
30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
|
|
502
|
+
__m256i const lut_high_u8x32 = _mm256_set_epi8(120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32,
|
|
503
|
+
120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32);
|
|
500
504
|
__m256i const nibble_mask_u8x32 = _mm256_set1_epi8(0x0F);
|
|
501
505
|
__m256i const magnitude_mask_u8x32 = _mm256_set1_epi8(0x1F);
|
|
502
506
|
__m256i const half_select_u8x32 = _mm256_set1_epi8(0x10);
|
|
@@ -512,8 +516,8 @@ NK_INTERNAL void nk_reduce_moments_e2m3_alder_contiguous_( //
|
|
|
512
516
|
__m256i shuffle_idx_u8x32 = _mm256_and_si256(magnitude_u8x32, nibble_mask_u8x32);
|
|
513
517
|
__m256i upper_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(magnitude_u8x32, half_select_u8x32),
|
|
514
518
|
half_select_u8x32);
|
|
515
|
-
__m256i unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(
|
|
516
|
-
_mm256_shuffle_epi8(
|
|
519
|
+
__m256i unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_low_u8x32, shuffle_idx_u8x32),
|
|
520
|
+
_mm256_shuffle_epi8(lut_high_u8x32, shuffle_idx_u8x32),
|
|
517
521
|
upper_select_u8x32);
|
|
518
522
|
// Sign vector: +1 for positive, -1 for negative (i8)
|
|
519
523
|
__m256i negate_mask_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(data_u8x32, sign_mask_u8x32), sign_mask_u8x32);
|
|
@@ -534,8 +538,8 @@ NK_INTERNAL void nk_reduce_moments_e2m3_alder_contiguous_( //
|
|
|
534
538
|
__m256i shuffle_idx_u8x32 = _mm256_and_si256(magnitude_u8x32, nibble_mask_u8x32);
|
|
535
539
|
__m256i upper_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(magnitude_u8x32, half_select_u8x32),
|
|
536
540
|
half_select_u8x32);
|
|
537
|
-
__m256i unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(
|
|
538
|
-
_mm256_shuffle_epi8(
|
|
541
|
+
__m256i unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_low_u8x32, shuffle_idx_u8x32),
|
|
542
|
+
_mm256_shuffle_epi8(lut_high_u8x32, shuffle_idx_u8x32),
|
|
539
543
|
upper_select_u8x32);
|
|
540
544
|
__m256i negate_mask_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(data_u8x32, sign_mask_u8x32), sign_mask_u8x32);
|
|
541
545
|
__m256i sign_i8x32 = _mm256_blendv_epi8(ones_i8x32, neg_ones_i8x32, negate_mask_u8x32);
|
|
@@ -552,10 +556,10 @@ NK_INTERNAL void nk_reduce_moments_e2m3_alder_strided_( //
|
|
|
552
556
|
nk_e2m3_t const *data, nk_size_t count, nk_size_t stride_elements, //
|
|
553
557
|
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
554
558
|
__m256i stride_mask_u8x32 = nk_stride_blend_u1x32_(stride_elements);
|
|
555
|
-
__m256i const
|
|
556
|
-
|
|
557
|
-
__m256i const
|
|
558
|
-
|
|
559
|
+
__m256i const lut_low_u8x32 = _mm256_set_epi8(30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0, //
|
|
560
|
+
30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
|
|
561
|
+
__m256i const lut_high_u8x32 = _mm256_set_epi8(120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32,
|
|
562
|
+
120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32);
|
|
559
563
|
__m256i const nibble_mask_u8x32 = _mm256_set1_epi8(0x0F);
|
|
560
564
|
__m256i const magnitude_mask_u8x32 = _mm256_set1_epi8(0x1F);
|
|
561
565
|
__m256i const half_select_u8x32 = _mm256_set1_epi8(0x10);
|
|
@@ -567,15 +571,15 @@ NK_INTERNAL void nk_reduce_moments_e2m3_alder_strided_( //
|
|
|
567
571
|
nk_size_t idx_scalars = 0;
|
|
568
572
|
nk_size_t total_scalars = count * stride_elements;
|
|
569
573
|
nk_size_t step = nk_size_round_up_to_multiple_(32, stride_elements);
|
|
570
|
-
for (; idx_scalars +
|
|
574
|
+
for (; idx_scalars + stride_elements + 31 <= total_scalars; idx_scalars += step) {
|
|
571
575
|
__m256i data_u8x32 = _mm256_loadu_si256((__m256i const *)(data + idx_scalars));
|
|
572
576
|
data_u8x32 = _mm256_and_si256(data_u8x32, stride_mask_u8x32);
|
|
573
577
|
__m256i magnitude_u8x32 = _mm256_and_si256(data_u8x32, magnitude_mask_u8x32);
|
|
574
578
|
__m256i shuffle_idx_u8x32 = _mm256_and_si256(magnitude_u8x32, nibble_mask_u8x32);
|
|
575
579
|
__m256i upper_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(magnitude_u8x32, half_select_u8x32),
|
|
576
580
|
half_select_u8x32);
|
|
577
|
-
__m256i unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(
|
|
578
|
-
_mm256_shuffle_epi8(
|
|
581
|
+
__m256i unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_low_u8x32, shuffle_idx_u8x32),
|
|
582
|
+
_mm256_shuffle_epi8(lut_high_u8x32, shuffle_idx_u8x32),
|
|
579
583
|
upper_select_u8x32);
|
|
580
584
|
__m256i negate_mask_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(data_u8x32, sign_mask_u8x32), sign_mask_u8x32);
|
|
581
585
|
__m256i sign_i8x32 = _mm256_blendv_epi8(ones_i8x32, neg_ones_i8x32, negate_mask_u8x32);
|