numkong 7.0.0 → 7.4.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +239 -122
- package/binding.gyp +25 -491
- package/c/dispatch_bf16.c +59 -1
- package/c/dispatch_e2m3.c +41 -8
- package/c/dispatch_e3m2.c +49 -8
- package/c/dispatch_e4m3.c +51 -9
- package/c/dispatch_e5m2.c +45 -1
- package/c/dispatch_f16.c +79 -26
- package/c/dispatch_f16c.c +5 -5
- package/c/dispatch_f32.c +56 -0
- package/c/dispatch_f64.c +52 -0
- package/c/dispatch_i4.c +3 -0
- package/c/dispatch_i8.c +62 -3
- package/c/dispatch_other.c +18 -0
- package/c/dispatch_u1.c +54 -9
- package/c/dispatch_u4.c +3 -0
- package/c/dispatch_u8.c +64 -3
- package/c/numkong.c +3 -0
- package/include/README.md +79 -9
- package/include/numkong/attention/sapphireamx.h +278 -276
- package/include/numkong/attention/sme.h +983 -977
- package/include/numkong/attention.h +1 -1
- package/include/numkong/capabilities.h +289 -94
- package/include/numkong/cast/README.md +40 -40
- package/include/numkong/cast/diamond.h +64 -0
- package/include/numkong/cast/haswell.h +42 -194
- package/include/numkong/cast/icelake.h +42 -37
- package/include/numkong/cast/loongsonasx.h +252 -0
- package/include/numkong/cast/neon.h +216 -249
- package/include/numkong/cast/powervsx.h +449 -0
- package/include/numkong/cast/rvv.h +223 -274
- package/include/numkong/cast/sapphire.h +18 -18
- package/include/numkong/cast/serial.h +1018 -944
- package/include/numkong/cast/skylake.h +82 -23
- package/include/numkong/cast/v128relaxed.h +462 -105
- package/include/numkong/cast.h +24 -0
- package/include/numkong/cast.hpp +44 -0
- package/include/numkong/curved/README.md +17 -17
- package/include/numkong/curved/neon.h +131 -7
- package/include/numkong/curved/neonbfdot.h +6 -7
- package/include/numkong/curved/rvv.h +26 -26
- package/include/numkong/curved/smef64.h +186 -182
- package/include/numkong/curved.h +14 -18
- package/include/numkong/dot/README.md +154 -137
- package/include/numkong/dot/alder.h +43 -43
- package/include/numkong/dot/diamond.h +158 -0
- package/include/numkong/dot/genoa.h +4 -30
- package/include/numkong/dot/haswell.h +215 -180
- package/include/numkong/dot/icelake.h +190 -76
- package/include/numkong/dot/loongsonasx.h +671 -0
- package/include/numkong/dot/neon.h +124 -73
- package/include/numkong/dot/neonbfdot.h +11 -12
- package/include/numkong/dot/neonfhm.h +44 -46
- package/include/numkong/dot/neonfp8.h +323 -0
- package/include/numkong/dot/neonsdot.h +190 -76
- package/include/numkong/dot/powervsx.h +752 -0
- package/include/numkong/dot/rvv.h +92 -84
- package/include/numkong/dot/rvvbf16.h +12 -12
- package/include/numkong/dot/rvvhalf.h +12 -12
- package/include/numkong/dot/sapphire.h +4 -4
- package/include/numkong/dot/serial.h +66 -30
- package/include/numkong/dot/sierra.h +31 -31
- package/include/numkong/dot/skylake.h +142 -110
- package/include/numkong/dot/sve.h +217 -177
- package/include/numkong/dot/svebfdot.h +10 -10
- package/include/numkong/dot/svehalf.h +85 -41
- package/include/numkong/dot/svesdot.h +89 -0
- package/include/numkong/dot/v128relaxed.h +124 -89
- package/include/numkong/dot.h +114 -48
- package/include/numkong/dots/README.md +203 -203
- package/include/numkong/dots/alder.h +12 -9
- package/include/numkong/dots/diamond.h +86 -0
- package/include/numkong/dots/genoa.h +10 -4
- package/include/numkong/dots/haswell.h +63 -48
- package/include/numkong/dots/icelake.h +27 -18
- package/include/numkong/dots/loongsonasx.h +176 -0
- package/include/numkong/dots/neon.h +14 -11
- package/include/numkong/dots/neonbfdot.h +4 -3
- package/include/numkong/dots/neonfhm.h +11 -9
- package/include/numkong/dots/neonfp8.h +99 -0
- package/include/numkong/dots/neonsdot.h +48 -12
- package/include/numkong/dots/powervsx.h +194 -0
- package/include/numkong/dots/rvv.h +451 -344
- package/include/numkong/dots/sapphireamx.h +1028 -984
- package/include/numkong/dots/serial.h +213 -197
- package/include/numkong/dots/sierra.h +10 -7
- package/include/numkong/dots/skylake.h +47 -36
- package/include/numkong/dots/sme.h +2001 -2364
- package/include/numkong/dots/smebi32.h +175 -162
- package/include/numkong/dots/smef64.h +328 -323
- package/include/numkong/dots/v128relaxed.h +64 -41
- package/include/numkong/dots.h +573 -293
- package/include/numkong/dots.hpp +45 -43
- package/include/numkong/each/README.md +133 -137
- package/include/numkong/each/haswell.h +6 -6
- package/include/numkong/each/icelake.h +7 -7
- package/include/numkong/each/neon.h +76 -42
- package/include/numkong/each/neonbfdot.h +11 -12
- package/include/numkong/each/neonhalf.h +24 -116
- package/include/numkong/each/rvv.h +28 -28
- package/include/numkong/each/sapphire.h +27 -161
- package/include/numkong/each/serial.h +6 -6
- package/include/numkong/each/skylake.h +7 -7
- package/include/numkong/each/v128relaxed.h +562 -0
- package/include/numkong/each.h +148 -62
- package/include/numkong/each.hpp +2 -2
- package/include/numkong/geospatial/README.md +18 -18
- package/include/numkong/geospatial/haswell.h +365 -325
- package/include/numkong/geospatial/neon.h +350 -306
- package/include/numkong/geospatial/rvv.h +4 -4
- package/include/numkong/geospatial/skylake.h +376 -340
- package/include/numkong/geospatial/v128relaxed.h +366 -327
- package/include/numkong/geospatial.h +17 -17
- package/include/numkong/matrix.hpp +4 -4
- package/include/numkong/maxsim/README.md +14 -14
- package/include/numkong/maxsim/alder.h +6 -6
- package/include/numkong/maxsim/genoa.h +4 -4
- package/include/numkong/maxsim/haswell.h +6 -6
- package/include/numkong/maxsim/icelake.h +18 -18
- package/include/numkong/maxsim/neonsdot.h +21 -21
- package/include/numkong/maxsim/sapphireamx.h +14 -14
- package/include/numkong/maxsim/serial.h +6 -6
- package/include/numkong/maxsim/sme.h +221 -196
- package/include/numkong/maxsim/v128relaxed.h +6 -6
- package/include/numkong/mesh/README.md +62 -56
- package/include/numkong/mesh/haswell.h +339 -464
- package/include/numkong/mesh/neon.h +1100 -519
- package/include/numkong/mesh/neonbfdot.h +36 -68
- package/include/numkong/mesh/rvv.h +530 -435
- package/include/numkong/mesh/serial.h +75 -91
- package/include/numkong/mesh/skylake.h +1627 -302
- package/include/numkong/mesh/v128relaxed.h +443 -330
- package/include/numkong/mesh.h +63 -49
- package/include/numkong/mesh.hpp +4 -4
- package/include/numkong/numkong.h +3 -3
- package/include/numkong/numkong.hpp +1 -0
- package/include/numkong/probability/README.md +23 -19
- package/include/numkong/probability/neon.h +82 -52
- package/include/numkong/probability/rvv.h +28 -23
- package/include/numkong/probability/serial.h +51 -39
- package/include/numkong/probability.h +20 -23
- package/include/numkong/random.h +1 -1
- package/include/numkong/reduce/README.md +143 -138
- package/include/numkong/reduce/alder.h +81 -77
- package/include/numkong/reduce/haswell.h +222 -220
- package/include/numkong/reduce/neon.h +629 -519
- package/include/numkong/reduce/neonbfdot.h +7 -218
- package/include/numkong/reduce/neonfhm.h +9 -381
- package/include/numkong/reduce/neonsdot.h +9 -9
- package/include/numkong/reduce/rvv.h +928 -802
- package/include/numkong/reduce/serial.h +23 -27
- package/include/numkong/reduce/sierra.h +20 -20
- package/include/numkong/reduce/skylake.h +326 -324
- package/include/numkong/reduce/v128relaxed.h +52 -52
- package/include/numkong/reduce.h +4 -23
- package/include/numkong/reduce.hpp +156 -11
- package/include/numkong/scalar/README.md +6 -6
- package/include/numkong/scalar/haswell.h +26 -17
- package/include/numkong/scalar/loongsonasx.h +74 -0
- package/include/numkong/scalar/neon.h +9 -9
- package/include/numkong/scalar/powervsx.h +96 -0
- package/include/numkong/scalar/rvv.h +2 -2
- package/include/numkong/scalar/sapphire.h +21 -10
- package/include/numkong/scalar/serial.h +21 -21
- package/include/numkong/scalar.h +13 -0
- package/include/numkong/set/README.md +28 -28
- package/include/numkong/set/haswell.h +12 -12
- package/include/numkong/set/icelake.h +14 -14
- package/include/numkong/set/loongsonasx.h +181 -0
- package/include/numkong/set/neon.h +17 -18
- package/include/numkong/set/powervsx.h +326 -0
- package/include/numkong/set/rvv.h +4 -4
- package/include/numkong/set/serial.h +6 -6
- package/include/numkong/set/sve.h +60 -59
- package/include/numkong/set/v128relaxed.h +6 -6
- package/include/numkong/set.h +21 -7
- package/include/numkong/sets/README.md +26 -26
- package/include/numkong/sets/loongsonasx.h +52 -0
- package/include/numkong/sets/powervsx.h +65 -0
- package/include/numkong/sets/smebi32.h +395 -364
- package/include/numkong/sets.h +83 -40
- package/include/numkong/sparse/README.md +4 -4
- package/include/numkong/sparse/icelake.h +101 -101
- package/include/numkong/sparse/serial.h +1 -1
- package/include/numkong/sparse/sve2.h +137 -141
- package/include/numkong/sparse/turin.h +12 -12
- package/include/numkong/sparse.h +10 -10
- package/include/numkong/spatial/README.md +230 -226
- package/include/numkong/spatial/alder.h +113 -116
- package/include/numkong/spatial/diamond.h +240 -0
- package/include/numkong/spatial/genoa.h +0 -68
- package/include/numkong/spatial/haswell.h +74 -55
- package/include/numkong/spatial/icelake.h +539 -58
- package/include/numkong/spatial/loongsonasx.h +483 -0
- package/include/numkong/spatial/neon.h +125 -52
- package/include/numkong/spatial/neonbfdot.h +8 -9
- package/include/numkong/spatial/neonfp8.h +258 -0
- package/include/numkong/spatial/neonsdot.h +180 -12
- package/include/numkong/spatial/powervsx.h +738 -0
- package/include/numkong/spatial/rvv.h +146 -139
- package/include/numkong/spatial/rvvbf16.h +17 -12
- package/include/numkong/spatial/rvvhalf.h +13 -10
- package/include/numkong/spatial/serial.h +13 -12
- package/include/numkong/spatial/sierra.h +232 -39
- package/include/numkong/spatial/skylake.h +73 -74
- package/include/numkong/spatial/sve.h +93 -72
- package/include/numkong/spatial/svebfdot.h +29 -29
- package/include/numkong/spatial/svehalf.h +52 -26
- package/include/numkong/spatial/svesdot.h +142 -0
- package/include/numkong/spatial/v128relaxed.h +293 -41
- package/include/numkong/spatial.h +338 -82
- package/include/numkong/spatials/README.md +194 -194
- package/include/numkong/spatials/diamond.h +82 -0
- package/include/numkong/spatials/haswell.h +2 -2
- package/include/numkong/spatials/loongsonasx.h +153 -0
- package/include/numkong/spatials/neonfp8.h +111 -0
- package/include/numkong/spatials/neonsdot.h +34 -0
- package/include/numkong/spatials/powervsx.h +153 -0
- package/include/numkong/spatials/rvv.h +259 -243
- package/include/numkong/spatials/sapphireamx.h +173 -173
- package/include/numkong/spatials/serial.h +2 -2
- package/include/numkong/spatials/skylake.h +2 -2
- package/include/numkong/spatials/sme.h +590 -605
- package/include/numkong/spatials/smef64.h +139 -130
- package/include/numkong/spatials/v128relaxed.h +2 -2
- package/include/numkong/spatials.h +820 -500
- package/include/numkong/spatials.hpp +49 -48
- package/include/numkong/tensor.hpp +406 -17
- package/include/numkong/trigonometry/README.md +19 -19
- package/include/numkong/trigonometry/haswell.h +402 -401
- package/include/numkong/trigonometry/neon.h +386 -387
- package/include/numkong/trigonometry/rvv.h +52 -51
- package/include/numkong/trigonometry/serial.h +13 -13
- package/include/numkong/trigonometry/skylake.h +373 -369
- package/include/numkong/trigonometry/v128relaxed.h +375 -374
- package/include/numkong/trigonometry.h +13 -13
- package/include/numkong/trigonometry.hpp +2 -2
- package/include/numkong/types.h +287 -49
- package/include/numkong/types.hpp +436 -12
- package/include/numkong/vector.hpp +82 -14
- package/javascript/dist/cjs/numkong-wasm.js +6 -12
- package/javascript/dist/cjs/numkong.d.ts +7 -1
- package/javascript/dist/cjs/numkong.js +37 -11
- package/javascript/dist/cjs/types.d.ts +9 -0
- package/javascript/dist/cjs/types.js +96 -0
- package/javascript/dist/esm/numkong-browser.d.ts +14 -0
- package/javascript/dist/esm/numkong-browser.js +23 -0
- package/javascript/dist/esm/numkong-wasm.js +6 -12
- package/javascript/dist/esm/numkong.d.ts +7 -1
- package/javascript/dist/esm/numkong.js +37 -11
- package/javascript/dist/esm/types.d.ts +9 -0
- package/javascript/dist/esm/types.js +96 -0
- package/javascript/node-gyp-build.d.ts +4 -1
- package/javascript/numkong-browser.ts +40 -0
- package/javascript/numkong-wasm.ts +7 -13
- package/javascript/numkong.c +5 -26
- package/javascript/numkong.ts +36 -11
- package/javascript/tsconfig-base.json +1 -0
- package/javascript/tsconfig-cjs.json +6 -1
- package/javascript/types.ts +110 -0
- package/numkong.gypi +101 -0
- package/package.json +34 -13
- package/probes/arm_neon.c +8 -0
- package/probes/arm_neon_bfdot.c +9 -0
- package/probes/arm_neon_fhm.c +9 -0
- package/probes/arm_neon_half.c +8 -0
- package/probes/arm_neon_sdot.c +9 -0
- package/probes/arm_neonfp8.c +9 -0
- package/probes/arm_sme.c +16 -0
- package/probes/arm_sme2.c +16 -0
- package/probes/arm_sme2p1.c +16 -0
- package/probes/arm_sme_bf16.c +16 -0
- package/probes/arm_sme_bi32.c +16 -0
- package/probes/arm_sme_f64.c +16 -0
- package/probes/arm_sme_fa64.c +14 -0
- package/probes/arm_sme_half.c +16 -0
- package/probes/arm_sme_lut2.c +15 -0
- package/probes/arm_sve.c +18 -0
- package/probes/arm_sve2.c +20 -0
- package/probes/arm_sve2p1.c +18 -0
- package/probes/arm_sve_bfdot.c +20 -0
- package/probes/arm_sve_half.c +18 -0
- package/probes/arm_sve_sdot.c +21 -0
- package/probes/loongarch_lasx.c +12 -0
- package/probes/power_vsx.c +12 -0
- package/probes/probe.js +127 -0
- package/probes/riscv_rvv.c +14 -0
- package/probes/riscv_rvv_bb.c +15 -0
- package/probes/riscv_rvv_bf16.c +17 -0
- package/probes/riscv_rvv_half.c +14 -0
- package/probes/wasm_v128relaxed.c +11 -0
- package/probes/x86_alder.c +17 -0
- package/probes/x86_diamond.c +17 -0
- package/probes/x86_genoa.c +17 -0
- package/probes/x86_graniteamx.c +19 -0
- package/probes/x86_haswell.c +11 -0
- package/probes/x86_icelake.c +17 -0
- package/probes/x86_sapphire.c +16 -0
- package/probes/x86_sapphireamx.c +18 -0
- package/probes/x86_sierra.c +17 -0
- package/probes/x86_skylake.c +15 -0
- package/probes/x86_turin.c +17 -0
- package/wasm/numkong-emscripten.js +2 -0
- package/wasm/numkong.d.ts +14 -0
- package/wasm/numkong.js +1124 -0
- package/wasm/numkong.wasm +0 -0
- package/include/numkong/curved/neonhalf.h +0 -212
- package/include/numkong/dot/neonhalf.h +0 -198
- package/include/numkong/dots/neonhalf.h +0 -57
- package/include/numkong/mesh/neonhalf.h +0 -616
- package/include/numkong/reduce/neonhalf.h +0 -157
- package/include/numkong/spatial/neonhalf.h +0 -118
- package/include/numkong/spatial/sapphire.h +0 -343
- package/include/numkong/spatials/neonhalf.h +0 -58
- package/javascript/README.md +0 -246
|
@@ -8,22 +8,22 @@
|
|
|
8
8
|
*
|
|
9
9
|
* @section dot_sve_instructions ARM SVE Instructions
|
|
10
10
|
*
|
|
11
|
-
* Intrinsic
|
|
12
|
-
* svld1_f32
|
|
13
|
-
* svld2_f32
|
|
14
|
-
* svmla_f32_x
|
|
15
|
-
* svmls_f32_x
|
|
16
|
-
* svaddv_f32
|
|
17
|
-
* svdup_f32
|
|
18
|
-
* svwhilelt_b32
|
|
19
|
-
* svptrue_b32
|
|
20
|
-
* svcntw
|
|
21
|
-
* svcntd
|
|
22
|
-
* svld1_f64
|
|
23
|
-
* svld2_f64
|
|
24
|
-
* svmla_f64_x
|
|
25
|
-
* svmls_f64_x
|
|
26
|
-
* svaddv_f64
|
|
11
|
+
* Intrinsic Instruction V1
|
|
12
|
+
* svld1_f32 LD1W (Z.S, P/Z, [Xn]) 4-6cy @ 2p
|
|
13
|
+
* svld2_f32 LD2W (Z.S, P/Z, [Xn]) 6-8cy @ 1p
|
|
14
|
+
* svmla_f32_x FMLA (Z.S, P/M, Z.S, Z.S) 4cy @ 2p
|
|
15
|
+
* svmls_f32_x FMLS (Z.S, P/M, Z.S, Z.S) 4cy @ 2p
|
|
16
|
+
* svaddv_f32 FADDV (S, P, Z.S) 6cy @ 1p
|
|
17
|
+
* svdup_f32 DUP (Z.S, #imm) 1cy @ 2p
|
|
18
|
+
* svwhilelt_b32 WHILELT (P.S, Xn, Xm) 2cy @ 1p
|
|
19
|
+
* svptrue_b32 PTRUE (P.S, pattern) 1cy @ 2p
|
|
20
|
+
* svcntw CNTW (Xd) 1cy @ 2p
|
|
21
|
+
* svcntd CNTD (Xd) 1cy @ 2p
|
|
22
|
+
* svld1_f64 LD1D (Z.D, P/Z, [Xn]) 4-6cy @ 2p
|
|
23
|
+
* svld2_f64 LD2D (Z.D, P/Z, [Xn]) 6-8cy @ 1p
|
|
24
|
+
* svmla_f64_x FMLA (Z.D, P/M, Z.D, Z.D) 4cy @ 2p
|
|
25
|
+
* svmls_f64_x FMLS (Z.D, P/M, Z.D, Z.D) 4cy @ 2p
|
|
26
|
+
* svaddv_f64 FADDV (D, P, Z.D) 6cy @ 1p
|
|
27
27
|
*
|
|
28
28
|
* SVE vector widths vary across implementations: Graviton3 uses 256-bit, while Graviton4/5
|
|
29
29
|
* and Apple M4+ use 128-bit. Code using svcntb() adapts automatically, but wider vectors
|
|
@@ -58,49 +58,57 @@ extern "C" {
|
|
|
58
58
|
* return 0 (SVE spec), which is harmless since only the lower half is meaningful
|
|
59
59
|
* after each halving stage.
|
|
60
60
|
*/
|
|
61
|
-
NK_INTERNAL nk_f64_t nk_dot_stable_sum_f64_sve_(svbool_t
|
|
61
|
+
NK_INTERNAL nk_f64_t nk_dot_stable_sum_f64_sve_(svbool_t predicate_b64x, svfloat64_t sum, svfloat64_t compensation) {
|
|
62
62
|
// Stage 0: TwoSum merge of sum + compensation (parallel across all active lanes)
|
|
63
|
-
svfloat64_t tentative_sum_f64x = svadd_f64_x(
|
|
64
|
-
svfloat64_t virtual_addend_f64x = svsub_f64_x(
|
|
63
|
+
svfloat64_t tentative_sum_f64x = svadd_f64_x(predicate_b64x, sum, compensation);
|
|
64
|
+
svfloat64_t virtual_addend_f64x = svsub_f64_x(predicate_b64x, tentative_sum_f64x, sum);
|
|
65
65
|
svfloat64_t accumulated_error_f64x = svadd_f64_x(
|
|
66
|
-
|
|
67
|
-
svsub_f64_x(
|
|
66
|
+
predicate_b64x,
|
|
67
|
+
svsub_f64_x(predicate_b64x, sum, svsub_f64_x(predicate_b64x, tentative_sum_f64x, virtual_addend_f64x)),
|
|
68
|
+
svsub_f64_x(predicate_b64x, compensation, virtual_addend_f64x));
|
|
68
69
|
|
|
69
70
|
// Tree reduction: TwoSum halving at each level, log2(VL) iterations
|
|
70
71
|
for (unsigned int half = (unsigned int)svcntd() / 2; half > 0; half >>= 1) {
|
|
71
|
-
svuint64_t upper_indices_u64x = svadd_n_u64_x(
|
|
72
|
+
svuint64_t upper_indices_u64x = svadd_n_u64_x(predicate_b64x, svindex_u64(0, 1), half);
|
|
72
73
|
svfloat64_t upper_sum_f64x = svtbl_f64(tentative_sum_f64x, upper_indices_u64x);
|
|
73
74
|
svfloat64_t upper_error_f64x = svtbl_f64(accumulated_error_f64x, upper_indices_u64x);
|
|
74
75
|
// TwoSum: lower_half + upper_half
|
|
75
|
-
svfloat64_t halved_tentative_sum_f64x = svadd_f64_x(
|
|
76
|
-
svfloat64_t halved_virtual_addend_f64x = svsub_f64_x(
|
|
76
|
+
svfloat64_t halved_tentative_sum_f64x = svadd_f64_x(predicate_b64x, tentative_sum_f64x, upper_sum_f64x);
|
|
77
|
+
svfloat64_t halved_virtual_addend_f64x = svsub_f64_x(predicate_b64x, halved_tentative_sum_f64x,
|
|
78
|
+
tentative_sum_f64x);
|
|
77
79
|
svfloat64_t rounding_error_f64x = svadd_f64_x(
|
|
78
|
-
|
|
79
|
-
svsub_f64_x(
|
|
80
|
-
svsub_f64_x(
|
|
81
|
-
svsub_f64_x(
|
|
80
|
+
predicate_b64x,
|
|
81
|
+
svsub_f64_x(predicate_b64x, tentative_sum_f64x,
|
|
82
|
+
svsub_f64_x(predicate_b64x, halved_tentative_sum_f64x, halved_virtual_addend_f64x)),
|
|
83
|
+
svsub_f64_x(predicate_b64x, upper_sum_f64x, halved_virtual_addend_f64x));
|
|
82
84
|
tentative_sum_f64x = halved_tentative_sum_f64x;
|
|
83
85
|
accumulated_error_f64x = svadd_f64_x(
|
|
84
|
-
|
|
86
|
+
predicate_b64x, svadd_f64_x(predicate_b64x, accumulated_error_f64x, upper_error_f64x), rounding_error_f64x);
|
|
85
87
|
}
|
|
86
88
|
// Result is in lane 0
|
|
87
|
-
svbool_t
|
|
88
|
-
return svlastb_f64(
|
|
89
|
-
svlastb_f64(
|
|
89
|
+
svbool_t predicate_first_b64x = svwhilelt_b64_u64(0u, 1);
|
|
90
|
+
return svlastb_f64(predicate_first_b64x, tentative_sum_f64x) +
|
|
91
|
+
svlastb_f64(predicate_first_b64x, accumulated_error_f64x);
|
|
90
92
|
}
|
|
91
93
|
|
|
92
94
|
NK_PUBLIC void nk_dot_f32_sve(nk_f32_t const *a_scalars, nk_f32_t const *b_scalars, nk_size_t count_scalars,
|
|
93
95
|
nk_f64_t *result) {
|
|
94
96
|
nk_size_t idx_scalars = 0;
|
|
95
|
-
nk_size_t const vector_length = svcntd();
|
|
96
97
|
svfloat64_t ab_f64x = svdup_f64(0.);
|
|
97
|
-
for (; idx_scalars < count_scalars; idx_scalars +=
|
|
98
|
-
svbool_t
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
98
|
+
for (; idx_scalars < count_scalars; idx_scalars += svcntw()) {
|
|
99
|
+
svbool_t predicate_b32x = svwhilelt_b32_u64(idx_scalars, count_scalars);
|
|
100
|
+
svfloat32_t a_f32x = svld1_f32(predicate_b32x, a_scalars + idx_scalars);
|
|
101
|
+
svfloat32_t b_f32x = svld1_f32(predicate_b32x, b_scalars + idx_scalars);
|
|
102
|
+
nk_size_t remaining = count_scalars - idx_scalars < svcntw() ? count_scalars - idx_scalars : svcntw();
|
|
103
|
+
|
|
104
|
+
// svcvt_f64_f32_x widens only even-indexed f32 elements; svext by 1 shifts odd into even.
|
|
105
|
+
svbool_t pred_even_b64x = svwhilelt_b64_u64(0u, (remaining + 1) / 2);
|
|
106
|
+
ab_f64x = svmla_f64_m(pred_even_b64x, ab_f64x, svcvt_f64_f32_x(pred_even_b64x, a_f32x),
|
|
107
|
+
svcvt_f64_f32_x(pred_even_b64x, b_f32x));
|
|
108
|
+
|
|
109
|
+
svbool_t pred_odd_b64x = svwhilelt_b64_u64(0u, remaining / 2);
|
|
110
|
+
ab_f64x = svmla_f64_m(pred_odd_b64x, ab_f64x, svcvt_f64_f32_x(pred_odd_b64x, svext_f32(a_f32x, a_f32x, 1)),
|
|
111
|
+
svcvt_f64_f32_x(pred_odd_b64x, svext_f32(b_f32x, b_f32x, 1)));
|
|
104
112
|
}
|
|
105
113
|
*result = svaddv_f64(svptrue_b64(), ab_f64x);
|
|
106
114
|
}
|
|
@@ -108,22 +116,38 @@ NK_PUBLIC void nk_dot_f32_sve(nk_f32_t const *a_scalars, nk_f32_t const *b_scala
|
|
|
108
116
|
NK_PUBLIC void nk_dot_f32c_sve(nk_f32c_t const *a_pairs, nk_f32c_t const *b_pairs, nk_size_t count_pairs,
|
|
109
117
|
nk_f64c_t *results) {
|
|
110
118
|
nk_size_t idx_pairs = 0;
|
|
111
|
-
nk_size_t const vector_length = svcntd();
|
|
112
119
|
svfloat64_t ab_real_f64x = svdup_f64(0.);
|
|
113
120
|
svfloat64_t ab_imag_f64x = svdup_f64(0.);
|
|
114
|
-
for (; idx_pairs < count_pairs; idx_pairs +=
|
|
115
|
-
svbool_t
|
|
116
|
-
|
|
117
|
-
svfloat32x2_t
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
121
|
+
for (; idx_pairs < count_pairs; idx_pairs += svcntw()) {
|
|
122
|
+
svbool_t predicate_b32x = svwhilelt_b32_u64(idx_pairs, count_pairs);
|
|
123
|
+
svfloat32x2_t a_f32x2 = svld2_f32(predicate_b32x, (nk_f32_t const *)(a_pairs + idx_pairs));
|
|
124
|
+
svfloat32x2_t b_f32x2 = svld2_f32(predicate_b32x, (nk_f32_t const *)(b_pairs + idx_pairs));
|
|
125
|
+
svfloat32_t a_real_f32x = svget2_f32(a_f32x2, 0);
|
|
126
|
+
svfloat32_t a_imag_f32x = svget2_f32(a_f32x2, 1);
|
|
127
|
+
svfloat32_t b_real_f32x = svget2_f32(b_f32x2, 0);
|
|
128
|
+
svfloat32_t b_imag_f32x = svget2_f32(b_f32x2, 1);
|
|
129
|
+
nk_size_t remaining = count_pairs - idx_pairs < svcntw() ? count_pairs - idx_pairs : svcntw();
|
|
130
|
+
|
|
131
|
+
// svcvt_f64_f32_x widens only even-indexed f32 elements; svext by 1 shifts odd into even.
|
|
132
|
+
svbool_t pred_even_b64x = svwhilelt_b64_u64(0u, (remaining + 1) / 2);
|
|
133
|
+
svfloat64_t a_real_even_f64x = svcvt_f64_f32_x(pred_even_b64x, a_real_f32x);
|
|
134
|
+
svfloat64_t a_imag_even_f64x = svcvt_f64_f32_x(pred_even_b64x, a_imag_f32x);
|
|
135
|
+
svfloat64_t b_real_even_f64x = svcvt_f64_f32_x(pred_even_b64x, b_real_f32x);
|
|
136
|
+
svfloat64_t b_imag_even_f64x = svcvt_f64_f32_x(pred_even_b64x, b_imag_f32x);
|
|
137
|
+
ab_real_f64x = svmla_f64_m(pred_even_b64x, ab_real_f64x, a_real_even_f64x, b_real_even_f64x);
|
|
138
|
+
ab_real_f64x = svmls_f64_m(pred_even_b64x, ab_real_f64x, a_imag_even_f64x, b_imag_even_f64x);
|
|
139
|
+
ab_imag_f64x = svmla_f64_m(pred_even_b64x, ab_imag_f64x, a_real_even_f64x, b_imag_even_f64x);
|
|
140
|
+
ab_imag_f64x = svmla_f64_m(pred_even_b64x, ab_imag_f64x, a_imag_even_f64x, b_real_even_f64x);
|
|
141
|
+
|
|
142
|
+
svbool_t pred_odd_b64x = svwhilelt_b64_u64(0u, remaining / 2);
|
|
143
|
+
svfloat64_t a_real_odd_f64x = svcvt_f64_f32_x(pred_odd_b64x, svext_f32(a_real_f32x, a_real_f32x, 1));
|
|
144
|
+
svfloat64_t a_imag_odd_f64x = svcvt_f64_f32_x(pred_odd_b64x, svext_f32(a_imag_f32x, a_imag_f32x, 1));
|
|
145
|
+
svfloat64_t b_real_odd_f64x = svcvt_f64_f32_x(pred_odd_b64x, svext_f32(b_real_f32x, b_real_f32x, 1));
|
|
146
|
+
svfloat64_t b_imag_odd_f64x = svcvt_f64_f32_x(pred_odd_b64x, svext_f32(b_imag_f32x, b_imag_f32x, 1));
|
|
147
|
+
ab_real_f64x = svmla_f64_m(pred_odd_b64x, ab_real_f64x, a_real_odd_f64x, b_real_odd_f64x);
|
|
148
|
+
ab_real_f64x = svmls_f64_m(pred_odd_b64x, ab_real_f64x, a_imag_odd_f64x, b_imag_odd_f64x);
|
|
149
|
+
ab_imag_f64x = svmla_f64_m(pred_odd_b64x, ab_imag_f64x, a_real_odd_f64x, b_imag_odd_f64x);
|
|
150
|
+
ab_imag_f64x = svmla_f64_m(pred_odd_b64x, ab_imag_f64x, a_imag_odd_f64x, b_real_odd_f64x);
|
|
127
151
|
}
|
|
128
152
|
results->real = svaddv_f64(svptrue_b64(), ab_real_f64x);
|
|
129
153
|
results->imag = svaddv_f64(svptrue_b64(), ab_imag_f64x);
|
|
@@ -132,22 +156,38 @@ NK_PUBLIC void nk_dot_f32c_sve(nk_f32c_t const *a_pairs, nk_f32c_t const *b_pair
|
|
|
132
156
|
NK_PUBLIC void nk_vdot_f32c_sve(nk_f32c_t const *a_pairs, nk_f32c_t const *b_pairs, nk_size_t count_pairs,
|
|
133
157
|
nk_f64c_t *results) {
|
|
134
158
|
nk_size_t idx_pairs = 0;
|
|
135
|
-
nk_size_t const vector_length = svcntd();
|
|
136
159
|
svfloat64_t ab_real_f64x = svdup_f64(0.);
|
|
137
160
|
svfloat64_t ab_imag_f64x = svdup_f64(0.);
|
|
138
|
-
for (; idx_pairs < count_pairs; idx_pairs +=
|
|
139
|
-
svbool_t
|
|
140
|
-
|
|
141
|
-
svfloat32x2_t
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
161
|
+
for (; idx_pairs < count_pairs; idx_pairs += svcntw()) {
|
|
162
|
+
svbool_t predicate_b32x = svwhilelt_b32_u64(idx_pairs, count_pairs);
|
|
163
|
+
svfloat32x2_t a_f32x2 = svld2_f32(predicate_b32x, (nk_f32_t const *)(a_pairs + idx_pairs));
|
|
164
|
+
svfloat32x2_t b_f32x2 = svld2_f32(predicate_b32x, (nk_f32_t const *)(b_pairs + idx_pairs));
|
|
165
|
+
svfloat32_t a_real_f32x = svget2_f32(a_f32x2, 0);
|
|
166
|
+
svfloat32_t a_imag_f32x = svget2_f32(a_f32x2, 1);
|
|
167
|
+
svfloat32_t b_real_f32x = svget2_f32(b_f32x2, 0);
|
|
168
|
+
svfloat32_t b_imag_f32x = svget2_f32(b_f32x2, 1);
|
|
169
|
+
nk_size_t remaining = count_pairs - idx_pairs < svcntw() ? count_pairs - idx_pairs : svcntw();
|
|
170
|
+
|
|
171
|
+
// svcvt_f64_f32_x widens only even-indexed f32 elements; svext by 1 shifts odd into even.
|
|
172
|
+
svbool_t pred_even_b64x = svwhilelt_b64_u64(0u, (remaining + 1) / 2);
|
|
173
|
+
svfloat64_t a_real_even_f64x = svcvt_f64_f32_x(pred_even_b64x, a_real_f32x);
|
|
174
|
+
svfloat64_t a_imag_even_f64x = svcvt_f64_f32_x(pred_even_b64x, a_imag_f32x);
|
|
175
|
+
svfloat64_t b_real_even_f64x = svcvt_f64_f32_x(pred_even_b64x, b_real_f32x);
|
|
176
|
+
svfloat64_t b_imag_even_f64x = svcvt_f64_f32_x(pred_even_b64x, b_imag_f32x);
|
|
177
|
+
ab_real_f64x = svmla_f64_m(pred_even_b64x, ab_real_f64x, a_real_even_f64x, b_real_even_f64x);
|
|
178
|
+
ab_real_f64x = svmla_f64_m(pred_even_b64x, ab_real_f64x, a_imag_even_f64x, b_imag_even_f64x);
|
|
179
|
+
ab_imag_f64x = svmla_f64_m(pred_even_b64x, ab_imag_f64x, a_real_even_f64x, b_imag_even_f64x);
|
|
180
|
+
ab_imag_f64x = svmls_f64_m(pred_even_b64x, ab_imag_f64x, a_imag_even_f64x, b_real_even_f64x);
|
|
181
|
+
|
|
182
|
+
svbool_t pred_odd_b64x = svwhilelt_b64_u64(0u, remaining / 2);
|
|
183
|
+
svfloat64_t a_real_odd_f64x = svcvt_f64_f32_x(pred_odd_b64x, svext_f32(a_real_f32x, a_real_f32x, 1));
|
|
184
|
+
svfloat64_t a_imag_odd_f64x = svcvt_f64_f32_x(pred_odd_b64x, svext_f32(a_imag_f32x, a_imag_f32x, 1));
|
|
185
|
+
svfloat64_t b_real_odd_f64x = svcvt_f64_f32_x(pred_odd_b64x, svext_f32(b_real_f32x, b_real_f32x, 1));
|
|
186
|
+
svfloat64_t b_imag_odd_f64x = svcvt_f64_f32_x(pred_odd_b64x, svext_f32(b_imag_f32x, b_imag_f32x, 1));
|
|
187
|
+
ab_real_f64x = svmla_f64_m(pred_odd_b64x, ab_real_f64x, a_real_odd_f64x, b_real_odd_f64x);
|
|
188
|
+
ab_real_f64x = svmla_f64_m(pred_odd_b64x, ab_real_f64x, a_imag_odd_f64x, b_imag_odd_f64x);
|
|
189
|
+
ab_imag_f64x = svmla_f64_m(pred_odd_b64x, ab_imag_f64x, a_real_odd_f64x, b_imag_odd_f64x);
|
|
190
|
+
ab_imag_f64x = svmls_f64_m(pred_odd_b64x, ab_imag_f64x, a_imag_odd_f64x, b_real_odd_f64x);
|
|
151
191
|
}
|
|
152
192
|
results->real = svaddv_f64(svptrue_b64(), ab_real_f64x);
|
|
153
193
|
results->imag = svaddv_f64(svptrue_b64(), ab_imag_f64x);
|
|
@@ -160,23 +200,23 @@ NK_PUBLIC void nk_dot_f64_sve(nk_f64_t const *a_scalars, nk_f64_t const *b_scala
|
|
|
160
200
|
svfloat64_t sum_f64x = svdup_f64(0.);
|
|
161
201
|
svfloat64_t compensation_f64x = svdup_f64(0.);
|
|
162
202
|
do {
|
|
163
|
-
svbool_t
|
|
164
|
-
svfloat64_t a_f64x = svld1_f64(
|
|
165
|
-
svfloat64_t b_f64x = svld1_f64(
|
|
203
|
+
svbool_t predicate_b64x = svwhilelt_b64_u64(idx_scalars, count_scalars);
|
|
204
|
+
svfloat64_t a_f64x = svld1_f64(predicate_b64x, a_scalars + idx_scalars);
|
|
205
|
+
svfloat64_t b_f64x = svld1_f64(predicate_b64x, b_scalars + idx_scalars);
|
|
166
206
|
// TwoProd: product = a*b, error = -(product - a*b) negated
|
|
167
|
-
svfloat64_t product_f64x = svmul_f64_x(
|
|
168
|
-
svfloat64_t product_error_f64x = svneg_f64_x(
|
|
169
|
-
svnmls_f64_x(
|
|
207
|
+
svfloat64_t product_f64x = svmul_f64_x(predicate_b64x, a_f64x, b_f64x);
|
|
208
|
+
svfloat64_t product_error_f64x = svneg_f64_x(predicate_b64x,
|
|
209
|
+
svnmls_f64_x(predicate_b64x, product_f64x, a_f64x, b_f64x));
|
|
170
210
|
// TwoSum: tentative_sum = sum + product
|
|
171
|
-
svfloat64_t tentative_sum_f64x =
|
|
172
|
-
svfloat64_t virtual_addend_f64x = svsub_f64_x(
|
|
211
|
+
svfloat64_t tentative_sum_f64x = svadd_f64_m(predicate_b64x, sum_f64x, product_f64x);
|
|
212
|
+
svfloat64_t virtual_addend_f64x = svsub_f64_x(predicate_b64x, tentative_sum_f64x, sum_f64x);
|
|
173
213
|
svfloat64_t sum_error_f64x = svadd_f64_x(
|
|
174
|
-
|
|
175
|
-
svsub_f64_x(
|
|
176
|
-
svsub_f64_x(
|
|
214
|
+
predicate_b64x,
|
|
215
|
+
svsub_f64_x(predicate_b64x, sum_f64x, svsub_f64_x(predicate_b64x, tentative_sum_f64x, virtual_addend_f64x)),
|
|
216
|
+
svsub_f64_x(predicate_b64x, product_f64x, virtual_addend_f64x));
|
|
177
217
|
sum_f64x = tentative_sum_f64x;
|
|
178
|
-
compensation_f64x =
|
|
179
|
-
svadd_f64_x(
|
|
218
|
+
compensation_f64x = svadd_f64_m(predicate_b64x, compensation_f64x,
|
|
219
|
+
svadd_f64_x(predicate_b64x, sum_error_f64x, product_error_f64x));
|
|
180
220
|
idx_scalars += svcntd();
|
|
181
221
|
} while (idx_scalars < count_scalars);
|
|
182
222
|
*result = nk_dot_stable_sum_f64_sve_(svptrue_b64(), sum_f64x, compensation_f64x);
|
|
@@ -192,9 +232,9 @@ NK_PUBLIC void nk_dot_f64c_sve(nk_f64c_t const *a_pairs, nk_f64c_t const *b_pair
|
|
|
192
232
|
svfloat64_t sum_imag_f64x = svdup_f64(0.);
|
|
193
233
|
svfloat64_t comp_imag_f64x = svdup_f64(0.);
|
|
194
234
|
do {
|
|
195
|
-
svbool_t
|
|
196
|
-
svfloat64x2_t a_f64x2 = svld2_f64(
|
|
197
|
-
svfloat64x2_t b_f64x2 = svld2_f64(
|
|
235
|
+
svbool_t predicate_b64x = svwhilelt_b64_u64(idx_pairs, count_pairs);
|
|
236
|
+
svfloat64x2_t a_f64x2 = svld2_f64(predicate_b64x, (nk_f64_t const *)(a_pairs + idx_pairs));
|
|
237
|
+
svfloat64x2_t b_f64x2 = svld2_f64(predicate_b64x, (nk_f64_t const *)(b_pairs + idx_pairs));
|
|
198
238
|
svfloat64_t a_real_f64x = svget2_f64(a_f64x2, 0);
|
|
199
239
|
svfloat64_t a_imag_f64x = svget2_f64(a_f64x2, 1);
|
|
200
240
|
svfloat64_t b_real_f64x = svget2_f64(b_f64x2, 0);
|
|
@@ -202,75 +242,75 @@ NK_PUBLIC void nk_dot_f64c_sve(nk_f64c_t const *a_pairs, nk_f64c_t const *b_pair
|
|
|
202
242
|
|
|
203
243
|
// TwoProd + TwoSum for real part: sum_real += a_real*b_real
|
|
204
244
|
{
|
|
205
|
-
svfloat64_t product_f64x = svmul_f64_x(
|
|
245
|
+
svfloat64_t product_f64x = svmul_f64_x(predicate_b64x, a_real_f64x, b_real_f64x);
|
|
206
246
|
svfloat64_t product_error_f64x = svneg_f64_x(
|
|
207
|
-
|
|
208
|
-
svfloat64_t tentative_sum_f64x =
|
|
209
|
-
svfloat64_t virtual_addend_f64x = svsub_f64_x(
|
|
247
|
+
predicate_b64x, svnmls_f64_x(predicate_b64x, product_f64x, a_real_f64x, b_real_f64x));
|
|
248
|
+
svfloat64_t tentative_sum_f64x = svadd_f64_m(predicate_b64x, sum_real_f64x, product_f64x);
|
|
249
|
+
svfloat64_t virtual_addend_f64x = svsub_f64_x(predicate_b64x, tentative_sum_f64x, sum_real_f64x);
|
|
210
250
|
svfloat64_t sum_error_f64x = svadd_f64_x(
|
|
211
|
-
|
|
212
|
-
svsub_f64_x(
|
|
213
|
-
svsub_f64_x(
|
|
214
|
-
svsub_f64_x(
|
|
251
|
+
predicate_b64x,
|
|
252
|
+
svsub_f64_x(predicate_b64x, sum_real_f64x,
|
|
253
|
+
svsub_f64_x(predicate_b64x, tentative_sum_f64x, virtual_addend_f64x)),
|
|
254
|
+
svsub_f64_x(predicate_b64x, product_f64x, virtual_addend_f64x));
|
|
215
255
|
sum_real_f64x = tentative_sum_f64x;
|
|
216
|
-
comp_real_f64x =
|
|
217
|
-
svadd_f64_x(
|
|
256
|
+
comp_real_f64x = svadd_f64_m(predicate_b64x, comp_real_f64x,
|
|
257
|
+
svadd_f64_x(predicate_b64x, sum_error_f64x, product_error_f64x));
|
|
218
258
|
}
|
|
219
259
|
// TwoProd + TwoSum for real part: sum_real -= a_imag*b_imag
|
|
220
260
|
{
|
|
221
|
-
svfloat64_t product_f64x = svmul_f64_x(
|
|
261
|
+
svfloat64_t product_f64x = svmul_f64_x(predicate_b64x, a_imag_f64x, b_imag_f64x);
|
|
222
262
|
svfloat64_t product_error_f64x = svneg_f64_x(
|
|
223
|
-
|
|
224
|
-
svfloat64_t neg_product_f64x = svneg_f64_x(
|
|
225
|
-
svfloat64_t neg_product_error_f64x = svneg_f64_x(
|
|
226
|
-
svfloat64_t tentative_sum_f64x =
|
|
227
|
-
svfloat64_t virtual_addend_f64x = svsub_f64_x(
|
|
263
|
+
predicate_b64x, svnmls_f64_x(predicate_b64x, product_f64x, a_imag_f64x, b_imag_f64x));
|
|
264
|
+
svfloat64_t neg_product_f64x = svneg_f64_x(predicate_b64x, product_f64x);
|
|
265
|
+
svfloat64_t neg_product_error_f64x = svneg_f64_x(predicate_b64x, product_error_f64x);
|
|
266
|
+
svfloat64_t tentative_sum_f64x = svadd_f64_m(predicate_b64x, sum_real_f64x, neg_product_f64x);
|
|
267
|
+
svfloat64_t virtual_addend_f64x = svsub_f64_x(predicate_b64x, tentative_sum_f64x, sum_real_f64x);
|
|
228
268
|
svfloat64_t sum_error_f64x = svadd_f64_x(
|
|
229
|
-
|
|
230
|
-
svsub_f64_x(
|
|
231
|
-
svsub_f64_x(
|
|
232
|
-
svsub_f64_x(
|
|
269
|
+
predicate_b64x,
|
|
270
|
+
svsub_f64_x(predicate_b64x, sum_real_f64x,
|
|
271
|
+
svsub_f64_x(predicate_b64x, tentative_sum_f64x, virtual_addend_f64x)),
|
|
272
|
+
svsub_f64_x(predicate_b64x, neg_product_f64x, virtual_addend_f64x));
|
|
233
273
|
sum_real_f64x = tentative_sum_f64x;
|
|
234
|
-
comp_real_f64x =
|
|
235
|
-
svadd_f64_x(
|
|
274
|
+
comp_real_f64x = svadd_f64_m(predicate_b64x, comp_real_f64x,
|
|
275
|
+
svadd_f64_x(predicate_b64x, sum_error_f64x, neg_product_error_f64x));
|
|
236
276
|
}
|
|
237
277
|
// TwoProd + TwoSum for imaginary part: sum_imag += a_real*b_imag
|
|
238
278
|
{
|
|
239
|
-
svfloat64_t product_f64x = svmul_f64_x(
|
|
279
|
+
svfloat64_t product_f64x = svmul_f64_x(predicate_b64x, a_real_f64x, b_imag_f64x);
|
|
240
280
|
svfloat64_t product_error_f64x = svneg_f64_x(
|
|
241
|
-
|
|
242
|
-
svfloat64_t tentative_sum_f64x =
|
|
243
|
-
svfloat64_t virtual_addend_f64x = svsub_f64_x(
|
|
281
|
+
predicate_b64x, svnmls_f64_x(predicate_b64x, product_f64x, a_real_f64x, b_imag_f64x));
|
|
282
|
+
svfloat64_t tentative_sum_f64x = svadd_f64_m(predicate_b64x, sum_imag_f64x, product_f64x);
|
|
283
|
+
svfloat64_t virtual_addend_f64x = svsub_f64_x(predicate_b64x, tentative_sum_f64x, sum_imag_f64x);
|
|
244
284
|
svfloat64_t sum_error_f64x = svadd_f64_x(
|
|
245
|
-
|
|
246
|
-
svsub_f64_x(
|
|
247
|
-
svsub_f64_x(
|
|
248
|
-
svsub_f64_x(
|
|
285
|
+
predicate_b64x,
|
|
286
|
+
svsub_f64_x(predicate_b64x, sum_imag_f64x,
|
|
287
|
+
svsub_f64_x(predicate_b64x, tentative_sum_f64x, virtual_addend_f64x)),
|
|
288
|
+
svsub_f64_x(predicate_b64x, product_f64x, virtual_addend_f64x));
|
|
249
289
|
sum_imag_f64x = tentative_sum_f64x;
|
|
250
|
-
comp_imag_f64x =
|
|
251
|
-
svadd_f64_x(
|
|
290
|
+
comp_imag_f64x = svadd_f64_m(predicate_b64x, comp_imag_f64x,
|
|
291
|
+
svadd_f64_x(predicate_b64x, sum_error_f64x, product_error_f64x));
|
|
252
292
|
}
|
|
253
293
|
// TwoProd + TwoSum for imaginary part: sum_imag += a_imag*b_real
|
|
254
294
|
{
|
|
255
|
-
svfloat64_t product_f64x = svmul_f64_x(
|
|
295
|
+
svfloat64_t product_f64x = svmul_f64_x(predicate_b64x, a_imag_f64x, b_real_f64x);
|
|
256
296
|
svfloat64_t product_error_f64x = svneg_f64_x(
|
|
257
|
-
|
|
258
|
-
svfloat64_t tentative_sum_f64x =
|
|
259
|
-
svfloat64_t virtual_addend_f64x = svsub_f64_x(
|
|
297
|
+
predicate_b64x, svnmls_f64_x(predicate_b64x, product_f64x, a_imag_f64x, b_real_f64x));
|
|
298
|
+
svfloat64_t tentative_sum_f64x = svadd_f64_m(predicate_b64x, sum_imag_f64x, product_f64x);
|
|
299
|
+
svfloat64_t virtual_addend_f64x = svsub_f64_x(predicate_b64x, tentative_sum_f64x, sum_imag_f64x);
|
|
260
300
|
svfloat64_t sum_error_f64x = svadd_f64_x(
|
|
261
|
-
|
|
262
|
-
svsub_f64_x(
|
|
263
|
-
svsub_f64_x(
|
|
264
|
-
svsub_f64_x(
|
|
301
|
+
predicate_b64x,
|
|
302
|
+
svsub_f64_x(predicate_b64x, sum_imag_f64x,
|
|
303
|
+
svsub_f64_x(predicate_b64x, tentative_sum_f64x, virtual_addend_f64x)),
|
|
304
|
+
svsub_f64_x(predicate_b64x, product_f64x, virtual_addend_f64x));
|
|
265
305
|
sum_imag_f64x = tentative_sum_f64x;
|
|
266
|
-
comp_imag_f64x =
|
|
267
|
-
svadd_f64_x(
|
|
306
|
+
comp_imag_f64x = svadd_f64_m(predicate_b64x, comp_imag_f64x,
|
|
307
|
+
svadd_f64_x(predicate_b64x, sum_error_f64x, product_error_f64x));
|
|
268
308
|
}
|
|
269
309
|
idx_pairs += svcntd();
|
|
270
310
|
} while (idx_pairs < count_pairs);
|
|
271
|
-
svbool_t
|
|
272
|
-
results->real = nk_dot_stable_sum_f64_sve_(
|
|
273
|
-
results->imag = nk_dot_stable_sum_f64_sve_(
|
|
311
|
+
svbool_t predicate_all_b64x = svptrue_b64();
|
|
312
|
+
results->real = nk_dot_stable_sum_f64_sve_(predicate_all_b64x, sum_real_f64x, comp_real_f64x);
|
|
313
|
+
results->imag = nk_dot_stable_sum_f64_sve_(predicate_all_b64x, sum_imag_f64x, comp_imag_f64x);
|
|
274
314
|
}
|
|
275
315
|
|
|
276
316
|
NK_PUBLIC void nk_vdot_f64c_sve(nk_f64c_t const *a_pairs, nk_f64c_t const *b_pairs, nk_size_t count_pairs,
|
|
@@ -283,9 +323,9 @@ NK_PUBLIC void nk_vdot_f64c_sve(nk_f64c_t const *a_pairs, nk_f64c_t const *b_pai
|
|
|
283
323
|
svfloat64_t sum_imag_f64x = svdup_f64(0.);
|
|
284
324
|
svfloat64_t comp_imag_f64x = svdup_f64(0.);
|
|
285
325
|
do {
|
|
286
|
-
svbool_t
|
|
287
|
-
svfloat64x2_t a_f64x2 = svld2_f64(
|
|
288
|
-
svfloat64x2_t b_f64x2 = svld2_f64(
|
|
326
|
+
svbool_t predicate_b64x = svwhilelt_b64_u64(idx_pairs, count_pairs);
|
|
327
|
+
svfloat64x2_t a_f64x2 = svld2_f64(predicate_b64x, (nk_f64_t const *)(a_pairs + idx_pairs));
|
|
328
|
+
svfloat64x2_t b_f64x2 = svld2_f64(predicate_b64x, (nk_f64_t const *)(b_pairs + idx_pairs));
|
|
289
329
|
svfloat64_t a_real_f64x = svget2_f64(a_f64x2, 0);
|
|
290
330
|
svfloat64_t a_imag_f64x = svget2_f64(a_f64x2, 1);
|
|
291
331
|
svfloat64_t b_real_f64x = svget2_f64(b_f64x2, 0);
|
|
@@ -293,75 +333,75 @@ NK_PUBLIC void nk_vdot_f64c_sve(nk_f64c_t const *a_pairs, nk_f64c_t const *b_pai
|
|
|
293
333
|
|
|
294
334
|
// TwoProd + TwoSum for real part: sum_real += a_real*b_real
|
|
295
335
|
{
|
|
296
|
-
svfloat64_t product_f64x = svmul_f64_x(
|
|
336
|
+
svfloat64_t product_f64x = svmul_f64_x(predicate_b64x, a_real_f64x, b_real_f64x);
|
|
297
337
|
svfloat64_t product_error_f64x = svneg_f64_x(
|
|
298
|
-
|
|
299
|
-
svfloat64_t tentative_sum_f64x =
|
|
300
|
-
svfloat64_t virtual_addend_f64x = svsub_f64_x(
|
|
338
|
+
predicate_b64x, svnmls_f64_x(predicate_b64x, product_f64x, a_real_f64x, b_real_f64x));
|
|
339
|
+
svfloat64_t tentative_sum_f64x = svadd_f64_m(predicate_b64x, sum_real_f64x, product_f64x);
|
|
340
|
+
svfloat64_t virtual_addend_f64x = svsub_f64_x(predicate_b64x, tentative_sum_f64x, sum_real_f64x);
|
|
301
341
|
svfloat64_t sum_error_f64x = svadd_f64_x(
|
|
302
|
-
|
|
303
|
-
svsub_f64_x(
|
|
304
|
-
svsub_f64_x(
|
|
305
|
-
svsub_f64_x(
|
|
342
|
+
predicate_b64x,
|
|
343
|
+
svsub_f64_x(predicate_b64x, sum_real_f64x,
|
|
344
|
+
svsub_f64_x(predicate_b64x, tentative_sum_f64x, virtual_addend_f64x)),
|
|
345
|
+
svsub_f64_x(predicate_b64x, product_f64x, virtual_addend_f64x));
|
|
306
346
|
sum_real_f64x = tentative_sum_f64x;
|
|
307
|
-
comp_real_f64x =
|
|
308
|
-
svadd_f64_x(
|
|
347
|
+
comp_real_f64x = svadd_f64_m(predicate_b64x, comp_real_f64x,
|
|
348
|
+
svadd_f64_x(predicate_b64x, sum_error_f64x, product_error_f64x));
|
|
309
349
|
}
|
|
310
350
|
// TwoProd + TwoSum for real part: sum_real += a_imag*b_imag (conjugate: + not -)
|
|
311
351
|
{
|
|
312
|
-
svfloat64_t product_f64x = svmul_f64_x(
|
|
352
|
+
svfloat64_t product_f64x = svmul_f64_x(predicate_b64x, a_imag_f64x, b_imag_f64x);
|
|
313
353
|
svfloat64_t product_error_f64x = svneg_f64_x(
|
|
314
|
-
|
|
315
|
-
svfloat64_t tentative_sum_f64x =
|
|
316
|
-
svfloat64_t virtual_addend_f64x = svsub_f64_x(
|
|
354
|
+
predicate_b64x, svnmls_f64_x(predicate_b64x, product_f64x, a_imag_f64x, b_imag_f64x));
|
|
355
|
+
svfloat64_t tentative_sum_f64x = svadd_f64_m(predicate_b64x, sum_real_f64x, product_f64x);
|
|
356
|
+
svfloat64_t virtual_addend_f64x = svsub_f64_x(predicate_b64x, tentative_sum_f64x, sum_real_f64x);
|
|
317
357
|
svfloat64_t sum_error_f64x = svadd_f64_x(
|
|
318
|
-
|
|
319
|
-
svsub_f64_x(
|
|
320
|
-
svsub_f64_x(
|
|
321
|
-
svsub_f64_x(
|
|
358
|
+
predicate_b64x,
|
|
359
|
+
svsub_f64_x(predicate_b64x, sum_real_f64x,
|
|
360
|
+
svsub_f64_x(predicate_b64x, tentative_sum_f64x, virtual_addend_f64x)),
|
|
361
|
+
svsub_f64_x(predicate_b64x, product_f64x, virtual_addend_f64x));
|
|
322
362
|
sum_real_f64x = tentative_sum_f64x;
|
|
323
|
-
comp_real_f64x =
|
|
324
|
-
svadd_f64_x(
|
|
363
|
+
comp_real_f64x = svadd_f64_m(predicate_b64x, comp_real_f64x,
|
|
364
|
+
svadd_f64_x(predicate_b64x, sum_error_f64x, product_error_f64x));
|
|
325
365
|
}
|
|
326
366
|
// TwoProd + TwoSum for imaginary part: sum_imag += a_real*b_imag
|
|
327
367
|
{
|
|
328
|
-
svfloat64_t product_f64x = svmul_f64_x(
|
|
368
|
+
svfloat64_t product_f64x = svmul_f64_x(predicate_b64x, a_real_f64x, b_imag_f64x);
|
|
329
369
|
svfloat64_t product_error_f64x = svneg_f64_x(
|
|
330
|
-
|
|
331
|
-
svfloat64_t tentative_sum_f64x =
|
|
332
|
-
svfloat64_t virtual_addend_f64x = svsub_f64_x(
|
|
370
|
+
predicate_b64x, svnmls_f64_x(predicate_b64x, product_f64x, a_real_f64x, b_imag_f64x));
|
|
371
|
+
svfloat64_t tentative_sum_f64x = svadd_f64_m(predicate_b64x, sum_imag_f64x, product_f64x);
|
|
372
|
+
svfloat64_t virtual_addend_f64x = svsub_f64_x(predicate_b64x, tentative_sum_f64x, sum_imag_f64x);
|
|
333
373
|
svfloat64_t sum_error_f64x = svadd_f64_x(
|
|
334
|
-
|
|
335
|
-
svsub_f64_x(
|
|
336
|
-
svsub_f64_x(
|
|
337
|
-
svsub_f64_x(
|
|
374
|
+
predicate_b64x,
|
|
375
|
+
svsub_f64_x(predicate_b64x, sum_imag_f64x,
|
|
376
|
+
svsub_f64_x(predicate_b64x, tentative_sum_f64x, virtual_addend_f64x)),
|
|
377
|
+
svsub_f64_x(predicate_b64x, product_f64x, virtual_addend_f64x));
|
|
338
378
|
sum_imag_f64x = tentative_sum_f64x;
|
|
339
|
-
comp_imag_f64x =
|
|
340
|
-
svadd_f64_x(
|
|
379
|
+
comp_imag_f64x = svadd_f64_m(predicate_b64x, comp_imag_f64x,
|
|
380
|
+
svadd_f64_x(predicate_b64x, sum_error_f64x, product_error_f64x));
|
|
341
381
|
}
|
|
342
382
|
// TwoProd + TwoSum for imaginary part: sum_imag -= a_imag*b_real (conjugate: - not +)
|
|
343
383
|
{
|
|
344
|
-
svfloat64_t product_f64x = svmul_f64_x(
|
|
384
|
+
svfloat64_t product_f64x = svmul_f64_x(predicate_b64x, a_imag_f64x, b_real_f64x);
|
|
345
385
|
svfloat64_t product_error_f64x = svneg_f64_x(
|
|
346
|
-
|
|
347
|
-
svfloat64_t neg_product_f64x = svneg_f64_x(
|
|
348
|
-
svfloat64_t neg_product_error_f64x = svneg_f64_x(
|
|
349
|
-
svfloat64_t tentative_sum_f64x =
|
|
350
|
-
svfloat64_t virtual_addend_f64x = svsub_f64_x(
|
|
386
|
+
predicate_b64x, svnmls_f64_x(predicate_b64x, product_f64x, a_imag_f64x, b_real_f64x));
|
|
387
|
+
svfloat64_t neg_product_f64x = svneg_f64_x(predicate_b64x, product_f64x);
|
|
388
|
+
svfloat64_t neg_product_error_f64x = svneg_f64_x(predicate_b64x, product_error_f64x);
|
|
389
|
+
svfloat64_t tentative_sum_f64x = svadd_f64_m(predicate_b64x, sum_imag_f64x, neg_product_f64x);
|
|
390
|
+
svfloat64_t virtual_addend_f64x = svsub_f64_x(predicate_b64x, tentative_sum_f64x, sum_imag_f64x);
|
|
351
391
|
svfloat64_t sum_error_f64x = svadd_f64_x(
|
|
352
|
-
|
|
353
|
-
svsub_f64_x(
|
|
354
|
-
svsub_f64_x(
|
|
355
|
-
svsub_f64_x(
|
|
392
|
+
predicate_b64x,
|
|
393
|
+
svsub_f64_x(predicate_b64x, sum_imag_f64x,
|
|
394
|
+
svsub_f64_x(predicate_b64x, tentative_sum_f64x, virtual_addend_f64x)),
|
|
395
|
+
svsub_f64_x(predicate_b64x, neg_product_f64x, virtual_addend_f64x));
|
|
356
396
|
sum_imag_f64x = tentative_sum_f64x;
|
|
357
|
-
comp_imag_f64x =
|
|
358
|
-
svadd_f64_x(
|
|
397
|
+
comp_imag_f64x = svadd_f64_m(predicate_b64x, comp_imag_f64x,
|
|
398
|
+
svadd_f64_x(predicate_b64x, sum_error_f64x, neg_product_error_f64x));
|
|
359
399
|
}
|
|
360
400
|
idx_pairs += svcntd();
|
|
361
401
|
} while (idx_pairs < count_pairs);
|
|
362
|
-
svbool_t
|
|
363
|
-
results->real = nk_dot_stable_sum_f64_sve_(
|
|
364
|
-
results->imag = nk_dot_stable_sum_f64_sve_(
|
|
402
|
+
svbool_t predicate_all_b64x = svptrue_b64();
|
|
403
|
+
results->real = nk_dot_stable_sum_f64_sve_(predicate_all_b64x, sum_real_f64x, comp_real_f64x);
|
|
404
|
+
results->imag = nk_dot_stable_sum_f64_sve_(predicate_all_b64x, sum_imag_f64x, comp_imag_f64x);
|
|
365
405
|
}
|
|
366
406
|
|
|
367
407
|
#if defined(__clang__)
|
|
@@ -8,13 +8,13 @@
|
|
|
8
8
|
*
|
|
9
9
|
* @section dot_svebfdot_instructions ARM SVE+BF16 Instructions
|
|
10
10
|
*
|
|
11
|
-
* Intrinsic
|
|
12
|
-
* svld1_bf16
|
|
13
|
-
* svbfdot_f32
|
|
14
|
-
* svaddv_f32
|
|
15
|
-
* svdup_f32
|
|
16
|
-
* svwhilelt_b16
|
|
17
|
-
* svcnth
|
|
11
|
+
* Intrinsic Instruction V1
|
|
12
|
+
* svld1_bf16 LD1H (Z.H, P/Z, [Xn]) 4-6cy @ 2p
|
|
13
|
+
* svbfdot_f32 BFDOT (Z.S, Z.H, Z.H) 4cy @ 2p
|
|
14
|
+
* svaddv_f32 FADDV (S, P, Z.S) 6cy @ 1p
|
|
15
|
+
* svdup_f32 DUP (Z.S, #imm) 1cy @ 2p
|
|
16
|
+
* svwhilelt_b16 WHILELT (P.H, Xn, Xm) 2cy @ 1p
|
|
17
|
+
* svcnth CNTH (Xd) 1cy @ 2p
|
|
18
18
|
*
|
|
19
19
|
* SVE vector widths vary across implementations: Graviton3 uses 256-bit, while Graviton4/5
|
|
20
20
|
* and Apple M4+ use 128-bit. Code using svcnth() adapts automatically, but wider vectors
|
|
@@ -50,9 +50,9 @@ NK_PUBLIC void nk_dot_bf16_svebfdot(nk_bf16_t const *a_scalars, nk_bf16_t const
|
|
|
50
50
|
nk_bf16_for_arm_simd_t const *a = (nk_bf16_for_arm_simd_t const *)(a_scalars);
|
|
51
51
|
nk_bf16_for_arm_simd_t const *b = (nk_bf16_for_arm_simd_t const *)(b_scalars);
|
|
52
52
|
do {
|
|
53
|
-
svbool_t
|
|
54
|
-
svbfloat16_t a_bf16x = svld1_bf16(
|
|
55
|
-
svbfloat16_t b_bf16x = svld1_bf16(
|
|
53
|
+
svbool_t predicate_b16x = svwhilelt_b16_u64(idx_scalars, count_scalars);
|
|
54
|
+
svbfloat16_t a_bf16x = svld1_bf16(predicate_b16x, a + idx_scalars);
|
|
55
|
+
svbfloat16_t b_bf16x = svld1_bf16(predicate_b16x, b + idx_scalars);
|
|
56
56
|
sum_f32x = svbfdot_f32(sum_f32x, a_bf16x, b_bf16x);
|
|
57
57
|
idx_scalars += svcnth();
|
|
58
58
|
} while (idx_scalars < count_scalars);
|