numkong 7.4.4 → 7.5.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. package/README.md +1 -0
  2. package/binding.gyp +81 -5
  3. package/c/dispatch_f16.c +23 -0
  4. package/c/numkong.c +0 -13
  5. package/include/numkong/attention/sme.h +34 -31
  6. package/include/numkong/capabilities.h +2 -15
  7. package/include/numkong/cast/neon.h +15 -0
  8. package/include/numkong/curved/smef64.h +82 -62
  9. package/include/numkong/dot/rvvbf16.h +1 -1
  10. package/include/numkong/dot/rvvhalf.h +1 -1
  11. package/include/numkong/dot/sve.h +6 -5
  12. package/include/numkong/dot/svebfdot.h +2 -1
  13. package/include/numkong/dot/svehalf.h +6 -5
  14. package/include/numkong/dot/svesdot.h +3 -2
  15. package/include/numkong/dots/graniteamx.h +733 -0
  16. package/include/numkong/dots/serial.h +11 -4
  17. package/include/numkong/dots/sme.h +172 -140
  18. package/include/numkong/dots/smebi32.h +14 -11
  19. package/include/numkong/dots/smef64.h +31 -26
  20. package/include/numkong/dots.h +29 -3
  21. package/include/numkong/each/serial.h +22 -0
  22. package/include/numkong/geospatial/haswell.h +1 -1
  23. package/include/numkong/geospatial/neon.h +1 -1
  24. package/include/numkong/geospatial/serial.h +1 -1
  25. package/include/numkong/geospatial/skylake.h +1 -1
  26. package/include/numkong/maxsim/sme.h +94 -55
  27. package/include/numkong/mesh/README.md +13 -27
  28. package/include/numkong/mesh/haswell.h +25 -122
  29. package/include/numkong/mesh/neon.h +21 -110
  30. package/include/numkong/mesh/neonbfdot.h +4 -43
  31. package/include/numkong/mesh/rvv.h +7 -82
  32. package/include/numkong/mesh/serial.h +48 -53
  33. package/include/numkong/mesh/skylake.h +7 -123
  34. package/include/numkong/mesh/v128relaxed.h +9 -93
  35. package/include/numkong/mesh.h +2 -2
  36. package/include/numkong/mesh.hpp +35 -96
  37. package/include/numkong/reduce/neon.h +29 -0
  38. package/include/numkong/reduce/neonbfdot.h +2 -2
  39. package/include/numkong/reduce/neonfhm.h +4 -4
  40. package/include/numkong/reduce/sve.h +52 -0
  41. package/include/numkong/reduce.h +4 -0
  42. package/include/numkong/set/sve.h +6 -5
  43. package/include/numkong/sets/smebi32.h +35 -30
  44. package/include/numkong/sparse/sve2.h +3 -2
  45. package/include/numkong/spatial/sve.h +7 -6
  46. package/include/numkong/spatial/svebfdot.h +7 -4
  47. package/include/numkong/spatial/svehalf.h +5 -4
  48. package/include/numkong/spatial/svesdot.h +9 -8
  49. package/include/numkong/spatials/graniteamx.h +173 -0
  50. package/include/numkong/spatials/serial.h +22 -0
  51. package/include/numkong/spatials/sme.h +391 -350
  52. package/include/numkong/spatials/smef64.h +79 -70
  53. package/include/numkong/spatials.h +37 -4
  54. package/include/numkong/types.h +59 -0
  55. package/javascript/dist/cjs/numkong.js +13 -0
  56. package/javascript/dist/esm/numkong.js +13 -0
  57. package/javascript/numkong.c +56 -12
  58. package/javascript/numkong.ts +13 -0
  59. package/package.json +7 -7
  60. package/probes/probe.js +2 -2
  61. package/wasm/numkong.wasm +0 -0
@@ -41,8 +41,9 @@
41
41
  #include "numkong/types.h"
42
42
  #include "numkong/set/serial.h"
43
43
  #include "numkong/sets/serial.h"
44
- #include "numkong/dots/sme.h" // `nk_sme_zero_za32_*` constants
45
- #include "numkong/reduce.h" // `nk_reduce_moments_u1`
44
+ #include "numkong/reduce/sve.h" // `nk_svaddv_f64_`
45
+ #include "numkong/reduce/neon.h" // `nk_reduce_moments_u1_neon`
46
+ #include "numkong/dots/sme.h" // `nk_sme_zero_za32_*`
46
47
 
47
48
  #if defined(__cplusplus)
48
49
  extern "C" {
@@ -100,7 +101,7 @@ NK_PUBLIC nk_u32_t nk_sets_reduce_sumsq_u1_streaming_(nk_u1x8_t const *data, nk_
100
101
  svbool_t predicate_b8x = svwhilelt_b8_u64(offset, n_bytes);
101
102
  acc_u32x = svdot_u32(acc_u32x, svcnt_u8_z(predicate_b8x, svld1_u8(predicate_b8x, data + offset)), ones_u8x);
102
103
  }
103
- return (nk_u32_t)svaddv_u32(svptrue_b32(), acc_u32x);
104
+ return (nk_u32_t)nk_svaddv_u32_(svptrue_b32(), acc_u32x);
104
105
  }
