cui-llama.rn 1.4.6 → 1.5.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (80) hide show
  1. package/android/src/main/CMakeLists.txt +9 -2
  2. package/android/src/main/jni.cpp +52 -34
  3. package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
  4. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
  6. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
  8. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
  9. package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
  10. package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
  11. package/cpp/binary-ops.cpp +158 -0
  12. package/cpp/binary-ops.h +16 -0
  13. package/cpp/chat.cpp +1769 -1779
  14. package/cpp/chat.h +9 -1
  15. package/cpp/common.cpp +20 -522
  16. package/cpp/common.h +13 -36
  17. package/cpp/cpu-common.h +72 -0
  18. package/cpp/ggml-common.h +12 -6
  19. package/cpp/ggml-cpu-aarch64.cpp +1557 -80
  20. package/cpp/ggml-cpu-impl.h +2 -21
  21. package/cpp/ggml-cpu-quants.c +904 -405
  22. package/cpp/ggml-cpu.c +909 -13237
  23. package/cpp/ggml-impl.h +50 -23
  24. package/cpp/ggml-metal-impl.h +77 -3
  25. package/cpp/ggml-metal.m +794 -580
  26. package/cpp/ggml.c +92 -3
  27. package/cpp/ggml.h +29 -5
  28. package/cpp/gguf.cpp +1 -0
  29. package/cpp/llama-adapter.cpp +55 -20
  30. package/cpp/llama-adapter.h +11 -9
  31. package/cpp/llama-arch.cpp +217 -16
  32. package/cpp/llama-arch.h +25 -0
  33. package/cpp/llama-batch.h +2 -2
  34. package/cpp/llama-chat.cpp +54 -2
  35. package/cpp/llama-chat.h +3 -0
  36. package/cpp/llama-context.cpp +2294 -1238
  37. package/cpp/llama-context.h +214 -77
  38. package/cpp/llama-cparams.h +1 -0
  39. package/cpp/llama-graph.cpp +1695 -0
  40. package/cpp/llama-graph.h +592 -0
  41. package/cpp/llama-hparams.cpp +8 -0
  42. package/cpp/llama-hparams.h +17 -0
  43. package/cpp/llama-io.cpp +15 -0
  44. package/cpp/llama-io.h +35 -0
  45. package/cpp/llama-kv-cache.cpp +965 -303
  46. package/cpp/llama-kv-cache.h +145 -151
  47. package/cpp/llama-memory.cpp +1 -0
  48. package/cpp/llama-memory.h +21 -0
  49. package/cpp/llama-mmap.cpp +1 -1
  50. package/cpp/llama-model-loader.cpp +10 -5
  51. package/cpp/llama-model-loader.h +5 -3
  52. package/cpp/llama-model.cpp +9194 -201
  53. package/cpp/llama-model.h +40 -1
  54. package/cpp/llama-sampling.cpp +5 -0
  55. package/cpp/llama-vocab.cpp +36 -5
  56. package/cpp/llama.cpp +51 -9984
  57. package/cpp/llama.h +102 -22
  58. package/cpp/log.cpp +34 -0
  59. package/cpp/minja/chat-template.hpp +15 -7
  60. package/cpp/minja/minja.hpp +120 -94
  61. package/cpp/ops.cpp +8723 -0
  62. package/cpp/ops.h +128 -0
  63. package/cpp/rn-llama.cpp +44 -53
  64. package/cpp/rn-llama.h +2 -12
  65. package/cpp/sampling.cpp +3 -0
  66. package/cpp/sgemm.cpp +533 -88
  67. package/cpp/simd-mappings.h +888 -0
  68. package/cpp/speculative.cpp +4 -4
  69. package/cpp/unary-ops.cpp +186 -0
  70. package/cpp/unary-ops.h +28 -0
  71. package/cpp/vec.cpp +258 -0
  72. package/cpp/vec.h +802 -0
  73. package/ios/CMakeLists.txt +5 -2
  74. package/ios/RNLlama.mm +2 -2
  75. package/ios/RNLlamaContext.mm +40 -24
  76. package/package.json +1 -1
  77. package/src/NativeRNLlama.ts +6 -4
  78. package/src/index.ts +3 -1
  79. package/cpp/chat-template.hpp +0 -529
  80. package/cpp/minja.hpp +0 -2915
@@ -45,6 +45,24 @@ using block_q4_0x8 = block<4, 8>;
45
45
  using block_q8_0x4 = block<8, 4>;
46
46
  using block_q8_0x8 = block<8, 8>;
47
47
 
