@fugood/llama.node 0.3.15 → 0.3.16

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. package/bin/darwin/arm64/llama-node.node +0 -0
  2. package/bin/darwin/x64/llama-node.node +0 -0
  3. package/bin/linux/arm64/llama-node.node +0 -0
  4. package/bin/linux/x64/llama-node.node +0 -0
  5. package/bin/linux-cuda/arm64/llama-node.node +0 -0
  6. package/bin/linux-cuda/x64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  8. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  9. package/bin/win32/arm64/llama-node.node +0 -0
  10. package/bin/win32/arm64/node.lib +0 -0
  11. package/bin/win32/x64/llama-node.node +0 -0
  12. package/bin/win32/x64/node.lib +0 -0
  13. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  14. package/bin/win32-vulkan/arm64/node.lib +0 -0
  15. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  16. package/bin/win32-vulkan/x64/node.lib +0 -0
  17. package/package.json +1 -1
  18. package/src/llama.cpp/examples/server/server.cpp +5 -0
  19. package/src/llama.cpp/examples/tts/tts.cpp +8 -0
  20. package/src/llama.cpp/ggml/src/CMakeLists.txt +5 -1
  21. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +1493 -12
  22. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +31 -27
  23. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +32 -12
  24. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +27 -1
  25. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +12 -43
  26. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +6 -6
  27. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +46 -12
  28. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +4 -2
  29. package/src/llama.cpp/src/llama-arch.cpp +1 -0
  30. package/src/llama.cpp/src/llama-model.cpp +65 -38
  31. package/src/llama.cpp/tests/test-backend-ops.cpp +57 -14
@@ -45,6 +45,24 @@ using block_q4_0x8 = block<4, 8>;
45
45
  using block_q8_0x4 = block<8, 4>;
46
46
  using block_q8_0x8 = block<8, 8>;
47
47
 
48
+
49
+ struct block_q4_Kx8 {
50
+ ggml_half d[8]; // super-block scale for quantized scales
51
+ ggml_half dmin[8]; // super-block scale for quantized mins
52
+ uint8_t scales[96]; // scales and mins, quantized with 6 bits
53
+ uint8_t qs[1024]; // 4--bit quants
54
+ };
55
+
56
+ static_assert(sizeof(block_q4_Kx8) == sizeof(ggml_half) * 16 + K_SCALE_SIZE * 8 + QK_K * 4, "wrong q4_K block size/padding");
57
+
58
+ struct block_q8_Kx4 {
59
+ float d[4]; // delta
60
+ int8_t qs[QK_K * 4]; // quants
61
+ int16_t bsums[QK_K / 4]; // sum of quants in groups of 16
62
+ };
63
+
64
+ static_assert(sizeof(block_q8_Kx4) == sizeof(float) * 4 + QK_K * 4 + (QK_K / 4) * sizeof(int16_t), "wrong q8_K block size/padding");
65
+
48
66
  struct block_iq4_nlx4 {
49
67
  ggml_half d[4]; // deltas for 4 iq4_nl blocks
50
68
  uint8_t qs[QK4_NL * 2]; // nibbles / quants for 4 iq4_nl blocks
@@ -60,6 +78,13 @@ static_assert(sizeof(block_iq4_nlx4) == 4 * sizeof(ggml_half) + QK4_NL * 2, "wro
60
78
 
61
79
  #define UNUSED GGML_UNUSED
62
80
 
81
+ static inline int nearest_int(float fval) {
82
+ assert(fabsf(fval) <= 4194303.f);
83
+ float val = fval + 12582912.f;
84
+ int i; memcpy(&i, &val, sizeof(int));
85
+ return (i & 0x007fffff) - 0x00400000;
86
+ }
87
+
63
88
  // Functions to create the interleaved data layout formats
64
89
 
65
90
  // interleave 4 block_q4_0s in blocks of blck_size_interleave
@@ -534,6 +559,270 @@ static void quantize_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRIC
534
559
  #endif
535
560
  }
536
561
 
562
+ static void quantize_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
563
+ assert(QK_K == 256);
564
+ assert(k % QK_K == 0);
565
+ const int nb = k / QK_K;
566
+
567
+ block_q8_Kx4 * GGML_RESTRICT y = (block_q8_Kx4 *) vy;
568
+
569
+ #if defined(__AVX2__)
570
+ float iscale[4];
571
+ __m256 srcv[4][32];
572
+ __m256 iscale_vec[4];
573
+
574
+ for (int i = 0; i < nb; i++) {
575
+ for (int row_iter = 0; row_iter < 4; row_iter++) {
576
+ // Load elements into 4 AVX vectors
577
+ __m256 v0 = _mm256_loadu_ps( x + row_iter * k + i * 256 );
578
+ __m256 v1 = _mm256_loadu_ps( x + row_iter * k + i * 256 + 8 );
579
+ __m256 v2 = _mm256_loadu_ps( x + row_iter * k + i * 256 + 16 );
580
+ __m256 v3 = _mm256_loadu_ps( x + row_iter * k + i * 256 + 24 );
581
+
582
+ // Compute max(abs(e)) for the block
583
+ const __m256 signBit = _mm256_set1_ps( -0.0f );
584
+ __m256 abs0 = _mm256_andnot_ps( signBit, v0 );
585
+ __m256 abs1 = _mm256_andnot_ps( signBit, v1 );
586
+ __m256 abs2 = _mm256_andnot_ps( signBit, v2 );
587
+ __m256 abs3 = _mm256_andnot_ps( signBit, v3 );
588
+
589
+ __m256 maxAbs = _mm256_max_ps( abs0, abs1 );
590
+ maxAbs = _mm256_max_ps( maxAbs, abs2 );
591
+ maxAbs = _mm256_max_ps( maxAbs, abs3 );
592
+
593
+ __m256 mask0 = _mm256_cmp_ps( maxAbs, v0, _CMP_EQ_OQ );
594
+ __m256 mask1 = _mm256_cmp_ps( maxAbs, v1, _CMP_EQ_OQ );
595
+ __m256 mask2 = _mm256_cmp_ps( maxAbs, v2, _CMP_EQ_OQ );
596
+ __m256 mask3 = _mm256_cmp_ps( maxAbs, v3, _CMP_EQ_OQ );
597
+
598
+ __m256 maskAbs = _mm256_or_ps(_mm256_or_ps(mask0, mask1),_mm256_or_ps(mask2, mask3));
599
+
600
+ srcv[row_iter][0] = v0;
601
+ srcv[row_iter][1] = v1;
602
+ srcv[row_iter][2] = v2;
603
+ srcv[row_iter][3] = v3;
604
+
605
+ for (int sb = 1; sb < 8; sb++) {
606
+ // Temporarily stores absolute quant values
607
+ __m256 tempAbs = maxAbs;
608
+
609
+ // Load elements into 4 AVX vectors
610
+ __m256 v0 = _mm256_loadu_ps( x + row_iter * k + i * 256 + sb * 32);
611
+ __m256 v1 = _mm256_loadu_ps( x + row_iter * k + i * 256 + sb * 32 + 8 );
612
+ __m256 v2 = _mm256_loadu_ps( x + row_iter * k + i * 256 + sb * 32 + 16 );
613
+ __m256 v3 = _mm256_loadu_ps( x + row_iter * k + i * 256 + sb * 32 + 24 );
614
+
615
+ // Compute max(abs(e)) for the block
616
+ __m256 abs0 = _mm256_andnot_ps( signBit, v0 );
617
+ __m256 abs1 = _mm256_andnot_ps( signBit, v1 );
618
+ __m256 abs2 = _mm256_andnot_ps( signBit, v2 );
619
+ __m256 abs3 = _mm256_andnot_ps( signBit, v3 );
620
+
621
+ maxAbs = _mm256_max_ps( maxAbs, abs0 );
622
+ maxAbs = _mm256_max_ps( maxAbs, abs1 );
623
+ maxAbs = _mm256_max_ps( maxAbs, abs2 );
624
+ maxAbs = _mm256_max_ps( maxAbs, abs3 );
625
+
626
+ __m256 mask_prev = _mm256_cmp_ps( tempAbs, maxAbs, _CMP_EQ_OQ );
627
+ maskAbs = _mm256_and_ps( maskAbs, mask_prev );
628
+
629
+ mask0 = _mm256_cmp_ps( maxAbs, v0, _CMP_EQ_OQ );
630
+ mask1 = _mm256_cmp_ps( maxAbs, v1, _CMP_EQ_OQ );
631
+ mask2 = _mm256_cmp_ps( maxAbs, v2, _CMP_EQ_OQ );
632
+ mask3 = _mm256_cmp_ps( maxAbs, v3, _CMP_EQ_OQ );
633
+
634
+ __m256 mask_curr = _mm256_or_ps(_mm256_or_ps(mask0, mask1),_mm256_or_ps(mask2, mask3));
635
+ maskAbs = _mm256_or_ps(maskAbs, mask_curr);
636
+
637
+ srcv[row_iter][sb * 4] = v0;
638
+ srcv[row_iter][sb * 4 + 1] = v1;
639
+ srcv[row_iter][sb * 4 + 2] = v2;
640
+ srcv[row_iter][sb * 4 + 3] = v3;
641
+ }
642
+
643
+ __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
644
+ max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
645
+ max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
646
+ const float maxScalar = _mm_cvtss_f32( max4 );
647
+
648
+ __m256 maxScalarVec = _mm256_set1_ps(maxScalar);
649
+
650
+ __m256 mask_next = _mm256_cmp_ps( maxScalarVec, maxAbs, _CMP_EQ_OQ );
651
+ __m256 finalMask = _mm256_and_ps(maskAbs, mask_next);
652
+
653
+ const int mask = _mm256_movemask_ps(finalMask);
654
+ iscale[row_iter] = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
655
+
656
+ if(mask) {
657
+ iscale[row_iter] = ( maxScalar != 0.0f ) ? -127.f / maxScalar: 0.0f;
658
+ }
659
+
660
+ y[i].d[row_iter] = maxScalar ? 1/iscale[row_iter] : 0;
661
+ iscale_vec[row_iter] = _mm256_set1_ps(iscale[row_iter]);
662
+ }
663
+
664
+ __m256i quants_interleaved[32];
665
+ for (int j = 0; j < 32; j++) {
666
+ // Apply the multiplier
667
+ __m256 v0 = _mm256_mul_ps(srcv[0][j], iscale_vec[0]);
668
+ __m256 v1 = _mm256_mul_ps(srcv[1][j], iscale_vec[1]);
669
+ __m256 v2 = _mm256_mul_ps(srcv[2][j], iscale_vec[2]);
670
+ __m256 v3 = _mm256_mul_ps(srcv[3][j], iscale_vec[3]);
671
+
672
+ // Round to nearest integer
673
+ v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
674
+ v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
675
+ v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
676
+ v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
677
+
678
+ // Convert floats to integers
679
+ __m256i i0 = _mm256_cvtps_epi32( v0 );
680
+ __m256i i1 = _mm256_cvtps_epi32( v1 );
681
+ __m256i i2 = _mm256_cvtps_epi32( v2 );
682
+ __m256i i3 = _mm256_cvtps_epi32( v3 );
683
+
684
+ // Convert int32 to int16
685
+ i0 = _mm256_packs_epi32( i0, i1 );
686
+ i2 = _mm256_packs_epi32( i2, i3 );
687
+ // Convert int16 to int8
688
+ i0 = _mm256_packs_epi16( i0, i2 );
689
+
690
+ // Permute and store the quantized weights in the required order after the pack instruction
691
+ const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
692
+ i0 = _mm256_permutevar8x32_epi32( i0, perm );
693
+
694
+ _mm256_storeu_si256((__m256i *)(y[i].qs + 32 * j), i0);
695
+ quants_interleaved[j] = i0;
696
+ }
697
+
698
+ // Masks to shuffle the quants of corresonding sub blocks for rearraning quants for vectorized bsums computation
699
+ __m256i shuffle_mask_sb2 = _mm256_castsi128_si256(_mm_setr_epi8(0, 1, 0, 1, 4, 5, 6, 7, 8, 9, 8, 9, 12, 13, 14, 15));
700
+ shuffle_mask_sb2 = _mm256_permute2f128_si256(shuffle_mask_sb2, shuffle_mask_sb2, 0);
701
+ __m256i shuffle_mask_sb3 = _mm256_castsi128_si256(_mm_setr_epi8(0, 1, 2, 3, 0, 1, 6, 7, 8, 9, 10, 11, 8, 9, 14, 15));
702
+ shuffle_mask_sb3 = _mm256_permute2f128_si256(shuffle_mask_sb3, shuffle_mask_sb3, 0);
703
+ __m256i shuffle_mask_sb4 = _mm256_castsi128_si256(_mm_setr_epi8(0, 1, 2, 3, 4, 5, 0, 1, 8, 9, 10, 11, 12, 13, 8, 9));
704
+ shuffle_mask_sb4 = _mm256_permute2f128_si256(shuffle_mask_sb4, shuffle_mask_sb4, 0);
705
+
706
+ for (int k = 0; k < 4; k++) {
707
+ // Quants from four different sub blocks are taken
708
+ __m256i q0 = quants_interleaved[k * 8 + 0];
709
+ __m256i q1 = quants_interleaved[k * 8 + 1];
710
+ __m256i q2 = quants_interleaved[k * 8 + 2];
711
+ __m256i q3 = quants_interleaved[k * 8 + 3];
712
+ __m256i q4 = quants_interleaved[k * 8 + 4];
713
+ __m256i q5 = quants_interleaved[k * 8 + 5];
714
+ __m256i q6 = quants_interleaved[k * 8 + 6];
715
+ __m256i q7 = quants_interleaved[k * 8 + 7];
716
+
717
+
718
+ // The below code block has the first half of different sub blocks shuffled and blended so as to process 2 values from each sub block at a time
719
+ __m256i sb2_h1_shuffled = _mm256_shuffle_epi8(q2, shuffle_mask_sb2);
720
+ __m256i sb_h1_interleaved = _mm256_blend_epi16(q0, sb2_h1_shuffled, 34);
721
+ __m256i sb3_h1_shuffled = _mm256_shuffle_epi8(q4, shuffle_mask_sb3);
722
+ sb_h1_interleaved = _mm256_blend_epi16(sb_h1_interleaved, sb3_h1_shuffled, 68);
723
+ __m256i sb4_h1_shuffled = _mm256_shuffle_epi8(q6, shuffle_mask_sb4);
724
+ sb_h1_interleaved = _mm256_blend_epi16(sb_h1_interleaved, sb4_h1_shuffled, 136);
725
+
726
+ __m256i one = _mm256_set1_epi8(1);
727
+ __m256i bsums_r1 = _mm256_maddubs_epi16(one, sb_h1_interleaved);
728
+
729
+ for (int l = 0; l < 3; l++) {
730
+ // Quants value shifted to process next two values from each sub block
731
+ q0 = _mm256_srli_epi64(q0, 16);
732
+ q2 = _mm256_srli_epi64(q2, 16);
733
+ q4 = _mm256_srli_epi64(q4, 16);
734
+ q6 = _mm256_srli_epi64(q6, 16);
735
+
736
+ sb2_h1_shuffled = _mm256_shuffle_epi8(q2, shuffle_mask_sb2);
737
+ sb_h1_interleaved = _mm256_blend_epi16(q0, sb2_h1_shuffled, 34);
738
+ sb3_h1_shuffled = _mm256_shuffle_epi8(q4, shuffle_mask_sb3);
739
+ sb_h1_interleaved = _mm256_blend_epi16(sb_h1_interleaved, sb3_h1_shuffled, 68);
740
+ sb4_h1_shuffled = _mm256_shuffle_epi8(q6, shuffle_mask_sb4);
741
+ sb_h1_interleaved = _mm256_blend_epi16(sb_h1_interleaved, sb4_h1_shuffled, 136);
742
+
743
+ bsums_r1 = _mm256_add_epi16(bsums_r1, _mm256_maddubs_epi16(one, sb_h1_interleaved));
744
+ }
745
+
746
+ // The below code block has the second half of different sub blocks shuffled and blended so as to process 2 values from each sub block at a time
747
+ __m256i sb2_h2_shuffled = _mm256_shuffle_epi8(q3, shuffle_mask_sb2);
748
+ __m256i sb_h2_interleaved = _mm256_blend_epi16(q1, sb2_h2_shuffled, 34);
749
+ __m256i sb3_h2_shuffled = _mm256_shuffle_epi8(q5, shuffle_mask_sb3);
750
+ sb_h2_interleaved = _mm256_blend_epi16(sb_h2_interleaved, sb3_h2_shuffled, 68);
751
+ __m256i sb4_h2_shuffled = _mm256_shuffle_epi8(q7, shuffle_mask_sb4);
752
+ sb_h2_interleaved = _mm256_blend_epi16(sb_h2_interleaved, sb4_h2_shuffled, 136);
753
+
754
+ __m256i bsums_r2 = _mm256_maddubs_epi16(one, sb_h2_interleaved);
755
+
756
+ for (int l = 0; l < 3; l++) {
757
+ // Quants value shifted to process next two values from each sub block
758
+ q1 = _mm256_srli_epi64(q1, 16);
759
+ q3 = _mm256_srli_epi64(q3, 16);
760
+ q5 = _mm256_srli_epi64(q5, 16);
761
+ q7 = _mm256_srli_epi64(q7, 16);
762
+
763
+ sb2_h2_shuffled = _mm256_shuffle_epi8(q3, shuffle_mask_sb2);
764
+ sb_h2_interleaved = _mm256_blend_epi16(q1, sb2_h2_shuffled, 34);
765
+ sb3_h2_shuffled = _mm256_shuffle_epi8(q5, shuffle_mask_sb3);
766
+ sb_h2_interleaved = _mm256_blend_epi16(sb_h2_interleaved, sb3_h2_shuffled, 68);
767
+ sb4_h2_shuffled = _mm256_shuffle_epi8(q7, shuffle_mask_sb4);
768
+ sb_h2_interleaved = _mm256_blend_epi16(sb_h2_interleaved, sb4_h2_shuffled, 136);
769
+
770
+ bsums_r2 = _mm256_add_epi16(bsums_r2, _mm256_maddubs_epi16(one, sb_h2_interleaved));
771
+ }
772
+
773
+ // Overall bsums in interleaved fashion computed by adding results of both halves
774
+ __m256i bsums_r = _mm256_add_epi16(bsums_r1, bsums_r2);
775
+ _mm256_storeu_si256((__m256i *)(y[i].bsums + 16 * k), bsums_r);
776
+ }
777
+ }
778
+
779
+ #else
780
+
781
+ // scalar
782
+ const int blck_size_interleave = 8;
783
+ float srcv[4][QK_K];
784
+ float iscale[4];
785
+
786
+ for (int i = 0; i < nb; i++) {
787
+ for (int row_iter = 0; row_iter < 4; row_iter++) {
788
+ float amax = 0.0f; // absolute max
789
+ float max = 0;
790
+
791
+ for (int j = 0; j < QK_K; j++) {
792
+ srcv[row_iter][j] = x[row_iter * k + i * QK_K + j];
793
+ // Update the maximum value of the corresponding super block
794
+ if(amax < fabsf(srcv[row_iter][j])) {
795
+ amax = fabsf(srcv[row_iter][j]);
796
+ max = srcv[row_iter][j];
797
+ }
798
+ }
799
+
800
+ iscale[row_iter] = amax ? -127.f/max : 0;
801
+
802
+ y[i].d[row_iter] = amax ? 1/iscale[row_iter] : 0;
803
+ }
804
+
805
+ for (int j = 0; j < QK_K / 4; j++) {
806
+ y[i].bsums[j] = 0;
807
+ }
808
+
809
+ // Quants values are interleaved in sequence of eight bytes from corresponding super blocks
810
+ // Bsums values are interleaved in sequence of four bsums from each super block taken for interleaving
811
+ // i.e first four bsums from the first super block, followed by first four bsums from second super block and so on
812
+ for (int j = 0; j < QK_K * 4; j++) {
813
+ int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave;
814
+ int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave;
815
+ src_offset += (j % blck_size_interleave);
816
+ int index = (((j & 31) >> 3) << 2) + ((j >> 8) << 4) + ((j >> 6) & 3);
817
+
818
+ float x0 = srcv[src_id][src_offset] * iscale[src_id];
819
+ y[i].qs[j] = nearest_int(x0);
820
+ y[i].bsums[index] += y[i].qs[j];
821
+ }
822
+ }
823
+ #endif
824
+ }
825
+
537
826
  static void quantize_mat_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row, int64_t blck_size_interleave) {
538
827
  assert(nrow == 4);
539
828
  UNUSED(nrow);
@@ -546,6 +835,16 @@ static void quantize_mat_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRIC
546
835
  }
