numkong 7.0.0 → 7.4.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +239 -122
- package/binding.gyp +25 -491
- package/c/dispatch_bf16.c +59 -1
- package/c/dispatch_e2m3.c +41 -8
- package/c/dispatch_e3m2.c +49 -8
- package/c/dispatch_e4m3.c +51 -9
- package/c/dispatch_e5m2.c +45 -1
- package/c/dispatch_f16.c +79 -26
- package/c/dispatch_f16c.c +5 -5
- package/c/dispatch_f32.c +56 -0
- package/c/dispatch_f64.c +52 -0
- package/c/dispatch_i4.c +3 -0
- package/c/dispatch_i8.c +62 -3
- package/c/dispatch_other.c +18 -0
- package/c/dispatch_u1.c +54 -9
- package/c/dispatch_u4.c +3 -0
- package/c/dispatch_u8.c +64 -3
- package/c/numkong.c +3 -0
- package/include/README.md +79 -9
- package/include/numkong/attention/sapphireamx.h +278 -276
- package/include/numkong/attention/sme.h +983 -977
- package/include/numkong/attention.h +1 -1
- package/include/numkong/capabilities.h +289 -94
- package/include/numkong/cast/README.md +40 -40
- package/include/numkong/cast/diamond.h +64 -0
- package/include/numkong/cast/haswell.h +42 -194
- package/include/numkong/cast/icelake.h +42 -37
- package/include/numkong/cast/loongsonasx.h +252 -0
- package/include/numkong/cast/neon.h +216 -249
- package/include/numkong/cast/powervsx.h +449 -0
- package/include/numkong/cast/rvv.h +223 -274
- package/include/numkong/cast/sapphire.h +18 -18
- package/include/numkong/cast/serial.h +1018 -944
- package/include/numkong/cast/skylake.h +82 -23
- package/include/numkong/cast/v128relaxed.h +462 -105
- package/include/numkong/cast.h +24 -0
- package/include/numkong/cast.hpp +44 -0
- package/include/numkong/curved/README.md +17 -17
- package/include/numkong/curved/neon.h +131 -7
- package/include/numkong/curved/neonbfdot.h +6 -7
- package/include/numkong/curved/rvv.h +26 -26
- package/include/numkong/curved/smef64.h +186 -182
- package/include/numkong/curved.h +14 -18
- package/include/numkong/dot/README.md +154 -137
- package/include/numkong/dot/alder.h +43 -43
- package/include/numkong/dot/diamond.h +158 -0
- package/include/numkong/dot/genoa.h +4 -30
- package/include/numkong/dot/haswell.h +215 -180
- package/include/numkong/dot/icelake.h +190 -76
- package/include/numkong/dot/loongsonasx.h +671 -0
- package/include/numkong/dot/neon.h +124 -73
- package/include/numkong/dot/neonbfdot.h +11 -12
- package/include/numkong/dot/neonfhm.h +44 -46
- package/include/numkong/dot/neonfp8.h +323 -0
- package/include/numkong/dot/neonsdot.h +190 -76
- package/include/numkong/dot/powervsx.h +752 -0
- package/include/numkong/dot/rvv.h +92 -84
- package/include/numkong/dot/rvvbf16.h +12 -12
- package/include/numkong/dot/rvvhalf.h +12 -12
- package/include/numkong/dot/sapphire.h +4 -4
- package/include/numkong/dot/serial.h +66 -30
- package/include/numkong/dot/sierra.h +31 -31
- package/include/numkong/dot/skylake.h +142 -110
- package/include/numkong/dot/sve.h +217 -177
- package/include/numkong/dot/svebfdot.h +10 -10
- package/include/numkong/dot/svehalf.h +85 -41
- package/include/numkong/dot/svesdot.h +89 -0
- package/include/numkong/dot/v128relaxed.h +124 -89
- package/include/numkong/dot.h +114 -48
- package/include/numkong/dots/README.md +203 -203
- package/include/numkong/dots/alder.h +12 -9
- package/include/numkong/dots/diamond.h +86 -0
- package/include/numkong/dots/genoa.h +10 -4
- package/include/numkong/dots/haswell.h +63 -48
- package/include/numkong/dots/icelake.h +27 -18
- package/include/numkong/dots/loongsonasx.h +176 -0
- package/include/numkong/dots/neon.h +14 -11
- package/include/numkong/dots/neonbfdot.h +4 -3
- package/include/numkong/dots/neonfhm.h +11 -9
- package/include/numkong/dots/neonfp8.h +99 -0
- package/include/numkong/dots/neonsdot.h +48 -12
- package/include/numkong/dots/powervsx.h +194 -0
- package/include/numkong/dots/rvv.h +451 -344
- package/include/numkong/dots/sapphireamx.h +1028 -984
- package/include/numkong/dots/serial.h +213 -197
- package/include/numkong/dots/sierra.h +10 -7
- package/include/numkong/dots/skylake.h +47 -36
- package/include/numkong/dots/sme.h +2001 -2364
- package/include/numkong/dots/smebi32.h +175 -162
- package/include/numkong/dots/smef64.h +328 -323
- package/include/numkong/dots/v128relaxed.h +64 -41
- package/include/numkong/dots.h +573 -293
- package/include/numkong/dots.hpp +45 -43
- package/include/numkong/each/README.md +133 -137
- package/include/numkong/each/haswell.h +6 -6
- package/include/numkong/each/icelake.h +7 -7
- package/include/numkong/each/neon.h +76 -42
- package/include/numkong/each/neonbfdot.h +11 -12
- package/include/numkong/each/neonhalf.h +24 -116
- package/include/numkong/each/rvv.h +28 -28
- package/include/numkong/each/sapphire.h +27 -161
- package/include/numkong/each/serial.h +6 -6
- package/include/numkong/each/skylake.h +7 -7
- package/include/numkong/each/v128relaxed.h +562 -0
- package/include/numkong/each.h +148 -62
- package/include/numkong/each.hpp +2 -2
- package/include/numkong/geospatial/README.md +18 -18
- package/include/numkong/geospatial/haswell.h +365 -325
- package/include/numkong/geospatial/neon.h +350 -306
- package/include/numkong/geospatial/rvv.h +4 -4
- package/include/numkong/geospatial/skylake.h +376 -340
- package/include/numkong/geospatial/v128relaxed.h +366 -327
- package/include/numkong/geospatial.h +17 -17
- package/include/numkong/matrix.hpp +4 -4
- package/include/numkong/maxsim/README.md +14 -14
- package/include/numkong/maxsim/alder.h +6 -6
- package/include/numkong/maxsim/genoa.h +4 -4
- package/include/numkong/maxsim/haswell.h +6 -6
- package/include/numkong/maxsim/icelake.h +18 -18
- package/include/numkong/maxsim/neonsdot.h +21 -21
- package/include/numkong/maxsim/sapphireamx.h +14 -14
- package/include/numkong/maxsim/serial.h +6 -6
- package/include/numkong/maxsim/sme.h +221 -196
- package/include/numkong/maxsim/v128relaxed.h +6 -6
- package/include/numkong/mesh/README.md +62 -56
- package/include/numkong/mesh/haswell.h +339 -464
- package/include/numkong/mesh/neon.h +1100 -519
- package/include/numkong/mesh/neonbfdot.h +36 -68
- package/include/numkong/mesh/rvv.h +530 -435
- package/include/numkong/mesh/serial.h +75 -91
- package/include/numkong/mesh/skylake.h +1627 -302
- package/include/numkong/mesh/v128relaxed.h +443 -330
- package/include/numkong/mesh.h +63 -49
- package/include/numkong/mesh.hpp +4 -4
- package/include/numkong/numkong.h +3 -3
- package/include/numkong/numkong.hpp +1 -0
- package/include/numkong/probability/README.md +23 -19
- package/include/numkong/probability/neon.h +82 -52
- package/include/numkong/probability/rvv.h +28 -23
- package/include/numkong/probability/serial.h +51 -39
- package/include/numkong/probability.h +20 -23
- package/include/numkong/random.h +1 -1
- package/include/numkong/reduce/README.md +143 -138
- package/include/numkong/reduce/alder.h +81 -77
- package/include/numkong/reduce/haswell.h +222 -220
- package/include/numkong/reduce/neon.h +629 -519
- package/include/numkong/reduce/neonbfdot.h +7 -218
- package/include/numkong/reduce/neonfhm.h +9 -381
- package/include/numkong/reduce/neonsdot.h +9 -9
- package/include/numkong/reduce/rvv.h +928 -802
- package/include/numkong/reduce/serial.h +23 -27
- package/include/numkong/reduce/sierra.h +20 -20
- package/include/numkong/reduce/skylake.h +326 -324
- package/include/numkong/reduce/v128relaxed.h +52 -52
- package/include/numkong/reduce.h +4 -23
- package/include/numkong/reduce.hpp +156 -11
- package/include/numkong/scalar/README.md +6 -6
- package/include/numkong/scalar/haswell.h +26 -17
- package/include/numkong/scalar/loongsonasx.h +74 -0
- package/include/numkong/scalar/neon.h +9 -9
- package/include/numkong/scalar/powervsx.h +96 -0
- package/include/numkong/scalar/rvv.h +2 -2
- package/include/numkong/scalar/sapphire.h +21 -10
- package/include/numkong/scalar/serial.h +21 -21
- package/include/numkong/scalar.h +13 -0
- package/include/numkong/set/README.md +28 -28
- package/include/numkong/set/haswell.h +12 -12
- package/include/numkong/set/icelake.h +14 -14
- package/include/numkong/set/loongsonasx.h +181 -0
- package/include/numkong/set/neon.h +17 -18
- package/include/numkong/set/powervsx.h +326 -0
- package/include/numkong/set/rvv.h +4 -4
- package/include/numkong/set/serial.h +6 -6
- package/include/numkong/set/sve.h +60 -59
- package/include/numkong/set/v128relaxed.h +6 -6
- package/include/numkong/set.h +21 -7
- package/include/numkong/sets/README.md +26 -26
- package/include/numkong/sets/loongsonasx.h +52 -0
- package/include/numkong/sets/powervsx.h +65 -0
- package/include/numkong/sets/smebi32.h +395 -364
- package/include/numkong/sets.h +83 -40
- package/include/numkong/sparse/README.md +4 -4
- package/include/numkong/sparse/icelake.h +101 -101
- package/include/numkong/sparse/serial.h +1 -1
- package/include/numkong/sparse/sve2.h +137 -141
- package/include/numkong/sparse/turin.h +12 -12
- package/include/numkong/sparse.h +10 -10
- package/include/numkong/spatial/README.md +230 -226
- package/include/numkong/spatial/alder.h +113 -116
- package/include/numkong/spatial/diamond.h +240 -0
- package/include/numkong/spatial/genoa.h +0 -68
- package/include/numkong/spatial/haswell.h +74 -55
- package/include/numkong/spatial/icelake.h +539 -58
- package/include/numkong/spatial/loongsonasx.h +483 -0
- package/include/numkong/spatial/neon.h +125 -52
- package/include/numkong/spatial/neonbfdot.h +8 -9
- package/include/numkong/spatial/neonfp8.h +258 -0
- package/include/numkong/spatial/neonsdot.h +180 -12
- package/include/numkong/spatial/powervsx.h +738 -0
- package/include/numkong/spatial/rvv.h +146 -139
- package/include/numkong/spatial/rvvbf16.h +17 -12
- package/include/numkong/spatial/rvvhalf.h +13 -10
- package/include/numkong/spatial/serial.h +13 -12
- package/include/numkong/spatial/sierra.h +232 -39
- package/include/numkong/spatial/skylake.h +73 -74
- package/include/numkong/spatial/sve.h +93 -72
- package/include/numkong/spatial/svebfdot.h +29 -29
- package/include/numkong/spatial/svehalf.h +52 -26
- package/include/numkong/spatial/svesdot.h +142 -0
- package/include/numkong/spatial/v128relaxed.h +293 -41
- package/include/numkong/spatial.h +338 -82
- package/include/numkong/spatials/README.md +194 -194
- package/include/numkong/spatials/diamond.h +82 -0
- package/include/numkong/spatials/haswell.h +2 -2
- package/include/numkong/spatials/loongsonasx.h +153 -0
- package/include/numkong/spatials/neonfp8.h +111 -0
- package/include/numkong/spatials/neonsdot.h +34 -0
- package/include/numkong/spatials/powervsx.h +153 -0
- package/include/numkong/spatials/rvv.h +259 -243
- package/include/numkong/spatials/sapphireamx.h +173 -173
- package/include/numkong/spatials/serial.h +2 -2
- package/include/numkong/spatials/skylake.h +2 -2
- package/include/numkong/spatials/sme.h +590 -605
- package/include/numkong/spatials/smef64.h +139 -130
- package/include/numkong/spatials/v128relaxed.h +2 -2
- package/include/numkong/spatials.h +820 -500
- package/include/numkong/spatials.hpp +49 -48
- package/include/numkong/tensor.hpp +406 -17
- package/include/numkong/trigonometry/README.md +19 -19
- package/include/numkong/trigonometry/haswell.h +402 -401
- package/include/numkong/trigonometry/neon.h +386 -387
- package/include/numkong/trigonometry/rvv.h +52 -51
- package/include/numkong/trigonometry/serial.h +13 -13
- package/include/numkong/trigonometry/skylake.h +373 -369
- package/include/numkong/trigonometry/v128relaxed.h +375 -374
- package/include/numkong/trigonometry.h +13 -13
- package/include/numkong/trigonometry.hpp +2 -2
- package/include/numkong/types.h +287 -49
- package/include/numkong/types.hpp +436 -12
- package/include/numkong/vector.hpp +82 -14
- package/javascript/dist/cjs/numkong-wasm.js +6 -12
- package/javascript/dist/cjs/numkong.d.ts +7 -1
- package/javascript/dist/cjs/numkong.js +37 -11
- package/javascript/dist/cjs/types.d.ts +9 -0
- package/javascript/dist/cjs/types.js +96 -0
- package/javascript/dist/esm/numkong-browser.d.ts +14 -0
- package/javascript/dist/esm/numkong-browser.js +23 -0
- package/javascript/dist/esm/numkong-wasm.js +6 -12
- package/javascript/dist/esm/numkong.d.ts +7 -1
- package/javascript/dist/esm/numkong.js +37 -11
- package/javascript/dist/esm/types.d.ts +9 -0
- package/javascript/dist/esm/types.js +96 -0
- package/javascript/node-gyp-build.d.ts +4 -1
- package/javascript/numkong-browser.ts +40 -0
- package/javascript/numkong-wasm.ts +7 -13
- package/javascript/numkong.c +5 -26
- package/javascript/numkong.ts +36 -11
- package/javascript/tsconfig-base.json +1 -0
- package/javascript/tsconfig-cjs.json +6 -1
- package/javascript/types.ts +110 -0
- package/numkong.gypi +101 -0
- package/package.json +34 -13
- package/probes/arm_neon.c +8 -0
- package/probes/arm_neon_bfdot.c +9 -0
- package/probes/arm_neon_fhm.c +9 -0
- package/probes/arm_neon_half.c +8 -0
- package/probes/arm_neon_sdot.c +9 -0
- package/probes/arm_neonfp8.c +9 -0
- package/probes/arm_sme.c +16 -0
- package/probes/arm_sme2.c +16 -0
- package/probes/arm_sme2p1.c +16 -0
- package/probes/arm_sme_bf16.c +16 -0
- package/probes/arm_sme_bi32.c +16 -0
- package/probes/arm_sme_f64.c +16 -0
- package/probes/arm_sme_fa64.c +14 -0
- package/probes/arm_sme_half.c +16 -0
- package/probes/arm_sme_lut2.c +15 -0
- package/probes/arm_sve.c +18 -0
- package/probes/arm_sve2.c +20 -0
- package/probes/arm_sve2p1.c +18 -0
- package/probes/arm_sve_bfdot.c +20 -0
- package/probes/arm_sve_half.c +18 -0
- package/probes/arm_sve_sdot.c +21 -0
- package/probes/loongarch_lasx.c +12 -0
- package/probes/power_vsx.c +12 -0
- package/probes/probe.js +127 -0
- package/probes/riscv_rvv.c +14 -0
- package/probes/riscv_rvv_bb.c +15 -0
- package/probes/riscv_rvv_bf16.c +17 -0
- package/probes/riscv_rvv_half.c +14 -0
- package/probes/wasm_v128relaxed.c +11 -0
- package/probes/x86_alder.c +17 -0
- package/probes/x86_diamond.c +17 -0
- package/probes/x86_genoa.c +17 -0
- package/probes/x86_graniteamx.c +19 -0
- package/probes/x86_haswell.c +11 -0
- package/probes/x86_icelake.c +17 -0
- package/probes/x86_sapphire.c +16 -0
- package/probes/x86_sapphireamx.c +18 -0
- package/probes/x86_sierra.c +17 -0
- package/probes/x86_skylake.c +15 -0
- package/probes/x86_turin.c +17 -0
- package/wasm/numkong-emscripten.js +2 -0
- package/wasm/numkong.d.ts +14 -0
- package/wasm/numkong.js +1124 -0
- package/wasm/numkong.wasm +0 -0
- package/include/numkong/curved/neonhalf.h +0 -212
- package/include/numkong/dot/neonhalf.h +0 -198
- package/include/numkong/dots/neonhalf.h +0 -57
- package/include/numkong/mesh/neonhalf.h +0 -616
- package/include/numkong/reduce/neonhalf.h +0 -157
- package/include/numkong/spatial/neonhalf.h +0 -118
- package/include/numkong/spatial/sapphire.h +0 -343
- package/include/numkong/spatials/neonhalf.h +0 -58
- package/javascript/README.md +0 -246
|
@@ -61,7 +61,7 @@ extern "C" {
|
|
|
61
61
|
#endif
|
|
62
62
|
|
|
63
63
|
#if defined(__clang__)
|
|
64
|
-
#pragma clang attribute push(__attribute__((target("sme,
|
|
64
|
+
#pragma clang attribute push(__attribute__((target("sme,sme-f64f64"))), apply_to = function)
|
|
65
65
|
#elif defined(__GNUC__)
|
|
66
66
|
#pragma GCC push_options
|
|
67
67
|
#pragma GCC target("+sme+sme-f64f64")
|
|
@@ -71,122 +71,123 @@ extern "C" {
|
|
|
71
71
|
* @brief SVE Dot2 accumulator: sum += a × b with error compensation.
|
|
72
72
|
* Uses TwoProd (svneg+svnmls) and TwoSum error-free transformations.
|
|
73
73
|
*/
|
|
74
|
-
NK_PUBLIC void nk_dot2_f64_sve_accumulate_(svbool_t
|
|
75
|
-
svfloat64_t a_f64x, svfloat64_t b_f64x)
|
|
76
|
-
svfloat64_t product_f64x = svmul_f64_x(
|
|
77
|
-
svfloat64_t product_error_f64x = svneg_f64_x(
|
|
78
|
-
svnmls_f64_x(
|
|
79
|
-
svfloat64_t running_sum_f64x =
|
|
80
|
-
svfloat64_t recovered_addend_f64x = svsub_f64_x(
|
|
74
|
+
NK_PUBLIC void nk_dot2_f64_sve_accumulate_(svbool_t predicate_b64x, svfloat64_t *sum, svfloat64_t *comp,
|
|
75
|
+
svfloat64_t a_f64x, svfloat64_t b_f64x) NK_STREAMING_ {
|
|
76
|
+
svfloat64_t product_f64x = svmul_f64_x(predicate_b64x, a_f64x, b_f64x);
|
|
77
|
+
svfloat64_t product_error_f64x = svneg_f64_x(predicate_b64x,
|
|
78
|
+
svnmls_f64_x(predicate_b64x, product_f64x, a_f64x, b_f64x));
|
|
79
|
+
svfloat64_t running_sum_f64x = svadd_f64_m(predicate_b64x, *sum, product_f64x);
|
|
80
|
+
svfloat64_t recovered_addend_f64x = svsub_f64_x(predicate_b64x, running_sum_f64x, *sum);
|
|
81
81
|
svfloat64_t sum_error_f64x = svadd_f64_x(
|
|
82
|
-
|
|
83
|
-
svsub_f64_x(
|
|
84
|
-
svsub_f64_x(
|
|
82
|
+
predicate_b64x,
|
|
83
|
+
svsub_f64_x(predicate_b64x, *sum, svsub_f64_x(predicate_b64x, running_sum_f64x, recovered_addend_f64x)),
|
|
84
|
+
svsub_f64_x(predicate_b64x, product_f64x, recovered_addend_f64x));
|
|
85
85
|
*sum = running_sum_f64x;
|
|
86
|
-
*comp =
|
|
86
|
+
*comp = svadd_f64_m(predicate_b64x, *comp, svadd_f64_x(predicate_b64x, sum_error_f64x, product_error_f64x));
|
|
87
87
|
}
|
|
88
88
|
|
|
89
89
|
/**
|
|
90
90
|
* @brief f32 bilinear: GEMV via FMOPA (widening f32→f64, exact accumulation).
|
|
91
91
|
* ZA0.D = C staging, ZA1.D = GEMV accumulator.
|
|
92
92
|
*/
|
|
93
|
-
__arm_locally_streaming __arm_new("za") static void nk_bilinear_f32_smef64_streaming_(
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
nk_f64_t *result) {
|
|
97
|
-
svbool_t predicate_body_f64x = svptrue_b64();
|
|
93
|
+
__arm_locally_streaming __arm_new("za") static void nk_bilinear_f32_smef64_streaming_(
|
|
94
|
+
nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t dimensions, nk_f64_t *result) {
|
|
95
|
+
svbool_t predicate_body_b64x = svptrue_b64();
|
|
98
96
|
nk_size_t tile_dimension = svcntd();
|
|
99
97
|
nk_f64_t outer_sum_f64 = 0.0;
|
|
100
98
|
|
|
101
|
-
for (nk_size_t row = 0; row <
|
|
102
|
-
nk_size_t rows_remaining = (row + tile_dimension <=
|
|
103
|
-
svbool_t
|
|
99
|
+
for (nk_size_t row = 0; row < dimensions; row += tile_dimension) {
|
|
100
|
+
nk_size_t rows_remaining = (row + tile_dimension <= dimensions) ? tile_dimension : (dimensions - row);
|
|
101
|
+
svbool_t row_predicate_b64x = svwhilelt_b64_u64(0u, rows_remaining);
|
|
104
102
|
|
|
105
103
|
svzero_mask_za(nk_sme_zero_za64_tile_1_);
|
|
106
104
|
|
|
107
|
-
for (nk_size_t j = 0; j <
|
|
108
|
-
nk_size_t batch_size = (j + tile_dimension <=
|
|
109
|
-
svbool_t
|
|
105
|
+
for (nk_size_t j = 0; j < dimensions; j += tile_dimension) {
|
|
106
|
+
nk_size_t batch_size = (j + tile_dimension <= dimensions) ? tile_dimension : (dimensions - j);
|
|
107
|
+
svbool_t batch_predicate_b64x = svwhilelt_b64_u64(0u, batch_size);
|
|
110
108
|
|
|
111
109
|
svzero_mask_za(nk_sme_zero_za64_tile_0_);
|
|
112
110
|
for (nk_size_t r = 0; r < rows_remaining; r++) {
|
|
113
111
|
svfloat64_t c_row_f64x = svcvt_f64_f32_x(
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
112
|
+
batch_predicate_b64x,
|
|
113
|
+
svreinterpret_f32_u64(
|
|
114
|
+
svld1uw_u64(batch_predicate_b64x, (nk_u32_t const *)(c + (row + r) * dimensions + j))));
|
|
115
|
+
svwrite_hor_za64_f64_m(0, r, batch_predicate_b64x, c_row_f64x);
|
|
117
116
|
}
|
|
118
117
|
|
|
119
118
|
for (nk_size_t k = 0; k < batch_size; k++) {
|
|
120
|
-
svfloat64_t c_col_f64x = svread_ver_za64_f64_m(svdup_f64(0.0),
|
|
121
|
-
svmopa_za64_f64_m(1,
|
|
119
|
+
svfloat64_t c_col_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), row_predicate_b64x, 0, k);
|
|
120
|
+
svmopa_za64_f64_m(1, row_predicate_b64x, row_predicate_b64x, c_col_f64x, svdup_f64((nk_f64_t)b[j + k]));
|
|
122
121
|
}
|
|
123
122
|
}
|
|
124
123
|
|
|
125
|
-
svfloat64_t v_f64x = svread_ver_za64_f64_m(svdup_f64(0.0),
|
|
124
|
+
svfloat64_t v_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), row_predicate_b64x, 1, 0);
|
|
126
125
|
svfloat64_t a_f64x = svcvt_f64_f32_x(
|
|
127
|
-
|
|
128
|
-
outer_sum_f64 += svaddv_f64(
|
|
126
|
+
row_predicate_b64x, svreinterpret_f32_u64(svld1uw_u64(row_predicate_b64x, (nk_u32_t const *)(a + row))));
|
|
127
|
+
outer_sum_f64 += svaddv_f64(predicate_body_b64x, svmul_f64_x(row_predicate_b64x, a_f64x, v_f64x));
|
|
129
128
|
}
|
|
130
129
|
|
|
131
130
|
*result = outer_sum_f64;
|
|
132
131
|
}
|
|
133
132
|
|
|
134
|
-
NK_PUBLIC void nk_bilinear_f32_smef64(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t
|
|
133
|
+
NK_PUBLIC void nk_bilinear_f32_smef64(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t dimensions,
|
|
135
134
|
nk_f64_t *result) {
|
|
136
|
-
nk_bilinear_f32_smef64_streaming_(a, b, c,
|
|
135
|
+
nk_bilinear_f32_smef64_streaming_(a, b, c, dimensions, result);
|
|
137
136
|
}
|
|
138
137
|
|
|
139
138
|
/**
|
|
140
139
|
* @brief f32 Mahalanobis: GEMV v = C×d via FMOPA, where d = a − b (exact in f64).
|
|
141
140
|
* ZA0.D = C staging, ZA1.D = GEMV accumulator.
|
|
142
141
|
*/
|
|
143
|
-
__arm_locally_streaming __arm_new("za") static
|
|
144
|
-
nk_mahalanobis_f32_smef64_streaming_(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c,
|
|
142
|
+
__arm_locally_streaming __arm_new("za") static nk_f64_t
|
|
143
|
+
nk_mahalanobis_f32_smef64_streaming_(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c,
|
|
144
|
+
nk_size_t dimensions) {
|
|
145
145
|
|
|
146
|
-
svbool_t
|
|
146
|
+
svbool_t predicate_body_b64x = svptrue_b64();
|
|
147
147
|
nk_size_t tile_dimension = svcntd();
|
|
148
148
|
nk_f64_t outer_sum_f64 = 0.0;
|
|
149
149
|
|
|
150
|
-
for (nk_size_t row = 0; row <
|
|
151
|
-
nk_size_t rows_remaining = (row + tile_dimension <=
|
|
152
|
-
svbool_t
|
|
150
|
+
for (nk_size_t row = 0; row < dimensions; row += tile_dimension) {
|
|
151
|
+
nk_size_t rows_remaining = (row + tile_dimension <= dimensions) ? tile_dimension : (dimensions - row);
|
|
152
|
+
svbool_t row_predicate_b64x = svwhilelt_b64_u64(0u, rows_remaining);
|
|
153
153
|
|
|
154
154
|
svzero_mask_za(nk_sme_zero_za64_tile_1_);
|
|
155
155
|
|
|
156
|
-
for (nk_size_t j = 0; j <
|
|
157
|
-
nk_size_t batch_size = (j + tile_dimension <=
|
|
158
|
-
svbool_t
|
|
156
|
+
for (nk_size_t j = 0; j < dimensions; j += tile_dimension) {
|
|
157
|
+
nk_size_t batch_size = (j + tile_dimension <= dimensions) ? tile_dimension : (dimensions - j);
|
|
158
|
+
svbool_t batch_predicate_b64x = svwhilelt_b64_u64(0u, batch_size);
|
|
159
159
|
|
|
160
160
|
svzero_mask_za(nk_sme_zero_za64_tile_0_);
|
|
161
161
|
for (nk_size_t r = 0; r < rows_remaining; r++) {
|
|
162
162
|
svfloat64_t c_row_f64x = svcvt_f64_f32_x(
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
163
|
+
batch_predicate_b64x,
|
|
164
|
+
svreinterpret_f32_u64(
|
|
165
|
+
svld1uw_u64(batch_predicate_b64x, (nk_u32_t const *)(c + (row + r) * dimensions + j))));
|
|
166
|
+
svwrite_hor_za64_f64_m(0, r, batch_predicate_b64x, c_row_f64x);
|
|
166
167
|
}
|
|
167
168
|
|
|
168
169
|
for (nk_size_t k = 0; k < batch_size; k++) {
|
|
169
|
-
svfloat64_t c_col_f64x = svread_ver_za64_f64_m(svdup_f64(0.0),
|
|
170
|
+
svfloat64_t c_col_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), row_predicate_b64x, 0, k);
|
|
170
171
|
nk_f64_t d_k = (nk_f64_t)a[j + k] - (nk_f64_t)b[j + k];
|
|
171
|
-
svmopa_za64_f64_m(1,
|
|
172
|
+
svmopa_za64_f64_m(1, row_predicate_b64x, row_predicate_b64x, c_col_f64x, svdup_f64(d_k));
|
|
172
173
|
}
|
|
173
174
|
}
|
|
174
175
|
|
|
175
|
-
svfloat64_t v_f64x = svread_ver_za64_f64_m(svdup_f64(0.0),
|
|
176
|
+
svfloat64_t v_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), row_predicate_b64x, 1, 0);
|
|
176
177
|
svfloat64_t a_f64x = svcvt_f64_f32_x(
|
|
177
|
-
|
|
178
|
+
row_predicate_b64x, svreinterpret_f32_u64(svld1uw_u64(row_predicate_b64x, (nk_u32_t const *)(a + row))));
|
|
178
179
|
svfloat64_t b_f64x = svcvt_f64_f32_x(
|
|
179
|
-
|
|
180
|
-
svfloat64_t d_f64x = svsub_f64_x(
|
|
181
|
-
outer_sum_f64 += svaddv_f64(
|
|
180
|
+
row_predicate_b64x, svreinterpret_f32_u64(svld1uw_u64(row_predicate_b64x, (nk_u32_t const *)(b + row))));
|
|
181
|
+
svfloat64_t d_f64x = svsub_f64_x(row_predicate_b64x, a_f64x, b_f64x);
|
|
182
|
+
outer_sum_f64 += svaddv_f64(predicate_body_b64x, svmul_f64_x(row_predicate_b64x, d_f64x, v_f64x));
|
|
182
183
|
}
|
|
183
184
|
|
|
184
185
|
return outer_sum_f64;
|
|
185
186
|
}
|
|
186
187
|
|
|
187
|
-
NK_PUBLIC void nk_mahalanobis_f32_smef64(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t
|
|
188
|
+
NK_PUBLIC void nk_mahalanobis_f32_smef64(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t dimensions,
|
|
188
189
|
nk_f64_t *result) {
|
|
189
|
-
nk_f64_t quadratic = nk_mahalanobis_f32_smef64_streaming_(a, b, c,
|
|
190
|
+
nk_f64_t quadratic = nk_mahalanobis_f32_smef64_streaming_(a, b, c, dimensions);
|
|
190
191
|
*result = nk_f64_sqrt_neon(quadratic > 0 ? quadratic : 0);
|
|
191
192
|
}
|
|
192
193
|
|
|
@@ -195,84 +196,84 @@ NK_PUBLIC void nk_mahalanobis_f32_smef64(nk_f32_t const *a, nk_f32_t const *b, n
|
|
|
195
196
|
* 4-row fast path shares b_f64x loads; 1-row tail for remainder.
|
|
196
197
|
*/
|
|
197
198
|
__arm_locally_streaming static void nk_bilinear_f64_smef64_streaming_(nk_f64_t const *a, nk_f64_t const *b,
|
|
198
|
-
nk_f64_t const *c, nk_size_t
|
|
199
|
+
nk_f64_t const *c, nk_size_t dimensions,
|
|
199
200
|
nk_f64_t *result) {
|
|
200
|
-
svbool_t
|
|
201
|
+
svbool_t predicate_all_b64x = svptrue_b64();
|
|
201
202
|
nk_f64_t outer_sum = 0.0, outer_comp = 0.0;
|
|
202
203
|
nk_size_t row = 0;
|
|
203
204
|
|
|
204
205
|
// 4-row fast path: share b_f64x load across 4 rows
|
|
205
|
-
for (; row + 4 <=
|
|
206
|
+
for (; row + 4 <= dimensions; row += 4) {
|
|
206
207
|
nk_f64_t a0 = a[row + 0], a1 = a[row + 1], a2 = a[row + 2], a3 = a[row + 3];
|
|
207
208
|
svfloat64_t sum_0_f64x = svdup_f64(0), compensation_0_f64x = svdup_f64(0);
|
|
208
209
|
svfloat64_t sum_1_f64x = svdup_f64(0), compensation_1_f64x = svdup_f64(0);
|
|
209
210
|
svfloat64_t sum_2_f64x = svdup_f64(0), compensation_2_f64x = svdup_f64(0);
|
|
210
211
|
svfloat64_t sum_3_f64x = svdup_f64(0), compensation_3_f64x = svdup_f64(0);
|
|
211
212
|
nk_size_t j = 0;
|
|
212
|
-
svbool_t
|
|
213
|
-
|
|
214
|
-
while (svptest_first(
|
|
215
|
-
svfloat64_t b_f64x = svld1_f64(
|
|
216
|
-
nk_dot2_f64_sve_accumulate_(
|
|
217
|
-
svld1_f64(
|
|
218
|
-
nk_dot2_f64_sve_accumulate_(
|
|
219
|
-
svld1_f64(
|
|
220
|
-
nk_dot2_f64_sve_accumulate_(
|
|
221
|
-
svld1_f64(
|
|
222
|
-
nk_dot2_f64_sve_accumulate_(
|
|
223
|
-
svld1_f64(
|
|
213
|
+
svbool_t predicate_b64x = svwhilelt_b64(j, dimensions);
|
|
214
|
+
|
|
215
|
+
while (svptest_first(predicate_all_b64x, predicate_b64x)) {
|
|
216
|
+
svfloat64_t b_f64x = svld1_f64(predicate_b64x, b + j);
|
|
217
|
+
nk_dot2_f64_sve_accumulate_(predicate_b64x, &sum_0_f64x, &compensation_0_f64x,
|
|
218
|
+
svld1_f64(predicate_b64x, c + (row + 0) * dimensions + j), b_f64x);
|
|
219
|
+
nk_dot2_f64_sve_accumulate_(predicate_b64x, &sum_1_f64x, &compensation_1_f64x,
|
|
220
|
+
svld1_f64(predicate_b64x, c + (row + 1) * dimensions + j), b_f64x);
|
|
221
|
+
nk_dot2_f64_sve_accumulate_(predicate_b64x, &sum_2_f64x, &compensation_2_f64x,
|
|
222
|
+
svld1_f64(predicate_b64x, c + (row + 2) * dimensions + j), b_f64x);
|
|
223
|
+
nk_dot2_f64_sve_accumulate_(predicate_b64x, &sum_3_f64x, &compensation_3_f64x,
|
|
224
|
+
svld1_f64(predicate_b64x, c + (row + 3) * dimensions + j), b_f64x);
|
|
224
225
|
j += svcntd();
|
|
225
|
-
|
|
226
|
+
predicate_b64x = svwhilelt_b64(j, dimensions);
|
|
226
227
|
}
|
|
227
228
|
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
229
|
+
nk_f64_dot2_(&outer_sum, &outer_comp, a0,
|
|
230
|
+
svaddv_f64(predicate_all_b64x, sum_0_f64x) + svaddv_f64(predicate_all_b64x, compensation_0_f64x));
|
|
231
|
+
nk_f64_dot2_(&outer_sum, &outer_comp, a1,
|
|
232
|
+
svaddv_f64(predicate_all_b64x, sum_1_f64x) + svaddv_f64(predicate_all_b64x, compensation_1_f64x));
|
|
233
|
+
nk_f64_dot2_(&outer_sum, &outer_comp, a2,
|
|
234
|
+
svaddv_f64(predicate_all_b64x, sum_2_f64x) + svaddv_f64(predicate_all_b64x, compensation_2_f64x));
|
|
235
|
+
nk_f64_dot2_(&outer_sum, &outer_comp, a3,
|
|
236
|
+
svaddv_f64(predicate_all_b64x, sum_3_f64x) + svaddv_f64(predicate_all_b64x, compensation_3_f64x));
|
|
236
237
|
}
|
|
237
238
|
|
|
238
239
|
// 1-row tail
|
|
239
|
-
for (; row <
|
|
240
|
+
for (; row < dimensions; ++row) {
|
|
240
241
|
svfloat64_t sum_f64x = svdup_f64(0.0), compensation_f64x = svdup_f64(0.0);
|
|
241
242
|
nk_size_t j = 0;
|
|
242
|
-
svbool_t
|
|
243
|
+
svbool_t predicate_b64x = svwhilelt_b64(j, dimensions);
|
|
243
244
|
|
|
244
|
-
while (svptest_first(
|
|
245
|
-
nk_dot2_f64_sve_accumulate_(
|
|
246
|
-
svld1_f64(
|
|
245
|
+
while (svptest_first(predicate_all_b64x, predicate_b64x)) {
|
|
246
|
+
nk_dot2_f64_sve_accumulate_(predicate_b64x, &sum_f64x, &compensation_f64x,
|
|
247
|
+
svld1_f64(predicate_b64x, c + row * dimensions + j),
|
|
248
|
+
svld1_f64(predicate_b64x, b + j));
|
|
247
249
|
j += svcntd();
|
|
248
|
-
|
|
250
|
+
predicate_b64x = svwhilelt_b64(j, dimensions);
|
|
249
251
|
}
|
|
250
252
|
|
|
251
|
-
nk_f64_t cb_j = svaddv_f64(
|
|
253
|
+
nk_f64_t cb_j = svaddv_f64(predicate_all_b64x, sum_f64x) + svaddv_f64(predicate_all_b64x, compensation_f64x);
|
|
252
254
|
nk_f64_dot2_(&outer_sum, &outer_comp, a[row], cb_j);
|
|
253
255
|
}
|
|
254
256
|
|
|
255
257
|
*result = outer_sum + outer_comp;
|
|
256
258
|
}
|
|
257
259
|
|
|
258
|
-
NK_PUBLIC void nk_bilinear_f64_smef64(nk_f64_t const *a, nk_f64_t const *b, nk_f64_t const *c, nk_size_t
|
|
260
|
+
NK_PUBLIC void nk_bilinear_f64_smef64(nk_f64_t const *a, nk_f64_t const *b, nk_f64_t const *c, nk_size_t dimensions,
|
|
259
261
|
nk_f64_t *result) {
|
|
260
|
-
nk_bilinear_f64_smef64_streaming_(a, b, c,
|
|
262
|
+
nk_bilinear_f64_smef64_streaming_(a, b, c, dimensions, result);
|
|
261
263
|
}
|
|
262
264
|
|
|
263
265
|
/**
|
|
264
266
|
* @brief f64 Mahalanobis: row-by-row streaming SVE with Dot2 compensation.
|
|
265
267
|
* 4-row fast path shares (a−b) column vector; 1-row tail for remainder.
|
|
266
268
|
*/
|
|
267
|
-
__arm_locally_streaming static
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
svbool_t predicate_all_f64x = svptrue_b64();
|
|
269
|
+
__arm_locally_streaming static nk_f64_t nk_mahalanobis_f64_smef64_streaming_(nk_f64_t const *a, nk_f64_t const *b,
|
|
270
|
+
nk_f64_t const *c, nk_size_t dimensions) {
|
|
271
|
+
svbool_t predicate_all_b64x = svptrue_b64();
|
|
271
272
|
nk_f64_t outer_sum = 0.0, outer_comp = 0.0;
|
|
272
273
|
nk_size_t row = 0;
|
|
273
274
|
|
|
274
275
|
// 4-row fast path: share (a−b) column vector across 4 rows
|
|
275
|
-
for (; row + 4 <=
|
|
276
|
+
for (; row + 4 <= dimensions; row += 4) {
|
|
276
277
|
nk_f64_t d0 = a[row + 0] - b[row + 0], d1 = a[row + 1] - b[row + 1];
|
|
277
278
|
nk_f64_t d2 = a[row + 2] - b[row + 2], d3 = a[row + 3] - b[row + 3];
|
|
278
279
|
svfloat64_t sum_0_f64x = svdup_f64(0), compensation_0_f64x = svdup_f64(0);
|
|
@@ -280,59 +281,59 @@ __arm_locally_streaming static inline nk_f64_t nk_mahalanobis_f64_smef64_streami
|
|
|
280
281
|
svfloat64_t sum_2_f64x = svdup_f64(0), compensation_2_f64x = svdup_f64(0);
|
|
281
282
|
svfloat64_t sum_3_f64x = svdup_f64(0), compensation_3_f64x = svdup_f64(0);
|
|
282
283
|
nk_size_t j = 0;
|
|
283
|
-
svbool_t
|
|
284
|
-
|
|
285
|
-
while (svptest_first(
|
|
286
|
-
svfloat64_t diff_col_f64x = svsub_f64_x(
|
|
287
|
-
svld1_f64(
|
|
288
|
-
nk_dot2_f64_sve_accumulate_(
|
|
289
|
-
svld1_f64(
|
|
290
|
-
nk_dot2_f64_sve_accumulate_(
|
|
291
|
-
svld1_f64(
|
|
292
|
-
nk_dot2_f64_sve_accumulate_(
|
|
293
|
-
svld1_f64(
|
|
294
|
-
nk_dot2_f64_sve_accumulate_(
|
|
295
|
-
svld1_f64(
|
|
284
|
+
svbool_t predicate_b64x = svwhilelt_b64(j, dimensions);
|
|
285
|
+
|
|
286
|
+
while (svptest_first(predicate_all_b64x, predicate_b64x)) {
|
|
287
|
+
svfloat64_t diff_col_f64x = svsub_f64_x(predicate_b64x, svld1_f64(predicate_b64x, a + j),
|
|
288
|
+
svld1_f64(predicate_b64x, b + j));
|
|
289
|
+
nk_dot2_f64_sve_accumulate_(predicate_b64x, &sum_0_f64x, &compensation_0_f64x,
|
|
290
|
+
svld1_f64(predicate_b64x, c + (row + 0) * dimensions + j), diff_col_f64x);
|
|
291
|
+
nk_dot2_f64_sve_accumulate_(predicate_b64x, &sum_1_f64x, &compensation_1_f64x,
|
|
292
|
+
svld1_f64(predicate_b64x, c + (row + 1) * dimensions + j), diff_col_f64x);
|
|
293
|
+
nk_dot2_f64_sve_accumulate_(predicate_b64x, &sum_2_f64x, &compensation_2_f64x,
|
|
294
|
+
svld1_f64(predicate_b64x, c + (row + 2) * dimensions + j), diff_col_f64x);
|
|
295
|
+
nk_dot2_f64_sve_accumulate_(predicate_b64x, &sum_3_f64x, &compensation_3_f64x,
|
|
296
|
+
svld1_f64(predicate_b64x, c + (row + 3) * dimensions + j), diff_col_f64x);
|
|
296
297
|
j += svcntd();
|
|
297
|
-
|
|
298
|
+
predicate_b64x = svwhilelt_b64(j, dimensions);
|
|
298
299
|
}
|
|
299
300
|
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
301
|
+
nk_f64_dot2_(&outer_sum, &outer_comp, d0,
|
|
302
|
+
svaddv_f64(predicate_all_b64x, sum_0_f64x) + svaddv_f64(predicate_all_b64x, compensation_0_f64x));
|
|
303
|
+
nk_f64_dot2_(&outer_sum, &outer_comp, d1,
|
|
304
|
+
svaddv_f64(predicate_all_b64x, sum_1_f64x) + svaddv_f64(predicate_all_b64x, compensation_1_f64x));
|
|
305
|
+
nk_f64_dot2_(&outer_sum, &outer_comp, d2,
|
|
306
|
+
svaddv_f64(predicate_all_b64x, sum_2_f64x) + svaddv_f64(predicate_all_b64x, compensation_2_f64x));
|
|
307
|
+
nk_f64_dot2_(&outer_sum, &outer_comp, d3,
|
|
308
|
+
svaddv_f64(predicate_all_b64x, sum_3_f64x) + svaddv_f64(predicate_all_b64x, compensation_3_f64x));
|
|
308
309
|
}
|
|
309
310
|
|
|
310
311
|
// 1-row tail
|
|
311
|
-
for (; row <
|
|
312
|
+
for (; row < dimensions; ++row) {
|
|
312
313
|
nk_f64_t diff_row = a[row] - b[row];
|
|
313
314
|
svfloat64_t sum_f64x = svdup_f64(0.0), compensation_f64x = svdup_f64(0.0);
|
|
314
315
|
nk_size_t j = 0;
|
|
315
|
-
svbool_t
|
|
316
|
+
svbool_t predicate_b64x = svwhilelt_b64(j, dimensions);
|
|
316
317
|
|
|
317
|
-
while (svptest_first(
|
|
318
|
-
svfloat64_t diff_col_f64x = svsub_f64_x(
|
|
319
|
-
svld1_f64(
|
|
320
|
-
nk_dot2_f64_sve_accumulate_(
|
|
321
|
-
svld1_f64(
|
|
318
|
+
while (svptest_first(predicate_all_b64x, predicate_b64x)) {
|
|
319
|
+
svfloat64_t diff_col_f64x = svsub_f64_x(predicate_b64x, svld1_f64(predicate_b64x, a + j),
|
|
320
|
+
svld1_f64(predicate_b64x, b + j));
|
|
321
|
+
nk_dot2_f64_sve_accumulate_(predicate_b64x, &sum_f64x, &compensation_f64x,
|
|
322
|
+
svld1_f64(predicate_b64x, c + row * dimensions + j), diff_col_f64x);
|
|
322
323
|
j += svcntd();
|
|
323
|
-
|
|
324
|
+
predicate_b64x = svwhilelt_b64(j, dimensions);
|
|
324
325
|
}
|
|
325
326
|
|
|
326
|
-
nk_f64_t cb_j = svaddv_f64(
|
|
327
|
+
nk_f64_t cb_j = svaddv_f64(predicate_all_b64x, sum_f64x) + svaddv_f64(predicate_all_b64x, compensation_f64x);
|
|
327
328
|
nk_f64_dot2_(&outer_sum, &outer_comp, diff_row, cb_j);
|
|
328
329
|
}
|
|
329
330
|
|
|
330
331
|
return outer_sum + outer_comp;
|
|
331
332
|
}
|
|
332
333
|
|
|
333
|
-
NK_PUBLIC void nk_mahalanobis_f64_smef64(nk_f64_t const *a, nk_f64_t const *b, nk_f64_t const *c, nk_size_t
|
|
334
|
+
NK_PUBLIC void nk_mahalanobis_f64_smef64(nk_f64_t const *a, nk_f64_t const *b, nk_f64_t const *c, nk_size_t dimensions,
|
|
334
335
|
nk_f64_t *result) {
|
|
335
|
-
nk_f64_t quadratic = nk_mahalanobis_f64_smef64_streaming_(a, b, c,
|
|
336
|
+
nk_f64_t quadratic = nk_mahalanobis_f64_smef64_streaming_(a, b, c, dimensions);
|
|
336
337
|
*result = nk_f64_sqrt_neon(quadratic > 0 ? quadratic : 0);
|
|
337
338
|
}
|
|
338
339
|
|
|
@@ -340,75 +341,78 @@ NK_PUBLIC void nk_mahalanobis_f64_smef64(nk_f64_t const *a, nk_f64_t const *b, n
|
|
|
340
341
|
* @brief f32c bilinear: complex GEMV via FMOPA (widening f32→f64).
|
|
341
342
|
* ZA0.D = C staging, ZA1.D = v_real accumulator, ZA2.D = v_imag accumulator.
|
|
342
343
|
*/
|
|
343
|
-
__arm_locally_streaming __arm_new("za") static void nk_bilinear_f32c_smef64_streaming_(
|
|
344
|
-
|
|
345
|
-
|
|
344
|
+
__arm_locally_streaming __arm_new("za") static void nk_bilinear_f32c_smef64_streaming_(nk_f32c_t const *a_pairs,
|
|
345
|
+
nk_f32c_t const *b_pairs,
|
|
346
|
+
nk_f32c_t const *c_pairs,
|
|
347
|
+
nk_size_t dimensions,
|
|
348
|
+
nk_f64c_t *results) {
|
|
349
|
+
svbool_t predicate_body_b64x = svptrue_b64();
|
|
346
350
|
nk_size_t tile_dimension = svcntd();
|
|
347
351
|
nk_f64_t outer_sum_real_f64 = 0.0, outer_sum_imag_f64 = 0.0;
|
|
348
352
|
|
|
349
|
-
for (nk_size_t row = 0; row <
|
|
350
|
-
nk_size_t rows_remaining = (row + tile_dimension <=
|
|
351
|
-
svbool_t
|
|
353
|
+
for (nk_size_t row = 0; row < dimensions; row += tile_dimension) {
|
|
354
|
+
nk_size_t rows_remaining = (row + tile_dimension <= dimensions) ? tile_dimension : (dimensions - row);
|
|
355
|
+
svbool_t row_predicate_b64x = svwhilelt_b64_u64(0u, rows_remaining);
|
|
352
356
|
|
|
353
357
|
svzero_mask_za(nk_sme_zero_za64_tile_1_);
|
|
354
358
|
svzero_mask_za(nk_sme_zero_za64_tile_2_);
|
|
355
359
|
|
|
356
|
-
for (nk_size_t j = 0; j <
|
|
357
|
-
nk_size_t batch_size = (j + tile_dimension <=
|
|
358
|
-
svbool_t
|
|
359
|
-
svbool_t
|
|
360
|
+
for (nk_size_t j = 0; j < dimensions; j += tile_dimension) {
|
|
361
|
+
nk_size_t batch_size = (j + tile_dimension <= dimensions) ? tile_dimension : (dimensions - j);
|
|
362
|
+
svbool_t batch_predicate_b64x = svwhilelt_b64_u64(0u, batch_size);
|
|
363
|
+
svbool_t batch_predicate_b32x = svwhilelt_b32_u64(0u, batch_size + batch_size);
|
|
360
364
|
|
|
361
365
|
// Pass 1: Stage C_real into ZA0
|
|
362
366
|
svzero_mask_za(nk_sme_zero_za64_tile_0_);
|
|
363
367
|
for (nk_size_t r = 0; r < rows_remaining; r++) {
|
|
364
|
-
svfloat32_t c_f32x = svld1_f32(
|
|
365
|
-
(nk_f32_t const *)c_pairs + ((row + r) *
|
|
366
|
-
svfloat64_t c_real_f64x = svcvt_f64_f32_x(
|
|
367
|
-
svwrite_hor_za64_f64_m(0, r,
|
|
368
|
+
svfloat32_t c_f32x = svld1_f32(batch_predicate_b32x,
|
|
369
|
+
(nk_f32_t const *)c_pairs + ((row + r) * dimensions + j) * 2);
|
|
370
|
+
svfloat64_t c_real_f64x = svcvt_f64_f32_x(batch_predicate_b64x, svtrn1_f32(c_f32x, c_f32x));
|
|
371
|
+
svwrite_hor_za64_f64_m(0, r, batch_predicate_b64x, c_real_f64x);
|
|
368
372
|
}
|
|
369
373
|
|
|
370
374
|
for (nk_size_t k = 0; k < batch_size; k++) {
|
|
371
|
-
svfloat64_t c_re_col_f64x = svread_ver_za64_f64_m(svdup_f64(0.0),
|
|
372
|
-
svmopa_za64_f64_m(1,
|
|
375
|
+
svfloat64_t c_re_col_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), row_predicate_b64x, 0, k);
|
|
376
|
+
svmopa_za64_f64_m(1, row_predicate_b64x, row_predicate_b64x, c_re_col_f64x,
|
|
373
377
|
svdup_f64((nk_f64_t)b_pairs[j + k].real)); // v_real += c_real × b_real
|
|
374
|
-
svmopa_za64_f64_m(2,
|
|
378
|
+
svmopa_za64_f64_m(2, row_predicate_b64x, row_predicate_b64x, c_re_col_f64x,
|
|
375
379
|
svdup_f64((nk_f64_t)b_pairs[j + k].imag)); // v_imag += c_real × b_imag
|
|
376
380
|
}
|
|
377
381
|
|
|
378
382
|
// Pass 2: Stage C_imag into ZA0
|
|
379
383
|
svzero_mask_za(nk_sme_zero_za64_tile_0_);
|
|
380
384
|
for (nk_size_t r = 0; r < rows_remaining; r++) {
|
|
381
|
-
svfloat32_t c_f32x = svld1_f32(
|
|
382
|
-
(nk_f32_t const *)c_pairs + ((row + r) *
|
|
383
|
-
svfloat64_t c_imag_f64x = svcvt_f64_f32_x(
|
|
384
|
-
svwrite_hor_za64_f64_m(0, r,
|
|
385
|
+
svfloat32_t c_f32x = svld1_f32(batch_predicate_b32x,
|
|
386
|
+
(nk_f32_t const *)c_pairs + ((row + r) * dimensions + j) * 2);
|
|
387
|
+
svfloat64_t c_imag_f64x = svcvt_f64_f32_x(batch_predicate_b64x, svtrn2_f32(c_f32x, c_f32x));
|
|
388
|
+
svwrite_hor_za64_f64_m(0, r, batch_predicate_b64x, c_imag_f64x);
|
|
385
389
|
}
|
|
386
390
|
|
|
387
391
|
for (nk_size_t k = 0; k < batch_size; k++) {
|
|
388
|
-
svfloat64_t c_im_col_f64x = svread_ver_za64_f64_m(svdup_f64(0.0),
|
|
389
|
-
svmopa_za64_f64_m(2,
|
|
392
|
+
svfloat64_t c_im_col_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), row_predicate_b64x, 0, k);
|
|
393
|
+
svmopa_za64_f64_m(2, row_predicate_b64x, row_predicate_b64x, c_im_col_f64x,
|
|
390
394
|
svdup_f64((nk_f64_t)b_pairs[j + k].real)); // v_imag += c_imag × b_real
|
|
391
|
-
svmops_za64_f64_m(1,
|
|
395
|
+
svmops_za64_f64_m(1, row_predicate_b64x, row_predicate_b64x, c_im_col_f64x,
|
|
392
396
|
svdup_f64((nk_f64_t)b_pairs[j + k].imag)); // v_real -= c_imag × b_imag
|
|
393
397
|
}
|
|
394
398
|
}
|
|
395
399
|
|
|
396
|
-
svfloat64_t v_re_f64x = svread_ver_za64_f64_m(svdup_f64(0.0),
|
|
397
|
-
svfloat64_t v_im_f64x = svread_ver_za64_f64_m(svdup_f64(0.0),
|
|
400
|
+
svfloat64_t v_re_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), row_predicate_b64x, 1, 0);
|
|
401
|
+
svfloat64_t v_im_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), row_predicate_b64x, 2, 0);
|
|
398
402
|
|
|
399
403
|
// Deinterleave a[row:row+tile]
|
|
400
|
-
svbool_t
|
|
401
|
-
svfloat32_t a_f32x = svld1_f32(
|
|
402
|
-
svfloat64_t a_re_f64x = svcvt_f64_f32_x(
|
|
403
|
-
svfloat64_t a_im_f64x = svcvt_f64_f32_x(
|
|
404
|
+
svbool_t row_predicate_b32x = svwhilelt_b32_u64(0u, rows_remaining + rows_remaining);
|
|
405
|
+
svfloat32_t a_f32x = svld1_f32(row_predicate_b32x, (nk_f32_t const *)a_pairs + row * 2);
|
|
406
|
+
svfloat64_t a_re_f64x = svcvt_f64_f32_x(row_predicate_b64x, svtrn1_f32(a_f32x, a_f32x));
|
|
407
|
+
svfloat64_t a_im_f64x = svcvt_f64_f32_x(row_predicate_b64x, svtrn2_f32(a_f32x, a_f32x));
|
|
404
408
|
|
|
405
409
|
// Complex dot: a × v
|
|
406
410
|
outer_sum_real_f64 += svaddv_f64(
|
|
407
|
-
|
|
408
|
-
svmul_f64_x(
|
|
411
|
+
predicate_body_b64x, svsub_f64_x(row_predicate_b64x, svmul_f64_x(row_predicate_b64x, a_re_f64x, v_re_f64x),
|
|
412
|
+
svmul_f64_x(row_predicate_b64x, a_im_f64x, v_im_f64x)));
|
|
409
413
|
outer_sum_imag_f64 += svaddv_f64(
|
|
410
|
-
|
|
411
|
-
svmul_f64_x(
|
|
414
|
+
predicate_body_b64x, svadd_f64_x(row_predicate_b64x, svmul_f64_x(row_predicate_b64x, a_re_f64x, v_im_f64x),
|
|
415
|
+
svmul_f64_x(row_predicate_b64x, a_im_f64x, v_re_f64x)));
|
|
412
416
|
}
|
|
413
417
|
|
|
414
418
|
results->real = outer_sum_real_f64;
|
|
@@ -416,8 +420,8 @@ __arm_locally_streaming __arm_new("za") static void nk_bilinear_f32c_smef64_stre
|
|
|
416
420
|
}
|
|
417
421
|
|
|
418
422
|
NK_PUBLIC void nk_bilinear_f32c_smef64(nk_f32c_t const *a_pairs, nk_f32c_t const *b_pairs, nk_f32c_t const *c_pairs,
|
|
419
|
-
nk_size_t
|
|
420
|
-
nk_bilinear_f32c_smef64_streaming_(a_pairs, b_pairs, c_pairs,
|
|
423
|
+
nk_size_t dimensions, nk_f64c_t *results) {
|
|
424
|
+
nk_bilinear_f32c_smef64_streaming_(a_pairs, b_pairs, c_pairs, dimensions, results);
|
|
421
425
|
}
|
|
422
426
|
|
|
423
427
|
/**
|
|
@@ -426,20 +430,20 @@ NK_PUBLIC void nk_bilinear_f32c_smef64(nk_f32c_t const *a_pairs, nk_f32c_t const
|
|
|
426
430
|
*/
|
|
427
431
|
__arm_locally_streaming static void nk_bilinear_f64c_smef64_streaming_(nk_f64c_t const *a_pairs,
|
|
428
432
|
nk_f64c_t const *b_pairs,
|
|
429
|
-
nk_f64c_t const *c_pairs, nk_size_t
|
|
433
|
+
nk_f64c_t const *c_pairs, nk_size_t dimensions,
|
|
430
434
|
nk_f64c_t *results) {
|
|
431
|
-
svbool_t
|
|
435
|
+
svbool_t predicate_all_b64x = svptrue_b64();
|
|
432
436
|
nk_f64_t outer_sum_real = 0.0, outer_comp_real = 0.0;
|
|
433
437
|
nk_f64_t outer_sum_imag = 0.0, outer_comp_imag = 0.0;
|
|
434
|
-
nk_size_t const n2 =
|
|
438
|
+
nk_size_t const n2 = dimensions * 2; // total f64 elements in interleaved layout
|
|
435
439
|
|
|
436
440
|
// swap_idx_u64x = [1,0,3,2,5,4,...] — swap adjacent f64 lanes
|
|
437
|
-
svuint64_t swap_idx_u64x = sveor_u64_x(
|
|
441
|
+
svuint64_t swap_idx_u64x = sveor_u64_x(predicate_all_b64x, svindex_u64(0, 1), svdup_u64(1));
|
|
438
442
|
// sign_mask_u64x = [0, 0x8000..., 0, 0x8000..., ...] — sign bit in odd positions
|
|
439
443
|
svuint64_t sign_mask_u64x = svlsl_u64_x(
|
|
440
|
-
|
|
444
|
+
predicate_all_b64x, svand_u64_x(predicate_all_b64x, svindex_u64(0, 1), svdup_u64(1)), svdup_u64(63));
|
|
441
445
|
|
|
442
|
-
for (nk_size_t row = 0; row <
|
|
446
|
+
for (nk_size_t row = 0; row < dimensions; ++row) {
|
|
443
447
|
nk_f64_t a_real = a_pairs[row].real;
|
|
444
448
|
nk_f64_t a_imag = a_pairs[row].imag;
|
|
445
449
|
|
|
@@ -447,33 +451,33 @@ __arm_locally_streaming static void nk_bilinear_f64c_smef64_streaming_(nk_f64c_t
|
|
|
447
451
|
svfloat64_t sum_real_f64x = svdup_f64(0), comp_real_f64x = svdup_f64(0);
|
|
448
452
|
svfloat64_t sum_imag_f64x = svdup_f64(0), comp_imag_f64x = svdup_f64(0);
|
|
449
453
|
nk_size_t j = 0;
|
|
450
|
-
svbool_t
|
|
454
|
+
svbool_t predicate_b64x = svwhilelt_b64(j, n2);
|
|
451
455
|
|
|
452
|
-
while (svptest_first(
|
|
456
|
+
while (svptest_first(predicate_all_b64x, predicate_b64x)) {
|
|
453
457
|
// Load interleaved [re₀, im₀, re₁, im₁, ...] — no deinterleave needed
|
|
454
|
-
svfloat64_t b_f64x = svld1_f64(
|
|
455
|
-
svfloat64_t c_f64x = svld1_f64(
|
|
458
|
+
svfloat64_t b_f64x = svld1_f64(predicate_b64x, (nk_f64_t const *)b_pairs + j);
|
|
459
|
+
svfloat64_t c_f64x = svld1_f64(predicate_b64x, (nk_f64_t const *)c_pairs + row * n2 + j);
|
|
456
460
|
svfloat64_t c_swapped_f64x = svtbl_f64(c_f64x, swap_idx_u64x);
|
|
457
461
|
|
|
458
462
|
// 2 Dot2 accumulators instead of 4:
|
|
459
463
|
// sum_real_f64x accumulates [c_real×b_real, c_imag×b_imag, ...] (sign-flip deferred)
|
|
460
464
|
// sum_imag_f64x accumulates [c_imag×b_real, c_real×b_imag, ...] (all positive)
|
|
461
|
-
nk_dot2_f64_sve_accumulate_(
|
|
462
|
-
nk_dot2_f64_sve_accumulate_(
|
|
465
|
+
nk_dot2_f64_sve_accumulate_(predicate_b64x, &sum_real_f64x, &comp_real_f64x, c_f64x, b_f64x);
|
|
466
|
+
nk_dot2_f64_sve_accumulate_(predicate_b64x, &sum_imag_f64x, &comp_imag_f64x, c_swapped_f64x, b_f64x);
|
|
463
467
|
|
|
464
468
|
j += svcntd();
|
|
465
|
-
|
|
469
|
+
predicate_b64x = svwhilelt_b64(j, n2);
|
|
466
470
|
}
|
|
467
471
|
|
|
468
472
|
// Flip sign of odd positions in sum_real_f64x: [c_real×b_real, -(c_imag×b_imag), ...]
|
|
469
473
|
sum_real_f64x = svreinterpret_f64_u64(
|
|
470
|
-
sveor_u64_x(
|
|
474
|
+
sveor_u64_x(predicate_all_b64x, svreinterpret_u64_f64(sum_real_f64x), sign_mask_u64x));
|
|
471
475
|
comp_real_f64x = svreinterpret_f64_u64(
|
|
472
|
-
sveor_u64_x(
|
|
473
|
-
nk_f64_t inner_real = svaddv_f64(
|
|
474
|
-
svadd_f64_x(
|
|
475
|
-
nk_f64_t inner_imag = svaddv_f64(
|
|
476
|
-
svadd_f64_x(
|
|
476
|
+
sveor_u64_x(predicate_all_b64x, svreinterpret_u64_f64(comp_real_f64x), sign_mask_u64x));
|
|
477
|
+
nk_f64_t inner_real = svaddv_f64(predicate_all_b64x,
|
|
478
|
+
svadd_f64_x(predicate_all_b64x, sum_real_f64x, comp_real_f64x));
|
|
479
|
+
nk_f64_t inner_imag = svaddv_f64(predicate_all_b64x,
|
|
480
|
+
svadd_f64_x(predicate_all_b64x, sum_imag_f64x, comp_imag_f64x));
|
|
477
481
|
|
|
478
482
|
// Outer Dot2 complex multiply: a × inner
|
|
479
483
|
nk_f64_dot2_(&outer_sum_real, &outer_comp_real, a_real, inner_real);
|
|
@@ -487,8 +491,8 @@ __arm_locally_streaming static void nk_bilinear_f64c_smef64_streaming_(nk_f64c_t
|
|
|
487
491
|
}
|
|
488
492
|
|
|
489
493
|
NK_PUBLIC void nk_bilinear_f64c_smef64(nk_f64c_t const *a_pairs, nk_f64c_t const *b_pairs, nk_f64c_t const *c_pairs,
|
|
490
|
-
nk_size_t
|
|
491
|
-
nk_bilinear_f64c_smef64_streaming_(a_pairs, b_pairs, c_pairs,
|
|
494
|
+
nk_size_t dimensions, nk_f64c_t *results) {
|
|
495
|
+
nk_bilinear_f64c_smef64_streaming_(a_pairs, b_pairs, c_pairs, dimensions, results);
|
|
492
496
|
}
|
|
493
497
|
|
|
494
498
|
#if defined(__clang__)
|