48
+
49
+ struct block_q4_Kx8 {
50
+ lm_ggml_half d[8]; // super-block scale for quantized scales
51
+ lm_ggml_half dmin[8]; // super-block scale for quantized mins
52
+ uint8_t scales[96]; // scales and mins, quantized with 6 bits
53
+ uint8_t qs[1024]; // 4--bit quants
54
+ };
55
+
56
+ static_assert(sizeof(block_q4_Kx8) == sizeof(lm_ggml_half) * 16 + K_SCALE_SIZE * 8 + QK_K * 4, "wrong q4_K block size/padding");
57
+
58
+ struct block_q8_Kx4 {
59
+ float d[4]; // delta
60
+ int8_t qs[QK_K * 4]; // quants
61
+ int16_t bsums[QK_K / 4]; // sum of quants in groups of 16
62
+ };
63
+
64
+ static_assert(sizeof(block_q8_Kx4) == sizeof(float) * 4 + QK_K * 4 + (QK_K / 4) * sizeof(int16_t), "wrong q8_K block size/padding");
65
+
48
66
  struct block_iq4_nlx4 {
49
67
  lm_ggml_half d[4]; // deltas for 4 iq4_nl blocks
50
68
  uint8_t qs[QK4_NL * 2]; // nibbles / quants for 4 iq4_nl blocks
@@ -60,6 +78,13 @@ static_assert(sizeof(block_iq4_nlx4) == 4 * sizeof(lm_ggml_half) + QK4_NL * 2, "
60
78
 
61
79
  #define UNUSED LM_GGML_UNUSED
62
80
 
81
+ static inline int nearest_int(float fval) {
82
+ assert(fabsf(fval) <= 4194303.f);
83
+ float val = fval + 12582912.f;
84
+ int i; memcpy(&i, &val, sizeof(int));
85
+ return (i & 0x007fffff) - 0x00400000;
86
+ }
87
+
63
88
  // Functions to create the interleaved data layout formats
64
89
 
65
90
  // interleave 4 block_q4_0s in blocks of blck_size_interleave
@@ -225,7 +250,7 @@ static inline __m256i mul_sum_i8_pairs_int32x8(const __m256i x, const __m256i y)
225
250
 
226
251
  static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
227
252
 
228
- static void quantize_q8_0_4x4(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT vy, int64_t k) {
253
+ static void lm_ggml_quantize_mat_q8_0_4x4(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT vy, int64_t k) {
229
254
  assert(QK8_0 == 32);
230
255
  assert(k % QK8_0 == 0);
231
256
  const int nb = k / QK8_0;
@@ -319,7 +344,7 @@ static void quantize_q8_0_4x4(const float * LM_GGML_RESTRICT x, void * LM_GGML_R
319
344
  #endif
320
345
  }
321
346
 
322
- static void quantize_q8_0_4x8(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT vy, int64_t k) {
347
+ static void lm_ggml_quantize_mat_q8_0_4x8(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT vy, int64_t k) {
323
348
  assert(QK8_0 == 32);
324
349
  assert(k % QK8_0 == 0);
325
350
  const int nb = k / QK8_0;
@@ -534,16 +559,289 @@ static void quantize_q8_0_4x8(const float * LM_GGML_RESTRICT x, void * LM_GGML_R
534
559
  #endif
535
560
  }
536
561
 
537
- static void quantize_mat_q8_0(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row, int64_t blck_size_interleave) {
562
+ static void lm_ggml_quantize_mat_q8_K_4x8(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT vy, int64_t k) {
563
+ assert(QK_K == 256);
564
+ assert(k % QK_K == 0);
565
+ const int nb = k / QK_K;
566
+
567
+ block_q8_Kx4 * LM_GGML_RESTRICT y = (block_q8_Kx4 *) vy;
568
+
569
+ #if defined(__AVX2__)
570
+ float iscale[4];
571
+ __m256 srcv[4][32];
572
+ __m256 iscale_vec[4];
573
+
574
+ for (int i = 0; i < nb; i++) {
575
+ for (int row_iter = 0; row_iter < 4; row_iter++) {
576
+ // Load elements into 4 AVX vectors
577
+ __m256 v0 = _mm256_loadu_ps( x + row_iter * k + i * 256 );
578
+ __m256 v1 = _mm256_loadu_ps( x + row_iter * k + i * 256 + 8 );
579
+ __m256 v2 = _mm256_loadu_ps( x + row_iter * k + i * 256 + 16 );
580
+ __m256 v3 = _mm256_loadu_ps( x + row_iter * k + i * 256 + 24 );
581
+
582
+ // Compute max(abs(e)) for the block
583
+ const __m256 signBit = _mm256_set1_ps( -0.0f );
584
+ __m256 abs0 = _mm256_andnot_ps( signBit, v0 );
585
+ __m256 abs1 = _mm256_andnot_ps( signBit, v1 );
586
+ __m256 abs2 = _mm256_andnot_ps( signBit, v2 );
587
+ __m256 abs3 = _mm256_andnot_ps( signBit, v3 );
588
+
589
+ __m256 maxAbs = _mm256_max_ps( abs0, abs1 );
590
+ maxAbs = _mm256_max_ps( maxAbs, abs2 );
591
+ maxAbs = _mm256_max_ps( maxAbs, abs3 );
592
+
593
+ __m256 mask0 = _mm256_cmp_ps( maxAbs, v0, _CMP_EQ_OQ );
594
+ __m256 mask1 = _mm256_cmp_ps( maxAbs, v1, _CMP_EQ_OQ );
595
+ __m256 mask2 = _mm256_cmp_ps( maxAbs, v2, _CMP_EQ_OQ );
596
+ __m256 mask3 = _mm256_cmp_ps( maxAbs, v3, _CMP_EQ_OQ );
597
+
598
+ __m256 maskAbs = _mm256_or_ps(_mm256_or_ps(mask0, mask1),_mm256_or_ps(mask2, mask3));
599
+
600
+ srcv[row_iter][0] = v0;
601
+ srcv[row_iter][1] = v1;
602
+ srcv[row_iter][2] = v2;
603
+ srcv[row_iter][3] = v3;
604
+
605
+ for (int sb = 1; sb < 8; sb++) {
606
+ // Temporarily stores absolute quant values
607
+ __m256 tempAbs = maxAbs;
608
+
609
+ // Load elements into 4 AVX vectors
610
+ __m256 v0 = _mm256_loadu_ps( x + row_iter * k + i * 256 + sb * 32);
611
+ __m256 v1 = _mm256_loadu_ps( x + row_iter * k + i * 256 + sb * 32 + 8 );
612
+ __m256 v2 = _mm256_loadu_ps( x + row_iter * k + i * 256 + sb * 32 + 16 );
613
+ __m256 v3 = _mm256_loadu_ps( x + row_iter * k + i * 256 + sb * 32 + 24 );
614
+
615
+ // Compute max(abs(e)) for the block
616
+ __m256 abs0 = _mm256_andnot_ps( signBit, v0 );
617
+ __m256 abs1 = _mm256_andnot_ps( signBit, v1 );
618
+ __m256 abs2 = _mm256_andnot_ps( signBit, v2 );
619
+ __m256 abs3 = _mm256_andnot_ps( signBit, v3 );
620
+
621
+ maxAbs = _mm256_max_ps( maxAbs, abs0 );
622
+ maxAbs = _mm256_max_ps( maxAbs, abs1 );
623
+ maxAbs = _mm256_max_ps( maxAbs, abs2 );
624
+ maxAbs = _mm256_max_ps( maxAbs, abs3 );
625
+
626
+ __m256 mask_prev = _mm256_cmp_ps( tempAbs, maxAbs, _CMP_EQ_OQ );
627
+ maskAbs = _mm256_and_ps( maskAbs, mask_prev );
628
+
629
+ mask0 = _mm256_cmp_ps( maxAbs, v0, _CMP_EQ_OQ );
630
+ mask1 = _mm256_cmp_ps( maxAbs, v1, _CMP_EQ_OQ );
631
+ mask2 = _mm256_cmp_ps( maxAbs, v2, _CMP_EQ_OQ );
632
+ mask3 = _mm256_cmp_ps( maxAbs, v3, _CMP_EQ_OQ );
633
+
634
+ __m256 mask_curr = _mm256_or_ps(_mm256_or_ps(mask0, mask1),_mm256_or_ps(mask2, mask3));
635
+ maskAbs = _mm256_or_ps(maskAbs, mask_curr);
636
+
637
+ srcv[row_iter][sb * 4] = v0;
638
+ srcv[row_iter][sb * 4 + 1] = v1;
639
+ srcv[row_iter][sb * 4 + 2] = v2;
640
+ srcv[row_iter][sb * 4 + 3] = v3;
641
+ }
642
+
643
+ __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
644
+ max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
645
+ max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
646
+ const float maxScalar = _mm_cvtss_f32( max4 );
647
+
648
+ __m256 maxScalarVec = _mm256_set1_ps(maxScalar);
649
+
650
+ __m256 mask_next = _mm256_cmp_ps( maxScalarVec, maxAbs, _CMP_EQ_OQ );
651
+ __m256 finalMask = _mm256_and_ps(maskAbs, mask_next);
652
+
653
+ const int mask = _mm256_movemask_ps(finalMask);
654
+ iscale[row_iter] = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
655
+
656
+ if(mask) {
657
+ iscale[row_iter] = ( maxScalar != 0.0f ) ? -127.f / maxScalar: 0.0f;
658
+ }
659
+
660
+ y[i].d[row_iter] = maxScalar ? 1/iscale[row_iter] : 0;
661
+ iscale_vec[row_iter] = _mm256_set1_ps(iscale[row_iter]);
662
+ }
663
+
664
+ __m256i quants_interleaved[32];
665
+ for (int j = 0; j < 32; j++) {
666
+ // Apply the multiplier
667
+ __m256 v0 = _mm256_mul_ps(srcv[0][j], iscale_vec[0]);
668
+ __m256 v1 = _mm256_mul_ps(srcv[1][j], iscale_vec[1]);
669
+ __m256 v2 = _mm256_mul_ps(srcv[2][j], iscale_vec[2]);
670
+ __m256 v3 = _mm256_mul_ps(srcv[3][j], iscale_vec[3]);
671
+
672
+ // Round to nearest integer
673
+ v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
674
+ v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
675
+ v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
676
+ v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
677
+
678
+ // Convert floats to integers
679
+ __m256i i0 = _mm256_cvtps_epi32( v0 );
680
+ __m256i i1 = _mm256_cvtps_epi32( v1 );
681
+ __m256i i2 = _mm256_cvtps_epi32( v2 );
682
+ __m256i i3 = _mm256_cvtps_epi32( v3 );
683
+
684
+ // Convert int32 to int16
685
+ i0 = _mm256_packs_epi32( i0, i1 );
686
+ i2 = _mm256_packs_epi32( i2, i3 );
687
+ // Convert int16 to int8
688
+ i0 = _mm256_packs_epi16( i0, i2 );
689
+
690
+ // Permute and store the quantized weights in the required order after the pack instruction
691
+ const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
692
+ i0 = _mm256_permutevar8x32_epi32( i0, perm );
693
+
694
+ _mm256_storeu_si256((__m256i *)(y[i].qs + 32 * j), i0);
695
+ quants_interleaved[j] = i0;
696
+ }
697
+
698
+ // Masks to shuffle the quants of corresonding sub blocks for rearraning quants for vectorized bsums computation
699
+ __m256i shuffle_mask_sb2 = _mm256_castsi128_si256(_mm_setr_epi8(0, 1, 0, 1, 4, 5, 6, 7, 8, 9, 8, 9, 12, 13, 14, 15));
700
+ shuffle_mask_sb2 = _mm256_permute2f128_si256(shuffle_mask_sb2, shuffle_mask_sb2, 0);
701
+ __m256i shuffle_mask_sb3 = _mm256_castsi128_si256(_mm_setr_epi8(0, 1, 2, 3, 0, 1, 6, 7, 8, 9, 10, 11, 8, 9, 14, 15));
702
+ shuffle_mask_sb3 = _mm256_permute2f128_si256(shuffle_mask_sb3, shuffle_mask_sb3, 0);
703
+ __m256i shuffle_mask_sb4 = _mm256_castsi128_si256(_mm_setr_epi8(0, 1, 2, 3, 4, 5, 0, 1, 8, 9, 10, 11, 12, 13, 8, 9));
704
+ shuffle_mask_sb4 = _mm256_permute2f128_si256(shuffle_mask_sb4, shuffle_mask_sb4, 0);
705
+
706
+ for (int k = 0; k < 4; k++) {
707
+ // Quants from four different sub blocks are taken
708
+ __m256i q0 = quants_interleaved[k * 8 + 0];
709
+ __m256i q1 = quants_interleaved[k * 8 + 1];
710
+ __m256i q2 = quants_interleaved[k * 8 + 2];
711
+ __m256i q3 = quants_interleaved[k * 8 + 3];
712
+ __m256i q4 = quants_interleaved[k * 8 + 4];
713
+ __m256i q5 = quants_interleaved[k * 8 + 5];
714
+ __m256i q6 = quants_interleaved[k * 8 + 6];
715
+ __m256i q7 = quants_interleaved[k * 8 + 7];
716
+
717
+
718
+ // The below code block has the first half of different sub blocks shuffled and blended so as to process 2 values from each sub block at a time
719
+ __m256i sb2_h1_shuffled = _mm256_shuffle_epi8(q2, shuffle_mask_sb2);
720
+ __m256i sb_h1_interleaved = _mm256_blend_epi16(q0, sb2_h1_shuffled, 34);
721
+ __m256i sb3_h1_shuffled = _mm256_shuffle_epi8(q4, shuffle_mask_sb3);
722
+ sb_h1_interleaved = _mm256_blend_epi16(sb_h1_interleaved, sb3_h1_shuffled, 68);
723
+ __m256i sb4_h1_shuffled = _mm256_shuffle_epi8(q6, shuffle_mask_sb4);
724
+ sb_h1_interleaved = _mm256_blend_epi16(sb_h1_interleaved, sb4_h1_shuffled, 136);
725
+
726
+ __m256i one = _mm256_set1_epi8(1);
727
+ __m256i bsums_r1 = _mm256_maddubs_epi16(one, sb_h1_interleaved);
728
+
729
+ for (int l = 0; l < 3; l++) {
730
+ // Quants value shifted to process next two values from each sub block
731
+ q0 = _mm256_srli_epi64(q0, 16);
732
+ q2 = _mm256_srli_epi64(q2, 16);
733
+ q4 = _mm256_srli_epi64(q4, 16);
734
+ q6 = _mm256_srli_epi64(q6, 16);
735
+
736
+ sb2_h1_shuffled = _mm256_shuffle_epi8(q2, shuffle_mask_sb2);
737
+ sb_h1_interleaved = _mm256_blend_epi16(q0, sb2_h1_shuffled, 34);
738
+ sb3_h1_shuffled = _mm256_shuffle_epi8(q4, shuffle_mask_sb3);
739
+ sb_h1_interleaved = _mm256_blend_epi16(sb_h1_interleaved, sb3_h1_shuffled, 68);
740
+ sb4_h1_shuffled = _mm256_shuffle_epi8(q6, shuffle_mask_sb4);
741
+ sb_h1_interleaved = _mm256_blend_epi16(sb_h1_interleaved, sb4_h1_shuffled, 136);
742
+
743
+ bsums_r1 = _mm256_add_epi16(bsums_r1, _mm256_maddubs_epi16(one, sb_h1_interleaved));
744
+ }
745
+
746
+ // The below code block has the second half of different sub blocks shuffled and blended so as to process 2 values from each sub block at a time
747
+ __m256i sb2_h2_shuffled = _mm256_shuffle_epi8(q3, shuffle_mask_sb2);
748
+ __m256i sb_h2_interleaved = _mm256_blend_epi16(q1, sb2_h2_shuffled, 34);
749
+ __m256i sb3_h2_shuffled = _mm256_shuffle_epi8(q5, shuffle_mask_sb3);
750
+ sb_h2_interleaved = _mm256_blend_epi16(sb_h2_interleaved, sb3_h2_shuffled, 68);
751
+ __m256i sb4_h2_shuffled = _mm256_shuffle_epi8(q7, shuffle_mask_sb4);
752
+ sb_h2_interleaved = _mm256_blend_epi16(sb_h2_interleaved, sb4_h2_shuffled, 136);
753
+
754
+ __m256i bsums_r2 = _mm256_maddubs_epi16(one, sb_h2_interleaved);
755
+
756
+ for (int l = 0; l < 3; l++) {
757
+ // Quants value shifted to process next two values from each sub block
758
+ q1 = _mm256_srli_epi64(q1, 16);
759
+ q3 = _mm256_srli_epi64(q3, 16);
760
+ q5 = _mm256_srli_epi64(q5, 16);
761
+ q7 = _mm256_srli_epi64(q7, 16);
762
+
763
+ sb2_h2_shuffled = _mm256_shuffle_epi8(q3, shuffle_mask_sb2);
764
+ sb_h2_interleaved = _mm256_blend_epi16(q1, sb2_h2_shuffled, 34);
765
+ sb3_h2_shuffled = _mm256_shuffle_epi8(q5, shuffle_mask_sb3);
766
+ sb_h2_interleaved = _mm256_blend_epi16(sb_h2_interleaved, sb3_h2_shuffled, 68);
767
+ sb4_h2_shuffled = _mm256_shuffle_epi8(q7, shuffle_mask_sb4);
768
+ sb_h2_interleaved = _mm256_blend_epi16(sb_h2_interleaved, sb4_h2_shuffled, 136);
769
+
770
+ bsums_r2 = _mm256_add_epi16(bsums_r2, _mm256_maddubs_epi16(one, sb_h2_interleaved));
771
+ }
772
+
773
+ // Overall bsums in interleaved fashion computed by adding results of both halves
774
+ __m256i bsums_r = _mm256_add_epi16(bsums_r1, bsums_r2);
775
+ _mm256_storeu_si256((__m256i *)(y[i].bsums + 16 * k), bsums_r);
776
+ }
777
+ }
778
+
779
+ #else
780
+
781
+ // scalar
782
+ const int blck_size_interleave = 8;
783
+ float srcv[4][QK_K];
784
+ float iscale[4];
785
+
786
+ for (int i = 0; i < nb; i++) {
787
+ for (int row_iter = 0; row_iter < 4; row_iter++) {
788
+ float amax = 0.0f; // absolute max
789
+ float max = 0;
790
+
791
+ for (int j = 0; j < QK_K; j++) {
792
+ srcv[row_iter][j] = x[row_iter * k + i * QK_K + j];
793
+ // Update the maximum value of the corresponding super block
794
+ if(amax < fabsf(srcv[row_iter][j])) {
795
+ amax = fabsf(srcv[row_iter][j]);
796
+ max = srcv[row_iter][j];
797
+ }
798
+ }
799
+
800
+ iscale[row_iter] = amax ? -127.f/max : 0;
801
+
802
+ y[i].d[row_iter] = amax ? 1/iscale[row_iter] : 0;
803
+ }
804
+
805
+ for (int j = 0; j < QK_K / 4; j++) {
806
+ y[i].bsums[j] = 0;
807
+ }
808
+
809
+ // Quants values are interleaved in sequence of eight bytes from corresponding super blocks
810
+ // Bsums values are interleaved in sequence of four bsums from each super block taken for interleaving
811
+ // i.e first four bsums from the first super block, followed by first four bsums from second super block and so on
812
+ for (int j = 0; j < QK_K * 4; j++) {
813
+ int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave;
814
+ int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave;
815
+ src_offset += (j % blck_size_interleave);
816
+ int index = (((j & 31) >> 3) << 2) + ((j >> 8) << 4) + ((j >> 6) & 3);
817
+
818
+ float x0 = srcv[src_id][src_offset] * iscale[src_id];
819
+ y[i].qs[j] = nearest_int(x0);
820
+ y[i].bsums[index] += y[i].qs[j];
821
+ }
822
+ }
823
+ #endif
824
+ }
825
+
826
+ template <int64_t INTER_SIZE, lm_ggml_type PARAM_TYPE>
827
+ void lm_ggml_quantize_mat_t(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row);
828
+
829
+ template <> void lm_ggml_quantize_mat_t<4, LM_GGML_TYPE_Q8_0>(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {
538
830
  assert(nrow == 4);
539
831
  UNUSED(nrow);
540
- if (blck_size_interleave == 4) {
541
- quantize_q8_0_4x4(x, vy, n_per_row);
542
- } else if (blck_size_interleave == 8) {
543
- quantize_q8_0_4x8(x, vy, n_per_row);
544
- } else {
545
- assert(false);
546
- }
832
+ lm_ggml_quantize_mat_q8_0_4x4(x, vy, n_per_row);
833
+ }
834
+
835
+ template <> void lm_ggml_quantize_mat_t<8, LM_GGML_TYPE_Q8_0>(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {
836
+ assert(nrow == 4);
837
+ UNUSED(nrow);
838
+ lm_ggml_quantize_mat_q8_0_4x8(x, vy, n_per_row);
839
+ }
840
+
841
+ template <> void lm_ggml_quantize_mat_t<8, LM_GGML_TYPE_Q8_K>(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {
842
+ assert(nrow == 4);
843
+ UNUSED(nrow);
844
+ lm_ggml_quantize_mat_q8_K_4x8(x, vy, n_per_row);
547
845
  }
548
846
 
549
847
  static void lm_ggml_gemv_q4_0_4x4_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, const void * LM_GGML_RESTRICT vy, int nr, int nc) {
@@ -994,6 +1292,281 @@ static void lm_ggml_gemv_q4_0_8x8_q8_0(int n, float * LM_GGML_RESTRICT s, size_t
994
1292
  }
995
1293
  }
996
1294
 
1295
+ static void lm_ggml_gemv_q4_K_8x8_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, const void * LM_GGML_RESTRICT vy, int nr, int nc) {
1296
+ const int qk = QK_K;
1297
+ const int nb = n / qk;
1298
+ const int ncols_interleaved = 8;
1299
+ const int blocklen = 8;
1300
+ static const uint32_t kmask1 = 0x3f3f3f3f;
1301
+ static const uint32_t kmask2 = 0x0f0f0f0f;
1302
+ static const uint32_t kmask3 = 0x03030303;
1303
+
1304
+ assert (n % qk == 0);
1305
+ assert (nc % ncols_interleaved == 0);
1306
+
1307
+ UNUSED(s);
1308
+ UNUSED(bs);
1309
+ UNUSED(vx);
1310
+ UNUSED(vy);
1311
+ UNUSED(nr);
1312
+ UNUSED(nc);
1313
+ UNUSED(nb);
1314
+ UNUSED(ncols_interleaved);
1315
+ UNUSED(blocklen);
1316
+
1317
+ #if defined(__AVX2__)
1318
+ // Lookup table to convert signed nibbles to signed bytes
1319
+ __m256i signextendlut = _mm256_castsi128_si256(_mm_set_epi8(-1, -2, -3, -4, -5, -6, -7, -8, 7, 6, 5, 4, 3, 2, 1, 0));
1320
+ signextendlut = _mm256_permute2f128_si256(signextendlut, signextendlut, 0);
1321
+ // Shuffle masks to rearrange delta and scale values to multiply with appropriate scales
1322
+ __m128i deltamask = _mm_set_epi8(15, 14, 7, 6, 13, 12, 5, 4, 11, 10, 3, 2, 9, 8, 1, 0);
1323
+ __m128i scalemask = _mm_set_epi8(7, 7, 3, 3, 6, 6, 2, 2, 5, 5, 1, 1, 4, 4, 0, 0);
1324
+ // Permute mask used for easier vector processing at later stages
1325
+ __m256i finalpermutemask = _mm256_set_epi32(7, 5, 3, 1, 6, 4, 2, 0);
1326
+
1327
+ // Mask to extract nibbles from bytes
1328
+ const __m256i m4b = _mm256_set1_epi8(0x0F);
1329
+
1330
+ int64_t b_nb = n / QK_K;
1331
+
1332
+ const block_q4_Kx8 * b_ptr_start = (const block_q4_Kx8 *)vx;
1333
+ const block_q8_K * a_ptr_start = (const block_q8_K *)vy;
1334
+
1335
+ // Process Q8_K blocks one by one
1336
+ for (int64_t y = 0; y < nr; y++) {
1337
+
1338
+ // Pointers to LHS blocks of block_q8_K format
1339
+ const block_q8_K * a_ptr = a_ptr_start + (y * nb);
1340
+
1341
+ // Take group of eight interleaved block_q4_K structures at each pass of the loop and perform dot product operation
1342
+ for (int64_t x = 0; x < nc / 8; x++) {
1343
+
1344
+ // Pointers to RHS blocks
1345
+ const block_q4_Kx8 * b_ptr = b_ptr_start + (x * b_nb);
1346
+
1347
+ // Master FP accumulators
1348
+ __m256 acc_row = _mm256_setzero_ps();
1349
+ __m256 acc_min_rows = _mm256_setzero_ps();
1350
+
1351
+ for (int64_t b = 0; b < nb; b++) {
1352
+
1353
+ // Load and convert to FP32 scale from block_q8_K
1354
+ const __m256 row_scale_f32 = _mm256_set1_ps((a_ptr[b].d));
1355
+
1356
+ // Load the scale values for the 8 blocks interleaved in block_q4_Kx8
1357
+ // col_scale_f32 rearranged so as to multiply with appropriate quants
1358
+ const __m256 col_scale_f32 = LM_GGML_F32Cx8_REARRANGE_LOAD(b_ptr[b].d, deltamask);
1359
+ const __m256 col_dmin_f32 = LM_GGML_F32Cx8_LOAD(b_ptr[b].dmin);
1360
+
1361
+ __m256i iacc_b = _mm256_setzero_si256();
1362
+ __m256i iacc_min_b = _mm256_setzero_si256();
1363
+
1364
+ const __m256i q8sums = _mm256_loadu_si256((const __m256i * )(a_ptr[b].bsums));
1365
+ __m256i q8s = _mm256_castsi128_si256(_mm_hadd_epi16(_mm256_castsi256_si128(q8sums), _mm256_extracti128_si256(q8sums, 1)));
1366
+ q8s = _mm256_permute2f128_si256(q8s, q8s, 0);
1367
+
1368
+ // Processes two sub blocks from each Q4_K in each iteration
1369
+ for (int sb = 0; sb < QK_K / 64; sb++) {
1370
+
1371
+ // Load the eight block_q4_K for two sub blocks quantized values interleaved with each other in chunks of eight - B0,B1 ....B6,B7
1372
+ const __m256i rhs_raw_vec_0123_0 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + sb * 256));
1373
+ const __m256i rhs_raw_vec_4567_0 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 32 + sb * 256));
1374
+ const __m256i rhs_raw_vec_0123_1 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 64 + sb * 256));
1375
+ const __m256i rhs_raw_vec_4567_1 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 96 + sb * 256));
1376
+ const __m256i rhs_raw_vec_0123_2 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 128 + sb * 256));
1377
+ const __m256i rhs_raw_vec_4567_2 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 160 + sb * 256));
1378
+ const __m256i rhs_raw_vec_0123_3 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 192 + sb * 256));
1379
+ const __m256i rhs_raw_vec_4567_3 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 224 + sb * 256));
1380
+
1381
+ // 4-bit -> 8-bit
1382
+ // Values of the first sub block of eight block_q4_K structures for the sb loop
1383
+ const __m256i rhs_vec_0123_00 = _mm256_and_si256(rhs_raw_vec_0123_0, m4b);
1384
+ const __m256i rhs_vec_4567_00 = _mm256_and_si256(rhs_raw_vec_4567_0, m4b);
1385
+ const __m256i rhs_vec_0123_01 = _mm256_and_si256(rhs_raw_vec_0123_1, m4b);
1386
+ const __m256i rhs_vec_4567_01 = _mm256_and_si256(rhs_raw_vec_4567_1, m4b);
1387
+ const __m256i rhs_vec_0123_02 = _mm256_and_si256(rhs_raw_vec_0123_2, m4b);
1388
+ const __m256i rhs_vec_4567_02 = _mm256_and_si256(rhs_raw_vec_4567_2, m4b);
1389
+ const __m256i rhs_vec_0123_03 = _mm256_and_si256(rhs_raw_vec_0123_3, m4b);
1390
+ const __m256i rhs_vec_4567_03 = _mm256_and_si256(rhs_raw_vec_4567_3, m4b);
1391
+
1392
+ // Values of the second sub block of eight block_q4_K structures when sb = 1
1393
+ const __m256i rhs_vec_0123_10 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_0, 4), m4b);
1394
+ const __m256i rhs_vec_4567_10 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_0, 4), m4b);
1395
+ const __m256i rhs_vec_0123_11 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_1, 4), m4b);
1396
+ const __m256i rhs_vec_4567_11 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_1, 4), m4b);
1397
+ const __m256i rhs_vec_0123_12 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_2, 4), m4b);
1398
+ const __m256i rhs_vec_4567_12 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_2, 4), m4b);
1399
+ const __m256i rhs_vec_0123_13 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_3, 4), m4b);
1400
+ const __m256i rhs_vec_4567_13 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_3, 4), m4b);
1401
+
1402
+ uint32_t utmp_0[4], utmp_1[4];
1403
+
1404
+ // Scales and Mins of corresponding sub blocks from different Q8_K structures are stored together
1405
+ // The below block is for eg to extract first sub block's scales and mins from different Q4_K structures for the sb loop
1406
+ memcpy(utmp_0, b_ptr[b].scales + 24 * sb, 12);
1407
+ utmp_0[3] = ((utmp_0[2] >> 4) & kmask2) | (((utmp_0[1] >> 6) & kmask3) << 4);
1408
+ const uint32_t uaux_0 = utmp_0[1] & kmask1;
1409
+ utmp_0[1] = (utmp_0[2] & kmask2) | (((utmp_0[0] >> 6) & kmask3) << 4);
1410
+ utmp_0[2] = uaux_0;
1411
+ utmp_0[0] &= kmask1;
1412
+
1413
+ // The below block is for eg to extract second sub block's scales and mins from different Q4_K structures for the sb loop
1414
+ memcpy(utmp_1, b_ptr[b].scales + 12 + sb * 24, 12);
1415
+ utmp_1[3] = ((utmp_1[2] >> 4) & kmask2) | (((utmp_1[1] >> 6) & kmask3) << 4);
1416
+ const uint32_t uaux_1 = utmp_1[1] & kmask1;
1417
+ utmp_1[1] = (utmp_1[2] & kmask2) | (((utmp_1[0] >> 6) & kmask3) << 4);
1418
+ utmp_1[2] = uaux_1;
1419
+ utmp_1[0] &= kmask1;
1420
+
1421
+ // Scales of first sub block in the sb loop
1422
+ const __m128i mins_and_scales_0 = _mm_set_epi32(utmp_0[3], utmp_0[2], utmp_0[1], utmp_0[0]);
1423
+ __m128i scales_rearrange_0 = _mm_shuffle_epi8(mins_and_scales_0, scalemask);
1424
+ __m256i scales_0 = _mm256_cvtepu8_epi16(scales_rearrange_0);
1425
+
1426
+ // Scales of second sub block in the sb loop
1427
+ __m128i mins_and_scales_1 = _mm_set_epi32(utmp_1[3], utmp_1[2], utmp_1[1], utmp_1[0]);
1428
+ __m128i scales_rearrange_1 = _mm_shuffle_epi8(mins_and_scales_1, scalemask);
1429
+ __m256i scales_1 = _mm256_cvtepu8_epi16(scales_rearrange_1);
1430
+
1431
+ // Mins of first and second sub block of Q4_K block are arranged side by side
1432
+ __m256i mins_01 = _mm256_cvtepu8_epi16(_mm_unpacklo_epi8(_mm_shuffle_epi32(mins_and_scales_0, 78), _mm_shuffle_epi32(mins_and_scales_1, 78)));
1433
+
1434
+ // Load the two sub block values corresponding to sb in block_q8_K in batches of 16 bytes and replicate the same across 256 bit vector
1435
+ __m256i lhs_vec_00 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + sb * 64)));
1436
+ __m256i lhs_vec_01 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 16 + sb * 64)));
1437
+ __m256i lhs_vec_10 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 32 + sb * 64)));
1438
+ __m256i lhs_vec_11 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 48 + sb * 64)));
1439
+
1440
+ lhs_vec_00 = _mm256_permute2f128_si256(lhs_vec_00, lhs_vec_00, 0);
1441
+ lhs_vec_01 = _mm256_permute2f128_si256(lhs_vec_01, lhs_vec_01, 0);
1442
+ lhs_vec_10 = _mm256_permute2f128_si256(lhs_vec_10, lhs_vec_10, 0);
1443
+ lhs_vec_11 = _mm256_permute2f128_si256(lhs_vec_11, lhs_vec_11, 0);
1444
+
1445
+ // Dot product done within 32 bit lanes and accumulated in the same vector
1446
+ // First done for first sub block and thenn for second sub block in each sb
1447
+ // B0(0-3) B4(0-3) B1(0-3) B5(0-3) B2(0-3) B6(0-3) B3(0-3) B7(0-3) with A0(0-3)
1448
+ // B0(4-7) B4(4-7) B1(4-7) B5(4-7) B2(4-7) B6(4-7) B3(4-7) B7(4-7) with A0(4-7)
1449
+ // ...........................................................................
1450
+ // B0(28-31) B4(28-31) B1(28-31) B5(28-31) B2(28-31) B6(28-31) B3(28-31) B7(28-31) with A0(28-31)
1451
+
1452
+
1453
+ __m256i iacc_0 = _mm256_setzero_si256();
1454
+ __m256i iacc_1 = _mm256_setzero_si256();
1455
+
1456
+ iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_00 ,_mm256_shuffle_epi32(rhs_vec_4567_00, 177), 170), _mm256_shuffle_epi32(lhs_vec_00, 0)));
1457
+ iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_00, 177) ,rhs_vec_4567_00, 170), _mm256_shuffle_epi32(lhs_vec_00, 85)));
1458
+
1459
+ iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_01 ,_mm256_shuffle_epi32(rhs_vec_4567_01, 177), 170), _mm256_shuffle_epi32(lhs_vec_00, 170)));
1460
+ iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_01, 177) ,rhs_vec_4567_01, 170), _mm256_shuffle_epi32(lhs_vec_00, 255)));
1461
+
1462
+ iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_02 ,_mm256_shuffle_epi32(rhs_vec_4567_02, 177), 170), _mm256_shuffle_epi32(lhs_vec_01, 0)));
1463
+ iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_02, 177) ,rhs_vec_4567_02, 170), _mm256_shuffle_epi32(lhs_vec_01, 85)));
1464
+
1465
+ iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_03 ,_mm256_shuffle_epi32(rhs_vec_4567_03, 177), 170), _mm256_shuffle_epi32(lhs_vec_01, 170)));
1466
+ iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_03, 177) ,rhs_vec_4567_03, 170), _mm256_shuffle_epi32(lhs_vec_01, 255)));
1467
+
1468
+ iacc_0 = _mm256_madd_epi16(iacc_0, scales_0);
1469
+
1470
+ iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_10 ,_mm256_shuffle_epi32(rhs_vec_4567_10, 177), 170), _mm256_shuffle_epi32(lhs_vec_10, 0)));
1471
+ iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_10, 177) ,rhs_vec_4567_10, 170), _mm256_shuffle_epi32(lhs_vec_10, 85)));
1472
+
1473
+ iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_11 ,_mm256_shuffle_epi32(rhs_vec_4567_11, 177), 170), _mm256_shuffle_epi32(lhs_vec_10, 170)));
1474
+ iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_11, 177) ,rhs_vec_4567_11, 170), _mm256_shuffle_epi32(lhs_vec_10, 255)));
1475
+
1476
+ iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_12 ,_mm256_shuffle_epi32(rhs_vec_4567_12, 177), 170), _mm256_shuffle_epi32(lhs_vec_11, 0)));
1477
+ iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_12, 177) ,rhs_vec_4567_12, 170), _mm256_shuffle_epi32(lhs_vec_11, 85)));
1478
+
1479
+ iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_13 ,_mm256_shuffle_epi32(rhs_vec_4567_13, 177), 170), _mm256_shuffle_epi32(lhs_vec_11, 170)));
1480
+ iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_13, 177) ,rhs_vec_4567_13, 170), _mm256_shuffle_epi32(lhs_vec_11, 255)));
1481
+
1482
+ iacc_1 = _mm256_madd_epi16(iacc_1, scales_1);
1483
+
1484
+ // Accumulate the iacc value for one sb
1485
+ __m256i iacc_sb = _mm256_add_epi32(iacc_0, iacc_1);
1486
+
1487
+ // Broadcast the bsums of the two sub blocks of the iteration of Q8_K across the vector
1488
+ // Multiply-Add with corresponding mins of Q4_Kx8 with bsums
1489
+ __m256i q8s_sb = _mm256_shuffle_epi32(q8s, 0);
1490
+ __m256i iacc_min_sb = _mm256_madd_epi16(q8s_sb, mins_01);
1491
+ q8s = _mm256_bsrli_epi128(q8s, 4);
1492
+
1493
+ // Accumulate for the complete block
1494
+ iacc_b = _mm256_add_epi32(iacc_b, iacc_sb);
1495
+ iacc_min_b = _mm256_add_epi32(iacc_min_b, iacc_min_sb);
1496
+ }
1497
+
1498
+ // Multiply-Add with scale values for the complete super block
1499
+ acc_row = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_b), _mm256_mul_ps(col_scale_f32, row_scale_f32), acc_row);
1500
+ acc_min_rows = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_min_b), _mm256_mul_ps(col_dmin_f32, row_scale_f32), acc_min_rows);
1501
+
1502
+ }
1503
+
1504
+ // Accumulated output values permuted so as to be stored in appropriate order post accumulation
1505
+ acc_row = _mm256_permutevar8x32_ps(acc_row, finalpermutemask);
1506
+ _mm256_storeu_ps(s + (y * nr + x * 8), _mm256_sub_ps(acc_row, acc_min_rows));
1507
+ }
1508
+ }
1509
+
1510
+ #else
1511
+
1512
+ float sumf[8];
1513
+ float sum_minf[8];
1514
+ uint32_t utmp[32];
1515
+ int sumi1;
1516
+ int sumi2;
1517
+ int sumi;
1518
+
1519
+ const block_q8_K * a_ptr = (const block_q8_K *) vy;
1520
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
1521
+ const block_q4_Kx8 * b_ptr = (const block_q4_Kx8 *) vx + (x * nb);
1522
+
1523
+ for (int j = 0; j < ncols_interleaved; j++) {
1524
+ sumf[j] = 0.0;
1525
+ sum_minf[j] = 0.0;
1526
+ }
1527
+ for (int l = 0; l < nb; l++) {
1528
+ for (int sb = 0; sb < 8; sb++) {
1529
+ memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12);
1530
+ utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);
1531
+ const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;
1532
+ utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);
1533
+ utmp[sb * 4 + 2] = uaux_0;
1534
+ utmp[sb * 4 + 0] &= kmask1;
1535
+ }
1536
+ for (int k = 0; k < (qk / (2 * blocklen)); k++) {
1537
+ uint8_t *scales_0 = (uint8_t*) utmp + (k / 4) * 32;
1538
+ uint8_t *scales_1 = (uint8_t*) utmp + (k / 4) * 32 + 16;
1539
+ for (int j = 0; j < ncols_interleaved; j++) {
1540
+ sumi1 = 0;
1541
+ sumi2 = 0;
1542
+ sumi = 0;
1543
+ for (int i = 0; i < blocklen; ++i) {
1544
+ const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF);
1545
+ const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4);
1546
+ sumi1 = (v0 * a_ptr[l].qs[(k >> 2) * 64 + (k % 4) * blocklen + i]);
1547
+ sumi2 = (v1 * a_ptr[l].qs[(k >> 2) * 64 + (k % 4) * blocklen + i + 32]);
1548
+ sumi1 = sumi1 * scales_0[j];
1549
+ sumi2 = sumi2 * scales_1[j];
1550
+ sumi += sumi1 + sumi2;
1551
+ }
1552
+ sumf[j] += sumi * LM_GGML_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d;
1553
+ }
1554
+ }
1555
+ for (int sb = 0; sb < 8; sb++) {
1556
+ uint8_t *mins = (uint8_t*) utmp + 8 + sb * 16;
1557
+ for (int j = 0; j < ncols_interleaved; j++) {
1558
+ sum_minf[j] += mins[j] * (a_ptr[l].bsums[sb * 2] + a_ptr[l].bsums[sb * 2 + 1]) * LM_GGML_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d;
1559
+ }
1560
+ }
1561
+ }
1562
+ for (int j = 0; j < ncols_interleaved; j++) {
1563
+ s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j];
1564
+ }
1565
+ }
1566
+ #endif
1567
+ }
1568
+
1569
+
997
1570
  static void lm_ggml_gemv_iq4_nl_4x4_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, const void * LM_GGML_RESTRICT vy, int nr, int nc) {
998
1571
  const int qk = QK8_0;
999
1572
  const int nb = n / qk;
@@ -3480,6 +4053,781 @@ static void lm_ggml_gemm_q4_0_8x8_q8_0(int n, float * LM_GGML_RESTRICT s, size_t
3480
4053
  }
3481
4054
  }
