numkong 7.4.4 → 7.5.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. package/README.md +1 -0
  2. package/binding.gyp +81 -5
  3. package/c/dispatch_f16.c +23 -0
  4. package/c/numkong.c +0 -13
  5. package/include/numkong/attention/sme.h +34 -31
  6. package/include/numkong/capabilities.h +2 -15
  7. package/include/numkong/cast/neon.h +15 -0
  8. package/include/numkong/curved/smef64.h +82 -62
  9. package/include/numkong/dot/rvvbf16.h +1 -1
  10. package/include/numkong/dot/rvvhalf.h +1 -1
  11. package/include/numkong/dot/sve.h +6 -5
  12. package/include/numkong/dot/svebfdot.h +2 -1
  13. package/include/numkong/dot/svehalf.h +6 -5
  14. package/include/numkong/dot/svesdot.h +3 -2
  15. package/include/numkong/dots/graniteamx.h +733 -0
  16. package/include/numkong/dots/serial.h +11 -4
  17. package/include/numkong/dots/sme.h +172 -140
  18. package/include/numkong/dots/smebi32.h +14 -11
  19. package/include/numkong/dots/smef64.h +31 -26
  20. package/include/numkong/dots.h +29 -3
  21. package/include/numkong/each/serial.h +22 -0
  22. package/include/numkong/geospatial/haswell.h +1 -1
  23. package/include/numkong/geospatial/neon.h +1 -1
  24. package/include/numkong/geospatial/serial.h +1 -1
  25. package/include/numkong/geospatial/skylake.h +1 -1
  26. package/include/numkong/maxsim/sme.h +94 -55
  27. package/include/numkong/mesh/README.md +13 -27
  28. package/include/numkong/mesh/haswell.h +25 -122
  29. package/include/numkong/mesh/neon.h +21 -110
  30. package/include/numkong/mesh/neonbfdot.h +4 -43
  31. package/include/numkong/mesh/rvv.h +7 -82
  32. package/include/numkong/mesh/serial.h +48 -53
  33. package/include/numkong/mesh/skylake.h +7 -123
  34. package/include/numkong/mesh/v128relaxed.h +9 -93
  35. package/include/numkong/mesh.h +2 -2
  36. package/include/numkong/mesh.hpp +35 -96
  37. package/include/numkong/reduce/neon.h +29 -0
  38. package/include/numkong/reduce/neonbfdot.h +2 -2
  39. package/include/numkong/reduce/neonfhm.h +4 -4
  40. package/include/numkong/reduce/sve.h +52 -0
  41. package/include/numkong/reduce.h +4 -0
  42. package/include/numkong/set/sve.h +6 -5
  43. package/include/numkong/sets/smebi32.h +35 -30
  44. package/include/numkong/sparse/sve2.h +3 -2
  45. package/include/numkong/spatial/sve.h +7 -6
  46. package/include/numkong/spatial/svebfdot.h +7 -4
  47. package/include/numkong/spatial/svehalf.h +5 -4
  48. package/include/numkong/spatial/svesdot.h +9 -8
  49. package/include/numkong/spatials/graniteamx.h +173 -0
  50. package/include/numkong/spatials/serial.h +22 -0
  51. package/include/numkong/spatials/sme.h +391 -350
  52. package/include/numkong/spatials/smef64.h +79 -70
  53. package/include/numkong/spatials.h +37 -4
  54. package/include/numkong/types.h +59 -0
  55. package/javascript/dist/cjs/numkong.js +13 -0
  56. package/javascript/dist/esm/numkong.js +13 -0
  57. package/javascript/numkong.c +56 -12
  58. package/javascript/numkong.ts +13 -0
  59. package/package.json +7 -7
  60. package/probes/probe.js +2 -2
  61. package/wasm/numkong.wasm +0 -0
@@ -13,6 +13,7 @@
13
13
  #if NK_TARGET_SME
14
14
 
15
15
  #include "numkong/dots/serial.h"
16
+ #include "numkong/reduce/sve.h" // `nk_svaddv_f64_`
16
17
  #include "numkong/dots/sme.h"
17
18
 
18
19
  #if defined(__cplusplus)
@@ -44,7 +45,7 @@ NK_PUBLIC nk_f32_t nk_dots_reduce_sumsq_f16_ssve_(nk_f16_t const *data, nk_size_
44
45
  svfloat32_t values_odd_f32x = svcvtlt_f32_f16_x(predicate_odd_b32x, values_f16x);
45
46
  accumulator_odd_f32x = svmla_f32_m(predicate_odd_b32x, accumulator_odd_f32x, values_odd_f32x, values_odd_f32x);
46
47
  }
47
- return svaddv_f32(svptrue_b32(), accumulator_even_f32x) + svaddv_f32(svptrue_b32(), accumulator_odd_f32x);
48
+ return nk_svaddv_f32_(svptrue_b32(), accumulator_even_f32x) + nk_svaddv_f32_(svptrue_b32(), accumulator_odd_f32x);
48
49
  }
49
50
 
50
51
  NK_PUBLIC nk_f32_t nk_dots_reduce_sumsq_bf16_ssve_(nk_bf16_t const *data, nk_size_t count) NK_STREAMING_ {
@@ -55,7 +56,7 @@ NK_PUBLIC nk_f32_t nk_dots_reduce_sumsq_bf16_ssve_(nk_bf16_t const *data, nk_siz
55
56
  svbfloat16_t values_bf16x = svld1_bf16(predicate_b16x, (nk_bf16_for_arm_simd_t const *)(data + i));
56
57
  accumulator_f32x = svbfdot_f32(accumulator_f32x, values_bf16x, values_bf16x);
57
58
  }
58
- return svaddv_f32(svptrue_b32(), accumulator_f32x);
59
+ return nk_svaddv_f32_(svptrue_b32(), accumulator_f32x);
59
60
  }
60
61
 
61
62
  NK_PUBLIC nk_f32_t nk_dots_reduce_sumsq_e4m3_ssve_(nk_e4m3_t const *data, nk_size_t count) NK_STREAMING_ {
@@ -79,7 +80,7 @@ NK_PUBLIC nk_f32_t nk_dots_reduce_sumsq_e4m3_ssve_(nk_e4m3_t const *data, nk_siz
79
80
  svfloat32_t values_odd_f32x = svcvtlt_f32_f16_x(predicate_odd_b32x, values_f16x);
80
81
  accumulator_odd_f32x = svmla_f32_m(predicate_odd_b32x, accumulator_odd_f32x, values_odd_f32x, values_odd_f32x);
81
82
  }
82
- return svaddv_f32(svptrue_b32(), accumulator_even_f32x) + svaddv_f32(svptrue_b32(), accumulator_odd_f32x);
83
+ return nk_svaddv_f32_(svptrue_b32(), accumulator_even_f32x) + nk_svaddv_f32_(svptrue_b32(), accumulator_odd_f32x);
83
84
  }
84
85
 
85
86
  NK_PUBLIC nk_f32_t nk_dots_reduce_sumsq_e5m2_ssve_(nk_e5m2_t const *data, nk_size_t count) NK_STREAMING_ {
@@ -103,7 +104,7 @@ NK_PUBLIC nk_f32_t nk_dots_reduce_sumsq_e5m2_ssve_(nk_e5m2_t const *data, nk_siz
103
104
  svfloat32_t values_odd_f32x = svcvtlt_f32_f16_x(predicate_odd_b32x, values_f16x);
104
105
  accumulator_odd_f32x = svmla_f32_m(predicate_odd_b32x, accumulator_odd_f32x, values_odd_f32x, values_odd_f32x);
105
106
  }
106
- return svaddv_f32(svptrue_b32(), accumulator_even_f32x) + svaddv_f32(svptrue_b32(), accumulator_odd_f32x);
107
+ return nk_svaddv_f32_(svptrue_b32(), accumulator_even_f32x) + nk_svaddv_f32_(svptrue_b32(), accumulator_odd_f32x);
107
108
  }
108
109
 
109
110
  NK_PUBLIC nk_f32_t nk_dots_reduce_sumsq_e2m3_ssve_(nk_e2m3_t const *data, nk_size_t count) NK_STREAMING_ {
@@ -115,7 +116,7 @@ NK_PUBLIC nk_f32_t nk_dots_reduce_sumsq_e2m3_ssve_(nk_e2m3_t const *data, nk_siz
115
116
  svint8_t values_i8x = nk_e2m3x_to_i8x_ssve_(predicate_b8x, raw_u8x);
116
117
  accumulator_i32x = svdot_s32(accumulator_i32x, values_i8x, values_i8x);
117
118
  }
118
- return (nk_f32_t)svaddv_s32(svptrue_b32(), accumulator_i32x) / 256.0f;
119
+ return (nk_f32_t)nk_svaddv_s32_(svptrue_b32(), accumulator_i32x) / 256.0f;
119
120
  }
120
121
 
121
122
  NK_PUBLIC nk_f32_t nk_dots_reduce_sumsq_e3m2_ssve_(nk_e3m2_t const *data, nk_size_t count) NK_STREAMING_ {
@@ -139,7 +140,7 @@ NK_PUBLIC nk_f32_t nk_dots_reduce_sumsq_e3m2_ssve_(nk_e3m2_t const *data, nk_siz
139
140
  svfloat32_t values_odd_f32x = svcvtlt_f32_f16_x(predicate_odd_b32x, values_f16x);
140
141
  accumulator_odd_f32x = svmla_f32_m(predicate_odd_b32x, accumulator_odd_f32x, values_odd_f32x, values_odd_f32x);
141
142
  }
142
- return svaddv_f32(svptrue_b32(), accumulator_even_f32x) + svaddv_f32(svptrue_b32(), accumulator_odd_f32x);
143
+ return nk_svaddv_f32_(svptrue_b32(), accumulator_even_f32x) + nk_svaddv_f32_(svptrue_b32(), accumulator_odd_f32x);
143
144
  }
144
145
 
145
146
  NK_PUBLIC nk_u32_t nk_dots_reduce_sumsq_i8_ssve_(nk_i8_t const *data, nk_size_t count) NK_STREAMING_ {
@@ -150,7 +151,7 @@ NK_PUBLIC nk_u32_t nk_dots_reduce_sumsq_i8_ssve_(nk_i8_t const *data, nk_size_t
150
151
  svint8_t loaded_i8x = svld1_s8(predicate_b8x, data + i);
151
152
  accumulator_i32x = svdot_s32(accumulator_i32x, loaded_i8x, loaded_i8x);
152
153
  }
153
- return (nk_u32_t)svaddv_s32(svptrue_b32(), accumulator_i32x);
154
+ return (nk_u32_t)nk_svaddv_s32_(svptrue_b32(), accumulator_i32x);
154
155
  }
155
156
 
156
157
  NK_PUBLIC nk_u32_t nk_dots_reduce_sumsq_u8_ssve_(nk_u8_t const *data, nk_size_t count) NK_STREAMING_ {
@@ -161,7 +162,7 @@ NK_PUBLIC nk_u32_t nk_dots_reduce_sumsq_u8_ssve_(nk_u8_t const *data, nk_size_t
161
162
  svuint8_t loaded_u8x = svld1_u8(predicate_b8x, data + i);
162
163
  accumulator_u32x = svdot_u32(accumulator_u32x, loaded_u8x, loaded_u8x);
163
164
  }
164
- return (nk_u32_t)svaddv_u32(svptrue_b32(), accumulator_u32x);
165
+ return (nk_u32_t)nk_svaddv_u32_(svptrue_b32(), accumulator_u32x);
165
166
  }
