numkong 7.0.0 → 7.4.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +197 -124
- package/binding.gyp +34 -484
- package/c/dispatch_bf16.c +59 -1
- package/c/dispatch_e2m3.c +41 -8
- package/c/dispatch_e3m2.c +49 -8
- package/c/dispatch_e4m3.c +51 -9
- package/c/dispatch_e5m2.c +45 -1
- package/c/dispatch_f16.c +79 -26
- package/c/dispatch_f16c.c +5 -5
- package/c/dispatch_f32.c +56 -0
- package/c/dispatch_f64.c +52 -0
- package/c/dispatch_i4.c +3 -0
- package/c/dispatch_i8.c +62 -3
- package/c/dispatch_other.c +18 -0
- package/c/dispatch_u1.c +54 -9
- package/c/dispatch_u4.c +3 -0
- package/c/dispatch_u8.c +64 -3
- package/c/numkong.c +3 -0
- package/include/README.md +79 -9
- package/include/numkong/attention/sapphireamx.h +278 -276
- package/include/numkong/attention/sme.h +983 -977
- package/include/numkong/attention.h +1 -1
- package/include/numkong/capabilities.h +289 -94
- package/include/numkong/cast/README.md +40 -40
- package/include/numkong/cast/diamond.h +64 -0
- package/include/numkong/cast/haswell.h +42 -194
- package/include/numkong/cast/icelake.h +42 -37
- package/include/numkong/cast/loongsonasx.h +252 -0
- package/include/numkong/cast/neon.h +216 -249
- package/include/numkong/cast/powervsx.h +449 -0
- package/include/numkong/cast/rvv.h +223 -274
- package/include/numkong/cast/sapphire.h +18 -18
- package/include/numkong/cast/serial.h +1018 -944
- package/include/numkong/cast/skylake.h +82 -23
- package/include/numkong/cast/v128relaxed.h +462 -105
- package/include/numkong/cast.h +24 -0
- package/include/numkong/cast.hpp +44 -0
- package/include/numkong/curved/README.md +17 -17
- package/include/numkong/curved/neon.h +131 -7
- package/include/numkong/curved/neonbfdot.h +6 -7
- package/include/numkong/curved/rvv.h +26 -26
- package/include/numkong/curved/smef64.h +186 -182
- package/include/numkong/curved.h +14 -18
- package/include/numkong/dot/README.md +154 -137
- package/include/numkong/dot/alder.h +43 -43
- package/include/numkong/dot/diamond.h +158 -0
- package/include/numkong/dot/genoa.h +4 -30
- package/include/numkong/dot/haswell.h +215 -180
- package/include/numkong/dot/icelake.h +190 -76
- package/include/numkong/dot/loongsonasx.h +671 -0
- package/include/numkong/dot/neon.h +124 -73
- package/include/numkong/dot/neonbfdot.h +11 -12
- package/include/numkong/dot/neonfhm.h +44 -46
- package/include/numkong/dot/neonfp8.h +323 -0
- package/include/numkong/dot/neonsdot.h +190 -76
- package/include/numkong/dot/powervsx.h +752 -0
- package/include/numkong/dot/rvv.h +92 -84
- package/include/numkong/dot/rvvbf16.h +12 -12
- package/include/numkong/dot/rvvhalf.h +12 -12
- package/include/numkong/dot/sapphire.h +4 -4
- package/include/numkong/dot/serial.h +66 -30
- package/include/numkong/dot/sierra.h +31 -31
- package/include/numkong/dot/skylake.h +142 -110
- package/include/numkong/dot/sve.h +217 -177
- package/include/numkong/dot/svebfdot.h +10 -10
- package/include/numkong/dot/svehalf.h +85 -41
- package/include/numkong/dot/svesdot.h +89 -0
- package/include/numkong/dot/v128relaxed.h +124 -89
- package/include/numkong/dot.h +114 -48
- package/include/numkong/dots/README.md +203 -203
- package/include/numkong/dots/alder.h +12 -9
- package/include/numkong/dots/diamond.h +86 -0
- package/include/numkong/dots/genoa.h +10 -4
- package/include/numkong/dots/haswell.h +63 -48
- package/include/numkong/dots/icelake.h +27 -18
- package/include/numkong/dots/loongsonasx.h +176 -0
- package/include/numkong/dots/neon.h +14 -11
- package/include/numkong/dots/neonbfdot.h +4 -3
- package/include/numkong/dots/neonfhm.h +11 -9
- package/include/numkong/dots/neonfp8.h +99 -0
- package/include/numkong/dots/neonsdot.h +48 -12
- package/include/numkong/dots/powervsx.h +194 -0
- package/include/numkong/dots/rvv.h +451 -344
- package/include/numkong/dots/sapphireamx.h +1028 -984
- package/include/numkong/dots/serial.h +213 -197
- package/include/numkong/dots/sierra.h +10 -7
- package/include/numkong/dots/skylake.h +47 -36
- package/include/numkong/dots/sme.h +2001 -2364
- package/include/numkong/dots/smebi32.h +175 -162
- package/include/numkong/dots/smef64.h +328 -323
- package/include/numkong/dots/v128relaxed.h +64 -41
- package/include/numkong/dots.h +573 -293
- package/include/numkong/dots.hpp +45 -43
- package/include/numkong/each/README.md +133 -137
- package/include/numkong/each/haswell.h +6 -6
- package/include/numkong/each/icelake.h +7 -7
- package/include/numkong/each/neon.h +76 -42
- package/include/numkong/each/neonbfdot.h +11 -12
- package/include/numkong/each/neonhalf.h +24 -116
- package/include/numkong/each/rvv.h +28 -28
- package/include/numkong/each/sapphire.h +27 -161
- package/include/numkong/each/serial.h +6 -6
- package/include/numkong/each/skylake.h +7 -7
- package/include/numkong/each/v128relaxed.h +562 -0
- package/include/numkong/each.h +148 -62
- package/include/numkong/each.hpp +2 -2
- package/include/numkong/geospatial/README.md +18 -18
- package/include/numkong/geospatial/haswell.h +365 -325
- package/include/numkong/geospatial/neon.h +350 -306
- package/include/numkong/geospatial/rvv.h +4 -4
- package/include/numkong/geospatial/skylake.h +376 -340
- package/include/numkong/geospatial/v128relaxed.h +366 -327
- package/include/numkong/geospatial.h +17 -17
- package/include/numkong/matrix.hpp +4 -4
- package/include/numkong/maxsim/README.md +14 -14
- package/include/numkong/maxsim/alder.h +6 -6
- package/include/numkong/maxsim/genoa.h +4 -4
- package/include/numkong/maxsim/haswell.h +6 -6
- package/include/numkong/maxsim/icelake.h +18 -18
- package/include/numkong/maxsim/neonsdot.h +21 -21
- package/include/numkong/maxsim/sapphireamx.h +14 -14
- package/include/numkong/maxsim/serial.h +6 -6
- package/include/numkong/maxsim/sme.h +221 -196
- package/include/numkong/maxsim/v128relaxed.h +6 -6
- package/include/numkong/mesh/README.md +62 -56
- package/include/numkong/mesh/haswell.h +339 -464
- package/include/numkong/mesh/neon.h +1100 -519
- package/include/numkong/mesh/neonbfdot.h +36 -68
- package/include/numkong/mesh/rvv.h +530 -435
- package/include/numkong/mesh/serial.h +75 -91
- package/include/numkong/mesh/skylake.h +1627 -302
- package/include/numkong/mesh/v128relaxed.h +443 -330
- package/include/numkong/mesh.h +63 -49
- package/include/numkong/mesh.hpp +4 -4
- package/include/numkong/numkong.h +3 -3
- package/include/numkong/numkong.hpp +1 -0
- package/include/numkong/probability/README.md +23 -19
- package/include/numkong/probability/neon.h +82 -52
- package/include/numkong/probability/rvv.h +28 -23
- package/include/numkong/probability/serial.h +51 -39
- package/include/numkong/probability.h +20 -23
- package/include/numkong/random.h +1 -1
- package/include/numkong/reduce/README.md +143 -138
- package/include/numkong/reduce/alder.h +81 -77
- package/include/numkong/reduce/haswell.h +222 -220
- package/include/numkong/reduce/neon.h +629 -519
- package/include/numkong/reduce/neonbfdot.h +7 -218
- package/include/numkong/reduce/neonfhm.h +9 -381
- package/include/numkong/reduce/neonsdot.h +9 -9
- package/include/numkong/reduce/rvv.h +928 -802
- package/include/numkong/reduce/serial.h +23 -27
- package/include/numkong/reduce/sierra.h +20 -20
- package/include/numkong/reduce/skylake.h +326 -324
- package/include/numkong/reduce/v128relaxed.h +52 -52
- package/include/numkong/reduce.h +4 -23
- package/include/numkong/reduce.hpp +156 -11
- package/include/numkong/scalar/README.md +6 -6
- package/include/numkong/scalar/haswell.h +26 -17
- package/include/numkong/scalar/loongsonasx.h +74 -0
- package/include/numkong/scalar/neon.h +9 -9
- package/include/numkong/scalar/powervsx.h +96 -0
- package/include/numkong/scalar/rvv.h +2 -2
- package/include/numkong/scalar/sapphire.h +21 -10
- package/include/numkong/scalar/serial.h +21 -21
- package/include/numkong/scalar.h +13 -0
- package/include/numkong/set/README.md +28 -28
- package/include/numkong/set/haswell.h +12 -12
- package/include/numkong/set/icelake.h +14 -14
- package/include/numkong/set/loongsonasx.h +181 -0
- package/include/numkong/set/neon.h +17 -18
- package/include/numkong/set/powervsx.h +326 -0
- package/include/numkong/set/rvv.h +4 -4
- package/include/numkong/set/serial.h +6 -6
- package/include/numkong/set/sve.h +60 -59
- package/include/numkong/set/v128relaxed.h +6 -6
- package/include/numkong/set.h +21 -7
- package/include/numkong/sets/README.md +26 -26
- package/include/numkong/sets/loongsonasx.h +52 -0
- package/include/numkong/sets/powervsx.h +65 -0
- package/include/numkong/sets/smebi32.h +395 -364
- package/include/numkong/sets.h +83 -40
- package/include/numkong/sparse/README.md +4 -4
- package/include/numkong/sparse/icelake.h +101 -101
- package/include/numkong/sparse/serial.h +1 -1
- package/include/numkong/sparse/sve2.h +137 -141
- package/include/numkong/sparse/turin.h +12 -12
- package/include/numkong/sparse.h +10 -10
- package/include/numkong/spatial/README.md +230 -226
- package/include/numkong/spatial/alder.h +113 -116
- package/include/numkong/spatial/diamond.h +240 -0
- package/include/numkong/spatial/genoa.h +0 -68
- package/include/numkong/spatial/haswell.h +74 -55
- package/include/numkong/spatial/icelake.h +539 -58
- package/include/numkong/spatial/loongsonasx.h +483 -0
- package/include/numkong/spatial/neon.h +125 -52
- package/include/numkong/spatial/neonbfdot.h +8 -9
- package/include/numkong/spatial/neonfp8.h +258 -0
- package/include/numkong/spatial/neonsdot.h +180 -12
- package/include/numkong/spatial/powervsx.h +738 -0
- package/include/numkong/spatial/rvv.h +146 -139
- package/include/numkong/spatial/rvvbf16.h +17 -12
- package/include/numkong/spatial/rvvhalf.h +13 -10
- package/include/numkong/spatial/serial.h +13 -12
- package/include/numkong/spatial/sierra.h +232 -39
- package/include/numkong/spatial/skylake.h +73 -74
- package/include/numkong/spatial/sve.h +93 -72
- package/include/numkong/spatial/svebfdot.h +29 -29
- package/include/numkong/spatial/svehalf.h +52 -26
- package/include/numkong/spatial/svesdot.h +142 -0
- package/include/numkong/spatial/v128relaxed.h +293 -41
- package/include/numkong/spatial.h +338 -82
- package/include/numkong/spatials/README.md +194 -194
- package/include/numkong/spatials/diamond.h +82 -0
- package/include/numkong/spatials/haswell.h +2 -2
- package/include/numkong/spatials/loongsonasx.h +153 -0
- package/include/numkong/spatials/neonfp8.h +111 -0
- package/include/numkong/spatials/neonsdot.h +34 -0
- package/include/numkong/spatials/powervsx.h +153 -0
- package/include/numkong/spatials/rvv.h +259 -243
- package/include/numkong/spatials/sapphireamx.h +173 -173
- package/include/numkong/spatials/serial.h +2 -2
- package/include/numkong/spatials/skylake.h +2 -2
- package/include/numkong/spatials/sme.h +590 -605
- package/include/numkong/spatials/smef64.h +139 -130
- package/include/numkong/spatials/v128relaxed.h +2 -2
- package/include/numkong/spatials.h +820 -500
- package/include/numkong/spatials.hpp +49 -48
- package/include/numkong/tensor.hpp +406 -17
- package/include/numkong/trigonometry/README.md +19 -19
- package/include/numkong/trigonometry/haswell.h +402 -401
- package/include/numkong/trigonometry/neon.h +386 -387
- package/include/numkong/trigonometry/rvv.h +52 -51
- package/include/numkong/trigonometry/serial.h +13 -13
- package/include/numkong/trigonometry/skylake.h +373 -369
- package/include/numkong/trigonometry/v128relaxed.h +375 -374
- package/include/numkong/trigonometry.h +13 -13
- package/include/numkong/trigonometry.hpp +2 -2
- package/include/numkong/types.h +287 -49
- package/include/numkong/types.hpp +436 -12
- package/include/numkong/vector.hpp +82 -14
- package/javascript/dist/cjs/numkong-wasm.js +6 -12
- package/javascript/dist/cjs/numkong.d.ts +7 -1
- package/javascript/dist/cjs/numkong.js +37 -11
- package/javascript/dist/cjs/types.d.ts +9 -0
- package/javascript/dist/cjs/types.js +96 -0
- package/javascript/dist/esm/numkong-browser.d.ts +14 -0
- package/javascript/dist/esm/numkong-browser.js +23 -0
- package/javascript/dist/esm/numkong-wasm.js +6 -12
- package/javascript/dist/esm/numkong.d.ts +7 -1
- package/javascript/dist/esm/numkong.js +37 -11
- package/javascript/dist/esm/types.d.ts +9 -0
- package/javascript/dist/esm/types.js +96 -0
- package/javascript/node-gyp-build.d.ts +4 -1
- package/javascript/numkong-browser.ts +40 -0
- package/javascript/numkong-wasm.ts +7 -13
- package/javascript/numkong.c +5 -26
- package/javascript/numkong.ts +36 -11
- package/javascript/tsconfig-base.json +1 -0
- package/javascript/tsconfig-cjs.json +6 -1
- package/javascript/types.ts +110 -0
- package/numkong.gypi +101 -0
- package/package.json +34 -13
- package/probes/arm_neon.c +8 -0
- package/probes/arm_neon_bfdot.c +9 -0
- package/probes/arm_neon_fhm.c +9 -0
- package/probes/arm_neon_half.c +8 -0
- package/probes/arm_neon_sdot.c +9 -0
- package/probes/arm_neonfp8.c +9 -0
- package/probes/arm_sme.c +16 -0
- package/probes/arm_sme2.c +16 -0
- package/probes/arm_sme2p1.c +16 -0
- package/probes/arm_sme_bf16.c +16 -0
- package/probes/arm_sme_bi32.c +16 -0
- package/probes/arm_sme_f64.c +16 -0
- package/probes/arm_sme_fa64.c +14 -0
- package/probes/arm_sme_half.c +16 -0
- package/probes/arm_sme_lut2.c +15 -0
- package/probes/arm_sve.c +18 -0
- package/probes/arm_sve2.c +20 -0
- package/probes/arm_sve2p1.c +18 -0
- package/probes/arm_sve_bfdot.c +20 -0
- package/probes/arm_sve_half.c +18 -0
- package/probes/arm_sve_sdot.c +21 -0
- package/probes/loongarch_lasx.c +12 -0
- package/probes/power_vsx.c +12 -0
- package/probes/probe.js +127 -0
- package/probes/riscv_rvv.c +14 -0
- package/probes/riscv_rvv_bb.c +15 -0
- package/probes/riscv_rvv_bf16.c +17 -0
- package/probes/riscv_rvv_half.c +14 -0
- package/probes/wasm_v128relaxed.c +11 -0
- package/probes/x86_alder.c +17 -0
- package/probes/x86_diamond.c +17 -0
- package/probes/x86_genoa.c +17 -0
- package/probes/x86_graniteamx.c +19 -0
- package/probes/x86_haswell.c +11 -0
- package/probes/x86_icelake.c +17 -0
- package/probes/x86_sapphire.c +16 -0
- package/probes/x86_sapphireamx.c +18 -0
- package/probes/x86_sierra.c +17 -0
- package/probes/x86_skylake.c +15 -0
- package/probes/x86_turin.c +17 -0
- package/wasm/numkong-emscripten.js +2 -0
- package/wasm/numkong.d.ts +14 -0
- package/wasm/numkong.js +1124 -0
- package/wasm/numkong.wasm +0 -0
- package/include/numkong/curved/neonhalf.h +0 -212
- package/include/numkong/dot/neonhalf.h +0 -198
- package/include/numkong/dots/neonhalf.h +0 -57
- package/include/numkong/mesh/neonhalf.h +0 -616
- package/include/numkong/reduce/neonhalf.h +0 -157
- package/include/numkong/spatial/neonhalf.h +0 -118
- package/include/numkong/spatial/sapphire.h +0 -343
- package/include/numkong/spatials/neonhalf.h +0 -58
- package/javascript/README.md +0 -246
|
@@ -45,7 +45,7 @@ extern "C" {
|
|
|
45
45
|
#endif
|
|
46
46
|
|
|
47
47
|
#if defined(__clang__)
|
|
48
|
-
#pragma clang attribute push(__attribute__((target("sme,
|
|
48
|
+
#pragma clang attribute push(__attribute__((target("sme,sme-f64f64"))), apply_to = function)
|
|
49
49
|
#elif defined(__GNUC__)
|
|
50
50
|
#pragma GCC push_options
|
|
51
51
|
#pragma GCC target("+sme+sme-f64f64")
|
|
@@ -72,11 +72,11 @@ extern "C" {
|
|
|
72
72
|
* for higher-than-f32 accumulation precision; replacing it with f32 FMOPA would be
|
|
73
73
|
* counterproductive. Apple M4 has `hw.optional.arm.SME_F32F32: 1` but we don't use it here.
|
|
74
74
|
*/
|
|
75
|
-
#pragma region
|
|
75
|
+
#pragma region F32 Floats
|
|
76
76
|
|
|
77
77
|
NK_PUBLIC nk_size_t nk_dots_packed_size_f32_smef64(nk_size_t columns, nk_size_t depth) {
|
|
78
|
-
nk_size_t const tile_dimension =
|
|
79
|
-
nk_size_t const depth_tile_size =
|
|
78
|
+
nk_size_t const tile_dimension = nk_sme_cntd_(); // rows per `ZA64` tile (8 for SVL=512)
|
|
79
|
+
nk_size_t const depth_tile_size = nk_sme_cntw_(); // `f32` depth elements per tile (16 for SVL=512)
|
|
80
80
|
|
|
81
81
|
nk_size_t const column_tile_count = nk_size_divide_round_up_(columns, tile_dimension);
|
|
82
82
|
nk_size_t const depth_tile_count = nk_size_divide_round_up_(depth, depth_tile_size);
|
|
@@ -88,13 +88,13 @@ NK_PUBLIC nk_size_t nk_dots_packed_size_f32_smef64(nk_size_t columns, nk_size_t
|
|
|
88
88
|
return size;
|
|
89
89
|
}
|
|
90
90
|
|
|
91
|
-
NK_PUBLIC void nk_dots_pack_f32_smef64(nk_f32_t const *b, nk_size_t columns, nk_size_t depth,
|
|
92
|
-
void *b_packed) {
|
|
91
|
+
NK_PUBLIC void nk_dots_pack_f32_smef64(nk_f32_t const *b, nk_size_t columns, nk_size_t depth,
|
|
92
|
+
nk_size_t b_stride_in_bytes, void *b_packed) {
|
|
93
93
|
|
|
94
|
-
nk_size_t const tile_dimension =
|
|
95
|
-
nk_size_t const depth_tile_size =
|
|
94
|
+
nk_size_t const tile_dimension = nk_sme_cntd_(); // rows per `ZA64` tile (8 for SVL=512)
|
|
95
|
+
nk_size_t const depth_tile_size = nk_sme_cntw_(); // `f32` depth elements per tile (16 for SVL=512)
|
|
96
96
|
nk_size_t const tile_elements = tile_dimension * depth_tile_size; // 128
|
|
97
|
-
nk_size_t const b_stride_elements =
|
|
97
|
+
nk_size_t const b_stride_elements = b_stride_in_bytes / sizeof(nk_f32_t);
|
|
98
98
|
|
|
99
99
|
nk_size_t const column_tile_count = nk_size_divide_round_up_(columns, tile_dimension);
|
|
100
100
|
nk_size_t const depth_tile_count = nk_size_divide_round_up_(depth, depth_tile_size);
|
|
@@ -106,7 +106,7 @@ NK_PUBLIC void nk_dots_pack_f32_smef64(nk_f32_t const *b, nk_size_t columns, nk_
|
|
|
106
106
|
header->depth_tile_count = (nk_u32_t)depth_tile_count;
|
|
107
107
|
header->columns = (nk_u32_t)columns;
|
|
108
108
|
header->depth = (nk_u32_t)depth;
|
|
109
|
-
header->svl_bytes = (nk_u32_t)
|
|
109
|
+
header->svl_bytes = (nk_u32_t)nk_sme_cntb_(); // streaming vector length in bytes
|
|
110
110
|
|
|
111
111
|
nk_f32_t *tiles = (nk_f32_t *)((char *)b_packed + sizeof(nk_dots_sme_packed_header_t));
|
|
112
112
|
|
|
@@ -148,7 +148,7 @@ NK_PUBLIC void nk_dots_pack_f32_smef64(nk_f32_t const *b, nk_size_t columns, nk_
|
|
|
148
148
|
header->norms_offset = (nk_u32_t)(sizeof(nk_dots_sme_packed_header_t) + data_size);
|
|
149
149
|
nk_f64_t *norms_ptr = (nk_f64_t *)((char *)b_packed + header->norms_offset);
|
|
150
150
|
for (nk_size_t col = 0; col < columns; col++) {
|
|
151
|
-
nk_f32_t const *col_data = (nk_f32_t const *)((char const *)b + col *
|
|
151
|
+
nk_f32_t const *col_data = (nk_f32_t const *)((char const *)b + col * b_stride_in_bytes);
|
|
152
152
|
norms_ptr[col] = nk_dots_reduce_sumsq_f32_(col_data, depth);
|
|
153
153
|
}
|
|
154
154
|
}
|
|
@@ -168,14 +168,14 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_packed_f32_smef64_st
|
|
|
168
168
|
|
|
169
169
|
nk_f32_t const *b_tiles = (nk_f32_t const *)((char const *)b_packed + sizeof(nk_dots_sme_packed_header_t));
|
|
170
170
|
|
|
171
|
-
svbool_t const
|
|
171
|
+
svbool_t const predicate_all_b64x = svptrue_b64();
|
|
172
172
|
|
|
173
173
|
// ZA0.D = staging, ZA1-7.D = accumulation (7-tile fast path)
|
|
174
174
|
for (nk_size_t row_tile_index = 0; row_tile_index < nk_size_divide_round_up_(rows, tile_dimension);
|
|
175
175
|
row_tile_index++) {
|
|
176
176
|
nk_size_t const row_start = row_tile_index * tile_dimension;
|
|
177
177
|
nk_size_t const rows_remaining = (row_start + tile_dimension <= rows) ? tile_dimension : (rows - row_start);
|
|
178
|
-
svbool_t const
|
|
178
|
+
svbool_t const row_predicate_b64x = svwhilelt_b64_u64(0u, rows_remaining);
|
|
179
179
|
|
|
180
180
|
nk_size_t column_tile_index = 0;
|
|
181
181
|
|
|
@@ -200,18 +200,17 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_packed_f32_smef64_st
|
|
|
200
200
|
svzero_mask_za(nk_sme_zero_za64_tile_0_);
|
|
201
201
|
|
|
202
202
|
// Load A rows into ZA0.D: extending load f32→u64 + convert to f64
|
|
203
|
-
svbool_t const
|
|
204
|
-
svbool_t const
|
|
205
|
-
(uint64_t)depth);
|
|
203
|
+
svbool_t const batch_predicate_b64x = svwhilelt_b64_u64(0u, batch_size);
|
|
204
|
+
svbool_t const a_depth_predicate_b64x = svwhilelt_b64_u64(depth_offset + depth_batch_start, depth);
|
|
206
205
|
for (nk_size_t row_in_tile = 0; row_in_tile < rows_remaining; row_in_tile++) {
|
|
207
206
|
nk_size_t const a_row = row_start + row_in_tile;
|
|
208
207
|
// Extending load: svld1uw_u64 loads f32 bits into lower 32 of each u64 lane
|
|
209
208
|
svfloat64_t a_row_widened_f64x = svcvt_f64_f32_x(
|
|
210
|
-
|
|
209
|
+
batch_predicate_b64x,
|
|
211
210
|
svreinterpret_f32_u64(svld1uw_u64(
|
|
212
|
-
|
|
211
|
+
a_depth_predicate_b64x,
|
|
213
212
|
(nk_u32_t const *)&a[a_row * a_stride_elements + depth_offset + depth_batch_start])));
|
|
214
|
-
svwrite_hor_za64_f64_m(0, row_in_tile,
|
|
213
|
+
svwrite_hor_za64_f64_m(0, row_in_tile, batch_predicate_b64x, a_row_widened_f64x);
|
|
215
214
|
}
|
|
216
215
|
|
|
217
216
|
// Vertical read + MOPA for each depth step in batch
|
|
@@ -219,110 +218,110 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_packed_f32_smef64_st
|
|
|
219
218
|
nk_size_t const k_abs = depth_offset + depth_batch_start + step;
|
|
220
219
|
if (k_abs >= depth) break;
|
|
221
220
|
|
|
222
|
-
svfloat64_t a_f64x = svread_ver_za64_f64_m(svdup_f64(0.0),
|
|
221
|
+
svfloat64_t a_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), row_predicate_b64x, 0, step);
|
|
223
222
|
|
|
224
223
|
nk_size_t const b_k = depth_batch_start + step;
|
|
225
224
|
|
|
226
225
|
// Extending load f32→u64 + convert to f64: svld1uw_u64 replaces svld1_f32 + svunpklo_u64
|
|
227
226
|
svfloat64_t b_column_tile_1_f64x = svcvt_f64_f32_x(
|
|
228
|
-
|
|
227
|
+
predicate_all_b64x,
|
|
229
228
|
svreinterpret_f32_u64(svld1uw_u64(
|
|
230
|
-
|
|
229
|
+
predicate_all_b64x,
|
|
231
230
|
(nk_u32_t const *)(b_tiles +
|
|
232
231
|
((column_tile_index + 0) * depth_tile_count + depth_tile_idx) *
|
|
233
232
|
tile_elements +
|
|
234
233
|
b_k * tile_dimension))));
|
|
235
234
|
svfloat64_t b_column_tile_2_f64x = svcvt_f64_f32_x(
|
|
236
|
-
|
|
235
|
+
predicate_all_b64x,
|
|
237
236
|
svreinterpret_f32_u64(svld1uw_u64(
|
|
238
|
-
|
|
237
|
+
predicate_all_b64x,
|
|
239
238
|
(nk_u32_t const *)(b_tiles +
|
|
240
239
|
((column_tile_index + 1) * depth_tile_count + depth_tile_idx) *
|
|
241
240
|
tile_elements +
|
|
242
241
|
b_k * tile_dimension))));
|
|
243
242
|
svfloat64_t b_column_tile_3_f64x = svcvt_f64_f32_x(
|
|
244
|
-
|
|
243
|
+
predicate_all_b64x,
|
|
245
244
|
svreinterpret_f32_u64(svld1uw_u64(
|
|
246
|
-
|
|
245
|
+
predicate_all_b64x,
|
|
247
246
|
(nk_u32_t const *)(b_tiles +
|
|
248
247
|
((column_tile_index + 2) * depth_tile_count + depth_tile_idx) *
|
|
249
248
|
tile_elements +
|
|
250
249
|
b_k * tile_dimension))));
|
|
251
250
|
svfloat64_t b_column_tile_4_f64x = svcvt_f64_f32_x(
|
|
252
|
-
|
|
251
|
+
predicate_all_b64x,
|
|
253
252
|
svreinterpret_f32_u64(svld1uw_u64(
|
|
254
|
-
|
|
253
|
+
predicate_all_b64x,
|
|
255
254
|
(nk_u32_t const *)(b_tiles +
|
|
256
255
|
((column_tile_index + 3) * depth_tile_count + depth_tile_idx) *
|
|
257
256
|
tile_elements +
|
|
258
257
|
b_k * tile_dimension))));
|
|
259
258
|
svfloat64_t b_column_tile_5_f64x = svcvt_f64_f32_x(
|
|
260
|
-
|
|
259
|
+
predicate_all_b64x,
|
|
261
260
|
svreinterpret_f32_u64(svld1uw_u64(
|
|
262
|
-
|
|
261
|
+
predicate_all_b64x,
|
|
263
262
|
(nk_u32_t const *)(b_tiles +
|
|
264
263
|
((column_tile_index + 4) * depth_tile_count + depth_tile_idx) *
|
|
265
264
|
tile_elements +
|
|
266
265
|
b_k * tile_dimension))));
|
|
267
266
|
svfloat64_t b_column_tile_6_f64x = svcvt_f64_f32_x(
|
|
268
|
-
|
|
267
|
+
predicate_all_b64x,
|
|
269
268
|
svreinterpret_f32_u64(svld1uw_u64(
|
|
270
|
-
|
|
269
|
+
predicate_all_b64x,
|
|
271
270
|
(nk_u32_t const *)(b_tiles +
|
|
272
271
|
((column_tile_index + 5) * depth_tile_count + depth_tile_idx) *
|
|
273
272
|
tile_elements +
|
|
274
273
|
b_k * tile_dimension))));
|
|
275
274
|
svfloat64_t b_column_tile_7_f64x = svcvt_f64_f32_x(
|
|
276
|
-
|
|
275
|
+
predicate_all_b64x,
|
|
277
276
|
svreinterpret_f32_u64(svld1uw_u64(
|
|
278
|
-
|
|
277
|
+
predicate_all_b64x,
|
|
279
278
|
(nk_u32_t const *)(b_tiles +
|
|
280
279
|
((column_tile_index + 6) * depth_tile_count + depth_tile_idx) *
|
|
281
280
|
tile_elements +
|
|
282
281
|
b_k * tile_dimension))));
|
|
283
282
|
|
|
284
|
-
svmopa_za64_f64_m(1,
|
|
285
|
-
svmopa_za64_f64_m(2,
|
|
286
|
-
svmopa_za64_f64_m(3,
|
|
287
|
-
svmopa_za64_f64_m(4,
|
|
288
|
-
svmopa_za64_f64_m(5,
|
|
289
|
-
svmopa_za64_f64_m(6,
|
|
290
|
-
svmopa_za64_f64_m(7,
|
|
283
|
+
svmopa_za64_f64_m(1, row_predicate_b64x, predicate_all_b64x, a_f64x, b_column_tile_1_f64x);
|
|
284
|
+
svmopa_za64_f64_m(2, row_predicate_b64x, predicate_all_b64x, a_f64x, b_column_tile_2_f64x);
|
|
285
|
+
svmopa_za64_f64_m(3, row_predicate_b64x, predicate_all_b64x, a_f64x, b_column_tile_3_f64x);
|
|
286
|
+
svmopa_za64_f64_m(4, row_predicate_b64x, predicate_all_b64x, a_f64x, b_column_tile_4_f64x);
|
|
287
|
+
svmopa_za64_f64_m(5, row_predicate_b64x, predicate_all_b64x, a_f64x, b_column_tile_5_f64x);
|
|
288
|
+
svmopa_za64_f64_m(6, row_predicate_b64x, predicate_all_b64x, a_f64x, b_column_tile_6_f64x);
|
|
289
|
+
svmopa_za64_f64_m(7, row_predicate_b64x, predicate_all_b64x, a_f64x, b_column_tile_7_f64x);
|
|
291
290
|
}
|
|
292
291
|
}
|
|
293
292
|
}
|
|
294
293
|
|
|
295
294
|
// Extract from ZA1-7 and store native f64 outputs.
|
|
296
|
-
svbool_t const
|
|
295
|
+
svbool_t const predicate_tile_b64x = svwhilelt_b64_u64(0u, tile_dimension);
|
|
297
296
|
// The 7th tile (index 6) may be partial when it's the last column tile
|
|
298
297
|
nk_size_t const last_fast_col_start = (column_tile_index + 6) * tile_dimension;
|
|
299
298
|
nk_size_t const last_fast_cols = (last_fast_col_start + tile_dimension <= columns)
|
|
300
299
|
? tile_dimension
|
|
301
300
|
: (columns - last_fast_col_start);
|
|
302
|
-
svbool_t const
|
|
301
|
+
svbool_t const last_tile_pred_b64x = svwhilelt_b64_u64(0u, last_fast_cols);
|
|
303
302
|
for (nk_size_t row_idx = 0; row_idx < rows_remaining; row_idx++) {
|
|
304
303
|
nk_f64_t *c_row = c + (row_start + row_idx) * c_stride_elements;
|
|
305
304
|
|
|
306
|
-
svfloat64_t za_row_f64x = svread_hor_za64_f64_m(svdup_f64(0),
|
|
307
|
-
svst1_f64(
|
|
305
|
+
svfloat64_t za_row_f64x = svread_hor_za64_f64_m(svdup_f64(0), predicate_all_b64x, 1, row_idx);
|
|
306
|
+
svst1_f64(predicate_tile_b64x, c_row + (column_tile_index + 0) * tile_dimension, za_row_f64x);
|
|
308
307
|
|
|
309
|
-
za_row_f64x = svread_hor_za64_f64_m(svdup_f64(0),
|
|
310
|
-
svst1_f64(
|
|
308
|
+
za_row_f64x = svread_hor_za64_f64_m(svdup_f64(0), predicate_all_b64x, 2, row_idx);
|
|
309
|
+
svst1_f64(predicate_tile_b64x, c_row + (column_tile_index + 1) * tile_dimension, za_row_f64x);
|
|
311
310
|
|
|
312
|
-
za_row_f64x = svread_hor_za64_f64_m(svdup_f64(0),
|
|
313
|
-
svst1_f64(
|
|
311
|
+
za_row_f64x = svread_hor_za64_f64_m(svdup_f64(0), predicate_all_b64x, 3, row_idx);
|
|
312
|
+
svst1_f64(predicate_tile_b64x, c_row + (column_tile_index + 2) * tile_dimension, za_row_f64x);
|
|
314
313
|
|
|
315
|
-
za_row_f64x = svread_hor_za64_f64_m(svdup_f64(0),
|
|
316
|
-
svst1_f64(
|
|
314
|
+
za_row_f64x = svread_hor_za64_f64_m(svdup_f64(0), predicate_all_b64x, 4, row_idx);
|
|
315
|
+
svst1_f64(predicate_tile_b64x, c_row + (column_tile_index + 3) * tile_dimension, za_row_f64x);
|
|
317
316
|
|
|
318
|
-
za_row_f64x = svread_hor_za64_f64_m(svdup_f64(0),
|
|
319
|
-
svst1_f64(
|
|
317
|
+
za_row_f64x = svread_hor_za64_f64_m(svdup_f64(0), predicate_all_b64x, 5, row_idx);
|
|
318
|
+
svst1_f64(predicate_tile_b64x, c_row + (column_tile_index + 4) * tile_dimension, za_row_f64x);
|
|
320
319
|
|
|
321
|
-
za_row_f64x = svread_hor_za64_f64_m(svdup_f64(0),
|
|
322
|
-
svst1_f64(
|
|
320
|
+
za_row_f64x = svread_hor_za64_f64_m(svdup_f64(0), predicate_all_b64x, 6, row_idx);
|
|
321
|
+
svst1_f64(predicate_tile_b64x, c_row + (column_tile_index + 5) * tile_dimension, za_row_f64x);
|
|
323
322
|
|
|
324
|
-
za_row_f64x = svread_hor_za64_f64_m(svdup_f64(0),
|
|
325
|
-
svst1_f64(
|
|
323
|
+
za_row_f64x = svread_hor_za64_f64_m(svdup_f64(0), predicate_all_b64x, 7, row_idx);
|
|
324
|
+
svst1_f64(last_tile_pred_b64x, c_row + (column_tile_index + 6) * tile_dimension, za_row_f64x);
|
|
326
325
|
}
|
|
327
326
|
}
|
|
328
327
|
|
|
@@ -331,7 +330,7 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_packed_f32_smef64_st
|
|
|
331
330
|
nk_size_t const column_start = column_tile_index * tile_dimension;
|
|
332
331
|
nk_size_t const columns_remaining = (column_start + tile_dimension <= columns) ? tile_dimension
|
|
333
332
|
: (columns - column_start);
|
|
334
|
-
svbool_t const
|
|
333
|
+
svbool_t const column_predicate_b64x = svwhilelt_b64_u64(0u, columns_remaining);
|
|
335
334
|
|
|
336
335
|
svzero_mask_za(nk_sme_zero_za64_tile_1_);
|
|
337
336
|
|
|
@@ -349,54 +348,54 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_packed_f32_smef64_st
|
|
|
349
348
|
|
|
350
349
|
svzero_mask_za(nk_sme_zero_za64_tile_0_);
|
|
351
350
|
|
|
352
|
-
svbool_t const
|
|
353
|
-
svbool_t const
|
|
354
|
-
(uint64_t)depth);
|
|
351
|
+
svbool_t const batch_predicate_b64x = svwhilelt_b64_u64(0u, batch_size);
|
|
352
|
+
svbool_t const a_depth_pred_b64x = svwhilelt_b64_u64(depth_offset + depth_batch_start, depth);
|
|
355
353
|
for (nk_size_t row_in_tile = 0; row_in_tile < rows_remaining; row_in_tile++) {
|
|
356
354
|
nk_size_t const a_row = row_start + row_in_tile;
|
|
357
355
|
svfloat64_t a_row_widened_f64x = svcvt_f64_f32_x(
|
|
358
|
-
|
|
356
|
+
batch_predicate_b64x,
|
|
359
357
|
svreinterpret_f32_u64(svld1uw_u64(
|
|
360
|
-
|
|
358
|
+
a_depth_pred_b64x,
|
|
361
359
|
(nk_u32_t const *)&a[a_row * a_stride_elements + depth_offset + depth_batch_start])));
|
|
362
|
-
svwrite_hor_za64_f64_m(0, row_in_tile,
|
|
360
|
+
svwrite_hor_za64_f64_m(0, row_in_tile, batch_predicate_b64x, a_row_widened_f64x);
|
|
363
361
|
}
|
|
364
362
|
|
|
365
363
|
for (nk_size_t step = 0; step < batch_size; step++) {
|
|
366
364
|
nk_size_t const k_abs = depth_offset + depth_batch_start + step;
|
|
367
365
|
if (k_abs >= depth) break;
|
|
368
366
|
|
|
369
|
-
svfloat64_t a_f64x = svread_ver_za64_f64_m(svdup_f64(0.0),
|
|
367
|
+
svfloat64_t a_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), row_predicate_b64x, 0, step);
|
|
370
368
|
|
|
371
369
|
nk_size_t const b_k = depth_batch_start + step;
|
|
372
370
|
nk_f32_t const *b_tile = b_tiles + (column_tile_index * depth_tile_count + depth_tile_idx) *
|
|
373
371
|
tile_elements;
|
|
374
372
|
// Extending load f32→u64 + convert to f64
|
|
375
373
|
svfloat64_t b_f64x = svcvt_f64_f32_x(
|
|
376
|
-
|
|
374
|
+
predicate_all_b64x,
|
|
377
375
|
svreinterpret_f32_u64(
|
|
378
|
-
svld1uw_u64(
|
|
376
|
+
svld1uw_u64(predicate_all_b64x, (nk_u32_t const *)(b_tile + b_k * tile_dimension))));
|
|
379
377
|
|
|
380
|
-
svmopa_za64_f64_m(1,
|
|
378
|
+
svmopa_za64_f64_m(1, row_predicate_b64x, column_predicate_b64x, a_f64x, b_f64x);
|
|
381
379
|
}
|
|
382
380
|
}
|
|
383
381
|
}
|
|
384
382
|
|
|
385
383
|
// Store native f64 outputs for the tail column tile.
|
|
386
384
|
for (nk_size_t row_idx = 0; row_idx < rows_remaining; row_idx++) {
|
|
387
|
-
svfloat64_t za_row_f64x = svread_hor_za64_f64_m(svdup_f64(0),
|
|
385
|
+
svfloat64_t za_row_f64x = svread_hor_za64_f64_m(svdup_f64(0), predicate_all_b64x, 1, row_idx);
|
|
388
386
|
nk_f64_t *c_row = c + (row_start + row_idx) * c_stride_elements + column_start;
|
|
389
|
-
svst1_f64(
|
|
387
|
+
svst1_f64(column_predicate_b64x, c_row, za_row_f64x);
|
|
390
388
|
}
|
|
391
389
|
}
|
|
392
390
|
}
|
|
393
391
|
}
|
|
394
392
|
|
|
395
393
|
NK_PUBLIC void nk_dots_packed_f32_smef64(nk_f32_t const *a, void const *b_packed, nk_f64_t *c, nk_size_t rows,
|
|
396
|
-
nk_size_t columns, nk_size_t depth, nk_size_t
|
|
394
|
+
nk_size_t columns, nk_size_t depth, nk_size_t a_stride_in_bytes,
|
|
395
|
+
nk_size_t c_stride_in_bytes) {
|
|
397
396
|
|
|
398
|
-
nk_size_t const a_stride_elements =
|
|
399
|
-
nk_size_t const c_stride_elements =
|
|
397
|
+
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_f32_t);
|
|
398
|
+
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f64_t);
|
|
400
399
|
|
|
401
400
|
nk_dots_packed_f32_smef64_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
|
|
402
401
|
}
|
|
@@ -408,30 +407,32 @@ NK_PUBLIC void nk_dots_packed_f32_smef64(nk_f32_t const *a, void const *b_packed
|
|
|
408
407
|
* per column tile. Eliminates all scalar B-packing loops.
|
|
409
408
|
*/
|
|
410
409
|
__arm_locally_streaming __arm_new("za") static void nk_dots_symmetric_f32_smef64_streaming_(
|
|
411
|
-
nk_f32_t const *vectors, nk_size_t
|
|
410
|
+
nk_f32_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, nk_f64_t *result,
|
|
412
411
|
nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
|
|
413
412
|
|
|
414
413
|
nk_size_t const tile_dimension = svcntd(); // 8 for SVL=512
|
|
415
414
|
nk_size_t const depth_tile_size = svcntw(); // 16 for SVL=512
|
|
416
415
|
nk_size_t const depth_steps_per_batch = tile_dimension; // 8
|
|
417
416
|
|
|
418
|
-
svbool_t const
|
|
417
|
+
svbool_t const predicate_all_b64x = svptrue_b64();
|
|
419
418
|
|
|
420
419
|
NK_ALIGN64 nk_f64_t a_buffer[8][8];
|
|
421
420
|
|
|
422
421
|
nk_size_t const row_end = row_start + row_count;
|
|
423
|
-
nk_size_t const column_tile_count = nk_size_divide_round_up_(
|
|
422
|
+
nk_size_t const column_tile_count = nk_size_divide_round_up_(vectors_count, tile_dimension);
|
|
424
423
|
nk_size_t const depth_tile_count = nk_size_divide_round_up_(depth, depth_tile_size);
|
|
425
424
|
|
|
426
|
-
for (nk_size_t row_tile_start = row_start; row_tile_start < row_end && row_tile_start <
|
|
425
|
+
for (nk_size_t row_tile_start = row_start; row_tile_start < row_end && row_tile_start < vectors_count;
|
|
427
426
|
row_tile_start += tile_dimension) {
|
|
428
427
|
nk_size_t const rows_clamped = (row_tile_start + tile_dimension <= row_end) ? tile_dimension
|
|
429
428
|
: (row_end - row_tile_start);
|
|
430
|
-
nk_size_t const rows_actual = (row_tile_start + rows_clamped <=
|
|
431
|
-
|
|
432
|
-
|
|
429
|
+
nk_size_t const rows_actual = (row_tile_start + rows_clamped <= vectors_count)
|
|
430
|
+
? rows_clamped
|
|
431
|
+
: (vectors_count - row_tile_start);
|
|
432
|
+
svbool_t const row_predicate_b64x = svwhilelt_b64_u64(0u, rows_actual);
|
|
433
433
|
|
|
434
|
-
|
|
434
|
+
// Upper triangle: start from this row tile's column
|
|
435
|
+
nk_size_t column_tile_index = row_tile_start / tile_dimension;
|
|
435
436
|
|
|
436
437
|
// Fast path: 7 column tiles at a time
|
|
437
438
|
for (; column_tile_index + 7 <= column_tile_count; column_tile_index += 7) {
|
|
@@ -451,209 +452,208 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_symmetric_f32_smef64
|
|
|
451
452
|
if (depth_offset + depth_batch_start >= depth) break;
|
|
452
453
|
|
|
453
454
|
// ZA transpose for A rows: extending load f32→f64, MOVA directly into ZA0
|
|
454
|
-
svbool_t const
|
|
455
|
-
svbool_t const
|
|
456
|
-
(uint64_t)depth);
|
|
455
|
+
svbool_t const batch_predicate_b64x = svwhilelt_b64_u64(0u, batch_size);
|
|
456
|
+
svbool_t const a_depth_predicate_b64x = svwhilelt_b64_u64(depth_offset + depth_batch_start, depth);
|
|
457
457
|
svzero_mask_za(nk_sme_zero_za64_tile_0_);
|
|
458
458
|
for (nk_size_t row_in_tile = 0; row_in_tile < rows_actual; row_in_tile++) {
|
|
459
459
|
nk_size_t const row_abs = row_tile_start + row_in_tile;
|
|
460
460
|
svfloat64_t a_row_widened_f64x = svcvt_f64_f32_x(
|
|
461
|
-
|
|
461
|
+
batch_predicate_b64x,
|
|
462
462
|
svreinterpret_f32_u64(svld1uw_u64(
|
|
463
|
-
|
|
463
|
+
a_depth_predicate_b64x, (nk_u32_t const *)&vectors[row_abs * stride_elements +
|
|
464
464
|
depth_offset + depth_batch_start])));
|
|
465
|
-
svwrite_hor_za64_f64_m(0, row_in_tile,
|
|
465
|
+
svwrite_hor_za64_f64_m(0, row_in_tile, batch_predicate_b64x, a_row_widened_f64x);
|
|
466
466
|
}
|
|
467
467
|
|
|
468
468
|
// Save A columns from ZA0 to stack buffer
|
|
469
469
|
for (nk_size_t s = 0; s < batch_size; s++)
|
|
470
|
-
svst1_f64(
|
|
471
|
-
svread_ver_za64_f64_m(svdup_f64(0),
|
|
470
|
+
svst1_f64(predicate_all_b64x, a_buffer[s],
|
|
471
|
+
svread_ver_za64_f64_m(svdup_f64(0), row_predicate_b64x, 0, s));
|
|
472
472
|
|
|
473
473
|
// Column tile 0 → ZA1 via MOVA
|
|
474
474
|
svzero_mask_za(nk_sme_zero_za64_tile_0_);
|
|
475
475
|
for (nk_size_t column = 0; column < tile_dimension; column++) {
|
|
476
476
|
nk_size_t const column_abs = (column_tile_index + 0) * tile_dimension + column;
|
|
477
|
-
if (column_abs <
|
|
477
|
+
if (column_abs < vectors_count) {
|
|
478
478
|
svfloat64_t widened_f64x = svcvt_f64_f32_x(
|
|
479
|
-
|
|
479
|
+
batch_predicate_b64x,
|
|
480
480
|
svreinterpret_f32_u64(svld1uw_u64(
|
|
481
|
-
|
|
481
|
+
a_depth_predicate_b64x,
|
|
482
482
|
(nk_u32_t const
|
|
483
483
|
*)&vectors[column_abs * stride_elements + depth_offset + depth_batch_start])));
|
|
484
|
-
svwrite_hor_za64_f64_m(0, column,
|
|
484
|
+
svwrite_hor_za64_f64_m(0, column, batch_predicate_b64x, widened_f64x);
|
|
485
485
|
}
|
|
486
486
|
}
|
|
487
487
|
for (nk_size_t step = 0; step < batch_size; step++) {
|
|
488
|
-
svfloat64_t a_f64x = svld1_f64(
|
|
489
|
-
svfloat64_t b_f64x = svread_ver_za64_f64_m(svdup_f64(0.0),
|
|
490
|
-
svmopa_za64_f64_m(1,
|
|
488
|
+
svfloat64_t a_f64x = svld1_f64(predicate_all_b64x, a_buffer[step]);
|
|
489
|
+
svfloat64_t b_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), predicate_all_b64x, 0, step);
|
|
490
|
+
svmopa_za64_f64_m(1, row_predicate_b64x, predicate_all_b64x, a_f64x, b_f64x);
|
|
491
491
|
}
|
|
492
492
|
|
|
493
493
|
// Column tile 1 → ZA2 via MOVA
|
|
494
494
|
svzero_mask_za(nk_sme_zero_za64_tile_0_);
|
|
495
495
|
for (nk_size_t column = 0; column < tile_dimension; column++) {
|
|
496
496
|
nk_size_t const column_abs = (column_tile_index + 1) * tile_dimension + column;
|
|
497
|
-
if (column_abs <
|
|
497
|
+
if (column_abs < vectors_count) {
|
|
498
498
|
svfloat64_t widened_f64x = svcvt_f64_f32_x(
|
|
499
|
-
|
|
499
|
+
batch_predicate_b64x,
|
|
500
500
|
svreinterpret_f32_u64(svld1uw_u64(
|
|
501
|
-
|
|
501
|
+
a_depth_predicate_b64x,
|
|
502
502
|
(nk_u32_t const
|
|
503
503
|
*)&vectors[column_abs * stride_elements + depth_offset + depth_batch_start])));
|
|
504
|
-
svwrite_hor_za64_f64_m(0, column,
|
|
504
|
+
svwrite_hor_za64_f64_m(0, column, batch_predicate_b64x, widened_f64x);
|
|
505
505
|
}
|
|
506
506
|
}
|
|
507
507
|
for (nk_size_t step = 0; step < batch_size; step++) {
|
|
508
|
-
svfloat64_t a_f64x = svld1_f64(
|
|
509
|
-
svfloat64_t b_f64x = svread_ver_za64_f64_m(svdup_f64(0.0),
|
|
510
|
-
svmopa_za64_f64_m(2,
|
|
508
|
+
svfloat64_t a_f64x = svld1_f64(predicate_all_b64x, a_buffer[step]);
|
|
509
|
+
svfloat64_t b_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), predicate_all_b64x, 0, step);
|
|
510
|
+
svmopa_za64_f64_m(2, row_predicate_b64x, predicate_all_b64x, a_f64x, b_f64x);
|
|
511
511
|
}
|
|
512
512
|
|
|
513
513
|
// Column tile 2 → ZA3 via MOVA
|
|
514
514
|
svzero_mask_za(nk_sme_zero_za64_tile_0_);
|
|
515
515
|
for (nk_size_t column = 0; column < tile_dimension; column++) {
|
|
516
516
|
nk_size_t const column_abs = (column_tile_index + 2) * tile_dimension + column;
|
|
517
|
-
if (column_abs <
|
|
517
|
+
if (column_abs < vectors_count) {
|
|
518
518
|
svfloat64_t widened_f64x = svcvt_f64_f32_x(
|
|
519
|
-
|
|
519
|
+
batch_predicate_b64x,
|
|
520
520
|
svreinterpret_f32_u64(svld1uw_u64(
|
|
521
|
-
|
|
521
|
+
a_depth_predicate_b64x,
|
|
522
522
|
(nk_u32_t const
|
|
523
523
|
*)&vectors[column_abs * stride_elements + depth_offset + depth_batch_start])));
|
|
524
|
-
svwrite_hor_za64_f64_m(0, column,
|
|
524
|
+
svwrite_hor_za64_f64_m(0, column, batch_predicate_b64x, widened_f64x);
|
|
525
525
|
}
|
|
526
526
|
}
|
|
527
527
|
for (nk_size_t step = 0; step < batch_size; step++) {
|
|
528
|
-
svfloat64_t a_f64x = svld1_f64(
|
|
529
|
-
svfloat64_t b_f64x = svread_ver_za64_f64_m(svdup_f64(0.0),
|
|
530
|
-
svmopa_za64_f64_m(3,
|
|
528
|
+
svfloat64_t a_f64x = svld1_f64(predicate_all_b64x, a_buffer[step]);
|
|
529
|
+
svfloat64_t b_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), predicate_all_b64x, 0, step);
|
|
530
|
+
svmopa_za64_f64_m(3, row_predicate_b64x, predicate_all_b64x, a_f64x, b_f64x);
|
|
531
531
|
}
|
|
532
532
|
|
|
533
533
|
// Column tile 3 → ZA4 via MOVA
|
|
534
534
|
svzero_mask_za(nk_sme_zero_za64_tile_0_);
|
|
535
535
|
for (nk_size_t column = 0; column < tile_dimension; column++) {
|
|
536
536
|
nk_size_t const column_abs = (column_tile_index + 3) * tile_dimension + column;
|
|
537
|
-
if (column_abs <
|
|
537
|
+
if (column_abs < vectors_count) {
|
|
538
538
|
svfloat64_t widened_f64x = svcvt_f64_f32_x(
|
|
539
|
-
|
|
539
|
+
batch_predicate_b64x,
|
|
540
540
|
svreinterpret_f32_u64(svld1uw_u64(
|
|
541
|
-
|
|
541
|
+
a_depth_predicate_b64x,
|
|
542
542
|
(nk_u32_t const
|
|
543
543
|
*)&vectors[column_abs * stride_elements + depth_offset + depth_batch_start])));
|
|
544
|
-
svwrite_hor_za64_f64_m(0, column,
|
|
544
|
+
svwrite_hor_za64_f64_m(0, column, batch_predicate_b64x, widened_f64x);
|
|
545
545
|
}
|
|
546
546
|
}
|
|
547
547
|
for (nk_size_t step = 0; step < batch_size; step++) {
|
|
548
|
-
svfloat64_t a_f64x = svld1_f64(
|
|
549
|
-
svfloat64_t b_f64x = svread_ver_za64_f64_m(svdup_f64(0.0),
|
|
550
|
-
svmopa_za64_f64_m(4,
|
|
548
|
+
svfloat64_t a_f64x = svld1_f64(predicate_all_b64x, a_buffer[step]);
|
|
549
|
+
svfloat64_t b_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), predicate_all_b64x, 0, step);
|
|
550
|
+
svmopa_za64_f64_m(4, row_predicate_b64x, predicate_all_b64x, a_f64x, b_f64x);
|
|
551
551
|
}
|
|
552
552
|
|
|
553
553
|
// Column tile 4 → ZA5 via MOVA
|
|
554
554
|
svzero_mask_za(nk_sme_zero_za64_tile_0_);
|
|
555
555
|
for (nk_size_t column = 0; column < tile_dimension; column++) {
|
|
556
556
|
nk_size_t const column_abs = (column_tile_index + 4) * tile_dimension + column;
|
|
557
|
-
if (column_abs <
|
|
557
|
+
if (column_abs < vectors_count) {
|
|
558
558
|
svfloat64_t widened_f64x = svcvt_f64_f32_x(
|
|
559
|
-
|
|
559
|
+
batch_predicate_b64x,
|
|
560
560
|
svreinterpret_f32_u64(svld1uw_u64(
|
|
561
|
-
|
|
561
|
+
a_depth_predicate_b64x,
|
|
562
562
|
(nk_u32_t const
|
|
563
563
|
*)&vectors[column_abs * stride_elements + depth_offset + depth_batch_start])));
|
|
564
|
-
svwrite_hor_za64_f64_m(0, column,
|
|
564
|
+
svwrite_hor_za64_f64_m(0, column, batch_predicate_b64x, widened_f64x);
|
|
565
565
|
}
|
|
566
566
|
}
|
|
567
567
|
for (nk_size_t step = 0; step < batch_size; step++) {
|
|
568
|
-
svfloat64_t a_f64x = svld1_f64(
|
|
569
|
-
svfloat64_t b_f64x = svread_ver_za64_f64_m(svdup_f64(0.0),
|
|
570
|
-
svmopa_za64_f64_m(5,
|
|
568
|
+
svfloat64_t a_f64x = svld1_f64(predicate_all_b64x, a_buffer[step]);
|
|
569
|
+
svfloat64_t b_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), predicate_all_b64x, 0, step);
|
|
570
|
+
svmopa_za64_f64_m(5, row_predicate_b64x, predicate_all_b64x, a_f64x, b_f64x);
|
|
571
571
|
}
|
|
572
572
|
|
|
573
573
|
// Column tile 5 → ZA6 via MOVA
|
|
574
574
|
svzero_mask_za(nk_sme_zero_za64_tile_0_);
|
|
575
575
|
for (nk_size_t column = 0; column < tile_dimension; column++) {
|
|
576
576
|
nk_size_t const column_abs = (column_tile_index + 5) * tile_dimension + column;
|
|
577
|
-
if (column_abs <
|
|
577
|
+
if (column_abs < vectors_count) {
|
|
578
578
|
svfloat64_t widened_f64x = svcvt_f64_f32_x(
|
|
579
|
-
|
|
579
|
+
batch_predicate_b64x,
|
|
580
580
|
svreinterpret_f32_u64(svld1uw_u64(
|
|
581
|
-
|
|
581
|
+
a_depth_predicate_b64x,
|
|
582
582
|
(nk_u32_t const
|
|
583
583
|
*)&vectors[column_abs * stride_elements + depth_offset + depth_batch_start])));
|
|
584
|
-
svwrite_hor_za64_f64_m(0, column,
|
|
584
|
+
svwrite_hor_za64_f64_m(0, column, batch_predicate_b64x, widened_f64x);
|
|
585
585
|
}
|
|
586
586
|
}
|
|
587
587
|
for (nk_size_t step = 0; step < batch_size; step++) {
|
|
588
|
-
svfloat64_t a_f64x = svld1_f64(
|
|
589
|
-
svfloat64_t b_f64x = svread_ver_za64_f64_m(svdup_f64(0.0),
|
|
590
|
-
svmopa_za64_f64_m(6,
|
|
588
|
+
svfloat64_t a_f64x = svld1_f64(predicate_all_b64x, a_buffer[step]);
|
|
589
|
+
svfloat64_t b_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), predicate_all_b64x, 0, step);
|
|
590
|
+
svmopa_za64_f64_m(6, row_predicate_b64x, predicate_all_b64x, a_f64x, b_f64x);
|
|
591
591
|
}
|
|
592
592
|
|
|
593
593
|
// Column tile 6 → ZA7 via MOVA
|
|
594
594
|
svzero_mask_za(nk_sme_zero_za64_tile_0_);
|
|
595
595
|
for (nk_size_t column = 0; column < tile_dimension; column++) {
|
|
596
596
|
nk_size_t const column_abs = (column_tile_index + 6) * tile_dimension + column;
|
|
597
|
-
if (column_abs <
|
|
597
|
+
if (column_abs < vectors_count) {
|
|
598
598
|
svfloat64_t widened_f64x = svcvt_f64_f32_x(
|
|
599
|
-
|
|
599
|
+
batch_predicate_b64x,
|
|
600
600
|
svreinterpret_f32_u64(svld1uw_u64(
|
|
601
|
-
|
|
601
|
+
a_depth_predicate_b64x,
|
|
602
602
|
(nk_u32_t const
|
|
603
603
|
*)&vectors[column_abs * stride_elements + depth_offset + depth_batch_start])));
|
|
604
|
-
svwrite_hor_za64_f64_m(0, column,
|
|
604
|
+
svwrite_hor_za64_f64_m(0, column, batch_predicate_b64x, widened_f64x);
|
|
605
605
|
}
|
|
606
606
|
}
|
|
607
607
|
for (nk_size_t step = 0; step < batch_size; step++) {
|
|
608
|
-
svfloat64_t a_f64x = svld1_f64(
|
|
609
|
-
svfloat64_t b_f64x = svread_ver_za64_f64_m(svdup_f64(0.0),
|
|
610
|
-
svmopa_za64_f64_m(7,
|
|
608
|
+
svfloat64_t a_f64x = svld1_f64(predicate_all_b64x, a_buffer[step]);
|
|
609
|
+
svfloat64_t b_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), predicate_all_b64x, 0, step);
|
|
610
|
+
svmopa_za64_f64_m(7, row_predicate_b64x, predicate_all_b64x, a_f64x, b_f64x);
|
|
611
611
|
}
|
|
612
612
|
}
|
|
613
613
|
}
|
|
614
614
|
|
|
615
615
|
// Extract results and store native f64 outputs.
|
|
616
|
-
svbool_t const
|
|
616
|
+
svbool_t const predicate_tile_b64x = svwhilelt_b64_u64(0u, tile_dimension);
|
|
617
617
|
// The 7th tile (index 6) may be partial when it's the last column tile
|
|
618
618
|
nk_size_t const last_fast_col_start = (column_tile_index + 6) * tile_dimension;
|
|
619
|
-
nk_size_t const last_fast_cols = (last_fast_col_start + tile_dimension <=
|
|
619
|
+
nk_size_t const last_fast_cols = (last_fast_col_start + tile_dimension <= vectors_count)
|
|
620
620
|
? tile_dimension
|
|
621
|
-
: (
|
|
622
|
-
svbool_t const
|
|
621
|
+
: (vectors_count - last_fast_col_start);
|
|
622
|
+
svbool_t const last_tile_pred_b64x = svwhilelt_b64_u64(0u, last_fast_cols);
|
|
623
623
|
for (nk_size_t row = 0; row < rows_actual; row++) {
|
|
624
624
|
nk_size_t const row_abs = row_tile_start + row;
|
|
625
625
|
nk_f64_t *result_row = result + row_abs * result_stride_elements;
|
|
626
626
|
|
|
627
|
-
svfloat64_t za_row_f64x = svread_hor_za64_f64_m(svdup_f64(0),
|
|
628
|
-
svst1_f64(
|
|
627
|
+
svfloat64_t za_row_f64x = svread_hor_za64_f64_m(svdup_f64(0), predicate_all_b64x, 1, row);
|
|
628
|
+
svst1_f64(predicate_tile_b64x, result_row + (column_tile_index + 0) * tile_dimension, za_row_f64x);
|
|
629
629
|
|
|
630
|
-
za_row_f64x = svread_hor_za64_f64_m(svdup_f64(0),
|
|
631
|
-
svst1_f64(
|
|
630
|
+
za_row_f64x = svread_hor_za64_f64_m(svdup_f64(0), predicate_all_b64x, 2, row);
|
|
631
|
+
svst1_f64(predicate_tile_b64x, result_row + (column_tile_index + 1) * tile_dimension, za_row_f64x);
|
|
632
632
|
|
|
633
|
-
za_row_f64x = svread_hor_za64_f64_m(svdup_f64(0),
|
|
634
|
-
svst1_f64(
|
|
633
|
+
za_row_f64x = svread_hor_za64_f64_m(svdup_f64(0), predicate_all_b64x, 3, row);
|
|
634
|
+
svst1_f64(predicate_tile_b64x, result_row + (column_tile_index + 2) * tile_dimension, za_row_f64x);
|
|
635
635
|
|
|
636
|
-
za_row_f64x = svread_hor_za64_f64_m(svdup_f64(0),
|
|
637
|
-
svst1_f64(
|
|
636
|
+
za_row_f64x = svread_hor_za64_f64_m(svdup_f64(0), predicate_all_b64x, 4, row);
|
|
637
|
+
svst1_f64(predicate_tile_b64x, result_row + (column_tile_index + 3) * tile_dimension, za_row_f64x);
|
|
638
638
|
|
|
639
|
-
za_row_f64x = svread_hor_za64_f64_m(svdup_f64(0),
|
|
640
|
-
svst1_f64(
|
|
639
|
+
za_row_f64x = svread_hor_za64_f64_m(svdup_f64(0), predicate_all_b64x, 5, row);
|
|
640
|
+
svst1_f64(predicate_tile_b64x, result_row + (column_tile_index + 4) * tile_dimension, za_row_f64x);
|
|
641
641
|
|
|
642
|
-
za_row_f64x = svread_hor_za64_f64_m(svdup_f64(0),
|
|
643
|
-
svst1_f64(
|
|
642
|
+
za_row_f64x = svread_hor_za64_f64_m(svdup_f64(0), predicate_all_b64x, 6, row);
|
|
643
|
+
svst1_f64(predicate_tile_b64x, result_row + (column_tile_index + 5) * tile_dimension, za_row_f64x);
|
|
644
644
|
|
|
645
|
-
za_row_f64x = svread_hor_za64_f64_m(svdup_f64(0),
|
|
646
|
-
svst1_f64(
|
|
645
|
+
za_row_f64x = svread_hor_za64_f64_m(svdup_f64(0), predicate_all_b64x, 7, row);
|
|
646
|
+
svst1_f64(last_tile_pred_b64x, result_row + (column_tile_index + 6) * tile_dimension, za_row_f64x);
|
|
647
647
|
}
|
|
648
648
|
}
|
|
649
649
|
|
|
650
650
|
// Remainder: 1 column tile at a time
|
|
651
651
|
for (; column_tile_index < column_tile_count; column_tile_index++) {
|
|
652
652
|
nk_size_t const column_tile_start = column_tile_index * tile_dimension;
|
|
653
|
-
nk_size_t const columns_remaining = (column_tile_start + tile_dimension <=
|
|
653
|
+
nk_size_t const columns_remaining = (column_tile_start + tile_dimension <= vectors_count)
|
|
654
654
|
? tile_dimension
|
|
655
|
-
: (
|
|
656
|
-
svbool_t const
|
|
655
|
+
: (vectors_count - column_tile_start);
|
|
656
|
+
svbool_t const column_predicate_b64x = svwhilelt_b64_u64(0u, columns_remaining);
|
|
657
657
|
|
|
658
658
|
svzero_mask_za(nk_sme_zero_za64_tile_1_);
|
|
659
659
|
|
|
@@ -669,44 +669,43 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_symmetric_f32_smef64
|
|
|
669
669
|
|
|
670
670
|
if (depth_offset + depth_batch_start >= depth) break;
|
|
671
671
|
|
|
672
|
-
svbool_t const
|
|
673
|
-
svbool_t const
|
|
674
|
-
(uint64_t)depth);
|
|
672
|
+
svbool_t const batch_predicate_b64x = svwhilelt_b64_u64(0u, batch_size);
|
|
673
|
+
svbool_t const a_depth_pred_b64x = svwhilelt_b64_u64(depth_offset + depth_batch_start, depth);
|
|
675
674
|
svzero_mask_za(nk_sme_zero_za64_tile_0_);
|
|
676
675
|
for (nk_size_t row_in_tile = 0; row_in_tile < rows_actual; row_in_tile++) {
|
|
677
676
|
nk_size_t const row_abs = row_tile_start + row_in_tile;
|
|
678
677
|
svfloat64_t a_row_widened_f64x = svcvt_f64_f32_x(
|
|
679
|
-
|
|
678
|
+
batch_predicate_b64x,
|
|
680
679
|
svreinterpret_f32_u64(svld1uw_u64(
|
|
681
|
-
|
|
680
|
+
a_depth_pred_b64x, (nk_u32_t const *)&vectors[row_abs * stride_elements + depth_offset +
|
|
682
681
|
depth_batch_start])));
|
|
683
|
-
svwrite_hor_za64_f64_m(0, row_in_tile,
|
|
682
|
+
svwrite_hor_za64_f64_m(0, row_in_tile, batch_predicate_b64x, a_row_widened_f64x);
|
|
684
683
|
}
|
|
685
684
|
|
|
686
685
|
// Save A columns from ZA0 to stack buffer
|
|
687
686
|
for (nk_size_t s = 0; s < batch_size; s++)
|
|
688
|
-
svst1_f64(
|
|
689
|
-
svread_ver_za64_f64_m(svdup_f64(0),
|
|
687
|
+
svst1_f64(predicate_all_b64x, a_buffer[s],
|
|
688
|
+
svread_ver_za64_f64_m(svdup_f64(0), row_predicate_b64x, 0, s));
|
|
690
689
|
|
|
691
690
|
// Load B column tile into ZA0 via MOVA, vertical read + FMOPA into ZA1
|
|
692
691
|
svzero_mask_za(nk_sme_zero_za64_tile_0_);
|
|
693
692
|
for (nk_size_t column = 0; column < tile_dimension; column++) {
|
|
694
693
|
nk_size_t const column_abs = column_tile_start + column;
|
|
695
|
-
if (column_abs <
|
|
694
|
+
if (column_abs < vectors_count) {
|
|
696
695
|
svfloat64_t widened_f64x = svcvt_f64_f32_x(
|
|
697
|
-
|
|
696
|
+
batch_predicate_b64x,
|
|
698
697
|
svreinterpret_f32_u64(svld1uw_u64(
|
|
699
|
-
|
|
698
|
+
a_depth_pred_b64x, (nk_u32_t const *)&vectors[column_abs * stride_elements +
|
|
700
699
|
depth_offset + depth_batch_start])));
|
|
701
|
-
svwrite_hor_za64_f64_m(0, column,
|
|
700
|
+
svwrite_hor_za64_f64_m(0, column, batch_predicate_b64x, widened_f64x);
|
|
702
701
|
}
|
|
703
702
|
}
|
|
704
703
|
for (nk_size_t step = 0; step < batch_size; step++) {
|
|
705
704
|
nk_size_t const k_abs = depth_offset + depth_batch_start + step;
|
|
706
705
|
if (k_abs >= depth) break;
|
|
707
|
-
svfloat64_t a_f64x = svld1_f64(
|
|
708
|
-
svfloat64_t b_f64x = svread_ver_za64_f64_m(svdup_f64(0.0),
|
|
709
|
-
svmopa_za64_f64_m(1,
|
|
706
|
+
svfloat64_t a_f64x = svld1_f64(predicate_all_b64x, a_buffer[step]);
|
|
707
|
+
svfloat64_t b_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), column_predicate_b64x, 0, step);
|
|
708
|
+
svmopa_za64_f64_m(1, row_predicate_b64x, column_predicate_b64x, a_f64x, b_f64x);
|
|
710
709
|
}
|
|
711
710
|
}
|
|
712
711
|
}
|
|
@@ -714,25 +713,26 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_symmetric_f32_smef64
|
|
|
714
713
|
// Store native f64 outputs for the tail column tile.
|
|
715
714
|
for (nk_size_t row = 0; row < rows_actual; row++) {
|
|
716
715
|
nk_size_t const row_abs = row_tile_start + row;
|
|
717
|
-
svfloat64_t za_row_f64x = svread_hor_za64_f64_m(svdup_f64(0),
|
|
718
|
-
svst1_f64(
|
|
716
|
+
svfloat64_t za_row_f64x = svread_hor_za64_f64_m(svdup_f64(0), predicate_all_b64x, 1, row);
|
|
717
|
+
svst1_f64(column_predicate_b64x, result + row_abs * result_stride_elements + column_tile_start,
|
|
719
718
|
za_row_f64x);
|
|
720
719
|
}
|
|
721
720
|
}
|
|
722
721
|
}
|
|
723
722
|
}
|
|
724
723
|
|
|
725
|
-
NK_PUBLIC void nk_dots_symmetric_f32_smef64(nk_f32_t const *vectors, nk_size_t
|
|
726
|
-
nk_size_t
|
|
727
|
-
nk_size_t
|
|
724
|
+
NK_PUBLIC void nk_dots_symmetric_f32_smef64(nk_f32_t const *vectors, nk_size_t vectors_count, nk_size_t depth,
|
|
725
|
+
nk_size_t stride_in_bytes, nk_f64_t *result,
|
|
726
|
+
nk_size_t result_stride_in_bytes, nk_size_t row_start,
|
|
727
|
+
nk_size_t row_count) {
|
|
728
728
|
|
|
729
|
-
nk_size_t const stride_elements =
|
|
730
|
-
nk_size_t const result_stride_elements =
|
|
731
|
-
nk_dots_symmetric_f32_smef64_streaming_(vectors,
|
|
732
|
-
row_start, row_count);
|
|
729
|
+
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_f32_t);
|
|
730
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f64_t);
|
|
731
|
+
nk_dots_symmetric_f32_smef64_streaming_(vectors, vectors_count, depth, stride_elements, result,
|
|
732
|
+
result_stride_elements, row_start, row_count);
|
|
733
733
|
}
|
|
734
734
|
|
|
735
|
-
#pragma endregion
|
|
735
|
+
#pragma endregion F32 Floats
|
|
736
736
|
|
|
737
737
|
/*
|
|
738
738
|
* f64 GEMM via 3-way Ozaki splitting using FMOPA with ZA64 tiles.
|
|
@@ -768,7 +768,7 @@ NK_PUBLIC void nk_dots_symmetric_f32_smef64(nk_f32_t const *vectors, nk_size_t n
|
|
|
768
768
|
* - f64 input vectors: 8 elements (SVL/64)
|
|
769
769
|
* - FMOPA predicates: b64 (native f64 granularity)
|
|
770
770
|
*/
|
|
771
|
-
#pragma region
|
|
771
|
+
#pragma region F64 Floats
|
|
772
772
|
|
|
773
773
|
/* Mantissa bit masks for 3-way Ozaki splitting of f64 values.
|
|
774
774
|
*
|
|
@@ -783,17 +783,17 @@ NK_PUBLIC void nk_dots_symmetric_f32_smef64(nk_f32_t const *vectors, nk_size_t n
|
|
|
783
783
|
*
|
|
784
784
|
* All slices fit in f32 (24-bit significand). Products: max 19+19 = 38 ≤ 53, exact in f64.
|
|
785
785
|
*/
|
|
786
|
-
NK_PUBLIC nk_u64_t nk_f64_smef64_ozaki_mask_19_bits_(void)
|
|
786
|
+
NK_PUBLIC nk_u64_t nk_f64_smef64_ozaki_mask_19_bits_(void) NK_STREAMING_ {
|
|
787
787
|
return 0xFFFFFFFC00000000ULL; // keep top 19 sig bits
|
|
788
788
|
}
|
|
789
|
-
NK_PUBLIC nk_u64_t nk_f64_smef64_ozaki_mask_17_bits_(void)
|
|
789
|
+
NK_PUBLIC nk_u64_t nk_f64_smef64_ozaki_mask_17_bits_(void) NK_STREAMING_ {
|
|
790
790
|
return 0xFFFFFFF000000000ULL; // keep top 17 sig bits
|
|
791
791
|
}
|
|
792
792
|
|
|
793
793
|
/* Split a scalar f64 into 3 non-overlapping Ozaki slices (19+17+17 mantissa bits).
|
|
794
794
|
* Each slice fits in f32. Outputs stored via pointers. */
|
|
795
795
|
NK_PUBLIC void nk_f64_smef64_ozaki_split_f64_(nk_f64_t val, nk_f64_t *slice_0, nk_f64_t *slice_1,
|
|
796
|
-
nk_f64_t *slice_2)
|
|
796
|
+
nk_f64_t *slice_2) NK_STREAMING_ {
|
|
797
797
|
nk_fui64_t pun;
|
|
798
798
|
pun.f = val;
|
|
799
799
|
pun.u &= nk_f64_smef64_ozaki_mask_19_bits_();
|
|
@@ -806,36 +806,39 @@ NK_PUBLIC void nk_f64_smef64_ozaki_split_f64_(nk_f64_t val, nk_f64_t *slice_0, n
|
|
|
806
806
|
}
|
|
807
807
|
|
|
808
808
|
__arm_locally_streaming __arm_new("za") static void nk_dots_symmetric_f64_smef64_streaming_(
|
|
809
|
-
nk_f64_t const *vectors, nk_size_t
|
|
809
|
+
nk_f64_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, nk_f64_t *result,
|
|
810
810
|
nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
|
|
811
811
|
|
|
812
812
|
nk_size_t const tile_dimension = svcntd();
|
|
813
813
|
nk_size_t const depth_steps_per_batch = tile_dimension;
|
|
814
814
|
|
|
815
|
-
svbool_t const
|
|
815
|
+
svbool_t const predicate_all_b64x = svptrue_b64();
|
|
816
816
|
svuint64_t const ozaki_mask_19_u64x = svdup_u64(nk_f64_smef64_ozaki_mask_19_bits_());
|
|
817
817
|
svuint64_t const ozaki_mask_17_u64x = svdup_u64(nk_f64_smef64_ozaki_mask_17_bits_());
|
|
818
818
|
|
|
819
819
|
NK_ALIGN64 nk_f64_t a_buffer[8][8]; // save A columns before reusing ZA0 for B
|
|
820
820
|
|
|
821
821
|
nk_size_t const row_end = row_start + row_count;
|
|
822
|
-
nk_size_t const column_tile_count = nk_size_divide_round_up_(
|
|
822
|
+
nk_size_t const column_tile_count = nk_size_divide_round_up_(vectors_count, tile_dimension);
|
|
823
823
|
|
|
824
824
|
// ZA0.D = staging (A then B), ZA1-3.D = merged Ozaki accumulators (i+j=0,1,2)
|
|
825
|
-
for (nk_size_t row_tile_start = row_start; row_tile_start < row_end && row_tile_start <
|
|
825
|
+
for (nk_size_t row_tile_start = row_start; row_tile_start < row_end && row_tile_start < vectors_count;
|
|
826
826
|
row_tile_start += tile_dimension) {
|
|
827
827
|
nk_size_t const rows_remaining = (row_tile_start + tile_dimension <= row_end) ? tile_dimension
|
|
828
828
|
: (row_end - row_tile_start);
|
|
829
|
-
nk_size_t const rows_clamped = (row_tile_start + rows_remaining <=
|
|
830
|
-
|
|
831
|
-
|
|
832
|
-
|
|
833
|
-
|
|
829
|
+
nk_size_t const rows_clamped = (row_tile_start + rows_remaining <= vectors_count)
|
|
830
|
+
? rows_remaining
|
|
831
|
+
: (vectors_count - row_tile_start);
|
|
832
|
+
svbool_t const row_predicate_b64x = svwhilelt_b64_u64(0u, rows_clamped);
|
|
833
|
+
|
|
834
|
+
// Upper triangle: start from this row tile's column
|
|
835
|
+
for (nk_size_t column_tile_index = row_tile_start / tile_dimension; column_tile_index < column_tile_count;
|
|
836
|
+
column_tile_index++) {
|
|
834
837
|
nk_size_t const column_tile_start = column_tile_index * tile_dimension;
|
|
835
|
-
nk_size_t const columns_remaining = (column_tile_start + tile_dimension <=
|
|
838
|
+
nk_size_t const columns_remaining = (column_tile_start + tile_dimension <= vectors_count)
|
|
836
839
|
? tile_dimension
|
|
837
|
-
: (
|
|
838
|
-
svbool_t const
|
|
840
|
+
: (vectors_count - column_tile_start);
|
|
841
|
+
svbool_t const column_predicate_b64x = svwhilelt_b64_u64(0u, columns_remaining);
|
|
839
842
|
|
|
840
843
|
// Zero ZA1-3 (3 merged Ozaki accumulators)
|
|
841
844
|
svzero_mask_za(nk_sme_zero_za64_tiles_1_3_);
|
|
@@ -846,67 +849,67 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_symmetric_f64_smef64
|
|
|
846
849
|
? depth_batch_start + depth_steps_per_batch
|
|
847
850
|
: depth;
|
|
848
851
|
nk_size_t const batch_size = depth_batch_end - depth_batch_start;
|
|
849
|
-
svbool_t const
|
|
852
|
+
svbool_t const batch_predicate_b64x = svwhilelt_b64_u64(0u, batch_size);
|
|
850
853
|
|
|
851
854
|
// Load A rows into ZA0
|
|
852
855
|
svzero_mask_za(nk_sme_zero_za64_tile_0_);
|
|
853
856
|
for (nk_size_t row_in_tile = 0; row_in_tile < rows_clamped; row_in_tile++) {
|
|
854
857
|
nk_size_t const row_abs = row_tile_start + row_in_tile;
|
|
855
|
-
svld1_hor_za64(0, row_in_tile,
|
|
858
|
+
svld1_hor_za64(0, row_in_tile, batch_predicate_b64x,
|
|
856
859
|
vectors + row_abs * stride_elements + depth_batch_start);
|
|
857
860
|
}
|
|
858
861
|
|
|
859
862
|
// Save A columns to buffer before reusing ZA0 for B
|
|
860
863
|
for (nk_size_t s = 0; s < batch_size; s++)
|
|
861
|
-
svst1_f64(
|
|
862
|
-
svread_ver_za64_f64_m(svdup_f64(0.0),
|
|
864
|
+
svst1_f64(predicate_all_b64x, a_buffer[s],
|
|
865
|
+
svread_ver_za64_f64_m(svdup_f64(0.0), row_predicate_b64x, 0, s));
|
|
863
866
|
|
|
864
867
|
// Load B columns into ZA0 (reuse)
|
|
865
868
|
svzero_mask_za(nk_sme_zero_za64_tile_0_);
|
|
866
869
|
for (nk_size_t column = 0; column < tile_dimension; column++) {
|
|
867
870
|
nk_size_t const column_abs = column_tile_start + column;
|
|
868
|
-
if (column_abs <
|
|
869
|
-
svld1_hor_za64(0, column,
|
|
871
|
+
if (column_abs < vectors_count)
|
|
872
|
+
svld1_hor_za64(0, column, batch_predicate_b64x,
|
|
870
873
|
vectors + column_abs * stride_elements + depth_batch_start);
|
|
871
874
|
}
|
|
872
875
|
|
|
873
876
|
// Split both A and B into 3 Ozaki slices, 6 FMOPAs per step
|
|
874
877
|
for (nk_size_t step = 0; step < batch_size; step++) {
|
|
875
|
-
svfloat64_t a_f64x = svld1_f64(
|
|
878
|
+
svfloat64_t a_f64x = svld1_f64(predicate_all_b64x, a_buffer[step]);
|
|
876
879
|
svuint64_t a_bits_u64x = svreinterpret_u64_f64(a_f64x);
|
|
877
880
|
svfloat64_t a_slice_0_f64x = svreinterpret_f64_u64(
|
|
878
|
-
svand_u64_x(
|
|
879
|
-
svfloat64_t residual_a_f64x = svsub_f64_x(
|
|
881
|
+
svand_u64_x(predicate_all_b64x, a_bits_u64x, ozaki_mask_19_u64x));
|
|
882
|
+
svfloat64_t residual_a_f64x = svsub_f64_x(predicate_all_b64x, a_f64x, a_slice_0_f64x);
|
|
880
883
|
svuint64_t residual_a_bits_u64x = svreinterpret_u64_f64(residual_a_f64x);
|
|
881
884
|
svfloat64_t a_slice_1_f64x = svreinterpret_f64_u64(
|
|
882
|
-
svand_u64_x(
|
|
883
|
-
svfloat64_t a_slice_2_f64x = svsub_f64_x(
|
|
885
|
+
svand_u64_x(predicate_all_b64x, residual_a_bits_u64x, ozaki_mask_17_u64x));
|
|
886
|
+
svfloat64_t a_slice_2_f64x = svsub_f64_x(predicate_all_b64x, residual_a_f64x, a_slice_1_f64x);
|
|
884
887
|
|
|
885
|
-
svfloat64_t b_f64x = svread_ver_za64_f64_m(svdup_f64(0.0),
|
|
888
|
+
svfloat64_t b_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), column_predicate_b64x, 0, step);
|
|
886
889
|
svuint64_t b_bits_u64x = svreinterpret_u64_f64(b_f64x);
|
|
887
890
|
svfloat64_t b_slice_0_f64x = svreinterpret_f64_u64(
|
|
888
|
-
svand_u64_x(
|
|
889
|
-
svfloat64_t residual_b_f64x = svsub_f64_x(
|
|
891
|
+
svand_u64_x(predicate_all_b64x, b_bits_u64x, ozaki_mask_19_u64x));
|
|
892
|
+
svfloat64_t residual_b_f64x = svsub_f64_x(predicate_all_b64x, b_f64x, b_slice_0_f64x);
|
|
890
893
|
svuint64_t residual_b_bits_u64x = svreinterpret_u64_f64(residual_b_f64x);
|
|
891
894
|
svfloat64_t b_slice_1_f64x = svreinterpret_f64_u64(
|
|
892
|
-
svand_u64_x(
|
|
893
|
-
svfloat64_t b_slice_2_f64x = svsub_f64_x(
|
|
895
|
+
svand_u64_x(predicate_all_b64x, residual_b_bits_u64x, ozaki_mask_17_u64x));
|
|
896
|
+
svfloat64_t b_slice_2_f64x = svsub_f64_x(predicate_all_b64x, residual_b_f64x, b_slice_1_f64x);
|
|
894
897
|
|
|
895
898
|
// 6 FMOPAs reordered to minimize WAW pipeline stalls on 3 tiles.
|
|
896
899
|
// Same-tile accumulation order preserved (bit-identical output).
|
|
897
900
|
// Tile schedule: ZA3(0), ZA2(1), ZA1(2), ZA3(4), ZA2(5), ZA3(8).
|
|
898
901
|
// 9 cycles vs 15 original (3 unavoidable bubbles with only 3 tiles).
|
|
899
|
-
svmopa_za64_f64_m(3,
|
|
902
|
+
svmopa_za64_f64_m(3, row_predicate_b64x, column_predicate_b64x, a_slice_0_f64x,
|
|
900
903
|
b_slice_2_f64x); // ZA3: i+j=2 (1/3)
|
|
901
|
-
svmopa_za64_f64_m(2,
|
|
904
|
+
svmopa_za64_f64_m(2, row_predicate_b64x, column_predicate_b64x, a_slice_0_f64x,
|
|
902
905
|
b_slice_1_f64x); // ZA2: i+j=1 (1/2)
|
|
903
|
-
svmopa_za64_f64_m(1,
|
|
906
|
+
svmopa_za64_f64_m(1, row_predicate_b64x, column_predicate_b64x, a_slice_0_f64x,
|
|
904
907
|
b_slice_0_f64x); // ZA1: i+j=0
|
|
905
|
-
svmopa_za64_f64_m(3,
|
|
908
|
+
svmopa_za64_f64_m(3, row_predicate_b64x, column_predicate_b64x, a_slice_1_f64x,
|
|
906
909
|
b_slice_1_f64x); // ZA3: i+j=2 (2/3)
|
|
907
|
-
svmopa_za64_f64_m(2,
|
|
910
|
+
svmopa_za64_f64_m(2, row_predicate_b64x, column_predicate_b64x, a_slice_1_f64x,
|
|
908
911
|
b_slice_0_f64x); // ZA2: i+j=1 (2/2)
|
|
909
|
-
svmopa_za64_f64_m(3,
|
|
912
|
+
svmopa_za64_f64_m(3, row_predicate_b64x, column_predicate_b64x, a_slice_2_f64x,
|
|
910
913
|
b_slice_0_f64x); // ZA3: i+j=2 (3/3)
|
|
911
914
|
}
|
|
912
915
|
}
|
|
@@ -914,31 +917,32 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_symmetric_f64_smef64
|
|
|
914
917
|
// Sum ZA3 + ZA2 + ZA1 (smallest to largest)
|
|
915
918
|
for (nk_size_t row = 0; row < rows_clamped; row++) {
|
|
916
919
|
nk_size_t const row_abs = row_tile_start + row;
|
|
917
|
-
svfloat64_t result_f64x = svread_hor_za64_f64_m(svdup_f64(0.0),
|
|
918
|
-
result_f64x = svadd_f64_x(
|
|
919
|
-
svread_hor_za64_f64_m(svdup_f64(0.0),
|
|
920
|
-
result_f64x = svadd_f64_x(
|
|
921
|
-
svread_hor_za64_f64_m(svdup_f64(0.0),
|
|
922
|
-
svst1_f64(
|
|
920
|
+
svfloat64_t result_f64x = svread_hor_za64_f64_m(svdup_f64(0.0), predicate_all_b64x, 3, row);
|
|
921
|
+
result_f64x = svadd_f64_x(predicate_all_b64x, result_f64x,
|
|
922
|
+
svread_hor_za64_f64_m(svdup_f64(0.0), predicate_all_b64x, 2, row));
|
|
923
|
+
result_f64x = svadd_f64_x(predicate_all_b64x, result_f64x,
|
|
924
|
+
svread_hor_za64_f64_m(svdup_f64(0.0), predicate_all_b64x, 1, row));
|
|
925
|
+
svst1_f64(column_predicate_b64x, result + row_abs * result_stride_elements + column_tile_start,
|
|
923
926
|
result_f64x);
|
|
924
927
|
}
|
|
925
928
|
}
|
|
926
929
|
}
|
|
927
930
|
}
|
|
928
931
|
|
|
929
|
-
NK_PUBLIC void nk_dots_symmetric_f64_smef64(nk_f64_t const *vectors, nk_size_t
|
|
930
|
-
nk_size_t
|
|
931
|
-
nk_size_t
|
|
932
|
+
NK_PUBLIC void nk_dots_symmetric_f64_smef64(nk_f64_t const *vectors, nk_size_t vectors_count, nk_size_t depth,
|
|
933
|
+
nk_size_t stride_in_bytes, nk_f64_t *result,
|
|
934
|
+
nk_size_t result_stride_in_bytes, nk_size_t row_start,
|
|
935
|
+
nk_size_t row_count) {
|
|
932
936
|
|
|
933
|
-
nk_size_t const stride_elements =
|
|
934
|
-
nk_size_t const result_stride_elements =
|
|
935
|
-
nk_dots_symmetric_f64_smef64_streaming_(vectors,
|
|
936
|
-
row_start, row_count);
|
|
937
|
+
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_f64_t);
|
|
938
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f64_t);
|
|
939
|
+
nk_dots_symmetric_f64_smef64_streaming_(vectors, vectors_count, depth, stride_elements, result,
|
|
940
|
+
result_stride_elements, row_start, row_count);
|
|
937
941
|
}
|
|
938
942
|
|
|
939
943
|
NK_PUBLIC nk_size_t nk_dots_packed_size_f64_smef64(nk_size_t columns, nk_size_t depth) {
|
|
940
|
-
nk_size_t const tile_dimension =
|
|
941
|
-
nk_size_t const depth_tile_size =
|
|
944
|
+
nk_size_t const tile_dimension = nk_sme_cntd_();
|
|
945
|
+
nk_size_t const depth_tile_size = nk_sme_cntw_();
|
|
942
946
|
nk_size_t const column_tile_count = nk_size_divide_round_up_(columns, tile_dimension);
|
|
943
947
|
nk_size_t const depth_tile_count = nk_size_divide_round_up_(depth, depth_tile_size);
|
|
944
948
|
// Single header + interleaved 3-slice data (3× tile_dimension elements per depth step)
|
|
@@ -948,13 +952,13 @@ NK_PUBLIC nk_size_t nk_dots_packed_size_f64_smef64(nk_size_t columns, nk_size_t
|
|
|
948
952
|
return size;
|
|
949
953
|
}
|
|
950
954
|
|
|
951
|
-
NK_PUBLIC void nk_dots_pack_f64_smef64(nk_f64_t const *b, nk_size_t columns, nk_size_t depth,
|
|
952
|
-
void *b_packed) {
|
|
955
|
+
NK_PUBLIC void nk_dots_pack_f64_smef64(nk_f64_t const *b, nk_size_t columns, nk_size_t depth,
|
|
956
|
+
nk_size_t b_stride_in_bytes, void *b_packed) {
|
|
953
957
|
|
|
954
|
-
nk_size_t const b_stride_elements =
|
|
958
|
+
nk_size_t const b_stride_elements = b_stride_in_bytes / sizeof(nk_f64_t);
|
|
955
959
|
|
|
956
|
-
nk_size_t const tile_dimension =
|
|
957
|
-
nk_size_t const depth_tile_size =
|
|
960
|
+
nk_size_t const tile_dimension = nk_sme_cntd_();
|
|
961
|
+
nk_size_t const depth_tile_size = nk_sme_cntw_();
|
|
958
962
|
nk_size_t const interleaved_stride = 3 * tile_dimension;
|
|
959
963
|
nk_size_t const interleaved_tile_elements = depth_tile_size * interleaved_stride;
|
|
960
964
|
|
|
@@ -968,7 +972,7 @@ NK_PUBLIC void nk_dots_pack_f64_smef64(nk_f64_t const *b, nk_size_t columns, nk_
|
|
|
968
972
|
header->depth_tile_count = (nk_u32_t)depth_tile_count;
|
|
969
973
|
header->columns = (nk_u32_t)columns;
|
|
970
974
|
header->depth = (nk_u32_t)depth;
|
|
971
|
-
header->svl_bytes = (nk_u32_t)
|
|
975
|
+
header->svl_bytes = (nk_u32_t)nk_sme_cntb_();
|
|
972
976
|
|
|
973
977
|
nk_f32_t *tiles = (nk_f32_t *)((char *)b_packed + sizeof(nk_dots_sme_packed_header_t));
|
|
974
978
|
|
|
@@ -1009,7 +1013,7 @@ NK_PUBLIC void nk_dots_pack_f64_smef64(nk_f64_t const *b, nk_size_t columns, nk_
|
|
|
1009
1013
|
header->norms_offset = (nk_u32_t)(sizeof(nk_dots_sme_packed_header_t) + data_size);
|
|
1010
1014
|
nk_f64_t *norms_ptr = (nk_f64_t *)((char *)b_packed + header->norms_offset);
|
|
1011
1015
|
for (nk_size_t col = 0; col < columns; col++) {
|
|
1012
|
-
nk_f64_t const *col_data = (nk_f64_t const *)((char const *)b + col *
|
|
1016
|
+
nk_f64_t const *col_data = (nk_f64_t const *)((char const *)b + col * b_stride_in_bytes);
|
|
1013
1017
|
norms_ptr[col] = nk_dots_reduce_sumsq_f64_(col_data, depth);
|
|
1014
1018
|
}
|
|
1015
1019
|
}
|
|
@@ -1032,7 +1036,7 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_packed_f64_smef64_st
|
|
|
1032
1036
|
// B tile data pointer (f32, interleaved slices)
|
|
1033
1037
|
nk_f32_t const *b_tiles = (nk_f32_t const *)((char const *)b_packed + sizeof(nk_dots_sme_packed_header_t));
|
|
1034
1038
|
|
|
1035
|
-
svbool_t const
|
|
1039
|
+
svbool_t const predicate_all_b64x = svptrue_b64();
|
|
1036
1040
|
|
|
1037
1041
|
// Mantissa masks for in-register Ozaki splitting (19+17+17 bits)
|
|
1038
1042
|
svuint64_t const ozaki_mask_19_u64x = svdup_u64(nk_f64_smef64_ozaki_mask_19_bits_());
|
|
@@ -1045,7 +1049,7 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_packed_f64_smef64_st
|
|
|
1045
1049
|
row_tile_index++) {
|
|
1046
1050
|
nk_size_t const row_start = row_tile_index * tile_dimension;
|
|
1047
1051
|
nk_size_t const rows_remaining = (row_start + tile_dimension <= rows) ? tile_dimension : (rows - row_start);
|
|
1048
|
-
svbool_t const
|
|
1052
|
+
svbool_t const row_predicate_b64x = svwhilelt_b64_u64(0u, rows_remaining);
|
|
1049
1053
|
|
|
1050
1054
|
nk_size_t column_tile_index = 0;
|
|
1051
1055
|
|
|
@@ -1059,8 +1063,8 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_packed_f64_smef64_st
|
|
|
1059
1063
|
nk_size_t const columns_remaining_1 = (column_start_1 + tile_dimension <= columns)
|
|
1060
1064
|
? tile_dimension
|
|
1061
1065
|
: (columns - column_start_1);
|
|
1062
|
-
svbool_t const
|
|
1063
|
-
svbool_t const
|
|
1066
|
+
svbool_t const column_predicate_0_b64x = svwhilelt_b64_u64(0u, columns_remaining_0);
|
|
1067
|
+
svbool_t const column_predicate_1_b64x = svwhilelt_b64_u64(0u, columns_remaining_1);
|
|
1064
1068
|
|
|
1065
1069
|
// Zero ZA1-6 (3 accumulators × 2 column tiles)
|
|
1066
1070
|
svzero_mask_za(nk_sme_zero_za64_tiles_1_6_);
|
|
@@ -1081,9 +1085,9 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_packed_f64_smef64_st
|
|
|
1081
1085
|
svzero_mask_za(nk_sme_zero_za64_tile_0_);
|
|
1082
1086
|
for (nk_size_t row_in_tile = 0; row_in_tile < rows_remaining; row_in_tile++) {
|
|
1083
1087
|
nk_size_t const a_row = row_start + row_in_tile;
|
|
1084
|
-
svbool_t const
|
|
1085
|
-
|
|
1086
|
-
svld1_hor_za64(0, row_in_tile,
|
|
1088
|
+
svbool_t const a_depth_predicate_b64x = svwhilelt_b64_u64(depth_offset + depth_batch_start,
|
|
1089
|
+
depth);
|
|
1090
|
+
svld1_hor_za64(0, row_in_tile, a_depth_predicate_b64x,
|
|
1087
1091
|
&a[a_row * a_stride_elements + depth_offset + depth_batch_start]);
|
|
1088
1092
|
}
|
|
1089
1093
|
|
|
@@ -1100,71 +1104,71 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_packed_f64_smef64_st
|
|
|
1100
1104
|
if (k_abs >= depth) break;
|
|
1101
1105
|
|
|
1102
1106
|
// Read A column from ZA0 and split into 3 Ozaki slices
|
|
1103
|
-
svfloat64_t a_f64x = svread_ver_za64_f64_m(svdup_f64(0.0),
|
|
1107
|
+
svfloat64_t a_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), row_predicate_b64x, 0, step);
|
|
1104
1108
|
svuint64_t a_bits_u64x = svreinterpret_u64_f64(a_f64x);
|
|
1105
1109
|
svfloat64_t a_slice_0_f64x = svreinterpret_f64_u64(
|
|
1106
|
-
svand_u64_x(
|
|
1107
|
-
svfloat64_t residual_a_f64x = svsub_f64_x(
|
|
1110
|
+
svand_u64_x(predicate_all_b64x, a_bits_u64x, ozaki_mask_19_u64x));
|
|
1111
|
+
svfloat64_t residual_a_f64x = svsub_f64_x(predicate_all_b64x, a_f64x, a_slice_0_f64x);
|
|
1108
1112
|
svuint64_t residual_a_bits_u64x = svreinterpret_u64_f64(residual_a_f64x);
|
|
1109
1113
|
svfloat64_t a_slice_1_f64x = svreinterpret_f64_u64(
|
|
1110
|
-
svand_u64_x(
|
|
1111
|
-
svfloat64_t a_slice_2_f64x = svsub_f64_x(
|
|
1114
|
+
svand_u64_x(predicate_all_b64x, residual_a_bits_u64x, ozaki_mask_17_u64x));
|
|
1115
|
+
svfloat64_t a_slice_2_f64x = svsub_f64_x(predicate_all_b64x, residual_a_f64x, a_slice_1_f64x);
|
|
1112
1116
|
|
|
1113
1117
|
// Load all 6 B slices upfront (3 per column tile) for pipeline interleaving
|
|
1114
1118
|
nk_size_t const b_tile_offset_0 = b_batch_offset_0 + step * interleaved_stride;
|
|
1115
1119
|
nk_size_t const b_tile_offset_1 = b_batch_offset_1 + step * interleaved_stride;
|
|
1116
1120
|
svfloat64_t b_column_0_slice_0_f64x = svcvt_f64_f32_x(
|
|
1117
|
-
|
|
1121
|
+
predicate_all_b64x,
|
|
1118
1122
|
svreinterpret_f32_u64(
|
|
1119
|
-
svld1uw_u64(
|
|
1123
|
+
svld1uw_u64(predicate_all_b64x, (nk_u32_t const *)(b_tiles + b_tile_offset_0))));
|
|
1120
1124
|
svfloat64_t b_column_0_slice_1_f64x = svcvt_f64_f32_x(
|
|
1121
|
-
|
|
1125
|
+
predicate_all_b64x,
|
|
1122
1126
|
svreinterpret_f32_u64(svld1uw_u64(
|
|
1123
|
-
|
|
1127
|
+
predicate_all_b64x, (nk_u32_t const *)(b_tiles + b_tile_offset_0 + tile_dimension))));
|
|
1124
1128
|
svfloat64_t b_column_0_slice_2_f64x = svcvt_f64_f32_x(
|
|
1125
|
-
|
|
1126
|
-
|
|
1129
|
+
predicate_all_b64x, svreinterpret_f32_u64(svld1uw_u64(
|
|
1130
|
+
predicate_all_b64x, (nk_u32_t const *)(b_tiles + b_tile_offset_0 +
|
|
1127
1131
|
2 * tile_dimension))));
|
|
1128
1132
|
svfloat64_t b_column_1_slice_0_f64x = svcvt_f64_f32_x(
|
|
1129
|
-
|
|
1133
|
+
predicate_all_b64x,
|
|
1130
1134
|
svreinterpret_f32_u64(
|
|
1131
|
-
svld1uw_u64(
|
|
1135
|
+
svld1uw_u64(predicate_all_b64x, (nk_u32_t const *)(b_tiles + b_tile_offset_1))));
|
|
1132
1136
|
svfloat64_t b_column_1_slice_1_f64x = svcvt_f64_f32_x(
|
|
1133
|
-
|
|
1137
|
+
predicate_all_b64x,
|
|
1134
1138
|
svreinterpret_f32_u64(svld1uw_u64(
|
|
1135
|
-
|
|
1139
|
+
predicate_all_b64x, (nk_u32_t const *)(b_tiles + b_tile_offset_1 + tile_dimension))));
|
|
1136
1140
|
svfloat64_t b_column_1_slice_2_f64x = svcvt_f64_f32_x(
|
|
1137
|
-
|
|
1138
|
-
|
|
1141
|
+
predicate_all_b64x, svreinterpret_f32_u64(svld1uw_u64(
|
|
1142
|
+
predicate_all_b64x, (nk_u32_t const *)(b_tiles + b_tile_offset_1 +
|
|
1139
1143
|
2 * tile_dimension))));
|
|
1140
1144
|
|
|
1141
1145
|
// 12 FMOPAs interleaved across 6 tiles to eliminate WAW pipeline stalls.
|
|
1142
1146
|
// Same-tile accumulation order preserved (bit-identical output).
|
|
1143
1147
|
// Tile gaps: ZA3 at 0,6,10 (6,4); ZA6 at 1,7,11 (6,4); ZA2 at 4,8 (4);
|
|
1144
1148
|
// ZA5 at 5,9 (4); ZA1 at 2; ZA4 at 3. All gaps >= 4-cycle latency.
|
|
1145
|
-
svmopa_za64_f64_m(3,
|
|
1149
|
+
svmopa_za64_f64_m(3, row_predicate_b64x, column_predicate_0_b64x, a_slice_0_f64x,
|
|
1146
1150
|
b_column_0_slice_2_f64x); // ZA3: i+j=2 (1/3)
|
|
1147
|
-
svmopa_za64_f64_m(6,
|
|
1151
|
+
svmopa_za64_f64_m(6, row_predicate_b64x, column_predicate_1_b64x, a_slice_0_f64x,
|
|
1148
1152
|
b_column_1_slice_2_f64x); // ZA6: i+j=2 (1/3)
|
|
1149
|
-
svmopa_za64_f64_m(1,
|
|
1153
|
+
svmopa_za64_f64_m(1, row_predicate_b64x, column_predicate_0_b64x, a_slice_0_f64x,
|
|
1150
1154
|
b_column_0_slice_0_f64x); // ZA1: i+j=0
|
|
1151
|
-
svmopa_za64_f64_m(4,
|
|
1155
|
+
svmopa_za64_f64_m(4, row_predicate_b64x, column_predicate_1_b64x, a_slice_0_f64x,
|
|
1152
1156
|
b_column_1_slice_0_f64x); // ZA4: i+j=0
|
|
1153
|
-
svmopa_za64_f64_m(2,
|
|
1157
|
+
svmopa_za64_f64_m(2, row_predicate_b64x, column_predicate_0_b64x, a_slice_0_f64x,
|
|
1154
1158
|
b_column_0_slice_1_f64x); // ZA2: i+j=1 (1/2)
|
|
1155
|
-
svmopa_za64_f64_m(5,
|
|
1159
|
+
svmopa_za64_f64_m(5, row_predicate_b64x, column_predicate_1_b64x, a_slice_0_f64x,
|
|
1156
1160
|
b_column_1_slice_1_f64x); // ZA5: i+j=1 (1/2)
|
|
1157
|
-
svmopa_za64_f64_m(3,
|
|
1161
|
+
svmopa_za64_f64_m(3, row_predicate_b64x, column_predicate_0_b64x, a_slice_1_f64x,
|
|
1158
1162
|
b_column_0_slice_1_f64x); // ZA3: i+j=2 (2/3)
|
|
1159
|
-
svmopa_za64_f64_m(6,
|
|
1163
|
+
svmopa_za64_f64_m(6, row_predicate_b64x, column_predicate_1_b64x, a_slice_1_f64x,
|
|
1160
1164
|
b_column_1_slice_1_f64x); // ZA6: i+j=2 (2/3)
|
|
1161
|
-
svmopa_za64_f64_m(2,
|
|
1165
|
+
svmopa_za64_f64_m(2, row_predicate_b64x, column_predicate_0_b64x, a_slice_1_f64x,
|
|
1162
1166
|
b_column_0_slice_0_f64x); // ZA2: i+j=1 (2/2)
|
|
1163
|
-
svmopa_za64_f64_m(5,
|
|
1167
|
+
svmopa_za64_f64_m(5, row_predicate_b64x, column_predicate_1_b64x, a_slice_1_f64x,
|
|
1164
1168
|
b_column_1_slice_0_f64x); // ZA5: i+j=1 (2/2)
|
|
1165
|
-
svmopa_za64_f64_m(3,
|
|
1169
|
+
svmopa_za64_f64_m(3, row_predicate_b64x, column_predicate_0_b64x, a_slice_2_f64x,
|
|
1166
1170
|
b_column_0_slice_0_f64x); // ZA3: i+j=2 (3/3)
|
|
1167
|
-
svmopa_za64_f64_m(6,
|
|
1171
|
+
svmopa_za64_f64_m(6, row_predicate_b64x, column_predicate_1_b64x, a_slice_2_f64x,
|
|
1168
1172
|
b_column_1_slice_0_f64x); // ZA6: i+j=2 (3/3)
|
|
1169
1173
|
}
|
|
1170
1174
|
}
|
|
@@ -1173,23 +1177,23 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_packed_f64_smef64_st
|
|
|
1173
1177
|
// Simple summation for col tile 0: ZA3 + ZA2 + ZA1 (smallest to largest)
|
|
1174
1178
|
for (nk_size_t row = 0; row < rows_remaining; row++) {
|
|
1175
1179
|
nk_f64_t *c_row = c + (row_start + row) * c_stride_elements + column_start_0;
|
|
1176
|
-
svfloat64_t result_f64x = svread_hor_za64_f64_m(svdup_f64(0.0),
|
|
1177
|
-
result_f64x = svadd_f64_x(
|
|
1178
|
-
svread_hor_za64_f64_m(svdup_f64(0.0),
|
|
1179
|
-
result_f64x = svadd_f64_x(
|
|
1180
|
-
svread_hor_za64_f64_m(svdup_f64(0.0),
|
|
1181
|
-
svst1_f64(
|
|
1180
|
+
svfloat64_t result_f64x = svread_hor_za64_f64_m(svdup_f64(0.0), predicate_all_b64x, 3, row);
|
|
1181
|
+
result_f64x = svadd_f64_x(predicate_all_b64x, result_f64x,
|
|
1182
|
+
svread_hor_za64_f64_m(svdup_f64(0.0), predicate_all_b64x, 2, row));
|
|
1183
|
+
result_f64x = svadd_f64_x(predicate_all_b64x, result_f64x,
|
|
1184
|
+
svread_hor_za64_f64_m(svdup_f64(0.0), predicate_all_b64x, 1, row));
|
|
1185
|
+
svst1_f64(column_predicate_0_b64x, c_row, result_f64x);
|
|
1182
1186
|
}
|
|
1183
1187
|
|
|
1184
1188
|
// Simple summation for col tile 1: ZA6 + ZA5 + ZA4 (smallest to largest)
|
|
1185
1189
|
for (nk_size_t row = 0; row < rows_remaining; row++) {
|
|
1186
1190
|
nk_f64_t *c_row = c + (row_start + row) * c_stride_elements + column_start_1;
|
|
1187
|
-
svfloat64_t result_f64x = svread_hor_za64_f64_m(svdup_f64(0.0),
|
|
1188
|
-
result_f64x = svadd_f64_x(
|
|
1189
|
-
svread_hor_za64_f64_m(svdup_f64(0.0),
|
|
1190
|
-
result_f64x = svadd_f64_x(
|
|
1191
|
-
svread_hor_za64_f64_m(svdup_f64(0.0),
|
|
1192
|
-
svst1_f64(
|
|
1191
|
+
svfloat64_t result_f64x = svread_hor_za64_f64_m(svdup_f64(0.0), predicate_all_b64x, 6, row);
|
|
1192
|
+
result_f64x = svadd_f64_x(predicate_all_b64x, result_f64x,
|
|
1193
|
+
svread_hor_za64_f64_m(svdup_f64(0.0), predicate_all_b64x, 5, row));
|
|
1194
|
+
result_f64x = svadd_f64_x(predicate_all_b64x, result_f64x,
|
|
1195
|
+
svread_hor_za64_f64_m(svdup_f64(0.0), predicate_all_b64x, 4, row));
|
|
1196
|
+
svst1_f64(column_predicate_1_b64x, c_row, result_f64x);
|
|
1193
1197
|
}
|
|
1194
1198
|
}
|
|
1195
1199
|
|
|
@@ -1198,7 +1202,7 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_packed_f64_smef64_st
|
|
|
1198
1202
|
nk_size_t const column_start = column_tile_index * tile_dimension;
|
|
1199
1203
|
nk_size_t const columns_remaining = (column_start + tile_dimension <= columns) ? tile_dimension
|
|
1200
1204
|
: (columns - column_start);
|
|
1201
|
-
svbool_t const
|
|
1205
|
+
svbool_t const column_predicate_b64x = svwhilelt_b64_u64(0u, columns_remaining);
|
|
1202
1206
|
|
|
1203
1207
|
// Zero ZA1-3 (3 merged accumulators)
|
|
1204
1208
|
svzero_mask_za(nk_sme_zero_za64_tiles_1_3_);
|
|
@@ -1219,9 +1223,9 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_packed_f64_smef64_st
|
|
|
1219
1223
|
svzero_mask_za(nk_sme_zero_za64_tile_0_);
|
|
1220
1224
|
for (nk_size_t row_in_tile = 0; row_in_tile < rows_remaining; row_in_tile++) {
|
|
1221
1225
|
nk_size_t const a_row = row_start + row_in_tile;
|
|
1222
|
-
svbool_t const
|
|
1223
|
-
|
|
1224
|
-
svld1_hor_za64(0, row_in_tile,
|
|
1226
|
+
svbool_t const a_depth_predicate_b64x = svwhilelt_b64_u64(depth_offset + depth_batch_start,
|
|
1227
|
+
depth);
|
|
1228
|
+
svld1_hor_za64(0, row_in_tile, a_depth_predicate_b64x,
|
|
1225
1229
|
&a[a_row * a_stride_elements + depth_offset + depth_batch_start]);
|
|
1226
1230
|
}
|
|
1227
1231
|
|
|
@@ -1234,45 +1238,45 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_packed_f64_smef64_st
|
|
|
1234
1238
|
if (k_abs >= depth) break;
|
|
1235
1239
|
|
|
1236
1240
|
// Read A column from ZA0 and split into 3 Ozaki slices
|
|
1237
|
-
svfloat64_t a_f64x = svread_ver_za64_f64_m(svdup_f64(0.0),
|
|
1241
|
+
svfloat64_t a_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), row_predicate_b64x, 0, step);
|
|
1238
1242
|
svuint64_t a_bits_u64x = svreinterpret_u64_f64(a_f64x);
|
|
1239
1243
|
svfloat64_t a_slice_0_f64x = svreinterpret_f64_u64(
|
|
1240
|
-
svand_u64_x(
|
|
1241
|
-
svfloat64_t residual_a_f64x = svsub_f64_x(
|
|
1244
|
+
svand_u64_x(predicate_all_b64x, a_bits_u64x, ozaki_mask_19_u64x));
|
|
1245
|
+
svfloat64_t residual_a_f64x = svsub_f64_x(predicate_all_b64x, a_f64x, a_slice_0_f64x);
|
|
1242
1246
|
svuint64_t residual_a_bits_u64x = svreinterpret_u64_f64(residual_a_f64x);
|
|
1243
1247
|
svfloat64_t a_slice_1_f64x = svreinterpret_f64_u64(
|
|
1244
|
-
svand_u64_x(
|
|
1245
|
-
svfloat64_t a_slice_2_f64x = svsub_f64_x(
|
|
1248
|
+
svand_u64_x(predicate_all_b64x, residual_a_bits_u64x, ozaki_mask_17_u64x));
|
|
1249
|
+
svfloat64_t a_slice_2_f64x = svsub_f64_x(predicate_all_b64x, residual_a_f64x, a_slice_1_f64x);
|
|
1246
1250
|
|
|
1247
1251
|
// Load 3 B slices (contiguous in interleaved layout)
|
|
1248
1252
|
nk_size_t const b_tile_offset = b_batch_offset + step * interleaved_stride;
|
|
1249
1253
|
svfloat64_t b_slice_0_f64x = svcvt_f64_f32_x(
|
|
1250
|
-
|
|
1251
|
-
|
|
1254
|
+
predicate_all_b64x, svreinterpret_f32_u64(svld1uw_u64(
|
|
1255
|
+
predicate_all_b64x, (nk_u32_t const *)(b_tiles + b_tile_offset))));
|
|
1252
1256
|
svfloat64_t b_slice_1_f64x = svcvt_f64_f32_x(
|
|
1253
|
-
|
|
1257
|
+
predicate_all_b64x,
|
|
1254
1258
|
svreinterpret_f32_u64(svld1uw_u64(
|
|
1255
|
-
|
|
1259
|
+
predicate_all_b64x, (nk_u32_t const *)(b_tiles + b_tile_offset + tile_dimension))));
|
|
1256
1260
|
svfloat64_t b_slice_2_f64x = svcvt_f64_f32_x(
|
|
1257
|
-
|
|
1261
|
+
predicate_all_b64x,
|
|
1258
1262
|
svreinterpret_f32_u64(svld1uw_u64(
|
|
1259
|
-
|
|
1263
|
+
predicate_all_b64x, (nk_u32_t const *)(b_tiles + b_tile_offset + 2 * tile_dimension))));
|
|
1260
1264
|
|
|
1261
1265
|
// 6 FMOPAs reordered to minimize WAW pipeline stalls on 3 tiles.
|
|
1262
1266
|
// Same-tile accumulation order preserved (bit-identical output).
|
|
1263
1267
|
// Tile schedule: ZA3(0), ZA2(1), ZA1(2), ZA3(4), ZA2(5), ZA3(8).
|
|
1264
1268
|
// 9 cycles vs 15 original (3 unavoidable bubbles with only 3 tiles).
|
|
1265
|
-
svmopa_za64_f64_m(3,
|
|
1269
|
+
svmopa_za64_f64_m(3, row_predicate_b64x, column_predicate_b64x, a_slice_0_f64x,
|
|
1266
1270
|
b_slice_2_f64x); // ZA3: i+j=2 (1/3)
|
|
1267
|
-
svmopa_za64_f64_m(2,
|
|
1271
|
+
svmopa_za64_f64_m(2, row_predicate_b64x, column_predicate_b64x, a_slice_0_f64x,
|
|
1268
1272
|
b_slice_1_f64x); // ZA2: i+j=1 (1/2)
|
|
1269
|
-
svmopa_za64_f64_m(1,
|
|
1273
|
+
svmopa_za64_f64_m(1, row_predicate_b64x, column_predicate_b64x, a_slice_0_f64x,
|
|
1270
1274
|
b_slice_0_f64x); // ZA1: i+j=0
|
|
1271
|
-
svmopa_za64_f64_m(3,
|
|
1275
|
+
svmopa_za64_f64_m(3, row_predicate_b64x, column_predicate_b64x, a_slice_1_f64x,
|
|
1272
1276
|
b_slice_1_f64x); // ZA3: i+j=2 (2/3)
|
|
1273
|
-
svmopa_za64_f64_m(2,
|
|
1277
|
+
svmopa_za64_f64_m(2, row_predicate_b64x, column_predicate_b64x, a_slice_1_f64x,
|
|
1274
1278
|
b_slice_0_f64x); // ZA2: i+j=1 (2/2)
|
|
1275
|
-
svmopa_za64_f64_m(3,
|
|
1279
|
+
svmopa_za64_f64_m(3, row_predicate_b64x, column_predicate_b64x, a_slice_2_f64x,
|
|
1276
1280
|
b_slice_0_f64x); // ZA3: i+j=2 (3/3)
|
|
1277
1281
|
}
|
|
1278
1282
|
}
|
|
@@ -1281,27 +1285,28 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_packed_f64_smef64_st
|
|
|
1281
1285
|
// Simple summation: ZA3 + ZA2 + ZA1 (smallest to largest)
|
|
1282
1286
|
for (nk_size_t row = 0; row < rows_remaining; row++) {
|
|
1283
1287
|
nk_f64_t *c_row = c + (row_start + row) * c_stride_elements + column_start;
|
|
1284
|
-
svfloat64_t result_f64x = svread_hor_za64_f64_m(svdup_f64(0.0),
|
|
1285
|
-
result_f64x = svadd_f64_x(
|
|
1286
|
-
svread_hor_za64_f64_m(svdup_f64(0.0),
|
|
1287
|
-
result_f64x = svadd_f64_x(
|
|
1288
|
-
svread_hor_za64_f64_m(svdup_f64(0.0),
|
|
1289
|
-
svst1_f64(
|
|
1288
|
+
svfloat64_t result_f64x = svread_hor_za64_f64_m(svdup_f64(0.0), predicate_all_b64x, 3, row);
|
|
1289
|
+
result_f64x = svadd_f64_x(predicate_all_b64x, result_f64x,
|
|
1290
|
+
svread_hor_za64_f64_m(svdup_f64(0.0), predicate_all_b64x, 2, row));
|
|
1291
|
+
result_f64x = svadd_f64_x(predicate_all_b64x, result_f64x,
|
|
1292
|
+
svread_hor_za64_f64_m(svdup_f64(0.0), predicate_all_b64x, 1, row));
|
|
1293
|
+
svst1_f64(column_predicate_b64x, c_row, result_f64x);
|
|
1290
1294
|
}
|
|
1291
1295
|
}
|
|
1292
1296
|
}
|
|
1293
1297
|
}
|
|
1294
1298
|
|
|
1295
1299
|
NK_PUBLIC void nk_dots_packed_f64_smef64(nk_f64_t const *a, void const *b_packed, nk_f64_t *c, nk_size_t rows,
|
|
1296
|
-
nk_size_t columns, nk_size_t depth, nk_size_t
|
|
1300
|
+
nk_size_t columns, nk_size_t depth, nk_size_t a_stride_in_bytes,
|
|
1301
|
+
nk_size_t c_stride_in_bytes) {
|
|
1297
1302
|
|
|
1298
|
-
nk_size_t const a_stride_elements =
|
|
1299
|
-
nk_size_t const c_stride_elements =
|
|
1303
|
+
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_f64_t);
|
|
1304
|
+
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f64_t);
|
|
1300
1305
|
|
|
1301
1306
|
nk_dots_packed_f64_smef64_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
|
|
1302
1307
|
}
|
|
1303
1308
|
|
|
1304
|
-
#pragma endregion
|
|
1309
|
+
#pragma endregion F64 Floats
|
|
1305
1310
|
|
|
1306
1311
|
#if defined(__clang__)
|
|
1307
1312
|
#pragma clang attribute pop
|