3482
4055
 
4056
+ static void lm_ggml_gemm_q4_K_8x8_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, const void * LM_GGML_RESTRICT vy, int nr, int nc) {
4057
+ const int qk = QK_K;
4058
+ const int nb = n / qk;
4059
+ const int ncols_interleaved = 8;
4060
+ const int blocklen = 8;
4061
+ static const uint32_t kmask1 = 0x3f3f3f3f;
4062
+ static const uint32_t kmask2 = 0x0f0f0f0f;
4063
+ static const uint32_t kmask3 = 0x03030303;
4064
+
4065
+ assert (n % qk == 0);
4066
+ assert (nr % 4 == 0);
4067
+ assert (nc % ncols_interleaved == 0);
4068
+
4069
+ UNUSED(s);
4070
+ UNUSED(bs);
4071
+ UNUSED(vx);
4072
+ UNUSED(vy);
4073
+ UNUSED(nr);
4074
+ UNUSED(nc);
4075
+ UNUSED(nb);
4076
+ UNUSED(ncols_interleaved);
4077
+ UNUSED(blocklen);
4078
+
4079
+ #if defined(__AVX2__)
4080
+ const block_q4_Kx8 * b_ptr_start = (const block_q4_Kx8 * ) vx;
4081
+ const block_q8_Kx4 * a_ptr_start = (const block_q8_Kx4 * ) vy;
4082
+ int64_t b_nb = n / QK_K;
4083
+ int64_t y = 0;
4084
+
4085
+ // Mask to mask out nibbles from packed bytes
4086
+ const __m256i m4b = _mm256_set1_epi8(0x0F);
4087
+ // Permute mask used for easier vector processing at later stages
4088
+ __m256i requiredOrder = _mm256_set_epi32(3, 2, 1, 0, 7, 6, 5, 4);
4089
+
4090
+ int anr = nr - nr % 16;; // Used to align nr with boundary of 16
4091
+ // Take group of four block_q8_Kx4 structures at each pass of the loop and perform dot product operation
4092
+ for (; y < anr / 4; y += 4) {
4093
+
4094
+ const block_q8_Kx4 * a_ptrs[4];
4095
+
4096
+ a_ptrs[0] = a_ptr_start + (y * nb);
4097
+ for (int i = 0; i < 3; ++i) {
4098
+ a_ptrs[i + 1] = a_ptrs[i] + nb;
4099
+ }
4100
+
4101
+ // Take group of eight block_q4_kx8 structures at each pass of the loop and perform dot product operation
4102
+ for (int64_t x = 0; x < nc / 8; x++) {
4103
+
4104
+ const block_q4_Kx8 * b_ptr = b_ptr_start + (x * b_nb);
4105
+
4106
+ // Master FP accumulators
4107
+ __m256 acc_rows[16];
4108
+ for (int i = 0; i < 16; i++) {
4109
+ acc_rows[i] = _mm256_setzero_ps();
4110
+ }
4111
+
4112
+ __m256 acc_min_rows[16];
4113
+ for (int i = 0; i < 16; i++) {
4114
+ acc_min_rows[i] = _mm256_setzero_ps();
4115
+ }
4116
+
4117
+ // For super block
4118
+ for (int64_t b = 0; b < nb; b++) {
4119
+
4120
+ // Scale values - Load the eight scale values of block_q4_kx8
4121
+ const __m256 col_scale_f32 = LM_GGML_F32Cx8_LOAD(b_ptr[b].d);
4122
+
4123
+ // dmin values - Load the eight dmin values of block_q4_kx8
4124
+ const __m256 col_dmin_f32 = LM_GGML_F32Cx8_LOAD(b_ptr[b].dmin);
4125
+
4126
+ // Loop to iterate over the eight sub blocks of a super block - two sub blocks are processed per iteration
4127
+ for (int sb = 0; sb < QK_K / 64; sb++) {
4128
+
4129
+ // Load the eight block_q4_K for two sub blocks quantized values interleaved with each other in chunks of eight bytes - B0,B1 ....B6,B7
4130
+ const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + sb * 256));
4131
+ const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 32 + sb * 256));
4132
+ const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 64 + sb * 256));
4133
+ const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 96 + sb * 256));
4134
+ const __m256i rhs_raw_mat_0123_2 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 128 + sb * 256));
4135
+ const __m256i rhs_raw_mat_4567_2 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 160 + sb * 256));
4136
+ const __m256i rhs_raw_mat_0123_3 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 192 + sb * 256));
4137
+ const __m256i rhs_raw_mat_4567_3 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 224 + sb * 256));
4138
+
4139
+ // Save the values in the following vectors in the formats B0B1B4B5, B2B3B6B7 for further processing and storing of values
4140
+ const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240);
4141
+ const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240);
4142
+ const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240);
4143
+ const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240);
4144
+ const __m256i rhs_raw_mat_0145_2 = _mm256_blend_epi32(rhs_raw_mat_0123_2, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_2, requiredOrder), 240);
4145
+ const __m256i rhs_raw_mat_2367_2 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_2, requiredOrder), rhs_raw_mat_4567_2, 240);
4146
+ const __m256i rhs_raw_mat_0145_3 = _mm256_blend_epi32(rhs_raw_mat_0123_3, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_3, requiredOrder), 240);
4147
+ const __m256i rhs_raw_mat_2367_3 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_3, requiredOrder), rhs_raw_mat_4567_3, 240);
4148
+
4149
+ // 4-bit -> 8-bit
4150
+ // First sub block of the two sub blocks processed in the iteration
4151
+ const __m256i rhs_mat_0145_00 = _mm256_and_si256(rhs_raw_mat_0145_0, m4b); //B00(0-7) B01(0-7) B04(0-7) B05(0-7)
4152
+ const __m256i rhs_mat_2367_00 = _mm256_and_si256(rhs_raw_mat_2367_0, m4b); //B02(0-7) B03(0-7) B06(0-7) B07(0-7)
4153
+
4154
+ const __m256i rhs_mat_0145_01 = _mm256_and_si256(rhs_raw_mat_0145_1, m4b); //B00(8-15) B01(8-15) B04(8-15) B05(8-15)
4155
+ const __m256i rhs_mat_2367_01 = _mm256_and_si256(rhs_raw_mat_2367_1, m4b); //B02(8-15) B03(8-15) B06(8-15) B07(8-15)
4156
+
4157
+ const __m256i rhs_mat_0145_02 = _mm256_and_si256(rhs_raw_mat_0145_2, m4b); //B00(16-23) B01(16-23) B04(16-23) B05(16-23)
4158
+ const __m256i rhs_mat_2367_02 = _mm256_and_si256(rhs_raw_mat_2367_2, m4b); //B02(16-23) B03(16-23) B06(16-23) B07(16-23)
4159
+
4160
+ const __m256i rhs_mat_0145_03 = _mm256_and_si256(rhs_raw_mat_0145_3, m4b); //B00(24-31) B01(24-31) B04(24-31) B05(24-31)
4161
+ const __m256i rhs_mat_2367_03 = _mm256_and_si256(rhs_raw_mat_2367_3, m4b); //B02(24-31) B03(24-31) B06(24-31) B07(24-31)
4162
+
4163
+ // Second sub block of the two sub blocks processed in the iteration
4164
+ const __m256i rhs_mat_0145_10 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_0, 4), m4b); //B10(0-7) B11(0-7) B14(0-7) B15(0-7)
4165
+ const __m256i rhs_mat_2367_10 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_0, 4), m4b); //B12(0-7) B13(0-7) B16(0-7) B17(0-7)
4166
+
4167
+ const __m256i rhs_mat_0145_11 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_1, 4), m4b); //B10(8-15) B11(8-15) B14(8-15) B15(8-15)
4168
+ const __m256i rhs_mat_2367_11 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_1, 4), m4b); //B12(8-15) B13(8-15) B16(8-15) B17(8-15)
4169
+
4170
+ const __m256i rhs_mat_0145_12 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_2, 4), m4b); //B10(16-23) B11(16-23) B14(16-23) B15(16-23)
4171
+ const __m256i rhs_mat_2367_12 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_2, 4), m4b); //B12(16-23) B13(16-23) B16(16-23) B17(16-23)
4172
+
4173
+ const __m256i rhs_mat_0145_13 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_3, 4), m4b); //B10(24-31) B11(24-31) B14(24-31) B15(24-31)
4174
+ const __m256i rhs_mat_2367_13 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_3, 4), m4b); //B12(24-31) B13(24-31) B16(24-31) B17(24-31)
4175
+
4176
+ // Shuffle pattern one - right side input
4177
+ const __m256i rhs_mat_0145_00_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_00, 136); //B00(0-3) B01(0-3) B00(0-3) B01(0-3) B04(0-3) B05(0-3) B04(0-3) B05(0-3)
4178
+ const __m256i rhs_mat_2367_00_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_00, 136); //B02(0-3) B03(0-3) B02(0-3) B03(0-3) B06(0-3) B07(0-3) B06(0-3) B07(0-3)
4179
+
4180
+ const __m256i rhs_mat_0145_01_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_01, 136); //B00(8-11) B01(8-11) B00(8-11) B01(8-11) B04(8-11) B05(8-11) B04(8-11) B05(8-11)
4181
+ const __m256i rhs_mat_2367_01_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_01, 136); //B02(8-11) B03(8-11) B02(8-11) B03(8-11) B06(8-11) B07(8-11) B06(8-11) B07(8-11)
4182
+
4183
+ const __m256i rhs_mat_0145_02_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_02, 136); //B00(16-19) B01(16-19) B00(16-19) B01(16-19) B04(16-19) B05(16-19) B04(16-19) B05(16-19)
4184
+ const __m256i rhs_mat_2367_02_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_02, 136); //B02(16-19) B03(16-19) B02(16-19) B03(16-19) B06(16-19) B07(16-19) B06(16-19) B07(16-19)
4185
+
4186
+ const __m256i rhs_mat_0145_03_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_03, 136); //B00(24-27) B01(24-27) B00(24-27) B01(24-27) B04(24-27) B05(24-27) B04(24-27) B05(24-27)
4187
+ const __m256i rhs_mat_2367_03_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_03, 136); //B02(24-27) B03(24-27) B02(24-27) B03(24-27) B06(24-27) B07(24-27) B06(24-27) B07(24-27)
4188
+
4189
+ const __m256i rhs_mat_0145_10_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_10, 136); //B10(0-3) B11(0-3) B10(0-3) B11(0-3) B14(0-3) B15(0-3) B14(0-3) B15(0-3)
4190
+ const __m256i rhs_mat_2367_10_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_10, 136); //B12(0-3) B13(0-3) B12(0-3) B13(0-3) B16(0-3) B17(0-3) B16(0-3) B17(0-3)
4191
+
4192
+ const __m256i rhs_mat_0145_11_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_11, 136); //B10(8-11) B11(8-11) B10(8-11) B11(8-11) B14(8-11) B15(8-11) B14(8-11) B15(8-11)
4193
+ const __m256i rhs_mat_2367_11_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_11, 136); //B12(8-11) B13(8-11) B12(8-11) B13(8-11) B16(8-11) B17(8-11) B16(8-11) B17(8-11)
4194
+
4195
+ const __m256i rhs_mat_0145_12_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_12, 136); //B10(16-19) B11(16-19) B10(16-19) B11(16-19) B14(16-19) B15(16-19) B14(16-19) B15(16-19)
4196
+ const __m256i rhs_mat_2367_12_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_12, 136); //B12(16-19) B13(16-19) B12(16-19) B13(16-19) B16(16-19) B17(16-19) B16(16-19) B17(16-19)
4197
+
4198
+ const __m256i rhs_mat_0145_13_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_13, 136); //B10(24-27) B11(24-27) B10(24-27) B11(24-27) B14(24-27) B15(24-27) B14(24-27) B15(24-27)
4199
+ const __m256i rhs_mat_2367_13_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_13, 136); //B12(24-27) B13(24-27) B12(24-27) B13(24-27) B16(24-27) B17(24-27) B16(24-27) B17(24-27)
4200
+
4201
+
4202
+ // Shuffle pattern two - right side input
4203
+ const __m256i rhs_mat_0145_00_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_00, 221); //B00(4-7) B01(4-7) B00(4-7) B01(4-7) B04(4-7) B05(4-7) B04(4-7) B05(4-7)
4204
+ const __m256i rhs_mat_2367_00_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_00, 221); //B02(4-7) B03(4-7) B02(4-7) B03(4-7) B06(4-7) B07(4-7) B06(4-7) B07(4-7)
4205
+
4206
+ const __m256i rhs_mat_0145_01_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_01, 221); //B00(12-15) B01(12-15) B00(12-15) B01(12-15) B04(12-15) B05(12-15) B04(12-15) B05(12-15)
4207
+ const __m256i rhs_mat_2367_01_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_01, 221); //B02(12-15) B03(12-15) B02(12-15) B03(12-15) B06(12-15) B07(12-15) B06(12-15) B07(12-15)
4208
+
4209
+ const __m256i rhs_mat_0145_02_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_02, 221); //B00(20-23) B01(20-23) B00(20-23) B01(20-23) B04(20-23) B05(20-23) B04(20-23) B05(20-23)
4210
+ const __m256i rhs_mat_2367_02_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_02, 221); //B02(20-23) B03(20-23) B02(20-23) B03(20-23) B06(20-23) B07(20-23) B06(20-23) B07(20-23)
4211
+
4212
+ const __m256i rhs_mat_0145_03_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_03, 221); //B00(28-31) B01(28-31) B00(28-31) B01(28-31) B04(28-31) B05(28-31) B04(28-31) B05(28-31)
4213
+ const __m256i rhs_mat_2367_03_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_03, 221); //B02(28-31) B03(28-31) B02(28-31) B03(28-31) B06(28-31) B07(28-31) B06(28-31) B07(28-31)
4214
+
4215
+ const __m256i rhs_mat_0145_10_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_10, 221); //B10(4-7) B11(4-7) B10(4-7) B11(4-7) B14(4-7) B15(4-7) B14(4-7) B15(4-7)
4216
+ const __m256i rhs_mat_2367_10_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_10, 221); //B12(4-7) B13(4-7) B12(4-7) B13(4-7) B16(4-7) B17(4-7) B16(4-7) B17(4-7)
4217
+
4218
+ const __m256i rhs_mat_0145_11_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_11, 221); //B10(12-15) B11(12-15) B10(12-15) B11(12-15) B14(12-15) B15(12-15) B14(12-15) B15(12-15)
4219
+ const __m256i rhs_mat_2367_11_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_11, 221); //B12(12-15) B13(12-15) B12(12-15) B13(12-15) B16(12-15) B17(12-15) B16(12-15) B17(12-15)
4220
+
4221
+ const __m256i rhs_mat_0145_12_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_12, 221); //B10(20-23) B11(20-23) B10(20-23) B11(20-23) B14(20-23) B15(20-23) B14(20-23) B15(20-23)
4222
+ const __m256i rhs_mat_2367_12_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_12, 221); //B12(20-23) B13(20-23) B12(20-23) B13(20-23) B16(20-23) B17(20-23) B16(20-23) B17(20-23)
4223
+
4224
+ const __m256i rhs_mat_0145_13_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_13, 221); //B10(28-31) B11(28-31) B10(28-31) B11(28-31) B14(28-31) B15(28-31) B14(28-31) B15(28-31)
4225
+ const __m256i rhs_mat_2367_13_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_13, 221); //B12(28-31) B13(28-31) B12(28-31) B13(28-31) B16(28-31) B17(28-31) B16(28-31) B17(28-31)
4226
+
4227
+ uint32_t utmp_0[4], utmp_1[4];
4228
+
4229
+ // Scales and Mins of corresponding sub blocks from different Q4_K structures are stored together
4230
+ // The below block is for eg to extract first sub block's scales and mins from different Q4_K structures for the sb loop
4231
+ memcpy(utmp_0, b_ptr[b].scales + 24 * sb, 12);
4232
+ utmp_0[3] = ((utmp_0[2] >> 4) & kmask2) | (((utmp_0[1] >> 6) & kmask3) << 4);
4233
+ const uint32_t uaux_0 = utmp_0[1] & kmask1;
4234
+ utmp_0[1] = (utmp_0[2] & kmask2) | (((utmp_0[0] >> 6) & kmask3) << 4);
4235
+ utmp_0[2] = uaux_0;
4236
+ utmp_0[0] &= kmask1;
4237
+
4238
+ // The below block is for eg to extract second sub block's scales and mins from different Q4_K structures for the sb loop
4239
+ memcpy(utmp_1, b_ptr[b].scales + 12 + sb * 24, 12);
4240
+ utmp_1[3] = ((utmp_1[2] >> 4) & kmask2) | (((utmp_1[1] >> 6) & kmask3) << 4);
4241
+ const uint32_t uaux_1 = utmp_1[1] & kmask1;
4242
+ utmp_1[1] = (utmp_1[2] & kmask2) | (((utmp_1[0] >> 6) & kmask3) << 4);
4243
+ utmp_1[2] = uaux_1;
4244
+ utmp_1[0] &= kmask1;
4245
+
4246
+ // Scales of first sub block in the sb loop
4247
+ const __m128i mins_and_scales_0 = _mm_set_epi32(utmp_0[3], utmp_0[2], utmp_0[1], utmp_0[0]);
4248
+ const __m256i scales_0 = _mm256_cvtepu8_epi16(_mm_unpacklo_epi8(mins_and_scales_0, mins_and_scales_0));
4249
+
4250
+ // Scales of second sub block in the sb loop
4251
+ const __m128i mins_and_scales_1 = _mm_set_epi32(utmp_1[3], utmp_1[2], utmp_1[1], utmp_1[0]);
4252
+ const __m256i scales_1 = _mm256_cvtepu8_epi16(_mm_unpacklo_epi8(mins_and_scales_1, mins_and_scales_1));
4253
+
4254
+ // Mins of first and second sub block of Q4_K block are arranged side by side
4255
+ const __m256i mins_01 = _mm256_cvtepu8_epi16(_mm_unpacklo_epi8(_mm_shuffle_epi32(mins_and_scales_0, 78), _mm_shuffle_epi32(mins_and_scales_1, 78)));
4256
+
4257
+ const __m256i scale_0145_0 = _mm256_shuffle_epi32(scales_0, 68);
4258
+ const __m256i scale_2367_0 = _mm256_shuffle_epi32(scales_0, 238);
4259
+
4260
+ const __m256i scale_0145_1 = _mm256_shuffle_epi32(scales_1, 68);
4261
+ const __m256i scale_2367_1 = _mm256_shuffle_epi32(scales_1, 238);
4262
+
4263
+ for (int rp = 0; rp < 4; rp++) {
4264
+
4265
+ // Load the four block_q8_k quantized values interleaved with each other in chunks of eight bytes - A0,A1,A2,A3
4266
+ // Loaded as set of 128 bit vectors and repeated into a 256 bit vector
4267
+ __m256i lhs_mat_0123_00 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 256 * sb)));
4268
+ __m256i lhs_mat_01_00 = _mm256_permute2f128_si256(lhs_mat_0123_00, lhs_mat_0123_00, 0);
4269
+ __m256i lhs_mat_23_00 = _mm256_permute2f128_si256(lhs_mat_0123_00, lhs_mat_0123_00, 17);
4270
+ __m256i lhs_mat_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 32 + 256 * sb)));
4271
+ __m256i lhs_mat_01_01 = _mm256_permute2f128_si256(lhs_mat_0123_01, lhs_mat_0123_01, 0);
4272
+ __m256i lhs_mat_23_01 = _mm256_permute2f128_si256(lhs_mat_0123_01, lhs_mat_0123_01, 17);
4273
+ __m256i lhs_mat_0123_02 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 64 + 256 * sb)));
4274
+ __m256i lhs_mat_01_02 = _mm256_permute2f128_si256(lhs_mat_0123_02, lhs_mat_0123_02, 0);
4275
+ __m256i lhs_mat_23_02 = _mm256_permute2f128_si256(lhs_mat_0123_02, lhs_mat_0123_02, 17);
4276
+ __m256i lhs_mat_0123_03 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 96 + 256 * sb)));
4277
+ __m256i lhs_mat_01_03 = _mm256_permute2f128_si256(lhs_mat_0123_03, lhs_mat_0123_03, 0);
4278
+ __m256i lhs_mat_23_03 = _mm256_permute2f128_si256(lhs_mat_0123_03, lhs_mat_0123_03, 17);
4279
+ __m256i lhs_mat_0123_10 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 128 + 256 * sb)));
4280
+ __m256i lhs_mat_01_10 = _mm256_permute2f128_si256(lhs_mat_0123_10, lhs_mat_0123_10, 0);
4281
+ __m256i lhs_mat_23_10 = _mm256_permute2f128_si256(lhs_mat_0123_10, lhs_mat_0123_10, 17);
4282
+ __m256i lhs_mat_0123_11 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 160 + 256 * sb)));
4283
+ __m256i lhs_mat_01_11 = _mm256_permute2f128_si256(lhs_mat_0123_11, lhs_mat_0123_11, 0);
4284
+ __m256i lhs_mat_23_11 = _mm256_permute2f128_si256(lhs_mat_0123_11, lhs_mat_0123_11, 17);
4285
+ __m256i lhs_mat_0123_12 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 192 + 256 * sb)));
4286
+ __m256i lhs_mat_01_12 = _mm256_permute2f128_si256(lhs_mat_0123_12, lhs_mat_0123_12, 0);
4287
+ __m256i lhs_mat_23_12 = _mm256_permute2f128_si256(lhs_mat_0123_12, lhs_mat_0123_12, 17);
4288
+ __m256i lhs_mat_0123_13 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 224 + 256 * sb)));
4289
+ __m256i lhs_mat_01_13 = _mm256_permute2f128_si256(lhs_mat_0123_13, lhs_mat_0123_13, 0);
4290
+ __m256i lhs_mat_23_13 = _mm256_permute2f128_si256(lhs_mat_0123_13, lhs_mat_0123_13, 17);
4291
+
4292
+ // Bsums are loaded - four bsums are loaded (for two sub blocks) for the different Q8_K blocks
4293
+ __m256i lhs_bsums_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].bsums + 16 * sb)));
4294
+ __m256i lhs_bsums_hsum_0123_01 = _mm256_castsi128_si256(_mm_hadd_epi16(_mm256_castsi256_si128(lhs_bsums_0123_01), _mm256_extractf128_si256(lhs_bsums_0123_01, 1)));
4295
+ lhs_bsums_hsum_0123_01 = _mm256_permute2x128_si256(lhs_bsums_hsum_0123_01, lhs_bsums_hsum_0123_01, 0);
4296
+
4297
+ // Shuffle pattern one - left side input
4298
+ const __m256i lhs_mat_01_00_sp1 = _mm256_shuffle_epi32(lhs_mat_01_00, 160); //A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3)
4299
+ const __m256i lhs_mat_23_00_sp1 = _mm256_shuffle_epi32(lhs_mat_23_00, 160); //A02(0-3) A03(0-3) A02(0-3) A03(0-3) A02(0-3) A03(0-3) A02(0-3) A03(0-3)
4300
+
4301
+ const __m256i lhs_mat_01_01_sp1 = _mm256_shuffle_epi32(lhs_mat_01_01, 160); //A00(8-11) A00(8-11) A01(8-11) A01(8-11) A00(8-11) A00(8-11) A01(8-11) A01(8-11)
4302
+ const __m256i lhs_mat_23_01_sp1 = _mm256_shuffle_epi32(lhs_mat_23_01, 160); //A02(8-11) A03(8-11) A02(8-11) A03(8-11) A02(8-11) A03(8-11) A02(8-11) A03(8-11)
4303
+
4304
+ const __m256i lhs_mat_01_02_sp1 = _mm256_shuffle_epi32(lhs_mat_01_02, 160); //A00(16-19) A00(16-19) A01(16-19) A01(16-19) A00(16-19) A00(16-19) A01(16-19) A01(16-19)
4305
+ const __m256i lhs_mat_23_02_sp1 = _mm256_shuffle_epi32(lhs_mat_23_02, 160); //A02(16-19) A03(16-19) A02(16-19) A03(16-19) A02(16-19) A03(16-19) A02(16-19) A03(16-19)
4306
+
4307
+ const __m256i lhs_mat_01_03_sp1 = _mm256_shuffle_epi32(lhs_mat_01_03, 160); //A00(24-27) A00(24-27) A01(24-27) A01(24-27) A00(24-27) A00(24-27) A01(24-27) A01(24-27)
4308
+ const __m256i lhs_mat_23_03_sp1 = _mm256_shuffle_epi32(lhs_mat_23_03, 160); //A02(24-27) A03(24-27) A02(24-27) A03(24-27) A02(24-27) A03(24-27) A02(24-27) A03(24-27)
4309
+
4310
+ const __m256i lhs_mat_01_10_sp1 = _mm256_shuffle_epi32(lhs_mat_01_10, 160); //A10(0-3) A10(0-3) A11(0-3) A11(0-3) A10(0-3) A10(0-3) A11(0-3) A11(0-3)
4311
+ const __m256i lhs_mat_23_10_sp1 = _mm256_shuffle_epi32(lhs_mat_23_10, 160); //A12(0-3) A13(0-3) A12(0-3) A13(0-3) A12(0-3) A13(0-3) A12(0-3) A13(0-3)
4312
+
4313
+ const __m256i lhs_mat_01_11_sp1 = _mm256_shuffle_epi32(lhs_mat_01_11, 160); //A10(8-11) A10(8-11) A11(8-11) A11(8-11) A10(8-11) A10(8-11) A11(8-11) A11(8-11)
4314
+ const __m256i lhs_mat_23_11_sp1 = _mm256_shuffle_epi32(lhs_mat_23_11, 160); //A12(8-11) A13(8-11) A12(8-11) A13(8-11) A12(8-11) A13(8-11) A12(8-11) A13(8-11)
4315
+
4316
+ const __m256i lhs_mat_01_12_sp1 = _mm256_shuffle_epi32(lhs_mat_01_12, 160); //A10(16-19) A10(16-19) A11(16-19) A11(16-19) A10(16-19) A10(16-19) A11(16-19) A11(16-19)
4317
+ const __m256i lhs_mat_23_12_sp1 = _mm256_shuffle_epi32(lhs_mat_23_12, 160); //A12(16-19) A13(16-19) A12(16-19) A13(16-19) A12(16-19) A13(16-19) A12(16-19) A13(16-19)
4318
+
4319
+ const __m256i lhs_mat_01_13_sp1 = _mm256_shuffle_epi32(lhs_mat_01_13, 160); //A10(24-27) A10(24-27) A11(24-27) A11(24-27) A10(24-27) A10(24-27) A11(24-27) A11(24-27)
4320
+ const __m256i lhs_mat_23_13_sp1 = _mm256_shuffle_epi32(lhs_mat_23_13, 160); //A12(24-27) A13(24-27) A12(24-27) A13(24-27) A12(24-27) A13(24-27) A12(24-27) A13(24-27)
4321
+
4322
+ // Shuffle pattern two- left side input
4323
+ const __m256i lhs_mat_01_00_sp2 = _mm256_shuffle_epi32(lhs_mat_01_00, 245); //A00(4-7) A00(4-7) A01(4-7) A01(4-7) A00(4-7) A00(4-7) A01(4-7) A01(4-7)
4324
+ const __m256i lhs_mat_23_00_sp2 = _mm256_shuffle_epi32(lhs_mat_23_00, 245); //A02(4-7) A03(4-7) A02(4-7) A03(4-7) A02(4-7) A03(4-7) A02(4-7) A03(4-7)
4325
+
4326
+ const __m256i lhs_mat_01_01_sp2 = _mm256_shuffle_epi32(lhs_mat_01_01, 245); //A00(12-15) A00(12-15) A01(12-15) A01(12-15) A00(12-15) A00(12-15) A01(12-15) A01(12-15)
4327
+ const __m256i lhs_mat_23_01_sp2 = _mm256_shuffle_epi32(lhs_mat_23_01, 245); //A02(12-15) A03(12-15) A02(12-15) A03(12-15) A02(12-15) A03(12-15) A02(12-15) A03(12-15)
4328
+
4329
+ const __m256i lhs_mat_01_02_sp2 = _mm256_shuffle_epi32(lhs_mat_01_02, 245); //A00(20-23) A00(20-23) A01(20-23) A01(20-23) A00(20-23) A00(20-23) A01(20-23) A01(20-23)
4330
+ const __m256i lhs_mat_23_02_sp2 = _mm256_shuffle_epi32(lhs_mat_23_02, 245); //A02(20-23) A03(20-23) A02(20-23) A03(20-23) A02(20-23) A03(20-23) A02(20-23) A03(20-23)
4331
+
4332
+ const __m256i lhs_mat_01_03_sp2 = _mm256_shuffle_epi32(lhs_mat_01_03, 245); //A00(28-31) A00(28-31) A01(28-31) A01(28-31) A00(28-31) A00(28-31) A01(28-31) A01(28-31)
4333
+ const __m256i lhs_mat_23_03_sp2 = _mm256_shuffle_epi32(lhs_mat_23_03, 245); //A02(28-31) A03(28-31) A02(28-31) A03(28-31) A02(28-31) A03(28-31) A02(28-31) A03(28-31)
4334
+
4335
+ const __m256i lhs_mat_01_10_sp2 = _mm256_shuffle_epi32(lhs_mat_01_10, 245); //A10(4-7) A10(4-7) A11(4-7) A11(4-7) A10(4-7) A10(4-7) A11(4-7) A11(4-7)
4336
+ const __m256i lhs_mat_23_10_sp2 = _mm256_shuffle_epi32(lhs_mat_23_10, 245); //A12(4-7) A13(4-7) A12(4-7) A13(4-7) A12(4-7) A13(4-7) A12(4-7) A13(4-7)
4337
+
4338
+ const __m256i lhs_mat_01_11_sp2 = _mm256_shuffle_epi32(lhs_mat_01_11, 245); //A10(12-15) A10(12-15) A11(12-15) A11(12-15) A10(12-15) A10(12-15) A11(12-15) A11(12-15)
4339
+ const __m256i lhs_mat_23_11_sp2 = _mm256_shuffle_epi32(lhs_mat_23_11, 245); //A12(12-15) A13(12-15) A12(12-15) A13(12-15) A12(12-15) A13(12-15) A12(12-15) A13(12-15)
4340
+
4341
+ const __m256i lhs_mat_01_12_sp2 = _mm256_shuffle_epi32(lhs_mat_01_12, 245); //A10(20-23) A10(20-23) A11(20-23) A11(20-23) A10(20-23) A10(20-23) A11(20-23) A11(20-23)
4342
+ const __m256i lhs_mat_23_12_sp2 = _mm256_shuffle_epi32(lhs_mat_23_12, 245); //A12(20-23) A13(20-23) A12(20-23) A13(20-23) A12(20-23) A13(20-23) A12(20-23) A13(20-23)
4343
+
4344
+ const __m256i lhs_mat_01_13_sp2 = _mm256_shuffle_epi32(lhs_mat_01_13, 245); //A10(28-31) A10(28-31) A11(28-31) A11(28-31) A10(28-31) A10(28-31) A11(28-31) A11(28-31)
4345
+ const __m256i lhs_mat_23_13_sp2 = _mm256_shuffle_epi32(lhs_mat_23_13, 245); //A12(28-31) A13(28-31) A12(28-31) A13(28-31) A12(28-31) A13(28-31) A12(28-31) A13(28-31)
4346
+
4347
+ // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
4348
+ __m256i iacc_mat_00_0_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_03_sp1, lhs_mat_01_03_sp1), _mm256_maddubs_epi16(rhs_mat_0145_02_sp1, lhs_mat_01_02_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_01_sp1, lhs_mat_01_01_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_00_sp1, lhs_mat_01_00_sp1));
4349
+ __m256i iacc_mat_01_0_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_03_sp1, lhs_mat_01_03_sp1), _mm256_maddubs_epi16(rhs_mat_2367_02_sp1, lhs_mat_01_02_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_01_sp1, lhs_mat_01_01_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_00_sp1, lhs_mat_01_00_sp1));
4350
+ __m256i iacc_mat_10_0_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_03_sp1, lhs_mat_23_03_sp1), _mm256_maddubs_epi16(rhs_mat_0145_02_sp1, lhs_mat_23_02_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_01_sp1, lhs_mat_23_01_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_00_sp1, lhs_mat_23_00_sp1));
4351
+ __m256i iacc_mat_11_0_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_03_sp1, lhs_mat_23_03_sp1), _mm256_maddubs_epi16(rhs_mat_2367_02_sp1, lhs_mat_23_02_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_01_sp1, lhs_mat_23_01_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_00_sp1, lhs_mat_23_00_sp1));
4352
+ __m256i iacc_mat_00_1_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_13_sp1, lhs_mat_01_13_sp1), _mm256_maddubs_epi16(rhs_mat_0145_12_sp1, lhs_mat_01_12_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_11_sp1, lhs_mat_01_11_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_10_sp1, lhs_mat_01_10_sp1));
4353
+ __m256i iacc_mat_01_1_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_13_sp1, lhs_mat_01_13_sp1), _mm256_maddubs_epi16(rhs_mat_2367_12_sp1, lhs_mat_01_12_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_11_sp1, lhs_mat_01_11_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_10_sp1, lhs_mat_01_10_sp1));
4354
+ __m256i iacc_mat_10_1_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_13_sp1, lhs_mat_23_13_sp1), _mm256_maddubs_epi16(rhs_mat_0145_12_sp1, lhs_mat_23_12_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_11_sp1, lhs_mat_23_11_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_10_sp1, lhs_mat_23_10_sp1));
4355
+ __m256i iacc_mat_11_1_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_13_sp1, lhs_mat_23_13_sp1), _mm256_maddubs_epi16(rhs_mat_2367_12_sp1, lhs_mat_23_12_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_11_sp1, lhs_mat_23_11_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_10_sp1, lhs_mat_23_10_sp1));
4356
+
4357
+ __m256i iacc_mat_00_0_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_03_sp2, lhs_mat_01_03_sp2), _mm256_maddubs_epi16(rhs_mat_0145_02_sp2, lhs_mat_01_02_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_01_sp2, lhs_mat_01_01_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_00_sp2, lhs_mat_01_00_sp2));
4358
+ __m256i iacc_mat_01_0_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_03_sp2, lhs_mat_01_03_sp2), _mm256_maddubs_epi16(rhs_mat_2367_02_sp2, lhs_mat_01_02_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_01_sp2, lhs_mat_01_01_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_00_sp2, lhs_mat_01_00_sp2));
4359
+ __m256i iacc_mat_10_0_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_03_sp2, lhs_mat_23_03_sp2), _mm256_maddubs_epi16(rhs_mat_0145_02_sp2, lhs_mat_23_02_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_01_sp2, lhs_mat_23_01_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_00_sp2, lhs_mat_23_00_sp2));
4360
+ __m256i iacc_mat_11_0_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_03_sp2, lhs_mat_23_03_sp2), _mm256_maddubs_epi16(rhs_mat_2367_02_sp2, lhs_mat_23_02_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_01_sp2, lhs_mat_23_01_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_00_sp2, lhs_mat_23_00_sp2));
4361
+ __m256i iacc_mat_00_1_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_13_sp2, lhs_mat_01_13_sp2), _mm256_maddubs_epi16(rhs_mat_0145_12_sp2, lhs_mat_01_12_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_11_sp2, lhs_mat_01_11_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_10_sp2, lhs_mat_01_10_sp2));
4362
+ __m256i iacc_mat_01_1_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_13_sp2, lhs_mat_01_13_sp2), _mm256_maddubs_epi16(rhs_mat_2367_12_sp2, lhs_mat_01_12_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_11_sp2, lhs_mat_01_11_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_10_sp2, lhs_mat_01_10_sp2));
4363
+ __m256i iacc_mat_10_1_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_13_sp2, lhs_mat_23_13_sp2), _mm256_maddubs_epi16(rhs_mat_0145_12_sp2, lhs_mat_23_12_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_11_sp2, lhs_mat_23_11_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_10_sp2, lhs_mat_23_10_sp2));
4364
+ __m256i iacc_mat_11_1_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_13_sp2, lhs_mat_23_13_sp2), _mm256_maddubs_epi16(rhs_mat_2367_12_sp2, lhs_mat_23_12_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_11_sp2, lhs_mat_23_11_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_10_sp2, lhs_mat_23_10_sp2));
4365
+
4366
+ // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
4367
+ __m256i iacc_mat_00_0 = _mm256_add_epi16(iacc_mat_00_0_sp1, iacc_mat_00_0_sp2);
4368
+ __m256i iacc_mat_01_0 = _mm256_add_epi16(iacc_mat_01_0_sp1, iacc_mat_01_0_sp2);
4369
+ __m256i iacc_mat_10_0 = _mm256_add_epi16(iacc_mat_10_0_sp1, iacc_mat_10_0_sp2);
4370
+ __m256i iacc_mat_11_0 = _mm256_add_epi16(iacc_mat_11_0_sp1, iacc_mat_11_0_sp2);
4371
+
4372
+ __m256i iacc_mat_00_1 = _mm256_add_epi16(iacc_mat_00_1_sp1, iacc_mat_00_1_sp2);
4373
+ __m256i iacc_mat_01_1 = _mm256_add_epi16(iacc_mat_01_1_sp1, iacc_mat_01_1_sp2);
4374
+ __m256i iacc_mat_10_1 = _mm256_add_epi16(iacc_mat_10_1_sp1, iacc_mat_10_1_sp2);
4375
+ __m256i iacc_mat_11_1 = _mm256_add_epi16(iacc_mat_11_1_sp1, iacc_mat_11_1_sp2);
4376
+
4377
+ // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
4378
+ iacc_mat_00_0 = _mm256_madd_epi16(iacc_mat_00_0, scale_0145_0);
4379
+ iacc_mat_01_0 = _mm256_madd_epi16(iacc_mat_01_0, scale_2367_0);
4380
+ iacc_mat_10_0 = _mm256_madd_epi16(iacc_mat_10_0, scale_0145_0);
4381
+ iacc_mat_11_0 = _mm256_madd_epi16(iacc_mat_11_0, scale_2367_0);
4382
+
4383
+ iacc_mat_00_1 = _mm256_madd_epi16(iacc_mat_00_1, scale_0145_1);
4384
+ iacc_mat_01_1 = _mm256_madd_epi16(iacc_mat_01_1, scale_2367_1);
4385
+ iacc_mat_10_1 = _mm256_madd_epi16(iacc_mat_10_1, scale_0145_1);
4386
+ iacc_mat_11_1 = _mm256_madd_epi16(iacc_mat_11_1, scale_2367_1);
4387
+
4388
+ // Straighten out to make 4 row vectors (4 for each sub block which are accumulated together in the next step)
4389
+ __m256i iacc_row_0_0 = _mm256_blend_epi32(iacc_mat_00_0, _mm256_shuffle_epi32(iacc_mat_01_0, 78), 204);
4390
+ __m256i iacc_row_1_0 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_00_0, 78), iacc_mat_01_0, 204);
4391
+ __m256i iacc_row_2_0 = _mm256_blend_epi32(iacc_mat_10_0, _mm256_shuffle_epi32(iacc_mat_11_0, 78), 204);
4392
+ __m256i iacc_row_3_0 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_10_0, 78), iacc_mat_11_0, 204);
4393
+ __m256i iacc_row_0_1 = _mm256_blend_epi32(iacc_mat_00_1, _mm256_shuffle_epi32(iacc_mat_01_1, 78), 204);
4394
+ __m256i iacc_row_1_1 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_00_1, 78), iacc_mat_01_1, 204);
4395
+ __m256i iacc_row_2_1 = _mm256_blend_epi32(iacc_mat_10_1, _mm256_shuffle_epi32(iacc_mat_11_1, 78), 204);
4396
+ __m256i iacc_row_3_1 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_10_1, 78), iacc_mat_11_1, 204);
4397
+
4398
+ __m256i iacc_row_0 = _mm256_add_epi32(iacc_row_0_0, iacc_row_0_1);
4399
+ __m256i iacc_row_1 = _mm256_add_epi32(iacc_row_1_0, iacc_row_1_1);
4400
+ __m256i iacc_row_2 = _mm256_add_epi32(iacc_row_2_0, iacc_row_2_1);
4401
+ __m256i iacc_row_3 = _mm256_add_epi32(iacc_row_3_0, iacc_row_3_1);
4402
+
4403
+ // Load the scale(d) values for all the 4 Q8_k blocks and repeat it across lanes
4404
+ const __m128 row_scale_f32_sse = _mm_load_ps(a_ptrs[rp][b].d);
4405
+ const __m256 row_scale_f32 = _mm256_set_m128(row_scale_f32_sse, row_scale_f32_sse);//LM_GGML_F32Cx8_REPEAT_LOAD(a_ptrs[rp][b].d, loadMask);
4406
+
4407
+ // Multiply with appropiate scales and accumulate (for both d and dmin) below
4408
+ acc_rows[rp * 4] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_0), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[rp * 4]);
4409
+ acc_rows[rp * 4 + 1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_1), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[rp * 4 + 1]);
4410
+ acc_rows[rp * 4 + 2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_2), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[rp * 4 + 2]);
4411
+ acc_rows[rp * 4 + 3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_3), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[rp * 4 + 3]);
4412
+
4413
+ __m256i iacc_row_min_0 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_hsum_0123_01, 0), mins_01);
4414
+ __m256i iacc_row_min_1 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_hsum_0123_01, 85), mins_01);
4415
+ __m256i iacc_row_min_2 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_hsum_0123_01, 170), mins_01);
4416
+ __m256i iacc_row_min_3 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_hsum_0123_01, 255), mins_01);
4417
+
4418
+ acc_min_rows[rp * 4] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_0), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_min_rows[rp * 4]);
4419
+ acc_min_rows[rp * 4 + 1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_1), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_min_rows[rp * 4 + 1]);
4420
+ acc_min_rows[rp * 4 + 2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_2), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_min_rows[rp * 4 + 2]);
4421
+ acc_min_rows[rp * 4 + 3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_3), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_min_rows[rp * 4 + 3]);
4422
+
4423
+ }
4424
+ }
4425
+ }
4426
+ // Store the accumulated values
4427
+ for (int i = 0; i < 16; i++) {
4428
+ _mm256_storeu_ps((float * )(s + ((y * 4 + i) * bs + x * 8)), _mm256_sub_ps(acc_rows[i], acc_min_rows[i]));
4429
+ }
4430
+ }
4431
+ }
4432
+ for (; y < nr / 4; y++) {
4433
+
4434
+ const block_q8_Kx4 * a_ptr = a_ptr_start + (y * nb);
4435
+
4436
+ for (int64_t x = 0; x < nc / 8; x++) {
4437
+
4438
+ const block_q4_Kx8 * b_ptr = b_ptr_start + (x * b_nb);
4439
+
4440
+ // Master FP accumulators
4441
+ __m256 acc_rows[4];
4442
+ for (int i = 0; i < 4; i++) {
4443
+ acc_rows[i] = _mm256_setzero_ps();
4444
+ }
4445
+
4446
+ __m256 acc_min_rows[4];
4447
+ for (int i = 0; i < 4; i++) {
4448
+ acc_min_rows[i] = _mm256_setzero_ps();
4449
+ }
4450
+
4451
+ for (int64_t b = 0; b < nb; b++) {
4452
+
4453
+ // Scale values - Load the eight scale values of block_q4_Kx8
4454
+ const __m256 col_scale_f32 = LM_GGML_F32Cx8_LOAD(b_ptr[b].d);
4455
+
4456
+ // dmin values - Load the eight dmin values of block_q4_Kx8
4457
+ const __m256 col_dmin_f32 = LM_GGML_F32Cx8_LOAD(b_ptr[b].dmin);
4458
+
4459
+ // Loop to iterate over the eight sub blocks of a super block - two sub blocks are processed per iteration
4460
+ for (int sb = 0; sb < QK_K / 64; sb++) {
4461
+
4462
+ // Load the eight block_q4_k for two sub blocks quantized values interleaved with each other in chunks of eight bytes - B0,B1 ....B6,B7
4463
+ const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + sb * 256));
4464
+ const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 32 + sb * 256));
4465
+ const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 64 + sb * 256));
4466
+ const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 96 + sb * 256));
4467
+ const __m256i rhs_raw_mat_0123_2 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 128 + sb * 256));
4468
+ const __m256i rhs_raw_mat_4567_2 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 160 + sb * 256));
4469
+ const __m256i rhs_raw_mat_0123_3 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 192 + sb * 256));
4470
+ const __m256i rhs_raw_mat_4567_3 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 224 + sb * 256));
4471
+
4472
+ // Save the values in the following vectors in the formats B0B1B4B5, B2B3B6B7 for further processing and storing of values
4473
+ const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240);
4474
+ const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240);
4475
+ const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240);
4476
+ const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240);
4477
+ const __m256i rhs_raw_mat_0145_2 = _mm256_blend_epi32(rhs_raw_mat_0123_2, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_2, requiredOrder), 240);
4478
+ const __m256i rhs_raw_mat_2367_2 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_2, requiredOrder), rhs_raw_mat_4567_2, 240);
4479
+ const __m256i rhs_raw_mat_0145_3 = _mm256_blend_epi32(rhs_raw_mat_0123_3, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_3, requiredOrder), 240);
4480
+ const __m256i rhs_raw_mat_2367_3 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_3, requiredOrder), rhs_raw_mat_4567_3, 240);
4481
+
4482
+ // 4-bit -> 8-bit
4483
+ // First sub block of the two sub blocks processed in the iteration
4484
+ const __m256i rhs_mat_0145_00 = _mm256_and_si256(rhs_raw_mat_0145_0, m4b); //B00(0-7) B01(0-7) B04(0-7) B05(0-7)
4485
+ const __m256i rhs_mat_2367_00 = _mm256_and_si256(rhs_raw_mat_2367_0, m4b); //B02(0-7) B03(0-7) B06(0-7) B07(0-7)
4486
+
4487
+ const __m256i rhs_mat_0145_01 = _mm256_and_si256(rhs_raw_mat_0145_1, m4b); //B00(8-15) B01(8-15) B04(8-15) B05(8-15)
4488
+ const __m256i rhs_mat_2367_01 = _mm256_and_si256(rhs_raw_mat_2367_1, m4b); //B02(8-15) B03(8-15) B06(8-15) B07(8-15)
4489
+
4490
+ const __m256i rhs_mat_0145_02 = _mm256_and_si256(rhs_raw_mat_0145_2, m4b); //B00(16-23) B01(16-23) B04(16-23) B05(16-23)
4491
+ const __m256i rhs_mat_2367_02 = _mm256_and_si256(rhs_raw_mat_2367_2, m4b); //B02(16-23) B03(16-23) B06(16-23) B07(16-23)
4492
+
4493
+ const __m256i rhs_mat_0145_03 = _mm256_and_si256(rhs_raw_mat_0145_3, m4b); //B00(24-31) B01(24-31) B04(24-31) B05(24-31)
4494
+ const __m256i rhs_mat_2367_03 = _mm256_and_si256(rhs_raw_mat_2367_3, m4b); //B02(24-31) B03(24-31) B06(24-31) B07(24-31)
4495
+
4496
+ // Second sub block of the two sub blocks processed in the iteration
4497
+ const __m256i rhs_mat_0145_10 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_0, 4), m4b); //B10(0-7) B11(0-7) B14(0-7) B15(0-7)
4498
+ const __m256i rhs_mat_2367_10 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_0, 4), m4b); //B12(0-7) B13(0-7) B16(0-7) B17(0-7)
4499
+
4500
+ const __m256i rhs_mat_0145_11 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_1, 4), m4b); //B10(8-15) B11(8-15) B14(8-15) B15(8-15)
4501
+ const __m256i rhs_mat_2367_11 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_1, 4), m4b); //B12(8-15) B13(8-15) B16(8-15) B17(8-15)
4502
+
4503
+ const __m256i rhs_mat_0145_12 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_2, 4), m4b); //B10(16-23) B11(16-23) B14(16-23) B15(16-23)
4504
+ const __m256i rhs_mat_2367_12 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_2, 4), m4b); //B12(16-23) B13(16-23) B16(16-23) B17(16-23)
4505
+
4506
+ const __m256i rhs_mat_0145_13 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_3, 4), m4b); //B10(24-31) B11(24-31) B14(24-31) B15(24-31)
4507
+ const __m256i rhs_mat_2367_13 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_3, 4), m4b); //B12(24-31) B13(24-31) B16(24-31) B17(24-31)
4508
+
4509
+ // Shuffle pattern one - right side input
4510
+ const __m256i rhs_mat_0145_00_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_00, 136); //B00(0-3) B01(0-3) B00(0-3) B01(0-3) B04(0-3) B05(0-3) B04(0-3) B05(0-3)
4511
+ const __m256i rhs_mat_2367_00_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_00, 136); //B02(0-3) B03(0-3) B02(0-3) B03(0-3) B06(0-3) B07(0-3) B06(0-3) B07(0-3)
4512
+
4513
+ const __m256i rhs_mat_0145_01_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_01, 136); //B00(8-11) B01(8-11) B00(8-11) B01(8-11) B04(8-11) B05(8-11) B04(8-11) B05(8-11)
4514
+ const __m256i rhs_mat_2367_01_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_01, 136); //B02(8-11) B03(8-11) B02(8-11) B03(8-11) B06(8-11) B07(8-11) B06(8-11) B07(8-11)
4515
+
4516
+ const __m256i rhs_mat_0145_02_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_02, 136); //B00(16-19) B01(16-19) B00(16-19) B01(16-19) B04(16-19) B05(16-19) B04(16-19) B05(16-19)
4517
+ const __m256i rhs_mat_2367_02_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_02, 136); //B02(16-19) B03(16-19) B02(16-19) B03(16-19) B06(16-19) B07(16-19) B06(16-19) B07(16-19)
4518
+
4519
+ const __m256i rhs_mat_0145_03_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_03, 136); //B00(24-27) B01(24-27) B00(24-27) B01(24-27) B04(24-27) B05(24-27) B04(24-27) B05(24-27)
4520
+ const __m256i rhs_mat_2367_03_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_03, 136); //B02(24-27) B03(24-27) B02(24-27) B03(24-27) B06(24-27) B07(24-27) B06(24-27) B07(24-27)
4521
+
4522
+ const __m256i rhs_mat_0145_10_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_10, 136); //B10(0-3) B11(0-3) B10(0-3) B11(0-3) B14(0-3) B15(0-3) B14(0-3) B15(0-3)
4523
+ const __m256i rhs_mat_2367_10_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_10, 136); //B12(0-3) B13(0-3) B12(0-3) B13(0-3) B16(0-3) B17(0-3) B16(0-3) B17(0-3)
4524
+
4525
+ const __m256i rhs_mat_0145_11_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_11, 136); //B10(8-11) B11(8-11) B10(8-11) B11(8-11) B14(8-11) B15(8-11) B14(8-11) B15(8-11)
4526
+ const __m256i rhs_mat_2367_11_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_11, 136); //B12(8-11) B13(8-11) B12(8-11) B13(8-11) B16(8-11) B17(8-11) B16(8-11) B17(8-11)
4527
+
4528
+ const __m256i rhs_mat_0145_12_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_12, 136); //B10(16-19) B11(16-19) B10(16-19) B11(16-19) B14(16-19) B15(16-19) B14(16-19) B15(16-19)
4529
+ const __m256i rhs_mat_2367_12_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_12, 136); //B12(16-19) B13(16-19) B12(16-19) B13(16-19) B16(16-19) B17(16-19) B16(16-19) B17(16-19)
4530
+
4531
+ const __m256i rhs_mat_0145_13_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_13, 136); //B10(24-27) B11(24-27) B10(24-27) B11(24-27) B14(24-27) B15(24-27) B14(24-27) B15(24-27)
4532
+ const __m256i rhs_mat_2367_13_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_13, 136); //B12(24-27) B13(24-27) B12(24-27) B13(24-27) B16(24-27) B17(24-27) B16(24-27) B17(24-27)
4533
+
4534
+ // Shuffle pattern two - right side input
4535
+ const __m256i rhs_mat_0145_00_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_00, 221); //B00(4-7) B01(4-7) B00(4-7) B01(4-7) B04(4-7) B05(4-7) B04(4-7) B05(4-7)
4536
+ const __m256i rhs_mat_2367_00_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_00, 221); //B02(4-7) B03(4-7) B02(4-7) B03(4-7) B06(4-7) B07(4-7) B06(4-7) B07(4-7)
4537
+
4538
+ const __m256i rhs_mat_0145_01_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_01, 221); //B00(12-15) B01(12-15) B00(12-15) B01(12-15) B04(12-15) B05(12-15) B04(12-15) B05(12-15)
4539
+ const __m256i rhs_mat_2367_01_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_01, 221); //B02(12-15) B03(12-15) B02(12-15) B03(12-15) B06(12-15) B07(12-15) B06(12-15) B07(12-15)
4540
+
4541
+ const __m256i rhs_mat_0145_02_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_02, 221); //B00(20-23) B01(20-23) B00(20-23) B01(20-23) B04(20-23) B05(20-23) B04(20-23) B05(20-23)
4542
+ const __m256i rhs_mat_2367_02_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_02, 221); //B02(20-23) B03(20-23) B02(20-23) B03(20-23) B06(20-23) B07(20-23) B06(20-23) B07(20-23)
4543
+
4544
+ const __m256i rhs_mat_0145_03_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_03, 221); //B00(28-31) B01(28-31) B00(28-31) B01(28-31) B04(28-31) B05(28-31) B04(28-31) B05(28-31)
4545
+ const __m256i rhs_mat_2367_03_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_03, 221); //B02(28-31) B03(28-31) B02(28-31) B03(28-31) B06(28-31) B07(28-31) B06(28-31) B07(28-31)
4546
+
4547
+ const __m256i rhs_mat_0145_10_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_10, 221); //B10(4-7) B11(4-7) B10(4-7) B11(4-7) B14(4-7) B15(4-7) B14(4-7) B15(4-7)
4548
+ const __m256i rhs_mat_2367_10_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_10, 221); //B12(4-7) B13(4-7) B12(4-7) B13(4-7) B16(4-7) B17(4-7) B16(4-7) B17(4-7)
4549
+
4550
+ const __m256i rhs_mat_0145_11_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_11, 221); //B10(12-15) B11(12-15) B10(12-15) B11(12-15) B14(12-15) B15(12-15) B14(12-15) B15(12-15)
4551
+ const __m256i rhs_mat_2367_11_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_11, 221); //B12(12-15) B13(12-15) B12(12-15) B13(12-15) B16(12-15) B17(12-15) B16(12-15) B17(12-15)
4552
+
4553
+ const __m256i rhs_mat_0145_12_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_12, 221); //B10(20-23) B11(20-23) B10(20-23) B11(20-23) B14(20-23) B15(20-23) B14(20-23) B15(20-23)
4554
+ const __m256i rhs_mat_2367_12_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_12, 221); //B12(20-23) B13(20-23) B12(20-23) B13(20-23) B16(20-23) B17(20-23) B16(20-23) B17(20-23)
4555
+
4556
+ const __m256i rhs_mat_0145_13_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_13, 221); //B10(28-31) B11(28-31) B10(28-31) B11(28-31) B14(28-31) B15(28-31) B14(28-31) B15(28-31)
4557
+ const __m256i rhs_mat_2367_13_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_13, 221); //B12(28-31) B13(28-31) B12(28-31) B13(28-31) B16(28-31) B17(28-31) B16(28-31) B17(28-31)
4558
+
4559
+ uint32_t utmp_0[4], utmp_1[4];
4560
+
4561
+ // Scales and Mins of corresponding sub blocks from different Q4_K structures are stored together
4562
+ // The below block is for eg to extract first sub block's scales and mins from different Q4_K structures for the sb loop
4563
+ memcpy(utmp_0, b_ptr[b].scales + 24 * sb, 12);
4564
+ utmp_0[3] = ((utmp_0[2] >> 4) & kmask2) | (((utmp_0[1] >> 6) & kmask3) << 4);
4565
+ const uint32_t uaux_0 = utmp_0[1] & kmask1;
4566
+ utmp_0[1] = (utmp_0[2] & kmask2) | (((utmp_0[0] >> 6) & kmask3) << 4);
4567
+ utmp_0[2] = uaux_0;
4568
+ utmp_0[0] &= kmask1;
4569
+
4570
+ // The below block is for eg to extract second sub block's scales and mins from different Q4_K structures when sb = 1
4571
+ memcpy(utmp_1, b_ptr[b].scales + 12 + sb * 24, 12);
4572
+ utmp_1[3] = ((utmp_1[2] >> 4) & kmask2) | (((utmp_1[1] >> 6) & kmask3) << 4);
4573
+ const uint32_t uaux_1 = utmp_1[1] & kmask1;
4574
+ utmp_1[1] = (utmp_1[2] & kmask2) | (((utmp_1[0] >> 6) & kmask3) << 4);
4575
+ utmp_1[2] = uaux_1;
4576
+ utmp_1[0] &= kmask1;
4577
+
4578
+ // Scales of first sub block in the sb loop
4579
+ const __m128i mins_and_scales_0 = _mm_set_epi32(utmp_0[3], utmp_0[2], utmp_0[1], utmp_0[0]);
4580
+ const __m256i scales_0 = _mm256_cvtepu8_epi16(_mm_unpacklo_epi8(mins_and_scales_0, mins_and_scales_0));
4581
+
4582
+ // Scales of second sub block in the sb loop
4583
+ const __m128i mins_and_scales_1 = _mm_set_epi32(utmp_1[3], utmp_1[2], utmp_1[1], utmp_1[0]);
4584
+ const __m256i scales_1 = _mm256_cvtepu8_epi16(_mm_unpacklo_epi8(mins_and_scales_1, mins_and_scales_1));
4585
+
4586
+ // Mins of first and second sub block of Q4_K block are arranged side by side
4587
+ const __m256i mins_01 = _mm256_cvtepu8_epi16(_mm_unpacklo_epi8(_mm_shuffle_epi32(mins_and_scales_0, 78), _mm_shuffle_epi32(mins_and_scales_1, 78)));
4588
+
4589
+ const __m256i scale_0145_0 = _mm256_shuffle_epi32(scales_0, 68);
4590
+ const __m256i scale_2367_0 = _mm256_shuffle_epi32(scales_0, 238);
4591
+
4592
+ const __m256i scale_0145_1 = _mm256_shuffle_epi32(scales_1, 68);
4593
+ const __m256i scale_2367_1 = _mm256_shuffle_epi32(scales_1, 238);
4594
+
4595
+ // Load the four block_q8_k quantized values interleaved with each other in chunks of eight bytes - A0,A1,A2,A3
4596
+ // Loaded as set of 128 bit vectors and repeated into a 256 bit vector
4597
+ __m256i lhs_mat_0123_00 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 256 * sb)));
4598
+ __m256i lhs_mat_01_00 = _mm256_permute2f128_si256(lhs_mat_0123_00, lhs_mat_0123_00, 0);
4599
+ __m256i lhs_mat_23_00 = _mm256_permute2f128_si256(lhs_mat_0123_00, lhs_mat_0123_00, 17);
4600
+ __m256i lhs_mat_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 32 + 256 * sb)));
4601
+ __m256i lhs_mat_01_01 = _mm256_permute2f128_si256(lhs_mat_0123_01, lhs_mat_0123_01, 0);
4602
+ __m256i lhs_mat_23_01 = _mm256_permute2f128_si256(lhs_mat_0123_01, lhs_mat_0123_01, 17);
4603
+ __m256i lhs_mat_0123_02 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 64 + 256 * sb)));
4604
+ __m256i lhs_mat_01_02 = _mm256_permute2f128_si256(lhs_mat_0123_02, lhs_mat_0123_02, 0);
4605
+ __m256i lhs_mat_23_02 = _mm256_permute2f128_si256(lhs_mat_0123_02, lhs_mat_0123_02, 17);
4606
+ __m256i lhs_mat_0123_03 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 96 + 256 * sb)));
4607
+ __m256i lhs_mat_01_03 = _mm256_permute2f128_si256(lhs_mat_0123_03, lhs_mat_0123_03, 0);
4608
+ __m256i lhs_mat_23_03 = _mm256_permute2f128_si256(lhs_mat_0123_03, lhs_mat_0123_03, 17);
4609
+ __m256i lhs_mat_0123_10 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 128 + 256 * sb)));
4610
+ __m256i lhs_mat_01_10 = _mm256_permute2f128_si256(lhs_mat_0123_10, lhs_mat_0123_10, 0);
4611
+ __m256i lhs_mat_23_10 = _mm256_permute2f128_si256(lhs_mat_0123_10, lhs_mat_0123_10, 17);
4612
+ __m256i lhs_mat_0123_11 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 160 + 256 * sb)));
4613
+ __m256i lhs_mat_01_11 = _mm256_permute2f128_si256(lhs_mat_0123_11, lhs_mat_0123_11, 0);
4614
+ __m256i lhs_mat_23_11 = _mm256_permute2f128_si256(lhs_mat_0123_11, lhs_mat_0123_11, 17);
4615
+ __m256i lhs_mat_0123_12 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 192 + 256 * sb)));
4616
+ __m256i lhs_mat_01_12 = _mm256_permute2f128_si256(lhs_mat_0123_12, lhs_mat_0123_12, 0);
4617
+ __m256i lhs_mat_23_12 = _mm256_permute2f128_si256(lhs_mat_0123_12, lhs_mat_0123_12, 17);
4618
+ __m256i lhs_mat_0123_13 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 224 + 256 * sb)));
4619
+ __m256i lhs_mat_01_13 = _mm256_permute2f128_si256(lhs_mat_0123_13, lhs_mat_0123_13, 0);
4620
+ __m256i lhs_mat_23_13 = _mm256_permute2f128_si256(lhs_mat_0123_13, lhs_mat_0123_13, 17);
4621
+
4622
+ // Bsums are loaded - four bsums are loaded (for two sub blocks) for the different Q8_K blocks
4623
+ __m256i lhs_bsums_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].bsums + 16 * sb)));
4624
+ __m256i lhs_bsums_hsum_0123_01 = _mm256_castsi128_si256(_mm_hadd_epi16(_mm256_castsi256_si128(lhs_bsums_0123_01), _mm256_extractf128_si256(lhs_bsums_0123_01, 1)));
4625
+ lhs_bsums_hsum_0123_01 = _mm256_permute2x128_si256(lhs_bsums_hsum_0123_01, lhs_bsums_hsum_0123_01, 0);
4626
+
4627
+ // Shuffle pattern one - left side input
4628
+ const __m256i lhs_mat_01_00_sp1 = _mm256_shuffle_epi32(lhs_mat_01_00, 160); //A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3)
4629
+ const __m256i lhs_mat_23_00_sp1 = _mm256_shuffle_epi32(lhs_mat_23_00, 160); //A02(0-3) A03(0-3) A02(0-3) A03(0-3) A02(0-3) A03(0-3) A02(0-3) A03(0-3)
4630
+
4631
+ const __m256i lhs_mat_01_01_sp1 = _mm256_shuffle_epi32(lhs_mat_01_01, 160); //A00(8-11) A00(8-11) A01(8-11) A01(8-11) A00(8-11) A00(8-11) A01(8-11) A01(8-11)
4632
+ const __m256i lhs_mat_23_01_sp1 = _mm256_shuffle_epi32(lhs_mat_23_01, 160); //A02(8-11) A03(8-11) A02(8-11) A03(8-11) A02(8-11) A03(8-11) A02(8-11) A03(8-11)
4633
+
4634
+ const __m256i lhs_mat_01_02_sp1 = _mm256_shuffle_epi32(lhs_mat_01_02, 160); //A00(16-19) A00(16-19) A01(16-19) A01(16-19) A00(16-19) A00(16-19) A01(16-19) A01(16-19)
4635
+ const __m256i lhs_mat_23_02_sp1 = _mm256_shuffle_epi32(lhs_mat_23_02, 160); //A02(16-19) A03(16-19) A02(16-19) A03(16-19) A02(16-19) A03(16-19) A02(16-19) A03(16-19)
4636
+
4637
+ const __m256i lhs_mat_01_03_sp1 = _mm256_shuffle_epi32(lhs_mat_01_03, 160); //A00(24-27) A00(24-27) A01(24-27) A01(24-27) A00(24-27) A00(24-27) A01(24-27) A01(24-27)
4638
+ const __m256i lhs_mat_23_03_sp1 = _mm256_shuffle_epi32(lhs_mat_23_03, 160); //A02(24-27) A03(24-27) A02(24-27) A03(24-27) A02(24-27) A03(24-27) A02(24-27) A03(24-27)
4639
+
4640
+ const __m256i lhs_mat_01_10_sp1 = _mm256_shuffle_epi32(lhs_mat_01_10, 160); //A10(0-3) A10(0-3) A11(0-3) A11(0-3) A10(0-3) A10(0-3) A11(0-3) A11(0-3)
4641
+ const __m256i lhs_mat_23_10_sp1 = _mm256_shuffle_epi32(lhs_mat_23_10, 160); //A12(0-3) A13(0-3) A12(0-3) A13(0-3) A12(0-3) A13(0-3) A12(0-3) A13(0-3)
4642
+
4643
+ const __m256i lhs_mat_01_11_sp1 = _mm256_shuffle_epi32(lhs_mat_01_11, 160); //A10(8-11) A10(8-11) A11(8-11) A11(8-11) A10(8-11) A10(8-11) A11(8-11) A11(8-11)
4644
+ const __m256i lhs_mat_23_11_sp1 = _mm256_shuffle_epi32(lhs_mat_23_11, 160); //A12(8-11) A13(8-11) A12(8-11) A13(8-11) A12(8-11) A13(8-11) A12(8-11) A13(8-11)
4645
+
4646
+ const __m256i lhs_mat_01_12_sp1 = _mm256_shuffle_epi32(lhs_mat_01_12, 160); //A10(16-19) A10(16-19) A11(16-19) A11(16-19) A10(16-19) A10(16-19) A11(16-19) A11(16-19)
4647
+ const __m256i lhs_mat_23_12_sp1 = _mm256_shuffle_epi32(lhs_mat_23_12, 160); //A12(16-19) A13(16-19) A12(16-19) A13(16-19) A12(16-19) A13(16-19) A12(16-19) A13(16-19)
4648
+
4649
+ const __m256i lhs_mat_01_13_sp1 = _mm256_shuffle_epi32(lhs_mat_01_13, 160); //A10(24-27) A10(24-27) A11(24-27) A11(24-27) A10(24-27) A10(24-27) A11(24-27) A11(24-27)
4650
+ const __m256i lhs_mat_23_13_sp1 = _mm256_shuffle_epi32(lhs_mat_23_13, 160); //A12(24-27) A13(24-27) A12(24-27) A13(24-27) A12(24-27) A13(24-27) A12(24-27) A13(24-27)
4651
+
4652
+ // Shuffle pattern two- left side input
4653
+ const __m256i lhs_mat_01_00_sp2 = _mm256_shuffle_epi32(lhs_mat_01_00, 245); //A00(4-7) A00(4-7) A01(4-7) A01(4-7) A00(4-7) A00(4-7) A01(4-7) A01(4-7)
4654
+ const __m256i lhs_mat_23_00_sp2 = _mm256_shuffle_epi32(lhs_mat_23_00, 245); //A02(4-7) A03(4-7) A02(4-7) A03(4-7) A02(4-7) A03(4-7) A02(4-7) A03(4-7)
4655
+
4656
+ const __m256i lhs_mat_01_01_sp2 = _mm256_shuffle_epi32(lhs_mat_01_01, 245); //A00(12-15) A00(12-15) A01(12-15) A01(12-15) A00(12-15) A00(12-15) A01(12-15) A01(12-15)
4657
+ const __m256i lhs_mat_23_01_sp2 = _mm256_shuffle_epi32(lhs_mat_23_01, 245); //A02(12-15) A03(12-15) A02(12-15) A03(12-15) A02(12-15) A03(12-15) A02(12-15) A03(12-15)
4658
+
4659
+ const __m256i lhs_mat_01_02_sp2 = _mm256_shuffle_epi32(lhs_mat_01_02, 245); //A00(20-23) A00(20-23) A01(20-23) A01(20-23) A00(20-23) A00(20-23) A01(20-23) A01(20-23)
4660
+ const __m256i lhs_mat_23_02_sp2 = _mm256_shuffle_epi32(lhs_mat_23_02, 245); //A02(20-23) A03(20-23) A02(20-23) A03(20-23) A02(20-23) A03(20-23) A02(20-23) A03(20-23)
4661
+
4662
+ const __m256i lhs_mat_01_03_sp2 = _mm256_shuffle_epi32(lhs_mat_01_03, 245); //A00(28-31) A00(28-31) A01(28-31) A01(28-31) A00(28-31) A00(28-31) A01(28-31) A01(28-31)
4663
+ const __m256i lhs_mat_23_03_sp2 = _mm256_shuffle_epi32(lhs_mat_23_03, 245); //A02(28-31) A03(28-31) A02(28-31) A03(28-31) A02(28-31) A03(28-31) A02(28-31) A03(28-31)
4664
+
4665
+ const __m256i lhs_mat_01_10_sp2 = _mm256_shuffle_epi32(lhs_mat_01_10, 245); //A10(4-7) A10(4-7) A11(4-7) A11(4-7) A10(4-7) A10(4-7) A11(4-7) A11(4-7)
4666
+ const __m256i lhs_mat_23_10_sp2 = _mm256_shuffle_epi32(lhs_mat_23_10, 245); //A12(4-7) A13(4-7) A12(4-7) A13(4-7) A12(4-7) A13(4-7) A12(4-7) A13(4-7)
4667
+
4668
+ const __m256i lhs_mat_01_11_sp2 = _mm256_shuffle_epi32(lhs_mat_01_11, 245); //A10(12-15) A10(12-15) A11(12-15) A11(12-15) A10(12-15) A10(12-15) A11(12-15) A11(12-15)
4669
+ const __m256i lhs_mat_23_11_sp2 = _mm256_shuffle_epi32(lhs_mat_23_11, 245); //A12(12-15) A13(12-15) A12(12-15) A13(12-15) A12(12-15) A13(12-15) A12(12-15) A13(12-15)
4670
+
4671
+ const __m256i lhs_mat_01_12_sp2 = _mm256_shuffle_epi32(lhs_mat_01_12, 245); //A10(20-23) A10(20-23) A11(20-23) A11(20-23) A10(20-23) A10(20-23) A11(20-23) A11(20-23)
4672
+ const __m256i lhs_mat_23_12_sp2 = _mm256_shuffle_epi32(lhs_mat_23_12, 245); //A12(20-23) A13(20-23) A12(20-23) A13(20-23) A12(20-23) A13(20-23) A12(20-23) A13(20-23)
4673
+
4674
+ const __m256i lhs_mat_01_13_sp2 = _mm256_shuffle_epi32(lhs_mat_01_13, 245); //A10(28-31) A10(28-31) A11(28-31) A11(28-31) A10(28-31) A10(28-31) A11(28-31) A11(28-31)
4675
+ const __m256i lhs_mat_23_13_sp2 = _mm256_shuffle_epi32(lhs_mat_23_13, 245); //A12(28-31) A13(28-31) A12(28-31) A13(28-31) A12(28-31) A13(28-31) A12(28-31) A13(28-31)
4676
+
4677
+ // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
4678
+ __m256i iacc_mat_00_0_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_03_sp1, lhs_mat_01_03_sp1), _mm256_maddubs_epi16(rhs_mat_0145_02_sp1, lhs_mat_01_02_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_01_sp1, lhs_mat_01_01_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_00_sp1, lhs_mat_01_00_sp1));
4679
+ __m256i iacc_mat_01_0_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_03_sp1, lhs_mat_01_03_sp1), _mm256_maddubs_epi16(rhs_mat_2367_02_sp1, lhs_mat_01_02_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_01_sp1, lhs_mat_01_01_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_00_sp1, lhs_mat_01_00_sp1));
4680
+ __m256i iacc_mat_10_0_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_03_sp1, lhs_mat_23_03_sp1), _mm256_maddubs_epi16(rhs_mat_0145_02_sp1, lhs_mat_23_02_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_01_sp1, lhs_mat_23_01_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_00_sp1, lhs_mat_23_00_sp1));
4681
+ __m256i iacc_mat_11_0_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_03_sp1, lhs_mat_23_03_sp1), _mm256_maddubs_epi16(rhs_mat_2367_02_sp1, lhs_mat_23_02_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_01_sp1, lhs_mat_23_01_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_00_sp1, lhs_mat_23_00_sp1));
4682
+ __m256i iacc_mat_00_1_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_13_sp1, lhs_mat_01_13_sp1), _mm256_maddubs_epi16(rhs_mat_0145_12_sp1, lhs_mat_01_12_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_11_sp1, lhs_mat_01_11_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_10_sp1, lhs_mat_01_10_sp1));
4683
+ __m256i iacc_mat_01_1_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_13_sp1, lhs_mat_01_13_sp1), _mm256_maddubs_epi16(rhs_mat_2367_12_sp1, lhs_mat_01_12_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_11_sp1, lhs_mat_01_11_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_10_sp1, lhs_mat_01_10_sp1));
4684
+ __m256i iacc_mat_10_1_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_13_sp1, lhs_mat_23_13_sp1), _mm256_maddubs_epi16(rhs_mat_0145_12_sp1, lhs_mat_23_12_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_11_sp1, lhs_mat_23_11_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_10_sp1, lhs_mat_23_10_sp1));
4685
+ __m256i iacc_mat_11_1_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_13_sp1, lhs_mat_23_13_sp1), _mm256_maddubs_epi16(rhs_mat_2367_12_sp1, lhs_mat_23_12_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_11_sp1, lhs_mat_23_11_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_10_sp1, lhs_mat_23_10_sp1));
4686
+
4687
+ __m256i iacc_mat_00_0_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_03_sp2, lhs_mat_01_03_sp2), _mm256_maddubs_epi16(rhs_mat_0145_02_sp2, lhs_mat_01_02_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_01_sp2, lhs_mat_01_01_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_00_sp2, lhs_mat_01_00_sp2));
4688
+ __m256i iacc_mat_01_0_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_03_sp2, lhs_mat_01_03_sp2), _mm256_maddubs_epi16(rhs_mat_2367_02_sp2, lhs_mat_01_02_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_01_sp2, lhs_mat_01_01_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_00_sp2, lhs_mat_01_00_sp2));
4689
+ __m256i iacc_mat_10_0_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_03_sp2, lhs_mat_23_03_sp2), _mm256_maddubs_epi16(rhs_mat_0145_02_sp2, lhs_mat_23_02_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_01_sp2, lhs_mat_23_01_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_00_sp2, lhs_mat_23_00_sp2));
4690
+ __m256i iacc_mat_11_0_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_03_sp2, lhs_mat_23_03_sp2), _mm256_maddubs_epi16(rhs_mat_2367_02_sp2, lhs_mat_23_02_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_01_sp2, lhs_mat_23_01_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_00_sp2, lhs_mat_23_00_sp2));
4691
+ __m256i iacc_mat_00_1_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_13_sp2, lhs_mat_01_13_sp2), _mm256_maddubs_epi16(rhs_mat_0145_12_sp2, lhs_mat_01_12_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_11_sp2, lhs_mat_01_11_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_10_sp2, lhs_mat_01_10_sp2));
4692
+ __m256i iacc_mat_01_1_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_13_sp2, lhs_mat_01_13_sp2), _mm256_maddubs_epi16(rhs_mat_2367_12_sp2, lhs_mat_01_12_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_11_sp2, lhs_mat_01_11_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_10_sp2, lhs_mat_01_10_sp2));
4693
+ __m256i iacc_mat_10_1_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_13_sp2, lhs_mat_23_13_sp2), _mm256_maddubs_epi16(rhs_mat_0145_12_sp2, lhs_mat_23_12_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_11_sp2, lhs_mat_23_11_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_10_sp2, lhs_mat_23_10_sp2));
4694
+ __m256i iacc_mat_11_1_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_13_sp2, lhs_mat_23_13_sp2), _mm256_maddubs_epi16(rhs_mat_2367_12_sp2, lhs_mat_23_12_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_11_sp2, lhs_mat_23_11_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_10_sp2, lhs_mat_23_10_sp2));
4695
+
4696
+ // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
4697
+ __m256i iacc_mat_00_0 = _mm256_add_epi16(iacc_mat_00_0_sp1, iacc_mat_00_0_sp2);
4698
+ __m256i iacc_mat_01_0 = _mm256_add_epi16(iacc_mat_01_0_sp1, iacc_mat_01_0_sp2);
4699
+ __m256i iacc_mat_10_0 = _mm256_add_epi16(iacc_mat_10_0_sp1, iacc_mat_10_0_sp2);
4700
+ __m256i iacc_mat_11_0 = _mm256_add_epi16(iacc_mat_11_0_sp1, iacc_mat_11_0_sp2);
4701
+
4702
+ __m256i iacc_mat_00_1 = _mm256_add_epi16(iacc_mat_00_1_sp1, iacc_mat_00_1_sp2);
4703
+ __m256i iacc_mat_01_1 = _mm256_add_epi16(iacc_mat_01_1_sp1, iacc_mat_01_1_sp2);
4704
+ __m256i iacc_mat_10_1 = _mm256_add_epi16(iacc_mat_10_1_sp1, iacc_mat_10_1_sp2);
4705
+ __m256i iacc_mat_11_1 = _mm256_add_epi16(iacc_mat_11_1_sp1, iacc_mat_11_1_sp2);
4706
+
4707
+ // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
4708
+ iacc_mat_00_0 = _mm256_madd_epi16(iacc_mat_00_0, scale_0145_0);
4709
+ iacc_mat_01_0 = _mm256_madd_epi16(iacc_mat_01_0, scale_2367_0);
4710
+ iacc_mat_10_0 = _mm256_madd_epi16(iacc_mat_10_0, scale_0145_0);
4711
+ iacc_mat_11_0 = _mm256_madd_epi16(iacc_mat_11_0, scale_2367_0);
4712
+
4713
+ iacc_mat_00_1 = _mm256_madd_epi16(iacc_mat_00_1, scale_0145_1);
4714
+ iacc_mat_01_1 = _mm256_madd_epi16(iacc_mat_01_1, scale_2367_1);
4715
+ iacc_mat_10_1 = _mm256_madd_epi16(iacc_mat_10_1, scale_0145_1);
4716
+ iacc_mat_11_1 = _mm256_madd_epi16(iacc_mat_11_1, scale_2367_1);
4717
+
4718
+ // Straighten out to make 4 row vectors (4 for each sub block which are accumulated together in the next step)
4719
+ __m256i iacc_row_0_0 = _mm256_blend_epi32(iacc_mat_00_0, _mm256_shuffle_epi32(iacc_mat_01_0, 78), 204);
4720
+ __m256i iacc_row_1_0 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_00_0, 78), iacc_mat_01_0, 204);
4721
+ __m256i iacc_row_2_0 = _mm256_blend_epi32(iacc_mat_10_0, _mm256_shuffle_epi32(iacc_mat_11_0, 78), 204);
4722
+ __m256i iacc_row_3_0 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_10_0, 78), iacc_mat_11_0, 204);
4723
+ __m256i iacc_row_0_1 = _mm256_blend_epi32(iacc_mat_00_1, _mm256_shuffle_epi32(iacc_mat_01_1, 78), 204);
4724
+ __m256i iacc_row_1_1 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_00_1, 78), iacc_mat_01_1, 204);
4725
+ __m256i iacc_row_2_1 = _mm256_blend_epi32(iacc_mat_10_1, _mm256_shuffle_epi32(iacc_mat_11_1, 78), 204);
4726
+ __m256i iacc_row_3_1 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_10_1, 78), iacc_mat_11_1, 204);
4727
+
4728
+ __m256i iacc_row_0 = _mm256_add_epi32(iacc_row_0_0, iacc_row_0_1);
4729
+ __m256i iacc_row_1 = _mm256_add_epi32(iacc_row_1_0, iacc_row_1_1);
4730
+ __m256i iacc_row_2 = _mm256_add_epi32(iacc_row_2_0, iacc_row_2_1);
4731
+ __m256i iacc_row_3 = _mm256_add_epi32(iacc_row_3_0, iacc_row_3_1);
4732
+
4733
+ // Load the scale(d) values for all the 4 Q8_k blocks and repeat it across lanes
4734
+ const __m128 row_scale_f32_sse = _mm_load_ps(a_ptr[b].d);
4735
+ const __m256 row_scale_f32 = _mm256_set_m128(row_scale_f32_sse, row_scale_f32_sse); //LM_GGML_F32Cx8_REPEAT_LOAD(a_ptrs[rp][b].d, loadMask);
4736
+
4737
+ // Multiply with appropiate scales and accumulate (for both d and dmin) below
4738
+ acc_rows[0] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_0), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[0]);
4739
+ acc_rows[1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_1), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[1]);
4740
+ acc_rows[2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_2), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[2]);
4741
+ acc_rows[3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_3), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[3]);
4742
+
4743
+ __m256i iacc_row_min_0 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_hsum_0123_01, 0), mins_01);
4744
+ __m256i iacc_row_min_1 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_hsum_0123_01, 85), mins_01);
4745
+ __m256i iacc_row_min_2 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_hsum_0123_01, 170), mins_01);
4746
+ __m256i iacc_row_min_3 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_hsum_0123_01, 255), mins_01);
4747
+
4748
+ acc_min_rows[0] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_0), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_min_rows[0]);
4749
+ acc_min_rows[1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_1), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_min_rows[1]);
4750
+ acc_min_rows[2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_2), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_min_rows[2]);
4751
+ acc_min_rows[3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_3), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_min_rows[3]);
4752
+ }
4753
+ }
4754
+
4755
+ // Store the accumulated values
4756
+ for (int i = 0; i < 4; i++) {
4757
+ _mm256_storeu_ps((float * )(s + ((y * 4 + i) * bs + x * 8)), _mm256_sub_ps(acc_rows[i], acc_min_rows[i]));
4758
+ }
4759
+ }
4760
+ }
4761
+
4762
+ #else
4763
+
4764
+ float sumf[4][8];
4765
+ float sum_minf[4][8];
4766
+ uint32_t utmp[32];
4767
+ int sumi1;
4768
+ int sumi2;
4769
+ int sumi;
4770
+
4771
+ for (int y = 0; y < nr / 4; y++) {
4772
+ const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb);
4773
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
4774
+ const block_q4_Kx8 * b_ptr = (const block_q4_Kx8 *) vx + (x * nb);
4775
+ for (int m = 0; m < 4; m++) {
4776
+ for (int j = 0; j < ncols_interleaved; j++) {
4777
+ sumf[m][j] = 0.0;
4778
+ sum_minf[m][j] = 0.0;
4779
+ }
4780
+ }
4781
+ for (int l = 0; l < nb; l++) {
4782
+ for (int sb = 0; sb < 8; sb++) {
4783
+ memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12);
4784
+ utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);
4785
+ const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;
4786
+ utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);
4787
+ utmp[sb * 4 + 2] = uaux_0;
4788
+ utmp[sb * 4 + 0] &= kmask1;
4789
+ }
4790
+ for (int k = 0; k < (qk / (2 * blocklen)); k++) {
4791
+ uint8_t *scales_0 = (uint8_t*) utmp + (k / 4) * 32;
4792
+ uint8_t *scales_1 = (uint8_t*) utmp + (k / 4) * 32 + 16;
4793
+ for (int m = 0; m < 4; m++) {
4794
+ for (int j = 0; j < ncols_interleaved; j++) {
4795
+ sumi1 = 0;
4796
+ sumi2 = 0;
4797
+ sumi = 0;
4798
+ for (int i = 0; i < blocklen; ++i) {
4799
+ const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF);
4800
+ const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4);
4801
+ sumi1 = (v0 * a_ptr[l].qs[(k >> 2) * 256 + (k % 4) * 4 * blocklen + m * blocklen + i]);
4802
+ sumi2 = (v1 * a_ptr[l].qs[(k >> 2) * 256 + (k % 4) * 4 * blocklen + m * blocklen + i + 128]);
4803
+ sumi1 = sumi1 * scales_0[j];
4804
+ sumi2 = sumi2 * scales_1[j];
4805
+ sumi += sumi1 + sumi2;
4806
+ }
4807
+ sumf[m][j] += sumi * LM_GGML_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m];
4808
+ }
4809
+ }
4810
+ }
4811
+ for (int sb = 0; sb < 8; sb++) {
4812
+ uint8_t *mins = (uint8_t*) utmp + 8 + sb * 16;
4813
+ for(int m = 0; m < 4; m++) {
4814
+ const int16_t *bsums = a_ptr[l].bsums + (sb * 8) + (m * 4) - ((sb % 2) * 6);
4815
+ for(int j = 0; j < ncols_interleaved; j++) {
4816
+ sum_minf[m][j] += mins[j] * (bsums[0] + bsums[1]) * LM_GGML_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m];
4817
+ }
4818
+ }
4819
+ }
4820
+ }
4821
+ for (int m = 0; m < 4; m++) {
4822
+ for (int j = 0; j < ncols_interleaved; j++) {
4823
+ s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j];
4824
+ }
4825
+ }
4826
+ }
4827
+ }
4828
+ #endif
4829
+ }
4830
+
3483
4831
  static void lm_ggml_gemm_iq4_nl_4x4_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, const void * LM_GGML_RESTRICT vy, int nr, int nc) {
3484
4832
  const int qk = QK8_0;
3485
4833
  const int nb = n / qk;
@@ -3660,6 +5008,82 @@ static block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int blck_size_in
3660
5008
  return out;
3661
5009
  }
