numkong 7.4.5 → 7.6.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. package/README.md +1 -0
  2. package/binding.gyp +99 -5
  3. package/c/dispatch_e5m2.c +23 -3
  4. package/c/dispatch_f16.c +23 -0
  5. package/c/numkong.c +0 -13
  6. package/include/numkong/attention/sme.h +34 -31
  7. package/include/numkong/capabilities.h +2 -15
  8. package/include/numkong/cast/README.md +3 -0
  9. package/include/numkong/cast/haswell.h +28 -64
  10. package/include/numkong/cast/neon.h +15 -0
  11. package/include/numkong/cast/serial.h +17 -0
  12. package/include/numkong/cast/skylake.h +67 -52
  13. package/include/numkong/cast.h +1 -0
  14. package/include/numkong/curved/smef64.h +82 -62
  15. package/include/numkong/dot/README.md +1 -0
  16. package/include/numkong/dot/haswell.h +92 -13
  17. package/include/numkong/dot/rvvbf16.h +1 -1
  18. package/include/numkong/dot/rvvhalf.h +1 -1
  19. package/include/numkong/dot/serial.h +15 -0
  20. package/include/numkong/dot/skylake.h +61 -14
  21. package/include/numkong/dot/sve.h +6 -5
  22. package/include/numkong/dot/svebfdot.h +2 -1
  23. package/include/numkong/dot/svehalf.h +6 -5
  24. package/include/numkong/dot/svesdot.h +3 -2
  25. package/include/numkong/dots/README.md +2 -0
  26. package/include/numkong/dots/graniteamx.h +1167 -0
  27. package/include/numkong/dots/haswell.h +28 -28
  28. package/include/numkong/dots/sapphireamx.h +1 -1
  29. package/include/numkong/dots/serial.h +33 -11
  30. package/include/numkong/dots/skylake.h +28 -23
  31. package/include/numkong/dots/sme.h +172 -140
  32. package/include/numkong/dots/smebi32.h +14 -11
  33. package/include/numkong/dots/smef64.h +31 -26
  34. package/include/numkong/dots.h +41 -3
  35. package/include/numkong/each/serial.h +39 -0
  36. package/include/numkong/geospatial/haswell.h +1 -1
  37. package/include/numkong/geospatial/neon.h +1 -1
  38. package/include/numkong/geospatial/serial.h +15 -4
  39. package/include/numkong/geospatial/skylake.h +1 -1
  40. package/include/numkong/maxsim/serial.h +15 -0
  41. package/include/numkong/maxsim/sme.h +34 -33
  42. package/include/numkong/mesh/README.md +50 -44
  43. package/include/numkong/mesh/genoa.h +462 -0
  44. package/include/numkong/mesh/haswell.h +806 -933
  45. package/include/numkong/mesh/neon.h +871 -943
  46. package/include/numkong/mesh/neonbfdot.h +382 -522
  47. package/include/numkong/mesh/neonfhm.h +676 -0
  48. package/include/numkong/mesh/rvv.h +404 -319
  49. package/include/numkong/mesh/serial.h +225 -161
  50. package/include/numkong/mesh/skylake.h +1029 -1585
  51. package/include/numkong/mesh/v128relaxed.h +403 -377
  52. package/include/numkong/mesh.h +38 -0
  53. package/include/numkong/reduce/neon.h +29 -0
  54. package/include/numkong/reduce/neonbfdot.h +2 -2
  55. package/include/numkong/reduce/neonfhm.h +4 -4
  56. package/include/numkong/reduce/serial.h +15 -1
  57. package/include/numkong/reduce/sve.h +52 -0
  58. package/include/numkong/reduce.h +4 -0
  59. package/include/numkong/set/sve.h +6 -5
  60. package/include/numkong/sets/smebi32.h +35 -30
  61. package/include/numkong/sparse/serial.h +17 -2
  62. package/include/numkong/sparse/sve2.h +3 -2
  63. package/include/numkong/spatial/genoa.h +0 -68
  64. package/include/numkong/spatial/haswell.h +98 -56
  65. package/include/numkong/spatial/serial.h +15 -0
  66. package/include/numkong/spatial/skylake.h +114 -54
  67. package/include/numkong/spatial/sve.h +7 -6
  68. package/include/numkong/spatial/svebfdot.h +7 -4
  69. package/include/numkong/spatial/svehalf.h +5 -4
  70. package/include/numkong/spatial/svesdot.h +9 -8
  71. package/include/numkong/spatial.h +0 -12
  72. package/include/numkong/spatials/graniteamx.h +301 -0
  73. package/include/numkong/spatials/serial.h +39 -0
  74. package/include/numkong/spatials/skylake.h +2 -2
  75. package/include/numkong/spatials/sme.h +391 -350
  76. package/include/numkong/spatials/smef64.h +79 -70
  77. package/include/numkong/spatials.h +54 -4
  78. package/include/numkong/tensor.hpp +107 -23
  79. package/include/numkong/types.h +59 -0
  80. package/javascript/dist/cjs/numkong.js +13 -0
  81. package/javascript/dist/esm/numkong.js +13 -0
  82. package/javascript/numkong.c +59 -14
  83. package/javascript/numkong.ts +13 -0
  84. package/package.json +7 -7
  85. package/probes/probe.js +2 -2
  86. package/wasm/numkong.wasm +0 -0
