numkong 7.4.5 → 7.5.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +1 -0
- package/binding.gyp +81 -5
- package/c/dispatch_f16.c +23 -0
- package/c/numkong.c +0 -13
- package/include/numkong/attention/sme.h +34 -31
- package/include/numkong/capabilities.h +2 -15
- package/include/numkong/cast/neon.h +15 -0
- package/include/numkong/curved/smef64.h +82 -62
- package/include/numkong/dot/rvvbf16.h +1 -1
- package/include/numkong/dot/rvvhalf.h +1 -1
- package/include/numkong/dot/sve.h +6 -5
- package/include/numkong/dot/svebfdot.h +2 -1
- package/include/numkong/dot/svehalf.h +6 -5
- package/include/numkong/dot/svesdot.h +3 -2
- package/include/numkong/dots/graniteamx.h +733 -0
- package/include/numkong/dots/serial.h +11 -4
- package/include/numkong/dots/sme.h +172 -140
- package/include/numkong/dots/smebi32.h +14 -11
- package/include/numkong/dots/smef64.h +31 -26
- package/include/numkong/dots.h +29 -3
- package/include/numkong/each/serial.h +22 -0
- package/include/numkong/geospatial/haswell.h +1 -1
- package/include/numkong/geospatial/neon.h +1 -1
- package/include/numkong/geospatial/serial.h +1 -1
- package/include/numkong/geospatial/skylake.h +1 -1
- package/include/numkong/maxsim/sme.h +34 -33
- package/include/numkong/mesh/serial.h +22 -0
- package/include/numkong/reduce/neon.h +29 -0
- package/include/numkong/reduce/neonbfdot.h +2 -2
- package/include/numkong/reduce/neonfhm.h +4 -4
- package/include/numkong/reduce/sve.h +52 -0
- package/include/numkong/reduce.h +4 -0
- package/include/numkong/set/sve.h +6 -5
- package/include/numkong/sets/smebi32.h +35 -30
- package/include/numkong/sparse/sve2.h +3 -2
- package/include/numkong/spatial/sve.h +7 -6
- package/include/numkong/spatial/svebfdot.h +7 -4
- package/include/numkong/spatial/svehalf.h +5 -4
- package/include/numkong/spatial/svesdot.h +9 -8
- package/include/numkong/spatials/graniteamx.h +173 -0
- package/include/numkong/spatials/serial.h +22 -0
- package/include/numkong/spatials/sme.h +391 -350
- package/include/numkong/spatials/smef64.h +79 -70
- package/include/numkong/spatials.h +37 -4
- package/include/numkong/types.h +59 -0
- package/javascript/dist/cjs/numkong.js +13 -0
- package/javascript/dist/esm/numkong.js +13 -0
- package/javascript/numkong.c +56 -12
- package/javascript/numkong.ts +13 -0
- package/package.json +7 -7
- package/probes/probe.js +2 -2
- package/wasm/numkong.wasm +0 -0
|
@@ -63,7 +63,7 @@
|
|
|
63
63
|
|
|
64
64
|
#include "numkong/types.h"
|
|
65
65
|
#include "numkong/cast/serial.h" // `nk_e4m3_to_f16_serial`, `nk_e5m2_to_f16_serial`
|
|
66
|
-
#include "numkong/dots/serial.h" // `nk_dots_reduce_sumsq_f16_`, `nk_dots_reduce_sumsq_i8_
|
|
66
|
+
#include "numkong/dots/serial.h" // `nk_dots_reduce_sumsq_f16_`, `nk_dots_reduce_sumsq_i8_`
|
|
67
67
|
|
|
68
68
|
#if defined(__cplusplus)
|
|
69
69
|
extern "C" {
|
|
@@ -146,9 +146,9 @@ NK_PUBLIC nk_size_t nk_dots_packed_size_bf16_sme(nk_size_t columns, nk_size_t de
|
|
|
146
146
|
*
|
|
147
147
|
* Replaces the scalar scatter loop with hardware-accelerated tile transpose.
|
|
148
148
|
*/
|
|
149
|
-
|
|
150
|
-
void const *b, nk_size_t columns, nk_size_t depth,
|
|
151
|
-
nk_size_t b_stride_bytes, void *tiles_ptr) {
|
|
149
|
+
__arm_new("za") static void nk_dots_pack_b16_sme_streaming_( //
|
|
150
|
+
void const *b, nk_size_t columns, nk_size_t depth, //
|
|
151
|
+
nk_size_t b_stride_bytes, void *tiles_ptr) NK_STREAMING_ {
|
|
152
152
|
|
|
153
153
|
nk_size_t const expansion = 2;
|
|
154
154
|
nk_size_t const tile_dimension = svcntw();
|
|
@@ -200,9 +200,8 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_pack_b16_sme_streami
|
|
|
200
200
|
* required by SMOPA/UMOPA: each i32 word in the output vector holds four 8-bit values
|
|
201
201
|
* from the same column but adjacent depth positions.
|
|
202
202
|
*/
|
|
203
|
-
|
|
204
|
-
void const *b, nk_size_t columns, nk_size_t depth,
|
|
205
|
-
nk_size_t b_stride_bytes, void *tiles_ptr) {
|
|
203
|
+
__arm_new("za") static void nk_dots_pack_b8_sme_streaming_( //
|
|
204
|
+
void const *b, nk_size_t columns, nk_size_t depth, nk_size_t b_stride_bytes, void *tiles_ptr) NK_STREAMING_ {
|
|
206
205
|
|
|
207
206
|
nk_size_t const expansion = 4;
|
|
208
207
|
nk_size_t const tile_dimension = svcntw();
|
|
@@ -267,7 +266,9 @@ NK_PUBLIC void nk_dots_pack_f16_sme( //
|
|
|
267
266
|
|
|
268
267
|
nk_f16_t *tiles_ptr = (nk_f16_t *)((char *)b_packed + sizeof(nk_dots_sme_packed_header_t));
|
|
269
268
|
|
|
269
|
+
nk_sme_start_streaming_();
|
|
270
270
|
nk_dots_pack_b16_sme_streaming_(b, columns, depth, b_stride_in_bytes, tiles_ptr);
|
|
271
|
+
nk_sme_stop_streaming_();
|
|
271
272
|
|
|
272
273
|
nk_size_t const data_size = total_vectors * vector_elements * sizeof(nk_f16_t);
|
|
273
274
|
header->norms_offset = (nk_u32_t)(sizeof(nk_dots_sme_packed_header_t) + data_size);
|
|
@@ -299,7 +300,9 @@ NK_PUBLIC void nk_dots_pack_bf16_sme( //
|
|
|
299
300
|
|
|
300
301
|
nk_bf16_t *tiles_ptr = (nk_bf16_t *)((char *)b_packed + sizeof(nk_dots_sme_packed_header_t));
|
|
301
302
|
|
|
303
|
+
nk_sme_start_streaming_();
|
|
302
304
|
nk_dots_pack_b16_sme_streaming_(b, columns, depth, b_stride_in_bytes, tiles_ptr);
|
|
305
|
+
nk_sme_stop_streaming_();
|
|
303
306
|
|
|
304
307
|
nk_size_t const data_size = total_vectors * vector_elements * sizeof(nk_bf16_t);
|
|
305
308
|
header->norms_offset = (nk_u32_t)(sizeof(nk_dots_sme_packed_header_t) + data_size);
|
|
@@ -321,10 +324,9 @@ NK_PUBLIC void nk_dots_pack_bf16_sme( //
|
|
|
321
324
|
* - Zm[2*d+k] = B[column_start+d, depth_base+k] (pre-packed interleaved)
|
|
322
325
|
* - Loop over depth in steps of 2 (expansion factor)
|
|
323
326
|
*/
|
|
324
|
-
|
|
325
|
-
nk_f16_t const *a, void const *b_packed, nk_f32_t *c,
|
|
326
|
-
nk_size_t
|
|
327
|
-
nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
|
|
327
|
+
__arm_new("za") static void nk_dots_packed_f16_sme_streaming_( //
|
|
328
|
+
nk_f16_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
329
|
+
nk_size_t a_stride_elements, nk_size_t c_stride_elements) NK_STREAMING_ {
|
|
328
330
|
|
|
329
331
|
nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
|
|
330
332
|
nk_size_t const column_tile_count = header->column_tile_count;
|
|
@@ -469,10 +471,9 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_packed_f16_sme_strea
|
|
|
469
471
|
* `bf16` → `f32` GEMM core kernel using SME outer products.
|
|
470
472
|
* Same interleaved algorithm as f16 kernel, using BFMOPA bf16 → f32.
|
|
471
473
|
*/
|
|
472
|
-
|
|
473
|
-
nk_bf16_t const *a, void const *b_packed, nk_f32_t *c,
|
|
474
|
-
nk_size_t
|
|
475
|
-
nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
|
|
474
|
+
__arm_new("za") static void nk_dots_packed_bf16_sme_streaming_( //
|
|
475
|
+
nk_bf16_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
476
|
+
nk_size_t a_stride_elements, nk_size_t c_stride_elements) NK_STREAMING_ {
|
|
476
477
|
|
|
477
478
|
nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
|
|
478
479
|
nk_size_t const column_tile_count = header->column_tile_count;
|
|
@@ -614,7 +615,9 @@ NK_PUBLIC void nk_dots_packed_f16_sme( //
|
|
|
614
615
|
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_f16_t);
|
|
615
616
|
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
|
|
616
617
|
|
|
618
|
+
nk_sme_start_streaming_();
|
|
617
619
|
nk_dots_packed_f16_sme_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
|
|
620
|
+
nk_sme_stop_streaming_();
|
|
618
621
|
}
|
|
619
622
|
|
|
620
623
|
NK_PUBLIC void nk_dots_packed_bf16_sme( //
|
|
@@ -625,7 +628,9 @@ NK_PUBLIC void nk_dots_packed_bf16_sme( //
|
|
|
625
628
|
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_bf16_t);
|
|
626
629
|
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
|
|
627
630
|
|
|
631
|
+
nk_sme_start_streaming_();
|
|
628
632
|
nk_dots_packed_bf16_sme_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
|
|
633
|
+
nk_sme_stop_streaming_();
|
|
629
634
|
}
|
|
630
635
|
|
|
631
636
|
/**
|
|
@@ -634,9 +639,9 @@ NK_PUBLIC void nk_dots_packed_bf16_sme( //
|
|
|
634
639
|
* pre-reads A columns into Z registers, then reloads ZA0 with B data
|
|
635
640
|
* per column tile. Eliminates all scalar B-packing loops.
|
|
636
641
|
*/
|
|
637
|
-
|
|
642
|
+
__arm_new("za") static void nk_dots_symmetric_f16_sme_streaming_( //
|
|
638
643
|
nk_f16_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, nk_f32_t *result,
|
|
639
|
-
nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
|
|
644
|
+
nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) NK_STREAMING_ {
|
|
640
645
|
|
|
641
646
|
nk_size_t const expansion = 2;
|
|
642
647
|
nk_size_t const tile_dimension = svcntw(); // 16
|
|
@@ -827,19 +832,21 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_symmetric_f16_sme_st
|
|
|
827
832
|
}
|
|
828
833
|
}
|
|
829
834
|
|
|
830
|
-
NK_PUBLIC void nk_dots_symmetric_f16_sme(
|
|
831
|
-
|
|
832
|
-
|
|
835
|
+
NK_PUBLIC void nk_dots_symmetric_f16_sme( //
|
|
836
|
+
nk_f16_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, nk_f32_t *result,
|
|
837
|
+
nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
833
838
|
|
|
834
839
|
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_f16_t);
|
|
835
840
|
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
841
|
+
nk_sme_start_streaming_();
|
|
836
842
|
nk_dots_symmetric_f16_sme_streaming_(vectors, vectors_count, depth, stride_elements, result, result_stride_elements,
|
|
837
843
|
row_start, row_count);
|
|
844
|
+
nk_sme_stop_streaming_();
|
|
838
845
|
}
|
|
839
846
|
|
|
840
|
-
|
|
847
|
+
__arm_new("za") static void nk_dots_symmetric_bf16_sme_streaming_( //
|
|
841
848
|
nk_bf16_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, nk_f32_t *result,
|
|
842
|
-
nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
|
|
849
|
+
nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) NK_STREAMING_ {
|
|
843
850
|
|
|
844
851
|
nk_size_t const expansion = 2;
|
|
845
852
|
nk_size_t const tile_dimension = svcntw();
|
|
@@ -1020,14 +1027,16 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_symmetric_bf16_sme_s
|
|
|
1020
1027
|
}
|
|
1021
1028
|
}
|
|
1022
1029
|
|
|
1023
|
-
NK_PUBLIC void nk_dots_symmetric_bf16_sme(
|
|
1024
|
-
|
|
1025
|
-
|
|
1030
|
+
NK_PUBLIC void nk_dots_symmetric_bf16_sme( //
|
|
1031
|
+
nk_bf16_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, nk_f32_t *result,
|
|
1032
|
+
nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
1026
1033
|
|
|
1027
1034
|
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_bf16_t);
|
|
1028
1035
|
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
1036
|
+
nk_sme_start_streaming_();
|
|
1029
1037
|
nk_dots_symmetric_bf16_sme_streaming_(vectors, vectors_count, depth, stride_elements, result,
|
|
1030
1038
|
result_stride_elements, row_start, row_count);
|
|
1039
|
+
nk_sme_stop_streaming_();
|
|
1031
1040
|
}
|
|
1032
1041
|
|
|
1033
1042
|
#pragma endregion F16 Floats
|
|
@@ -1061,9 +1070,8 @@ NK_PUBLIC nk_size_t nk_dots_packed_size_i8_sme(nk_size_t columns, nk_size_t dept
|
|
|
1061
1070
|
return size;
|
|
1062
1071
|
}
|
|
1063
1072
|
|
|
1064
|
-
NK_PUBLIC void nk_dots_pack_i8_sme(
|
|
1065
|
-
nk_i8_t const *b, nk_size_t columns, nk_size_t depth,
|
|
1066
|
-
nk_size_t b_stride_in_bytes, void *b_packed) {
|
|
1073
|
+
NK_PUBLIC void nk_dots_pack_i8_sme( //
|
|
1074
|
+
nk_i8_t const *b, nk_size_t columns, nk_size_t depth, nk_size_t b_stride_in_bytes, void *b_packed) {
|
|
1067
1075
|
|
|
1068
1076
|
nk_size_t const expansion = 4;
|
|
1069
1077
|
nk_size_t const tile_dimension = nk_sme_cntw_();
|
|
@@ -1082,7 +1090,9 @@ NK_PUBLIC void nk_dots_pack_i8_sme( //
|
|
|
1082
1090
|
|
|
1083
1091
|
nk_i8_t *tiles_ptr = (nk_i8_t *)((char *)b_packed + sizeof(nk_dots_sme_packed_header_t));
|
|
1084
1092
|
|
|
1093
|
+
nk_sme_start_streaming_();
|
|
1085
1094
|
nk_dots_pack_b8_sme_streaming_(b, columns, depth, b_stride_in_bytes, tiles_ptr);
|
|
1095
|
+
nk_sme_stop_streaming_();
|
|
1086
1096
|
|
|
1087
1097
|
nk_size_t const data_size = total_vectors * vector_elements * sizeof(nk_i8_t);
|
|
1088
1098
|
header->norms_offset = (nk_u32_t)(sizeof(nk_dots_sme_packed_header_t) + data_size);
|
|
@@ -1104,10 +1114,9 @@ NK_PUBLIC void nk_dots_pack_i8_sme( //
|
|
|
1104
1114
|
* - Zm[4*d+k] = B[column_start+d, depth_base+k] (pre-packed interleaved)
|
|
1105
1115
|
* - Loop over depth in steps of 4 (expansion factor)
|
|
1106
1116
|
*/
|
|
1107
|
-
|
|
1108
|
-
nk_i8_t const *a, void const *b_packed, nk_i32_t *c,
|
|
1109
|
-
nk_size_t
|
|
1110
|
-
nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
|
|
1117
|
+
__arm_new("za") static void nk_dots_packed_i8_sme_streaming_( //
|
|
1118
|
+
nk_i8_t const *a, void const *b_packed, nk_i32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
1119
|
+
nk_size_t a_stride_elements, nk_size_t c_stride_elements) NK_STREAMING_ {
|
|
1111
1120
|
|
|
1112
1121
|
nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
|
|
1113
1122
|
nk_size_t const column_tile_count = header->column_tile_count;
|
|
@@ -1244,13 +1253,14 @@ NK_PUBLIC void nk_dots_packed_i8_sme( //
|
|
|
1244
1253
|
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_i8_t);
|
|
1245
1254
|
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_i32_t);
|
|
1246
1255
|
|
|
1256
|
+
nk_sme_start_streaming_();
|
|
1247
1257
|
nk_dots_packed_i8_sme_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
|
|
1258
|
+
nk_sme_stop_streaming_();
|
|
1248
1259
|
}
|
|
1249
1260
|
|
|
1250
|
-
|
|
1251
|
-
__arm_locally_streaming __arm_new("za") static void nk_dots_symmetric_i8_sme_streaming_(
|
|
1261
|
+
__arm_new("za") static void nk_dots_symmetric_i8_sme_streaming_( //
|
|
1252
1262
|
nk_i8_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, nk_i32_t *result,
|
|
1253
|
-
nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
|
|
1263
|
+
nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) NK_STREAMING_ {
|
|
1254
1264
|
|
|
1255
1265
|
nk_size_t const expansion = 4;
|
|
1256
1266
|
nk_size_t const tile_dimension = svcntw();
|
|
@@ -1417,14 +1427,16 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_symmetric_i8_sme_str
|
|
|
1417
1427
|
}
|
|
1418
1428
|
}
|
|
1419
1429
|
|
|
1420
|
-
NK_PUBLIC void nk_dots_symmetric_i8_sme(
|
|
1421
|
-
|
|
1422
|
-
|
|
1430
|
+
NK_PUBLIC void nk_dots_symmetric_i8_sme( //
|
|
1431
|
+
nk_i8_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, nk_i32_t *result,
|
|
1432
|
+
nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
1423
1433
|
|
|
1424
1434
|
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_i8_t);
|
|
1425
1435
|
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_i32_t);
|
|
1436
|
+
nk_sme_start_streaming_();
|
|
1426
1437
|
nk_dots_symmetric_i8_sme_streaming_(vectors, vectors_count, depth, stride_elements, result, result_stride_elements,
|
|
1427
1438
|
row_start, row_count);
|
|
1439
|
+
nk_sme_stop_streaming_();
|
|
1428
1440
|
}
|
|
1429
1441
|
|
|
1430
1442
|
#pragma endregion I8 Integers
|
|
@@ -1519,10 +1531,9 @@ NK_PUBLIC svfloat16_t nk_e5m2x_to_f16x_ssve_(svbool_t predicate_b16x, svuint8_t
|
|
|
1519
1531
|
* Fused `e4m3` × `e4m3` → `f32` GEMM kernel using interleaved FMOPA.
|
|
1520
1532
|
* Converts `e4m3` → `f16` on-the-fly for A, B is pre-converted during packing.
|
|
1521
1533
|
*/
|
|
1522
|
-
|
|
1523
|
-
nk_e4m3_t const *a, void const *b_packed, nk_f32_t *c,
|
|
1524
|
-
nk_size_t
|
|
1525
|
-
nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
|
|
1534
|
+
__arm_new("za") static void nk_dots_packed_e4m3_sme_streaming_( //
|
|
1535
|
+
nk_e4m3_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
1536
|
+
nk_size_t a_stride_elements, nk_size_t c_stride_elements) NK_STREAMING_ {
|
|
1526
1537
|
|
|
1527
1538
|
nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
|
|
1528
1539
|
nk_size_t const column_tile_count = header->column_tile_count;
|
|
@@ -1667,8 +1678,8 @@ NK_PUBLIC nk_size_t nk_dots_packed_size_e4m3_sme(nk_size_t columns, nk_size_t de
|
|
|
1667
1678
|
}
|
|
1668
1679
|
|
|
1669
1680
|
/** @brief Streaming e4m3 → f16 pack using ZA tile transpose. */
|
|
1670
|
-
|
|
1671
|
-
void const *b, nk_size_t columns, nk_size_t depth, nk_size_t b_stride_bytes, void *tiles_ptr) {
|
|
1681
|
+
__arm_new("za") static void nk_dots_pack_e4m3_to_b16_sme_streaming_( //
|
|
1682
|
+
void const *b, nk_size_t columns, nk_size_t depth, nk_size_t b_stride_bytes, void *tiles_ptr) NK_STREAMING_ {
|
|
1672
1683
|
|
|
1673
1684
|
nk_size_t const expansion = 2;
|
|
1674
1685
|
nk_size_t const tile_dimension = svcntw();
|
|
@@ -1712,8 +1723,8 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_pack_e4m3_to_b16_sme
|
|
|
1712
1723
|
}
|
|
1713
1724
|
|
|
1714
1725
|
/** @brief Streaming e5m2 → f16 pack using ZA tile transpose. */
|
|
1715
|
-
|
|
1716
|
-
void const *b, nk_size_t columns, nk_size_t depth, nk_size_t b_stride_bytes, void *tiles_ptr) {
|
|
1726
|
+
__arm_new("za") static void nk_dots_pack_e5m2_to_b16_sme_streaming_( //
|
|
1727
|
+
void const *b, nk_size_t columns, nk_size_t depth, nk_size_t b_stride_bytes, void *tiles_ptr) NK_STREAMING_ {
|
|
1717
1728
|
|
|
1718
1729
|
nk_size_t const expansion = 2;
|
|
1719
1730
|
nk_size_t const tile_dimension = svcntw();
|
|
@@ -1777,7 +1788,9 @@ NK_PUBLIC void nk_dots_pack_e4m3_sme( //
|
|
|
1777
1788
|
|
|
1778
1789
|
nk_f16_t *tiles_ptr = (nk_f16_t *)((char *)b_packed + sizeof(nk_dots_sme_packed_header_t));
|
|
1779
1790
|
|
|
1791
|
+
nk_sme_start_streaming_();
|
|
1780
1792
|
nk_dots_pack_e4m3_to_b16_sme_streaming_(b, columns, depth, b_stride_in_bytes, tiles_ptr);
|
|
1793
|
+
nk_sme_stop_streaming_();
|
|
1781
1794
|
|
|
1782
1795
|
nk_size_t const data_size = total_vectors * vector_elements * sizeof(nk_f16_t);
|
|
1783
1796
|
header->norms_offset = (nk_u32_t)(sizeof(nk_dots_sme_packed_header_t) + data_size);
|
|
@@ -1796,7 +1809,9 @@ NK_PUBLIC void nk_dots_packed_e4m3_sme( //
|
|
|
1796
1809
|
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_e4m3_t);
|
|
1797
1810
|
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
|
|
1798
1811
|
|
|
1812
|
+
nk_sme_start_streaming_();
|
|
1799
1813
|
nk_dots_packed_e4m3_sme_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
|
|
1814
|
+
nk_sme_stop_streaming_();
|
|
1800
1815
|
}
|
|
1801
1816
|
|
|
1802
1817
|
/**
|
|
@@ -1805,9 +1820,9 @@ NK_PUBLIC void nk_dots_packed_e4m3_sme( //
|
|
|
1805
1820
|
* Pre-reads A columns into Z registers, then reloads ZA0 with converted B data
|
|
1806
1821
|
* per column tile. Eliminates all scalar B-packing loops.
|
|
1807
1822
|
*/
|
|
1808
|
-
|
|
1823
|
+
__arm_new("za") static void nk_dots_symmetric_e4m3_sme_streaming_( //
|
|
1809
1824
|
nk_e4m3_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, nk_f32_t *result,
|
|
1810
|
-
nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
|
|
1825
|
+
nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) NK_STREAMING_ {
|
|
1811
1826
|
|
|
1812
1827
|
nk_size_t const expansion = 2;
|
|
1813
1828
|
nk_size_t const tile_dimension = svcntw();
|
|
@@ -2005,14 +2020,16 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_symmetric_e4m3_sme_s
|
|
|
2005
2020
|
}
|
|
2006
2021
|
}
|
|
2007
2022
|
|
|
2008
|
-
NK_PUBLIC void nk_dots_symmetric_e4m3_sme(
|
|
2009
|
-
|
|
2010
|
-
|
|
2023
|
+
NK_PUBLIC void nk_dots_symmetric_e4m3_sme( //
|
|
2024
|
+
nk_e4m3_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, nk_f32_t *result,
|
|
2025
|
+
nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
2011
2026
|
|
|
2012
2027
|
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_e4m3_t);
|
|
2013
2028
|
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
2029
|
+
nk_sme_start_streaming_();
|
|
2014
2030
|
nk_dots_symmetric_e4m3_sme_streaming_(vectors, vectors_count, depth, stride_elements, result,
|
|
2015
2031
|
result_stride_elements, row_start, row_count);
|
|
2032
|
+
nk_sme_stop_streaming_();
|
|
2016
2033
|
}
|
|
2017
2034
|
|
|
2018
2035
|
#pragma endregion E4M3 Floats
|
|
@@ -2031,10 +2048,9 @@ NK_PUBLIC void nk_dots_symmetric_e4m3_sme(nk_e4m3_t const *vectors, nk_size_t ve
|
|
|
2031
2048
|
* Fused `e5m2` × `e5m2` → `f32` GEMM kernel using interleaved FMOPA.
|
|
2032
2049
|
* Converts `e5m2` → `f16` on-the-fly for A, B is pre-converted during packing.
|
|
2033
2050
|
*/
|
|
2034
|
-
|
|
2035
|
-
nk_e5m2_t const *a, void const *b_packed, nk_f32_t *c,
|
|
2036
|
-
nk_size_t
|
|
2037
|
-
nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
|
|
2051
|
+
__arm_new("za") static void nk_dots_packed_e5m2_sme_streaming_( //
|
|
2052
|
+
nk_e5m2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
2053
|
+
nk_size_t a_stride_elements, nk_size_t c_stride_elements) NK_STREAMING_ {
|
|
2038
2054
|
|
|
2039
2055
|
nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
|
|
2040
2056
|
nk_size_t const column_tile_count = header->column_tile_count;
|
|
@@ -2196,7 +2212,9 @@ NK_PUBLIC void nk_dots_pack_e5m2_sme(nk_e5m2_t const *b, nk_size_t columns, nk_s
|
|
|
2196
2212
|
|
|
2197
2213
|
nk_f16_t *tiles_ptr = (nk_f16_t *)((char *)b_packed + sizeof(nk_dots_sme_packed_header_t));
|
|
2198
2214
|
|
|
2215
|
+
nk_sme_start_streaming_();
|
|
2199
2216
|
nk_dots_pack_e5m2_to_b16_sme_streaming_(b, columns, depth, b_stride_in_bytes, tiles_ptr);
|
|
2217
|
+
nk_sme_stop_streaming_();
|
|
2200
2218
|
|
|
2201
2219
|
nk_size_t const data_size = total_vectors * vector_elements * sizeof(nk_f16_t);
|
|
2202
2220
|
header->norms_offset = (nk_u32_t)(sizeof(nk_dots_sme_packed_header_t) + data_size);
|
|
@@ -2210,15 +2228,16 @@ NK_PUBLIC void nk_dots_pack_e5m2_sme(nk_e5m2_t const *b, nk_size_t columns, nk_s
|
|
|
2210
2228
|
/* `e5m2` × `e5m2` → `f32` GEMM: public interface.
|
|
2211
2229
|
* Predicate-based edge handling eliminates scalar fallbacks.
|
|
2212
2230
|
*/
|
|
2213
|
-
NK_PUBLIC void nk_dots_packed_e5m2_sme(
|
|
2214
|
-
nk_e5m2_t const *a, void const *b_packed, nk_f32_t *c,
|
|
2215
|
-
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
2231
|
+
NK_PUBLIC void nk_dots_packed_e5m2_sme( //
|
|
2232
|
+
nk_e5m2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
2216
2233
|
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
2217
2234
|
|
|
2218
2235
|
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_e5m2_t);
|
|
2219
2236
|
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
|
|
2220
2237
|
|
|
2238
|
+
nk_sme_start_streaming_();
|
|
2221
2239
|
nk_dots_packed_e5m2_sme_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
|
|
2240
|
+
nk_sme_stop_streaming_();
|
|
2222
2241
|
}
|
|
2223
2242
|
|
|
2224
2243
|
/**
|
|
@@ -2227,9 +2246,9 @@ NK_PUBLIC void nk_dots_packed_e5m2_sme( //
|
|
|
2227
2246
|
* Pre-reads A columns into Z registers, then reloads ZA0 with converted B data
|
|
2228
2247
|
* per column tile. Eliminates all scalar B-packing loops.
|
|
2229
2248
|
*/
|
|
2230
|
-
|
|
2249
|
+
__arm_new("za") static void nk_dots_symmetric_e5m2_sme_streaming_( //
|
|
2231
2250
|
nk_e5m2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, nk_f32_t *result,
|
|
2232
|
-
nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
|
|
2251
|
+
nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) NK_STREAMING_ {
|
|
2233
2252
|
|
|
2234
2253
|
nk_size_t const expansion = 2;
|
|
2235
2254
|
nk_size_t const tile_dimension = svcntw();
|
|
@@ -2425,14 +2444,16 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_symmetric_e5m2_sme_s
|
|
|
2425
2444
|
}
|
|
2426
2445
|
}
|
|
2427
2446
|
|
|
2428
|
-
NK_PUBLIC void nk_dots_symmetric_e5m2_sme(
|
|
2429
|
-
|
|
2430
|
-
|
|
2447
|
+
NK_PUBLIC void nk_dots_symmetric_e5m2_sme( //
|
|
2448
|
+
nk_e5m2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, nk_f32_t *result,
|
|
2449
|
+
nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
2431
2450
|
|
|
2432
2451
|
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_e5m2_t);
|
|
2433
2452
|
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
2453
|
+
nk_sme_start_streaming_();
|
|
2434
2454
|
nk_dots_symmetric_e5m2_sme_streaming_(vectors, vectors_count, depth, stride_elements, result,
|
|
2435
2455
|
result_stride_elements, row_start, row_count);
|
|
2456
|
+
nk_sme_stop_streaming_();
|
|
2436
2457
|
}
|
|
2437
2458
|
|
|
2438
2459
|
#pragma endregion E5M2 Floats
|
|
@@ -2490,10 +2511,9 @@ NK_PUBLIC svint8_t nk_e2m3x_to_i8x_ssve_(svbool_t predicate_b8x, svuint8_t raw_b
|
|
|
2490
2511
|
* Converts `e2m3` → `i8` on-the-fly for A, B is pre-converted during packing.
|
|
2491
2512
|
* Accumulates in `i32` via `svmopa_za32_s8_m`, then converts to `f32` with 1/256 scaling.
|
|
2492
2513
|
*/
|
|
2493
|
-
|
|
2494
|
-
nk_e2m3_t const *a, void const *b_packed, nk_f32_t *c,
|
|
2495
|
-
nk_size_t
|
|
2496
|
-
nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
|
|
2514
|
+
__arm_new("za") static void nk_dots_packed_e2m3_sme_streaming_( //
|
|
2515
|
+
nk_e2m3_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
2516
|
+
nk_size_t a_stride_elements, nk_size_t c_stride_elements) NK_STREAMING_ {
|
|
2497
2517
|
|
|
2498
2518
|
nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
|
|
2499
2519
|
nk_size_t const column_tile_count = header->column_tile_count;
|
|
@@ -2648,9 +2668,8 @@ NK_PUBLIC nk_size_t nk_dots_packed_size_e2m3_sme(nk_size_t columns, nk_size_t de
|
|
|
2648
2668
|
}
|
|
2649
2669
|
|
|
2650
2670
|
/** @brief Streaming pack helper for e2m3 → i8 conversion + quad-interleave using ZA tile transpose. */
|
|
2651
|
-
|
|
2652
|
-
void const *b, nk_size_t columns, nk_size_t depth,
|
|
2653
|
-
nk_size_t b_stride_bytes, void *tiles_ptr) {
|
|
2671
|
+
__arm_new("za") static void nk_dots_pack_e2m3_to_b8_sme_streaming_( //
|
|
2672
|
+
void const *b, nk_size_t columns, nk_size_t depth, nk_size_t b_stride_bytes, void *tiles_ptr) NK_STREAMING_ {
|
|
2654
2673
|
|
|
2655
2674
|
nk_size_t const expansion = 4;
|
|
2656
2675
|
nk_size_t const tile_dimension = svcntw();
|
|
@@ -2693,9 +2712,8 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_pack_e2m3_to_b8_sme_
|
|
|
2693
2712
|
}
|
|
2694
2713
|
}
|
|
2695
2714
|
|
|
2696
|
-
NK_PUBLIC void nk_dots_pack_e2m3_sme(
|
|
2697
|
-
nk_e2m3_t const *b, nk_size_t columns, nk_size_t depth,
|
|
2698
|
-
nk_size_t b_stride_in_bytes, void *b_packed) {
|
|
2715
|
+
NK_PUBLIC void nk_dots_pack_e2m3_sme( //
|
|
2716
|
+
nk_e2m3_t const *b, nk_size_t columns, nk_size_t depth, nk_size_t b_stride_in_bytes, void *b_packed) {
|
|
2699
2717
|
|
|
2700
2718
|
nk_size_t const expansion = 4;
|
|
2701
2719
|
nk_size_t const tile_dimension = nk_sme_cntw_();
|
|
@@ -2714,7 +2732,9 @@ NK_PUBLIC void nk_dots_pack_e2m3_sme( //
|
|
|
2714
2732
|
|
|
2715
2733
|
nk_i8_t *tiles_ptr = (nk_i8_t *)((char *)b_packed + sizeof(nk_dots_sme_packed_header_t));
|
|
2716
2734
|
|
|
2735
|
+
nk_sme_start_streaming_();
|
|
2717
2736
|
nk_dots_pack_e2m3_to_b8_sme_streaming_(b, columns, depth, b_stride_in_bytes, tiles_ptr);
|
|
2737
|
+
nk_sme_stop_streaming_();
|
|
2718
2738
|
|
|
2719
2739
|
nk_size_t const data_size = total_vectors * vector_elements * sizeof(nk_i8_t);
|
|
2720
2740
|
header->norms_offset = (nk_u32_t)(sizeof(nk_dots_sme_packed_header_t) + data_size);
|
|
@@ -2725,15 +2745,16 @@ NK_PUBLIC void nk_dots_pack_e2m3_sme( //
|
|
|
2725
2745
|
}
|
|
2726
2746
|
}
|
|
2727
2747
|
|
|
2728
|
-
NK_PUBLIC void nk_dots_packed_e2m3_sme(
|
|
2729
|
-
nk_e2m3_t const *a, void const *b_packed, nk_f32_t *c,
|
|
2730
|
-
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
2748
|
+
NK_PUBLIC void nk_dots_packed_e2m3_sme( //
|
|
2749
|
+
nk_e2m3_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
2731
2750
|
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
2732
2751
|
|
|
2733
2752
|
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_e2m3_t);
|
|
2734
2753
|
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
|
|
2735
2754
|
|
|
2755
|
+
nk_sme_start_streaming_();
|
|
2736
2756
|
nk_dots_packed_e2m3_sme_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
|
|
2757
|
+
nk_sme_stop_streaming_();
|
|
2737
2758
|
}
|
|
2738
2759
|
|
|
2739
2760
|
/**
|
|
@@ -2742,9 +2763,9 @@ NK_PUBLIC void nk_dots_packed_e2m3_sme( //
|
|
|
2742
2763
|
* Pre-reads A columns into Z registers, then reloads ZA0 with converted B data
|
|
2743
2764
|
* per column tile. Accumulates in i32, converts to f32 with 1/256 scaling.
|
|
2744
2765
|
*/
|
|
2745
|
-
|
|
2766
|
+
__arm_new("za") static void nk_dots_symmetric_e2m3_sme_streaming_( //
|
|
2746
2767
|
nk_e2m3_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, nk_f32_t *result,
|
|
2747
|
-
nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
|
|
2768
|
+
nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) NK_STREAMING_ {
|
|
2748
2769
|
|
|
2749
2770
|
nk_size_t const expansion = 4;
|
|
2750
2771
|
nk_size_t const tile_dimension = svcntw();
|
|
@@ -2947,14 +2968,16 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_symmetric_e2m3_sme_s
|
|
|
2947
2968
|
}
|
|
2948
2969
|
}
|
|
2949
2970
|
|
|
2950
|
-
NK_PUBLIC void nk_dots_symmetric_e2m3_sme(
|
|
2951
|
-
|
|
2952
|
-
|
|
2971
|
+
NK_PUBLIC void nk_dots_symmetric_e2m3_sme( //
|
|
2972
|
+
nk_e2m3_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, nk_f32_t *result,
|
|
2973
|
+
nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
2953
2974
|
|
|
2954
2975
|
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_e2m3_t);
|
|
2955
2976
|
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
2977
|
+
nk_sme_start_streaming_();
|
|
2956
2978
|
nk_dots_symmetric_e2m3_sme_streaming_(vectors, vectors_count, depth, stride_elements, result,
|
|
2957
2979
|
result_stride_elements, row_start, row_count);
|
|
2980
|
+
nk_sme_stop_streaming_();
|
|
2958
2981
|
}
|
|
2959
2982
|
|
|
2960
2983
|
#pragma endregion E2M3 Floats
|
|
@@ -3012,10 +3035,9 @@ NK_PUBLIC svfloat16_t nk_e3m2x_to_f16x_ssve_(svbool_t predicate_b16x, svuint8_t
|
|
|
3012
3035
|
* Fused `e3m2` × `e3m2` → `f32` GEMM kernel using interleaved FMOPA.
|
|
3013
3036
|
* Converts `e3m2` → `f16` on-the-fly for A, B is pre-converted during packing.
|
|
3014
3037
|
*/
|
|
3015
|
-
|
|
3016
|
-
nk_e3m2_t const *a, void const *b_packed, nk_f32_t *c,
|
|
3017
|
-
nk_size_t
|
|
3018
|
-
nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
|
|
3038
|
+
__arm_new("za") static void nk_dots_packed_e3m2_sme_streaming_( //
|
|
3039
|
+
nk_e3m2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
3040
|
+
nk_size_t a_stride_elements, nk_size_t c_stride_elements) NK_STREAMING_ {
|
|
3019
3041
|
|
|
3020
3042
|
nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
|
|
3021
3043
|
nk_size_t const column_tile_count = header->column_tile_count;
|
|
@@ -3158,8 +3180,8 @@ NK_PUBLIC nk_size_t nk_dots_packed_size_e3m2_sme(nk_size_t columns, nk_size_t de
|
|
|
3158
3180
|
}
|
|
3159
3181
|
|
|
3160
3182
|
/** @brief Streaming e3m2 → f16 pack using ZA tile transpose. */
|
|
3161
|
-
|
|
3162
|
-
void const *b, nk_size_t columns, nk_size_t depth, nk_size_t b_stride_bytes, void *tiles_ptr) {
|
|
3183
|
+
__arm_new("za") static void nk_dots_pack_e3m2_to_b16_sme_streaming_( //
|
|
3184
|
+
void const *b, nk_size_t columns, nk_size_t depth, nk_size_t b_stride_bytes, void *tiles_ptr) NK_STREAMING_ {
|
|
3163
3185
|
|
|
3164
3186
|
nk_size_t const expansion = 2;
|
|
3165
3187
|
nk_size_t const tile_dimension = svcntw();
|
|
@@ -3202,8 +3224,8 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_pack_e3m2_to_b16_sme
|
|
|
3202
3224
|
}
|
|
3203
3225
|
}
|
|
3204
3226
|
|
|
3205
|
-
NK_PUBLIC void nk_dots_pack_e3m2_sme(
|
|
3206
|
-
|
|
3227
|
+
NK_PUBLIC void nk_dots_pack_e3m2_sme( //
|
|
3228
|
+
nk_e3m2_t const *b, nk_size_t columns, nk_size_t depth, nk_size_t b_stride_in_bytes, void *b_packed) {
|
|
3207
3229
|
|
|
3208
3230
|
nk_size_t const expansion = 2;
|
|
3209
3231
|
nk_size_t const tile_dimension = nk_sme_cntw_();
|
|
@@ -3222,7 +3244,9 @@ NK_PUBLIC void nk_dots_pack_e3m2_sme(nk_e3m2_t const *b, nk_size_t columns, nk_s
|
|
|
3222
3244
|
|
|
3223
3245
|
nk_f16_t *tiles_ptr = (nk_f16_t *)((char *)b_packed + sizeof(nk_dots_sme_packed_header_t));
|
|
3224
3246
|
|
|
3247
|
+
nk_sme_start_streaming_();
|
|
3225
3248
|
nk_dots_pack_e3m2_to_b16_sme_streaming_(b, columns, depth, b_stride_in_bytes, tiles_ptr);
|
|
3249
|
+
nk_sme_stop_streaming_();
|
|
3226
3250
|
|
|
3227
3251
|
nk_size_t const data_size = total_vectors * vector_elements * sizeof(nk_f16_t);
|
|
3228
3252
|
header->norms_offset = (nk_u32_t)(sizeof(nk_dots_sme_packed_header_t) + data_size);
|
|
@@ -3236,15 +3260,16 @@ NK_PUBLIC void nk_dots_pack_e3m2_sme(nk_e3m2_t const *b, nk_size_t columns, nk_s
|
|
|
3236
3260
|
/* `e3m2` × `e3m2` → `f32` GEMM: public interface.
|
|
3237
3261
|
* Predicate-based edge handling eliminates scalar fallbacks.
|
|
3238
3262
|
*/
|
|
3239
|
-
NK_PUBLIC void nk_dots_packed_e3m2_sme(
|
|
3240
|
-
nk_e3m2_t const *a, void const *b_packed, nk_f32_t *c,
|
|
3241
|
-
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
3263
|
+
NK_PUBLIC void nk_dots_packed_e3m2_sme( //
|
|
3264
|
+
nk_e3m2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
3242
3265
|
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
3243
3266
|
|
|
3244
3267
|
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_e3m2_t);
|
|
3245
3268
|
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
|
|
3246
3269
|
|
|
3270
|
+
nk_sme_start_streaming_();
|
|
3247
3271
|
nk_dots_packed_e3m2_sme_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
|
|
3272
|
+
nk_sme_stop_streaming_();
|
|
3248
3273
|
}
|
|
3249
3274
|
|
|
3250
3275
|
/**
|
|
@@ -3253,9 +3278,9 @@ NK_PUBLIC void nk_dots_packed_e3m2_sme( //
|
|
|
3253
3278
|
* Pre-reads A columns into Z registers, then reloads ZA0 with converted B data
|
|
3254
3279
|
* per column tile. Eliminates all scalar B-packing loops.
|
|
3255
3280
|
*/
|
|
3256
|
-
|
|
3281
|
+
__arm_new("za") static void nk_dots_symmetric_e3m2_sme_streaming_( //
|
|
3257
3282
|
nk_e3m2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, nk_f32_t *result,
|
|
3258
|
-
nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
|
|
3283
|
+
nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) NK_STREAMING_ {
|
|
3259
3284
|
|
|
3260
3285
|
nk_size_t const expansion = 2;
|
|
3261
3286
|
nk_size_t const tile_dimension = svcntw();
|
|
@@ -3451,14 +3476,16 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_symmetric_e3m2_sme_s
|
|
|
3451
3476
|
}
|
|
3452
3477
|
}
|
|
3453
3478
|
|
|
3454
|
-
NK_PUBLIC void nk_dots_symmetric_e3m2_sme(
|
|
3455
|
-
|
|
3456
|
-
|
|
3479
|
+
NK_PUBLIC void nk_dots_symmetric_e3m2_sme( //
|
|
3480
|
+
nk_e3m2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, nk_f32_t *result,
|
|
3481
|
+
nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
3457
3482
|
|
|
3458
3483
|
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_e3m2_t);
|
|
3459
3484
|
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
3485
|
+
nk_sme_start_streaming_();
|
|
3460
3486
|
nk_dots_symmetric_e3m2_sme_streaming_(vectors, vectors_count, depth, stride_elements, result,
|
|
3461
3487
|
result_stride_elements, row_start, row_count);
|
|
3488
|
+
nk_sme_stop_streaming_();
|
|
3462
3489
|
}
|
|
3463
3490
|
|
|
3464
3491
|
#pragma endregion I8 Integers
|
|
@@ -3481,9 +3508,8 @@ NK_PUBLIC nk_size_t nk_dots_packed_size_u8_sme(nk_size_t columns, nk_size_t dept
|
|
|
3481
3508
|
return nk_dots_packed_size_i8_sme(columns, depth);
|
|
3482
3509
|
}
|
|
3483
3510
|
|
|
3484
|
-
NK_PUBLIC void nk_dots_pack_u8_sme(
|
|
3485
|
-
nk_u8_t const *b, nk_size_t columns, nk_size_t depth,
|
|
3486
|
-
nk_size_t b_stride_in_bytes, void *b_packed) {
|
|
3511
|
+
NK_PUBLIC void nk_dots_pack_u8_sme( //
|
|
3512
|
+
nk_u8_t const *b, nk_size_t columns, nk_size_t depth, nk_size_t b_stride_in_bytes, void *b_packed) {
|
|
3487
3513
|
|
|
3488
3514
|
nk_size_t const expansion = 4;
|
|
3489
3515
|
nk_size_t const tile_dimension = nk_sme_cntw_();
|
|
@@ -3502,7 +3528,9 @@ NK_PUBLIC void nk_dots_pack_u8_sme( //
|
|
|
3502
3528
|
|
|
3503
3529
|
nk_u8_t *tiles_ptr = (nk_u8_t *)((char *)b_packed + sizeof(nk_dots_sme_packed_header_t));
|
|
3504
3530
|
|
|
3531
|
+
nk_sme_start_streaming_();
|
|
3505
3532
|
nk_dots_pack_b8_sme_streaming_(b, columns, depth, b_stride_in_bytes, tiles_ptr);
|
|
3533
|
+
nk_sme_stop_streaming_();
|
|
3506
3534
|
|
|
3507
3535
|
nk_size_t const data_size = total_vectors * vector_elements * sizeof(nk_u8_t);
|
|
3508
3536
|
header->norms_offset = (nk_u32_t)(sizeof(nk_dots_sme_packed_header_t) + data_size);
|
|
@@ -3517,10 +3545,9 @@ NK_PUBLIC void nk_dots_pack_u8_sme( //
|
|
|
3517
3545
|
* `u8` × `u8` → `u32` GEMM core kernel using SME outer products.
|
|
3518
3546
|
* Same interleaved algorithm as i8 kernel, using UMOPA u8→u32.
|
|
3519
3547
|
*/
|
|
3520
|
-
|
|
3521
|
-
nk_u8_t const *a, void const *b_packed, nk_u32_t *c,
|
|
3522
|
-
nk_size_t
|
|
3523
|
-
nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
|
|
3548
|
+
__arm_new("za") static void nk_dots_packed_u8_sme_streaming_( //
|
|
3549
|
+
nk_u8_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
3550
|
+
nk_size_t a_stride_elements, nk_size_t c_stride_elements) NK_STREAMING_ {
|
|
3524
3551
|
|
|
3525
3552
|
nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
|
|
3526
3553
|
nk_size_t const column_tile_count = header->column_tile_count;
|
|
@@ -3651,21 +3678,21 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_packed_u8_sme_stream
|
|
|
3651
3678
|
}
|
|
3652
3679
|
}
|
|
3653
3680
|
|
|
3654
|
-
NK_PUBLIC void nk_dots_packed_u8_sme(
|
|
3655
|
-
nk_u8_t const *a, void const *b_packed, nk_u32_t *c,
|
|
3656
|
-
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
3681
|
+
NK_PUBLIC void nk_dots_packed_u8_sme( //
|
|
3682
|
+
nk_u8_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
3657
3683
|
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
3658
3684
|
|
|
3659
3685
|
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_u8_t);
|
|
3660
3686
|
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_u32_t);
|
|
3661
3687
|
|
|
3688
|
+
nk_sme_start_streaming_();
|
|
3662
3689
|
nk_dots_packed_u8_sme_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
|
|
3690
|
+
nk_sme_stop_streaming_();
|
|
3663
3691
|
}
|
|
3664
3692
|
|
|
3665
|
-
|
|
3666
|
-
__arm_locally_streaming __arm_new("za") static void nk_dots_symmetric_u8_sme_streaming_(
|
|
3693
|
+
__arm_new("za") static void nk_dots_symmetric_u8_sme_streaming_( //
|
|
3667
3694
|
nk_u8_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, nk_u32_t *result,
|
|
3668
|
-
nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
|
|
3695
|
+
nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) NK_STREAMING_ {
|
|
3669
3696
|
|
|
3670
3697
|
nk_size_t const expansion = 4;
|
|
3671
3698
|
nk_size_t const tile_dimension = svcntw();
|
|
@@ -3833,14 +3860,16 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_symmetric_u8_sme_str
|
|
|
3833
3860
|
}
|
|
3834
3861
|
}
|
|
3835
3862
|
|
|
3836
|
-
NK_PUBLIC void nk_dots_symmetric_u8_sme(
|
|
3837
|
-
|
|
3838
|
-
|
|
3863
|
+
NK_PUBLIC void nk_dots_symmetric_u8_sme( //
|
|
3864
|
+
nk_u8_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, nk_u32_t *result,
|
|
3865
|
+
nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
3839
3866
|
|
|
3840
3867
|
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_u8_t);
|
|
3841
3868
|
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_u32_t);
|
|
3869
|
+
nk_sme_start_streaming_();
|
|
3842
3870
|
nk_dots_symmetric_u8_sme_streaming_(vectors, vectors_count, depth, stride_elements, result, result_stride_elements,
|
|
3843
3871
|
row_start, row_count);
|
|
3872
|
+
nk_sme_stop_streaming_();
|
|
3844
3873
|
}
|
|
3845
3874
|
|
|
3846
3875
|
#pragma endregion U8 Integers
|
|
@@ -3876,9 +3905,8 @@ NK_PUBLIC nk_size_t nk_dots_packed_size_u4_sme(nk_size_t columns, nk_size_t dept
|
|
|
3876
3905
|
return size;
|
|
3877
3906
|
}
|
|
3878
3907
|
|
|
3879
|
-
NK_PUBLIC void nk_dots_pack_u4_sme(
|
|
3880
|
-
nk_u4x2_t const *b, nk_size_t columns, nk_size_t depth,
|
|
3881
|
-
nk_size_t b_stride_in_bytes, void *b_packed) {
|
|
3908
|
+
NK_PUBLIC void nk_dots_pack_u4_sme( //
|
|
3909
|
+
nk_u4x2_t const *b, nk_size_t columns, nk_size_t depth, nk_size_t b_stride_in_bytes, void *b_packed) {
|
|
3882
3910
|
|
|
3883
3911
|
nk_size_t const tile_dimension = nk_sme_cntw_();
|
|
3884
3912
|
nk_size_t const vector_elements = nk_sme_cntb_();
|
|
@@ -3945,10 +3973,9 @@ NK_PUBLIC void nk_dots_pack_u4_sme( //
|
|
|
3945
3973
|
* B input is pre-split low/high nibble vectors from nk_dots_pack_u4_sme.
|
|
3946
3974
|
* Two UMOPAs per depth step: one for low nibbles, one for high nibbles.
|
|
3947
3975
|
*/
|
|
3948
|
-
|
|
3949
|
-
nk_u4x2_t const *a, void const *b_packed, nk_u32_t *c,
|
|
3950
|
-
nk_size_t
|
|
3951
|
-
nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
|
|
3976
|
+
__arm_new("za") static void nk_dots_packed_u4_sme_streaming_( //
|
|
3977
|
+
nk_u4x2_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
3978
|
+
nk_size_t a_stride_elements, nk_size_t c_stride_elements) NK_STREAMING_ {
|
|
3952
3979
|
|
|
3953
3980
|
nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
|
|
3954
3981
|
nk_size_t const column_tile_count = header->column_tile_count;
|
|
@@ -4123,14 +4150,15 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_packed_u4_sme_stream
|
|
|
4123
4150
|
}
|
|
4124
4151
|
}
|
|
4125
4152
|
|
|
4126
|
-
NK_PUBLIC void nk_dots_packed_u4_sme(
|
|
4127
|
-
nk_u4x2_t const *a, void const *b_packed, nk_u32_t *c,
|
|
4128
|
-
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
4153
|
+
NK_PUBLIC void nk_dots_packed_u4_sme( //
|
|
4154
|
+
nk_u4x2_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
4129
4155
|
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
4130
4156
|
|
|
4131
4157
|
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_u4x2_t);
|
|
4132
4158
|
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_u32_t);
|
|
4159
|
+
nk_sme_start_streaming_();
|
|
4133
4160
|
nk_dots_packed_u4_sme_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
|
|
4161
|
+
nk_sme_stop_streaming_();
|
|
4134
4162
|
}
|
|
4135
4163
|
|
|
4136
4164
|
#pragma endregion Unsigned Integers
|
|
@@ -4222,10 +4250,9 @@ NK_PUBLIC void nk_dots_pack_i4_sme( //
|
|
|
4222
4250
|
* B input is pre-split sign-extended nibble vectors from nk_dots_pack_i4_sme.
|
|
4223
4251
|
* Two SMOPAs per depth step: one for low nibbles, one for high nibbles.
|
|
4224
4252
|
*/
|
|
4225
|
-
|
|
4226
|
-
nk_i4x2_t const *a, void const *b_packed, nk_i32_t *c,
|
|
4227
|
-
nk_size_t
|
|
4228
|
-
nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
|
|
4253
|
+
__arm_new("za") static void nk_dots_packed_i4_sme_streaming_( //
|
|
4254
|
+
nk_i4x2_t const *a, void const *b_packed, nk_i32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
4255
|
+
nk_size_t a_stride_elements, nk_size_t c_stride_elements) NK_STREAMING_ {
|
|
4229
4256
|
|
|
4230
4257
|
nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
|
|
4231
4258
|
nk_size_t const column_tile_count = header->column_tile_count;
|
|
@@ -4406,14 +4433,15 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_packed_i4_sme_stream
|
|
|
4406
4433
|
}
|
|
4407
4434
|
}
|
|
4408
4435
|
|
|
4409
|
-
NK_PUBLIC void nk_dots_packed_i4_sme(
|
|
4410
|
-
nk_i4x2_t const *a, void const *b_packed, nk_i32_t *c,
|
|
4411
|
-
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
4436
|
+
NK_PUBLIC void nk_dots_packed_i4_sme( //
|
|
4437
|
+
nk_i4x2_t const *a, void const *b_packed, nk_i32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
4412
4438
|
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
4413
4439
|
|
|
4414
4440
|
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_i4x2_t);
|
|
4415
4441
|
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_i32_t);
|
|
4442
|
+
nk_sme_start_streaming_();
|
|
4416
4443
|
nk_dots_packed_i4_sme_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
|
|
4444
|
+
nk_sme_stop_streaming_();
|
|
4417
4445
|
}
|
|
4418
4446
|
|
|
4419
4447
|
/**
|
|
@@ -4421,9 +4449,9 @@ NK_PUBLIC void nk_dots_packed_i4_sme( //
|
|
|
4421
4449
|
* Loads packed nibble bytes directly into ZA0, splits into low/high nibbles in registers,
|
|
4422
4450
|
* issues 2 UMOPAs per depth step. ZA0 = staging tile, ZA1-ZA3 = accumulators.
|
|
4423
4451
|
*/
|
|
4424
|
-
|
|
4452
|
+
__arm_new("za") static void nk_dots_symmetric_u4_sme_streaming_( //
|
|
4425
4453
|
nk_u4x2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, nk_u32_t *result,
|
|
4426
|
-
nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
|
|
4454
|
+
nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) NK_STREAMING_ {
|
|
4427
4455
|
|
|
4428
4456
|
nk_size_t const expansion = 4;
|
|
4429
4457
|
nk_size_t const tile_dimension = svcntw();
|
|
@@ -4695,14 +4723,16 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_symmetric_u4_sme_str
|
|
|
4695
4723
|
}
|
|
4696
4724
|
}
|
|
4697
4725
|
|
|
4698
|
-
NK_PUBLIC void nk_dots_symmetric_u4_sme(
|
|
4699
|
-
|
|
4700
|
-
|
|
4726
|
+
NK_PUBLIC void nk_dots_symmetric_u4_sme( //
|
|
4727
|
+
nk_u4x2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, nk_u32_t *result,
|
|
4728
|
+
nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
4701
4729
|
|
|
4702
4730
|
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_u4x2_t);
|
|
4703
4731
|
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_u32_t);
|
|
4732
|
+
nk_sme_start_streaming_();
|
|
4704
4733
|
nk_dots_symmetric_u4_sme_streaming_(vectors, vectors_count, depth, stride_elements, result, result_stride_elements,
|
|
4705
4734
|
row_start, row_count);
|
|
4735
|
+
nk_sme_stop_streaming_();
|
|
4706
4736
|
}
|
|
4707
4737
|
|
|
4708
4738
|
/**
|
|
@@ -4710,9 +4740,9 @@ NK_PUBLIC void nk_dots_symmetric_u4_sme(nk_u4x2_t const *vectors, nk_size_t vect
|
|
|
4710
4740
|
* Loads packed nibble bytes directly into ZA0, sign-extends via LSL+ASR in registers,
|
|
4711
4741
|
* issues 2 SMOPAs per depth step. ZA0 = staging tile, ZA1-ZA3 = accumulators.
|
|
4712
4742
|
*/
|
|
4713
|
-
|
|
4743
|
+
__arm_new("za") static void nk_dots_symmetric_i4_sme_streaming_( //
|
|
4714
4744
|
nk_i4x2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, nk_i32_t *result,
|
|
4715
|
-
nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
|
|
4745
|
+
nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) NK_STREAMING_ {
|
|
4716
4746
|
|
|
4717
4747
|
nk_size_t const expansion = 4;
|
|
4718
4748
|
nk_size_t const tile_dimension = svcntw();
|
|
@@ -4983,14 +5013,16 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_symmetric_i4_sme_str
|
|
|
4983
5013
|
}
|
|
4984
5014
|
}
|
|
4985
5015
|
|
|
4986
|
-
NK_PUBLIC void nk_dots_symmetric_i4_sme(
|
|
4987
|
-
|
|
4988
|
-
|
|
5016
|
+
NK_PUBLIC void nk_dots_symmetric_i4_sme( //
|
|
5017
|
+
nk_i4x2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, nk_i32_t *result,
|
|
5018
|
+
nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
4989
5019
|
|
|
4990
5020
|
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_i4x2_t);
|
|
4991
5021
|
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_i32_t);
|
|
5022
|
+
nk_sme_start_streaming_();
|
|
4992
5023
|
nk_dots_symmetric_i4_sme_streaming_(vectors, vectors_count, depth, stride_elements, result, result_stride_elements,
|
|
4993
5024
|
row_start, row_count);
|
|
5025
|
+
nk_sme_stop_streaming_();
|
|
4994
5026
|
}
|
|
4995
5027
|
|
|
4996
5028
|
#pragma endregion Signed Integers
|