3662
5010
 
5011
+ static block_q4_Kx8 make_block_q4_Kx8(block_q4_K * in, unsigned int blck_size_interleave) {
5012
+ block_q4_Kx8 out;
5013
+ //Delta(scale) and dmin values of the eight Q4_K structures are copied onto the output interleaved structure
5014
+ for (int i = 0; i < 8; i++) {
5015
+ out.d[i] = in[i].LM_GGML_COMMON_AGGR_U.LM_GGML_COMMON_AGGR_S.d;
5016
+ }
5017
+
5018
+ for (int i = 0; i < 8; i++) {
5019
+ out.dmin[i] = in[i].LM_GGML_COMMON_AGGR_U.LM_GGML_COMMON_AGGR_S.dmin;
5020
+ }
5021
+
5022
+ const int end = QK_K * 4 / blck_size_interleave;
5023
+
5024
+ // Interleave Q4_K quants by taking 8 bytes at a time
5025
+ for (int i = 0; i < end; ++i) {
5026
+ int src_id = i % 8;
5027
+ int src_offset = (i / 8) * blck_size_interleave;
5028
+ int dst_offset = i * blck_size_interleave;
5029
+
5030
+ uint64_t elems;
5031
+ memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t));
5032
+ memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t));
5033
+ }
5034
+
5035
+ // The below logic is designed so as to unpack and rearrange scales and mins values in Q4_K
5036
+ // Currently the Q4_K structure has 8 scales and 8 mins packed in 12 bytes ( 6 bits for each value)
5037
+ // The output Q4_Kx8 structure has 96 bytes
5038
+ // Every 12 byte is packed such that it contains scales and mins for corresponding sub blocks from Q4_K structure
5039
+ // For eg - First 12 bytes contains 8 scales and 8 mins - each of first sub block from different Q4_K structures
5040
+ uint8_t s[8], m[8];
5041
+
5042
+ for (int i = 0; i < 4; i++) {
5043
+ for (int j = 0; j < 8; j++) {
5044
+ s[j] = in[j].scales[i] & 63;
5045
+ m[j] = in[j].scales[i + 4] & 63;
5046
+ }
5047
+
5048
+ out.scales[i * 12] = (s[0] & 63) + ((s[4] & 48) << 2);
5049
+ out.scales[i * 12 + 1] = (s[1] & 63) + ((s[5] & 48) << 2);
5050
+ out.scales[i * 12 + 2] = (s[2] & 63) + ((s[6] & 48) << 2);
5051
+ out.scales[i * 12 + 3] = (s[3] & 63) + ((s[7] & 48) << 2);
5052
+ out.scales[i * 12 + 4] = (m[0] & 63) + ((m[4] & 48) << 2);
5053
+ out.scales[i * 12 + 5] = (m[1] & 63) + ((m[5] & 48) << 2);
5054
+ out.scales[i * 12 + 6] = (m[2] & 63) + ((m[6] & 48) << 2);
5055
+ out.scales[i * 12 + 7] = (m[3] & 63) + ((m[7] & 48) << 2);
5056
+ out.scales[i * 12 + 8] = (s[4] & 15) + ((m[4] & 15) << 4);
5057
+ out.scales[i * 12 + 9] = (s[5] & 15) + ((m[5] & 15) << 4);
5058
+ out.scales[i * 12 + 10] = (s[6] & 15) + ((m[6] & 15) << 4);
5059
+ out.scales[i * 12 + 11] = (s[7] & 15) + ((m[7] & 15) << 4);
5060
+
5061
+ }
5062
+
5063
+ for (int i = 0; i < 4; i++) {
5064
+ for (int j = 0; j < 8; j++) {
5065
+ s[j] = ((in[j].scales[i] & 192) >> 2) | (in[j].scales[i+8] & 15);
5066
+ m[j] = ((in[j].scales[i + 4] & 192) >> 2) | ((in[j].scales[i+8] & 240) >> 4);
5067
+ }
5068
+
5069
+ out.scales[i * 12 + 48] = (s[0] & 63) + ((s[4] & 48) << 2);
5070
+ out.scales[i * 12 + 49] = (s[1] & 63) + ((s[5] & 48) << 2);
5071
+ out.scales[i * 12 + 50] = (s[2] & 63) + ((s[6] & 48) << 2);
5072
+ out.scales[i * 12 + 51] = (s[3] & 63) + ((s[7] & 48) << 2);
5073
+ out.scales[i * 12 + 52] = (m[0] & 63) + ((m[4] & 48) << 2);
5074
+ out.scales[i * 12 + 53] = (m[1] & 63) + ((m[5] & 48) << 2);
5075
+ out.scales[i * 12 + 54] = (m[2] & 63) + ((m[6] & 48) << 2);
5076
+ out.scales[i * 12 + 55] = (m[3] & 63) + ((m[7] & 48) << 2);
5077
+ out.scales[i * 12 + 56] = (s[4] & 15) + ((m[4] & 15) << 4);
5078
+ out.scales[i * 12 + 57] = (s[5] & 15) + ((m[5] & 15) << 4);
5079
+ out.scales[i * 12 + 58] = (s[6] & 15) + ((m[6] & 15) << 4);
5080
+ out.scales[i * 12 + 59] = (s[7] & 15) + ((m[7] & 15) << 4);
5081
+
5082
+ }
5083
+
5084
+ return out;
5085
+ }
5086
+
3663
5087
  static int repack_q4_0_to_q4_0_4_bl(struct lm_ggml_tensor * t, int interleave_block, const void * LM_GGML_RESTRICT data, size_t data_size) {
3664
5088
  LM_GGML_ASSERT(t->type == LM_GGML_TYPE_Q4_0);
3665
5089
  LM_GGML_ASSERT(interleave_block == 4 || interleave_block == 8);
@@ -3690,6 +5114,36 @@ static int repack_q4_0_to_q4_0_4_bl(struct lm_ggml_tensor * t, int interleave_bl
3690
5114
 
3691
5115
  LM_GGML_UNUSED(data_size);
3692
5116
  }
5117
+ static int repack_q4_K_to_q4_K_8_bl(struct lm_ggml_tensor * t, int interleave_block, const void * LM_GGML_RESTRICT data, size_t data_size) {
5118
+ LM_GGML_ASSERT(t->type == LM_GGML_TYPE_Q4_K);
5119
+ LM_GGML_ASSERT(interleave_block == 8);
5120
+ constexpr int nrows_interleaved = 8;
5121
+
5122
+ block_q4_Kx8 * dst = (block_q4_Kx8*)t->data;
5123
+ const block_q4_K * src = (const block_q4_K*) data;
5124
+ block_q4_K dst_tmp[8];
5125
+ int nrow = lm_ggml_nrows(t);
5126
+ int nblocks = t->ne[0] / QK_K;
5127
+
5128
+ LM_GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_K));
5129
+
5130
+ if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
5131
+ return -1;
5132
+ }
5133
+
5134
+ for (int b = 0; b < nrow; b += nrows_interleaved) {
5135
+ for (int64_t x = 0; x < nblocks; x++) {
5136
+ for (int i = 0; i < nrows_interleaved; i++ ) {
5137
+ dst_tmp[i] = src[x + i * nblocks];
5138
+ }
5139
+ *dst++ = make_block_q4_Kx8(dst_tmp, interleave_block);
5140
+ }
5141
+ src += nrows_interleaved * nblocks;
5142
+ }
5143
+ return 0;
5144
+
5145
+ LM_GGML_UNUSED(data_size);
5146
+ }
3693
5147
 
