numkong 7.4.5 → 7.6.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. package/README.md +1 -0
  2. package/binding.gyp +99 -5
  3. package/c/dispatch_e5m2.c +23 -3
  4. package/c/dispatch_f16.c +23 -0
  5. package/c/numkong.c +0 -13
  6. package/include/numkong/attention/sme.h +34 -31
  7. package/include/numkong/capabilities.h +2 -15
  8. package/include/numkong/cast/README.md +3 -0
  9. package/include/numkong/cast/haswell.h +28 -64
  10. package/include/numkong/cast/neon.h +15 -0
  11. package/include/numkong/cast/serial.h +17 -0
  12. package/include/numkong/cast/skylake.h +67 -52
  13. package/include/numkong/cast.h +1 -0
  14. package/include/numkong/curved/smef64.h +82 -62
  15. package/include/numkong/dot/README.md +1 -0
  16. package/include/numkong/dot/haswell.h +92 -13
  17. package/include/numkong/dot/rvvbf16.h +1 -1
  18. package/include/numkong/dot/rvvhalf.h +1 -1
  19. package/include/numkong/dot/serial.h +15 -0
  20. package/include/numkong/dot/skylake.h +61 -14
  21. package/include/numkong/dot/sve.h +6 -5
  22. package/include/numkong/dot/svebfdot.h +2 -1
  23. package/include/numkong/dot/svehalf.h +6 -5
  24. package/include/numkong/dot/svesdot.h +3 -2
  25. package/include/numkong/dots/README.md +2 -0
  26. package/include/numkong/dots/graniteamx.h +1167 -0
  27. package/include/numkong/dots/haswell.h +28 -28
  28. package/include/numkong/dots/sapphireamx.h +1 -1
  29. package/include/numkong/dots/serial.h +33 -11
  30. package/include/numkong/dots/skylake.h +28 -23
  31. package/include/numkong/dots/sme.h +172 -140
  32. package/include/numkong/dots/smebi32.h +14 -11
  33. package/include/numkong/dots/smef64.h +31 -26
  34. package/include/numkong/dots.h +41 -3
  35. package/include/numkong/each/serial.h +39 -0
  36. package/include/numkong/geospatial/haswell.h +1 -1
  37. package/include/numkong/geospatial/neon.h +1 -1
  38. package/include/numkong/geospatial/serial.h +15 -4
  39. package/include/numkong/geospatial/skylake.h +1 -1
  40. package/include/numkong/maxsim/serial.h +15 -0
  41. package/include/numkong/maxsim/sme.h +34 -33
  42. package/include/numkong/mesh/README.md +50 -44
  43. package/include/numkong/mesh/genoa.h +462 -0
  44. package/include/numkong/mesh/haswell.h +806 -933
  45. package/include/numkong/mesh/neon.h +871 -943
  46. package/include/numkong/mesh/neonbfdot.h +382 -522
  47. package/include/numkong/mesh/neonfhm.h +676 -0
  48. package/include/numkong/mesh/rvv.h +404 -319
  49. package/include/numkong/mesh/serial.h +225 -161
  50. package/include/numkong/mesh/skylake.h +1029 -1585
  51. package/include/numkong/mesh/v128relaxed.h +403 -377
  52. package/include/numkong/mesh.h +38 -0
  53. package/include/numkong/reduce/neon.h +29 -0
  54. package/include/numkong/reduce/neonbfdot.h +2 -2
  55. package/include/numkong/reduce/neonfhm.h +4 -4
  56. package/include/numkong/reduce/serial.h +15 -1
  57. package/include/numkong/reduce/sve.h +52 -0
  58. package/include/numkong/reduce.h +4 -0
  59. package/include/numkong/set/sve.h +6 -5
  60. package/include/numkong/sets/smebi32.h +35 -30
  61. package/include/numkong/sparse/serial.h +17 -2
  62. package/include/numkong/sparse/sve2.h +3 -2
  63. package/include/numkong/spatial/genoa.h +0 -68
  64. package/include/numkong/spatial/haswell.h +98 -56
  65. package/include/numkong/spatial/serial.h +15 -0
  66. package/include/numkong/spatial/skylake.h +114 -54
  67. package/include/numkong/spatial/sve.h +7 -6
  68. package/include/numkong/spatial/svebfdot.h +7 -4
  69. package/include/numkong/spatial/svehalf.h +5 -4
  70. package/include/numkong/spatial/svesdot.h +9 -8
  71. package/include/numkong/spatial.h +0 -12
  72. package/include/numkong/spatials/graniteamx.h +301 -0
  73. package/include/numkong/spatials/serial.h +39 -0
  74. package/include/numkong/spatials/skylake.h +2 -2
  75. package/include/numkong/spatials/sme.h +391 -350
  76. package/include/numkong/spatials/smef64.h +79 -70
  77. package/include/numkong/spatials.h +54 -4
  78. package/include/numkong/tensor.hpp +107 -23
  79. package/include/numkong/types.h +59 -0
  80. package/javascript/dist/cjs/numkong.js +13 -0
  81. package/javascript/dist/esm/numkong.js +13 -0
  82. package/javascript/numkong.c +59 -14
  83. package/javascript/numkong.ts +13 -0
  84. package/package.json +7 -7
  85. package/probes/probe.js +2 -2
  86. package/wasm/numkong.wasm +0 -0
