@fugood/llama.node 1.1.6 → 1.1.8
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/lib/binding.ts +4 -0
- package/lib/index.js +6 -1
- package/lib/index.ts +6 -0
- package/lib/version.js +5 -0
- package/lib/version.ts +2 -0
- package/package.json +14 -14
- package/scripts/llama.cpp.patch +9 -9
- package/src/LlamaCompletionWorker.cpp +73 -20
- package/src/LlamaCompletionWorker.h +8 -0
- package/src/LlamaContext.cpp +9 -0
- package/src/common.hpp +8 -1
- package/src/llama.cpp/CMakeLists.txt +2 -0
- package/src/llama.cpp/common/arg.cpp +132 -41
- package/src/llama.cpp/common/chat-parser.cpp +9 -1
- package/src/llama.cpp/common/chat.cpp +311 -9
- package/src/llama.cpp/common/chat.h +4 -1
- package/src/llama.cpp/common/common.cpp +54 -0
- package/src/llama.cpp/common/common.h +46 -9
- package/src/llama.cpp/ggml/CMakeLists.txt +2 -0
- package/src/llama.cpp/ggml/include/ggml-opt.h +25 -6
- package/src/llama.cpp/ggml/include/ggml-zdnn.h +16 -0
- package/src/llama.cpp/ggml/include/ggml.h +28 -2
- package/src/llama.cpp/ggml/src/CMakeLists.txt +1 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +1 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +66 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/arch/x86/repack.cpp +1136 -1077
- package/src/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +14 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +6 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +21 -24
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +16 -7
- package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +63 -2
- package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +1 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/repack.cpp +200 -51
- package/src/llama.cpp/ggml/src/ggml-cpu/repack.h +11 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/traits.cpp +2 -2
- package/src/llama.cpp/ggml/src/ggml-cpu/traits.h +1 -1
- package/src/llama.cpp/include/llama.h +25 -0
- package/src/llama.cpp/src/llama-batch.cpp +1 -1
- package/src/llama.cpp/src/llama-chat.cpp +2 -4
- package/src/llama.cpp/src/llama-context.cpp +29 -22
- package/src/llama.cpp/src/llama-context.h +6 -5
- package/src/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +12 -6
- package/src/llama.cpp/src/llama-kv-cache-unified-iswa.h +2 -2
- package/src/llama.cpp/src/llama-kv-cache-unified.cpp +89 -69
- package/src/llama.cpp/src/llama-kv-cache-unified.h +2 -2
- package/src/llama.cpp/src/llama-memory-hybrid.cpp +6 -2
- package/src/llama.cpp/src/llama-memory-hybrid.h +2 -2
- package/src/llama.cpp/src/llama-memory-recurrent.cpp +6 -2
- package/src/llama.cpp/src/llama-memory-recurrent.h +2 -2
- package/src/llama.cpp/src/llama-memory.h +2 -2
- package/src/llama.cpp/src/llama-model.cpp +81 -70
- package/src/llama.cpp/src/llama-model.h +2 -0
- package/src/llama.cpp/src/llama-quant.cpp +1 -1
- package/src/llama.cpp/src/llama-vocab.cpp +2 -1
|
@@ -511,38 +511,34 @@ void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTR
|
|
|
511
511
|
#endif
|
|
512
512
|
}
|
|
513
513
|
|
|
514
|
-
|
|
514
|
+
//
|
|
515
|
+
// GEMV/GEMM templates
|
|
516
|
+
//
|
|
517
|
+
|
|
518
|
+
#if defined(__AVX2__) || defined(__AVX512F__)
|
|
519
|
+
|
|
520
|
+
// GEMV for 8x blocks of 32 4-bit quants with a single scale factor per block
|
|
521
|
+
template<typename block_tx8>
|
|
522
|
+
static void gemv_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, __m256i signextendlut) {
|
|
523
|
+
static_assert(
|
|
524
|
+
std::is_same_v<block_tx8, block_q4_0x8> ||
|
|
525
|
+
std::is_same_v<block_tx8, block_iq4_nlx8>,
|
|
526
|
+
"Unsupported block type");
|
|
527
|
+
|
|
515
528
|
const int qk = QK8_0;
|
|
516
529
|
const int nb = n / qk;
|
|
517
|
-
const int ncols_interleaved = 8;
|
|
518
|
-
const int blocklen = 8;
|
|
519
|
-
|
|
520
|
-
assert (n % qk == 0);
|
|
521
|
-
assert (nc % ncols_interleaved == 0);
|
|
522
530
|
|
|
523
|
-
UNUSED(s);
|
|
524
531
|
UNUSED(bs);
|
|
525
|
-
UNUSED(vx);
|
|
526
|
-
UNUSED(vy);
|
|
527
|
-
UNUSED(nr);
|
|
528
|
-
UNUSED(nc);
|
|
529
|
-
UNUSED(nb);
|
|
530
|
-
UNUSED(ncols_interleaved);
|
|
531
|
-
UNUSED(blocklen);
|
|
532
532
|
|
|
533
|
-
#if defined(__AVX2__)
|
|
534
|
-
// Lookup table to convert signed nibbles to signed bytes
|
|
535
|
-
__m256i signextendlut = _mm256_castsi128_si256(_mm_set_epi8(-1, -2, -3, -4, -5, -6, -7, -8, 7, 6, 5, 4, 3, 2, 1, 0));
|
|
536
|
-
signextendlut = _mm256_permute2f128_si256(signextendlut, signextendlut, 0);
|
|
537
533
|
__m128i changemask = _mm_set_epi8(15, 14, 7, 6, 13, 12, 5, 4, 11, 10, 3, 2, 9, 8, 1, 0);
|
|
538
534
|
__m256i finalpermutemask = _mm256_set_epi32(7, 5, 3, 1, 6, 4, 2, 0);
|
|
539
535
|
|
|
540
536
|
// Permute mask used for easier vector processing at later stages
|
|
541
537
|
const __m256i m4b = _mm256_set1_epi8(0x0F);
|
|
542
538
|
|
|
543
|
-
int64_t b_nb = n /
|
|
539
|
+
int64_t b_nb = n / 32;
|
|
544
540
|
|
|
545
|
-
const
|
|
541
|
+
const block_tx8 * b_ptr_start = (const block_tx8 *)vx;
|
|
546
542
|
const block_q8_0 * a_ptr_start = (const block_q8_0 *)vy;
|
|
547
543
|
|
|
548
544
|
// Process Q8_0 blocks one by one
|
|
@@ -551,17 +547,17 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
|
|
551
547
|
// Pointers to LHS blocks of block_q8_0 format
|
|
552
548
|
const block_q8_0 * a_ptr = a_ptr_start + (y * nb);
|
|
553
549
|
|
|
554
|
-
// Take group of eight
|
|
550
|
+
// Take group of eight blocks at each pass of the loop and perform dot product operation
|
|
555
551
|
for (int64_t x = 0; x < nc / 8; x++) {
|
|
556
552
|
|
|
557
553
|
// Pointers to RHS blocks
|
|
558
|
-
const
|
|
554
|
+
const block_tx8 * b_ptr = b_ptr_start + (x * b_nb);
|
|
559
555
|
|
|
560
556
|
// Master FP accumulator
|
|
561
557
|
__m256 acc_row = _mm256_setzero_ps();
|
|
562
558
|
|
|
563
559
|
for (int64_t b = 0; b < nb; b++) {
|
|
564
|
-
// Load 8 blocks of
|
|
560
|
+
// Load 8 blocks of 32 interleaved as 8 bytes (B0 - B7)
|
|
565
561
|
const __m256i rhs_raw_vec_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs));
|
|
566
562
|
const __m256i rhs_raw_vec_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs) + 1);
|
|
567
563
|
const __m256i rhs_raw_vec_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs) + 2);
|
|
@@ -578,8 +574,13 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
|
|
578
574
|
const __m256i rhs_vec_0123_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_1, 4), m4b)); // B0(24-31) B1(24-31) B2(24-31) B3(24-31)
|
|
579
575
|
const __m256i rhs_vec_4567_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_1, 4), m4b)); // B4(24-31) B5(24-31) B6(24-31) B7(24-31)
|
|
580
576
|
|
|
581
|
-
// Load the scale values for the 8 blocks interleaved in
|
|
582
|
-
|
|
577
|
+
// Load the scale values for the 8 blocks interleaved in block_tx8
|
|
578
|
+
__m256 col_scale_f32;
|
|
579
|
+
if constexpr (
|
|
580
|
+
std::is_same_v<block_tx8, block_q4_0x8> ||
|
|
581
|
+
std::is_same_v<block_tx8, block_iq4_nlx8>) {
|
|
582
|
+
col_scale_f32 = GGML_F32Cx8_REARRANGE_LOAD(b_ptr[b].d, changemask);
|
|
583
|
+
}
|
|
583
584
|
|
|
584
585
|
// Load and convert to FP32 scale from block_q8_0
|
|
585
586
|
const __m256 row_scale_f32 = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(a_ptr[b].d));
|
|
@@ -620,240 +621,782 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
|
|
620
621
|
_mm256_storeu_ps(s + (y * nr + x * 8), acc_row);
|
|
621
622
|
}
|
|
622
623
|
}
|
|
623
|
-
return;
|
|
624
|
-
|
|
625
|
-
#endif
|
|
626
|
-
ggml_gemv_q4_0_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
|
627
624
|
}
|
|
628
625
|
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
static const uint32_t kmask3 = 0x03030303;
|
|
637
|
-
|
|
638
|
-
assert (n % qk == 0);
|
|
639
|
-
assert (nc % ncols_interleaved == 0);
|
|
626
|
+
// GEMM for 8x blocks of 32 4-bit quants with a single scale factor per block
|
|
627
|
+
template<typename block_tx8>
|
|
628
|
+
static void gemm_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, __m256i signextendlut) {
|
|
629
|
+
static_assert(
|
|
630
|
+
std::is_same_v<block_tx8, block_q4_0x8> ||
|
|
631
|
+
std::is_same_v<block_tx8, block_iq4_nlx8>,
|
|
632
|
+
"Unsupported block type");
|
|
640
633
|
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
UNUSED(vx);
|
|
644
|
-
UNUSED(vy);
|
|
645
|
-
UNUSED(nr);
|
|
646
|
-
UNUSED(nc);
|
|
647
|
-
UNUSED(nb);
|
|
648
|
-
UNUSED(ncols_interleaved);
|
|
649
|
-
UNUSED(blocklen);
|
|
634
|
+
const int qk = QK8_0;
|
|
635
|
+
const int nb = n / qk;
|
|
650
636
|
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
__m256i signextendlut = _mm256_castsi128_si256(_mm_set_epi8(-1, -2, -3, -4, -5, -6, -7, -8, 7, 6, 5, 4, 3, 2, 1, 0));
|
|
654
|
-
signextendlut = _mm256_permute2f128_si256(signextendlut, signextendlut, 0);
|
|
655
|
-
// Shuffle masks to rearrange delta and scale values to multiply with appropriate scales
|
|
656
|
-
__m128i deltamask = _mm_set_epi8(15, 14, 7, 6, 13, 12, 5, 4, 11, 10, 3, 2, 9, 8, 1, 0);
|
|
657
|
-
__m128i scalemask = _mm_set_epi8(7, 7, 3, 3, 6, 6, 2, 2, 5, 5, 1, 1, 4, 4, 0, 0);
|
|
658
|
-
// Permute mask used for easier vector processing at later stages
|
|
659
|
-
__m256i finalpermutemask = _mm256_set_epi32(7, 5, 3, 1, 6, 4, 2, 0);
|
|
637
|
+
const block_tx8 * b_ptr_start = (const block_tx8 *)vx;
|
|
638
|
+
const block_q8_0x4 * a_ptr_start = (const block_q8_0x4 *)vy;
|
|
660
639
|
|
|
661
|
-
|
|
640
|
+
int64_t b_nb = n / 32;
|
|
641
|
+
int64_t y = 0;
|
|
642
|
+
// Mask to mask out nibbles from packed bytes
|
|
662
643
|
const __m256i m4b = _mm256_set1_epi8(0x0F);
|
|
644
|
+
const __m128i loadMask = _mm_blend_epi32(_mm_setzero_si128(), _mm_set1_epi32(0xFFFFFFFF), 3);
|
|
645
|
+
// Permute mask used for easier vector processing at later stages
|
|
646
|
+
__m256i requiredOrder = _mm256_set_epi32(3, 2, 1, 0, 7, 6, 5, 4);
|
|
647
|
+
int64_t xstart = 0;
|
|
648
|
+
int anr = nr - nr%16; // Used to align nr with boundary of 16
|
|
649
|
+
#ifdef __AVX512F__
|
|
650
|
+
int anc = nc - nc%16; // Used to align nc with boundary of 16
|
|
651
|
+
// Mask to mask out nibbles from packed bytes expanded to 512 bit length
|
|
652
|
+
const __m512i m4bexpanded = _mm512_set1_epi8(0x0F);
|
|
653
|
+
// Lookup table to convert signed nibbles to signed bytes expanded to 512 bit length
|
|
654
|
+
__m512i signextendlutexpanded = _mm512_inserti32x8(_mm512_castsi256_si512(signextendlut), signextendlut, 1);
|
|
663
655
|
|
|
664
|
-
|
|
665
|
-
|
|
666
|
-
const block_q4_Kx8 * b_ptr_start = (const block_q4_Kx8 *)vx;
|
|
667
|
-
const block_q8_K * a_ptr_start = (const block_q8_K *)vy;
|
|
656
|
+
// Take group of four block_q8_0x4 structures at each pass of the loop and perform dot product operation
|
|
657
|
+
for (; y < anr / 4; y += 4) {
|
|
668
658
|
|
|
669
|
-
|
|
670
|
-
for (int64_t y = 0; y < nr; y++) {
|
|
659
|
+
const block_q8_0x4 * a_ptrs[4];
|
|
671
660
|
|
|
672
|
-
|
|
673
|
-
|
|
661
|
+
a_ptrs[0] = a_ptr_start + (y * nb);
|
|
662
|
+
for (int i = 0; i < 3; ++i) {
|
|
663
|
+
a_ptrs[i + 1] = a_ptrs[i] + nb;
|
|
664
|
+
}
|
|
674
665
|
|
|
675
|
-
// Take group of
|
|
676
|
-
for (int64_t x = 0; x <
|
|
666
|
+
// Take group of two block_tx8 structures at each pass of the loop and perform dot product operation
|
|
667
|
+
for (int64_t x = 0; x < anc / 8; x += 2) {
|
|
677
668
|
|
|
678
|
-
|
|
679
|
-
const
|
|
669
|
+
const block_tx8 * b_ptr_0 = b_ptr_start + ((x) * b_nb);
|
|
670
|
+
const block_tx8 * b_ptr_1 = b_ptr_start + ((x + 1) * b_nb);
|
|
680
671
|
|
|
681
672
|
// Master FP accumulators
|
|
682
|
-
|
|
683
|
-
|
|
673
|
+
__m512 acc_rows[16];
|
|
674
|
+
for (int i = 0; i < 16; i++) {
|
|
675
|
+
acc_rows[i] = _mm512_setzero_ps();
|
|
676
|
+
}
|
|
684
677
|
|
|
685
678
|
for (int64_t b = 0; b < nb; b++) {
|
|
679
|
+
// Load the sixteen blocks of quantized values interleaved with each other in chunks of eight - B0,B1 ....BE,BF
|
|
680
|
+
const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs));
|
|
681
|
+
const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 32));
|
|
682
|
+
const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 64));
|
|
683
|
+
const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 96));
|
|
684
|
+
|
|
685
|
+
const __m256i rhs_raw_mat_89AB_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs));
|
|
686
|
+
const __m256i rhs_raw_mat_CDEF_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 32));
|
|
687
|
+
const __m256i rhs_raw_mat_89AB_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 64));
|
|
688
|
+
const __m256i rhs_raw_mat_CDEF_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 96));
|
|
689
|
+
|
|
690
|
+
// Save the values in the following vectors in the formats B0B1B4B5B8B9BCBD, B2B3B6B7BABBBEBF for further processing and storing of values
|
|
691
|
+
const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240);
|
|
692
|
+
const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240);
|
|
693
|
+
const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240);
|
|
694
|
+
const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240);
|
|
695
|
+
|
|
696
|
+
const __m256i rhs_raw_mat_89CD_0 = _mm256_blend_epi32(rhs_raw_mat_89AB_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_0, requiredOrder), 240);
|
|
697
|
+
const __m256i rhs_raw_mat_ABEF_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_0, requiredOrder), rhs_raw_mat_CDEF_0, 240);
|
|
698
|
+
const __m256i rhs_raw_mat_89CD_1 = _mm256_blend_epi32(rhs_raw_mat_89AB_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_1, requiredOrder), 240);
|
|
699
|
+
const __m256i rhs_raw_mat_ABEF_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_1, requiredOrder), rhs_raw_mat_CDEF_1, 240);
|
|
700
|
+
|
|
701
|
+
const __m512i rhs_raw_mat_014589CD_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_0), rhs_raw_mat_89CD_0, 1);
|
|
702
|
+
const __m512i rhs_raw_mat_2367ABEF_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_0), rhs_raw_mat_ABEF_0, 1);
|
|
703
|
+
const __m512i rhs_raw_mat_014589CD_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_1), rhs_raw_mat_89CD_1, 1);
|
|
704
|
+
const __m512i rhs_raw_mat_2367ABEF_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_1), rhs_raw_mat_ABEF_1, 1);
|
|
686
705
|
|
|
687
|
-
//
|
|
688
|
-
const
|
|
689
|
-
|
|
690
|
-
// Load the scale values for the 8 blocks interleaved in block_q4_Kx8
|
|
691
|
-
// col_scale_f32 rearranged so as to multiply with appropriate quants
|
|
692
|
-
const __m256 col_scale_f32 = GGML_F32Cx8_REARRANGE_LOAD(b_ptr[b].d, deltamask);
|
|
693
|
-
const __m256 col_dmin_f32 = GGML_F32Cx8_LOAD(b_ptr[b].dmin);
|
|
706
|
+
// 4-bit -> 8-bit - Sign is maintained
|
|
707
|
+
const __m512i rhs_mat_014589CD_0 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_014589CD_0, m4bexpanded)); //B0(0-7) B1(0-7) B4(0-7) B5(0-7) B8(0-7) B9(0-7) BC(0-7) BD(0-7)
|
|
708
|
+
const __m512i rhs_mat_2367ABEF_0 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_2367ABEF_0, m4bexpanded)); //B2(0-7) B3(0-7) B6(0-7) B7(0-7) BA(0-7) BB(0-7) BE(0-7) BF(0-7)
|
|
694
709
|
|
|
695
|
-
|
|
696
|
-
|
|
710
|
+
const __m512i rhs_mat_014589CD_1 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_014589CD_1, m4bexpanded)); //B0(8-15) B1(8-15) B4(8-15) B5(8-15) B8(8-15) B9(8-15) BC(8-15) BD(8-15)
|
|
711
|
+
const __m512i rhs_mat_2367ABEF_1 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_2367ABEF_1, m4bexpanded)); //B2(8-15) B3(8-15) B6(8-15) B7(8-15) BA(8-15) BB(8-15) BE(8-15) BF(8-15)
|
|
697
712
|
|
|
698
|
-
const
|
|
699
|
-
|
|
700
|
-
q8s = _mm256_permute2f128_si256(q8s, q8s, 0);
|
|
713
|
+
const __m512i rhs_mat_014589CD_2 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_0, 4), m4bexpanded)); //B0(16-23) B1(16-23) B4(16-23) B5(16-23) B8(16-23) B9(16-23) BC(16-23) BD(16-23)
|
|
714
|
+
const __m512i rhs_mat_2367ABEF_2 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_0, 4), m4bexpanded)); //B2(16-23) B3(16-23) B6(16-23) B7(16-23) BA(16-23) BB(16-23) BE(16-23) BF(16-23)
|
|
701
715
|
|
|
702
|
-
|
|
703
|
-
|
|
716
|
+
const __m512i rhs_mat_014589CD_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_1, 4), m4bexpanded)); //B0(24-31) B1(24-31) B4(24-31) B5(24-31) B8(24-31) B9(24-31) BC(24-31) BD(24-31)
|
|
717
|
+
const __m512i rhs_mat_2367ABEF_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 4), m4bexpanded)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31) BA(24-31) BB(24-31) BE(24-31) BF(24-31)
|
|
704
718
|
|
|
705
|
-
|
|
706
|
-
|
|
707
|
-
|
|
708
|
-
const __m256i rhs_raw_vec_0123_1 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 64 + sb * 256));
|
|
709
|
-
const __m256i rhs_raw_vec_4567_1 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 96 + sb * 256));
|
|
710
|
-
const __m256i rhs_raw_vec_0123_2 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 128 + sb * 256));
|
|
711
|
-
const __m256i rhs_raw_vec_4567_2 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 160 + sb * 256));
|
|
712
|
-
const __m256i rhs_raw_vec_0123_3 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 192 + sb * 256));
|
|
713
|
-
const __m256i rhs_raw_vec_4567_3 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 224 + sb * 256));
|
|
719
|
+
// Shuffle pattern one - right side input
|
|
720
|
+
const __m512i rhs_mat_014589CD_0_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, (_MM_PERM_ENUM)136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) B8(0-3) B9(0-3) B8(0-3) B9(0-3) BC(0-3) BD(0-3) BC(0-3) BD(0-3)
|
|
721
|
+
const __m512i rhs_mat_2367ABEF_0_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, (_MM_PERM_ENUM)136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) BA(0-3) BB(0-3) BA(0-3) BB(0-3) BE(0-3) BF(0-3) BE(0-3) BF(0-3)
|
|
714
722
|
|
|
715
|
-
|
|
716
|
-
|
|
717
|
-
const __m256i rhs_vec_0123_00 = _mm256_and_si256(rhs_raw_vec_0123_0, m4b);
|
|
718
|
-
const __m256i rhs_vec_4567_00 = _mm256_and_si256(rhs_raw_vec_4567_0, m4b);
|
|
719
|
-
const __m256i rhs_vec_0123_01 = _mm256_and_si256(rhs_raw_vec_0123_1, m4b);
|
|
720
|
-
const __m256i rhs_vec_4567_01 = _mm256_and_si256(rhs_raw_vec_4567_1, m4b);
|
|
721
|
-
const __m256i rhs_vec_0123_02 = _mm256_and_si256(rhs_raw_vec_0123_2, m4b);
|
|
722
|
-
const __m256i rhs_vec_4567_02 = _mm256_and_si256(rhs_raw_vec_4567_2, m4b);
|
|
723
|
-
const __m256i rhs_vec_0123_03 = _mm256_and_si256(rhs_raw_vec_0123_3, m4b);
|
|
724
|
-
const __m256i rhs_vec_4567_03 = _mm256_and_si256(rhs_raw_vec_4567_3, m4b);
|
|
723
|
+
const __m512i rhs_mat_014589CD_1_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, (_MM_PERM_ENUM)136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) B8(8-11) B9(8-11) B8(8-11) B9(8-11) BC(8-11) BD(8-11) BC(8-11) BD(8-11)
|
|
724
|
+
const __m512i rhs_mat_2367ABEF_1_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, (_MM_PERM_ENUM)136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) BA(8-11) BB(8-11) BA(8-11) BB(8-11) BE(8-11) BF(8-11) BE(8-11) BF(8-11)
|
|
725
725
|
|
|
726
|
-
|
|
727
|
-
|
|
728
|
-
const __m256i rhs_vec_4567_10 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_0, 4), m4b);
|
|
729
|
-
const __m256i rhs_vec_0123_11 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_1, 4), m4b);
|
|
730
|
-
const __m256i rhs_vec_4567_11 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_1, 4), m4b);
|
|
731
|
-
const __m256i rhs_vec_0123_12 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_2, 4), m4b);
|
|
732
|
-
const __m256i rhs_vec_4567_12 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_2, 4), m4b);
|
|
733
|
-
const __m256i rhs_vec_0123_13 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_3, 4), m4b);
|
|
734
|
-
const __m256i rhs_vec_4567_13 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_3, 4), m4b);
|
|
726
|
+
const __m512i rhs_mat_014589CD_2_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, (_MM_PERM_ENUM)136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) B8(16-19) B9(16-19) B8(16-19) B9(16-19) BC(16-19) BD(16-19) BC(16-19) BD(16-19)
|
|
727
|
+
const __m512i rhs_mat_2367ABEF_2_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, (_MM_PERM_ENUM)136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) BA(16-19) BB(16-19) BA(16-19) BB(16-19) BE(16-19) BF(16-19) BE(16-19) BF(16-19)
|
|
735
728
|
|
|
736
|
-
|
|
729
|
+
const __m512i rhs_mat_014589CD_3_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, (_MM_PERM_ENUM)136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) B8(24-27) B9(24-27) B8(24-27) B9(24-27) BC(24-27) BD(24-27) BC(24-27) BD(24-27)
|
|
730
|
+
const __m512i rhs_mat_2367ABEF_3_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, (_MM_PERM_ENUM)136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) BA(24-27) BB(24-27) BA(24-27) BB(24-27) BE(24-27) BF(24-27) BE(24-27) BF(24-27)
|
|
737
731
|
|
|
738
|
-
|
|
739
|
-
// The below block is for eg to extract first sub block's scales and mins from different Q4_K structures for the sb loop
|
|
740
|
-
memcpy(utmp_0, b_ptr[b].scales + 24 * sb, 12);
|
|
741
|
-
utmp_0[3] = ((utmp_0[2] >> 4) & kmask2) | (((utmp_0[1] >> 6) & kmask3) << 4);
|
|
742
|
-
const uint32_t uaux_0 = utmp_0[1] & kmask1;
|
|
743
|
-
utmp_0[1] = (utmp_0[2] & kmask2) | (((utmp_0[0] >> 6) & kmask3) << 4);
|
|
744
|
-
utmp_0[2] = uaux_0;
|
|
745
|
-
utmp_0[0] &= kmask1;
|
|
732
|
+
// Shuffle pattern two - right side input
|
|
746
733
|
|
|
747
|
-
|
|
748
|
-
|
|
749
|
-
utmp_1[3] = ((utmp_1[2] >> 4) & kmask2) | (((utmp_1[1] >> 6) & kmask3) << 4);
|
|
750
|
-
const uint32_t uaux_1 = utmp_1[1] & kmask1;
|
|
751
|
-
utmp_1[1] = (utmp_1[2] & kmask2) | (((utmp_1[0] >> 6) & kmask3) << 4);
|
|
752
|
-
utmp_1[2] = uaux_1;
|
|
753
|
-
utmp_1[0] &= kmask1;
|
|
734
|
+
const __m512i rhs_mat_014589CD_0_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, (_MM_PERM_ENUM)221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) B8(4-7) B9(4-7) B8(4-7) B9(4-7) BC(4-7) BD(4-7) BC(4-7) BD(4-7)
|
|
735
|
+
const __m512i rhs_mat_2367ABEF_0_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, (_MM_PERM_ENUM)221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) BA(4-7) BB(4-7) BA(4-7) BB(4-7) BE(4-7) BF(4-7) BE(4-7) BF(4-7)
|
|
754
736
|
|
|
755
|
-
|
|
756
|
-
|
|
757
|
-
__m128i scales_rearrange_0 = _mm_shuffle_epi8(mins_and_scales_0, scalemask);
|
|
758
|
-
__m256i scales_0 = _mm256_cvtepu8_epi16(scales_rearrange_0);
|
|
737
|
+
const __m512i rhs_mat_014589CD_1_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, (_MM_PERM_ENUM)221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) B8(12-15) B9(12-15) B8(12-15) B9(12-15) BC(12-15) BD(12-15) BC(12-15) BD(12-15)
|
|
738
|
+
const __m512i rhs_mat_2367ABEF_1_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, (_MM_PERM_ENUM)221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) BA(12-15) BB(12-15) BA(12-15) BB(12-15) BE(12-15) BF(12-15) BE(12-15) BF(12-15)
|
|
759
739
|
|
|
760
|
-
|
|
761
|
-
|
|
762
|
-
__m128i scales_rearrange_1 = _mm_shuffle_epi8(mins_and_scales_1, scalemask);
|
|
763
|
-
__m256i scales_1 = _mm256_cvtepu8_epi16(scales_rearrange_1);
|
|
740
|
+
const __m512i rhs_mat_014589CD_2_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, (_MM_PERM_ENUM)221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) B8(20-23) B9(20-23) B8(20-23) B9(20-23) BC(20-23) BD(20-23) BC(20-23) BD(20-23)
|
|
741
|
+
const __m512i rhs_mat_2367ABEF_2_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, (_MM_PERM_ENUM)221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) BA(20-23) BB(20-23) BA(20-23) BB(20-23) BE(20-23) BF(20-23) BE(20-23) BF(20-23)
|
|
764
742
|
|
|
765
|
-
|
|
766
|
-
|
|
743
|
+
const __m512i rhs_mat_014589CD_3_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, (_MM_PERM_ENUM)221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) B8(28-31) B9(28-31) B8(28-31) B9(28-31) BC(28-31) BD(28-31) BC(28-31) BD(28-31)
|
|
744
|
+
const __m512i rhs_mat_2367ABEF_3_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, (_MM_PERM_ENUM)221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) BA(28-31) BB(28-31) BA(28-31) BB(28-31) BE(28-31) BF(28-31) BE(28-31) BF(28-31)
|
|
767
745
|
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
|
|
746
|
+
// Scale values - Load the weight scale values of two block_tx8
|
|
747
|
+
__m512 col_scale_f32;
|
|
748
|
+
if constexpr (
|
|
749
|
+
std::is_same_v<block_tx8, block_q4_0x8> ||
|
|
750
|
+
std::is_same_v<block_tx8, block_iq4_nlx8>) {
|
|
751
|
+
col_scale_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].d, b_ptr_1[b].d);
|
|
752
|
+
}
|
|
773
753
|
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
lhs_vec_10 = _mm256_permute2f128_si256(lhs_vec_10, lhs_vec_10, 0);
|
|
777
|
-
lhs_vec_11 = _mm256_permute2f128_si256(lhs_vec_11, lhs_vec_11, 0);
|
|
754
|
+
// Process LHS in pairs of rows
|
|
755
|
+
for (int rp = 0; rp < 4; rp++) {
|
|
778
756
|
|
|
779
|
-
//
|
|
780
|
-
//
|
|
781
|
-
|
|
782
|
-
|
|
783
|
-
|
|
784
|
-
|
|
757
|
+
// Load the four blocks of quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3
|
|
758
|
+
// Loaded as set of 128 bit vectors and repeated and stored into a 256 bit vector before again repeating into 512 bit vector
|
|
759
|
+
__m256i lhs_mat_ymm_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs)));
|
|
760
|
+
__m256i lhs_mat_ymm_01_0 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_0, lhs_mat_ymm_0123_0, 0);
|
|
761
|
+
__m256i lhs_mat_ymm_23_0 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_0, lhs_mat_ymm_0123_0, 17);
|
|
762
|
+
__m256i lhs_mat_ymm_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 32)));
|
|
763
|
+
__m256i lhs_mat_ymm_01_1 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_1, lhs_mat_ymm_0123_1, 0);
|
|
764
|
+
__m256i lhs_mat_ymm_23_1 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_1, lhs_mat_ymm_0123_1, 17);
|
|
765
|
+
__m256i lhs_mat_ymm_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 64)));
|
|
766
|
+
__m256i lhs_mat_ymm_01_2 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_2, lhs_mat_ymm_0123_2, 0);
|
|
767
|
+
__m256i lhs_mat_ymm_23_2 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_2, lhs_mat_ymm_0123_2, 17);
|
|
768
|
+
__m256i lhs_mat_ymm_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 96)));
|
|
769
|
+
__m256i lhs_mat_ymm_01_3 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_3, lhs_mat_ymm_0123_3, 0);
|
|
770
|
+
__m256i lhs_mat_ymm_23_3 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_3, lhs_mat_ymm_0123_3, 17);
|
|
785
771
|
|
|
772
|
+
__m512i lhs_mat_01_0 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_0), lhs_mat_ymm_01_0, 1);
|
|
773
|
+
__m512i lhs_mat_23_0 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_0), lhs_mat_ymm_23_0, 1);
|
|
774
|
+
__m512i lhs_mat_01_1 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_1), lhs_mat_ymm_01_1, 1);
|
|
775
|
+
__m512i lhs_mat_23_1 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_1), lhs_mat_ymm_23_1, 1);
|
|
776
|
+
__m512i lhs_mat_01_2 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_2), lhs_mat_ymm_01_2, 1);
|
|
777
|
+
__m512i lhs_mat_23_2 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_2), lhs_mat_ymm_23_2, 1);
|
|
778
|
+
__m512i lhs_mat_01_3 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_3), lhs_mat_ymm_01_3, 1);
|
|
779
|
+
__m512i lhs_mat_23_3 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_3), lhs_mat_ymm_23_3, 1);
|
|
786
780
|
|
|
787
|
-
|
|
788
|
-
__m256i iacc_1 = _mm256_setzero_si256();
|
|
781
|
+
// Shuffle pattern one - left side input
|
|
789
782
|
|
|
790
|
-
|
|
791
|
-
|
|
783
|
+
const __m512i lhs_mat_01_0_sp1 = _mm512_shuffle_epi32(lhs_mat_01_0, (_MM_PERM_ENUM)160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3)
|
|
784
|
+
const __m512i lhs_mat_23_0_sp1 = _mm512_shuffle_epi32(lhs_mat_23_0, (_MM_PERM_ENUM)160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3)
|
|
792
785
|
|
|
793
|
-
|
|
794
|
-
|
|
786
|
+
const __m512i lhs_mat_01_1_sp1 = _mm512_shuffle_epi32(lhs_mat_01_1, (_MM_PERM_ENUM)160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11)
|
|
787
|
+
const __m512i lhs_mat_23_1_sp1 = _mm512_shuffle_epi32(lhs_mat_23_1, (_MM_PERM_ENUM)160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11)
|
|
795
788
|
|
|
796
|
-
|
|
797
|
-
|
|
789
|
+
const __m512i lhs_mat_01_2_sp1 = _mm512_shuffle_epi32(lhs_mat_01_2, (_MM_PERM_ENUM)160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19)
|
|
790
|
+
const __m512i lhs_mat_23_2_sp1 = _mm512_shuffle_epi32(lhs_mat_23_2, (_MM_PERM_ENUM)160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19)
|
|
798
791
|
|
|
799
|
-
|
|
800
|
-
|
|
792
|
+
const __m512i lhs_mat_01_3_sp1 = _mm512_shuffle_epi32(lhs_mat_01_3, (_MM_PERM_ENUM)160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27)
|
|
793
|
+
const __m512i lhs_mat_23_3_sp1 = _mm512_shuffle_epi32(lhs_mat_23_3, (_MM_PERM_ENUM)160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27)
|
|
801
794
|
|
|
802
|
-
|
|
795
|
+
// Shuffle pattern two - left side input
|
|
803
796
|
|
|
804
|
-
|
|
805
|
-
|
|
797
|
+
const __m512i lhs_mat_01_0_sp2 = _mm512_shuffle_epi32(lhs_mat_01_0, (_MM_PERM_ENUM)245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7)
|
|
798
|
+
const __m512i lhs_mat_23_0_sp2 = _mm512_shuffle_epi32(lhs_mat_23_0, (_MM_PERM_ENUM)245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7)
|
|
806
799
|
|
|
807
|
-
|
|
808
|
-
|
|
800
|
+
const __m512i lhs_mat_01_1_sp2 = _mm512_shuffle_epi32(lhs_mat_01_1, (_MM_PERM_ENUM)245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15)
|
|
801
|
+
const __m512i lhs_mat_23_1_sp2 = _mm512_shuffle_epi32(lhs_mat_23_1, (_MM_PERM_ENUM)245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15)
|
|
809
802
|
|
|
810
|
-
|
|
811
|
-
|
|
803
|
+
const __m512i lhs_mat_01_2_sp2 = _mm512_shuffle_epi32(lhs_mat_01_2, (_MM_PERM_ENUM)245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23)
|
|
804
|
+
const __m512i lhs_mat_23_2_sp2 = _mm512_shuffle_epi32(lhs_mat_23_2, (_MM_PERM_ENUM)245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23)
|
|
812
805
|
|
|
813
|
-
|
|
814
|
-
|
|
806
|
+
const __m512i lhs_mat_01_3_sp2 = _mm512_shuffle_epi32(lhs_mat_01_3, (_MM_PERM_ENUM)245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31)
|
|
807
|
+
const __m512i lhs_mat_23_3_sp2 = _mm512_shuffle_epi32(lhs_mat_23_3, (_MM_PERM_ENUM)245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31)
|
|
815
808
|
|
|
816
|
-
|
|
809
|
+
// The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
|
|
810
|
+
// Resembles MMLAs into 2x2 matrices in ARM Version
|
|
811
|
+
const __m512i zero = _mm512_setzero_epi32();
|
|
812
|
+
__m512i iacc_mat_00_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp1, rhs_mat_014589CD_3_sp1), lhs_mat_01_2_sp1, rhs_mat_014589CD_2_sp1), lhs_mat_01_1_sp1, rhs_mat_014589CD_1_sp1), lhs_mat_01_0_sp1, rhs_mat_014589CD_0_sp1);
|
|
813
|
+
__m512i iacc_mat_01_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp1, rhs_mat_2367ABEF_3_sp1), lhs_mat_01_2_sp1, rhs_mat_2367ABEF_2_sp1), lhs_mat_01_1_sp1, rhs_mat_2367ABEF_1_sp1), lhs_mat_01_0_sp1, rhs_mat_2367ABEF_0_sp1);
|
|
814
|
+
__m512i iacc_mat_10_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp1, rhs_mat_014589CD_3_sp1), lhs_mat_23_2_sp1, rhs_mat_014589CD_2_sp1), lhs_mat_23_1_sp1, rhs_mat_014589CD_1_sp1), lhs_mat_23_0_sp1, rhs_mat_014589CD_0_sp1);
|
|
815
|
+
__m512i iacc_mat_11_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp1, rhs_mat_2367ABEF_3_sp1), lhs_mat_23_2_sp1, rhs_mat_2367ABEF_2_sp1), lhs_mat_23_1_sp1, rhs_mat_2367ABEF_1_sp1), lhs_mat_23_0_sp1, rhs_mat_2367ABEF_0_sp1);
|
|
816
|
+
__m512i iacc_mat_00_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp2, rhs_mat_014589CD_3_sp2), lhs_mat_01_2_sp2, rhs_mat_014589CD_2_sp2), lhs_mat_01_1_sp2, rhs_mat_014589CD_1_sp2), lhs_mat_01_0_sp2, rhs_mat_014589CD_0_sp2);
|
|
817
|
+
__m512i iacc_mat_01_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp2, rhs_mat_2367ABEF_3_sp2), lhs_mat_01_2_sp2, rhs_mat_2367ABEF_2_sp2), lhs_mat_01_1_sp2, rhs_mat_2367ABEF_1_sp2), lhs_mat_01_0_sp2, rhs_mat_2367ABEF_0_sp2);
|
|
818
|
+
__m512i iacc_mat_10_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp2, rhs_mat_014589CD_3_sp2), lhs_mat_23_2_sp2, rhs_mat_014589CD_2_sp2), lhs_mat_23_1_sp2, rhs_mat_014589CD_1_sp2), lhs_mat_23_0_sp2, rhs_mat_014589CD_0_sp2);
|
|
819
|
+
__m512i iacc_mat_11_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp2, rhs_mat_2367ABEF_3_sp2), lhs_mat_23_2_sp2, rhs_mat_2367ABEF_2_sp2), lhs_mat_23_1_sp2, rhs_mat_2367ABEF_1_sp2), lhs_mat_23_0_sp2, rhs_mat_2367ABEF_0_sp2);
|
|
817
820
|
|
|
818
|
-
//
|
|
819
|
-
|
|
821
|
+
// Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
|
|
822
|
+
__m512i iacc_mat_00 = _mm512_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2);
|
|
823
|
+
__m512i iacc_mat_01 = _mm512_add_epi32(iacc_mat_01_sp1, iacc_mat_01_sp2);
|
|
824
|
+
__m512i iacc_mat_10 = _mm512_add_epi32(iacc_mat_10_sp1, iacc_mat_10_sp2);
|
|
825
|
+
__m512i iacc_mat_11 = _mm512_add_epi32(iacc_mat_11_sp1, iacc_mat_11_sp2);
|
|
820
826
|
|
|
821
|
-
// Broadcast the bsums of the two sub blocks of the iteration of Q8_K across the vector
|
|
822
|
-
// Multiply-Add with corresponding mins of Q4_Kx8 with bsums
|
|
823
|
-
__m256i q8s_sb = _mm256_shuffle_epi32(q8s, 0);
|
|
824
|
-
__m256i iacc_min_sb = _mm256_madd_epi16(q8s_sb, mins_01);
|
|
825
|
-
q8s = _mm256_bsrli_epi128(q8s, 4);
|
|
826
827
|
|
|
827
|
-
//
|
|
828
|
-
|
|
829
|
-
|
|
828
|
+
// Straighten out to make 4 row vectors
|
|
829
|
+
__m512i iacc_row_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00, _mm512_shuffle_epi32(iacc_mat_01, (_MM_PERM_ENUM)78));
|
|
830
|
+
__m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, (_MM_PERM_ENUM)78), iacc_mat_01);
|
|
831
|
+
__m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, (_MM_PERM_ENUM)78));
|
|
832
|
+
__m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, (_MM_PERM_ENUM)78), iacc_mat_11);
|
|
833
|
+
|
|
834
|
+
// Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes
|
|
835
|
+
const __m128i row_scale_f16 = _mm_shuffle_epi32(_mm_maskload_epi32((int const*)(a_ptrs[rp][b].d), loadMask), 68);
|
|
836
|
+
const __m512 row_scale_f32 = GGML_F32Cx16_REPEAT_LOAD(row_scale_f16);
|
|
837
|
+
|
|
838
|
+
// Multiply with appropiate scales and accumulate
|
|
839
|
+
acc_rows[rp * 4] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_0), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[rp * 4]);
|
|
840
|
+
acc_rows[rp * 4 + 1] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_1), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[rp * 4 + 1]);
|
|
841
|
+
acc_rows[rp * 4 + 2] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_2), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[rp * 4 + 2]);
|
|
842
|
+
acc_rows[rp * 4 + 3] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_3), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[rp * 4 + 3]);
|
|
830
843
|
}
|
|
844
|
+
}
|
|
831
845
|
|
|
832
|
-
|
|
833
|
-
|
|
834
|
-
|
|
846
|
+
// Store the accumulated values
|
|
847
|
+
for (int i = 0; i < 16; i++) {
|
|
848
|
+
_mm512_storeu_ps((float *)(s + ((y * 4 + i) * bs + x * 8)), acc_rows[i]);
|
|
849
|
+
}
|
|
850
|
+
}
|
|
851
|
+
}
|
|
852
|
+
|
|
853
|
+
// Take a block_q8_0x4 structures at each pass of the loop and perform dot product operation
|
|
854
|
+
for (; y < nr / 4; y ++) {
|
|
855
|
+
const block_q8_0x4 * a_ptr = a_ptr_start + (y * nb);
|
|
856
|
+
|
|
857
|
+
// Take group of two block_tx8 structures at each pass of the loop and perform dot product operation
|
|
858
|
+
for (int64_t x = 0; x < anc / 8; x += 2) {
|
|
859
|
+
|
|
860
|
+
const block_tx8 * b_ptr_0 = b_ptr_start + ((x) * b_nb);
|
|
861
|
+
const block_tx8 * b_ptr_1 = b_ptr_start + ((x + 1) * b_nb);
|
|
835
862
|
|
|
863
|
+
// Master FP accumulators
|
|
864
|
+
__m512 acc_rows[4];
|
|
865
|
+
for (int i = 0; i < 4; i++) {
|
|
866
|
+
acc_rows[i] = _mm512_setzero_ps();
|
|
836
867
|
}
|
|
837
868
|
|
|
838
|
-
|
|
839
|
-
|
|
840
|
-
|
|
869
|
+
for (int64_t b = 0; b < nb; b++) {
|
|
870
|
+
// Load the sixteen blocks of quantized values interleaved with each other in chunks of eight - B0,B1 ....BE,BF
|
|
871
|
+
const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs));
|
|
872
|
+
const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 32));
|
|
873
|
+
const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 64));
|
|
874
|
+
const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 96));
|
|
875
|
+
|
|
876
|
+
const __m256i rhs_raw_mat_89AB_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs));
|
|
877
|
+
const __m256i rhs_raw_mat_CDEF_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 32));
|
|
878
|
+
const __m256i rhs_raw_mat_89AB_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 64));
|
|
879
|
+
const __m256i rhs_raw_mat_CDEF_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 96));
|
|
880
|
+
|
|
881
|
+
// Save the values in the following vectors in the formats B0B1B4B5, B2B3B6B7 for further processing and storing of valuess
|
|
882
|
+
const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240);
|
|
883
|
+
const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240);
|
|
884
|
+
const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240);
|
|
885
|
+
const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240);
|
|
886
|
+
|
|
887
|
+
const __m256i rhs_raw_mat_89CD_0 = _mm256_blend_epi32(rhs_raw_mat_89AB_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_0, requiredOrder), 240);
|
|
888
|
+
const __m256i rhs_raw_mat_ABEF_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_0, requiredOrder), rhs_raw_mat_CDEF_0, 240);
|
|
889
|
+
const __m256i rhs_raw_mat_89CD_1 = _mm256_blend_epi32(rhs_raw_mat_89AB_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_1, requiredOrder), 240);
|
|
890
|
+
const __m256i rhs_raw_mat_ABEF_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_1, requiredOrder), rhs_raw_mat_CDEF_1, 240);
|
|
891
|
+
|
|
892
|
+
const __m512i rhs_raw_mat_014589CD_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_0), rhs_raw_mat_89CD_0, 1);
|
|
893
|
+
const __m512i rhs_raw_mat_2367ABEF_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_0), rhs_raw_mat_ABEF_0, 1);
|
|
894
|
+
const __m512i rhs_raw_mat_014589CD_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_1), rhs_raw_mat_89CD_1, 1);
|
|
895
|
+
const __m512i rhs_raw_mat_2367ABEF_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_1), rhs_raw_mat_ABEF_1, 1);
|
|
896
|
+
|
|
897
|
+
// 4-bit -> 8-bit - Sign is maintained
|
|
898
|
+
const __m512i rhs_mat_014589CD_0 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_014589CD_0, m4bexpanded)); //B0(0-7) B1(0-7) B4(0-7) B5(0-7) B8(0-7) B9(0-7) BC(0-7) BD(0-7)
|
|
899
|
+
const __m512i rhs_mat_2367ABEF_0 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_2367ABEF_0, m4bexpanded)); //B2(0-7) B3(0-7) B6(0-7) B7(0-7) BA(0-7) BB(0-7) BE(0-7) BF(0-7)
|
|
900
|
+
|
|
901
|
+
const __m512i rhs_mat_014589CD_1 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_014589CD_1, m4bexpanded)); //B0(8-15) B1(8-15) B4(8-15) B5(8-15) B8(8-15) B9(8-15) BC(8-15) BD(8-15)
|
|
902
|
+
const __m512i rhs_mat_2367ABEF_1 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_2367ABEF_1, m4bexpanded)); //B2(8-15) B3(8-15) B6(8-15) B7(8-15) BA(8-15) BB(8-15) BE(8-15) BF(8-15)
|
|
903
|
+
|
|
904
|
+
const __m512i rhs_mat_014589CD_2 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_0, 4), m4bexpanded)); //B0(16-23) B1(16-23) B4(16-23) B5(16-23) B8(16-23) B9(16-23) BC(16-23) BD(16-23)
|
|
905
|
+
const __m512i rhs_mat_2367ABEF_2 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_0, 4), m4bexpanded)); //B2(16-23) B3(16-23) B6(16-23) B7(16-23) BA(16-23) BB(16-23) BE(16-23) BF(16-23)
|
|
906
|
+
|
|
907
|
+
const __m512i rhs_mat_014589CD_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_1, 4), m4bexpanded)); //B0(24-31) B1(24-31) B4(24-31) B5(24-31) B8(24-31) B9(24-31) BC(24-31) BD(24-31)
|
|
908
|
+
const __m512i rhs_mat_2367ABEF_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 4), m4bexpanded)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31) BA(24-31) BB(24-31) BE(24-31) BF(24-31)
|
|
909
|
+
|
|
910
|
+
// Shuffle pattern one - right side input
|
|
911
|
+
const __m512i rhs_mat_014589CD_0_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, (_MM_PERM_ENUM)136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) B8(0-3) B9(0-3) B8(0-3) B9(0-3) BC(0-3) BD(0-3) BC(0-3) BD(0-3)
|
|
912
|
+
const __m512i rhs_mat_2367ABEF_0_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, (_MM_PERM_ENUM)136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) BA(0-3) BB(0-3) BA(0-3) BB(0-3) BE(0-3) BF(0-3) BE(0-3) BF(0-3)
|
|
913
|
+
|
|
914
|
+
const __m512i rhs_mat_014589CD_1_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, (_MM_PERM_ENUM)136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) B8(8-11) B9(8-11) B8(8-11) B9(8-11) BC(8-11) BD(8-11) BC(8-11) BD(8-11)
|
|
915
|
+
const __m512i rhs_mat_2367ABEF_1_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, (_MM_PERM_ENUM)136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) BA(8-11) BB(8-11) BA(8-11) BB(8-11) BE(8-11) BF(8-11) BE(8-11) BF(8-11)
|
|
916
|
+
|
|
917
|
+
const __m512i rhs_mat_014589CD_2_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, (_MM_PERM_ENUM)136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) B8(16-19) B9(16-19) B8(16-19) B9(16-19) BC(16-19) BD(16-19) BC(16-19) BD(16-19)
|
|
918
|
+
const __m512i rhs_mat_2367ABEF_2_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, (_MM_PERM_ENUM)136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) BA(16-19) BB(16-19) BA(16-19) BB(16-19) BE(16-19) BF(16-19) BE(16-19) BF(16-19)
|
|
919
|
+
|
|
920
|
+
const __m512i rhs_mat_014589CD_3_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, (_MM_PERM_ENUM)136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) B8(24-27) B9(24-27) B8(24-27) B9(24-27) BC(24-27) BD(24-27) BC(24-27) BD(24-27)
|
|
921
|
+
const __m512i rhs_mat_2367ABEF_3_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, (_MM_PERM_ENUM)136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) BA(24-27) BB(24-27) BA(24-27) BB(24-27) BE(24-27) BF(24-27) BE(24-27) BF(24-27)
|
|
922
|
+
|
|
923
|
+
// Shuffle pattern two - right side input
|
|
924
|
+
|
|
925
|
+
const __m512i rhs_mat_014589CD_0_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, (_MM_PERM_ENUM)221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) B8(4-7) B9(4-7) B8(4-7) B9(4-7) BC(4-7) BD(4-7) BC(4-7) BD(4-7)
|
|
926
|
+
const __m512i rhs_mat_2367ABEF_0_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, (_MM_PERM_ENUM)221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) BA(4-7) BB(4-7) BA(4-7) BB(4-7) BE(4-7) BF(4-7) BE(4-7) BF(4-7)
|
|
927
|
+
|
|
928
|
+
const __m512i rhs_mat_014589CD_1_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, (_MM_PERM_ENUM)221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) B8(12-15) B9(12-15) B8(12-15) B9(12-15) BC(12-15) BD(12-15) BC(12-15) BD(12-15)
|
|
929
|
+
const __m512i rhs_mat_2367ABEF_1_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, (_MM_PERM_ENUM)221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) BA(12-15) BB(12-15) BA(12-15) BB(12-15) BE(12-15) BF(12-15) BE(12-15) BF(12-15)
|
|
930
|
+
|
|
931
|
+
const __m512i rhs_mat_014589CD_2_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, (_MM_PERM_ENUM)221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) B8(20-23) B9(20-23) B8(20-23) B9(20-23) BC(20-23) BD(20-23) BC(20-23) BD(20-23)
|
|
932
|
+
const __m512i rhs_mat_2367ABEF_2_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, (_MM_PERM_ENUM)221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) BA(20-23) BB(20-23) BA(20-23) BB(20-23) BE(20-23) BF(20-23) BE(20-23) BF(20-23)
|
|
933
|
+
|
|
934
|
+
const __m512i rhs_mat_014589CD_3_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, (_MM_PERM_ENUM)221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) B8(28-31) B9(28-31) B8(28-31) B9(28-31) BC(28-31) BD(28-31) BC(28-31) BD(28-31)
|
|
935
|
+
const __m512i rhs_mat_2367ABEF_3_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, (_MM_PERM_ENUM)221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) BA(28-31) BB(28-31) BA(28-31) BB(28-31) BE(28-31) BF(28-31) BE(28-31) BF(28-31)
|
|
936
|
+
|
|
937
|
+
|
|
938
|
+
// Scale values - Load the weight scale values of two block_tx8
|
|
939
|
+
__m512 col_scale_f32;
|
|
940
|
+
if constexpr (
|
|
941
|
+
std::is_same_v<block_tx8, block_q4_0x8> ||
|
|
942
|
+
std::is_same_v<block_tx8, block_iq4_nlx8>) {
|
|
943
|
+
col_scale_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].d, b_ptr_1[b].d);
|
|
944
|
+
}
|
|
945
|
+
|
|
946
|
+
// Load the four blocks of quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3
|
|
947
|
+
// Loaded as set of 128 bit vectors and repeated and stored into a 256 bit vector before again repeating into 512 bit vector
|
|
948
|
+
__m256i lhs_mat_ymm_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs)));
|
|
949
|
+
__m256i lhs_mat_ymm_01_0 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_0, lhs_mat_ymm_0123_0, 0);
|
|
950
|
+
__m256i lhs_mat_ymm_23_0 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_0, lhs_mat_ymm_0123_0, 17);
|
|
951
|
+
__m256i lhs_mat_ymm_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 32)));
|
|
952
|
+
__m256i lhs_mat_ymm_01_1 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_1, lhs_mat_ymm_0123_1, 0);
|
|
953
|
+
__m256i lhs_mat_ymm_23_1 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_1, lhs_mat_ymm_0123_1, 17);
|
|
954
|
+
__m256i lhs_mat_ymm_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 64)));
|
|
955
|
+
__m256i lhs_mat_ymm_01_2 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_2, lhs_mat_ymm_0123_2, 0);
|
|
956
|
+
__m256i lhs_mat_ymm_23_2 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_2, lhs_mat_ymm_0123_2, 17);
|
|
957
|
+
__m256i lhs_mat_ymm_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 96)));
|
|
958
|
+
__m256i lhs_mat_ymm_01_3 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_3, lhs_mat_ymm_0123_3, 0);
|
|
959
|
+
__m256i lhs_mat_ymm_23_3 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_3, lhs_mat_ymm_0123_3, 17);
|
|
960
|
+
|
|
961
|
+
__m512i lhs_mat_01_0 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_0), lhs_mat_ymm_01_0, 1);
|
|
962
|
+
__m512i lhs_mat_23_0 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_0), lhs_mat_ymm_23_0, 1);
|
|
963
|
+
__m512i lhs_mat_01_1 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_1), lhs_mat_ymm_01_1, 1);
|
|
964
|
+
__m512i lhs_mat_23_1 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_1), lhs_mat_ymm_23_1, 1);
|
|
965
|
+
__m512i lhs_mat_01_2 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_2), lhs_mat_ymm_01_2, 1);
|
|
966
|
+
__m512i lhs_mat_23_2 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_2), lhs_mat_ymm_23_2, 1);
|
|
967
|
+
__m512i lhs_mat_01_3 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_3), lhs_mat_ymm_01_3, 1);
|
|
968
|
+
__m512i lhs_mat_23_3 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_3), lhs_mat_ymm_23_3, 1);
|
|
969
|
+
|
|
970
|
+
// Shuffle pattern one - left side input
|
|
971
|
+
|
|
972
|
+
const __m512i lhs_mat_01_0_sp1 = _mm512_shuffle_epi32(lhs_mat_01_0, (_MM_PERM_ENUM)160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3)
|
|
973
|
+
const __m512i lhs_mat_23_0_sp1 = _mm512_shuffle_epi32(lhs_mat_23_0, (_MM_PERM_ENUM)160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3)
|
|
974
|
+
|
|
975
|
+
const __m512i lhs_mat_01_1_sp1 = _mm512_shuffle_epi32(lhs_mat_01_1, (_MM_PERM_ENUM)160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11)
|
|
976
|
+
const __m512i lhs_mat_23_1_sp1 = _mm512_shuffle_epi32(lhs_mat_23_1, (_MM_PERM_ENUM)160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11)
|
|
977
|
+
|
|
978
|
+
const __m512i lhs_mat_01_2_sp1 = _mm512_shuffle_epi32(lhs_mat_01_2, (_MM_PERM_ENUM)160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19)
|
|
979
|
+
const __m512i lhs_mat_23_2_sp1 = _mm512_shuffle_epi32(lhs_mat_23_2, (_MM_PERM_ENUM)160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19)
|
|
980
|
+
|
|
981
|
+
const __m512i lhs_mat_01_3_sp1 = _mm512_shuffle_epi32(lhs_mat_01_3, (_MM_PERM_ENUM)160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27)
|
|
982
|
+
const __m512i lhs_mat_23_3_sp1 = _mm512_shuffle_epi32(lhs_mat_23_3, (_MM_PERM_ENUM)160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27)
|
|
983
|
+
|
|
984
|
+
// Shuffle pattern two - left side input
|
|
985
|
+
|
|
986
|
+
const __m512i lhs_mat_01_0_sp2 = _mm512_shuffle_epi32(lhs_mat_01_0, (_MM_PERM_ENUM)245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7)
|
|
987
|
+
const __m512i lhs_mat_23_0_sp2 = _mm512_shuffle_epi32(lhs_mat_23_0, (_MM_PERM_ENUM)245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7)
|
|
988
|
+
|
|
989
|
+
const __m512i lhs_mat_01_1_sp2 = _mm512_shuffle_epi32(lhs_mat_01_1, (_MM_PERM_ENUM)245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15)
|
|
990
|
+
const __m512i lhs_mat_23_1_sp2 = _mm512_shuffle_epi32(lhs_mat_23_1, (_MM_PERM_ENUM)245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15)
|
|
991
|
+
|
|
992
|
+
const __m512i lhs_mat_01_2_sp2 = _mm512_shuffle_epi32(lhs_mat_01_2, (_MM_PERM_ENUM)245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23)
|
|
993
|
+
const __m512i lhs_mat_23_2_sp2 = _mm512_shuffle_epi32(lhs_mat_23_2, (_MM_PERM_ENUM)245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23)
|
|
994
|
+
|
|
995
|
+
const __m512i lhs_mat_01_3_sp2 = _mm512_shuffle_epi32(lhs_mat_01_3, (_MM_PERM_ENUM)245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31)
|
|
996
|
+
const __m512i lhs_mat_23_3_sp2 = _mm512_shuffle_epi32(lhs_mat_23_3, (_MM_PERM_ENUM)245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31)
|
|
997
|
+
|
|
998
|
+
// The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
|
|
999
|
+
// Resembles MMLAs into 2x2 matrices in ARM Version
|
|
1000
|
+
const __m512i zero = _mm512_setzero_epi32();
|
|
1001
|
+
__m512i iacc_mat_00_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp1, rhs_mat_014589CD_3_sp1), lhs_mat_01_2_sp1, rhs_mat_014589CD_2_sp1), lhs_mat_01_1_sp1, rhs_mat_014589CD_1_sp1), lhs_mat_01_0_sp1, rhs_mat_014589CD_0_sp1);
|
|
1002
|
+
__m512i iacc_mat_01_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp1, rhs_mat_2367ABEF_3_sp1), lhs_mat_01_2_sp1, rhs_mat_2367ABEF_2_sp1), lhs_mat_01_1_sp1, rhs_mat_2367ABEF_1_sp1), lhs_mat_01_0_sp1, rhs_mat_2367ABEF_0_sp1);
|
|
1003
|
+
__m512i iacc_mat_10_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp1, rhs_mat_014589CD_3_sp1), lhs_mat_23_2_sp1, rhs_mat_014589CD_2_sp1), lhs_mat_23_1_sp1, rhs_mat_014589CD_1_sp1), lhs_mat_23_0_sp1, rhs_mat_014589CD_0_sp1);
|
|
1004
|
+
__m512i iacc_mat_11_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp1, rhs_mat_2367ABEF_3_sp1), lhs_mat_23_2_sp1, rhs_mat_2367ABEF_2_sp1), lhs_mat_23_1_sp1, rhs_mat_2367ABEF_1_sp1), lhs_mat_23_0_sp1, rhs_mat_2367ABEF_0_sp1);
|
|
1005
|
+
__m512i iacc_mat_00_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp2, rhs_mat_014589CD_3_sp2), lhs_mat_01_2_sp2, rhs_mat_014589CD_2_sp2), lhs_mat_01_1_sp2, rhs_mat_014589CD_1_sp2), lhs_mat_01_0_sp2, rhs_mat_014589CD_0_sp2);
|
|
1006
|
+
__m512i iacc_mat_01_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp2, rhs_mat_2367ABEF_3_sp2), lhs_mat_01_2_sp2, rhs_mat_2367ABEF_2_sp2), lhs_mat_01_1_sp2, rhs_mat_2367ABEF_1_sp2), lhs_mat_01_0_sp2, rhs_mat_2367ABEF_0_sp2);
|
|
1007
|
+
__m512i iacc_mat_10_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp2, rhs_mat_014589CD_3_sp2), lhs_mat_23_2_sp2, rhs_mat_014589CD_2_sp2), lhs_mat_23_1_sp2, rhs_mat_014589CD_1_sp2), lhs_mat_23_0_sp2, rhs_mat_014589CD_0_sp2);
|
|
1008
|
+
__m512i iacc_mat_11_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp2, rhs_mat_2367ABEF_3_sp2), lhs_mat_23_2_sp2, rhs_mat_2367ABEF_2_sp2), lhs_mat_23_1_sp2, rhs_mat_2367ABEF_1_sp2), lhs_mat_23_0_sp2, rhs_mat_2367ABEF_0_sp2);
|
|
1009
|
+
|
|
1010
|
+
// Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
|
|
1011
|
+
__m512i iacc_mat_00 = _mm512_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2);
|
|
1012
|
+
__m512i iacc_mat_01 = _mm512_add_epi32(iacc_mat_01_sp1, iacc_mat_01_sp2);
|
|
1013
|
+
__m512i iacc_mat_10 = _mm512_add_epi32(iacc_mat_10_sp1, iacc_mat_10_sp2);
|
|
1014
|
+
__m512i iacc_mat_11 = _mm512_add_epi32(iacc_mat_11_sp1, iacc_mat_11_sp2);
|
|
1015
|
+
|
|
1016
|
+
|
|
1017
|
+
// Straighten out to make 4 row vectors
|
|
1018
|
+
__m512i iacc_row_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00, _mm512_shuffle_epi32(iacc_mat_01, (_MM_PERM_ENUM)78));
|
|
1019
|
+
__m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, (_MM_PERM_ENUM)78), iacc_mat_01);
|
|
1020
|
+
__m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, (_MM_PERM_ENUM)78));
|
|
1021
|
+
__m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, (_MM_PERM_ENUM)78), iacc_mat_11);
|
|
1022
|
+
|
|
1023
|
+
// Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes
|
|
1024
|
+
const __m128i row_scale_f16 = _mm_shuffle_epi32(_mm_maskload_epi32((int const*)(a_ptr[b].d), loadMask), 68);
|
|
1025
|
+
const __m512 row_scale_f32 = GGML_F32Cx16_REPEAT_LOAD(row_scale_f16);
|
|
1026
|
+
|
|
1027
|
+
// Multiply with appropiate scales and accumulate
|
|
1028
|
+
acc_rows[0] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_0), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[0]);
|
|
1029
|
+
acc_rows[1] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_1), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[1]);
|
|
1030
|
+
acc_rows[2] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_2), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[2]);
|
|
1031
|
+
acc_rows[3] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_3), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[3]);
|
|
1032
|
+
}
|
|
1033
|
+
|
|
1034
|
+
// Store the accumulated values
|
|
1035
|
+
for (int i = 0; i < 4; i++) {
|
|
1036
|
+
_mm512_storeu_ps((float *)(s + ((y * 4 + i) * bs + x * 8)), acc_rows[i]);
|
|
1037
|
+
}
|
|
841
1038
|
}
|
|
842
1039
|
}
|
|
1040
|
+
if (anc != nc) {
|
|
1041
|
+
xstart = anc/8;
|
|
1042
|
+
y = 0;
|
|
1043
|
+
}
|
|
1044
|
+
#endif // __AVX512F__
|
|
843
1045
|
|
|
844
|
-
|
|
845
|
-
|
|
846
|
-
|
|
847
|
-
|
|
848
|
-
|
|
1046
|
+
// Take group of four block_q8_0x4 structures at each pass of the loop and perform dot product operation
|
|
1047
|
+
|
|
1048
|
+
for (; y < anr / 4; y += 4) {
|
|
1049
|
+
const block_q8_0x4 * a_ptrs[4];
|
|
1050
|
+
|
|
1051
|
+
a_ptrs[0] = a_ptr_start + (y * nb);
|
|
1052
|
+
for (int i = 0; i < 3; ++i) {
|
|
1053
|
+
a_ptrs[i + 1] = a_ptrs[i] + nb;
|
|
1054
|
+
}
|
|
1055
|
+
|
|
1056
|
+
// Take group of eight block_tx8 structures at each pass of the loop and perform dot product operation
|
|
1057
|
+
for (int64_t x = xstart; x < nc / 8; x++) {
|
|
1058
|
+
|
|
1059
|
+
const block_tx8 * b_ptr = b_ptr_start + (x * b_nb);
|
|
1060
|
+
|
|
1061
|
+
// Master FP accumulators
|
|
1062
|
+
__m256 acc_rows[16];
|
|
1063
|
+
for (int i = 0; i < 16; i++) {
|
|
1064
|
+
acc_rows[i] = _mm256_setzero_ps();
|
|
1065
|
+
}
|
|
1066
|
+
|
|
1067
|
+
for (int64_t b = 0; b < nb; b++) {
|
|
1068
|
+
// Load the eight blocks of quantized values interleaved with each other in chunks of eight - B0,B1 ....B6,B7
|
|
1069
|
+
const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs));
|
|
1070
|
+
const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 32));
|
|
1071
|
+
const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 64));
|
|
1072
|
+
const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 96));
|
|
1073
|
+
|
|
1074
|
+
// Save the values in the following vectors in the formats B0B1B4B5, B2B3B6B7 for further processing and storing of values
|
|
1075
|
+
const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240);
|
|
1076
|
+
const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240);
|
|
1077
|
+
const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240);
|
|
1078
|
+
const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240);
|
|
1079
|
+
|
|
1080
|
+
// 4-bit -> 8-bit - Sign is maintained
|
|
1081
|
+
const __m256i rhs_mat_0145_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_0145_0, m4b)); //B0(0-7) B1(0-7) B4(0-7) B5(0-7)
|
|
1082
|
+
const __m256i rhs_mat_2367_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_2367_0, m4b)); //B2(0-7) B3(0-7) B6(0-7) B7(0-7)
|
|
1083
|
+
|
|
1084
|
+
const __m256i rhs_mat_0145_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_0145_1, m4b)); //B0(8-15) B1(8-15) B4(8-15) B5(8-15)
|
|
1085
|
+
const __m256i rhs_mat_2367_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_2367_1, m4b)); //B2(8-15) B3(8-15) B6(8-15) B7(8-15)
|
|
1086
|
+
|
|
1087
|
+
const __m256i rhs_mat_0145_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_0, 4), m4b)); //B0(16-23) B1(16-23) B4(16-23) B5(16-23)
|
|
1088
|
+
const __m256i rhs_mat_2367_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_0, 4), m4b)); //B2(16-23) B3(16-23) B6(16-23) B7(16-23)
|
|
1089
|
+
|
|
1090
|
+
const __m256i rhs_mat_0145_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_1, 4), m4b)); //B0(24-31) B1(24-31) B4(24-31) B5(24-31)
|
|
1091
|
+
const __m256i rhs_mat_2367_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_1, 4), m4b)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31)
|
|
1092
|
+
|
|
1093
|
+
// Shuffle pattern one - right side input
|
|
1094
|
+
const __m256i rhs_mat_0145_0_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_0, 136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3)
|
|
1095
|
+
const __m256i rhs_mat_2367_0_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_0, 136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3)
|
|
1096
|
+
|
|
1097
|
+
const __m256i rhs_mat_0145_1_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_1, 136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11)
|
|
1098
|
+
const __m256i rhs_mat_2367_1_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_1, 136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11)
|
|
1099
|
+
|
|
1100
|
+
const __m256i rhs_mat_0145_2_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_2, 136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19)
|
|
1101
|
+
const __m256i rhs_mat_2367_2_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_2, 136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19)
|
|
1102
|
+
|
|
1103
|
+
const __m256i rhs_mat_0145_3_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_3, 136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27)
|
|
1104
|
+
const __m256i rhs_mat_2367_3_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_3, 136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27)
|
|
1105
|
+
|
|
1106
|
+
// Shuffle pattern two - right side input
|
|
1107
|
+
|
|
1108
|
+
const __m256i rhs_mat_0145_0_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_0, 221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7)
|
|
1109
|
+
const __m256i rhs_mat_2367_0_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_0, 221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7)
|
|
1110
|
+
|
|
1111
|
+
const __m256i rhs_mat_0145_1_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_1, 221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15)
|
|
1112
|
+
const __m256i rhs_mat_2367_1_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_1, 221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15)
|
|
1113
|
+
|
|
1114
|
+
const __m256i rhs_mat_0145_2_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_2, 221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23)
|
|
1115
|
+
const __m256i rhs_mat_2367_2_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_2, 221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23)
|
|
1116
|
+
|
|
1117
|
+
const __m256i rhs_mat_0145_3_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_3, 221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31)
|
|
1118
|
+
const __m256i rhs_mat_2367_3_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_3, 221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31)
|
|
1119
|
+
|
|
1120
|
+
// Scale values - Load the wight scale values of block_tx8
|
|
1121
|
+
__m256 col_scale_f32;
|
|
1122
|
+
if constexpr (
|
|
1123
|
+
std::is_same_v<block_tx8, block_q4_0x8> ||
|
|
1124
|
+
std::is_same_v<block_tx8, block_iq4_nlx8>) {
|
|
1125
|
+
col_scale_f32 = GGML_F32Cx8_LOAD(b_ptr[b].d);
|
|
1126
|
+
}
|
|
1127
|
+
|
|
1128
|
+
// Process LHS in groups of four
|
|
1129
|
+
for (int rp = 0; rp < 4; rp++) {
|
|
1130
|
+
// Load the four blocks of quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3
|
|
1131
|
+
// Loaded as set of 128 bit vectors and repeated into a 256 bit vector
|
|
1132
|
+
__m256i lhs_mat_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs)));
|
|
1133
|
+
__m256i lhs_mat_01_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 0);
|
|
1134
|
+
__m256i lhs_mat_23_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 17);
|
|
1135
|
+
__m256i lhs_mat_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 32)));
|
|
1136
|
+
__m256i lhs_mat_01_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 0);
|
|
1137
|
+
__m256i lhs_mat_23_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 17);
|
|
1138
|
+
__m256i lhs_mat_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 64)));
|
|
1139
|
+
__m256i lhs_mat_01_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 0);
|
|
1140
|
+
__m256i lhs_mat_23_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 17);
|
|
1141
|
+
__m256i lhs_mat_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 96)));
|
|
1142
|
+
__m256i lhs_mat_01_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 0);
|
|
1143
|
+
__m256i lhs_mat_23_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 17);
|
|
1144
|
+
|
|
1145
|
+
// Shuffle pattern one - left side input
|
|
1146
|
+
const __m256i lhs_mat_01_0_sp1 = _mm256_shuffle_epi32(lhs_mat_01_0, 160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3)
|
|
1147
|
+
const __m256i lhs_mat_23_0_sp1 = _mm256_shuffle_epi32(lhs_mat_23_0, 160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3)
|
|
1148
|
+
|
|
1149
|
+
const __m256i lhs_mat_01_1_sp1 = _mm256_shuffle_epi32(lhs_mat_01_1, 160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11)
|
|
1150
|
+
const __m256i lhs_mat_23_1_sp1 = _mm256_shuffle_epi32(lhs_mat_23_1, 160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11)
|
|
1151
|
+
|
|
1152
|
+
const __m256i lhs_mat_01_2_sp1 = _mm256_shuffle_epi32(lhs_mat_01_2, 160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19)
|
|
1153
|
+
const __m256i lhs_mat_23_2_sp1 = _mm256_shuffle_epi32(lhs_mat_23_2, 160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19)
|
|
1154
|
+
|
|
1155
|
+
const __m256i lhs_mat_01_3_sp1 = _mm256_shuffle_epi32(lhs_mat_01_3, 160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27)
|
|
1156
|
+
const __m256i lhs_mat_23_3_sp1 = _mm256_shuffle_epi32(lhs_mat_23_3, 160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27)
|
|
1157
|
+
|
|
1158
|
+
// Shuffle pattern two - left side input
|
|
1159
|
+
const __m256i lhs_mat_01_0_sp2 = _mm256_shuffle_epi32(lhs_mat_01_0, 245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7)
|
|
1160
|
+
const __m256i lhs_mat_23_0_sp2 = _mm256_shuffle_epi32(lhs_mat_23_0, 245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7)
|
|
1161
|
+
|
|
1162
|
+
const __m256i lhs_mat_01_1_sp2 = _mm256_shuffle_epi32(lhs_mat_01_1, 245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15)
|
|
1163
|
+
const __m256i lhs_mat_23_1_sp2 = _mm256_shuffle_epi32(lhs_mat_23_1, 245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15)
|
|
1164
|
+
|
|
1165
|
+
const __m256i lhs_mat_01_2_sp2 = _mm256_shuffle_epi32(lhs_mat_01_2, 245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23)
|
|
1166
|
+
const __m256i lhs_mat_23_2_sp2 = _mm256_shuffle_epi32(lhs_mat_23_2, 245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23)
|
|
1167
|
+
|
|
1168
|
+
const __m256i lhs_mat_01_3_sp2 = _mm256_shuffle_epi32(lhs_mat_01_3, 245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31)
|
|
1169
|
+
const __m256i lhs_mat_23_3_sp2 = _mm256_shuffle_epi32(lhs_mat_23_3, 245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31)
|
|
1170
|
+
|
|
1171
|
+
// The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
|
|
1172
|
+
// Resembles MMLAs into 2x2 matrices in ARM Version
|
|
1173
|
+
const __m256i zero = _mm256_setzero_si256();
|
|
1174
|
+
__m256i iacc_mat_00_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp1, rhs_mat_0145_3_sp1), lhs_mat_01_2_sp1, rhs_mat_0145_2_sp1), lhs_mat_01_1_sp1, rhs_mat_0145_1_sp1), lhs_mat_01_0_sp1, rhs_mat_0145_0_sp1);
|
|
1175
|
+
__m256i iacc_mat_01_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp1, rhs_mat_2367_3_sp1), lhs_mat_01_2_sp1, rhs_mat_2367_2_sp1), lhs_mat_01_1_sp1, rhs_mat_2367_1_sp1), lhs_mat_01_0_sp1, rhs_mat_2367_0_sp1);
|
|
1176
|
+
__m256i iacc_mat_10_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp1, rhs_mat_0145_3_sp1), lhs_mat_23_2_sp1, rhs_mat_0145_2_sp1), lhs_mat_23_1_sp1, rhs_mat_0145_1_sp1), lhs_mat_23_0_sp1, rhs_mat_0145_0_sp1);
|
|
1177
|
+
__m256i iacc_mat_11_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp1, rhs_mat_2367_3_sp1), lhs_mat_23_2_sp1, rhs_mat_2367_2_sp1), lhs_mat_23_1_sp1, rhs_mat_2367_1_sp1), lhs_mat_23_0_sp1, rhs_mat_2367_0_sp1);
|
|
1178
|
+
__m256i iacc_mat_00_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp2, rhs_mat_0145_3_sp2), lhs_mat_01_2_sp2, rhs_mat_0145_2_sp2), lhs_mat_01_1_sp2, rhs_mat_0145_1_sp2), lhs_mat_01_0_sp2, rhs_mat_0145_0_sp2);
|
|
1179
|
+
__m256i iacc_mat_01_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp2, rhs_mat_2367_3_sp2), lhs_mat_01_2_sp2, rhs_mat_2367_2_sp2), lhs_mat_01_1_sp2, rhs_mat_2367_1_sp2), lhs_mat_01_0_sp2, rhs_mat_2367_0_sp2);
|
|
1180
|
+
__m256i iacc_mat_10_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp2, rhs_mat_0145_3_sp2), lhs_mat_23_2_sp2, rhs_mat_0145_2_sp2), lhs_mat_23_1_sp2, rhs_mat_0145_1_sp2), lhs_mat_23_0_sp2, rhs_mat_0145_0_sp2);
|
|
1181
|
+
__m256i iacc_mat_11_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp2, rhs_mat_2367_3_sp2), lhs_mat_23_2_sp2, rhs_mat_2367_2_sp2), lhs_mat_23_1_sp2, rhs_mat_2367_1_sp2), lhs_mat_23_0_sp2, rhs_mat_2367_0_sp2);
|
|
1182
|
+
|
|
1183
|
+
// Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
|
|
1184
|
+
__m256i iacc_mat_00 = _mm256_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2);
|
|
1185
|
+
__m256i iacc_mat_01 = _mm256_add_epi32(iacc_mat_01_sp1, iacc_mat_01_sp2);
|
|
1186
|
+
__m256i iacc_mat_10 = _mm256_add_epi32(iacc_mat_10_sp1, iacc_mat_10_sp2);
|
|
1187
|
+
__m256i iacc_mat_11 = _mm256_add_epi32(iacc_mat_11_sp1, iacc_mat_11_sp2);
|
|
1188
|
+
|
|
1189
|
+
// Straighten out to make 4 row vectors
|
|
1190
|
+
__m256i iacc_row_0 = _mm256_blend_epi32(iacc_mat_00, _mm256_shuffle_epi32(iacc_mat_01, 78), 204);
|
|
1191
|
+
__m256i iacc_row_1 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_00, 78), iacc_mat_01, 204);
|
|
1192
|
+
__m256i iacc_row_2 = _mm256_blend_epi32(iacc_mat_10, _mm256_shuffle_epi32(iacc_mat_11, 78), 204);
|
|
1193
|
+
__m256i iacc_row_3 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11, 204);
|
|
1194
|
+
|
|
1195
|
+
// Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes
|
|
1196
|
+
const __m256 row_scale_f32 = GGML_F32Cx8_REPEAT_LOAD(a_ptrs[rp][b].d, loadMask);
|
|
1197
|
+
|
|
1198
|
+
// Multiply with appropiate scales and accumulate
|
|
1199
|
+
acc_rows[rp * 4] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_0), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[rp * 4]);
|
|
1200
|
+
acc_rows[rp * 4 + 1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_1), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[rp * 4 + 1]);
|
|
1201
|
+
acc_rows[rp * 4 + 2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_2), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[rp * 4 + 2]);
|
|
1202
|
+
acc_rows[rp * 4 + 3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_3), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[rp * 4 + 3]);
|
|
1203
|
+
}
|
|
1204
|
+
}
|
|
1205
|
+
|
|
1206
|
+
// Store the accumulated values
|
|
1207
|
+
for (int i = 0; i < 16; i++) {
|
|
1208
|
+
_mm256_storeu_ps((float *)(s + ((y * 4 + i) * bs + x * 8)), acc_rows[i]);
|
|
1209
|
+
}
|
|
1210
|
+
}
|
|
1211
|
+
}
|
|
1212
|
+
|
|
1213
|
+
// Take a block_q8_0x4 structures at each pass of the loop and perform dot product operation
|
|
1214
|
+
for (; y < nr / 4; y ++) {
|
|
1215
|
+
const block_q8_0x4 * a_ptr = a_ptr_start + (y * nb);
|
|
1216
|
+
|
|
1217
|
+
// Load the eight blocks of quantized values interleaved with each other in chunks of eight - B0,B1 ....B6,B7
|
|
1218
|
+
for (int64_t x = xstart; x < nc / 8; x++) {
|
|
1219
|
+
const block_tx8 * b_ptr = b_ptr_start + (x * b_nb);
|
|
1220
|
+
|
|
1221
|
+
// Master FP accumulators
|
|
1222
|
+
__m256 acc_rows[4];
|
|
1223
|
+
for (int i = 0; i < 4; i++) {
|
|
1224
|
+
acc_rows[i] = _mm256_setzero_ps();
|
|
1225
|
+
}
|
|
1226
|
+
|
|
1227
|
+
for (int64_t b = 0; b < nb; b++) {
|
|
1228
|
+
// Load the eight block_q8_0 quantized values interleaved with each other in chunks of eight - B0,B1 ....B6,B7
|
|
1229
|
+
const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs));
|
|
1230
|
+
const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 32));
|
|
1231
|
+
const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 64));
|
|
1232
|
+
const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 96));
|
|
1233
|
+
|
|
1234
|
+
// Save the values in the following vectors in the formats B0B1B4B5, B2B3B6B7 for further processing and storing of valuess
|
|
1235
|
+
const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240);
|
|
1236
|
+
const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240);
|
|
1237
|
+
const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240);
|
|
1238
|
+
const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240);
|
|
1239
|
+
|
|
1240
|
+
// 4-bit -> 8-bit - Sign is maintained
|
|
1241
|
+
const __m256i rhs_mat_0145_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_0145_0, m4b)); //B0(0-7) B1(0-7) B4(0-7) B5(0-7)
|
|
1242
|
+
const __m256i rhs_mat_2367_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_2367_0, m4b)); //B2(0-7) B3(0-7) B6(0-7) B7(0-7)
|
|
1243
|
+
|
|
1244
|
+
const __m256i rhs_mat_0145_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_0145_1, m4b)); //B0(8-15) B1(8-15) B4(8-15) B5(8-15)
|
|
1245
|
+
const __m256i rhs_mat_2367_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_2367_1, m4b)); //B2(8-15) B3(8-15) B6(8-15) B7(8-15)
|
|
1246
|
+
|
|
1247
|
+
const __m256i rhs_mat_0145_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_0, 4), m4b)); //B0(16-23) B1(16-23) B4(16-23) B5(16-23)
|
|
1248
|
+
const __m256i rhs_mat_2367_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_0, 4), m4b)); //B2(16-23) B3(16-23) B6(16-23) B7(16-23)
|
|
1249
|
+
|
|
1250
|
+
const __m256i rhs_mat_0145_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_1, 4), m4b)); //B0(24-31) B1(24-31) B4(24-31) B5(24-31)
|
|
1251
|
+
const __m256i rhs_mat_2367_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_1, 4), m4b)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31)
|
|
1252
|
+
|
|
1253
|
+
// Shuffle pattern one - right side input
|
|
1254
|
+
const __m256i rhs_mat_0145_0_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_0, 136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3)
|
|
1255
|
+
const __m256i rhs_mat_2367_0_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_0, 136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3)
|
|
1256
|
+
|
|
1257
|
+
const __m256i rhs_mat_0145_1_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_1, 136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11)
|
|
1258
|
+
const __m256i rhs_mat_2367_1_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_1, 136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11)
|
|
1259
|
+
|
|
1260
|
+
const __m256i rhs_mat_0145_2_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_2, 136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19)
|
|
1261
|
+
const __m256i rhs_mat_2367_2_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_2, 136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19)
|
|
1262
|
+
|
|
1263
|
+
const __m256i rhs_mat_0145_3_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_3, 136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27)
|
|
1264
|
+
const __m256i rhs_mat_2367_3_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_3, 136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27)
|
|
1265
|
+
|
|
1266
|
+
// Shuffle pattern two - right side input
|
|
1267
|
+
|
|
1268
|
+
const __m256i rhs_mat_0145_0_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_0, 221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7)
|
|
1269
|
+
const __m256i rhs_mat_2367_0_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_0, 221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7)
|
|
1270
|
+
|
|
1271
|
+
const __m256i rhs_mat_0145_1_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_1, 221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15)
|
|
1272
|
+
const __m256i rhs_mat_2367_1_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_1, 221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15)
|
|
1273
|
+
|
|
1274
|
+
const __m256i rhs_mat_0145_2_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_2, 221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23)
|
|
1275
|
+
const __m256i rhs_mat_2367_2_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_2, 221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23)
|
|
1276
|
+
|
|
1277
|
+
const __m256i rhs_mat_0145_3_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_3, 221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31)
|
|
1278
|
+
const __m256i rhs_mat_2367_3_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_3, 221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31)
|
|
1279
|
+
|
|
1280
|
+
// Scale values - Load the wight scale values of block_tx8
|
|
1281
|
+
__m256 col_scale_f32;
|
|
1282
|
+
if constexpr (
|
|
1283
|
+
std::is_same_v<block_tx8, block_q4_0x8> ||
|
|
1284
|
+
std::is_same_v<block_tx8, block_iq4_nlx8>) {
|
|
1285
|
+
col_scale_f32 = GGML_F32Cx8_LOAD(b_ptr[b].d);
|
|
1286
|
+
}
|
|
1287
|
+
|
|
1288
|
+
// Load the four blocks of quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3
|
|
1289
|
+
// Loaded as set of 128 bit vectors and repeated into a 256 bit vector
|
|
1290
|
+
__m256i lhs_mat_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs)));
|
|
1291
|
+
__m256i lhs_mat_01_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 0);
|
|
1292
|
+
__m256i lhs_mat_23_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 17);
|
|
1293
|
+
__m256i lhs_mat_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 32)));
|
|
1294
|
+
__m256i lhs_mat_01_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 0);
|
|
1295
|
+
__m256i lhs_mat_23_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 17);
|
|
1296
|
+
__m256i lhs_mat_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 64)));
|
|
1297
|
+
__m256i lhs_mat_01_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 0);
|
|
1298
|
+
__m256i lhs_mat_23_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 17);
|
|
1299
|
+
__m256i lhs_mat_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 96)));
|
|
1300
|
+
__m256i lhs_mat_01_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 0);
|
|
1301
|
+
__m256i lhs_mat_23_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 17);
|
|
1302
|
+
|
|
1303
|
+
// Shuffle pattern one - left side input
|
|
1304
|
+
|
|
1305
|
+
const __m256i lhs_mat_01_0_sp1 = _mm256_shuffle_epi32(lhs_mat_01_0, 160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3)
|
|
1306
|
+
const __m256i lhs_mat_23_0_sp1 = _mm256_shuffle_epi32(lhs_mat_23_0, 160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3)
|
|
1307
|
+
|
|
1308
|
+
const __m256i lhs_mat_01_1_sp1 = _mm256_shuffle_epi32(lhs_mat_01_1, 160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11)
|
|
1309
|
+
const __m256i lhs_mat_23_1_sp1 = _mm256_shuffle_epi32(lhs_mat_23_1, 160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11)
|
|
1310
|
+
|
|
1311
|
+
const __m256i lhs_mat_01_2_sp1 = _mm256_shuffle_epi32(lhs_mat_01_2, 160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19)
|
|
1312
|
+
const __m256i lhs_mat_23_2_sp1 = _mm256_shuffle_epi32(lhs_mat_23_2, 160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19)
|
|
1313
|
+
|
|
1314
|
+
const __m256i lhs_mat_01_3_sp1 = _mm256_shuffle_epi32(lhs_mat_01_3, 160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27)
|
|
1315
|
+
const __m256i lhs_mat_23_3_sp1 = _mm256_shuffle_epi32(lhs_mat_23_3, 160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27)
|
|
1316
|
+
|
|
1317
|
+
// Shuffle pattern two - left side input
|
|
1318
|
+
|
|
1319
|
+
const __m256i lhs_mat_01_0_sp2 = _mm256_shuffle_epi32(lhs_mat_01_0, 245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7)
|
|
1320
|
+
const __m256i lhs_mat_23_0_sp2 = _mm256_shuffle_epi32(lhs_mat_23_0, 245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7)
|
|
1321
|
+
|
|
1322
|
+
const __m256i lhs_mat_01_1_sp2 = _mm256_shuffle_epi32(lhs_mat_01_1, 245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15)
|
|
1323
|
+
const __m256i lhs_mat_23_1_sp2 = _mm256_shuffle_epi32(lhs_mat_23_1, 245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15)
|
|
1324
|
+
|
|
1325
|
+
const __m256i lhs_mat_01_2_sp2 = _mm256_shuffle_epi32(lhs_mat_01_2, 245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23)
|
|
1326
|
+
const __m256i lhs_mat_23_2_sp2 = _mm256_shuffle_epi32(lhs_mat_23_2, 245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23)
|
|
1327
|
+
|
|
1328
|
+
const __m256i lhs_mat_01_3_sp2 = _mm256_shuffle_epi32(lhs_mat_01_3, 245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31)
|
|
1329
|
+
const __m256i lhs_mat_23_3_sp2 = _mm256_shuffle_epi32(lhs_mat_23_3, 245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31)
|
|
1330
|
+
|
|
1331
|
+
// The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
|
|
1332
|
+
// Resembles MMLAs into 2x2 matrices in ARM Version
|
|
1333
|
+
const __m256i zero = _mm256_setzero_si256();
|
|
1334
|
+
__m256i iacc_mat_00_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp1, rhs_mat_0145_3_sp1), lhs_mat_01_2_sp1, rhs_mat_0145_2_sp1), lhs_mat_01_1_sp1, rhs_mat_0145_1_sp1), lhs_mat_01_0_sp1, rhs_mat_0145_0_sp1);
|
|
1335
|
+
__m256i iacc_mat_01_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp1, rhs_mat_2367_3_sp1), lhs_mat_01_2_sp1, rhs_mat_2367_2_sp1), lhs_mat_01_1_sp1, rhs_mat_2367_1_sp1), lhs_mat_01_0_sp1, rhs_mat_2367_0_sp1);
|
|
1336
|
+
__m256i iacc_mat_10_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp1, rhs_mat_0145_3_sp1), lhs_mat_23_2_sp1, rhs_mat_0145_2_sp1), lhs_mat_23_1_sp1, rhs_mat_0145_1_sp1), lhs_mat_23_0_sp1, rhs_mat_0145_0_sp1);
|
|
1337
|
+
__m256i iacc_mat_11_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp1, rhs_mat_2367_3_sp1), lhs_mat_23_2_sp1, rhs_mat_2367_2_sp1), lhs_mat_23_1_sp1, rhs_mat_2367_1_sp1), lhs_mat_23_0_sp1, rhs_mat_2367_0_sp1);
|
|
1338
|
+
__m256i iacc_mat_00_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp2, rhs_mat_0145_3_sp2), lhs_mat_01_2_sp2, rhs_mat_0145_2_sp2), lhs_mat_01_1_sp2, rhs_mat_0145_1_sp2), lhs_mat_01_0_sp2, rhs_mat_0145_0_sp2);
|
|
1339
|
+
__m256i iacc_mat_01_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp2, rhs_mat_2367_3_sp2), lhs_mat_01_2_sp2, rhs_mat_2367_2_sp2), lhs_mat_01_1_sp2, rhs_mat_2367_1_sp2), lhs_mat_01_0_sp2, rhs_mat_2367_0_sp2);
|
|
1340
|
+
__m256i iacc_mat_10_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp2, rhs_mat_0145_3_sp2), lhs_mat_23_2_sp2, rhs_mat_0145_2_sp2), lhs_mat_23_1_sp2, rhs_mat_0145_1_sp2), lhs_mat_23_0_sp2, rhs_mat_0145_0_sp2);
|
|
1341
|
+
__m256i iacc_mat_11_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp2, rhs_mat_2367_3_sp2), lhs_mat_23_2_sp2, rhs_mat_2367_2_sp2), lhs_mat_23_1_sp2, rhs_mat_2367_1_sp2), lhs_mat_23_0_sp2, rhs_mat_2367_0_sp2);
|
|
1342
|
+
|
|
1343
|
+
// Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
|
|
1344
|
+
__m256i iacc_mat_00 = _mm256_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2);
|
|
1345
|
+
__m256i iacc_mat_01 = _mm256_add_epi32(iacc_mat_01_sp1, iacc_mat_01_sp2);
|
|
1346
|
+
__m256i iacc_mat_10 = _mm256_add_epi32(iacc_mat_10_sp1, iacc_mat_10_sp2);
|
|
1347
|
+
__m256i iacc_mat_11 = _mm256_add_epi32(iacc_mat_11_sp1, iacc_mat_11_sp2);
|
|
1348
|
+
|
|
1349
|
+
|
|
1350
|
+
// Straighten out to make 4 row vectors
|
|
1351
|
+
__m256i iacc_row_0 = _mm256_blend_epi32(iacc_mat_00, _mm256_shuffle_epi32(iacc_mat_01, 78), 204);
|
|
1352
|
+
__m256i iacc_row_1 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_00, 78), iacc_mat_01, 204);
|
|
1353
|
+
__m256i iacc_row_2 = _mm256_blend_epi32(iacc_mat_10, _mm256_shuffle_epi32(iacc_mat_11, 78), 204);
|
|
1354
|
+
__m256i iacc_row_3 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11, 204);
|
|
1355
|
+
|
|
1356
|
+
// Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes
|
|
1357
|
+
const __m256 row_scale_f32 = GGML_F32Cx8_REPEAT_LOAD(a_ptr[b].d, loadMask);
|
|
1358
|
+
|
|
1359
|
+
// Multiply with appropiate scales and accumulate
|
|
1360
|
+
acc_rows[0] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_0), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[0]);
|
|
1361
|
+
acc_rows[1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_1), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[1]);
|
|
1362
|
+
acc_rows[2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_2), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[2]);
|
|
1363
|
+
acc_rows[3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_3), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[3]);
|
|
1364
|
+
}
|
|
1365
|
+
|
|
1366
|
+
// Store the accumulated values
|
|
1367
|
+
for (int i = 0; i < 4; i++) {
|
|
1368
|
+
_mm256_storeu_ps((float *)(s + ((y * 4 + i) * bs + x * 8)), acc_rows[i]);
|
|
1369
|
+
}
|
|
1370
|
+
}
|
|
1371
|
+
}
|
|
1372
|
+
}
|
|
1373
|
+
|
|
1374
|
+
#endif // defined(__AVX2__) || defined(__AVX512F__)
|
|
1375
|
+
|
|
1376
|
+
void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
|
1377
|
+
#if defined(__AVX2__) || defined(__AVX512F__)
|
|
1378
|
+
{
|
|
1379
|
+
// Lookup table to convert signed nibbles to signed bytes
|
|
1380
|
+
__m256i signextendlut = _mm256_castsi128_si256(_mm_set_epi8(-1, -2, -3, -4, -5, -6, -7, -8, 7, 6, 5, 4, 3, 2, 1, 0));
|
|
1381
|
+
signextendlut = _mm256_permute2f128_si256(signextendlut, signextendlut, 0);
|
|
1382
|
+
|
|
1383
|
+
gemv_q4_b32_8x8_q8_0_lut_avx<block_q4_0x8>(n, s, bs, vx, vy, nr, nc, signextendlut);
|
|
1384
|
+
|
|
1385
|
+
return;
|
|
1386
|
+
}
|
|
849
1387
|
#endif
|
|
1388
|
+
|
|
1389
|
+
ggml_gemv_q4_0_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
|
850
1390
|
}
|
|
851
1391
|
|
|
852
|
-
void
|
|
1392
|
+
void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
|
853
1393
|
const int qk = QK_K;
|
|
854
1394
|
const int nb = n / qk;
|
|
855
1395
|
const int ncols_interleaved = 8;
|
|
856
1396
|
const int blocklen = 8;
|
|
1397
|
+
static const uint32_t kmask1 = 0x3f3f3f3f;
|
|
1398
|
+
static const uint32_t kmask2 = 0x0f0f0f0f;
|
|
1399
|
+
static const uint32_t kmask3 = 0x03030303;
|
|
857
1400
|
|
|
858
1401
|
assert (n % qk == 0);
|
|
859
1402
|
assert (nc % ncols_interleaved == 0);
|
|
@@ -872,21 +1415,18 @@ void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
|
|
872
1415
|
// Lookup table to convert signed nibbles to signed bytes
|
|
873
1416
|
__m256i signextendlut = _mm256_castsi128_si256(_mm_set_epi8(-1, -2, -3, -4, -5, -6, -7, -8, 7, 6, 5, 4, 3, 2, 1, 0));
|
|
874
1417
|
signextendlut = _mm256_permute2f128_si256(signextendlut, signextendlut, 0);
|
|
875
|
-
// Shuffle masks to rearrange delta values to multiply with appropriate scales
|
|
1418
|
+
// Shuffle masks to rearrange delta and scale values to multiply with appropriate scales
|
|
876
1419
|
__m128i deltamask = _mm_set_epi8(15, 14, 7, 6, 13, 12, 5, 4, 11, 10, 3, 2, 9, 8, 1, 0);
|
|
1420
|
+
__m128i scalemask = _mm_set_epi8(7, 7, 3, 3, 6, 6, 2, 2, 5, 5, 1, 1, 4, 4, 0, 0);
|
|
877
1421
|
// Permute mask used for easier vector processing at later stages
|
|
878
1422
|
__m256i finalpermutemask = _mm256_set_epi32(7, 5, 3, 1, 6, 4, 2, 0);
|
|
879
1423
|
|
|
880
|
-
|
|
881
|
-
const
|
|
882
|
-
|
|
883
|
-
//Mask to get appropriate scales
|
|
884
|
-
__m128i scalemask1 = _mm_set_epi8(14,14,6,6,12,12,4,4,10,10,2,2,8,8,0,0);
|
|
885
|
-
__m128i scalemask2 = _mm_set_epi8(15,15,7,7,13,13,5,5,11,11,3,3,9,9,1,1);
|
|
1424
|
+
// Mask to extract nibbles from bytes
|
|
1425
|
+
const __m256i m4b = _mm256_set1_epi8(0x0F);
|
|
886
1426
|
|
|
887
1427
|
int64_t b_nb = n / QK_K;
|
|
888
1428
|
|
|
889
|
-
const
|
|
1429
|
+
const block_q4_Kx8 * b_ptr_start = (const block_q4_Kx8 *)vx;
|
|
890
1430
|
const block_q8_K * a_ptr_start = (const block_q8_K *)vy;
|
|
891
1431
|
|
|
892
1432
|
// Process Q8_K blocks one by one
|
|
@@ -895,11 +1435,11 @@ void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
|
|
895
1435
|
// Pointers to LHS blocks of block_q8_K format
|
|
896
1436
|
const block_q8_K * a_ptr = a_ptr_start + (y * nb);
|
|
897
1437
|
|
|
898
|
-
// Take group of eight interleaved
|
|
899
|
-
for(int64_t x = 0; x < nc / 8; x++) {
|
|
1438
|
+
// Take group of eight interleaved block_q4_K structures at each pass of the loop and perform dot product operation
|
|
1439
|
+
for (int64_t x = 0; x < nc / 8; x++) {
|
|
900
1440
|
|
|
901
1441
|
// Pointers to RHS blocks
|
|
902
|
-
const
|
|
1442
|
+
const block_q4_Kx8 * b_ptr = b_ptr_start + (x * b_nb);
|
|
903
1443
|
|
|
904
1444
|
// Master FP accumulators
|
|
905
1445
|
__m256 acc_row = _mm256_setzero_ps();
|
|
@@ -907,10 +1447,10 @@ void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
|
|
907
1447
|
|
|
908
1448
|
for (int64_t b = 0; b < nb; b++) {
|
|
909
1449
|
|
|
910
|
-
// Load and convert to FP32
|
|
1450
|
+
// Load and convert to FP32 scale from block_q8_K
|
|
911
1451
|
const __m256 row_scale_f32 = _mm256_set1_ps((a_ptr[b].d));
|
|
912
1452
|
|
|
913
|
-
// Load the
|
|
1453
|
+
// Load the scale values for the 8 blocks interleaved in block_q4_Kx8
|
|
914
1454
|
// col_scale_f32 rearranged so as to multiply with appropriate quants
|
|
915
1455
|
const __m256 col_scale_f32 = GGML_F32Cx8_REARRANGE_LOAD(b_ptr[b].d, deltamask);
|
|
916
1456
|
const __m256 col_dmin_f32 = GGML_F32Cx8_LOAD(b_ptr[b].dmin);
|
|
@@ -918,11 +1458,15 @@ void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
|
|
918
1458
|
__m256i iacc_b = _mm256_setzero_si256();
|
|
919
1459
|
__m256i iacc_min_b = _mm256_setzero_si256();
|
|
920
1460
|
|
|
921
|
-
|
|
922
|
-
|
|
1461
|
+
const __m256i q8sums = _mm256_loadu_si256((const __m256i * )(a_ptr[b].bsums));
|
|
1462
|
+
__m256i q8s = _mm256_castsi128_si256(_mm_hadd_epi16(_mm256_castsi256_si128(q8sums), _mm256_extracti128_si256(q8sums, 1)));
|
|
1463
|
+
q8s = _mm256_permute2f128_si256(q8s, q8s, 0);
|
|
923
1464
|
|
|
924
|
-
|
|
925
|
-
|
|
1465
|
+
// Processes two sub blocks from each Q4_K in each iteration
|
|
1466
|
+
for (int sb = 0; sb < QK_K / 64; sb++) {
|
|
1467
|
+
|
|
1468
|
+
// Load the eight block_q4_K for two sub blocks quantized values interleaved with each other in chunks of eight - B0,B1 ....B6,B7
|
|
1469
|
+
const __m256i rhs_raw_vec_0123_0 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + sb * 256));
|
|
926
1470
|
const __m256i rhs_raw_vec_4567_0 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 32 + sb * 256));
|
|
927
1471
|
const __m256i rhs_raw_vec_0123_1 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 64 + sb * 256));
|
|
928
1472
|
const __m256i rhs_raw_vec_4567_1 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 96 + sb * 256));
|
|
@@ -931,982 +1475,482 @@ void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
|
|
931
1475
|
const __m256i rhs_raw_vec_0123_3 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 192 + sb * 256));
|
|
932
1476
|
const __m256i rhs_raw_vec_4567_3 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 224 + sb * 256));
|
|
933
1477
|
|
|
934
|
-
//
|
|
935
|
-
// Values of the
|
|
936
|
-
const __m256i rhs_vec_0123_00 = _mm256_and_si256(rhs_raw_vec_0123_0,
|
|
937
|
-
const __m256i
|
|
938
|
-
const __m256i
|
|
939
|
-
const __m256i
|
|
940
|
-
|
|
941
|
-
const __m256i
|
|
942
|
-
const __m256i
|
|
943
|
-
const __m256i
|
|
944
|
-
const __m256i rhs_vec_4567_60 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_0, 6), m3b); //B64(0-7) B65(0-7) B66(0-7) B67(0-7)
|
|
945
|
-
|
|
946
|
-
const __m256i rhs_vec_0123_01 = _mm256_and_si256(rhs_raw_vec_0123_1, m3b); //B00(8-15) B01(8-15) B02(8-15) B03(8-15)
|
|
947
|
-
const __m256i rhs_vec_0123_21 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_1, 2), m3b); //B20(8-15) B21(8-15) B22(8-15) B23(8-15)
|
|
948
|
-
const __m256i rhs_vec_0123_41 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_1, 4), m3b); //B40(8-15) B41(8-15) B42(8-15) B43(8-15)
|
|
949
|
-
const __m256i rhs_vec_0123_61 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_1, 6), m3b); //B60(8-15) B61(8-15) B62(8-15) B63(8-15)
|
|
950
|
-
|
|
951
|
-
const __m256i rhs_vec_4567_01 = _mm256_and_si256(rhs_raw_vec_4567_1, m3b); //B04(8-15) B05(8-15) B06(8-15) B07(8-15)
|
|
952
|
-
const __m256i rhs_vec_4567_21 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_1, 2), m3b); //B24(8-15) B25(8-15) B26(8-15) B27(8-15)
|
|
953
|
-
const __m256i rhs_vec_4567_41 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_1, 4), m3b); //B44(8-15) B45(8-15) B46(8-15) B47(8-15)
|
|
954
|
-
const __m256i rhs_vec_4567_61 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_1, 6), m3b); //B64(8-15) B65(8-15) B66(8-15) B67(8-15)
|
|
955
|
-
|
|
956
|
-
// Values of the 1st,3rd,5th,7th sub blocks of eight block_q2_K structures for the sb loop
|
|
957
|
-
const __m256i rhs_vec_0123_10 = _mm256_and_si256(rhs_raw_vec_0123_2, m3b); //B10(0-7) B11(0-7) B12(0-7) B13(0-7)
|
|
958
|
-
const __m256i rhs_vec_0123_30 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_2, 2), m3b); //B30(0-7) B31(0-7) B32(0-7) B33(0-7)
|
|
959
|
-
const __m256i rhs_vec_0123_50 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_2, 4), m3b); //B50(0-7) B51(0-7) B52(0-7) B53(0-7)
|
|
960
|
-
const __m256i rhs_vec_0123_70 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_2, 6), m3b); //B70(0-7) B71(0-7) B72(0-7) B73(0-7)
|
|
961
|
-
|
|
962
|
-
const __m256i rhs_vec_4567_10 = _mm256_and_si256(rhs_raw_vec_4567_2, m3b); //B14(0-7) B15(0-7) B16(0-7) B17(0-7)
|
|
963
|
-
const __m256i rhs_vec_4567_30 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_2, 2), m3b); //B34(0-7) B35(0-7) B36(0-7) B37(0-7)
|
|
964
|
-
const __m256i rhs_vec_4567_50 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_2, 4), m3b); //B54(0-7) B55(0-7) B56(0-7) B57(0-7)
|
|
965
|
-
const __m256i rhs_vec_4567_70 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_2, 6), m3b); //B74(0-7) B75(0-7) B76(0-7) B77(0-7)
|
|
966
|
-
|
|
967
|
-
const __m256i rhs_vec_0123_11 = _mm256_and_si256(rhs_raw_vec_0123_3, m3b); //B10(8-15) B11(8-15) B12(8-15) B13(8-15)
|
|
968
|
-
const __m256i rhs_vec_0123_31 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_3, 2), m3b); //B30(8-15) B31(8-15) B32(8-15) B33(8-15)
|
|
969
|
-
const __m256i rhs_vec_0123_51 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_3, 4), m3b); //B50(8-15) B51(8-15) B52(8-15) B53(8-15)
|
|
970
|
-
const __m256i rhs_vec_0123_71 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_3, 6), m3b); //B70(8-15) B71(8-15) B72(8-15) B73(8-15)
|
|
971
|
-
|
|
972
|
-
const __m256i rhs_vec_4567_11 = _mm256_and_si256(rhs_raw_vec_4567_3, m3b); //B14(8-15) B15(8-15) B16(8-15) B17(8-15)
|
|
973
|
-
const __m256i rhs_vec_4567_31 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_3, 2), m3b); //B34(8-15) B35(8-15) B36(8-15) B37(8-15)
|
|
974
|
-
const __m256i rhs_vec_4567_51 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_3, 4), m3b); //B54(8-15) B55(8-15) B56(8-15) B57(8-15)
|
|
975
|
-
const __m256i rhs_vec_4567_71 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_3, 6), m3b); //B74(8-15) B75(8-15) B76(8-15) B77(8-15)
|
|
976
|
-
|
|
977
|
-
//Scales and Mins of corresponding sub blocks from different Q2_K structures are stored together
|
|
978
|
-
//s00 m00 s01 m01 s10 m10 s11 m11 s20 m20 s21 m21 s30 m30 s31 m31 s40 m40 s41 m41 s50 m50 s51 m51 s60 m60 s61 m61 s70 m70 s71 m71
|
|
979
|
-
|
|
980
|
-
const __m128i mins_and_scales_01 = _mm_loadu_si128((const __m128i *)(b_ptr[b].scales + sb * 64));
|
|
981
|
-
const __m128i mins_and_scales_23 = _mm_loadu_si128((const __m128i *)(b_ptr[b].scales + 16 + sb * 64));
|
|
982
|
-
const __m128i mins_and_scales_45 = _mm_loadu_si128((const __m128i *)(b_ptr[b].scales + 32 + sb * 64));
|
|
983
|
-
const __m128i mins_and_scales_67 = _mm_loadu_si128((const __m128i *)(b_ptr[b].scales + 48 + sb * 64));
|
|
984
|
-
|
|
985
|
-
// Extract scales which is lower half from mins_and_scales
|
|
986
|
-
const __m128i scales_01 = _mm_and_si128(mins_and_scales_01, m4b_sse);
|
|
987
|
-
const __m128i scales_23 = _mm_and_si128(mins_and_scales_23, m4b_sse);
|
|
988
|
-
const __m128i scales_45 = _mm_and_si128(mins_and_scales_45, m4b_sse);
|
|
989
|
-
const __m128i scales_67 = _mm_and_si128(mins_and_scales_67, m4b_sse);
|
|
990
|
-
|
|
991
|
-
// Extract mins which is upper half from mins_and_scales
|
|
992
|
-
const __m256i mins_01 = _mm256_cvtepu8_epi16(_mm_and_si128(_mm_srli_epi16(mins_and_scales_01, 4), m4b_sse));
|
|
993
|
-
const __m256i mins_23 = _mm256_cvtepu8_epi16(_mm_and_si128(_mm_srli_epi16(mins_and_scales_23, 4), m4b_sse));
|
|
994
|
-
const __m256i mins_45 = _mm256_cvtepu8_epi16(_mm_and_si128(_mm_srli_epi16(mins_and_scales_45, 4), m4b_sse));
|
|
995
|
-
const __m256i mins_67 = _mm256_cvtepu8_epi16(_mm_and_si128(_mm_srli_epi16(mins_and_scales_67, 4), m4b_sse));
|
|
996
|
-
|
|
997
|
-
// Scales of sub blocks in the sb loop
|
|
998
|
-
// Scales of the 0th sub block from each super block
|
|
999
|
-
__m128i scales_rearrange_0 = _mm_shuffle_epi8(scales_01, scalemask1);
|
|
1000
|
-
__m256i scales_0 = _mm256_cvtepu8_epi16(scales_rearrange_0);
|
|
1001
|
-
|
|
1002
|
-
// Scales of the 1st sub block from each super block
|
|
1003
|
-
__m128i scales_rearrange_1 = _mm_shuffle_epi8(scales_01, scalemask2);
|
|
1004
|
-
__m256i scales_1 = _mm256_cvtepu8_epi16(scales_rearrange_1);
|
|
1005
|
-
|
|
1006
|
-
// Scales of the 2nd sub block from each super block
|
|
1007
|
-
__m128i scales_rearrange_2 = _mm_shuffle_epi8(scales_23, scalemask1);
|
|
1008
|
-
__m256i scales_2 = _mm256_cvtepu8_epi16(scales_rearrange_2);
|
|
1009
|
-
|
|
1010
|
-
// Scales of the 3rd sub block from each super block
|
|
1011
|
-
__m128i scales_rearrange_3 = _mm_shuffle_epi8(scales_23, scalemask2);
|
|
1012
|
-
__m256i scales_3 = _mm256_cvtepu8_epi16(scales_rearrange_3);
|
|
1013
|
-
|
|
1014
|
-
// Scales of the 4th sub block from each super block
|
|
1015
|
-
__m128i scales_rearrange_4 = _mm_shuffle_epi8(scales_45, scalemask1);
|
|
1016
|
-
__m256i scales_4 = _mm256_cvtepu8_epi16(scales_rearrange_4);
|
|
1017
|
-
|
|
1018
|
-
// Scales of the 5th sub block from each super block
|
|
1019
|
-
__m128i scales_rearrange_5 = _mm_shuffle_epi8(scales_45, scalemask2);
|
|
1020
|
-
__m256i scales_5 = _mm256_cvtepu8_epi16(scales_rearrange_5);
|
|
1021
|
-
|
|
1022
|
-
// Scales of the 6th sub block from each super block
|
|
1023
|
-
__m128i scales_rearrange_6 = _mm_shuffle_epi8(scales_67, scalemask1);
|
|
1024
|
-
__m256i scales_6 = _mm256_cvtepu8_epi16(scales_rearrange_6);
|
|
1025
|
-
|
|
1026
|
-
// Scales of the 7th sub block from each super block
|
|
1027
|
-
__m128i scales_rearrange_7 = _mm_shuffle_epi8(scales_67, scalemask2);
|
|
1028
|
-
__m256i scales_7 = _mm256_cvtepu8_epi16(scales_rearrange_7);
|
|
1029
|
-
|
|
1030
|
-
// Load the sub block values corresponding to sb in block_q8_K in batches of 16 bytes and replicate the same across 256 bit vector
|
|
1031
|
-
__m256i lhs_vec_0 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + sb * 128)));
|
|
1032
|
-
__m256i lhs_vec_1 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 16 + sb * 128)));
|
|
1033
|
-
__m256i lhs_vec_2 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 32 + sb * 128)));
|
|
1034
|
-
__m256i lhs_vec_3 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 48 + sb * 128)));
|
|
1035
|
-
__m256i lhs_vec_4 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 64 + sb * 128)));
|
|
1036
|
-
__m256i lhs_vec_5 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 80 + sb * 128)));
|
|
1037
|
-
__m256i lhs_vec_6 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 96 + sb * 128)));
|
|
1038
|
-
__m256i lhs_vec_7 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 112 + sb * 128)));
|
|
1039
|
-
|
|
1040
|
-
lhs_vec_0 = _mm256_permute2f128_si256(lhs_vec_0, lhs_vec_0, 0);
|
|
1041
|
-
lhs_vec_1 = _mm256_permute2f128_si256(lhs_vec_1, lhs_vec_1, 0);
|
|
1042
|
-
lhs_vec_2 = _mm256_permute2f128_si256(lhs_vec_2, lhs_vec_2, 0);
|
|
1043
|
-
lhs_vec_3 = _mm256_permute2f128_si256(lhs_vec_3, lhs_vec_3, 0);
|
|
1044
|
-
lhs_vec_4 = _mm256_permute2f128_si256(lhs_vec_4, lhs_vec_4, 0);
|
|
1045
|
-
lhs_vec_5 = _mm256_permute2f128_si256(lhs_vec_5, lhs_vec_5, 0);
|
|
1046
|
-
lhs_vec_6 = _mm256_permute2f128_si256(lhs_vec_6, lhs_vec_6, 0);
|
|
1047
|
-
lhs_vec_7 = _mm256_permute2f128_si256(lhs_vec_7, lhs_vec_7, 0);
|
|
1048
|
-
|
|
1049
|
-
__m256i iacc_0 = _mm256_setzero_si256();
|
|
1050
|
-
__m256i iacc_1 = _mm256_setzero_si256();
|
|
1051
|
-
__m256i iacc_2 = _mm256_setzero_si256();
|
|
1052
|
-
__m256i iacc_3 = _mm256_setzero_si256();
|
|
1053
|
-
__m256i iacc_4 = _mm256_setzero_si256();
|
|
1054
|
-
__m256i iacc_5 = _mm256_setzero_si256();
|
|
1055
|
-
__m256i iacc_6 = _mm256_setzero_si256();
|
|
1056
|
-
__m256i iacc_7 = _mm256_setzero_si256();
|
|
1057
|
-
|
|
1058
|
-
// Dot product done within 32 bit lanes and accumulated in the same vector
|
|
1059
|
-
// First done for 0th sub block and then for seven (1st - 7th) other sub blocks processed for each sb (sb < QK_K/128 loop) // B0(0-3) B4(0-3) B1(0-3) B5(0-3) B2(0-3) B6(0-3) B3(0-3) B7(0-3) with A0(0-3)
|
|
1060
|
-
// B0(4-7) B4(4-7) B1(4-7) B5(4-7) B2(4-7) B6(4-7) B3(4-7) B7(4-7) with A0(4-7)
|
|
1061
|
-
// B0(8-11) B4(8-11) B1(8-11) B5(8-11) B2(8-11) B6(8-11) B3(8-11) B7(8-11) with A0(8-11)
|
|
1062
|
-
// B0(12-15) B4(12-15) B1(12-15) B5(12-15) B2(12-15) B6(12-15) B3(12-15) B7(12-15) with A0(12-15)
|
|
1063
|
-
|
|
1064
|
-
iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_00 ,_mm256_shuffle_epi32(rhs_vec_4567_00, 177), 170), _mm256_shuffle_epi32(lhs_vec_0, 0)));
|
|
1065
|
-
iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_00, 177) ,rhs_vec_4567_00, 170), _mm256_shuffle_epi32(lhs_vec_0, 85)));
|
|
1066
|
-
|
|
1067
|
-
iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_01 ,_mm256_shuffle_epi32(rhs_vec_4567_01, 177), 170), _mm256_shuffle_epi32(lhs_vec_0, 170)));
|
|
1068
|
-
iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_01, 177) ,rhs_vec_4567_01, 170), _mm256_shuffle_epi32(lhs_vec_0, 255)));
|
|
1069
|
-
|
|
1070
|
-
iacc_0 = _mm256_madd_epi16(iacc_0, scales_0);
|
|
1071
|
-
|
|
1072
|
-
iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_10 ,_mm256_shuffle_epi32(rhs_vec_4567_10, 177), 170), _mm256_shuffle_epi32(lhs_vec_1, 0)));
|
|
1073
|
-
iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_10, 177) ,rhs_vec_4567_10, 170), _mm256_shuffle_epi32(lhs_vec_1, 85)));
|
|
1074
|
-
|
|
1075
|
-
iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_11 ,_mm256_shuffle_epi32(rhs_vec_4567_11, 177), 170), _mm256_shuffle_epi32(lhs_vec_1, 170)));
|
|
1076
|
-
iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_11, 177) ,rhs_vec_4567_11, 170), _mm256_shuffle_epi32(lhs_vec_1, 255)));
|
|
1077
|
-
|
|
1078
|
-
iacc_1 = _mm256_madd_epi16(iacc_1, scales_1);
|
|
1079
|
-
|
|
1080
|
-
iacc_2 = _mm256_add_epi16(iacc_2, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_20 ,_mm256_shuffle_epi32(rhs_vec_4567_20, 177), 170), _mm256_shuffle_epi32(lhs_vec_2, 0)));
|
|
1081
|
-
iacc_2 = _mm256_add_epi16(iacc_2, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_20, 177) ,rhs_vec_4567_20, 170), _mm256_shuffle_epi32(lhs_vec_2, 85)));
|
|
1082
|
-
|
|
1083
|
-
iacc_2 = _mm256_add_epi16(iacc_2, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_21 ,_mm256_shuffle_epi32(rhs_vec_4567_21, 177), 170), _mm256_shuffle_epi32(lhs_vec_2, 170)));
|
|
1084
|
-
iacc_2 = _mm256_add_epi16(iacc_2, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_21, 177) ,rhs_vec_4567_21, 170), _mm256_shuffle_epi32(lhs_vec_2, 255)));
|
|
1085
|
-
|
|
1086
|
-
iacc_2 = _mm256_madd_epi16(iacc_2, scales_2);
|
|
1087
|
-
|
|
1088
|
-
iacc_3 = _mm256_add_epi16(iacc_3, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_30 ,_mm256_shuffle_epi32(rhs_vec_4567_30, 177), 170), _mm256_shuffle_epi32(lhs_vec_3, 0)));
|
|
1089
|
-
iacc_3 = _mm256_add_epi16(iacc_3, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_30, 177) ,rhs_vec_4567_30, 170), _mm256_shuffle_epi32(lhs_vec_3, 85)));
|
|
1090
|
-
|
|
1091
|
-
iacc_3 = _mm256_add_epi16(iacc_3, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_31 ,_mm256_shuffle_epi32(rhs_vec_4567_31, 177), 170), _mm256_shuffle_epi32(lhs_vec_3, 170)));
|
|
1092
|
-
iacc_3 = _mm256_add_epi16(iacc_3, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_31, 177) ,rhs_vec_4567_31, 170), _mm256_shuffle_epi32(lhs_vec_3, 255)));
|
|
1093
|
-
|
|
1094
|
-
iacc_3 = _mm256_madd_epi16(iacc_3, scales_3);
|
|
1095
|
-
|
|
1096
|
-
iacc_4 = _mm256_add_epi16(iacc_4, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_40 ,_mm256_shuffle_epi32(rhs_vec_4567_40, 177), 170), _mm256_shuffle_epi32(lhs_vec_4, 0)));
|
|
1097
|
-
iacc_4 = _mm256_add_epi16(iacc_4, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_40, 177) ,rhs_vec_4567_40, 170), _mm256_shuffle_epi32(lhs_vec_4, 85)));
|
|
1098
|
-
|
|
1099
|
-
iacc_4 = _mm256_add_epi16(iacc_4, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_41 ,_mm256_shuffle_epi32(rhs_vec_4567_41, 177), 170), _mm256_shuffle_epi32(lhs_vec_4, 170)));
|
|
1100
|
-
iacc_4 = _mm256_add_epi16(iacc_4, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_41, 177) ,rhs_vec_4567_41, 170), _mm256_shuffle_epi32(lhs_vec_4, 255)));
|
|
1101
|
-
|
|
1102
|
-
iacc_4 = _mm256_madd_epi16(iacc_4, scales_4);
|
|
1103
|
-
|
|
1104
|
-
iacc_5 = _mm256_add_epi16(iacc_5, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_50 ,_mm256_shuffle_epi32(rhs_vec_4567_50, 177), 170), _mm256_shuffle_epi32(lhs_vec_5, 0)));
|
|
1105
|
-
iacc_5 = _mm256_add_epi16(iacc_5, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_50, 177) ,rhs_vec_4567_50, 170), _mm256_shuffle_epi32(lhs_vec_5, 85)));
|
|
1106
|
-
|
|
1107
|
-
iacc_5 = _mm256_add_epi16(iacc_5, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_51 ,_mm256_shuffle_epi32(rhs_vec_4567_51, 177), 170), _mm256_shuffle_epi32(lhs_vec_5, 170)));
|
|
1108
|
-
iacc_5 = _mm256_add_epi16(iacc_5, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_51, 177) ,rhs_vec_4567_51, 170), _mm256_shuffle_epi32(lhs_vec_5, 255)));
|
|
1109
|
-
|
|
1110
|
-
iacc_5 = _mm256_madd_epi16(iacc_5, scales_5);
|
|
1111
|
-
|
|
1112
|
-
iacc_6 = _mm256_add_epi16(iacc_6, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_60 ,_mm256_shuffle_epi32(rhs_vec_4567_60, 177), 170), _mm256_shuffle_epi32(lhs_vec_6, 0)));
|
|
1113
|
-
iacc_6 = _mm256_add_epi16(iacc_6, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_60, 177) ,rhs_vec_4567_60, 170), _mm256_shuffle_epi32(lhs_vec_6, 85)));
|
|
1114
|
-
|
|
1115
|
-
iacc_6 = _mm256_add_epi16(iacc_6, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_61 ,_mm256_shuffle_epi32(rhs_vec_4567_61, 177), 170), _mm256_shuffle_epi32(lhs_vec_6, 170)));
|
|
1116
|
-
iacc_6 = _mm256_add_epi16(iacc_6, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_61, 177) ,rhs_vec_4567_61, 170), _mm256_shuffle_epi32(lhs_vec_6, 255)));
|
|
1117
|
-
|
|
1118
|
-
iacc_6 = _mm256_madd_epi16(iacc_6, scales_6);
|
|
1119
|
-
|
|
1120
|
-
iacc_7 = _mm256_add_epi16(iacc_7, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_70 ,_mm256_shuffle_epi32(rhs_vec_4567_70, 177), 170), _mm256_shuffle_epi32(lhs_vec_7, 0)));
|
|
1121
|
-
iacc_7 = _mm256_add_epi16(iacc_7, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_70, 177) ,rhs_vec_4567_70, 170), _mm256_shuffle_epi32(lhs_vec_7, 85)));
|
|
1122
|
-
|
|
1123
|
-
iacc_7 = _mm256_add_epi16(iacc_7, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_71 ,_mm256_shuffle_epi32(rhs_vec_4567_71, 177), 170), _mm256_shuffle_epi32(lhs_vec_7, 170)));
|
|
1124
|
-
iacc_7 = _mm256_add_epi16(iacc_7, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_71, 177) ,rhs_vec_4567_71, 170), _mm256_shuffle_epi32(lhs_vec_7, 255)));
|
|
1125
|
-
|
|
1126
|
-
iacc_7 = _mm256_madd_epi16(iacc_7, scales_7);
|
|
1127
|
-
|
|
1128
|
-
// Accumulate the iacc value for one sb
|
|
1129
|
-
__m256i iacc_sb = _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(iacc_0, iacc_1), _mm256_add_epi32(iacc_2, iacc_3)), _mm256_add_epi32(_mm256_add_epi32(iacc_4, iacc_5), _mm256_add_epi32(iacc_6, iacc_7)));
|
|
1130
|
-
|
|
1131
|
-
__m128i q8sums = _mm_loadu_si128((const __m128i *)(a_ptr[b].bsums + sb * 8));
|
|
1132
|
-
__m256i q8s = _mm256_castsi128_si256(q8sums);
|
|
1133
|
-
q8s= _mm256_permute2f128_si256(q8s, q8s, 0);
|
|
1134
|
-
|
|
1135
|
-
// Broadcast the bsums of the two corresponding subblocks of q8_k
|
|
1136
|
-
// Multiply-Add with corresponding mins of Q2_Kx8 with bsums
|
|
1137
|
-
__m256i iacc_min_sb_01 = _mm256_madd_epi16(_mm256_shuffle_epi32(q8s, 0), mins_01);
|
|
1138
|
-
__m256i iacc_min_sb_23 = _mm256_madd_epi16(_mm256_shuffle_epi32(q8s, 85), mins_23);
|
|
1139
|
-
__m256i iacc_min_sb_45 = _mm256_madd_epi16(_mm256_shuffle_epi32(q8s, 170), mins_45);
|
|
1140
|
-
__m256i iacc_min_sb_67 = _mm256_madd_epi16(_mm256_shuffle_epi32(q8s, 255), mins_67);
|
|
1141
|
-
|
|
1142
|
-
__m256i iacc_min_sb = _mm256_add_epi32(_mm256_add_epi32(iacc_min_sb_01, iacc_min_sb_23), _mm256_add_epi32(iacc_min_sb_45,iacc_min_sb_67));
|
|
1143
|
-
|
|
1144
|
-
// Accumulate for the complete block
|
|
1145
|
-
iacc_b = _mm256_add_epi32(iacc_b, iacc_sb);
|
|
1146
|
-
iacc_min_b = _mm256_add_epi32(iacc_min_b, iacc_min_sb);
|
|
1147
|
-
}
|
|
1148
|
-
|
|
1149
|
-
//Multiply-Add with scale values for complete super block
|
|
1150
|
-
acc_row = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_b), _mm256_mul_ps(col_scale_f32, row_scale_f32), acc_row);
|
|
1151
|
-
acc_min_rows = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_min_b), _mm256_mul_ps(col_dmin_f32, row_scale_f32), acc_min_rows);
|
|
1152
|
-
}
|
|
1153
|
-
// Accumulated output values permuted so as to be stored in appropriate order post accumulation
|
|
1154
|
-
acc_row = _mm256_permutevar8x32_ps(acc_row, finalpermutemask);
|
|
1155
|
-
_mm256_storeu_ps(s + (y * nr + x * 8), _mm256_sub_ps(acc_row, acc_min_rows));
|
|
1156
|
-
}
|
|
1157
|
-
}
|
|
1158
|
-
#else
|
|
1159
|
-
|
|
1160
|
-
ggml_gemv_q2_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
|
1161
|
-
|
|
1162
|
-
#endif
|
|
1163
|
-
}
|
|
1164
|
-
|
|
1165
|
-
void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
|
1166
|
-
const int qk = QK8_0;
|
|
1167
|
-
const int nb = n / qk;
|
|
1168
|
-
const int ncols_interleaved = 8;
|
|
1169
|
-
const int blocklen = 8;
|
|
1170
|
-
|
|
1171
|
-
assert (n % qk == 0);
|
|
1172
|
-
assert (nr % 4 == 0);
|
|
1173
|
-
assert (nc % ncols_interleaved == 0);
|
|
1174
|
-
|
|
1175
|
-
UNUSED(s);
|
|
1176
|
-
UNUSED(bs);
|
|
1177
|
-
UNUSED(vx);
|
|
1178
|
-
UNUSED(vy);
|
|
1179
|
-
UNUSED(nr);
|
|
1180
|
-
UNUSED(nc);
|
|
1181
|
-
UNUSED(nb);
|
|
1182
|
-
UNUSED(ncols_interleaved);
|
|
1183
|
-
UNUSED(blocklen);
|
|
1184
|
-
|
|
1185
|
-
#if defined(__AVX2__) || defined(__AVX512F__)
|
|
1186
|
-
{
|
|
1187
|
-
const block_q4_0x8 * b_ptr_start = (const block_q4_0x8 *)vx;
|
|
1188
|
-
const block_q8_0x4 * a_ptr_start = (const block_q8_0x4 *)vy;
|
|
1189
|
-
int64_t b_nb = n / QK4_0;
|
|
1190
|
-
int64_t y = 0;
|
|
1191
|
-
// Mask to mask out nibbles from packed bytes
|
|
1192
|
-
const __m256i m4b = _mm256_set1_epi8(0x0F);
|
|
1193
|
-
const __m128i loadMask = _mm_blend_epi32(_mm_setzero_si128(), _mm_set1_epi32(0xFFFFFFFF), 3);
|
|
1194
|
-
// Lookup table to convert signed nibbles to signed bytes
|
|
1195
|
-
__m256i signextendlut = _mm256_castsi128_si256(_mm_set_epi8(-1, -2, -3, -4, -5, -6, -7, -8, 7, 6, 5, 4, 3, 2, 1, 0));
|
|
1196
|
-
signextendlut = _mm256_permute2f128_si256(signextendlut, signextendlut, 0);
|
|
1197
|
-
// Permute mask used for easier vector processing at later stages
|
|
1198
|
-
__m256i requiredOrder = _mm256_set_epi32(3, 2, 1, 0, 7, 6, 5, 4);
|
|
1199
|
-
int64_t xstart = 0;
|
|
1200
|
-
int anr = nr - nr%16; // Used to align nr with boundary of 16
|
|
1201
|
-
#ifdef __AVX512F__
|
|
1202
|
-
int anc = nc - nc%16; // Used to align nc with boundary of 16
|
|
1203
|
-
// Mask to mask out nibbles from packed bytes expanded to 512 bit length
|
|
1204
|
-
const __m512i m4bexpanded = _mm512_set1_epi8(0x0F);
|
|
1205
|
-
// Lookup table to convert signed nibbles to signed bytes expanded to 512 bit length
|
|
1206
|
-
__m512i signextendlutexpanded = _mm512_inserti32x8(_mm512_castsi256_si512(signextendlut), signextendlut, 1);
|
|
1207
|
-
|
|
1208
|
-
// Take group of four block_q8_0x4 structures at each pass of the loop and perform dot product operation
|
|
1209
|
-
for (; y < anr / 4; y += 4) {
|
|
1210
|
-
|
|
1211
|
-
const block_q8_0x4 * a_ptrs[4];
|
|
1212
|
-
|
|
1213
|
-
a_ptrs[0] = a_ptr_start + (y * nb);
|
|
1214
|
-
for (int i = 0; i < 3; ++i) {
|
|
1215
|
-
a_ptrs[i + 1] = a_ptrs[i] + nb;
|
|
1216
|
-
}
|
|
1217
|
-
|
|
1218
|
-
// Take group of two block_q4_0x8 structures at each pass of the loop and perform dot product operation
|
|
1219
|
-
for (int64_t x = 0; x < anc / 8; x += 2) {
|
|
1220
|
-
|
|
1221
|
-
const block_q4_0x8 * b_ptr_0 = b_ptr_start + ((x) * b_nb);
|
|
1222
|
-
const block_q4_0x8 * b_ptr_1 = b_ptr_start + ((x + 1) * b_nb);
|
|
1223
|
-
|
|
1224
|
-
// Master FP accumulators
|
|
1225
|
-
__m512 acc_rows[16];
|
|
1226
|
-
for (int i = 0; i < 16; i++) {
|
|
1227
|
-
acc_rows[i] = _mm512_setzero_ps();
|
|
1228
|
-
}
|
|
1229
|
-
|
|
1230
|
-
for (int64_t b = 0; b < nb; b++) {
|
|
1231
|
-
// Load the sixteen block_q4_0 quantized values interleaved with each other in chunks of eight - B0,B1 ....BE,BF
|
|
1232
|
-
const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs));
|
|
1233
|
-
const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 32));
|
|
1234
|
-
const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 64));
|
|
1235
|
-
const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 96));
|
|
1236
|
-
|
|
1237
|
-
const __m256i rhs_raw_mat_89AB_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs));
|
|
1238
|
-
const __m256i rhs_raw_mat_CDEF_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 32));
|
|
1239
|
-
const __m256i rhs_raw_mat_89AB_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 64));
|
|
1240
|
-
const __m256i rhs_raw_mat_CDEF_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 96));
|
|
1241
|
-
|
|
1242
|
-
// Save the values in the following vectors in the formats B0B1B4B5B8B9BCBD, B2B3B6B7BABBBEBF for further processing and storing of values
|
|
1243
|
-
const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240);
|
|
1244
|
-
const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240);
|
|
1245
|
-
const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240);
|
|
1246
|
-
const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240);
|
|
1247
|
-
|
|
1248
|
-
const __m256i rhs_raw_mat_89CD_0 = _mm256_blend_epi32(rhs_raw_mat_89AB_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_0, requiredOrder), 240);
|
|
1249
|
-
const __m256i rhs_raw_mat_ABEF_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_0, requiredOrder), rhs_raw_mat_CDEF_0, 240);
|
|
1250
|
-
const __m256i rhs_raw_mat_89CD_1 = _mm256_blend_epi32(rhs_raw_mat_89AB_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_1, requiredOrder), 240);
|
|
1251
|
-
const __m256i rhs_raw_mat_ABEF_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_1, requiredOrder), rhs_raw_mat_CDEF_1, 240);
|
|
1252
|
-
|
|
1253
|
-
const __m512i rhs_raw_mat_014589CD_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_0), rhs_raw_mat_89CD_0, 1);
|
|
1254
|
-
const __m512i rhs_raw_mat_2367ABEF_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_0), rhs_raw_mat_ABEF_0, 1);
|
|
1255
|
-
const __m512i rhs_raw_mat_014589CD_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_1), rhs_raw_mat_89CD_1, 1);
|
|
1256
|
-
const __m512i rhs_raw_mat_2367ABEF_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_1), rhs_raw_mat_ABEF_1, 1);
|
|
1257
|
-
|
|
1258
|
-
// 4-bit -> 8-bit - Sign is maintained
|
|
1259
|
-
const __m512i rhs_mat_014589CD_0 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_014589CD_0, m4bexpanded)); //B0(0-7) B1(0-7) B4(0-7) B5(0-7) B8(0-7) B9(0-7) BC(0-7) BD(0-7)
|
|
1260
|
-
const __m512i rhs_mat_2367ABEF_0 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_2367ABEF_0, m4bexpanded)); //B2(0-7) B3(0-7) B6(0-7) B7(0-7) BA(0-7) BB(0-7) BE(0-7) BF(0-7)
|
|
1261
|
-
|
|
1262
|
-
const __m512i rhs_mat_014589CD_1 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_014589CD_1, m4bexpanded)); //B0(8-15) B1(8-15) B4(8-15) B5(8-15) B8(8-15) B9(8-15) BC(8-15) BD(8-15)
|
|
1263
|
-
const __m512i rhs_mat_2367ABEF_1 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_2367ABEF_1, m4bexpanded)); //B2(8-15) B3(8-15) B6(8-15) B7(8-15) BA(8-15) BB(8-15) BE(8-15) BF(8-15)
|
|
1264
|
-
|
|
1265
|
-
const __m512i rhs_mat_014589CD_2 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_0, 4), m4bexpanded)); //B0(16-23) B1(16-23) B4(16-23) B5(16-23) B8(16-23) B9(16-23) BC(16-23) BD(16-23)
|
|
1266
|
-
const __m512i rhs_mat_2367ABEF_2 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_0, 4), m4bexpanded)); //B2(16-23) B3(16-23) B6(16-23) B7(16-23) BA(16-23) BB(16-23) BE(16-23) BF(16-23)
|
|
1267
|
-
|
|
1268
|
-
const __m512i rhs_mat_014589CD_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_1, 4), m4bexpanded)); //B0(24-31) B1(24-31) B4(24-31) B5(24-31) B8(24-31) B9(24-31) BC(24-31) BD(24-31)
|
|
1269
|
-
const __m512i rhs_mat_2367ABEF_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 4), m4bexpanded)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31) BA(24-31) BB(24-31) BE(24-31) BF(24-31)
|
|
1270
|
-
|
|
1271
|
-
// Shuffle pattern one - right side input
|
|
1272
|
-
const __m512i rhs_mat_014589CD_0_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, (_MM_PERM_ENUM)136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) B8(0-3) B9(0-3) B8(0-3) B9(0-3) BC(0-3) BD(0-3) BC(0-3) BD(0-3)
|
|
1273
|
-
const __m512i rhs_mat_2367ABEF_0_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, (_MM_PERM_ENUM)136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) BA(0-3) BB(0-3) BA(0-3) BB(0-3) BE(0-3) BF(0-3) BE(0-3) BF(0-3)
|
|
1274
|
-
|
|
1275
|
-
const __m512i rhs_mat_014589CD_1_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, (_MM_PERM_ENUM)136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) B8(8-11) B9(8-11) B8(8-11) B9(8-11) BC(8-11) BD(8-11) BC(8-11) BD(8-11)
|
|
1276
|
-
const __m512i rhs_mat_2367ABEF_1_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, (_MM_PERM_ENUM)136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) BA(8-11) BB(8-11) BA(8-11) BB(8-11) BE(8-11) BF(8-11) BE(8-11) BF(8-11)
|
|
1277
|
-
|
|
1278
|
-
const __m512i rhs_mat_014589CD_2_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, (_MM_PERM_ENUM)136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) B8(16-19) B9(16-19) B8(16-19) B9(16-19) BC(16-19) BD(16-19) BC(16-19) BD(16-19)
|
|
1279
|
-
const __m512i rhs_mat_2367ABEF_2_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, (_MM_PERM_ENUM)136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) BA(16-19) BB(16-19) BA(16-19) BB(16-19) BE(16-19) BF(16-19) BE(16-19) BF(16-19)
|
|
1280
|
-
|
|
1281
|
-
const __m512i rhs_mat_014589CD_3_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, (_MM_PERM_ENUM)136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) B8(24-27) B9(24-27) B8(24-27) B9(24-27) BC(24-27) BD(24-27) BC(24-27) BD(24-27)
|
|
1282
|
-
const __m512i rhs_mat_2367ABEF_3_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, (_MM_PERM_ENUM)136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) BA(24-27) BB(24-27) BA(24-27) BB(24-27) BE(24-27) BF(24-27) BE(24-27) BF(24-27)
|
|
1283
|
-
|
|
1284
|
-
// Shuffle pattern two - right side input
|
|
1285
|
-
|
|
1286
|
-
const __m512i rhs_mat_014589CD_0_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, (_MM_PERM_ENUM)221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) B8(4-7) B9(4-7) B8(4-7) B9(4-7) BC(4-7) BD(4-7) BC(4-7) BD(4-7)
|
|
1287
|
-
const __m512i rhs_mat_2367ABEF_0_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, (_MM_PERM_ENUM)221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) BA(4-7) BB(4-7) BA(4-7) BB(4-7) BE(4-7) BF(4-7) BE(4-7) BF(4-7)
|
|
1288
|
-
|
|
1289
|
-
const __m512i rhs_mat_014589CD_1_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, (_MM_PERM_ENUM)221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) B8(12-15) B9(12-15) B8(12-15) B9(12-15) BC(12-15) BD(12-15) BC(12-15) BD(12-15)
|
|
1290
|
-
const __m512i rhs_mat_2367ABEF_1_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, (_MM_PERM_ENUM)221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) BA(12-15) BB(12-15) BA(12-15) BB(12-15) BE(12-15) BF(12-15) BE(12-15) BF(12-15)
|
|
1291
|
-
|
|
1292
|
-
const __m512i rhs_mat_014589CD_2_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, (_MM_PERM_ENUM)221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) B8(20-23) B9(20-23) B8(20-23) B9(20-23) BC(20-23) BD(20-23) BC(20-23) BD(20-23)
|
|
1293
|
-
const __m512i rhs_mat_2367ABEF_2_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, (_MM_PERM_ENUM)221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) BA(20-23) BB(20-23) BA(20-23) BB(20-23) BE(20-23) BF(20-23) BE(20-23) BF(20-23)
|
|
1294
|
-
|
|
1295
|
-
const __m512i rhs_mat_014589CD_3_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, (_MM_PERM_ENUM)221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) B8(28-31) B9(28-31) B8(28-31) B9(28-31) BC(28-31) BD(28-31) BC(28-31) BD(28-31)
|
|
1296
|
-
const __m512i rhs_mat_2367ABEF_3_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, (_MM_PERM_ENUM)221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) BA(28-31) BB(28-31) BA(28-31) BB(28-31) BE(28-31) BF(28-31) BE(28-31) BF(28-31)
|
|
1297
|
-
|
|
1298
|
-
// Scale values - Load the weight scale values of two block_q4_0x8
|
|
1299
|
-
const __m512 col_scale_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].d, b_ptr_1[b].d);
|
|
1300
|
-
|
|
1301
|
-
// Process LHS in pairs of rows
|
|
1302
|
-
for (int rp = 0; rp < 4; rp++) {
|
|
1303
|
-
|
|
1304
|
-
// Load the four block_q4_0 quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3
|
|
1305
|
-
// Loaded as set of 128 bit vectors and repeated and stored into a 256 bit vector before again repeating into 512 bit vector
|
|
1306
|
-
__m256i lhs_mat_ymm_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs)));
|
|
1307
|
-
__m256i lhs_mat_ymm_01_0 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_0, lhs_mat_ymm_0123_0, 0);
|
|
1308
|
-
__m256i lhs_mat_ymm_23_0 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_0, lhs_mat_ymm_0123_0, 17);
|
|
1309
|
-
__m256i lhs_mat_ymm_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 32)));
|
|
1310
|
-
__m256i lhs_mat_ymm_01_1 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_1, lhs_mat_ymm_0123_1, 0);
|
|
1311
|
-
__m256i lhs_mat_ymm_23_1 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_1, lhs_mat_ymm_0123_1, 17);
|
|
1312
|
-
__m256i lhs_mat_ymm_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 64)));
|
|
1313
|
-
__m256i lhs_mat_ymm_01_2 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_2, lhs_mat_ymm_0123_2, 0);
|
|
1314
|
-
__m256i lhs_mat_ymm_23_2 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_2, lhs_mat_ymm_0123_2, 17);
|
|
1315
|
-
__m256i lhs_mat_ymm_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 96)));
|
|
1316
|
-
__m256i lhs_mat_ymm_01_3 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_3, lhs_mat_ymm_0123_3, 0);
|
|
1317
|
-
__m256i lhs_mat_ymm_23_3 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_3, lhs_mat_ymm_0123_3, 17);
|
|
1318
|
-
|
|
1319
|
-
__m512i lhs_mat_01_0 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_0), lhs_mat_ymm_01_0, 1);
|
|
1320
|
-
__m512i lhs_mat_23_0 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_0), lhs_mat_ymm_23_0, 1);
|
|
1321
|
-
__m512i lhs_mat_01_1 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_1), lhs_mat_ymm_01_1, 1);
|
|
1322
|
-
__m512i lhs_mat_23_1 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_1), lhs_mat_ymm_23_1, 1);
|
|
1323
|
-
__m512i lhs_mat_01_2 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_2), lhs_mat_ymm_01_2, 1);
|
|
1324
|
-
__m512i lhs_mat_23_2 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_2), lhs_mat_ymm_23_2, 1);
|
|
1325
|
-
__m512i lhs_mat_01_3 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_3), lhs_mat_ymm_01_3, 1);
|
|
1326
|
-
__m512i lhs_mat_23_3 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_3), lhs_mat_ymm_23_3, 1);
|
|
1327
|
-
|
|
1328
|
-
// Shuffle pattern one - left side input
|
|
1329
|
-
|
|
1330
|
-
const __m512i lhs_mat_01_0_sp1 = _mm512_shuffle_epi32(lhs_mat_01_0, (_MM_PERM_ENUM)160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3)
|
|
1331
|
-
const __m512i lhs_mat_23_0_sp1 = _mm512_shuffle_epi32(lhs_mat_23_0, (_MM_PERM_ENUM)160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3)
|
|
1332
|
-
|
|
1333
|
-
const __m512i lhs_mat_01_1_sp1 = _mm512_shuffle_epi32(lhs_mat_01_1, (_MM_PERM_ENUM)160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11)
|
|
1334
|
-
const __m512i lhs_mat_23_1_sp1 = _mm512_shuffle_epi32(lhs_mat_23_1, (_MM_PERM_ENUM)160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11)
|
|
1335
|
-
|
|
1336
|
-
const __m512i lhs_mat_01_2_sp1 = _mm512_shuffle_epi32(lhs_mat_01_2, (_MM_PERM_ENUM)160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19)
|
|
1337
|
-
const __m512i lhs_mat_23_2_sp1 = _mm512_shuffle_epi32(lhs_mat_23_2, (_MM_PERM_ENUM)160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19)
|
|
1338
|
-
|
|
1339
|
-
const __m512i lhs_mat_01_3_sp1 = _mm512_shuffle_epi32(lhs_mat_01_3, (_MM_PERM_ENUM)160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27)
|
|
1340
|
-
const __m512i lhs_mat_23_3_sp1 = _mm512_shuffle_epi32(lhs_mat_23_3, (_MM_PERM_ENUM)160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27)
|
|
1341
|
-
|
|
1342
|
-
// Shuffle pattern two - left side input
|
|
1343
|
-
|
|
1344
|
-
const __m512i lhs_mat_01_0_sp2 = _mm512_shuffle_epi32(lhs_mat_01_0, (_MM_PERM_ENUM)245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7)
|
|
1345
|
-
const __m512i lhs_mat_23_0_sp2 = _mm512_shuffle_epi32(lhs_mat_23_0, (_MM_PERM_ENUM)245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7)
|
|
1346
|
-
|
|
1347
|
-
const __m512i lhs_mat_01_1_sp2 = _mm512_shuffle_epi32(lhs_mat_01_1, (_MM_PERM_ENUM)245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15)
|
|
1348
|
-
const __m512i lhs_mat_23_1_sp2 = _mm512_shuffle_epi32(lhs_mat_23_1, (_MM_PERM_ENUM)245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15)
|
|
1349
|
-
|
|
1350
|
-
const __m512i lhs_mat_01_2_sp2 = _mm512_shuffle_epi32(lhs_mat_01_2, (_MM_PERM_ENUM)245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23)
|
|
1351
|
-
const __m512i lhs_mat_23_2_sp2 = _mm512_shuffle_epi32(lhs_mat_23_2, (_MM_PERM_ENUM)245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23)
|
|
1352
|
-
|
|
1353
|
-
const __m512i lhs_mat_01_3_sp2 = _mm512_shuffle_epi32(lhs_mat_01_3, (_MM_PERM_ENUM)245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31)
|
|
1354
|
-
const __m512i lhs_mat_23_3_sp2 = _mm512_shuffle_epi32(lhs_mat_23_3, (_MM_PERM_ENUM)245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31)
|
|
1355
|
-
|
|
1356
|
-
// The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
|
|
1357
|
-
// Resembles MMLAs into 2x2 matrices in ARM Version
|
|
1358
|
-
const __m512i zero = _mm512_setzero_epi32();
|
|
1359
|
-
__m512i iacc_mat_00_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp1, rhs_mat_014589CD_3_sp1), lhs_mat_01_2_sp1, rhs_mat_014589CD_2_sp1), lhs_mat_01_1_sp1, rhs_mat_014589CD_1_sp1), lhs_mat_01_0_sp1, rhs_mat_014589CD_0_sp1);
|
|
1360
|
-
__m512i iacc_mat_01_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp1, rhs_mat_2367ABEF_3_sp1), lhs_mat_01_2_sp1, rhs_mat_2367ABEF_2_sp1), lhs_mat_01_1_sp1, rhs_mat_2367ABEF_1_sp1), lhs_mat_01_0_sp1, rhs_mat_2367ABEF_0_sp1);
|
|
1361
|
-
__m512i iacc_mat_10_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp1, rhs_mat_014589CD_3_sp1), lhs_mat_23_2_sp1, rhs_mat_014589CD_2_sp1), lhs_mat_23_1_sp1, rhs_mat_014589CD_1_sp1), lhs_mat_23_0_sp1, rhs_mat_014589CD_0_sp1);
|
|
1362
|
-
__m512i iacc_mat_11_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp1, rhs_mat_2367ABEF_3_sp1), lhs_mat_23_2_sp1, rhs_mat_2367ABEF_2_sp1), lhs_mat_23_1_sp1, rhs_mat_2367ABEF_1_sp1), lhs_mat_23_0_sp1, rhs_mat_2367ABEF_0_sp1);
|
|
1363
|
-
__m512i iacc_mat_00_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp2, rhs_mat_014589CD_3_sp2), lhs_mat_01_2_sp2, rhs_mat_014589CD_2_sp2), lhs_mat_01_1_sp2, rhs_mat_014589CD_1_sp2), lhs_mat_01_0_sp2, rhs_mat_014589CD_0_sp2);
|
|
1364
|
-
__m512i iacc_mat_01_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp2, rhs_mat_2367ABEF_3_sp2), lhs_mat_01_2_sp2, rhs_mat_2367ABEF_2_sp2), lhs_mat_01_1_sp2, rhs_mat_2367ABEF_1_sp2), lhs_mat_01_0_sp2, rhs_mat_2367ABEF_0_sp2);
|
|
1365
|
-
__m512i iacc_mat_10_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp2, rhs_mat_014589CD_3_sp2), lhs_mat_23_2_sp2, rhs_mat_014589CD_2_sp2), lhs_mat_23_1_sp2, rhs_mat_014589CD_1_sp2), lhs_mat_23_0_sp2, rhs_mat_014589CD_0_sp2);
|
|
1366
|
-
__m512i iacc_mat_11_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp2, rhs_mat_2367ABEF_3_sp2), lhs_mat_23_2_sp2, rhs_mat_2367ABEF_2_sp2), lhs_mat_23_1_sp2, rhs_mat_2367ABEF_1_sp2), lhs_mat_23_0_sp2, rhs_mat_2367ABEF_0_sp2);
|
|
1367
|
-
|
|
1368
|
-
// Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
|
|
1369
|
-
__m512i iacc_mat_00 = _mm512_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2);
|
|
1370
|
-
__m512i iacc_mat_01 = _mm512_add_epi32(iacc_mat_01_sp1, iacc_mat_01_sp2);
|
|
1371
|
-
__m512i iacc_mat_10 = _mm512_add_epi32(iacc_mat_10_sp1, iacc_mat_10_sp2);
|
|
1372
|
-
__m512i iacc_mat_11 = _mm512_add_epi32(iacc_mat_11_sp1, iacc_mat_11_sp2);
|
|
1373
|
-
|
|
1374
|
-
|
|
1375
|
-
// Straighten out to make 4 row vectors
|
|
1376
|
-
__m512i iacc_row_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00, _mm512_shuffle_epi32(iacc_mat_01, (_MM_PERM_ENUM)78));
|
|
1377
|
-
__m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, (_MM_PERM_ENUM)78), iacc_mat_01);
|
|
1378
|
-
__m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, (_MM_PERM_ENUM)78));
|
|
1379
|
-
__m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, (_MM_PERM_ENUM)78), iacc_mat_11);
|
|
1380
|
-
|
|
1381
|
-
// Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes
|
|
1382
|
-
const __m128i row_scale_f16 = _mm_shuffle_epi32(_mm_maskload_epi32((int const*)(a_ptrs[rp][b].d), loadMask), 68);
|
|
1383
|
-
const __m512 row_scale_f32 = GGML_F32Cx16_REPEAT_LOAD(row_scale_f16);
|
|
1384
|
-
|
|
1385
|
-
// Multiply with appropiate scales and accumulate
|
|
1386
|
-
acc_rows[rp * 4] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_0), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[rp * 4]);
|
|
1387
|
-
acc_rows[rp * 4 + 1] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_1), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[rp * 4 + 1]);
|
|
1388
|
-
acc_rows[rp * 4 + 2] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_2), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[rp * 4 + 2]);
|
|
1389
|
-
acc_rows[rp * 4 + 3] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_3), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[rp * 4 + 3]);
|
|
1390
|
-
}
|
|
1391
|
-
}
|
|
1392
|
-
|
|
1393
|
-
// Store the accumulated values
|
|
1394
|
-
for (int i = 0; i < 16; i++) {
|
|
1395
|
-
_mm512_storeu_ps((float *)(s + ((y * 4 + i) * bs + x * 8)), acc_rows[i]);
|
|
1396
|
-
}
|
|
1397
|
-
}
|
|
1398
|
-
}
|
|
1399
|
-
// Take a block_q8_0x4 structures at each pass of the loop and perform dot product operation
|
|
1400
|
-
for (; y < nr / 4; y ++) {
|
|
1401
|
-
|
|
1402
|
-
const block_q8_0x4 * a_ptr = a_ptr_start + (y * nb);
|
|
1403
|
-
|
|
1404
|
-
// Take group of two block_q4_0x8 structures at each pass of the loop and perform dot product operation
|
|
1405
|
-
for (int64_t x = 0; x < anc / 8; x += 2) {
|
|
1406
|
-
|
|
1407
|
-
const block_q4_0x8 * b_ptr_0 = b_ptr_start + ((x) * b_nb);
|
|
1408
|
-
const block_q4_0x8 * b_ptr_1 = b_ptr_start + ((x + 1) * b_nb);
|
|
1409
|
-
|
|
1410
|
-
// Master FP accumulators
|
|
1411
|
-
__m512 acc_rows[4];
|
|
1412
|
-
for (int i = 0; i < 4; i++) {
|
|
1413
|
-
acc_rows[i] = _mm512_setzero_ps();
|
|
1414
|
-
}
|
|
1415
|
-
|
|
1416
|
-
for (int64_t b = 0; b < nb; b++) {
|
|
1417
|
-
// Load the sixteen block_q4_0 quantized values interleaved with each other in chunks of eight - B0,B1 ....BE,BF
|
|
1418
|
-
const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs));
|
|
1419
|
-
const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 32));
|
|
1420
|
-
const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 64));
|
|
1421
|
-
const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 96));
|
|
1422
|
-
|
|
1423
|
-
const __m256i rhs_raw_mat_89AB_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs));
|
|
1424
|
-
const __m256i rhs_raw_mat_CDEF_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 32));
|
|
1425
|
-
const __m256i rhs_raw_mat_89AB_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 64));
|
|
1426
|
-
const __m256i rhs_raw_mat_CDEF_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 96));
|
|
1427
|
-
|
|
1428
|
-
// Save the values in the following vectors in the formats B0B1B4B5, B2B3B6B7 for further processing and storing of valuess
|
|
1429
|
-
const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240);
|
|
1430
|
-
const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240);
|
|
1431
|
-
const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240);
|
|
1432
|
-
const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240);
|
|
1433
|
-
|
|
1434
|
-
const __m256i rhs_raw_mat_89CD_0 = _mm256_blend_epi32(rhs_raw_mat_89AB_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_0, requiredOrder), 240);
|
|
1435
|
-
const __m256i rhs_raw_mat_ABEF_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_0, requiredOrder), rhs_raw_mat_CDEF_0, 240);
|
|
1436
|
-
const __m256i rhs_raw_mat_89CD_1 = _mm256_blend_epi32(rhs_raw_mat_89AB_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_1, requiredOrder), 240);
|
|
1437
|
-
const __m256i rhs_raw_mat_ABEF_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_1, requiredOrder), rhs_raw_mat_CDEF_1, 240);
|
|
1478
|
+
// 4-bit -> 8-bit
|
|
1479
|
+
// Values of the first sub block of eight block_q4_K structures for the sb loop
|
|
1480
|
+
const __m256i rhs_vec_0123_00 = _mm256_and_si256(rhs_raw_vec_0123_0, m4b);
|
|
1481
|
+
const __m256i rhs_vec_4567_00 = _mm256_and_si256(rhs_raw_vec_4567_0, m4b);
|
|
1482
|
+
const __m256i rhs_vec_0123_01 = _mm256_and_si256(rhs_raw_vec_0123_1, m4b);
|
|
1483
|
+
const __m256i rhs_vec_4567_01 = _mm256_and_si256(rhs_raw_vec_4567_1, m4b);
|
|
1484
|
+
const __m256i rhs_vec_0123_02 = _mm256_and_si256(rhs_raw_vec_0123_2, m4b);
|
|
1485
|
+
const __m256i rhs_vec_4567_02 = _mm256_and_si256(rhs_raw_vec_4567_2, m4b);
|
|
1486
|
+
const __m256i rhs_vec_0123_03 = _mm256_and_si256(rhs_raw_vec_0123_3, m4b);
|
|
1487
|
+
const __m256i rhs_vec_4567_03 = _mm256_and_si256(rhs_raw_vec_4567_3, m4b);
|
|
1438
1488
|
|
|
1439
|
-
|
|
1440
|
-
const
|
|
1441
|
-
const
|
|
1442
|
-
const
|
|
1489
|
+
// Values of the second sub block of eight block_q4_K structures when sb = 1
|
|
1490
|
+
const __m256i rhs_vec_0123_10 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_0, 4), m4b);
|
|
1491
|
+
const __m256i rhs_vec_4567_10 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_0, 4), m4b);
|
|
1492
|
+
const __m256i rhs_vec_0123_11 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_1, 4), m4b);
|
|
1493
|
+
const __m256i rhs_vec_4567_11 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_1, 4), m4b);
|
|
1494
|
+
const __m256i rhs_vec_0123_12 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_2, 4), m4b);
|
|
1495
|
+
const __m256i rhs_vec_4567_12 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_2, 4), m4b);
|
|
1496
|
+
const __m256i rhs_vec_0123_13 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_3, 4), m4b);
|
|
1497
|
+
const __m256i rhs_vec_4567_13 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_3, 4), m4b);
|
|
1443
1498
|
|
|
1444
|
-
|
|
1445
|
-
const __m512i rhs_mat_014589CD_0 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_014589CD_0, m4bexpanded)); //B0(0-7) B1(0-7) B4(0-7) B5(0-7) B8(0-7) B9(0-7) BC(0-7) BD(0-7)
|
|
1446
|
-
const __m512i rhs_mat_2367ABEF_0 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_2367ABEF_0, m4bexpanded)); //B2(0-7) B3(0-7) B6(0-7) B7(0-7) BA(0-7) BB(0-7) BE(0-7) BF(0-7)
|
|
1499
|
+
uint32_t utmp_0[4], utmp_1[4];
|
|
1447
1500
|
|
|
1448
|
-
|
|
1449
|
-
|
|
1501
|
+
// Scales and Mins of corresponding sub blocks from different Q8_K structures are stored together
|
|
1502
|
+
// The below block is for eg to extract first sub block's scales and mins from different Q4_K structures for the sb loop
|
|
1503
|
+
memcpy(utmp_0, b_ptr[b].scales + 24 * sb, 12);
|
|
1504
|
+
utmp_0[3] = ((utmp_0[2] >> 4) & kmask2) | (((utmp_0[1] >> 6) & kmask3) << 4);
|
|
1505
|
+
const uint32_t uaux_0 = utmp_0[1] & kmask1;
|
|
1506
|
+
utmp_0[1] = (utmp_0[2] & kmask2) | (((utmp_0[0] >> 6) & kmask3) << 4);
|
|
1507
|
+
utmp_0[2] = uaux_0;
|
|
1508
|
+
utmp_0[0] &= kmask1;
|
|
1450
1509
|
|
|
1451
|
-
|
|
1452
|
-
|
|
1510
|
+
// The below block is for eg to extract second sub block's scales and mins from different Q4_K structures for the sb loop
|
|
1511
|
+
memcpy(utmp_1, b_ptr[b].scales + 12 + sb * 24, 12);
|
|
1512
|
+
utmp_1[3] = ((utmp_1[2] >> 4) & kmask2) | (((utmp_1[1] >> 6) & kmask3) << 4);
|
|
1513
|
+
const uint32_t uaux_1 = utmp_1[1] & kmask1;
|
|
1514
|
+
utmp_1[1] = (utmp_1[2] & kmask2) | (((utmp_1[0] >> 6) & kmask3) << 4);
|
|
1515
|
+
utmp_1[2] = uaux_1;
|
|
1516
|
+
utmp_1[0] &= kmask1;
|
|
1453
1517
|
|
|
1454
|
-
|
|
1455
|
-
const
|
|
1518
|
+
// Scales of first sub block in the sb loop
|
|
1519
|
+
const __m128i mins_and_scales_0 = _mm_set_epi32(utmp_0[3], utmp_0[2], utmp_0[1], utmp_0[0]);
|
|
1520
|
+
__m128i scales_rearrange_0 = _mm_shuffle_epi8(mins_and_scales_0, scalemask);
|
|
1521
|
+
__m256i scales_0 = _mm256_cvtepu8_epi16(scales_rearrange_0);
|
|
1456
1522
|
|
|
1457
|
-
//
|
|
1458
|
-
|
|
1459
|
-
|
|
1523
|
+
// Scales of second sub block in the sb loop
|
|
1524
|
+
__m128i mins_and_scales_1 = _mm_set_epi32(utmp_1[3], utmp_1[2], utmp_1[1], utmp_1[0]);
|
|
1525
|
+
__m128i scales_rearrange_1 = _mm_shuffle_epi8(mins_and_scales_1, scalemask);
|
|
1526
|
+
__m256i scales_1 = _mm256_cvtepu8_epi16(scales_rearrange_1);
|
|
1460
1527
|
|
|
1461
|
-
|
|
1462
|
-
|
|
1528
|
+
// Mins of first and second sub block of Q4_K block are arranged side by side
|
|
1529
|
+
__m256i mins_01 = _mm256_cvtepu8_epi16(_mm_unpacklo_epi8(_mm_shuffle_epi32(mins_and_scales_0, 78), _mm_shuffle_epi32(mins_and_scales_1, 78)));
|
|
1463
1530
|
|
|
1464
|
-
|
|
1465
|
-
|
|
1531
|
+
// Load the two sub block values corresponding to sb in block_q8_K in batches of 16 bytes and replicate the same across 256 bit vector
|
|
1532
|
+
__m256i lhs_vec_00 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + sb * 64)));
|
|
1533
|
+
__m256i lhs_vec_01 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 16 + sb * 64)));
|
|
1534
|
+
__m256i lhs_vec_10 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 32 + sb * 64)));
|
|
1535
|
+
__m256i lhs_vec_11 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 48 + sb * 64)));
|
|
1466
1536
|
|
|
1467
|
-
|
|
1468
|
-
|
|
1537
|
+
lhs_vec_00 = _mm256_permute2f128_si256(lhs_vec_00, lhs_vec_00, 0);
|
|
1538
|
+
lhs_vec_01 = _mm256_permute2f128_si256(lhs_vec_01, lhs_vec_01, 0);
|
|
1539
|
+
lhs_vec_10 = _mm256_permute2f128_si256(lhs_vec_10, lhs_vec_10, 0);
|
|
1540
|
+
lhs_vec_11 = _mm256_permute2f128_si256(lhs_vec_11, lhs_vec_11, 0);
|
|
1469
1541
|
|
|
1470
|
-
//
|
|
1542
|
+
// Dot product done within 32 bit lanes and accumulated in the same vector
|
|
1543
|
+
// First done for first sub block and thenn for second sub block in each sb
|
|
1544
|
+
// B0(0-3) B4(0-3) B1(0-3) B5(0-3) B2(0-3) B6(0-3) B3(0-3) B7(0-3) with A0(0-3)
|
|
1545
|
+
// B0(4-7) B4(4-7) B1(4-7) B5(4-7) B2(4-7) B6(4-7) B3(4-7) B7(4-7) with A0(4-7)
|
|
1546
|
+
// ...........................................................................
|
|
1547
|
+
// B0(28-31) B4(28-31) B1(28-31) B5(28-31) B2(28-31) B6(28-31) B3(28-31) B7(28-31) with A0(28-31)
|
|
1471
1548
|
|
|
1472
|
-
const __m512i rhs_mat_014589CD_0_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, (_MM_PERM_ENUM)221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) B8(4-7) B9(4-7) B8(4-7) B9(4-7) BC(4-7) BD(4-7) BC(4-7) BD(4-7)
|
|
1473
|
-
const __m512i rhs_mat_2367ABEF_0_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, (_MM_PERM_ENUM)221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) BA(4-7) BB(4-7) BA(4-7) BB(4-7) BE(4-7) BF(4-7) BE(4-7) BF(4-7)
|
|
1474
1549
|
|
|
1475
|
-
|
|
1476
|
-
|
|
1550
|
+
__m256i iacc_0 = _mm256_setzero_si256();
|
|
1551
|
+
__m256i iacc_1 = _mm256_setzero_si256();
|
|
1477
1552
|
|
|
1478
|
-
|
|
1479
|
-
|
|
1553
|
+
iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_00 ,_mm256_shuffle_epi32(rhs_vec_4567_00, 177), 170), _mm256_shuffle_epi32(lhs_vec_00, 0)));
|
|
1554
|
+
iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_00, 177) ,rhs_vec_4567_00, 170), _mm256_shuffle_epi32(lhs_vec_00, 85)));
|
|
1480
1555
|
|
|
1481
|
-
|
|
1482
|
-
|
|
1556
|
+
iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_01 ,_mm256_shuffle_epi32(rhs_vec_4567_01, 177), 170), _mm256_shuffle_epi32(lhs_vec_00, 170)));
|
|
1557
|
+
iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_01, 177) ,rhs_vec_4567_01, 170), _mm256_shuffle_epi32(lhs_vec_00, 255)));
|
|
1483
1558
|
|
|
1559
|
+
iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_02 ,_mm256_shuffle_epi32(rhs_vec_4567_02, 177), 170), _mm256_shuffle_epi32(lhs_vec_01, 0)));
|
|
1560
|
+
iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_02, 177) ,rhs_vec_4567_02, 170), _mm256_shuffle_epi32(lhs_vec_01, 85)));
|
|
1484
1561
|
|
|
1485
|
-
|
|
1486
|
-
|
|
1562
|
+
iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_03 ,_mm256_shuffle_epi32(rhs_vec_4567_03, 177), 170), _mm256_shuffle_epi32(lhs_vec_01, 170)));
|
|
1563
|
+
iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_03, 177) ,rhs_vec_4567_03, 170), _mm256_shuffle_epi32(lhs_vec_01, 255)));
|
|
1487
1564
|
|
|
1488
|
-
|
|
1489
|
-
// Loaded as set of 128 bit vectors and repeated and stored into a 256 bit vector before again repeating into 512 bit vector
|
|
1490
|
-
__m256i lhs_mat_ymm_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs)));
|
|
1491
|
-
__m256i lhs_mat_ymm_01_0 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_0, lhs_mat_ymm_0123_0, 0);
|
|
1492
|
-
__m256i lhs_mat_ymm_23_0 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_0, lhs_mat_ymm_0123_0, 17);
|
|
1493
|
-
__m256i lhs_mat_ymm_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 32)));
|
|
1494
|
-
__m256i lhs_mat_ymm_01_1 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_1, lhs_mat_ymm_0123_1, 0);
|
|
1495
|
-
__m256i lhs_mat_ymm_23_1 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_1, lhs_mat_ymm_0123_1, 17);
|
|
1496
|
-
__m256i lhs_mat_ymm_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 64)));
|
|
1497
|
-
__m256i lhs_mat_ymm_01_2 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_2, lhs_mat_ymm_0123_2, 0);
|
|
1498
|
-
__m256i lhs_mat_ymm_23_2 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_2, lhs_mat_ymm_0123_2, 17);
|
|
1499
|
-
__m256i lhs_mat_ymm_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 96)));
|
|
1500
|
-
__m256i lhs_mat_ymm_01_3 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_3, lhs_mat_ymm_0123_3, 0);
|
|
1501
|
-
__m256i lhs_mat_ymm_23_3 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_3, lhs_mat_ymm_0123_3, 17);
|
|
1565
|
+
iacc_0 = _mm256_madd_epi16(iacc_0, scales_0);
|
|
1502
1566
|
|
|
1503
|
-
|
|
1504
|
-
|
|
1505
|
-
__m512i lhs_mat_01_1 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_1), lhs_mat_ymm_01_1, 1);
|
|
1506
|
-
__m512i lhs_mat_23_1 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_1), lhs_mat_ymm_23_1, 1);
|
|
1507
|
-
__m512i lhs_mat_01_2 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_2), lhs_mat_ymm_01_2, 1);
|
|
1508
|
-
__m512i lhs_mat_23_2 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_2), lhs_mat_ymm_23_2, 1);
|
|
1509
|
-
__m512i lhs_mat_01_3 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_3), lhs_mat_ymm_01_3, 1);
|
|
1510
|
-
__m512i lhs_mat_23_3 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_3), lhs_mat_ymm_23_3, 1);
|
|
1567
|
+
iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_10 ,_mm256_shuffle_epi32(rhs_vec_4567_10, 177), 170), _mm256_shuffle_epi32(lhs_vec_10, 0)));
|
|
1568
|
+
iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_10, 177) ,rhs_vec_4567_10, 170), _mm256_shuffle_epi32(lhs_vec_10, 85)));
|
|
1511
1569
|
|
|
1512
|
-
|
|
1570
|
+
iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_11 ,_mm256_shuffle_epi32(rhs_vec_4567_11, 177), 170), _mm256_shuffle_epi32(lhs_vec_10, 170)));
|
|
1571
|
+
iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_11, 177) ,rhs_vec_4567_11, 170), _mm256_shuffle_epi32(lhs_vec_10, 255)));
|
|
1513
1572
|
|
|
1514
|
-
|
|
1515
|
-
|
|
1573
|
+
iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_12 ,_mm256_shuffle_epi32(rhs_vec_4567_12, 177), 170), _mm256_shuffle_epi32(lhs_vec_11, 0)));
|
|
1574
|
+
iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_12, 177) ,rhs_vec_4567_12, 170), _mm256_shuffle_epi32(lhs_vec_11, 85)));
|
|
1516
1575
|
|
|
1517
|
-
|
|
1518
|
-
|
|
1576
|
+
iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_13 ,_mm256_shuffle_epi32(rhs_vec_4567_13, 177), 170), _mm256_shuffle_epi32(lhs_vec_11, 170)));
|
|
1577
|
+
iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_13, 177) ,rhs_vec_4567_13, 170), _mm256_shuffle_epi32(lhs_vec_11, 255)));
|
|
1519
1578
|
|
|
1520
|
-
|
|
1521
|
-
const __m512i lhs_mat_23_2_sp1 = _mm512_shuffle_epi32(lhs_mat_23_2, (_MM_PERM_ENUM)160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19)
|
|
1579
|
+
iacc_1 = _mm256_madd_epi16(iacc_1, scales_1);
|
|
1522
1580
|
|
|
1523
|
-
|
|
1524
|
-
|
|
1581
|
+
// Accumulate the iacc value for one sb
|
|
1582
|
+
__m256i iacc_sb = _mm256_add_epi32(iacc_0, iacc_1);
|
|
1525
1583
|
|
|
1526
|
-
//
|
|
1584
|
+
// Broadcast the bsums of the two sub blocks of the iteration of Q8_K across the vector
|
|
1585
|
+
// Multiply-Add with corresponding mins of Q4_Kx8 with bsums
|
|
1586
|
+
__m256i q8s_sb = _mm256_shuffle_epi32(q8s, 0);
|
|
1587
|
+
__m256i iacc_min_sb = _mm256_madd_epi16(q8s_sb, mins_01);
|
|
1588
|
+
q8s = _mm256_bsrli_epi128(q8s, 4);
|
|
1527
1589
|
|
|
1528
|
-
|
|
1529
|
-
|
|
1590
|
+
// Accumulate for the complete block
|
|
1591
|
+
iacc_b = _mm256_add_epi32(iacc_b, iacc_sb);
|
|
1592
|
+
iacc_min_b = _mm256_add_epi32(iacc_min_b, iacc_min_sb);
|
|
1593
|
+
}
|
|
1530
1594
|
|
|
1531
|
-
|
|
1532
|
-
|
|
1595
|
+
// Multiply-Add with scale values for the complete super block
|
|
1596
|
+
acc_row = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_b), _mm256_mul_ps(col_scale_f32, row_scale_f32), acc_row);
|
|
1597
|
+
acc_min_rows = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_min_b), _mm256_mul_ps(col_dmin_f32, row_scale_f32), acc_min_rows);
|
|
1533
1598
|
|
|
1534
|
-
|
|
1535
|
-
const __m512i lhs_mat_23_2_sp2 = _mm512_shuffle_epi32(lhs_mat_23_2, (_MM_PERM_ENUM)245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23)
|
|
1599
|
+
}
|
|
1536
1600
|
|
|
1537
|
-
|
|
1538
|
-
|
|
1601
|
+
// Accumulated output values permuted so as to be stored in appropriate order post accumulation
|
|
1602
|
+
acc_row = _mm256_permutevar8x32_ps(acc_row, finalpermutemask);
|
|
1603
|
+
_mm256_storeu_ps(s + (y * nr + x * 8), _mm256_sub_ps(acc_row, acc_min_rows));
|
|
1604
|
+
}
|
|
1605
|
+
}
|
|
1539
1606
|
|
|
1540
|
-
|
|
1541
|
-
|
|
1542
|
-
|
|
1543
|
-
|
|
1544
|
-
|
|
1545
|
-
|
|
1546
|
-
|
|
1547
|
-
__m512i iacc_mat_00_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp2, rhs_mat_014589CD_3_sp2), lhs_mat_01_2_sp2, rhs_mat_014589CD_2_sp2), lhs_mat_01_1_sp2, rhs_mat_014589CD_1_sp2), lhs_mat_01_0_sp2, rhs_mat_014589CD_0_sp2);
|
|
1548
|
-
__m512i iacc_mat_01_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp2, rhs_mat_2367ABEF_3_sp2), lhs_mat_01_2_sp2, rhs_mat_2367ABEF_2_sp2), lhs_mat_01_1_sp2, rhs_mat_2367ABEF_1_sp2), lhs_mat_01_0_sp2, rhs_mat_2367ABEF_0_sp2);
|
|
1549
|
-
__m512i iacc_mat_10_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp2, rhs_mat_014589CD_3_sp2), lhs_mat_23_2_sp2, rhs_mat_014589CD_2_sp2), lhs_mat_23_1_sp2, rhs_mat_014589CD_1_sp2), lhs_mat_23_0_sp2, rhs_mat_014589CD_0_sp2);
|
|
1550
|
-
__m512i iacc_mat_11_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp2, rhs_mat_2367ABEF_3_sp2), lhs_mat_23_2_sp2, rhs_mat_2367ABEF_2_sp2), lhs_mat_23_1_sp2, rhs_mat_2367ABEF_1_sp2), lhs_mat_23_0_sp2, rhs_mat_2367ABEF_0_sp2);
|
|
1607
|
+
#else
|
|
1608
|
+
UNUSED(kmask1);
|
|
1609
|
+
UNUSED(kmask2);
|
|
1610
|
+
UNUSED(kmask3);
|
|
1611
|
+
ggml_gemv_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
|
1612
|
+
#endif
|
|
1613
|
+
}
|
|
1551
1614
|
|
|
1552
|
-
|
|
1553
|
-
|
|
1554
|
-
|
|
1555
|
-
|
|
1556
|
-
__m512i iacc_mat_11 = _mm512_add_epi32(iacc_mat_11_sp1, iacc_mat_11_sp2);
|
|
1615
|
+
void ggml_gemv_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
|
1616
|
+
#if defined(__AVX2__)
|
|
1617
|
+
__m256i signextendlut = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i*)kvalues_iq4nl));
|
|
1618
|
+
signextendlut = _mm256_permute2f128_si256(signextendlut, signextendlut, 0);
|
|
1557
1619
|
|
|
1620
|
+
gemv_q4_b32_8x8_q8_0_lut_avx<block_iq4_nlx8>(n, s, bs, vx, vy, nr, nc, signextendlut);
|
|
1558
1621
|
|
|
1559
|
-
|
|
1560
|
-
|
|
1561
|
-
__m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, (_MM_PERM_ENUM)78), iacc_mat_01);
|
|
1562
|
-
__m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, (_MM_PERM_ENUM)78));
|
|
1563
|
-
__m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, (_MM_PERM_ENUM)78), iacc_mat_11);
|
|
1622
|
+
return;
|
|
1623
|
+
#endif
|
|
1564
1624
|
|
|
1565
|
-
|
|
1566
|
-
|
|
1567
|
-
const __m512 row_scale_f32 = GGML_F32Cx16_REPEAT_LOAD(row_scale_f16);
|
|
1625
|
+
ggml_gemv_iq4_nl_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
|
1626
|
+
}
|
|
1568
1627
|
|
|
1569
|
-
|
|
1570
|
-
|
|
1571
|
-
|
|
1572
|
-
|
|
1573
|
-
|
|
1574
|
-
}
|
|
1628
|
+
void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
|
1629
|
+
const int qk = QK_K;
|
|
1630
|
+
const int nb = n / qk;
|
|
1631
|
+
const int ncols_interleaved = 8;
|
|
1632
|
+
const int blocklen = 8;
|
|
1575
1633
|
|
|
1576
|
-
|
|
1577
|
-
|
|
1578
|
-
_mm512_storeu_ps((float *)(s + ((y * 4 + i) * bs + x * 8)), acc_rows[i]);
|
|
1579
|
-
}
|
|
1580
|
-
}
|
|
1581
|
-
}
|
|
1582
|
-
if (anc != nc) {
|
|
1583
|
-
xstart = anc/8;
|
|
1584
|
-
y = 0;
|
|
1585
|
-
}
|
|
1586
|
-
#endif // __AVX512F__
|
|
1634
|
+
assert (n % qk == 0);
|
|
1635
|
+
assert (nc % ncols_interleaved == 0);
|
|
1587
1636
|
|
|
1588
|
-
|
|
1637
|
+
UNUSED(s);
|
|
1638
|
+
UNUSED(bs);
|
|
1639
|
+
UNUSED(vx);
|
|
1640
|
+
UNUSED(vy);
|
|
1641
|
+
UNUSED(nr);
|
|
1642
|
+
UNUSED(nc);
|
|
1643
|
+
UNUSED(nb);
|
|
1644
|
+
UNUSED(ncols_interleaved);
|
|
1645
|
+
UNUSED(blocklen);
|
|
1589
1646
|
|
|
1590
|
-
|
|
1591
|
-
|
|
1647
|
+
#if defined(__AVX2__)
|
|
1648
|
+
// Lookup table to convert signed nibbles to signed bytes
|
|
1649
|
+
__m256i signextendlut = _mm256_castsi128_si256(_mm_set_epi8(-1, -2, -3, -4, -5, -6, -7, -8, 7, 6, 5, 4, 3, 2, 1, 0));
|
|
1650
|
+
signextendlut = _mm256_permute2f128_si256(signextendlut, signextendlut, 0);
|
|
1651
|
+
// Shuffle masks to rearrange delta values to multiply with appropriate scales
|
|
1652
|
+
__m128i deltamask = _mm_set_epi8(15, 14, 7, 6, 13, 12, 5, 4, 11, 10, 3, 2, 9, 8, 1, 0);
|
|
1653
|
+
// Permute mask used for easier vector processing at later stages
|
|
1654
|
+
__m256i finalpermutemask = _mm256_set_epi32(7, 5, 3, 1, 6, 4, 2, 0);
|
|
1592
1655
|
|
|
1593
|
-
|
|
1594
|
-
|
|
1595
|
-
a_ptrs[i + 1] = a_ptrs[i] + nb;
|
|
1596
|
-
}
|
|
1656
|
+
const __m256i m3b = _mm256_set1_epi8(3);
|
|
1657
|
+
const __m128i m4b_sse = _mm_set1_epi8(0xF);
|
|
1597
1658
|
|
|
1598
|
-
|
|
1599
|
-
|
|
1659
|
+
//Mask to get appropriate scales
|
|
1660
|
+
__m128i scalemask1 = _mm_set_epi8(14,14,6,6,12,12,4,4,10,10,2,2,8,8,0,0);
|
|
1661
|
+
__m128i scalemask2 = _mm_set_epi8(15,15,7,7,13,13,5,5,11,11,3,3,9,9,1,1);
|
|
1600
1662
|
|
|
1601
|
-
|
|
1663
|
+
int64_t b_nb = n / QK_K;
|
|
1602
1664
|
|
|
1603
|
-
|
|
1604
|
-
|
|
1605
|
-
for (int i = 0; i < 16; i++) {
|
|
1606
|
-
acc_rows[i] = _mm256_setzero_ps();
|
|
1607
|
-
}
|
|
1665
|
+
const block_q2_Kx8 * b_ptr_start = (const block_q2_Kx8 *)vx;
|
|
1666
|
+
const block_q8_K * a_ptr_start = (const block_q8_K *)vy;
|
|
1608
1667
|
|
|
1609
|
-
|
|
1610
|
-
|
|
1611
|
-
const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs));
|
|
1612
|
-
const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 32));
|
|
1613
|
-
const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 64));
|
|
1614
|
-
const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 96));
|
|
1668
|
+
// Process Q8_K blocks one by one
|
|
1669
|
+
for (int64_t y = 0; y < nr; y++) {
|
|
1615
1670
|
|
|
1616
|
-
|
|
1617
|
-
|
|
1618
|
-
const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240);
|
|
1619
|
-
const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240);
|
|
1620
|
-
const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240);
|
|
1671
|
+
// Pointers to LHS blocks of block_q8_K format
|
|
1672
|
+
const block_q8_K * a_ptr = a_ptr_start + (y * nb);
|
|
1621
1673
|
|
|
1622
|
-
|
|
1623
|
-
|
|
1624
|
-
const __m256i rhs_mat_2367_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_2367_0, m4b)); //B2(0-7) B3(0-7) B6(0-7) B7(0-7)
|
|
1674
|
+
// Take group of eight interleaved block_q2_K structures at each pass of the loop and perform dot product operation
|
|
1675
|
+
for(int64_t x = 0; x < nc / 8; x++) {
|
|
1625
1676
|
|
|
1626
|
-
|
|
1627
|
-
|
|
1677
|
+
// Pointers to RHS blocks
|
|
1678
|
+
const block_q2_Kx8 * b_ptr = b_ptr_start + (x * b_nb);
|
|
1628
1679
|
|
|
1629
|
-
|
|
1630
|
-
|
|
1680
|
+
// Master FP accumulators
|
|
1681
|
+
__m256 acc_row = _mm256_setzero_ps();
|
|
1682
|
+
__m256 acc_min_rows = _mm256_setzero_ps();
|
|
1631
1683
|
|
|
1632
|
-
|
|
1633
|
-
const __m256i rhs_mat_2367_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_1, 4), m4b)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31)
|
|
1684
|
+
for (int64_t b = 0; b < nb; b++) {
|
|
1634
1685
|
|
|
1635
|
-
|
|
1636
|
-
|
|
1637
|
-
const __m256i rhs_mat_2367_0_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_0, 136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3)
|
|
1686
|
+
// Load and convert to FP32 delta from block_q8_K
|
|
1687
|
+
const __m256 row_scale_f32 = _mm256_set1_ps((a_ptr[b].d));
|
|
1638
1688
|
|
|
1639
|
-
|
|
1640
|
-
|
|
1689
|
+
// Load the delta values for the 8 blocks interleaved in block_q2_Kx8
|
|
1690
|
+
// col_scale_f32 rearranged so as to multiply with appropriate quants
|
|
1691
|
+
const __m256 col_scale_f32 = GGML_F32Cx8_REARRANGE_LOAD(b_ptr[b].d, deltamask);
|
|
1692
|
+
const __m256 col_dmin_f32 = GGML_F32Cx8_LOAD(b_ptr[b].dmin);
|
|
1641
1693
|
|
|
1642
|
-
|
|
1643
|
-
|
|
1694
|
+
__m256i iacc_b = _mm256_setzero_si256();
|
|
1695
|
+
__m256i iacc_min_b = _mm256_setzero_si256();
|
|
1644
1696
|
|
|
1645
|
-
|
|
1646
|
-
|
|
1697
|
+
// Processes eight sub blocks from each Q2_K in each iteration
|
|
1698
|
+
for(int sb = 0; sb < QK_K / 128; sb++) {
|
|
1647
1699
|
|
|
1648
|
-
//
|
|
1700
|
+
// Load the eight block_q2_K for eight sub blocks quantized values interleaved with each other in chunks of eight - B0,B1 ....B6,B7
|
|
1701
|
+
const __m256i rhs_raw_vec_0123_0 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + sb * 256));
|
|
1702
|
+
const __m256i rhs_raw_vec_4567_0 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 32 + sb * 256));
|
|
1703
|
+
const __m256i rhs_raw_vec_0123_1 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 64 + sb * 256));
|
|
1704
|
+
const __m256i rhs_raw_vec_4567_1 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 96 + sb * 256));
|
|
1705
|
+
const __m256i rhs_raw_vec_0123_2 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 128 + sb * 256));
|
|
1706
|
+
const __m256i rhs_raw_vec_4567_2 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 160 + sb * 256));
|
|
1707
|
+
const __m256i rhs_raw_vec_0123_3 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 192 + sb * 256));
|
|
1708
|
+
const __m256i rhs_raw_vec_4567_3 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 224 + sb * 256));
|
|
1649
1709
|
|
|
1650
|
-
|
|
1651
|
-
|
|
1710
|
+
// 2-bit -> 8-bit
|
|
1711
|
+
// Values of the 0th,2nd,4th,6th sub blocks of eight block_q2_K structures for the sb loop
|
|
1712
|
+
const __m256i rhs_vec_0123_00 = _mm256_and_si256(rhs_raw_vec_0123_0, m3b); //B00(0-7) B01(0-7) B02(0-7) B03(0-7)
|
|
1713
|
+
const __m256i rhs_vec_0123_20 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_0, 2), m3b); //B20(0-7) B21(0-7) B22(0-7) B23(0-7)
|
|
1714
|
+
const __m256i rhs_vec_0123_40 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_0, 4), m3b); //B40(0-7) B41(0-7) B42(0-7) B43(0-7)
|
|
1715
|
+
const __m256i rhs_vec_0123_60 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_0, 6), m3b); //B60(0-7) B61(0-7) B62(0-7) B63(0-7)
|
|
1652
1716
|
|
|
1653
|
-
const __m256i
|
|
1654
|
-
const __m256i
|
|
1717
|
+
const __m256i rhs_vec_4567_00 = _mm256_and_si256(rhs_raw_vec_4567_0, m3b); //B04(0-7) B05(0-7) B06(0-7) B07(0-7)
|
|
1718
|
+
const __m256i rhs_vec_4567_20 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_0, 2), m3b); //B24(0-7) B25(0-7) B26(0-7) B27(0-7)
|
|
1719
|
+
const __m256i rhs_vec_4567_40 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_0, 4), m3b); //B44(0-7) B45(0-7) B46(0-7) B47(0-7)
|
|
1720
|
+
const __m256i rhs_vec_4567_60 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_0, 6), m3b); //B64(0-7) B65(0-7) B66(0-7) B67(0-7)
|
|
1655
1721
|
|
|
1656
|
-
const __m256i
|
|
1657
|
-
const __m256i
|
|
1722
|
+
const __m256i rhs_vec_0123_01 = _mm256_and_si256(rhs_raw_vec_0123_1, m3b); //B00(8-15) B01(8-15) B02(8-15) B03(8-15)
|
|
1723
|
+
const __m256i rhs_vec_0123_21 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_1, 2), m3b); //B20(8-15) B21(8-15) B22(8-15) B23(8-15)
|
|
1724
|
+
const __m256i rhs_vec_0123_41 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_1, 4), m3b); //B40(8-15) B41(8-15) B42(8-15) B43(8-15)
|
|
1725
|
+
const __m256i rhs_vec_0123_61 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_1, 6), m3b); //B60(8-15) B61(8-15) B62(8-15) B63(8-15)
|
|
1658
1726
|
|
|
1659
|
-
const __m256i
|
|
1660
|
-
const __m256i
|
|
1727
|
+
const __m256i rhs_vec_4567_01 = _mm256_and_si256(rhs_raw_vec_4567_1, m3b); //B04(8-15) B05(8-15) B06(8-15) B07(8-15)
|
|
1728
|
+
const __m256i rhs_vec_4567_21 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_1, 2), m3b); //B24(8-15) B25(8-15) B26(8-15) B27(8-15)
|
|
1729
|
+
const __m256i rhs_vec_4567_41 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_1, 4), m3b); //B44(8-15) B45(8-15) B46(8-15) B47(8-15)
|
|
1730
|
+
const __m256i rhs_vec_4567_61 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_1, 6), m3b); //B64(8-15) B65(8-15) B66(8-15) B67(8-15)
|
|
1661
1731
|
|
|
1662
|
-
//
|
|
1663
|
-
const
|
|
1732
|
+
// Values of the 1st,3rd,5th,7th sub blocks of eight block_q2_K structures for the sb loop
|
|
1733
|
+
const __m256i rhs_vec_0123_10 = _mm256_and_si256(rhs_raw_vec_0123_2, m3b); //B10(0-7) B11(0-7) B12(0-7) B13(0-7)
|
|
1734
|
+
const __m256i rhs_vec_0123_30 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_2, 2), m3b); //B30(0-7) B31(0-7) B32(0-7) B33(0-7)
|
|
1735
|
+
const __m256i rhs_vec_0123_50 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_2, 4), m3b); //B50(0-7) B51(0-7) B52(0-7) B53(0-7)
|
|
1736
|
+
const __m256i rhs_vec_0123_70 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_2, 6), m3b); //B70(0-7) B71(0-7) B72(0-7) B73(0-7)
|
|
1664
1737
|
|
|
1665
|
-
|
|
1666
|
-
|
|
1667
|
-
|
|
1668
|
-
|
|
1669
|
-
__m256i lhs_mat_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs)));
|
|
1670
|
-
__m256i lhs_mat_01_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 0);
|
|
1671
|
-
__m256i lhs_mat_23_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 17);
|
|
1672
|
-
__m256i lhs_mat_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 32)));
|
|
1673
|
-
__m256i lhs_mat_01_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 0);
|
|
1674
|
-
__m256i lhs_mat_23_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 17);
|
|
1675
|
-
__m256i lhs_mat_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 64)));
|
|
1676
|
-
__m256i lhs_mat_01_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 0);
|
|
1677
|
-
__m256i lhs_mat_23_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 17);
|
|
1678
|
-
__m256i lhs_mat_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 96)));
|
|
1679
|
-
__m256i lhs_mat_01_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 0);
|
|
1680
|
-
__m256i lhs_mat_23_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 17);
|
|
1738
|
+
const __m256i rhs_vec_4567_10 = _mm256_and_si256(rhs_raw_vec_4567_2, m3b); //B14(0-7) B15(0-7) B16(0-7) B17(0-7)
|
|
1739
|
+
const __m256i rhs_vec_4567_30 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_2, 2), m3b); //B34(0-7) B35(0-7) B36(0-7) B37(0-7)
|
|
1740
|
+
const __m256i rhs_vec_4567_50 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_2, 4), m3b); //B54(0-7) B55(0-7) B56(0-7) B57(0-7)
|
|
1741
|
+
const __m256i rhs_vec_4567_70 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_2, 6), m3b); //B74(0-7) B75(0-7) B76(0-7) B77(0-7)
|
|
1681
1742
|
|
|
1682
|
-
|
|
1683
|
-
|
|
1684
|
-
|
|
1743
|
+
const __m256i rhs_vec_0123_11 = _mm256_and_si256(rhs_raw_vec_0123_3, m3b); //B10(8-15) B11(8-15) B12(8-15) B13(8-15)
|
|
1744
|
+
const __m256i rhs_vec_0123_31 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_3, 2), m3b); //B30(8-15) B31(8-15) B32(8-15) B33(8-15)
|
|
1745
|
+
const __m256i rhs_vec_0123_51 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_3, 4), m3b); //B50(8-15) B51(8-15) B52(8-15) B53(8-15)
|
|
1746
|
+
const __m256i rhs_vec_0123_71 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_3, 6), m3b); //B70(8-15) B71(8-15) B72(8-15) B73(8-15)
|
|
1685
1747
|
|
|
1686
|
-
|
|
1687
|
-
|
|
1748
|
+
const __m256i rhs_vec_4567_11 = _mm256_and_si256(rhs_raw_vec_4567_3, m3b); //B14(8-15) B15(8-15) B16(8-15) B17(8-15)
|
|
1749
|
+
const __m256i rhs_vec_4567_31 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_3, 2), m3b); //B34(8-15) B35(8-15) B36(8-15) B37(8-15)
|
|
1750
|
+
const __m256i rhs_vec_4567_51 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_3, 4), m3b); //B54(8-15) B55(8-15) B56(8-15) B57(8-15)
|
|
1751
|
+
const __m256i rhs_vec_4567_71 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_3, 6), m3b); //B74(8-15) B75(8-15) B76(8-15) B77(8-15)
|
|
1688
1752
|
|
|
1689
|
-
|
|
1690
|
-
|
|
1753
|
+
//Scales and Mins of corresponding sub blocks from different Q2_K structures are stored together
|
|
1754
|
+
//s00 m00 s01 m01 s10 m10 s11 m11 s20 m20 s21 m21 s30 m30 s31 m31 s40 m40 s41 m41 s50 m50 s51 m51 s60 m60 s61 m61 s70 m70 s71 m71
|
|
1691
1755
|
|
|
1692
|
-
|
|
1693
|
-
|
|
1756
|
+
const __m128i mins_and_scales_01 = _mm_loadu_si128((const __m128i *)(b_ptr[b].scales + sb * 64));
|
|
1757
|
+
const __m128i mins_and_scales_23 = _mm_loadu_si128((const __m128i *)(b_ptr[b].scales + 16 + sb * 64));
|
|
1758
|
+
const __m128i mins_and_scales_45 = _mm_loadu_si128((const __m128i *)(b_ptr[b].scales + 32 + sb * 64));
|
|
1759
|
+
const __m128i mins_and_scales_67 = _mm_loadu_si128((const __m128i *)(b_ptr[b].scales + 48 + sb * 64));
|
|
1694
1760
|
|
|
1695
|
-
|
|
1696
|
-
|
|
1697
|
-
|
|
1761
|
+
// Extract scales which is lower half from mins_and_scales
|
|
1762
|
+
const __m128i scales_01 = _mm_and_si128(mins_and_scales_01, m4b_sse);
|
|
1763
|
+
const __m128i scales_23 = _mm_and_si128(mins_and_scales_23, m4b_sse);
|
|
1764
|
+
const __m128i scales_45 = _mm_and_si128(mins_and_scales_45, m4b_sse);
|
|
1765
|
+
const __m128i scales_67 = _mm_and_si128(mins_and_scales_67, m4b_sse);
|
|
1698
1766
|
|
|
1699
|
-
|
|
1700
|
-
|
|
1767
|
+
// Extract mins which is upper half from mins_and_scales
|
|
1768
|
+
const __m256i mins_01 = _mm256_cvtepu8_epi16(_mm_and_si128(_mm_srli_epi16(mins_and_scales_01, 4), m4b_sse));
|
|
1769
|
+
const __m256i mins_23 = _mm256_cvtepu8_epi16(_mm_and_si128(_mm_srli_epi16(mins_and_scales_23, 4), m4b_sse));
|
|
1770
|
+
const __m256i mins_45 = _mm256_cvtepu8_epi16(_mm_and_si128(_mm_srli_epi16(mins_and_scales_45, 4), m4b_sse));
|
|
1771
|
+
const __m256i mins_67 = _mm256_cvtepu8_epi16(_mm_and_si128(_mm_srli_epi16(mins_and_scales_67, 4), m4b_sse));
|
|
1701
1772
|
|
|
1702
|
-
|
|
1703
|
-
|
|
1773
|
+
// Scales of sub blocks in the sb loop
|
|
1774
|
+
// Scales of the 0th sub block from each super block
|
|
1775
|
+
__m128i scales_rearrange_0 = _mm_shuffle_epi8(scales_01, scalemask1);
|
|
1776
|
+
__m256i scales_0 = _mm256_cvtepu8_epi16(scales_rearrange_0);
|
|
1704
1777
|
|
|
1705
|
-
|
|
1706
|
-
|
|
1778
|
+
// Scales of the 1st sub block from each super block
|
|
1779
|
+
__m128i scales_rearrange_1 = _mm_shuffle_epi8(scales_01, scalemask2);
|
|
1780
|
+
__m256i scales_1 = _mm256_cvtepu8_epi16(scales_rearrange_1);
|
|
1707
1781
|
|
|
1708
|
-
|
|
1709
|
-
|
|
1710
|
-
|
|
1711
|
-
__m256i iacc_mat_00_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp1, rhs_mat_0145_3_sp1), lhs_mat_01_2_sp1, rhs_mat_0145_2_sp1), lhs_mat_01_1_sp1, rhs_mat_0145_1_sp1), lhs_mat_01_0_sp1, rhs_mat_0145_0_sp1);
|
|
1712
|
-
__m256i iacc_mat_01_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp1, rhs_mat_2367_3_sp1), lhs_mat_01_2_sp1, rhs_mat_2367_2_sp1), lhs_mat_01_1_sp1, rhs_mat_2367_1_sp1), lhs_mat_01_0_sp1, rhs_mat_2367_0_sp1);
|
|
1713
|
-
__m256i iacc_mat_10_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp1, rhs_mat_0145_3_sp1), lhs_mat_23_2_sp1, rhs_mat_0145_2_sp1), lhs_mat_23_1_sp1, rhs_mat_0145_1_sp1), lhs_mat_23_0_sp1, rhs_mat_0145_0_sp1);
|
|
1714
|
-
__m256i iacc_mat_11_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp1, rhs_mat_2367_3_sp1), lhs_mat_23_2_sp1, rhs_mat_2367_2_sp1), lhs_mat_23_1_sp1, rhs_mat_2367_1_sp1), lhs_mat_23_0_sp1, rhs_mat_2367_0_sp1);
|
|
1715
|
-
__m256i iacc_mat_00_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp2, rhs_mat_0145_3_sp2), lhs_mat_01_2_sp2, rhs_mat_0145_2_sp2), lhs_mat_01_1_sp2, rhs_mat_0145_1_sp2), lhs_mat_01_0_sp2, rhs_mat_0145_0_sp2);
|
|
1716
|
-
__m256i iacc_mat_01_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp2, rhs_mat_2367_3_sp2), lhs_mat_01_2_sp2, rhs_mat_2367_2_sp2), lhs_mat_01_1_sp2, rhs_mat_2367_1_sp2), lhs_mat_01_0_sp2, rhs_mat_2367_0_sp2);
|
|
1717
|
-
__m256i iacc_mat_10_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp2, rhs_mat_0145_3_sp2), lhs_mat_23_2_sp2, rhs_mat_0145_2_sp2), lhs_mat_23_1_sp2, rhs_mat_0145_1_sp2), lhs_mat_23_0_sp2, rhs_mat_0145_0_sp2);
|
|
1718
|
-
__m256i iacc_mat_11_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp2, rhs_mat_2367_3_sp2), lhs_mat_23_2_sp2, rhs_mat_2367_2_sp2), lhs_mat_23_1_sp2, rhs_mat_2367_1_sp2), lhs_mat_23_0_sp2, rhs_mat_2367_0_sp2);
|
|
1782
|
+
// Scales of the 2nd sub block from each super block
|
|
1783
|
+
__m128i scales_rearrange_2 = _mm_shuffle_epi8(scales_23, scalemask1);
|
|
1784
|
+
__m256i scales_2 = _mm256_cvtepu8_epi16(scales_rearrange_2);
|
|
1719
1785
|
|
|
1720
|
-
|
|
1721
|
-
|
|
1722
|
-
|
|
1723
|
-
__m256i iacc_mat_10 = _mm256_add_epi32(iacc_mat_10_sp1, iacc_mat_10_sp2);
|
|
1724
|
-
__m256i iacc_mat_11 = _mm256_add_epi32(iacc_mat_11_sp1, iacc_mat_11_sp2);
|
|
1786
|
+
// Scales of the 3rd sub block from each super block
|
|
1787
|
+
__m128i scales_rearrange_3 = _mm_shuffle_epi8(scales_23, scalemask2);
|
|
1788
|
+
__m256i scales_3 = _mm256_cvtepu8_epi16(scales_rearrange_3);
|
|
1725
1789
|
|
|
1726
|
-
|
|
1727
|
-
|
|
1728
|
-
|
|
1729
|
-
__m256i iacc_row_2 = _mm256_blend_epi32(iacc_mat_10, _mm256_shuffle_epi32(iacc_mat_11, 78), 204);
|
|
1730
|
-
__m256i iacc_row_3 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11, 204);
|
|
1790
|
+
// Scales of the 4th sub block from each super block
|
|
1791
|
+
__m128i scales_rearrange_4 = _mm_shuffle_epi8(scales_45, scalemask1);
|
|
1792
|
+
__m256i scales_4 = _mm256_cvtepu8_epi16(scales_rearrange_4);
|
|
1731
1793
|
|
|
1732
|
-
|
|
1733
|
-
|
|
1794
|
+
// Scales of the 5th sub block from each super block
|
|
1795
|
+
__m128i scales_rearrange_5 = _mm_shuffle_epi8(scales_45, scalemask2);
|
|
1796
|
+
__m256i scales_5 = _mm256_cvtepu8_epi16(scales_rearrange_5);
|
|
1734
1797
|
|
|
1735
|
-
|
|
1736
|
-
|
|
1737
|
-
|
|
1738
|
-
acc_rows[rp * 4 + 2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_2), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[rp * 4 + 2]);
|
|
1739
|
-
acc_rows[rp * 4 + 3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_3), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[rp * 4 + 3]);
|
|
1740
|
-
}
|
|
1741
|
-
}
|
|
1798
|
+
// Scales of the 6th sub block from each super block
|
|
1799
|
+
__m128i scales_rearrange_6 = _mm_shuffle_epi8(scales_67, scalemask1);
|
|
1800
|
+
__m256i scales_6 = _mm256_cvtepu8_epi16(scales_rearrange_6);
|
|
1742
1801
|
|
|
1743
|
-
|
|
1744
|
-
|
|
1745
|
-
|
|
1746
|
-
}
|
|
1747
|
-
}
|
|
1748
|
-
}
|
|
1802
|
+
// Scales of the 7th sub block from each super block
|
|
1803
|
+
__m128i scales_rearrange_7 = _mm_shuffle_epi8(scales_67, scalemask2);
|
|
1804
|
+
__m256i scales_7 = _mm256_cvtepu8_epi16(scales_rearrange_7);
|
|
1749
1805
|
|
|
1750
|
-
|
|
1751
|
-
|
|
1806
|
+
// Load the sub block values corresponding to sb in block_q8_K in batches of 16 bytes and replicate the same across 256 bit vector
|
|
1807
|
+
__m256i lhs_vec_0 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + sb * 128)));
|
|
1808
|
+
__m256i lhs_vec_1 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 16 + sb * 128)));
|
|
1809
|
+
__m256i lhs_vec_2 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 32 + sb * 128)));
|
|
1810
|
+
__m256i lhs_vec_3 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 48 + sb * 128)));
|
|
1811
|
+
__m256i lhs_vec_4 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 64 + sb * 128)));
|
|
1812
|
+
__m256i lhs_vec_5 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 80 + sb * 128)));
|
|
1813
|
+
__m256i lhs_vec_6 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 96 + sb * 128)));
|
|
1814
|
+
__m256i lhs_vec_7 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 112 + sb * 128)));
|
|
1752
1815
|
|
|
1753
|
-
|
|
1816
|
+
lhs_vec_0 = _mm256_permute2f128_si256(lhs_vec_0, lhs_vec_0, 0);
|
|
1817
|
+
lhs_vec_1 = _mm256_permute2f128_si256(lhs_vec_1, lhs_vec_1, 0);
|
|
1818
|
+
lhs_vec_2 = _mm256_permute2f128_si256(lhs_vec_2, lhs_vec_2, 0);
|
|
1819
|
+
lhs_vec_3 = _mm256_permute2f128_si256(lhs_vec_3, lhs_vec_3, 0);
|
|
1820
|
+
lhs_vec_4 = _mm256_permute2f128_si256(lhs_vec_4, lhs_vec_4, 0);
|
|
1821
|
+
lhs_vec_5 = _mm256_permute2f128_si256(lhs_vec_5, lhs_vec_5, 0);
|
|
1822
|
+
lhs_vec_6 = _mm256_permute2f128_si256(lhs_vec_6, lhs_vec_6, 0);
|
|
1823
|
+
lhs_vec_7 = _mm256_permute2f128_si256(lhs_vec_7, lhs_vec_7, 0);
|
|
1754
1824
|
|
|
1755
|
-
|
|
1756
|
-
|
|
1825
|
+
__m256i iacc_0 = _mm256_setzero_si256();
|
|
1826
|
+
__m256i iacc_1 = _mm256_setzero_si256();
|
|
1827
|
+
__m256i iacc_2 = _mm256_setzero_si256();
|
|
1828
|
+
__m256i iacc_3 = _mm256_setzero_si256();
|
|
1829
|
+
__m256i iacc_4 = _mm256_setzero_si256();
|
|
1830
|
+
__m256i iacc_5 = _mm256_setzero_si256();
|
|
1831
|
+
__m256i iacc_6 = _mm256_setzero_si256();
|
|
1832
|
+
__m256i iacc_7 = _mm256_setzero_si256();
|
|
1757
1833
|
|
|
1758
|
-
|
|
1834
|
+
// Dot product done within 32 bit lanes and accumulated in the same vector
|
|
1835
|
+
// First done for 0th sub block and then for seven (1st - 7th) other sub blocks processed for each sb (sb < QK_K/128 loop) // B0(0-3) B4(0-3) B1(0-3) B5(0-3) B2(0-3) B6(0-3) B3(0-3) B7(0-3) with A0(0-3)
|
|
1836
|
+
// B0(4-7) B4(4-7) B1(4-7) B5(4-7) B2(4-7) B6(4-7) B3(4-7) B7(4-7) with A0(4-7)
|
|
1837
|
+
// B0(8-11) B4(8-11) B1(8-11) B5(8-11) B2(8-11) B6(8-11) B3(8-11) B7(8-11) with A0(8-11)
|
|
1838
|
+
// B0(12-15) B4(12-15) B1(12-15) B5(12-15) B2(12-15) B6(12-15) B3(12-15) B7(12-15) with A0(12-15)
|
|
1759
1839
|
|
|
1760
|
-
|
|
1761
|
-
|
|
1762
|
-
for (int i = 0; i < 4; i++) {
|
|
1763
|
-
acc_rows[i] = _mm256_setzero_ps();
|
|
1764
|
-
}
|
|
1840
|
+
iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_00 ,_mm256_shuffle_epi32(rhs_vec_4567_00, 177), 170), _mm256_shuffle_epi32(lhs_vec_0, 0)));
|
|
1841
|
+
iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_00, 177) ,rhs_vec_4567_00, 170), _mm256_shuffle_epi32(lhs_vec_0, 85)));
|
|
1765
1842
|
|
|
1766
|
-
|
|
1767
|
-
|
|
1768
|
-
const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs));
|
|
1769
|
-
const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 32));
|
|
1770
|
-
const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 64));
|
|
1771
|
-
const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 96));
|
|
1843
|
+
iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_01 ,_mm256_shuffle_epi32(rhs_vec_4567_01, 177), 170), _mm256_shuffle_epi32(lhs_vec_0, 170)));
|
|
1844
|
+
iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_01, 177) ,rhs_vec_4567_01, 170), _mm256_shuffle_epi32(lhs_vec_0, 255)));
|
|
1772
1845
|
|
|
1773
|
-
|
|
1774
|
-
const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240);
|
|
1775
|
-
const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240);
|
|
1776
|
-
const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240);
|
|
1777
|
-
const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240);
|
|
1846
|
+
iacc_0 = _mm256_madd_epi16(iacc_0, scales_0);
|
|
1778
1847
|
|
|
1779
|
-
|
|
1780
|
-
|
|
1781
|
-
const __m256i rhs_mat_2367_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_2367_0, m4b)); //B2(0-7) B3(0-7) B6(0-7) B7(0-7)
|
|
1848
|
+
iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_10 ,_mm256_shuffle_epi32(rhs_vec_4567_10, 177), 170), _mm256_shuffle_epi32(lhs_vec_1, 0)));
|
|
1849
|
+
iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_10, 177) ,rhs_vec_4567_10, 170), _mm256_shuffle_epi32(lhs_vec_1, 85)));
|
|
1782
1850
|
|
|
1783
|
-
|
|
1784
|
-
|
|
1851
|
+
iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_11 ,_mm256_shuffle_epi32(rhs_vec_4567_11, 177), 170), _mm256_shuffle_epi32(lhs_vec_1, 170)));
|
|
1852
|
+
iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_11, 177) ,rhs_vec_4567_11, 170), _mm256_shuffle_epi32(lhs_vec_1, 255)));
|
|
1785
1853
|
|
|
1786
|
-
|
|
1787
|
-
const __m256i rhs_mat_2367_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_0, 4), m4b)); //B2(16-23) B3(16-23) B6(16-23) B7(16-23)
|
|
1854
|
+
iacc_1 = _mm256_madd_epi16(iacc_1, scales_1);
|
|
1788
1855
|
|
|
1789
|
-
|
|
1790
|
-
|
|
1856
|
+
iacc_2 = _mm256_add_epi16(iacc_2, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_20 ,_mm256_shuffle_epi32(rhs_vec_4567_20, 177), 170), _mm256_shuffle_epi32(lhs_vec_2, 0)));
|
|
1857
|
+
iacc_2 = _mm256_add_epi16(iacc_2, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_20, 177) ,rhs_vec_4567_20, 170), _mm256_shuffle_epi32(lhs_vec_2, 85)));
|
|
1791
1858
|
|
|
1792
|
-
|
|
1793
|
-
|
|
1794
|
-
const __m256i rhs_mat_2367_0_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_0, 136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3)
|
|
1859
|
+
iacc_2 = _mm256_add_epi16(iacc_2, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_21 ,_mm256_shuffle_epi32(rhs_vec_4567_21, 177), 170), _mm256_shuffle_epi32(lhs_vec_2, 170)));
|
|
1860
|
+
iacc_2 = _mm256_add_epi16(iacc_2, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_21, 177) ,rhs_vec_4567_21, 170), _mm256_shuffle_epi32(lhs_vec_2, 255)));
|
|
1795
1861
|
|
|
1796
|
-
|
|
1797
|
-
const __m256i rhs_mat_2367_1_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_1, 136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11)
|
|
1862
|
+
iacc_2 = _mm256_madd_epi16(iacc_2, scales_2);
|
|
1798
1863
|
|
|
1799
|
-
|
|
1800
|
-
|
|
1864
|
+
iacc_3 = _mm256_add_epi16(iacc_3, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_30 ,_mm256_shuffle_epi32(rhs_vec_4567_30, 177), 170), _mm256_shuffle_epi32(lhs_vec_3, 0)));
|
|
1865
|
+
iacc_3 = _mm256_add_epi16(iacc_3, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_30, 177) ,rhs_vec_4567_30, 170), _mm256_shuffle_epi32(lhs_vec_3, 85)));
|
|
1801
1866
|
|
|
1802
|
-
|
|
1803
|
-
|
|
1867
|
+
iacc_3 = _mm256_add_epi16(iacc_3, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_31 ,_mm256_shuffle_epi32(rhs_vec_4567_31, 177), 170), _mm256_shuffle_epi32(lhs_vec_3, 170)));
|
|
1868
|
+
iacc_3 = _mm256_add_epi16(iacc_3, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_31, 177) ,rhs_vec_4567_31, 170), _mm256_shuffle_epi32(lhs_vec_3, 255)));
|
|
1804
1869
|
|
|
1805
|
-
|
|
1870
|
+
iacc_3 = _mm256_madd_epi16(iacc_3, scales_3);
|
|
1806
1871
|
|
|
1807
|
-
|
|
1808
|
-
|
|
1872
|
+
iacc_4 = _mm256_add_epi16(iacc_4, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_40 ,_mm256_shuffle_epi32(rhs_vec_4567_40, 177), 170), _mm256_shuffle_epi32(lhs_vec_4, 0)));
|
|
1873
|
+
iacc_4 = _mm256_add_epi16(iacc_4, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_40, 177) ,rhs_vec_4567_40, 170), _mm256_shuffle_epi32(lhs_vec_4, 85)));
|
|
1809
1874
|
|
|
1810
|
-
|
|
1811
|
-
|
|
1875
|
+
iacc_4 = _mm256_add_epi16(iacc_4, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_41 ,_mm256_shuffle_epi32(rhs_vec_4567_41, 177), 170), _mm256_shuffle_epi32(lhs_vec_4, 170)));
|
|
1876
|
+
iacc_4 = _mm256_add_epi16(iacc_4, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_41, 177) ,rhs_vec_4567_41, 170), _mm256_shuffle_epi32(lhs_vec_4, 255)));
|
|
1812
1877
|
|
|
1813
|
-
|
|
1814
|
-
const __m256i rhs_mat_2367_2_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_2, 221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23)
|
|
1878
|
+
iacc_4 = _mm256_madd_epi16(iacc_4, scales_4);
|
|
1815
1879
|
|
|
1816
|
-
|
|
1817
|
-
|
|
1880
|
+
iacc_5 = _mm256_add_epi16(iacc_5, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_50 ,_mm256_shuffle_epi32(rhs_vec_4567_50, 177), 170), _mm256_shuffle_epi32(lhs_vec_5, 0)));
|
|
1881
|
+
iacc_5 = _mm256_add_epi16(iacc_5, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_50, 177) ,rhs_vec_4567_50, 170), _mm256_shuffle_epi32(lhs_vec_5, 85)));
|
|
1818
1882
|
|
|
1819
|
-
|
|
1820
|
-
|
|
1883
|
+
iacc_5 = _mm256_add_epi16(iacc_5, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_51 ,_mm256_shuffle_epi32(rhs_vec_4567_51, 177), 170), _mm256_shuffle_epi32(lhs_vec_5, 170)));
|
|
1884
|
+
iacc_5 = _mm256_add_epi16(iacc_5, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_51, 177) ,rhs_vec_4567_51, 170), _mm256_shuffle_epi32(lhs_vec_5, 255)));
|
|
1821
1885
|
|
|
1822
|
-
|
|
1823
|
-
// Loaded as set of 128 bit vectors and repeated into a 256 bit vector
|
|
1824
|
-
__m256i lhs_mat_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs)));
|
|
1825
|
-
__m256i lhs_mat_01_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 0);
|
|
1826
|
-
__m256i lhs_mat_23_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 17);
|
|
1827
|
-
__m256i lhs_mat_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 32)));
|
|
1828
|
-
__m256i lhs_mat_01_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 0);
|
|
1829
|
-
__m256i lhs_mat_23_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 17);
|
|
1830
|
-
__m256i lhs_mat_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 64)));
|
|
1831
|
-
__m256i lhs_mat_01_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 0);
|
|
1832
|
-
__m256i lhs_mat_23_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 17);
|
|
1833
|
-
__m256i lhs_mat_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 96)));
|
|
1834
|
-
__m256i lhs_mat_01_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 0);
|
|
1835
|
-
__m256i lhs_mat_23_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 17);
|
|
1886
|
+
iacc_5 = _mm256_madd_epi16(iacc_5, scales_5);
|
|
1836
1887
|
|
|
1837
|
-
|
|
1888
|
+
iacc_6 = _mm256_add_epi16(iacc_6, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_60 ,_mm256_shuffle_epi32(rhs_vec_4567_60, 177), 170), _mm256_shuffle_epi32(lhs_vec_6, 0)));
|
|
1889
|
+
iacc_6 = _mm256_add_epi16(iacc_6, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_60, 177) ,rhs_vec_4567_60, 170), _mm256_shuffle_epi32(lhs_vec_6, 85)));
|
|
1838
1890
|
|
|
1839
|
-
|
|
1840
|
-
|
|
1891
|
+
iacc_6 = _mm256_add_epi16(iacc_6, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_61 ,_mm256_shuffle_epi32(rhs_vec_4567_61, 177), 170), _mm256_shuffle_epi32(lhs_vec_6, 170)));
|
|
1892
|
+
iacc_6 = _mm256_add_epi16(iacc_6, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_61, 177) ,rhs_vec_4567_61, 170), _mm256_shuffle_epi32(lhs_vec_6, 255)));
|
|
1841
1893
|
|
|
1842
|
-
|
|
1843
|
-
const __m256i lhs_mat_23_1_sp1 = _mm256_shuffle_epi32(lhs_mat_23_1, 160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11)
|
|
1894
|
+
iacc_6 = _mm256_madd_epi16(iacc_6, scales_6);
|
|
1844
1895
|
|
|
1845
|
-
|
|
1846
|
-
|
|
1896
|
+
iacc_7 = _mm256_add_epi16(iacc_7, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_70 ,_mm256_shuffle_epi32(rhs_vec_4567_70, 177), 170), _mm256_shuffle_epi32(lhs_vec_7, 0)));
|
|
1897
|
+
iacc_7 = _mm256_add_epi16(iacc_7, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_70, 177) ,rhs_vec_4567_70, 170), _mm256_shuffle_epi32(lhs_vec_7, 85)));
|
|
1847
1898
|
|
|
1848
|
-
|
|
1849
|
-
|
|
1899
|
+
iacc_7 = _mm256_add_epi16(iacc_7, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_71 ,_mm256_shuffle_epi32(rhs_vec_4567_71, 177), 170), _mm256_shuffle_epi32(lhs_vec_7, 170)));
|
|
1900
|
+
iacc_7 = _mm256_add_epi16(iacc_7, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_71, 177) ,rhs_vec_4567_71, 170), _mm256_shuffle_epi32(lhs_vec_7, 255)));
|
|
1850
1901
|
|
|
1851
|
-
|
|
1902
|
+
iacc_7 = _mm256_madd_epi16(iacc_7, scales_7);
|
|
1852
1903
|
|
|
1853
|
-
|
|
1854
|
-
|
|
1904
|
+
// Accumulate the iacc value for one sb
|
|
1905
|
+
__m256i iacc_sb = _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(iacc_0, iacc_1), _mm256_add_epi32(iacc_2, iacc_3)), _mm256_add_epi32(_mm256_add_epi32(iacc_4, iacc_5), _mm256_add_epi32(iacc_6, iacc_7)));
|
|
1855
1906
|
|
|
1856
|
-
|
|
1857
|
-
|
|
1907
|
+
__m128i q8sums = _mm_loadu_si128((const __m128i *)(a_ptr[b].bsums + sb * 8));
|
|
1908
|
+
__m256i q8s = _mm256_castsi128_si256(q8sums);
|
|
1909
|
+
q8s= _mm256_permute2f128_si256(q8s, q8s, 0);
|
|
1858
1910
|
|
|
1859
|
-
|
|
1860
|
-
|
|
1911
|
+
// Broadcast the bsums of the two corresponding subblocks of q8_k
|
|
1912
|
+
// Multiply-Add with corresponding mins of Q2_Kx8 with bsums
|
|
1913
|
+
__m256i iacc_min_sb_01 = _mm256_madd_epi16(_mm256_shuffle_epi32(q8s, 0), mins_01);
|
|
1914
|
+
__m256i iacc_min_sb_23 = _mm256_madd_epi16(_mm256_shuffle_epi32(q8s, 85), mins_23);
|
|
1915
|
+
__m256i iacc_min_sb_45 = _mm256_madd_epi16(_mm256_shuffle_epi32(q8s, 170), mins_45);
|
|
1916
|
+
__m256i iacc_min_sb_67 = _mm256_madd_epi16(_mm256_shuffle_epi32(q8s, 255), mins_67);
|
|
1861
1917
|
|
|
1862
|
-
|
|
1863
|
-
const __m256i lhs_mat_23_3_sp2 = _mm256_shuffle_epi32(lhs_mat_23_3, 245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31)
|
|
1918
|
+
__m256i iacc_min_sb = _mm256_add_epi32(_mm256_add_epi32(iacc_min_sb_01, iacc_min_sb_23), _mm256_add_epi32(iacc_min_sb_45,iacc_min_sb_67));
|
|
1864
1919
|
|
|
1865
|
-
//
|
|
1866
|
-
|
|
1867
|
-
|
|
1868
|
-
|
|
1869
|
-
__m256i iacc_mat_01_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp1, rhs_mat_2367_3_sp1), lhs_mat_01_2_sp1, rhs_mat_2367_2_sp1), lhs_mat_01_1_sp1, rhs_mat_2367_1_sp1), lhs_mat_01_0_sp1, rhs_mat_2367_0_sp1);
|
|
1870
|
-
__m256i iacc_mat_10_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp1, rhs_mat_0145_3_sp1), lhs_mat_23_2_sp1, rhs_mat_0145_2_sp1), lhs_mat_23_1_sp1, rhs_mat_0145_1_sp1), lhs_mat_23_0_sp1, rhs_mat_0145_0_sp1);
|
|
1871
|
-
__m256i iacc_mat_11_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp1, rhs_mat_2367_3_sp1), lhs_mat_23_2_sp1, rhs_mat_2367_2_sp1), lhs_mat_23_1_sp1, rhs_mat_2367_1_sp1), lhs_mat_23_0_sp1, rhs_mat_2367_0_sp1);
|
|
1872
|
-
__m256i iacc_mat_00_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp2, rhs_mat_0145_3_sp2), lhs_mat_01_2_sp2, rhs_mat_0145_2_sp2), lhs_mat_01_1_sp2, rhs_mat_0145_1_sp2), lhs_mat_01_0_sp2, rhs_mat_0145_0_sp2);
|
|
1873
|
-
__m256i iacc_mat_01_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp2, rhs_mat_2367_3_sp2), lhs_mat_01_2_sp2, rhs_mat_2367_2_sp2), lhs_mat_01_1_sp2, rhs_mat_2367_1_sp2), lhs_mat_01_0_sp2, rhs_mat_2367_0_sp2);
|
|
1874
|
-
__m256i iacc_mat_10_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp2, rhs_mat_0145_3_sp2), lhs_mat_23_2_sp2, rhs_mat_0145_2_sp2), lhs_mat_23_1_sp2, rhs_mat_0145_1_sp2), lhs_mat_23_0_sp2, rhs_mat_0145_0_sp2);
|
|
1875
|
-
__m256i iacc_mat_11_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp2, rhs_mat_2367_3_sp2), lhs_mat_23_2_sp2, rhs_mat_2367_2_sp2), lhs_mat_23_1_sp2, rhs_mat_2367_1_sp2), lhs_mat_23_0_sp2, rhs_mat_2367_0_sp2);
|
|
1920
|
+
// Accumulate for the complete block
|
|
1921
|
+
iacc_b = _mm256_add_epi32(iacc_b, iacc_sb);
|
|
1922
|
+
iacc_min_b = _mm256_add_epi32(iacc_min_b, iacc_min_sb);
|
|
1923
|
+
}
|
|
1876
1924
|
|
|
1877
|
-
|
|
1878
|
-
|
|
1879
|
-
|
|
1880
|
-
|
|
1881
|
-
|
|
1925
|
+
//Multiply-Add with scale values for complete super block
|
|
1926
|
+
acc_row = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_b), _mm256_mul_ps(col_scale_f32, row_scale_f32), acc_row);
|
|
1927
|
+
acc_min_rows = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_min_b), _mm256_mul_ps(col_dmin_f32, row_scale_f32), acc_min_rows);
|
|
1928
|
+
}
|
|
1929
|
+
// Accumulated output values permuted so as to be stored in appropriate order post accumulation
|
|
1930
|
+
acc_row = _mm256_permutevar8x32_ps(acc_row, finalpermutemask);
|
|
1931
|
+
_mm256_storeu_ps(s + (y * nr + x * 8), _mm256_sub_ps(acc_row, acc_min_rows));
|
|
1932
|
+
}
|
|
1933
|
+
}
|
|
1934
|
+
#else
|
|
1882
1935
|
|
|
1936
|
+
ggml_gemv_q2_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
|
1883
1937
|
|
|
1884
|
-
|
|
1885
|
-
|
|
1886
|
-
__m256i iacc_row_1 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_00, 78), iacc_mat_01, 204);
|
|
1887
|
-
__m256i iacc_row_2 = _mm256_blend_epi32(iacc_mat_10, _mm256_shuffle_epi32(iacc_mat_11, 78), 204);
|
|
1888
|
-
__m256i iacc_row_3 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11, 204);
|
|
1938
|
+
#endif
|
|
1939
|
+
}
|
|
1889
1940
|
|
|
1890
|
-
|
|
1891
|
-
|
|
1941
|
+
void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
|
1942
|
+
#if defined(__AVX2__) || defined(__AVX512F__)
|
|
1943
|
+
{
|
|
1944
|
+
// Lookup table to convert signed nibbles to signed bytes
|
|
1945
|
+
__m256i signextendlut = _mm256_castsi128_si256(_mm_set_epi8(-1, -2, -3, -4, -5, -6, -7, -8, 7, 6, 5, 4, 3, 2, 1, 0));
|
|
1946
|
+
signextendlut = _mm256_permute2f128_si256(signextendlut, signextendlut, 0);
|
|
1892
1947
|
|
|
1893
|
-
|
|
1894
|
-
acc_rows[0] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_0), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[0]);
|
|
1895
|
-
acc_rows[1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_1), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[1]);
|
|
1896
|
-
acc_rows[2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_2), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[2]);
|
|
1897
|
-
acc_rows[3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_3), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[3]);
|
|
1898
|
-
}
|
|
1948
|
+
gemm_q4_b32_8x8_q8_0_lut_avx<block_q4_0x8>(n, s, bs, vx, vy, nr, nc, signextendlut);
|
|
1899
1949
|
|
|
1900
|
-
// Store the accumulated values
|
|
1901
|
-
for (int i = 0; i < 4; i++) {
|
|
1902
|
-
_mm256_storeu_ps((float *)(s + ((y * 4 + i) * bs + x * 8)), acc_rows[i]);
|
|
1903
|
-
}
|
|
1904
|
-
}
|
|
1905
|
-
}
|
|
1906
1950
|
return;
|
|
1907
1951
|
}
|
|
1952
|
+
#endif // defined(__AVX2__) || defined(__AVX512F__)
|
|
1908
1953
|
|
|
1909
|
-
#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__)
|
|
1910
1954
|
ggml_gemm_q4_0_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
|
1911
1955
|
}
|
|
1912
1956
|
|
|
@@ -3364,6 +3408,21 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
|
|
3364
3408
|
#endif
|
|
3365
3409
|
}
|
|
3366
3410
|
|
|
3411
|
+
void ggml_gemm_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
|
3412
|
+
#if defined(__AVX2__) || defined(__AVX512F__)
|
|
3413
|
+
{
|
|
3414
|
+
__m256i signextendlut = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i*)kvalues_iq4nl));
|
|
3415
|
+
signextendlut = _mm256_permute2f128_si256(signextendlut, signextendlut, 0);
|
|
3416
|
+
|
|
3417
|
+
gemm_q4_b32_8x8_q8_0_lut_avx<block_iq4_nlx8>(n, s, bs, vx, vy, nr, nc, signextendlut);
|
|
3418
|
+
|
|
3419
|
+
return;
|
|
3420
|
+
}
|
|
3421
|
+
#endif // defined(__AVX2__) || defined(__AVX512F__)
|
|
3422
|
+
|
|
3423
|
+
ggml_gemm_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
|
|
3424
|
+
}
|
|
3425
|
+
|
|
3367
3426
|
void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
|
3368
3427
|
const int qk = QK_K;
|
|
3369
3428
|
const int nb = n / qk;
|