@@ -13,6 +13,7 @@
13
13
  #if NK_TARGET_SME
14
14
 
15
15
  #include "numkong/dots/serial.h"
16
+ #include "numkong/reduce/sve.h" // `nk_svaddv_f64_`
16
17
  #include "numkong/dots/smef64.h"
17
18
 
18
19
  #if defined(__cplusplus)
@@ -44,7 +45,7 @@ NK_PUBLIC nk_f64_t nk_dots_reduce_sumsq_f32_ssve_(nk_f32_t const *data, nk_size_
44
45
  svfloat64_t values_odd_f64x = svcvtlt_f64_f32_x(predicate_odd_b64x, values_f32x);
45
46
  accumulator_odd_f64x = svmla_f64_m(predicate_odd_b64x, accumulator_odd_f64x, values_odd_f64x, values_odd_f64x);
46
47
  }
47
- return svaddv_f64(svptrue_b64(), accumulator_even_f64x) + svaddv_f64(svptrue_b64(), accumulator_odd_f64x);
48
+ return nk_svaddv_f64_(svptrue_b64(), accumulator_even_f64x) + nk_svaddv_f64_(svptrue_b64(), accumulator_odd_f64x);
48
49
  }
49
50
 
50
51
  NK_PUBLIC nk_f64_t nk_dots_reduce_sumsq_f64_ssve_(nk_f64_t const *data, nk_size_t count) NK_STREAMING_ {
@@ -55,7 +56,7 @@ NK_PUBLIC nk_f64_t nk_dots_reduce_sumsq_f64_ssve_(nk_f64_t const *data, nk_size_
55
56
  svfloat64_t values_f64x = svld1_f64(predicate_b64x, data + i);
56
57
  accumulator_f64x = svmla_f64_m(predicate_b64x, accumulator_f64x, values_f64x, values_f64x);
57
58
  }
58
- return svaddv_f64(svptrue_b64(), accumulator_f64x);
59
+ return nk_svaddv_f64_(svptrue_b64(), accumulator_f64x);
59
60
  }
60
61
 
