numkong 7.4.5 → 7.5.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. package/README.md +1 -0
  2. package/binding.gyp +81 -5
  3. package/c/dispatch_f16.c +23 -0
  4. package/c/numkong.c +0 -13
  5. package/include/numkong/attention/sme.h +34 -31
  6. package/include/numkong/capabilities.h +2 -15
  7. package/include/numkong/cast/neon.h +15 -0
  8. package/include/numkong/curved/smef64.h +82 -62
  9. package/include/numkong/dot/rvvbf16.h +1 -1
  10. package/include/numkong/dot/rvvhalf.h +1 -1
  11. package/include/numkong/dot/sve.h +6 -5
  12. package/include/numkong/dot/svebfdot.h +2 -1
  13. package/include/numkong/dot/svehalf.h +6 -5
  14. package/include/numkong/dot/svesdot.h +3 -2
  15. package/include/numkong/dots/graniteamx.h +733 -0
  16. package/include/numkong/dots/serial.h +11 -4
  17. package/include/numkong/dots/sme.h +172 -140
  18. package/include/numkong/dots/smebi32.h +14 -11
  19. package/include/numkong/dots/smef64.h +31 -26
  20. package/include/numkong/dots.h +29 -3
  21. package/include/numkong/each/serial.h +22 -0
  22. package/include/numkong/geospatial/haswell.h +1 -1
  23. package/include/numkong/geospatial/neon.h +1 -1
  24. package/include/numkong/geospatial/serial.h +1 -1
  25. package/include/numkong/geospatial/skylake.h +1 -1
  26. package/include/numkong/maxsim/sme.h +34 -33
  27. package/include/numkong/mesh/serial.h +22 -0
  28. package/include/numkong/reduce/neon.h +29 -0
  29. package/include/numkong/reduce/neonbfdot.h +2 -2
  30. package/include/numkong/reduce/neonfhm.h +4 -4
  31. package/include/numkong/reduce/sve.h +52 -0
  32. package/include/numkong/reduce.h +4 -0
  33. package/include/numkong/set/sve.h +6 -5
  34. package/include/numkong/sets/smebi32.h +35 -30
  35. package/include/numkong/sparse/sve2.h +3 -2
  36. package/include/numkong/spatial/sve.h +7 -6
  37. package/include/numkong/spatial/svebfdot.h +7 -4
  38. package/include/numkong/spatial/svehalf.h +5 -4
  39. package/include/numkong/spatial/svesdot.h +9 -8
  40. package/include/numkong/spatials/graniteamx.h +173 -0
  41. package/include/numkong/spatials/serial.h +22 -0
  42. package/include/numkong/spatials/sme.h +391 -350
  43. package/include/numkong/spatials/smef64.h +79 -70
  44. package/include/numkong/spatials.h +37 -4
  45. package/include/numkong/types.h +59 -0
  46. package/javascript/dist/cjs/numkong.js +13 -0
  47. package/javascript/dist/esm/numkong.js +13 -0
  48. package/javascript/numkong.c +56 -12
  49. package/javascript/numkong.ts +13 -0
  50. package/package.json +7 -7
  51. package/probes/probe.js +2 -2
  52. package/wasm/numkong.wasm +0 -0