105
106
 
106
107
  #pragma region Hamming Distance
@@ -187,11 +188,9 @@ NK_PUBLIC void nk_dots_pack_u1_smebi32(nk_u1x8_t const *b, nk_size_t row_count,
187
188
  // Compute per-row population counts
188
189
  for (nk_size_t row = 0; row < row_count; row++) {
189
190
  nk_u1x8_t const *src_row = (nk_u1x8_t const *)((char const *)b + row * b_stride_in_bytes);
190
- {
191
- nk_u64_t nk_local_sum_, nk_local_sumsq_;
192
- nk_reduce_moments_u1(src_row, depth_bytes * 8, sizeof(nk_u1x8_t), &nk_local_sum_, &nk_local_sumsq_);
193
- norms_ptr[row] = (nk_u32_t)nk_local_sum_;
194
- }
191
+ nk_u64_t nk_local_sum_, nk_local_sumsq_;
192
+ nk_reduce_moments_u1_neon(src_row, depth_bytes * 8, sizeof(nk_u1x8_t), &nk_local_sum_, &nk_local_sumsq_);
193
+ norms_ptr[row] = (nk_u32_t)nk_local_sum_;
195
194
  }
196
195
  }
197
196
 
@@ -203,9 +202,9 @@ NK_PUBLIC void nk_dots_pack_u1_smebi32(nk_u1x8_t const *b, nk_size_t row_count,
203
202
  * Each ZA0.S batch covers 16 depth u32 steps (one full depth tile).
204
203
  * BMOPA expansion=1 for u32: each u32 contributes 32 bits via XNOR+POPCNT.
205
204
  */
206
- __arm_locally_streaming __arm_new("za") static void nk_hammings_packed_u1_smebi32_streaming_(
205
+ __arm_new("za") static void nk_hammings_packed_u1_smebi32_streaming_( //
207
206
  nk_u1x8_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t row_count_a, nk_size_t row_count_b,
208
- nk_size_t depth_bits, nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
207
+ nk_size_t depth_bits, nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) NK_STREAMING_ {
209
208
 
210
209
  nk_sets_smebi32_packed_header_t const *header = (nk_sets_smebi32_packed_header_t const *)b_packed;
211
210
  nk_size_t const row_tile_count_b = header->row_tile_count;
@@ -344,11 +343,13 @@ __arm_locally_streaming __arm_new("za") static void nk_hammings_packed_u1_smebi3
344
343
  }
345
344
  }
346
345
 
347
- NK_PUBLIC void nk_hammings_packed_u1_smebi32(nk_u1x8_t const *a, void const *b_packed, nk_u32_t *c,
348
- nk_size_t row_count_a, nk_size_t row_count_b, nk_size_t depth_bits,
349
- nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
346
+ NK_PUBLIC void nk_hammings_packed_u1_smebi32( //
347
+ nk_u1x8_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t row_count_a, nk_size_t row_count_b,
348
+ nk_size_t depth_bits, nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
349
+ nk_sme_start_streaming_();
350
350
  nk_hammings_packed_u1_smebi32_streaming_(a, b_packed, c, row_count_a, row_count_b, depth_bits, a_stride_in_bytes,
351
351
  c_stride_in_bytes);
352
+ nk_sme_stop_streaming_();
352
353
  }
353
354
 
354
355
  /**
@@ -357,9 +358,9 @@ NK_PUBLIC void nk_hammings_packed_u1_smebi32(nk_u1x8_t const *a, void const *b_p
357
358
  * ZA1-3.S = BMOPA accumulators (3 B column tiles in fast path).
358
359
  * Mirrors the unpacked kernel nk_hammings_packed_u1_smebi32_streaming_ pattern.
359
360
  */
360
- __arm_locally_streaming __arm_new("za") static void nk_hammings_symmetric_u1_smebi32_streaming_(
361
+ __arm_new("za") static void nk_hammings_symmetric_u1_smebi32_streaming_( //
361
362
  nk_u1x8_t const *vectors, nk_size_t vectors_count, nk_size_t depth_bits, nk_size_t stride_in_bytes,
362
- nk_u32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
363
+ nk_u32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) NK_STREAMING_ {
363
364
 
364
365
  nk_size_t const tile_dim = svcntw(); // 16 for 512-bit SVL
365
366
  nk_size_t const depth_tile_size = svcntw(); // 16 u32 per depth tile
@@ -545,12 +546,13 @@ __arm_locally_streaming __arm_new("za") static void nk_hammings_symmetric_u1_sme
545
546
  }
546
547
  }
547
548
 
548
- NK_PUBLIC void nk_hammings_symmetric_u1_smebi32(nk_u1x8_t const *vectors, nk_size_t vectors_count, nk_size_t depth_bits,
549
- nk_size_t stride_in_bytes, nk_u32_t *result,
550
- nk_size_t result_stride_in_bytes, nk_size_t row_start,
551
- nk_size_t row_count) {
549
+ NK_PUBLIC void nk_hammings_symmetric_u1_smebi32( //
550
+ nk_u1x8_t const *vectors, nk_size_t vectors_count, nk_size_t depth_bits, nk_size_t stride_in_bytes,
551
+ nk_u32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
552
+ nk_sme_start_streaming_();
552
553
  nk_hammings_symmetric_u1_smebi32_streaming_(vectors, vectors_count, depth_bits, stride_in_bytes, result,
553
554
  result_stride_in_bytes, row_start, row_count);
555
+ nk_sme_stop_streaming_();
554
556
  }
555
557
 
556
558
  #pragma endregion Hamming Distance
@@ -581,9 +583,9 @@ NK_PUBLIC void nk_hammings_symmetric_u1_smebi32(nk_u1x8_t const *vectors, nk_siz
581
583
  * union = (norm_a + norm_b + hamming) / 2
582
584
  * jaccard = 1 - intersection / union (1.0 when union == 0)
583
585
  */
584
- __arm_locally_streaming __arm_new("za") static void nk_jaccards_packed_u1_smebi32_streaming_(
586
+ __arm_new("za") static void nk_jaccards_packed_u1_smebi32_streaming_( //
585
587
  nk_u1x8_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t row_count_a, nk_size_t row_count_b,
586
- nk_size_t depth_bits, nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
588
+ nk_size_t depth_bits, nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) NK_STREAMING_ {
587
589
 
588
590
  nk_sets_smebi32_packed_header_t const *header = (nk_sets_smebi32_packed_header_t const *)b_packed;
589
591
  nk_size_t const row_tile_count_b = header->row_tile_count;
@@ -796,11 +798,13 @@ __arm_locally_streaming __arm_new("za") static void nk_jaccards_packed_u1_smebi3
796
798
  }
797
799
  }
798
800
 
799
- NK_PUBLIC void nk_jaccards_packed_u1_smebi32(nk_u1x8_t const *a, void const *b_packed, nk_f32_t *c,
800
- nk_size_t row_count_a, nk_size_t row_count_b, nk_size_t depth_bits,
801
- nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
801
+ NK_PUBLIC void nk_jaccards_packed_u1_smebi32( //
802
+ nk_u1x8_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t row_count_a, nk_size_t row_count_b,
803
+ nk_size_t depth_bits, nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
804
+ nk_sme_start_streaming_();
802
805
  nk_jaccards_packed_u1_smebi32_streaming_(a, b_packed, c, row_count_a, row_count_b, depth_bits, a_stride_in_bytes,
803
806
  c_stride_in_bytes);
807
+ nk_sme_stop_streaming_();
804
808
  }
805
809
 
806
810
  /**
@@ -808,9 +812,9 @@ NK_PUBLIC void nk_jaccards_packed_u1_smebi32(nk_u1x8_t const *a, void const *b_p
808
812
  * Fills upper triangle only (column_tile >= row_tile); caller sees result[i][j] for j >= i.
809
813
  * Norms computed on-the-fly using streaming SVE popcount.
810
814
  */
811
- __arm_locally_streaming __arm_new("za") static void nk_jaccards_symmetric_u1_smebi32_streaming_(
815
+ __arm_new("za") static void nk_jaccards_symmetric_u1_smebi32_streaming_( //
812
816
  nk_u1x8_t const *vectors, nk_size_t vectors_count, nk_size_t depth_bits, nk_size_t stride_in_bytes,
813
- nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
817
+ nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) NK_STREAMING_ {
814
818
 
815
819
  nk_size_t const tile_dim = svcntw(); // 16 for 512-bit SVL
816
820
  nk_size_t const depth_tile_size = svcntw(); // 16 u32 per depth tile
@@ -1104,12 +1108,13 @@ __arm_locally_streaming __arm_new("za") static void nk_jaccards_symmetric_u1_sme
1104
1108
  }
1105
1109
  }
1106
1110
 
1107
- NK_PUBLIC void nk_jaccards_symmetric_u1_smebi32(nk_u1x8_t const *vectors, nk_size_t vectors_count, nk_size_t depth_bits,
1108
- nk_size_t stride_in_bytes, nk_f32_t *result,
1109
- nk_size_t result_stride_in_bytes, nk_size_t row_start,
1110
- nk_size_t row_count) {
1111
+ NK_PUBLIC void nk_jaccards_symmetric_u1_smebi32( //
1112
+ nk_u1x8_t const *vectors, nk_size_t vectors_count, nk_size_t depth_bits, nk_size_t stride_in_bytes,
1113
+ nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
1114
+ nk_sme_start_streaming_();
1111
1115
  nk_jaccards_symmetric_u1_smebi32_streaming_(vectors, vectors_count, depth_bits, stride_in_bytes, result,
1112
1116
  result_stride_in_bytes, row_start, row_count);
1117
+ nk_sme_stop_streaming_();
1113
1118
  }
1114
1119
 
1115
1120
  #pragma endregion Jaccard Distance
@@ -12,6 +12,7 @@
12
12
  #if NK_TARGET_ARM64_
13
13
 
14
14
  #include "numkong/types.h"
15
+ #include "numkong/reduce/sve.h" // `nk_svaddv_f64_`
15
16
 
16
17
  #if defined(__cplusplus)
17
18
  extern "C" {
@@ -395,7 +396,7 @@ NK_PUBLIC void nk_sparse_dot_u32f32_sve2( //
395
396
  a_idx += a_step;
396
397
  b_idx += b_step;
397
398
  }
398
- *product = svaddv_f64(predicate_all_b64x, product_f64x);
399
+ *product = nk_svaddv_f64_(predicate_all_b64x, product_f64x);
399
400
  }
400
401
 
401
402
  #if defined(__clang__)
@@ -485,7 +486,7 @@ NK_PUBLIC void nk_sparse_dot_u16bf16_sve2( //
485
486
  a_idx += a_step;
486
487
  b_idx += b_step;
487
488
  }
488
- *product = svaddv_f32(svptrue_b32(), product_f32x);
489
+ *product = nk_svaddv_f32_(svptrue_b32(), product_f32x);
489
490
  }
490
491
 
491
492
  #if defined(__clang__)
@@ -36,6 +36,7 @@
36
36
  #if NK_TARGET_SVE
37
37
 
38
38
  #include "numkong/types.h"
39
+ #include "numkong/reduce/sve.h" // `nk_svaddv_f64_`
39
40
  #include "numkong/spatial/neon.h" // `nk_f64_sqrt_neon`
40
41
  #include "numkong/dot/sve.h" // `nk_dot_stable_sum_f64_sve_`
41
42
 
@@ -113,7 +114,7 @@ NK_PUBLIC void nk_sqeuclidean_f32_sve(nk_f32_t const *a, nk_f32_t const *b, nk_s
113
114
  svfloat64_t diff_odd_f64x = svsub_f64_x(pred_odd_b64x, a_odd_f64x, b_odd_f64x);
114
115
  dist_sq_f64x = svmla_f64_m(pred_odd_b64x, dist_sq_f64x, diff_odd_f64x, diff_odd_f64x);
115
116
  }
116
- nk_f64_t dist_sq_f64 = svaddv_f64(svptrue_b64(), dist_sq_f64x);
117
+ nk_f64_t dist_sq_f64 = nk_svaddv_f64_(svptrue_b64(), dist_sq_f64x);
117
118
  *result = dist_sq_f64;
118
119
  }
119
120
 
@@ -149,9 +150,9 @@ NK_PUBLIC void nk_angular_f32_sve(nk_f32_t const *a, nk_f32_t const *b, nk_size_
149
150
  b2_f64x = svmla_f64_m(pred_odd_b64x, b2_f64x, b_odd_f64x, b_odd_f64x);
150
151
  }
151
152
 
152
- nk_f64_t ab_f64 = svaddv_f64(svptrue_b64(), ab_f64x);
153
- nk_f64_t a2_f64 = svaddv_f64(svptrue_b64(), a2_f64x);
154
- nk_f64_t b2_f64 = svaddv_f64(svptrue_b64(), b2_f64x);
153
+ nk_f64_t ab_f64 = nk_svaddv_f64_(svptrue_b64(), ab_f64x);
154
+ nk_f64_t a2_f64 = nk_svaddv_f64_(svptrue_b64(), a2_f64x);
155
+ nk_f64_t b2_f64 = nk_svaddv_f64_(svptrue_b64(), b2_f64x);
155
156
  *result = nk_angular_normalize_f64_neon_(ab_f64, a2_f64, b2_f64);
156
157
  }
157
158
 
@@ -225,8 +226,8 @@ NK_PUBLIC void nk_angular_f64_sve(nk_f64_t const *a, nk_f64_t const *b, nk_size_
225
226
  } while (i < n);
226
227
 
227
228
  nk_f64_t ab_f64 = nk_dot_stable_sum_f64_sve_(predicate_all_b64x, ab_sum_f64x, ab_compensation_f64x);
228
- nk_f64_t a2_f64 = svaddv_f64(predicate_all_b64x, a2_f64x);
229
- nk_f64_t b2_f64 = svaddv_f64(predicate_all_b64x, b2_f64x);
229
+ nk_f64_t a2_f64 = nk_svaddv_f64_(predicate_all_b64x, a2_f64x);
230
+ nk_f64_t b2_f64 = nk_svaddv_f64_(predicate_all_b64x, b2_f64x);
230
231
  *result = nk_angular_normalize_f64_neon_(ab_f64, a2_f64, b2_f64);
231
232
  }
232
233
 
@@ -36,6 +36,7 @@
36
36
  #if NK_TARGET_SVEBFDOT
37
37
 
38
38
  #include "numkong/types.h"
39
+ #include "numkong/reduce/sve.h" // `nk_svaddv_f64_`
39
40
  #include "numkong/spatial/neon.h" // `nk_f32_sqrt_neon`
40
41
 
41
42
  #if defined(__cplusplus)
@@ -75,7 +76,9 @@ NK_PUBLIC void nk_sqeuclidean_bf16_svebfdot(nk_bf16_t const *a_enum, nk_bf16_t c
75
76
  d2_high_f32x = svmla_f32_m(predicate_high_b32x, d2_high_f32x, a_minus_b_high_f32x, a_minus_b_high_f32x);
76
77
  i += svcnth();
77
78
  } while (i < n);
78
- nk_f32_t d2 = svaddv_f32(svptrue_b32(), d2_low_f32x) + svaddv_f32(svptrue_b32(), d2_high_f32x);
79
+ nk_f32_t d2_low = nk_svaddv_f32_(svptrue_b32(), d2_low_f32x);
80
+ nk_f32_t d2_high = nk_svaddv_f32_(svptrue_b32(), d2_high_f32x);
81
+ nk_f32_t d2 = d2_low + d2_high;
79
82
  *result = d2;
80
83
  }
81
84
  NK_PUBLIC void nk_euclidean_bf16_svebfdot(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result) {
@@ -101,9 +104,9 @@ NK_PUBLIC void nk_angular_bf16_svebfdot(nk_bf16_t const *a_enum, nk_bf16_t const
101
104
  i += svcnth();
102
105
  } while (i < n);
103
106
 
104
- nk_f32_t ab = svaddv_f32(svptrue_b32(), ab_f32x);
105
- nk_f32_t a2 = svaddv_f32(svptrue_b32(), a2_f32x);
106
- nk_f32_t b2 = svaddv_f32(svptrue_b32(), b2_f32x);
107
+ nk_f32_t ab = nk_svaddv_f32_(svptrue_b32(), ab_f32x);
108
+ nk_f32_t a2 = nk_svaddv_f32_(svptrue_b32(), a2_f32x);
109
+ nk_f32_t b2 = nk_svaddv_f32_(svptrue_b32(), b2_f32x);
107
110
  *result = nk_angular_normalize_f32_neon_(ab, a2, b2);
108
111
  }
109
112
 
@@ -32,6 +32,7 @@
32
32
  #if NK_TARGET_SVEHALF
33
33
 
34
34
  #include "numkong/types.h"
35
+ #include "numkong/reduce/sve.h" // `nk_svaddv_f64_`
35
36
  #include "numkong/spatial/neon.h" // `nk_f32_sqrt_neon`
36
37
 
37
38
  #if defined(__cplusplus)
@@ -74,7 +75,7 @@ NK_PUBLIC void nk_sqeuclidean_f16_svehalf(nk_f16_t const *a_enum, nk_f16_t const
74
75
 
75
76
  i += svcnth();
76
77
  } while (i < n);
77
- *result = svaddv_f32(svptrue_b32(), d2_f32x);
78
+ *result = nk_svaddv_f32_(svptrue_b32(), d2_f32x);
78
79
  }