@@ -604,12 +604,6 @@ NK_PUBLIC void nk_euclidean_bf16_genoa(nk_bf16_t const *a, nk_bf16_t const *b, n
604
604
  NK_PUBLIC void nk_sqeuclidean_bf16_genoa(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result);
605
605
  /** @copydoc nk_angular_f64 */
606
606
  NK_PUBLIC void nk_angular_bf16_genoa(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result);
607
- /** @copydoc nk_euclidean_f64 */
608
- NK_PUBLIC void nk_euclidean_e5m2_genoa(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result);
609
- /** @copydoc nk_sqeuclidean_f64 */
610
- NK_PUBLIC void nk_sqeuclidean_e5m2_genoa(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result);
611
- /** @copydoc nk_angular_f64 */
612
- NK_PUBLIC void nk_angular_e5m2_genoa(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result);
613
607
  #endif // NK_TARGET_GENOA
614
608
 
615
609
  #if NK_TARGET_DIAMOND
@@ -1263,8 +1257,6 @@ NK_PUBLIC void nk_euclidean_e5m2(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size
1263
1257
  nk_euclidean_e5m2_neonfp8(a, b, n, result);
1264
1258
  #elif NK_TARGET_DIAMOND
1265
1259
  nk_euclidean_e5m2_diamond(a, b, n, result);
1266
- #elif NK_TARGET_GENOA
1267
- nk_euclidean_e5m2_genoa(a, b, n, result);
1268
1260
  #elif NK_TARGET_SKYLAKE
1269
1261
  nk_euclidean_e5m2_skylake(a, b, n, result);
1270
1262
  #elif NK_TARGET_RVV
@@ -1281,8 +1273,6 @@ NK_PUBLIC void nk_sqeuclidean_e5m2(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_si
1281
1273
  nk_sqeuclidean_e5m2_neonfp8(a, b, n, result);
1282
1274
  #elif NK_TARGET_DIAMOND
1283
1275
  nk_sqeuclidean_e5m2_diamond(a, b, n, result);
1284
- #elif NK_TARGET_GENOA
1285
- nk_sqeuclidean_e5m2_genoa(a, b, n, result);
1286
1276
  #elif NK_TARGET_SKYLAKE
1287
1277
  nk_sqeuclidean_e5m2_skylake(a, b, n, result);
1288
1278
  #elif NK_TARGET_RVV
@@ -1299,8 +1289,6 @@ NK_PUBLIC void nk_angular_e5m2(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t
1299
1289
  nk_angular_e5m2_neonfp8(a, b, n, result);
1300
1290
  #elif NK_TARGET_DIAMOND
1301
1291
  nk_angular_e5m2_diamond(a, b, n, result);
1302
- #elif NK_TARGET_GENOA
1303
- nk_angular_e5m2_genoa(a, b, n, result);
1304
1292
  #elif NK_TARGET_SKYLAKE
1305
1293
  nk_angular_e5m2_skylake(a, b, n, result);
1306
1294
  #elif NK_TARGET_RVV
@@ -0,0 +1,301 @@
1
+ /**
2
+ * @brief Batched Spatial Distances for Granite Rapids (AMX-FP16) with AVX-512 Finalization.
3
+ * @file include/numkong/spatials/graniteamx.h
4
+ * @author Ash Vardanian
5
+ * @date April 9, 2026
6
+ *
7
+ * @sa include/numkong/spatials.h
8
+ */
9
+ #ifndef NK_SPATIALS_GRANITEAMX_H
10
+ #define NK_SPATIALS_GRANITEAMX_H
11
+
12
+ #if NK_TARGET_X8664_
13
+ #if NK_TARGET_GRANITEAMX
14
+
15
+ #include "numkong/spatial/skylake.h"
16
+ #include "numkong/spatial/serial.h"
17
+ #include "numkong/dots/graniteamx.h"
18
+
19
+ #if defined(__cplusplus)
20
+ extern "C" {
21
+ #endif
22
+
23
+ #if defined(__clang__)
24
+ #pragma clang attribute push( \
25
+ __attribute__((target( \
26
+ "avx2,avx512f,avx512vl,avx512bw,avx512dq,avx512fp16,avx512vbmi,f16c,fma,bmi,bmi2,amx-tile,amx-bf16,amx-int8,amx-fp16"))), \
27
+ apply_to = function)
28
+ #elif defined(__GNUC__)
29
+ #pragma GCC push_options
30
+ #pragma GCC target("avx2", "avx512f", "avx512vl", "avx512bw", "avx512dq", "avx512fp16", "avx512vbmi", "f16c", "fma", \
31
+ "bmi", "bmi2", "amx-tile", "amx-bf16", "amx-int8", "amx-fp16")
32
+ #endif
33
+
34
+ #pragma region F16 Packed
35
+
36
+ NK_INTERNAL void nk_angulars_packed_f16_graniteamx_finalize_(nk_f16_t const *a, void const *b_packed, nk_f32_t *c,
37
+ nk_size_t rows, nk_size_t columns, nk_size_t depth,
38
+ nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
39
+ nk_dots_amx_packed_header_t const *header = (nk_dots_amx_packed_header_t const *)b_packed;
40
+ nk_f32_t const *b_norms = (nk_f32_t const *)((char const *)b_packed + header->norms_byte_offset);
41
+ for (nk_size_t row = 0; row < rows; row++) {
42
+ nk_f32_t query_norm_sq = nk_dots_reduce_sumsq_f16_(a + row * a_stride_elements, depth);
43
+ nk_angulars_row_f32dots_sapphireamx_(c + row * c_stride_elements, b_norms, query_norm_sq, columns);
44
+ }
45
+ }
46
+
47
+ NK_PUBLIC void nk_angulars_packed_f16_graniteamx( //
48
+ nk_f16_t const *a, void const *b_packed, nk_f32_t *c, //
49
+ nk_size_t rows, nk_size_t columns, nk_size_t depth, //
50
+ nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
51
+ nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_f16_t);
52
+ nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
53
+ nk_dots_packed_f16_graniteamx(a, b_packed, c, rows, columns, depth, a_stride_in_bytes, c_stride_in_bytes);
54
+ nk_angulars_packed_f16_graniteamx_finalize_(a, b_packed, c, rows, columns, depth, a_stride_elements,
55
+ c_stride_elements);
56
+ }
57
+
58
+ NK_INTERNAL void nk_euclideans_packed_f16_graniteamx_finalize_(nk_f16_t const *a, void const *b_packed, nk_f32_t *c,
59
+ nk_size_t rows, nk_size_t columns, nk_size_t depth,
60
+ nk_size_t a_stride_elements,
61
+ nk_size_t c_stride_elements) {
62
+ nk_dots_amx_packed_header_t const *header = (nk_dots_amx_packed_header_t const *)b_packed;
63
+ nk_f32_t const *b_norms = (nk_f32_t const *)((char const *)b_packed + header->norms_byte_offset);
64
+ for (nk_size_t row = 0; row < rows; row++) {
65
+ nk_f32_t query_norm_sq = nk_dots_reduce_sumsq_f16_(a + row * a_stride_elements, depth);
66
+ nk_euclideans_row_f32dots_sapphireamx_(c + row * c_stride_elements, b_norms, query_norm_sq, columns);
67
+ }
68
+ }
69
+
70
+ NK_PUBLIC void nk_euclideans_packed_f16_graniteamx( //
71
+ nk_f16_t const *a, void const *b_packed, nk_f32_t *c, //
72
+ nk_size_t rows, nk_size_t columns, nk_size_t depth, //
73
+ nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
74
+ nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_f16_t);
75
+ nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
76
+ nk_dots_packed_f16_graniteamx(a, b_packed, c, rows, columns, depth, a_stride_in_bytes, c_stride_in_bytes);
77
+ nk_euclideans_packed_f16_graniteamx_finalize_(a, b_packed, c, rows, columns, depth, a_stride_elements,
78
+ c_stride_elements);
79
+ }
80
+
81
+ #pragma endregion F16 Packed
82
+
83
+ #pragma region F16 Symmetric
84
+
85
+ NK_INTERNAL void nk_angulars_symmetric_f16_graniteamx_finalize_(nk_f16_t const *vectors, nk_size_t vectors_count,
86
+ nk_size_t depth, nk_size_t stride_elements,
87
+ nk_f32_t *result, nk_size_t result_stride_elements,
88
+ nk_size_t row_start, nk_size_t row_count) {
89
+
90
+ for (nk_size_t row = row_start; row < row_start + row_count; row++)
91
+ result[row * result_stride_elements + row] = nk_dots_reduce_sumsq_f16_(vectors + row * stride_elements, depth);
92
+
93
+ nk_f32_t column_norms_cache[256];
94
+ for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
95
+ nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
96
+ for (nk_size_t col = chunk_start; col < chunk_end; col++)
97
+ column_norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_f16_(vectors + col * stride_elements, depth);
98
+
99
+ for (nk_size_t row = row_start; row < row_start + row_count; row++) {
100
+ nk_f32_t *r_row = result + row * result_stride_elements;
101
+ nk_size_t col_start = chunk_start > row + 1 ? chunk_start : row + 1;
102
+ if (col_start >= chunk_end) continue;
103
+ nk_angulars_row_f32dots_sapphireamx_(r_row + col_start, column_norms_cache + col_start - chunk_start,
104
+ r_row[row], chunk_end - col_start);
105
+ }
106
+ }
107
+
108
+ for (nk_size_t row = row_start; row < row_start + row_count; row++) result[row * result_stride_elements + row] = 0;
109
+ }
110
+
111
+ NK_PUBLIC void nk_angulars_symmetric_f16_graniteamx( //
112
+ nk_f16_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
113
+ nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
114
+ nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_f16_t);
115
+ nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
116
+ nk_dots_symmetric_f16_graniteamx(vectors, vectors_count, depth, stride_in_bytes, result, result_stride_in_bytes,
117
+ row_start, row_count);
118
+ nk_angulars_symmetric_f16_graniteamx_finalize_(vectors, vectors_count, depth, stride_elements, result,
119
+ result_stride_elements, row_start, row_count);
120
+ }
121
+
122
+ NK_INTERNAL void nk_euclideans_symmetric_f16_graniteamx_finalize_(nk_f16_t const *vectors, nk_size_t vectors_count,
123
+ nk_size_t depth, nk_size_t stride_elements,
124
+ nk_f32_t *result, nk_size_t result_stride_elements,
125
+ nk_size_t row_start, nk_size_t row_count) {
126
+
127
+ for (nk_size_t row = row_start; row < row_start + row_count; row++)
128
+ result[row * result_stride_elements + row] = nk_dots_reduce_sumsq_f16_(vectors + row * stride_elements, depth);
129
+
130
+ nk_f32_t column_norms_cache[256];
131
+ for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
132
+ nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
133
+ for (nk_size_t col = chunk_start; col < chunk_end; col++)
134
+ column_norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_f16_(vectors + col * stride_elements, depth);
135
+
136
+ for (nk_size_t row = row_start; row < row_start + row_count; row++) {
137
+ nk_f32_t *r_row = result + row * result_stride_elements;
138
+ nk_size_t col_start = chunk_start > row + 1 ? chunk_start : row + 1;
139
+ if (col_start >= chunk_end) continue;
140
+ nk_euclideans_row_f32dots_sapphireamx_(r_row + col_start, column_norms_cache + col_start - chunk_start,
141
+ r_row[row], chunk_end - col_start);
142
+ }
143
+ }
144
+
145
+ for (nk_size_t row = row_start; row < row_start + row_count; row++) result[row * result_stride_elements + row] = 0;
146
+ }
147
+
148
+ NK_PUBLIC void nk_euclideans_symmetric_f16_graniteamx( //
149
+ nk_f16_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
150
+ nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
151
+ nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_f16_t);
152
+ nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
153
+ nk_dots_symmetric_f16_graniteamx(vectors, vectors_count, depth, stride_in_bytes, result, result_stride_in_bytes,
154
+ row_start, row_count);
155
+ nk_euclideans_symmetric_f16_graniteamx_finalize_(vectors, vectors_count, depth, stride_elements, result,
156
+ result_stride_elements, row_start, row_count);
157
+ }
158
+
159
+ #pragma endregion F16 Symmetric
160
+
161
+ #pragma region E5M2 Packed
162
+
163
+ NK_INTERNAL void nk_angulars_packed_e5m2_graniteamx_finalize_(nk_e5m2_t const *a, void const *b_packed, nk_f32_t *c,
164
+ nk_size_t rows, nk_size_t columns, nk_size_t depth,
165
+ nk_size_t a_stride_elements,
166
+ nk_size_t c_stride_elements) {
167
+ nk_dots_amx_packed_header_t const *header = (nk_dots_amx_packed_header_t const *)b_packed;
168
+ nk_f32_t const *b_norms = (nk_f32_t const *)((char const *)b_packed + header->norms_byte_offset);
169
+ for (nk_size_t row = 0; row < rows; row++) {
170
+ nk_f32_t query_norm_sq = nk_dots_reduce_sumsq_e5m2_(a + row * a_stride_elements, depth);
171
+ nk_angulars_row_f32dots_sapphireamx_(c + row * c_stride_elements, b_norms, query_norm_sq, columns);
172
+ }
173
+ }
174
+
175
+ NK_PUBLIC void nk_angulars_packed_e5m2_graniteamx( //
176
+ nk_e5m2_t const *a, void const *b_packed, nk_f32_t *c, //
177
+ nk_size_t rows, nk_size_t columns, nk_size_t depth, //
178
+ nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
179
+ nk_size_t const a_stride_elements = a_stride_in_bytes;
180
+ nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
181
+ nk_dots_packed_e5m2_graniteamx(a, b_packed, c, rows, columns, depth, a_stride_in_bytes, c_stride_in_bytes);
182
+ nk_angulars_packed_e5m2_graniteamx_finalize_(a, b_packed, c, rows, columns, depth, a_stride_elements,
183
+ c_stride_elements);
184
+ }
185
+
186
+ NK_INTERNAL void nk_euclideans_packed_e5m2_graniteamx_finalize_(nk_e5m2_t const *a, void const *b_packed, nk_f32_t *c,
187
+ nk_size_t rows, nk_size_t columns, nk_size_t depth,
188
+ nk_size_t a_stride_elements,
189
+ nk_size_t c_stride_elements) {
190
+ nk_dots_amx_packed_header_t const *header = (nk_dots_amx_packed_header_t const *)b_packed;
191
+ nk_f32_t const *b_norms = (nk_f32_t const *)((char const *)b_packed + header->norms_byte_offset);
192
+ for (nk_size_t row = 0; row < rows; row++) {
193
+ nk_f32_t query_norm_sq = nk_dots_reduce_sumsq_e5m2_(a + row * a_stride_elements, depth);
194
+ nk_euclideans_row_f32dots_sapphireamx_(c + row * c_stride_elements, b_norms, query_norm_sq, columns);
195
+ }
196
+ }
197
+
198
+ NK_PUBLIC void nk_euclideans_packed_e5m2_graniteamx( //
199
+ nk_e5m2_t const *a, void const *b_packed, nk_f32_t *c, //
200
+ nk_size_t rows, nk_size_t columns, nk_size_t depth, //
201
+ nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
202
+ nk_size_t const a_stride_elements = a_stride_in_bytes;
203
+ nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
204
+ nk_dots_packed_e5m2_graniteamx(a, b_packed, c, rows, columns, depth, a_stride_in_bytes, c_stride_in_bytes);
205
+ nk_euclideans_packed_e5m2_graniteamx_finalize_(a, b_packed, c, rows, columns, depth, a_stride_elements,
206
+ c_stride_elements);
207
+ }
208
+
209
+ #pragma endregion E5M2 Packed
210
+
211
+ #pragma region E5M2 Symmetric
212
+
213
+ NK_INTERNAL void nk_angulars_symmetric_e5m2_graniteamx_finalize_(nk_e5m2_t const *vectors, nk_size_t vectors_count,
214
+ nk_size_t depth, nk_size_t stride_elements,
215
+ nk_f32_t *result, nk_size_t result_stride_elements,
216
+ nk_size_t row_start, nk_size_t row_count) {
217
+
218
+ for (nk_size_t row = row_start; row < row_start + row_count; row++)
219
+ result[row * result_stride_elements + row] = nk_dots_reduce_sumsq_e5m2_(vectors + row * stride_elements, depth);
220
+
221
+ nk_f32_t column_norms_cache[256];
222
+ for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
223
+ nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
224
+ for (nk_size_t col = chunk_start; col < chunk_end; col++)
225
+ column_norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_e5m2_(vectors + col * stride_elements, depth);
226
+
227
+ for (nk_size_t row = row_start; row < row_start + row_count; row++) {
228
+ nk_f32_t *r_row = result + row * result_stride_elements;
229
+ nk_size_t col_start = chunk_start > row + 1 ? chunk_start : row + 1;
230
+ if (col_start >= chunk_end) continue;
231
+ nk_angulars_row_f32dots_sapphireamx_(r_row + col_start, column_norms_cache + col_start - chunk_start,
232
+ r_row[row], chunk_end - col_start);
233
+ }
234
+ }
235
+
236
+ for (nk_size_t row = row_start; row < row_start + row_count; row++) result[row * result_stride_elements + row] = 0;
237
+ }
238
+
239
+ NK_PUBLIC void nk_angulars_symmetric_e5m2_graniteamx( //
240
+ nk_e5m2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
241
+ nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
242
+ nk_size_t const stride_elements = stride_in_bytes;
243
+ nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
244
+ nk_dots_symmetric_e5m2_graniteamx(vectors, vectors_count, depth, stride_in_bytes, result, result_stride_in_bytes,
245
+ row_start, row_count);
246
+ nk_angulars_symmetric_e5m2_graniteamx_finalize_(vectors, vectors_count, depth, stride_elements, result,
247
+ result_stride_elements, row_start, row_count);
248
+ }
249
+
250
+ NK_INTERNAL void nk_euclideans_symmetric_e5m2_graniteamx_finalize_(nk_e5m2_t const *vectors, nk_size_t vectors_count,
251
+ nk_size_t depth, nk_size_t stride_elements,
252
+ nk_f32_t *result, nk_size_t result_stride_elements,
253
+ nk_size_t row_start, nk_size_t row_count) {
254
+
255
+ for (nk_size_t row = row_start; row < row_start + row_count; row++)
256
+ result[row * result_stride_elements + row] = nk_dots_reduce_sumsq_e5m2_(vectors + row * stride_elements, depth);
257
+
258
+ nk_f32_t column_norms_cache[256];
259
+ for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
260
+ nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
261
+ for (nk_size_t col = chunk_start; col < chunk_end; col++)
262
+ column_norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_e5m2_(vectors + col * stride_elements, depth);
263
+
264
+ for (nk_size_t row = row_start; row < row_start + row_count; row++) {
265
+ nk_f32_t *r_row = result + row * result_stride_elements;
266
+ nk_size_t col_start = chunk_start > row + 1 ? chunk_start : row + 1;
267
+ if (col_start >= chunk_end) continue;
268
+ nk_euclideans_row_f32dots_sapphireamx_(r_row + col_start, column_norms_cache + col_start - chunk_start,
269
+ r_row[row], chunk_end - col_start);
270
+ }
271
+ }
272
+
273
+ for (nk_size_t row = row_start; row < row_start + row_count; row++) result[row * result_stride_elements + row] = 0;
274
+ }
275
+
276
+ NK_PUBLIC void nk_euclideans_symmetric_e5m2_graniteamx( //
277
+ nk_e5m2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
278
+ nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
279
+ nk_size_t const stride_elements = stride_in_bytes;
280
+ nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
281
+ nk_dots_symmetric_e5m2_graniteamx(vectors, vectors_count, depth, stride_in_bytes, result, result_stride_in_bytes,
282
+ row_start, row_count);
283
+ nk_euclideans_symmetric_e5m2_graniteamx_finalize_(vectors, vectors_count, depth, stride_elements, result,
284
+ result_stride_elements, row_start, row_count);
285
+ }
286
+
287
+ #pragma endregion E5M2 Symmetric
288
+
289
+ #if defined(__clang__)
290
+ #pragma clang attribute pop
291
+ #elif defined(__GNUC__)
292
+ #pragma GCC pop_options
293
+ #endif
294
+
295
+ #if defined(__cplusplus)
296
+ } // extern "C"
297
+ #endif
298
+
299
+ #endif // NK_TARGET_GRANITEAMX
300
+ #endif // NK_TARGET_X8664_
301
+ #endif // NK_SPATIALS_GRANITEAMX_H
@@ -15,6 +15,29 @@
15
15
  extern "C" {
16
16
  #endif
17
17
 
18
+ /* Keep the serial instantiations below actually scalar, regardless of build type.
19
+ * Without this, -O3 + LTO can vectorize or clone the serial kernels under AVX-512
20
+ * callers in dispatch_*.c, which wastes binary and breaks the nk_*_serial-as-scalar-oracle
21
+ * contract that tests and numerical-stability docs rely on. See dots/serial.h. */
22
+ #if defined(__clang__)
23
+ #pragma clang attribute push(__attribute__((noinline)), apply_to = function)
24
+ #elif defined(__GNUC__)
25
+ #pragma GCC push_options
26
+ #pragma GCC optimize("no-tree-vectorize", "no-tree-slp-vectorize", "no-ipa-cp-clone", "no-inline")
27
+ #endif
28
+
29
+ /* Size bias for release. Gated on NDEBUG so Debug builds keep -O0 for stepping. */
30
+ #if defined(NDEBUG)
31
+ #if defined(_MSC_VER)
32
+ #pragma optimize("s", on)
33
+ #elif defined(__clang__)
34
+ #pragma clang attribute push(__attribute__((minsize)), apply_to = function)
35
+ #elif defined(__GNUC__)
36
+ #pragma GCC push_options
37
+ #pragma GCC optimize("Os")
38
+ #endif
39
+ #endif
40
+
18
41
  nk_define_cross_normalized_packed_(angular, f64, serial, f64, f64, f64, /*norm_value_type=*/f64, f64, nk_b256_vec_t,
19
42
  nk_dots_packed_f64_serial, nk_angular_through_f64_from_dot_serial_,
20
43
  nk_dots_reduce_sumsq_f64_, nk_load_b256_serial_, nk_partial_load_b64x4_serial_,
@@ -219,6 +242,22 @@ nk_define_cross_normalized_symmetric_(euclidean, u4, serial, u4x2, u32, /*norm_v
219
242
  nk_dots_reduce_sumsq_u4_, nk_load_b128_serial_, nk_partial_load_b32x4_serial_,
220
243
  nk_store_b128_serial_, nk_partial_store_b32x4_serial_, 2)
221
244
 
245
+ #if defined(NDEBUG)
246
+ #if defined(_MSC_VER)
247
+ #pragma optimize("", on)
248
+ #elif defined(__clang__)
249
+ #pragma clang attribute pop
250
+ #elif defined(__GNUC__)
251
+ #pragma GCC pop_options
252
+ #endif
253
+ #endif
254
+
255
+ #if defined(__clang__)
256
+ #pragma clang attribute pop
257
+ #elif defined(__GNUC__)
258
+ #pragma GCC pop_options
259
+ #endif
260
+
222
261
  #if defined(__cplusplus)
223
262
  } // extern "C"
224
263
  #endif
@@ -97,11 +97,11 @@ nk_define_cross_normalized_symmetric_(euclidean, bf16, skylake, bf16, f32, /*nor
97
97
  nk_dots_reduce_sumsq_bf16_, nk_load_b128_haswell_, nk_partial_load_b32x4_skylake_,
98
98
  nk_store_b128_haswell_, nk_partial_store_b32x4_skylake_, 1)
99
99
 
100
- nk_define_cross_normalized_packed_(angular, e4m3, skylake, e4m3, f32, f32, /*norm_value_type=*/f32, f32, nk_b128_vec_t,
100
+ nk_define_cross_normalized_packed_(angular, e4m3, skylake, e4m3, f16, f32, /*norm_value_type=*/f32, f32, nk_b128_vec_t,
101
101
  nk_dots_packed_e4m3_skylake, nk_angular_through_f32_from_dot_haswell_,
102
102
  nk_dots_reduce_sumsq_e4m3_, nk_load_b128_haswell_, nk_partial_load_b32x4_skylake_,
103
103
  nk_store_b128_haswell_, nk_partial_store_b32x4_skylake_, 1)
104
- nk_define_cross_normalized_packed_(euclidean, e4m3, skylake, e4m3, f32, f32, /*norm_value_type=*/f32, f32,
104
+ nk_define_cross_normalized_packed_(euclidean, e4m3, skylake, e4m3, f16, f32, /*norm_value_type=*/f32, f32,
105
105
  nk_b128_vec_t, nk_dots_packed_e4m3_skylake,
106
106
  nk_euclidean_through_f32_from_dot_haswell_, nk_dots_reduce_sumsq_e4m3_,
107
107
  nk_load_b128_haswell_, nk_partial_load_b32x4_skylake_, nk_store_b128_haswell_,