numkong 7.0.0 → 7.4.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +197 -124
- package/binding.gyp +34 -484
- package/c/dispatch_bf16.c +59 -1
- package/c/dispatch_e2m3.c +41 -8
- package/c/dispatch_e3m2.c +49 -8
- package/c/dispatch_e4m3.c +51 -9
- package/c/dispatch_e5m2.c +45 -1
- package/c/dispatch_f16.c +79 -26
- package/c/dispatch_f16c.c +5 -5
- package/c/dispatch_f32.c +56 -0
- package/c/dispatch_f64.c +52 -0
- package/c/dispatch_i4.c +3 -0
- package/c/dispatch_i8.c +62 -3
- package/c/dispatch_other.c +18 -0
- package/c/dispatch_u1.c +54 -9
- package/c/dispatch_u4.c +3 -0
- package/c/dispatch_u8.c +64 -3
- package/c/numkong.c +3 -0
- package/include/README.md +79 -9
- package/include/numkong/attention/sapphireamx.h +278 -276
- package/include/numkong/attention/sme.h +983 -977
- package/include/numkong/attention.h +1 -1
- package/include/numkong/capabilities.h +289 -94
- package/include/numkong/cast/README.md +40 -40
- package/include/numkong/cast/diamond.h +64 -0
- package/include/numkong/cast/haswell.h +42 -194
- package/include/numkong/cast/icelake.h +42 -37
- package/include/numkong/cast/loongsonasx.h +252 -0
- package/include/numkong/cast/neon.h +216 -249
- package/include/numkong/cast/powervsx.h +449 -0
- package/include/numkong/cast/rvv.h +223 -274
- package/include/numkong/cast/sapphire.h +18 -18
- package/include/numkong/cast/serial.h +1018 -944
- package/include/numkong/cast/skylake.h +82 -23
- package/include/numkong/cast/v128relaxed.h +462 -105
- package/include/numkong/cast.h +24 -0
- package/include/numkong/cast.hpp +44 -0
- package/include/numkong/curved/README.md +17 -17
- package/include/numkong/curved/neon.h +131 -7
- package/include/numkong/curved/neonbfdot.h +6 -7
- package/include/numkong/curved/rvv.h +26 -26
- package/include/numkong/curved/smef64.h +186 -182
- package/include/numkong/curved.h +14 -18
- package/include/numkong/dot/README.md +154 -137
- package/include/numkong/dot/alder.h +43 -43
- package/include/numkong/dot/diamond.h +158 -0
- package/include/numkong/dot/genoa.h +4 -30
- package/include/numkong/dot/haswell.h +215 -180
- package/include/numkong/dot/icelake.h +190 -76
- package/include/numkong/dot/loongsonasx.h +671 -0
- package/include/numkong/dot/neon.h +124 -73
- package/include/numkong/dot/neonbfdot.h +11 -12
- package/include/numkong/dot/neonfhm.h +44 -46
- package/include/numkong/dot/neonfp8.h +323 -0
- package/include/numkong/dot/neonsdot.h +190 -76
- package/include/numkong/dot/powervsx.h +752 -0
- package/include/numkong/dot/rvv.h +92 -84
- package/include/numkong/dot/rvvbf16.h +12 -12
- package/include/numkong/dot/rvvhalf.h +12 -12
- package/include/numkong/dot/sapphire.h +4 -4
- package/include/numkong/dot/serial.h +66 -30
- package/include/numkong/dot/sierra.h +31 -31
- package/include/numkong/dot/skylake.h +142 -110
- package/include/numkong/dot/sve.h +217 -177
- package/include/numkong/dot/svebfdot.h +10 -10
- package/include/numkong/dot/svehalf.h +85 -41
- package/include/numkong/dot/svesdot.h +89 -0
- package/include/numkong/dot/v128relaxed.h +124 -89
- package/include/numkong/dot.h +114 -48
- package/include/numkong/dots/README.md +203 -203
- package/include/numkong/dots/alder.h +12 -9
- package/include/numkong/dots/diamond.h +86 -0
- package/include/numkong/dots/genoa.h +10 -4
- package/include/numkong/dots/haswell.h +63 -48
- package/include/numkong/dots/icelake.h +27 -18
- package/include/numkong/dots/loongsonasx.h +176 -0
- package/include/numkong/dots/neon.h +14 -11
- package/include/numkong/dots/neonbfdot.h +4 -3
- package/include/numkong/dots/neonfhm.h +11 -9
- package/include/numkong/dots/neonfp8.h +99 -0
- package/include/numkong/dots/neonsdot.h +48 -12
- package/include/numkong/dots/powervsx.h +194 -0
- package/include/numkong/dots/rvv.h +451 -344
- package/include/numkong/dots/sapphireamx.h +1028 -984
- package/include/numkong/dots/serial.h +213 -197
- package/include/numkong/dots/sierra.h +10 -7
- package/include/numkong/dots/skylake.h +47 -36
- package/include/numkong/dots/sme.h +2001 -2364
- package/include/numkong/dots/smebi32.h +175 -162
- package/include/numkong/dots/smef64.h +328 -323
- package/include/numkong/dots/v128relaxed.h +64 -41
- package/include/numkong/dots.h +573 -293
- package/include/numkong/dots.hpp +45 -43
- package/include/numkong/each/README.md +133 -137
- package/include/numkong/each/haswell.h +6 -6
- package/include/numkong/each/icelake.h +7 -7
- package/include/numkong/each/neon.h +76 -42
- package/include/numkong/each/neonbfdot.h +11 -12
- package/include/numkong/each/neonhalf.h +24 -116
- package/include/numkong/each/rvv.h +28 -28
- package/include/numkong/each/sapphire.h +27 -161
- package/include/numkong/each/serial.h +6 -6
- package/include/numkong/each/skylake.h +7 -7
- package/include/numkong/each/v128relaxed.h +562 -0
- package/include/numkong/each.h +148 -62
- package/include/numkong/each.hpp +2 -2
- package/include/numkong/geospatial/README.md +18 -18
- package/include/numkong/geospatial/haswell.h +365 -325
- package/include/numkong/geospatial/neon.h +350 -306
- package/include/numkong/geospatial/rvv.h +4 -4
- package/include/numkong/geospatial/skylake.h +376 -340
- package/include/numkong/geospatial/v128relaxed.h +366 -327
- package/include/numkong/geospatial.h +17 -17
- package/include/numkong/matrix.hpp +4 -4
- package/include/numkong/maxsim/README.md +14 -14
- package/include/numkong/maxsim/alder.h +6 -6
- package/include/numkong/maxsim/genoa.h +4 -4
- package/include/numkong/maxsim/haswell.h +6 -6
- package/include/numkong/maxsim/icelake.h +18 -18
- package/include/numkong/maxsim/neonsdot.h +21 -21
- package/include/numkong/maxsim/sapphireamx.h +14 -14
- package/include/numkong/maxsim/serial.h +6 -6
- package/include/numkong/maxsim/sme.h +221 -196
- package/include/numkong/maxsim/v128relaxed.h +6 -6
- package/include/numkong/mesh/README.md +62 -56
- package/include/numkong/mesh/haswell.h +339 -464
- package/include/numkong/mesh/neon.h +1100 -519
- package/include/numkong/mesh/neonbfdot.h +36 -68
- package/include/numkong/mesh/rvv.h +530 -435
- package/include/numkong/mesh/serial.h +75 -91
- package/include/numkong/mesh/skylake.h +1627 -302
- package/include/numkong/mesh/v128relaxed.h +443 -330
- package/include/numkong/mesh.h +63 -49
- package/include/numkong/mesh.hpp +4 -4
- package/include/numkong/numkong.h +3 -3
- package/include/numkong/numkong.hpp +1 -0
- package/include/numkong/probability/README.md +23 -19
- package/include/numkong/probability/neon.h +82 -52
- package/include/numkong/probability/rvv.h +28 -23
- package/include/numkong/probability/serial.h +51 -39
- package/include/numkong/probability.h +20 -23
- package/include/numkong/random.h +1 -1
- package/include/numkong/reduce/README.md +143 -138
- package/include/numkong/reduce/alder.h +81 -77
- package/include/numkong/reduce/haswell.h +222 -220
- package/include/numkong/reduce/neon.h +629 -519
- package/include/numkong/reduce/neonbfdot.h +7 -218
- package/include/numkong/reduce/neonfhm.h +9 -381
- package/include/numkong/reduce/neonsdot.h +9 -9
- package/include/numkong/reduce/rvv.h +928 -802
- package/include/numkong/reduce/serial.h +23 -27
- package/include/numkong/reduce/sierra.h +20 -20
- package/include/numkong/reduce/skylake.h +326 -324
- package/include/numkong/reduce/v128relaxed.h +52 -52
- package/include/numkong/reduce.h +4 -23
- package/include/numkong/reduce.hpp +156 -11
- package/include/numkong/scalar/README.md +6 -6
- package/include/numkong/scalar/haswell.h +26 -17
- package/include/numkong/scalar/loongsonasx.h +74 -0
- package/include/numkong/scalar/neon.h +9 -9
- package/include/numkong/scalar/powervsx.h +96 -0
- package/include/numkong/scalar/rvv.h +2 -2
- package/include/numkong/scalar/sapphire.h +21 -10
- package/include/numkong/scalar/serial.h +21 -21
- package/include/numkong/scalar.h +13 -0
- package/include/numkong/set/README.md +28 -28
- package/include/numkong/set/haswell.h +12 -12
- package/include/numkong/set/icelake.h +14 -14
- package/include/numkong/set/loongsonasx.h +181 -0
- package/include/numkong/set/neon.h +17 -18
- package/include/numkong/set/powervsx.h +326 -0
- package/include/numkong/set/rvv.h +4 -4
- package/include/numkong/set/serial.h +6 -6
- package/include/numkong/set/sve.h +60 -59
- package/include/numkong/set/v128relaxed.h +6 -6
- package/include/numkong/set.h +21 -7
- package/include/numkong/sets/README.md +26 -26
- package/include/numkong/sets/loongsonasx.h +52 -0
- package/include/numkong/sets/powervsx.h +65 -0
- package/include/numkong/sets/smebi32.h +395 -364
- package/include/numkong/sets.h +83 -40
- package/include/numkong/sparse/README.md +4 -4
- package/include/numkong/sparse/icelake.h +101 -101
- package/include/numkong/sparse/serial.h +1 -1
- package/include/numkong/sparse/sve2.h +137 -141
- package/include/numkong/sparse/turin.h +12 -12
- package/include/numkong/sparse.h +10 -10
- package/include/numkong/spatial/README.md +230 -226
- package/include/numkong/spatial/alder.h +113 -116
- package/include/numkong/spatial/diamond.h +240 -0
- package/include/numkong/spatial/genoa.h +0 -68
- package/include/numkong/spatial/haswell.h +74 -55
- package/include/numkong/spatial/icelake.h +539 -58
- package/include/numkong/spatial/loongsonasx.h +483 -0
- package/include/numkong/spatial/neon.h +125 -52
- package/include/numkong/spatial/neonbfdot.h +8 -9
- package/include/numkong/spatial/neonfp8.h +258 -0
- package/include/numkong/spatial/neonsdot.h +180 -12
- package/include/numkong/spatial/powervsx.h +738 -0
- package/include/numkong/spatial/rvv.h +146 -139
- package/include/numkong/spatial/rvvbf16.h +17 -12
- package/include/numkong/spatial/rvvhalf.h +13 -10
- package/include/numkong/spatial/serial.h +13 -12
- package/include/numkong/spatial/sierra.h +232 -39
- package/include/numkong/spatial/skylake.h +73 -74
- package/include/numkong/spatial/sve.h +93 -72
- package/include/numkong/spatial/svebfdot.h +29 -29
- package/include/numkong/spatial/svehalf.h +52 -26
- package/include/numkong/spatial/svesdot.h +142 -0
- package/include/numkong/spatial/v128relaxed.h +293 -41
- package/include/numkong/spatial.h +338 -82
- package/include/numkong/spatials/README.md +194 -194
- package/include/numkong/spatials/diamond.h +82 -0
- package/include/numkong/spatials/haswell.h +2 -2
- package/include/numkong/spatials/loongsonasx.h +153 -0
- package/include/numkong/spatials/neonfp8.h +111 -0
- package/include/numkong/spatials/neonsdot.h +34 -0
- package/include/numkong/spatials/powervsx.h +153 -0
- package/include/numkong/spatials/rvv.h +259 -243
- package/include/numkong/spatials/sapphireamx.h +173 -173
- package/include/numkong/spatials/serial.h +2 -2
- package/include/numkong/spatials/skylake.h +2 -2
- package/include/numkong/spatials/sme.h +590 -605
- package/include/numkong/spatials/smef64.h +139 -130
- package/include/numkong/spatials/v128relaxed.h +2 -2
- package/include/numkong/spatials.h +820 -500
- package/include/numkong/spatials.hpp +49 -48
- package/include/numkong/tensor.hpp +406 -17
- package/include/numkong/trigonometry/README.md +19 -19
- package/include/numkong/trigonometry/haswell.h +402 -401
- package/include/numkong/trigonometry/neon.h +386 -387
- package/include/numkong/trigonometry/rvv.h +52 -51
- package/include/numkong/trigonometry/serial.h +13 -13
- package/include/numkong/trigonometry/skylake.h +373 -369
- package/include/numkong/trigonometry/v128relaxed.h +375 -374
- package/include/numkong/trigonometry.h +13 -13
- package/include/numkong/trigonometry.hpp +2 -2
- package/include/numkong/types.h +287 -49
- package/include/numkong/types.hpp +436 -12
- package/include/numkong/vector.hpp +82 -14
- package/javascript/dist/cjs/numkong-wasm.js +6 -12
- package/javascript/dist/cjs/numkong.d.ts +7 -1
- package/javascript/dist/cjs/numkong.js +37 -11
- package/javascript/dist/cjs/types.d.ts +9 -0
- package/javascript/dist/cjs/types.js +96 -0
- package/javascript/dist/esm/numkong-browser.d.ts +14 -0
- package/javascript/dist/esm/numkong-browser.js +23 -0
- package/javascript/dist/esm/numkong-wasm.js +6 -12
- package/javascript/dist/esm/numkong.d.ts +7 -1
- package/javascript/dist/esm/numkong.js +37 -11
- package/javascript/dist/esm/types.d.ts +9 -0
- package/javascript/dist/esm/types.js +96 -0
- package/javascript/node-gyp-build.d.ts +4 -1
- package/javascript/numkong-browser.ts +40 -0
- package/javascript/numkong-wasm.ts +7 -13
- package/javascript/numkong.c +5 -26
- package/javascript/numkong.ts +36 -11
- package/javascript/tsconfig-base.json +1 -0
- package/javascript/tsconfig-cjs.json +6 -1
- package/javascript/types.ts +110 -0
- package/numkong.gypi +101 -0
- package/package.json +34 -13
- package/probes/arm_neon.c +8 -0
- package/probes/arm_neon_bfdot.c +9 -0
- package/probes/arm_neon_fhm.c +9 -0
- package/probes/arm_neon_half.c +8 -0
- package/probes/arm_neon_sdot.c +9 -0
- package/probes/arm_neonfp8.c +9 -0
- package/probes/arm_sme.c +16 -0
- package/probes/arm_sme2.c +16 -0
- package/probes/arm_sme2p1.c +16 -0
- package/probes/arm_sme_bf16.c +16 -0
- package/probes/arm_sme_bi32.c +16 -0
- package/probes/arm_sme_f64.c +16 -0
- package/probes/arm_sme_fa64.c +14 -0
- package/probes/arm_sme_half.c +16 -0
- package/probes/arm_sme_lut2.c +15 -0
- package/probes/arm_sve.c +18 -0
- package/probes/arm_sve2.c +20 -0
- package/probes/arm_sve2p1.c +18 -0
- package/probes/arm_sve_bfdot.c +20 -0
- package/probes/arm_sve_half.c +18 -0
- package/probes/arm_sve_sdot.c +21 -0
- package/probes/loongarch_lasx.c +12 -0
- package/probes/power_vsx.c +12 -0
- package/probes/probe.js +127 -0
- package/probes/riscv_rvv.c +14 -0
- package/probes/riscv_rvv_bb.c +15 -0
- package/probes/riscv_rvv_bf16.c +17 -0
- package/probes/riscv_rvv_half.c +14 -0
- package/probes/wasm_v128relaxed.c +11 -0
- package/probes/x86_alder.c +17 -0
- package/probes/x86_diamond.c +17 -0
- package/probes/x86_genoa.c +17 -0
- package/probes/x86_graniteamx.c +19 -0
- package/probes/x86_haswell.c +11 -0
- package/probes/x86_icelake.c +17 -0
- package/probes/x86_sapphire.c +16 -0
- package/probes/x86_sapphireamx.c +18 -0
- package/probes/x86_sierra.c +17 -0
- package/probes/x86_skylake.c +15 -0
- package/probes/x86_turin.c +17 -0
- package/wasm/numkong-emscripten.js +2 -0
- package/wasm/numkong.d.ts +14 -0
- package/wasm/numkong.js +1124 -0
- package/wasm/numkong.wasm +0 -0
- package/include/numkong/curved/neonhalf.h +0 -212
- package/include/numkong/dot/neonhalf.h +0 -198
- package/include/numkong/dots/neonhalf.h +0 -57
- package/include/numkong/mesh/neonhalf.h +0 -616
- package/include/numkong/reduce/neonhalf.h +0 -157
- package/include/numkong/spatial/neonhalf.h +0 -118
- package/include/numkong/spatial/sapphire.h +0 -343
- package/include/numkong/spatials/neonhalf.h +0 -58
- package/javascript/README.md +0 -246
|
@@ -27,14 +27,14 @@ extern "C" {
|
|
|
27
27
|
#pragma GCC target("arch=armv8-a+simd")
|
|
28
28
|
#endif
|
|
29
29
|
|
|
30
|
-
NK_INTERNAL nk_u64_t nk_reduce_sadd_u64x2_neon_(uint64x2_t
|
|
31
|
-
uint64x2_t swapped_u64x2 = vextq_u64(
|
|
32
|
-
return vgetq_lane_u64(vqaddq_u64(
|
|
30
|
+
NK_INTERNAL nk_u64_t nk_reduce_sadd_u64x2_neon_(uint64x2_t v_u64x2) {
|
|
31
|
+
uint64x2_t swapped_u64x2 = vextq_u64(v_u64x2, v_u64x2, 1);
|
|
32
|
+
return vgetq_lane_u64(vqaddq_u64(v_u64x2, swapped_u64x2), 0);
|
|
33
33
|
}
|
|
34
34
|
|
|
35
35
|
/** @brief Saturating square of each i64 lane → u64. If |a| >= 2^32, a² overflows u64 → saturate. */
|
|
36
|
-
NK_INTERNAL uint64x2_t nk_i64_smul_sq_i64x2_neon_(int64x2_t
|
|
37
|
-
uint64x2_t absolute_u64x2 = vreinterpretq_u64_s64(vabsq_s64(
|
|
36
|
+
NK_INTERNAL uint64x2_t nk_i64_smul_sq_i64x2_neon_(int64x2_t val_i64x2) {
|
|
37
|
+
uint64x2_t absolute_u64x2 = vreinterpretq_u64_s64(vabsq_s64(val_i64x2));
|
|
38
38
|
uint32x2_t low_halves_u32x2 = vmovn_u64(absolute_u64x2);
|
|
39
39
|
uint64x2_t high_bits_u64x2 = vshrq_n_u64(absolute_u64x2, 32);
|
|
40
40
|
uint64x2_t low_squared_u64x2 = vmull_u32(low_halves_u32x2, low_halves_u32x2);
|
|
@@ -43,9 +43,9 @@ NK_INTERNAL uint64x2_t nk_i64_smul_sq_i64x2_neon_(int64x2_t val) {
|
|
|
43
43
|
}
|
|
44
44
|
|
|
45
45
|
/** @brief Saturating square of each u64 lane → u64. If a >= 2^32, a² overflows u64 → saturate. */
|
|
46
|
-
NK_INTERNAL uint64x2_t nk_u64_smul_sq_u64x2_neon_(uint64x2_t
|
|
47
|
-
uint32x2_t low_halves_u32x2 = vmovn_u64(
|
|
48
|
-
uint64x2_t high_bits_u64x2 = vshrq_n_u64(
|
|
46
|
+
NK_INTERNAL uint64x2_t nk_u64_smul_sq_u64x2_neon_(uint64x2_t val_u64x2) {
|
|
47
|
+
uint32x2_t low_halves_u32x2 = vmovn_u64(val_u64x2);
|
|
48
|
+
uint64x2_t high_bits_u64x2 = vshrq_n_u64(val_u64x2, 32);
|
|
49
49
|
uint64x2_t low_squared_u64x2 = vmull_u32(low_halves_u32x2, low_halves_u32x2);
|
|
50
50
|
uint64x2_t is_small_u64x2 = vceqq_u64(high_bits_u64x2, vdupq_n_u64(0));
|
|
51
51
|
return vbslq_u64(is_small_u64x2, low_squared_u64x2, vdupq_n_u64(NK_U64_MAX));
|
|
@@ -59,7 +59,7 @@ NK_INTERNAL void nk_reduce_moments_f32_neon_contiguous_( //
|
|
|
59
59
|
for (; idx + 4 <= count; idx += 4) {
|
|
60
60
|
float32x4_t data_f32x4 = vld1q_f32(data_ptr + idx);
|
|
61
61
|
float64x2_t data_low_f64x2 = vcvt_f64_f32(vget_low_f32(data_f32x4));
|
|
62
|
-
float64x2_t data_high_f64x2 =
|
|
62
|
+
float64x2_t data_high_f64x2 = vcvt_high_f64_f32(data_f32x4);
|
|
63
63
|
sum_f64x2 = vaddq_f64(sum_f64x2, data_low_f64x2);
|
|
64
64
|
sum_f64x2 = vaddq_f64(sum_f64x2, data_high_f64x2);
|
|
65
65
|
sumsq_f64x2 = vfmaq_f64(sumsq_f64x2, data_low_f64x2, data_low_f64x2);
|
|
@@ -79,10 +79,10 @@ NK_INTERNAL void nk_reduce_moments_f32_neon_strided_( //
|
|
|
79
79
|
float64x2_t sum_f64x2 = vdupq_n_f64(0), sumsq_f64x2 = vdupq_n_f64(0);
|
|
80
80
|
nk_size_t idx = 0;
|
|
81
81
|
if (stride_elements == 2) {
|
|
82
|
-
for (; idx + 4
|
|
82
|
+
for (; idx + 4 < count; idx += 4) {
|
|
83
83
|
float32x4x2_t loaded_f32x4x2 = vld2q_f32(data_ptr + idx * 2);
|
|
84
84
|
float64x2_t data_low_f64x2 = vcvt_f64_f32(vget_low_f32(loaded_f32x4x2.val[0]));
|
|
85
|
-
float64x2_t data_high_f64x2 =
|
|
85
|
+
float64x2_t data_high_f64x2 = vcvt_high_f64_f32(loaded_f32x4x2.val[0]);
|
|
86
86
|
sum_f64x2 = vaddq_f64(sum_f64x2, data_low_f64x2);
|
|
87
87
|
sum_f64x2 = vaddq_f64(sum_f64x2, data_high_f64x2);
|
|
88
88
|
sumsq_f64x2 = vfmaq_f64(sumsq_f64x2, data_low_f64x2, data_low_f64x2);
|
|
@@ -90,10 +90,10 @@ NK_INTERNAL void nk_reduce_moments_f32_neon_strided_( //
|
|
|
90
90
|
}
|
|
91
91
|
}
|
|
92
92
|
else if (stride_elements == 3) {
|
|
93
|
-
for (; idx + 4
|
|
93
|
+
for (; idx + 4 < count; idx += 4) {
|
|
94
94
|
float32x4x3_t loaded_f32x4x3 = vld3q_f32(data_ptr + idx * 3);
|
|
95
95
|
float64x2_t data_low_f64x2 = vcvt_f64_f32(vget_low_f32(loaded_f32x4x3.val[0]));
|
|
96
|
-
float64x2_t data_high_f64x2 =
|
|
96
|
+
float64x2_t data_high_f64x2 = vcvt_high_f64_f32(loaded_f32x4x3.val[0]);
|
|
97
97
|
sum_f64x2 = vaddq_f64(sum_f64x2, data_low_f64x2);
|
|
98
98
|
sum_f64x2 = vaddq_f64(sum_f64x2, data_high_f64x2);
|
|
99
99
|
sumsq_f64x2 = vfmaq_f64(sumsq_f64x2, data_low_f64x2, data_low_f64x2);
|
|
@@ -101,10 +101,10 @@ NK_INTERNAL void nk_reduce_moments_f32_neon_strided_( //
|
|
|
101
101
|
}
|
|
102
102
|
}
|
|
103
103
|
else {
|
|
104
|
-
for (; idx + 4
|
|
104
|
+
for (; idx + 4 < count; idx += 4) {
|
|
105
105
|
float32x4x4_t loaded_f32x4x4 = vld4q_f32(data_ptr + idx * 4);
|
|
106
106
|
float64x2_t data_low_f64x2 = vcvt_f64_f32(vget_low_f32(loaded_f32x4x4.val[0]));
|
|
107
|
-
float64x2_t data_high_f64x2 =
|
|
107
|
+
float64x2_t data_high_f64x2 = vcvt_high_f64_f32(loaded_f32x4x4.val[0]);
|
|
108
108
|
sum_f64x2 = vaddq_f64(sum_f64x2, data_low_f64x2);
|
|
109
109
|
sum_f64x2 = vaddq_f64(sum_f64x2, data_high_f64x2);
|
|
110
110
|
sumsq_f64x2 = vfmaq_f64(sumsq_f64x2, data_low_f64x2, data_low_f64x2);
|
|
@@ -165,7 +165,7 @@ NK_INTERNAL void nk_reduce_minmax_f32_neon_contiguous_( //
|
|
|
165
165
|
nk_partial_load_b32x4_serial_(data_ptr + idx, &tail_vec, remaining);
|
|
166
166
|
uint32x4_t lane_u32x4 = vcombine_u32(vreinterpret_u32_u64(vcreate_u64(0x0000000100000000ULL)),
|
|
167
167
|
vreinterpret_u32_u64(vcreate_u64(0x0000000300000002ULL)));
|
|
168
|
-
uint32x4_t valid_u32x4 = vcltq_u32(lane_u32x4, vdupq_n_u32((
|
|
168
|
+
uint32x4_t valid_u32x4 = vcltq_u32(lane_u32x4, vdupq_n_u32((nk_u32_t)remaining));
|
|
169
169
|
float32x4_t data_min_f32x4 = vbslq_f32(valid_u32x4, tail_vec.f32x4, min_f32x4);
|
|
170
170
|
float32x4_t data_max_f32x4 = vbslq_f32(valid_u32x4, tail_vec.f32x4, max_f32x4);
|
|
171
171
|
uint32x4_t less_u32x4 = vcltq_f32(data_min_f32x4, min_f32x4);
|
|
@@ -219,19 +219,19 @@ NK_INTERNAL void nk_reduce_minmax_f32_neon_strided_( //
|
|
|
219
219
|
float32x4_t data_for_min_f32x4, data_for_max_f32x4;
|
|
220
220
|
|
|
221
221
|
nk_reduce_minmax_f32_neon_cycle:
|
|
222
|
-
if (stride_elements == 2 && idx + 4
|
|
222
|
+
if (stride_elements == 2 && idx + 4 < count) {
|
|
223
223
|
float32x4x2_t loaded = vld2q_f32(data_ptr + idx * 2);
|
|
224
224
|
data_for_min_f32x4 = loaded.val[0];
|
|
225
225
|
data_for_max_f32x4 = loaded.val[0];
|
|
226
226
|
idx += 4;
|
|
227
227
|
}
|
|
228
|
-
else if (stride_elements == 3 && idx + 4
|
|
228
|
+
else if (stride_elements == 3 && idx + 4 < count) {
|
|
229
229
|
float32x4x3_t loaded = vld3q_f32(data_ptr + idx * 3);
|
|
230
230
|
data_for_min_f32x4 = loaded.val[0];
|
|
231
231
|
data_for_max_f32x4 = loaded.val[0];
|
|
232
232
|
idx += 4;
|
|
233
233
|
}
|
|
234
|
-
else if (stride_elements == 4 && idx + 4
|
|
234
|
+
else if (stride_elements == 4 && idx + 4 < count) {
|
|
235
235
|
float32x4x4_t loaded = vld4q_f32(data_ptr + idx * 4);
|
|
236
236
|
data_for_min_f32x4 = loaded.val[0];
|
|
237
237
|
data_for_max_f32x4 = loaded.val[0];
|
|
@@ -240,7 +240,7 @@ nk_reduce_minmax_f32_neon_cycle:
|
|
|
240
240
|
else if (idx < count) {
|
|
241
241
|
nk_b128_vec_t tail_vec;
|
|
242
242
|
nk_strided_load_b32x4_serial_(data_ptr + idx * stride_elements, stride_elements, &tail_vec, count - idx);
|
|
243
|
-
uint32x4_t valid_u32x4 = vcltq_u32(lane_indices_u32x4, vdupq_n_u32((
|
|
243
|
+
uint32x4_t valid_u32x4 = vcltq_u32(lane_indices_u32x4, vdupq_n_u32((nk_u32_t)(count - idx)));
|
|
244
244
|
data_for_min_f32x4 = vbslq_f32(valid_u32x4, tail_vec.f32x4, min_f32x4);
|
|
245
245
|
data_for_max_f32x4 = vbslq_f32(valid_u32x4, tail_vec.f32x4, max_f32x4);
|
|
246
246
|
idx = count;
|
|
@@ -395,8 +395,8 @@ NK_INTERNAL void nk_reduce_minmax_f64_neon_contiguous_( //
|
|
|
395
395
|
nk_f64_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
396
396
|
nk_f64_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
397
397
|
float64x2_t min_f64x2 = vdupq_n_f64(NK_F64_MAX), max_f64x2 = vdupq_n_f64(NK_F64_MIN);
|
|
398
|
-
uint64x2_t
|
|
399
|
-
uint64x2_t
|
|
398
|
+
uint64x2_t min_iter_u64x2 = vdupq_n_u64(0), max_iter_u64x2 = vdupq_n_u64(0);
|
|
399
|
+
uint64x2_t iter_u64x2 = vdupq_n_u64(0), one_u64x2 = vdupq_n_u64(1);
|
|
400
400
|
nk_size_t idx = 0;
|
|
401
401
|
for (; idx + 2 <= count; idx += 2) {
|
|
402
402
|
float64x2_t data_f64x2 = vld1q_f64(data_ptr + idx);
|
|
@@ -404,15 +404,15 @@ NK_INTERNAL void nk_reduce_minmax_f64_neon_contiguous_( //
|
|
|
404
404
|
uint64x2_t greater_u64x2 = vcgtq_f64(data_f64x2, max_f64x2);
|
|
405
405
|
min_f64x2 = vbslq_f64(less_u64x2, data_f64x2, min_f64x2);
|
|
406
406
|
max_f64x2 = vbslq_f64(greater_u64x2, data_f64x2, max_f64x2);
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
407
|
+
min_iter_u64x2 = vbslq_u64(less_u64x2, iter_u64x2, min_iter_u64x2);
|
|
408
|
+
max_iter_u64x2 = vbslq_u64(greater_u64x2, iter_u64x2, max_iter_u64x2);
|
|
409
|
+
iter_u64x2 = vaddq_u64(iter_u64x2, one_u64x2);
|
|
410
410
|
}
|
|
411
411
|
nk_b128_vec_t min_values_vec, max_values_vec, min_indices_vec, max_indices_vec;
|
|
412
412
|
min_values_vec.f64x2 = min_f64x2;
|
|
413
|
-
min_indices_vec.u64x2 =
|
|
413
|
+
min_indices_vec.u64x2 = min_iter_u64x2;
|
|
414
414
|
max_values_vec.f64x2 = max_f64x2;
|
|
415
|
-
max_indices_vec.u64x2 =
|
|
415
|
+
max_indices_vec.u64x2 = max_iter_u64x2;
|
|
416
416
|
nk_f64_t min_value, max_value;
|
|
417
417
|
nk_size_t min_index, max_index;
|
|
418
418
|
if (min_values_vec.f64s[0] <= min_values_vec.f64s[1])
|
|
@@ -466,10 +466,10 @@ NK_INTERNAL void nk_reduce_moments_i8_neon_contiguous_( //
|
|
|
466
466
|
int8x16_t data_i8x16 = vld1q_s8(data_ptr + idx);
|
|
467
467
|
int16x8_t pairwise_i16x8 = vpaddlq_s8(data_i8x16);
|
|
468
468
|
sum_i32x4 = vaddq_s32(sum_i32x4, vpaddlq_s16(pairwise_i16x8));
|
|
469
|
-
int16x8_t
|
|
470
|
-
int16x8_t
|
|
471
|
-
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vpaddlq_u16(vreinterpretq_u16_s16(
|
|
472
|
-
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vpaddlq_u16(vreinterpretq_u16_s16(
|
|
469
|
+
int16x8_t squares_low_i16x8 = vmull_s8(vget_low_s8(data_i8x16), vget_low_s8(data_i8x16));
|
|
470
|
+
int16x8_t squares_high_i16x8 = vmull_high_s8(data_i8x16, data_i8x16);
|
|
471
|
+
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vpaddlq_u16(vreinterpretq_u16_s16(squares_low_i16x8))));
|
|
472
|
+
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vpaddlq_u16(vreinterpretq_u16_s16(squares_high_i16x8))));
|
|
473
473
|
}
|
|
474
474
|
nk_i64_t sum = vaddlvq_s32(sum_i32x4);
|
|
475
475
|
nk_u64_t sumsq = vgetq_lane_u64(sumsq_u64x2, 0) + vgetq_lane_u64(sumsq_u64x2, 1);
|
|
@@ -487,39 +487,39 @@ NK_INTERNAL void nk_reduce_moments_i8_neon_strided_( //
|
|
|
487
487
|
uint64x2_t sumsq_u64x2 = vdupq_n_u64(0);
|
|
488
488
|
nk_size_t idx = 0;
|
|
489
489
|
if (stride_elements == 2) {
|
|
490
|
-
for (; idx + 16
|
|
490
|
+
for (; idx + 16 < count; idx += 16) {
|
|
491
491
|
int8x16x2_t loaded_i8x16x2 = vld2q_s8(data_ptr + idx * 2);
|
|
492
492
|
int8x16_t data_i8x16 = loaded_i8x16x2.val[0];
|
|
493
493
|
int16x8_t pairwise_i16x8 = vpaddlq_s8(data_i8x16);
|
|
494
494
|
sum_i32x4 = vaddq_s32(sum_i32x4, vpaddlq_s16(pairwise_i16x8));
|
|
495
|
-
int16x8_t
|
|
496
|
-
int16x8_t
|
|
497
|
-
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vpaddlq_u16(vreinterpretq_u16_s16(
|
|
498
|
-
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vpaddlq_u16(vreinterpretq_u16_s16(
|
|
495
|
+
int16x8_t squares_low_i16x8 = vmull_s8(vget_low_s8(data_i8x16), vget_low_s8(data_i8x16));
|
|
496
|
+
int16x8_t squares_high_i16x8 = vmull_high_s8(data_i8x16, data_i8x16);
|
|
497
|
+
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vpaddlq_u16(vreinterpretq_u16_s16(squares_low_i16x8))));
|
|
498
|
+
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vpaddlq_u16(vreinterpretq_u16_s16(squares_high_i16x8))));
|
|
499
499
|
}
|
|
500
500
|
}
|
|
501
501
|
else if (stride_elements == 3) {
|
|
502
|
-
for (; idx + 16
|
|
502
|
+
for (; idx + 16 < count; idx += 16) {
|
|
503
503
|
int8x16x3_t loaded_i8x16x3 = vld3q_s8(data_ptr + idx * 3);
|
|
504
504
|
int8x16_t data_i8x16 = loaded_i8x16x3.val[0];
|
|
505
505
|
int16x8_t pairwise_i16x8 = vpaddlq_s8(data_i8x16);
|
|
506
506
|
sum_i32x4 = vaddq_s32(sum_i32x4, vpaddlq_s16(pairwise_i16x8));
|
|
507
|
-
int16x8_t
|
|
508
|
-
int16x8_t
|
|
509
|
-
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vpaddlq_u16(vreinterpretq_u16_s16(
|
|
510
|
-
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vpaddlq_u16(vreinterpretq_u16_s16(
|
|
507
|
+
int16x8_t squares_low_i16x8 = vmull_s8(vget_low_s8(data_i8x16), vget_low_s8(data_i8x16));
|
|
508
|
+
int16x8_t squares_high_i16x8 = vmull_high_s8(data_i8x16, data_i8x16);
|
|
509
|
+
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vpaddlq_u16(vreinterpretq_u16_s16(squares_low_i16x8))));
|
|
510
|
+
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vpaddlq_u16(vreinterpretq_u16_s16(squares_high_i16x8))));
|
|
511
511
|
}
|
|
512
512
|
}
|
|
513
513
|
else {
|
|
514
|
-
for (; idx + 16
|
|
514
|
+
for (; idx + 16 < count; idx += 16) {
|
|
515
515
|
int8x16x4_t loaded_i8x16x4 = vld4q_s8(data_ptr + idx * 4);
|
|
516
516
|
int8x16_t data_i8x16 = loaded_i8x16x4.val[0];
|
|
517
517
|
int16x8_t pairwise_i16x8 = vpaddlq_s8(data_i8x16);
|
|
518
518
|
sum_i32x4 = vaddq_s32(sum_i32x4, vpaddlq_s16(pairwise_i16x8));
|
|
519
|
-
int16x8_t
|
|
520
|
-
int16x8_t
|
|
521
|
-
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vpaddlq_u16(vreinterpretq_u16_s16(
|
|
522
|
-
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vpaddlq_u16(vreinterpretq_u16_s16(
|
|
519
|
+
int16x8_t squares_low_i16x8 = vmull_s8(vget_low_s8(data_i8x16), vget_low_s8(data_i8x16));
|
|
520
|
+
int16x8_t squares_high_i16x8 = vmull_high_s8(data_i8x16, data_i8x16);
|
|
521
|
+
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vpaddlq_u16(vreinterpretq_u16_s16(squares_low_i16x8))));
|
|
522
|
+
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vpaddlq_u16(vreinterpretq_u16_s16(squares_high_i16x8))));
|
|
523
523
|
}
|
|
524
524
|
}
|
|
525
525
|
nk_i64_t sum = vaddlvq_s32(sum_i32x4);
|
|
@@ -578,7 +578,7 @@ NK_INTERNAL void nk_reduce_minmax_i8_neon_contiguous_( //
|
|
|
578
578
|
nk_partial_load_b8x16_serial_(data_ptr + idx, &tail_vec, remaining);
|
|
579
579
|
uint8x16_t lane_indices_u8x16 = vcombine_u8(vreinterpret_u8_u64(vcreate_u64(0x0706050403020100ULL)),
|
|
580
580
|
vreinterpret_u8_u64(vcreate_u64(0x0F0E0D0C0B0A0908ULL)));
|
|
581
|
-
uint8x16_t valid_u8x16 = vcltq_u8(lane_indices_u8x16, vdupq_n_u8((
|
|
581
|
+
uint8x16_t valid_u8x16 = vcltq_u8(lane_indices_u8x16, vdupq_n_u8((nk_u8_t)remaining));
|
|
582
582
|
int8x16_t data_for_min_i8x16 = vbslq_s8(valid_u8x16, tail_vec.i8x16, vdupq_n_s8(NK_I8_MAX));
|
|
583
583
|
int8x16_t data_for_max_i8x16 = vbslq_s8(valid_u8x16, tail_vec.i8x16, vdupq_n_s8(NK_I8_MIN));
|
|
584
584
|
uint8x16_t less_u8x16 = vcltq_s8(data_for_min_i8x16, min_i8x16);
|
|
@@ -624,28 +624,28 @@ NK_INTERNAL void nk_reduce_minmax_i8_neon_strided_( //
|
|
|
624
624
|
int8x16_t data_for_min_i8x16, data_for_max_i8x16;
|
|
625
625
|
|
|
626
626
|
nk_reduce_minmax_i8_neon_cycle:
|
|
627
|
-
if (stride_elements == 2 && idx + 16
|
|
627
|
+
if (stride_elements == 2 && idx + 16 < count) {
|
|
628
628
|
int8x16x2_t loaded = vld2q_s8(data_ptr + idx * 2);
|
|
629
629
|
data_for_min_i8x16 = loaded.val[0];
|
|
630
630
|
data_for_max_i8x16 = loaded.val[0];
|
|
631
631
|
idx += 16;
|
|
632
632
|
}
|
|
633
|
-
else if (stride_elements == 3 && idx + 16
|
|
633
|
+
else if (stride_elements == 3 && idx + 16 < count) {
|
|
634
634
|
int8x16x3_t loaded = vld3q_s8(data_ptr + idx * 3);
|
|
635
635
|
data_for_min_i8x16 = loaded.val[0];
|
|
636
636
|
data_for_max_i8x16 = loaded.val[0];
|
|
637
637
|
idx += 16;
|
|
638
638
|
}
|
|
639
|
-
else if (stride_elements == 4 && idx + 16
|
|
640
|
-
int8x16x4_t
|
|
641
|
-
data_for_min_i8x16 =
|
|
642
|
-
data_for_max_i8x16 =
|
|
639
|
+
else if (stride_elements == 4 && idx + 16 < count) {
|
|
640
|
+
int8x16x4_t loaded_i8x16x4 = vld4q_s8(data_ptr + idx * 4);
|
|
641
|
+
data_for_min_i8x16 = loaded_i8x16x4.val[0];
|
|
642
|
+
data_for_max_i8x16 = loaded_i8x16x4.val[0];
|
|
643
643
|
idx += 16;
|
|
644
644
|
}
|
|
645
645
|
else if (idx < count) {
|
|
646
646
|
nk_b128_vec_t tail_vec;
|
|
647
647
|
nk_strided_load_b8x16_serial_(data_ptr + idx * stride_elements, stride_elements, &tail_vec, count - idx);
|
|
648
|
-
uint8x16_t valid_u8x16 = vcltq_u8(lane_indices_u8x16, vdupq_n_u8((
|
|
648
|
+
uint8x16_t valid_u8x16 = vcltq_u8(lane_indices_u8x16, vdupq_n_u8((nk_u8_t)(count - idx)));
|
|
649
649
|
data_for_min_i8x16 = vbslq_s8(valid_u8x16, tail_vec.i8x16, min_i8x16);
|
|
650
650
|
data_for_max_i8x16 = vbslq_s8(valid_u8x16, tail_vec.i8x16, max_i8x16);
|
|
651
651
|
idx = count;
|
|
@@ -730,14 +730,14 @@ NK_INTERNAL void nk_reduce_moments_u8_neon_contiguous_( //
|
|
|
730
730
|
nk_size_t idx = 0;
|
|
731
731
|
for (; idx + 16 <= count; idx += 16) {
|
|
732
732
|
uint8x16_t data_u8x16 = vld1q_u8(data_ptr + idx);
|
|
733
|
-
uint16x8_t
|
|
734
|
-
sum_u32x4 = vaddq_u32(sum_u32x4, vpaddlq_u16(
|
|
735
|
-
uint16x8_t
|
|
736
|
-
uint16x8_t
|
|
737
|
-
uint32x4_t
|
|
738
|
-
uint32x4_t
|
|
739
|
-
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(
|
|
740
|
-
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(
|
|
733
|
+
uint16x8_t pairwise_sum_u16x8 = vpaddlq_u8(data_u8x16);
|
|
734
|
+
sum_u32x4 = vaddq_u32(sum_u32x4, vpaddlq_u16(pairwise_sum_u16x8));
|
|
735
|
+
uint16x8_t squares_low_u16x8 = vmull_u8(vget_low_u8(data_u8x16), vget_low_u8(data_u8x16));
|
|
736
|
+
uint16x8_t squares_high_u16x8 = vmull_high_u8(data_u8x16, data_u8x16);
|
|
737
|
+
uint32x4_t squares_low_u32x4 = vpaddlq_u16(squares_low_u16x8);
|
|
738
|
+
uint32x4_t squares_high_u32x4 = vpaddlq_u16(squares_high_u16x8);
|
|
739
|
+
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(squares_low_u32x4));
|
|
740
|
+
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(squares_high_u32x4));
|
|
741
741
|
}
|
|
742
742
|
nk_u64_t sum = vaddlvq_u32(sum_u32x4);
|
|
743
743
|
nk_u64_t sumsq = vgetq_lane_u64(sumsq_u64x2, 0) + vgetq_lane_u64(sumsq_u64x2, 1);
|
|
@@ -755,39 +755,39 @@ NK_INTERNAL void nk_reduce_moments_u8_neon_strided_( //
|
|
|
755
755
|
uint64x2_t sumsq_u64x2 = vdupq_n_u64(0);
|
|
756
756
|
nk_size_t idx = 0;
|
|
757
757
|
if (stride_elements == 2) {
|
|
758
|
-
for (; idx + 16
|
|
758
|
+
for (; idx + 16 < count; idx += 16) {
|
|
759
759
|
uint8x16x2_t loaded_u8x16x2 = vld2q_u8(data_ptr + idx * 2);
|
|
760
760
|
uint8x16_t data_u8x16 = loaded_u8x16x2.val[0];
|
|
761
761
|
uint16x8_t pairwise_u16x8 = vpaddlq_u8(data_u8x16);
|
|
762
762
|
sum_u32x4 = vaddq_u32(sum_u32x4, vpaddlq_u16(pairwise_u16x8));
|
|
763
|
-
uint16x8_t
|
|
764
|
-
uint16x8_t
|
|
765
|
-
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vpaddlq_u16(
|
|
766
|
-
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vpaddlq_u16(
|
|
763
|
+
uint16x8_t squares_low_u16x8 = vmull_u8(vget_low_u8(data_u8x16), vget_low_u8(data_u8x16));
|
|
764
|
+
uint16x8_t squares_high_u16x8 = vmull_high_u8(data_u8x16, data_u8x16);
|
|
765
|
+
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vpaddlq_u16(squares_low_u16x8)));
|
|
766
|
+
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vpaddlq_u16(squares_high_u16x8)));
|
|
767
767
|
}
|
|
768
768
|
}
|
|
769
769
|
else if (stride_elements == 3) {
|
|
770
|
-
for (; idx + 16
|
|
770
|
+
for (; idx + 16 < count; idx += 16) {
|
|
771
771
|
uint8x16x3_t loaded_u8x16x3 = vld3q_u8(data_ptr + idx * 3);
|
|
772
772
|
uint8x16_t data_u8x16 = loaded_u8x16x3.val[0];
|
|
773
773
|
uint16x8_t pairwise_u16x8 = vpaddlq_u8(data_u8x16);
|
|
774
774
|
sum_u32x4 = vaddq_u32(sum_u32x4, vpaddlq_u16(pairwise_u16x8));
|
|
775
|
-
uint16x8_t
|
|
776
|
-
uint16x8_t
|
|
777
|
-
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vpaddlq_u16(
|
|
778
|
-
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vpaddlq_u16(
|
|
775
|
+
uint16x8_t squares_low_u16x8 = vmull_u8(vget_low_u8(data_u8x16), vget_low_u8(data_u8x16));
|
|
776
|
+
uint16x8_t squares_high_u16x8 = vmull_high_u8(data_u8x16, data_u8x16);
|
|
777
|
+
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vpaddlq_u16(squares_low_u16x8)));
|
|
778
|
+
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vpaddlq_u16(squares_high_u16x8)));
|
|
779
779
|
}
|
|
780
780
|
}
|
|
781
781
|
else {
|
|
782
|
-
for (; idx + 16
|
|
782
|
+
for (; idx + 16 < count; idx += 16) {
|
|
783
783
|
uint8x16x4_t loaded_u8x16x4 = vld4q_u8(data_ptr + idx * 4);
|
|
784
784
|
uint8x16_t data_u8x16 = loaded_u8x16x4.val[0];
|
|
785
785
|
uint16x8_t pairwise_u16x8 = vpaddlq_u8(data_u8x16);
|
|
786
786
|
sum_u32x4 = vaddq_u32(sum_u32x4, vpaddlq_u16(pairwise_u16x8));
|
|
787
|
-
uint16x8_t
|
|
788
|
-
uint16x8_t
|
|
789
|
-
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vpaddlq_u16(
|
|
790
|
-
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vpaddlq_u16(
|
|
787
|
+
uint16x8_t squares_low_u16x8 = vmull_u8(vget_low_u8(data_u8x16), vget_low_u8(data_u8x16));
|
|
788
|
+
uint16x8_t squares_high_u16x8 = vmull_high_u8(data_u8x16, data_u8x16);
|
|
789
|
+
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vpaddlq_u16(squares_low_u16x8)));
|
|
790
|
+
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vpaddlq_u16(squares_high_u16x8)));
|
|
791
791
|
}
|
|
792
792
|
}
|
|
793
793
|
nk_u64_t sum = vaddlvq_u32(sum_u32x4);
|
|
@@ -845,7 +845,7 @@ NK_INTERNAL void nk_reduce_minmax_u8_neon_contiguous_( //
|
|
|
845
845
|
nk_partial_load_b8x16_serial_(data_ptr + idx, &tail_vec, remaining);
|
|
846
846
|
uint8x16_t lane_indices_u8x16 = vcombine_u8(vreinterpret_u8_u64(vcreate_u64(0x0706050403020100ULL)),
|
|
847
847
|
vreinterpret_u8_u64(vcreate_u64(0x0F0E0D0C0B0A0908ULL)));
|
|
848
|
-
uint8x16_t valid_u8x16 = vcltq_u8(lane_indices_u8x16, vdupq_n_u8((
|
|
848
|
+
uint8x16_t valid_u8x16 = vcltq_u8(lane_indices_u8x16, vdupq_n_u8((nk_u8_t)remaining));
|
|
849
849
|
uint8x16_t data_for_min_u8x16 = vbslq_u8(valid_u8x16, tail_vec.u8x16, vdupq_n_u8(NK_U8_MAX));
|
|
850
850
|
uint8x16_t data_for_max_u8x16 = vbslq_u8(valid_u8x16, tail_vec.u8x16, vdupq_n_u8(0));
|
|
851
851
|
uint8x16_t less_u8x16 = vcltq_u8(data_for_min_u8x16, min_u8x16);
|
|
@@ -891,28 +891,28 @@ NK_INTERNAL void nk_reduce_minmax_u8_neon_strided_( //
|
|
|
891
891
|
uint8x16_t data_for_min_u8x16, data_for_max_u8x16;
|
|
892
892
|
|
|
893
893
|
nk_reduce_minmax_u8_neon_cycle:
|
|
894
|
-
if (stride_elements == 2 && idx + 16
|
|
894
|
+
if (stride_elements == 2 && idx + 16 < count) {
|
|
895
895
|
uint8x16x2_t loaded = vld2q_u8((nk_u8_t const *)data_ptr + idx * 2);
|
|
896
896
|
data_for_min_u8x16 = loaded.val[0];
|
|
897
897
|
data_for_max_u8x16 = loaded.val[0];
|
|
898
898
|
idx += 16;
|
|
899
899
|
}
|
|
900
|
-
else if (stride_elements == 3 && idx + 16
|
|
900
|
+
else if (stride_elements == 3 && idx + 16 < count) {
|
|
901
901
|
uint8x16x3_t loaded = vld3q_u8((nk_u8_t const *)data_ptr + idx * 3);
|
|
902
902
|
data_for_min_u8x16 = loaded.val[0];
|
|
903
903
|
data_for_max_u8x16 = loaded.val[0];
|
|
904
904
|
idx += 16;
|
|
905
905
|
}
|
|
906
|
-
else if (stride_elements == 4 && idx + 16
|
|
907
|
-
uint8x16x4_t
|
|
908
|
-
data_for_min_u8x16 =
|
|
909
|
-
data_for_max_u8x16 =
|
|
906
|
+
else if (stride_elements == 4 && idx + 16 < count) {
|
|
907
|
+
uint8x16x4_t loaded_u8x16x4 = vld4q_u8((nk_u8_t const *)data_ptr + idx * 4);
|
|
908
|
+
data_for_min_u8x16 = loaded_u8x16x4.val[0];
|
|
909
|
+
data_for_max_u8x16 = loaded_u8x16x4.val[0];
|
|
910
910
|
idx += 16;
|
|
911
911
|
}
|
|
912
912
|
else if (idx < count) {
|
|
913
913
|
nk_b128_vec_t tail_vec;
|
|
914
914
|
nk_strided_load_b8x16_serial_(data_ptr + idx * stride_elements, stride_elements, &tail_vec, count - idx);
|
|
915
|
-
uint8x16_t valid_u8x16 = vcltq_u8(lane_indices_u8x16, vdupq_n_u8((
|
|
915
|
+
uint8x16_t valid_u8x16 = vcltq_u8(lane_indices_u8x16, vdupq_n_u8((nk_u8_t)(count - idx)));
|
|
916
916
|
data_for_min_u8x16 = vbslq_u8(valid_u8x16, tail_vec.u8x16, min_u8x16);
|
|
917
917
|
data_for_max_u8x16 = vbslq_u8(valid_u8x16, tail_vec.u8x16, max_u8x16);
|
|
918
918
|
idx = count;
|
|
@@ -996,14 +996,14 @@ NK_INTERNAL void nk_reduce_moments_i16_neon_contiguous_( //
|
|
|
996
996
|
nk_size_t idx = 0;
|
|
997
997
|
for (; idx + 8 <= count; idx += 8) {
|
|
998
998
|
int16x8_t data_i16x8 = vld1q_s16(data_ptr + idx);
|
|
999
|
-
int32x4_t
|
|
1000
|
-
sum_i64x2 = vaddq_s64(sum_i64x2, vpaddlq_s32(
|
|
999
|
+
int32x4_t sum32_i32x4 = vpaddlq_s16(data_i16x8);
|
|
1000
|
+
sum_i64x2 = vaddq_s64(sum_i64x2, vpaddlq_s32(sum32_i32x4));
|
|
1001
1001
|
// sumsq: widening multiply i16*i16 -> i32, then widen to u64
|
|
1002
|
-
int32x4_t
|
|
1003
|
-
int32x4_t
|
|
1002
|
+
int32x4_t sq_low_i32x4 = vmull_s16(vget_low_s16(data_i16x8), vget_low_s16(data_i16x8));
|
|
1003
|
+
int32x4_t sq_high_i32x4 = vmull_high_s16(data_i16x8, data_i16x8);
|
|
1004
1004
|
// i16*i16 squares are always non-negative, safe to reinterpret as u32
|
|
1005
|
-
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vreinterpretq_u32_s32(
|
|
1006
|
-
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vreinterpretq_u32_s32(
|
|
1005
|
+
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vreinterpretq_u32_s32(sq_low_i32x4)));
|
|
1006
|
+
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vreinterpretq_u32_s32(sq_high_i32x4)));
|
|
1007
1007
|
}
|
|
1008
1008
|
nk_i64_t sum = vgetq_lane_s64(sum_i64x2, 0) + vgetq_lane_s64(sum_i64x2, 1);
|
|
1009
1009
|
nk_u64_t sumsq = vgetq_lane_u64(sumsq_u64x2, 0) + vgetq_lane_u64(sumsq_u64x2, 1);
|
|
@@ -1022,39 +1022,39 @@ NK_INTERNAL void nk_reduce_moments_i16_neon_strided_( //
|
|
|
1022
1022
|
uint64x2_t sumsq_u64x2 = vdupq_n_u64(0);
|
|
1023
1023
|
nk_size_t idx = 0;
|
|
1024
1024
|
if (stride_elements == 2) {
|
|
1025
|
-
for (; idx + 8
|
|
1025
|
+
for (; idx + 8 < count; idx += 8) {
|
|
1026
1026
|
int16x8x2_t loaded_i16x8x2 = vld2q_s16(data_ptr + idx * 2);
|
|
1027
1027
|
int16x8_t data_i16x8 = loaded_i16x8x2.val[0];
|
|
1028
1028
|
int32x4_t pairwise_i32x4 = vpaddlq_s16(data_i16x8);
|
|
1029
1029
|
sum_i64x2 = vaddq_s64(sum_i64x2, vpaddlq_s32(pairwise_i32x4));
|
|
1030
|
-
int32x4_t
|
|
1031
|
-
int32x4_t
|
|
1032
|
-
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vreinterpretq_u32_s32(
|
|
1033
|
-
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vreinterpretq_u32_s32(
|
|
1030
|
+
int32x4_t squares_low_i32x4 = vmull_s16(vget_low_s16(data_i16x8), vget_low_s16(data_i16x8));
|
|
1031
|
+
int32x4_t squares_high_i32x4 = vmull_high_s16(data_i16x8, data_i16x8);
|
|
1032
|
+
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vreinterpretq_u32_s32(squares_low_i32x4)));
|
|
1033
|
+
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vreinterpretq_u32_s32(squares_high_i32x4)));
|
|
1034
1034
|
}
|
|
1035
1035
|
}
|
|
1036
1036
|
else if (stride_elements == 3) {
|
|
1037
|
-
for (; idx + 8
|
|
1037
|
+
for (; idx + 8 < count; idx += 8) {
|
|
1038
1038
|
int16x8x3_t loaded_i16x8x3 = vld3q_s16(data_ptr + idx * 3);
|
|
1039
1039
|
int16x8_t data_i16x8 = loaded_i16x8x3.val[0];
|
|
1040
1040
|
int32x4_t pairwise_i32x4 = vpaddlq_s16(data_i16x8);
|
|
1041
1041
|
sum_i64x2 = vaddq_s64(sum_i64x2, vpaddlq_s32(pairwise_i32x4));
|
|
1042
|
-
int32x4_t
|
|
1043
|
-
int32x4_t
|
|
1044
|
-
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vreinterpretq_u32_s32(
|
|
1045
|
-
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vreinterpretq_u32_s32(
|
|
1042
|
+
int32x4_t squares_low_i32x4 = vmull_s16(vget_low_s16(data_i16x8), vget_low_s16(data_i16x8));
|
|
1043
|
+
int32x4_t squares_high_i32x4 = vmull_high_s16(data_i16x8, data_i16x8);
|
|
1044
|
+
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vreinterpretq_u32_s32(squares_low_i32x4)));
|
|
1045
|
+
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vreinterpretq_u32_s32(squares_high_i32x4)));
|
|
1046
1046
|
}
|
|
1047
1047
|
}
|
|
1048
1048
|
else {
|
|
1049
|
-
for (; idx + 8
|
|
1049
|
+
for (; idx + 8 < count; idx += 8) {
|
|
1050
1050
|
int16x8x4_t loaded_i16x8x4 = vld4q_s16(data_ptr + idx * 4);
|
|
1051
1051
|
int16x8_t data_i16x8 = loaded_i16x8x4.val[0];
|
|
1052
1052
|
int32x4_t pairwise_i32x4 = vpaddlq_s16(data_i16x8);
|
|
1053
1053
|
sum_i64x2 = vaddq_s64(sum_i64x2, vpaddlq_s32(pairwise_i32x4));
|
|
1054
|
-
int32x4_t
|
|
1055
|
-
int32x4_t
|
|
1056
|
-
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vreinterpretq_u32_s32(
|
|
1057
|
-
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vreinterpretq_u32_s32(
|
|
1054
|
+
int32x4_t squares_low_i32x4 = vmull_s16(vget_low_s16(data_i16x8), vget_low_s16(data_i16x8));
|
|
1055
|
+
int32x4_t squares_high_i32x4 = vmull_high_s16(data_i16x8, data_i16x8);
|
|
1056
|
+
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vreinterpretq_u32_s32(squares_low_i32x4)));
|
|
1057
|
+
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vreinterpretq_u32_s32(squares_high_i32x4)));
|
|
1058
1058
|
}
|
|
1059
1059
|
}
|
|
1060
1060
|
nk_i64_t sum = vgetq_lane_s64(sum_i64x2, 0) + vgetq_lane_s64(sum_i64x2, 1);
|
|
@@ -1113,7 +1113,7 @@ NK_INTERNAL void nk_reduce_minmax_i16_neon_contiguous_( //
|
|
|
1113
1113
|
nk_partial_load_b16x8_serial_(data_ptr + idx, &tail_vec, remaining);
|
|
1114
1114
|
uint16x8_t lane_indices_u16x8 = vcombine_u16(vreinterpret_u16_u64(vcreate_u64(0x0003000200010000ULL)),
|
|
1115
1115
|
vreinterpret_u16_u64(vcreate_u64(0x0007000600050004ULL)));
|
|
1116
|
-
uint16x8_t valid_u16x8 = vcltq_u16(lane_indices_u16x8, vdupq_n_u16((
|
|
1116
|
+
uint16x8_t valid_u16x8 = vcltq_u16(lane_indices_u16x8, vdupq_n_u16((nk_u16_t)remaining));
|
|
1117
1117
|
int16x8_t data_for_min_i16x8 = vbslq_s16(valid_u16x8, tail_vec.i16x8, vdupq_n_s16(NK_I16_MAX));
|
|
1118
1118
|
int16x8_t data_for_max_i16x8 = vbslq_s16(valid_u16x8, tail_vec.i16x8, vdupq_n_s16(NK_I16_MIN));
|
|
1119
1119
|
uint16x8_t less_u16x8 = vcltq_s16(data_for_min_i16x8, min_i16x8);
|
|
@@ -1159,19 +1159,19 @@ NK_INTERNAL void nk_reduce_minmax_i16_neon_strided_( //
|
|
|
1159
1159
|
int16x8_t data_for_min_i16x8, data_for_max_i16x8;
|
|
1160
1160
|
|
|
1161
1161
|
nk_reduce_minmax_i16_neon_cycle:
|
|
1162
|
-
if (stride_elements == 2 && idx + 8
|
|
1162
|
+
if (stride_elements == 2 && idx + 8 < count) {
|
|
1163
1163
|
int16x8x2_t loaded = vld2q_s16(data_ptr + idx * 2);
|
|
1164
1164
|
data_for_min_i16x8 = loaded.val[0];
|
|
1165
1165
|
data_for_max_i16x8 = loaded.val[0];
|
|
1166
1166
|
idx += 8;
|
|
1167
1167
|
}
|
|
1168
|
-
else if (stride_elements == 3 && idx + 8
|
|
1168
|
+
else if (stride_elements == 3 && idx + 8 < count) {
|
|
1169
1169
|
int16x8x3_t loaded = vld3q_s16(data_ptr + idx * 3);
|
|
1170
1170
|
data_for_min_i16x8 = loaded.val[0];
|
|
1171
1171
|
data_for_max_i16x8 = loaded.val[0];
|
|
1172
1172
|
idx += 8;
|
|
1173
1173
|
}
|
|
1174
|
-
else if (stride_elements == 4 && idx + 8
|
|
1174
|
+
else if (stride_elements == 4 && idx + 8 < count) {
|
|
1175
1175
|
int16x8x4_t loaded = vld4q_s16(data_ptr + idx * 4);
|
|
1176
1176
|
data_for_min_i16x8 = loaded.val[0];
|
|
1177
1177
|
data_for_max_i16x8 = loaded.val[0];
|
|
@@ -1180,7 +1180,7 @@ nk_reduce_minmax_i16_neon_cycle:
|
|
|
1180
1180
|
else if (idx < count) {
|
|
1181
1181
|
nk_b128_vec_t tail_vec;
|
|
1182
1182
|
nk_strided_load_b16x8_serial_(data_ptr + idx * stride_elements, stride_elements, &tail_vec, count - idx);
|
|
1183
|
-
uint16x8_t valid_u16x8 = vcltq_u16(lane_indices_u16x8, vdupq_n_u16((
|
|
1183
|
+
uint16x8_t valid_u16x8 = vcltq_u16(lane_indices_u16x8, vdupq_n_u16((nk_u16_t)(count - idx)));
|
|
1184
1184
|
data_for_min_i16x8 = vbslq_s16(valid_u16x8, tail_vec.i16x8, min_i16x8);
|
|
1185
1185
|
data_for_max_i16x8 = vbslq_s16(valid_u16x8, tail_vec.i16x8, max_i16x8);
|
|
1186
1186
|
idx = count;
|
|
@@ -1265,12 +1265,12 @@ NK_INTERNAL void nk_reduce_moments_u16_neon_contiguous_( //
|
|
|
1265
1265
|
nk_size_t idx = 0;
|
|
1266
1266
|
for (; idx + 8 <= count; idx += 8) {
|
|
1267
1267
|
uint16x8_t data_u16x8 = vld1q_u16(data_ptr + idx);
|
|
1268
|
-
uint32x4_t
|
|
1269
|
-
sum_u64x2 = vaddq_u64(sum_u64x2, vpaddlq_u32(
|
|
1270
|
-
uint32x4_t
|
|
1271
|
-
uint32x4_t
|
|
1272
|
-
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(
|
|
1273
|
-
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(
|
|
1268
|
+
uint32x4_t sum32_u32x4 = vpaddlq_u16(data_u16x8);
|
|
1269
|
+
sum_u64x2 = vaddq_u64(sum_u64x2, vpaddlq_u32(sum32_u32x4));
|
|
1270
|
+
uint32x4_t sq_low_u32x4 = vmull_u16(vget_low_u16(data_u16x8), vget_low_u16(data_u16x8));
|
|
1271
|
+
uint32x4_t sq_high_u32x4 = vmull_high_u16(data_u16x8, data_u16x8);
|
|
1272
|
+
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(sq_low_u32x4));
|
|
1273
|
+
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(sq_high_u32x4));
|
|
1274
1274
|
}
|
|
1275
1275
|
nk_u64_t sum = vgetq_lane_u64(sum_u64x2, 0) + vgetq_lane_u64(sum_u64x2, 1);
|
|
1276
1276
|
nk_u64_t sumsq = vgetq_lane_u64(sumsq_u64x2, 0) + vgetq_lane_u64(sumsq_u64x2, 1);
|
|
@@ -1289,39 +1289,39 @@ NK_INTERNAL void nk_reduce_moments_u16_neon_strided_( //
|
|
|
1289
1289
|
nk_size_t idx = 0;
|
|
1290
1290
|
|
|
1291
1291
|
if (stride_elements == 2) {
|
|
1292
|
-
for (; idx + 8
|
|
1292
|
+
for (; idx + 8 < count; idx += 8) {
|
|
1293
1293
|
uint16x8x2_t loaded_u16x8x2 = vld2q_u16(data_ptr + idx * 2);
|
|
1294
1294
|
uint16x8_t data_u16x8 = loaded_u16x8x2.val[0];
|
|
1295
1295
|
uint32x4_t widened_sum_u32x4 = vpaddlq_u16(data_u16x8);
|
|
1296
1296
|
sum_u64x2 = vaddq_u64(sum_u64x2, vpaddlq_u32(widened_sum_u32x4));
|
|
1297
|
-
uint32x4_t
|
|
1298
|
-
uint32x4_t
|
|
1299
|
-
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(
|
|
1300
|
-
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(
|
|
1297
|
+
uint32x4_t sq_low_u32x4 = vmull_u16(vget_low_u16(data_u16x8), vget_low_u16(data_u16x8));
|
|
1298
|
+
uint32x4_t sq_high_u32x4 = vmull_high_u16(data_u16x8, data_u16x8);
|
|
1299
|
+
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(sq_low_u32x4));
|
|
1300
|
+
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(sq_high_u32x4));
|
|
1301
1301
|
}
|
|
1302
1302
|
}
|
|
1303
1303
|
else if (stride_elements == 3) {
|
|
1304
|
-
for (; idx + 8
|
|
1304
|
+
for (; idx + 8 < count; idx += 8) {
|
|
1305
1305
|
uint16x8x3_t loaded_u16x8x3 = vld3q_u16(data_ptr + idx * 3);
|
|
1306
1306
|
uint16x8_t data_u16x8 = loaded_u16x8x3.val[0];
|
|
1307
1307
|
uint32x4_t widened_sum_u32x4 = vpaddlq_u16(data_u16x8);
|
|
1308
1308
|
sum_u64x2 = vaddq_u64(sum_u64x2, vpaddlq_u32(widened_sum_u32x4));
|
|
1309
|
-
uint32x4_t
|
|
1310
|
-
uint32x4_t
|
|
1311
|
-
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(
|
|
1312
|
-
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(
|
|
1309
|
+
uint32x4_t sq_low_u32x4 = vmull_u16(vget_low_u16(data_u16x8), vget_low_u16(data_u16x8));
|
|
1310
|
+
uint32x4_t sq_high_u32x4 = vmull_high_u16(data_u16x8, data_u16x8);
|
|
1311
|
+
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(sq_low_u32x4));
|
|
1312
|
+
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(sq_high_u32x4));
|
|
1313
1313
|
}
|
|
1314
1314
|
}
|
|
1315
1315
|
else {
|
|
1316
|
-
for (; idx + 8
|
|
1316
|
+
for (; idx + 8 < count; idx += 8) {
|
|
1317
1317
|
uint16x8x4_t loaded_u16x8x4 = vld4q_u16(data_ptr + idx * 4);
|
|
1318
1318
|
uint16x8_t data_u16x8 = loaded_u16x8x4.val[0];
|
|
1319
1319
|
uint32x4_t widened_sum_u32x4 = vpaddlq_u16(data_u16x8);
|
|
1320
1320
|
sum_u64x2 = vaddq_u64(sum_u64x2, vpaddlq_u32(widened_sum_u32x4));
|
|
1321
|
-
uint32x4_t
|
|
1322
|
-
uint32x4_t
|
|
1323
|
-
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(
|
|
1324
|
-
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(
|
|
1321
|
+
uint32x4_t sq_low_u32x4 = vmull_u16(vget_low_u16(data_u16x8), vget_low_u16(data_u16x8));
|
|
1322
|
+
uint32x4_t sq_high_u32x4 = vmull_high_u16(data_u16x8, data_u16x8);
|
|
1323
|
+
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(sq_low_u32x4));
|
|
1324
|
+
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(sq_high_u32x4));
|
|
1325
1325
|
}
|
|
1326
1326
|
}
|
|
1327
1327
|
|
|
@@ -1380,7 +1380,7 @@ NK_INTERNAL void nk_reduce_minmax_u16_neon_contiguous_( //
|
|
|
1380
1380
|
nk_partial_load_b16x8_serial_(data_ptr + idx, &tail_vec, remaining);
|
|
1381
1381
|
uint16x8_t lane_indices_u16x8 = vcombine_u16(vreinterpret_u16_u64(vcreate_u64(0x0003000200010000ULL)),
|
|
1382
1382
|
vreinterpret_u16_u64(vcreate_u64(0x0007000600050004ULL)));
|
|
1383
|
-
uint16x8_t valid_u16x8 = vcltq_u16(lane_indices_u16x8, vdupq_n_u16((
|
|
1383
|
+
uint16x8_t valid_u16x8 = vcltq_u16(lane_indices_u16x8, vdupq_n_u16((nk_u16_t)remaining));
|
|
1384
1384
|
uint16x8_t data_for_min_u16x8 = vbslq_u16(valid_u16x8, tail_vec.u16x8, vdupq_n_u16(NK_U16_MAX));
|
|
1385
1385
|
uint16x8_t data_for_max_u16x8 = vbslq_u16(valid_u16x8, tail_vec.u16x8, vdupq_n_u16(0));
|
|
1386
1386
|
uint16x8_t less_u16x8 = vcltq_u16(data_for_min_u16x8, min_u16x8);
|
|
@@ -1426,19 +1426,19 @@ NK_INTERNAL void nk_reduce_minmax_u16_neon_strided_( //
|
|
|
1426
1426
|
uint16x8_t data_for_min_u16x8, data_for_max_u16x8;
|
|
1427
1427
|
|
|
1428
1428
|
nk_reduce_minmax_u16_neon_cycle:
|
|
1429
|
-
if (stride_elements == 2 && idx + 8
|
|
1429
|
+
if (stride_elements == 2 && idx + 8 < count) {
|
|
1430
1430
|
uint16x8x2_t loaded = vld2q_u16((nk_u16_t const *)data_ptr + idx * 2);
|
|
1431
1431
|
data_for_min_u16x8 = loaded.val[0];
|
|
1432
1432
|
data_for_max_u16x8 = loaded.val[0];
|
|
1433
1433
|
idx += 8;
|
|
1434
1434
|
}
|
|
1435
|
-
else if (stride_elements == 3 && idx + 8
|
|
1435
|
+
else if (stride_elements == 3 && idx + 8 < count) {
|
|
1436
1436
|
uint16x8x3_t loaded = vld3q_u16((nk_u16_t const *)data_ptr + idx * 3);
|
|
1437
1437
|
data_for_min_u16x8 = loaded.val[0];
|
|
1438
1438
|
data_for_max_u16x8 = loaded.val[0];
|
|
1439
1439
|
idx += 8;
|
|
1440
1440
|
}
|
|
1441
|
-
else if (stride_elements == 4 && idx + 8
|
|
1441
|
+
else if (stride_elements == 4 && idx + 8 < count) {
|
|
1442
1442
|
uint16x8x4_t loaded = vld4q_u16((nk_u16_t const *)data_ptr + idx * 4);
|
|
1443
1443
|
data_for_min_u16x8 = loaded.val[0];
|
|
1444
1444
|
data_for_max_u16x8 = loaded.val[0];
|
|
@@ -1447,7 +1447,7 @@ nk_reduce_minmax_u16_neon_cycle:
|
|
|
1447
1447
|
else if (idx < count) {
|
|
1448
1448
|
nk_b128_vec_t tail_vec;
|
|
1449
1449
|
nk_strided_load_b16x8_serial_(data_ptr + idx * stride_elements, stride_elements, &tail_vec, count - idx);
|
|
1450
|
-
uint16x8_t valid_u16x8 = vcltq_u16(lane_indices_u16x8, vdupq_n_u16((
|
|
1450
|
+
uint16x8_t valid_u16x8 = vcltq_u16(lane_indices_u16x8, vdupq_n_u16((nk_u16_t)(count - idx)));
|
|
1451
1451
|
data_for_min_u16x8 = vbslq_u16(valid_u16x8, tail_vec.u16x8, min_u16x8);
|
|
1452
1452
|
data_for_max_u16x8 = vbslq_u16(valid_u16x8, tail_vec.u16x8, max_u16x8);
|
|
1453
1453
|
idx = count;
|
|
@@ -1527,8 +1527,8 @@ NK_INTERNAL void nk_reduce_moments_i32_neon_contiguous_( //
|
|
|
1527
1527
|
nk_i32_t const *data_ptr, nk_size_t count, //
|
|
1528
1528
|
nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
1529
1529
|
// 128-bit accumulation: lower (u64) + upper (i64) per lane
|
|
1530
|
-
uint64x2_t
|
|
1531
|
-
int64x2_t
|
|
1530
|
+
uint64x2_t sum_low_u64x2 = vdupq_n_u64(0);
|
|
1531
|
+
int64x2_t sum_high_i64x2 = vdupq_n_s64(0);
|
|
1532
1532
|
uint64x2_t sumsq_u64x2 = vdupq_n_u64(0);
|
|
1533
1533
|
int sumsq_overflow = 0;
|
|
1534
1534
|
// XOR sign-bit trick for unsigned u64 compare on NEON
|
|
@@ -1537,39 +1537,41 @@ NK_INTERNAL void nk_reduce_moments_i32_neon_contiguous_( //
|
|
|
1537
1537
|
for (; idx + 4 <= count; idx += 4) {
|
|
1538
1538
|
int32x4_t data_i32x4 = vld1q_s32(data_ptr + idx);
|
|
1539
1539
|
// Sum: widen i32->i64 and accumulate with carry detection
|
|
1540
|
-
int64x2_t
|
|
1541
|
-
uint64x2_t
|
|
1542
|
-
|
|
1543
|
-
int64x2_t
|
|
1544
|
-
int64x2_t
|
|
1545
|
-
uint64x2_t
|
|
1546
|
-
|
|
1547
|
-
|
|
1548
|
-
|
|
1549
|
-
int64x2_t
|
|
1550
|
-
|
|
1551
|
-
|
|
1552
|
-
|
|
1553
|
-
|
|
1554
|
-
|
|
1555
|
-
|
|
1556
|
-
|
|
1540
|
+
int64x2_t data_low_i64x2 = vmovl_s32(vget_low_s32(data_i32x4));
|
|
1541
|
+
uint64x2_t before_u64x2 = sum_low_u64x2;
|
|
1542
|
+
sum_low_u64x2 = vaddq_u64(sum_low_u64x2, vreinterpretq_u64_s64(data_low_i64x2));
|
|
1543
|
+
int64x2_t result_biased_i64x2 = veorq_s64(vreinterpretq_s64_u64(sum_low_u64x2), sign_bit_i64x2);
|
|
1544
|
+
int64x2_t before_biased_i64x2 = veorq_s64(vreinterpretq_s64_u64(before_u64x2), sign_bit_i64x2);
|
|
1545
|
+
uint64x2_t carry_u64x2 = vcgtq_s64(before_biased_i64x2, result_biased_i64x2);
|
|
1546
|
+
sum_high_i64x2 = vsubq_s64(sum_high_i64x2, vreinterpretq_s64_u64(carry_u64x2));
|
|
1547
|
+
sum_high_i64x2 = vaddq_s64(sum_high_i64x2, vshrq_n_s64(data_low_i64x2, 63));
|
|
1548
|
+
|
|
1549
|
+
int64x2_t data_high_i64x2 = vmovl_high_s32(data_i32x4);
|
|
1550
|
+
before_u64x2 = sum_low_u64x2;
|
|
1551
|
+
sum_low_u64x2 = vaddq_u64(sum_low_u64x2, vreinterpretq_u64_s64(data_high_i64x2));
|
|
1552
|
+
result_biased_i64x2 = veorq_s64(vreinterpretq_s64_u64(sum_low_u64x2), sign_bit_i64x2);
|
|
1553
|
+
before_biased_i64x2 = veorq_s64(vreinterpretq_s64_u64(before_u64x2), sign_bit_i64x2);
|
|
1554
|
+
carry_u64x2 = vcgtq_s64(before_biased_i64x2, result_biased_i64x2);
|
|
1555
|
+
sum_high_i64x2 = vsubq_s64(sum_high_i64x2, vreinterpretq_s64_u64(carry_u64x2));
|
|
1556
|
+
sum_high_i64x2 = vaddq_s64(sum_high_i64x2, vshrq_n_s64(data_high_i64x2, 63));
|
|
1557
1557
|
|
|
1558
1558
|
// Sumsq: widening multiply i32*i32 -> i64 (always non-negative for squares)
|
|
1559
|
-
int64x2_t
|
|
1560
|
-
int64x2_t
|
|
1561
|
-
uint64x2_t
|
|
1562
|
-
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vreinterpretq_u64_s64(
|
|
1563
|
-
|
|
1564
|
-
|
|
1565
|
-
sumsq_overflow |=
|
|
1566
|
-
|
|
1567
|
-
|
|
1568
|
-
|
|
1569
|
-
|
|
1570
|
-
|
|
1571
|
-
|
|
1572
|
-
|
|
1559
|
+
int64x2_t squares_low_i64x2 = vmull_s32(vget_low_s32(data_i32x4), vget_low_s32(data_i32x4));
|
|
1560
|
+
int64x2_t squares_high_i64x2 = vmull_high_s32(data_i32x4, data_i32x4);
|
|
1561
|
+
uint64x2_t sq_before_u64x2 = sumsq_u64x2;
|
|
1562
|
+
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vreinterpretq_u64_s64(squares_low_i64x2));
|
|
1563
|
+
result_biased_i64x2 = veorq_s64(vreinterpretq_s64_u64(sumsq_u64x2), sign_bit_i64x2);
|
|
1564
|
+
before_biased_i64x2 = veorq_s64(vreinterpretq_s64_u64(sq_before_u64x2), sign_bit_i64x2);
|
|
1565
|
+
sumsq_overflow |=
|
|
1566
|
+
(vgetq_lane_s64(vreinterpretq_s64_u64(vcgtq_s64(before_biased_i64x2, result_biased_i64x2)), 0) |
|
|
1567
|
+
vgetq_lane_s64(vreinterpretq_s64_u64(vcgtq_s64(before_biased_i64x2, result_biased_i64x2)), 1));
|
|
1568
|
+
sq_before_u64x2 = sumsq_u64x2;
|
|
1569
|
+
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vreinterpretq_u64_s64(squares_high_i64x2));
|
|
1570
|
+
result_biased_i64x2 = veorq_s64(vreinterpretq_s64_u64(sumsq_u64x2), sign_bit_i64x2);
|
|
1571
|
+
before_biased_i64x2 = veorq_s64(vreinterpretq_s64_u64(sq_before_u64x2), sign_bit_i64x2);
|
|
1572
|
+
sumsq_overflow |=
|
|
1573
|
+
(vgetq_lane_s64(vreinterpretq_s64_u64(vcgtq_s64(before_biased_i64x2, result_biased_i64x2)), 0) |
|
|
1574
|
+
vgetq_lane_s64(vreinterpretq_s64_u64(vcgtq_s64(before_biased_i64x2, result_biased_i64x2)), 1));
|
|
1573
1575
|
}
|
|
1574
1576
|
// Sumsq horizontal saturating reduction
|
|
1575
1577
|
nk_u64_t sumsq;
|
|
@@ -1577,29 +1579,29 @@ NK_INTERNAL void nk_reduce_moments_i32_neon_contiguous_( //
|
|
|
1577
1579
|
else sumsq = nk_reduce_sadd_u64x2_neon_(sumsq_u64x2);
|
|
1578
1580
|
// Sum: horizontal 128-bit reduction (2 lanes -> scalar)
|
|
1579
1581
|
nk_b128_vec_t lower_vec, upper_vec;
|
|
1580
|
-
lower_vec.u64x2 =
|
|
1581
|
-
upper_vec.i64x2 =
|
|
1582
|
-
nk_u64_t
|
|
1583
|
-
nk_i64_t
|
|
1584
|
-
nk_u64_t sum_before =
|
|
1585
|
-
|
|
1586
|
-
sum_before =
|
|
1587
|
-
|
|
1582
|
+
lower_vec.u64x2 = sum_low_u64x2;
|
|
1583
|
+
upper_vec.i64x2 = sum_high_i64x2;
|
|
1584
|
+
nk_u64_t sum_low = 0;
|
|
1585
|
+
nk_i64_t sum_high = 0;
|
|
1586
|
+
nk_u64_t sum_before = sum_low;
|
|
1587
|
+
sum_low += lower_vec.u64s[0], sum_high += (sum_low < sum_before) + upper_vec.i64s[0];
|
|
1588
|
+
sum_before = sum_low;
|
|
1589
|
+
sum_low += lower_vec.u64s[1], sum_high += (sum_low < sum_before) + upper_vec.i64s[1];
|
|
1588
1590
|
// Scalar tail
|
|
1589
1591
|
for (; idx < count; ++idx) {
|
|
1590
1592
|
nk_i64_t value_i64 = (nk_i64_t)data_ptr[idx];
|
|
1591
|
-
sum_before =
|
|
1592
|
-
|
|
1593
|
-
if (
|
|
1594
|
-
|
|
1593
|
+
sum_before = sum_low;
|
|
1594
|
+
sum_low += (nk_u64_t)value_i64;
|
|
1595
|
+
if (sum_low < sum_before) sum_high++;
|
|
1596
|
+
sum_high += (value_i64 >> 63);
|
|
1595
1597
|
nk_i64_t product = nk_i64_saturating_mul_serial(value_i64, value_i64);
|
|
1596
1598
|
nk_u64_t unsigned_product = (nk_u64_t)product;
|
|
1597
1599
|
sumsq = nk_u64_saturating_add_serial(sumsq, unsigned_product);
|
|
1598
1600
|
}
|
|
1599
1601
|
// Clamp 128-bit sum to i64 range
|
|
1600
|
-
nk_i64_t
|
|
1601
|
-
if (
|
|
1602
|
-
else if (
|
|
1602
|
+
nk_i64_t sum_low_signed = (nk_i64_t)sum_low;
|
|
1603
|
+
if (sum_high == (sum_low_signed >> 63)) *sum_ptr = sum_low_signed;
|
|
1604
|
+
else if (sum_high >= 0) *sum_ptr = NK_I64_MAX;
|
|
1603
1605
|
else *sum_ptr = NK_I64_MIN;
|
|
1604
1606
|
*sumsq_ptr = sumsq;
|
|
1605
1607
|
}
|
|
@@ -1607,43 +1609,43 @@ NK_INTERNAL void nk_reduce_moments_i32_neon_contiguous_( //
|
|
|
1607
1609
|
NK_INTERNAL void nk_reduce_moments_i32_neon_strided_( //
|
|
1608
1610
|
nk_i32_t const *data_ptr, nk_size_t count, nk_size_t stride_elements, //
|
|
1609
1611
|
nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
1610
|
-
uint64x2_t
|
|
1611
|
-
int64x2_t
|
|
1612
|
+
uint64x2_t sum_low_u64x2 = vdupq_n_u64(0);
|
|
1613
|
+
int64x2_t sum_high_i64x2 = vdupq_n_s64(0);
|
|
1612
1614
|
uint64x2_t sumsq_u64x2 = vdupq_n_u64(0);
|
|
1613
1615
|
int sumsq_overflow = 0;
|
|
1614
1616
|
int64x2_t sign_bit_i64x2 = vdupq_n_s64((nk_i64_t)0x8000000000000000ULL);
|
|
1615
1617
|
nk_size_t idx = 0;
|
|
1616
1618
|
if (stride_elements == 2) {
|
|
1617
|
-
for (; idx + 4
|
|
1619
|
+
for (; idx + 4 < count; idx += 4) {
|
|
1618
1620
|
int32x4x2_t loaded_i32x4x2 = vld2q_s32(data_ptr + idx * 2);
|
|
1619
1621
|
int32x4_t data_i32x4 = loaded_i32x4x2.val[0];
|
|
1620
|
-
int64x2_t
|
|
1621
|
-
uint64x2_t before_u64x2 =
|
|
1622
|
-
|
|
1623
|
-
int64x2_t result_biased_i64x2 = veorq_s64(vreinterpretq_s64_u64(
|
|
1622
|
+
int64x2_t low_i64x2 = vmovl_s32(vget_low_s32(data_i32x4));
|
|
1623
|
+
uint64x2_t before_u64x2 = sum_low_u64x2;
|
|
1624
|
+
sum_low_u64x2 = vaddq_u64(sum_low_u64x2, vreinterpretq_u64_s64(low_i64x2));
|
|
1625
|
+
int64x2_t result_biased_i64x2 = veorq_s64(vreinterpretq_s64_u64(sum_low_u64x2), sign_bit_i64x2);
|
|
1624
1626
|
int64x2_t before_biased_i64x2 = veorq_s64(vreinterpretq_s64_u64(before_u64x2), sign_bit_i64x2);
|
|
1625
1627
|
uint64x2_t carry_u64x2 = vcgtq_s64(before_biased_i64x2, result_biased_i64x2);
|
|
1626
|
-
|
|
1627
|
-
|
|
1628
|
-
int64x2_t
|
|
1629
|
-
before_u64x2 =
|
|
1630
|
-
|
|
1631
|
-
result_biased_i64x2 = veorq_s64(vreinterpretq_s64_u64(
|
|
1628
|
+
sum_high_i64x2 = vsubq_s64(sum_high_i64x2, vreinterpretq_s64_u64(carry_u64x2));
|
|
1629
|
+
sum_high_i64x2 = vaddq_s64(sum_high_i64x2, vshrq_n_s64(low_i64x2, 63));
|
|
1630
|
+
int64x2_t high_i64x2 = vmovl_high_s32(data_i32x4);
|
|
1631
|
+
before_u64x2 = sum_low_u64x2;
|
|
1632
|
+
sum_low_u64x2 = vaddq_u64(sum_low_u64x2, vreinterpretq_u64_s64(high_i64x2));
|
|
1633
|
+
result_biased_i64x2 = veorq_s64(vreinterpretq_s64_u64(sum_low_u64x2), sign_bit_i64x2);
|
|
1632
1634
|
before_biased_i64x2 = veorq_s64(vreinterpretq_s64_u64(before_u64x2), sign_bit_i64x2);
|
|
1633
1635
|
carry_u64x2 = vcgtq_s64(before_biased_i64x2, result_biased_i64x2);
|
|
1634
|
-
|
|
1635
|
-
|
|
1636
|
-
int64x2_t
|
|
1637
|
-
int64x2_t
|
|
1636
|
+
sum_high_i64x2 = vsubq_s64(sum_high_i64x2, vreinterpretq_s64_u64(carry_u64x2));
|
|
1637
|
+
sum_high_i64x2 = vaddq_s64(sum_high_i64x2, vshrq_n_s64(high_i64x2, 63));
|
|
1638
|
+
int64x2_t squares_low_i64x2 = vmull_s32(vget_low_s32(data_i32x4), vget_low_s32(data_i32x4));
|
|
1639
|
+
int64x2_t squares_high_i64x2 = vmull_high_s32(data_i32x4, data_i32x4);
|
|
1638
1640
|
uint64x2_t sq_before_u64x2 = sumsq_u64x2;
|
|
1639
|
-
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vreinterpretq_u64_s64(
|
|
1641
|
+
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vreinterpretq_u64_s64(squares_low_i64x2));
|
|
1640
1642
|
result_biased_i64x2 = veorq_s64(vreinterpretq_s64_u64(sumsq_u64x2), sign_bit_i64x2);
|
|
1641
1643
|
before_biased_i64x2 = veorq_s64(vreinterpretq_s64_u64(sq_before_u64x2), sign_bit_i64x2);
|
|
1642
1644
|
sumsq_overflow |=
|
|
1643
1645
|
(vgetq_lane_s64(vreinterpretq_s64_u64(vcgtq_s64(before_biased_i64x2, result_biased_i64x2)), 0) |
|
|
1644
1646
|
vgetq_lane_s64(vreinterpretq_s64_u64(vcgtq_s64(before_biased_i64x2, result_biased_i64x2)), 1));
|
|
1645
1647
|
sq_before_u64x2 = sumsq_u64x2;
|
|
1646
|
-
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vreinterpretq_u64_s64(
|
|
1648
|
+
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vreinterpretq_u64_s64(squares_high_i64x2));
|
|
1647
1649
|
result_biased_i64x2 = veorq_s64(vreinterpretq_s64_u64(sumsq_u64x2), sign_bit_i64x2);
|
|
1648
1650
|
before_biased_i64x2 = veorq_s64(vreinterpretq_s64_u64(sq_before_u64x2), sign_bit_i64x2);
|
|
1649
1651
|
sumsq_overflow |=
|
|
@@ -1652,36 +1654,36 @@ NK_INTERNAL void nk_reduce_moments_i32_neon_strided_( //
|
|
|
1652
1654
|
}
|
|
1653
1655
|
}
|
|
1654
1656
|
else if (stride_elements == 3) {
|
|
1655
|
-
for (; idx + 4
|
|
1657
|
+
for (; idx + 4 < count; idx += 4) {
|
|
1656
1658
|
int32x4x3_t loaded_i32x4x3 = vld3q_s32(data_ptr + idx * 3);
|
|
1657
1659
|
int32x4_t data_i32x4 = loaded_i32x4x3.val[0];
|
|
1658
|
-
int64x2_t
|
|
1659
|
-
uint64x2_t before_u64x2 =
|
|
1660
|
-
|
|
1661
|
-
int64x2_t result_biased_i64x2 = veorq_s64(vreinterpretq_s64_u64(
|
|
1660
|
+
int64x2_t low_i64x2 = vmovl_s32(vget_low_s32(data_i32x4));
|
|
1661
|
+
uint64x2_t before_u64x2 = sum_low_u64x2;
|
|
1662
|
+
sum_low_u64x2 = vaddq_u64(sum_low_u64x2, vreinterpretq_u64_s64(low_i64x2));
|
|
1663
|
+
int64x2_t result_biased_i64x2 = veorq_s64(vreinterpretq_s64_u64(sum_low_u64x2), sign_bit_i64x2);
|
|
1662
1664
|
int64x2_t before_biased_i64x2 = veorq_s64(vreinterpretq_s64_u64(before_u64x2), sign_bit_i64x2);
|
|
1663
1665
|
uint64x2_t carry_u64x2 = vcgtq_s64(before_biased_i64x2, result_biased_i64x2);
|
|
1664
|
-
|
|
1665
|
-
|
|
1666
|
-
int64x2_t
|
|
1667
|
-
before_u64x2 =
|
|
1668
|
-
|
|
1669
|
-
result_biased_i64x2 = veorq_s64(vreinterpretq_s64_u64(
|
|
1666
|
+
sum_high_i64x2 = vsubq_s64(sum_high_i64x2, vreinterpretq_s64_u64(carry_u64x2));
|
|
1667
|
+
sum_high_i64x2 = vaddq_s64(sum_high_i64x2, vshrq_n_s64(low_i64x2, 63));
|
|
1668
|
+
int64x2_t high_i64x2 = vmovl_high_s32(data_i32x4);
|
|
1669
|
+
before_u64x2 = sum_low_u64x2;
|
|
1670
|
+
sum_low_u64x2 = vaddq_u64(sum_low_u64x2, vreinterpretq_u64_s64(high_i64x2));
|
|
1671
|
+
result_biased_i64x2 = veorq_s64(vreinterpretq_s64_u64(sum_low_u64x2), sign_bit_i64x2);
|
|
1670
1672
|
before_biased_i64x2 = veorq_s64(vreinterpretq_s64_u64(before_u64x2), sign_bit_i64x2);
|
|
1671
1673
|
carry_u64x2 = vcgtq_s64(before_biased_i64x2, result_biased_i64x2);
|
|
1672
|
-
|
|
1673
|
-
|
|
1674
|
-
int64x2_t
|
|
1675
|
-
int64x2_t
|
|
1674
|
+
sum_high_i64x2 = vsubq_s64(sum_high_i64x2, vreinterpretq_s64_u64(carry_u64x2));
|
|
1675
|
+
sum_high_i64x2 = vaddq_s64(sum_high_i64x2, vshrq_n_s64(high_i64x2, 63));
|
|
1676
|
+
int64x2_t squares_low_i64x2 = vmull_s32(vget_low_s32(data_i32x4), vget_low_s32(data_i32x4));
|
|
1677
|
+
int64x2_t squares_high_i64x2 = vmull_high_s32(data_i32x4, data_i32x4);
|
|
1676
1678
|
uint64x2_t sq_before_u64x2 = sumsq_u64x2;
|
|
1677
|
-
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vreinterpretq_u64_s64(
|
|
1679
|
+
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vreinterpretq_u64_s64(squares_low_i64x2));
|
|
1678
1680
|
result_biased_i64x2 = veorq_s64(vreinterpretq_s64_u64(sumsq_u64x2), sign_bit_i64x2);
|
|
1679
1681
|
before_biased_i64x2 = veorq_s64(vreinterpretq_s64_u64(sq_before_u64x2), sign_bit_i64x2);
|
|
1680
1682
|
sumsq_overflow |=
|
|
1681
1683
|
(vgetq_lane_s64(vreinterpretq_s64_u64(vcgtq_s64(before_biased_i64x2, result_biased_i64x2)), 0) |
|
|
1682
1684
|
vgetq_lane_s64(vreinterpretq_s64_u64(vcgtq_s64(before_biased_i64x2, result_biased_i64x2)), 1));
|
|
1683
1685
|
sq_before_u64x2 = sumsq_u64x2;
|
|
1684
|
-
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vreinterpretq_u64_s64(
|
|
1686
|
+
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vreinterpretq_u64_s64(squares_high_i64x2));
|
|
1685
1687
|
result_biased_i64x2 = veorq_s64(vreinterpretq_s64_u64(sumsq_u64x2), sign_bit_i64x2);
|
|
1686
1688
|
before_biased_i64x2 = veorq_s64(vreinterpretq_s64_u64(sq_before_u64x2), sign_bit_i64x2);
|
|
1687
1689
|
sumsq_overflow |=
|
|
@@ -1690,36 +1692,36 @@ NK_INTERNAL void nk_reduce_moments_i32_neon_strided_( //
|
|
|
1690
1692
|
}
|
|
1691
1693
|
}
|
|
1692
1694
|
else {
|
|
1693
|
-
for (; idx + 4
|
|
1695
|
+
for (; idx + 4 < count; idx += 4) {
|
|
1694
1696
|
int32x4x4_t loaded_i32x4x4 = vld4q_s32(data_ptr + idx * 4);
|
|
1695
1697
|
int32x4_t data_i32x4 = loaded_i32x4x4.val[0];
|
|
1696
|
-
int64x2_t
|
|
1697
|
-
uint64x2_t before_u64x2 =
|
|
1698
|
-
|
|
1699
|
-
int64x2_t result_biased_i64x2 = veorq_s64(vreinterpretq_s64_u64(
|
|
1698
|
+
int64x2_t low_i64x2 = vmovl_s32(vget_low_s32(data_i32x4));
|
|
1699
|
+
uint64x2_t before_u64x2 = sum_low_u64x2;
|
|
1700
|
+
sum_low_u64x2 = vaddq_u64(sum_low_u64x2, vreinterpretq_u64_s64(low_i64x2));
|
|
1701
|
+
int64x2_t result_biased_i64x2 = veorq_s64(vreinterpretq_s64_u64(sum_low_u64x2), sign_bit_i64x2);
|
|
1700
1702
|
int64x2_t before_biased_i64x2 = veorq_s64(vreinterpretq_s64_u64(before_u64x2), sign_bit_i64x2);
|
|
1701
1703
|
uint64x2_t carry_u64x2 = vcgtq_s64(before_biased_i64x2, result_biased_i64x2);
|
|
1702
|
-
|
|
1703
|
-
|
|
1704
|
-
int64x2_t
|
|
1705
|
-
before_u64x2 =
|
|
1706
|
-
|
|
1707
|
-
result_biased_i64x2 = veorq_s64(vreinterpretq_s64_u64(
|
|
1704
|
+
sum_high_i64x2 = vsubq_s64(sum_high_i64x2, vreinterpretq_s64_u64(carry_u64x2));
|
|
1705
|
+
sum_high_i64x2 = vaddq_s64(sum_high_i64x2, vshrq_n_s64(low_i64x2, 63));
|
|
1706
|
+
int64x2_t high_i64x2 = vmovl_high_s32(data_i32x4);
|
|
1707
|
+
before_u64x2 = sum_low_u64x2;
|
|
1708
|
+
sum_low_u64x2 = vaddq_u64(sum_low_u64x2, vreinterpretq_u64_s64(high_i64x2));
|
|
1709
|
+
result_biased_i64x2 = veorq_s64(vreinterpretq_s64_u64(sum_low_u64x2), sign_bit_i64x2);
|
|
1708
1710
|
before_biased_i64x2 = veorq_s64(vreinterpretq_s64_u64(before_u64x2), sign_bit_i64x2);
|
|
1709
1711
|
carry_u64x2 = vcgtq_s64(before_biased_i64x2, result_biased_i64x2);
|
|
1710
|
-
|
|
1711
|
-
|
|
1712
|
-
int64x2_t
|
|
1713
|
-
int64x2_t
|
|
1712
|
+
sum_high_i64x2 = vsubq_s64(sum_high_i64x2, vreinterpretq_s64_u64(carry_u64x2));
|
|
1713
|
+
sum_high_i64x2 = vaddq_s64(sum_high_i64x2, vshrq_n_s64(high_i64x2, 63));
|
|
1714
|
+
int64x2_t squares_low_i64x2 = vmull_s32(vget_low_s32(data_i32x4), vget_low_s32(data_i32x4));
|
|
1715
|
+
int64x2_t squares_high_i64x2 = vmull_high_s32(data_i32x4, data_i32x4);
|
|
1714
1716
|
uint64x2_t sq_before_u64x2 = sumsq_u64x2;
|
|
1715
|
-
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vreinterpretq_u64_s64(
|
|
1717
|
+
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vreinterpretq_u64_s64(squares_low_i64x2));
|
|
1716
1718
|
result_biased_i64x2 = veorq_s64(vreinterpretq_s64_u64(sumsq_u64x2), sign_bit_i64x2);
|
|
1717
1719
|
before_biased_i64x2 = veorq_s64(vreinterpretq_s64_u64(sq_before_u64x2), sign_bit_i64x2);
|
|
1718
1720
|
sumsq_overflow |=
|
|
1719
1721
|
(vgetq_lane_s64(vreinterpretq_s64_u64(vcgtq_s64(before_biased_i64x2, result_biased_i64x2)), 0) |
|
|
1720
1722
|
vgetq_lane_s64(vreinterpretq_s64_u64(vcgtq_s64(before_biased_i64x2, result_biased_i64x2)), 1));
|
|
1721
1723
|
sq_before_u64x2 = sumsq_u64x2;
|
|
1722
|
-
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vreinterpretq_u64_s64(
|
|
1724
|
+
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vreinterpretq_u64_s64(squares_high_i64x2));
|
|
1723
1725
|
result_biased_i64x2 = veorq_s64(vreinterpretq_s64_u64(sumsq_u64x2), sign_bit_i64x2);
|
|
1724
1726
|
before_biased_i64x2 = veorq_s64(vreinterpretq_s64_u64(sq_before_u64x2), sign_bit_i64x2);
|
|
1725
1727
|
sumsq_overflow |=
|
|
@@ -1731,27 +1733,27 @@ NK_INTERNAL void nk_reduce_moments_i32_neon_strided_( //
|
|
|
1731
1733
|
if (sumsq_overflow) sumsq = NK_U64_MAX;
|
|
1732
1734
|
else sumsq = nk_reduce_sadd_u64x2_neon_(sumsq_u64x2);
|
|
1733
1735
|
nk_b128_vec_t lower_vec, upper_vec;
|
|
1734
|
-
lower_vec.u64x2 =
|
|
1735
|
-
upper_vec.i64x2 =
|
|
1736
|
-
nk_u64_t
|
|
1737
|
-
nk_i64_t
|
|
1738
|
-
nk_u64_t sum_before =
|
|
1739
|
-
|
|
1740
|
-
sum_before =
|
|
1741
|
-
|
|
1736
|
+
lower_vec.u64x2 = sum_low_u64x2;
|
|
1737
|
+
upper_vec.i64x2 = sum_high_i64x2;
|
|
1738
|
+
nk_u64_t sum_low = 0;
|
|
1739
|
+
nk_i64_t sum_high = 0;
|
|
1740
|
+
nk_u64_t sum_before = sum_low;
|
|
1741
|
+
sum_low += lower_vec.u64s[0], sum_high += (sum_low < sum_before) + upper_vec.i64s[0];
|
|
1742
|
+
sum_before = sum_low;
|
|
1743
|
+
sum_low += lower_vec.u64s[1], sum_high += (sum_low < sum_before) + upper_vec.i64s[1];
|
|
1742
1744
|
for (; idx < count; ++idx) {
|
|
1743
1745
|
nk_i64_t val = (nk_i64_t) * (data_ptr + idx * stride_elements);
|
|
1744
|
-
sum_before =
|
|
1745
|
-
|
|
1746
|
-
if (
|
|
1747
|
-
|
|
1746
|
+
sum_before = sum_low;
|
|
1747
|
+
sum_low += (nk_u64_t)val;
|
|
1748
|
+
if (sum_low < sum_before) sum_high++;
|
|
1749
|
+
sum_high += (val >> 63);
|
|
1748
1750
|
nk_i64_t product = nk_i64_saturating_mul_serial(val, val);
|
|
1749
1751
|
nk_u64_t unsigned_product = (nk_u64_t)product;
|
|
1750
1752
|
sumsq = nk_u64_saturating_add_serial(sumsq, unsigned_product);
|
|
1751
1753
|
}
|
|
1752
|
-
nk_i64_t
|
|
1753
|
-
if (
|
|
1754
|
-
else if (
|
|
1754
|
+
nk_i64_t sum_low_signed = (nk_i64_t)sum_low;
|
|
1755
|
+
if (sum_high == (sum_low_signed >> 63)) *sum_ptr = sum_low_signed;
|
|
1756
|
+
else if (sum_high >= 0) *sum_ptr = NK_I64_MAX;
|
|
1755
1757
|
else *sum_ptr = NK_I64_MIN;
|
|
1756
1758
|
*sumsq_ptr = sumsq;
|
|
1757
1759
|
}
|
|
@@ -1793,7 +1795,7 @@ NK_INTERNAL void nk_reduce_minmax_i32_neon_contiguous_( //
|
|
|
1793
1795
|
nk_partial_load_b32x4_serial_(data_ptr + idx, &tail_vec, remaining);
|
|
1794
1796
|
uint32x4_t lane_indices_u32x4 = vcombine_u32(vreinterpret_u32_u64(vcreate_u64(0x0000000100000000ULL)),
|
|
1795
1797
|
vreinterpret_u32_u64(vcreate_u64(0x0000000300000002ULL)));
|
|
1796
|
-
uint32x4_t valid_u32x4 = vcltq_u32(lane_indices_u32x4, vdupq_n_u32((
|
|
1798
|
+
uint32x4_t valid_u32x4 = vcltq_u32(lane_indices_u32x4, vdupq_n_u32((nk_u32_t)remaining));
|
|
1797
1799
|
int32x4_t data_min_i32x4 = vbslq_s32(valid_u32x4, tail_vec.i32x4, vdupq_n_s32(NK_I32_MAX));
|
|
1798
1800
|
int32x4_t data_max_i32x4 = vbslq_s32(valid_u32x4, tail_vec.i32x4, vdupq_n_s32(NK_I32_MIN));
|
|
1799
1801
|
uint32x4_t less_u32x4 = vcltq_s32(data_min_i32x4, min_i32x4);
|
|
@@ -1839,19 +1841,19 @@ NK_INTERNAL void nk_reduce_minmax_i32_neon_strided_( //
|
|
|
1839
1841
|
int32x4_t data_for_min_i32x4, data_for_max_i32x4;
|
|
1840
1842
|
|
|
1841
1843
|
nk_reduce_minmax_i32_neon_cycle:
|
|
1842
|
-
if (stride_elements == 2 && idx + 4
|
|
1844
|
+
if (stride_elements == 2 && idx + 4 < count) {
|
|
1843
1845
|
int32x4x2_t loaded = vld2q_s32(data_ptr + idx * 2);
|
|
1844
1846
|
data_for_min_i32x4 = loaded.val[0];
|
|
1845
1847
|
data_for_max_i32x4 = loaded.val[0];
|
|
1846
1848
|
idx += 4;
|
|
1847
1849
|
}
|
|
1848
|
-
else if (stride_elements == 3 && idx + 4
|
|
1850
|
+
else if (stride_elements == 3 && idx + 4 < count) {
|
|
1849
1851
|
int32x4x3_t loaded = vld3q_s32(data_ptr + idx * 3);
|
|
1850
1852
|
data_for_min_i32x4 = loaded.val[0];
|
|
1851
1853
|
data_for_max_i32x4 = loaded.val[0];
|
|
1852
1854
|
idx += 4;
|
|
1853
1855
|
}
|
|
1854
|
-
else if (stride_elements == 4 && idx + 4
|
|
1856
|
+
else if (stride_elements == 4 && idx + 4 < count) {
|
|
1855
1857
|
int32x4x4_t loaded = vld4q_s32(data_ptr + idx * 4);
|
|
1856
1858
|
data_for_min_i32x4 = loaded.val[0];
|
|
1857
1859
|
data_for_max_i32x4 = loaded.val[0];
|
|
@@ -1860,7 +1862,7 @@ nk_reduce_minmax_i32_neon_cycle:
|
|
|
1860
1862
|
else if (idx < count) {
|
|
1861
1863
|
nk_b128_vec_t tail_vec;
|
|
1862
1864
|
nk_strided_load_b32x4_serial_(data_ptr + idx * stride_elements, stride_elements, &tail_vec, count - idx);
|
|
1863
|
-
uint32x4_t valid_u32x4 = vcltq_u32(lane_indices_u32x4, vdupq_n_u32((
|
|
1865
|
+
uint32x4_t valid_u32x4 = vcltq_u32(lane_indices_u32x4, vdupq_n_u32((nk_u32_t)(count - idx)));
|
|
1864
1866
|
data_for_min_i32x4 = vbslq_s32(valid_u32x4, tail_vec.i32x4, min_i32x4);
|
|
1865
1867
|
data_for_max_i32x4 = vbslq_s32(valid_u32x4, tail_vec.i32x4, max_i32x4);
|
|
1866
1868
|
idx = count;
|
|
@@ -1951,10 +1953,10 @@ NK_INTERNAL void nk_reduce_moments_u32_neon_contiguous_( //
|
|
|
1951
1953
|
sum_u64x2 = vaddq_u64(sum_u64x2, vmovl_u32(vget_low_u32(data_u32x4)));
|
|
1952
1954
|
sum_u64x2 = vaddq_u64(sum_u64x2, vmovl_high_u32(data_u32x4));
|
|
1953
1955
|
// Sumsq: widening multiply u32*u32 -> u64, saturating add
|
|
1954
|
-
uint64x2_t
|
|
1955
|
-
uint64x2_t
|
|
1956
|
-
sumsq_u64x2 = vqaddq_u64(sumsq_u64x2,
|
|
1957
|
-
sumsq_u64x2 = vqaddq_u64(sumsq_u64x2,
|
|
1956
|
+
uint64x2_t sq_low_u64x2 = vmull_u32(vget_low_u32(data_u32x4), vget_low_u32(data_u32x4));
|
|
1957
|
+
uint64x2_t sq_high_u64x2 = vmull_high_u32(data_u32x4, data_u32x4);
|
|
1958
|
+
sumsq_u64x2 = vqaddq_u64(sumsq_u64x2, sq_low_u64x2);
|
|
1959
|
+
sumsq_u64x2 = vqaddq_u64(sumsq_u64x2, sq_high_u64x2);
|
|
1958
1960
|
}
|
|
1959
1961
|
nk_u64_t sum = vgetq_lane_u64(sum_u64x2, 0) + vgetq_lane_u64(sum_u64x2, 1);
|
|
1960
1962
|
nk_u64_t sumsq = nk_reduce_sadd_u64x2_neon_(sumsq_u64x2);
|
|
@@ -1974,39 +1976,39 @@ NK_INTERNAL void nk_reduce_moments_u32_neon_strided_( //
|
|
|
1974
1976
|
uint64x2_t sumsq_u64x2 = vdupq_n_u64(0);
|
|
1975
1977
|
nk_size_t idx = 0;
|
|
1976
1978
|
if (stride_elements == 2) {
|
|
1977
|
-
for (; idx + 4
|
|
1979
|
+
for (; idx + 4 < count; idx += 4) {
|
|
1978
1980
|
uint32x4x2_t loaded_u32x4x2 = vld2q_u32(data_ptr + idx * 2);
|
|
1979
1981
|
uint32x4_t data_u32x4 = loaded_u32x4x2.val[0];
|
|
1980
1982
|
sum_u64x2 = vaddq_u64(sum_u64x2, vmovl_u32(vget_low_u32(data_u32x4)));
|
|
1981
1983
|
sum_u64x2 = vaddq_u64(sum_u64x2, vmovl_high_u32(data_u32x4));
|
|
1982
|
-
uint64x2_t
|
|
1983
|
-
uint64x2_t
|
|
1984
|
-
sumsq_u64x2 = vqaddq_u64(sumsq_u64x2,
|
|
1985
|
-
sumsq_u64x2 = vqaddq_u64(sumsq_u64x2,
|
|
1984
|
+
uint64x2_t squares_low_u64x2 = vmull_u32(vget_low_u32(data_u32x4), vget_low_u32(data_u32x4));
|
|
1985
|
+
uint64x2_t squares_high_u64x2 = vmull_high_u32(data_u32x4, data_u32x4);
|
|
1986
|
+
sumsq_u64x2 = vqaddq_u64(sumsq_u64x2, squares_low_u64x2);
|
|
1987
|
+
sumsq_u64x2 = vqaddq_u64(sumsq_u64x2, squares_high_u64x2);
|
|
1986
1988
|
}
|
|
1987
1989
|
}
|
|
1988
1990
|
else if (stride_elements == 3) {
|
|
1989
|
-
for (; idx + 4
|
|
1991
|
+
for (; idx + 4 < count; idx += 4) {
|
|
1990
1992
|
uint32x4x3_t loaded_u32x4x3 = vld3q_u32(data_ptr + idx * 3);
|
|
1991
1993
|
uint32x4_t data_u32x4 = loaded_u32x4x3.val[0];
|
|
1992
1994
|
sum_u64x2 = vaddq_u64(sum_u64x2, vmovl_u32(vget_low_u32(data_u32x4)));
|
|
1993
1995
|
sum_u64x2 = vaddq_u64(sum_u64x2, vmovl_high_u32(data_u32x4));
|
|
1994
|
-
uint64x2_t
|
|
1995
|
-
uint64x2_t
|
|
1996
|
-
sumsq_u64x2 = vqaddq_u64(sumsq_u64x2,
|
|
1997
|
-
sumsq_u64x2 = vqaddq_u64(sumsq_u64x2,
|
|
1996
|
+
uint64x2_t squares_low_u64x2 = vmull_u32(vget_low_u32(data_u32x4), vget_low_u32(data_u32x4));
|
|
1997
|
+
uint64x2_t squares_high_u64x2 = vmull_high_u32(data_u32x4, data_u32x4);
|
|
1998
|
+
sumsq_u64x2 = vqaddq_u64(sumsq_u64x2, squares_low_u64x2);
|
|
1999
|
+
sumsq_u64x2 = vqaddq_u64(sumsq_u64x2, squares_high_u64x2);
|
|
1998
2000
|
}
|
|
1999
2001
|
}
|
|
2000
2002
|
else {
|
|
2001
|
-
for (; idx + 4
|
|
2003
|
+
for (; idx + 4 < count; idx += 4) {
|
|
2002
2004
|
uint32x4x4_t loaded_u32x4x4 = vld4q_u32(data_ptr + idx * 4);
|
|
2003
2005
|
uint32x4_t data_u32x4 = loaded_u32x4x4.val[0];
|
|
2004
2006
|
sum_u64x2 = vaddq_u64(sum_u64x2, vmovl_u32(vget_low_u32(data_u32x4)));
|
|
2005
2007
|
sum_u64x2 = vaddq_u64(sum_u64x2, vmovl_high_u32(data_u32x4));
|
|
2006
|
-
uint64x2_t
|
|
2007
|
-
uint64x2_t
|
|
2008
|
-
sumsq_u64x2 = vqaddq_u64(sumsq_u64x2,
|
|
2009
|
-
sumsq_u64x2 = vqaddq_u64(sumsq_u64x2,
|
|
2008
|
+
uint64x2_t squares_low_u64x2 = vmull_u32(vget_low_u32(data_u32x4), vget_low_u32(data_u32x4));
|
|
2009
|
+
uint64x2_t squares_high_u64x2 = vmull_high_u32(data_u32x4, data_u32x4);
|
|
2010
|
+
sumsq_u64x2 = vqaddq_u64(sumsq_u64x2, squares_low_u64x2);
|
|
2011
|
+
sumsq_u64x2 = vqaddq_u64(sumsq_u64x2, squares_high_u64x2);
|
|
2010
2012
|
}
|
|
2011
2013
|
}
|
|
2012
2014
|
nk_u64_t sum = vgetq_lane_u64(sum_u64x2, 0) + vgetq_lane_u64(sum_u64x2, 1);
|
|
@@ -2066,7 +2068,7 @@ NK_INTERNAL void nk_reduce_minmax_u32_neon_contiguous_( //
|
|
|
2066
2068
|
nk_partial_load_b32x4_serial_(data_ptr + idx, &tail_vec, remaining);
|
|
2067
2069
|
uint32x4_t lane_indices_u32x4 = vcombine_u32(vreinterpret_u32_u64(vcreate_u64(0x0000000100000000ULL)),
|
|
2068
2070
|
vreinterpret_u32_u64(vcreate_u64(0x0000000300000002ULL)));
|
|
2069
|
-
uint32x4_t valid_u32x4 = vcltq_u32(lane_indices_u32x4, vdupq_n_u32((
|
|
2071
|
+
uint32x4_t valid_u32x4 = vcltq_u32(lane_indices_u32x4, vdupq_n_u32((nk_u32_t)remaining));
|
|
2070
2072
|
uint32x4_t data_min_u32x4 = vbslq_u32(valid_u32x4, tail_vec.u32x4, vdupq_n_u32(NK_U32_MAX));
|
|
2071
2073
|
uint32x4_t data_max_u32x4 = vbslq_u32(valid_u32x4, tail_vec.u32x4, vdupq_n_u32(0));
|
|
2072
2074
|
uint32x4_t less_u32x4 = vcltq_u32(data_min_u32x4, min_u32x4);
|
|
@@ -2112,19 +2114,19 @@ NK_INTERNAL void nk_reduce_minmax_u32_neon_strided_( //
|
|
|
2112
2114
|
uint32x4_t data_for_min_u32x4, data_for_max_u32x4;
|
|
2113
2115
|
|
|
2114
2116
|
nk_reduce_minmax_u32_neon_cycle:
|
|
2115
|
-
if (stride_elements == 2 && idx + 4
|
|
2117
|
+
if (stride_elements == 2 && idx + 4 < count) {
|
|
2116
2118
|
uint32x4x2_t loaded = vld2q_u32(data_ptr + idx * 2);
|
|
2117
2119
|
data_for_min_u32x4 = loaded.val[0];
|
|
2118
2120
|
data_for_max_u32x4 = loaded.val[0];
|
|
2119
2121
|
idx += 4;
|
|
2120
2122
|
}
|
|
2121
|
-
else if (stride_elements == 3 && idx + 4
|
|
2123
|
+
else if (stride_elements == 3 && idx + 4 < count) {
|
|
2122
2124
|
uint32x4x3_t loaded = vld3q_u32(data_ptr + idx * 3);
|
|
2123
2125
|
data_for_min_u32x4 = loaded.val[0];
|
|
2124
2126
|
data_for_max_u32x4 = loaded.val[0];
|
|
2125
2127
|
idx += 4;
|
|
2126
2128
|
}
|
|
2127
|
-
else if (stride_elements == 4 && idx + 4
|
|
2129
|
+
else if (stride_elements == 4 && idx + 4 < count) {
|
|
2128
2130
|
uint32x4x4_t loaded = vld4q_u32(data_ptr + idx * 4);
|
|
2129
2131
|
data_for_min_u32x4 = loaded.val[0];
|
|
2130
2132
|
data_for_max_u32x4 = loaded.val[0];
|
|
@@ -2133,7 +2135,7 @@ nk_reduce_minmax_u32_neon_cycle:
|
|
|
2133
2135
|
else if (idx < count) {
|
|
2134
2136
|
nk_b128_vec_t tail_vec;
|
|
2135
2137
|
nk_strided_load_b32x4_serial_(data_ptr + idx * stride_elements, stride_elements, &tail_vec, count - idx);
|
|
2136
|
-
uint32x4_t valid_u32x4 = vcltq_u32(lane_indices_u32x4, vdupq_n_u32((
|
|
2138
|
+
uint32x4_t valid_u32x4 = vcltq_u32(lane_indices_u32x4, vdupq_n_u32((nk_u32_t)(count - idx)));
|
|
2137
2139
|
data_for_min_u32x4 = vbslq_u32(valid_u32x4, tail_vec.u32x4, min_u32x4);
|
|
2138
2140
|
data_for_max_u32x4 = vbslq_u32(valid_u32x4, tail_vec.u32x4, max_u32x4);
|
|
2139
2141
|
idx = count;
|
|
@@ -2214,8 +2216,8 @@ NK_PUBLIC void nk_reduce_minmax_u32_neon( //
|
|
|
2214
2216
|
NK_INTERNAL void nk_reduce_moments_i64_neon_contiguous_( //
|
|
2215
2217
|
nk_i64_t const *data_ptr, nk_size_t count, //
|
|
2216
2218
|
nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
2217
|
-
uint64x2_t
|
|
2218
|
-
int64x2_t
|
|
2219
|
+
uint64x2_t sum_low_u64x2 = vdupq_n_u64(0);
|
|
2220
|
+
int64x2_t sum_high_i64x2 = vdupq_n_s64(0);
|
|
2219
2221
|
// NEON can still load/extract i64 vectors for sumsq via scalar nk_i64_smul_
|
|
2220
2222
|
uint64x2_t sumsq_u64x2 = vdupq_n_u64(0);
|
|
2221
2223
|
int sumsq_overflow = 0;
|
|
@@ -2224,31 +2226,32 @@ NK_INTERNAL void nk_reduce_moments_i64_neon_contiguous_( //
|
|
|
2224
2226
|
for (; idx + 2 <= count; idx += 2) {
|
|
2225
2227
|
int64x2_t data_i64x2 = vld1q_s64(data_ptr + idx);
|
|
2226
2228
|
// Sumsq via helper (scalar per-lane multiply)
|
|
2227
|
-
uint64x2_t
|
|
2228
|
-
uint64x2_t
|
|
2229
|
-
sumsq_u64x2 = vaddq_u64(sumsq_u64x2,
|
|
2230
|
-
int64x2_t
|
|
2231
|
-
int64x2_t
|
|
2232
|
-
sumsq_overflow |=
|
|
2233
|
-
|
|
2229
|
+
uint64x2_t sq_u64x2 = nk_i64_smul_sq_i64x2_neon_(data_i64x2);
|
|
2230
|
+
uint64x2_t sq_before_u64x2 = sumsq_u64x2;
|
|
2231
|
+
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, sq_u64x2);
|
|
2232
|
+
int64x2_t result_biased_i64x2 = veorq_s64(vreinterpretq_s64_u64(sumsq_u64x2), sign_bit_i64x2);
|
|
2233
|
+
int64x2_t before_biased_i64x2 = veorq_s64(vreinterpretq_s64_u64(sq_before_u64x2), sign_bit_i64x2);
|
|
2234
|
+
sumsq_overflow |=
|
|
2235
|
+
(vgetq_lane_s64(vreinterpretq_s64_u64(vcgtq_s64(before_biased_i64x2, result_biased_i64x2)), 0) |
|
|
2236
|
+
vgetq_lane_s64(vreinterpretq_s64_u64(vcgtq_s64(before_biased_i64x2, result_biased_i64x2)), 1));
|
|
2234
2237
|
// Vectorized 128-bit carry-propagating sum
|
|
2235
|
-
uint64x2_t sum_before_u64x2 =
|
|
2236
|
-
|
|
2238
|
+
uint64x2_t sum_before_u64x2 = sum_low_u64x2;
|
|
2239
|
+
sum_low_u64x2 = vaddq_u64(sum_low_u64x2, vreinterpretq_u64_s64(data_i64x2));
|
|
2237
2240
|
int64x2_t sb_biased = veorq_s64(vreinterpretq_s64_u64(sum_before_u64x2), sign_bit_i64x2);
|
|
2238
|
-
int64x2_t sr_biased = veorq_s64(vreinterpretq_s64_u64(
|
|
2241
|
+
int64x2_t sr_biased = veorq_s64(vreinterpretq_s64_u64(sum_low_u64x2), sign_bit_i64x2);
|
|
2239
2242
|
uint64x2_t carry_u64x2 = vcgtq_s64(sb_biased, sr_biased);
|
|
2240
|
-
|
|
2243
|
+
sum_high_i64x2 = vsubq_s64(sum_high_i64x2, vreinterpretq_s64_u64(carry_u64x2));
|
|
2241
2244
|
int64x2_t sign_ext_i64x2 = vshrq_n_s64(data_i64x2, 63);
|
|
2242
|
-
|
|
2245
|
+
sum_high_i64x2 = vaddq_s64(sum_high_i64x2, sign_ext_i64x2);
|
|
2243
2246
|
}
|
|
2244
|
-
// Horizontal reduction of 2 lanes to scalar (
|
|
2245
|
-
nk_u64_t
|
|
2246
|
-
nk_i64_t
|
|
2247
|
+
// Horizontal reduction of 2 lanes to scalar (sum_low, sum_high)
|
|
2248
|
+
nk_u64_t sum_low = vgetq_lane_u64(sum_low_u64x2, 0);
|
|
2249
|
+
nk_i64_t sum_high = vgetq_lane_s64(sum_high_i64x2, 0);
|
|
2247
2250
|
{
|
|
2248
|
-
nk_u64_t before =
|
|
2249
|
-
|
|
2250
|
-
if (
|
|
2251
|
-
|
|
2251
|
+
nk_u64_t before = sum_low;
|
|
2252
|
+
sum_low += vgetq_lane_u64(sum_low_u64x2, 1);
|
|
2253
|
+
if (sum_low < before) sum_high++;
|
|
2254
|
+
sum_high += vgetq_lane_s64(sum_high_i64x2, 1);
|
|
2252
2255
|
}
|
|
2253
2256
|
nk_u64_t sumsq;
|
|
2254
2257
|
if (sumsq_overflow) sumsq = NK_U64_MAX;
|
|
@@ -2258,14 +2261,14 @@ NK_INTERNAL void nk_reduce_moments_i64_neon_contiguous_( //
|
|
|
2258
2261
|
nk_i64_t product = nk_i64_saturating_mul_serial(val, val);
|
|
2259
2262
|
nk_u64_t unsigned_product = (nk_u64_t)product;
|
|
2260
2263
|
sumsq = nk_u64_saturating_add_serial(sumsq, unsigned_product);
|
|
2261
|
-
nk_u64_t before =
|
|
2262
|
-
|
|
2263
|
-
if (
|
|
2264
|
-
|
|
2265
|
-
}
|
|
2266
|
-
nk_i64_t
|
|
2267
|
-
if (
|
|
2268
|
-
else if (
|
|
2264
|
+
nk_u64_t before = sum_low;
|
|
2265
|
+
sum_low += (nk_u64_t)val;
|
|
2266
|
+
if (sum_low < before) sum_high++;
|
|
2267
|
+
sum_high += (val >> 63);
|
|
2268
|
+
}
|
|
2269
|
+
nk_i64_t sum_low_signed = (nk_i64_t)sum_low;
|
|
2270
|
+
if (sum_high == (sum_low_signed >> 63)) *sum_ptr = sum_low_signed;
|
|
2271
|
+
else if (sum_high >= 0) *sum_ptr = NK_I64_MAX;
|
|
2269
2272
|
else *sum_ptr = NK_I64_MIN;
|
|
2270
2273
|
*sumsq_ptr = sumsq;
|
|
2271
2274
|
}
|
|
@@ -2286,8 +2289,8 @@ NK_INTERNAL void nk_reduce_minmax_i64_neon_contiguous_( //
|
|
|
2286
2289
|
nk_i64_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
2287
2290
|
nk_i64_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
2288
2291
|
int64x2_t min_i64x2 = vdupq_n_s64(NK_I64_MAX), max_i64x2 = vdupq_n_s64(NK_I64_MIN);
|
|
2289
|
-
uint64x2_t
|
|
2290
|
-
uint64x2_t
|
|
2292
|
+
uint64x2_t min_iter_u64x2 = vdupq_n_u64(0), max_iter_u64x2 = vdupq_n_u64(0);
|
|
2293
|
+
uint64x2_t iter_u64x2 = vdupq_n_u64(0), one_u64x2 = vdupq_n_u64(1);
|
|
2291
2294
|
nk_size_t idx = 0;
|
|
2292
2295
|
for (; idx + 2 <= count; idx += 2) {
|
|
2293
2296
|
int64x2_t data_i64x2 = vld1q_s64(data_ptr + idx);
|
|
@@ -2295,15 +2298,15 @@ NK_INTERNAL void nk_reduce_minmax_i64_neon_contiguous_( //
|
|
|
2295
2298
|
uint64x2_t greater_u64x2 = vcgtq_s64(data_i64x2, max_i64x2);
|
|
2296
2299
|
min_i64x2 = vbslq_s64(less_u64x2, data_i64x2, min_i64x2);
|
|
2297
2300
|
max_i64x2 = vbslq_s64(greater_u64x2, data_i64x2, max_i64x2);
|
|
2298
|
-
|
|
2299
|
-
|
|
2300
|
-
|
|
2301
|
+
min_iter_u64x2 = vbslq_u64(less_u64x2, iter_u64x2, min_iter_u64x2);
|
|
2302
|
+
max_iter_u64x2 = vbslq_u64(greater_u64x2, iter_u64x2, max_iter_u64x2);
|
|
2303
|
+
iter_u64x2 = vaddq_u64(iter_u64x2, one_u64x2);
|
|
2301
2304
|
}
|
|
2302
2305
|
nk_b128_vec_t min_values_vec, max_values_vec, min_indices_vec, max_indices_vec;
|
|
2303
2306
|
min_values_vec.i64x2 = min_i64x2;
|
|
2304
|
-
min_indices_vec.u64x2 =
|
|
2307
|
+
min_indices_vec.u64x2 = min_iter_u64x2;
|
|
2305
2308
|
max_values_vec.i64x2 = max_i64x2;
|
|
2306
|
-
max_indices_vec.u64x2 =
|
|
2309
|
+
max_indices_vec.u64x2 = max_iter_u64x2;
|
|
2307
2310
|
nk_i64_t min_value, max_value;
|
|
2308
2311
|
nk_size_t min_index, max_index;
|
|
2309
2312
|
if (min_values_vec.i64s[0] <= min_values_vec.i64s[1])
|
|
@@ -2350,8 +2353,8 @@ NK_INTERNAL void nk_reduce_moments_u64_neon_contiguous_( //
|
|
|
2350
2353
|
for (; idx + 2 <= count; idx += 2) {
|
|
2351
2354
|
uint64x2_t data_u64x2 = vld1q_u64(data_ptr + idx);
|
|
2352
2355
|
sum_u64x2 = vqaddq_u64(sum_u64x2, data_u64x2);
|
|
2353
|
-
uint64x2_t
|
|
2354
|
-
sumsq_u64x2 = vqaddq_u64(sumsq_u64x2,
|
|
2356
|
+
uint64x2_t sq_u64x2 = nk_u64_smul_sq_u64x2_neon_(data_u64x2);
|
|
2357
|
+
sumsq_u64x2 = vqaddq_u64(sumsq_u64x2, sq_u64x2);
|
|
2355
2358
|
}
|
|
2356
2359
|
nk_u64_t sum = nk_reduce_sadd_u64x2_neon_(sum_u64x2);
|
|
2357
2360
|
nk_u64_t sumsq = nk_reduce_sadd_u64x2_neon_(sumsq_u64x2);
|
|
@@ -2380,8 +2383,8 @@ NK_INTERNAL void nk_reduce_minmax_u64_neon_contiguous_( //
|
|
|
2380
2383
|
nk_u64_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
2381
2384
|
nk_u64_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
2382
2385
|
uint64x2_t min_u64x2 = vdupq_n_u64(NK_U64_MAX), max_u64x2 = vdupq_n_u64(0);
|
|
2383
|
-
uint64x2_t
|
|
2384
|
-
uint64x2_t
|
|
2386
|
+
uint64x2_t min_iter_u64x2 = vdupq_n_u64(0), max_iter_u64x2 = vdupq_n_u64(0);
|
|
2387
|
+
uint64x2_t iter_u64x2 = vdupq_n_u64(0), one_u64x2 = vdupq_n_u64(1);
|
|
2385
2388
|
nk_size_t idx = 0;
|
|
2386
2389
|
for (; idx + 2 <= count; idx += 2) {
|
|
2387
2390
|
uint64x2_t data_u64x2 = vld1q_u64(data_ptr + idx);
|
|
@@ -2389,15 +2392,15 @@ NK_INTERNAL void nk_reduce_minmax_u64_neon_contiguous_( //
|
|
|
2389
2392
|
uint64x2_t greater_u64x2 = vcgtq_u64(data_u64x2, max_u64x2);
|
|
2390
2393
|
min_u64x2 = vbslq_u64(less_u64x2, data_u64x2, min_u64x2);
|
|
2391
2394
|
max_u64x2 = vbslq_u64(greater_u64x2, data_u64x2, max_u64x2);
|
|
2392
|
-
|
|
2393
|
-
|
|
2394
|
-
|
|
2395
|
+
min_iter_u64x2 = vbslq_u64(less_u64x2, iter_u64x2, min_iter_u64x2);
|
|
2396
|
+
max_iter_u64x2 = vbslq_u64(greater_u64x2, iter_u64x2, max_iter_u64x2);
|
|
2397
|
+
iter_u64x2 = vaddq_u64(iter_u64x2, one_u64x2);
|
|
2395
2398
|
}
|
|
2396
2399
|
nk_b128_vec_t min_values_vec, max_values_vec, min_indices_vec, max_indices_vec;
|
|
2397
2400
|
min_values_vec.u64x2 = min_u64x2;
|
|
2398
|
-
min_indices_vec.u64x2 =
|
|
2401
|
+
min_indices_vec.u64x2 = min_iter_u64x2;
|
|
2399
2402
|
max_values_vec.u64x2 = max_u64x2;
|
|
2400
|
-
max_indices_vec.u64x2 =
|
|
2403
|
+
max_indices_vec.u64x2 = max_iter_u64x2;
|
|
2401
2404
|
nk_u64_t min_value, max_value;
|
|
2402
2405
|
nk_size_t min_index, max_index;
|
|
2403
2406
|
if (min_values_vec.u64s[0] <= min_values_vec.u64s[1])
|
|
@@ -2494,10 +2497,10 @@ NK_INTERNAL void nk_reduce_moments_e2m3_neon_contiguous_( //
|
|
|
2494
2497
|
int8x16_t scaled_i8x16 = vbslq_s8(is_negative_u8x16, negative_i8x16, positive_i8x16);
|
|
2495
2498
|
int16x8_t pairwise_i16x8 = vpaddlq_s8(scaled_i8x16);
|
|
2496
2499
|
sum_i32x4 = vaddq_s32(sum_i32x4, vpaddlq_s16(pairwise_i16x8));
|
|
2497
|
-
int16x8_t
|
|
2498
|
-
int16x8_t
|
|
2499
|
-
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vpaddlq_u16(vreinterpretq_u16_s16(
|
|
2500
|
-
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vpaddlq_u16(vreinterpretq_u16_s16(
|
|
2500
|
+
int16x8_t squares_low_i16x8 = vmull_s8(vget_low_s8(scaled_i8x16), vget_low_s8(scaled_i8x16));
|
|
2501
|
+
int16x8_t squares_high_i16x8 = vmull_high_s8(scaled_i8x16, scaled_i8x16);
|
|
2502
|
+
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vpaddlq_u16(vreinterpretq_u16_s16(squares_low_i16x8))));
|
|
2503
|
+
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vpaddlq_u16(vreinterpretq_u16_s16(squares_high_i16x8))));
|
|
2501
2504
|
}
|
|
2502
2505
|
nk_i64_t sum = vaddlvq_s32(sum_i32x4);
|
|
2503
2506
|
nk_u64_t sumsq = vgetq_lane_u64(sumsq_u64x2, 0) + vgetq_lane_u64(sumsq_u64x2, 1);
|
|
@@ -2527,7 +2530,7 @@ NK_INTERNAL void nk_reduce_moments_e2m3_neon_strided_( //
|
|
|
2527
2530
|
uint64x2_t sumsq_u64x2 = vdupq_n_u64(0);
|
|
2528
2531
|
nk_size_t idx = 0;
|
|
2529
2532
|
if (stride_elements == 2) {
|
|
2530
|
-
for (; idx + 16
|
|
2533
|
+
for (; idx + 16 < count; idx += 16) {
|
|
2531
2534
|
uint8x16x2_t loaded_u8x16x2 = vld2q_u8((nk_u8_t const *)(data_ptr + idx * 2));
|
|
2532
2535
|
uint8x16_t raw_u8x16 = loaded_u8x16x2.val[0];
|
|
2533
2536
|
uint8x16_t magnitude_u8x16 = vandq_u8(raw_u8x16, vdupq_n_u8(0x1F));
|
|
@@ -2538,14 +2541,14 @@ NK_INTERNAL void nk_reduce_moments_e2m3_neon_strided_( //
|
|
|
2538
2541
|
int8x16_t scaled_i8x16 = vbslq_s8(is_negative_u8x16, negative_i8x16, positive_i8x16);
|
|
2539
2542
|
int16x8_t pairwise_i16x8 = vpaddlq_s8(scaled_i8x16);
|
|
2540
2543
|
sum_i32x4 = vaddq_s32(sum_i32x4, vpaddlq_s16(pairwise_i16x8));
|
|
2541
|
-
int16x8_t
|
|
2542
|
-
int16x8_t
|
|
2543
|
-
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vpaddlq_u16(vreinterpretq_u16_s16(
|
|
2544
|
-
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vpaddlq_u16(vreinterpretq_u16_s16(
|
|
2544
|
+
int16x8_t squares_low_i16x8 = vmull_s8(vget_low_s8(scaled_i8x16), vget_low_s8(scaled_i8x16));
|
|
2545
|
+
int16x8_t squares_high_i16x8 = vmull_high_s8(scaled_i8x16, scaled_i8x16);
|
|
2546
|
+
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vpaddlq_u16(vreinterpretq_u16_s16(squares_low_i16x8))));
|
|
2547
|
+
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vpaddlq_u16(vreinterpretq_u16_s16(squares_high_i16x8))));
|
|
2545
2548
|
}
|
|
2546
2549
|
}
|
|
2547
2550
|
else if (stride_elements == 3) {
|
|
2548
|
-
for (; idx + 16
|
|
2551
|
+
for (; idx + 16 < count; idx += 16) {
|
|
2549
2552
|
uint8x16x3_t loaded_u8x16x3 = vld3q_u8((nk_u8_t const *)(data_ptr + idx * 3));
|
|
2550
2553
|
uint8x16_t raw_u8x16 = loaded_u8x16x3.val[0];
|
|
2551
2554
|
uint8x16_t magnitude_u8x16 = vandq_u8(raw_u8x16, vdupq_n_u8(0x1F));
|
|
@@ -2556,14 +2559,14 @@ NK_INTERNAL void nk_reduce_moments_e2m3_neon_strided_( //
|
|
|
2556
2559
|
int8x16_t scaled_i8x16 = vbslq_s8(is_negative_u8x16, negative_i8x16, positive_i8x16);
|
|
2557
2560
|
int16x8_t pairwise_i16x8 = vpaddlq_s8(scaled_i8x16);
|
|
2558
2561
|
sum_i32x4 = vaddq_s32(sum_i32x4, vpaddlq_s16(pairwise_i16x8));
|
|
2559
|
-
int16x8_t
|
|
2560
|
-
int16x8_t
|
|
2561
|
-
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vpaddlq_u16(vreinterpretq_u16_s16(
|
|
2562
|
-
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vpaddlq_u16(vreinterpretq_u16_s16(
|
|
2562
|
+
int16x8_t squares_low_i16x8 = vmull_s8(vget_low_s8(scaled_i8x16), vget_low_s8(scaled_i8x16));
|
|
2563
|
+
int16x8_t squares_high_i16x8 = vmull_high_s8(scaled_i8x16, scaled_i8x16);
|
|
2564
|
+
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vpaddlq_u16(vreinterpretq_u16_s16(squares_low_i16x8))));
|
|
2565
|
+
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vpaddlq_u16(vreinterpretq_u16_s16(squares_high_i16x8))));
|
|
2563
2566
|
}
|
|
2564
2567
|
}
|
|
2565
2568
|
else {
|
|
2566
|
-
for (; idx + 16
|
|
2569
|
+
for (; idx + 16 < count; idx += 16) {
|
|
2567
2570
|
uint8x16x4_t loaded_u8x16x4 = vld4q_u8((nk_u8_t const *)(data_ptr + idx * 4));
|
|
2568
2571
|
uint8x16_t raw_u8x16 = loaded_u8x16x4.val[0];
|
|
2569
2572
|
uint8x16_t magnitude_u8x16 = vandq_u8(raw_u8x16, vdupq_n_u8(0x1F));
|
|
@@ -2574,10 +2577,10 @@ NK_INTERNAL void nk_reduce_moments_e2m3_neon_strided_( //
|
|
|
2574
2577
|
int8x16_t scaled_i8x16 = vbslq_s8(is_negative_u8x16, negative_i8x16, positive_i8x16);
|
|
2575
2578
|
int16x8_t pairwise_i16x8 = vpaddlq_s8(scaled_i8x16);
|
|
2576
2579
|
sum_i32x4 = vaddq_s32(sum_i32x4, vpaddlq_s16(pairwise_i16x8));
|
|
2577
|
-
int16x8_t
|
|
2578
|
-
int16x8_t
|
|
2579
|
-
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vpaddlq_u16(vreinterpretq_u16_s16(
|
|
2580
|
-
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vpaddlq_u16(vreinterpretq_u16_s16(
|
|
2580
|
+
int16x8_t squares_low_i16x8 = vmull_s8(vget_low_s8(scaled_i8x16), vget_low_s8(scaled_i8x16));
|
|
2581
|
+
int16x8_t squares_high_i16x8 = vmull_high_s8(scaled_i8x16, scaled_i8x16);
|
|
2582
|
+
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vpaddlq_u16(vreinterpretq_u16_s16(squares_low_i16x8))));
|
|
2583
|
+
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vpaddlq_u16(vreinterpretq_u16_s16(squares_high_i16x8))));
|
|
2581
2584
|
}
|
|
2582
2585
|
}
|
|
2583
2586
|
nk_i64_t sum = vaddlvq_s32(sum_i32x4);
|
|
@@ -2625,7 +2628,7 @@ NK_INTERNAL void nk_reduce_minmax_e2m3_neon_contiguous_( //
|
|
|
2625
2628
|
// Mask invalid lanes: min gets 0xFF (won't be selected), max gets 0x00
|
|
2626
2629
|
uint8x16_t lane_indices_u8x16 = vcombine_u8(vreinterpret_u8_u64(vcreate_u64(0x0706050403020100ULL)),
|
|
2627
2630
|
vreinterpret_u8_u64(vcreate_u64(0x0F0E0D0C0B0A0908ULL)));
|
|
2628
|
-
uint8x16_t valid_u8x16 = vcltq_u8(lane_indices_u8x16, vdupq_n_u8((
|
|
2631
|
+
uint8x16_t valid_u8x16 = vcltq_u8(lane_indices_u8x16, vdupq_n_u8((nk_u8_t)count));
|
|
2629
2632
|
first_comparable_u8x16 = vbslq_u8(valid_u8x16, first_comparable_u8x16, vdupq_n_u8(0));
|
|
2630
2633
|
}
|
|
2631
2634
|
else {
|
|
@@ -2636,7 +2639,7 @@ NK_INTERNAL void nk_reduce_minmax_e2m3_neon_contiguous_( //
|
|
|
2636
2639
|
// For max: invalid lanes (0x00) should not win, which is already correct since 0x00 won't beat real data
|
|
2637
2640
|
uint8x16_t lane_indices_init_u8x16 = vcombine_u8(vreinterpret_u8_u64(vcreate_u64(0x0706050403020100ULL)),
|
|
2638
2641
|
vreinterpret_u8_u64(vcreate_u64(0x0F0E0D0C0B0A0908ULL)));
|
|
2639
|
-
uint8x16_t valid_init_u8x16 = vcltq_u8(lane_indices_init_u8x16, vdupq_n_u8((
|
|
2642
|
+
uint8x16_t valid_init_u8x16 = vcltq_u8(lane_indices_init_u8x16, vdupq_n_u8((nk_u8_t)first_count));
|
|
2640
2643
|
uint8x16_t min_u8x16 = vbslq_u8(valid_init_u8x16, first_comparable_u8x16, vdupq_n_u8(0xFF));
|
|
2641
2644
|
uint8x16_t max_u8x16 = first_comparable_u8x16; // invalid lanes are 0x00, safe for max
|
|
2642
2645
|
uint8x16_t min_iter_u8x16 = vdupq_n_u8(0), max_iter_u8x16 = vdupq_n_u8(0);
|
|
@@ -2660,7 +2663,7 @@ NK_INTERNAL void nk_reduce_minmax_e2m3_neon_contiguous_( //
|
|
|
2660
2663
|
uint8x16_t comparable_u8x16 = nk_fp6x16_to_comparable_neon_(tail_vec.u8x16);
|
|
2661
2664
|
uint8x16_t lane_indices_u8x16 = vcombine_u8(vreinterpret_u8_u64(vcreate_u64(0x0706050403020100ULL)),
|
|
2662
2665
|
vreinterpret_u8_u64(vcreate_u64(0x0F0E0D0C0B0A0908ULL)));
|
|
2663
|
-
uint8x16_t valid_u8x16 = vcltq_u8(lane_indices_u8x16, vdupq_n_u8((
|
|
2666
|
+
uint8x16_t valid_u8x16 = vcltq_u8(lane_indices_u8x16, vdupq_n_u8((nk_u8_t)remaining));
|
|
2664
2667
|
uint8x16_t data_for_min_u8x16 = vbslq_u8(valid_u8x16, comparable_u8x16, vdupq_n_u8(0xFF));
|
|
2665
2668
|
uint8x16_t data_for_max_u8x16 = vbslq_u8(valid_u8x16, comparable_u8x16, vdupq_n_u8(0));
|
|
2666
2669
|
uint8x16_t less_u8x16 = vcltq_u8(data_for_min_u8x16, min_u8x16);
|
|
@@ -2706,21 +2709,21 @@ NK_INTERNAL void nk_reduce_minmax_e2m3_neon_strided_( //
|
|
|
2706
2709
|
uint8x16_t data_for_min_u8x16, data_for_max_u8x16;
|
|
2707
2710
|
|
|
2708
2711
|
nk_reduce_minmax_e2m3_neon_cycle:
|
|
2709
|
-
if (stride_elements == 2 && idx + 16
|
|
2712
|
+
if (stride_elements == 2 && idx + 16 < count) {
|
|
2710
2713
|
uint8x16x2_t loaded = vld2q_u8((nk_u8_t const *)(data_ptr + idx * 2));
|
|
2711
2714
|
uint8x16_t comparable_u8x16 = nk_fp6x16_to_comparable_neon_(loaded.val[0]);
|
|
2712
2715
|
data_for_min_u8x16 = comparable_u8x16;
|
|
2713
2716
|
data_for_max_u8x16 = comparable_u8x16;
|
|
2714
2717
|
idx += 16;
|
|
2715
2718
|
}
|
|
2716
|
-
else if (stride_elements == 3 && idx + 16
|
|
2719
|
+
else if (stride_elements == 3 && idx + 16 < count) {
|
|
2717
2720
|
uint8x16x3_t loaded = vld3q_u8((nk_u8_t const *)(data_ptr + idx * 3));
|
|
2718
2721
|
uint8x16_t comparable_u8x16 = nk_fp6x16_to_comparable_neon_(loaded.val[0]);
|
|
2719
2722
|
data_for_min_u8x16 = comparable_u8x16;
|
|
2720
2723
|
data_for_max_u8x16 = comparable_u8x16;
|
|
2721
2724
|
idx += 16;
|
|
2722
2725
|
}
|
|
2723
|
-
else if (stride_elements == 4 && idx + 16
|
|
2726
|
+
else if (stride_elements == 4 && idx + 16 < count) {
|
|
2724
2727
|
uint8x16x4_t loaded = vld4q_u8((nk_u8_t const *)(data_ptr + idx * 4));
|
|
2725
2728
|
uint8x16_t comparable_u8x16 = nk_fp6x16_to_comparable_neon_(loaded.val[0]);
|
|
2726
2729
|
data_for_min_u8x16 = comparable_u8x16;
|
|
@@ -2731,7 +2734,7 @@ nk_reduce_minmax_e2m3_neon_cycle:
|
|
|
2731
2734
|
nk_b128_vec_t tail_vec;
|
|
2732
2735
|
nk_strided_load_b8x16_serial_(data_ptr + idx * stride_elements, stride_elements, &tail_vec, count - idx);
|
|
2733
2736
|
uint8x16_t comparable_u8x16 = nk_fp6x16_to_comparable_neon_(tail_vec.u8x16);
|
|
2734
|
-
uint8x16_t valid_u8x16 = vcltq_u8(lane_indices_u8x16, vdupq_n_u8((
|
|
2737
|
+
uint8x16_t valid_u8x16 = vcltq_u8(lane_indices_u8x16, vdupq_n_u8((nk_u8_t)(count - idx)));
|
|
2735
2738
|
data_for_min_u8x16 = vbslq_u8(valid_u8x16, comparable_u8x16, vdupq_n_u8(0xFF));
|
|
2736
2739
|
data_for_max_u8x16 = vbslq_u8(valid_u8x16, comparable_u8x16, vdupq_n_u8(0x00));
|
|
2737
2740
|
idx = count;
|
|
@@ -2812,51 +2815,53 @@ NK_INTERNAL void nk_reduce_moments_e3m2_neon_contiguous_( //
|
|
|
2812
2815
|
nk_e3m2_t const *data_ptr, nk_size_t count, //
|
|
2813
2816
|
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
2814
2817
|
// VTBL LUT: maps 6-bit magnitude (0..31) to (value×16) low byte; max value×16 = 448 needs i16
|
|
2815
|
-
uint8x16x2_t
|
|
2818
|
+
uint8x16x2_t lut_e3m2_low;
|
|
2816
2819
|
// table[0]: low bytes for magnitudes 0..15
|
|
2817
2820
|
// 0x0706050403020100 → bytes [0..7] = 0,1,2,3,4,5,6,7
|
|
2818
2821
|
// 0x1C1814100E0C0A08 → bytes [8..15] = 8,10,12,14,16,20,24,28
|
|
2819
|
-
|
|
2820
|
-
|
|
2822
|
+
lut_e3m2_low.val[0] = vcombine_u8(vreinterpret_u8_u64(vcreate_u64(0x0706050403020100ULL)),
|
|
2823
|
+
vreinterpret_u8_u64(vcreate_u64(0x1C1814100E0C0A08ULL)));
|
|
2821
2824
|
// table[1]: low bytes for magnitudes 16..31
|
|
2822
2825
|
// 0x7060504038302820 → bytes [0..7] = 32,40,48,56,64,80,96,112
|
|
2823
2826
|
// 0xC0804000E0C0A080 → bytes [8..15] = 128,160,192,224,0,64,128,192
|
|
2824
|
-
|
|
2825
|
-
|
|
2827
|
+
lut_e3m2_low.val[1] = vcombine_u8(vreinterpret_u8_u64(vcreate_u64(0x7060504038302820ULL)),
|
|
2828
|
+
vreinterpret_u8_u64(vcreate_u64(0xC0804000E0C0A080ULL)));
|
|
2826
2829
|
int32x4_t sum_i32x4 = vdupq_n_s32(0);
|
|
2827
2830
|
uint64x2_t sumsq_u64x2 = vdupq_n_u64(0);
|
|
2828
2831
|
nk_size_t idx = 0;
|
|
2829
2832
|
for (; idx + 16 <= count; idx += 16) {
|
|
2830
2833
|
uint8x16_t raw_u8x16 = vld1q_u8((nk_u8_t const *)(data_ptr + idx));
|
|
2831
2834
|
uint8x16_t magnitude_u8x16 = vandq_u8(raw_u8x16, vdupq_n_u8(0x1F));
|
|
2832
|
-
uint8x16_t low_byte_u8x16 = vqtbl2q_u8(
|
|
2835
|
+
uint8x16_t low_byte_u8x16 = vqtbl2q_u8(lut_e3m2_low, magnitude_u8x16);
|
|
2833
2836
|
uint8x16_t high_byte_u8x16 = vandq_u8(vcgtq_u8(magnitude_u8x16, vdupq_n_u8(27)), vdupq_n_u8(1));
|
|
2834
2837
|
uint8x16_t is_negative_u8x16 = vtstq_u8(raw_u8x16, vdupq_n_u8(0x20));
|
|
2835
2838
|
// Interleave low+high bytes into i16 values (two halves of 8 each)
|
|
2836
|
-
uint16x8_t
|
|
2837
|
-
uint16x8_t
|
|
2839
|
+
uint16x8_t unsigned_low_u16x8 = vreinterpretq_u16_u8(vzip1q_u8(low_byte_u8x16, high_byte_u8x16));
|
|
2840
|
+
uint16x8_t unsigned_high_u16x8 = vreinterpretq_u16_u8(vzip2q_u8(low_byte_u8x16, high_byte_u8x16));
|
|
2838
2841
|
// Sign-extend the per-byte negative mask to per-i16 lanes
|
|
2839
|
-
int8x8_t
|
|
2840
|
-
int8x8_t
|
|
2841
|
-
uint16x8_t
|
|
2842
|
-
uint16x8_t
|
|
2842
|
+
int8x8_t is_negative_low_i8x8 = vreinterpret_s8_u8(vget_low_u8(is_negative_u8x16));
|
|
2843
|
+
int8x8_t is_negative_high_i8x8 = vreinterpret_s8_u8(vget_high_u8(is_negative_u8x16));
|
|
2844
|
+
uint16x8_t is_negative_low_u16x8 = vreinterpretq_u16_s16(vmovl_s8(is_negative_low_i8x8));
|
|
2845
|
+
uint16x8_t is_negative_high_u16x8 = vreinterpretq_u16_s16(vmovl_s8(is_negative_high_i8x8));
|
|
2843
2846
|
// Apply sign via conditional negate
|
|
2844
|
-
int16x8_t
|
|
2845
|
-
int16x8_t
|
|
2846
|
-
|
|
2847
|
-
int16x8_t
|
|
2847
|
+
int16x8_t positive_low_i16x8 = vreinterpretq_s16_u16(unsigned_low_u16x8);
|
|
2848
|
+
int16x8_t scaled_low_i16x8 = vbslq_s16(is_negative_low_u16x8, vnegq_s16(positive_low_i16x8),
|
|
2849
|
+
positive_low_i16x8);
|
|
2850
|
+
int16x8_t positive_high_i16x8 = vreinterpretq_s16_u16(unsigned_high_u16x8);
|
|
2851
|
+
int16x8_t scaled_high_i16x8 = vbslq_s16(is_negative_high_u16x8, vnegq_s16(positive_high_i16x8),
|
|
2852
|
+
positive_high_i16x8);
|
|
2848
2853
|
// Sum: i16→i32 widening, accumulate in i32x4
|
|
2849
|
-
sum_i32x4 = vaddq_s32(sum_i32x4, vpaddlq_s16(
|
|
2850
|
-
sum_i32x4 = vaddq_s32(sum_i32x4, vpaddlq_s16(
|
|
2854
|
+
sum_i32x4 = vaddq_s32(sum_i32x4, vpaddlq_s16(scaled_low_i16x8));
|
|
2855
|
+
sum_i32x4 = vaddq_s32(sum_i32x4, vpaddlq_s16(scaled_high_i16x8));
|
|
2851
2856
|
// Sumsq: vmull_s16→i32 (always positive as squares), widen to u64
|
|
2852
|
-
int32x4_t
|
|
2853
|
-
int32x4_t
|
|
2854
|
-
int32x4_t
|
|
2855
|
-
int32x4_t
|
|
2856
|
-
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vreinterpretq_u32_s32(
|
|
2857
|
-
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vreinterpretq_u32_s32(
|
|
2858
|
-
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vreinterpretq_u32_s32(
|
|
2859
|
-
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vreinterpretq_u32_s32(
|
|
2857
|
+
int32x4_t squares_low_a_i32x4 = vmull_s16(vget_low_s16(scaled_low_i16x8), vget_low_s16(scaled_low_i16x8));
|
|
2858
|
+
int32x4_t squares_low_b_i32x4 = vmull_high_s16(scaled_low_i16x8, scaled_low_i16x8);
|
|
2859
|
+
int32x4_t squares_high_a_i32x4 = vmull_s16(vget_low_s16(scaled_high_i16x8), vget_low_s16(scaled_high_i16x8));
|
|
2860
|
+
int32x4_t squares_high_b_i32x4 = vmull_high_s16(scaled_high_i16x8, scaled_high_i16x8);
|
|
2861
|
+
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vreinterpretq_u32_s32(squares_low_a_i32x4)));
|
|
2862
|
+
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vreinterpretq_u32_s32(squares_low_b_i32x4)));
|
|
2863
|
+
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vreinterpretq_u32_s32(squares_high_a_i32x4)));
|
|
2864
|
+
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vreinterpretq_u32_s32(squares_high_b_i32x4)));
|
|
2860
2865
|
}
|
|
2861
2866
|
nk_i64_t sum = vaddlvq_s32(sum_i32x4);
|
|
2862
2867
|
nk_u64_t sumsq = vgetq_lane_u64(sumsq_u64x2, 0) + vgetq_lane_u64(sumsq_u64x2, 1);
|
|
@@ -2871,114 +2876,117 @@ NK_INTERNAL void nk_reduce_moments_e3m2_neon_contiguous_( //
|
|
|
2871
2876
|
NK_INTERNAL void nk_reduce_moments_e3m2_neon_strided_( //
|
|
2872
2877
|
nk_e3m2_t const *data_ptr, nk_size_t count, nk_size_t stride_elements, //
|
|
2873
2878
|
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
2874
|
-
uint8x16x2_t
|
|
2879
|
+
uint8x16x2_t lut_e3m2_low;
|
|
2875
2880
|
// table[0]: low bytes for magnitudes 0..15
|
|
2876
2881
|
// 0x0706050403020100 → bytes [0..7] = 0,1,2,3,4,5,6,7
|
|
2877
2882
|
// 0x1C1814100E0C0A08 → bytes [8..15] = 8,10,12,14,16,20,24,28
|
|
2878
|
-
|
|
2879
|
-
|
|
2883
|
+
lut_e3m2_low.val[0] = vcombine_u8(vreinterpret_u8_u64(vcreate_u64(0x0706050403020100ULL)),
|
|
2884
|
+
vreinterpret_u8_u64(vcreate_u64(0x1C1814100E0C0A08ULL)));
|
|
2880
2885
|
// table[1]: low bytes for magnitudes 16..31
|
|
2881
2886
|
// 0x7060504038302820 → bytes [0..7] = 32,40,48,56,64,80,96,112
|
|
2882
2887
|
// 0xC0804000E0C0A080 → bytes [8..15] = 128,160,192,224,0,64,128,192
|
|
2883
|
-
|
|
2884
|
-
|
|
2888
|
+
lut_e3m2_low.val[1] = vcombine_u8(vreinterpret_u8_u64(vcreate_u64(0x7060504038302820ULL)),
|
|
2889
|
+
vreinterpret_u8_u64(vcreate_u64(0xC0804000E0C0A080ULL)));
|
|
2885
2890
|
int32x4_t sum_i32x4 = vdupq_n_s32(0);
|
|
2886
2891
|
uint64x2_t sumsq_u64x2 = vdupq_n_u64(0);
|
|
2887
2892
|
nk_size_t idx = 0;
|
|
2888
2893
|
if (stride_elements == 2) {
|
|
2889
|
-
for (; idx + 16
|
|
2894
|
+
for (; idx + 16 < count; idx += 16) {
|
|
2890
2895
|
uint8x16x2_t loaded_u8x16x2 = vld2q_u8((nk_u8_t const *)(data_ptr + idx * 2));
|
|
2891
2896
|
uint8x16_t raw_u8x16 = loaded_u8x16x2.val[0];
|
|
2892
2897
|
uint8x16_t magnitude_u8x16 = vandq_u8(raw_u8x16, vdupq_n_u8(0x1F));
|
|
2893
|
-
uint8x16_t low_byte_u8x16 = vqtbl2q_u8(
|
|
2898
|
+
uint8x16_t low_byte_u8x16 = vqtbl2q_u8(lut_e3m2_low, magnitude_u8x16);
|
|
2894
2899
|
uint8x16_t high_byte_u8x16 = vandq_u8(vcgtq_u8(magnitude_u8x16, vdupq_n_u8(27)), vdupq_n_u8(1));
|
|
2895
2900
|
uint8x16_t is_negative_u8x16 = vtstq_u8(raw_u8x16, vdupq_n_u8(0x20));
|
|
2896
|
-
uint16x8_t
|
|
2897
|
-
uint16x8_t
|
|
2898
|
-
int8x8_t
|
|
2899
|
-
int8x8_t
|
|
2900
|
-
uint16x8_t
|
|
2901
|
-
uint16x8_t
|
|
2902
|
-
int16x8_t
|
|
2903
|
-
int16x8_t
|
|
2904
|
-
|
|
2905
|
-
int16x8_t
|
|
2906
|
-
int16x8_t
|
|
2907
|
-
|
|
2908
|
-
sum_i32x4 = vaddq_s32(sum_i32x4, vpaddlq_s16(
|
|
2909
|
-
sum_i32x4 = vaddq_s32(sum_i32x4, vpaddlq_s16(
|
|
2910
|
-
int32x4_t
|
|
2911
|
-
int32x4_t
|
|
2912
|
-
int32x4_t
|
|
2913
|
-
|
|
2914
|
-
|
|
2915
|
-
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vreinterpretq_u32_s32(
|
|
2916
|
-
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vreinterpretq_u32_s32(
|
|
2917
|
-
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vreinterpretq_u32_s32(
|
|
2901
|
+
uint16x8_t unsigned_low_u16x8 = vreinterpretq_u16_u8(vzip1q_u8(low_byte_u8x16, high_byte_u8x16));
|
|
2902
|
+
uint16x8_t unsigned_high_u16x8 = vreinterpretq_u16_u8(vzip2q_u8(low_byte_u8x16, high_byte_u8x16));
|
|
2903
|
+
int8x8_t is_negative_low_i8x8 = vreinterpret_s8_u8(vget_low_u8(is_negative_u8x16));
|
|
2904
|
+
int8x8_t is_negative_high_i8x8 = vreinterpret_s8_u8(vget_high_u8(is_negative_u8x16));
|
|
2905
|
+
uint16x8_t is_negative_low_u16x8 = vreinterpretq_u16_s16(vmovl_s8(is_negative_low_i8x8));
|
|
2906
|
+
uint16x8_t is_negative_high_u16x8 = vreinterpretq_u16_s16(vmovl_s8(is_negative_high_i8x8));
|
|
2907
|
+
int16x8_t positive_low_i16x8 = vreinterpretq_s16_u16(unsigned_low_u16x8);
|
|
2908
|
+
int16x8_t scaled_low_i16x8 = vbslq_s16(is_negative_low_u16x8, vnegq_s16(positive_low_i16x8),
|
|
2909
|
+
positive_low_i16x8);
|
|
2910
|
+
int16x8_t positive_high_i16x8 = vreinterpretq_s16_u16(unsigned_high_u16x8);
|
|
2911
|
+
int16x8_t scaled_high_i16x8 = vbslq_s16(is_negative_high_u16x8, vnegq_s16(positive_high_i16x8),
|
|
2912
|
+
positive_high_i16x8);
|
|
2913
|
+
sum_i32x4 = vaddq_s32(sum_i32x4, vpaddlq_s16(scaled_low_i16x8));
|
|
2914
|
+
sum_i32x4 = vaddq_s32(sum_i32x4, vpaddlq_s16(scaled_high_i16x8));
|
|
2915
|
+
int32x4_t squares_low_a_i32x4 = vmull_s16(vget_low_s16(scaled_low_i16x8), vget_low_s16(scaled_low_i16x8));
|
|
2916
|
+
int32x4_t squares_low_b_i32x4 = vmull_high_s16(scaled_low_i16x8, scaled_low_i16x8);
|
|
2917
|
+
int32x4_t squares_high_a_i32x4 = vmull_s16(vget_low_s16(scaled_high_i16x8),
|
|
2918
|
+
vget_low_s16(scaled_high_i16x8));
|
|
2919
|
+
int32x4_t squares_high_b_i32x4 = vmull_high_s16(scaled_high_i16x8, scaled_high_i16x8);
|
|
2920
|
+
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vreinterpretq_u32_s32(squares_low_a_i32x4)));
|
|
2921
|
+
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vreinterpretq_u32_s32(squares_low_b_i32x4)));
|
|
2922
|
+
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vreinterpretq_u32_s32(squares_high_a_i32x4)));
|
|
2923
|
+
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vreinterpretq_u32_s32(squares_high_b_i32x4)));
|
|
2918
2924
|
}
|
|
2919
2925
|
}
|
|
2920
2926
|
else if (stride_elements == 3) {
|
|
2921
|
-
for (; idx + 16
|
|
2927
|
+
for (; idx + 16 < count; idx += 16) {
|
|
2922
2928
|
uint8x16x3_t loaded_u8x16x3 = vld3q_u8((nk_u8_t const *)(data_ptr + idx * 3));
|
|
2923
2929
|
uint8x16_t raw_u8x16 = loaded_u8x16x3.val[0];
|
|
2924
2930
|
uint8x16_t magnitude_u8x16 = vandq_u8(raw_u8x16, vdupq_n_u8(0x1F));
|
|
2925
|
-
uint8x16_t low_byte_u8x16 = vqtbl2q_u8(
|
|
2931
|
+
uint8x16_t low_byte_u8x16 = vqtbl2q_u8(lut_e3m2_low, magnitude_u8x16);
|
|
2926
2932
|
uint8x16_t high_byte_u8x16 = vandq_u8(vcgtq_u8(magnitude_u8x16, vdupq_n_u8(27)), vdupq_n_u8(1));
|
|
2927
2933
|
uint8x16_t is_negative_u8x16 = vtstq_u8(raw_u8x16, vdupq_n_u8(0x20));
|
|
2928
|
-
uint16x8_t
|
|
2929
|
-
uint16x8_t
|
|
2930
|
-
int8x8_t
|
|
2931
|
-
int8x8_t
|
|
2932
|
-
uint16x8_t
|
|
2933
|
-
uint16x8_t
|
|
2934
|
-
int16x8_t
|
|
2935
|
-
int16x8_t
|
|
2936
|
-
|
|
2937
|
-
int16x8_t
|
|
2938
|
-
int16x8_t
|
|
2939
|
-
|
|
2940
|
-
sum_i32x4 = vaddq_s32(sum_i32x4, vpaddlq_s16(
|
|
2941
|
-
sum_i32x4 = vaddq_s32(sum_i32x4, vpaddlq_s16(
|
|
2942
|
-
int32x4_t
|
|
2943
|
-
int32x4_t
|
|
2944
|
-
int32x4_t
|
|
2945
|
-
|
|
2946
|
-
|
|
2947
|
-
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vreinterpretq_u32_s32(
|
|
2948
|
-
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vreinterpretq_u32_s32(
|
|
2949
|
-
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vreinterpretq_u32_s32(
|
|
2934
|
+
uint16x8_t unsigned_low_u16x8 = vreinterpretq_u16_u8(vzip1q_u8(low_byte_u8x16, high_byte_u8x16));
|
|
2935
|
+
uint16x8_t unsigned_high_u16x8 = vreinterpretq_u16_u8(vzip2q_u8(low_byte_u8x16, high_byte_u8x16));
|
|
2936
|
+
int8x8_t is_negative_low_i8x8 = vreinterpret_s8_u8(vget_low_u8(is_negative_u8x16));
|
|
2937
|
+
int8x8_t is_negative_high_i8x8 = vreinterpret_s8_u8(vget_high_u8(is_negative_u8x16));
|
|
2938
|
+
uint16x8_t is_negative_low_u16x8 = vreinterpretq_u16_s16(vmovl_s8(is_negative_low_i8x8));
|
|
2939
|
+
uint16x8_t is_negative_high_u16x8 = vreinterpretq_u16_s16(vmovl_s8(is_negative_high_i8x8));
|
|
2940
|
+
int16x8_t positive_low_i16x8 = vreinterpretq_s16_u16(unsigned_low_u16x8);
|
|
2941
|
+
int16x8_t scaled_low_i16x8 = vbslq_s16(is_negative_low_u16x8, vnegq_s16(positive_low_i16x8),
|
|
2942
|
+
positive_low_i16x8);
|
|
2943
|
+
int16x8_t positive_high_i16x8 = vreinterpretq_s16_u16(unsigned_high_u16x8);
|
|
2944
|
+
int16x8_t scaled_high_i16x8 = vbslq_s16(is_negative_high_u16x8, vnegq_s16(positive_high_i16x8),
|
|
2945
|
+
positive_high_i16x8);
|
|
2946
|
+
sum_i32x4 = vaddq_s32(sum_i32x4, vpaddlq_s16(scaled_low_i16x8));
|
|
2947
|
+
sum_i32x4 = vaddq_s32(sum_i32x4, vpaddlq_s16(scaled_high_i16x8));
|
|
2948
|
+
int32x4_t squares_low_a_i32x4 = vmull_s16(vget_low_s16(scaled_low_i16x8), vget_low_s16(scaled_low_i16x8));
|
|
2949
|
+
int32x4_t squares_low_b_i32x4 = vmull_high_s16(scaled_low_i16x8, scaled_low_i16x8);
|
|
2950
|
+
int32x4_t squares_high_a_i32x4 = vmull_s16(vget_low_s16(scaled_high_i16x8),
|
|
2951
|
+
vget_low_s16(scaled_high_i16x8));
|
|
2952
|
+
int32x4_t squares_high_b_i32x4 = vmull_high_s16(scaled_high_i16x8, scaled_high_i16x8);
|
|
2953
|
+
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vreinterpretq_u32_s32(squares_low_a_i32x4)));
|
|
2954
|
+
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vreinterpretq_u32_s32(squares_low_b_i32x4)));
|
|
2955
|
+
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vreinterpretq_u32_s32(squares_high_a_i32x4)));
|
|
2956
|
+
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vreinterpretq_u32_s32(squares_high_b_i32x4)));
|
|
2950
2957
|
}
|
|
2951
2958
|
}
|
|
2952
2959
|
else {
|
|
2953
|
-
for (; idx + 16
|
|
2960
|
+
for (; idx + 16 < count; idx += 16) {
|
|
2954
2961
|
uint8x16x4_t loaded_u8x16x4 = vld4q_u8((nk_u8_t const *)(data_ptr + idx * 4));
|
|
2955
2962
|
uint8x16_t raw_u8x16 = loaded_u8x16x4.val[0];
|
|
2956
2963
|
uint8x16_t magnitude_u8x16 = vandq_u8(raw_u8x16, vdupq_n_u8(0x1F));
|
|
2957
|
-
uint8x16_t low_byte_u8x16 = vqtbl2q_u8(
|
|
2964
|
+
uint8x16_t low_byte_u8x16 = vqtbl2q_u8(lut_e3m2_low, magnitude_u8x16);
|
|
2958
2965
|
uint8x16_t high_byte_u8x16 = vandq_u8(vcgtq_u8(magnitude_u8x16, vdupq_n_u8(27)), vdupq_n_u8(1));
|
|
2959
2966
|
uint8x16_t is_negative_u8x16 = vtstq_u8(raw_u8x16, vdupq_n_u8(0x20));
|
|
2960
|
-
uint16x8_t
|
|
2961
|
-
uint16x8_t
|
|
2962
|
-
int8x8_t
|
|
2963
|
-
int8x8_t
|
|
2964
|
-
uint16x8_t
|
|
2965
|
-
uint16x8_t
|
|
2966
|
-
int16x8_t
|
|
2967
|
-
int16x8_t
|
|
2968
|
-
|
|
2969
|
-
int16x8_t
|
|
2970
|
-
int16x8_t
|
|
2971
|
-
|
|
2972
|
-
sum_i32x4 = vaddq_s32(sum_i32x4, vpaddlq_s16(
|
|
2973
|
-
sum_i32x4 = vaddq_s32(sum_i32x4, vpaddlq_s16(
|
|
2974
|
-
int32x4_t
|
|
2975
|
-
int32x4_t
|
|
2976
|
-
int32x4_t
|
|
2977
|
-
|
|
2978
|
-
|
|
2979
|
-
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vreinterpretq_u32_s32(
|
|
2980
|
-
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vreinterpretq_u32_s32(
|
|
2981
|
-
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vreinterpretq_u32_s32(
|
|
2967
|
+
uint16x8_t unsigned_low_u16x8 = vreinterpretq_u16_u8(vzip1q_u8(low_byte_u8x16, high_byte_u8x16));
|
|
2968
|
+
uint16x8_t unsigned_high_u16x8 = vreinterpretq_u16_u8(vzip2q_u8(low_byte_u8x16, high_byte_u8x16));
|
|
2969
|
+
int8x8_t is_negative_low_i8x8 = vreinterpret_s8_u8(vget_low_u8(is_negative_u8x16));
|
|
2970
|
+
int8x8_t is_negative_high_i8x8 = vreinterpret_s8_u8(vget_high_u8(is_negative_u8x16));
|
|
2971
|
+
uint16x8_t is_negative_low_u16x8 = vreinterpretq_u16_s16(vmovl_s8(is_negative_low_i8x8));
|
|
2972
|
+
uint16x8_t is_negative_high_u16x8 = vreinterpretq_u16_s16(vmovl_s8(is_negative_high_i8x8));
|
|
2973
|
+
int16x8_t positive_low_i16x8 = vreinterpretq_s16_u16(unsigned_low_u16x8);
|
|
2974
|
+
int16x8_t scaled_low_i16x8 = vbslq_s16(is_negative_low_u16x8, vnegq_s16(positive_low_i16x8),
|
|
2975
|
+
positive_low_i16x8);
|
|
2976
|
+
int16x8_t positive_high_i16x8 = vreinterpretq_s16_u16(unsigned_high_u16x8);
|
|
2977
|
+
int16x8_t scaled_high_i16x8 = vbslq_s16(is_negative_high_u16x8, vnegq_s16(positive_high_i16x8),
|
|
2978
|
+
positive_high_i16x8);
|
|
2979
|
+
sum_i32x4 = vaddq_s32(sum_i32x4, vpaddlq_s16(scaled_low_i16x8));
|
|
2980
|
+
sum_i32x4 = vaddq_s32(sum_i32x4, vpaddlq_s16(scaled_high_i16x8));
|
|
2981
|
+
int32x4_t squares_low_a_i32x4 = vmull_s16(vget_low_s16(scaled_low_i16x8), vget_low_s16(scaled_low_i16x8));
|
|
2982
|
+
int32x4_t squares_low_b_i32x4 = vmull_high_s16(scaled_low_i16x8, scaled_low_i16x8);
|
|
2983
|
+
int32x4_t squares_high_a_i32x4 = vmull_s16(vget_low_s16(scaled_high_i16x8),
|
|
2984
|
+
vget_low_s16(scaled_high_i16x8));
|
|
2985
|
+
int32x4_t squares_high_b_i32x4 = vmull_high_s16(scaled_high_i16x8, scaled_high_i16x8);
|
|
2986
|
+
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vreinterpretq_u32_s32(squares_low_a_i32x4)));
|
|
2987
|
+
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vreinterpretq_u32_s32(squares_low_b_i32x4)));
|
|
2988
|
+
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vreinterpretq_u32_s32(squares_high_a_i32x4)));
|
|
2989
|
+
sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vreinterpretq_u32_s32(squares_high_b_i32x4)));
|
|
2982
2990
|
}
|
|
2983
2991
|
}
|
|
2984
2992
|
nk_i64_t sum = vaddlvq_s32(sum_i32x4);
|
|
@@ -3025,7 +3033,7 @@ NK_INTERNAL void nk_reduce_minmax_e3m2_neon_contiguous_( //
|
|
|
3025
3033
|
first_comparable_u8x16 = nk_fp6x16_to_comparable_neon_(first_vec.u8x16);
|
|
3026
3034
|
uint8x16_t lane_indices_u8x16 = vcombine_u8(vreinterpret_u8_u64(vcreate_u64(0x0706050403020100ULL)),
|
|
3027
3035
|
vreinterpret_u8_u64(vcreate_u64(0x0F0E0D0C0B0A0908ULL)));
|
|
3028
|
-
uint8x16_t valid_u8x16 = vcltq_u8(lane_indices_u8x16, vdupq_n_u8((
|
|
3036
|
+
uint8x16_t valid_u8x16 = vcltq_u8(lane_indices_u8x16, vdupq_n_u8((nk_u8_t)count));
|
|
3029
3037
|
first_comparable_u8x16 = vbslq_u8(valid_u8x16, first_comparable_u8x16, vdupq_n_u8(0));
|
|
3030
3038
|
}
|
|
3031
3039
|
else {
|
|
@@ -3034,7 +3042,7 @@ NK_INTERNAL void nk_reduce_minmax_e3m2_neon_contiguous_( //
|
|
|
3034
3042
|
}
|
|
3035
3043
|
uint8x16_t lane_indices_init_u8x16 = vcombine_u8(vreinterpret_u8_u64(vcreate_u64(0x0706050403020100ULL)),
|
|
3036
3044
|
vreinterpret_u8_u64(vcreate_u64(0x0F0E0D0C0B0A0908ULL)));
|
|
3037
|
-
uint8x16_t valid_init_u8x16 = vcltq_u8(lane_indices_init_u8x16, vdupq_n_u8((
|
|
3045
|
+
uint8x16_t valid_init_u8x16 = vcltq_u8(lane_indices_init_u8x16, vdupq_n_u8((nk_u8_t)first_count));
|
|
3038
3046
|
uint8x16_t min_u8x16 = vbslq_u8(valid_init_u8x16, first_comparable_u8x16, vdupq_n_u8(0xFF));
|
|
3039
3047
|
uint8x16_t max_u8x16 = first_comparable_u8x16;
|
|
3040
3048
|
uint8x16_t min_iter_u8x16 = vdupq_n_u8(0), max_iter_u8x16 = vdupq_n_u8(0);
|
|
@@ -3058,7 +3066,7 @@ NK_INTERNAL void nk_reduce_minmax_e3m2_neon_contiguous_( //
|
|
|
3058
3066
|
uint8x16_t comparable_u8x16 = nk_fp6x16_to_comparable_neon_(tail_vec.u8x16);
|
|
3059
3067
|
uint8x16_t lane_indices_u8x16 = vcombine_u8(vreinterpret_u8_u64(vcreate_u64(0x0706050403020100ULL)),
|
|
3060
3068
|
vreinterpret_u8_u64(vcreate_u64(0x0F0E0D0C0B0A0908ULL)));
|
|
3061
|
-
uint8x16_t valid_u8x16 = vcltq_u8(lane_indices_u8x16, vdupq_n_u8((
|
|
3069
|
+
uint8x16_t valid_u8x16 = vcltq_u8(lane_indices_u8x16, vdupq_n_u8((nk_u8_t)remaining));
|
|
3062
3070
|
uint8x16_t data_for_min_u8x16 = vbslq_u8(valid_u8x16, comparable_u8x16, vdupq_n_u8(0xFF));
|
|
3063
3071
|
uint8x16_t data_for_max_u8x16 = vbslq_u8(valid_u8x16, comparable_u8x16, vdupq_n_u8(0));
|
|
3064
3072
|
uint8x16_t less_u8x16 = vcltq_u8(data_for_min_u8x16, min_u8x16);
|
|
@@ -3104,21 +3112,21 @@ NK_INTERNAL void nk_reduce_minmax_e3m2_neon_strided_( //
|
|
|
3104
3112
|
uint8x16_t data_for_min_u8x16, data_for_max_u8x16;
|
|
3105
3113
|
|
|
3106
3114
|
nk_reduce_minmax_e3m2_neon_cycle:
|
|
3107
|
-
if (stride_elements == 2 && idx + 16
|
|
3115
|
+
if (stride_elements == 2 && idx + 16 < count) {
|
|
3108
3116
|
uint8x16x2_t loaded = vld2q_u8((nk_u8_t const *)(data_ptr + idx * 2));
|
|
3109
3117
|
uint8x16_t comparable_u8x16 = nk_fp6x16_to_comparable_neon_(loaded.val[0]);
|
|
3110
3118
|
data_for_min_u8x16 = comparable_u8x16;
|
|
3111
3119
|
data_for_max_u8x16 = comparable_u8x16;
|
|
3112
3120
|
idx += 16;
|
|
3113
3121
|
}
|
|
3114
|
-
else if (stride_elements == 3 && idx + 16
|
|
3122
|
+
else if (stride_elements == 3 && idx + 16 < count) {
|
|
3115
3123
|
uint8x16x3_t loaded = vld3q_u8((nk_u8_t const *)(data_ptr + idx * 3));
|
|
3116
3124
|
uint8x16_t comparable_u8x16 = nk_fp6x16_to_comparable_neon_(loaded.val[0]);
|
|
3117
3125
|
data_for_min_u8x16 = comparable_u8x16;
|
|
3118
3126
|
data_for_max_u8x16 = comparable_u8x16;
|
|
3119
3127
|
idx += 16;
|
|
3120
3128
|
}
|
|
3121
|
-
else if (stride_elements == 4 && idx + 16
|
|
3129
|
+
else if (stride_elements == 4 && idx + 16 < count) {
|
|
3122
3130
|
uint8x16x4_t loaded = vld4q_u8((nk_u8_t const *)(data_ptr + idx * 4));
|
|
3123
3131
|
uint8x16_t comparable_u8x16 = nk_fp6x16_to_comparable_neon_(loaded.val[0]);
|
|
3124
3132
|
data_for_min_u8x16 = comparable_u8x16;
|
|
@@ -3129,7 +3137,7 @@ nk_reduce_minmax_e3m2_neon_cycle:
|
|
|
3129
3137
|
nk_b128_vec_t tail_vec;
|
|
3130
3138
|
nk_strided_load_b8x16_serial_(data_ptr + idx * stride_elements, stride_elements, &tail_vec, count - idx);
|
|
3131
3139
|
uint8x16_t comparable_u8x16 = nk_fp6x16_to_comparable_neon_(tail_vec.u8x16);
|
|
3132
|
-
uint8x16_t valid_u8x16 = vcltq_u8(lane_indices_u8x16, vdupq_n_u8((
|
|
3140
|
+
uint8x16_t valid_u8x16 = vcltq_u8(lane_indices_u8x16, vdupq_n_u8((nk_u8_t)(count - idx)));
|
|
3133
3141
|
data_for_min_u8x16 = vbslq_u8(valid_u8x16, comparable_u8x16, vdupq_n_u8(0xFF));
|
|
3134
3142
|
data_for_max_u8x16 = vbslq_u8(valid_u8x16, comparable_u8x16, vdupq_n_u8(0x00));
|
|
3135
3143
|
idx = count;
|
|
@@ -3213,12 +3221,12 @@ NK_INTERNAL void nk_reduce_moments_e4m3_neon_contiguous_( //
|
|
|
3213
3221
|
nk_size_t idx = 0;
|
|
3214
3222
|
for (; idx + 16 <= count; idx += 16) {
|
|
3215
3223
|
uint8x16_t raw_u8x16 = vld1q_u8((nk_u8_t const *)(data_ptr + idx));
|
|
3216
|
-
float16x8_t
|
|
3217
|
-
nk_e4m3x16_to_f16x8x2_neon_(raw_u8x16, &
|
|
3218
|
-
float32x4_t a_f32x4 = vcvt_f32_f16(vget_low_f16(
|
|
3219
|
-
float32x4_t b_f32x4 = vcvt_high_f32_f16(
|
|
3220
|
-
float32x4_t c_f32x4 = vcvt_f32_f16(vget_low_f16(
|
|
3221
|
-
float32x4_t d_f32x4 = vcvt_high_f32_f16(
|
|
3224
|
+
float16x8_t half_low_f16x8, half_high_f16x8;
|
|
3225
|
+
nk_e4m3x16_to_f16x8x2_neon_(raw_u8x16, &half_low_f16x8, &half_high_f16x8);
|
|
3226
|
+
float32x4_t a_f32x4 = vcvt_f32_f16(vget_low_f16(half_low_f16x8));
|
|
3227
|
+
float32x4_t b_f32x4 = vcvt_high_f32_f16(half_low_f16x8);
|
|
3228
|
+
float32x4_t c_f32x4 = vcvt_f32_f16(vget_low_f16(half_high_f16x8));
|
|
3229
|
+
float32x4_t d_f32x4 = vcvt_high_f32_f16(half_high_f16x8);
|
|
3222
3230
|
sum_f32x4 = vaddq_f32(vaddq_f32(sum_f32x4, vaddq_f32(a_f32x4, b_f32x4)), vaddq_f32(c_f32x4, d_f32x4));
|
|
3223
3231
|
sumsq_f32x4 = vfmaq_f32(vfmaq_f32(vfmaq_f32(vfmaq_f32( //
|
|
3224
3232
|
sumsq_f32x4, a_f32x4, a_f32x4),
|
|
@@ -3241,14 +3249,14 @@ NK_INTERNAL void nk_reduce_moments_e4m3_neon_strided_( //
|
|
|
3241
3249
|
float32x4_t sum_f32x4 = vdupq_n_f32(0), sumsq_f32x4 = vdupq_n_f32(0);
|
|
3242
3250
|
nk_size_t idx = 0;
|
|
3243
3251
|
if (stride_elements == 2) {
|
|
3244
|
-
for (; idx + 16
|
|
3252
|
+
for (; idx + 16 < count; idx += 16) {
|
|
3245
3253
|
uint8x16x2_t loaded_u8x16x2 = vld2q_u8((nk_u8_t const *)(data_ptr + idx * 2));
|
|
3246
|
-
float16x8_t
|
|
3247
|
-
nk_e4m3x16_to_f16x8x2_neon_(loaded_u8x16x2.val[0], &
|
|
3248
|
-
float32x4_t a_f32x4 = vcvt_f32_f16(vget_low_f16(
|
|
3249
|
-
float32x4_t b_f32x4 = vcvt_high_f32_f16(
|
|
3250
|
-
float32x4_t c_f32x4 = vcvt_f32_f16(vget_low_f16(
|
|
3251
|
-
float32x4_t d_f32x4 = vcvt_high_f32_f16(
|
|
3254
|
+
float16x8_t half_low_f16x8, half_high_f16x8;
|
|
3255
|
+
nk_e4m3x16_to_f16x8x2_neon_(loaded_u8x16x2.val[0], &half_low_f16x8, &half_high_f16x8);
|
|
3256
|
+
float32x4_t a_f32x4 = vcvt_f32_f16(vget_low_f16(half_low_f16x8));
|
|
3257
|
+
float32x4_t b_f32x4 = vcvt_high_f32_f16(half_low_f16x8);
|
|
3258
|
+
float32x4_t c_f32x4 = vcvt_f32_f16(vget_low_f16(half_high_f16x8));
|
|
3259
|
+
float32x4_t d_f32x4 = vcvt_high_f32_f16(half_high_f16x8);
|
|
3252
3260
|
sum_f32x4 = vaddq_f32(vaddq_f32(sum_f32x4, vaddq_f32(a_f32x4, b_f32x4)), vaddq_f32(c_f32x4, d_f32x4));
|
|
3253
3261
|
sumsq_f32x4 = vfmaq_f32(vfmaq_f32(vfmaq_f32(vfmaq_f32( //
|
|
3254
3262
|
sumsq_f32x4, a_f32x4, a_f32x4),
|
|
@@ -3258,14 +3266,14 @@ NK_INTERNAL void nk_reduce_moments_e4m3_neon_strided_( //
|
|
|
3258
3266
|
}
|
|
3259
3267
|
}
|
|
3260
3268
|
else if (stride_elements == 3) {
|
|
3261
|
-
for (; idx + 16
|
|
3269
|
+
for (; idx + 16 < count; idx += 16) {
|
|
3262
3270
|
uint8x16x3_t loaded_u8x16x3 = vld3q_u8((nk_u8_t const *)(data_ptr + idx * 3));
|
|
3263
|
-
float16x8_t
|
|
3264
|
-
nk_e4m3x16_to_f16x8x2_neon_(loaded_u8x16x3.val[0], &
|
|
3265
|
-
float32x4_t a_f32x4 = vcvt_f32_f16(vget_low_f16(
|
|
3266
|
-
float32x4_t b_f32x4 = vcvt_high_f32_f16(
|
|
3267
|
-
float32x4_t c_f32x4 = vcvt_f32_f16(vget_low_f16(
|
|
3268
|
-
float32x4_t d_f32x4 = vcvt_high_f32_f16(
|
|
3271
|
+
float16x8_t half_low_f16x8, half_high_f16x8;
|
|
3272
|
+
nk_e4m3x16_to_f16x8x2_neon_(loaded_u8x16x3.val[0], &half_low_f16x8, &half_high_f16x8);
|
|
3273
|
+
float32x4_t a_f32x4 = vcvt_f32_f16(vget_low_f16(half_low_f16x8));
|
|
3274
|
+
float32x4_t b_f32x4 = vcvt_high_f32_f16(half_low_f16x8);
|
|
3275
|
+
float32x4_t c_f32x4 = vcvt_f32_f16(vget_low_f16(half_high_f16x8));
|
|
3276
|
+
float32x4_t d_f32x4 = vcvt_high_f32_f16(half_high_f16x8);
|
|
3269
3277
|
sum_f32x4 = vaddq_f32(vaddq_f32(sum_f32x4, vaddq_f32(a_f32x4, b_f32x4)), vaddq_f32(c_f32x4, d_f32x4));
|
|
3270
3278
|
sumsq_f32x4 = vfmaq_f32(vfmaq_f32(vfmaq_f32(vfmaq_f32( //
|
|
3271
3279
|
sumsq_f32x4, a_f32x4, a_f32x4),
|
|
@@ -3275,14 +3283,14 @@ NK_INTERNAL void nk_reduce_moments_e4m3_neon_strided_( //
|
|
|
3275
3283
|
}
|
|
3276
3284
|
}
|
|
3277
3285
|
else {
|
|
3278
|
-
for (; idx + 16
|
|
3286
|
+
for (; idx + 16 < count; idx += 16) {
|
|
3279
3287
|
uint8x16x4_t loaded_u8x16x4 = vld4q_u8((nk_u8_t const *)(data_ptr + idx * 4));
|
|
3280
|
-
float16x8_t
|
|
3281
|
-
nk_e4m3x16_to_f16x8x2_neon_(loaded_u8x16x4.val[0], &
|
|
3282
|
-
float32x4_t a_f32x4 = vcvt_f32_f16(vget_low_f16(
|
|
3283
|
-
float32x4_t b_f32x4 = vcvt_high_f32_f16(
|
|
3284
|
-
float32x4_t c_f32x4 = vcvt_f32_f16(vget_low_f16(
|
|
3285
|
-
float32x4_t d_f32x4 = vcvt_high_f32_f16(
|
|
3288
|
+
float16x8_t half_low_f16x8, half_high_f16x8;
|
|
3289
|
+
nk_e4m3x16_to_f16x8x2_neon_(loaded_u8x16x4.val[0], &half_low_f16x8, &half_high_f16x8);
|
|
3290
|
+
float32x4_t a_f32x4 = vcvt_f32_f16(vget_low_f16(half_low_f16x8));
|
|
3291
|
+
float32x4_t b_f32x4 = vcvt_high_f32_f16(half_low_f16x8);
|
|
3292
|
+
float32x4_t c_f32x4 = vcvt_f32_f16(vget_low_f16(half_high_f16x8));
|
|
3293
|
+
float32x4_t d_f32x4 = vcvt_high_f32_f16(half_high_f16x8);
|
|
3286
3294
|
sum_f32x4 = vaddq_f32(vaddq_f32(sum_f32x4, vaddq_f32(a_f32x4, b_f32x4)), vaddq_f32(c_f32x4, d_f32x4));
|
|
3287
3295
|
sumsq_f32x4 = vfmaq_f32(vfmaq_f32(vfmaq_f32(vfmaq_f32( //
|
|
3288
3296
|
sumsq_f32x4, a_f32x4, a_f32x4),
|
|
@@ -3355,7 +3363,7 @@ NK_INTERNAL void nk_reduce_minmax_e4m3_neon_contiguous_( //
|
|
|
3355
3363
|
uint8x16_t nan_max_u8x16 = vbslq_u8(is_nan_u8x16, vdupq_n_u8(0x00), comparable_u8x16);
|
|
3356
3364
|
uint8x16_t lane_indices_u8x16 = vcombine_u8(vreinterpret_u8_u64(vcreate_u64(0x0706050403020100ULL)),
|
|
3357
3365
|
vreinterpret_u8_u64(vcreate_u64(0x0F0E0D0C0B0A0908ULL)));
|
|
3358
|
-
uint8x16_t valid_u8x16 = vcltq_u8(lane_indices_u8x16, vdupq_n_u8((
|
|
3366
|
+
uint8x16_t valid_u8x16 = vcltq_u8(lane_indices_u8x16, vdupq_n_u8((nk_u8_t)remaining));
|
|
3359
3367
|
uint8x16_t data_for_min_u8x16 = vbslq_u8(valid_u8x16, nan_min_u8x16, vdupq_n_u8(0xFF));
|
|
3360
3368
|
uint8x16_t data_for_max_u8x16 = vbslq_u8(valid_u8x16, nan_max_u8x16, vdupq_n_u8(0));
|
|
3361
3369
|
uint8x16_t less_u8x16 = vcltq_u8(data_for_min_u8x16, min_u8x16);
|
|
@@ -3407,7 +3415,7 @@ NK_INTERNAL void nk_reduce_minmax_e4m3_neon_strided_( //
|
|
|
3407
3415
|
uint8x16_t data_for_min_u8x16, data_for_max_u8x16;
|
|
3408
3416
|
|
|
3409
3417
|
nk_reduce_minmax_e4m3_neon_cycle:
|
|
3410
|
-
if (stride_elements == 2 && idx + 16
|
|
3418
|
+
if (stride_elements == 2 && idx + 16 < count) {
|
|
3411
3419
|
uint8x16x2_t loaded = vld2q_u8((nk_u8_t const *)(data_ptr + idx * 2));
|
|
3412
3420
|
uint8x16_t comparable_u8x16 = nk_fp8x16_to_comparable_neon_(loaded.val[0]);
|
|
3413
3421
|
uint8x16_t is_nan_u8x16 = vorrq_u8(vceqq_u8(comparable_u8x16, vdupq_n_u8(0x00)),
|
|
@@ -3416,7 +3424,7 @@ nk_reduce_minmax_e4m3_neon_cycle:
|
|
|
3416
3424
|
data_for_max_u8x16 = vbslq_u8(is_nan_u8x16, vdupq_n_u8(0x00), comparable_u8x16);
|
|
3417
3425
|
idx += 16;
|
|
3418
3426
|
}
|
|
3419
|
-
else if (stride_elements == 3 && idx + 16
|
|
3427
|
+
else if (stride_elements == 3 && idx + 16 < count) {
|
|
3420
3428
|
uint8x16x3_t loaded = vld3q_u8((nk_u8_t const *)(data_ptr + idx * 3));
|
|
3421
3429
|
uint8x16_t comparable_u8x16 = nk_fp8x16_to_comparable_neon_(loaded.val[0]);
|
|
3422
3430
|
uint8x16_t is_nan_u8x16 = vorrq_u8(vceqq_u8(comparable_u8x16, vdupq_n_u8(0x00)),
|
|
@@ -3425,7 +3433,7 @@ nk_reduce_minmax_e4m3_neon_cycle:
|
|
|
3425
3433
|
data_for_max_u8x16 = vbslq_u8(is_nan_u8x16, vdupq_n_u8(0x00), comparable_u8x16);
|
|
3426
3434
|
idx += 16;
|
|
3427
3435
|
}
|
|
3428
|
-
else if (stride_elements == 4 && idx + 16
|
|
3436
|
+
else if (stride_elements == 4 && idx + 16 < count) {
|
|
3429
3437
|
uint8x16x4_t loaded = vld4q_u8((nk_u8_t const *)(data_ptr + idx * 4));
|
|
3430
3438
|
uint8x16_t comparable_u8x16 = nk_fp8x16_to_comparable_neon_(loaded.val[0]);
|
|
3431
3439
|
uint8x16_t is_nan_u8x16 = vorrq_u8(vceqq_u8(comparable_u8x16, vdupq_n_u8(0x00)),
|
|
@@ -3440,7 +3448,7 @@ nk_reduce_minmax_e4m3_neon_cycle:
|
|
|
3440
3448
|
uint8x16_t comparable_u8x16 = nk_fp8x16_to_comparable_neon_(tail_vec.u8x16);
|
|
3441
3449
|
uint8x16_t is_nan_u8x16 = vorrq_u8(vceqq_u8(comparable_u8x16, vdupq_n_u8(0x00)),
|
|
3442
3450
|
vceqq_u8(comparable_u8x16, vdupq_n_u8(0xFF)));
|
|
3443
|
-
uint8x16_t valid_u8x16 = vcltq_u8(lane_indices_u8x16, vdupq_n_u8((
|
|
3451
|
+
uint8x16_t valid_u8x16 = vcltq_u8(lane_indices_u8x16, vdupq_n_u8((nk_u8_t)(count - idx)));
|
|
3444
3452
|
uint8x16_t invalid_or_nan_u8x16 = vornq_u8(is_nan_u8x16, valid_u8x16);
|
|
3445
3453
|
data_for_min_u8x16 = vbslq_u8(invalid_or_nan_u8x16, vdupq_n_u8(0xFF), comparable_u8x16);
|
|
3446
3454
|
data_for_max_u8x16 = vbslq_u8(invalid_or_nan_u8x16, vdupq_n_u8(0x00), comparable_u8x16);
|
|
@@ -3537,10 +3545,10 @@ NK_INTERNAL void nk_reduce_moments_e5m2_neon_contiguous_( //
|
|
|
3537
3545
|
for (; idx + 8 <= count; idx += 8) {
|
|
3538
3546
|
uint8x8_t raw_u8x8 = vld1_u8((nk_u8_t const *)(data_ptr + idx));
|
|
3539
3547
|
float16x8_t half_f16x8 = nk_e5m2x8_to_f16x8_neon_(raw_u8x8);
|
|
3540
|
-
float32x4_t
|
|
3541
|
-
float32x4_t
|
|
3542
|
-
sum_f32x4 = vaddq_f32(sum_f32x4, vaddq_f32(
|
|
3543
|
-
sumsq_f32x4 = vfmaq_f32(vfmaq_f32(sumsq_f32x4,
|
|
3548
|
+
float32x4_t low_f32x4 = vcvt_f32_f16(vget_low_f16(half_f16x8));
|
|
3549
|
+
float32x4_t high_f32x4 = vcvt_high_f32_f16(half_f16x8);
|
|
3550
|
+
sum_f32x4 = vaddq_f32(sum_f32x4, vaddq_f32(low_f32x4, high_f32x4));
|
|
3551
|
+
sumsq_f32x4 = vfmaq_f32(vfmaq_f32(sumsq_f32x4, low_f32x4, low_f32x4), high_f32x4, high_f32x4);
|
|
3544
3552
|
}
|
|
3545
3553
|
nk_f32_t sum = vaddvq_f32(sum_f32x4), sumsq = vaddvq_f32(sumsq_f32x4);
|
|
3546
3554
|
for (; idx < count; ++idx) {
|
|
@@ -3557,33 +3565,33 @@ NK_INTERNAL void nk_reduce_moments_e5m2_neon_strided_( //
|
|
|
3557
3565
|
float32x4_t sum_f32x4 = vdupq_n_f32(0), sumsq_f32x4 = vdupq_n_f32(0);
|
|
3558
3566
|
nk_size_t idx = 0;
|
|
3559
3567
|
if (stride_elements == 2) {
|
|
3560
|
-
for (; idx + 8
|
|
3568
|
+
for (; idx + 8 < count; idx += 8) {
|
|
3561
3569
|
uint8x8x2_t loaded_u8x8x2 = vld2_u8((nk_u8_t const *)(data_ptr + idx * 2));
|
|
3562
3570
|
float16x8_t half_f16x8 = nk_e5m2x8_to_f16x8_neon_(loaded_u8x8x2.val[0]);
|
|
3563
|
-
float32x4_t
|
|
3564
|
-
float32x4_t
|
|
3565
|
-
sum_f32x4 = vaddq_f32(sum_f32x4, vaddq_f32(
|
|
3566
|
-
sumsq_f32x4 = vfmaq_f32(vfmaq_f32(sumsq_f32x4,
|
|
3571
|
+
float32x4_t low_f32x4 = vcvt_f32_f16(vget_low_f16(half_f16x8));
|
|
3572
|
+
float32x4_t high_f32x4 = vcvt_high_f32_f16(half_f16x8);
|
|
3573
|
+
sum_f32x4 = vaddq_f32(sum_f32x4, vaddq_f32(low_f32x4, high_f32x4));
|
|
3574
|
+
sumsq_f32x4 = vfmaq_f32(vfmaq_f32(sumsq_f32x4, low_f32x4, low_f32x4), high_f32x4, high_f32x4);
|
|
3567
3575
|
}
|
|
3568
3576
|
}
|
|
3569
3577
|
else if (stride_elements == 3) {
|
|
3570
|
-
for (; idx + 8
|
|
3578
|
+
for (; idx + 8 < count; idx += 8) {
|
|
3571
3579
|
uint8x8x3_t loaded_u8x8x3 = vld3_u8((nk_u8_t const *)(data_ptr + idx * 3));
|
|
3572
3580
|
float16x8_t half_f16x8 = nk_e5m2x8_to_f16x8_neon_(loaded_u8x8x3.val[0]);
|
|
3573
|
-
float32x4_t
|
|
3574
|
-
float32x4_t
|
|
3575
|
-
sum_f32x4 = vaddq_f32(sum_f32x4, vaddq_f32(
|
|
3576
|
-
sumsq_f32x4 = vfmaq_f32(vfmaq_f32(sumsq_f32x4,
|
|
3581
|
+
float32x4_t low_f32x4 = vcvt_f32_f16(vget_low_f16(half_f16x8));
|
|
3582
|
+
float32x4_t high_f32x4 = vcvt_high_f32_f16(half_f16x8);
|
|
3583
|
+
sum_f32x4 = vaddq_f32(sum_f32x4, vaddq_f32(low_f32x4, high_f32x4));
|
|
3584
|
+
sumsq_f32x4 = vfmaq_f32(vfmaq_f32(sumsq_f32x4, low_f32x4, low_f32x4), high_f32x4, high_f32x4);
|
|
3577
3585
|
}
|
|
3578
3586
|
}
|
|
3579
3587
|
else {
|
|
3580
|
-
for (; idx + 8
|
|
3588
|
+
for (; idx + 8 < count; idx += 8) {
|
|
3581
3589
|
uint8x8x4_t loaded_u8x8x4 = vld4_u8((nk_u8_t const *)(data_ptr + idx * 4));
|
|
3582
3590
|
float16x8_t half_f16x8 = nk_e5m2x8_to_f16x8_neon_(loaded_u8x8x4.val[0]);
|
|
3583
|
-
float32x4_t
|
|
3584
|
-
float32x4_t
|
|
3585
|
-
sum_f32x4 = vaddq_f32(sum_f32x4, vaddq_f32(
|
|
3586
|
-
sumsq_f32x4 = vfmaq_f32(vfmaq_f32(sumsq_f32x4,
|
|
3591
|
+
float32x4_t low_f32x4 = vcvt_f32_f16(vget_low_f16(half_f16x8));
|
|
3592
|
+
float32x4_t high_f32x4 = vcvt_high_f32_f16(half_f16x8);
|
|
3593
|
+
sum_f32x4 = vaddq_f32(sum_f32x4, vaddq_f32(low_f32x4, high_f32x4));
|
|
3594
|
+
sumsq_f32x4 = vfmaq_f32(vfmaq_f32(sumsq_f32x4, low_f32x4, low_f32x4), high_f32x4, high_f32x4);
|
|
3587
3595
|
}
|
|
3588
3596
|
}
|
|
3589
3597
|
nk_f32_t sum = vaddvq_f32(sum_f32x4), sumsq = vaddvq_f32(sumsq_f32x4);
|
|
@@ -3652,7 +3660,7 @@ NK_INTERNAL void nk_reduce_minmax_e5m2_neon_contiguous_( //
|
|
|
3652
3660
|
uint8x16_t nan_max_u8x16 = vbslq_u8(is_nan_u8x16, vdupq_n_u8(0x00), comparable_u8x16);
|
|
3653
3661
|
uint8x16_t lane_indices_u8x16 = vcombine_u8(vreinterpret_u8_u64(vcreate_u64(0x0706050403020100ULL)),
|
|
3654
3662
|
vreinterpret_u8_u64(vcreate_u64(0x0F0E0D0C0B0A0908ULL)));
|
|
3655
|
-
uint8x16_t valid_u8x16 = vcltq_u8(lane_indices_u8x16, vdupq_n_u8((
|
|
3663
|
+
uint8x16_t valid_u8x16 = vcltq_u8(lane_indices_u8x16, vdupq_n_u8((nk_u8_t)remaining));
|
|
3656
3664
|
uint8x16_t data_for_min_u8x16 = vbslq_u8(valid_u8x16, nan_min_u8x16, vdupq_n_u8(0xFF));
|
|
3657
3665
|
uint8x16_t data_for_max_u8x16 = vbslq_u8(valid_u8x16, nan_max_u8x16, vdupq_n_u8(0));
|
|
3658
3666
|
uint8x16_t less_u8x16 = vcltq_u8(data_for_min_u8x16, min_u8x16);
|
|
@@ -3704,7 +3712,7 @@ NK_INTERNAL void nk_reduce_minmax_e5m2_neon_strided_( //
|
|
|
3704
3712
|
uint8x16_t data_for_min_u8x16, data_for_max_u8x16;
|
|
3705
3713
|
|
|
3706
3714
|
nk_reduce_minmax_e5m2_neon_cycle:
|
|
3707
|
-
if (stride_elements == 2 && idx + 16
|
|
3715
|
+
if (stride_elements == 2 && idx + 16 < count) {
|
|
3708
3716
|
uint8x16x2_t loaded = vld2q_u8((nk_u8_t const *)(data_ptr + idx * 2));
|
|
3709
3717
|
uint8x16_t comparable_u8x16 = nk_fp8x16_to_comparable_neon_(loaded.val[0]);
|
|
3710
3718
|
uint8x16_t is_nan_u8x16 = vorrq_u8(vcleq_u8(comparable_u8x16, vdupq_n_u8(0x02)),
|
|
@@ -3713,7 +3721,7 @@ nk_reduce_minmax_e5m2_neon_cycle:
|
|
|
3713
3721
|
data_for_max_u8x16 = vbslq_u8(is_nan_u8x16, vdupq_n_u8(0x00), comparable_u8x16);
|
|
3714
3722
|
idx += 16;
|
|
3715
3723
|
}
|
|
3716
|
-
else if (stride_elements == 3 && idx + 16
|
|
3724
|
+
else if (stride_elements == 3 && idx + 16 < count) {
|
|
3717
3725
|
uint8x16x3_t loaded = vld3q_u8((nk_u8_t const *)(data_ptr + idx * 3));
|
|
3718
3726
|
uint8x16_t comparable_u8x16 = nk_fp8x16_to_comparable_neon_(loaded.val[0]);
|
|
3719
3727
|
uint8x16_t is_nan_u8x16 = vorrq_u8(vcleq_u8(comparable_u8x16, vdupq_n_u8(0x02)),
|
|
@@ -3722,7 +3730,7 @@ nk_reduce_minmax_e5m2_neon_cycle:
|
|
|
3722
3730
|
data_for_max_u8x16 = vbslq_u8(is_nan_u8x16, vdupq_n_u8(0x00), comparable_u8x16);
|
|
3723
3731
|
idx += 16;
|
|
3724
3732
|
}
|
|
3725
|
-
else if (stride_elements == 4 && idx + 16
|
|
3733
|
+
else if (stride_elements == 4 && idx + 16 < count) {
|
|
3726
3734
|
uint8x16x4_t loaded = vld4q_u8((nk_u8_t const *)(data_ptr + idx * 4));
|
|
3727
3735
|
uint8x16_t comparable_u8x16 = nk_fp8x16_to_comparable_neon_(loaded.val[0]);
|
|
3728
3736
|
uint8x16_t is_nan_u8x16 = vorrq_u8(vcleq_u8(comparable_u8x16, vdupq_n_u8(0x02)),
|
|
@@ -3737,7 +3745,7 @@ nk_reduce_minmax_e5m2_neon_cycle:
|
|
|
3737
3745
|
uint8x16_t comparable_u8x16 = nk_fp8x16_to_comparable_neon_(tail_vec.u8x16);
|
|
3738
3746
|
uint8x16_t is_nan_u8x16 = vorrq_u8(vcleq_u8(comparable_u8x16, vdupq_n_u8(0x02)),
|
|
3739
3747
|
vcgeq_u8(comparable_u8x16, vdupq_n_u8(0xFD)));
|
|
3740
|
-
uint8x16_t valid_u8x16 = vcltq_u8(lane_indices_u8x16, vdupq_n_u8((
|
|
3748
|
+
uint8x16_t valid_u8x16 = vcltq_u8(lane_indices_u8x16, vdupq_n_u8((nk_u8_t)(count - idx)));
|
|
3741
3749
|
uint8x16_t invalid_or_nan_u8x16 = vornq_u8(is_nan_u8x16, valid_u8x16);
|
|
3742
3750
|
data_for_min_u8x16 = vbslq_u8(invalid_or_nan_u8x16, vdupq_n_u8(0xFF), comparable_u8x16);
|
|
3743
3751
|
data_for_max_u8x16 = vbslq_u8(invalid_or_nan_u8x16, vdupq_n_u8(0x00), comparable_u8x16);
|
|
@@ -3826,6 +3834,108 @@ NK_PUBLIC void nk_reduce_minmax_e5m2_neon( //
|
|
|
3826
3834
|
max_index_ptr);
|
|
3827
3835
|
}
|
|
3828
3836
|
|
|
3837
|
+
NK_INTERNAL void nk_reduce_moments_f16_neon_contiguous_( //
|
|
3838
|
+
nk_f16_t const *data_ptr, nk_size_t count, //
|
|
3839
|
+
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
3840
|
+
float32x4_t sum_f32x4 = vdupq_n_f32(0);
|
|
3841
|
+
float32x4_t sumsq_f32x4 = vdupq_n_f32(0);
|
|
3842
|
+
nk_size_t idx = 0;
|
|
3843
|
+
|
|
3844
|
+
for (; idx + 8 <= count; idx += 8) {
|
|
3845
|
+
float16x8_t data_f16x8 = vreinterpretq_f16_u16(vld1q_u16((nk_u16_t const *)(data_ptr + idx)));
|
|
3846
|
+
float32x4_t low_f32x4 = vcvt_f32_f16(vget_low_f16(data_f16x8));
|
|
3847
|
+
float32x4_t high_f32x4 = vcvt_high_f32_f16(data_f16x8);
|
|
3848
|
+
sum_f32x4 = vaddq_f32(sum_f32x4, low_f32x4);
|
|
3849
|
+
sum_f32x4 = vaddq_f32(sum_f32x4, high_f32x4);
|
|
3850
|
+
sumsq_f32x4 = vfmaq_f32(sumsq_f32x4, low_f32x4, low_f32x4);
|
|
3851
|
+
sumsq_f32x4 = vfmaq_f32(sumsq_f32x4, high_f32x4, high_f32x4);
|
|
3852
|
+
}
|
|
3853
|
+
|
|
3854
|
+
nk_f32_t sum = vaddvq_f32(sum_f32x4);
|
|
3855
|
+
nk_f32_t sumsq = vaddvq_f32(sumsq_f32x4);
|
|
3856
|
+
for (; idx < count; ++idx) {
|
|
3857
|
+
nk_f32_t value_f32;
|
|
3858
|
+
nk_f16_to_f32_serial(data_ptr + idx, &value_f32);
|
|
3859
|
+
sum += value_f32, sumsq += value_f32 * value_f32;
|
|
3860
|
+
}
|
|
3861
|
+
*sum_ptr = sum, *sumsq_ptr = sumsq;
|
|
3862
|
+
}
|
|
3863
|
+
|
|
3864
|
+
NK_INTERNAL void nk_reduce_moments_f16_neon_strided_( //
|
|
3865
|
+
nk_f16_t const *data_ptr, nk_size_t count, nk_size_t stride_elements, //
|
|
3866
|
+
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
3867
|
+
float32x4_t sum_f32x4 = vdupq_n_f32(0);
|
|
3868
|
+
float32x4_t sumsq_f32x4 = vdupq_n_f32(0);
|
|
3869
|
+
nk_size_t idx = 0;
|
|
3870
|
+
|
|
3871
|
+
if (stride_elements == 2) {
|
|
3872
|
+
for (; idx + 8 < count; idx += 8) {
|
|
3873
|
+
uint16x8x2_t loaded_u16x8x2 = vld2q_u16((nk_u16_t const *)(data_ptr + idx * 2));
|
|
3874
|
+
float16x8_t data_f16x8 = vreinterpretq_f16_u16(loaded_u16x8x2.val[0]);
|
|
3875
|
+
float32x4_t low_f32x4 = vcvt_f32_f16(vget_low_f16(data_f16x8));
|
|
3876
|
+
float32x4_t high_f32x4 = vcvt_high_f32_f16(data_f16x8);
|
|
3877
|
+
sum_f32x4 = vaddq_f32(sum_f32x4, low_f32x4);
|
|
3878
|
+
sum_f32x4 = vaddq_f32(sum_f32x4, high_f32x4);
|
|
3879
|
+
sumsq_f32x4 = vfmaq_f32(sumsq_f32x4, low_f32x4, low_f32x4);
|
|
3880
|
+
sumsq_f32x4 = vfmaq_f32(sumsq_f32x4, high_f32x4, high_f32x4);
|
|
3881
|
+
}
|
|
3882
|
+
}
|
|
3883
|
+
else if (stride_elements == 3) {
|
|
3884
|
+
for (; idx + 8 < count; idx += 8) {
|
|
3885
|
+
uint16x8x3_t loaded_u16x8x3 = vld3q_u16((nk_u16_t const *)(data_ptr + idx * 3));
|
|
3886
|
+
float16x8_t data_f16x8 = vreinterpretq_f16_u16(loaded_u16x8x3.val[0]);
|
|
3887
|
+
float32x4_t low_f32x4 = vcvt_f32_f16(vget_low_f16(data_f16x8));
|
|
3888
|
+
float32x4_t high_f32x4 = vcvt_high_f32_f16(data_f16x8);
|
|
3889
|
+
sum_f32x4 = vaddq_f32(sum_f32x4, low_f32x4);
|
|
3890
|
+
sum_f32x4 = vaddq_f32(sum_f32x4, high_f32x4);
|
|
3891
|
+
sumsq_f32x4 = vfmaq_f32(sumsq_f32x4, low_f32x4, low_f32x4);
|
|
3892
|
+
sumsq_f32x4 = vfmaq_f32(sumsq_f32x4, high_f32x4, high_f32x4);
|
|
3893
|
+
}
|
|
3894
|
+
}
|
|
3895
|
+
else if (stride_elements == 4) {
|
|
3896
|
+
for (; idx + 8 < count; idx += 8) {
|
|
3897
|
+
uint16x8x4_t loaded_u16x8x4 = vld4q_u16((nk_u16_t const *)(data_ptr + idx * 4));
|
|
3898
|
+
float16x8_t data_f16x8 = vreinterpretq_f16_u16(loaded_u16x8x4.val[0]);
|
|
3899
|
+
float32x4_t low_f32x4 = vcvt_f32_f16(vget_low_f16(data_f16x8));
|
|
3900
|
+
float32x4_t high_f32x4 = vcvt_high_f32_f16(data_f16x8);
|
|
3901
|
+
sum_f32x4 = vaddq_f32(sum_f32x4, low_f32x4);
|
|
3902
|
+
sum_f32x4 = vaddq_f32(sum_f32x4, high_f32x4);
|
|
3903
|
+
sumsq_f32x4 = vfmaq_f32(sumsq_f32x4, low_f32x4, low_f32x4);
|
|
3904
|
+
sumsq_f32x4 = vfmaq_f32(sumsq_f32x4, high_f32x4, high_f32x4);
|
|
3905
|
+
}
|
|
3906
|
+
}
|
|
3907
|
+
|
|
3908
|
+
nk_f32_t sum = vaddvq_f32(sum_f32x4);
|
|
3909
|
+
nk_f32_t sumsq = vaddvq_f32(sumsq_f32x4);
|
|
3910
|
+
for (; idx < count; ++idx) {
|
|
3911
|
+
nk_f32_t value_f32;
|
|
3912
|
+
nk_f16_to_f32_serial((nk_f16_t const *)(data_ptr + idx * stride_elements), &value_f32);
|
|
3913
|
+
sum += value_f32, sumsq += value_f32 * value_f32;
|
|
3914
|
+
}
|
|
3915
|
+
*sum_ptr = sum, *sumsq_ptr = sumsq;
|
|
3916
|
+
}
|
|
3917
|
+
|
|
3918
|
+
NK_PUBLIC void nk_reduce_moments_f16_neon( //
|
|
3919
|
+
nk_f16_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
3920
|
+
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
3921
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_f16_t);
|
|
3922
|
+
int aligned = (stride_bytes % sizeof(nk_f16_t) == 0);
|
|
3923
|
+
if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
|
|
3924
|
+
else if (!aligned) nk_reduce_moments_f16_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
3925
|
+
else if (count > (nk_size_t)(NK_U16_MAX + 1) * 8) {
|
|
3926
|
+
nk_size_t left_count = count / 2;
|
|
3927
|
+
nk_f32_t left_sum_value, left_sumsq_value, right_sum_value, right_sumsq_value;
|
|
3928
|
+
nk_reduce_moments_f16_neon(data_ptr, left_count, stride_bytes, &left_sum_value, &left_sumsq_value);
|
|
3929
|
+
nk_reduce_moments_f16_neon(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
|
|
3930
|
+
&right_sum_value, &right_sumsq_value);
|
|
3931
|
+
*sum_ptr = left_sum_value + right_sum_value, *sumsq_ptr = left_sumsq_value + right_sumsq_value;
|
|
3932
|
+
}
|
|
3933
|
+
else if (stride_elements == 1) nk_reduce_moments_f16_neon_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
|
|
3934
|
+
else if (stride_elements <= 4)
|
|
3935
|
+
nk_reduce_moments_f16_neon_strided_(data_ptr, count, stride_elements, sum_ptr, sumsq_ptr);
|
|
3936
|
+
else nk_reduce_moments_f16_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
3937
|
+
}
|
|
3938
|
+
|
|
3829
3939
|
#if defined(__clang__)
|
|
3830
3940
|
#pragma clang attribute pop
|
|
3831
3941
|
#elif defined(__GNUC__)
|