61
62
  NK_PUBLIC svfloat64_t nk_angulars_from_dot_f64x_ssvef64_(svbool_t predicate_b64x, svfloat64_t dots_f64x,
@@ -85,10 +86,9 @@ NK_PUBLIC svfloat64_t nk_euclideans_from_dot_f64x_ssvef64_(svbool_t predicate_b6
85
86
 
86
87
  #pragma region F32 Packed Angular
87
88
 
88
- __arm_locally_streaming static void nk_angulars_packed_f32_smef64_finalize_streaming_( //
89
- nk_f32_t const *a, void const *b_packed, nk_f64_t *c, //
90
- nk_size_t rows, nk_size_t columns, nk_size_t depth, //
91
- nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
89
+ static void nk_angulars_packed_f32_smef64_finalize_ssve_( //
90
+ nk_f32_t const *a, void const *b_packed, nk_f64_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
91
+ nk_size_t a_stride_elements, nk_size_t c_stride_elements) NK_STREAMING_ {
92
92
 
93
93
  nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
94
94
  nk_f64_t const *b_norms = (nk_f64_t const *)((char const *)b_packed + header->norms_offset);
@@ -110,26 +110,26 @@ __arm_locally_streaming static void nk_angulars_packed_f32_smef64_finalize_strea
110
110
  }
111
111
  }
112
112
 
113
- NK_PUBLIC void nk_angulars_packed_f32_smef64( //
114
- nk_f32_t const *a, void const *b_packed, nk_f64_t *c, //
115
- nk_size_t rows, nk_size_t columns, nk_size_t depth, //
113
+ NK_PUBLIC void nk_angulars_packed_f32_smef64( //
114
+ nk_f32_t const *a, void const *b_packed, nk_f64_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
116
115
  nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
117
116
 
118
117
  nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_f32_t);
119
118
  nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f64_t);
120
119
 
120
+ nk_sme_start_streaming_();
121
121
  nk_dots_packed_f32_smef64_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
122
- nk_angulars_packed_f32_smef64_finalize_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements,
123
- c_stride_elements);
122
+ nk_angulars_packed_f32_smef64_finalize_ssve_(a, b_packed, c, rows, columns, depth, a_stride_elements,
123
+ c_stride_elements);
124
+ nk_sme_stop_streaming_();
124
125
  }
125
126
 
126
127
  #pragma endregion F32 Packed Angular
127
128
  #pragma region F32 Packed Euclidean
128
129
 
129
- __arm_locally_streaming static void nk_euclideans_packed_f32_smef64_finalize_streaming_( //
130
- nk_f32_t const *a, void const *b_packed, nk_f64_t *c, //
131
- nk_size_t rows, nk_size_t columns, nk_size_t depth, //
132
- nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
130
+ static void nk_euclideans_packed_f32_smef64_finalize_ssve_( //
131
+ nk_f32_t const *a, void const *b_packed, nk_f64_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
132
+ nk_size_t a_stride_elements, nk_size_t c_stride_elements) NK_STREAMING_ {
133
133
 
134
134
  nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
135
135
  nk_f64_t const *b_norms = (nk_f64_t const *)((char const *)b_packed + header->norms_offset);
@@ -151,25 +151,26 @@ __arm_locally_streaming static void nk_euclideans_packed_f32_smef64_finalize_str
151
151
  }
152
152
  }
153
153
 
154
- NK_PUBLIC void nk_euclideans_packed_f32_smef64( //
155
- nk_f32_t const *a, void const *b_packed, nk_f64_t *c, //
156
- nk_size_t rows, nk_size_t columns, nk_size_t depth, //
154
+ NK_PUBLIC void nk_euclideans_packed_f32_smef64( //
155
+ nk_f32_t const *a, void const *b_packed, nk_f64_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
157
156
  nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
158
157
 
159
158
  nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_f32_t);
160
159
  nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f64_t);
161
160
 
161
+ nk_sme_start_streaming_();
162
162
  nk_dots_packed_f32_smef64_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
163
- nk_euclideans_packed_f32_smef64_finalize_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements,
164
- c_stride_elements);
163
+ nk_euclideans_packed_f32_smef64_finalize_ssve_(a, b_packed, c, rows, columns, depth, a_stride_elements,
164
+ c_stride_elements);
165
+ nk_sme_stop_streaming_();
165
166
  }
166
167
 
167
168
  #pragma endregion F32 Packed Euclidean
168
169
  #pragma region F32 Symmetric Angular
169
170
 