166
167
 
167
168
  NK_PUBLIC nk_u32_t nk_dots_reduce_sumsq_i4_ssve_(nk_i4x2_t const *data, nk_size_t count) NK_STREAMING_ {
@@ -181,7 +182,7 @@ NK_PUBLIC nk_u32_t nk_dots_reduce_sumsq_i4_ssve_(nk_i4x2_t const *data, nk_size_
181
182
  accumulator_i32x = svdot_s32(accumulator_i32x, low_i8x, low_i8x);
182
183
  accumulator_i32x = svdot_s32(accumulator_i32x, high_i8x, high_i8x);
183
184
  }
184
- return (nk_u32_t)svaddv_s32(svptrue_b32(), accumulator_i32x);
185
+ return (nk_u32_t)nk_svaddv_s32_(svptrue_b32(), accumulator_i32x);
185
186
  }
186
187
 
187
188
  NK_PUBLIC nk_u32_t nk_dots_reduce_sumsq_u4_ssve_(nk_u4x2_t const *data, nk_size_t count) NK_STREAMING_ {
@@ -197,7 +198,7 @@ NK_PUBLIC nk_u32_t nk_dots_reduce_sumsq_u4_ssve_(nk_u4x2_t const *data, nk_size_
197
198
  accumulator_u32x = svdot_u32(accumulator_u32x, low_u8x, low_u8x);
198
199
  accumulator_u32x = svdot_u32(accumulator_u32x, high_u8x, high_u8x);
199
200
  }
200
- return (nk_u32_t)svaddv_u32(svptrue_b32(), accumulator_u32x);
201
+ return (nk_u32_t)nk_svaddv_u32_(svptrue_b32(), accumulator_u32x);
201
202
  }
202
203
 
