numkong 7.5.0 → 7.6.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. package/binding.gyp +18 -0
  2. package/c/dispatch_e5m2.c +23 -3
  3. package/include/numkong/capabilities.h +1 -1
  4. package/include/numkong/cast/README.md +3 -0
  5. package/include/numkong/cast/haswell.h +28 -64
  6. package/include/numkong/cast/serial.h +17 -0
  7. package/include/numkong/cast/skylake.h +67 -52
  8. package/include/numkong/cast.h +1 -0
  9. package/include/numkong/dot/README.md +1 -0
  10. package/include/numkong/dot/haswell.h +92 -13
  11. package/include/numkong/dot/serial.h +15 -0
  12. package/include/numkong/dot/skylake.h +61 -14
  13. package/include/numkong/dots/README.md +2 -0
  14. package/include/numkong/dots/graniteamx.h +434 -0
  15. package/include/numkong/dots/haswell.h +28 -28
  16. package/include/numkong/dots/sapphireamx.h +1 -1
  17. package/include/numkong/dots/serial.h +23 -8
  18. package/include/numkong/dots/skylake.h +28 -23
  19. package/include/numkong/dots.h +12 -0
  20. package/include/numkong/each/serial.h +18 -1
  21. package/include/numkong/geospatial/serial.h +14 -3
  22. package/include/numkong/maxsim/serial.h +15 -0
  23. package/include/numkong/mesh/README.md +50 -44
  24. package/include/numkong/mesh/genoa.h +462 -0
  25. package/include/numkong/mesh/haswell.h +806 -933
  26. package/include/numkong/mesh/neon.h +871 -943
  27. package/include/numkong/mesh/neonbfdot.h +382 -522
  28. package/include/numkong/mesh/neonfhm.h +676 -0
  29. package/include/numkong/mesh/rvv.h +404 -319
  30. package/include/numkong/mesh/serial.h +204 -162
  31. package/include/numkong/mesh/skylake.h +1029 -1585
  32. package/include/numkong/mesh/v128relaxed.h +403 -377
  33. package/include/numkong/mesh.h +38 -0
  34. package/include/numkong/reduce/serial.h +15 -1
  35. package/include/numkong/sparse/serial.h +17 -2
  36. package/include/numkong/spatial/genoa.h +0 -68
  37. package/include/numkong/spatial/haswell.h +98 -56
  38. package/include/numkong/spatial/serial.h +15 -0
  39. package/include/numkong/spatial/skylake.h +114 -54
  40. package/include/numkong/spatial.h +0 -12
  41. package/include/numkong/spatials/graniteamx.h +128 -0
  42. package/include/numkong/spatials/serial.h +18 -1
  43. package/include/numkong/spatials/skylake.h +2 -2
  44. package/include/numkong/spatials.h +17 -0
  45. package/include/numkong/tensor.hpp +107 -23
  46. package/javascript/numkong.c +3 -2
  47. package/package.json +7 -7
  48. package/wasm/numkong.wasm +0 -0
@@ -156,6 +156,36 @@ NK_INTERNAL void nk_dot_through_f32_update_skylake_(nk_dot_through_f32_state_sky
156
156
  state->sum_f32x16 = _mm512_fmadd_ps(a.zmm_ps, b.zmm_ps, state->sum_f32x16);
157
157
  }
158
158
 