170
- __arm_locally_streaming static void nk_angulars_symmetric_f32_smef64_finalize_streaming_( //
171
- nk_f32_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, //
172
- nk_f64_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
171
+ static void nk_angulars_symmetric_f32_smef64_finalize_ssve_( //
172
+ nk_f32_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, nk_f64_t *result,
173
+ nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) NK_STREAMING_ {
173
174
  // Phase 1: cache row norms on diagonal
174
175
  for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
175
176
  nk_f32_t const *row_vector = vectors + row_index * stride_elements;
@@ -204,25 +205,27 @@ __arm_locally_streaming static void nk_angulars_symmetric_f32_smef64_finalize_st
204
205
  result[row_index * result_stride_elements + row_index] = 0;
205
206
  }
206
207
 
207
- NK_PUBLIC void nk_angulars_symmetric_f32_smef64( //
208
- nk_f32_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
209
- nk_f64_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
208
+ NK_PUBLIC void nk_angulars_symmetric_f32_smef64( //
209
+ nk_f32_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, nk_f64_t *result,
210
+ nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
210
211
 
211
212
  nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_f32_t);
212
213
  nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f64_t);
213
214
 
215
+ nk_sme_start_streaming_();
214
216
  nk_dots_symmetric_f32_smef64_streaming_(vectors, vectors_count, depth, stride_elements, result,
215
217
  result_stride_elements, row_start, row_count);
216
- nk_angulars_symmetric_f32_smef64_finalize_streaming_(vectors, vectors_count, depth, stride_elements, result,
217
- result_stride_elements, row_start, row_count);
218
+ nk_angulars_symmetric_f32_smef64_finalize_ssve_(vectors, vectors_count, depth, stride_elements, result,
219
+ result_stride_elements, row_start, row_count);
220
+ nk_sme_stop_streaming_();
218
221
  }
219
222
 
220
223
  #pragma endregion F32 Symmetric Angular
221
224
  #pragma region F32 Symmetric Euclidean
222
225
 
223
- __arm_locally_streaming static void nk_euclideans_symmetric_f32_smef64_finalize_streaming_( //
224
- nk_f32_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, //
225
- nk_f64_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
226
+ static void nk_euclideans_symmetric_f32_smef64_finalize_ssve_( //
227
+ nk_f32_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, nk_f64_t *result,
228
+ nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) NK_STREAMING_ {
226
229
  // Phase 1: cache row norms on diagonal
227
230
  for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
228
231
  nk_f32_t const *row_vector = vectors + row_index * stride_elements;
@@ -257,26 +260,27 @@ __arm_locally_streaming static void nk_euclideans_symmetric_f32_smef64_finalize_
257
260
  result[row_index * result_stride_elements + row_index] = 0;
258
261
  }
259
262
 
260
- NK_PUBLIC void nk_euclideans_symmetric_f32_smef64( //
261
- nk_f32_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
262
- nk_f64_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
263
+ NK_PUBLIC void nk_euclideans_symmetric_f32_smef64( //
264
+ nk_f32_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, nk_f64_t *result,
265
+ nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
263
266
 
264
267
  nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_f32_t);
265
268
  nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f64_t);
266
269
 
270
+ nk_sme_start_streaming_();
267
271
  nk_dots_symmetric_f32_smef64_streaming_(vectors, vectors_count, depth, stride_elements, result,
268
272
  result_stride_elements, row_start, row_count);
269
- nk_euclideans_symmetric_f32_smef64_finalize_streaming_(vectors, vectors_count, depth, stride_elements, result,
270
- result_stride_elements, row_start, row_count);
273
+ nk_euclideans_symmetric_f32_smef64_finalize_ssve_(vectors, vectors_count, depth, stride_elements, result,
274
+ result_stride_elements, row_start, row_count);
275
+ nk_sme_stop_streaming_();
271
276
  }
272
277
 
273
278
  #pragma endregion F32 Symmetric Euclidean
274
279
  #pragma region F64 Packed Angular
275
280
 