79
80
 
80
81
  NK_PUBLIC void nk_euclidean_f16_svehalf(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result) {
@@ -114,9 +115,9 @@ NK_PUBLIC void nk_angular_f16_svehalf(nk_f16_t const *a_enum, nk_f16_t const *b_
114
115
  i += svcnth();
115
116
  } while (i < n);
116
117
 
117
- nk_f32_t ab_f32 = svaddv_f32(svptrue_b32(), ab_f32x);
118
- nk_f32_t a2_f32 = svaddv_f32(svptrue_b32(), a2_f32x);
119
- nk_f32_t b2_f32 = svaddv_f32(svptrue_b32(), b2_f32x);
118
+ nk_f32_t ab_f32 = nk_svaddv_f32_(svptrue_b32(), ab_f32x);
119
+ nk_f32_t a2_f32 = nk_svaddv_f32_(svptrue_b32(), a2_f32x);
120
+ nk_f32_t b2_f32 = nk_svaddv_f32_(svptrue_b32(), b2_f32x);
120
121
  *result = nk_angular_normalize_f32_neon_(ab_f32, a2_f32, b2_f32);
121
122
  }
122
123
 
@@ -34,6 +34,7 @@
34
34
  #if NK_TARGET_SVESDOT
35
35
 
36
36
  #include "numkong/types.h"