3694
5148
  static int repack_q4_0_to_q4_0_8_bl(struct lm_ggml_tensor * t, int interleave_block, const void * LM_GGML_RESTRICT data, size_t data_size) {
3695
5149
  LM_GGML_ASSERT(t->type == LM_GGML_TYPE_Q4_0);
@@ -3807,6 +5261,10 @@ template <> int repack<block_q4_0, 8, 8>(struct lm_ggml_tensor * t, const void *
3807
5261
  return repack_q4_0_to_q4_0_8_bl(t, 8, data, data_size);
3808
5262
  }
3809
5263
 
5264
+ template <> int repack<block_q4_K, 8, 8>(struct lm_ggml_tensor * t, const void * data, size_t data_size) {
5265
+ return repack_q4_K_to_q4_K_8_bl(t, 8, data, data_size);
5266
+ }
5267
+
3810
5268
  template <> int repack<block_iq4_nl, 4, 4>(struct lm_ggml_tensor * t, const void * data, size_t data_size) {
3811
5269
  return repack_iq4_nl_to_iq4_nl_4_bl(t, 4, data, data_size);
3812
5270
  }
@@ -3817,44 +5275,50 @@ template <> int repack<block_iq4_nl, 4, 4>(struct lm_ggml_tensor * t, const void
3817
5275
  //}
3818
5276
 
3819
5277
  // gemv
3820
- template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS>
5278
+ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, lm_ggml_type PARAM_TYPE>
3821
5279
  void gemv(int, float *, size_t, const void *, const void *, int, int);
3822
5280
 
3823
- template <> void gemv<block_q4_0, 4, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
5281
+ template <> void gemv<block_q4_0, 4, 4, LM_GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
3824
5282
  lm_ggml_gemv_q4_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
3825
5283
  }
3826
5284
 
3827
- template <> void gemv<block_q4_0, 8, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
5285
+ template <> void gemv<block_q4_0, 8, 4, LM_GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
3828
5286
  lm_ggml_gemv_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
3829
5287
  }