276
- __arm_locally_streaming static void nk_angulars_packed_f64_smef64_finalize_streaming_( //
277
- nk_f64_t const *a, void const *b_packed, nk_f64_t *c, //
278
- nk_size_t rows, nk_size_t columns, nk_size_t depth, //
279
- nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
281
+ static void nk_angulars_packed_f64_smef64_finalize_ssve_( //
282
+ nk_f64_t const *a, void const *b_packed, nk_f64_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
283
+ nk_size_t a_stride_elements, nk_size_t c_stride_elements) NK_STREAMING_ {
280
284
 
281
285
  nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
282
286
  nk_f64_t const *b_norms = (nk_f64_t const *)((char const *)b_packed + header->norms_offset);
@@ -298,26 +302,26 @@ __arm_locally_streaming static void nk_angulars_packed_f64_smef64_finalize_strea
298
302
  }
299
303
  }
300
304
 
301
- NK_PUBLIC void nk_angulars_packed_f64_smef64( //
302
- nk_f64_t const *a, void const *b_packed, nk_f64_t *c, //
303
- nk_size_t rows, nk_size_t columns, nk_size_t depth, //
305
+ NK_PUBLIC void nk_angulars_packed_f64_smef64( //
306
+ nk_f64_t const *a, void const *b_packed, nk_f64_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
304
307
  nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
305
308
 
306
309
  nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_f64_t);
307
310
  nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f64_t);
308
311
 
312
+ nk_sme_start_streaming_();
309
313
  nk_dots_packed_f64_smef64_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
310
- nk_angulars_packed_f64_smef64_finalize_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements,
311
- c_stride_elements);
314
+ nk_angulars_packed_f64_smef64_finalize_ssve_(a, b_packed, c, rows, columns, depth, a_stride_elements,
315
+ c_stride_elements);
316
+ nk_sme_stop_streaming_();
312
317
  }
313
318
 
314
319
  #pragma endregion F64 Packed Angular
315
320
  #pragma region F64 Packed Euclidean
316
321
 
317
- __arm_locally_streaming static void nk_euclideans_packed_f64_smef64_finalize_streaming_( //
318
- nk_f64_t const *a, void const *b_packed, nk_f64_t *c, //
319
- nk_size_t rows, nk_size_t columns, nk_size_t depth, //
320
- nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
322
+ static void nk_euclideans_packed_f64_smef64_finalize_ssve_( //
323
+ nk_f64_t const *a, void const *b_packed, nk_f64_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
324
+ nk_size_t a_stride_elements, nk_size_t c_stride_elements) NK_STREAMING_ {
321
325
 
322
326
  nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
323
327
  nk_f64_t const *b_norms = (nk_f64_t const *)((char const *)b_packed + header->norms_offset);
@@ -339,25 +343,26 @@ __arm_locally_streaming static void nk_euclideans_packed_f64_smef64_finalize_str
339
343
  }
340
344
  }
341
345
 
342
- NK_PUBLIC void nk_euclideans_packed_f64_smef64( //
343
- nk_f64_t const *a, void const *b_packed, nk_f64_t *c, //
344
- nk_size_t rows, nk_size_t columns, nk_size_t depth, //
346
+ NK_PUBLIC void nk_euclideans_packed_f64_smef64( //
347
+ nk_f64_t const *a, void const *b_packed, nk_f64_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
345
348
  nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
346
349
 
347
350
  nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_f64_t);
348
351
  nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f64_t);
349
352
 
353
+ nk_sme_start_streaming_();
350
354
  nk_dots_packed_f64_smef64_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
351
- nk_euclideans_packed_f64_smef64_finalize_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements,
352
- c_stride_elements);
355
+ nk_euclideans_packed_f64_smef64_finalize_ssve_(a, b_packed, c, rows, columns, depth, a_stride_elements,
356
+ c_stride_elements);
357
+ nk_sme_stop_streaming_();
353
358
  }
354
359
 
355
360
  #pragma endregion F64 Packed Euclidean
