numkong 7.4.5 → 7.5.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. package/README.md +1 -0
  2. package/binding.gyp +81 -5
  3. package/c/dispatch_f16.c +23 -0
  4. package/c/numkong.c +0 -13
  5. package/include/numkong/attention/sme.h +34 -31
  6. package/include/numkong/capabilities.h +2 -15
  7. package/include/numkong/cast/neon.h +15 -0
  8. package/include/numkong/curved/smef64.h +82 -62
  9. package/include/numkong/dot/rvvbf16.h +1 -1
  10. package/include/numkong/dot/rvvhalf.h +1 -1
  11. package/include/numkong/dot/sve.h +6 -5
  12. package/include/numkong/dot/svebfdot.h +2 -1
  13. package/include/numkong/dot/svehalf.h +6 -5
  14. package/include/numkong/dot/svesdot.h +3 -2
  15. package/include/numkong/dots/graniteamx.h +733 -0
  16. package/include/numkong/dots/serial.h +11 -4
  17. package/include/numkong/dots/sme.h +172 -140
  18. package/include/numkong/dots/smebi32.h +14 -11
  19. package/include/numkong/dots/smef64.h +31 -26
  20. package/include/numkong/dots.h +29 -3
  21. package/include/numkong/each/serial.h +22 -0
  22. package/include/numkong/geospatial/haswell.h +1 -1
  23. package/include/numkong/geospatial/neon.h +1 -1
  24. package/include/numkong/geospatial/serial.h +1 -1
  25. package/include/numkong/geospatial/skylake.h +1 -1
  26. package/include/numkong/maxsim/sme.h +34 -33
  27. package/include/numkong/mesh/serial.h +22 -0
  28. package/include/numkong/reduce/neon.h +29 -0
  29. package/include/numkong/reduce/neonbfdot.h +2 -2
  30. package/include/numkong/reduce/neonfhm.h +4 -4
  31. package/include/numkong/reduce/sve.h +52 -0
  32. package/include/numkong/reduce.h +4 -0
  33. package/include/numkong/set/sve.h +6 -5
  34. package/include/numkong/sets/smebi32.h +35 -30
  35. package/include/numkong/sparse/sve2.h +3 -2
  36. package/include/numkong/spatial/sve.h +7 -6
  37. package/include/numkong/spatial/svebfdot.h +7 -4
  38. package/include/numkong/spatial/svehalf.h +5 -4
  39. package/include/numkong/spatial/svesdot.h +9 -8
  40. package/include/numkong/spatials/graniteamx.h +173 -0
  41. package/include/numkong/spatials/serial.h +22 -0
  42. package/include/numkong/spatials/sme.h +391 -350
  43. package/include/numkong/spatials/smef64.h +79 -70
  44. package/include/numkong/spatials.h +37 -4
  45. package/include/numkong/types.h +59 -0
  46. package/javascript/dist/cjs/numkong.js +13 -0
  47. package/javascript/dist/esm/numkong.js +13 -0
  48. package/javascript/numkong.c +56 -12
  49. package/javascript/numkong.ts +13 -0
  50. package/package.json +7 -7
  51. package/probes/probe.js +2 -2
  52. package/wasm/numkong.wasm +0 -0