@@ -0,0 +1,52 @@
1
+ /**
2
+ * @brief SVE horizontal reduction helpers with MSan unpoisoning.
3
+ * @file include/numkong/reduce/sve.h
4
+ * @author Ash Vardanian
5
+ * @date April 12, 2026
6
+ *
7
+ * LLVM's MSan does not instrument ARM SVE intrinsics — `svaddv` moves data
8
+ * from vector to scalar registers via architecture-specific paths invisible
9
+ * to the compiler, causing false-positive uninitialized-value reports.
10
+ * These macros wrap the reduction and unpoison the scalar result.
11
+ *
12
+ * The `svaddv` intrinsic stays inside a macro so it expands in the caller's
13
+ * target context — SVE and SME streaming translation units carry incompatible
14
+ * target attributes. The unpoisoning runs on the already-reduced scalar, so it
15
+ * lives in a target-agnostic `NK_INTERNAL` helper called from the macro.
16
+ *
17
+ * @sa include/numkong/reduce.h
18
+ */
19
+ #ifndef NK_REDUCE_SVE_H
20
+ #define NK_REDUCE_SVE_H
21
+
22
+ #if NK_TARGET_ARM64_
23
+ #if NK_TARGET_SVE || NK_TARGET_SVE2 || NK_TARGET_SME
24
+
25
+ #include "numkong/types.h"
26
+
27
+ NK_INTERNAL nk_f64_t nk_unpoison_f64_(nk_f64_t v) NK_STREAMING_COMPATIBLE_ {
28
+ nk_unpoison_(&v, sizeof(v));
29
+ return v;
30
+ }
31
+ NK_INTERNAL nk_f32_t nk_unpoison_f32_(nk_f32_t v) NK_STREAMING_COMPATIBLE_ {
32
+ nk_unpoison_(&v, sizeof(v));
33
+ return v;
34
+ }
35
+ NK_INTERNAL nk_u64_t nk_unpoison_u64_(nk_u64_t v) NK_STREAMING_COMPATIBLE_ {
36
+ nk_unpoison_(&v, sizeof(v));
37
+ return v;
38
+ }
39
+ NK_INTERNAL nk_i64_t nk_unpoison_i64_(nk_i64_t v) NK_STREAMING_COMPATIBLE_ {
40
+ nk_unpoison_(&v, sizeof(v));
41
+ return v;
42
+ }
43
+
44
+ #define nk_svaddv_f64_(predicate, vector) nk_unpoison_f64_(svaddv_f64((predicate), (vector)))
45
+ #define nk_svaddv_f32_(predicate, vector) nk_unpoison_f32_(svaddv_f32((predicate), (vector)))
46
+ #define nk_svaddv_u32_(predicate, vector) nk_unpoison_u64_(svaddv_u32((predicate), (vector)))
47
+ #define nk_svaddv_s32_(predicate, vector) nk_unpoison_i64_(svaddv_s32((predicate), (vector)))
48
+ #define nk_svaddv_u8_(predicate, vector) nk_unpoison_u64_(svaddv_u8((predicate), (vector)))
49
+
50
+ #endif // NK_TARGET_SVE || NK_TARGET_SVE2 || NK_TARGET_SME
51
+ #endif // NK_TARGET_ARM64_
52
+ #endif // NK_REDUCE_SVE_H
@@ -389,6 +389,8 @@ NK_PUBLIC void nk_reduce_moments_i16_neon(nk_i16_t const *, nk_size_t, nk_size_t
389
389
  /** @copydoc nk_reduce_moments_f64 */
390
390
  NK_PUBLIC void nk_reduce_moments_u16_neon(nk_u16_t const *, nk_size_t, nk_size_t, nk_u64_t *, nk_u64_t *);
391
391
  /** @copydoc nk_reduce_moments_f64 */
392
+ NK_PUBLIC void nk_reduce_moments_u1_neon(nk_u1x8_t const *, nk_size_t, nk_size_t, nk_u64_t *, nk_u64_t *);
393
+ /** @copydoc nk_reduce_moments_f64 */
392
394
  NK_PUBLIC void nk_reduce_moments_i32_neon(nk_i32_t const *, nk_size_t, nk_size_t, nk_i64_t *, nk_u64_t *);
393
395
  /** @copydoc nk_reduce_moments_f64 */
394
396
  NK_PUBLIC void nk_reduce_moments_u32_neon(nk_u32_t const *, nk_size_t, nk_size_t, nk_u64_t *, nk_u64_t *);
@@ -1559,6 +1561,8 @@ NK_PUBLIC void nk_reduce_moments_u1(nk_u1x8_t const *d, nk_size_t n, nk_size_t s
1559
1561
  nk_reduce_moments_u1_skylake(d, n, s, sum, sumsq);
1560
1562
  #elif NK_TARGET_HASWELL
1561
1563
  nk_reduce_moments_u1_haswell(d, n, s, sum, sumsq);
1564
+ #elif NK_TARGET_NEON
1565
+ nk_reduce_moments_u1_neon(d, n, s, sum, sumsq);
1562
1566
  #else
1563
1567
  nk_reduce_moments_u1_serial(d, n, s, sum, sumsq);
1564
1568
  #endif
@@ -32,8 +32,9 @@
32
32
  #if NK_TARGET_ARM64_
33
33
  #if NK_TARGET_SVE
34
34
 
35
- #include "numkong/types.h" // `nk_u1x8_t`
36
- #include "numkong/set/neon.h" // `nk_hamming_u1_neon`
35
+ #include "numkong/types.h" // `nk_u1x8_t`
36
+ #include "numkong/reduce/sve.h" // `nk_svaddv_f64_`
37
+ #include "numkong/set/neon.h" // `nk_hamming_u1_neon`
37
38
 
38
39
  #if defined(__cplusplus)
39
40
  extern "C" {
@@ -73,7 +74,7 @@ NK_PUBLIC void nk_hamming_u1_sve(nk_u1x8_t const *a, nk_u1x8_t const *b, nk_size
73
74
  i += words_per_register;
74
75
  ++cycle;
75
76
  } while (i < n_bytes && cycle < 31);
76
- differences += svaddv_u8(all_predicate_b8x, popcount_u8x);
77
+ differences += nk_svaddv_u8_(all_predicate_b8x, popcount_u8x);
77
78
  popcount_u8x = svdup_n_u8(0);
78
79
  cycle = 0; // Reset the cycle counter.
79
80
  }
@@ -110,9 +111,9 @@ NK_PUBLIC void nk_jaccard_u1_sve(nk_u1x8_t const *a, nk_u1x8_t const *b, nk_size
110
111
  i += words_per_register;
111
112
  ++cycle;
112
113
  } while (i < n_bytes && cycle < 31);
113
- intersection_count += svaddv_u8(all_predicate_b8x, intersection_popcount_u8x);
114
+ intersection_count += nk_svaddv_u8_(all_predicate_b8x, intersection_popcount_u8x);
114
115
  intersection_popcount_u8x = svdup_n_u8(0);
115
- union_count += svaddv_u8(all_predicate_b8x, union_popcount_u8x);
116
+ union_count += nk_svaddv_u8_(all_predicate_b8x, union_popcount_u8x);
116
117
  union_popcount_u8x = svdup_n_u8(0);
117
118
  cycle = 0; // Reset the cycle counter.
118
119
  }
@@ -41,8 +41,9 @@
41
41
  #include "numkong/types.h"
42
42
  #include "numkong/set/serial.h"
43
43
  #include "numkong/sets/serial.h"
44
- #include "numkong/dots/sme.h" // `nk_sme_zero_za32_*` constants
45
- #include "numkong/reduce.h" // `nk_reduce_moments_u1`
44
+ #include "numkong/reduce/sve.h" // `nk_svaddv_f64_`
45
+ #include "numkong/reduce/neon.h" // `nk_reduce_moments_u1_neon`
46
+ #include "numkong/dots/sme.h" // `nk_sme_zero_za32_*`
46
47
 
47
48
  #if defined(__cplusplus)
48
49
  extern "C" {
@@ -100,7 +101,7 @@ NK_PUBLIC nk_u32_t nk_sets_reduce_sumsq_u1_streaming_(nk_u1x8_t const *data, nk_
100
101
  svbool_t predicate_b8x = svwhilelt_b8_u64(offset, n_bytes);
101
102
  acc_u32x = svdot_u32(acc_u32x, svcnt_u8_z(predicate_b8x, svld1_u8(predicate_b8x, data + offset)), ones_u8x);
102
103
  }
103
- return (nk_u32_t)svaddv_u32(svptrue_b32(), acc_u32x);
104
+ return (nk_u32_t)nk_svaddv_u32_(svptrue_b32(), acc_u32x);
104
105
  }
105
106
 
106
107
  #pragma region Hamming Distance
@@ -187,11 +188,9 @@ NK_PUBLIC void nk_dots_pack_u1_smebi32(nk_u1x8_t const *b, nk_size_t row_count,
187
188
  // Compute per-row population counts
188
189
  for (nk_size_t row = 0; row < row_count; row++) {
189
190
  nk_u1x8_t const *src_row = (nk_u1x8_t const *)((char const *)b + row * b_stride_in_bytes);
190
- {
191
- nk_u64_t nk_local_sum_, nk_local_sumsq_;
192
- nk_reduce_moments_u1(src_row, depth_bytes * 8, sizeof(nk_u1x8_t), &nk_local_sum_, &nk_local_sumsq_);
193
- norms_ptr[row] = (nk_u32_t)nk_local_sum_;
194
- }
191
+ nk_u64_t nk_local_sum_, nk_local_sumsq_;
192
+ nk_reduce_moments_u1_neon(src_row, depth_bytes * 8, sizeof(nk_u1x8_t), &nk_local_sum_, &nk_local_sumsq_);
193
+ norms_ptr[row] = (nk_u32_t)nk_local_sum_;
195
194
  }
196
195
  }
197
196
 
@@ -203,9 +202,9 @@ NK_PUBLIC void nk_dots_pack_u1_smebi32(nk_u1x8_t const *b, nk_size_t row_count,
203
202
  * Each ZA0.S batch covers 16 depth u32 steps (one full depth tile).
204
203
  * BMOPA expansion=1 for u32: each u32 contributes 32 bits via XNOR+POPCNT.
205
204
  */
206
- __arm_locally_streaming __arm_new("za") static void nk_hammings_packed_u1_smebi32_streaming_(
205
+ __arm_new("za") static void nk_hammings_packed_u1_smebi32_streaming_( //
207
206
  nk_u1x8_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t row_count_a, nk_size_t row_count_b,
208
- nk_size_t depth_bits, nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
207
+ nk_size_t depth_bits, nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) NK_STREAMING_ {
209
208
 
210
209
  nk_sets_smebi32_packed_header_t const *header = (nk_sets_smebi32_packed_header_t const *)b_packed;
211
210
  nk_size_t const row_tile_count_b = header->row_tile_count;
@@ -344,11 +343,13 @@ __arm_locally_streaming __arm_new("za") static void nk_hammings_packed_u1_smebi3
344
343
  }
345
344
  }
346
345
 
347
- NK_PUBLIC void nk_hammings_packed_u1_smebi32(nk_u1x8_t const *a, void const *b_packed, nk_u32_t *c,
348
- nk_size_t row_count_a, nk_size_t row_count_b, nk_size_t depth_bits,
349
- nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
346
+ NK_PUBLIC void nk_hammings_packed_u1_smebi32( //
347
+ nk_u1x8_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t row_count_a, nk_size_t row_count_b,
348
+ nk_size_t depth_bits, nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
349
+ nk_sme_start_streaming_();
350
350
  nk_hammings_packed_u1_smebi32_streaming_(a, b_packed, c, row_count_a, row_count_b, depth_bits, a_stride_in_bytes,
351
351
  c_stride_in_bytes);
352
+ nk_sme_stop_streaming_();
352
353
  }
353
354
 
354
355
  /**
@@ -357,9 +358,9 @@ NK_PUBLIC void nk_hammings_packed_u1_smebi32(nk_u1x8_t const *a, void const *b_p
357
358
  * ZA1-3.S = BMOPA accumulators (3 B column tiles in fast path).
358
359
  * Mirrors the unpacked kernel nk_hammings_packed_u1_smebi32_streaming_ pattern.
359
360
  */
360
- __arm_locally_streaming __arm_new("za") static void nk_hammings_symmetric_u1_smebi32_streaming_(
361
+ __arm_new("za") static void nk_hammings_symmetric_u1_smebi32_streaming_( //
361
362
  nk_u1x8_t const *vectors, nk_size_t vectors_count, nk_size_t depth_bits, nk_size_t stride_in_bytes,
362
- nk_u32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
363
+ nk_u32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) NK_STREAMING_ {
363
364
 
364
365
  nk_size_t const tile_dim = svcntw(); // 16 for 512-bit SVL
365
366
  nk_size_t const depth_tile_size = svcntw(); // 16 u32 per depth tile
@@ -545,12 +546,13 @@ __arm_locally_streaming __arm_new("za") static void nk_hammings_symmetric_u1_sme
545
546
  }
546
547
  }
547
548
 
548
- NK_PUBLIC void nk_hammings_symmetric_u1_smebi32(nk_u1x8_t const *vectors, nk_size_t vectors_count, nk_size_t depth_bits,
549
- nk_size_t stride_in_bytes, nk_u32_t *result,
550
- nk_size_t result_stride_in_bytes, nk_size_t row_start,
551
- nk_size_t row_count) {
549
+ NK_PUBLIC void nk_hammings_symmetric_u1_smebi32( //
550
+ nk_u1x8_t const *vectors, nk_size_t vectors_count, nk_size_t depth_bits, nk_size_t stride_in_bytes,
551
+ nk_u32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
552
+ nk_sme_start_streaming_();
552
553
  nk_hammings_symmetric_u1_smebi32_streaming_(vectors, vectors_count, depth_bits, stride_in_bytes, result,
553
554
  result_stride_in_bytes, row_start, row_count);
555
+ nk_sme_stop_streaming_();
554
556
  }
555
557
 
556
558
  #pragma endregion Hamming Distance
@@ -581,9 +583,9 @@ NK_PUBLIC void nk_hammings_symmetric_u1_smebi32(nk_u1x8_t const *vectors, nk_siz
581
583
  * union = (norm_a + norm_b + hamming) / 2
582
584
  * jaccard = 1 - intersection / union (1.0 when union == 0)
583
585
  */
584
- __arm_locally_streaming __arm_new("za") static void nk_jaccards_packed_u1_smebi32_streaming_(
586
+ __arm_new("za") static void nk_jaccards_packed_u1_smebi32_streaming_( //
585
587
  nk_u1x8_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t row_count_a, nk_size_t row_count_b,
586
- nk_size_t depth_bits, nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
588
+ nk_size_t depth_bits, nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) NK_STREAMING_ {
587
589
 
588
590
  nk_sets_smebi32_packed_header_t const *header = (nk_sets_smebi32_packed_header_t const *)b_packed;
589
591
  nk_size_t const row_tile_count_b = header->row_tile_count;
@@ -796,11 +798,13 @@ __arm_locally_streaming __arm_new("za") static void nk_jaccards_packed_u1_smebi3
796
798
  }
797
799
  }
798
800
 
799
- NK_PUBLIC void nk_jaccards_packed_u1_smebi32(nk_u1x8_t const *a, void const *b_packed, nk_f32_t *c,
800
- nk_size_t row_count_a, nk_size_t row_count_b, nk_size_t depth_bits,
801
- nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
801
+ NK_PUBLIC void nk_jaccards_packed_u1_smebi32( //
802
+ nk_u1x8_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t row_count_a, nk_size_t row_count_b,
803
+ nk_size_t depth_bits, nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
804
+ nk_sme_start_streaming_();
802
805
  nk_jaccards_packed_u1_smebi32_streaming_(a, b_packed, c, row_count_a, row_count_b, depth_bits, a_stride_in_bytes,
803
806
  c_stride_in_bytes);
807
+ nk_sme_stop_streaming_();
804
808
  }
805
809
 
806
810
  /**
@@ -808,9 +812,9 @@ NK_PUBLIC void nk_jaccards_packed_u1_smebi32(nk_u1x8_t const *a, void const *b_p
808
812
  * Fills upper triangle only (column_tile >= row_tile); caller sees result[i][j] for j >= i.
809
813
  * Norms computed on-the-fly using streaming SVE popcount.
810
814
  */
811
- __arm_locally_streaming __arm_new("za") static void nk_jaccards_symmetric_u1_smebi32_streaming_(
815
+ __arm_new("za") static void nk_jaccards_symmetric_u1_smebi32_streaming_( //
812
816
  nk_u1x8_t const *vectors, nk_size_t vectors_count, nk_size_t depth_bits, nk_size_t stride_in_bytes,
813
- nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
817
+ nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) NK_STREAMING_ {
814
818
 
815
819
  nk_size_t const tile_dim = svcntw(); // 16 for 512-bit SVL
816
820
  nk_size_t const depth_tile_size = svcntw(); // 16 u32 per depth tile
@@ -1104,12 +1108,13 @@ __arm_locally_streaming __arm_new("za") static void nk_jaccards_symmetric_u1_sme
1104
1108
  }
1105
1109
  }
1106
1110
 
1107
- NK_PUBLIC void nk_jaccards_symmetric_u1_smebi32(nk_u1x8_t const *vectors, nk_size_t vectors_count, nk_size_t depth_bits,
1108
- nk_size_t stride_in_bytes, nk_f32_t *result,
1109
- nk_size_t result_stride_in_bytes, nk_size_t row_start,
1110
- nk_size_t row_count) {
1111
+ NK_PUBLIC void nk_jaccards_symmetric_u1_smebi32( //
1112
+ nk_u1x8_t const *vectors, nk_size_t vectors_count, nk_size_t depth_bits, nk_size_t stride_in_bytes,
1113
+ nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
1114
+ nk_sme_start_streaming_();
1111
1115
  nk_jaccards_symmetric_u1_smebi32_streaming_(vectors, vectors_count, depth_bits, stride_in_bytes, result,
1112
1116
  result_stride_in_bytes, row_start, row_count);
1117
+ nk_sme_stop_streaming_();
1113
1118
  }
1114
1119
 
1115
1120
  #pragma endregion Jaccard Distance
@@ -12,6 +12,7 @@
12
12
  #if NK_TARGET_ARM64_
13
13
 
14
14
  #include "numkong/types.h"
15
+ #include "numkong/reduce/sve.h" // `nk_svaddv_f64_`
15
16
 
16
17
  #if defined(__cplusplus)
17
18
  extern "C" {
@@ -395,7 +396,7 @@ NK_PUBLIC void nk_sparse_dot_u32f32_sve2( //
395
396
  a_idx += a_step;
396
397
  b_idx += b_step;
397
398
  }
398
- *product = svaddv_f64(predicate_all_b64x, product_f64x);
399
+ *product = nk_svaddv_f64_(predicate_all_b64x, product_f64x);
399
400
  }
400
401
 
401
402
  #if defined(__clang__)
@@ -485,7 +486,7 @@ NK_PUBLIC void nk_sparse_dot_u16bf16_sve2( //
485
486
  a_idx += a_step;
486
487
  b_idx += b_step;
487
488
  }
488
- *product = svaddv_f32(svptrue_b32(), product_f32x);
489
+ *product = nk_svaddv_f32_(svptrue_b32(), product_f32x);
489
490
  }
490
491
 
491
492
  #if defined(__clang__)
@@ -36,6 +36,7 @@
36
36
  #if NK_TARGET_SVE
37
37
 
38
38
  #include "numkong/types.h"
39
+ #include "numkong/reduce/sve.h" // `nk_svaddv_f64_`
39
40
  #include "numkong/spatial/neon.h" // `nk_f64_sqrt_neon`
40
41
  #include "numkong/dot/sve.h" // `nk_dot_stable_sum_f64_sve_`
41
42
 
@@ -113,7 +114,7 @@ NK_PUBLIC void nk_sqeuclidean_f32_sve(nk_f32_t const *a, nk_f32_t const *b, nk_s
113
114
  svfloat64_t diff_odd_f64x = svsub_f64_x(pred_odd_b64x, a_odd_f64x, b_odd_f64x);
114
115
  dist_sq_f64x = svmla_f64_m(pred_odd_b64x, dist_sq_f64x, diff_odd_f64x, diff_odd_f64x);
115
116
  }
116
- nk_f64_t dist_sq_f64 = svaddv_f64(svptrue_b64(), dist_sq_f64x);
117
+ nk_f64_t dist_sq_f64 = nk_svaddv_f64_(svptrue_b64(), dist_sq_f64x);
117
118
  *result = dist_sq_f64;
118
119
  }
119
120
 
@@ -149,9 +150,9 @@ NK_PUBLIC void nk_angular_f32_sve(nk_f32_t const *a, nk_f32_t const *b, nk_size_
149
150
  b2_f64x = svmla_f64_m(pred_odd_b64x, b2_f64x, b_odd_f64x, b_odd_f64x);
150
151
  }
151
152
 
152
- nk_f64_t ab_f64 = svaddv_f64(svptrue_b64(), ab_f64x);
153
- nk_f64_t a2_f64 = svaddv_f64(svptrue_b64(), a2_f64x);
154
- nk_f64_t b2_f64 = svaddv_f64(svptrue_b64(), b2_f64x);
153
+ nk_f64_t ab_f64 = nk_svaddv_f64_(svptrue_b64(), ab_f64x);
154
+ nk_f64_t a2_f64 = nk_svaddv_f64_(svptrue_b64(), a2_f64x);
155
+ nk_f64_t b2_f64 = nk_svaddv_f64_(svptrue_b64(), b2_f64x);
155
156
  *result = nk_angular_normalize_f64_neon_(ab_f64, a2_f64, b2_f64);
156
157
  }
157
158
 
@@ -225,8 +226,8 @@ NK_PUBLIC void nk_angular_f64_sve(nk_f64_t const *a, nk_f64_t const *b, nk_size_
225
226
  } while (i < n);
226
227
 
227
228
  nk_f64_t ab_f64 = nk_dot_stable_sum_f64_sve_(predicate_all_b64x, ab_sum_f64x, ab_compensation_f64x);
228
- nk_f64_t a2_f64 = svaddv_f64(predicate_all_b64x, a2_f64x);
229
- nk_f64_t b2_f64 = svaddv_f64(predicate_all_b64x, b2_f64x);
229
+ nk_f64_t a2_f64 = nk_svaddv_f64_(predicate_all_b64x, a2_f64x);
230
+ nk_f64_t b2_f64 = nk_svaddv_f64_(predicate_all_b64x, b2_f64x);
230
231
  *result = nk_angular_normalize_f64_neon_(ab_f64, a2_f64, b2_f64);
231
232
  }
232
233
 
@@ -36,6 +36,7 @@
36
36
  #if NK_TARGET_SVEBFDOT
37
37
 
38
38
  #include "numkong/types.h"
39
+ #include "numkong/reduce/sve.h" // `nk_svaddv_f64_`
39
40
  #include "numkong/spatial/neon.h" // `nk_f32_sqrt_neon`
40
41
 
41
42
  #if defined(__cplusplus)
@@ -75,7 +76,9 @@ NK_PUBLIC void nk_sqeuclidean_bf16_svebfdot(nk_bf16_t const *a_enum, nk_bf16_t c
75
76
  d2_high_f32x = svmla_f32_m(predicate_high_b32x, d2_high_f32x, a_minus_b_high_f32x, a_minus_b_high_f32x);
76
77
  i += svcnth();
77
78
  } while (i < n);
78
- nk_f32_t d2 = svaddv_f32(svptrue_b32(), d2_low_f32x) + svaddv_f32(svptrue_b32(), d2_high_f32x);
79
+ nk_f32_t d2_low = nk_svaddv_f32_(svptrue_b32(), d2_low_f32x);
80
+ nk_f32_t d2_high = nk_svaddv_f32_(svptrue_b32(), d2_high_f32x);
81
+ nk_f32_t d2 = d2_low + d2_high;
79
82
  *result = d2;
80
83
  }
81
84
  NK_PUBLIC void nk_euclidean_bf16_svebfdot(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result) {
@@ -101,9 +104,9 @@ NK_PUBLIC void nk_angular_bf16_svebfdot(nk_bf16_t const *a_enum, nk_bf16_t const
101
104
  i += svcnth();
102
105
  } while (i < n);
103
106
 
104
- nk_f32_t ab = svaddv_f32(svptrue_b32(), ab_f32x);
105
- nk_f32_t a2 = svaddv_f32(svptrue_b32(), a2_f32x);
106
- nk_f32_t b2 = svaddv_f32(svptrue_b32(), b2_f32x);
107
+ nk_f32_t ab = nk_svaddv_f32_(svptrue_b32(), ab_f32x);
108
+ nk_f32_t a2 = nk_svaddv_f32_(svptrue_b32(), a2_f32x);
109
+ nk_f32_t b2 = nk_svaddv_f32_(svptrue_b32(), b2_f32x);
107
110
  *result = nk_angular_normalize_f32_neon_(ab, a2, b2);
108
111
  }
109
112
 
@@ -32,6 +32,7 @@
32
32
  #if NK_TARGET_SVEHALF
33
33
 
34
34
  #include "numkong/types.h"
35
+ #include "numkong/reduce/sve.h" // `nk_svaddv_f64_`
35
36
  #include "numkong/spatial/neon.h" // `nk_f32_sqrt_neon`
36
37
 
37
38
  #if defined(__cplusplus)
@@ -74,7 +75,7 @@ NK_PUBLIC void nk_sqeuclidean_f16_svehalf(nk_f16_t const *a_enum, nk_f16_t const
74
75
 
75
76
  i += svcnth();
76
77
  } while (i < n);
77
- *result = svaddv_f32(svptrue_b32(), d2_f32x);
78
+ *result = nk_svaddv_f32_(svptrue_b32(), d2_f32x);
78
79
  }
79
80
 
80
81
  NK_PUBLIC void nk_euclidean_f16_svehalf(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result) {
@@ -114,9 +115,9 @@ NK_PUBLIC void nk_angular_f16_svehalf(nk_f16_t const *a_enum, nk_f16_t const *b_
114
115
  i += svcnth();
115
116
  } while (i < n);
116
117
 
117
- nk_f32_t ab_f32 = svaddv_f32(svptrue_b32(), ab_f32x);
118
- nk_f32_t a2_f32 = svaddv_f32(svptrue_b32(), a2_f32x);
119
- nk_f32_t b2_f32 = svaddv_f32(svptrue_b32(), b2_f32x);
118
+ nk_f32_t ab_f32 = nk_svaddv_f32_(svptrue_b32(), ab_f32x);
119
+ nk_f32_t a2_f32 = nk_svaddv_f32_(svptrue_b32(), a2_f32x);
120
+ nk_f32_t b2_f32 = nk_svaddv_f32_(svptrue_b32(), b2_f32x);
120
121
  *result = nk_angular_normalize_f32_neon_(ab_f32, a2_f32, b2_f32);
121
122
  }
122
123
 
@@ -34,6 +34,7 @@
34
34
  #if NK_TARGET_SVESDOT
35
35
 
36
36
  #include "numkong/types.h"
37
+ #include "numkong/reduce/sve.h" // `nk_svaddv_f64_`
37
38
  #include "numkong/spatial/neon.h" // `nk_angular_normalize_f32_neon_`, `nk_f32_sqrt_neon`
38
39
 
39
40
  #if defined(__cplusplus)
@@ -58,7 +59,7 @@ NK_PUBLIC void nk_sqeuclidean_i8_svesdot(nk_i8_t const *a, nk_i8_t const *b, nk_
58
59
  distance_sq_u32x = svdot_u32(distance_sq_u32x, diff_u8x, diff_u8x);
59
60
  i += svcntb();
60
61
  } while (i < n);
61
- *result = (nk_u32_t)svaddv_u32(svptrue_b32(), distance_sq_u32x);
62
+ *result = (nk_u32_t)nk_svaddv_u32_(svptrue_b32(), distance_sq_u32x);
62
63
  }
