numkong 7.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +201 -0
- package/README.md +495 -0
- package/binding.gyp +540 -0
- package/c/dispatch.h +512 -0
- package/c/dispatch_bf16.c +389 -0
- package/c/dispatch_bf16c.c +52 -0
- package/c/dispatch_e2m3.c +263 -0
- package/c/dispatch_e3m2.c +243 -0
- package/c/dispatch_e4m3.c +276 -0
- package/c/dispatch_e5m2.c +272 -0
- package/c/dispatch_f16.c +376 -0
- package/c/dispatch_f16c.c +58 -0
- package/c/dispatch_f32.c +378 -0
- package/c/dispatch_f32c.c +99 -0
- package/c/dispatch_f64.c +296 -0
- package/c/dispatch_f64c.c +98 -0
- package/c/dispatch_i16.c +96 -0
- package/c/dispatch_i32.c +89 -0
- package/c/dispatch_i4.c +150 -0
- package/c/dispatch_i64.c +86 -0
- package/c/dispatch_i8.c +289 -0
- package/c/dispatch_other.c +330 -0
- package/c/dispatch_u1.c +148 -0
- package/c/dispatch_u16.c +124 -0
- package/c/dispatch_u32.c +118 -0
- package/c/dispatch_u4.c +150 -0
- package/c/dispatch_u64.c +102 -0
- package/c/dispatch_u8.c +303 -0
- package/c/numkong.c +950 -0
- package/include/README.md +573 -0
- package/include/module.modulemap +129 -0
- package/include/numkong/attention/sapphireamx.h +1361 -0
- package/include/numkong/attention/sme.h +2066 -0
- package/include/numkong/attention.h +49 -0
- package/include/numkong/capabilities.h +748 -0
- package/include/numkong/cast/README.md +262 -0
- package/include/numkong/cast/haswell.h +975 -0
- package/include/numkong/cast/icelake.h +470 -0
- package/include/numkong/cast/neon.h +1192 -0
- package/include/numkong/cast/rvv.h +1021 -0
- package/include/numkong/cast/sapphire.h +262 -0
- package/include/numkong/cast/serial.h +2262 -0
- package/include/numkong/cast/skylake.h +856 -0
- package/include/numkong/cast/v128relaxed.h +180 -0
- package/include/numkong/cast.h +230 -0
- package/include/numkong/curved/README.md +223 -0
- package/include/numkong/curved/genoa.h +182 -0
- package/include/numkong/curved/haswell.h +276 -0
- package/include/numkong/curved/neon.h +205 -0
- package/include/numkong/curved/neonbfdot.h +212 -0
- package/include/numkong/curved/neonhalf.h +212 -0
- package/include/numkong/curved/rvv.h +305 -0
- package/include/numkong/curved/serial.h +207 -0
- package/include/numkong/curved/skylake.h +457 -0
- package/include/numkong/curved/smef64.h +506 -0
- package/include/numkong/curved.h +517 -0
- package/include/numkong/curved.hpp +144 -0
- package/include/numkong/dot/README.md +425 -0
- package/include/numkong/dot/alder.h +563 -0
- package/include/numkong/dot/genoa.h +315 -0
- package/include/numkong/dot/haswell.h +1688 -0
- package/include/numkong/dot/icelake.h +883 -0
- package/include/numkong/dot/neon.h +818 -0
- package/include/numkong/dot/neonbfdot.h +244 -0
- package/include/numkong/dot/neonfhm.h +360 -0
- package/include/numkong/dot/neonhalf.h +198 -0
- package/include/numkong/dot/neonsdot.h +508 -0
- package/include/numkong/dot/rvv.h +714 -0
- package/include/numkong/dot/rvvbb.h +72 -0
- package/include/numkong/dot/rvvbf16.h +123 -0
- package/include/numkong/dot/rvvhalf.h +129 -0
- package/include/numkong/dot/sapphire.h +141 -0
- package/include/numkong/dot/serial.h +838 -0
- package/include/numkong/dot/sierra.h +405 -0
- package/include/numkong/dot/skylake.h +1084 -0
- package/include/numkong/dot/sve.h +379 -0
- package/include/numkong/dot/svebfdot.h +74 -0
- package/include/numkong/dot/svehalf.h +123 -0
- package/include/numkong/dot/v128relaxed.h +1258 -0
- package/include/numkong/dot.h +1070 -0
- package/include/numkong/dot.hpp +94 -0
- package/include/numkong/dots/README.md +496 -0
- package/include/numkong/dots/alder.h +114 -0
- package/include/numkong/dots/genoa.h +94 -0
- package/include/numkong/dots/haswell.h +295 -0
- package/include/numkong/dots/icelake.h +171 -0
- package/include/numkong/dots/neon.h +120 -0
- package/include/numkong/dots/neonbfdot.h +58 -0
- package/include/numkong/dots/neonfhm.h +94 -0
- package/include/numkong/dots/neonhalf.h +57 -0
- package/include/numkong/dots/neonsdot.h +108 -0
- package/include/numkong/dots/rvv.h +2486 -0
- package/include/numkong/dots/sapphireamx.h +3973 -0
- package/include/numkong/dots/serial.h +2844 -0
- package/include/numkong/dots/sierra.h +97 -0
- package/include/numkong/dots/skylake.h +196 -0
- package/include/numkong/dots/sme.h +5372 -0
- package/include/numkong/dots/smebi32.h +461 -0
- package/include/numkong/dots/smef64.h +1318 -0
- package/include/numkong/dots/smehalf.h +47 -0
- package/include/numkong/dots/v128relaxed.h +294 -0
- package/include/numkong/dots.h +2804 -0
- package/include/numkong/dots.hpp +639 -0
- package/include/numkong/each/README.md +469 -0
- package/include/numkong/each/haswell.h +1658 -0
- package/include/numkong/each/icelake.h +272 -0
- package/include/numkong/each/neon.h +1104 -0
- package/include/numkong/each/neonbfdot.h +212 -0
- package/include/numkong/each/neonhalf.h +410 -0
- package/include/numkong/each/rvv.h +1121 -0
- package/include/numkong/each/sapphire.h +477 -0
- package/include/numkong/each/serial.h +260 -0
- package/include/numkong/each/skylake.h +1562 -0
- package/include/numkong/each.h +2146 -0
- package/include/numkong/each.hpp +434 -0
- package/include/numkong/geospatial/README.md +147 -0
- package/include/numkong/geospatial/haswell.h +593 -0
- package/include/numkong/geospatial/neon.h +571 -0
- package/include/numkong/geospatial/rvv.h +701 -0
- package/include/numkong/geospatial/serial.h +309 -0
- package/include/numkong/geospatial/skylake.h +577 -0
- package/include/numkong/geospatial/v128relaxed.h +613 -0
- package/include/numkong/geospatial.h +453 -0
- package/include/numkong/geospatial.hpp +235 -0
- package/include/numkong/matrix.hpp +336 -0
- package/include/numkong/maxsim/README.md +187 -0
- package/include/numkong/maxsim/alder.h +511 -0
- package/include/numkong/maxsim/genoa.h +115 -0
- package/include/numkong/maxsim/haswell.h +553 -0
- package/include/numkong/maxsim/icelake.h +480 -0
- package/include/numkong/maxsim/neonsdot.h +394 -0
- package/include/numkong/maxsim/sapphireamx.h +877 -0
- package/include/numkong/maxsim/serial.h +490 -0
- package/include/numkong/maxsim/sme.h +929 -0
- package/include/numkong/maxsim/v128relaxed.h +280 -0
- package/include/numkong/maxsim.h +571 -0
- package/include/numkong/maxsim.hpp +133 -0
- package/include/numkong/mesh/README.md +227 -0
- package/include/numkong/mesh/haswell.h +2235 -0
- package/include/numkong/mesh/neon.h +1329 -0
- package/include/numkong/mesh/neonbfdot.h +842 -0
- package/include/numkong/mesh/neonhalf.h +616 -0
- package/include/numkong/mesh/rvv.h +916 -0
- package/include/numkong/mesh/serial.h +742 -0
- package/include/numkong/mesh/skylake.h +1135 -0
- package/include/numkong/mesh/v128relaxed.h +1052 -0
- package/include/numkong/mesh.h +652 -0
- package/include/numkong/mesh.hpp +762 -0
- package/include/numkong/numkong.h +78 -0
- package/include/numkong/numkong.hpp +57 -0
- package/include/numkong/probability/README.md +173 -0
- package/include/numkong/probability/haswell.h +267 -0
- package/include/numkong/probability/neon.h +225 -0
- package/include/numkong/probability/rvv.h +409 -0
- package/include/numkong/probability/serial.h +169 -0
- package/include/numkong/probability/skylake.h +324 -0
- package/include/numkong/probability.h +383 -0
- package/include/numkong/probability.hpp +120 -0
- package/include/numkong/random.h +50 -0
- package/include/numkong/random.hpp +285 -0
- package/include/numkong/reduce/README.md +547 -0
- package/include/numkong/reduce/alder.h +632 -0
- package/include/numkong/reduce/genoa.h +201 -0
- package/include/numkong/reduce/haswell.h +3783 -0
- package/include/numkong/reduce/icelake.h +549 -0
- package/include/numkong/reduce/neon.h +3841 -0
- package/include/numkong/reduce/neonbfdot.h +353 -0
- package/include/numkong/reduce/neonfhm.h +665 -0
- package/include/numkong/reduce/neonhalf.h +157 -0
- package/include/numkong/reduce/neonsdot.h +357 -0
- package/include/numkong/reduce/rvv.h +3407 -0
- package/include/numkong/reduce/serial.h +757 -0
- package/include/numkong/reduce/sierra.h +338 -0
- package/include/numkong/reduce/skylake.h +3792 -0
- package/include/numkong/reduce/v128relaxed.h +2302 -0
- package/include/numkong/reduce.h +1597 -0
- package/include/numkong/reduce.hpp +633 -0
- package/include/numkong/scalar/README.md +89 -0
- package/include/numkong/scalar/haswell.h +113 -0
- package/include/numkong/scalar/neon.h +122 -0
- package/include/numkong/scalar/neonhalf.h +70 -0
- package/include/numkong/scalar/rvv.h +211 -0
- package/include/numkong/scalar/sapphire.h +63 -0
- package/include/numkong/scalar/serial.h +332 -0
- package/include/numkong/scalar/v128relaxed.h +56 -0
- package/include/numkong/scalar.h +683 -0
- package/include/numkong/set/README.md +179 -0
- package/include/numkong/set/haswell.h +334 -0
- package/include/numkong/set/icelake.h +485 -0
- package/include/numkong/set/neon.h +364 -0
- package/include/numkong/set/rvv.h +226 -0
- package/include/numkong/set/rvvbb.h +117 -0
- package/include/numkong/set/serial.h +174 -0
- package/include/numkong/set/sve.h +185 -0
- package/include/numkong/set/v128relaxed.h +240 -0
- package/include/numkong/set.h +457 -0
- package/include/numkong/set.hpp +114 -0
- package/include/numkong/sets/README.md +149 -0
- package/include/numkong/sets/haswell.h +63 -0
- package/include/numkong/sets/icelake.h +66 -0
- package/include/numkong/sets/neon.h +61 -0
- package/include/numkong/sets/serial.h +43 -0
- package/include/numkong/sets/smebi32.h +1099 -0
- package/include/numkong/sets/v128relaxed.h +58 -0
- package/include/numkong/sets.h +339 -0
- package/include/numkong/sparse/README.md +156 -0
- package/include/numkong/sparse/icelake.h +463 -0
- package/include/numkong/sparse/neon.h +288 -0
- package/include/numkong/sparse/serial.h +117 -0
- package/include/numkong/sparse/sve2.h +507 -0
- package/include/numkong/sparse/turin.h +322 -0
- package/include/numkong/sparse.h +363 -0
- package/include/numkong/sparse.hpp +113 -0
- package/include/numkong/spatial/README.md +435 -0
- package/include/numkong/spatial/alder.h +607 -0
- package/include/numkong/spatial/genoa.h +290 -0
- package/include/numkong/spatial/haswell.h +960 -0
- package/include/numkong/spatial/icelake.h +586 -0
- package/include/numkong/spatial/neon.h +773 -0
- package/include/numkong/spatial/neonbfdot.h +165 -0
- package/include/numkong/spatial/neonhalf.h +118 -0
- package/include/numkong/spatial/neonsdot.h +261 -0
- package/include/numkong/spatial/rvv.h +984 -0
- package/include/numkong/spatial/rvvbf16.h +123 -0
- package/include/numkong/spatial/rvvhalf.h +117 -0
- package/include/numkong/spatial/sapphire.h +343 -0
- package/include/numkong/spatial/serial.h +346 -0
- package/include/numkong/spatial/sierra.h +323 -0
- package/include/numkong/spatial/skylake.h +606 -0
- package/include/numkong/spatial/sve.h +224 -0
- package/include/numkong/spatial/svebfdot.h +122 -0
- package/include/numkong/spatial/svehalf.h +109 -0
- package/include/numkong/spatial/v128relaxed.h +717 -0
- package/include/numkong/spatial.h +1425 -0
- package/include/numkong/spatial.hpp +183 -0
- package/include/numkong/spatials/README.md +580 -0
- package/include/numkong/spatials/alder.h +94 -0
- package/include/numkong/spatials/genoa.h +94 -0
- package/include/numkong/spatials/haswell.h +219 -0
- package/include/numkong/spatials/icelake.h +113 -0
- package/include/numkong/spatials/neon.h +109 -0
- package/include/numkong/spatials/neonbfdot.h +60 -0
- package/include/numkong/spatials/neonfhm.h +92 -0
- package/include/numkong/spatials/neonhalf.h +58 -0
- package/include/numkong/spatials/neonsdot.h +109 -0
- package/include/numkong/spatials/rvv.h +1960 -0
- package/include/numkong/spatials/sapphireamx.h +1149 -0
- package/include/numkong/spatials/serial.h +226 -0
- package/include/numkong/spatials/sierra.h +96 -0
- package/include/numkong/spatials/skylake.h +184 -0
- package/include/numkong/spatials/sme.h +1901 -0
- package/include/numkong/spatials/smef64.h +465 -0
- package/include/numkong/spatials/v128relaxed.h +240 -0
- package/include/numkong/spatials.h +3021 -0
- package/include/numkong/spatials.hpp +508 -0
- package/include/numkong/tensor.hpp +1592 -0
- package/include/numkong/trigonometry/README.md +184 -0
- package/include/numkong/trigonometry/haswell.h +652 -0
- package/include/numkong/trigonometry/neon.h +639 -0
- package/include/numkong/trigonometry/rvv.h +699 -0
- package/include/numkong/trigonometry/serial.h +703 -0
- package/include/numkong/trigonometry/skylake.h +721 -0
- package/include/numkong/trigonometry/v128relaxed.h +666 -0
- package/include/numkong/trigonometry.h +467 -0
- package/include/numkong/trigonometry.hpp +166 -0
- package/include/numkong/types.h +1384 -0
- package/include/numkong/types.hpp +5603 -0
- package/include/numkong/vector.hpp +698 -0
- package/javascript/README.md +246 -0
- package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
- package/javascript/dist/cjs/numkong-wasm.js +617 -0
- package/javascript/dist/cjs/numkong.d.ts +343 -0
- package/javascript/dist/cjs/numkong.js +523 -0
- package/javascript/dist/cjs/package.json +3 -0
- package/javascript/dist/cjs/types.d.ts +284 -0
- package/javascript/dist/cjs/types.js +653 -0
- package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
- package/javascript/dist/esm/numkong-wasm.js +595 -0
- package/javascript/dist/esm/numkong.d.ts +343 -0
- package/javascript/dist/esm/numkong.js +452 -0
- package/javascript/dist/esm/package.json +3 -0
- package/javascript/dist/esm/types.d.ts +284 -0
- package/javascript/dist/esm/types.js +630 -0
- package/javascript/dist-package-cjs.json +3 -0
- package/javascript/dist-package-esm.json +3 -0
- package/javascript/node-gyp-build.d.ts +1 -0
- package/javascript/numkong-wasm.ts +756 -0
- package/javascript/numkong.c +689 -0
- package/javascript/numkong.ts +575 -0
- package/javascript/tsconfig-base.json +39 -0
- package/javascript/tsconfig-cjs.json +8 -0
- package/javascript/tsconfig-esm.json +8 -0
- package/javascript/types.ts +674 -0
- package/package.json +87 -0
|
@@ -0,0 +1,1104 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SIMD-accelerated Elementwise Arithmetic for NEON.
|
|
3
|
+
* @file include/numkong/each/neon.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date December 27, 2025
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/each.h
|
|
8
|
+
*
|
|
9
|
+
* @section elementwise_neon_instructions ARM NEON Instructions
|
|
10
|
+
*
|
|
11
|
+
* Intrinsic Instruction Latency Throughput
|
|
12
|
+
* A76 M4+/V1+/Oryon
|
|
13
|
+
* vld1q_f32 LD1 (V.4S) 4cy 2/cy 2/cy
|
|
14
|
+
* vst1q_f32 ST1 (V.4S) 2cy 2/cy 2/cy
|
|
15
|
+
* vaddq_f32 FADD (V.4S, V.4S, V.4S) 2cy 2/cy 4/cy
|
|
16
|
+
* vmulq_f32 FMUL (V.4S, V.4S, V.4S) 3cy 2/cy 4/cy
|
|
17
|
+
* vfmaq_f32 FMLA (V.4S, V.4S, V.4S) 4cy 2/cy 4/cy
|
|
18
|
+
* vaddq_f64 FADD (V.2D, V.2D, V.2D) 2cy 2/cy 4/cy
|
|
19
|
+
* vmulq_f64 FMUL (V.2D, V.2D, V.2D) 3cy 2/cy 4/cy
|
|
20
|
+
* vfmaq_f64 FMLA (V.2D, V.2D, V.2D) 4cy 2/cy 4/cy
|
|
21
|
+
* vqaddq_s16 SQADD (V.8H, V.8H, V.8H) 2cy 2/cy 4/cy
|
|
22
|
+
* vcvtq_f32_s32 SCVTF (V.4S, V.4S) 3cy 2/cy 2/cy
|
|
23
|
+
* vcvtnq_s32_f32 FCVTNS (V.4S, V.4S) 3cy 2/cy 2/cy
|
|
24
|
+
* vqmovn_s32 SQXTN (V.4H, V.4S) 3cy 2/cy 2/cy
|
|
25
|
+
*
|
|
26
|
+
* Elementwise operations are throughput-bound rather than latency-bound. FP arithmetic
|
|
27
|
+
* throughput doubles on 4-pipe cores (Apple M4+, Graviton3+, Oryon) from 2/cy to 4/cy.
|
|
28
|
+
*
|
|
29
|
+
* Memory bandwidth (LD1/ST1) typically becomes the bottleneck for large arrays, as load/store
|
|
30
|
+
* throughput remains at 2/cy across all cores.
|
|
31
|
+
*/
|
|
32
|
+
#ifndef NK_EACH_NEON_H
|
|
33
|
+
#define NK_EACH_NEON_H
|
|
34
|
+
|
|
35
|
+
#if NK_TARGET_ARM_
|
|
36
|
+
#if NK_TARGET_NEON
|
|
37
|
+
|
|
38
|
+
#include "numkong/types.h"
|
|
39
|
+
#include "numkong/cast/neon.h"
|
|
40
|
+
|
|
41
|
+
#if defined(__cplusplus)
|
|
42
|
+
extern "C" {
|
|
43
|
+
#endif
|
|
44
|
+
|
|
45
|
+
#if defined(__clang__)
|
|
46
|
+
#pragma clang attribute push(__attribute__((target("arch=armv8-a+simd"))), apply_to = function)
|
|
47
|
+
#elif defined(__GNUC__)
|
|
48
|
+
#pragma GCC push_options
|
|
49
|
+
#pragma GCC target("arch=armv8-a+simd")
|
|
50
|
+
#endif
|
|
51
|
+
|
|
52
|
+
NK_PUBLIC void nk_each_sum_f32_neon(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
53
|
+
// The main loop:
|
|
54
|
+
nk_size_t i = 0;
|
|
55
|
+
for (; i + 4 <= n; i += 4) {
|
|
56
|
+
float32x4_t a_f32x4 = vld1q_f32(a + i);
|
|
57
|
+
float32x4_t b_f32x4 = vld1q_f32(b + i);
|
|
58
|
+
float32x4_t sum_f32x4 = vaddq_f32(a_f32x4, b_f32x4);
|
|
59
|
+
vst1q_f32(result + i, sum_f32x4);
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
// The tail:
|
|
63
|
+
for (; i < n; ++i) result[i] = a[i] + b[i];
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
NK_PUBLIC void nk_each_scale_f32_neon(nk_f32_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
67
|
+
nk_f32_t *result) {
|
|
68
|
+
nk_f32_t alpha_val = *alpha;
|
|
69
|
+
nk_f32_t beta_val = *beta;
|
|
70
|
+
float32x4_t beta_f32x4 = vdupq_n_f32(beta_val);
|
|
71
|
+
|
|
72
|
+
// The main loop:
|
|
73
|
+
nk_size_t i = 0;
|
|
74
|
+
for (; i + 4 <= n; i += 4) {
|
|
75
|
+
float32x4_t a_f32x4 = vld1q_f32(a + i);
|
|
76
|
+
float32x4_t result_f32x4 = vfmaq_n_f32(beta_f32x4, a_f32x4, alpha_val);
|
|
77
|
+
vst1q_f32(result + i, result_f32x4);
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
// The tail:
|
|
81
|
+
for (; i < n; ++i) result[i] = alpha_val * a[i] + beta_val;
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
NK_PUBLIC void nk_each_blend_f32_neon( //
|
|
85
|
+
nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, //
|
|
86
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_f32_t *result) {
|
|
87
|
+
|
|
88
|
+
nk_f32_t alpha_val = *alpha;
|
|
89
|
+
nk_f32_t beta_val = *beta;
|
|
90
|
+
|
|
91
|
+
// There are several special cases we may want to implement:
|
|
92
|
+
// 1. Simple addition, when both weights are equal to 1.0.
|
|
93
|
+
if (alpha_val == 1 && beta_val == 1) {
|
|
94
|
+
// In this case we can avoid expensive multiplications.
|
|
95
|
+
nk_each_sum_f32_neon(a, b, n, result);
|
|
96
|
+
return;
|
|
97
|
+
}
|
|
98
|
+
// 2. Just scaling, when one of the weights is equal to zero.
|
|
99
|
+
else if (alpha_val == 0 || beta_val == 0) {
|
|
100
|
+
// In this case we can avoid half of the load instructions.
|
|
101
|
+
nk_f32_t zero = 0;
|
|
102
|
+
if (beta_val == 0) { nk_each_scale_f32_neon(a, n, alpha, &zero, result); }
|
|
103
|
+
else { nk_each_scale_f32_neon(b, n, beta, &zero, result); }
|
|
104
|
+
return;
|
|
105
|
+
}
|
|
106
|
+
|
|
107
|
+
// The general case.
|
|
108
|
+
// The main loop:
|
|
109
|
+
nk_size_t i = 0;
|
|
110
|
+
for (; i + 4 <= n; i += 4) {
|
|
111
|
+
float32x4_t a_f32x4 = vld1q_f32(a + i);
|
|
112
|
+
float32x4_t b_f32x4 = vld1q_f32(b + i);
|
|
113
|
+
float32x4_t a_scaled_f32x4 = vmulq_n_f32(a_f32x4, alpha_val);
|
|
114
|
+
float32x4_t result_f32x4 = vfmaq_n_f32(a_scaled_f32x4, b_f32x4, beta_val);
|
|
115
|
+
vst1q_f32(result + i, result_f32x4);
|
|
116
|
+
}
|
|
117
|
+
|
|
118
|
+
// The tail:
|
|
119
|
+
for (; i < n; ++i) result[i] = alpha_val * a[i] + beta_val * b[i];
|
|
120
|
+
}
|
|
121
|
+
|
|
122
|
+
NK_PUBLIC void nk_each_fma_f32_neon( //
|
|
123
|
+
nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, //
|
|
124
|
+
nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta, nk_f32_t *result) {
|
|
125
|
+
nk_f32_t alpha_val = *alpha;
|
|
126
|
+
nk_f32_t beta_val = *beta;
|
|
127
|
+
|
|
128
|
+
// The main loop:
|
|
129
|
+
nk_size_t i = 0;
|
|
130
|
+
for (; i + 4 <= n; i += 4) {
|
|
131
|
+
float32x4_t a_f32x4 = vld1q_f32(a + i);
|
|
132
|
+
float32x4_t b_f32x4 = vld1q_f32(b + i);
|
|
133
|
+
float32x4_t c_f32x4 = vld1q_f32(c + i);
|
|
134
|
+
float32x4_t ab_f32x4 = vmulq_f32(a_f32x4, b_f32x4);
|
|
135
|
+
float32x4_t ab_scaled_f32x4 = vmulq_n_f32(ab_f32x4, alpha_val);
|
|
136
|
+
float32x4_t result_f32x4 = vfmaq_n_f32(ab_scaled_f32x4, c_f32x4, beta_val);
|
|
137
|
+
vst1q_f32(result + i, result_f32x4);
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
// The tail:
|
|
141
|
+
for (; i < n; ++i) result[i] = alpha_val * a[i] * b[i] + beta_val * c[i];
|
|
142
|
+
}
|
|
143
|
+
|
|
144
|
+
NK_PUBLIC void nk_each_sum_i16_neon(nk_i16_t const *a, nk_i16_t const *b, nk_size_t n, nk_i16_t *result) {
|
|
145
|
+
// The main loop:
|
|
146
|
+
nk_size_t i = 0;
|
|
147
|
+
for (; i + 8 <= n; i += 8) {
|
|
148
|
+
int16x8_t a_s16x8 = vld1q_s16(a + i);
|
|
149
|
+
int16x8_t b_s16x8 = vld1q_s16(b + i);
|
|
150
|
+
int16x8_t sum_s16x8 = vqaddq_s16(a_s16x8, b_s16x8);
|
|
151
|
+
vst1q_s16(result + i, sum_s16x8);
|
|
152
|
+
}
|
|
153
|
+
|
|
154
|
+
// The tail:
|
|
155
|
+
for (; i < n; ++i) result[i] = nk_i16_saturating_add_serial(a[i], b[i]);
|
|
156
|
+
}
|
|
157
|
+
|
|
158
|
+
NK_PUBLIC void nk_each_scale_i16_neon(nk_i16_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
159
|
+
nk_i16_t *result) {
|
|
160
|
+
float32_t alpha_f32 = *alpha;
|
|
161
|
+
float32_t beta_f32 = *beta;
|
|
162
|
+
float32x4_t alpha_f32x4 = vdupq_n_f32(alpha_f32);
|
|
163
|
+
float32x4_t beta_f32x4 = vdupq_n_f32(beta_f32);
|
|
164
|
+
float32x4_t min_f32x4 = vdupq_n_f32(-32768.0f);
|
|
165
|
+
float32x4_t max_f32x4 = vdupq_n_f32(32767.0f);
|
|
166
|
+
|
|
167
|
+
// The main loop:
|
|
168
|
+
nk_size_t i = 0;
|
|
169
|
+
for (; i + 4 <= n; i += 4) {
|
|
170
|
+
int16x4_t a_i16x4 = vld1_s16(a + i);
|
|
171
|
+
float32x4_t a_f32x4 = vcvtq_f32_s32(vmovl_s16(a_i16x4));
|
|
172
|
+
float32x4_t result_f32x4 = vfmaq_f32(beta_f32x4, a_f32x4, alpha_f32x4);
|
|
173
|
+
result_f32x4 = vmaxq_f32(vminq_f32(result_f32x4, max_f32x4), min_f32x4);
|
|
174
|
+
int16x4_t result_i16x4 = vqmovn_s32(vcvtnq_s32_f32(result_f32x4));
|
|
175
|
+
vst1_s16(result + i, result_i16x4);
|
|
176
|
+
}
|
|
177
|
+
|
|
178
|
+
// The tail:
|
|
179
|
+
for (; i < n; ++i) {
|
|
180
|
+
nk_f32_t sum = alpha_f32 * a[i] + beta_f32;
|
|
181
|
+
nk_f32_to_i16_serial(&sum, result + i);
|
|
182
|
+
}
|
|
183
|
+
}
|
|
184
|
+
|
|
185
|
+
NK_PUBLIC void nk_each_fma_i16_neon( //
|
|
186
|
+
nk_i16_t const *a, nk_i16_t const *b, nk_i16_t const *c, //
|
|
187
|
+
nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta, nk_i16_t *result) {
|
|
188
|
+
float32_t alpha_f32 = *alpha;
|
|
189
|
+
float32_t beta_f32 = *beta;
|
|
190
|
+
float32x4_t min_f32x4 = vdupq_n_f32(-32768.0f);
|
|
191
|
+
float32x4_t max_f32x4 = vdupq_n_f32(32767.0f);
|
|
192
|
+
|
|
193
|
+
// The main loop:
|
|
194
|
+
nk_size_t i = 0;
|
|
195
|
+
for (; i + 4 <= n; i += 4) {
|
|
196
|
+
int16x4_t a_i16x4 = vld1_s16(a + i);
|
|
197
|
+
int16x4_t b_i16x4 = vld1_s16(b + i);
|
|
198
|
+
int16x4_t c_i16x4 = vld1_s16(c + i);
|
|
199
|
+
float32x4_t a_f32x4 = vcvtq_f32_s32(vmovl_s16(a_i16x4));
|
|
200
|
+
float32x4_t b_f32x4 = vcvtq_f32_s32(vmovl_s16(b_i16x4));
|
|
201
|
+
float32x4_t c_f32x4 = vcvtq_f32_s32(vmovl_s16(c_i16x4));
|
|
202
|
+
float32x4_t ab_f32x4 = vmulq_f32(a_f32x4, b_f32x4);
|
|
203
|
+
float32x4_t ab_scaled_f32x4 = vmulq_n_f32(ab_f32x4, alpha_f32);
|
|
204
|
+
float32x4_t result_f32x4 = vfmaq_n_f32(ab_scaled_f32x4, c_f32x4, beta_f32);
|
|
205
|
+
result_f32x4 = vmaxq_f32(vminq_f32(result_f32x4, max_f32x4), min_f32x4);
|
|
206
|
+
int16x4_t result_i16x4 = vqmovn_s32(vcvtnq_s32_f32(result_f32x4));
|
|
207
|
+
vst1_s16(result + i, result_i16x4);
|
|
208
|
+
}
|
|
209
|
+
|
|
210
|
+
// The tail:
|
|
211
|
+
for (; i < n; ++i) {
|
|
212
|
+
nk_f32_t sum = alpha_f32 * a[i] * b[i] + beta_f32 * c[i];
|
|
213
|
+
nk_f32_to_i16_serial(&sum, result + i);
|
|
214
|
+
}
|
|
215
|
+
}
|
|
216
|
+
|
|
217
|
+
NK_PUBLIC void nk_each_sum_u16_neon(nk_u16_t const *a, nk_u16_t const *b, nk_size_t n, nk_u16_t *result) {
|
|
218
|
+
// The main loop:
|
|
219
|
+
nk_size_t i = 0;
|
|
220
|
+
for (; i + 8 <= n; i += 8) {
|
|
221
|
+
uint16x8_t a_u16x8 = vld1q_u16(a + i);
|
|
222
|
+
uint16x8_t b_u16x8 = vld1q_u16(b + i);
|
|
223
|
+
uint16x8_t sum_u16x8 = vqaddq_u16(a_u16x8, b_u16x8);
|
|
224
|
+
vst1q_u16(result + i, sum_u16x8);
|
|
225
|
+
}
|
|
226
|
+
|
|
227
|
+
// The tail:
|
|
228
|
+
for (; i < n; ++i) result[i] = nk_u16_saturating_add_serial(a[i], b[i]);
|
|
229
|
+
}
|
|
230
|
+
|
|
231
|
+
NK_PUBLIC void nk_each_scale_u16_neon(nk_u16_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
232
|
+
nk_u16_t *result) {
|
|
233
|
+
float32_t alpha_f32 = *alpha;
|
|
234
|
+
float32_t beta_f32 = *beta;
|
|
235
|
+
float32x4_t alpha_f32x4 = vdupq_n_f32(alpha_f32);
|
|
236
|
+
float32x4_t beta_f32x4 = vdupq_n_f32(beta_f32);
|
|
237
|
+
float32x4_t min_f32x4 = vdupq_n_f32(0.0f);
|
|
238
|
+
float32x4_t max_f32x4 = vdupq_n_f32(65535.0f);
|
|
239
|
+
|
|
240
|
+
// The main loop:
|
|
241
|
+
nk_size_t i = 0;
|
|
242
|
+
for (; i + 4 <= n; i += 4) {
|
|
243
|
+
uint16x4_t a_u16x4 = vld1_u16(a + i);
|
|
244
|
+
float32x4_t a_f32x4 = vcvtq_f32_u32(vmovl_u16(a_u16x4));
|
|
245
|
+
float32x4_t result_f32x4 = vfmaq_f32(beta_f32x4, a_f32x4, alpha_f32x4);
|
|
246
|
+
result_f32x4 = vmaxq_f32(vminq_f32(result_f32x4, max_f32x4), min_f32x4);
|
|
247
|
+
uint16x4_t result_u16x4 = vqmovn_u32(vcvtnq_u32_f32(result_f32x4));
|
|
248
|
+
vst1_u16(result + i, result_u16x4);
|
|
249
|
+
}
|
|
250
|
+
|
|
251
|
+
// The tail:
|
|
252
|
+
for (; i < n; ++i) {
|
|
253
|
+
nk_f32_t sum = alpha_f32 * a[i] + beta_f32;
|
|
254
|
+
nk_f32_to_u16_serial(&sum, result + i);
|
|
255
|
+
}
|
|
256
|
+
}
|
|
257
|
+
|
|
258
|
+
NK_PUBLIC void nk_each_fma_u16_neon( //
|
|
259
|
+
nk_u16_t const *a, nk_u16_t const *b, nk_u16_t const *c, //
|
|
260
|
+
nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta, nk_u16_t *result) {
|
|
261
|
+
float32_t alpha_f32 = *alpha;
|
|
262
|
+
float32_t beta_f32 = *beta;
|
|
263
|
+
float32x4_t min_f32x4 = vdupq_n_f32(0.0f);
|
|
264
|
+
float32x4_t max_f32x4 = vdupq_n_f32(65535.0f);
|
|
265
|
+
|
|
266
|
+
// The main loop:
|
|
267
|
+
nk_size_t i = 0;
|
|
268
|
+
for (; i + 4 <= n; i += 4) {
|
|
269
|
+
uint16x4_t a_u16x4 = vld1_u16(a + i);
|
|
270
|
+
uint16x4_t b_u16x4 = vld1_u16(b + i);
|
|
271
|
+
uint16x4_t c_u16x4 = vld1_u16(c + i);
|
|
272
|
+
float32x4_t a_f32x4 = vcvtq_f32_u32(vmovl_u16(a_u16x4));
|
|
273
|
+
float32x4_t b_f32x4 = vcvtq_f32_u32(vmovl_u16(b_u16x4));
|
|
274
|
+
float32x4_t c_f32x4 = vcvtq_f32_u32(vmovl_u16(c_u16x4));
|
|
275
|
+
float32x4_t ab_f32x4 = vmulq_f32(a_f32x4, b_f32x4);
|
|
276
|
+
float32x4_t ab_scaled_f32x4 = vmulq_n_f32(ab_f32x4, alpha_f32);
|
|
277
|
+
float32x4_t result_f32x4 = vfmaq_n_f32(ab_scaled_f32x4, c_f32x4, beta_f32);
|
|
278
|
+
result_f32x4 = vmaxq_f32(vminq_f32(result_f32x4, max_f32x4), min_f32x4);
|
|
279
|
+
uint16x4_t result_u16x4 = vqmovn_u32(vcvtnq_u32_f32(result_f32x4));
|
|
280
|
+
vst1_u16(result + i, result_u16x4);
|
|
281
|
+
}
|
|
282
|
+
|
|
283
|
+
// The tail:
|
|
284
|
+
for (; i < n; ++i) {
|
|
285
|
+
nk_f32_t sum = alpha_f32 * a[i] * b[i] + beta_f32 * c[i];
|
|
286
|
+
nk_f32_to_u16_serial(&sum, result + i);
|
|
287
|
+
}
|
|
288
|
+
}
|
|
289
|
+
|
|
290
|
+
NK_PUBLIC void nk_each_sum_i32_neon(nk_i32_t const *a, nk_i32_t const *b, nk_size_t n, nk_i32_t *result) {
|
|
291
|
+
// The main loop:
|
|
292
|
+
nk_size_t i = 0;
|
|
293
|
+
for (; i + 4 <= n; i += 4) {
|
|
294
|
+
int32x4_t a_s32x4 = vld1q_s32(a + i);
|
|
295
|
+
int32x4_t b_s32x4 = vld1q_s32(b + i);
|
|
296
|
+
int32x4_t sum_s32x4 = vqaddq_s32(a_s32x4, b_s32x4);
|
|
297
|
+
vst1q_s32(result + i, sum_s32x4);
|
|
298
|
+
}
|
|
299
|
+
|
|
300
|
+
// The tail:
|
|
301
|
+
for (; i < n; ++i) result[i] = nk_i32_saturating_add_serial(a[i], b[i]);
|
|
302
|
+
}
|
|
303
|
+
|
|
304
|
+
NK_PUBLIC void nk_each_scale_i32_neon(nk_i32_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
|
|
305
|
+
nk_i32_t *result) {
|
|
306
|
+
nk_f64_t alpha_val = *alpha;
|
|
307
|
+
nk_f64_t beta_val = *beta;
|
|
308
|
+
float64x2_t alpha_f64x2 = vdupq_n_f64(alpha_val);
|
|
309
|
+
float64x2_t beta_f64x2 = vdupq_n_f64(beta_val);
|
|
310
|
+
float64x2_t min_f64x2 = vdupq_n_f64(-2147483648.0);
|
|
311
|
+
float64x2_t max_f64x2 = vdupq_n_f64(2147483647.0);
|
|
312
|
+
|
|
313
|
+
// The main loop:
|
|
314
|
+
nk_size_t i = 0;
|
|
315
|
+
for (; i + 2 <= n; i += 2) {
|
|
316
|
+
int32x2_t a_i32x2 = vld1_s32(a + i);
|
|
317
|
+
float64x2_t a_f64x2 = vcvtq_f64_s64(vmovl_s32(a_i32x2));
|
|
318
|
+
float64x2_t result_f64x2 = vfmaq_f64(beta_f64x2, a_f64x2, alpha_f64x2);
|
|
319
|
+
result_f64x2 = vmaxq_f64(vminq_f64(result_f64x2, max_f64x2), min_f64x2);
|
|
320
|
+
int32x2_t result_i32x2 = vqmovn_s64(vcvtnq_s64_f64(result_f64x2));
|
|
321
|
+
vst1_s32(result + i, result_i32x2);
|
|
322
|
+
}
|
|
323
|
+
|
|
324
|
+
// The tail:
|
|
325
|
+
for (; i < n; ++i) {
|
|
326
|
+
nk_f64_t sum = alpha_val * a[i] + beta_val;
|
|
327
|
+
nk_f64_to_i32_serial(&sum, result + i);
|
|
328
|
+
}
|
|
329
|
+
}
|
|
330
|
+
|
|
331
|
+
NK_PUBLIC void nk_each_fma_i32_neon( //
|
|
332
|
+
nk_i32_t const *a, nk_i32_t const *b, nk_i32_t const *c, //
|
|
333
|
+
nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta, nk_i32_t *result) {
|
|
334
|
+
nk_f64_t alpha_val = *alpha;
|
|
335
|
+
nk_f64_t beta_val = *beta;
|
|
336
|
+
float64x2_t min_f64x2 = vdupq_n_f64(-2147483648.0);
|
|
337
|
+
float64x2_t max_f64x2 = vdupq_n_f64(2147483647.0);
|
|
338
|
+
|
|
339
|
+
// The main loop:
|
|
340
|
+
nk_size_t i = 0;
|
|
341
|
+
for (; i + 2 <= n; i += 2) {
|
|
342
|
+
int32x2_t a_i32x2 = vld1_s32(a + i);
|
|
343
|
+
int32x2_t b_i32x2 = vld1_s32(b + i);
|
|
344
|
+
int32x2_t c_i32x2 = vld1_s32(c + i);
|
|
345
|
+
float64x2_t a_f64x2 = vcvtq_f64_s64(vmovl_s32(a_i32x2));
|
|
346
|
+
float64x2_t b_f64x2 = vcvtq_f64_s64(vmovl_s32(b_i32x2));
|
|
347
|
+
float64x2_t c_f64x2 = vcvtq_f64_s64(vmovl_s32(c_i32x2));
|
|
348
|
+
float64x2_t ab_f64x2 = vmulq_f64(a_f64x2, b_f64x2);
|
|
349
|
+
float64x2_t ab_scaled_f64x2 = vmulq_n_f64(ab_f64x2, alpha_val);
|
|
350
|
+
float64x2_t result_f64x2 = vfmaq_n_f64(ab_scaled_f64x2, c_f64x2, beta_val);
|
|
351
|
+
result_f64x2 = vmaxq_f64(vminq_f64(result_f64x2, max_f64x2), min_f64x2);
|
|
352
|
+
int32x2_t result_i32x2 = vqmovn_s64(vcvtnq_s64_f64(result_f64x2));
|
|
353
|
+
vst1_s32(result + i, result_i32x2);
|
|
354
|
+
}
|
|
355
|
+
|
|
356
|
+
// The tail:
|
|
357
|
+
for (; i < n; ++i) {
|
|
358
|
+
nk_f64_t sum = alpha_val * a[i] * b[i] + beta_val * c[i];
|
|
359
|
+
nk_f64_to_i32_serial(&sum, result + i);
|
|
360
|
+
}
|
|
361
|
+
}
|
|
362
|
+
|
|
363
|
+
NK_PUBLIC void nk_each_sum_u32_neon(nk_u32_t const *a, nk_u32_t const *b, nk_size_t n, nk_u32_t *result) {
|
|
364
|
+
// The main loop:
|
|
365
|
+
nk_size_t i = 0;
|
|
366
|
+
for (; i + 4 <= n; i += 4) {
|
|
367
|
+
uint32x4_t a_u32x4 = vld1q_u32(a + i);
|
|
368
|
+
uint32x4_t b_u32x4 = vld1q_u32(b + i);
|
|
369
|
+
uint32x4_t sum_u32x4 = vqaddq_u32(a_u32x4, b_u32x4);
|
|
370
|
+
vst1q_u32(result + i, sum_u32x4);
|
|
371
|
+
}
|
|
372
|
+
|
|
373
|
+
// The tail:
|
|
374
|
+
for (; i < n; ++i) result[i] = nk_u32_saturating_add_serial(a[i], b[i]);
|
|
375
|
+
}
|
|
376
|
+
|
|
377
|
+
NK_PUBLIC void nk_each_scale_u32_neon(nk_u32_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
|
|
378
|
+
nk_u32_t *result) {
|
|
379
|
+
nk_f64_t alpha_val = *alpha;
|
|
380
|
+
nk_f64_t beta_val = *beta;
|
|
381
|
+
float64x2_t alpha_f64x2 = vdupq_n_f64(alpha_val);
|
|
382
|
+
float64x2_t beta_f64x2 = vdupq_n_f64(beta_val);
|
|
383
|
+
float64x2_t min_f64x2 = vdupq_n_f64(0.0);
|
|
384
|
+
float64x2_t max_f64x2 = vdupq_n_f64(4294967295.0);
|
|
385
|
+
|
|
386
|
+
// The main loop:
|
|
387
|
+
nk_size_t i = 0;
|
|
388
|
+
for (; i + 2 <= n; i += 2) {
|
|
389
|
+
uint32x2_t a_u32x2 = vld1_u32(a + i);
|
|
390
|
+
float64x2_t a_f64x2 = vcvtq_f64_u64(vmovl_u32(a_u32x2));
|
|
391
|
+
float64x2_t result_f64x2 = vfmaq_f64(beta_f64x2, a_f64x2, alpha_f64x2);
|
|
392
|
+
result_f64x2 = vmaxq_f64(vminq_f64(result_f64x2, max_f64x2), min_f64x2);
|
|
393
|
+
uint32x2_t result_u32x2 = vqmovn_u64(vcvtnq_u64_f64(result_f64x2));
|
|
394
|
+
vst1_u32(result + i, result_u32x2);
|
|
395
|
+
}
|
|
396
|
+
|
|
397
|
+
// The tail:
|
|
398
|
+
for (; i < n; ++i) {
|
|
399
|
+
nk_f64_t sum = alpha_val * a[i] + beta_val;
|
|
400
|
+
nk_f64_to_u32_serial(&sum, result + i);
|
|
401
|
+
}
|
|
402
|
+
}
|
|
403
|
+
|
|
404
|
+
NK_PUBLIC void nk_each_fma_u32_neon( //
|
|
405
|
+
nk_u32_t const *a, nk_u32_t const *b, nk_u32_t const *c, //
|
|
406
|
+
nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta, nk_u32_t *result) {
|
|
407
|
+
nk_f64_t alpha_val = *alpha;
|
|
408
|
+
nk_f64_t beta_val = *beta;
|
|
409
|
+
float64x2_t min_f64x2 = vdupq_n_f64(0.0);
|
|
410
|
+
float64x2_t max_f64x2 = vdupq_n_f64(4294967295.0);
|
|
411
|
+
|
|
412
|
+
// The main loop:
|
|
413
|
+
nk_size_t i = 0;
|
|
414
|
+
for (; i + 2 <= n; i += 2) {
|
|
415
|
+
uint32x2_t a_u32x2 = vld1_u32(a + i);
|
|
416
|
+
uint32x2_t b_u32x2 = vld1_u32(b + i);
|
|
417
|
+
uint32x2_t c_u32x2 = vld1_u32(c + i);
|
|
418
|
+
float64x2_t a_f64x2 = vcvtq_f64_u64(vmovl_u32(a_u32x2));
|
|
419
|
+
float64x2_t b_f64x2 = vcvtq_f64_u64(vmovl_u32(b_u32x2));
|
|
420
|
+
float64x2_t c_f64x2 = vcvtq_f64_u64(vmovl_u32(c_u32x2));
|
|
421
|
+
float64x2_t ab_f64x2 = vmulq_f64(a_f64x2, b_f64x2);
|
|
422
|
+
float64x2_t ab_scaled_f64x2 = vmulq_n_f64(ab_f64x2, alpha_val);
|
|
423
|
+
float64x2_t result_f64x2 = vfmaq_n_f64(ab_scaled_f64x2, c_f64x2, beta_val);
|
|
424
|
+
result_f64x2 = vmaxq_f64(vminq_f64(result_f64x2, max_f64x2), min_f64x2);
|
|
425
|
+
uint32x2_t result_u32x2 = vqmovn_u64(vcvtnq_u64_f64(result_f64x2));
|
|
426
|
+
vst1_u32(result + i, result_u32x2);
|
|
427
|
+
}
|
|
428
|
+
|
|
429
|
+
// The tail:
|
|
430
|
+
for (; i < n; ++i) {
|
|
431
|
+
nk_f64_t sum = alpha_val * a[i] * b[i] + beta_val * c[i];
|
|
432
|
+
nk_f64_to_u32_serial(&sum, result + i);
|
|
433
|
+
}
|
|
434
|
+
}
|
|
435
|
+
|
|
436
|
+
NK_PUBLIC void nk_each_sum_i64_neon(nk_i64_t const *a, nk_i64_t const *b, nk_size_t n, nk_i64_t *result) {
|
|
437
|
+
// The main loop:
|
|
438
|
+
nk_size_t i = 0;
|
|
439
|
+
for (; i + 2 <= n; i += 2) {
|
|
440
|
+
int64x2_t a_s64x2 = vld1q_s64(a + i);
|
|
441
|
+
int64x2_t b_s64x2 = vld1q_s64(b + i);
|
|
442
|
+
int64x2_t sum_s64x2 = vqaddq_s64(a_s64x2, b_s64x2);
|
|
443
|
+
vst1q_s64(result + i, sum_s64x2);
|
|
444
|
+
}
|
|
445
|
+
|
|
446
|
+
// The tail:
|
|
447
|
+
for (; i < n; ++i) result[i] = nk_i64_saturating_add_serial(a[i], b[i]);
|
|
448
|
+
}
|
|
449
|
+
|
|
450
|
+
NK_PUBLIC void nk_each_scale_i64_neon(nk_i64_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
|
|
451
|
+
nk_i64_t *result) {
|
|
452
|
+
nk_f64_t alpha_val = *alpha;
|
|
453
|
+
nk_f64_t beta_val = *beta;
|
|
454
|
+
float64x2_t alpha_f64x2 = vdupq_n_f64(alpha_val);
|
|
455
|
+
float64x2_t beta_f64x2 = vdupq_n_f64(beta_val);
|
|
456
|
+
float64x2_t min_f64x2 = vdupq_n_f64((nk_f64_t)NK_I64_MIN);
|
|
457
|
+
float64x2_t max_f64x2 = vdupq_n_f64((nk_f64_t)NK_I64_MAX);
|
|
458
|
+
|
|
459
|
+
// The main loop:
|
|
460
|
+
nk_size_t i = 0;
|
|
461
|
+
for (; i + 2 <= n; i += 2) {
|
|
462
|
+
int64x2_t a_i64x2 = vld1q_s64(a + i);
|
|
463
|
+
float64x2_t a_f64x2 = vcvtq_f64_s64(a_i64x2);
|
|
464
|
+
float64x2_t result_f64x2 = vfmaq_f64(beta_f64x2, a_f64x2, alpha_f64x2);
|
|
465
|
+
result_f64x2 = vmaxq_f64(vminq_f64(result_f64x2, max_f64x2), min_f64x2);
|
|
466
|
+
int64x2_t result_i64x2 = vcvtnq_s64_f64(result_f64x2);
|
|
467
|
+
vst1q_s64(result + i, result_i64x2);
|
|
468
|
+
}
|
|
469
|
+
|
|
470
|
+
// The tail:
|
|
471
|
+
for (; i < n; ++i) {
|
|
472
|
+
nk_f64_t sum = alpha_val * a[i] + beta_val;
|
|
473
|
+
nk_f64_to_i64_serial(&sum, result + i);
|
|
474
|
+
}
|
|
475
|
+
}
|
|
476
|
+
|
|
477
|
+
NK_PUBLIC void nk_each_fma_i64_neon( //
|
|
478
|
+
nk_i64_t const *a, nk_i64_t const *b, nk_i64_t const *c, //
|
|
479
|
+
nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta, nk_i64_t *result) {
|
|
480
|
+
nk_f64_t alpha_val = *alpha;
|
|
481
|
+
nk_f64_t beta_val = *beta;
|
|
482
|
+
float64x2_t min_f64x2 = vdupq_n_f64((nk_f64_t)NK_I64_MIN);
|
|
483
|
+
float64x2_t max_f64x2 = vdupq_n_f64((nk_f64_t)NK_I64_MAX);
|
|
484
|
+
|
|
485
|
+
// The main loop:
|
|
486
|
+
nk_size_t i = 0;
|
|
487
|
+
for (; i + 2 <= n; i += 2) {
|
|
488
|
+
int64x2_t a_i64x2 = vld1q_s64(a + i);
|
|
489
|
+
int64x2_t b_i64x2 = vld1q_s64(b + i);
|
|
490
|
+
int64x2_t c_i64x2 = vld1q_s64(c + i);
|
|
491
|
+
float64x2_t a_f64x2 = vcvtq_f64_s64(a_i64x2);
|
|
492
|
+
float64x2_t b_f64x2 = vcvtq_f64_s64(b_i64x2);
|
|
493
|
+
float64x2_t c_f64x2 = vcvtq_f64_s64(c_i64x2);
|
|
494
|
+
float64x2_t ab_f64x2 = vmulq_f64(a_f64x2, b_f64x2);
|
|
495
|
+
float64x2_t ab_scaled_f64x2 = vmulq_n_f64(ab_f64x2, alpha_val);
|
|
496
|
+
float64x2_t result_f64x2 = vfmaq_n_f64(ab_scaled_f64x2, c_f64x2, beta_val);
|
|
497
|
+
result_f64x2 = vmaxq_f64(vminq_f64(result_f64x2, max_f64x2), min_f64x2);
|
|
498
|
+
int64x2_t result_i64x2 = vcvtnq_s64_f64(result_f64x2);
|
|
499
|
+
vst1q_s64(result + i, result_i64x2);
|
|
500
|
+
}
|
|
501
|
+
|
|
502
|
+
// The tail:
|
|
503
|
+
for (; i < n; ++i) {
|
|
504
|
+
nk_f64_t sum = alpha_val * a[i] * b[i] + beta_val * c[i];
|
|
505
|
+
nk_f64_to_i64_serial(&sum, result + i);
|
|
506
|
+
}
|
|
507
|
+
}
|
|
508
|
+
|
|
509
|
+
NK_PUBLIC void nk_each_sum_u64_neon(nk_u64_t const *a, nk_u64_t const *b, nk_size_t n, nk_u64_t *result) {
|
|
510
|
+
// The main loop:
|
|
511
|
+
nk_size_t i = 0;
|
|
512
|
+
for (; i + 2 <= n; i += 2) {
|
|
513
|
+
uint64x2_t a_u64x2 = vld1q_u64(a + i);
|
|
514
|
+
uint64x2_t b_u64x2 = vld1q_u64(b + i);
|
|
515
|
+
uint64x2_t sum_u64x2 = vqaddq_u64(a_u64x2, b_u64x2);
|
|
516
|
+
vst1q_u64(result + i, sum_u64x2);
|
|
517
|
+
}
|
|
518
|
+
|
|
519
|
+
// The tail:
|
|
520
|
+
for (; i < n; ++i) result[i] = nk_u64_saturating_add_serial(a[i], b[i]);
|
|
521
|
+
}
|
|
522
|
+
|
|
523
|
+
NK_PUBLIC void nk_each_scale_u64_neon(nk_u64_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
|
|
524
|
+
nk_u64_t *result) {
|
|
525
|
+
nk_f64_t alpha_val = *alpha;
|
|
526
|
+
nk_f64_t beta_val = *beta;
|
|
527
|
+
float64x2_t alpha_f64x2 = vdupq_n_f64(alpha_val);
|
|
528
|
+
float64x2_t beta_f64x2 = vdupq_n_f64(beta_val);
|
|
529
|
+
float64x2_t min_f64x2 = vdupq_n_f64(0.0);
|
|
530
|
+
float64x2_t max_f64x2 = vdupq_n_f64((nk_f64_t)NK_U64_MAX);
|
|
531
|
+
|
|
532
|
+
// The main loop:
|
|
533
|
+
nk_size_t i = 0;
|
|
534
|
+
for (; i + 2 <= n; i += 2) {
|
|
535
|
+
uint64x2_t a_u64x2 = vld1q_u64(a + i);
|
|
536
|
+
float64x2_t a_f64x2 = vcvtq_f64_u64(a_u64x2);
|
|
537
|
+
float64x2_t result_f64x2 = vfmaq_f64(beta_f64x2, a_f64x2, alpha_f64x2);
|
|
538
|
+
result_f64x2 = vmaxq_f64(vminq_f64(result_f64x2, max_f64x2), min_f64x2);
|
|
539
|
+
uint64x2_t result_u64x2 = vcvtnq_u64_f64(result_f64x2);
|
|
540
|
+
vst1q_u64(result + i, result_u64x2);
|
|
541
|
+
}
|
|
542
|
+
|
|
543
|
+
// The tail:
|
|
544
|
+
for (; i < n; ++i) {
|
|
545
|
+
nk_f64_t sum = alpha_val * a[i] + beta_val;
|
|
546
|
+
nk_f64_to_u64_serial(&sum, result + i);
|
|
547
|
+
}
|
|
548
|
+
}
|
|
549
|
+
|
|
550
|
+
NK_PUBLIC void nk_each_fma_u64_neon( //
|
|
551
|
+
nk_u64_t const *a, nk_u64_t const *b, nk_u64_t const *c, //
|
|
552
|
+
nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta, nk_u64_t *result) {
|
|
553
|
+
nk_f64_t alpha_val = *alpha;
|
|
554
|
+
nk_f64_t beta_val = *beta;
|
|
555
|
+
float64x2_t min_f64x2 = vdupq_n_f64(0.0);
|
|
556
|
+
float64x2_t max_f64x2 = vdupq_n_f64((nk_f64_t)NK_U64_MAX);
|
|
557
|
+
|
|
558
|
+
// The main loop:
|
|
559
|
+
nk_size_t i = 0;
|
|
560
|
+
for (; i + 2 <= n; i += 2) {
|
|
561
|
+
uint64x2_t a_u64x2 = vld1q_u64(a + i);
|
|
562
|
+
uint64x2_t b_u64x2 = vld1q_u64(b + i);
|
|
563
|
+
uint64x2_t c_u64x2 = vld1q_u64(c + i);
|
|
564
|
+
float64x2_t a_f64x2 = vcvtq_f64_u64(a_u64x2);
|
|
565
|
+
float64x2_t b_f64x2 = vcvtq_f64_u64(b_u64x2);
|
|
566
|
+
float64x2_t c_f64x2 = vcvtq_f64_u64(c_u64x2);
|
|
567
|
+
float64x2_t ab_f64x2 = vmulq_f64(a_f64x2, b_f64x2);
|
|
568
|
+
float64x2_t ab_scaled_f64x2 = vmulq_n_f64(ab_f64x2, alpha_val);
|
|
569
|
+
float64x2_t result_f64x2 = vfmaq_n_f64(ab_scaled_f64x2, c_f64x2, beta_val);
|
|
570
|
+
result_f64x2 = vmaxq_f64(vminq_f64(result_f64x2, max_f64x2), min_f64x2);
|
|
571
|
+
uint64x2_t result_u64x2 = vcvtnq_u64_f64(result_f64x2);
|
|
572
|
+
vst1q_u64(result + i, result_u64x2);
|
|
573
|
+
}
|
|
574
|
+
|
|
575
|
+
// The tail:
|
|
576
|
+
for (; i < n; ++i) {
|
|
577
|
+
nk_f64_t sum = alpha_val * a[i] * b[i] + beta_val * c[i];
|
|
578
|
+
nk_f64_to_u64_serial(&sum, result + i);
|
|
579
|
+
}
|
|
580
|
+
}
|
|
581
|
+
|
|
582
|
+
NK_PUBLIC void nk_each_sum_f64_neon(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result) {
|
|
583
|
+
// The main loop:
|
|
584
|
+
nk_size_t i = 0;
|
|
585
|
+
for (; i + 2 <= n; i += 2) {
|
|
586
|
+
float64x2_t a_f64x2 = vld1q_f64(a + i);
|
|
587
|
+
float64x2_t b_f64x2 = vld1q_f64(b + i);
|
|
588
|
+
float64x2_t sum_f64x2 = vaddq_f64(a_f64x2, b_f64x2);
|
|
589
|
+
vst1q_f64(result + i, sum_f64x2);
|
|
590
|
+
}
|
|
591
|
+
|
|
592
|
+
// The tail:
|
|
593
|
+
for (; i < n; ++i) result[i] = a[i] + b[i];
|
|
594
|
+
}
|
|
595
|
+
|
|
596
|
+
NK_PUBLIC void nk_each_scale_f64_neon(nk_f64_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
|
|
597
|
+
nk_f64_t *result) {
|
|
598
|
+
nk_f64_t alpha_val = *alpha;
|
|
599
|
+
nk_f64_t beta_val = *beta;
|
|
600
|
+
float64x2_t alpha_f64x2 = vdupq_n_f64(alpha_val);
|
|
601
|
+
float64x2_t beta_f64x2 = vdupq_n_f64(beta_val);
|
|
602
|
+
|
|
603
|
+
// The main loop:
|
|
604
|
+
nk_size_t i = 0;
|
|
605
|
+
for (; i + 2 <= n; i += 2) {
|
|
606
|
+
float64x2_t a_f64x2 = vld1q_f64(a + i);
|
|
607
|
+
float64x2_t result_f64x2 = vfmaq_f64(beta_f64x2, a_f64x2, alpha_f64x2);
|
|
608
|
+
vst1q_f64(result + i, result_f64x2);
|
|
609
|
+
}
|
|
610
|
+
|
|
611
|
+
// The tail:
|
|
612
|
+
for (; i < n; ++i) result[i] = alpha_val * a[i] + beta_val;
|
|
613
|
+
}
|
|
614
|
+
|
|
615
|
+
NK_PUBLIC void nk_each_blend_f64_neon( //
|
|
616
|
+
nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, //
|
|
617
|
+
nk_f64_t const *alpha, nk_f64_t const *beta, nk_f64_t *result) {
|
|
618
|
+
|
|
619
|
+
nk_f64_t alpha_val = *alpha;
|
|
620
|
+
nk_f64_t beta_val = *beta;
|
|
621
|
+
|
|
622
|
+
// There are several special cases we may want to implement:
|
|
623
|
+
// 1. Simple addition, when both weights are equal to 1.0.
|
|
624
|
+
if (alpha_val == 1 && beta_val == 1) {
|
|
625
|
+
// In this case we can avoid expensive multiplications.
|
|
626
|
+
nk_each_sum_f64_neon(a, b, n, result);
|
|
627
|
+
return;
|
|
628
|
+
}
|
|
629
|
+
// 2. Just scaling, when one of the weights is equal to zero.
|
|
630
|
+
else if (alpha_val == 0 || beta_val == 0) {
|
|
631
|
+
// In this case we can avoid half of the load instructions.
|
|
632
|
+
nk_f64_t zero = 0;
|
|
633
|
+
if (beta_val == 0) { nk_each_scale_f64_neon(a, n, alpha, &zero, result); }
|
|
634
|
+
else { nk_each_scale_f64_neon(b, n, beta, &zero, result); }
|
|
635
|
+
return;
|
|
636
|
+
}
|
|
637
|
+
|
|
638
|
+
// The general case.
|
|
639
|
+
// The main loop:
|
|
640
|
+
nk_size_t i = 0;
|
|
641
|
+
for (; i + 2 <= n; i += 2) {
|
|
642
|
+
float64x2_t a_f64x2 = vld1q_f64(a + i);
|
|
643
|
+
float64x2_t b_f64x2 = vld1q_f64(b + i);
|
|
644
|
+
float64x2_t a_scaled_f64x2 = vmulq_n_f64(a_f64x2, alpha_val);
|
|
645
|
+
float64x2_t b_scaled_f64x2 = vmulq_n_f64(b_f64x2, beta_val);
|
|
646
|
+
float64x2_t result_f64x2 = vaddq_f64(a_scaled_f64x2, b_scaled_f64x2);
|
|
647
|
+
vst1q_f64(result + i, result_f64x2);
|
|
648
|
+
}
|
|
649
|
+
|
|
650
|
+
// The tail:
|
|
651
|
+
for (; i < n; ++i) result[i] = alpha_val * a[i] + beta_val * b[i];
|
|
652
|
+
}
|
|
653
|
+
|
|
654
|
+
NK_PUBLIC void nk_each_fma_f64_neon( //
|
|
655
|
+
nk_f64_t const *a, nk_f64_t const *b, nk_f64_t const *c, //
|
|
656
|
+
nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta, nk_f64_t *result) {
|
|
657
|
+
nk_f64_t alpha_val = *alpha;
|
|
658
|
+
nk_f64_t beta_val = *beta;
|
|
659
|
+
|
|
660
|
+
// The main loop:
|
|
661
|
+
nk_size_t i = 0;
|
|
662
|
+
for (; i + 2 <= n; i += 2) {
|
|
663
|
+
float64x2_t a_f64x2 = vld1q_f64(a + i);
|
|
664
|
+
float64x2_t b_f64x2 = vld1q_f64(b + i);
|
|
665
|
+
float64x2_t c_f64x2 = vld1q_f64(c + i);
|
|
666
|
+
float64x2_t ab_f64x2 = vmulq_f64(a_f64x2, b_f64x2);
|
|
667
|
+
float64x2_t ab_scaled_f64x2 = vmulq_n_f64(ab_f64x2, alpha_val);
|
|
668
|
+
float64x2_t result_f64x2 = vfmaq_n_f64(ab_scaled_f64x2, c_f64x2, beta_val);
|
|
669
|
+
vst1q_f64(result + i, result_f64x2);
|
|
670
|
+
}
|
|
671
|
+
|
|
672
|
+
// The tail:
|
|
673
|
+
for (; i < n; ++i) result[i] = alpha_val * a[i] * b[i] + beta_val * c[i];
|
|
674
|
+
}
|
|
675
|
+
|
|
676
|
+
NK_PUBLIC void nk_each_sum_e4m3_neon(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_e4m3_t *result) {
|
|
677
|
+
nk_size_t i = 0;
|
|
678
|
+
for (; i + 8 <= n; i += 8) {
|
|
679
|
+
float16x8_t a_f16x8 = nk_e4m3x8_to_f16x8_neon_(vld1_u8(a + i));
|
|
680
|
+
float16x8_t b_f16x8 = nk_e4m3x8_to_f16x8_neon_(vld1_u8(b + i));
|
|
681
|
+
float32x4_t a_low_f32x4 = vcvt_f32_f16(vget_low_f16(a_f16x8));
|
|
682
|
+
float32x4_t a_high_f32x4 = vcvt_f32_f16(vget_high_f16(a_f16x8));
|
|
683
|
+
float32x4_t b_low_f32x4 = vcvt_f32_f16(vget_low_f16(b_f16x8));
|
|
684
|
+
float32x4_t b_high_f32x4 = vcvt_f32_f16(vget_high_f16(b_f16x8));
|
|
685
|
+
float32x4_t result_low_f32x4 = vaddq_f32(a_low_f32x4, b_low_f32x4);
|
|
686
|
+
float32x4_t result_high_f32x4 = vaddq_f32(a_high_f32x4, b_high_f32x4);
|
|
687
|
+
nk_b32_vec_t result_low_vec = nk_f32x4_to_e4m3x4_neon_(result_low_f32x4);
|
|
688
|
+
nk_b32_vec_t result_high_vec = nk_f32x4_to_e4m3x4_neon_(result_high_f32x4);
|
|
689
|
+
vst1_u8(result + i, vcreate_u8((nk_u64_t)result_low_vec.u32 | ((nk_u64_t)result_high_vec.u32 << 32)));
|
|
690
|
+
}
|
|
691
|
+
for (; i < n; ++i) {
|
|
692
|
+
nk_f32_t ai, bi, sum;
|
|
693
|
+
nk_e4m3_to_f32_serial(a + i, &ai);
|
|
694
|
+
nk_e4m3_to_f32_serial(b + i, &bi);
|
|
695
|
+
sum = ai + bi;
|
|
696
|
+
nk_f32_to_e4m3_serial(&sum, result + i);
|
|
697
|
+
}
|
|
698
|
+
}
|
|
699
|
+
|
|
700
|
+
NK_PUBLIC void nk_each_sum_e5m2_neon(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_e5m2_t *result) {
|
|
701
|
+
nk_size_t i = 0;
|
|
702
|
+
for (; i + 8 <= n; i += 8) {
|
|
703
|
+
float16x8_t a_f16x8 = nk_e5m2x8_to_f16x8_neon_(vld1_u8(a + i));
|
|
704
|
+
float16x8_t b_f16x8 = nk_e5m2x8_to_f16x8_neon_(vld1_u8(b + i));
|
|
705
|
+
float32x4_t a_low_f32x4 = vcvt_f32_f16(vget_low_f16(a_f16x8));
|
|
706
|
+
float32x4_t a_high_f32x4 = vcvt_f32_f16(vget_high_f16(a_f16x8));
|
|
707
|
+
float32x4_t b_low_f32x4 = vcvt_f32_f16(vget_low_f16(b_f16x8));
|
|
708
|
+
float32x4_t b_high_f32x4 = vcvt_f32_f16(vget_high_f16(b_f16x8));
|
|
709
|
+
float32x4_t result_low_f32x4 = vaddq_f32(a_low_f32x4, b_low_f32x4);
|
|
710
|
+
float32x4_t result_high_f32x4 = vaddq_f32(a_high_f32x4, b_high_f32x4);
|
|
711
|
+
nk_b32_vec_t result_low_vec = nk_f32x4_to_e5m2x4_neon_(result_low_f32x4);
|
|
712
|
+
nk_b32_vec_t result_high_vec = nk_f32x4_to_e5m2x4_neon_(result_high_f32x4);
|
|
713
|
+
vst1_u8(result + i, vcreate_u8((nk_u64_t)result_low_vec.u32 | ((nk_u64_t)result_high_vec.u32 << 32)));
|
|
714
|
+
}
|
|
715
|
+
for (; i < n; ++i) {
|
|
716
|
+
nk_f32_t ai, bi, sum;
|
|
717
|
+
nk_e5m2_to_f32_serial(a + i, &ai);
|
|
718
|
+
nk_e5m2_to_f32_serial(b + i, &bi);
|
|
719
|
+
sum = ai + bi;
|
|
720
|
+
nk_f32_to_e5m2_serial(&sum, result + i);
|
|
721
|
+
}
|
|
722
|
+
}
|
|
723
|
+
|
|
724
|
+
NK_PUBLIC void nk_each_scale_e4m3_neon(nk_e4m3_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
725
|
+
nk_e4m3_t *result) {
|
|
726
|
+
float32x4_t alpha_f32x4 = vdupq_n_f32(*alpha);
|
|
727
|
+
float32x4_t beta_f32x4 = vdupq_n_f32(*beta);
|
|
728
|
+
nk_size_t i = 0;
|
|
729
|
+
for (; i + 8 <= n; i += 8) {
|
|
730
|
+
float16x8_t a_f16x8 = nk_e4m3x8_to_f16x8_neon_(vld1_u8(a + i));
|
|
731
|
+
float32x4_t a_low_f32x4 = vcvt_f32_f16(vget_low_f16(a_f16x8));
|
|
732
|
+
float32x4_t a_high_f32x4 = vcvt_f32_f16(vget_high_f16(a_f16x8));
|
|
733
|
+
float32x4_t result_low_f32x4 = vfmaq_f32(beta_f32x4, a_low_f32x4, alpha_f32x4);
|
|
734
|
+
float32x4_t result_high_f32x4 = vfmaq_f32(beta_f32x4, a_high_f32x4, alpha_f32x4);
|
|
735
|
+
nk_b32_vec_t result_low_vec = nk_f32x4_to_e4m3x4_neon_(result_low_f32x4);
|
|
736
|
+
nk_b32_vec_t result_high_vec = nk_f32x4_to_e4m3x4_neon_(result_high_f32x4);
|
|
737
|
+
vst1_u8(result + i, vcreate_u8((nk_u64_t)result_low_vec.u32 | ((nk_u64_t)result_high_vec.u32 << 32)));
|
|
738
|
+
}
|
|
739
|
+
for (; i < n; ++i) {
|
|
740
|
+
nk_f32_t ai, scaled;
|
|
741
|
+
nk_e4m3_to_f32_serial(a + i, &ai);
|
|
742
|
+
scaled = *alpha * ai + *beta;
|
|
743
|
+
nk_f32_to_e4m3_serial(&scaled, result + i);
|
|
744
|
+
}
|
|
745
|
+
}
|
|
746
|
+
|
|
747
|
+
NK_PUBLIC void nk_each_scale_e5m2_neon(nk_e5m2_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
748
|
+
nk_e5m2_t *result) {
|
|
749
|
+
float32x4_t alpha_f32x4 = vdupq_n_f32(*alpha);
|
|
750
|
+
float32x4_t beta_f32x4 = vdupq_n_f32(*beta);
|
|
751
|
+
nk_size_t i = 0;
|
|
752
|
+
for (; i + 8 <= n; i += 8) {
|
|
753
|
+
float16x8_t a_f16x8 = nk_e5m2x8_to_f16x8_neon_(vld1_u8(a + i));
|
|
754
|
+
float32x4_t a_low_f32x4 = vcvt_f32_f16(vget_low_f16(a_f16x8));
|
|
755
|
+
float32x4_t a_high_f32x4 = vcvt_f32_f16(vget_high_f16(a_f16x8));
|
|
756
|
+
float32x4_t result_low_f32x4 = vfmaq_f32(beta_f32x4, a_low_f32x4, alpha_f32x4);
|
|
757
|
+
float32x4_t result_high_f32x4 = vfmaq_f32(beta_f32x4, a_high_f32x4, alpha_f32x4);
|
|
758
|
+
nk_b32_vec_t result_low_vec = nk_f32x4_to_e5m2x4_neon_(result_low_f32x4);
|
|
759
|
+
nk_b32_vec_t result_high_vec = nk_f32x4_to_e5m2x4_neon_(result_high_f32x4);
|
|
760
|
+
vst1_u8(result + i, vcreate_u8((nk_u64_t)result_low_vec.u32 | ((nk_u64_t)result_high_vec.u32 << 32)));
|
|
761
|
+
}
|
|
762
|
+
for (; i < n; ++i) {
|
|
763
|
+
nk_f32_t ai, scaled;
|
|
764
|
+
nk_e5m2_to_f32_serial(a + i, &ai);
|
|
765
|
+
scaled = *alpha * ai + *beta;
|
|
766
|
+
nk_f32_to_e5m2_serial(&scaled, result + i);
|
|
767
|
+
}
|
|
768
|
+
}
|
|
769
|
+
|
|
770
|
+
NK_PUBLIC void nk_each_blend_e4m3_neon(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t const *alpha,
|
|
771
|
+
nk_f32_t const *beta, nk_e4m3_t *result) {
|
|
772
|
+
float32x4_t alpha_f32x4 = vdupq_n_f32(*alpha);
|
|
773
|
+
float32x4_t beta_f32x4 = vdupq_n_f32(*beta);
|
|
774
|
+
nk_size_t i = 0;
|
|
775
|
+
for (; i + 8 <= n; i += 8) {
|
|
776
|
+
float16x8_t a_f16x8 = nk_e4m3x8_to_f16x8_neon_(vld1_u8(a + i));
|
|
777
|
+
float16x8_t b_f16x8 = nk_e4m3x8_to_f16x8_neon_(vld1_u8(b + i));
|
|
778
|
+
float32x4_t a_low_f32x4 = vcvt_f32_f16(vget_low_f16(a_f16x8));
|
|
779
|
+
float32x4_t a_high_f32x4 = vcvt_f32_f16(vget_high_f16(a_f16x8));
|
|
780
|
+
float32x4_t b_low_f32x4 = vcvt_f32_f16(vget_low_f16(b_f16x8));
|
|
781
|
+
float32x4_t b_high_f32x4 = vcvt_f32_f16(vget_high_f16(b_f16x8));
|
|
782
|
+
float32x4_t a_scaled_low_f32x4 = vmulq_f32(a_low_f32x4, alpha_f32x4);
|
|
783
|
+
float32x4_t a_scaled_high_f32x4 = vmulq_f32(a_high_f32x4, alpha_f32x4);
|
|
784
|
+
float32x4_t result_low_f32x4 = vfmaq_f32(a_scaled_low_f32x4, b_low_f32x4, beta_f32x4);
|
|
785
|
+
float32x4_t result_high_f32x4 = vfmaq_f32(a_scaled_high_f32x4, b_high_f32x4, beta_f32x4);
|
|
786
|
+
nk_b32_vec_t result_low_vec = nk_f32x4_to_e4m3x4_neon_(result_low_f32x4);
|
|
787
|
+
nk_b32_vec_t result_high_vec = nk_f32x4_to_e4m3x4_neon_(result_high_f32x4);
|
|
788
|
+
vst1_u8(result + i, vcreate_u8((nk_u64_t)result_low_vec.u32 | ((nk_u64_t)result_high_vec.u32 << 32)));
|
|
789
|
+
}
|
|
790
|
+
for (; i < n; ++i) {
|
|
791
|
+
nk_f32_t ai, bi, blended;
|
|
792
|
+
nk_e4m3_to_f32_serial(a + i, &ai);
|
|
793
|
+
nk_e4m3_to_f32_serial(b + i, &bi);
|
|
794
|
+
blended = *alpha * ai + *beta * bi;
|
|
795
|
+
nk_f32_to_e4m3_serial(&blended, result + i);
|
|
796
|
+
}
|
|
797
|
+
}
|
|
798
|
+
|
|
799
|
+
NK_PUBLIC void nk_each_blend_e5m2_neon(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t const *alpha,
|
|
800
|
+
nk_f32_t const *beta, nk_e5m2_t *result) {
|
|
801
|
+
float32x4_t alpha_f32x4 = vdupq_n_f32(*alpha);
|
|
802
|
+
float32x4_t beta_f32x4 = vdupq_n_f32(*beta);
|
|
803
|
+
nk_size_t i = 0;
|
|
804
|
+
for (; i + 8 <= n; i += 8) {
|
|
805
|
+
float16x8_t a_f16x8 = nk_e5m2x8_to_f16x8_neon_(vld1_u8(a + i));
|
|
806
|
+
float16x8_t b_f16x8 = nk_e5m2x8_to_f16x8_neon_(vld1_u8(b + i));
|
|
807
|
+
float32x4_t a_low_f32x4 = vcvt_f32_f16(vget_low_f16(a_f16x8));
|
|
808
|
+
float32x4_t a_high_f32x4 = vcvt_f32_f16(vget_high_f16(a_f16x8));
|
|
809
|
+
float32x4_t b_low_f32x4 = vcvt_f32_f16(vget_low_f16(b_f16x8));
|
|
810
|
+
float32x4_t b_high_f32x4 = vcvt_f32_f16(vget_high_f16(b_f16x8));
|
|
811
|
+
float32x4_t a_scaled_low_f32x4 = vmulq_f32(a_low_f32x4, alpha_f32x4);
|
|
812
|
+
float32x4_t a_scaled_high_f32x4 = vmulq_f32(a_high_f32x4, alpha_f32x4);
|
|
813
|
+
float32x4_t result_low_f32x4 = vfmaq_f32(a_scaled_low_f32x4, b_low_f32x4, beta_f32x4);
|
|
814
|
+
float32x4_t result_high_f32x4 = vfmaq_f32(a_scaled_high_f32x4, b_high_f32x4, beta_f32x4);
|
|
815
|
+
nk_b32_vec_t result_low_vec = nk_f32x4_to_e5m2x4_neon_(result_low_f32x4);
|
|
816
|
+
nk_b32_vec_t result_high_vec = nk_f32x4_to_e5m2x4_neon_(result_high_f32x4);
|
|
817
|
+
vst1_u8(result + i, vcreate_u8((nk_u64_t)result_low_vec.u32 | ((nk_u64_t)result_high_vec.u32 << 32)));
|
|
818
|
+
}
|
|
819
|
+
for (; i < n; ++i) {
|
|
820
|
+
nk_f32_t ai, bi, blended;
|
|
821
|
+
nk_e5m2_to_f32_serial(a + i, &ai);
|
|
822
|
+
nk_e5m2_to_f32_serial(b + i, &bi);
|
|
823
|
+
blended = *alpha * ai + *beta * bi;
|
|
824
|
+
nk_f32_to_e5m2_serial(&blended, result + i);
|
|
825
|
+
}
|
|
826
|
+
}
|
|
827
|
+
|
|
828
|
+
NK_PUBLIC void nk_each_fma_e4m3_neon(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_e4m3_t const *c, nk_size_t n,
|
|
829
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_e4m3_t *result) {
|
|
830
|
+
float32x4_t alpha_f32x4 = vdupq_n_f32(*alpha);
|
|
831
|
+
float32x4_t beta_f32x4 = vdupq_n_f32(*beta);
|
|
832
|
+
nk_size_t i = 0;
|
|
833
|
+
for (; i + 8 <= n; i += 8) {
|
|
834
|
+
float16x8_t a_f16x8 = nk_e4m3x8_to_f16x8_neon_(vld1_u8(a + i));
|
|
835
|
+
float16x8_t b_f16x8 = nk_e4m3x8_to_f16x8_neon_(vld1_u8(b + i));
|
|
836
|
+
float16x8_t c_f16x8 = nk_e4m3x8_to_f16x8_neon_(vld1_u8(c + i));
|
|
837
|
+
float32x4_t a_low_f32x4 = vcvt_f32_f16(vget_low_f16(a_f16x8));
|
|
838
|
+
float32x4_t a_high_f32x4 = vcvt_f32_f16(vget_high_f16(a_f16x8));
|
|
839
|
+
float32x4_t b_low_f32x4 = vcvt_f32_f16(vget_low_f16(b_f16x8));
|
|
840
|
+
float32x4_t b_high_f32x4 = vcvt_f32_f16(vget_high_f16(b_f16x8));
|
|
841
|
+
float32x4_t c_low_f32x4 = vcvt_f32_f16(vget_low_f16(c_f16x8));
|
|
842
|
+
float32x4_t c_high_f32x4 = vcvt_f32_f16(vget_high_f16(c_f16x8));
|
|
843
|
+
float32x4_t ab_low_f32x4 = vmulq_f32(a_low_f32x4, b_low_f32x4);
|
|
844
|
+
float32x4_t ab_high_f32x4 = vmulq_f32(a_high_f32x4, b_high_f32x4);
|
|
845
|
+
float32x4_t ab_scaled_low_f32x4 = vmulq_f32(ab_low_f32x4, alpha_f32x4);
|
|
846
|
+
float32x4_t ab_scaled_high_f32x4 = vmulq_f32(ab_high_f32x4, alpha_f32x4);
|
|
847
|
+
float32x4_t result_low_f32x4 = vfmaq_f32(ab_scaled_low_f32x4, c_low_f32x4, beta_f32x4);
|
|
848
|
+
float32x4_t result_high_f32x4 = vfmaq_f32(ab_scaled_high_f32x4, c_high_f32x4, beta_f32x4);
|
|
849
|
+
nk_b32_vec_t result_low_vec = nk_f32x4_to_e4m3x4_neon_(result_low_f32x4);
|
|
850
|
+
nk_b32_vec_t result_high_vec = nk_f32x4_to_e4m3x4_neon_(result_high_f32x4);
|
|
851
|
+
vst1_u8(result + i, vcreate_u8((nk_u64_t)result_low_vec.u32 | ((nk_u64_t)result_high_vec.u32 << 32)));
|
|
852
|
+
}
|
|
853
|
+
for (; i < n; ++i) {
|
|
854
|
+
nk_f32_t ai, bi, ci, fma;
|
|
855
|
+
nk_e4m3_to_f32_serial(a + i, &ai);
|
|
856
|
+
nk_e4m3_to_f32_serial(b + i, &bi);
|
|
857
|
+
nk_e4m3_to_f32_serial(c + i, &ci);
|
|
858
|
+
fma = *alpha * ai * bi + *beta * ci;
|
|
859
|
+
nk_f32_to_e4m3_serial(&fma, result + i);
|
|
860
|
+
}
|
|
861
|
+
}
|
|
862
|
+
|
|
863
|
+
NK_PUBLIC void nk_each_fma_e5m2_neon(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_e5m2_t const *c, nk_size_t n,
|
|
864
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_e5m2_t *result) {
|
|
865
|
+
float32x4_t alpha_f32x4 = vdupq_n_f32(*alpha);
|
|
866
|
+
float32x4_t beta_f32x4 = vdupq_n_f32(*beta);
|
|
867
|
+
nk_size_t i = 0;
|
|
868
|
+
for (; i + 8 <= n; i += 8) {
|
|
869
|
+
float16x8_t a_f16x8 = nk_e5m2x8_to_f16x8_neon_(vld1_u8(a + i));
|
|
870
|
+
float16x8_t b_f16x8 = nk_e5m2x8_to_f16x8_neon_(vld1_u8(b + i));
|
|
871
|
+
float16x8_t c_f16x8 = nk_e5m2x8_to_f16x8_neon_(vld1_u8(c + i));
|
|
872
|
+
float32x4_t a_low_f32x4 = vcvt_f32_f16(vget_low_f16(a_f16x8));
|
|
873
|
+
float32x4_t a_high_f32x4 = vcvt_f32_f16(vget_high_f16(a_f16x8));
|
|
874
|
+
float32x4_t b_low_f32x4 = vcvt_f32_f16(vget_low_f16(b_f16x8));
|
|
875
|
+
float32x4_t b_high_f32x4 = vcvt_f32_f16(vget_high_f16(b_f16x8));
|
|
876
|
+
float32x4_t c_low_f32x4 = vcvt_f32_f16(vget_low_f16(c_f16x8));
|
|
877
|
+
float32x4_t c_high_f32x4 = vcvt_f32_f16(vget_high_f16(c_f16x8));
|
|
878
|
+
float32x4_t ab_low_f32x4 = vmulq_f32(a_low_f32x4, b_low_f32x4);
|
|
879
|
+
float32x4_t ab_high_f32x4 = vmulq_f32(a_high_f32x4, b_high_f32x4);
|
|
880
|
+
float32x4_t ab_scaled_low_f32x4 = vmulq_f32(ab_low_f32x4, alpha_f32x4);
|
|
881
|
+
float32x4_t ab_scaled_high_f32x4 = vmulq_f32(ab_high_f32x4, alpha_f32x4);
|
|
882
|
+
float32x4_t result_low_f32x4 = vfmaq_f32(ab_scaled_low_f32x4, c_low_f32x4, beta_f32x4);
|
|
883
|
+
float32x4_t result_high_f32x4 = vfmaq_f32(ab_scaled_high_f32x4, c_high_f32x4, beta_f32x4);
|
|
884
|
+
nk_b32_vec_t result_low_vec = nk_f32x4_to_e5m2x4_neon_(result_low_f32x4);
|
|
885
|
+
nk_b32_vec_t result_high_vec = nk_f32x4_to_e5m2x4_neon_(result_high_f32x4);
|
|
886
|
+
vst1_u8(result + i, vcreate_u8((nk_u64_t)result_low_vec.u32 | ((nk_u64_t)result_high_vec.u32 << 32)));
|
|
887
|
+
}
|
|
888
|
+
for (; i < n; ++i) {
|
|
889
|
+
nk_f32_t ai, bi, ci, fma;
|
|
890
|
+
nk_e5m2_to_f32_serial(a + i, &ai);
|
|
891
|
+
nk_e5m2_to_f32_serial(b + i, &bi);
|
|
892
|
+
nk_e5m2_to_f32_serial(c + i, &ci);
|
|
893
|
+
fma = *alpha * ai * bi + *beta * ci;
|
|
894
|
+
nk_f32_to_e5m2_serial(&fma, result + i);
|
|
895
|
+
}
|
|
896
|
+
}
|
|
897
|
+
|
|
898
|
+
NK_PUBLIC void nk_each_scale_f32c_neon(nk_f32c_t const *a, nk_size_t n, nk_f32c_t const *alpha, nk_f32c_t const *beta,
|
|
899
|
+
nk_f32c_t *result) {
|
|
900
|
+
float32x4_t alpha_real_f32x4 = vdupq_n_f32(alpha->real);
|
|
901
|
+
float32x4_t alpha_imag_f32x4 = vdupq_n_f32(alpha->imag);
|
|
902
|
+
float32x4_t beta_real_f32x4 = vdupq_n_f32(beta->real);
|
|
903
|
+
float32x4_t beta_imag_f32x4 = vdupq_n_f32(beta->imag);
|
|
904
|
+
nk_size_t i = 0;
|
|
905
|
+
for (; i + 4 <= n; i += 4) {
|
|
906
|
+
float32x4x2_t a_f32x4x2 = vld2q_f32((nk_f32_t const *)(a + i));
|
|
907
|
+
float32x4_t y_real_f32x4 = vfmaq_f32(beta_real_f32x4, alpha_real_f32x4, a_f32x4x2.val[0]);
|
|
908
|
+
y_real_f32x4 = vfmsq_f32(y_real_f32x4, alpha_imag_f32x4, a_f32x4x2.val[1]);
|
|
909
|
+
float32x4_t y_imag_f32x4 = vfmaq_f32(beta_imag_f32x4, alpha_real_f32x4, a_f32x4x2.val[1]);
|
|
910
|
+
y_imag_f32x4 = vfmaq_f32(y_imag_f32x4, alpha_imag_f32x4, a_f32x4x2.val[0]);
|
|
911
|
+
float32x4x2_t out = {y_real_f32x4, y_imag_f32x4};
|
|
912
|
+
vst2q_f32((nk_f32_t *)(result + i), out);
|
|
913
|
+
}
|
|
914
|
+
for (; i < n; i++) {
|
|
915
|
+
nk_f32_t a_real = a[i].real, a_imag = a[i].imag;
|
|
916
|
+
result[i].real = alpha->real * a_real - alpha->imag * a_imag + beta->real;
|
|
917
|
+
result[i].imag = alpha->real * a_imag + alpha->imag * a_real + beta->imag;
|
|
918
|
+
}
|
|
919
|
+
}
|
|
920
|
+
|
|
921
|
+
NK_PUBLIC void nk_each_scale_f64c_neon(nk_f64c_t const *a, nk_size_t n, nk_f64c_t const *alpha, nk_f64c_t const *beta,
|
|
922
|
+
nk_f64c_t *result) {
|
|
923
|
+
float64x2_t alpha_real_f64x2 = vdupq_n_f64(alpha->real);
|
|
924
|
+
float64x2_t alpha_imag_f64x2 = vdupq_n_f64(alpha->imag);
|
|
925
|
+
float64x2_t beta_real_f64x2 = vdupq_n_f64(beta->real);
|
|
926
|
+
float64x2_t beta_imag_f64x2 = vdupq_n_f64(beta->imag);
|
|
927
|
+
nk_size_t i = 0;
|
|
928
|
+
for (; i + 2 <= n; i += 2) {
|
|
929
|
+
float64x2x2_t a_f64x2x2 = vld2q_f64((nk_f64_t const *)(a + i));
|
|
930
|
+
float64x2_t y_real_f64x2 = vfmaq_f64(beta_real_f64x2, alpha_real_f64x2, a_f64x2x2.val[0]);
|
|
931
|
+
y_real_f64x2 = vfmsq_f64(y_real_f64x2, alpha_imag_f64x2, a_f64x2x2.val[1]);
|
|
932
|
+
float64x2_t y_imag_f64x2 = vfmaq_f64(beta_imag_f64x2, alpha_real_f64x2, a_f64x2x2.val[1]);
|
|
933
|
+
y_imag_f64x2 = vfmaq_f64(y_imag_f64x2, alpha_imag_f64x2, a_f64x2x2.val[0]);
|
|
934
|
+
float64x2x2_t out = {y_real_f64x2, y_imag_f64x2};
|
|
935
|
+
vst2q_f64((nk_f64_t *)(result + i), out);
|
|
936
|
+
}
|
|
937
|
+
for (; i < n; i++) {
|
|
938
|
+
nk_f64_t a_real = a[i].real, a_imag = a[i].imag;
|
|
939
|
+
result[i].real = alpha->real * a_real - alpha->imag * a_imag + beta->real;
|
|
940
|
+
result[i].imag = alpha->real * a_imag + alpha->imag * a_real + beta->imag;
|
|
941
|
+
}
|
|
942
|
+
}
|
|
943
|
+
|
|
944
|
+
NK_PUBLIC void nk_each_blend_f32c_neon(nk_f32c_t const *a, nk_f32c_t const *b, nk_size_t n, nk_f32c_t const *alpha,
|
|
945
|
+
nk_f32c_t const *beta, nk_f32c_t *result) {
|
|
946
|
+
float32x4_t alpha_real_f32x4 = vdupq_n_f32(alpha->real);
|
|
947
|
+
float32x4_t alpha_imag_f32x4 = vdupq_n_f32(alpha->imag);
|
|
948
|
+
float32x4_t beta_real_f32x4 = vdupq_n_f32(beta->real);
|
|
949
|
+
float32x4_t beta_imag_f32x4 = vdupq_n_f32(beta->imag);
|
|
950
|
+
nk_size_t i = 0;
|
|
951
|
+
for (; i + 4 <= n; i += 4) {
|
|
952
|
+
float32x4x2_t a_f32x4x2 = vld2q_f32((nk_f32_t const *)(a + i));
|
|
953
|
+
float32x4x2_t b_f32x4x2 = vld2q_f32((nk_f32_t const *)(b + i));
|
|
954
|
+
float32x4_t ya_real_f32x4 = vmulq_f32(alpha_real_f32x4, a_f32x4x2.val[0]);
|
|
955
|
+
ya_real_f32x4 = vfmsq_f32(ya_real_f32x4, alpha_imag_f32x4, a_f32x4x2.val[1]);
|
|
956
|
+
float32x4_t ya_imag_f32x4 = vmulq_f32(alpha_real_f32x4, a_f32x4x2.val[1]);
|
|
957
|
+
ya_imag_f32x4 = vfmaq_f32(ya_imag_f32x4, alpha_imag_f32x4, a_f32x4x2.val[0]);
|
|
958
|
+
float32x4_t y_real_f32x4 = vfmaq_f32(ya_real_f32x4, beta_real_f32x4, b_f32x4x2.val[0]);
|
|
959
|
+
y_real_f32x4 = vfmsq_f32(y_real_f32x4, beta_imag_f32x4, b_f32x4x2.val[1]);
|
|
960
|
+
float32x4_t y_imag_f32x4 = vfmaq_f32(ya_imag_f32x4, beta_real_f32x4, b_f32x4x2.val[1]);
|
|
961
|
+
y_imag_f32x4 = vfmaq_f32(y_imag_f32x4, beta_imag_f32x4, b_f32x4x2.val[0]);
|
|
962
|
+
float32x4x2_t out = {y_real_f32x4, y_imag_f32x4};
|
|
963
|
+
vst2q_f32((nk_f32_t *)(result + i), out);
|
|
964
|
+
}
|
|
965
|
+
for (; i < n; i++) {
|
|
966
|
+
nk_f32_t a_real = a[i].real, a_imag = a[i].imag;
|
|
967
|
+
nk_f32_t b_real = b[i].real, b_imag = b[i].imag;
|
|
968
|
+
nk_f32_t ar = alpha->real * a_real - alpha->imag * a_imag;
|
|
969
|
+
nk_f32_t ai = alpha->real * a_imag + alpha->imag * a_real;
|
|
970
|
+
nk_f32_t br = beta->real * b_real - beta->imag * b_imag;
|
|
971
|
+
nk_f32_t bi = beta->real * b_imag + beta->imag * b_real;
|
|
972
|
+
result[i].real = ar + br;
|
|
973
|
+
result[i].imag = ai + bi;
|
|
974
|
+
}
|
|
975
|
+
}
|
|
976
|
+
|
|
977
|
+
NK_PUBLIC void nk_each_blend_f64c_neon(nk_f64c_t const *a, nk_f64c_t const *b, nk_size_t n, nk_f64c_t const *alpha,
|
|
978
|
+
nk_f64c_t const *beta, nk_f64c_t *result) {
|
|
979
|
+
float64x2_t alpha_real_f64x2 = vdupq_n_f64(alpha->real);
|
|
980
|
+
float64x2_t alpha_imag_f64x2 = vdupq_n_f64(alpha->imag);
|
|
981
|
+
float64x2_t beta_real_f64x2 = vdupq_n_f64(beta->real);
|
|
982
|
+
float64x2_t beta_imag_f64x2 = vdupq_n_f64(beta->imag);
|
|
983
|
+
nk_size_t i = 0;
|
|
984
|
+
for (; i + 2 <= n; i += 2) {
|
|
985
|
+
float64x2x2_t a_f64x2x2 = vld2q_f64((nk_f64_t const *)(a + i));
|
|
986
|
+
float64x2x2_t b_f64x2x2 = vld2q_f64((nk_f64_t const *)(b + i));
|
|
987
|
+
float64x2_t ya_real_f64x2 = vmulq_f64(alpha_real_f64x2, a_f64x2x2.val[0]);
|
|
988
|
+
ya_real_f64x2 = vfmsq_f64(ya_real_f64x2, alpha_imag_f64x2, a_f64x2x2.val[1]);
|
|
989
|
+
float64x2_t ya_imag_f64x2 = vmulq_f64(alpha_real_f64x2, a_f64x2x2.val[1]);
|
|
990
|
+
ya_imag_f64x2 = vfmaq_f64(ya_imag_f64x2, alpha_imag_f64x2, a_f64x2x2.val[0]);
|
|
991
|
+
float64x2_t y_real_f64x2 = vfmaq_f64(ya_real_f64x2, beta_real_f64x2, b_f64x2x2.val[0]);
|
|
992
|
+
y_real_f64x2 = vfmsq_f64(y_real_f64x2, beta_imag_f64x2, b_f64x2x2.val[1]);
|
|
993
|
+
float64x2_t y_imag_f64x2 = vfmaq_f64(ya_imag_f64x2, beta_real_f64x2, b_f64x2x2.val[1]);
|
|
994
|
+
y_imag_f64x2 = vfmaq_f64(y_imag_f64x2, beta_imag_f64x2, b_f64x2x2.val[0]);
|
|
995
|
+
float64x2x2_t out = {y_real_f64x2, y_imag_f64x2};
|
|
996
|
+
vst2q_f64((nk_f64_t *)(result + i), out);
|
|
997
|
+
}
|
|
998
|
+
for (; i < n; i++) {
|
|
999
|
+
nk_f64_t a_real = a[i].real, a_imag = a[i].imag;
|
|
1000
|
+
nk_f64_t b_real = b[i].real, b_imag = b[i].imag;
|
|
1001
|
+
nk_f64_t ar = alpha->real * a_real - alpha->imag * a_imag;
|
|
1002
|
+
nk_f64_t ai = alpha->real * a_imag + alpha->imag * a_real;
|
|
1003
|
+
nk_f64_t br = beta->real * b_real - beta->imag * b_imag;
|
|
1004
|
+
nk_f64_t bi = beta->real * b_imag + beta->imag * b_real;
|
|
1005
|
+
result[i].real = ar + br;
|
|
1006
|
+
result[i].imag = ai + bi;
|
|
1007
|
+
}
|
|
1008
|
+
}
|
|
1009
|
+
|
|
1010
|
+
NK_PUBLIC void nk_each_fma_f32c_neon(nk_f32c_t const *a, nk_f32c_t const *b, nk_f32c_t const *c, nk_size_t n,
|
|
1011
|
+
nk_f32c_t const *alpha, nk_f32c_t const *beta, nk_f32c_t *result) {
|
|
1012
|
+
float32x4_t alpha_real_f32x4 = vdupq_n_f32(alpha->real);
|
|
1013
|
+
float32x4_t alpha_imag_f32x4 = vdupq_n_f32(alpha->imag);
|
|
1014
|
+
float32x4_t beta_real_f32x4 = vdupq_n_f32(beta->real);
|
|
1015
|
+
float32x4_t beta_imag_f32x4 = vdupq_n_f32(beta->imag);
|
|
1016
|
+
nk_size_t i = 0;
|
|
1017
|
+
for (; i + 4 <= n; i += 4) {
|
|
1018
|
+
float32x4x2_t a_f32x4x2 = vld2q_f32((nk_f32_t const *)(a + i));
|
|
1019
|
+
float32x4x2_t b_f32x4x2 = vld2q_f32((nk_f32_t const *)(b + i));
|
|
1020
|
+
float32x4x2_t c_f32x4x2 = vld2q_f32((nk_f32_t const *)(c + i));
|
|
1021
|
+
float32x4_t ab_real_f32x4 = vmulq_f32(a_f32x4x2.val[0], b_f32x4x2.val[0]);
|
|
1022
|
+
ab_real_f32x4 = vfmsq_f32(ab_real_f32x4, a_f32x4x2.val[1], b_f32x4x2.val[1]);
|
|
1023
|
+
float32x4_t ab_imag_f32x4 = vmulq_f32(a_f32x4x2.val[0], b_f32x4x2.val[1]);
|
|
1024
|
+
ab_imag_f32x4 = vfmaq_f32(ab_imag_f32x4, a_f32x4x2.val[1], b_f32x4x2.val[0]);
|
|
1025
|
+
float32x4_t y_real_f32x4 = vmulq_f32(alpha_real_f32x4, ab_real_f32x4);
|
|
1026
|
+
y_real_f32x4 = vfmsq_f32(y_real_f32x4, alpha_imag_f32x4, ab_imag_f32x4);
|
|
1027
|
+
float32x4_t y_imag_f32x4 = vmulq_f32(alpha_real_f32x4, ab_imag_f32x4);
|
|
1028
|
+
y_imag_f32x4 = vfmaq_f32(y_imag_f32x4, alpha_imag_f32x4, ab_real_f32x4);
|
|
1029
|
+
y_real_f32x4 = vfmaq_f32(y_real_f32x4, beta_real_f32x4, c_f32x4x2.val[0]);
|
|
1030
|
+
y_real_f32x4 = vfmsq_f32(y_real_f32x4, beta_imag_f32x4, c_f32x4x2.val[1]);
|
|
1031
|
+
y_imag_f32x4 = vfmaq_f32(y_imag_f32x4, beta_real_f32x4, c_f32x4x2.val[1]);
|
|
1032
|
+
y_imag_f32x4 = vfmaq_f32(y_imag_f32x4, beta_imag_f32x4, c_f32x4x2.val[0]);
|
|
1033
|
+
float32x4x2_t out = {y_real_f32x4, y_imag_f32x4};
|
|
1034
|
+
vst2q_f32((nk_f32_t *)(result + i), out);
|
|
1035
|
+
}
|
|
1036
|
+
for (; i < n; i++) {
|
|
1037
|
+
nk_f32_t a_real = a[i].real, a_imag = a[i].imag;
|
|
1038
|
+
nk_f32_t b_real = b[i].real, b_imag = b[i].imag;
|
|
1039
|
+
nk_f32_t c_real = c[i].real, c_imag = c[i].imag;
|
|
1040
|
+
nk_f32_t ab_real = a_real * b_real - a_imag * b_imag;
|
|
1041
|
+
nk_f32_t ab_imag = a_real * b_imag + a_imag * b_real;
|
|
1042
|
+
nk_f32_t aab_real = alpha->real * ab_real - alpha->imag * ab_imag;
|
|
1043
|
+
nk_f32_t aab_imag = alpha->real * ab_imag + alpha->imag * ab_real;
|
|
1044
|
+
nk_f32_t bc_real = beta->real * c_real - beta->imag * c_imag;
|
|
1045
|
+
nk_f32_t bc_imag = beta->real * c_imag + beta->imag * c_real;
|
|
1046
|
+
result[i].real = aab_real + bc_real;
|
|
1047
|
+
result[i].imag = aab_imag + bc_imag;
|
|
1048
|
+
}
|
|
1049
|
+
}
|
|
1050
|
+
|
|
1051
|
+
NK_PUBLIC void nk_each_fma_f64c_neon(nk_f64c_t const *a, nk_f64c_t const *b, nk_f64c_t const *c, nk_size_t n,
|
|
1052
|
+
nk_f64c_t const *alpha, nk_f64c_t const *beta, nk_f64c_t *result) {
|
|
1053
|
+
float64x2_t alpha_real_f64x2 = vdupq_n_f64(alpha->real);
|
|
1054
|
+
float64x2_t alpha_imag_f64x2 = vdupq_n_f64(alpha->imag);
|
|
1055
|
+
float64x2_t beta_real_f64x2 = vdupq_n_f64(beta->real);
|
|
1056
|
+
float64x2_t beta_imag_f64x2 = vdupq_n_f64(beta->imag);
|
|
1057
|
+
nk_size_t i = 0;
|
|
1058
|
+
for (; i + 2 <= n; i += 2) {
|
|
1059
|
+
float64x2x2_t a_f64x2x2 = vld2q_f64((nk_f64_t const *)(a + i));
|
|
1060
|
+
float64x2x2_t b_f64x2x2 = vld2q_f64((nk_f64_t const *)(b + i));
|
|
1061
|
+
float64x2x2_t c_f64x2x2 = vld2q_f64((nk_f64_t const *)(c + i));
|
|
1062
|
+
float64x2_t ab_real_f64x2 = vmulq_f64(a_f64x2x2.val[0], b_f64x2x2.val[0]);
|
|
1063
|
+
ab_real_f64x2 = vfmsq_f64(ab_real_f64x2, a_f64x2x2.val[1], b_f64x2x2.val[1]);
|
|
1064
|
+
float64x2_t ab_imag_f64x2 = vmulq_f64(a_f64x2x2.val[0], b_f64x2x2.val[1]);
|
|
1065
|
+
ab_imag_f64x2 = vfmaq_f64(ab_imag_f64x2, a_f64x2x2.val[1], b_f64x2x2.val[0]);
|
|
1066
|
+
float64x2_t y_real_f64x2 = vmulq_f64(alpha_real_f64x2, ab_real_f64x2);
|
|
1067
|
+
y_real_f64x2 = vfmsq_f64(y_real_f64x2, alpha_imag_f64x2, ab_imag_f64x2);
|
|
1068
|
+
float64x2_t y_imag_f64x2 = vmulq_f64(alpha_real_f64x2, ab_imag_f64x2);
|
|
1069
|
+
y_imag_f64x2 = vfmaq_f64(y_imag_f64x2, alpha_imag_f64x2, ab_real_f64x2);
|
|
1070
|
+
y_real_f64x2 = vfmaq_f64(y_real_f64x2, beta_real_f64x2, c_f64x2x2.val[0]);
|
|
1071
|
+
y_real_f64x2 = vfmsq_f64(y_real_f64x2, beta_imag_f64x2, c_f64x2x2.val[1]);
|
|
1072
|
+
y_imag_f64x2 = vfmaq_f64(y_imag_f64x2, beta_real_f64x2, c_f64x2x2.val[1]);
|
|
1073
|
+
y_imag_f64x2 = vfmaq_f64(y_imag_f64x2, beta_imag_f64x2, c_f64x2x2.val[0]);
|
|
1074
|
+
float64x2x2_t out = {y_real_f64x2, y_imag_f64x2};
|
|
1075
|
+
vst2q_f64((nk_f64_t *)(result + i), out);
|
|
1076
|
+
}
|
|
1077
|
+
for (; i < n; i++) {
|
|
1078
|
+
nk_f64_t a_real = a[i].real, a_imag = a[i].imag;
|
|
1079
|
+
nk_f64_t b_real = b[i].real, b_imag = b[i].imag;
|
|
1080
|
+
nk_f64_t c_real = c[i].real, c_imag = c[i].imag;
|
|
1081
|
+
nk_f64_t ab_real = a_real * b_real - a_imag * b_imag;
|
|
1082
|
+
nk_f64_t ab_imag = a_real * b_imag + a_imag * b_real;
|
|
1083
|
+
nk_f64_t aab_real = alpha->real * ab_real - alpha->imag * ab_imag;
|
|
1084
|
+
nk_f64_t aab_imag = alpha->real * ab_imag + alpha->imag * ab_real;
|
|
1085
|
+
nk_f64_t bc_real = beta->real * c_real - beta->imag * c_imag;
|
|
1086
|
+
nk_f64_t bc_imag = beta->real * c_imag + beta->imag * c_real;
|
|
1087
|
+
result[i].real = aab_real + bc_real;
|
|
1088
|
+
result[i].imag = aab_imag + bc_imag;
|
|
1089
|
+
}
|
|
1090
|
+
}
|
|
1091
|
+
|
|
1092
|
+
#if defined(__clang__)
|
|
1093
|
+
#pragma clang attribute pop
|
|
1094
|
+
#elif defined(__GNUC__)
|
|
1095
|
+
#pragma GCC pop_options
|
|
1096
|
+
#endif
|
|
1097
|
+
|
|
1098
|
+
#if defined(__cplusplus)
|
|
1099
|
+
} // extern "C"
|
|
1100
|
+
#endif
|
|
1101
|
+
|
|
1102
|
+
#endif // NK_TARGET_NEON
|
|
1103
|
+
#endif // NK_TARGET_ARM_
|
|
1104
|
+
#endif // NK_EACH_NEON_H
|