@@ -0,0 +1,733 @@
1
+ /**
2
+ * @brief SIMD-accelerated Batched Dot Products for Granite Rapids.
3
+ * @file include/numkong/dots/graniteamx.h
4
+ * @author Ash Vardanian
5
+ * @date April 9, 2026
6
+ *
7
+ * @sa include/numkong/dots.h
8
+ *
9
+ * Native FP16×FP16→FP32 GEMM kernels using Intel AMX-FP16 (TDPFP16PS) on Granite Rapids CPUs.
10
+ * Same tile geometry as BF16 (16 rows × 32 FP16 = 1KB per tile), same 2×2 output blocking,
11
+ * same packing format — only the tile multiply instruction differs.
12
+ *
13
+ * Tile register allocation:
14
+ *
15
+ * - TMM0, TMM1: A matrix tiles (row blocks i and i+16)
16
+ * - TMM2, TMM3: B matrix tiles (column blocks j and j+16)
17
+ * - TMM4-7: C accumulator tiles (2 × 2 output grid = 32×32 F32 results)
18
+ *
19
+ * @section amx_fp16_instructions Intel AMX-FP16 Instructions (Granite Rapids+)
20
+ *
21
+ * FP16 matrix multiply (AMX-FP16):
22
+ *
23
+ * Intrinsic Instruction Operation
24
+ * _tile_dpfp16ps TDPFP16PS (TMM, TMM, TMM) C += A × B (fp16 → f32)
25
+ *
26
+ * TDPFP16PS: 16 × 16 × 32 = 8192 FP16 MACs per instruction (same throughput as TDPBF16PS).
27
+ *
28
+ * @section ozaki_limitations F32→F64 via Ozaki Scheme — Attempted and Abandoned
29
+ *
30
+ * We explored using AMX-FP16 tiles to compute F32→F64 GEMMs via the Ozaki decomposition scheme,
31
+ * splitting each F32 scalar into 2 or 3 FP16 terms and performing cross-product TDPFP16PS operations.
32
+ *
33
+ * Results on Intel Xeon 6776P (Granite Rapids), single-threaded:
34
+ *
35
+ * | Variant | Speed (gso/s) | Precision | Notes |
36
+ * |----------------------|:-------------:|:---------:|:------------------------------------------:|
37
+ * | 2-term, 2×1 blocking | ~150 | ~22 bits | Split accumulators, N=16 F64 flush |
38
+ * | 3-term, 1×1 blocking | ~110 | ~22 bits | 3 accumulators by magnitude band |
39
+ * | Pipelined 2-term | ~156 | ~22 bits | Double-buffered A split, AMX/AVX-512 overlap|
40
+ * | MKL SGEMM | ~170 | ~20 bits | Pure F32, no decomposition |
41
+ * | Skylake F64 accum | ~50 | ~48 bits | F32×F32 multiply, F64 accumulation |
42
+ *
43
+ * The fundamental bottleneck is TDPFP16PS's internal F32 accumulation: each instruction sums
44
+ * 32 FP16×FP16 products into an F32 register (23-bit mantissa). Even with Ozaki cross-term
45
+ * separation into distinct TMM accumulators (preventing magnitude mixing) and periodic extraction
46
+ * to F64 running sums, the per-instruction accumulation of 32 products loses ~5 bits
47
+ * (log2(32) = 5), capping effective precision at ~28 - 5 = ~23 bits — barely exceeding F32 BLAS.
48
+ *
49
+ * Approaches attempted:
50
+ *
51
+ * - 2-term decomposition (a = a_high + a_low): 4 TDPFP16PS per depth tile, ~20-bit products.
52
+ * With split accumulators (main + correction) merged in F64: ~22-bit effective precision.
53
+ * Faster than MKL at small depths (≤512) but precision plateaus at ~22 bits.
54
+ *
55
+ * - 3-term decomposition (a = a_high + a_mid + a_low): 6 TDPFP16PS per depth tile, ~30-bit
56
+ * products. No precision improvement over 2-term because the F32 TMM accumulation is the
57
+ * bottleneck, not the decomposition quality. Strictly slower and no more precise.
58
+ *
59
+ * - Periodic F64 flush (extract TMM accumulators to F64 every N depth tiles): prevents precision
60
+ * degradation at large depths. With N=16, ~15% overhead. Effective precision still ~24 bits
61
+ * (limited by per-TDPFP16PS accumulation of 32 products, not by inter-tile accumulation).
62
+ *
63
+ * - AMX/AVX-512 pipelining (double-buffered A splitting overlapped with AMX compute): ~7%
64
+ * speedup at large sizes. Does not affect precision.
65
+ *
66
+ * Conclusion: AMX-FP16's F32 tile accumulation fundamentally limits Ozaki to ~22-24 bits —
67
+ * comparable to F32 BLAS, far short of the ~48-bit F64 precision needed to justify the complexity.
68
+ * For F32→F64 GEMM, pure AVX-512 with F64 FMA remains the correct approach.
69
+ */
70
+ #ifndef NK_DOTS_GRANITEAMX_H
71
+ #define NK_DOTS_GRANITEAMX_H
72
+
73
+ #if NK_TARGET_X8664_
74
+ #if NK_TARGET_GRANITEAMX
75
+
76
+ #include "numkong/dots/serial.h"
77
+ #include "numkong/dots/sapphireamx.h"
78
+
79
+ #if defined(__cplusplus)
80
+ extern "C" {
81
+ #endif
82
+
83
+ #if defined(__clang__)
84
+ #pragma clang attribute push( \
85
+ __attribute__((target( \
86
+ "avx2,avx512f,avx512vl,avx512bw,avx512dq,avx512fp16,avx512vbmi,f16c,fma,bmi,bmi2,amx-tile,amx-bf16,amx-int8,amx-fp16"))), \
87
+ apply_to = function)
88
+ #elif defined(__GNUC__)
89
+ #pragma GCC push_options
90
+ #pragma GCC target("avx2", "avx512f", "avx512vl", "avx512bw", "avx512dq", "avx512fp16", "avx512vbmi", "f16c", "fma", \
91
+ "bmi", "bmi2", "amx-tile", "amx-bf16", "amx-int8", "amx-fp16")
92
+ #endif
93
+
94
+ #pragma region Tile Types
95
+
96
+ typedef struct {
97
+ NK_ALIGN64 nk_f16_t data[16][32]; // 16 rows × 32 columns = 1KB
98
+ } nk_dots_f16_a16x32_graniteamx_t;
99
+
100
+ typedef struct {
101
+ NK_ALIGN64 nk_f16_t data[16][16][2]; // 16 depth-groups × 16 columns × 2 = 1KB (pair-interleaved)
102
+ } nk_dots_f16_b32x16_graniteamx_t;
103
+
104
+ typedef struct {
105
+ NK_ALIGN64 nk_f32_t data[16][16]; // 16 × 16 = 1KB accumulator
106
+ } nk_dots_f16_state_graniteamx_t;
107
+
108
+ typedef struct {
109
+ nk_dots_f16_state_graniteamx_t c[2][2]; // 4KB total (2×2 output blocking)
110
+ } nk_dots_f16_state2x2_graniteamx_t;
111
+
112
+ #pragma endregion Tile Types
113
+
114
+ #pragma region Helpers
115
+
116
+ /* Initialize FP16 output state to zero */
117
+ NK_INTERNAL void nk_dots_f16_init_graniteamx_(nk_dots_f16_state_graniteamx_t *state) {
118
+ __m512 zero_f32x16 = _mm512_setzero_ps();
119
+ for (nk_size_t row_idx = 0; row_idx < 16; row_idx++) { _mm512_store_ps(state->data[row_idx], zero_f32x16); }
120
+ }
121
+
122
+ /* Load A tile from FP16 row-major source with masking for edge tiles */
123
+ NK_INTERNAL void nk_dots_f16_load_a_graniteamx_( //
124
+ nk_dots_f16_a16x32_graniteamx_t *a_tile, //
125
+ nk_f16_t const *src, nk_size_t src_stride_elements, //
126
+ nk_size_t valid_rows, nk_size_t valid_cols) {
127
+
128
+ __mmask32 column_mask = (valid_cols >= 32) ? 0xFFFFFFFF : ((__mmask32)1 << valid_cols) - 1;
129
+ __m512i zero_i16x32 = _mm512_setzero_si512();
130
+
131
+ for (nk_size_t row_idx = 0; row_idx < 16; row_idx++) {
132
+ if (row_idx < valid_rows) {
133
+ __m512i row_i16x32 = _mm512_maskz_loadu_epi16(column_mask, src + row_idx * src_stride_elements);
134
+ _mm512_store_si512((__m512i *)a_tile->data[row_idx], row_i16x32);
135
+ }
136
+ else { _mm512_store_si512((__m512i *)a_tile->data[row_idx], zero_i16x32); }
137
+ }
138
+ nk_compiler_barrier_sapphireamx_();
139
+ }
140
+
141
+ /* Store F32 state to output matrix with masking for edge tiles */
142
+ NK_INTERNAL void nk_dots_f16_store_graniteamx_( //
143
+ nk_dots_f16_state_graniteamx_t const *state, //
144
+ nk_f32_t *dst, nk_size_t dst_stride_elements, //
145
+ nk_size_t valid_rows, nk_size_t valid_cols) {
146
+
147
+ __mmask16 column_mask = (valid_cols >= 16) ? 0xFFFF : ((__mmask16)1 << valid_cols) - 1;
148
+
149
+ for (nk_size_t row_idx = 0; row_idx < valid_rows; row_idx++) {
150
+ __m512 row_f32x16 = _mm512_load_ps(state->data[row_idx]);
151
+ _mm512_mask_storeu_ps(dst + row_idx * dst_stride_elements, column_mask, row_f32x16);
152
+ }
153
+ }
154
+
155
+ NK_INTERNAL void nk_dots_f16_output2x2_graniteamx_( //
156
+ nk_dots_f16_state2x2_graniteamx_t const *state, //
157
+ nk_f32_t *dst, nk_size_t dst_stride_elements, //
158
+ nk_size_t valid_rows, nk_size_t valid_cols) {
159
+
160
+ nk_size_t const rows_high = (valid_rows > 16) ? 16 : valid_rows;
161
+ nk_size_t const cols_left = (valid_cols > 16) ? 16 : valid_cols;
162
+ nk_size_t const cols_right = (valid_cols > 16) ? valid_cols - 16 : 0;
163
+
164
+ if (rows_high > 0 && cols_left > 0)
165
+ nk_dots_f16_store_graniteamx_(&state->c[0][0], dst, dst_stride_elements, rows_high, cols_left);
166
+ if (rows_high > 0 && cols_right > 0)
167
+ nk_dots_f16_store_graniteamx_(&state->c[0][1], dst + 16, dst_stride_elements, rows_high, cols_right);
168
+
169
+ if (valid_rows > 16) {
170
+ nk_size_t const rows_low = valid_rows - 16;
171
+ nk_f32_t *dst_low = dst + 16 * dst_stride_elements;
172
+ if (cols_left > 0)
173
+ nk_dots_f16_store_graniteamx_(&state->c[1][0], dst_low, dst_stride_elements, rows_low, cols_left);
174
+ if (cols_right > 0)
175
+ nk_dots_f16_store_graniteamx_(&state->c[1][1], dst_low + 16, dst_stride_elements, rows_low, cols_right);
176
+ }
177
+ }
178
+
179
+ NK_INTERNAL void nk_dots_f16_update_graniteamx_( //
180
+ nk_dots_f16_state_graniteamx_t *state, //
181
+ nk_dots_f16_a16x32_graniteamx_t const *a_tile_0, //
182
+ nk_dots_f16_a16x32_graniteamx_t const *a_tile_1, //
183
+ nk_dots_f16_a16x32_graniteamx_t const *a_tile_2, //
184
+ nk_dots_f16_b32x16_graniteamx_t const *b_tile_0, //
185
+ nk_dots_f16_b32x16_graniteamx_t const *b_tile_1, //
186
+ nk_dots_f16_b32x16_graniteamx_t const *b_tile_2) {
187
+
188
+ _tile_loadd(0, state->data, 64);
189
+ _tile_loadd(1, a_tile_0->data, 64);
190
+ _tile_loadd(2, a_tile_1->data, 64);
191
+ _tile_loadd(3, a_tile_2->data, 64);
192
+ _tile_loadd(4, b_tile_0->data, 64);
193
+ _tile_loadd(5, b_tile_1->data, 64);
194
+ _tile_loadd(6, b_tile_2->data, 64);
195
+
196
+ _tile_dpfp16ps(0, 1, 4);
197
+ _tile_dpfp16ps(0, 2, 5);
198
+ _tile_dpfp16ps(0, 3, 6);
199
+
200
+ _tile_stored(0, state->data, 64);
201
+ }
202
+
203
+ #pragma endregion Helpers
204
+
205
+ #pragma region F16 Native
206
+
207
+ NK_PUBLIC nk_size_t nk_dots_packed_size_f16_graniteamx(nk_size_t column_count, nk_size_t depth) {
208
+ nk_size_t const tmm_rows = 16;
209
+ nk_size_t const tmm_cols = 32;
210
+ nk_size_t const tile_bytes = 512 * sizeof(nk_f16_t); // 16 × 32 × 2 = 1KB
211
+
212
+ nk_size_t const full_column_tiles = column_count / tmm_rows;
213
+ nk_size_t const tiles_along_depth = nk_size_divide_round_up_(depth, tmm_cols);
214
+ nk_size_t const column_remainder_count = column_count - full_column_tiles * tmm_rows;
215
+
216
+ // Header (64 bytes aligned)
217
+ nk_size_t size = sizeof(nk_dots_amx_packed_header_t);
218
+
219
+ // All tiles for full column rows (pair-interleaved, depth remainder zero-padded)
220
+ size += full_column_tiles * tiles_along_depth * tile_bytes;
221
+
222
+ // Column edge: remaining rows for ALL depth columns, stored row-major
223
+ if (column_remainder_count > 0) size += column_remainder_count * depth * sizeof(nk_f16_t);
224
+
225
+ // Per-column norms for angular/euclidean distance (4 bytes each: f32)
226
+ size += column_count * sizeof(nk_f32_t);
227
+
228
+ return size;
229
+ }
230
+
231
+ NK_PUBLIC void nk_dots_pack_f16_graniteamx( //
232
+ nk_f16_t const *b, nk_size_t column_count, nk_size_t depth, //
233
+ nk_size_t b_stride_in_bytes, void *b_packed) {
234
+
235
+ // AMX FP16 tile dimensions: 16 rows × 32 columns (512 FP16 elements = 1KB)
236
+ nk_size_t const tmm_rows = 16;
237
+ nk_size_t const tmm_cols = 32;
238
+ nk_size_t const tile_elements = 512;
239
+ nk_size_t const tile_bytes = tile_elements * sizeof(nk_f16_t);
240
+ nk_size_t const b_stride_elements = b_stride_in_bytes / sizeof(nk_f16_t);
241
+
242
+ // Compute layout dimensions
243
+ nk_size_t const column_tiles_count = column_count / tmm_rows;
244
+ nk_size_t const depth_tiles_count = nk_size_divide_round_up_(depth, tmm_cols);
245
+ nk_size_t const column_remainder_count = column_count - column_tiles_count * tmm_rows;
246
+ nk_size_t const total_tiles = column_tiles_count * depth_tiles_count;
247
+
248
+ // Write header with layout metadata
249
+ nk_dots_amx_packed_header_t *header = (nk_dots_amx_packed_header_t *)b_packed;
250
+ header->full_column_tiles = (nk_u32_t)column_tiles_count;
251
+ header->full_depth_tiles = (nk_u32_t)depth_tiles_count;
252
+ header->column_remainder_count = (nk_u32_t)column_remainder_count;
253
+
254
+ // Compute memory region offsets
255
+ nk_size_t const tiles_offset = sizeof(nk_dots_amx_packed_header_t);
256
+ nk_size_t const column_edge_offset = tiles_offset + total_tiles * tile_bytes;
257
+ header->column_edge_offset = (nk_u32_t)column_edge_offset;
258
+
259
+ // Pointers to packed data regions
260
+ nk_f16_t *tiles_ptr = (nk_f16_t *)((char *)b_packed + tiles_offset);
261
+ nk_f16_t *column_edge_ptr = (nk_f16_t *)((char *)b_packed + column_edge_offset);
262
+
263
+ // Pack tiles: gather 16 strided rows into aligned temporary, transpose via SIMD, copy to packed buffer.
264
+ // FP16 has the same 16-bit pair-interleaved layout as BF16, so reuse the BF16 transpose.
265
+ for (nk_size_t column_tile_idx = 0; column_tile_idx < column_tiles_count; column_tile_idx++) {
266
+ for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
267
+
268
+ nk_size_t const tile_index = column_tile_idx * depth_tiles_count + depth_tile_idx;
269
+ nk_f16_t *tile_output = tiles_ptr + tile_index * tile_elements;
270
+
271
+ nk_size_t const src_row_start = column_tile_idx * tmm_rows;
272
+ nk_size_t const src_column_start = depth_tile_idx * tmm_cols;
273
+ nk_size_t const columns_to_pack = (src_column_start + tmm_cols <= depth) ? tmm_cols
274
+ : (depth - src_column_start);
275
+
276
+ // Gather 16 strided source rows into a contiguous aligned tile
277
+ nk_dots_bf16_a16x32_sapphireamx_t source_tile;
278
+ if (columns_to_pack == tmm_cols) {
279
+ for (nk_size_t row_idx = 0; row_idx < tmm_rows; row_idx++) {
280
+ nk_f16_t const *source_row = b + (src_row_start + row_idx) * b_stride_elements + src_column_start;
281
+ _mm512_store_si512(&source_tile.data[row_idx][0], _mm512_loadu_si512(source_row));
282
+ }
283
+ }
284
+ else {
285
+ __mmask32 depth_mask = (__mmask32)((columns_to_pack < 32) ? ((1U << columns_to_pack) - 1) : ~0U);
286
+ for (nk_size_t row_idx = 0; row_idx < tmm_rows; row_idx++) {
287
+ nk_f16_t const *source_row = b + (src_row_start + row_idx) * b_stride_elements + src_column_start;
288
+ _mm512_store_si512(&source_tile.data[row_idx][0], _mm512_maskz_loadu_epi16(depth_mask, source_row));
289
+ }
290
+ }
291
+
292
+ // Transpose into aligned local, then copy to (potentially unaligned) packed buffer.
293
+ // BF16 and FP16 share identical 16-bit pair-interleaved layout for TDPBF16PS/TDPFP16PS.
294
+ nk_dots_bf16_b32x16_sapphireamx_t transposed_tile;
295
+ nk_dots_pack_bf16_transposed_sapphireamx_(&source_tile, &transposed_tile);
296
+ for (nk_size_t i = 0; i < tile_bytes; i += 64)
297
+ _mm512_storeu_si512((char *)tile_output + i, _mm512_load_si512((char const *)&transposed_tile + i));
298
+ }
299
+ }
300
+
301
+ // Pack column-remainder rows using vectorized masked copies
302
+ if (column_remainder_count > 0) {
303
+ nk_size_t const remainder_start_row = column_tiles_count * tmm_rows;
304
+ for (nk_size_t row_idx = 0; row_idx < column_remainder_count; row_idx++) {
305
+ nk_f16_t const *src_row = b + (remainder_start_row + row_idx) * b_stride_elements;
306
+ nk_f16_t *dst_row = column_edge_ptr + row_idx * depth;
307
+ nk_size_t column_idx = 0;
308
+ for (; column_idx + 32 <= depth; column_idx += 32) {
309
+ _mm512_storeu_si512(dst_row + column_idx, _mm512_loadu_si512(src_row + column_idx));
310
+ }
311
+ if (column_idx < depth) {
312
+ __mmask32 tail_mask = (__mmask32)((1U << (depth - column_idx)) - 1);
313
+ _mm512_mask_storeu_epi16(dst_row + column_idx, tail_mask,
314
+ _mm512_maskz_loadu_epi16(tail_mask, src_row + column_idx));
315
+ }
316
+ }
317
+ }
318
+
319
+ // Compute and store per-column norms for angular/euclidean distance
320
+ nk_size_t norms_offset = column_edge_offset +
321
+ (column_remainder_count > 0 ? column_remainder_count * depth * sizeof(nk_f16_t) : 0);
322
+ header->norms_byte_offset = (nk_u32_t)norms_offset;
323
+ nk_f32_t *norms = (nk_f32_t *)((char *)b_packed + norms_offset);
324
+ for (nk_size_t col = 0; col < column_count; col++)
325
+ norms[col] = nk_dots_reduce_sumsq_f16_(b + col * b_stride_elements, depth);
326
+ }
327
+
328
+ NK_PUBLIC void nk_dots_packed_f16_graniteamx( //
329
+ nk_f16_t const *a, void const *b_packed, nk_f32_t *c, //
330
+ nk_size_t rows_count, nk_size_t cols_count, nk_size_t depth, nk_size_t a_stride_bytes, nk_size_t c_stride_bytes) {
331
+ nk_unused_(cols_count);
332
+
333
+ // Parse packed B header
334
+ nk_dots_amx_packed_header_t const *header = (nk_dots_amx_packed_header_t const *)b_packed;
335
+ nk_size_t const column_tiles_count = header->full_column_tiles;
336
+ nk_size_t const depth_tiles_count = header->full_depth_tiles;
337
+ nk_size_t const column_remainder_count = header->column_remainder_count;
338
+
339
+ // Packed B data regions
340
+ nk_f16_t const *b_tiles_base = (nk_f16_t const *)((char const *)b_packed + sizeof(nk_dots_amx_packed_header_t));
341
+ nk_f16_t const *col_edge_ptr = (nk_f16_t const *)((char const *)b_packed + header->column_edge_offset);
342
+
343
+ // Stride conversions
344
+ nk_size_t const a_stride_elements = a_stride_bytes / sizeof(nk_f16_t);
345
+ nk_size_t const c_stride_elements = c_stride_bytes / sizeof(nk_f32_t);
346
+
347
+ // Tile dimensions
348
+ nk_size_t const tile_depth = 32;
349
+ nk_size_t const tile_size = 512;
350
+ nk_size_t const full_cols = column_tiles_count * 16;
351
+
352
+ // Block counts (32 × 32 output blocks = 2 × 2 tiles)
353
+ nk_size_t const row_blocks_count = nk_size_divide_round_up_(rows_count, 32);
354
+ nk_size_t const col_blocks_count = column_tiles_count / 2;
355
+
356
+ if (depth_tiles_count == 0) return;
357
+
358
+ // Tile buffers for A (only used for edge tiles)
359
+ nk_dots_f16_a16x32_graniteamx_t a_tile_top, a_tile_bottom;
360
+ nk_dots_f16_state2x2_graniteamx_t c_accum_buffer;
361
+
362
+ // Precompute: number of full depth-tiles (no masking needed)
363
+ nk_size_t const full_depth_tiles_count = depth / tile_depth;
364
+ nk_size_t const depth_remainder = depth % tile_depth;
365
+
366
+ nk_amx_tile_configure_sapphireamx_();
367
+
368
+ // Loop order: row_blocks outer, col_blocks inner - maximizes A tile L2 cache reuse
369
+ for (nk_size_t row_block_idx = 0; row_block_idx < row_blocks_count; row_block_idx++) {
370
+ nk_size_t const row_block_start = row_block_idx * 32;
371
+ nk_size_t const valid_rows_count = (row_block_start + 32 <= rows_count) ? 32 : (rows_count - row_block_start);
372
+ nk_size_t const is_full_row_block = (valid_rows_count == 32);
373
+
374
+ for (nk_size_t column_block_idx = 0; column_block_idx < col_blocks_count; column_block_idx++) {
375
+ nk_size_t const col_block_start = column_block_idx * 32;
376
+ nk_size_t const b_column_left_base = (column_block_idx * 2) * depth_tiles_count;
377
+ nk_size_t const b_column_right_base = (column_block_idx * 2 + 1) * depth_tiles_count;
378
+
379
+ // Zero accumulators (TMM4-7 stay resident across entire depth loop)
380
+ _tile_zero(4);
381
+ _tile_zero(5);
382
+ _tile_zero(6);
383
+ _tile_zero(7);
384
+
385
+ // Fast path: full row-block with full depth-tiles → direct A load with 2-deep pipelining
386
+ if (is_full_row_block && full_depth_tiles_count > 0) {
387
+ nk_f16_t const *a_top_base = a + row_block_start * a_stride_elements;
388
+ nk_f16_t const *a_bottom_base = a + (row_block_start + 16) * a_stride_elements;
389
+
390
+ nk_dots_f16_b32x16_graniteamx_t const *b_tile_left =
391
+ (nk_dots_f16_b32x16_graniteamx_t const *)(b_tiles_base + b_column_left_base * tile_size);
392
+ nk_dots_f16_b32x16_graniteamx_t const *b_tile_right =
393
+ (nk_dots_f16_b32x16_graniteamx_t const *)(b_tiles_base + b_column_right_base * tile_size);
394
+
395
+ // Prologue: load first depth tile
396
+ _tile_loadd(0, a_top_base, a_stride_bytes);
397
+ _tile_loadd(1, a_bottom_base, a_stride_bytes);
398
+ _tile_loadd(2, b_tile_left->data, 64);
399
+ _tile_loadd(3, b_tile_right->data, 64);
400
+
401
+ // Main loop: 2-deep software pipelining
402
+ for (nk_size_t depth_tile_idx = 0; depth_tile_idx < full_depth_tiles_count - 1; depth_tile_idx++) {
403
+ nk_size_t const next_depth_offset = (depth_tile_idx + 1) * tile_depth;
404
+
405
+ _tile_dpfp16ps(4, 0, 2);
406
+ _tile_dpfp16ps(5, 0, 3);
407
+ _tile_dpfp16ps(6, 1, 2);
408
+ _tile_dpfp16ps(7, 1, 3);
409
+
410
+ _tile_loadd(0, a_top_base + next_depth_offset, a_stride_bytes);
411
+ _tile_loadd(1, a_bottom_base + next_depth_offset, a_stride_bytes);
412
+ b_tile_left = (nk_dots_f16_b32x16_graniteamx_t const *)(b_tiles_base +
413
+ (b_column_left_base + depth_tile_idx + 1) *
414
+ tile_size);
415
+ b_tile_right = (nk_dots_f16_b32x16_graniteamx_t const *)(b_tiles_base + (b_column_right_base +
416
+ depth_tile_idx + 1) *
417
+ tile_size);
418
+ _tile_loadd(2, b_tile_left->data, 64);
419
+ _tile_loadd(3, b_tile_right->data, 64);
420
+ }
421
+
422
+ // Epilogue: final depth tile
423
+ _tile_dpfp16ps(4, 0, 2);
424
+ _tile_dpfp16ps(5, 0, 3);
425
+ _tile_dpfp16ps(6, 1, 2);
426
+ _tile_dpfp16ps(7, 1, 3);
427
+
428
+ // Handle partial depth-tile (if any)
429
+ if (depth_remainder > 0) {
430
+ nk_size_t const depth_offset = full_depth_tiles_count * tile_depth;
431
+
432
+ nk_dots_f16_load_a_graniteamx_(&a_tile_top, a_top_base + depth_offset, a_stride_elements, 16,
433
+ depth_remainder);
434
+ nk_dots_f16_load_a_graniteamx_(&a_tile_bottom, a_bottom_base + depth_offset, a_stride_elements, 16,
435
+ depth_remainder);
436
+
437
+ b_tile_left = (nk_dots_f16_b32x16_graniteamx_t const *)(b_tiles_base + (b_column_left_base +
438
+ full_depth_tiles_count) *
439
+ tile_size);
440
+ b_tile_right = (nk_dots_f16_b32x16_graniteamx_t const *)(b_tiles_base + (b_column_right_base +
441
+ full_depth_tiles_count) *
442
+ tile_size);
443
+
444
+ _tile_loadd(0, a_tile_top.data, 64);
445
+ _tile_loadd(1, a_tile_bottom.data, 64);
446
+ _tile_loadd(2, b_tile_left->data, 64);
447
+ _tile_loadd(3, b_tile_right->data, 64);
448
+
449
+ _tile_dpfp16ps(4, 0, 2);
450
+ _tile_dpfp16ps(5, 0, 3);
451
+ _tile_dpfp16ps(6, 1, 2);
452
+ _tile_dpfp16ps(7, 1, 3);
453
+ }
454
+ }
455
+ // Full row-block but only partial depth tile (depth < tile_depth)
456
+ else if (is_full_row_block) {
457
+ nk_f16_t const *a_top_base = a + row_block_start * a_stride_elements;
458
+ nk_f16_t const *a_bottom_base = a + (row_block_start + 16) * a_stride_elements;
459
+
460
+ nk_dots_f16_load_a_graniteamx_(&a_tile_top, a_top_base, a_stride_elements, 16, depth_remainder);
461
+ nk_dots_f16_load_a_graniteamx_(&a_tile_bottom, a_bottom_base, a_stride_elements, 16, depth_remainder);
462
+
463
+ nk_dots_f16_b32x16_graniteamx_t const *b_tile_left =
464
+ (nk_dots_f16_b32x16_graniteamx_t const *)(b_tiles_base + b_column_left_base * tile_size);
465
+ nk_dots_f16_b32x16_graniteamx_t const *b_tile_right =
466
+ (nk_dots_f16_b32x16_graniteamx_t const *)(b_tiles_base + b_column_right_base * tile_size);
467
+
468
+ _tile_loadd(0, a_tile_top.data, 64);
469
+ _tile_loadd(1, a_tile_bottom.data, 64);
470
+ _tile_loadd(2, b_tile_left->data, 64);
471
+ _tile_loadd(3, b_tile_right->data, 64);
472
+
473
+ _tile_dpfp16ps(4, 0, 2);
474
+ _tile_dpfp16ps(5, 0, 3);
475
+ _tile_dpfp16ps(6, 1, 2);
476
+ _tile_dpfp16ps(7, 1, 3);
477
+ }
478
+ // Slow path: edge row-block → buffered load with masking
479
+ else {
480
+ nk_size_t const rows_in_high_tile = (valid_rows_count > 16) ? 16 : valid_rows_count;
481
+ nk_size_t const rows_in_low_tile = (valid_rows_count > 16) ? valid_rows_count - 16 : 0;
482
+
483
+ for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
484
+ nk_size_t const depth_offset = depth_tile_idx * tile_depth;
485
+ nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth
486
+ : depth_remainder;
487
+
488
+ nk_dots_f16_load_a_graniteamx_(&a_tile_top, a + row_block_start * a_stride_elements + depth_offset,
489
+ a_stride_elements, rows_in_high_tile, valid_depth);
490
+ if (rows_in_low_tile > 0) {
491
+ nk_dots_f16_load_a_graniteamx_(&a_tile_bottom,
492
+ a + (row_block_start + 16) * a_stride_elements + depth_offset,
493
+ a_stride_elements, rows_in_low_tile, valid_depth);
494
+ }
495
+
496
+ nk_dots_f16_b32x16_graniteamx_t const *b_tile_left =
497
+ (nk_dots_f16_b32x16_graniteamx_t const *)(b_tiles_base +
498
+ (b_column_left_base + depth_tile_idx) * tile_size);
499
+ nk_dots_f16_b32x16_graniteamx_t const *b_tile_right =
500
+ (nk_dots_f16_b32x16_graniteamx_t const *)(b_tiles_base +
501
+ (b_column_right_base + depth_tile_idx) * tile_size);
502
+
503
+ _tile_loadd(0, a_tile_top.data, 64);
504
+ _tile_loadd(1, a_tile_bottom.data, 64);
505
+ _tile_loadd(2, b_tile_left->data, 64);
506
+ _tile_loadd(3, b_tile_right->data, 64);
507
+
508
+ _tile_dpfp16ps(4, 0, 2);
509
+ _tile_dpfp16ps(5, 0, 3);
510
+ _tile_dpfp16ps(6, 1, 2);
511
+ _tile_dpfp16ps(7, 1, 3);
512
+ }
513
+ }
514
+
515
+ // Store accumulators to output (once per output block)
516
+ if (is_full_row_block) {
517
+ nk_f32_t *c_block = c + row_block_start * c_stride_elements + col_block_start;
518
+ _tile_stored(4, c_block, c_stride_bytes);
519
+ _tile_stored(5, c_block + 16, c_stride_bytes);
520
+ _tile_stored(6, (nk_f32_t *)((char *)c_block + 16 * c_stride_bytes), c_stride_bytes);
521
+ _tile_stored(7, (nk_f32_t *)((char *)c_block + 16 * c_stride_bytes) + 16, c_stride_bytes);
522
+ }
523
+ else {
524
+ _tile_stored(4, c_accum_buffer.c[0][0].data, 64);
525
+ _tile_stored(5, c_accum_buffer.c[0][1].data, 64);
526
+ _tile_stored(6, c_accum_buffer.c[1][0].data, 64);
527
+ _tile_stored(7, c_accum_buffer.c[1][1].data, 64);
528
+ nk_dots_f16_output2x2_graniteamx_(&c_accum_buffer,
529
+ c + row_block_start * c_stride_elements + col_block_start,
530
+ c_stride_elements, valid_rows_count, 32);
531
+ }
532
+ }
533
+ }
534
+
535
+ // Handle odd column-tile (single 16-column tile if column_tiles_count is odd)
536
+ if (column_tiles_count % 2 == 1) {
537
+ nk_size_t const column_tile_idx = column_tiles_count - 1;
538
+ nk_size_t const col_start = column_tile_idx * 16;
539
+ nk_size_t const b_column_base = column_tile_idx * depth_tiles_count;
540
+
541
+ for (nk_size_t row_block_idx = 0; row_block_idx < row_blocks_count; row_block_idx++) {
542
+ nk_size_t const row_block_start = row_block_idx * 32;
543
+ nk_size_t const valid_rows_count = (row_block_start + 32 <= rows_count) ? 32
544
+ : (rows_count - row_block_start);
545
+ nk_size_t const rows_in_high_tile = (valid_rows_count > 16) ? 16 : valid_rows_count;
546
+ nk_size_t const rows_in_low_tile = (valid_rows_count > 16) ? valid_rows_count - 16 : 0;
547
+
548
+ nk_dots_f16_state_graniteamx_t c_high_state, c_low_state;
549
+
550
+ _tile_zero(4);
551
+ _tile_zero(6);
552
+
553
+ for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
554
+ nk_size_t const depth_offset = depth_tile_idx * tile_depth;
555
+ nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
556
+
557
+ nk_dots_f16_load_a_graniteamx_(&a_tile_top, a + row_block_start * a_stride_elements + depth_offset,
558
+ a_stride_elements, rows_in_high_tile, valid_depth);
559
+ if (rows_in_low_tile > 0) {
560
+ nk_dots_f16_load_a_graniteamx_(&a_tile_bottom,
561
+ a + (row_block_start + 16) * a_stride_elements + depth_offset,
562
+ a_stride_elements, rows_in_low_tile, valid_depth);
563
+ }
564
+
565
+ nk_dots_f16_b32x16_graniteamx_t const *b_tile =
566
+ (nk_dots_f16_b32x16_graniteamx_t const *)(b_tiles_base +
567
+ (b_column_base + depth_tile_idx) * tile_size);
568
+
569
+ _tile_loadd(0, a_tile_top.data, 64);
570
+ _tile_loadd(1, a_tile_bottom.data, 64);
571
+ _tile_loadd(2, b_tile->data, 64);
572
+
573
+ _tile_dpfp16ps(4, 0, 2);
574
+ _tile_dpfp16ps(6, 1, 2);
575
+ }
576
+
577
+ _tile_stored(4, c_high_state.data, 64);
578
+ _tile_stored(6, c_low_state.data, 64);
579
+
580
+ nk_dots_f16_store_graniteamx_(&c_high_state, c + row_block_start * c_stride_elements + col_start,
581
+ c_stride_elements, rows_in_high_tile, 16);
582
+ if (rows_in_low_tile > 0) {
583
+ nk_dots_f16_store_graniteamx_(&c_low_state, c + (row_block_start + 16) * c_stride_elements + col_start,
584
+ c_stride_elements, rows_in_low_tile, 16);
585
+ }
586
+ }
587
+ }
588
+
589
+ // Handle column-edge (remaining columns < 16) using AMX with partial tiles
590
+ if (column_remainder_count > 0) {
591
+ for (nk_size_t row_block_idx = 0; row_block_idx < row_blocks_count; row_block_idx++) {
592
+ nk_size_t const row_block_start = row_block_idx * 32;
593
+ nk_size_t const valid_rows_count = (row_block_start + 32 <= rows_count) ? 32
594
+ : (rows_count - row_block_start);
595
+ nk_size_t const rows_in_high_tile = (valid_rows_count > 16) ? 16 : valid_rows_count;
596
+ nk_size_t const rows_in_low_tile = (valid_rows_count > 16) ? valid_rows_count - 16 : 0;
597
+
598
+ nk_dots_f16_state_graniteamx_t c_high_state, c_low_state;
599
+ nk_dots_bf16_a16x32_sapphireamx_t b_as_a;
600
+ nk_dots_bf16_b32x16_sapphireamx_t b_tile;
601
+
602
+ _tile_zero(4);
603
+ _tile_zero(6);
604
+
605
+ for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
606
+ nk_size_t const depth_offset = depth_tile_idx * tile_depth;
607
+ nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
608
+
609
+ nk_dots_f16_load_a_graniteamx_(&a_tile_top, a + row_block_start * a_stride_elements + depth_offset,
610
+ a_stride_elements, rows_in_high_tile, valid_depth);
611
+ if (rows_in_low_tile > 0) {
612
+ nk_dots_f16_load_a_graniteamx_(&a_tile_bottom,
613
+ a + (row_block_start + 16) * a_stride_elements + depth_offset,
614
+ a_stride_elements, rows_in_low_tile, valid_depth);
615
+ }
616
+
617
+ // Load edge columns as BF16-shaped tile (same 16-bit layout) and transpose on-the-fly
618
+ nk_dots_bf16_load_a_sapphireamx_(&b_as_a, (nk_bf16_t const *)(col_edge_ptr + depth_offset), depth,
619
+ column_remainder_count, valid_depth);
620
+ nk_dots_pack_bf16_transposed_sapphireamx_(&b_as_a, &b_tile);
621
+
622
+ _tile_loadd(0, a_tile_top.data, 64);
623
+ _tile_loadd(1, a_tile_bottom.data, 64);
624
+ _tile_loadd(2, b_tile.data, 64);
625
+
626
+ _tile_dpfp16ps(4, 0, 2);
627
+ _tile_dpfp16ps(6, 1, 2);
628
+ }
629
+
630
+ _tile_stored(4, c_high_state.data, 64);
631
+ _tile_stored(6, c_low_state.data, 64);
632
+
633
+ nk_dots_f16_store_graniteamx_(&c_high_state, c + row_block_start * c_stride_elements + full_cols,
634
+ c_stride_elements, rows_in_high_tile, column_remainder_count);
635
+ if (rows_in_low_tile > 0) {
636
+ nk_dots_f16_store_graniteamx_(&c_low_state, c + (row_block_start + 16) * c_stride_elements + full_cols,
637
+ c_stride_elements, rows_in_low_tile, column_remainder_count);
638
+ }
639
+ }
640
+ }
641
+
642
+ _tile_release();
643
+ }
644
+
645
+ NK_PUBLIC void nk_dots_symmetric_f16_graniteamx( //
646
+ nk_f16_t const *vectors, nk_size_t vectors_count, nk_size_t depth, //
647
+ nk_size_t stride_in_bytes, nk_f32_t *result, nk_size_t result_stride_in_bytes, //
648
+ nk_size_t row_start, nk_size_t row_count) {
649
+
650
+ nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_f16_t);
651
+ nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
652
+
653
+ nk_size_t const row_end = (row_count == 0)
654
+ ? vectors_count
655
+ : (row_start + row_count < vectors_count ? row_start + row_count : vectors_count);
656
+
657
+ // Round depth up to multiple of 96 (3 tiles × 32 elements)
658
+ nk_size_t const depth_tiles = nk_size_divide_round_up_(depth, 32);
659
+ nk_size_t const depth_tile_groups = nk_size_divide_round_up_(depth_tiles, 3);
660
+
661
+ nk_dots_f16_a16x32_graniteamx_t a_tiles[3];
662
+ nk_dots_f16_a16x32_graniteamx_t b_src_tiles[3];
663
+ nk_dots_f16_b32x16_graniteamx_t b_tiles[3];
664
+ nk_dots_f16_state_graniteamx_t state;
665
+
666
+ nk_amx_tile_configure_sapphireamx_();
667
+
668
+ for (nk_size_t row_tile = row_start; row_tile < row_end; row_tile += 16) {
669
+ nk_size_t const valid_rows = (row_tile + 16 <= row_end) ? 16 : (row_end - row_tile);
670
+
671
+ for (nk_size_t col_tile = 0; col_tile < vectors_count; col_tile += 16) {
672
+ nk_size_t const valid_cols = (col_tile + 16 <= vectors_count) ? 16 : (vectors_count - col_tile);
673
+
674
+ nk_dots_f16_init_graniteamx_(&state);
675
+
676
+ for (nk_size_t depth_group_idx = 0; depth_group_idx < depth_tile_groups; depth_group_idx++) {
677
+ nk_size_t const depth_base = depth_group_idx * 96;
678
+
679
+ for (int tile_idx = 0; tile_idx < 3; tile_idx++) {
680
+ nk_size_t const depth_start = depth_base + tile_idx * 32;
681
+ nk_size_t const valid_depth = (depth_start + 32 <= depth)
682
+ ? 32
683
+ : (depth > depth_start ? depth - depth_start : 0);
684
+
685
+ nk_dots_f16_load_a_graniteamx_( //
686
+ &a_tiles[tile_idx], //
687
+ vectors + row_tile * stride_elements + depth_start, //
688
+ stride_elements, valid_rows, valid_depth);
689
+
690
+ if (row_tile == col_tile) {
691
+ // Reuse A data as B (self-correlation on diagonal)
692
+ nk_dots_pack_bf16_transposed_sapphireamx_(
693
+ (nk_dots_bf16_a16x32_sapphireamx_t const *)&a_tiles[tile_idx],
694
+ (nk_dots_bf16_b32x16_sapphireamx_t *)&b_tiles[tile_idx]);
695
+ }
696
+ else {
697
+ nk_dots_f16_load_a_graniteamx_( //
698
+ &b_src_tiles[tile_idx], //
699
+ vectors + col_tile * stride_elements + depth_start, //
700
+ stride_elements, valid_cols, valid_depth);
701
+ nk_dots_pack_bf16_transposed_sapphireamx_(
702
+ (nk_dots_bf16_a16x32_sapphireamx_t const *)&b_src_tiles[tile_idx],
703
+ (nk_dots_bf16_b32x16_sapphireamx_t *)&b_tiles[tile_idx]);
704
+ }
705
+ }
706
+
707
+ nk_dots_f16_update_graniteamx_( //
708
+ &state, &a_tiles[0], &a_tiles[1], &a_tiles[2], //
709
+ &b_tiles[0], &b_tiles[1], &b_tiles[2]);
710
+ }
711
+
712
+ nk_dots_f16_store_graniteamx_( //
713
+ &state, result + row_tile * result_stride_elements + col_tile, //
714
+ result_stride_elements, valid_rows, valid_cols);
715
+ }
716
+ }
717
+ }
718
+
719
+ #pragma endregion F16 Native
720
+
721
+ #if defined(__clang__)
722
+ #pragma clang attribute pop
723
+ #elif defined(__GNUC__)
724
+ #pragma GCC pop_options
725
+ #endif
726
+
727
+ #if defined(__cplusplus)
728
+ } // extern "C"
729
+ #endif
730
+
731
+ #endif // NK_TARGET_GRANITEAMX
732
+ #endif // NK_TARGET_X8664_
733
+ #endif // NK_DOTS_GRANITEAMX_H