63
64
  NK_PUBLIC void nk_euclidean_i8_svesdot(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_f32_t *result) {
64
65
  nk_u32_t distance_sq_u32;
@@ -81,9 +82,9 @@ NK_PUBLIC void nk_angular_i8_svesdot(nk_i8_t const *a, nk_i8_t const *b, nk_size
81
82
  i += svcntb();
82
83
  } while (i < n);
83
84
 
84
- nk_i32_t ab = (nk_i32_t)svaddv_s32(svptrue_b32(), ab_i32x);
85
- nk_i32_t a2 = (nk_i32_t)svaddv_s32(svptrue_b32(), a2_i32x);
86
- nk_i32_t b2 = (nk_i32_t)svaddv_s32(svptrue_b32(), b2_i32x);
85
+ nk_i32_t ab = (nk_i32_t)nk_svaddv_s32_(svptrue_b32(), ab_i32x);
86
+ nk_i32_t a2 = (nk_i32_t)nk_svaddv_s32_(svptrue_b32(), a2_i32x);
87
+ nk_i32_t b2 = (nk_i32_t)nk_svaddv_s32_(svptrue_b32(), b2_i32x);
87
88
  *result = nk_angular_normalize_f32_neon_((nk_f32_t)ab, (nk_f32_t)a2, (nk_f32_t)b2);
88
89
  }