356
361
  #pragma region F64 Symmetric Angular
357
362
 
358
- __arm_locally_streaming static void nk_angulars_symmetric_f64_smef64_finalize_streaming_( //
359
- nk_f64_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, //
360
- nk_f64_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
363
+ static void nk_angulars_symmetric_f64_smef64_finalize_ssve_( //
364
+ nk_f64_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, nk_f64_t *result,
365
+ nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) NK_STREAMING_ {
361
366
  // Phase 1: cache row norms on diagonal
362
367
  for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
363
368
  nk_f64_t const *row_vector = vectors + row_index * stride_elements;
@@ -392,25 +397,27 @@ __arm_locally_streaming static void nk_angulars_symmetric_f64_smef64_finalize_st
392
397
  result[row_index * result_stride_elements + row_index] = 0;
393
398
  }
394
399
 
395
- NK_PUBLIC void nk_angulars_symmetric_f64_smef64( //
396
- nk_f64_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
397
- nk_f64_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
400
+ NK_PUBLIC void nk_angulars_symmetric_f64_smef64( //
401
+ nk_f64_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, nk_f64_t *result,
402
+ nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
398
403
 
399
404
  nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_f64_t);
400
405
  nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f64_t);
401
406
 
407
+ nk_sme_start_streaming_();
402
408
  nk_dots_symmetric_f64_smef64_streaming_(vectors, vectors_count, depth, stride_elements, result,
403
409
  result_stride_elements, row_start, row_count);
404
- nk_angulars_symmetric_f64_smef64_finalize_streaming_(vectors, vectors_count, depth, stride_elements, result,
405
- result_stride_elements, row_start, row_count);
410
+ nk_angulars_symmetric_f64_smef64_finalize_ssve_(vectors, vectors_count, depth, stride_elements, result,
411
+ result_stride_elements, row_start, row_count);
412
+ nk_sme_stop_streaming_();
406
413
  }
407
414
 
408
415
  #pragma endregion F64 Symmetric Angular
409
416
  #pragma region F64 Symmetric Euclidean
410
417
 
411
- __arm_locally_streaming static void nk_euclideans_symmetric_f64_smef64_finalize_streaming_( //
412
- nk_f64_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, //
413
- nk_f64_t *result, nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
418
+ static void nk_euclideans_symmetric_f64_smef64_finalize_ssve_( //
419
+ nk_f64_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, nk_f64_t *result,
420
+ nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) NK_STREAMING_ {
414
421
  // Phase 1: cache row norms on diagonal
415
422
  for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) {
416
423
  nk_f64_t const *row_vector = vectors + row_index * stride_elements;
@@ -445,17 +452,19 @@ __arm_locally_streaming static void nk_euclideans_symmetric_f64_smef64_finalize_
445
452
  result[row_index * result_stride_elements + row_index] = 0;
446
453
  }
447
454
 
448
- NK_PUBLIC void nk_euclideans_symmetric_f64_smef64( //
449
- nk_f64_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
450
- nk_f64_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
455
+ NK_PUBLIC void nk_euclideans_symmetric_f64_smef64( //
456
+ nk_f64_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, nk_f64_t *result,
457
+ nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
451
458
 
452
459
  nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_f64_t);
453
460
  nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f64_t);
454
461
 
462
+ nk_sme_start_streaming_();
455
463
  nk_dots_symmetric_f64_smef64_streaming_(vectors, vectors_count, depth, stride_elements, result,
456
464
  result_stride_elements, row_start, row_count);
457
- nk_euclideans_symmetric_f64_smef64_finalize_streaming_(vectors, vectors_count, depth, stride_elements, result,
458
- result_stride_elements, row_start, row_count);
465
+ nk_euclideans_symmetric_f64_smef64_finalize_ssve_(vectors, vectors_count, depth, stride_elements, result,
466
+ result_stride_elements, row_start, row_count);
467
+ nk_sme_stop_streaming_();
459
468
  }
460
469
 
461
470
  #pragma endregion F64 Symmetric Euclidean
