numkong 7.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +201 -0
- package/README.md +495 -0
- package/binding.gyp +540 -0
- package/c/dispatch.h +512 -0
- package/c/dispatch_bf16.c +389 -0
- package/c/dispatch_bf16c.c +52 -0
- package/c/dispatch_e2m3.c +263 -0
- package/c/dispatch_e3m2.c +243 -0
- package/c/dispatch_e4m3.c +276 -0
- package/c/dispatch_e5m2.c +272 -0
- package/c/dispatch_f16.c +376 -0
- package/c/dispatch_f16c.c +58 -0
- package/c/dispatch_f32.c +378 -0
- package/c/dispatch_f32c.c +99 -0
- package/c/dispatch_f64.c +296 -0
- package/c/dispatch_f64c.c +98 -0
- package/c/dispatch_i16.c +96 -0
- package/c/dispatch_i32.c +89 -0
- package/c/dispatch_i4.c +150 -0
- package/c/dispatch_i64.c +86 -0
- package/c/dispatch_i8.c +289 -0
- package/c/dispatch_other.c +330 -0
- package/c/dispatch_u1.c +148 -0
- package/c/dispatch_u16.c +124 -0
- package/c/dispatch_u32.c +118 -0
- package/c/dispatch_u4.c +150 -0
- package/c/dispatch_u64.c +102 -0
- package/c/dispatch_u8.c +303 -0
- package/c/numkong.c +950 -0
- package/include/README.md +573 -0
- package/include/module.modulemap +129 -0
- package/include/numkong/attention/sapphireamx.h +1361 -0
- package/include/numkong/attention/sme.h +2066 -0
- package/include/numkong/attention.h +49 -0
- package/include/numkong/capabilities.h +748 -0
- package/include/numkong/cast/README.md +262 -0
- package/include/numkong/cast/haswell.h +975 -0
- package/include/numkong/cast/icelake.h +470 -0
- package/include/numkong/cast/neon.h +1192 -0
- package/include/numkong/cast/rvv.h +1021 -0
- package/include/numkong/cast/sapphire.h +262 -0
- package/include/numkong/cast/serial.h +2262 -0
- package/include/numkong/cast/skylake.h +856 -0
- package/include/numkong/cast/v128relaxed.h +180 -0
- package/include/numkong/cast.h +230 -0
- package/include/numkong/curved/README.md +223 -0
- package/include/numkong/curved/genoa.h +182 -0
- package/include/numkong/curved/haswell.h +276 -0
- package/include/numkong/curved/neon.h +205 -0
- package/include/numkong/curved/neonbfdot.h +212 -0
- package/include/numkong/curved/neonhalf.h +212 -0
- package/include/numkong/curved/rvv.h +305 -0
- package/include/numkong/curved/serial.h +207 -0
- package/include/numkong/curved/skylake.h +457 -0
- package/include/numkong/curved/smef64.h +506 -0
- package/include/numkong/curved.h +517 -0
- package/include/numkong/curved.hpp +144 -0
- package/include/numkong/dot/README.md +425 -0
- package/include/numkong/dot/alder.h +563 -0
- package/include/numkong/dot/genoa.h +315 -0
- package/include/numkong/dot/haswell.h +1688 -0
- package/include/numkong/dot/icelake.h +883 -0
- package/include/numkong/dot/neon.h +818 -0
- package/include/numkong/dot/neonbfdot.h +244 -0
- package/include/numkong/dot/neonfhm.h +360 -0
- package/include/numkong/dot/neonhalf.h +198 -0
- package/include/numkong/dot/neonsdot.h +508 -0
- package/include/numkong/dot/rvv.h +714 -0
- package/include/numkong/dot/rvvbb.h +72 -0
- package/include/numkong/dot/rvvbf16.h +123 -0
- package/include/numkong/dot/rvvhalf.h +129 -0
- package/include/numkong/dot/sapphire.h +141 -0
- package/include/numkong/dot/serial.h +838 -0
- package/include/numkong/dot/sierra.h +405 -0
- package/include/numkong/dot/skylake.h +1084 -0
- package/include/numkong/dot/sve.h +379 -0
- package/include/numkong/dot/svebfdot.h +74 -0
- package/include/numkong/dot/svehalf.h +123 -0
- package/include/numkong/dot/v128relaxed.h +1258 -0
- package/include/numkong/dot.h +1070 -0
- package/include/numkong/dot.hpp +94 -0
- package/include/numkong/dots/README.md +496 -0
- package/include/numkong/dots/alder.h +114 -0
- package/include/numkong/dots/genoa.h +94 -0
- package/include/numkong/dots/haswell.h +295 -0
- package/include/numkong/dots/icelake.h +171 -0
- package/include/numkong/dots/neon.h +120 -0
- package/include/numkong/dots/neonbfdot.h +58 -0
- package/include/numkong/dots/neonfhm.h +94 -0
- package/include/numkong/dots/neonhalf.h +57 -0
- package/include/numkong/dots/neonsdot.h +108 -0
- package/include/numkong/dots/rvv.h +2486 -0
- package/include/numkong/dots/sapphireamx.h +3973 -0
- package/include/numkong/dots/serial.h +2844 -0
- package/include/numkong/dots/sierra.h +97 -0
- package/include/numkong/dots/skylake.h +196 -0
- package/include/numkong/dots/sme.h +5372 -0
- package/include/numkong/dots/smebi32.h +461 -0
- package/include/numkong/dots/smef64.h +1318 -0
- package/include/numkong/dots/smehalf.h +47 -0
- package/include/numkong/dots/v128relaxed.h +294 -0
- package/include/numkong/dots.h +2804 -0
- package/include/numkong/dots.hpp +639 -0
- package/include/numkong/each/README.md +469 -0
- package/include/numkong/each/haswell.h +1658 -0
- package/include/numkong/each/icelake.h +272 -0
- package/include/numkong/each/neon.h +1104 -0
- package/include/numkong/each/neonbfdot.h +212 -0
- package/include/numkong/each/neonhalf.h +410 -0
- package/include/numkong/each/rvv.h +1121 -0
- package/include/numkong/each/sapphire.h +477 -0
- package/include/numkong/each/serial.h +260 -0
- package/include/numkong/each/skylake.h +1562 -0
- package/include/numkong/each.h +2146 -0
- package/include/numkong/each.hpp +434 -0
- package/include/numkong/geospatial/README.md +147 -0
- package/include/numkong/geospatial/haswell.h +593 -0
- package/include/numkong/geospatial/neon.h +571 -0
- package/include/numkong/geospatial/rvv.h +701 -0
- package/include/numkong/geospatial/serial.h +309 -0
- package/include/numkong/geospatial/skylake.h +577 -0
- package/include/numkong/geospatial/v128relaxed.h +613 -0
- package/include/numkong/geospatial.h +453 -0
- package/include/numkong/geospatial.hpp +235 -0
- package/include/numkong/matrix.hpp +336 -0
- package/include/numkong/maxsim/README.md +187 -0
- package/include/numkong/maxsim/alder.h +511 -0
- package/include/numkong/maxsim/genoa.h +115 -0
- package/include/numkong/maxsim/haswell.h +553 -0
- package/include/numkong/maxsim/icelake.h +480 -0
- package/include/numkong/maxsim/neonsdot.h +394 -0
- package/include/numkong/maxsim/sapphireamx.h +877 -0
- package/include/numkong/maxsim/serial.h +490 -0
- package/include/numkong/maxsim/sme.h +929 -0
- package/include/numkong/maxsim/v128relaxed.h +280 -0
- package/include/numkong/maxsim.h +571 -0
- package/include/numkong/maxsim.hpp +133 -0
- package/include/numkong/mesh/README.md +227 -0
- package/include/numkong/mesh/haswell.h +2235 -0
- package/include/numkong/mesh/neon.h +1329 -0
- package/include/numkong/mesh/neonbfdot.h +842 -0
- package/include/numkong/mesh/neonhalf.h +616 -0
- package/include/numkong/mesh/rvv.h +916 -0
- package/include/numkong/mesh/serial.h +742 -0
- package/include/numkong/mesh/skylake.h +1135 -0
- package/include/numkong/mesh/v128relaxed.h +1052 -0
- package/include/numkong/mesh.h +652 -0
- package/include/numkong/mesh.hpp +762 -0
- package/include/numkong/numkong.h +78 -0
- package/include/numkong/numkong.hpp +57 -0
- package/include/numkong/probability/README.md +173 -0
- package/include/numkong/probability/haswell.h +267 -0
- package/include/numkong/probability/neon.h +225 -0
- package/include/numkong/probability/rvv.h +409 -0
- package/include/numkong/probability/serial.h +169 -0
- package/include/numkong/probability/skylake.h +324 -0
- package/include/numkong/probability.h +383 -0
- package/include/numkong/probability.hpp +120 -0
- package/include/numkong/random.h +50 -0
- package/include/numkong/random.hpp +285 -0
- package/include/numkong/reduce/README.md +547 -0
- package/include/numkong/reduce/alder.h +632 -0
- package/include/numkong/reduce/genoa.h +201 -0
- package/include/numkong/reduce/haswell.h +3783 -0
- package/include/numkong/reduce/icelake.h +549 -0
- package/include/numkong/reduce/neon.h +3841 -0
- package/include/numkong/reduce/neonbfdot.h +353 -0
- package/include/numkong/reduce/neonfhm.h +665 -0
- package/include/numkong/reduce/neonhalf.h +157 -0
- package/include/numkong/reduce/neonsdot.h +357 -0
- package/include/numkong/reduce/rvv.h +3407 -0
- package/include/numkong/reduce/serial.h +757 -0
- package/include/numkong/reduce/sierra.h +338 -0
- package/include/numkong/reduce/skylake.h +3792 -0
- package/include/numkong/reduce/v128relaxed.h +2302 -0
- package/include/numkong/reduce.h +1597 -0
- package/include/numkong/reduce.hpp +633 -0
- package/include/numkong/scalar/README.md +89 -0
- package/include/numkong/scalar/haswell.h +113 -0
- package/include/numkong/scalar/neon.h +122 -0
- package/include/numkong/scalar/neonhalf.h +70 -0
- package/include/numkong/scalar/rvv.h +211 -0
- package/include/numkong/scalar/sapphire.h +63 -0
- package/include/numkong/scalar/serial.h +332 -0
- package/include/numkong/scalar/v128relaxed.h +56 -0
- package/include/numkong/scalar.h +683 -0
- package/include/numkong/set/README.md +179 -0
- package/include/numkong/set/haswell.h +334 -0
- package/include/numkong/set/icelake.h +485 -0
- package/include/numkong/set/neon.h +364 -0
- package/include/numkong/set/rvv.h +226 -0
- package/include/numkong/set/rvvbb.h +117 -0
- package/include/numkong/set/serial.h +174 -0
- package/include/numkong/set/sve.h +185 -0
- package/include/numkong/set/v128relaxed.h +240 -0
- package/include/numkong/set.h +457 -0
- package/include/numkong/set.hpp +114 -0
- package/include/numkong/sets/README.md +149 -0
- package/include/numkong/sets/haswell.h +63 -0
- package/include/numkong/sets/icelake.h +66 -0
- package/include/numkong/sets/neon.h +61 -0
- package/include/numkong/sets/serial.h +43 -0
- package/include/numkong/sets/smebi32.h +1099 -0
- package/include/numkong/sets/v128relaxed.h +58 -0
- package/include/numkong/sets.h +339 -0
- package/include/numkong/sparse/README.md +156 -0
- package/include/numkong/sparse/icelake.h +463 -0
- package/include/numkong/sparse/neon.h +288 -0
- package/include/numkong/sparse/serial.h +117 -0
- package/include/numkong/sparse/sve2.h +507 -0
- package/include/numkong/sparse/turin.h +322 -0
- package/include/numkong/sparse.h +363 -0
- package/include/numkong/sparse.hpp +113 -0
- package/include/numkong/spatial/README.md +435 -0
- package/include/numkong/spatial/alder.h +607 -0
- package/include/numkong/spatial/genoa.h +290 -0
- package/include/numkong/spatial/haswell.h +960 -0
- package/include/numkong/spatial/icelake.h +586 -0
- package/include/numkong/spatial/neon.h +773 -0
- package/include/numkong/spatial/neonbfdot.h +165 -0
- package/include/numkong/spatial/neonhalf.h +118 -0
- package/include/numkong/spatial/neonsdot.h +261 -0
- package/include/numkong/spatial/rvv.h +984 -0
- package/include/numkong/spatial/rvvbf16.h +123 -0
- package/include/numkong/spatial/rvvhalf.h +117 -0
- package/include/numkong/spatial/sapphire.h +343 -0
- package/include/numkong/spatial/serial.h +346 -0
- package/include/numkong/spatial/sierra.h +323 -0
- package/include/numkong/spatial/skylake.h +606 -0
- package/include/numkong/spatial/sve.h +224 -0
- package/include/numkong/spatial/svebfdot.h +122 -0
- package/include/numkong/spatial/svehalf.h +109 -0
- package/include/numkong/spatial/v128relaxed.h +717 -0
- package/include/numkong/spatial.h +1425 -0
- package/include/numkong/spatial.hpp +183 -0
- package/include/numkong/spatials/README.md +580 -0
- package/include/numkong/spatials/alder.h +94 -0
- package/include/numkong/spatials/genoa.h +94 -0
- package/include/numkong/spatials/haswell.h +219 -0
- package/include/numkong/spatials/icelake.h +113 -0
- package/include/numkong/spatials/neon.h +109 -0
- package/include/numkong/spatials/neonbfdot.h +60 -0
- package/include/numkong/spatials/neonfhm.h +92 -0
- package/include/numkong/spatials/neonhalf.h +58 -0
- package/include/numkong/spatials/neonsdot.h +109 -0
- package/include/numkong/spatials/rvv.h +1960 -0
- package/include/numkong/spatials/sapphireamx.h +1149 -0
- package/include/numkong/spatials/serial.h +226 -0
- package/include/numkong/spatials/sierra.h +96 -0
- package/include/numkong/spatials/skylake.h +184 -0
- package/include/numkong/spatials/sme.h +1901 -0
- package/include/numkong/spatials/smef64.h +465 -0
- package/include/numkong/spatials/v128relaxed.h +240 -0
- package/include/numkong/spatials.h +3021 -0
- package/include/numkong/spatials.hpp +508 -0
- package/include/numkong/tensor.hpp +1592 -0
- package/include/numkong/trigonometry/README.md +184 -0
- package/include/numkong/trigonometry/haswell.h +652 -0
- package/include/numkong/trigonometry/neon.h +639 -0
- package/include/numkong/trigonometry/rvv.h +699 -0
- package/include/numkong/trigonometry/serial.h +703 -0
- package/include/numkong/trigonometry/skylake.h +721 -0
- package/include/numkong/trigonometry/v128relaxed.h +666 -0
- package/include/numkong/trigonometry.h +467 -0
- package/include/numkong/trigonometry.hpp +166 -0
- package/include/numkong/types.h +1384 -0
- package/include/numkong/types.hpp +5603 -0
- package/include/numkong/vector.hpp +698 -0
- package/javascript/README.md +246 -0
- package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
- package/javascript/dist/cjs/numkong-wasm.js +617 -0
- package/javascript/dist/cjs/numkong.d.ts +343 -0
- package/javascript/dist/cjs/numkong.js +523 -0
- package/javascript/dist/cjs/package.json +3 -0
- package/javascript/dist/cjs/types.d.ts +284 -0
- package/javascript/dist/cjs/types.js +653 -0
- package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
- package/javascript/dist/esm/numkong-wasm.js +595 -0
- package/javascript/dist/esm/numkong.d.ts +343 -0
- package/javascript/dist/esm/numkong.js +452 -0
- package/javascript/dist/esm/package.json +3 -0
- package/javascript/dist/esm/types.d.ts +284 -0
- package/javascript/dist/esm/types.js +630 -0
- package/javascript/dist-package-cjs.json +3 -0
- package/javascript/dist-package-esm.json +3 -0
- package/javascript/node-gyp-build.d.ts +1 -0
- package/javascript/numkong-wasm.ts +756 -0
- package/javascript/numkong.c +689 -0
- package/javascript/numkong.ts +575 -0
- package/javascript/tsconfig-base.json +39 -0
- package/javascript/tsconfig-cjs.json +8 -0
- package/javascript/tsconfig-esm.json +8 -0
- package/javascript/types.ts +674 -0
- package/package.json +87 -0
|
@@ -0,0 +1,379 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SIMD-accelerated Dot Products for SVE.
|
|
3
|
+
* @file include/numkong/dot/sve.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date December 27, 2025
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/dot.h
|
|
8
|
+
*
|
|
9
|
+
* @section dot_sve_instructions ARM SVE Instructions
|
|
10
|
+
*
|
|
11
|
+
* Intrinsic Instruction Latency Throughput
|
|
12
|
+
* svld1_f32 LD1W (Z.S, P/Z, [Xn]) 4-6cy 2/cy
|
|
13
|
+
* svld2_f32 LD2W (Z.S, P/Z, [Xn]) 6-8cy 1/cy
|
|
14
|
+
* svmla_f32_x FMLA (Z.S, P/M, Z.S, Z.S) 4cy 2/cy
|
|
15
|
+
* svmls_f32_x FMLS (Z.S, P/M, Z.S, Z.S) 4cy 2/cy
|
|
16
|
+
* svaddv_f32 FADDV (S, P, Z.S) 6cy 1/cy
|
|
17
|
+
* svdup_f32 DUP (Z.S, #imm) 1cy 2/cy
|
|
18
|
+
* svwhilelt_b32 WHILELT (P.S, Xn, Xm) 2cy 1/cy
|
|
19
|
+
* svptrue_b32 PTRUE (P.S, pattern) 1cy 2/cy
|
|
20
|
+
* svcntw CNTW (Xd) 1cy 2/cy
|
|
21
|
+
* svcntd CNTD (Xd) 1cy 2/cy
|
|
22
|
+
* svld1_f64 LD1D (Z.D, P/Z, [Xn]) 4-6cy 2/cy
|
|
23
|
+
* svld2_f64 LD2D (Z.D, P/Z, [Xn]) 6-8cy 1/cy
|
|
24
|
+
* svmla_f64_x FMLA (Z.D, P/M, Z.D, Z.D) 4cy 2/cy
|
|
25
|
+
* svmls_f64_x FMLS (Z.D, P/M, Z.D, Z.D) 4cy 2/cy
|
|
26
|
+
* svaddv_f64 FADDV (D, P, Z.D) 6cy 1/cy
|
|
27
|
+
*
|
|
28
|
+
* SVE vector widths vary across implementations: Graviton3 uses 256-bit, while Graviton4/5
|
|
29
|
+
* and Apple M4+ use 128-bit. Code using svcntb() adapts automatically, but wider vectors
|
|
30
|
+
* process more elements per iteration with identical latencies.
|
|
31
|
+
*
|
|
32
|
+
* The FADDV horizontal reduction has higher latency (6cy) compared to vertical operations,
|
|
33
|
+
* making it beneficial to accumulate in vector registers and reduce only at the end.
|
|
34
|
+
*/
|
|
35
|
+
#ifndef NK_DOT_SVE_H
|
|
36
|
+
#define NK_DOT_SVE_H
|
|
37
|
+
|
|
38
|
+
#if NK_TARGET_ARM_
|
|
39
|
+
#if NK_TARGET_SVE
|
|
40
|
+
|
|
41
|
+
#include "numkong/types.h" // `nk_f32_t`
|
|
42
|
+
#include "numkong/dot/serial.h" // `nk_u1x8_popcount_`
|
|
43
|
+
|
|
44
|
+
#if defined(__cplusplus)
|
|
45
|
+
extern "C" {
|
|
46
|
+
#endif
|
|
47
|
+
|
|
48
|
+
#if defined(__clang__)
|
|
49
|
+
#pragma clang attribute push(__attribute__((target("arch=armv8.2-a+sve"))), apply_to = function)
|
|
50
|
+
#elif defined(__GNUC__)
|
|
51
|
+
#pragma GCC push_options
|
|
52
|
+
#pragma GCC target("arch=armv8.2-a+sve")
|
|
53
|
+
#endif
|
|
54
|
+
|
|
55
|
+
/** @brief Compensated horizontal sum of SVE f64 lanes via TwoSum tree reduction.
|
|
56
|
+
*
|
|
57
|
+
* Uses svtbl to extract the upper half at each tree level. Out-of-range indices
|
|
58
|
+
* return 0 (SVE spec), which is harmless since only the lower half is meaningful
|
|
59
|
+
* after each halving stage.
|
|
60
|
+
*/
|
|
61
|
+
NK_INTERNAL nk_f64_t nk_dot_stable_sum_f64_sve_(svbool_t predicate, svfloat64_t sum, svfloat64_t compensation) {
|
|
62
|
+
// Stage 0: TwoSum merge of sum + compensation (parallel across all active lanes)
|
|
63
|
+
svfloat64_t tentative_sum_f64x = svadd_f64_x(predicate, sum, compensation);
|
|
64
|
+
svfloat64_t virtual_addend_f64x = svsub_f64_x(predicate, tentative_sum_f64x, sum);
|
|
65
|
+
svfloat64_t accumulated_error_f64x = svadd_f64_x(
|
|
66
|
+
predicate, svsub_f64_x(predicate, sum, svsub_f64_x(predicate, tentative_sum_f64x, virtual_addend_f64x)),
|
|
67
|
+
svsub_f64_x(predicate, compensation, virtual_addend_f64x));
|
|
68
|
+
|
|
69
|
+
// Tree reduction: TwoSum halving at each level, log2(VL) iterations
|
|
70
|
+
for (unsigned int half = (unsigned int)svcntd() / 2; half > 0; half >>= 1) {
|
|
71
|
+
svuint64_t upper_indices_u64x = svadd_n_u64_x(predicate, svindex_u64(0, 1), half);
|
|
72
|
+
svfloat64_t upper_sum_f64x = svtbl_f64(tentative_sum_f64x, upper_indices_u64x);
|
|
73
|
+
svfloat64_t upper_error_f64x = svtbl_f64(accumulated_error_f64x, upper_indices_u64x);
|
|
74
|
+
// TwoSum: lower_half + upper_half
|
|
75
|
+
svfloat64_t halved_tentative_sum_f64x = svadd_f64_x(predicate, tentative_sum_f64x, upper_sum_f64x);
|
|
76
|
+
svfloat64_t halved_virtual_addend_f64x = svsub_f64_x(predicate, halved_tentative_sum_f64x, tentative_sum_f64x);
|
|
77
|
+
svfloat64_t rounding_error_f64x = svadd_f64_x(
|
|
78
|
+
predicate,
|
|
79
|
+
svsub_f64_x(predicate, tentative_sum_f64x,
|
|
80
|
+
svsub_f64_x(predicate, halved_tentative_sum_f64x, halved_virtual_addend_f64x)),
|
|
81
|
+
svsub_f64_x(predicate, upper_sum_f64x, halved_virtual_addend_f64x));
|
|
82
|
+
tentative_sum_f64x = halved_tentative_sum_f64x;
|
|
83
|
+
accumulated_error_f64x = svadd_f64_x(
|
|
84
|
+
predicate, svadd_f64_x(predicate, accumulated_error_f64x, upper_error_f64x), rounding_error_f64x);
|
|
85
|
+
}
|
|
86
|
+
// Result is in lane 0
|
|
87
|
+
svbool_t predicate_first_f64x = svwhilelt_b64_u64(0u, 1);
|
|
88
|
+
return svlastb_f64(predicate_first_f64x, tentative_sum_f64x) +
|
|
89
|
+
svlastb_f64(predicate_first_f64x, accumulated_error_f64x);
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
NK_PUBLIC void nk_dot_f32_sve(nk_f32_t const *a_scalars, nk_f32_t const *b_scalars, nk_size_t count_scalars,
|
|
93
|
+
nk_f64_t *result) {
|
|
94
|
+
nk_size_t idx_scalars = 0;
|
|
95
|
+
nk_size_t const vector_length = svcntd();
|
|
96
|
+
svfloat64_t ab_f64x = svdup_f64(0.);
|
|
97
|
+
for (; idx_scalars < count_scalars; idx_scalars += vector_length) {
|
|
98
|
+
svbool_t predicate_f64x = svwhilelt_b64_u64(idx_scalars, count_scalars);
|
|
99
|
+
svfloat64_t a_f64x = svcvt_f64_f32_x(
|
|
100
|
+
predicate_f64x, svld1_f32(svwhilelt_b32_u64(idx_scalars, count_scalars), a_scalars + idx_scalars));
|
|
101
|
+
svfloat64_t b_f64x = svcvt_f64_f32_x(
|
|
102
|
+
predicate_f64x, svld1_f32(svwhilelt_b32_u64(idx_scalars, count_scalars), b_scalars + idx_scalars));
|
|
103
|
+
ab_f64x = svmla_f64_x(predicate_f64x, ab_f64x, a_f64x, b_f64x);
|
|
104
|
+
}
|
|
105
|
+
*result = svaddv_f64(svptrue_b64(), ab_f64x);
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
NK_PUBLIC void nk_dot_f32c_sve(nk_f32c_t const *a_pairs, nk_f32c_t const *b_pairs, nk_size_t count_pairs,
|
|
109
|
+
nk_f64c_t *results) {
|
|
110
|
+
nk_size_t idx_pairs = 0;
|
|
111
|
+
nk_size_t const vector_length = svcntd();
|
|
112
|
+
svfloat64_t ab_real_f64x = svdup_f64(0.);
|
|
113
|
+
svfloat64_t ab_imag_f64x = svdup_f64(0.);
|
|
114
|
+
for (; idx_pairs < count_pairs; idx_pairs += vector_length) {
|
|
115
|
+
svbool_t predicate_f64x = svwhilelt_b64_u64(idx_pairs, count_pairs);
|
|
116
|
+
svbool_t predicate_f32x = svwhilelt_b32_u64(idx_pairs, count_pairs);
|
|
117
|
+
svfloat32x2_t a_f32x2 = svld2_f32(predicate_f32x, (nk_f32_t const *)(a_pairs + idx_pairs));
|
|
118
|
+
svfloat32x2_t b_f32x2 = svld2_f32(predicate_f32x, (nk_f32_t const *)(b_pairs + idx_pairs));
|
|
119
|
+
svfloat64_t a_real_f64x = svcvt_f64_f32_x(predicate_f64x, svget2_f32(a_f32x2, 0));
|
|
120
|
+
svfloat64_t a_imag_f64x = svcvt_f64_f32_x(predicate_f64x, svget2_f32(a_f32x2, 1));
|
|
121
|
+
svfloat64_t b_real_f64x = svcvt_f64_f32_x(predicate_f64x, svget2_f32(b_f32x2, 0));
|
|
122
|
+
svfloat64_t b_imag_f64x = svcvt_f64_f32_x(predicate_f64x, svget2_f32(b_f32x2, 1));
|
|
123
|
+
ab_real_f64x = svmla_f64_x(predicate_f64x, ab_real_f64x, a_real_f64x, b_real_f64x);
|
|
124
|
+
ab_real_f64x = svmls_f64_x(predicate_f64x, ab_real_f64x, a_imag_f64x, b_imag_f64x);
|
|
125
|
+
ab_imag_f64x = svmla_f64_x(predicate_f64x, ab_imag_f64x, a_real_f64x, b_imag_f64x);
|
|
126
|
+
ab_imag_f64x = svmla_f64_x(predicate_f64x, ab_imag_f64x, a_imag_f64x, b_real_f64x);
|
|
127
|
+
}
|
|
128
|
+
results->real = svaddv_f64(svptrue_b64(), ab_real_f64x);
|
|
129
|
+
results->imag = svaddv_f64(svptrue_b64(), ab_imag_f64x);
|
|
130
|
+
}
|
|
131
|
+
|
|
132
|
+
NK_PUBLIC void nk_vdot_f32c_sve(nk_f32c_t const *a_pairs, nk_f32c_t const *b_pairs, nk_size_t count_pairs,
|
|
133
|
+
nk_f64c_t *results) {
|
|
134
|
+
nk_size_t idx_pairs = 0;
|
|
135
|
+
nk_size_t const vector_length = svcntd();
|
|
136
|
+
svfloat64_t ab_real_f64x = svdup_f64(0.);
|
|
137
|
+
svfloat64_t ab_imag_f64x = svdup_f64(0.);
|
|
138
|
+
for (; idx_pairs < count_pairs; idx_pairs += vector_length) {
|
|
139
|
+
svbool_t predicate_f64x = svwhilelt_b64_u64(idx_pairs, count_pairs);
|
|
140
|
+
svbool_t predicate_f32x = svwhilelt_b32_u64(idx_pairs, count_pairs);
|
|
141
|
+
svfloat32x2_t a_f32x2 = svld2_f32(predicate_f32x, (nk_f32_t const *)(a_pairs + idx_pairs));
|
|
142
|
+
svfloat32x2_t b_f32x2 = svld2_f32(predicate_f32x, (nk_f32_t const *)(b_pairs + idx_pairs));
|
|
143
|
+
svfloat64_t a_real_f64x = svcvt_f64_f32_x(predicate_f64x, svget2_f32(a_f32x2, 0));
|
|
144
|
+
svfloat64_t a_imag_f64x = svcvt_f64_f32_x(predicate_f64x, svget2_f32(a_f32x2, 1));
|
|
145
|
+
svfloat64_t b_real_f64x = svcvt_f64_f32_x(predicate_f64x, svget2_f32(b_f32x2, 0));
|
|
146
|
+
svfloat64_t b_imag_f64x = svcvt_f64_f32_x(predicate_f64x, svget2_f32(b_f32x2, 1));
|
|
147
|
+
ab_real_f64x = svmla_f64_x(predicate_f64x, ab_real_f64x, a_real_f64x, b_real_f64x);
|
|
148
|
+
ab_real_f64x = svmla_f64_x(predicate_f64x, ab_real_f64x, a_imag_f64x, b_imag_f64x);
|
|
149
|
+
ab_imag_f64x = svmla_f64_x(predicate_f64x, ab_imag_f64x, a_real_f64x, b_imag_f64x);
|
|
150
|
+
ab_imag_f64x = svmls_f64_x(predicate_f64x, ab_imag_f64x, a_imag_f64x, b_real_f64x);
|
|
151
|
+
}
|
|
152
|
+
results->real = svaddv_f64(svptrue_b64(), ab_real_f64x);
|
|
153
|
+
results->imag = svaddv_f64(svptrue_b64(), ab_imag_f64x);
|
|
154
|
+
}
|
|
155
|
+
|
|
156
|
+
NK_PUBLIC void nk_dot_f64_sve(nk_f64_t const *a_scalars, nk_f64_t const *b_scalars, nk_size_t count_scalars,
|
|
157
|
+
nk_f64_t *result) {
|
|
158
|
+
// Dot2 (Ogita-Rump-Oishi) compensated accumulation via TwoProd + TwoSum
|
|
159
|
+
nk_size_t idx_scalars = 0;
|
|
160
|
+
svfloat64_t sum_f64x = svdup_f64(0.);
|
|
161
|
+
svfloat64_t compensation_f64x = svdup_f64(0.);
|
|
162
|
+
do {
|
|
163
|
+
svbool_t predicate_f64x = svwhilelt_b64_u64(idx_scalars, count_scalars);
|
|
164
|
+
svfloat64_t a_f64x = svld1_f64(predicate_f64x, a_scalars + idx_scalars);
|
|
165
|
+
svfloat64_t b_f64x = svld1_f64(predicate_f64x, b_scalars + idx_scalars);
|
|
166
|
+
// TwoProd: product = a*b, error = -(product - a*b) negated
|
|
167
|
+
svfloat64_t product_f64x = svmul_f64_x(predicate_f64x, a_f64x, b_f64x);
|
|
168
|
+
svfloat64_t product_error_f64x = svneg_f64_x(predicate_f64x,
|
|
169
|
+
svnmls_f64_x(predicate_f64x, product_f64x, a_f64x, b_f64x));
|
|
170
|
+
// TwoSum: tentative_sum = sum + product
|
|
171
|
+
svfloat64_t tentative_sum_f64x = svadd_f64_x(predicate_f64x, sum_f64x, product_f64x);
|
|
172
|
+
svfloat64_t virtual_addend_f64x = svsub_f64_x(predicate_f64x, tentative_sum_f64x, sum_f64x);
|
|
173
|
+
svfloat64_t sum_error_f64x = svadd_f64_x(
|
|
174
|
+
predicate_f64x,
|
|
175
|
+
svsub_f64_x(predicate_f64x, sum_f64x, svsub_f64_x(predicate_f64x, tentative_sum_f64x, virtual_addend_f64x)),
|
|
176
|
+
svsub_f64_x(predicate_f64x, product_f64x, virtual_addend_f64x));
|
|
177
|
+
sum_f64x = tentative_sum_f64x;
|
|
178
|
+
compensation_f64x = svadd_f64_x(predicate_f64x, compensation_f64x,
|
|
179
|
+
svadd_f64_x(predicate_f64x, sum_error_f64x, product_error_f64x));
|
|
180
|
+
idx_scalars += svcntd();
|
|
181
|
+
} while (idx_scalars < count_scalars);
|
|
182
|
+
*result = nk_dot_stable_sum_f64_sve_(svptrue_b64(), sum_f64x, compensation_f64x);
|
|
183
|
+
}
|
|
184
|
+
|
|
185
|
+
NK_PUBLIC void nk_dot_f64c_sve(nk_f64c_t const *a_pairs, nk_f64c_t const *b_pairs, nk_size_t count_pairs,
|
|
186
|
+
nk_f64c_t *results) {
|
|
187
|
+
// Dot2 compensated accumulation for complex dot product: (a_real + i*a_imag)(b_real + i*b_imag)
|
|
188
|
+
// real = a_real*b_real - a_imag*b_imag, imag = a_real*b_imag + a_imag*b_real
|
|
189
|
+
nk_size_t idx_pairs = 0;
|
|
190
|
+
svfloat64_t sum_real_f64x = svdup_f64(0.);
|
|
191
|
+
svfloat64_t comp_real_f64x = svdup_f64(0.);
|
|
192
|
+
svfloat64_t sum_imag_f64x = svdup_f64(0.);
|
|
193
|
+
svfloat64_t comp_imag_f64x = svdup_f64(0.);
|
|
194
|
+
do {
|
|
195
|
+
svbool_t predicate_f64x = svwhilelt_b64_u64(idx_pairs, count_pairs);
|
|
196
|
+
svfloat64x2_t a_f64x2 = svld2_f64(predicate_f64x, (nk_f64_t const *)(a_pairs + idx_pairs));
|
|
197
|
+
svfloat64x2_t b_f64x2 = svld2_f64(predicate_f64x, (nk_f64_t const *)(b_pairs + idx_pairs));
|
|
198
|
+
svfloat64_t a_real_f64x = svget2_f64(a_f64x2, 0);
|
|
199
|
+
svfloat64_t a_imag_f64x = svget2_f64(a_f64x2, 1);
|
|
200
|
+
svfloat64_t b_real_f64x = svget2_f64(b_f64x2, 0);
|
|
201
|
+
svfloat64_t b_imag_f64x = svget2_f64(b_f64x2, 1);
|
|
202
|
+
|
|
203
|
+
// TwoProd + TwoSum for real part: sum_real += a_real*b_real
|
|
204
|
+
{
|
|
205
|
+
svfloat64_t product_f64x = svmul_f64_x(predicate_f64x, a_real_f64x, b_real_f64x);
|
|
206
|
+
svfloat64_t product_error_f64x = svneg_f64_x(
|
|
207
|
+
predicate_f64x, svnmls_f64_x(predicate_f64x, product_f64x, a_real_f64x, b_real_f64x));
|
|
208
|
+
svfloat64_t tentative_sum_f64x = svadd_f64_x(predicate_f64x, sum_real_f64x, product_f64x);
|
|
209
|
+
svfloat64_t virtual_addend_f64x = svsub_f64_x(predicate_f64x, tentative_sum_f64x, sum_real_f64x);
|
|
210
|
+
svfloat64_t sum_error_f64x = svadd_f64_x(
|
|
211
|
+
predicate_f64x,
|
|
212
|
+
svsub_f64_x(predicate_f64x, sum_real_f64x,
|
|
213
|
+
svsub_f64_x(predicate_f64x, tentative_sum_f64x, virtual_addend_f64x)),
|
|
214
|
+
svsub_f64_x(predicate_f64x, product_f64x, virtual_addend_f64x));
|
|
215
|
+
sum_real_f64x = tentative_sum_f64x;
|
|
216
|
+
comp_real_f64x = svadd_f64_x(predicate_f64x, comp_real_f64x,
|
|
217
|
+
svadd_f64_x(predicate_f64x, sum_error_f64x, product_error_f64x));
|
|
218
|
+
}
|
|
219
|
+
// TwoProd + TwoSum for real part: sum_real -= a_imag*b_imag
|
|
220
|
+
{
|
|
221
|
+
svfloat64_t product_f64x = svmul_f64_x(predicate_f64x, a_imag_f64x, b_imag_f64x);
|
|
222
|
+
svfloat64_t product_error_f64x = svneg_f64_x(
|
|
223
|
+
predicate_f64x, svnmls_f64_x(predicate_f64x, product_f64x, a_imag_f64x, b_imag_f64x));
|
|
224
|
+
svfloat64_t neg_product_f64x = svneg_f64_x(predicate_f64x, product_f64x);
|
|
225
|
+
svfloat64_t neg_product_error_f64x = svneg_f64_x(predicate_f64x, product_error_f64x);
|
|
226
|
+
svfloat64_t tentative_sum_f64x = svadd_f64_x(predicate_f64x, sum_real_f64x, neg_product_f64x);
|
|
227
|
+
svfloat64_t virtual_addend_f64x = svsub_f64_x(predicate_f64x, tentative_sum_f64x, sum_real_f64x);
|
|
228
|
+
svfloat64_t sum_error_f64x = svadd_f64_x(
|
|
229
|
+
predicate_f64x,
|
|
230
|
+
svsub_f64_x(predicate_f64x, sum_real_f64x,
|
|
231
|
+
svsub_f64_x(predicate_f64x, tentative_sum_f64x, virtual_addend_f64x)),
|
|
232
|
+
svsub_f64_x(predicate_f64x, neg_product_f64x, virtual_addend_f64x));
|
|
233
|
+
sum_real_f64x = tentative_sum_f64x;
|
|
234
|
+
comp_real_f64x = svadd_f64_x(predicate_f64x, comp_real_f64x,
|
|
235
|
+
svadd_f64_x(predicate_f64x, sum_error_f64x, neg_product_error_f64x));
|
|
236
|
+
}
|
|
237
|
+
// TwoProd + TwoSum for imaginary part: sum_imag += a_real*b_imag
|
|
238
|
+
{
|
|
239
|
+
svfloat64_t product_f64x = svmul_f64_x(predicate_f64x, a_real_f64x, b_imag_f64x);
|
|
240
|
+
svfloat64_t product_error_f64x = svneg_f64_x(
|
|
241
|
+
predicate_f64x, svnmls_f64_x(predicate_f64x, product_f64x, a_real_f64x, b_imag_f64x));
|
|
242
|
+
svfloat64_t tentative_sum_f64x = svadd_f64_x(predicate_f64x, sum_imag_f64x, product_f64x);
|
|
243
|
+
svfloat64_t virtual_addend_f64x = svsub_f64_x(predicate_f64x, tentative_sum_f64x, sum_imag_f64x);
|
|
244
|
+
svfloat64_t sum_error_f64x = svadd_f64_x(
|
|
245
|
+
predicate_f64x,
|
|
246
|
+
svsub_f64_x(predicate_f64x, sum_imag_f64x,
|
|
247
|
+
svsub_f64_x(predicate_f64x, tentative_sum_f64x, virtual_addend_f64x)),
|
|
248
|
+
svsub_f64_x(predicate_f64x, product_f64x, virtual_addend_f64x));
|
|
249
|
+
sum_imag_f64x = tentative_sum_f64x;
|
|
250
|
+
comp_imag_f64x = svadd_f64_x(predicate_f64x, comp_imag_f64x,
|
|
251
|
+
svadd_f64_x(predicate_f64x, sum_error_f64x, product_error_f64x));
|
|
252
|
+
}
|
|
253
|
+
// TwoProd + TwoSum for imaginary part: sum_imag += a_imag*b_real
|
|
254
|
+
{
|
|
255
|
+
svfloat64_t product_f64x = svmul_f64_x(predicate_f64x, a_imag_f64x, b_real_f64x);
|
|
256
|
+
svfloat64_t product_error_f64x = svneg_f64_x(
|
|
257
|
+
predicate_f64x, svnmls_f64_x(predicate_f64x, product_f64x, a_imag_f64x, b_real_f64x));
|
|
258
|
+
svfloat64_t tentative_sum_f64x = svadd_f64_x(predicate_f64x, sum_imag_f64x, product_f64x);
|
|
259
|
+
svfloat64_t virtual_addend_f64x = svsub_f64_x(predicate_f64x, tentative_sum_f64x, sum_imag_f64x);
|
|
260
|
+
svfloat64_t sum_error_f64x = svadd_f64_x(
|
|
261
|
+
predicate_f64x,
|
|
262
|
+
svsub_f64_x(predicate_f64x, sum_imag_f64x,
|
|
263
|
+
svsub_f64_x(predicate_f64x, tentative_sum_f64x, virtual_addend_f64x)),
|
|
264
|
+
svsub_f64_x(predicate_f64x, product_f64x, virtual_addend_f64x));
|
|
265
|
+
sum_imag_f64x = tentative_sum_f64x;
|
|
266
|
+
comp_imag_f64x = svadd_f64_x(predicate_f64x, comp_imag_f64x,
|
|
267
|
+
svadd_f64_x(predicate_f64x, sum_error_f64x, product_error_f64x));
|
|
268
|
+
}
|
|
269
|
+
idx_pairs += svcntd();
|
|
270
|
+
} while (idx_pairs < count_pairs);
|
|
271
|
+
svbool_t predicate_all_f64x = svptrue_b64();
|
|
272
|
+
results->real = nk_dot_stable_sum_f64_sve_(predicate_all_f64x, sum_real_f64x, comp_real_f64x);
|
|
273
|
+
results->imag = nk_dot_stable_sum_f64_sve_(predicate_all_f64x, sum_imag_f64x, comp_imag_f64x);
|
|
274
|
+
}
|
|
275
|
+
|
|
276
|
+
NK_PUBLIC void nk_vdot_f64c_sve(nk_f64c_t const *a_pairs, nk_f64c_t const *b_pairs, nk_size_t count_pairs,
|
|
277
|
+
nk_f64c_t *results) {
|
|
278
|
+
// Dot2 compensated conjugate dot product: conj(a) · b = (a_real - i*a_imag)(b_real + i*b_imag)
|
|
279
|
+
// real = a_real*b_real + a_imag*b_imag, imag = a_real*b_imag - a_imag*b_real
|
|
280
|
+
nk_size_t idx_pairs = 0;
|
|
281
|
+
svfloat64_t sum_real_f64x = svdup_f64(0.);
|
|
282
|
+
svfloat64_t comp_real_f64x = svdup_f64(0.);
|
|
283
|
+
svfloat64_t sum_imag_f64x = svdup_f64(0.);
|
|
284
|
+
svfloat64_t comp_imag_f64x = svdup_f64(0.);
|
|
285
|
+
do {
|
|
286
|
+
svbool_t predicate_f64x = svwhilelt_b64_u64(idx_pairs, count_pairs);
|
|
287
|
+
svfloat64x2_t a_f64x2 = svld2_f64(predicate_f64x, (nk_f64_t const *)(a_pairs + idx_pairs));
|
|
288
|
+
svfloat64x2_t b_f64x2 = svld2_f64(predicate_f64x, (nk_f64_t const *)(b_pairs + idx_pairs));
|
|
289
|
+
svfloat64_t a_real_f64x = svget2_f64(a_f64x2, 0);
|
|
290
|
+
svfloat64_t a_imag_f64x = svget2_f64(a_f64x2, 1);
|
|
291
|
+
svfloat64_t b_real_f64x = svget2_f64(b_f64x2, 0);
|
|
292
|
+
svfloat64_t b_imag_f64x = svget2_f64(b_f64x2, 1);
|
|
293
|
+
|
|
294
|
+
// TwoProd + TwoSum for real part: sum_real += a_real*b_real
|
|
295
|
+
{
|
|
296
|
+
svfloat64_t product_f64x = svmul_f64_x(predicate_f64x, a_real_f64x, b_real_f64x);
|
|
297
|
+
svfloat64_t product_error_f64x = svneg_f64_x(
|
|
298
|
+
predicate_f64x, svnmls_f64_x(predicate_f64x, product_f64x, a_real_f64x, b_real_f64x));
|
|
299
|
+
svfloat64_t tentative_sum_f64x = svadd_f64_x(predicate_f64x, sum_real_f64x, product_f64x);
|
|
300
|
+
svfloat64_t virtual_addend_f64x = svsub_f64_x(predicate_f64x, tentative_sum_f64x, sum_real_f64x);
|
|
301
|
+
svfloat64_t sum_error_f64x = svadd_f64_x(
|
|
302
|
+
predicate_f64x,
|
|
303
|
+
svsub_f64_x(predicate_f64x, sum_real_f64x,
|
|
304
|
+
svsub_f64_x(predicate_f64x, tentative_sum_f64x, virtual_addend_f64x)),
|
|
305
|
+
svsub_f64_x(predicate_f64x, product_f64x, virtual_addend_f64x));
|
|
306
|
+
sum_real_f64x = tentative_sum_f64x;
|
|
307
|
+
comp_real_f64x = svadd_f64_x(predicate_f64x, comp_real_f64x,
|
|
308
|
+
svadd_f64_x(predicate_f64x, sum_error_f64x, product_error_f64x));
|
|
309
|
+
}
|
|
310
|
+
// TwoProd + TwoSum for real part: sum_real += a_imag*b_imag (conjugate: + not -)
|
|
311
|
+
{
|
|
312
|
+
svfloat64_t product_f64x = svmul_f64_x(predicate_f64x, a_imag_f64x, b_imag_f64x);
|
|
313
|
+
svfloat64_t product_error_f64x = svneg_f64_x(
|
|
314
|
+
predicate_f64x, svnmls_f64_x(predicate_f64x, product_f64x, a_imag_f64x, b_imag_f64x));
|
|
315
|
+
svfloat64_t tentative_sum_f64x = svadd_f64_x(predicate_f64x, sum_real_f64x, product_f64x);
|
|
316
|
+
svfloat64_t virtual_addend_f64x = svsub_f64_x(predicate_f64x, tentative_sum_f64x, sum_real_f64x);
|
|
317
|
+
svfloat64_t sum_error_f64x = svadd_f64_x(
|
|
318
|
+
predicate_f64x,
|
|
319
|
+
svsub_f64_x(predicate_f64x, sum_real_f64x,
|
|
320
|
+
svsub_f64_x(predicate_f64x, tentative_sum_f64x, virtual_addend_f64x)),
|
|
321
|
+
svsub_f64_x(predicate_f64x, product_f64x, virtual_addend_f64x));
|
|
322
|
+
sum_real_f64x = tentative_sum_f64x;
|
|
323
|
+
comp_real_f64x = svadd_f64_x(predicate_f64x, comp_real_f64x,
|
|
324
|
+
svadd_f64_x(predicate_f64x, sum_error_f64x, product_error_f64x));
|
|
325
|
+
}
|
|
326
|
+
// TwoProd + TwoSum for imaginary part: sum_imag += a_real*b_imag
|
|
327
|
+
{
|
|
328
|
+
svfloat64_t product_f64x = svmul_f64_x(predicate_f64x, a_real_f64x, b_imag_f64x);
|
|
329
|
+
svfloat64_t product_error_f64x = svneg_f64_x(
|
|
330
|
+
predicate_f64x, svnmls_f64_x(predicate_f64x, product_f64x, a_real_f64x, b_imag_f64x));
|
|
331
|
+
svfloat64_t tentative_sum_f64x = svadd_f64_x(predicate_f64x, sum_imag_f64x, product_f64x);
|
|
332
|
+
svfloat64_t virtual_addend_f64x = svsub_f64_x(predicate_f64x, tentative_sum_f64x, sum_imag_f64x);
|
|
333
|
+
svfloat64_t sum_error_f64x = svadd_f64_x(
|
|
334
|
+
predicate_f64x,
|
|
335
|
+
svsub_f64_x(predicate_f64x, sum_imag_f64x,
|
|
336
|
+
svsub_f64_x(predicate_f64x, tentative_sum_f64x, virtual_addend_f64x)),
|
|
337
|
+
svsub_f64_x(predicate_f64x, product_f64x, virtual_addend_f64x));
|
|
338
|
+
sum_imag_f64x = tentative_sum_f64x;
|
|
339
|
+
comp_imag_f64x = svadd_f64_x(predicate_f64x, comp_imag_f64x,
|
|
340
|
+
svadd_f64_x(predicate_f64x, sum_error_f64x, product_error_f64x));
|
|
341
|
+
}
|
|
342
|
+
// TwoProd + TwoSum for imaginary part: sum_imag -= a_imag*b_real (conjugate: - not +)
|
|
343
|
+
{
|
|
344
|
+
svfloat64_t product_f64x = svmul_f64_x(predicate_f64x, a_imag_f64x, b_real_f64x);
|
|
345
|
+
svfloat64_t product_error_f64x = svneg_f64_x(
|
|
346
|
+
predicate_f64x, svnmls_f64_x(predicate_f64x, product_f64x, a_imag_f64x, b_real_f64x));
|
|
347
|
+
svfloat64_t neg_product_f64x = svneg_f64_x(predicate_f64x, product_f64x);
|
|
348
|
+
svfloat64_t neg_product_error_f64x = svneg_f64_x(predicate_f64x, product_error_f64x);
|
|
349
|
+
svfloat64_t tentative_sum_f64x = svadd_f64_x(predicate_f64x, sum_imag_f64x, neg_product_f64x);
|
|
350
|
+
svfloat64_t virtual_addend_f64x = svsub_f64_x(predicate_f64x, tentative_sum_f64x, sum_imag_f64x);
|
|
351
|
+
svfloat64_t sum_error_f64x = svadd_f64_x(
|
|
352
|
+
predicate_f64x,
|
|
353
|
+
svsub_f64_x(predicate_f64x, sum_imag_f64x,
|
|
354
|
+
svsub_f64_x(predicate_f64x, tentative_sum_f64x, virtual_addend_f64x)),
|
|
355
|
+
svsub_f64_x(predicate_f64x, neg_product_f64x, virtual_addend_f64x));
|
|
356
|
+
sum_imag_f64x = tentative_sum_f64x;
|
|
357
|
+
comp_imag_f64x = svadd_f64_x(predicate_f64x, comp_imag_f64x,
|
|
358
|
+
svadd_f64_x(predicate_f64x, sum_error_f64x, neg_product_error_f64x));
|
|
359
|
+
}
|
|
360
|
+
idx_pairs += svcntd();
|
|
361
|
+
} while (idx_pairs < count_pairs);
|
|
362
|
+
svbool_t predicate_all_f64x = svptrue_b64();
|
|
363
|
+
results->real = nk_dot_stable_sum_f64_sve_(predicate_all_f64x, sum_real_f64x, comp_real_f64x);
|
|
364
|
+
results->imag = nk_dot_stable_sum_f64_sve_(predicate_all_f64x, sum_imag_f64x, comp_imag_f64x);
|
|
365
|
+
}
|
|
366
|
+
|
|
367
|
+
#if defined(__clang__)
|
|
368
|
+
#pragma clang attribute pop
|
|
369
|
+
#elif defined(__GNUC__)
|
|
370
|
+
#pragma GCC pop_options
|
|
371
|
+
#endif
|
|
372
|
+
|
|
373
|
+
#if defined(__cplusplus)
|
|
374
|
+
} // extern "C"
|
|
375
|
+
#endif
|
|
376
|
+
|
|
377
|
+
#endif // NK_TARGET_SVE
|
|
378
|
+
#endif // NK_TARGET_ARM_
|
|
379
|
+
#endif // NK_DOT_SVE_H
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SIMD-accelerated Dot Products for SVE BF16.
|
|
3
|
+
* @file include/numkong/dot/svebfdot.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date March 16, 2026
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/dot.h
|
|
8
|
+
*
|
|
9
|
+
* @section dot_svebfdot_instructions ARM SVE+BF16 Instructions
|
|
10
|
+
*
|
|
11
|
+
* Intrinsic Instruction Latency Throughput
|
|
12
|
+
* svld1_bf16 LD1H (Z.H, P/Z, [Xn]) 4-6cy 2/cy
|
|
13
|
+
* svbfdot_f32 BFDOT (Z.S, Z.H, Z.H) 4cy 2/cy
|
|
14
|
+
* svaddv_f32 FADDV (S, P, Z.S) 6cy 1/cy
|
|
15
|
+
* svdup_f32 DUP (Z.S, #imm) 1cy 2/cy
|
|
16
|
+
* svwhilelt_b16 WHILELT (P.H, Xn, Xm) 2cy 1/cy
|
|
17
|
+
* svcnth CNTH (Xd) 1cy 2/cy
|
|
18
|
+
*
|
|
19
|
+
* SVE vector widths vary across implementations: Graviton3 uses 256-bit, while Graviton4/5
|
|
20
|
+
* and Apple M4+ use 128-bit. Code using svcnth() adapts automatically, but wider vectors
|
|
21
|
+
* process more elements per iteration with identical latencies.
|
|
22
|
+
*
|
|
23
|
+
* The BFDOT instruction fuses two BF16 multiplications with FP32 accumulation per lane,
|
|
24
|
+
* providing 4x the throughput of convert-then-FMA sequences. Each BFDOT processes
|
|
25
|
+
* pairs of BF16 values, accumulating directly into FP32 without explicit conversion.
|
|
26
|
+
*/
|
|
27
|
+
#ifndef NK_DOT_SVEBFDOT_H
|
|
28
|
+
#define NK_DOT_SVEBFDOT_H
|
|
29
|
+
|
|
30
|
+
#if NK_TARGET_ARM_
|
|
31
|
+
#if NK_TARGET_SVEBFDOT
|
|
32
|
+
|
|
33
|
+
#include "numkong/types.h"
|
|
34
|
+
|
|
35
|
+
#if defined(__cplusplus)
|
|
36
|
+
extern "C" {
|
|
37
|
+
#endif
|
|
38
|
+
|
|
39
|
+
#if defined(__clang__)
|
|
40
|
+
#pragma clang attribute push(__attribute__((target("arch=armv8.2-a+sve+bf16"))), apply_to = function)
|
|
41
|
+
#elif defined(__GNUC__)
|
|
42
|
+
#pragma GCC push_options
|
|
43
|
+
#pragma GCC target("arch=armv8.2-a+sve+bf16")
|
|
44
|
+
#endif
|
|
45
|
+
|
|
46
|
+
NK_PUBLIC void nk_dot_bf16_svebfdot(nk_bf16_t const *a_scalars, nk_bf16_t const *b_scalars, nk_size_t count_scalars,
|
|
47
|
+
nk_f32_t *result) {
|
|
48
|
+
nk_size_t idx_scalars = 0;
|
|
49
|
+
svfloat32_t sum_f32x = svdup_f32(0);
|
|
50
|
+
nk_bf16_for_arm_simd_t const *a = (nk_bf16_for_arm_simd_t const *)(a_scalars);
|
|
51
|
+
nk_bf16_for_arm_simd_t const *b = (nk_bf16_for_arm_simd_t const *)(b_scalars);
|
|
52
|
+
do {
|
|
53
|
+
svbool_t predicate_bf16x = svwhilelt_b16_u64(idx_scalars, count_scalars);
|
|
54
|
+
svbfloat16_t a_bf16x = svld1_bf16(predicate_bf16x, a + idx_scalars);
|
|
55
|
+
svbfloat16_t b_bf16x = svld1_bf16(predicate_bf16x, b + idx_scalars);
|
|
56
|
+
sum_f32x = svbfdot_f32(sum_f32x, a_bf16x, b_bf16x);
|
|
57
|
+
idx_scalars += svcnth();
|
|
58
|
+
} while (idx_scalars < count_scalars);
|
|
59
|
+
*result = svaddv_f32(svptrue_b32(), sum_f32x);
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
#if defined(__clang__)
|
|
63
|
+
#pragma clang attribute pop
|
|
64
|
+
#elif defined(__GNUC__)
|
|
65
|
+
#pragma GCC pop_options
|
|
66
|
+
#endif
|
|
67
|
+
|
|
68
|
+
#if defined(__cplusplus)
|
|
69
|
+
} // extern "C"
|
|
70
|
+
#endif
|
|
71
|
+
|
|
72
|
+
#endif // NK_TARGET_SVEBFDOT
|
|
73
|
+
#endif // NK_TARGET_ARM_
|
|
74
|
+
#endif // NK_DOT_SVEBFDOT_H
|
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SIMD-accelerated Dot Products for SVE FP16.
|
|
3
|
+
* @file include/numkong/dot/sve.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date December 27, 2025
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/dot.h
|
|
8
|
+
*
|
|
9
|
+
* @section dot_svehalf_instructions ARM SVE+FP16 Instructions
|
|
10
|
+
*
|
|
11
|
+
* Intrinsic Instruction Latency Throughput
|
|
12
|
+
* svld1_f16 LD1H (Z.H, P/Z, [Xn]) 4-6cy 2/cy
|
|
13
|
+
* svld2_f16 LD2H (Z.H, P/Z, [Xn]) 6-8cy 1/cy
|
|
14
|
+
* svmla_f16_x FMLA (Z.H, P/M, Z.H, Z.H) 4cy 2/cy
|
|
15
|
+
* svmls_f16_x FMLS (Z.H, P/M, Z.H, Z.H) 4cy 2/cy
|
|
16
|
+
* svaddv_f16 FADDV (H, P, Z.H) 6cy 1/cy
|
|
17
|
+
* svdup_f16 DUP (Z.H, #imm) 1cy 2/cy
|
|
18
|
+
* svwhilelt_b16 WHILELT (P.H, Xn, Xm) 2cy 1/cy
|
|
19
|
+
* svptrue_b16 PTRUE (P.H, pattern) 1cy 2/cy
|
|
20
|
+
* svcnth CNTH (Xd) 1cy 2/cy
|
|
21
|
+
*
|
|
22
|
+
* SVE vector widths vary across implementations: Graviton3 uses 256-bit, while Graviton4/5
|
|
23
|
+
* and Apple M4+ use 128-bit. Code using svcntb() adapts automatically, but wider vectors
|
|
24
|
+
* process more elements per iteration with identical latencies.
|
|
25
|
+
*
|
|
26
|
+
* FP16 operations double the element count per vector compared to FP32, providing higher
|
|
27
|
+
* throughput at the cost of reduced precision. The FADDV reduction remains the bottleneck.
|
|
28
|
+
*/
|
|
29
|
+
#ifndef NK_DOT_SVEHALF_H
|
|
30
|
+
#define NK_DOT_SVEHALF_H
|
|
31
|
+
|
|
32
|
+
#if NK_TARGET_ARM_
|
|
33
|
+
#if NK_TARGET_SVEHALF
|
|
34
|
+
|
|
35
|
+
#include "numkong/types.h" // `nk_f16_t`
|
|
36
|
+
#include "numkong/dot/serial.h" // `nk_u1x8_popcount_`
|
|
37
|
+
|
|
38
|
+
#if defined(__cplusplus)
|
|
39
|
+
extern "C" {
|
|
40
|
+
#endif
|
|
41
|
+
|
|
42
|
+
#if defined(__clang__)
|
|
43
|
+
#pragma clang attribute push(__attribute__((target("arch=armv8.2-a+sve+fp16"))), apply_to = function)
|
|
44
|
+
#elif defined(__GNUC__)
|
|
45
|
+
#pragma GCC push_options
|
|
46
|
+
#pragma GCC target("arch=armv8.2-a+sve+fp16")
|
|
47
|
+
#endif
|
|
48
|
+
|
|
49
|
+
NK_PUBLIC void nk_dot_f16_svehalf(nk_f16_t const *a_scalars, nk_f16_t const *b_scalars, nk_size_t count_scalars,
|
|
50
|
+
nk_f32_t *result) {
|
|
51
|
+
nk_size_t idx_scalars = 0;
|
|
52
|
+
svfloat32_t ab_f32x = svdup_f32(0);
|
|
53
|
+
do {
|
|
54
|
+
svbool_t predicate_f32x = svwhilelt_b32_u64(idx_scalars, count_scalars);
|
|
55
|
+
svfloat16_t a_f16x = svld1_f16(predicate_f32x, (nk_f16_for_arm_simd_t const *)(a_scalars) + idx_scalars);
|
|
56
|
+
svfloat16_t b_f16x = svld1_f16(predicate_f32x, (nk_f16_for_arm_simd_t const *)(b_scalars) + idx_scalars);
|
|
57
|
+
svfloat32_t a_f32x = svcvt_f32_f16_x(predicate_f32x, a_f16x);
|
|
58
|
+
svfloat32_t b_f32x = svcvt_f32_f16_x(predicate_f32x, b_f16x);
|
|
59
|
+
ab_f32x = svmla_f32_x(predicate_f32x, ab_f32x, a_f32x, b_f32x);
|
|
60
|
+
idx_scalars += svcntw();
|
|
61
|
+
} while (idx_scalars < count_scalars);
|
|
62
|
+
*result = svaddv_f32(svptrue_b32(), ab_f32x);
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
NK_PUBLIC void nk_dot_f16c_svehalf(nk_f16c_t const *a_pairs, nk_f16c_t const *b_pairs, nk_size_t count_pairs,
|
|
66
|
+
nk_f32c_t *results) {
|
|
67
|
+
nk_size_t idx_scalars = 0;
|
|
68
|
+
svfloat32_t ab_real_f32x = svdup_f32(0);
|
|
69
|
+
svfloat32_t ab_imag_f32x = svdup_f32(0);
|
|
70
|
+
do {
|
|
71
|
+
svbool_t predicate_f32x = svwhilelt_b32_u64(idx_scalars, count_pairs);
|
|
72
|
+
svfloat16x2_t a_f16x2 = svld2_f16(predicate_f32x, (nk_f16_for_arm_simd_t const *)(a_pairs) + idx_scalars * 2);
|
|
73
|
+
svfloat16x2_t b_f16x2 = svld2_f16(predicate_f32x, (nk_f16_for_arm_simd_t const *)(b_pairs) + idx_scalars * 2);
|
|
74
|
+
svfloat32_t a_real_f32x = svcvt_f32_f16_x(predicate_f32x, svget2_f16(a_f16x2, 0));
|
|
75
|
+
svfloat32_t a_imag_f32x = svcvt_f32_f16_x(predicate_f32x, svget2_f16(a_f16x2, 1));
|
|
76
|
+
svfloat32_t b_real_f32x = svcvt_f32_f16_x(predicate_f32x, svget2_f16(b_f16x2, 0));
|
|
77
|
+
svfloat32_t b_imag_f32x = svcvt_f32_f16_x(predicate_f32x, svget2_f16(b_f16x2, 1));
|
|
78
|
+
ab_real_f32x = svmla_f32_x(predicate_f32x, ab_real_f32x, a_real_f32x, b_real_f32x);
|
|
79
|
+
ab_real_f32x = svmls_f32_x(predicate_f32x, ab_real_f32x, a_imag_f32x, b_imag_f32x);
|
|
80
|
+
ab_imag_f32x = svmla_f32_x(predicate_f32x, ab_imag_f32x, a_real_f32x, b_imag_f32x);
|
|
81
|
+
ab_imag_f32x = svmla_f32_x(predicate_f32x, ab_imag_f32x, a_imag_f32x, b_real_f32x);
|
|
82
|
+
idx_scalars += svcntw();
|
|
83
|
+
} while (idx_scalars < count_pairs);
|
|
84
|
+
results->real = svaddv_f32(svptrue_b32(), ab_real_f32x);
|
|
85
|
+
results->imag = svaddv_f32(svptrue_b32(), ab_imag_f32x);
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
NK_PUBLIC void nk_vdot_f16c_svehalf(nk_f16c_t const *a_pairs, nk_f16c_t const *b_pairs, nk_size_t count_pairs,
|
|
89
|
+
nk_f32c_t *results) {
|
|
90
|
+
nk_size_t idx_scalars = 0;
|
|
91
|
+
svfloat32_t ab_real_f32x = svdup_f32(0);
|
|
92
|
+
svfloat32_t ab_imag_f32x = svdup_f32(0);
|
|
93
|
+
do {
|
|
94
|
+
svbool_t predicate_f32x = svwhilelt_b32_u64(idx_scalars, count_pairs);
|
|
95
|
+
svfloat16x2_t a_f16x2 = svld2_f16(predicate_f32x, (nk_f16_for_arm_simd_t const *)(a_pairs) + idx_scalars * 2);
|
|
96
|
+
svfloat16x2_t b_f16x2 = svld2_f16(predicate_f32x, (nk_f16_for_arm_simd_t const *)(b_pairs) + idx_scalars * 2);
|
|
97
|
+
svfloat32_t a_real_f32x = svcvt_f32_f16_x(predicate_f32x, svget2_f16(a_f16x2, 0));
|
|
98
|
+
svfloat32_t a_imag_f32x = svcvt_f32_f16_x(predicate_f32x, svget2_f16(a_f16x2, 1));
|
|
99
|
+
svfloat32_t b_real_f32x = svcvt_f32_f16_x(predicate_f32x, svget2_f16(b_f16x2, 0));
|
|
100
|
+
svfloat32_t b_imag_f32x = svcvt_f32_f16_x(predicate_f32x, svget2_f16(b_f16x2, 1));
|
|
101
|
+
ab_real_f32x = svmla_f32_x(predicate_f32x, ab_real_f32x, a_real_f32x, b_real_f32x);
|
|
102
|
+
ab_real_f32x = svmla_f32_x(predicate_f32x, ab_real_f32x, a_imag_f32x, b_imag_f32x);
|
|
103
|
+
ab_imag_f32x = svmla_f32_x(predicate_f32x, ab_imag_f32x, a_real_f32x, b_imag_f32x);
|
|
104
|
+
ab_imag_f32x = svmls_f32_x(predicate_f32x, ab_imag_f32x, a_imag_f32x, b_real_f32x);
|
|
105
|
+
idx_scalars += svcntw();
|
|
106
|
+
} while (idx_scalars < count_pairs);
|
|
107
|
+
results->real = svaddv_f32(svptrue_b32(), ab_real_f32x);
|
|
108
|
+
results->imag = svaddv_f32(svptrue_b32(), ab_imag_f32x);
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
#if defined(__clang__)
|
|
112
|
+
#pragma clang attribute pop
|
|
113
|
+
#elif defined(__GNUC__)
|
|
114
|
+
#pragma GCC pop_options
|
|
115
|
+
#endif
|
|
116
|
+
|
|
117
|
+
#if defined(__cplusplus)
|
|
118
|
+
} // extern "C"
|
|
119
|
+
#endif
|
|
120
|
+
|
|
121
|
+
#endif // NK_TARGET_SVEHALF
|
|
122
|
+
#endif // NK_TARGET_ARM_
|
|
123
|
+
#endif // NK_DOT_SVEHALF_H
|