numkong 7.4.5 → 7.6.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +1 -0
- package/binding.gyp +99 -5
- package/c/dispatch_e5m2.c +23 -3
- package/c/dispatch_f16.c +23 -0
- package/c/numkong.c +0 -13
- package/include/numkong/attention/sme.h +34 -31
- package/include/numkong/capabilities.h +2 -15
- package/include/numkong/cast/README.md +3 -0
- package/include/numkong/cast/haswell.h +28 -64
- package/include/numkong/cast/neon.h +15 -0
- package/include/numkong/cast/serial.h +17 -0
- package/include/numkong/cast/skylake.h +67 -52
- package/include/numkong/cast.h +1 -0
- package/include/numkong/curved/smef64.h +82 -62
- package/include/numkong/dot/README.md +1 -0
- package/include/numkong/dot/haswell.h +92 -13
- package/include/numkong/dot/rvvbf16.h +1 -1
- package/include/numkong/dot/rvvhalf.h +1 -1
- package/include/numkong/dot/serial.h +15 -0
- package/include/numkong/dot/skylake.h +61 -14
- package/include/numkong/dot/sve.h +6 -5
- package/include/numkong/dot/svebfdot.h +2 -1
- package/include/numkong/dot/svehalf.h +6 -5
- package/include/numkong/dot/svesdot.h +3 -2
- package/include/numkong/dots/README.md +2 -0
- package/include/numkong/dots/graniteamx.h +1167 -0
- package/include/numkong/dots/haswell.h +28 -28
- package/include/numkong/dots/sapphireamx.h +1 -1
- package/include/numkong/dots/serial.h +33 -11
- package/include/numkong/dots/skylake.h +28 -23
- package/include/numkong/dots/sme.h +172 -140
- package/include/numkong/dots/smebi32.h +14 -11
- package/include/numkong/dots/smef64.h +31 -26
- package/include/numkong/dots.h +41 -3
- package/include/numkong/each/serial.h +39 -0
- package/include/numkong/geospatial/haswell.h +1 -1
- package/include/numkong/geospatial/neon.h +1 -1
- package/include/numkong/geospatial/serial.h +15 -4
- package/include/numkong/geospatial/skylake.h +1 -1
- package/include/numkong/maxsim/serial.h +15 -0
- package/include/numkong/maxsim/sme.h +34 -33
- package/include/numkong/mesh/README.md +50 -44
- package/include/numkong/mesh/genoa.h +462 -0
- package/include/numkong/mesh/haswell.h +806 -933
- package/include/numkong/mesh/neon.h +871 -943
- package/include/numkong/mesh/neonbfdot.h +382 -522
- package/include/numkong/mesh/neonfhm.h +676 -0
- package/include/numkong/mesh/rvv.h +404 -319
- package/include/numkong/mesh/serial.h +225 -161
- package/include/numkong/mesh/skylake.h +1029 -1585
- package/include/numkong/mesh/v128relaxed.h +403 -377
- package/include/numkong/mesh.h +38 -0
- package/include/numkong/reduce/neon.h +29 -0
- package/include/numkong/reduce/neonbfdot.h +2 -2
- package/include/numkong/reduce/neonfhm.h +4 -4
- package/include/numkong/reduce/serial.h +15 -1
- package/include/numkong/reduce/sve.h +52 -0
- package/include/numkong/reduce.h +4 -0
- package/include/numkong/set/sve.h +6 -5
- package/include/numkong/sets/smebi32.h +35 -30
- package/include/numkong/sparse/serial.h +17 -2
- package/include/numkong/sparse/sve2.h +3 -2
- package/include/numkong/spatial/genoa.h +0 -68
- package/include/numkong/spatial/haswell.h +98 -56
- package/include/numkong/spatial/serial.h +15 -0
- package/include/numkong/spatial/skylake.h +114 -54
- package/include/numkong/spatial/sve.h +7 -6
- package/include/numkong/spatial/svebfdot.h +7 -4
- package/include/numkong/spatial/svehalf.h +5 -4
- package/include/numkong/spatial/svesdot.h +9 -8
- package/include/numkong/spatial.h +0 -12
- package/include/numkong/spatials/graniteamx.h +301 -0
- package/include/numkong/spatials/serial.h +39 -0
- package/include/numkong/spatials/skylake.h +2 -2
- package/include/numkong/spatials/sme.h +391 -350
- package/include/numkong/spatials/smef64.h +79 -70
- package/include/numkong/spatials.h +54 -4
- package/include/numkong/tensor.hpp +107 -23
- package/include/numkong/types.h +59 -0
- package/javascript/dist/cjs/numkong.js +13 -0
- package/javascript/dist/esm/numkong.js +13 -0
- package/javascript/numkong.c +59 -14
- package/javascript/numkong.ts +13 -0
- package/package.json +7 -7
- package/probes/probe.js +2 -2
- package/wasm/numkong.wasm +0 -0
|
@@ -13,6 +13,7 @@
|
|
|
13
13
|
#if NK_TARGET_SME
|
|
14
14
|
|
|
15
15
|
#include "numkong/dots/serial.h"
|
|
16
|
+
#include "numkong/reduce/sve.h" // `nk_svaddv_f64_`
|
|
16
17
|
#include "numkong/dots/sme.h"
|
|
17
18
|
|
|
18
19
|
#if defined(__cplusplus)
|
|
@@ -44,7 +45,7 @@ NK_PUBLIC nk_f32_t nk_dots_reduce_sumsq_f16_ssve_(nk_f16_t const *data, nk_size_
|
|
|
44
45
|
svfloat32_t values_odd_f32x = svcvtlt_f32_f16_x(predicate_odd_b32x, values_f16x);
|
|
45
46
|
accumulator_odd_f32x = svmla_f32_m(predicate_odd_b32x, accumulator_odd_f32x, values_odd_f32x, values_odd_f32x);
|
|
46
47
|
}
|
|
47
|
-
return
|
|
48
|
+
return nk_svaddv_f32_(svptrue_b32(), accumulator_even_f32x) + nk_svaddv_f32_(svptrue_b32(), accumulator_odd_f32x);
|
|
48
49
|
}
|
|
49
50
|
|
|
50
51
|
NK_PUBLIC nk_f32_t nk_dots_reduce_sumsq_bf16_ssve_(nk_bf16_t const *data, nk_size_t count) NK_STREAMING_ {
|
|
@@ -55,7 +56,7 @@ NK_PUBLIC nk_f32_t nk_dots_reduce_sumsq_bf16_ssve_(nk_bf16_t const *data, nk_siz
|
|
|
55
56
|
svbfloat16_t values_bf16x = svld1_bf16(predicate_b16x, (nk_bf16_for_arm_simd_t const *)(data + i));
|
|
56
57
|
accumulator_f32x = svbfdot_f32(accumulator_f32x, values_bf16x, values_bf16x);
|
|
57
58
|
}
|
|
58
|
-
return
|
|
59
|
+
return nk_svaddv_f32_(svptrue_b32(), accumulator_f32x);
|
|
59
60
|
}
|
|
60
61
|
|
|
61
62
|
NK_PUBLIC nk_f32_t nk_dots_reduce_sumsq_e4m3_ssve_(nk_e4m3_t const *data, nk_size_t count) NK_STREAMING_ {
|
|
@@ -79,7 +80,7 @@ NK_PUBLIC nk_f32_t nk_dots_reduce_sumsq_e4m3_ssve_(nk_e4m3_t const *data, nk_siz
|
|
|
79
80
|
svfloat32_t values_odd_f32x = svcvtlt_f32_f16_x(predicate_odd_b32x, values_f16x);
|
|
80
81
|
accumulator_odd_f32x = svmla_f32_m(predicate_odd_b32x, accumulator_odd_f32x, values_odd_f32x, values_odd_f32x);
|
|
81
82
|
}
|
|
82
|
-
return
|
|
83
|
+
return nk_svaddv_f32_(svptrue_b32(), accumulator_even_f32x) + nk_svaddv_f32_(svptrue_b32(), accumulator_odd_f32x);
|
|
83
84
|
}
|
|
84
85
|
|
|
85
86
|
NK_PUBLIC nk_f32_t nk_dots_reduce_sumsq_e5m2_ssve_(nk_e5m2_t const *data, nk_size_t count) NK_STREAMING_ {
|
|
@@ -103,7 +104,7 @@ NK_PUBLIC nk_f32_t nk_dots_reduce_sumsq_e5m2_ssve_(nk_e5m2_t const *data, nk_siz
|
|
|
103
104
|
svfloat32_t values_odd_f32x = svcvtlt_f32_f16_x(predicate_odd_b32x, values_f16x);
|
|
104
105
|
accumulator_odd_f32x = svmla_f32_m(predicate_odd_b32x, accumulator_odd_f32x, values_odd_f32x, values_odd_f32x);
|
|
105
106
|
}
|
|
106
|
-
return
|
|
107
|
+
return nk_svaddv_f32_(svptrue_b32(), accumulator_even_f32x) + nk_svaddv_f32_(svptrue_b32(), accumulator_odd_f32x);
|
|
107
108
|
}
|
|
108
109
|
|
|
109
110
|
NK_PUBLIC nk_f32_t nk_dots_reduce_sumsq_e2m3_ssve_(nk_e2m3_t const *data, nk_size_t count) NK_STREAMING_ {
|
|
@@ -115,7 +116,7 @@ NK_PUBLIC nk_f32_t nk_dots_reduce_sumsq_e2m3_ssve_(nk_e2m3_t const *data, nk_siz
|
|
|
115
116
|
svint8_t values_i8x = nk_e2m3x_to_i8x_ssve_(predicate_b8x, raw_u8x);
|
|
116
117
|
accumulator_i32x = svdot_s32(accumulator_i32x, values_i8x, values_i8x);
|
|
117
118
|
}
|
|
118
|
-
return (nk_f32_t)
|
|
119
|
+
return (nk_f32_t)nk_svaddv_s32_(svptrue_b32(), accumulator_i32x) / 256.0f;
|
|
119
120
|
}
|
|
120
121
|
|
|
121
122
|
NK_PUBLIC nk_f32_t nk_dots_reduce_sumsq_e3m2_ssve_(nk_e3m2_t const *data, nk_size_t count) NK_STREAMING_ {
|
|
@@ -139,7 +140,7 @@ NK_PUBLIC nk_f32_t nk_dots_reduce_sumsq_e3m2_ssve_(nk_e3m2_t const *data, nk_siz
|
|
|
139
140
|
svfloat32_t values_odd_f32x = svcvtlt_f32_f16_x(predicate_odd_b32x, values_f16x);
|
|
140
141
|
accumulator_odd_f32x = svmla_f32_m(predicate_odd_b32x, accumulator_odd_f32x, values_odd_f32x, values_odd_f32x);
|
|
141
142
|
}
|
|
142
|
-
return
|
|
143
|
+
return nk_svaddv_f32_(svptrue_b32(), accumulator_even_f32x) + nk_svaddv_f32_(svptrue_b32(), accumulator_odd_f32x);
|
|
143
144
|
}
|
|
144
145
|
|
|
145
146
|
NK_PUBLIC nk_u32_t nk_dots_reduce_sumsq_i8_ssve_(nk_i8_t const *data, nk_size_t count) NK_STREAMING_ {
|
|
@@ -150,7 +151,7 @@ NK_PUBLIC nk_u32_t nk_dots_reduce_sumsq_i8_ssve_(nk_i8_t const *data, nk_size_t
|
|
|
150
151
|
svint8_t loaded_i8x = svld1_s8(predicate_b8x, data + i);
|
|
151
152
|
accumulator_i32x = svdot_s32(accumulator_i32x, loaded_i8x, loaded_i8x);
|
|
152
153
|
}
|
|
153
|
-
return (nk_u32_t)
|
|
154
|
+
return (nk_u32_t)nk_svaddv_s32_(svptrue_b32(), accumulator_i32x);
|
|
154
155
|
}
|
|
155
156
|
|
|
156
157
|
NK_PUBLIC nk_u32_t nk_dots_reduce_sumsq_u8_ssve_(nk_u8_t const *data, nk_size_t count) NK_STREAMING_ {
|
|
@@ -161,7 +162,7 @@ NK_PUBLIC nk_u32_t nk_dots_reduce_sumsq_u8_ssve_(nk_u8_t const *data, nk_size_t
|
|
|
161
162
|
svuint8_t loaded_u8x = svld1_u8(predicate_b8x, data + i);
|
|
162
163
|
accumulator_u32x = svdot_u32(accumulator_u32x, loaded_u8x, loaded_u8x);
|
|
163
164
|
}
|
|
164
|
-
return (nk_u32_t)
|
|
165
|
+
return (nk_u32_t)nk_svaddv_u32_(svptrue_b32(), accumulator_u32x);
|
|
165
166
|
}
|
|
166
167
|
|
|
167
168
|
NK_PUBLIC nk_u32_t nk_dots_reduce_sumsq_i4_ssve_(nk_i4x2_t const *data, nk_size_t count) NK_STREAMING_ {
|
|
@@ -181,7 +182,7 @@ NK_PUBLIC nk_u32_t nk_dots_reduce_sumsq_i4_ssve_(nk_i4x2_t const *data, nk_size_
|
|
|
181
182
|
accumulator_i32x = svdot_s32(accumulator_i32x, low_i8x, low_i8x);
|
|
182
183
|
accumulator_i32x = svdot_s32(accumulator_i32x, high_i8x, high_i8x);
|
|
183
184
|
}
|
|
184
|
-
return (nk_u32_t)
|
|
185
|
+
return (nk_u32_t)nk_svaddv_s32_(svptrue_b32(), accumulator_i32x);
|
|
185
186
|
}
|
|
186
187
|
|
|
187
188
|
NK_PUBLIC nk_u32_t nk_dots_reduce_sumsq_u4_ssve_(nk_u4x2_t const *data, nk_size_t count) NK_STREAMING_ {
|
|
@@ -197,7 +198,7 @@ NK_PUBLIC nk_u32_t nk_dots_reduce_sumsq_u4_ssve_(nk_u4x2_t const *data, nk_size_
|
|
|
197
198
|
accumulator_u32x = svdot_u32(accumulator_u32x, low_u8x, low_u8x);
|
|
198
199
|
accumulator_u32x = svdot_u32(accumulator_u32x, high_u8x, high_u8x);
|
|
199
200
|
}
|
|
200
|
-
return (nk_u32_t)
|
|
201
|
+
return (nk_u32_t)nk_svaddv_u32_(svptrue_b32(), accumulator_u32x);
|
|
201
202
|
}
|
|
202
203
|
|
|
203
204
|
NK_PUBLIC svfloat32_t nk_angulars_from_dot_f32x_ssve_(svbool_t predicate_b32x, svfloat32_t dots_f32x,
|
|
@@ -226,10 +227,9 @@ NK_PUBLIC svfloat32_t nk_euclideans_from_dot_f32x_ssve_(svbool_t predicate_b32x,
|
|
|
226
227
|
|
|
227
228
|
#pragma region F16 Floats
|
|
228
229
|
|
|
229
|
-
|
|
230
|
-
nk_f16_t const *a, void const *b_packed, nk_f32_t *c,
|
|
231
|
-
nk_size_t
|
|
232
|
-
nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
|
|
230
|
+
static void nk_angulars_packed_f16_sme_finalize_ssve_( //
|
|
231
|
+
nk_f16_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
232
|
+
nk_size_t a_stride_elements, nk_size_t c_stride_elements) NK_STREAMING_ {
|
|
233
233
|
nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
|
|
234
234
|
nk_f32_t const *b_norms = (nk_f32_t const *)((char const *)b_packed + header->norms_offset);
|
|
235
235
|
for (nk_size_t row_index = 0; row_index < rows; row_index++) {
|
|
@@ -248,21 +248,21 @@ __arm_locally_streaming static void nk_angulars_packed_f16_sme_finalize_streamin
|
|
|
248
248
|
}
|
|
249
249
|
}
|
|
250
250
|
|
|
251
|
-
NK_PUBLIC void nk_angulars_packed_f16_sme(
|
|
252
|
-
nk_f16_t const *a, void const *b_packed, nk_f32_t *c,
|
|
253
|
-
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
251
|
+
NK_PUBLIC void nk_angulars_packed_f16_sme( //
|
|
252
|
+
nk_f16_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
254
253
|
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
255
254
|
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_f16_t);
|
|
256
255
|
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
|
|
256
|
+
nk_sme_start_streaming_();
|
|
257
257
|
nk_dots_packed_f16_sme_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
|
|
258
|
-
|
|
259
|
-
|
|
258
|
+
nk_angulars_packed_f16_sme_finalize_ssve_(a, b_packed, c, rows, columns, depth, a_stride_elements,
|
|
259
|
+
c_stride_elements);
|
|
260
|
+
nk_sme_stop_streaming_();
|
|
260
261
|
}
|
|
261
262
|
|
|
262
|
-
|
|
263
|
-
nk_f16_t const *a, void const *b_packed, nk_f32_t *c,
|
|
264
|
-
nk_size_t
|
|
265
|
-
nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
|
|
263
|
+
static void nk_euclideans_packed_f16_sme_finalize_ssve_( //
|
|
264
|
+
nk_f16_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
265
|
+
nk_size_t a_stride_elements, nk_size_t c_stride_elements) NK_STREAMING_ {
|
|
266
266
|
nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
|
|
267
267
|
nk_f32_t const *b_norms = (nk_f32_t const *)((char const *)b_packed + header->norms_offset);
|
|
268
268
|
for (nk_size_t row_index = 0; row_index < rows; row_index++) {
|
|
@@ -281,20 +281,21 @@ __arm_locally_streaming static void nk_euclideans_packed_f16_sme_finalize_stream
|
|
|
281
281
|
}
|
|
282
282
|
}
|
|
283
283
|
|
|
284
|
-
NK_PUBLIC void nk_euclideans_packed_f16_sme(
|
|
285
|
-
nk_f16_t const *a, void const *b_packed, nk_f32_t *c,
|
|
286
|
-
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
284
|
+
NK_PUBLIC void nk_euclideans_packed_f16_sme( //
|
|
285
|
+
nk_f16_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
287
286
|
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
288
287
|
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_f16_t);
|
|
289
288
|
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
|
|
289
|
+
nk_sme_start_streaming_();
|
|
290
290
|
nk_dots_packed_f16_sme_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
|
|
291
|
-
|
|
292
|
-
|
|
291
|
+
nk_euclideans_packed_f16_sme_finalize_ssve_(a, b_packed, c, rows, columns, depth, a_stride_elements,
|
|
292
|
+
c_stride_elements);
|
|
293
|
+
nk_sme_stop_streaming_();
|
|
293
294
|
}
|
|
294
295
|
|
|
295
|
-
|
|
296
|
-
nk_f16_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements,
|
|
297
|
-
|
|
296
|
+
static void nk_angulars_symmetric_f16_sme_finalize_ssve_( //
|
|
297
|
+
nk_f16_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, nk_f32_t *result,
|
|
298
|
+
nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) NK_STREAMING_ {
|
|
298
299
|
// Phase 1: cache row norms on diagonal
|
|
299
300
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
300
301
|
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
@@ -326,20 +327,22 @@ __arm_locally_streaming static void nk_angulars_symmetric_f16_sme_finalize_strea
|
|
|
326
327
|
result[row_index * result_stride_elements + row_index] = 0;
|
|
327
328
|
}
|
|
328
329
|
|
|
329
|
-
NK_PUBLIC void nk_angulars_symmetric_f16_sme(
|
|
330
|
-
nk_f16_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes,
|
|
331
|
-
|
|
330
|
+
NK_PUBLIC void nk_angulars_symmetric_f16_sme( //
|
|
331
|
+
nk_f16_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, nk_f32_t *result,
|
|
332
|
+
nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
332
333
|
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_f16_t);
|
|
333
334
|
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
335
|
+
nk_sme_start_streaming_();
|
|
334
336
|
nk_dots_symmetric_f16_sme_streaming_(vectors, vectors_count, depth, stride_elements, result, result_stride_elements,
|
|
335
337
|
row_start, row_count);
|
|
336
|
-
|
|
337
|
-
|
|
338
|
+
nk_angulars_symmetric_f16_sme_finalize_ssve_(vectors, vectors_count, depth, stride_elements, result,
|
|
339
|
+
result_stride_elements, row_start, row_count);
|
|
340
|
+
nk_sme_stop_streaming_();
|
|
338
341
|
}
|
|
339
342
|
|
|
340
|
-
|
|
341
|
-
nk_f16_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements,
|
|
342
|
-
|
|
343
|
+
static void nk_euclideans_symmetric_f16_sme_finalize_ssve_( //
|
|
344
|
+
nk_f16_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, nk_f32_t *result,
|
|
345
|
+
nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) NK_STREAMING_ {
|
|
343
346
|
// Phase 1: cache row norms on diagonal
|
|
344
347
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
345
348
|
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
@@ -371,25 +374,26 @@ __arm_locally_streaming static void nk_euclideans_symmetric_f16_sme_finalize_str
|
|
|
371
374
|
result[row_index * result_stride_elements + row_index] = 0;
|
|
372
375
|
}
|
|
373
376
|
|
|
374
|
-
NK_PUBLIC void nk_euclideans_symmetric_f16_sme(
|
|
375
|
-
nk_f16_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes,
|
|
376
|
-
|
|
377
|
+
NK_PUBLIC void nk_euclideans_symmetric_f16_sme( //
|
|
378
|
+
nk_f16_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, nk_f32_t *result,
|
|
379
|
+
nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
377
380
|
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_f16_t);
|
|
378
381
|
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
382
|
+
nk_sme_start_streaming_();
|
|
379
383
|
nk_dots_symmetric_f16_sme_streaming_(vectors, vectors_count, depth, stride_elements, result, result_stride_elements,
|
|
380
384
|
row_start, row_count);
|
|
381
|
-
|
|
382
|
-
|
|
385
|
+
nk_euclideans_symmetric_f16_sme_finalize_ssve_(vectors, vectors_count, depth, stride_elements, result,
|
|
386
|
+
result_stride_elements, row_start, row_count);
|
|
387
|
+
nk_sme_stop_streaming_();
|
|
383
388
|
}
|
|
384
389
|
|
|
385
390
|
#pragma endregion F16 Floats
|
|
386
391
|
|
|
387
392
|
#pragma region BF16 Floats
|
|
388
393
|
|
|
389
|
-
|
|
390
|
-
nk_bf16_t const *a, void const *b_packed, nk_f32_t *c,
|
|
391
|
-
nk_size_t
|
|
392
|
-
nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
|
|
394
|
+
static void nk_angulars_packed_bf16_sme_finalize_ssve_( //
|
|
395
|
+
nk_bf16_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
396
|
+
nk_size_t a_stride_elements, nk_size_t c_stride_elements) NK_STREAMING_ {
|
|
393
397
|
nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
|
|
394
398
|
nk_f32_t const *b_norms = (nk_f32_t const *)((char const *)b_packed + header->norms_offset);
|
|
395
399
|
for (nk_size_t row_index = 0; row_index < rows; row_index++) {
|
|
@@ -408,21 +412,21 @@ __arm_locally_streaming static void nk_angulars_packed_bf16_sme_finalize_streami
|
|
|
408
412
|
}
|
|
409
413
|
}
|
|
410
414
|
|
|
411
|
-
NK_PUBLIC void nk_angulars_packed_bf16_sme(
|
|
412
|
-
nk_bf16_t const *a, void const *b_packed, nk_f32_t *c,
|
|
413
|
-
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
415
|
+
NK_PUBLIC void nk_angulars_packed_bf16_sme( //
|
|
416
|
+
nk_bf16_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
414
417
|
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
415
418
|
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_bf16_t);
|
|
416
419
|
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
|
|
420
|
+
nk_sme_start_streaming_();
|
|
417
421
|
nk_dots_packed_bf16_sme_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
|
|
418
|
-
|
|
419
|
-
|
|
422
|
+
nk_angulars_packed_bf16_sme_finalize_ssve_(a, b_packed, c, rows, columns, depth, a_stride_elements,
|
|
423
|
+
c_stride_elements);
|
|
424
|
+
nk_sme_stop_streaming_();
|
|
420
425
|
}
|
|
421
426
|
|
|
422
|
-
|
|
423
|
-
nk_bf16_t const *a, void const *b_packed, nk_f32_t *c,
|
|
424
|
-
nk_size_t
|
|
425
|
-
nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
|
|
427
|
+
static void nk_euclideans_packed_bf16_sme_finalize_ssve_( //
|
|
428
|
+
nk_bf16_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
429
|
+
nk_size_t a_stride_elements, nk_size_t c_stride_elements) NK_STREAMING_ {
|
|
426
430
|
nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
|
|
427
431
|
nk_f32_t const *b_norms = (nk_f32_t const *)((char const *)b_packed + header->norms_offset);
|
|
428
432
|
for (nk_size_t row_index = 0; row_index < rows; row_index++) {
|
|
@@ -441,20 +445,21 @@ __arm_locally_streaming static void nk_euclideans_packed_bf16_sme_finalize_strea
|
|
|
441
445
|
}
|
|
442
446
|
}
|
|
443
447
|
|
|
444
|
-
NK_PUBLIC void nk_euclideans_packed_bf16_sme(
|
|
445
|
-
nk_bf16_t const *a, void const *b_packed, nk_f32_t *c,
|
|
446
|
-
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
448
|
+
NK_PUBLIC void nk_euclideans_packed_bf16_sme( //
|
|
449
|
+
nk_bf16_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
447
450
|
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
448
451
|
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_bf16_t);
|
|
449
452
|
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
|
|
453
|
+
nk_sme_start_streaming_();
|
|
450
454
|
nk_dots_packed_bf16_sme_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
|
|
451
|
-
|
|
452
|
-
|
|
455
|
+
nk_euclideans_packed_bf16_sme_finalize_ssve_(a, b_packed, c, rows, columns, depth, a_stride_elements,
|
|
456
|
+
c_stride_elements);
|
|
457
|
+
nk_sme_stop_streaming_();
|
|
453
458
|
}
|
|
454
459
|
|
|
455
|
-
|
|
456
|
-
nk_bf16_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements,
|
|
457
|
-
|
|
460
|
+
static void nk_angulars_symmetric_bf16_sme_finalize_ssve_( //
|
|
461
|
+
nk_bf16_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, nk_f32_t *result,
|
|
462
|
+
nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) NK_STREAMING_ {
|
|
458
463
|
// Phase 1: cache row norms on diagonal
|
|
459
464
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
460
465
|
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
@@ -486,20 +491,22 @@ __arm_locally_streaming static void nk_angulars_symmetric_bf16_sme_finalize_stre
|
|
|
486
491
|
result[row_index * result_stride_elements + row_index] = 0;
|
|
487
492
|
}
|
|
488
493
|
|
|
489
|
-
NK_PUBLIC void nk_angulars_symmetric_bf16_sme(
|
|
490
|
-
nk_bf16_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes,
|
|
491
|
-
|
|
494
|
+
NK_PUBLIC void nk_angulars_symmetric_bf16_sme( //
|
|
495
|
+
nk_bf16_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, nk_f32_t *result,
|
|
496
|
+
nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
492
497
|
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_bf16_t);
|
|
493
498
|
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
499
|
+
nk_sme_start_streaming_();
|
|
494
500
|
nk_dots_symmetric_bf16_sme_streaming_(vectors, vectors_count, depth, stride_elements, result,
|
|
495
501
|
result_stride_elements, row_start, row_count);
|
|
496
|
-
|
|
497
|
-
|
|
502
|
+
nk_angulars_symmetric_bf16_sme_finalize_ssve_(vectors, vectors_count, depth, stride_elements, result,
|
|
503
|
+
result_stride_elements, row_start, row_count);
|
|
504
|
+
nk_sme_stop_streaming_();
|
|
498
505
|
}
|
|
499
506
|
|
|
500
|
-
|
|
501
|
-
nk_bf16_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements,
|
|
502
|
-
|
|
507
|
+
static void nk_euclideans_symmetric_bf16_sme_finalize_ssve_( //
|
|
508
|
+
nk_bf16_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, nk_f32_t *result,
|
|
509
|
+
nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) NK_STREAMING_ {
|
|
503
510
|
// Phase 1: cache row norms on diagonal
|
|
504
511
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
505
512
|
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
@@ -531,25 +538,26 @@ __arm_locally_streaming static void nk_euclideans_symmetric_bf16_sme_finalize_st
|
|
|
531
538
|
result[row_index * result_stride_elements + row_index] = 0;
|
|
532
539
|
}
|
|
533
540
|
|
|
534
|
-
NK_PUBLIC void nk_euclideans_symmetric_bf16_sme(
|
|
535
|
-
nk_bf16_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes,
|
|
536
|
-
|
|
541
|
+
NK_PUBLIC void nk_euclideans_symmetric_bf16_sme( //
|
|
542
|
+
nk_bf16_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, nk_f32_t *result,
|
|
543
|
+
nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
537
544
|
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_bf16_t);
|
|
538
545
|
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
546
|
+
nk_sme_start_streaming_();
|
|
539
547
|
nk_dots_symmetric_bf16_sme_streaming_(vectors, vectors_count, depth, stride_elements, result,
|
|
540
548
|
result_stride_elements, row_start, row_count);
|
|
541
|
-
|
|
542
|
-
|
|
549
|
+
nk_euclideans_symmetric_bf16_sme_finalize_ssve_(vectors, vectors_count, depth, stride_elements, result,
|
|
550
|
+
result_stride_elements, row_start, row_count);
|
|
551
|
+
nk_sme_stop_streaming_();
|
|
543
552
|
}
|
|
544
553
|
|
|
545
554
|
#pragma endregion BF16 Floats
|
|
546
555
|
|
|
547
556
|
#pragma region E4M3 Floats
|
|
548
557
|
|
|
549
|
-
|
|
550
|
-
nk_e4m3_t const *a, void const *b_packed, nk_f32_t *c,
|
|
551
|
-
nk_size_t
|
|
552
|
-
nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
|
|
558
|
+
static void nk_angulars_packed_e4m3_sme_finalize_ssve_( //
|
|
559
|
+
nk_e4m3_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
560
|
+
nk_size_t a_stride_elements, nk_size_t c_stride_elements) NK_STREAMING_ {
|
|
553
561
|
nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
|
|
554
562
|
nk_f32_t const *b_norms = (nk_f32_t const *)((char const *)b_packed + header->norms_offset);
|
|
555
563
|
for (nk_size_t row_index = 0; row_index < rows; row_index++) {
|
|
@@ -568,21 +576,21 @@ __arm_locally_streaming static void nk_angulars_packed_e4m3_sme_finalize_streami
|
|
|
568
576
|
}
|
|
569
577
|
}
|
|
570
578
|
|
|
571
|
-
NK_PUBLIC void nk_angulars_packed_e4m3_sme(
|
|
572
|
-
nk_e4m3_t const *a, void const *b_packed, nk_f32_t *c,
|
|
573
|
-
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
579
|
+
NK_PUBLIC void nk_angulars_packed_e4m3_sme( //
|
|
580
|
+
nk_e4m3_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
574
581
|
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
575
582
|
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_e4m3_t);
|
|
576
583
|
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
|
|
584
|
+
nk_sme_start_streaming_();
|
|
577
585
|
nk_dots_packed_e4m3_sme_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
|
|
578
|
-
|
|
579
|
-
|
|
586
|
+
nk_angulars_packed_e4m3_sme_finalize_ssve_(a, b_packed, c, rows, columns, depth, a_stride_elements,
|
|
587
|
+
c_stride_elements);
|
|
588
|
+
nk_sme_stop_streaming_();
|
|
580
589
|
}
|
|
581
590
|
|
|
582
|
-
|
|
583
|
-
nk_e4m3_t const *a, void const *b_packed, nk_f32_t *c,
|
|
584
|
-
nk_size_t
|
|
585
|
-
nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
|
|
591
|
+
static void nk_euclideans_packed_e4m3_sme_finalize_ssve_( //
|
|
592
|
+
nk_e4m3_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
593
|
+
nk_size_t a_stride_elements, nk_size_t c_stride_elements) NK_STREAMING_ {
|
|
586
594
|
nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
|
|
587
595
|
nk_f32_t const *b_norms = (nk_f32_t const *)((char const *)b_packed + header->norms_offset);
|
|
588
596
|
for (nk_size_t row_index = 0; row_index < rows; row_index++) {
|
|
@@ -601,20 +609,21 @@ __arm_locally_streaming static void nk_euclideans_packed_e4m3_sme_finalize_strea
|
|
|
601
609
|
}
|
|
602
610
|
}
|
|
603
611
|
|
|
604
|
-
NK_PUBLIC void nk_euclideans_packed_e4m3_sme(
|
|
605
|
-
nk_e4m3_t const *a, void const *b_packed, nk_f32_t *c,
|
|
606
|
-
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
612
|
+
NK_PUBLIC void nk_euclideans_packed_e4m3_sme( //
|
|
613
|
+
nk_e4m3_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
607
614
|
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
608
615
|
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_e4m3_t);
|
|
609
616
|
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
|
|
617
|
+
nk_sme_start_streaming_();
|
|
610
618
|
nk_dots_packed_e4m3_sme_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
|
|
611
|
-
|
|
612
|
-
|
|
619
|
+
nk_euclideans_packed_e4m3_sme_finalize_ssve_(a, b_packed, c, rows, columns, depth, a_stride_elements,
|
|
620
|
+
c_stride_elements);
|
|
621
|
+
nk_sme_stop_streaming_();
|
|
613
622
|
}
|
|
614
623
|
|
|
615
|
-
|
|
616
|
-
nk_e4m3_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements,
|
|
617
|
-
|
|
624
|
+
static void nk_angulars_symmetric_e4m3_sme_finalize_ssve_( //
|
|
625
|
+
nk_e4m3_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, nk_f32_t *result,
|
|
626
|
+
nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) NK_STREAMING_ {
|
|
618
627
|
// Phase 1: cache row norms on diagonal
|
|
619
628
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
620
629
|
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
@@ -646,20 +655,22 @@ __arm_locally_streaming static void nk_angulars_symmetric_e4m3_sme_finalize_stre
|
|
|
646
655
|
result[row_index * result_stride_elements + row_index] = 0;
|
|
647
656
|
}
|
|
648
657
|
|
|
649
|
-
NK_PUBLIC void nk_angulars_symmetric_e4m3_sme(
|
|
650
|
-
nk_e4m3_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes,
|
|
651
|
-
|
|
658
|
+
NK_PUBLIC void nk_angulars_symmetric_e4m3_sme( //
|
|
659
|
+
nk_e4m3_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, nk_f32_t *result,
|
|
660
|
+
nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
652
661
|
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_e4m3_t);
|
|
653
662
|
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
663
|
+
nk_sme_start_streaming_();
|
|
654
664
|
nk_dots_symmetric_e4m3_sme_streaming_(vectors, vectors_count, depth, stride_elements, result,
|
|
655
665
|
result_stride_elements, row_start, row_count);
|
|
656
|
-
|
|
657
|
-
|
|
666
|
+
nk_angulars_symmetric_e4m3_sme_finalize_ssve_(vectors, vectors_count, depth, stride_elements, result,
|
|
667
|
+
result_stride_elements, row_start, row_count);
|
|
668
|
+
nk_sme_stop_streaming_();
|
|
658
669
|
}
|
|
659
670
|
|
|
660
|
-
|
|
661
|
-
nk_e4m3_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements,
|
|
662
|
-
|
|
671
|
+
static void nk_euclideans_symmetric_e4m3_sme_finalize_ssve_( //
|
|
672
|
+
nk_e4m3_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, nk_f32_t *result,
|
|
673
|
+
nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) NK_STREAMING_ {
|
|
663
674
|
// Phase 1: cache row norms on diagonal
|
|
664
675
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
665
676
|
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
@@ -691,25 +702,26 @@ __arm_locally_streaming static void nk_euclideans_symmetric_e4m3_sme_finalize_st
|
|
|
691
702
|
result[row_index * result_stride_elements + row_index] = 0;
|
|
692
703
|
}
|
|
693
704
|
|
|
694
|
-
NK_PUBLIC void nk_euclideans_symmetric_e4m3_sme(
|
|
695
|
-
nk_e4m3_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes,
|
|
696
|
-
|
|
705
|
+
NK_PUBLIC void nk_euclideans_symmetric_e4m3_sme( //
|
|
706
|
+
nk_e4m3_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, nk_f32_t *result,
|
|
707
|
+
nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
697
708
|
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_e4m3_t);
|
|
698
709
|
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
710
|
+
nk_sme_start_streaming_();
|
|
699
711
|
nk_dots_symmetric_e4m3_sme_streaming_(vectors, vectors_count, depth, stride_elements, result,
|
|
700
712
|
result_stride_elements, row_start, row_count);
|
|
701
|
-
|
|
702
|
-
|
|
713
|
+
nk_euclideans_symmetric_e4m3_sme_finalize_ssve_(vectors, vectors_count, depth, stride_elements, result,
|
|
714
|
+
result_stride_elements, row_start, row_count);
|
|
715
|
+
nk_sme_stop_streaming_();
|
|
703
716
|
}
|
|
704
717
|
|
|
705
718
|
#pragma endregion E4M3 Floats
|
|
706
719
|
|
|
707
720
|
#pragma region E5M2 Floats
|
|
708
721
|
|
|
709
|
-
|
|
710
|
-
nk_e5m2_t const *a, void const *b_packed, nk_f32_t *c,
|
|
711
|
-
nk_size_t
|
|
712
|
-
nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
|
|
722
|
+
static void nk_angulars_packed_e5m2_sme_finalize_ssve_( //
|
|
723
|
+
nk_e5m2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
724
|
+
nk_size_t a_stride_elements, nk_size_t c_stride_elements) NK_STREAMING_ {
|
|
713
725
|
nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
|
|
714
726
|
nk_f32_t const *b_norms = (nk_f32_t const *)((char const *)b_packed + header->norms_offset);
|
|
715
727
|
for (nk_size_t row_index = 0; row_index < rows; row_index++) {
|
|
@@ -728,21 +740,21 @@ __arm_locally_streaming static void nk_angulars_packed_e5m2_sme_finalize_streami
|
|
|
728
740
|
}
|
|
729
741
|
}
|
|
730
742
|
|
|
731
|
-
NK_PUBLIC void nk_angulars_packed_e5m2_sme(
|
|
732
|
-
nk_e5m2_t const *a, void const *b_packed, nk_f32_t *c,
|
|
733
|
-
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
743
|
+
NK_PUBLIC void nk_angulars_packed_e5m2_sme( //
|
|
744
|
+
nk_e5m2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
734
745
|
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
735
746
|
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_e5m2_t);
|
|
736
747
|
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
|
|
748
|
+
nk_sme_start_streaming_();
|
|
737
749
|
nk_dots_packed_e5m2_sme_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
|
|
738
|
-
|
|
739
|
-
|
|
750
|
+
nk_angulars_packed_e5m2_sme_finalize_ssve_(a, b_packed, c, rows, columns, depth, a_stride_elements,
|
|
751
|
+
c_stride_elements);
|
|
752
|
+
nk_sme_stop_streaming_();
|
|
740
753
|
}
|
|
741
754
|
|
|
742
|
-
|
|
743
|
-
nk_e5m2_t const *a, void const *b_packed, nk_f32_t *c,
|
|
744
|
-
nk_size_t
|
|
745
|
-
nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
|
|
755
|
+
static void nk_euclideans_packed_e5m2_sme_finalize_ssve_( //
|
|
756
|
+
nk_e5m2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
757
|
+
nk_size_t a_stride_elements, nk_size_t c_stride_elements) NK_STREAMING_ {
|
|
746
758
|
nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
|
|
747
759
|
nk_f32_t const *b_norms = (nk_f32_t const *)((char const *)b_packed + header->norms_offset);
|
|
748
760
|
for (nk_size_t row_index = 0; row_index < rows; row_index++) {
|
|
@@ -761,20 +773,21 @@ __arm_locally_streaming static void nk_euclideans_packed_e5m2_sme_finalize_strea
|
|
|
761
773
|
}
|
|
762
774
|
}
|
|
763
775
|
|
|
764
|
-
NK_PUBLIC void nk_euclideans_packed_e5m2_sme(
|
|
765
|
-
nk_e5m2_t const *a, void const *b_packed, nk_f32_t *c,
|
|
766
|
-
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
776
|
+
NK_PUBLIC void nk_euclideans_packed_e5m2_sme( //
|
|
777
|
+
nk_e5m2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
767
778
|
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
768
779
|
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_e5m2_t);
|
|
769
780
|
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
|
|
781
|
+
nk_sme_start_streaming_();
|
|
770
782
|
nk_dots_packed_e5m2_sme_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
|
|
771
|
-
|
|
772
|
-
|
|
783
|
+
nk_euclideans_packed_e5m2_sme_finalize_ssve_(a, b_packed, c, rows, columns, depth, a_stride_elements,
|
|
784
|
+
c_stride_elements);
|
|
785
|
+
nk_sme_stop_streaming_();
|
|
773
786
|
}
|
|
774
787
|
|
|
775
|
-
|
|
776
|
-
nk_e5m2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements,
|
|
777
|
-
|
|
788
|
+
static void nk_angulars_symmetric_e5m2_sme_finalize_ssve_( //
|
|
789
|
+
nk_e5m2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, nk_f32_t *result,
|
|
790
|
+
nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) NK_STREAMING_ {
|
|
778
791
|
// Phase 1: cache row norms on diagonal
|
|
779
792
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
780
793
|
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
@@ -806,20 +819,22 @@ __arm_locally_streaming static void nk_angulars_symmetric_e5m2_sme_finalize_stre
|
|
|
806
819
|
result[row_index * result_stride_elements + row_index] = 0;
|
|
807
820
|
}
|
|
808
821
|
|
|
809
|
-
NK_PUBLIC void nk_angulars_symmetric_e5m2_sme(
|
|
810
|
-
nk_e5m2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes,
|
|
811
|
-
|
|
822
|
+
NK_PUBLIC void nk_angulars_symmetric_e5m2_sme( //
|
|
823
|
+
nk_e5m2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, nk_f32_t *result,
|
|
824
|
+
nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
812
825
|
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_e5m2_t);
|
|
813
826
|
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
827
|
+
nk_sme_start_streaming_();
|
|
814
828
|
nk_dots_symmetric_e5m2_sme_streaming_(vectors, vectors_count, depth, stride_elements, result,
|
|
815
829
|
result_stride_elements, row_start, row_count);
|
|
816
|
-
|
|
817
|
-
|
|
830
|
+
nk_angulars_symmetric_e5m2_sme_finalize_ssve_(vectors, vectors_count, depth, stride_elements, result,
|
|
831
|
+
result_stride_elements, row_start, row_count);
|
|
832
|
+
nk_sme_stop_streaming_();
|
|
818
833
|
}
|
|
819
834
|
|
|
820
|
-
|
|
821
|
-
nk_e5m2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements,
|
|
822
|
-
|
|
835
|
+
static void nk_euclideans_symmetric_e5m2_sme_finalize_ssve_( //
|
|
836
|
+
nk_e5m2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, nk_f32_t *result,
|
|
837
|
+
nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) NK_STREAMING_ {
|
|
823
838
|
// Phase 1: cache row norms on diagonal
|
|
824
839
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
825
840
|
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
@@ -851,25 +866,26 @@ __arm_locally_streaming static void nk_euclideans_symmetric_e5m2_sme_finalize_st
|
|
|
851
866
|
result[row_index * result_stride_elements + row_index] = 0;
|
|
852
867
|
}
|
|
853
868
|
|
|
854
|
-
NK_PUBLIC void nk_euclideans_symmetric_e5m2_sme(
|
|
855
|
-
nk_e5m2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes,
|
|
856
|
-
|
|
869
|
+
NK_PUBLIC void nk_euclideans_symmetric_e5m2_sme( //
|
|
870
|
+
nk_e5m2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, nk_f32_t *result,
|
|
871
|
+
nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
857
872
|
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_e5m2_t);
|
|
858
873
|
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
874
|
+
nk_sme_start_streaming_();
|
|
859
875
|
nk_dots_symmetric_e5m2_sme_streaming_(vectors, vectors_count, depth, stride_elements, result,
|
|
860
876
|
result_stride_elements, row_start, row_count);
|
|
861
|
-
|
|
862
|
-
|
|
877
|
+
nk_euclideans_symmetric_e5m2_sme_finalize_ssve_(vectors, vectors_count, depth, stride_elements, result,
|
|
878
|
+
result_stride_elements, row_start, row_count);
|
|
879
|
+
nk_sme_stop_streaming_();
|
|
863
880
|
}
|
|
864
881
|
|
|
865
882
|
#pragma endregion E5M2 Floats
|
|
866
883
|
|
|
867
884
|
#pragma region E2M3 Floats
|
|
868
885
|
|
|
869
|
-
|
|
870
|
-
nk_e2m3_t const *a, void const *b_packed, nk_f32_t *c,
|
|
871
|
-
nk_size_t
|
|
872
|
-
nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
|
|
886
|
+
static void nk_angulars_packed_e2m3_sme_finalize_ssve_( //
|
|
887
|
+
nk_e2m3_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
888
|
+
nk_size_t a_stride_elements, nk_size_t c_stride_elements) NK_STREAMING_ {
|
|
873
889
|
nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
|
|
874
890
|
nk_f32_t const *b_norms = (nk_f32_t const *)((char const *)b_packed + header->norms_offset);
|
|
875
891
|
for (nk_size_t row_index = 0; row_index < rows; row_index++) {
|
|
@@ -888,21 +904,21 @@ __arm_locally_streaming static void nk_angulars_packed_e2m3_sme_finalize_streami
|
|
|
888
904
|
}
|
|
889
905
|
}
|
|
890
906
|
|
|
891
|
-
NK_PUBLIC void nk_angulars_packed_e2m3_sme(
|
|
892
|
-
nk_e2m3_t const *a, void const *b_packed, nk_f32_t *c,
|
|
893
|
-
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
907
|
+
NK_PUBLIC void nk_angulars_packed_e2m3_sme( //
|
|
908
|
+
nk_e2m3_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
894
909
|
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
895
910
|
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_e2m3_t);
|
|
896
911
|
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
|
|
912
|
+
nk_sme_start_streaming_();
|
|
897
913
|
nk_dots_packed_e2m3_sme_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
|
|
898
|
-
|
|
899
|
-
|
|
914
|
+
nk_angulars_packed_e2m3_sme_finalize_ssve_(a, b_packed, c, rows, columns, depth, a_stride_elements,
|
|
915
|
+
c_stride_elements);
|
|
916
|
+
nk_sme_stop_streaming_();
|
|
900
917
|
}
|
|
901
918
|
|
|
902
|
-
|
|
903
|
-
nk_e2m3_t const *a, void const *b_packed, nk_f32_t *c,
|
|
904
|
-
nk_size_t
|
|
905
|
-
nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
|
|
919
|
+
static void nk_euclideans_packed_e2m3_sme_finalize_ssve_( //
|
|
920
|
+
nk_e2m3_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
921
|
+
nk_size_t a_stride_elements, nk_size_t c_stride_elements) NK_STREAMING_ {
|
|
906
922
|
nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
|
|
907
923
|
nk_f32_t const *b_norms = (nk_f32_t const *)((char const *)b_packed + header->norms_offset);
|
|
908
924
|
for (nk_size_t row_index = 0; row_index < rows; row_index++) {
|
|
@@ -921,20 +937,21 @@ __arm_locally_streaming static void nk_euclideans_packed_e2m3_sme_finalize_strea
|
|
|
921
937
|
}
|
|
922
938
|
}
|
|
923
939
|
|
|
924
|
-
NK_PUBLIC void nk_euclideans_packed_e2m3_sme(
|
|
925
|
-
nk_e2m3_t const *a, void const *b_packed, nk_f32_t *c,
|
|
926
|
-
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
940
|
+
NK_PUBLIC void nk_euclideans_packed_e2m3_sme( //
|
|
941
|
+
nk_e2m3_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
927
942
|
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
928
943
|
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_e2m3_t);
|
|
929
944
|
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
|
|
945
|
+
nk_sme_start_streaming_();
|
|
930
946
|
nk_dots_packed_e2m3_sme_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
|
|
931
|
-
|
|
932
|
-
|
|
947
|
+
nk_euclideans_packed_e2m3_sme_finalize_ssve_(a, b_packed, c, rows, columns, depth, a_stride_elements,
|
|
948
|
+
c_stride_elements);
|
|
949
|
+
nk_sme_stop_streaming_();
|
|
933
950
|
}
|
|
934
951
|
|
|
935
|
-
|
|
936
|
-
nk_e2m3_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements,
|
|
937
|
-
|
|
952
|
+
static void nk_angulars_symmetric_e2m3_sme_finalize_ssve_( //
|
|
953
|
+
nk_e2m3_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, nk_f32_t *result,
|
|
954
|
+
nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) NK_STREAMING_ {
|
|
938
955
|
// Phase 1: cache row norms on diagonal
|
|
939
956
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
940
957
|
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
@@ -966,20 +983,22 @@ __arm_locally_streaming static void nk_angulars_symmetric_e2m3_sme_finalize_stre
|
|
|
966
983
|
result[row_index * result_stride_elements + row_index] = 0;
|
|
967
984
|
}
|
|
968
985
|
|
|
969
|
-
NK_PUBLIC void nk_angulars_symmetric_e2m3_sme(
|
|
970
|
-
nk_e2m3_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes,
|
|
971
|
-
|
|
986
|
+
NK_PUBLIC void nk_angulars_symmetric_e2m3_sme( //
|
|
987
|
+
nk_e2m3_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, nk_f32_t *result,
|
|
988
|
+
nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
972
989
|
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_e2m3_t);
|
|
973
990
|
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
991
|
+
nk_sme_start_streaming_();
|
|
974
992
|
nk_dots_symmetric_e2m3_sme_streaming_(vectors, vectors_count, depth, stride_elements, result,
|
|
975
993
|
result_stride_elements, row_start, row_count);
|
|
976
|
-
|
|
977
|
-
|
|
994
|
+
nk_angulars_symmetric_e2m3_sme_finalize_ssve_(vectors, vectors_count, depth, stride_elements, result,
|
|
995
|
+
result_stride_elements, row_start, row_count);
|
|
996
|
+
nk_sme_stop_streaming_();
|
|
978
997
|
}
|
|
979
998
|
|
|
980
|
-
|
|
981
|
-
nk_e2m3_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements,
|
|
982
|
-
|
|
999
|
+
static void nk_euclideans_symmetric_e2m3_sme_finalize_ssve_( //
|
|
1000
|
+
nk_e2m3_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, nk_f32_t *result,
|
|
1001
|
+
nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) NK_STREAMING_ {
|
|
983
1002
|
// Phase 1: cache row norms on diagonal
|
|
984
1003
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
985
1004
|
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
@@ -1011,25 +1030,26 @@ __arm_locally_streaming static void nk_euclideans_symmetric_e2m3_sme_finalize_st
|
|
|
1011
1030
|
result[row_index * result_stride_elements + row_index] = 0;
|
|
1012
1031
|
}
|
|
1013
1032
|
|
|
1014
|
-
NK_PUBLIC void nk_euclideans_symmetric_e2m3_sme(
|
|
1015
|
-
nk_e2m3_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes,
|
|
1016
|
-
|
|
1033
|
+
NK_PUBLIC void nk_euclideans_symmetric_e2m3_sme( //
|
|
1034
|
+
nk_e2m3_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, nk_f32_t *result,
|
|
1035
|
+
nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
1017
1036
|
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_e2m3_t);
|
|
1018
1037
|
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
1038
|
+
nk_sme_start_streaming_();
|
|
1019
1039
|
nk_dots_symmetric_e2m3_sme_streaming_(vectors, vectors_count, depth, stride_elements, result,
|
|
1020
1040
|
result_stride_elements, row_start, row_count);
|
|
1021
|
-
|
|
1022
|
-
|
|
1041
|
+
nk_euclideans_symmetric_e2m3_sme_finalize_ssve_(vectors, vectors_count, depth, stride_elements, result,
|
|
1042
|
+
result_stride_elements, row_start, row_count);
|
|
1043
|
+
nk_sme_stop_streaming_();
|
|
1023
1044
|
}
|
|
1024
1045
|
|
|
1025
1046
|
#pragma endregion E2M3 Floats
|
|
1026
1047
|
|
|
1027
1048
|
#pragma region E3M2 Floats
|
|
1028
1049
|
|
|
1029
|
-
|
|
1030
|
-
nk_e3m2_t const *a, void const *b_packed, nk_f32_t *c,
|
|
1031
|
-
nk_size_t
|
|
1032
|
-
nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
|
|
1050
|
+
static void nk_angulars_packed_e3m2_sme_finalize_ssve_( //
|
|
1051
|
+
nk_e3m2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
1052
|
+
nk_size_t a_stride_elements, nk_size_t c_stride_elements) NK_STREAMING_ {
|
|
1033
1053
|
nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
|
|
1034
1054
|
nk_f32_t const *b_norms = (nk_f32_t const *)((char const *)b_packed + header->norms_offset);
|
|
1035
1055
|
for (nk_size_t row_index = 0; row_index < rows; row_index++) {
|
|
@@ -1048,21 +1068,21 @@ __arm_locally_streaming static void nk_angulars_packed_e3m2_sme_finalize_streami
|
|
|
1048
1068
|
}
|
|
1049
1069
|
}
|
|
1050
1070
|
|
|
1051
|
-
NK_PUBLIC void nk_angulars_packed_e3m2_sme(
|
|
1052
|
-
nk_e3m2_t const *a, void const *b_packed, nk_f32_t *c,
|
|
1053
|
-
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
1071
|
+
NK_PUBLIC void nk_angulars_packed_e3m2_sme( //
|
|
1072
|
+
nk_e3m2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
1054
1073
|
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
1055
1074
|
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_e3m2_t);
|
|
1056
1075
|
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
|
|
1076
|
+
nk_sme_start_streaming_();
|
|
1057
1077
|
nk_dots_packed_e3m2_sme_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
|
|
1058
|
-
|
|
1059
|
-
|
|
1078
|
+
nk_angulars_packed_e3m2_sme_finalize_ssve_(a, b_packed, c, rows, columns, depth, a_stride_elements,
|
|
1079
|
+
c_stride_elements);
|
|
1080
|
+
nk_sme_stop_streaming_();
|
|
1060
1081
|
}
|
|
1061
1082
|
|
|
1062
|
-
|
|
1063
|
-
nk_e3m2_t const *a, void const *b_packed, nk_f32_t *c,
|
|
1064
|
-
nk_size_t
|
|
1065
|
-
nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
|
|
1083
|
+
static void nk_euclideans_packed_e3m2_sme_finalize_ssve_( //
|
|
1084
|
+
nk_e3m2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
1085
|
+
nk_size_t a_stride_elements, nk_size_t c_stride_elements) NK_STREAMING_ {
|
|
1066
1086
|
nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
|
|
1067
1087
|
nk_f32_t const *b_norms = (nk_f32_t const *)((char const *)b_packed + header->norms_offset);
|
|
1068
1088
|
for (nk_size_t row_index = 0; row_index < rows; row_index++) {
|
|
@@ -1081,20 +1101,21 @@ __arm_locally_streaming static void nk_euclideans_packed_e3m2_sme_finalize_strea
|
|
|
1081
1101
|
}
|
|
1082
1102
|
}
|
|
1083
1103
|
|
|
1084
|
-
NK_PUBLIC void nk_euclideans_packed_e3m2_sme(
|
|
1085
|
-
nk_e3m2_t const *a, void const *b_packed, nk_f32_t *c,
|
|
1086
|
-
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
1104
|
+
NK_PUBLIC void nk_euclideans_packed_e3m2_sme( //
|
|
1105
|
+
nk_e3m2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
1087
1106
|
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
1088
1107
|
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_e3m2_t);
|
|
1089
1108
|
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
|
|
1109
|
+
nk_sme_start_streaming_();
|
|
1090
1110
|
nk_dots_packed_e3m2_sme_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
|
|
1091
|
-
|
|
1092
|
-
|
|
1111
|
+
nk_euclideans_packed_e3m2_sme_finalize_ssve_(a, b_packed, c, rows, columns, depth, a_stride_elements,
|
|
1112
|
+
c_stride_elements);
|
|
1113
|
+
nk_sme_stop_streaming_();
|
|
1093
1114
|
}
|
|
1094
1115
|
|
|
1095
|
-
|
|
1096
|
-
nk_e3m2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements,
|
|
1097
|
-
|
|
1116
|
+
static void nk_angulars_symmetric_e3m2_sme_finalize_ssve_( //
|
|
1117
|
+
nk_e3m2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, nk_f32_t *result,
|
|
1118
|
+
nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) NK_STREAMING_ {
|
|
1098
1119
|
// Phase 1: cache row norms on diagonal
|
|
1099
1120
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
1100
1121
|
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
@@ -1126,20 +1147,22 @@ __arm_locally_streaming static void nk_angulars_symmetric_e3m2_sme_finalize_stre
|
|
|
1126
1147
|
result[row_index * result_stride_elements + row_index] = 0;
|
|
1127
1148
|
}
|
|
1128
1149
|
|
|
1129
|
-
NK_PUBLIC void nk_angulars_symmetric_e3m2_sme(
|
|
1130
|
-
nk_e3m2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes,
|
|
1131
|
-
|
|
1150
|
+
NK_PUBLIC void nk_angulars_symmetric_e3m2_sme( //
|
|
1151
|
+
nk_e3m2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, nk_f32_t *result,
|
|
1152
|
+
nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
1132
1153
|
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_e3m2_t);
|
|
1133
1154
|
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
1155
|
+
nk_sme_start_streaming_();
|
|
1134
1156
|
nk_dots_symmetric_e3m2_sme_streaming_(vectors, vectors_count, depth, stride_elements, result,
|
|
1135
1157
|
result_stride_elements, row_start, row_count);
|
|
1136
|
-
|
|
1137
|
-
|
|
1158
|
+
nk_angulars_symmetric_e3m2_sme_finalize_ssve_(vectors, vectors_count, depth, stride_elements, result,
|
|
1159
|
+
result_stride_elements, row_start, row_count);
|
|
1160
|
+
nk_sme_stop_streaming_();
|
|
1138
1161
|
}
|
|
1139
1162
|
|
|
1140
|
-
|
|
1141
|
-
nk_e3m2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements,
|
|
1142
|
-
|
|
1163
|
+
static void nk_euclideans_symmetric_e3m2_sme_finalize_ssve_( //
|
|
1164
|
+
nk_e3m2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, nk_f32_t *result,
|
|
1165
|
+
nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) NK_STREAMING_ {
|
|
1143
1166
|
// Phase 1: cache row norms on diagonal
|
|
1144
1167
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
1145
1168
|
nk_f32_t *result_row = result + row_index * result_stride_elements;
|
|
@@ -1171,24 +1194,25 @@ __arm_locally_streaming static void nk_euclideans_symmetric_e3m2_sme_finalize_st
|
|
|
1171
1194
|
result[row_index * result_stride_elements + row_index] = 0;
|
|
1172
1195
|
}
|
|
1173
1196
|
|
|
1174
|
-
NK_PUBLIC void nk_euclideans_symmetric_e3m2_sme(
|
|
1175
|
-
nk_e3m2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes,
|
|
1176
|
-
|
|
1197
|
+
NK_PUBLIC void nk_euclideans_symmetric_e3m2_sme( //
|
|
1198
|
+
nk_e3m2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, nk_f32_t *result,
|
|
1199
|
+
nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
1177
1200
|
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_e3m2_t);
|
|
1178
1201
|
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
1202
|
+
nk_sme_start_streaming_();
|
|
1179
1203
|
nk_dots_symmetric_e3m2_sme_streaming_(vectors, vectors_count, depth, stride_elements, result,
|
|
1180
1204
|
result_stride_elements, row_start, row_count);
|
|
1181
|
-
|
|
1182
|
-
|
|
1205
|
+
nk_euclideans_symmetric_e3m2_sme_finalize_ssve_(vectors, vectors_count, depth, stride_elements, result,
|
|
1206
|
+
result_stride_elements, row_start, row_count);
|
|
1207
|
+
nk_sme_stop_streaming_();
|
|
1183
1208
|
}
|
|
1184
1209
|
|
|
1185
1210
|
#pragma endregion E3M2 Floats
|
|
1186
1211
|
#pragma region I8 Integers
|
|
1187
1212
|
|
|
1188
|
-
|
|
1189
|
-
nk_i8_t const *a, void const *b_packed, nk_f32_t *c,
|
|
1190
|
-
nk_size_t
|
|
1191
|
-
nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
|
|
1213
|
+
static void nk_angulars_packed_i8_sme_finalize_ssve_( //
|
|
1214
|
+
nk_i8_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
1215
|
+
nk_size_t a_stride_elements, nk_size_t c_stride_elements) NK_STREAMING_ {
|
|
1192
1216
|
nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
|
|
1193
1217
|
nk_u32_t const *b_norms = (nk_u32_t const *)((char const *)b_packed + header->norms_offset);
|
|
1194
1218
|
for (nk_size_t row_index = 0; row_index < rows; row_index++) {
|
|
@@ -1209,22 +1233,22 @@ __arm_locally_streaming static void nk_angulars_packed_i8_sme_finalize_streaming
|
|
|
1209
1233
|
}
|
|
1210
1234
|
}
|
|
1211
1235
|
|
|
1212
|
-
NK_PUBLIC void nk_angulars_packed_i8_sme(
|
|
1213
|
-
nk_i8_t const *a, void const *b_packed, nk_f32_t *c,
|
|
1214
|
-
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
1236
|
+
NK_PUBLIC void nk_angulars_packed_i8_sme( //
|
|
1237
|
+
nk_i8_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
1215
1238
|
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
1216
1239
|
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_i8_t);
|
|
1217
1240
|
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
|
|
1241
|
+
nk_sme_start_streaming_();
|
|
1218
1242
|
nk_dots_packed_i8_sme_streaming_(a, b_packed, (nk_i32_t *)c, rows, columns, depth, a_stride_elements,
|
|
1219
1243
|
c_stride_elements);
|
|
1220
|
-
|
|
1221
|
-
|
|
1244
|
+
nk_angulars_packed_i8_sme_finalize_ssve_(a, b_packed, c, rows, columns, depth, a_stride_elements,
|
|
1245
|
+
c_stride_elements);
|
|
1246
|
+
nk_sme_stop_streaming_();
|
|
1222
1247
|
}
|
|
1223
1248
|
|
|
1224
|
-
|
|
1225
|
-
nk_i8_t const *a, void const *b_packed, nk_f32_t *c,
|
|
1226
|
-
nk_size_t
|
|
1227
|
-
nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
|
|
1249
|
+
static void nk_euclideans_packed_i8_sme_finalize_ssve_( //
|
|
1250
|
+
nk_i8_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
1251
|
+
nk_size_t a_stride_elements, nk_size_t c_stride_elements) NK_STREAMING_ {
|
|
1228
1252
|
nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
|
|
1229
1253
|
nk_u32_t const *b_norms = (nk_u32_t const *)((char const *)b_packed + header->norms_offset);
|
|
1230
1254
|
for (nk_size_t row_index = 0; row_index < rows; row_index++) {
|
|
@@ -1245,21 +1269,22 @@ __arm_locally_streaming static void nk_euclideans_packed_i8_sme_finalize_streami
|
|
|
1245
1269
|
}
|
|
1246
1270
|
}
|
|
1247
1271
|
|
|
1248
|
-
NK_PUBLIC void nk_euclideans_packed_i8_sme(
|
|
1249
|
-
nk_i8_t const *a, void const *b_packed, nk_f32_t *c,
|
|
1250
|
-
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
1272
|
+
NK_PUBLIC void nk_euclideans_packed_i8_sme( //
|
|
1273
|
+
nk_i8_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
1251
1274
|
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
1252
1275
|
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_i8_t);
|
|
1253
1276
|
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
|
|
1277
|
+
nk_sme_start_streaming_();
|
|
1254
1278
|
nk_dots_packed_i8_sme_streaming_(a, b_packed, (nk_i32_t *)c, rows, columns, depth, a_stride_elements,
|
|
1255
1279
|
c_stride_elements);
|
|
1256
|
-
|
|
1257
|
-
|
|
1280
|
+
nk_euclideans_packed_i8_sme_finalize_ssve_(a, b_packed, c, rows, columns, depth, a_stride_elements,
|
|
1281
|
+
c_stride_elements);
|
|
1282
|
+
nk_sme_stop_streaming_();
|
|
1258
1283
|
}
|
|
1259
1284
|
|
|
1260
|
-
|
|
1261
|
-
nk_i8_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements,
|
|
1262
|
-
|
|
1285
|
+
static void nk_angulars_symmetric_i8_sme_finalize_ssve_( //
|
|
1286
|
+
nk_i8_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, nk_f32_t *result,
|
|
1287
|
+
nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) NK_STREAMING_ {
|
|
1263
1288
|
// Phase 1: cache row norms on diagonal (store as u32 in f32 slot)
|
|
1264
1289
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
1265
1290
|
nk_u32_t row_sumsq_u32 = nk_dots_reduce_sumsq_i8_ssve_(vectors + row_index * stride_elements, depth);
|
|
@@ -1294,20 +1319,22 @@ __arm_locally_streaming static void nk_angulars_symmetric_i8_sme_finalize_stream
|
|
|
1294
1319
|
result[row_index * result_stride_elements + row_index] = 0;
|
|
1295
1320
|
}
|
|
1296
1321
|
|
|
1297
|
-
NK_PUBLIC void nk_angulars_symmetric_i8_sme(
|
|
1298
|
-
nk_i8_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes,
|
|
1299
|
-
|
|
1322
|
+
NK_PUBLIC void nk_angulars_symmetric_i8_sme( //
|
|
1323
|
+
nk_i8_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, nk_f32_t *result,
|
|
1324
|
+
nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
1300
1325
|
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_i8_t);
|
|
1301
1326
|
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
1327
|
+
nk_sme_start_streaming_();
|
|
1302
1328
|
nk_dots_symmetric_i8_sme_streaming_(vectors, vectors_count, depth, stride_elements, (nk_i32_t *)result,
|
|
1303
1329
|
result_stride_elements, row_start, row_count);
|
|
1304
|
-
|
|
1305
|
-
|
|
1330
|
+
nk_angulars_symmetric_i8_sme_finalize_ssve_(vectors, vectors_count, depth, stride_elements, result,
|
|
1331
|
+
result_stride_elements, row_start, row_count);
|
|
1332
|
+
nk_sme_stop_streaming_();
|
|
1306
1333
|
}
|
|
1307
1334
|
|
|
1308
|
-
|
|
1309
|
-
nk_i8_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements,
|
|
1310
|
-
|
|
1335
|
+
static void nk_euclideans_symmetric_i8_sme_finalize_ssve_( //
|
|
1336
|
+
nk_i8_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, nk_f32_t *result,
|
|
1337
|
+
nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) NK_STREAMING_ {
|
|
1311
1338
|
// Phase 1: cache row norms on diagonal (store as u32 in f32 slot)
|
|
1312
1339
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
1313
1340
|
nk_u32_t row_sumsq_u32 = nk_dots_reduce_sumsq_i8_ssve_(vectors + row_index * stride_elements, depth);
|
|
@@ -1342,25 +1369,26 @@ __arm_locally_streaming static void nk_euclideans_symmetric_i8_sme_finalize_stre
|
|
|
1342
1369
|
result[row_index * result_stride_elements + row_index] = 0;
|
|
1343
1370
|
}
|
|
1344
1371
|
|
|
1345
|
-
NK_PUBLIC void nk_euclideans_symmetric_i8_sme(
|
|
1346
|
-
nk_i8_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes,
|
|
1347
|
-
|
|
1372
|
+
NK_PUBLIC void nk_euclideans_symmetric_i8_sme( //
|
|
1373
|
+
nk_i8_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, nk_f32_t *result,
|
|
1374
|
+
nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
1348
1375
|
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_i8_t);
|
|
1349
1376
|
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
1377
|
+
nk_sme_start_streaming_();
|
|
1350
1378
|
nk_dots_symmetric_i8_sme_streaming_(vectors, vectors_count, depth, stride_elements, (nk_i32_t *)result,
|
|
1351
1379
|
result_stride_elements, row_start, row_count);
|
|
1352
|
-
|
|
1353
|
-
|
|
1380
|
+
nk_euclideans_symmetric_i8_sme_finalize_ssve_(vectors, vectors_count, depth, stride_elements, result,
|
|
1381
|
+
result_stride_elements, row_start, row_count);
|
|
1382
|
+
nk_sme_stop_streaming_();
|
|
1354
1383
|
}
|
|
1355
1384
|
|
|
1356
1385
|
#pragma endregion I8 Integers
|
|
1357
1386
|
|
|
1358
1387
|
#pragma region U8 Integers
|
|
1359
1388
|
|
|
1360
|
-
|
|
1361
|
-
nk_u8_t const *a, void const *b_packed, nk_f32_t *c,
|
|
1362
|
-
nk_size_t
|
|
1363
|
-
nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
|
|
1389
|
+
static void nk_angulars_packed_u8_sme_finalize_ssve_( //
|
|
1390
|
+
nk_u8_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
1391
|
+
nk_size_t a_stride_elements, nk_size_t c_stride_elements) NK_STREAMING_ {
|
|
1364
1392
|
nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
|
|
1365
1393
|
nk_u32_t const *b_norms = (nk_u32_t const *)((char const *)b_packed + header->norms_offset);
|
|
1366
1394
|
for (nk_size_t row_index = 0; row_index < rows; row_index++) {
|
|
@@ -1381,22 +1409,22 @@ __arm_locally_streaming static void nk_angulars_packed_u8_sme_finalize_streaming
|
|
|
1381
1409
|
}
|
|
1382
1410
|
}
|
|
1383
1411
|
|
|
1384
|
-
NK_PUBLIC void nk_angulars_packed_u8_sme(
|
|
1385
|
-
nk_u8_t const *a, void const *b_packed, nk_f32_t *c,
|
|
1386
|
-
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
1412
|
+
NK_PUBLIC void nk_angulars_packed_u8_sme( //
|
|
1413
|
+
nk_u8_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
1387
1414
|
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
1388
1415
|
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_u8_t);
|
|
1389
1416
|
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
|
|
1417
|
+
nk_sme_start_streaming_();
|
|
1390
1418
|
nk_dots_packed_u8_sme_streaming_(a, b_packed, (nk_u32_t *)c, rows, columns, depth, a_stride_elements,
|
|
1391
1419
|
c_stride_elements);
|
|
1392
|
-
|
|
1393
|
-
|
|
1420
|
+
nk_angulars_packed_u8_sme_finalize_ssve_(a, b_packed, c, rows, columns, depth, a_stride_elements,
|
|
1421
|
+
c_stride_elements);
|
|
1422
|
+
nk_sme_stop_streaming_();
|
|
1394
1423
|
}
|
|
1395
1424
|
|
|
1396
|
-
|
|
1397
|
-
nk_u8_t const *a, void const *b_packed, nk_f32_t *c,
|
|
1398
|
-
nk_size_t
|
|
1399
|
-
nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
|
|
1425
|
+
static void nk_euclideans_packed_u8_sme_finalize_ssve_( //
|
|
1426
|
+
nk_u8_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
1427
|
+
nk_size_t a_stride_elements, nk_size_t c_stride_elements) NK_STREAMING_ {
|
|
1400
1428
|
nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
|
|
1401
1429
|
nk_u32_t const *b_norms = (nk_u32_t const *)((char const *)b_packed + header->norms_offset);
|
|
1402
1430
|
for (nk_size_t row_index = 0; row_index < rows; row_index++) {
|
|
@@ -1417,21 +1445,22 @@ __arm_locally_streaming static void nk_euclideans_packed_u8_sme_finalize_streami
|
|
|
1417
1445
|
}
|
|
1418
1446
|
}
|
|
1419
1447
|
|
|
1420
|
-
NK_PUBLIC void nk_euclideans_packed_u8_sme(
|
|
1421
|
-
nk_u8_t const *a, void const *b_packed, nk_f32_t *c,
|
|
1422
|
-
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
1448
|
+
NK_PUBLIC void nk_euclideans_packed_u8_sme( //
|
|
1449
|
+
nk_u8_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
1423
1450
|
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
1424
1451
|
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_u8_t);
|
|
1425
1452
|
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
|
|
1453
|
+
nk_sme_start_streaming_();
|
|
1426
1454
|
nk_dots_packed_u8_sme_streaming_(a, b_packed, (nk_u32_t *)c, rows, columns, depth, a_stride_elements,
|
|
1427
1455
|
c_stride_elements);
|
|
1428
|
-
|
|
1429
|
-
|
|
1456
|
+
nk_euclideans_packed_u8_sme_finalize_ssve_(a, b_packed, c, rows, columns, depth, a_stride_elements,
|
|
1457
|
+
c_stride_elements);
|
|
1458
|
+
nk_sme_stop_streaming_();
|
|
1430
1459
|
}
|
|
1431
1460
|
|
|
1432
|
-
|
|
1433
|
-
nk_u8_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements,
|
|
1434
|
-
|
|
1461
|
+
static void nk_angulars_symmetric_u8_sme_finalize_ssve_( //
|
|
1462
|
+
nk_u8_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, nk_f32_t *result,
|
|
1463
|
+
nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) NK_STREAMING_ {
|
|
1435
1464
|
// Phase 1: cache row norms on diagonal (store as u32 in f32 slot)
|
|
1436
1465
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
1437
1466
|
nk_u32_t row_sumsq_u32 = nk_dots_reduce_sumsq_u8_ssve_(vectors + row_index * stride_elements, depth);
|
|
@@ -1466,20 +1495,22 @@ __arm_locally_streaming static void nk_angulars_symmetric_u8_sme_finalize_stream
|
|
|
1466
1495
|
result[row_index * result_stride_elements + row_index] = 0;
|
|
1467
1496
|
}
|
|
1468
1497
|
|
|
1469
|
-
NK_PUBLIC void nk_angulars_symmetric_u8_sme(
|
|
1470
|
-
nk_u8_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes,
|
|
1471
|
-
|
|
1498
|
+
NK_PUBLIC void nk_angulars_symmetric_u8_sme( //
|
|
1499
|
+
nk_u8_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, nk_f32_t *result,
|
|
1500
|
+
nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
1472
1501
|
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_u8_t);
|
|
1473
1502
|
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
1503
|
+
nk_sme_start_streaming_();
|
|
1474
1504
|
nk_dots_symmetric_u8_sme_streaming_(vectors, vectors_count, depth, stride_elements, (nk_u32_t *)result,
|
|
1475
1505
|
result_stride_elements, row_start, row_count);
|
|
1476
|
-
|
|
1477
|
-
|
|
1506
|
+
nk_angulars_symmetric_u8_sme_finalize_ssve_(vectors, vectors_count, depth, stride_elements, result,
|
|
1507
|
+
result_stride_elements, row_start, row_count);
|
|
1508
|
+
nk_sme_stop_streaming_();
|
|
1478
1509
|
}
|
|
1479
1510
|
|
|
1480
|
-
|
|
1481
|
-
nk_u8_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements,
|
|
1482
|
-
|
|
1511
|
+
static void nk_euclideans_symmetric_u8_sme_finalize_ssve_( //
|
|
1512
|
+
nk_u8_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, nk_f32_t *result,
|
|
1513
|
+
nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) NK_STREAMING_ {
|
|
1483
1514
|
// Phase 1: cache row norms on diagonal (store as u32 in f32 slot)
|
|
1484
1515
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
1485
1516
|
nk_u32_t row_sumsq_u32 = nk_dots_reduce_sumsq_u8_ssve_(vectors + row_index * stride_elements, depth);
|
|
@@ -1514,25 +1545,26 @@ __arm_locally_streaming static void nk_euclideans_symmetric_u8_sme_finalize_stre
|
|
|
1514
1545
|
result[row_index * result_stride_elements + row_index] = 0;
|
|
1515
1546
|
}
|
|
1516
1547
|
|
|
1517
|
-
NK_PUBLIC void nk_euclideans_symmetric_u8_sme(
|
|
1518
|
-
nk_u8_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes,
|
|
1519
|
-
|
|
1548
|
+
NK_PUBLIC void nk_euclideans_symmetric_u8_sme( //
|
|
1549
|
+
nk_u8_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, nk_f32_t *result,
|
|
1550
|
+
nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
1520
1551
|
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_u8_t);
|
|
1521
1552
|
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
1553
|
+
nk_sme_start_streaming_();
|
|
1522
1554
|
nk_dots_symmetric_u8_sme_streaming_(vectors, vectors_count, depth, stride_elements, (nk_u32_t *)result,
|
|
1523
1555
|
result_stride_elements, row_start, row_count);
|
|
1524
|
-
|
|
1525
|
-
|
|
1556
|
+
nk_euclideans_symmetric_u8_sme_finalize_ssve_(vectors, vectors_count, depth, stride_elements, result,
|
|
1557
|
+
result_stride_elements, row_start, row_count);
|
|
1558
|
+
nk_sme_stop_streaming_();
|
|
1526
1559
|
}
|
|
1527
1560
|
|
|
1528
1561
|
#pragma endregion U8 Integers
|
|
1529
1562
|
|
|
1530
1563
|
#pragma region I4 Integers
|
|
1531
1564
|
|
|
1532
|
-
|
|
1533
|
-
nk_i4x2_t const *a, void const *b_packed, nk_f32_t *c,
|
|
1534
|
-
nk_size_t
|
|
1535
|
-
nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
|
|
1565
|
+
static void nk_angulars_packed_i4_sme_finalize_ssve_( //
|
|
1566
|
+
nk_i4x2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
1567
|
+
nk_size_t a_stride_elements, nk_size_t c_stride_elements) NK_STREAMING_ {
|
|
1536
1568
|
nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
|
|
1537
1569
|
nk_u32_t const *b_norms = (nk_u32_t const *)((char const *)b_packed + header->norms_offset);
|
|
1538
1570
|
for (nk_size_t row_index = 0; row_index < rows; row_index++) {
|
|
@@ -1553,22 +1585,22 @@ __arm_locally_streaming static void nk_angulars_packed_i4_sme_finalize_streaming
|
|
|
1553
1585
|
}
|
|
1554
1586
|
}
|
|
1555
1587
|
|
|
1556
|
-
NK_PUBLIC void nk_angulars_packed_i4_sme(
|
|
1557
|
-
nk_i4x2_t const *a, void const *b_packed, nk_f32_t *c,
|
|
1558
|
-
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
1588
|
+
NK_PUBLIC void nk_angulars_packed_i4_sme( //
|
|
1589
|
+
nk_i4x2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
1559
1590
|
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
1560
1591
|
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_i4x2_t);
|
|
1561
1592
|
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
|
|
1593
|
+
nk_sme_start_streaming_();
|
|
1562
1594
|
nk_dots_packed_i4_sme_streaming_(a, b_packed, (nk_i32_t *)c, rows, columns, depth, a_stride_elements,
|
|
1563
1595
|
c_stride_elements);
|
|
1564
|
-
|
|
1565
|
-
|
|
1596
|
+
nk_angulars_packed_i4_sme_finalize_ssve_(a, b_packed, c, rows, columns, depth, a_stride_elements,
|
|
1597
|
+
c_stride_elements);
|
|
1598
|
+
nk_sme_stop_streaming_();
|
|
1566
1599
|
}
|
|
1567
1600
|
|
|
1568
|
-
|
|
1569
|
-
nk_i4x2_t const *a, void const *b_packed, nk_f32_t *c,
|
|
1570
|
-
nk_size_t
|
|
1571
|
-
nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
|
|
1601
|
+
static void nk_euclideans_packed_i4_sme_finalize_ssve_( //
|
|
1602
|
+
nk_i4x2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
1603
|
+
nk_size_t a_stride_elements, nk_size_t c_stride_elements) NK_STREAMING_ {
|
|
1572
1604
|
nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
|
|
1573
1605
|
nk_u32_t const *b_norms = (nk_u32_t const *)((char const *)b_packed + header->norms_offset);
|
|
1574
1606
|
for (nk_size_t row_index = 0; row_index < rows; row_index++) {
|
|
@@ -1589,21 +1621,22 @@ __arm_locally_streaming static void nk_euclideans_packed_i4_sme_finalize_streami
|
|
|
1589
1621
|
}
|
|
1590
1622
|
}
|
|
1591
1623
|
|
|
1592
|
-
NK_PUBLIC void nk_euclideans_packed_i4_sme(
|
|
1593
|
-
nk_i4x2_t const *a, void const *b_packed, nk_f32_t *c,
|
|
1594
|
-
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
1624
|
+
NK_PUBLIC void nk_euclideans_packed_i4_sme( //
|
|
1625
|
+
nk_i4x2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
1595
1626
|
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
1596
1627
|
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_i4x2_t);
|
|
1597
1628
|
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
|
|
1629
|
+
nk_sme_start_streaming_();
|
|
1598
1630
|
nk_dots_packed_i4_sme_streaming_(a, b_packed, (nk_i32_t *)c, rows, columns, depth, a_stride_elements,
|
|
1599
1631
|
c_stride_elements);
|
|
1600
|
-
|
|
1601
|
-
|
|
1632
|
+
nk_euclideans_packed_i4_sme_finalize_ssve_(a, b_packed, c, rows, columns, depth, a_stride_elements,
|
|
1633
|
+
c_stride_elements);
|
|
1634
|
+
nk_sme_stop_streaming_();
|
|
1602
1635
|
}
|
|
1603
1636
|
|
|
1604
|
-
|
|
1605
|
-
nk_i4x2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements,
|
|
1606
|
-
|
|
1637
|
+
static void nk_angulars_symmetric_i4_sme_finalize_ssve_( //
|
|
1638
|
+
nk_i4x2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, nk_f32_t *result,
|
|
1639
|
+
nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) NK_STREAMING_ {
|
|
1607
1640
|
// Phase 1: cache row norms on diagonal (store as u32 in f32 slot)
|
|
1608
1641
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
1609
1642
|
nk_u32_t row_sumsq_u32 = nk_dots_reduce_sumsq_i4_ssve_(vectors + row_index * stride_elements, depth);
|
|
@@ -1638,20 +1671,22 @@ __arm_locally_streaming static void nk_angulars_symmetric_i4_sme_finalize_stream
|
|
|
1638
1671
|
result[row_index * result_stride_elements + row_index] = 0;
|
|
1639
1672
|
}
|
|
1640
1673
|
|
|
1641
|
-
NK_PUBLIC void nk_angulars_symmetric_i4_sme(
|
|
1642
|
-
nk_i4x2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes,
|
|
1643
|
-
|
|
1674
|
+
NK_PUBLIC void nk_angulars_symmetric_i4_sme( //
|
|
1675
|
+
nk_i4x2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, nk_f32_t *result,
|
|
1676
|
+
nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
1644
1677
|
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_i4x2_t);
|
|
1645
1678
|
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
1679
|
+
nk_sme_start_streaming_();
|
|
1646
1680
|
nk_dots_symmetric_i4_sme_streaming_(vectors, vectors_count, depth, stride_elements, (nk_i32_t *)result,
|
|
1647
1681
|
result_stride_elements, row_start, row_count);
|
|
1648
|
-
|
|
1649
|
-
|
|
1682
|
+
nk_angulars_symmetric_i4_sme_finalize_ssve_(vectors, vectors_count, depth, stride_elements, result,
|
|
1683
|
+
result_stride_elements, row_start, row_count);
|
|
1684
|
+
nk_sme_stop_streaming_();
|
|
1650
1685
|
}
|
|
1651
1686
|
|
|
1652
|
-
|
|
1653
|
-
nk_i4x2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements,
|
|
1654
|
-
|
|
1687
|
+
static void nk_euclideans_symmetric_i4_sme_finalize_ssve_( //
|
|
1688
|
+
nk_i4x2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, nk_f32_t *result,
|
|
1689
|
+
nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) NK_STREAMING_ {
|
|
1655
1690
|
// Phase 1: cache row norms on diagonal (store as u32 in f32 slot)
|
|
1656
1691
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
1657
1692
|
nk_u32_t row_sumsq_u32 = nk_dots_reduce_sumsq_i4_ssve_(vectors + row_index * stride_elements, depth);
|
|
@@ -1686,25 +1721,26 @@ __arm_locally_streaming static void nk_euclideans_symmetric_i4_sme_finalize_stre
|
|
|
1686
1721
|
result[row_index * result_stride_elements + row_index] = 0;
|
|
1687
1722
|
}
|
|
1688
1723
|
|
|
1689
|
-
NK_PUBLIC void nk_euclideans_symmetric_i4_sme(
|
|
1690
|
-
nk_i4x2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes,
|
|
1691
|
-
|
|
1724
|
+
NK_PUBLIC void nk_euclideans_symmetric_i4_sme( //
|
|
1725
|
+
nk_i4x2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, nk_f32_t *result,
|
|
1726
|
+
nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
1692
1727
|
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_i4x2_t);
|
|
1693
1728
|
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
1729
|
+
nk_sme_start_streaming_();
|
|
1694
1730
|
nk_dots_symmetric_i4_sme_streaming_(vectors, vectors_count, depth, stride_elements, (nk_i32_t *)result,
|
|
1695
1731
|
result_stride_elements, row_start, row_count);
|
|
1696
|
-
|
|
1697
|
-
|
|
1732
|
+
nk_euclideans_symmetric_i4_sme_finalize_ssve_(vectors, vectors_count, depth, stride_elements, result,
|
|
1733
|
+
result_stride_elements, row_start, row_count);
|
|
1734
|
+
nk_sme_stop_streaming_();
|
|
1698
1735
|
}
|
|
1699
1736
|
|
|
1700
1737
|
#pragma endregion Signed Integers
|
|
1701
1738
|
|
|
1702
1739
|
#pragma region U4 Integers
|
|
1703
1740
|
|
|
1704
|
-
|
|
1705
|
-
nk_u4x2_t const *a, void const *b_packed, nk_f32_t *c,
|
|
1706
|
-
nk_size_t
|
|
1707
|
-
nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
|
|
1741
|
+
static void nk_angulars_packed_u4_sme_finalize_ssve_( //
|
|
1742
|
+
nk_u4x2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
1743
|
+
nk_size_t a_stride_elements, nk_size_t c_stride_elements) NK_STREAMING_ {
|
|
1708
1744
|
nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
|
|
1709
1745
|
nk_u32_t const *b_norms = (nk_u32_t const *)((char const *)b_packed + header->norms_offset);
|
|
1710
1746
|
for (nk_size_t row_index = 0; row_index < rows; row_index++) {
|
|
@@ -1725,22 +1761,22 @@ __arm_locally_streaming static void nk_angulars_packed_u4_sme_finalize_streaming
|
|
|
1725
1761
|
}
|
|
1726
1762
|
}
|
|
1727
1763
|
|
|
1728
|
-
NK_PUBLIC void nk_angulars_packed_u4_sme(
|
|
1729
|
-
nk_u4x2_t const *a, void const *b_packed, nk_f32_t *c,
|
|
1730
|
-
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
1764
|
+
NK_PUBLIC void nk_angulars_packed_u4_sme( //
|
|
1765
|
+
nk_u4x2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
1731
1766
|
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
1732
1767
|
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_u4x2_t);
|
|
1733
1768
|
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
|
|
1769
|
+
nk_sme_start_streaming_();
|
|
1734
1770
|
nk_dots_packed_u4_sme_streaming_(a, b_packed, (nk_u32_t *)c, rows, columns, depth, a_stride_elements,
|
|
1735
1771
|
c_stride_elements);
|
|
1736
|
-
|
|
1737
|
-
|
|
1772
|
+
nk_angulars_packed_u4_sme_finalize_ssve_(a, b_packed, c, rows, columns, depth, a_stride_elements,
|
|
1773
|
+
c_stride_elements);
|
|
1774
|
+
nk_sme_stop_streaming_();
|
|
1738
1775
|
}
|
|
1739
1776
|
|
|
1740
|
-
|
|
1741
|
-
nk_u4x2_t const *a, void const *b_packed, nk_f32_t *c,
|
|
1742
|
-
nk_size_t
|
|
1743
|
-
nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
|
|
1777
|
+
static void nk_euclideans_packed_u4_sme_finalize_ssve_( //
|
|
1778
|
+
nk_u4x2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
1779
|
+
nk_size_t a_stride_elements, nk_size_t c_stride_elements) NK_STREAMING_ {
|
|
1744
1780
|
nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
|
|
1745
1781
|
nk_u32_t const *b_norms = (nk_u32_t const *)((char const *)b_packed + header->norms_offset);
|
|
1746
1782
|
for (nk_size_t row_index = 0; row_index < rows; row_index++) {
|
|
@@ -1761,21 +1797,22 @@ __arm_locally_streaming static void nk_euclideans_packed_u4_sme_finalize_streami
|
|
|
1761
1797
|
}
|
|
1762
1798
|
}
|
|
1763
1799
|
|
|
1764
|
-
NK_PUBLIC void nk_euclideans_packed_u4_sme(
|
|
1765
|
-
nk_u4x2_t const *a, void const *b_packed, nk_f32_t *c,
|
|
1766
|
-
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
1800
|
+
NK_PUBLIC void nk_euclideans_packed_u4_sme( //
|
|
1801
|
+
nk_u4x2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
1767
1802
|
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
1768
1803
|
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_u4x2_t);
|
|
1769
1804
|
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
|
|
1805
|
+
nk_sme_start_streaming_();
|
|
1770
1806
|
nk_dots_packed_u4_sme_streaming_(a, b_packed, (nk_u32_t *)c, rows, columns, depth, a_stride_elements,
|
|
1771
1807
|
c_stride_elements);
|
|
1772
|
-
|
|
1773
|
-
|
|
1808
|
+
nk_euclideans_packed_u4_sme_finalize_ssve_(a, b_packed, c, rows, columns, depth, a_stride_elements,
|
|
1809
|
+
c_stride_elements);
|
|
1810
|
+
nk_sme_stop_streaming_();
|
|
1774
1811
|
}
|
|
1775
1812
|
|
|
1776
|
-
|
|
1777
|
-
nk_u4x2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements,
|
|
1778
|
-
|
|
1813
|
+
static void nk_angulars_symmetric_u4_sme_finalize_ssve_( //
|
|
1814
|
+
nk_u4x2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, nk_f32_t *result,
|
|
1815
|
+
nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) NK_STREAMING_ {
|
|
1779
1816
|
// Phase 1: cache row norms on diagonal (store as u32 in f32 slot)
|
|
1780
1817
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
1781
1818
|
nk_u32_t row_sumsq_u32 = nk_dots_reduce_sumsq_u4_ssve_(vectors + row_index * stride_elements, depth);
|
|
@@ -1810,20 +1847,22 @@ __arm_locally_streaming static void nk_angulars_symmetric_u4_sme_finalize_stream
|
|
|
1810
1847
|
result[row_index * result_stride_elements + row_index] = 0;
|
|
1811
1848
|
}
|
|
1812
1849
|
|
|
1813
|
-
NK_PUBLIC void nk_angulars_symmetric_u4_sme(
|
|
1814
|
-
nk_u4x2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes,
|
|
1815
|
-
|
|
1850
|
+
NK_PUBLIC void nk_angulars_symmetric_u4_sme( //
|
|
1851
|
+
nk_u4x2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, nk_f32_t *result,
|
|
1852
|
+
nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
1816
1853
|
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_u4x2_t);
|
|
1817
1854
|
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
1855
|
+
nk_sme_start_streaming_();
|
|
1818
1856
|
nk_dots_symmetric_u4_sme_streaming_(vectors, vectors_count, depth, stride_elements, (nk_u32_t *)result,
|
|
1819
1857
|
result_stride_elements, row_start, row_count);
|
|
1820
|
-
|
|
1821
|
-
|
|
1858
|
+
nk_angulars_symmetric_u4_sme_finalize_ssve_(vectors, vectors_count, depth, stride_elements, result,
|
|
1859
|
+
result_stride_elements, row_start, row_count);
|
|
1860
|
+
nk_sme_stop_streaming_();
|
|
1822
1861
|
}
|
|
1823
1862
|
|
|
1824
|
-
|
|
1825
|
-
nk_u4x2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements,
|
|
1826
|
-
|
|
1863
|
+
static void nk_euclideans_symmetric_u4_sme_finalize_ssve_( //
|
|
1864
|
+
nk_u4x2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, nk_f32_t *result,
|
|
1865
|
+
nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) NK_STREAMING_ {
|
|
1827
1866
|
// Phase 1: cache row norms on diagonal (store as u32 in f32 slot)
|
|
1828
1867
|
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
|
|
1829
1868
|
nk_u32_t row_sumsq_u32 = nk_dots_reduce_sumsq_u4_ssve_(vectors + row_index * stride_elements, depth);
|
|
@@ -1858,15 +1897,17 @@ __arm_locally_streaming static void nk_euclideans_symmetric_u4_sme_finalize_stre
|
|
|
1858
1897
|
result[row_index * result_stride_elements + row_index] = 0;
|
|
1859
1898
|
}
|
|
1860
1899
|
|
|
1861
|
-
NK_PUBLIC void nk_euclideans_symmetric_u4_sme(
|
|
1862
|
-
nk_u4x2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes,
|
|
1863
|
-
|
|
1900
|
+
NK_PUBLIC void nk_euclideans_symmetric_u4_sme( //
|
|
1901
|
+
nk_u4x2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, nk_f32_t *result,
|
|
1902
|
+
nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
1864
1903
|
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_u4x2_t);
|
|
1865
1904
|
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
1905
|
+
nk_sme_start_streaming_();
|
|
1866
1906
|
nk_dots_symmetric_u4_sme_streaming_(vectors, vectors_count, depth, stride_elements, (nk_u32_t *)result,
|
|
1867
1907
|
result_stride_elements, row_start, row_count);
|
|
1868
|
-
|
|
1869
|
-
|
|
1908
|
+
nk_euclideans_symmetric_u4_sme_finalize_ssve_(vectors, vectors_count, depth, stride_elements, result,
|
|
1909
|
+
result_stride_elements, row_start, row_count);
|
|
1910
|
+
nk_sme_stop_streaming_();
|
|
1870
1911
|
}
|
|
1871
1912
|
|
|
1872
1913
|
#pragma endregion Unsigned Integers
|