numkong 7.4.4 → 7.5.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. package/README.md +1 -0
  2. package/binding.gyp +81 -5
  3. package/c/dispatch_f16.c +23 -0
  4. package/c/numkong.c +0 -13
  5. package/include/numkong/attention/sme.h +34 -31
  6. package/include/numkong/capabilities.h +2 -15
  7. package/include/numkong/cast/neon.h +15 -0
  8. package/include/numkong/curved/smef64.h +82 -62
  9. package/include/numkong/dot/rvvbf16.h +1 -1
  10. package/include/numkong/dot/rvvhalf.h +1 -1
  11. package/include/numkong/dot/sve.h +6 -5
  12. package/include/numkong/dot/svebfdot.h +2 -1
  13. package/include/numkong/dot/svehalf.h +6 -5
  14. package/include/numkong/dot/svesdot.h +3 -2
  15. package/include/numkong/dots/graniteamx.h +733 -0
  16. package/include/numkong/dots/serial.h +11 -4
  17. package/include/numkong/dots/sme.h +172 -140
  18. package/include/numkong/dots/smebi32.h +14 -11
  19. package/include/numkong/dots/smef64.h +31 -26
  20. package/include/numkong/dots.h +29 -3
  21. package/include/numkong/each/serial.h +22 -0
  22. package/include/numkong/geospatial/haswell.h +1 -1
  23. package/include/numkong/geospatial/neon.h +1 -1
  24. package/include/numkong/geospatial/serial.h +1 -1
  25. package/include/numkong/geospatial/skylake.h +1 -1
  26. package/include/numkong/maxsim/sme.h +94 -55
  27. package/include/numkong/mesh/README.md +13 -27
  28. package/include/numkong/mesh/haswell.h +25 -122
  29. package/include/numkong/mesh/neon.h +21 -110
  30. package/include/numkong/mesh/neonbfdot.h +4 -43
  31. package/include/numkong/mesh/rvv.h +7 -82
  32. package/include/numkong/mesh/serial.h +48 -53
  33. package/include/numkong/mesh/skylake.h +7 -123
  34. package/include/numkong/mesh/v128relaxed.h +9 -93
  35. package/include/numkong/mesh.h +2 -2
  36. package/include/numkong/mesh.hpp +35 -96
  37. package/include/numkong/reduce/neon.h +29 -0
  38. package/include/numkong/reduce/neonbfdot.h +2 -2
  39. package/include/numkong/reduce/neonfhm.h +4 -4
  40. package/include/numkong/reduce/sve.h +52 -0
  41. package/include/numkong/reduce.h +4 -0
  42. package/include/numkong/set/sve.h +6 -5
  43. package/include/numkong/sets/smebi32.h +35 -30
  44. package/include/numkong/sparse/sve2.h +3 -2
  45. package/include/numkong/spatial/sve.h +7 -6
  46. package/include/numkong/spatial/svebfdot.h +7 -4
  47. package/include/numkong/spatial/svehalf.h +5 -4
  48. package/include/numkong/spatial/svesdot.h +9 -8
  49. package/include/numkong/spatials/graniteamx.h +173 -0
  50. package/include/numkong/spatials/serial.h +22 -0
  51. package/include/numkong/spatials/sme.h +391 -350
  52. package/include/numkong/spatials/smef64.h +79 -70
  53. package/include/numkong/spatials.h +37 -4
  54. package/include/numkong/types.h +59 -0
  55. package/javascript/dist/cjs/numkong.js +13 -0
  56. package/javascript/dist/esm/numkong.js +13 -0
  57. package/javascript/numkong.c +56 -12
  58. package/javascript/numkong.ts +13 -0
  59. package/package.json +7 -7
  60. package/probes/probe.js +2 -2
  61. package/wasm/numkong.wasm +0 -0