159
+ /**
160
+ * @brief E5M2 byte-batched update: consumes 64 raw E5M2 bytes per call and widens inline.
161
+ * Two independent FMA chains (each 2-deep) merge into the single state accumulator at exit.
162
+ * Keeps register pressure at one __m512 across calls while breaking the FMA dep chain.
163
+ */
164
+ NK_INTERNAL void nk_dot_e5m2x64_update_skylake_(nk_dot_through_f32_state_skylake_t_ *state, nk_b512_vec_t a_bytes,
165
+ nk_b512_vec_t b_bytes, nk_size_t depth_offset,
166
+ nk_size_t active_dimensions) {
167
+ nk_unused_(depth_offset);
168
+ nk_unused_(active_dimensions);
169
+ __m512i const zero_u8x64 = _mm512_setzero_si512();
170
+ __m512i a_even_f16x32 = _mm512_unpacklo_epi8(zero_u8x64, a_bytes.zmm);
171
+ __m512i a_odd_f16x32 = _mm512_unpackhi_epi8(zero_u8x64, a_bytes.zmm);
172
+ __m512i b_even_f16x32 = _mm512_unpacklo_epi8(zero_u8x64, b_bytes.zmm);
173
+ __m512i b_odd_f16x32 = _mm512_unpackhi_epi8(zero_u8x64, b_bytes.zmm);
174
+ __m512 a_first_f32x16 = _mm512_cvtph_ps(_mm512_castsi512_si256(a_even_f16x32));
175
+ __m512 a_second_f32x16 = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(a_even_f16x32, 1));
176
+ __m512 a_third_f32x16 = _mm512_cvtph_ps(_mm512_castsi512_si256(a_odd_f16x32));
177
+ __m512 a_fourth_f32x16 = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(a_odd_f16x32, 1));
178
+ __m512 b_first_f32x16 = _mm512_cvtph_ps(_mm512_castsi512_si256(b_even_f16x32));
179
+ __m512 b_second_f32x16 = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(b_even_f16x32, 1));
180
+ __m512 b_third_f32x16 = _mm512_cvtph_ps(_mm512_castsi512_si256(b_odd_f16x32));
181
+ __m512 b_fourth_f32x16 = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(b_odd_f16x32, 1));
182
+ __m512 first_chain_f32x16 = _mm512_mul_ps(a_first_f32x16, b_first_f32x16);
183
+ __m512 second_chain_f32x16 = _mm512_mul_ps(a_second_f32x16, b_second_f32x16);
184
+ first_chain_f32x16 = _mm512_fmadd_ps(a_third_f32x16, b_third_f32x16, first_chain_f32x16);
185
+ second_chain_f32x16 = _mm512_fmadd_ps(a_fourth_f32x16, b_fourth_f32x16, second_chain_f32x16);
186
+ state->sum_f32x16 = _mm512_add_ps(state->sum_f32x16, _mm512_add_ps(first_chain_f32x16, second_chain_f32x16));
187
+ }
188
+
159
189
  /**
160
190
  * @brief Finalizes 4x low-precision dot-products placing them into 4x consecutive 32-bit slots.
161
191
  * @sa nk_dot_f16x16_udpate_skylake, nk_dot_bf16x16_udpate_skylake
@@ -543,7 +573,7 @@ NK_PUBLIC void nk_dot_e4m3_skylake(nk_e4m3_t const *a_scalars, nk_e4m3_t const *
543
573
 
544
574
  nk_dot_e4m3_skylake_cycle:
545
575
  if (count_scalars < 16) {
546
- __mmask16 mask = (__mmask16)_bzhi_u32(0xFFFF, count_scalars);
576
+ __mmask16 mask = (__mmask16)_bzhi_u32(0xFFFF, (unsigned int)count_scalars);
547
577
  a_e4m3_u8x16 = _mm_maskz_loadu_epi8(mask, a_scalars);
548
578
  b_e4m3_u8x16 = _mm_maskz_loadu_epi8(mask, b_scalars);
549
579
  count_scalars = 0;
@@ -563,27 +593,44 @@ nk_dot_e4m3_skylake_cycle:
563
593
 
564
594
  NK_PUBLIC void nk_dot_e5m2_skylake(nk_e5m2_t const *a_scalars, nk_e5m2_t const *b_scalars, nk_size_t count_scalars,
565
595
  nk_f32_t *result) {
566
- __m128i a_e5m2_u8x16, b_e5m2_u8x16;
567
- __m512 sum_f32x16 = _mm512_setzero_ps();
596
+ // E5M2 shares F16 bias (15): vpunpck*bw against zero places the byte as F16 encoding,
597
+ // so we inline the widen rather than calling the helper 4× — same ops, cleaner code.
598
+ __m512 first_chain_f32x16 = _mm512_setzero_ps();
599
+ __m512 second_chain_f32x16 = _mm512_setzero_ps();
600
+ __m512i const zero_u8x64 = _mm512_setzero_si512();
601
+ __m512i a_u8x64, b_u8x64;
568
602
 
569
603
  nk_dot_e5m2_skylake_cycle:
570
- if (count_scalars < 16) {
571
- __mmask16 mask = (__mmask16)_bzhi_u32(0xFFFF, count_scalars);
572
- a_e5m2_u8x16 = _mm_maskz_loadu_epi8(mask, a_scalars);
573
- b_e5m2_u8x16 = _mm_maskz_loadu_epi8(mask, b_scalars);
604
+ if (count_scalars < 64) {
605
+ __mmask64 mask = _bzhi_u64(0xFFFFFFFFFFFFFFFFULL, (unsigned int)count_scalars);
606
+ a_u8x64 = _mm512_maskz_loadu_epi8(mask, a_scalars);
607
+ b_u8x64 = _mm512_maskz_loadu_epi8(mask, b_scalars);
574
608
  count_scalars = 0;
575
609
  }
576
610
  else {
577
- a_e5m2_u8x16 = _mm_loadu_si128((__m128i const *)a_scalars);
578
- b_e5m2_u8x16 = _mm_loadu_si128((__m128i const *)b_scalars);
579
- a_scalars += 16, b_scalars += 16, count_scalars -= 16;
611
+ a_u8x64 = _mm512_loadu_si512((__m512i const *)a_scalars);
612
+ b_u8x64 = _mm512_loadu_si512((__m512i const *)b_scalars);
613
+ a_scalars += 64, b_scalars += 64, count_scalars -= 64;
580
614
  }
581
- __m512 a_f32x16 = nk_e5m2x16_to_f32x16_skylake_(a_e5m2_u8x16);
582
- __m512 b_f32x16 = nk_e5m2x16_to_f32x16_skylake_(b_e5m2_u8x16);
583
- sum_f32x16 = _mm512_fmadd_ps(a_f32x16, b_f32x16, sum_f32x16);
615
+ __m512i a_even_f16x32 = _mm512_unpacklo_epi8(zero_u8x64, a_u8x64);
616
+ __m512i a_odd_f16x32 = _mm512_unpackhi_epi8(zero_u8x64, a_u8x64);
617
+ __m512i b_even_f16x32 = _mm512_unpacklo_epi8(zero_u8x64, b_u8x64);
618
+ __m512i b_odd_f16x32 = _mm512_unpackhi_epi8(zero_u8x64, b_u8x64);
619
+ __m512 a_first_f32x16 = _mm512_cvtph_ps(_mm512_castsi512_si256(a_even_f16x32));
620
+ __m512 a_second_f32x16 = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(a_even_f16x32, 1));
621
+ __m512 a_third_f32x16 = _mm512_cvtph_ps(_mm512_castsi512_si256(a_odd_f16x32));
622
+ __m512 a_fourth_f32x16 = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(a_odd_f16x32, 1));
623
+ __m512 b_first_f32x16 = _mm512_cvtph_ps(_mm512_castsi512_si256(b_even_f16x32));
624
+ __m512 b_second_f32x16 = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(b_even_f16x32, 1));
625
+ __m512 b_third_f32x16 = _mm512_cvtph_ps(_mm512_castsi512_si256(b_odd_f16x32));
626
+ __m512 b_fourth_f32x16 = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(b_odd_f16x32, 1));
627
+ first_chain_f32x16 = _mm512_fmadd_ps(a_first_f32x16, b_first_f32x16, first_chain_f32x16);
628
+ second_chain_f32x16 = _mm512_fmadd_ps(a_second_f32x16, b_second_f32x16, second_chain_f32x16);
629
+ first_chain_f32x16 = _mm512_fmadd_ps(a_third_f32x16, b_third_f32x16, first_chain_f32x16);
630
+ second_chain_f32x16 = _mm512_fmadd_ps(a_fourth_f32x16, b_fourth_f32x16, second_chain_f32x16);
584
631
  if (count_scalars) goto nk_dot_e5m2_skylake_cycle;
585
632
 
586
- *result = nk_reduce_add_f32x16_skylake_(sum_f32x16);
633
+ *result = nk_reduce_add_f32x16_skylake_(_mm512_add_ps(first_chain_f32x16, second_chain_f32x16));
587
634
  }
588
635
 
589
636
  NK_PUBLIC void nk_dot_e2m3_skylake(nk_e2m3_t const *a_scalars, nk_e2m3_t const *b_scalars, nk_size_t count_scalars,
@@ -72,6 +72,8 @@ Int8 data is quad-interleaved: [a₀, a₁, a₂, a₃, a₀, a₁, a₂, a₃,
72
72
  Tile configuration via `LDTILECFG` sets row counts and column byte-widths per tile — allows undersized tiles at matrix edges without masking.
73
73
  Morton Z-curve ordering for tile traversal improves cache reuse when both A and B exceed L2.
74
74
  This eliminates the explicit M×N×K loop nesting and register file pressure of vector ISAs — the entire dot-product reduction happens inside the tile instruction.
75
+ FP8 inputs on Sapphire AMX go through an on-the-fly E4M3/E5M2 → BF16 pack via the Ice Lake `VPERMI2W` LUT helpers — port-5-bound but the simplest correct route to feed `TDPBF16PS` tiles.
76
+ Granite Rapids adds `TDPFP16PS` (same tile shape, FP16 operands); the E5M2 variant widens inputs with a single `VPUNPCK*BW` against zero into FP16 tiles at pack time and then reuses the native FP16 compute loop — keeps the intermediate at FP16 precision instead of truncating to BF16 like the Sapphire path.
75
77
 
76
78
  ### SME Outer-Product Streaming
77
79
 
@@ -718,6 +718,440 @@ NK_PUBLIC void nk_dots_symmetric_f16_graniteamx(
718
718
 
719
719
  #pragma endregion F16 Native
720
720
 
721
+ #pragma region E5M2 Source (widened to FP16 tiles)
722
+
723
+ /* E5M2 Granite AMX kernels: same F16 tile shapes, same TDPFP16PS compute body as the F16 path.
724
+ * The only difference is byte-to-word widening during A-load and B-pack: E5M2 shares F16's
725
+ * exponent bias (15), so `(byte << 8)` is the exact F16 bit pattern for every E5M2 value,
726
+ * including zero/subnormals/Inf/NaN. Tile buffers hold F16 after widen — "e5m2" in the
727
+ * typedefs refers to the source dtype, not the on-tile representation.
728
+ *
729
+ * The tile types below alias the F16 types (identical memory layout) so we can reuse the
730
+ * F16 init/store/output2x2/update internal helpers by pointer cast at the boundary. Only the
731
+ * public entry points and the load-A widen helper are new.
732
+ */
733
+
734
+ typedef nk_dots_f16_a16x32_graniteamx_t nk_dots_e5m2_a16x32_graniteamx_t;
735
+ typedef nk_dots_f16_b32x16_graniteamx_t nk_dots_e5m2_b32x16_graniteamx_t;
736
+ typedef nk_dots_f16_state_graniteamx_t nk_dots_e5m2_state_graniteamx_t;
737
+ typedef nk_dots_f16_state2x2_graniteamx_t nk_dots_e5m2_state2x2_graniteamx_t;
738
+
739
+ /* Load A tile from E5M2 row-major source, widen to F16 via `(byte << 8)` into the F16 tile buffer. */
740
+ NK_INTERNAL void nk_dots_e5m2_load_a_graniteamx_( //
741
+ nk_dots_e5m2_a16x32_graniteamx_t *a_tile, //
742
+ nk_e5m2_t const *src, nk_size_t src_stride_elements, //
743
+ nk_size_t valid_rows, nk_size_t valid_cols) {
744
+
745
+ __mmask32 column_mask = (valid_cols >= 32) ? 0xFFFFFFFF : ((__mmask32)1 << valid_cols) - 1;
746
+ __m512i zero_i16x32 = _mm512_setzero_si512();
747
+
748
+ for (nk_size_t row_idx = 0; row_idx < 16; row_idx++) {
749
+ if (row_idx < valid_rows) {
750
+ __m256i e5m2_u8x32 = _mm256_maskz_loadu_epi8(column_mask, src + row_idx * src_stride_elements);
751
+ __m512i word_u16x32 = _mm512_cvtepu8_epi16(e5m2_u8x32);
752
+ __m512i f16_bits_i16x32 = _mm512_slli_epi16(word_u16x32, 8);
753
+ _mm512_store_si512((__m512i *)a_tile->data[row_idx], f16_bits_i16x32);
754
+ }
755
+ else { _mm512_store_si512((__m512i *)a_tile->data[row_idx], zero_i16x32); }
756
+ }
757
+ nk_compiler_barrier_sapphireamx_();
758
+ }
759
+
760
+ NK_PUBLIC nk_size_t nk_dots_packed_size_e5m2_graniteamx(nk_size_t column_count, nk_size_t depth) {
761
+ nk_size_t const tmm_rows = 16;
762
+ nk_size_t const tmm_cols = 32;
763
+ nk_size_t const tile_bytes = 512 * sizeof(nk_f16_t); // Tiles hold F16 after widen: same 1KB as F16.
764
+
765
+ nk_size_t const full_column_tiles = column_count / tmm_rows;
766
+ nk_size_t const tiles_along_depth = nk_size_divide_round_up_(depth, tmm_cols);
767
+ nk_size_t const column_remainder_count = column_count - full_column_tiles * tmm_rows;
768
+
769
+ nk_size_t size = sizeof(nk_dots_amx_packed_header_t);
770
+ size += full_column_tiles * tiles_along_depth * tile_bytes;
771
+ if (column_remainder_count > 0) size += column_remainder_count * depth * sizeof(nk_f16_t);
772
+ size += column_count * sizeof(nk_f32_t);
773
+
774
+ return size;
775
+ }
776
+
777
+ NK_PUBLIC void nk_dots_pack_e5m2_graniteamx( //
778
+ nk_e5m2_t const *b, nk_size_t column_count, nk_size_t depth, //
779
+ nk_size_t b_stride_in_bytes, void *b_packed) {
780
+
781
+ nk_size_t const tmm_rows = 16;
782
+ nk_size_t const tmm_cols = 32;
783
+ nk_size_t const tile_elements = 512;
784
+ nk_size_t const tile_bytes = tile_elements * sizeof(nk_f16_t);
785
+ nk_size_t const b_stride_elements = b_stride_in_bytes; // E5M2: 1 byte per element
786
+
787
+ nk_size_t const column_tiles_count = column_count / tmm_rows;
788
+ nk_size_t const depth_tiles_count = nk_size_divide_round_up_(depth, tmm_cols);
789
+ nk_size_t const column_remainder_count = column_count - column_tiles_count * tmm_rows;
790
+ nk_size_t const total_tiles = column_tiles_count * depth_tiles_count;
791
+
792
+ nk_dots_amx_packed_header_t *header = (nk_dots_amx_packed_header_t *)b_packed;
793
+ header->full_column_tiles = (nk_u32_t)column_tiles_count;
794
+ header->full_depth_tiles = (nk_u32_t)depth_tiles_count;
795
+ header->column_remainder_count = (nk_u32_t)column_remainder_count;
796
+
797
+ nk_size_t const tiles_offset = sizeof(nk_dots_amx_packed_header_t);
798
+ nk_size_t const column_edge_offset = tiles_offset + total_tiles * tile_bytes;
799
+ header->column_edge_offset = (nk_u32_t)column_edge_offset;
800
+
801
+ nk_f16_t *tiles_ptr = (nk_f16_t *)((char *)b_packed + tiles_offset);
802
+ nk_f16_t *column_edge_ptr = (nk_f16_t *)((char *)b_packed + column_edge_offset);
803
+
804
+ // Pack tiles: widen-and-gather 16 strided E5M2 rows into an F16-aligned source tile,
805
+ // then transpose via the shared BF16/F16 transposer (same 16-bit lane layout).
806
+ for (nk_size_t column_tile_idx = 0; column_tile_idx < column_tiles_count; column_tile_idx++) {
807
+ for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
808
+
809
+ nk_size_t const tile_index = column_tile_idx * depth_tiles_count + depth_tile_idx;
810
+ nk_f16_t *tile_output = tiles_ptr + tile_index * tile_elements;
811
+
812
+ nk_size_t const src_row_start = column_tile_idx * tmm_rows;
813
+ nk_size_t const src_column_start = depth_tile_idx * tmm_cols;
814
+ nk_size_t const columns_to_pack = (src_column_start + tmm_cols <= depth) ? tmm_cols
815
+ : (depth - src_column_start);
816
+
817
+ nk_dots_bf16_a16x32_sapphireamx_t source_tile;
818
+ __mmask32 depth_mask = (__mmask32)((columns_to_pack < 32) ? ((1U << columns_to_pack) - 1) : ~0U);
819
+ for (nk_size_t row_idx = 0; row_idx < tmm_rows; row_idx++) {
820
+ nk_e5m2_t const *source_row = b + (src_row_start + row_idx) * b_stride_elements + src_column_start;
821
+ __m256i e5m2_u8x32 = _mm256_maskz_loadu_epi8(depth_mask, source_row);
822
+ __m512i word_u16x32 = _mm512_cvtepu8_epi16(e5m2_u8x32);
823
+ __m512i f16_bits_i16x32 = _mm512_slli_epi16(word_u16x32, 8);
824
+ _mm512_store_si512(&source_tile.data[row_idx][0], f16_bits_i16x32);
825
+ }
826
+
827
+ nk_dots_bf16_b32x16_sapphireamx_t transposed_tile;
828
+ nk_dots_pack_bf16_transposed_sapphireamx_(&source_tile, &transposed_tile);
829
+ for (nk_size_t i = 0; i < tile_bytes; i += 64)
830
+ _mm512_storeu_si512((char *)tile_output + i, _mm512_load_si512((char const *)&transposed_tile + i));
831
+ }
832
+ }
833
+
834
+ // Column-remainder: widen and store row-major (depth contiguous per column)
835
+ if (column_remainder_count > 0) {
836
+ nk_size_t const remainder_start_row = column_tiles_count * tmm_rows;
837
+ for (nk_size_t row_idx = 0; row_idx < column_remainder_count; row_idx++) {
838
+ nk_e5m2_t const *src_row = b + (remainder_start_row + row_idx) * b_stride_elements;
839
+ nk_f16_t *dst_row = column_edge_ptr + row_idx * depth;
840
+ nk_size_t column_idx = 0;
841
+ for (; column_idx + 32 <= depth; column_idx += 32) {
842
+ __m256i e5m2_u8x32 = _mm256_loadu_si256((__m256i const *)(src_row + column_idx));
843
+ __m512i word_u16x32 = _mm512_cvtepu8_epi16(e5m2_u8x32);
844
+ __m512i f16_bits_i16x32 = _mm512_slli_epi16(word_u16x32, 8);
845
+ _mm512_storeu_si512(dst_row + column_idx, f16_bits_i16x32);
846
+ }
847
+ if (column_idx < depth) {
848
+ __mmask32 tail_mask = (__mmask32)((1U << (depth - column_idx)) - 1);
849
+ __m256i e5m2_u8x32 = _mm256_maskz_loadu_epi8(tail_mask, src_row + column_idx);
850
+ __m512i word_u16x32 = _mm512_cvtepu8_epi16(e5m2_u8x32);
851
+ __m512i f16_bits_i16x32 = _mm512_slli_epi16(word_u16x32, 8);
852
+ _mm512_mask_storeu_epi16(dst_row + column_idx, tail_mask, f16_bits_i16x32);
853
+ }
854
+ }
855
+ }
856
+
857
+ nk_size_t norms_offset = column_edge_offset +
858
+ (column_remainder_count > 0 ? column_remainder_count * depth * sizeof(nk_f16_t) : 0);
859
+ header->norms_byte_offset = (nk_u32_t)norms_offset;
860
+ nk_f32_t *norms = (nk_f32_t *)((char *)b_packed + norms_offset);
861
+ for (nk_size_t col = 0; col < column_count; col++)
862
+ norms[col] = nk_dots_reduce_sumsq_e5m2_(b + col * b_stride_elements, depth);
863
+ }
864
+
865
+ NK_PUBLIC void nk_dots_packed_e5m2_graniteamx( //
866
+ nk_e5m2_t const *a, void const *b_packed, nk_f32_t *c, //
867
+ nk_size_t rows_count, nk_size_t cols_count, nk_size_t depth, nk_size_t a_stride_bytes, nk_size_t c_stride_bytes) {
868
+ nk_unused_(cols_count);
869
+
870
+ nk_dots_amx_packed_header_t const *header = (nk_dots_amx_packed_header_t const *)b_packed;
871
+ nk_size_t const column_tiles_count = header->full_column_tiles;
872
+ nk_size_t const depth_tiles_count = header->full_depth_tiles;
873
+ nk_size_t const column_remainder_count = header->column_remainder_count;
874
+
875
+ nk_f16_t const *b_tiles_base = (nk_f16_t const *)((char const *)b_packed + sizeof(nk_dots_amx_packed_header_t));
876
+ nk_f16_t const *col_edge_ptr = (nk_f16_t const *)((char const *)b_packed + header->column_edge_offset);
877
+
878
+ nk_size_t const a_stride_elements = a_stride_bytes; // E5M2: 1 byte per element
879
+ nk_size_t const c_stride_elements = c_stride_bytes / sizeof(nk_f32_t);
880
+
881
+ nk_size_t const tile_depth = 32;
882
+ nk_size_t const tile_size = 512;
883
+ nk_size_t const full_cols = column_tiles_count * 16;
884
+
885
+ nk_size_t const row_blocks_count = nk_size_divide_round_up_(rows_count, 32);
886
+ nk_size_t const col_blocks_count = column_tiles_count / 2;
887
+
888
+ if (depth_tiles_count == 0) return;
889
+
890
+ nk_dots_e5m2_a16x32_graniteamx_t a_tile_top, a_tile_bottom;
891
+ nk_dots_e5m2_state2x2_graniteamx_t c_accum_buffer;
892
+
893
+ nk_size_t const full_depth_tiles_count = depth / tile_depth;
894
+ nk_size_t const depth_remainder = depth % tile_depth;
895
+
896
+ nk_amx_tile_configure_sapphireamx_();
897
+
898
+ for (nk_size_t row_block_idx = 0; row_block_idx < row_blocks_count; row_block_idx++) {
899
+ nk_size_t const row_block_start = row_block_idx * 32;
900
+ nk_size_t const valid_rows_count = (row_block_start + 32 <= rows_count) ? 32 : (rows_count - row_block_start);
901
+ nk_size_t const is_full_row_block = (valid_rows_count == 32);
902
+
903
+ for (nk_size_t column_block_idx = 0; column_block_idx < col_blocks_count; column_block_idx++) {
904
+ nk_size_t const col_block_start = column_block_idx * 32;
905
+ nk_size_t const b_column_left_base = (column_block_idx * 2) * depth_tiles_count;
906
+ nk_size_t const b_column_right_base = (column_block_idx * 2 + 1) * depth_tiles_count;
907
+
908
+ _tile_zero(4);
909
+ _tile_zero(5);
910
+ _tile_zero(6);
911
+ _tile_zero(7);
912
+
913
+ // E5M2 A needs widen-then-tile-load every time (no fast in-place path), so we always
914
+ // route through the widen buffer a_tile_{top,bottom} rather than `_tile_loadd` from A.
915
+ for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
916
+ nk_size_t const depth_offset = depth_tile_idx * tile_depth;
917
+ nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
918
+ nk_size_t const rows_in_high_tile = is_full_row_block
919
+ ? 16
920
+ : ((valid_rows_count > 16) ? 16 : valid_rows_count);
921
+ nk_size_t const rows_in_low_tile = is_full_row_block
922
+ ? 16
923
+ : ((valid_rows_count > 16) ? valid_rows_count - 16 : 0);
924
+
925
+ nk_dots_e5m2_load_a_graniteamx_(&a_tile_top, a + row_block_start * a_stride_elements + depth_offset,
926
+ a_stride_elements, rows_in_high_tile, valid_depth);
927
+ if (rows_in_low_tile > 0) {
928
+ nk_dots_e5m2_load_a_graniteamx_(&a_tile_bottom,
929
+ a + (row_block_start + 16) * a_stride_elements + depth_offset,
930
+ a_stride_elements, rows_in_low_tile, valid_depth);
931
+ }
932
+
933
+ nk_dots_f16_b32x16_graniteamx_t const *b_tile_left = //
934
+ (nk_dots_f16_b32x16_graniteamx_t const *)(b_tiles_base +
935
+ (b_column_left_base + depth_tile_idx) * tile_size);
936
+ nk_dots_f16_b32x16_graniteamx_t const *b_tile_right = //
937
+ (nk_dots_f16_b32x16_graniteamx_t const *)(b_tiles_base +
938
+ (b_column_right_base + depth_tile_idx) * tile_size);
939
+
940
+ _tile_loadd(0, a_tile_top.data, 64);
941
+ if (rows_in_low_tile > 0) _tile_loadd(1, a_tile_bottom.data, 64);
942
+ _tile_loadd(2, b_tile_left->data, 64);
943
+ _tile_loadd(3, b_tile_right->data, 64);
944
+
945
+ _tile_dpfp16ps(4, 0, 2);
946
+ _tile_dpfp16ps(5, 0, 3);
947
+ if (rows_in_low_tile > 0) {
948
+ _tile_dpfp16ps(6, 1, 2);
949
+ _tile_dpfp16ps(7, 1, 3);
950
+ }
951
+ }
952
+
953
+ if (is_full_row_block) {
954
+ nk_f32_t *c_block = c + row_block_start * c_stride_elements + col_block_start;
955
+ _tile_stored(4, c_block, c_stride_bytes);
956
+ _tile_stored(5, c_block + 16, c_stride_bytes);
957
+ _tile_stored(6, (nk_f32_t *)((char *)c_block + 16 * c_stride_bytes), c_stride_bytes);
958
+ _tile_stored(7, (nk_f32_t *)((char *)c_block + 16 * c_stride_bytes) + 16, c_stride_bytes);
959
+ }
960
+ else {
961
+ _tile_stored(4, c_accum_buffer.c[0][0].data, 64);
962
+ _tile_stored(5, c_accum_buffer.c[0][1].data, 64);
963
+ _tile_stored(6, c_accum_buffer.c[1][0].data, 64);
964
+ _tile_stored(7, c_accum_buffer.c[1][1].data, 64);
965
+ nk_dots_f16_output2x2_graniteamx_(&c_accum_buffer,
966
+ c + row_block_start * c_stride_elements + col_block_start,
967
+ c_stride_elements, valid_rows_count, 32);
968
+ }
969
+ }
970
+ }
971
+
972
+ // Odd column tile
973
+ if (column_tiles_count % 2 == 1) {
974
+ nk_size_t const column_tile_idx = column_tiles_count - 1;
975
+ nk_size_t const col_start = column_tile_idx * 16;
976
+ nk_size_t const b_column_base = column_tile_idx * depth_tiles_count;
977
+
978
+ for (nk_size_t row_block_idx = 0; row_block_idx < row_blocks_count; row_block_idx++) {
979
+ nk_size_t const row_block_start = row_block_idx * 32;
980
+ nk_size_t const valid_rows_count = (row_block_start + 32 <= rows_count) ? 32
981
+ : (rows_count - row_block_start);
982
+ nk_size_t const rows_in_high_tile = (valid_rows_count > 16) ? 16 : valid_rows_count;
983
+ nk_size_t const rows_in_low_tile = (valid_rows_count > 16) ? valid_rows_count - 16 : 0;
984
+
985
+ nk_dots_f16_state_graniteamx_t c_high_state, c_low_state;
986
+
987
+ _tile_zero(4);
988
+ _tile_zero(6);
989
+
990
+ for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
991
+ nk_size_t const depth_offset = depth_tile_idx * tile_depth;
992
+ nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
993
+
994
+ nk_dots_e5m2_load_a_graniteamx_(&a_tile_top, a + row_block_start * a_stride_elements + depth_offset,
995
+ a_stride_elements, rows_in_high_tile, valid_depth);
996
+ if (rows_in_low_tile > 0) {
997
+ nk_dots_e5m2_load_a_graniteamx_(&a_tile_bottom,
998
+ a + (row_block_start + 16) * a_stride_elements + depth_offset,
999
+ a_stride_elements, rows_in_low_tile, valid_depth);
1000
+ }
1001
+
1002
+ nk_dots_f16_b32x16_graniteamx_t const *b_tile = //
1003
+ (nk_dots_f16_b32x16_graniteamx_t const *)(b_tiles_base +
1004
+ (b_column_base + depth_tile_idx) * tile_size);
1005
+
1006
+ _tile_loadd(0, a_tile_top.data, 64);
1007
+ if (rows_in_low_tile > 0) _tile_loadd(1, a_tile_bottom.data, 64);
1008
+ _tile_loadd(2, b_tile->data, 64);
1009
+
1010
+ _tile_dpfp16ps(4, 0, 2);
1011
+ if (rows_in_low_tile > 0) _tile_dpfp16ps(6, 1, 2);
1012
+ }
1013
+
1014
+ _tile_stored(4, c_high_state.data, 64);
1015
+ if (rows_in_low_tile > 0) _tile_stored(6, c_low_state.data, 64);
1016
+
1017
+ nk_dots_f16_store_graniteamx_(&c_high_state, c + row_block_start * c_stride_elements + col_start,
1018
+ c_stride_elements, rows_in_high_tile, 16);
1019
+ if (rows_in_low_tile > 0) {
1020
+ nk_dots_f16_store_graniteamx_(&c_low_state, c + (row_block_start + 16) * c_stride_elements + col_start,
1021
+ c_stride_elements, rows_in_low_tile, 16);
1022
+ }
1023
+ }
1024
+ }
1025
+
1026
+ // Column edge (fewer than 16 extra columns, stored row-major as F16 after pack)
1027
+ if (column_remainder_count > 0) {
1028
+ for (nk_size_t row_block_idx = 0; row_block_idx < row_blocks_count; row_block_idx++) {
1029
+ nk_size_t const row_block_start = row_block_idx * 32;
1030
+ nk_size_t const valid_rows_count = (row_block_start + 32 <= rows_count) ? 32
1031
+ : (rows_count - row_block_start);
1032
+ nk_size_t const rows_in_high_tile = (valid_rows_count > 16) ? 16 : valid_rows_count;
1033
+ nk_size_t const rows_in_low_tile = (valid_rows_count > 16) ? valid_rows_count - 16 : 0;
1034
+
1035
+ nk_dots_f16_state_graniteamx_t c_high_state, c_low_state;
1036
+ nk_dots_bf16_a16x32_sapphireamx_t b_as_a;
1037
+ nk_dots_bf16_b32x16_sapphireamx_t b_tile;
1038
+
1039
+ _tile_zero(4);
1040
+ _tile_zero(6);
1041
+
1042
+ for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
1043
+ nk_size_t const depth_offset = depth_tile_idx * tile_depth;
1044
+ nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
1045
+
1046
+ nk_dots_e5m2_load_a_graniteamx_(&a_tile_top, a + row_block_start * a_stride_elements + depth_offset,
1047
+ a_stride_elements, rows_in_high_tile, valid_depth);
1048
+ if (rows_in_low_tile > 0) {
1049
+ nk_dots_e5m2_load_a_graniteamx_(&a_tile_bottom,
1050
+ a + (row_block_start + 16) * a_stride_elements + depth_offset,
1051
+ a_stride_elements, rows_in_low_tile, valid_depth);
1052
+ }
1053
+
1054
+ nk_dots_bf16_load_a_sapphireamx_(&b_as_a, (nk_bf16_t const *)(col_edge_ptr + depth_offset), depth,
1055
+ column_remainder_count, valid_depth);
1056
+ nk_dots_pack_bf16_transposed_sapphireamx_(&b_as_a, &b_tile);
1057
+
1058
+ _tile_loadd(0, a_tile_top.data, 64);
1059
+ if (rows_in_low_tile > 0) _tile_loadd(1, a_tile_bottom.data, 64);
1060
+ _tile_loadd(2, b_tile.data, 64);
1061
+
1062
+ _tile_dpfp16ps(4, 0, 2);
1063
+ if (rows_in_low_tile > 0) _tile_dpfp16ps(6, 1, 2);
1064
+ }
1065
+
1066
+ _tile_stored(4, c_high_state.data, 64);
1067
+ if (rows_in_low_tile > 0) _tile_stored(6, c_low_state.data, 64);
1068
+
1069
+ nk_dots_f16_store_graniteamx_(&c_high_state, c + row_block_start * c_stride_elements + full_cols,
1070
+ c_stride_elements, rows_in_high_tile, column_remainder_count);
1071
+ if (rows_in_low_tile > 0) {
1072
+ nk_dots_f16_store_graniteamx_(&c_low_state, c + (row_block_start + 16) * c_stride_elements + full_cols,
1073
+ c_stride_elements, rows_in_low_tile, column_remainder_count);
1074
+ }
1075
+ }
1076
+ }
1077
+
1078
+ _tile_release();
1079
+ }
1080
+
1081
+ NK_PUBLIC void nk_dots_symmetric_e5m2_graniteamx( //
1082
+ nk_e5m2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, //
1083
+ nk_size_t stride_in_bytes, nk_f32_t *result, nk_size_t result_stride_in_bytes, //
1084
+ nk_size_t row_start, nk_size_t row_count) {
1085
+
1086
+ nk_size_t const stride_elements = stride_in_bytes; // E5M2: 1 byte per element
1087
+ nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
1088
+
1089
+ nk_size_t const row_end = (row_count == 0)
1090
+ ? vectors_count
1091
+ : (row_start + row_count < vectors_count ? row_start + row_count : vectors_count);
1092
+
1093
+ nk_size_t const depth_tiles = nk_size_divide_round_up_(depth, 32);
1094
+ nk_size_t const depth_tile_groups = nk_size_divide_round_up_(depth_tiles, 3);
1095
+
1096
+ nk_dots_e5m2_a16x32_graniteamx_t a_tiles[3];
1097
+ nk_dots_e5m2_a16x32_graniteamx_t b_src_tiles[3];
1098
+ nk_dots_f16_b32x16_graniteamx_t b_tiles[3];
1099
+ nk_dots_f16_state_graniteamx_t state;
1100
+
1101
+ nk_amx_tile_configure_sapphireamx_();
1102
+
1103
+ for (nk_size_t row_tile = row_start; row_tile < row_end; row_tile += 16) {
1104
+ nk_size_t const valid_rows = (row_tile + 16 <= row_end) ? 16 : (row_end - row_tile);
1105
+
1106
+ for (nk_size_t col_tile = 0; col_tile < vectors_count; col_tile += 16) {
1107
+ nk_size_t const valid_cols = (col_tile + 16 <= vectors_count) ? 16 : (vectors_count - col_tile);
1108
+
1109
+ nk_dots_f16_init_graniteamx_(&state);
1110
+
1111
+ for (nk_size_t depth_group_idx = 0; depth_group_idx < depth_tile_groups; depth_group_idx++) {
1112
+ nk_size_t const depth_base = depth_group_idx * 96;
1113
+
1114
+ for (int tile_idx = 0; tile_idx < 3; tile_idx++) {
1115
+ nk_size_t const depth_start = depth_base + tile_idx * 32;
1116
+ nk_size_t const valid_depth = (depth_start + 32 <= depth)
1117
+ ? 32
1118
+ : (depth > depth_start ? depth - depth_start : 0);
1119
+
1120
+ nk_dots_e5m2_load_a_graniteamx_( //
1121
+ &a_tiles[tile_idx], //
1122
+ vectors + row_tile * stride_elements + depth_start, //
1123
+ stride_elements, valid_rows, valid_depth);
1124
+
1125
+ if (row_tile == col_tile) {
1126
+ nk_dots_pack_bf16_transposed_sapphireamx_(
1127
+ (nk_dots_bf16_a16x32_sapphireamx_t const *)&a_tiles[tile_idx],
1128
+ (nk_dots_bf16_b32x16_sapphireamx_t *)&b_tiles[tile_idx]);
1129
+ }
1130
+ else {
1131
+ nk_dots_e5m2_load_a_graniteamx_( //
1132
+ &b_src_tiles[tile_idx], //
1133
+ vectors + col_tile * stride_elements + depth_start, //
1134
+ stride_elements, valid_cols, valid_depth);
1135
+ nk_dots_pack_bf16_transposed_sapphireamx_(
1136
+ (nk_dots_bf16_a16x32_sapphireamx_t const *)&b_src_tiles[tile_idx],
1137
+ (nk_dots_bf16_b32x16_sapphireamx_t *)&b_tiles[tile_idx]);
1138
+ }
1139
+ }
1140
+
1141
+ nk_dots_f16_update_graniteamx_( //
1142
+ &state, &a_tiles[0], &a_tiles[1], &a_tiles[2], //
1143
+ &b_tiles[0], &b_tiles[1], &b_tiles[2]);
1144
+ }
1145
+
1146
+ nk_dots_f16_store_graniteamx_( //
1147
+ &state, result + row_tile * result_stride_elements + col_tile, //
1148
+ result_stride_elements, valid_rows, valid_cols);
1149
+ }
1150
+ }
1151
+ }
1152
+
1153
+ #pragma endregion E5M2 Source
1154
+
721
1155
  #if defined(__clang__)
