numkong 7.0.0 → 7.4.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +239 -122
- package/binding.gyp +25 -491
- package/c/dispatch_bf16.c +59 -1
- package/c/dispatch_e2m3.c +41 -8
- package/c/dispatch_e3m2.c +49 -8
- package/c/dispatch_e4m3.c +51 -9
- package/c/dispatch_e5m2.c +45 -1
- package/c/dispatch_f16.c +79 -26
- package/c/dispatch_f16c.c +5 -5
- package/c/dispatch_f32.c +56 -0
- package/c/dispatch_f64.c +52 -0
- package/c/dispatch_i4.c +3 -0
- package/c/dispatch_i8.c +62 -3
- package/c/dispatch_other.c +18 -0
- package/c/dispatch_u1.c +54 -9
- package/c/dispatch_u4.c +3 -0
- package/c/dispatch_u8.c +64 -3
- package/c/numkong.c +3 -0
- package/include/README.md +79 -9
- package/include/numkong/attention/sapphireamx.h +278 -276
- package/include/numkong/attention/sme.h +983 -977
- package/include/numkong/attention.h +1 -1
- package/include/numkong/capabilities.h +289 -94
- package/include/numkong/cast/README.md +40 -40
- package/include/numkong/cast/diamond.h +64 -0
- package/include/numkong/cast/haswell.h +42 -194
- package/include/numkong/cast/icelake.h +42 -37
- package/include/numkong/cast/loongsonasx.h +252 -0
- package/include/numkong/cast/neon.h +216 -249
- package/include/numkong/cast/powervsx.h +449 -0
- package/include/numkong/cast/rvv.h +223 -274
- package/include/numkong/cast/sapphire.h +18 -18
- package/include/numkong/cast/serial.h +1018 -944
- package/include/numkong/cast/skylake.h +82 -23
- package/include/numkong/cast/v128relaxed.h +462 -105
- package/include/numkong/cast.h +24 -0
- package/include/numkong/cast.hpp +44 -0
- package/include/numkong/curved/README.md +17 -17
- package/include/numkong/curved/neon.h +131 -7
- package/include/numkong/curved/neonbfdot.h +6 -7
- package/include/numkong/curved/rvv.h +26 -26
- package/include/numkong/curved/smef64.h +186 -182
- package/include/numkong/curved.h +14 -18
- package/include/numkong/dot/README.md +154 -137
- package/include/numkong/dot/alder.h +43 -43
- package/include/numkong/dot/diamond.h +158 -0
- package/include/numkong/dot/genoa.h +4 -30
- package/include/numkong/dot/haswell.h +215 -180
- package/include/numkong/dot/icelake.h +190 -76
- package/include/numkong/dot/loongsonasx.h +671 -0
- package/include/numkong/dot/neon.h +124 -73
- package/include/numkong/dot/neonbfdot.h +11 -12
- package/include/numkong/dot/neonfhm.h +44 -46
- package/include/numkong/dot/neonfp8.h +323 -0
- package/include/numkong/dot/neonsdot.h +190 -76
- package/include/numkong/dot/powervsx.h +752 -0
- package/include/numkong/dot/rvv.h +92 -84
- package/include/numkong/dot/rvvbf16.h +12 -12
- package/include/numkong/dot/rvvhalf.h +12 -12
- package/include/numkong/dot/sapphire.h +4 -4
- package/include/numkong/dot/serial.h +66 -30
- package/include/numkong/dot/sierra.h +31 -31
- package/include/numkong/dot/skylake.h +142 -110
- package/include/numkong/dot/sve.h +217 -177
- package/include/numkong/dot/svebfdot.h +10 -10
- package/include/numkong/dot/svehalf.h +85 -41
- package/include/numkong/dot/svesdot.h +89 -0
- package/include/numkong/dot/v128relaxed.h +124 -89
- package/include/numkong/dot.h +114 -48
- package/include/numkong/dots/README.md +203 -203
- package/include/numkong/dots/alder.h +12 -9
- package/include/numkong/dots/diamond.h +86 -0
- package/include/numkong/dots/genoa.h +10 -4
- package/include/numkong/dots/haswell.h +63 -48
- package/include/numkong/dots/icelake.h +27 -18
- package/include/numkong/dots/loongsonasx.h +176 -0
- package/include/numkong/dots/neon.h +14 -11
- package/include/numkong/dots/neonbfdot.h +4 -3
- package/include/numkong/dots/neonfhm.h +11 -9
- package/include/numkong/dots/neonfp8.h +99 -0
- package/include/numkong/dots/neonsdot.h +48 -12
- package/include/numkong/dots/powervsx.h +194 -0
- package/include/numkong/dots/rvv.h +451 -344
- package/include/numkong/dots/sapphireamx.h +1028 -984
- package/include/numkong/dots/serial.h +213 -197
- package/include/numkong/dots/sierra.h +10 -7
- package/include/numkong/dots/skylake.h +47 -36
- package/include/numkong/dots/sme.h +2001 -2364
- package/include/numkong/dots/smebi32.h +175 -162
- package/include/numkong/dots/smef64.h +328 -323
- package/include/numkong/dots/v128relaxed.h +64 -41
- package/include/numkong/dots.h +573 -293
- package/include/numkong/dots.hpp +45 -43
- package/include/numkong/each/README.md +133 -137
- package/include/numkong/each/haswell.h +6 -6
- package/include/numkong/each/icelake.h +7 -7
- package/include/numkong/each/neon.h +76 -42
- package/include/numkong/each/neonbfdot.h +11 -12
- package/include/numkong/each/neonhalf.h +24 -116
- package/include/numkong/each/rvv.h +28 -28
- package/include/numkong/each/sapphire.h +27 -161
- package/include/numkong/each/serial.h +6 -6
- package/include/numkong/each/skylake.h +7 -7
- package/include/numkong/each/v128relaxed.h +562 -0
- package/include/numkong/each.h +148 -62
- package/include/numkong/each.hpp +2 -2
- package/include/numkong/geospatial/README.md +18 -18
- package/include/numkong/geospatial/haswell.h +365 -325
- package/include/numkong/geospatial/neon.h +350 -306
- package/include/numkong/geospatial/rvv.h +4 -4
- package/include/numkong/geospatial/skylake.h +376 -340
- package/include/numkong/geospatial/v128relaxed.h +366 -327
- package/include/numkong/geospatial.h +17 -17
- package/include/numkong/matrix.hpp +4 -4
- package/include/numkong/maxsim/README.md +14 -14
- package/include/numkong/maxsim/alder.h +6 -6
- package/include/numkong/maxsim/genoa.h +4 -4
- package/include/numkong/maxsim/haswell.h +6 -6
- package/include/numkong/maxsim/icelake.h +18 -18
- package/include/numkong/maxsim/neonsdot.h +21 -21
- package/include/numkong/maxsim/sapphireamx.h +14 -14
- package/include/numkong/maxsim/serial.h +6 -6
- package/include/numkong/maxsim/sme.h +221 -196
- package/include/numkong/maxsim/v128relaxed.h +6 -6
- package/include/numkong/mesh/README.md +62 -56
- package/include/numkong/mesh/haswell.h +339 -464
- package/include/numkong/mesh/neon.h +1100 -519
- package/include/numkong/mesh/neonbfdot.h +36 -68
- package/include/numkong/mesh/rvv.h +530 -435
- package/include/numkong/mesh/serial.h +75 -91
- package/include/numkong/mesh/skylake.h +1627 -302
- package/include/numkong/mesh/v128relaxed.h +443 -330
- package/include/numkong/mesh.h +63 -49
- package/include/numkong/mesh.hpp +4 -4
- package/include/numkong/numkong.h +3 -3
- package/include/numkong/numkong.hpp +1 -0
- package/include/numkong/probability/README.md +23 -19
- package/include/numkong/probability/neon.h +82 -52
- package/include/numkong/probability/rvv.h +28 -23
- package/include/numkong/probability/serial.h +51 -39
- package/include/numkong/probability.h +20 -23
- package/include/numkong/random.h +1 -1
- package/include/numkong/reduce/README.md +143 -138
- package/include/numkong/reduce/alder.h +81 -77
- package/include/numkong/reduce/haswell.h +222 -220
- package/include/numkong/reduce/neon.h +629 -519
- package/include/numkong/reduce/neonbfdot.h +7 -218
- package/include/numkong/reduce/neonfhm.h +9 -381
- package/include/numkong/reduce/neonsdot.h +9 -9
- package/include/numkong/reduce/rvv.h +928 -802
- package/include/numkong/reduce/serial.h +23 -27
- package/include/numkong/reduce/sierra.h +20 -20
- package/include/numkong/reduce/skylake.h +326 -324
- package/include/numkong/reduce/v128relaxed.h +52 -52
- package/include/numkong/reduce.h +4 -23
- package/include/numkong/reduce.hpp +156 -11
- package/include/numkong/scalar/README.md +6 -6
- package/include/numkong/scalar/haswell.h +26 -17
- package/include/numkong/scalar/loongsonasx.h +74 -0
- package/include/numkong/scalar/neon.h +9 -9
- package/include/numkong/scalar/powervsx.h +96 -0
- package/include/numkong/scalar/rvv.h +2 -2
- package/include/numkong/scalar/sapphire.h +21 -10
- package/include/numkong/scalar/serial.h +21 -21
- package/include/numkong/scalar.h +13 -0
- package/include/numkong/set/README.md +28 -28
- package/include/numkong/set/haswell.h +12 -12
- package/include/numkong/set/icelake.h +14 -14
- package/include/numkong/set/loongsonasx.h +181 -0
- package/include/numkong/set/neon.h +17 -18
- package/include/numkong/set/powervsx.h +326 -0
- package/include/numkong/set/rvv.h +4 -4
- package/include/numkong/set/serial.h +6 -6
- package/include/numkong/set/sve.h +60 -59
- package/include/numkong/set/v128relaxed.h +6 -6
- package/include/numkong/set.h +21 -7
- package/include/numkong/sets/README.md +26 -26
- package/include/numkong/sets/loongsonasx.h +52 -0
- package/include/numkong/sets/powervsx.h +65 -0
- package/include/numkong/sets/smebi32.h +395 -364
- package/include/numkong/sets.h +83 -40
- package/include/numkong/sparse/README.md +4 -4
- package/include/numkong/sparse/icelake.h +101 -101
- package/include/numkong/sparse/serial.h +1 -1
- package/include/numkong/sparse/sve2.h +137 -141
- package/include/numkong/sparse/turin.h +12 -12
- package/include/numkong/sparse.h +10 -10
- package/include/numkong/spatial/README.md +230 -226
- package/include/numkong/spatial/alder.h +113 -116
- package/include/numkong/spatial/diamond.h +240 -0
- package/include/numkong/spatial/genoa.h +0 -68
- package/include/numkong/spatial/haswell.h +74 -55
- package/include/numkong/spatial/icelake.h +539 -58
- package/include/numkong/spatial/loongsonasx.h +483 -0
- package/include/numkong/spatial/neon.h +125 -52
- package/include/numkong/spatial/neonbfdot.h +8 -9
- package/include/numkong/spatial/neonfp8.h +258 -0
- package/include/numkong/spatial/neonsdot.h +180 -12
- package/include/numkong/spatial/powervsx.h +738 -0
- package/include/numkong/spatial/rvv.h +146 -139
- package/include/numkong/spatial/rvvbf16.h +17 -12
- package/include/numkong/spatial/rvvhalf.h +13 -10
- package/include/numkong/spatial/serial.h +13 -12
- package/include/numkong/spatial/sierra.h +232 -39
- package/include/numkong/spatial/skylake.h +73 -74
- package/include/numkong/spatial/sve.h +93 -72
- package/include/numkong/spatial/svebfdot.h +29 -29
- package/include/numkong/spatial/svehalf.h +52 -26
- package/include/numkong/spatial/svesdot.h +142 -0
- package/include/numkong/spatial/v128relaxed.h +293 -41
- package/include/numkong/spatial.h +338 -82
- package/include/numkong/spatials/README.md +194 -194
- package/include/numkong/spatials/diamond.h +82 -0
- package/include/numkong/spatials/haswell.h +2 -2
- package/include/numkong/spatials/loongsonasx.h +153 -0
- package/include/numkong/spatials/neonfp8.h +111 -0
- package/include/numkong/spatials/neonsdot.h +34 -0
- package/include/numkong/spatials/powervsx.h +153 -0
- package/include/numkong/spatials/rvv.h +259 -243
- package/include/numkong/spatials/sapphireamx.h +173 -173
- package/include/numkong/spatials/serial.h +2 -2
- package/include/numkong/spatials/skylake.h +2 -2
- package/include/numkong/spatials/sme.h +590 -605
- package/include/numkong/spatials/smef64.h +139 -130
- package/include/numkong/spatials/v128relaxed.h +2 -2
- package/include/numkong/spatials.h +820 -500
- package/include/numkong/spatials.hpp +49 -48
- package/include/numkong/tensor.hpp +406 -17
- package/include/numkong/trigonometry/README.md +19 -19
- package/include/numkong/trigonometry/haswell.h +402 -401
- package/include/numkong/trigonometry/neon.h +386 -387
- package/include/numkong/trigonometry/rvv.h +52 -51
- package/include/numkong/trigonometry/serial.h +13 -13
- package/include/numkong/trigonometry/skylake.h +373 -369
- package/include/numkong/trigonometry/v128relaxed.h +375 -374
- package/include/numkong/trigonometry.h +13 -13
- package/include/numkong/trigonometry.hpp +2 -2
- package/include/numkong/types.h +287 -49
- package/include/numkong/types.hpp +436 -12
- package/include/numkong/vector.hpp +82 -14
- package/javascript/dist/cjs/numkong-wasm.js +6 -12
- package/javascript/dist/cjs/numkong.d.ts +7 -1
- package/javascript/dist/cjs/numkong.js +37 -11
- package/javascript/dist/cjs/types.d.ts +9 -0
- package/javascript/dist/cjs/types.js +96 -0
- package/javascript/dist/esm/numkong-browser.d.ts +14 -0
- package/javascript/dist/esm/numkong-browser.js +23 -0
- package/javascript/dist/esm/numkong-wasm.js +6 -12
- package/javascript/dist/esm/numkong.d.ts +7 -1
- package/javascript/dist/esm/numkong.js +37 -11
- package/javascript/dist/esm/types.d.ts +9 -0
- package/javascript/dist/esm/types.js +96 -0
- package/javascript/node-gyp-build.d.ts +4 -1
- package/javascript/numkong-browser.ts +40 -0
- package/javascript/numkong-wasm.ts +7 -13
- package/javascript/numkong.c +5 -26
- package/javascript/numkong.ts +36 -11
- package/javascript/tsconfig-base.json +1 -0
- package/javascript/tsconfig-cjs.json +6 -1
- package/javascript/types.ts +110 -0
- package/numkong.gypi +101 -0
- package/package.json +34 -13
- package/probes/arm_neon.c +8 -0
- package/probes/arm_neon_bfdot.c +9 -0
- package/probes/arm_neon_fhm.c +9 -0
- package/probes/arm_neon_half.c +8 -0
- package/probes/arm_neon_sdot.c +9 -0
- package/probes/arm_neonfp8.c +9 -0
- package/probes/arm_sme.c +16 -0
- package/probes/arm_sme2.c +16 -0
- package/probes/arm_sme2p1.c +16 -0
- package/probes/arm_sme_bf16.c +16 -0
- package/probes/arm_sme_bi32.c +16 -0
- package/probes/arm_sme_f64.c +16 -0
- package/probes/arm_sme_fa64.c +14 -0
- package/probes/arm_sme_half.c +16 -0
- package/probes/arm_sme_lut2.c +15 -0
- package/probes/arm_sve.c +18 -0
- package/probes/arm_sve2.c +20 -0
- package/probes/arm_sve2p1.c +18 -0
- package/probes/arm_sve_bfdot.c +20 -0
- package/probes/arm_sve_half.c +18 -0
- package/probes/arm_sve_sdot.c +21 -0
- package/probes/loongarch_lasx.c +12 -0
- package/probes/power_vsx.c +12 -0
- package/probes/probe.js +127 -0
- package/probes/riscv_rvv.c +14 -0
- package/probes/riscv_rvv_bb.c +15 -0
- package/probes/riscv_rvv_bf16.c +17 -0
- package/probes/riscv_rvv_half.c +14 -0
- package/probes/wasm_v128relaxed.c +11 -0
- package/probes/x86_alder.c +17 -0
- package/probes/x86_diamond.c +17 -0
- package/probes/x86_genoa.c +17 -0
- package/probes/x86_graniteamx.c +19 -0
- package/probes/x86_haswell.c +11 -0
- package/probes/x86_icelake.c +17 -0
- package/probes/x86_sapphire.c +16 -0
- package/probes/x86_sapphireamx.c +18 -0
- package/probes/x86_sierra.c +17 -0
- package/probes/x86_skylake.c +15 -0
- package/probes/x86_turin.c +17 -0
- package/wasm/numkong-emscripten.js +2 -0
- package/wasm/numkong.d.ts +14 -0
- package/wasm/numkong.js +1124 -0
- package/wasm/numkong.wasm +0 -0
- package/include/numkong/curved/neonhalf.h +0 -212
- package/include/numkong/dot/neonhalf.h +0 -198
- package/include/numkong/dots/neonhalf.h +0 -57
- package/include/numkong/mesh/neonhalf.h +0 -616
- package/include/numkong/reduce/neonhalf.h +0 -157
- package/include/numkong/spatial/neonhalf.h +0 -118
- package/include/numkong/spatial/sapphire.h +0 -343
- package/include/numkong/spatials/neonhalf.h +0 -58
- package/javascript/README.md +0 -246
|
@@ -8,20 +8,19 @@
|
|
|
8
8
|
*
|
|
9
9
|
* @section elementwise_neon_instructions ARM NEON Instructions
|
|
10
10
|
*
|
|
11
|
-
* Intrinsic
|
|
12
|
-
*
|
|
13
|
-
*
|
|
14
|
-
*
|
|
15
|
-
*
|
|
16
|
-
*
|
|
17
|
-
*
|
|
18
|
-
*
|
|
19
|
-
*
|
|
20
|
-
*
|
|
21
|
-
*
|
|
22
|
-
*
|
|
23
|
-
*
|
|
24
|
-
* vqmovn_s32 SQXTN (V.4H, V.4S) 3cy 2/cy 2/cy
|
|
11
|
+
* Intrinsic Instruction A76 M5
|
|
12
|
+
* vld1q_f32 LD1 (V.4S) 4cy @ 2p 4cy @ 3p
|
|
13
|
+
* vst1q_f32 ST1 (V.4S) 2cy @ 2p 2cy @ 3p
|
|
14
|
+
* vaddq_f32 FADD (V.4S, V.4S, V.4S) 2cy @ 2p 2cy @ 4p
|
|
15
|
+
* vmulq_f32 FMUL (V.4S, V.4S, V.4S) 3cy @ 2p 3cy @ 4p
|
|
16
|
+
* vfmaq_f32 FMLA (V.4S, V.4S, V.4S) 4cy @ 2p 3cy @ 4p
|
|
17
|
+
* vaddq_f64 FADD (V.2D, V.2D, V.2D) 2cy @ 2p 2cy @ 4p
|
|
18
|
+
* vmulq_f64 FMUL (V.2D, V.2D, V.2D) 3cy @ 2p 3cy @ 4p
|
|
19
|
+
* vfmaq_f64 FMLA (V.2D, V.2D, V.2D) 4cy @ 2p 3cy @ 4p
|
|
20
|
+
* vqaddq_s16 SQADD (V.8H, V.8H, V.8H) 2cy @ 2p 3cy @ 2p
|
|
21
|
+
* vcvtq_f32_s32 SCVTF (V.4S, V.4S) 3cy @ 2p 3cy @ 4p
|
|
22
|
+
* vcvtnq_s32_f32 FCVTNS (V.4S, V.4S) 3cy @ 2p 3cy @ 4p
|
|
23
|
+
* vqmovn_s32 SQXTN (V.4H, V.4S) 3cy @ 2p 3cy @ 4p
|
|
25
24
|
*
|
|
26
25
|
* Elementwise operations are throughput-bound rather than latency-bound. FP arithmetic
|
|
27
26
|
* throughput doubles on 4-pipe cores (Apple M4+, Graviton3+, Oryon) from 2/cy to 4/cy.
|
|
@@ -37,6 +36,7 @@
|
|
|
37
36
|
|
|
38
37
|
#include "numkong/types.h"
|
|
39
38
|
#include "numkong/cast/neon.h"
|
|
39
|
+
#include "numkong/cast/serial.h" // `nk_f32_to_u8_serial`, `nk_f32_to_i8_serial`
|
|
40
40
|
|
|
41
41
|
#if defined(__cplusplus)
|
|
42
42
|
extern "C" {
|
|
@@ -145,10 +145,10 @@ NK_PUBLIC void nk_each_sum_i16_neon(nk_i16_t const *a, nk_i16_t const *b, nk_siz
|
|
|
145
145
|
// The main loop:
|
|
146
146
|
nk_size_t i = 0;
|
|
147
147
|
for (; i + 8 <= n; i += 8) {
|
|
148
|
-
int16x8_t
|
|
149
|
-
int16x8_t
|
|
150
|
-
int16x8_t
|
|
151
|
-
vst1q_s16(result + i,
|
|
148
|
+
int16x8_t a_i16x8 = vld1q_s16(a + i);
|
|
149
|
+
int16x8_t b_i16x8 = vld1q_s16(b + i);
|
|
150
|
+
int16x8_t sum_i16x8 = vqaddq_s16(a_i16x8, b_i16x8);
|
|
151
|
+
vst1q_s16(result + i, sum_i16x8);
|
|
152
152
|
}
|
|
153
153
|
|
|
154
154
|
// The tail:
|
|
@@ -291,10 +291,10 @@ NK_PUBLIC void nk_each_sum_i32_neon(nk_i32_t const *a, nk_i32_t const *b, nk_siz
|
|
|
291
291
|
// The main loop:
|
|
292
292
|
nk_size_t i = 0;
|
|
293
293
|
for (; i + 4 <= n; i += 4) {
|
|
294
|
-
int32x4_t
|
|
295
|
-
int32x4_t
|
|
296
|
-
int32x4_t
|
|
297
|
-
vst1q_s32(result + i,
|
|
294
|
+
int32x4_t a_i32x4 = vld1q_s32(a + i);
|
|
295
|
+
int32x4_t b_i32x4 = vld1q_s32(b + i);
|
|
296
|
+
int32x4_t sum_i32x4 = vqaddq_s32(a_i32x4, b_i32x4);
|
|
297
|
+
vst1q_s32(result + i, sum_i32x4);
|
|
298
298
|
}
|
|
299
299
|
|
|
300
300
|
// The tail:
|
|
@@ -437,10 +437,10 @@ NK_PUBLIC void nk_each_sum_i64_neon(nk_i64_t const *a, nk_i64_t const *b, nk_siz
|
|
|
437
437
|
// The main loop:
|
|
438
438
|
nk_size_t i = 0;
|
|
439
439
|
for (; i + 2 <= n; i += 2) {
|
|
440
|
-
int64x2_t
|
|
441
|
-
int64x2_t
|
|
442
|
-
int64x2_t
|
|
443
|
-
vst1q_s64(result + i,
|
|
440
|
+
int64x2_t a_i64x2 = vld1q_s64(a + i);
|
|
441
|
+
int64x2_t b_i64x2 = vld1q_s64(b + i);
|
|
442
|
+
int64x2_t sum_i64x2 = vqaddq_s64(a_i64x2, b_i64x2);
|
|
443
|
+
vst1q_s64(result + i, sum_i64x2);
|
|
444
444
|
}
|
|
445
445
|
|
|
446
446
|
// The tail:
|
|
@@ -679,9 +679,9 @@ NK_PUBLIC void nk_each_sum_e4m3_neon(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_
|
|
|
679
679
|
float16x8_t a_f16x8 = nk_e4m3x8_to_f16x8_neon_(vld1_u8(a + i));
|
|
680
680
|
float16x8_t b_f16x8 = nk_e4m3x8_to_f16x8_neon_(vld1_u8(b + i));
|
|
681
681
|
float32x4_t a_low_f32x4 = vcvt_f32_f16(vget_low_f16(a_f16x8));
|
|
682
|
-
float32x4_t a_high_f32x4 =
|
|
682
|
+
float32x4_t a_high_f32x4 = vcvt_high_f32_f16(a_f16x8);
|
|
683
683
|
float32x4_t b_low_f32x4 = vcvt_f32_f16(vget_low_f16(b_f16x8));
|
|
684
|
-
float32x4_t b_high_f32x4 =
|
|
684
|
+
float32x4_t b_high_f32x4 = vcvt_high_f32_f16(b_f16x8);
|
|
685
685
|
float32x4_t result_low_f32x4 = vaddq_f32(a_low_f32x4, b_low_f32x4);
|
|
686
686
|
float32x4_t result_high_f32x4 = vaddq_f32(a_high_f32x4, b_high_f32x4);
|
|
687
687
|
nk_b32_vec_t result_low_vec = nk_f32x4_to_e4m3x4_neon_(result_low_f32x4);
|
|
@@ -703,9 +703,9 @@ NK_PUBLIC void nk_each_sum_e5m2_neon(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_
|
|
|
703
703
|
float16x8_t a_f16x8 = nk_e5m2x8_to_f16x8_neon_(vld1_u8(a + i));
|
|
704
704
|
float16x8_t b_f16x8 = nk_e5m2x8_to_f16x8_neon_(vld1_u8(b + i));
|
|
705
705
|
float32x4_t a_low_f32x4 = vcvt_f32_f16(vget_low_f16(a_f16x8));
|
|
706
|
-
float32x4_t a_high_f32x4 =
|
|
706
|
+
float32x4_t a_high_f32x4 = vcvt_high_f32_f16(a_f16x8);
|
|
707
707
|
float32x4_t b_low_f32x4 = vcvt_f32_f16(vget_low_f16(b_f16x8));
|
|
708
|
-
float32x4_t b_high_f32x4 =
|
|
708
|
+
float32x4_t b_high_f32x4 = vcvt_high_f32_f16(b_f16x8);
|
|
709
709
|
float32x4_t result_low_f32x4 = vaddq_f32(a_low_f32x4, b_low_f32x4);
|
|
710
710
|
float32x4_t result_high_f32x4 = vaddq_f32(a_high_f32x4, b_high_f32x4);
|
|
711
711
|
nk_b32_vec_t result_low_vec = nk_f32x4_to_e5m2x4_neon_(result_low_f32x4);
|
|
@@ -729,7 +729,7 @@ NK_PUBLIC void nk_each_scale_e4m3_neon(nk_e4m3_t const *a, nk_size_t n, nk_f32_t
|
|
|
729
729
|
for (; i + 8 <= n; i += 8) {
|
|
730
730
|
float16x8_t a_f16x8 = nk_e4m3x8_to_f16x8_neon_(vld1_u8(a + i));
|
|
731
731
|
float32x4_t a_low_f32x4 = vcvt_f32_f16(vget_low_f16(a_f16x8));
|
|
732
|
-
float32x4_t a_high_f32x4 =
|
|
732
|
+
float32x4_t a_high_f32x4 = vcvt_high_f32_f16(a_f16x8);
|
|
733
733
|
float32x4_t result_low_f32x4 = vfmaq_f32(beta_f32x4, a_low_f32x4, alpha_f32x4);
|
|
734
734
|
float32x4_t result_high_f32x4 = vfmaq_f32(beta_f32x4, a_high_f32x4, alpha_f32x4);
|
|
735
735
|
nk_b32_vec_t result_low_vec = nk_f32x4_to_e4m3x4_neon_(result_low_f32x4);
|
|
@@ -752,7 +752,7 @@ NK_PUBLIC void nk_each_scale_e5m2_neon(nk_e5m2_t const *a, nk_size_t n, nk_f32_t
|
|
|
752
752
|
for (; i + 8 <= n; i += 8) {
|
|
753
753
|
float16x8_t a_f16x8 = nk_e5m2x8_to_f16x8_neon_(vld1_u8(a + i));
|
|
754
754
|
float32x4_t a_low_f32x4 = vcvt_f32_f16(vget_low_f16(a_f16x8));
|
|
755
|
-
float32x4_t a_high_f32x4 =
|
|
755
|
+
float32x4_t a_high_f32x4 = vcvt_high_f32_f16(a_f16x8);
|
|
756
756
|
float32x4_t result_low_f32x4 = vfmaq_f32(beta_f32x4, a_low_f32x4, alpha_f32x4);
|
|
757
757
|
float32x4_t result_high_f32x4 = vfmaq_f32(beta_f32x4, a_high_f32x4, alpha_f32x4);
|
|
758
758
|
nk_b32_vec_t result_low_vec = nk_f32x4_to_e5m2x4_neon_(result_low_f32x4);
|
|
@@ -776,9 +776,9 @@ NK_PUBLIC void nk_each_blend_e4m3_neon(nk_e4m3_t const *a, nk_e4m3_t const *b, n
|
|
|
776
776
|
float16x8_t a_f16x8 = nk_e4m3x8_to_f16x8_neon_(vld1_u8(a + i));
|
|
777
777
|
float16x8_t b_f16x8 = nk_e4m3x8_to_f16x8_neon_(vld1_u8(b + i));
|
|
778
778
|
float32x4_t a_low_f32x4 = vcvt_f32_f16(vget_low_f16(a_f16x8));
|
|
779
|
-
float32x4_t a_high_f32x4 =
|
|
779
|
+
float32x4_t a_high_f32x4 = vcvt_high_f32_f16(a_f16x8);
|
|
780
780
|
float32x4_t b_low_f32x4 = vcvt_f32_f16(vget_low_f16(b_f16x8));
|
|
781
|
-
float32x4_t b_high_f32x4 =
|
|
781
|
+
float32x4_t b_high_f32x4 = vcvt_high_f32_f16(b_f16x8);
|
|
782
782
|
float32x4_t a_scaled_low_f32x4 = vmulq_f32(a_low_f32x4, alpha_f32x4);
|
|
783
783
|
float32x4_t a_scaled_high_f32x4 = vmulq_f32(a_high_f32x4, alpha_f32x4);
|
|
784
784
|
float32x4_t result_low_f32x4 = vfmaq_f32(a_scaled_low_f32x4, b_low_f32x4, beta_f32x4);
|
|
@@ -805,9 +805,9 @@ NK_PUBLIC void nk_each_blend_e5m2_neon(nk_e5m2_t const *a, nk_e5m2_t const *b, n
|
|
|
805
805
|
float16x8_t a_f16x8 = nk_e5m2x8_to_f16x8_neon_(vld1_u8(a + i));
|
|
806
806
|
float16x8_t b_f16x8 = nk_e5m2x8_to_f16x8_neon_(vld1_u8(b + i));
|
|
807
807
|
float32x4_t a_low_f32x4 = vcvt_f32_f16(vget_low_f16(a_f16x8));
|
|
808
|
-
float32x4_t a_high_f32x4 =
|
|
808
|
+
float32x4_t a_high_f32x4 = vcvt_high_f32_f16(a_f16x8);
|
|
809
809
|
float32x4_t b_low_f32x4 = vcvt_f32_f16(vget_low_f16(b_f16x8));
|
|
810
|
-
float32x4_t b_high_f32x4 =
|
|
810
|
+
float32x4_t b_high_f32x4 = vcvt_high_f32_f16(b_f16x8);
|
|
811
811
|
float32x4_t a_scaled_low_f32x4 = vmulq_f32(a_low_f32x4, alpha_f32x4);
|
|
812
812
|
float32x4_t a_scaled_high_f32x4 = vmulq_f32(a_high_f32x4, alpha_f32x4);
|
|
813
813
|
float32x4_t result_low_f32x4 = vfmaq_f32(a_scaled_low_f32x4, b_low_f32x4, beta_f32x4);
|
|
@@ -835,11 +835,11 @@ NK_PUBLIC void nk_each_fma_e4m3_neon(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_
|
|
|
835
835
|
float16x8_t b_f16x8 = nk_e4m3x8_to_f16x8_neon_(vld1_u8(b + i));
|
|
836
836
|
float16x8_t c_f16x8 = nk_e4m3x8_to_f16x8_neon_(vld1_u8(c + i));
|
|
837
837
|
float32x4_t a_low_f32x4 = vcvt_f32_f16(vget_low_f16(a_f16x8));
|
|
838
|
-
float32x4_t a_high_f32x4 =
|
|
838
|
+
float32x4_t a_high_f32x4 = vcvt_high_f32_f16(a_f16x8);
|
|
839
839
|
float32x4_t b_low_f32x4 = vcvt_f32_f16(vget_low_f16(b_f16x8));
|
|
840
|
-
float32x4_t b_high_f32x4 =
|
|
840
|
+
float32x4_t b_high_f32x4 = vcvt_high_f32_f16(b_f16x8);
|
|
841
841
|
float32x4_t c_low_f32x4 = vcvt_f32_f16(vget_low_f16(c_f16x8));
|
|
842
|
-
float32x4_t c_high_f32x4 =
|
|
842
|
+
float32x4_t c_high_f32x4 = vcvt_high_f32_f16(c_f16x8);
|
|
843
843
|
float32x4_t ab_low_f32x4 = vmulq_f32(a_low_f32x4, b_low_f32x4);
|
|
844
844
|
float32x4_t ab_high_f32x4 = vmulq_f32(a_high_f32x4, b_high_f32x4);
|
|
845
845
|
float32x4_t ab_scaled_low_f32x4 = vmulq_f32(ab_low_f32x4, alpha_f32x4);
|
|
@@ -870,11 +870,11 @@ NK_PUBLIC void nk_each_fma_e5m2_neon(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_
|
|
|
870
870
|
float16x8_t b_f16x8 = nk_e5m2x8_to_f16x8_neon_(vld1_u8(b + i));
|
|
871
871
|
float16x8_t c_f16x8 = nk_e5m2x8_to_f16x8_neon_(vld1_u8(c + i));
|
|
872
872
|
float32x4_t a_low_f32x4 = vcvt_f32_f16(vget_low_f16(a_f16x8));
|
|
873
|
-
float32x4_t a_high_f32x4 =
|
|
873
|
+
float32x4_t a_high_f32x4 = vcvt_high_f32_f16(a_f16x8);
|
|
874
874
|
float32x4_t b_low_f32x4 = vcvt_f32_f16(vget_low_f16(b_f16x8));
|
|
875
|
-
float32x4_t b_high_f32x4 =
|
|
875
|
+
float32x4_t b_high_f32x4 = vcvt_high_f32_f16(b_f16x8);
|
|
876
876
|
float32x4_t c_low_f32x4 = vcvt_f32_f16(vget_low_f16(c_f16x8));
|
|
877
|
-
float32x4_t c_high_f32x4 =
|
|
877
|
+
float32x4_t c_high_f32x4 = vcvt_high_f32_f16(c_f16x8);
|
|
878
878
|
float32x4_t ab_low_f32x4 = vmulq_f32(a_low_f32x4, b_low_f32x4);
|
|
879
879
|
float32x4_t ab_high_f32x4 = vmulq_f32(a_high_f32x4, b_high_f32x4);
|
|
880
880
|
float32x4_t ab_scaled_low_f32x4 = vmulq_f32(ab_low_f32x4, alpha_f32x4);
|
|
@@ -1089,6 +1089,40 @@ NK_PUBLIC void nk_each_fma_f64c_neon(nk_f64c_t const *a, nk_f64c_t const *b, nk_
|
|
|
1089
1089
|
}
|
|
1090
1090
|
}
|
|
1091
1091
|
|
|
1092
|
+
NK_PUBLIC void nk_each_sum_u8_neon(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_u8_t *result) {
|
|
1093
|
+
// The main loop:
|
|
1094
|
+
nk_size_t i = 0;
|
|
1095
|
+
for (; i + 16 <= n; i += 16) {
|
|
1096
|
+
uint8x16_t a_vec = vld1q_u8(a + i);
|
|
1097
|
+
uint8x16_t b_vec = vld1q_u8(b + i);
|
|
1098
|
+
uint8x16_t sum_vec = vqaddq_u8(a_vec, b_vec);
|
|
1099
|
+
vst1q_u8(result + i, sum_vec);
|
|
1100
|
+
}
|
|
1101
|
+
|
|
1102
|
+
// The tail:
|
|
1103
|
+
for (; i < n; ++i) {
|
|
1104
|
+
nk_f32_t sum = (nk_f32_t)a[i] + b[i];
|
|
1105
|
+
nk_f32_to_u8_serial(&sum, result + i);
|
|
1106
|
+
}
|
|
1107
|
+
}
|
|
1108
|
+
|
|
1109
|
+
NK_PUBLIC void nk_each_sum_i8_neon(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_i8_t *result) {
|
|
1110
|
+
// The main loop:
|
|
1111
|
+
nk_size_t i = 0;
|
|
1112
|
+
for (; i + 16 <= n; i += 16) {
|
|
1113
|
+
int8x16_t a_vec = vld1q_s8(a + i);
|
|
1114
|
+
int8x16_t b_vec = vld1q_s8(b + i);
|
|
1115
|
+
int8x16_t sum_vec = vqaddq_s8(a_vec, b_vec);
|
|
1116
|
+
vst1q_s8(result + i, sum_vec);
|
|
1117
|
+
}
|
|
1118
|
+
|
|
1119
|
+
// The tail:
|
|
1120
|
+
for (; i < n; ++i) {
|
|
1121
|
+
nk_f32_t sum = (nk_f32_t)a[i] + b[i];
|
|
1122
|
+
nk_f32_to_i8_serial(&sum, result + i);
|
|
1123
|
+
}
|
|
1124
|
+
}
|
|
1125
|
+
|
|
1092
1126
|
#if defined(__clang__)
|
|
1093
1127
|
#pragma clang attribute pop
|
|
1094
1128
|
#elif defined(__GNUC__)
|
|
@@ -8,18 +8,17 @@
|
|
|
8
8
|
*
|
|
9
9
|
* @section elementwise_neonbfdot_instructions ARM NEON BF16 Instructions (ARMv8.6-BF16)
|
|
10
10
|
*
|
|
11
|
-
* Intrinsic
|
|
12
|
-
*
|
|
13
|
-
*
|
|
14
|
-
*
|
|
15
|
-
*
|
|
16
|
-
*
|
|
17
|
-
*
|
|
18
|
-
*
|
|
19
|
-
*
|
|
20
|
-
*
|
|
21
|
-
*
|
|
22
|
-
* vdupq_n_f32 DUP (V.4S, scalar) 2cy 2/cy 4/cy
|
|
11
|
+
* Intrinsic Instruction A76 M5
|
|
12
|
+
* vld1_bf16 LD1 (V.4H) 4cy @ 2p 4cy @ 3p
|
|
13
|
+
* vst1_bf16 ST1 (V.4H) 2cy @ 2p 2cy @ 3p
|
|
14
|
+
* vcvt_f32_bf16 BFCVTN (V.4H, V.4S) 3cy @ 2p 3cy @ 4p
|
|
15
|
+
* vcvt_bf16_f32 BFCVT (V.4H, V.4S) 3cy @ 2p 3cy @ 4p
|
|
16
|
+
* vaddq_f32 FADD (V.4S, V.4S, V.4S) 2cy @ 2p 2cy @ 4p
|
|
17
|
+
* vmulq_f32 FMUL (V.4S, V.4S, V.4S) 3cy @ 2p 3cy @ 4p
|
|
18
|
+
* vmulq_n_f32 FMUL (V.4S, V.4S, scalar) 3cy @ 2p 3cy @ 4p
|
|
19
|
+
* vfmaq_f32 FMLA (V.4S, V.4S, V.4S) 4cy @ 2p 3cy @ 4p
|
|
20
|
+
* vfmaq_n_f32 FMLA (V.4S, V.4S, scalar) 4cy @ 2p 3cy @ 4p
|
|
21
|
+
* vdupq_n_f32 DUP (V.4S, scalar) 2cy @ 2p 2cy @ 4p
|
|
23
22
|
*
|
|
24
23
|
* The ARMv8.6-BF16 extension provides element-wise operations on BF16 data by converting to F32
|
|
25
24
|
* for arithmetic, then back to BF16 for storage. This preserves the dynamic range benefits of BF16
|
|
@@ -8,28 +8,27 @@
|
|
|
8
8
|
*
|
|
9
9
|
* @section elementwise_neonhalf_instructions ARM NEON FP16 Instructions (ARMv8.2-FP16)
|
|
10
10
|
*
|
|
11
|
-
* Intrinsic
|
|
12
|
-
*
|
|
13
|
-
*
|
|
14
|
-
*
|
|
15
|
-
*
|
|
16
|
-
*
|
|
17
|
-
*
|
|
18
|
-
*
|
|
19
|
-
*
|
|
20
|
-
*
|
|
21
|
-
*
|
|
22
|
-
*
|
|
23
|
-
*
|
|
24
|
-
*
|
|
25
|
-
*
|
|
26
|
-
*
|
|
27
|
-
*
|
|
28
|
-
*
|
|
29
|
-
*
|
|
30
|
-
*
|
|
31
|
-
*
|
|
32
|
-
* vqaddq_s8 SQADD (V.16B, V.16B, V.16B) 2cy 2/cy 4/cy
|
|
11
|
+
* Intrinsic Instruction A76 M5
|
|
12
|
+
* vld1q_f16 LD1 (V.8H) 4cy @ 2p 4cy @ 3p
|
|
13
|
+
* vst1q_f16 ST1 (V.8H) 2cy @ 2p 2cy @ 3p
|
|
14
|
+
* vaddq_f16 FADD (V.8H, V.8H, V.8H) 2cy @ 2p 2cy @ 4p
|
|
15
|
+
* vmulq_f16 FMUL (V.8H, V.8H, V.8H) 3cy @ 2p 3cy @ 4p
|
|
16
|
+
* vmulq_n_f16 FMUL (V.8H, V.8H, scalar) 3cy @ 2p 3cy @ 4p
|
|
17
|
+
* vfmaq_f16 FMLA (V.8H, V.8H, V.8H) 4cy @ 2p 4cy @ 4p
|
|
18
|
+
* vfmaq_n_f16 FMLA (V.8H, V.8H, scalar) 4cy @ 2p 4cy @ 4p
|
|
19
|
+
* vdupq_n_f16 DUP (V.8H, scalar) 2cy @ 2p 2cy @ 4p
|
|
20
|
+
* vld1_u8 LD1 (V.8B) 4cy @ 2p 4cy @ 3p
|
|
21
|
+
* vld1_s8 LD1 (V.8B) 4cy @ 2p 4cy @ 3p
|
|
22
|
+
* vmovl_u8 UXTL (V.8H, V.8B) 2cy @ 2p 2cy @ 4p
|
|
23
|
+
* vmovl_s8 SXTL (V.8H, V.8B) 2cy @ 2p 2cy @ 4p
|
|
24
|
+
* vcvtq_f16_u16 UCVTF (V.8H, V.8H) 3cy @ 2p 3cy @ 4p
|
|
25
|
+
* vcvtq_f16_s16 SCVTF (V.8H, V.8H) 3cy @ 2p 3cy @ 4p
|
|
26
|
+
* vcvtnq_u16_f16 FCVTNU (V.8H, V.8H) 3cy @ 2p 3cy @ 4p
|
|
27
|
+
* vcvtnq_s16_f16 FCVTNS (V.8H, V.8H) 3cy @ 2p 3cy @ 4p
|
|
28
|
+
* vqmovn_u16 UQXTN (V.8B, V.8H) 3cy @ 2p 3cy @ 4p
|
|
29
|
+
* vqmovn_s16 SQXTN (V.8B, V.8H) 3cy @ 2p 3cy @ 4p
|
|
30
|
+
* vqaddq_u8 UQADD (V.16B, V.16B, V.16B) 2cy @ 2p 3cy @ 2p
|
|
31
|
+
* vqaddq_s8 SQADD (V.16B, V.16B, V.16B) 2cy @ 2p 3cy @ 2p
|
|
33
32
|
*
|
|
34
33
|
* The ARMv8.2-FP16 extension enables native half-precision element-wise operations, processing 8
|
|
35
34
|
* F16 elements per instruction. Operations like sum, scale, blend, and fma work directly in F16,
|
|
@@ -46,6 +45,7 @@
|
|
|
46
45
|
|
|
47
46
|
#include "numkong/types.h"
|
|
48
47
|
#include "numkong/cast/serial.h" // `nk_f32_to_i8_serial`
|
|
48
|
+
#include "numkong/each/neon.h" // `nk_each_sum_u8_neon`, `nk_each_sum_i8_neon`
|
|
49
49
|
|
|
50
50
|
#if defined(__cplusplus)
|
|
51
51
|
extern "C" {
|
|
@@ -161,23 +161,6 @@ NK_PUBLIC void nk_each_fma_f16_neonhalf( //
|
|
|
161
161
|
beta_f16 * ((float16_t const *)c)[i];
|
|
162
162
|
}
|
|
163
163
|
|
|
164
|
-
NK_PUBLIC void nk_each_sum_u8_neonhalf(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_u8_t *result) {
|
|
165
|
-
// The main loop:
|
|
166
|
-
nk_size_t i = 0;
|
|
167
|
-
for (; i + 16 <= n; i += 16) {
|
|
168
|
-
uint8x16_t a_vec = vld1q_u8(a + i);
|
|
169
|
-
uint8x16_t b_vec = vld1q_u8(b + i);
|
|
170
|
-
uint8x16_t sum_vec = vqaddq_u8(a_vec, b_vec);
|
|
171
|
-
vst1q_u8(result + i, sum_vec);
|
|
172
|
-
}
|
|
173
|
-
|
|
174
|
-
// The tail:
|
|
175
|
-
for (; i < n; ++i) {
|
|
176
|
-
nk_f32_t sum = (nk_f32_t)a[i] + b[i];
|
|
177
|
-
nk_f32_to_u8_serial(&sum, result + i);
|
|
178
|
-
}
|
|
179
|
-
}
|
|
180
|
-
|
|
181
164
|
NK_PUBLIC void nk_each_scale_u8_neonhalf(nk_u8_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
182
165
|
nk_u8_t *result) {
|
|
183
166
|
float16_t alpha_f16 = (float16_t)*alpha;
|
|
@@ -213,7 +196,7 @@ NK_PUBLIC void nk_each_blend_u8_neonhalf( //
|
|
|
213
196
|
// 1. Simple addition, when both weights are equal to 1.0.
|
|
214
197
|
if (alpha_val == 1 && beta_val == 1) {
|
|
215
198
|
// In this case we can avoid expensive multiplications.
|
|
216
|
-
|
|
199
|
+
nk_each_sum_u8_neon(a, b, n, result);
|
|
217
200
|
return;
|
|
218
201
|
}
|
|
219
202
|
// 2. Just scaling, when one of the weights is equal to zero.
|
|
@@ -249,52 +232,6 @@ NK_PUBLIC void nk_each_blend_u8_neonhalf( //
|
|
|
249
232
|
}
|
|
250
233
|
}
|
|
251
234
|
|
|
252
|
-
NK_PUBLIC void nk_each_fma_u8_neonhalf( //
|
|
253
|
-
nk_u8_t const *a, nk_u8_t const *b, nk_u8_t const *c, //
|
|
254
|
-
nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta, nk_u8_t *result) {
|
|
255
|
-
float16_t alpha_f16 = (float16_t)*alpha;
|
|
256
|
-
float16_t beta_f16 = (float16_t)*beta;
|
|
257
|
-
|
|
258
|
-
// The main loop:
|
|
259
|
-
nk_size_t i = 0;
|
|
260
|
-
for (; i + 8 <= n; i += 8) {
|
|
261
|
-
uint8x8_t a_u8x8 = vld1_u8(a + i);
|
|
262
|
-
uint8x8_t b_u8x8 = vld1_u8(b + i);
|
|
263
|
-
uint8x8_t c_u8x8 = vld1_u8(c + i);
|
|
264
|
-
float16x8_t a_f16x8 = vcvtq_f16_u16(vmovl_u8(a_u8x8));
|
|
265
|
-
float16x8_t b_f16x8 = vcvtq_f16_u16(vmovl_u8(b_u8x8));
|
|
266
|
-
float16x8_t c_f16x8 = vcvtq_f16_u16(vmovl_u8(c_u8x8));
|
|
267
|
-
float16x8_t ab_f16x8 = vmulq_f16(a_f16x8, b_f16x8);
|
|
268
|
-
float16x8_t ab_scaled_f16x8 = vmulq_n_f16(ab_f16x8, alpha_f16);
|
|
269
|
-
float16x8_t result_f16x8 = vfmaq_n_f16(ab_scaled_f16x8, c_f16x8, beta_f16);
|
|
270
|
-
uint8x8_t result_u8x8 = vqmovn_u16(vcvtnq_u16_f16(result_f16x8));
|
|
271
|
-
vst1_u8(result + i, result_u8x8);
|
|
272
|
-
}
|
|
273
|
-
|
|
274
|
-
// The tail:
|
|
275
|
-
for (; i < n; ++i) {
|
|
276
|
-
nk_f32_t sum = alpha_f16 * a[i] * b[i] + beta_f16 * c[i];
|
|
277
|
-
nk_f32_to_u8_serial(&sum, result + i);
|
|
278
|
-
}
|
|
279
|
-
}
|
|
280
|
-
|
|
281
|
-
NK_PUBLIC void nk_each_sum_i8_neonhalf(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_i8_t *result) {
|
|
282
|
-
// The main loop:
|
|
283
|
-
nk_size_t i = 0;
|
|
284
|
-
for (; i + 16 <= n; i += 16) {
|
|
285
|
-
int8x16_t a_vec = vld1q_s8(a + i);
|
|
286
|
-
int8x16_t b_vec = vld1q_s8(b + i);
|
|
287
|
-
int8x16_t sum_vec = vqaddq_s8(a_vec, b_vec);
|
|
288
|
-
vst1q_s8(result + i, sum_vec);
|
|
289
|
-
}
|
|
290
|
-
|
|
291
|
-
// The tail:
|
|
292
|
-
for (; i < n; ++i) {
|
|
293
|
-
nk_f32_t sum = (nk_f32_t)a[i] + b[i];
|
|
294
|
-
nk_f32_to_i8_serial(&sum, result + i);
|
|
295
|
-
}
|
|
296
|
-
}
|
|
297
|
-
|
|
298
235
|
NK_PUBLIC void nk_each_scale_i8_neonhalf(nk_i8_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
299
236
|
nk_i8_t *result) {
|
|
300
237
|
float16_t alpha_f16 = (float16_t)*alpha;
|
|
@@ -330,7 +267,7 @@ NK_PUBLIC void nk_each_blend_i8_neonhalf( //
|
|
|
330
267
|
// 1. Simple addition, when both weights are equal to 1.0.
|
|
331
268
|
if (alpha_val == 1 && beta_val == 1) {
|
|
332
269
|
// In this case we can avoid expensive multiplications.
|
|
333
|
-
|
|
270
|
+
nk_each_sum_i8_neon(a, b, n, result);
|
|
334
271
|
return;
|
|
335
272
|
}
|
|
336
273
|
// 2. Just scaling, when one of the weights is equal to zero.
|
|
@@ -366,35 +303,6 @@ NK_PUBLIC void nk_each_blend_i8_neonhalf( //
|
|
|
366
303
|
}
|
|
367
304
|
}
|
|
368
305
|
|
|
369
|
-
NK_PUBLIC void nk_each_fma_i8_neonhalf( //
|
|
370
|
-
nk_i8_t const *a, nk_i8_t const *b, nk_i8_t const *c, //
|
|
371
|
-
nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta, nk_i8_t *result) {
|
|
372
|
-
float16_t alpha_f16 = (float16_t)*alpha;
|
|
373
|
-
float16_t beta_f16 = (float16_t)*beta;
|
|
374
|
-
|
|
375
|
-
// The main loop:
|
|
376
|
-
nk_size_t i = 0;
|
|
377
|
-
for (; i + 8 <= n; i += 8) {
|
|
378
|
-
int8x8_t a_i8x8 = vld1_s8(a + i);
|
|
379
|
-
int8x8_t b_i8x8 = vld1_s8(b + i);
|
|
380
|
-
int8x8_t c_i8x8 = vld1_s8(c + i);
|
|
381
|
-
float16x8_t a_f16x8 = vcvtq_f16_s16(vmovl_s8(a_i8x8));
|
|
382
|
-
float16x8_t b_f16x8 = vcvtq_f16_s16(vmovl_s8(b_i8x8));
|
|
383
|
-
float16x8_t c_f16x8 = vcvtq_f16_s16(vmovl_s8(c_i8x8));
|
|
384
|
-
float16x8_t ab_f16x8 = vmulq_f16(a_f16x8, b_f16x8);
|
|
385
|
-
float16x8_t ab_scaled_f16x8 = vmulq_n_f16(ab_f16x8, alpha_f16);
|
|
386
|
-
float16x8_t result_f16x8 = vfmaq_n_f16(ab_scaled_f16x8, c_f16x8, beta_f16);
|
|
387
|
-
int8x8_t result_i8x8 = vqmovn_s16(vcvtnq_s16_f16(result_f16x8));
|
|
388
|
-
vst1_s8(result + i, result_i8x8);
|
|
389
|
-
}
|
|
390
|
-
|
|
391
|
-
// The tail:
|
|
392
|
-
for (; i < n; ++i) {
|
|
393
|
-
nk_f32_t sum = alpha_f16 * a[i] * b[i] + beta_f16 * c[i];
|
|
394
|
-
nk_f32_to_i8_serial(&sum, result + i);
|
|
395
|
-
}
|
|
396
|
-
}
|
|
397
|
-
|
|
398
306
|
#if defined(__clang__)
|
|
399
307
|
#pragma clang attribute pop
|
|
400
308
|
#elif defined(__GNUC__)
|
|
@@ -185,8 +185,8 @@ NK_PUBLIC void nk_each_sum_e5m2_rvv(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_s
|
|
|
185
185
|
NK_PUBLIC void nk_each_scale_f64_rvv(nk_f64_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
|
|
186
186
|
nk_f64_t *result) {
|
|
187
187
|
nk_f64_t alpha_val = *alpha, beta_val = *beta;
|
|
188
|
-
nk_size_t
|
|
189
|
-
vfloat64m4_t beta_f64m4 = __riscv_vfmv_v_f_f64m4(beta_val,
|
|
188
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e64m4();
|
|
189
|
+
vfloat64m4_t beta_f64m4 = __riscv_vfmv_v_f_f64m4(beta_val, max_vector_length);
|
|
190
190
|
for (nk_size_t vector_length; n > 0; n -= vector_length, a += vector_length, result += vector_length) {
|
|
191
191
|
vector_length = __riscv_vsetvl_e64m4(n);
|
|
192
192
|
vfloat64m4_t a_f64m4 = __riscv_vle64_v_f64m4(a, vector_length);
|
|
@@ -198,8 +198,8 @@ NK_PUBLIC void nk_each_scale_f64_rvv(nk_f64_t const *a, nk_size_t n, nk_f64_t co
|
|
|
198
198
|
NK_PUBLIC void nk_each_scale_f32_rvv(nk_f32_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
199
199
|
nk_f32_t *result) {
|
|
200
200
|
nk_f32_t alpha_val = *alpha, beta_val = *beta;
|
|
201
|
-
nk_size_t
|
|
202
|
-
vfloat32m4_t beta_f32m4 = __riscv_vfmv_v_f_f32m4(beta_val,
|
|
201
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m4();
|
|
202
|
+
vfloat32m4_t beta_f32m4 = __riscv_vfmv_v_f_f32m4(beta_val, max_vector_length);
|
|
203
203
|
for (nk_size_t vector_length; n > 0; n -= vector_length, a += vector_length, result += vector_length) {
|
|
204
204
|
vector_length = __riscv_vsetvl_e32m4(n);
|
|
205
205
|
vfloat32m4_t a_f32m4 = __riscv_vle32_v_f32m4(a, vector_length);
|
|
@@ -211,8 +211,8 @@ NK_PUBLIC void nk_each_scale_f32_rvv(nk_f32_t const *a, nk_size_t n, nk_f32_t co
|
|
|
211
211
|
NK_PUBLIC void nk_each_scale_f16_rvv(nk_f16_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
212
212
|
nk_f16_t *result) {
|
|
213
213
|
nk_f32_t alpha_val = *alpha, beta_val = *beta;
|
|
214
|
-
nk_size_t
|
|
215
|
-
vfloat32m2_t beta_f32m2 = __riscv_vfmv_v_f_f32m2(beta_val,
|
|
214
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
|
|
215
|
+
vfloat32m2_t beta_f32m2 = __riscv_vfmv_v_f_f32m2(beta_val, max_vector_length);
|
|
216
216
|
for (nk_size_t vector_length; n > 0; n -= vector_length, a += vector_length, result += vector_length) {
|
|
217
217
|
vector_length = __riscv_vsetvl_e16m1(n);
|
|
218
218
|
vuint16m1_t a_u16m1 = __riscv_vle16_v_u16m1((nk_u16_t const *)a, vector_length);
|
|
@@ -226,8 +226,8 @@ NK_PUBLIC void nk_each_scale_f16_rvv(nk_f16_t const *a, nk_size_t n, nk_f32_t co
|
|
|
226
226
|
NK_PUBLIC void nk_each_scale_bf16_rvv(nk_bf16_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
227
227
|
nk_bf16_t *result) {
|
|
228
228
|
nk_f32_t alpha_val = *alpha, beta_val = *beta;
|
|
229
|
-
nk_size_t
|
|
230
|
-
vfloat32m2_t beta_f32m2 = __riscv_vfmv_v_f_f32m2(beta_val,
|
|
229
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
|
|
230
|
+
vfloat32m2_t beta_f32m2 = __riscv_vfmv_v_f_f32m2(beta_val, max_vector_length);
|
|
231
231
|
for (nk_size_t vector_length; n > 0; n -= vector_length, a += vector_length, result += vector_length) {
|
|
232
232
|
vector_length = __riscv_vsetvl_e16m1(n);
|
|
233
233
|
vuint16m1_t a_u16m1 = __riscv_vle16_v_u16m1((nk_u16_t const *)a, vector_length);
|
|
@@ -241,8 +241,8 @@ NK_PUBLIC void nk_each_scale_bf16_rvv(nk_bf16_t const *a, nk_size_t n, nk_f32_t
|
|
|
241
241
|
NK_PUBLIC void nk_each_scale_i8_rvv(nk_i8_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
242
242
|
nk_i8_t *result) {
|
|
243
243
|
nk_f32_t alpha_val = *alpha, beta_val = *beta;
|
|
244
|
-
nk_size_t
|
|
245
|
-
vfloat32m4_t beta_f32m4 = __riscv_vfmv_v_f_f32m4(beta_val,
|
|
244
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m4();
|
|
245
|
+
vfloat32m4_t beta_f32m4 = __riscv_vfmv_v_f_f32m4(beta_val, max_vector_length);
|
|
246
246
|
for (nk_size_t vector_length; n > 0; n -= vector_length, a += vector_length, result += vector_length) {
|
|
247
247
|
vector_length = __riscv_vsetvl_e8m1(n);
|
|
248
248
|
vint8m1_t a_i8m1 = __riscv_vle8_v_i8m1(a, vector_length);
|
|
@@ -262,8 +262,8 @@ NK_PUBLIC void nk_each_scale_i8_rvv(nk_i8_t const *a, nk_size_t n, nk_f32_t cons
|
|
|
262
262
|
NK_PUBLIC void nk_each_scale_u8_rvv(nk_u8_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
263
263
|
nk_u8_t *result) {
|
|
264
264
|
nk_f32_t alpha_val = *alpha, beta_val = *beta;
|
|
265
|
-
nk_size_t
|
|
266
|
-
vfloat32m4_t beta_f32m4 = __riscv_vfmv_v_f_f32m4(beta_val,
|
|
265
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m4();
|
|
266
|
+
vfloat32m4_t beta_f32m4 = __riscv_vfmv_v_f_f32m4(beta_val, max_vector_length);
|
|
267
267
|
for (nk_size_t vector_length; n > 0; n -= vector_length, a += vector_length, result += vector_length) {
|
|
268
268
|
vector_length = __riscv_vsetvl_e8m1(n);
|
|
269
269
|
vuint8m1_t a_u8m1 = __riscv_vle8_v_u8m1(a, vector_length);
|
|
@@ -283,8 +283,8 @@ NK_PUBLIC void nk_each_scale_u8_rvv(nk_u8_t const *a, nk_size_t n, nk_f32_t cons
|
|
|
283
283
|
NK_PUBLIC void nk_each_scale_i16_rvv(nk_i16_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
284
284
|
nk_i16_t *result) {
|
|
285
285
|
nk_f32_t alpha_val = *alpha, beta_val = *beta;
|
|
286
|
-
nk_size_t
|
|
287
|
-
vfloat32m2_t beta_f32m2 = __riscv_vfmv_v_f_f32m2(beta_val,
|
|
286
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
|
|
287
|
+
vfloat32m2_t beta_f32m2 = __riscv_vfmv_v_f_f32m2(beta_val, max_vector_length);
|
|
288
288
|
for (nk_size_t vector_length; n > 0; n -= vector_length, a += vector_length, result += vector_length) {
|
|
289
289
|
vector_length = __riscv_vsetvl_e16m1(n);
|
|
290
290
|
vint16m1_t a_i16m1 = __riscv_vle16_v_i16m1(a, vector_length);
|
|
@@ -302,8 +302,8 @@ NK_PUBLIC void nk_each_scale_i16_rvv(nk_i16_t const *a, nk_size_t n, nk_f32_t co
|
|
|
302
302
|
NK_PUBLIC void nk_each_scale_u16_rvv(nk_u16_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
303
303
|
nk_u16_t *result) {
|
|
304
304
|
nk_f32_t alpha_val = *alpha, beta_val = *beta;
|
|
305
|
-
nk_size_t
|
|
306
|
-
vfloat32m2_t beta_f32m2 = __riscv_vfmv_v_f_f32m2(beta_val,
|
|
305
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
|
|
306
|
+
vfloat32m2_t beta_f32m2 = __riscv_vfmv_v_f_f32m2(beta_val, max_vector_length);
|
|
307
307
|
for (nk_size_t vector_length; n > 0; n -= vector_length, a += vector_length, result += vector_length) {
|
|
308
308
|
vector_length = __riscv_vsetvl_e16m1(n);
|
|
309
309
|
vuint16m1_t a_u16m1 = __riscv_vle16_v_u16m1(a, vector_length);
|
|
@@ -321,8 +321,8 @@ NK_PUBLIC void nk_each_scale_u16_rvv(nk_u16_t const *a, nk_size_t n, nk_f32_t co
|
|
|
321
321
|
NK_PUBLIC void nk_each_scale_i32_rvv(nk_i32_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
|
|
322
322
|
nk_i32_t *result) {
|
|
323
323
|
nk_f64_t alpha_val = *alpha, beta_val = *beta;
|
|
324
|
-
nk_size_t
|
|
325
|
-
vfloat64m2_t beta_f64m2 = __riscv_vfmv_v_f_f64m2(beta_val,
|
|
324
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e64m2();
|
|
325
|
+
vfloat64m2_t beta_f64m2 = __riscv_vfmv_v_f_f64m2(beta_val, max_vector_length);
|
|
326
326
|
for (nk_size_t vector_length; n > 0; n -= vector_length, a += vector_length, result += vector_length) {
|
|
327
327
|
vector_length = __riscv_vsetvl_e32m1(n);
|
|
328
328
|
vint32m1_t a_i32m1 = __riscv_vle32_v_i32m1(a, vector_length);
|
|
@@ -338,8 +338,8 @@ NK_PUBLIC void nk_each_scale_i32_rvv(nk_i32_t const *a, nk_size_t n, nk_f64_t co
|
|
|
338
338
|
NK_PUBLIC void nk_each_scale_u32_rvv(nk_u32_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
|
|
339
339
|
nk_u32_t *result) {
|
|
340
340
|
nk_f64_t alpha_val = *alpha, beta_val = *beta;
|
|
341
|
-
nk_size_t
|
|
342
|
-
vfloat64m2_t beta_f64m2 = __riscv_vfmv_v_f_f64m2(beta_val,
|
|
341
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e64m2();
|
|
342
|
+
vfloat64m2_t beta_f64m2 = __riscv_vfmv_v_f_f64m2(beta_val, max_vector_length);
|
|
343
343
|
for (nk_size_t vector_length; n > 0; n -= vector_length, a += vector_length, result += vector_length) {
|
|
344
344
|
vector_length = __riscv_vsetvl_e32m1(n);
|
|
345
345
|
vuint32m1_t a_u32m1 = __riscv_vle32_v_u32m1(a, vector_length);
|
|
@@ -355,8 +355,8 @@ NK_PUBLIC void nk_each_scale_u32_rvv(nk_u32_t const *a, nk_size_t n, nk_f64_t co
|
|
|
355
355
|
NK_PUBLIC void nk_each_scale_i64_rvv(nk_i64_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
|
|
356
356
|
nk_i64_t *result) {
|
|
357
357
|
nk_f64_t alpha_val = *alpha, beta_val = *beta;
|
|
358
|
-
nk_size_t
|
|
359
|
-
vfloat64m4_t beta_f64m4 = __riscv_vfmv_v_f_f64m4(beta_val,
|
|
358
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e64m4();
|
|
359
|
+
vfloat64m4_t beta_f64m4 = __riscv_vfmv_v_f_f64m4(beta_val, max_vector_length);
|
|
360
360
|
for (nk_size_t vector_length; n > 0; n -= vector_length, a += vector_length, result += vector_length) {
|
|
361
361
|
vector_length = __riscv_vsetvl_e64m4(n);
|
|
362
362
|
vint64m4_t a_i64m4 = __riscv_vle64_v_i64m4(a, vector_length);
|
|
@@ -370,8 +370,8 @@ NK_PUBLIC void nk_each_scale_i64_rvv(nk_i64_t const *a, nk_size_t n, nk_f64_t co
|
|
|
370
370
|
NK_PUBLIC void nk_each_scale_u64_rvv(nk_u64_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
|
|
371
371
|
nk_u64_t *result) {
|
|
372
372
|
nk_f64_t alpha_val = *alpha, beta_val = *beta;
|
|
373
|
-
nk_size_t
|
|
374
|
-
vfloat64m4_t beta_f64m4 = __riscv_vfmv_v_f_f64m4(beta_val,
|
|
373
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e64m4();
|
|
374
|
+
vfloat64m4_t beta_f64m4 = __riscv_vfmv_v_f_f64m4(beta_val, max_vector_length);
|
|
375
375
|
for (nk_size_t vector_length; n > 0; n -= vector_length, a += vector_length, result += vector_length) {
|
|
376
376
|
vector_length = __riscv_vsetvl_e64m4(n);
|
|
377
377
|
vuint64m4_t a_u64m4 = __riscv_vle64_v_u64m4(a, vector_length);
|
|
@@ -386,8 +386,8 @@ NK_PUBLIC void nk_each_scale_u64_rvv(nk_u64_t const *a, nk_size_t n, nk_f64_t co
|
|
|
386
386
|
NK_PUBLIC void nk_each_scale_e4m3_rvv(nk_e4m3_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
387
387
|
nk_e4m3_t *result) {
|
|
388
388
|
nk_f32_t alpha_val = *alpha, beta_val = *beta;
|
|
389
|
-
nk_size_t
|
|
390
|
-
vfloat32m4_t beta_f32m4 = __riscv_vfmv_v_f_f32m4(beta_val,
|
|
389
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m4();
|
|
390
|
+
vfloat32m4_t beta_f32m4 = __riscv_vfmv_v_f_f32m4(beta_val, max_vector_length);
|
|
391
391
|
for (nk_size_t vector_length; n > 0; n -= vector_length, a += vector_length, result += vector_length) {
|
|
392
392
|
vector_length = __riscv_vsetvl_e8m1(n);
|
|
393
393
|
vuint8m1_t a_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)a, vector_length);
|
|
@@ -401,8 +401,8 @@ NK_PUBLIC void nk_each_scale_e4m3_rvv(nk_e4m3_t const *a, nk_size_t n, nk_f32_t
|
|
|
401
401
|
NK_PUBLIC void nk_each_scale_e5m2_rvv(nk_e5m2_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
402
402
|
nk_e5m2_t *result) {
|
|
403
403
|
nk_f32_t alpha_val = *alpha, beta_val = *beta;
|
|
404
|
-
nk_size_t
|
|
405
|
-
vfloat32m4_t beta_f32m4 = __riscv_vfmv_v_f_f32m4(beta_val,
|
|
404
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m4();
|
|
405
|
+
vfloat32m4_t beta_f32m4 = __riscv_vfmv_v_f_f32m4(beta_val, max_vector_length);
|
|
406
406
|
for (nk_size_t vector_length; n > 0; n -= vector_length, a += vector_length, result += vector_length) {
|
|
407
407
|
vector_length = __riscv_vsetvl_e8m1(n);
|
|
408
408
|
vuint8m1_t a_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)a, vector_length);
|