@@ -39,9 +39,9 @@ extern "C" {
39
39
  * BMOPA gives matching = popcount(XNOR(a,b)).
40
40
  * dot(a,b) = popcount(a AND b) = (pop_a + pop_b - depth_bits + matching) / 2
41
41
  */
42
- __arm_locally_streaming __arm_new("za") static void nk_dots_packed_u1_smebi32_streaming_(
42
+ __arm_new("za") static void nk_dots_packed_u1_smebi32_streaming_( //
43
43
  nk_u1x8_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t row_count_a, nk_size_t row_count_b,
44
- nk_size_t depth_bits, nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
44
+ nk_size_t depth_bits, nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) NK_STREAMING_ {
45
45
 
46
46
  nk_sets_smebi32_packed_header_t const *header = (nk_sets_smebi32_packed_header_t const *)b_packed;
47
47
  nk_size_t const row_tile_count_b = header->row_tile_count;
@@ -204,20 +204,22 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_packed_u1_smebi32_st
204
204
  }
205
205
  }
206
206
 
207
- NK_PUBLIC void nk_dots_packed_u1_smebi32(nk_u1x8_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t row_count_a,
208
- nk_size_t row_count_b, nk_size_t depth_bits, nk_size_t a_stride_in_bytes,
209
- nk_size_t c_stride_in_bytes) {
207
+ NK_PUBLIC void nk_dots_packed_u1_smebi32( //
208
+ nk_u1x8_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t row_count_a, nk_size_t row_count_b,
209
+ nk_size_t depth_bits, nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
210
+ nk_sme_start_streaming_();
210
211
  nk_dots_packed_u1_smebi32_streaming_(a, b_packed, c, row_count_a, row_count_b, depth_bits, a_stride_in_bytes,
211
212
  c_stride_in_bytes);
213
+ nk_sme_stop_streaming_();
212
214
  }
213
215
 
214
216
  /**
215
217
  * Symmetric u1 dot-product using ZA0 time-sharing + 3-tile fast path.
216
218
  * Same ZA transpose pattern as hammings_symmetric, but with dot extraction.
217
219
  */
218
- __arm_locally_streaming __arm_new("za") static void nk_dots_symmetric_u1_smebi32_streaming_(
220
+ __arm_new("za") static void nk_dots_symmetric_u1_smebi32_streaming_( //
219
221
  nk_u1x8_t const *vectors, nk_size_t vectors_count, nk_size_t depth_bits, nk_size_t stride_in_bytes,
220
- nk_u32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
222
+ nk_u32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) NK_STREAMING_ {
221
223
 
222
224
  nk_size_t const tile_dim = svcntw(); // 16 for 512-bit SVL
223
225
  nk_size_t const depth_tile_size = svcntw(); // 16 u32 per depth tile
@@ -451,12 +453,13 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_symmetric_u1_smebi32
451
453
  }
452
454
  }
453
455
 
454
- NK_PUBLIC void nk_dots_symmetric_u1_smebi32(nk_u1x8_t const *vectors, nk_size_t vectors_count, nk_size_t depth_bits,
455
- nk_size_t stride_in_bytes, nk_u32_t *result,
456
- nk_size_t result_stride_in_bytes, nk_size_t row_start,
457
- nk_size_t row_count) {
456
+ NK_PUBLIC void nk_dots_symmetric_u1_smebi32( //
457
+ nk_u1x8_t const *vectors, nk_size_t vectors_count, nk_size_t depth_bits, nk_size_t stride_in_bytes,
458
+ nk_u32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
459
+ nk_sme_start_streaming_();
458
460
  nk_dots_symmetric_u1_smebi32_streaming_(vectors, vectors_count, depth_bits, stride_in_bytes, result,
459
461
  result_stride_in_bytes, row_start, row_count);
462
+ nk_sme_stop_streaming_();
460
463
  }
461
464
 
462
465
  #if defined(__clang__)
@@ -153,9 +153,9 @@ NK_PUBLIC void nk_dots_pack_f32_smef64(nk_f32_t const *b, nk_size_t columns, nk_
153
153
  }
154
154
  }
155
155
 
156
- __arm_locally_streaming __arm_new("za") static void nk_dots_packed_f32_smef64_streaming_(
156
+ __arm_new("za") static void nk_dots_packed_f32_smef64_streaming_( //
157
157
  nk_f32_t const *a, void const *b_packed, nk_f64_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
158
- nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
158
+ nk_size_t a_stride_elements, nk_size_t c_stride_elements) NK_STREAMING_ {
159
159
 
160
160
  nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
161
161
  nk_size_t const column_tile_count = header->column_tile_count;
@@ -390,14 +390,16 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_packed_f32_smef64_st
390
390
  }
391
391
  }
392
392
 
393
- NK_PUBLIC void nk_dots_packed_f32_smef64(nk_f32_t const *a, void const *b_packed, nk_f64_t *c, nk_size_t rows,
394
- nk_size_t columns, nk_size_t depth, nk_size_t a_stride_in_bytes,
395
- nk_size_t c_stride_in_bytes) {
393
+ NK_PUBLIC void nk_dots_packed_f32_smef64( //
394
+ nk_f32_t const *a, void const *b_packed, nk_f64_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
395
+ nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
396
396
 
397
397
  nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_f32_t);
398
398
  nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f64_t);
399
399
 
400
+ nk_sme_start_streaming_();
400
401
  nk_dots_packed_f32_smef64_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
402
+ nk_sme_stop_streaming_();
401
403
  }
402
404
 
403
405
  /**
@@ -406,9 +408,9 @@ NK_PUBLIC void nk_dots_packed_f32_smef64(nk_f32_t const *a, void const *b_packed
406
408
  * pre-reads A columns into Z registers, then reloads ZA0 with widened B data
407
409
  * per column tile. Eliminates all scalar B-packing loops.
408
410
  */
409
- __arm_locally_streaming __arm_new("za") static void nk_dots_symmetric_f32_smef64_streaming_(
411
+ __arm_new("za") static void nk_dots_symmetric_f32_smef64_streaming_( //
410
412
  nk_f32_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, nk_f64_t *result,
411
- nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
413
+ nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) NK_STREAMING_ {
412
414
 
413
415
  nk_size_t const tile_dimension = svcntd(); // 8 for SVL=512
414
416
  nk_size_t const depth_tile_size = svcntw(); // 16 for SVL=512
@@ -721,15 +723,16 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_symmetric_f32_smef64
721
723
  }
722
724
  }
723
725
 
724
- NK_PUBLIC void nk_dots_symmetric_f32_smef64(nk_f32_t const *vectors, nk_size_t vectors_count, nk_size_t depth,
725
- nk_size_t stride_in_bytes, nk_f64_t *result,
726
- nk_size_t result_stride_in_bytes, nk_size_t row_start,
727
- nk_size_t row_count) {
726
+ NK_PUBLIC void nk_dots_symmetric_f32_smef64( //
727
+ nk_f32_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, nk_f64_t *result,
728
+ nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
728
729
 
729
730
  nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_f32_t);
730
731
  nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f64_t);
732
+ nk_sme_start_streaming_();
731
733
  nk_dots_symmetric_f32_smef64_streaming_(vectors, vectors_count, depth, stride_elements, result,
732
734
  result_stride_elements, row_start, row_count);
735
+ nk_sme_stop_streaming_();
733
736
  }
