cui-llama.rn 1.4.6 → 1.5.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/src/main/CMakeLists.txt +9 -2
- package/android/src/main/jni.cpp +52 -34
- package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
- package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
- package/cpp/binary-ops.cpp +158 -0
- package/cpp/binary-ops.h +16 -0
- package/cpp/chat.cpp +1769 -1779
- package/cpp/chat.h +9 -1
- package/cpp/common.cpp +20 -522
- package/cpp/common.h +13 -36
- package/cpp/cpu-common.h +72 -0
- package/cpp/ggml-common.h +12 -6
- package/cpp/ggml-cpu-aarch64.cpp +1557 -80
- package/cpp/ggml-cpu-impl.h +2 -21
- package/cpp/ggml-cpu-quants.c +904 -405
- package/cpp/ggml-cpu.c +909 -13237
- package/cpp/ggml-impl.h +50 -23
- package/cpp/ggml-metal-impl.h +77 -3
- package/cpp/ggml-metal.m +794 -580
- package/cpp/ggml.c +92 -3
- package/cpp/ggml.h +29 -5
- package/cpp/gguf.cpp +1 -0
- package/cpp/llama-adapter.cpp +55 -20
- package/cpp/llama-adapter.h +11 -9
- package/cpp/llama-arch.cpp +217 -16
- package/cpp/llama-arch.h +25 -0
- package/cpp/llama-batch.h +2 -2
- package/cpp/llama-chat.cpp +54 -2
- package/cpp/llama-chat.h +3 -0
- package/cpp/llama-context.cpp +2294 -1238
- package/cpp/llama-context.h +214 -77
- package/cpp/llama-cparams.h +1 -0
- package/cpp/llama-graph.cpp +1695 -0
- package/cpp/llama-graph.h +592 -0
- package/cpp/llama-hparams.cpp +8 -0
- package/cpp/llama-hparams.h +17 -0
- package/cpp/llama-io.cpp +15 -0
- package/cpp/llama-io.h +35 -0
- package/cpp/llama-kv-cache.cpp +965 -303
- package/cpp/llama-kv-cache.h +145 -151
- package/cpp/llama-memory.cpp +1 -0
- package/cpp/llama-memory.h +21 -0
- package/cpp/llama-mmap.cpp +1 -1
- package/cpp/llama-model-loader.cpp +10 -5
- package/cpp/llama-model-loader.h +5 -3
- package/cpp/llama-model.cpp +9194 -201
- package/cpp/llama-model.h +40 -1
- package/cpp/llama-sampling.cpp +5 -0
- package/cpp/llama-vocab.cpp +36 -5
- package/cpp/llama.cpp +51 -9984
- package/cpp/llama.h +102 -22
- package/cpp/log.cpp +34 -0
- package/cpp/minja/chat-template.hpp +15 -7
- package/cpp/minja/minja.hpp +120 -94
- package/cpp/ops.cpp +8723 -0
- package/cpp/ops.h +128 -0
- package/cpp/rn-llama.cpp +44 -53
- package/cpp/rn-llama.h +2 -12
- package/cpp/sampling.cpp +3 -0
- package/cpp/sgemm.cpp +533 -88
- package/cpp/simd-mappings.h +888 -0
- package/cpp/speculative.cpp +4 -4
- package/cpp/unary-ops.cpp +186 -0
- package/cpp/unary-ops.h +28 -0
- package/cpp/vec.cpp +258 -0
- package/cpp/vec.h +802 -0
- package/ios/CMakeLists.txt +5 -2
- package/ios/RNLlama.mm +2 -2
- package/ios/RNLlamaContext.mm +40 -24
- package/package.json +1 -1
- package/src/NativeRNLlama.ts +6 -4
- package/src/index.ts +3 -1
- package/cpp/chat-template.hpp +0 -529
- package/cpp/minja.hpp +0 -2915
package/cpp/ggml-cpu-aarch64.cpp
CHANGED
@@ -45,6 +45,24 @@ using block_q4_0x8 = block<4, 8>;
|
|
45
45
|
using block_q8_0x4 = block<8, 4>;
|
46
46
|
using block_q8_0x8 = block<8, 8>;
|
47
47
|
|
48
|
+
|
49
|
+
struct block_q4_Kx8 {
|
50
|
+
lm_ggml_half d[8]; // super-block scale for quantized scales
|
51
|
+
lm_ggml_half dmin[8]; // super-block scale for quantized mins
|
52
|
+
uint8_t scales[96]; // scales and mins, quantized with 6 bits
|
53
|
+
uint8_t qs[1024]; // 4--bit quants
|
54
|
+
};
|
55
|
+
|
56
|
+
static_assert(sizeof(block_q4_Kx8) == sizeof(lm_ggml_half) * 16 + K_SCALE_SIZE * 8 + QK_K * 4, "wrong q4_K block size/padding");
|
57
|
+
|
58
|
+
struct block_q8_Kx4 {
|
59
|
+
float d[4]; // delta
|
60
|
+
int8_t qs[QK_K * 4]; // quants
|
61
|
+
int16_t bsums[QK_K / 4]; // sum of quants in groups of 16
|
62
|
+
};
|
63
|
+
|
64
|
+
static_assert(sizeof(block_q8_Kx4) == sizeof(float) * 4 + QK_K * 4 + (QK_K / 4) * sizeof(int16_t), "wrong q8_K block size/padding");
|
65
|
+
|
48
66
|
struct block_iq4_nlx4 {
|
49
67
|
lm_ggml_half d[4]; // deltas for 4 iq4_nl blocks
|
50
68
|
uint8_t qs[QK4_NL * 2]; // nibbles / quants for 4 iq4_nl blocks
|
@@ -60,6 +78,13 @@ static_assert(sizeof(block_iq4_nlx4) == 4 * sizeof(lm_ggml_half) + QK4_NL * 2, "
|
|
60
78
|
|
61
79
|
#define UNUSED LM_GGML_UNUSED
|
62
80
|
|
81
|
+
static inline int nearest_int(float fval) {
|
82
|
+
assert(fabsf(fval) <= 4194303.f);
|
83
|
+
float val = fval + 12582912.f;
|
84
|
+
int i; memcpy(&i, &val, sizeof(int));
|
85
|
+
return (i & 0x007fffff) - 0x00400000;
|
86
|
+
}
|
87
|
+
|
63
88
|
// Functions to create the interleaved data layout formats
|
64
89
|
|
65
90
|
// interleave 4 block_q4_0s in blocks of blck_size_interleave
|
@@ -225,7 +250,7 @@ static inline __m256i mul_sum_i8_pairs_int32x8(const __m256i x, const __m256i y)
|
|
225
250
|
|
226
251
|
static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
|
227
252
|
|
228
|
-
static void
|
253
|
+
static void lm_ggml_quantize_mat_q8_0_4x4(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT vy, int64_t k) {
|
229
254
|
assert(QK8_0 == 32);
|
230
255
|
assert(k % QK8_0 == 0);
|
231
256
|
const int nb = k / QK8_0;
|
@@ -319,7 +344,7 @@ static void quantize_q8_0_4x4(const float * LM_GGML_RESTRICT x, void * LM_GGML_R
|
|
319
344
|
#endif
|
320
345
|
}
|
321
346
|
|
322
|
-
static void
|
347
|
+
static void lm_ggml_quantize_mat_q8_0_4x8(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT vy, int64_t k) {
|
323
348
|
assert(QK8_0 == 32);
|
324
349
|
assert(k % QK8_0 == 0);
|
325
350
|
const int nb = k / QK8_0;
|
@@ -534,16 +559,289 @@ static void quantize_q8_0_4x8(const float * LM_GGML_RESTRICT x, void * LM_GGML_R
|
|
534
559
|
#endif
|
535
560
|
}
|
536
561
|
|
537
|
-
static void
|
562
|
+
static void lm_ggml_quantize_mat_q8_K_4x8(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT vy, int64_t k) {
|
563
|
+
assert(QK_K == 256);
|
564
|
+
assert(k % QK_K == 0);
|
565
|
+
const int nb = k / QK_K;
|
566
|
+
|
567
|
+
block_q8_Kx4 * LM_GGML_RESTRICT y = (block_q8_Kx4 *) vy;
|
568
|
+
|
569
|
+
#if defined(__AVX2__)
|
570
|
+
float iscale[4];
|
571
|
+
__m256 srcv[4][32];
|
572
|
+
__m256 iscale_vec[4];
|
573
|
+
|
574
|
+
for (int i = 0; i < nb; i++) {
|
575
|
+
for (int row_iter = 0; row_iter < 4; row_iter++) {
|
576
|
+
// Load elements into 4 AVX vectors
|
577
|
+
__m256 v0 = _mm256_loadu_ps( x + row_iter * k + i * 256 );
|
578
|
+
__m256 v1 = _mm256_loadu_ps( x + row_iter * k + i * 256 + 8 );
|
579
|
+
__m256 v2 = _mm256_loadu_ps( x + row_iter * k + i * 256 + 16 );
|
580
|
+
__m256 v3 = _mm256_loadu_ps( x + row_iter * k + i * 256 + 24 );
|
581
|
+
|
582
|
+
// Compute max(abs(e)) for the block
|
583
|
+
const __m256 signBit = _mm256_set1_ps( -0.0f );
|
584
|
+
__m256 abs0 = _mm256_andnot_ps( signBit, v0 );
|
585
|
+
__m256 abs1 = _mm256_andnot_ps( signBit, v1 );
|
586
|
+
__m256 abs2 = _mm256_andnot_ps( signBit, v2 );
|
587
|
+
__m256 abs3 = _mm256_andnot_ps( signBit, v3 );
|
588
|
+
|
589
|
+
__m256 maxAbs = _mm256_max_ps( abs0, abs1 );
|
590
|
+
maxAbs = _mm256_max_ps( maxAbs, abs2 );
|
591
|
+
maxAbs = _mm256_max_ps( maxAbs, abs3 );
|
592
|
+
|
593
|
+
__m256 mask0 = _mm256_cmp_ps( maxAbs, v0, _CMP_EQ_OQ );
|
594
|
+
__m256 mask1 = _mm256_cmp_ps( maxAbs, v1, _CMP_EQ_OQ );
|
595
|
+
__m256 mask2 = _mm256_cmp_ps( maxAbs, v2, _CMP_EQ_OQ );
|
596
|
+
__m256 mask3 = _mm256_cmp_ps( maxAbs, v3, _CMP_EQ_OQ );
|
597
|
+
|
598
|
+
__m256 maskAbs = _mm256_or_ps(_mm256_or_ps(mask0, mask1),_mm256_or_ps(mask2, mask3));
|
599
|
+
|
600
|
+
srcv[row_iter][0] = v0;
|
601
|
+
srcv[row_iter][1] = v1;
|
602
|
+
srcv[row_iter][2] = v2;
|
603
|
+
srcv[row_iter][3] = v3;
|
604
|
+
|
605
|
+
for (int sb = 1; sb < 8; sb++) {
|
606
|
+
// Temporarily stores absolute quant values
|
607
|
+
__m256 tempAbs = maxAbs;
|
608
|
+
|
609
|
+
// Load elements into 4 AVX vectors
|
610
|
+
__m256 v0 = _mm256_loadu_ps( x + row_iter * k + i * 256 + sb * 32);
|
611
|
+
__m256 v1 = _mm256_loadu_ps( x + row_iter * k + i * 256 + sb * 32 + 8 );
|
612
|
+
__m256 v2 = _mm256_loadu_ps( x + row_iter * k + i * 256 + sb * 32 + 16 );
|
613
|
+
__m256 v3 = _mm256_loadu_ps( x + row_iter * k + i * 256 + sb * 32 + 24 );
|
614
|
+
|
615
|
+
// Compute max(abs(e)) for the block
|
616
|
+
__m256 abs0 = _mm256_andnot_ps( signBit, v0 );
|
617
|
+
__m256 abs1 = _mm256_andnot_ps( signBit, v1 );
|
618
|
+
__m256 abs2 = _mm256_andnot_ps( signBit, v2 );
|
619
|
+
__m256 abs3 = _mm256_andnot_ps( signBit, v3 );
|
620
|
+
|
621
|
+
maxAbs = _mm256_max_ps( maxAbs, abs0 );
|
622
|
+
maxAbs = _mm256_max_ps( maxAbs, abs1 );
|
623
|
+
maxAbs = _mm256_max_ps( maxAbs, abs2 );
|
624
|
+
maxAbs = _mm256_max_ps( maxAbs, abs3 );
|
625
|
+
|
626
|
+
__m256 mask_prev = _mm256_cmp_ps( tempAbs, maxAbs, _CMP_EQ_OQ );
|
627
|
+
maskAbs = _mm256_and_ps( maskAbs, mask_prev );
|
628
|
+
|
629
|
+
mask0 = _mm256_cmp_ps( maxAbs, v0, _CMP_EQ_OQ );
|
630
|
+
mask1 = _mm256_cmp_ps( maxAbs, v1, _CMP_EQ_OQ );
|
631
|
+
mask2 = _mm256_cmp_ps( maxAbs, v2, _CMP_EQ_OQ );
|
632
|
+
mask3 = _mm256_cmp_ps( maxAbs, v3, _CMP_EQ_OQ );
|
633
|
+
|
634
|
+
__m256 mask_curr = _mm256_or_ps(_mm256_or_ps(mask0, mask1),_mm256_or_ps(mask2, mask3));
|
635
|
+
maskAbs = _mm256_or_ps(maskAbs, mask_curr);
|
636
|
+
|
637
|
+
srcv[row_iter][sb * 4] = v0;
|
638
|
+
srcv[row_iter][sb * 4 + 1] = v1;
|
639
|
+
srcv[row_iter][sb * 4 + 2] = v2;
|
640
|
+
srcv[row_iter][sb * 4 + 3] = v3;
|
641
|
+
}
|
642
|
+
|
643
|
+
__m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
|
644
|
+
max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
|
645
|
+
max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
|
646
|
+
const float maxScalar = _mm_cvtss_f32( max4 );
|
647
|
+
|
648
|
+
__m256 maxScalarVec = _mm256_set1_ps(maxScalar);
|
649
|
+
|
650
|
+
__m256 mask_next = _mm256_cmp_ps( maxScalarVec, maxAbs, _CMP_EQ_OQ );
|
651
|
+
__m256 finalMask = _mm256_and_ps(maskAbs, mask_next);
|
652
|
+
|
653
|
+
const int mask = _mm256_movemask_ps(finalMask);
|
654
|
+
iscale[row_iter] = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
|
655
|
+
|
656
|
+
if(mask) {
|
657
|
+
iscale[row_iter] = ( maxScalar != 0.0f ) ? -127.f / maxScalar: 0.0f;
|
658
|
+
}
|
659
|
+
|
660
|
+
y[i].d[row_iter] = maxScalar ? 1/iscale[row_iter] : 0;
|
661
|
+
iscale_vec[row_iter] = _mm256_set1_ps(iscale[row_iter]);
|
662
|
+
}
|
663
|
+
|
664
|
+
__m256i quants_interleaved[32];
|
665
|
+
for (int j = 0; j < 32; j++) {
|
666
|
+
// Apply the multiplier
|
667
|
+
__m256 v0 = _mm256_mul_ps(srcv[0][j], iscale_vec[0]);
|
668
|
+
__m256 v1 = _mm256_mul_ps(srcv[1][j], iscale_vec[1]);
|
669
|
+
__m256 v2 = _mm256_mul_ps(srcv[2][j], iscale_vec[2]);
|
670
|
+
__m256 v3 = _mm256_mul_ps(srcv[3][j], iscale_vec[3]);
|
671
|
+
|
672
|
+
// Round to nearest integer
|
673
|
+
v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
|
674
|
+
v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
|
675
|
+
v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
|
676
|
+
v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
|
677
|
+
|
678
|
+
// Convert floats to integers
|
679
|
+
__m256i i0 = _mm256_cvtps_epi32( v0 );
|
680
|
+
__m256i i1 = _mm256_cvtps_epi32( v1 );
|
681
|
+
__m256i i2 = _mm256_cvtps_epi32( v2 );
|
682
|
+
__m256i i3 = _mm256_cvtps_epi32( v3 );
|
683
|
+
|
684
|
+
// Convert int32 to int16
|
685
|
+
i0 = _mm256_packs_epi32( i0, i1 );
|
686
|
+
i2 = _mm256_packs_epi32( i2, i3 );
|
687
|
+
// Convert int16 to int8
|
688
|
+
i0 = _mm256_packs_epi16( i0, i2 );
|
689
|
+
|
690
|
+
// Permute and store the quantized weights in the required order after the pack instruction
|
691
|
+
const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
|
692
|
+
i0 = _mm256_permutevar8x32_epi32( i0, perm );
|
693
|
+
|
694
|
+
_mm256_storeu_si256((__m256i *)(y[i].qs + 32 * j), i0);
|
695
|
+
quants_interleaved[j] = i0;
|
696
|
+
}
|
697
|
+
|
698
|
+
// Masks to shuffle the quants of corresonding sub blocks for rearraning quants for vectorized bsums computation
|
699
|
+
__m256i shuffle_mask_sb2 = _mm256_castsi128_si256(_mm_setr_epi8(0, 1, 0, 1, 4, 5, 6, 7, 8, 9, 8, 9, 12, 13, 14, 15));
|
700
|
+
shuffle_mask_sb2 = _mm256_permute2f128_si256(shuffle_mask_sb2, shuffle_mask_sb2, 0);
|
701
|
+
__m256i shuffle_mask_sb3 = _mm256_castsi128_si256(_mm_setr_epi8(0, 1, 2, 3, 0, 1, 6, 7, 8, 9, 10, 11, 8, 9, 14, 15));
|
702
|
+
shuffle_mask_sb3 = _mm256_permute2f128_si256(shuffle_mask_sb3, shuffle_mask_sb3, 0);
|
703
|
+
__m256i shuffle_mask_sb4 = _mm256_castsi128_si256(_mm_setr_epi8(0, 1, 2, 3, 4, 5, 0, 1, 8, 9, 10, 11, 12, 13, 8, 9));
|
704
|
+
shuffle_mask_sb4 = _mm256_permute2f128_si256(shuffle_mask_sb4, shuffle_mask_sb4, 0);
|
705
|
+
|
706
|
+
for (int k = 0; k < 4; k++) {
|
707
|
+
// Quants from four different sub blocks are taken
|
708
|
+
__m256i q0 = quants_interleaved[k * 8 + 0];
|
709
|
+
__m256i q1 = quants_interleaved[k * 8 + 1];
|
710
|
+
__m256i q2 = quants_interleaved[k * 8 + 2];
|
711
|
+
__m256i q3 = quants_interleaved[k * 8 + 3];
|
712
|
+
__m256i q4 = quants_interleaved[k * 8 + 4];
|
713
|
+
__m256i q5 = quants_interleaved[k * 8 + 5];
|
714
|
+
__m256i q6 = quants_interleaved[k * 8 + 6];
|
715
|
+
__m256i q7 = quants_interleaved[k * 8 + 7];
|
716
|
+
|
717
|
+
|
718
|
+
// The below code block has the first half of different sub blocks shuffled and blended so as to process 2 values from each sub block at a time
|
719
|
+
__m256i sb2_h1_shuffled = _mm256_shuffle_epi8(q2, shuffle_mask_sb2);
|
720
|
+
__m256i sb_h1_interleaved = _mm256_blend_epi16(q0, sb2_h1_shuffled, 34);
|
721
|
+
__m256i sb3_h1_shuffled = _mm256_shuffle_epi8(q4, shuffle_mask_sb3);
|
722
|
+
sb_h1_interleaved = _mm256_blend_epi16(sb_h1_interleaved, sb3_h1_shuffled, 68);
|
723
|
+
__m256i sb4_h1_shuffled = _mm256_shuffle_epi8(q6, shuffle_mask_sb4);
|
724
|
+
sb_h1_interleaved = _mm256_blend_epi16(sb_h1_interleaved, sb4_h1_shuffled, 136);
|
725
|
+
|
726
|
+
__m256i one = _mm256_set1_epi8(1);
|
727
|
+
__m256i bsums_r1 = _mm256_maddubs_epi16(one, sb_h1_interleaved);
|
728
|
+
|
729
|
+
for (int l = 0; l < 3; l++) {
|
730
|
+
// Quants value shifted to process next two values from each sub block
|
731
|
+
q0 = _mm256_srli_epi64(q0, 16);
|
732
|
+
q2 = _mm256_srli_epi64(q2, 16);
|
733
|
+
q4 = _mm256_srli_epi64(q4, 16);
|
734
|
+
q6 = _mm256_srli_epi64(q6, 16);
|
735
|
+
|
736
|
+
sb2_h1_shuffled = _mm256_shuffle_epi8(q2, shuffle_mask_sb2);
|
737
|
+
sb_h1_interleaved = _mm256_blend_epi16(q0, sb2_h1_shuffled, 34);
|
738
|
+
sb3_h1_shuffled = _mm256_shuffle_epi8(q4, shuffle_mask_sb3);
|
739
|
+
sb_h1_interleaved = _mm256_blend_epi16(sb_h1_interleaved, sb3_h1_shuffled, 68);
|
740
|
+
sb4_h1_shuffled = _mm256_shuffle_epi8(q6, shuffle_mask_sb4);
|
741
|
+
sb_h1_interleaved = _mm256_blend_epi16(sb_h1_interleaved, sb4_h1_shuffled, 136);
|
742
|
+
|
743
|
+
bsums_r1 = _mm256_add_epi16(bsums_r1, _mm256_maddubs_epi16(one, sb_h1_interleaved));
|
744
|
+
}
|
745
|
+
|
746
|
+
// The below code block has the second half of different sub blocks shuffled and blended so as to process 2 values from each sub block at a time
|
747
|
+
__m256i sb2_h2_shuffled = _mm256_shuffle_epi8(q3, shuffle_mask_sb2);
|
748
|
+
__m256i sb_h2_interleaved = _mm256_blend_epi16(q1, sb2_h2_shuffled, 34);
|
749
|
+
__m256i sb3_h2_shuffled = _mm256_shuffle_epi8(q5, shuffle_mask_sb3);
|
750
|
+
sb_h2_interleaved = _mm256_blend_epi16(sb_h2_interleaved, sb3_h2_shuffled, 68);
|
751
|
+
__m256i sb4_h2_shuffled = _mm256_shuffle_epi8(q7, shuffle_mask_sb4);
|
752
|
+
sb_h2_interleaved = _mm256_blend_epi16(sb_h2_interleaved, sb4_h2_shuffled, 136);
|
753
|
+
|
754
|
+
__m256i bsums_r2 = _mm256_maddubs_epi16(one, sb_h2_interleaved);
|
755
|
+
|
756
|
+
for (int l = 0; l < 3; l++) {
|
757
|
+
// Quants value shifted to process next two values from each sub block
|
758
|
+
q1 = _mm256_srli_epi64(q1, 16);
|
759
|
+
q3 = _mm256_srli_epi64(q3, 16);
|
760
|
+
q5 = _mm256_srli_epi64(q5, 16);
|
761
|
+
q7 = _mm256_srli_epi64(q7, 16);
|
762
|
+
|
763
|
+
sb2_h2_shuffled = _mm256_shuffle_epi8(q3, shuffle_mask_sb2);
|
764
|
+
sb_h2_interleaved = _mm256_blend_epi16(q1, sb2_h2_shuffled, 34);
|
765
|
+
sb3_h2_shuffled = _mm256_shuffle_epi8(q5, shuffle_mask_sb3);
|
766
|
+
sb_h2_interleaved = _mm256_blend_epi16(sb_h2_interleaved, sb3_h2_shuffled, 68);
|
767
|
+
sb4_h2_shuffled = _mm256_shuffle_epi8(q7, shuffle_mask_sb4);
|
768
|
+
sb_h2_interleaved = _mm256_blend_epi16(sb_h2_interleaved, sb4_h2_shuffled, 136);
|
769
|
+
|
770
|
+
bsums_r2 = _mm256_add_epi16(bsums_r2, _mm256_maddubs_epi16(one, sb_h2_interleaved));
|
771
|
+
}
|
772
|
+
|
773
|
+
// Overall bsums in interleaved fashion computed by adding results of both halves
|
774
|
+
__m256i bsums_r = _mm256_add_epi16(bsums_r1, bsums_r2);
|
775
|
+
_mm256_storeu_si256((__m256i *)(y[i].bsums + 16 * k), bsums_r);
|
776
|
+
}
|
777
|
+
}
|
778
|
+
|
779
|
+
#else
|
780
|
+
|
781
|
+
// scalar
|
782
|
+
const int blck_size_interleave = 8;
|
783
|
+
float srcv[4][QK_K];
|
784
|
+
float iscale[4];
|
785
|
+
|
786
|
+
for (int i = 0; i < nb; i++) {
|
787
|
+
for (int row_iter = 0; row_iter < 4; row_iter++) {
|
788
|
+
float amax = 0.0f; // absolute max
|
789
|
+
float max = 0;
|
790
|
+
|
791
|
+
for (int j = 0; j < QK_K; j++) {
|
792
|
+
srcv[row_iter][j] = x[row_iter * k + i * QK_K + j];
|
793
|
+
// Update the maximum value of the corresponding super block
|
794
|
+
if(amax < fabsf(srcv[row_iter][j])) {
|
795
|
+
amax = fabsf(srcv[row_iter][j]);
|
796
|
+
max = srcv[row_iter][j];
|
797
|
+
}
|
798
|
+
}
|
799
|
+
|
800
|
+
iscale[row_iter] = amax ? -127.f/max : 0;
|
801
|
+
|
802
|
+
y[i].d[row_iter] = amax ? 1/iscale[row_iter] : 0;
|
803
|
+
}
|
804
|
+
|
805
|
+
for (int j = 0; j < QK_K / 4; j++) {
|
806
|
+
y[i].bsums[j] = 0;
|
807
|
+
}
|
808
|
+
|
809
|
+
// Quants values are interleaved in sequence of eight bytes from corresponding super blocks
|
810
|
+
// Bsums values are interleaved in sequence of four bsums from each super block taken for interleaving
|
811
|
+
// i.e first four bsums from the first super block, followed by first four bsums from second super block and so on
|
812
|
+
for (int j = 0; j < QK_K * 4; j++) {
|
813
|
+
int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave;
|
814
|
+
int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave;
|
815
|
+
src_offset += (j % blck_size_interleave);
|
816
|
+
int index = (((j & 31) >> 3) << 2) + ((j >> 8) << 4) + ((j >> 6) & 3);
|
817
|
+
|
818
|
+
float x0 = srcv[src_id][src_offset] * iscale[src_id];
|
819
|
+
y[i].qs[j] = nearest_int(x0);
|
820
|
+
y[i].bsums[index] += y[i].qs[j];
|
821
|
+
}
|
822
|
+
}
|
823
|
+
#endif
|
824
|
+
}
|
825
|
+
|
826
|
+
template <int64_t INTER_SIZE, lm_ggml_type PARAM_TYPE>
|
827
|
+
void lm_ggml_quantize_mat_t(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row);
|
828
|
+
|
829
|
+
template <> void lm_ggml_quantize_mat_t<4, LM_GGML_TYPE_Q8_0>(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {
|
538
830
|
assert(nrow == 4);
|
539
831
|
UNUSED(nrow);
|
540
|
-
|
541
|
-
|
542
|
-
|
543
|
-
|
544
|
-
|
545
|
-
|
546
|
-
|
832
|
+
lm_ggml_quantize_mat_q8_0_4x4(x, vy, n_per_row);
|
833
|
+
}
|
834
|
+
|
835
|
+
template <> void lm_ggml_quantize_mat_t<8, LM_GGML_TYPE_Q8_0>(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {
|
836
|
+
assert(nrow == 4);
|
837
|
+
UNUSED(nrow);
|
838
|
+
lm_ggml_quantize_mat_q8_0_4x8(x, vy, n_per_row);
|
839
|
+
}
|
840
|
+
|
841
|
+
template <> void lm_ggml_quantize_mat_t<8, LM_GGML_TYPE_Q8_K>(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {
|
842
|
+
assert(nrow == 4);
|
843
|
+
UNUSED(nrow);
|
844
|
+
lm_ggml_quantize_mat_q8_K_4x8(x, vy, n_per_row);
|
547
845
|
}
|
548
846
|
|
549
847
|
static void lm_ggml_gemv_q4_0_4x4_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, const void * LM_GGML_RESTRICT vy, int nr, int nc) {
|
@@ -994,6 +1292,281 @@ static void lm_ggml_gemv_q4_0_8x8_q8_0(int n, float * LM_GGML_RESTRICT s, size_t
|
|
994
1292
|
}
|
995
1293
|
}
|
996
1294
|
|
1295
|
+
static void lm_ggml_gemv_q4_K_8x8_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, const void * LM_GGML_RESTRICT vy, int nr, int nc) {
|
1296
|
+
const int qk = QK_K;
|
1297
|
+
const int nb = n / qk;
|
1298
|
+
const int ncols_interleaved = 8;
|
1299
|
+
const int blocklen = 8;
|
1300
|
+
static const uint32_t kmask1 = 0x3f3f3f3f;
|
1301
|
+
static const uint32_t kmask2 = 0x0f0f0f0f;
|
1302
|
+
static const uint32_t kmask3 = 0x03030303;
|
1303
|
+
|
1304
|
+
assert (n % qk == 0);
|
1305
|
+
assert (nc % ncols_interleaved == 0);
|
1306
|
+
|
1307
|
+
UNUSED(s);
|
1308
|
+
UNUSED(bs);
|
1309
|
+
UNUSED(vx);
|
1310
|
+
UNUSED(vy);
|
1311
|
+
UNUSED(nr);
|
1312
|
+
UNUSED(nc);
|
1313
|
+
UNUSED(nb);
|
1314
|
+
UNUSED(ncols_interleaved);
|
1315
|
+
UNUSED(blocklen);
|
1316
|
+
|
1317
|
+
#if defined(__AVX2__)
|
1318
|
+
// Lookup table to convert signed nibbles to signed bytes
|
1319
|
+
__m256i signextendlut = _mm256_castsi128_si256(_mm_set_epi8(-1, -2, -3, -4, -5, -6, -7, -8, 7, 6, 5, 4, 3, 2, 1, 0));
|
1320
|
+
signextendlut = _mm256_permute2f128_si256(signextendlut, signextendlut, 0);
|
1321
|
+
// Shuffle masks to rearrange delta and scale values to multiply with appropriate scales
|
1322
|
+
__m128i deltamask = _mm_set_epi8(15, 14, 7, 6, 13, 12, 5, 4, 11, 10, 3, 2, 9, 8, 1, 0);
|
1323
|
+
__m128i scalemask = _mm_set_epi8(7, 7, 3, 3, 6, 6, 2, 2, 5, 5, 1, 1, 4, 4, 0, 0);
|
1324
|
+
// Permute mask used for easier vector processing at later stages
|
1325
|
+
__m256i finalpermutemask = _mm256_set_epi32(7, 5, 3, 1, 6, 4, 2, 0);
|
1326
|
+
|
1327
|
+
// Mask to extract nibbles from bytes
|
1328
|
+
const __m256i m4b = _mm256_set1_epi8(0x0F);
|
1329
|
+
|
1330
|
+
int64_t b_nb = n / QK_K;
|
1331
|
+
|
1332
|
+
const block_q4_Kx8 * b_ptr_start = (const block_q4_Kx8 *)vx;
|
1333
|
+
const block_q8_K * a_ptr_start = (const block_q8_K *)vy;
|
1334
|
+
|
1335
|
+
// Process Q8_K blocks one by one
|
1336
|
+
for (int64_t y = 0; y < nr; y++) {
|
1337
|
+
|
1338
|
+
// Pointers to LHS blocks of block_q8_K format
|
1339
|
+
const block_q8_K * a_ptr = a_ptr_start + (y * nb);
|
1340
|
+
|
1341
|
+
// Take group of eight interleaved block_q4_K structures at each pass of the loop and perform dot product operation
|
1342
|
+
for (int64_t x = 0; x < nc / 8; x++) {
|
1343
|
+
|
1344
|
+
// Pointers to RHS blocks
|
1345
|
+
const block_q4_Kx8 * b_ptr = b_ptr_start + (x * b_nb);
|
1346
|
+
|
1347
|
+
// Master FP accumulators
|
1348
|
+
__m256 acc_row = _mm256_setzero_ps();
|
1349
|
+
__m256 acc_min_rows = _mm256_setzero_ps();
|
1350
|
+
|
1351
|
+
for (int64_t b = 0; b < nb; b++) {
|
1352
|
+
|
1353
|
+
// Load and convert to FP32 scale from block_q8_K
|
1354
|
+
const __m256 row_scale_f32 = _mm256_set1_ps((a_ptr[b].d));
|
1355
|
+
|
1356
|
+
// Load the scale values for the 8 blocks interleaved in block_q4_Kx8
|
1357
|
+
// col_scale_f32 rearranged so as to multiply with appropriate quants
|
1358
|
+
const __m256 col_scale_f32 = LM_GGML_F32Cx8_REARRANGE_LOAD(b_ptr[b].d, deltamask);
|
1359
|
+
const __m256 col_dmin_f32 = LM_GGML_F32Cx8_LOAD(b_ptr[b].dmin);
|
1360
|
+
|
1361
|
+
__m256i iacc_b = _mm256_setzero_si256();
|
1362
|
+
__m256i iacc_min_b = _mm256_setzero_si256();
|
1363
|
+
|
1364
|
+
const __m256i q8sums = _mm256_loadu_si256((const __m256i * )(a_ptr[b].bsums));
|
1365
|
+
__m256i q8s = _mm256_castsi128_si256(_mm_hadd_epi16(_mm256_castsi256_si128(q8sums), _mm256_extracti128_si256(q8sums, 1)));
|
1366
|
+
q8s = _mm256_permute2f128_si256(q8s, q8s, 0);
|
1367
|
+
|
1368
|
+
// Processes two sub blocks from each Q4_K in each iteration
|
1369
|
+
for (int sb = 0; sb < QK_K / 64; sb++) {
|
1370
|
+
|
1371
|
+
// Load the eight block_q4_K for two sub blocks quantized values interleaved with each other in chunks of eight - B0,B1 ....B6,B7
|
1372
|
+
const __m256i rhs_raw_vec_0123_0 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + sb * 256));
|
1373
|
+
const __m256i rhs_raw_vec_4567_0 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 32 + sb * 256));
|
1374
|
+
const __m256i rhs_raw_vec_0123_1 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 64 + sb * 256));
|
1375
|
+
const __m256i rhs_raw_vec_4567_1 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 96 + sb * 256));
|
1376
|
+
const __m256i rhs_raw_vec_0123_2 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 128 + sb * 256));
|
1377
|
+
const __m256i rhs_raw_vec_4567_2 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 160 + sb * 256));
|
1378
|
+
const __m256i rhs_raw_vec_0123_3 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 192 + sb * 256));
|
1379
|
+
const __m256i rhs_raw_vec_4567_3 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 224 + sb * 256));
|
1380
|
+
|
1381
|
+
// 4-bit -> 8-bit
|
1382
|
+
// Values of the first sub block of eight block_q4_K structures for the sb loop
|
1383
|
+
const __m256i rhs_vec_0123_00 = _mm256_and_si256(rhs_raw_vec_0123_0, m4b);
|
1384
|
+
const __m256i rhs_vec_4567_00 = _mm256_and_si256(rhs_raw_vec_4567_0, m4b);
|
1385
|
+
const __m256i rhs_vec_0123_01 = _mm256_and_si256(rhs_raw_vec_0123_1, m4b);
|
1386
|
+
const __m256i rhs_vec_4567_01 = _mm256_and_si256(rhs_raw_vec_4567_1, m4b);
|
1387
|
+
const __m256i rhs_vec_0123_02 = _mm256_and_si256(rhs_raw_vec_0123_2, m4b);
|
1388
|
+
const __m256i rhs_vec_4567_02 = _mm256_and_si256(rhs_raw_vec_4567_2, m4b);
|
1389
|
+
const __m256i rhs_vec_0123_03 = _mm256_and_si256(rhs_raw_vec_0123_3, m4b);
|
1390
|
+
const __m256i rhs_vec_4567_03 = _mm256_and_si256(rhs_raw_vec_4567_3, m4b);
|
1391
|
+
|
1392
|
+
// Values of the second sub block of eight block_q4_K structures when sb = 1
|
1393
|
+
const __m256i rhs_vec_0123_10 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_0, 4), m4b);
|
1394
|
+
const __m256i rhs_vec_4567_10 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_0, 4), m4b);
|
1395
|
+
const __m256i rhs_vec_0123_11 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_1, 4), m4b);
|
1396
|
+
const __m256i rhs_vec_4567_11 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_1, 4), m4b);
|
1397
|
+
const __m256i rhs_vec_0123_12 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_2, 4), m4b);
|
1398
|
+
const __m256i rhs_vec_4567_12 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_2, 4), m4b);
|
1399
|
+
const __m256i rhs_vec_0123_13 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_3, 4), m4b);
|
1400
|
+
const __m256i rhs_vec_4567_13 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_3, 4), m4b);
|
1401
|
+
|
1402
|
+
uint32_t utmp_0[4], utmp_1[4];
|
1403
|
+
|
1404
|
+
// Scales and Mins of corresponding sub blocks from different Q8_K structures are stored together
|
1405
|
+
// The below block is for eg to extract first sub block's scales and mins from different Q4_K structures for the sb loop
|
1406
|
+
memcpy(utmp_0, b_ptr[b].scales + 24 * sb, 12);
|
1407
|
+
utmp_0[3] = ((utmp_0[2] >> 4) & kmask2) | (((utmp_0[1] >> 6) & kmask3) << 4);
|
1408
|
+
const uint32_t uaux_0 = utmp_0[1] & kmask1;
|
1409
|
+
utmp_0[1] = (utmp_0[2] & kmask2) | (((utmp_0[0] >> 6) & kmask3) << 4);
|
1410
|
+
utmp_0[2] = uaux_0;
|
1411
|
+
utmp_0[0] &= kmask1;
|
1412
|
+
|
1413
|
+
// The below block is for eg to extract second sub block's scales and mins from different Q4_K structures for the sb loop
|
1414
|
+
memcpy(utmp_1, b_ptr[b].scales + 12 + sb * 24, 12);
|
1415
|
+
utmp_1[3] = ((utmp_1[2] >> 4) & kmask2) | (((utmp_1[1] >> 6) & kmask3) << 4);
|
1416
|
+
const uint32_t uaux_1 = utmp_1[1] & kmask1;
|
1417
|
+
utmp_1[1] = (utmp_1[2] & kmask2) | (((utmp_1[0] >> 6) & kmask3) << 4);
|
1418
|
+
utmp_1[2] = uaux_1;
|
1419
|
+
utmp_1[0] &= kmask1;
|
1420
|
+
|
1421
|
+
// Scales of first sub block in the sb loop
|
1422
|
+
const __m128i mins_and_scales_0 = _mm_set_epi32(utmp_0[3], utmp_0[2], utmp_0[1], utmp_0[0]);
|
1423
|
+
__m128i scales_rearrange_0 = _mm_shuffle_epi8(mins_and_scales_0, scalemask);
|
1424
|
+
__m256i scales_0 = _mm256_cvtepu8_epi16(scales_rearrange_0);
|
1425
|
+
|
1426
|
+
// Scales of second sub block in the sb loop
|
1427
|
+
__m128i mins_and_scales_1 = _mm_set_epi32(utmp_1[3], utmp_1[2], utmp_1[1], utmp_1[0]);
|
1428
|
+
__m128i scales_rearrange_1 = _mm_shuffle_epi8(mins_and_scales_1, scalemask);
|
1429
|
+
__m256i scales_1 = _mm256_cvtepu8_epi16(scales_rearrange_1);
|
1430
|
+
|
1431
|
+
// Mins of first and second sub block of Q4_K block are arranged side by side
|
1432
|
+
__m256i mins_01 = _mm256_cvtepu8_epi16(_mm_unpacklo_epi8(_mm_shuffle_epi32(mins_and_scales_0, 78), _mm_shuffle_epi32(mins_and_scales_1, 78)));
|
1433
|
+
|
1434
|
+
// Load the two sub block values corresponding to sb in block_q8_K in batches of 16 bytes and replicate the same across 256 bit vector
|
1435
|
+
__m256i lhs_vec_00 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + sb * 64)));
|
1436
|
+
__m256i lhs_vec_01 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 16 + sb * 64)));
|
1437
|
+
__m256i lhs_vec_10 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 32 + sb * 64)));
|
1438
|
+
__m256i lhs_vec_11 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 48 + sb * 64)));
|
1439
|
+
|
1440
|
+
lhs_vec_00 = _mm256_permute2f128_si256(lhs_vec_00, lhs_vec_00, 0);
|
1441
|
+
lhs_vec_01 = _mm256_permute2f128_si256(lhs_vec_01, lhs_vec_01, 0);
|
1442
|
+
lhs_vec_10 = _mm256_permute2f128_si256(lhs_vec_10, lhs_vec_10, 0);
|
1443
|
+
lhs_vec_11 = _mm256_permute2f128_si256(lhs_vec_11, lhs_vec_11, 0);
|
1444
|
+
|
1445
|
+
// Dot product done within 32 bit lanes and accumulated in the same vector
|
1446
|
+
// First done for first sub block and thenn for second sub block in each sb
|
1447
|
+
// B0(0-3) B4(0-3) B1(0-3) B5(0-3) B2(0-3) B6(0-3) B3(0-3) B7(0-3) with A0(0-3)
|
1448
|
+
// B0(4-7) B4(4-7) B1(4-7) B5(4-7) B2(4-7) B6(4-7) B3(4-7) B7(4-7) with A0(4-7)
|
1449
|
+
// ...........................................................................
|
1450
|
+
// B0(28-31) B4(28-31) B1(28-31) B5(28-31) B2(28-31) B6(28-31) B3(28-31) B7(28-31) with A0(28-31)
|
1451
|
+
|
1452
|
+
|
1453
|
+
__m256i iacc_0 = _mm256_setzero_si256();
|
1454
|
+
__m256i iacc_1 = _mm256_setzero_si256();
|
1455
|
+
|
1456
|
+
iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_00 ,_mm256_shuffle_epi32(rhs_vec_4567_00, 177), 170), _mm256_shuffle_epi32(lhs_vec_00, 0)));
|
1457
|
+
iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_00, 177) ,rhs_vec_4567_00, 170), _mm256_shuffle_epi32(lhs_vec_00, 85)));
|
1458
|
+
|
1459
|
+
iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_01 ,_mm256_shuffle_epi32(rhs_vec_4567_01, 177), 170), _mm256_shuffle_epi32(lhs_vec_00, 170)));
|
1460
|
+
iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_01, 177) ,rhs_vec_4567_01, 170), _mm256_shuffle_epi32(lhs_vec_00, 255)));
|
1461
|
+
|
1462
|
+
iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_02 ,_mm256_shuffle_epi32(rhs_vec_4567_02, 177), 170), _mm256_shuffle_epi32(lhs_vec_01, 0)));
|
1463
|
+
iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_02, 177) ,rhs_vec_4567_02, 170), _mm256_shuffle_epi32(lhs_vec_01, 85)));
|
1464
|
+
|
1465
|
+
iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_03 ,_mm256_shuffle_epi32(rhs_vec_4567_03, 177), 170), _mm256_shuffle_epi32(lhs_vec_01, 170)));
|
1466
|
+
iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_03, 177) ,rhs_vec_4567_03, 170), _mm256_shuffle_epi32(lhs_vec_01, 255)));
|
1467
|
+
|
1468
|
+
iacc_0 = _mm256_madd_epi16(iacc_0, scales_0);
|
1469
|
+
|
1470
|
+
iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_10 ,_mm256_shuffle_epi32(rhs_vec_4567_10, 177), 170), _mm256_shuffle_epi32(lhs_vec_10, 0)));
|
1471
|
+
iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_10, 177) ,rhs_vec_4567_10, 170), _mm256_shuffle_epi32(lhs_vec_10, 85)));
|
1472
|
+
|
1473
|
+
iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_11 ,_mm256_shuffle_epi32(rhs_vec_4567_11, 177), 170), _mm256_shuffle_epi32(lhs_vec_10, 170)));
|
1474
|
+
iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_11, 177) ,rhs_vec_4567_11, 170), _mm256_shuffle_epi32(lhs_vec_10, 255)));
|
1475
|
+
|
1476
|
+
iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_12 ,_mm256_shuffle_epi32(rhs_vec_4567_12, 177), 170), _mm256_shuffle_epi32(lhs_vec_11, 0)));
|
1477
|
+
iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_12, 177) ,rhs_vec_4567_12, 170), _mm256_shuffle_epi32(lhs_vec_11, 85)));
|
1478
|
+
|
1479
|
+
iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_13 ,_mm256_shuffle_epi32(rhs_vec_4567_13, 177), 170), _mm256_shuffle_epi32(lhs_vec_11, 170)));
|
1480
|
+
iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_13, 177) ,rhs_vec_4567_13, 170), _mm256_shuffle_epi32(lhs_vec_11, 255)));
|
1481
|
+
|
1482
|
+
iacc_1 = _mm256_madd_epi16(iacc_1, scales_1);
|
1483
|
+
|
1484
|
+
// Accumulate the iacc value for one sb
|
1485
|
+
__m256i iacc_sb = _mm256_add_epi32(iacc_0, iacc_1);
|
1486
|
+
|
1487
|
+
// Broadcast the bsums of the two sub blocks of the iteration of Q8_K across the vector
|
1488
|
+
// Multiply-Add with corresponding mins of Q4_Kx8 with bsums
|
1489
|
+
__m256i q8s_sb = _mm256_shuffle_epi32(q8s, 0);
|
1490
|
+
__m256i iacc_min_sb = _mm256_madd_epi16(q8s_sb, mins_01);
|
1491
|
+
q8s = _mm256_bsrli_epi128(q8s, 4);
|
1492
|
+
|
1493
|
+
// Accumulate for the complete block
|
1494
|
+
iacc_b = _mm256_add_epi32(iacc_b, iacc_sb);
|
1495
|
+
iacc_min_b = _mm256_add_epi32(iacc_min_b, iacc_min_sb);
|
1496
|
+
}
|
1497
|
+
|
1498
|
+
// Multiply-Add with scale values for the complete super block
|
1499
|
+
acc_row = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_b), _mm256_mul_ps(col_scale_f32, row_scale_f32), acc_row);
|
1500
|
+
acc_min_rows = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_min_b), _mm256_mul_ps(col_dmin_f32, row_scale_f32), acc_min_rows);
|
1501
|
+
|
1502
|
+
}
|
1503
|
+
|
1504
|
+
// Accumulated output values permuted so as to be stored in appropriate order post accumulation
|
1505
|
+
acc_row = _mm256_permutevar8x32_ps(acc_row, finalpermutemask);
|
1506
|
+
_mm256_storeu_ps(s + (y * nr + x * 8), _mm256_sub_ps(acc_row, acc_min_rows));
|
1507
|
+
}
|
1508
|
+
}
|
1509
|
+
|
1510
|
+
#else
|
1511
|
+
|
1512
|
+
float sumf[8];
|
1513
|
+
float sum_minf[8];
|
1514
|
+
uint32_t utmp[32];
|
1515
|
+
int sumi1;
|
1516
|
+
int sumi2;
|
1517
|
+
int sumi;
|
1518
|
+
|
1519
|
+
const block_q8_K * a_ptr = (const block_q8_K *) vy;
|
1520
|
+
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
1521
|
+
const block_q4_Kx8 * b_ptr = (const block_q4_Kx8 *) vx + (x * nb);
|
1522
|
+
|
1523
|
+
for (int j = 0; j < ncols_interleaved; j++) {
|
1524
|
+
sumf[j] = 0.0;
|
1525
|
+
sum_minf[j] = 0.0;
|
1526
|
+
}
|
1527
|
+
for (int l = 0; l < nb; l++) {
|
1528
|
+
for (int sb = 0; sb < 8; sb++) {
|
1529
|
+
memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12);
|
1530
|
+
utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);
|
1531
|
+
const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;
|
1532
|
+
utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);
|
1533
|
+
utmp[sb * 4 + 2] = uaux_0;
|
1534
|
+
utmp[sb * 4 + 0] &= kmask1;
|
1535
|
+
}
|
1536
|
+
for (int k = 0; k < (qk / (2 * blocklen)); k++) {
|
1537
|
+
uint8_t *scales_0 = (uint8_t*) utmp + (k / 4) * 32;
|
1538
|
+
uint8_t *scales_1 = (uint8_t*) utmp + (k / 4) * 32 + 16;
|
1539
|
+
for (int j = 0; j < ncols_interleaved; j++) {
|
1540
|
+
sumi1 = 0;
|
1541
|
+
sumi2 = 0;
|
1542
|
+
sumi = 0;
|
1543
|
+
for (int i = 0; i < blocklen; ++i) {
|
1544
|
+
const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF);
|
1545
|
+
const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4);
|
1546
|
+
sumi1 = (v0 * a_ptr[l].qs[(k >> 2) * 64 + (k % 4) * blocklen + i]);
|
1547
|
+
sumi2 = (v1 * a_ptr[l].qs[(k >> 2) * 64 + (k % 4) * blocklen + i + 32]);
|
1548
|
+
sumi1 = sumi1 * scales_0[j];
|
1549
|
+
sumi2 = sumi2 * scales_1[j];
|
1550
|
+
sumi += sumi1 + sumi2;
|
1551
|
+
}
|
1552
|
+
sumf[j] += sumi * LM_GGML_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d;
|
1553
|
+
}
|
1554
|
+
}
|
1555
|
+
for (int sb = 0; sb < 8; sb++) {
|
1556
|
+
uint8_t *mins = (uint8_t*) utmp + 8 + sb * 16;
|
1557
|
+
for (int j = 0; j < ncols_interleaved; j++) {
|
1558
|
+
sum_minf[j] += mins[j] * (a_ptr[l].bsums[sb * 2] + a_ptr[l].bsums[sb * 2 + 1]) * LM_GGML_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d;
|
1559
|
+
}
|
1560
|
+
}
|
1561
|
+
}
|
1562
|
+
for (int j = 0; j < ncols_interleaved; j++) {
|
1563
|
+
s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j];
|
1564
|
+
}
|
1565
|
+
}
|
1566
|
+
#endif
|
1567
|
+
}
|
1568
|
+
|
1569
|
+
|
997
1570
|
static void lm_ggml_gemv_iq4_nl_4x4_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, const void * LM_GGML_RESTRICT vy, int nr, int nc) {
|
998
1571
|
const int qk = QK8_0;
|
999
1572
|
const int nb = n / qk;
|
@@ -3480,6 +4053,781 @@ static void lm_ggml_gemm_q4_0_8x8_q8_0(int n, float * LM_GGML_RESTRICT s, size_t
|
|
3480
4053
|
}
|
3481
4054
|
}
|
3482
4055
|
|
4056
|
+
static void lm_ggml_gemm_q4_K_8x8_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, const void * LM_GGML_RESTRICT vy, int nr, int nc) {
|
4057
|
+
const int qk = QK_K;
|
4058
|
+
const int nb = n / qk;
|
4059
|
+
const int ncols_interleaved = 8;
|
4060
|
+
const int blocklen = 8;
|
4061
|
+
static const uint32_t kmask1 = 0x3f3f3f3f;
|
4062
|
+
static const uint32_t kmask2 = 0x0f0f0f0f;
|
4063
|
+
static const uint32_t kmask3 = 0x03030303;
|
4064
|
+
|
4065
|
+
assert (n % qk == 0);
|
4066
|
+
assert (nr % 4 == 0);
|
4067
|
+
assert (nc % ncols_interleaved == 0);
|
4068
|
+
|
4069
|
+
UNUSED(s);
|
4070
|
+
UNUSED(bs);
|
4071
|
+
UNUSED(vx);
|
4072
|
+
UNUSED(vy);
|
4073
|
+
UNUSED(nr);
|
4074
|
+
UNUSED(nc);
|
4075
|
+
UNUSED(nb);
|
4076
|
+
UNUSED(ncols_interleaved);
|
4077
|
+
UNUSED(blocklen);
|
4078
|
+
|
4079
|
+
#if defined(__AVX2__)
|
4080
|
+
const block_q4_Kx8 * b_ptr_start = (const block_q4_Kx8 * ) vx;
|
4081
|
+
const block_q8_Kx4 * a_ptr_start = (const block_q8_Kx4 * ) vy;
|
4082
|
+
int64_t b_nb = n / QK_K;
|
4083
|
+
int64_t y = 0;
|
4084
|
+
|
4085
|
+
// Mask to mask out nibbles from packed bytes
|
4086
|
+
const __m256i m4b = _mm256_set1_epi8(0x0F);
|
4087
|
+
// Permute mask used for easier vector processing at later stages
|
4088
|
+
__m256i requiredOrder = _mm256_set_epi32(3, 2, 1, 0, 7, 6, 5, 4);
|
4089
|
+
|
4090
|
+
int anr = nr - nr % 16;; // Used to align nr with boundary of 16
|
4091
|
+
// Take group of four block_q8_Kx4 structures at each pass of the loop and perform dot product operation
|
4092
|
+
for (; y < anr / 4; y += 4) {
|
4093
|
+
|
4094
|
+
const block_q8_Kx4 * a_ptrs[4];
|
4095
|
+
|
4096
|
+
a_ptrs[0] = a_ptr_start + (y * nb);
|
4097
|
+
for (int i = 0; i < 3; ++i) {
|
4098
|
+
a_ptrs[i + 1] = a_ptrs[i] + nb;
|
4099
|
+
}
|
4100
|
+
|
4101
|
+
// Take group of eight block_q4_kx8 structures at each pass of the loop and perform dot product operation
|
4102
|
+
for (int64_t x = 0; x < nc / 8; x++) {
|
4103
|
+
|
4104
|
+
const block_q4_Kx8 * b_ptr = b_ptr_start + (x * b_nb);
|
4105
|
+
|
4106
|
+
// Master FP accumulators
|
4107
|
+
__m256 acc_rows[16];
|
4108
|
+
for (int i = 0; i < 16; i++) {
|
4109
|
+
acc_rows[i] = _mm256_setzero_ps();
|
4110
|
+
}
|
4111
|
+
|
4112
|
+
__m256 acc_min_rows[16];
|
4113
|
+
for (int i = 0; i < 16; i++) {
|
4114
|
+
acc_min_rows[i] = _mm256_setzero_ps();
|
4115
|
+
}
|
4116
|
+
|
4117
|
+
// For super block
|
4118
|
+
for (int64_t b = 0; b < nb; b++) {
|
4119
|
+
|
4120
|
+
// Scale values - Load the eight scale values of block_q4_kx8
|
4121
|
+
const __m256 col_scale_f32 = LM_GGML_F32Cx8_LOAD(b_ptr[b].d);
|
4122
|
+
|
4123
|
+
// dmin values - Load the eight dmin values of block_q4_kx8
|
4124
|
+
const __m256 col_dmin_f32 = LM_GGML_F32Cx8_LOAD(b_ptr[b].dmin);
|
4125
|
+
|
4126
|
+
// Loop to iterate over the eight sub blocks of a super block - two sub blocks are processed per iteration
|
4127
|
+
for (int sb = 0; sb < QK_K / 64; sb++) {
|
4128
|
+
|
4129
|
+
// Load the eight block_q4_K for two sub blocks quantized values interleaved with each other in chunks of eight bytes - B0,B1 ....B6,B7
|
4130
|
+
const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + sb * 256));
|
4131
|
+
const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 32 + sb * 256));
|
4132
|
+
const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 64 + sb * 256));
|
4133
|
+
const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 96 + sb * 256));
|
4134
|
+
const __m256i rhs_raw_mat_0123_2 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 128 + sb * 256));
|
4135
|
+
const __m256i rhs_raw_mat_4567_2 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 160 + sb * 256));
|
4136
|
+
const __m256i rhs_raw_mat_0123_3 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 192 + sb * 256));
|
4137
|
+
const __m256i rhs_raw_mat_4567_3 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 224 + sb * 256));
|
4138
|
+
|
4139
|
+
// Save the values in the following vectors in the formats B0B1B4B5, B2B3B6B7 for further processing and storing of values
|
4140
|
+
const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240);
|
4141
|
+
const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240);
|
4142
|
+
const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240);
|
4143
|
+
const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240);
|
4144
|
+
const __m256i rhs_raw_mat_0145_2 = _mm256_blend_epi32(rhs_raw_mat_0123_2, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_2, requiredOrder), 240);
|
4145
|
+
const __m256i rhs_raw_mat_2367_2 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_2, requiredOrder), rhs_raw_mat_4567_2, 240);
|
4146
|
+
const __m256i rhs_raw_mat_0145_3 = _mm256_blend_epi32(rhs_raw_mat_0123_3, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_3, requiredOrder), 240);
|
4147
|
+
const __m256i rhs_raw_mat_2367_3 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_3, requiredOrder), rhs_raw_mat_4567_3, 240);
|
4148
|
+
|
4149
|
+
// 4-bit -> 8-bit
|
4150
|
+
// First sub block of the two sub blocks processed in the iteration
|
4151
|
+
const __m256i rhs_mat_0145_00 = _mm256_and_si256(rhs_raw_mat_0145_0, m4b); //B00(0-7) B01(0-7) B04(0-7) B05(0-7)
|
4152
|
+
const __m256i rhs_mat_2367_00 = _mm256_and_si256(rhs_raw_mat_2367_0, m4b); //B02(0-7) B03(0-7) B06(0-7) B07(0-7)
|
4153
|
+
|
4154
|
+
const __m256i rhs_mat_0145_01 = _mm256_and_si256(rhs_raw_mat_0145_1, m4b); //B00(8-15) B01(8-15) B04(8-15) B05(8-15)
|
4155
|
+
const __m256i rhs_mat_2367_01 = _mm256_and_si256(rhs_raw_mat_2367_1, m4b); //B02(8-15) B03(8-15) B06(8-15) B07(8-15)
|
4156
|
+
|
4157
|
+
const __m256i rhs_mat_0145_02 = _mm256_and_si256(rhs_raw_mat_0145_2, m4b); //B00(16-23) B01(16-23) B04(16-23) B05(16-23)
|
4158
|
+
const __m256i rhs_mat_2367_02 = _mm256_and_si256(rhs_raw_mat_2367_2, m4b); //B02(16-23) B03(16-23) B06(16-23) B07(16-23)
|
4159
|
+
|
4160
|
+
const __m256i rhs_mat_0145_03 = _mm256_and_si256(rhs_raw_mat_0145_3, m4b); //B00(24-31) B01(24-31) B04(24-31) B05(24-31)
|
4161
|
+
const __m256i rhs_mat_2367_03 = _mm256_and_si256(rhs_raw_mat_2367_3, m4b); //B02(24-31) B03(24-31) B06(24-31) B07(24-31)
|
4162
|
+
|
4163
|
+
// Second sub block of the two sub blocks processed in the iteration
|
4164
|
+
const __m256i rhs_mat_0145_10 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_0, 4), m4b); //B10(0-7) B11(0-7) B14(0-7) B15(0-7)
|
4165
|
+
const __m256i rhs_mat_2367_10 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_0, 4), m4b); //B12(0-7) B13(0-7) B16(0-7) B17(0-7)
|
4166
|
+
|
4167
|
+
const __m256i rhs_mat_0145_11 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_1, 4), m4b); //B10(8-15) B11(8-15) B14(8-15) B15(8-15)
|
4168
|
+
const __m256i rhs_mat_2367_11 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_1, 4), m4b); //B12(8-15) B13(8-15) B16(8-15) B17(8-15)
|
4169
|
+
|
4170
|
+
const __m256i rhs_mat_0145_12 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_2, 4), m4b); //B10(16-23) B11(16-23) B14(16-23) B15(16-23)
|
4171
|
+
const __m256i rhs_mat_2367_12 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_2, 4), m4b); //B12(16-23) B13(16-23) B16(16-23) B17(16-23)
|
4172
|
+
|
4173
|
+
const __m256i rhs_mat_0145_13 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_3, 4), m4b); //B10(24-31) B11(24-31) B14(24-31) B15(24-31)
|
4174
|
+
const __m256i rhs_mat_2367_13 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_3, 4), m4b); //B12(24-31) B13(24-31) B16(24-31) B17(24-31)
|
4175
|
+
|
4176
|
+
// Shuffle pattern one - right side input
|
4177
|
+
const __m256i rhs_mat_0145_00_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_00, 136); //B00(0-3) B01(0-3) B00(0-3) B01(0-3) B04(0-3) B05(0-3) B04(0-3) B05(0-3)
|
4178
|
+
const __m256i rhs_mat_2367_00_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_00, 136); //B02(0-3) B03(0-3) B02(0-3) B03(0-3) B06(0-3) B07(0-3) B06(0-3) B07(0-3)
|
4179
|
+
|
4180
|
+
const __m256i rhs_mat_0145_01_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_01, 136); //B00(8-11) B01(8-11) B00(8-11) B01(8-11) B04(8-11) B05(8-11) B04(8-11) B05(8-11)
|
4181
|
+
const __m256i rhs_mat_2367_01_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_01, 136); //B02(8-11) B03(8-11) B02(8-11) B03(8-11) B06(8-11) B07(8-11) B06(8-11) B07(8-11)
|
4182
|
+
|
4183
|
+
const __m256i rhs_mat_0145_02_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_02, 136); //B00(16-19) B01(16-19) B00(16-19) B01(16-19) B04(16-19) B05(16-19) B04(16-19) B05(16-19)
|
4184
|
+
const __m256i rhs_mat_2367_02_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_02, 136); //B02(16-19) B03(16-19) B02(16-19) B03(16-19) B06(16-19) B07(16-19) B06(16-19) B07(16-19)
|
4185
|
+
|
4186
|
+
const __m256i rhs_mat_0145_03_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_03, 136); //B00(24-27) B01(24-27) B00(24-27) B01(24-27) B04(24-27) B05(24-27) B04(24-27) B05(24-27)
|
4187
|
+
const __m256i rhs_mat_2367_03_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_03, 136); //B02(24-27) B03(24-27) B02(24-27) B03(24-27) B06(24-27) B07(24-27) B06(24-27) B07(24-27)
|
4188
|
+
|
4189
|
+
const __m256i rhs_mat_0145_10_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_10, 136); //B10(0-3) B11(0-3) B10(0-3) B11(0-3) B14(0-3) B15(0-3) B14(0-3) B15(0-3)
|
4190
|
+
const __m256i rhs_mat_2367_10_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_10, 136); //B12(0-3) B13(0-3) B12(0-3) B13(0-3) B16(0-3) B17(0-3) B16(0-3) B17(0-3)
|
4191
|
+
|
4192
|
+
const __m256i rhs_mat_0145_11_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_11, 136); //B10(8-11) B11(8-11) B10(8-11) B11(8-11) B14(8-11) B15(8-11) B14(8-11) B15(8-11)
|
4193
|
+
const __m256i rhs_mat_2367_11_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_11, 136); //B12(8-11) B13(8-11) B12(8-11) B13(8-11) B16(8-11) B17(8-11) B16(8-11) B17(8-11)
|
4194
|
+
|
4195
|
+
const __m256i rhs_mat_0145_12_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_12, 136); //B10(16-19) B11(16-19) B10(16-19) B11(16-19) B14(16-19) B15(16-19) B14(16-19) B15(16-19)
|
4196
|
+
const __m256i rhs_mat_2367_12_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_12, 136); //B12(16-19) B13(16-19) B12(16-19) B13(16-19) B16(16-19) B17(16-19) B16(16-19) B17(16-19)
|
4197
|
+
|
4198
|
+
const __m256i rhs_mat_0145_13_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_13, 136); //B10(24-27) B11(24-27) B10(24-27) B11(24-27) B14(24-27) B15(24-27) B14(24-27) B15(24-27)
|
4199
|
+
const __m256i rhs_mat_2367_13_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_13, 136); //B12(24-27) B13(24-27) B12(24-27) B13(24-27) B16(24-27) B17(24-27) B16(24-27) B17(24-27)
|
4200
|
+
|
4201
|
+
|
4202
|
+
// Shuffle pattern two - right side input
|
4203
|
+
const __m256i rhs_mat_0145_00_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_00, 221); //B00(4-7) B01(4-7) B00(4-7) B01(4-7) B04(4-7) B05(4-7) B04(4-7) B05(4-7)
|
4204
|
+
const __m256i rhs_mat_2367_00_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_00, 221); //B02(4-7) B03(4-7) B02(4-7) B03(4-7) B06(4-7) B07(4-7) B06(4-7) B07(4-7)
|
4205
|
+
|
4206
|
+
const __m256i rhs_mat_0145_01_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_01, 221); //B00(12-15) B01(12-15) B00(12-15) B01(12-15) B04(12-15) B05(12-15) B04(12-15) B05(12-15)
|
4207
|
+
const __m256i rhs_mat_2367_01_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_01, 221); //B02(12-15) B03(12-15) B02(12-15) B03(12-15) B06(12-15) B07(12-15) B06(12-15) B07(12-15)
|
4208
|
+
|
4209
|
+
const __m256i rhs_mat_0145_02_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_02, 221); //B00(20-23) B01(20-23) B00(20-23) B01(20-23) B04(20-23) B05(20-23) B04(20-23) B05(20-23)
|
4210
|
+
const __m256i rhs_mat_2367_02_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_02, 221); //B02(20-23) B03(20-23) B02(20-23) B03(20-23) B06(20-23) B07(20-23) B06(20-23) B07(20-23)
|
4211
|
+
|
4212
|
+
const __m256i rhs_mat_0145_03_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_03, 221); //B00(28-31) B01(28-31) B00(28-31) B01(28-31) B04(28-31) B05(28-31) B04(28-31) B05(28-31)
|
4213
|
+
const __m256i rhs_mat_2367_03_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_03, 221); //B02(28-31) B03(28-31) B02(28-31) B03(28-31) B06(28-31) B07(28-31) B06(28-31) B07(28-31)
|
4214
|
+
|
4215
|
+
const __m256i rhs_mat_0145_10_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_10, 221); //B10(4-7) B11(4-7) B10(4-7) B11(4-7) B14(4-7) B15(4-7) B14(4-7) B15(4-7)
|
4216
|
+
const __m256i rhs_mat_2367_10_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_10, 221); //B12(4-7) B13(4-7) B12(4-7) B13(4-7) B16(4-7) B17(4-7) B16(4-7) B17(4-7)
|
4217
|
+
|
4218
|
+
const __m256i rhs_mat_0145_11_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_11, 221); //B10(12-15) B11(12-15) B10(12-15) B11(12-15) B14(12-15) B15(12-15) B14(12-15) B15(12-15)
|
4219
|
+
const __m256i rhs_mat_2367_11_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_11, 221); //B12(12-15) B13(12-15) B12(12-15) B13(12-15) B16(12-15) B17(12-15) B16(12-15) B17(12-15)
|
4220
|
+
|
4221
|
+
const __m256i rhs_mat_0145_12_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_12, 221); //B10(20-23) B11(20-23) B10(20-23) B11(20-23) B14(20-23) B15(20-23) B14(20-23) B15(20-23)
|
4222
|
+
const __m256i rhs_mat_2367_12_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_12, 221); //B12(20-23) B13(20-23) B12(20-23) B13(20-23) B16(20-23) B17(20-23) B16(20-23) B17(20-23)
|
4223
|
+
|
4224
|
+
const __m256i rhs_mat_0145_13_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_13, 221); //B10(28-31) B11(28-31) B10(28-31) B11(28-31) B14(28-31) B15(28-31) B14(28-31) B15(28-31)
|
4225
|
+
const __m256i rhs_mat_2367_13_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_13, 221); //B12(28-31) B13(28-31) B12(28-31) B13(28-31) B16(28-31) B17(28-31) B16(28-31) B17(28-31)
|
4226
|
+
|
4227
|
+
uint32_t utmp_0[4], utmp_1[4];
|
4228
|
+
|
4229
|
+
// Scales and Mins of corresponding sub blocks from different Q4_K structures are stored together
|
4230
|
+
// The below block is for eg to extract first sub block's scales and mins from different Q4_K structures for the sb loop
|
4231
|
+
memcpy(utmp_0, b_ptr[b].scales + 24 * sb, 12);
|
4232
|
+
utmp_0[3] = ((utmp_0[2] >> 4) & kmask2) | (((utmp_0[1] >> 6) & kmask3) << 4);
|
4233
|
+
const uint32_t uaux_0 = utmp_0[1] & kmask1;
|
4234
|
+
utmp_0[1] = (utmp_0[2] & kmask2) | (((utmp_0[0] >> 6) & kmask3) << 4);
|
4235
|
+
utmp_0[2] = uaux_0;
|
4236
|
+
utmp_0[0] &= kmask1;
|
4237
|
+
|
4238
|
+
// The below block is for eg to extract second sub block's scales and mins from different Q4_K structures for the sb loop
|
4239
|
+
memcpy(utmp_1, b_ptr[b].scales + 12 + sb * 24, 12);
|
4240
|
+
utmp_1[3] = ((utmp_1[2] >> 4) & kmask2) | (((utmp_1[1] >> 6) & kmask3) << 4);
|
4241
|
+
const uint32_t uaux_1 = utmp_1[1] & kmask1;
|
4242
|
+
utmp_1[1] = (utmp_1[2] & kmask2) | (((utmp_1[0] >> 6) & kmask3) << 4);
|
4243
|
+
utmp_1[2] = uaux_1;
|
4244
|
+
utmp_1[0] &= kmask1;
|
4245
|
+
|
4246
|
+
// Scales of first sub block in the sb loop
|
4247
|
+
const __m128i mins_and_scales_0 = _mm_set_epi32(utmp_0[3], utmp_0[2], utmp_0[1], utmp_0[0]);
|
4248
|
+
const __m256i scales_0 = _mm256_cvtepu8_epi16(_mm_unpacklo_epi8(mins_and_scales_0, mins_and_scales_0));
|
4249
|
+
|
4250
|
+
// Scales of second sub block in the sb loop
|
4251
|
+
const __m128i mins_and_scales_1 = _mm_set_epi32(utmp_1[3], utmp_1[2], utmp_1[1], utmp_1[0]);
|
4252
|
+
const __m256i scales_1 = _mm256_cvtepu8_epi16(_mm_unpacklo_epi8(mins_and_scales_1, mins_and_scales_1));
|
4253
|
+
|
4254
|
+
// Mins of first and second sub block of Q4_K block are arranged side by side
|
4255
|
+
const __m256i mins_01 = _mm256_cvtepu8_epi16(_mm_unpacklo_epi8(_mm_shuffle_epi32(mins_and_scales_0, 78), _mm_shuffle_epi32(mins_and_scales_1, 78)));
|
4256
|
+
|
4257
|
+
const __m256i scale_0145_0 = _mm256_shuffle_epi32(scales_0, 68);
|
4258
|
+
const __m256i scale_2367_0 = _mm256_shuffle_epi32(scales_0, 238);
|
4259
|
+
|
4260
|
+
const __m256i scale_0145_1 = _mm256_shuffle_epi32(scales_1, 68);
|
4261
|
+
const __m256i scale_2367_1 = _mm256_shuffle_epi32(scales_1, 238);
|
4262
|
+
|
4263
|
+
for (int rp = 0; rp < 4; rp++) {
|
4264
|
+
|
4265
|
+
// Load the four block_q8_k quantized values interleaved with each other in chunks of eight bytes - A0,A1,A2,A3
|
4266
|
+
// Loaded as set of 128 bit vectors and repeated into a 256 bit vector
|
4267
|
+
__m256i lhs_mat_0123_00 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 256 * sb)));
|
4268
|
+
__m256i lhs_mat_01_00 = _mm256_permute2f128_si256(lhs_mat_0123_00, lhs_mat_0123_00, 0);
|
4269
|
+
__m256i lhs_mat_23_00 = _mm256_permute2f128_si256(lhs_mat_0123_00, lhs_mat_0123_00, 17);
|
4270
|
+
__m256i lhs_mat_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 32 + 256 * sb)));
|
4271
|
+
__m256i lhs_mat_01_01 = _mm256_permute2f128_si256(lhs_mat_0123_01, lhs_mat_0123_01, 0);
|
4272
|
+
__m256i lhs_mat_23_01 = _mm256_permute2f128_si256(lhs_mat_0123_01, lhs_mat_0123_01, 17);
|
4273
|
+
__m256i lhs_mat_0123_02 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 64 + 256 * sb)));
|
4274
|
+
__m256i lhs_mat_01_02 = _mm256_permute2f128_si256(lhs_mat_0123_02, lhs_mat_0123_02, 0);
|
4275
|
+
__m256i lhs_mat_23_02 = _mm256_permute2f128_si256(lhs_mat_0123_02, lhs_mat_0123_02, 17);
|
4276
|
+
__m256i lhs_mat_0123_03 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 96 + 256 * sb)));
|
4277
|
+
__m256i lhs_mat_01_03 = _mm256_permute2f128_si256(lhs_mat_0123_03, lhs_mat_0123_03, 0);
|
4278
|
+
__m256i lhs_mat_23_03 = _mm256_permute2f128_si256(lhs_mat_0123_03, lhs_mat_0123_03, 17);
|
4279
|
+
__m256i lhs_mat_0123_10 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 128 + 256 * sb)));
|
4280
|
+
__m256i lhs_mat_01_10 = _mm256_permute2f128_si256(lhs_mat_0123_10, lhs_mat_0123_10, 0);
|
4281
|
+
__m256i lhs_mat_23_10 = _mm256_permute2f128_si256(lhs_mat_0123_10, lhs_mat_0123_10, 17);
|
4282
|
+
__m256i lhs_mat_0123_11 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 160 + 256 * sb)));
|
4283
|
+
__m256i lhs_mat_01_11 = _mm256_permute2f128_si256(lhs_mat_0123_11, lhs_mat_0123_11, 0);
|
4284
|
+
__m256i lhs_mat_23_11 = _mm256_permute2f128_si256(lhs_mat_0123_11, lhs_mat_0123_11, 17);
|
4285
|
+
__m256i lhs_mat_0123_12 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 192 + 256 * sb)));
|
4286
|
+
__m256i lhs_mat_01_12 = _mm256_permute2f128_si256(lhs_mat_0123_12, lhs_mat_0123_12, 0);
|
4287
|
+
__m256i lhs_mat_23_12 = _mm256_permute2f128_si256(lhs_mat_0123_12, lhs_mat_0123_12, 17);
|
4288
|
+
__m256i lhs_mat_0123_13 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 224 + 256 * sb)));
|
4289
|
+
__m256i lhs_mat_01_13 = _mm256_permute2f128_si256(lhs_mat_0123_13, lhs_mat_0123_13, 0);
|
4290
|
+
__m256i lhs_mat_23_13 = _mm256_permute2f128_si256(lhs_mat_0123_13, lhs_mat_0123_13, 17);
|
4291
|
+
|
4292
|
+
// Bsums are loaded - four bsums are loaded (for two sub blocks) for the different Q8_K blocks
|
4293
|
+
__m256i lhs_bsums_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].bsums + 16 * sb)));
|
4294
|
+
__m256i lhs_bsums_hsum_0123_01 = _mm256_castsi128_si256(_mm_hadd_epi16(_mm256_castsi256_si128(lhs_bsums_0123_01), _mm256_extractf128_si256(lhs_bsums_0123_01, 1)));
|
4295
|
+
lhs_bsums_hsum_0123_01 = _mm256_permute2x128_si256(lhs_bsums_hsum_0123_01, lhs_bsums_hsum_0123_01, 0);
|
4296
|
+
|
4297
|
+
// Shuffle pattern one - left side input
|
4298
|
+
const __m256i lhs_mat_01_00_sp1 = _mm256_shuffle_epi32(lhs_mat_01_00, 160); //A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3)
|
4299
|
+
const __m256i lhs_mat_23_00_sp1 = _mm256_shuffle_epi32(lhs_mat_23_00, 160); //A02(0-3) A03(0-3) A02(0-3) A03(0-3) A02(0-3) A03(0-3) A02(0-3) A03(0-3)
|
4300
|
+
|
4301
|
+
const __m256i lhs_mat_01_01_sp1 = _mm256_shuffle_epi32(lhs_mat_01_01, 160); //A00(8-11) A00(8-11) A01(8-11) A01(8-11) A00(8-11) A00(8-11) A01(8-11) A01(8-11)
|
4302
|
+
const __m256i lhs_mat_23_01_sp1 = _mm256_shuffle_epi32(lhs_mat_23_01, 160); //A02(8-11) A03(8-11) A02(8-11) A03(8-11) A02(8-11) A03(8-11) A02(8-11) A03(8-11)
|
4303
|
+
|
4304
|
+
const __m256i lhs_mat_01_02_sp1 = _mm256_shuffle_epi32(lhs_mat_01_02, 160); //A00(16-19) A00(16-19) A01(16-19) A01(16-19) A00(16-19) A00(16-19) A01(16-19) A01(16-19)
|
4305
|
+
const __m256i lhs_mat_23_02_sp1 = _mm256_shuffle_epi32(lhs_mat_23_02, 160); //A02(16-19) A03(16-19) A02(16-19) A03(16-19) A02(16-19) A03(16-19) A02(16-19) A03(16-19)
|
4306
|
+
|
4307
|
+
const __m256i lhs_mat_01_03_sp1 = _mm256_shuffle_epi32(lhs_mat_01_03, 160); //A00(24-27) A00(24-27) A01(24-27) A01(24-27) A00(24-27) A00(24-27) A01(24-27) A01(24-27)
|
4308
|
+
const __m256i lhs_mat_23_03_sp1 = _mm256_shuffle_epi32(lhs_mat_23_03, 160); //A02(24-27) A03(24-27) A02(24-27) A03(24-27) A02(24-27) A03(24-27) A02(24-27) A03(24-27)
|
4309
|
+
|
4310
|
+
const __m256i lhs_mat_01_10_sp1 = _mm256_shuffle_epi32(lhs_mat_01_10, 160); //A10(0-3) A10(0-3) A11(0-3) A11(0-3) A10(0-3) A10(0-3) A11(0-3) A11(0-3)
|
4311
|
+
const __m256i lhs_mat_23_10_sp1 = _mm256_shuffle_epi32(lhs_mat_23_10, 160); //A12(0-3) A13(0-3) A12(0-3) A13(0-3) A12(0-3) A13(0-3) A12(0-3) A13(0-3)
|
4312
|
+
|
4313
|
+
const __m256i lhs_mat_01_11_sp1 = _mm256_shuffle_epi32(lhs_mat_01_11, 160); //A10(8-11) A10(8-11) A11(8-11) A11(8-11) A10(8-11) A10(8-11) A11(8-11) A11(8-11)
|
4314
|
+
const __m256i lhs_mat_23_11_sp1 = _mm256_shuffle_epi32(lhs_mat_23_11, 160); //A12(8-11) A13(8-11) A12(8-11) A13(8-11) A12(8-11) A13(8-11) A12(8-11) A13(8-11)
|
4315
|
+
|
4316
|
+
const __m256i lhs_mat_01_12_sp1 = _mm256_shuffle_epi32(lhs_mat_01_12, 160); //A10(16-19) A10(16-19) A11(16-19) A11(16-19) A10(16-19) A10(16-19) A11(16-19) A11(16-19)
|
4317
|
+
const __m256i lhs_mat_23_12_sp1 = _mm256_shuffle_epi32(lhs_mat_23_12, 160); //A12(16-19) A13(16-19) A12(16-19) A13(16-19) A12(16-19) A13(16-19) A12(16-19) A13(16-19)
|
4318
|
+
|
4319
|
+
const __m256i lhs_mat_01_13_sp1 = _mm256_shuffle_epi32(lhs_mat_01_13, 160); //A10(24-27) A10(24-27) A11(24-27) A11(24-27) A10(24-27) A10(24-27) A11(24-27) A11(24-27)
|
4320
|
+
const __m256i lhs_mat_23_13_sp1 = _mm256_shuffle_epi32(lhs_mat_23_13, 160); //A12(24-27) A13(24-27) A12(24-27) A13(24-27) A12(24-27) A13(24-27) A12(24-27) A13(24-27)
|
4321
|
+
|
4322
|
+
// Shuffle pattern two- left side input
|
4323
|
+
const __m256i lhs_mat_01_00_sp2 = _mm256_shuffle_epi32(lhs_mat_01_00, 245); //A00(4-7) A00(4-7) A01(4-7) A01(4-7) A00(4-7) A00(4-7) A01(4-7) A01(4-7)
|
4324
|
+
const __m256i lhs_mat_23_00_sp2 = _mm256_shuffle_epi32(lhs_mat_23_00, 245); //A02(4-7) A03(4-7) A02(4-7) A03(4-7) A02(4-7) A03(4-7) A02(4-7) A03(4-7)
|
4325
|
+
|
4326
|
+
const __m256i lhs_mat_01_01_sp2 = _mm256_shuffle_epi32(lhs_mat_01_01, 245); //A00(12-15) A00(12-15) A01(12-15) A01(12-15) A00(12-15) A00(12-15) A01(12-15) A01(12-15)
|
4327
|
+
const __m256i lhs_mat_23_01_sp2 = _mm256_shuffle_epi32(lhs_mat_23_01, 245); //A02(12-15) A03(12-15) A02(12-15) A03(12-15) A02(12-15) A03(12-15) A02(12-15) A03(12-15)
|
4328
|
+
|
4329
|
+
const __m256i lhs_mat_01_02_sp2 = _mm256_shuffle_epi32(lhs_mat_01_02, 245); //A00(20-23) A00(20-23) A01(20-23) A01(20-23) A00(20-23) A00(20-23) A01(20-23) A01(20-23)
|
4330
|
+
const __m256i lhs_mat_23_02_sp2 = _mm256_shuffle_epi32(lhs_mat_23_02, 245); //A02(20-23) A03(20-23) A02(20-23) A03(20-23) A02(20-23) A03(20-23) A02(20-23) A03(20-23)
|
4331
|
+
|
4332
|
+
const __m256i lhs_mat_01_03_sp2 = _mm256_shuffle_epi32(lhs_mat_01_03, 245); //A00(28-31) A00(28-31) A01(28-31) A01(28-31) A00(28-31) A00(28-31) A01(28-31) A01(28-31)
|
4333
|
+
const __m256i lhs_mat_23_03_sp2 = _mm256_shuffle_epi32(lhs_mat_23_03, 245); //A02(28-31) A03(28-31) A02(28-31) A03(28-31) A02(28-31) A03(28-31) A02(28-31) A03(28-31)
|
4334
|
+
|
4335
|
+
const __m256i lhs_mat_01_10_sp2 = _mm256_shuffle_epi32(lhs_mat_01_10, 245); //A10(4-7) A10(4-7) A11(4-7) A11(4-7) A10(4-7) A10(4-7) A11(4-7) A11(4-7)
|
4336
|
+
const __m256i lhs_mat_23_10_sp2 = _mm256_shuffle_epi32(lhs_mat_23_10, 245); //A12(4-7) A13(4-7) A12(4-7) A13(4-7) A12(4-7) A13(4-7) A12(4-7) A13(4-7)
|
4337
|
+
|
4338
|
+
const __m256i lhs_mat_01_11_sp2 = _mm256_shuffle_epi32(lhs_mat_01_11, 245); //A10(12-15) A10(12-15) A11(12-15) A11(12-15) A10(12-15) A10(12-15) A11(12-15) A11(12-15)
|
4339
|
+
const __m256i lhs_mat_23_11_sp2 = _mm256_shuffle_epi32(lhs_mat_23_11, 245); //A12(12-15) A13(12-15) A12(12-15) A13(12-15) A12(12-15) A13(12-15) A12(12-15) A13(12-15)
|
4340
|
+
|
4341
|
+
const __m256i lhs_mat_01_12_sp2 = _mm256_shuffle_epi32(lhs_mat_01_12, 245); //A10(20-23) A10(20-23) A11(20-23) A11(20-23) A10(20-23) A10(20-23) A11(20-23) A11(20-23)
|
4342
|
+
const __m256i lhs_mat_23_12_sp2 = _mm256_shuffle_epi32(lhs_mat_23_12, 245); //A12(20-23) A13(20-23) A12(20-23) A13(20-23) A12(20-23) A13(20-23) A12(20-23) A13(20-23)
|
4343
|
+
|
4344
|
+
const __m256i lhs_mat_01_13_sp2 = _mm256_shuffle_epi32(lhs_mat_01_13, 245); //A10(28-31) A10(28-31) A11(28-31) A11(28-31) A10(28-31) A10(28-31) A11(28-31) A11(28-31)
|
4345
|
+
const __m256i lhs_mat_23_13_sp2 = _mm256_shuffle_epi32(lhs_mat_23_13, 245); //A12(28-31) A13(28-31) A12(28-31) A13(28-31) A12(28-31) A13(28-31) A12(28-31) A13(28-31)
|
4346
|
+
|
4347
|
+
// The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
|
4348
|
+
__m256i iacc_mat_00_0_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_03_sp1, lhs_mat_01_03_sp1), _mm256_maddubs_epi16(rhs_mat_0145_02_sp1, lhs_mat_01_02_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_01_sp1, lhs_mat_01_01_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_00_sp1, lhs_mat_01_00_sp1));
|
4349
|
+
__m256i iacc_mat_01_0_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_03_sp1, lhs_mat_01_03_sp1), _mm256_maddubs_epi16(rhs_mat_2367_02_sp1, lhs_mat_01_02_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_01_sp1, lhs_mat_01_01_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_00_sp1, lhs_mat_01_00_sp1));
|
4350
|
+
__m256i iacc_mat_10_0_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_03_sp1, lhs_mat_23_03_sp1), _mm256_maddubs_epi16(rhs_mat_0145_02_sp1, lhs_mat_23_02_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_01_sp1, lhs_mat_23_01_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_00_sp1, lhs_mat_23_00_sp1));
|
4351
|
+
__m256i iacc_mat_11_0_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_03_sp1, lhs_mat_23_03_sp1), _mm256_maddubs_epi16(rhs_mat_2367_02_sp1, lhs_mat_23_02_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_01_sp1, lhs_mat_23_01_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_00_sp1, lhs_mat_23_00_sp1));
|
4352
|
+
__m256i iacc_mat_00_1_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_13_sp1, lhs_mat_01_13_sp1), _mm256_maddubs_epi16(rhs_mat_0145_12_sp1, lhs_mat_01_12_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_11_sp1, lhs_mat_01_11_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_10_sp1, lhs_mat_01_10_sp1));
|
4353
|
+
__m256i iacc_mat_01_1_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_13_sp1, lhs_mat_01_13_sp1), _mm256_maddubs_epi16(rhs_mat_2367_12_sp1, lhs_mat_01_12_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_11_sp1, lhs_mat_01_11_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_10_sp1, lhs_mat_01_10_sp1));
|
4354
|
+
__m256i iacc_mat_10_1_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_13_sp1, lhs_mat_23_13_sp1), _mm256_maddubs_epi16(rhs_mat_0145_12_sp1, lhs_mat_23_12_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_11_sp1, lhs_mat_23_11_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_10_sp1, lhs_mat_23_10_sp1));
|
4355
|
+
__m256i iacc_mat_11_1_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_13_sp1, lhs_mat_23_13_sp1), _mm256_maddubs_epi16(rhs_mat_2367_12_sp1, lhs_mat_23_12_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_11_sp1, lhs_mat_23_11_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_10_sp1, lhs_mat_23_10_sp1));
|
4356
|
+
|
4357
|
+
__m256i iacc_mat_00_0_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_03_sp2, lhs_mat_01_03_sp2), _mm256_maddubs_epi16(rhs_mat_0145_02_sp2, lhs_mat_01_02_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_01_sp2, lhs_mat_01_01_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_00_sp2, lhs_mat_01_00_sp2));
|
4358
|
+
__m256i iacc_mat_01_0_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_03_sp2, lhs_mat_01_03_sp2), _mm256_maddubs_epi16(rhs_mat_2367_02_sp2, lhs_mat_01_02_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_01_sp2, lhs_mat_01_01_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_00_sp2, lhs_mat_01_00_sp2));
|
4359
|
+
__m256i iacc_mat_10_0_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_03_sp2, lhs_mat_23_03_sp2), _mm256_maddubs_epi16(rhs_mat_0145_02_sp2, lhs_mat_23_02_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_01_sp2, lhs_mat_23_01_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_00_sp2, lhs_mat_23_00_sp2));
|
4360
|
+
__m256i iacc_mat_11_0_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_03_sp2, lhs_mat_23_03_sp2), _mm256_maddubs_epi16(rhs_mat_2367_02_sp2, lhs_mat_23_02_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_01_sp2, lhs_mat_23_01_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_00_sp2, lhs_mat_23_00_sp2));
|
4361
|
+
__m256i iacc_mat_00_1_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_13_sp2, lhs_mat_01_13_sp2), _mm256_maddubs_epi16(rhs_mat_0145_12_sp2, lhs_mat_01_12_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_11_sp2, lhs_mat_01_11_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_10_sp2, lhs_mat_01_10_sp2));
|
4362
|
+
__m256i iacc_mat_01_1_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_13_sp2, lhs_mat_01_13_sp2), _mm256_maddubs_epi16(rhs_mat_2367_12_sp2, lhs_mat_01_12_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_11_sp2, lhs_mat_01_11_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_10_sp2, lhs_mat_01_10_sp2));
|
4363
|
+
__m256i iacc_mat_10_1_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_13_sp2, lhs_mat_23_13_sp2), _mm256_maddubs_epi16(rhs_mat_0145_12_sp2, lhs_mat_23_12_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_11_sp2, lhs_mat_23_11_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_10_sp2, lhs_mat_23_10_sp2));
|
4364
|
+
__m256i iacc_mat_11_1_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_13_sp2, lhs_mat_23_13_sp2), _mm256_maddubs_epi16(rhs_mat_2367_12_sp2, lhs_mat_23_12_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_11_sp2, lhs_mat_23_11_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_10_sp2, lhs_mat_23_10_sp2));
|
4365
|
+
|
4366
|
+
// Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
|
4367
|
+
__m256i iacc_mat_00_0 = _mm256_add_epi16(iacc_mat_00_0_sp1, iacc_mat_00_0_sp2);
|
4368
|
+
__m256i iacc_mat_01_0 = _mm256_add_epi16(iacc_mat_01_0_sp1, iacc_mat_01_0_sp2);
|
4369
|
+
__m256i iacc_mat_10_0 = _mm256_add_epi16(iacc_mat_10_0_sp1, iacc_mat_10_0_sp2);
|
4370
|
+
__m256i iacc_mat_11_0 = _mm256_add_epi16(iacc_mat_11_0_sp1, iacc_mat_11_0_sp2);
|
4371
|
+
|
4372
|
+
__m256i iacc_mat_00_1 = _mm256_add_epi16(iacc_mat_00_1_sp1, iacc_mat_00_1_sp2);
|
4373
|
+
__m256i iacc_mat_01_1 = _mm256_add_epi16(iacc_mat_01_1_sp1, iacc_mat_01_1_sp2);
|
4374
|
+
__m256i iacc_mat_10_1 = _mm256_add_epi16(iacc_mat_10_1_sp1, iacc_mat_10_1_sp2);
|
4375
|
+
__m256i iacc_mat_11_1 = _mm256_add_epi16(iacc_mat_11_1_sp1, iacc_mat_11_1_sp2);
|
4376
|
+
|
4377
|
+
// Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
|
4378
|
+
iacc_mat_00_0 = _mm256_madd_epi16(iacc_mat_00_0, scale_0145_0);
|
4379
|
+
iacc_mat_01_0 = _mm256_madd_epi16(iacc_mat_01_0, scale_2367_0);
|
4380
|
+
iacc_mat_10_0 = _mm256_madd_epi16(iacc_mat_10_0, scale_0145_0);
|
4381
|
+
iacc_mat_11_0 = _mm256_madd_epi16(iacc_mat_11_0, scale_2367_0);
|
4382
|
+
|
4383
|
+
iacc_mat_00_1 = _mm256_madd_epi16(iacc_mat_00_1, scale_0145_1);
|
4384
|
+
iacc_mat_01_1 = _mm256_madd_epi16(iacc_mat_01_1, scale_2367_1);
|
4385
|
+
iacc_mat_10_1 = _mm256_madd_epi16(iacc_mat_10_1, scale_0145_1);
|
4386
|
+
iacc_mat_11_1 = _mm256_madd_epi16(iacc_mat_11_1, scale_2367_1);
|
4387
|
+
|
4388
|
+
// Straighten out to make 4 row vectors (4 for each sub block which are accumulated together in the next step)
|
4389
|
+
__m256i iacc_row_0_0 = _mm256_blend_epi32(iacc_mat_00_0, _mm256_shuffle_epi32(iacc_mat_01_0, 78), 204);
|
4390
|
+
__m256i iacc_row_1_0 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_00_0, 78), iacc_mat_01_0, 204);
|
4391
|
+
__m256i iacc_row_2_0 = _mm256_blend_epi32(iacc_mat_10_0, _mm256_shuffle_epi32(iacc_mat_11_0, 78), 204);
|
4392
|
+
__m256i iacc_row_3_0 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_10_0, 78), iacc_mat_11_0, 204);
|
4393
|
+
__m256i iacc_row_0_1 = _mm256_blend_epi32(iacc_mat_00_1, _mm256_shuffle_epi32(iacc_mat_01_1, 78), 204);
|
4394
|
+
__m256i iacc_row_1_1 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_00_1, 78), iacc_mat_01_1, 204);
|
4395
|
+
__m256i iacc_row_2_1 = _mm256_blend_epi32(iacc_mat_10_1, _mm256_shuffle_epi32(iacc_mat_11_1, 78), 204);
|
4396
|
+
__m256i iacc_row_3_1 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_10_1, 78), iacc_mat_11_1, 204);
|
4397
|
+
|
4398
|
+
__m256i iacc_row_0 = _mm256_add_epi32(iacc_row_0_0, iacc_row_0_1);
|
4399
|
+
__m256i iacc_row_1 = _mm256_add_epi32(iacc_row_1_0, iacc_row_1_1);
|
4400
|
+
__m256i iacc_row_2 = _mm256_add_epi32(iacc_row_2_0, iacc_row_2_1);
|
4401
|
+
__m256i iacc_row_3 = _mm256_add_epi32(iacc_row_3_0, iacc_row_3_1);
|
4402
|
+
|
4403
|
+
// Load the scale(d) values for all the 4 Q8_k blocks and repeat it across lanes
|
4404
|
+
const __m128 row_scale_f32_sse = _mm_load_ps(a_ptrs[rp][b].d);
|
4405
|
+
const __m256 row_scale_f32 = _mm256_set_m128(row_scale_f32_sse, row_scale_f32_sse);//LM_GGML_F32Cx8_REPEAT_LOAD(a_ptrs[rp][b].d, loadMask);
|
4406
|
+
|
4407
|
+
// Multiply with appropiate scales and accumulate (for both d and dmin) below
|
4408
|
+
acc_rows[rp * 4] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_0), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[rp * 4]);
|
4409
|
+
acc_rows[rp * 4 + 1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_1), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[rp * 4 + 1]);
|
4410
|
+
acc_rows[rp * 4 + 2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_2), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[rp * 4 + 2]);
|
4411
|
+
acc_rows[rp * 4 + 3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_3), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[rp * 4 + 3]);
|
4412
|
+
|
4413
|
+
__m256i iacc_row_min_0 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_hsum_0123_01, 0), mins_01);
|
4414
|
+
__m256i iacc_row_min_1 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_hsum_0123_01, 85), mins_01);
|
4415
|
+
__m256i iacc_row_min_2 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_hsum_0123_01, 170), mins_01);
|
4416
|
+
__m256i iacc_row_min_3 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_hsum_0123_01, 255), mins_01);
|
4417
|
+
|
4418
|
+
acc_min_rows[rp * 4] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_0), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_min_rows[rp * 4]);
|
4419
|
+
acc_min_rows[rp * 4 + 1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_1), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_min_rows[rp * 4 + 1]);
|
4420
|
+
acc_min_rows[rp * 4 + 2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_2), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_min_rows[rp * 4 + 2]);
|
4421
|
+
acc_min_rows[rp * 4 + 3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_3), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_min_rows[rp * 4 + 3]);
|
4422
|
+
|
4423
|
+
}
|
4424
|
+
}
|
4425
|
+
}
|
4426
|
+
// Store the accumulated values
|
4427
|
+
for (int i = 0; i < 16; i++) {
|
4428
|
+
_mm256_storeu_ps((float * )(s + ((y * 4 + i) * bs + x * 8)), _mm256_sub_ps(acc_rows[i], acc_min_rows[i]));
|
4429
|
+
}
|
4430
|
+
}
|
4431
|
+
}
|
4432
|
+
for (; y < nr / 4; y++) {
|
4433
|
+
|
4434
|
+
const block_q8_Kx4 * a_ptr = a_ptr_start + (y * nb);
|
4435
|
+
|
4436
|
+
for (int64_t x = 0; x < nc / 8; x++) {
|
4437
|
+
|
4438
|
+
const block_q4_Kx8 * b_ptr = b_ptr_start + (x * b_nb);
|
4439
|
+
|
4440
|
+
// Master FP accumulators
|
4441
|
+
__m256 acc_rows[4];
|
4442
|
+
for (int i = 0; i < 4; i++) {
|
4443
|
+
acc_rows[i] = _mm256_setzero_ps();
|
4444
|
+
}
|
4445
|
+
|
4446
|
+
__m256 acc_min_rows[4];
|
4447
|
+
for (int i = 0; i < 4; i++) {
|
4448
|
+
acc_min_rows[i] = _mm256_setzero_ps();
|
4449
|
+
}
|
4450
|
+
|
4451
|
+
for (int64_t b = 0; b < nb; b++) {
|
4452
|
+
|
4453
|
+
// Scale values - Load the eight scale values of block_q4_Kx8
|
4454
|
+
const __m256 col_scale_f32 = LM_GGML_F32Cx8_LOAD(b_ptr[b].d);
|
4455
|
+
|
4456
|
+
// dmin values - Load the eight dmin values of block_q4_Kx8
|
4457
|
+
const __m256 col_dmin_f32 = LM_GGML_F32Cx8_LOAD(b_ptr[b].dmin);
|
4458
|
+
|
4459
|
+
// Loop to iterate over the eight sub blocks of a super block - two sub blocks are processed per iteration
|
4460
|
+
for (int sb = 0; sb < QK_K / 64; sb++) {
|
4461
|
+
|
4462
|
+
// Load the eight block_q4_k for two sub blocks quantized values interleaved with each other in chunks of eight bytes - B0,B1 ....B6,B7
|
4463
|
+
const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + sb * 256));
|
4464
|
+
const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 32 + sb * 256));
|
4465
|
+
const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 64 + sb * 256));
|
4466
|
+
const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 96 + sb * 256));
|
4467
|
+
const __m256i rhs_raw_mat_0123_2 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 128 + sb * 256));
|
4468
|
+
const __m256i rhs_raw_mat_4567_2 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 160 + sb * 256));
|
4469
|
+
const __m256i rhs_raw_mat_0123_3 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 192 + sb * 256));
|
4470
|
+
const __m256i rhs_raw_mat_4567_3 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 224 + sb * 256));
|
4471
|
+
|
4472
|
+
// Save the values in the following vectors in the formats B0B1B4B5, B2B3B6B7 for further processing and storing of values
|
4473
|
+
const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240);
|
4474
|
+
const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240);
|
4475
|
+
const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240);
|
4476
|
+
const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240);
|
4477
|
+
const __m256i rhs_raw_mat_0145_2 = _mm256_blend_epi32(rhs_raw_mat_0123_2, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_2, requiredOrder), 240);
|
4478
|
+
const __m256i rhs_raw_mat_2367_2 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_2, requiredOrder), rhs_raw_mat_4567_2, 240);
|
4479
|
+
const __m256i rhs_raw_mat_0145_3 = _mm256_blend_epi32(rhs_raw_mat_0123_3, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_3, requiredOrder), 240);
|
4480
|
+
const __m256i rhs_raw_mat_2367_3 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_3, requiredOrder), rhs_raw_mat_4567_3, 240);
|
4481
|
+
|
4482
|
+
// 4-bit -> 8-bit
|
4483
|
+
// First sub block of the two sub blocks processed in the iteration
|
4484
|
+
const __m256i rhs_mat_0145_00 = _mm256_and_si256(rhs_raw_mat_0145_0, m4b); //B00(0-7) B01(0-7) B04(0-7) B05(0-7)
|
4485
|
+
const __m256i rhs_mat_2367_00 = _mm256_and_si256(rhs_raw_mat_2367_0, m4b); //B02(0-7) B03(0-7) B06(0-7) B07(0-7)
|
4486
|
+
|
4487
|
+
const __m256i rhs_mat_0145_01 = _mm256_and_si256(rhs_raw_mat_0145_1, m4b); //B00(8-15) B01(8-15) B04(8-15) B05(8-15)
|
4488
|
+
const __m256i rhs_mat_2367_01 = _mm256_and_si256(rhs_raw_mat_2367_1, m4b); //B02(8-15) B03(8-15) B06(8-15) B07(8-15)
|
4489
|
+
|
4490
|
+
const __m256i rhs_mat_0145_02 = _mm256_and_si256(rhs_raw_mat_0145_2, m4b); //B00(16-23) B01(16-23) B04(16-23) B05(16-23)
|
4491
|
+
const __m256i rhs_mat_2367_02 = _mm256_and_si256(rhs_raw_mat_2367_2, m4b); //B02(16-23) B03(16-23) B06(16-23) B07(16-23)
|
4492
|
+
|
4493
|
+
const __m256i rhs_mat_0145_03 = _mm256_and_si256(rhs_raw_mat_0145_3, m4b); //B00(24-31) B01(24-31) B04(24-31) B05(24-31)
|
4494
|
+
const __m256i rhs_mat_2367_03 = _mm256_and_si256(rhs_raw_mat_2367_3, m4b); //B02(24-31) B03(24-31) B06(24-31) B07(24-31)
|
4495
|
+
|
4496
|
+
// Second sub block of the two sub blocks processed in the iteration
|
4497
|
+
const __m256i rhs_mat_0145_10 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_0, 4), m4b); //B10(0-7) B11(0-7) B14(0-7) B15(0-7)
|
4498
|
+
const __m256i rhs_mat_2367_10 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_0, 4), m4b); //B12(0-7) B13(0-7) B16(0-7) B17(0-7)
|
4499
|
+
|
4500
|
+
const __m256i rhs_mat_0145_11 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_1, 4), m4b); //B10(8-15) B11(8-15) B14(8-15) B15(8-15)
|
4501
|
+
const __m256i rhs_mat_2367_11 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_1, 4), m4b); //B12(8-15) B13(8-15) B16(8-15) B17(8-15)
|
4502
|
+
|
4503
|
+
const __m256i rhs_mat_0145_12 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_2, 4), m4b); //B10(16-23) B11(16-23) B14(16-23) B15(16-23)
|
4504
|
+
const __m256i rhs_mat_2367_12 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_2, 4), m4b); //B12(16-23) B13(16-23) B16(16-23) B17(16-23)
|
4505
|
+
|
4506
|
+
const __m256i rhs_mat_0145_13 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_3, 4), m4b); //B10(24-31) B11(24-31) B14(24-31) B15(24-31)
|
4507
|
+
const __m256i rhs_mat_2367_13 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_3, 4), m4b); //B12(24-31) B13(24-31) B16(24-31) B17(24-31)
|
4508
|
+
|
4509
|
+
// Shuffle pattern one - right side input
|
4510
|
+
const __m256i rhs_mat_0145_00_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_00, 136); //B00(0-3) B01(0-3) B00(0-3) B01(0-3) B04(0-3) B05(0-3) B04(0-3) B05(0-3)
|
4511
|
+
const __m256i rhs_mat_2367_00_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_00, 136); //B02(0-3) B03(0-3) B02(0-3) B03(0-3) B06(0-3) B07(0-3) B06(0-3) B07(0-3)
|
4512
|
+
|
4513
|
+
const __m256i rhs_mat_0145_01_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_01, 136); //B00(8-11) B01(8-11) B00(8-11) B01(8-11) B04(8-11) B05(8-11) B04(8-11) B05(8-11)
|
4514
|
+
const __m256i rhs_mat_2367_01_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_01, 136); //B02(8-11) B03(8-11) B02(8-11) B03(8-11) B06(8-11) B07(8-11) B06(8-11) B07(8-11)
|
4515
|
+
|
4516
|
+
const __m256i rhs_mat_0145_02_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_02, 136); //B00(16-19) B01(16-19) B00(16-19) B01(16-19) B04(16-19) B05(16-19) B04(16-19) B05(16-19)
|
4517
|
+
const __m256i rhs_mat_2367_02_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_02, 136); //B02(16-19) B03(16-19) B02(16-19) B03(16-19) B06(16-19) B07(16-19) B06(16-19) B07(16-19)
|
4518
|
+
|
4519
|
+
const __m256i rhs_mat_0145_03_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_03, 136); //B00(24-27) B01(24-27) B00(24-27) B01(24-27) B04(24-27) B05(24-27) B04(24-27) B05(24-27)
|
4520
|
+
const __m256i rhs_mat_2367_03_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_03, 136); //B02(24-27) B03(24-27) B02(24-27) B03(24-27) B06(24-27) B07(24-27) B06(24-27) B07(24-27)
|
4521
|
+
|
4522
|
+
const __m256i rhs_mat_0145_10_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_10, 136); //B10(0-3) B11(0-3) B10(0-3) B11(0-3) B14(0-3) B15(0-3) B14(0-3) B15(0-3)
|
4523
|
+
const __m256i rhs_mat_2367_10_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_10, 136); //B12(0-3) B13(0-3) B12(0-3) B13(0-3) B16(0-3) B17(0-3) B16(0-3) B17(0-3)
|
4524
|
+
|
4525
|
+
const __m256i rhs_mat_0145_11_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_11, 136); //B10(8-11) B11(8-11) B10(8-11) B11(8-11) B14(8-11) B15(8-11) B14(8-11) B15(8-11)
|
4526
|
+
const __m256i rhs_mat_2367_11_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_11, 136); //B12(8-11) B13(8-11) B12(8-11) B13(8-11) B16(8-11) B17(8-11) B16(8-11) B17(8-11)
|
4527
|
+
|
4528
|
+
const __m256i rhs_mat_0145_12_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_12, 136); //B10(16-19) B11(16-19) B10(16-19) B11(16-19) B14(16-19) B15(16-19) B14(16-19) B15(16-19)
|
4529
|
+
const __m256i rhs_mat_2367_12_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_12, 136); //B12(16-19) B13(16-19) B12(16-19) B13(16-19) B16(16-19) B17(16-19) B16(16-19) B17(16-19)
|
4530
|
+
|
4531
|
+
const __m256i rhs_mat_0145_13_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_13, 136); //B10(24-27) B11(24-27) B10(24-27) B11(24-27) B14(24-27) B15(24-27) B14(24-27) B15(24-27)
|
4532
|
+
const __m256i rhs_mat_2367_13_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_13, 136); //B12(24-27) B13(24-27) B12(24-27) B13(24-27) B16(24-27) B17(24-27) B16(24-27) B17(24-27)
|
4533
|
+
|
4534
|
+
// Shuffle pattern two - right side input
|
4535
|
+
const __m256i rhs_mat_0145_00_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_00, 221); //B00(4-7) B01(4-7) B00(4-7) B01(4-7) B04(4-7) B05(4-7) B04(4-7) B05(4-7)
|
4536
|
+
const __m256i rhs_mat_2367_00_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_00, 221); //B02(4-7) B03(4-7) B02(4-7) B03(4-7) B06(4-7) B07(4-7) B06(4-7) B07(4-7)
|
4537
|
+
|
4538
|
+
const __m256i rhs_mat_0145_01_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_01, 221); //B00(12-15) B01(12-15) B00(12-15) B01(12-15) B04(12-15) B05(12-15) B04(12-15) B05(12-15)
|
4539
|
+
const __m256i rhs_mat_2367_01_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_01, 221); //B02(12-15) B03(12-15) B02(12-15) B03(12-15) B06(12-15) B07(12-15) B06(12-15) B07(12-15)
|
4540
|
+
|
4541
|
+
const __m256i rhs_mat_0145_02_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_02, 221); //B00(20-23) B01(20-23) B00(20-23) B01(20-23) B04(20-23) B05(20-23) B04(20-23) B05(20-23)
|
4542
|
+
const __m256i rhs_mat_2367_02_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_02, 221); //B02(20-23) B03(20-23) B02(20-23) B03(20-23) B06(20-23) B07(20-23) B06(20-23) B07(20-23)
|
4543
|
+
|
4544
|
+
const __m256i rhs_mat_0145_03_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_03, 221); //B00(28-31) B01(28-31) B00(28-31) B01(28-31) B04(28-31) B05(28-31) B04(28-31) B05(28-31)
|
4545
|
+
const __m256i rhs_mat_2367_03_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_03, 221); //B02(28-31) B03(28-31) B02(28-31) B03(28-31) B06(28-31) B07(28-31) B06(28-31) B07(28-31)
|
4546
|
+
|
4547
|
+
const __m256i rhs_mat_0145_10_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_10, 221); //B10(4-7) B11(4-7) B10(4-7) B11(4-7) B14(4-7) B15(4-7) B14(4-7) B15(4-7)
|
4548
|
+
const __m256i rhs_mat_2367_10_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_10, 221); //B12(4-7) B13(4-7) B12(4-7) B13(4-7) B16(4-7) B17(4-7) B16(4-7) B17(4-7)
|
4549
|
+
|
4550
|
+
const __m256i rhs_mat_0145_11_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_11, 221); //B10(12-15) B11(12-15) B10(12-15) B11(12-15) B14(12-15) B15(12-15) B14(12-15) B15(12-15)
|
4551
|
+
const __m256i rhs_mat_2367_11_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_11, 221); //B12(12-15) B13(12-15) B12(12-15) B13(12-15) B16(12-15) B17(12-15) B16(12-15) B17(12-15)
|
4552
|
+
|
4553
|
+
const __m256i rhs_mat_0145_12_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_12, 221); //B10(20-23) B11(20-23) B10(20-23) B11(20-23) B14(20-23) B15(20-23) B14(20-23) B15(20-23)
|
4554
|
+
const __m256i rhs_mat_2367_12_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_12, 221); //B12(20-23) B13(20-23) B12(20-23) B13(20-23) B16(20-23) B17(20-23) B16(20-23) B17(20-23)
|
4555
|
+
|
4556
|
+
const __m256i rhs_mat_0145_13_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_13, 221); //B10(28-31) B11(28-31) B10(28-31) B11(28-31) B14(28-31) B15(28-31) B14(28-31) B15(28-31)
|
4557
|
+
const __m256i rhs_mat_2367_13_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_13, 221); //B12(28-31) B13(28-31) B12(28-31) B13(28-31) B16(28-31) B17(28-31) B16(28-31) B17(28-31)
|
4558
|
+
|
4559
|
+
uint32_t utmp_0[4], utmp_1[4];
|
4560
|
+
|
4561
|
+
// Scales and Mins of corresponding sub blocks from different Q4_K structures are stored together
|
4562
|
+
// The below block is for eg to extract first sub block's scales and mins from different Q4_K structures for the sb loop
|
4563
|
+
memcpy(utmp_0, b_ptr[b].scales + 24 * sb, 12);
|
4564
|
+
utmp_0[3] = ((utmp_0[2] >> 4) & kmask2) | (((utmp_0[1] >> 6) & kmask3) << 4);
|
4565
|
+
const uint32_t uaux_0 = utmp_0[1] & kmask1;
|
4566
|
+
utmp_0[1] = (utmp_0[2] & kmask2) | (((utmp_0[0] >> 6) & kmask3) << 4);
|
4567
|
+
utmp_0[2] = uaux_0;
|
4568
|
+
utmp_0[0] &= kmask1;
|
4569
|
+
|
4570
|
+
// The below block is for eg to extract second sub block's scales and mins from different Q4_K structures when sb = 1
|
4571
|
+
memcpy(utmp_1, b_ptr[b].scales + 12 + sb * 24, 12);
|
4572
|
+
utmp_1[3] = ((utmp_1[2] >> 4) & kmask2) | (((utmp_1[1] >> 6) & kmask3) << 4);
|
4573
|
+
const uint32_t uaux_1 = utmp_1[1] & kmask1;
|
4574
|
+
utmp_1[1] = (utmp_1[2] & kmask2) | (((utmp_1[0] >> 6) & kmask3) << 4);
|
4575
|
+
utmp_1[2] = uaux_1;
|
4576
|
+
utmp_1[0] &= kmask1;
|
4577
|
+
|
4578
|
+
// Scales of first sub block in the sb loop
|
4579
|
+
const __m128i mins_and_scales_0 = _mm_set_epi32(utmp_0[3], utmp_0[2], utmp_0[1], utmp_0[0]);
|
4580
|
+
const __m256i scales_0 = _mm256_cvtepu8_epi16(_mm_unpacklo_epi8(mins_and_scales_0, mins_and_scales_0));
|
4581
|
+
|
4582
|
+
// Scales of second sub block in the sb loop
|
4583
|
+
const __m128i mins_and_scales_1 = _mm_set_epi32(utmp_1[3], utmp_1[2], utmp_1[1], utmp_1[0]);
|
4584
|
+
const __m256i scales_1 = _mm256_cvtepu8_epi16(_mm_unpacklo_epi8(mins_and_scales_1, mins_and_scales_1));
|
4585
|
+
|
4586
|
+
// Mins of first and second sub block of Q4_K block are arranged side by side
|
4587
|
+
const __m256i mins_01 = _mm256_cvtepu8_epi16(_mm_unpacklo_epi8(_mm_shuffle_epi32(mins_and_scales_0, 78), _mm_shuffle_epi32(mins_and_scales_1, 78)));
|
4588
|
+
|
4589
|
+
const __m256i scale_0145_0 = _mm256_shuffle_epi32(scales_0, 68);
|
4590
|
+
const __m256i scale_2367_0 = _mm256_shuffle_epi32(scales_0, 238);
|
4591
|
+
|
4592
|
+
const __m256i scale_0145_1 = _mm256_shuffle_epi32(scales_1, 68);
|
4593
|
+
const __m256i scale_2367_1 = _mm256_shuffle_epi32(scales_1, 238);
|
4594
|
+
|
4595
|
+
// Load the four block_q8_k quantized values interleaved with each other in chunks of eight bytes - A0,A1,A2,A3
|
4596
|
+
// Loaded as set of 128 bit vectors and repeated into a 256 bit vector
|
4597
|
+
__m256i lhs_mat_0123_00 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 256 * sb)));
|
4598
|
+
__m256i lhs_mat_01_00 = _mm256_permute2f128_si256(lhs_mat_0123_00, lhs_mat_0123_00, 0);
|
4599
|
+
__m256i lhs_mat_23_00 = _mm256_permute2f128_si256(lhs_mat_0123_00, lhs_mat_0123_00, 17);
|
4600
|
+
__m256i lhs_mat_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 32 + 256 * sb)));
|
4601
|
+
__m256i lhs_mat_01_01 = _mm256_permute2f128_si256(lhs_mat_0123_01, lhs_mat_0123_01, 0);
|
4602
|
+
__m256i lhs_mat_23_01 = _mm256_permute2f128_si256(lhs_mat_0123_01, lhs_mat_0123_01, 17);
|
4603
|
+
__m256i lhs_mat_0123_02 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 64 + 256 * sb)));
|
4604
|
+
__m256i lhs_mat_01_02 = _mm256_permute2f128_si256(lhs_mat_0123_02, lhs_mat_0123_02, 0);
|
4605
|
+
__m256i lhs_mat_23_02 = _mm256_permute2f128_si256(lhs_mat_0123_02, lhs_mat_0123_02, 17);
|
4606
|
+
__m256i lhs_mat_0123_03 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 96 + 256 * sb)));
|
4607
|
+
__m256i lhs_mat_01_03 = _mm256_permute2f128_si256(lhs_mat_0123_03, lhs_mat_0123_03, 0);
|
4608
|
+
__m256i lhs_mat_23_03 = _mm256_permute2f128_si256(lhs_mat_0123_03, lhs_mat_0123_03, 17);
|
4609
|
+
__m256i lhs_mat_0123_10 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 128 + 256 * sb)));
|
4610
|
+
__m256i lhs_mat_01_10 = _mm256_permute2f128_si256(lhs_mat_0123_10, lhs_mat_0123_10, 0);
|
4611
|
+
__m256i lhs_mat_23_10 = _mm256_permute2f128_si256(lhs_mat_0123_10, lhs_mat_0123_10, 17);
|
4612
|
+
__m256i lhs_mat_0123_11 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 160 + 256 * sb)));
|
4613
|
+
__m256i lhs_mat_01_11 = _mm256_permute2f128_si256(lhs_mat_0123_11, lhs_mat_0123_11, 0);
|
4614
|
+
__m256i lhs_mat_23_11 = _mm256_permute2f128_si256(lhs_mat_0123_11, lhs_mat_0123_11, 17);
|
4615
|
+
__m256i lhs_mat_0123_12 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 192 + 256 * sb)));
|
4616
|
+
__m256i lhs_mat_01_12 = _mm256_permute2f128_si256(lhs_mat_0123_12, lhs_mat_0123_12, 0);
|
4617
|
+
__m256i lhs_mat_23_12 = _mm256_permute2f128_si256(lhs_mat_0123_12, lhs_mat_0123_12, 17);
|
4618
|
+
__m256i lhs_mat_0123_13 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 224 + 256 * sb)));
|
4619
|
+
__m256i lhs_mat_01_13 = _mm256_permute2f128_si256(lhs_mat_0123_13, lhs_mat_0123_13, 0);
|
4620
|
+
__m256i lhs_mat_23_13 = _mm256_permute2f128_si256(lhs_mat_0123_13, lhs_mat_0123_13, 17);
|
4621
|
+
|
4622
|
+
// Bsums are loaded - four bsums are loaded (for two sub blocks) for the different Q8_K blocks
|
4623
|
+
__m256i lhs_bsums_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].bsums + 16 * sb)));
|
4624
|
+
__m256i lhs_bsums_hsum_0123_01 = _mm256_castsi128_si256(_mm_hadd_epi16(_mm256_castsi256_si128(lhs_bsums_0123_01), _mm256_extractf128_si256(lhs_bsums_0123_01, 1)));
|
4625
|
+
lhs_bsums_hsum_0123_01 = _mm256_permute2x128_si256(lhs_bsums_hsum_0123_01, lhs_bsums_hsum_0123_01, 0);
|
4626
|
+
|
4627
|
+
// Shuffle pattern one - left side input
|
4628
|
+
const __m256i lhs_mat_01_00_sp1 = _mm256_shuffle_epi32(lhs_mat_01_00, 160); //A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3)
|
4629
|
+
const __m256i lhs_mat_23_00_sp1 = _mm256_shuffle_epi32(lhs_mat_23_00, 160); //A02(0-3) A03(0-3) A02(0-3) A03(0-3) A02(0-3) A03(0-3) A02(0-3) A03(0-3)
|
4630
|
+
|
4631
|
+
const __m256i lhs_mat_01_01_sp1 = _mm256_shuffle_epi32(lhs_mat_01_01, 160); //A00(8-11) A00(8-11) A01(8-11) A01(8-11) A00(8-11) A00(8-11) A01(8-11) A01(8-11)
|
4632
|
+
const __m256i lhs_mat_23_01_sp1 = _mm256_shuffle_epi32(lhs_mat_23_01, 160); //A02(8-11) A03(8-11) A02(8-11) A03(8-11) A02(8-11) A03(8-11) A02(8-11) A03(8-11)
|
4633
|
+
|
4634
|
+
const __m256i lhs_mat_01_02_sp1 = _mm256_shuffle_epi32(lhs_mat_01_02, 160); //A00(16-19) A00(16-19) A01(16-19) A01(16-19) A00(16-19) A00(16-19) A01(16-19) A01(16-19)
|
4635
|
+
const __m256i lhs_mat_23_02_sp1 = _mm256_shuffle_epi32(lhs_mat_23_02, 160); //A02(16-19) A03(16-19) A02(16-19) A03(16-19) A02(16-19) A03(16-19) A02(16-19) A03(16-19)
|
4636
|
+
|
4637
|
+
const __m256i lhs_mat_01_03_sp1 = _mm256_shuffle_epi32(lhs_mat_01_03, 160); //A00(24-27) A00(24-27) A01(24-27) A01(24-27) A00(24-27) A00(24-27) A01(24-27) A01(24-27)
|
4638
|
+
const __m256i lhs_mat_23_03_sp1 = _mm256_shuffle_epi32(lhs_mat_23_03, 160); //A02(24-27) A03(24-27) A02(24-27) A03(24-27) A02(24-27) A03(24-27) A02(24-27) A03(24-27)
|
4639
|
+
|
4640
|
+
const __m256i lhs_mat_01_10_sp1 = _mm256_shuffle_epi32(lhs_mat_01_10, 160); //A10(0-3) A10(0-3) A11(0-3) A11(0-3) A10(0-3) A10(0-3) A11(0-3) A11(0-3)
|
4641
|
+
const __m256i lhs_mat_23_10_sp1 = _mm256_shuffle_epi32(lhs_mat_23_10, 160); //A12(0-3) A13(0-3) A12(0-3) A13(0-3) A12(0-3) A13(0-3) A12(0-3) A13(0-3)
|
4642
|
+
|
4643
|
+
const __m256i lhs_mat_01_11_sp1 = _mm256_shuffle_epi32(lhs_mat_01_11, 160); //A10(8-11) A10(8-11) A11(8-11) A11(8-11) A10(8-11) A10(8-11) A11(8-11) A11(8-11)
|
4644
|
+
const __m256i lhs_mat_23_11_sp1 = _mm256_shuffle_epi32(lhs_mat_23_11, 160); //A12(8-11) A13(8-11) A12(8-11) A13(8-11) A12(8-11) A13(8-11) A12(8-11) A13(8-11)
|
4645
|
+
|
4646
|
+
const __m256i lhs_mat_01_12_sp1 = _mm256_shuffle_epi32(lhs_mat_01_12, 160); //A10(16-19) A10(16-19) A11(16-19) A11(16-19) A10(16-19) A10(16-19) A11(16-19) A11(16-19)
|
4647
|
+
const __m256i lhs_mat_23_12_sp1 = _mm256_shuffle_epi32(lhs_mat_23_12, 160); //A12(16-19) A13(16-19) A12(16-19) A13(16-19) A12(16-19) A13(16-19) A12(16-19) A13(16-19)
|
4648
|
+
|
4649
|
+
const __m256i lhs_mat_01_13_sp1 = _mm256_shuffle_epi32(lhs_mat_01_13, 160); //A10(24-27) A10(24-27) A11(24-27) A11(24-27) A10(24-27) A10(24-27) A11(24-27) A11(24-27)
|
4650
|
+
const __m256i lhs_mat_23_13_sp1 = _mm256_shuffle_epi32(lhs_mat_23_13, 160); //A12(24-27) A13(24-27) A12(24-27) A13(24-27) A12(24-27) A13(24-27) A12(24-27) A13(24-27)
|
4651
|
+
|
4652
|
+
// Shuffle pattern two- left side input
|
4653
|
+
const __m256i lhs_mat_01_00_sp2 = _mm256_shuffle_epi32(lhs_mat_01_00, 245); //A00(4-7) A00(4-7) A01(4-7) A01(4-7) A00(4-7) A00(4-7) A01(4-7) A01(4-7)
|
4654
|
+
const __m256i lhs_mat_23_00_sp2 = _mm256_shuffle_epi32(lhs_mat_23_00, 245); //A02(4-7) A03(4-7) A02(4-7) A03(4-7) A02(4-7) A03(4-7) A02(4-7) A03(4-7)
|
4655
|
+
|
4656
|
+
const __m256i lhs_mat_01_01_sp2 = _mm256_shuffle_epi32(lhs_mat_01_01, 245); //A00(12-15) A00(12-15) A01(12-15) A01(12-15) A00(12-15) A00(12-15) A01(12-15) A01(12-15)
|
4657
|
+
const __m256i lhs_mat_23_01_sp2 = _mm256_shuffle_epi32(lhs_mat_23_01, 245); //A02(12-15) A03(12-15) A02(12-15) A03(12-15) A02(12-15) A03(12-15) A02(12-15) A03(12-15)
|
4658
|
+
|
4659
|
+
const __m256i lhs_mat_01_02_sp2 = _mm256_shuffle_epi32(lhs_mat_01_02, 245); //A00(20-23) A00(20-23) A01(20-23) A01(20-23) A00(20-23) A00(20-23) A01(20-23) A01(20-23)
|
4660
|
+
const __m256i lhs_mat_23_02_sp2 = _mm256_shuffle_epi32(lhs_mat_23_02, 245); //A02(20-23) A03(20-23) A02(20-23) A03(20-23) A02(20-23) A03(20-23) A02(20-23) A03(20-23)
|
4661
|
+
|
4662
|
+
const __m256i lhs_mat_01_03_sp2 = _mm256_shuffle_epi32(lhs_mat_01_03, 245); //A00(28-31) A00(28-31) A01(28-31) A01(28-31) A00(28-31) A00(28-31) A01(28-31) A01(28-31)
|
4663
|
+
const __m256i lhs_mat_23_03_sp2 = _mm256_shuffle_epi32(lhs_mat_23_03, 245); //A02(28-31) A03(28-31) A02(28-31) A03(28-31) A02(28-31) A03(28-31) A02(28-31) A03(28-31)
|
4664
|
+
|
4665
|
+
const __m256i lhs_mat_01_10_sp2 = _mm256_shuffle_epi32(lhs_mat_01_10, 245); //A10(4-7) A10(4-7) A11(4-7) A11(4-7) A10(4-7) A10(4-7) A11(4-7) A11(4-7)
|
4666
|
+
const __m256i lhs_mat_23_10_sp2 = _mm256_shuffle_epi32(lhs_mat_23_10, 245); //A12(4-7) A13(4-7) A12(4-7) A13(4-7) A12(4-7) A13(4-7) A12(4-7) A13(4-7)
|
4667
|
+
|
4668
|
+
const __m256i lhs_mat_01_11_sp2 = _mm256_shuffle_epi32(lhs_mat_01_11, 245); //A10(12-15) A10(12-15) A11(12-15) A11(12-15) A10(12-15) A10(12-15) A11(12-15) A11(12-15)
|
4669
|
+
const __m256i lhs_mat_23_11_sp2 = _mm256_shuffle_epi32(lhs_mat_23_11, 245); //A12(12-15) A13(12-15) A12(12-15) A13(12-15) A12(12-15) A13(12-15) A12(12-15) A13(12-15)
|
4670
|
+
|
4671
|
+
const __m256i lhs_mat_01_12_sp2 = _mm256_shuffle_epi32(lhs_mat_01_12, 245); //A10(20-23) A10(20-23) A11(20-23) A11(20-23) A10(20-23) A10(20-23) A11(20-23) A11(20-23)
|
4672
|
+
const __m256i lhs_mat_23_12_sp2 = _mm256_shuffle_epi32(lhs_mat_23_12, 245); //A12(20-23) A13(20-23) A12(20-23) A13(20-23) A12(20-23) A13(20-23) A12(20-23) A13(20-23)
|
4673
|
+
|
4674
|
+
const __m256i lhs_mat_01_13_sp2 = _mm256_shuffle_epi32(lhs_mat_01_13, 245); //A10(28-31) A10(28-31) A11(28-31) A11(28-31) A10(28-31) A10(28-31) A11(28-31) A11(28-31)
|
4675
|
+
const __m256i lhs_mat_23_13_sp2 = _mm256_shuffle_epi32(lhs_mat_23_13, 245); //A12(28-31) A13(28-31) A12(28-31) A13(28-31) A12(28-31) A13(28-31) A12(28-31) A13(28-31)
|
4676
|
+
|
4677
|
+
// The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
|
4678
|
+
__m256i iacc_mat_00_0_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_03_sp1, lhs_mat_01_03_sp1), _mm256_maddubs_epi16(rhs_mat_0145_02_sp1, lhs_mat_01_02_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_01_sp1, lhs_mat_01_01_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_00_sp1, lhs_mat_01_00_sp1));
|
4679
|
+
__m256i iacc_mat_01_0_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_03_sp1, lhs_mat_01_03_sp1), _mm256_maddubs_epi16(rhs_mat_2367_02_sp1, lhs_mat_01_02_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_01_sp1, lhs_mat_01_01_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_00_sp1, lhs_mat_01_00_sp1));
|
4680
|
+
__m256i iacc_mat_10_0_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_03_sp1, lhs_mat_23_03_sp1), _mm256_maddubs_epi16(rhs_mat_0145_02_sp1, lhs_mat_23_02_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_01_sp1, lhs_mat_23_01_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_00_sp1, lhs_mat_23_00_sp1));
|
4681
|
+
__m256i iacc_mat_11_0_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_03_sp1, lhs_mat_23_03_sp1), _mm256_maddubs_epi16(rhs_mat_2367_02_sp1, lhs_mat_23_02_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_01_sp1, lhs_mat_23_01_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_00_sp1, lhs_mat_23_00_sp1));
|
4682
|
+
__m256i iacc_mat_00_1_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_13_sp1, lhs_mat_01_13_sp1), _mm256_maddubs_epi16(rhs_mat_0145_12_sp1, lhs_mat_01_12_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_11_sp1, lhs_mat_01_11_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_10_sp1, lhs_mat_01_10_sp1));
|
4683
|
+
__m256i iacc_mat_01_1_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_13_sp1, lhs_mat_01_13_sp1), _mm256_maddubs_epi16(rhs_mat_2367_12_sp1, lhs_mat_01_12_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_11_sp1, lhs_mat_01_11_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_10_sp1, lhs_mat_01_10_sp1));
|
4684
|
+
__m256i iacc_mat_10_1_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_13_sp1, lhs_mat_23_13_sp1), _mm256_maddubs_epi16(rhs_mat_0145_12_sp1, lhs_mat_23_12_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_11_sp1, lhs_mat_23_11_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_10_sp1, lhs_mat_23_10_sp1));
|
4685
|
+
__m256i iacc_mat_11_1_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_13_sp1, lhs_mat_23_13_sp1), _mm256_maddubs_epi16(rhs_mat_2367_12_sp1, lhs_mat_23_12_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_11_sp1, lhs_mat_23_11_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_10_sp1, lhs_mat_23_10_sp1));
|
4686
|
+
|
4687
|
+
__m256i iacc_mat_00_0_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_03_sp2, lhs_mat_01_03_sp2), _mm256_maddubs_epi16(rhs_mat_0145_02_sp2, lhs_mat_01_02_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_01_sp2, lhs_mat_01_01_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_00_sp2, lhs_mat_01_00_sp2));
|
4688
|
+
__m256i iacc_mat_01_0_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_03_sp2, lhs_mat_01_03_sp2), _mm256_maddubs_epi16(rhs_mat_2367_02_sp2, lhs_mat_01_02_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_01_sp2, lhs_mat_01_01_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_00_sp2, lhs_mat_01_00_sp2));
|
4689
|
+
__m256i iacc_mat_10_0_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_03_sp2, lhs_mat_23_03_sp2), _mm256_maddubs_epi16(rhs_mat_0145_02_sp2, lhs_mat_23_02_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_01_sp2, lhs_mat_23_01_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_00_sp2, lhs_mat_23_00_sp2));
|
4690
|
+
__m256i iacc_mat_11_0_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_03_sp2, lhs_mat_23_03_sp2), _mm256_maddubs_epi16(rhs_mat_2367_02_sp2, lhs_mat_23_02_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_01_sp2, lhs_mat_23_01_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_00_sp2, lhs_mat_23_00_sp2));
|
4691
|
+
__m256i iacc_mat_00_1_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_13_sp2, lhs_mat_01_13_sp2), _mm256_maddubs_epi16(rhs_mat_0145_12_sp2, lhs_mat_01_12_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_11_sp2, lhs_mat_01_11_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_10_sp2, lhs_mat_01_10_sp2));
|
4692
|
+
__m256i iacc_mat_01_1_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_13_sp2, lhs_mat_01_13_sp2), _mm256_maddubs_epi16(rhs_mat_2367_12_sp2, lhs_mat_01_12_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_11_sp2, lhs_mat_01_11_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_10_sp2, lhs_mat_01_10_sp2));
|
4693
|
+
__m256i iacc_mat_10_1_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_13_sp2, lhs_mat_23_13_sp2), _mm256_maddubs_epi16(rhs_mat_0145_12_sp2, lhs_mat_23_12_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_11_sp2, lhs_mat_23_11_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_10_sp2, lhs_mat_23_10_sp2));
|
4694
|
+
__m256i iacc_mat_11_1_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_13_sp2, lhs_mat_23_13_sp2), _mm256_maddubs_epi16(rhs_mat_2367_12_sp2, lhs_mat_23_12_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_11_sp2, lhs_mat_23_11_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_10_sp2, lhs_mat_23_10_sp2));
|
4695
|
+
|
4696
|
+
// Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
|
4697
|
+
__m256i iacc_mat_00_0 = _mm256_add_epi16(iacc_mat_00_0_sp1, iacc_mat_00_0_sp2);
|
4698
|
+
__m256i iacc_mat_01_0 = _mm256_add_epi16(iacc_mat_01_0_sp1, iacc_mat_01_0_sp2);
|
4699
|
+
__m256i iacc_mat_10_0 = _mm256_add_epi16(iacc_mat_10_0_sp1, iacc_mat_10_0_sp2);
|
4700
|
+
__m256i iacc_mat_11_0 = _mm256_add_epi16(iacc_mat_11_0_sp1, iacc_mat_11_0_sp2);
|
4701
|
+
|
4702
|
+
__m256i iacc_mat_00_1 = _mm256_add_epi16(iacc_mat_00_1_sp1, iacc_mat_00_1_sp2);
|
4703
|
+
__m256i iacc_mat_01_1 = _mm256_add_epi16(iacc_mat_01_1_sp1, iacc_mat_01_1_sp2);
|
4704
|
+
__m256i iacc_mat_10_1 = _mm256_add_epi16(iacc_mat_10_1_sp1, iacc_mat_10_1_sp2);
|
4705
|
+
__m256i iacc_mat_11_1 = _mm256_add_epi16(iacc_mat_11_1_sp1, iacc_mat_11_1_sp2);
|
4706
|
+
|
4707
|
+
// Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
|
4708
|
+
iacc_mat_00_0 = _mm256_madd_epi16(iacc_mat_00_0, scale_0145_0);
|
4709
|
+
iacc_mat_01_0 = _mm256_madd_epi16(iacc_mat_01_0, scale_2367_0);
|
4710
|
+
iacc_mat_10_0 = _mm256_madd_epi16(iacc_mat_10_0, scale_0145_0);
|
4711
|
+
iacc_mat_11_0 = _mm256_madd_epi16(iacc_mat_11_0, scale_2367_0);
|
4712
|
+
|
4713
|
+
iacc_mat_00_1 = _mm256_madd_epi16(iacc_mat_00_1, scale_0145_1);
|
4714
|
+
iacc_mat_01_1 = _mm256_madd_epi16(iacc_mat_01_1, scale_2367_1);
|
4715
|
+
iacc_mat_10_1 = _mm256_madd_epi16(iacc_mat_10_1, scale_0145_1);
|
4716
|
+
iacc_mat_11_1 = _mm256_madd_epi16(iacc_mat_11_1, scale_2367_1);
|
4717
|
+
|
4718
|
+
// Straighten out to make 4 row vectors (4 for each sub block which are accumulated together in the next step)
|
4719
|
+
__m256i iacc_row_0_0 = _mm256_blend_epi32(iacc_mat_00_0, _mm256_shuffle_epi32(iacc_mat_01_0, 78), 204);
|
4720
|
+
__m256i iacc_row_1_0 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_00_0, 78), iacc_mat_01_0, 204);
|
4721
|
+
__m256i iacc_row_2_0 = _mm256_blend_epi32(iacc_mat_10_0, _mm256_shuffle_epi32(iacc_mat_11_0, 78), 204);
|
4722
|
+
__m256i iacc_row_3_0 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_10_0, 78), iacc_mat_11_0, 204);
|
4723
|
+
__m256i iacc_row_0_1 = _mm256_blend_epi32(iacc_mat_00_1, _mm256_shuffle_epi32(iacc_mat_01_1, 78), 204);
|
4724
|
+
__m256i iacc_row_1_1 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_00_1, 78), iacc_mat_01_1, 204);
|
4725
|
+
__m256i iacc_row_2_1 = _mm256_blend_epi32(iacc_mat_10_1, _mm256_shuffle_epi32(iacc_mat_11_1, 78), 204);
|
4726
|
+
__m256i iacc_row_3_1 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_10_1, 78), iacc_mat_11_1, 204);
|
4727
|
+
|
4728
|
+
__m256i iacc_row_0 = _mm256_add_epi32(iacc_row_0_0, iacc_row_0_1);
|
4729
|
+
__m256i iacc_row_1 = _mm256_add_epi32(iacc_row_1_0, iacc_row_1_1);
|
4730
|
+
__m256i iacc_row_2 = _mm256_add_epi32(iacc_row_2_0, iacc_row_2_1);
|
4731
|
+
__m256i iacc_row_3 = _mm256_add_epi32(iacc_row_3_0, iacc_row_3_1);
|
4732
|
+
|
4733
|
+
// Load the scale(d) values for all the 4 Q8_k blocks and repeat it across lanes
|
4734
|
+
const __m128 row_scale_f32_sse = _mm_load_ps(a_ptr[b].d);
|
4735
|
+
const __m256 row_scale_f32 = _mm256_set_m128(row_scale_f32_sse, row_scale_f32_sse); //LM_GGML_F32Cx8_REPEAT_LOAD(a_ptrs[rp][b].d, loadMask);
|
4736
|
+
|
4737
|
+
// Multiply with appropiate scales and accumulate (for both d and dmin) below
|
4738
|
+
acc_rows[0] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_0), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[0]);
|
4739
|
+
acc_rows[1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_1), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[1]);
|
4740
|
+
acc_rows[2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_2), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[2]);
|
4741
|
+
acc_rows[3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_3), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[3]);
|
4742
|
+
|
4743
|
+
__m256i iacc_row_min_0 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_hsum_0123_01, 0), mins_01);
|
4744
|
+
__m256i iacc_row_min_1 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_hsum_0123_01, 85), mins_01);
|
4745
|
+
__m256i iacc_row_min_2 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_hsum_0123_01, 170), mins_01);
|
4746
|
+
__m256i iacc_row_min_3 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_hsum_0123_01, 255), mins_01);
|
4747
|
+
|
4748
|
+
acc_min_rows[0] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_0), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_min_rows[0]);
|
4749
|
+
acc_min_rows[1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_1), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_min_rows[1]);
|
4750
|
+
acc_min_rows[2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_2), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_min_rows[2]);
|
4751
|
+
acc_min_rows[3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_3), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_min_rows[3]);
|
4752
|
+
}
|
4753
|
+
}
|
4754
|
+
|
4755
|
+
// Store the accumulated values
|
4756
|
+
for (int i = 0; i < 4; i++) {
|
4757
|
+
_mm256_storeu_ps((float * )(s + ((y * 4 + i) * bs + x * 8)), _mm256_sub_ps(acc_rows[i], acc_min_rows[i]));
|
4758
|
+
}
|
4759
|
+
}
|
4760
|
+
}
|
4761
|
+
|
4762
|
+
#else
|
4763
|
+
|
4764
|
+
float sumf[4][8];
|
4765
|
+
float sum_minf[4][8];
|
4766
|
+
uint32_t utmp[32];
|
4767
|
+
int sumi1;
|
4768
|
+
int sumi2;
|
4769
|
+
int sumi;
|
4770
|
+
|
4771
|
+
for (int y = 0; y < nr / 4; y++) {
|
4772
|
+
const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb);
|
4773
|
+
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
4774
|
+
const block_q4_Kx8 * b_ptr = (const block_q4_Kx8 *) vx + (x * nb);
|
4775
|
+
for (int m = 0; m < 4; m++) {
|
4776
|
+
for (int j = 0; j < ncols_interleaved; j++) {
|
4777
|
+
sumf[m][j] = 0.0;
|
4778
|
+
sum_minf[m][j] = 0.0;
|
4779
|
+
}
|
4780
|
+
}
|
4781
|
+
for (int l = 0; l < nb; l++) {
|
4782
|
+
for (int sb = 0; sb < 8; sb++) {
|
4783
|
+
memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12);
|
4784
|
+
utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);
|
4785
|
+
const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;
|
4786
|
+
utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);
|
4787
|
+
utmp[sb * 4 + 2] = uaux_0;
|
4788
|
+
utmp[sb * 4 + 0] &= kmask1;
|
4789
|
+
}
|
4790
|
+
for (int k = 0; k < (qk / (2 * blocklen)); k++) {
|
4791
|
+
uint8_t *scales_0 = (uint8_t*) utmp + (k / 4) * 32;
|
4792
|
+
uint8_t *scales_1 = (uint8_t*) utmp + (k / 4) * 32 + 16;
|
4793
|
+
for (int m = 0; m < 4; m++) {
|
4794
|
+
for (int j = 0; j < ncols_interleaved; j++) {
|
4795
|
+
sumi1 = 0;
|
4796
|
+
sumi2 = 0;
|
4797
|
+
sumi = 0;
|
4798
|
+
for (int i = 0; i < blocklen; ++i) {
|
4799
|
+
const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF);
|
4800
|
+
const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4);
|
4801
|
+
sumi1 = (v0 * a_ptr[l].qs[(k >> 2) * 256 + (k % 4) * 4 * blocklen + m * blocklen + i]);
|
4802
|
+
sumi2 = (v1 * a_ptr[l].qs[(k >> 2) * 256 + (k % 4) * 4 * blocklen + m * blocklen + i + 128]);
|
4803
|
+
sumi1 = sumi1 * scales_0[j];
|
4804
|
+
sumi2 = sumi2 * scales_1[j];
|
4805
|
+
sumi += sumi1 + sumi2;
|
4806
|
+
}
|
4807
|
+
sumf[m][j] += sumi * LM_GGML_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m];
|
4808
|
+
}
|
4809
|
+
}
|
4810
|
+
}
|
4811
|
+
for (int sb = 0; sb < 8; sb++) {
|
4812
|
+
uint8_t *mins = (uint8_t*) utmp + 8 + sb * 16;
|
4813
|
+
for(int m = 0; m < 4; m++) {
|
4814
|
+
const int16_t *bsums = a_ptr[l].bsums + (sb * 8) + (m * 4) - ((sb % 2) * 6);
|
4815
|
+
for(int j = 0; j < ncols_interleaved; j++) {
|
4816
|
+
sum_minf[m][j] += mins[j] * (bsums[0] + bsums[1]) * LM_GGML_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m];
|
4817
|
+
}
|
4818
|
+
}
|
4819
|
+
}
|
4820
|
+
}
|
4821
|
+
for (int m = 0; m < 4; m++) {
|
4822
|
+
for (int j = 0; j < ncols_interleaved; j++) {
|
4823
|
+
s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j];
|
4824
|
+
}
|
4825
|
+
}
|
4826
|
+
}
|
4827
|
+
}
|
4828
|
+
#endif
|
4829
|
+
}
|
4830
|
+
|
3483
4831
|
static void lm_ggml_gemm_iq4_nl_4x4_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, const void * LM_GGML_RESTRICT vy, int nr, int nc) {
|
3484
4832
|
const int qk = QK8_0;
|
3485
4833
|
const int nb = n / qk;
|
@@ -3660,6 +5008,82 @@ static block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int blck_size_in
|
|
3660
5008
|
return out;
|
3661
5009
|
}
|
3662
5010
|
|
5011
|
+
static block_q4_Kx8 make_block_q4_Kx8(block_q4_K * in, unsigned int blck_size_interleave) {
|
5012
|
+
block_q4_Kx8 out;
|
5013
|
+
//Delta(scale) and dmin values of the eight Q4_K structures are copied onto the output interleaved structure
|
5014
|
+
for (int i = 0; i < 8; i++) {
|
5015
|
+
out.d[i] = in[i].LM_GGML_COMMON_AGGR_U.LM_GGML_COMMON_AGGR_S.d;
|
5016
|
+
}
|
5017
|
+
|
5018
|
+
for (int i = 0; i < 8; i++) {
|
5019
|
+
out.dmin[i] = in[i].LM_GGML_COMMON_AGGR_U.LM_GGML_COMMON_AGGR_S.dmin;
|
5020
|
+
}
|
5021
|
+
|
5022
|
+
const int end = QK_K * 4 / blck_size_interleave;
|
5023
|
+
|
5024
|
+
// Interleave Q4_K quants by taking 8 bytes at a time
|
5025
|
+
for (int i = 0; i < end; ++i) {
|
5026
|
+
int src_id = i % 8;
|
5027
|
+
int src_offset = (i / 8) * blck_size_interleave;
|
5028
|
+
int dst_offset = i * blck_size_interleave;
|
5029
|
+
|
5030
|
+
uint64_t elems;
|
5031
|
+
memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t));
|
5032
|
+
memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t));
|
5033
|
+
}
|
5034
|
+
|
5035
|
+
// The below logic is designed so as to unpack and rearrange scales and mins values in Q4_K
|
5036
|
+
// Currently the Q4_K structure has 8 scales and 8 mins packed in 12 bytes ( 6 bits for each value)
|
5037
|
+
// The output Q4_Kx8 structure has 96 bytes
|
5038
|
+
// Every 12 byte is packed such that it contains scales and mins for corresponding sub blocks from Q4_K structure
|
5039
|
+
// For eg - First 12 bytes contains 8 scales and 8 mins - each of first sub block from different Q4_K structures
|
5040
|
+
uint8_t s[8], m[8];
|
5041
|
+
|
5042
|
+
for (int i = 0; i < 4; i++) {
|
5043
|
+
for (int j = 0; j < 8; j++) {
|
5044
|
+
s[j] = in[j].scales[i] & 63;
|
5045
|
+
m[j] = in[j].scales[i + 4] & 63;
|
5046
|
+
}
|
5047
|
+
|
5048
|
+
out.scales[i * 12] = (s[0] & 63) + ((s[4] & 48) << 2);
|
5049
|
+
out.scales[i * 12 + 1] = (s[1] & 63) + ((s[5] & 48) << 2);
|
5050
|
+
out.scales[i * 12 + 2] = (s[2] & 63) + ((s[6] & 48) << 2);
|
5051
|
+
out.scales[i * 12 + 3] = (s[3] & 63) + ((s[7] & 48) << 2);
|
5052
|
+
out.scales[i * 12 + 4] = (m[0] & 63) + ((m[4] & 48) << 2);
|
5053
|
+
out.scales[i * 12 + 5] = (m[1] & 63) + ((m[5] & 48) << 2);
|
5054
|
+
out.scales[i * 12 + 6] = (m[2] & 63) + ((m[6] & 48) << 2);
|
5055
|
+
out.scales[i * 12 + 7] = (m[3] & 63) + ((m[7] & 48) << 2);
|
5056
|
+
out.scales[i * 12 + 8] = (s[4] & 15) + ((m[4] & 15) << 4);
|
5057
|
+
out.scales[i * 12 + 9] = (s[5] & 15) + ((m[5] & 15) << 4);
|
5058
|
+
out.scales[i * 12 + 10] = (s[6] & 15) + ((m[6] & 15) << 4);
|
5059
|
+
out.scales[i * 12 + 11] = (s[7] & 15) + ((m[7] & 15) << 4);
|
5060
|
+
|
5061
|
+
}
|
5062
|
+
|
5063
|
+
for (int i = 0; i < 4; i++) {
|
5064
|
+
for (int j = 0; j < 8; j++) {
|
5065
|
+
s[j] = ((in[j].scales[i] & 192) >> 2) | (in[j].scales[i+8] & 15);
|
5066
|
+
m[j] = ((in[j].scales[i + 4] & 192) >> 2) | ((in[j].scales[i+8] & 240) >> 4);
|
5067
|
+
}
|
5068
|
+
|
5069
|
+
out.scales[i * 12 + 48] = (s[0] & 63) + ((s[4] & 48) << 2);
|
5070
|
+
out.scales[i * 12 + 49] = (s[1] & 63) + ((s[5] & 48) << 2);
|
5071
|
+
out.scales[i * 12 + 50] = (s[2] & 63) + ((s[6] & 48) << 2);
|
5072
|
+
out.scales[i * 12 + 51] = (s[3] & 63) + ((s[7] & 48) << 2);
|
5073
|
+
out.scales[i * 12 + 52] = (m[0] & 63) + ((m[4] & 48) << 2);
|
5074
|
+
out.scales[i * 12 + 53] = (m[1] & 63) + ((m[5] & 48) << 2);
|
5075
|
+
out.scales[i * 12 + 54] = (m[2] & 63) + ((m[6] & 48) << 2);
|
5076
|
+
out.scales[i * 12 + 55] = (m[3] & 63) + ((m[7] & 48) << 2);
|
5077
|
+
out.scales[i * 12 + 56] = (s[4] & 15) + ((m[4] & 15) << 4);
|
5078
|
+
out.scales[i * 12 + 57] = (s[5] & 15) + ((m[5] & 15) << 4);
|
5079
|
+
out.scales[i * 12 + 58] = (s[6] & 15) + ((m[6] & 15) << 4);
|
5080
|
+
out.scales[i * 12 + 59] = (s[7] & 15) + ((m[7] & 15) << 4);
|
5081
|
+
|
5082
|
+
}
|
5083
|
+
|
5084
|
+
return out;
|
5085
|
+
}
|
5086
|
+
|
3663
5087
|
static int repack_q4_0_to_q4_0_4_bl(struct lm_ggml_tensor * t, int interleave_block, const void * LM_GGML_RESTRICT data, size_t data_size) {
|
3664
5088
|
LM_GGML_ASSERT(t->type == LM_GGML_TYPE_Q4_0);
|
3665
5089
|
LM_GGML_ASSERT(interleave_block == 4 || interleave_block == 8);
|
@@ -3690,6 +5114,36 @@ static int repack_q4_0_to_q4_0_4_bl(struct lm_ggml_tensor * t, int interleave_bl
|
|
3690
5114
|
|
3691
5115
|
LM_GGML_UNUSED(data_size);
|
3692
5116
|
}
|
5117
|
+
static int repack_q4_K_to_q4_K_8_bl(struct lm_ggml_tensor * t, int interleave_block, const void * LM_GGML_RESTRICT data, size_t data_size) {
|
5118
|
+
LM_GGML_ASSERT(t->type == LM_GGML_TYPE_Q4_K);
|
5119
|
+
LM_GGML_ASSERT(interleave_block == 8);
|
5120
|
+
constexpr int nrows_interleaved = 8;
|
5121
|
+
|
5122
|
+
block_q4_Kx8 * dst = (block_q4_Kx8*)t->data;
|
5123
|
+
const block_q4_K * src = (const block_q4_K*) data;
|
5124
|
+
block_q4_K dst_tmp[8];
|
5125
|
+
int nrow = lm_ggml_nrows(t);
|
5126
|
+
int nblocks = t->ne[0] / QK_K;
|
5127
|
+
|
5128
|
+
LM_GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_K));
|
5129
|
+
|
5130
|
+
if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
|
5131
|
+
return -1;
|
5132
|
+
}
|
5133
|
+
|
5134
|
+
for (int b = 0; b < nrow; b += nrows_interleaved) {
|
5135
|
+
for (int64_t x = 0; x < nblocks; x++) {
|
5136
|
+
for (int i = 0; i < nrows_interleaved; i++ ) {
|
5137
|
+
dst_tmp[i] = src[x + i * nblocks];
|
5138
|
+
}
|
5139
|
+
*dst++ = make_block_q4_Kx8(dst_tmp, interleave_block);
|
5140
|
+
}
|
5141
|
+
src += nrows_interleaved * nblocks;
|
5142
|
+
}
|
5143
|
+
return 0;
|
5144
|
+
|
5145
|
+
LM_GGML_UNUSED(data_size);
|
5146
|
+
}
|
3693
5147
|
|
3694
5148
|
static int repack_q4_0_to_q4_0_8_bl(struct lm_ggml_tensor * t, int interleave_block, const void * LM_GGML_RESTRICT data, size_t data_size) {
|
3695
5149
|
LM_GGML_ASSERT(t->type == LM_GGML_TYPE_Q4_0);
|
@@ -3807,6 +5261,10 @@ template <> int repack<block_q4_0, 8, 8>(struct lm_ggml_tensor * t, const void *
|
|
3807
5261
|
return repack_q4_0_to_q4_0_8_bl(t, 8, data, data_size);
|
3808
5262
|
}
|
3809
5263
|
|
5264
|
+
template <> int repack<block_q4_K, 8, 8>(struct lm_ggml_tensor * t, const void * data, size_t data_size) {
|
5265
|
+
return repack_q4_K_to_q4_K_8_bl(t, 8, data, data_size);
|
5266
|
+
}
|
5267
|
+
|
3810
5268
|
template <> int repack<block_iq4_nl, 4, 4>(struct lm_ggml_tensor * t, const void * data, size_t data_size) {
|
3811
5269
|
return repack_iq4_nl_to_iq4_nl_4_bl(t, 4, data, data_size);
|
3812
5270
|
}
|
@@ -3817,44 +5275,50 @@ template <> int repack<block_iq4_nl, 4, 4>(struct lm_ggml_tensor * t, const void
|
|
3817
5275
|
//}
|
3818
5276
|
|
3819
5277
|
// gemv
|
3820
|
-
template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS>
|
5278
|
+
template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, lm_ggml_type PARAM_TYPE>
|
3821
5279
|
void gemv(int, float *, size_t, const void *, const void *, int, int);
|
3822
5280
|
|
3823
|
-
template <> void gemv<block_q4_0, 4, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
5281
|
+
template <> void gemv<block_q4_0, 4, 4, LM_GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
3824
5282
|
lm_ggml_gemv_q4_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
|
3825
5283
|
}
|
3826
5284
|
|
3827
|
-
template <> void gemv<block_q4_0, 8, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
5285
|
+
template <> void gemv<block_q4_0, 8, 4, LM_GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
3828
5286
|
lm_ggml_gemv_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
|
3829
5287
|
}
|
3830
5288
|
|
3831
|
-
template <> void gemv<block_q4_0, 8, 8>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
5289
|
+
template <> void gemv<block_q4_0, 8, 8, LM_GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
3832
5290
|
lm_ggml_gemv_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
|
3833
5291
|
}
|
3834
5292
|
|
3835
|
-
template <>
|
3836
|
-
|
5293
|
+
template <> void gemv<block_q4_K, 8, 8, LM_GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
5294
|
+
lm_ggml_gemv_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
|
5295
|
+
}
|
5296
|
+
|
5297
|
+
template <> void gemv<block_iq4_nl, 4, 4, LM_GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
3837
5298
|
lm_ggml_gemv_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
|
3838
5299
|
}
|
3839
5300
|
|
3840
5301
|
// gemm
|
3841
|
-
template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS>
|
5302
|
+
template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, lm_ggml_type PARAM_TYPE>
|
3842
5303
|
void gemm(int, float *, size_t, const void *, const void *, int, int);
|
3843
5304
|
|
3844
|
-
template <> void gemm<block_q4_0, 4, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
5305
|
+
template <> void gemm<block_q4_0, 4, 4, LM_GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
3845
5306
|
lm_ggml_gemm_q4_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
|
3846
5307
|
}
|
3847
5308
|
|
3848
|
-
template <> void gemm<block_q4_0, 8, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
5309
|
+
template <> void gemm<block_q4_0, 8, 4, LM_GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
3849
5310
|
lm_ggml_gemm_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
|
3850
5311
|
}
|
3851
5312
|
|
3852
|
-
template <> void gemm<block_q4_0, 8, 8>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
5313
|
+
template <> void gemm<block_q4_0, 8, 8, LM_GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
3853
5314
|
lm_ggml_gemm_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
|
3854
5315
|
}
|
3855
5316
|
|
3856
|
-
template <>
|
3857
|
-
|
5317
|
+
template <> void gemm<block_q4_K, 8, 8, LM_GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
5318
|
+
lm_ggml_gemm_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
|
5319
|
+
}
|
5320
|
+
|
5321
|
+
template <> void gemm<block_iq4_nl, 4, 4, LM_GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
3858
5322
|
lm_ggml_gemm_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
|
3859
5323
|
}
|
3860
5324
|
|
@@ -3863,37 +5327,37 @@ class tensor_traits_base : public ggml::cpu::tensor_traits {
|
|
3863
5327
|
virtual int repack(struct lm_ggml_tensor * t, const void * data, size_t data_size) = 0;
|
3864
5328
|
};
|
3865
5329
|
|
3866
|
-
template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS> class tensor_traits : public tensor_traits_base {
|
5330
|
+
template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, lm_ggml_type PARAM_TYPE> class tensor_traits : public tensor_traits_base {
|
3867
5331
|
|
3868
5332
|
bool work_size(int /* n_threads */, const struct lm_ggml_tensor * op, size_t & size) override {
|
3869
5333
|
// not realy a LM_GGML_TYPE_Q8_0 but same size.
|
3870
5334
|
switch (op->op) {
|
3871
|
-
|
3872
|
-
|
3873
|
-
|
3874
|
-
|
3875
|
-
|
3876
|
-
|
3877
|
-
|
3878
|
-
|
3879
|
-
|
3880
|
-
|
3881
|
-
|
5335
|
+
case LM_GGML_OP_MUL_MAT:
|
5336
|
+
size = lm_ggml_row_size(PARAM_TYPE, lm_ggml_nelements(op->src[1]));
|
5337
|
+
return true;
|
5338
|
+
case LM_GGML_OP_MUL_MAT_ID:
|
5339
|
+
size = lm_ggml_row_size(PARAM_TYPE, lm_ggml_nelements(op->src[1]));
|
5340
|
+
size = LM_GGML_PAD(size, sizeof(int64_t)); // + padding for next bloc.
|
5341
|
+
size += sizeof(int64_t) * (1+op->src[0]->ne[2]) * op->src[1]->ne[2];
|
5342
|
+
return true;
|
5343
|
+
default:
|
5344
|
+
// LM_GGML_ABORT("fatal error");
|
5345
|
+
break;
|
3882
5346
|
}
|
3883
5347
|
return false;
|
3884
5348
|
}
|
3885
5349
|
|
3886
5350
|
bool compute_forward(struct lm_ggml_compute_params * params, struct lm_ggml_tensor * op) override {
|
3887
5351
|
switch (op->op) {
|
3888
|
-
|
3889
|
-
|
3890
|
-
|
3891
|
-
|
3892
|
-
|
3893
|
-
|
3894
|
-
|
3895
|
-
|
3896
|
-
|
5352
|
+
case LM_GGML_OP_MUL_MAT:
|
5353
|
+
forward_mul_mat(params, op);
|
5354
|
+
return true;
|
5355
|
+
case LM_GGML_OP_MUL_MAT_ID:
|
5356
|
+
forward_mul_mat_id(params, op);
|
5357
|
+
return true;
|
5358
|
+
default:
|
5359
|
+
// LM_GGML_ABORT("fatal error");
|
5360
|
+
break;
|
3897
5361
|
}
|
3898
5362
|
return false;
|
3899
5363
|
}
|
@@ -3925,17 +5389,17 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS> class tensor_
|
|
3925
5389
|
// LM_GGML_ASSERT(lm_ggml_n_dims(op->src[1]) == 2);
|
3926
5390
|
|
3927
5391
|
char * wdata = static_cast<char *>(params->wdata);
|
3928
|
-
const size_t nbw1 = lm_ggml_row_size(
|
5392
|
+
const size_t nbw1 = lm_ggml_row_size(PARAM_TYPE, ne10);
|
3929
5393
|
|
3930
5394
|
assert(params->wsize >= nbw1 * ne11);
|
3931
5395
|
|
3932
|
-
const lm_ggml_from_float_t from_float = lm_ggml_get_type_traits_cpu(
|
5396
|
+
const lm_ggml_from_float_t from_float = lm_ggml_get_type_traits_cpu(PARAM_TYPE)->from_float;
|
3933
5397
|
|
3934
5398
|
int64_t i11_processed = 0;
|
3935
5399
|
for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
|
3936
|
-
|
3937
|
-
INTER_SIZE);
|
5400
|
+
lm_ggml_quantize_mat_t<INTER_SIZE, PARAM_TYPE>((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), 4, ne10);
|
3938
5401
|
}
|
5402
|
+
|
3939
5403
|
i11_processed = ne11 - ne11 % 4;
|
3940
5404
|
for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) {
|
3941
5405
|
from_float((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), ne10);
|
@@ -3944,26 +5408,28 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS> class tensor_
|
|
3944
5408
|
lm_ggml_barrier(params->threadpool);
|
3945
5409
|
|
3946
5410
|
const void * src1_wdata = params->wdata;
|
3947
|
-
const size_t src1_col_stride = lm_ggml_row_size(
|
5411
|
+
const size_t src1_col_stride = lm_ggml_row_size(PARAM_TYPE, ne10);
|
3948
5412
|
int64_t src0_start = (ith * ne01) / nth;
|
3949
5413
|
int64_t src0_end = ((ith + 1) * ne01) / nth;
|
3950
5414
|
src0_start = (src0_start % NB_COLS) ? src0_start + NB_COLS - (src0_start % NB_COLS) : src0_start;
|
3951
|
-
src0_end = (src0_end
|
5415
|
+
src0_end = (src0_end % NB_COLS) ? src0_end + NB_COLS - (src0_end % NB_COLS) : src0_end;
|
3952
5416
|
if (src0_start >= src0_end) {
|
3953
5417
|
return;
|
3954
5418
|
}
|
3955
5419
|
|
3956
5420
|
// If there are more than three rows in src1, use gemm; otherwise, use gemv.
|
3957
5421
|
if (ne11 > 3) {
|
3958
|
-
gemm<BLOC_TYPE, INTER_SIZE, NB_COLS>(ne00,
|
3959
|
-
|
3960
|
-
|
5422
|
+
gemm<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
|
5423
|
+
(float *) ((char *) dst->data) + src0_start, ne01,
|
5424
|
+
(const char *) src0->data + src0_start * nb01,
|
5425
|
+
(const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start);
|
3961
5426
|
}
|
3962
5427
|
for (int iter = ne11 - ne11 % 4; iter < ne11; iter++) {
|
3963
|
-
gemv<BLOC_TYPE, INTER_SIZE, NB_COLS>(ne00,
|
3964
|
-
|
3965
|
-
|
3966
|
-
|
5428
|
+
gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
|
5429
|
+
(float *) ((char *) dst->data + (iter * nb1)) + src0_start, ne01,
|
5430
|
+
(const char *) src0->data + src0_start * nb01,
|
5431
|
+
(const char *) src1_wdata + (src1_col_stride * iter), 1,
|
5432
|
+
src0_end - src0_start);
|
3967
5433
|
}
|
3968
5434
|
}
|
3969
5435
|
|
@@ -3978,7 +5444,7 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS> class tensor_
|
|
3978
5444
|
const int ith = params->ith;
|
3979
5445
|
const int nth = params->nth;
|
3980
5446
|
|
3981
|
-
const lm_ggml_from_float_t from_float = lm_ggml_get_type_traits_cpu(
|
5447
|
+
const lm_ggml_from_float_t from_float = lm_ggml_get_type_traits_cpu(PARAM_TYPE)->from_float;
|
3982
5448
|
|
3983
5449
|
// we don't support permuted src0 or src1
|
3984
5450
|
LM_GGML_ASSERT(nb00 == lm_ggml_type_size(src0->type));
|
@@ -4000,7 +5466,7 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS> class tensor_
|
|
4000
5466
|
const int n_ids = ids->ne[0]; // n_expert_used
|
4001
5467
|
const int n_as = ne02; // n_expert
|
4002
5468
|
|
4003
|
-
const size_t nbw1 = lm_ggml_row_size(
|
5469
|
+
const size_t nbw1 = lm_ggml_row_size(PARAM_TYPE, ne10);
|
4004
5470
|
const size_t nbw2 = nbw1*ne11;
|
4005
5471
|
const size_t nbw3 = nbw2*ne12;
|
4006
5472
|
|
@@ -4012,12 +5478,13 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS> class tensor_
|
|
4012
5478
|
LM_GGML_ASSERT(params->wsize >= (LM_GGML_PAD(nbw3, sizeof(int64_t)) + n_as * sizeof(int64_t) +
|
4013
5479
|
n_as * ne12 * sizeof(mmid_row_mapping)));
|
4014
5480
|
|
4015
|
-
auto
|
4016
|
-
auto
|
4017
|
-
|
5481
|
+
auto * wdata = (char *) params->wdata;
|
5482
|
+
auto * wdata_src1_end = (char *) wdata + LM_GGML_PAD(nbw3, sizeof(int64_t));
|
5483
|
+
auto * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
|
5484
|
+
|
4018
5485
|
struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *) (matrix_row_counts + n_as); // [n_as][ne12]
|
4019
5486
|
|
4020
|
-
// src1: float32 =>
|
5487
|
+
// src1: float32 => param type
|
4021
5488
|
for (int64_t i12 = 0; i12 < ne12; ++i12) {
|
4022
5489
|
for (int64_t i11 = ith; i11 < ne11; i11 += nth) {
|
4023
5490
|
from_float((float *)((char *) src1->data + i12 * nb12 + i11 * nb11),
|
@@ -4056,34 +5523,37 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS> class tensor_
|
|
4056
5523
|
continue;
|
4057
5524
|
}
|
4058
5525
|
|
4059
|
-
auto src0_cur = (const char *) src0->data + cur_a*nb02;
|
5526
|
+
const auto * src0_cur = (const char *) src0->data + cur_a*nb02;
|
4060
5527
|
|
4061
5528
|
//const int64_t nr0 = ne01; // src0 rows
|
4062
5529
|
const int64_t nr1 = cne1; // src1 rows
|
4063
5530
|
|
4064
5531
|
int64_t src0_cur_start = (ith * ne01) / nth;
|
4065
5532
|
int64_t src0_cur_end = ((ith + 1) * ne01) / nth;
|
4066
|
-
src0_cur_start =
|
4067
|
-
(src0_cur_start % NB_COLS) ? src0_cur_start + NB_COLS - (src0_cur_start % NB_COLS) : src0_cur_start;
|
4068
|
-
src0_cur_end = (src0_cur_end % NB_COLS) ? src0_cur_end + NB_COLS - (src0_cur_end % NB_COLS) : src0_cur_end;
|
4069
5533
|
|
4070
|
-
|
5534
|
+
src0_cur_start = (src0_cur_start % NB_COLS) ? src0_cur_start + NB_COLS - (src0_cur_start % NB_COLS) : src0_cur_start;
|
5535
|
+
src0_cur_end = (src0_cur_end % NB_COLS) ? src0_cur_end + NB_COLS - (src0_cur_end % NB_COLS) : src0_cur_end;
|
5536
|
+
|
5537
|
+
if (src0_cur_start >= src0_cur_end) {
|
5538
|
+
return;
|
5539
|
+
}
|
4071
5540
|
|
4072
5541
|
for (int ir1 = 0; ir1 < nr1; ir1++) {
|
4073
5542
|
struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, ir1);
|
4074
|
-
const int id = row_mapping.i1; // selected expert index
|
4075
5543
|
|
4076
|
-
const
|
4077
|
-
const int64_t i12 = row_mapping.i2; // row index in src1
|
5544
|
+
const int id = row_mapping.i1; // selected expert index
|
4078
5545
|
|
4079
|
-
const int64_t
|
4080
|
-
const int64_t
|
5546
|
+
const int64_t i11 = id % ne11;
|
5547
|
+
const int64_t i12 = row_mapping.i2; // row index in src1
|
4081
5548
|
|
4082
|
-
|
5549
|
+
const int64_t i1 = id; // selected expert index
|
5550
|
+
const int64_t i2 = i12; // row
|
4083
5551
|
|
4084
|
-
|
4085
|
-
|
4086
|
-
|
5552
|
+
const auto * src1_col = (const char *) wdata + (i11 * nbw1 + i12 * nbw2);
|
5553
|
+
|
5554
|
+
gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
|
5555
|
+
(float *)((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start, ne01,
|
5556
|
+
src0_cur + src0_cur_start * nb01,
|
4087
5557
|
src1_col, 1, src0_cur_end - src0_cur_start);
|
4088
5558
|
}
|
4089
5559
|
}
|
@@ -4098,12 +5568,13 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS> class tensor_
|
|
4098
5568
|
};
|
4099
5569
|
|
4100
5570
|
// instance for Q4
|
4101
|
-
static const tensor_traits<block_q4_0, 4, 4> q4_0_4x4_q8_0;
|
4102
|
-
static const tensor_traits<block_q4_0, 8, 4> q4_0_4x8_q8_0;
|
4103
|
-
static const tensor_traits<block_q4_0, 8, 8> q4_0_8x8_q8_0;
|
5571
|
+
static const tensor_traits<block_q4_0, 4, 4, LM_GGML_TYPE_Q8_0> q4_0_4x4_q8_0;
|
5572
|
+
static const tensor_traits<block_q4_0, 8, 4, LM_GGML_TYPE_Q8_0> q4_0_4x8_q8_0;
|
5573
|
+
static const tensor_traits<block_q4_0, 8, 8, LM_GGML_TYPE_Q8_0> q4_0_8x8_q8_0;
|
5574
|
+
static const tensor_traits<block_q4_K, 8, 8, LM_GGML_TYPE_Q8_K> q4_K_8x8_q8_K;
|
4104
5575
|
|
4105
5576
|
// instance for IQ4
|
4106
|
-
static const tensor_traits<block_iq4_nl, 4, 4> iq4_nl_4x4_q8_0;
|
5577
|
+
static const tensor_traits<block_iq4_nl, 4, 4, LM_GGML_TYPE_Q8_0> iq4_nl_4x4_q8_0;
|
4107
5578
|
|
4108
5579
|
} // namespace ggml::cpu::aarch64
|
4109
5580
|
|
@@ -4124,6 +5595,12 @@ static const ggml::cpu::tensor_traits * lm_ggml_aarch64_get_optimal_repack_type(
|
|
4124
5595
|
return &ggml::cpu::aarch64::q4_0_4x4_q8_0;
|
4125
5596
|
}
|
4126
5597
|
}
|
5598
|
+
} else if (cur->type == LM_GGML_TYPE_Q4_K) {
|
5599
|
+
if (lm_ggml_cpu_has_avx2()) {
|
5600
|
+
if (cur->ne[1] % 8 == 0) {
|
5601
|
+
return &ggml::cpu::aarch64::q4_K_8x8_q8_K;
|
5602
|
+
}
|
5603
|
+
}
|
4127
5604
|
} else if (cur->type == LM_GGML_TYPE_IQ4_NL) {
|
4128
5605
|
if (lm_ggml_cpu_has_neon() && lm_ggml_cpu_has_dotprod()) {
|
4129
5606
|
if (cur->ne[1] % 4 == 0) {
|