547
836
  }
548
837
 
838
+ static void quantize_mat_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row, int64_t blck_size_interleave) {
839
+ assert(nrow == 4);
840
+ UNUSED(nrow);
841
+ if (blck_size_interleave == 8) {
842
+ quantize_q8_K_4x8(x, vy, n_per_row);
843
+ } else {
844
+ assert(false);
845
+ }
846
+ }
847
+
549
848
  static void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
550
849
  const int qk = QK8_0;
551
850
  const int nb = n / qk;
@@ -994,6 +1293,281 @@ static void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, c
994
1293
  }
995
1294
  }
996
1295
 
1296
+ static void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
1297
+ const int qk = QK_K;
1298
+ const int nb = n / qk;
1299
+ const int ncols_interleaved = 8;
1300
+ const int blocklen = 8;
1301
+ static const uint32_t kmask1 = 0x3f3f3f3f;
1302
+ static const uint32_t kmask2 = 0x0f0f0f0f;
1303
+ static const uint32_t kmask3 = 0x03030303;
1304
+
1305
+ assert (n % qk == 0);
1306
+ assert (nc % ncols_interleaved == 0);
1307
+
1308
+ UNUSED(s);
1309
+ UNUSED(bs);
1310
+ UNUSED(vx);
1311
+ UNUSED(vy);
1312
+ UNUSED(nr);
1313
+ UNUSED(nc);
1314
+ UNUSED(nb);
1315
+ UNUSED(ncols_interleaved);
1316
+ UNUSED(blocklen);
1317
+
1318
+ #if defined(__AVX2__)
1319
+ // Lookup table to convert signed nibbles to signed bytes
1320
+ __m256i signextendlut = _mm256_castsi128_si256(_mm_set_epi8(-1, -2, -3, -4, -5, -6, -7, -8, 7, 6, 5, 4, 3, 2, 1, 0));
1321
+ signextendlut = _mm256_permute2f128_si256(signextendlut, signextendlut, 0);
1322
+ // Shuffle masks to rearrange delta and scale values to multiply with appropriate scales
1323
+ __m128i deltamask = _mm_set_epi8(15, 14, 7, 6, 13, 12, 5, 4, 11, 10, 3, 2, 9, 8, 1, 0);
1324
+ __m128i scalemask = _mm_set_epi8(7, 7, 3, 3, 6, 6, 2, 2, 5, 5, 1, 1, 4, 4, 0, 0);
1325
+ // Permute mask used for easier vector processing at later stages
1326
+ __m256i finalpermutemask = _mm256_set_epi32(7, 5, 3, 1, 6, 4, 2, 0);
1327
+
1328
+ // Mask to extract nibbles from bytes
1329
+ const __m256i m4b = _mm256_set1_epi8(0x0F);
1330
+
1331
+ int64_t b_nb = n / QK_K;
1332
+
1333
+ const block_q4_Kx8 * b_ptr_start = (const block_q4_Kx8 *)vx;
1334
+ const block_q8_K * a_ptr_start = (const block_q8_K *)vy;
1335
+
1336
+ // Process Q8_K blocks one by one
1337
+ for (int64_t y = 0; y < nr; y++) {
1338
+
1339
+ // Pointers to LHS blocks of block_q8_K format
1340
+ const block_q8_K * a_ptr = a_ptr_start + (y * nb);
1341
+
1342
+ // Take group of eight interleaved block_q4_K structures at each pass of the loop and perform dot product operation
1343
+ for (int64_t x = 0; x < nc / 8; x++) {
1344
+
1345
+ // Pointers to RHS blocks
1346
+ const block_q4_Kx8 * b_ptr = b_ptr_start + (x * b_nb);
1347
+
1348
+ // Master FP accumulators
1349
+ __m256 acc_row = _mm256_setzero_ps();
1350
+ __m256 acc_min_rows = _mm256_setzero_ps();
1351
+
1352
+ for (int64_t b = 0; b < nb; b++) {
1353
+
1354
+ // Load and convert to FP32 scale from block_q8_K
1355
+ const __m256 row_scale_f32 = _mm256_set1_ps((a_ptr[b].d));
1356
+
1357
+ // Load the scale values for the 8 blocks interleaved in block_q4_Kx8
1358
+ // col_scale_f32 rearranged so as to multiply with appropriate quants
1359
+ const __m256 col_scale_f32 = GGML_F32Cx8_REARRANGE_LOAD(b_ptr[b].d, deltamask);
1360
+ const __m256 col_dmin_f32 = GGML_F32Cx8_LOAD(b_ptr[b].dmin);
1361
+
1362
+ __m256i iacc_b = _mm256_setzero_si256();
1363
+ __m256i iacc_min_b = _mm256_setzero_si256();
1364
+
1365
+ const __m256i q8sums = _mm256_loadu_si256((const __m256i * )(a_ptr[b].bsums));
1366
+ __m256i q8s = _mm256_castsi128_si256(_mm_hadd_epi16(_mm256_castsi256_si128(q8sums), _mm256_extracti128_si256(q8sums, 1)));
1367
+ q8s = _mm256_permute2f128_si256(q8s, q8s, 0);
1368
+
1369
+ // Processes two sub blocks from each Q4_K in each iteration
1370
+ for (int sb = 0; sb < QK_K / 64; sb++) {
1371
+
1372
+ // Load the eight block_q4_K for two sub blocks quantized values interleaved with each other in chunks of eight - B0,B1 ....B6,B7
1373
+ const __m256i rhs_raw_vec_0123_0 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + sb * 256));
1374
+ const __m256i rhs_raw_vec_4567_0 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 32 + sb * 256));
1375
+ const __m256i rhs_raw_vec_0123_1 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 64 + sb * 256));
1376
+ const __m256i rhs_raw_vec_4567_1 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 96 + sb * 256));
1377
+ const __m256i rhs_raw_vec_0123_2 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 128 + sb * 256));
1378
+ const __m256i rhs_raw_vec_4567_2 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 160 + sb * 256));
1379
+ const __m256i rhs_raw_vec_0123_3 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 192 + sb * 256));
1380
+ const __m256i rhs_raw_vec_4567_3 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 224 + sb * 256));
1381
+
1382
+ // 4-bit -> 8-bit
1383
+ // Values of the first sub block of eight block_q4_K structures for the sb loop
1384
+ const __m256i rhs_vec_0123_00 = _mm256_and_si256(rhs_raw_vec_0123_0, m4b);
1385
+ const __m256i rhs_vec_4567_00 = _mm256_and_si256(rhs_raw_vec_4567_0, m4b);
1386
+ const __m256i rhs_vec_0123_01 = _mm256_and_si256(rhs_raw_vec_0123_1, m4b);
1387
+ const __m256i rhs_vec_4567_01 = _mm256_and_si256(rhs_raw_vec_4567_1, m4b);
1388
+ const __m256i rhs_vec_0123_02 = _mm256_and_si256(rhs_raw_vec_0123_2, m4b);
1389
+ const __m256i rhs_vec_4567_02 = _mm256_and_si256(rhs_raw_vec_4567_2, m4b);
1390
+ const __m256i rhs_vec_0123_03 = _mm256_and_si256(rhs_raw_vec_0123_3, m4b);
1391
+ const __m256i rhs_vec_4567_03 = _mm256_and_si256(rhs_raw_vec_4567_3, m4b);
1392
+
1393
+ // Values of the second sub block of eight block_q4_K structures when sb = 1
1394
+ const __m256i rhs_vec_0123_10 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_0, 4), m4b);
1395
+ const __m256i rhs_vec_4567_10 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_0, 4), m4b);
1396
+ const __m256i rhs_vec_0123_11 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_1, 4), m4b);
1397
+ const __m256i rhs_vec_4567_11 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_1, 4), m4b);
1398
+ const __m256i rhs_vec_0123_12 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_2, 4), m4b);
1399
+ const __m256i rhs_vec_4567_12 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_2, 4), m4b);
1400
+ const __m256i rhs_vec_0123_13 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_3, 4), m4b);
1401
+ const __m256i rhs_vec_4567_13 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_3, 4), m4b);
1402
+
1403
+ uint32_t utmp_0[4], utmp_1[4];
1404
+
1405
+ // Scales and Mins of corresponding sub blocks from different Q8_K structures are stored together
1406
+ // The below block is for eg to extract first sub block's scales and mins from different Q4_K structures for the sb loop
1407
+ memcpy(utmp_0, b_ptr[b].scales + 24 * sb, 12);
1408
+ utmp_0[3] = ((utmp_0[2] >> 4) & kmask2) | (((utmp_0[1] >> 6) & kmask3) << 4);
1409
+ const uint32_t uaux_0 = utmp_0[1] & kmask1;
1410
+ utmp_0[1] = (utmp_0[2] & kmask2) | (((utmp_0[0] >> 6) & kmask3) << 4);
1411
+ utmp_0[2] = uaux_0;
1412
+ utmp_0[0] &= kmask1;
1413
+
1414
+ // The below block is for eg to extract second sub block's scales and mins from different Q4_K structures for the sb loop
1415
+ memcpy(utmp_1, b_ptr[b].scales + 12 + sb * 24, 12);
1416
+ utmp_1[3] = ((utmp_1[2] >> 4) & kmask2) | (((utmp_1[1] >> 6) & kmask3) << 4);
1417
+ const uint32_t uaux_1 = utmp_1[1] & kmask1;
1418
+ utmp_1[1] = (utmp_1[2] & kmask2) | (((utmp_1[0] >> 6) & kmask3) << 4);
1419
+ utmp_1[2] = uaux_1;
1420
+ utmp_1[0] &= kmask1;
1421
+
1422
+ // Scales of first sub block in the sb loop
1423
+ const __m128i mins_and_scales_0 = _mm_set_epi32(utmp_0[3], utmp_0[2], utmp_0[1], utmp_0[0]);
1424
+ __m128i scales_rearrange_0 = _mm_shuffle_epi8(mins_and_scales_0, scalemask);
1425
+ __m256i scales_0 = _mm256_cvtepu8_epi16(scales_rearrange_0);
1426
+
1427
+ // Scales of second sub block in the sb loop
1428
+ __m128i mins_and_scales_1 = _mm_set_epi32(utmp_1[3], utmp_1[2], utmp_1[1], utmp_1[0]);
1429
+ __m128i scales_rearrange_1 = _mm_shuffle_epi8(mins_and_scales_1, scalemask);
1430
+ __m256i scales_1 = _mm256_cvtepu8_epi16(scales_rearrange_1);
1431
+
1432
+ // Mins of first and second sub block of Q4_K block are arranged side by side
1433
+ __m256i mins_01 = _mm256_cvtepu8_epi16(_mm_unpacklo_epi8(_mm_shuffle_epi32(mins_and_scales_0, 78), _mm_shuffle_epi32(mins_and_scales_1, 78)));
1434
+
1435
+ // Load the two sub block values corresponding to sb in block_q8_K in batches of 16 bytes and replicate the same across 256 bit vector
1436
+ __m256i lhs_vec_00 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + sb * 64)));
1437
+ __m256i lhs_vec_01 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 16 + sb * 64)));
1438
+ __m256i lhs_vec_10 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 32 + sb * 64)));
1439
+ __m256i lhs_vec_11 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 48 + sb * 64)));
1440
+
1441
+ lhs_vec_00 = _mm256_permute2f128_si256(lhs_vec_00, lhs_vec_00, 0);
1442
+ lhs_vec_01 = _mm256_permute2f128_si256(lhs_vec_01, lhs_vec_01, 0);
1443
+ lhs_vec_10 = _mm256_permute2f128_si256(lhs_vec_10, lhs_vec_10, 0);
1444
+ lhs_vec_11 = _mm256_permute2f128_si256(lhs_vec_11, lhs_vec_11, 0);
1445
+
1446
+ // Dot product done within 32 bit lanes and accumulated in the same vector
1447
+ // First done for first sub block and thenn for second sub block in each sb
1448
+ // B0(0-3) B4(0-3) B1(0-3) B5(0-3) B2(0-3) B6(0-3) B3(0-3) B7(0-3) with A0(0-3)
1449
+ // B0(4-7) B4(4-7) B1(4-7) B5(4-7) B2(4-7) B6(4-7) B3(4-7) B7(4-7) with A0(4-7)
1450
+ // ...........................................................................
1451
+ // B0(28-31) B4(28-31) B1(28-31) B5(28-31) B2(28-31) B6(28-31) B3(28-31) B7(28-31) with A0(28-31)
1452
+
1453
+
1454
+ __m256i iacc_0 = _mm256_setzero_si256();
1455
+ __m256i iacc_1 = _mm256_setzero_si256();
1456
+
1457
+ iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_00 ,_mm256_shuffle_epi32(rhs_vec_4567_00, 177), 170), _mm256_shuffle_epi32(lhs_vec_00, 0)));
1458
+ iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_00, 177) ,rhs_vec_4567_00, 170), _mm256_shuffle_epi32(lhs_vec_00, 85)));
1459
+
1460
+ iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_01 ,_mm256_shuffle_epi32(rhs_vec_4567_01, 177), 170), _mm256_shuffle_epi32(lhs_vec_00, 170)));
1461
+ iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_01, 177) ,rhs_vec_4567_01, 170), _mm256_shuffle_epi32(lhs_vec_00, 255)));
1462
+
1463
+ iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_02 ,_mm256_shuffle_epi32(rhs_vec_4567_02, 177), 170), _mm256_shuffle_epi32(lhs_vec_01, 0)));
1464
+ iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_02, 177) ,rhs_vec_4567_02, 170), _mm256_shuffle_epi32(lhs_vec_01, 85)));
1465
+
1466
+ iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_03 ,_mm256_shuffle_epi32(rhs_vec_4567_03, 177), 170), _mm256_shuffle_epi32(lhs_vec_01, 170)));
1467
+ iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_03, 177) ,rhs_vec_4567_03, 170), _mm256_shuffle_epi32(lhs_vec_01, 255)));
1468
+
1469
+ iacc_0 = _mm256_madd_epi16(iacc_0, scales_0);
1470
+
1471
+ iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_10 ,_mm256_shuffle_epi32(rhs_vec_4567_10, 177), 170), _mm256_shuffle_epi32(lhs_vec_10, 0)));
1472
+ iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_10, 177) ,rhs_vec_4567_10, 170), _mm256_shuffle_epi32(lhs_vec_10, 85)));
1473
+
1474
+ iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_11 ,_mm256_shuffle_epi32(rhs_vec_4567_11, 177), 170), _mm256_shuffle_epi32(lhs_vec_10, 170)));
1475
+ iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_11, 177) ,rhs_vec_4567_11, 170), _mm256_shuffle_epi32(lhs_vec_10, 255)));
1476
+
1477
+ iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_12 ,_mm256_shuffle_epi32(rhs_vec_4567_12, 177), 170), _mm256_shuffle_epi32(lhs_vec_11, 0)));
1478
+ iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_12, 177) ,rhs_vec_4567_12, 170), _mm256_shuffle_epi32(lhs_vec_11, 85)));
1479
+
1480
+ iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_13 ,_mm256_shuffle_epi32(rhs_vec_4567_13, 177), 170), _mm256_shuffle_epi32(lhs_vec_11, 170)));
1481
+ iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_13, 177) ,rhs_vec_4567_13, 170), _mm256_shuffle_epi32(lhs_vec_11, 255)));
1482
+
1483
+ iacc_1 = _mm256_madd_epi16(iacc_1, scales_1);
1484
+
1485
+ // Accumulate the iacc value for one sb
1486
+ __m256i iacc_sb = _mm256_add_epi32(iacc_0, iacc_1);
1487
+
1488
+ // Broadcast the bsums of the two sub blocks of the iteration of Q8_K across the vector
1489
+ // Multiply-Add with corresponding mins of Q4_Kx8 with bsums
1490
+ __m256i q8s_sb = _mm256_shuffle_epi32(q8s, 0);
1491
+ __m256i iacc_min_sb = _mm256_madd_epi16(q8s_sb, mins_01);
1492
+ q8s = _mm256_bsrli_epi128(q8s, 4);
1493
+
1494
+ // Accumulate for the complete block
1495
+ iacc_b = _mm256_add_epi32(iacc_b, iacc_sb);
1496
+ iacc_min_b = _mm256_add_epi32(iacc_min_b, iacc_min_sb);
1497
+ }
1498
+
1499
+ // Multiply-Add with scale values for the complete super block
1500
+ acc_row = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_b), _mm256_mul_ps(col_scale_f32, row_scale_f32), acc_row);
1501
+ acc_min_rows = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_min_b), _mm256_mul_ps(col_dmin_f32, row_scale_f32), acc_min_rows);
1502
+
1503
+ }
1504
+
1505
+ // Accumulated output values permuted so as to be stored in appropriate order post accumulation
1506
+ acc_row = _mm256_permutevar8x32_ps(acc_row, finalpermutemask);
1507
+ _mm256_storeu_ps(s + (y * nr + x * 8), _mm256_sub_ps(acc_row, acc_min_rows));
1508
+ }
1509
+ }
1510
+
1511
+ #else
1512
+
1513
+ float sumf[8];
1514
+ float sum_minf[8];
1515
+ uint32_t utmp[32];
1516
+ int sumi1;
1517
+ int sumi2;
1518
+ int sumi;
1519
+
1520
+ const block_q8_K * a_ptr = (const block_q8_K *) vy;
1521
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
1522
+ const block_q4_Kx8 * b_ptr = (const block_q4_Kx8 *) vx + (x * nb);
1523
+
1524
+ for (int j = 0; j < ncols_interleaved; j++) {
1525
+ sumf[j] = 0.0;
1526
+ sum_minf[j] = 0.0;
1527
+ }
1528
+ for (int l = 0; l < nb; l++) {
1529
+ for (int sb = 0; sb < 8; sb++) {
1530
+ memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12);
1531
+ utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);
1532
+ const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;
1533
+ utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);
1534
+ utmp[sb * 4 + 2] = uaux_0;
1535
+ utmp[sb * 4 + 0] &= kmask1;
1536
+ }
1537
+ for (int k = 0; k < (qk / (2 * blocklen)); k++) {
1538
+ uint8_t *scales_0 = (uint8_t*) utmp + (k / 4) * 32;
1539
+ uint8_t *scales_1 = (uint8_t*) utmp + (k / 4) * 32 + 16;
1540
+ for (int j = 0; j < ncols_interleaved; j++) {
1541
+ sumi1 = 0;
1542
+ sumi2 = 0;
1543
+ sumi = 0;
1544
+ for (int i = 0; i < blocklen; ++i) {
1545
+ const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF);
1546
+ const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4);
1547
+ sumi1 = (v0 * a_ptr[l].qs[(k >> 2) * 64 + (k % 4) * blocklen + i]);
1548
+ sumi2 = (v1 * a_ptr[l].qs[(k >> 2) * 64 + (k % 4) * blocklen + i + 32]);
1549
+ sumi1 = sumi1 * scales_0[j];
1550
+ sumi2 = sumi2 * scales_1[j];
1551
+ sumi += sumi1 + sumi2;
1552
+ }
1553
+ sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d;
1554
+ }
1555
+ }
1556
+ for (int sb = 0; sb < 8; sb++) {
1557
+ uint8_t *mins = (uint8_t*) utmp + 8 + sb * 16;
1558
+ for (int j = 0; j < ncols_interleaved; j++) {
1559
+ sum_minf[j] += mins[j] * (a_ptr[l].bsums[sb * 2] + a_ptr[l].bsums[sb * 2 + 1]) * GGML_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d;
1560
+ }
1561
+ }
1562
+ }
1563
+ for (int j = 0; j < ncols_interleaved; j++) {
1564
+ s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j];
1565
+ }
1566
+ }
1567
+ #endif
1568
+ }
1569
+
1570
+
997
1571
  static void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