722
1156
  #pragma clang attribute pop
723
1157
  #elif defined(__GNUC__)
@@ -115,45 +115,45 @@ nk_define_cross_packed_(dots, bf16, haswell, bf16, bf16, f32, nk_b256_vec_t, nk_
115
115
  nk_partial_store_b32x4_haswell_,
116
116
  /*depth_simd_dimensions=*/16, /*dimensions_per_value=*/1)
117
117
 
118
- /* E4M3 GEMM: depth_simd_dimensions=8 (8 e4m3s = 8 bytes) upcasted to 8×f32 (256-bit) */
119
- nk_define_cross_pack_size_(dots, e4m3, haswell, e4m3, f32, /*norm_value_type=*/f32, /*depth_simd_dimensions=*/8,
118
+ /* E4M3 GEMM: depth_simd_dimensions=32 (byte-level batch; widen inside the update helper) */
119
+ nk_define_cross_pack_size_(dots, e4m3, haswell, e4m3, f32, /*norm_value_type=*/f32, /*depth_simd_dimensions=*/32,
120
120
  /*dimensions_per_value=*/1)
121
- nk_define_cross_pack_(dots, e4m3, haswell, e4m3, f32, nk_b256_vec_t, nk_load_e4m3x8_to_f32x8_haswell_,
122
- nk_partial_load_e4m3x8_to_f32x8_haswell_, nk_store_b256_haswell_, nk_partial_store_b32x8_serial_,
123
- /*simd_width=*/8, /*norm_value_type=*/f32, nk_dots_reduce_sumsq_e4m3_,
124
- /*depth_simd_dimensions=*/8, /*dimensions_per_value=*/1)
121
+ nk_define_cross_pack_(dots, e4m3, haswell, e4m3, f32, nk_b256_vec_t, nk_load_b256_haswell_,
122
+ nk_partial_load_b8x32_serial_, nk_store_b256_haswell_, nk_partial_store_b8x32_serial_,
123
+ /*simd_width=*/32, /*norm_value_type=*/f32, nk_dots_reduce_sumsq_e4m3_,
124
+ /*depth_simd_dimensions=*/32, /*dimensions_per_value=*/1)
125
125
  nk_define_cross_symmetric_(dots, e4m3, haswell, e4m3, f32, nk_b256_vec_t, nk_dot_through_f32_state_haswell_t_,
126
- nk_b128_vec_t, nk_dot_through_f32_init_haswell_, nk_load_e4m3x8_to_f32x8_haswell_,
127
- nk_partial_load_e4m3x8_to_f32x8_haswell_, nk_dot_through_f32_update_haswell_,
126
+ nk_b128_vec_t, nk_dot_through_f32_init_haswell_, nk_load_b256_haswell_,
127
+ nk_partial_load_b8x32_serial_, nk_dot_e4m3x32_update_haswell_,
128
128
  nk_dot_through_f32_finalize_haswell_, nk_store_b128_haswell_,
129
129
  nk_partial_store_b32x4_haswell_,
130
- /*depth_simd_dimensions=*/8, /*dimensions_per_value=*/1)
130
+ /*depth_simd_dimensions=*/32, /*dimensions_per_value=*/1)
131
131
  nk_define_cross_packed_(dots, e4m3, haswell, e4m3, f32, f32, nk_b256_vec_t, nk_dot_through_f32_state_haswell_t_,
132
- nk_b128_vec_t, nk_dot_through_f32_init_haswell_, nk_load_e4m3x8_to_f32x8_haswell_,
133
- nk_partial_load_e4m3x8_to_f32x8_haswell_, nk_load_b256_haswell_, nk_partial_load_b32x8_serial_,
134
- nk_dot_through_f32_update_haswell_, nk_dot_through_f32_finalize_haswell_,
135
- nk_store_b128_haswell_, nk_partial_store_b32x4_haswell_,
136
- /*depth_simd_dimensions=*/8, /*dimensions_per_value=*/1)
132
+ nk_b128_vec_t, nk_dot_through_f32_init_haswell_, nk_load_b256_haswell_,
133
+ nk_partial_load_b8x32_serial_, nk_load_b256_haswell_, nk_partial_load_b8x32_serial_,
134
+ nk_dot_e4m3x32_update_haswell_, nk_dot_through_f32_finalize_haswell_, nk_store_b128_haswell_,
135
+ nk_partial_store_b32x4_haswell_,
136
+ /*depth_simd_dimensions=*/32, /*dimensions_per_value=*/1)
137
137
 
