numkong 7.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +201 -0
- package/README.md +495 -0
- package/binding.gyp +540 -0
- package/c/dispatch.h +512 -0
- package/c/dispatch_bf16.c +389 -0
- package/c/dispatch_bf16c.c +52 -0
- package/c/dispatch_e2m3.c +263 -0
- package/c/dispatch_e3m2.c +243 -0
- package/c/dispatch_e4m3.c +276 -0
- package/c/dispatch_e5m2.c +272 -0
- package/c/dispatch_f16.c +376 -0
- package/c/dispatch_f16c.c +58 -0
- package/c/dispatch_f32.c +378 -0
- package/c/dispatch_f32c.c +99 -0
- package/c/dispatch_f64.c +296 -0
- package/c/dispatch_f64c.c +98 -0
- package/c/dispatch_i16.c +96 -0
- package/c/dispatch_i32.c +89 -0
- package/c/dispatch_i4.c +150 -0
- package/c/dispatch_i64.c +86 -0
- package/c/dispatch_i8.c +289 -0
- package/c/dispatch_other.c +330 -0
- package/c/dispatch_u1.c +148 -0
- package/c/dispatch_u16.c +124 -0
- package/c/dispatch_u32.c +118 -0
- package/c/dispatch_u4.c +150 -0
- package/c/dispatch_u64.c +102 -0
- package/c/dispatch_u8.c +303 -0
- package/c/numkong.c +950 -0
- package/include/README.md +573 -0
- package/include/module.modulemap +129 -0
- package/include/numkong/attention/sapphireamx.h +1361 -0
- package/include/numkong/attention/sme.h +2066 -0
- package/include/numkong/attention.h +49 -0
- package/include/numkong/capabilities.h +748 -0
- package/include/numkong/cast/README.md +262 -0
- package/include/numkong/cast/haswell.h +975 -0
- package/include/numkong/cast/icelake.h +470 -0
- package/include/numkong/cast/neon.h +1192 -0
- package/include/numkong/cast/rvv.h +1021 -0
- package/include/numkong/cast/sapphire.h +262 -0
- package/include/numkong/cast/serial.h +2262 -0
- package/include/numkong/cast/skylake.h +856 -0
- package/include/numkong/cast/v128relaxed.h +180 -0
- package/include/numkong/cast.h +230 -0
- package/include/numkong/curved/README.md +223 -0
- package/include/numkong/curved/genoa.h +182 -0
- package/include/numkong/curved/haswell.h +276 -0
- package/include/numkong/curved/neon.h +205 -0
- package/include/numkong/curved/neonbfdot.h +212 -0
- package/include/numkong/curved/neonhalf.h +212 -0
- package/include/numkong/curved/rvv.h +305 -0
- package/include/numkong/curved/serial.h +207 -0
- package/include/numkong/curved/skylake.h +457 -0
- package/include/numkong/curved/smef64.h +506 -0
- package/include/numkong/curved.h +517 -0
- package/include/numkong/curved.hpp +144 -0
- package/include/numkong/dot/README.md +425 -0
- package/include/numkong/dot/alder.h +563 -0
- package/include/numkong/dot/genoa.h +315 -0
- package/include/numkong/dot/haswell.h +1688 -0
- package/include/numkong/dot/icelake.h +883 -0
- package/include/numkong/dot/neon.h +818 -0
- package/include/numkong/dot/neonbfdot.h +244 -0
- package/include/numkong/dot/neonfhm.h +360 -0
- package/include/numkong/dot/neonhalf.h +198 -0
- package/include/numkong/dot/neonsdot.h +508 -0
- package/include/numkong/dot/rvv.h +714 -0
- package/include/numkong/dot/rvvbb.h +72 -0
- package/include/numkong/dot/rvvbf16.h +123 -0
- package/include/numkong/dot/rvvhalf.h +129 -0
- package/include/numkong/dot/sapphire.h +141 -0
- package/include/numkong/dot/serial.h +838 -0
- package/include/numkong/dot/sierra.h +405 -0
- package/include/numkong/dot/skylake.h +1084 -0
- package/include/numkong/dot/sve.h +379 -0
- package/include/numkong/dot/svebfdot.h +74 -0
- package/include/numkong/dot/svehalf.h +123 -0
- package/include/numkong/dot/v128relaxed.h +1258 -0
- package/include/numkong/dot.h +1070 -0
- package/include/numkong/dot.hpp +94 -0
- package/include/numkong/dots/README.md +496 -0
- package/include/numkong/dots/alder.h +114 -0
- package/include/numkong/dots/genoa.h +94 -0
- package/include/numkong/dots/haswell.h +295 -0
- package/include/numkong/dots/icelake.h +171 -0
- package/include/numkong/dots/neon.h +120 -0
- package/include/numkong/dots/neonbfdot.h +58 -0
- package/include/numkong/dots/neonfhm.h +94 -0
- package/include/numkong/dots/neonhalf.h +57 -0
- package/include/numkong/dots/neonsdot.h +108 -0
- package/include/numkong/dots/rvv.h +2486 -0
- package/include/numkong/dots/sapphireamx.h +3973 -0
- package/include/numkong/dots/serial.h +2844 -0
- package/include/numkong/dots/sierra.h +97 -0
- package/include/numkong/dots/skylake.h +196 -0
- package/include/numkong/dots/sme.h +5372 -0
- package/include/numkong/dots/smebi32.h +461 -0
- package/include/numkong/dots/smef64.h +1318 -0
- package/include/numkong/dots/smehalf.h +47 -0
- package/include/numkong/dots/v128relaxed.h +294 -0
- package/include/numkong/dots.h +2804 -0
- package/include/numkong/dots.hpp +639 -0
- package/include/numkong/each/README.md +469 -0
- package/include/numkong/each/haswell.h +1658 -0
- package/include/numkong/each/icelake.h +272 -0
- package/include/numkong/each/neon.h +1104 -0
- package/include/numkong/each/neonbfdot.h +212 -0
- package/include/numkong/each/neonhalf.h +410 -0
- package/include/numkong/each/rvv.h +1121 -0
- package/include/numkong/each/sapphire.h +477 -0
- package/include/numkong/each/serial.h +260 -0
- package/include/numkong/each/skylake.h +1562 -0
- package/include/numkong/each.h +2146 -0
- package/include/numkong/each.hpp +434 -0
- package/include/numkong/geospatial/README.md +147 -0
- package/include/numkong/geospatial/haswell.h +593 -0
- package/include/numkong/geospatial/neon.h +571 -0
- package/include/numkong/geospatial/rvv.h +701 -0
- package/include/numkong/geospatial/serial.h +309 -0
- package/include/numkong/geospatial/skylake.h +577 -0
- package/include/numkong/geospatial/v128relaxed.h +613 -0
- package/include/numkong/geospatial.h +453 -0
- package/include/numkong/geospatial.hpp +235 -0
- package/include/numkong/matrix.hpp +336 -0
- package/include/numkong/maxsim/README.md +187 -0
- package/include/numkong/maxsim/alder.h +511 -0
- package/include/numkong/maxsim/genoa.h +115 -0
- package/include/numkong/maxsim/haswell.h +553 -0
- package/include/numkong/maxsim/icelake.h +480 -0
- package/include/numkong/maxsim/neonsdot.h +394 -0
- package/include/numkong/maxsim/sapphireamx.h +877 -0
- package/include/numkong/maxsim/serial.h +490 -0
- package/include/numkong/maxsim/sme.h +929 -0
- package/include/numkong/maxsim/v128relaxed.h +280 -0
- package/include/numkong/maxsim.h +571 -0
- package/include/numkong/maxsim.hpp +133 -0
- package/include/numkong/mesh/README.md +227 -0
- package/include/numkong/mesh/haswell.h +2235 -0
- package/include/numkong/mesh/neon.h +1329 -0
- package/include/numkong/mesh/neonbfdot.h +842 -0
- package/include/numkong/mesh/neonhalf.h +616 -0
- package/include/numkong/mesh/rvv.h +916 -0
- package/include/numkong/mesh/serial.h +742 -0
- package/include/numkong/mesh/skylake.h +1135 -0
- package/include/numkong/mesh/v128relaxed.h +1052 -0
- package/include/numkong/mesh.h +652 -0
- package/include/numkong/mesh.hpp +762 -0
- package/include/numkong/numkong.h +78 -0
- package/include/numkong/numkong.hpp +57 -0
- package/include/numkong/probability/README.md +173 -0
- package/include/numkong/probability/haswell.h +267 -0
- package/include/numkong/probability/neon.h +225 -0
- package/include/numkong/probability/rvv.h +409 -0
- package/include/numkong/probability/serial.h +169 -0
- package/include/numkong/probability/skylake.h +324 -0
- package/include/numkong/probability.h +383 -0
- package/include/numkong/probability.hpp +120 -0
- package/include/numkong/random.h +50 -0
- package/include/numkong/random.hpp +285 -0
- package/include/numkong/reduce/README.md +547 -0
- package/include/numkong/reduce/alder.h +632 -0
- package/include/numkong/reduce/genoa.h +201 -0
- package/include/numkong/reduce/haswell.h +3783 -0
- package/include/numkong/reduce/icelake.h +549 -0
- package/include/numkong/reduce/neon.h +3841 -0
- package/include/numkong/reduce/neonbfdot.h +353 -0
- package/include/numkong/reduce/neonfhm.h +665 -0
- package/include/numkong/reduce/neonhalf.h +157 -0
- package/include/numkong/reduce/neonsdot.h +357 -0
- package/include/numkong/reduce/rvv.h +3407 -0
- package/include/numkong/reduce/serial.h +757 -0
- package/include/numkong/reduce/sierra.h +338 -0
- package/include/numkong/reduce/skylake.h +3792 -0
- package/include/numkong/reduce/v128relaxed.h +2302 -0
- package/include/numkong/reduce.h +1597 -0
- package/include/numkong/reduce.hpp +633 -0
- package/include/numkong/scalar/README.md +89 -0
- package/include/numkong/scalar/haswell.h +113 -0
- package/include/numkong/scalar/neon.h +122 -0
- package/include/numkong/scalar/neonhalf.h +70 -0
- package/include/numkong/scalar/rvv.h +211 -0
- package/include/numkong/scalar/sapphire.h +63 -0
- package/include/numkong/scalar/serial.h +332 -0
- package/include/numkong/scalar/v128relaxed.h +56 -0
- package/include/numkong/scalar.h +683 -0
- package/include/numkong/set/README.md +179 -0
- package/include/numkong/set/haswell.h +334 -0
- package/include/numkong/set/icelake.h +485 -0
- package/include/numkong/set/neon.h +364 -0
- package/include/numkong/set/rvv.h +226 -0
- package/include/numkong/set/rvvbb.h +117 -0
- package/include/numkong/set/serial.h +174 -0
- package/include/numkong/set/sve.h +185 -0
- package/include/numkong/set/v128relaxed.h +240 -0
- package/include/numkong/set.h +457 -0
- package/include/numkong/set.hpp +114 -0
- package/include/numkong/sets/README.md +149 -0
- package/include/numkong/sets/haswell.h +63 -0
- package/include/numkong/sets/icelake.h +66 -0
- package/include/numkong/sets/neon.h +61 -0
- package/include/numkong/sets/serial.h +43 -0
- package/include/numkong/sets/smebi32.h +1099 -0
- package/include/numkong/sets/v128relaxed.h +58 -0
- package/include/numkong/sets.h +339 -0
- package/include/numkong/sparse/README.md +156 -0
- package/include/numkong/sparse/icelake.h +463 -0
- package/include/numkong/sparse/neon.h +288 -0
- package/include/numkong/sparse/serial.h +117 -0
- package/include/numkong/sparse/sve2.h +507 -0
- package/include/numkong/sparse/turin.h +322 -0
- package/include/numkong/sparse.h +363 -0
- package/include/numkong/sparse.hpp +113 -0
- package/include/numkong/spatial/README.md +435 -0
- package/include/numkong/spatial/alder.h +607 -0
- package/include/numkong/spatial/genoa.h +290 -0
- package/include/numkong/spatial/haswell.h +960 -0
- package/include/numkong/spatial/icelake.h +586 -0
- package/include/numkong/spatial/neon.h +773 -0
- package/include/numkong/spatial/neonbfdot.h +165 -0
- package/include/numkong/spatial/neonhalf.h +118 -0
- package/include/numkong/spatial/neonsdot.h +261 -0
- package/include/numkong/spatial/rvv.h +984 -0
- package/include/numkong/spatial/rvvbf16.h +123 -0
- package/include/numkong/spatial/rvvhalf.h +117 -0
- package/include/numkong/spatial/sapphire.h +343 -0
- package/include/numkong/spatial/serial.h +346 -0
- package/include/numkong/spatial/sierra.h +323 -0
- package/include/numkong/spatial/skylake.h +606 -0
- package/include/numkong/spatial/sve.h +224 -0
- package/include/numkong/spatial/svebfdot.h +122 -0
- package/include/numkong/spatial/svehalf.h +109 -0
- package/include/numkong/spatial/v128relaxed.h +717 -0
- package/include/numkong/spatial.h +1425 -0
- package/include/numkong/spatial.hpp +183 -0
- package/include/numkong/spatials/README.md +580 -0
- package/include/numkong/spatials/alder.h +94 -0
- package/include/numkong/spatials/genoa.h +94 -0
- package/include/numkong/spatials/haswell.h +219 -0
- package/include/numkong/spatials/icelake.h +113 -0
- package/include/numkong/spatials/neon.h +109 -0
- package/include/numkong/spatials/neonbfdot.h +60 -0
- package/include/numkong/spatials/neonfhm.h +92 -0
- package/include/numkong/spatials/neonhalf.h +58 -0
- package/include/numkong/spatials/neonsdot.h +109 -0
- package/include/numkong/spatials/rvv.h +1960 -0
- package/include/numkong/spatials/sapphireamx.h +1149 -0
- package/include/numkong/spatials/serial.h +226 -0
- package/include/numkong/spatials/sierra.h +96 -0
- package/include/numkong/spatials/skylake.h +184 -0
- package/include/numkong/spatials/sme.h +1901 -0
- package/include/numkong/spatials/smef64.h +465 -0
- package/include/numkong/spatials/v128relaxed.h +240 -0
- package/include/numkong/spatials.h +3021 -0
- package/include/numkong/spatials.hpp +508 -0
- package/include/numkong/tensor.hpp +1592 -0
- package/include/numkong/trigonometry/README.md +184 -0
- package/include/numkong/trigonometry/haswell.h +652 -0
- package/include/numkong/trigonometry/neon.h +639 -0
- package/include/numkong/trigonometry/rvv.h +699 -0
- package/include/numkong/trigonometry/serial.h +703 -0
- package/include/numkong/trigonometry/skylake.h +721 -0
- package/include/numkong/trigonometry/v128relaxed.h +666 -0
- package/include/numkong/trigonometry.h +467 -0
- package/include/numkong/trigonometry.hpp +166 -0
- package/include/numkong/types.h +1384 -0
- package/include/numkong/types.hpp +5603 -0
- package/include/numkong/vector.hpp +698 -0
- package/javascript/README.md +246 -0
- package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
- package/javascript/dist/cjs/numkong-wasm.js +617 -0
- package/javascript/dist/cjs/numkong.d.ts +343 -0
- package/javascript/dist/cjs/numkong.js +523 -0
- package/javascript/dist/cjs/package.json +3 -0
- package/javascript/dist/cjs/types.d.ts +284 -0
- package/javascript/dist/cjs/types.js +653 -0
- package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
- package/javascript/dist/esm/numkong-wasm.js +595 -0
- package/javascript/dist/esm/numkong.d.ts +343 -0
- package/javascript/dist/esm/numkong.js +452 -0
- package/javascript/dist/esm/package.json +3 -0
- package/javascript/dist/esm/types.d.ts +284 -0
- package/javascript/dist/esm/types.js +630 -0
- package/javascript/dist-package-cjs.json +3 -0
- package/javascript/dist-package-esm.json +3 -0
- package/javascript/node-gyp-build.d.ts +1 -0
- package/javascript/numkong-wasm.ts +756 -0
- package/javascript/numkong.c +689 -0
- package/javascript/numkong.ts +575 -0
- package/javascript/tsconfig-base.json +39 -0
- package/javascript/tsconfig-cjs.json +8 -0
- package/javascript/tsconfig-esm.json +8 -0
- package/javascript/types.ts +674 -0
- package/package.json +87 -0
|
@@ -0,0 +1,506 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SIMD-accelerated Curved Space Similarity for SME F64.
|
|
3
|
+
* @file include/numkong/curved/smef64.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date January 14, 2026
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/curved.h
|
|
8
|
+
*
|
|
9
|
+
* Implements bilinear forms and Mahalanobis distance using ARM SME:
|
|
10
|
+
* - f32 inputs: GEMV via f64 FMOPA (widening load f32→f64, exact accumulation)
|
|
11
|
+
* - f64 inputs: row-by-row streaming SVE with Dot2 (Ogita-Rump-Oishi 2005)
|
|
12
|
+
* - f32c complex: 4-FMOPA complex GEMV with FMOPS for the cᵢₘ×bᵢₘ subtraction
|
|
13
|
+
* - f64c complex: interleaved Dot2 with permute + deferred XOR sign-flip
|
|
14
|
+
*
|
|
15
|
+
* Complex number history — approaches tried and abandoned:
|
|
16
|
+
*
|
|
17
|
+
* 1. Ozaki 3-way FMOPA (f64c): Split each f64 into 19+17+17 bit mantissa parts,
|
|
18
|
+
* compute 7 ZA tiles of FMOPA/FMOPS per inner step (24 tile ops total).
|
|
19
|
+
* Abandoned: the 3× split + 4× complex cross-terms = 12× tile ops vs real,
|
|
20
|
+
* with staging overhead dominating at GEMV (not GEMM) granularity.
|
|
21
|
+
*
|
|
22
|
+
* 2. Deinterleaved 4-accumulator SVE Dot2 (f64c): Separate real/imaginary via
|
|
23
|
+
* UZP1/UZP2, run 4 independent Dot2 chains (rr, ii, ri, ir). Theoretically
|
|
24
|
+
* matches the serial kernel's arithmetic intensity, but UZP on SVE requires
|
|
25
|
+
* loading 2 vectors to produce 1 full-width deinterleaved vector, and the
|
|
26
|
+
* total ops/byte is identical to the interleaved approach (~28 SVE ops/iter).
|
|
27
|
+
*
|
|
28
|
+
* 3. Simple (non-Ozaki) f64 FMOPA for complex: Would give ~5-10 GFLOP/s but
|
|
29
|
+
* drops Dot2 compensation entirely (naive f64 accumulation, ~BLAS precision).
|
|
30
|
+
* Not implemented because precision is a core requirement for f64 kernels.
|
|
31
|
+
*
|
|
32
|
+
* The current interleaved Dot2 approach (2 accumulators + svtbl swap + XOR sign
|
|
33
|
+
* flip) is the best balance found: ~15 SVE ops/iter vs ~28 for deinterleaved,
|
|
34
|
+
* with identical Dot2 precision. The ~1.5 GFLOP/s throughput is limited by the
|
|
35
|
+
* SME coprocessor's slow per-instruction pipeline — the serial version achieves
|
|
36
|
+
* ~2.2 GFLOP/s despite using software Dekker FMA (~20 ops/TwoProd vs SVE's 3)
|
|
37
|
+
* because it runs on the faster main core.
|
|
38
|
+
*
|
|
39
|
+
* On Apple M4, SVE instructions are only available inside SME streaming mode.
|
|
40
|
+
* Functions using SVE intrinsics are marked `__arm_locally_streaming` in a
|
|
41
|
+
* `_streaming_` helper; the NK_PUBLIC entry point is a thin non-streaming
|
|
42
|
+
* wrapper. NEON intrinsics cannot be called from streaming mode, so Mahalanobis
|
|
43
|
+
* functions split into a streaming helper (SVE) and a non-streaming wrapper
|
|
44
|
+
* (NEON sqrt).
|
|
45
|
+
*
|
|
46
|
+
* @see Ogita, T., Rump, S.M., Oishi, S. (2005). "Accurate Sum and Dot Product"
|
|
47
|
+
*/
|
|
48
|
+
#ifndef NK_CURVED_SMEF64_H
|
|
49
|
+
#define NK_CURVED_SMEF64_H
|
|
50
|
+
|
|
51
|
+
#if NK_TARGET_ARM_
|
|
52
|
+
#if NK_TARGET_SMEF64
|
|
53
|
+
|
|
54
|
+
#include "numkong/types.h"
|
|
55
|
+
#include "numkong/spatial/neon.h" // `nk_f64_sqrt_neon`
|
|
56
|
+
#include "numkong/dots/sme.h" // nk_sme_zero_za64_tile_0_, etc. (for f32 FMOPA)
|
|
57
|
+
#include "numkong/curved/serial.h" // `nk_bilinear_f64_serial`, etc.
|
|
58
|
+
|
|
59
|
+
#if defined(__cplusplus)
|
|
60
|
+
extern "C" {
|
|
61
|
+
#endif
|
|
62
|
+
|
|
63
|
+
#if defined(__clang__)
|
|
64
|
+
#pragma clang attribute push(__attribute__((target("sme,sve,sme-f64f64"))), apply_to = function)
|
|
65
|
+
#elif defined(__GNUC__)
|
|
66
|
+
#pragma GCC push_options
|
|
67
|
+
#pragma GCC target("+sme+sme-f64f64")
|
|
68
|
+
#endif
|
|
69
|
+
|
|
70
|
+
/**
|
|
71
|
+
* @brief SVE Dot2 accumulator: sum += a × b with error compensation.
|
|
72
|
+
* Uses TwoProd (svneg+svnmls) and TwoSum error-free transformations.
|
|
73
|
+
*/
|
|
74
|
+
NK_PUBLIC void nk_dot2_f64_sve_accumulate_(svbool_t predicate_f64x, svfloat64_t *sum, svfloat64_t *comp,
|
|
75
|
+
svfloat64_t a_f64x, svfloat64_t b_f64x) NK_STREAMING_COMPATIBLE_ {
|
|
76
|
+
svfloat64_t product_f64x = svmul_f64_x(predicate_f64x, a_f64x, b_f64x);
|
|
77
|
+
svfloat64_t product_error_f64x = svneg_f64_x(predicate_f64x,
|
|
78
|
+
svnmls_f64_x(predicate_f64x, product_f64x, a_f64x, b_f64x));
|
|
79
|
+
svfloat64_t running_sum_f64x = svadd_f64_x(predicate_f64x, *sum, product_f64x);
|
|
80
|
+
svfloat64_t recovered_addend_f64x = svsub_f64_x(predicate_f64x, running_sum_f64x, *sum);
|
|
81
|
+
svfloat64_t sum_error_f64x = svadd_f64_x(
|
|
82
|
+
predicate_f64x,
|
|
83
|
+
svsub_f64_x(predicate_f64x, *sum, svsub_f64_x(predicate_f64x, running_sum_f64x, recovered_addend_f64x)),
|
|
84
|
+
svsub_f64_x(predicate_f64x, product_f64x, recovered_addend_f64x));
|
|
85
|
+
*sum = running_sum_f64x;
|
|
86
|
+
*comp = svadd_f64_x(predicate_f64x, *comp, svadd_f64_x(predicate_f64x, sum_error_f64x, product_error_f64x));
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
/**
|
|
90
|
+
* @brief f32 bilinear: GEMV via FMOPA (widening f32→f64, exact accumulation).
|
|
91
|
+
* ZA0.D = C staging, ZA1.D = GEMV accumulator.
|
|
92
|
+
*/
|
|
93
|
+
__arm_locally_streaming __arm_new("za") static void nk_bilinear_f32_smef64_streaming_(nk_f32_t const *a,
|
|
94
|
+
nk_f32_t const *b,
|
|
95
|
+
nk_f32_t const *c, nk_size_t n,
|
|
96
|
+
nk_f64_t *result) {
|
|
97
|
+
svbool_t predicate_body_f64x = svptrue_b64();
|
|
98
|
+
nk_size_t tile_dimension = svcntd();
|
|
99
|
+
nk_f64_t outer_sum_f64 = 0.0;
|
|
100
|
+
|
|
101
|
+
for (nk_size_t row = 0; row < n; row += tile_dimension) {
|
|
102
|
+
nk_size_t rows_remaining = (row + tile_dimension <= n) ? tile_dimension : (n - row);
|
|
103
|
+
svbool_t row_predicate_f64x = svwhilelt_b64_u64(0u, rows_remaining);
|
|
104
|
+
|
|
105
|
+
svzero_mask_za(nk_sme_zero_za64_tile_1_);
|
|
106
|
+
|
|
107
|
+
for (nk_size_t j = 0; j < n; j += tile_dimension) {
|
|
108
|
+
nk_size_t batch_size = (j + tile_dimension <= n) ? tile_dimension : (n - j);
|
|
109
|
+
svbool_t batch_predicate_f64x = svwhilelt_b64_u64(0u, batch_size);
|
|
110
|
+
|
|
111
|
+
svzero_mask_za(nk_sme_zero_za64_tile_0_);
|
|
112
|
+
for (nk_size_t r = 0; r < rows_remaining; r++) {
|
|
113
|
+
svfloat64_t c_row_f64x = svcvt_f64_f32_x(
|
|
114
|
+
batch_predicate_f64x, svreinterpret_f32_u64(svld1uw_u64(
|
|
115
|
+
batch_predicate_f64x, (nk_u32_t const *)(c + (row + r) * n + j))));
|
|
116
|
+
svwrite_hor_za64_f64_m(0, r, batch_predicate_f64x, c_row_f64x);
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
for (nk_size_t k = 0; k < batch_size; k++) {
|
|
120
|
+
svfloat64_t c_col_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), row_predicate_f64x, 0, k);
|
|
121
|
+
svmopa_za64_f64_m(1, row_predicate_f64x, row_predicate_f64x, c_col_f64x, svdup_f64((nk_f64_t)b[j + k]));
|
|
122
|
+
}
|
|
123
|
+
}
|
|
124
|
+
|
|
125
|
+
svfloat64_t v_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), row_predicate_f64x, 1, 0);
|
|
126
|
+
svfloat64_t a_f64x = svcvt_f64_f32_x(
|
|
127
|
+
row_predicate_f64x, svreinterpret_f32_u64(svld1uw_u64(row_predicate_f64x, (nk_u32_t const *)(a + row))));
|
|
128
|
+
outer_sum_f64 += svaddv_f64(predicate_body_f64x, svmul_f64_x(row_predicate_f64x, a_f64x, v_f64x));
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
*result = outer_sum_f64;
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
NK_PUBLIC void nk_bilinear_f32_smef64(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t n,
|
|
135
|
+
nk_f64_t *result) {
|
|
136
|
+
nk_bilinear_f32_smef64_streaming_(a, b, c, n, result);
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
/**
|
|
140
|
+
* @brief f32 Mahalanobis: GEMV v = C×d via FMOPA, where d = a − b (exact in f64).
|
|
141
|
+
* ZA0.D = C staging, ZA1.D = GEMV accumulator.
|
|
142
|
+
*/
|
|
143
|
+
__arm_locally_streaming __arm_new("za") static inline nk_f64_t
|
|
144
|
+
nk_mahalanobis_f32_smef64_streaming_(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t n) {
|
|
145
|
+
|
|
146
|
+
svbool_t predicate_body_f64x = svptrue_b64();
|
|
147
|
+
nk_size_t tile_dimension = svcntd();
|
|
148
|
+
nk_f64_t outer_sum_f64 = 0.0;
|
|
149
|
+
|
|
150
|
+
for (nk_size_t row = 0; row < n; row += tile_dimension) {
|
|
151
|
+
nk_size_t rows_remaining = (row + tile_dimension <= n) ? tile_dimension : (n - row);
|
|
152
|
+
svbool_t row_predicate_f64x = svwhilelt_b64_u64(0u, rows_remaining);
|
|
153
|
+
|
|
154
|
+
svzero_mask_za(nk_sme_zero_za64_tile_1_);
|
|
155
|
+
|
|
156
|
+
for (nk_size_t j = 0; j < n; j += tile_dimension) {
|
|
157
|
+
nk_size_t batch_size = (j + tile_dimension <= n) ? tile_dimension : (n - j);
|
|
158
|
+
svbool_t batch_predicate_f64x = svwhilelt_b64_u64(0u, batch_size);
|
|
159
|
+
|
|
160
|
+
svzero_mask_za(nk_sme_zero_za64_tile_0_);
|
|
161
|
+
for (nk_size_t r = 0; r < rows_remaining; r++) {
|
|
162
|
+
svfloat64_t c_row_f64x = svcvt_f64_f32_x(
|
|
163
|
+
batch_predicate_f64x, svreinterpret_f32_u64(svld1uw_u64(
|
|
164
|
+
batch_predicate_f64x, (nk_u32_t const *)(c + (row + r) * n + j))));
|
|
165
|
+
svwrite_hor_za64_f64_m(0, r, batch_predicate_f64x, c_row_f64x);
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
for (nk_size_t k = 0; k < batch_size; k++) {
|
|
169
|
+
svfloat64_t c_col_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), row_predicate_f64x, 0, k);
|
|
170
|
+
nk_f64_t d_k = (nk_f64_t)a[j + k] - (nk_f64_t)b[j + k];
|
|
171
|
+
svmopa_za64_f64_m(1, row_predicate_f64x, row_predicate_f64x, c_col_f64x, svdup_f64(d_k));
|
|
172
|
+
}
|
|
173
|
+
}
|
|
174
|
+
|
|
175
|
+
svfloat64_t v_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), row_predicate_f64x, 1, 0);
|
|
176
|
+
svfloat64_t a_f64x = svcvt_f64_f32_x(
|
|
177
|
+
row_predicate_f64x, svreinterpret_f32_u64(svld1uw_u64(row_predicate_f64x, (nk_u32_t const *)(a + row))));
|
|
178
|
+
svfloat64_t b_f64x = svcvt_f64_f32_x(
|
|
179
|
+
row_predicate_f64x, svreinterpret_f32_u64(svld1uw_u64(row_predicate_f64x, (nk_u32_t const *)(b + row))));
|
|
180
|
+
svfloat64_t d_f64x = svsub_f64_x(row_predicate_f64x, a_f64x, b_f64x);
|
|
181
|
+
outer_sum_f64 += svaddv_f64(predicate_body_f64x, svmul_f64_x(row_predicate_f64x, d_f64x, v_f64x));
|
|
182
|
+
}
|
|
183
|
+
|
|
184
|
+
return outer_sum_f64;
|
|
185
|
+
}
|
|
186
|
+
|
|
187
|
+
NK_PUBLIC void nk_mahalanobis_f32_smef64(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t n,
|
|
188
|
+
nk_f64_t *result) {
|
|
189
|
+
nk_f64_t quadratic = nk_mahalanobis_f32_smef64_streaming_(a, b, c, n);
|
|
190
|
+
*result = nk_f64_sqrt_neon(quadratic > 0 ? quadratic : 0);
|
|
191
|
+
}
|
|
192
|
+
|
|
193
|
+
/**
|
|
194
|
+
* @brief f64 bilinear: row-by-row streaming SVE with Dot2 compensation.
|
|
195
|
+
* 4-row fast path shares b_f64x loads; 1-row tail for remainder.
|
|
196
|
+
*/
|
|
197
|
+
__arm_locally_streaming static void nk_bilinear_f64_smef64_streaming_(nk_f64_t const *a, nk_f64_t const *b,
|
|
198
|
+
nk_f64_t const *c, nk_size_t n,
|
|
199
|
+
nk_f64_t *result) {
|
|
200
|
+
svbool_t predicate_all_f64x = svptrue_b64();
|
|
201
|
+
nk_f64_t outer_sum = 0.0, outer_comp = 0.0;
|
|
202
|
+
nk_size_t row = 0;
|
|
203
|
+
|
|
204
|
+
// 4-row fast path: share b_f64x load across 4 rows
|
|
205
|
+
for (; row + 4 <= n; row += 4) {
|
|
206
|
+
nk_f64_t a0 = a[row + 0], a1 = a[row + 1], a2 = a[row + 2], a3 = a[row + 3];
|
|
207
|
+
svfloat64_t sum_0_f64x = svdup_f64(0), compensation_0_f64x = svdup_f64(0);
|
|
208
|
+
svfloat64_t sum_1_f64x = svdup_f64(0), compensation_1_f64x = svdup_f64(0);
|
|
209
|
+
svfloat64_t sum_2_f64x = svdup_f64(0), compensation_2_f64x = svdup_f64(0);
|
|
210
|
+
svfloat64_t sum_3_f64x = svdup_f64(0), compensation_3_f64x = svdup_f64(0);
|
|
211
|
+
nk_size_t j = 0;
|
|
212
|
+
svbool_t predicate_f64x = svwhilelt_b64(j, n);
|
|
213
|
+
|
|
214
|
+
while (svptest_first(predicate_all_f64x, predicate_f64x)) {
|
|
215
|
+
svfloat64_t b_f64x = svld1_f64(predicate_f64x, b + j);
|
|
216
|
+
nk_dot2_f64_sve_accumulate_(predicate_f64x, &sum_0_f64x, &compensation_0_f64x,
|
|
217
|
+
svld1_f64(predicate_f64x, c + (row + 0) * n + j), b_f64x);
|
|
218
|
+
nk_dot2_f64_sve_accumulate_(predicate_f64x, &sum_1_f64x, &compensation_1_f64x,
|
|
219
|
+
svld1_f64(predicate_f64x, c + (row + 1) * n + j), b_f64x);
|
|
220
|
+
nk_dot2_f64_sve_accumulate_(predicate_f64x, &sum_2_f64x, &compensation_2_f64x,
|
|
221
|
+
svld1_f64(predicate_f64x, c + (row + 2) * n + j), b_f64x);
|
|
222
|
+
nk_dot2_f64_sve_accumulate_(predicate_f64x, &sum_3_f64x, &compensation_3_f64x,
|
|
223
|
+
svld1_f64(predicate_f64x, c + (row + 3) * n + j), b_f64x);
|
|
224
|
+
j += svcntd();
|
|
225
|
+
predicate_f64x = svwhilelt_b64(j, n);
|
|
226
|
+
}
|
|
227
|
+
|
|
228
|
+
nk_f64_t cb[4] = {
|
|
229
|
+
svaddv_f64(predicate_all_f64x, sum_0_f64x) + svaddv_f64(predicate_all_f64x, compensation_0_f64x),
|
|
230
|
+
svaddv_f64(predicate_all_f64x, sum_1_f64x) + svaddv_f64(predicate_all_f64x, compensation_1_f64x),
|
|
231
|
+
svaddv_f64(predicate_all_f64x, sum_2_f64x) + svaddv_f64(predicate_all_f64x, compensation_2_f64x),
|
|
232
|
+
svaddv_f64(predicate_all_f64x, sum_3_f64x) + svaddv_f64(predicate_all_f64x, compensation_3_f64x),
|
|
233
|
+
};
|
|
234
|
+
nk_f64_t av[4] = {a0, a1, a2, a3};
|
|
235
|
+
for (int r = 0; r < 4; ++r) nk_f64_dot2_(&outer_sum, &outer_comp, av[r], cb[r]);
|
|
236
|
+
}
|
|
237
|
+
|
|
238
|
+
// 1-row tail
|
|
239
|
+
for (; row < n; ++row) {
|
|
240
|
+
svfloat64_t sum_f64x = svdup_f64(0.0), compensation_f64x = svdup_f64(0.0);
|
|
241
|
+
nk_size_t j = 0;
|
|
242
|
+
svbool_t predicate_f64x = svwhilelt_b64(j, n);
|
|
243
|
+
|
|
244
|
+
while (svptest_first(predicate_all_f64x, predicate_f64x)) {
|
|
245
|
+
nk_dot2_f64_sve_accumulate_(predicate_f64x, &sum_f64x, &compensation_f64x,
|
|
246
|
+
svld1_f64(predicate_f64x, c + row * n + j), svld1_f64(predicate_f64x, b + j));
|
|
247
|
+
j += svcntd();
|
|
248
|
+
predicate_f64x = svwhilelt_b64(j, n);
|
|
249
|
+
}
|
|
250
|
+
|
|
251
|
+
nk_f64_t cb_j = svaddv_f64(predicate_all_f64x, sum_f64x) + svaddv_f64(predicate_all_f64x, compensation_f64x);
|
|
252
|
+
nk_f64_dot2_(&outer_sum, &outer_comp, a[row], cb_j);
|
|
253
|
+
}
|
|
254
|
+
|
|
255
|
+
*result = outer_sum + outer_comp;
|
|
256
|
+
}
|
|
257
|
+
|
|
258
|
+
NK_PUBLIC void nk_bilinear_f64_smef64(nk_f64_t const *a, nk_f64_t const *b, nk_f64_t const *c, nk_size_t n,
|
|
259
|
+
nk_f64_t *result) {
|
|
260
|
+
nk_bilinear_f64_smef64_streaming_(a, b, c, n, result);
|
|
261
|
+
}
|
|
262
|
+
|
|
263
|
+
/**
|
|
264
|
+
* @brief f64 Mahalanobis: row-by-row streaming SVE with Dot2 compensation.
|
|
265
|
+
* 4-row fast path shares (a−b) column vector; 1-row tail for remainder.
|
|
266
|
+
*/
|
|
267
|
+
__arm_locally_streaming static inline nk_f64_t nk_mahalanobis_f64_smef64_streaming_(nk_f64_t const *a,
|
|
268
|
+
nk_f64_t const *b,
|
|
269
|
+
nk_f64_t const *c, nk_size_t n) {
|
|
270
|
+
svbool_t predicate_all_f64x = svptrue_b64();
|
|
271
|
+
nk_f64_t outer_sum = 0.0, outer_comp = 0.0;
|
|
272
|
+
nk_size_t row = 0;
|
|
273
|
+
|
|
274
|
+
// 4-row fast path: share (a−b) column vector across 4 rows
|
|
275
|
+
for (; row + 4 <= n; row += 4) {
|
|
276
|
+
nk_f64_t d0 = a[row + 0] - b[row + 0], d1 = a[row + 1] - b[row + 1];
|
|
277
|
+
nk_f64_t d2 = a[row + 2] - b[row + 2], d3 = a[row + 3] - b[row + 3];
|
|
278
|
+
svfloat64_t sum_0_f64x = svdup_f64(0), compensation_0_f64x = svdup_f64(0);
|
|
279
|
+
svfloat64_t sum_1_f64x = svdup_f64(0), compensation_1_f64x = svdup_f64(0);
|
|
280
|
+
svfloat64_t sum_2_f64x = svdup_f64(0), compensation_2_f64x = svdup_f64(0);
|
|
281
|
+
svfloat64_t sum_3_f64x = svdup_f64(0), compensation_3_f64x = svdup_f64(0);
|
|
282
|
+
nk_size_t j = 0;
|
|
283
|
+
svbool_t predicate_f64x = svwhilelt_b64(j, n);
|
|
284
|
+
|
|
285
|
+
while (svptest_first(predicate_all_f64x, predicate_f64x)) {
|
|
286
|
+
svfloat64_t diff_col_f64x = svsub_f64_x(predicate_f64x, svld1_f64(predicate_f64x, a + j),
|
|
287
|
+
svld1_f64(predicate_f64x, b + j));
|
|
288
|
+
nk_dot2_f64_sve_accumulate_(predicate_f64x, &sum_0_f64x, &compensation_0_f64x,
|
|
289
|
+
svld1_f64(predicate_f64x, c + (row + 0) * n + j), diff_col_f64x);
|
|
290
|
+
nk_dot2_f64_sve_accumulate_(predicate_f64x, &sum_1_f64x, &compensation_1_f64x,
|
|
291
|
+
svld1_f64(predicate_f64x, c + (row + 1) * n + j), diff_col_f64x);
|
|
292
|
+
nk_dot2_f64_sve_accumulate_(predicate_f64x, &sum_2_f64x, &compensation_2_f64x,
|
|
293
|
+
svld1_f64(predicate_f64x, c + (row + 2) * n + j), diff_col_f64x);
|
|
294
|
+
nk_dot2_f64_sve_accumulate_(predicate_f64x, &sum_3_f64x, &compensation_3_f64x,
|
|
295
|
+
svld1_f64(predicate_f64x, c + (row + 3) * n + j), diff_col_f64x);
|
|
296
|
+
j += svcntd();
|
|
297
|
+
predicate_f64x = svwhilelt_b64(j, n);
|
|
298
|
+
}
|
|
299
|
+
|
|
300
|
+
nk_f64_t cb[4] = {
|
|
301
|
+
svaddv_f64(predicate_all_f64x, sum_0_f64x) + svaddv_f64(predicate_all_f64x, compensation_0_f64x),
|
|
302
|
+
svaddv_f64(predicate_all_f64x, sum_1_f64x) + svaddv_f64(predicate_all_f64x, compensation_1_f64x),
|
|
303
|
+
svaddv_f64(predicate_all_f64x, sum_2_f64x) + svaddv_f64(predicate_all_f64x, compensation_2_f64x),
|
|
304
|
+
svaddv_f64(predicate_all_f64x, sum_3_f64x) + svaddv_f64(predicate_all_f64x, compensation_3_f64x),
|
|
305
|
+
};
|
|
306
|
+
nk_f64_t dv[4] = {d0, d1, d2, d3};
|
|
307
|
+
for (int r = 0; r < 4; ++r) nk_f64_dot2_(&outer_sum, &outer_comp, dv[r], cb[r]);
|
|
308
|
+
}
|
|
309
|
+
|
|
310
|
+
// 1-row tail
|
|
311
|
+
for (; row < n; ++row) {
|
|
312
|
+
nk_f64_t diff_row = a[row] - b[row];
|
|
313
|
+
svfloat64_t sum_f64x = svdup_f64(0.0), compensation_f64x = svdup_f64(0.0);
|
|
314
|
+
nk_size_t j = 0;
|
|
315
|
+
svbool_t predicate_f64x = svwhilelt_b64(j, n);
|
|
316
|
+
|
|
317
|
+
while (svptest_first(predicate_all_f64x, predicate_f64x)) {
|
|
318
|
+
svfloat64_t diff_col_f64x = svsub_f64_x(predicate_f64x, svld1_f64(predicate_f64x, a + j),
|
|
319
|
+
svld1_f64(predicate_f64x, b + j));
|
|
320
|
+
nk_dot2_f64_sve_accumulate_(predicate_f64x, &sum_f64x, &compensation_f64x,
|
|
321
|
+
svld1_f64(predicate_f64x, c + row * n + j), diff_col_f64x);
|
|
322
|
+
j += svcntd();
|
|
323
|
+
predicate_f64x = svwhilelt_b64(j, n);
|
|
324
|
+
}
|
|
325
|
+
|
|
326
|
+
nk_f64_t cb_j = svaddv_f64(predicate_all_f64x, sum_f64x) + svaddv_f64(predicate_all_f64x, compensation_f64x);
|
|
327
|
+
nk_f64_dot2_(&outer_sum, &outer_comp, diff_row, cb_j);
|
|
328
|
+
}
|
|
329
|
+
|
|
330
|
+
return outer_sum + outer_comp;
|
|
331
|
+
}
|
|
332
|
+
|
|
333
|
+
NK_PUBLIC void nk_mahalanobis_f64_smef64(nk_f64_t const *a, nk_f64_t const *b, nk_f64_t const *c, nk_size_t n,
|
|
334
|
+
nk_f64_t *result) {
|
|
335
|
+
nk_f64_t quadratic = nk_mahalanobis_f64_smef64_streaming_(a, b, c, n);
|
|
336
|
+
*result = nk_f64_sqrt_neon(quadratic > 0 ? quadratic : 0);
|
|
337
|
+
}
|
|
338
|
+
|
|
339
|
+
/**
|
|
340
|
+
* @brief f32c bilinear: complex GEMV via FMOPA (widening f32→f64).
|
|
341
|
+
* ZA0.D = C staging, ZA1.D = v_real accumulator, ZA2.D = v_imag accumulator.
|
|
342
|
+
*/
|
|
343
|
+
__arm_locally_streaming __arm_new("za") static void nk_bilinear_f32c_smef64_streaming_(
|
|
344
|
+
nk_f32c_t const *a_pairs, nk_f32c_t const *b_pairs, nk_f32c_t const *c_pairs, nk_size_t n, nk_f64c_t *results) {
|
|
345
|
+
svbool_t predicate_body_f64x = svptrue_b64();
|
|
346
|
+
nk_size_t tile_dimension = svcntd();
|
|
347
|
+
nk_f64_t outer_sum_real_f64 = 0.0, outer_sum_imag_f64 = 0.0;
|
|
348
|
+
|
|
349
|
+
for (nk_size_t row = 0; row < n; row += tile_dimension) {
|
|
350
|
+
nk_size_t rows_remaining = (row + tile_dimension <= n) ? tile_dimension : (n - row);
|
|
351
|
+
svbool_t row_predicate_f64x = svwhilelt_b64_u64(0u, rows_remaining);
|
|
352
|
+
|
|
353
|
+
svzero_mask_za(nk_sme_zero_za64_tile_1_);
|
|
354
|
+
svzero_mask_za(nk_sme_zero_za64_tile_2_);
|
|
355
|
+
|
|
356
|
+
for (nk_size_t j = 0; j < n; j += tile_dimension) {
|
|
357
|
+
nk_size_t batch_size = (j + tile_dimension <= n) ? tile_dimension : (n - j);
|
|
358
|
+
svbool_t batch_predicate_f64x = svwhilelt_b64_u64(0u, batch_size);
|
|
359
|
+
svbool_t batch_predicate_f32x = svwhilelt_b32_u64(0u, batch_size + batch_size);
|
|
360
|
+
|
|
361
|
+
// Pass 1: Stage C_real into ZA0
|
|
362
|
+
svzero_mask_za(nk_sme_zero_za64_tile_0_);
|
|
363
|
+
for (nk_size_t r = 0; r < rows_remaining; r++) {
|
|
364
|
+
svfloat32_t c_f32x = svld1_f32(batch_predicate_f32x,
|
|
365
|
+
(nk_f32_t const *)c_pairs + ((row + r) * n + j) * 2);
|
|
366
|
+
svfloat64_t c_real_f64x = svcvt_f64_f32_x(batch_predicate_f64x, svtrn1_f32(c_f32x, c_f32x));
|
|
367
|
+
svwrite_hor_za64_f64_m(0, r, batch_predicate_f64x, c_real_f64x);
|
|
368
|
+
}
|
|
369
|
+
|
|
370
|
+
for (nk_size_t k = 0; k < batch_size; k++) {
|
|
371
|
+
svfloat64_t c_re_col_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), row_predicate_f64x, 0, k);
|
|
372
|
+
svmopa_za64_f64_m(1, row_predicate_f64x, row_predicate_f64x, c_re_col_f64x,
|
|
373
|
+
svdup_f64((nk_f64_t)b_pairs[j + k].real)); // v_real += c_real × b_real
|
|
374
|
+
svmopa_za64_f64_m(2, row_predicate_f64x, row_predicate_f64x, c_re_col_f64x,
|
|
375
|
+
svdup_f64((nk_f64_t)b_pairs[j + k].imag)); // v_imag += c_real × b_imag
|
|
376
|
+
}
|
|
377
|
+
|
|
378
|
+
// Pass 2: Stage C_imag into ZA0
|
|
379
|
+
svzero_mask_za(nk_sme_zero_za64_tile_0_);
|
|
380
|
+
for (nk_size_t r = 0; r < rows_remaining; r++) {
|
|
381
|
+
svfloat32_t c_f32x = svld1_f32(batch_predicate_f32x,
|
|
382
|
+
(nk_f32_t const *)c_pairs + ((row + r) * n + j) * 2);
|
|
383
|
+
svfloat64_t c_imag_f64x = svcvt_f64_f32_x(batch_predicate_f64x, svtrn2_f32(c_f32x, c_f32x));
|
|
384
|
+
svwrite_hor_za64_f64_m(0, r, batch_predicate_f64x, c_imag_f64x);
|
|
385
|
+
}
|
|
386
|
+
|
|
387
|
+
for (nk_size_t k = 0; k < batch_size; k++) {
|
|
388
|
+
svfloat64_t c_im_col_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), row_predicate_f64x, 0, k);
|
|
389
|
+
svmopa_za64_f64_m(2, row_predicate_f64x, row_predicate_f64x, c_im_col_f64x,
|
|
390
|
+
svdup_f64((nk_f64_t)b_pairs[j + k].real)); // v_imag += c_imag × b_real
|
|
391
|
+
svmops_za64_f64_m(1, row_predicate_f64x, row_predicate_f64x, c_im_col_f64x,
|
|
392
|
+
svdup_f64((nk_f64_t)b_pairs[j + k].imag)); // v_real -= c_imag × b_imag
|
|
393
|
+
}
|
|
394
|
+
}
|
|
395
|
+
|
|
396
|
+
svfloat64_t v_re_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), row_predicate_f64x, 1, 0);
|
|
397
|
+
svfloat64_t v_im_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), row_predicate_f64x, 2, 0);
|
|
398
|
+
|
|
399
|
+
// Deinterleave a[row:row+tile]
|
|
400
|
+
svbool_t row_predicate_f32x = svwhilelt_b32_u64(0u, rows_remaining + rows_remaining);
|
|
401
|
+
svfloat32_t a_f32x = svld1_f32(row_predicate_f32x, (nk_f32_t const *)a_pairs + row * 2);
|
|
402
|
+
svfloat64_t a_re_f64x = svcvt_f64_f32_x(row_predicate_f64x, svtrn1_f32(a_f32x, a_f32x));
|
|
403
|
+
svfloat64_t a_im_f64x = svcvt_f64_f32_x(row_predicate_f64x, svtrn2_f32(a_f32x, a_f32x));
|
|
404
|
+
|
|
405
|
+
// Complex dot: a × v
|
|
406
|
+
outer_sum_real_f64 += svaddv_f64(
|
|
407
|
+
predicate_body_f64x, svsub_f64_x(row_predicate_f64x, svmul_f64_x(row_predicate_f64x, a_re_f64x, v_re_f64x),
|
|
408
|
+
svmul_f64_x(row_predicate_f64x, a_im_f64x, v_im_f64x)));
|
|
409
|
+
outer_sum_imag_f64 += svaddv_f64(
|
|
410
|
+
predicate_body_f64x, svadd_f64_x(row_predicate_f64x, svmul_f64_x(row_predicate_f64x, a_re_f64x, v_im_f64x),
|
|
411
|
+
svmul_f64_x(row_predicate_f64x, a_im_f64x, v_re_f64x)));
|
|
412
|
+
}
|
|
413
|
+
|
|
414
|
+
results->real = outer_sum_real_f64;
|
|
415
|
+
results->imag = outer_sum_imag_f64;
|
|
416
|
+
}
|
|
417
|
+
|
|
418
|
+
NK_PUBLIC void nk_bilinear_f32c_smef64(nk_f32c_t const *a_pairs, nk_f32c_t const *b_pairs, nk_f32c_t const *c_pairs,
|
|
419
|
+
nk_size_t n, nk_f64c_t *results) {
|
|
420
|
+
nk_bilinear_f32c_smef64_streaming_(a_pairs, b_pairs, c_pairs, n, results);
|
|
421
|
+
}
|
|
422
|
+
|
|
423
|
+
/**
|
|
424
|
+
* @brief f64c bilinear: interleaved Dot2 with permute + deferred XOR sign-flip.
|
|
425
|
+
* 2 accumulators instead of 4, halving inner loop work (~15 vs ~28 SVE ops).
|
|
426
|
+
*/
|
|
427
|
+
__arm_locally_streaming static void nk_bilinear_f64c_smef64_streaming_(nk_f64c_t const *a_pairs,
|
|
428
|
+
nk_f64c_t const *b_pairs,
|
|
429
|
+
nk_f64c_t const *c_pairs, nk_size_t n,
|
|
430
|
+
nk_f64c_t *results) {
|
|
431
|
+
svbool_t predicate_all_f64x = svptrue_b64();
|
|
432
|
+
nk_f64_t outer_sum_real = 0.0, outer_comp_real = 0.0;
|
|
433
|
+
nk_f64_t outer_sum_imag = 0.0, outer_comp_imag = 0.0;
|
|
434
|
+
nk_size_t const n2 = n * 2; // total f64 elements in interleaved layout
|
|
435
|
+
|
|
436
|
+
// swap_idx_u64x = [1,0,3,2,5,4,...] — swap adjacent f64 lanes
|
|
437
|
+
svuint64_t swap_idx_u64x = sveor_u64_x(predicate_all_f64x, svindex_u64(0, 1), svdup_u64(1));
|
|
438
|
+
// sign_mask_u64x = [0, 0x8000..., 0, 0x8000..., ...] — sign bit in odd positions
|
|
439
|
+
svuint64_t sign_mask_u64x = svlsl_u64_x(
|
|
440
|
+
predicate_all_f64x, svand_u64_x(predicate_all_f64x, svindex_u64(0, 1), svdup_u64(1)), svdup_u64(63));
|
|
441
|
+
|
|
442
|
+
for (nk_size_t row = 0; row < n; ++row) {
|
|
443
|
+
nk_f64_t a_real = a_pairs[row].real;
|
|
444
|
+
nk_f64_t a_imag = a_pairs[row].imag;
|
|
445
|
+
|
|
446
|
+
// 2 interleaved Dot2 accumulators (instead of 4 deinterleaved)
|
|
447
|
+
svfloat64_t sum_real_f64x = svdup_f64(0), comp_real_f64x = svdup_f64(0);
|
|
448
|
+
svfloat64_t sum_imag_f64x = svdup_f64(0), comp_imag_f64x = svdup_f64(0);
|
|
449
|
+
nk_size_t j = 0;
|
|
450
|
+
svbool_t predicate_f64x = svwhilelt_b64(j, n2);
|
|
451
|
+
|
|
452
|
+
while (svptest_first(predicate_all_f64x, predicate_f64x)) {
|
|
453
|
+
// Load interleaved [re₀, im₀, re₁, im₁, ...] — no deinterleave needed
|
|
454
|
+
svfloat64_t b_f64x = svld1_f64(predicate_f64x, (nk_f64_t const *)b_pairs + j);
|
|
455
|
+
svfloat64_t c_f64x = svld1_f64(predicate_f64x, (nk_f64_t const *)c_pairs + row * n2 + j);
|
|
456
|
+
svfloat64_t c_swapped_f64x = svtbl_f64(c_f64x, swap_idx_u64x);
|
|
457
|
+
|
|
458
|
+
// 2 Dot2 accumulators instead of 4:
|
|
459
|
+
// sum_real_f64x accumulates [c_real×b_real, c_imag×b_imag, ...] (sign-flip deferred)
|
|
460
|
+
// sum_imag_f64x accumulates [c_imag×b_real, c_real×b_imag, ...] (all positive)
|
|
461
|
+
nk_dot2_f64_sve_accumulate_(predicate_f64x, &sum_real_f64x, &comp_real_f64x, c_f64x, b_f64x);
|
|
462
|
+
nk_dot2_f64_sve_accumulate_(predicate_f64x, &sum_imag_f64x, &comp_imag_f64x, c_swapped_f64x, b_f64x);
|
|
463
|
+
|
|
464
|
+
j += svcntd();
|
|
465
|
+
predicate_f64x = svwhilelt_b64(j, n2);
|
|
466
|
+
}
|
|
467
|
+
|
|
468
|
+
// Flip sign of odd positions in sum_real_f64x: [c_real×b_real, -(c_imag×b_imag), ...]
|
|
469
|
+
sum_real_f64x = svreinterpret_f64_u64(
|
|
470
|
+
sveor_u64_x(predicate_all_f64x, svreinterpret_u64_f64(sum_real_f64x), sign_mask_u64x));
|
|
471
|
+
comp_real_f64x = svreinterpret_f64_u64(
|
|
472
|
+
sveor_u64_x(predicate_all_f64x, svreinterpret_u64_f64(comp_real_f64x), sign_mask_u64x));
|
|
473
|
+
nk_f64_t inner_real = svaddv_f64(predicate_all_f64x,
|
|
474
|
+
svadd_f64_x(predicate_all_f64x, sum_real_f64x, comp_real_f64x));
|
|
475
|
+
nk_f64_t inner_imag = svaddv_f64(predicate_all_f64x,
|
|
476
|
+
svadd_f64_x(predicate_all_f64x, sum_imag_f64x, comp_imag_f64x));
|
|
477
|
+
|
|
478
|
+
// Outer Dot2 complex multiply: a × inner
|
|
479
|
+
nk_f64_dot2_(&outer_sum_real, &outer_comp_real, a_real, inner_real);
|
|
480
|
+
nk_f64_dot2_(&outer_sum_real, &outer_comp_real, -a_imag, inner_imag);
|
|
481
|
+
nk_f64_dot2_(&outer_sum_imag, &outer_comp_imag, a_real, inner_imag);
|
|
482
|
+
nk_f64_dot2_(&outer_sum_imag, &outer_comp_imag, a_imag, inner_real);
|
|
483
|
+
}
|
|
484
|
+
|
|
485
|
+
results->real = outer_sum_real + outer_comp_real;
|
|
486
|
+
results->imag = outer_sum_imag + outer_comp_imag;
|
|
487
|
+
}
|
|
488
|
+
|
|
489
|
+
NK_PUBLIC void nk_bilinear_f64c_smef64(nk_f64c_t const *a_pairs, nk_f64c_t const *b_pairs, nk_f64c_t const *c_pairs,
|
|
490
|
+
nk_size_t n, nk_f64c_t *results) {
|
|
491
|
+
nk_bilinear_f64c_smef64_streaming_(a_pairs, b_pairs, c_pairs, n, results);
|
|
492
|
+
}
|
|
493
|
+
|
|
494
|
+
#if defined(__clang__)
|
|
495
|
+
#pragma clang attribute pop
|
|
496
|
+
#elif defined(__GNUC__)
|
|
497
|
+
#pragma GCC pop_options
|
|
498
|
+
#endif
|
|
499
|
+
|
|
500
|
+
#if defined(__cplusplus)
|
|
501
|
+
} // extern "C"
|
|
502
|
+
#endif
|
|
503
|
+
|
|
504
|
+
#endif // NK_TARGET_SMEF64
|
|
505
|
+
#endif // NK_TARGET_ARM_
|
|
506
|
+
#endif // NK_CURVED_SMEF64_H
|