numkong 7.5.0 → 7.6.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/binding.gyp +18 -0
- package/c/dispatch_e5m2.c +23 -3
- package/include/numkong/capabilities.h +1 -1
- package/include/numkong/cast/README.md +3 -0
- package/include/numkong/cast/haswell.h +28 -64
- package/include/numkong/cast/serial.h +17 -0
- package/include/numkong/cast/skylake.h +67 -52
- package/include/numkong/cast.h +1 -0
- package/include/numkong/dot/README.md +1 -0
- package/include/numkong/dot/haswell.h +92 -13
- package/include/numkong/dot/serial.h +15 -0
- package/include/numkong/dot/skylake.h +61 -14
- package/include/numkong/dots/README.md +2 -0
- package/include/numkong/dots/graniteamx.h +434 -0
- package/include/numkong/dots/haswell.h +28 -28
- package/include/numkong/dots/sapphireamx.h +1 -1
- package/include/numkong/dots/serial.h +23 -8
- package/include/numkong/dots/skylake.h +28 -23
- package/include/numkong/dots.h +12 -0
- package/include/numkong/each/serial.h +18 -1
- package/include/numkong/geospatial/serial.h +14 -3
- package/include/numkong/maxsim/serial.h +15 -0
- package/include/numkong/mesh/README.md +50 -44
- package/include/numkong/mesh/genoa.h +462 -0
- package/include/numkong/mesh/haswell.h +806 -933
- package/include/numkong/mesh/neon.h +871 -943
- package/include/numkong/mesh/neonbfdot.h +382 -522
- package/include/numkong/mesh/neonfhm.h +676 -0
- package/include/numkong/mesh/rvv.h +404 -319
- package/include/numkong/mesh/serial.h +204 -162
- package/include/numkong/mesh/skylake.h +1029 -1585
- package/include/numkong/mesh/v128relaxed.h +403 -377
- package/include/numkong/mesh.h +38 -0
- package/include/numkong/reduce/serial.h +15 -1
- package/include/numkong/sparse/serial.h +17 -2
- package/include/numkong/spatial/genoa.h +0 -68
- package/include/numkong/spatial/haswell.h +98 -56
- package/include/numkong/spatial/serial.h +15 -0
- package/include/numkong/spatial/skylake.h +114 -54
- package/include/numkong/spatial.h +0 -12
- package/include/numkong/spatials/graniteamx.h +128 -0
- package/include/numkong/spatials/serial.h +18 -1
- package/include/numkong/spatials/skylake.h +2 -2
- package/include/numkong/spatials.h +17 -0
- package/include/numkong/tensor.hpp +107 -23
- package/javascript/numkong.c +3 -2
- package/package.json +7 -7
- package/wasm/numkong.wasm +0 -0
|
@@ -156,6 +156,36 @@ NK_INTERNAL void nk_dot_through_f32_update_skylake_(nk_dot_through_f32_state_sky
|
|
|
156
156
|
state->sum_f32x16 = _mm512_fmadd_ps(a.zmm_ps, b.zmm_ps, state->sum_f32x16);
|
|
157
157
|
}
|
|
158
158
|
|
|
159
|
+
/**
|
|
160
|
+
* @brief E5M2 byte-batched update: consumes 64 raw E5M2 bytes per call and widens inline.
|
|
161
|
+
* Two independent FMA chains (each 2-deep) merge into the single state accumulator at exit.
|
|
162
|
+
* Keeps register pressure at one __m512 across calls while breaking the FMA dep chain.
|
|
163
|
+
*/
|
|
164
|
+
NK_INTERNAL void nk_dot_e5m2x64_update_skylake_(nk_dot_through_f32_state_skylake_t_ *state, nk_b512_vec_t a_bytes,
|
|
165
|
+
nk_b512_vec_t b_bytes, nk_size_t depth_offset,
|
|
166
|
+
nk_size_t active_dimensions) {
|
|
167
|
+
nk_unused_(depth_offset);
|
|
168
|
+
nk_unused_(active_dimensions);
|
|
169
|
+
__m512i const zero_u8x64 = _mm512_setzero_si512();
|
|
170
|
+
__m512i a_even_f16x32 = _mm512_unpacklo_epi8(zero_u8x64, a_bytes.zmm);
|
|
171
|
+
__m512i a_odd_f16x32 = _mm512_unpackhi_epi8(zero_u8x64, a_bytes.zmm);
|
|
172
|
+
__m512i b_even_f16x32 = _mm512_unpacklo_epi8(zero_u8x64, b_bytes.zmm);
|
|
173
|
+
__m512i b_odd_f16x32 = _mm512_unpackhi_epi8(zero_u8x64, b_bytes.zmm);
|
|
174
|
+
__m512 a_first_f32x16 = _mm512_cvtph_ps(_mm512_castsi512_si256(a_even_f16x32));
|
|
175
|
+
__m512 a_second_f32x16 = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(a_even_f16x32, 1));
|
|
176
|
+
__m512 a_third_f32x16 = _mm512_cvtph_ps(_mm512_castsi512_si256(a_odd_f16x32));
|
|
177
|
+
__m512 a_fourth_f32x16 = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(a_odd_f16x32, 1));
|
|
178
|
+
__m512 b_first_f32x16 = _mm512_cvtph_ps(_mm512_castsi512_si256(b_even_f16x32));
|
|
179
|
+
__m512 b_second_f32x16 = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(b_even_f16x32, 1));
|
|
180
|
+
__m512 b_third_f32x16 = _mm512_cvtph_ps(_mm512_castsi512_si256(b_odd_f16x32));
|
|
181
|
+
__m512 b_fourth_f32x16 = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(b_odd_f16x32, 1));
|
|
182
|
+
__m512 first_chain_f32x16 = _mm512_mul_ps(a_first_f32x16, b_first_f32x16);
|
|
183
|
+
__m512 second_chain_f32x16 = _mm512_mul_ps(a_second_f32x16, b_second_f32x16);
|
|
184
|
+
first_chain_f32x16 = _mm512_fmadd_ps(a_third_f32x16, b_third_f32x16, first_chain_f32x16);
|
|
185
|
+
second_chain_f32x16 = _mm512_fmadd_ps(a_fourth_f32x16, b_fourth_f32x16, second_chain_f32x16);
|
|
186
|
+
state->sum_f32x16 = _mm512_add_ps(state->sum_f32x16, _mm512_add_ps(first_chain_f32x16, second_chain_f32x16));
|
|
187
|
+
}
|
|
188
|
+
|
|
159
189
|
/**
|
|
160
190
|
* @brief Finalizes 4x low-precision dot-products placing them into 4x consecutive 32-bit slots.
|
|
161
191
|
* @sa nk_dot_f16x16_udpate_skylake, nk_dot_bf16x16_udpate_skylake
|
|
@@ -543,7 +573,7 @@ NK_PUBLIC void nk_dot_e4m3_skylake(nk_e4m3_t const *a_scalars, nk_e4m3_t const *
|
|
|
543
573
|
|
|
544
574
|
nk_dot_e4m3_skylake_cycle:
|
|
545
575
|
if (count_scalars < 16) {
|
|
546
|
-
__mmask16 mask = (__mmask16)_bzhi_u32(0xFFFF, count_scalars);
|
|
576
|
+
__mmask16 mask = (__mmask16)_bzhi_u32(0xFFFF, (unsigned int)count_scalars);
|
|
547
577
|
a_e4m3_u8x16 = _mm_maskz_loadu_epi8(mask, a_scalars);
|
|
548
578
|
b_e4m3_u8x16 = _mm_maskz_loadu_epi8(mask, b_scalars);
|
|
549
579
|
count_scalars = 0;
|
|
@@ -563,27 +593,44 @@ nk_dot_e4m3_skylake_cycle:
|
|
|
563
593
|
|
|
564
594
|
NK_PUBLIC void nk_dot_e5m2_skylake(nk_e5m2_t const *a_scalars, nk_e5m2_t const *b_scalars, nk_size_t count_scalars,
|
|
565
595
|
nk_f32_t *result) {
|
|
566
|
-
|
|
567
|
-
|
|
596
|
+
// E5M2 shares F16 bias (15): vpunpck*bw against zero places the byte as F16 encoding,
|
|
597
|
+
// so we inline the widen rather than calling the helper 4× — same ops, cleaner code.
|
|
598
|
+
__m512 first_chain_f32x16 = _mm512_setzero_ps();
|
|
599
|
+
__m512 second_chain_f32x16 = _mm512_setzero_ps();
|
|
600
|
+
__m512i const zero_u8x64 = _mm512_setzero_si512();
|
|
601
|
+
__m512i a_u8x64, b_u8x64;
|
|
568
602
|
|
|
569
603
|
nk_dot_e5m2_skylake_cycle:
|
|
570
|
-
if (count_scalars <
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
|
|
604
|
+
if (count_scalars < 64) {
|
|
605
|
+
__mmask64 mask = _bzhi_u64(0xFFFFFFFFFFFFFFFFULL, (unsigned int)count_scalars);
|
|
606
|
+
a_u8x64 = _mm512_maskz_loadu_epi8(mask, a_scalars);
|
|
607
|
+
b_u8x64 = _mm512_maskz_loadu_epi8(mask, b_scalars);
|
|
574
608
|
count_scalars = 0;
|
|
575
609
|
}
|
|
576
610
|
else {
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
a_scalars +=
|
|
611
|
+
a_u8x64 = _mm512_loadu_si512((__m512i const *)a_scalars);
|
|
612
|
+
b_u8x64 = _mm512_loadu_si512((__m512i const *)b_scalars);
|
|
613
|
+
a_scalars += 64, b_scalars += 64, count_scalars -= 64;
|
|
580
614
|
}
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
|
|
615
|
+
__m512i a_even_f16x32 = _mm512_unpacklo_epi8(zero_u8x64, a_u8x64);
|
|
616
|
+
__m512i a_odd_f16x32 = _mm512_unpackhi_epi8(zero_u8x64, a_u8x64);
|
|
617
|
+
__m512i b_even_f16x32 = _mm512_unpacklo_epi8(zero_u8x64, b_u8x64);
|
|
618
|
+
__m512i b_odd_f16x32 = _mm512_unpackhi_epi8(zero_u8x64, b_u8x64);
|
|
619
|
+
__m512 a_first_f32x16 = _mm512_cvtph_ps(_mm512_castsi512_si256(a_even_f16x32));
|
|
620
|
+
__m512 a_second_f32x16 = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(a_even_f16x32, 1));
|
|
621
|
+
__m512 a_third_f32x16 = _mm512_cvtph_ps(_mm512_castsi512_si256(a_odd_f16x32));
|
|
622
|
+
__m512 a_fourth_f32x16 = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(a_odd_f16x32, 1));
|
|
623
|
+
__m512 b_first_f32x16 = _mm512_cvtph_ps(_mm512_castsi512_si256(b_even_f16x32));
|
|
624
|
+
__m512 b_second_f32x16 = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(b_even_f16x32, 1));
|
|
625
|
+
__m512 b_third_f32x16 = _mm512_cvtph_ps(_mm512_castsi512_si256(b_odd_f16x32));
|
|
626
|
+
__m512 b_fourth_f32x16 = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(b_odd_f16x32, 1));
|
|
627
|
+
first_chain_f32x16 = _mm512_fmadd_ps(a_first_f32x16, b_first_f32x16, first_chain_f32x16);
|
|
628
|
+
second_chain_f32x16 = _mm512_fmadd_ps(a_second_f32x16, b_second_f32x16, second_chain_f32x16);
|
|
629
|
+
first_chain_f32x16 = _mm512_fmadd_ps(a_third_f32x16, b_third_f32x16, first_chain_f32x16);
|
|
630
|
+
second_chain_f32x16 = _mm512_fmadd_ps(a_fourth_f32x16, b_fourth_f32x16, second_chain_f32x16);
|
|
584
631
|
if (count_scalars) goto nk_dot_e5m2_skylake_cycle;
|
|
585
632
|
|
|
586
|
-
*result = nk_reduce_add_f32x16_skylake_(
|
|
633
|
+
*result = nk_reduce_add_f32x16_skylake_(_mm512_add_ps(first_chain_f32x16, second_chain_f32x16));
|
|
587
634
|
}
|
|
588
635
|
|
|
589
636
|
NK_PUBLIC void nk_dot_e2m3_skylake(nk_e2m3_t const *a_scalars, nk_e2m3_t const *b_scalars, nk_size_t count_scalars,
|
|
@@ -72,6 +72,8 @@ Int8 data is quad-interleaved: [a₀, a₁, a₂, a₃, a₀, a₁, a₂, a₃,
|
|
|
72
72
|
Tile configuration via `LDTILECFG` sets row counts and column byte-widths per tile — allows undersized tiles at matrix edges without masking.
|
|
73
73
|
Morton Z-curve ordering for tile traversal improves cache reuse when both A and B exceed L2.
|
|
74
74
|
This eliminates the explicit M×N×K loop nesting and register file pressure of vector ISAs — the entire dot-product reduction happens inside the tile instruction.
|
|
75
|
+
FP8 inputs on Sapphire AMX go through an on-the-fly E4M3/E5M2 → BF16 pack via the Ice Lake `VPERMI2W` LUT helpers — port-5-bound but the simplest correct route to feed `TDPBF16PS` tiles.
|
|
76
|
+
Granite Rapids adds `TDPFP16PS` (same tile shape, FP16 operands); the E5M2 variant widens inputs with a single `VPUNPCK*BW` against zero into FP16 tiles at pack time and then reuses the native FP16 compute loop — keeps the intermediate at FP16 precision instead of truncating to BF16 like the Sapphire path.
|
|
75
77
|
|
|
76
78
|
### SME Outer-Product Streaming
|
|
77
79
|
|
|
@@ -718,6 +718,440 @@ NK_PUBLIC void nk_dots_symmetric_f16_graniteamx(
|
|
|
718
718
|
|
|
719
719
|
#pragma endregion F16 Native
|
|
720
720
|
|
|
721
|
+
#pragma region E5M2 Source (widened to FP16 tiles)
|
|
722
|
+
|
|
723
|
+
/* E5M2 Granite AMX kernels: same F16 tile shapes, same TDPFP16PS compute body as the F16 path.
|
|
724
|
+
* The only difference is byte-to-word widening during A-load and B-pack: E5M2 shares F16's
|
|
725
|
+
* exponent bias (15), so `(byte << 8)` is the exact F16 bit pattern for every E5M2 value,
|
|
726
|
+
* including zero/subnormals/Inf/NaN. Tile buffers hold F16 after widen — "e5m2" in the
|
|
727
|
+
* typedefs refers to the source dtype, not the on-tile representation.
|
|
728
|
+
*
|
|
729
|
+
* The tile types below alias the F16 types (identical memory layout) so we can reuse the
|
|
730
|
+
* F16 init/store/output2x2/update internal helpers by pointer cast at the boundary. Only the
|
|
731
|
+
* public entry points and the load-A widen helper are new.
|
|
732
|
+
*/
|
|
733
|
+
|
|
734
|
+
typedef nk_dots_f16_a16x32_graniteamx_t nk_dots_e5m2_a16x32_graniteamx_t;
|
|
735
|
+
typedef nk_dots_f16_b32x16_graniteamx_t nk_dots_e5m2_b32x16_graniteamx_t;
|
|
736
|
+
typedef nk_dots_f16_state_graniteamx_t nk_dots_e5m2_state_graniteamx_t;
|
|
737
|
+
typedef nk_dots_f16_state2x2_graniteamx_t nk_dots_e5m2_state2x2_graniteamx_t;
|
|
738
|
+
|
|
739
|
+
/* Load A tile from E5M2 row-major source, widen to F16 via `(byte << 8)` into the F16 tile buffer. */
|
|
740
|
+
NK_INTERNAL void nk_dots_e5m2_load_a_graniteamx_( //
|
|
741
|
+
nk_dots_e5m2_a16x32_graniteamx_t *a_tile, //
|
|
742
|
+
nk_e5m2_t const *src, nk_size_t src_stride_elements, //
|
|
743
|
+
nk_size_t valid_rows, nk_size_t valid_cols) {
|
|
744
|
+
|
|
745
|
+
__mmask32 column_mask = (valid_cols >= 32) ? 0xFFFFFFFF : ((__mmask32)1 << valid_cols) - 1;
|
|
746
|
+
__m512i zero_i16x32 = _mm512_setzero_si512();
|
|
747
|
+
|
|
748
|
+
for (nk_size_t row_idx = 0; row_idx < 16; row_idx++) {
|
|
749
|
+
if (row_idx < valid_rows) {
|
|
750
|
+
__m256i e5m2_u8x32 = _mm256_maskz_loadu_epi8(column_mask, src + row_idx * src_stride_elements);
|
|
751
|
+
__m512i word_u16x32 = _mm512_cvtepu8_epi16(e5m2_u8x32);
|
|
752
|
+
__m512i f16_bits_i16x32 = _mm512_slli_epi16(word_u16x32, 8);
|
|
753
|
+
_mm512_store_si512((__m512i *)a_tile->data[row_idx], f16_bits_i16x32);
|
|
754
|
+
}
|
|
755
|
+
else { _mm512_store_si512((__m512i *)a_tile->data[row_idx], zero_i16x32); }
|
|
756
|
+
}
|
|
757
|
+
nk_compiler_barrier_sapphireamx_();
|
|
758
|
+
}
|
|
759
|
+
|
|
760
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_e5m2_graniteamx(nk_size_t column_count, nk_size_t depth) {
|
|
761
|
+
nk_size_t const tmm_rows = 16;
|
|
762
|
+
nk_size_t const tmm_cols = 32;
|
|
763
|
+
nk_size_t const tile_bytes = 512 * sizeof(nk_f16_t); // Tiles hold F16 after widen: same 1KB as F16.
|
|
764
|
+
|
|
765
|
+
nk_size_t const full_column_tiles = column_count / tmm_rows;
|
|
766
|
+
nk_size_t const tiles_along_depth = nk_size_divide_round_up_(depth, tmm_cols);
|
|
767
|
+
nk_size_t const column_remainder_count = column_count - full_column_tiles * tmm_rows;
|
|
768
|
+
|
|
769
|
+
nk_size_t size = sizeof(nk_dots_amx_packed_header_t);
|
|
770
|
+
size += full_column_tiles * tiles_along_depth * tile_bytes;
|
|
771
|
+
if (column_remainder_count > 0) size += column_remainder_count * depth * sizeof(nk_f16_t);
|
|
772
|
+
size += column_count * sizeof(nk_f32_t);
|
|
773
|
+
|
|
774
|
+
return size;
|
|
775
|
+
}
|
|
776
|
+
|
|
777
|
+
NK_PUBLIC void nk_dots_pack_e5m2_graniteamx( //
|
|
778
|
+
nk_e5m2_t const *b, nk_size_t column_count, nk_size_t depth, //
|
|
779
|
+
nk_size_t b_stride_in_bytes, void *b_packed) {
|
|
780
|
+
|
|
781
|
+
nk_size_t const tmm_rows = 16;
|
|
782
|
+
nk_size_t const tmm_cols = 32;
|
|
783
|
+
nk_size_t const tile_elements = 512;
|
|
784
|
+
nk_size_t const tile_bytes = tile_elements * sizeof(nk_f16_t);
|
|
785
|
+
nk_size_t const b_stride_elements = b_stride_in_bytes; // E5M2: 1 byte per element
|
|
786
|
+
|
|
787
|
+
nk_size_t const column_tiles_count = column_count / tmm_rows;
|
|
788
|
+
nk_size_t const depth_tiles_count = nk_size_divide_round_up_(depth, tmm_cols);
|
|
789
|
+
nk_size_t const column_remainder_count = column_count - column_tiles_count * tmm_rows;
|
|
790
|
+
nk_size_t const total_tiles = column_tiles_count * depth_tiles_count;
|
|
791
|
+
|
|
792
|
+
nk_dots_amx_packed_header_t *header = (nk_dots_amx_packed_header_t *)b_packed;
|
|
793
|
+
header->full_column_tiles = (nk_u32_t)column_tiles_count;
|
|
794
|
+
header->full_depth_tiles = (nk_u32_t)depth_tiles_count;
|
|
795
|
+
header->column_remainder_count = (nk_u32_t)column_remainder_count;
|
|
796
|
+
|
|
797
|
+
nk_size_t const tiles_offset = sizeof(nk_dots_amx_packed_header_t);
|
|
798
|
+
nk_size_t const column_edge_offset = tiles_offset + total_tiles * tile_bytes;
|
|
799
|
+
header->column_edge_offset = (nk_u32_t)column_edge_offset;
|
|
800
|
+
|
|
801
|
+
nk_f16_t *tiles_ptr = (nk_f16_t *)((char *)b_packed + tiles_offset);
|
|
802
|
+
nk_f16_t *column_edge_ptr = (nk_f16_t *)((char *)b_packed + column_edge_offset);
|
|
803
|
+
|
|
804
|
+
// Pack tiles: widen-and-gather 16 strided E5M2 rows into an F16-aligned source tile,
|
|
805
|
+
// then transpose via the shared BF16/F16 transposer (same 16-bit lane layout).
|
|
806
|
+
for (nk_size_t column_tile_idx = 0; column_tile_idx < column_tiles_count; column_tile_idx++) {
|
|
807
|
+
for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
|
|
808
|
+
|
|
809
|
+
nk_size_t const tile_index = column_tile_idx * depth_tiles_count + depth_tile_idx;
|
|
810
|
+
nk_f16_t *tile_output = tiles_ptr + tile_index * tile_elements;
|
|
811
|
+
|
|
812
|
+
nk_size_t const src_row_start = column_tile_idx * tmm_rows;
|
|
813
|
+
nk_size_t const src_column_start = depth_tile_idx * tmm_cols;
|
|
814
|
+
nk_size_t const columns_to_pack = (src_column_start + tmm_cols <= depth) ? tmm_cols
|
|
815
|
+
: (depth - src_column_start);
|
|
816
|
+
|
|
817
|
+
nk_dots_bf16_a16x32_sapphireamx_t source_tile;
|
|
818
|
+
__mmask32 depth_mask = (__mmask32)((columns_to_pack < 32) ? ((1U << columns_to_pack) - 1) : ~0U);
|
|
819
|
+
for (nk_size_t row_idx = 0; row_idx < tmm_rows; row_idx++) {
|
|
820
|
+
nk_e5m2_t const *source_row = b + (src_row_start + row_idx) * b_stride_elements + src_column_start;
|
|
821
|
+
__m256i e5m2_u8x32 = _mm256_maskz_loadu_epi8(depth_mask, source_row);
|
|
822
|
+
__m512i word_u16x32 = _mm512_cvtepu8_epi16(e5m2_u8x32);
|
|
823
|
+
__m512i f16_bits_i16x32 = _mm512_slli_epi16(word_u16x32, 8);
|
|
824
|
+
_mm512_store_si512(&source_tile.data[row_idx][0], f16_bits_i16x32);
|
|
825
|
+
}
|
|
826
|
+
|
|
827
|
+
nk_dots_bf16_b32x16_sapphireamx_t transposed_tile;
|
|
828
|
+
nk_dots_pack_bf16_transposed_sapphireamx_(&source_tile, &transposed_tile);
|
|
829
|
+
for (nk_size_t i = 0; i < tile_bytes; i += 64)
|
|
830
|
+
_mm512_storeu_si512((char *)tile_output + i, _mm512_load_si512((char const *)&transposed_tile + i));
|
|
831
|
+
}
|
|
832
|
+
}
|
|
833
|
+
|
|
834
|
+
// Column-remainder: widen and store row-major (depth contiguous per column)
|
|
835
|
+
if (column_remainder_count > 0) {
|
|
836
|
+
nk_size_t const remainder_start_row = column_tiles_count * tmm_rows;
|
|
837
|
+
for (nk_size_t row_idx = 0; row_idx < column_remainder_count; row_idx++) {
|
|
838
|
+
nk_e5m2_t const *src_row = b + (remainder_start_row + row_idx) * b_stride_elements;
|
|
839
|
+
nk_f16_t *dst_row = column_edge_ptr + row_idx * depth;
|
|
840
|
+
nk_size_t column_idx = 0;
|
|
841
|
+
for (; column_idx + 32 <= depth; column_idx += 32) {
|
|
842
|
+
__m256i e5m2_u8x32 = _mm256_loadu_si256((__m256i const *)(src_row + column_idx));
|
|
843
|
+
__m512i word_u16x32 = _mm512_cvtepu8_epi16(e5m2_u8x32);
|
|
844
|
+
__m512i f16_bits_i16x32 = _mm512_slli_epi16(word_u16x32, 8);
|
|
845
|
+
_mm512_storeu_si512(dst_row + column_idx, f16_bits_i16x32);
|
|
846
|
+
}
|
|
847
|
+
if (column_idx < depth) {
|
|
848
|
+
__mmask32 tail_mask = (__mmask32)((1U << (depth - column_idx)) - 1);
|
|
849
|
+
__m256i e5m2_u8x32 = _mm256_maskz_loadu_epi8(tail_mask, src_row + column_idx);
|
|
850
|
+
__m512i word_u16x32 = _mm512_cvtepu8_epi16(e5m2_u8x32);
|
|
851
|
+
__m512i f16_bits_i16x32 = _mm512_slli_epi16(word_u16x32, 8);
|
|
852
|
+
_mm512_mask_storeu_epi16(dst_row + column_idx, tail_mask, f16_bits_i16x32);
|
|
853
|
+
}
|
|
854
|
+
}
|
|
855
|
+
}
|
|
856
|
+
|
|
857
|
+
nk_size_t norms_offset = column_edge_offset +
|
|
858
|
+
(column_remainder_count > 0 ? column_remainder_count * depth * sizeof(nk_f16_t) : 0);
|
|
859
|
+
header->norms_byte_offset = (nk_u32_t)norms_offset;
|
|
860
|
+
nk_f32_t *norms = (nk_f32_t *)((char *)b_packed + norms_offset);
|
|
861
|
+
for (nk_size_t col = 0; col < column_count; col++)
|
|
862
|
+
norms[col] = nk_dots_reduce_sumsq_e5m2_(b + col * b_stride_elements, depth);
|
|
863
|
+
}
|
|
864
|
+
|
|
865
|
+
NK_PUBLIC void nk_dots_packed_e5m2_graniteamx( //
|
|
866
|
+
nk_e5m2_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
867
|
+
nk_size_t rows_count, nk_size_t cols_count, nk_size_t depth, nk_size_t a_stride_bytes, nk_size_t c_stride_bytes) {
|
|
868
|
+
nk_unused_(cols_count);
|
|
869
|
+
|
|
870
|
+
nk_dots_amx_packed_header_t const *header = (nk_dots_amx_packed_header_t const *)b_packed;
|
|
871
|
+
nk_size_t const column_tiles_count = header->full_column_tiles;
|
|
872
|
+
nk_size_t const depth_tiles_count = header->full_depth_tiles;
|
|
873
|
+
nk_size_t const column_remainder_count = header->column_remainder_count;
|
|
874
|
+
|
|
875
|
+
nk_f16_t const *b_tiles_base = (nk_f16_t const *)((char const *)b_packed + sizeof(nk_dots_amx_packed_header_t));
|
|
876
|
+
nk_f16_t const *col_edge_ptr = (nk_f16_t const *)((char const *)b_packed + header->column_edge_offset);
|
|
877
|
+
|
|
878
|
+
nk_size_t const a_stride_elements = a_stride_bytes; // E5M2: 1 byte per element
|
|
879
|
+
nk_size_t const c_stride_elements = c_stride_bytes / sizeof(nk_f32_t);
|
|
880
|
+
|
|
881
|
+
nk_size_t const tile_depth = 32;
|
|
882
|
+
nk_size_t const tile_size = 512;
|
|
883
|
+
nk_size_t const full_cols = column_tiles_count * 16;
|
|
884
|
+
|
|
885
|
+
nk_size_t const row_blocks_count = nk_size_divide_round_up_(rows_count, 32);
|
|
886
|
+
nk_size_t const col_blocks_count = column_tiles_count / 2;
|
|
887
|
+
|
|
888
|
+
if (depth_tiles_count == 0) return;
|
|
889
|
+
|
|
890
|
+
nk_dots_e5m2_a16x32_graniteamx_t a_tile_top, a_tile_bottom;
|
|
891
|
+
nk_dots_e5m2_state2x2_graniteamx_t c_accum_buffer;
|
|
892
|
+
|
|
893
|
+
nk_size_t const full_depth_tiles_count = depth / tile_depth;
|
|
894
|
+
nk_size_t const depth_remainder = depth % tile_depth;
|
|
895
|
+
|
|
896
|
+
nk_amx_tile_configure_sapphireamx_();
|
|
897
|
+
|
|
898
|
+
for (nk_size_t row_block_idx = 0; row_block_idx < row_blocks_count; row_block_idx++) {
|
|
899
|
+
nk_size_t const row_block_start = row_block_idx * 32;
|
|
900
|
+
nk_size_t const valid_rows_count = (row_block_start + 32 <= rows_count) ? 32 : (rows_count - row_block_start);
|
|
901
|
+
nk_size_t const is_full_row_block = (valid_rows_count == 32);
|
|
902
|
+
|
|
903
|
+
for (nk_size_t column_block_idx = 0; column_block_idx < col_blocks_count; column_block_idx++) {
|
|
904
|
+
nk_size_t const col_block_start = column_block_idx * 32;
|
|
905
|
+
nk_size_t const b_column_left_base = (column_block_idx * 2) * depth_tiles_count;
|
|
906
|
+
nk_size_t const b_column_right_base = (column_block_idx * 2 + 1) * depth_tiles_count;
|
|
907
|
+
|
|
908
|
+
_tile_zero(4);
|
|
909
|
+
_tile_zero(5);
|
|
910
|
+
_tile_zero(6);
|
|
911
|
+
_tile_zero(7);
|
|
912
|
+
|
|
913
|
+
// E5M2 A needs widen-then-tile-load every time (no fast in-place path), so we always
|
|
914
|
+
// route through the widen buffer a_tile_{top,bottom} rather than `_tile_loadd` from A.
|
|
915
|
+
for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
|
|
916
|
+
nk_size_t const depth_offset = depth_tile_idx * tile_depth;
|
|
917
|
+
nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
|
|
918
|
+
nk_size_t const rows_in_high_tile = is_full_row_block
|
|
919
|
+
? 16
|
|
920
|
+
: ((valid_rows_count > 16) ? 16 : valid_rows_count);
|
|
921
|
+
nk_size_t const rows_in_low_tile = is_full_row_block
|
|
922
|
+
? 16
|
|
923
|
+
: ((valid_rows_count > 16) ? valid_rows_count - 16 : 0);
|
|
924
|
+
|
|
925
|
+
nk_dots_e5m2_load_a_graniteamx_(&a_tile_top, a + row_block_start * a_stride_elements + depth_offset,
|
|
926
|
+
a_stride_elements, rows_in_high_tile, valid_depth);
|
|
927
|
+
if (rows_in_low_tile > 0) {
|
|
928
|
+
nk_dots_e5m2_load_a_graniteamx_(&a_tile_bottom,
|
|
929
|
+
a + (row_block_start + 16) * a_stride_elements + depth_offset,
|
|
930
|
+
a_stride_elements, rows_in_low_tile, valid_depth);
|
|
931
|
+
}
|
|
932
|
+
|
|
933
|
+
nk_dots_f16_b32x16_graniteamx_t const *b_tile_left = //
|
|
934
|
+
(nk_dots_f16_b32x16_graniteamx_t const *)(b_tiles_base +
|
|
935
|
+
(b_column_left_base + depth_tile_idx) * tile_size);
|
|
936
|
+
nk_dots_f16_b32x16_graniteamx_t const *b_tile_right = //
|
|
937
|
+
(nk_dots_f16_b32x16_graniteamx_t const *)(b_tiles_base +
|
|
938
|
+
(b_column_right_base + depth_tile_idx) * tile_size);
|
|
939
|
+
|
|
940
|
+
_tile_loadd(0, a_tile_top.data, 64);
|
|
941
|
+
if (rows_in_low_tile > 0) _tile_loadd(1, a_tile_bottom.data, 64);
|
|
942
|
+
_tile_loadd(2, b_tile_left->data, 64);
|
|
943
|
+
_tile_loadd(3, b_tile_right->data, 64);
|
|
944
|
+
|
|
945
|
+
_tile_dpfp16ps(4, 0, 2);
|
|
946
|
+
_tile_dpfp16ps(5, 0, 3);
|
|
947
|
+
if (rows_in_low_tile > 0) {
|
|
948
|
+
_tile_dpfp16ps(6, 1, 2);
|
|
949
|
+
_tile_dpfp16ps(7, 1, 3);
|
|
950
|
+
}
|
|
951
|
+
}
|
|
952
|
+
|
|
953
|
+
if (is_full_row_block) {
|
|
954
|
+
nk_f32_t *c_block = c + row_block_start * c_stride_elements + col_block_start;
|
|
955
|
+
_tile_stored(4, c_block, c_stride_bytes);
|
|
956
|
+
_tile_stored(5, c_block + 16, c_stride_bytes);
|
|
957
|
+
_tile_stored(6, (nk_f32_t *)((char *)c_block + 16 * c_stride_bytes), c_stride_bytes);
|
|
958
|
+
_tile_stored(7, (nk_f32_t *)((char *)c_block + 16 * c_stride_bytes) + 16, c_stride_bytes);
|
|
959
|
+
}
|
|
960
|
+
else {
|
|
961
|
+
_tile_stored(4, c_accum_buffer.c[0][0].data, 64);
|
|
962
|
+
_tile_stored(5, c_accum_buffer.c[0][1].data, 64);
|
|
963
|
+
_tile_stored(6, c_accum_buffer.c[1][0].data, 64);
|
|
964
|
+
_tile_stored(7, c_accum_buffer.c[1][1].data, 64);
|
|
965
|
+
nk_dots_f16_output2x2_graniteamx_(&c_accum_buffer,
|
|
966
|
+
c + row_block_start * c_stride_elements + col_block_start,
|
|
967
|
+
c_stride_elements, valid_rows_count, 32);
|
|
968
|
+
}
|
|
969
|
+
}
|
|
970
|
+
}
|
|
971
|
+
|
|
972
|
+
// Odd column tile
|
|
973
|
+
if (column_tiles_count % 2 == 1) {
|
|
974
|
+
nk_size_t const column_tile_idx = column_tiles_count - 1;
|
|
975
|
+
nk_size_t const col_start = column_tile_idx * 16;
|
|
976
|
+
nk_size_t const b_column_base = column_tile_idx * depth_tiles_count;
|
|
977
|
+
|
|
978
|
+
for (nk_size_t row_block_idx = 0; row_block_idx < row_blocks_count; row_block_idx++) {
|
|
979
|
+
nk_size_t const row_block_start = row_block_idx * 32;
|
|
980
|
+
nk_size_t const valid_rows_count = (row_block_start + 32 <= rows_count) ? 32
|
|
981
|
+
: (rows_count - row_block_start);
|
|
982
|
+
nk_size_t const rows_in_high_tile = (valid_rows_count > 16) ? 16 : valid_rows_count;
|
|
983
|
+
nk_size_t const rows_in_low_tile = (valid_rows_count > 16) ? valid_rows_count - 16 : 0;
|
|
984
|
+
|
|
985
|
+
nk_dots_f16_state_graniteamx_t c_high_state, c_low_state;
|
|
986
|
+
|
|
987
|
+
_tile_zero(4);
|
|
988
|
+
_tile_zero(6);
|
|
989
|
+
|
|
990
|
+
for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
|
|
991
|
+
nk_size_t const depth_offset = depth_tile_idx * tile_depth;
|
|
992
|
+
nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
|
|
993
|
+
|
|
994
|
+
nk_dots_e5m2_load_a_graniteamx_(&a_tile_top, a + row_block_start * a_stride_elements + depth_offset,
|
|
995
|
+
a_stride_elements, rows_in_high_tile, valid_depth);
|
|
996
|
+
if (rows_in_low_tile > 0) {
|
|
997
|
+
nk_dots_e5m2_load_a_graniteamx_(&a_tile_bottom,
|
|
998
|
+
a + (row_block_start + 16) * a_stride_elements + depth_offset,
|
|
999
|
+
a_stride_elements, rows_in_low_tile, valid_depth);
|
|
1000
|
+
}
|
|
1001
|
+
|
|
1002
|
+
nk_dots_f16_b32x16_graniteamx_t const *b_tile = //
|
|
1003
|
+
(nk_dots_f16_b32x16_graniteamx_t const *)(b_tiles_base +
|
|
1004
|
+
(b_column_base + depth_tile_idx) * tile_size);
|
|
1005
|
+
|
|
1006
|
+
_tile_loadd(0, a_tile_top.data, 64);
|
|
1007
|
+
if (rows_in_low_tile > 0) _tile_loadd(1, a_tile_bottom.data, 64);
|
|
1008
|
+
_tile_loadd(2, b_tile->data, 64);
|
|
1009
|
+
|
|
1010
|
+
_tile_dpfp16ps(4, 0, 2);
|
|
1011
|
+
if (rows_in_low_tile > 0) _tile_dpfp16ps(6, 1, 2);
|
|
1012
|
+
}
|
|
1013
|
+
|
|
1014
|
+
_tile_stored(4, c_high_state.data, 64);
|
|
1015
|
+
if (rows_in_low_tile > 0) _tile_stored(6, c_low_state.data, 64);
|
|
1016
|
+
|
|
1017
|
+
nk_dots_f16_store_graniteamx_(&c_high_state, c + row_block_start * c_stride_elements + col_start,
|
|
1018
|
+
c_stride_elements, rows_in_high_tile, 16);
|
|
1019
|
+
if (rows_in_low_tile > 0) {
|
|
1020
|
+
nk_dots_f16_store_graniteamx_(&c_low_state, c + (row_block_start + 16) * c_stride_elements + col_start,
|
|
1021
|
+
c_stride_elements, rows_in_low_tile, 16);
|
|
1022
|
+
}
|
|
1023
|
+
}
|
|
1024
|
+
}
|
|
1025
|
+
|
|
1026
|
+
// Column edge (fewer than 16 extra columns, stored row-major as F16 after pack)
|
|
1027
|
+
if (column_remainder_count > 0) {
|
|
1028
|
+
for (nk_size_t row_block_idx = 0; row_block_idx < row_blocks_count; row_block_idx++) {
|
|
1029
|
+
nk_size_t const row_block_start = row_block_idx * 32;
|
|
1030
|
+
nk_size_t const valid_rows_count = (row_block_start + 32 <= rows_count) ? 32
|
|
1031
|
+
: (rows_count - row_block_start);
|
|
1032
|
+
nk_size_t const rows_in_high_tile = (valid_rows_count > 16) ? 16 : valid_rows_count;
|
|
1033
|
+
nk_size_t const rows_in_low_tile = (valid_rows_count > 16) ? valid_rows_count - 16 : 0;
|
|
1034
|
+
|
|
1035
|
+
nk_dots_f16_state_graniteamx_t c_high_state, c_low_state;
|
|
1036
|
+
nk_dots_bf16_a16x32_sapphireamx_t b_as_a;
|
|
1037
|
+
nk_dots_bf16_b32x16_sapphireamx_t b_tile;
|
|
1038
|
+
|
|
1039
|
+
_tile_zero(4);
|
|
1040
|
+
_tile_zero(6);
|
|
1041
|
+
|
|
1042
|
+
for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
|
|
1043
|
+
nk_size_t const depth_offset = depth_tile_idx * tile_depth;
|
|
1044
|
+
nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
|
|
1045
|
+
|
|
1046
|
+
nk_dots_e5m2_load_a_graniteamx_(&a_tile_top, a + row_block_start * a_stride_elements + depth_offset,
|
|
1047
|
+
a_stride_elements, rows_in_high_tile, valid_depth);
|
|
1048
|
+
if (rows_in_low_tile > 0) {
|
|
1049
|
+
nk_dots_e5m2_load_a_graniteamx_(&a_tile_bottom,
|
|
1050
|
+
a + (row_block_start + 16) * a_stride_elements + depth_offset,
|
|
1051
|
+
a_stride_elements, rows_in_low_tile, valid_depth);
|
|
1052
|
+
}
|
|
1053
|
+
|
|
1054
|
+
nk_dots_bf16_load_a_sapphireamx_(&b_as_a, (nk_bf16_t const *)(col_edge_ptr + depth_offset), depth,
|
|
1055
|
+
column_remainder_count, valid_depth);
|
|
1056
|
+
nk_dots_pack_bf16_transposed_sapphireamx_(&b_as_a, &b_tile);
|
|
1057
|
+
|
|
1058
|
+
_tile_loadd(0, a_tile_top.data, 64);
|
|
1059
|
+
if (rows_in_low_tile > 0) _tile_loadd(1, a_tile_bottom.data, 64);
|
|
1060
|
+
_tile_loadd(2, b_tile.data, 64);
|
|
1061
|
+
|
|
1062
|
+
_tile_dpfp16ps(4, 0, 2);
|
|
1063
|
+
if (rows_in_low_tile > 0) _tile_dpfp16ps(6, 1, 2);
|
|
1064
|
+
}
|
|
1065
|
+
|
|
1066
|
+
_tile_stored(4, c_high_state.data, 64);
|
|
1067
|
+
if (rows_in_low_tile > 0) _tile_stored(6, c_low_state.data, 64);
|
|
1068
|
+
|
|
1069
|
+
nk_dots_f16_store_graniteamx_(&c_high_state, c + row_block_start * c_stride_elements + full_cols,
|
|
1070
|
+
c_stride_elements, rows_in_high_tile, column_remainder_count);
|
|
1071
|
+
if (rows_in_low_tile > 0) {
|
|
1072
|
+
nk_dots_f16_store_graniteamx_(&c_low_state, c + (row_block_start + 16) * c_stride_elements + full_cols,
|
|
1073
|
+
c_stride_elements, rows_in_low_tile, column_remainder_count);
|
|
1074
|
+
}
|
|
1075
|
+
}
|
|
1076
|
+
}
|
|
1077
|
+
|
|
1078
|
+
_tile_release();
|
|
1079
|
+
}
|
|
1080
|
+
|
|
1081
|
+
NK_PUBLIC void nk_dots_symmetric_e5m2_graniteamx( //
|
|
1082
|
+
nk_e5m2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, //
|
|
1083
|
+
nk_size_t stride_in_bytes, nk_f32_t *result, nk_size_t result_stride_in_bytes, //
|
|
1084
|
+
nk_size_t row_start, nk_size_t row_count) {
|
|
1085
|
+
|
|
1086
|
+
nk_size_t const stride_elements = stride_in_bytes; // E5M2: 1 byte per element
|
|
1087
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
1088
|
+
|
|
1089
|
+
nk_size_t const row_end = (row_count == 0)
|
|
1090
|
+
? vectors_count
|
|
1091
|
+
: (row_start + row_count < vectors_count ? row_start + row_count : vectors_count);
|
|
1092
|
+
|
|
1093
|
+
nk_size_t const depth_tiles = nk_size_divide_round_up_(depth, 32);
|
|
1094
|
+
nk_size_t const depth_tile_groups = nk_size_divide_round_up_(depth_tiles, 3);
|
|
1095
|
+
|
|
1096
|
+
nk_dots_e5m2_a16x32_graniteamx_t a_tiles[3];
|
|
1097
|
+
nk_dots_e5m2_a16x32_graniteamx_t b_src_tiles[3];
|
|
1098
|
+
nk_dots_f16_b32x16_graniteamx_t b_tiles[3];
|
|
1099
|
+
nk_dots_f16_state_graniteamx_t state;
|
|
1100
|
+
|
|
1101
|
+
nk_amx_tile_configure_sapphireamx_();
|
|
1102
|
+
|
|
1103
|
+
for (nk_size_t row_tile = row_start; row_tile < row_end; row_tile += 16) {
|
|
1104
|
+
nk_size_t const valid_rows = (row_tile + 16 <= row_end) ? 16 : (row_end - row_tile);
|
|
1105
|
+
|
|
1106
|
+
for (nk_size_t col_tile = 0; col_tile < vectors_count; col_tile += 16) {
|
|
1107
|
+
nk_size_t const valid_cols = (col_tile + 16 <= vectors_count) ? 16 : (vectors_count - col_tile);
|
|
1108
|
+
|
|
1109
|
+
nk_dots_f16_init_graniteamx_(&state);
|
|
1110
|
+
|
|
1111
|
+
for (nk_size_t depth_group_idx = 0; depth_group_idx < depth_tile_groups; depth_group_idx++) {
|
|
1112
|
+
nk_size_t const depth_base = depth_group_idx * 96;
|
|
1113
|
+
|
|
1114
|
+
for (int tile_idx = 0; tile_idx < 3; tile_idx++) {
|
|
1115
|
+
nk_size_t const depth_start = depth_base + tile_idx * 32;
|
|
1116
|
+
nk_size_t const valid_depth = (depth_start + 32 <= depth)
|
|
1117
|
+
? 32
|
|
1118
|
+
: (depth > depth_start ? depth - depth_start : 0);
|
|
1119
|
+
|
|
1120
|
+
nk_dots_e5m2_load_a_graniteamx_( //
|
|
1121
|
+
&a_tiles[tile_idx], //
|
|
1122
|
+
vectors + row_tile * stride_elements + depth_start, //
|
|
1123
|
+
stride_elements, valid_rows, valid_depth);
|
|
1124
|
+
|
|
1125
|
+
if (row_tile == col_tile) {
|
|
1126
|
+
nk_dots_pack_bf16_transposed_sapphireamx_(
|
|
1127
|
+
(nk_dots_bf16_a16x32_sapphireamx_t const *)&a_tiles[tile_idx],
|
|
1128
|
+
(nk_dots_bf16_b32x16_sapphireamx_t *)&b_tiles[tile_idx]);
|
|
1129
|
+
}
|
|
1130
|
+
else {
|
|
1131
|
+
nk_dots_e5m2_load_a_graniteamx_( //
|
|
1132
|
+
&b_src_tiles[tile_idx], //
|
|
1133
|
+
vectors + col_tile * stride_elements + depth_start, //
|
|
1134
|
+
stride_elements, valid_cols, valid_depth);
|
|
1135
|
+
nk_dots_pack_bf16_transposed_sapphireamx_(
|
|
1136
|
+
(nk_dots_bf16_a16x32_sapphireamx_t const *)&b_src_tiles[tile_idx],
|
|
1137
|
+
(nk_dots_bf16_b32x16_sapphireamx_t *)&b_tiles[tile_idx]);
|
|
1138
|
+
}
|
|
1139
|
+
}
|
|
1140
|
+
|
|
1141
|
+
nk_dots_f16_update_graniteamx_( //
|
|
1142
|
+
&state, &a_tiles[0], &a_tiles[1], &a_tiles[2], //
|
|
1143
|
+
&b_tiles[0], &b_tiles[1], &b_tiles[2]);
|
|
1144
|
+
}
|
|
1145
|
+
|
|
1146
|
+
nk_dots_f16_store_graniteamx_( //
|
|
1147
|
+
&state, result + row_tile * result_stride_elements + col_tile, //
|
|
1148
|
+
result_stride_elements, valid_rows, valid_cols);
|
|
1149
|
+
}
|
|
1150
|
+
}
|
|
1151
|
+
}
|
|
1152
|
+
|
|
1153
|
+
#pragma endregion E5M2 Source
|
|
1154
|
+
|
|
721
1155
|
#if defined(__clang__)
|
|
722
1156
|
#pragma clang attribute pop
|
|
723
1157
|
#elif defined(__GNUC__)
|
|
@@ -115,45 +115,45 @@ nk_define_cross_packed_(dots, bf16, haswell, bf16, bf16, f32, nk_b256_vec_t, nk_
|
|
|
115
115
|
nk_partial_store_b32x4_haswell_,
|
|
116
116
|
/*depth_simd_dimensions=*/16, /*dimensions_per_value=*/1)
|
|
117
117
|
|
|
118
|
-
/* E4M3 GEMM: depth_simd_dimensions=
|
|
119
|
-
nk_define_cross_pack_size_(dots, e4m3, haswell, e4m3, f32, /*norm_value_type=*/f32, /*depth_simd_dimensions=*/
|
|
118
|
+
/* E4M3 GEMM: depth_simd_dimensions=32 (byte-level batch; widen inside the update helper) */
|
|
119
|
+
nk_define_cross_pack_size_(dots, e4m3, haswell, e4m3, f32, /*norm_value_type=*/f32, /*depth_simd_dimensions=*/32,
|
|
120
120
|
/*dimensions_per_value=*/1)
|
|
121
|
-
nk_define_cross_pack_(dots, e4m3, haswell, e4m3, f32, nk_b256_vec_t,
|
|
122
|
-
|
|
123
|
-
/*simd_width=*/
|
|
124
|
-
/*depth_simd_dimensions=*/
|
|
121
|
+
nk_define_cross_pack_(dots, e4m3, haswell, e4m3, f32, nk_b256_vec_t, nk_load_b256_haswell_,
|
|
122
|
+
nk_partial_load_b8x32_serial_, nk_store_b256_haswell_, nk_partial_store_b8x32_serial_,
|
|
123
|
+
/*simd_width=*/32, /*norm_value_type=*/f32, nk_dots_reduce_sumsq_e4m3_,
|
|
124
|
+
/*depth_simd_dimensions=*/32, /*dimensions_per_value=*/1)
|
|
125
125
|
nk_define_cross_symmetric_(dots, e4m3, haswell, e4m3, f32, nk_b256_vec_t, nk_dot_through_f32_state_haswell_t_,
|
|
126
|
-
nk_b128_vec_t, nk_dot_through_f32_init_haswell_,
|
|
127
|
-
|
|
126
|
+
nk_b128_vec_t, nk_dot_through_f32_init_haswell_, nk_load_b256_haswell_,
|
|
127
|
+
nk_partial_load_b8x32_serial_, nk_dot_e4m3x32_update_haswell_,
|
|
128
128
|
nk_dot_through_f32_finalize_haswell_, nk_store_b128_haswell_,
|
|
129
129
|
nk_partial_store_b32x4_haswell_,
|
|
130
|
-
/*depth_simd_dimensions=*/
|
|
130
|
+
/*depth_simd_dimensions=*/32, /*dimensions_per_value=*/1)
|
|
131
131
|
nk_define_cross_packed_(dots, e4m3, haswell, e4m3, f32, f32, nk_b256_vec_t, nk_dot_through_f32_state_haswell_t_,
|
|
132
|
-
nk_b128_vec_t, nk_dot_through_f32_init_haswell_,
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
/*depth_simd_dimensions=*/
|
|
132
|
+
nk_b128_vec_t, nk_dot_through_f32_init_haswell_, nk_load_b256_haswell_,
|
|
133
|
+
nk_partial_load_b8x32_serial_, nk_load_b256_haswell_, nk_partial_load_b8x32_serial_,
|
|
134
|
+
nk_dot_e4m3x32_update_haswell_, nk_dot_through_f32_finalize_haswell_, nk_store_b128_haswell_,
|
|
135
|
+
nk_partial_store_b32x4_haswell_,
|
|
136
|
+
/*depth_simd_dimensions=*/32, /*dimensions_per_value=*/1)
|
|
137
137
|
|
|
138
|
-
/* E5M2 GEMM: depth_simd_dimensions=
|
|
139
|
-
nk_define_cross_pack_size_(dots, e5m2, haswell, e5m2, f32, /*norm_value_type=*/f32, /*depth_simd_dimensions=*/
|
|
138
|
+
/* E5M2 GEMM: depth_simd_dimensions=32 (byte-level batch; widen inside the update helper) */
|
|
139
|
+
nk_define_cross_pack_size_(dots, e5m2, haswell, e5m2, f32, /*norm_value_type=*/f32, /*depth_simd_dimensions=*/32,
|
|
140
140
|
/*dimensions_per_value=*/1)
|
|
141
|
-
nk_define_cross_pack_(dots, e5m2, haswell, e5m2, f32, nk_b256_vec_t,
|
|
142
|
-
|
|
143
|
-
/*simd_width=*/
|
|
144
|
-
/*depth_simd_dimensions=*/
|
|
141
|
+
nk_define_cross_pack_(dots, e5m2, haswell, e5m2, f32, nk_b256_vec_t, nk_load_b256_haswell_,
|
|
142
|
+
nk_partial_load_b8x32_serial_, nk_store_b256_haswell_, nk_partial_store_b8x32_serial_,
|
|
143
|
+
/*simd_width=*/32, /*norm_value_type=*/f32, nk_dots_reduce_sumsq_e5m2_,
|
|
144
|
+
/*depth_simd_dimensions=*/32, /*dimensions_per_value=*/1)
|
|
145
145
|
nk_define_cross_symmetric_(dots, e5m2, haswell, e5m2, f32, nk_b256_vec_t, nk_dot_through_f32_state_haswell_t_,
|
|
146
|
-
nk_b128_vec_t, nk_dot_through_f32_init_haswell_,
|
|
147
|
-
|
|
146
|
+
nk_b128_vec_t, nk_dot_through_f32_init_haswell_, nk_load_b256_haswell_,
|
|
147
|
+
nk_partial_load_b8x32_serial_, nk_dot_e5m2x32_update_haswell_,
|
|
148
148
|
nk_dot_through_f32_finalize_haswell_, nk_store_b128_haswell_,
|
|
149
149
|
nk_partial_store_b32x4_haswell_,
|
|
150
|
-
/*depth_simd_dimensions=*/
|
|
150
|
+
/*depth_simd_dimensions=*/32, /*dimensions_per_value=*/1)
|
|
151
151
|
nk_define_cross_packed_(dots, e5m2, haswell, e5m2, f32, f32, nk_b256_vec_t, nk_dot_through_f32_state_haswell_t_,
|
|
152
|
-
nk_b128_vec_t, nk_dot_through_f32_init_haswell_,
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
/*depth_simd_dimensions=*/
|
|
152
|
+
nk_b128_vec_t, nk_dot_through_f32_init_haswell_, nk_load_b256_haswell_,
|
|
153
|
+
nk_partial_load_b8x32_serial_, nk_load_b256_haswell_, nk_partial_load_b8x32_serial_,
|
|
154
|
+
nk_dot_e5m2x32_update_haswell_, nk_dot_through_f32_finalize_haswell_, nk_store_b128_haswell_,
|
|
155
|
+
nk_partial_store_b32x4_haswell_,
|
|
156
|
+
/*depth_simd_dimensions=*/32, /*dimensions_per_value=*/1)
|
|
157
157
|
|
|
158
158
|
/* E2M3 GEMM: integer LUT path, depth_simd_dimensions=32 (32 e2m3s = 32 bytes = AVX2 register width) */
|
|
159
159
|
nk_define_cross_pack_size_(dots, e2m3, haswell, e2m3, e2m3, /*norm_value_type=*/f32, /*depth_simd_dimensions=*/32,
|
|
@@ -73,7 +73,7 @@
|
|
|
73
73
|
#if NK_TARGET_SAPPHIREAMX
|
|
74
74
|
|
|
75
75
|
#include "numkong/cast/icelake.h" // For FP8 ↔ BF16 conversions
|
|
76
|
-
#include "numkong/dots/serial.h" //
|
|
76
|
+
#include "numkong/dots/serial.h" // `nk_dots_reduce_sumsq_bf16_`
|
|
77
77
|
|
|
78
78
|
#if defined(__cplusplus)
|
|
79
79
|
extern "C" {
|