734
737
 
735
738
  #pragma endregion F32 Floats
@@ -783,17 +786,16 @@ NK_PUBLIC void nk_dots_symmetric_f32_smef64(nk_f32_t const *vectors, nk_size_t v
783
786
  *
784
787
  * All slices fit in f32 (24-bit significand). Products: max 19+19 = 38 ≤ 53, exact in f64.
785
788
  */
786
- NK_PUBLIC nk_u64_t nk_f64_smef64_ozaki_mask_19_bits_(void) NK_STREAMING_ {
789
+ NK_PUBLIC nk_u64_t nk_f64_smef64_ozaki_mask_19_bits_(void) {
787
790
  return 0xFFFFFFFC00000000ULL; // keep top 19 sig bits
788
791
  }
789
- NK_PUBLIC nk_u64_t nk_f64_smef64_ozaki_mask_17_bits_(void) NK_STREAMING_ {
792
+ NK_PUBLIC nk_u64_t nk_f64_smef64_ozaki_mask_17_bits_(void) {
790
793
  return 0xFFFFFFF000000000ULL; // keep top 17 sig bits
791
794
  }
792
795
 
793
796
  /* Split a scalar f64 into 3 non-overlapping Ozaki slices (19+17+17 mantissa bits).
794
797
  * Each slice fits in f32. Outputs stored via pointers. */
795
- NK_PUBLIC void nk_f64_smef64_ozaki_split_f64_(nk_f64_t val, nk_f64_t *slice_0, nk_f64_t *slice_1,
796
- nk_f64_t *slice_2) NK_STREAMING_ {
798
+ NK_PUBLIC void nk_f64_smef64_ozaki_split_f64_(nk_f64_t val, nk_f64_t *slice_0, nk_f64_t *slice_1, nk_f64_t *slice_2) {
797
799
  nk_fui64_t pun;
798
800
  pun.f = val;
799
801
  pun.u &= nk_f64_smef64_ozaki_mask_19_bits_();
@@ -805,9 +807,9 @@ NK_PUBLIC void nk_f64_smef64_ozaki_split_f64_(nk_f64_t val, nk_f64_t *slice_0, n
805
807
  *slice_2 = residual - *slice_1;
806
808
  }
807
809
 
808
- __arm_locally_streaming __arm_new("za") static void nk_dots_symmetric_f64_smef64_streaming_(
810
+ __arm_new("za") static void nk_dots_symmetric_f64_smef64_streaming_( //
809
811
  nk_f64_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, nk_f64_t *result,
810
- nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
812
+ nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) NK_STREAMING_ {
811
813
 
812
814
  nk_size_t const tile_dimension = svcntd();
813
815
  nk_size_t const depth_steps_per_batch = tile_dimension;
@@ -929,15 +931,16 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_symmetric_f64_smef64
929
931
  }
930
932
  }
931
933
 
932
- NK_PUBLIC void nk_dots_symmetric_f64_smef64(nk_f64_t const *vectors, nk_size_t vectors_count, nk_size_t depth,
933
- nk_size_t stride_in_bytes, nk_f64_t *result,
934
- nk_size_t result_stride_in_bytes, nk_size_t row_start,
935
- nk_size_t row_count) {
934
+ NK_PUBLIC void nk_dots_symmetric_f64_smef64( //
935
+ nk_f64_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, nk_f64_t *result,
936
+ nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
936
937
 
937
938
  nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_f64_t);
938
939
  nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f64_t);
940
+ nk_sme_start_streaming_();
939
941
  nk_dots_symmetric_f64_smef64_streaming_(vectors, vectors_count, depth, stride_elements, result,
940
942
  result_stride_elements, row_start, row_count);
943
+ nk_sme_stop_streaming_();
941
944
  }
942
945
 
943
946
  NK_PUBLIC nk_size_t nk_dots_packed_size_f64_smef64(nk_size_t columns, nk_size_t depth) {
@@ -1018,9 +1021,9 @@ NK_PUBLIC void nk_dots_pack_f64_smef64(nk_f64_t const *b, nk_size_t columns, nk_
1018
1021
  }
1019
1022
  }
1020
1023
 
1021
- __arm_locally_streaming __arm_new("za") static void nk_dots_packed_f64_smef64_streaming_(
1024
+ __arm_new("za") static void nk_dots_packed_f64_smef64_streaming_( //
1022
1025
  nk_f64_t const *a, void const *b_packed, nk_f64_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
1023
- nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
1026
+ nk_size_t a_stride_elements, nk_size_t c_stride_elements) NK_STREAMING_ {
1024
1027
 
1025
1028
  // Read header
1026
1029
  nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
@@ -1296,14 +1299,16 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_packed_f64_smef64_st
1296
1299
  }
1297
1300
  }
1298
1301
 
1299
- NK_PUBLIC void nk_dots_packed_f64_smef64(nk_f64_t const *a, void const *b_packed, nk_f64_t *c, nk_size_t rows,
1300
- nk_size_t columns, nk_size_t depth, nk_size_t a_stride_in_bytes,
1301
- nk_size_t c_stride_in_bytes) {
1302
+ NK_PUBLIC void nk_dots_packed_f64_smef64( //
1303
+ nk_f64_t const *a, void const *b_packed, nk_f64_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
1304
+ nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
1302
1305
 
1303
1306
  nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_f64_t);
1304
1307
  nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f64_t);
1305
1308
 
1309
+ nk_sme_start_streaming_();
1306
1310
  nk_dots_packed_f64_smef64_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
1311
+ nk_sme_stop_streaming_();
1307
1312
  }
1308
1313
 
1309
1314
  #pragma endregion F64 Floats
@@ -681,6 +681,25 @@ NK_PUBLIC void nk_dots_symmetric_u8_sapphireamx(nk_u8_t const *vectors, nk_size_
681
681
  nk_size_t row_start, nk_size_t row_count);
682
682
  #endif // NK_TARGET_SAPPHIREAMX
683
683
 
684
+ /* Granite Rapids backends using Intel AMX-FP16 (Advanced Matrix Extensions with FP16 support).
685
+ * AMX-FP16 adds TDPFP16PS (FP16×FP16→FP32 tile multiply-accumulate), same tile geometry as BF16.
686
+ * The F32 Ozaki kernel splits F32 inputs into 2 FP16 halves for ~35-40 bit effective precision.
687
+ */
688
+ #if NK_TARGET_GRANITEAMX
689
+ /** @copydoc nk_dots_packed_size_f16 */
690
+ NK_PUBLIC nk_size_t nk_dots_packed_size_f16_graniteamx(nk_size_t width, nk_size_t depth);
691
+ /** @copydoc nk_dots_pack_f16 */
692
+ NK_PUBLIC void nk_dots_pack_f16_graniteamx(nk_f16_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
693
+ void *b_packed);
694
+ /** @copydoc nk_dots_packed_f16 */
695
+ NK_PUBLIC void nk_dots_packed_f16_graniteamx(nk_f16_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
696
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
697
+ /** @copydoc nk_dots_symmetric_f16 */
698
+ NK_PUBLIC void nk_dots_symmetric_f16_graniteamx(nk_f16_t const *vectors, nk_size_t vectors_count, nk_size_t depth,
699
+ nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
700
+ nk_size_t row_start, nk_size_t row_count);
701
+ #endif // NK_TARGET_GRANITEAMX
702
+
684
703
  /* ARM SME backends using Scalable Matrix Extension.
685
704
  * SME provides ZA tile registers for outer product operations.
686
705
  * F16/BF16/I8/U8/E4M3 use ZA32 tiles, F32/F64 use ZA64 tiles (FEAT_SME_F64F64).
@@ -1858,6 +1877,7 @@ NK_PUBLIC void nk_dots_symmetric_u1_loongsonasx(nk_u1x8_t const *vectors, nk_siz
1858
1877
  #include "numkong/dots/genoa.h"
1859
1878
  #include "numkong/dots/diamond.h"
1860
1879
  #include "numkong/dots/sapphireamx.h"
1880
+ #include "numkong/dots/graniteamx.h"
1861
1881
  #include "numkong/dots/neon.h"
1862
1882
  #include "numkong/dots/neonsdot.h"
1863
1883
  #include "numkong/dots/neonfhm.h"
@@ -2002,7 +2022,9 @@ NK_PUBLIC void nk_dots_packed_f64(nk_f64_t const *a, void const *b_packed, nk_f6
2002
2022
  }
2003
2023
 
2004
2024
  NK_PUBLIC nk_size_t nk_dots_packed_size_f16(nk_size_t width, nk_size_t depth) {
2005
- #if NK_TARGET_SME
2025
+ #if NK_TARGET_GRANITEAMX
2026
+ return nk_dots_packed_size_f16_graniteamx(width, depth);
2027
+ #elif NK_TARGET_SME
2006
2028
  return nk_dots_packed_size_f16_sme(width, depth);
2007
2029
  #elif NK_TARGET_NEONFHM
2008
2030
  return nk_dots_packed_size_f16_neonfhm(width, depth);
@@ -2023,7 +2045,9 @@ NK_PUBLIC nk_size_t nk_dots_packed_size_f16(nk_size_t width, nk_size_t depth) {
2023
2045
 
2024
2046
  NK_PUBLIC void nk_dots_pack_f16(nk_f16_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
2025
2047
  void *b_packed) {
2026
- #if NK_TARGET_SME
2048
+ #if NK_TARGET_GRANITEAMX
2049
+ nk_dots_pack_f16_graniteamx(b, width, depth, b_stride, b_packed);
2050
+ #elif NK_TARGET_SME
2027
2051
  nk_dots_pack_f16_sme(b, width, depth, b_stride, b_packed);
2028
2052
  #elif NK_TARGET_NEONFHM
2029
2053
  nk_dots_pack_f16_neonfhm(b, width, depth, b_stride, b_packed);
@@ -2044,7 +2068,9 @@ NK_PUBLIC void nk_dots_pack_f16(nk_f16_t const *b, nk_size_t width, nk_size_t de
2044
2068
 
2045
2069
  NK_PUBLIC void nk_dots_packed_f16(nk_f16_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
2046
2070
  nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride) {
2047
- #if NK_TARGET_SME
2071
+ #if NK_TARGET_GRANITEAMX
2072
+ nk_dots_packed_f16_graniteamx(a, b_packed, c, height, width, depth, a_stride, c_stride);
2073
+ #elif NK_TARGET_SME
2048
2074
  nk_dots_packed_f16_sme(a, b_packed, c, height, width, depth, a_stride, c_stride);
2049
2075
  #elif NK_TARGET_NEONFHM
2050
2076
  nk_dots_packed_f16_neonfhm(a, b_packed, c, height, width, depth, a_stride, c_stride);
@@ -76,6 +76,18 @@ extern "C" {
76
76
  } \
77
77
  }
78
78
 
79
+ /* Optimize serial fallbacks for size — see dots/serial.h for rationale. */
80
+ #if defined(NDEBUG)
81
+ #if defined(_MSC_VER)
82
+ #pragma optimize("s", on)
83
+ #elif defined(__clang__)
84
+ #pragma clang attribute push(__attribute__((minsize)), apply_to = function)
85
+ #elif defined(__GNUC__)
86
+ #pragma GCC push_options
87
+ #pragma GCC optimize("Os")
88
+ #endif
89
+ #endif
90
+
79
91
  nk_define_each_sum_(f64, f64, nk_assign_from_to_, nk_assign_from_to_) // nk_each_sum_f64_serial
80
92
  nk_define_each_sum_(f32, f32, nk_assign_from_to_, nk_assign_from_to_) // nk_each_sum_f32_serial
81
93
  nk_define_each_sum_(f16, f32, nk_f16_to_f32_serial, nk_f32_to_f16_serial) // nk_each_sum_f16_serial
@@ -253,6 +265,16 @@ NK_PUBLIC void nk_each_fma_f64c_serial(nk_f64c_t const *a, nk_f64c_t const *b, n
253
265
  }
254
266
  }
255
267
 
268
+ #if defined(NDEBUG)
269
+ #if defined(_MSC_VER)
270
+ #pragma optimize("", on)
271
+ #elif defined(__clang__)
272
+ #pragma clang attribute pop
273
+ #elif defined(__GNUC__)
274
+ #pragma GCC pop_options
275
+ #endif
276
+ #endif
277
+
256
278
  #if defined(__cplusplus)
257
279
  } // extern "C"
258
280
  #endif
@@ -24,7 +24,7 @@
24
24
  #if NK_TARGET_HASWELL
25
25
 
26
26
  #include "numkong/types.h"
27
- #include "numkong/trigonometry/haswell.h" // `nk_sin_f64x4_haswell_`, `nk_cos_f64x4_haswell_`, `nk_atan2_f64x4_haswell_`, etc.
27
+ #include "numkong/trigonometry/haswell.h" // `nk_sin_f64x4_haswell_`, `nk_cos_f64x4_haswell_`, `nk_atan2_f64x4_haswell_`
28
28
 
29
29
  #if defined(__cplusplus)
30
30
  extern "C" {
@@ -21,7 +21,7 @@
21
21
  #if NK_TARGET_NEON
22
22
 
23
23
  #include "numkong/types.h"
24
- #include "numkong/trigonometry/neon.h" // `nk_sin_f64x2_neon_`, `nk_cos_f64x2_neon_`, `nk_atan2_f64x2_neon_`, etc.
24
+ #include "numkong/trigonometry/neon.h" // `nk_sin_f64x2_neon_`, `nk_cos_f64x2_neon_`, `nk_atan2_f64x2_neon_`
25
25
 
26
26
  #if defined(__cplusplus)
27
27
  extern "C" {
@@ -11,7 +11,7 @@
11
11
 
12
12
  #include "numkong/types.h"
13
13
  #include "numkong/spatial/serial.h" // `nk_f64_sqrt_serial`, `nk_f32_sqrt_serial`
14
- #include "numkong/trigonometry/serial.h" // `nk_f64_sin`, `nk_f64_cos`, `nk_f64_atan2`, etc.
14
+ #include "numkong/trigonometry/serial.h" // `nk_f64_sin`, `nk_f64_cos`, `nk_f64_atan2`
15
15
 
16
16
  #if defined(__cplusplus)
17
17
  extern "C" {
@@ -24,7 +24,7 @@
24
24
  #if NK_TARGET_SKYLAKE
25
25
 
26
26
  #include "numkong/types.h"
27
- #include "numkong/trigonometry/skylake.h" // `nk_sin_f64x8_skylake_`, `nk_cos_f64x8_skylake_`, `nk_atan2_f64x8_skylake_`, etc.
27
+ #include "numkong/trigonometry/skylake.h" // `nk_sin_f64x8_skylake_`, `nk_cos_f64x8_skylake_`, `nk_atan2_f64x8_skylake_`
28
28
 
29
29
  #if defined(__cplusplus)
30
30
  extern "C" {
@@ -46,7 +46,8 @@
46
46
  #if NK_TARGET_ARM64_
47
47
  #if NK_TARGET_SME
48
48
 
49
- #include "numkong/dots/sme.h" // nk_dots_sme_packed_header_t, nk_dots_pack_{f16,bf16}_sme, nk_dots_packed_size_{f16,bf16}_sme
49
+ #include "numkong/dots/sme.h" // `nk_dots_sme_packed_header_t`
50
+ #include "numkong/reduce/sve.h" // `nk_svaddv_f64_`
50
51
 
51
52
  #if defined(__cplusplus)
52
53
  extern "C" {
@@ -90,10 +91,9 @@ NK_STATIC_ASSERT(sizeof(nk_maxsim_sme_packed_header_t) == 64, nk_maxsim_sme_pack
90
91
  *
91
92
  * 1-tile remainder: uses ZA0 only, with predicated loads for partial tiles.
92
93
  */
93
- __arm_locally_streaming __arm_new("za") static void nk_maxsim_packed_f16_streaming_( //
94
- void const *query_packed, void const *document_packed, //
95
- nk_size_t query_count, nk_size_t document_count, //
96
- nk_size_t depth, nk_f32_t *result) {
94
+ __arm_new("za") static void nk_maxsim_packed_f16_streaming_( //
95
+ void const *query_packed, void const *document_packed, nk_size_t query_count, nk_size_t document_count,
96
+ nk_size_t depth, nk_f32_t *result) NK_STREAMING_ {
97
97
 
98
98
  nk_maxsim_sme_packed_header_t const *query_header = (nk_maxsim_sme_packed_header_t const *)query_packed;
99
99
  nk_maxsim_sme_packed_header_t const *document_header = (nk_maxsim_sme_packed_header_t const *)document_packed;
@@ -258,18 +258,19 @@ __arm_locally_streaming __arm_new("za") static void nk_maxsim_packed_f16_streami
258
258
  document_inverse_norms_f32x);
259
259
  svfloat32_t angular_distance_f32x = svmax_f32_x(
260
260
  row_predicate_b32x, svsub_f32_x(row_predicate_b32x, svdup_f32(1.0f), cosine_f32x), svdup_f32(0.0f));
261
- total_angular_distance += svaddv_f32(row_predicate_b32x, angular_distance_f32x);
261
+ total_angular_distance += nk_svaddv_f32_(row_predicate_b32x, angular_distance_f32x);
262
262
  }
263
263
 
264
264
  *result = total_angular_distance;
265
265
  }
266
266
 
267
- NK_PUBLIC void nk_maxsim_packed_f16_sme( //
268
- void const *query_packed, void const *document_packed, //
269
- nk_size_t query_count, nk_size_t document_count, nk_size_t depth, //
270
- nk_f32_t *result) { //
267
+ NK_PUBLIC void nk_maxsim_packed_f16_sme( //
268
+ void const *query_packed, void const *document_packed, nk_size_t query_count, nk_size_t document_count,
269
+ nk_size_t depth, nk_f32_t *result) {
271
270
 
271
+ nk_sme_start_streaming_();
272
272
  nk_maxsim_packed_f16_streaming_(query_packed, document_packed, query_count, document_count, depth, result);
273
+ nk_sme_stop_streaming_();
273
274
  }
274
275
 
275
276
  /**
@@ -281,10 +282,9 @@ NK_PUBLIC void nk_maxsim_packed_f16_sme( //
281
282
  *
282
283
  * 1-tile remainder: uses ZA0 only, with predicated loads for partial tiles.
283
284
  */
284
- __arm_locally_streaming __arm_new("za") static void nk_maxsim_packed_bf16_streaming_( //
285
- void const *query_packed, void const *document_packed, //
286
- nk_size_t query_count, nk_size_t document_count, //
287
- nk_size_t depth, nk_f32_t *result) {
285
+ __arm_new("za") static void nk_maxsim_packed_bf16_streaming_( //
286
+ void const *query_packed, void const *document_packed, nk_size_t query_count, nk_size_t document_count,
287
+ nk_size_t depth, nk_f32_t *result) NK_STREAMING_ {
288
288
 
289
289
  nk_maxsim_sme_packed_header_t const *query_header = (nk_maxsim_sme_packed_header_t const *)query_packed;
290
290
  nk_maxsim_sme_packed_header_t const *document_header = (nk_maxsim_sme_packed_header_t const *)document_packed;
@@ -454,18 +454,19 @@ __arm_locally_streaming __arm_new("za") static void nk_maxsim_packed_bf16_stream
454
454
  document_inverse_norms_f32x);
455
455
  svfloat32_t angular_distance_f32x = svmax_f32_x(
456
456
  row_predicate_b32x, svsub_f32_x(row_predicate_b32x, svdup_f32(1.0f), cosine_f32x), svdup_f32(0.0f));
457
- total_angular_distance += svaddv_f32(row_predicate_b32x, angular_distance_f32x);
457
+ total_angular_distance += nk_svaddv_f32_(row_predicate_b32x, angular_distance_f32x);
458
458
  }
459
459
 
460
460
  *result = total_angular_distance;
461
461
  }
462
462
 
463
- NK_PUBLIC void nk_maxsim_packed_bf16_sme( //
464
- void const *query_packed, void const *document_packed, //
465
- nk_size_t query_count, nk_size_t document_count, nk_size_t depth, //
466
- nk_f32_t *result) { //
463
+ NK_PUBLIC void nk_maxsim_packed_bf16_sme( //
464
+ void const *query_packed, void const *document_packed, nk_size_t query_count, nk_size_t document_count,
465
+ nk_size_t depth, nk_f32_t *result) {
467
466
 
467
+ nk_sme_start_streaming_();
468
468
  nk_maxsim_packed_bf16_streaming_(query_packed, document_packed, query_count, document_count, depth, result);
469
+ nk_sme_stop_streaming_();
469
470
  }
470
471
 
471
472
  NK_PUBLIC nk_size_t nk_maxsim_packed_size_bf16_sme(nk_size_t columns, nk_size_t depth) { //
@@ -649,7 +650,47 @@ NK_PUBLIC nk_f64_t nk_maxsim_reduce_dot_f32_ssve_( //
649
650
  svfloat64_t b_odd_f64x = svcvtlt_f64_f32_x(predicate_odd_b64x, b_f32x);
650
651
  accumulator_odd_f64x = svmla_f64_m(predicate_odd_b64x, accumulator_odd_f64x, a_odd_f64x, b_odd_f64x);
651
652
  }
652
- return svaddv_f64(svptrue_b64(), accumulator_even_f64x) + svaddv_f64(svptrue_b64(), accumulator_odd_f64x);
653
+ return nk_svaddv_f64_(svptrue_b64(), accumulator_even_f64x) + nk_svaddv_f64_(svptrue_b64(), accumulator_odd_f64x);
654
+ }
655
+
656
+ /**
657
+ * Streaming-compatible angular distance accumulation from pre-reduced dot products
658
+ * and contiguous f64 norm arrays.
659
+ * Computes rsqrt via Newton-Raphson and accumulates `1 - dot / sqrt(||q||^2 * ||d||^2)`.
660
+ */
661
+ NK_PUBLIC nk_f64_t nk_maxsim_angular_from_dots_ssve_( //
662
+ nk_f64_t const *dot_products, nk_size_t count, //
663
+ nk_f64_t const *query_norms_f64, nk_f64_t const *document_norms_f64) NK_STREAMING_ { //
664
+
665
+ nk_f64_t total_angular_distance_f64 = 0.0;
666
+ nk_size_t const vector_length = svcntd();
667
+ for (nk_size_t i = 0; i < count; i += vector_length) {
668
+ svbool_t predicate_b64x = svwhilelt_b64_u64(i, count);
669
+ svfloat64_t dot_products_f64x = svld1_f64(predicate_b64x, dot_products + i);
670
+ svfloat64_t query_norms_f64x = svld1_f64(predicate_b64x, query_norms_f64 + i);
671
+ svfloat64_t document_norms_f64x = svld1_f64(predicate_b64x, document_norms_f64 + i);
672
+
673
+ // norm_product = query_norm * document_norm
674
+ svfloat64_t norm_products_f64x = svmul_f64_x(predicate_b64x, query_norms_f64x, document_norms_f64x);
675
+
676
+ // Newton-Raphson rsqrt: estimate then two refinement steps
677
+ svfloat64_t rsqrt_f64x = svrsqrte_f64(norm_products_f64x);
678
+ rsqrt_f64x = svmul_f64_x(predicate_b64x, rsqrt_f64x,
679
+ svrsqrts_f64(svmul_f64_x(predicate_b64x, norm_products_f64x, rsqrt_f64x), rsqrt_f64x));
680
+ rsqrt_f64x = svmul_f64_x(predicate_b64x, rsqrt_f64x,
681
+ svrsqrts_f64(svmul_f64_x(predicate_b64x, norm_products_f64x, rsqrt_f64x), rsqrt_f64x));
682
+
683
+ // cosine = dot_product * rsqrt(norm_product), zeroed where norm <= 0
684
+ svbool_t positive_b64x = svcmpgt_f64(predicate_b64x, norm_products_f64x, svdup_n_f64(0.0));
685
+ svfloat64_t cosine_f64x = svmul_f64_z(positive_b64x, dot_products_f64x, rsqrt_f64x);
686
+
687
+ // angular_distance = max(0, 1 - cosine)
688
+ svfloat64_t angular_distance_f64x = svsub_f64_x(predicate_b64x, svdup_f64(1.0), cosine_f64x);
689
+ angular_distance_f64x = svmax_f64_x(predicate_b64x, angular_distance_f64x, svdup_f64(0.0));
690
+
691
+ total_angular_distance_f64 += nk_svaddv_f64_(predicate_b64x, angular_distance_f64x);
692
+ }
693
+ return total_angular_distance_f64;
653
694
  }
654
695
 
655
696
  /**
@@ -661,10 +702,9 @@ NK_PUBLIC nk_f64_t nk_maxsim_reduce_dot_f32_ssve_( //
661
702
  * Refinement: tile-wide interleaved f64 dot products for the winning (query, document) pairs.
662
703
  * Angular distance: 1 - dot / sqrt(||q||^2 * ||d||^2), accumulated with f64.
663
704
  */
664
- __arm_locally_streaming __arm_new("za") static void nk_maxsim_packed_f32_streaming_( //
665
- void const *query_packed, void const *document_packed, //
666
- nk_size_t query_count, nk_size_t document_count, nk_size_t depth, //
667
- nk_f64_t *result) {
705
+ __arm_new("za") static void nk_maxsim_packed_f32_streaming_( //
706
+ void const *query_packed, void const *document_packed, nk_size_t query_count, nk_size_t document_count,
707
+ nk_size_t depth, nk_f64_t *result) NK_STREAMING_ {
668
708
 
669
709
  nk_maxsim_sme_packed_header_t const *query_header = (nk_maxsim_sme_packed_header_t const *)query_packed;
670
710
  nk_maxsim_sme_packed_header_t const *document_header = (nk_maxsim_sme_packed_header_t const *)document_packed;
@@ -895,48 +935,47 @@ __arm_locally_streaming __arm_new("za") static void nk_maxsim_packed_f32_streami
895
935
  svcvtlt_f64_f32_x(predicate_odd_b64x, document_values_3_f32x));
896
936
  }
897
937
 
898
- // Reduce accumulators and compute angular distance per row
899
- svfloat64_t *batch_accumulators[] = {&accumulator_0_f64x, &accumulator_1_f64x, &accumulator_2_f64x,
900
- &accumulator_3_f64x};
901
- for (nk_size_t batch_index = 0; batch_index < 4; batch_index++) {
902
- nk_size_t query_index = row_start + row_batch_start + batch_index;
903
- nk_u32_t best_document_index = best_document_indices[row_batch_start + batch_index];
904
- nk_f64_t dot_product_f64 = svaddv_f64(svptrue_b64(), *batch_accumulators[batch_index]);
905
- nk_f64_t norm_product_f64 = (nk_f64_t)query_norms[query_index] *
906
- (nk_f64_t)document_norms[best_document_index];
907
- nk_f64_t cosine_f64 = (norm_product_f64 > 0.0) ? dot_product_f64 * nk_f64_rsqrt_serial(norm_product_f64)
908
- : 0.0;
909
- nk_f64_t angular_distance_f64 = 1.0 - cosine_f64;
910
- if (angular_distance_f64 < 0.0) angular_distance_f64 = 0.0;
911
- total_angular_distance_f64 += angular_distance_f64;
938
+ // Reduce SVE accumulators to scalars and compute angular distances
939
+ nk_f64_t dot_products_f64[4];
940
+ dot_products_f64[0] = nk_svaddv_f64_(svptrue_b64(), accumulator_0_f64x);
941
+ dot_products_f64[1] = nk_svaddv_f64_(svptrue_b64(), accumulator_1_f64x);
942
+ dot_products_f64[2] = nk_svaddv_f64_(svptrue_b64(), accumulator_2_f64x);
943
+ dot_products_f64[3] = nk_svaddv_f64_(svptrue_b64(), accumulator_3_f64x);
944
+ nk_f64_t batch_query_norms_f64[4], batch_document_norms_f64[4];
945
+ for (nk_size_t i = 0; i < 4; i++) {
946
+ batch_query_norms_f64[i] = (nk_f64_t)query_norms[row_start + row_batch_start + i];
947
+ batch_document_norms_f64[i] = (nk_f64_t)document_norms[best_document_indices[row_batch_start + i]];
912
948
  }
949
+ total_angular_distance_f64 += nk_maxsim_angular_from_dots_ssve_(dot_products_f64, 4, batch_query_norms_f64,
950
+ batch_document_norms_f64);
913
951
  }
914
952
 
915
- // Remainder: 1 row at a time
916
- for (; row_batch_start < rows_remaining; row_batch_start++) {
917
- nk_size_t query_index = row_start + row_batch_start;
918
- nk_u32_t best_document_index = best_document_indices[row_batch_start];
919
- nk_f64_t dot_product_f64 = nk_maxsim_reduce_dot_f32_ssve_(query_original_ptrs[row_batch_start],
920
- document_original_ptrs[row_batch_start], depth);
921
- nk_f64_t norm_product_f64 = (nk_f64_t)query_norms[query_index] *
922
- (nk_f64_t)document_norms[best_document_index];
923
- nk_f64_t cosine_f64 = (norm_product_f64 > 0.0) ? dot_product_f64 * nk_f64_rsqrt_serial(norm_product_f64)
924
- : 0.0;
925
- nk_f64_t angular_distance_f64 = 1.0 - cosine_f64;
926
- if (angular_distance_f64 < 0.0) angular_distance_f64 = 0.0;
927
- total_angular_distance_f64 += angular_distance_f64;
953
+ // Remainder: compute dot products then batch the angular distance
954
+ nk_size_t remainder_count = rows_remaining - row_batch_start;
955
+ if (remainder_count > 0) {
956
+ nk_f64_t remainder_dot_products_f64[3];
957
+ nk_f64_t remainder_query_norms_f64[3], remainder_document_norms_f64[3];
958
+ for (nk_size_t i = 0; i < remainder_count; i++) {
959
+ remainder_dot_products_f64[i] = nk_maxsim_reduce_dot_f32_ssve_(
960
+ query_original_ptrs[row_batch_start + i], document_original_ptrs[row_batch_start + i], depth);
961
+ remainder_query_norms_f64[i] = (nk_f64_t)query_norms[row_start + row_batch_start + i];
962
+ remainder_document_norms_f64[i] = (nk_f64_t)document_norms[best_document_indices[row_batch_start + i]];
963
+ }
964
+ total_angular_distance_f64 += nk_maxsim_angular_from_dots_ssve_(
965
+ remainder_dot_products_f64, remainder_count, remainder_query_norms_f64, remainder_document_norms_f64);
928
966
  }
929
967
  }
930
968
 
931
969
  *result = total_angular_distance_f64;
932
970
  }
933
971
 
934
- NK_PUBLIC void nk_maxsim_packed_f32_sme( //
935
- void const *query_packed, void const *document_packed, //
936
- nk_size_t query_count, nk_size_t document_count, nk_size_t depth, //
937
- nk_f64_t *result) { //
972
+ NK_PUBLIC void nk_maxsim_packed_f32_sme( //
973
+ void const *query_packed, void const *document_packed, nk_size_t query_count, nk_size_t document_count,
974
+ nk_size_t depth, nk_f64_t *result) {
938
975
 
976
+ nk_sme_start_streaming_();
939
977
  nk_maxsim_packed_f32_streaming_(query_packed, document_packed, query_count, document_count, depth, result);
978
+ nk_sme_stop_streaming_();
940
979
  }
941
980
 
942
981
  #if defined(__clang__)