3830
5288
 
3831
- template <> void gemv<block_q4_0, 8, 8>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
5289
+ template <> void gemv<block_q4_0, 8, 8, LM_GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
3832
5290
  lm_ggml_gemv_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
3833
5291
  }
3834
5292
 
3835
- template <>
3836
- void gemv<block_iq4_nl, 4, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
5293
+ template <> void gemv<block_q4_K, 8, 8, LM_GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
5294
+ lm_ggml_gemv_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
5295
+ }
5296
+
5297
+ template <> void gemv<block_iq4_nl, 4, 4, LM_GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
3837
5298
  lm_ggml_gemv_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
3838
5299
  }
3839
5300
 
3840
5301
  // gemm
3841
- template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS>
5302
+ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, lm_ggml_type PARAM_TYPE>
3842
5303
  void gemm(int, float *, size_t, const void *, const void *, int, int);
3843
5304
 
3844
- template <> void gemm<block_q4_0, 4, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
5305
+ template <> void gemm<block_q4_0, 4, 4, LM_GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
3845
5306
  lm_ggml_gemm_q4_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
3846
5307
  }
3847
5308
 
3848
- template <> void gemm<block_q4_0, 8, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
5309
+ template <> void gemm<block_q4_0, 8, 4, LM_GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
3849
5310
  lm_ggml_gemm_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
