numkong 7.0.0 → 7.4.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +197 -124
- package/binding.gyp +34 -484
- package/c/dispatch_bf16.c +59 -1
- package/c/dispatch_e2m3.c +41 -8
- package/c/dispatch_e3m2.c +49 -8
- package/c/dispatch_e4m3.c +51 -9
- package/c/dispatch_e5m2.c +45 -1
- package/c/dispatch_f16.c +79 -26
- package/c/dispatch_f16c.c +5 -5
- package/c/dispatch_f32.c +56 -0
- package/c/dispatch_f64.c +52 -0
- package/c/dispatch_i4.c +3 -0
- package/c/dispatch_i8.c +62 -3
- package/c/dispatch_other.c +18 -0
- package/c/dispatch_u1.c +54 -9
- package/c/dispatch_u4.c +3 -0
- package/c/dispatch_u8.c +64 -3
- package/c/numkong.c +3 -0
- package/include/README.md +79 -9
- package/include/numkong/attention/sapphireamx.h +278 -276
- package/include/numkong/attention/sme.h +983 -977
- package/include/numkong/attention.h +1 -1
- package/include/numkong/capabilities.h +289 -94
- package/include/numkong/cast/README.md +40 -40
- package/include/numkong/cast/diamond.h +64 -0
- package/include/numkong/cast/haswell.h +42 -194
- package/include/numkong/cast/icelake.h +42 -37
- package/include/numkong/cast/loongsonasx.h +252 -0
- package/include/numkong/cast/neon.h +216 -249
- package/include/numkong/cast/powervsx.h +449 -0
- package/include/numkong/cast/rvv.h +223 -274
- package/include/numkong/cast/sapphire.h +18 -18
- package/include/numkong/cast/serial.h +1018 -944
- package/include/numkong/cast/skylake.h +82 -23
- package/include/numkong/cast/v128relaxed.h +462 -105
- package/include/numkong/cast.h +24 -0
- package/include/numkong/cast.hpp +44 -0
- package/include/numkong/curved/README.md +17 -17
- package/include/numkong/curved/neon.h +131 -7
- package/include/numkong/curved/neonbfdot.h +6 -7
- package/include/numkong/curved/rvv.h +26 -26
- package/include/numkong/curved/smef64.h +186 -182
- package/include/numkong/curved.h +14 -18
- package/include/numkong/dot/README.md +154 -137
- package/include/numkong/dot/alder.h +43 -43
- package/include/numkong/dot/diamond.h +158 -0
- package/include/numkong/dot/genoa.h +4 -30
- package/include/numkong/dot/haswell.h +215 -180
- package/include/numkong/dot/icelake.h +190 -76
- package/include/numkong/dot/loongsonasx.h +671 -0
- package/include/numkong/dot/neon.h +124 -73
- package/include/numkong/dot/neonbfdot.h +11 -12
- package/include/numkong/dot/neonfhm.h +44 -46
- package/include/numkong/dot/neonfp8.h +323 -0
- package/include/numkong/dot/neonsdot.h +190 -76
- package/include/numkong/dot/powervsx.h +752 -0
- package/include/numkong/dot/rvv.h +92 -84
- package/include/numkong/dot/rvvbf16.h +12 -12
- package/include/numkong/dot/rvvhalf.h +12 -12
- package/include/numkong/dot/sapphire.h +4 -4
- package/include/numkong/dot/serial.h +66 -30
- package/include/numkong/dot/sierra.h +31 -31
- package/include/numkong/dot/skylake.h +142 -110
- package/include/numkong/dot/sve.h +217 -177
- package/include/numkong/dot/svebfdot.h +10 -10
- package/include/numkong/dot/svehalf.h +85 -41
- package/include/numkong/dot/svesdot.h +89 -0
- package/include/numkong/dot/v128relaxed.h +124 -89
- package/include/numkong/dot.h +114 -48
- package/include/numkong/dots/README.md +203 -203
- package/include/numkong/dots/alder.h +12 -9
- package/include/numkong/dots/diamond.h +86 -0
- package/include/numkong/dots/genoa.h +10 -4
- package/include/numkong/dots/haswell.h +63 -48
- package/include/numkong/dots/icelake.h +27 -18
- package/include/numkong/dots/loongsonasx.h +176 -0
- package/include/numkong/dots/neon.h +14 -11
- package/include/numkong/dots/neonbfdot.h +4 -3
- package/include/numkong/dots/neonfhm.h +11 -9
- package/include/numkong/dots/neonfp8.h +99 -0
- package/include/numkong/dots/neonsdot.h +48 -12
- package/include/numkong/dots/powervsx.h +194 -0
- package/include/numkong/dots/rvv.h +451 -344
- package/include/numkong/dots/sapphireamx.h +1028 -984
- package/include/numkong/dots/serial.h +213 -197
- package/include/numkong/dots/sierra.h +10 -7
- package/include/numkong/dots/skylake.h +47 -36
- package/include/numkong/dots/sme.h +2001 -2364
- package/include/numkong/dots/smebi32.h +175 -162
- package/include/numkong/dots/smef64.h +328 -323
- package/include/numkong/dots/v128relaxed.h +64 -41
- package/include/numkong/dots.h +573 -293
- package/include/numkong/dots.hpp +45 -43
- package/include/numkong/each/README.md +133 -137
- package/include/numkong/each/haswell.h +6 -6
- package/include/numkong/each/icelake.h +7 -7
- package/include/numkong/each/neon.h +76 -42
- package/include/numkong/each/neonbfdot.h +11 -12
- package/include/numkong/each/neonhalf.h +24 -116
- package/include/numkong/each/rvv.h +28 -28
- package/include/numkong/each/sapphire.h +27 -161
- package/include/numkong/each/serial.h +6 -6
- package/include/numkong/each/skylake.h +7 -7
- package/include/numkong/each/v128relaxed.h +562 -0
- package/include/numkong/each.h +148 -62
- package/include/numkong/each.hpp +2 -2
- package/include/numkong/geospatial/README.md +18 -18
- package/include/numkong/geospatial/haswell.h +365 -325
- package/include/numkong/geospatial/neon.h +350 -306
- package/include/numkong/geospatial/rvv.h +4 -4
- package/include/numkong/geospatial/skylake.h +376 -340
- package/include/numkong/geospatial/v128relaxed.h +366 -327
- package/include/numkong/geospatial.h +17 -17
- package/include/numkong/matrix.hpp +4 -4
- package/include/numkong/maxsim/README.md +14 -14
- package/include/numkong/maxsim/alder.h +6 -6
- package/include/numkong/maxsim/genoa.h +4 -4
- package/include/numkong/maxsim/haswell.h +6 -6
- package/include/numkong/maxsim/icelake.h +18 -18
- package/include/numkong/maxsim/neonsdot.h +21 -21
- package/include/numkong/maxsim/sapphireamx.h +14 -14
- package/include/numkong/maxsim/serial.h +6 -6
- package/include/numkong/maxsim/sme.h +221 -196
- package/include/numkong/maxsim/v128relaxed.h +6 -6
- package/include/numkong/mesh/README.md +62 -56
- package/include/numkong/mesh/haswell.h +339 -464
- package/include/numkong/mesh/neon.h +1100 -519
- package/include/numkong/mesh/neonbfdot.h +36 -68
- package/include/numkong/mesh/rvv.h +530 -435
- package/include/numkong/mesh/serial.h +75 -91
- package/include/numkong/mesh/skylake.h +1627 -302
- package/include/numkong/mesh/v128relaxed.h +443 -330
- package/include/numkong/mesh.h +63 -49
- package/include/numkong/mesh.hpp +4 -4
- package/include/numkong/numkong.h +3 -3
- package/include/numkong/numkong.hpp +1 -0
- package/include/numkong/probability/README.md +23 -19
- package/include/numkong/probability/neon.h +82 -52
- package/include/numkong/probability/rvv.h +28 -23
- package/include/numkong/probability/serial.h +51 -39
- package/include/numkong/probability.h +20 -23
- package/include/numkong/random.h +1 -1
- package/include/numkong/reduce/README.md +143 -138
- package/include/numkong/reduce/alder.h +81 -77
- package/include/numkong/reduce/haswell.h +222 -220
- package/include/numkong/reduce/neon.h +629 -519
- package/include/numkong/reduce/neonbfdot.h +7 -218
- package/include/numkong/reduce/neonfhm.h +9 -381
- package/include/numkong/reduce/neonsdot.h +9 -9
- package/include/numkong/reduce/rvv.h +928 -802
- package/include/numkong/reduce/serial.h +23 -27
- package/include/numkong/reduce/sierra.h +20 -20
- package/include/numkong/reduce/skylake.h +326 -324
- package/include/numkong/reduce/v128relaxed.h +52 -52
- package/include/numkong/reduce.h +4 -23
- package/include/numkong/reduce.hpp +156 -11
- package/include/numkong/scalar/README.md +6 -6
- package/include/numkong/scalar/haswell.h +26 -17
- package/include/numkong/scalar/loongsonasx.h +74 -0
- package/include/numkong/scalar/neon.h +9 -9
- package/include/numkong/scalar/powervsx.h +96 -0
- package/include/numkong/scalar/rvv.h +2 -2
- package/include/numkong/scalar/sapphire.h +21 -10
- package/include/numkong/scalar/serial.h +21 -21
- package/include/numkong/scalar.h +13 -0
- package/include/numkong/set/README.md +28 -28
- package/include/numkong/set/haswell.h +12 -12
- package/include/numkong/set/icelake.h +14 -14
- package/include/numkong/set/loongsonasx.h +181 -0
- package/include/numkong/set/neon.h +17 -18
- package/include/numkong/set/powervsx.h +326 -0
- package/include/numkong/set/rvv.h +4 -4
- package/include/numkong/set/serial.h +6 -6
- package/include/numkong/set/sve.h +60 -59
- package/include/numkong/set/v128relaxed.h +6 -6
- package/include/numkong/set.h +21 -7
- package/include/numkong/sets/README.md +26 -26
- package/include/numkong/sets/loongsonasx.h +52 -0
- package/include/numkong/sets/powervsx.h +65 -0
- package/include/numkong/sets/smebi32.h +395 -364
- package/include/numkong/sets.h +83 -40
- package/include/numkong/sparse/README.md +4 -4
- package/include/numkong/sparse/icelake.h +101 -101
- package/include/numkong/sparse/serial.h +1 -1
- package/include/numkong/sparse/sve2.h +137 -141
- package/include/numkong/sparse/turin.h +12 -12
- package/include/numkong/sparse.h +10 -10
- package/include/numkong/spatial/README.md +230 -226
- package/include/numkong/spatial/alder.h +113 -116
- package/include/numkong/spatial/diamond.h +240 -0
- package/include/numkong/spatial/genoa.h +0 -68
- package/include/numkong/spatial/haswell.h +74 -55
- package/include/numkong/spatial/icelake.h +539 -58
- package/include/numkong/spatial/loongsonasx.h +483 -0
- package/include/numkong/spatial/neon.h +125 -52
- package/include/numkong/spatial/neonbfdot.h +8 -9
- package/include/numkong/spatial/neonfp8.h +258 -0
- package/include/numkong/spatial/neonsdot.h +180 -12
- package/include/numkong/spatial/powervsx.h +738 -0
- package/include/numkong/spatial/rvv.h +146 -139
- package/include/numkong/spatial/rvvbf16.h +17 -12
- package/include/numkong/spatial/rvvhalf.h +13 -10
- package/include/numkong/spatial/serial.h +13 -12
- package/include/numkong/spatial/sierra.h +232 -39
- package/include/numkong/spatial/skylake.h +73 -74
- package/include/numkong/spatial/sve.h +93 -72
- package/include/numkong/spatial/svebfdot.h +29 -29
- package/include/numkong/spatial/svehalf.h +52 -26
- package/include/numkong/spatial/svesdot.h +142 -0
- package/include/numkong/spatial/v128relaxed.h +293 -41
- package/include/numkong/spatial.h +338 -82
- package/include/numkong/spatials/README.md +194 -194
- package/include/numkong/spatials/diamond.h +82 -0
- package/include/numkong/spatials/haswell.h +2 -2
- package/include/numkong/spatials/loongsonasx.h +153 -0
- package/include/numkong/spatials/neonfp8.h +111 -0
- package/include/numkong/spatials/neonsdot.h +34 -0
- package/include/numkong/spatials/powervsx.h +153 -0
- package/include/numkong/spatials/rvv.h +259 -243
- package/include/numkong/spatials/sapphireamx.h +173 -173
- package/include/numkong/spatials/serial.h +2 -2
- package/include/numkong/spatials/skylake.h +2 -2
- package/include/numkong/spatials/sme.h +590 -605
- package/include/numkong/spatials/smef64.h +139 -130
- package/include/numkong/spatials/v128relaxed.h +2 -2
- package/include/numkong/spatials.h +820 -500
- package/include/numkong/spatials.hpp +49 -48
- package/include/numkong/tensor.hpp +406 -17
- package/include/numkong/trigonometry/README.md +19 -19
- package/include/numkong/trigonometry/haswell.h +402 -401
- package/include/numkong/trigonometry/neon.h +386 -387
- package/include/numkong/trigonometry/rvv.h +52 -51
- package/include/numkong/trigonometry/serial.h +13 -13
- package/include/numkong/trigonometry/skylake.h +373 -369
- package/include/numkong/trigonometry/v128relaxed.h +375 -374
- package/include/numkong/trigonometry.h +13 -13
- package/include/numkong/trigonometry.hpp +2 -2
- package/include/numkong/types.h +287 -49
- package/include/numkong/types.hpp +436 -12
- package/include/numkong/vector.hpp +82 -14
- package/javascript/dist/cjs/numkong-wasm.js +6 -12
- package/javascript/dist/cjs/numkong.d.ts +7 -1
- package/javascript/dist/cjs/numkong.js +37 -11
- package/javascript/dist/cjs/types.d.ts +9 -0
- package/javascript/dist/cjs/types.js +96 -0
- package/javascript/dist/esm/numkong-browser.d.ts +14 -0
- package/javascript/dist/esm/numkong-browser.js +23 -0
- package/javascript/dist/esm/numkong-wasm.js +6 -12
- package/javascript/dist/esm/numkong.d.ts +7 -1
- package/javascript/dist/esm/numkong.js +37 -11
- package/javascript/dist/esm/types.d.ts +9 -0
- package/javascript/dist/esm/types.js +96 -0
- package/javascript/node-gyp-build.d.ts +4 -1
- package/javascript/numkong-browser.ts +40 -0
- package/javascript/numkong-wasm.ts +7 -13
- package/javascript/numkong.c +5 -26
- package/javascript/numkong.ts +36 -11
- package/javascript/tsconfig-base.json +1 -0
- package/javascript/tsconfig-cjs.json +6 -1
- package/javascript/types.ts +110 -0
- package/numkong.gypi +101 -0
- package/package.json +34 -13
- package/probes/arm_neon.c +8 -0
- package/probes/arm_neon_bfdot.c +9 -0
- package/probes/arm_neon_fhm.c +9 -0
- package/probes/arm_neon_half.c +8 -0
- package/probes/arm_neon_sdot.c +9 -0
- package/probes/arm_neonfp8.c +9 -0
- package/probes/arm_sme.c +16 -0
- package/probes/arm_sme2.c +16 -0
- package/probes/arm_sme2p1.c +16 -0
- package/probes/arm_sme_bf16.c +16 -0
- package/probes/arm_sme_bi32.c +16 -0
- package/probes/arm_sme_f64.c +16 -0
- package/probes/arm_sme_fa64.c +14 -0
- package/probes/arm_sme_half.c +16 -0
- package/probes/arm_sme_lut2.c +15 -0
- package/probes/arm_sve.c +18 -0
- package/probes/arm_sve2.c +20 -0
- package/probes/arm_sve2p1.c +18 -0
- package/probes/arm_sve_bfdot.c +20 -0
- package/probes/arm_sve_half.c +18 -0
- package/probes/arm_sve_sdot.c +21 -0
- package/probes/loongarch_lasx.c +12 -0
- package/probes/power_vsx.c +12 -0
- package/probes/probe.js +127 -0
- package/probes/riscv_rvv.c +14 -0
- package/probes/riscv_rvv_bb.c +15 -0
- package/probes/riscv_rvv_bf16.c +17 -0
- package/probes/riscv_rvv_half.c +14 -0
- package/probes/wasm_v128relaxed.c +11 -0
- package/probes/x86_alder.c +17 -0
- package/probes/x86_diamond.c +17 -0
- package/probes/x86_genoa.c +17 -0
- package/probes/x86_graniteamx.c +19 -0
- package/probes/x86_haswell.c +11 -0
- package/probes/x86_icelake.c +17 -0
- package/probes/x86_sapphire.c +16 -0
- package/probes/x86_sapphireamx.c +18 -0
- package/probes/x86_sierra.c +17 -0
- package/probes/x86_skylake.c +15 -0
- package/probes/x86_turin.c +17 -0
- package/wasm/numkong-emscripten.js +2 -0
- package/wasm/numkong.d.ts +14 -0
- package/wasm/numkong.js +1124 -0
- package/wasm/numkong.wasm +0 -0
- package/include/numkong/curved/neonhalf.h +0 -212
- package/include/numkong/dot/neonhalf.h +0 -198
- package/include/numkong/dots/neonhalf.h +0 -57
- package/include/numkong/mesh/neonhalf.h +0 -616
- package/include/numkong/reduce/neonhalf.h +0 -157
- package/include/numkong/spatial/neonhalf.h +0 -118
- package/include/numkong/spatial/sapphire.h +0 -343
- package/include/numkong/spatials/neonhalf.h +0 -58
- package/javascript/README.md +0 -246
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
/**
|
|
2
|
-
* @brief ARMv8.4-FHM implementations for the redesigned reduction API
|
|
2
|
+
* @brief ARMv8.4-FHM implementations for the redesigned reduction API.
|
|
3
3
|
* @file include/numkong/reduce/neonfhm.h
|
|
4
4
|
* @author Ash Vardanian
|
|
5
5
|
* @date February 13, 2026
|
|
@@ -38,7 +38,7 @@ NK_INTERNAL void nk_reduce_moments_e4m3_neonfhm_contiguous_( //
|
|
|
38
38
|
nk_size_t idx = 0;
|
|
39
39
|
|
|
40
40
|
for (; idx + 8 <= count; idx += 8) {
|
|
41
|
-
uint8x8_t data_u8x8 = vld1_u8((
|
|
41
|
+
uint8x8_t data_u8x8 = vld1_u8((nk_u8_t const *)(data_ptr + idx));
|
|
42
42
|
float16x8_t data_f16x8 = nk_e4m3x8_to_f16x8_neon_(data_u8x8);
|
|
43
43
|
sum_f32x4 = vfmlalq_low_f16(sum_f32x4, data_f16x8, ones_f16x8);
|
|
44
44
|
sum_f32x4 = vfmlalq_high_f16(sum_f32x4, data_f16x8, ones_f16x8);
|
|
@@ -71,7 +71,7 @@ NK_INTERNAL void nk_reduce_moments_e4m3_neonfhm_strided_( //
|
|
|
71
71
|
nk_size_t idx = 0;
|
|
72
72
|
|
|
73
73
|
if (stride_elements == 2) {
|
|
74
|
-
for (; idx + 8
|
|
74
|
+
for (; idx + 8 < count; idx += 8) {
|
|
75
75
|
uint8x8x2_t loaded = vld2_u8((nk_u8_t const *)(data_ptr + idx * 2));
|
|
76
76
|
float16x8_t data_f16x8 = nk_e4m3x8_to_f16x8_neon_(loaded.val[0]);
|
|
77
77
|
sum_f32x4 = vfmlalq_low_f16(sum_f32x4, data_f16x8, ones_f16x8);
|
|
@@ -81,7 +81,7 @@ NK_INTERNAL void nk_reduce_moments_e4m3_neonfhm_strided_( //
|
|
|
81
81
|
}
|
|
82
82
|
}
|
|
83
83
|
else if (stride_elements == 3) {
|
|
84
|
-
for (; idx + 8
|
|
84
|
+
for (; idx + 8 < count; idx += 8) {
|
|
85
85
|
uint8x8x3_t loaded = vld3_u8((nk_u8_t const *)(data_ptr + idx * 3));
|
|
86
86
|
float16x8_t data_f16x8 = nk_e4m3x8_to_f16x8_neon_(loaded.val[0]);
|
|
87
87
|
sum_f32x4 = vfmlalq_low_f16(sum_f32x4, data_f16x8, ones_f16x8);
|
|
@@ -91,7 +91,7 @@ NK_INTERNAL void nk_reduce_moments_e4m3_neonfhm_strided_( //
|
|
|
91
91
|
}
|
|
92
92
|
}
|
|
93
93
|
else if (stride_elements == 4) {
|
|
94
|
-
for (; idx + 8
|
|
94
|
+
for (; idx + 8 < count; idx += 8) {
|
|
95
95
|
uint8x8x4_t loaded = vld4_u8((nk_u8_t const *)(data_ptr + idx * 4));
|
|
96
96
|
float16x8_t data_f16x8 = nk_e4m3x8_to_f16x8_neon_(loaded.val[0]);
|
|
97
97
|
sum_f32x4 = vfmlalq_low_f16(sum_f32x4, data_f16x8, ones_f16x8);
|
|
@@ -163,7 +163,7 @@ NK_INTERNAL void nk_reduce_moments_e5m2_neonfhm_contiguous_( //
|
|
|
163
163
|
nk_size_t idx = 0;
|
|
164
164
|
|
|
165
165
|
for (; idx + 8 <= count; idx += 8) {
|
|
166
|
-
uint8x8_t data_u8x8 = vld1_u8((
|
|
166
|
+
uint8x8_t data_u8x8 = vld1_u8((nk_u8_t const *)(data_ptr + idx));
|
|
167
167
|
float16x8_t data_f16x8 = nk_e5m2x8_to_f16x8_neon_(data_u8x8);
|
|
168
168
|
sum_f32x4 = vfmlalq_low_f16(sum_f32x4, data_f16x8, ones_f16x8);
|
|
169
169
|
sum_f32x4 = vfmlalq_high_f16(sum_f32x4, data_f16x8, ones_f16x8);
|
|
@@ -196,7 +196,7 @@ NK_INTERNAL void nk_reduce_moments_e5m2_neonfhm_strided_( //
|
|
|
196
196
|
nk_size_t idx = 0;
|
|
197
197
|
|
|
198
198
|
if (stride_elements == 2) {
|
|
199
|
-
for (; idx + 8
|
|
199
|
+
for (; idx + 8 < count; idx += 8) {
|
|
200
200
|
uint8x8x2_t loaded = vld2_u8((nk_u8_t const *)(data_ptr + idx * 2));
|
|
201
201
|
float16x8_t data_f16x8 = nk_e5m2x8_to_f16x8_neon_(loaded.val[0]);
|
|
202
202
|
sum_f32x4 = vfmlalq_low_f16(sum_f32x4, data_f16x8, ones_f16x8);
|
|
@@ -206,7 +206,7 @@ NK_INTERNAL void nk_reduce_moments_e5m2_neonfhm_strided_( //
|
|
|
206
206
|
}
|
|
207
207
|
}
|
|
208
208
|
else if (stride_elements == 3) {
|
|
209
|
-
for (; idx + 8
|
|
209
|
+
for (; idx + 8 < count; idx += 8) {
|
|
210
210
|
uint8x8x3_t loaded = vld3_u8((nk_u8_t const *)(data_ptr + idx * 3));
|
|
211
211
|
float16x8_t data_f16x8 = nk_e5m2x8_to_f16x8_neon_(loaded.val[0]);
|
|
212
212
|
sum_f32x4 = vfmlalq_low_f16(sum_f32x4, data_f16x8, ones_f16x8);
|
|
@@ -216,7 +216,7 @@ NK_INTERNAL void nk_reduce_moments_e5m2_neonfhm_strided_( //
|
|
|
216
216
|
}
|
|
217
217
|
}
|
|
218
218
|
else if (stride_elements == 4) {
|
|
219
|
-
for (; idx + 8
|
|
219
|
+
for (; idx + 8 < count; idx += 8) {
|
|
220
220
|
uint8x8x4_t loaded = vld4_u8((nk_u8_t const *)(data_ptr + idx * 4));
|
|
221
221
|
float16x8_t data_f16x8 = nk_e5m2x8_to_f16x8_neon_(loaded.val[0]);
|
|
222
222
|
sum_f32x4 = vfmlalq_low_f16(sum_f32x4, data_f16x8, ones_f16x8);
|
|
@@ -278,378 +278,6 @@ NK_PUBLIC void nk_reduce_moments_e5m2_neonfhm( //
|
|
|
278
278
|
else nk_reduce_moments_e5m2_neonfhm_strided_(data_ptr, count, stride_elements, sum_ptr, sumsq_ptr);
|
|
279
279
|
}
|
|
280
280
|
|
|
281
|
-
NK_INTERNAL void nk_reduce_minmax_e4m3_neonfhm_contiguous_( //
|
|
282
|
-
nk_e4m3_t const *data_ptr, nk_size_t count, //
|
|
283
|
-
nk_e4m3_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
284
|
-
nk_e4m3_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
285
|
-
uint8x16_t min_u8x16 = vdupq_n_u8(0xFF);
|
|
286
|
-
uint8x16_t max_u8x16 = vdupq_n_u8(0x00);
|
|
287
|
-
uint8x16_t min_iter_u8x16 = vdupq_n_u8(0), max_iter_u8x16 = vdupq_n_u8(0);
|
|
288
|
-
uint8x16_t iter_u8x16 = vdupq_n_u8(0), one_u8x16 = vdupq_n_u8(1);
|
|
289
|
-
nk_size_t idx = 0;
|
|
290
|
-
for (; idx + 16 <= count; idx += 16) {
|
|
291
|
-
uint8x16_t raw_u8x16 = vld1q_u8((nk_u8_t const *)(data_ptr + idx));
|
|
292
|
-
uint8x16_t comparable_u8x16 = nk_fp8x16_to_comparable_neon_(raw_u8x16);
|
|
293
|
-
uint8x16_t less_u8x16 = vcltq_u8(comparable_u8x16, min_u8x16);
|
|
294
|
-
uint8x16_t greater_u8x16 = vcgtq_u8(comparable_u8x16, max_u8x16);
|
|
295
|
-
min_u8x16 = vbslq_u8(less_u8x16, comparable_u8x16, min_u8x16);
|
|
296
|
-
max_u8x16 = vbslq_u8(greater_u8x16, comparable_u8x16, max_u8x16);
|
|
297
|
-
min_iter_u8x16 = vbslq_u8(less_u8x16, iter_u8x16, min_iter_u8x16);
|
|
298
|
-
max_iter_u8x16 = vbslq_u8(greater_u8x16, iter_u8x16, max_iter_u8x16);
|
|
299
|
-
iter_u8x16 = vaddq_u8(iter_u8x16, one_u8x16);
|
|
300
|
-
}
|
|
301
|
-
nk_size_t remaining = count - idx;
|
|
302
|
-
if (remaining > 0) {
|
|
303
|
-
nk_b128_vec_t tail_vec;
|
|
304
|
-
nk_partial_load_b8x16_serial_(data_ptr + idx, &tail_vec, remaining);
|
|
305
|
-
uint8x16_t comparable_u8x16 = nk_fp8x16_to_comparable_neon_(tail_vec.u8x16);
|
|
306
|
-
uint8x16_t lane_indices_u8x16 = vcombine_u8(vreinterpret_u8_u64(vcreate_u64(0x0706050403020100ULL)),
|
|
307
|
-
vreinterpret_u8_u64(vcreate_u64(0x0F0E0D0C0B0A0908ULL)));
|
|
308
|
-
uint8x16_t valid_u8x16 = vcltq_u8(lane_indices_u8x16, vdupq_n_u8((uint8_t)remaining));
|
|
309
|
-
uint8x16_t data_for_min_u8x16 = vbslq_u8(valid_u8x16, comparable_u8x16, vdupq_n_u8(0xFF));
|
|
310
|
-
uint8x16_t data_for_max_u8x16 = vbslq_u8(valid_u8x16, comparable_u8x16, vdupq_n_u8(0));
|
|
311
|
-
uint8x16_t less_u8x16 = vcltq_u8(data_for_min_u8x16, min_u8x16);
|
|
312
|
-
uint8x16_t greater_u8x16 = vcgtq_u8(data_for_max_u8x16, max_u8x16);
|
|
313
|
-
min_u8x16 = vbslq_u8(less_u8x16, data_for_min_u8x16, min_u8x16);
|
|
314
|
-
max_u8x16 = vbslq_u8(greater_u8x16, data_for_max_u8x16, max_u8x16);
|
|
315
|
-
min_iter_u8x16 = vbslq_u8(less_u8x16, iter_u8x16, min_iter_u8x16);
|
|
316
|
-
max_iter_u8x16 = vbslq_u8(greater_u8x16, iter_u8x16, max_iter_u8x16);
|
|
317
|
-
}
|
|
318
|
-
nk_u8_t min_comparable = vminvq_u8(min_u8x16), max_comparable = vmaxvq_u8(max_u8x16);
|
|
319
|
-
// All-NaN early return: both sentinels unchanged means no valid data was found
|
|
320
|
-
if (min_comparable == 0xFF && max_comparable == 0x00) {
|
|
321
|
-
*min_value_ptr = nk_comparable_to_fp8_(min_comparable), *min_index_ptr = NK_SIZE_MAX;
|
|
322
|
-
*max_value_ptr = nk_comparable_to_fp8_(max_comparable), *max_index_ptr = NK_SIZE_MAX;
|
|
323
|
-
return;
|
|
324
|
-
}
|
|
325
|
-
uint8x16_t min_value_match_u8x16 = vceqq_u8(min_u8x16, vdupq_n_u8(min_comparable));
|
|
326
|
-
uint8x16_t masked_min_iter_u8x16 = vbslq_u8(min_value_match_u8x16, min_iter_u8x16, vdupq_n_u8(0xFF));
|
|
327
|
-
nk_u8_t earliest_min_cycle = vminvq_u8(masked_min_iter_u8x16);
|
|
328
|
-
uint8x16_t max_value_match_u8x16 = vceqq_u8(max_u8x16, vdupq_n_u8(max_comparable));
|
|
329
|
-
uint8x16_t masked_max_iter_u8x16 = vbslq_u8(max_value_match_u8x16, max_iter_u8x16, vdupq_n_u8(0xFF));
|
|
330
|
-
nk_u8_t earliest_max_cycle = vminvq_u8(masked_max_iter_u8x16);
|
|
331
|
-
uint8x16_t lane_indices_u8x16 = vcombine_u8(vreinterpret_u8_u64(vcreate_u64(0x0706050403020100ULL)),
|
|
332
|
-
vreinterpret_u8_u64(vcreate_u64(0x0F0E0D0C0B0A0908ULL)));
|
|
333
|
-
uint8x16_t min_cycle_match_u8x16 = vceqq_u8(min_iter_u8x16, vdupq_n_u8(earliest_min_cycle));
|
|
334
|
-
uint8x16_t min_both_match_u8x16 = vandq_u8(min_value_match_u8x16, min_cycle_match_u8x16);
|
|
335
|
-
uint8x16_t min_masked_lanes_u8x16 = vbslq_u8(min_both_match_u8x16, lane_indices_u8x16, vdupq_n_u8(0xFF));
|
|
336
|
-
nk_u8_t min_lane_offset = vminvq_u8(min_masked_lanes_u8x16);
|
|
337
|
-
nk_size_t min_idx = (nk_size_t)earliest_min_cycle * 16 + (nk_size_t)min_lane_offset;
|
|
338
|
-
uint8x16_t max_cycle_match_u8x16 = vceqq_u8(max_iter_u8x16, vdupq_n_u8(earliest_max_cycle));
|
|
339
|
-
uint8x16_t max_both_match_u8x16 = vandq_u8(max_value_match_u8x16, max_cycle_match_u8x16);
|
|
340
|
-
uint8x16_t max_masked_lanes_u8x16 = vbslq_u8(max_both_match_u8x16, lane_indices_u8x16, vdupq_n_u8(0xFF));
|
|
341
|
-
nk_u8_t max_lane_offset = vminvq_u8(max_masked_lanes_u8x16);
|
|
342
|
-
nk_size_t max_idx = (nk_size_t)earliest_max_cycle * 16 + (nk_size_t)max_lane_offset;
|
|
343
|
-
*min_value_ptr = nk_comparable_to_fp8_(min_comparable), *min_index_ptr = min_idx;
|
|
344
|
-
*max_value_ptr = nk_comparable_to_fp8_(max_comparable), *max_index_ptr = max_idx;
|
|
345
|
-
}
|
|
346
|
-
|
|
347
|
-
NK_INTERNAL void nk_reduce_minmax_e4m3_neonfhm_strided_( //
|
|
348
|
-
nk_e4m3_t const *data_ptr, nk_size_t count, nk_size_t stride_elements, //
|
|
349
|
-
nk_e4m3_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
350
|
-
nk_e4m3_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
351
|
-
uint8x16_t min_u8x16 = vdupq_n_u8(0xFF), max_u8x16 = vdupq_n_u8(0);
|
|
352
|
-
uint8x16_t min_iter_u8x16 = vdupq_n_u8(0), max_iter_u8x16 = vdupq_n_u8(0);
|
|
353
|
-
uint8x16_t iter_u8x16 = vdupq_n_u8(0), one_u8x16 = vdupq_n_u8(1);
|
|
354
|
-
uint8x16_t lane_indices_u8x16 = vcombine_u8(vreinterpret_u8_u64(vcreate_u64(0x0706050403020100ULL)),
|
|
355
|
-
vreinterpret_u8_u64(vcreate_u64(0x0F0E0D0C0B0A0908ULL)));
|
|
356
|
-
nk_size_t idx = 0;
|
|
357
|
-
uint8x16_t data_for_min_u8x16, data_for_max_u8x16;
|
|
358
|
-
|
|
359
|
-
nk_reduce_minmax_e4m3_neonfhm_cycle:
|
|
360
|
-
if (stride_elements == 2 && idx + 16 <= count) {
|
|
361
|
-
uint8x16x2_t loaded = vld2q_u8((nk_u8_t const *)(data_ptr + idx * 2));
|
|
362
|
-
uint8x16_t comparable_u8x16 = nk_fp8x16_to_comparable_neon_(loaded.val[0]);
|
|
363
|
-
data_for_min_u8x16 = comparable_u8x16;
|
|
364
|
-
data_for_max_u8x16 = comparable_u8x16;
|
|
365
|
-
idx += 16;
|
|
366
|
-
}
|
|
367
|
-
else if (stride_elements == 3 && idx + 16 <= count) {
|
|
368
|
-
uint8x16x3_t loaded = vld3q_u8((nk_u8_t const *)(data_ptr + idx * 3));
|
|
369
|
-
uint8x16_t comparable_u8x16 = nk_fp8x16_to_comparable_neon_(loaded.val[0]);
|
|
370
|
-
data_for_min_u8x16 = comparable_u8x16;
|
|
371
|
-
data_for_max_u8x16 = comparable_u8x16;
|
|
372
|
-
idx += 16;
|
|
373
|
-
}
|
|
374
|
-
else if (stride_elements == 4 && idx + 16 <= count) {
|
|
375
|
-
uint8x16x4_t loaded = vld4q_u8((nk_u8_t const *)(data_ptr + idx * 4));
|
|
376
|
-
uint8x16_t comparable_u8x16 = nk_fp8x16_to_comparable_neon_(loaded.val[0]);
|
|
377
|
-
data_for_min_u8x16 = comparable_u8x16;
|
|
378
|
-
data_for_max_u8x16 = comparable_u8x16;
|
|
379
|
-
idx += 16;
|
|
380
|
-
}
|
|
381
|
-
else if (idx < count) {
|
|
382
|
-
nk_b128_vec_t tail_vec;
|
|
383
|
-
nk_strided_load_b8x16_serial_(data_ptr + idx * stride_elements, stride_elements, &tail_vec, count - idx);
|
|
384
|
-
uint8x16_t comparable_u8x16 = nk_fp8x16_to_comparable_neon_(tail_vec.u8x16);
|
|
385
|
-
uint8x16_t valid_u8x16 = vcltq_u8(lane_indices_u8x16, vdupq_n_u8((uint8_t)(count - idx)));
|
|
386
|
-
data_for_min_u8x16 = vbslq_u8(valid_u8x16, comparable_u8x16, vdupq_n_u8(0xFF));
|
|
387
|
-
data_for_max_u8x16 = vbslq_u8(valid_u8x16, comparable_u8x16, vdupq_n_u8(0x00));
|
|
388
|
-
idx = count;
|
|
389
|
-
}
|
|
390
|
-
else {
|
|
391
|
-
nk_u8_t min_comparable = vminvq_u8(min_u8x16), max_comparable = vmaxvq_u8(max_u8x16);
|
|
392
|
-
if (min_comparable == 0xFF && max_comparable == 0x00) {
|
|
393
|
-
*min_value_ptr = nk_comparable_to_fp8_(min_comparable), *min_index_ptr = NK_SIZE_MAX;
|
|
394
|
-
*max_value_ptr = nk_comparable_to_fp8_(max_comparable), *max_index_ptr = NK_SIZE_MAX;
|
|
395
|
-
return;
|
|
396
|
-
}
|
|
397
|
-
uint8x16_t min_value_match_u8x16 = vceqq_u8(min_u8x16, vdupq_n_u8(min_comparable));
|
|
398
|
-
uint8x16_t masked_min_iter_u8x16 = vbslq_u8(min_value_match_u8x16, min_iter_u8x16, vdupq_n_u8(0xFF));
|
|
399
|
-
nk_u8_t earliest_min_cycle = vminvq_u8(masked_min_iter_u8x16);
|
|
400
|
-
uint8x16_t max_value_match_u8x16 = vceqq_u8(max_u8x16, vdupq_n_u8(max_comparable));
|
|
401
|
-
uint8x16_t masked_max_iter_u8x16 = vbslq_u8(max_value_match_u8x16, max_iter_u8x16, vdupq_n_u8(0xFF));
|
|
402
|
-
nk_u8_t earliest_max_cycle = vminvq_u8(masked_max_iter_u8x16);
|
|
403
|
-
uint8x16_t min_cycle_match_u8x16 = vceqq_u8(min_iter_u8x16, vdupq_n_u8(earliest_min_cycle));
|
|
404
|
-
uint8x16_t min_both_match_u8x16 = vandq_u8(min_value_match_u8x16, min_cycle_match_u8x16);
|
|
405
|
-
uint8x16_t min_masked_lanes_u8x16 = vbslq_u8(min_both_match_u8x16, lane_indices_u8x16, vdupq_n_u8(0xFF));
|
|
406
|
-
nk_u8_t min_lane_offset = vminvq_u8(min_masked_lanes_u8x16);
|
|
407
|
-
nk_size_t min_idx = (nk_size_t)earliest_min_cycle * 16 + (nk_size_t)min_lane_offset;
|
|
408
|
-
uint8x16_t max_cycle_match_u8x16 = vceqq_u8(max_iter_u8x16, vdupq_n_u8(earliest_max_cycle));
|
|
409
|
-
uint8x16_t max_both_match_u8x16 = vandq_u8(max_value_match_u8x16, max_cycle_match_u8x16);
|
|
410
|
-
uint8x16_t max_masked_lanes_u8x16 = vbslq_u8(max_both_match_u8x16, lane_indices_u8x16, vdupq_n_u8(0xFF));
|
|
411
|
-
nk_u8_t max_lane_offset = vminvq_u8(max_masked_lanes_u8x16);
|
|
412
|
-
nk_size_t max_idx = (nk_size_t)earliest_max_cycle * 16 + (nk_size_t)max_lane_offset;
|
|
413
|
-
*min_value_ptr = nk_comparable_to_fp8_(min_comparable), *min_index_ptr = min_idx;
|
|
414
|
-
*max_value_ptr = nk_comparable_to_fp8_(max_comparable), *max_index_ptr = max_idx;
|
|
415
|
-
return;
|
|
416
|
-
}
|
|
417
|
-
|
|
418
|
-
// Shared update body
|
|
419
|
-
uint8x16_t less_u8x16 = vcltq_u8(data_for_min_u8x16, min_u8x16);
|
|
420
|
-
uint8x16_t greater_u8x16 = vcgtq_u8(data_for_max_u8x16, max_u8x16);
|
|
421
|
-
min_u8x16 = vbslq_u8(less_u8x16, data_for_min_u8x16, min_u8x16);
|
|
422
|
-
max_u8x16 = vbslq_u8(greater_u8x16, data_for_max_u8x16, max_u8x16);
|
|
423
|
-
min_iter_u8x16 = vbslq_u8(less_u8x16, iter_u8x16, min_iter_u8x16);
|
|
424
|
-
max_iter_u8x16 = vbslq_u8(greater_u8x16, iter_u8x16, max_iter_u8x16);
|
|
425
|
-
iter_u8x16 = vaddq_u8(iter_u8x16, one_u8x16);
|
|
426
|
-
goto nk_reduce_minmax_e4m3_neonfhm_cycle;
|
|
427
|
-
}
|
|
428
|
-
|
|
429
|
-
NK_PUBLIC void nk_reduce_minmax_e4m3_neonfhm( //
|
|
430
|
-
nk_e4m3_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
431
|
-
nk_e4m3_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
432
|
-
nk_e4m3_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
433
|
-
nk_size_t stride_elements = stride_bytes / sizeof(nk_e4m3_t);
|
|
434
|
-
int aligned = (stride_bytes % sizeof(nk_e4m3_t) == 0);
|
|
435
|
-
if (count == 0)
|
|
436
|
-
*min_value_ptr = NK_E4M3_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_E4M3_MIN,
|
|
437
|
-
*max_index_ptr = NK_SIZE_MAX;
|
|
438
|
-
else if (!aligned)
|
|
439
|
-
nk_reduce_minmax_e4m3_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
440
|
-
max_index_ptr);
|
|
441
|
-
else if (count > (nk_size_t)256 * 16) {
|
|
442
|
-
nk_size_t left_count = count / 2;
|
|
443
|
-
nk_e4m3_t left_min_value, right_min_value, left_max_value, right_max_value;
|
|
444
|
-
nk_size_t left_min_index, right_min_index, left_max_index, right_max_index;
|
|
445
|
-
nk_reduce_minmax_e4m3_neonfhm(data_ptr, left_count, stride_bytes, &left_min_value, &left_min_index,
|
|
446
|
-
&left_max_value, &left_max_index);
|
|
447
|
-
nk_reduce_minmax_e4m3_neonfhm(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
|
|
448
|
-
&right_min_value, &right_min_index, &right_max_value, &right_max_index);
|
|
449
|
-
if (nk_e4m3_order_serial(right_min_value, left_min_value) < 0)
|
|
450
|
-
*min_value_ptr = right_min_value, *min_index_ptr = left_count + right_min_index;
|
|
451
|
-
else *min_value_ptr = left_min_value, *min_index_ptr = left_min_index;
|
|
452
|
-
if (nk_e4m3_order_serial(right_max_value, left_max_value) > 0)
|
|
453
|
-
*max_value_ptr = right_max_value, *max_index_ptr = left_count + right_max_index;
|
|
454
|
-
else *max_value_ptr = left_max_value, *max_index_ptr = left_max_index;
|
|
455
|
-
}
|
|
456
|
-
else if (stride_elements == 1)
|
|
457
|
-
nk_reduce_minmax_e4m3_neonfhm_contiguous_(data_ptr, count, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
458
|
-
max_index_ptr);
|
|
459
|
-
else if (stride_elements <= 4)
|
|
460
|
-
nk_reduce_minmax_e4m3_neonfhm_strided_(data_ptr, count, stride_elements, min_value_ptr, min_index_ptr,
|
|
461
|
-
max_value_ptr, max_index_ptr);
|
|
462
|
-
else
|
|
463
|
-
nk_reduce_minmax_e4m3_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
464
|
-
max_index_ptr);
|
|
465
|
-
}
|
|
466
|
-
|
|
467
|
-
NK_INTERNAL void nk_reduce_minmax_e5m2_neonfhm_contiguous_( //
|
|
468
|
-
nk_e5m2_t const *data_ptr, nk_size_t count, //
|
|
469
|
-
nk_e5m2_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
470
|
-
nk_e5m2_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
471
|
-
uint8x16_t min_u8x16 = vdupq_n_u8(0xFF);
|
|
472
|
-
uint8x16_t max_u8x16 = vdupq_n_u8(0x00);
|
|
473
|
-
uint8x16_t min_iter_u8x16 = vdupq_n_u8(0), max_iter_u8x16 = vdupq_n_u8(0);
|
|
474
|
-
uint8x16_t iter_u8x16 = vdupq_n_u8(0), one_u8x16 = vdupq_n_u8(1);
|
|
475
|
-
nk_size_t idx = 0;
|
|
476
|
-
for (; idx + 16 <= count; idx += 16) {
|
|
477
|
-
uint8x16_t raw_u8x16 = vld1q_u8((nk_u8_t const *)(data_ptr + idx));
|
|
478
|
-
uint8x16_t comparable_u8x16 = nk_fp8x16_to_comparable_neon_(raw_u8x16);
|
|
479
|
-
uint8x16_t less_u8x16 = vcltq_u8(comparable_u8x16, min_u8x16);
|
|
480
|
-
uint8x16_t greater_u8x16 = vcgtq_u8(comparable_u8x16, max_u8x16);
|
|
481
|
-
min_u8x16 = vbslq_u8(less_u8x16, comparable_u8x16, min_u8x16);
|
|
482
|
-
max_u8x16 = vbslq_u8(greater_u8x16, comparable_u8x16, max_u8x16);
|
|
483
|
-
min_iter_u8x16 = vbslq_u8(less_u8x16, iter_u8x16, min_iter_u8x16);
|
|
484
|
-
max_iter_u8x16 = vbslq_u8(greater_u8x16, iter_u8x16, max_iter_u8x16);
|
|
485
|
-
iter_u8x16 = vaddq_u8(iter_u8x16, one_u8x16);
|
|
486
|
-
}
|
|
487
|
-
nk_size_t remaining = count - idx;
|
|
488
|
-
if (remaining > 0) {
|
|
489
|
-
nk_b128_vec_t tail_vec;
|
|
490
|
-
nk_partial_load_b8x16_serial_(data_ptr + idx, &tail_vec, remaining);
|
|
491
|
-
uint8x16_t comparable_u8x16 = nk_fp8x16_to_comparable_neon_(tail_vec.u8x16);
|
|
492
|
-
uint8x16_t lane_indices_u8x16 = vcombine_u8(vreinterpret_u8_u64(vcreate_u64(0x0706050403020100ULL)),
|
|
493
|
-
vreinterpret_u8_u64(vcreate_u64(0x0F0E0D0C0B0A0908ULL)));
|
|
494
|
-
uint8x16_t valid_u8x16 = vcltq_u8(lane_indices_u8x16, vdupq_n_u8((uint8_t)remaining));
|
|
495
|
-
uint8x16_t data_for_min_u8x16 = vbslq_u8(valid_u8x16, comparable_u8x16, vdupq_n_u8(0xFF));
|
|
496
|
-
uint8x16_t data_for_max_u8x16 = vbslq_u8(valid_u8x16, comparable_u8x16, vdupq_n_u8(0));
|
|
497
|
-
uint8x16_t less_u8x16 = vcltq_u8(data_for_min_u8x16, min_u8x16);
|
|
498
|
-
uint8x16_t greater_u8x16 = vcgtq_u8(data_for_max_u8x16, max_u8x16);
|
|
499
|
-
min_u8x16 = vbslq_u8(less_u8x16, data_for_min_u8x16, min_u8x16);
|
|
500
|
-
max_u8x16 = vbslq_u8(greater_u8x16, data_for_max_u8x16, max_u8x16);
|
|
501
|
-
min_iter_u8x16 = vbslq_u8(less_u8x16, iter_u8x16, min_iter_u8x16);
|
|
502
|
-
max_iter_u8x16 = vbslq_u8(greater_u8x16, iter_u8x16, max_iter_u8x16);
|
|
503
|
-
}
|
|
504
|
-
nk_u8_t min_comparable = vminvq_u8(min_u8x16), max_comparable = vmaxvq_u8(max_u8x16);
|
|
505
|
-
// All-NaN early return: both sentinels unchanged means no valid data was found
|
|
506
|
-
if (min_comparable == 0xFF && max_comparable == 0x00) {
|
|
507
|
-
*min_value_ptr = nk_comparable_to_fp8_(min_comparable), *min_index_ptr = NK_SIZE_MAX;
|
|
508
|
-
*max_value_ptr = nk_comparable_to_fp8_(max_comparable), *max_index_ptr = NK_SIZE_MAX;
|
|
509
|
-
return;
|
|
510
|
-
}
|
|
511
|
-
uint8x16_t min_value_match_u8x16 = vceqq_u8(min_u8x16, vdupq_n_u8(min_comparable));
|
|
512
|
-
uint8x16_t masked_min_iter_u8x16 = vbslq_u8(min_value_match_u8x16, min_iter_u8x16, vdupq_n_u8(0xFF));
|
|
513
|
-
nk_u8_t earliest_min_cycle = vminvq_u8(masked_min_iter_u8x16);
|
|
514
|
-
uint8x16_t max_value_match_u8x16 = vceqq_u8(max_u8x16, vdupq_n_u8(max_comparable));
|
|
515
|
-
uint8x16_t masked_max_iter_u8x16 = vbslq_u8(max_value_match_u8x16, max_iter_u8x16, vdupq_n_u8(0xFF));
|
|
516
|
-
nk_u8_t earliest_max_cycle = vminvq_u8(masked_max_iter_u8x16);
|
|
517
|
-
uint8x16_t lane_indices_u8x16 = vcombine_u8(vreinterpret_u8_u64(vcreate_u64(0x0706050403020100ULL)),
|
|
518
|
-
vreinterpret_u8_u64(vcreate_u64(0x0F0E0D0C0B0A0908ULL)));
|
|
519
|
-
uint8x16_t min_cycle_match_u8x16 = vceqq_u8(min_iter_u8x16, vdupq_n_u8(earliest_min_cycle));
|
|
520
|
-
uint8x16_t min_both_match_u8x16 = vandq_u8(min_value_match_u8x16, min_cycle_match_u8x16);
|
|
521
|
-
uint8x16_t min_masked_lanes_u8x16 = vbslq_u8(min_both_match_u8x16, lane_indices_u8x16, vdupq_n_u8(0xFF));
|
|
522
|
-
nk_u8_t min_lane_offset = vminvq_u8(min_masked_lanes_u8x16);
|
|
523
|
-
nk_size_t min_idx = (nk_size_t)earliest_min_cycle * 16 + (nk_size_t)min_lane_offset;
|
|
524
|
-
uint8x16_t max_cycle_match_u8x16 = vceqq_u8(max_iter_u8x16, vdupq_n_u8(earliest_max_cycle));
|
|
525
|
-
uint8x16_t max_both_match_u8x16 = vandq_u8(max_value_match_u8x16, max_cycle_match_u8x16);
|
|
526
|
-
uint8x16_t max_masked_lanes_u8x16 = vbslq_u8(max_both_match_u8x16, lane_indices_u8x16, vdupq_n_u8(0xFF));
|
|
527
|
-
nk_u8_t max_lane_offset = vminvq_u8(max_masked_lanes_u8x16);
|
|
528
|
-
nk_size_t max_idx = (nk_size_t)earliest_max_cycle * 16 + (nk_size_t)max_lane_offset;
|
|
529
|
-
*min_value_ptr = nk_comparable_to_fp8_(min_comparable), *min_index_ptr = min_idx;
|
|
530
|
-
*max_value_ptr = nk_comparable_to_fp8_(max_comparable), *max_index_ptr = max_idx;
|
|
531
|
-
}
|
|
532
|
-
|
|
533
|
-
NK_INTERNAL void nk_reduce_minmax_e5m2_neonfhm_strided_( //
|
|
534
|
-
nk_e5m2_t const *data_ptr, nk_size_t count, nk_size_t stride_elements, //
|
|
535
|
-
nk_e5m2_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
536
|
-
nk_e5m2_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
537
|
-
uint8x16_t min_u8x16 = vdupq_n_u8(0xFF), max_u8x16 = vdupq_n_u8(0);
|
|
538
|
-
uint8x16_t min_iter_u8x16 = vdupq_n_u8(0), max_iter_u8x16 = vdupq_n_u8(0);
|
|
539
|
-
uint8x16_t iter_u8x16 = vdupq_n_u8(0), one_u8x16 = vdupq_n_u8(1);
|
|
540
|
-
uint8x16_t lane_indices_u8x16 = vcombine_u8(vreinterpret_u8_u64(vcreate_u64(0x0706050403020100ULL)),
|
|
541
|
-
vreinterpret_u8_u64(vcreate_u64(0x0F0E0D0C0B0A0908ULL)));
|
|
542
|
-
nk_size_t idx = 0;
|
|
543
|
-
uint8x16_t data_for_min_u8x16, data_for_max_u8x16;
|
|
544
|
-
|
|
545
|
-
nk_reduce_minmax_e5m2_neonfhm_cycle:
|
|
546
|
-
if (stride_elements == 2 && idx + 16 <= count) {
|
|
547
|
-
uint8x16x2_t loaded = vld2q_u8((nk_u8_t const *)(data_ptr + idx * 2));
|
|
548
|
-
uint8x16_t comparable_u8x16 = nk_fp8x16_to_comparable_neon_(loaded.val[0]);
|
|
549
|
-
data_for_min_u8x16 = comparable_u8x16;
|
|
550
|
-
data_for_max_u8x16 = comparable_u8x16;
|
|
551
|
-
idx += 16;
|
|
552
|
-
}
|
|
553
|
-
else if (stride_elements == 3 && idx + 16 <= count) {
|
|
554
|
-
uint8x16x3_t loaded = vld3q_u8((nk_u8_t const *)(data_ptr + idx * 3));
|
|
555
|
-
uint8x16_t comparable_u8x16 = nk_fp8x16_to_comparable_neon_(loaded.val[0]);
|
|
556
|
-
data_for_min_u8x16 = comparable_u8x16;
|
|
557
|
-
data_for_max_u8x16 = comparable_u8x16;
|
|
558
|
-
idx += 16;
|
|
559
|
-
}
|
|
560
|
-
else if (stride_elements == 4 && idx + 16 <= count) {
|
|
561
|
-
uint8x16x4_t loaded = vld4q_u8((nk_u8_t const *)(data_ptr + idx * 4));
|
|
562
|
-
uint8x16_t comparable_u8x16 = nk_fp8x16_to_comparable_neon_(loaded.val[0]);
|
|
563
|
-
data_for_min_u8x16 = comparable_u8x16;
|
|
564
|
-
data_for_max_u8x16 = comparable_u8x16;
|
|
565
|
-
idx += 16;
|
|
566
|
-
}
|
|
567
|
-
else if (idx < count) {
|
|
568
|
-
nk_b128_vec_t tail_vec;
|
|
569
|
-
nk_strided_load_b8x16_serial_(data_ptr + idx * stride_elements, stride_elements, &tail_vec, count - idx);
|
|
570
|
-
uint8x16_t comparable_u8x16 = nk_fp8x16_to_comparable_neon_(tail_vec.u8x16);
|
|
571
|
-
uint8x16_t valid_u8x16 = vcltq_u8(lane_indices_u8x16, vdupq_n_u8((uint8_t)(count - idx)));
|
|
572
|
-
data_for_min_u8x16 = vbslq_u8(valid_u8x16, comparable_u8x16, vdupq_n_u8(0xFF));
|
|
573
|
-
data_for_max_u8x16 = vbslq_u8(valid_u8x16, comparable_u8x16, vdupq_n_u8(0x00));
|
|
574
|
-
idx = count;
|
|
575
|
-
}
|
|
576
|
-
else {
|
|
577
|
-
nk_u8_t min_comparable = vminvq_u8(min_u8x16), max_comparable = vmaxvq_u8(max_u8x16);
|
|
578
|
-
if (min_comparable == 0xFF && max_comparable == 0x00) {
|
|
579
|
-
*min_value_ptr = nk_comparable_to_fp8_(min_comparable), *min_index_ptr = NK_SIZE_MAX;
|
|
580
|
-
*max_value_ptr = nk_comparable_to_fp8_(max_comparable), *max_index_ptr = NK_SIZE_MAX;
|
|
581
|
-
return;
|
|
582
|
-
}
|
|
583
|
-
uint8x16_t min_value_match_u8x16 = vceqq_u8(min_u8x16, vdupq_n_u8(min_comparable));
|
|
584
|
-
uint8x16_t masked_min_iter_u8x16 = vbslq_u8(min_value_match_u8x16, min_iter_u8x16, vdupq_n_u8(0xFF));
|
|
585
|
-
nk_u8_t earliest_min_cycle = vminvq_u8(masked_min_iter_u8x16);
|
|
586
|
-
uint8x16_t max_value_match_u8x16 = vceqq_u8(max_u8x16, vdupq_n_u8(max_comparable));
|
|
587
|
-
uint8x16_t masked_max_iter_u8x16 = vbslq_u8(max_value_match_u8x16, max_iter_u8x16, vdupq_n_u8(0xFF));
|
|
588
|
-
nk_u8_t earliest_max_cycle = vminvq_u8(masked_max_iter_u8x16);
|
|
589
|
-
uint8x16_t min_cycle_match_u8x16 = vceqq_u8(min_iter_u8x16, vdupq_n_u8(earliest_min_cycle));
|
|
590
|
-
uint8x16_t min_both_match_u8x16 = vandq_u8(min_value_match_u8x16, min_cycle_match_u8x16);
|
|
591
|
-
uint8x16_t min_masked_lanes_u8x16 = vbslq_u8(min_both_match_u8x16, lane_indices_u8x16, vdupq_n_u8(0xFF));
|
|
592
|
-
nk_u8_t min_lane_offset = vminvq_u8(min_masked_lanes_u8x16);
|
|
593
|
-
nk_size_t min_idx = (nk_size_t)earliest_min_cycle * 16 + (nk_size_t)min_lane_offset;
|
|
594
|
-
uint8x16_t max_cycle_match_u8x16 = vceqq_u8(max_iter_u8x16, vdupq_n_u8(earliest_max_cycle));
|
|
595
|
-
uint8x16_t max_both_match_u8x16 = vandq_u8(max_value_match_u8x16, max_cycle_match_u8x16);
|
|
596
|
-
uint8x16_t max_masked_lanes_u8x16 = vbslq_u8(max_both_match_u8x16, lane_indices_u8x16, vdupq_n_u8(0xFF));
|
|
597
|
-
nk_u8_t max_lane_offset = vminvq_u8(max_masked_lanes_u8x16);
|
|
598
|
-
nk_size_t max_idx = (nk_size_t)earliest_max_cycle * 16 + (nk_size_t)max_lane_offset;
|
|
599
|
-
*min_value_ptr = nk_comparable_to_fp8_(min_comparable), *min_index_ptr = min_idx;
|
|
600
|
-
*max_value_ptr = nk_comparable_to_fp8_(max_comparable), *max_index_ptr = max_idx;
|
|
601
|
-
return;
|
|
602
|
-
}
|
|
603
|
-
|
|
604
|
-
// Shared update body
|
|
605
|
-
uint8x16_t less_u8x16 = vcltq_u8(data_for_min_u8x16, min_u8x16);
|
|
606
|
-
uint8x16_t greater_u8x16 = vcgtq_u8(data_for_max_u8x16, max_u8x16);
|
|
607
|
-
min_u8x16 = vbslq_u8(less_u8x16, data_for_min_u8x16, min_u8x16);
|
|
608
|
-
max_u8x16 = vbslq_u8(greater_u8x16, data_for_max_u8x16, max_u8x16);
|
|
609
|
-
min_iter_u8x16 = vbslq_u8(less_u8x16, iter_u8x16, min_iter_u8x16);
|
|
610
|
-
max_iter_u8x16 = vbslq_u8(greater_u8x16, iter_u8x16, max_iter_u8x16);
|
|
611
|
-
iter_u8x16 = vaddq_u8(iter_u8x16, one_u8x16);
|
|
612
|
-
goto nk_reduce_minmax_e5m2_neonfhm_cycle;
|
|
613
|
-
}
|
|
614
|
-
|
|
615
|
-
NK_PUBLIC void nk_reduce_minmax_e5m2_neonfhm( //
|
|
616
|
-
nk_e5m2_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
617
|
-
nk_e5m2_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
618
|
-
nk_e5m2_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
619
|
-
nk_size_t stride_elements = stride_bytes / sizeof(nk_e5m2_t);
|
|
620
|
-
int aligned = (stride_bytes % sizeof(nk_e5m2_t) == 0);
|
|
621
|
-
if (count == 0)
|
|
622
|
-
*min_value_ptr = NK_E5M2_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_E5M2_MIN,
|
|
623
|
-
*max_index_ptr = NK_SIZE_MAX;
|
|
624
|
-
else if (!aligned)
|
|
625
|
-
nk_reduce_minmax_e5m2_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
626
|
-
max_index_ptr);
|
|
627
|
-
else if (count > (nk_size_t)256 * 16) {
|
|
628
|
-
nk_size_t left_count = count / 2;
|
|
629
|
-
nk_e5m2_t left_min_value, right_min_value, left_max_value, right_max_value;
|
|
630
|
-
nk_size_t left_min_index, right_min_index, left_max_index, right_max_index;
|
|
631
|
-
nk_reduce_minmax_e5m2_neonfhm(data_ptr, left_count, stride_bytes, &left_min_value, &left_min_index,
|
|
632
|
-
&left_max_value, &left_max_index);
|
|
633
|
-
nk_reduce_minmax_e5m2_neonfhm(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
|
|
634
|
-
&right_min_value, &right_min_index, &right_max_value, &right_max_index);
|
|
635
|
-
if (nk_e5m2_order_serial(right_min_value, left_min_value) < 0)
|
|
636
|
-
*min_value_ptr = right_min_value, *min_index_ptr = left_count + right_min_index;
|
|
637
|
-
else *min_value_ptr = left_min_value, *min_index_ptr = left_min_index;
|
|
638
|
-
if (nk_e5m2_order_serial(right_max_value, left_max_value) > 0)
|
|
639
|
-
*max_value_ptr = right_max_value, *max_index_ptr = left_count + right_max_index;
|
|
640
|
-
else *max_value_ptr = left_max_value, *max_index_ptr = left_max_index;
|
|
641
|
-
}
|
|
642
|
-
else if (stride_elements == 1)
|
|
643
|
-
nk_reduce_minmax_e5m2_neonfhm_contiguous_(data_ptr, count, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
644
|
-
max_index_ptr);
|
|
645
|
-
else if (stride_elements <= 4)
|
|
646
|
-
nk_reduce_minmax_e5m2_neonfhm_strided_(data_ptr, count, stride_elements, min_value_ptr, min_index_ptr,
|
|
647
|
-
max_value_ptr, max_index_ptr);
|
|
648
|
-
else
|
|
649
|
-
nk_reduce_minmax_e5m2_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
650
|
-
max_index_ptr);
|
|
651
|
-
}
|
|
652
|
-
|
|
653
281
|
#if defined(__clang__)
|
|
654
282
|
#pragma clang attribute pop
|
|
655
283
|
#elif defined(__GNUC__)
|
|
@@ -60,7 +60,7 @@ NK_INTERNAL void nk_reduce_moments_i8_neonsdot_strided_( //
|
|
|
60
60
|
int32x4_t sumsq_i32x4 = vdupq_n_s32(0);
|
|
61
61
|
nk_size_t idx = 0;
|
|
62
62
|
if (stride_elements == 2) {
|
|
63
|
-
for (; idx + 16
|
|
63
|
+
for (; idx + 16 < count; idx += 16) {
|
|
64
64
|
int8x16x2_t loaded = vld2q_s8(data_ptr + idx * 2);
|
|
65
65
|
int8x16_t data_i8x16 = loaded.val[0];
|
|
66
66
|
sum_i32x4 = vdotq_s32(sum_i32x4, data_i8x16, ones_i8x16);
|
|
@@ -68,7 +68,7 @@ NK_INTERNAL void nk_reduce_moments_i8_neonsdot_strided_( //
|
|
|
68
68
|
}
|
|
69
69
|
}
|
|
70
70
|
else if (stride_elements == 3) {
|
|
71
|
-
for (; idx + 16
|
|
71
|
+
for (; idx + 16 < count; idx += 16) {
|
|
72
72
|
int8x16x3_t loaded = vld3q_s8(data_ptr + idx * 3);
|
|
73
73
|
int8x16_t data_i8x16 = loaded.val[0];
|
|
74
74
|
sum_i32x4 = vdotq_s32(sum_i32x4, data_i8x16, ones_i8x16);
|
|
@@ -76,7 +76,7 @@ NK_INTERNAL void nk_reduce_moments_i8_neonsdot_strided_( //
|
|
|
76
76
|
}
|
|
77
77
|
}
|
|
78
78
|
else if (stride_elements == 4) {
|
|
79
|
-
for (; idx + 16
|
|
79
|
+
for (; idx + 16 < count; idx += 16) {
|
|
80
80
|
int8x16x4_t loaded = vld4q_s8(data_ptr + idx * 4);
|
|
81
81
|
int8x16_t data_i8x16 = loaded.val[0];
|
|
82
82
|
sum_i32x4 = vdotq_s32(sum_i32x4, data_i8x16, ones_i8x16);
|
|
@@ -151,7 +151,7 @@ NK_INTERNAL void nk_reduce_moments_u8_neonsdot_strided_( //
|
|
|
151
151
|
uint32x4_t sumsq_u32x4 = vdupq_n_u32(0);
|
|
152
152
|
nk_size_t idx = 0;
|
|
153
153
|
if (stride_elements == 2) {
|
|
154
|
-
for (; idx + 16
|
|
154
|
+
for (; idx + 16 < count; idx += 16) {
|
|
155
155
|
uint8x16x2_t loaded = vld2q_u8(data_ptr + idx * 2);
|
|
156
156
|
uint8x16_t data_u8x16 = loaded.val[0];
|
|
157
157
|
sum_u32x4 = vdotq_u32(sum_u32x4, data_u8x16, ones_u8x16);
|
|
@@ -159,7 +159,7 @@ NK_INTERNAL void nk_reduce_moments_u8_neonsdot_strided_( //
|
|
|
159
159
|
}
|
|
160
160
|
}
|
|
161
161
|
else if (stride_elements == 3) {
|
|
162
|
-
for (; idx + 16
|
|
162
|
+
for (; idx + 16 < count; idx += 16) {
|
|
163
163
|
uint8x16x3_t loaded = vld3q_u8(data_ptr + idx * 3);
|
|
164
164
|
uint8x16_t data_u8x16 = loaded.val[0];
|
|
165
165
|
sum_u32x4 = vdotq_u32(sum_u32x4, data_u8x16, ones_u8x16);
|
|
@@ -167,7 +167,7 @@ NK_INTERNAL void nk_reduce_moments_u8_neonsdot_strided_( //
|
|
|
167
167
|
}
|
|
168
168
|
}
|
|
169
169
|
else if (stride_elements == 4) {
|
|
170
|
-
for (; idx + 16
|
|
170
|
+
for (; idx + 16 < count; idx += 16) {
|
|
171
171
|
uint8x16x4_t loaded = vld4q_u8(data_ptr + idx * 4);
|
|
172
172
|
uint8x16_t data_u8x16 = loaded.val[0];
|
|
173
173
|
sum_u32x4 = vdotq_u32(sum_u32x4, data_u8x16, ones_u8x16);
|
|
@@ -268,7 +268,7 @@ NK_INTERNAL void nk_reduce_moments_e2m3_neonsdot_strided_( //
|
|
|
268
268
|
int32x4_t sumsq_i32x4 = vdupq_n_s32(0);
|
|
269
269
|
nk_size_t idx = 0;
|
|
270
270
|
if (stride_elements == 2) {
|
|
271
|
-
for (; idx + 16
|
|
271
|
+
for (; idx + 16 < count; idx += 16) {
|
|
272
272
|
uint8x16x2_t loaded_u8x16x2 = vld2q_u8((nk_u8_t const *)(data_ptr + idx * 2));
|
|
273
273
|
uint8x16_t raw_u8x16 = loaded_u8x16x2.val[0];
|
|
274
274
|
uint8x16_t magnitude_u8x16 = vandq_u8(raw_u8x16, vdupq_n_u8(0x1F));
|
|
@@ -282,7 +282,7 @@ NK_INTERNAL void nk_reduce_moments_e2m3_neonsdot_strided_( //
|
|
|
282
282
|
}
|
|
283
283
|
}
|
|
284
284
|
else if (stride_elements == 3) {
|
|
285
|
-
for (; idx + 16
|
|
285
|
+
for (; idx + 16 < count; idx += 16) {
|
|
286
286
|
uint8x16x3_t loaded_u8x16x3 = vld3q_u8((nk_u8_t const *)(data_ptr + idx * 3));
|
|
287
287
|
uint8x16_t raw_u8x16 = loaded_u8x16x3.val[0];
|
|
288
288
|
uint8x16_t magnitude_u8x16 = vandq_u8(raw_u8x16, vdupq_n_u8(0x1F));
|
|
@@ -296,7 +296,7 @@ NK_INTERNAL void nk_reduce_moments_e2m3_neonsdot_strided_( //
|
|
|
296
296
|
}
|
|
297
297
|
}
|
|
298
298
|
else if (stride_elements == 4) {
|
|
299
|
-
for (; idx + 16
|
|
299
|
+
for (; idx + 16 < count; idx += 16) {
|
|
300
300
|
uint8x16x4_t loaded_u8x16x4 = vld4q_u8((nk_u8_t const *)(data_ptr + idx * 4));
|
|
301
301
|
uint8x16_t raw_u8x16 = loaded_u8x16x4.val[0];
|
|
302
302
|
uint8x16_t magnitude_u8x16 = vandq_u8(raw_u8x16, vdupq_n_u8(0x1F));
|