998
1572
  const int qk = QK8_0;
999
1573
  const int nb = n / qk;
@@ -3480,6 +4054,781 @@ static void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, c
3480
4054
  }
3481
4055
  }
3482
4056
 
4057
+ static void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
4058
+ const int qk = QK_K;
4059
+ const int nb = n / qk;
4060
+ const int ncols_interleaved = 8;
4061
+ const int blocklen = 8;
4062
+ static const uint32_t kmask1 = 0x3f3f3f3f;
4063
+ static const uint32_t kmask2 = 0x0f0f0f0f;
4064
+ static const uint32_t kmask3 = 0x03030303;
4065
+
4066
+ assert (n % qk == 0);
4067
+ assert (nr % 4 == 0);
4068
+ assert (nc % ncols_interleaved == 0);
4069
+
4070
+ UNUSED(s);
4071
+ UNUSED(bs);
4072
+ UNUSED(vx);
4073
+ UNUSED(vy);
4074
+ UNUSED(nr);
4075
+ UNUSED(nc);
4076
+ UNUSED(nb);
4077
+ UNUSED(ncols_interleaved);
4078
+ UNUSED(blocklen);
4079
+
4080
+ #if defined(__AVX2__)
4081
+ const block_q4_Kx8 * b_ptr_start = (const block_q4_Kx8 * ) vx;
4082
+ const block_q8_Kx4 * a_ptr_start = (const block_q8_Kx4 * ) vy;
4083
+ int64_t b_nb = n / QK_K;
4084
+ int64_t y = 0;
4085
+
4086
+ // Mask to mask out nibbles from packed bytes
4087
+ const __m256i m4b = _mm256_set1_epi8(0x0F);
4088
+ // Permute mask used for easier vector processing at later stages
4089
+ __m256i requiredOrder = _mm256_set_epi32(3, 2, 1, 0, 7, 6, 5, 4);
4090
+
4091
+ int anr = nr - nr % 16;; // Used to align nr with boundary of 16
4092
+ // Take group of four block_q8_Kx4 structures at each pass of the loop and perform dot product operation
4093
+ for (; y < anr / 4; y += 4) {
4094
+
4095
+ const block_q8_Kx4 * a_ptrs[4];
4096
+
4097
+ a_ptrs[0] = a_ptr_start + (y * nb);
4098
+ for (int i = 0; i < 3; ++i) {
4099
+ a_ptrs[i + 1] = a_ptrs[i] + nb;
4100
+ }
4101
+
4102
+ // Take group of eight block_q4_kx8 structures at each pass of the loop and perform dot product operation
4103
+ for (int64_t x = 0; x < nc / 8; x++) {
4104
+
4105
+ const block_q4_Kx8 * b_ptr = b_ptr_start + (x * b_nb);
4106
+
4107
+ // Master FP accumulators
4108
+ __m256 acc_rows[16];
4109
+ for (int i = 0; i < 16; i++) {
4110
+ acc_rows[i] = _mm256_setzero_ps();
4111
+ }
4112
+
4113
+ __m256 acc_min_rows[16];
4114
+ for (int i = 0; i < 16; i++) {
4115
+ acc_min_rows[i] = _mm256_setzero_ps();
4116
+ }
4117
+
4118
+ // For super block
4119
+ for (int64_t b = 0; b < nb; b++) {
4120
+
4121
+ // Scale values - Load the eight scale values of block_q4_kx8
4122
+ const __m256 col_scale_f32 = GGML_F32Cx8_LOAD(b_ptr[b].d);
4123
+
4124
+ // dmin values - Load the eight dmin values of block_q4_kx8
4125
+ const __m256 col_dmin_f32 = GGML_F32Cx8_LOAD(b_ptr[b].dmin);
4126
+
4127
+ // Loop to iterate over the eight sub blocks of a super block - two sub blocks are processed per iteration
4128
+ for (int sb = 0; sb < QK_K / 64; sb++) {
4129
+
4130
+ // Load the eight block_q4_K for two sub blocks quantized values interleaved with each other in chunks of eight bytes - B0,B1 ....B6,B7
4131
+ const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + sb * 256));
4132
+ const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 32 + sb * 256));
4133
+ const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 64 + sb * 256));
4134
+ const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 96 + sb * 256));
4135
+ const __m256i rhs_raw_mat_0123_2 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 128 + sb * 256));
4136
+ const __m256i rhs_raw_mat_4567_2 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 160 + sb * 256));
4137
+ const __m256i rhs_raw_mat_0123_3 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 192 + sb * 256));
4138
+ const __m256i rhs_raw_mat_4567_3 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 224 + sb * 256));
4139
+
4140
+ // Save the values in the following vectors in the formats B0B1B4B5, B2B3B6B7 for further processing and storing of values
4141
+ const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240);
4142
+ const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240);
4143
+ const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240);
4144
+ const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240);
4145
+ const __m256i rhs_raw_mat_0145_2 = _mm256_blend_epi32(rhs_raw_mat_0123_2, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_2, requiredOrder), 240);
4146
+ const __m256i rhs_raw_mat_2367_2 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_2, requiredOrder), rhs_raw_mat_4567_2, 240);
4147
+ const __m256i rhs_raw_mat_0145_3 = _mm256_blend_epi32(rhs_raw_mat_0123_3, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_3, requiredOrder), 240);
4148
+ const __m256i rhs_raw_mat_2367_3 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_3, requiredOrder), rhs_raw_mat_4567_3, 240);
4149
+
4150
+ // 4-bit -> 8-bit
4151
+ // First sub block of the two sub blocks processed in the iteration
4152
+ const __m256i rhs_mat_0145_00 = _mm256_and_si256(rhs_raw_mat_0145_0, m4b); //B00(0-7) B01(0-7) B04(0-7) B05(0-7)
4153
+ const __m256i rhs_mat_2367_00 = _mm256_and_si256(rhs_raw_mat_2367_0, m4b); //B02(0-7) B03(0-7) B06(0-7) B07(0-7)
4154
+
4155
+ const __m256i rhs_mat_0145_01 = _mm256_and_si256(rhs_raw_mat_0145_1, m4b); //B00(8-15) B01(8-15) B04(8-15) B05(8-15)
4156
+ const __m256i rhs_mat_2367_01 = _mm256_and_si256(rhs_raw_mat_2367_1, m4b); //B02(8-15) B03(8-15) B06(8-15) B07(8-15)
4157
+
4158
+ const __m256i rhs_mat_0145_02 = _mm256_and_si256(rhs_raw_mat_0145_2, m4b); //B00(16-23) B01(16-23) B04(16-23) B05(16-23)
4159
+ const __m256i rhs_mat_2367_02 = _mm256_and_si256(rhs_raw_mat_2367_2, m4b); //B02(16-23) B03(16-23) B06(16-23) B07(16-23)
4160
+
4161
+ const __m256i rhs_mat_0145_03 = _mm256_and_si256(rhs_raw_mat_0145_3, m4b); //B00(24-31) B01(24-31) B04(24-31) B05(24-31)
4162
+ const __m256i rhs_mat_2367_03 = _mm256_and_si256(rhs_raw_mat_2367_3, m4b); //B02(24-31) B03(24-31) B06(24-31) B07(24-31)
4163
+
4164
+ // Second sub block of the two sub blocks processed in the iteration
4165
+ const __m256i rhs_mat_0145_10 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_0, 4), m4b); //B10(0-7) B11(0-7) B14(0-7) B15(0-7)
4166
+ const __m256i rhs_mat_2367_10 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_0, 4), m4b); //B12(0-7) B13(0-7) B16(0-7) B17(0-7)
4167
+
4168
+ const __m256i rhs_mat_0145_11 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_1, 4), m4b); //B10(8-15) B11(8-15) B14(8-15) B15(8-15)
4169
+ const __m256i rhs_mat_2367_11 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_1, 4), m4b); //B12(8-15) B13(8-15) B16(8-15) B17(8-15)
4170
+
4171
+ const __m256i rhs_mat_0145_12 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_2, 4), m4b); //B10(16-23) B11(16-23) B14(16-23) B15(16-23)
4172
+ const __m256i rhs_mat_2367_12 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_2, 4), m4b); //B12(16-23) B13(16-23) B16(16-23) B17(16-23)
4173
+
4174
+ const __m256i rhs_mat_0145_13 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_3, 4), m4b); //B10(24-31) B11(24-31) B14(24-31) B15(24-31)
4175
+ const __m256i rhs_mat_2367_13 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_3, 4), m4b); //B12(24-31) B13(24-31) B16(24-31) B17(24-31)
4176
+
4177
+ // Shuffle pattern one - right side input
4178
+ const __m256i rhs_mat_0145_00_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_00, 136); //B00(0-3) B01(0-3) B00(0-3) B01(0-3) B04(0-3) B05(0-3) B04(0-3) B05(0-3)
4179
+ const __m256i rhs_mat_2367_00_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_00, 136); //B02(0-3) B03(0-3) B02(0-3) B03(0-3) B06(0-3) B07(0-3) B06(0-3) B07(0-3)
4180
+
4181
+ const __m256i rhs_mat_0145_01_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_01, 136); //B00(8-11) B01(8-11) B00(8-11) B01(8-11) B04(8-11) B05(8-11) B04(8-11) B05(8-11)
4182
+ const __m256i rhs_mat_2367_01_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_01, 136); //B02(8-11) B03(8-11) B02(8-11) B03(8-11) B06(8-11) B07(8-11) B06(8-11) B07(8-11)
4183
+
4184
+ const __m256i rhs_mat_0145_02_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_02, 136); //B00(16-19) B01(16-19) B00(16-19) B01(16-19) B04(16-19) B05(16-19) B04(16-19) B05(16-19)
4185
+ const __m256i rhs_mat_2367_02_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_02, 136); //B02(16-19) B03(16-19) B02(16-19) B03(16-19) B06(16-19) B07(16-19) B06(16-19) B07(16-19)
4186
+
4187
+ const __m256i rhs_mat_0145_03_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_03, 136); //B00(24-27) B01(24-27) B00(24-27) B01(24-27) B04(24-27) B05(24-27) B04(24-27) B05(24-27)
4188
+ const __m256i rhs_mat_2367_03_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_03, 136); //B02(24-27) B03(24-27) B02(24-27) B03(24-27) B06(24-27) B07(24-27) B06(24-27) B07(24-27)
4189
+
4190
+ const __m256i rhs_mat_0145_10_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_10, 136); //B10(0-3) B11(0-3) B10(0-3) B11(0-3) B14(0-3) B15(0-3) B14(0-3) B15(0-3)
4191
+ const __m256i rhs_mat_2367_10_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_10, 136); //B12(0-3) B13(0-3) B12(0-3) B13(0-3) B16(0-3) B17(0-3) B16(0-3) B17(0-3)
4192
+
4193
+ const __m256i rhs_mat_0145_11_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_11, 136); //B10(8-11) B11(8-11) B10(8-11) B11(8-11) B14(8-11) B15(8-11) B14(8-11) B15(8-11)
4194
+ const __m256i rhs_mat_2367_11_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_11, 136); //B12(8-11) B13(8-11) B12(8-11) B13(8-11) B16(8-11) B17(8-11) B16(8-11) B17(8-11)
4195
+
4196
+ const __m256i rhs_mat_0145_12_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_12, 136); //B10(16-19) B11(16-19) B10(16-19) B11(16-19) B14(16-19) B15(16-19) B14(16-19) B15(16-19)
4197
+ const __m256i rhs_mat_2367_12_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_12, 136); //B12(16-19) B13(16-19) B12(16-19) B13(16-19) B16(16-19) B17(16-19) B16(16-19) B17(16-19)
4198
+
4199
+ const __m256i rhs_mat_0145_13_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_13, 136); //B10(24-27) B11(24-27) B10(24-27) B11(24-27) B14(24-27) B15(24-27) B14(24-27) B15(24-27)
4200
+ const __m256i rhs_mat_2367_13_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_13, 136); //B12(24-27) B13(24-27) B12(24-27) B13(24-27) B16(24-27) B17(24-27) B16(24-27) B17(24-27)
4201
+
4202
+
4203
+ // Shuffle pattern two - right side input
4204
+ const __m256i rhs_mat_0145_00_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_00, 221); //B00(4-7) B01(4-7) B00(4-7) B01(4-7) B04(4-7) B05(4-7) B04(4-7) B05(4-7)
4205
+ const __m256i rhs_mat_2367_00_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_00, 221); //B02(4-7) B03(4-7) B02(4-7) B03(4-7) B06(4-7) B07(4-7) B06(4-7) B07(4-7)
4206
+
4207
+ const __m256i rhs_mat_0145_01_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_01, 221); //B00(12-15) B01(12-15) B00(12-15) B01(12-15) B04(12-15) B05(12-15) B04(12-15) B05(12-15)
4208
+ const __m256i rhs_mat_2367_01_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_01, 221); //B02(12-15) B03(12-15) B02(12-15) B03(12-15) B06(12-15) B07(12-15) B06(12-15) B07(12-15)
4209
+
4210
+ const __m256i rhs_mat_0145_02_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_02, 221); //B00(20-23) B01(20-23) B00(20-23) B01(20-23) B04(20-23) B05(20-23) B04(20-23) B05(20-23)
4211
+ const __m256i rhs_mat_2367_02_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_02, 221); //B02(20-23) B03(20-23) B02(20-23) B03(20-23) B06(20-23) B07(20-23) B06(20-23) B07(20-23)
4212
+
4213
+ const __m256i rhs_mat_0145_03_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_03, 221); //B00(28-31) B01(28-31) B00(28-31) B01(28-31) B04(28-31) B05(28-31) B04(28-31) B05(28-31)
4214
+ const __m256i rhs_mat_2367_03_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_03, 221); //B02(28-31) B03(28-31) B02(28-31) B03(28-31) B06(28-31) B07(28-31) B06(28-31) B07(28-31)
4215
+
4216
+ const __m256i rhs_mat_0145_10_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_10, 221); //B10(4-7) B11(4-7) B10(4-7) B11(4-7) B14(4-7) B15(4-7) B14(4-7) B15(4-7)
4217
+ const __m256i rhs_mat_2367_10_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_10, 221); //B12(4-7) B13(4-7) B12(4-7) B13(4-7) B16(4-7) B17(4-7) B16(4-7) B17(4-7)
4218
+
4219
+ const __m256i rhs_mat_0145_11_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_11, 221); //B10(12-15) B11(12-15) B10(12-15) B11(12-15) B14(12-15) B15(12-15) B14(12-15) B15(12-15)
4220
+ const __m256i rhs_mat_2367_11_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_11, 221); //B12(12-15) B13(12-15) B12(12-15) B13(12-15) B16(12-15) B17(12-15) B16(12-15) B17(12-15)
4221
+
4222
+ const __m256i rhs_mat_0145_12_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_12, 221); //B10(20-23) B11(20-23) B10(20-23) B11(20-23) B14(20-23) B15(20-23) B14(20-23) B15(20-23)
4223
+ const __m256i rhs_mat_2367_12_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_12, 221); //B12(20-23) B13(20-23) B12(20-23) B13(20-23) B16(20-23) B17(20-23) B16(20-23) B17(20-23)
4224
+
4225
+ const __m256i rhs_mat_0145_13_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_13, 221); //B10(28-31) B11(28-31) B10(28-31) B11(28-31) B14(28-31) B15(28-31) B14(28-31) B15(28-31)
4226
+ const __m256i rhs_mat_2367_13_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_13, 221); //B12(28-31) B13(28-31) B12(28-31) B13(28-31) B16(28-31) B17(28-31) B16(28-31) B17(28-31)
4227
+
4228
+ uint32_t utmp_0[4], utmp_1[4];
4229
+
4230
+ // Scales and Mins of corresponding sub blocks from different Q4_K structures are stored together
4231
+ // The below block is for eg to extract first sub block's scales and mins from different Q4_K structures for the sb loop
4232
+ memcpy(utmp_0, b_ptr[b].scales + 24 * sb, 12);
4233
+ utmp_0[3] = ((utmp_0[2] >> 4) & kmask2) | (((utmp_0[1] >> 6) & kmask3) << 4);
4234
+ const uint32_t uaux_0 = utmp_0[1] & kmask1;
4235
+ utmp_0[1] = (utmp_0[2] & kmask2) | (((utmp_0[0] >> 6) & kmask3) << 4);
4236
+ utmp_0[2] = uaux_0;
4237
+ utmp_0[0] &= kmask1;
4238
+
4239
+ // The below block is for eg to extract second sub block's scales and mins from different Q4_K structures for the sb loop
4240
+ memcpy(utmp_1, b_ptr[b].scales + 12 + sb * 24, 12);
4241
+ utmp_1[3] = ((utmp_1[2] >> 4) & kmask2) | (((utmp_1[1] >> 6) & kmask3) << 4);
4242
+ const uint32_t uaux_1 = utmp_1[1] & kmask1;
4243
+ utmp_1[1] = (utmp_1[2] & kmask2) | (((utmp_1[0] >> 6) & kmask3) << 4);
4244
+ utmp_1[2] = uaux_1;
4245
+ utmp_1[0] &= kmask1;
4246
+
4247
+ // Scales of first sub block in the sb loop
4248
+ const __m128i mins_and_scales_0 = _mm_set_epi32(utmp_0[3], utmp_0[2], utmp_0[1], utmp_0[0]);
4249
+ const __m256i scales_0 = _mm256_cvtepu8_epi16(_mm_unpacklo_epi8(mins_and_scales_0, mins_and_scales_0));
4250
+
4251
+ // Scales of second sub block in the sb loop
4252
+ const __m128i mins_and_scales_1 = _mm_set_epi32(utmp_1[3], utmp_1[2], utmp_1[1], utmp_1[0]);
4253
+ const __m256i scales_1 = _mm256_cvtepu8_epi16(_mm_unpacklo_epi8(mins_and_scales_1, mins_and_scales_1));
4254
+
4255
+ // Mins of first and second sub block of Q4_K block are arranged side by side
4256
+ const __m256i mins_01 = _mm256_cvtepu8_epi16(_mm_unpacklo_epi8(_mm_shuffle_epi32(mins_and_scales_0, 78), _mm_shuffle_epi32(mins_and_scales_1, 78)));
4257
+
4258
+ const __m256i scale_0145_0 = _mm256_shuffle_epi32(scales_0, 68);
4259
+ const __m256i scale_2367_0 = _mm256_shuffle_epi32(scales_0, 238);
4260
+
4261
+ const __m256i scale_0145_1 = _mm256_shuffle_epi32(scales_1, 68);
4262
+ const __m256i scale_2367_1 = _mm256_shuffle_epi32(scales_1, 238);
4263
+
4264
+ for (int rp = 0; rp < 4; rp++) {
4265
+
4266
+ // Load the four block_q8_k quantized values interleaved with each other in chunks of eight bytes - A0,A1,A2,A3
4267
+ // Loaded as set of 128 bit vectors and repeated into a 256 bit vector
4268
+ __m256i lhs_mat_0123_00 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 256 * sb)));
4269
+ __m256i lhs_mat_01_00 = _mm256_permute2f128_si256(lhs_mat_0123_00, lhs_mat_0123_00, 0);
4270
+ __m256i lhs_mat_23_00 = _mm256_permute2f128_si256(lhs_mat_0123_00, lhs_mat_0123_00, 17);
4271
+ __m256i lhs_mat_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 32 + 256 * sb)));
4272
+ __m256i lhs_mat_01_01 = _mm256_permute2f128_si256(lhs_mat_0123_01, lhs_mat_0123_01, 0);
4273
+ __m256i lhs_mat_23_01 = _mm256_permute2f128_si256(lhs_mat_0123_01, lhs_mat_0123_01, 17);
4274
+ __m256i lhs_mat_0123_02 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 64 + 256 * sb)));
4275
+ __m256i lhs_mat_01_02 = _mm256_permute2f128_si256(lhs_mat_0123_02, lhs_mat_0123_02, 0);
4276
+ __m256i lhs_mat_23_02 = _mm256_permute2f128_si256(lhs_mat_0123_02, lhs_mat_0123_02, 17);
4277
+ __m256i lhs_mat_0123_03 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 96 + 256 * sb)));
4278
+ __m256i lhs_mat_01_03 = _mm256_permute2f128_si256(lhs_mat_0123_03, lhs_mat_0123_03, 0);
4279
+ __m256i lhs_mat_23_03 = _mm256_permute2f128_si256(lhs_mat_0123_03, lhs_mat_0123_03, 17);
4280
+ __m256i lhs_mat_0123_10 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 128 + 256 * sb)));
4281
+ __m256i lhs_mat_01_10 = _mm256_permute2f128_si256(lhs_mat_0123_10, lhs_mat_0123_10, 0);
4282
+ __m256i lhs_mat_23_10 = _mm256_permute2f128_si256(lhs_mat_0123_10, lhs_mat_0123_10, 17);
4283
+ __m256i lhs_mat_0123_11 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 160 + 256 * sb)));
4284
+ __m256i lhs_mat_01_11 = _mm256_permute2f128_si256(lhs_mat_0123_11, lhs_mat_0123_11, 0);
4285
+ __m256i lhs_mat_23_11 = _mm256_permute2f128_si256(lhs_mat_0123_11, lhs_mat_0123_11, 17);
4286
+ __m256i lhs_mat_0123_12 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 192 + 256 * sb)));
4287
+ __m256i lhs_mat_01_12 = _mm256_permute2f128_si256(lhs_mat_0123_12, lhs_mat_0123_12, 0);
4288
+ __m256i lhs_mat_23_12 = _mm256_permute2f128_si256(lhs_mat_0123_12, lhs_mat_0123_12, 17);
4289
+ __m256i lhs_mat_0123_13 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 224 + 256 * sb)));
4290
+ __m256i lhs_mat_01_13 = _mm256_permute2f128_si256(lhs_mat_0123_13, lhs_mat_0123_13, 0);
4291
+ __m256i lhs_mat_23_13 = _mm256_permute2f128_si256(lhs_mat_0123_13, lhs_mat_0123_13, 17);
4292
+
4293
+ // Bsums are loaded - four bsums are loaded (for two sub blocks) for the different Q8_K blocks
4294
+ __m256i lhs_bsums_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].bsums + 16 * sb)));
4295
+ __m256i lhs_bsums_hsum_0123_01 = _mm256_castsi128_si256(_mm_hadd_epi16(_mm256_castsi256_si128(lhs_bsums_0123_01), _mm256_extractf128_si256(lhs_bsums_0123_01, 1)));
4296
+ lhs_bsums_hsum_0123_01 = _mm256_permute2x128_si256(lhs_bsums_hsum_0123_01, lhs_bsums_hsum_0123_01, 0);
4297
+
4298
+ // Shuffle pattern one - left side input
4299
+ const __m256i lhs_mat_01_00_sp1 = _mm256_shuffle_epi32(lhs_mat_01_00, 160); //A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3)
4300
+ const __m256i lhs_mat_23_00_sp1 = _mm256_shuffle_epi32(lhs_mat_23_00, 160); //A02(0-3) A03(0-3) A02(0-3) A03(0-3) A02(0-3) A03(0-3) A02(0-3) A03(0-3)
4301
+
4302
+ const __m256i lhs_mat_01_01_sp1 = _mm256_shuffle_epi32(lhs_mat_01_01, 160); //A00(8-11) A00(8-11) A01(8-11) A01(8-11) A00(8-11) A00(8-11) A01(8-11) A01(8-11)
4303
+ const __m256i lhs_mat_23_01_sp1 = _mm256_shuffle_epi32(lhs_mat_23_01, 160); //A02(8-11) A03(8-11) A02(8-11) A03(8-11) A02(8-11) A03(8-11) A02(8-11) A03(8-11)
4304
+
4305
+ const __m256i lhs_mat_01_02_sp1 = _mm256_shuffle_epi32(lhs_mat_01_02, 160); //A00(16-19) A00(16-19) A01(16-19) A01(16-19) A00(16-19) A00(16-19) A01(16-19) A01(16-19)
4306
+ const __m256i lhs_mat_23_02_sp1 = _mm256_shuffle_epi32(lhs_mat_23_02, 160); //A02(16-19) A03(16-19) A02(16-19) A03(16-19) A02(16-19) A03(16-19) A02(16-19) A03(16-19)
4307
+
4308
+ const __m256i lhs_mat_01_03_sp1 = _mm256_shuffle_epi32(lhs_mat_01_03, 160); //A00(24-27) A00(24-27) A01(24-27) A01(24-27) A00(24-27) A00(24-27) A01(24-27) A01(24-27)
4309
+ const __m256i lhs_mat_23_03_sp1 = _mm256_shuffle_epi32(lhs_mat_23_03, 160); //A02(24-27) A03(24-27) A02(24-27) A03(24-27) A02(24-27) A03(24-27) A02(24-27) A03(24-27)
4310
+
4311
+ const __m256i lhs_mat_01_10_sp1 = _mm256_shuffle_epi32(lhs_mat_01_10, 160); //A10(0-3) A10(0-3) A11(0-3) A11(0-3) A10(0-3) A10(0-3) A11(0-3) A11(0-3)
4312
+ const __m256i lhs_mat_23_10_sp1 = _mm256_shuffle_epi32(lhs_mat_23_10, 160); //A12(0-3) A13(0-3) A12(0-3) A13(0-3) A12(0-3) A13(0-3) A12(0-3) A13(0-3)
4313
+
4314
+ const __m256i lhs_mat_01_11_sp1 = _mm256_shuffle_epi32(lhs_mat_01_11, 160); //A10(8-11) A10(8-11) A11(8-11) A11(8-11) A10(8-11) A10(8-11) A11(8-11) A11(8-11)
4315
+ const __m256i lhs_mat_23_11_sp1 = _mm256_shuffle_epi32(lhs_mat_23_11, 160); //A12(8-11) A13(8-11) A12(8-11) A13(8-11) A12(8-11) A13(8-11) A12(8-11) A13(8-11)
4316
+
4317
+ const __m256i lhs_mat_01_12_sp1 = _mm256_shuffle_epi32(lhs_mat_01_12, 160); //A10(16-19) A10(16-19) A11(16-19) A11(16-19) A10(16-19) A10(16-19) A11(16-19) A11(16-19)
4318
+ const __m256i lhs_mat_23_12_sp1 = _mm256_shuffle_epi32(lhs_mat_23_12, 160); //A12(16-19) A13(16-19) A12(16-19) A13(16-19) A12(16-19) A13(16-19) A12(16-19) A13(16-19)
4319
+
4320
+ const __m256i lhs_mat_01_13_sp1 = _mm256_shuffle_epi32(lhs_mat_01_13, 160); //A10(24-27) A10(24-27) A11(24-27) A11(24-27) A10(24-27) A10(24-27) A11(24-27) A11(24-27)
4321
+ const __m256i lhs_mat_23_13_sp1 = _mm256_shuffle_epi32(lhs_mat_23_13, 160); //A12(24-27) A13(24-27) A12(24-27) A13(24-27) A12(24-27) A13(24-27) A12(24-27) A13(24-27)
4322
+
4323
+ // Shuffle pattern two- left side input
4324
+ const __m256i lhs_mat_01_00_sp2 = _mm256_shuffle_epi32(lhs_mat_01_00, 245); //A00(4-7) A00(4-7) A01(4-7) A01(4-7) A00(4-7) A00(4-7) A01(4-7) A01(4-7)
4325
+ const __m256i lhs_mat_23_00_sp2 = _mm256_shuffle_epi32(lhs_mat_23_00, 245); //A02(4-7) A03(4-7) A02(4-7) A03(4-7) A02(4-7) A03(4-7) A02(4-7) A03(4-7)
4326
+
4327
+ const __m256i lhs_mat_01_01_sp2 = _mm256_shuffle_epi32(lhs_mat_01_01, 245); //A00(12-15) A00(12-15) A01(12-15) A01(12-15) A00(12-15) A00(12-15) A01(12-15) A01(12-15)
4328
+ const __m256i lhs_mat_23_01_sp2 = _mm256_shuffle_epi32(lhs_mat_23_01, 245); //A02(12-15) A03(12-15) A02(12-15) A03(12-15) A02(12-15) A03(12-15) A02(12-15) A03(12-15)
4329
+
4330
+ const __m256i lhs_mat_01_02_sp2 = _mm256_shuffle_epi32(lhs_mat_01_02, 245); //A00(20-23) A00(20-23) A01(20-23) A01(20-23) A00(20-23) A00(20-23) A01(20-23) A01(20-23)
4331
+ const __m256i lhs_mat_23_02_sp2 = _mm256_shuffle_epi32(lhs_mat_23_02, 245); //A02(20-23) A03(20-23) A02(20-23) A03(20-23) A02(20-23) A03(20-23) A02(20-23) A03(20-23)
4332
+
4333
+ const __m256i lhs_mat_01_03_sp2 = _mm256_shuffle_epi32(lhs_mat_01_03, 245); //A00(28-31) A00(28-31) A01(28-31) A01(28-31) A00(28-31) A00(28-31) A01(28-31) A01(28-31)
4334
+ const __m256i lhs_mat_23_03_sp2 = _mm256_shuffle_epi32(lhs_mat_23_03, 245); //A02(28-31) A03(28-31) A02(28-31) A03(28-31) A02(28-31) A03(28-31) A02(28-31) A03(28-31)
4335
+
4336
+ const __m256i lhs_mat_01_10_sp2 = _mm256_shuffle_epi32(lhs_mat_01_10, 245); //A10(4-7) A10(4-7) A11(4-7) A11(4-7) A10(4-7) A10(4-7) A11(4-7) A11(4-7)
4337
+ const __m256i lhs_mat_23_10_sp2 = _mm256_shuffle_epi32(lhs_mat_23_10, 245); //A12(4-7) A13(4-7) A12(4-7) A13(4-7) A12(4-7) A13(4-7) A12(4-7) A13(4-7)
4338
+
4339
+ const __m256i lhs_mat_01_11_sp2 = _mm256_shuffle_epi32(lhs_mat_01_11, 245); //A10(12-15) A10(12-15) A11(12-15) A11(12-15) A10(12-15) A10(12-15) A11(12-15) A11(12-15)
4340
+ const __m256i lhs_mat_23_11_sp2 = _mm256_shuffle_epi32(lhs_mat_23_11, 245); //A12(12-15) A13(12-15) A12(12-15) A13(12-15) A12(12-15) A13(12-15) A12(12-15) A13(12-15)
4341
+
4342
+ const __m256i lhs_mat_01_12_sp2 = _mm256_shuffle_epi32(lhs_mat_01_12, 245); //A10(20-23) A10(20-23) A11(20-23) A11(20-23) A10(20-23) A10(20-23) A11(20-23) A11(20-23)
4343
+ const __m256i lhs_mat_23_12_sp2 = _mm256_shuffle_epi32(lhs_mat_23_12, 245); //A12(20-23) A13(20-23) A12(20-23) A13(20-23) A12(20-23) A13(20-23) A12(20-23) A13(20-23)
4344
+
4345
+ const __m256i lhs_mat_01_13_sp2 = _mm256_shuffle_epi32(lhs_mat_01_13, 245); //A10(28-31) A10(28-31) A11(28-31) A11(28-31) A10(28-31) A10(28-31) A11(28-31) A11(28-31)
4346
+ const __m256i lhs_mat_23_13_sp2 = _mm256_shuffle_epi32(lhs_mat_23_13, 245); //A12(28-31) A13(28-31) A12(28-31) A13(28-31) A12(28-31) A13(28-31) A12(28-31) A13(28-31)
4347
+
4348
+ // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
4349
+ __m256i iacc_mat_00_0_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_03_sp1, lhs_mat_01_03_sp1), _mm256_maddubs_epi16(rhs_mat_0145_02_sp1, lhs_mat_01_02_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_01_sp1, lhs_mat_01_01_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_00_sp1, lhs_mat_01_00_sp1));
4350
+ __m256i iacc_mat_01_0_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_03_sp1, lhs_mat_01_03_sp1), _mm256_maddubs_epi16(rhs_mat_2367_02_sp1, lhs_mat_01_02_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_01_sp1, lhs_mat_01_01_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_00_sp1, lhs_mat_01_00_sp1));
4351
+ __m256i iacc_mat_10_0_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_03_sp1, lhs_mat_23_03_sp1), _mm256_maddubs_epi16(rhs_mat_0145_02_sp1, lhs_mat_23_02_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_01_sp1, lhs_mat_23_01_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_00_sp1, lhs_mat_23_00_sp1));
4352
+ __m256i iacc_mat_11_0_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_03_sp1, lhs_mat_23_03_sp1), _mm256_maddubs_epi16(rhs_mat_2367_02_sp1, lhs_mat_23_02_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_01_sp1, lhs_mat_23_01_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_00_sp1, lhs_mat_23_00_sp1));
4353
+ __m256i iacc_mat_00_1_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_13_sp1, lhs_mat_01_13_sp1), _mm256_maddubs_epi16(rhs_mat_0145_12_sp1, lhs_mat_01_12_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_11_sp1, lhs_mat_01_11_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_10_sp1, lhs_mat_01_10_sp1));
4354
+ __m256i iacc_mat_01_1_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_13_sp1, lhs_mat_01_13_sp1), _mm256_maddubs_epi16(rhs_mat_2367_12_sp1, lhs_mat_01_12_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_11_sp1, lhs_mat_01_11_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_10_sp1, lhs_mat_01_10_sp1));
4355
+ __m256i iacc_mat_10_1_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_13_sp1, lhs_mat_23_13_sp1), _mm256_maddubs_epi16(rhs_mat_0145_12_sp1, lhs_mat_23_12_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_11_sp1, lhs_mat_23_11_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_10_sp1, lhs_mat_23_10_sp1));
4356
+ __m256i iacc_mat_11_1_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_13_sp1, lhs_mat_23_13_sp1), _mm256_maddubs_epi16(rhs_mat_2367_12_sp1, lhs_mat_23_12_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_11_sp1, lhs_mat_23_11_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_10_sp1, lhs_mat_23_10_sp1));
4357
+
4358
+ __m256i iacc_mat_00_0_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_03_sp2, lhs_mat_01_03_sp2), _mm256_maddubs_epi16(rhs_mat_0145_02_sp2, lhs_mat_01_02_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_01_sp2, lhs_mat_01_01_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_00_sp2, lhs_mat_01_00_sp2));
4359
+ __m256i iacc_mat_01_0_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_03_sp2, lhs_mat_01_03_sp2), _mm256_maddubs_epi16(rhs_mat_2367_02_sp2, lhs_mat_01_02_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_01_sp2, lhs_mat_01_01_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_00_sp2, lhs_mat_01_00_sp2));
4360
+ __m256i iacc_mat_10_0_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_03_sp2, lhs_mat_23_03_sp2), _mm256_maddubs_epi16(rhs_mat_0145_02_sp2, lhs_mat_23_02_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_01_sp2, lhs_mat_23_01_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_00_sp2, lhs_mat_23_00_sp2));
4361
+ __m256i iacc_mat_11_0_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_03_sp2, lhs_mat_23_03_sp2), _mm256_maddubs_epi16(rhs_mat_2367_02_sp2, lhs_mat_23_02_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_01_sp2, lhs_mat_23_01_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_00_sp2, lhs_mat_23_00_sp2));
4362
+ __m256i iacc_mat_00_1_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_13_sp2, lhs_mat_01_13_sp2), _mm256_maddubs_epi16(rhs_mat_0145_12_sp2, lhs_mat_01_12_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_11_sp2, lhs_mat_01_11_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_10_sp2, lhs_mat_01_10_sp2));
4363
+ __m256i iacc_mat_01_1_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_13_sp2, lhs_mat_01_13_sp2), _mm256_maddubs_epi16(rhs_mat_2367_12_sp2, lhs_mat_01_12_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_11_sp2, lhs_mat_01_11_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_10_sp2, lhs_mat_01_10_sp2));
4364
+ __m256i iacc_mat_10_1_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_13_sp2, lhs_mat_23_13_sp2), _mm256_maddubs_epi16(rhs_mat_0145_12_sp2, lhs_mat_23_12_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_11_sp2, lhs_mat_23_11_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_10_sp2, lhs_mat_23_10_sp2));
4365
+ __m256i iacc_mat_11_1_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_13_sp2, lhs_mat_23_13_sp2), _mm256_maddubs_epi16(rhs_mat_2367_12_sp2, lhs_mat_23_12_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_11_sp2, lhs_mat_23_11_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_10_sp2, lhs_mat_23_10_sp2));
4366
+
4367
+ // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
4368
+ __m256i iacc_mat_00_0 = _mm256_add_epi16(iacc_mat_00_0_sp1, iacc_mat_00_0_sp2);
4369
+ __m256i iacc_mat_01_0 = _mm256_add_epi16(iacc_mat_01_0_sp1, iacc_mat_01_0_sp2);
4370
+ __m256i iacc_mat_10_0 = _mm256_add_epi16(iacc_mat_10_0_sp1, iacc_mat_10_0_sp2);
4371
+ __m256i iacc_mat_11_0 = _mm256_add_epi16(iacc_mat_11_0_sp1, iacc_mat_11_0_sp2);
4372
+
4373
+ __m256i iacc_mat_00_1 = _mm256_add_epi16(iacc_mat_00_1_sp1, iacc_mat_00_1_sp2);
4374
+ __m256i iacc_mat_01_1 = _mm256_add_epi16(iacc_mat_01_1_sp1, iacc_mat_01_1_sp2);
4375
+ __m256i iacc_mat_10_1 = _mm256_add_epi16(iacc_mat_10_1_sp1, iacc_mat_10_1_sp2);
4376
+ __m256i iacc_mat_11_1 = _mm256_add_epi16(iacc_mat_11_1_sp1, iacc_mat_11_1_sp2);
4377
+
4378
+ // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
4379
+ iacc_mat_00_0 = _mm256_madd_epi16(iacc_mat_00_0, scale_0145_0);
4380
+ iacc_mat_01_0 = _mm256_madd_epi16(iacc_mat_01_0, scale_2367_0);
4381
+ iacc_mat_10_0 = _mm256_madd_epi16(iacc_mat_10_0, scale_0145_0);
4382
+ iacc_mat_11_0 = _mm256_madd_epi16(iacc_mat_11_0, scale_2367_0);
4383
+
4384
+ iacc_mat_00_1 = _mm256_madd_epi16(iacc_mat_00_1, scale_0145_1);
4385
+ iacc_mat_01_1 = _mm256_madd_epi16(iacc_mat_01_1, scale_2367_1);
4386
+ iacc_mat_10_1 = _mm256_madd_epi16(iacc_mat_10_1, scale_0145_1);
4387
+ iacc_mat_11_1 = _mm256_madd_epi16(iacc_mat_11_1, scale_2367_1);
4388
+
4389
+ // Straighten out to make 4 row vectors (4 for each sub block which are accumulated together in the next step)
4390
+ __m256i iacc_row_0_0 = _mm256_blend_epi32(iacc_mat_00_0, _mm256_shuffle_epi32(iacc_mat_01_0, 78), 204);
4391
+ __m256i iacc_row_1_0 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_00_0, 78), iacc_mat_01_0, 204);
4392
+ __m256i iacc_row_2_0 = _mm256_blend_epi32(iacc_mat_10_0, _mm256_shuffle_epi32(iacc_mat_11_0, 78), 204);
4393
+ __m256i iacc_row_3_0 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_10_0, 78), iacc_mat_11_0, 204);
4394
+ __m256i iacc_row_0_1 = _mm256_blend_epi32(iacc_mat_00_1, _mm256_shuffle_epi32(iacc_mat_01_1, 78), 204);
4395
+ __m256i iacc_row_1_1 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_00_1, 78), iacc_mat_01_1, 204);
4396
+ __m256i iacc_row_2_1 = _mm256_blend_epi32(iacc_mat_10_1, _mm256_shuffle_epi32(iacc_mat_11_1, 78), 204);
4397
+ __m256i iacc_row_3_1 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_10_1, 78), iacc_mat_11_1, 204);
4398
+
4399
+ __m256i iacc_row_0 = _mm256_add_epi32(iacc_row_0_0, iacc_row_0_1);
4400
+ __m256i iacc_row_1 = _mm256_add_epi32(iacc_row_1_0, iacc_row_1_1);
4401
+ __m256i iacc_row_2 = _mm256_add_epi32(iacc_row_2_0, iacc_row_2_1);
4402
+ __m256i iacc_row_3 = _mm256_add_epi32(iacc_row_3_0, iacc_row_3_1);
4403
+
4404
+ // Load the scale(d) values for all the 4 Q8_k blocks and repeat it across lanes
4405
+ const __m128 row_scale_f32_sse = _mm_load_ps(a_ptrs[rp][b].d);
4406
+ const __m256 row_scale_f32 = _mm256_set_m128(row_scale_f32_sse, row_scale_f32_sse);//GGML_F32Cx8_REPEAT_LOAD(a_ptrs[rp][b].d, loadMask);
4407
+
4408
+ // Multiply with appropiate scales and accumulate (for both d and dmin) below
4409
+ acc_rows[rp * 4] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_0), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[rp * 4]);
4410
+ acc_rows[rp * 4 + 1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_1), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[rp * 4 + 1]);
4411
+ acc_rows[rp * 4 + 2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_2), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[rp * 4 + 2]);
4412
+ acc_rows[rp * 4 + 3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_3), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[rp * 4 + 3]);
4413
+
4414
+ __m256i iacc_row_min_0 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_hsum_0123_01, 0), mins_01);
4415
+ __m256i iacc_row_min_1 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_hsum_0123_01, 85), mins_01);
4416
+ __m256i iacc_row_min_2 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_hsum_0123_01, 170), mins_01);
4417
+ __m256i iacc_row_min_3 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_hsum_0123_01, 255), mins_01);
4418
+
4419
+ acc_min_rows[rp * 4] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_0), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_min_rows[rp * 4]);
4420
+ acc_min_rows[rp * 4 + 1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_1), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_min_rows[rp * 4 + 1]);
4421
+ acc_min_rows[rp * 4 + 2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_2), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_min_rows[rp * 4 + 2]);
4422
+ acc_min_rows[rp * 4 + 3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_3), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_min_rows[rp * 4 + 3]);
4423
+
4424
+ }
4425
+ }
4426
+ }
4427
+ // Store the accumulated values
4428
+ for (int i = 0; i < 16; i++) {
4429
+ _mm256_storeu_ps((float * )(s + ((y * 4 + i) * bs + x * 8)), _mm256_sub_ps(acc_rows[i], acc_min_rows[i]));
4430
+ }
4431
+ }
4432
+ }
4433
+ for (; y < nr / 4; y++) {
4434
+
4435
+ const block_q8_Kx4 * a_ptr = a_ptr_start + (y * nb);
4436
+
4437
+ for (int64_t x = 0; x < nc / 8; x++) {
4438
+
4439
+ const block_q4_Kx8 * b_ptr = b_ptr_start + (x * b_nb);
4440
+
4441
+ // Master FP accumulators
4442
+ __m256 acc_rows[4];
4443
+ for (int i = 0; i < 4; i++) {
4444
+ acc_rows[i] = _mm256_setzero_ps();
4445
+ }
4446
+
4447
+ __m256 acc_min_rows[4];
4448
+ for (int i = 0; i < 4; i++) {
4449
+ acc_min_rows[i] = _mm256_setzero_ps();
4450
+ }
4451
+
4452
+ for (int64_t b = 0; b < nb; b++) {
4453
+
4454
+ // Scale values - Load the eight scale values of block_q4_Kx8
4455
+ const __m256 col_scale_f32 = GGML_F32Cx8_LOAD(b_ptr[b].d);
4456
+
4457
+ // dmin values - Load the eight dmin values of block_q4_Kx8
4458
+ const __m256 col_dmin_f32 = GGML_F32Cx8_LOAD(b_ptr[b].dmin);
4459
+
4460
+ // Loop to iterate over the eight sub blocks of a super block - two sub blocks are processed per iteration
4461
+ for (int sb = 0; sb < QK_K / 64; sb++) {
4462
+
4463
+ // Load the eight block_q4_k for two sub blocks quantized values interleaved with each other in chunks of eight bytes - B0,B1 ....B6,B7
4464
+ const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + sb * 256));
4465
+ const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 32 + sb * 256));
4466
+ const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 64 + sb * 256));
4467
+ const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 96 + sb * 256));
4468
+ const __m256i rhs_raw_mat_0123_2 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 128 + sb * 256));
4469
+ const __m256i rhs_raw_mat_4567_2 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 160 + sb * 256));
4470
+ const __m256i rhs_raw_mat_0123_3 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 192 + sb * 256));
4471
+ const __m256i rhs_raw_mat_4567_3 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 224 + sb * 256));
4472
+
4473
+ // Save the values in the following vectors in the formats B0B1B4B5, B2B3B6B7 for further processing and storing of values
4474
+ const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240);
4475
+ const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240);
4476
+ const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240);
4477
+ const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240);
4478
+ const __m256i rhs_raw_mat_0145_2 = _mm256_blend_epi32(rhs_raw_mat_0123_2, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_2, requiredOrder), 240);
4479
+ const __m256i rhs_raw_mat_2367_2 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_2, requiredOrder), rhs_raw_mat_4567_2, 240);
4480
+ const __m256i rhs_raw_mat_0145_3 = _mm256_blend_epi32(rhs_raw_mat_0123_3, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_3, requiredOrder), 240);
4481
+ const __m256i rhs_raw_mat_2367_3 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_3, requiredOrder), rhs_raw_mat_4567_3, 240);
4482
+
4483
+ // 4-bit -> 8-bit
4484
+ // First sub block of the two sub blocks processed in the iteration
4485
+ const __m256i rhs_mat_0145_00 = _mm256_and_si256(rhs_raw_mat_0145_0, m4b); //B00(0-7) B01(0-7) B04(0-7) B05(0-7)
4486
+ const __m256i rhs_mat_2367_00 = _mm256_and_si256(rhs_raw_mat_2367_0, m4b); //B02(0-7) B03(0-7) B06(0-7) B07(0-7)
4487
+
4488
+ const __m256i rhs_mat_0145_01 = _mm256_and_si256(rhs_raw_mat_0145_1, m4b); //B00(8-15) B01(8-15) B04(8-15) B05(8-15)
4489
+ const __m256i rhs_mat_2367_01 = _mm256_and_si256(rhs_raw_mat_2367_1, m4b); //B02(8-15) B03(8-15) B06(8-15) B07(8-15)
4490
+
4491
+ const __m256i rhs_mat_0145_02 = _mm256_and_si256(rhs_raw_mat_0145_2, m4b); //B00(16-23) B01(16-23) B04(16-23) B05(16-23)
4492
+ const __m256i rhs_mat_2367_02 = _mm256_and_si256(rhs_raw_mat_2367_2, m4b); //B02(16-23) B03(16-23) B06(16-23) B07(16-23)
4493
+
4494
+ const __m256i rhs_mat_0145_03 = _mm256_and_si256(rhs_raw_mat_0145_3, m4b); //B00(24-31) B01(24-31) B04(24-31) B05(24-31)
4495
+ const __m256i rhs_mat_2367_03 = _mm256_and_si256(rhs_raw_mat_2367_3, m4b); //B02(24-31) B03(24-31) B06(24-31) B07(24-31)
4496
+
4497
+ // Second sub block of the two sub blocks processed in the iteration
4498
+ const __m256i rhs_mat_0145_10 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_0, 4), m4b); //B10(0-7) B11(0-7) B14(0-7) B15(0-7)
4499
+ const __m256i rhs_mat_2367_10 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_0, 4), m4b); //B12(0-7) B13(0-7) B16(0-7) B17(0-7)
4500
+
4501
+ const __m256i rhs_mat_0145_11 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_1, 4), m4b); //B10(8-15) B11(8-15) B14(8-15) B15(8-15)
4502
+ const __m256i rhs_mat_2367_11 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_1, 4), m4b); //B12(8-15) B13(8-15) B16(8-15) B17(8-15)
4503
+
4504
+ const __m256i rhs_mat_0145_12 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_2, 4), m4b); //B10(16-23) B11(16-23) B14(16-23) B15(16-23)
4505
+ const __m256i rhs_mat_2367_12 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_2, 4), m4b); //B12(16-23) B13(16-23) B16(16-23) B17(16-23)
4506
+
4507
+ const __m256i rhs_mat_0145_13 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_3, 4), m4b); //B10(24-31) B11(24-31) B14(24-31) B15(24-31)
4508
+ const __m256i rhs_mat_2367_13 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_3, 4), m4b); //B12(24-31) B13(24-31) B16(24-31) B17(24-31)
4509
+
4510
+ // Shuffle pattern one - right side input
4511
+ const __m256i rhs_mat_0145_00_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_00, 136); //B00(0-3) B01(0-3) B00(0-3) B01(0-3) B04(0-3) B05(0-3) B04(0-3) B05(0-3)
4512
+ const __m256i rhs_mat_2367_00_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_00, 136); //B02(0-3) B03(0-3) B02(0-3) B03(0-3) B06(0-3) B07(0-3) B06(0-3) B07(0-3)
4513
+
4514
+ const __m256i rhs_mat_0145_01_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_01, 136); //B00(8-11) B01(8-11) B00(8-11) B01(8-11) B04(8-11) B05(8-11) B04(8-11) B05(8-11)
4515
+ const __m256i rhs_mat_2367_01_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_01, 136); //B02(8-11) B03(8-11) B02(8-11) B03(8-11) B06(8-11) B07(8-11) B06(8-11) B07(8-11)
4516
+
4517
+ const __m256i rhs_mat_0145_02_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_02, 136); //B00(16-19) B01(16-19) B00(16-19) B01(16-19) B04(16-19) B05(16-19) B04(16-19) B05(16-19)
4518
+ const __m256i rhs_mat_2367_02_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_02, 136); //B02(16-19) B03(16-19) B02(16-19) B03(16-19) B06(16-19) B07(16-19) B06(16-19) B07(16-19)
4519
+
4520
+ const __m256i rhs_mat_0145_03_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_03, 136); //B00(24-27) B01(24-27) B00(24-27) B01(24-27) B04(24-27) B05(24-27) B04(24-27) B05(24-27)
4521
+ const __m256i rhs_mat_2367_03_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_03, 136); //B02(24-27) B03(24-27) B02(24-27) B03(24-27) B06(24-27) B07(24-27) B06(24-27) B07(24-27)
4522
+
4523
+ const __m256i rhs_mat_0145_10_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_10, 136); //B10(0-3) B11(0-3) B10(0-3) B11(0-3) B14(0-3) B15(0-3) B14(0-3) B15(0-3)
4524
+ const __m256i rhs_mat_2367_10_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_10, 136); //B12(0-3) B13(0-3) B12(0-3) B13(0-3) B16(0-3) B17(0-3) B16(0-3) B17(0-3)
4525
+
4526
+ const __m256i rhs_mat_0145_11_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_11, 136); //B10(8-11) B11(8-11) B10(8-11) B11(8-11) B14(8-11) B15(8-11) B14(8-11) B15(8-11)
4527
+ const __m256i rhs_mat_2367_11_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_11, 136); //B12(8-11) B13(8-11) B12(8-11) B13(8-11) B16(8-11) B17(8-11) B16(8-11) B17(8-11)
4528
+
4529
+ const __m256i rhs_mat_0145_12_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_12, 136); //B10(16-19) B11(16-19) B10(16-19) B11(16-19) B14(16-19) B15(16-19) B14(16-19) B15(16-19)
4530
+ const __m256i rhs_mat_2367_12_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_12, 136); //B12(16-19) B13(16-19) B12(16-19) B13(16-19) B16(16-19) B17(16-19) B16(16-19) B17(16-19)
4531
+
4532
+ const __m256i rhs_mat_0145_13_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_13, 136); //B10(24-27) B11(24-27) B10(24-27) B11(24-27) B14(24-27) B15(24-27) B14(24-27) B15(24-27)
4533
+ const __m256i rhs_mat_2367_13_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_13, 136); //B12(24-27) B13(24-27) B12(24-27) B13(24-27) B16(24-27) B17(24-27) B16(24-27) B17(24-27)
4534
+
4535
+ // Shuffle pattern two - right side input
4536
+ const __m256i rhs_mat_0145_00_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_00, 221); //B00(4-7) B01(4-7) B00(4-7) B01(4-7) B04(4-7) B05(4-7) B04(4-7) B05(4-7)
4537
+ const __m256i rhs_mat_2367_00_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_00, 221); //B02(4-7) B03(4-7) B02(4-7) B03(4-7) B06(4-7) B07(4-7) B06(4-7) B07(4-7)
4538
+
4539
+ const __m256i rhs_mat_0145_01_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_01, 221); //B00(12-15) B01(12-15) B00(12-15) B01(12-15) B04(12-15) B05(12-15) B04(12-15) B05(12-15)
4540
+ const __m256i rhs_mat_2367_01_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_01, 221); //B02(12-15) B03(12-15) B02(12-15) B03(12-15) B06(12-15) B07(12-15) B06(12-15) B07(12-15)
4541
+
4542
+ const __m256i rhs_mat_0145_02_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_02, 221); //B00(20-23) B01(20-23) B00(20-23) B01(20-23) B04(20-23) B05(20-23) B04(20-23) B05(20-23)
4543
+ const __m256i rhs_mat_2367_02_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_02, 221); //B02(20-23) B03(20-23) B02(20-23) B03(20-23) B06(20-23) B07(20-23) B06(20-23) B07(20-23)
4544
+
4545
+ const __m256i rhs_mat_0145_03_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_03, 221); //B00(28-31) B01(28-31) B00(28-31) B01(28-31) B04(28-31) B05(28-31) B04(28-31) B05(28-31)
4546
+ const __m256i rhs_mat_2367_03_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_03, 221); //B02(28-31) B03(28-31) B02(28-31) B03(28-31) B06(28-31) B07(28-31) B06(28-31) B07(28-31)
4547
+
4548
+ const __m256i rhs_mat_0145_10_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_10, 221); //B10(4-7) B11(4-7) B10(4-7) B11(4-7) B14(4-7) B15(4-7) B14(4-7) B15(4-7)
4549
+ const __m256i rhs_mat_2367_10_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_10, 221); //B12(4-7) B13(4-7) B12(4-7) B13(4-7) B16(4-7) B17(4-7) B16(4-7) B17(4-7)
4550
+
4551
+ const __m256i rhs_mat_0145_11_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_11, 221); //B10(12-15) B11(12-15) B10(12-15) B11(12-15) B14(12-15) B15(12-15) B14(12-15) B15(12-15)
4552
+ const __m256i rhs_mat_2367_11_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_11, 221); //B12(12-15) B13(12-15) B12(12-15) B13(12-15) B16(12-15) B17(12-15) B16(12-15) B17(12-15)
4553
+
4554
+ const __m256i rhs_mat_0145_12_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_12, 221); //B10(20-23) B11(20-23) B10(20-23) B11(20-23) B14(20-23) B15(20-23) B14(20-23) B15(20-23)
4555
+ const __m256i rhs_mat_2367_12_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_12, 221); //B12(20-23) B13(20-23) B12(20-23) B13(20-23) B16(20-23) B17(20-23) B16(20-23) B17(20-23)
4556
+
4557
+ const __m256i rhs_mat_0145_13_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_13, 221); //B10(28-31) B11(28-31) B10(28-31) B11(28-31) B14(28-31) B15(28-31) B14(28-31) B15(28-31)
4558
+ const __m256i rhs_mat_2367_13_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_13, 221); //B12(28-31) B13(28-31) B12(28-31) B13(28-31) B16(28-31) B17(28-31) B16(28-31) B17(28-31)
4559
+
4560
+ uint32_t utmp_0[4], utmp_1[4];
4561
+
4562
+ // Scales and Mins of corresponding sub blocks from different Q4_K structures are stored together
4563
+ // The below block is for eg to extract first sub block's scales and mins from different Q4_K structures for the sb loop
4564
+ memcpy(utmp_0, b_ptr[b].scales + 24 * sb, 12);
4565
+ utmp_0[3] = ((utmp_0[2] >> 4) & kmask2) | (((utmp_0[1] >> 6) & kmask3) << 4);
4566
+ const uint32_t uaux_0 = utmp_0[1] & kmask1;
4567
+ utmp_0[1] = (utmp_0[2] & kmask2) | (((utmp_0[0] >> 6) & kmask3) << 4);
4568
+ utmp_0[2] = uaux_0;
4569
+ utmp_0[0] &= kmask1;
4570
+
4571
+ // The below block is for eg to extract second sub block's scales and mins from different Q4_K structures when sb = 1
4572
+ memcpy(utmp_1, b_ptr[b].scales + 12 + sb * 24, 12);
4573
+ utmp_1[3] = ((utmp_1[2] >> 4) & kmask2) | (((utmp_1[1] >> 6) & kmask3) << 4);
4574
+ const uint32_t uaux_1 = utmp_1[1] & kmask1;
4575
+ utmp_1[1] = (utmp_1[2] & kmask2) | (((utmp_1[0] >> 6) & kmask3) << 4);
4576
+ utmp_1[2] = uaux_1;
4577
+ utmp_1[0] &= kmask1;
4578
+
4579
+ // Scales of first sub block in the sb loop
4580
+ const __m128i mins_and_scales_0 = _mm_set_epi32(utmp_0[3], utmp_0[2], utmp_0[1], utmp_0[0]);
4581
+ const __m256i scales_0 = _mm256_cvtepu8_epi16(_mm_unpacklo_epi8(mins_and_scales_0, mins_and_scales_0));
4582
+
4583
+ // Scales of second sub block in the sb loop
4584
+ const __m128i mins_and_scales_1 = _mm_set_epi32(utmp_1[3], utmp_1[2], utmp_1[1], utmp_1[0]);
4585
+ const __m256i scales_1 = _mm256_cvtepu8_epi16(_mm_unpacklo_epi8(mins_and_scales_1, mins_and_scales_1));
4586
+
4587
+ // Mins of first and second sub block of Q4_K block are arranged side by side
4588
+ const __m256i mins_01 = _mm256_cvtepu8_epi16(_mm_unpacklo_epi8(_mm_shuffle_epi32(mins_and_scales_0, 78), _mm_shuffle_epi32(mins_and_scales_1, 78)));
4589
+
4590
+ const __m256i scale_0145_0 = _mm256_shuffle_epi32(scales_0, 68);
4591
+ const __m256i scale_2367_0 = _mm256_shuffle_epi32(scales_0, 238);
4592
+
4593
+ const __m256i scale_0145_1 = _mm256_shuffle_epi32(scales_1, 68);
4594
+ const __m256i scale_2367_1 = _mm256_shuffle_epi32(scales_1, 238);
4595
+
4596
+ // Load the four block_q8_k quantized values interleaved with each other in chunks of eight bytes - A0,A1,A2,A3
4597
+ // Loaded as set of 128 bit vectors and repeated into a 256 bit vector
4598
+ __m256i lhs_mat_0123_00 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 256 * sb)));
4599
+ __m256i lhs_mat_01_00 = _mm256_permute2f128_si256(lhs_mat_0123_00, lhs_mat_0123_00, 0);
4600
+ __m256i lhs_mat_23_00 = _mm256_permute2f128_si256(lhs_mat_0123_00, lhs_mat_0123_00, 17);
4601
+ __m256i lhs_mat_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 32 + 256 * sb)));
4602
+ __m256i lhs_mat_01_01 = _mm256_permute2f128_si256(lhs_mat_0123_01, lhs_mat_0123_01, 0);
4603
+ __m256i lhs_mat_23_01 = _mm256_permute2f128_si256(lhs_mat_0123_01, lhs_mat_0123_01, 17);
4604
+ __m256i lhs_mat_0123_02 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 64 + 256 * sb)));
4605
+ __m256i lhs_mat_01_02 = _mm256_permute2f128_si256(lhs_mat_0123_02, lhs_mat_0123_02, 0);
4606
+ __m256i lhs_mat_23_02 = _mm256_permute2f128_si256(lhs_mat_0123_02, lhs_mat_0123_02, 17);
4607
+ __m256i lhs_mat_0123_03 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 96 + 256 * sb)));
4608
+ __m256i lhs_mat_01_03 = _mm256_permute2f128_si256(lhs_mat_0123_03, lhs_mat_0123_03, 0);
4609
+ __m256i lhs_mat_23_03 = _mm256_permute2f128_si256(lhs_mat_0123_03, lhs_mat_0123_03, 17);
4610
+ __m256i lhs_mat_0123_10 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 128 + 256 * sb)));
4611
+ __m256i lhs_mat_01_10 = _mm256_permute2f128_si256(lhs_mat_0123_10, lhs_mat_0123_10, 0);
4612
+ __m256i lhs_mat_23_10 = _mm256_permute2f128_si256(lhs_mat_0123_10, lhs_mat_0123_10, 17);
4613
+ __m256i lhs_mat_0123_11 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 160 + 256 * sb)));
4614
+ __m256i lhs_mat_01_11 = _mm256_permute2f128_si256(lhs_mat_0123_11, lhs_mat_0123_11, 0);
4615
+ __m256i lhs_mat_23_11 = _mm256_permute2f128_si256(lhs_mat_0123_11, lhs_mat_0123_11, 17);
4616
+ __m256i lhs_mat_0123_12 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 192 + 256 * sb)));
4617
+ __m256i lhs_mat_01_12 = _mm256_permute2f128_si256(lhs_mat_0123_12, lhs_mat_0123_12, 0);
4618
+ __m256i lhs_mat_23_12 = _mm256_permute2f128_si256(lhs_mat_0123_12, lhs_mat_0123_12, 17);
4619
+ __m256i lhs_mat_0123_13 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 224 + 256 * sb)));
4620
+ __m256i lhs_mat_01_13 = _mm256_permute2f128_si256(lhs_mat_0123_13, lhs_mat_0123_13, 0);
4621
+ __m256i lhs_mat_23_13 = _mm256_permute2f128_si256(lhs_mat_0123_13, lhs_mat_0123_13, 17);
4622
+
4623
+ // Bsums are loaded - four bsums are loaded (for two sub blocks) for the different Q8_K blocks
4624
+ __m256i lhs_bsums_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].bsums + 16 * sb)));
4625
+ __m256i lhs_bsums_hsum_0123_01 = _mm256_castsi128_si256(_mm_hadd_epi16(_mm256_castsi256_si128(lhs_bsums_0123_01), _mm256_extractf128_si256(lhs_bsums_0123_01, 1)));
4626
+ lhs_bsums_hsum_0123_01 = _mm256_permute2x128_si256(lhs_bsums_hsum_0123_01, lhs_bsums_hsum_0123_01, 0);
4627
+
4628
+ // Shuffle pattern one - left side input
4629
+ const __m256i lhs_mat_01_00_sp1 = _mm256_shuffle_epi32(lhs_mat_01_00, 160); //A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3)
4630
+ const __m256i lhs_mat_23_00_sp1 = _mm256_shuffle_epi32(lhs_mat_23_00, 160); //A02(0-3) A03(0-3) A02(0-3) A03(0-3) A02(0-3) A03(0-3) A02(0-3) A03(0-3)
4631
+
4632
+ const __m256i lhs_mat_01_01_sp1 = _mm256_shuffle_epi32(lhs_mat_01_01, 160); //A00(8-11) A00(8-11) A01(8-11) A01(8-11) A00(8-11) A00(8-11) A01(8-11) A01(8-11)
4633
+ const __m256i lhs_mat_23_01_sp1 = _mm256_shuffle_epi32(lhs_mat_23_01, 160); //A02(8-11) A03(8-11) A02(8-11) A03(8-11) A02(8-11) A03(8-11) A02(8-11) A03(8-11)
4634
+
4635
+ const __m256i lhs_mat_01_02_sp1 = _mm256_shuffle_epi32(lhs_mat_01_02, 160); //A00(16-19) A00(16-19) A01(16-19) A01(16-19) A00(16-19) A00(16-19) A01(16-19) A01(16-19)
4636
+ const __m256i lhs_mat_23_02_sp1 = _mm256_shuffle_epi32(lhs_mat_23_02, 160); //A02(16-19) A03(16-19) A02(16-19) A03(16-19) A02(16-19) A03(16-19) A02(16-19) A03(16-19)
4637
+
4638
+ const __m256i lhs_mat_01_03_sp1 = _mm256_shuffle_epi32(lhs_mat_01_03, 160); //A00(24-27) A00(24-27) A01(24-27) A01(24-27) A00(24-27) A00(24-27) A01(24-27) A01(24-27)
4639
+ const __m256i lhs_mat_23_03_sp1 = _mm256_shuffle_epi32(lhs_mat_23_03, 160); //A02(24-27) A03(24-27) A02(24-27) A03(24-27) A02(24-27) A03(24-27) A02(24-27) A03(24-27)
4640
+
4641
+ const __m256i lhs_mat_01_10_sp1 = _mm256_shuffle_epi32(lhs_mat_01_10, 160); //A10(0-3) A10(0-3) A11(0-3) A11(0-3) A10(0-3) A10(0-3) A11(0-3) A11(0-3)
4642
+ const __m256i lhs_mat_23_10_sp1 = _mm256_shuffle_epi32(lhs_mat_23_10, 160); //A12(0-3) A13(0-3) A12(0-3) A13(0-3) A12(0-3) A13(0-3) A12(0-3) A13(0-3)
4643
+
4644
+ const __m256i lhs_mat_01_11_sp1 = _mm256_shuffle_epi32(lhs_mat_01_11, 160); //A10(8-11) A10(8-11) A11(8-11) A11(8-11) A10(8-11) A10(8-11) A11(8-11) A11(8-11)
4645
+ const __m256i lhs_mat_23_11_sp1 = _mm256_shuffle_epi32(lhs_mat_23_11, 160); //A12(8-11) A13(8-11) A12(8-11) A13(8-11) A12(8-11) A13(8-11) A12(8-11) A13(8-11)
4646
+
4647
+ const __m256i lhs_mat_01_12_sp1 = _mm256_shuffle_epi32(lhs_mat_01_12, 160); //A10(16-19) A10(16-19) A11(16-19) A11(16-19) A10(16-19) A10(16-19) A11(16-19) A11(16-19)
4648
+ const __m256i lhs_mat_23_12_sp1 = _mm256_shuffle_epi32(lhs_mat_23_12, 160); //A12(16-19) A13(16-19) A12(16-19) A13(16-19) A12(16-19) A13(16-19) A12(16-19) A13(16-19)
4649
+
4650
+ const __m256i lhs_mat_01_13_sp1 = _mm256_shuffle_epi32(lhs_mat_01_13, 160); //A10(24-27) A10(24-27) A11(24-27) A11(24-27) A10(24-27) A10(24-27) A11(24-27) A11(24-27)
4651
+ const __m256i lhs_mat_23_13_sp1 = _mm256_shuffle_epi32(lhs_mat_23_13, 160); //A12(24-27) A13(24-27) A12(24-27) A13(24-27) A12(24-27) A13(24-27) A12(24-27) A13(24-27)
4652
+
4653
+ // Shuffle pattern two- left side input
4654
+ const __m256i lhs_mat_01_00_sp2 = _mm256_shuffle_epi32(lhs_mat_01_00, 245); //A00(4-7) A00(4-7) A01(4-7) A01(4-7) A00(4-7) A00(4-7) A01(4-7) A01(4-7)
4655
+ const __m256i lhs_mat_23_00_sp2 = _mm256_shuffle_epi32(lhs_mat_23_00, 245); //A02(4-7) A03(4-7) A02(4-7) A03(4-7) A02(4-7) A03(4-7) A02(4-7) A03(4-7)
4656
+
4657
+ const __m256i lhs_mat_01_01_sp2 = _mm256_shuffle_epi32(lhs_mat_01_01, 245); //A00(12-15) A00(12-15) A01(12-15) A01(12-15) A00(12-15) A00(12-15) A01(12-15) A01(12-15)
4658
+ const __m256i lhs_mat_23_01_sp2 = _mm256_shuffle_epi32(lhs_mat_23_01, 245); //A02(12-15) A03(12-15) A02(12-15) A03(12-15) A02(12-15) A03(12-15) A02(12-15) A03(12-15)
4659
+
4660
+ const __m256i lhs_mat_01_02_sp2 = _mm256_shuffle_epi32(lhs_mat_01_02, 245); //A00(20-23) A00(20-23) A01(20-23) A01(20-23) A00(20-23) A00(20-23) A01(20-23) A01(20-23)
4661
+ const __m256i lhs_mat_23_02_sp2 = _mm256_shuffle_epi32(lhs_mat_23_02, 245); //A02(20-23) A03(20-23) A02(20-23) A03(20-23) A02(20-23) A03(20-23) A02(20-23) A03(20-23)
4662
+
4663
+ const __m256i lhs_mat_01_03_sp2 = _mm256_shuffle_epi32(lhs_mat_01_03, 245); //A00(28-31) A00(28-31) A01(28-31) A01(28-31) A00(28-31) A00(28-31) A01(28-31) A01(28-31)
4664
+ const __m256i lhs_mat_23_03_sp2 = _mm256_shuffle_epi32(lhs_mat_23_03, 245); //A02(28-31) A03(28-31) A02(28-31) A03(28-31) A02(28-31) A03(28-31) A02(28-31) A03(28-31)
4665
+
4666
+ const __m256i lhs_mat_01_10_sp2 = _mm256_shuffle_epi32(lhs_mat_01_10, 245); //A10(4-7) A10(4-7) A11(4-7) A11(4-7) A10(4-7) A10(4-7) A11(4-7) A11(4-7)
4667
+ const __m256i lhs_mat_23_10_sp2 = _mm256_shuffle_epi32(lhs_mat_23_10, 245); //A12(4-7) A13(4-7) A12(4-7) A13(4-7) A12(4-7) A13(4-7) A12(4-7) A13(4-7)
4668
+
4669
+ const __m256i lhs_mat_01_11_sp2 = _mm256_shuffle_epi32(lhs_mat_01_11, 245); //A10(12-15) A10(12-15) A11(12-15) A11(12-15) A10(12-15) A10(12-15) A11(12-15) A11(12-15)
4670
+ const __m256i lhs_mat_23_11_sp2 = _mm256_shuffle_epi32(lhs_mat_23_11, 245); //A12(12-15) A13(12-15) A12(12-15) A13(12-15) A12(12-15) A13(12-15) A12(12-15) A13(12-15)
4671
+
4672
+ const __m256i lhs_mat_01_12_sp2 = _mm256_shuffle_epi32(lhs_mat_01_12, 245); //A10(20-23) A10(20-23) A11(20-23) A11(20-23) A10(20-23) A10(20-23) A11(20-23) A11(20-23)
4673
+ const __m256i lhs_mat_23_12_sp2 = _mm256_shuffle_epi32(lhs_mat_23_12, 245); //A12(20-23) A13(20-23) A12(20-23) A13(20-23) A12(20-23) A13(20-23) A12(20-23) A13(20-23)
4674
+
4675
+ const __m256i lhs_mat_01_13_sp2 = _mm256_shuffle_epi32(lhs_mat_01_13, 245); //A10(28-31) A10(28-31) A11(28-31) A11(28-31) A10(28-31) A10(28-31) A11(28-31) A11(28-31)
4676
+ const __m256i lhs_mat_23_13_sp2 = _mm256_shuffle_epi32(lhs_mat_23_13, 245); //A12(28-31) A13(28-31) A12(28-31) A13(28-31) A12(28-31) A13(28-31) A12(28-31) A13(28-31)
4677
+
4678
+ // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
4679
+ __m256i iacc_mat_00_0_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_03_sp1, lhs_mat_01_03_sp1), _mm256_maddubs_epi16(rhs_mat_0145_02_sp1, lhs_mat_01_02_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_01_sp1, lhs_mat_01_01_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_00_sp1, lhs_mat_01_00_sp1));
4680
+ __m256i iacc_mat_01_0_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_03_sp1, lhs_mat_01_03_sp1), _mm256_maddubs_epi16(rhs_mat_2367_02_sp1, lhs_mat_01_02_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_01_sp1, lhs_mat_01_01_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_00_sp1, lhs_mat_01_00_sp1));
4681
+ __m256i iacc_mat_10_0_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_03_sp1, lhs_mat_23_03_sp1), _mm256_maddubs_epi16(rhs_mat_0145_02_sp1, lhs_mat_23_02_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_01_sp1, lhs_mat_23_01_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_00_sp1, lhs_mat_23_00_sp1));
4682
+ __m256i iacc_mat_11_0_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_03_sp1, lhs_mat_23_03_sp1), _mm256_maddubs_epi16(rhs_mat_2367_02_sp1, lhs_mat_23_02_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_01_sp1, lhs_mat_23_01_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_00_sp1, lhs_mat_23_00_sp1));
4683
+ __m256i iacc_mat_00_1_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_13_sp1, lhs_mat_01_13_sp1), _mm256_maddubs_epi16(rhs_mat_0145_12_sp1, lhs_mat_01_12_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_11_sp1, lhs_mat_01_11_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_10_sp1, lhs_mat_01_10_sp1));
4684
+ __m256i iacc_mat_01_1_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_13_sp1, lhs_mat_01_13_sp1), _mm256_maddubs_epi16(rhs_mat_2367_12_sp1, lhs_mat_01_12_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_11_sp1, lhs_mat_01_11_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_10_sp1, lhs_mat_01_10_sp1));
4685
+ __m256i iacc_mat_10_1_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_13_sp1, lhs_mat_23_13_sp1), _mm256_maddubs_epi16(rhs_mat_0145_12_sp1, lhs_mat_23_12_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_11_sp1, lhs_mat_23_11_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_10_sp1, lhs_mat_23_10_sp1));
4686
+ __m256i iacc_mat_11_1_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_13_sp1, lhs_mat_23_13_sp1), _mm256_maddubs_epi16(rhs_mat_2367_12_sp1, lhs_mat_23_12_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_11_sp1, lhs_mat_23_11_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_10_sp1, lhs_mat_23_10_sp1));
4687
+
4688
+ __m256i iacc_mat_00_0_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_03_sp2, lhs_mat_01_03_sp2), _mm256_maddubs_epi16(rhs_mat_0145_02_sp2, lhs_mat_01_02_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_01_sp2, lhs_mat_01_01_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_00_sp2, lhs_mat_01_00_sp2));
4689
+ __m256i iacc_mat_01_0_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_03_sp2, lhs_mat_01_03_sp2), _mm256_maddubs_epi16(rhs_mat_2367_02_sp2, lhs_mat_01_02_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_01_sp2, lhs_mat_01_01_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_00_sp2, lhs_mat_01_00_sp2));
4690
+ __m256i iacc_mat_10_0_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_03_sp2, lhs_mat_23_03_sp2), _mm256_maddubs_epi16(rhs_mat_0145_02_sp2, lhs_mat_23_02_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_01_sp2, lhs_mat_23_01_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_00_sp2, lhs_mat_23_00_sp2));
4691
+ __m256i iacc_mat_11_0_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_03_sp2, lhs_mat_23_03_sp2), _mm256_maddubs_epi16(rhs_mat_2367_02_sp2, lhs_mat_23_02_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_01_sp2, lhs_mat_23_01_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_00_sp2, lhs_mat_23_00_sp2));
4692
+ __m256i iacc_mat_00_1_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_13_sp2, lhs_mat_01_13_sp2), _mm256_maddubs_epi16(rhs_mat_0145_12_sp2, lhs_mat_01_12_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_11_sp2, lhs_mat_01_11_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_10_sp2, lhs_mat_01_10_sp2));
4693
+ __m256i iacc_mat_01_1_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_13_sp2, lhs_mat_01_13_sp2), _mm256_maddubs_epi16(rhs_mat_2367_12_sp2, lhs_mat_01_12_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_11_sp2, lhs_mat_01_11_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_10_sp2, lhs_mat_01_10_sp2));
4694
+ __m256i iacc_mat_10_1_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_13_sp2, lhs_mat_23_13_sp2), _mm256_maddubs_epi16(rhs_mat_0145_12_sp2, lhs_mat_23_12_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_11_sp2, lhs_mat_23_11_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_10_sp2, lhs_mat_23_10_sp2));
4695
+ __m256i iacc_mat_11_1_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_13_sp2, lhs_mat_23_13_sp2), _mm256_maddubs_epi16(rhs_mat_2367_12_sp2, lhs_mat_23_12_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_11_sp2, lhs_mat_23_11_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_10_sp2, lhs_mat_23_10_sp2));
4696
+
4697
+ // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
4698
+ __m256i iacc_mat_00_0 = _mm256_add_epi16(iacc_mat_00_0_sp1, iacc_mat_00_0_sp2);
4699
+ __m256i iacc_mat_01_0 = _mm256_add_epi16(iacc_mat_01_0_sp1, iacc_mat_01_0_sp2);
4700
+ __m256i iacc_mat_10_0 = _mm256_add_epi16(iacc_mat_10_0_sp1, iacc_mat_10_0_sp2);
4701
+ __m256i iacc_mat_11_0 = _mm256_add_epi16(iacc_mat_11_0_sp1, iacc_mat_11_0_sp2);
4702
+
4703
+ __m256i iacc_mat_00_1 = _mm256_add_epi16(iacc_mat_00_1_sp1, iacc_mat_00_1_sp2);
4704
+ __m256i iacc_mat_01_1 = _mm256_add_epi16(iacc_mat_01_1_sp1, iacc_mat_01_1_sp2);
4705
+ __m256i iacc_mat_10_1 = _mm256_add_epi16(iacc_mat_10_1_sp1, iacc_mat_10_1_sp2);
4706
+ __m256i iacc_mat_11_1 = _mm256_add_epi16(iacc_mat_11_1_sp1, iacc_mat_11_1_sp2);
4707
+
4708
+ // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
4709
+ iacc_mat_00_0 = _mm256_madd_epi16(iacc_mat_00_0, scale_0145_0);
4710
+ iacc_mat_01_0 = _mm256_madd_epi16(iacc_mat_01_0, scale_2367_0);
4711
+ iacc_mat_10_0 = _mm256_madd_epi16(iacc_mat_10_0, scale_0145_0);
4712
+ iacc_mat_11_0 = _mm256_madd_epi16(iacc_mat_11_0, scale_2367_0);
4713
+
4714
+ iacc_mat_00_1 = _mm256_madd_epi16(iacc_mat_00_1, scale_0145_1);
4715
+ iacc_mat_01_1 = _mm256_madd_epi16(iacc_mat_01_1, scale_2367_1);
4716
+ iacc_mat_10_1 = _mm256_madd_epi16(iacc_mat_10_1, scale_0145_1);
4717
+ iacc_mat_11_1 = _mm256_madd_epi16(iacc_mat_11_1, scale_2367_1);
4718
+
4719
+ // Straighten out to make 4 row vectors (4 for each sub block which are accumulated together in the next step)
4720
+ __m256i iacc_row_0_0 = _mm256_blend_epi32(iacc_mat_00_0, _mm256_shuffle_epi32(iacc_mat_01_0, 78), 204);
4721
+ __m256i iacc_row_1_0 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_00_0, 78), iacc_mat_01_0, 204);
4722
+ __m256i iacc_row_2_0 = _mm256_blend_epi32(iacc_mat_10_0, _mm256_shuffle_epi32(iacc_mat_11_0, 78), 204);
4723
+ __m256i iacc_row_3_0 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_10_0, 78), iacc_mat_11_0, 204);
4724
+ __m256i iacc_row_0_1 = _mm256_blend_epi32(iacc_mat_00_1, _mm256_shuffle_epi32(iacc_mat_01_1, 78), 204);
4725
+ __m256i iacc_row_1_1 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_00_1, 78), iacc_mat_01_1, 204);
4726
+ __m256i iacc_row_2_1 = _mm256_blend_epi32(iacc_mat_10_1, _mm256_shuffle_epi32(iacc_mat_11_1, 78), 204);
4727
+ __m256i iacc_row_3_1 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_10_1, 78), iacc_mat_11_1, 204);
4728
+
4729
+ __m256i iacc_row_0 = _mm256_add_epi32(iacc_row_0_0, iacc_row_0_1);
4730
+ __m256i iacc_row_1 = _mm256_add_epi32(iacc_row_1_0, iacc_row_1_1);
4731
+ __m256i iacc_row_2 = _mm256_add_epi32(iacc_row_2_0, iacc_row_2_1);
4732
+ __m256i iacc_row_3 = _mm256_add_epi32(iacc_row_3_0, iacc_row_3_1);
4733
+
4734
+ // Load the scale(d) values for all the 4 Q8_k blocks and repeat it across lanes
4735
+ const __m128 row_scale_f32_sse = _mm_load_ps(a_ptr[b].d);
4736
+ const __m256 row_scale_f32 = _mm256_set_m128(row_scale_f32_sse, row_scale_f32_sse); //GGML_F32Cx8_REPEAT_LOAD(a_ptrs[rp][b].d, loadMask);
4737
+
4738
+ // Multiply with appropiate scales and accumulate (for both d and dmin) below
4739
+ acc_rows[0] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_0), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[0]);
4740
+ acc_rows[1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_1), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[1]);
4741
+ acc_rows[2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_2), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[2]);
4742
+ acc_rows[3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_3), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[3]);
4743
+
4744
+ __m256i iacc_row_min_0 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_hsum_0123_01, 0), mins_01);
4745
+ __m256i iacc_row_min_1 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_hsum_0123_01, 85), mins_01);
4746
+ __m256i iacc_row_min_2 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_hsum_0123_01, 170), mins_01);
4747
+ __m256i iacc_row_min_3 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_hsum_0123_01, 255), mins_01);
4748
+
4749
+ acc_min_rows[0] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_0), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_min_rows[0]);
4750
+ acc_min_rows[1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_1), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_min_rows[1]);
4751
+ acc_min_rows[2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_2), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_min_rows[2]);
4752
+ acc_min_rows[3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_3), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_min_rows[3]);
4753
+ }
4754
+ }
4755
+
4756
+ // Store the accumulated values
4757
+ for (int i = 0; i < 4; i++) {
4758
+ _mm256_storeu_ps((float * )(s + ((y * 4 + i) * bs + x * 8)), _mm256_sub_ps(acc_rows[i], acc_min_rows[i]));
4759
+ }
4760
+ }
4761
+ }
4762
+
4763
+ #else
4764
+
4765
+ float sumf[4][8];
4766
+ float sum_minf[4][8];
4767
+ uint32_t utmp[32];
4768
+ int sumi1;
4769
+ int sumi2;
4770
+ int sumi;
4771
+
4772
+ for (int y = 0; y < nr / 4; y++) {
4773
+ const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb);
4774
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
4775
+ const block_q4_Kx8 * b_ptr = (const block_q4_Kx8 *) vx + (x * nb);
4776
+ for (int m = 0; m < 4; m++) {
4777
+ for (int j = 0; j < ncols_interleaved; j++) {
4778
+ sumf[m][j] = 0.0;
4779
+ sum_minf[m][j] = 0.0;
4780
+ }
4781
+ }
4782
+ for (int l = 0; l < nb; l++) {
4783
+ for (int sb = 0; sb < 8; sb++) {
4784
+ memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12);
4785
+ utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);
4786
+ const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;
4787
+ utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);
4788
+ utmp[sb * 4 + 2] = uaux_0;
4789
+ utmp[sb * 4 + 0] &= kmask1;
4790
+ }
4791
+ for (int k = 0; k < (qk / (2 * blocklen)); k++) {
4792
+ uint8_t *scales_0 = (uint8_t*) utmp + (k / 4) * 32;
4793
+ uint8_t *scales_1 = (uint8_t*) utmp + (k / 4) * 32 + 16;
4794
+ for (int m = 0; m < 4; m++) {
4795
+ for (int j = 0; j < ncols_interleaved; j++) {
4796
+ sumi1 = 0;
4797
+ sumi2 = 0;
4798
+ sumi = 0;
4799
+ for (int i = 0; i < blocklen; ++i) {
4800
+ const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF);
4801
+ const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4);
4802
+ sumi1 = (v0 * a_ptr[l].qs[(k >> 2) * 256 + (k % 4) * 4 * blocklen + m * blocklen + i]);
4803
+ sumi2 = (v1 * a_ptr[l].qs[(k >> 2) * 256 + (k % 4) * 4 * blocklen + m * blocklen + i + 128]);
4804
+ sumi1 = sumi1 * scales_0[j];
4805
+ sumi2 = sumi2 * scales_1[j];
4806
+ sumi += sumi1 + sumi2;
4807
+ }
4808
+ sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m];
4809
+ }
4810
+ }
4811
+ }
4812
+ for (int sb = 0; sb < 8; sb++) {
4813
+ uint8_t *mins = (uint8_t*) utmp + 8 + sb * 16;
4814
+ for(int m = 0; m < 4; m++) {
4815
+ const int16_t *bsums = a_ptr[l].bsums + (sb * 8) + (m * 4) - ((sb % 2) * 6);
4816
+ for(int j = 0; j < ncols_interleaved; j++) {
4817
+ sum_minf[m][j] += mins[j] * (bsums[0] + bsums[1]) * GGML_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m];
4818
+ }
4819
+ }
4820
+ }
4821
+ }
4822
+ for (int m = 0; m < 4; m++) {
4823
+ for (int j = 0; j < ncols_interleaved; j++) {
4824
+ s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j];
4825
+ }
4826
+ }
4827
+ }
4828
+ }
4829
+ #endif
4830
+ }
4831
+
3483
4832
  static void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
