numkong 7.4.5 → 7.6.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +1 -0
- package/binding.gyp +99 -5
- package/c/dispatch_e5m2.c +23 -3
- package/c/dispatch_f16.c +23 -0
- package/c/numkong.c +0 -13
- package/include/numkong/attention/sme.h +34 -31
- package/include/numkong/capabilities.h +2 -15
- package/include/numkong/cast/README.md +3 -0
- package/include/numkong/cast/haswell.h +28 -64
- package/include/numkong/cast/neon.h +15 -0
- package/include/numkong/cast/serial.h +17 -0
- package/include/numkong/cast/skylake.h +67 -52
- package/include/numkong/cast.h +1 -0
- package/include/numkong/curved/smef64.h +82 -62
- package/include/numkong/dot/README.md +1 -0
- package/include/numkong/dot/haswell.h +92 -13
- package/include/numkong/dot/rvvbf16.h +1 -1
- package/include/numkong/dot/rvvhalf.h +1 -1
- package/include/numkong/dot/serial.h +15 -0
- package/include/numkong/dot/skylake.h +61 -14
- package/include/numkong/dot/sve.h +6 -5
- package/include/numkong/dot/svebfdot.h +2 -1
- package/include/numkong/dot/svehalf.h +6 -5
- package/include/numkong/dot/svesdot.h +3 -2
- package/include/numkong/dots/README.md +2 -0
- package/include/numkong/dots/graniteamx.h +1167 -0
- package/include/numkong/dots/haswell.h +28 -28
- package/include/numkong/dots/sapphireamx.h +1 -1
- package/include/numkong/dots/serial.h +33 -11
- package/include/numkong/dots/skylake.h +28 -23
- package/include/numkong/dots/sme.h +172 -140
- package/include/numkong/dots/smebi32.h +14 -11
- package/include/numkong/dots/smef64.h +31 -26
- package/include/numkong/dots.h +41 -3
- package/include/numkong/each/serial.h +39 -0
- package/include/numkong/geospatial/haswell.h +1 -1
- package/include/numkong/geospatial/neon.h +1 -1
- package/include/numkong/geospatial/serial.h +15 -4
- package/include/numkong/geospatial/skylake.h +1 -1
- package/include/numkong/maxsim/serial.h +15 -0
- package/include/numkong/maxsim/sme.h +34 -33
- package/include/numkong/mesh/README.md +50 -44
- package/include/numkong/mesh/genoa.h +462 -0
- package/include/numkong/mesh/haswell.h +806 -933
- package/include/numkong/mesh/neon.h +871 -943
- package/include/numkong/mesh/neonbfdot.h +382 -522
- package/include/numkong/mesh/neonfhm.h +676 -0
- package/include/numkong/mesh/rvv.h +404 -319
- package/include/numkong/mesh/serial.h +225 -161
- package/include/numkong/mesh/skylake.h +1029 -1585
- package/include/numkong/mesh/v128relaxed.h +403 -377
- package/include/numkong/mesh.h +38 -0
- package/include/numkong/reduce/neon.h +29 -0
- package/include/numkong/reduce/neonbfdot.h +2 -2
- package/include/numkong/reduce/neonfhm.h +4 -4
- package/include/numkong/reduce/serial.h +15 -1
- package/include/numkong/reduce/sve.h +52 -0
- package/include/numkong/reduce.h +4 -0
- package/include/numkong/set/sve.h +6 -5
- package/include/numkong/sets/smebi32.h +35 -30
- package/include/numkong/sparse/serial.h +17 -2
- package/include/numkong/sparse/sve2.h +3 -2
- package/include/numkong/spatial/genoa.h +0 -68
- package/include/numkong/spatial/haswell.h +98 -56
- package/include/numkong/spatial/serial.h +15 -0
- package/include/numkong/spatial/skylake.h +114 -54
- package/include/numkong/spatial/sve.h +7 -6
- package/include/numkong/spatial/svebfdot.h +7 -4
- package/include/numkong/spatial/svehalf.h +5 -4
- package/include/numkong/spatial/svesdot.h +9 -8
- package/include/numkong/spatial.h +0 -12
- package/include/numkong/spatials/graniteamx.h +301 -0
- package/include/numkong/spatials/serial.h +39 -0
- package/include/numkong/spatials/skylake.h +2 -2
- package/include/numkong/spatials/sme.h +391 -350
- package/include/numkong/spatials/smef64.h +79 -70
- package/include/numkong/spatials.h +54 -4
- package/include/numkong/tensor.hpp +107 -23
- package/include/numkong/types.h +59 -0
- package/javascript/dist/cjs/numkong.js +13 -0
- package/javascript/dist/esm/numkong.js +13 -0
- package/javascript/numkong.c +59 -14
- package/javascript/numkong.ts +13 -0
- package/package.json +7 -7
- package/probes/probe.js +2 -2
- package/wasm/numkong.wasm +0 -0
|
@@ -0,0 +1,1167 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SIMD-accelerated Batched Dot Products for Granite Rapids.
|
|
3
|
+
* @file include/numkong/dots/graniteamx.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date April 9, 2026
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/dots.h
|
|
8
|
+
*
|
|
9
|
+
* Native FP16×FP16→FP32 GEMM kernels using Intel AMX-FP16 (TDPFP16PS) on Granite Rapids CPUs.
|
|
10
|
+
* Same tile geometry as BF16 (16 rows × 32 FP16 = 1KB per tile), same 2×2 output blocking,
|
|
11
|
+
* same packing format — only the tile multiply instruction differs.
|
|
12
|
+
*
|
|
13
|
+
* Tile register allocation:
|
|
14
|
+
*
|
|
15
|
+
* - TMM0, TMM1: A matrix tiles (row blocks i and i+16)
|
|
16
|
+
* - TMM2, TMM3: B matrix tiles (column blocks j and j+16)
|
|
17
|
+
* - TMM4-7: C accumulator tiles (2 × 2 output grid = 32×32 F32 results)
|
|
18
|
+
*
|
|
19
|
+
* @section amx_fp16_instructions Intel AMX-FP16 Instructions (Granite Rapids+)
|
|
20
|
+
*
|
|
21
|
+
* FP16 matrix multiply (AMX-FP16):
|
|
22
|
+
*
|
|
23
|
+
* Intrinsic Instruction Operation
|
|
24
|
+
* _tile_dpfp16ps TDPFP16PS (TMM, TMM, TMM) C += A × B (fp16 → f32)
|
|
25
|
+
*
|
|
26
|
+
* TDPFP16PS: 16 × 16 × 32 = 8192 FP16 MACs per instruction (same throughput as TDPBF16PS).
|
|
27
|
+
*
|
|
28
|
+
* @section ozaki_limitations F32→F64 via Ozaki Scheme — Attempted and Abandoned
|
|
29
|
+
*
|
|
30
|
+
* We explored using AMX-FP16 tiles to compute F32→F64 GEMMs via the Ozaki decomposition scheme,
|
|
31
|
+
* splitting each F32 scalar into 2 or 3 FP16 terms and performing cross-product TDPFP16PS operations.
|
|
32
|
+
*
|
|
33
|
+
* Results on Intel Xeon 6776P (Granite Rapids), single-threaded:
|
|
34
|
+
*
|
|
35
|
+
* | Variant | Speed (gso/s) | Precision | Notes |
|
|
36
|
+
* |----------------------|:-------------:|:---------:|:------------------------------------------:|
|
|
37
|
+
* | 2-term, 2×1 blocking | ~150 | ~22 bits | Split accumulators, N=16 F64 flush |
|
|
38
|
+
* | 3-term, 1×1 blocking | ~110 | ~22 bits | 3 accumulators by magnitude band |
|
|
39
|
+
* | Pipelined 2-term | ~156 | ~22 bits | Double-buffered A split, AMX/AVX-512 overlap|
|
|
40
|
+
* | MKL SGEMM | ~170 | ~20 bits | Pure F32, no decomposition |
|
|
41
|
+
* | Skylake F64 accum | ~50 | ~48 bits | F32×F32 multiply, F64 accumulation |
|
|
42
|
+
*
|
|
43
|
+
* The fundamental bottleneck is TDPFP16PS's internal F32 accumulation: each instruction sums
|
|
44
|
+
* 32 FP16×FP16 products into an F32 register (23-bit mantissa). Even with Ozaki cross-term
|
|
45
|
+
* separation into distinct TMM accumulators (preventing magnitude mixing) and periodic extraction
|
|
46
|
+
* to F64 running sums, the per-instruction accumulation of 32 products loses ~5 bits
|
|
47
|
+
* (log2(32) = 5), capping effective precision at ~28 - 5 = ~23 bits — barely exceeding F32 BLAS.
|
|
48
|
+
*
|
|
49
|
+
* Approaches attempted:
|
|
50
|
+
*
|
|
51
|
+
* - 2-term decomposition (a = a_high + a_low): 4 TDPFP16PS per depth tile, ~20-bit products.
|
|
52
|
+
* With split accumulators (main + correction) merged in F64: ~22-bit effective precision.
|
|
53
|
+
* Faster than MKL at small depths (≤512) but precision plateaus at ~22 bits.
|
|
54
|
+
*
|
|
55
|
+
* - 3-term decomposition (a = a_high + a_mid + a_low): 6 TDPFP16PS per depth tile, ~30-bit
|
|
56
|
+
* products. No precision improvement over 2-term because the F32 TMM accumulation is the
|
|
57
|
+
* bottleneck, not the decomposition quality. Strictly slower and no more precise.
|
|
58
|
+
*
|
|
59
|
+
* - Periodic F64 flush (extract TMM accumulators to F64 every N depth tiles): prevents precision
|
|
60
|
+
* degradation at large depths. With N=16, ~15% overhead. Effective precision still ~24 bits
|
|
61
|
+
* (limited by per-TDPFP16PS accumulation of 32 products, not by inter-tile accumulation).
|
|
62
|
+
*
|
|
63
|
+
* - AMX/AVX-512 pipelining (double-buffered A splitting overlapped with AMX compute): ~7%
|
|
64
|
+
* speedup at large sizes. Does not affect precision.
|
|
65
|
+
*
|
|
66
|
+
* Conclusion: AMX-FP16's F32 tile accumulation fundamentally limits Ozaki to ~22-24 bits —
|
|
67
|
+
* comparable to F32 BLAS, far short of the ~48-bit F64 precision needed to justify the complexity.
|
|
68
|
+
* For F32→F64 GEMM, pure AVX-512 with F64 FMA remains the correct approach.
|
|
69
|
+
*/
|
|
70
|
+
#ifndef NK_DOTS_GRANITEAMX_H
|
|
71
|
+
#define NK_DOTS_GRANITEAMX_H
|
|
72
|
+
|
|
73
|
+
#if NK_TARGET_X8664_
|
|
74
|
+
#if NK_TARGET_GRANITEAMX
|
|
75
|
+
|
|
76
|
+
#include "numkong/dots/serial.h"
|
|
77
|
+
#include "numkong/dots/sapphireamx.h"
|
|
78
|
+
|
|
79
|
+
#if defined(__cplusplus)
|
|
80
|
+
extern "C" {
|
|
81
|
+
#endif
|
|
82
|
+
|
|
83
|
+
#if defined(__clang__)
|
|
84
|
+
#pragma clang attribute push( \
|
|
85
|
+
__attribute__((target( \
|
|
86
|
+
"avx2,avx512f,avx512vl,avx512bw,avx512dq,avx512fp16,avx512vbmi,f16c,fma,bmi,bmi2,amx-tile,amx-bf16,amx-int8,amx-fp16"))), \
|
|
87
|
+
apply_to = function)
|
|
88
|
+
#elif defined(__GNUC__)
|
|
89
|
+
#pragma GCC push_options
|
|
90
|
+
#pragma GCC target("avx2", "avx512f", "avx512vl", "avx512bw", "avx512dq", "avx512fp16", "avx512vbmi", "f16c", "fma", \
|
|
91
|
+
"bmi", "bmi2", "amx-tile", "amx-bf16", "amx-int8", "amx-fp16")
|
|
92
|
+
#endif
|
|
93
|
+
|
|
94
|
+
#pragma region Tile Types
|
|
95
|
+
|
|
96
|
+
typedef struct {
|
|
97
|
+
NK_ALIGN64 nk_f16_t data[16][32]; // 16 rows × 32 columns = 1KB
|
|
98
|
+
} nk_dots_f16_a16x32_graniteamx_t;
|
|
99
|
+
|
|
100
|
+
typedef struct {
|
|
101
|
+
NK_ALIGN64 nk_f16_t data[16][16][2]; // 16 depth-groups × 16 columns × 2 = 1KB (pair-interleaved)
|
|
102
|
+
} nk_dots_f16_b32x16_graniteamx_t;
|
|
103
|
+
|
|
104
|
+
typedef struct {
|
|
105
|
+
NK_ALIGN64 nk_f32_t data[16][16]; // 16 × 16 = 1KB accumulator
|
|
106
|
+
} nk_dots_f16_state_graniteamx_t;
|
|
107
|
+
|
|
108
|
+
typedef struct {
|
|
109
|
+
nk_dots_f16_state_graniteamx_t c[2][2]; // 4KB total (2×2 output blocking)
|
|
110
|
+
} nk_dots_f16_state2x2_graniteamx_t;
|
|
111
|
+
|
|
112
|
+
#pragma endregion Tile Types
|
|
113
|
+
|
|
114
|
+
#pragma region Helpers
|
|
115
|
+
|
|
116
|
+
/* Initialize FP16 output state to zero */
|
|
117
|
+
NK_INTERNAL void nk_dots_f16_init_graniteamx_(nk_dots_f16_state_graniteamx_t *state) {
|
|
118
|
+
__m512 zero_f32x16 = _mm512_setzero_ps();
|
|
119
|
+
for (nk_size_t row_idx = 0; row_idx < 16; row_idx++) { _mm512_store_ps(state->data[row_idx], zero_f32x16); }
|
|
120
|
+
}
|
|
121
|
+
|
|
122
|
+
/* Load A tile from FP16 row-major source with masking for edge tiles */
|
|
123
|
+
NK_INTERNAL void nk_dots_f16_load_a_graniteamx_( //
|
|
124
|
+
nk_dots_f16_a16x32_graniteamx_t *a_tile, //
|
|
125
|
+
nk_f16_t const *src, nk_size_t src_stride_elements, //
|
|
126
|
+
nk_size_t valid_rows, nk_size_t valid_cols) {
|
|
127
|
+
|
|
128
|
+
__mmask32 column_mask = (valid_cols >= 32) ? 0xFFFFFFFF : ((__mmask32)1 << valid_cols) - 1;
|
|
129
|
+
__m512i zero_i16x32 = _mm512_setzero_si512();
|
|
130
|
+
|
|
131
|
+
for (nk_size_t row_idx = 0; row_idx < 16; row_idx++) {
|
|
132
|
+
if (row_idx < valid_rows) {
|
|
133
|
+
__m512i row_i16x32 = _mm512_maskz_loadu_epi16(column_mask, src + row_idx * src_stride_elements);
|
|
134
|
+
_mm512_store_si512((__m512i *)a_tile->data[row_idx], row_i16x32);
|
|
135
|
+
}
|
|
136
|
+
else { _mm512_store_si512((__m512i *)a_tile->data[row_idx], zero_i16x32); }
|
|
137
|
+
}
|
|
138
|
+
nk_compiler_barrier_sapphireamx_();
|
|
139
|
+
}
|
|
140
|
+
|
|
141
|
+
/* Store F32 state to output matrix with masking for edge tiles */
|
|
142
|
+
NK_INTERNAL void nk_dots_f16_store_graniteamx_( //
|
|
143
|
+
nk_dots_f16_state_graniteamx_t const *state, //
|
|
144
|
+
nk_f32_t *dst, nk_size_t dst_stride_elements, //
|
|
145
|
+
nk_size_t valid_rows, nk_size_t valid_cols) {
|
|
146
|
+
|
|
147
|
+
__mmask16 column_mask = (valid_cols >= 16) ? 0xFFFF : ((__mmask16)1 << valid_cols) - 1;
|
|
148
|
+
|
|
149
|
+
for (nk_size_t row_idx = 0; row_idx < valid_rows; row_idx++) {
|
|
150
|
+
__m512 row_f32x16 = _mm512_load_ps(state->data[row_idx]);
|
|
151
|
+
_mm512_mask_storeu_ps(dst + row_idx * dst_stride_elements, column_mask, row_f32x16);
|
|
152
|
+
}
|
|
153
|
+
}
|
|
154
|
+
|
|
155
|
+
NK_INTERNAL void nk_dots_f16_output2x2_graniteamx_( //
|
|
156
|
+
nk_dots_f16_state2x2_graniteamx_t const *state, //
|
|
157
|
+
nk_f32_t *dst, nk_size_t dst_stride_elements, //
|
|
158
|
+
nk_size_t valid_rows, nk_size_t valid_cols) {
|
|
159
|
+
|
|
160
|
+
nk_size_t const rows_high = (valid_rows > 16) ? 16 : valid_rows;
|
|
161
|
+
nk_size_t const cols_left = (valid_cols > 16) ? 16 : valid_cols;
|
|
162
|
+
nk_size_t const cols_right = (valid_cols > 16) ? valid_cols - 16 : 0;
|
|
163
|
+
|
|
164
|
+
if (rows_high > 0 && cols_left > 0)
|
|
165
|
+
nk_dots_f16_store_graniteamx_(&state->c[0][0], dst, dst_stride_elements, rows_high, cols_left);
|
|
166
|
+
if (rows_high > 0 && cols_right > 0)
|
|
167
|
+
nk_dots_f16_store_graniteamx_(&state->c[0][1], dst + 16, dst_stride_elements, rows_high, cols_right);
|
|
168
|
+
|
|
169
|
+
if (valid_rows > 16) {
|
|
170
|
+
nk_size_t const rows_low = valid_rows - 16;
|
|
171
|
+
nk_f32_t *dst_low = dst + 16 * dst_stride_elements;
|
|
172
|
+
if (cols_left > 0)
|
|
173
|
+
nk_dots_f16_store_graniteamx_(&state->c[1][0], dst_low, dst_stride_elements, rows_low, cols_left);
|
|
174
|
+
if (cols_right > 0)
|
|
175
|
+
nk_dots_f16_store_graniteamx_(&state->c[1][1], dst_low + 16, dst_stride_elements, rows_low, cols_right);
|
|
176
|
+
}
|
|
177
|
+
}
|
|
178
|
+
|
|
179
|
+
NK_INTERNAL void nk_dots_f16_update_graniteamx_( //
|
|
180
|
+
nk_dots_f16_state_graniteamx_t *state, //
|
|
181
|
+
nk_dots_f16_a16x32_graniteamx_t const *a_tile_0, //
|
|
182
|
+
nk_dots_f16_a16x32_graniteamx_t const *a_tile_1, //
|
|
183
|
+
nk_dots_f16_a16x32_graniteamx_t const *a_tile_2, //
|
|
184
|
+
nk_dots_f16_b32x16_graniteamx_t const *b_tile_0, //
|
|
185
|
+
nk_dots_f16_b32x16_graniteamx_t const *b_tile_1, //
|
|
186
|
+
nk_dots_f16_b32x16_graniteamx_t const *b_tile_2) {
|
|
187
|
+
|
|
188
|
+
_tile_loadd(0, state->data, 64);
|
|
189
|
+
_tile_loadd(1, a_tile_0->data, 64);
|
|
190
|
+
_tile_loadd(2, a_tile_1->data, 64);
|
|
191
|
+
_tile_loadd(3, a_tile_2->data, 64);
|
|
192
|
+
_tile_loadd(4, b_tile_0->data, 64);
|
|
193
|
+
_tile_loadd(5, b_tile_1->data, 64);
|
|
194
|
+
_tile_loadd(6, b_tile_2->data, 64);
|
|
195
|
+
|
|
196
|
+
_tile_dpfp16ps(0, 1, 4);
|
|
197
|
+
_tile_dpfp16ps(0, 2, 5);
|
|
198
|
+
_tile_dpfp16ps(0, 3, 6);
|
|
199
|
+
|
|
200
|
+
_tile_stored(0, state->data, 64);
|
|
201
|
+
}
|
|
202
|
+
|
|
203
|
+
#pragma endregion Helpers
|
|
204
|
+
|
|
205
|
+
#pragma region F16 Native
|
|
206
|
+
|
|
207
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_f16_graniteamx(nk_size_t column_count, nk_size_t depth) {
|
|
208
|
+
nk_size_t const tmm_rows = 16;
|
|
209
|
+
nk_size_t const tmm_cols = 32;
|
|
210
|
+
nk_size_t const tile_bytes = 512 * sizeof(nk_f16_t); // 16 × 32 × 2 = 1KB
|
|
211
|
+
|
|
212
|
+
nk_size_t const full_column_tiles = column_count / tmm_rows;
|
|
213
|
+
nk_size_t const tiles_along_depth = nk_size_divide_round_up_(depth, tmm_cols);
|
|
214
|
+
nk_size_t const column_remainder_count = column_count - full_column_tiles * tmm_rows;
|
|
215
|
+
|
|
216
|
+
// Header (64 bytes aligned)
|
|
217
|
+
nk_size_t size = sizeof(nk_dots_amx_packed_header_t);
|
|
218
|
+
|
|
219
|
+
// All tiles for full column rows (pair-interleaved, depth remainder zero-padded)
|
|
220
|
+
size += full_column_tiles * tiles_along_depth * tile_bytes;
|
|
221
|
+
|
|
222
|
+
// Column edge: remaining rows for ALL depth columns, stored row-major
|
|
223
|
+
if (column_remainder_count > 0) size += column_remainder_count * depth * sizeof(nk_f16_t);
|
|
224
|
+
|
|
225
|
+
// Per-column norms for angular/euclidean distance (4 bytes each: f32)
|
|
226
|
+
size += column_count * sizeof(nk_f32_t);
|
|
227
|
+
|
|
228
|
+
return size;
|
|
229
|
+
}
|
|
230
|
+
|
|
231
|
+
NK_PUBLIC void nk_dots_pack_f16_graniteamx( //
|
|
232
|
+
nk_f16_t const *b, nk_size_t column_count, nk_size_t depth, //
|
|
233
|
+
nk_size_t b_stride_in_bytes, void *b_packed) {
|
|
234
|
+
|
|
235
|
+
// AMX FP16 tile dimensions: 16 rows × 32 columns (512 FP16 elements = 1KB)
|
|
236
|
+
nk_size_t const tmm_rows = 16;
|
|
237
|
+
nk_size_t const tmm_cols = 32;
|
|
238
|
+
nk_size_t const tile_elements = 512;
|
|
239
|
+
nk_size_t const tile_bytes = tile_elements * sizeof(nk_f16_t);
|
|
240
|
+
nk_size_t const b_stride_elements = b_stride_in_bytes / sizeof(nk_f16_t);
|
|
241
|
+
|
|
242
|
+
// Compute layout dimensions
|
|
243
|
+
nk_size_t const column_tiles_count = column_count / tmm_rows;
|
|
244
|
+
nk_size_t const depth_tiles_count = nk_size_divide_round_up_(depth, tmm_cols);
|
|
245
|
+
nk_size_t const column_remainder_count = column_count - column_tiles_count * tmm_rows;
|
|
246
|
+
nk_size_t const total_tiles = column_tiles_count * depth_tiles_count;
|
|
247
|
+
|
|
248
|
+
// Write header with layout metadata
|
|
249
|
+
nk_dots_amx_packed_header_t *header = (nk_dots_amx_packed_header_t *)b_packed;
|
|
250
|
+
header->full_column_tiles = (nk_u32_t)column_tiles_count;
|
|
251
|
+
header->full_depth_tiles = (nk_u32_t)depth_tiles_count;
|
|
252
|
+
header->column_remainder_count = (nk_u32_t)column_remainder_count;
|
|
253
|
+
|
|
254
|
+
// Compute memory region offsets
|
|
255
|
+
nk_size_t const tiles_offset = sizeof(nk_dots_amx_packed_header_t);
|
|
256
|
+
nk_size_t const column_edge_offset = tiles_offset + total_tiles * tile_bytes;
|
|
257
|
+
header->column_edge_offset = (nk_u32_t)column_edge_offset;
|
|
258
|
+
|
|
259
|
+
// Pointers to packed data regions
|
|
260
|
+
nk_f16_t *tiles_ptr = (nk_f16_t *)((char *)b_packed + tiles_offset);
|
|
261
|
+
nk_f16_t *column_edge_ptr = (nk_f16_t *)((char *)b_packed + column_edge_offset);
|
|
262
|
+
|
|
263
|
+
// Pack tiles: gather 16 strided rows into aligned temporary, transpose via SIMD, copy to packed buffer.
|
|
264
|
+
// FP16 has the same 16-bit pair-interleaved layout as BF16, so reuse the BF16 transpose.
|
|
265
|
+
for (nk_size_t column_tile_idx = 0; column_tile_idx < column_tiles_count; column_tile_idx++) {
|
|
266
|
+
for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
|
|
267
|
+
|
|
268
|
+
nk_size_t const tile_index = column_tile_idx * depth_tiles_count + depth_tile_idx;
|
|
269
|
+
nk_f16_t *tile_output = tiles_ptr + tile_index * tile_elements;
|
|
270
|
+
|
|
271
|
+
nk_size_t const src_row_start = column_tile_idx * tmm_rows;
|
|
272
|
+
nk_size_t const src_column_start = depth_tile_idx * tmm_cols;
|
|
273
|
+
nk_size_t const columns_to_pack = (src_column_start + tmm_cols <= depth) ? tmm_cols
|
|
274
|
+
: (depth - src_column_start);
|
|
275
|
+
|
|
276
|
+
// Gather 16 strided source rows into a contiguous aligned tile
|
|
277
|
+
nk_dots_bf16_a16x32_sapphireamx_t source_tile;
|
|
278
|
+
if (columns_to_pack == tmm_cols) {
|
|
279
|
+
for (nk_size_t row_idx = 0; row_idx < tmm_rows; row_idx++) {
|
|
280
|
+
nk_f16_t const *source_row = b + (src_row_start + row_idx) * b_stride_elements + src_column_start;
|
|
281
|
+
_mm512_store_si512(&source_tile.data[row_idx][0], _mm512_loadu_si512(source_row));
|
|
282
|
+
}
|
|
283
|
+
}
|
|
284
|
+
else {
|
|
285
|
+
__mmask32 depth_mask = (__mmask32)((columns_to_pack < 32) ? ((1U << columns_to_pack) - 1) : ~0U);
|
|
286
|
+
for (nk_size_t row_idx = 0; row_idx < tmm_rows; row_idx++) {
|
|
287
|
+
nk_f16_t const *source_row = b + (src_row_start + row_idx) * b_stride_elements + src_column_start;
|
|
288
|
+
_mm512_store_si512(&source_tile.data[row_idx][0], _mm512_maskz_loadu_epi16(depth_mask, source_row));
|
|
289
|
+
}
|
|
290
|
+
}
|
|
291
|
+
|
|
292
|
+
// Transpose into aligned local, then copy to (potentially unaligned) packed buffer.
|
|
293
|
+
// BF16 and FP16 share identical 16-bit pair-interleaved layout for TDPBF16PS/TDPFP16PS.
|
|
294
|
+
nk_dots_bf16_b32x16_sapphireamx_t transposed_tile;
|
|
295
|
+
nk_dots_pack_bf16_transposed_sapphireamx_(&source_tile, &transposed_tile);
|
|
296
|
+
for (nk_size_t i = 0; i < tile_bytes; i += 64)
|
|
297
|
+
_mm512_storeu_si512((char *)tile_output + i, _mm512_load_si512((char const *)&transposed_tile + i));
|
|
298
|
+
}
|
|
299
|
+
}
|
|
300
|
+
|
|
301
|
+
// Pack column-remainder rows using vectorized masked copies
|
|
302
|
+
if (column_remainder_count > 0) {
|
|
303
|
+
nk_size_t const remainder_start_row = column_tiles_count * tmm_rows;
|
|
304
|
+
for (nk_size_t row_idx = 0; row_idx < column_remainder_count; row_idx++) {
|
|
305
|
+
nk_f16_t const *src_row = b + (remainder_start_row + row_idx) * b_stride_elements;
|
|
306
|
+
nk_f16_t *dst_row = column_edge_ptr + row_idx * depth;
|
|
307
|
+
nk_size_t column_idx = 0;
|
|
308
|
+
for (; column_idx + 32 <= depth; column_idx += 32) {
|
|
309
|
+
_mm512_storeu_si512(dst_row + column_idx, _mm512_loadu_si512(src_row + column_idx));
|
|
310
|
+
}
|
|
311
|
+
if (column_idx < depth) {
|
|
312
|
+
__mmask32 tail_mask = (__mmask32)((1U << (depth - column_idx)) - 1);
|
|
313
|
+
_mm512_mask_storeu_epi16(dst_row + column_idx, tail_mask,
|
|
314
|
+
_mm512_maskz_loadu_epi16(tail_mask, src_row + column_idx));
|
|
315
|
+
}
|
|
316
|
+
}
|
|
317
|
+
}
|
|
318
|
+
|
|
319
|
+
// Compute and store per-column norms for angular/euclidean distance
|
|
320
|
+
nk_size_t norms_offset = column_edge_offset +
|
|
321
|
+
(column_remainder_count > 0 ? column_remainder_count * depth * sizeof(nk_f16_t) : 0);
|
|
322
|
+
header->norms_byte_offset = (nk_u32_t)norms_offset;
|
|
323
|
+
nk_f32_t *norms = (nk_f32_t *)((char *)b_packed + norms_offset);
|
|
324
|
+
for (nk_size_t col = 0; col < column_count; col++)
|
|
325
|
+
norms[col] = nk_dots_reduce_sumsq_f16_(b + col * b_stride_elements, depth);
|
|
326
|
+
}
|
|
327
|
+
|
|
328
|
+
NK_PUBLIC void nk_dots_packed_f16_graniteamx( //
|
|
329
|
+
nk_f16_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
330
|
+
nk_size_t rows_count, nk_size_t cols_count, nk_size_t depth, nk_size_t a_stride_bytes, nk_size_t c_stride_bytes) {
|
|
331
|
+
nk_unused_(cols_count);
|
|
332
|
+
|
|
333
|
+
// Parse packed B header
|
|
334
|
+
nk_dots_amx_packed_header_t const *header = (nk_dots_amx_packed_header_t const *)b_packed;
|
|
335
|
+
nk_size_t const column_tiles_count = header->full_column_tiles;
|
|
336
|
+
nk_size_t const depth_tiles_count = header->full_depth_tiles;
|
|
337
|
+
nk_size_t const column_remainder_count = header->column_remainder_count;
|
|
338
|
+
|
|
339
|
+
// Packed B data regions
|
|
340
|
+
nk_f16_t const *b_tiles_base = (nk_f16_t const *)((char const *)b_packed + sizeof(nk_dots_amx_packed_header_t));
|
|
341
|
+
nk_f16_t const *col_edge_ptr = (nk_f16_t const *)((char const *)b_packed + header->column_edge_offset);
|
|
342
|
+
|
|
343
|
+
// Stride conversions
|
|
344
|
+
nk_size_t const a_stride_elements = a_stride_bytes / sizeof(nk_f16_t);
|
|
345
|
+
nk_size_t const c_stride_elements = c_stride_bytes / sizeof(nk_f32_t);
|
|
346
|
+
|
|
347
|
+
// Tile dimensions
|
|
348
|
+
nk_size_t const tile_depth = 32;
|
|
349
|
+
nk_size_t const tile_size = 512;
|
|
350
|
+
nk_size_t const full_cols = column_tiles_count * 16;
|
|
351
|
+
|
|
352
|
+
// Block counts (32 × 32 output blocks = 2 × 2 tiles)
|
|
353
|
+
nk_size_t const row_blocks_count = nk_size_divide_round_up_(rows_count, 32);
|
|
354
|
+
nk_size_t const col_blocks_count = column_tiles_count / 2;
|
|
355
|
+
|
|
356
|
+
if (depth_tiles_count == 0) return;
|
|
357
|
+
|
|
358
|
+
// Tile buffers for A (only used for edge tiles)
|
|
359
|
+
nk_dots_f16_a16x32_graniteamx_t a_tile_top, a_tile_bottom;
|
|
360
|
+
nk_dots_f16_state2x2_graniteamx_t c_accum_buffer;
|
|
361
|
+
|
|
362
|
+
// Precompute: number of full depth-tiles (no masking needed)
|
|
363
|
+
nk_size_t const full_depth_tiles_count = depth / tile_depth;
|
|
364
|
+
nk_size_t const depth_remainder = depth % tile_depth;
|
|
365
|
+
|
|
366
|
+
nk_amx_tile_configure_sapphireamx_();
|
|
367
|
+
|
|
368
|
+
// Loop order: row_blocks outer, col_blocks inner - maximizes A tile L2 cache reuse
|
|
369
|
+
for (nk_size_t row_block_idx = 0; row_block_idx < row_blocks_count; row_block_idx++) {
|
|
370
|
+
nk_size_t const row_block_start = row_block_idx * 32;
|
|
371
|
+
nk_size_t const valid_rows_count = (row_block_start + 32 <= rows_count) ? 32 : (rows_count - row_block_start);
|
|
372
|
+
nk_size_t const is_full_row_block = (valid_rows_count == 32);
|
|
373
|
+
|
|
374
|
+
for (nk_size_t column_block_idx = 0; column_block_idx < col_blocks_count; column_block_idx++) {
|
|
375
|
+
nk_size_t const col_block_start = column_block_idx * 32;
|
|
376
|
+
nk_size_t const b_column_left_base = (column_block_idx * 2) * depth_tiles_count;
|
|
377
|
+
nk_size_t const b_column_right_base = (column_block_idx * 2 + 1) * depth_tiles_count;
|
|
378
|
+
|
|
379
|
+
// Zero accumulators (TMM4-7 stay resident across entire depth loop)
|
|
380
|
+
_tile_zero(4);
|
|
381
|
+
_tile_zero(5);
|
|
382
|
+
_tile_zero(6);
|
|
383
|
+
_tile_zero(7);
|
|
384
|
+
|
|
385
|
+
// Fast path: full row-block with full depth-tiles → direct A load with 2-deep pipelining
|
|
386
|
+
if (is_full_row_block && full_depth_tiles_count > 0) {
|
|
387
|
+
nk_f16_t const *a_top_base = a + row_block_start * a_stride_elements;
|
|
388
|
+
nk_f16_t const *a_bottom_base = a + (row_block_start + 16) * a_stride_elements;
|
|
389
|
+
|
|
390
|
+
nk_dots_f16_b32x16_graniteamx_t const *b_tile_left =
|
|
391
|
+
(nk_dots_f16_b32x16_graniteamx_t const *)(b_tiles_base + b_column_left_base * tile_size);
|
|
392
|
+
nk_dots_f16_b32x16_graniteamx_t const *b_tile_right =
|
|
393
|
+
(nk_dots_f16_b32x16_graniteamx_t const *)(b_tiles_base + b_column_right_base * tile_size);
|
|
394
|
+
|
|
395
|
+
// Prologue: load first depth tile
|
|
396
|
+
_tile_loadd(0, a_top_base, a_stride_bytes);
|
|
397
|
+
_tile_loadd(1, a_bottom_base, a_stride_bytes);
|
|
398
|
+
_tile_loadd(2, b_tile_left->data, 64);
|
|
399
|
+
_tile_loadd(3, b_tile_right->data, 64);
|
|
400
|
+
|
|
401
|
+
// Main loop: 2-deep software pipelining
|
|
402
|
+
for (nk_size_t depth_tile_idx = 0; depth_tile_idx < full_depth_tiles_count - 1; depth_tile_idx++) {
|
|
403
|
+
nk_size_t const next_depth_offset = (depth_tile_idx + 1) * tile_depth;
|
|
404
|
+
|
|
405
|
+
_tile_dpfp16ps(4, 0, 2);
|
|
406
|
+
_tile_dpfp16ps(5, 0, 3);
|
|
407
|
+
_tile_dpfp16ps(6, 1, 2);
|
|
408
|
+
_tile_dpfp16ps(7, 1, 3);
|
|
409
|
+
|
|
410
|
+
_tile_loadd(0, a_top_base + next_depth_offset, a_stride_bytes);
|
|
411
|
+
_tile_loadd(1, a_bottom_base + next_depth_offset, a_stride_bytes);
|
|
412
|
+
b_tile_left = (nk_dots_f16_b32x16_graniteamx_t const *)(b_tiles_base +
|
|
413
|
+
(b_column_left_base + depth_tile_idx + 1) *
|
|
414
|
+
tile_size);
|
|
415
|
+
b_tile_right = (nk_dots_f16_b32x16_graniteamx_t const *)(b_tiles_base + (b_column_right_base +
|
|
416
|
+
depth_tile_idx + 1) *
|
|
417
|
+
tile_size);
|
|
418
|
+
_tile_loadd(2, b_tile_left->data, 64);
|
|
419
|
+
_tile_loadd(3, b_tile_right->data, 64);
|
|
420
|
+
}
|
|
421
|
+
|
|
422
|
+
// Epilogue: final depth tile
|
|
423
|
+
_tile_dpfp16ps(4, 0, 2);
|
|
424
|
+
_tile_dpfp16ps(5, 0, 3);
|
|
425
|
+
_tile_dpfp16ps(6, 1, 2);
|
|
426
|
+
_tile_dpfp16ps(7, 1, 3);
|
|
427
|
+
|
|
428
|
+
// Handle partial depth-tile (if any)
|
|
429
|
+
if (depth_remainder > 0) {
|
|
430
|
+
nk_size_t const depth_offset = full_depth_tiles_count * tile_depth;
|
|
431
|
+
|
|
432
|
+
nk_dots_f16_load_a_graniteamx_(&a_tile_top, a_top_base + depth_offset, a_stride_elements, 16,
|
|
433
|
+
depth_remainder);
|
|
434
|
+
nk_dots_f16_load_a_graniteamx_(&a_tile_bottom, a_bottom_base + depth_offset, a_stride_elements, 16,
|
|
435
|
+
depth_remainder);
|
|
436
|
+
|
|
437
|
+
b_tile_left = (nk_dots_f16_b32x16_graniteamx_t const *)(b_tiles_base + (b_column_left_base +
|
|
438
|
+
full_depth_tiles_count) *
|
|
439
|
+
tile_size);
|
|
440
|
+
b_tile_right = (nk_dots_f16_b32x16_graniteamx_t const *)(b_tiles_base + (b_column_right_base +
|
|
441
|
+
full_depth_tiles_count) *
|
|
442
|
+
tile_size);
|
|
443
|
+
|
|
444
|
+
_tile_loadd(0, a_tile_top.data, 64);
|
|
445
|
+
_tile_loadd(1, a_tile_bottom.data, 64);
|
|
446
|
+
_tile_loadd(2, b_tile_left->data, 64);
|
|
447
|
+
_tile_loadd(3, b_tile_right->data, 64);
|
|
448
|
+
|
|
449
|
+
_tile_dpfp16ps(4, 0, 2);
|
|
450
|
+
_tile_dpfp16ps(5, 0, 3);
|
|
451
|
+
_tile_dpfp16ps(6, 1, 2);
|
|
452
|
+
_tile_dpfp16ps(7, 1, 3);
|
|
453
|
+
}
|
|
454
|
+
}
|
|
455
|
+
// Full row-block but only partial depth tile (depth < tile_depth)
|
|
456
|
+
else if (is_full_row_block) {
|
|
457
|
+
nk_f16_t const *a_top_base = a + row_block_start * a_stride_elements;
|
|
458
|
+
nk_f16_t const *a_bottom_base = a + (row_block_start + 16) * a_stride_elements;
|
|
459
|
+
|
|
460
|
+
nk_dots_f16_load_a_graniteamx_(&a_tile_top, a_top_base, a_stride_elements, 16, depth_remainder);
|
|
461
|
+
nk_dots_f16_load_a_graniteamx_(&a_tile_bottom, a_bottom_base, a_stride_elements, 16, depth_remainder);
|
|
462
|
+
|
|
463
|
+
nk_dots_f16_b32x16_graniteamx_t const *b_tile_left =
|
|
464
|
+
(nk_dots_f16_b32x16_graniteamx_t const *)(b_tiles_base + b_column_left_base * tile_size);
|
|
465
|
+
nk_dots_f16_b32x16_graniteamx_t const *b_tile_right =
|
|
466
|
+
(nk_dots_f16_b32x16_graniteamx_t const *)(b_tiles_base + b_column_right_base * tile_size);
|
|
467
|
+
|
|
468
|
+
_tile_loadd(0, a_tile_top.data, 64);
|
|
469
|
+
_tile_loadd(1, a_tile_bottom.data, 64);
|
|
470
|
+
_tile_loadd(2, b_tile_left->data, 64);
|
|
471
|
+
_tile_loadd(3, b_tile_right->data, 64);
|
|
472
|
+
|
|
473
|
+
_tile_dpfp16ps(4, 0, 2);
|
|
474
|
+
_tile_dpfp16ps(5, 0, 3);
|
|
475
|
+
_tile_dpfp16ps(6, 1, 2);
|
|
476
|
+
_tile_dpfp16ps(7, 1, 3);
|
|
477
|
+
}
|
|
478
|
+
// Slow path: edge row-block → buffered load with masking
|
|
479
|
+
else {
|
|
480
|
+
nk_size_t const rows_in_high_tile = (valid_rows_count > 16) ? 16 : valid_rows_count;
|
|
481
|
+
nk_size_t const rows_in_low_tile = (valid_rows_count > 16) ? valid_rows_count - 16 : 0;
|
|
482
|
+
|
|
483
|
+
for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
|
|
484
|
+
nk_size_t const depth_offset = depth_tile_idx * tile_depth;
|
|
485
|
+
nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth
|
|
486
|
+
: depth_remainder;
|
|
487
|
+
|
|
488
|
+
nk_dots_f16_load_a_graniteamx_(&a_tile_top, a + row_block_start * a_stride_elements + depth_offset,
|
|
489
|
+
a_stride_elements, rows_in_high_tile, valid_depth);
|
|
490
|
+
if (rows_in_low_tile > 0) {
|
|
491
|
+
nk_dots_f16_load_a_graniteamx_(&a_tile_bottom,
|
|
492
|
+
a + (row_block_start + 16) * a_stride_elements + depth_offset,
|
|
493
|
+
a_stride_elements, rows_in_low_tile, valid_depth);
|
|
494
|
+
}
|
|
495
|
+
|
|
496
|
+
nk_dots_f16_b32x16_graniteamx_t const *b_tile_left =
|
|
497
|
+
(nk_dots_f16_b32x16_graniteamx_t const *)(b_tiles_base +
|
|
498
|
+
(b_column_left_base + depth_tile_idx) * tile_size);
|
|
499
|
+
nk_dots_f16_b32x16_graniteamx_t const *b_tile_right =
|
|
500
|
+
(nk_dots_f16_b32x16_graniteamx_t const *)(b_tiles_base +
|
|
501
|
+
(b_column_right_base + depth_tile_idx) * tile_size);
|
|
502
|
+
|
|
503
|
+
_tile_loadd(0, a_tile_top.data, 64);
|
|
504
|
+
_tile_loadd(1, a_tile_bottom.data, 64);
|
|
505
|
+
_tile_loadd(2, b_tile_left->data, 64);
|
|
506
|
+
_tile_loadd(3, b_tile_right->data, 64);
|
|
507
|
+
|
|
508
|
+
_tile_dpfp16ps(4, 0, 2);
|
|
509
|
+
_tile_dpfp16ps(5, 0, 3);
|
|
510
|
+
_tile_dpfp16ps(6, 1, 2);
|
|
511
|
+
_tile_dpfp16ps(7, 1, 3);
|
|
512
|
+
}
|
|
513
|
+
}
|
|
514
|
+
|
|
515
|
+
// Store accumulators to output (once per output block)
|
|
516
|
+
if (is_full_row_block) {
|
|
517
|
+
nk_f32_t *c_block = c + row_block_start * c_stride_elements + col_block_start;
|
|
518
|
+
_tile_stored(4, c_block, c_stride_bytes);
|
|
519
|
+
_tile_stored(5, c_block + 16, c_stride_bytes);
|
|
520
|
+
_tile_stored(6, (nk_f32_t *)((char *)c_block + 16 * c_stride_bytes), c_stride_bytes);
|
|
521
|
+
_tile_stored(7, (nk_f32_t *)((char *)c_block + 16 * c_stride_bytes) + 16, c_stride_bytes);
|
|
522
|
+
}
|
|
523
|
+
else {
|
|
524
|
+
_tile_stored(4, c_accum_buffer.c[0][0].data, 64);
|
|
525
|
+
_tile_stored(5, c_accum_buffer.c[0][1].data, 64);
|
|
526
|
+
_tile_stored(6, c_accum_buffer.c[1][0].data, 64);
|
|
527
|
+
_tile_stored(7, c_accum_buffer.c[1][1].data, 64);
|
|
528
|
+
nk_dots_f16_output2x2_graniteamx_(&c_accum_buffer,
|
|
529
|
+
c + row_block_start * c_stride_elements + col_block_start,
|
|
530
|
+
c_stride_elements, valid_rows_count, 32);
|
|
531
|
+
}
|
|
532
|
+
}
|
|
533
|
+
}
|
|
534
|
+
|
|
535
|
+
// Handle odd column-tile (single 16-column tile if column_tiles_count is odd)
|
|
536
|
+
if (column_tiles_count % 2 == 1) {
|
|
537
|
+
nk_size_t const column_tile_idx = column_tiles_count - 1;
|
|
538
|
+
nk_size_t const col_start = column_tile_idx * 16;
|
|
539
|
+
nk_size_t const b_column_base = column_tile_idx * depth_tiles_count;
|
|
540
|
+
|
|
541
|
+
for (nk_size_t row_block_idx = 0; row_block_idx < row_blocks_count; row_block_idx++) {
|
|
542
|
+
nk_size_t const row_block_start = row_block_idx * 32;
|
|
543
|
+
nk_size_t const valid_rows_count = (row_block_start + 32 <= rows_count) ? 32
|
|
544
|
+
: (rows_count - row_block_start);
|
|
545
|
+
nk_size_t const rows_in_high_tile = (valid_rows_count > 16) ? 16 : valid_rows_count;
|
|
546
|
+
nk_size_t const rows_in_low_tile = (valid_rows_count > 16) ? valid_rows_count - 16 : 0;
|
|
547
|
+
|
|
548
|
+
nk_dots_f16_state_graniteamx_t c_high_state, c_low_state;
|
|
549
|
+
|
|
550
|
+
_tile_zero(4);
|
|
551
|
+
_tile_zero(6);
|
|
552
|
+
|
|
553
|
+
for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
|
|
554
|
+
nk_size_t const depth_offset = depth_tile_idx * tile_depth;
|
|
555
|
+
nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
|
|
556
|
+
|
|
557
|
+
nk_dots_f16_load_a_graniteamx_(&a_tile_top, a + row_block_start * a_stride_elements + depth_offset,
|
|
558
|
+
a_stride_elements, rows_in_high_tile, valid_depth);
|
|
559
|
+
if (rows_in_low_tile > 0) {
|
|
560
|
+
nk_dots_f16_load_a_graniteamx_(&a_tile_bottom,
|
|
561
|
+
a + (row_block_start + 16) * a_stride_elements + depth_offset,
|
|
562
|
+
a_stride_elements, rows_in_low_tile, valid_depth);
|
|
563
|
+
}
|
|
564
|
+
|
|
565
|
+
nk_dots_f16_b32x16_graniteamx_t const *b_tile =
|
|
566
|
+
(nk_dots_f16_b32x16_graniteamx_t const *)(b_tiles_base +
|
|
567
|
+
(b_column_base + depth_tile_idx) * tile_size);
|
|
568
|
+
|
|
569
|
+
_tile_loadd(0, a_tile_top.data, 64);
|
|
570
|
+
_tile_loadd(1, a_tile_bottom.data, 64);
|
|
571
|
+
_tile_loadd(2, b_tile->data, 64);
|
|
572
|
+
|
|
573
|
+
_tile_dpfp16ps(4, 0, 2);
|
|
574
|
+
_tile_dpfp16ps(6, 1, 2);
|
|
575
|
+
}
|
|
576
|
+
|
|
577
|
+
_tile_stored(4, c_high_state.data, 64);
|
|
578
|
+
_tile_stored(6, c_low_state.data, 64);
|
|
579
|
+
|
|
580
|
+
nk_dots_f16_store_graniteamx_(&c_high_state, c + row_block_start * c_stride_elements + col_start,
|
|
581
|
+
c_stride_elements, rows_in_high_tile, 16);
|
|
582
|
+
if (rows_in_low_tile > 0) {
|
|
583
|
+
nk_dots_f16_store_graniteamx_(&c_low_state, c + (row_block_start + 16) * c_stride_elements + col_start,
|
|
584
|
+
c_stride_elements, rows_in_low_tile, 16);
|
|
585
|
+
}
|
|
586
|
+
}
|
|
587
|
+
}
|
|
588
|
+
|
|
589
|
+
// Handle column-edge (remaining columns < 16) using AMX with partial tiles
|
|
590
|
+
if (column_remainder_count > 0) {
|
|
591
|
+
for (nk_size_t row_block_idx = 0; row_block_idx < row_blocks_count; row_block_idx++) {
|
|
592
|
+
nk_size_t const row_block_start = row_block_idx * 32;
|
|
593
|
+
nk_size_t const valid_rows_count = (row_block_start + 32 <= rows_count) ? 32
|
|
594
|
+
: (rows_count - row_block_start);
|
|
595
|
+
nk_size_t const rows_in_high_tile = (valid_rows_count > 16) ? 16 : valid_rows_count;
|
|
596
|
+
nk_size_t const rows_in_low_tile = (valid_rows_count > 16) ? valid_rows_count - 16 : 0;
|
|
597
|
+
|
|
598
|
+
nk_dots_f16_state_graniteamx_t c_high_state, c_low_state;
|
|
599
|
+
nk_dots_bf16_a16x32_sapphireamx_t b_as_a;
|
|
600
|
+
nk_dots_bf16_b32x16_sapphireamx_t b_tile;
|
|
601
|
+
|
|
602
|
+
_tile_zero(4);
|
|
603
|
+
_tile_zero(6);
|
|
604
|
+
|
|
605
|
+
for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
|
|
606
|
+
nk_size_t const depth_offset = depth_tile_idx * tile_depth;
|
|
607
|
+
nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
|
|
608
|
+
|
|
609
|
+
nk_dots_f16_load_a_graniteamx_(&a_tile_top, a + row_block_start * a_stride_elements + depth_offset,
|
|
610
|
+
a_stride_elements, rows_in_high_tile, valid_depth);
|
|
611
|
+
if (rows_in_low_tile > 0) {
|
|
612
|
+
nk_dots_f16_load_a_graniteamx_(&a_tile_bottom,
|
|
613
|
+
a + (row_block_start + 16) * a_stride_elements + depth_offset,
|
|
614
|
+
a_stride_elements, rows_in_low_tile, valid_depth);
|
|
615
|
+
}
|
|
616
|
+
|
|
617
|
+
// Load edge columns as BF16-shaped tile (same 16-bit layout) and transpose on-the-fly
|
|
618
|
+
nk_dots_bf16_load_a_sapphireamx_(&b_as_a, (nk_bf16_t const *)(col_edge_ptr + depth_offset), depth,
|
|
619
|
+
column_remainder_count, valid_depth);
|
|
620
|
+
nk_dots_pack_bf16_transposed_sapphireamx_(&b_as_a, &b_tile);
|
|
621
|
+
|
|
622
|
+
_tile_loadd(0, a_tile_top.data, 64);
|
|
623
|
+
_tile_loadd(1, a_tile_bottom.data, 64);
|
|
624
|
+
_tile_loadd(2, b_tile.data, 64);
|
|
625
|
+
|
|
626
|
+
_tile_dpfp16ps(4, 0, 2);
|
|
627
|
+
_tile_dpfp16ps(6, 1, 2);
|
|
628
|
+
}
|
|
629
|
+
|
|
630
|
+
_tile_stored(4, c_high_state.data, 64);
|
|
631
|
+
_tile_stored(6, c_low_state.data, 64);
|
|
632
|
+
|
|
633
|
+
nk_dots_f16_store_graniteamx_(&c_high_state, c + row_block_start * c_stride_elements + full_cols,
|
|
634
|
+
c_stride_elements, rows_in_high_tile, column_remainder_count);
|
|
635
|
+
if (rows_in_low_tile > 0) {
|
|
636
|
+
nk_dots_f16_store_graniteamx_(&c_low_state, c + (row_block_start + 16) * c_stride_elements + full_cols,
|
|
637
|
+
c_stride_elements, rows_in_low_tile, column_remainder_count);
|
|
638
|
+
}
|
|
639
|
+
}
|
|
640
|
+
}
|
|
641
|
+
|
|
642
|
+
_tile_release();
|
|
643
|
+
}
|
|
644
|
+
|
|
645
|
+
NK_PUBLIC void nk_dots_symmetric_f16_graniteamx( //
|
|
646
|
+
nk_f16_t const *vectors, nk_size_t vectors_count, nk_size_t depth, //
|
|
647
|
+
nk_size_t stride_in_bytes, nk_f32_t *result, nk_size_t result_stride_in_bytes, //
|
|
648
|
+
nk_size_t row_start, nk_size_t row_count) {
|
|
649
|
+
|
|
650
|
+
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_f16_t);
|
|
651
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
652
|
+
|
|
653
|
+
nk_size_t const row_end = (row_count == 0)
|
|
654
|
+
? vectors_count
|
|
655
|
+
: (row_start + row_count < vectors_count ? row_start + row_count : vectors_count);
|
|
656
|
+
|
|
657
|
+
// Round depth up to multiple of 96 (3 tiles × 32 elements)
|
|
658
|
+
nk_size_t const depth_tiles = nk_size_divide_round_up_(depth, 32);
|
|
659
|
+
nk_size_t const depth_tile_groups = nk_size_divide_round_up_(depth_tiles, 3);
|
|
660
|
+
|
|
661
|
+
nk_dots_f16_a16x32_graniteamx_t a_tiles[3];
|
|
662
|
+
nk_dots_f16_a16x32_graniteamx_t b_src_tiles[3];
|
|
663
|
+
nk_dots_f16_b32x16_graniteamx_t b_tiles[3];
|
|
664
|
+
nk_dots_f16_state_graniteamx_t state;
|
|
665
|
+
|
|
666
|
+
nk_amx_tile_configure_sapphireamx_();
|
|
667
|
+
|
|
668
|
+
for (nk_size_t row_tile = row_start; row_tile < row_end; row_tile += 16) {
|
|
669
|
+
nk_size_t const valid_rows = (row_tile + 16 <= row_end) ? 16 : (row_end - row_tile);
|
|
670
|
+
|
|
671
|
+
for (nk_size_t col_tile = 0; col_tile < vectors_count; col_tile += 16) {
|
|
672
|
+
nk_size_t const valid_cols = (col_tile + 16 <= vectors_count) ? 16 : (vectors_count - col_tile);
|
|
673
|
+
|
|
674
|
+
nk_dots_f16_init_graniteamx_(&state);
|
|
675
|
+
|
|
676
|
+
for (nk_size_t depth_group_idx = 0; depth_group_idx < depth_tile_groups; depth_group_idx++) {
|
|
677
|
+
nk_size_t const depth_base = depth_group_idx * 96;
|
|
678
|
+
|
|
679
|
+
for (int tile_idx = 0; tile_idx < 3; tile_idx++) {
|
|
680
|
+
nk_size_t const depth_start = depth_base + tile_idx * 32;
|
|
681
|
+
nk_size_t const valid_depth = (depth_start + 32 <= depth)
|
|
682
|
+
? 32
|
|
683
|
+
: (depth > depth_start ? depth - depth_start : 0);
|
|
684
|
+
|
|
685
|
+
nk_dots_f16_load_a_graniteamx_( //
|
|
686
|
+
&a_tiles[tile_idx], //
|
|
687
|
+
vectors + row_tile * stride_elements + depth_start, //
|
|
688
|
+
stride_elements, valid_rows, valid_depth);
|
|
689
|
+
|
|
690
|
+
if (row_tile == col_tile) {
|
|
691
|
+
// Reuse A data as B (self-correlation on diagonal)
|
|
692
|
+
nk_dots_pack_bf16_transposed_sapphireamx_(
|
|
693
|
+
(nk_dots_bf16_a16x32_sapphireamx_t const *)&a_tiles[tile_idx],
|
|
694
|
+
(nk_dots_bf16_b32x16_sapphireamx_t *)&b_tiles[tile_idx]);
|
|
695
|
+
}
|
|
696
|
+
else {
|
|
697
|
+
nk_dots_f16_load_a_graniteamx_( //
|
|
698
|
+
&b_src_tiles[tile_idx], //
|
|
699
|
+
vectors + col_tile * stride_elements + depth_start, //
|
|
700
|
+
stride_elements, valid_cols, valid_depth);
|
|
701
|
+
nk_dots_pack_bf16_transposed_sapphireamx_(
|
|
702
|
+
(nk_dots_bf16_a16x32_sapphireamx_t const *)&b_src_tiles[tile_idx],
|
|
703
|
+
(nk_dots_bf16_b32x16_sapphireamx_t *)&b_tiles[tile_idx]);
|
|
704
|
+
}
|
|
705
|
+
}
|
|
706
|
+
|
|
707
|
+
nk_dots_f16_update_graniteamx_( //
|
|
708
|
+
&state, &a_tiles[0], &a_tiles[1], &a_tiles[2], //
|
|
709
|
+
&b_tiles[0], &b_tiles[1], &b_tiles[2]);
|
|
710
|
+
}
|
|
711
|
+
|
|
712
|
+
nk_dots_f16_store_graniteamx_( //
|
|
713
|
+
&state, result + row_tile * result_stride_elements + col_tile, //
|
|
714
|
+
result_stride_elements, valid_rows, valid_cols);
|
|
715
|
+
}
|
|
716
|
+
}
|
|
717
|
+
}
|
|
718
|
+
|
|
719
|
+
#pragma endregion F16 Native
|
|
720
|
+
|
|
721
|
+
#pragma region E5M2 Source (widened to FP16 tiles)
|
|
722
|
+
|
|
723
|
+
/* E5M2 Granite AMX kernels: same F16 tile shapes, same TDPFP16PS compute body as the F16 path.
|
|
724
|
+
* The only difference is byte-to-word widening during A-load and B-pack: E5M2 shares F16's
|
|
725
|
+
* exponent bias (15), so `(byte << 8)` is the exact F16 bit pattern for every E5M2 value,
|
|
726
|
+
* including zero/subnormals/Inf/NaN. Tile buffers hold F16 after widen — "e5m2" in the
|
|
727
|
+
* typedefs refers to the source dtype, not the on-tile representation.
|
|
728
|
+
*
|
|
729
|
+
* The tile types below alias the F16 types (identical memory layout) so we can reuse the
|
|
730
|
+
* F16 init/store/output2x2/update internal helpers by pointer cast at the boundary. Only the
|
|
731
|
+
* public entry points and the load-A widen helper are new.
|
|
732
|
+
*/
|
|
733
|
+
|
|
734
|
+
typedef nk_dots_f16_a16x32_graniteamx_t nk_dots_e5m2_a16x32_graniteamx_t;
|
|
735
|
+
typedef nk_dots_f16_b32x16_graniteamx_t nk_dots_e5m2_b32x16_graniteamx_t;
|
|
736
|
+
typedef nk_dots_f16_state_graniteamx_t nk_dots_e5m2_state_graniteamx_t;
|
|
737
|
+
typedef nk_dots_f16_state2x2_graniteamx_t nk_dots_e5m2_state2x2_graniteamx_t;
|
|
738
|
+
|
|
739
|
+
/* Load A tile from E5M2 row-major source, widen to F16 via `(byte << 8)` into the F16 tile buffer. */
|
|
740
|
+
NK_INTERNAL void nk_dots_e5m2_load_a_graniteamx_( //
|
|
741
|
+
nk_dots_e5m2_a16x32_graniteamx_t *a_tile, //
|
|
742
|
+
nk_e5m2_t const *src, nk_size_t src_stride_elements, //
|
|
743
|
+
nk_size_t valid_rows, nk_size_t valid_cols) {
|
|
744
|
+
|
|
745
|
+
__mmask32 column_mask = (valid_cols >= 32) ? 0xFFFFFFFF : ((__mmask32)1 << valid_cols) - 1;
|
|
746
|
+
__m512i zero_i16x32 = _mm512_setzero_si512();
|
|
747
|
+
|
|
748
|
+
for (nk_size_t row_idx = 0; row_idx < 16; row_idx++) {
|
|
749
|
+
if (row_idx < valid_rows) {
|
|
750
|
+
__m256i e5m2_u8x32 = _mm256_maskz_loadu_epi8(column_mask, src + row_idx * src_stride_elements);
|
|
751
|
+
__m512i word_u16x32 = _mm512_cvtepu8_epi16(e5m2_u8x32);
|
|
752
|
+
__m512i f16_bits_i16x32 = _mm512_slli_epi16(word_u16x32, 8);
|
|
753
|
+
_mm512_store_si512((__m512i *)a_tile->data[row_idx], f16_bits_i16x32);
|
|
754
|
+
}
|
|
755
|
+
else { _mm512_store_si512((__m512i *)a_tile->data[row_idx], zero_i16x32); }
|
|
756
|
+
}
|
|
757
|
+
nk_compiler_barrier_sapphireamx_();
|
|
758
|
+
}
|
|
759
|
+
|
|
760
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_e5m2_graniteamx(nk_size_t column_count, nk_size_t depth) {
|
|
761
|
+
nk_size_t const tmm_rows = 16;
|
|
762
|
+
nk_size_t const tmm_cols = 32;
|
|
763
|
+
nk_size_t const tile_bytes = 512 * sizeof(nk_f16_t); // Tiles hold F16 after widen: same 1KB as F16.
|
|
764
|
+
|
|
765
|
+
nk_size_t const full_column_tiles = column_count / tmm_rows;
|
|
766
|
+
nk_size_t const tiles_along_depth = nk_size_divide_round_up_(depth, tmm_cols);
|
|
767
|
+
nk_size_t const column_remainder_count = column_count - full_column_tiles * tmm_rows;
|
|
768
|
+
|
|
769
|
+
nk_size_t size = sizeof(nk_dots_amx_packed_header_t);
|
|
770
|
+
size += full_column_tiles * tiles_along_depth * tile_bytes;
|
|
771
|
+
if (column_remainder_count > 0) size += column_remainder_count * depth * sizeof(nk_f16_t);
|
|
772
|
+
size += column_count * sizeof(nk_f32_t);
|
|
773
|
+
|
|
774
|
+
return size;
|
|
775
|
+
}
|
|
776
|
+
|
|
777
|
+
NK_PUBLIC void nk_dots_pack_e5m2_graniteamx( //
|
|
778
|
+
nk_e5m2_t const *b, nk_size_t column_count, nk_size_t depth, //
|
|
779
|
+
nk_size_t b_stride_in_bytes, void *b_packed) {
|
|
780
|
+
|
|
781
|
+
nk_size_t const tmm_rows = 16;
|
|
782
|
+
nk_size_t const tmm_cols = 32;
|
|
783
|
+
nk_size_t const tile_elements = 512;
|
|
784
|
+
nk_size_t const tile_bytes = tile_elements * sizeof(nk_f16_t);
|
|
785
|
+
nk_size_t const b_stride_elements = b_stride_in_bytes; // E5M2: 1 byte per element
|
|
786
|
+
|
|
787
|
+
nk_size_t const column_tiles_count = column_count / tmm_rows;
|
|
788
|
+
nk_size_t const depth_tiles_count = nk_size_divide_round_up_(depth, tmm_cols);
|
|
789
|
+
nk_size_t const column_remainder_count = column_count - column_tiles_count * tmm_rows;
|
|
790
|
+
nk_size_t const total_tiles = column_tiles_count * depth_tiles_count;
|
|
791
|
+
|
|
792
|
+
nk_dots_amx_packed_header_t *header = (nk_dots_amx_packed_header_t *)b_packed;
|
|
793
|
+
header->full_column_tiles = (nk_u32_t)column_tiles_count;
|
|
794
|
+
header->full_depth_tiles = (nk_u32_t)depth_tiles_count;
|
|
795
|
+
header->column_remainder_count = (nk_u32_t)column_remainder_count;
|
|
796
|
+
|
|
797
|
+
nk_size_t const tiles_offset = sizeof(nk_dots_amx_packed_header_t);
|
|
798
|
+
nk_size_t const column_edge_offset = tiles_offset + total_tiles * tile_bytes;
|
|
799
|
+
header->column_edge_offset = (nk_u32_t)column_edge_offset;
|
|
800
|
+
|
|
801
|
+
nk_f16_t *tiles_ptr = (nk_f16_t *)((char *)b_packed + tiles_offset);
|
|
802
|
+
nk_f16_t *column_edge_ptr = (nk_f16_t *)((char *)b_packed + column_edge_offset);
|
|
803
|
+
|
|
804
|
+
// Pack tiles: widen-and-gather 16 strided E5M2 rows into an F16-aligned source tile,
|
|
805
|
+
// then transpose via the shared BF16/F16 transposer (same 16-bit lane layout).
|
|
806
|
+
for (nk_size_t column_tile_idx = 0; column_tile_idx < column_tiles_count; column_tile_idx++) {
|
|
807
|
+
for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
|
|
808
|
+
|
|
809
|
+
nk_size_t const tile_index = column_tile_idx * depth_tiles_count + depth_tile_idx;
|
|
810
|
+
nk_f16_t *tile_output = tiles_ptr + tile_index * tile_elements;
|
|
811
|
+
|
|
812
|
+
nk_size_t const src_row_start = column_tile_idx * tmm_rows;
|
|
813
|
+
nk_size_t const src_column_start = depth_tile_idx * tmm_cols;
|
|
814
|
+
nk_size_t const columns_to_pack = (src_column_start + tmm_cols <= depth) ? tmm_cols
|
|
815
|
+
: (depth - src_column_start);
|
|
816
|
+
|
|
817
|
+
nk_dots_bf16_a16x32_sapphireamx_t source_tile;
|
|
818
|
+
__mmask32 depth_mask = (__mmask32)((columns_to_pack < 32) ? ((1U << columns_to_pack) - 1) : ~0U);
|
|
819
|
+
for (nk_size_t row_idx = 0; row_idx < tmm_rows; row_idx++) {
|
|
820
|
+
nk_e5m2_t const *source_row = b + (src_row_start + row_idx) * b_stride_elements + src_column_start;
|
|
821
|
+
__m256i e5m2_u8x32 = _mm256_maskz_loadu_epi8(depth_mask, source_row);
|
|
822
|
+
__m512i word_u16x32 = _mm512_cvtepu8_epi16(e5m2_u8x32);
|
|
823
|
+
__m512i f16_bits_i16x32 = _mm512_slli_epi16(word_u16x32, 8);
|
|
824
|
+
_mm512_store_si512(&source_tile.data[row_idx][0], f16_bits_i16x32);
|
|
825
|
+
}
|
|
826
|
+
|
|
827
|
+
nk_dots_bf16_b32x16_sapphireamx_t transposed_tile;
|
|
828
|
+
nk_dots_pack_bf16_transposed_sapphireamx_(&source_tile, &transposed_tile);
|
|
829
|
+
for (nk_size_t i = 0; i < tile_bytes; i += 64)
|
|
830
|
+
_mm512_storeu_si512((char *)tile_output + i, _mm512_load_si512((char const *)&transposed_tile + i));
|
|
831
|
+
}
|
|
832
|
+
}
|
|
833
|
+
|
|
834
|
+
// Column-remainder: widen and store row-major (depth contiguous per column)
|
|
835
|
+
if (column_remainder_count > 0) {
|
|
836
|
+
nk_size_t const remainder_start_row = column_tiles_count * tmm_rows;
|
|
837
|
+
for (nk_size_t row_idx = 0; row_idx < column_remainder_count; row_idx++) {
|
|
838
|
+
nk_e5m2_t const *src_row = b + (remainder_start_row + row_idx) * b_stride_elements;
|
|
839
|
+
nk_f16_t *dst_row = column_edge_ptr + row_idx * depth;
|
|
840
|
+
nk_size_t column_idx = 0;
|
|
841
|
+
for (; column_idx + 32 <= depth; column_idx += 32) {
|
|
842
|
+
__m256i e5m2_u8x32 = _mm256_loadu_si256((__m256i const *)(src_row + column_idx));
|
|
843
|
+
__m512i word_u16x32 = _mm512_cvtepu8_epi16(e5m2_u8x32);
|
|
844
|
+
__m512i f16_bits_i16x32 = _mm512_slli_epi16(word_u16x32, 8);
|
|
845
|
+
_mm512_storeu_si512(dst_row + column_idx, f16_bits_i16x32);
|
|
846
|
+
}
|
|
847
|
+
if (column_idx < depth) {
|
|
848
|
+
__mmask32 tail_mask = (__mmask32)((1U << (depth - column_idx)) - 1);
|
|
849
|
+
__m256i e5m2_u8x32 = _mm256_maskz_loadu_epi8(tail_mask, src_row + column_idx);
|
|
850
|
+
__m512i word_u16x32 = _mm512_cvtepu8_epi16(e5m2_u8x32);
|
|
851
|
+
__m512i f16_bits_i16x32 = _mm512_slli_epi16(word_u16x32, 8);
|
|
852
|
+
_mm512_mask_storeu_epi16(dst_row + column_idx, tail_mask, f16_bits_i16x32);
|
|
853
|
+
}
|
|
854
|
+
}
|
|
855
|
+
}
|
|
856
|
+
|
|
857
|
+
nk_size_t norms_offset = column_edge_offset +
|
|
858
|
+
(column_remainder_count > 0 ? column_remainder_count * depth * sizeof(nk_f16_t) : 0);
|
|
859
|
+
header->norms_byte_offset = (nk_u32_t)norms_offset;
|
|
860
|
+
nk_f32_t *norms = (nk_f32_t *)((char *)b_packed + norms_offset);
|
|
861
|
+
for (nk_size_t col = 0; col < column_count; col++)
|
|
862
|
+
norms[col] = nk_dots_reduce_sumsq_e5m2_(b + col * b_stride_elements, depth);
|
|
863
|
+
}
|
|
864
|
+
|
|
865
|
+
NK_PUBLIC void nk_dots_packed_e5m2_graniteamx( //
|
|
866
|
+
nk_e5m2_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
867
|
+
nk_size_t rows_count, nk_size_t cols_count, nk_size_t depth, nk_size_t a_stride_bytes, nk_size_t c_stride_bytes) {
|
|
868
|
+
nk_unused_(cols_count);
|
|
869
|
+
|
|
870
|
+
nk_dots_amx_packed_header_t const *header = (nk_dots_amx_packed_header_t const *)b_packed;
|
|
871
|
+
nk_size_t const column_tiles_count = header->full_column_tiles;
|
|
872
|
+
nk_size_t const depth_tiles_count = header->full_depth_tiles;
|
|
873
|
+
nk_size_t const column_remainder_count = header->column_remainder_count;
|
|
874
|
+
|
|
875
|
+
nk_f16_t const *b_tiles_base = (nk_f16_t const *)((char const *)b_packed + sizeof(nk_dots_amx_packed_header_t));
|
|
876
|
+
nk_f16_t const *col_edge_ptr = (nk_f16_t const *)((char const *)b_packed + header->column_edge_offset);
|
|
877
|
+
|
|
878
|
+
nk_size_t const a_stride_elements = a_stride_bytes; // E5M2: 1 byte per element
|
|
879
|
+
nk_size_t const c_stride_elements = c_stride_bytes / sizeof(nk_f32_t);
|
|
880
|
+
|
|
881
|
+
nk_size_t const tile_depth = 32;
|
|
882
|
+
nk_size_t const tile_size = 512;
|
|
883
|
+
nk_size_t const full_cols = column_tiles_count * 16;
|
|
884
|
+
|
|
885
|
+
nk_size_t const row_blocks_count = nk_size_divide_round_up_(rows_count, 32);
|
|
886
|
+
nk_size_t const col_blocks_count = column_tiles_count / 2;
|
|
887
|
+
|
|
888
|
+
if (depth_tiles_count == 0) return;
|
|
889
|
+
|
|
890
|
+
nk_dots_e5m2_a16x32_graniteamx_t a_tile_top, a_tile_bottom;
|
|
891
|
+
nk_dots_e5m2_state2x2_graniteamx_t c_accum_buffer;
|
|
892
|
+
|
|
893
|
+
nk_size_t const full_depth_tiles_count = depth / tile_depth;
|
|
894
|
+
nk_size_t const depth_remainder = depth % tile_depth;
|
|
895
|
+
|
|
896
|
+
nk_amx_tile_configure_sapphireamx_();
|
|
897
|
+
|
|
898
|
+
for (nk_size_t row_block_idx = 0; row_block_idx < row_blocks_count; row_block_idx++) {
|
|
899
|
+
nk_size_t const row_block_start = row_block_idx * 32;
|
|
900
|
+
nk_size_t const valid_rows_count = (row_block_start + 32 <= rows_count) ? 32 : (rows_count - row_block_start);
|
|
901
|
+
nk_size_t const is_full_row_block = (valid_rows_count == 32);
|
|
902
|
+
|
|
903
|
+
for (nk_size_t column_block_idx = 0; column_block_idx < col_blocks_count; column_block_idx++) {
|
|
904
|
+
nk_size_t const col_block_start = column_block_idx * 32;
|
|
905
|
+
nk_size_t const b_column_left_base = (column_block_idx * 2) * depth_tiles_count;
|
|
906
|
+
nk_size_t const b_column_right_base = (column_block_idx * 2 + 1) * depth_tiles_count;
|
|
907
|
+
|
|
908
|
+
_tile_zero(4);
|
|
909
|
+
_tile_zero(5);
|
|
910
|
+
_tile_zero(6);
|
|
911
|
+
_tile_zero(7);
|
|
912
|
+
|
|
913
|
+
// E5M2 A needs widen-then-tile-load every time (no fast in-place path), so we always
|
|
914
|
+
// route through the widen buffer a_tile_{top,bottom} rather than `_tile_loadd` from A.
|
|
915
|
+
for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
|
|
916
|
+
nk_size_t const depth_offset = depth_tile_idx * tile_depth;
|
|
917
|
+
nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
|
|
918
|
+
nk_size_t const rows_in_high_tile = is_full_row_block
|
|
919
|
+
? 16
|
|
920
|
+
: ((valid_rows_count > 16) ? 16 : valid_rows_count);
|
|
921
|
+
nk_size_t const rows_in_low_tile = is_full_row_block
|
|
922
|
+
? 16
|
|
923
|
+
: ((valid_rows_count > 16) ? valid_rows_count - 16 : 0);
|
|
924
|
+
|
|
925
|
+
nk_dots_e5m2_load_a_graniteamx_(&a_tile_top, a + row_block_start * a_stride_elements + depth_offset,
|
|
926
|
+
a_stride_elements, rows_in_high_tile, valid_depth);
|
|
927
|
+
if (rows_in_low_tile > 0) {
|
|
928
|
+
nk_dots_e5m2_load_a_graniteamx_(&a_tile_bottom,
|
|
929
|
+
a + (row_block_start + 16) * a_stride_elements + depth_offset,
|
|
930
|
+
a_stride_elements, rows_in_low_tile, valid_depth);
|
|
931
|
+
}
|
|
932
|
+
|
|
933
|
+
nk_dots_f16_b32x16_graniteamx_t const *b_tile_left = //
|
|
934
|
+
(nk_dots_f16_b32x16_graniteamx_t const *)(b_tiles_base +
|
|
935
|
+
(b_column_left_base + depth_tile_idx) * tile_size);
|
|
936
|
+
nk_dots_f16_b32x16_graniteamx_t const *b_tile_right = //
|
|
937
|
+
(nk_dots_f16_b32x16_graniteamx_t const *)(b_tiles_base +
|
|
938
|
+
(b_column_right_base + depth_tile_idx) * tile_size);
|
|
939
|
+
|
|
940
|
+
_tile_loadd(0, a_tile_top.data, 64);
|
|
941
|
+
if (rows_in_low_tile > 0) _tile_loadd(1, a_tile_bottom.data, 64);
|
|
942
|
+
_tile_loadd(2, b_tile_left->data, 64);
|
|
943
|
+
_tile_loadd(3, b_tile_right->data, 64);
|
|
944
|
+
|
|
945
|
+
_tile_dpfp16ps(4, 0, 2);
|
|
946
|
+
_tile_dpfp16ps(5, 0, 3);
|
|
947
|
+
if (rows_in_low_tile > 0) {
|
|
948
|
+
_tile_dpfp16ps(6, 1, 2);
|
|
949
|
+
_tile_dpfp16ps(7, 1, 3);
|
|
950
|
+
}
|
|
951
|
+
}
|
|
952
|
+
|
|
953
|
+
if (is_full_row_block) {
|
|
954
|
+
nk_f32_t *c_block = c + row_block_start * c_stride_elements + col_block_start;
|
|
955
|
+
_tile_stored(4, c_block, c_stride_bytes);
|
|
956
|
+
_tile_stored(5, c_block + 16, c_stride_bytes);
|
|
957
|
+
_tile_stored(6, (nk_f32_t *)((char *)c_block + 16 * c_stride_bytes), c_stride_bytes);
|
|
958
|
+
_tile_stored(7, (nk_f32_t *)((char *)c_block + 16 * c_stride_bytes) + 16, c_stride_bytes);
|
|
959
|
+
}
|
|
960
|
+
else {
|
|
961
|
+
_tile_stored(4, c_accum_buffer.c[0][0].data, 64);
|
|
962
|
+
_tile_stored(5, c_accum_buffer.c[0][1].data, 64);
|
|
963
|
+
_tile_stored(6, c_accum_buffer.c[1][0].data, 64);
|
|
964
|
+
_tile_stored(7, c_accum_buffer.c[1][1].data, 64);
|
|
965
|
+
nk_dots_f16_output2x2_graniteamx_(&c_accum_buffer,
|
|
966
|
+
c + row_block_start * c_stride_elements + col_block_start,
|
|
967
|
+
c_stride_elements, valid_rows_count, 32);
|
|
968
|
+
}
|
|
969
|
+
}
|
|
970
|
+
}
|
|
971
|
+
|
|
972
|
+
// Odd column tile
|
|
973
|
+
if (column_tiles_count % 2 == 1) {
|
|
974
|
+
nk_size_t const column_tile_idx = column_tiles_count - 1;
|
|
975
|
+
nk_size_t const col_start = column_tile_idx * 16;
|
|
976
|
+
nk_size_t const b_column_base = column_tile_idx * depth_tiles_count;
|
|
977
|
+
|
|
978
|
+
for (nk_size_t row_block_idx = 0; row_block_idx < row_blocks_count; row_block_idx++) {
|
|
979
|
+
nk_size_t const row_block_start = row_block_idx * 32;
|
|
980
|
+
nk_size_t const valid_rows_count = (row_block_start + 32 <= rows_count) ? 32
|
|
981
|
+
: (rows_count - row_block_start);
|
|
982
|
+
nk_size_t const rows_in_high_tile = (valid_rows_count > 16) ? 16 : valid_rows_count;
|
|
983
|
+
nk_size_t const rows_in_low_tile = (valid_rows_count > 16) ? valid_rows_count - 16 : 0;
|
|
984
|
+
|
|
985
|
+
nk_dots_f16_state_graniteamx_t c_high_state, c_low_state;
|
|
986
|
+
|
|
987
|
+
_tile_zero(4);
|
|
988
|
+
_tile_zero(6);
|
|
989
|
+
|
|
990
|
+
for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
|
|
991
|
+
nk_size_t const depth_offset = depth_tile_idx * tile_depth;
|
|
992
|
+
nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
|
|
993
|
+
|
|
994
|
+
nk_dots_e5m2_load_a_graniteamx_(&a_tile_top, a + row_block_start * a_stride_elements + depth_offset,
|
|
995
|
+
a_stride_elements, rows_in_high_tile, valid_depth);
|
|
996
|
+
if (rows_in_low_tile > 0) {
|
|
997
|
+
nk_dots_e5m2_load_a_graniteamx_(&a_tile_bottom,
|
|
998
|
+
a + (row_block_start + 16) * a_stride_elements + depth_offset,
|
|
999
|
+
a_stride_elements, rows_in_low_tile, valid_depth);
|
|
1000
|
+
}
|
|
1001
|
+
|
|
1002
|
+
nk_dots_f16_b32x16_graniteamx_t const *b_tile = //
|
|
1003
|
+
(nk_dots_f16_b32x16_graniteamx_t const *)(b_tiles_base +
|
|
1004
|
+
(b_column_base + depth_tile_idx) * tile_size);
|
|
1005
|
+
|
|
1006
|
+
_tile_loadd(0, a_tile_top.data, 64);
|
|
1007
|
+
if (rows_in_low_tile > 0) _tile_loadd(1, a_tile_bottom.data, 64);
|
|
1008
|
+
_tile_loadd(2, b_tile->data, 64);
|
|
1009
|
+
|
|
1010
|
+
_tile_dpfp16ps(4, 0, 2);
|
|
1011
|
+
if (rows_in_low_tile > 0) _tile_dpfp16ps(6, 1, 2);
|
|
1012
|
+
}
|
|
1013
|
+
|
|
1014
|
+
_tile_stored(4, c_high_state.data, 64);
|
|
1015
|
+
if (rows_in_low_tile > 0) _tile_stored(6, c_low_state.data, 64);
|
|
1016
|
+
|
|
1017
|
+
nk_dots_f16_store_graniteamx_(&c_high_state, c + row_block_start * c_stride_elements + col_start,
|
|
1018
|
+
c_stride_elements, rows_in_high_tile, 16);
|
|
1019
|
+
if (rows_in_low_tile > 0) {
|
|
1020
|
+
nk_dots_f16_store_graniteamx_(&c_low_state, c + (row_block_start + 16) * c_stride_elements + col_start,
|
|
1021
|
+
c_stride_elements, rows_in_low_tile, 16);
|
|
1022
|
+
}
|
|
1023
|
+
}
|
|
1024
|
+
}
|
|
1025
|
+
|
|
1026
|
+
// Column edge (fewer than 16 extra columns, stored row-major as F16 after pack)
|
|
1027
|
+
if (column_remainder_count > 0) {
|
|
1028
|
+
for (nk_size_t row_block_idx = 0; row_block_idx < row_blocks_count; row_block_idx++) {
|
|
1029
|
+
nk_size_t const row_block_start = row_block_idx * 32;
|
|
1030
|
+
nk_size_t const valid_rows_count = (row_block_start + 32 <= rows_count) ? 32
|
|
1031
|
+
: (rows_count - row_block_start);
|
|
1032
|
+
nk_size_t const rows_in_high_tile = (valid_rows_count > 16) ? 16 : valid_rows_count;
|
|
1033
|
+
nk_size_t const rows_in_low_tile = (valid_rows_count > 16) ? valid_rows_count - 16 : 0;
|
|
1034
|
+
|
|
1035
|
+
nk_dots_f16_state_graniteamx_t c_high_state, c_low_state;
|
|
1036
|
+
nk_dots_bf16_a16x32_sapphireamx_t b_as_a;
|
|
1037
|
+
nk_dots_bf16_b32x16_sapphireamx_t b_tile;
|
|
1038
|
+
|
|
1039
|
+
_tile_zero(4);
|
|
1040
|
+
_tile_zero(6);
|
|
1041
|
+
|
|
1042
|
+
for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
|
|
1043
|
+
nk_size_t const depth_offset = depth_tile_idx * tile_depth;
|
|
1044
|
+
nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
|
|
1045
|
+
|
|
1046
|
+
nk_dots_e5m2_load_a_graniteamx_(&a_tile_top, a + row_block_start * a_stride_elements + depth_offset,
|
|
1047
|
+
a_stride_elements, rows_in_high_tile, valid_depth);
|
|
1048
|
+
if (rows_in_low_tile > 0) {
|
|
1049
|
+
nk_dots_e5m2_load_a_graniteamx_(&a_tile_bottom,
|
|
1050
|
+
a + (row_block_start + 16) * a_stride_elements + depth_offset,
|
|
1051
|
+
a_stride_elements, rows_in_low_tile, valid_depth);
|
|
1052
|
+
}
|
|
1053
|
+
|
|
1054
|
+
nk_dots_bf16_load_a_sapphireamx_(&b_as_a, (nk_bf16_t const *)(col_edge_ptr + depth_offset), depth,
|
|
1055
|
+
column_remainder_count, valid_depth);
|
|
1056
|
+
nk_dots_pack_bf16_transposed_sapphireamx_(&b_as_a, &b_tile);
|
|
1057
|
+
|
|
1058
|
+
_tile_loadd(0, a_tile_top.data, 64);
|
|
1059
|
+
if (rows_in_low_tile > 0) _tile_loadd(1, a_tile_bottom.data, 64);
|
|
1060
|
+
_tile_loadd(2, b_tile.data, 64);
|
|
1061
|
+
|
|
1062
|
+
_tile_dpfp16ps(4, 0, 2);
|
|
1063
|
+
if (rows_in_low_tile > 0) _tile_dpfp16ps(6, 1, 2);
|
|
1064
|
+
}
|
|
1065
|
+
|
|
1066
|
+
_tile_stored(4, c_high_state.data, 64);
|
|
1067
|
+
if (rows_in_low_tile > 0) _tile_stored(6, c_low_state.data, 64);
|
|
1068
|
+
|
|
1069
|
+
nk_dots_f16_store_graniteamx_(&c_high_state, c + row_block_start * c_stride_elements + full_cols,
|
|
1070
|
+
c_stride_elements, rows_in_high_tile, column_remainder_count);
|
|
1071
|
+
if (rows_in_low_tile > 0) {
|
|
1072
|
+
nk_dots_f16_store_graniteamx_(&c_low_state, c + (row_block_start + 16) * c_stride_elements + full_cols,
|
|
1073
|
+
c_stride_elements, rows_in_low_tile, column_remainder_count);
|
|
1074
|
+
}
|
|
1075
|
+
}
|
|
1076
|
+
}
|
|
1077
|
+
|
|
1078
|
+
_tile_release();
|
|
1079
|
+
}
|
|
1080
|
+
|
|
1081
|
+
NK_PUBLIC void nk_dots_symmetric_e5m2_graniteamx( //
|
|
1082
|
+
nk_e5m2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, //
|
|
1083
|
+
nk_size_t stride_in_bytes, nk_f32_t *result, nk_size_t result_stride_in_bytes, //
|
|
1084
|
+
nk_size_t row_start, nk_size_t row_count) {
|
|
1085
|
+
|
|
1086
|
+
nk_size_t const stride_elements = stride_in_bytes; // E5M2: 1 byte per element
|
|
1087
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
1088
|
+
|
|
1089
|
+
nk_size_t const row_end = (row_count == 0)
|
|
1090
|
+
? vectors_count
|
|
1091
|
+
: (row_start + row_count < vectors_count ? row_start + row_count : vectors_count);
|
|
1092
|
+
|
|
1093
|
+
nk_size_t const depth_tiles = nk_size_divide_round_up_(depth, 32);
|
|
1094
|
+
nk_size_t const depth_tile_groups = nk_size_divide_round_up_(depth_tiles, 3);
|
|
1095
|
+
|
|
1096
|
+
nk_dots_e5m2_a16x32_graniteamx_t a_tiles[3];
|
|
1097
|
+
nk_dots_e5m2_a16x32_graniteamx_t b_src_tiles[3];
|
|
1098
|
+
nk_dots_f16_b32x16_graniteamx_t b_tiles[3];
|
|
1099
|
+
nk_dots_f16_state_graniteamx_t state;
|
|
1100
|
+
|
|
1101
|
+
nk_amx_tile_configure_sapphireamx_();
|
|
1102
|
+
|
|
1103
|
+
for (nk_size_t row_tile = row_start; row_tile < row_end; row_tile += 16) {
|
|
1104
|
+
nk_size_t const valid_rows = (row_tile + 16 <= row_end) ? 16 : (row_end - row_tile);
|
|
1105
|
+
|
|
1106
|
+
for (nk_size_t col_tile = 0; col_tile < vectors_count; col_tile += 16) {
|
|
1107
|
+
nk_size_t const valid_cols = (col_tile + 16 <= vectors_count) ? 16 : (vectors_count - col_tile);
|
|
1108
|
+
|
|
1109
|
+
nk_dots_f16_init_graniteamx_(&state);
|
|
1110
|
+
|
|
1111
|
+
for (nk_size_t depth_group_idx = 0; depth_group_idx < depth_tile_groups; depth_group_idx++) {
|
|
1112
|
+
nk_size_t const depth_base = depth_group_idx * 96;
|
|
1113
|
+
|
|
1114
|
+
for (int tile_idx = 0; tile_idx < 3; tile_idx++) {
|
|
1115
|
+
nk_size_t const depth_start = depth_base + tile_idx * 32;
|
|
1116
|
+
nk_size_t const valid_depth = (depth_start + 32 <= depth)
|
|
1117
|
+
? 32
|
|
1118
|
+
: (depth > depth_start ? depth - depth_start : 0);
|
|
1119
|
+
|
|
1120
|
+
nk_dots_e5m2_load_a_graniteamx_( //
|
|
1121
|
+
&a_tiles[tile_idx], //
|
|
1122
|
+
vectors + row_tile * stride_elements + depth_start, //
|
|
1123
|
+
stride_elements, valid_rows, valid_depth);
|
|
1124
|
+
|
|
1125
|
+
if (row_tile == col_tile) {
|
|
1126
|
+
nk_dots_pack_bf16_transposed_sapphireamx_(
|
|
1127
|
+
(nk_dots_bf16_a16x32_sapphireamx_t const *)&a_tiles[tile_idx],
|
|
1128
|
+
(nk_dots_bf16_b32x16_sapphireamx_t *)&b_tiles[tile_idx]);
|
|
1129
|
+
}
|
|
1130
|
+
else {
|
|
1131
|
+
nk_dots_e5m2_load_a_graniteamx_( //
|
|
1132
|
+
&b_src_tiles[tile_idx], //
|
|
1133
|
+
vectors + col_tile * stride_elements + depth_start, //
|
|
1134
|
+
stride_elements, valid_cols, valid_depth);
|
|
1135
|
+
nk_dots_pack_bf16_transposed_sapphireamx_(
|
|
1136
|
+
(nk_dots_bf16_a16x32_sapphireamx_t const *)&b_src_tiles[tile_idx],
|
|
1137
|
+
(nk_dots_bf16_b32x16_sapphireamx_t *)&b_tiles[tile_idx]);
|
|
1138
|
+
}
|
|
1139
|
+
}
|
|
1140
|
+
|
|
1141
|
+
nk_dots_f16_update_graniteamx_( //
|
|
1142
|
+
&state, &a_tiles[0], &a_tiles[1], &a_tiles[2], //
|
|
1143
|
+
&b_tiles[0], &b_tiles[1], &b_tiles[2]);
|
|
1144
|
+
}
|
|
1145
|
+
|
|
1146
|
+
nk_dots_f16_store_graniteamx_( //
|
|
1147
|
+
&state, result + row_tile * result_stride_elements + col_tile, //
|
|
1148
|
+
result_stride_elements, valid_rows, valid_cols);
|
|
1149
|
+
}
|
|
1150
|
+
}
|
|
1151
|
+
}
|
|
1152
|
+
|
|
1153
|
+
#pragma endregion E5M2 Source
|
|
1154
|
+
|
|
1155
|
+
#if defined(__clang__)
|
|
1156
|
+
#pragma clang attribute pop
|
|
1157
|
+
#elif defined(__GNUC__)
|
|
1158
|
+
#pragma GCC pop_options
|
|
1159
|
+
#endif
|
|
1160
|
+
|
|
1161
|
+
#if defined(__cplusplus)
|
|
1162
|
+
} // extern "C"
|
|
1163
|
+
#endif
|
|
1164
|
+
|
|
1165
|
+
#endif // NK_TARGET_GRANITEAMX
|
|
1166
|
+
#endif // NK_TARGET_X8664_
|
|
1167
|
+
#endif // NK_DOTS_GRANITEAMX_H
|