37
+ #include "numkong/reduce/sve.h" // `nk_svaddv_f64_`
37
38
  #include "numkong/spatial/neon.h" // `nk_angular_normalize_f32_neon_`, `nk_f32_sqrt_neon`
38
39
 
39
40
  #if defined(__cplusplus)
@@ -58,7 +59,7 @@ NK_PUBLIC void nk_sqeuclidean_i8_svesdot(nk_i8_t const *a, nk_i8_t const *b, nk_
58
59
  distance_sq_u32x = svdot_u32(distance_sq_u32x, diff_u8x, diff_u8x);
59
60
  i += svcntb();
60
61
  } while (i < n);
61
- *result = (nk_u32_t)svaddv_u32(svptrue_b32(), distance_sq_u32x);
62
+ *result = (nk_u32_t)nk_svaddv_u32_(svptrue_b32(), distance_sq_u32x);
62
63
  }
63
64
  NK_PUBLIC void nk_euclidean_i8_svesdot(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_f32_t *result) {
64
65
  nk_u32_t distance_sq_u32;
@@ -81,9 +82,9 @@ NK_PUBLIC void nk_angular_i8_svesdot(nk_i8_t const *a, nk_i8_t const *b, nk_size
81
82
  i += svcntb();
82
83
  } while (i < n);
83
84
 
84
- nk_i32_t ab = (nk_i32_t)svaddv_s32(svptrue_b32(), ab_i32x);
85
- nk_i32_t a2 = (nk_i32_t)svaddv_s32(svptrue_b32(), a2_i32x);
86
- nk_i32_t b2 = (nk_i32_t)svaddv_s32(svptrue_b32(), b2_i32x);
85
+ nk_i32_t ab = (nk_i32_t)nk_svaddv_s32_(svptrue_b32(), ab_i32x);
86
+ nk_i32_t a2 = (nk_i32_t)nk_svaddv_s32_(svptrue_b32(), a2_i32x);
87
+ nk_i32_t b2 = (nk_i32_t)nk_svaddv_s32_(svptrue_b32(), b2_i32x);
87
88
  *result = nk_angular_normalize_f32_neon_((nk_f32_t)ab, (nk_f32_t)a2, (nk_f32_t)b2);
88
89
  }