3484
4833
  const int qk = QK8_0;
3485
4834
  const int nb = n / qk;
@@ -3660,6 +5009,82 @@ static block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int blck_size_in
3660
5009
  return out;
3661
5010
  }
3662
5011
 
5012
+ static block_q4_Kx8 make_block_q4_Kx8(block_q4_K * in, unsigned int blck_size_interleave) {
5013
+ block_q4_Kx8 out;
5014
+ //Delta(scale) and dmin values of the eight Q4_K structures are copied onto the output interleaved structure
5015
+ for (int i = 0; i < 8; i++) {
5016
+ out.d[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d;
5017
+ }
5018
+
5019
+ for (int i = 0; i < 8; i++) {
5020
+ out.dmin[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin;
5021
+ }
5022
+
5023
+ const int end = QK_K * 4 / blck_size_interleave;
5024
+
5025
+ // Interleave Q4_K quants by taking 8 bytes at a time
5026
+ for (int i = 0; i < end; ++i) {
5027
+ int src_id = i % 8;
5028
+ int src_offset = (i / 8) * blck_size_interleave;
5029
+ int dst_offset = i * blck_size_interleave;
5030
+
5031
+ uint64_t elems;
5032
+ memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t));
5033
+ memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t));
5034
+ }
5035
+
5036
+ // The below logic is designed so as to unpack and rearrange scales and mins values in Q4_K
5037
+ // Currently the Q4_K structure has 8 scales and 8 mins packed in 12 bytes ( 6 bits for each value)
5038
+ // The output Q4_Kx8 structure has 96 bytes
5039
+ // Every 12 byte is packed such that it contains scales and mins for corresponding sub blocks from Q4_K structure
5040
+ // For eg - First 12 bytes contains 8 scales and 8 mins - each of first sub block from different Q4_K structures
5041
+ uint8_t s[8], m[8];
5042
+
5043
+ for (int i = 0; i < 4; i++) {
5044
+ for (int j = 0; j < 8; j++) {
5045
+ s[j] = in[j].scales[i] & 63;
5046
+ m[j] = in[j].scales[i + 4] & 63;
5047
+ }
5048
+
5049
+ out.scales[i * 12] = (s[0] & 63) + ((s[4] & 48) << 2);
5050
+ out.scales[i * 12 + 1] = (s[1] & 63) + ((s[5] & 48) << 2);
5051
+ out.scales[i * 12 + 2] = (s[2] & 63) + ((s[6] & 48) << 2);
5052
+ out.scales[i * 12 + 3] = (s[3] & 63) + ((s[7] & 48) << 2);
5053
+ out.scales[i * 12 + 4] = (m[0] & 63) + ((m[4] & 48) << 2);
5054
+ out.scales[i * 12 + 5] = (m[1] & 63) + ((m[5] & 48) << 2);
5055
+ out.scales[i * 12 + 6] = (m[2] & 63) + ((m[6] & 48) << 2);
5056
+ out.scales[i * 12 + 7] = (m[3] & 63) + ((m[7] & 48) << 2);
5057
+ out.scales[i * 12 + 8] = (s[4] & 15) + ((m[4] & 15) << 4);
5058
+ out.scales[i * 12 + 9] = (s[5] & 15) + ((m[5] & 15) << 4);
5059
+ out.scales[i * 12 + 10] = (s[6] & 15) + ((m[6] & 15) << 4);
5060
+ out.scales[i * 12 + 11] = (s[7] & 15) + ((m[7] & 15) << 4);
5061
+
5062
+ }
5063
+
5064
+ for (int i = 0; i < 4; i++) {
5065
+ for (int j = 0; j < 8; j++) {
5066
+ s[j] = ((in[j].scales[i] & 192) >> 2) | (in[j].scales[i+8] & 15);
5067
+ m[j] = ((in[j].scales[i + 4] & 192) >> 2) | ((in[j].scales[i+8] & 240) >> 4);
5068
+ }
5069
+
5070
+ out.scales[i * 12 + 48] = (s[0] & 63) + ((s[4] & 48) << 2);
5071
+ out.scales[i * 12 + 49] = (s[1] & 63) + ((s[5] & 48) << 2);
5072
+ out.scales[i * 12 + 50] = (s[2] & 63) + ((s[6] & 48) << 2);
5073
+ out.scales[i * 12 + 51] = (s[3] & 63) + ((s[7] & 48) << 2);
5074
+ out.scales[i * 12 + 52] = (m[0] & 63) + ((m[4] & 48) << 2);
5075
+ out.scales[i * 12 + 53] = (m[1] & 63) + ((m[5] & 48) << 2);
5076
+ out.scales[i * 12 + 54] = (m[2] & 63) + ((m[6] & 48) << 2);
5077
+ out.scales[i * 12 + 55] = (m[3] & 63) + ((m[7] & 48) << 2);
5078
+ out.scales[i * 12 + 56] = (s[4] & 15) + ((m[4] & 15) << 4);
5079
+ out.scales[i * 12 + 57] = (s[5] & 15) + ((m[5] & 15) << 4);
5080
+ out.scales[i * 12 + 58] = (s[6] & 15) + ((m[6] & 15) << 4);
5081
+ out.scales[i * 12 + 59] = (s[7] & 15) + ((m[7] & 15) << 4);
5082
+
5083
+ }
5084
+
5085
+ return out;
5086
+ }
5087
+
3663
5088
  static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