203
204
  NK_PUBLIC svfloat32_t nk_angulars_from_dot_f32x_ssve_(svbool_t predicate_b32x, svfloat32_t dots_f32x,
@@ -226,10 +227,9 @@ NK_PUBLIC svfloat32_t nk_euclideans_from_dot_f32x_ssve_(svbool_t predicate_b32x,
226
227
 
227
228
  #pragma region F16 Floats
228
229
 
229
- __arm_locally_streaming static void nk_angulars_packed_f16_sme_finalize_streaming_( //
230
- nk_f16_t const *a, void const *b_packed, nk_f32_t *c, //
231
- nk_size_t rows, nk_size_t columns, nk_size_t depth, //
232
- nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
230
+ static void nk_angulars_packed_f16_sme_finalize_ssve_( //
231
+ nk_f16_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
232
+ nk_size_t a_stride_elements, nk_size_t c_stride_elements) NK_STREAMING_ {
233
233
  nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
234
234
  nk_f32_t const *b_norms = (nk_f32_t const *)((char const *)b_packed + header->norms_offset);
235
235
  for (nk_size_t row_index = 0; row_index < rows; row_index++) {
@@ -248,21 +248,21 @@ __arm_locally_streaming static void nk_angulars_packed_f16_sme_finalize_streamin
248
248
  }
249
249
  }
250
250
 
251
- NK_PUBLIC void nk_angulars_packed_f16_sme( //
252
- nk_f16_t const *a, void const *b_packed, nk_f32_t *c, //
253
- nk_size_t rows, nk_size_t columns, nk_size_t depth, //
251
+ NK_PUBLIC void nk_angulars_packed_f16_sme( //
252
+ nk_f16_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
254
253
  nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
255
254
  nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_f16_t);
256
255
  nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
256
+ nk_sme_start_streaming_();
257
257
  nk_dots_packed_f16_sme_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
258
- nk_angulars_packed_f16_sme_finalize_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements,
259
- c_stride_elements);
258
+ nk_angulars_packed_f16_sme_finalize_ssve_(a, b_packed, c, rows, columns, depth, a_stride_elements,
259
+ c_stride_elements);
260
+ nk_sme_stop_streaming_();
260
261
  }
261
262
 
262
- __arm_locally_streaming static void nk_euclideans_packed_f16_sme_finalize_streaming_( //
263
- nk_f16_t const *a, void const *b_packed, nk_f32_t *c, //
264
- nk_size_t rows, nk_size_t columns, nk_size_t depth, //
265
- nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
263
+ static void nk_euclideans_packed_f16_sme_finalize_ssve_( //
264
+ nk_f16_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
265
+ nk_size_t a_stride_elements, nk_size_t c_stride_elements) NK_STREAMING_ {
266
266
  nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
267
267
  nk_f32_t const *b_norms = (nk_f32_t const *)((char const *)b_packed + header->norms_offset);
268
268
  for (nk_size_t row_index = 0; row_index < rows; row_index++) {
@@ -281,20 +281,21 @@ __arm_locally_streaming static void nk_euclideans_packed_f16_sme_finalize_stream
281
281
  }
282
282
  }
283
283
 
284
- NK_PUBLIC void nk_euclideans_packed_f16_sme( //
285
- nk_f16_t const *a, void const *b_packed, nk_f32_t *c, //
286
- nk_size_t rows, nk_size_t columns, nk_size_t depth, //
284
+ NK_PUBLIC void nk_euclideans_packed_f16_sme( //
285
+ nk_f16_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
287
286
  nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
288
287
  nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_f16_t);
289
288
  nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
289
+ nk_sme_start_streaming_();
290
290
  nk_dots_packed_f16_sme_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
291
- nk_euclideans_packed_f16_sme_finalize_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements,
292
- c_stride_elements);
291
+ nk_euclideans_packed_f16_sme_finalize_ssve_(a, b_packed, c, rows, columns, depth, a_stride_elements,
292
+ c_stride_elements);
293
+ nk_sme_stop_streaming_();
293
294
  }
294
295
 
295
- __arm_locally_streaming static void nk_angulars_symmetric_f16_sme_finalize_streaming_( //
296
- nk_f16_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, //
297
- nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
296
+ static void nk_angulars_symmetric_f16_sme_finalize_ssve_( //
297
+ nk_f16_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, nk_f32_t *result,
298
+ nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) NK_STREAMING_ {
298
299
  // Phase 1: cache row norms on diagonal
299
300
  for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
300
301
  nk_f32_t *result_row = result + row_index * result_stride_elements;
@@ -326,20 +327,22 @@ __arm_locally_streaming static void nk_angulars_symmetric_f16_sme_finalize_strea
326
327
  result[row_index * result_stride_elements + row_index] = 0;
327
328
  }
328
329
 
329
- NK_PUBLIC void nk_angulars_symmetric_f16_sme( //
330
- nk_f16_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
331
- nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
330
+ NK_PUBLIC void nk_angulars_symmetric_f16_sme( //
331
+ nk_f16_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, nk_f32_t *result,
332
+ nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
332
333
  nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_f16_t);
333
334
  nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
335
+ nk_sme_start_streaming_();
334
336
  nk_dots_symmetric_f16_sme_streaming_(vectors, vectors_count, depth, stride_elements, result, result_stride_elements,
335
337
  row_start, row_count);
336
- nk_angulars_symmetric_f16_sme_finalize_streaming_(vectors, vectors_count, depth, stride_elements, result,
337
- result_stride_elements, row_start, row_count);
338
+ nk_angulars_symmetric_f16_sme_finalize_ssve_(vectors, vectors_count, depth, stride_elements, result,
339
+ result_stride_elements, row_start, row_count);
340
+ nk_sme_stop_streaming_();
338
341
  }
339
342
 
340
- __arm_locally_streaming static void nk_euclideans_symmetric_f16_sme_finalize_streaming_( //
341
- nk_f16_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, //
342
- nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
343
+ static void nk_euclideans_symmetric_f16_sme_finalize_ssve_( //
344
+ nk_f16_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, nk_f32_t *result,
345
+ nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) NK_STREAMING_ {
343
346
  // Phase 1: cache row norms on diagonal
344
347
  for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
345
348
  nk_f32_t *result_row = result + row_index * result_stride_elements;
@@ -371,25 +374,26 @@ __arm_locally_streaming static void nk_euclideans_symmetric_f16_sme_finalize_str
371
374
  result[row_index * result_stride_elements + row_index] = 0;
372
375
  }
373
376
 
374
- NK_PUBLIC void nk_euclideans_symmetric_f16_sme( //
375
- nk_f16_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
376
- nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
377
+ NK_PUBLIC void nk_euclideans_symmetric_f16_sme( //
378
+ nk_f16_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, nk_f32_t *result,
379
+ nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
377
380
  nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_f16_t);
378
381
  nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
382
+ nk_sme_start_streaming_();
379
383
  nk_dots_symmetric_f16_sme_streaming_(vectors, vectors_count, depth, stride_elements, result, result_stride_elements,
380
384
  row_start, row_count);
381
- nk_euclideans_symmetric_f16_sme_finalize_streaming_(vectors, vectors_count, depth, stride_elements, result,
382
- result_stride_elements, row_start, row_count);
385
+ nk_euclideans_symmetric_f16_sme_finalize_ssve_(vectors, vectors_count, depth, stride_elements, result,
386
+ result_stride_elements, row_start, row_count);
387
+ nk_sme_stop_streaming_();
383
388
  }
384
389
 
385
390
  #pragma endregion F16 Floats
386
391
 
387
392
  #pragma region BF16 Floats
388
393
 
389
- __arm_locally_streaming static void nk_angulars_packed_bf16_sme_finalize_streaming_( //
390
- nk_bf16_t const *a, void const *b_packed, nk_f32_t *c, //
391
- nk_size_t rows, nk_size_t columns, nk_size_t depth, //
392
- nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
394
+ static void nk_angulars_packed_bf16_sme_finalize_ssve_( //
395
+ nk_bf16_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
396
+ nk_size_t a_stride_elements, nk_size_t c_stride_elements) NK_STREAMING_ {
393
397
  nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
394
398
  nk_f32_t const *b_norms = (nk_f32_t const *)((char const *)b_packed + header->norms_offset);
395
399
  for (nk_size_t row_index = 0; row_index < rows; row_index++) {
@@ -408,21 +412,21 @@ __arm_locally_streaming static void nk_angulars_packed_bf16_sme_finalize_streami
408
412
  }
409
413
  }
410
414
 
411
- NK_PUBLIC void nk_angulars_packed_bf16_sme( //
412
- nk_bf16_t const *a, void const *b_packed, nk_f32_t *c, //
413
- nk_size_t rows, nk_size_t columns, nk_size_t depth, //
415
+ NK_PUBLIC void nk_angulars_packed_bf16_sme( //
416
+ nk_bf16_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
414
417
  nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
415
418
  nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_bf16_t);
416
419
  nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
420
+ nk_sme_start_streaming_();
417
421
  nk_dots_packed_bf16_sme_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
418
- nk_angulars_packed_bf16_sme_finalize_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements,
419
- c_stride_elements);
422
+ nk_angulars_packed_bf16_sme_finalize_ssve_(a, b_packed, c, rows, columns, depth, a_stride_elements,
423
+ c_stride_elements);
424
+ nk_sme_stop_streaming_();
420
425
  }
421
426
 
422
- __arm_locally_streaming static void nk_euclideans_packed_bf16_sme_finalize_streaming_( //
423
- nk_bf16_t const *a, void const *b_packed, nk_f32_t *c, //
424
- nk_size_t rows, nk_size_t columns, nk_size_t depth, //
425
- nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
427
+ static void nk_euclideans_packed_bf16_sme_finalize_ssve_( //
428
+ nk_bf16_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
429
+ nk_size_t a_stride_elements, nk_size_t c_stride_elements) NK_STREAMING_ {
426
430
  nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
427
431
  nk_f32_t const *b_norms = (nk_f32_t const *)((char const *)b_packed + header->norms_offset);
428
432
  for (nk_size_t row_index = 0; row_index < rows; row_index++) {
@@ -441,20 +445,21 @@ __arm_locally_streaming static void nk_euclideans_packed_bf16_sme_finalize_strea
441
445
  }
442
446
  }
443
447
 
444
- NK_PUBLIC void nk_euclideans_packed_bf16_sme( //
445
- nk_bf16_t const *a, void const *b_packed, nk_f32_t *c, //
446
- nk_size_t rows, nk_size_t columns, nk_size_t depth, //
448
+ NK_PUBLIC void nk_euclideans_packed_bf16_sme( //
449
+ nk_bf16_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
447
450
  nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
448
451
  nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_bf16_t);
449
452
  nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
453
+ nk_sme_start_streaming_();
450
454
  nk_dots_packed_bf16_sme_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
451
- nk_euclideans_packed_bf16_sme_finalize_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements,
452
- c_stride_elements);
455
+ nk_euclideans_packed_bf16_sme_finalize_ssve_(a, b_packed, c, rows, columns, depth, a_stride_elements,
456
+ c_stride_elements);
457
+ nk_sme_stop_streaming_();
453
458
  }
454
459
 
455
- __arm_locally_streaming static void nk_angulars_symmetric_bf16_sme_finalize_streaming_( //
456
- nk_bf16_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, //
457
- nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
460
+ static void nk_angulars_symmetric_bf16_sme_finalize_ssve_( //
461
+ nk_bf16_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, nk_f32_t *result,
462
+ nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) NK_STREAMING_ {
458
463
  // Phase 1: cache row norms on diagonal
459
464
  for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
460
465
  nk_f32_t *result_row = result + row_index * result_stride_elements;
@@ -486,20 +491,22 @@ __arm_locally_streaming static void nk_angulars_symmetric_bf16_sme_finalize_stre
486
491
  result[row_index * result_stride_elements + row_index] = 0;
487
492
  }
488
493
 
489
- NK_PUBLIC void nk_angulars_symmetric_bf16_sme( //
490
- nk_bf16_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
491
- nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
494
+ NK_PUBLIC void nk_angulars_symmetric_bf16_sme( //
495
+ nk_bf16_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, nk_f32_t *result,
496
+ nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
492
497
  nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_bf16_t);
493
498
  nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
499
+ nk_sme_start_streaming_();
494
500
  nk_dots_symmetric_bf16_sme_streaming_(vectors, vectors_count, depth, stride_elements, result,
495
501
  result_stride_elements, row_start, row_count);
496
- nk_angulars_symmetric_bf16_sme_finalize_streaming_(vectors, vectors_count, depth, stride_elements, result,
497
- result_stride_elements, row_start, row_count);
502
+ nk_angulars_symmetric_bf16_sme_finalize_ssve_(vectors, vectors_count, depth, stride_elements, result,
503
+ result_stride_elements, row_start, row_count);
504
+ nk_sme_stop_streaming_();
498
505
  }
499
506
 
500
- __arm_locally_streaming static void nk_euclideans_symmetric_bf16_sme_finalize_streaming_( //
501
- nk_bf16_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, //
502
- nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
507
+ static void nk_euclideans_symmetric_bf16_sme_finalize_ssve_( //
508
+ nk_bf16_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, nk_f32_t *result,
509
+ nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) NK_STREAMING_ {
503
510
  // Phase 1: cache row norms on diagonal
504
511
  for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
505
512
  nk_f32_t *result_row = result + row_index * result_stride_elements;
@@ -531,25 +538,26 @@ __arm_locally_streaming static void nk_euclideans_symmetric_bf16_sme_finalize_st
531
538
  result[row_index * result_stride_elements + row_index] = 0;
532
539
  }
533
540
 
534
- NK_PUBLIC void nk_euclideans_symmetric_bf16_sme( //
535
- nk_bf16_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
536
- nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
541
+ NK_PUBLIC void nk_euclideans_symmetric_bf16_sme( //
542
+ nk_bf16_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, nk_f32_t *result,
543
+ nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
537
544
  nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_bf16_t);
538
545
  nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
546
+ nk_sme_start_streaming_();
539
547
  nk_dots_symmetric_bf16_sme_streaming_(vectors, vectors_count, depth, stride_elements, result,
540
548
  result_stride_elements, row_start, row_count);
541
- nk_euclideans_symmetric_bf16_sme_finalize_streaming_(vectors, vectors_count, depth, stride_elements, result,
542
- result_stride_elements, row_start, row_count);
549
+ nk_euclideans_symmetric_bf16_sme_finalize_ssve_(vectors, vectors_count, depth, stride_elements, result,
550
+ result_stride_elements, row_start, row_count);
551
+ nk_sme_stop_streaming_();
543
552
  }
544
553
 
545
554
  #pragma endregion BF16 Floats
546
555
 
547
556
  #pragma region E4M3 Floats
548
557
 
549
- __arm_locally_streaming static void nk_angulars_packed_e4m3_sme_finalize_streaming_( //
550
- nk_e4m3_t const *a, void const *b_packed, nk_f32_t *c, //
551
- nk_size_t rows, nk_size_t columns, nk_size_t depth, //
552
- nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
558
+ static void nk_angulars_packed_e4m3_sme_finalize_ssve_( //
559
+ nk_e4m3_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
560
+ nk_size_t a_stride_elements, nk_size_t c_stride_elements) NK_STREAMING_ {
553
561
  nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
554
562
  nk_f32_t const *b_norms = (nk_f32_t const *)((char const *)b_packed + header->norms_offset);
555
563
  for (nk_size_t row_index = 0; row_index < rows; row_index++) {
@@ -568,21 +576,21 @@ __arm_locally_streaming static void nk_angulars_packed_e4m3_sme_finalize_streami
568
576
  }
569
577
  }
570
578
 
571
- NK_PUBLIC void nk_angulars_packed_e4m3_sme( //
572
- nk_e4m3_t const *a, void const *b_packed, nk_f32_t *c, //
573
- nk_size_t rows, nk_size_t columns, nk_size_t depth, //
579
+ NK_PUBLIC void nk_angulars_packed_e4m3_sme( //
580
+ nk_e4m3_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
574
581
  nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
575
582
  nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_e4m3_t);
576
583
  nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
584
+ nk_sme_start_streaming_();
577
585
  nk_dots_packed_e4m3_sme_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
578
- nk_angulars_packed_e4m3_sme_finalize_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements,
579
- c_stride_elements);
586
+ nk_angulars_packed_e4m3_sme_finalize_ssve_(a, b_packed, c, rows, columns, depth, a_stride_elements,
587
+ c_stride_elements);
588
+ nk_sme_stop_streaming_();
580
589
  }
581
590
 
582
- __arm_locally_streaming static void nk_euclideans_packed_e4m3_sme_finalize_streaming_( //
583
- nk_e4m3_t const *a, void const *b_packed, nk_f32_t *c, //
584
- nk_size_t rows, nk_size_t columns, nk_size_t depth, //
585
- nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
591
+ static void nk_euclideans_packed_e4m3_sme_finalize_ssve_( //
592
+ nk_e4m3_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
593
+ nk_size_t a_stride_elements, nk_size_t c_stride_elements) NK_STREAMING_ {
586
594
  nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
587
595
  nk_f32_t const *b_norms = (nk_f32_t const *)((char const *)b_packed + header->norms_offset);
588
596
  for (nk_size_t row_index = 0; row_index < rows; row_index++) {
@@ -601,20 +609,21 @@ __arm_locally_streaming static void nk_euclideans_packed_e4m3_sme_finalize_strea
601
609
  }
602
610
  }
603
611
 
604
- NK_PUBLIC void nk_euclideans_packed_e4m3_sme( //
605
- nk_e4m3_t const *a, void const *b_packed, nk_f32_t *c, //
606
- nk_size_t rows, nk_size_t columns, nk_size_t depth, //
612
+ NK_PUBLIC void nk_euclideans_packed_e4m3_sme( //
613
+ nk_e4m3_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
607
614
  nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
608
615
  nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_e4m3_t);
609
616
  nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
617
+ nk_sme_start_streaming_();
610
618
  nk_dots_packed_e4m3_sme_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
611
- nk_euclideans_packed_e4m3_sme_finalize_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements,
612
- c_stride_elements);
619
+ nk_euclideans_packed_e4m3_sme_finalize_ssve_(a, b_packed, c, rows, columns, depth, a_stride_elements,
620
+ c_stride_elements);
621
+ nk_sme_stop_streaming_();
613
622
  }
614
623
 
615
- __arm_locally_streaming static void nk_angulars_symmetric_e4m3_sme_finalize_streaming_( //
616
- nk_e4m3_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, //
617
- nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
624
+ static void nk_angulars_symmetric_e4m3_sme_finalize_ssve_( //
625
+ nk_e4m3_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, nk_f32_t *result,
626
+ nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) NK_STREAMING_ {
618
627
  // Phase 1: cache row norms on diagonal
619
628
  for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
620
629
  nk_f32_t *result_row = result + row_index * result_stride_elements;
@@ -646,20 +655,22 @@ __arm_locally_streaming static void nk_angulars_symmetric_e4m3_sme_finalize_stre
646
655
  result[row_index * result_stride_elements + row_index] = 0;
647
656
  }
648
657
 
649
- NK_PUBLIC void nk_angulars_symmetric_e4m3_sme( //
650
- nk_e4m3_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
651
- nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
658
+ NK_PUBLIC void nk_angulars_symmetric_e4m3_sme( //
659
+ nk_e4m3_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, nk_f32_t *result,
660
+ nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
652
661
  nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_e4m3_t);
653
662
  nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
663
+ nk_sme_start_streaming_();
654
664
  nk_dots_symmetric_e4m3_sme_streaming_(vectors, vectors_count, depth, stride_elements, result,
655
665
  result_stride_elements, row_start, row_count);
656
- nk_angulars_symmetric_e4m3_sme_finalize_streaming_(vectors, vectors_count, depth, stride_elements, result,
657
- result_stride_elements, row_start, row_count);
666
+ nk_angulars_symmetric_e4m3_sme_finalize_ssve_(vectors, vectors_count, depth, stride_elements, result,
667
+ result_stride_elements, row_start, row_count);
668
+ nk_sme_stop_streaming_();
658
669
  }
659
670
 
660
- __arm_locally_streaming static void nk_euclideans_symmetric_e4m3_sme_finalize_streaming_( //
661
- nk_e4m3_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, //
662
- nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
671
+ static void nk_euclideans_symmetric_e4m3_sme_finalize_ssve_( //
672
+ nk_e4m3_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, nk_f32_t *result,
673
+ nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) NK_STREAMING_ {
663
674
  // Phase 1: cache row norms on diagonal
664
675
  for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
665
676
  nk_f32_t *result_row = result + row_index * result_stride_elements;
@@ -691,25 +702,26 @@ __arm_locally_streaming static void nk_euclideans_symmetric_e4m3_sme_finalize_st
691
702
  result[row_index * result_stride_elements + row_index] = 0;
692
703
  }
693
704
 
694
- NK_PUBLIC void nk_euclideans_symmetric_e4m3_sme( //
695
- nk_e4m3_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
696
- nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
705
+ NK_PUBLIC void nk_euclideans_symmetric_e4m3_sme( //
706
+ nk_e4m3_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, nk_f32_t *result,
707
+ nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
697
708
  nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_e4m3_t);
698
709
  nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
710
+ nk_sme_start_streaming_();
699
711
  nk_dots_symmetric_e4m3_sme_streaming_(vectors, vectors_count, depth, stride_elements, result,
700
712
  result_stride_elements, row_start, row_count);
701
- nk_euclideans_symmetric_e4m3_sme_finalize_streaming_(vectors, vectors_count, depth, stride_elements, result,
702
- result_stride_elements, row_start, row_count);
713
+ nk_euclideans_symmetric_e4m3_sme_finalize_ssve_(vectors, vectors_count, depth, stride_elements, result,
714
+ result_stride_elements, row_start, row_count);
715
+ nk_sme_stop_streaming_();
703
716
  }
704
717
 
705
718
  #pragma endregion E4M3 Floats
706
719
 
707
720
  #pragma region E5M2 Floats
708
721
 
709
- __arm_locally_streaming static void nk_angulars_packed_e5m2_sme_finalize_streaming_( //
710
- nk_e5m2_t const *a, void const *b_packed, nk_f32_t *c, //
711
- nk_size_t rows, nk_size_t columns, nk_size_t depth, //
712
- nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
722
+ static void nk_angulars_packed_e5m2_sme_finalize_ssve_( //
723
+ nk_e5m2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
724
+ nk_size_t a_stride_elements, nk_size_t c_stride_elements) NK_STREAMING_ {
713
725
  nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
714
726
  nk_f32_t const *b_norms = (nk_f32_t const *)((char const *)b_packed + header->norms_offset);
715
727
  for (nk_size_t row_index = 0; row_index < rows; row_index++) {
@@ -728,21 +740,21 @@ __arm_locally_streaming static void nk_angulars_packed_e5m2_sme_finalize_streami
728
740
  }
729
741
  }
730
742
 
731
- NK_PUBLIC void nk_angulars_packed_e5m2_sme( //
732
- nk_e5m2_t const *a, void const *b_packed, nk_f32_t *c, //
733
- nk_size_t rows, nk_size_t columns, nk_size_t depth, //
743
+ NK_PUBLIC void nk_angulars_packed_e5m2_sme( //
744
+ nk_e5m2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
734
745
  nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
735
746
  nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_e5m2_t);
736
747
  nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
748
+ nk_sme_start_streaming_();
737
749
  nk_dots_packed_e5m2_sme_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
738
- nk_angulars_packed_e5m2_sme_finalize_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements,
739
- c_stride_elements);
750
+ nk_angulars_packed_e5m2_sme_finalize_ssve_(a, b_packed, c, rows, columns, depth, a_stride_elements,
751
+ c_stride_elements);
752
+ nk_sme_stop_streaming_();
740
753
  }
741
754
 
742
- __arm_locally_streaming static void nk_euclideans_packed_e5m2_sme_finalize_streaming_( //
743
- nk_e5m2_t const *a, void const *b_packed, nk_f32_t *c, //
744
- nk_size_t rows, nk_size_t columns, nk_size_t depth, //
745
- nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
755
+ static void nk_euclideans_packed_e5m2_sme_finalize_ssve_( //
756
+ nk_e5m2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
757
+ nk_size_t a_stride_elements, nk_size_t c_stride_elements) NK_STREAMING_ {
746
758
  nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
747
759
  nk_f32_t const *b_norms = (nk_f32_t const *)((char const *)b_packed + header->norms_offset);
748
760
  for (nk_size_t row_index = 0; row_index < rows; row_index++) {
@@ -761,20 +773,21 @@ __arm_locally_streaming static void nk_euclideans_packed_e5m2_sme_finalize_strea
761
773
  }
762
774
  }
763
775
 
764
- NK_PUBLIC void nk_euclideans_packed_e5m2_sme( //
765
- nk_e5m2_t const *a, void const *b_packed, nk_f32_t *c, //
766
- nk_size_t rows, nk_size_t columns, nk_size_t depth, //
776
+ NK_PUBLIC void nk_euclideans_packed_e5m2_sme( //
777
+ nk_e5m2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
767
778
  nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
768
779
  nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_e5m2_t);
769
780
  nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
781
+ nk_sme_start_streaming_();
770
782
  nk_dots_packed_e5m2_sme_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
771
- nk_euclideans_packed_e5m2_sme_finalize_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements,
772
- c_stride_elements);
783
+ nk_euclideans_packed_e5m2_sme_finalize_ssve_(a, b_packed, c, rows, columns, depth, a_stride_elements,
784
+ c_stride_elements);
785
+ nk_sme_stop_streaming_();
773
786
  }
774
787
 
775
- __arm_locally_streaming static void nk_angulars_symmetric_e5m2_sme_finalize_streaming_( //
776
- nk_e5m2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, //
777
- nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
788
+ static void nk_angulars_symmetric_e5m2_sme_finalize_ssve_( //
789
+ nk_e5m2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, nk_f32_t *result,
790
+ nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) NK_STREAMING_ {
778
791
  // Phase 1: cache row norms on diagonal
779
792
  for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
780
793
  nk_f32_t *result_row = result + row_index * result_stride_elements;
@@ -806,20 +819,22 @@ __arm_locally_streaming static void nk_angulars_symmetric_e5m2_sme_finalize_stre
806
819
  result[row_index * result_stride_elements + row_index] = 0;
807
820
  }
808
821
 
809
- NK_PUBLIC void nk_angulars_symmetric_e5m2_sme( //
810
- nk_e5m2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
811
- nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
822
+ NK_PUBLIC void nk_angulars_symmetric_e5m2_sme( //
823
+ nk_e5m2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, nk_f32_t *result,
824
+ nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
812
825
  nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_e5m2_t);
813
826
  nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
827
+ nk_sme_start_streaming_();
814
828
  nk_dots_symmetric_e5m2_sme_streaming_(vectors, vectors_count, depth, stride_elements, result,
815
829
  result_stride_elements, row_start, row_count);
816
- nk_angulars_symmetric_e5m2_sme_finalize_streaming_(vectors, vectors_count, depth, stride_elements, result,
817
- result_stride_elements, row_start, row_count);
830
+ nk_angulars_symmetric_e5m2_sme_finalize_ssve_(vectors, vectors_count, depth, stride_elements, result,
831
+ result_stride_elements, row_start, row_count);
832
+ nk_sme_stop_streaming_();
818
833
  }
819
834
 
820
- __arm_locally_streaming static void nk_euclideans_symmetric_e5m2_sme_finalize_streaming_( //
821
- nk_e5m2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, //
822
- nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
835
+ static void nk_euclideans_symmetric_e5m2_sme_finalize_ssve_( //
836
+ nk_e5m2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, nk_f32_t *result,
837
+ nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) NK_STREAMING_ {
823
838
  // Phase 1: cache row norms on diagonal
824
839
  for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
825
840
  nk_f32_t *result_row = result + row_index * result_stride_elements;
@@ -851,25 +866,26 @@ __arm_locally_streaming static void nk_euclideans_symmetric_e5m2_sme_finalize_st
851
866
  result[row_index * result_stride_elements + row_index] = 0;
852
867
  }
853
868
 
854
- NK_PUBLIC void nk_euclideans_symmetric_e5m2_sme( //
855
- nk_e5m2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
856
- nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
869
+ NK_PUBLIC void nk_euclideans_symmetric_e5m2_sme( //
870
+ nk_e5m2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, nk_f32_t *result,
871
+ nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
857
872
  nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_e5m2_t);
858
873
  nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
874
+ nk_sme_start_streaming_();
859
875
  nk_dots_symmetric_e5m2_sme_streaming_(vectors, vectors_count, depth, stride_elements, result,
860
876
  result_stride_elements, row_start, row_count);
861
- nk_euclideans_symmetric_e5m2_sme_finalize_streaming_(vectors, vectors_count, depth, stride_elements, result,
862
- result_stride_elements, row_start, row_count);
877
+ nk_euclideans_symmetric_e5m2_sme_finalize_ssve_(vectors, vectors_count, depth, stride_elements, result,
878
+ result_stride_elements, row_start, row_count);
879
+ nk_sme_stop_streaming_();
863
880
  }
864
881
 
865
882
  #pragma endregion E5M2 Floats
866
883
 
867
884
  #pragma region E2M3 Floats
868
885
 
869
- __arm_locally_streaming static void nk_angulars_packed_e2m3_sme_finalize_streaming_( //
870
- nk_e2m3_t const *a, void const *b_packed, nk_f32_t *c, //
871
- nk_size_t rows, nk_size_t columns, nk_size_t depth, //
872
- nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
886
+ static void nk_angulars_packed_e2m3_sme_finalize_ssve_( //
887
+ nk_e2m3_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
888
+ nk_size_t a_stride_elements, nk_size_t c_stride_elements) NK_STREAMING_ {
873
889
  nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
874
890
  nk_f32_t const *b_norms = (nk_f32_t const *)((char const *)b_packed + header->norms_offset);
875
891
  for (nk_size_t row_index = 0; row_index < rows; row_index++) {
@@ -888,21 +904,21 @@ __arm_locally_streaming static void nk_angulars_packed_e2m3_sme_finalize_streami
888
904
  }
889
905
  }
890
906
 
891
- NK_PUBLIC void nk_angulars_packed_e2m3_sme( //
892
- nk_e2m3_t const *a, void const *b_packed, nk_f32_t *c, //
893
- nk_size_t rows, nk_size_t columns, nk_size_t depth, //
907
+ NK_PUBLIC void nk_angulars_packed_e2m3_sme( //
908
+ nk_e2m3_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
894
909
  nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
895
910
  nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_e2m3_t);
896
911
  nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
912
+ nk_sme_start_streaming_();
897
913
  nk_dots_packed_e2m3_sme_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
898
- nk_angulars_packed_e2m3_sme_finalize_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements,
899
- c_stride_elements);
914
+ nk_angulars_packed_e2m3_sme_finalize_ssve_(a, b_packed, c, rows, columns, depth, a_stride_elements,
915
+ c_stride_elements);
916
+ nk_sme_stop_streaming_();
900
917
  }
901
918
 
902
- __arm_locally_streaming static void nk_euclideans_packed_e2m3_sme_finalize_streaming_( //
903
- nk_e2m3_t const *a, void const *b_packed, nk_f32_t *c, //
904
- nk_size_t rows, nk_size_t columns, nk_size_t depth, //
905
- nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
919
+ static void nk_euclideans_packed_e2m3_sme_finalize_ssve_( //
920
+ nk_e2m3_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
921
+ nk_size_t a_stride_elements, nk_size_t c_stride_elements) NK_STREAMING_ {
906
922
  nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
907
923
  nk_f32_t const *b_norms = (nk_f32_t const *)((char const *)b_packed + header->norms_offset);
908
924
  for (nk_size_t row_index = 0; row_index < rows; row_index++) {
@@ -921,20 +937,21 @@ __arm_locally_streaming static void nk_euclideans_packed_e2m3_sme_finalize_strea
921
937
  }
922
938
  }
923
939
 
924
- NK_PUBLIC void nk_euclideans_packed_e2m3_sme( //
925
- nk_e2m3_t const *a, void const *b_packed, nk_f32_t *c, //
926
- nk_size_t rows, nk_size_t columns, nk_size_t depth, //
940
+ NK_PUBLIC void nk_euclideans_packed_e2m3_sme( //
941
+ nk_e2m3_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
927
942
  nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
928
943
  nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_e2m3_t);
929
944
  nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
945
+ nk_sme_start_streaming_();
930
946
  nk_dots_packed_e2m3_sme_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
931
- nk_euclideans_packed_e2m3_sme_finalize_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements,
932
- c_stride_elements);
947
+ nk_euclideans_packed_e2m3_sme_finalize_ssve_(a, b_packed, c, rows, columns, depth, a_stride_elements,
948
+ c_stride_elements);
949
+ nk_sme_stop_streaming_();
933
950
  }
934
951
 
935
- __arm_locally_streaming static void nk_angulars_symmetric_e2m3_sme_finalize_streaming_( //
936
- nk_e2m3_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, //
937
- nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
952
+ static void nk_angulars_symmetric_e2m3_sme_finalize_ssve_( //
953
+ nk_e2m3_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, nk_f32_t *result,
954
+ nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) NK_STREAMING_ {
938
955
  // Phase 1: cache row norms on diagonal
939
956
  for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
940
957
  nk_f32_t *result_row = result + row_index * result_stride_elements;
@@ -966,20 +983,22 @@ __arm_locally_streaming static void nk_angulars_symmetric_e2m3_sme_finalize_stre
966
983
  result[row_index * result_stride_elements + row_index] = 0;
967
984
  }
968
985
 
969
- NK_PUBLIC void nk_angulars_symmetric_e2m3_sme( //
970
- nk_e2m3_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
971
- nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
986
+ NK_PUBLIC void nk_angulars_symmetric_e2m3_sme( //
987
+ nk_e2m3_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, nk_f32_t *result,
988
+ nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
972
989
  nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_e2m3_t);
973
990
  nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
991
+ nk_sme_start_streaming_();
974
992
  nk_dots_symmetric_e2m3_sme_streaming_(vectors, vectors_count, depth, stride_elements, result,
975
993
  result_stride_elements, row_start, row_count);
976
- nk_angulars_symmetric_e2m3_sme_finalize_streaming_(vectors, vectors_count, depth, stride_elements, result,
977
- result_stride_elements, row_start, row_count);
994
+ nk_angulars_symmetric_e2m3_sme_finalize_ssve_(vectors, vectors_count, depth, stride_elements, result,
995
+ result_stride_elements, row_start, row_count);
996
+ nk_sme_stop_streaming_();
978
997
  }
979
998
 
980
- __arm_locally_streaming static void nk_euclideans_symmetric_e2m3_sme_finalize_streaming_( //
981
- nk_e2m3_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, //
982
- nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
999
+ static void nk_euclideans_symmetric_e2m3_sme_finalize_ssve_( //
1000
+ nk_e2m3_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, nk_f32_t *result,
1001
+ nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) NK_STREAMING_ {
983
1002
  // Phase 1: cache row norms on diagonal
984
1003
  for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
985
1004
  nk_f32_t *result_row = result + row_index * result_stride_elements;
@@ -1011,25 +1030,26 @@ __arm_locally_streaming static void nk_euclideans_symmetric_e2m3_sme_finalize_st
1011
1030
  result[row_index * result_stride_elements + row_index] = 0;
1012
1031
  }
1013
1032
 
1014
- NK_PUBLIC void nk_euclideans_symmetric_e2m3_sme( //
1015
- nk_e2m3_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
1016
- nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
1033
+ NK_PUBLIC void nk_euclideans_symmetric_e2m3_sme( //
1034
+ nk_e2m3_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, nk_f32_t *result,
1035
+ nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
1017
1036
  nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_e2m3_t);
1018
1037
  nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
1038
+ nk_sme_start_streaming_();
1019
1039
  nk_dots_symmetric_e2m3_sme_streaming_(vectors, vectors_count, depth, stride_elements, result,
1020
1040
  result_stride_elements, row_start, row_count);
1021
- nk_euclideans_symmetric_e2m3_sme_finalize_streaming_(vectors, vectors_count, depth, stride_elements, result,
1022
- result_stride_elements, row_start, row_count);
1041
+ nk_euclideans_symmetric_e2m3_sme_finalize_ssve_(vectors, vectors_count, depth, stride_elements, result,
1042
+ result_stride_elements, row_start, row_count);
1043
+ nk_sme_stop_streaming_();
1023
1044
  }
1024
1045
 
1025
1046
  #pragma endregion E2M3 Floats
1026
1047
 
1027
1048
  #pragma region E3M2 Floats
1028
1049
 
1029
- __arm_locally_streaming static void nk_angulars_packed_e3m2_sme_finalize_streaming_( //
1030
- nk_e3m2_t const *a, void const *b_packed, nk_f32_t *c, //
1031
- nk_size_t rows, nk_size_t columns, nk_size_t depth, //
1032
- nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
1050
+ static void nk_angulars_packed_e3m2_sme_finalize_ssve_( //
1051
+ nk_e3m2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
1052
+ nk_size_t a_stride_elements, nk_size_t c_stride_elements) NK_STREAMING_ {
1033
1053
  nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
1034
1054
  nk_f32_t const *b_norms = (nk_f32_t const *)((char const *)b_packed + header->norms_offset);
1035
1055
  for (nk_size_t row_index = 0; row_index < rows; row_index++) {
@@ -1048,21 +1068,21 @@ __arm_locally_streaming static void nk_angulars_packed_e3m2_sme_finalize_streami
1048
1068
  }
1049
1069
  }
1050
1070
 
1051
- NK_PUBLIC void nk_angulars_packed_e3m2_sme( //
1052
- nk_e3m2_t const *a, void const *b_packed, nk_f32_t *c, //
1053
- nk_size_t rows, nk_size_t columns, nk_size_t depth, //
1071
+ NK_PUBLIC void nk_angulars_packed_e3m2_sme( //
1072
+ nk_e3m2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
1054
1073
  nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
1055
1074
  nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_e3m2_t);
1056
1075
  nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
1076
+ nk_sme_start_streaming_();
1057
1077
  nk_dots_packed_e3m2_sme_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
1058
- nk_angulars_packed_e3m2_sme_finalize_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements,
1059
- c_stride_elements);
1078
+ nk_angulars_packed_e3m2_sme_finalize_ssve_(a, b_packed, c, rows, columns, depth, a_stride_elements,
1079
+ c_stride_elements);
1080
+ nk_sme_stop_streaming_();
1060
1081
  }
1061
1082
 
1062
- __arm_locally_streaming static void nk_euclideans_packed_e3m2_sme_finalize_streaming_( //
1063
- nk_e3m2_t const *a, void const *b_packed, nk_f32_t *c, //
1064
- nk_size_t rows, nk_size_t columns, nk_size_t depth, //
1065
- nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
1083
+ static void nk_euclideans_packed_e3m2_sme_finalize_ssve_( //
1084
+ nk_e3m2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
1085
+ nk_size_t a_stride_elements, nk_size_t c_stride_elements) NK_STREAMING_ {
1066
1086
  nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
1067
1087
  nk_f32_t const *b_norms = (nk_f32_t const *)((char const *)b_packed + header->norms_offset);
1068
1088
  for (nk_size_t row_index = 0; row_index < rows; row_index++) {
@@ -1081,20 +1101,21 @@ __arm_locally_streaming static void nk_euclideans_packed_e3m2_sme_finalize_strea
1081
1101
  }
1082
1102
  }
1083
1103
 
1084
- NK_PUBLIC void nk_euclideans_packed_e3m2_sme( //
1085
- nk_e3m2_t const *a, void const *b_packed, nk_f32_t *c, //
1086
- nk_size_t rows, nk_size_t columns, nk_size_t depth, //
1104
+ NK_PUBLIC void nk_euclideans_packed_e3m2_sme( //
1105
+ nk_e3m2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
1087
1106
  nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
1088
1107
  nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_e3m2_t);
1089
1108
  nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
1109
+ nk_sme_start_streaming_();
1090
1110
  nk_dots_packed_e3m2_sme_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
1091
- nk_euclideans_packed_e3m2_sme_finalize_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements,
1092
- c_stride_elements);
1111
+ nk_euclideans_packed_e3m2_sme_finalize_ssve_(a, b_packed, c, rows, columns, depth, a_stride_elements,
1112
+ c_stride_elements);
1113
+ nk_sme_stop_streaming_();
1093
1114
  }
1094
1115
 
1095
- __arm_locally_streaming static void nk_angulars_symmetric_e3m2_sme_finalize_streaming_( //
1096
- nk_e3m2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, //
1097
- nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
1116
+ static void nk_angulars_symmetric_e3m2_sme_finalize_ssve_( //
1117
+ nk_e3m2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, nk_f32_t *result,
1118
+ nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) NK_STREAMING_ {
1098
1119
  // Phase 1: cache row norms on diagonal
1099
1120
  for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
1100
1121
  nk_f32_t *result_row = result + row_index * result_stride_elements;
@@ -1126,20 +1147,22 @@ __arm_locally_streaming static void nk_angulars_symmetric_e3m2_sme_finalize_stre
1126
1147
  result[row_index * result_stride_elements + row_index] = 0;
1127
1148
  }
1128
1149
 
1129
- NK_PUBLIC void nk_angulars_symmetric_e3m2_sme( //
1130
- nk_e3m2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
1131
- nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
1150
+ NK_PUBLIC void nk_angulars_symmetric_e3m2_sme( //
1151
+ nk_e3m2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, nk_f32_t *result,
1152
+ nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
1132
1153
  nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_e3m2_t);
1133
1154
  nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
1155
+ nk_sme_start_streaming_();
1134
1156
  nk_dots_symmetric_e3m2_sme_streaming_(vectors, vectors_count, depth, stride_elements, result,
1135
1157
  result_stride_elements, row_start, row_count);
1136
- nk_angulars_symmetric_e3m2_sme_finalize_streaming_(vectors, vectors_count, depth, stride_elements, result,
1137
- result_stride_elements, row_start, row_count);
1158
+ nk_angulars_symmetric_e3m2_sme_finalize_ssve_(vectors, vectors_count, depth, stride_elements, result,
1159
+ result_stride_elements, row_start, row_count);
1160
+ nk_sme_stop_streaming_();
1138
1161
  }
1139
1162
 
1140
- __arm_locally_streaming static void nk_euclideans_symmetric_e3m2_sme_finalize_streaming_( //
1141
- nk_e3m2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, //
1142
- nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
1163
+ static void nk_euclideans_symmetric_e3m2_sme_finalize_ssve_( //
1164
+ nk_e3m2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, nk_f32_t *result,
1165
+ nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) NK_STREAMING_ {
1143
1166
  // Phase 1: cache row norms on diagonal
1144
1167
  for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
1145
1168
  nk_f32_t *result_row = result + row_index * result_stride_elements;
@@ -1171,24 +1194,25 @@ __arm_locally_streaming static void nk_euclideans_symmetric_e3m2_sme_finalize_st
1171
1194
  result[row_index * result_stride_elements + row_index] = 0;
1172
1195
  }
1173
1196
 
1174
- NK_PUBLIC void nk_euclideans_symmetric_e3m2_sme( //
1175
- nk_e3m2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
1176
- nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
1197
+ NK_PUBLIC void nk_euclideans_symmetric_e3m2_sme( //
1198
+ nk_e3m2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, nk_f32_t *result,
1199
+ nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
1177
1200
  nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_e3m2_t);
1178
1201
  nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
1202
+ nk_sme_start_streaming_();
1179
1203
  nk_dots_symmetric_e3m2_sme_streaming_(vectors, vectors_count, depth, stride_elements, result,
1180
1204
  result_stride_elements, row_start, row_count);
1181
- nk_euclideans_symmetric_e3m2_sme_finalize_streaming_(vectors, vectors_count, depth, stride_elements, result,
1182
- result_stride_elements, row_start, row_count);
1205
+ nk_euclideans_symmetric_e3m2_sme_finalize_ssve_(vectors, vectors_count, depth, stride_elements, result,
1206
+ result_stride_elements, row_start, row_count);
1207
+ nk_sme_stop_streaming_();
1183
1208
  }
1184
1209
 
1185
1210
  #pragma endregion E3M2 Floats
1186
1211
  #pragma region I8 Integers
1187
1212
 
1188
- __arm_locally_streaming static void nk_angulars_packed_i8_sme_finalize_streaming_( //
1189
- nk_i8_t const *a, void const *b_packed, nk_f32_t *c, //
1190
- nk_size_t rows, nk_size_t columns, nk_size_t depth, //
1191
- nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
1213
+ static void nk_angulars_packed_i8_sme_finalize_ssve_( //
1214
+ nk_i8_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
1215
+ nk_size_t a_stride_elements, nk_size_t c_stride_elements) NK_STREAMING_ {
1192
1216
  nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
1193
1217
  nk_u32_t const *b_norms = (nk_u32_t const *)((char const *)b_packed + header->norms_offset);
1194
1218
  for (nk_size_t row_index = 0; row_index < rows; row_index++) {
@@ -1209,22 +1233,22 @@ __arm_locally_streaming static void nk_angulars_packed_i8_sme_finalize_streaming
1209
1233
  }
1210
1234
  }
1211
1235
 
1212
- NK_PUBLIC void nk_angulars_packed_i8_sme( //
1213
- nk_i8_t const *a, void const *b_packed, nk_f32_t *c, //
1214
- nk_size_t rows, nk_size_t columns, nk_size_t depth, //
1236
+ NK_PUBLIC void nk_angulars_packed_i8_sme( //
1237
+ nk_i8_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
1215
1238
  nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
1216
1239
  nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_i8_t);
1217
1240
  nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
1241
+ nk_sme_start_streaming_();
1218
1242
  nk_dots_packed_i8_sme_streaming_(a, b_packed, (nk_i32_t *)c, rows, columns, depth, a_stride_elements,
1219
1243
  c_stride_elements);
1220
- nk_angulars_packed_i8_sme_finalize_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements,
1221
- c_stride_elements);
1244
+ nk_angulars_packed_i8_sme_finalize_ssve_(a, b_packed, c, rows, columns, depth, a_stride_elements,
1245
+ c_stride_elements);
1246
+ nk_sme_stop_streaming_();
1222
1247
  }
1223
1248
 
1224
- __arm_locally_streaming static void nk_euclideans_packed_i8_sme_finalize_streaming_( //
1225
- nk_i8_t const *a, void const *b_packed, nk_f32_t *c, //
1226
- nk_size_t rows, nk_size_t columns, nk_size_t depth, //
1227
- nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
1249
+ static void nk_euclideans_packed_i8_sme_finalize_ssve_( //
1250
+ nk_i8_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
1251
+ nk_size_t a_stride_elements, nk_size_t c_stride_elements) NK_STREAMING_ {
1228
1252
  nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
1229
1253
  nk_u32_t const *b_norms = (nk_u32_t const *)((char const *)b_packed + header->norms_offset);
1230
1254
  for (nk_size_t row_index = 0; row_index < rows; row_index++) {
@@ -1245,21 +1269,22 @@ __arm_locally_streaming static void nk_euclideans_packed_i8_sme_finalize_streami
1245
1269
  }
1246
1270
  }
1247
1271
 
1248
- NK_PUBLIC void nk_euclideans_packed_i8_sme( //
1249
- nk_i8_t const *a, void const *b_packed, nk_f32_t *c, //
1250
- nk_size_t rows, nk_size_t columns, nk_size_t depth, //
1272
+ NK_PUBLIC void nk_euclideans_packed_i8_sme( //
1273
+ nk_i8_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
1251
1274
  nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
1252
1275
  nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_i8_t);
1253
1276
  nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
1277
+ nk_sme_start_streaming_();
1254
1278
  nk_dots_packed_i8_sme_streaming_(a, b_packed, (nk_i32_t *)c, rows, columns, depth, a_stride_elements,
1255
1279
  c_stride_elements);
1256
- nk_euclideans_packed_i8_sme_finalize_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements,
1257
- c_stride_elements);
1280
+ nk_euclideans_packed_i8_sme_finalize_ssve_(a, b_packed, c, rows, columns, depth, a_stride_elements,
1281
+ c_stride_elements);
1282
+ nk_sme_stop_streaming_();
1258
1283
  }
1259
1284
 
1260
- __arm_locally_streaming static void nk_angulars_symmetric_i8_sme_finalize_streaming_( //
1261
- nk_i8_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, //
1262
- nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
1285
+ static void nk_angulars_symmetric_i8_sme_finalize_ssve_( //
1286
+ nk_i8_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, nk_f32_t *result,
1287
+ nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) NK_STREAMING_ {
1263
1288
  // Phase 1: cache row norms on diagonal (store as u32 in f32 slot)
1264
1289
  for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
1265
1290
  nk_u32_t row_sumsq_u32 = nk_dots_reduce_sumsq_i8_ssve_(vectors + row_index * stride_elements, depth);
@@ -1294,20 +1319,22 @@ __arm_locally_streaming static void nk_angulars_symmetric_i8_sme_finalize_stream
1294
1319
  result[row_index * result_stride_elements + row_index] = 0;
1295
1320
  }
1296
1321
 
1297
- NK_PUBLIC void nk_angulars_symmetric_i8_sme( //
1298
- nk_i8_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
1299
- nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
1322
+ NK_PUBLIC void nk_angulars_symmetric_i8_sme( //
1323
+ nk_i8_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, nk_f32_t *result,
1324
+ nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
1300
1325
  nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_i8_t);
1301
1326
  nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
1327
+ nk_sme_start_streaming_();
1302
1328
  nk_dots_symmetric_i8_sme_streaming_(vectors, vectors_count, depth, stride_elements, (nk_i32_t *)result,
1303
1329
  result_stride_elements, row_start, row_count);
1304
- nk_angulars_symmetric_i8_sme_finalize_streaming_(vectors, vectors_count, depth, stride_elements, result,
1305
- result_stride_elements, row_start, row_count);
1330
+ nk_angulars_symmetric_i8_sme_finalize_ssve_(vectors, vectors_count, depth, stride_elements, result,
1331
+ result_stride_elements, row_start, row_count);
1332
+ nk_sme_stop_streaming_();
1306
1333
  }
1307
1334
 
1308
- __arm_locally_streaming static void nk_euclideans_symmetric_i8_sme_finalize_streaming_( //
1309
- nk_i8_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, //
1310
- nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
1335
+ static void nk_euclideans_symmetric_i8_sme_finalize_ssve_( //
1336
+ nk_i8_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, nk_f32_t *result,
1337
+ nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) NK_STREAMING_ {
1311
1338
  // Phase 1: cache row norms on diagonal (store as u32 in f32 slot)
1312
1339
  for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
1313
1340
  nk_u32_t row_sumsq_u32 = nk_dots_reduce_sumsq_i8_ssve_(vectors + row_index * stride_elements, depth);
@@ -1342,25 +1369,26 @@ __arm_locally_streaming static void nk_euclideans_symmetric_i8_sme_finalize_stre
1342
1369
  result[row_index * result_stride_elements + row_index] = 0;
1343
1370
  }
1344
1371
 
1345
- NK_PUBLIC void nk_euclideans_symmetric_i8_sme( //
1346
- nk_i8_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
1347
- nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
1372
+ NK_PUBLIC void nk_euclideans_symmetric_i8_sme( //
1373
+ nk_i8_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, nk_f32_t *result,
1374
+ nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
1348
1375
  nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_i8_t);
1349
1376
  nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
1377
+ nk_sme_start_streaming_();
1350
1378
  nk_dots_symmetric_i8_sme_streaming_(vectors, vectors_count, depth, stride_elements, (nk_i32_t *)result,
1351
1379
  result_stride_elements, row_start, row_count);
1352
- nk_euclideans_symmetric_i8_sme_finalize_streaming_(vectors, vectors_count, depth, stride_elements, result,
1353
- result_stride_elements, row_start, row_count);
1380
+ nk_euclideans_symmetric_i8_sme_finalize_ssve_(vectors, vectors_count, depth, stride_elements, result,
1381
+ result_stride_elements, row_start, row_count);
1382
+ nk_sme_stop_streaming_();
1354
1383
  }
1355
1384
 
1356
1385
  #pragma endregion I8 Integers
1357
1386
 
1358
1387
  #pragma region U8 Integers
1359
1388
 
1360
- __arm_locally_streaming static void nk_angulars_packed_u8_sme_finalize_streaming_( //
1361
- nk_u8_t const *a, void const *b_packed, nk_f32_t *c, //
1362
- nk_size_t rows, nk_size_t columns, nk_size_t depth, //
1363
- nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
1389
+ static void nk_angulars_packed_u8_sme_finalize_ssve_( //
1390
+ nk_u8_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
1391
+ nk_size_t a_stride_elements, nk_size_t c_stride_elements) NK_STREAMING_ {
1364
1392
  nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
1365
1393
  nk_u32_t const *b_norms = (nk_u32_t const *)((char const *)b_packed + header->norms_offset);
1366
1394
  for (nk_size_t row_index = 0; row_index < rows; row_index++) {
@@ -1381,22 +1409,22 @@ __arm_locally_streaming static void nk_angulars_packed_u8_sme_finalize_streaming
1381
1409
  }
1382
1410
  }
1383
1411
 
1384
- NK_PUBLIC void nk_angulars_packed_u8_sme( //
1385
- nk_u8_t const *a, void const *b_packed, nk_f32_t *c, //
1386
- nk_size_t rows, nk_size_t columns, nk_size_t depth, //
1412
+ NK_PUBLIC void nk_angulars_packed_u8_sme( //
1413
+ nk_u8_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
1387
1414
  nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
1388
1415
  nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_u8_t);
1389
1416
  nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
1417
+ nk_sme_start_streaming_();
1390
1418
  nk_dots_packed_u8_sme_streaming_(a, b_packed, (nk_u32_t *)c, rows, columns, depth, a_stride_elements,
1391
1419
  c_stride_elements);
1392
- nk_angulars_packed_u8_sme_finalize_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements,
1393
- c_stride_elements);
1420
+ nk_angulars_packed_u8_sme_finalize_ssve_(a, b_packed, c, rows, columns, depth, a_stride_elements,
1421
+ c_stride_elements);
1422
+ nk_sme_stop_streaming_();
1394
1423
  }
1395
1424
 
1396
- __arm_locally_streaming static void nk_euclideans_packed_u8_sme_finalize_streaming_( //
1397
- nk_u8_t const *a, void const *b_packed, nk_f32_t *c, //
1398
- nk_size_t rows, nk_size_t columns, nk_size_t depth, //
1399
- nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
1425
+ static void nk_euclideans_packed_u8_sme_finalize_ssve_( //
1426
+ nk_u8_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
1427
+ nk_size_t a_stride_elements, nk_size_t c_stride_elements) NK_STREAMING_ {
1400
1428
  nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
1401
1429
  nk_u32_t const *b_norms = (nk_u32_t const *)((char const *)b_packed + header->norms_offset);
1402
1430
  for (nk_size_t row_index = 0; row_index < rows; row_index++) {
@@ -1417,21 +1445,22 @@ __arm_locally_streaming static void nk_euclideans_packed_u8_sme_finalize_streami
1417
1445
  }
1418
1446
  }
1419
1447
 
1420
- NK_PUBLIC void nk_euclideans_packed_u8_sme( //
1421
- nk_u8_t const *a, void const *b_packed, nk_f32_t *c, //
1422
- nk_size_t rows, nk_size_t columns, nk_size_t depth, //
1448
+ NK_PUBLIC void nk_euclideans_packed_u8_sme( //
1449
+ nk_u8_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
1423
1450
  nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
1424
1451
  nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_u8_t);
1425
1452
  nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
1453
+ nk_sme_start_streaming_();
1426
1454
  nk_dots_packed_u8_sme_streaming_(a, b_packed, (nk_u32_t *)c, rows, columns, depth, a_stride_elements,
1427
1455
  c_stride_elements);
1428
- nk_euclideans_packed_u8_sme_finalize_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements,
1429
- c_stride_elements);
1456
+ nk_euclideans_packed_u8_sme_finalize_ssve_(a, b_packed, c, rows, columns, depth, a_stride_elements,
1457
+ c_stride_elements);
1458
+ nk_sme_stop_streaming_();
1430
1459
  }
1431
1460
 
1432
- __arm_locally_streaming static void nk_angulars_symmetric_u8_sme_finalize_streaming_( //
1433
- nk_u8_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, //
1434
- nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
1461
+ static void nk_angulars_symmetric_u8_sme_finalize_ssve_( //
1462
+ nk_u8_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, nk_f32_t *result,
1463
+ nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) NK_STREAMING_ {
1435
1464
  // Phase 1: cache row norms on diagonal (store as u32 in f32 slot)
1436
1465
  for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
1437
1466
  nk_u32_t row_sumsq_u32 = nk_dots_reduce_sumsq_u8_ssve_(vectors + row_index * stride_elements, depth);
@@ -1466,20 +1495,22 @@ __arm_locally_streaming static void nk_angulars_symmetric_u8_sme_finalize_stream
1466
1495
  result[row_index * result_stride_elements + row_index] = 0;
1467
1496
  }
1468
1497
 
1469
- NK_PUBLIC void nk_angulars_symmetric_u8_sme( //
1470
- nk_u8_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
1471
- nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
1498
+ NK_PUBLIC void nk_angulars_symmetric_u8_sme( //
1499
+ nk_u8_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, nk_f32_t *result,
1500
+ nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
1472
1501
  nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_u8_t);
1473
1502
  nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
1503
+ nk_sme_start_streaming_();
1474
1504
  nk_dots_symmetric_u8_sme_streaming_(vectors, vectors_count, depth, stride_elements, (nk_u32_t *)result,
1475
1505
  result_stride_elements, row_start, row_count);
1476
- nk_angulars_symmetric_u8_sme_finalize_streaming_(vectors, vectors_count, depth, stride_elements, result,
1477
- result_stride_elements, row_start, row_count);
1506
+ nk_angulars_symmetric_u8_sme_finalize_ssve_(vectors, vectors_count, depth, stride_elements, result,
1507
+ result_stride_elements, row_start, row_count);
1508
+ nk_sme_stop_streaming_();
1478
1509
  }
1479
1510
 
1480
- __arm_locally_streaming static void nk_euclideans_symmetric_u8_sme_finalize_streaming_( //
1481
- nk_u8_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, //
1482
- nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
1511
+ static void nk_euclideans_symmetric_u8_sme_finalize_ssve_( //
1512
+ nk_u8_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, nk_f32_t *result,
1513
+ nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) NK_STREAMING_ {
1483
1514
  // Phase 1: cache row norms on diagonal (store as u32 in f32 slot)
1484
1515
  for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
1485
1516
  nk_u32_t row_sumsq_u32 = nk_dots_reduce_sumsq_u8_ssve_(vectors + row_index * stride_elements, depth);
@@ -1514,25 +1545,26 @@ __arm_locally_streaming static void nk_euclideans_symmetric_u8_sme_finalize_stre
1514
1545
  result[row_index * result_stride_elements + row_index] = 0;
1515
1546
  }
1516
1547
 
1517
- NK_PUBLIC void nk_euclideans_symmetric_u8_sme( //
1518
- nk_u8_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
1519
- nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
1548
+ NK_PUBLIC void nk_euclideans_symmetric_u8_sme( //
1549
+ nk_u8_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, nk_f32_t *result,
1550
+ nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
1520
1551
  nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_u8_t);
1521
1552
  nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
1553
+ nk_sme_start_streaming_();
1522
1554
  nk_dots_symmetric_u8_sme_streaming_(vectors, vectors_count, depth, stride_elements, (nk_u32_t *)result,
1523
1555
  result_stride_elements, row_start, row_count);
1524
- nk_euclideans_symmetric_u8_sme_finalize_streaming_(vectors, vectors_count, depth, stride_elements, result,
1525
- result_stride_elements, row_start, row_count);
1556
+ nk_euclideans_symmetric_u8_sme_finalize_ssve_(vectors, vectors_count, depth, stride_elements, result,
1557
+ result_stride_elements, row_start, row_count);
1558
+ nk_sme_stop_streaming_();
1526
1559
  }
1527
1560
 
1528
1561
  #pragma endregion U8 Integers
1529
1562
 
1530
1563
  #pragma region I4 Integers
1531
1564
 
1532
- __arm_locally_streaming static void nk_angulars_packed_i4_sme_finalize_streaming_( //
1533
- nk_i4x2_t const *a, void const *b_packed, nk_f32_t *c, //
1534
- nk_size_t rows, nk_size_t columns, nk_size_t depth, //
1535
- nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
1565
+ static void nk_angulars_packed_i4_sme_finalize_ssve_( //
1566
+ nk_i4x2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
1567
+ nk_size_t a_stride_elements, nk_size_t c_stride_elements) NK_STREAMING_ {
1536
1568
  nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
1537
1569
  nk_u32_t const *b_norms = (nk_u32_t const *)((char const *)b_packed + header->norms_offset);
1538
1570
  for (nk_size_t row_index = 0; row_index < rows; row_index++) {
@@ -1553,22 +1585,22 @@ __arm_locally_streaming static void nk_angulars_packed_i4_sme_finalize_streaming
1553
1585
  }
1554
1586
  }
1555
1587
 
1556
- NK_PUBLIC void nk_angulars_packed_i4_sme( //
1557
- nk_i4x2_t const *a, void const *b_packed, nk_f32_t *c, //
1558
- nk_size_t rows, nk_size_t columns, nk_size_t depth, //
1588
+ NK_PUBLIC void nk_angulars_packed_i4_sme( //
1589
+ nk_i4x2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
1559
1590
  nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
1560
1591
  nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_i4x2_t);
1561
1592
  nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
1593
+ nk_sme_start_streaming_();
1562
1594
  nk_dots_packed_i4_sme_streaming_(a, b_packed, (nk_i32_t *)c, rows, columns, depth, a_stride_elements,
1563
1595
  c_stride_elements);
1564
- nk_angulars_packed_i4_sme_finalize_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements,
1565
- c_stride_elements);
1596
+ nk_angulars_packed_i4_sme_finalize_ssve_(a, b_packed, c, rows, columns, depth, a_stride_elements,
1597
+ c_stride_elements);
1598
+ nk_sme_stop_streaming_();
1566
1599
  }
1567
1600
 
1568
- __arm_locally_streaming static void nk_euclideans_packed_i4_sme_finalize_streaming_( //
1569
- nk_i4x2_t const *a, void const *b_packed, nk_f32_t *c, //
1570
- nk_size_t rows, nk_size_t columns, nk_size_t depth, //
1571
- nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
1601
+ static void nk_euclideans_packed_i4_sme_finalize_ssve_( //
1602
+ nk_i4x2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
1603
+ nk_size_t a_stride_elements, nk_size_t c_stride_elements) NK_STREAMING_ {
1572
1604
  nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
1573
1605
  nk_u32_t const *b_norms = (nk_u32_t const *)((char const *)b_packed + header->norms_offset);
1574
1606
  for (nk_size_t row_index = 0; row_index < rows; row_index++) {
@@ -1589,21 +1621,22 @@ __arm_locally_streaming static void nk_euclideans_packed_i4_sme_finalize_streami
1589
1621
  }
1590
1622
  }
1591
1623
 
1592
- NK_PUBLIC void nk_euclideans_packed_i4_sme( //
1593
- nk_i4x2_t const *a, void const *b_packed, nk_f32_t *c, //
1594
- nk_size_t rows, nk_size_t columns, nk_size_t depth, //
1624
+ NK_PUBLIC void nk_euclideans_packed_i4_sme( //
1625
+ nk_i4x2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
1595
1626
  nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
1596
1627
  nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_i4x2_t);
1597
1628
  nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
1629
+ nk_sme_start_streaming_();
1598
1630
  nk_dots_packed_i4_sme_streaming_(a, b_packed, (nk_i32_t *)c, rows, columns, depth, a_stride_elements,
1599
1631
  c_stride_elements);
1600
- nk_euclideans_packed_i4_sme_finalize_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements,
1601
- c_stride_elements);
1632
+ nk_euclideans_packed_i4_sme_finalize_ssve_(a, b_packed, c, rows, columns, depth, a_stride_elements,
1633
+ c_stride_elements);
1634
+ nk_sme_stop_streaming_();
1602
1635
  }
1603
1636
 
1604
- __arm_locally_streaming static void nk_angulars_symmetric_i4_sme_finalize_streaming_( //
1605
- nk_i4x2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, //
1606
- nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
1637
+ static void nk_angulars_symmetric_i4_sme_finalize_ssve_( //
1638
+ nk_i4x2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, nk_f32_t *result,
1639
+ nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) NK_STREAMING_ {
1607
1640
  // Phase 1: cache row norms on diagonal (store as u32 in f32 slot)
1608
1641
  for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
1609
1642
  nk_u32_t row_sumsq_u32 = nk_dots_reduce_sumsq_i4_ssve_(vectors + row_index * stride_elements, depth);
@@ -1638,20 +1671,22 @@ __arm_locally_streaming static void nk_angulars_symmetric_i4_sme_finalize_stream
1638
1671
  result[row_index * result_stride_elements + row_index] = 0;
1639
1672
  }
1640
1673
 
1641
- NK_PUBLIC void nk_angulars_symmetric_i4_sme( //
1642
- nk_i4x2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
1643
- nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
1674
+ NK_PUBLIC void nk_angulars_symmetric_i4_sme( //
1675
+ nk_i4x2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, nk_f32_t *result,
1676
+ nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
1644
1677
  nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_i4x2_t);
1645
1678
  nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
1679
+ nk_sme_start_streaming_();
1646
1680
  nk_dots_symmetric_i4_sme_streaming_(vectors, vectors_count, depth, stride_elements, (nk_i32_t *)result,
1647
1681
  result_stride_elements, row_start, row_count);
1648
- nk_angulars_symmetric_i4_sme_finalize_streaming_(vectors, vectors_count, depth, stride_elements, result,
1649
- result_stride_elements, row_start, row_count);
1682
+ nk_angulars_symmetric_i4_sme_finalize_ssve_(vectors, vectors_count, depth, stride_elements, result,
1683
+ result_stride_elements, row_start, row_count);
1684
+ nk_sme_stop_streaming_();
1650
1685
  }
1651
1686
 
1652
- __arm_locally_streaming static void nk_euclideans_symmetric_i4_sme_finalize_streaming_( //
1653
- nk_i4x2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, //
1654
- nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
1687
+ static void nk_euclideans_symmetric_i4_sme_finalize_ssve_( //
1688
+ nk_i4x2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, nk_f32_t *result,
1689
+ nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) NK_STREAMING_ {
1655
1690
  // Phase 1: cache row norms on diagonal (store as u32 in f32 slot)
1656
1691
  for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
1657
1692
  nk_u32_t row_sumsq_u32 = nk_dots_reduce_sumsq_i4_ssve_(vectors + row_index * stride_elements, depth);
@@ -1686,25 +1721,26 @@ __arm_locally_streaming static void nk_euclideans_symmetric_i4_sme_finalize_stre
1686
1721
  result[row_index * result_stride_elements + row_index] = 0;
1687
1722
  }
1688
1723
 
1689
- NK_PUBLIC void nk_euclideans_symmetric_i4_sme( //
1690
- nk_i4x2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
1691
- nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
1724
+ NK_PUBLIC void nk_euclideans_symmetric_i4_sme( //
1725
+ nk_i4x2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, nk_f32_t *result,
1726
+ nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
1692
1727
  nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_i4x2_t);
1693
1728
  nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
1729
+ nk_sme_start_streaming_();
1694
1730
  nk_dots_symmetric_i4_sme_streaming_(vectors, vectors_count, depth, stride_elements, (nk_i32_t *)result,
1695
1731
  result_stride_elements, row_start, row_count);
1696
- nk_euclideans_symmetric_i4_sme_finalize_streaming_(vectors, vectors_count, depth, stride_elements, result,
1697
- result_stride_elements, row_start, row_count);
1732
+ nk_euclideans_symmetric_i4_sme_finalize_ssve_(vectors, vectors_count, depth, stride_elements, result,
1733
+ result_stride_elements, row_start, row_count);
1734
+ nk_sme_stop_streaming_();
1698
1735
  }
1699
1736
 
1700
1737
  #pragma endregion Signed Integers
1701
1738
 
1702
1739
  #pragma region U4 Integers
1703
1740
 
1704
- __arm_locally_streaming static void nk_angulars_packed_u4_sme_finalize_streaming_( //
1705
- nk_u4x2_t const *a, void const *b_packed, nk_f32_t *c, //
1706
- nk_size_t rows, nk_size_t columns, nk_size_t depth, //
1707
- nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
1741
+ static void nk_angulars_packed_u4_sme_finalize_ssve_( //
1742
+ nk_u4x2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
1743
+ nk_size_t a_stride_elements, nk_size_t c_stride_elements) NK_STREAMING_ {
1708
1744
  nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
1709
1745
  nk_u32_t const *b_norms = (nk_u32_t const *)((char const *)b_packed + header->norms_offset);
1710
1746
  for (nk_size_t row_index = 0; row_index < rows; row_index++) {
@@ -1725,22 +1761,22 @@ __arm_locally_streaming static void nk_angulars_packed_u4_sme_finalize_streaming
1725
1761
  }
1726
1762
  }
1727
1763
 
1728
- NK_PUBLIC void nk_angulars_packed_u4_sme( //
1729
- nk_u4x2_t const *a, void const *b_packed, nk_f32_t *c, //
1730
- nk_size_t rows, nk_size_t columns, nk_size_t depth, //
1764
+ NK_PUBLIC void nk_angulars_packed_u4_sme( //
1765
+ nk_u4x2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
1731
1766
  nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
1732
1767
  nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_u4x2_t);
1733
1768
  nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
1769
+ nk_sme_start_streaming_();
1734
1770
  nk_dots_packed_u4_sme_streaming_(a, b_packed, (nk_u32_t *)c, rows, columns, depth, a_stride_elements,
1735
1771
  c_stride_elements);
1736
- nk_angulars_packed_u4_sme_finalize_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements,
1737
- c_stride_elements);
1772
+ nk_angulars_packed_u4_sme_finalize_ssve_(a, b_packed, c, rows, columns, depth, a_stride_elements,
1773
+ c_stride_elements);
1774
+ nk_sme_stop_streaming_();
1738
1775
  }
1739
1776
 
1740
- __arm_locally_streaming static void nk_euclideans_packed_u4_sme_finalize_streaming_( //
1741
- nk_u4x2_t const *a, void const *b_packed, nk_f32_t *c, //
1742
- nk_size_t rows, nk_size_t columns, nk_size_t depth, //
1743
- nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
1777
+ static void nk_euclideans_packed_u4_sme_finalize_ssve_( //
1778
+ nk_u4x2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
1779
+ nk_size_t a_stride_elements, nk_size_t c_stride_elements) NK_STREAMING_ {
1744
1780
  nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
1745
1781
  nk_u32_t const *b_norms = (nk_u32_t const *)((char const *)b_packed + header->norms_offset);
1746
1782
  for (nk_size_t row_index = 0; row_index < rows; row_index++) {
@@ -1761,21 +1797,22 @@ __arm_locally_streaming static void nk_euclideans_packed_u4_sme_finalize_streami
1761
1797
  }
1762
1798
  }
1763
1799
 
1764
- NK_PUBLIC void nk_euclideans_packed_u4_sme( //
1765
- nk_u4x2_t const *a, void const *b_packed, nk_f32_t *c, //
1766
- nk_size_t rows, nk_size_t columns, nk_size_t depth, //
1800
+ NK_PUBLIC void nk_euclideans_packed_u4_sme( //
1801
+ nk_u4x2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
1767
1802
  nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
1768
1803
  nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_u4x2_t);
1769
1804
  nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
1805
+ nk_sme_start_streaming_();
1770
1806
  nk_dots_packed_u4_sme_streaming_(a, b_packed, (nk_u32_t *)c, rows, columns, depth, a_stride_elements,
1771
1807
  c_stride_elements);
1772
- nk_euclideans_packed_u4_sme_finalize_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements,
1773
- c_stride_elements);
1808
+ nk_euclideans_packed_u4_sme_finalize_ssve_(a, b_packed, c, rows, columns, depth, a_stride_elements,
1809
+ c_stride_elements);
1810
+ nk_sme_stop_streaming_();
1774
1811
  }
1775
1812
 
1776
- __arm_locally_streaming static void nk_angulars_symmetric_u4_sme_finalize_streaming_( //
1777
- nk_u4x2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, //
1778
- nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
1813
+ static void nk_angulars_symmetric_u4_sme_finalize_ssve_( //
1814
+ nk_u4x2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, nk_f32_t *result,
1815
+ nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) NK_STREAMING_ {
1779
1816
  // Phase 1: cache row norms on diagonal (store as u32 in f32 slot)
1780
1817
  for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
1781
1818
  nk_u32_t row_sumsq_u32 = nk_dots_reduce_sumsq_u4_ssve_(vectors + row_index * stride_elements, depth);
@@ -1810,20 +1847,22 @@ __arm_locally_streaming static void nk_angulars_symmetric_u4_sme_finalize_stream
1810
1847
  result[row_index * result_stride_elements + row_index] = 0;
1811
1848
  }
1812
1849
 
1813
- NK_PUBLIC void nk_angulars_symmetric_u4_sme( //
1814
- nk_u4x2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
1815
- nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
1850
+ NK_PUBLIC void nk_angulars_symmetric_u4_sme( //
1851
+ nk_u4x2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, nk_f32_t *result,
1852
+ nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
1816
1853
  nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_u4x2_t);
1817
1854
  nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
1855
+ nk_sme_start_streaming_();
1818
1856
  nk_dots_symmetric_u4_sme_streaming_(vectors, vectors_count, depth, stride_elements, (nk_u32_t *)result,
1819
1857
  result_stride_elements, row_start, row_count);
1820
- nk_angulars_symmetric_u4_sme_finalize_streaming_(vectors, vectors_count, depth, stride_elements, result,
1821
- result_stride_elements, row_start, row_count);
1858
+ nk_angulars_symmetric_u4_sme_finalize_ssve_(vectors, vectors_count, depth, stride_elements, result,
1859
+ result_stride_elements, row_start, row_count);
1860
+ nk_sme_stop_streaming_();
1822
1861
  }
1823
1862
 
1824
- __arm_locally_streaming static void nk_euclideans_symmetric_u4_sme_finalize_streaming_( //
1825
- nk_u4x2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, //
1826
- nk_f32_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
1863
+ static void nk_euclideans_symmetric_u4_sme_finalize_ssve_( //
1864
+ nk_u4x2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, nk_f32_t *result,
1865
+ nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) NK_STREAMING_ {
1827
1866
  // Phase 1: cache row norms on diagonal (store as u32 in f32 slot)
1828
1867
  for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
1829
1868
  nk_u32_t row_sumsq_u32 = nk_dots_reduce_sumsq_u4_ssve_(vectors + row_index * stride_elements, depth);
@@ -1858,15 +1897,17 @@ __arm_locally_streaming static void nk_euclideans_symmetric_u4_sme_finalize_stre
1858
1897
  result[row_index * result_stride_elements + row_index] = 0;
1859
1898
  }
1860
1899
 
1861
- NK_PUBLIC void nk_euclideans_symmetric_u4_sme( //
1862
- nk_u4x2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
1863
- nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
1900
+ NK_PUBLIC void nk_euclideans_symmetric_u4_sme( //
1901
+ nk_u4x2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, nk_f32_t *result,
1902
+ nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
1864
1903
  nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_u4x2_t);
1865
1904
  nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
1905
+ nk_sme_start_streaming_();
1866
1906
  nk_dots_symmetric_u4_sme_streaming_(vectors, vectors_count, depth, stride_elements, (nk_u32_t *)result,
1867
1907
  result_stride_elements, row_start, row_count);
1868
- nk_euclideans_symmetric_u4_sme_finalize_streaming_(vectors, vectors_count, depth, stride_elements, result,
1869
- result_stride_elements, row_start, row_count);
1908
+ nk_euclideans_symmetric_u4_sme_finalize_ssve_(vectors, vectors_count, depth, stride_elements, result,
1909
+ result_stride_elements, row_start, row_count);
1910
+ nk_sme_stop_streaming_();
1870
1911
  }
1871
1912
 
1872
1913
  #pragma endregion Unsigned Integers