89
90
 
@@ -98,7 +99,7 @@ NK_PUBLIC void nk_sqeuclidean_u8_svesdot(nk_u8_t const *a, nk_u8_t const *b, nk_
98
99
  distance_sq_u32x = svdot_u32(distance_sq_u32x, diff_u8x, diff_u8x);
99
100
  i += svcntb();
100
101
  } while (i < n);
101
- *result = (nk_u32_t)svaddv_u32(svptrue_b32(), distance_sq_u32x);
102
+ *result = (nk_u32_t)nk_svaddv_u32_(svptrue_b32(), distance_sq_u32x);
102
103
  }
103
104
  NK_PUBLIC void nk_euclidean_u8_svesdot(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_f32_t *result) {
104
105
  nk_u32_t distance_sq_u32;
@@ -121,9 +122,9 @@ NK_PUBLIC void nk_angular_u8_svesdot(nk_u8_t const *a, nk_u8_t const *b, nk_size
121
122
  i += svcntb();
122
123
  } while (i < n);
123
124
 
124
- nk_u32_t ab = (nk_u32_t)svaddv_u32(svptrue_b32(), ab_u32x);
125
- nk_u32_t a2 = (nk_u32_t)svaddv_u32(svptrue_b32(), a2_u32x);
126
- nk_u32_t b2 = (nk_u32_t)svaddv_u32(svptrue_b32(), b2_u32x);
125
+ nk_u32_t ab = (nk_u32_t)nk_svaddv_u32_(svptrue_b32(), ab_u32x);
126
+ nk_u32_t a2 = (nk_u32_t)nk_svaddv_u32_(svptrue_b32(), a2_u32x);
127
+ nk_u32_t b2 = (nk_u32_t)nk_svaddv_u32_(svptrue_b32(), b2_u32x);
127
128
  *result = nk_angular_normalize_f32_neon_((nk_f32_t)ab, (nk_f32_t)a2, (nk_f32_t)b2);
128
129
  }