3850
5311
  }
3851
5312
 
3852
- template <> void gemm<block_q4_0, 8, 8>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
5313
+ template <> void gemm<block_q4_0, 8, 8, LM_GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
3853
5314
  lm_ggml_gemm_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
3854
5315
  }
3855
5316
 
3856
- template <>
3857
- void gemm<block_iq4_nl, 4, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
5317
+ template <> void gemm<block_q4_K, 8, 8, LM_GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
5318
+ lm_ggml_gemm_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
5319
+ }
5320
+
5321
+ template <> void gemm<block_iq4_nl, 4, 4, LM_GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
3858
5322
  lm_ggml_gemm_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
3859
5323
  }
3860
5324
 
@@ -3863,37 +5327,37 @@ class tensor_traits_base : public ggml::cpu::tensor_traits {
3863
5327
  virtual int repack(struct lm_ggml_tensor * t, const void * data, size_t data_size) = 0;
3864
5328
  };
3865
5329
 
3866
- template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS> class tensor_traits : public tensor_traits_base {
5330
+ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, lm_ggml_type PARAM_TYPE> class tensor_traits : public tensor_traits_base {
3867
5331
 
3868
5332
  bool work_size(int /* n_threads */, const struct lm_ggml_tensor * op, size_t & size) override {
3869
5333
  // not realy a LM_GGML_TYPE_Q8_0 but same size.
3870
5334
  switch (op->op) {
3871
- case LM_GGML_OP_MUL_MAT:
3872
- size = lm_ggml_row_size(LM_GGML_TYPE_Q8_0, lm_ggml_nelements(op->src[1]));
3873
- return true;
3874
- case LM_GGML_OP_MUL_MAT_ID:
3875
- size = lm_ggml_row_size(LM_GGML_TYPE_Q8_0, lm_ggml_nelements(op->src[1]));
3876
- size = LM_GGML_PAD(size, sizeof(int64_t)); // + padding for next bloc.
3877
- size += sizeof(int64_t) * (1+op->src[0]->ne[2]) * op->src[1]->ne[2];
3878
- return true;
3879
- default:
3880
- // LM_GGML_ABORT("fatal error");
3881
- break;
5335
+ case LM_GGML_OP_MUL_MAT:
5336
+ size = lm_ggml_row_size(PARAM_TYPE, lm_ggml_nelements(op->src[1]));
5337
+ return true;
5338
+ case LM_GGML_OP_MUL_MAT_ID:
5339
+ size = lm_ggml_row_size(PARAM_TYPE, lm_ggml_nelements(op->src[1]));
5340
+ size = LM_GGML_PAD(size, sizeof(int64_t)); // + padding for next bloc.
5341
+ size += sizeof(int64_t) * (1+op->src[0]->ne[2]) * op->src[1]->ne[2];
5342
+ return true;
5343
+ default:
5344
+ // LM_GGML_ABORT("fatal error");
5345
+ break;
3882
5346
  }
3883
5347
  return false;
3884
5348
  }
3885
5349
 
3886
5350
  bool compute_forward(struct lm_ggml_compute_params * params, struct lm_ggml_tensor * op) override {
3887
5351
  switch (op->op) {
3888
- case LM_GGML_OP_MUL_MAT:
3889
- forward_mul_mat(params, op);
3890
- return true;
3891
- case LM_GGML_OP_MUL_MAT_ID:
3892
- forward_mul_mat_id(params, op);
3893
- return true;
3894
- default:
3895
- // LM_GGML_ABORT("fatal error");
3896
- break;
5352
+ case LM_GGML_OP_MUL_MAT:
5353
+ forward_mul_mat(params, op);
5354
+ return true;
5355
+ case LM_GGML_OP_MUL_MAT_ID:
5356
+ forward_mul_mat_id(params, op);
5357
+ return true;
5358
+ default:
5359
+ // LM_GGML_ABORT("fatal error");
5360
+ break;
3897
5361
  }
3898
5362
  return false;
3899
5363
  }
@@ -3925,17 +5389,17 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS> class tensor_
3925
5389
  // LM_GGML_ASSERT(lm_ggml_n_dims(op->src[1]) == 2);
3926
5390
 
3927
5391
  char * wdata = static_cast<char *>(params->wdata);
3928
- const size_t nbw1 = lm_ggml_row_size(LM_GGML_TYPE_Q8_0, ne10);
5392
+ const size_t nbw1 = lm_ggml_row_size(PARAM_TYPE, ne10);
3929
5393
 
3930
5394
  assert(params->wsize >= nbw1 * ne11);
3931
5395
 
3932
- const lm_ggml_from_float_t from_float = lm_ggml_get_type_traits_cpu(LM_GGML_TYPE_Q8_0)->from_float;
5396
+ const lm_ggml_from_float_t from_float = lm_ggml_get_type_traits_cpu(PARAM_TYPE)->from_float;
3933
5397
 
3934
5398
  int64_t i11_processed = 0;
3935
5399
  for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
3936
- quantize_mat_q8_0((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), 4, ne10,
3937
- INTER_SIZE);
5400
+ lm_ggml_quantize_mat_t<INTER_SIZE, PARAM_TYPE>((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), 4, ne10);
3938
5401
  }
5402
+
3939
5403
  i11_processed = ne11 - ne11 % 4;
3940
5404
  for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) {
3941
5405
  from_float((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), ne10);
@@ -3944,26 +5408,28 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS> class tensor_
3944
5408
  lm_ggml_barrier(params->threadpool);
3945
5409
 
3946
5410
  const void * src1_wdata = params->wdata;
3947
- const size_t src1_col_stride = lm_ggml_row_size(LM_GGML_TYPE_Q8_0, ne10);
5411
+ const size_t src1_col_stride = lm_ggml_row_size(PARAM_TYPE, ne10);
3948
5412
  int64_t src0_start = (ith * ne01) / nth;
3949
5413
  int64_t src0_end = ((ith + 1) * ne01) / nth;
3950
5414
  src0_start = (src0_start % NB_COLS) ? src0_start + NB_COLS - (src0_start % NB_COLS) : src0_start;
3951
- src0_end = (src0_end % NB_COLS) ? src0_end + NB_COLS - (src0_end % NB_COLS) : src0_end;
5415
+ src0_end = (src0_end % NB_COLS) ? src0_end + NB_COLS - (src0_end % NB_COLS) : src0_end;
3952
5416
  if (src0_start >= src0_end) {
3953
5417
  return;
3954
5418
  }
3955
5419
 
3956
5420
  // If there are more than three rows in src1, use gemm; otherwise, use gemv.
3957
5421
  if (ne11 > 3) {
3958
- gemm<BLOC_TYPE, INTER_SIZE, NB_COLS>(ne00, (float *) ((char *) dst->data) + src0_start, ne01,
3959
- (const char *) src0->data + src0_start * nb01,
3960
- (const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start);
5422
+ gemm<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
5423
+ (float *) ((char *) dst->data) + src0_start, ne01,
5424
+ (const char *) src0->data + src0_start * nb01,
5425
+ (const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start);
3961
5426
  }
3962
5427
  for (int iter = ne11 - ne11 % 4; iter < ne11; iter++) {
3963
- gemv<BLOC_TYPE, INTER_SIZE, NB_COLS>(ne00, (float *) ((char *) dst->data + (iter * nb1)) + src0_start, ne01,
3964
- (const char *) src0->data + src0_start * nb01,
3965
- (const char *) src1_wdata + (src1_col_stride * iter), 1,
3966
- src0_end - src0_start);
5428
+ gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
5429
+ (float *) ((char *) dst->data + (iter * nb1)) + src0_start, ne01,
5430
+ (const char *) src0->data + src0_start * nb01,
5431
+ (const char *) src1_wdata + (src1_col_stride * iter), 1,
5432
+ src0_end - src0_start);
3967
5433
  }
3968
5434
  }
