cui-llama.rn 1.1.2 → 1.1.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -36,6 +36,84 @@
36
36
  // from bias offset form to pure sign form (this saves subtract
37
37
  // operations durin unpacking)
38
38
  //
39
+ #if defined(__AVX__)
40
+ #if defined(__F16C__)
41
+ // the _mm256_cvt intrinsics require F16C
42
+ #define LM_GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x)))
43
+ #define LM_GGML_F32Cx8_REPEAT_LOAD(x, loadMask) _mm256_cvtph_ps(_mm_shuffle_epi32(_mm_maskload_epi32((int const*)(x), loadMask), 68))
44
+ #define LM_GGML_F32Cx8_REARRANGE_LOAD(x, arrangeMask) _mm256_cvtph_ps(_mm_shuffle_epi8(_mm_loadu_si128((const __m128i *) x), arrangeMask))
45
+ #else
46
+ static inline __m256 __avx_f32cx8_load(lm_ggml_fp16_t *x) {
47
+ float tmp[8];
48
+
49
+ for (int i = 0; i < 8; i++) {
50
+ tmp[i] = LM_GGML_FP16_TO_FP32(x[i]);
51
+ }
52
+
53
+ return _mm256_loadu_ps(tmp);
54
+ }
55
+ static inline __m256 __avx_repeat_f32cx8_load(lm_ggml_fp16_t *x) {
56
+ float tmp[8];
57
+
58
+ for (int i = 0; i < 4; i++) {
59
+ tmp[i] = LM_GGML_FP16_TO_FP32(x[i]);
60
+ tmp[i + 4] = LM_GGML_FP16_TO_FP32(x[i]);
61
+ }
62
+
63
+ return _mm256_loadu_ps(tmp);
64
+ }
65
+ static inline __m256 __avx_rearranged_f32cx8_load(lm_ggml_fp16_t *x, __m128i arrangeMask) {
66
+ uint16_t tmphalf[8];
67
+ float tmp[8];
68
+
69
+ _mm_storeu_si128((__m128i*)tmphalf, _mm_shuffle_epi8(_mm_loadu_si128((const __m128i *) x), arrangeMask));
70
+ for (int i = 0; i < 8; i++) {
71
+ tmp[i] = LM_GGML_FP16_TO_FP32(tmphalf[i]);
72
+ }
73
+
74
+ return _mm256_loadu_ps(tmp);
75
+ }
76
+
77
+ #define LM_GGML_F32Cx8_LOAD(x) __avx_f32cx8_load(x)
78
+ #define LM_GGML_F32Cx8_REPEAT_LOAD(x, loadMask) __avx_repeat_f32cx8_load(x)
79
+ #define LM_GGML_F32Cx8_REARRANGE_LOAD(x, arrangeMask) __avx_rearranged_f32cx8_load(x, arrangeMask)
80
+ #endif
81
+ #endif
82
+
83
+
84
+ #if defined(__AVX2__) || defined(__AVX512F__)
85
+ static inline __m256i sum_i16_pairs_int(const __m256i x) {
86
+ const __m256i ones = _mm256_set1_epi16(1);
87
+ return _mm256_madd_epi16(ones, x);
88
+ }
89
+
90
+ static inline __m256i mul_sum_us8_pairs_int(const __m256i ax, const __m256i sy) {
91
+ #if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
92
+ const __m256i zero = _mm256_setzero_si256();
93
+ return _mm256_dpbusd_epi32(zero, ax, sy);
94
+ #else
95
+ // Perform multiplication and create 16-bit values
96
+ const __m256i dot = _mm256_maddubs_epi16(ax, sy);
97
+ return sum_i16_pairs_int(dot);
98
+ #endif
99
+ }
100
+
101
+ // Integer variant of the function defined in ggml-quants.c
102
+ // multiply int8_t, add results pairwise twice and return as float vector
103
+ static inline __m256i mul_sum_i8_pairs_int(const __m256i x, const __m256i y) {
104
+ #if __AVXVNNIINT8__
105
+ const __m256i zero = _mm256_setzero_si256();
106
+ return _mm256_dpbssd_epi32(zero, x, y);
107
+ #else
108
+ // Get absolute values of x vectors
109
+ const __m256i ax = _mm256_sign_epi8(x, x);
110
+ // Sign the values of the y vectors
111
+ const __m256i sy = _mm256_sign_epi8(y, x);
112
+ return mul_sum_us8_pairs_int(ax, sy);
113
+ #endif
114
+ }
115
+ #endif
116
+
39
117
  static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int blck_size_interleave, unsigned int xor_mask) {
40
118
  block_q4_0x4 out;
41
119
 
@@ -255,6 +333,103 @@ void quantize_q8_0_4x8(const float * restrict x, void * restrict vy, int64_t k)
255
333
  y[i].qs[32 * j + 31] = vgetq_lane_s32(vi, 3);
256
334
  }
257
335
  }