3664
5089
  GGML_ASSERT(t->type == GGML_TYPE_Q4_0);
3665
5090
  GGML_ASSERT(interleave_block == 4 || interleave_block == 8);
@@ -3690,6 +5115,36 @@ static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block
3690
5115
 
3691
5116
  GGML_UNUSED(data_size);
3692
5117
  }
5118
+ static int repack_q4_K_to_q4_K_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
5119
+ GGML_ASSERT(t->type == GGML_TYPE_Q4_K);
5120
+ GGML_ASSERT(interleave_block == 8);
5121
+ constexpr int nrows_interleaved = 8;
5122
+
5123
+ block_q4_Kx8 * dst = (block_q4_Kx8*)t->data;
5124
+ const block_q4_K * src = (const block_q4_K*) data;
5125
+ block_q4_K dst_tmp[8];
5126
+ int nrow = ggml_nrows(t);
5127
+ int nblocks = t->ne[0] / QK_K;
5128
+
5129
+ GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_K));
5130
+
5131
+ if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
5132
+ return -1;
5133
+ }
5134
+
5135
+ for (int b = 0; b < nrow; b += nrows_interleaved) {
5136
+ for (int64_t x = 0; x < nblocks; x++) {
5137
+ for (int i = 0; i < nrows_interleaved; i++ ) {
5138
+ dst_tmp[i] = src[x + i * nblocks];
5139
+ }
5140
+ *dst++ = make_block_q4_Kx8(dst_tmp, interleave_block);
5141
+ }
5142
+ src += nrows_interleaved * nblocks;
5143
+ }
5144
+ return 0;
5145
+
5146
+ GGML_UNUSED(data_size);
5147
+ }
3693
5148
 