129
130
 
@@ -0,0 +1,173 @@
1
+ /**
2
+ * @brief Batched Spatial Distances for Granite Rapids (AMX-FP16) with AVX-512 Finalization.
3
+ * @file include/numkong/spatials/graniteamx.h
4
+ * @author Ash Vardanian
5
+ * @date April 9, 2026
6
+ *
7
+ * @sa include/numkong/spatials.h
8
+ */
9
+ #ifndef NK_SPATIALS_GRANITEAMX_H
10
+ #define NK_SPATIALS_GRANITEAMX_H
11
+
12
+ #if NK_TARGET_X8664_
13
+ #if NK_TARGET_GRANITEAMX
14
+
15
+ #include "numkong/spatial/skylake.h"
16
+ #include "numkong/spatial/serial.h"
17
+ #include "numkong/dots/graniteamx.h"
18
+
19
+ #if defined(__cplusplus)
20
+ extern "C" {
21
+ #endif
22
+
23
+ #if defined(__clang__)
24
+ #pragma clang attribute push( \
25
+ __attribute__((target( \
26
+ "avx2,avx512f,avx512vl,avx512bw,avx512dq,avx512fp16,avx512vbmi,f16c,fma,bmi,bmi2,amx-tile,amx-bf16,amx-int8,amx-fp16"))), \
27
+ apply_to = function)
28
+ #elif defined(__GNUC__)
29
+ #pragma GCC push_options
30
+ #pragma GCC target("avx2", "avx512f", "avx512vl", "avx512bw", "avx512dq", "avx512fp16", "avx512vbmi", "f16c", "fma", \
31
+ "bmi", "bmi2", "amx-tile", "amx-bf16", "amx-int8", "amx-fp16")
32
+ #endif
33
+
34
+ #pragma region F16 Packed
35
+
36
+ NK_INTERNAL void nk_angulars_packed_f16_graniteamx_finalize_(nk_f16_t const *a, void const *b_packed, nk_f32_t *c,
37
+ nk_size_t rows, nk_size_t columns, nk_size_t depth,
38
+ nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
39
+ nk_dots_amx_packed_header_t const *header = (nk_dots_amx_packed_header_t const *)b_packed;
40
+ nk_f32_t const *b_norms = (nk_f32_t const *)((char const *)b_packed + header->norms_byte_offset);
41
+ for (nk_size_t row = 0; row < rows; row++) {
42
+ nk_f32_t query_norm_sq = nk_dots_reduce_sumsq_f16_(a + row * a_stride_elements, depth);
43
+ nk_angulars_row_f32dots_sapphireamx_(c + row * c_stride_elements, b_norms, query_norm_sq, columns);
44
+ }
45
+ }
46
+
47
+ NK_PUBLIC void nk_angulars_packed_f16_graniteamx( //
48
+ nk_f16_t const *a, void const *b_packed, nk_f32_t *c, //
49
+ nk_size_t rows, nk_size_t columns, nk_size_t depth, //
50
+ nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
51
+ nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_f16_t);
52
+ nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
53
+ nk_dots_packed_f16_graniteamx(a, b_packed, c, rows, columns, depth, a_stride_in_bytes, c_stride_in_bytes);
54
+ nk_angulars_packed_f16_graniteamx_finalize_(a, b_packed, c, rows, columns, depth, a_stride_elements,
55
+ c_stride_elements);
56
+ }
57
+
58
+ NK_INTERNAL void nk_euclideans_packed_f16_graniteamx_finalize_(nk_f16_t const *a, void const *b_packed, nk_f32_t *c,
59
+ nk_size_t rows, nk_size_t columns, nk_size_t depth,
60
+ nk_size_t a_stride_elements,
61
+ nk_size_t c_stride_elements) {
62
+ nk_dots_amx_packed_header_t const *header = (nk_dots_amx_packed_header_t const *)b_packed;
63
+ nk_f32_t const *b_norms = (nk_f32_t const *)((char const *)b_packed + header->norms_byte_offset);
64
+ for (nk_size_t row = 0; row < rows; row++) {
65
+ nk_f32_t query_norm_sq = nk_dots_reduce_sumsq_f16_(a + row * a_stride_elements, depth);
66
+ nk_euclideans_row_f32dots_sapphireamx_(c + row * c_stride_elements, b_norms, query_norm_sq, columns);
67
+ }
68
+ }
69
+
70
+ NK_PUBLIC void nk_euclideans_packed_f16_graniteamx( //
71
+ nk_f16_t const *a, void const *b_packed, nk_f32_t *c, //
72
+ nk_size_t rows, nk_size_t columns, nk_size_t depth, //
73
+ nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
74
+ nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_f16_t);
75
+ nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
76
+ nk_dots_packed_f16_graniteamx(a, b_packed, c, rows, columns, depth, a_stride_in_bytes, c_stride_in_bytes);
77
+ nk_euclideans_packed_f16_graniteamx_finalize_(a, b_packed, c, rows, columns, depth, a_stride_elements,
78
+ c_stride_elements);
79
+ }
80
+
81
+ #pragma endregion F16 Packed
82
+
83
+ #pragma region F16 Symmetric
84
+
85
+ NK_INTERNAL void nk_angulars_symmetric_f16_graniteamx_finalize_(nk_f16_t const *vectors, nk_size_t vectors_count,
86
+ nk_size_t depth, nk_size_t stride_elements,
87
+ nk_f32_t *result, nk_size_t result_stride_elements,
88
+ nk_size_t row_start, nk_size_t row_count) {
89
+
90
+ for (nk_size_t row = row_start; row < row_start + row_count; row++)
91
+ result[row * result_stride_elements + row] = nk_dots_reduce_sumsq_f16_(vectors + row * stride_elements, depth);
92
+
93
+ nk_f32_t column_norms_cache[256];
94
+ for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
95
+ nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
96
+ for (nk_size_t col = chunk_start; col < chunk_end; col++)
97
+ column_norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_f16_(vectors + col * stride_elements, depth);
98
+
99
+ for (nk_size_t row = row_start; row < row_start + row_count; row++) {
100
+ nk_f32_t *r_row = result + row * result_stride_elements;
101
+ nk_size_t col_start = chunk_start > row + 1 ? chunk_start : row + 1;
102
+ if (col_start >= chunk_end) continue;
103
+ nk_angulars_row_f32dots_sapphireamx_(r_row + col_start, column_norms_cache + col_start - chunk_start,
104
+ r_row[row], chunk_end - col_start);
105
+ }
106
+ }
107
+
108
+ for (nk_size_t row = row_start; row < row_start + row_count; row++) result[row * result_stride_elements + row] = 0;
109
+ }
110
+
111
+ NK_PUBLIC void nk_angulars_symmetric_f16_graniteamx( //
112
+ nk_f16_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
113
+ nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
114
+ nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_f16_t);
115
+ nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
116
+ nk_dots_symmetric_f16_graniteamx(vectors, vectors_count, depth, stride_in_bytes, result, result_stride_in_bytes,
117
+ row_start, row_count);
118
+ nk_angulars_symmetric_f16_graniteamx_finalize_(vectors, vectors_count, depth, stride_elements, result,
119
+ result_stride_elements, row_start, row_count);
120
+ }
121
+
122
+ NK_INTERNAL void nk_euclideans_symmetric_f16_graniteamx_finalize_(nk_f16_t const *vectors, nk_size_t vectors_count,
123
+ nk_size_t depth, nk_size_t stride_elements,
124
+ nk_f32_t *result, nk_size_t result_stride_elements,
125
+ nk_size_t row_start, nk_size_t row_count) {
126
+
127
+ for (nk_size_t row = row_start; row < row_start + row_count; row++)
128
+ result[row * result_stride_elements + row] = nk_dots_reduce_sumsq_f16_(vectors + row * stride_elements, depth);
129
+
130
+ nk_f32_t column_norms_cache[256];
131
+ for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
132
+ nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
133
+ for (nk_size_t col = chunk_start; col < chunk_end; col++)
134
+ column_norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_f16_(vectors + col * stride_elements, depth);
135
+
136
+ for (nk_size_t row = row_start; row < row_start + row_count; row++) {
137
+ nk_f32_t *r_row = result + row * result_stride_elements;
138
+ nk_size_t col_start = chunk_start > row + 1 ? chunk_start : row + 1;
139
+ if (col_start >= chunk_end) continue;
140
+ nk_euclideans_row_f32dots_sapphireamx_(r_row + col_start, column_norms_cache + col_start - chunk_start,
141
+ r_row[row], chunk_end - col_start);
142
+ }
143
+ }
144
+
145
+ for (nk_size_t row = row_start; row < row_start + row_count; row++) result[row * result_stride_elements + row] = 0;
146
+ }
147
+
148
+ NK_PUBLIC void nk_euclideans_symmetric_f16_graniteamx( //
149
+ nk_f16_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
150
+ nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
151
+ nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_f16_t);
152
+ nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
153
+ nk_dots_symmetric_f16_graniteamx(vectors, vectors_count, depth, stride_in_bytes, result, result_stride_in_bytes,
154
+ row_start, row_count);
155
+ nk_euclideans_symmetric_f16_graniteamx_finalize_(vectors, vectors_count, depth, stride_elements, result,
156
+ result_stride_elements, row_start, row_count);
157
+ }
158
+
159
+ #pragma endregion F16 Symmetric
160
+
161
+ #if defined(__clang__)
162
+ #pragma clang attribute pop
163
+ #elif defined(__GNUC__)
164
+ #pragma GCC pop_options
165
+ #endif
166
+
167
+ #if defined(__cplusplus)
168
+ } // extern "C"
169
+ #endif
170
+
171
+ #endif // NK_TARGET_GRANITEAMX
172
+ #endif // NK_TARGET_X8664_
173
+ #endif // NK_SPATIALS_GRANITEAMX_H
@@ -15,6 +15,18 @@
15
15
  extern "C" {
16
16
  #endif
17
17
 
18
+ /* Optimize serial fallbacks for size — see dots/serial.h for rationale. */
19
+ #if defined(NDEBUG)
20
+ #if defined(_MSC_VER)
21
+ #pragma optimize("s", on)
22
+ #elif defined(__clang__)
23
+ #pragma clang attribute push(__attribute__((minsize)), apply_to = function)
24
+ #elif defined(__GNUC__)
25
+ #pragma GCC push_options
26
+ #pragma GCC optimize("Os")
27
+ #endif
28
+ #endif
29
+
18
30
  nk_define_cross_normalized_packed_(angular, f64, serial, f64, f64, f64, /*norm_value_type=*/f64, f64, nk_b256_vec_t,
19
31
  nk_dots_packed_f64_serial, nk_angular_through_f64_from_dot_serial_,
20
32
  nk_dots_reduce_sumsq_f64_, nk_load_b256_serial_, nk_partial_load_b64x4_serial_,
@@ -219,6 +231,16 @@ nk_define_cross_normalized_symmetric_(euclidean, u4, serial, u4x2, u32, /*norm_v
219
231
  nk_dots_reduce_sumsq_u4_, nk_load_b128_serial_, nk_partial_load_b32x4_serial_,
220
232
  nk_store_b128_serial_, nk_partial_store_b32x4_serial_, 2)
221
233
 
234
+ #if defined(NDEBUG)
235
+ #if defined(_MSC_VER)
236
+ #pragma optimize("", on)
237
+ #elif defined(__clang__)
238
+ #pragma clang attribute pop
239
+ #elif defined(__GNUC__)
240
+ #pragma GCC pop_options
241
+ #endif
242
+ #endif
243
+
222
244
  #if defined(__cplusplus)
223
245
  } // extern "C"
224
246
  #endif