89
90
 
@@ -98,7 +99,7 @@ NK_PUBLIC void nk_sqeuclidean_u8_svesdot(nk_u8_t const *a, nk_u8_t const *b, nk_
98
99
  distance_sq_u32x = svdot_u32(distance_sq_u32x, diff_u8x, diff_u8x);
99
100
  i += svcntb();
100
101
  } while (i < n);
101
- *result = (nk_u32_t)svaddv_u32(svptrue_b32(), distance_sq_u32x);
102
+ *result = (nk_u32_t)nk_svaddv_u32_(svptrue_b32(), distance_sq_u32x);
102
103
  }
103
104
  NK_PUBLIC void nk_euclidean_u8_svesdot(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_f32_t *result) {
104
105
  nk_u32_t distance_sq_u32;
@@ -121,9 +122,9 @@ NK_PUBLIC void nk_angular_u8_svesdot(nk_u8_t const *a, nk_u8_t const *b, nk_size
121
122
  i += svcntb();
122
123
  } while (i < n);
123
124
 
124
- nk_u32_t ab = (nk_u32_t)svaddv_u32(svptrue_b32(), ab_u32x);
125
- nk_u32_t a2 = (nk_u32_t)svaddv_u32(svptrue_b32(), a2_u32x);
126
- nk_u32_t b2 = (nk_u32_t)svaddv_u32(svptrue_b32(), b2_u32x);
125
+ nk_u32_t ab = (nk_u32_t)nk_svaddv_u32_(svptrue_b32(), ab_u32x);
126
+ nk_u32_t a2 = (nk_u32_t)nk_svaddv_u32_(svptrue_b32(), a2_u32x);
127
+ nk_u32_t b2 = (nk_u32_t)nk_svaddv_u32_(svptrue_b32(), b2_u32x);
127
128
  *result = nk_angular_normalize_f32_neon_((nk_f32_t)ab, (nk_f32_t)a2, (nk_f32_t)b2);
128
129
  }
