numkong 7.0.0 → 7.4.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +197 -124
- package/binding.gyp +34 -484
- package/c/dispatch_bf16.c +59 -1
- package/c/dispatch_e2m3.c +41 -8
- package/c/dispatch_e3m2.c +49 -8
- package/c/dispatch_e4m3.c +51 -9
- package/c/dispatch_e5m2.c +45 -1
- package/c/dispatch_f16.c +79 -26
- package/c/dispatch_f16c.c +5 -5
- package/c/dispatch_f32.c +56 -0
- package/c/dispatch_f64.c +52 -0
- package/c/dispatch_i4.c +3 -0
- package/c/dispatch_i8.c +62 -3
- package/c/dispatch_other.c +18 -0
- package/c/dispatch_u1.c +54 -9
- package/c/dispatch_u4.c +3 -0
- package/c/dispatch_u8.c +64 -3
- package/c/numkong.c +3 -0
- package/include/README.md +79 -9
- package/include/numkong/attention/sapphireamx.h +278 -276
- package/include/numkong/attention/sme.h +983 -977
- package/include/numkong/attention.h +1 -1
- package/include/numkong/capabilities.h +289 -94
- package/include/numkong/cast/README.md +40 -40
- package/include/numkong/cast/diamond.h +64 -0
- package/include/numkong/cast/haswell.h +42 -194
- package/include/numkong/cast/icelake.h +42 -37
- package/include/numkong/cast/loongsonasx.h +252 -0
- package/include/numkong/cast/neon.h +216 -249
- package/include/numkong/cast/powervsx.h +449 -0
- package/include/numkong/cast/rvv.h +223 -274
- package/include/numkong/cast/sapphire.h +18 -18
- package/include/numkong/cast/serial.h +1018 -944
- package/include/numkong/cast/skylake.h +82 -23
- package/include/numkong/cast/v128relaxed.h +462 -105
- package/include/numkong/cast.h +24 -0
- package/include/numkong/cast.hpp +44 -0
- package/include/numkong/curved/README.md +17 -17
- package/include/numkong/curved/neon.h +131 -7
- package/include/numkong/curved/neonbfdot.h +6 -7
- package/include/numkong/curved/rvv.h +26 -26
- package/include/numkong/curved/smef64.h +186 -182
- package/include/numkong/curved.h +14 -18
- package/include/numkong/dot/README.md +154 -137
- package/include/numkong/dot/alder.h +43 -43
- package/include/numkong/dot/diamond.h +158 -0
- package/include/numkong/dot/genoa.h +4 -30
- package/include/numkong/dot/haswell.h +215 -180
- package/include/numkong/dot/icelake.h +190 -76
- package/include/numkong/dot/loongsonasx.h +671 -0
- package/include/numkong/dot/neon.h +124 -73
- package/include/numkong/dot/neonbfdot.h +11 -12
- package/include/numkong/dot/neonfhm.h +44 -46
- package/include/numkong/dot/neonfp8.h +323 -0
- package/include/numkong/dot/neonsdot.h +190 -76
- package/include/numkong/dot/powervsx.h +752 -0
- package/include/numkong/dot/rvv.h +92 -84
- package/include/numkong/dot/rvvbf16.h +12 -12
- package/include/numkong/dot/rvvhalf.h +12 -12
- package/include/numkong/dot/sapphire.h +4 -4
- package/include/numkong/dot/serial.h +66 -30
- package/include/numkong/dot/sierra.h +31 -31
- package/include/numkong/dot/skylake.h +142 -110
- package/include/numkong/dot/sve.h +217 -177
- package/include/numkong/dot/svebfdot.h +10 -10
- package/include/numkong/dot/svehalf.h +85 -41
- package/include/numkong/dot/svesdot.h +89 -0
- package/include/numkong/dot/v128relaxed.h +124 -89
- package/include/numkong/dot.h +114 -48
- package/include/numkong/dots/README.md +203 -203
- package/include/numkong/dots/alder.h +12 -9
- package/include/numkong/dots/diamond.h +86 -0
- package/include/numkong/dots/genoa.h +10 -4
- package/include/numkong/dots/haswell.h +63 -48
- package/include/numkong/dots/icelake.h +27 -18
- package/include/numkong/dots/loongsonasx.h +176 -0
- package/include/numkong/dots/neon.h +14 -11
- package/include/numkong/dots/neonbfdot.h +4 -3
- package/include/numkong/dots/neonfhm.h +11 -9
- package/include/numkong/dots/neonfp8.h +99 -0
- package/include/numkong/dots/neonsdot.h +48 -12
- package/include/numkong/dots/powervsx.h +194 -0
- package/include/numkong/dots/rvv.h +451 -344
- package/include/numkong/dots/sapphireamx.h +1028 -984
- package/include/numkong/dots/serial.h +213 -197
- package/include/numkong/dots/sierra.h +10 -7
- package/include/numkong/dots/skylake.h +47 -36
- package/include/numkong/dots/sme.h +2001 -2364
- package/include/numkong/dots/smebi32.h +175 -162
- package/include/numkong/dots/smef64.h +328 -323
- package/include/numkong/dots/v128relaxed.h +64 -41
- package/include/numkong/dots.h +573 -293
- package/include/numkong/dots.hpp +45 -43
- package/include/numkong/each/README.md +133 -137
- package/include/numkong/each/haswell.h +6 -6
- package/include/numkong/each/icelake.h +7 -7
- package/include/numkong/each/neon.h +76 -42
- package/include/numkong/each/neonbfdot.h +11 -12
- package/include/numkong/each/neonhalf.h +24 -116
- package/include/numkong/each/rvv.h +28 -28
- package/include/numkong/each/sapphire.h +27 -161
- package/include/numkong/each/serial.h +6 -6
- package/include/numkong/each/skylake.h +7 -7
- package/include/numkong/each/v128relaxed.h +562 -0
- package/include/numkong/each.h +148 -62
- package/include/numkong/each.hpp +2 -2
- package/include/numkong/geospatial/README.md +18 -18
- package/include/numkong/geospatial/haswell.h +365 -325
- package/include/numkong/geospatial/neon.h +350 -306
- package/include/numkong/geospatial/rvv.h +4 -4
- package/include/numkong/geospatial/skylake.h +376 -340
- package/include/numkong/geospatial/v128relaxed.h +366 -327
- package/include/numkong/geospatial.h +17 -17
- package/include/numkong/matrix.hpp +4 -4
- package/include/numkong/maxsim/README.md +14 -14
- package/include/numkong/maxsim/alder.h +6 -6
- package/include/numkong/maxsim/genoa.h +4 -4
- package/include/numkong/maxsim/haswell.h +6 -6
- package/include/numkong/maxsim/icelake.h +18 -18
- package/include/numkong/maxsim/neonsdot.h +21 -21
- package/include/numkong/maxsim/sapphireamx.h +14 -14
- package/include/numkong/maxsim/serial.h +6 -6
- package/include/numkong/maxsim/sme.h +221 -196
- package/include/numkong/maxsim/v128relaxed.h +6 -6
- package/include/numkong/mesh/README.md +62 -56
- package/include/numkong/mesh/haswell.h +339 -464
- package/include/numkong/mesh/neon.h +1100 -519
- package/include/numkong/mesh/neonbfdot.h +36 -68
- package/include/numkong/mesh/rvv.h +530 -435
- package/include/numkong/mesh/serial.h +75 -91
- package/include/numkong/mesh/skylake.h +1627 -302
- package/include/numkong/mesh/v128relaxed.h +443 -330
- package/include/numkong/mesh.h +63 -49
- package/include/numkong/mesh.hpp +4 -4
- package/include/numkong/numkong.h +3 -3
- package/include/numkong/numkong.hpp +1 -0
- package/include/numkong/probability/README.md +23 -19
- package/include/numkong/probability/neon.h +82 -52
- package/include/numkong/probability/rvv.h +28 -23
- package/include/numkong/probability/serial.h +51 -39
- package/include/numkong/probability.h +20 -23
- package/include/numkong/random.h +1 -1
- package/include/numkong/reduce/README.md +143 -138
- package/include/numkong/reduce/alder.h +81 -77
- package/include/numkong/reduce/haswell.h +222 -220
- package/include/numkong/reduce/neon.h +629 -519
- package/include/numkong/reduce/neonbfdot.h +7 -218
- package/include/numkong/reduce/neonfhm.h +9 -381
- package/include/numkong/reduce/neonsdot.h +9 -9
- package/include/numkong/reduce/rvv.h +928 -802
- package/include/numkong/reduce/serial.h +23 -27
- package/include/numkong/reduce/sierra.h +20 -20
- package/include/numkong/reduce/skylake.h +326 -324
- package/include/numkong/reduce/v128relaxed.h +52 -52
- package/include/numkong/reduce.h +4 -23
- package/include/numkong/reduce.hpp +156 -11
- package/include/numkong/scalar/README.md +6 -6
- package/include/numkong/scalar/haswell.h +26 -17
- package/include/numkong/scalar/loongsonasx.h +74 -0
- package/include/numkong/scalar/neon.h +9 -9
- package/include/numkong/scalar/powervsx.h +96 -0
- package/include/numkong/scalar/rvv.h +2 -2
- package/include/numkong/scalar/sapphire.h +21 -10
- package/include/numkong/scalar/serial.h +21 -21
- package/include/numkong/scalar.h +13 -0
- package/include/numkong/set/README.md +28 -28
- package/include/numkong/set/haswell.h +12 -12
- package/include/numkong/set/icelake.h +14 -14
- package/include/numkong/set/loongsonasx.h +181 -0
- package/include/numkong/set/neon.h +17 -18
- package/include/numkong/set/powervsx.h +326 -0
- package/include/numkong/set/rvv.h +4 -4
- package/include/numkong/set/serial.h +6 -6
- package/include/numkong/set/sve.h +60 -59
- package/include/numkong/set/v128relaxed.h +6 -6
- package/include/numkong/set.h +21 -7
- package/include/numkong/sets/README.md +26 -26
- package/include/numkong/sets/loongsonasx.h +52 -0
- package/include/numkong/sets/powervsx.h +65 -0
- package/include/numkong/sets/smebi32.h +395 -364
- package/include/numkong/sets.h +83 -40
- package/include/numkong/sparse/README.md +4 -4
- package/include/numkong/sparse/icelake.h +101 -101
- package/include/numkong/sparse/serial.h +1 -1
- package/include/numkong/sparse/sve2.h +137 -141
- package/include/numkong/sparse/turin.h +12 -12
- package/include/numkong/sparse.h +10 -10
- package/include/numkong/spatial/README.md +230 -226
- package/include/numkong/spatial/alder.h +113 -116
- package/include/numkong/spatial/diamond.h +240 -0
- package/include/numkong/spatial/genoa.h +0 -68
- package/include/numkong/spatial/haswell.h +74 -55
- package/include/numkong/spatial/icelake.h +539 -58
- package/include/numkong/spatial/loongsonasx.h +483 -0
- package/include/numkong/spatial/neon.h +125 -52
- package/include/numkong/spatial/neonbfdot.h +8 -9
- package/include/numkong/spatial/neonfp8.h +258 -0
- package/include/numkong/spatial/neonsdot.h +180 -12
- package/include/numkong/spatial/powervsx.h +738 -0
- package/include/numkong/spatial/rvv.h +146 -139
- package/include/numkong/spatial/rvvbf16.h +17 -12
- package/include/numkong/spatial/rvvhalf.h +13 -10
- package/include/numkong/spatial/serial.h +13 -12
- package/include/numkong/spatial/sierra.h +232 -39
- package/include/numkong/spatial/skylake.h +73 -74
- package/include/numkong/spatial/sve.h +93 -72
- package/include/numkong/spatial/svebfdot.h +29 -29
- package/include/numkong/spatial/svehalf.h +52 -26
- package/include/numkong/spatial/svesdot.h +142 -0
- package/include/numkong/spatial/v128relaxed.h +293 -41
- package/include/numkong/spatial.h +338 -82
- package/include/numkong/spatials/README.md +194 -194
- package/include/numkong/spatials/diamond.h +82 -0
- package/include/numkong/spatials/haswell.h +2 -2
- package/include/numkong/spatials/loongsonasx.h +153 -0
- package/include/numkong/spatials/neonfp8.h +111 -0
- package/include/numkong/spatials/neonsdot.h +34 -0
- package/include/numkong/spatials/powervsx.h +153 -0
- package/include/numkong/spatials/rvv.h +259 -243
- package/include/numkong/spatials/sapphireamx.h +173 -173
- package/include/numkong/spatials/serial.h +2 -2
- package/include/numkong/spatials/skylake.h +2 -2
- package/include/numkong/spatials/sme.h +590 -605
- package/include/numkong/spatials/smef64.h +139 -130
- package/include/numkong/spatials/v128relaxed.h +2 -2
- package/include/numkong/spatials.h +820 -500
- package/include/numkong/spatials.hpp +49 -48
- package/include/numkong/tensor.hpp +406 -17
- package/include/numkong/trigonometry/README.md +19 -19
- package/include/numkong/trigonometry/haswell.h +402 -401
- package/include/numkong/trigonometry/neon.h +386 -387
- package/include/numkong/trigonometry/rvv.h +52 -51
- package/include/numkong/trigonometry/serial.h +13 -13
- package/include/numkong/trigonometry/skylake.h +373 -369
- package/include/numkong/trigonometry/v128relaxed.h +375 -374
- package/include/numkong/trigonometry.h +13 -13
- package/include/numkong/trigonometry.hpp +2 -2
- package/include/numkong/types.h +287 -49
- package/include/numkong/types.hpp +436 -12
- package/include/numkong/vector.hpp +82 -14
- package/javascript/dist/cjs/numkong-wasm.js +6 -12
- package/javascript/dist/cjs/numkong.d.ts +7 -1
- package/javascript/dist/cjs/numkong.js +37 -11
- package/javascript/dist/cjs/types.d.ts +9 -0
- package/javascript/dist/cjs/types.js +96 -0
- package/javascript/dist/esm/numkong-browser.d.ts +14 -0
- package/javascript/dist/esm/numkong-browser.js +23 -0
- package/javascript/dist/esm/numkong-wasm.js +6 -12
- package/javascript/dist/esm/numkong.d.ts +7 -1
- package/javascript/dist/esm/numkong.js +37 -11
- package/javascript/dist/esm/types.d.ts +9 -0
- package/javascript/dist/esm/types.js +96 -0
- package/javascript/node-gyp-build.d.ts +4 -1
- package/javascript/numkong-browser.ts +40 -0
- package/javascript/numkong-wasm.ts +7 -13
- package/javascript/numkong.c +5 -26
- package/javascript/numkong.ts +36 -11
- package/javascript/tsconfig-base.json +1 -0
- package/javascript/tsconfig-cjs.json +6 -1
- package/javascript/types.ts +110 -0
- package/numkong.gypi +101 -0
- package/package.json +34 -13
- package/probes/arm_neon.c +8 -0
- package/probes/arm_neon_bfdot.c +9 -0
- package/probes/arm_neon_fhm.c +9 -0
- package/probes/arm_neon_half.c +8 -0
- package/probes/arm_neon_sdot.c +9 -0
- package/probes/arm_neonfp8.c +9 -0
- package/probes/arm_sme.c +16 -0
- package/probes/arm_sme2.c +16 -0
- package/probes/arm_sme2p1.c +16 -0
- package/probes/arm_sme_bf16.c +16 -0
- package/probes/arm_sme_bi32.c +16 -0
- package/probes/arm_sme_f64.c +16 -0
- package/probes/arm_sme_fa64.c +14 -0
- package/probes/arm_sme_half.c +16 -0
- package/probes/arm_sme_lut2.c +15 -0
- package/probes/arm_sve.c +18 -0
- package/probes/arm_sve2.c +20 -0
- package/probes/arm_sve2p1.c +18 -0
- package/probes/arm_sve_bfdot.c +20 -0
- package/probes/arm_sve_half.c +18 -0
- package/probes/arm_sve_sdot.c +21 -0
- package/probes/loongarch_lasx.c +12 -0
- package/probes/power_vsx.c +12 -0
- package/probes/probe.js +127 -0
- package/probes/riscv_rvv.c +14 -0
- package/probes/riscv_rvv_bb.c +15 -0
- package/probes/riscv_rvv_bf16.c +17 -0
- package/probes/riscv_rvv_half.c +14 -0
- package/probes/wasm_v128relaxed.c +11 -0
- package/probes/x86_alder.c +17 -0
- package/probes/x86_diamond.c +17 -0
- package/probes/x86_genoa.c +17 -0
- package/probes/x86_graniteamx.c +19 -0
- package/probes/x86_haswell.c +11 -0
- package/probes/x86_icelake.c +17 -0
- package/probes/x86_sapphire.c +16 -0
- package/probes/x86_sapphireamx.c +18 -0
- package/probes/x86_sierra.c +17 -0
- package/probes/x86_skylake.c +15 -0
- package/probes/x86_turin.c +17 -0
- package/wasm/numkong-emscripten.js +2 -0
- package/wasm/numkong.d.ts +14 -0
- package/wasm/numkong.js +1124 -0
- package/wasm/numkong.wasm +0 -0
- package/include/numkong/curved/neonhalf.h +0 -212
- package/include/numkong/dot/neonhalf.h +0 -198
- package/include/numkong/dots/neonhalf.h +0 -57
- package/include/numkong/mesh/neonhalf.h +0 -616
- package/include/numkong/reduce/neonhalf.h +0 -157
- package/include/numkong/spatial/neonhalf.h +0 -118
- package/include/numkong/spatial/sapphire.h +0 -343
- package/include/numkong/spatials/neonhalf.h +0 -58
- package/javascript/README.md +0 -246
|
@@ -620,8 +620,8 @@ NK_INTERNAL void nk_reduce_minmax_f16_v128relaxed_contiguous_( //
|
|
|
620
620
|
if (val > max_value_f32) max_value_f32 = val, max_idx = idx;
|
|
621
621
|
}
|
|
622
622
|
if (min_value_f32 == NK_F32_MAX && max_value_f32 == NK_F32_MIN) {
|
|
623
|
-
*min_value_ptr =
|
|
624
|
-
*
|
|
623
|
+
*min_value_ptr = NK_F16_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_F16_MIN,
|
|
624
|
+
*max_index_ptr = NK_SIZE_MAX;
|
|
625
625
|
return;
|
|
626
626
|
}
|
|
627
627
|
*min_value_ptr = data[min_idx], *min_index_ptr = min_idx;
|
|
@@ -635,8 +635,8 @@ NK_PUBLIC void nk_reduce_minmax_f16_v128relaxed( //
|
|
|
635
635
|
nk_size_t stride_elements = stride_bytes / sizeof(nk_f16_t);
|
|
636
636
|
int aligned = (stride_bytes % sizeof(nk_f16_t) == 0);
|
|
637
637
|
if (count == 0)
|
|
638
|
-
*min_value_ptr =
|
|
639
|
-
*
|
|
638
|
+
*min_value_ptr = NK_F16_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_F16_MIN,
|
|
639
|
+
*max_index_ptr = NK_SIZE_MAX;
|
|
640
640
|
else if (!aligned)
|
|
641
641
|
nk_reduce_minmax_f16_serial(data, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
642
642
|
max_index_ptr);
|
|
@@ -856,8 +856,8 @@ NK_PUBLIC void nk_reduce_moments_u16_v128relaxed( //
|
|
|
856
856
|
NK_INTERNAL void nk_reduce_moments_i32_v128relaxed_contiguous_( //
|
|
857
857
|
nk_i32_t const *data, nk_size_t count, //
|
|
858
858
|
nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
859
|
-
v128_t
|
|
860
|
-
v128_t
|
|
859
|
+
v128_t sum_low_u64x2 = wasm_i64x2_splat(0);
|
|
860
|
+
v128_t sum_high_i64x2 = wasm_i64x2_splat(0);
|
|
861
861
|
v128_t sumsq_u64x2 = wasm_i64x2_splat(0);
|
|
862
862
|
v128_t sumsq_overflow_u64x2 = wasm_i64x2_splat(0);
|
|
863
863
|
v128_t sign_bit_i64x2 = wasm_i64x2_splat((nk_i64_t)0x8000000000000000LL);
|
|
@@ -865,21 +865,21 @@ NK_INTERNAL void nk_reduce_moments_i32_v128relaxed_contiguous_( //
|
|
|
865
865
|
for (; idx + 4 <= count; idx += 4) {
|
|
866
866
|
v128_t data_i32x4 = wasm_v128_load(data + idx);
|
|
867
867
|
v128_t data_low_i64x2 = wasm_i64x2_extend_low_i32x4(data_i32x4);
|
|
868
|
-
v128_t before_u64x2 =
|
|
869
|
-
|
|
870
|
-
v128_t result_biased_i64x2 = wasm_v128_xor(
|
|
868
|
+
v128_t before_u64x2 = sum_low_u64x2;
|
|
869
|
+
sum_low_u64x2 = wasm_i64x2_add(sum_low_u64x2, data_low_i64x2);
|
|
870
|
+
v128_t result_biased_i64x2 = wasm_v128_xor(sum_low_u64x2, sign_bit_i64x2);
|
|
871
871
|
v128_t before_biased_i64x2 = wasm_v128_xor(before_u64x2, sign_bit_i64x2);
|
|
872
872
|
v128_t carry_u64x2 = wasm_i64x2_gt(before_biased_i64x2, result_biased_i64x2);
|
|
873
|
-
|
|
874
|
-
|
|
873
|
+
sum_high_i64x2 = wasm_i64x2_sub(sum_high_i64x2, carry_u64x2);
|
|
874
|
+
sum_high_i64x2 = wasm_i64x2_add(sum_high_i64x2, wasm_i64x2_shr(data_low_i64x2, 63));
|
|
875
875
|
v128_t data_high_i64x2 = wasm_i64x2_extend_high_i32x4(data_i32x4);
|
|
876
|
-
before_u64x2 =
|
|
877
|
-
|
|
878
|
-
result_biased_i64x2 = wasm_v128_xor(
|
|
876
|
+
before_u64x2 = sum_low_u64x2;
|
|
877
|
+
sum_low_u64x2 = wasm_i64x2_add(sum_low_u64x2, data_high_i64x2);
|
|
878
|
+
result_biased_i64x2 = wasm_v128_xor(sum_low_u64x2, sign_bit_i64x2);
|
|
879
879
|
before_biased_i64x2 = wasm_v128_xor(before_u64x2, sign_bit_i64x2);
|
|
880
880
|
carry_u64x2 = wasm_i64x2_gt(before_biased_i64x2, result_biased_i64x2);
|
|
881
|
-
|
|
882
|
-
|
|
881
|
+
sum_high_i64x2 = wasm_i64x2_sub(sum_high_i64x2, carry_u64x2);
|
|
882
|
+
sum_high_i64x2 = wasm_i64x2_add(sum_high_i64x2, wasm_i64x2_shr(data_high_i64x2, 63));
|
|
883
883
|
v128_t sq_low_i64x2 = wasm_i64x2_extmul_low_i32x4(data_i32x4, data_i32x4);
|
|
884
884
|
v128_t sq_high_i64x2 = wasm_i64x2_extmul_high_i32x4(data_i32x4, data_i32x4);
|
|
885
885
|
v128_t sq_before_u64x2 = sumsq_u64x2;
|
|
@@ -897,26 +897,26 @@ NK_INTERNAL void nk_reduce_moments_i32_v128relaxed_contiguous_( //
|
|
|
897
897
|
wasm_i64x2_extract_lane(sumsq_overflow_u64x2, 1));
|
|
898
898
|
nk_u64_t sumsq = sumsq_overflow ? NK_U64_MAX : nk_reduce_sadd_u64x2_v128relaxed_(sumsq_u64x2);
|
|
899
899
|
nk_b128_vec_t lower_vec, upper_vec;
|
|
900
|
-
lower_vec.v128 =
|
|
901
|
-
upper_vec.v128 =
|
|
902
|
-
nk_u64_t
|
|
903
|
-
nk_i64_t
|
|
904
|
-
nk_u64_t sum_before =
|
|
905
|
-
|
|
906
|
-
sum_before =
|
|
907
|
-
|
|
900
|
+
lower_vec.v128 = sum_low_u64x2;
|
|
901
|
+
upper_vec.v128 = sum_high_i64x2;
|
|
902
|
+
nk_u64_t sum_low = 0;
|
|
903
|
+
nk_i64_t sum_high = 0;
|
|
904
|
+
nk_u64_t sum_before = sum_low;
|
|
905
|
+
sum_low += lower_vec.u64s[0], sum_high += (sum_low < sum_before) + upper_vec.i64s[0];
|
|
906
|
+
sum_before = sum_low;
|
|
907
|
+
sum_low += lower_vec.u64s[1], sum_high += (sum_low < sum_before) + upper_vec.i64s[1];
|
|
908
908
|
for (; idx < count; ++idx) {
|
|
909
909
|
nk_i64_t val = (nk_i64_t)data[idx];
|
|
910
|
-
sum_before =
|
|
911
|
-
|
|
912
|
-
if (
|
|
913
|
-
|
|
910
|
+
sum_before = sum_low;
|
|
911
|
+
sum_low += (nk_u64_t)val;
|
|
912
|
+
if (sum_low < sum_before) sum_high++;
|
|
913
|
+
sum_high += (val >> 63);
|
|
914
914
|
nk_u64_t product = (nk_u64_t)(val * val);
|
|
915
915
|
sumsq = nk_u64_saturating_add_serial(sumsq, product);
|
|
916
916
|
}
|
|
917
|
-
nk_i64_t
|
|
918
|
-
if (
|
|
919
|
-
else if (
|
|
917
|
+
nk_i64_t sum_low_signed = (nk_i64_t)sum_low;
|
|
918
|
+
if (sum_high == (sum_low_signed >> 63)) *sum_ptr = sum_low_signed;
|
|
919
|
+
else if (sum_high >= 0) *sum_ptr = NK_I64_MAX;
|
|
920
920
|
else *sum_ptr = NK_I64_MIN;
|
|
921
921
|
*sumsq_ptr = sumsq;
|
|
922
922
|
}
|
|
@@ -981,8 +981,8 @@ NK_PUBLIC void nk_reduce_moments_u32_v128relaxed( //
|
|
|
981
981
|
NK_INTERNAL void nk_reduce_moments_i64_v128relaxed_contiguous_( //
|
|
982
982
|
nk_i64_t const *data, nk_size_t count, //
|
|
983
983
|
nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
984
|
-
v128_t
|
|
985
|
-
v128_t
|
|
984
|
+
v128_t sum_low_u64x2 = wasm_i64x2_splat(0);
|
|
985
|
+
v128_t sum_high_i64x2 = wasm_i64x2_splat(0);
|
|
986
986
|
v128_t sumsq_u64x2 = wasm_i64x2_splat(0);
|
|
987
987
|
v128_t sumsq_overflow_u64x2 = wasm_i64x2_splat(0);
|
|
988
988
|
v128_t sign_bit_i64x2 = wasm_i64x2_splat((nk_i64_t)0x8000000000000000LL);
|
|
@@ -995,36 +995,36 @@ NK_INTERNAL void nk_reduce_moments_i64_v128relaxed_contiguous_( //
|
|
|
995
995
|
sumsq_overflow_u64x2 = wasm_v128_or(
|
|
996
996
|
sumsq_overflow_u64x2,
|
|
997
997
|
wasm_i64x2_gt(wasm_v128_xor(sq_before_u64x2, sign_bit_i64x2), wasm_v128_xor(sumsq_u64x2, sign_bit_i64x2)));
|
|
998
|
-
v128_t before_u64x2 =
|
|
999
|
-
|
|
998
|
+
v128_t before_u64x2 = sum_low_u64x2;
|
|
999
|
+
sum_low_u64x2 = wasm_i64x2_add(sum_low_u64x2, data_i64x2);
|
|
1000
1000
|
v128_t carry_u64x2 = wasm_i64x2_gt(wasm_v128_xor(before_u64x2, sign_bit_i64x2),
|
|
1001
|
-
wasm_v128_xor(
|
|
1002
|
-
|
|
1003
|
-
|
|
1001
|
+
wasm_v128_xor(sum_low_u64x2, sign_bit_i64x2));
|
|
1002
|
+
sum_high_i64x2 = wasm_i64x2_sub(sum_high_i64x2, carry_u64x2);
|
|
1003
|
+
sum_high_i64x2 = wasm_i64x2_add(sum_high_i64x2, wasm_i64x2_shr(data_i64x2, 63));
|
|
1004
1004
|
}
|
|
1005
1005
|
int sumsq_overflow = (int)(wasm_i64x2_extract_lane(sumsq_overflow_u64x2, 0) |
|
|
1006
1006
|
wasm_i64x2_extract_lane(sumsq_overflow_u64x2, 1));
|
|
1007
1007
|
nk_u64_t sumsq = sumsq_overflow ? NK_U64_MAX : nk_reduce_sadd_u64x2_v128relaxed_(sumsq_u64x2);
|
|
1008
|
-
nk_u64_t
|
|
1009
|
-
nk_i64_t
|
|
1008
|
+
nk_u64_t sum_low = (nk_u64_t)wasm_i64x2_extract_lane(sum_low_u64x2, 0);
|
|
1009
|
+
nk_i64_t sum_high = wasm_i64x2_extract_lane(sum_high_i64x2, 0);
|
|
1010
1010
|
{
|
|
1011
|
-
nk_u64_t sum_before =
|
|
1012
|
-
|
|
1013
|
-
if (
|
|
1014
|
-
|
|
1011
|
+
nk_u64_t sum_before = sum_low;
|
|
1012
|
+
sum_low += (nk_u64_t)wasm_i64x2_extract_lane(sum_low_u64x2, 1);
|
|
1013
|
+
if (sum_low < sum_before) sum_high++;
|
|
1014
|
+
sum_high += wasm_i64x2_extract_lane(sum_high_i64x2, 1);
|
|
1015
1015
|
}
|
|
1016
1016
|
for (; idx < count; ++idx) {
|
|
1017
1017
|
nk_i64_t val = data[idx];
|
|
1018
1018
|
nk_u64_t unsigned_product = (nk_u64_t)nk_i64_saturating_mul_serial(val, val);
|
|
1019
1019
|
sumsq = nk_u64_saturating_add_serial(sumsq, unsigned_product);
|
|
1020
|
-
nk_u64_t sum_before =
|
|
1021
|
-
|
|
1022
|
-
if (
|
|
1023
|
-
|
|
1024
|
-
}
|
|
1025
|
-
nk_i64_t
|
|
1026
|
-
if (
|
|
1027
|
-
else if (
|
|
1020
|
+
nk_u64_t sum_before = sum_low;
|
|
1021
|
+
sum_low += (nk_u64_t)val;
|
|
1022
|
+
if (sum_low < sum_before) sum_high++;
|
|
1023
|
+
sum_high += (val >> 63);
|
|
1024
|
+
}
|
|
1025
|
+
nk_i64_t sum_low_signed = (nk_i64_t)sum_low;
|
|
1026
|
+
if (sum_high == (sum_low_signed >> 63)) *sum_ptr = sum_low_signed;
|
|
1027
|
+
else if (sum_high >= 0) *sum_ptr = NK_I64_MAX;
|
|
1028
1028
|
else *sum_ptr = NK_I64_MIN;
|
|
1029
1029
|
*sumsq_ptr = sumsq;
|
|
1030
1030
|
}
|
package/include/numkong/reduce.h
CHANGED
|
@@ -446,19 +446,13 @@ NK_PUBLIC void nk_reduce_minmax_e4m3_neon(nk_e4m3_t const *, nk_size_t, nk_size_
|
|
|
446
446
|
/** @copydoc nk_reduce_minmax_f64 */
|
|
447
447
|
NK_PUBLIC void nk_reduce_minmax_e5m2_neon(nk_e5m2_t const *, nk_size_t, nk_size_t, nk_e5m2_t *, nk_size_t *,
|
|
448
448
|
nk_e5m2_t *, nk_size_t *);
|
|
449
|
-
#endif // NK_TARGET_NEON
|
|
450
|
-
|
|
451
|
-
#if NK_TARGET_NEONHALF
|
|
452
449
|
/** @copydoc nk_reduce_moments_f64 */
|
|
453
|
-
NK_PUBLIC void
|
|
454
|
-
#endif //
|
|
450
|
+
NK_PUBLIC void nk_reduce_moments_f16_neon(nk_f16_t const *, nk_size_t, nk_size_t, nk_f32_t *, nk_f32_t *);
|
|
451
|
+
#endif // NK_TARGET_NEON
|
|
455
452
|
|
|
456
453
|
#if NK_TARGET_NEONBFDOT
|
|
457
454
|
/** @copydoc nk_reduce_moments_f64 */
|
|
458
455
|
NK_PUBLIC void nk_reduce_moments_bf16_neonbfdot(nk_bf16_t const *, nk_size_t, nk_size_t, nk_f32_t *, nk_f32_t *);
|
|
459
|
-
/** @copydoc nk_reduce_minmax_f64 */
|
|
460
|
-
NK_PUBLIC void nk_reduce_minmax_bf16_neonbfdot(nk_bf16_t const *, nk_size_t, nk_size_t, nk_bf16_t *, nk_size_t *,
|
|
461
|
-
nk_bf16_t *, nk_size_t *);
|
|
462
456
|
#endif // NK_TARGET_NEONBFDOT
|
|
463
457
|
|
|
464
458
|
#if NK_TARGET_NEONSDOT
|
|
@@ -475,12 +469,6 @@ NK_PUBLIC void nk_reduce_moments_e2m3_neonsdot(nk_e2m3_t const *, nk_size_t, nk_
|
|
|
475
469
|
NK_PUBLIC void nk_reduce_moments_e4m3_neonfhm(nk_e4m3_t const *, nk_size_t, nk_size_t, nk_f32_t *, nk_f32_t *);
|
|
476
470
|
/** @copydoc nk_reduce_moments_f64 */
|
|
477
471
|
NK_PUBLIC void nk_reduce_moments_e5m2_neonfhm(nk_e5m2_t const *, nk_size_t, nk_size_t, nk_f32_t *, nk_f32_t *);
|
|
478
|
-
/** @copydoc nk_reduce_minmax_f64 */
|
|
479
|
-
NK_PUBLIC void nk_reduce_minmax_e4m3_neonfhm(nk_e4m3_t const *, nk_size_t, nk_size_t, nk_e4m3_t *, nk_size_t *,
|
|
480
|
-
nk_e4m3_t *, nk_size_t *);
|
|
481
|
-
/** @copydoc nk_reduce_minmax_f64 */
|
|
482
|
-
NK_PUBLIC void nk_reduce_minmax_e5m2_neonfhm(nk_e5m2_t const *, nk_size_t, nk_size_t, nk_e5m2_t *, nk_size_t *,
|
|
483
|
-
nk_e5m2_t *, nk_size_t *);
|
|
484
472
|
#endif // NK_TARGET_NEONFHM
|
|
485
473
|
|
|
486
474
|
#if NK_TARGET_HASWELL
|
|
@@ -950,7 +938,6 @@ NK_INTERNAL nk_dtype_t nk_reduce_minmax_value_dtype(nk_dtype_t dtype) {
|
|
|
950
938
|
|
|
951
939
|
#include "numkong/reduce/serial.h"
|
|
952
940
|
#include "numkong/reduce/neon.h"
|
|
953
|
-
#include "numkong/reduce/neonhalf.h"
|
|
954
941
|
#include "numkong/reduce/neonbfdot.h"
|
|
955
942
|
#include "numkong/reduce/neonsdot.h"
|
|
956
943
|
#include "numkong/reduce/neonfhm.h"
|
|
@@ -1324,8 +1311,8 @@ NK_PUBLIC void nk_reduce_moments_f16(nk_f16_t const *d, nk_size_t n, nk_size_t s
|
|
|
1324
1311
|
nk_reduce_moments_f16_skylake(d, n, s, sum, sumsq);
|
|
1325
1312
|
#elif NK_TARGET_HASWELL
|
|
1326
1313
|
nk_reduce_moments_f16_haswell(d, n, s, sum, sumsq);
|
|
1327
|
-
#elif
|
|
1328
|
-
|
|
1314
|
+
#elif NK_TARGET_NEON
|
|
1315
|
+
nk_reduce_moments_f16_neon(d, n, s, sum, sumsq);
|
|
1329
1316
|
#elif NK_TARGET_RVV
|
|
1330
1317
|
nk_reduce_moments_f16_rvv(d, n, s, sum, sumsq);
|
|
1331
1318
|
#elif NK_TARGET_V128RELAXED
|
|
@@ -1376,8 +1363,6 @@ NK_PUBLIC void nk_reduce_minmax_bf16(nk_bf16_t const *d, nk_size_t n, nk_size_t
|
|
|
1376
1363
|
nk_reduce_minmax_bf16_skylake(d, n, s, mn, mi, mx, xi);
|
|
1377
1364
|
#elif NK_TARGET_HASWELL
|
|
1378
1365
|
nk_reduce_minmax_bf16_haswell(d, n, s, mn, mi, mx, xi);
|
|
1379
|
-
#elif NK_TARGET_NEONBFDOT
|
|
1380
|
-
nk_reduce_minmax_bf16_neonbfdot(d, n, s, mn, mi, mx, xi);
|
|
1381
1366
|
#elif NK_TARGET_RVV
|
|
1382
1367
|
nk_reduce_minmax_bf16_rvv(d, n, s, mn, mi, mx, xi);
|
|
1383
1368
|
#elif NK_TARGET_V128RELAXED
|
|
@@ -1413,8 +1398,6 @@ NK_PUBLIC void nk_reduce_minmax_e4m3(nk_e4m3_t const *d, nk_size_t n, nk_size_t
|
|
|
1413
1398
|
nk_reduce_minmax_e4m3_skylake(d, n, s, mn, mi, mx, xi);
|
|
1414
1399
|
#elif NK_TARGET_HASWELL
|
|
1415
1400
|
nk_reduce_minmax_e4m3_haswell(d, n, s, mn, mi, mx, xi);
|
|
1416
|
-
#elif NK_TARGET_NEONFHM
|
|
1417
|
-
nk_reduce_minmax_e4m3_neonfhm(d, n, s, mn, mi, mx, xi);
|
|
1418
1401
|
#elif NK_TARGET_NEON
|
|
1419
1402
|
nk_reduce_minmax_e4m3_neon(d, n, s, mn, mi, mx, xi);
|
|
1420
1403
|
#elif NK_TARGET_RVV
|
|
@@ -1452,8 +1435,6 @@ NK_PUBLIC void nk_reduce_minmax_e5m2(nk_e5m2_t const *d, nk_size_t n, nk_size_t
|
|
|
1452
1435
|
nk_reduce_minmax_e5m2_skylake(d, n, s, mn, mi, mx, xi);
|
|
1453
1436
|
#elif NK_TARGET_HASWELL
|
|
1454
1437
|
nk_reduce_minmax_e5m2_haswell(d, n, s, mn, mi, mx, xi);
|
|
1455
|
-
#elif NK_TARGET_NEONFHM
|
|
1456
|
-
nk_reduce_minmax_e5m2_neonfhm(d, n, s, mn, mi, mx, xi);
|
|
1457
1438
|
#elif NK_TARGET_NEON
|
|
1458
1439
|
nk_reduce_minmax_e5m2_neon(d, n, s, mn, mi, mx, xi);
|
|
1459
1440
|
#elif NK_TARGET_RVV
|
|
@@ -192,13 +192,95 @@ void reduce_minmax(in_type_ const *data, std::size_t count, std::size_t stride_b
|
|
|
192
192
|
if (max_index) *max_index = static_cast<std::size_t>(max_offset);
|
|
193
193
|
}
|
|
194
194
|
|
|
195
|
+
/** @brief Compute sum and sum-of-squares over a vector view. */
|
|
196
|
+
template <numeric_dtype in_type_, numeric_dtype sum_type_ = typename in_type_::reduce_moments_sum_t,
|
|
197
|
+
numeric_dtype sumsq_type_ = typename in_type_::reduce_moments_sumsq_t,
|
|
198
|
+
allow_simd_t allow_simd_ = prefer_simd_k>
|
|
199
|
+
void reduce_moments(vector_view<in_type_> input, sum_type_ *sum, sumsq_type_ *sumsq) noexcept {
|
|
200
|
+
reduce_moments<in_type_, sum_type_, sumsq_type_, allow_simd_>(
|
|
201
|
+
input.data(), input.size(), static_cast<std::size_t>(input.stride_bytes()), sum, sumsq);
|
|
202
|
+
}
|
|
203
|
+
|
|
204
|
+
/** @brief Find minimum and maximum elements with their indices over a vector view. */
|
|
205
|
+
template <numeric_dtype in_type_, numeric_dtype minmax_type_ = typename in_type_::reduce_minmax_value_t,
|
|
206
|
+
allow_simd_t allow_simd_ = prefer_simd_k>
|
|
207
|
+
void reduce_minmax(vector_view<in_type_> input, minmax_type_ *min_value, std::size_t *min_index,
|
|
208
|
+
minmax_type_ *max_value, std::size_t *max_index) noexcept {
|
|
209
|
+
reduce_minmax<in_type_, minmax_type_, allow_simd_>(input.data(), input.size(),
|
|
210
|
+
static_cast<std::size_t>(input.stride_bytes()), min_value,
|
|
211
|
+
min_index, max_value, max_index);
|
|
212
|
+
}
|
|
213
|
+
|
|
195
214
|
} // namespace ashvardanian::numkong
|
|
196
215
|
|
|
197
216
|
#include "numkong/tensor.hpp"
|
|
198
217
|
|
|
199
218
|
namespace ashvardanian::numkong {
|
|
200
219
|
|
|
201
|
-
#pragma region
|
|
220
|
+
#pragma region Tensor Reduction Helpers
|
|
221
|
+
|
|
222
|
+
/** @brief Result of detecting how many trailing dimensions form a single arithmetic progression. */
|
|
223
|
+
struct uniform_stride_tail_result_t_ {
|
|
224
|
+
std::size_t tail_dims; ///< Number of collapsible trailing dimensions.
|
|
225
|
+
std::size_t element_count; ///< Product of collapsed extents.
|
|
226
|
+
std::size_t stride_bytes; ///< Absolute stride of the innermost collapsed dimension.
|
|
227
|
+
};
|
|
228
|
+
|
|
229
|
+
/** @brief Detect trailing dimensions where stride[i] == stride[i+1] * extent[i+1].
|
|
230
|
+
* When this holds, the tail is a single strided sequence and can be passed to a SIMD
|
|
231
|
+
* kernel in one call with (element_count, stride_bytes). */
|
|
232
|
+
template <typename value_type_, std::size_t max_rank_>
|
|
233
|
+
uniform_stride_tail_result_t_ uniform_stride_tail_(tensor_view<value_type_, max_rank_> input) noexcept {
|
|
234
|
+
if constexpr (dimensions_per_value<value_type_>() > 1) return {0, 0, 0};
|
|
235
|
+
auto rank = input.rank();
|
|
236
|
+
if (rank == 0) return {0, 1, sizeof(value_type_)};
|
|
237
|
+
std::size_t tail = 1;
|
|
238
|
+
auto innermost_stride = input.stride_bytes(rank - 1);
|
|
239
|
+
auto expected_stride = innermost_stride;
|
|
240
|
+
for (std::size_t i = rank - 1; i > 0; --i) {
|
|
241
|
+
expected_stride *= static_cast<std::ptrdiff_t>(input.extent(i));
|
|
242
|
+
if (input.stride_bytes(i - 1) != expected_stride) break;
|
|
243
|
+
++tail;
|
|
244
|
+
}
|
|
245
|
+
std::size_t count = 1;
|
|
246
|
+
for (std::size_t i = rank - tail; i < rank; ++i) count *= input.extent(i);
|
|
247
|
+
return {tail, count, static_cast<std::size_t>(innermost_stride < 0 ? -innermost_stride : innermost_stride)};
|
|
248
|
+
}
|
|
249
|
+
|
|
250
|
+
/** @brief Collapse the trailing `tail.tail_dims` dimensions into one, preserving outer dims and strides. */
|
|
251
|
+
template <typename value_type_, std::size_t max_rank_>
|
|
252
|
+
tensor_view<value_type_, max_rank_> collapse_uniform_tail_(tensor_view<value_type_, max_rank_> input,
|
|
253
|
+
uniform_stride_tail_result_t_ const &tail) noexcept {
|
|
254
|
+
shape_storage_<max_rank_> s;
|
|
255
|
+
s.rank = input.rank() - tail.tail_dims + 1;
|
|
256
|
+
for (std::size_t i = 0; i + tail.tail_dims < input.rank(); ++i) {
|
|
257
|
+
s.extents[i] = input.extent(i);
|
|
258
|
+
s.strides[i] = input.stride_bytes(i);
|
|
259
|
+
}
|
|
260
|
+
s.extents[s.rank - 1] = tail.element_count;
|
|
261
|
+
s.strides[s.rank - 1] = input.stride_bytes(input.rank() - 1);
|
|
262
|
+
return {input.byte_data(), s};
|
|
263
|
+
}
|
|
264
|
+
|
|
265
|
+
/** @brief Normalize a fully-collapsed tail for SIMD kernel consumption, handling negative strides. */
|
|
266
|
+
template <typename value_type_, std::size_t max_rank_>
|
|
267
|
+
normalized_rank1_lane_<value_type_, max_rank_> normalize_rank1_lane_from_tail_(
|
|
268
|
+
tensor_view<value_type_, max_rank_> input, uniform_stride_tail_result_t_ const &tail) noexcept {
|
|
269
|
+
normalized_rank1_lane_<value_type_, max_rank_> lane;
|
|
270
|
+
lane.count = tail.element_count;
|
|
271
|
+
lane.stride_bytes = tail.stride_bytes;
|
|
272
|
+
auto innermost_stride = input.stride_bytes(input.rank() - 1);
|
|
273
|
+
if (innermost_stride >= 0) {
|
|
274
|
+
lane.data = input.data();
|
|
275
|
+
lane.reversed = false;
|
|
276
|
+
}
|
|
277
|
+
else {
|
|
278
|
+
lane.data = reinterpret_cast<value_type_ const *>(
|
|
279
|
+
input.byte_data() + static_cast<std::ptrdiff_t>(lane.count - 1) * innermost_stride);
|
|
280
|
+
lane.reversed = true;
|
|
281
|
+
}
|
|
282
|
+
return lane;
|
|
283
|
+
}
|
|
202
284
|
|
|
203
285
|
template <numeric_dtype value_type_, std::size_t max_rank_>
|
|
204
286
|
bool reduce_rank1_moments_(tensor_view<value_type_, max_rank_> input, typename value_type_::reduce_moments_sum_t &sum,
|
|
@@ -391,9 +473,9 @@ bool reduce_minmax_axis_packed_(tensor_view<value_type_, max_rank_> input, std::
|
|
|
391
473
|
return true;
|
|
392
474
|
}
|
|
393
475
|
|
|
394
|
-
#pragma endregion
|
|
476
|
+
#pragma endregion Tensor Reduction Helpers
|
|
395
477
|
|
|
396
|
-
#pragma region
|
|
478
|
+
#pragma region Scalar Reductions
|
|
397
479
|
|
|
398
480
|
/** @brief Compute Σxᵢ and Σxᵢ² in a single pass. Returns zeroed result for empty tensors. */
|
|
399
481
|
template <numeric_dtype value_type_, std::size_t max_rank_ = 8>
|
|
@@ -403,11 +485,14 @@ moments_result<typename value_type_::reduce_moments_sum_t, typename value_type_:
|
|
|
403
485
|
using sumsq_t = typename value_type_::reduce_moments_sumsq_t;
|
|
404
486
|
moments_result<sum_t, sumsq_t> result {};
|
|
405
487
|
if (input.empty() || input.numel() == 0 || !tensor_layout_supported_(input)) return result;
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
488
|
+
auto tail = uniform_stride_tail_(input);
|
|
489
|
+
if (tail.tail_dims == input.rank()) {
|
|
490
|
+
auto lane = normalize_rank1_lane_from_tail_<value_type_, max_rank_>(input, tail);
|
|
491
|
+
numkong::reduce_moments<value_type_>(lane.data, lane.count, lane.stride_bytes, &result.sum, &result.sumsq);
|
|
409
492
|
return result;
|
|
410
493
|
}
|
|
494
|
+
if (tail.tail_dims >= 2) return moments<value_type_, max_rank_>(collapse_uniform_tail_(input, tail));
|
|
495
|
+
// Sub-byte rank-1 fallback: uniform_stride_tail_ returns {0,0,0} for packed types.
|
|
411
496
|
if (input.rank() == 1) {
|
|
412
497
|
reduce_rank1_moments_(input, result.sum, result.sumsq);
|
|
413
498
|
return result;
|
|
@@ -426,11 +511,19 @@ minmax_result<typename value_type_::reduce_minmax_value_t> minmax(tensor_view<va
|
|
|
426
511
|
using minmax_t = typename value_type_::reduce_minmax_value_t;
|
|
427
512
|
minmax_result<minmax_t> result {};
|
|
428
513
|
if (input.empty() || input.numel() == 0 || !tensor_layout_supported_(input)) return result;
|
|
429
|
-
|
|
430
|
-
|
|
514
|
+
auto tail = uniform_stride_tail_(input);
|
|
515
|
+
if (tail.tail_dims == input.rank()) {
|
|
516
|
+
auto lane = normalize_rank1_lane_from_tail_<value_type_, max_rank_>(input, tail);
|
|
517
|
+
numkong::reduce_minmax<value_type_>(lane.data, lane.count, lane.stride_bytes, &result.min_value,
|
|
431
518
|
&result.min_index, &result.max_value, &result.max_index);
|
|
519
|
+
if (lane.reversed) {
|
|
520
|
+
result.min_index = tail.element_count - 1 - result.min_index;
|
|
521
|
+
result.max_index = tail.element_count - 1 - result.max_index;
|
|
522
|
+
}
|
|
432
523
|
return result;
|
|
433
524
|
}
|
|
525
|
+
if (tail.tail_dims >= 2) return minmax<value_type_, max_rank_>(collapse_uniform_tail_(input, tail));
|
|
526
|
+
// Sub-byte rank-1 fallback.
|
|
434
527
|
if (input.rank() == 1) {
|
|
435
528
|
reduce_rank1_minmax_(input, result);
|
|
436
529
|
return result;
|
|
@@ -484,9 +577,61 @@ std::size_t argmax(tensor_view<value_type_, max_rank_> input) noexcept {
|
|
|
484
577
|
return minmax(input).max_index;
|
|
485
578
|
}
|
|
486
579
|
|
|
487
|
-
|
|
580
|
+
/** @brief Compute Σxᵢ and Σxᵢ² over a vector view. */
|
|
581
|
+
template <numeric_dtype value_type_>
|
|
582
|
+
moments_result<typename value_type_::reduce_moments_sum_t, typename value_type_::reduce_moments_sumsq_t> moments(
|
|
583
|
+
vector_view<value_type_> input) noexcept {
|
|
584
|
+
using sum_t = typename value_type_::reduce_moments_sum_t;
|
|
585
|
+
using sumsq_t = typename value_type_::reduce_moments_sumsq_t;
|
|
586
|
+
moments_result<sum_t, sumsq_t> result {};
|
|
587
|
+
if (input.size() == 0) return result;
|
|
588
|
+
reduce_moments<value_type_>(input, &result.sum, &result.sumsq);
|
|
589
|
+
return result;
|
|
590
|
+
}
|
|
591
|
+
|
|
592
|
+
/** @brief Find min and max values with their indices over a vector view. */
|
|
593
|
+
template <numeric_dtype value_type_>
|
|
594
|
+
minmax_result<typename value_type_::reduce_minmax_value_t> minmax(vector_view<value_type_> input) noexcept {
|
|
595
|
+
using minmax_t = typename value_type_::reduce_minmax_value_t;
|
|
596
|
+
minmax_result<minmax_t> result {};
|
|
597
|
+
if (input.size() == 0) return result;
|
|
598
|
+
reduce_minmax<value_type_>(input, &result.min_value, &result.min_index, &result.max_value, &result.max_index);
|
|
599
|
+
return result;
|
|
600
|
+
}
|
|
601
|
+
|
|
602
|
+
/** @brief Σ of all elements in a vector view. */
|
|
603
|
+
template <numeric_dtype value_type_>
|
|
604
|
+
typename value_type_::reduce_moments_sum_t sum(vector_view<value_type_> input) noexcept {
|
|
605
|
+
return moments(input).sum;
|
|
606
|
+
}
|
|
607
|
+
|
|
608
|
+
/** @brief Find the minimum element value in a vector view. */
|
|
609
|
+
template <numeric_dtype value_type_>
|
|
610
|
+
typename value_type_::reduce_minmax_value_t min(vector_view<value_type_> input) noexcept {
|
|
611
|
+
return minmax(input).min_value;
|
|
612
|
+
}
|
|
613
|
+
|
|
614
|
+
/** @brief Find the maximum element value in a vector view. */
|
|
615
|
+
template <numeric_dtype value_type_>
|
|
616
|
+
typename value_type_::reduce_minmax_value_t max(vector_view<value_type_> input) noexcept {
|
|
617
|
+
return minmax(input).max_value;
|
|
618
|
+
}
|
|
619
|
+
|
|
620
|
+
/** @brief Index of the minimum element in a vector view. */
|
|
621
|
+
template <numeric_dtype value_type_>
|
|
622
|
+
std::size_t argmin(vector_view<value_type_> input) noexcept {
|
|
623
|
+
return minmax(input).min_index;
|
|
624
|
+
}
|
|
625
|
+
|
|
626
|
+
/** @brief Index of the maximum element in a vector view. */
|
|
627
|
+
template <numeric_dtype value_type_>
|
|
628
|
+
std::size_t argmax(vector_view<value_type_> input) noexcept {
|
|
629
|
+
return minmax(input).max_index;
|
|
630
|
+
}
|
|
631
|
+
|
|
632
|
+
#pragma endregion Scalar Reductions
|
|
488
633
|
|
|
489
|
-
#pragma region
|
|
634
|
+
#pragma region Axis Reductions
|
|
490
635
|
|
|
491
636
|
/** @brief Σ along a single axis. Returns empty tensor on failure. */
|
|
492
637
|
template <numeric_dtype value_type_, std::size_t max_rank_ = 8,
|
|
@@ -626,7 +771,7 @@ tensor<typename value_type_::reduce_minmax_value_t, allocator_type_, max_rank_>
|
|
|
626
771
|
return try_minmax<value_type_, max_rank_, allocator_type_>(input, axis, keep_dims).max_value;
|
|
627
772
|
}
|
|
628
773
|
|
|
629
|
-
#pragma endregion
|
|
774
|
+
#pragma endregion Axis Reductions
|
|
630
775
|
|
|
631
776
|
} // namespace ashvardanian::numkong
|
|
632
777
|
|
|
@@ -6,21 +6,21 @@ Ordering functions (`nk_f16_order`, `nk_bf16_order`, `nk_e4m3_order`) convert fl
|
|
|
6
6
|
|
|
7
7
|
Reciprocal square root:
|
|
8
8
|
|
|
9
|
-
|
|
9
|
+
$$
|
|
10
10
|
\text{rsqrt}(x) = \frac{1}{\sqrt{x}}
|
|
11
|
-
|
|
11
|
+
$$
|
|
12
12
|
|
|
13
13
|
Fused multiply-add:
|
|
14
14
|
|
|
15
|
-
|
|
15
|
+
$$
|
|
16
16
|
\text{fma}(a, b, c) = a \cdot b + c
|
|
17
|
-
|
|
17
|
+
$$
|
|
18
18
|
|
|
19
19
|
Saturating addition:
|
|
20
20
|
|
|
21
|
-
|
|
21
|
+
$$
|
|
22
22
|
\text{sat\_add}(a, b) = \text{clamp}(a + b, \text{T\_MIN}, \text{T\_MAX})
|
|
23
|
-
|
|
23
|
+
$$
|
|
24
24
|
|
|
25
25
|
Reformulating as Python pseudocode:
|
|
26
26
|
|
|
@@ -8,13 +8,13 @@
|
|
|
8
8
|
*
|
|
9
9
|
* @section scalars_haswell_instructions Key AVX2/FMA Scalar Instructions
|
|
10
10
|
*
|
|
11
|
-
* Intrinsic
|
|
12
|
-
* _mm_sqrt_ps
|
|
13
|
-
* _mm_sqrt_pd
|
|
14
|
-
* _mm_fmadd_ss
|
|
15
|
-
* _mm_fmadd_sd
|
|
16
|
-
* _mm_cvtps_ph
|
|
17
|
-
* _mm_cvtph_ps
|
|
11
|
+
* Intrinsic Instruction Haswell Genoa
|
|
12
|
+
* _mm_sqrt_ps VSQRTPS (XMM, XMM) 11cy @ p0 15cy @ p01
|
|
13
|
+
* _mm_sqrt_pd VSQRTPD (XMM, XMM) 16cy @ p0 15cy @ p01
|
|
14
|
+
* _mm_fmadd_ss VFMADD (XMM, XMM, XMM) 5cy @ p01 4cy @ p01
|
|
15
|
+
* _mm_fmadd_sd VFMADD (XMM, XMM, XMM) 5cy @ p01 4cy @ p01
|
|
16
|
+
* _mm_cvtps_ph VCVTPS2PH (XMM, XMM, I8) 5cy @ p01 4cy @ p12+p23
|
|
17
|
+
* _mm_cvtph_ps VCVTPH2PS (XMM, XMM) 5cy @ p01 4cy @ p12+p23
|
|
18
18
|
*/
|
|
19
19
|
#ifndef NK_SCALAR_HASWELL_H
|
|
20
20
|
#define NK_SCALAR_HASWELL_H
|
|
@@ -52,23 +52,32 @@ NK_PUBLIC nk_f64_t nk_f64_fma_haswell(nk_f64_t a, nk_f64_t b, nk_f64_t c) {
|
|
|
52
52
|
return _mm_cvtsd_f64(_mm_fmadd_sd(_mm_set_sd(a), _mm_set_sd(b), _mm_set_sd(c)));
|
|
53
53
|
}
|
|
54
54
|
NK_PUBLIC nk_f16_t nk_f16_sqrt_haswell(nk_f16_t x) {
|
|
55
|
-
|
|
56
|
-
|
|
55
|
+
nk_fui16_t x_fui, out_fui;
|
|
56
|
+
x_fui.f = x;
|
|
57
|
+
__m128 x_f32x4 = _mm_cvtph_ps(_mm_cvtsi32_si128(x_fui.u));
|
|
58
|
+
out_fui.u = (nk_u16_t)_mm_cvtsi128_si32(_mm_cvtps_ph(_mm_sqrt_ps(x_f32x4), _MM_FROUND_TO_NEAREST_INT));
|
|
59
|
+
return out_fui.f;
|
|
57
60
|
}
|
|
58
61
|
NK_PUBLIC nk_f16_t nk_f16_rsqrt_haswell(nk_f16_t x) {
|
|
59
|
-
|
|
62
|
+
nk_fui16_t x_fui, out_fui;
|
|
63
|
+
x_fui.f = x;
|
|
64
|
+
__m128 x_f32x4 = _mm_cvtph_ps(_mm_cvtsi32_si128(x_fui.u));
|
|
60
65
|
__m128 estimate_f32x4 = _mm_rsqrt_ss(x_f32x4);
|
|
61
66
|
__m128 refinement_f32x4 = _mm_mul_ss(_mm_mul_ss(x_f32x4, estimate_f32x4), estimate_f32x4);
|
|
62
67
|
refinement_f32x4 = _mm_sub_ss(_mm_set_ss(3.0f), refinement_f32x4);
|
|
63
68
|
estimate_f32x4 = _mm_mul_ss(_mm_mul_ss(_mm_set_ss(0.5f), estimate_f32x4), refinement_f32x4);
|
|
64
|
-
|
|
69
|
+
out_fui.u = (nk_u16_t)_mm_cvtsi128_si32(_mm_cvtps_ph(estimate_f32x4, _MM_FROUND_TO_NEAREST_INT));
|
|
70
|
+
return out_fui.f;
|
|
65
71
|
}
|
|
66
72
|
NK_PUBLIC nk_f16_t nk_f16_fma_haswell(nk_f16_t a, nk_f16_t b, nk_f16_t c) {
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
__m128
|
|
70
|
-
|
|
73
|
+
nk_fui16_t a_fui, b_fui, c_fui, out_fui;
|
|
74
|
+
a_fui.f = a, b_fui.f = b, c_fui.f = c;
|
|
75
|
+
__m128 a_f32x4 = _mm_cvtph_ps(_mm_cvtsi32_si128(a_fui.u));
|
|
76
|
+
__m128 b_f32x4 = _mm_cvtph_ps(_mm_cvtsi32_si128(b_fui.u));
|
|
77
|
+
__m128 c_f32x4 = _mm_cvtph_ps(_mm_cvtsi32_si128(c_fui.u));
|
|
78
|
+
out_fui.u = (nk_u16_t)_mm_cvtsi128_si32(
|
|
71
79
|
_mm_cvtps_ph(_mm_fmadd_ss(a_f32x4, b_f32x4, c_f32x4), _MM_FROUND_TO_NEAREST_INT));
|
|
80
|
+
return out_fui.f;
|
|
72
81
|
}
|
|
73
82
|
NK_PUBLIC nk_u8_t nk_u8_saturating_add_haswell(nk_u8_t a, nk_u8_t b) {
|
|
74
83
|
return (nk_u8_t)_mm_cvtsi128_si32(_mm_adds_epu8(_mm_cvtsi32_si128(a), _mm_cvtsi32_si128(b)));
|
|
@@ -89,8 +98,8 @@ NK_PUBLIC nk_u64_t nk_u64_saturating_mul_haswell(nk_u64_t a, nk_u64_t b) {
|
|
|
89
98
|
}
|
|
90
99
|
NK_PUBLIC nk_i64_t nk_i64_saturating_mul_haswell(nk_i64_t a, nk_i64_t b) {
|
|
91
100
|
int sign = (a < 0) ^ (b < 0);
|
|
92
|
-
nk_u64_t abs_a = a < 0 ? -(nk_u64_t)a : (nk_u64_t)a;
|
|
93
|
-
nk_u64_t abs_b = b < 0 ? -(nk_u64_t)b : (nk_u64_t)b;
|
|
101
|
+
nk_u64_t abs_a = a < 0 ? (0u - (nk_u64_t)a) : (nk_u64_t)a;
|
|
102
|
+
nk_u64_t abs_b = b < 0 ? (0u - (nk_u64_t)b) : (nk_u64_t)b;
|
|
94
103
|
unsigned long long high;
|
|
95
104
|
unsigned long long low = _mulx_u64(abs_a, abs_b, &high);
|
|
96
105
|
if (high || (sign && low > 9223372036854775808ull) || (!sign && low > 9223372036854775807ull))
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SIMD-accelerated Scalar Math Helpers for LoongArch LASX.
|
|
3
|
+
* @file include/numkong/scalar/loongsonasx.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date March 23, 2026
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/scalar.h
|
|
8
|
+
*
|
|
9
|
+
* LASX provides `xvfrsqrt` (full-precision reciprocal sqrt) and `xvfsqrt`
|
|
10
|
+
* (full-precision sqrt). No Newton-Raphson refinement needed.
|
|
11
|
+
* Full-precision sqrt uses the hardware `xvfsqrt` instruction.
|
|
12
|
+
* Broadcast via `xvreplgr2vr`, extract via `xvpickve2gr` — no memory round-trips.
|
|
13
|
+
*/
|
|
14
|
+
#ifndef NK_SCALAR_LOONGSONASX_H
|
|
15
|
+
#define NK_SCALAR_LOONGSONASX_H
|
|
16
|
+
|
|
17
|
+
#if NK_TARGET_LOONGARCH_
|
|
18
|
+
#if NK_TARGET_LOONGSONASX
|
|
19
|
+
|
|
20
|
+
#include "numkong/types.h"
|
|
21
|
+
|
|
22
|
+
#if defined(__cplusplus)
|
|
23
|
+
extern "C" {
|
|
24
|
+
#endif
|
|
25
|
+
|
|
26
|
+
/** @brief Broadcast f32 scalar into all 4 lanes of a 128-bit register (GCC/Clang portable). */
|
|
27
|
+
NK_INTERNAL __m128 nk_xvreplgr2vr_s_128_(float x) {
|
|
28
|
+
nk_fui32_t c;
|
|
29
|
+
c.f = x;
|
|
30
|
+
return (__m128)__lsx_vreplgr2vr_w((int)c.u);
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
/** @brief Broadcast f32 scalar into all 8 lanes of a 256-bit register (GCC/Clang portable). */
|
|
34
|
+
NK_INTERNAL __m256 nk_xvfreplgr2vr_s_(float x) {
|
|
35
|
+
nk_fui32_t c;
|
|
36
|
+
c.f = x;
|
|
37
|
+
return (__m256)__lasx_xvreplgr2vr_w((int)c.u);
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
/** @brief Broadcast f64 scalar into all 4 lanes of a 256-bit register (GCC/Clang portable). */
|
|
41
|
+
NK_INTERNAL __m256d nk_xvfreplgr2vr_d_(double x) {
|
|
42
|
+
nk_fui64_t c;
|
|
43
|
+
c.f = x;
|
|
44
|
+
return (__m256d)__lasx_xvreplgr2vr_d((long long)c.u);
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
NK_PUBLIC nk_f32_t nk_f32_rsqrt_loongsonasx(nk_f32_t x) {
|
|
48
|
+
// xvfrsqrt.s is full precision — no Newton-Raphson needed
|
|
49
|
+
__m256 x_f32x8 = nk_xvfreplgr2vr_s_(x);
|
|
50
|
+
__m256 result_f32x8 = __lasx_xvfrsqrt_s(x_f32x8);
|
|
51
|
+
nk_fui32_t c;
|
|
52
|
+
c.u = (nk_u32_t)__lasx_xvpickve2gr_w((__m256i)result_f32x8, 0);
|
|
53
|
+
return c.f;
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
NK_PUBLIC nk_f32_t nk_f32_sqrt_loongsonasx(nk_f32_t x) { return x > 0 ? x * nk_f32_rsqrt_loongsonasx(x) : 0; }
|
|
57
|
+
|
|
58
|
+
NK_PUBLIC nk_f64_t nk_f64_sqrt_loongsonasx(nk_f64_t x) {
|
|
59
|
+
__m256d x_f64x4 = nk_xvfreplgr2vr_d_(x);
|
|
60
|
+
__m256d result_f64x4 = __lasx_xvfsqrt_d(x_f64x4);
|
|
61
|
+
nk_fui64_t c;
|
|
62
|
+
c.u = (nk_u64_t)__lasx_xvpickve2gr_du((__m256i)result_f64x4, 0);
|
|
63
|
+
return c.f;
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
NK_PUBLIC nk_f64_t nk_f64_rsqrt_loongsonasx(nk_f64_t x) { return 1.0 / nk_f64_sqrt_loongsonasx(x); }
|
|
67
|
+
|
|
68
|
+
#if defined(__cplusplus)
|
|
69
|
+
} // extern "C"
|
|
70
|
+
#endif
|
|
71
|
+
|
|
72
|
+
#endif // NK_TARGET_LOONGSONASX
|
|
73
|
+
#endif // NK_TARGET_LOONGARCH_
|
|
74
|
+
#endif // NK_SCALAR_LOONGSONASX_H
|