336
+ #elif defined(__AVX2__) || defined(__AVX__)
337
+ float id[4];
338
+ __m256 srcv[4][4];
339
+ __m256 idvec[4];
340
+
341
+ for (int i = 0; i < nb; i++) {
342
+ for (int row_iter = 0; row_iter < 4; row_iter++) {
343
+ // Load elements into 4 AVX vectors
344
+ __m256 v0 = _mm256_loadu_ps( x + row_iter * k + i * 32 );
345
+ __m256 v1 = _mm256_loadu_ps( x + row_iter * k + i * 32 + 8 );
346
+ __m256 v2 = _mm256_loadu_ps( x + row_iter * k + i * 32 + 16 );
347
+ __m256 v3 = _mm256_loadu_ps( x + row_iter * k + i * 32 + 24 );
348
+
349
+ // Compute max(abs(e)) for the block
350
+ const __m256 signBit = _mm256_set1_ps( -0.0f );
351
+ __m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
352
+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
353
+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
354
+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
355
+
356
+ __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
357
+ max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
358
+ max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
359
+ const float maxScalar = _mm_cvtss_f32( max4 );
360
+
361
+ // Divided by 127.f to mirror results in quantize_row_q8_0
362
+ const float d = maxScalar / 127.f;
363
+ id[row_iter] = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f; //d ? 1.0f / d : 0.0f;
364
+
365
+ // Store the scale for the individual block
366
+ y[i].d[row_iter] = LM_GGML_FP32_TO_FP16(d);
367
+
368
+ // Store the values in blocks of eight values - Aim is to use these later for block interleaving
369
+ srcv[row_iter][0] = v0;
370
+ srcv[row_iter][1] = v1;
371
+ srcv[row_iter][2] = v2;
372
+ srcv[row_iter][3] = v3;
373
+ idvec[row_iter] = _mm256_set1_ps(id[row_iter]);
374
+ }
375
+
376
+ // The loop iterates four times - The aim is to get 4 corresponding chunks of eight bytes from the original weight blocks that are interleaved
377
+ for (int j = 0; j < 4; j++) {
378
+ // Apply the multiplier
379
+ __m256 v0 = _mm256_mul_ps(srcv[0][j], idvec[0]);
380
+ __m256 v1 = _mm256_mul_ps(srcv[1][j], idvec[1]);
381
+ __m256 v2 = _mm256_mul_ps(srcv[2][j], idvec[2]);
382
+ __m256 v3 = _mm256_mul_ps(srcv[3][j], idvec[3]);
383
+
384
+ // Round to nearest integer
385
+ v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
386
+ v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
387
+ v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
388
+ v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
389
+
390
+ // Convert floats to integers
391
+ __m256i i0 = _mm256_cvtps_epi32( v0 );
392
+ __m256i i1 = _mm256_cvtps_epi32( v1 );
393
+ __m256i i2 = _mm256_cvtps_epi32( v2 );
394
+ __m256i i3 = _mm256_cvtps_epi32( v3 );
395
+
396
+ #if defined(__AVX2__)
397
+ // Convert int32 to int16
398
+ i0 = _mm256_packs_epi32( i0, i1 );
399
+ i2 = _mm256_packs_epi32( i2, i3 );
400
+ // Convert int16 to int8
401
+ i0 = _mm256_packs_epi16( i0, i2 );
402
+
403
+ // Permute and store the quantized weights in the required order after the pack instruction
404
+ const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
405
+ i0 = _mm256_permutevar8x32_epi32( i0, perm );
406
+
407
+ _mm256_storeu_si256((__m256i *)(y[i].qs + 32 * j), i0);
408
+ #else
409
+ // Since we don't have in AVX some necessary functions,
410
+ // we split the registers in half and call AVX2 analogs from SSE
411
+ __m128i ni0 = _mm256_castsi256_si128( i0 );
412
+ __m128i ni1 = _mm256_extractf128_si256( i0, 1);
413
+ __m128i ni2 = _mm256_castsi256_si128( i1 );
414
+ __m128i ni3 = _mm256_extractf128_si256( i1, 1);
415
+ __m128i ni4 = _mm256_castsi256_si128( i2 );
416
+ __m128i ni5 = _mm256_extractf128_si256( i2, 1);
417
+ __m128i ni6 = _mm256_castsi256_si128( i3 );
418
+ __m128i ni7 = _mm256_extractf128_si256( i3, 1);
419
+
420
+ // Convert int32 to int16
421
+ ni0 = _mm_packs_epi32( ni0, ni1 );
422
+ ni2 = _mm_packs_epi32( ni2, ni3 );
423
+ ni4 = _mm_packs_epi32( ni4, ni5 );
424
+ ni6 = _mm_packs_epi32( ni6, ni7 );
425
+ // Convert int16 to int8
426
+ ni0 = _mm_packs_epi16( ni0, ni2 );
427
+ ni4 = _mm_packs_epi16( ni4, ni6 );
428
+ _mm_storeu_si128((__m128i *)(y[i].qs + 32 * j), ni0);
429
+ _mm_storeu_si128((__m128i *)(y[i].qs + 32 * j + 16), ni4);
430
+ #endif
431
+ }
432
+ }
258
433
  #else
259
434
  // scalar
260
435
  const int blck_size_interleave = 8;
@@ -684,6 +859,96 @@ void lm_ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void
684
859
  LM_GGML_ASSERT((lm_ggml_cpu_has_sve() || lm_ggml_cpu_has_matmul_int8()) &&
685
860
  "__ARM_FEATURE_SVE and __ARM_FEATURE_MATMUL_INT8 not defined, use the Q4_0_4_4 quantization format for optimal "
686
861
  "performance");