129
130
 
@@ -0,0 +1,173 @@
1
+ /**
2
+ * @brief Batched Spatial Distances for Granite Rapids (AMX-FP16) with AVX-512 Finalization.
3
+ * @file include/numkong/spatials/graniteamx.h
4
+ * @author Ash Vardanian
5
+ * @date April 9, 2026
6
+ *
7
+ * @sa include/numkong/spatials.h
8
+ */
9
+ #ifndef NK_SPATIALS_GRANITEAMX_H
10
+ #define NK_SPATIALS_GRANITEAMX_H
11
+
12
+ #if NK_TARGET_X8664_
13
+ #if NK_TARGET_GRANITEAMX
14
+
15
+ #include "numkong/spatial/skylake.h"
16
+ #include "numkong/spatial/serial.h"
17
+ #include "numkong/dots/graniteamx.h"
18
+
19
+ #if defined(__cplusplus)
20
+ extern "C" {
21
+ #endif
22
+
23
+ #if defined(__clang__)
24
+ #pragma clang attribute push( \
25
+ __attribute__((target( \
26
+ "avx2,avx512f,avx512vl,avx512bw,avx512dq,avx512fp16,avx512vbmi,f16c,fma,bmi,bmi2,amx-tile,amx-bf16,amx-int8,amx-fp16"))), \
27
+ apply_to = function)
28
+ #elif defined(__GNUC__)
29
+ #pragma GCC push_options
30
+ #pragma GCC target("avx2", "avx512f", "avx512vl", "avx512bw", "avx512dq", "avx512fp16", "avx512vbmi", "f16c", "fma", \
31
+ "bmi", "bmi2", "amx-tile", "amx-bf16", "amx-int8", "amx-fp16")
32
+ #endif
33
+
34
+ #pragma region F16 Packed
35
+
36
+ NK_INTERNAL void nk_angulars_packed_f16_graniteamx_finalize_(nk_f16_t const *a, void const *b_packed, nk_f32_t *c,
37
+ nk_size_t rows, nk_size_t columns, nk_size_t depth,
38
+ nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
39
+ nk_dots_amx_packed_header_t const *header = (nk_dots_amx_packed_header_t const *)b_packed;
40
+ nk_f32_t const *b_norms = (nk_f32_t const *)((char const *)b_packed + header->norms_byte_offset);
41
+ for (nk_size_t row = 0; row < rows; row++) {
42
+ nk_f32_t query_norm_sq = nk_dots_reduce_sumsq_f16_(a + row * a_stride_elements, depth);
43
+ nk_angulars_row_f32dots_sapphireamx_(c + row * c_stride_elements, b_norms, query_norm_sq, columns);
44
+ }
45
+ }
46
+
47
+ NK_PUBLIC void nk_angulars_packed_f16_graniteamx( //
48
+ nk_f16_t const *a, void const *b_packed, nk_f32_t *c, //
49
+ nk_size_t rows, nk_size_t columns, nk_size_t depth, //
50
+ nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
51
+ nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_f16_t);
52
+ nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
53
+ nk_dots_packed_f16_graniteamx(a, b_packed, c, rows, columns, depth, a_stride_in_bytes, c_stride_in_bytes);
54
+ nk_angulars_packed_f16_graniteamx_finalize_(a, b_packed, c, rows, columns, depth, a_stride_elements,
55
+ c_stride_elements);
56
+ }
57
+
58
+ NK_INTERNAL void nk_euclideans_packed_f16_graniteamx_finalize_(nk_f16_t const *a, void const *b_packed, nk_f32_t *c,
59
+ nk_size_t rows, nk_size_t columns, nk_size_t depth,
60
+ nk_size_t a_stride_elements,
61
+ nk_size_t c_stride_elements) {
62
+ nk_dots_amx_packed_header_t const *header = (nk_dots_amx_packed_header_t const *)b_packed;
63
+ nk_f32_t const *b_norms = (nk_f32_t const *)((char const *)b_packed + header->norms_byte_offset);
64
+ for (nk_size_t row = 0; row < rows; row++) {
65
+ nk_f32_t query_norm_sq = nk_dots_reduce_sumsq_f16_(a + row * a_stride_elements, depth);
66
+ nk_euclideans_row_f32dots_sapphireamx_(c + row * c_stride_elements, b_norms, query_norm_sq, columns);
67
+ }
68
+ }
69
+
70
+ NK_PUBLIC void nk_euclideans_packed_f16_graniteamx( //
71
+ nk_f16_t const *a, void const *b_packed, nk_f32_t *c, //
72
+ nk_size_t rows, nk_size_t columns, nk_size_t depth, //
73
+ nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
74
+ nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_f16_t);
75
+ nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
76
+ nk_dots_packed_f16_graniteamx(a, b_packed, c, rows, columns, depth, a_stride_in_bytes, c_stride_in_bytes);
77
+ nk_euclideans_packed_f16_graniteamx_finalize_(a, b_packed, c, rows, columns, depth, a_stride_elements,
78
+ c_stride_elements);
79
+ }
80
+
81
+ #pragma endregion F16 Packed
82
+
83
+ #pragma region F16 Symmetric
84
+
85
+ NK_INTERNAL void nk_angulars_symmetric_f16_graniteamx_finalize_(nk_f16_t const *vectors, nk_size_t vectors_count,
86
+ nk_size_t depth, nk_size_t stride_elements,
87
+ nk_f32_t *result, nk_size_t result_stride_elements,
88
+ nk_size_t row_start, nk_size_t row_count) {
89
+
90
+ for (nk_size_t row = row_start; row < row_start + row_count; row++)
91
+ result[row * result_stride_elements + row] = nk_dots_reduce_sumsq_f16_(vectors + row * stride_elements, depth);
92
+
93
+ nk_f32_t column_norms_cache[256];
94
+ for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
95
+ nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
96
+ for (nk_size_t col = chunk_start; col < chunk_end; col++)
97
+ column_norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_f16_(vectors + col * stride_elements, depth);
98
+
99
+ for (nk_size_t row = row_start; row < row_start + row_count; row++) {
100
+ nk_f32_t *r_row = result + row * result_stride_elements;
101
+ nk_size_t col_start = chunk_start > row + 1 ? chunk_start : row + 1;
102
+ if (col_start >= chunk_end) continue;
103
+ nk_angulars_row_f32dots_sapphireamx_(r_row + col_start, column_norms_cache + col_start - chunk_start,
104
+ r_row[row], chunk_end - col_start);
105
+ }
106
+ }
107
+
108
+ for (nk_size_t row = row_start; row < row_start + row_count; row++) result[row * result_stride_elements + row] = 0;
109
+ }
110
+
111
+ NK_PUBLIC void nk_angulars_symmetric_f16_graniteamx( //
112
+ nk_f16_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
113
+ nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
114
+ nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_f16_t);
115
+ nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
116
+ nk_dots_symmetric_f16_graniteamx(vectors, vectors_count, depth, stride_in_bytes, result, result_stride_in_bytes,
117
+ row_start, row_count);
118
+ nk_angulars_symmetric_f16_graniteamx_finalize_(vectors, vectors_count, depth, stride_elements, result,
119
+ result_stride_elements, row_start, row_count);
120
+ }
121
+
122
+ NK_INTERNAL void nk_euclideans_symmetric_f16_graniteamx_finalize_(nk_f16_t const *vectors, nk_size_t vectors_count,
123
+ nk_size_t depth, nk_size_t stride_elements,
124
+ nk_f32_t *result, nk_size_t result_stride_elements,
125
+ nk_size_t row_start, nk_size_t row_count) {
126
+
127
+ for (nk_size_t row = row_start; row < row_start + row_count; row++)
128
+ result[row * result_stride_elements + row] = nk_dots_reduce_sumsq_f16_(vectors + row * stride_elements, depth);
129
+
130
+ nk_f32_t column_norms_cache[256];
131
+ for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
132
+ nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
133
+ for (nk_size_t col = chunk_start; col < chunk_end; col++)
134
+ column_norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_f16_(vectors + col * stride_elements, depth);
135
+
136
+ for (nk_size_t row = row_start; row < row_start + row_count; row++) {
137
+ nk_f32_t *r_row = result + row * result_stride_elements;
138
+ nk_size_t col_start = chunk_start > row + 1 ? chunk_start : row + 1;
139
+ if (col_start >= chunk_end) continue;
140
+ nk_euclideans_row_f32dots_sapphireamx_(r_row + col_start, column_norms_cache + col_start - chunk_start,
141
+ r_row[row], chunk_end - col_start);
142
+ }
143
+ }
144
+
145
+ for (nk_size_t row = row_start; row < row_start + row_count; row++) result[row * result_stride_elements + row] = 0;
146
+ }
147
+
148
+ NK_PUBLIC void nk_euclideans_symmetric_f16_graniteamx( //
149
+ nk_f16_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
150
+ nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
151
+ nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_f16_t);
152
+ nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
153
+ nk_dots_symmetric_f16_graniteamx(vectors, vectors_count, depth, stride_in_bytes, result, result_stride_in_bytes,
154
+ row_start, row_count);
155
+ nk_euclideans_symmetric_f16_graniteamx_finalize_(vectors, vectors_count, depth, stride_elements, result,
156
+ result_stride_elements, row_start, row_count);
157
+ }
158
+
159
+ #pragma endregion F16 Symmetric
160
+
161
+ #if defined(__clang__)
162
+ #pragma clang attribute pop
163
+ #elif defined(__GNUC__)
164
+ #pragma GCC pop_options
165
+ #endif
166
+
167
+ #if defined(__cplusplus)
168
+ } // extern "C"
169
+ #endif
170
+
171
+ #endif // NK_TARGET_GRANITEAMX
172
+ #endif // NK_TARGET_X8664_
173
+ #endif // NK_SPATIALS_GRANITEAMX_H
@@ -15,6 +15,18 @@
15
15
  extern "C" {
16
16
  #endif
17
17
 
18
+ /* Optimize serial fallbacks for size — see dots/serial.h for rationale. */
19
+ #if defined(NDEBUG)
20
+ #if defined(_MSC_VER)
21
+ #pragma optimize("s", on)
22
+ #elif defined(__clang__)
23
+ #pragma clang attribute push(__attribute__((minsize)), apply_to = function)
24
+ #elif defined(__GNUC__)
25
+ #pragma GCC push_options
26
+ #pragma GCC optimize("Os")
27
+ #endif
28
+ #endif
29
+
18
30
  nk_define_cross_normalized_packed_(angular, f64, serial, f64, f64, f64, /*norm_value_type=*/f64, f64, nk_b256_vec_t,
19
31
  nk_dots_packed_f64_serial, nk_angular_through_f64_from_dot_serial_,
20
32
  nk_dots_reduce_sumsq_f64_, nk_load_b256_serial_, nk_partial_load_b64x4_serial_,
@@ -219,6 +231,16 @@ nk_define_cross_normalized_symmetric_(euclidean, u4, serial, u4x2, u32, /*norm_v
219
231
  nk_dots_reduce_sumsq_u4_, nk_load_b128_serial_, nk_partial_load_b32x4_serial_,
220
232
  nk_store_b128_serial_, nk_partial_store_b32x4_serial_, 2)
221
233
 
234
+ #if defined(NDEBUG)
235
+ #if defined(_MSC_VER)
236
+ #pragma optimize("", on)
237
+ #elif defined(__clang__)
238
+ #pragma clang attribute pop
239
+ #elif defined(__GNUC__)
240
+ #pragma GCC pop_options
241
+ #endif
242
+ #endif
243
+
222
244
  #if defined(__cplusplus)
223
245
  } // extern "C"
224
246
  #endif