3694
5149
  static int repack_q4_0_to_q4_0_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
3695
5150
  GGML_ASSERT(t->type == GGML_TYPE_Q4_0);
@@ -3807,6 +5262,10 @@ template <> int repack<block_q4_0, 8, 8>(struct ggml_tensor * t, const void * da
3807
5262
  return repack_q4_0_to_q4_0_8_bl(t, 8, data, data_size);
3808
5263
  }
3809
5264
 
5265
+ template <> int repack<block_q4_K, 8, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {
5266
+ return repack_q4_K_to_q4_K_8_bl(t, 8, data, data_size);
5267
+ }
5268
+
3810
5269
  template <> int repack<block_iq4_nl, 4, 4>(struct ggml_tensor * t, const void * data, size_t data_size) {
3811
5270
  return repack_iq4_nl_to_iq4_nl_4_bl(t, 4, data, data_size);
3812
5271
  }
@@ -3832,6 +5291,10 @@ template <> void gemv<block_q4_0, 8, 8>(int n, float * s, size_t bs, const void
3832
5291
  ggml_gemv_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
3833
5292
  }
3834
5293
 
5294
+ template <> void gemv<block_q4_K, 8, 8>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
5295
+ ggml_gemv_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
5296
+ }
5297
+
3835
5298
  template <>