@@ -739,6 +739,45 @@ NK_PUBLIC void nk_euclideans_symmetric_u8_sapphireamx(nk_u8_t const *vectors, nk
739
739
  nk_size_t row_start, nk_size_t row_count);
740
740
  #endif // NK_TARGET_SAPPHIREAMX
741
741
 
742
+ /* Granite Rapids backends using Intel AMX-FP16.
743
+ * Native FP16 spatial kernels.
744
+ */
745
+ #if NK_TARGET_GRANITEAMX
746
+ /** @copydoc nk_angulars_packed_f16 */
747
+ NK_PUBLIC void nk_angulars_packed_f16_graniteamx(nk_f16_t const *a, void const *b_packed, nk_f32_t *result,
748
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
749
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
750
+ /** @copydoc nk_angulars_symmetric_f16 */
751
+ NK_PUBLIC void nk_angulars_symmetric_f16_graniteamx(nk_f16_t const *vectors, nk_size_t vectors_count, nk_size_t depth,
752
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
753
+ nk_size_t row_start, nk_size_t row_count);
754
+ /** @copydoc nk_euclideans_packed_f16 */
755
+ NK_PUBLIC void nk_euclideans_packed_f16_graniteamx(nk_f16_t const *a, void const *b_packed, nk_f32_t *result,
756
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
757
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
758
+ /** @copydoc nk_euclideans_symmetric_f16 */
759
+ NK_PUBLIC void nk_euclideans_symmetric_f16_graniteamx(nk_f16_t const *vectors, nk_size_t vectors_count, nk_size_t depth,
760
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
761
+ nk_size_t row_start, nk_size_t row_count);
762
+ /** @copydoc nk_angulars_packed_f16 */
763
+ NK_PUBLIC void nk_angulars_packed_e5m2_graniteamx(nk_e5m2_t const *a, void const *b_packed, nk_f32_t *result,
764
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
765
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
766
+ /** @copydoc nk_angulars_symmetric_f16 */
767
+ NK_PUBLIC void nk_angulars_symmetric_e5m2_graniteamx(nk_e5m2_t const *vectors, nk_size_t vectors_count, nk_size_t depth,
768
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
769
+ nk_size_t row_start, nk_size_t row_count);
770
+ /** @copydoc nk_euclideans_packed_f16 */
771
+ NK_PUBLIC void nk_euclideans_packed_e5m2_graniteamx(nk_e5m2_t const *a, void const *b_packed, nk_f32_t *result,
772
+ nk_size_t rows, nk_size_t cols, nk_size_t depth,
773
+ nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
774
+ /** @copydoc nk_euclideans_symmetric_f16 */
775
+ NK_PUBLIC void nk_euclideans_symmetric_e5m2_graniteamx(nk_e5m2_t const *vectors, nk_size_t vectors_count,
776
+ nk_size_t depth, nk_size_t stride, nk_f32_t *result,
777
+ nk_size_t result_stride, nk_size_t row_start,
778
+ nk_size_t row_count);
779
+ #endif // NK_TARGET_GRANITEAMX
780
+
742
781
  /* ARM SME backends using Scalable Matrix Extension.
743
782
  * SME provides ZA tile registers for outer product operations.
744
783
  * F16/BF16/I8/U8/E4M3 use ZA32 tiles, F32/F64 use ZA64 tiles (FEAT_SME_F64F64).
@@ -2078,6 +2117,7 @@ NK_PUBLIC void nk_euclideans_symmetric_u8_rvv(nk_u8_t const *vectors, nk_size_t
2078
2117
  #include "numkong/spatials/alder.h"
2079
2118
  #include "numkong/spatials/sierra.h"
2080
2119
  #include "numkong/spatials/sapphireamx.h"
2120
+ #include "numkong/spatials/graniteamx.h"
2081
2121
  #include "numkong/spatials/rvv.h"
2082
2122
  #include "numkong/spatials/v128relaxed.h"
2083
2123
  #include "numkong/spatials/sme.h"
@@ -2290,7 +2330,9 @@ NK_PUBLIC void nk_euclideans_symmetric_f32(nk_f32_t const *vectors, nk_size_t ve
2290
2330
  NK_PUBLIC void nk_angulars_packed_f16(nk_f16_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
2291
2331
  nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
2292
2332
  nk_size_t r_stride_in_bytes) {
2293
- #if NK_TARGET_SME
2333
+ #if NK_TARGET_GRANITEAMX
2334
+ nk_angulars_packed_f16_graniteamx(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2335
+ #elif NK_TARGET_SME
2294
2336
  nk_angulars_packed_f16_sme(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2295
2337
  #elif NK_TARGET_NEONFHM
2296
2338
  nk_angulars_packed_f16_neonfhm(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
@@ -2311,7 +2353,10 @@ NK_PUBLIC void nk_angulars_packed_f16(nk_f16_t const *a, void const *b_packed, n
2311
2353
  NK_PUBLIC void nk_angulars_symmetric_f16(nk_f16_t const *vectors, nk_size_t vectors_count, nk_size_t depth,
2312
2354
  nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
2313
2355
  nk_size_t row_start, nk_size_t row_count) {
2314
- #if NK_TARGET_SME
2356
+ #if NK_TARGET_GRANITEAMX
2357
+ nk_angulars_symmetric_f16_graniteamx(vectors, vectors_count, depth, stride, result, result_stride, row_start,
2358
+ row_count);
2359
+ #elif NK_TARGET_SME
2315
2360
  nk_angulars_symmetric_f16_sme(vectors, vectors_count, depth, stride, result, result_stride, row_start, row_count);
2316
2361
  #elif NK_TARGET_NEONFHM
2317
2362
  nk_angulars_symmetric_f16_neonfhm(vectors, vectors_count, depth, stride, result, result_stride, row_start,
@@ -2337,7 +2382,9 @@ NK_PUBLIC void nk_angulars_symmetric_f16(nk_f16_t const *vectors, nk_size_t vect
2337
2382
  NK_PUBLIC void nk_euclideans_packed_f16(nk_f16_t const *a, void const *b_packed, nk_f32_t *result, nk_size_t rows,
2338
2383
  nk_size_t cols, nk_size_t depth, nk_size_t a_stride_in_bytes,
2339
2384
  nk_size_t r_stride_in_bytes) {
2340
- #if NK_TARGET_SME
2385
+ #if NK_TARGET_GRANITEAMX
2386
+ nk_euclideans_packed_f16_graniteamx(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2387
+ #elif NK_TARGET_SME
2341
2388
  nk_euclideans_packed_f16_sme(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
2342
2389
  #elif NK_TARGET_NEONFHM
2343
2390
  nk_euclideans_packed_f16_neonfhm(a, b_packed, result, rows, cols, depth, a_stride_in_bytes, r_stride_in_bytes);
@@ -2358,7 +2405,10 @@ NK_PUBLIC void nk_euclideans_packed_f16(nk_f16_t const *a, void const *b_packed,
2358
2405
  NK_PUBLIC void nk_euclideans_symmetric_f16(nk_f16_t const *vectors, nk_size_t vectors_count, nk_size_t depth,
2359
2406
  nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
2360
2407
  nk_size_t row_start, nk_size_t row_count) {
2361
- #if NK_TARGET_SME
2408
+ #if NK_TARGET_GRANITEAMX
2409
+ nk_euclideans_symmetric_f16_graniteamx(vectors, vectors_count, depth, stride, result, result_stride, row_start,
2410
+ row_count);
2411
+ #elif NK_TARGET_SME
2362
2412
  nk_euclideans_symmetric_f16_sme(vectors, vectors_count, depth, stride, result, result_stride, row_start, row_count);
2363
2413
  #elif NK_TARGET_NEONFHM
2364
2414
  nk_euclideans_symmetric_f16_neonfhm(vectors, vectors_count, depth, stride, result, result_stride, row_start,