862
+ #elif defined(__AVX2__)
863
+ // Lookup table to convert signed nibbles to signed bytes
864
+ __m256i signextendlut = _mm256_castsi128_si256(_mm_set_epi8(-1, -2, -3, -4, -5, -6, -7, -8, 7, 6, 5, 4, 3, 2, 1, 0));
865
+ signextendlut = _mm256_permute2f128_si256(signextendlut, signextendlut, 0);
866
+ __m128i changemask = _mm_set_epi8(15, 14, 7, 6, 13, 12, 5, 4, 11, 10, 3, 2, 9, 8, 1, 0);
867
+ __m256i finalpermutemask = _mm256_set_epi32(7, 5, 3, 1, 6, 4, 2, 0);
868
+
869
+ // Permute mask used for easier vector processing at later stages
870
+ const __m256i m4b = _mm256_set1_epi8(0x0F);
871
+
872
+ int64_t b_nb = n / QK4_0;
873
+
874
+ const block_q4_0x8 * b_ptr_start = (const block_q4_0x8 *)vx;
875
+ const block_q8_0 * a_ptr_start = (const block_q8_0 *)vy;
876
+
877
+ // Process Q8_0 blocks one by one
878
+ for (int64_t y = 0; y < nr; y++) {
879
+
880
+ // Pointers to LHS blocks of block_q8_0 format
881
+ const block_q8_0 * a_ptr = a_ptr_start + (y * nb);
882
+
883
+ // Take group of eight block_q4_0x8 structures at each pass of the loop and perform dot product operation
884
+ for (int64_t x = 0; x < nc / 8; x++) {
885
+
886
+ // Pointers to RHS blocks
887
+ const block_q4_0x8 * b_ptr = b_ptr_start + (x * b_nb);
888
+
889
+ // Master FP accumulator
890
+ __m256 acc_row = _mm256_setzero_ps();
891
+
892
+ for (int64_t b = 0; b < nb; b++) {
893
+ // Load 8 blocks of Q4_0 interleaved as 8 bytes (B0 - B7)
894
+ const __m256i rhs_raw_vec_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs));
895
+ const __m256i rhs_raw_vec_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs) + 1);
896
+ const __m256i rhs_raw_vec_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs) + 2);
897
+ const __m256i rhs_raw_vec_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs) + 3);
898
+
899
+ // 4-bit -> 8-bit - Sign is maintained
900
+ const __m256i rhs_vec_0123_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_vec_0123_0, m4b)); // B0(0-7) B1(0-7) B2(0-7) B3(0-7)
901
+ const __m256i rhs_vec_4567_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_vec_4567_0, m4b)); // B4(0-7) B5(0-7) B6(0-7) B7(0-7)
902
+ const __m256i rhs_vec_0123_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_vec_0123_1, m4b)); // B0(8-15) B1(8-15) B2(8-15) B3(8-15)
903
+ const __m256i rhs_vec_4567_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_vec_4567_1, m4b)); // B0(8-15) B1(8-15) B2(8-15) B3(8-15)
904
+
905
+ const __m256i rhs_vec_0123_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_0, 4), m4b)); // B0(16-23) B1(16-23) B2(16-23) B3(16-23)
906
+ const __m256i rhs_vec_4567_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_0, 4), m4b)); // B4(16-23) B5(16-23) B6(16-23) B7(16-23)
907
+ const __m256i rhs_vec_0123_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_1, 4), m4b)); // B0(24-31) B1(24-31) B2(24-31) B3(24-31)
908
+ const __m256i rhs_vec_4567_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_1, 4), m4b)); // B4(24-31) B5(24-31) B6(24-31) B7(24-31)
909
+
910
+ // Load the scale values for the 8 blocks interleaved in block_q4_0x8
911
+ const __m256 col_scale_f32 = LM_GGML_F32Cx8_REARRANGE_LOAD(b_ptr[b].d, changemask);
912
+
913
+ // Load and convert to FP32 scale from block_q8_0
914
+ const __m256 row_scale_f32 = _mm256_set1_ps(LM_GGML_FP16_TO_FP32(a_ptr[b].d));
915
+
916
+ // Load the block values in block_q8_0 in batches of 16 bytes and replicate the same across 256 bit vector
917
+ __m256i lhs_vec_0 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)a_ptr[b].qs));
918
+ __m256i lhs_vec_1 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 16)));
919
+
920
+ lhs_vec_0 = _mm256_permute2f128_si256(lhs_vec_0, lhs_vec_0, 0); // A0 (0-15) A0(0-15)
921
+ lhs_vec_1 = _mm256_permute2f128_si256(lhs_vec_1, lhs_vec_1, 0); // A0 (16-31) A0(16-31))
922
+
923
+ __m256i iacc = _mm256_setzero_si256();
924
+
925
+ // Dot product done within 32 bit lanes and accumulated in the same vector
926
+ // B0(0-3) B4(0-3) B1(0-3) B5(0-3) B2(0-3) B6(0-3) B3(0-3) B7(0-3) with A0(0-3)
927
+ // B0(4-7) B4(4-7) B1(4-7) B5(4-7) B2(4-7) B6(4-7) B3(4-7) B7(4-7) with A0(4-7)
928
+ // ...........................................................................
929
+ // B0(28-31) B4(28-31) B1(28-31) B5(28-31) B2(28-31) B6(28-31) B3(28-31) B7(28-31) with A0(28-31)
930
+
931
+ iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int(_mm256_blend_epi32(rhs_vec_0123_0 ,_mm256_shuffle_epi32(rhs_vec_4567_0, 177), 170), _mm256_shuffle_epi32(lhs_vec_0, 0)));
932
+ iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_0, 177) ,rhs_vec_4567_0, 170), _mm256_shuffle_epi32(lhs_vec_0, 85)));
933
+
934
+ iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int(_mm256_blend_epi32(rhs_vec_0123_1 ,_mm256_shuffle_epi32(rhs_vec_4567_1, 177), 170), _mm256_shuffle_epi32(lhs_vec_0, 170)));
935
+ iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_1, 177) ,rhs_vec_4567_1, 170), _mm256_shuffle_epi32(lhs_vec_0, 255)));
936
+
937
+ iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int(_mm256_blend_epi32(rhs_vec_0123_2 ,_mm256_shuffle_epi32(rhs_vec_4567_2, 177), 170), _mm256_shuffle_epi32(lhs_vec_1, 0)));
938
+ iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_2, 177) ,rhs_vec_4567_2, 170), _mm256_shuffle_epi32(lhs_vec_1, 85)));
939
+
940
+ iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int(_mm256_blend_epi32(rhs_vec_0123_3 ,_mm256_shuffle_epi32(rhs_vec_4567_3, 177), 170), _mm256_shuffle_epi32(lhs_vec_1, 170)));
941
+ iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_3, 177) ,rhs_vec_4567_3, 170), _mm256_shuffle_epi32(lhs_vec_1, 255)));
942
+
943
+ // Accumulated values multipled with appropriate scales
944
+ acc_row = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc), _mm256_mul_ps(col_scale_f32, row_scale_f32), acc_row);
945
+ }
946
+
947
+ // Accumulated output values permuted so as to be stored in appropriate order post accumulation
948
+ acc_row = _mm256_permutevar8x32_ps(acc_row, finalpermutemask);
949
+ _mm256_storeu_ps(s + (y * nr + x * 8), acc_row);
950
+ }
951
+ }
687
952
  #else