3836
5299
  void gemv<block_iq4_nl, 4, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
3837
5300
  ggml_gemv_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
@@ -3853,6 +5316,10 @@ template <> void gemm<block_q4_0, 8, 8>(int n, float * s, size_t bs, const void
3853
5316
  ggml_gemm_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
3854
5317
  }
3855
5318
 
5319
+ template <> void gemm<block_q4_K, 8, 8>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
5320
+ ggml_gemm_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
5321
+ }
5322
+
3856
5323
  template <>
3857
5324
  void gemm<block_iq4_nl, 4, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
3858
5325
  ggml_gemm_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
@@ -3863,16 +5330,16 @@ class tensor_traits_base : public ggml::cpu::tensor_traits {
3863
5330
  virtual int repack(struct ggml_tensor * t, const void * data, size_t data_size) = 0;
3864
5331
  };
3865
5332
 
3866
- template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS> class tensor_traits : public tensor_traits_base {
5333
+ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PARAM_TYPE> class tensor_traits : public tensor_traits_base {
3867
5334
 
3868
5335
  bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override {
3869
5336
  // not realy a GGML_TYPE_Q8_0 but same size.
3870
5337
  switch (op->op) {
3871
5338
  case GGML_OP_MUL_MAT:
3872
- size = ggml_row_size(GGML_TYPE_Q8_0, ggml_nelements(op->src[1]));
5339
+ size = ggml_row_size(PARAM_TYPE, ggml_nelements(op->src[1]));
3873
5340
  return true;
3874
5341
  case GGML_OP_MUL_MAT_ID:
3875
- size = ggml_row_size(GGML_TYPE_Q8_0, ggml_nelements(op->src[1]));
5342
+ size = ggml_row_size(PARAM_TYPE, ggml_nelements(op->src[1]));
3876
5343
  size = GGML_PAD(size, sizeof(int64_t)); // + padding for next bloc.
3877
5344
  size += sizeof(int64_t) * (1+op->src[0]->ne[2]) * op->src[1]->ne[2];
3878
5345
  return true;
@@ -3925,16 +5392,23 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS> class tensor_
3925
5392
  // GGML_ASSERT(ggml_n_dims(op->src[1]) == 2);
3926
5393
 
3927
5394
  char * wdata = static_cast<char *>(params->wdata);
3928
- const size_t nbw1 = ggml_row_size(GGML_TYPE_Q8_0, ne10);
5395
+ const size_t nbw1 = ggml_row_size(PARAM_TYPE, ne10);
3929
5396
 
3930
5397
  assert(params->wsize >= nbw1 * ne11);
3931
5398
 
3932
- const ggml_from_float_t from_float = ggml_get_type_traits_cpu(GGML_TYPE_Q8_0)->from_float;
5399
+ const ggml_from_float_t from_float = ggml_get_type_traits_cpu(PARAM_TYPE)->from_float;
3933
5400
 
3934
5401
  int64_t i11_processed = 0;
3935
- for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
3936
- quantize_mat_q8_0((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), 4, ne10,
5402
+ if(PARAM_TYPE == GGML_TYPE_Q8_K) {
5403
+ for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
5404
+ quantize_mat_q8_K((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), 4, ne10,
3937
5405
  INTER_SIZE);
5406
+ }
5407
+ } else {
5408
+ for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
5409
+ quantize_mat_q8_0((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), 4, ne10,
5410
+ INTER_SIZE);
5411
+ }
3938
5412
  }
3939
5413
  i11_processed = ne11 - ne11 % 4;
3940
5414
  for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) {
@@ -3944,7 +5418,7 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS> class tensor_
3944
5418
  ggml_barrier(params->threadpool);
3945
5419
 
3946
5420
  const void * src1_wdata = params->wdata;
3947
- const size_t src1_col_stride = ggml_row_size(GGML_TYPE_Q8_0, ne10);
5421
+ const size_t src1_col_stride = ggml_row_size(PARAM_TYPE, ne10);
3948
5422
  int64_t src0_start = (ith * ne01) / nth;
3949
5423
  int64_t src0_end = ((ith + 1) * ne01) / nth;
3950
5424
  src0_start = (src0_start % NB_COLS) ? src0_start + NB_COLS - (src0_start % NB_COLS) : src0_start;
@@ -4098,12 +5572,13 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS> class tensor_
4098
5572
  };
4099
5573
 
4100
5574
  // instance for Q4
4101
- static const tensor_traits<block_q4_0, 4, 4> q4_0_4x4_q8_0;
4102
- static const tensor_traits<block_q4_0, 8, 4> q4_0_4x8_q8_0;
4103
- static const tensor_traits<block_q4_0, 8, 8> q4_0_8x8_q8_0;
5575
+ static const tensor_traits<block_q4_0, 4, 4, GGML_TYPE_Q8_0> q4_0_4x4_q8_0;
5576
+ static const tensor_traits<block_q4_0, 8, 4, GGML_TYPE_Q8_0> q4_0_4x8_q8_0;
5577
+ static const tensor_traits<block_q4_0, 8, 8, GGML_TYPE_Q8_0> q4_0_8x8_q8_0;
5578
+ static const tensor_traits<block_q4_K, 8, 8, GGML_TYPE_Q8_K> q4_K_8x8_q8_K;
4104
5579
 
4105
5580
  // instance for IQ4
4106
- static const tensor_traits<block_iq4_nl, 4, 4> iq4_nl_4x4_q8_0;
5581
+ static const tensor_traits<block_iq4_nl, 4, 4, GGML_TYPE_IQ4_NL> iq4_nl_4x4_q8_0;
4107
5582
 
4108
5583
  } // namespace ggml::cpu::aarch64
4109
5584
 
@@ -4124,6 +5599,12 @@ static const ggml::cpu::tensor_traits * ggml_aarch64_get_optimal_repack_type(con
4124
5599
  return &ggml::cpu::aarch64::q4_0_4x4_q8_0;
4125
5600
  }
4126
5601
  }
5602
+ } else if (cur->type == GGML_TYPE_Q4_K) {
5603
+ if (ggml_cpu_has_avx2()) {
5604
+ if (cur->ne[1] % 8 == 0) {
5605
+ return &ggml::cpu::aarch64::q4_K_8x8_q8_K;
5606
+ }
5607
+ }
4127
5608
  } else if (cur->type == GGML_TYPE_IQ4_NL) {
4128
5609
  if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
4129
5610
  if (cur->ne[1] % 4 == 0) {