numkong 7.0.0 → 7.4.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +197 -124
- package/binding.gyp +34 -484
- package/c/dispatch_bf16.c +59 -1
- package/c/dispatch_e2m3.c +41 -8
- package/c/dispatch_e3m2.c +49 -8
- package/c/dispatch_e4m3.c +51 -9
- package/c/dispatch_e5m2.c +45 -1
- package/c/dispatch_f16.c +79 -26
- package/c/dispatch_f16c.c +5 -5
- package/c/dispatch_f32.c +56 -0
- package/c/dispatch_f64.c +52 -0
- package/c/dispatch_i4.c +3 -0
- package/c/dispatch_i8.c +62 -3
- package/c/dispatch_other.c +18 -0
- package/c/dispatch_u1.c +54 -9
- package/c/dispatch_u4.c +3 -0
- package/c/dispatch_u8.c +64 -3
- package/c/numkong.c +3 -0
- package/include/README.md +79 -9
- package/include/numkong/attention/sapphireamx.h +278 -276
- package/include/numkong/attention/sme.h +983 -977
- package/include/numkong/attention.h +1 -1
- package/include/numkong/capabilities.h +289 -94
- package/include/numkong/cast/README.md +40 -40
- package/include/numkong/cast/diamond.h +64 -0
- package/include/numkong/cast/haswell.h +42 -194
- package/include/numkong/cast/icelake.h +42 -37
- package/include/numkong/cast/loongsonasx.h +252 -0
- package/include/numkong/cast/neon.h +216 -249
- package/include/numkong/cast/powervsx.h +449 -0
- package/include/numkong/cast/rvv.h +223 -274
- package/include/numkong/cast/sapphire.h +18 -18
- package/include/numkong/cast/serial.h +1018 -944
- package/include/numkong/cast/skylake.h +82 -23
- package/include/numkong/cast/v128relaxed.h +462 -105
- package/include/numkong/cast.h +24 -0
- package/include/numkong/cast.hpp +44 -0
- package/include/numkong/curved/README.md +17 -17
- package/include/numkong/curved/neon.h +131 -7
- package/include/numkong/curved/neonbfdot.h +6 -7
- package/include/numkong/curved/rvv.h +26 -26
- package/include/numkong/curved/smef64.h +186 -182
- package/include/numkong/curved.h +14 -18
- package/include/numkong/dot/README.md +154 -137
- package/include/numkong/dot/alder.h +43 -43
- package/include/numkong/dot/diamond.h +158 -0
- package/include/numkong/dot/genoa.h +4 -30
- package/include/numkong/dot/haswell.h +215 -180
- package/include/numkong/dot/icelake.h +190 -76
- package/include/numkong/dot/loongsonasx.h +671 -0
- package/include/numkong/dot/neon.h +124 -73
- package/include/numkong/dot/neonbfdot.h +11 -12
- package/include/numkong/dot/neonfhm.h +44 -46
- package/include/numkong/dot/neonfp8.h +323 -0
- package/include/numkong/dot/neonsdot.h +190 -76
- package/include/numkong/dot/powervsx.h +752 -0
- package/include/numkong/dot/rvv.h +92 -84
- package/include/numkong/dot/rvvbf16.h +12 -12
- package/include/numkong/dot/rvvhalf.h +12 -12
- package/include/numkong/dot/sapphire.h +4 -4
- package/include/numkong/dot/serial.h +66 -30
- package/include/numkong/dot/sierra.h +31 -31
- package/include/numkong/dot/skylake.h +142 -110
- package/include/numkong/dot/sve.h +217 -177
- package/include/numkong/dot/svebfdot.h +10 -10
- package/include/numkong/dot/svehalf.h +85 -41
- package/include/numkong/dot/svesdot.h +89 -0
- package/include/numkong/dot/v128relaxed.h +124 -89
- package/include/numkong/dot.h +114 -48
- package/include/numkong/dots/README.md +203 -203
- package/include/numkong/dots/alder.h +12 -9
- package/include/numkong/dots/diamond.h +86 -0
- package/include/numkong/dots/genoa.h +10 -4
- package/include/numkong/dots/haswell.h +63 -48
- package/include/numkong/dots/icelake.h +27 -18
- package/include/numkong/dots/loongsonasx.h +176 -0
- package/include/numkong/dots/neon.h +14 -11
- package/include/numkong/dots/neonbfdot.h +4 -3
- package/include/numkong/dots/neonfhm.h +11 -9
- package/include/numkong/dots/neonfp8.h +99 -0
- package/include/numkong/dots/neonsdot.h +48 -12
- package/include/numkong/dots/powervsx.h +194 -0
- package/include/numkong/dots/rvv.h +451 -344
- package/include/numkong/dots/sapphireamx.h +1028 -984
- package/include/numkong/dots/serial.h +213 -197
- package/include/numkong/dots/sierra.h +10 -7
- package/include/numkong/dots/skylake.h +47 -36
- package/include/numkong/dots/sme.h +2001 -2364
- package/include/numkong/dots/smebi32.h +175 -162
- package/include/numkong/dots/smef64.h +328 -323
- package/include/numkong/dots/v128relaxed.h +64 -41
- package/include/numkong/dots.h +573 -293
- package/include/numkong/dots.hpp +45 -43
- package/include/numkong/each/README.md +133 -137
- package/include/numkong/each/haswell.h +6 -6
- package/include/numkong/each/icelake.h +7 -7
- package/include/numkong/each/neon.h +76 -42
- package/include/numkong/each/neonbfdot.h +11 -12
- package/include/numkong/each/neonhalf.h +24 -116
- package/include/numkong/each/rvv.h +28 -28
- package/include/numkong/each/sapphire.h +27 -161
- package/include/numkong/each/serial.h +6 -6
- package/include/numkong/each/skylake.h +7 -7
- package/include/numkong/each/v128relaxed.h +562 -0
- package/include/numkong/each.h +148 -62
- package/include/numkong/each.hpp +2 -2
- package/include/numkong/geospatial/README.md +18 -18
- package/include/numkong/geospatial/haswell.h +365 -325
- package/include/numkong/geospatial/neon.h +350 -306
- package/include/numkong/geospatial/rvv.h +4 -4
- package/include/numkong/geospatial/skylake.h +376 -340
- package/include/numkong/geospatial/v128relaxed.h +366 -327
- package/include/numkong/geospatial.h +17 -17
- package/include/numkong/matrix.hpp +4 -4
- package/include/numkong/maxsim/README.md +14 -14
- package/include/numkong/maxsim/alder.h +6 -6
- package/include/numkong/maxsim/genoa.h +4 -4
- package/include/numkong/maxsim/haswell.h +6 -6
- package/include/numkong/maxsim/icelake.h +18 -18
- package/include/numkong/maxsim/neonsdot.h +21 -21
- package/include/numkong/maxsim/sapphireamx.h +14 -14
- package/include/numkong/maxsim/serial.h +6 -6
- package/include/numkong/maxsim/sme.h +221 -196
- package/include/numkong/maxsim/v128relaxed.h +6 -6
- package/include/numkong/mesh/README.md +62 -56
- package/include/numkong/mesh/haswell.h +339 -464
- package/include/numkong/mesh/neon.h +1100 -519
- package/include/numkong/mesh/neonbfdot.h +36 -68
- package/include/numkong/mesh/rvv.h +530 -435
- package/include/numkong/mesh/serial.h +75 -91
- package/include/numkong/mesh/skylake.h +1627 -302
- package/include/numkong/mesh/v128relaxed.h +443 -330
- package/include/numkong/mesh.h +63 -49
- package/include/numkong/mesh.hpp +4 -4
- package/include/numkong/numkong.h +3 -3
- package/include/numkong/numkong.hpp +1 -0
- package/include/numkong/probability/README.md +23 -19
- package/include/numkong/probability/neon.h +82 -52
- package/include/numkong/probability/rvv.h +28 -23
- package/include/numkong/probability/serial.h +51 -39
- package/include/numkong/probability.h +20 -23
- package/include/numkong/random.h +1 -1
- package/include/numkong/reduce/README.md +143 -138
- package/include/numkong/reduce/alder.h +81 -77
- package/include/numkong/reduce/haswell.h +222 -220
- package/include/numkong/reduce/neon.h +629 -519
- package/include/numkong/reduce/neonbfdot.h +7 -218
- package/include/numkong/reduce/neonfhm.h +9 -381
- package/include/numkong/reduce/neonsdot.h +9 -9
- package/include/numkong/reduce/rvv.h +928 -802
- package/include/numkong/reduce/serial.h +23 -27
- package/include/numkong/reduce/sierra.h +20 -20
- package/include/numkong/reduce/skylake.h +326 -324
- package/include/numkong/reduce/v128relaxed.h +52 -52
- package/include/numkong/reduce.h +4 -23
- package/include/numkong/reduce.hpp +156 -11
- package/include/numkong/scalar/README.md +6 -6
- package/include/numkong/scalar/haswell.h +26 -17
- package/include/numkong/scalar/loongsonasx.h +74 -0
- package/include/numkong/scalar/neon.h +9 -9
- package/include/numkong/scalar/powervsx.h +96 -0
- package/include/numkong/scalar/rvv.h +2 -2
- package/include/numkong/scalar/sapphire.h +21 -10
- package/include/numkong/scalar/serial.h +21 -21
- package/include/numkong/scalar.h +13 -0
- package/include/numkong/set/README.md +28 -28
- package/include/numkong/set/haswell.h +12 -12
- package/include/numkong/set/icelake.h +14 -14
- package/include/numkong/set/loongsonasx.h +181 -0
- package/include/numkong/set/neon.h +17 -18
- package/include/numkong/set/powervsx.h +326 -0
- package/include/numkong/set/rvv.h +4 -4
- package/include/numkong/set/serial.h +6 -6
- package/include/numkong/set/sve.h +60 -59
- package/include/numkong/set/v128relaxed.h +6 -6
- package/include/numkong/set.h +21 -7
- package/include/numkong/sets/README.md +26 -26
- package/include/numkong/sets/loongsonasx.h +52 -0
- package/include/numkong/sets/powervsx.h +65 -0
- package/include/numkong/sets/smebi32.h +395 -364
- package/include/numkong/sets.h +83 -40
- package/include/numkong/sparse/README.md +4 -4
- package/include/numkong/sparse/icelake.h +101 -101
- package/include/numkong/sparse/serial.h +1 -1
- package/include/numkong/sparse/sve2.h +137 -141
- package/include/numkong/sparse/turin.h +12 -12
- package/include/numkong/sparse.h +10 -10
- package/include/numkong/spatial/README.md +230 -226
- package/include/numkong/spatial/alder.h +113 -116
- package/include/numkong/spatial/diamond.h +240 -0
- package/include/numkong/spatial/genoa.h +0 -68
- package/include/numkong/spatial/haswell.h +74 -55
- package/include/numkong/spatial/icelake.h +539 -58
- package/include/numkong/spatial/loongsonasx.h +483 -0
- package/include/numkong/spatial/neon.h +125 -52
- package/include/numkong/spatial/neonbfdot.h +8 -9
- package/include/numkong/spatial/neonfp8.h +258 -0
- package/include/numkong/spatial/neonsdot.h +180 -12
- package/include/numkong/spatial/powervsx.h +738 -0
- package/include/numkong/spatial/rvv.h +146 -139
- package/include/numkong/spatial/rvvbf16.h +17 -12
- package/include/numkong/spatial/rvvhalf.h +13 -10
- package/include/numkong/spatial/serial.h +13 -12
- package/include/numkong/spatial/sierra.h +232 -39
- package/include/numkong/spatial/skylake.h +73 -74
- package/include/numkong/spatial/sve.h +93 -72
- package/include/numkong/spatial/svebfdot.h +29 -29
- package/include/numkong/spatial/svehalf.h +52 -26
- package/include/numkong/spatial/svesdot.h +142 -0
- package/include/numkong/spatial/v128relaxed.h +293 -41
- package/include/numkong/spatial.h +338 -82
- package/include/numkong/spatials/README.md +194 -194
- package/include/numkong/spatials/diamond.h +82 -0
- package/include/numkong/spatials/haswell.h +2 -2
- package/include/numkong/spatials/loongsonasx.h +153 -0
- package/include/numkong/spatials/neonfp8.h +111 -0
- package/include/numkong/spatials/neonsdot.h +34 -0
- package/include/numkong/spatials/powervsx.h +153 -0
- package/include/numkong/spatials/rvv.h +259 -243
- package/include/numkong/spatials/sapphireamx.h +173 -173
- package/include/numkong/spatials/serial.h +2 -2
- package/include/numkong/spatials/skylake.h +2 -2
- package/include/numkong/spatials/sme.h +590 -605
- package/include/numkong/spatials/smef64.h +139 -130
- package/include/numkong/spatials/v128relaxed.h +2 -2
- package/include/numkong/spatials.h +820 -500
- package/include/numkong/spatials.hpp +49 -48
- package/include/numkong/tensor.hpp +406 -17
- package/include/numkong/trigonometry/README.md +19 -19
- package/include/numkong/trigonometry/haswell.h +402 -401
- package/include/numkong/trigonometry/neon.h +386 -387
- package/include/numkong/trigonometry/rvv.h +52 -51
- package/include/numkong/trigonometry/serial.h +13 -13
- package/include/numkong/trigonometry/skylake.h +373 -369
- package/include/numkong/trigonometry/v128relaxed.h +375 -374
- package/include/numkong/trigonometry.h +13 -13
- package/include/numkong/trigonometry.hpp +2 -2
- package/include/numkong/types.h +287 -49
- package/include/numkong/types.hpp +436 -12
- package/include/numkong/vector.hpp +82 -14
- package/javascript/dist/cjs/numkong-wasm.js +6 -12
- package/javascript/dist/cjs/numkong.d.ts +7 -1
- package/javascript/dist/cjs/numkong.js +37 -11
- package/javascript/dist/cjs/types.d.ts +9 -0
- package/javascript/dist/cjs/types.js +96 -0
- package/javascript/dist/esm/numkong-browser.d.ts +14 -0
- package/javascript/dist/esm/numkong-browser.js +23 -0
- package/javascript/dist/esm/numkong-wasm.js +6 -12
- package/javascript/dist/esm/numkong.d.ts +7 -1
- package/javascript/dist/esm/numkong.js +37 -11
- package/javascript/dist/esm/types.d.ts +9 -0
- package/javascript/dist/esm/types.js +96 -0
- package/javascript/node-gyp-build.d.ts +4 -1
- package/javascript/numkong-browser.ts +40 -0
- package/javascript/numkong-wasm.ts +7 -13
- package/javascript/numkong.c +5 -26
- package/javascript/numkong.ts +36 -11
- package/javascript/tsconfig-base.json +1 -0
- package/javascript/tsconfig-cjs.json +6 -1
- package/javascript/types.ts +110 -0
- package/numkong.gypi +101 -0
- package/package.json +34 -13
- package/probes/arm_neon.c +8 -0
- package/probes/arm_neon_bfdot.c +9 -0
- package/probes/arm_neon_fhm.c +9 -0
- package/probes/arm_neon_half.c +8 -0
- package/probes/arm_neon_sdot.c +9 -0
- package/probes/arm_neonfp8.c +9 -0
- package/probes/arm_sme.c +16 -0
- package/probes/arm_sme2.c +16 -0
- package/probes/arm_sme2p1.c +16 -0
- package/probes/arm_sme_bf16.c +16 -0
- package/probes/arm_sme_bi32.c +16 -0
- package/probes/arm_sme_f64.c +16 -0
- package/probes/arm_sme_fa64.c +14 -0
- package/probes/arm_sme_half.c +16 -0
- package/probes/arm_sme_lut2.c +15 -0
- package/probes/arm_sve.c +18 -0
- package/probes/arm_sve2.c +20 -0
- package/probes/arm_sve2p1.c +18 -0
- package/probes/arm_sve_bfdot.c +20 -0
- package/probes/arm_sve_half.c +18 -0
- package/probes/arm_sve_sdot.c +21 -0
- package/probes/loongarch_lasx.c +12 -0
- package/probes/power_vsx.c +12 -0
- package/probes/probe.js +127 -0
- package/probes/riscv_rvv.c +14 -0
- package/probes/riscv_rvv_bb.c +15 -0
- package/probes/riscv_rvv_bf16.c +17 -0
- package/probes/riscv_rvv_half.c +14 -0
- package/probes/wasm_v128relaxed.c +11 -0
- package/probes/x86_alder.c +17 -0
- package/probes/x86_diamond.c +17 -0
- package/probes/x86_genoa.c +17 -0
- package/probes/x86_graniteamx.c +19 -0
- package/probes/x86_haswell.c +11 -0
- package/probes/x86_icelake.c +17 -0
- package/probes/x86_sapphire.c +16 -0
- package/probes/x86_sapphireamx.c +18 -0
- package/probes/x86_sierra.c +17 -0
- package/probes/x86_skylake.c +15 -0
- package/probes/x86_turin.c +17 -0
- package/wasm/numkong-emscripten.js +2 -0
- package/wasm/numkong.d.ts +14 -0
- package/wasm/numkong.js +1124 -0
- package/wasm/numkong.wasm +0 -0
- package/include/numkong/curved/neonhalf.h +0 -212
- package/include/numkong/dot/neonhalf.h +0 -198
- package/include/numkong/dots/neonhalf.h +0 -57
- package/include/numkong/mesh/neonhalf.h +0 -616
- package/include/numkong/reduce/neonhalf.h +0 -157
- package/include/numkong/spatial/neonhalf.h +0 -118
- package/include/numkong/spatial/sapphire.h +0 -343
- package/include/numkong/spatials/neonhalf.h +0 -58
- package/javascript/README.md +0 -246
|
@@ -214,8 +214,8 @@ NK_INTERNAL void nk_compiler_barrier_sapphireamx_(void) { __asm__ volatile("" ::
|
|
|
214
214
|
|
|
215
215
|
/* Initialize BF16 output state to zero */
|
|
216
216
|
NK_INTERNAL void nk_dots_bf16_init_sapphireamx_(nk_dots_bf16_state_sapphireamx_t *state) {
|
|
217
|
-
__m512
|
|
218
|
-
for (nk_size_t row_idx = 0; row_idx < 16; row_idx++) { _mm512_store_ps(state->data[row_idx],
|
|
217
|
+
__m512 zero_f32x16 = _mm512_setzero_ps();
|
|
218
|
+
for (nk_size_t row_idx = 0; row_idx < 16; row_idx++) { _mm512_store_ps(state->data[row_idx], zero_f32x16); }
|
|
219
219
|
}
|
|
220
220
|
|
|
221
221
|
/* Load A tile from row-major source with masking for edge tiles */
|
|
@@ -225,14 +225,14 @@ NK_INTERNAL void nk_dots_bf16_load_a_sapphireamx_( //
|
|
|
225
225
|
nk_size_t valid_rows, nk_size_t valid_cols) {
|
|
226
226
|
|
|
227
227
|
__mmask32 column_mask = (valid_cols >= 32) ? 0xFFFFFFFF : ((__mmask32)1 << valid_cols) - 1;
|
|
228
|
-
__m512i
|
|
228
|
+
__m512i zero_i16x32 = _mm512_setzero_si512();
|
|
229
229
|
|
|
230
230
|
for (nk_size_t row_idx = 0; row_idx < 16; row_idx++) {
|
|
231
231
|
if (row_idx < valid_rows) {
|
|
232
|
-
__m512i
|
|
233
|
-
_mm512_store_si512((__m512i *)a_tile->data[row_idx],
|
|
232
|
+
__m512i row_i16x32 = _mm512_maskz_loadu_epi16(column_mask, src + row_idx * src_stride_elements);
|
|
233
|
+
_mm512_store_si512((__m512i *)a_tile->data[row_idx], row_i16x32);
|
|
234
234
|
}
|
|
235
|
-
else { _mm512_store_si512((__m512i *)a_tile->data[row_idx],
|
|
235
|
+
else { _mm512_store_si512((__m512i *)a_tile->data[row_idx], zero_i16x32); }
|
|
236
236
|
}
|
|
237
237
|
nk_compiler_barrier_sapphireamx_();
|
|
238
238
|
}
|
|
@@ -246,8 +246,8 @@ NK_INTERNAL void nk_dots_bf16_store_sapphireamx_( //
|
|
|
246
246
|
__mmask16 column_mask = (valid_cols >= 16) ? 0xFFFF : ((__mmask16)1 << valid_cols) - 1;
|
|
247
247
|
|
|
248
248
|
for (nk_size_t row_idx = 0; row_idx < valid_rows; row_idx++) {
|
|
249
|
-
__m512
|
|
250
|
-
_mm512_mask_storeu_ps(dst + row_idx * dst_stride_elements, column_mask,
|
|
249
|
+
__m512 row_f32x16 = _mm512_load_ps(state->data[row_idx]);
|
|
250
|
+
_mm512_mask_storeu_ps(dst + row_idx * dst_stride_elements, column_mask, row_f32x16);
|
|
251
251
|
}
|
|
252
252
|
}
|
|
253
253
|
|
|
@@ -281,8 +281,10 @@ NK_INTERNAL void nk_dots_bf16_update_sapphireamx_( //
|
|
|
281
281
|
|
|
282
282
|
/* Initialize INT8 output state to zero */
|
|
283
283
|
NK_INTERNAL void nk_dots_i8_init_sapphireamx_(nk_dots_i8_state_sapphireamx_t *state) {
|
|
284
|
-
__m512i
|
|
285
|
-
for (nk_size_t row_idx = 0; row_idx < 16; row_idx++) {
|
|
284
|
+
__m512i zero_i32x16 = _mm512_setzero_si512();
|
|
285
|
+
for (nk_size_t row_idx = 0; row_idx < 16; row_idx++) {
|
|
286
|
+
_mm512_store_si512((__m512i *)state->data[row_idx], zero_i32x16);
|
|
287
|
+
}
|
|
286
288
|
}
|
|
287
289
|
|
|
288
290
|
/* Load A tile from row-major source with masking for edge tiles */
|
|
@@ -292,14 +294,14 @@ NK_INTERNAL void nk_dots_i8_load_a_sapphireamx_( //
|
|
|
292
294
|
nk_size_t valid_rows, nk_size_t valid_cols) {
|
|
293
295
|
|
|
294
296
|
__mmask64 column_mask = (valid_cols >= 64) ? 0xFFFFFFFFFFFFFFFFULL : ((__mmask64)1 << valid_cols) - 1;
|
|
295
|
-
__m512i
|
|
297
|
+
__m512i zero_i8x64 = _mm512_setzero_si512();
|
|
296
298
|
|
|
297
299
|
for (nk_size_t row_idx = 0; row_idx < 16; row_idx++) {
|
|
298
300
|
if (row_idx < valid_rows) {
|
|
299
|
-
__m512i
|
|
300
|
-
_mm512_store_si512((__m512i *)a_tile->data[row_idx],
|
|
301
|
+
__m512i row_i8x64 = _mm512_maskz_loadu_epi8(column_mask, src + row_idx * src_stride);
|
|
302
|
+
_mm512_store_si512((__m512i *)a_tile->data[row_idx], row_i8x64);
|
|
301
303
|
}
|
|
302
|
-
else { _mm512_store_si512((__m512i *)a_tile->data[row_idx],
|
|
304
|
+
else { _mm512_store_si512((__m512i *)a_tile->data[row_idx], zero_i8x64); }
|
|
303
305
|
}
|
|
304
306
|
nk_compiler_barrier_sapphireamx_();
|
|
305
307
|
}
|
|
@@ -313,8 +315,8 @@ NK_INTERNAL void nk_dots_i8_store_sapphireamx_( //
|
|
|
313
315
|
__mmask16 column_mask = (valid_cols >= 16) ? 0xFFFF : ((__mmask16)1 << valid_cols) - 1;
|
|
314
316
|
|
|
315
317
|
for (nk_size_t row_idx = 0; row_idx < valid_rows; row_idx++) {
|
|
316
|
-
__m512i
|
|
317
|
-
_mm512_mask_storeu_epi32(dst + row_idx * dst_stride_elements, column_mask,
|
|
318
|
+
__m512i row_i32x16 = _mm512_load_si512((__m512i const *)state->data[row_idx]);
|
|
319
|
+
_mm512_mask_storeu_epi32(dst + row_idx * dst_stride_elements, column_mask, row_i32x16);
|
|
318
320
|
}
|
|
319
321
|
}
|
|
320
322
|
|
|
@@ -353,24 +355,23 @@ NK_INTERNAL void nk_dots_bf16_output2x2_sapphireamx_( //
|
|
|
353
355
|
nk_size_t valid_rows, nk_size_t valid_cols) {
|
|
354
356
|
|
|
355
357
|
// Rows 0-15
|
|
356
|
-
nk_size_t const
|
|
358
|
+
nk_size_t const rows_high = (valid_rows > 16) ? 16 : valid_rows;
|
|
357
359
|
nk_size_t const cols_left = (valid_cols > 16) ? 16 : valid_cols;
|
|
358
360
|
nk_size_t const cols_right = (valid_cols > 16) ? valid_cols - 16 : 0;
|
|
359
361
|
|
|
360
|
-
if (
|
|
361
|
-
nk_dots_bf16_store_sapphireamx_(&state->c[0][0], dst, dst_stride_elements,
|
|
362
|
-
if (
|
|
363
|
-
nk_dots_bf16_store_sapphireamx_(&state->c[0][1], dst + 16, dst_stride_elements,
|
|
362
|
+
if (rows_high > 0 && cols_left > 0)
|
|
363
|
+
nk_dots_bf16_store_sapphireamx_(&state->c[0][0], dst, dst_stride_elements, rows_high, cols_left);
|
|
364
|
+
if (rows_high > 0 && cols_right > 0)
|
|
365
|
+
nk_dots_bf16_store_sapphireamx_(&state->c[0][1], dst + 16, dst_stride_elements, rows_high, cols_right);
|
|
364
366
|
|
|
365
367
|
// Rows 16-31
|
|
366
368
|
if (valid_rows > 16) {
|
|
367
|
-
nk_size_t const
|
|
368
|
-
nk_f32_t *
|
|
369
|
+
nk_size_t const rows_low = valid_rows - 16;
|
|
370
|
+
nk_f32_t *dst_low = dst + 16 * dst_stride_elements;
|
|
369
371
|
if (cols_left > 0)
|
|
370
|
-
nk_dots_bf16_store_sapphireamx_(&state->c[1][0],
|
|
372
|
+
nk_dots_bf16_store_sapphireamx_(&state->c[1][0], dst_low, dst_stride_elements, rows_low, cols_left);
|
|
371
373
|
if (cols_right > 0)
|
|
372
|
-
nk_dots_bf16_store_sapphireamx_(&state->c[1][1],
|
|
373
|
-
cols_right);
|
|
374
|
+
nk_dots_bf16_store_sapphireamx_(&state->c[1][1], dst_low + 16, dst_stride_elements, rows_low, cols_right);
|
|
374
375
|
}
|
|
375
376
|
}
|
|
376
377
|
|
|
@@ -380,22 +381,22 @@ NK_INTERNAL void nk_dots_i8_output2x2_sapphireamx_( //
|
|
|
380
381
|
nk_i32_t *dst, nk_size_t dst_stride_elements, //
|
|
381
382
|
nk_size_t valid_rows, nk_size_t valid_cols) {
|
|
382
383
|
|
|
383
|
-
nk_size_t const
|
|
384
|
+
nk_size_t const rows_high = (valid_rows > 16) ? 16 : valid_rows;
|
|
384
385
|
nk_size_t const cols_left = (valid_cols > 16) ? 16 : valid_cols;
|
|
385
386
|
nk_size_t const cols_right = (valid_cols > 16) ? valid_cols - 16 : 0;
|
|
386
387
|
|
|
387
|
-
if (
|
|
388
|
-
nk_dots_i8_store_sapphireamx_(&state->c[0][0], dst, dst_stride_elements,
|
|
389
|
-
if (
|
|
390
|
-
nk_dots_i8_store_sapphireamx_(&state->c[0][1], dst + 16, dst_stride_elements,
|
|
388
|
+
if (rows_high > 0 && cols_left > 0)
|
|
389
|
+
nk_dots_i8_store_sapphireamx_(&state->c[0][0], dst, dst_stride_elements, rows_high, cols_left);
|
|
390
|
+
if (rows_high > 0 && cols_right > 0)
|
|
391
|
+
nk_dots_i8_store_sapphireamx_(&state->c[0][1], dst + 16, dst_stride_elements, rows_high, cols_right);
|
|
391
392
|
|
|
392
393
|
if (valid_rows > 16) {
|
|
393
|
-
nk_size_t const
|
|
394
|
-
nk_i32_t *
|
|
394
|
+
nk_size_t const rows_low = valid_rows - 16;
|
|
395
|
+
nk_i32_t *dst_low = dst + 16 * dst_stride_elements;
|
|
395
396
|
if (cols_left > 0)
|
|
396
|
-
nk_dots_i8_store_sapphireamx_(&state->c[1][0],
|
|
397
|
+
nk_dots_i8_store_sapphireamx_(&state->c[1][0], dst_low, dst_stride_elements, rows_low, cols_left);
|
|
397
398
|
if (cols_right > 0)
|
|
398
|
-
nk_dots_i8_store_sapphireamx_(&state->c[1][1],
|
|
399
|
+
nk_dots_i8_store_sapphireamx_(&state->c[1][1], dst_low + 16, dst_stride_elements, rows_low, cols_right);
|
|
399
400
|
}
|
|
400
401
|
}
|
|
401
402
|
|
|
@@ -441,114 +442,114 @@ NK_INTERNAL void nk_dots_pack_u8_transposed_sapphireamx_( //
|
|
|
441
442
|
|
|
442
443
|
// Load all 16 rows - each row is 64 UINT8 = 64 bytes = 1 ZMM
|
|
443
444
|
// Treat as 16 × 32-bit elements per row (each 32-bit = quad of UINT8)
|
|
444
|
-
__m512i
|
|
445
|
-
__m512i
|
|
446
|
-
__m512i
|
|
447
|
-
__m512i
|
|
448
|
-
__m512i
|
|
449
|
-
__m512i
|
|
450
|
-
__m512i
|
|
451
|
-
__m512i
|
|
452
|
-
__m512i
|
|
453
|
-
__m512i
|
|
454
|
-
__m512i
|
|
455
|
-
__m512i
|
|
456
|
-
__m512i
|
|
457
|
-
__m512i
|
|
458
|
-
__m512i
|
|
459
|
-
__m512i
|
|
445
|
+
__m512i row00_i32x16 = _mm512_load_si512(&a_tile->data[0][0]);
|
|
446
|
+
__m512i row01_i32x16 = _mm512_load_si512(&a_tile->data[1][0]);
|
|
447
|
+
__m512i row02_i32x16 = _mm512_load_si512(&a_tile->data[2][0]);
|
|
448
|
+
__m512i row03_i32x16 = _mm512_load_si512(&a_tile->data[3][0]);
|
|
449
|
+
__m512i row04_i32x16 = _mm512_load_si512(&a_tile->data[4][0]);
|
|
450
|
+
__m512i row05_i32x16 = _mm512_load_si512(&a_tile->data[5][0]);
|
|
451
|
+
__m512i row06_i32x16 = _mm512_load_si512(&a_tile->data[6][0]);
|
|
452
|
+
__m512i row07_i32x16 = _mm512_load_si512(&a_tile->data[7][0]);
|
|
453
|
+
__m512i row08_i32x16 = _mm512_load_si512(&a_tile->data[8][0]);
|
|
454
|
+
__m512i row09_i32x16 = _mm512_load_si512(&a_tile->data[9][0]);
|
|
455
|
+
__m512i row10_i32x16 = _mm512_load_si512(&a_tile->data[10][0]);
|
|
456
|
+
__m512i row11_i32x16 = _mm512_load_si512(&a_tile->data[11][0]);
|
|
457
|
+
__m512i row12_i32x16 = _mm512_load_si512(&a_tile->data[12][0]);
|
|
458
|
+
__m512i row13_i32x16 = _mm512_load_si512(&a_tile->data[13][0]);
|
|
459
|
+
__m512i row14_i32x16 = _mm512_load_si512(&a_tile->data[14][0]);
|
|
460
|
+
__m512i row15_i32x16 = _mm512_load_si512(&a_tile->data[15][0]);
|
|
460
461
|
|
|
461
462
|
// 16×16 transpose of 32-bit elements using hierarchical unpacks
|
|
462
463
|
// Stage 1: Unpack adjacent row pairs at 32-bit granularity
|
|
463
|
-
__m512i
|
|
464
|
-
__m512i
|
|
465
|
-
__m512i
|
|
466
|
-
__m512i
|
|
467
|
-
__m512i
|
|
468
|
-
__m512i
|
|
469
|
-
__m512i
|
|
470
|
-
__m512i
|
|
471
|
-
__m512i
|
|
472
|
-
__m512i
|
|
473
|
-
__m512i
|
|
474
|
-
__m512i
|
|
475
|
-
__m512i
|
|
476
|
-
__m512i
|
|
477
|
-
__m512i
|
|
478
|
-
__m512i
|
|
464
|
+
__m512i t01_low_i32x16 = _mm512_unpacklo_epi32(row00_i32x16, row01_i32x16);
|
|
465
|
+
__m512i t01_high_i32x16 = _mm512_unpackhi_epi32(row00_i32x16, row01_i32x16);
|
|
466
|
+
__m512i t23_low_i32x16 = _mm512_unpacklo_epi32(row02_i32x16, row03_i32x16);
|
|
467
|
+
__m512i t23_high_i32x16 = _mm512_unpackhi_epi32(row02_i32x16, row03_i32x16);
|
|
468
|
+
__m512i t45_low_i32x16 = _mm512_unpacklo_epi32(row04_i32x16, row05_i32x16);
|
|
469
|
+
__m512i t45_high_i32x16 = _mm512_unpackhi_epi32(row04_i32x16, row05_i32x16);
|
|
470
|
+
__m512i t67_low_i32x16 = _mm512_unpacklo_epi32(row06_i32x16, row07_i32x16);
|
|
471
|
+
__m512i t67_high_i32x16 = _mm512_unpackhi_epi32(row06_i32x16, row07_i32x16);
|
|
472
|
+
__m512i t89_low_i32x16 = _mm512_unpacklo_epi32(row08_i32x16, row09_i32x16);
|
|
473
|
+
__m512i t89_high_i32x16 = _mm512_unpackhi_epi32(row08_i32x16, row09_i32x16);
|
|
474
|
+
__m512i tab_low_i32x16 = _mm512_unpacklo_epi32(row10_i32x16, row11_i32x16);
|
|
475
|
+
__m512i tab_high_i32x16 = _mm512_unpackhi_epi32(row10_i32x16, row11_i32x16);
|
|
476
|
+
__m512i tcd_low_i32x16 = _mm512_unpacklo_epi32(row12_i32x16, row13_i32x16);
|
|
477
|
+
__m512i tcd_high_i32x16 = _mm512_unpackhi_epi32(row12_i32x16, row13_i32x16);
|
|
478
|
+
__m512i tef_low_i32x16 = _mm512_unpacklo_epi32(row14_i32x16, row15_i32x16);
|
|
479
|
+
__m512i tef_high_i32x16 = _mm512_unpackhi_epi32(row14_i32x16, row15_i32x16);
|
|
479
480
|
|
|
480
481
|
// Stage 2: Unpack at 64-bit granularity
|
|
481
|
-
__m512i
|
|
482
|
-
__m512i
|
|
483
|
-
__m512i
|
|
484
|
-
__m512i
|
|
485
|
-
__m512i
|
|
486
|
-
__m512i
|
|
487
|
-
__m512i
|
|
488
|
-
__m512i
|
|
489
|
-
__m512i
|
|
490
|
-
__m512i
|
|
491
|
-
__m512i
|
|
492
|
-
__m512i
|
|
493
|
-
__m512i
|
|
494
|
-
__m512i
|
|
495
|
-
__m512i
|
|
496
|
-
__m512i
|
|
482
|
+
__m512i u0123_ll_i32x16 = _mm512_unpacklo_epi64(t01_low_i32x16, t23_low_i32x16);
|
|
483
|
+
__m512i u0123_lh_i32x16 = _mm512_unpackhi_epi64(t01_low_i32x16, t23_low_i32x16);
|
|
484
|
+
__m512i u0123_hl_i32x16 = _mm512_unpacklo_epi64(t01_high_i32x16, t23_high_i32x16);
|
|
485
|
+
__m512i u0123_hh_i32x16 = _mm512_unpackhi_epi64(t01_high_i32x16, t23_high_i32x16);
|
|
486
|
+
__m512i u4567_ll_i32x16 = _mm512_unpacklo_epi64(t45_low_i32x16, t67_low_i32x16);
|
|
487
|
+
__m512i u4567_lh_i32x16 = _mm512_unpackhi_epi64(t45_low_i32x16, t67_low_i32x16);
|
|
488
|
+
__m512i u4567_hl_i32x16 = _mm512_unpacklo_epi64(t45_high_i32x16, t67_high_i32x16);
|
|
489
|
+
__m512i u4567_hh_i32x16 = _mm512_unpackhi_epi64(t45_high_i32x16, t67_high_i32x16);
|
|
490
|
+
__m512i u89ab_ll_i32x16 = _mm512_unpacklo_epi64(t89_low_i32x16, tab_low_i32x16);
|
|
491
|
+
__m512i u89ab_lh_i32x16 = _mm512_unpackhi_epi64(t89_low_i32x16, tab_low_i32x16);
|
|
492
|
+
__m512i u89ab_hl_i32x16 = _mm512_unpacklo_epi64(t89_high_i32x16, tab_high_i32x16);
|
|
493
|
+
__m512i u89ab_hh_i32x16 = _mm512_unpackhi_epi64(t89_high_i32x16, tab_high_i32x16);
|
|
494
|
+
__m512i ucdef_ll_i32x16 = _mm512_unpacklo_epi64(tcd_low_i32x16, tef_low_i32x16);
|
|
495
|
+
__m512i ucdef_lh_i32x16 = _mm512_unpackhi_epi64(tcd_low_i32x16, tef_low_i32x16);
|
|
496
|
+
__m512i ucdef_hl_i32x16 = _mm512_unpacklo_epi64(tcd_high_i32x16, tef_high_i32x16);
|
|
497
|
+
__m512i ucdef_hh_i32x16 = _mm512_unpackhi_epi64(tcd_high_i32x16, tef_high_i32x16);
|
|
497
498
|
|
|
498
499
|
// Stage 3: Shuffle 128-bit lanes
|
|
499
|
-
__m512i
|
|
500
|
-
__m512i
|
|
501
|
-
__m512i
|
|
502
|
-
__m512i
|
|
503
|
-
__m512i
|
|
504
|
-
__m512i
|
|
505
|
-
__m512i
|
|
506
|
-
__m512i
|
|
507
|
-
__m512i
|
|
508
|
-
__m512i
|
|
509
|
-
__m512i
|
|
510
|
-
__m512i
|
|
511
|
-
__m512i
|
|
512
|
-
__m512i
|
|
513
|
-
__m512i
|
|
514
|
-
__m512i
|
|
500
|
+
__m512i v0_a_i32x16 = _mm512_shuffle_i32x4(u0123_ll_i32x16, u4567_ll_i32x16, 0x88);
|
|
501
|
+
__m512i v0_b_i32x16 = _mm512_shuffle_i32x4(u0123_ll_i32x16, u4567_ll_i32x16, 0xDD);
|
|
502
|
+
__m512i v1_a_i32x16 = _mm512_shuffle_i32x4(u0123_lh_i32x16, u4567_lh_i32x16, 0x88);
|
|
503
|
+
__m512i v1_b_i32x16 = _mm512_shuffle_i32x4(u0123_lh_i32x16, u4567_lh_i32x16, 0xDD);
|
|
504
|
+
__m512i v2_a_i32x16 = _mm512_shuffle_i32x4(u0123_hl_i32x16, u4567_hl_i32x16, 0x88);
|
|
505
|
+
__m512i v2_b_i32x16 = _mm512_shuffle_i32x4(u0123_hl_i32x16, u4567_hl_i32x16, 0xDD);
|
|
506
|
+
__m512i v3_a_i32x16 = _mm512_shuffle_i32x4(u0123_hh_i32x16, u4567_hh_i32x16, 0x88);
|
|
507
|
+
__m512i v3_b_i32x16 = _mm512_shuffle_i32x4(u0123_hh_i32x16, u4567_hh_i32x16, 0xDD);
|
|
508
|
+
__m512i v4_a_i32x16 = _mm512_shuffle_i32x4(u89ab_ll_i32x16, ucdef_ll_i32x16, 0x88);
|
|
509
|
+
__m512i v4_b_i32x16 = _mm512_shuffle_i32x4(u89ab_ll_i32x16, ucdef_ll_i32x16, 0xDD);
|
|
510
|
+
__m512i v5_a_i32x16 = _mm512_shuffle_i32x4(u89ab_lh_i32x16, ucdef_lh_i32x16, 0x88);
|
|
511
|
+
__m512i v5_b_i32x16 = _mm512_shuffle_i32x4(u89ab_lh_i32x16, ucdef_lh_i32x16, 0xDD);
|
|
512
|
+
__m512i v6_a_i32x16 = _mm512_shuffle_i32x4(u89ab_hl_i32x16, ucdef_hl_i32x16, 0x88);
|
|
513
|
+
__m512i v6_b_i32x16 = _mm512_shuffle_i32x4(u89ab_hl_i32x16, ucdef_hl_i32x16, 0xDD);
|
|
514
|
+
__m512i v7_a_i32x16 = _mm512_shuffle_i32x4(u89ab_hh_i32x16, ucdef_hh_i32x16, 0x88);
|
|
515
|
+
__m512i v7_b_i32x16 = _mm512_shuffle_i32x4(u89ab_hh_i32x16, ucdef_hh_i32x16, 0xDD);
|
|
515
516
|
|
|
516
517
|
// Stage 4: Final 256-bit shuffle to complete transpose
|
|
517
|
-
__m512i
|
|
518
|
-
__m512i
|
|
519
|
-
__m512i
|
|
520
|
-
__m512i
|
|
521
|
-
__m512i
|
|
522
|
-
__m512i
|
|
523
|
-
__m512i
|
|
524
|
-
__m512i
|
|
525
|
-
__m512i
|
|
526
|
-
__m512i
|
|
527
|
-
__m512i
|
|
528
|
-
__m512i
|
|
529
|
-
__m512i
|
|
530
|
-
__m512i
|
|
531
|
-
__m512i
|
|
532
|
-
__m512i
|
|
518
|
+
__m512i out00_i32x16 = _mm512_shuffle_i32x4(v0_a_i32x16, v4_a_i32x16, 0x88);
|
|
519
|
+
__m512i out01_i32x16 = _mm512_shuffle_i32x4(v1_a_i32x16, v5_a_i32x16, 0x88);
|
|
520
|
+
__m512i out02_i32x16 = _mm512_shuffle_i32x4(v2_a_i32x16, v6_a_i32x16, 0x88);
|
|
521
|
+
__m512i out03_i32x16 = _mm512_shuffle_i32x4(v3_a_i32x16, v7_a_i32x16, 0x88);
|
|
522
|
+
__m512i out04_i32x16 = _mm512_shuffle_i32x4(v0_a_i32x16, v4_a_i32x16, 0xDD);
|
|
523
|
+
__m512i out05_i32x16 = _mm512_shuffle_i32x4(v1_a_i32x16, v5_a_i32x16, 0xDD);
|
|
524
|
+
__m512i out06_i32x16 = _mm512_shuffle_i32x4(v2_a_i32x16, v6_a_i32x16, 0xDD);
|
|
525
|
+
__m512i out07_i32x16 = _mm512_shuffle_i32x4(v3_a_i32x16, v7_a_i32x16, 0xDD);
|
|
526
|
+
__m512i out08_i32x16 = _mm512_shuffle_i32x4(v0_b_i32x16, v4_b_i32x16, 0x88);
|
|
527
|
+
__m512i out09_i32x16 = _mm512_shuffle_i32x4(v1_b_i32x16, v5_b_i32x16, 0x88);
|
|
528
|
+
__m512i out10_i32x16 = _mm512_shuffle_i32x4(v2_b_i32x16, v6_b_i32x16, 0x88);
|
|
529
|
+
__m512i out11_i32x16 = _mm512_shuffle_i32x4(v3_b_i32x16, v7_b_i32x16, 0x88);
|
|
530
|
+
__m512i out12_i32x16 = _mm512_shuffle_i32x4(v0_b_i32x16, v4_b_i32x16, 0xDD);
|
|
531
|
+
__m512i out13_i32x16 = _mm512_shuffle_i32x4(v1_b_i32x16, v5_b_i32x16, 0xDD);
|
|
532
|
+
__m512i out14_i32x16 = _mm512_shuffle_i32x4(v2_b_i32x16, v6_b_i32x16, 0xDD);
|
|
533
|
+
__m512i out15_i32x16 = _mm512_shuffle_i32x4(v3_b_i32x16, v7_b_i32x16, 0xDD);
|
|
533
534
|
|
|
534
535
|
// Store transposed results - each output row is one depth_group
|
|
535
536
|
// Output layout: B.data[depth_group][column][quad] = 16 columns × 4 UINT8 = 64 bytes
|
|
536
|
-
_mm512_store_si512(&b_tile->data[0][0][0],
|
|
537
|
-
_mm512_store_si512(&b_tile->data[1][0][0],
|
|
538
|
-
_mm512_store_si512(&b_tile->data[2][0][0],
|
|
539
|
-
_mm512_store_si512(&b_tile->data[3][0][0],
|
|
540
|
-
_mm512_store_si512(&b_tile->data[4][0][0],
|
|
541
|
-
_mm512_store_si512(&b_tile->data[5][0][0],
|
|
542
|
-
_mm512_store_si512(&b_tile->data[6][0][0],
|
|
543
|
-
_mm512_store_si512(&b_tile->data[7][0][0],
|
|
544
|
-
_mm512_store_si512(&b_tile->data[8][0][0],
|
|
545
|
-
_mm512_store_si512(&b_tile->data[9][0][0],
|
|
546
|
-
_mm512_store_si512(&b_tile->data[10][0][0],
|
|
547
|
-
_mm512_store_si512(&b_tile->data[11][0][0],
|
|
548
|
-
_mm512_store_si512(&b_tile->data[12][0][0],
|
|
549
|
-
_mm512_store_si512(&b_tile->data[13][0][0],
|
|
550
|
-
_mm512_store_si512(&b_tile->data[14][0][0],
|
|
551
|
-
_mm512_store_si512(&b_tile->data[15][0][0],
|
|
537
|
+
_mm512_store_si512(&b_tile->data[0][0][0], out00_i32x16);
|
|
538
|
+
_mm512_store_si512(&b_tile->data[1][0][0], out01_i32x16);
|
|
539
|
+
_mm512_store_si512(&b_tile->data[2][0][0], out02_i32x16);
|
|
540
|
+
_mm512_store_si512(&b_tile->data[3][0][0], out03_i32x16);
|
|
541
|
+
_mm512_store_si512(&b_tile->data[4][0][0], out08_i32x16);
|
|
542
|
+
_mm512_store_si512(&b_tile->data[5][0][0], out09_i32x16);
|
|
543
|
+
_mm512_store_si512(&b_tile->data[6][0][0], out10_i32x16);
|
|
544
|
+
_mm512_store_si512(&b_tile->data[7][0][0], out11_i32x16);
|
|
545
|
+
_mm512_store_si512(&b_tile->data[8][0][0], out04_i32x16);
|
|
546
|
+
_mm512_store_si512(&b_tile->data[9][0][0], out05_i32x16);
|
|
547
|
+
_mm512_store_si512(&b_tile->data[10][0][0], out06_i32x16);
|
|
548
|
+
_mm512_store_si512(&b_tile->data[11][0][0], out07_i32x16);
|
|
549
|
+
_mm512_store_si512(&b_tile->data[12][0][0], out12_i32x16);
|
|
550
|
+
_mm512_store_si512(&b_tile->data[13][0][0], out13_i32x16);
|
|
551
|
+
_mm512_store_si512(&b_tile->data[14][0][0], out14_i32x16);
|
|
552
|
+
_mm512_store_si512(&b_tile->data[15][0][0], out15_i32x16);
|
|
552
553
|
|
|
553
554
|
nk_compiler_barrier_sapphireamx_();
|
|
554
555
|
}
|
|
@@ -588,17 +589,17 @@ NK_INTERNAL void nk_dots_e4m3_load_a_sapphireamx_( //
|
|
|
588
589
|
nk_size_t valid_rows, nk_size_t valid_cols) {
|
|
589
590
|
|
|
590
591
|
__mmask32 column_mask = (valid_cols >= 32) ? 0xFFFFFFFF : ((__mmask32)1 << valid_cols) - 1;
|
|
591
|
-
__m512i
|
|
592
|
+
__m512i zero_i16x32 = _mm512_setzero_si512();
|
|
592
593
|
|
|
593
594
|
for (nk_size_t row_idx = 0; row_idx < 16; row_idx++) {
|
|
594
595
|
if (row_idx < valid_rows) {
|
|
595
596
|
// Load 32 E4M3 bytes with masking
|
|
596
|
-
__m256i
|
|
597
|
+
__m256i e4m3_row_u8x32 = _mm256_maskz_loadu_epi8(column_mask, src + row_idx * src_stride);
|
|
597
598
|
// Convert to 32 BF16 values
|
|
598
|
-
__m512i
|
|
599
|
-
_mm512_store_si512((__m512i *)a_tile->data[row_idx],
|
|
599
|
+
__m512i bf16_row_i16x32 = nk_e4m3x32_to_bf16x32_icelake_(e4m3_row_u8x32);
|
|
600
|
+
_mm512_store_si512((__m512i *)a_tile->data[row_idx], bf16_row_i16x32);
|
|
600
601
|
}
|
|
601
|
-
else { _mm512_store_si512((__m512i *)a_tile->data[row_idx],
|
|
602
|
+
else { _mm512_store_si512((__m512i *)a_tile->data[row_idx], zero_i16x32); }
|
|
602
603
|
}
|
|
603
604
|
nk_compiler_barrier_sapphireamx_();
|
|
604
605
|
}
|
|
@@ -610,15 +611,15 @@ NK_INTERNAL void nk_dots_e5m2_load_a_sapphireamx_( //
|
|
|
610
611
|
nk_size_t valid_rows, nk_size_t valid_cols) {
|
|
611
612
|
|
|
612
613
|
__mmask32 column_mask = (valid_cols >= 32) ? 0xFFFFFFFF : ((__mmask32)1 << valid_cols) - 1;
|
|
613
|
-
__m512i
|
|
614
|
+
__m512i zero_i16x32 = _mm512_setzero_si512();
|
|
614
615
|
|
|
615
616
|
for (nk_size_t row_idx = 0; row_idx < 16; row_idx++) {
|
|
616
617
|
if (row_idx < valid_rows) {
|
|
617
|
-
__m256i
|
|
618
|
-
__m512i
|
|
619
|
-
_mm512_store_si512((__m512i *)a_tile->data[row_idx],
|
|
618
|
+
__m256i e5m2_row_u8x32 = _mm256_maskz_loadu_epi8(column_mask, src + row_idx * src_stride);
|
|
619
|
+
__m512i bf16_row_i16x32 = nk_e5m2x32_to_bf16x32_icelake_(e5m2_row_u8x32);
|
|
620
|
+
_mm512_store_si512((__m512i *)a_tile->data[row_idx], bf16_row_i16x32);
|
|
620
621
|
}
|
|
621
|
-
else { _mm512_store_si512((__m512i *)a_tile->data[row_idx],
|
|
622
|
+
else { _mm512_store_si512((__m512i *)a_tile->data[row_idx], zero_i16x32); }
|
|
622
623
|
}
|
|
623
624
|
nk_compiler_barrier_sapphireamx_();
|
|
624
625
|
}
|
|
@@ -630,115 +631,115 @@ NK_INTERNAL void nk_dots_pack_bf16_transposed_sapphireamx_( //
|
|
|
630
631
|
|
|
631
632
|
// Load all 16 rows - each row is 32 BF16 = 64 bytes = 1 ZMM
|
|
632
633
|
// Treat as 16 × 32-bit elements per row (each 32-bit = pair of BF16)
|
|
633
|
-
__m512i
|
|
634
|
-
__m512i
|
|
635
|
-
__m512i
|
|
636
|
-
__m512i
|
|
637
|
-
__m512i
|
|
638
|
-
__m512i
|
|
639
|
-
__m512i
|
|
640
|
-
__m512i
|
|
641
|
-
__m512i
|
|
642
|
-
__m512i
|
|
643
|
-
__m512i
|
|
644
|
-
__m512i
|
|
645
|
-
__m512i
|
|
646
|
-
__m512i
|
|
647
|
-
__m512i
|
|
648
|
-
__m512i
|
|
634
|
+
__m512i row00_i32x16 = _mm512_load_si512(&a_tile->data[0][0]);
|
|
635
|
+
__m512i row01_i32x16 = _mm512_load_si512(&a_tile->data[1][0]);
|
|
636
|
+
__m512i row02_i32x16 = _mm512_load_si512(&a_tile->data[2][0]);
|
|
637
|
+
__m512i row03_i32x16 = _mm512_load_si512(&a_tile->data[3][0]);
|
|
638
|
+
__m512i row04_i32x16 = _mm512_load_si512(&a_tile->data[4][0]);
|
|
639
|
+
__m512i row05_i32x16 = _mm512_load_si512(&a_tile->data[5][0]);
|
|
640
|
+
__m512i row06_i32x16 = _mm512_load_si512(&a_tile->data[6][0]);
|
|
641
|
+
__m512i row07_i32x16 = _mm512_load_si512(&a_tile->data[7][0]);
|
|
642
|
+
__m512i row08_i32x16 = _mm512_load_si512(&a_tile->data[8][0]);
|
|
643
|
+
__m512i row09_i32x16 = _mm512_load_si512(&a_tile->data[9][0]);
|
|
644
|
+
__m512i row10_i32x16 = _mm512_load_si512(&a_tile->data[10][0]);
|
|
645
|
+
__m512i row11_i32x16 = _mm512_load_si512(&a_tile->data[11][0]);
|
|
646
|
+
__m512i row12_i32x16 = _mm512_load_si512(&a_tile->data[12][0]);
|
|
647
|
+
__m512i row13_i32x16 = _mm512_load_si512(&a_tile->data[13][0]);
|
|
648
|
+
__m512i row14_i32x16 = _mm512_load_si512(&a_tile->data[14][0]);
|
|
649
|
+
__m512i row15_i32x16 = _mm512_load_si512(&a_tile->data[15][0]);
|
|
649
650
|
|
|
650
651
|
// 16×16 transpose of 32-bit elements using hierarchical unpacks
|
|
651
652
|
// Stage 1: Unpack adjacent row pairs at 32-bit granularity
|
|
652
|
-
__m512i
|
|
653
|
-
__m512i
|
|
654
|
-
__m512i
|
|
655
|
-
__m512i
|
|
656
|
-
__m512i
|
|
657
|
-
__m512i
|
|
658
|
-
__m512i
|
|
659
|
-
__m512i
|
|
660
|
-
__m512i
|
|
661
|
-
__m512i
|
|
662
|
-
__m512i
|
|
663
|
-
__m512i
|
|
664
|
-
__m512i
|
|
665
|
-
__m512i
|
|
666
|
-
__m512i
|
|
667
|
-
__m512i
|
|
653
|
+
__m512i t01_low_i32x16 = _mm512_unpacklo_epi32(row00_i32x16, row01_i32x16);
|
|
654
|
+
__m512i t01_high_i32x16 = _mm512_unpackhi_epi32(row00_i32x16, row01_i32x16);
|
|
655
|
+
__m512i t23_low_i32x16 = _mm512_unpacklo_epi32(row02_i32x16, row03_i32x16);
|
|
656
|
+
__m512i t23_high_i32x16 = _mm512_unpackhi_epi32(row02_i32x16, row03_i32x16);
|
|
657
|
+
__m512i t45_low_i32x16 = _mm512_unpacklo_epi32(row04_i32x16, row05_i32x16);
|
|
658
|
+
__m512i t45_high_i32x16 = _mm512_unpackhi_epi32(row04_i32x16, row05_i32x16);
|
|
659
|
+
__m512i t67_low_i32x16 = _mm512_unpacklo_epi32(row06_i32x16, row07_i32x16);
|
|
660
|
+
__m512i t67_high_i32x16 = _mm512_unpackhi_epi32(row06_i32x16, row07_i32x16);
|
|
661
|
+
__m512i t89_low_i32x16 = _mm512_unpacklo_epi32(row08_i32x16, row09_i32x16);
|
|
662
|
+
__m512i t89_high_i32x16 = _mm512_unpackhi_epi32(row08_i32x16, row09_i32x16);
|
|
663
|
+
__m512i tab_low_i32x16 = _mm512_unpacklo_epi32(row10_i32x16, row11_i32x16);
|
|
664
|
+
__m512i tab_high_i32x16 = _mm512_unpackhi_epi32(row10_i32x16, row11_i32x16);
|
|
665
|
+
__m512i tcd_low_i32x16 = _mm512_unpacklo_epi32(row12_i32x16, row13_i32x16);
|
|
666
|
+
__m512i tcd_high_i32x16 = _mm512_unpackhi_epi32(row12_i32x16, row13_i32x16);
|
|
667
|
+
__m512i tef_low_i32x16 = _mm512_unpacklo_epi32(row14_i32x16, row15_i32x16);
|
|
668
|
+
__m512i tef_high_i32x16 = _mm512_unpackhi_epi32(row14_i32x16, row15_i32x16);
|
|
668
669
|
|
|
669
670
|
// Stage 2: Unpack at 64-bit granularity
|
|
670
|
-
__m512i
|
|
671
|
-
__m512i
|
|
672
|
-
__m512i
|
|
673
|
-
__m512i
|
|
674
|
-
__m512i
|
|
675
|
-
__m512i
|
|
676
|
-
__m512i
|
|
677
|
-
__m512i
|
|
678
|
-
__m512i
|
|
679
|
-
__m512i
|
|
680
|
-
__m512i
|
|
681
|
-
__m512i
|
|
682
|
-
__m512i
|
|
683
|
-
__m512i
|
|
684
|
-
__m512i
|
|
685
|
-
__m512i
|
|
671
|
+
__m512i u0123_ll_i32x16 = _mm512_unpacklo_epi64(t01_low_i32x16, t23_low_i32x16);
|
|
672
|
+
__m512i u0123_lh_i32x16 = _mm512_unpackhi_epi64(t01_low_i32x16, t23_low_i32x16);
|
|
673
|
+
__m512i u0123_hl_i32x16 = _mm512_unpacklo_epi64(t01_high_i32x16, t23_high_i32x16);
|
|
674
|
+
__m512i u0123_hh_i32x16 = _mm512_unpackhi_epi64(t01_high_i32x16, t23_high_i32x16);
|
|
675
|
+
__m512i u4567_ll_i32x16 = _mm512_unpacklo_epi64(t45_low_i32x16, t67_low_i32x16);
|
|
676
|
+
__m512i u4567_lh_i32x16 = _mm512_unpackhi_epi64(t45_low_i32x16, t67_low_i32x16);
|
|
677
|
+
__m512i u4567_hl_i32x16 = _mm512_unpacklo_epi64(t45_high_i32x16, t67_high_i32x16);
|
|
678
|
+
__m512i u4567_hh_i32x16 = _mm512_unpackhi_epi64(t45_high_i32x16, t67_high_i32x16);
|
|
679
|
+
__m512i u89ab_ll_i32x16 = _mm512_unpacklo_epi64(t89_low_i32x16, tab_low_i32x16);
|
|
680
|
+
__m512i u89ab_lh_i32x16 = _mm512_unpackhi_epi64(t89_low_i32x16, tab_low_i32x16);
|
|
681
|
+
__m512i u89ab_hl_i32x16 = _mm512_unpacklo_epi64(t89_high_i32x16, tab_high_i32x16);
|
|
682
|
+
__m512i u89ab_hh_i32x16 = _mm512_unpackhi_epi64(t89_high_i32x16, tab_high_i32x16);
|
|
683
|
+
__m512i ucdef_ll_i32x16 = _mm512_unpacklo_epi64(tcd_low_i32x16, tef_low_i32x16);
|
|
684
|
+
__m512i ucdef_lh_i32x16 = _mm512_unpackhi_epi64(tcd_low_i32x16, tef_low_i32x16);
|
|
685
|
+
__m512i ucdef_hl_i32x16 = _mm512_unpacklo_epi64(tcd_high_i32x16, tef_high_i32x16);
|
|
686
|
+
__m512i ucdef_hh_i32x16 = _mm512_unpackhi_epi64(tcd_high_i32x16, tef_high_i32x16);
|
|
686
687
|
|
|
687
688
|
// Stage 3: Shuffle 128-bit lanes using permute2x128 equivalent for 512-bit
|
|
688
689
|
// Use shuffle_i32x4 to move 128-bit chunks
|
|
689
|
-
__m512i
|
|
690
|
-
__m512i
|
|
691
|
-
__m512i
|
|
692
|
-
__m512i
|
|
693
|
-
__m512i
|
|
694
|
-
__m512i
|
|
695
|
-
__m512i
|
|
696
|
-
__m512i
|
|
697
|
-
__m512i
|
|
698
|
-
__m512i
|
|
699
|
-
__m512i
|
|
700
|
-
__m512i
|
|
701
|
-
__m512i
|
|
702
|
-
__m512i
|
|
703
|
-
__m512i
|
|
704
|
-
__m512i
|
|
690
|
+
__m512i v0_a_i32x16 = _mm512_shuffle_i32x4(u0123_ll_i32x16, u4567_ll_i32x16, 0x88); // lanes 0,2 from each
|
|
691
|
+
__m512i v0_b_i32x16 = _mm512_shuffle_i32x4(u0123_ll_i32x16, u4567_ll_i32x16, 0xDD); // lanes 1,3 from each
|
|
692
|
+
__m512i v1_a_i32x16 = _mm512_shuffle_i32x4(u0123_lh_i32x16, u4567_lh_i32x16, 0x88);
|
|
693
|
+
__m512i v1_b_i32x16 = _mm512_shuffle_i32x4(u0123_lh_i32x16, u4567_lh_i32x16, 0xDD);
|
|
694
|
+
__m512i v2_a_i32x16 = _mm512_shuffle_i32x4(u0123_hl_i32x16, u4567_hl_i32x16, 0x88);
|
|
695
|
+
__m512i v2_b_i32x16 = _mm512_shuffle_i32x4(u0123_hl_i32x16, u4567_hl_i32x16, 0xDD);
|
|
696
|
+
__m512i v3_a_i32x16 = _mm512_shuffle_i32x4(u0123_hh_i32x16, u4567_hh_i32x16, 0x88);
|
|
697
|
+
__m512i v3_b_i32x16 = _mm512_shuffle_i32x4(u0123_hh_i32x16, u4567_hh_i32x16, 0xDD);
|
|
698
|
+
__m512i v4_a_i32x16 = _mm512_shuffle_i32x4(u89ab_ll_i32x16, ucdef_ll_i32x16, 0x88);
|
|
699
|
+
__m512i v4_b_i32x16 = _mm512_shuffle_i32x4(u89ab_ll_i32x16, ucdef_ll_i32x16, 0xDD);
|
|
700
|
+
__m512i v5_a_i32x16 = _mm512_shuffle_i32x4(u89ab_lh_i32x16, ucdef_lh_i32x16, 0x88);
|
|
701
|
+
__m512i v5_b_i32x16 = _mm512_shuffle_i32x4(u89ab_lh_i32x16, ucdef_lh_i32x16, 0xDD);
|
|
702
|
+
__m512i v6_a_i32x16 = _mm512_shuffle_i32x4(u89ab_hl_i32x16, ucdef_hl_i32x16, 0x88);
|
|
703
|
+
__m512i v6_b_i32x16 = _mm512_shuffle_i32x4(u89ab_hl_i32x16, ucdef_hl_i32x16, 0xDD);
|
|
704
|
+
__m512i v7_a_i32x16 = _mm512_shuffle_i32x4(u89ab_hh_i32x16, ucdef_hh_i32x16, 0x88);
|
|
705
|
+
__m512i v7_b_i32x16 = _mm512_shuffle_i32x4(u89ab_hh_i32x16, ucdef_hh_i32x16, 0xDD);
|
|
705
706
|
|
|
706
707
|
// Stage 4: Final 256-bit shuffle to complete transpose
|
|
707
|
-
__m512i
|
|
708
|
-
__m512i
|
|
709
|
-
__m512i
|
|
710
|
-
__m512i
|
|
711
|
-
__m512i
|
|
712
|
-
__m512i
|
|
713
|
-
__m512i
|
|
714
|
-
__m512i
|
|
715
|
-
__m512i
|
|
716
|
-
__m512i
|
|
717
|
-
__m512i
|
|
718
|
-
__m512i
|
|
719
|
-
__m512i
|
|
720
|
-
__m512i
|
|
721
|
-
__m512i
|
|
722
|
-
__m512i
|
|
708
|
+
__m512i out00_i32x16 = _mm512_shuffle_i32x4(v0_a_i32x16, v4_a_i32x16, 0x88);
|
|
709
|
+
__m512i out01_i32x16 = _mm512_shuffle_i32x4(v1_a_i32x16, v5_a_i32x16, 0x88);
|
|
710
|
+
__m512i out02_i32x16 = _mm512_shuffle_i32x4(v2_a_i32x16, v6_a_i32x16, 0x88);
|
|
711
|
+
__m512i out03_i32x16 = _mm512_shuffle_i32x4(v3_a_i32x16, v7_a_i32x16, 0x88);
|
|
712
|
+
__m512i out04_i32x16 = _mm512_shuffle_i32x4(v0_a_i32x16, v4_a_i32x16, 0xDD);
|
|
713
|
+
__m512i out05_i32x16 = _mm512_shuffle_i32x4(v1_a_i32x16, v5_a_i32x16, 0xDD);
|
|
714
|
+
__m512i out06_i32x16 = _mm512_shuffle_i32x4(v2_a_i32x16, v6_a_i32x16, 0xDD);
|
|
715
|
+
__m512i out07_i32x16 = _mm512_shuffle_i32x4(v3_a_i32x16, v7_a_i32x16, 0xDD);
|
|
716
|
+
__m512i out08_i32x16 = _mm512_shuffle_i32x4(v0_b_i32x16, v4_b_i32x16, 0x88);
|
|
717
|
+
__m512i out09_i32x16 = _mm512_shuffle_i32x4(v1_b_i32x16, v5_b_i32x16, 0x88);
|
|
718
|
+
__m512i out10_i32x16 = _mm512_shuffle_i32x4(v2_b_i32x16, v6_b_i32x16, 0x88);
|
|
719
|
+
__m512i out11_i32x16 = _mm512_shuffle_i32x4(v3_b_i32x16, v7_b_i32x16, 0x88);
|
|
720
|
+
__m512i out12_i32x16 = _mm512_shuffle_i32x4(v0_b_i32x16, v4_b_i32x16, 0xDD);
|
|
721
|
+
__m512i out13_i32x16 = _mm512_shuffle_i32x4(v1_b_i32x16, v5_b_i32x16, 0xDD);
|
|
722
|
+
__m512i out14_i32x16 = _mm512_shuffle_i32x4(v2_b_i32x16, v6_b_i32x16, 0xDD);
|
|
723
|
+
__m512i out15_i32x16 = _mm512_shuffle_i32x4(v3_b_i32x16, v7_b_i32x16, 0xDD);
|
|
723
724
|
|
|
724
725
|
// Store transposed results - each output row is one depth_group
|
|
725
726
|
// Output layout: B.data[depth_group][column][pair] = 16 columns × 2 BF16 = 64 bytes
|
|
726
|
-
_mm512_store_si512(&b_tile->data[0][0][0],
|
|
727
|
-
_mm512_store_si512(&b_tile->data[1][0][0],
|
|
728
|
-
_mm512_store_si512(&b_tile->data[2][0][0],
|
|
729
|
-
_mm512_store_si512(&b_tile->data[3][0][0],
|
|
730
|
-
_mm512_store_si512(&b_tile->data[4][0][0],
|
|
731
|
-
_mm512_store_si512(&b_tile->data[5][0][0],
|
|
732
|
-
_mm512_store_si512(&b_tile->data[6][0][0],
|
|
733
|
-
_mm512_store_si512(&b_tile->data[7][0][0],
|
|
734
|
-
_mm512_store_si512(&b_tile->data[8][0][0],
|
|
735
|
-
_mm512_store_si512(&b_tile->data[9][0][0],
|
|
736
|
-
_mm512_store_si512(&b_tile->data[10][0][0],
|
|
737
|
-
_mm512_store_si512(&b_tile->data[11][0][0],
|
|
738
|
-
_mm512_store_si512(&b_tile->data[12][0][0],
|
|
739
|
-
_mm512_store_si512(&b_tile->data[13][0][0],
|
|
740
|
-
_mm512_store_si512(&b_tile->data[14][0][0],
|
|
741
|
-
_mm512_store_si512(&b_tile->data[15][0][0],
|
|
727
|
+
_mm512_store_si512(&b_tile->data[0][0][0], out00_i32x16);
|
|
728
|
+
_mm512_store_si512(&b_tile->data[1][0][0], out01_i32x16);
|
|
729
|
+
_mm512_store_si512(&b_tile->data[2][0][0], out02_i32x16);
|
|
730
|
+
_mm512_store_si512(&b_tile->data[3][0][0], out03_i32x16);
|
|
731
|
+
_mm512_store_si512(&b_tile->data[4][0][0], out08_i32x16);
|
|
732
|
+
_mm512_store_si512(&b_tile->data[5][0][0], out09_i32x16);
|
|
733
|
+
_mm512_store_si512(&b_tile->data[6][0][0], out10_i32x16);
|
|
734
|
+
_mm512_store_si512(&b_tile->data[7][0][0], out11_i32x16);
|
|
735
|
+
_mm512_store_si512(&b_tile->data[8][0][0], out04_i32x16);
|
|
736
|
+
_mm512_store_si512(&b_tile->data[9][0][0], out05_i32x16);
|
|
737
|
+
_mm512_store_si512(&b_tile->data[10][0][0], out06_i32x16);
|
|
738
|
+
_mm512_store_si512(&b_tile->data[11][0][0], out07_i32x16);
|
|
739
|
+
_mm512_store_si512(&b_tile->data[12][0][0], out12_i32x16);
|
|
740
|
+
_mm512_store_si512(&b_tile->data[13][0][0], out13_i32x16);
|
|
741
|
+
_mm512_store_si512(&b_tile->data[14][0][0], out14_i32x16);
|
|
742
|
+
_mm512_store_si512(&b_tile->data[15][0][0], out15_i32x16);
|
|
742
743
|
|
|
743
744
|
nk_compiler_barrier_sapphireamx_();
|
|
744
745
|
}
|
|
@@ -750,119 +751,119 @@ NK_INTERNAL void nk_dots_pack_i8_transposed_sapphireamx_( //
|
|
|
750
751
|
|
|
751
752
|
// Load all 16 rows - each row is 64 INT8 = 64 bytes = 1 ZMM
|
|
752
753
|
// Treat as 16 × 32-bit elements per row (each 32-bit = quad of INT8)
|
|
753
|
-
__m512i
|
|
754
|
-
__m512i
|
|
755
|
-
__m512i
|
|
756
|
-
__m512i
|
|
757
|
-
__m512i
|
|
758
|
-
__m512i
|
|
759
|
-
__m512i
|
|
760
|
-
__m512i
|
|
761
|
-
__m512i
|
|
762
|
-
__m512i
|
|
763
|
-
__m512i
|
|
764
|
-
__m512i
|
|
765
|
-
__m512i
|
|
766
|
-
__m512i
|
|
767
|
-
__m512i
|
|
768
|
-
__m512i
|
|
754
|
+
__m512i row00_i32x16 = _mm512_load_si512(&a_tile->data[0][0]);
|
|
755
|
+
__m512i row01_i32x16 = _mm512_load_si512(&a_tile->data[1][0]);
|
|
756
|
+
__m512i row02_i32x16 = _mm512_load_si512(&a_tile->data[2][0]);
|
|
757
|
+
__m512i row03_i32x16 = _mm512_load_si512(&a_tile->data[3][0]);
|
|
758
|
+
__m512i row04_i32x16 = _mm512_load_si512(&a_tile->data[4][0]);
|
|
759
|
+
__m512i row05_i32x16 = _mm512_load_si512(&a_tile->data[5][0]);
|
|
760
|
+
__m512i row06_i32x16 = _mm512_load_si512(&a_tile->data[6][0]);
|
|
761
|
+
__m512i row07_i32x16 = _mm512_load_si512(&a_tile->data[7][0]);
|
|
762
|
+
__m512i row08_i32x16 = _mm512_load_si512(&a_tile->data[8][0]);
|
|
763
|
+
__m512i row09_i32x16 = _mm512_load_si512(&a_tile->data[9][0]);
|
|
764
|
+
__m512i row10_i32x16 = _mm512_load_si512(&a_tile->data[10][0]);
|
|
765
|
+
__m512i row11_i32x16 = _mm512_load_si512(&a_tile->data[11][0]);
|
|
766
|
+
__m512i row12_i32x16 = _mm512_load_si512(&a_tile->data[12][0]);
|
|
767
|
+
__m512i row13_i32x16 = _mm512_load_si512(&a_tile->data[13][0]);
|
|
768
|
+
__m512i row14_i32x16 = _mm512_load_si512(&a_tile->data[14][0]);
|
|
769
|
+
__m512i row15_i32x16 = _mm512_load_si512(&a_tile->data[15][0]);
|
|
769
770
|
|
|
770
771
|
// 16×16 transpose of 32-bit elements using hierarchical unpacks
|
|
771
772
|
// Stage 1: Unpack adjacent row pairs at 32-bit granularity
|
|
772
|
-
__m512i
|
|
773
|
-
__m512i
|
|
774
|
-
__m512i
|
|
775
|
-
__m512i
|
|
776
|
-
__m512i
|
|
777
|
-
__m512i
|
|
778
|
-
__m512i
|
|
779
|
-
__m512i
|
|
780
|
-
__m512i
|
|
781
|
-
__m512i
|
|
782
|
-
__m512i
|
|
783
|
-
__m512i
|
|
784
|
-
__m512i
|
|
785
|
-
__m512i
|
|
786
|
-
__m512i
|
|
787
|
-
__m512i
|
|
773
|
+
__m512i t01_low_i32x16 = _mm512_unpacklo_epi32(row00_i32x16, row01_i32x16);
|
|
774
|
+
__m512i t01_high_i32x16 = _mm512_unpackhi_epi32(row00_i32x16, row01_i32x16);
|
|
775
|
+
__m512i t23_low_i32x16 = _mm512_unpacklo_epi32(row02_i32x16, row03_i32x16);
|
|
776
|
+
__m512i t23_high_i32x16 = _mm512_unpackhi_epi32(row02_i32x16, row03_i32x16);
|
|
777
|
+
__m512i t45_low_i32x16 = _mm512_unpacklo_epi32(row04_i32x16, row05_i32x16);
|
|
778
|
+
__m512i t45_high_i32x16 = _mm512_unpackhi_epi32(row04_i32x16, row05_i32x16);
|
|
779
|
+
__m512i t67_low_i32x16 = _mm512_unpacklo_epi32(row06_i32x16, row07_i32x16);
|
|
780
|
+
__m512i t67_high_i32x16 = _mm512_unpackhi_epi32(row06_i32x16, row07_i32x16);
|
|
781
|
+
__m512i t89_low_i32x16 = _mm512_unpacklo_epi32(row08_i32x16, row09_i32x16);
|
|
782
|
+
__m512i t89_high_i32x16 = _mm512_unpackhi_epi32(row08_i32x16, row09_i32x16);
|
|
783
|
+
__m512i tab_low_i32x16 = _mm512_unpacklo_epi32(row10_i32x16, row11_i32x16);
|
|
784
|
+
__m512i tab_high_i32x16 = _mm512_unpackhi_epi32(row10_i32x16, row11_i32x16);
|
|
785
|
+
__m512i tcd_low_i32x16 = _mm512_unpacklo_epi32(row12_i32x16, row13_i32x16);
|
|
786
|
+
__m512i tcd_high_i32x16 = _mm512_unpackhi_epi32(row12_i32x16, row13_i32x16);
|
|
787
|
+
__m512i tef_low_i32x16 = _mm512_unpacklo_epi32(row14_i32x16, row15_i32x16);
|
|
788
|
+
__m512i tef_high_i32x16 = _mm512_unpackhi_epi32(row14_i32x16, row15_i32x16);
|
|
788
789
|
|
|
789
790
|
// Stage 2: Unpack at 64-bit granularity
|
|
790
|
-
__m512i
|
|
791
|
-
__m512i
|
|
792
|
-
__m512i
|
|
793
|
-
__m512i
|
|
794
|
-
__m512i
|
|
795
|
-
__m512i
|
|
796
|
-
__m512i
|
|
797
|
-
__m512i
|
|
798
|
-
__m512i
|
|
799
|
-
__m512i
|
|
800
|
-
__m512i
|
|
801
|
-
__m512i
|
|
802
|
-
__m512i
|
|
803
|
-
__m512i
|
|
804
|
-
__m512i
|
|
805
|
-
__m512i
|
|
791
|
+
__m512i u0123_ll_i32x16 = _mm512_unpacklo_epi64(t01_low_i32x16, t23_low_i32x16);
|
|
792
|
+
__m512i u0123_lh_i32x16 = _mm512_unpackhi_epi64(t01_low_i32x16, t23_low_i32x16);
|
|
793
|
+
__m512i u0123_hl_i32x16 = _mm512_unpacklo_epi64(t01_high_i32x16, t23_high_i32x16);
|
|
794
|
+
__m512i u0123_hh_i32x16 = _mm512_unpackhi_epi64(t01_high_i32x16, t23_high_i32x16);
|
|
795
|
+
__m512i u4567_ll_i32x16 = _mm512_unpacklo_epi64(t45_low_i32x16, t67_low_i32x16);
|
|
796
|
+
__m512i u4567_lh_i32x16 = _mm512_unpackhi_epi64(t45_low_i32x16, t67_low_i32x16);
|
|
797
|
+
__m512i u4567_hl_i32x16 = _mm512_unpacklo_epi64(t45_high_i32x16, t67_high_i32x16);
|
|
798
|
+
__m512i u4567_hh_i32x16 = _mm512_unpackhi_epi64(t45_high_i32x16, t67_high_i32x16);
|
|
799
|
+
__m512i u89ab_ll_i32x16 = _mm512_unpacklo_epi64(t89_low_i32x16, tab_low_i32x16);
|
|
800
|
+
__m512i u89ab_lh_i32x16 = _mm512_unpackhi_epi64(t89_low_i32x16, tab_low_i32x16);
|
|
801
|
+
__m512i u89ab_hl_i32x16 = _mm512_unpacklo_epi64(t89_high_i32x16, tab_high_i32x16);
|
|
802
|
+
__m512i u89ab_hh_i32x16 = _mm512_unpackhi_epi64(t89_high_i32x16, tab_high_i32x16);
|
|
803
|
+
__m512i ucdef_ll_i32x16 = _mm512_unpacklo_epi64(tcd_low_i32x16, tef_low_i32x16);
|
|
804
|
+
__m512i ucdef_lh_i32x16 = _mm512_unpackhi_epi64(tcd_low_i32x16, tef_low_i32x16);
|
|
805
|
+
__m512i ucdef_hl_i32x16 = _mm512_unpacklo_epi64(tcd_high_i32x16, tef_high_i32x16);
|
|
806
|
+
__m512i ucdef_hh_i32x16 = _mm512_unpackhi_epi64(tcd_high_i32x16, tef_high_i32x16);
|
|
806
807
|
|
|
807
808
|
// Stage 3: Shuffle 128-bit lanes
|
|
808
|
-
__m512i
|
|
809
|
-
__m512i
|
|
810
|
-
__m512i
|
|
811
|
-
__m512i
|
|
812
|
-
__m512i
|
|
813
|
-
__m512i
|
|
814
|
-
__m512i
|
|
815
|
-
__m512i
|
|
816
|
-
__m512i
|
|
817
|
-
__m512i
|
|
818
|
-
__m512i
|
|
819
|
-
__m512i
|
|
820
|
-
__m512i
|
|
821
|
-
__m512i
|
|
822
|
-
__m512i
|
|
823
|
-
__m512i
|
|
809
|
+
__m512i v0_a_i32x16 = _mm512_shuffle_i32x4(u0123_ll_i32x16, u4567_ll_i32x16, 0x88);
|
|
810
|
+
__m512i v0_b_i32x16 = _mm512_shuffle_i32x4(u0123_ll_i32x16, u4567_ll_i32x16, 0xDD);
|
|
811
|
+
__m512i v1_a_i32x16 = _mm512_shuffle_i32x4(u0123_lh_i32x16, u4567_lh_i32x16, 0x88);
|
|
812
|
+
__m512i v1_b_i32x16 = _mm512_shuffle_i32x4(u0123_lh_i32x16, u4567_lh_i32x16, 0xDD);
|
|
813
|
+
__m512i v2_a_i32x16 = _mm512_shuffle_i32x4(u0123_hl_i32x16, u4567_hl_i32x16, 0x88);
|
|
814
|
+
__m512i v2_b_i32x16 = _mm512_shuffle_i32x4(u0123_hl_i32x16, u4567_hl_i32x16, 0xDD);
|
|
815
|
+
__m512i v3_a_i32x16 = _mm512_shuffle_i32x4(u0123_hh_i32x16, u4567_hh_i32x16, 0x88);
|
|
816
|
+
__m512i v3_b_i32x16 = _mm512_shuffle_i32x4(u0123_hh_i32x16, u4567_hh_i32x16, 0xDD);
|
|
817
|
+
__m512i v4_a_i32x16 = _mm512_shuffle_i32x4(u89ab_ll_i32x16, ucdef_ll_i32x16, 0x88);
|
|
818
|
+
__m512i v4_b_i32x16 = _mm512_shuffle_i32x4(u89ab_ll_i32x16, ucdef_ll_i32x16, 0xDD);
|
|
819
|
+
__m512i v5_a_i32x16 = _mm512_shuffle_i32x4(u89ab_lh_i32x16, ucdef_lh_i32x16, 0x88);
|
|
820
|
+
__m512i v5_b_i32x16 = _mm512_shuffle_i32x4(u89ab_lh_i32x16, ucdef_lh_i32x16, 0xDD);
|
|
821
|
+
__m512i v6_a_i32x16 = _mm512_shuffle_i32x4(u89ab_hl_i32x16, ucdef_hl_i32x16, 0x88);
|
|
822
|
+
__m512i v6_b_i32x16 = _mm512_shuffle_i32x4(u89ab_hl_i32x16, ucdef_hl_i32x16, 0xDD);
|
|
823
|
+
__m512i v7_a_i32x16 = _mm512_shuffle_i32x4(u89ab_hh_i32x16, ucdef_hh_i32x16, 0x88);
|
|
824
|
+
__m512i v7_b_i32x16 = _mm512_shuffle_i32x4(u89ab_hh_i32x16, ucdef_hh_i32x16, 0xDD);
|
|
824
825
|
|
|
825
826
|
// Stage 4: Final 256-bit shuffle to complete transpose
|
|
826
|
-
__m512i
|
|
827
|
-
__m512i
|
|
828
|
-
__m512i
|
|
829
|
-
__m512i
|
|
830
|
-
__m512i
|
|
831
|
-
__m512i
|
|
832
|
-
__m512i
|
|
833
|
-
__m512i
|
|
834
|
-
__m512i
|
|
835
|
-
__m512i
|
|
836
|
-
__m512i
|
|
837
|
-
__m512i
|
|
838
|
-
__m512i
|
|
839
|
-
__m512i
|
|
840
|
-
__m512i
|
|
841
|
-
__m512i
|
|
827
|
+
__m512i out00_i32x16 = _mm512_shuffle_i32x4(v0_a_i32x16, v4_a_i32x16, 0x88);
|
|
828
|
+
__m512i out01_i32x16 = _mm512_shuffle_i32x4(v1_a_i32x16, v5_a_i32x16, 0x88);
|
|
829
|
+
__m512i out02_i32x16 = _mm512_shuffle_i32x4(v2_a_i32x16, v6_a_i32x16, 0x88);
|
|
830
|
+
__m512i out03_i32x16 = _mm512_shuffle_i32x4(v3_a_i32x16, v7_a_i32x16, 0x88);
|
|
831
|
+
__m512i out04_i32x16 = _mm512_shuffle_i32x4(v0_a_i32x16, v4_a_i32x16, 0xDD);
|
|
832
|
+
__m512i out05_i32x16 = _mm512_shuffle_i32x4(v1_a_i32x16, v5_a_i32x16, 0xDD);
|
|
833
|
+
__m512i out06_i32x16 = _mm512_shuffle_i32x4(v2_a_i32x16, v6_a_i32x16, 0xDD);
|
|
834
|
+
__m512i out07_i32x16 = _mm512_shuffle_i32x4(v3_a_i32x16, v7_a_i32x16, 0xDD);
|
|
835
|
+
__m512i out08_i32x16 = _mm512_shuffle_i32x4(v0_b_i32x16, v4_b_i32x16, 0x88);
|
|
836
|
+
__m512i out09_i32x16 = _mm512_shuffle_i32x4(v1_b_i32x16, v5_b_i32x16, 0x88);
|
|
837
|
+
__m512i out10_i32x16 = _mm512_shuffle_i32x4(v2_b_i32x16, v6_b_i32x16, 0x88);
|
|
838
|
+
__m512i out11_i32x16 = _mm512_shuffle_i32x4(v3_b_i32x16, v7_b_i32x16, 0x88);
|
|
839
|
+
__m512i out12_i32x16 = _mm512_shuffle_i32x4(v0_b_i32x16, v4_b_i32x16, 0xDD);
|
|
840
|
+
__m512i out13_i32x16 = _mm512_shuffle_i32x4(v1_b_i32x16, v5_b_i32x16, 0xDD);
|
|
841
|
+
__m512i out14_i32x16 = _mm512_shuffle_i32x4(v2_b_i32x16, v6_b_i32x16, 0xDD);
|
|
842
|
+
__m512i out15_i32x16 = _mm512_shuffle_i32x4(v3_b_i32x16, v7_b_i32x16, 0xDD);
|
|
842
843
|
|
|
843
844
|
// Store transposed results - each output row is one depth_group
|
|
844
845
|
// Output layout: B.data[depth_group][column][quad] = 16 columns × 4 INT8 = 64 bytes
|
|
845
|
-
_mm512_store_si512(&b_tile->data[0][0][0],
|
|
846
|
-
_mm512_store_si512(&b_tile->data[1][0][0],
|
|
847
|
-
_mm512_store_si512(&b_tile->data[2][0][0],
|
|
848
|
-
_mm512_store_si512(&b_tile->data[3][0][0],
|
|
849
|
-
_mm512_store_si512(&b_tile->data[4][0][0],
|
|
850
|
-
_mm512_store_si512(&b_tile->data[5][0][0],
|
|
851
|
-
_mm512_store_si512(&b_tile->data[6][0][0],
|
|
852
|
-
_mm512_store_si512(&b_tile->data[7][0][0],
|
|
853
|
-
_mm512_store_si512(&b_tile->data[8][0][0],
|
|
854
|
-
_mm512_store_si512(&b_tile->data[9][0][0],
|
|
855
|
-
_mm512_store_si512(&b_tile->data[10][0][0],
|
|
856
|
-
_mm512_store_si512(&b_tile->data[11][0][0],
|
|
857
|
-
_mm512_store_si512(&b_tile->data[12][0][0],
|
|
858
|
-
_mm512_store_si512(&b_tile->data[13][0][0],
|
|
859
|
-
_mm512_store_si512(&b_tile->data[14][0][0],
|
|
860
|
-
_mm512_store_si512(&b_tile->data[15][0][0],
|
|
846
|
+
_mm512_store_si512(&b_tile->data[0][0][0], out00_i32x16);
|
|
847
|
+
_mm512_store_si512(&b_tile->data[1][0][0], out01_i32x16);
|
|
848
|
+
_mm512_store_si512(&b_tile->data[2][0][0], out02_i32x16);
|
|
849
|
+
_mm512_store_si512(&b_tile->data[3][0][0], out03_i32x16);
|
|
850
|
+
_mm512_store_si512(&b_tile->data[4][0][0], out08_i32x16);
|
|
851
|
+
_mm512_store_si512(&b_tile->data[5][0][0], out09_i32x16);
|
|
852
|
+
_mm512_store_si512(&b_tile->data[6][0][0], out10_i32x16);
|
|
853
|
+
_mm512_store_si512(&b_tile->data[7][0][0], out11_i32x16);
|
|
854
|
+
_mm512_store_si512(&b_tile->data[8][0][0], out04_i32x16);
|
|
855
|
+
_mm512_store_si512(&b_tile->data[9][0][0], out05_i32x16);
|
|
856
|
+
_mm512_store_si512(&b_tile->data[10][0][0], out06_i32x16);
|
|
857
|
+
_mm512_store_si512(&b_tile->data[11][0][0], out07_i32x16);
|
|
858
|
+
_mm512_store_si512(&b_tile->data[12][0][0], out12_i32x16);
|
|
859
|
+
_mm512_store_si512(&b_tile->data[13][0][0], out13_i32x16);
|
|
860
|
+
_mm512_store_si512(&b_tile->data[14][0][0], out14_i32x16);
|
|
861
|
+
_mm512_store_si512(&b_tile->data[15][0][0], out15_i32x16);
|
|
861
862
|
|
|
862
863
|
nk_compiler_barrier_sapphireamx_();
|
|
863
864
|
}
|
|
864
865
|
|
|
865
|
-
#pragma region
|
|
866
|
+
#pragma region F16 Floats
|
|
866
867
|
|
|
867
868
|
NK_PUBLIC nk_size_t nk_dots_packed_size_bf16_sapphireamx(nk_size_t column_count, nk_size_t depth) {
|
|
868
869
|
nk_size_t const tmm_rows = 16;
|
|
@@ -890,14 +891,14 @@ NK_PUBLIC nk_size_t nk_dots_packed_size_bf16_sapphireamx(nk_size_t column_count,
|
|
|
890
891
|
|
|
891
892
|
NK_PUBLIC void nk_dots_pack_bf16_sapphireamx( //
|
|
892
893
|
nk_bf16_t const *b, nk_size_t column_count, nk_size_t depth, //
|
|
893
|
-
nk_size_t
|
|
894
|
+
nk_size_t b_stride_in_bytes, void *b_packed) {
|
|
894
895
|
|
|
895
896
|
// AMX BF16 tile dimensions: 16 rows × 32 columns (512 BF16 elements = 1KB)
|
|
896
897
|
nk_size_t const tmm_rows = 16;
|
|
897
898
|
nk_size_t const tmm_cols = 32;
|
|
898
899
|
nk_size_t const tile_elements = 512;
|
|
899
900
|
nk_size_t const tile_bytes = tile_elements * sizeof(nk_bf16_t);
|
|
900
|
-
nk_size_t const b_stride_elements =
|
|
901
|
+
nk_size_t const b_stride_elements = b_stride_in_bytes / sizeof(nk_bf16_t);
|
|
901
902
|
|
|
902
903
|
// Compute layout dimensions
|
|
903
904
|
nk_size_t const column_tiles_count = column_count / tmm_rows;
|
|
@@ -920,36 +921,40 @@ NK_PUBLIC void nk_dots_pack_bf16_sapphireamx( //
|
|
|
920
921
|
nk_bf16_t *tiles_ptr = (nk_bf16_t *)((char *)b_packed + tiles_offset);
|
|
921
922
|
nk_bf16_t *column_edge_ptr = (nk_bf16_t *)((char *)b_packed + column_edge_offset);
|
|
922
923
|
|
|
923
|
-
//
|
|
924
|
-
|
|
925
|
-
|
|
926
|
-
// Pack tiles using LINEAR ordering: tile_index = column_tile × depth_tiles_count + depth_tile
|
|
927
|
-
// This provides sequential memory access when streaming along depth dimension,
|
|
928
|
-
// which is critical for cache efficiency in the compute kernel.
|
|
924
|
+
// Pack tiles using vectorized transposer: gather 16 strided rows into an aligned
|
|
925
|
+
// temporary, transpose via SIMD, then copy the result to the packed buffer.
|
|
929
926
|
for (nk_size_t column_tile_idx = 0; column_tile_idx < column_tiles_count; column_tile_idx++) {
|
|
930
927
|
for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
|
|
931
928
|
|
|
932
|
-
// Linear tile index: all depth-tiles for one column-tile are contiguous
|
|
933
929
|
nk_size_t const tile_index = column_tile_idx * depth_tiles_count + depth_tile_idx;
|
|
934
930
|
nk_bf16_t *tile_output = tiles_ptr + tile_index * tile_elements;
|
|
935
931
|
|
|
936
|
-
// Source coordinates in original B matrix
|
|
937
932
|
nk_size_t const src_row_start = column_tile_idx * tmm_rows;
|
|
938
933
|
nk_size_t const src_column_start = depth_tile_idx * tmm_cols;
|
|
939
934
|
nk_size_t const columns_to_pack = (src_column_start + tmm_cols <= depth) ? tmm_cols
|
|
940
935
|
: (depth - src_column_start);
|
|
941
936
|
|
|
942
|
-
//
|
|
943
|
-
|
|
944
|
-
|
|
945
|
-
|
|
946
|
-
|
|
947
|
-
|
|
948
|
-
|
|
949
|
-
|
|
950
|
-
|
|
937
|
+
// Gather 16 strided source rows into a contiguous aligned tile
|
|
938
|
+
nk_dots_bf16_a16x32_sapphireamx_t source_tile;
|
|
939
|
+
if (columns_to_pack == tmm_cols) {
|
|
940
|
+
for (nk_size_t row_idx = 0; row_idx < tmm_rows; row_idx++) {
|
|
941
|
+
nk_bf16_t const *source_row = b + (src_row_start + row_idx) * b_stride_elements + src_column_start;
|
|
942
|
+
_mm512_store_si512(&source_tile.data[row_idx][0], _mm512_loadu_si512(source_row));
|
|
943
|
+
}
|
|
944
|
+
}
|
|
945
|
+
else {
|
|
946
|
+
__mmask32 depth_mask = (__mmask32)((columns_to_pack < 32) ? ((1U << columns_to_pack) - 1) : ~0U);
|
|
947
|
+
for (nk_size_t row_idx = 0; row_idx < tmm_rows; row_idx++) {
|
|
948
|
+
nk_bf16_t const *source_row = b + (src_row_start + row_idx) * b_stride_elements + src_column_start;
|
|
949
|
+
_mm512_store_si512(&source_tile.data[row_idx][0], _mm512_maskz_loadu_epi16(depth_mask, source_row));
|
|
951
950
|
}
|
|
952
951
|
}
|
|
952
|
+
|
|
953
|
+
// Transpose into aligned local, then copy to (potentially unaligned) packed buffer
|
|
954
|
+
nk_dots_bf16_b32x16_sapphireamx_t transposed_tile;
|
|
955
|
+
nk_dots_pack_bf16_transposed_sapphireamx_(&source_tile, &transposed_tile);
|
|
956
|
+
for (nk_size_t i = 0; i < tile_bytes; i += 64)
|
|
957
|
+
_mm512_storeu_si512((char *)tile_output + i, _mm512_load_si512((char const *)&transposed_tile + i));
|
|
953
958
|
}
|
|
954
959
|
}
|
|
955
960
|
|
|
@@ -1004,7 +1009,7 @@ NK_PUBLIC void nk_dots_packed_bf16_sapphireamx( //
|
|
|
1004
1009
|
if (depth_tiles_count == 0) return;
|
|
1005
1010
|
|
|
1006
1011
|
// Tile buffers for A (only used for edge tiles)
|
|
1007
|
-
nk_dots_bf16_a16x32_sapphireamx_t
|
|
1012
|
+
nk_dots_bf16_a16x32_sapphireamx_t a_tile_top, a_tile_bottom;
|
|
1008
1013
|
nk_dots_bf16_state2x2_sapphireamx_t c_accum_buffer;
|
|
1009
1014
|
|
|
1010
1015
|
// Precompute: number of full depth-tiles (no masking needed)
|
|
@@ -1033,8 +1038,8 @@ NK_PUBLIC void nk_dots_packed_bf16_sapphireamx( //
|
|
|
1033
1038
|
|
|
1034
1039
|
// Fast path: full row-block with full depth-tiles → direct A load with 2-deep pipelining
|
|
1035
1040
|
if (is_full_row_block && full_depth_tiles_count > 0) {
|
|
1036
|
-
nk_bf16_t const *
|
|
1037
|
-
nk_bf16_t const *
|
|
1041
|
+
nk_bf16_t const *a_top_base = a + row_block_start * a_stride_elements;
|
|
1042
|
+
nk_bf16_t const *a_bottom_base = a + (row_block_start + 16) * a_stride_elements;
|
|
1038
1043
|
|
|
1039
1044
|
nk_dots_bf16_b32x16_sapphireamx_t const *b_tile_left =
|
|
1040
1045
|
(nk_dots_bf16_b32x16_sapphireamx_t const *)(b_tiles_base + b_column_left_base * tile_size);
|
|
@@ -1042,8 +1047,8 @@ NK_PUBLIC void nk_dots_packed_bf16_sapphireamx( //
|
|
|
1042
1047
|
(nk_dots_bf16_b32x16_sapphireamx_t const *)(b_tiles_base + b_column_right_base * tile_size);
|
|
1043
1048
|
|
|
1044
1049
|
// Prologue: load first depth tile
|
|
1045
|
-
_tile_loadd(0,
|
|
1046
|
-
_tile_loadd(1,
|
|
1050
|
+
_tile_loadd(0, a_top_base, a_stride_bytes);
|
|
1051
|
+
_tile_loadd(1, a_bottom_base, a_stride_bytes);
|
|
1047
1052
|
_tile_loadd(2, b_tile_left->data, 64);
|
|
1048
1053
|
_tile_loadd(3, b_tile_right->data, 64);
|
|
1049
1054
|
|
|
@@ -1056,8 +1061,8 @@ NK_PUBLIC void nk_dots_packed_bf16_sapphireamx( //
|
|
|
1056
1061
|
_tile_dpbf16ps(6, 1, 2);
|
|
1057
1062
|
_tile_dpbf16ps(7, 1, 3);
|
|
1058
1063
|
|
|
1059
|
-
_tile_loadd(0,
|
|
1060
|
-
_tile_loadd(1,
|
|
1064
|
+
_tile_loadd(0, a_top_base + next_depth_offset, a_stride_bytes);
|
|
1065
|
+
_tile_loadd(1, a_bottom_base + next_depth_offset, a_stride_bytes);
|
|
1061
1066
|
b_tile_left = (nk_dots_bf16_b32x16_sapphireamx_t const *)(b_tiles_base + (b_column_left_base +
|
|
1062
1067
|
depth_tile_idx + 1) *
|
|
1063
1068
|
tile_size);
|
|
@@ -1078,10 +1083,10 @@ NK_PUBLIC void nk_dots_packed_bf16_sapphireamx( //
|
|
|
1078
1083
|
if (depth_remainder > 0) {
|
|
1079
1084
|
nk_size_t const depth_offset = full_depth_tiles_count * tile_depth;
|
|
1080
1085
|
|
|
1081
|
-
nk_dots_bf16_load_a_sapphireamx_(&
|
|
1082
|
-
depth_remainder);
|
|
1083
|
-
nk_dots_bf16_load_a_sapphireamx_(&a_tile_lower, a_lower_base + depth_offset, a_stride_elements, 16,
|
|
1086
|
+
nk_dots_bf16_load_a_sapphireamx_(&a_tile_top, a_top_base + depth_offset, a_stride_elements, 16,
|
|
1084
1087
|
depth_remainder);
|
|
1088
|
+
nk_dots_bf16_load_a_sapphireamx_(&a_tile_bottom, a_bottom_base + depth_offset, a_stride_elements,
|
|
1089
|
+
16, depth_remainder);
|
|
1085
1090
|
|
|
1086
1091
|
b_tile_left = (nk_dots_bf16_b32x16_sapphireamx_t const *)(b_tiles_base + (b_column_left_base +
|
|
1087
1092
|
full_depth_tiles_count) *
|
|
@@ -1090,8 +1095,8 @@ NK_PUBLIC void nk_dots_packed_bf16_sapphireamx( //
|
|
|
1090
1095
|
full_depth_tiles_count) *
|
|
1091
1096
|
tile_size);
|
|
1092
1097
|
|
|
1093
|
-
_tile_loadd(0,
|
|
1094
|
-
_tile_loadd(1,
|
|
1098
|
+
_tile_loadd(0, a_tile_top.data, 64);
|
|
1099
|
+
_tile_loadd(1, a_tile_bottom.data, 64);
|
|
1095
1100
|
_tile_loadd(2, b_tile_left->data, 64);
|
|
1096
1101
|
_tile_loadd(3, b_tile_right->data, 64);
|
|
1097
1102
|
|
|
@@ -1103,19 +1108,19 @@ NK_PUBLIC void nk_dots_packed_bf16_sapphireamx( //
|
|
|
1103
1108
|
}
|
|
1104
1109
|
// Full row-block but only partial depth tile (depth < tile_depth)
|
|
1105
1110
|
else if (is_full_row_block) {
|
|
1106
|
-
nk_bf16_t const *
|
|
1107
|
-
nk_bf16_t const *
|
|
1111
|
+
nk_bf16_t const *a_top_base = a + row_block_start * a_stride_elements;
|
|
1112
|
+
nk_bf16_t const *a_bottom_base = a + (row_block_start + 16) * a_stride_elements;
|
|
1108
1113
|
|
|
1109
|
-
nk_dots_bf16_load_a_sapphireamx_(&
|
|
1110
|
-
nk_dots_bf16_load_a_sapphireamx_(&
|
|
1114
|
+
nk_dots_bf16_load_a_sapphireamx_(&a_tile_top, a_top_base, a_stride_elements, 16, depth_remainder);
|
|
1115
|
+
nk_dots_bf16_load_a_sapphireamx_(&a_tile_bottom, a_bottom_base, a_stride_elements, 16, depth_remainder);
|
|
1111
1116
|
|
|
1112
1117
|
nk_dots_bf16_b32x16_sapphireamx_t const *b_tile_left =
|
|
1113
1118
|
(nk_dots_bf16_b32x16_sapphireamx_t const *)(b_tiles_base + b_column_left_base * tile_size);
|
|
1114
1119
|
nk_dots_bf16_b32x16_sapphireamx_t const *b_tile_right =
|
|
1115
1120
|
(nk_dots_bf16_b32x16_sapphireamx_t const *)(b_tiles_base + b_column_right_base * tile_size);
|
|
1116
1121
|
|
|
1117
|
-
_tile_loadd(0,
|
|
1118
|
-
_tile_loadd(1,
|
|
1122
|
+
_tile_loadd(0, a_tile_top.data, 64);
|
|
1123
|
+
_tile_loadd(1, a_tile_bottom.data, 64);
|
|
1119
1124
|
_tile_loadd(2, b_tile_left->data, 64);
|
|
1120
1125
|
_tile_loadd(3, b_tile_right->data, 64);
|
|
1121
1126
|
|
|
@@ -1126,21 +1131,21 @@ NK_PUBLIC void nk_dots_packed_bf16_sapphireamx( //
|
|
|
1126
1131
|
}
|
|
1127
1132
|
// Slow path: edge row-block → buffered load with masking
|
|
1128
1133
|
else {
|
|
1129
|
-
nk_size_t const
|
|
1130
|
-
nk_size_t const
|
|
1134
|
+
nk_size_t const rows_in_high_tile = (valid_rows_count > 16) ? 16 : valid_rows_count;
|
|
1135
|
+
nk_size_t const rows_in_low_tile = (valid_rows_count > 16) ? valid_rows_count - 16 : 0;
|
|
1131
1136
|
|
|
1132
1137
|
for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
|
|
1133
1138
|
nk_size_t const depth_offset = depth_tile_idx * tile_depth;
|
|
1134
1139
|
nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth
|
|
1135
1140
|
: depth_remainder;
|
|
1136
1141
|
|
|
1137
|
-
nk_dots_bf16_load_a_sapphireamx_(&
|
|
1142
|
+
nk_dots_bf16_load_a_sapphireamx_(&a_tile_top,
|
|
1138
1143
|
a + row_block_start * a_stride_elements + depth_offset,
|
|
1139
|
-
a_stride_elements,
|
|
1140
|
-
if (
|
|
1141
|
-
nk_dots_bf16_load_a_sapphireamx_(&
|
|
1144
|
+
a_stride_elements, rows_in_high_tile, valid_depth);
|
|
1145
|
+
if (rows_in_low_tile > 0) {
|
|
1146
|
+
nk_dots_bf16_load_a_sapphireamx_(&a_tile_bottom,
|
|
1142
1147
|
a + (row_block_start + 16) * a_stride_elements + depth_offset,
|
|
1143
|
-
a_stride_elements,
|
|
1148
|
+
a_stride_elements, rows_in_low_tile, valid_depth);
|
|
1144
1149
|
}
|
|
1145
1150
|
|
|
1146
1151
|
nk_dots_bf16_b32x16_sapphireamx_t const *b_tile_left =
|
|
@@ -1150,8 +1155,8 @@ NK_PUBLIC void nk_dots_packed_bf16_sapphireamx( //
|
|
|
1150
1155
|
(nk_dots_bf16_b32x16_sapphireamx_t const *)(b_tiles_base +
|
|
1151
1156
|
(b_column_right_base + depth_tile_idx) * tile_size);
|
|
1152
1157
|
|
|
1153
|
-
_tile_loadd(0,
|
|
1154
|
-
_tile_loadd(1,
|
|
1158
|
+
_tile_loadd(0, a_tile_top.data, 64);
|
|
1159
|
+
_tile_loadd(1, a_tile_bottom.data, 64);
|
|
1155
1160
|
_tile_loadd(2, b_tile_left->data, 64);
|
|
1156
1161
|
_tile_loadd(3, b_tile_right->data, 64);
|
|
1157
1162
|
|
|
@@ -1192,10 +1197,10 @@ NK_PUBLIC void nk_dots_packed_bf16_sapphireamx( //
|
|
|
1192
1197
|
nk_size_t const row_block_start = row_block_idx * 32;
|
|
1193
1198
|
nk_size_t const valid_rows_count = (row_block_start + 32 <= rows_count) ? 32
|
|
1194
1199
|
: (rows_count - row_block_start);
|
|
1195
|
-
nk_size_t const
|
|
1196
|
-
nk_size_t const
|
|
1200
|
+
nk_size_t const rows_in_high_tile = (valid_rows_count > 16) ? 16 : valid_rows_count;
|
|
1201
|
+
nk_size_t const rows_in_low_tile = (valid_rows_count > 16) ? valid_rows_count - 16 : 0;
|
|
1197
1202
|
|
|
1198
|
-
nk_dots_bf16_state_sapphireamx_t
|
|
1203
|
+
nk_dots_bf16_state_sapphireamx_t c_high_state, c_low_state;
|
|
1199
1204
|
|
|
1200
1205
|
_tile_zero(4);
|
|
1201
1206
|
_tile_zero(6);
|
|
@@ -1204,35 +1209,35 @@ NK_PUBLIC void nk_dots_packed_bf16_sapphireamx( //
|
|
|
1204
1209
|
nk_size_t const depth_offset = depth_tile_idx * tile_depth;
|
|
1205
1210
|
nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
|
|
1206
1211
|
|
|
1207
|
-
nk_dots_bf16_load_a_sapphireamx_(&
|
|
1208
|
-
a_stride_elements,
|
|
1209
|
-
if (
|
|
1210
|
-
nk_dots_bf16_load_a_sapphireamx_(&
|
|
1212
|
+
nk_dots_bf16_load_a_sapphireamx_(&a_tile_top, a + row_block_start * a_stride_elements + depth_offset,
|
|
1213
|
+
a_stride_elements, rows_in_high_tile, valid_depth);
|
|
1214
|
+
if (rows_in_low_tile > 0) {
|
|
1215
|
+
nk_dots_bf16_load_a_sapphireamx_(&a_tile_bottom,
|
|
1211
1216
|
a + (row_block_start + 16) * a_stride_elements + depth_offset,
|
|
1212
|
-
a_stride_elements,
|
|
1217
|
+
a_stride_elements, rows_in_low_tile, valid_depth);
|
|
1213
1218
|
}
|
|
1214
1219
|
|
|
1215
1220
|
nk_dots_bf16_b32x16_sapphireamx_t const *b_tile =
|
|
1216
1221
|
(nk_dots_bf16_b32x16_sapphireamx_t const *)(b_tiles_base +
|
|
1217
1222
|
(b_column_base + depth_tile_idx) * tile_size);
|
|
1218
1223
|
|
|
1219
|
-
_tile_loadd(0,
|
|
1220
|
-
_tile_loadd(1,
|
|
1224
|
+
_tile_loadd(0, a_tile_top.data, 64);
|
|
1225
|
+
_tile_loadd(1, a_tile_bottom.data, 64);
|
|
1221
1226
|
_tile_loadd(2, b_tile->data, 64);
|
|
1222
1227
|
|
|
1223
1228
|
_tile_dpbf16ps(4, 0, 2);
|
|
1224
1229
|
_tile_dpbf16ps(6, 1, 2);
|
|
1225
1230
|
}
|
|
1226
1231
|
|
|
1227
|
-
_tile_stored(4,
|
|
1228
|
-
_tile_stored(6,
|
|
1232
|
+
_tile_stored(4, c_high_state.data, 64);
|
|
1233
|
+
_tile_stored(6, c_low_state.data, 64);
|
|
1229
1234
|
|
|
1230
|
-
nk_dots_bf16_store_sapphireamx_(&
|
|
1231
|
-
c_stride_elements,
|
|
1232
|
-
if (
|
|
1233
|
-
nk_dots_bf16_store_sapphireamx_(&
|
|
1235
|
+
nk_dots_bf16_store_sapphireamx_(&c_high_state, c + row_block_start * c_stride_elements + col_start,
|
|
1236
|
+
c_stride_elements, rows_in_high_tile, 16);
|
|
1237
|
+
if (rows_in_low_tile > 0) {
|
|
1238
|
+
nk_dots_bf16_store_sapphireamx_(&c_low_state,
|
|
1234
1239
|
c + (row_block_start + 16) * c_stride_elements + col_start,
|
|
1235
|
-
c_stride_elements,
|
|
1240
|
+
c_stride_elements, rows_in_low_tile, 16);
|
|
1236
1241
|
}
|
|
1237
1242
|
}
|
|
1238
1243
|
}
|
|
@@ -1243,10 +1248,10 @@ NK_PUBLIC void nk_dots_packed_bf16_sapphireamx( //
|
|
|
1243
1248
|
nk_size_t const row_block_start = row_block_idx * 32;
|
|
1244
1249
|
nk_size_t const valid_rows_count = (row_block_start + 32 <= rows_count) ? 32
|
|
1245
1250
|
: (rows_count - row_block_start);
|
|
1246
|
-
nk_size_t const
|
|
1247
|
-
nk_size_t const
|
|
1251
|
+
nk_size_t const rows_in_high_tile = (valid_rows_count > 16) ? 16 : valid_rows_count;
|
|
1252
|
+
nk_size_t const rows_in_low_tile = (valid_rows_count > 16) ? valid_rows_count - 16 : 0;
|
|
1248
1253
|
|
|
1249
|
-
nk_dots_bf16_state_sapphireamx_t
|
|
1254
|
+
nk_dots_bf16_state_sapphireamx_t c_high_state, c_low_state;
|
|
1250
1255
|
nk_dots_bf16_a16x32_sapphireamx_t b_as_a;
|
|
1251
1256
|
nk_dots_bf16_b32x16_sapphireamx_t b_tile;
|
|
1252
1257
|
|
|
@@ -1257,35 +1262,35 @@ NK_PUBLIC void nk_dots_packed_bf16_sapphireamx( //
|
|
|
1257
1262
|
nk_size_t const depth_offset = depth_tile_idx * tile_depth;
|
|
1258
1263
|
nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
|
|
1259
1264
|
|
|
1260
|
-
nk_dots_bf16_load_a_sapphireamx_(&
|
|
1261
|
-
a_stride_elements,
|
|
1262
|
-
if (
|
|
1263
|
-
nk_dots_bf16_load_a_sapphireamx_(&
|
|
1265
|
+
nk_dots_bf16_load_a_sapphireamx_(&a_tile_top, a + row_block_start * a_stride_elements + depth_offset,
|
|
1266
|
+
a_stride_elements, rows_in_high_tile, valid_depth);
|
|
1267
|
+
if (rows_in_low_tile > 0) {
|
|
1268
|
+
nk_dots_bf16_load_a_sapphireamx_(&a_tile_bottom,
|
|
1264
1269
|
a + (row_block_start + 16) * a_stride_elements + depth_offset,
|
|
1265
|
-
a_stride_elements,
|
|
1270
|
+
a_stride_elements, rows_in_low_tile, valid_depth);
|
|
1266
1271
|
}
|
|
1267
1272
|
|
|
1268
1273
|
nk_dots_bf16_load_a_sapphireamx_(&b_as_a, col_edge_ptr + depth_offset, depth, column_remainder_count,
|
|
1269
1274
|
valid_depth);
|
|
1270
1275
|
nk_dots_pack_bf16_transposed_sapphireamx_(&b_as_a, &b_tile);
|
|
1271
1276
|
|
|
1272
|
-
_tile_loadd(0,
|
|
1273
|
-
_tile_loadd(1,
|
|
1277
|
+
_tile_loadd(0, a_tile_top.data, 64);
|
|
1278
|
+
_tile_loadd(1, a_tile_bottom.data, 64);
|
|
1274
1279
|
_tile_loadd(2, b_tile.data, 64);
|
|
1275
1280
|
|
|
1276
1281
|
_tile_dpbf16ps(4, 0, 2);
|
|
1277
1282
|
_tile_dpbf16ps(6, 1, 2);
|
|
1278
1283
|
}
|
|
1279
1284
|
|
|
1280
|
-
_tile_stored(4,
|
|
1281
|
-
_tile_stored(6,
|
|
1285
|
+
_tile_stored(4, c_high_state.data, 64);
|
|
1286
|
+
_tile_stored(6, c_low_state.data, 64);
|
|
1282
1287
|
|
|
1283
|
-
nk_dots_bf16_store_sapphireamx_(&
|
|
1284
|
-
c_stride_elements,
|
|
1285
|
-
if (
|
|
1286
|
-
nk_dots_bf16_store_sapphireamx_(&
|
|
1288
|
+
nk_dots_bf16_store_sapphireamx_(&c_high_state, c + row_block_start * c_stride_elements + full_cols,
|
|
1289
|
+
c_stride_elements, rows_in_high_tile, column_remainder_count);
|
|
1290
|
+
if (rows_in_low_tile > 0) {
|
|
1291
|
+
nk_dots_bf16_store_sapphireamx_(&c_low_state,
|
|
1287
1292
|
c + (row_block_start + 16) * c_stride_elements + full_cols,
|
|
1288
|
-
c_stride_elements,
|
|
1293
|
+
c_stride_elements, rows_in_low_tile, column_remainder_count);
|
|
1289
1294
|
}
|
|
1290
1295
|
}
|
|
1291
1296
|
}
|
|
@@ -1294,9 +1299,9 @@ NK_PUBLIC void nk_dots_packed_bf16_sapphireamx( //
|
|
|
1294
1299
|
}
|
|
1295
1300
|
|
|
1296
1301
|
NK_PUBLIC void nk_dots_compact_bf16_sapphireamx( //
|
|
1297
|
-
void *c, nk_size_t row_count, nk_size_t column_count, nk_size_t
|
|
1302
|
+
void *c, nk_size_t row_count, nk_size_t column_count, nk_size_t c_stride_in_bytes) {
|
|
1298
1303
|
|
|
1299
|
-
nk_size_t const c_stride_f32 =
|
|
1304
|
+
nk_size_t const c_stride_f32 = c_stride_in_bytes / sizeof(nk_f32_t);
|
|
1300
1305
|
nk_f32_t const *c_f32 = (nk_f32_t const *)c;
|
|
1301
1306
|
nk_bf16_t *c_bf16 = (nk_bf16_t *)c;
|
|
1302
1307
|
|
|
@@ -1322,18 +1327,18 @@ NK_PUBLIC void nk_dots_compact_bf16_sapphireamx( //
|
|
|
1322
1327
|
}
|
|
1323
1328
|
}
|
|
1324
1329
|
|
|
1325
|
-
NK_PUBLIC void nk_dots_symmetric_bf16_sapphireamx(
|
|
1326
|
-
nk_bf16_t const *vectors, nk_size_t
|
|
1327
|
-
nk_size_t
|
|
1330
|
+
NK_PUBLIC void nk_dots_symmetric_bf16_sapphireamx( //
|
|
1331
|
+
nk_bf16_t const *vectors, nk_size_t vectors_count, nk_size_t depth, //
|
|
1332
|
+
nk_size_t stride_in_bytes, nk_f32_t *result, nk_size_t result_stride_in_bytes, //
|
|
1328
1333
|
nk_size_t row_start, nk_size_t row_count) {
|
|
1329
1334
|
|
|
1330
|
-
nk_size_t const stride_elements =
|
|
1331
|
-
nk_size_t const result_stride_elements =
|
|
1335
|
+
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_bf16_t);
|
|
1336
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
1332
1337
|
|
|
1333
1338
|
// Handle row slicing: compute rows [row_start, row_end)
|
|
1334
1339
|
nk_size_t const row_end = (row_count == 0)
|
|
1335
|
-
?
|
|
1336
|
-
: (row_start + row_count <
|
|
1340
|
+
? vectors_count
|
|
1341
|
+
: (row_start + row_count < vectors_count ? row_start + row_count : vectors_count);
|
|
1337
1342
|
|
|
1338
1343
|
// Round depth up to multiple of 96 (3 tiles × 32 elements)
|
|
1339
1344
|
nk_size_t const depth_tiles = nk_size_divide_round_up_(depth, 32);
|
|
@@ -1349,8 +1354,8 @@ NK_PUBLIC void nk_dots_symmetric_bf16_sapphireamx( //
|
|
|
1349
1354
|
for (nk_size_t row_tile = row_start; row_tile < row_end; row_tile += 16) {
|
|
1350
1355
|
nk_size_t const valid_rows = (row_tile + 16 <= row_end) ? 16 : (row_end - row_tile);
|
|
1351
1356
|
|
|
1352
|
-
for (nk_size_t col_tile = 0; col_tile <
|
|
1353
|
-
nk_size_t const valid_cols = (col_tile + 16 <=
|
|
1357
|
+
for (nk_size_t col_tile = 0; col_tile < vectors_count; col_tile += 16) {
|
|
1358
|
+
nk_size_t const valid_cols = (col_tile + 16 <= vectors_count) ? 16 : (vectors_count - col_tile);
|
|
1354
1359
|
|
|
1355
1360
|
nk_dots_bf16_init_sapphireamx_(&state);
|
|
1356
1361
|
|
|
@@ -1391,7 +1396,7 @@ NK_PUBLIC void nk_dots_symmetric_bf16_sapphireamx( //
|
|
|
1391
1396
|
}
|
|
1392
1397
|
}
|
|
1393
1398
|
|
|
1394
|
-
#pragma endregion
|
|
1399
|
+
#pragma endregion F16 Floats
|
|
1395
1400
|
|
|
1396
1401
|
#pragma region Signed Integers
|
|
1397
1402
|
|
|
@@ -1421,7 +1426,7 @@ NK_PUBLIC nk_size_t nk_dots_packed_size_i8_sapphireamx(nk_size_t column_count, n
|
|
|
1421
1426
|
|
|
1422
1427
|
NK_PUBLIC void nk_dots_pack_i8_sapphireamx( //
|
|
1423
1428
|
nk_i8_t const *b, nk_size_t column_count, nk_size_t depth, //
|
|
1424
|
-
nk_size_t
|
|
1429
|
+
nk_size_t b_stride_in_bytes, void *b_packed) {
|
|
1425
1430
|
|
|
1426
1431
|
// AMX I8 tile dimensions: 16 rows × 64 columns (1024 I8 elements = 1KB)
|
|
1427
1432
|
nk_size_t const tmm_rows = 16;
|
|
@@ -1450,34 +1455,45 @@ NK_PUBLIC void nk_dots_pack_i8_sapphireamx( //
|
|
|
1450
1455
|
nk_i8_t *tiles_ptr = (nk_i8_t *)((char *)b_packed + tiles_offset);
|
|
1451
1456
|
nk_i8_t *column_edge_ptr = (nk_i8_t *)((char *)b_packed + column_edge_offset);
|
|
1452
1457
|
|
|
1453
|
-
//
|
|
1454
|
-
|
|
1455
|
-
|
|
1456
|
-
// Pack tiles using LINEAR ordering: tile_index = column_tile × depth_tiles_count + depth_tile
|
|
1457
|
-
// This provides sequential memory access when streaming along depth dimension.
|
|
1458
|
+
// Pack tiles using vectorized transposer: gather 16 strided rows into an aligned
|
|
1459
|
+
// temporary, transpose via SIMD, then copy the result to the packed buffer.
|
|
1460
|
+
// Stack-local aligned tiles are needed because the packed buffer may not be 64-byte aligned.
|
|
1458
1461
|
for (nk_size_t column_tile_idx = 0; column_tile_idx < column_tiles_count; column_tile_idx++) {
|
|
1459
1462
|
for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
|
|
1460
1463
|
|
|
1461
|
-
// Linear tile index: all depth-tiles for one column-tile are contiguous
|
|
1462
1464
|
nk_size_t const tile_index = column_tile_idx * depth_tiles_count + depth_tile_idx;
|
|
1463
1465
|
nk_i8_t *tile_output = tiles_ptr + tile_index * tile_elements;
|
|
1464
1466
|
|
|
1465
|
-
// Source coordinates in original B matrix
|
|
1466
1467
|
nk_size_t const src_row_start = column_tile_idx * tmm_rows;
|
|
1467
1468
|
nk_size_t const src_column_start = depth_tile_idx * tmm_cols;
|
|
1468
1469
|
nk_size_t const columns_to_pack = (src_column_start + tmm_cols <= depth) ? tmm_cols
|
|
1469
1470
|
: (depth - src_column_start);
|
|
1470
1471
|
|
|
1471
|
-
//
|
|
1472
|
-
|
|
1473
|
-
|
|
1474
|
-
|
|
1475
|
-
|
|
1476
|
-
|
|
1477
|
-
|
|
1478
|
-
|
|
1472
|
+
// Gather 16 strided source rows into a contiguous aligned tile
|
|
1473
|
+
nk_dots_i8_a16x64_sapphireamx_t source_tile;
|
|
1474
|
+
if (columns_to_pack == tmm_cols) {
|
|
1475
|
+
for (nk_size_t row_idx = 0; row_idx < tmm_rows; row_idx++) {
|
|
1476
|
+
nk_i8_t const *source_row = (nk_i8_t const *)((char const *)b +
|
|
1477
|
+
(src_row_start + row_idx) * b_stride_in_bytes) +
|
|
1478
|
+
src_column_start;
|
|
1479
|
+
_mm512_store_si512(&source_tile.data[row_idx][0], _mm512_loadu_si512(source_row));
|
|
1480
|
+
}
|
|
1481
|
+
}
|
|
1482
|
+
else {
|
|
1483
|
+
__mmask64 depth_mask = (__mmask64)((columns_to_pack < 64) ? ((1ULL << columns_to_pack) - 1) : ~0ULL);
|
|
1484
|
+
for (nk_size_t row_idx = 0; row_idx < tmm_rows; row_idx++) {
|
|
1485
|
+
nk_i8_t const *source_row = (nk_i8_t const *)((char const *)b +
|
|
1486
|
+
(src_row_start + row_idx) * b_stride_in_bytes) +
|
|
1487
|
+
src_column_start;
|
|
1488
|
+
_mm512_store_si512(&source_tile.data[row_idx][0], _mm512_maskz_loadu_epi8(depth_mask, source_row));
|
|
1479
1489
|
}
|
|
1480
1490
|
}
|
|
1491
|
+
|
|
1492
|
+
// Transpose into aligned local, then copy to (potentially unaligned) packed buffer
|
|
1493
|
+
nk_dots_i8_b64x16_sapphireamx_t transposed_tile;
|
|
1494
|
+
nk_dots_pack_i8_transposed_sapphireamx_(&source_tile, &transposed_tile);
|
|
1495
|
+
for (nk_size_t i = 0; i < tile_elements; i += 64)
|
|
1496
|
+
_mm512_storeu_si512(tile_output + i, _mm512_load_si512((char const *)&transposed_tile + i));
|
|
1481
1497
|
}
|
|
1482
1498
|
}
|
|
1483
1499
|
|
|
@@ -1487,7 +1503,7 @@ NK_PUBLIC void nk_dots_pack_i8_sapphireamx( //
|
|
|
1487
1503
|
for (nk_size_t row_idx = 0; row_idx < column_remainder_count; row_idx++) {
|
|
1488
1504
|
for (nk_size_t column_idx = 0; column_idx < depth; column_idx++) {
|
|
1489
1505
|
column_edge_ptr[row_idx * depth + column_idx] =
|
|
1490
|
-
b[(remainder_start_row + row_idx) *
|
|
1506
|
+
b[(remainder_start_row + row_idx) * b_stride_in_bytes + column_idx];
|
|
1491
1507
|
}
|
|
1492
1508
|
}
|
|
1493
1509
|
}
|
|
@@ -1497,7 +1513,8 @@ NK_PUBLIC void nk_dots_pack_i8_sapphireamx( //
|
|
|
1497
1513
|
(column_remainder_count > 0 ? column_remainder_count * depth * sizeof(nk_i8_t) : 0);
|
|
1498
1514
|
header->norms_byte_offset = (nk_u32_t)norms_offset;
|
|
1499
1515
|
nk_u32_t *norms = (nk_u32_t *)((char *)b_packed + norms_offset);
|
|
1500
|
-
for (nk_size_t col = 0; col < column_count; col++)
|
|
1516
|
+
for (nk_size_t col = 0; col < column_count; col++)
|
|
1517
|
+
norms[col] = nk_dots_reduce_sumsq_i8_(b + col * b_stride_in_bytes, depth);
|
|
1501
1518
|
}
|
|
1502
1519
|
|
|
1503
1520
|
NK_PUBLIC void nk_dots_packed_i8_sapphireamx( //
|
|
@@ -1530,7 +1547,7 @@ NK_PUBLIC void nk_dots_packed_i8_sapphireamx( //
|
|
|
1530
1547
|
if (depth_tiles_count == 0) return;
|
|
1531
1548
|
|
|
1532
1549
|
// Tile buffers for A (only used for edge tiles)
|
|
1533
|
-
nk_dots_i8_a16x64_sapphireamx_t
|
|
1550
|
+
nk_dots_i8_a16x64_sapphireamx_t a_tile_top, a_tile_bottom;
|
|
1534
1551
|
nk_dots_i8_state2x2_sapphireamx_t c_accum_buffer;
|
|
1535
1552
|
|
|
1536
1553
|
// Precompute: number of full depth-tiles (no masking needed)
|
|
@@ -1562,8 +1579,8 @@ NK_PUBLIC void nk_dots_packed_i8_sapphireamx( //
|
|
|
1562
1579
|
// Fast path: full row-block with full depth-tiles → direct A load with 2-deep pipelining
|
|
1563
1580
|
if (is_full_row_block && full_depth_tiles_count > 0) {
|
|
1564
1581
|
// A row pointers for direct load
|
|
1565
|
-
nk_i8_t const *
|
|
1566
|
-
nk_i8_t const *
|
|
1582
|
+
nk_i8_t const *a_top_base = a + row_block_start * a_stride_bytes;
|
|
1583
|
+
nk_i8_t const *a_bottom_base = a + (row_block_start + 16) * a_stride_bytes;
|
|
1567
1584
|
|
|
1568
1585
|
// B tile pointers
|
|
1569
1586
|
nk_dots_i8_b64x16_sapphireamx_t const *b_tile_left =
|
|
@@ -1572,8 +1589,8 @@ NK_PUBLIC void nk_dots_packed_i8_sapphireamx( //
|
|
|
1572
1589
|
(nk_dots_i8_b64x16_sapphireamx_t const *)(b_tiles_base + b_column_right_base * tile_size);
|
|
1573
1590
|
|
|
1574
1591
|
// Prologue: load first depth tile into TMM0-3
|
|
1575
|
-
_tile_loadd(0,
|
|
1576
|
-
_tile_loadd(1,
|
|
1592
|
+
_tile_loadd(0, a_top_base, a_stride_bytes);
|
|
1593
|
+
_tile_loadd(1, a_bottom_base, a_stride_bytes);
|
|
1577
1594
|
_tile_loadd(2, b_tile_left->data, 64);
|
|
1578
1595
|
_tile_loadd(3, b_tile_right->data, 64);
|
|
1579
1596
|
|
|
@@ -1586,8 +1603,8 @@ NK_PUBLIC void nk_dots_packed_i8_sapphireamx( //
|
|
|
1586
1603
|
_tile_dpbssd(6, 1, 2);
|
|
1587
1604
|
_tile_dpbssd(7, 1, 3);
|
|
1588
1605
|
|
|
1589
|
-
_tile_loadd(0,
|
|
1590
|
-
_tile_loadd(1,
|
|
1606
|
+
_tile_loadd(0, a_top_base + next_depth_offset, a_stride_bytes);
|
|
1607
|
+
_tile_loadd(1, a_bottom_base + next_depth_offset, a_stride_bytes);
|
|
1591
1608
|
b_tile_left = (nk_dots_i8_b64x16_sapphireamx_t const *)(b_tiles_base +
|
|
1592
1609
|
(b_column_left_base + depth_tile_idx + 1) *
|
|
1593
1610
|
tile_size);
|
|
@@ -1608,9 +1625,9 @@ NK_PUBLIC void nk_dots_packed_i8_sapphireamx( //
|
|
|
1608
1625
|
if (depth_remainder > 0) {
|
|
1609
1626
|
nk_size_t const depth_offset = full_depth_tiles_count * tile_depth;
|
|
1610
1627
|
|
|
1611
|
-
nk_dots_i8_load_a_sapphireamx_(&
|
|
1628
|
+
nk_dots_i8_load_a_sapphireamx_(&a_tile_top, a_top_base + depth_offset, a_stride_bytes, 16,
|
|
1612
1629
|
depth_remainder);
|
|
1613
|
-
nk_dots_i8_load_a_sapphireamx_(&
|
|
1630
|
+
nk_dots_i8_load_a_sapphireamx_(&a_tile_bottom, a_bottom_base + depth_offset, a_stride_bytes, 16,
|
|
1614
1631
|
depth_remainder);
|
|
1615
1632
|
|
|
1616
1633
|
b_tile_left = (nk_dots_i8_b64x16_sapphireamx_t const *)(b_tiles_base + (b_column_left_base +
|
|
@@ -1620,8 +1637,8 @@ NK_PUBLIC void nk_dots_packed_i8_sapphireamx( //
|
|
|
1620
1637
|
full_depth_tiles_count) *
|
|
1621
1638
|
tile_size);
|
|
1622
1639
|
|
|
1623
|
-
_tile_loadd(0,
|
|
1624
|
-
_tile_loadd(1,
|
|
1640
|
+
_tile_loadd(0, a_tile_top.data, 64);
|
|
1641
|
+
_tile_loadd(1, a_tile_bottom.data, 64);
|
|
1625
1642
|
_tile_loadd(2, b_tile_left->data, 64);
|
|
1626
1643
|
_tile_loadd(3, b_tile_right->data, 64);
|
|
1627
1644
|
|
|
@@ -1633,19 +1650,19 @@ NK_PUBLIC void nk_dots_packed_i8_sapphireamx( //
|
|
|
1633
1650
|
}
|
|
1634
1651
|
// Full row-block but only partial depth tile (depth < tile_depth)
|
|
1635
1652
|
else if (is_full_row_block) {
|
|
1636
|
-
nk_i8_t const *
|
|
1637
|
-
nk_i8_t const *
|
|
1653
|
+
nk_i8_t const *a_top_base = a + row_block_start * a_stride_bytes;
|
|
1654
|
+
nk_i8_t const *a_bottom_base = a + (row_block_start + 16) * a_stride_bytes;
|
|
1638
1655
|
|
|
1639
|
-
nk_dots_i8_load_a_sapphireamx_(&
|
|
1640
|
-
nk_dots_i8_load_a_sapphireamx_(&
|
|
1656
|
+
nk_dots_i8_load_a_sapphireamx_(&a_tile_top, a_top_base, a_stride_bytes, 16, depth_remainder);
|
|
1657
|
+
nk_dots_i8_load_a_sapphireamx_(&a_tile_bottom, a_bottom_base, a_stride_bytes, 16, depth_remainder);
|
|
1641
1658
|
|
|
1642
1659
|
nk_dots_i8_b64x16_sapphireamx_t const *b_tile_left =
|
|
1643
1660
|
(nk_dots_i8_b64x16_sapphireamx_t const *)(b_tiles_base + b_column_left_base * tile_size);
|
|
1644
1661
|
nk_dots_i8_b64x16_sapphireamx_t const *b_tile_right =
|
|
1645
1662
|
(nk_dots_i8_b64x16_sapphireamx_t const *)(b_tiles_base + b_column_right_base * tile_size);
|
|
1646
1663
|
|
|
1647
|
-
_tile_loadd(0,
|
|
1648
|
-
_tile_loadd(1,
|
|
1664
|
+
_tile_loadd(0, a_tile_top.data, 64);
|
|
1665
|
+
_tile_loadd(1, a_tile_bottom.data, 64);
|
|
1649
1666
|
_tile_loadd(2, b_tile_left->data, 64);
|
|
1650
1667
|
_tile_loadd(3, b_tile_right->data, 64);
|
|
1651
1668
|
|
|
@@ -1656,20 +1673,20 @@ NK_PUBLIC void nk_dots_packed_i8_sapphireamx( //
|
|
|
1656
1673
|
}
|
|
1657
1674
|
// Slow path: edge row-block → always use buffered load with masking
|
|
1658
1675
|
else {
|
|
1659
|
-
nk_size_t const
|
|
1660
|
-
nk_size_t const
|
|
1676
|
+
nk_size_t const rows_in_high_tile = (valid_rows_count > 16) ? 16 : valid_rows_count;
|
|
1677
|
+
nk_size_t const rows_in_low_tile = (valid_rows_count > 16) ? valid_rows_count - 16 : 0;
|
|
1661
1678
|
|
|
1662
1679
|
for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
|
|
1663
1680
|
nk_size_t const depth_offset = depth_tile_idx * tile_depth;
|
|
1664
1681
|
nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth
|
|
1665
1682
|
: depth_remainder;
|
|
1666
1683
|
|
|
1667
|
-
nk_dots_i8_load_a_sapphireamx_(&
|
|
1668
|
-
a_stride_bytes,
|
|
1669
|
-
if (
|
|
1670
|
-
nk_dots_i8_load_a_sapphireamx_(&
|
|
1684
|
+
nk_dots_i8_load_a_sapphireamx_(&a_tile_top, a + row_block_start * a_stride_bytes + depth_offset,
|
|
1685
|
+
a_stride_bytes, rows_in_high_tile, valid_depth);
|
|
1686
|
+
if (rows_in_low_tile > 0) {
|
|
1687
|
+
nk_dots_i8_load_a_sapphireamx_(&a_tile_bottom,
|
|
1671
1688
|
a + (row_block_start + 16) * a_stride_bytes + depth_offset,
|
|
1672
|
-
a_stride_bytes,
|
|
1689
|
+
a_stride_bytes, rows_in_low_tile, valid_depth);
|
|
1673
1690
|
}
|
|
1674
1691
|
|
|
1675
1692
|
nk_dots_i8_b64x16_sapphireamx_t const *b_tile_left =
|
|
@@ -1679,8 +1696,8 @@ NK_PUBLIC void nk_dots_packed_i8_sapphireamx( //
|
|
|
1679
1696
|
(nk_dots_i8_b64x16_sapphireamx_t const *)(b_tiles_base +
|
|
1680
1697
|
(b_column_right_base + depth_tile_idx) * tile_size);
|
|
1681
1698
|
|
|
1682
|
-
_tile_loadd(0,
|
|
1683
|
-
_tile_loadd(1,
|
|
1699
|
+
_tile_loadd(0, a_tile_top.data, 64);
|
|
1700
|
+
_tile_loadd(1, a_tile_bottom.data, 64);
|
|
1684
1701
|
_tile_loadd(2, b_tile_left->data, 64);
|
|
1685
1702
|
_tile_loadd(3, b_tile_right->data, 64);
|
|
1686
1703
|
|
|
@@ -1716,11 +1733,11 @@ NK_PUBLIC void nk_dots_packed_i8_sapphireamx( //
|
|
|
1716
1733
|
nk_size_t const column_tile_idx = column_tiles_count - 1;
|
|
1717
1734
|
nk_size_t const col_start = column_tile_idx * 16;
|
|
1718
1735
|
nk_size_t const b_column_base = column_tile_idx * depth_tiles_count;
|
|
1719
|
-
nk_size_t const
|
|
1720
|
-
nk_size_t const
|
|
1736
|
+
nk_size_t const rows_in_high_tile = (valid_rows_count > 16) ? 16 : valid_rows_count;
|
|
1737
|
+
nk_size_t const rows_in_low_tile = (valid_rows_count > 16) ? valid_rows_count - 16 : 0;
|
|
1721
1738
|
|
|
1722
1739
|
// Use 1 × 2 blocking for single column-tile (2 row-tiles × 1 column-tile)
|
|
1723
|
-
nk_dots_i8_state_sapphireamx_t
|
|
1740
|
+
nk_dots_i8_state_sapphireamx_t c_high_state, c_low_state;
|
|
1724
1741
|
|
|
1725
1742
|
_tile_zero(4);
|
|
1726
1743
|
_tile_zero(6);
|
|
@@ -1729,44 +1746,43 @@ NK_PUBLIC void nk_dots_packed_i8_sapphireamx( //
|
|
|
1729
1746
|
nk_size_t const depth_offset = depth_tile_idx * tile_depth;
|
|
1730
1747
|
nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
|
|
1731
1748
|
|
|
1732
|
-
nk_dots_i8_load_a_sapphireamx_(&
|
|
1733
|
-
a_stride_bytes,
|
|
1734
|
-
if (
|
|
1735
|
-
nk_dots_i8_load_a_sapphireamx_(&
|
|
1749
|
+
nk_dots_i8_load_a_sapphireamx_(&a_tile_top, a + row_block_start * a_stride_bytes + depth_offset,
|
|
1750
|
+
a_stride_bytes, rows_in_high_tile, valid_depth);
|
|
1751
|
+
if (rows_in_low_tile > 0) {
|
|
1752
|
+
nk_dots_i8_load_a_sapphireamx_(&a_tile_bottom,
|
|
1736
1753
|
a + (row_block_start + 16) * a_stride_bytes + depth_offset,
|
|
1737
|
-
a_stride_bytes,
|
|
1754
|
+
a_stride_bytes, rows_in_low_tile, valid_depth);
|
|
1738
1755
|
}
|
|
1739
1756
|
|
|
1740
1757
|
nk_dots_i8_b64x16_sapphireamx_t const *b_tile =
|
|
1741
1758
|
(nk_dots_i8_b64x16_sapphireamx_t const *)(b_tiles_base +
|
|
1742
1759
|
(b_column_base + depth_tile_idx) * tile_size);
|
|
1743
1760
|
|
|
1744
|
-
_tile_loadd(0,
|
|
1745
|
-
_tile_loadd(1,
|
|
1761
|
+
_tile_loadd(0, a_tile_top.data, 64);
|
|
1762
|
+
_tile_loadd(1, a_tile_bottom.data, 64);
|
|
1746
1763
|
_tile_loadd(2, b_tile->data, 64);
|
|
1747
1764
|
|
|
1748
1765
|
_tile_dpbssd(4, 0, 2);
|
|
1749
1766
|
_tile_dpbssd(6, 1, 2);
|
|
1750
1767
|
}
|
|
1751
1768
|
|
|
1752
|
-
_tile_stored(4,
|
|
1753
|
-
_tile_stored(6,
|
|
1769
|
+
_tile_stored(4, c_high_state.data, 64);
|
|
1770
|
+
_tile_stored(6, c_low_state.data, 64);
|
|
1754
1771
|
|
|
1755
|
-
nk_dots_i8_store_sapphireamx_(&
|
|
1756
|
-
c_stride_elements,
|
|
1757
|
-
if (
|
|
1758
|
-
nk_dots_i8_store_sapphireamx_(&
|
|
1759
|
-
|
|
1760
|
-
c_stride_elements, rows_in_lower_tile, 16);
|
|
1772
|
+
nk_dots_i8_store_sapphireamx_(&c_high_state, c + row_block_start * c_stride_elements + col_start,
|
|
1773
|
+
c_stride_elements, rows_in_high_tile, 16);
|
|
1774
|
+
if (rows_in_low_tile > 0) {
|
|
1775
|
+
nk_dots_i8_store_sapphireamx_(&c_low_state, c + (row_block_start + 16) * c_stride_elements + col_start,
|
|
1776
|
+
c_stride_elements, rows_in_low_tile, 16);
|
|
1761
1777
|
}
|
|
1762
1778
|
}
|
|
1763
1779
|
|
|
1764
1780
|
// Handle column-edge (remaining columns < 16) using AMX with partial tiles
|
|
1765
1781
|
if (column_remainder_count > 0) {
|
|
1766
|
-
nk_size_t const
|
|
1767
|
-
nk_size_t const
|
|
1782
|
+
nk_size_t const rows_in_high_tile = (valid_rows_count > 16) ? 16 : valid_rows_count;
|
|
1783
|
+
nk_size_t const rows_in_low_tile = (valid_rows_count > 16) ? valid_rows_count - 16 : 0;
|
|
1768
1784
|
|
|
1769
|
-
nk_dots_i8_state_sapphireamx_t
|
|
1785
|
+
nk_dots_i8_state_sapphireamx_t c_high_state, c_low_state;
|
|
1770
1786
|
nk_dots_i8_a16x64_sapphireamx_t b_as_a;
|
|
1771
1787
|
nk_dots_i8_b64x16_sapphireamx_t b_tile;
|
|
1772
1788
|
|
|
@@ -1778,12 +1794,12 @@ NK_PUBLIC void nk_dots_packed_i8_sapphireamx( //
|
|
|
1778
1794
|
nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
|
|
1779
1795
|
|
|
1780
1796
|
// Load A tiles
|
|
1781
|
-
nk_dots_i8_load_a_sapphireamx_(&
|
|
1782
|
-
a_stride_bytes,
|
|
1783
|
-
if (
|
|
1784
|
-
nk_dots_i8_load_a_sapphireamx_(&
|
|
1797
|
+
nk_dots_i8_load_a_sapphireamx_(&a_tile_top, a + row_block_start * a_stride_bytes + depth_offset,
|
|
1798
|
+
a_stride_bytes, rows_in_high_tile, valid_depth);
|
|
1799
|
+
if (rows_in_low_tile > 0) {
|
|
1800
|
+
nk_dots_i8_load_a_sapphireamx_(&a_tile_bottom,
|
|
1785
1801
|
a + (row_block_start + 16) * a_stride_bytes + depth_offset,
|
|
1786
|
-
a_stride_bytes,
|
|
1802
|
+
a_stride_bytes, rows_in_low_tile, valid_depth);
|
|
1787
1803
|
}
|
|
1788
1804
|
|
|
1789
1805
|
// Load B edge data (row-major: b_edge[row × depth + column]) and pack into B tile
|
|
@@ -1792,23 +1808,22 @@ NK_PUBLIC void nk_dots_packed_i8_sapphireamx( //
|
|
|
1792
1808
|
valid_depth);
|
|
1793
1809
|
nk_dots_pack_i8_transposed_sapphireamx_(&b_as_a, &b_tile);
|
|
1794
1810
|
|
|
1795
|
-
_tile_loadd(0,
|
|
1796
|
-
_tile_loadd(1,
|
|
1811
|
+
_tile_loadd(0, a_tile_top.data, 64);
|
|
1812
|
+
_tile_loadd(1, a_tile_bottom.data, 64);
|
|
1797
1813
|
_tile_loadd(2, b_tile.data, 64);
|
|
1798
1814
|
|
|
1799
1815
|
_tile_dpbssd(4, 0, 2);
|
|
1800
1816
|
_tile_dpbssd(6, 1, 2);
|
|
1801
1817
|
}
|
|
1802
1818
|
|
|
1803
|
-
_tile_stored(4,
|
|
1804
|
-
_tile_stored(6,
|
|
1819
|
+
_tile_stored(4, c_high_state.data, 64);
|
|
1820
|
+
_tile_stored(6, c_low_state.data, 64);
|
|
1805
1821
|
|
|
1806
|
-
nk_dots_i8_store_sapphireamx_(&
|
|
1807
|
-
c_stride_elements,
|
|
1808
|
-
if (
|
|
1809
|
-
nk_dots_i8_store_sapphireamx_(&
|
|
1810
|
-
|
|
1811
|
-
c_stride_elements, rows_in_lower_tile, column_remainder_count);
|
|
1822
|
+
nk_dots_i8_store_sapphireamx_(&c_high_state, c + row_block_start * c_stride_elements + full_cols,
|
|
1823
|
+
c_stride_elements, rows_in_high_tile, column_remainder_count);
|
|
1824
|
+
if (rows_in_low_tile > 0) {
|
|
1825
|
+
nk_dots_i8_store_sapphireamx_(&c_low_state, c + (row_block_start + 16) * c_stride_elements + full_cols,
|
|
1826
|
+
c_stride_elements, rows_in_low_tile, column_remainder_count);
|
|
1812
1827
|
}
|
|
1813
1828
|
}
|
|
1814
1829
|
}
|
|
@@ -1817,10 +1832,10 @@ NK_PUBLIC void nk_dots_packed_i8_sapphireamx( //
|
|
|
1817
1832
|
}
|
|
1818
1833
|
|
|
1819
1834
|
NK_PUBLIC void nk_dots_compact_i8_sapphireamx( //
|
|
1820
|
-
void *c, nk_size_t row_count, nk_size_t column_count, nk_size_t
|
|
1835
|
+
void *c, nk_size_t row_count, nk_size_t column_count, nk_size_t c_stride_in_bytes, nk_i32_t const *a_squared_norms,
|
|
1821
1836
|
nk_i32_t const *b_squared_norms) {
|
|
1822
1837
|
|
|
1823
|
-
nk_size_t const c_stride_i32 =
|
|
1838
|
+
nk_size_t const c_stride_i32 = c_stride_in_bytes / sizeof(nk_i32_t);
|
|
1824
1839
|
nk_i32_t const *c_i32 = (nk_i32_t const *)c;
|
|
1825
1840
|
nk_i8_t *c_i8 = (nk_i8_t *)c;
|
|
1826
1841
|
|
|
@@ -1828,41 +1843,45 @@ NK_PUBLIC void nk_dots_compact_i8_sapphireamx( //
|
|
|
1828
1843
|
nk_f32_t *b_rsqrt = (nk_f32_t *)(c_i8 + row_count * column_count);
|
|
1829
1844
|
|
|
1830
1845
|
// Precompute rsqrt of all b_norms using AVX512 (16 at a time)
|
|
1831
|
-
__m512
|
|
1832
|
-
__m512
|
|
1846
|
+
__m512 half_vec_f32x16 = _mm512_set1_ps(0.5f);
|
|
1847
|
+
__m512 three_halves_vec_f32x16 = _mm512_set1_ps(1.5f);
|
|
1833
1848
|
nk_size_t column_idx = 0;
|
|
1834
1849
|
|
|
1835
1850
|
for (; column_idx + 16 <= column_count; column_idx += 16) {
|
|
1836
|
-
__m512i
|
|
1837
|
-
__m512
|
|
1838
|
-
__m512
|
|
1851
|
+
__m512i b_norms_i32x16 = _mm512_loadu_si512(b_squared_norms + column_idx);
|
|
1852
|
+
__m512 b_norms_f32x16 = _mm512_cvtepi32_ps(b_norms_i32x16);
|
|
1853
|
+
__m512 rsqrt_vec_f32x16 = _mm512_rsqrt14_ps(b_norms_f32x16);
|
|
1839
1854
|
// Newton-Raphson refinement
|
|
1840
|
-
|
|
1841
|
-
|
|
1842
|
-
_mm512_sub_ps(
|
|
1843
|
-
|
|
1855
|
+
rsqrt_vec_f32x16 = _mm512_mul_ps(
|
|
1856
|
+
rsqrt_vec_f32x16,
|
|
1857
|
+
_mm512_sub_ps(
|
|
1858
|
+
three_halves_vec_f32x16,
|
|
1859
|
+
_mm512_mul_ps(half_vec_f32x16,
|
|
1860
|
+
_mm512_mul_ps(b_norms_f32x16, _mm512_mul_ps(rsqrt_vec_f32x16, rsqrt_vec_f32x16)))));
|
|
1844
1861
|
// Zero out rsqrt where norm was zero
|
|
1845
|
-
__mmask16 nonzero_mask = _mm512_cmpneq_epi32_mask(
|
|
1846
|
-
|
|
1847
|
-
_mm512_storeu_ps(b_rsqrt + column_idx,
|
|
1862
|
+
__mmask16 nonzero_mask = _mm512_cmpneq_epi32_mask(b_norms_i32x16, _mm512_setzero_si512());
|
|
1863
|
+
rsqrt_vec_f32x16 = _mm512_maskz_mov_ps(nonzero_mask, rsqrt_vec_f32x16);
|
|
1864
|
+
_mm512_storeu_ps(b_rsqrt + column_idx, rsqrt_vec_f32x16);
|
|
1848
1865
|
}
|
|
1849
1866
|
|
|
1850
1867
|
// Handle remaining b_norms with masked operations
|
|
1851
1868
|
if (column_idx < column_count) {
|
|
1852
1869
|
__mmask16 tail_mask = (__mmask16)((1u << (column_count - column_idx)) - 1);
|
|
1853
|
-
__m512i
|
|
1854
|
-
__m512
|
|
1855
|
-
__m512
|
|
1856
|
-
|
|
1857
|
-
|
|
1858
|
-
_mm512_sub_ps(
|
|
1859
|
-
|
|
1860
|
-
|
|
1861
|
-
|
|
1862
|
-
|
|
1870
|
+
__m512i b_norms_i32x16 = _mm512_maskz_loadu_epi32(tail_mask, b_squared_norms + column_idx);
|
|
1871
|
+
__m512 b_norms_f32x16 = _mm512_cvtepi32_ps(b_norms_i32x16);
|
|
1872
|
+
__m512 rsqrt_vec_f32x16 = _mm512_rsqrt14_ps(b_norms_f32x16);
|
|
1873
|
+
rsqrt_vec_f32x16 = _mm512_mul_ps(
|
|
1874
|
+
rsqrt_vec_f32x16,
|
|
1875
|
+
_mm512_sub_ps(
|
|
1876
|
+
three_halves_vec_f32x16,
|
|
1877
|
+
_mm512_mul_ps(half_vec_f32x16,
|
|
1878
|
+
_mm512_mul_ps(b_norms_f32x16, _mm512_mul_ps(rsqrt_vec_f32x16, rsqrt_vec_f32x16)))));
|
|
1879
|
+
__mmask16 nonzero_mask = _mm512_cmpneq_epi32_mask(b_norms_i32x16, _mm512_setzero_si512());
|
|
1880
|
+
rsqrt_vec_f32x16 = _mm512_maskz_mov_ps(nonzero_mask & tail_mask, rsqrt_vec_f32x16);
|
|
1881
|
+
_mm512_mask_storeu_ps(b_rsqrt + column_idx, tail_mask, rsqrt_vec_f32x16);
|
|
1863
1882
|
}
|
|
1864
1883
|
|
|
1865
|
-
__m512
|
|
1884
|
+
__m512 scale_vec_f32x16 = _mm512_set1_ps(127.0f);
|
|
1866
1885
|
|
|
1867
1886
|
for (nk_size_t row_idx = 0; row_idx < row_count; row_idx++) {
|
|
1868
1887
|
nk_i32_t const *src_row = c_i32 + row_idx * c_stride_i32;
|
|
@@ -1872,55 +1891,57 @@ NK_PUBLIC void nk_dots_compact_i8_sapphireamx( //
|
|
|
1872
1891
|
nk_f32_t a_norm_f32 = (nk_f32_t)a_squared_norms[row_idx];
|
|
1873
1892
|
nk_f32_t a_rsqrt_val = 0.0f;
|
|
1874
1893
|
if (a_norm_f32 > 0.0f) {
|
|
1875
|
-
__m128
|
|
1876
|
-
__m128
|
|
1877
|
-
|
|
1878
|
-
|
|
1879
|
-
|
|
1880
|
-
|
|
1894
|
+
__m128 a_vec_f32x4 = _mm_set_ss(a_norm_f32);
|
|
1895
|
+
__m128 rsqrt_s_f32x4 = _mm_rsqrt_ss(a_vec_f32x4);
|
|
1896
|
+
rsqrt_s_f32x4 = _mm_mul_ss(
|
|
1897
|
+
rsqrt_s_f32x4,
|
|
1898
|
+
_mm_sub_ss(
|
|
1899
|
+
_mm_set_ss(1.5f),
|
|
1900
|
+
_mm_mul_ss(_mm_set_ss(0.5f), _mm_mul_ss(a_vec_f32x4, _mm_mul_ss(rsqrt_s_f32x4, rsqrt_s_f32x4)))));
|
|
1901
|
+
a_rsqrt_val = _mm_cvtss_f32(rsqrt_s_f32x4);
|
|
1881
1902
|
}
|
|
1882
|
-
__m512
|
|
1883
|
-
__m512
|
|
1903
|
+
__m512 a_rsqrt_vec_f32x16 = _mm512_set1_ps(a_rsqrt_val);
|
|
1904
|
+
__m512 row_scale_f32x16 = _mm512_mul_ps(a_rsqrt_vec_f32x16, scale_vec_f32x16);
|
|
1884
1905
|
|
|
1885
1906
|
column_idx = 0;
|
|
1886
1907
|
|
|
1887
1908
|
// Process 16 elements at a time
|
|
1888
1909
|
for (; column_idx + 16 <= column_count; column_idx += 16) {
|
|
1889
|
-
__m512i
|
|
1890
|
-
__m512
|
|
1891
|
-
__m512
|
|
1892
|
-
__m512
|
|
1893
|
-
__m512i
|
|
1910
|
+
__m512i c_vals_i32x16 = _mm512_loadu_si512(src_row + column_idx);
|
|
1911
|
+
__m512 c_f32_f32x16 = _mm512_cvtepi32_ps(c_vals_i32x16);
|
|
1912
|
+
__m512 b_rsqrt_vec_f32x16 = _mm512_loadu_ps(b_rsqrt + column_idx);
|
|
1913
|
+
__m512 normalized_f32x16 = _mm512_mul_ps(_mm512_mul_ps(c_f32_f32x16, row_scale_f32x16), b_rsqrt_vec_f32x16);
|
|
1914
|
+
__m512i result_i32x16 = _mm512_cvtps_epi32(normalized_f32x16);
|
|
1894
1915
|
// Saturating pack I32 → I8 (16 values → 16 bytes in low 128 bits)
|
|
1895
|
-
__m128i
|
|
1896
|
-
_mm_storeu_si128((__m128i *)(dst_row + column_idx),
|
|
1916
|
+
__m128i result_i8x16 = _mm512_cvtsepi32_epi8(result_i32x16);
|
|
1917
|
+
_mm_storeu_si128((__m128i *)(dst_row + column_idx), result_i8x16);
|
|
1897
1918
|
}
|
|
1898
1919
|
|
|
1899
1920
|
// Handle remaining elements with masked operations
|
|
1900
1921
|
if (column_idx < column_count) {
|
|
1901
1922
|
__mmask16 tail_mask = (__mmask16)((1u << (column_count - column_idx)) - 1);
|
|
1902
|
-
__m512i
|
|
1903
|
-
__m512
|
|
1904
|
-
__m512
|
|
1905
|
-
__m512
|
|
1906
|
-
__m512i
|
|
1907
|
-
__m128i
|
|
1908
|
-
_mm_mask_storeu_epi8(dst_row + column_idx, tail_mask,
|
|
1923
|
+
__m512i c_vals_i32x16 = _mm512_maskz_loadu_epi32(tail_mask, src_row + column_idx);
|
|
1924
|
+
__m512 c_f32_f32x16 = _mm512_cvtepi32_ps(c_vals_i32x16);
|
|
1925
|
+
__m512 b_rsqrt_vec_f32x16 = _mm512_maskz_loadu_ps(tail_mask, b_rsqrt + column_idx);
|
|
1926
|
+
__m512 normalized_f32x16 = _mm512_mul_ps(_mm512_mul_ps(c_f32_f32x16, row_scale_f32x16), b_rsqrt_vec_f32x16);
|
|
1927
|
+
__m512i result_i32x16 = _mm512_cvtps_epi32(normalized_f32x16);
|
|
1928
|
+
__m128i result_i8x16 = _mm512_cvtsepi32_epi8(result_i32x16);
|
|
1929
|
+
_mm_mask_storeu_epi8(dst_row + column_idx, tail_mask, result_i8x16);
|
|
1909
1930
|
}
|
|
1910
1931
|
}
|
|
1911
1932
|
}
|
|
1912
1933
|
|
|
1913
|
-
NK_PUBLIC void nk_dots_symmetric_i8_sapphireamx(
|
|
1914
|
-
nk_i8_t const *vectors, nk_size_t
|
|
1915
|
-
nk_size_t
|
|
1934
|
+
NK_PUBLIC void nk_dots_symmetric_i8_sapphireamx( //
|
|
1935
|
+
nk_i8_t const *vectors, nk_size_t vectors_count, nk_size_t depth, //
|
|
1936
|
+
nk_size_t stride_in_bytes, nk_i32_t *result, nk_size_t result_stride_in_bytes, //
|
|
1916
1937
|
nk_size_t row_start, nk_size_t row_count) {
|
|
1917
1938
|
|
|
1918
|
-
nk_size_t const result_stride_elements =
|
|
1939
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_i32_t);
|
|
1919
1940
|
|
|
1920
1941
|
// Handle row slicing: compute rows [row_start, row_end)
|
|
1921
1942
|
nk_size_t const row_end = (row_count == 0)
|
|
1922
|
-
?
|
|
1923
|
-
: (row_start + row_count <
|
|
1943
|
+
? vectors_count
|
|
1944
|
+
: (row_start + row_count < vectors_count ? row_start + row_count : vectors_count);
|
|
1924
1945
|
|
|
1925
1946
|
// Round depth up to multiple of 192 (3 tiles × 64 elements)
|
|
1926
1947
|
nk_size_t const depth_tiles = nk_size_divide_round_up_(depth, 64);
|
|
@@ -1936,8 +1957,8 @@ NK_PUBLIC void nk_dots_symmetric_i8_sapphireamx( //
|
|
|
1936
1957
|
for (nk_size_t row_tile = row_start; row_tile < row_end; row_tile += 16) {
|
|
1937
1958
|
nk_size_t const valid_rows = (row_tile + 16 <= row_end) ? 16 : (row_end - row_tile);
|
|
1938
1959
|
|
|
1939
|
-
for (nk_size_t col_tile = 0; col_tile <
|
|
1940
|
-
nk_size_t const valid_cols = (col_tile + 16 <=
|
|
1960
|
+
for (nk_size_t col_tile = 0; col_tile < vectors_count; col_tile += 16) {
|
|
1961
|
+
nk_size_t const valid_cols = (col_tile + 16 <= vectors_count) ? 16 : (vectors_count - col_tile);
|
|
1941
1962
|
|
|
1942
1963
|
nk_dots_i8_init_sapphireamx_(&state);
|
|
1943
1964
|
|
|
@@ -1950,19 +1971,19 @@ NK_PUBLIC void nk_dots_symmetric_i8_sapphireamx( //
|
|
|
1950
1971
|
? 64
|
|
1951
1972
|
: (depth > depth_start ? depth - depth_start : 0);
|
|
1952
1973
|
|
|
1953
|
-
nk_dots_i8_load_a_sapphireamx_(
|
|
1954
|
-
&a_tiles[tile_idx],
|
|
1955
|
-
vectors + row_tile *
|
|
1956
|
-
|
|
1974
|
+
nk_dots_i8_load_a_sapphireamx_( //
|
|
1975
|
+
&a_tiles[tile_idx], //
|
|
1976
|
+
vectors + row_tile * stride_in_bytes + depth_start, //
|
|
1977
|
+
stride_in_bytes, valid_rows, valid_depth);
|
|
1957
1978
|
|
|
1958
1979
|
if (row_tile == col_tile) {
|
|
1959
1980
|
nk_dots_pack_i8_transposed_sapphireamx_(&a_tiles[tile_idx], &b_tiles[tile_idx]);
|
|
1960
1981
|
}
|
|
1961
1982
|
else {
|
|
1962
|
-
nk_dots_i8_load_a_sapphireamx_(
|
|
1963
|
-
&b_src_tiles[tile_idx],
|
|
1964
|
-
vectors + col_tile *
|
|
1965
|
-
|
|
1983
|
+
nk_dots_i8_load_a_sapphireamx_( //
|
|
1984
|
+
&b_src_tiles[tile_idx], //
|
|
1985
|
+
vectors + col_tile * stride_in_bytes + depth_start, //
|
|
1986
|
+
stride_in_bytes, valid_cols, valid_depth);
|
|
1966
1987
|
nk_dots_pack_i8_transposed_sapphireamx_(&b_src_tiles[tile_idx], &b_tiles[tile_idx]);
|
|
1967
1988
|
}
|
|
1968
1989
|
}
|
|
@@ -1978,7 +1999,7 @@ NK_PUBLIC void nk_dots_symmetric_i8_sapphireamx( //
|
|
|
1978
1999
|
}
|
|
1979
2000
|
}
|
|
1980
2001
|
|
|
1981
|
-
#pragma endregion
|
|
2002
|
+
#pragma endregion Signed Integers
|
|
1982
2003
|
|
|
1983
2004
|
#pragma region Unsigned Integers
|
|
1984
2005
|
|
|
@@ -1989,7 +2010,7 @@ NK_PUBLIC nk_size_t nk_dots_packed_size_u8_sapphireamx(nk_size_t column_count, n
|
|
|
1989
2010
|
|
|
1990
2011
|
NK_PUBLIC void nk_dots_pack_u8_sapphireamx( //
|
|
1991
2012
|
nk_u8_t const *b, nk_size_t column_count, nk_size_t depth, //
|
|
1992
|
-
nk_size_t
|
|
2013
|
+
nk_size_t b_stride_in_bytes, void *b_packed) {
|
|
1993
2014
|
|
|
1994
2015
|
nk_size_t const tmm_rows = 16;
|
|
1995
2016
|
nk_size_t const tmm_cols = 64;
|
|
@@ -2013,8 +2034,9 @@ NK_PUBLIC void nk_dots_pack_u8_sapphireamx( //
|
|
|
2013
2034
|
nk_u8_t *tiles_ptr = (nk_u8_t *)((char *)b_packed + tiles_offset);
|
|
2014
2035
|
nk_u8_t *column_edge_ptr = (nk_u8_t *)((char *)b_packed + column_edge_offset);
|
|
2015
2036
|
|
|
2016
|
-
|
|
2017
|
-
|
|
2037
|
+
// Pack tiles using vectorized transposer: gather 16 strided rows into an aligned
|
|
2038
|
+
// temporary, transpose via SIMD, then copy the result to the packed buffer.
|
|
2039
|
+
// Stack-local aligned tiles are needed because the packed buffer may not be 64-byte aligned.
|
|
2018
2040
|
for (nk_size_t column_tile_idx = 0; column_tile_idx < column_tiles_count; column_tile_idx++) {
|
|
2019
2041
|
for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
|
|
2020
2042
|
|
|
@@ -2026,14 +2048,31 @@ NK_PUBLIC void nk_dots_pack_u8_sapphireamx( //
|
|
|
2026
2048
|
nk_size_t const columns_to_pack = (src_column_start + tmm_cols <= depth) ? tmm_cols
|
|
2027
2049
|
: (depth - src_column_start);
|
|
2028
2050
|
|
|
2029
|
-
//
|
|
2030
|
-
|
|
2031
|
-
|
|
2032
|
-
|
|
2033
|
-
|
|
2034
|
-
|
|
2051
|
+
// Gather 16 strided source rows into a contiguous aligned tile
|
|
2052
|
+
nk_dots_u8_a16x64_sapphireamx_t source_tile;
|
|
2053
|
+
if (columns_to_pack == tmm_cols) {
|
|
2054
|
+
for (nk_size_t row_idx = 0; row_idx < tmm_rows; row_idx++) {
|
|
2055
|
+
nk_u8_t const *source_row = (nk_u8_t const *)((char const *)b +
|
|
2056
|
+
(src_row_start + row_idx) * b_stride_in_bytes) +
|
|
2057
|
+
src_column_start;
|
|
2058
|
+
_mm512_store_si512(&source_tile.data[row_idx][0], _mm512_loadu_si512(source_row));
|
|
2059
|
+
}
|
|
2060
|
+
}
|
|
2061
|
+
else {
|
|
2062
|
+
__mmask64 depth_mask = (__mmask64)((columns_to_pack < 64) ? ((1ULL << columns_to_pack) - 1) : ~0ULL);
|
|
2063
|
+
for (nk_size_t row_idx = 0; row_idx < tmm_rows; row_idx++) {
|
|
2064
|
+
nk_u8_t const *source_row = (nk_u8_t const *)((char const *)b +
|
|
2065
|
+
(src_row_start + row_idx) * b_stride_in_bytes) +
|
|
2066
|
+
src_column_start;
|
|
2067
|
+
_mm512_store_si512(&source_tile.data[row_idx][0], _mm512_maskz_loadu_epi8(depth_mask, source_row));
|
|
2035
2068
|
}
|
|
2036
2069
|
}
|
|
2070
|
+
|
|
2071
|
+
// Transpose into aligned local, then copy to (potentially unaligned) packed buffer
|
|
2072
|
+
nk_dots_u8_b64x16_sapphireamx_t transposed_tile;
|
|
2073
|
+
nk_dots_pack_u8_transposed_sapphireamx_(&source_tile, &transposed_tile);
|
|
2074
|
+
for (nk_size_t i = 0; i < tile_elements; i += 64)
|
|
2075
|
+
_mm512_storeu_si512(tile_output + i, _mm512_load_si512((char const *)&transposed_tile + i));
|
|
2037
2076
|
}
|
|
2038
2077
|
}
|
|
2039
2078
|
|
|
@@ -2042,7 +2081,7 @@ NK_PUBLIC void nk_dots_pack_u8_sapphireamx( //
|
|
|
2042
2081
|
for (nk_size_t row_idx = 0; row_idx < column_remainder_count; row_idx++) {
|
|
2043
2082
|
for (nk_size_t column_idx = 0; column_idx < depth; column_idx++) {
|
|
2044
2083
|
column_edge_ptr[row_idx * depth + column_idx] =
|
|
2045
|
-
b[(remainder_start_row + row_idx) *
|
|
2084
|
+
b[(remainder_start_row + row_idx) * b_stride_in_bytes + column_idx];
|
|
2046
2085
|
}
|
|
2047
2086
|
}
|
|
2048
2087
|
}
|
|
@@ -2052,7 +2091,8 @@ NK_PUBLIC void nk_dots_pack_u8_sapphireamx( //
|
|
|
2052
2091
|
(column_remainder_count > 0 ? column_remainder_count * depth * sizeof(nk_u8_t) : 0);
|
|
2053
2092
|
header->norms_byte_offset = (nk_u32_t)norms_offset;
|
|
2054
2093
|
nk_u32_t *norms = (nk_u32_t *)((char *)b_packed + norms_offset);
|
|
2055
|
-
for (nk_size_t col = 0; col < column_count; col++)
|
|
2094
|
+
for (nk_size_t col = 0; col < column_count; col++)
|
|
2095
|
+
norms[col] = nk_dots_reduce_sumsq_u8_(b + col * b_stride_in_bytes, depth);
|
|
2056
2096
|
}
|
|
2057
2097
|
|
|
2058
2098
|
NK_PUBLIC void nk_dots_packed_u8_sapphireamx( //
|
|
@@ -2085,7 +2125,7 @@ NK_PUBLIC void nk_dots_packed_u8_sapphireamx( //
|
|
|
2085
2125
|
if (depth_tiles_count == 0) return;
|
|
2086
2126
|
|
|
2087
2127
|
// Tile buffers for A (only used for edge tiles)
|
|
2088
|
-
nk_dots_u8_a16x64_sapphireamx_t
|
|
2128
|
+
nk_dots_u8_a16x64_sapphireamx_t a_tile_top, a_tile_bottom;
|
|
2089
2129
|
nk_dots_u8_state2x2_sapphireamx_t c_accum_buffer;
|
|
2090
2130
|
|
|
2091
2131
|
// Precompute: number of full depth-tiles
|
|
@@ -2116,8 +2156,8 @@ NK_PUBLIC void nk_dots_packed_u8_sapphireamx( //
|
|
|
2116
2156
|
|
|
2117
2157
|
// Fast path: full row-block with full depth-tiles → direct A load with 2-deep pipelining
|
|
2118
2158
|
if (is_full_row_block && full_depth_tiles_count > 0) {
|
|
2119
|
-
nk_u8_t const *
|
|
2120
|
-
nk_u8_t const *
|
|
2159
|
+
nk_u8_t const *a_top_base = a + row_block_start * a_stride_bytes;
|
|
2160
|
+
nk_u8_t const *a_bottom_base = a + (row_block_start + 16) * a_stride_bytes;
|
|
2121
2161
|
|
|
2122
2162
|
nk_dots_u8_b64x16_sapphireamx_t const *b_tile_left =
|
|
2123
2163
|
(nk_dots_u8_b64x16_sapphireamx_t const *)(b_tiles_base + b_column_left_base * tile_size);
|
|
@@ -2125,8 +2165,8 @@ NK_PUBLIC void nk_dots_packed_u8_sapphireamx( //
|
|
|
2125
2165
|
(nk_dots_u8_b64x16_sapphireamx_t const *)(b_tiles_base + b_column_right_base * tile_size);
|
|
2126
2166
|
|
|
2127
2167
|
// Prologue: load first depth tile into TMM0-3
|
|
2128
|
-
_tile_loadd(0,
|
|
2129
|
-
_tile_loadd(1,
|
|
2168
|
+
_tile_loadd(0, a_top_base, a_stride_bytes);
|
|
2169
|
+
_tile_loadd(1, a_bottom_base, a_stride_bytes);
|
|
2130
2170
|
_tile_loadd(2, b_tile_left->data, 64);
|
|
2131
2171
|
_tile_loadd(3, b_tile_right->data, 64);
|
|
2132
2172
|
|
|
@@ -2139,8 +2179,8 @@ NK_PUBLIC void nk_dots_packed_u8_sapphireamx( //
|
|
|
2139
2179
|
_tile_dpbuud(6, 1, 2);
|
|
2140
2180
|
_tile_dpbuud(7, 1, 3);
|
|
2141
2181
|
|
|
2142
|
-
_tile_loadd(0,
|
|
2143
|
-
_tile_loadd(1,
|
|
2182
|
+
_tile_loadd(0, a_top_base + next_depth_offset, a_stride_bytes);
|
|
2183
|
+
_tile_loadd(1, a_bottom_base + next_depth_offset, a_stride_bytes);
|
|
2144
2184
|
b_tile_left = (nk_dots_u8_b64x16_sapphireamx_t const *)(b_tiles_base +
|
|
2145
2185
|
(b_column_left_base + depth_tile_idx + 1) *
|
|
2146
2186
|
tile_size);
|
|
@@ -2161,9 +2201,9 @@ NK_PUBLIC void nk_dots_packed_u8_sapphireamx( //
|
|
|
2161
2201
|
if (depth_remainder > 0) {
|
|
2162
2202
|
nk_size_t const depth_offset = full_depth_tiles_count * tile_depth;
|
|
2163
2203
|
|
|
2164
|
-
nk_dots_u8_load_a_sapphireamx_(&
|
|
2204
|
+
nk_dots_u8_load_a_sapphireamx_(&a_tile_top, a_top_base + depth_offset, a_stride_bytes, 16,
|
|
2165
2205
|
depth_remainder);
|
|
2166
|
-
nk_dots_u8_load_a_sapphireamx_(&
|
|
2206
|
+
nk_dots_u8_load_a_sapphireamx_(&a_tile_bottom, a_bottom_base + depth_offset, a_stride_bytes, 16,
|
|
2167
2207
|
depth_remainder);
|
|
2168
2208
|
|
|
2169
2209
|
b_tile_left = (nk_dots_u8_b64x16_sapphireamx_t const *)(b_tiles_base + (b_column_left_base +
|
|
@@ -2173,8 +2213,8 @@ NK_PUBLIC void nk_dots_packed_u8_sapphireamx( //
|
|
|
2173
2213
|
full_depth_tiles_count) *
|
|
2174
2214
|
tile_size);
|
|
2175
2215
|
|
|
2176
|
-
_tile_loadd(0,
|
|
2177
|
-
_tile_loadd(1,
|
|
2216
|
+
_tile_loadd(0, a_tile_top.data, 64);
|
|
2217
|
+
_tile_loadd(1, a_tile_bottom.data, 64);
|
|
2178
2218
|
_tile_loadd(2, b_tile_left->data, 64);
|
|
2179
2219
|
_tile_loadd(3, b_tile_right->data, 64);
|
|
2180
2220
|
|
|
@@ -2186,19 +2226,19 @@ NK_PUBLIC void nk_dots_packed_u8_sapphireamx( //
|
|
|
2186
2226
|
}
|
|
2187
2227
|
// Full row-block but only partial depth tile (depth < tile_depth)
|
|
2188
2228
|
else if (is_full_row_block) {
|
|
2189
|
-
nk_u8_t const *
|
|
2190
|
-
nk_u8_t const *
|
|
2229
|
+
nk_u8_t const *a_top_base = a + row_block_start * a_stride_bytes;
|
|
2230
|
+
nk_u8_t const *a_bottom_base = a + (row_block_start + 16) * a_stride_bytes;
|
|
2191
2231
|
|
|
2192
|
-
nk_dots_u8_load_a_sapphireamx_(&
|
|
2193
|
-
nk_dots_u8_load_a_sapphireamx_(&
|
|
2232
|
+
nk_dots_u8_load_a_sapphireamx_(&a_tile_top, a_top_base, a_stride_bytes, 16, depth_remainder);
|
|
2233
|
+
nk_dots_u8_load_a_sapphireamx_(&a_tile_bottom, a_bottom_base, a_stride_bytes, 16, depth_remainder);
|
|
2194
2234
|
|
|
2195
2235
|
nk_dots_u8_b64x16_sapphireamx_t const *b_tile_left =
|
|
2196
2236
|
(nk_dots_u8_b64x16_sapphireamx_t const *)(b_tiles_base + b_column_left_base * tile_size);
|
|
2197
2237
|
nk_dots_u8_b64x16_sapphireamx_t const *b_tile_right =
|
|
2198
2238
|
(nk_dots_u8_b64x16_sapphireamx_t const *)(b_tiles_base + b_column_right_base * tile_size);
|
|
2199
2239
|
|
|
2200
|
-
_tile_loadd(0,
|
|
2201
|
-
_tile_loadd(1,
|
|
2240
|
+
_tile_loadd(0, a_tile_top.data, 64);
|
|
2241
|
+
_tile_loadd(1, a_tile_bottom.data, 64);
|
|
2202
2242
|
_tile_loadd(2, b_tile_left->data, 64);
|
|
2203
2243
|
_tile_loadd(3, b_tile_right->data, 64);
|
|
2204
2244
|
|
|
@@ -2209,20 +2249,20 @@ NK_PUBLIC void nk_dots_packed_u8_sapphireamx( //
|
|
|
2209
2249
|
}
|
|
2210
2250
|
// Slow path: edge row-block → always use buffered load
|
|
2211
2251
|
else {
|
|
2212
|
-
nk_size_t const
|
|
2213
|
-
nk_size_t const
|
|
2252
|
+
nk_size_t const rows_in_high_tile = (valid_rows_count > 16) ? 16 : valid_rows_count;
|
|
2253
|
+
nk_size_t const rows_in_low_tile = (valid_rows_count > 16) ? valid_rows_count - 16 : 0;
|
|
2214
2254
|
|
|
2215
2255
|
for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
|
|
2216
2256
|
nk_size_t const depth_offset = depth_tile_idx * tile_depth;
|
|
2217
2257
|
nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth
|
|
2218
2258
|
: depth_remainder;
|
|
2219
2259
|
|
|
2220
|
-
nk_dots_u8_load_a_sapphireamx_(&
|
|
2221
|
-
a_stride_bytes,
|
|
2222
|
-
if (
|
|
2223
|
-
nk_dots_u8_load_a_sapphireamx_(&
|
|
2260
|
+
nk_dots_u8_load_a_sapphireamx_(&a_tile_top, a + row_block_start * a_stride_bytes + depth_offset,
|
|
2261
|
+
a_stride_bytes, rows_in_high_tile, valid_depth);
|
|
2262
|
+
if (rows_in_low_tile > 0) {
|
|
2263
|
+
nk_dots_u8_load_a_sapphireamx_(&a_tile_bottom,
|
|
2224
2264
|
a + (row_block_start + 16) * a_stride_bytes + depth_offset,
|
|
2225
|
-
a_stride_bytes,
|
|
2265
|
+
a_stride_bytes, rows_in_low_tile, valid_depth);
|
|
2226
2266
|
}
|
|
2227
2267
|
|
|
2228
2268
|
nk_dots_u8_b64x16_sapphireamx_t const *b_tile_left =
|
|
@@ -2232,8 +2272,8 @@ NK_PUBLIC void nk_dots_packed_u8_sapphireamx( //
|
|
|
2232
2272
|
(nk_dots_u8_b64x16_sapphireamx_t const *)(b_tiles_base +
|
|
2233
2273
|
(b_column_right_base + depth_tile_idx) * tile_size);
|
|
2234
2274
|
|
|
2235
|
-
_tile_loadd(0,
|
|
2236
|
-
_tile_loadd(1,
|
|
2275
|
+
_tile_loadd(0, a_tile_top.data, 64);
|
|
2276
|
+
_tile_loadd(1, a_tile_bottom.data, 64);
|
|
2237
2277
|
_tile_loadd(2, b_tile_left->data, 64);
|
|
2238
2278
|
_tile_loadd(3, b_tile_right->data, 64);
|
|
2239
2279
|
|
|
@@ -2268,10 +2308,10 @@ NK_PUBLIC void nk_dots_packed_u8_sapphireamx( //
|
|
|
2268
2308
|
nk_size_t const column_tile_idx = column_tiles_count - 1;
|
|
2269
2309
|
nk_size_t const col_start = column_tile_idx * 16;
|
|
2270
2310
|
nk_size_t const b_column_base = column_tile_idx * depth_tiles_count;
|
|
2271
|
-
nk_size_t const
|
|
2272
|
-
nk_size_t const
|
|
2311
|
+
nk_size_t const rows_in_high_tile = (valid_rows_count > 16) ? 16 : valid_rows_count;
|
|
2312
|
+
nk_size_t const rows_in_low_tile = (valid_rows_count > 16) ? valid_rows_count - 16 : 0;
|
|
2273
2313
|
|
|
2274
|
-
nk_dots_u8_state_sapphireamx_t
|
|
2314
|
+
nk_dots_u8_state_sapphireamx_t c_high_state, c_low_state;
|
|
2275
2315
|
|
|
2276
2316
|
_tile_zero(4);
|
|
2277
2317
|
_tile_zero(6);
|
|
@@ -2280,44 +2320,43 @@ NK_PUBLIC void nk_dots_packed_u8_sapphireamx( //
|
|
|
2280
2320
|
nk_size_t const depth_offset = depth_tile_idx * tile_depth;
|
|
2281
2321
|
nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
|
|
2282
2322
|
|
|
2283
|
-
nk_dots_u8_load_a_sapphireamx_(&
|
|
2284
|
-
a_stride_bytes,
|
|
2285
|
-
if (
|
|
2286
|
-
nk_dots_u8_load_a_sapphireamx_(&
|
|
2323
|
+
nk_dots_u8_load_a_sapphireamx_(&a_tile_top, a + row_block_start * a_stride_bytes + depth_offset,
|
|
2324
|
+
a_stride_bytes, rows_in_high_tile, valid_depth);
|
|
2325
|
+
if (rows_in_low_tile > 0) {
|
|
2326
|
+
nk_dots_u8_load_a_sapphireamx_(&a_tile_bottom,
|
|
2287
2327
|
a + (row_block_start + 16) * a_stride_bytes + depth_offset,
|
|
2288
|
-
a_stride_bytes,
|
|
2328
|
+
a_stride_bytes, rows_in_low_tile, valid_depth);
|
|
2289
2329
|
}
|
|
2290
2330
|
|
|
2291
2331
|
nk_dots_u8_b64x16_sapphireamx_t const *b_tile =
|
|
2292
2332
|
(nk_dots_u8_b64x16_sapphireamx_t const *)(b_tiles_base +
|
|
2293
2333
|
(b_column_base + depth_tile_idx) * tile_size);
|
|
2294
2334
|
|
|
2295
|
-
_tile_loadd(0,
|
|
2296
|
-
_tile_loadd(1,
|
|
2335
|
+
_tile_loadd(0, a_tile_top.data, 64);
|
|
2336
|
+
_tile_loadd(1, a_tile_bottom.data, 64);
|
|
2297
2337
|
_tile_loadd(2, b_tile->data, 64);
|
|
2298
2338
|
|
|
2299
2339
|
_tile_dpbuud(4, 0, 2);
|
|
2300
2340
|
_tile_dpbuud(6, 1, 2);
|
|
2301
2341
|
}
|
|
2302
2342
|
|
|
2303
|
-
_tile_stored(4,
|
|
2304
|
-
_tile_stored(6,
|
|
2343
|
+
_tile_stored(4, c_high_state.data, 64);
|
|
2344
|
+
_tile_stored(6, c_low_state.data, 64);
|
|
2305
2345
|
|
|
2306
|
-
nk_dots_u8_store_sapphireamx_(&
|
|
2307
|
-
c_stride_elements,
|
|
2308
|
-
if (
|
|
2309
|
-
nk_dots_u8_store_sapphireamx_(&
|
|
2310
|
-
|
|
2311
|
-
c_stride_elements, rows_in_lower_tile, 16);
|
|
2346
|
+
nk_dots_u8_store_sapphireamx_(&c_high_state, c + row_block_start * c_stride_elements + col_start,
|
|
2347
|
+
c_stride_elements, rows_in_high_tile, 16);
|
|
2348
|
+
if (rows_in_low_tile > 0) {
|
|
2349
|
+
nk_dots_u8_store_sapphireamx_(&c_low_state, c + (row_block_start + 16) * c_stride_elements + col_start,
|
|
2350
|
+
c_stride_elements, rows_in_low_tile, 16);
|
|
2312
2351
|
}
|
|
2313
2352
|
}
|
|
2314
2353
|
|
|
2315
2354
|
// Handle column-edge (remaining columns < 16) using AMX with partial tiles
|
|
2316
2355
|
if (column_remainder_count > 0) {
|
|
2317
|
-
nk_size_t const
|
|
2318
|
-
nk_size_t const
|
|
2356
|
+
nk_size_t const rows_in_high_tile = (valid_rows_count > 16) ? 16 : valid_rows_count;
|
|
2357
|
+
nk_size_t const rows_in_low_tile = (valid_rows_count > 16) ? valid_rows_count - 16 : 0;
|
|
2319
2358
|
|
|
2320
|
-
nk_dots_u8_state_sapphireamx_t
|
|
2359
|
+
nk_dots_u8_state_sapphireamx_t c_high_state, c_low_state;
|
|
2321
2360
|
nk_dots_u8_a16x64_sapphireamx_t b_as_a;
|
|
2322
2361
|
nk_dots_u8_b64x16_sapphireamx_t b_tile;
|
|
2323
2362
|
|
|
@@ -2328,35 +2367,34 @@ NK_PUBLIC void nk_dots_packed_u8_sapphireamx( //
|
|
|
2328
2367
|
nk_size_t const depth_offset = depth_tile_idx * tile_depth;
|
|
2329
2368
|
nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
|
|
2330
2369
|
|
|
2331
|
-
nk_dots_u8_load_a_sapphireamx_(&
|
|
2332
|
-
a_stride_bytes,
|
|
2333
|
-
if (
|
|
2334
|
-
nk_dots_u8_load_a_sapphireamx_(&
|
|
2370
|
+
nk_dots_u8_load_a_sapphireamx_(&a_tile_top, a + row_block_start * a_stride_bytes + depth_offset,
|
|
2371
|
+
a_stride_bytes, rows_in_high_tile, valid_depth);
|
|
2372
|
+
if (rows_in_low_tile > 0) {
|
|
2373
|
+
nk_dots_u8_load_a_sapphireamx_(&a_tile_bottom,
|
|
2335
2374
|
a + (row_block_start + 16) * a_stride_bytes + depth_offset,
|
|
2336
|
-
a_stride_bytes,
|
|
2375
|
+
a_stride_bytes, rows_in_low_tile, valid_depth);
|
|
2337
2376
|
}
|
|
2338
2377
|
|
|
2339
2378
|
nk_dots_u8_load_a_sapphireamx_(&b_as_a, col_edge_ptr + depth_offset, depth, column_remainder_count,
|
|
2340
2379
|
valid_depth);
|
|
2341
2380
|
nk_dots_pack_u8_transposed_sapphireamx_(&b_as_a, &b_tile);
|
|
2342
2381
|
|
|
2343
|
-
_tile_loadd(0,
|
|
2344
|
-
_tile_loadd(1,
|
|
2382
|
+
_tile_loadd(0, a_tile_top.data, 64);
|
|
2383
|
+
_tile_loadd(1, a_tile_bottom.data, 64);
|
|
2345
2384
|
_tile_loadd(2, b_tile.data, 64);
|
|
2346
2385
|
|
|
2347
2386
|
_tile_dpbuud(4, 0, 2);
|
|
2348
2387
|
_tile_dpbuud(6, 1, 2);
|
|
2349
2388
|
}
|
|
2350
2389
|
|
|
2351
|
-
_tile_stored(4,
|
|
2352
|
-
_tile_stored(6,
|
|
2390
|
+
_tile_stored(4, c_high_state.data, 64);
|
|
2391
|
+
_tile_stored(6, c_low_state.data, 64);
|
|
2353
2392
|
|
|
2354
|
-
nk_dots_u8_store_sapphireamx_(&
|
|
2355
|
-
c_stride_elements,
|
|
2356
|
-
if (
|
|
2357
|
-
nk_dots_u8_store_sapphireamx_(&
|
|
2358
|
-
|
|
2359
|
-
c_stride_elements, rows_in_lower_tile, column_remainder_count);
|
|
2393
|
+
nk_dots_u8_store_sapphireamx_(&c_high_state, c + row_block_start * c_stride_elements + full_cols,
|
|
2394
|
+
c_stride_elements, rows_in_high_tile, column_remainder_count);
|
|
2395
|
+
if (rows_in_low_tile > 0) {
|
|
2396
|
+
nk_dots_u8_store_sapphireamx_(&c_low_state, c + (row_block_start + 16) * c_stride_elements + full_cols,
|
|
2397
|
+
c_stride_elements, rows_in_low_tile, column_remainder_count);
|
|
2360
2398
|
}
|
|
2361
2399
|
}
|
|
2362
2400
|
}
|
|
@@ -2364,17 +2402,17 @@ NK_PUBLIC void nk_dots_packed_u8_sapphireamx( //
|
|
|
2364
2402
|
_tile_release();
|
|
2365
2403
|
}
|
|
2366
2404
|
|
|
2367
|
-
NK_PUBLIC void nk_dots_symmetric_u8_sapphireamx(
|
|
2368
|
-
nk_u8_t const *vectors, nk_size_t
|
|
2369
|
-
nk_size_t
|
|
2405
|
+
NK_PUBLIC void nk_dots_symmetric_u8_sapphireamx( //
|
|
2406
|
+
nk_u8_t const *vectors, nk_size_t vectors_count, nk_size_t depth, //
|
|
2407
|
+
nk_size_t stride_in_bytes, nk_u32_t *result, nk_size_t result_stride_in_bytes, //
|
|
2370
2408
|
nk_size_t row_start, nk_size_t row_count) {
|
|
2371
2409
|
|
|
2372
|
-
nk_size_t const result_stride_elements =
|
|
2410
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_u32_t);
|
|
2373
2411
|
|
|
2374
2412
|
// Handle row slicing: compute rows [row_start, row_end)
|
|
2375
2413
|
nk_size_t const row_end = (row_count == 0)
|
|
2376
|
-
?
|
|
2377
|
-
: (row_start + row_count <
|
|
2414
|
+
? vectors_count
|
|
2415
|
+
: (row_start + row_count < vectors_count ? row_start + row_count : vectors_count);
|
|
2378
2416
|
|
|
2379
2417
|
// Round depth up to multiple of 192 (3 tiles × 64 elements)
|
|
2380
2418
|
nk_size_t const depth_tiles = nk_size_divide_round_up_(depth, 64);
|
|
@@ -2390,8 +2428,8 @@ NK_PUBLIC void nk_dots_symmetric_u8_sapphireamx( //
|
|
|
2390
2428
|
for (nk_size_t row_tile = row_start; row_tile < row_end; row_tile += 16) {
|
|
2391
2429
|
nk_size_t const valid_rows = (row_tile + 16 <= row_end) ? 16 : (row_end - row_tile);
|
|
2392
2430
|
|
|
2393
|
-
for (nk_size_t col_tile = 0; col_tile <
|
|
2394
|
-
nk_size_t const valid_cols = (col_tile + 16 <=
|
|
2431
|
+
for (nk_size_t col_tile = 0; col_tile < vectors_count; col_tile += 16) {
|
|
2432
|
+
nk_size_t const valid_cols = (col_tile + 16 <= vectors_count) ? 16 : (vectors_count - col_tile);
|
|
2395
2433
|
|
|
2396
2434
|
nk_dots_u8_init_sapphireamx_(&state);
|
|
2397
2435
|
|
|
@@ -2404,19 +2442,19 @@ NK_PUBLIC void nk_dots_symmetric_u8_sapphireamx( //
|
|
|
2404
2442
|
? 64
|
|
2405
2443
|
: (depth > depth_start ? depth - depth_start : 0);
|
|
2406
2444
|
|
|
2407
|
-
nk_dots_u8_load_a_sapphireamx_(
|
|
2408
|
-
&a_tiles[tile_idx],
|
|
2409
|
-
vectors + row_tile *
|
|
2410
|
-
|
|
2445
|
+
nk_dots_u8_load_a_sapphireamx_( //
|
|
2446
|
+
&a_tiles[tile_idx], //
|
|
2447
|
+
vectors + row_tile * stride_in_bytes + depth_start, //
|
|
2448
|
+
stride_in_bytes, valid_rows, valid_depth);
|
|
2411
2449
|
|
|
2412
2450
|
if (row_tile == col_tile) {
|
|
2413
2451
|
nk_dots_pack_u8_transposed_sapphireamx_(&a_tiles[tile_idx], &b_tiles[tile_idx]);
|
|
2414
2452
|
}
|
|
2415
2453
|
else {
|
|
2416
|
-
nk_dots_u8_load_a_sapphireamx_(
|
|
2417
|
-
&b_src_tiles[tile_idx],
|
|
2418
|
-
vectors + col_tile *
|
|
2419
|
-
|
|
2454
|
+
nk_dots_u8_load_a_sapphireamx_( //
|
|
2455
|
+
&b_src_tiles[tile_idx], //
|
|
2456
|
+
vectors + col_tile * stride_in_bytes + depth_start, //
|
|
2457
|
+
stride_in_bytes, valid_cols, valid_depth);
|
|
2420
2458
|
nk_dots_pack_u8_transposed_sapphireamx_(&b_src_tiles[tile_idx], &b_tiles[tile_idx]);
|
|
2421
2459
|
}
|
|
2422
2460
|
}
|
|
@@ -2432,9 +2470,9 @@ NK_PUBLIC void nk_dots_symmetric_u8_sapphireamx( //
|
|
|
2432
2470
|
}
|
|
2433
2471
|
}
|
|
2434
2472
|
|
|
2435
|
-
#pragma endregion
|
|
2473
|
+
#pragma endregion Unsigned Integers
|
|
2436
2474
|
|
|
2437
|
-
#pragma region
|
|
2475
|
+
#pragma region E4M3 Floats
|
|
2438
2476
|
|
|
2439
2477
|
NK_PUBLIC nk_size_t nk_dots_packed_size_e4m3_sapphireamx(nk_size_t column_count, nk_size_t depth) {
|
|
2440
2478
|
// FP8 uses BF16 tile layout after conversion (same element count: 32 per row)
|
|
@@ -2443,7 +2481,7 @@ NK_PUBLIC nk_size_t nk_dots_packed_size_e4m3_sapphireamx(nk_size_t column_count,
|
|
|
2443
2481
|
|
|
2444
2482
|
NK_PUBLIC void nk_dots_pack_e4m3_sapphireamx( //
|
|
2445
2483
|
nk_e4m3_t const *b, nk_size_t column_count, nk_size_t depth, //
|
|
2446
|
-
nk_size_t
|
|
2484
|
+
nk_size_t b_stride_in_bytes, void *b_packed) {
|
|
2447
2485
|
|
|
2448
2486
|
nk_size_t const tmm_rows = 16;
|
|
2449
2487
|
nk_size_t const tmm_cols = 32; // Same depth granularity as BF16
|
|
@@ -2467,8 +2505,7 @@ NK_PUBLIC void nk_dots_pack_e4m3_sapphireamx( //
|
|
|
2467
2505
|
nk_bf16_t *tiles_ptr = (nk_bf16_t *)((char *)b_packed + tiles_offset);
|
|
2468
2506
|
nk_bf16_t *column_edge_ptr = (nk_bf16_t *)((char *)b_packed + column_edge_offset);
|
|
2469
2507
|
|
|
2470
|
-
|
|
2471
|
-
|
|
2508
|
+
// Pack tiles using vectorized convert + SIMD transpose
|
|
2472
2509
|
for (nk_size_t column_tile_idx = 0; column_tile_idx < column_tiles_count; column_tile_idx++) {
|
|
2473
2510
|
for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
|
|
2474
2511
|
nk_size_t const tile_index = column_tile_idx * depth_tiles_count + depth_tile_idx;
|
|
@@ -2479,21 +2516,19 @@ NK_PUBLIC void nk_dots_pack_e4m3_sapphireamx( //
|
|
|
2479
2516
|
nk_size_t const columns_to_pack = (src_column_start + tmm_cols <= depth) ? tmm_cols
|
|
2480
2517
|
: (depth - src_column_start);
|
|
2481
2518
|
|
|
2482
|
-
// Convert E4M3
|
|
2519
|
+
// Convert E4M3 → BF16 and gather into aligned source tile
|
|
2520
|
+
__mmask32 column_mask = (columns_to_pack >= 32) ? 0xFFFFFFFF : ((__mmask32)1 << columns_to_pack) - 1;
|
|
2521
|
+
nk_dots_bf16_a16x32_sapphireamx_t source_tile;
|
|
2483
2522
|
for (nk_size_t row_idx = 0; row_idx < tmm_rows; row_idx++) {
|
|
2484
|
-
|
|
2485
|
-
|
|
2486
|
-
|
|
2487
|
-
__m256i e4m3_row = _mm256_maskz_loadu_epi8(column_mask, b + src_row * b_stride + src_column_start);
|
|
2488
|
-
__m512i bf16_row = nk_e4m3x32_to_bf16x32_icelake_(e4m3_row);
|
|
2489
|
-
// Store with pair-interleaving
|
|
2490
|
-
nk_bf16_t bf16_buf[32];
|
|
2491
|
-
_mm512_storeu_si512((__m512i *)bf16_buf, bf16_row);
|
|
2492
|
-
for (nk_size_t column_idx = 0; column_idx < columns_to_pack; column_idx++) {
|
|
2493
|
-
nk_size_t const dst_idx = (column_idx / 2) * 32 + row_idx * 2 + (column_idx % 2);
|
|
2494
|
-
tile_output[dst_idx] = bf16_buf[column_idx];
|
|
2495
|
-
}
|
|
2523
|
+
__m256i e4m3_row_u8x32 = _mm256_maskz_loadu_epi8(
|
|
2524
|
+
column_mask, b + (src_row_start + row_idx) * b_stride_in_bytes + src_column_start);
|
|
2525
|
+
_mm512_store_si512(&source_tile.data[row_idx][0], nk_e4m3x32_to_bf16x32_icelake_(e4m3_row_u8x32));
|
|
2496
2526
|
}
|
|
2527
|
+
|
|
2528
|
+
nk_dots_bf16_b32x16_sapphireamx_t transposed_tile;
|
|
2529
|
+
nk_dots_pack_bf16_transposed_sapphireamx_(&source_tile, &transposed_tile);
|
|
2530
|
+
for (nk_size_t i = 0; i < tile_bytes; i += 64)
|
|
2531
|
+
_mm512_storeu_si512((char *)tile_output + i, _mm512_load_si512((char const *)&transposed_tile + i));
|
|
2497
2532
|
}
|
|
2498
2533
|
}
|
|
2499
2534
|
|
|
@@ -2504,10 +2539,11 @@ NK_PUBLIC void nk_dots_pack_e4m3_sapphireamx( //
|
|
|
2504
2539
|
for (nk_size_t column_idx = 0; column_idx < depth; column_idx += 32) {
|
|
2505
2540
|
nk_size_t columns = (column_idx + 32 <= depth) ? 32 : (depth - column_idx);
|
|
2506
2541
|
__mmask32 column_mask = (columns >= 32) ? 0xFFFFFFFF : ((__mmask32)1 << columns) - 1;
|
|
2507
|
-
__m256i
|
|
2508
|
-
column_mask, b + (remainder_start_row + row_idx) *
|
|
2509
|
-
__m512i
|
|
2510
|
-
_mm512_mask_storeu_epi16(column_edge_ptr + row_idx * depth + column_idx, column_mask,
|
|
2542
|
+
__m256i e4m3_chunk_u8x32 = _mm256_maskz_loadu_epi8(
|
|
2543
|
+
column_mask, b + (remainder_start_row + row_idx) * b_stride_in_bytes + column_idx);
|
|
2544
|
+
__m512i bf16_chunk_i16x32 = nk_e4m3x32_to_bf16x32_icelake_(e4m3_chunk_u8x32);
|
|
2545
|
+
_mm512_mask_storeu_epi16(column_edge_ptr + row_idx * depth + column_idx, column_mask,
|
|
2546
|
+
bf16_chunk_i16x32);
|
|
2511
2547
|
}
|
|
2512
2548
|
}
|
|
2513
2549
|
}
|
|
@@ -2518,7 +2554,7 @@ NK_PUBLIC void nk_dots_pack_e4m3_sapphireamx( //
|
|
|
2518
2554
|
header->norms_byte_offset = (nk_u32_t)norms_offset;
|
|
2519
2555
|
nk_f32_t *norms = (nk_f32_t *)((char *)b_packed + norms_offset);
|
|
2520
2556
|
for (nk_size_t col = 0; col < column_count; col++)
|
|
2521
|
-
norms[col] = nk_dots_reduce_sumsq_e4m3_(b + col *
|
|
2557
|
+
norms[col] = nk_dots_reduce_sumsq_e4m3_(b + col * b_stride_in_bytes, depth);
|
|
2522
2558
|
}
|
|
2523
2559
|
|
|
2524
2560
|
NK_PUBLIC void nk_dots_packed_e4m3_sapphireamx( //
|
|
@@ -2545,7 +2581,7 @@ NK_PUBLIC void nk_dots_packed_e4m3_sapphireamx( //
|
|
|
2545
2581
|
|
|
2546
2582
|
if (depth_tiles_count == 0) return;
|
|
2547
2583
|
|
|
2548
|
-
nk_dots_bf16_a16x32_sapphireamx_t
|
|
2584
|
+
nk_dots_bf16_a16x32_sapphireamx_t a_tile_top, a_tile_bottom;
|
|
2549
2585
|
nk_dots_bf16_state2x2_sapphireamx_t c_accum_buffer;
|
|
2550
2586
|
|
|
2551
2587
|
nk_size_t const full_depth_tiles_count = depth / tile_depth;
|
|
@@ -2558,8 +2594,8 @@ NK_PUBLIC void nk_dots_packed_e4m3_sapphireamx( //
|
|
|
2558
2594
|
nk_size_t const row_block_start = row_block_idx * 32;
|
|
2559
2595
|
nk_size_t const valid_rows_count = (row_block_start + 32 <= rows_count) ? 32 : (rows_count - row_block_start);
|
|
2560
2596
|
nk_size_t const is_full_row_block = (valid_rows_count == 32);
|
|
2561
|
-
nk_size_t const
|
|
2562
|
-
nk_size_t const
|
|
2597
|
+
nk_size_t const rows_in_high_tile = (valid_rows_count > 16) ? 16 : valid_rows_count;
|
|
2598
|
+
nk_size_t const rows_in_low_tile = (valid_rows_count > 16) ? valid_rows_count - 16 : 0;
|
|
2563
2599
|
|
|
2564
2600
|
for (nk_size_t column_block_idx = 0; column_block_idx < col_blocks_count; column_block_idx++) {
|
|
2565
2601
|
nk_size_t const col_block_start = column_block_idx * 32;
|
|
@@ -2578,12 +2614,12 @@ NK_PUBLIC void nk_dots_packed_e4m3_sapphireamx( //
|
|
|
2578
2614
|
nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
|
|
2579
2615
|
|
|
2580
2616
|
// Load A with FP8 → BF16 conversion
|
|
2581
|
-
nk_dots_e4m3_load_a_sapphireamx_(&
|
|
2582
|
-
a_stride_bytes,
|
|
2583
|
-
if (
|
|
2584
|
-
nk_dots_e4m3_load_a_sapphireamx_(&
|
|
2617
|
+
nk_dots_e4m3_load_a_sapphireamx_(&a_tile_top, a + row_block_start * a_stride_bytes + depth_offset,
|
|
2618
|
+
a_stride_bytes, rows_in_high_tile, valid_depth);
|
|
2619
|
+
if (rows_in_low_tile > 0) {
|
|
2620
|
+
nk_dots_e4m3_load_a_sapphireamx_(&a_tile_bottom,
|
|
2585
2621
|
a + (row_block_start + 16) * a_stride_bytes + depth_offset,
|
|
2586
|
-
a_stride_bytes,
|
|
2622
|
+
a_stride_bytes, rows_in_low_tile, valid_depth);
|
|
2587
2623
|
}
|
|
2588
2624
|
|
|
2589
2625
|
nk_dots_bf16_b32x16_sapphireamx_t const *b_tile_left =
|
|
@@ -2593,8 +2629,8 @@ NK_PUBLIC void nk_dots_packed_e4m3_sapphireamx( //
|
|
|
2593
2629
|
(nk_dots_bf16_b32x16_sapphireamx_t const *)(b_tiles_base +
|
|
2594
2630
|
(b_column_right_base + depth_tile_idx) * tile_size);
|
|
2595
2631
|
|
|
2596
|
-
_tile_loadd(0,
|
|
2597
|
-
_tile_loadd(1,
|
|
2632
|
+
_tile_loadd(0, a_tile_top.data, 64);
|
|
2633
|
+
_tile_loadd(1, a_tile_bottom.data, 64);
|
|
2598
2634
|
_tile_loadd(2, b_tile_left->data, 64);
|
|
2599
2635
|
_tile_loadd(3, b_tile_right->data, 64);
|
|
2600
2636
|
|
|
@@ -2629,7 +2665,7 @@ NK_PUBLIC void nk_dots_packed_e4m3_sapphireamx( //
|
|
|
2629
2665
|
nk_size_t const col_start = column_tile_idx * 16;
|
|
2630
2666
|
nk_size_t const b_column_base = column_tile_idx * depth_tiles_count;
|
|
2631
2667
|
|
|
2632
|
-
nk_dots_bf16_state_sapphireamx_t
|
|
2668
|
+
nk_dots_bf16_state_sapphireamx_t c_high_state, c_low_state;
|
|
2633
2669
|
_tile_zero(4);
|
|
2634
2670
|
_tile_zero(6);
|
|
2635
2671
|
|
|
@@ -2637,41 +2673,41 @@ NK_PUBLIC void nk_dots_packed_e4m3_sapphireamx( //
|
|
|
2637
2673
|
nk_size_t const depth_offset = depth_tile_idx * tile_depth;
|
|
2638
2674
|
nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
|
|
2639
2675
|
|
|
2640
|
-
nk_dots_e4m3_load_a_sapphireamx_(&
|
|
2641
|
-
a_stride_bytes,
|
|
2642
|
-
if (
|
|
2643
|
-
nk_dots_e4m3_load_a_sapphireamx_(&
|
|
2676
|
+
nk_dots_e4m3_load_a_sapphireamx_(&a_tile_top, a + row_block_start * a_stride_bytes + depth_offset,
|
|
2677
|
+
a_stride_bytes, rows_in_high_tile, valid_depth);
|
|
2678
|
+
if (rows_in_low_tile > 0) {
|
|
2679
|
+
nk_dots_e4m3_load_a_sapphireamx_(&a_tile_bottom,
|
|
2644
2680
|
a + (row_block_start + 16) * a_stride_bytes + depth_offset,
|
|
2645
|
-
a_stride_bytes,
|
|
2681
|
+
a_stride_bytes, rows_in_low_tile, valid_depth);
|
|
2646
2682
|
}
|
|
2647
2683
|
|
|
2648
2684
|
nk_dots_bf16_b32x16_sapphireamx_t const *b_tile =
|
|
2649
2685
|
(nk_dots_bf16_b32x16_sapphireamx_t const *)(b_tiles_base +
|
|
2650
2686
|
(b_column_base + depth_tile_idx) * tile_size);
|
|
2651
2687
|
|
|
2652
|
-
_tile_loadd(0,
|
|
2653
|
-
_tile_loadd(1,
|
|
2688
|
+
_tile_loadd(0, a_tile_top.data, 64);
|
|
2689
|
+
_tile_loadd(1, a_tile_bottom.data, 64);
|
|
2654
2690
|
_tile_loadd(2, b_tile->data, 64);
|
|
2655
2691
|
|
|
2656
2692
|
_tile_dpbf16ps(4, 0, 2);
|
|
2657
2693
|
_tile_dpbf16ps(6, 1, 2);
|
|
2658
2694
|
}
|
|
2659
2695
|
|
|
2660
|
-
_tile_stored(4,
|
|
2661
|
-
_tile_stored(6,
|
|
2696
|
+
_tile_stored(4, c_high_state.data, 64);
|
|
2697
|
+
_tile_stored(6, c_low_state.data, 64);
|
|
2662
2698
|
|
|
2663
|
-
nk_dots_bf16_store_sapphireamx_(&
|
|
2664
|
-
c_stride_elements,
|
|
2665
|
-
if (
|
|
2666
|
-
nk_dots_bf16_store_sapphireamx_(&
|
|
2699
|
+
nk_dots_bf16_store_sapphireamx_(&c_high_state, c + row_block_start * c_stride_elements + col_start,
|
|
2700
|
+
c_stride_elements, rows_in_high_tile, 16);
|
|
2701
|
+
if (rows_in_low_tile > 0) {
|
|
2702
|
+
nk_dots_bf16_store_sapphireamx_(&c_low_state,
|
|
2667
2703
|
c + (row_block_start + 16) * c_stride_elements + col_start,
|
|
2668
|
-
c_stride_elements,
|
|
2704
|
+
c_stride_elements, rows_in_low_tile, 16);
|
|
2669
2705
|
}
|
|
2670
2706
|
}
|
|
2671
2707
|
|
|
2672
2708
|
// Handle column-edge (remaining columns < 16) using AMX with partial tiles
|
|
2673
2709
|
if (column_remainder_count > 0) {
|
|
2674
|
-
nk_dots_bf16_state_sapphireamx_t
|
|
2710
|
+
nk_dots_bf16_state_sapphireamx_t c_high_state, c_low_state;
|
|
2675
2711
|
nk_dots_bf16_a16x32_sapphireamx_t b_as_a;
|
|
2676
2712
|
nk_dots_bf16_b32x16_sapphireamx_t b_tile;
|
|
2677
2713
|
|
|
@@ -2682,12 +2718,12 @@ NK_PUBLIC void nk_dots_packed_e4m3_sapphireamx( //
|
|
|
2682
2718
|
nk_size_t const depth_offset = depth_tile_idx * tile_depth;
|
|
2683
2719
|
nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
|
|
2684
2720
|
|
|
2685
|
-
nk_dots_e4m3_load_a_sapphireamx_(&
|
|
2686
|
-
a_stride_bytes,
|
|
2687
|
-
if (
|
|
2688
|
-
nk_dots_e4m3_load_a_sapphireamx_(&
|
|
2721
|
+
nk_dots_e4m3_load_a_sapphireamx_(&a_tile_top, a + row_block_start * a_stride_bytes + depth_offset,
|
|
2722
|
+
a_stride_bytes, rows_in_high_tile, valid_depth);
|
|
2723
|
+
if (rows_in_low_tile > 0) {
|
|
2724
|
+
nk_dots_e4m3_load_a_sapphireamx_(&a_tile_bottom,
|
|
2689
2725
|
a + (row_block_start + 16) * a_stride_bytes + depth_offset,
|
|
2690
|
-
a_stride_bytes,
|
|
2726
|
+
a_stride_bytes, rows_in_low_tile, valid_depth);
|
|
2691
2727
|
}
|
|
2692
2728
|
|
|
2693
2729
|
// B edge data is already in BF16 format
|
|
@@ -2695,23 +2731,23 @@ NK_PUBLIC void nk_dots_packed_e4m3_sapphireamx( //
|
|
|
2695
2731
|
valid_depth);
|
|
2696
2732
|
nk_dots_pack_bf16_transposed_sapphireamx_(&b_as_a, &b_tile);
|
|
2697
2733
|
|
|
2698
|
-
_tile_loadd(0,
|
|
2699
|
-
_tile_loadd(1,
|
|
2734
|
+
_tile_loadd(0, a_tile_top.data, 64);
|
|
2735
|
+
_tile_loadd(1, a_tile_bottom.data, 64);
|
|
2700
2736
|
_tile_loadd(2, b_tile.data, 64);
|
|
2701
2737
|
|
|
2702
2738
|
_tile_dpbf16ps(4, 0, 2);
|
|
2703
2739
|
_tile_dpbf16ps(6, 1, 2);
|
|
2704
2740
|
}
|
|
2705
2741
|
|
|
2706
|
-
_tile_stored(4,
|
|
2707
|
-
_tile_stored(6,
|
|
2742
|
+
_tile_stored(4, c_high_state.data, 64);
|
|
2743
|
+
_tile_stored(6, c_low_state.data, 64);
|
|
2708
2744
|
|
|
2709
|
-
nk_dots_bf16_store_sapphireamx_(&
|
|
2710
|
-
c_stride_elements,
|
|
2711
|
-
if (
|
|
2712
|
-
nk_dots_bf16_store_sapphireamx_(&
|
|
2745
|
+
nk_dots_bf16_store_sapphireamx_(&c_high_state, c + row_block_start * c_stride_elements + full_cols,
|
|
2746
|
+
c_stride_elements, rows_in_high_tile, column_remainder_count);
|
|
2747
|
+
if (rows_in_low_tile > 0) {
|
|
2748
|
+
nk_dots_bf16_store_sapphireamx_(&c_low_state,
|
|
2713
2749
|
c + (row_block_start + 16) * c_stride_elements + full_cols,
|
|
2714
|
-
c_stride_elements,
|
|
2750
|
+
c_stride_elements, rows_in_low_tile, column_remainder_count);
|
|
2715
2751
|
}
|
|
2716
2752
|
}
|
|
2717
2753
|
}
|
|
@@ -2719,9 +2755,9 @@ NK_PUBLIC void nk_dots_packed_e4m3_sapphireamx( //
|
|
|
2719
2755
|
_tile_release();
|
|
2720
2756
|
}
|
|
2721
2757
|
|
|
2722
|
-
#pragma endregion
|
|
2758
|
+
#pragma endregion E4M3 Floats
|
|
2723
2759
|
|
|
2724
|
-
#pragma region
|
|
2760
|
+
#pragma region E5M2 Floats
|
|
2725
2761
|
|
|
2726
2762
|
NK_PUBLIC nk_size_t nk_dots_packed_size_e5m2_sapphireamx(nk_size_t column_count, nk_size_t depth) {
|
|
2727
2763
|
return nk_dots_packed_size_bf16_sapphireamx(column_count, depth);
|
|
@@ -2729,7 +2765,7 @@ NK_PUBLIC nk_size_t nk_dots_packed_size_e5m2_sapphireamx(nk_size_t column_count,
|
|
|
2729
2765
|
|
|
2730
2766
|
NK_PUBLIC void nk_dots_pack_e5m2_sapphireamx( //
|
|
2731
2767
|
nk_e5m2_t const *b, nk_size_t column_count, nk_size_t depth, //
|
|
2732
|
-
nk_size_t
|
|
2768
|
+
nk_size_t b_stride_in_bytes, void *b_packed) {
|
|
2733
2769
|
|
|
2734
2770
|
nk_size_t const tmm_rows = 16;
|
|
2735
2771
|
nk_size_t const tmm_cols = 32;
|
|
@@ -2753,8 +2789,7 @@ NK_PUBLIC void nk_dots_pack_e5m2_sapphireamx( //
|
|
|
2753
2789
|
nk_bf16_t *tiles_ptr = (nk_bf16_t *)((char *)b_packed + tiles_offset);
|
|
2754
2790
|
nk_bf16_t *column_edge_ptr = (nk_bf16_t *)((char *)b_packed + column_edge_offset);
|
|
2755
2791
|
|
|
2756
|
-
|
|
2757
|
-
|
|
2792
|
+
// Pack tiles using vectorized convert + SIMD transpose
|
|
2758
2793
|
for (nk_size_t column_tile_idx = 0; column_tile_idx < column_tiles_count; column_tile_idx++) {
|
|
2759
2794
|
for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
|
|
2760
2795
|
nk_size_t const tile_index = column_tile_idx * depth_tiles_count + depth_tile_idx;
|
|
@@ -2765,18 +2800,18 @@ NK_PUBLIC void nk_dots_pack_e5m2_sapphireamx( //
|
|
|
2765
2800
|
nk_size_t const columns_to_pack = (src_column_start + tmm_cols <= depth) ? tmm_cols
|
|
2766
2801
|
: (depth - src_column_start);
|
|
2767
2802
|
|
|
2803
|
+
__mmask32 column_mask = (columns_to_pack >= 32) ? 0xFFFFFFFF : ((__mmask32)1 << columns_to_pack) - 1;
|
|
2804
|
+
nk_dots_bf16_a16x32_sapphireamx_t source_tile;
|
|
2768
2805
|
for (nk_size_t row_idx = 0; row_idx < tmm_rows; row_idx++) {
|
|
2769
|
-
|
|
2770
|
-
|
|
2771
|
-
|
|
2772
|
-
__m512i bf16_row = nk_e5m2x32_to_bf16x32_icelake_(e5m2_row);
|
|
2773
|
-
nk_bf16_t bf16_buf[32];
|
|
2774
|
-
_mm512_storeu_si512((__m512i *)bf16_buf, bf16_row);
|
|
2775
|
-
for (nk_size_t column_idx = 0; column_idx < columns_to_pack; column_idx++) {
|
|
2776
|
-
nk_size_t const dst_idx = (column_idx / 2) * 32 + row_idx * 2 + (column_idx % 2);
|
|
2777
|
-
tile_output[dst_idx] = bf16_buf[column_idx];
|
|
2778
|
-
}
|
|
2806
|
+
__m256i e5m2_row_u8x32 = _mm256_maskz_loadu_epi8(
|
|
2807
|
+
column_mask, b + (src_row_start + row_idx) * b_stride_in_bytes + src_column_start);
|
|
2808
|
+
_mm512_store_si512(&source_tile.data[row_idx][0], nk_e5m2x32_to_bf16x32_icelake_(e5m2_row_u8x32));
|
|
2779
2809
|
}
|
|
2810
|
+
|
|
2811
|
+
nk_dots_bf16_b32x16_sapphireamx_t transposed_tile;
|
|
2812
|
+
nk_dots_pack_bf16_transposed_sapphireamx_(&source_tile, &transposed_tile);
|
|
2813
|
+
for (nk_size_t i = 0; i < tile_bytes; i += 64)
|
|
2814
|
+
_mm512_storeu_si512((char *)tile_output + i, _mm512_load_si512((char const *)&transposed_tile + i));
|
|
2780
2815
|
}
|
|
2781
2816
|
}
|
|
2782
2817
|
|
|
@@ -2786,10 +2821,11 @@ NK_PUBLIC void nk_dots_pack_e5m2_sapphireamx( //
|
|
|
2786
2821
|
for (nk_size_t column_idx = 0; column_idx < depth; column_idx += 32) {
|
|
2787
2822
|
nk_size_t columns = (column_idx + 32 <= depth) ? 32 : (depth - column_idx);
|
|
2788
2823
|
__mmask32 column_mask = (columns >= 32) ? 0xFFFFFFFF : ((__mmask32)1 << columns) - 1;
|
|
2789
|
-
__m256i
|
|
2790
|
-
column_mask, b + (remainder_start_row + row_idx) *
|
|
2791
|
-
__m512i
|
|
2792
|
-
_mm512_mask_storeu_epi16(column_edge_ptr + row_idx * depth + column_idx, column_mask,
|
|
2824
|
+
__m256i e5m2_chunk_u8x32 = _mm256_maskz_loadu_epi8(
|
|
2825
|
+
column_mask, b + (remainder_start_row + row_idx) * b_stride_in_bytes + column_idx);
|
|
2826
|
+
__m512i bf16_chunk_i16x32 = nk_e5m2x32_to_bf16x32_icelake_(e5m2_chunk_u8x32);
|
|
2827
|
+
_mm512_mask_storeu_epi16(column_edge_ptr + row_idx * depth + column_idx, column_mask,
|
|
2828
|
+
bf16_chunk_i16x32);
|
|
2793
2829
|
}
|
|
2794
2830
|
}
|
|
2795
2831
|
}
|
|
@@ -2800,7 +2836,7 @@ NK_PUBLIC void nk_dots_pack_e5m2_sapphireamx( //
|
|
|
2800
2836
|
header->norms_byte_offset = (nk_u32_t)norms_offset;
|
|
2801
2837
|
nk_f32_t *norms = (nk_f32_t *)((char *)b_packed + norms_offset);
|
|
2802
2838
|
for (nk_size_t col = 0; col < column_count; col++)
|
|
2803
|
-
norms[col] = nk_dots_reduce_sumsq_e5m2_(b + col *
|
|
2839
|
+
norms[col] = nk_dots_reduce_sumsq_e5m2_(b + col * b_stride_in_bytes, depth);
|
|
2804
2840
|
}
|
|
2805
2841
|
|
|
2806
2842
|
NK_PUBLIC void nk_dots_packed_e5m2_sapphireamx( //
|
|
@@ -2826,7 +2862,7 @@ NK_PUBLIC void nk_dots_packed_e5m2_sapphireamx( //
|
|
|
2826
2862
|
|
|
2827
2863
|
if (depth_tiles_count == 0) return;
|
|
2828
2864
|
|
|
2829
|
-
nk_dots_bf16_a16x32_sapphireamx_t
|
|
2865
|
+
nk_dots_bf16_a16x32_sapphireamx_t a_tile_top, a_tile_bottom;
|
|
2830
2866
|
nk_dots_bf16_state2x2_sapphireamx_t c_accum_buffer;
|
|
2831
2867
|
|
|
2832
2868
|
nk_size_t const full_depth_tiles_count = depth / tile_depth;
|
|
@@ -2839,8 +2875,8 @@ NK_PUBLIC void nk_dots_packed_e5m2_sapphireamx( //
|
|
|
2839
2875
|
nk_size_t const row_block_start = row_block_idx * 32;
|
|
2840
2876
|
nk_size_t const valid_rows_count = (row_block_start + 32 <= rows_count) ? 32 : (rows_count - row_block_start);
|
|
2841
2877
|
nk_size_t const is_full_row_block = (valid_rows_count == 32);
|
|
2842
|
-
nk_size_t const
|
|
2843
|
-
nk_size_t const
|
|
2878
|
+
nk_size_t const rows_in_high_tile = (valid_rows_count > 16) ? 16 : valid_rows_count;
|
|
2879
|
+
nk_size_t const rows_in_low_tile = (valid_rows_count > 16) ? valid_rows_count - 16 : 0;
|
|
2844
2880
|
|
|
2845
2881
|
for (nk_size_t column_block_idx = 0; column_block_idx < col_blocks_count; column_block_idx++) {
|
|
2846
2882
|
nk_size_t const col_block_start = column_block_idx * 32;
|
|
@@ -2859,12 +2895,12 @@ NK_PUBLIC void nk_dots_packed_e5m2_sapphireamx( //
|
|
|
2859
2895
|
nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
|
|
2860
2896
|
|
|
2861
2897
|
// Load A with FP8 → BF16 conversion
|
|
2862
|
-
nk_dots_e5m2_load_a_sapphireamx_(&
|
|
2863
|
-
a_stride_bytes,
|
|
2864
|
-
if (
|
|
2865
|
-
nk_dots_e5m2_load_a_sapphireamx_(&
|
|
2898
|
+
nk_dots_e5m2_load_a_sapphireamx_(&a_tile_top, a + row_block_start * a_stride_bytes + depth_offset,
|
|
2899
|
+
a_stride_bytes, rows_in_high_tile, valid_depth);
|
|
2900
|
+
if (rows_in_low_tile > 0) {
|
|
2901
|
+
nk_dots_e5m2_load_a_sapphireamx_(&a_tile_bottom,
|
|
2866
2902
|
a + (row_block_start + 16) * a_stride_bytes + depth_offset,
|
|
2867
|
-
a_stride_bytes,
|
|
2903
|
+
a_stride_bytes, rows_in_low_tile, valid_depth);
|
|
2868
2904
|
}
|
|
2869
2905
|
|
|
2870
2906
|
nk_dots_bf16_b32x16_sapphireamx_t const *b_tile_left =
|
|
@@ -2874,8 +2910,8 @@ NK_PUBLIC void nk_dots_packed_e5m2_sapphireamx( //
|
|
|
2874
2910
|
(nk_dots_bf16_b32x16_sapphireamx_t const *)(b_tiles_base +
|
|
2875
2911
|
(b_column_right_base + depth_tile_idx) * tile_size);
|
|
2876
2912
|
|
|
2877
|
-
_tile_loadd(0,
|
|
2878
|
-
_tile_loadd(1,
|
|
2913
|
+
_tile_loadd(0, a_tile_top.data, 64);
|
|
2914
|
+
_tile_loadd(1, a_tile_bottom.data, 64);
|
|
2879
2915
|
_tile_loadd(2, b_tile_left->data, 64);
|
|
2880
2916
|
_tile_loadd(3, b_tile_right->data, 64);
|
|
2881
2917
|
|
|
@@ -2910,7 +2946,7 @@ NK_PUBLIC void nk_dots_packed_e5m2_sapphireamx( //
|
|
|
2910
2946
|
nk_size_t const col_start = column_tile_idx * 16;
|
|
2911
2947
|
nk_size_t const b_column_base = column_tile_idx * depth_tiles_count;
|
|
2912
2948
|
|
|
2913
|
-
nk_dots_bf16_state_sapphireamx_t
|
|
2949
|
+
nk_dots_bf16_state_sapphireamx_t c_high_state, c_low_state;
|
|
2914
2950
|
_tile_zero(4);
|
|
2915
2951
|
_tile_zero(6);
|
|
2916
2952
|
|
|
@@ -2918,41 +2954,41 @@ NK_PUBLIC void nk_dots_packed_e5m2_sapphireamx( //
|
|
|
2918
2954
|
nk_size_t const depth_offset = depth_tile_idx * tile_depth;
|
|
2919
2955
|
nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
|
|
2920
2956
|
|
|
2921
|
-
nk_dots_e5m2_load_a_sapphireamx_(&
|
|
2922
|
-
a_stride_bytes,
|
|
2923
|
-
if (
|
|
2924
|
-
nk_dots_e5m2_load_a_sapphireamx_(&
|
|
2957
|
+
nk_dots_e5m2_load_a_sapphireamx_(&a_tile_top, a + row_block_start * a_stride_bytes + depth_offset,
|
|
2958
|
+
a_stride_bytes, rows_in_high_tile, valid_depth);
|
|
2959
|
+
if (rows_in_low_tile > 0) {
|
|
2960
|
+
nk_dots_e5m2_load_a_sapphireamx_(&a_tile_bottom,
|
|
2925
2961
|
a + (row_block_start + 16) * a_stride_bytes + depth_offset,
|
|
2926
|
-
a_stride_bytes,
|
|
2962
|
+
a_stride_bytes, rows_in_low_tile, valid_depth);
|
|
2927
2963
|
}
|
|
2928
2964
|
|
|
2929
2965
|
nk_dots_bf16_b32x16_sapphireamx_t const *b_tile =
|
|
2930
2966
|
(nk_dots_bf16_b32x16_sapphireamx_t const *)(b_tiles_base +
|
|
2931
2967
|
(b_column_base + depth_tile_idx) * tile_size);
|
|
2932
2968
|
|
|
2933
|
-
_tile_loadd(0,
|
|
2934
|
-
_tile_loadd(1,
|
|
2969
|
+
_tile_loadd(0, a_tile_top.data, 64);
|
|
2970
|
+
_tile_loadd(1, a_tile_bottom.data, 64);
|
|
2935
2971
|
_tile_loadd(2, b_tile->data, 64);
|
|
2936
2972
|
|
|
2937
2973
|
_tile_dpbf16ps(4, 0, 2);
|
|
2938
2974
|
_tile_dpbf16ps(6, 1, 2);
|
|
2939
2975
|
}
|
|
2940
2976
|
|
|
2941
|
-
_tile_stored(4,
|
|
2942
|
-
_tile_stored(6,
|
|
2977
|
+
_tile_stored(4, c_high_state.data, 64);
|
|
2978
|
+
_tile_stored(6, c_low_state.data, 64);
|
|
2943
2979
|
|
|
2944
|
-
nk_dots_bf16_store_sapphireamx_(&
|
|
2945
|
-
c_stride_elements,
|
|
2946
|
-
if (
|
|
2947
|
-
nk_dots_bf16_store_sapphireamx_(&
|
|
2980
|
+
nk_dots_bf16_store_sapphireamx_(&c_high_state, c + row_block_start * c_stride_elements + col_start,
|
|
2981
|
+
c_stride_elements, rows_in_high_tile, 16);
|
|
2982
|
+
if (rows_in_low_tile > 0) {
|
|
2983
|
+
nk_dots_bf16_store_sapphireamx_(&c_low_state,
|
|
2948
2984
|
c + (row_block_start + 16) * c_stride_elements + col_start,
|
|
2949
|
-
c_stride_elements,
|
|
2985
|
+
c_stride_elements, rows_in_low_tile, 16);
|
|
2950
2986
|
}
|
|
2951
2987
|
}
|
|
2952
2988
|
|
|
2953
2989
|
// Handle column-edge (remaining columns < 16) using AMX with partial tiles
|
|
2954
2990
|
if (column_remainder_count > 0) {
|
|
2955
|
-
nk_dots_bf16_state_sapphireamx_t
|
|
2991
|
+
nk_dots_bf16_state_sapphireamx_t c_high_state, c_low_state;
|
|
2956
2992
|
nk_dots_bf16_a16x32_sapphireamx_t b_as_a;
|
|
2957
2993
|
nk_dots_bf16_b32x16_sapphireamx_t b_tile;
|
|
2958
2994
|
|
|
@@ -2963,35 +2999,35 @@ NK_PUBLIC void nk_dots_packed_e5m2_sapphireamx( //
|
|
|
2963
2999
|
nk_size_t const depth_offset = depth_tile_idx * tile_depth;
|
|
2964
3000
|
nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
|
|
2965
3001
|
|
|
2966
|
-
nk_dots_e5m2_load_a_sapphireamx_(&
|
|
2967
|
-
a_stride_bytes,
|
|
2968
|
-
if (
|
|
2969
|
-
nk_dots_e5m2_load_a_sapphireamx_(&
|
|
3002
|
+
nk_dots_e5m2_load_a_sapphireamx_(&a_tile_top, a + row_block_start * a_stride_bytes + depth_offset,
|
|
3003
|
+
a_stride_bytes, rows_in_high_tile, valid_depth);
|
|
3004
|
+
if (rows_in_low_tile > 0) {
|
|
3005
|
+
nk_dots_e5m2_load_a_sapphireamx_(&a_tile_bottom,
|
|
2970
3006
|
a + (row_block_start + 16) * a_stride_bytes + depth_offset,
|
|
2971
|
-
a_stride_bytes,
|
|
3007
|
+
a_stride_bytes, rows_in_low_tile, valid_depth);
|
|
2972
3008
|
}
|
|
2973
3009
|
|
|
2974
3010
|
nk_dots_bf16_load_a_sapphireamx_(&b_as_a, col_edge_ptr + depth_offset, depth, column_remainder_count,
|
|
2975
3011
|
valid_depth);
|
|
2976
3012
|
nk_dots_pack_bf16_transposed_sapphireamx_(&b_as_a, &b_tile);
|
|
2977
3013
|
|
|
2978
|
-
_tile_loadd(0,
|
|
2979
|
-
_tile_loadd(1,
|
|
3014
|
+
_tile_loadd(0, a_tile_top.data, 64);
|
|
3015
|
+
_tile_loadd(1, a_tile_bottom.data, 64);
|
|
2980
3016
|
_tile_loadd(2, b_tile.data, 64);
|
|
2981
3017
|
|
|
2982
3018
|
_tile_dpbf16ps(4, 0, 2);
|
|
2983
3019
|
_tile_dpbf16ps(6, 1, 2);
|
|
2984
3020
|
}
|
|
2985
3021
|
|
|
2986
|
-
_tile_stored(4,
|
|
2987
|
-
_tile_stored(6,
|
|
3022
|
+
_tile_stored(4, c_high_state.data, 64);
|
|
3023
|
+
_tile_stored(6, c_low_state.data, 64);
|
|
2988
3024
|
|
|
2989
|
-
nk_dots_bf16_store_sapphireamx_(&
|
|
2990
|
-
c_stride_elements,
|
|
2991
|
-
if (
|
|
2992
|
-
nk_dots_bf16_store_sapphireamx_(&
|
|
3025
|
+
nk_dots_bf16_store_sapphireamx_(&c_high_state, c + row_block_start * c_stride_elements + full_cols,
|
|
3026
|
+
c_stride_elements, rows_in_high_tile, column_remainder_count);
|
|
3027
|
+
if (rows_in_low_tile > 0) {
|
|
3028
|
+
nk_dots_bf16_store_sapphireamx_(&c_low_state,
|
|
2993
3029
|
c + (row_block_start + 16) * c_stride_elements + full_cols,
|
|
2994
|
-
c_stride_elements,
|
|
3030
|
+
c_stride_elements, rows_in_low_tile, column_remainder_count);
|
|
2995
3031
|
}
|
|
2996
3032
|
}
|
|
2997
3033
|
}
|
|
@@ -2999,17 +3035,17 @@ NK_PUBLIC void nk_dots_packed_e5m2_sapphireamx( //
|
|
|
2999
3035
|
_tile_release();
|
|
3000
3036
|
}
|
|
3001
3037
|
|
|
3002
|
-
NK_PUBLIC void nk_dots_symmetric_e5m2_sapphireamx(
|
|
3003
|
-
nk_e5m2_t const *vectors, nk_size_t
|
|
3004
|
-
nk_size_t
|
|
3038
|
+
NK_PUBLIC void nk_dots_symmetric_e5m2_sapphireamx( //
|
|
3039
|
+
nk_e5m2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, //
|
|
3040
|
+
nk_size_t stride_in_bytes, nk_f32_t *result, nk_size_t result_stride_in_bytes, //
|
|
3005
3041
|
nk_size_t row_start, nk_size_t row_count) {
|
|
3006
3042
|
|
|
3007
|
-
nk_size_t const result_stride_elements =
|
|
3043
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
3008
3044
|
|
|
3009
3045
|
// Handle row slicing: compute rows [row_start, row_end)
|
|
3010
3046
|
nk_size_t const row_end = (row_count == 0)
|
|
3011
|
-
?
|
|
3012
|
-
: (row_start + row_count <
|
|
3047
|
+
? vectors_count
|
|
3048
|
+
: (row_start + row_count < vectors_count ? row_start + row_count : vectors_count);
|
|
3013
3049
|
|
|
3014
3050
|
// Round depth up to multiple of 96 (3 tiles × 32 elements)
|
|
3015
3051
|
nk_size_t const depth_tiles = nk_size_divide_round_up_(depth, 32);
|
|
@@ -3025,8 +3061,8 @@ NK_PUBLIC void nk_dots_symmetric_e5m2_sapphireamx( //
|
|
|
3025
3061
|
for (nk_size_t row_tile = row_start; row_tile < row_end; row_tile += 16) {
|
|
3026
3062
|
nk_size_t const valid_rows = (row_tile + 16 <= row_end) ? 16 : (row_end - row_tile);
|
|
3027
3063
|
|
|
3028
|
-
for (nk_size_t col_tile = 0; col_tile <
|
|
3029
|
-
nk_size_t const valid_cols = (col_tile + 16 <=
|
|
3064
|
+
for (nk_size_t col_tile = 0; col_tile < vectors_count; col_tile += 16) {
|
|
3065
|
+
nk_size_t const valid_cols = (col_tile + 16 <= vectors_count) ? 16 : (vectors_count - col_tile);
|
|
3030
3066
|
|
|
3031
3067
|
nk_dots_bf16_init_sapphireamx_(&state);
|
|
3032
3068
|
|
|
@@ -3039,19 +3075,19 @@ NK_PUBLIC void nk_dots_symmetric_e5m2_sapphireamx( //
|
|
|
3039
3075
|
? 32
|
|
3040
3076
|
: (depth > depth_start ? depth - depth_start : 0);
|
|
3041
3077
|
|
|
3042
|
-
nk_dots_e5m2_load_a_sapphireamx_(
|
|
3043
|
-
&a_tiles[tile_idx],
|
|
3044
|
-
vectors + row_tile *
|
|
3045
|
-
|
|
3078
|
+
nk_dots_e5m2_load_a_sapphireamx_( //
|
|
3079
|
+
&a_tiles[tile_idx], //
|
|
3080
|
+
vectors + row_tile * stride_in_bytes + depth_start, //
|
|
3081
|
+
stride_in_bytes, valid_rows, valid_depth);
|
|
3046
3082
|
|
|
3047
3083
|
if (row_tile == col_tile) {
|
|
3048
3084
|
nk_dots_pack_bf16_transposed_sapphireamx_(&a_tiles[tile_idx], &b_tiles[tile_idx]);
|
|
3049
3085
|
}
|
|
3050
3086
|
else {
|
|
3051
|
-
nk_dots_e5m2_load_a_sapphireamx_(
|
|
3052
|
-
&b_src_tiles[tile_idx],
|
|
3053
|
-
vectors + col_tile *
|
|
3054
|
-
|
|
3087
|
+
nk_dots_e5m2_load_a_sapphireamx_( //
|
|
3088
|
+
&b_src_tiles[tile_idx], //
|
|
3089
|
+
vectors + col_tile * stride_in_bytes + depth_start, //
|
|
3090
|
+
stride_in_bytes, valid_cols, valid_depth);
|
|
3055
3091
|
nk_dots_pack_bf16_transposed_sapphireamx_(&b_src_tiles[tile_idx], &b_tiles[tile_idx]);
|
|
3056
3092
|
}
|
|
3057
3093
|
}
|
|
@@ -3067,17 +3103,17 @@ NK_PUBLIC void nk_dots_symmetric_e5m2_sapphireamx( //
|
|
|
3067
3103
|
}
|
|
3068
3104
|
}
|
|
3069
3105
|
|
|
3070
|
-
NK_PUBLIC void nk_dots_symmetric_e4m3_sapphireamx(
|
|
3071
|
-
nk_e4m3_t const *vectors, nk_size_t
|
|
3072
|
-
nk_size_t
|
|
3106
|
+
NK_PUBLIC void nk_dots_symmetric_e4m3_sapphireamx( //
|
|
3107
|
+
nk_e4m3_t const *vectors, nk_size_t vectors_count, nk_size_t depth, //
|
|
3108
|
+
nk_size_t stride_in_bytes, nk_f32_t *result, nk_size_t result_stride_in_bytes, //
|
|
3073
3109
|
nk_size_t row_start, nk_size_t row_count) {
|
|
3074
3110
|
|
|
3075
|
-
nk_size_t const result_stride_elements =
|
|
3111
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
3076
3112
|
|
|
3077
3113
|
// Handle row slicing: compute rows [row_start, row_end)
|
|
3078
3114
|
nk_size_t const row_end = (row_count == 0)
|
|
3079
|
-
?
|
|
3080
|
-
: (row_start + row_count <
|
|
3115
|
+
? vectors_count
|
|
3116
|
+
: (row_start + row_count < vectors_count ? row_start + row_count : vectors_count);
|
|
3081
3117
|
|
|
3082
3118
|
// Round depth up to multiple of 96 (3 tiles × 32 elements)
|
|
3083
3119
|
nk_size_t const depth_tiles = nk_size_divide_round_up_(depth, 32);
|
|
@@ -3093,8 +3129,8 @@ NK_PUBLIC void nk_dots_symmetric_e4m3_sapphireamx( //
|
|
|
3093
3129
|
for (nk_size_t row_tile = row_start; row_tile < row_end; row_tile += 16) {
|
|
3094
3130
|
nk_size_t const valid_rows = (row_tile + 16 <= row_end) ? 16 : (row_end - row_tile);
|
|
3095
3131
|
|
|
3096
|
-
for (nk_size_t col_tile = 0; col_tile <
|
|
3097
|
-
nk_size_t const valid_cols = (col_tile + 16 <=
|
|
3132
|
+
for (nk_size_t col_tile = 0; col_tile < vectors_count; col_tile += 16) {
|
|
3133
|
+
nk_size_t const valid_cols = (col_tile + 16 <= vectors_count) ? 16 : (vectors_count - col_tile);
|
|
3098
3134
|
|
|
3099
3135
|
nk_dots_bf16_init_sapphireamx_(&state);
|
|
3100
3136
|
|
|
@@ -3107,19 +3143,19 @@ NK_PUBLIC void nk_dots_symmetric_e4m3_sapphireamx( //
|
|
|
3107
3143
|
? 32
|
|
3108
3144
|
: (depth > depth_start ? depth - depth_start : 0);
|
|
3109
3145
|
|
|
3110
|
-
nk_dots_e4m3_load_a_sapphireamx_(
|
|
3111
|
-
&a_tiles[tile_idx],
|
|
3112
|
-
vectors + row_tile *
|
|
3113
|
-
|
|
3146
|
+
nk_dots_e4m3_load_a_sapphireamx_( //
|
|
3147
|
+
&a_tiles[tile_idx], //
|
|
3148
|
+
vectors + row_tile * stride_in_bytes + depth_start, //
|
|
3149
|
+
stride_in_bytes, valid_rows, valid_depth);
|
|
3114
3150
|
|
|
3115
3151
|
if (row_tile == col_tile) {
|
|
3116
3152
|
nk_dots_pack_bf16_transposed_sapphireamx_(&a_tiles[tile_idx], &b_tiles[tile_idx]);
|
|
3117
3153
|
}
|
|
3118
3154
|
else {
|
|
3119
|
-
nk_dots_e4m3_load_a_sapphireamx_(
|
|
3120
|
-
&b_src_tiles[tile_idx],
|
|
3121
|
-
vectors + col_tile *
|
|
3122
|
-
|
|
3155
|
+
nk_dots_e4m3_load_a_sapphireamx_( //
|
|
3156
|
+
&b_src_tiles[tile_idx], //
|
|
3157
|
+
vectors + col_tile * stride_in_bytes + depth_start, //
|
|
3158
|
+
stride_in_bytes, valid_cols, valid_depth);
|
|
3123
3159
|
nk_dots_pack_bf16_transposed_sapphireamx_(&b_src_tiles[tile_idx], &b_tiles[tile_idx]);
|
|
3124
3160
|
}
|
|
3125
3161
|
}
|
|
@@ -3135,9 +3171,9 @@ NK_PUBLIC void nk_dots_symmetric_e4m3_sapphireamx( //
|
|
|
3135
3171
|
}
|
|
3136
3172
|
}
|
|
3137
3173
|
|
|
3138
|
-
#pragma endregion
|
|
3174
|
+
#pragma endregion E5M2 Floats
|
|
3139
3175
|
|
|
3140
|
-
#pragma region
|
|
3176
|
+
#pragma region E2M3 Floats
|
|
3141
3177
|
|
|
3142
3178
|
/* Load E2M3 A tile with E2M3 to signed I8 conversion via VPERMB LUT.
|
|
3143
3179
|
* Each E2M3 byte encodes: bit 5 = sign, bits 4:0 = magnitude (5-bit index).
|
|
@@ -3194,12 +3230,12 @@ NK_INTERNAL void nk_dots_e2m3_store_sapphireamx_( //
|
|
|
3194
3230
|
nk_size_t valid_rows, nk_size_t valid_cols) {
|
|
3195
3231
|
|
|
3196
3232
|
__mmask16 column_mask = (valid_cols >= 16) ? 0xFFFF : ((__mmask16)1 << valid_cols) - 1;
|
|
3197
|
-
__m512
|
|
3233
|
+
__m512 scale_f32x16 = _mm512_set1_ps(1.0f / 256.0f);
|
|
3198
3234
|
|
|
3199
3235
|
for (nk_size_t row = 0; row < valid_rows; row++) {
|
|
3200
|
-
__m512i
|
|
3201
|
-
__m512
|
|
3202
|
-
_mm512_mask_storeu_ps(dst + row * dst_stride_elements, column_mask,
|
|
3236
|
+
__m512i i32_row_i32x16 = _mm512_load_si512(state->data[row]);
|
|
3237
|
+
__m512 f32_row_f32x16 = _mm512_mul_ps(_mm512_cvtepi32_ps(i32_row_i32x16), scale_f32x16);
|
|
3238
|
+
_mm512_mask_storeu_ps(dst + row * dst_stride_elements, column_mask, f32_row_f32x16);
|
|
3203
3239
|
}
|
|
3204
3240
|
}
|
|
3205
3241
|
|
|
@@ -3209,23 +3245,22 @@ NK_INTERNAL void nk_dots_e2m3_output2x2_sapphireamx_( //
|
|
|
3209
3245
|
nk_f32_t *dst, nk_size_t dst_stride_elements, //
|
|
3210
3246
|
nk_size_t valid_rows, nk_size_t valid_cols) {
|
|
3211
3247
|
|
|
3212
|
-
nk_size_t const
|
|
3248
|
+
nk_size_t const rows_high = (valid_rows > 16) ? 16 : valid_rows;
|
|
3213
3249
|
nk_size_t const cols_left = (valid_cols > 16) ? 16 : valid_cols;
|
|
3214
3250
|
nk_size_t const cols_right = (valid_cols > 16) ? valid_cols - 16 : 0;
|
|
3215
3251
|
|
|
3216
|
-
if (
|
|
3217
|
-
nk_dots_e2m3_store_sapphireamx_(&state->c[0][0], dst, dst_stride_elements,
|
|
3218
|
-
if (
|
|
3219
|
-
nk_dots_e2m3_store_sapphireamx_(&state->c[0][1], dst + 16, dst_stride_elements,
|
|
3252
|
+
if (rows_high > 0 && cols_left > 0)
|
|
3253
|
+
nk_dots_e2m3_store_sapphireamx_(&state->c[0][0], dst, dst_stride_elements, rows_high, cols_left);
|
|
3254
|
+
if (rows_high > 0 && cols_right > 0)
|
|
3255
|
+
nk_dots_e2m3_store_sapphireamx_(&state->c[0][1], dst + 16, dst_stride_elements, rows_high, cols_right);
|
|
3220
3256
|
|
|
3221
3257
|
if (valid_rows > 16) {
|
|
3222
|
-
nk_size_t const
|
|
3223
|
-
nk_f32_t *
|
|
3258
|
+
nk_size_t const rows_low = valid_rows - 16;
|
|
3259
|
+
nk_f32_t *dst_low = dst + 16 * dst_stride_elements;
|
|
3224
3260
|
if (cols_left > 0)
|
|
3225
|
-
nk_dots_e2m3_store_sapphireamx_(&state->c[1][0],
|
|
3261
|
+
nk_dots_e2m3_store_sapphireamx_(&state->c[1][0], dst_low, dst_stride_elements, rows_low, cols_left);
|
|
3226
3262
|
if (cols_right > 0)
|
|
3227
|
-
nk_dots_e2m3_store_sapphireamx_(&state->c[1][1],
|
|
3228
|
-
cols_right);
|
|
3263
|
+
nk_dots_e2m3_store_sapphireamx_(&state->c[1][1], dst_low + 16, dst_stride_elements, rows_low, cols_right);
|
|
3229
3264
|
}
|
|
3230
3265
|
}
|
|
3231
3266
|
|
|
@@ -3236,7 +3271,7 @@ NK_PUBLIC nk_size_t nk_dots_packed_size_e2m3_sapphireamx(nk_size_t column_count,
|
|
|
3236
3271
|
|
|
3237
3272
|
NK_PUBLIC void nk_dots_pack_e2m3_sapphireamx( //
|
|
3238
3273
|
nk_e2m3_t const *b, nk_size_t column_count, nk_size_t depth, //
|
|
3239
|
-
nk_size_t
|
|
3274
|
+
nk_size_t b_stride_in_bytes, void *b_packed) {
|
|
3240
3275
|
|
|
3241
3276
|
// AMX I8 tile dimensions: 16 rows x 64 columns (1024 I8 elements = 1KB)
|
|
3242
3277
|
nk_size_t const tmm_rows = 16;
|
|
@@ -3261,16 +3296,7 @@ NK_PUBLIC void nk_dots_pack_e2m3_sapphireamx( //
|
|
|
3261
3296
|
nk_i8_t *tiles_ptr = (nk_i8_t *)((char *)b_packed + tiles_offset);
|
|
3262
3297
|
nk_i8_t *column_edge_ptr = (nk_i8_t *)((char *)b_packed + column_edge_offset);
|
|
3263
3298
|
|
|
3264
|
-
//
|
|
3265
|
-
for (nk_size_t idx = 0; idx < total_tiles * tile_elements; idx++) tiles_ptr[idx] = 0;
|
|
3266
|
-
|
|
3267
|
-
// E2M3 magnitude-to-value LUT (value * 16)
|
|
3268
|
-
static nk_u8_t const lut_magnitude[32] = {
|
|
3269
|
-
0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, //
|
|
3270
|
-
32, 36, 40, 44, 48, 52, 56, 60, 64, 72, 80, 88, 96, 104, 112, 120, //
|
|
3271
|
-
};
|
|
3272
|
-
|
|
3273
|
-
// Pack tiles with E2M3 -> I8 conversion and quad-interleaving
|
|
3299
|
+
// Pack tiles using vectorized E2M3 → I8 conversion + SIMD transpose
|
|
3274
3300
|
for (nk_size_t column_tile_idx = 0; column_tile_idx < column_tiles_count; column_tile_idx++) {
|
|
3275
3301
|
for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
|
|
3276
3302
|
nk_size_t const tile_index = column_tile_idx * depth_tiles_count + depth_tile_idx;
|
|
@@ -3281,26 +3307,44 @@ NK_PUBLIC void nk_dots_pack_e2m3_sapphireamx( //
|
|
|
3281
3307
|
nk_size_t const columns_to_pack = (src_column_start + tmm_cols <= depth) ? tmm_cols
|
|
3282
3308
|
: (depth - src_column_start);
|
|
3283
3309
|
|
|
3284
|
-
|
|
3285
|
-
|
|
3286
|
-
|
|
3287
|
-
|
|
3288
|
-
|
|
3289
|
-
|
|
3290
|
-
|
|
3291
|
-
|
|
3292
|
-
|
|
3310
|
+
// Convert E2M3 → I8 and gather into aligned source tile
|
|
3311
|
+
nk_dots_i8_a16x64_sapphireamx_t source_tile;
|
|
3312
|
+
if (columns_to_pack == tmm_cols) {
|
|
3313
|
+
for (nk_size_t row_idx = 0; row_idx < tmm_rows; row_idx++) {
|
|
3314
|
+
__m512i raw_row = _mm512_loadu_si512(
|
|
3315
|
+
(nk_e2m3_t const *)((char const *)b + (src_row_start + row_idx) * b_stride_in_bytes) +
|
|
3316
|
+
src_column_start);
|
|
3317
|
+
_mm512_store_si512(&source_tile.data[row_idx][0], nk_e2m3x64_to_i8x64_skylake_(raw_row));
|
|
3318
|
+
}
|
|
3319
|
+
}
|
|
3320
|
+
else {
|
|
3321
|
+
__mmask64 depth_mask = (__mmask64)((columns_to_pack < 64) ? ((1ULL << columns_to_pack) - 1) : ~0ULL);
|
|
3322
|
+
for (nk_size_t row_idx = 0; row_idx < tmm_rows; row_idx++) {
|
|
3323
|
+
__m512i raw_row = _mm512_maskz_loadu_epi8(
|
|
3324
|
+
depth_mask,
|
|
3325
|
+
(nk_e2m3_t const *)((char const *)b + (src_row_start + row_idx) * b_stride_in_bytes) +
|
|
3326
|
+
src_column_start);
|
|
3327
|
+
_mm512_store_si512(&source_tile.data[row_idx][0], nk_e2m3x64_to_i8x64_skylake_(raw_row));
|
|
3293
3328
|
}
|
|
3294
3329
|
}
|
|
3330
|
+
|
|
3331
|
+
nk_dots_i8_b64x16_sapphireamx_t transposed_tile;
|
|
3332
|
+
nk_dots_pack_i8_transposed_sapphireamx_(&source_tile, &transposed_tile);
|
|
3333
|
+
for (nk_size_t i = 0; i < tile_elements; i += 64)
|
|
3334
|
+
_mm512_storeu_si512(tile_output + i, _mm512_load_si512((char const *)&transposed_tile + i));
|
|
3295
3335
|
}
|
|
3296
3336
|
}
|
|
3297
3337
|
|
|
3298
|
-
// Pack column-remainder rows (convert E2M3 to I8)
|
|
3338
|
+
// Pack column-remainder rows (convert E2M3 to I8) using scalar LUT
|
|
3339
|
+
static nk_u8_t const lut_magnitude[32] = {
|
|
3340
|
+
0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, //
|
|
3341
|
+
32, 36, 40, 44, 48, 52, 56, 60, 64, 72, 80, 88, 96, 104, 112, 120, //
|
|
3342
|
+
};
|
|
3299
3343
|
if (column_remainder_count > 0) {
|
|
3300
3344
|
nk_size_t const remainder_start_row = column_tiles_count * tmm_rows;
|
|
3301
3345
|
for (nk_size_t row_idx = 0; row_idx < column_remainder_count; row_idx++) {
|
|
3302
3346
|
for (nk_size_t column_idx = 0; column_idx < depth; column_idx++) {
|
|
3303
|
-
nk_u8_t raw = b[(remainder_start_row + row_idx) *
|
|
3347
|
+
nk_u8_t raw = b[(remainder_start_row + row_idx) * b_stride_in_bytes + column_idx];
|
|
3304
3348
|
nk_u8_t magnitude = raw & 0x1F;
|
|
3305
3349
|
nk_i8_t val = (nk_i8_t)lut_magnitude[magnitude];
|
|
3306
3350
|
if (raw & 0x20) val = -val;
|
|
@@ -3315,7 +3359,7 @@ NK_PUBLIC void nk_dots_pack_e2m3_sapphireamx( //
|
|
|
3315
3359
|
header->norms_byte_offset = (nk_u32_t)norms_offset;
|
|
3316
3360
|
nk_f32_t *norms = (nk_f32_t *)((char *)b_packed + norms_offset);
|
|
3317
3361
|
for (nk_size_t col = 0; col < column_count; col++)
|
|
3318
|
-
norms[col] = nk_dots_reduce_sumsq_e2m3_(b + col *
|
|
3362
|
+
norms[col] = nk_dots_reduce_sumsq_e2m3_(b + col * b_stride_in_bytes, depth);
|
|
3319
3363
|
}
|
|
3320
3364
|
|
|
3321
3365
|
NK_PUBLIC void nk_dots_packed_e2m3_sapphireamx( //
|
|
@@ -3342,7 +3386,7 @@ NK_PUBLIC void nk_dots_packed_e2m3_sapphireamx( //
|
|
|
3342
3386
|
|
|
3343
3387
|
if (depth_tiles_count == 0) return;
|
|
3344
3388
|
|
|
3345
|
-
nk_dots_i8_a16x64_sapphireamx_t
|
|
3389
|
+
nk_dots_i8_a16x64_sapphireamx_t a_tile_top, a_tile_bottom;
|
|
3346
3390
|
nk_dots_i8_state2x2_sapphireamx_t c_accum_buffer;
|
|
3347
3391
|
|
|
3348
3392
|
nk_size_t const full_depth_tiles_count = depth / tile_depth;
|
|
@@ -3355,8 +3399,8 @@ NK_PUBLIC void nk_dots_packed_e2m3_sapphireamx( //
|
|
|
3355
3399
|
nk_size_t const row_block_start = row_block_idx * 32;
|
|
3356
3400
|
nk_size_t const valid_rows_count = (row_block_start + 32 <= rows_count) ? 32 : (rows_count - row_block_start);
|
|
3357
3401
|
nk_size_t const is_full_row_block = (valid_rows_count == 32);
|
|
3358
|
-
nk_size_t const
|
|
3359
|
-
nk_size_t const
|
|
3402
|
+
nk_size_t const rows_in_high_tile = (valid_rows_count > 16) ? 16 : valid_rows_count;
|
|
3403
|
+
nk_size_t const rows_in_low_tile = (valid_rows_count > 16) ? valid_rows_count - 16 : 0;
|
|
3360
3404
|
|
|
3361
3405
|
for (nk_size_t column_block_idx = 0; column_block_idx < col_blocks_count; column_block_idx++) {
|
|
3362
3406
|
nk_size_t const col_block_start = column_block_idx * 32;
|
|
@@ -3375,12 +3419,12 @@ NK_PUBLIC void nk_dots_packed_e2m3_sapphireamx( //
|
|
|
3375
3419
|
nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
|
|
3376
3420
|
|
|
3377
3421
|
// Load A with E2M3 -> I8 conversion
|
|
3378
|
-
nk_dots_e2m3_load_a_sapphireamx_(&
|
|
3379
|
-
a_stride_bytes,
|
|
3380
|
-
if (
|
|
3381
|
-
nk_dots_e2m3_load_a_sapphireamx_(&
|
|
3422
|
+
nk_dots_e2m3_load_a_sapphireamx_(&a_tile_top, a + row_block_start * a_stride_bytes + depth_offset,
|
|
3423
|
+
a_stride_bytes, rows_in_high_tile, valid_depth);
|
|
3424
|
+
if (rows_in_low_tile > 0) {
|
|
3425
|
+
nk_dots_e2m3_load_a_sapphireamx_(&a_tile_bottom,
|
|
3382
3426
|
a + (row_block_start + 16) * a_stride_bytes + depth_offset,
|
|
3383
|
-
a_stride_bytes,
|
|
3427
|
+
a_stride_bytes, rows_in_low_tile, valid_depth);
|
|
3384
3428
|
}
|
|
3385
3429
|
|
|
3386
3430
|
nk_dots_i8_b64x16_sapphireamx_t const *b_tile_left =
|
|
@@ -3390,8 +3434,8 @@ NK_PUBLIC void nk_dots_packed_e2m3_sapphireamx( //
|
|
|
3390
3434
|
(nk_dots_i8_b64x16_sapphireamx_t const *)(b_tiles_base +
|
|
3391
3435
|
(b_column_right_base + depth_tile_idx) * tile_size);
|
|
3392
3436
|
|
|
3393
|
-
_tile_loadd(0,
|
|
3394
|
-
_tile_loadd(1,
|
|
3437
|
+
_tile_loadd(0, a_tile_top.data, 64);
|
|
3438
|
+
_tile_loadd(1, a_tile_bottom.data, 64);
|
|
3395
3439
|
_tile_loadd(2, b_tile_left->data, 64);
|
|
3396
3440
|
_tile_loadd(3, b_tile_right->data, 64);
|
|
3397
3441
|
|
|
@@ -3429,7 +3473,7 @@ NK_PUBLIC void nk_dots_packed_e2m3_sapphireamx( //
|
|
|
3429
3473
|
nk_size_t const col_start = column_tile_idx * 16;
|
|
3430
3474
|
nk_size_t const b_column_base = column_tile_idx * depth_tiles_count;
|
|
3431
3475
|
|
|
3432
|
-
nk_dots_i8_state_sapphireamx_t
|
|
3476
|
+
nk_dots_i8_state_sapphireamx_t c_high_state, c_low_state;
|
|
3433
3477
|
_tile_zero(4);
|
|
3434
3478
|
_tile_zero(6);
|
|
3435
3479
|
|
|
@@ -3437,41 +3481,41 @@ NK_PUBLIC void nk_dots_packed_e2m3_sapphireamx( //
|
|
|
3437
3481
|
nk_size_t const depth_offset = depth_tile_idx * tile_depth;
|
|
3438
3482
|
nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
|
|
3439
3483
|
|
|
3440
|
-
nk_dots_e2m3_load_a_sapphireamx_(&
|
|
3441
|
-
a_stride_bytes,
|
|
3442
|
-
if (
|
|
3443
|
-
nk_dots_e2m3_load_a_sapphireamx_(&
|
|
3484
|
+
nk_dots_e2m3_load_a_sapphireamx_(&a_tile_top, a + row_block_start * a_stride_bytes + depth_offset,
|
|
3485
|
+
a_stride_bytes, rows_in_high_tile, valid_depth);
|
|
3486
|
+
if (rows_in_low_tile > 0) {
|
|
3487
|
+
nk_dots_e2m3_load_a_sapphireamx_(&a_tile_bottom,
|
|
3444
3488
|
a + (row_block_start + 16) * a_stride_bytes + depth_offset,
|
|
3445
|
-
a_stride_bytes,
|
|
3489
|
+
a_stride_bytes, rows_in_low_tile, valid_depth);
|
|
3446
3490
|
}
|
|
3447
3491
|
|
|
3448
3492
|
nk_dots_i8_b64x16_sapphireamx_t const *b_tile =
|
|
3449
3493
|
(nk_dots_i8_b64x16_sapphireamx_t const *)(b_tiles_base +
|
|
3450
3494
|
(b_column_base + depth_tile_idx) * tile_size);
|
|
3451
3495
|
|
|
3452
|
-
_tile_loadd(0,
|
|
3453
|
-
_tile_loadd(1,
|
|
3496
|
+
_tile_loadd(0, a_tile_top.data, 64);
|
|
3497
|
+
_tile_loadd(1, a_tile_bottom.data, 64);
|
|
3454
3498
|
_tile_loadd(2, b_tile->data, 64);
|
|
3455
3499
|
|
|
3456
3500
|
_tile_dpbssd(4, 0, 2);
|
|
3457
3501
|
_tile_dpbssd(6, 1, 2);
|
|
3458
3502
|
}
|
|
3459
3503
|
|
|
3460
|
-
_tile_stored(4,
|
|
3461
|
-
_tile_stored(6,
|
|
3504
|
+
_tile_stored(4, c_high_state.data, 64);
|
|
3505
|
+
_tile_stored(6, c_low_state.data, 64);
|
|
3462
3506
|
|
|
3463
|
-
nk_dots_e2m3_store_sapphireamx_(&
|
|
3464
|
-
c_stride_elements,
|
|
3465
|
-
if (
|
|
3466
|
-
nk_dots_e2m3_store_sapphireamx_(&
|
|
3507
|
+
nk_dots_e2m3_store_sapphireamx_(&c_high_state, c + row_block_start * c_stride_elements + col_start,
|
|
3508
|
+
c_stride_elements, rows_in_high_tile, 16);
|
|
3509
|
+
if (rows_in_low_tile > 0) {
|
|
3510
|
+
nk_dots_e2m3_store_sapphireamx_(&c_low_state,
|
|
3467
3511
|
c + (row_block_start + 16) * c_stride_elements + col_start,
|
|
3468
|
-
c_stride_elements,
|
|
3512
|
+
c_stride_elements, rows_in_low_tile, 16);
|
|
3469
3513
|
}
|
|
3470
3514
|
}
|
|
3471
3515
|
|
|
3472
3516
|
// Handle column-edge (remaining columns < 16) using AMX with partial tiles
|
|
3473
3517
|
if (column_remainder_count > 0) {
|
|
3474
|
-
nk_dots_i8_state_sapphireamx_t
|
|
3518
|
+
nk_dots_i8_state_sapphireamx_t c_high_state, c_low_state;
|
|
3475
3519
|
nk_dots_i8_a16x64_sapphireamx_t b_as_a;
|
|
3476
3520
|
nk_dots_i8_b64x16_sapphireamx_t b_tile;
|
|
3477
3521
|
|
|
@@ -3482,12 +3526,12 @@ NK_PUBLIC void nk_dots_packed_e2m3_sapphireamx( //
|
|
|
3482
3526
|
nk_size_t const depth_offset = depth_tile_idx * tile_depth;
|
|
3483
3527
|
nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
|
|
3484
3528
|
|
|
3485
|
-
nk_dots_e2m3_load_a_sapphireamx_(&
|
|
3486
|
-
a_stride_bytes,
|
|
3487
|
-
if (
|
|
3488
|
-
nk_dots_e2m3_load_a_sapphireamx_(&
|
|
3529
|
+
nk_dots_e2m3_load_a_sapphireamx_(&a_tile_top, a + row_block_start * a_stride_bytes + depth_offset,
|
|
3530
|
+
a_stride_bytes, rows_in_high_tile, valid_depth);
|
|
3531
|
+
if (rows_in_low_tile > 0) {
|
|
3532
|
+
nk_dots_e2m3_load_a_sapphireamx_(&a_tile_bottom,
|
|
3489
3533
|
a + (row_block_start + 16) * a_stride_bytes + depth_offset,
|
|
3490
|
-
a_stride_bytes,
|
|
3534
|
+
a_stride_bytes, rows_in_low_tile, valid_depth);
|
|
3491
3535
|
}
|
|
3492
3536
|
|
|
3493
3537
|
// B edge data is already in I8 format
|
|
@@ -3495,23 +3539,23 @@ NK_PUBLIC void nk_dots_packed_e2m3_sapphireamx( //
|
|
|
3495
3539
|
valid_depth);
|
|
3496
3540
|
nk_dots_pack_i8_transposed_sapphireamx_(&b_as_a, &b_tile);
|
|
3497
3541
|
|
|
3498
|
-
_tile_loadd(0,
|
|
3499
|
-
_tile_loadd(1,
|
|
3542
|
+
_tile_loadd(0, a_tile_top.data, 64);
|
|
3543
|
+
_tile_loadd(1, a_tile_bottom.data, 64);
|
|
3500
3544
|
_tile_loadd(2, b_tile.data, 64);
|
|
3501
3545
|
|
|
3502
3546
|
_tile_dpbssd(4, 0, 2);
|
|
3503
3547
|
_tile_dpbssd(6, 1, 2);
|
|
3504
3548
|
}
|
|
3505
3549
|
|
|
3506
|
-
_tile_stored(4,
|
|
3507
|
-
_tile_stored(6,
|
|
3550
|
+
_tile_stored(4, c_high_state.data, 64);
|
|
3551
|
+
_tile_stored(6, c_low_state.data, 64);
|
|
3508
3552
|
|
|
3509
|
-
nk_dots_e2m3_store_sapphireamx_(&
|
|
3510
|
-
c_stride_elements,
|
|
3511
|
-
if (
|
|
3512
|
-
nk_dots_e2m3_store_sapphireamx_(&
|
|
3553
|
+
nk_dots_e2m3_store_sapphireamx_(&c_high_state, c + row_block_start * c_stride_elements + full_cols,
|
|
3554
|
+
c_stride_elements, rows_in_high_tile, column_remainder_count);
|
|
3555
|
+
if (rows_in_low_tile > 0) {
|
|
3556
|
+
nk_dots_e2m3_store_sapphireamx_(&c_low_state,
|
|
3513
3557
|
c + (row_block_start + 16) * c_stride_elements + full_cols,
|
|
3514
|
-
c_stride_elements,
|
|
3558
|
+
c_stride_elements, rows_in_low_tile, column_remainder_count);
|
|
3515
3559
|
}
|
|
3516
3560
|
}
|
|
3517
3561
|
}
|
|
@@ -3519,17 +3563,17 @@ NK_PUBLIC void nk_dots_packed_e2m3_sapphireamx( //
|
|
|
3519
3563
|
_tile_release();
|
|
3520
3564
|
}
|
|
3521
3565
|
|
|
3522
|
-
NK_PUBLIC void nk_dots_symmetric_e2m3_sapphireamx(
|
|
3523
|
-
nk_e2m3_t const *vectors, nk_size_t
|
|
3524
|
-
nk_size_t
|
|
3566
|
+
NK_PUBLIC void nk_dots_symmetric_e2m3_sapphireamx( //
|
|
3567
|
+
nk_e2m3_t const *vectors, nk_size_t vectors_count, nk_size_t depth, //
|
|
3568
|
+
nk_size_t stride_in_bytes, nk_f32_t *result, nk_size_t result_stride_in_bytes, //
|
|
3525
3569
|
nk_size_t row_start, nk_size_t row_count) {
|
|
3526
3570
|
|
|
3527
|
-
nk_size_t const result_stride_elements =
|
|
3571
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
3528
3572
|
|
|
3529
3573
|
// Handle row slicing: compute rows [row_start, row_end)
|
|
3530
3574
|
nk_size_t const row_end = (row_count == 0)
|
|
3531
|
-
?
|
|
3532
|
-
: (row_start + row_count <
|
|
3575
|
+
? vectors_count
|
|
3576
|
+
: (row_start + row_count < vectors_count ? row_start + row_count : vectors_count);
|
|
3533
3577
|
|
|
3534
3578
|
// Round depth up to multiple of 192 (3 tiles x 64 elements)
|
|
3535
3579
|
nk_size_t const depth_tiles = nk_size_divide_round_up_(depth, 64);
|
|
@@ -3545,8 +3589,8 @@ NK_PUBLIC void nk_dots_symmetric_e2m3_sapphireamx( //
|
|
|
3545
3589
|
for (nk_size_t row_tile = row_start; row_tile < row_end; row_tile += 16) {
|
|
3546
3590
|
nk_size_t const valid_rows = (row_tile + 16 <= row_end) ? 16 : (row_end - row_tile);
|
|
3547
3591
|
|
|
3548
|
-
for (nk_size_t col_tile = 0; col_tile <
|
|
3549
|
-
nk_size_t const valid_cols = (col_tile + 16 <=
|
|
3592
|
+
for (nk_size_t col_tile = 0; col_tile < vectors_count; col_tile += 16) {
|
|
3593
|
+
nk_size_t const valid_cols = (col_tile + 16 <= vectors_count) ? 16 : (vectors_count - col_tile);
|
|
3550
3594
|
|
|
3551
3595
|
nk_dots_i8_init_sapphireamx_(&state);
|
|
3552
3596
|
|
|
@@ -3559,19 +3603,19 @@ NK_PUBLIC void nk_dots_symmetric_e2m3_sapphireamx( //
|
|
|
3559
3603
|
? 64
|
|
3560
3604
|
: (depth > depth_start ? depth - depth_start : 0);
|
|
3561
3605
|
|
|
3562
|
-
nk_dots_e2m3_load_a_sapphireamx_(
|
|
3563
|
-
&a_tiles[tile_idx],
|
|
3564
|
-
vectors + row_tile *
|
|
3565
|
-
|
|
3606
|
+
nk_dots_e2m3_load_a_sapphireamx_( //
|
|
3607
|
+
&a_tiles[tile_idx], //
|
|
3608
|
+
vectors + row_tile * stride_in_bytes + depth_start, //
|
|
3609
|
+
stride_in_bytes, valid_rows, valid_depth);
|
|
3566
3610
|
|
|
3567
3611
|
if (row_tile == col_tile) {
|
|
3568
3612
|
nk_dots_pack_i8_transposed_sapphireamx_(&a_tiles[tile_idx], &b_tiles[tile_idx]);
|
|
3569
3613
|
}
|
|
3570
3614
|
else {
|
|
3571
|
-
nk_dots_e2m3_load_a_sapphireamx_(
|
|
3572
|
-
&b_src_tiles[tile_idx],
|
|
3573
|
-
vectors + col_tile *
|
|
3574
|
-
|
|
3615
|
+
nk_dots_e2m3_load_a_sapphireamx_( //
|
|
3616
|
+
&b_src_tiles[tile_idx], //
|
|
3617
|
+
vectors + col_tile * stride_in_bytes + depth_start, //
|
|
3618
|
+
stride_in_bytes, valid_cols, valid_depth);
|
|
3575
3619
|
nk_dots_pack_i8_transposed_sapphireamx_(&b_src_tiles[tile_idx], &b_tiles[tile_idx]);
|
|
3576
3620
|
}
|
|
3577
3621
|
}
|
|
@@ -3587,9 +3631,9 @@ NK_PUBLIC void nk_dots_symmetric_e2m3_sapphireamx( //
|
|
|
3587
3631
|
}
|
|
3588
3632
|
}
|
|
3589
3633
|
|
|
3590
|
-
#pragma endregion
|
|
3634
|
+
#pragma endregion E2M3 Floats
|
|
3591
3635
|
|
|
3592
|
-
#pragma region
|
|
3636
|
+
#pragma region E3M2 Floats
|
|
3593
3637
|
|
|
3594
3638
|
/* Load E3M2 A tile with FP8 to BF16 conversion */
|
|
3595
3639
|
NK_INTERNAL void nk_dots_e3m2_load_a_sapphireamx_( //
|
|
@@ -3598,15 +3642,15 @@ NK_INTERNAL void nk_dots_e3m2_load_a_sapphireamx_( //
|
|
|
3598
3642
|
nk_size_t valid_rows, nk_size_t valid_cols) {
|
|
3599
3643
|
|
|
3600
3644
|
__mmask32 column_mask = (valid_cols >= 32) ? 0xFFFFFFFF : ((__mmask32)1 << valid_cols) - 1;
|
|
3601
|
-
__m512i
|
|
3645
|
+
__m512i zero_i16x32 = _mm512_setzero_si512();
|
|
3602
3646
|
|
|
3603
3647
|
for (nk_size_t row_idx = 0; row_idx < 16; row_idx++) {
|
|
3604
3648
|
if (row_idx < valid_rows) {
|
|
3605
|
-
__m256i
|
|
3606
|
-
__m512i
|
|
3607
|
-
_mm512_store_si512((__m512i *)a_tile->data[row_idx],
|
|
3649
|
+
__m256i e3m2_row_u8x32 = _mm256_maskz_loadu_epi8(column_mask, src + row_idx * src_stride);
|
|
3650
|
+
__m512i bf16_row_i16x32 = nk_e3m2x32_to_bf16x32_icelake_(e3m2_row_u8x32);
|
|
3651
|
+
_mm512_store_si512((__m512i *)a_tile->data[row_idx], bf16_row_i16x32);
|
|
3608
3652
|
}
|
|
3609
|
-
else { _mm512_store_si512((__m512i *)a_tile->data[row_idx],
|
|
3653
|
+
else { _mm512_store_si512((__m512i *)a_tile->data[row_idx], zero_i16x32); }
|
|
3610
3654
|
}
|
|
3611
3655
|
nk_compiler_barrier_sapphireamx_();
|
|
3612
3656
|
}
|
|
@@ -3617,7 +3661,7 @@ NK_PUBLIC nk_size_t nk_dots_packed_size_e3m2_sapphireamx(nk_size_t column_count,
|
|
|
3617
3661
|
|
|
3618
3662
|
NK_PUBLIC void nk_dots_pack_e3m2_sapphireamx( //
|
|
3619
3663
|
nk_e3m2_t const *b, nk_size_t column_count, nk_size_t depth, //
|
|
3620
|
-
nk_size_t
|
|
3664
|
+
nk_size_t b_stride_in_bytes, void *b_packed) {
|
|
3621
3665
|
|
|
3622
3666
|
nk_size_t const tmm_rows = 16;
|
|
3623
3667
|
nk_size_t const tmm_cols = 32;
|
|
@@ -3641,8 +3685,7 @@ NK_PUBLIC void nk_dots_pack_e3m2_sapphireamx( //
|
|
|
3641
3685
|
nk_bf16_t *tiles_ptr = (nk_bf16_t *)((char *)b_packed + tiles_offset);
|
|
3642
3686
|
nk_bf16_t *column_edge_ptr = (nk_bf16_t *)((char *)b_packed + column_edge_offset);
|
|
3643
3687
|
|
|
3644
|
-
|
|
3645
|
-
|
|
3688
|
+
// Pack tiles using vectorized convert + SIMD transpose
|
|
3646
3689
|
for (nk_size_t column_tile_idx = 0; column_tile_idx < column_tiles_count; column_tile_idx++) {
|
|
3647
3690
|
for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
|
|
3648
3691
|
nk_size_t const tile_index = column_tile_idx * depth_tiles_count + depth_tile_idx;
|
|
@@ -3653,18 +3696,18 @@ NK_PUBLIC void nk_dots_pack_e3m2_sapphireamx( //
|
|
|
3653
3696
|
nk_size_t const columns_to_pack = (src_column_start + tmm_cols <= depth) ? tmm_cols
|
|
3654
3697
|
: (depth - src_column_start);
|
|
3655
3698
|
|
|
3699
|
+
__mmask32 column_mask = (columns_to_pack >= 32) ? 0xFFFFFFFF : ((__mmask32)1 << columns_to_pack) - 1;
|
|
3700
|
+
nk_dots_bf16_a16x32_sapphireamx_t source_tile;
|
|
3656
3701
|
for (nk_size_t row_idx = 0; row_idx < tmm_rows; row_idx++) {
|
|
3657
|
-
|
|
3658
|
-
|
|
3659
|
-
|
|
3660
|
-
__m512i bf16_row = nk_e3m2x32_to_bf16x32_icelake_(e3m2_row);
|
|
3661
|
-
nk_bf16_t bf16_buf[32];
|
|
3662
|
-
_mm512_storeu_si512((__m512i *)bf16_buf, bf16_row);
|
|
3663
|
-
for (nk_size_t column_idx = 0; column_idx < columns_to_pack; column_idx++) {
|
|
3664
|
-
nk_size_t const dst_idx = (column_idx / 2) * 32 + row_idx * 2 + (column_idx % 2);
|
|
3665
|
-
tile_output[dst_idx] = bf16_buf[column_idx];
|
|
3666
|
-
}
|
|
3702
|
+
__m256i e3m2_row_u8x32 = _mm256_maskz_loadu_epi8(
|
|
3703
|
+
column_mask, b + (src_row_start + row_idx) * b_stride_in_bytes + src_column_start);
|
|
3704
|
+
_mm512_store_si512(&source_tile.data[row_idx][0], nk_e3m2x32_to_bf16x32_icelake_(e3m2_row_u8x32));
|
|
3667
3705
|
}
|
|
3706
|
+
|
|
3707
|
+
nk_dots_bf16_b32x16_sapphireamx_t transposed_tile;
|
|
3708
|
+
nk_dots_pack_bf16_transposed_sapphireamx_(&source_tile, &transposed_tile);
|
|
3709
|
+
for (nk_size_t i = 0; i < tile_bytes; i += 64)
|
|
3710
|
+
_mm512_storeu_si512((char *)tile_output + i, _mm512_load_si512((char const *)&transposed_tile + i));
|
|
3668
3711
|
}
|
|
3669
3712
|
}
|
|
3670
3713
|
|
|
@@ -3674,10 +3717,11 @@ NK_PUBLIC void nk_dots_pack_e3m2_sapphireamx( //
|
|
|
3674
3717
|
for (nk_size_t column_idx = 0; column_idx < depth; column_idx += 32) {
|
|
3675
3718
|
nk_size_t columns = (column_idx + 32 <= depth) ? 32 : (depth - column_idx);
|
|
3676
3719
|
__mmask32 column_mask = (columns >= 32) ? 0xFFFFFFFF : ((__mmask32)1 << columns) - 1;
|
|
3677
|
-
__m256i
|
|
3678
|
-
column_mask, b + (remainder_start_row + row_idx) *
|
|
3679
|
-
__m512i
|
|
3680
|
-
_mm512_mask_storeu_epi16(column_edge_ptr + row_idx * depth + column_idx, column_mask,
|
|
3720
|
+
__m256i e3m2_chunk_u8x32 = _mm256_maskz_loadu_epi8(
|
|
3721
|
+
column_mask, b + (remainder_start_row + row_idx) * b_stride_in_bytes + column_idx);
|
|
3722
|
+
__m512i bf16_chunk_i16x32 = nk_e3m2x32_to_bf16x32_icelake_(e3m2_chunk_u8x32);
|
|
3723
|
+
_mm512_mask_storeu_epi16(column_edge_ptr + row_idx * depth + column_idx, column_mask,
|
|
3724
|
+
bf16_chunk_i16x32);
|
|
3681
3725
|
}
|
|
3682
3726
|
}
|
|
3683
3727
|
}
|
|
@@ -3688,7 +3732,7 @@ NK_PUBLIC void nk_dots_pack_e3m2_sapphireamx( //
|
|
|
3688
3732
|
header->norms_byte_offset = (nk_u32_t)norms_offset;
|
|
3689
3733
|
nk_f32_t *norms = (nk_f32_t *)((char *)b_packed + norms_offset);
|
|
3690
3734
|
for (nk_size_t col = 0; col < column_count; col++)
|
|
3691
|
-
norms[col] = nk_dots_reduce_sumsq_e3m2_(b + col *
|
|
3735
|
+
norms[col] = nk_dots_reduce_sumsq_e3m2_(b + col * b_stride_in_bytes, depth);
|
|
3692
3736
|
}
|
|
3693
3737
|
|
|
3694
3738
|
NK_PUBLIC void nk_dots_packed_e3m2_sapphireamx( //
|
|
@@ -3714,7 +3758,7 @@ NK_PUBLIC void nk_dots_packed_e3m2_sapphireamx( //
|
|
|
3714
3758
|
|
|
3715
3759
|
if (depth_tiles_count == 0) return;
|
|
3716
3760
|
|
|
3717
|
-
nk_dots_bf16_a16x32_sapphireamx_t
|
|
3761
|
+
nk_dots_bf16_a16x32_sapphireamx_t a_tile_top, a_tile_bottom;
|
|
3718
3762
|
nk_dots_bf16_state2x2_sapphireamx_t c_accum_buffer;
|
|
3719
3763
|
|
|
3720
3764
|
nk_size_t const full_depth_tiles_count = depth / tile_depth;
|
|
@@ -3727,8 +3771,8 @@ NK_PUBLIC void nk_dots_packed_e3m2_sapphireamx( //
|
|
|
3727
3771
|
nk_size_t const row_block_start = row_block_idx * 32;
|
|
3728
3772
|
nk_size_t const valid_rows_count = (row_block_start + 32 <= rows_count) ? 32 : (rows_count - row_block_start);
|
|
3729
3773
|
nk_size_t const is_full_row_block = (valid_rows_count == 32);
|
|
3730
|
-
nk_size_t const
|
|
3731
|
-
nk_size_t const
|
|
3774
|
+
nk_size_t const rows_in_high_tile = (valid_rows_count > 16) ? 16 : valid_rows_count;
|
|
3775
|
+
nk_size_t const rows_in_low_tile = (valid_rows_count > 16) ? valid_rows_count - 16 : 0;
|
|
3732
3776
|
|
|
3733
3777
|
for (nk_size_t column_block_idx = 0; column_block_idx < col_blocks_count; column_block_idx++) {
|
|
3734
3778
|
nk_size_t const col_block_start = column_block_idx * 32;
|
|
@@ -3747,12 +3791,12 @@ NK_PUBLIC void nk_dots_packed_e3m2_sapphireamx( //
|
|
|
3747
3791
|
nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
|
|
3748
3792
|
|
|
3749
3793
|
// Load A with FP8 -> BF16 conversion
|
|
3750
|
-
nk_dots_e3m2_load_a_sapphireamx_(&
|
|
3751
|
-
a_stride_bytes,
|
|
3752
|
-
if (
|
|
3753
|
-
nk_dots_e3m2_load_a_sapphireamx_(&
|
|
3794
|
+
nk_dots_e3m2_load_a_sapphireamx_(&a_tile_top, a + row_block_start * a_stride_bytes + depth_offset,
|
|
3795
|
+
a_stride_bytes, rows_in_high_tile, valid_depth);
|
|
3796
|
+
if (rows_in_low_tile > 0) {
|
|
3797
|
+
nk_dots_e3m2_load_a_sapphireamx_(&a_tile_bottom,
|
|
3754
3798
|
a + (row_block_start + 16) * a_stride_bytes + depth_offset,
|
|
3755
|
-
a_stride_bytes,
|
|
3799
|
+
a_stride_bytes, rows_in_low_tile, valid_depth);
|
|
3756
3800
|
}
|
|
3757
3801
|
|
|
3758
3802
|
nk_dots_bf16_b32x16_sapphireamx_t const *b_tile_left =
|
|
@@ -3762,8 +3806,8 @@ NK_PUBLIC void nk_dots_packed_e3m2_sapphireamx( //
|
|
|
3762
3806
|
(nk_dots_bf16_b32x16_sapphireamx_t const *)(b_tiles_base +
|
|
3763
3807
|
(b_column_right_base + depth_tile_idx) * tile_size);
|
|
3764
3808
|
|
|
3765
|
-
_tile_loadd(0,
|
|
3766
|
-
_tile_loadd(1,
|
|
3809
|
+
_tile_loadd(0, a_tile_top.data, 64);
|
|
3810
|
+
_tile_loadd(1, a_tile_bottom.data, 64);
|
|
3767
3811
|
_tile_loadd(2, b_tile_left->data, 64);
|
|
3768
3812
|
_tile_loadd(3, b_tile_right->data, 64);
|
|
3769
3813
|
|
|
@@ -3798,7 +3842,7 @@ NK_PUBLIC void nk_dots_packed_e3m2_sapphireamx( //
|
|
|
3798
3842
|
nk_size_t const col_start = column_tile_idx * 16;
|
|
3799
3843
|
nk_size_t const b_column_base = column_tile_idx * depth_tiles_count;
|
|
3800
3844
|
|
|
3801
|
-
nk_dots_bf16_state_sapphireamx_t
|
|
3845
|
+
nk_dots_bf16_state_sapphireamx_t c_high_state, c_low_state;
|
|
3802
3846
|
_tile_zero(4);
|
|
3803
3847
|
_tile_zero(6);
|
|
3804
3848
|
|
|
@@ -3806,41 +3850,41 @@ NK_PUBLIC void nk_dots_packed_e3m2_sapphireamx( //
|
|
|
3806
3850
|
nk_size_t const depth_offset = depth_tile_idx * tile_depth;
|
|
3807
3851
|
nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
|
|
3808
3852
|
|
|
3809
|
-
nk_dots_e3m2_load_a_sapphireamx_(&
|
|
3810
|
-
a_stride_bytes,
|
|
3811
|
-
if (
|
|
3812
|
-
nk_dots_e3m2_load_a_sapphireamx_(&
|
|
3853
|
+
nk_dots_e3m2_load_a_sapphireamx_(&a_tile_top, a + row_block_start * a_stride_bytes + depth_offset,
|
|
3854
|
+
a_stride_bytes, rows_in_high_tile, valid_depth);
|
|
3855
|
+
if (rows_in_low_tile > 0) {
|
|
3856
|
+
nk_dots_e3m2_load_a_sapphireamx_(&a_tile_bottom,
|
|
3813
3857
|
a + (row_block_start + 16) * a_stride_bytes + depth_offset,
|
|
3814
|
-
a_stride_bytes,
|
|
3858
|
+
a_stride_bytes, rows_in_low_tile, valid_depth);
|
|
3815
3859
|
}
|
|
3816
3860
|
|
|
3817
3861
|
nk_dots_bf16_b32x16_sapphireamx_t const *b_tile =
|
|
3818
3862
|
(nk_dots_bf16_b32x16_sapphireamx_t const *)(b_tiles_base +
|
|
3819
3863
|
(b_column_base + depth_tile_idx) * tile_size);
|
|
3820
3864
|
|
|
3821
|
-
_tile_loadd(0,
|
|
3822
|
-
_tile_loadd(1,
|
|
3865
|
+
_tile_loadd(0, a_tile_top.data, 64);
|
|
3866
|
+
_tile_loadd(1, a_tile_bottom.data, 64);
|
|
3823
3867
|
_tile_loadd(2, b_tile->data, 64);
|
|
3824
3868
|
|
|
3825
3869
|
_tile_dpbf16ps(4, 0, 2);
|
|
3826
3870
|
_tile_dpbf16ps(6, 1, 2);
|
|
3827
3871
|
}
|
|
3828
3872
|
|
|
3829
|
-
_tile_stored(4,
|
|
3830
|
-
_tile_stored(6,
|
|
3873
|
+
_tile_stored(4, c_high_state.data, 64);
|
|
3874
|
+
_tile_stored(6, c_low_state.data, 64);
|
|
3831
3875
|
|
|
3832
|
-
nk_dots_bf16_store_sapphireamx_(&
|
|
3833
|
-
c_stride_elements,
|
|
3834
|
-
if (
|
|
3835
|
-
nk_dots_bf16_store_sapphireamx_(&
|
|
3876
|
+
nk_dots_bf16_store_sapphireamx_(&c_high_state, c + row_block_start * c_stride_elements + col_start,
|
|
3877
|
+
c_stride_elements, rows_in_high_tile, 16);
|
|
3878
|
+
if (rows_in_low_tile > 0) {
|
|
3879
|
+
nk_dots_bf16_store_sapphireamx_(&c_low_state,
|
|
3836
3880
|
c + (row_block_start + 16) * c_stride_elements + col_start,
|
|
3837
|
-
c_stride_elements,
|
|
3881
|
+
c_stride_elements, rows_in_low_tile, 16);
|
|
3838
3882
|
}
|
|
3839
3883
|
}
|
|
3840
3884
|
|
|
3841
3885
|
// Handle column-edge (remaining columns < 16) using AMX with partial tiles
|
|
3842
3886
|
if (column_remainder_count > 0) {
|
|
3843
|
-
nk_dots_bf16_state_sapphireamx_t
|
|
3887
|
+
nk_dots_bf16_state_sapphireamx_t c_high_state, c_low_state;
|
|
3844
3888
|
nk_dots_bf16_a16x32_sapphireamx_t b_as_a;
|
|
3845
3889
|
nk_dots_bf16_b32x16_sapphireamx_t b_tile;
|
|
3846
3890
|
|
|
@@ -3851,35 +3895,35 @@ NK_PUBLIC void nk_dots_packed_e3m2_sapphireamx( //
|
|
|
3851
3895
|
nk_size_t const depth_offset = depth_tile_idx * tile_depth;
|
|
3852
3896
|
nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
|
|
3853
3897
|
|
|
3854
|
-
nk_dots_e3m2_load_a_sapphireamx_(&
|
|
3855
|
-
a_stride_bytes,
|
|
3856
|
-
if (
|
|
3857
|
-
nk_dots_e3m2_load_a_sapphireamx_(&
|
|
3898
|
+
nk_dots_e3m2_load_a_sapphireamx_(&a_tile_top, a + row_block_start * a_stride_bytes + depth_offset,
|
|
3899
|
+
a_stride_bytes, rows_in_high_tile, valid_depth);
|
|
3900
|
+
if (rows_in_low_tile > 0) {
|
|
3901
|
+
nk_dots_e3m2_load_a_sapphireamx_(&a_tile_bottom,
|
|
3858
3902
|
a + (row_block_start + 16) * a_stride_bytes + depth_offset,
|
|
3859
|
-
a_stride_bytes,
|
|
3903
|
+
a_stride_bytes, rows_in_low_tile, valid_depth);
|
|
3860
3904
|
}
|
|
3861
3905
|
|
|
3862
3906
|
nk_dots_bf16_load_a_sapphireamx_(&b_as_a, col_edge_ptr + depth_offset, depth, column_remainder_count,
|
|
3863
3907
|
valid_depth);
|
|
3864
3908
|
nk_dots_pack_bf16_transposed_sapphireamx_(&b_as_a, &b_tile);
|
|
3865
3909
|
|
|
3866
|
-
_tile_loadd(0,
|
|
3867
|
-
_tile_loadd(1,
|
|
3910
|
+
_tile_loadd(0, a_tile_top.data, 64);
|
|
3911
|
+
_tile_loadd(1, a_tile_bottom.data, 64);
|
|
3868
3912
|
_tile_loadd(2, b_tile.data, 64);
|
|
3869
3913
|
|
|
3870
3914
|
_tile_dpbf16ps(4, 0, 2);
|
|
3871
3915
|
_tile_dpbf16ps(6, 1, 2);
|
|
3872
3916
|
}
|
|
3873
3917
|
|
|
3874
|
-
_tile_stored(4,
|
|
3875
|
-
_tile_stored(6,
|
|
3918
|
+
_tile_stored(4, c_high_state.data, 64);
|
|
3919
|
+
_tile_stored(6, c_low_state.data, 64);
|
|
3876
3920
|
|
|
3877
|
-
nk_dots_bf16_store_sapphireamx_(&
|
|
3878
|
-
c_stride_elements,
|
|
3879
|
-
if (
|
|
3880
|
-
nk_dots_bf16_store_sapphireamx_(&
|
|
3921
|
+
nk_dots_bf16_store_sapphireamx_(&c_high_state, c + row_block_start * c_stride_elements + full_cols,
|
|
3922
|
+
c_stride_elements, rows_in_high_tile, column_remainder_count);
|
|
3923
|
+
if (rows_in_low_tile > 0) {
|
|
3924
|
+
nk_dots_bf16_store_sapphireamx_(&c_low_state,
|
|
3881
3925
|
c + (row_block_start + 16) * c_stride_elements + full_cols,
|
|
3882
|
-
c_stride_elements,
|
|
3926
|
+
c_stride_elements, rows_in_low_tile, column_remainder_count);
|
|
3883
3927
|
}
|
|
3884
3928
|
}
|
|
3885
3929
|
}
|
|
@@ -3887,18 +3931,18 @@ NK_PUBLIC void nk_dots_packed_e3m2_sapphireamx( //
|
|
|
3887
3931
|
_tile_release();
|
|
3888
3932
|
}
|
|
3889
3933
|
|
|
3890
|
-
NK_PUBLIC void nk_dots_symmetric_e3m2_sapphireamx(
|
|
3891
|
-
nk_e3m2_t const *vectors, nk_size_t
|
|
3892
|
-
nk_size_t
|
|
3934
|
+
NK_PUBLIC void nk_dots_symmetric_e3m2_sapphireamx( //
|
|
3935
|
+
nk_e3m2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, //
|
|
3936
|
+
nk_size_t stride_in_bytes, nk_f32_t *result, nk_size_t result_stride_in_bytes, //
|
|
3893
3937
|
nk_size_t row_start, nk_size_t row_count) {
|
|
3894
3938
|
|
|
3895
|
-
nk_size_t const stride_elements =
|
|
3896
|
-
nk_size_t const result_stride_elements =
|
|
3939
|
+
nk_size_t const stride_elements = stride_in_bytes; // sizeof(nk_e3m2_t) == 1, so bytes == elements
|
|
3940
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
3897
3941
|
|
|
3898
3942
|
// Handle row slicing: compute rows [row_start, row_end)
|
|
3899
3943
|
nk_size_t const row_end = (row_count == 0)
|
|
3900
|
-
?
|
|
3901
|
-
: (row_start + row_count <
|
|
3944
|
+
? vectors_count
|
|
3945
|
+
: (row_start + row_count < vectors_count ? row_start + row_count : vectors_count);
|
|
3902
3946
|
|
|
3903
3947
|
// Round depth up to multiple of 96 (3 tiles x 32 bf16 elements)
|
|
3904
3948
|
nk_size_t const depth_tiles = nk_size_divide_round_up_(depth, 32);
|
|
@@ -3914,8 +3958,8 @@ NK_PUBLIC void nk_dots_symmetric_e3m2_sapphireamx( //
|
|
|
3914
3958
|
for (nk_size_t row_tile = row_start; row_tile < row_end; row_tile += 16) {
|
|
3915
3959
|
nk_size_t const valid_rows = (row_tile + 16 <= row_end) ? 16 : (row_end - row_tile);
|
|
3916
3960
|
|
|
3917
|
-
for (nk_size_t col_tile = 0; col_tile <
|
|
3918
|
-
nk_size_t const valid_cols = (col_tile + 16 <=
|
|
3961
|
+
for (nk_size_t col_tile = 0; col_tile < vectors_count; col_tile += 16) {
|
|
3962
|
+
nk_size_t const valid_cols = (col_tile + 16 <= vectors_count) ? 16 : (vectors_count - col_tile);
|
|
3919
3963
|
|
|
3920
3964
|
nk_dots_bf16_init_sapphireamx_(&state);
|
|
3921
3965
|
|
|
@@ -3956,7 +4000,7 @@ NK_PUBLIC void nk_dots_symmetric_e3m2_sapphireamx( //
|
|
|
3956
4000
|
}
|
|
3957
4001
|
}
|
|
3958
4002
|
|
|
3959
|
-
#pragma endregion
|
|
4003
|
+
#pragma endregion E3M2 Floats
|
|
3960
4004
|
|
|
3961
4005
|
#if defined(__clang__)
|
|
3962
4006
|
#pragma clang attribute pop
|