numkong 7.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +201 -0
- package/README.md +495 -0
- package/binding.gyp +540 -0
- package/c/dispatch.h +512 -0
- package/c/dispatch_bf16.c +389 -0
- package/c/dispatch_bf16c.c +52 -0
- package/c/dispatch_e2m3.c +263 -0
- package/c/dispatch_e3m2.c +243 -0
- package/c/dispatch_e4m3.c +276 -0
- package/c/dispatch_e5m2.c +272 -0
- package/c/dispatch_f16.c +376 -0
- package/c/dispatch_f16c.c +58 -0
- package/c/dispatch_f32.c +378 -0
- package/c/dispatch_f32c.c +99 -0
- package/c/dispatch_f64.c +296 -0
- package/c/dispatch_f64c.c +98 -0
- package/c/dispatch_i16.c +96 -0
- package/c/dispatch_i32.c +89 -0
- package/c/dispatch_i4.c +150 -0
- package/c/dispatch_i64.c +86 -0
- package/c/dispatch_i8.c +289 -0
- package/c/dispatch_other.c +330 -0
- package/c/dispatch_u1.c +148 -0
- package/c/dispatch_u16.c +124 -0
- package/c/dispatch_u32.c +118 -0
- package/c/dispatch_u4.c +150 -0
- package/c/dispatch_u64.c +102 -0
- package/c/dispatch_u8.c +303 -0
- package/c/numkong.c +950 -0
- package/include/README.md +573 -0
- package/include/module.modulemap +129 -0
- package/include/numkong/attention/sapphireamx.h +1361 -0
- package/include/numkong/attention/sme.h +2066 -0
- package/include/numkong/attention.h +49 -0
- package/include/numkong/capabilities.h +748 -0
- package/include/numkong/cast/README.md +262 -0
- package/include/numkong/cast/haswell.h +975 -0
- package/include/numkong/cast/icelake.h +470 -0
- package/include/numkong/cast/neon.h +1192 -0
- package/include/numkong/cast/rvv.h +1021 -0
- package/include/numkong/cast/sapphire.h +262 -0
- package/include/numkong/cast/serial.h +2262 -0
- package/include/numkong/cast/skylake.h +856 -0
- package/include/numkong/cast/v128relaxed.h +180 -0
- package/include/numkong/cast.h +230 -0
- package/include/numkong/curved/README.md +223 -0
- package/include/numkong/curved/genoa.h +182 -0
- package/include/numkong/curved/haswell.h +276 -0
- package/include/numkong/curved/neon.h +205 -0
- package/include/numkong/curved/neonbfdot.h +212 -0
- package/include/numkong/curved/neonhalf.h +212 -0
- package/include/numkong/curved/rvv.h +305 -0
- package/include/numkong/curved/serial.h +207 -0
- package/include/numkong/curved/skylake.h +457 -0
- package/include/numkong/curved/smef64.h +506 -0
- package/include/numkong/curved.h +517 -0
- package/include/numkong/curved.hpp +144 -0
- package/include/numkong/dot/README.md +425 -0
- package/include/numkong/dot/alder.h +563 -0
- package/include/numkong/dot/genoa.h +315 -0
- package/include/numkong/dot/haswell.h +1688 -0
- package/include/numkong/dot/icelake.h +883 -0
- package/include/numkong/dot/neon.h +818 -0
- package/include/numkong/dot/neonbfdot.h +244 -0
- package/include/numkong/dot/neonfhm.h +360 -0
- package/include/numkong/dot/neonhalf.h +198 -0
- package/include/numkong/dot/neonsdot.h +508 -0
- package/include/numkong/dot/rvv.h +714 -0
- package/include/numkong/dot/rvvbb.h +72 -0
- package/include/numkong/dot/rvvbf16.h +123 -0
- package/include/numkong/dot/rvvhalf.h +129 -0
- package/include/numkong/dot/sapphire.h +141 -0
- package/include/numkong/dot/serial.h +838 -0
- package/include/numkong/dot/sierra.h +405 -0
- package/include/numkong/dot/skylake.h +1084 -0
- package/include/numkong/dot/sve.h +379 -0
- package/include/numkong/dot/svebfdot.h +74 -0
- package/include/numkong/dot/svehalf.h +123 -0
- package/include/numkong/dot/v128relaxed.h +1258 -0
- package/include/numkong/dot.h +1070 -0
- package/include/numkong/dot.hpp +94 -0
- package/include/numkong/dots/README.md +496 -0
- package/include/numkong/dots/alder.h +114 -0
- package/include/numkong/dots/genoa.h +94 -0
- package/include/numkong/dots/haswell.h +295 -0
- package/include/numkong/dots/icelake.h +171 -0
- package/include/numkong/dots/neon.h +120 -0
- package/include/numkong/dots/neonbfdot.h +58 -0
- package/include/numkong/dots/neonfhm.h +94 -0
- package/include/numkong/dots/neonhalf.h +57 -0
- package/include/numkong/dots/neonsdot.h +108 -0
- package/include/numkong/dots/rvv.h +2486 -0
- package/include/numkong/dots/sapphireamx.h +3973 -0
- package/include/numkong/dots/serial.h +2844 -0
- package/include/numkong/dots/sierra.h +97 -0
- package/include/numkong/dots/skylake.h +196 -0
- package/include/numkong/dots/sme.h +5372 -0
- package/include/numkong/dots/smebi32.h +461 -0
- package/include/numkong/dots/smef64.h +1318 -0
- package/include/numkong/dots/smehalf.h +47 -0
- package/include/numkong/dots/v128relaxed.h +294 -0
- package/include/numkong/dots.h +2804 -0
- package/include/numkong/dots.hpp +639 -0
- package/include/numkong/each/README.md +469 -0
- package/include/numkong/each/haswell.h +1658 -0
- package/include/numkong/each/icelake.h +272 -0
- package/include/numkong/each/neon.h +1104 -0
- package/include/numkong/each/neonbfdot.h +212 -0
- package/include/numkong/each/neonhalf.h +410 -0
- package/include/numkong/each/rvv.h +1121 -0
- package/include/numkong/each/sapphire.h +477 -0
- package/include/numkong/each/serial.h +260 -0
- package/include/numkong/each/skylake.h +1562 -0
- package/include/numkong/each.h +2146 -0
- package/include/numkong/each.hpp +434 -0
- package/include/numkong/geospatial/README.md +147 -0
- package/include/numkong/geospatial/haswell.h +593 -0
- package/include/numkong/geospatial/neon.h +571 -0
- package/include/numkong/geospatial/rvv.h +701 -0
- package/include/numkong/geospatial/serial.h +309 -0
- package/include/numkong/geospatial/skylake.h +577 -0
- package/include/numkong/geospatial/v128relaxed.h +613 -0
- package/include/numkong/geospatial.h +453 -0
- package/include/numkong/geospatial.hpp +235 -0
- package/include/numkong/matrix.hpp +336 -0
- package/include/numkong/maxsim/README.md +187 -0
- package/include/numkong/maxsim/alder.h +511 -0
- package/include/numkong/maxsim/genoa.h +115 -0
- package/include/numkong/maxsim/haswell.h +553 -0
- package/include/numkong/maxsim/icelake.h +480 -0
- package/include/numkong/maxsim/neonsdot.h +394 -0
- package/include/numkong/maxsim/sapphireamx.h +877 -0
- package/include/numkong/maxsim/serial.h +490 -0
- package/include/numkong/maxsim/sme.h +929 -0
- package/include/numkong/maxsim/v128relaxed.h +280 -0
- package/include/numkong/maxsim.h +571 -0
- package/include/numkong/maxsim.hpp +133 -0
- package/include/numkong/mesh/README.md +227 -0
- package/include/numkong/mesh/haswell.h +2235 -0
- package/include/numkong/mesh/neon.h +1329 -0
- package/include/numkong/mesh/neonbfdot.h +842 -0
- package/include/numkong/mesh/neonhalf.h +616 -0
- package/include/numkong/mesh/rvv.h +916 -0
- package/include/numkong/mesh/serial.h +742 -0
- package/include/numkong/mesh/skylake.h +1135 -0
- package/include/numkong/mesh/v128relaxed.h +1052 -0
- package/include/numkong/mesh.h +652 -0
- package/include/numkong/mesh.hpp +762 -0
- package/include/numkong/numkong.h +78 -0
- package/include/numkong/numkong.hpp +57 -0
- package/include/numkong/probability/README.md +173 -0
- package/include/numkong/probability/haswell.h +267 -0
- package/include/numkong/probability/neon.h +225 -0
- package/include/numkong/probability/rvv.h +409 -0
- package/include/numkong/probability/serial.h +169 -0
- package/include/numkong/probability/skylake.h +324 -0
- package/include/numkong/probability.h +383 -0
- package/include/numkong/probability.hpp +120 -0
- package/include/numkong/random.h +50 -0
- package/include/numkong/random.hpp +285 -0
- package/include/numkong/reduce/README.md +547 -0
- package/include/numkong/reduce/alder.h +632 -0
- package/include/numkong/reduce/genoa.h +201 -0
- package/include/numkong/reduce/haswell.h +3783 -0
- package/include/numkong/reduce/icelake.h +549 -0
- package/include/numkong/reduce/neon.h +3841 -0
- package/include/numkong/reduce/neonbfdot.h +353 -0
- package/include/numkong/reduce/neonfhm.h +665 -0
- package/include/numkong/reduce/neonhalf.h +157 -0
- package/include/numkong/reduce/neonsdot.h +357 -0
- package/include/numkong/reduce/rvv.h +3407 -0
- package/include/numkong/reduce/serial.h +757 -0
- package/include/numkong/reduce/sierra.h +338 -0
- package/include/numkong/reduce/skylake.h +3792 -0
- package/include/numkong/reduce/v128relaxed.h +2302 -0
- package/include/numkong/reduce.h +1597 -0
- package/include/numkong/reduce.hpp +633 -0
- package/include/numkong/scalar/README.md +89 -0
- package/include/numkong/scalar/haswell.h +113 -0
- package/include/numkong/scalar/neon.h +122 -0
- package/include/numkong/scalar/neonhalf.h +70 -0
- package/include/numkong/scalar/rvv.h +211 -0
- package/include/numkong/scalar/sapphire.h +63 -0
- package/include/numkong/scalar/serial.h +332 -0
- package/include/numkong/scalar/v128relaxed.h +56 -0
- package/include/numkong/scalar.h +683 -0
- package/include/numkong/set/README.md +179 -0
- package/include/numkong/set/haswell.h +334 -0
- package/include/numkong/set/icelake.h +485 -0
- package/include/numkong/set/neon.h +364 -0
- package/include/numkong/set/rvv.h +226 -0
- package/include/numkong/set/rvvbb.h +117 -0
- package/include/numkong/set/serial.h +174 -0
- package/include/numkong/set/sve.h +185 -0
- package/include/numkong/set/v128relaxed.h +240 -0
- package/include/numkong/set.h +457 -0
- package/include/numkong/set.hpp +114 -0
- package/include/numkong/sets/README.md +149 -0
- package/include/numkong/sets/haswell.h +63 -0
- package/include/numkong/sets/icelake.h +66 -0
- package/include/numkong/sets/neon.h +61 -0
- package/include/numkong/sets/serial.h +43 -0
- package/include/numkong/sets/smebi32.h +1099 -0
- package/include/numkong/sets/v128relaxed.h +58 -0
- package/include/numkong/sets.h +339 -0
- package/include/numkong/sparse/README.md +156 -0
- package/include/numkong/sparse/icelake.h +463 -0
- package/include/numkong/sparse/neon.h +288 -0
- package/include/numkong/sparse/serial.h +117 -0
- package/include/numkong/sparse/sve2.h +507 -0
- package/include/numkong/sparse/turin.h +322 -0
- package/include/numkong/sparse.h +363 -0
- package/include/numkong/sparse.hpp +113 -0
- package/include/numkong/spatial/README.md +435 -0
- package/include/numkong/spatial/alder.h +607 -0
- package/include/numkong/spatial/genoa.h +290 -0
- package/include/numkong/spatial/haswell.h +960 -0
- package/include/numkong/spatial/icelake.h +586 -0
- package/include/numkong/spatial/neon.h +773 -0
- package/include/numkong/spatial/neonbfdot.h +165 -0
- package/include/numkong/spatial/neonhalf.h +118 -0
- package/include/numkong/spatial/neonsdot.h +261 -0
- package/include/numkong/spatial/rvv.h +984 -0
- package/include/numkong/spatial/rvvbf16.h +123 -0
- package/include/numkong/spatial/rvvhalf.h +117 -0
- package/include/numkong/spatial/sapphire.h +343 -0
- package/include/numkong/spatial/serial.h +346 -0
- package/include/numkong/spatial/sierra.h +323 -0
- package/include/numkong/spatial/skylake.h +606 -0
- package/include/numkong/spatial/sve.h +224 -0
- package/include/numkong/spatial/svebfdot.h +122 -0
- package/include/numkong/spatial/svehalf.h +109 -0
- package/include/numkong/spatial/v128relaxed.h +717 -0
- package/include/numkong/spatial.h +1425 -0
- package/include/numkong/spatial.hpp +183 -0
- package/include/numkong/spatials/README.md +580 -0
- package/include/numkong/spatials/alder.h +94 -0
- package/include/numkong/spatials/genoa.h +94 -0
- package/include/numkong/spatials/haswell.h +219 -0
- package/include/numkong/spatials/icelake.h +113 -0
- package/include/numkong/spatials/neon.h +109 -0
- package/include/numkong/spatials/neonbfdot.h +60 -0
- package/include/numkong/spatials/neonfhm.h +92 -0
- package/include/numkong/spatials/neonhalf.h +58 -0
- package/include/numkong/spatials/neonsdot.h +109 -0
- package/include/numkong/spatials/rvv.h +1960 -0
- package/include/numkong/spatials/sapphireamx.h +1149 -0
- package/include/numkong/spatials/serial.h +226 -0
- package/include/numkong/spatials/sierra.h +96 -0
- package/include/numkong/spatials/skylake.h +184 -0
- package/include/numkong/spatials/sme.h +1901 -0
- package/include/numkong/spatials/smef64.h +465 -0
- package/include/numkong/spatials/v128relaxed.h +240 -0
- package/include/numkong/spatials.h +3021 -0
- package/include/numkong/spatials.hpp +508 -0
- package/include/numkong/tensor.hpp +1592 -0
- package/include/numkong/trigonometry/README.md +184 -0
- package/include/numkong/trigonometry/haswell.h +652 -0
- package/include/numkong/trigonometry/neon.h +639 -0
- package/include/numkong/trigonometry/rvv.h +699 -0
- package/include/numkong/trigonometry/serial.h +703 -0
- package/include/numkong/trigonometry/skylake.h +721 -0
- package/include/numkong/trigonometry/v128relaxed.h +666 -0
- package/include/numkong/trigonometry.h +467 -0
- package/include/numkong/trigonometry.hpp +166 -0
- package/include/numkong/types.h +1384 -0
- package/include/numkong/types.hpp +5603 -0
- package/include/numkong/vector.hpp +698 -0
- package/javascript/README.md +246 -0
- package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
- package/javascript/dist/cjs/numkong-wasm.js +617 -0
- package/javascript/dist/cjs/numkong.d.ts +343 -0
- package/javascript/dist/cjs/numkong.js +523 -0
- package/javascript/dist/cjs/package.json +3 -0
- package/javascript/dist/cjs/types.d.ts +284 -0
- package/javascript/dist/cjs/types.js +653 -0
- package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
- package/javascript/dist/esm/numkong-wasm.js +595 -0
- package/javascript/dist/esm/numkong.d.ts +343 -0
- package/javascript/dist/esm/numkong.js +452 -0
- package/javascript/dist/esm/package.json +3 -0
- package/javascript/dist/esm/types.d.ts +284 -0
- package/javascript/dist/esm/types.js +630 -0
- package/javascript/dist-package-cjs.json +3 -0
- package/javascript/dist-package-esm.json +3 -0
- package/javascript/node-gyp-build.d.ts +1 -0
- package/javascript/numkong-wasm.ts +756 -0
- package/javascript/numkong.c +689 -0
- package/javascript/numkong.ts +575 -0
- package/javascript/tsconfig-base.json +39 -0
- package/javascript/tsconfig-cjs.json +8 -0
- package/javascript/tsconfig-esm.json +8 -0
- package/javascript/types.ts +674 -0
- package/package.json +87 -0
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief NEON FP16 implementations for the redesigned reduction API (moments + minmax).
|
|
3
|
+
* @file include/numkong/reduce/neonhalf.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date February 13, 2026
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/reduce.h
|
|
8
|
+
*
|
|
9
|
+
* @section reduce_neonhalf_new_design Design Notes
|
|
10
|
+
*
|
|
11
|
+
* Moments (sum + sum-of-squares) accumulate in f32 via vcvt_f32_f16 widening, giving
|
|
12
|
+
* full f32 precision. The contiguous path processes 8 f16 elements per iteration, widening
|
|
13
|
+
* to two f32x4 halves and using vfmaq_f32 for fused multiply-accumulate of squares.
|
|
14
|
+
*
|
|
15
|
+
* Minmax tracks min/max values as native f16x8 with u16x8 iteration counters (same width
|
|
16
|
+
* as f16). The u16 counters wrap at 65536, so the dispatcher splits arrays larger than
|
|
17
|
+
* 65536 * 8 = 524288 elements via recursive halving.
|
|
18
|
+
*/
|
|
19
|
+
#ifndef NK_REDUCE_NEONHALF_H
|
|
20
|
+
#define NK_REDUCE_NEONHALF_H
|
|
21
|
+
|
|
22
|
+
#if NK_TARGET_ARM_
|
|
23
|
+
#if NK_TARGET_NEONHALF
|
|
24
|
+
|
|
25
|
+
#include "numkong/types.h"
|
|
26
|
+
#include "numkong/cast/neon.h"
|
|
27
|
+
#include "numkong/cast/serial.h"
|
|
28
|
+
#include "numkong/reduce/serial.h"
|
|
29
|
+
|
|
30
|
+
#if defined(__cplusplus)
|
|
31
|
+
extern "C" {
|
|
32
|
+
#endif
|
|
33
|
+
|
|
34
|
+
#if defined(__clang__)
|
|
35
|
+
#pragma clang attribute push(__attribute__((target("arch=armv8.2-a+simd+fp16"))), apply_to = function)
|
|
36
|
+
#elif defined(__GNUC__)
|
|
37
|
+
#pragma GCC push_options
|
|
38
|
+
#pragma GCC target("arch=armv8.2-a+simd+fp16")
|
|
39
|
+
#endif
|
|
40
|
+
|
|
41
|
+
NK_INTERNAL void nk_reduce_moments_f16_neonhalf_contiguous_( //
|
|
42
|
+
nk_f16_t const *data_ptr, nk_size_t count, //
|
|
43
|
+
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
44
|
+
float32x4_t sum_f32x4 = vdupq_n_f32(0);
|
|
45
|
+
float32x4_t sumsq_f32x4 = vdupq_n_f32(0);
|
|
46
|
+
nk_size_t idx = 0;
|
|
47
|
+
|
|
48
|
+
for (; idx + 8 <= count; idx += 8) {
|
|
49
|
+
float16x8_t data_f16x8 = vld1q_f16((nk_f16_for_arm_simd_t const *)(data_ptr + idx));
|
|
50
|
+
float32x4_t low_f32x4 = vcvt_f32_f16(vget_low_f16(data_f16x8));
|
|
51
|
+
float32x4_t high_f32x4 = vcvt_f32_f16(vget_high_f16(data_f16x8));
|
|
52
|
+
sum_f32x4 = vaddq_f32(sum_f32x4, low_f32x4);
|
|
53
|
+
sum_f32x4 = vaddq_f32(sum_f32x4, high_f32x4);
|
|
54
|
+
sumsq_f32x4 = vfmaq_f32(sumsq_f32x4, low_f32x4, low_f32x4);
|
|
55
|
+
sumsq_f32x4 = vfmaq_f32(sumsq_f32x4, high_f32x4, high_f32x4);
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
// Scalar tail
|
|
59
|
+
nk_f32_t sum = vaddvq_f32(sum_f32x4);
|
|
60
|
+
nk_f32_t sumsq = vaddvq_f32(sumsq_f32x4);
|
|
61
|
+
for (; idx < count; ++idx) {
|
|
62
|
+
nk_f32_t value_f32;
|
|
63
|
+
nk_f16_to_f32_serial(data_ptr + idx, &value_f32);
|
|
64
|
+
sum += value_f32, sumsq += value_f32 * value_f32;
|
|
65
|
+
}
|
|
66
|
+
*sum_ptr = sum, *sumsq_ptr = sumsq;
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
NK_INTERNAL void nk_reduce_moments_f16_neonhalf_strided_( //
|
|
70
|
+
nk_f16_t const *data_ptr, nk_size_t count, nk_size_t stride_elements, //
|
|
71
|
+
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
72
|
+
float32x4_t sum_f32x4 = vdupq_n_f32(0);
|
|
73
|
+
float32x4_t sumsq_f32x4 = vdupq_n_f32(0);
|
|
74
|
+
nk_size_t idx = 0;
|
|
75
|
+
|
|
76
|
+
if (stride_elements == 2) {
|
|
77
|
+
for (; idx + 8 <= count; idx += 8) {
|
|
78
|
+
uint16x8x2_t loaded_u16x8x2 = vld2q_u16((uint16_t const *)(data_ptr + idx * 2));
|
|
79
|
+
float16x8_t data_f16x8 = vreinterpretq_f16_u16(loaded_u16x8x2.val[0]);
|
|
80
|
+
float32x4_t low_f32x4 = vcvt_f32_f16(vget_low_f16(data_f16x8));
|
|
81
|
+
float32x4_t high_f32x4 = vcvt_f32_f16(vget_high_f16(data_f16x8));
|
|
82
|
+
sum_f32x4 = vaddq_f32(sum_f32x4, low_f32x4);
|
|
83
|
+
sum_f32x4 = vaddq_f32(sum_f32x4, high_f32x4);
|
|
84
|
+
sumsq_f32x4 = vfmaq_f32(sumsq_f32x4, low_f32x4, low_f32x4);
|
|
85
|
+
sumsq_f32x4 = vfmaq_f32(sumsq_f32x4, high_f32x4, high_f32x4);
|
|
86
|
+
}
|
|
87
|
+
}
|
|
88
|
+
else if (stride_elements == 3) {
|
|
89
|
+
for (; idx + 8 <= count; idx += 8) {
|
|
90
|
+
uint16x8x3_t loaded_u16x8x3 = vld3q_u16((uint16_t const *)(data_ptr + idx * 3));
|
|
91
|
+
float16x8_t data_f16x8 = vreinterpretq_f16_u16(loaded_u16x8x3.val[0]);
|
|
92
|
+
float32x4_t low_f32x4 = vcvt_f32_f16(vget_low_f16(data_f16x8));
|
|
93
|
+
float32x4_t high_f32x4 = vcvt_f32_f16(vget_high_f16(data_f16x8));
|
|
94
|
+
sum_f32x4 = vaddq_f32(sum_f32x4, low_f32x4);
|
|
95
|
+
sum_f32x4 = vaddq_f32(sum_f32x4, high_f32x4);
|
|
96
|
+
sumsq_f32x4 = vfmaq_f32(sumsq_f32x4, low_f32x4, low_f32x4);
|
|
97
|
+
sumsq_f32x4 = vfmaq_f32(sumsq_f32x4, high_f32x4, high_f32x4);
|
|
98
|
+
}
|
|
99
|
+
}
|
|
100
|
+
else if (stride_elements == 4) {
|
|
101
|
+
for (; idx + 8 <= count; idx += 8) {
|
|
102
|
+
uint16x8x4_t loaded_u16x8x4 = vld4q_u16((uint16_t const *)(data_ptr + idx * 4));
|
|
103
|
+
float16x8_t data_f16x8 = vreinterpretq_f16_u16(loaded_u16x8x4.val[0]);
|
|
104
|
+
float32x4_t low_f32x4 = vcvt_f32_f16(vget_low_f16(data_f16x8));
|
|
105
|
+
float32x4_t high_f32x4 = vcvt_f32_f16(vget_high_f16(data_f16x8));
|
|
106
|
+
sum_f32x4 = vaddq_f32(sum_f32x4, low_f32x4);
|
|
107
|
+
sum_f32x4 = vaddq_f32(sum_f32x4, high_f32x4);
|
|
108
|
+
sumsq_f32x4 = vfmaq_f32(sumsq_f32x4, low_f32x4, low_f32x4);
|
|
109
|
+
sumsq_f32x4 = vfmaq_f32(sumsq_f32x4, high_f32x4, high_f32x4);
|
|
110
|
+
}
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
// Scalar tail for remaining elements
|
|
114
|
+
nk_f32_t sum = vaddvq_f32(sum_f32x4);
|
|
115
|
+
nk_f32_t sumsq = vaddvq_f32(sumsq_f32x4);
|
|
116
|
+
for (; idx < count; ++idx) {
|
|
117
|
+
nk_f32_t value_f32;
|
|
118
|
+
nk_f16_to_f32_serial((nk_f16_t const *)(data_ptr + idx * stride_elements), &value_f32);
|
|
119
|
+
sum += value_f32, sumsq += value_f32 * value_f32;
|
|
120
|
+
}
|
|
121
|
+
*sum_ptr = sum, *sumsq_ptr = sumsq;
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
NK_PUBLIC void nk_reduce_moments_f16_neonhalf( //
|
|
125
|
+
nk_f16_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
126
|
+
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
127
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_f16_t);
|
|
128
|
+
int aligned = (stride_bytes % sizeof(nk_f16_t) == 0);
|
|
129
|
+
if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
|
|
130
|
+
else if (!aligned) nk_reduce_moments_f16_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
131
|
+
else if (count > (nk_size_t)(NK_U16_MAX + 1) * 8) {
|
|
132
|
+
nk_size_t left_count = count / 2;
|
|
133
|
+
nk_f32_t left_sum_value, left_sumsq_value, right_sum_value, right_sumsq_value;
|
|
134
|
+
nk_reduce_moments_f16_neonhalf(data_ptr, left_count, stride_bytes, &left_sum_value, &left_sumsq_value);
|
|
135
|
+
nk_reduce_moments_f16_neonhalf(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
|
|
136
|
+
&right_sum_value, &right_sumsq_value);
|
|
137
|
+
*sum_ptr = left_sum_value + right_sum_value, *sumsq_ptr = left_sumsq_value + right_sumsq_value;
|
|
138
|
+
}
|
|
139
|
+
else if (stride_elements == 1) nk_reduce_moments_f16_neonhalf_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
|
|
140
|
+
else if (stride_elements <= 4)
|
|
141
|
+
nk_reduce_moments_f16_neonhalf_strided_(data_ptr, count, stride_elements, sum_ptr, sumsq_ptr);
|
|
142
|
+
else nk_reduce_moments_f16_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
#if defined(__clang__)
|
|
146
|
+
#pragma clang attribute pop
|
|
147
|
+
#elif defined(__GNUC__)
|
|
148
|
+
#pragma GCC pop_options
|
|
149
|
+
#endif
|
|
150
|
+
|
|
151
|
+
#if defined(__cplusplus)
|
|
152
|
+
} // extern "C"
|
|
153
|
+
#endif
|
|
154
|
+
|
|
155
|
+
#endif // NK_TARGET_NEONHALF
|
|
156
|
+
#endif // NK_TARGET_ARM_
|
|
157
|
+
#endif // NK_REDUCE_NEONHALF_H
|
|
@@ -0,0 +1,357 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief ARMv8.4-DotProd implementations for the redesigned reduction API (moments).
|
|
3
|
+
* @file include/numkong/reduce/neonsdot.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date February 13, 2026
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/reduce.h
|
|
8
|
+
*/
|
|
9
|
+
#ifndef NK_REDUCE_NEONSDOT_H
|
|
10
|
+
#define NK_REDUCE_NEONSDOT_H
|
|
11
|
+
|
|
12
|
+
#if NK_TARGET_ARM_
|
|
13
|
+
#if NK_TARGET_NEONSDOT
|
|
14
|
+
|
|
15
|
+
#include "numkong/types.h"
|
|
16
|
+
#include "numkong/cast/serial.h"
|
|
17
|
+
#include "numkong/reduce/serial.h"
|
|
18
|
+
|
|
19
|
+
#if defined(__cplusplus)
|
|
20
|
+
extern "C" {
|
|
21
|
+
#endif
|
|
22
|
+
|
|
23
|
+
#if defined(__clang__)
|
|
24
|
+
#pragma clang attribute push(__attribute__((target("arch=armv8.2-a+dotprod"))), apply_to = function)
|
|
25
|
+
#elif defined(__GNUC__)
|
|
26
|
+
#pragma GCC push_options
|
|
27
|
+
#pragma GCC target("arch=armv8.2-a+dotprod")
|
|
28
|
+
#endif
|
|
29
|
+
|
|
30
|
+
NK_INTERNAL void nk_reduce_moments_i8_neonsdot_contiguous_( //
|
|
31
|
+
nk_i8_t const *data_ptr, nk_size_t count, //
|
|
32
|
+
nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
33
|
+
int8x16_t ones_i8x16 = vdupq_n_s8(1);
|
|
34
|
+
int32x4_t sum_i32x4 = vdupq_n_s32(0);
|
|
35
|
+
int32x4_t sumsq_i32x4 = vdupq_n_s32(0);
|
|
36
|
+
nk_size_t idx = 0;
|
|
37
|
+
for (; idx + 16 <= count; idx += 16) {
|
|
38
|
+
int8x16_t data_i8x16 = vld1q_s8(data_ptr + idx);
|
|
39
|
+
sum_i32x4 = vdotq_s32(sum_i32x4, data_i8x16, ones_i8x16);
|
|
40
|
+
sumsq_i32x4 = vdotq_s32(sumsq_i32x4, data_i8x16, data_i8x16);
|
|
41
|
+
}
|
|
42
|
+
// Widen i32 -> i64 and horizontal reduce
|
|
43
|
+
int64x2_t sum_i64x2 = vpaddlq_s32(sum_i32x4);
|
|
44
|
+
nk_i64_t sum = vgetq_lane_s64(sum_i64x2, 0) + vgetq_lane_s64(sum_i64x2, 1);
|
|
45
|
+
uint64x2_t sumsq_u64x2 = vpaddlq_u32(vreinterpretq_u32_s32(sumsq_i32x4));
|
|
46
|
+
nk_u64_t sumsq = vgetq_lane_u64(sumsq_u64x2, 0) + vgetq_lane_u64(sumsq_u64x2, 1);
|
|
47
|
+
for (; idx < count; ++idx) {
|
|
48
|
+
nk_i64_t value = (nk_i64_t)data_ptr[idx];
|
|
49
|
+
sum += value, sumsq += (nk_u64_t)(value * value);
|
|
50
|
+
}
|
|
51
|
+
*sum_ptr = sum;
|
|
52
|
+
*sumsq_ptr = sumsq;
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
NK_INTERNAL void nk_reduce_moments_i8_neonsdot_strided_( //
|
|
56
|
+
nk_i8_t const *data_ptr, nk_size_t count, nk_size_t stride_elements, //
|
|
57
|
+
nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
58
|
+
int8x16_t ones_i8x16 = vdupq_n_s8(1);
|
|
59
|
+
int32x4_t sum_i32x4 = vdupq_n_s32(0);
|
|
60
|
+
int32x4_t sumsq_i32x4 = vdupq_n_s32(0);
|
|
61
|
+
nk_size_t idx = 0;
|
|
62
|
+
if (stride_elements == 2) {
|
|
63
|
+
for (; idx + 16 <= count; idx += 16) {
|
|
64
|
+
int8x16x2_t loaded = vld2q_s8(data_ptr + idx * 2);
|
|
65
|
+
int8x16_t data_i8x16 = loaded.val[0];
|
|
66
|
+
sum_i32x4 = vdotq_s32(sum_i32x4, data_i8x16, ones_i8x16);
|
|
67
|
+
sumsq_i32x4 = vdotq_s32(sumsq_i32x4, data_i8x16, data_i8x16);
|
|
68
|
+
}
|
|
69
|
+
}
|
|
70
|
+
else if (stride_elements == 3) {
|
|
71
|
+
for (; idx + 16 <= count; idx += 16) {
|
|
72
|
+
int8x16x3_t loaded = vld3q_s8(data_ptr + idx * 3);
|
|
73
|
+
int8x16_t data_i8x16 = loaded.val[0];
|
|
74
|
+
sum_i32x4 = vdotq_s32(sum_i32x4, data_i8x16, ones_i8x16);
|
|
75
|
+
sumsq_i32x4 = vdotq_s32(sumsq_i32x4, data_i8x16, data_i8x16);
|
|
76
|
+
}
|
|
77
|
+
}
|
|
78
|
+
else if (stride_elements == 4) {
|
|
79
|
+
for (; idx + 16 <= count; idx += 16) {
|
|
80
|
+
int8x16x4_t loaded = vld4q_s8(data_ptr + idx * 4);
|
|
81
|
+
int8x16_t data_i8x16 = loaded.val[0];
|
|
82
|
+
sum_i32x4 = vdotq_s32(sum_i32x4, data_i8x16, ones_i8x16);
|
|
83
|
+
sumsq_i32x4 = vdotq_s32(sumsq_i32x4, data_i8x16, data_i8x16);
|
|
84
|
+
}
|
|
85
|
+
}
|
|
86
|
+
// Widen i32 -> i64 and horizontal reduce
|
|
87
|
+
int64x2_t sum_i64x2 = vpaddlq_s32(sum_i32x4);
|
|
88
|
+
nk_i64_t sum = vgetq_lane_s64(sum_i64x2, 0) + vgetq_lane_s64(sum_i64x2, 1);
|
|
89
|
+
uint64x2_t sumsq_u64x2 = vpaddlq_u32(vreinterpretq_u32_s32(sumsq_i32x4));
|
|
90
|
+
nk_u64_t sumsq = vgetq_lane_u64(sumsq_u64x2, 0) + vgetq_lane_u64(sumsq_u64x2, 1);
|
|
91
|
+
for (; idx < count; ++idx) {
|
|
92
|
+
nk_i64_t value = (nk_i64_t)data_ptr[idx * stride_elements];
|
|
93
|
+
sum += value, sumsq += (nk_u64_t)(value * value);
|
|
94
|
+
}
|
|
95
|
+
*sum_ptr = sum;
|
|
96
|
+
*sumsq_ptr = sumsq;
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
NK_PUBLIC void nk_reduce_moments_i8_neonsdot( //
|
|
100
|
+
nk_i8_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
101
|
+
nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
102
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_i8_t);
|
|
103
|
+
int aligned = (stride_bytes % sizeof(nk_i8_t) == 0);
|
|
104
|
+
if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
|
|
105
|
+
else if (!aligned) nk_reduce_moments_i8_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
106
|
+
else if (count > (nk_size_t)32768 * 16) {
|
|
107
|
+
nk_size_t left_count = count / 2;
|
|
108
|
+
nk_i64_t left_sum_value, right_sum_value;
|
|
109
|
+
nk_u64_t left_sumsq_value, right_sumsq_value;
|
|
110
|
+
nk_reduce_moments_i8_neonsdot(data_ptr, left_count, stride_bytes, &left_sum_value, &left_sumsq_value);
|
|
111
|
+
nk_reduce_moments_i8_neonsdot(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
|
|
112
|
+
&right_sum_value, &right_sumsq_value);
|
|
113
|
+
*sum_ptr = nk_i64_saturating_add_serial(left_sum_value, right_sum_value);
|
|
114
|
+
*sumsq_ptr = nk_u64_saturating_add_serial(left_sumsq_value, right_sumsq_value);
|
|
115
|
+
}
|
|
116
|
+
else if (stride_elements == 1) nk_reduce_moments_i8_neonsdot_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
|
|
117
|
+
else if (stride_elements <= 4)
|
|
118
|
+
nk_reduce_moments_i8_neonsdot_strided_(data_ptr, count, stride_elements, sum_ptr, sumsq_ptr);
|
|
119
|
+
else nk_reduce_moments_i8_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
120
|
+
}
|
|
121
|
+
|
|
122
|
+
NK_INTERNAL void nk_reduce_moments_u8_neonsdot_contiguous_( //
|
|
123
|
+
nk_u8_t const *data_ptr, nk_size_t count, //
|
|
124
|
+
nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
125
|
+
uint8x16_t ones_u8x16 = vdupq_n_u8(1);
|
|
126
|
+
uint32x4_t sum_u32x4 = vdupq_n_u32(0);
|
|
127
|
+
uint32x4_t sumsq_u32x4 = vdupq_n_u32(0);
|
|
128
|
+
nk_size_t idx = 0;
|
|
129
|
+
for (; idx + 16 <= count; idx += 16) {
|
|
130
|
+
uint8x16_t data_u8x16 = vld1q_u8(data_ptr + idx);
|
|
131
|
+
sum_u32x4 = vdotq_u32(sum_u32x4, data_u8x16, ones_u8x16);
|
|
132
|
+
sumsq_u32x4 = vdotq_u32(sumsq_u32x4, data_u8x16, data_u8x16);
|
|
133
|
+
}
|
|
134
|
+
uint64x2_t sum_u64x2 = vpaddlq_u32(sum_u32x4);
|
|
135
|
+
nk_u64_t sum = vgetq_lane_u64(sum_u64x2, 0) + vgetq_lane_u64(sum_u64x2, 1);
|
|
136
|
+
uint64x2_t sumsq_u64x2 = vpaddlq_u32(sumsq_u32x4);
|
|
137
|
+
nk_u64_t sumsq = vgetq_lane_u64(sumsq_u64x2, 0) + vgetq_lane_u64(sumsq_u64x2, 1);
|
|
138
|
+
for (; idx < count; ++idx) {
|
|
139
|
+
nk_u64_t value = (nk_u64_t)data_ptr[idx];
|
|
140
|
+
sum += value, sumsq += value * value;
|
|
141
|
+
}
|
|
142
|
+
*sum_ptr = sum;
|
|
143
|
+
*sumsq_ptr = sumsq;
|
|
144
|
+
}
|
|
145
|
+
|
|
146
|
+
NK_INTERNAL void nk_reduce_moments_u8_neonsdot_strided_( //
|
|
147
|
+
nk_u8_t const *data_ptr, nk_size_t count, nk_size_t stride_elements, //
|
|
148
|
+
nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
149
|
+
uint8x16_t ones_u8x16 = vdupq_n_u8(1);
|
|
150
|
+
uint32x4_t sum_u32x4 = vdupq_n_u32(0);
|
|
151
|
+
uint32x4_t sumsq_u32x4 = vdupq_n_u32(0);
|
|
152
|
+
nk_size_t idx = 0;
|
|
153
|
+
if (stride_elements == 2) {
|
|
154
|
+
for (; idx + 16 <= count; idx += 16) {
|
|
155
|
+
uint8x16x2_t loaded = vld2q_u8(data_ptr + idx * 2);
|
|
156
|
+
uint8x16_t data_u8x16 = loaded.val[0];
|
|
157
|
+
sum_u32x4 = vdotq_u32(sum_u32x4, data_u8x16, ones_u8x16);
|
|
158
|
+
sumsq_u32x4 = vdotq_u32(sumsq_u32x4, data_u8x16, data_u8x16);
|
|
159
|
+
}
|
|
160
|
+
}
|
|
161
|
+
else if (stride_elements == 3) {
|
|
162
|
+
for (; idx + 16 <= count; idx += 16) {
|
|
163
|
+
uint8x16x3_t loaded = vld3q_u8(data_ptr + idx * 3);
|
|
164
|
+
uint8x16_t data_u8x16 = loaded.val[0];
|
|
165
|
+
sum_u32x4 = vdotq_u32(sum_u32x4, data_u8x16, ones_u8x16);
|
|
166
|
+
sumsq_u32x4 = vdotq_u32(sumsq_u32x4, data_u8x16, data_u8x16);
|
|
167
|
+
}
|
|
168
|
+
}
|
|
169
|
+
else if (stride_elements == 4) {
|
|
170
|
+
for (; idx + 16 <= count; idx += 16) {
|
|
171
|
+
uint8x16x4_t loaded = vld4q_u8(data_ptr + idx * 4);
|
|
172
|
+
uint8x16_t data_u8x16 = loaded.val[0];
|
|
173
|
+
sum_u32x4 = vdotq_u32(sum_u32x4, data_u8x16, ones_u8x16);
|
|
174
|
+
sumsq_u32x4 = vdotq_u32(sumsq_u32x4, data_u8x16, data_u8x16);
|
|
175
|
+
}
|
|
176
|
+
}
|
|
177
|
+
uint64x2_t sum_u64x2 = vpaddlq_u32(sum_u32x4);
|
|
178
|
+
nk_u64_t sum = vgetq_lane_u64(sum_u64x2, 0) + vgetq_lane_u64(sum_u64x2, 1);
|
|
179
|
+
uint64x2_t sumsq_u64x2 = vpaddlq_u32(sumsq_u32x4);
|
|
180
|
+
nk_u64_t sumsq = vgetq_lane_u64(sumsq_u64x2, 0) + vgetq_lane_u64(sumsq_u64x2, 1);
|
|
181
|
+
for (; idx < count; ++idx) {
|
|
182
|
+
nk_u64_t value = (nk_u64_t)data_ptr[idx * stride_elements];
|
|
183
|
+
sum += value, sumsq += value * value;
|
|
184
|
+
}
|
|
185
|
+
*sum_ptr = sum;
|
|
186
|
+
*sumsq_ptr = sumsq;
|
|
187
|
+
}
|
|
188
|
+
|
|
189
|
+
NK_PUBLIC void nk_reduce_moments_u8_neonsdot( //
|
|
190
|
+
nk_u8_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
191
|
+
nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
192
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_u8_t);
|
|
193
|
+
int aligned = (stride_bytes % sizeof(nk_u8_t) == 0);
|
|
194
|
+
if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
|
|
195
|
+
else if (!aligned) nk_reduce_moments_u8_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
196
|
+
else if (count > (nk_size_t)16384 * 16) {
|
|
197
|
+
nk_size_t left_count = count / 2;
|
|
198
|
+
nk_u64_t left_sum_value, left_sumsq_value, right_sum_value, right_sumsq_value;
|
|
199
|
+
nk_reduce_moments_u8_neonsdot(data_ptr, left_count, stride_bytes, &left_sum_value, &left_sumsq_value);
|
|
200
|
+
nk_reduce_moments_u8_neonsdot(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
|
|
201
|
+
&right_sum_value, &right_sumsq_value);
|
|
202
|
+
*sum_ptr = nk_u64_saturating_add_serial(left_sum_value, right_sum_value);
|
|
203
|
+
*sumsq_ptr = nk_u64_saturating_add_serial(left_sumsq_value, right_sumsq_value);
|
|
204
|
+
}
|
|
205
|
+
else if (stride_elements == 1) nk_reduce_moments_u8_neonsdot_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
|
|
206
|
+
else if (stride_elements <= 4)
|
|
207
|
+
nk_reduce_moments_u8_neonsdot_strided_(data_ptr, count, stride_elements, sum_ptr, sumsq_ptr);
|
|
208
|
+
else nk_reduce_moments_u8_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
209
|
+
}
|
|
210
|
+
|
|
211
|
+
NK_INTERNAL void nk_reduce_moments_e2m3_neonsdot_contiguous_( //
|
|
212
|
+
nk_e2m3_t const *data_ptr, nk_size_t count, //
|
|
213
|
+
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
214
|
+
uint8x16x2_t lut_e2m3_x16;
|
|
215
|
+
// table[0]: values for magnitudes 0..15
|
|
216
|
+
// 0x0E0C0A0806040200 → bytes [0..7] = 0,2,4,6,8,10,12,14
|
|
217
|
+
// 0x1E1C1A1816141210 → bytes [8..15] = 16,18,20,22,24,26,28,30
|
|
218
|
+
lut_e2m3_x16.val[0] = vcombine_u8(vreinterpret_u8_u64(vcreate_u64(0x0E0C0A0806040200ULL)),
|
|
219
|
+
vreinterpret_u8_u64(vcreate_u64(0x1E1C1A1816141210ULL)));
|
|
220
|
+
// table[1]: values for magnitudes 16..31
|
|
221
|
+
// 0x3C3834302C282420 → bytes [0..7] = 32,36,40,44,48,52,56,60
|
|
222
|
+
// 0x7870686058504840 → bytes [8..15] = 64,72,80,88,96,104,112,120
|
|
223
|
+
lut_e2m3_x16.val[1] = vcombine_u8(vreinterpret_u8_u64(vcreate_u64(0x3C3834302C282420ULL)),
|
|
224
|
+
vreinterpret_u8_u64(vcreate_u64(0x7870686058504840ULL)));
|
|
225
|
+
int8x16_t ones_i8x16 = vdupq_n_s8(1);
|
|
226
|
+
int32x4_t sum_i32x4 = vdupq_n_s32(0);
|
|
227
|
+
int32x4_t sumsq_i32x4 = vdupq_n_s32(0);
|
|
228
|
+
nk_size_t idx = 0;
|
|
229
|
+
for (; idx + 16 <= count; idx += 16) {
|
|
230
|
+
uint8x16_t raw_u8x16 = vld1q_u8((nk_u8_t const *)(data_ptr + idx));
|
|
231
|
+
uint8x16_t magnitude_u8x16 = vandq_u8(raw_u8x16, vdupq_n_u8(0x1F));
|
|
232
|
+
uint8x16_t unsigned_u8x16 = vqtbl2q_u8(lut_e2m3_x16, magnitude_u8x16);
|
|
233
|
+
uint8x16_t is_negative_u8x16 = vtstq_u8(raw_u8x16, vdupq_n_u8(0x20));
|
|
234
|
+
int8x16_t positive_i8x16 = vreinterpretq_s8_u8(unsigned_u8x16);
|
|
235
|
+
int8x16_t negative_i8x16 = vnegq_s8(positive_i8x16);
|
|
236
|
+
int8x16_t scaled_i8x16 = vbslq_s8(is_negative_u8x16, negative_i8x16, positive_i8x16);
|
|
237
|
+
sum_i32x4 = vdotq_s32(sum_i32x4, scaled_i8x16, ones_i8x16);
|
|
238
|
+
sumsq_i32x4 = vdotq_s32(sumsq_i32x4, scaled_i8x16, scaled_i8x16);
|
|
239
|
+
}
|
|
240
|
+
int64x2_t sum_i64x2 = vpaddlq_s32(sum_i32x4);
|
|
241
|
+
nk_i64_t sum = vgetq_lane_s64(sum_i64x2, 0) + vgetq_lane_s64(sum_i64x2, 1);
|
|
242
|
+
uint64x2_t sumsq_u64x2 = vpaddlq_u32(vreinterpretq_u32_s32(sumsq_i32x4));
|
|
243
|
+
nk_u64_t sumsq = vgetq_lane_u64(sumsq_u64x2, 0) + vgetq_lane_u64(sumsq_u64x2, 1);
|
|
244
|
+
for (; idx < count; ++idx) {
|
|
245
|
+
nk_f32_t value;
|
|
246
|
+
nk_e2m3_to_f32_serial(&data_ptr[idx], &value);
|
|
247
|
+
sum += (nk_i64_t)(value * 16.0f), sumsq += (nk_u64_t)(nk_i64_t)(value * value * 256.0f);
|
|
248
|
+
}
|
|
249
|
+
*sum_ptr = (nk_f32_t)sum / 16.0f, *sumsq_ptr = (nk_f32_t)sumsq / 256.0f;
|
|
250
|
+
}
|
|
251
|
+
|
|
252
|
+
NK_INTERNAL void nk_reduce_moments_e2m3_neonsdot_strided_( //
|
|
253
|
+
nk_e2m3_t const *data_ptr, nk_size_t count, nk_size_t stride_elements, //
|
|
254
|
+
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
255
|
+
uint8x16x2_t lut_e2m3_x16;
|
|
256
|
+
// table[0]: values for magnitudes 0..15
|
|
257
|
+
// 0x0E0C0A0806040200 → bytes [0..7] = 0,2,4,6,8,10,12,14
|
|
258
|
+
// 0x1E1C1A1816141210 → bytes [8..15] = 16,18,20,22,24,26,28,30
|
|
259
|
+
lut_e2m3_x16.val[0] = vcombine_u8(vreinterpret_u8_u64(vcreate_u64(0x0E0C0A0806040200ULL)),
|
|
260
|
+
vreinterpret_u8_u64(vcreate_u64(0x1E1C1A1816141210ULL)));
|
|
261
|
+
// table[1]: values for magnitudes 16..31
|
|
262
|
+
// 0x3C3834302C282420 → bytes [0..7] = 32,36,40,44,48,52,56,60
|
|
263
|
+
// 0x7870686058504840 → bytes [8..15] = 64,72,80,88,96,104,112,120
|
|
264
|
+
lut_e2m3_x16.val[1] = vcombine_u8(vreinterpret_u8_u64(vcreate_u64(0x3C3834302C282420ULL)),
|
|
265
|
+
vreinterpret_u8_u64(vcreate_u64(0x7870686058504840ULL)));
|
|
266
|
+
int8x16_t ones_i8x16 = vdupq_n_s8(1);
|
|
267
|
+
int32x4_t sum_i32x4 = vdupq_n_s32(0);
|
|
268
|
+
int32x4_t sumsq_i32x4 = vdupq_n_s32(0);
|
|
269
|
+
nk_size_t idx = 0;
|
|
270
|
+
if (stride_elements == 2) {
|
|
271
|
+
for (; idx + 16 <= count; idx += 16) {
|
|
272
|
+
uint8x16x2_t loaded_u8x16x2 = vld2q_u8((nk_u8_t const *)(data_ptr + idx * 2));
|
|
273
|
+
uint8x16_t raw_u8x16 = loaded_u8x16x2.val[0];
|
|
274
|
+
uint8x16_t magnitude_u8x16 = vandq_u8(raw_u8x16, vdupq_n_u8(0x1F));
|
|
275
|
+
uint8x16_t unsigned_u8x16 = vqtbl2q_u8(lut_e2m3_x16, magnitude_u8x16);
|
|
276
|
+
uint8x16_t is_negative_u8x16 = vtstq_u8(raw_u8x16, vdupq_n_u8(0x20));
|
|
277
|
+
int8x16_t positive_i8x16 = vreinterpretq_s8_u8(unsigned_u8x16);
|
|
278
|
+
int8x16_t negative_i8x16 = vnegq_s8(positive_i8x16);
|
|
279
|
+
int8x16_t scaled_i8x16 = vbslq_s8(is_negative_u8x16, negative_i8x16, positive_i8x16);
|
|
280
|
+
sum_i32x4 = vdotq_s32(sum_i32x4, scaled_i8x16, ones_i8x16);
|
|
281
|
+
sumsq_i32x4 = vdotq_s32(sumsq_i32x4, scaled_i8x16, scaled_i8x16);
|
|
282
|
+
}
|
|
283
|
+
}
|
|
284
|
+
else if (stride_elements == 3) {
|
|
285
|
+
for (; idx + 16 <= count; idx += 16) {
|
|
286
|
+
uint8x16x3_t loaded_u8x16x3 = vld3q_u8((nk_u8_t const *)(data_ptr + idx * 3));
|
|
287
|
+
uint8x16_t raw_u8x16 = loaded_u8x16x3.val[0];
|
|
288
|
+
uint8x16_t magnitude_u8x16 = vandq_u8(raw_u8x16, vdupq_n_u8(0x1F));
|
|
289
|
+
uint8x16_t unsigned_u8x16 = vqtbl2q_u8(lut_e2m3_x16, magnitude_u8x16);
|
|
290
|
+
uint8x16_t is_negative_u8x16 = vtstq_u8(raw_u8x16, vdupq_n_u8(0x20));
|
|
291
|
+
int8x16_t positive_i8x16 = vreinterpretq_s8_u8(unsigned_u8x16);
|
|
292
|
+
int8x16_t negative_i8x16 = vnegq_s8(positive_i8x16);
|
|
293
|
+
int8x16_t scaled_i8x16 = vbslq_s8(is_negative_u8x16, negative_i8x16, positive_i8x16);
|
|
294
|
+
sum_i32x4 = vdotq_s32(sum_i32x4, scaled_i8x16, ones_i8x16);
|
|
295
|
+
sumsq_i32x4 = vdotq_s32(sumsq_i32x4, scaled_i8x16, scaled_i8x16);
|
|
296
|
+
}
|
|
297
|
+
}
|
|
298
|
+
else if (stride_elements == 4) {
|
|
299
|
+
for (; idx + 16 <= count; idx += 16) {
|
|
300
|
+
uint8x16x4_t loaded_u8x16x4 = vld4q_u8((nk_u8_t const *)(data_ptr + idx * 4));
|
|
301
|
+
uint8x16_t raw_u8x16 = loaded_u8x16x4.val[0];
|
|
302
|
+
uint8x16_t magnitude_u8x16 = vandq_u8(raw_u8x16, vdupq_n_u8(0x1F));
|
|
303
|
+
uint8x16_t unsigned_u8x16 = vqtbl2q_u8(lut_e2m3_x16, magnitude_u8x16);
|
|
304
|
+
uint8x16_t is_negative_u8x16 = vtstq_u8(raw_u8x16, vdupq_n_u8(0x20));
|
|
305
|
+
int8x16_t positive_i8x16 = vreinterpretq_s8_u8(unsigned_u8x16);
|
|
306
|
+
int8x16_t negative_i8x16 = vnegq_s8(positive_i8x16);
|
|
307
|
+
int8x16_t scaled_i8x16 = vbslq_s8(is_negative_u8x16, negative_i8x16, positive_i8x16);
|
|
308
|
+
sum_i32x4 = vdotq_s32(sum_i32x4, scaled_i8x16, ones_i8x16);
|
|
309
|
+
sumsq_i32x4 = vdotq_s32(sumsq_i32x4, scaled_i8x16, scaled_i8x16);
|
|
310
|
+
}
|
|
311
|
+
}
|
|
312
|
+
int64x2_t sum_i64x2 = vpaddlq_s32(sum_i32x4);
|
|
313
|
+
nk_i64_t sum = vgetq_lane_s64(sum_i64x2, 0) + vgetq_lane_s64(sum_i64x2, 1);
|
|
314
|
+
uint64x2_t sumsq_u64x2 = vpaddlq_u32(vreinterpretq_u32_s32(sumsq_i32x4));
|
|
315
|
+
nk_u64_t sumsq = vgetq_lane_u64(sumsq_u64x2, 0) + vgetq_lane_u64(sumsq_u64x2, 1);
|
|
316
|
+
for (; idx < count; ++idx) {
|
|
317
|
+
nk_f32_t value;
|
|
318
|
+
nk_e2m3_to_f32_serial(data_ptr + idx * stride_elements, &value);
|
|
319
|
+
sum += (nk_i64_t)(value * 16.0f), sumsq += (nk_u64_t)(nk_i64_t)(value * value * 256.0f);
|
|
320
|
+
}
|
|
321
|
+
*sum_ptr = (nk_f32_t)sum / 16.0f, *sumsq_ptr = (nk_f32_t)sumsq / 256.0f;
|
|
322
|
+
}
|
|
323
|
+
|
|
324
|
+
NK_PUBLIC void nk_reduce_moments_e2m3_neonsdot( //
|
|
325
|
+
nk_e2m3_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
326
|
+
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
327
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_e2m3_t);
|
|
328
|
+
int aligned = (stride_bytes % sizeof(nk_e2m3_t) == 0);
|
|
329
|
+
if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
|
|
330
|
+
else if (!aligned) nk_reduce_moments_e2m3_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
331
|
+
else if (count > (nk_size_t)(NK_I16_MAX + 1) * 16) {
|
|
332
|
+
nk_size_t left_count = count / 2;
|
|
333
|
+
nk_f32_t left_sum_value, left_sumsq_value, right_sum_value, right_sumsq_value;
|
|
334
|
+
nk_reduce_moments_e2m3_neonsdot(data_ptr, left_count, stride_bytes, &left_sum_value, &left_sumsq_value);
|
|
335
|
+
nk_reduce_moments_e2m3_neonsdot(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
|
|
336
|
+
&right_sum_value, &right_sumsq_value);
|
|
337
|
+
*sum_ptr = left_sum_value + right_sum_value, *sumsq_ptr = left_sumsq_value + right_sumsq_value;
|
|
338
|
+
}
|
|
339
|
+
else if (stride_elements == 1) nk_reduce_moments_e2m3_neonsdot_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
|
|
340
|
+
else if (stride_elements <= 4)
|
|
341
|
+
nk_reduce_moments_e2m3_neonsdot_strided_(data_ptr, count, stride_elements, sum_ptr, sumsq_ptr);
|
|
342
|
+
else nk_reduce_moments_e2m3_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
343
|
+
}
|
|
344
|
+
|
|
345
|
+
#if defined(__clang__)
|
|
346
|
+
#pragma clang attribute pop
|
|
347
|
+
#elif defined(__GNUC__)
|
|
348
|
+
#pragma GCC pop_options
|
|
349
|
+
#endif
|
|
350
|
+
|
|
351
|
+
#if defined(__cplusplus)
|
|
352
|
+
} // extern "C"
|
|
353
|
+
#endif
|
|
354
|
+
|
|
355
|
+
#endif // NK_TARGET_NEONSDOT
|
|
356
|
+
#endif // NK_TARGET_ARM_
|
|
357
|
+
#endif // NK_REDUCE_NEONSDOT_H
|