688
953
  float sumf[8];
689
954
  int sumi;
@@ -2143,6 +2408,353 @@ void lm_ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void
2143
2408
  LM_GGML_ASSERT((lm_ggml_cpu_has_sve() || lm_ggml_cpu_has_matmul_int8()) &&
2144
2409
  "__ARM_FEATURE_SVE and __ARM_FEATURE_MATMUL_INT8 not defined, use the Q4_0_4_4 quantization format for optimal "
2145
2410
  "performance");
2411
+ #elif defined(__AVX2__) || defined(__AVX512F__)
2412
+ const block_q4_0x8 * b_ptr_start = (const block_q4_0x8 *)vx;
2413
+ const block_q8_0x4 * a_ptr_start = (const block_q8_0x4 *)vy;
2414
+ int64_t b_nb = n / QK4_0;
2415
+ int64_t y = 0;
2416
+ // Mask to mask out nibbles from packed bytes
2417
+ const __m256i m4b = _mm256_set1_epi8(0x0F);
2418
+ const __m128i loadMask = _mm_blend_epi32(_mm_setzero_si128(), _mm_set1_epi32(0xFFFFFFFF), 3);
2419
+ // Lookup table to convert signed nibbles to signed bytes
2420
+ __m256i signextendlut = _mm256_castsi128_si256(_mm_set_epi8(-1, -2, -3, -4, -5, -6, -7, -8, 7, 6, 5, 4, 3, 2, 1, 0));
2421
+ signextendlut = _mm256_permute2f128_si256(signextendlut, signextendlut, 0);
2422
+ // Permute mask used for easier vector processing at later stages
2423
+ __m256i requiredOrder = _mm256_set_epi32(3 ,2 ,1 ,0, 7 ,6, 5, 4);
2424
+
2425
+ // Take group of four block_q8_0x4 structures at each pass of the loop and perform dot product operation
2426
+ int anr = nr - nr %16; // Used to align nr with boundary of 16
2427
+
2428
+ for (; y < anr / 4; y += 4) {
2429
+ const block_q8_0x4 * a_ptrs[4];
2430
+
2431
+ a_ptrs[0] = a_ptr_start + (y * nb);
2432
+ for (int i = 0; i < 3; ++i) {
2433
+ a_ptrs[i + 1] = a_ptrs[i] + nb;
2434
+ }
2435
+
2436
+ // Take group of eight block_q4_0x8 structures at each pass of the loop and perform dot product operation
2437
+ for (int64_t x = 0; x < nc / 8; x++) {
2438
+
2439
+ const block_q4_0x8 * b_ptr = b_ptr_start + (x * b_nb);
2440
+
2441
+ // Master FP accumulators
2442
+ __m256 acc_rows[16];
2443
+ for (int i = 0; i < 16; i++) {
2444
+ acc_rows[i] = _mm256_setzero_ps();
2445
+ }
2446
+
2447
+ for (int64_t b = 0; b < nb; b++) {
2448
+ // Load the eight block_q4_0 quantized values interleaved with each other in chunks of eight - B0,B1 ....B6,B7
2449
+ const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs));
2450
+ const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 32));
2451
+ const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 64));
2452
+ const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 96));
2453
+
2454
+ // Save the values in the following vectors in the formats B0B1B4B5, B2B3B6B7 for further processing and storing of values
2455
+ const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240);
2456
+ const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240);
2457
+ const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240);
2458
+ const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240);
2459
+
2460
+ // 4-bit -> 8-bit - Sign is maintained
2461
+ const __m256i rhs_mat_0145_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_0145_0, m4b)); //B0(0-7) B1(0-7) B4(0-7) B5(0-7)
2462
+ const __m256i rhs_mat_2367_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_2367_0, m4b)); //B2(0-7) B3(0-7) B6(0-7) B7(0-7)
2463
+
2464
+ const __m256i rhs_mat_0145_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_0145_1, m4b)); //B0(8-15) B1(8-15) B4(8-15) B5(8-15)
2465
+ const __m256i rhs_mat_2367_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_2367_1, m4b)); //B2(8-15) B3(8-15) B6(8-15) B7(8-15)
2466
+
2467
+ const __m256i rhs_mat_0145_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_0, 4), m4b)); //B0(16-23) B1(16-23) B4(16-23) B5(16-23)
2468
+ const __m256i rhs_mat_2367_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_0, 4), m4b)); //B2(16-23) B3(16-23) B6(16-23) B7(16-23)
2469
+
2470
+ const __m256i rhs_mat_0145_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_1, 4), m4b)); //B0(24-31) B1(24-31) B4(24-31) B5(24-31)
2471
+ const __m256i rhs_mat_2367_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_1, 4), m4b)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31)
2472
+
2473
+ // Shuffle pattern one - right side input
2474
+ const __m256i rhs_mat_0145_0_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_0, 136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3)
2475
+ const __m256i rhs_mat_2367_0_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_0, 136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3)
2476
+
2477
+ const __m256i rhs_mat_0145_1_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_1, 136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11)
2478
+ const __m256i rhs_mat_2367_1_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_1, 136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11)
2479
+
2480
+ const __m256i rhs_mat_0145_2_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_2, 136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19)
2481
+ const __m256i rhs_mat_2367_2_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_2, 136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19)
2482
+
2483
+ const __m256i rhs_mat_0145_3_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_3, 136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27)
2484
+ const __m256i rhs_mat_2367_3_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_3, 136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27)
2485
+
2486
+ // Shuffle pattern two - right side input
2487
+
2488
+ const __m256i rhs_mat_0145_0_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_0, 221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7)
2489
+ const __m256i rhs_mat_2367_0_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_0, 221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7)
2490
+
2491
+ const __m256i rhs_mat_0145_1_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_1, 221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15)
2492
+ const __m256i rhs_mat_2367_1_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_1, 221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15)
2493
+
2494
+ const __m256i rhs_mat_0145_2_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_2, 221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23)
2495
+ const __m256i rhs_mat_2367_2_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_2, 221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23)
2496
+
2497
+ const __m256i rhs_mat_0145_3_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_3, 221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31)
2498
+ const __m256i rhs_mat_2367_3_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_3, 221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31)
2499
+
2500
+ // Scale values - Load the wight scale values of block_q4_0x8
2501
+ const __m256 col_scale_f32 = LM_GGML_F32Cx8_LOAD(b_ptr[b].d);
2502
+
2503
+ // Process LHS in groups of four
2504
+ for (int rp = 0; rp < 4; rp++) {
2505
+ // Load the four block_q4_0 quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3
2506
+ // Loaded as set of 128 bit vectors and repeated into a 256 bit vector
2507
+ __m256i lhs_mat_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs)));
2508
+ __m256i lhs_mat_01_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 0);
2509
+ __m256i lhs_mat_23_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 17);
2510
+ __m256i lhs_mat_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 32)));
2511
+ __m256i lhs_mat_01_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 0);
2512
+ __m256i lhs_mat_23_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 17);
2513
+ __m256i lhs_mat_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 64)));
2514
+ __m256i lhs_mat_01_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 0);
2515
+ __m256i lhs_mat_23_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 17);
2516
+ __m256i lhs_mat_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 96)));
2517
+ __m256i lhs_mat_01_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 0);
2518
+ __m256i lhs_mat_23_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 17);
2519
+
2520
+ // Shuffle pattern one - left side input
2521
+ const __m256i lhs_mat_01_0_sp1 = _mm256_shuffle_epi32(lhs_mat_01_0, 160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3)
2522
+ const __m256i lhs_mat_23_0_sp1 = _mm256_shuffle_epi32(lhs_mat_23_0, 160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3)
2523
+
2524
+ const __m256i lhs_mat_01_1_sp1 = _mm256_shuffle_epi32(lhs_mat_01_1, 160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11)
2525
+ const __m256i lhs_mat_23_1_sp1 = _mm256_shuffle_epi32(lhs_mat_23_1, 160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11)
2526
+
2527
+ const __m256i lhs_mat_01_2_sp1 = _mm256_shuffle_epi32(lhs_mat_01_2, 160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19)
2528
+ const __m256i lhs_mat_23_2_sp1 = _mm256_shuffle_epi32(lhs_mat_23_2, 160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19)
2529
+
2530
+ const __m256i lhs_mat_01_3_sp1 = _mm256_shuffle_epi32(lhs_mat_01_3, 160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27)
2531
+ const __m256i lhs_mat_23_3_sp1 = _mm256_shuffle_epi32(lhs_mat_23_3, 160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27)
2532
+
2533
+ // Shuffle pattern two - left side input
2534
+ const __m256i lhs_mat_01_0_sp2 = _mm256_shuffle_epi32(lhs_mat_01_0, 245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7)
2535
+ const __m256i lhs_mat_23_0_sp2 = _mm256_shuffle_epi32(lhs_mat_23_0, 245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7)
2536
+
2537
+ const __m256i lhs_mat_01_1_sp2 = _mm256_shuffle_epi32(lhs_mat_01_1, 245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15)
2538
+ const __m256i lhs_mat_23_1_sp2 = _mm256_shuffle_epi32(lhs_mat_23_1, 245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15)
2539
+
2540
+ const __m256i lhs_mat_01_2_sp2 = _mm256_shuffle_epi32(lhs_mat_01_2, 245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23)
2541
+ const __m256i lhs_mat_23_2_sp2 = _mm256_shuffle_epi32(lhs_mat_23_2, 245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23)
2542
+
2543
+ const __m256i lhs_mat_01_3_sp2 = _mm256_shuffle_epi32(lhs_mat_01_3, 245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31)
2544
+ const __m256i lhs_mat_23_3_sp2 = _mm256_shuffle_epi32(lhs_mat_23_3, 245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31)
2545
+
2546
+ // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
2547
+ // Resembles MMLAs into 2x2 matrices in ARM Version
2548
+ __m256i iacc_mat_00_sp1 =
2549
+ _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_01_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int(lhs_mat_01_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int(lhs_mat_01_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int(lhs_mat_01_0_sp1, rhs_mat_0145_0_sp1));
2550
+ __m256i iacc_mat_01_sp1 =
2551
+ _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_01_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int(lhs_mat_01_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int(lhs_mat_01_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int(lhs_mat_01_0_sp1, rhs_mat_2367_0_sp1));
2552
+ __m256i iacc_mat_10_sp1 =
2553
+ _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_23_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int(lhs_mat_23_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int(lhs_mat_23_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int(lhs_mat_23_0_sp1, rhs_mat_0145_0_sp1));
2554
+ __m256i iacc_mat_11_sp1 =
2555
+ _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_23_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int(lhs_mat_23_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int(lhs_mat_23_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int(lhs_mat_23_0_sp1, rhs_mat_2367_0_sp1));
2556
+ __m256i iacc_mat_00_sp2 =
2557
+ _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_01_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int(lhs_mat_01_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int(lhs_mat_01_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int(lhs_mat_01_0_sp2, rhs_mat_0145_0_sp2));
2558
+ __m256i iacc_mat_01_sp2 =
2559
+ _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_01_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int(lhs_mat_01_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int(lhs_mat_01_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int(lhs_mat_01_0_sp2, rhs_mat_2367_0_sp2));
2560
+ __m256i iacc_mat_10_sp2 =
2561
+ _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_23_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int(lhs_mat_23_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int(lhs_mat_23_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int(lhs_mat_23_0_sp2, rhs_mat_0145_0_sp2));
2562
+ __m256i iacc_mat_11_sp2 =
2563
+ _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_23_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int(lhs_mat_23_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int(lhs_mat_23_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int(lhs_mat_23_0_sp2, rhs_mat_2367_0_sp2));
2564
+
2565
+ // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
2566
+ __m256i iacc_mat_00 = _mm256_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2);
2567
+ __m256i iacc_mat_01 = _mm256_add_epi32(iacc_mat_01_sp1, iacc_mat_01_sp2);
2568
+ __m256i iacc_mat_10 = _mm256_add_epi32(iacc_mat_10_sp1, iacc_mat_10_sp2);
2569
+ __m256i iacc_mat_11 = _mm256_add_epi32(iacc_mat_11_sp1, iacc_mat_11_sp2);
2570
+
2571
+ // Straighten out to make 4 row vectors
2572
+ __m256i iacc_row_0 = _mm256_blend_epi32(iacc_mat_00, _mm256_shuffle_epi32(iacc_mat_01, 78), 204);
2573
+ __m256i iacc_row_1 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_00, 78), iacc_mat_01, 204);
2574
+ __m256i iacc_row_2 = _mm256_blend_epi32(iacc_mat_10, _mm256_shuffle_epi32(iacc_mat_11, 78), 204);
2575
+ __m256i iacc_row_3 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11, 204);
2576
+
2577
+ // Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes
2578
+ const __m256 row_scale_f32 = LM_GGML_F32Cx8_REPEAT_LOAD(a_ptrs[rp][b].d, loadMask);
2579
+
2580
+ // Multiply with appropiate scales and accumulate
2581
+ acc_rows[rp * 4] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_0), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[rp * 4]);
2582
+ acc_rows[rp * 4 + 1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_1), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[rp * 4 + 1]);
2583
+ acc_rows[rp * 4 + 2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_2), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[rp * 4 + 2]);
2584
+ acc_rows[rp * 4 + 3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_3), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[rp * 4 + 3]);
2585
+ }
2586
+ }
2587
+
2588
+ // Store the accumulated values
2589
+ for (int i = 0; i < 16; i++) {
2590
+ _mm256_storeu_ps((float *)(s + ((y * 4 + i) * bs + x * 8)), acc_rows[i]);
2591
+ }
2592
+ }
2593
+ }
2594
+
2595
+ // Take a block_q8_0x4 structures at each pass of the loop and perform dot product operation
2596
+ for (; y < nr / 4; y ++) {
2597
+
2598
+ const block_q8_0x4 * a_ptr = a_ptr_start + (y * nb);
2599
+
2600
+ // Load the eight block_q4_0 quantized values interleaved with each other in chunks of eight - B0,B1 ....B6,B7
2601
+ for (int64_t x = 0; x < nc / 8; x++) {
2602
+
2603
+ const block_q4_0x8 * b_ptr = b_ptr_start + (x * b_nb);
2604
+
2605
+ // Master FP accumulators
2606
+ __m256 acc_rows[4];
2607
+ for (int i = 0; i < 4; i++) {
2608
+ acc_rows[i] = _mm256_setzero_ps();
2609
+ }
2610
+
2611
+ for (int64_t b = 0; b < nb; b++) {
2612
+ // Load the eight block_q8_0 quantized values interleaved with each other in chunks of eight - B0,B1 ....B6,B7
2613
+ const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs));
2614
+ const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 32));
2615
+ const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 64));
2616
+ const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 96));
2617
+
2618
+ // Save the values in the following vectors in the formats B0B1B4B5, B2B3B6B7 for further processing and storing of valuess
2619
+ const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240);
2620
+ const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240);
2621
+ const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240);
2622
+ const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240);
2623
+
2624
+ // 4-bit -> 8-bit - Sign is maintained
2625
+ const __m256i rhs_mat_0145_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_0145_0, m4b)); //B0(0-7) B1(0-7) B4(0-7) B5(0-7)
2626
+ const __m256i rhs_mat_2367_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_2367_0, m4b)); //B2(0-7) B3(0-7) B6(0-7) B7(0-7)
2627
+
2628
+ const __m256i rhs_mat_0145_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_0145_1, m4b)); //B0(8-15) B1(8-15) B4(8-15) B5(8-15)
2629
+ const __m256i rhs_mat_2367_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_2367_1, m4b)); //B2(8-15) B3(8-15) B6(8-15) B7(8-15)
2630
+
2631
+ const __m256i rhs_mat_0145_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_0, 4), m4b)); //B0(16-23) B1(16-23) B4(16-23) B5(16-23)
2632
+ const __m256i rhs_mat_2367_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_0, 4), m4b)); //B2(16-23) B3(16-23) B6(16-23) B7(16-23)
2633
+
2634
+ const __m256i rhs_mat_0145_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_1, 4), m4b)); //B0(24-31) B1(24-31) B4(24-31) B5(24-31)
2635
+ const __m256i rhs_mat_2367_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_1, 4), m4b)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31)
2636
+
2637
+ // Shuffle pattern one - right side input
2638
+ const __m256i rhs_mat_0145_0_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_0, 136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3)
2639
+ const __m256i rhs_mat_2367_0_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_0, 136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3)
2640
+
2641
+ const __m256i rhs_mat_0145_1_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_1, 136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11)
2642
+ const __m256i rhs_mat_2367_1_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_1, 136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11)
2643
+
2644
+ const __m256i rhs_mat_0145_2_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_2, 136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19)
2645
+ const __m256i rhs_mat_2367_2_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_2, 136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19)
2646
+
2647
+ const __m256i rhs_mat_0145_3_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_3, 136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27)
2648
+ const __m256i rhs_mat_2367_3_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_3, 136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27)
2649
+
2650
+ // Shuffle pattern two - right side input
2651
+
2652
+ const __m256i rhs_mat_0145_0_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_0, 221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7)
2653
+ const __m256i rhs_mat_2367_0_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_0, 221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7)
2654
+
2655
+ const __m256i rhs_mat_0145_1_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_1, 221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15)
2656
+ const __m256i rhs_mat_2367_1_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_1, 221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15)
2657
+
2658
+ const __m256i rhs_mat_0145_2_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_2, 221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23)
2659
+ const __m256i rhs_mat_2367_2_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_2, 221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23)
2660
+
2661
+ const __m256i rhs_mat_0145_3_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_3, 221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31)
2662
+ const __m256i rhs_mat_2367_3_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_3, 221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31)
2663
+
2664
+ // Scale values - Load the wight scale values of block_q4_0x8
2665
+ const __m256 col_scale_f32 = LM_GGML_F32Cx8_LOAD(b_ptr[b].d);
2666
+
2667
+ // Load the four block_q4_0 quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3
2668
+ // Loaded as set of 128 bit vectors and repeated into a 256 bit vector
2669
+ __m256i lhs_mat_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs)));
2670
+ __m256i lhs_mat_01_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 0);
2671
+ __m256i lhs_mat_23_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 17);
2672
+ __m256i lhs_mat_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 32)));
2673
+ __m256i lhs_mat_01_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 0);
2674
+ __m256i lhs_mat_23_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 17);
2675
+ __m256i lhs_mat_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 64)));
2676
+ __m256i lhs_mat_01_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 0);
2677
+ __m256i lhs_mat_23_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 17);
2678
+ __m256i lhs_mat_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 96)));
2679
+ __m256i lhs_mat_01_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 0);
2680
+ __m256i lhs_mat_23_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 17);
2681
+
2682
+ // Shuffle pattern one - left side input
2683
+
2684
+ const __m256i lhs_mat_01_0_sp1 = _mm256_shuffle_epi32(lhs_mat_01_0, 160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3)
2685
+ const __m256i lhs_mat_23_0_sp1 = _mm256_shuffle_epi32(lhs_mat_23_0, 160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3)
2686
+
2687
+ const __m256i lhs_mat_01_1_sp1 = _mm256_shuffle_epi32(lhs_mat_01_1, 160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11)
2688
+ const __m256i lhs_mat_23_1_sp1 = _mm256_shuffle_epi32(lhs_mat_23_1, 160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11)
2689
+
2690
+ const __m256i lhs_mat_01_2_sp1 = _mm256_shuffle_epi32(lhs_mat_01_2, 160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19)
2691
+ const __m256i lhs_mat_23_2_sp1 = _mm256_shuffle_epi32(lhs_mat_23_2, 160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19)
2692
+
2693
+ const __m256i lhs_mat_01_3_sp1 = _mm256_shuffle_epi32(lhs_mat_01_3, 160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27)
2694
+ const __m256i lhs_mat_23_3_sp1 = _mm256_shuffle_epi32(lhs_mat_23_3, 160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27)
2695
+
2696
+ // Shuffle pattern two - left side input
2697
+
2698
+ const __m256i lhs_mat_01_0_sp2 = _mm256_shuffle_epi32(lhs_mat_01_0, 245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7)
2699
+ const __m256i lhs_mat_23_0_sp2 = _mm256_shuffle_epi32(lhs_mat_23_0, 245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7)
2700
+
2701
+ const __m256i lhs_mat_01_1_sp2 = _mm256_shuffle_epi32(lhs_mat_01_1, 245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15)
2702
+ const __m256i lhs_mat_23_1_sp2 = _mm256_shuffle_epi32(lhs_mat_23_1, 245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15)
2703
+
2704
+ const __m256i lhs_mat_01_2_sp2 = _mm256_shuffle_epi32(lhs_mat_01_2, 245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23)
2705
+ const __m256i lhs_mat_23_2_sp2 = _mm256_shuffle_epi32(lhs_mat_23_2, 245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23)
2706
+
2707
+ const __m256i lhs_mat_01_3_sp2 = _mm256_shuffle_epi32(lhs_mat_01_3, 245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31)
2708
+ const __m256i lhs_mat_23_3_sp2 = _mm256_shuffle_epi32(lhs_mat_23_3, 245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31)
2709
+
2710
+ // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
2711
+ // Resembles MMLAs into 2x2 matrices in ARM Version
2712
+ __m256i iacc_mat_00_sp1 =
2713
+ _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_01_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int(lhs_mat_01_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int(lhs_mat_01_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int(lhs_mat_01_0_sp1, rhs_mat_0145_0_sp1));
2714
+ __m256i iacc_mat_01_sp1 =
2715
+ _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_01_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int(lhs_mat_01_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int(lhs_mat_01_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int(lhs_mat_01_0_sp1, rhs_mat_2367_0_sp1));
2716
+ __m256i iacc_mat_10_sp1 =
2717
+ _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_23_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int(lhs_mat_23_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int(lhs_mat_23_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int(lhs_mat_23_0_sp1, rhs_mat_0145_0_sp1));
2718
+ __m256i iacc_mat_11_sp1 =
2719
+ _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_23_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int(lhs_mat_23_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int(lhs_mat_23_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int(lhs_mat_23_0_sp1, rhs_mat_2367_0_sp1));
2720
+ __m256i iacc_mat_00_sp2 =
2721
+ _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_01_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int(lhs_mat_01_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int(lhs_mat_01_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int(lhs_mat_01_0_sp2, rhs_mat_0145_0_sp2));
2722
+ __m256i iacc_mat_01_sp2 =
2723
+ _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_01_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int(lhs_mat_01_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int(lhs_mat_01_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int(lhs_mat_01_0_sp2, rhs_mat_2367_0_sp2));
2724
+ __m256i iacc_mat_10_sp2 =
2725
+ _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_23_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int(lhs_mat_23_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int(lhs_mat_23_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int(lhs_mat_23_0_sp2, rhs_mat_0145_0_sp2));
2726
+ __m256i iacc_mat_11_sp2 =
2727
+ _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_23_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int(lhs_mat_23_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int(lhs_mat_23_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int(lhs_mat_23_0_sp2, rhs_mat_2367_0_sp2));
2728
+
2729
+ // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
2730
+ __m256i iacc_mat_00 = _mm256_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2);
2731
+ __m256i iacc_mat_01 = _mm256_add_epi32(iacc_mat_01_sp1, iacc_mat_01_sp2);
2732
+ __m256i iacc_mat_10 = _mm256_add_epi32(iacc_mat_10_sp1, iacc_mat_10_sp2);
2733
+ __m256i iacc_mat_11 = _mm256_add_epi32(iacc_mat_11_sp1, iacc_mat_11_sp2);
2734
+
2735
+
2736
+ // Straighten out to make 4 row vectors
2737
+ __m256i iacc_row_0 = _mm256_blend_epi32(iacc_mat_00, _mm256_shuffle_epi32(iacc_mat_01, 78), 204);
2738
+ __m256i iacc_row_1 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_00, 78), iacc_mat_01, 204);
2739
+ __m256i iacc_row_2 = _mm256_blend_epi32(iacc_mat_10, _mm256_shuffle_epi32(iacc_mat_11, 78), 204);
2740
+ __m256i iacc_row_3 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11, 204);
2741
+
2742
+ // Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes
2743
+ const __m256 row_scale_f32 = LM_GGML_F32Cx8_REPEAT_LOAD(a_ptr[b].d, loadMask);
2744
+
2745
+ // Multiply with appropiate scales and accumulate
2746
+ acc_rows[0] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_0), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[0]);
2747
+ acc_rows[1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_1), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[1]);
2748
+ acc_rows[2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_2), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[2]);
2749
+ acc_rows[3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_3), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[3]);
2750
+ }
2751
+
2752
+ // Store the accumulated values
2753
+ for (int i = 0; i < 4; i++) {
2754
+ _mm256_storeu_ps((float *)(s + ((y * 4 + i) * bs + x * 8)), acc_rows[i]);
2755
+ }
2756
+ }
2757
+ }
2146
2758
  #else
2147
2759
  float sumf[4][8];
2148
2760
  int sumi;