numkong 7.0.0 → 7.4.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +239 -122
- package/binding.gyp +25 -491
- package/c/dispatch_bf16.c +59 -1
- package/c/dispatch_e2m3.c +41 -8
- package/c/dispatch_e3m2.c +49 -8
- package/c/dispatch_e4m3.c +51 -9
- package/c/dispatch_e5m2.c +45 -1
- package/c/dispatch_f16.c +79 -26
- package/c/dispatch_f16c.c +5 -5
- package/c/dispatch_f32.c +56 -0
- package/c/dispatch_f64.c +52 -0
- package/c/dispatch_i4.c +3 -0
- package/c/dispatch_i8.c +62 -3
- package/c/dispatch_other.c +18 -0
- package/c/dispatch_u1.c +54 -9
- package/c/dispatch_u4.c +3 -0
- package/c/dispatch_u8.c +64 -3
- package/c/numkong.c +3 -0
- package/include/README.md +79 -9
- package/include/numkong/attention/sapphireamx.h +278 -276
- package/include/numkong/attention/sme.h +983 -977
- package/include/numkong/attention.h +1 -1
- package/include/numkong/capabilities.h +289 -94
- package/include/numkong/cast/README.md +40 -40
- package/include/numkong/cast/diamond.h +64 -0
- package/include/numkong/cast/haswell.h +42 -194
- package/include/numkong/cast/icelake.h +42 -37
- package/include/numkong/cast/loongsonasx.h +252 -0
- package/include/numkong/cast/neon.h +216 -249
- package/include/numkong/cast/powervsx.h +449 -0
- package/include/numkong/cast/rvv.h +223 -274
- package/include/numkong/cast/sapphire.h +18 -18
- package/include/numkong/cast/serial.h +1018 -944
- package/include/numkong/cast/skylake.h +82 -23
- package/include/numkong/cast/v128relaxed.h +462 -105
- package/include/numkong/cast.h +24 -0
- package/include/numkong/cast.hpp +44 -0
- package/include/numkong/curved/README.md +17 -17
- package/include/numkong/curved/neon.h +131 -7
- package/include/numkong/curved/neonbfdot.h +6 -7
- package/include/numkong/curved/rvv.h +26 -26
- package/include/numkong/curved/smef64.h +186 -182
- package/include/numkong/curved.h +14 -18
- package/include/numkong/dot/README.md +154 -137
- package/include/numkong/dot/alder.h +43 -43
- package/include/numkong/dot/diamond.h +158 -0
- package/include/numkong/dot/genoa.h +4 -30
- package/include/numkong/dot/haswell.h +215 -180
- package/include/numkong/dot/icelake.h +190 -76
- package/include/numkong/dot/loongsonasx.h +671 -0
- package/include/numkong/dot/neon.h +124 -73
- package/include/numkong/dot/neonbfdot.h +11 -12
- package/include/numkong/dot/neonfhm.h +44 -46
- package/include/numkong/dot/neonfp8.h +323 -0
- package/include/numkong/dot/neonsdot.h +190 -76
- package/include/numkong/dot/powervsx.h +752 -0
- package/include/numkong/dot/rvv.h +92 -84
- package/include/numkong/dot/rvvbf16.h +12 -12
- package/include/numkong/dot/rvvhalf.h +12 -12
- package/include/numkong/dot/sapphire.h +4 -4
- package/include/numkong/dot/serial.h +66 -30
- package/include/numkong/dot/sierra.h +31 -31
- package/include/numkong/dot/skylake.h +142 -110
- package/include/numkong/dot/sve.h +217 -177
- package/include/numkong/dot/svebfdot.h +10 -10
- package/include/numkong/dot/svehalf.h +85 -41
- package/include/numkong/dot/svesdot.h +89 -0
- package/include/numkong/dot/v128relaxed.h +124 -89
- package/include/numkong/dot.h +114 -48
- package/include/numkong/dots/README.md +203 -203
- package/include/numkong/dots/alder.h +12 -9
- package/include/numkong/dots/diamond.h +86 -0
- package/include/numkong/dots/genoa.h +10 -4
- package/include/numkong/dots/haswell.h +63 -48
- package/include/numkong/dots/icelake.h +27 -18
- package/include/numkong/dots/loongsonasx.h +176 -0
- package/include/numkong/dots/neon.h +14 -11
- package/include/numkong/dots/neonbfdot.h +4 -3
- package/include/numkong/dots/neonfhm.h +11 -9
- package/include/numkong/dots/neonfp8.h +99 -0
- package/include/numkong/dots/neonsdot.h +48 -12
- package/include/numkong/dots/powervsx.h +194 -0
- package/include/numkong/dots/rvv.h +451 -344
- package/include/numkong/dots/sapphireamx.h +1028 -984
- package/include/numkong/dots/serial.h +213 -197
- package/include/numkong/dots/sierra.h +10 -7
- package/include/numkong/dots/skylake.h +47 -36
- package/include/numkong/dots/sme.h +2001 -2364
- package/include/numkong/dots/smebi32.h +175 -162
- package/include/numkong/dots/smef64.h +328 -323
- package/include/numkong/dots/v128relaxed.h +64 -41
- package/include/numkong/dots.h +573 -293
- package/include/numkong/dots.hpp +45 -43
- package/include/numkong/each/README.md +133 -137
- package/include/numkong/each/haswell.h +6 -6
- package/include/numkong/each/icelake.h +7 -7
- package/include/numkong/each/neon.h +76 -42
- package/include/numkong/each/neonbfdot.h +11 -12
- package/include/numkong/each/neonhalf.h +24 -116
- package/include/numkong/each/rvv.h +28 -28
- package/include/numkong/each/sapphire.h +27 -161
- package/include/numkong/each/serial.h +6 -6
- package/include/numkong/each/skylake.h +7 -7
- package/include/numkong/each/v128relaxed.h +562 -0
- package/include/numkong/each.h +148 -62
- package/include/numkong/each.hpp +2 -2
- package/include/numkong/geospatial/README.md +18 -18
- package/include/numkong/geospatial/haswell.h +365 -325
- package/include/numkong/geospatial/neon.h +350 -306
- package/include/numkong/geospatial/rvv.h +4 -4
- package/include/numkong/geospatial/skylake.h +376 -340
- package/include/numkong/geospatial/v128relaxed.h +366 -327
- package/include/numkong/geospatial.h +17 -17
- package/include/numkong/matrix.hpp +4 -4
- package/include/numkong/maxsim/README.md +14 -14
- package/include/numkong/maxsim/alder.h +6 -6
- package/include/numkong/maxsim/genoa.h +4 -4
- package/include/numkong/maxsim/haswell.h +6 -6
- package/include/numkong/maxsim/icelake.h +18 -18
- package/include/numkong/maxsim/neonsdot.h +21 -21
- package/include/numkong/maxsim/sapphireamx.h +14 -14
- package/include/numkong/maxsim/serial.h +6 -6
- package/include/numkong/maxsim/sme.h +221 -196
- package/include/numkong/maxsim/v128relaxed.h +6 -6
- package/include/numkong/mesh/README.md +62 -56
- package/include/numkong/mesh/haswell.h +339 -464
- package/include/numkong/mesh/neon.h +1100 -519
- package/include/numkong/mesh/neonbfdot.h +36 -68
- package/include/numkong/mesh/rvv.h +530 -435
- package/include/numkong/mesh/serial.h +75 -91
- package/include/numkong/mesh/skylake.h +1627 -302
- package/include/numkong/mesh/v128relaxed.h +443 -330
- package/include/numkong/mesh.h +63 -49
- package/include/numkong/mesh.hpp +4 -4
- package/include/numkong/numkong.h +3 -3
- package/include/numkong/numkong.hpp +1 -0
- package/include/numkong/probability/README.md +23 -19
- package/include/numkong/probability/neon.h +82 -52
- package/include/numkong/probability/rvv.h +28 -23
- package/include/numkong/probability/serial.h +51 -39
- package/include/numkong/probability.h +20 -23
- package/include/numkong/random.h +1 -1
- package/include/numkong/reduce/README.md +143 -138
- package/include/numkong/reduce/alder.h +81 -77
- package/include/numkong/reduce/haswell.h +222 -220
- package/include/numkong/reduce/neon.h +629 -519
- package/include/numkong/reduce/neonbfdot.h +7 -218
- package/include/numkong/reduce/neonfhm.h +9 -381
- package/include/numkong/reduce/neonsdot.h +9 -9
- package/include/numkong/reduce/rvv.h +928 -802
- package/include/numkong/reduce/serial.h +23 -27
- package/include/numkong/reduce/sierra.h +20 -20
- package/include/numkong/reduce/skylake.h +326 -324
- package/include/numkong/reduce/v128relaxed.h +52 -52
- package/include/numkong/reduce.h +4 -23
- package/include/numkong/reduce.hpp +156 -11
- package/include/numkong/scalar/README.md +6 -6
- package/include/numkong/scalar/haswell.h +26 -17
- package/include/numkong/scalar/loongsonasx.h +74 -0
- package/include/numkong/scalar/neon.h +9 -9
- package/include/numkong/scalar/powervsx.h +96 -0
- package/include/numkong/scalar/rvv.h +2 -2
- package/include/numkong/scalar/sapphire.h +21 -10
- package/include/numkong/scalar/serial.h +21 -21
- package/include/numkong/scalar.h +13 -0
- package/include/numkong/set/README.md +28 -28
- package/include/numkong/set/haswell.h +12 -12
- package/include/numkong/set/icelake.h +14 -14
- package/include/numkong/set/loongsonasx.h +181 -0
- package/include/numkong/set/neon.h +17 -18
- package/include/numkong/set/powervsx.h +326 -0
- package/include/numkong/set/rvv.h +4 -4
- package/include/numkong/set/serial.h +6 -6
- package/include/numkong/set/sve.h +60 -59
- package/include/numkong/set/v128relaxed.h +6 -6
- package/include/numkong/set.h +21 -7
- package/include/numkong/sets/README.md +26 -26
- package/include/numkong/sets/loongsonasx.h +52 -0
- package/include/numkong/sets/powervsx.h +65 -0
- package/include/numkong/sets/smebi32.h +395 -364
- package/include/numkong/sets.h +83 -40
- package/include/numkong/sparse/README.md +4 -4
- package/include/numkong/sparse/icelake.h +101 -101
- package/include/numkong/sparse/serial.h +1 -1
- package/include/numkong/sparse/sve2.h +137 -141
- package/include/numkong/sparse/turin.h +12 -12
- package/include/numkong/sparse.h +10 -10
- package/include/numkong/spatial/README.md +230 -226
- package/include/numkong/spatial/alder.h +113 -116
- package/include/numkong/spatial/diamond.h +240 -0
- package/include/numkong/spatial/genoa.h +0 -68
- package/include/numkong/spatial/haswell.h +74 -55
- package/include/numkong/spatial/icelake.h +539 -58
- package/include/numkong/spatial/loongsonasx.h +483 -0
- package/include/numkong/spatial/neon.h +125 -52
- package/include/numkong/spatial/neonbfdot.h +8 -9
- package/include/numkong/spatial/neonfp8.h +258 -0
- package/include/numkong/spatial/neonsdot.h +180 -12
- package/include/numkong/spatial/powervsx.h +738 -0
- package/include/numkong/spatial/rvv.h +146 -139
- package/include/numkong/spatial/rvvbf16.h +17 -12
- package/include/numkong/spatial/rvvhalf.h +13 -10
- package/include/numkong/spatial/serial.h +13 -12
- package/include/numkong/spatial/sierra.h +232 -39
- package/include/numkong/spatial/skylake.h +73 -74
- package/include/numkong/spatial/sve.h +93 -72
- package/include/numkong/spatial/svebfdot.h +29 -29
- package/include/numkong/spatial/svehalf.h +52 -26
- package/include/numkong/spatial/svesdot.h +142 -0
- package/include/numkong/spatial/v128relaxed.h +293 -41
- package/include/numkong/spatial.h +338 -82
- package/include/numkong/spatials/README.md +194 -194
- package/include/numkong/spatials/diamond.h +82 -0
- package/include/numkong/spatials/haswell.h +2 -2
- package/include/numkong/spatials/loongsonasx.h +153 -0
- package/include/numkong/spatials/neonfp8.h +111 -0
- package/include/numkong/spatials/neonsdot.h +34 -0
- package/include/numkong/spatials/powervsx.h +153 -0
- package/include/numkong/spatials/rvv.h +259 -243
- package/include/numkong/spatials/sapphireamx.h +173 -173
- package/include/numkong/spatials/serial.h +2 -2
- package/include/numkong/spatials/skylake.h +2 -2
- package/include/numkong/spatials/sme.h +590 -605
- package/include/numkong/spatials/smef64.h +139 -130
- package/include/numkong/spatials/v128relaxed.h +2 -2
- package/include/numkong/spatials.h +820 -500
- package/include/numkong/spatials.hpp +49 -48
- package/include/numkong/tensor.hpp +406 -17
- package/include/numkong/trigonometry/README.md +19 -19
- package/include/numkong/trigonometry/haswell.h +402 -401
- package/include/numkong/trigonometry/neon.h +386 -387
- package/include/numkong/trigonometry/rvv.h +52 -51
- package/include/numkong/trigonometry/serial.h +13 -13
- package/include/numkong/trigonometry/skylake.h +373 -369
- package/include/numkong/trigonometry/v128relaxed.h +375 -374
- package/include/numkong/trigonometry.h +13 -13
- package/include/numkong/trigonometry.hpp +2 -2
- package/include/numkong/types.h +287 -49
- package/include/numkong/types.hpp +436 -12
- package/include/numkong/vector.hpp +82 -14
- package/javascript/dist/cjs/numkong-wasm.js +6 -12
- package/javascript/dist/cjs/numkong.d.ts +7 -1
- package/javascript/dist/cjs/numkong.js +37 -11
- package/javascript/dist/cjs/types.d.ts +9 -0
- package/javascript/dist/cjs/types.js +96 -0
- package/javascript/dist/esm/numkong-browser.d.ts +14 -0
- package/javascript/dist/esm/numkong-browser.js +23 -0
- package/javascript/dist/esm/numkong-wasm.js +6 -12
- package/javascript/dist/esm/numkong.d.ts +7 -1
- package/javascript/dist/esm/numkong.js +37 -11
- package/javascript/dist/esm/types.d.ts +9 -0
- package/javascript/dist/esm/types.js +96 -0
- package/javascript/node-gyp-build.d.ts +4 -1
- package/javascript/numkong-browser.ts +40 -0
- package/javascript/numkong-wasm.ts +7 -13
- package/javascript/numkong.c +5 -26
- package/javascript/numkong.ts +36 -11
- package/javascript/tsconfig-base.json +1 -0
- package/javascript/tsconfig-cjs.json +6 -1
- package/javascript/types.ts +110 -0
- package/numkong.gypi +101 -0
- package/package.json +34 -13
- package/probes/arm_neon.c +8 -0
- package/probes/arm_neon_bfdot.c +9 -0
- package/probes/arm_neon_fhm.c +9 -0
- package/probes/arm_neon_half.c +8 -0
- package/probes/arm_neon_sdot.c +9 -0
- package/probes/arm_neonfp8.c +9 -0
- package/probes/arm_sme.c +16 -0
- package/probes/arm_sme2.c +16 -0
- package/probes/arm_sme2p1.c +16 -0
- package/probes/arm_sme_bf16.c +16 -0
- package/probes/arm_sme_bi32.c +16 -0
- package/probes/arm_sme_f64.c +16 -0
- package/probes/arm_sme_fa64.c +14 -0
- package/probes/arm_sme_half.c +16 -0
- package/probes/arm_sme_lut2.c +15 -0
- package/probes/arm_sve.c +18 -0
- package/probes/arm_sve2.c +20 -0
- package/probes/arm_sve2p1.c +18 -0
- package/probes/arm_sve_bfdot.c +20 -0
- package/probes/arm_sve_half.c +18 -0
- package/probes/arm_sve_sdot.c +21 -0
- package/probes/loongarch_lasx.c +12 -0
- package/probes/power_vsx.c +12 -0
- package/probes/probe.js +127 -0
- package/probes/riscv_rvv.c +14 -0
- package/probes/riscv_rvv_bb.c +15 -0
- package/probes/riscv_rvv_bf16.c +17 -0
- package/probes/riscv_rvv_half.c +14 -0
- package/probes/wasm_v128relaxed.c +11 -0
- package/probes/x86_alder.c +17 -0
- package/probes/x86_diamond.c +17 -0
- package/probes/x86_genoa.c +17 -0
- package/probes/x86_graniteamx.c +19 -0
- package/probes/x86_haswell.c +11 -0
- package/probes/x86_icelake.c +17 -0
- package/probes/x86_sapphire.c +16 -0
- package/probes/x86_sapphireamx.c +18 -0
- package/probes/x86_sierra.c +17 -0
- package/probes/x86_skylake.c +15 -0
- package/probes/x86_turin.c +17 -0
- package/wasm/numkong-emscripten.js +2 -0
- package/wasm/numkong.d.ts +14 -0
- package/wasm/numkong.js +1124 -0
- package/wasm/numkong.wasm +0 -0
- package/include/numkong/curved/neonhalf.h +0 -212
- package/include/numkong/dot/neonhalf.h +0 -198
- package/include/numkong/dots/neonhalf.h +0 -57
- package/include/numkong/mesh/neonhalf.h +0 -616
- package/include/numkong/reduce/neonhalf.h +0 -157
- package/include/numkong/spatial/neonhalf.h +0 -118
- package/include/numkong/spatial/sapphire.h +0 -343
- package/include/numkong/spatials/neonhalf.h +0 -58
- package/javascript/README.md +0 -246
|
@@ -140,22 +140,22 @@ NK_PUBLIC void nk_reduce_moments_u16_serial( //
|
|
|
140
140
|
NK_PUBLIC void nk_reduce_moments_i32_serial( //
|
|
141
141
|
nk_i32_t const *data, nk_size_t count, nk_size_t stride_bytes, //
|
|
142
142
|
nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
143
|
-
nk_u64_t
|
|
144
|
-
nk_i64_t
|
|
143
|
+
nk_u64_t sum_low = 0;
|
|
144
|
+
nk_i64_t sum_high = 0;
|
|
145
145
|
nk_u64_t sumsq = 0;
|
|
146
146
|
unsigned char const *ptr = (unsigned char const *)data;
|
|
147
147
|
for (nk_size_t i = 0; i < count; ++i, ptr += stride_bytes) {
|
|
148
148
|
nk_i64_t val = (nk_i64_t)(*(nk_i32_t const *)ptr);
|
|
149
149
|
nk_u64_t product = (nk_u64_t)(val * val);
|
|
150
|
-
nk_u64_t sum_before =
|
|
151
|
-
|
|
152
|
-
if (
|
|
153
|
-
|
|
150
|
+
nk_u64_t sum_before = sum_low;
|
|
151
|
+
sum_low += (nk_u64_t)val;
|
|
152
|
+
if (sum_low < sum_before) sum_high++;
|
|
153
|
+
sum_high += (val >> 63);
|
|
154
154
|
sumsq = nk_u64_saturating_add_serial(sumsq, product);
|
|
155
155
|
}
|
|
156
|
-
nk_i64_t
|
|
157
|
-
if (
|
|
158
|
-
else if (
|
|
156
|
+
nk_i64_t sum_low_signed = (nk_i64_t)sum_low;
|
|
157
|
+
if (sum_high == (sum_low_signed >> 63)) *sum_ptr = sum_low_signed;
|
|
158
|
+
else if (sum_high >= 0) *sum_ptr = NK_I64_MAX;
|
|
159
159
|
else *sum_ptr = NK_I64_MIN;
|
|
160
160
|
*sumsq_ptr = sumsq;
|
|
161
161
|
}
|
|
@@ -177,8 +177,8 @@ NK_PUBLIC void nk_reduce_moments_u32_serial( //
|
|
|
177
177
|
NK_PUBLIC void nk_reduce_moments_i64_serial( //
|
|
178
178
|
nk_i64_t const *data, nk_size_t count, nk_size_t stride_bytes, //
|
|
179
179
|
nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
180
|
-
nk_u64_t
|
|
181
|
-
nk_i64_t
|
|
180
|
+
nk_u64_t sum_low = 0;
|
|
181
|
+
nk_i64_t sum_high = 0;
|
|
182
182
|
nk_u64_t sumsq = 0;
|
|
183
183
|
unsigned char const *ptr = (unsigned char const *)data;
|
|
184
184
|
for (nk_size_t i = 0; i < count; ++i, ptr += stride_bytes) {
|
|
@@ -186,14 +186,14 @@ NK_PUBLIC void nk_reduce_moments_i64_serial( //
|
|
|
186
186
|
nk_i64_t product = nk_i64_saturating_mul_serial(val, val);
|
|
187
187
|
nk_u64_t unsigned_product = (nk_u64_t)product;
|
|
188
188
|
sumsq = nk_u64_saturating_add_serial(sumsq, unsigned_product);
|
|
189
|
-
nk_u64_t sum_before =
|
|
190
|
-
|
|
191
|
-
if (
|
|
192
|
-
|
|
193
|
-
}
|
|
194
|
-
nk_i64_t
|
|
195
|
-
if (
|
|
196
|
-
else if (
|
|
189
|
+
nk_u64_t sum_before = sum_low;
|
|
190
|
+
sum_low += (nk_u64_t)val;
|
|
191
|
+
if (sum_low < sum_before) sum_high++;
|
|
192
|
+
sum_high += (val >> 63);
|
|
193
|
+
}
|
|
194
|
+
nk_i64_t sum_low_signed = (nk_i64_t)sum_low;
|
|
195
|
+
if (sum_high == (sum_low_signed >> 63)) *sum_ptr = sum_low_signed;
|
|
196
|
+
else if (sum_high >= 0) *sum_ptr = NK_I64_MAX;
|
|
197
197
|
else *sum_ptr = NK_I64_MIN;
|
|
198
198
|
*sumsq_ptr = sumsq;
|
|
199
199
|
}
|
|
@@ -572,13 +572,11 @@ NK_PUBLIC void nk_reduce_minmax_f16_serial( //
|
|
|
572
572
|
nk_f16_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
573
573
|
nk_f16_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
574
574
|
unsigned char const *ptr = (unsigned char const *)data;
|
|
575
|
-
nk_f16_t min_value =
|
|
575
|
+
nk_f16_t min_value = NK_F16_MAX, max_value = NK_F16_MIN;
|
|
576
576
|
nk_size_t min_idx = NK_SIZE_MAX, max_idx = NK_SIZE_MAX;
|
|
577
577
|
for (nk_size_t i = 0; i < count; ++i, ptr += stride_bytes) {
|
|
578
578
|
nk_f16_t raw_value = *(nk_f16_t const *)ptr;
|
|
579
|
-
|
|
580
|
-
raw_fui.f = raw_value;
|
|
581
|
-
if (nk_f16_is_nan_(raw_fui.u)) continue;
|
|
579
|
+
if (nk_f16_is_nan_(raw_value)) continue;
|
|
582
580
|
if (min_idx == NK_SIZE_MAX || nk_f16_order_serial(raw_value, min_value) < 0) min_value = raw_value, min_idx = i;
|
|
583
581
|
if (max_idx == NK_SIZE_MAX || nk_f16_order_serial(raw_value, max_value) > 0) max_value = raw_value, max_idx = i;
|
|
584
582
|
}
|
|
@@ -591,13 +589,11 @@ NK_PUBLIC void nk_reduce_minmax_bf16_serial( //
|
|
|
591
589
|
nk_bf16_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
592
590
|
nk_bf16_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
593
591
|
unsigned char const *ptr = (unsigned char const *)data;
|
|
594
|
-
nk_bf16_t min_value =
|
|
592
|
+
nk_bf16_t min_value = NK_BF16_MAX, max_value = NK_BF16_MIN;
|
|
595
593
|
nk_size_t min_idx = NK_SIZE_MAX, max_idx = NK_SIZE_MAX;
|
|
596
594
|
for (nk_size_t i = 0; i < count; ++i, ptr += stride_bytes) {
|
|
597
595
|
nk_bf16_t raw_value = *(nk_bf16_t const *)ptr;
|
|
598
|
-
|
|
599
|
-
raw_fui.bf = raw_value;
|
|
600
|
-
if (nk_bf16_is_nan_(raw_fui.u)) continue;
|
|
596
|
+
if (nk_bf16_is_nan_(raw_value)) continue;
|
|
601
597
|
if (min_idx == NK_SIZE_MAX || nk_bf16_order_serial(raw_value, min_value) < 0)
|
|
602
598
|
min_value = raw_value, min_idx = i;
|
|
603
599
|
if (max_idx == NK_SIZE_MAX || nk_bf16_order_serial(raw_value, max_value) > 0)
|
|
@@ -7,8 +7,8 @@
|
|
|
7
7
|
* @sa include/numkong/reduce.h
|
|
8
8
|
*
|
|
9
9
|
* Uses AVX-VNNI-INT8 (256-bit) for efficient widening dot-products on i8, u8, and e2m3:
|
|
10
|
-
* - `_mm256_dpbssd_epi32`: i8
|
|
11
|
-
* - `_mm256_dpbuud_epi32`: u8
|
|
10
|
+
* - `_mm256_dpbssd_epi32`: i8 × i8 → i32 signed dot product (AVXVNNIINT8)
|
|
11
|
+
* - `_mm256_dpbuud_epi32`: u8 × u8 → u32 unsigned dot product (AVXVNNIINT8)
|
|
12
12
|
*/
|
|
13
13
|
#ifndef NK_REDUCE_SIERRA_H
|
|
14
14
|
#define NK_REDUCE_SIERRA_H
|
|
@@ -68,7 +68,7 @@ NK_INTERNAL void nk_reduce_moments_i8_sierra_strided_( //
|
|
|
68
68
|
nk_size_t idx_scalars = 0;
|
|
69
69
|
nk_size_t total_scalars = count * stride_elements;
|
|
70
70
|
nk_size_t step = nk_size_round_up_to_multiple_(32, stride_elements);
|
|
71
|
-
for (; idx_scalars +
|
|
71
|
+
for (; idx_scalars + stride_elements + 31 <= total_scalars; idx_scalars += step) {
|
|
72
72
|
__m256i data_i8x32 = _mm256_loadu_si256((__m256i const *)(data + idx_scalars));
|
|
73
73
|
data_i8x32 = _mm256_and_si256(data_i8x32, stride_mask_i8x32);
|
|
74
74
|
sum_i32x8 = _mm256_dpbssd_epi32(sum_i32x8, data_i8x32, ones_i8x32);
|
|
@@ -109,7 +109,7 @@ NK_PUBLIC void nk_reduce_moments_i8_sierra( //
|
|
|
109
109
|
}
|
|
110
110
|
|
|
111
111
|
/**
|
|
112
|
-
* @section u8 moments via VPDPBUUD (unsigned u8
|
|
112
|
+
* @section u8 moments via VPDPBUUD (unsigned u8 × u8 → u32)
|
|
113
113
|
*
|
|
114
114
|
* Sierra's `_mm256_dpbuud_epi32` provides native u8×u8→u32 dot product, replacing
|
|
115
115
|
* Haswell's 8-instruction SAD+widen+MADD sequence with 3 instructions per 32 elements.
|
|
@@ -153,7 +153,7 @@ NK_INTERNAL void nk_reduce_moments_u8_sierra_strided_( //
|
|
|
153
153
|
nk_size_t idx_scalars = 0;
|
|
154
154
|
nk_size_t total_scalars = count * stride_elements;
|
|
155
155
|
nk_size_t step = nk_size_round_up_to_multiple_(32, stride_elements);
|
|
156
|
-
for (; idx_scalars +
|
|
156
|
+
for (; idx_scalars + stride_elements + 31 <= total_scalars; idx_scalars += step) {
|
|
157
157
|
__m256i data_u8x32 = _mm256_loadu_si256((__m256i const *)(data + idx_scalars));
|
|
158
158
|
data_u8x32 = _mm256_and_si256(data_u8x32, stride_mask_u8x32);
|
|
159
159
|
sum_i32x8 = _mm256_dpbuud_epi32(sum_i32x8, data_u8x32, ones_u8x32);
|
|
@@ -203,10 +203,10 @@ NK_PUBLIC void nk_reduce_moments_u8_sierra( //
|
|
|
203
203
|
NK_INTERNAL void nk_reduce_moments_e2m3_sierra_contiguous_( //
|
|
204
204
|
nk_e2m3_t const *data, nk_size_t count, //
|
|
205
205
|
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
206
|
-
__m256i const
|
|
207
|
-
|
|
208
|
-
__m256i const
|
|
209
|
-
|
|
206
|
+
__m256i const lut_low_u8x32 = _mm256_set_epi8(30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0, //
|
|
207
|
+
30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
|
|
208
|
+
__m256i const lut_high_u8x32 = _mm256_set_epi8(120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32,
|
|
209
|
+
120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32);
|
|
210
210
|
__m256i const nibble_mask_u8x32 = _mm256_set1_epi8(0x0F);
|
|
211
211
|
__m256i const magnitude_mask_u8x32 = _mm256_set1_epi8(0x1F);
|
|
212
212
|
__m256i const half_select_u8x32 = _mm256_set1_epi8(0x10);
|
|
@@ -221,8 +221,8 @@ NK_INTERNAL void nk_reduce_moments_e2m3_sierra_contiguous_( //
|
|
|
221
221
|
__m256i shuffle_idx_u8x32 = _mm256_and_si256(magnitude_u8x32, nibble_mask_u8x32);
|
|
222
222
|
__m256i upper_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(magnitude_u8x32, half_select_u8x32),
|
|
223
223
|
half_select_u8x32);
|
|
224
|
-
__m256i unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(
|
|
225
|
-
_mm256_shuffle_epi8(
|
|
224
|
+
__m256i unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_low_u8x32, shuffle_idx_u8x32),
|
|
225
|
+
_mm256_shuffle_epi8(lut_high_u8x32, shuffle_idx_u8x32),
|
|
226
226
|
upper_select_u8x32);
|
|
227
227
|
__m256i negate_mask_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(data_u8x32, sign_mask_u8x32), sign_mask_u8x32);
|
|
228
228
|
__m256i signed_i8x32 = _mm256_blendv_epi8(
|
|
@@ -241,8 +241,8 @@ NK_INTERNAL void nk_reduce_moments_e2m3_sierra_contiguous_( //
|
|
|
241
241
|
__m256i shuffle_idx_u8x32 = _mm256_and_si256(magnitude_u8x32, nibble_mask_u8x32);
|
|
242
242
|
__m256i upper_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(magnitude_u8x32, half_select_u8x32),
|
|
243
243
|
half_select_u8x32);
|
|
244
|
-
__m256i unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(
|
|
245
|
-
_mm256_shuffle_epi8(
|
|
244
|
+
__m256i unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_low_u8x32, shuffle_idx_u8x32),
|
|
245
|
+
_mm256_shuffle_epi8(lut_high_u8x32, shuffle_idx_u8x32),
|
|
246
246
|
upper_select_u8x32);
|
|
247
247
|
__m256i negate_mask_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(data_u8x32, sign_mask_u8x32), sign_mask_u8x32);
|
|
248
248
|
__m256i signed_i8x32 = _mm256_blendv_epi8(
|
|
@@ -258,10 +258,10 @@ NK_INTERNAL void nk_reduce_moments_e2m3_sierra_strided_( //
|
|
|
258
258
|
nk_e2m3_t const *data, nk_size_t count, nk_size_t stride_elements, //
|
|
259
259
|
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
260
260
|
__m256i stride_mask_u8x32 = nk_stride_blend_u1x32_(stride_elements);
|
|
261
|
-
__m256i const
|
|
262
|
-
|
|
263
|
-
__m256i const
|
|
264
|
-
|
|
261
|
+
__m256i const lut_low_u8x32 = _mm256_set_epi8(30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0, //
|
|
262
|
+
30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
|
|
263
|
+
__m256i const lut_high_u8x32 = _mm256_set_epi8(120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32,
|
|
264
|
+
120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32);
|
|
265
265
|
__m256i const nibble_mask_u8x32 = _mm256_set1_epi8(0x0F);
|
|
266
266
|
__m256i const magnitude_mask_u8x32 = _mm256_set1_epi8(0x1F);
|
|
267
267
|
__m256i const half_select_u8x32 = _mm256_set1_epi8(0x10);
|
|
@@ -272,15 +272,15 @@ NK_INTERNAL void nk_reduce_moments_e2m3_sierra_strided_( //
|
|
|
272
272
|
nk_size_t idx_scalars = 0;
|
|
273
273
|
nk_size_t total_scalars = count * stride_elements;
|
|
274
274
|
nk_size_t step = nk_size_round_up_to_multiple_(32, stride_elements);
|
|
275
|
-
for (; idx_scalars +
|
|
275
|
+
for (; idx_scalars + stride_elements + 31 <= total_scalars; idx_scalars += step) {
|
|
276
276
|
__m256i data_u8x32 = _mm256_loadu_si256((__m256i const *)(data + idx_scalars));
|
|
277
277
|
data_u8x32 = _mm256_and_si256(data_u8x32, stride_mask_u8x32);
|
|
278
278
|
__m256i magnitude_u8x32 = _mm256_and_si256(data_u8x32, magnitude_mask_u8x32);
|
|
279
279
|
__m256i shuffle_idx_u8x32 = _mm256_and_si256(magnitude_u8x32, nibble_mask_u8x32);
|
|
280
280
|
__m256i upper_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(magnitude_u8x32, half_select_u8x32),
|
|
281
281
|
half_select_u8x32);
|
|
282
|
-
__m256i unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(
|
|
283
|
-
_mm256_shuffle_epi8(
|
|
282
|
+
__m256i unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_low_u8x32, shuffle_idx_u8x32),
|
|
283
|
+
_mm256_shuffle_epi8(lut_high_u8x32, shuffle_idx_u8x32),
|
|
284
284
|
upper_select_u8x32);
|
|
285
285
|
__m256i negate_mask_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(data_u8x32, sign_mask_u8x32), sign_mask_u8x32);
|
|
286
286
|
__m256i signed_i8x32 = _mm256_blendv_epi8(
|