138
- /* E5M2 GEMM: depth_simd_dimensions=8 (8 e5m2s = 8 bytes) upcasted to 8×f32 (256-bit) */
139
- nk_define_cross_pack_size_(dots, e5m2, haswell, e5m2, f32, /*norm_value_type=*/f32, /*depth_simd_dimensions=*/8,
138
+ /* E5M2 GEMM: depth_simd_dimensions=32 (byte-level batch; widen inside the update helper) */
139
+ nk_define_cross_pack_size_(dots, e5m2, haswell, e5m2, f32, /*norm_value_type=*/f32, /*depth_simd_dimensions=*/32,
140
140
  /*dimensions_per_value=*/1)
141
- nk_define_cross_pack_(dots, e5m2, haswell, e5m2, f32, nk_b256_vec_t, nk_load_e5m2x8_to_f32x8_haswell_,
142
- nk_partial_load_e5m2x8_to_f32x8_haswell_, nk_store_b256_haswell_, nk_partial_store_b32x8_serial_,
143
- /*simd_width=*/8, /*norm_value_type=*/f32, nk_dots_reduce_sumsq_e5m2_,
144
- /*depth_simd_dimensions=*/8, /*dimensions_per_value=*/1)
141
+ nk_define_cross_pack_(dots, e5m2, haswell, e5m2, f32, nk_b256_vec_t, nk_load_b256_haswell_,
142
+ nk_partial_load_b8x32_serial_, nk_store_b256_haswell_, nk_partial_store_b8x32_serial_,
143
+ /*simd_width=*/32, /*norm_value_type=*/f32, nk_dots_reduce_sumsq_e5m2_,
144
+ /*depth_simd_dimensions=*/32, /*dimensions_per_value=*/1)
145
145
  nk_define_cross_symmetric_(dots, e5m2, haswell, e5m2, f32, nk_b256_vec_t, nk_dot_through_f32_state_haswell_t_,
146
- nk_b128_vec_t, nk_dot_through_f32_init_haswell_, nk_load_e5m2x8_to_f32x8_haswell_,
147
- nk_partial_load_e5m2x8_to_f32x8_haswell_, nk_dot_through_f32_update_haswell_,
146
+ nk_b128_vec_t, nk_dot_through_f32_init_haswell_, nk_load_b256_haswell_,
147
+ nk_partial_load_b8x32_serial_, nk_dot_e5m2x32_update_haswell_,
148
148
  nk_dot_through_f32_finalize_haswell_, nk_store_b128_haswell_,
149
149
  nk_partial_store_b32x4_haswell_,
150
- /*depth_simd_dimensions=*/8, /*dimensions_per_value=*/1)
150
+ /*depth_simd_dimensions=*/32, /*dimensions_per_value=*/1)
151
151
  nk_define_cross_packed_(dots, e5m2, haswell, e5m2, f32, f32, nk_b256_vec_t, nk_dot_through_f32_state_haswell_t_,
152
- nk_b128_vec_t, nk_dot_through_f32_init_haswell_, nk_load_e5m2x8_to_f32x8_haswell_,
153
- nk_partial_load_e5m2x8_to_f32x8_haswell_, nk_load_b256_haswell_, nk_partial_load_b32x8_serial_,
154
- nk_dot_through_f32_update_haswell_, nk_dot_through_f32_finalize_haswell_,
155
- nk_store_b128_haswell_, nk_partial_store_b32x4_haswell_,
156
- /*depth_simd_dimensions=*/8, /*dimensions_per_value=*/1)
152
+ nk_b128_vec_t, nk_dot_through_f32_init_haswell_, nk_load_b256_haswell_,
153
+ nk_partial_load_b8x32_serial_, nk_load_b256_haswell_, nk_partial_load_b8x32_serial_,
154
+ nk_dot_e5m2x32_update_haswell_, nk_dot_through_f32_finalize_haswell_, nk_store_b128_haswell_,
155
+ nk_partial_store_b32x4_haswell_,
156
+ /*depth_simd_dimensions=*/32, /*dimensions_per_value=*/1)
157
157
 
158
158
  /* E2M3 GEMM: integer LUT path, depth_simd_dimensions=32 (32 e2m3s = 32 bytes = AVX2 register width) */
159
159
  nk_define_cross_pack_size_(dots, e2m3, haswell, e2m3, e2m3, /*norm_value_type=*/f32, /*depth_simd_dimensions=*/32,
@@ -73,7 +73,7 @@
73
73
  #if NK_TARGET_SAPPHIREAMX
74
74
 
75
75
  #include "numkong/cast/icelake.h" // For FP8 ↔ BF16 conversions
76
- #include "numkong/dots/serial.h" // For nk_dots_reduce_sumsq_bf16_
76
+ #include "numkong/dots/serial.h" // `nk_dots_reduce_sumsq_bf16_`
77
77
 
78
78
  #if defined(__cplusplus)
79
79
  extern "C" {