3969
5435
 
@@ -3978,7 +5444,7 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS> class tensor_
3978
5444
  const int ith = params->ith;
3979
5445
  const int nth = params->nth;
3980
5446
 
3981
- const lm_ggml_from_float_t from_float = lm_ggml_get_type_traits_cpu(LM_GGML_TYPE_Q8_0)->from_float;
5447
+ const lm_ggml_from_float_t from_float = lm_ggml_get_type_traits_cpu(PARAM_TYPE)->from_float;
3982
5448
 
3983
5449
  // we don't support permuted src0 or src1
3984
5450
  LM_GGML_ASSERT(nb00 == lm_ggml_type_size(src0->type));
@@ -4000,7 +5466,7 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS> class tensor_
4000
5466
  const int n_ids = ids->ne[0]; // n_expert_used
4001
5467
  const int n_as = ne02; // n_expert
4002
5468
 
4003
- const size_t nbw1 = lm_ggml_row_size(LM_GGML_TYPE_Q8_0, ne10);
5469
+ const size_t nbw1 = lm_ggml_row_size(PARAM_TYPE, ne10);
4004
5470
  const size_t nbw2 = nbw1*ne11;
4005
5471
  const size_t nbw3 = nbw2*ne12;
4006
5472
 
@@ -4012,12 +5478,13 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS> class tensor_
4012
5478
  LM_GGML_ASSERT(params->wsize >= (LM_GGML_PAD(nbw3, sizeof(int64_t)) + n_as * sizeof(int64_t) +
4013
5479
  n_as * ne12 * sizeof(mmid_row_mapping)));
4014
5480
 
4015
- auto wdata = (char *) params->wdata;
4016
- auto wdata_src1_end = (char *) wdata + LM_GGML_PAD(nbw3, sizeof(int64_t));
4017
- int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
5481
+ auto * wdata = (char *) params->wdata;
5482
+ auto * wdata_src1_end = (char *) wdata + LM_GGML_PAD(nbw3, sizeof(int64_t));
5483
+ auto * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
5484
+
4018
5485
  struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *) (matrix_row_counts + n_as); // [n_as][ne12]
4019
5486
 
4020
- // src1: float32 => block_q8_0
5487
+ // src1: float32 => param type
4021
5488
  for (int64_t i12 = 0; i12 < ne12; ++i12) {
4022
5489
  for (int64_t i11 = ith; i11 < ne11; i11 += nth) {
4023
5490
  from_float((float *)((char *) src1->data + i12 * nb12 + i11 * nb11),
@@ -4056,34 +5523,37 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS> class tensor_
4056
5523
  continue;
4057
5524
  }
4058
5525
 
4059
- auto src0_cur = (const char *) src0->data + cur_a*nb02;
5526
+ const auto * src0_cur = (const char *) src0->data + cur_a*nb02;
4060
5527
 
4061
5528
  //const int64_t nr0 = ne01; // src0 rows
4062
5529
  const int64_t nr1 = cne1; // src1 rows
4063
5530
 
4064
5531
  int64_t src0_cur_start = (ith * ne01) / nth;
4065
5532
  int64_t src0_cur_end = ((ith + 1) * ne01) / nth;
4066
- src0_cur_start =
4067
- (src0_cur_start % NB_COLS) ? src0_cur_start + NB_COLS - (src0_cur_start % NB_COLS) : src0_cur_start;
4068
- src0_cur_end = (src0_cur_end % NB_COLS) ? src0_cur_end + NB_COLS - (src0_cur_end % NB_COLS) : src0_cur_end;
4069
5533
 
4070
- if (src0_cur_start >= src0_cur_end) return;
5534
+ src0_cur_start = (src0_cur_start % NB_COLS) ? src0_cur_start + NB_COLS - (src0_cur_start % NB_COLS) : src0_cur_start;
5535
+ src0_cur_end = (src0_cur_end % NB_COLS) ? src0_cur_end + NB_COLS - (src0_cur_end % NB_COLS) : src0_cur_end;
5536
+
5537
+ if (src0_cur_start >= src0_cur_end) {
5538
+ return;
5539
+ }
4071
5540
 
4072
5541
  for (int ir1 = 0; ir1 < nr1; ir1++) {
4073
5542
  struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, ir1);
4074
- const int id = row_mapping.i1; // selected expert index
4075
5543
 
4076
- const int64_t i11 = id % ne11;
4077
- const int64_t i12 = row_mapping.i2; // row index in src1
5544
+ const int id = row_mapping.i1; // selected expert index
4078
5545
 
4079
- const int64_t i1 = id; // selected expert index
4080
- const int64_t i2 = i12; // row
5546
+ const int64_t i11 = id % ne11;
5547
+ const int64_t i12 = row_mapping.i2; // row index in src1
4081
5548
 
4082
- auto src1_col = (const char *) wdata + (i11 * nbw1 + i12 * nbw2);
5549
+ const int64_t i1 = id; // selected expert index
5550
+ const int64_t i2 = i12; // row
4083
5551
 
4084
- gemv<BLOC_TYPE, INTER_SIZE, NB_COLS>(
4085
- ne00, (float *)((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start,
4086
- ne01, src0_cur + src0_cur_start * nb01,
5552
+ const auto * src1_col = (const char *) wdata + (i11 * nbw1 + i12 * nbw2);
5553
+
5554
+ gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
5555
+ (float *)((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start, ne01,
5556
+ src0_cur + src0_cur_start * nb01,
4087
5557
  src1_col, 1, src0_cur_end - src0_cur_start);
4088
5558
  }
4089
5559
  }
@@ -4098,12 +5568,13 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS> class tensor_
4098
5568
  };
4099
5569
 
4100
5570
  // instance for Q4
4101
- static const tensor_traits<block_q4_0, 4, 4> q4_0_4x4_q8_0;
4102
- static const tensor_traits<block_q4_0, 8, 4> q4_0_4x8_q8_0;
4103
- static const tensor_traits<block_q4_0, 8, 8> q4_0_8x8_q8_0;
5571
+ static const tensor_traits<block_q4_0, 4, 4, LM_GGML_TYPE_Q8_0> q4_0_4x4_q8_0;
5572
+ static const tensor_traits<block_q4_0, 8, 4, LM_GGML_TYPE_Q8_0> q4_0_4x8_q8_0;
5573
+ static const tensor_traits<block_q4_0, 8, 8, LM_GGML_TYPE_Q8_0> q4_0_8x8_q8_0;
5574
+ static const tensor_traits<block_q4_K, 8, 8, LM_GGML_TYPE_Q8_K> q4_K_8x8_q8_K;
4104
5575
 
4105
5576
  // instance for IQ4
4106
- static const tensor_traits<block_iq4_nl, 4, 4> iq4_nl_4x4_q8_0;
5577
+ static const tensor_traits<block_iq4_nl, 4, 4, LM_GGML_TYPE_Q8_0> iq4_nl_4x4_q8_0;
4107
5578
 
4108
5579
  } // namespace ggml::cpu::aarch64
4109
5580
 
@@ -4124,6 +5595,12 @@ static const ggml::cpu::tensor_traits * lm_ggml_aarch64_get_optimal_repack_type(
4124
5595
  return &ggml::cpu::aarch64::q4_0_4x4_q8_0;
4125
5596
  }
4126
5597
  }
5598
+ } else if (cur->type == LM_GGML_TYPE_Q4_K) {
5599
+ if (lm_ggml_cpu_has_avx2()) {
5600
+ if (cur->ne[1] % 8 == 0) {
5601
+ return &ggml::cpu::aarch64::q4_K_8x8_q8_K;
5602
+ }
5603
+ }
4127
5604
  } else if (cur->type == LM_GGML_TYPE_IQ4_NL) {
4128
5605
  if (lm_ggml_cpu_has_neon() && lm_ggml_cpu_has_dotprod()) {
4129
5606
  if (cur->ne[1] % 4 == 0) {