cui-llama.rn 1.3.3 → 1.3.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. package/android/src/main/CMakeLists.txt +5 -7
  2. package/android/src/main/java/com/rnllama/LlamaContext.java +4 -4
  3. package/android/src/main/jni.cpp +9 -9
  4. package/cpp/common.cpp +21 -40
  5. package/cpp/common.h +21 -12
  6. package/cpp/ggml-backend-impl.h +38 -20
  7. package/cpp/ggml-backend-reg.cpp +216 -87
  8. package/cpp/ggml-backend.h +1 -0
  9. package/cpp/ggml-common.h +42 -48
  10. package/cpp/{ggml-cpu-aarch64.c → ggml-cpu-aarch64.cpp} +591 -152
  11. package/cpp/ggml-cpu-aarch64.h +2 -26
  12. package/cpp/ggml-cpu-traits.cpp +36 -0
  13. package/cpp/ggml-cpu-traits.h +38 -0
  14. package/cpp/ggml-cpu.c +14122 -13971
  15. package/cpp/ggml-cpu.cpp +618 -715
  16. package/cpp/ggml-cpu.h +0 -17
  17. package/cpp/ggml-impl.h +6 -6
  18. package/cpp/ggml-metal.m +482 -24
  19. package/cpp/ggml-quants.c +0 -9
  20. package/cpp/ggml-threading.h +4 -2
  21. package/cpp/ggml.c +132 -43
  22. package/cpp/ggml.h +44 -13
  23. package/cpp/llama-sampling.cpp +35 -90
  24. package/cpp/llama-vocab.cpp +2 -1
  25. package/cpp/llama.cpp +737 -233
  26. package/cpp/llama.h +20 -16
  27. package/cpp/sampling.cpp +11 -16
  28. package/cpp/speculative.cpp +4 -0
  29. package/cpp/unicode.cpp +51 -51
  30. package/cpp/unicode.h +9 -10
  31. package/lib/commonjs/index.js +38 -1
  32. package/lib/commonjs/index.js.map +1 -1
  33. package/lib/module/index.js +36 -0
  34. package/lib/module/index.js.map +1 -1
  35. package/lib/typescript/NativeRNLlama.d.ts +2 -3
  36. package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
  37. package/lib/typescript/index.d.ts +36 -2
  38. package/lib/typescript/index.d.ts.map +1 -1
  39. package/package.json +1 -1
  40. package/src/NativeRNLlama.ts +3 -3
  41. package/src/index.ts +46 -2
  42. package/cpp/amx/amx.cpp +0 -196
  43. package/cpp/amx/amx.h +0 -20
  44. package/cpp/amx/common.h +0 -101
  45. package/cpp/amx/mmq.cpp +0 -2524
  46. package/cpp/amx/mmq.h +0 -16
  47. package/cpp/ggml-aarch64.c +0 -129
  48. package/cpp/ggml-aarch64.h +0 -19
@@ -1,20 +1,57 @@
1
- #define LM_GGML_COMMON_IMPL_C
1
+ #define LM_GGML_COMMON_IMPL_CPP
2
+ #define LM_GGML_COMMON_DECL_CPP
2
3
  #include "ggml-common.h"
4
+ #include "ggml-backend-impl.h"
3
5
 
4
6
  #include "ggml-quants.h"
5
7
  #include "ggml-impl.h"
6
8
  #include "ggml-cpu.h"
7
9
  #include "ggml-cpu-impl.h"
10
+ #include "ggml-cpu-traits.h"
8
11
 
9
- #include <math.h>
10
- #include <string.h>
11
- #include <assert.h>
12
- #include <float.h>
13
- #include <stdlib.h> // for qsort
14
- #include <stdio.h> // for LM_GGML_ASSERT
12
+ #include <cmath>
13
+ #include <cstring>
14
+ #include <cassert>
15
+ #include <cfloat>
16
+ #include <cstdlib> // for qsort
17
+ #include <cstdio> // for LM_GGML_ASSERT
15
18
 
16
19
  #include "ggml-cpu-aarch64.h"
17
20
 
21
+ // TODO: move to include file?
22
+ template <int K> constexpr int QK_0() {
23
+ if constexpr (K == 4) {
24
+ return QK4_0;
25
+ }
26
+ if constexpr (K == 8) {
27
+ return QK8_0;
28
+ }
29
+ return -1;
30
+ }
31
+
32
+ template <int K, int N> struct block {
33
+ lm_ggml_half d[N]; // deltas for N qK_0 blocks
34
+ int8_t qs[(QK_0<K>() * N * K) / 8]; // quants for N qK_0 blocks
35
+ };
36
+
37
+ // control size
38
+ static_assert(sizeof(block<4, 4>) == 4 * sizeof(lm_ggml_half) + QK8_0 * 2, "wrong block<4,4> size/padding");
39
+ static_assert(sizeof(block<4, 8>) == 8 * sizeof(lm_ggml_half) + QK8_0 * 4, "wrong block<4,8> size/padding");
40
+ static_assert(sizeof(block<8, 4>) == 4 * sizeof(lm_ggml_half) + QK8_0 * 4, "wrong block<8,4> size/padding");
41
+ static_assert(sizeof(block<8, 8>) == 8 * sizeof(lm_ggml_half) + QK8_0 * 8, "wrong block<8,8> size/padding");
42
+
43
+ using block_q4_0x4 = block<4, 4>;
44
+ using block_q4_0x8 = block<4, 8>;
45
+ using block_q8_0x4 = block<8, 4>;
46
+ using block_q8_0x8 = block<8, 8>;
47
+
48
+ struct block_iq4_nlx4 {
49
+ lm_ggml_half d[4]; // deltas for 4 iq4_nl blocks
50
+ uint8_t qs[QK4_NL * 2]; // nibbles / quants for 4 iq4_nl blocks
51
+ };
52
+
53
+ static_assert(sizeof(block_iq4_nlx4) == 4 * sizeof(lm_ggml_half) + QK4_NL * 2, "wrong iq4_nlx4 block size/padding");
54
+
18
55
  #if defined(__GNUC__)
19
56
  #pragma GCC diagnostic ignored "-Woverlength-strings"
20
57
  #elif defined(_MSC_VER)
@@ -128,7 +165,7 @@ static inline __m512i sum_i16_pairs_int_32x16(const __m512i x) {
128
165
  }
129
166
 
130
167
  static inline __m512i mul_sum_us8_pairs_int32x16(const __m512i ax, const __m512i sy) {
131
- #if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
168
+ #if defined(__AVX512VNNI__)
132
169
  const __m512i zero = _mm512_setzero_si512();
133
170
  return _mm512_dpbusd_epi32(zero, ax, sy);
134
171
  #else
@@ -185,12 +222,12 @@ static inline __m256i mul_sum_i8_pairs_int32x8(const __m256i x, const __m256i y)
185
222
 
186
223
  static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
187
224
 
188
- static void quantize_q8_0_4x4(const float * restrict x, void * restrict vy, int64_t k) {
225
+ static void quantize_q8_0_4x4(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT vy, int64_t k) {
189
226
  assert(QK8_0 == 32);
190
227
  assert(k % QK8_0 == 0);
191
228
  const int nb = k / QK8_0;
192
229
 
193
- block_q8_0x4 * restrict y = (block_q8_0x4 *) vy;
230
+ block_q8_0x4 * LM_GGML_RESTRICT y = (block_q8_0x4 *) vy;
194
231
 
195
232
  #if defined(__ARM_NEON)
196
233
  float32x4_t srcv[4][8];
@@ -279,12 +316,12 @@ static void quantize_q8_0_4x4(const float * restrict x, void * restrict vy, int6
279
316
  #endif
280
317
  }
281
318
 
282
- static void quantize_q8_0_4x8(const float * restrict x, void * restrict vy, int64_t k) {
319
+ static void quantize_q8_0_4x8(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT vy, int64_t k) {
283
320
  assert(QK8_0 == 32);
284
321
  assert(k % QK8_0 == 0);
285
322
  const int nb = k / QK8_0;
286
323
 
287
- block_q8_0x4 * restrict y = (block_q8_0x4 *) vy;
324
+ block_q8_0x4 * LM_GGML_RESTRICT y = (block_q8_0x4 *) vy;
288
325
 
289
326
  #if defined(__ARM_NEON)
290
327
  float32x4_t srcv[4][8];
@@ -494,7 +531,7 @@ static void quantize_q8_0_4x8(const float * restrict x, void * restrict vy, int6
494
531
  #endif
495
532
  }
496
533
 
497
- void quantize_mat_q8_0(const float * restrict x, void * restrict vy, int64_t nrow, int64_t n_per_row, int64_t blck_size_interleave) {
534
+ static void quantize_mat_q8_0(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row, int64_t blck_size_interleave) {
498
535
  assert(nrow == 4);
499
536
  UNUSED(nrow);
500
537
  if (blck_size_interleave == 4) {
@@ -506,7 +543,7 @@ void quantize_mat_q8_0(const float * restrict x, void * restrict vy, int64_t nro
506
543
  }
507
544
  }
508
545
 
509
- void lm_ggml_gemv_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
546
+ static void lm_ggml_gemv_q4_0_4x4_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, const void * LM_GGML_RESTRICT vy, int nr, int nc) {
510
547
  const int qk = QK8_0;
511
548
  const int nb = n / qk;
512
549
  const int ncols_interleaved = 4;
@@ -591,7 +628,7 @@ void lm_ggml_gemv_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void
591
628
  }
592
629
  }
593
630
 
594
- void lm_ggml_gemv_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
631
+ static void lm_ggml_gemv_q4_0_4x8_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, const void * LM_GGML_RESTRICT vy, int nr, int nc) {
595
632
  const int qk = QK8_0;
596
633
  const int nb = n / qk;
597
634
  const int ncols_interleaved = 4;
@@ -701,7 +738,7 @@ void lm_ggml_gemv_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void
701
738
  }
702
739
  }
703
740
 
704
- void lm_ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
741
+ static void lm_ggml_gemv_q4_0_8x8_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, const void * LM_GGML_RESTRICT vy, int nr, int nc) {
705
742
  const int qk = QK8_0;
706
743
  const int nb = n / qk;
707
744
  const int ncols_interleaved = 8;
@@ -974,7 +1011,7 @@ void lm_ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void
974
1011
  }
975
1012
  }
976
1013
 
977
- void lm_ggml_gemv_iq4_nl_4x4_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
1014
+ static void lm_ggml_gemv_iq4_nl_4x4_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, const void * LM_GGML_RESTRICT vy, int nr, int nc) {
978
1015
  const int qk = QK8_0;
979
1016
  const int nb = n / qk;
980
1017
  const int ncols_interleaved = 4;
@@ -1070,7 +1107,7 @@ void lm_ggml_gemv_iq4_nl_4x4_q8_0(int n, float * restrict s, size_t bs, const vo
1070
1107
  }
1071
1108
  }
1072
1109
 
1073
- void lm_ggml_gemm_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
1110
+ static void lm_ggml_gemm_q4_0_4x4_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, const void * LM_GGML_RESTRICT vy, int nr, int nc) {
1074
1111
  const int qk = QK8_0;
1075
1112
  const int nb = n / qk;
1076
1113
  const int ncols_interleaved = 4;
@@ -1586,7 +1623,7 @@ void lm_ggml_gemm_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void
1586
1623
  }
1587
1624
  }
1588
1625
 
1589
- void lm_ggml_gemm_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
1626
+ static void lm_ggml_gemm_q4_0_4x8_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, const void * LM_GGML_RESTRICT vy, int nr, int nc) {
1590
1627
  const int qk = QK8_0;
1591
1628
  const int nb = n / qk;
1592
1629
  const int ncols_interleaved = 4;
@@ -2040,7 +2077,7 @@ void lm_ggml_gemm_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void
2040
2077
  }
2041
2078
  }
2042
2079
 
2043
- void lm_ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
2080
+ static void lm_ggml_gemm_q4_0_8x8_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, const void * LM_GGML_RESTRICT vy, int nr, int nc) {
2044
2081
  const int qk = QK8_0;
2045
2082
  const int nb = n / qk;
2046
2083
  const int ncols_interleaved = 8;
@@ -2560,31 +2597,31 @@ void lm_ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void
2560
2597
  const __m512i rhs_mat_2367ABEF_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 4), m4bexpanded)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31) BA(24-31) BB(24-31) BE(24-31) BF(24-31)
2561
2598
 
2562
2599
  // Shuffle pattern one - right side input
2563
- const __m512i rhs_mat_014589CD_0_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, 136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) B8(0-3) B9(0-3) B8(0-3) B9(0-3) BC(0-3) BD(0-3) BC(0-3) BD(0-3)
2564
- const __m512i rhs_mat_2367ABEF_0_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, 136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) BA(0-3) BB(0-3) BA(0-3) BB(0-3) BE(0-3) BF(0-3) BE(0-3) BF(0-3)
2600
+ const __m512i rhs_mat_014589CD_0_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, (_MM_PERM_ENUM)136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) B8(0-3) B9(0-3) B8(0-3) B9(0-3) BC(0-3) BD(0-3) BC(0-3) BD(0-3)
2601
+ const __m512i rhs_mat_2367ABEF_0_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, (_MM_PERM_ENUM)136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) BA(0-3) BB(0-3) BA(0-3) BB(0-3) BE(0-3) BF(0-3) BE(0-3) BF(0-3)
2565
2602
 
2566
- const __m512i rhs_mat_014589CD_1_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, 136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) B8(8-11) B9(8-11) B8(8-11) B9(8-11) BC(8-11) BD(8-11) BC(8-11) BD(8-11)
2567
- const __m512i rhs_mat_2367ABEF_1_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, 136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) BA(8-11) BB(8-11) BA(8-11) BB(8-11) BE(8-11) BF(8-11) BE(8-11) BF(8-11)
2603
+ const __m512i rhs_mat_014589CD_1_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, (_MM_PERM_ENUM)136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) B8(8-11) B9(8-11) B8(8-11) B9(8-11) BC(8-11) BD(8-11) BC(8-11) BD(8-11)
2604
+ const __m512i rhs_mat_2367ABEF_1_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, (_MM_PERM_ENUM)136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) BA(8-11) BB(8-11) BA(8-11) BB(8-11) BE(8-11) BF(8-11) BE(8-11) BF(8-11)
2568
2605
 
2569
- const __m512i rhs_mat_014589CD_2_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, 136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) B8(16-19) B9(16-19) B8(16-19) B9(16-19) BC(16-19) BD(16-19) BC(16-19) BD(16-19)
2570
- const __m512i rhs_mat_2367ABEF_2_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, 136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) BA(16-19) BB(16-19) BA(16-19) BB(16-19) BE(16-19) BF(16-19) BE(16-19) BF(16-19)
2606
+ const __m512i rhs_mat_014589CD_2_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, (_MM_PERM_ENUM)136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) B8(16-19) B9(16-19) B8(16-19) B9(16-19) BC(16-19) BD(16-19) BC(16-19) BD(16-19)
2607
+ const __m512i rhs_mat_2367ABEF_2_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, (_MM_PERM_ENUM)136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) BA(16-19) BB(16-19) BA(16-19) BB(16-19) BE(16-19) BF(16-19) BE(16-19) BF(16-19)
2571
2608
 
2572
- const __m512i rhs_mat_014589CD_3_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, 136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) B8(24-27) B9(24-27) B8(24-27) B9(24-27) BC(24-27) BD(24-27) BC(24-27) BD(24-27)
2573
- const __m512i rhs_mat_2367ABEF_3_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, 136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) BA(24-27) BB(24-27) BA(24-27) BB(24-27) BE(24-27) BF(24-27) BE(24-27) BF(24-27)
2609
+ const __m512i rhs_mat_014589CD_3_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, (_MM_PERM_ENUM)136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) B8(24-27) B9(24-27) B8(24-27) B9(24-27) BC(24-27) BD(24-27) BC(24-27) BD(24-27)
2610
+ const __m512i rhs_mat_2367ABEF_3_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, (_MM_PERM_ENUM)136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) BA(24-27) BB(24-27) BA(24-27) BB(24-27) BE(24-27) BF(24-27) BE(24-27) BF(24-27)
2574
2611
 
2575
2612
  // Shuffle pattern two - right side input
2576
2613
 
2577
- const __m512i rhs_mat_014589CD_0_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, 221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) B8(4-7) B9(4-7) B8(4-7) B9(4-7) BC(4-7) BD(4-7) BC(4-7) BD(4-7)
2578
- const __m512i rhs_mat_2367ABEF_0_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, 221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) BA(4-7) BB(4-7) BA(4-7) BB(4-7) BE(4-7) BF(4-7) BE(4-7) BF(4-7)
2614
+ const __m512i rhs_mat_014589CD_0_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, (_MM_PERM_ENUM)221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) B8(4-7) B9(4-7) B8(4-7) B9(4-7) BC(4-7) BD(4-7) BC(4-7) BD(4-7)
2615
+ const __m512i rhs_mat_2367ABEF_0_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, (_MM_PERM_ENUM)221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) BA(4-7) BB(4-7) BA(4-7) BB(4-7) BE(4-7) BF(4-7) BE(4-7) BF(4-7)
2579
2616
 
2580
- const __m512i rhs_mat_014589CD_1_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, 221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) B8(12-15) B9(12-15) B8(12-15) B9(12-15) BC(12-15) BD(12-15) BC(12-15) BD(12-15)
2581
- const __m512i rhs_mat_2367ABEF_1_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, 221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) BA(12-15) BB(12-15) BA(12-15) BB(12-15) BE(12-15) BF(12-15) BE(12-15) BF(12-15)
2617
+ const __m512i rhs_mat_014589CD_1_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, (_MM_PERM_ENUM)221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) B8(12-15) B9(12-15) B8(12-15) B9(12-15) BC(12-15) BD(12-15) BC(12-15) BD(12-15)
2618
+ const __m512i rhs_mat_2367ABEF_1_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, (_MM_PERM_ENUM)221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) BA(12-15) BB(12-15) BA(12-15) BB(12-15) BE(12-15) BF(12-15) BE(12-15) BF(12-15)
2582
2619
 
2583
- const __m512i rhs_mat_014589CD_2_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, 221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) B8(20-23) B9(20-23) B8(20-23) B9(20-23) BC(20-23) BD(20-23) BC(20-23) BD(20-23)
2584
- const __m512i rhs_mat_2367ABEF_2_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, 221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) BA(20-23) BB(20-23) BA(20-23) BB(20-23) BE(20-23) BF(20-23) BE(20-23) BF(20-23)
2620
+ const __m512i rhs_mat_014589CD_2_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, (_MM_PERM_ENUM)221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) B8(20-23) B9(20-23) B8(20-23) B9(20-23) BC(20-23) BD(20-23) BC(20-23) BD(20-23)
2621
+ const __m512i rhs_mat_2367ABEF_2_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, (_MM_PERM_ENUM)221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) BA(20-23) BB(20-23) BA(20-23) BB(20-23) BE(20-23) BF(20-23) BE(20-23) BF(20-23)
2585
2622
 
2586
- const __m512i rhs_mat_014589CD_3_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, 221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) B8(28-31) B9(28-31) B8(28-31) B9(28-31) BC(28-31) BD(28-31) BC(28-31) BD(28-31)
2587
- const __m512i rhs_mat_2367ABEF_3_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, 221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) BA(28-31) BB(28-31) BA(28-31) BB(28-31) BE(28-31) BF(28-31) BE(28-31) BF(28-31)
2623
+ const __m512i rhs_mat_014589CD_3_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, (_MM_PERM_ENUM)221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) B8(28-31) B9(28-31) B8(28-31) B9(28-31) BC(28-31) BD(28-31) BC(28-31) BD(28-31)
2624
+ const __m512i rhs_mat_2367ABEF_3_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, (_MM_PERM_ENUM)221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) BA(28-31) BB(28-31) BA(28-31) BB(28-31) BE(28-31) BF(28-31) BE(28-31) BF(28-31)
2588
2625
 
2589
2626
  // Scale values - Load the weight scale values of two block_q4_0x8
2590
2627
  const __m512 col_scale_f32 = LM_GGML_F32Cx8x2_LOAD(b_ptr_0[b].d, b_ptr_1[b].d);
@@ -2618,31 +2655,31 @@ void lm_ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void
2618
2655
 
2619
2656
  // Shuffle pattern one - left side input
2620
2657
 
2621
- const __m512i lhs_mat_01_0_sp1 = _mm512_shuffle_epi32(lhs_mat_01_0, 160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3)
2622
- const __m512i lhs_mat_23_0_sp1 = _mm512_shuffle_epi32(lhs_mat_23_0, 160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3)
2658
+ const __m512i lhs_mat_01_0_sp1 = _mm512_shuffle_epi32(lhs_mat_01_0, (_MM_PERM_ENUM)160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3)
2659
+ const __m512i lhs_mat_23_0_sp1 = _mm512_shuffle_epi32(lhs_mat_23_0, (_MM_PERM_ENUM)160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3)
2623
2660
 
2624
- const __m512i lhs_mat_01_1_sp1 = _mm512_shuffle_epi32(lhs_mat_01_1, 160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11)
2625
- const __m512i lhs_mat_23_1_sp1 = _mm512_shuffle_epi32(lhs_mat_23_1, 160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11)
2661
+ const __m512i lhs_mat_01_1_sp1 = _mm512_shuffle_epi32(lhs_mat_01_1, (_MM_PERM_ENUM)160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11)
2662
+ const __m512i lhs_mat_23_1_sp1 = _mm512_shuffle_epi32(lhs_mat_23_1, (_MM_PERM_ENUM)160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11)
2626
2663
 
2627
- const __m512i lhs_mat_01_2_sp1 = _mm512_shuffle_epi32(lhs_mat_01_2, 160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19)
2628
- const __m512i lhs_mat_23_2_sp1 = _mm512_shuffle_epi32(lhs_mat_23_2, 160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19)
2664
+ const __m512i lhs_mat_01_2_sp1 = _mm512_shuffle_epi32(lhs_mat_01_2, (_MM_PERM_ENUM)160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19)
2665
+ const __m512i lhs_mat_23_2_sp1 = _mm512_shuffle_epi32(lhs_mat_23_2, (_MM_PERM_ENUM)160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19)
2629
2666
 
2630
- const __m512i lhs_mat_01_3_sp1 = _mm512_shuffle_epi32(lhs_mat_01_3, 160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27)
2631
- const __m512i lhs_mat_23_3_sp1 = _mm512_shuffle_epi32(lhs_mat_23_3, 160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27)
2667
+ const __m512i lhs_mat_01_3_sp1 = _mm512_shuffle_epi32(lhs_mat_01_3, (_MM_PERM_ENUM)160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27)
2668
+ const __m512i lhs_mat_23_3_sp1 = _mm512_shuffle_epi32(lhs_mat_23_3, (_MM_PERM_ENUM)160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27)
2632
2669
 
2633
2670
  // Shuffle pattern two - left side input
2634
2671
 
2635
- const __m512i lhs_mat_01_0_sp2 = _mm512_shuffle_epi32(lhs_mat_01_0, 245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7)
2636
- const __m512i lhs_mat_23_0_sp2 = _mm512_shuffle_epi32(lhs_mat_23_0, 245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7)
2672
+ const __m512i lhs_mat_01_0_sp2 = _mm512_shuffle_epi32(lhs_mat_01_0, (_MM_PERM_ENUM)245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7)
2673
+ const __m512i lhs_mat_23_0_sp2 = _mm512_shuffle_epi32(lhs_mat_23_0, (_MM_PERM_ENUM)245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7)
2637
2674
 
2638
- const __m512i lhs_mat_01_1_sp2 = _mm512_shuffle_epi32(lhs_mat_01_1, 245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15)
2639
- const __m512i lhs_mat_23_1_sp2 = _mm512_shuffle_epi32(lhs_mat_23_1, 245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15)
2675
+ const __m512i lhs_mat_01_1_sp2 = _mm512_shuffle_epi32(lhs_mat_01_1, (_MM_PERM_ENUM)245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15)
2676
+ const __m512i lhs_mat_23_1_sp2 = _mm512_shuffle_epi32(lhs_mat_23_1, (_MM_PERM_ENUM)245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15)
2640
2677
 
2641
- const __m512i lhs_mat_01_2_sp2 = _mm512_shuffle_epi32(lhs_mat_01_2, 245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23)
2642
- const __m512i lhs_mat_23_2_sp2 = _mm512_shuffle_epi32(lhs_mat_23_2, 245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23)
2678
+ const __m512i lhs_mat_01_2_sp2 = _mm512_shuffle_epi32(lhs_mat_01_2, (_MM_PERM_ENUM)245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23)
2679
+ const __m512i lhs_mat_23_2_sp2 = _mm512_shuffle_epi32(lhs_mat_23_2, (_MM_PERM_ENUM)245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23)
2643
2680
 
2644
- const __m512i lhs_mat_01_3_sp2 = _mm512_shuffle_epi32(lhs_mat_01_3, 245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31)
2645
- const __m512i lhs_mat_23_3_sp2 = _mm512_shuffle_epi32(lhs_mat_23_3, 245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31)
2681
+ const __m512i lhs_mat_01_3_sp2 = _mm512_shuffle_epi32(lhs_mat_01_3, (_MM_PERM_ENUM)245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31)
2682
+ const __m512i lhs_mat_23_3_sp2 = _mm512_shuffle_epi32(lhs_mat_23_3, (_MM_PERM_ENUM)245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31)
2646
2683
 
2647
2684
  // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
2648
2685
  // Resembles MMLAs into 2x2 matrices in ARM Version
@@ -2671,10 +2708,10 @@ void lm_ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void
2671
2708
 
2672
2709
 
2673
2710
  // Straighten out to make 4 row vectors
2674
- __m512i iacc_row_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00, _mm512_shuffle_epi32(iacc_mat_01, 78));
2675
- __m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, 78), iacc_mat_01);
2676
- __m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, 78));
2677
- __m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11);
2711
+ __m512i iacc_row_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00, _mm512_shuffle_epi32(iacc_mat_01, (_MM_PERM_ENUM)78));
2712
+ __m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, (_MM_PERM_ENUM)78), iacc_mat_01);
2713
+ __m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, (_MM_PERM_ENUM)78));
2714
+ __m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, (_MM_PERM_ENUM)78), iacc_mat_11);
2678
2715
 
2679
2716
  // Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes
2680
2717
  const __m128i row_scale_f16 = _mm_shuffle_epi32(_mm_maskload_epi32((int const*)(a_ptrs[rp][b].d), loadMask), 68);
@@ -2753,31 +2790,31 @@ void lm_ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void
2753
2790
  const __m512i rhs_mat_2367ABEF_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 4), m4bexpanded)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31) BA(24-31) BB(24-31) BE(24-31) BF(24-31)
2754
2791
 
2755
2792
  // Shuffle pattern one - right side input
2756
- const __m512i rhs_mat_014589CD_0_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, 136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) B8(0-3) B9(0-3) B8(0-3) B9(0-3) BC(0-3) BD(0-3) BC(0-3) BD(0-3)
2757
- const __m512i rhs_mat_2367ABEF_0_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, 136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) BA(0-3) BB(0-3) BA(0-3) BB(0-3) BE(0-3) BF(0-3) BE(0-3) BF(0-3)
2793
+ const __m512i rhs_mat_014589CD_0_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, (_MM_PERM_ENUM)136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) B8(0-3) B9(0-3) B8(0-3) B9(0-3) BC(0-3) BD(0-3) BC(0-3) BD(0-3)
2794
+ const __m512i rhs_mat_2367ABEF_0_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, (_MM_PERM_ENUM)136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) BA(0-3) BB(0-3) BA(0-3) BB(0-3) BE(0-3) BF(0-3) BE(0-3) BF(0-3)
2758
2795
 
2759
- const __m512i rhs_mat_014589CD_1_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, 136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) B8(8-11) B9(8-11) B8(8-11) B9(8-11) BC(8-11) BD(8-11) BC(8-11) BD(8-11)
2760
- const __m512i rhs_mat_2367ABEF_1_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, 136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) BA(8-11) BB(8-11) BA(8-11) BB(8-11) BE(8-11) BF(8-11) BE(8-11) BF(8-11)
2796
+ const __m512i rhs_mat_014589CD_1_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, (_MM_PERM_ENUM)136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) B8(8-11) B9(8-11) B8(8-11) B9(8-11) BC(8-11) BD(8-11) BC(8-11) BD(8-11)
2797
+ const __m512i rhs_mat_2367ABEF_1_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, (_MM_PERM_ENUM)136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) BA(8-11) BB(8-11) BA(8-11) BB(8-11) BE(8-11) BF(8-11) BE(8-11) BF(8-11)
2761
2798
 
2762
- const __m512i rhs_mat_014589CD_2_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, 136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) B8(16-19) B9(16-19) B8(16-19) B9(16-19) BC(16-19) BD(16-19) BC(16-19) BD(16-19)
2763
- const __m512i rhs_mat_2367ABEF_2_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, 136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) BA(16-19) BB(16-19) BA(16-19) BB(16-19) BE(16-19) BF(16-19) BE(16-19) BF(16-19)
2799
+ const __m512i rhs_mat_014589CD_2_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, (_MM_PERM_ENUM)136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) B8(16-19) B9(16-19) B8(16-19) B9(16-19) BC(16-19) BD(16-19) BC(16-19) BD(16-19)
2800
+ const __m512i rhs_mat_2367ABEF_2_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, (_MM_PERM_ENUM)136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) BA(16-19) BB(16-19) BA(16-19) BB(16-19) BE(16-19) BF(16-19) BE(16-19) BF(16-19)
2764
2801
 
2765
- const __m512i rhs_mat_014589CD_3_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, 136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) B8(24-27) B9(24-27) B8(24-27) B9(24-27) BC(24-27) BD(24-27) BC(24-27) BD(24-27)
2766
- const __m512i rhs_mat_2367ABEF_3_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, 136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) BA(24-27) BB(24-27) BA(24-27) BB(24-27) BE(24-27) BF(24-27) BE(24-27) BF(24-27)
2802
+ const __m512i rhs_mat_014589CD_3_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, (_MM_PERM_ENUM)136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) B8(24-27) B9(24-27) B8(24-27) B9(24-27) BC(24-27) BD(24-27) BC(24-27) BD(24-27)
2803
+ const __m512i rhs_mat_2367ABEF_3_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, (_MM_PERM_ENUM)136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) BA(24-27) BB(24-27) BA(24-27) BB(24-27) BE(24-27) BF(24-27) BE(24-27) BF(24-27)
2767
2804
 
2768
2805
  // Shuffle pattern two - right side input
2769
2806
 
2770
- const __m512i rhs_mat_014589CD_0_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, 221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) B8(4-7) B9(4-7) B8(4-7) B9(4-7) BC(4-7) BD(4-7) BC(4-7) BD(4-7)
2771
- const __m512i rhs_mat_2367ABEF_0_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, 221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) BA(4-7) BB(4-7) BA(4-7) BB(4-7) BE(4-7) BF(4-7) BE(4-7) BF(4-7)
2807
+ const __m512i rhs_mat_014589CD_0_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, (_MM_PERM_ENUM)221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) B8(4-7) B9(4-7) B8(4-7) B9(4-7) BC(4-7) BD(4-7) BC(4-7) BD(4-7)
2808
+ const __m512i rhs_mat_2367ABEF_0_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, (_MM_PERM_ENUM)221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) BA(4-7) BB(4-7) BA(4-7) BB(4-7) BE(4-7) BF(4-7) BE(4-7) BF(4-7)
2772
2809
 
2773
- const __m512i rhs_mat_014589CD_1_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, 221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) B8(12-15) B9(12-15) B8(12-15) B9(12-15) BC(12-15) BD(12-15) BC(12-15) BD(12-15)
2774
- const __m512i rhs_mat_2367ABEF_1_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, 221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) BA(12-15) BB(12-15) BA(12-15) BB(12-15) BE(12-15) BF(12-15) BE(12-15) BF(12-15)
2810
+ const __m512i rhs_mat_014589CD_1_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, (_MM_PERM_ENUM)221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) B8(12-15) B9(12-15) B8(12-15) B9(12-15) BC(12-15) BD(12-15) BC(12-15) BD(12-15)
2811
+ const __m512i rhs_mat_2367ABEF_1_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, (_MM_PERM_ENUM)221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) BA(12-15) BB(12-15) BA(12-15) BB(12-15) BE(12-15) BF(12-15) BE(12-15) BF(12-15)
2775
2812
 
2776
- const __m512i rhs_mat_014589CD_2_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, 221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) B8(20-23) B9(20-23) B8(20-23) B9(20-23) BC(20-23) BD(20-23) BC(20-23) BD(20-23)
2777
- const __m512i rhs_mat_2367ABEF_2_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, 221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) BA(20-23) BB(20-23) BA(20-23) BB(20-23) BE(20-23) BF(20-23) BE(20-23) BF(20-23)
2813
+ const __m512i rhs_mat_014589CD_2_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, (_MM_PERM_ENUM)221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) B8(20-23) B9(20-23) B8(20-23) B9(20-23) BC(20-23) BD(20-23) BC(20-23) BD(20-23)
2814
+ const __m512i rhs_mat_2367ABEF_2_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, (_MM_PERM_ENUM)221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) BA(20-23) BB(20-23) BA(20-23) BB(20-23) BE(20-23) BF(20-23) BE(20-23) BF(20-23)
2778
2815
 
2779
- const __m512i rhs_mat_014589CD_3_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, 221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) B8(28-31) B9(28-31) B8(28-31) B9(28-31) BC(28-31) BD(28-31) BC(28-31) BD(28-31)
2780
- const __m512i rhs_mat_2367ABEF_3_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, 221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) BA(28-31) BB(28-31) BA(28-31) BB(28-31) BE(28-31) BF(28-31) BE(28-31) BF(28-31)
2816
+ const __m512i rhs_mat_014589CD_3_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, (_MM_PERM_ENUM)221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) B8(28-31) B9(28-31) B8(28-31) B9(28-31) BC(28-31) BD(28-31) BC(28-31) BD(28-31)
2817
+ const __m512i rhs_mat_2367ABEF_3_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, (_MM_PERM_ENUM)221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) BA(28-31) BB(28-31) BA(28-31) BB(28-31) BE(28-31) BF(28-31) BE(28-31) BF(28-31)
2781
2818
 
2782
2819
 
2783
2820
  // Scale values - Load the weight scale values of two block_q4_0x8
@@ -2809,31 +2846,31 @@ void lm_ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void
2809
2846
 
2810
2847
  // Shuffle pattern one - left side input
2811
2848
 
2812
- const __m512i lhs_mat_01_0_sp1 = _mm512_shuffle_epi32(lhs_mat_01_0, 160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3)
2813
- const __m512i lhs_mat_23_0_sp1 = _mm512_shuffle_epi32(lhs_mat_23_0, 160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3)
2849
+ const __m512i lhs_mat_01_0_sp1 = _mm512_shuffle_epi32(lhs_mat_01_0, (_MM_PERM_ENUM)160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3)
2850
+ const __m512i lhs_mat_23_0_sp1 = _mm512_shuffle_epi32(lhs_mat_23_0, (_MM_PERM_ENUM)160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3)
2814
2851
 
2815
- const __m512i lhs_mat_01_1_sp1 = _mm512_shuffle_epi32(lhs_mat_01_1, 160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11)
2816
- const __m512i lhs_mat_23_1_sp1 = _mm512_shuffle_epi32(lhs_mat_23_1, 160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11)
2852
+ const __m512i lhs_mat_01_1_sp1 = _mm512_shuffle_epi32(lhs_mat_01_1, (_MM_PERM_ENUM)160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11)
2853
+ const __m512i lhs_mat_23_1_sp1 = _mm512_shuffle_epi32(lhs_mat_23_1, (_MM_PERM_ENUM)160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11)
2817
2854
 
2818
- const __m512i lhs_mat_01_2_sp1 = _mm512_shuffle_epi32(lhs_mat_01_2, 160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19)
2819
- const __m512i lhs_mat_23_2_sp1 = _mm512_shuffle_epi32(lhs_mat_23_2, 160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19)
2855
+ const __m512i lhs_mat_01_2_sp1 = _mm512_shuffle_epi32(lhs_mat_01_2, (_MM_PERM_ENUM)160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19)
2856
+ const __m512i lhs_mat_23_2_sp1 = _mm512_shuffle_epi32(lhs_mat_23_2, (_MM_PERM_ENUM)160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19)
2820
2857
 
2821
- const __m512i lhs_mat_01_3_sp1 = _mm512_shuffle_epi32(lhs_mat_01_3, 160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27)
2822
- const __m512i lhs_mat_23_3_sp1 = _mm512_shuffle_epi32(lhs_mat_23_3, 160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27)
2858
+ const __m512i lhs_mat_01_3_sp1 = _mm512_shuffle_epi32(lhs_mat_01_3, (_MM_PERM_ENUM)160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27)
2859
+ const __m512i lhs_mat_23_3_sp1 = _mm512_shuffle_epi32(lhs_mat_23_3, (_MM_PERM_ENUM)160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27)
2823
2860
 
2824
2861
  // Shuffle pattern two - left side input
2825
2862
 
2826
- const __m512i lhs_mat_01_0_sp2 = _mm512_shuffle_epi32(lhs_mat_01_0, 245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7)
2827
- const __m512i lhs_mat_23_0_sp2 = _mm512_shuffle_epi32(lhs_mat_23_0, 245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7)
2863
+ const __m512i lhs_mat_01_0_sp2 = _mm512_shuffle_epi32(lhs_mat_01_0, (_MM_PERM_ENUM)245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7)
2864
+ const __m512i lhs_mat_23_0_sp2 = _mm512_shuffle_epi32(lhs_mat_23_0, (_MM_PERM_ENUM)245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7)
2828
2865
 
2829
- const __m512i lhs_mat_01_1_sp2 = _mm512_shuffle_epi32(lhs_mat_01_1, 245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15)
2830
- const __m512i lhs_mat_23_1_sp2 = _mm512_shuffle_epi32(lhs_mat_23_1, 245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15)
2866
+ const __m512i lhs_mat_01_1_sp2 = _mm512_shuffle_epi32(lhs_mat_01_1, (_MM_PERM_ENUM)245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15)
2867
+ const __m512i lhs_mat_23_1_sp2 = _mm512_shuffle_epi32(lhs_mat_23_1, (_MM_PERM_ENUM)245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15)
2831
2868
 
2832
- const __m512i lhs_mat_01_2_sp2 = _mm512_shuffle_epi32(lhs_mat_01_2, 245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23)
2833
- const __m512i lhs_mat_23_2_sp2 = _mm512_shuffle_epi32(lhs_mat_23_2, 245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23)
2869
+ const __m512i lhs_mat_01_2_sp2 = _mm512_shuffle_epi32(lhs_mat_01_2, (_MM_PERM_ENUM)245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23)
2870
+ const __m512i lhs_mat_23_2_sp2 = _mm512_shuffle_epi32(lhs_mat_23_2, (_MM_PERM_ENUM)245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23)
2834
2871
 
2835
- const __m512i lhs_mat_01_3_sp2 = _mm512_shuffle_epi32(lhs_mat_01_3, 245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31)
2836
- const __m512i lhs_mat_23_3_sp2 = _mm512_shuffle_epi32(lhs_mat_23_3, 245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31)
2872
+ const __m512i lhs_mat_01_3_sp2 = _mm512_shuffle_epi32(lhs_mat_01_3, (_MM_PERM_ENUM)245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31)
2873
+ const __m512i lhs_mat_23_3_sp2 = _mm512_shuffle_epi32(lhs_mat_23_3, (_MM_PERM_ENUM)245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31)
2837
2874
 
2838
2875
  // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
2839
2876
  // Resembles MMLAs into 2x2 matrices in ARM Version
@@ -2862,10 +2899,10 @@ void lm_ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void
2862
2899
 
2863
2900
 
2864
2901
  // Straighten out to make 4 row vectors
2865
- __m512i iacc_row_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00, _mm512_shuffle_epi32(iacc_mat_01, 78));
2866
- __m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, 78), iacc_mat_01);
2867
- __m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, 78));
2868
- __m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11);
2902
+ __m512i iacc_row_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00, _mm512_shuffle_epi32(iacc_mat_01, (_MM_PERM_ENUM)78));
2903
+ __m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, (_MM_PERM_ENUM)78), iacc_mat_01);
2904
+ __m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, (_MM_PERM_ENUM)78));
2905
+ __m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, (_MM_PERM_ENUM)78), iacc_mat_11);
2869
2906
 
2870
2907
  // Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes
2871
2908
  const __m128i row_scale_f16 = _mm_shuffle_epi32(_mm_maskload_epi32((int const*)(a_ptr[b].d), loadMask), 68);
@@ -3460,7 +3497,7 @@ void lm_ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void
3460
3497
  }
3461
3498
  }
3462
3499
 
3463
- void lm_ggml_gemm_iq4_nl_4x4_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
3500
+ static void lm_ggml_gemm_iq4_nl_4x4_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, const void * LM_GGML_RESTRICT vy, int nr, int nc) {
3464
3501
  const int qk = QK8_0;
3465
3502
  const int nb = n / qk;
3466
3503
  const int ncols_interleaved = 4;
@@ -3571,7 +3608,6 @@ void lm_ggml_gemm_iq4_nl_4x4_q8_0(int n, float * restrict s, size_t bs, const vo
3571
3608
  }
3572
3609
  }
3573
3610
 
3574
- // FIXME: this code is duplicated from ggml-aarch64.c
3575
3611
  static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int blck_size_interleave) {
3576
3612
  block_q4_0x4 out;
3577
3613
 
@@ -3641,20 +3677,20 @@ static block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int blck_size_in
3641
3677
  return out;
3642
3678
  }
3643
3679
 
3644
- static int repack_q4_0_to_q4_0_4_bl(struct lm_ggml_tensor * t, int interleave_block, const void * restrict data, size_t data_size) {
3680
+ static int repack_q4_0_to_q4_0_4_bl(struct lm_ggml_tensor * t, int interleave_block, const void * LM_GGML_RESTRICT data, size_t data_size) {
3645
3681
  LM_GGML_ASSERT(t->type == LM_GGML_TYPE_Q4_0);
3646
3682
  LM_GGML_ASSERT(interleave_block == 4 || interleave_block == 8);
3683
+ constexpr int nrows_interleaved = 4;
3647
3684
 
3648
3685
  block_q4_0x4 * dst = (block_q4_0x4 *)t->data;
3649
3686
  const block_q4_0 * src = (const block_q4_0 *)data;
3650
3687
  block_q4_0 dst_tmp[4];
3651
- int nrow = t->ne[1]; // Number of rows
3652
- int nrows_interleaved = 4;
3688
+ int nrow = lm_ggml_nrows(t);
3653
3689
  int nblocks = t->ne[0] / QK4_0;
3654
3690
 
3655
3691
  LM_GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0));
3656
3692
 
3657
- if (nrow % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
3693
+ if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
3658
3694
  return -1;
3659
3695
  }
3660
3696
 
@@ -3672,20 +3708,20 @@ static int repack_q4_0_to_q4_0_4_bl(struct lm_ggml_tensor * t, int interleave_bl
3672
3708
  LM_GGML_UNUSED(data_size);
3673
3709
  }
3674
3710
 
3675
- static int repack_q4_0_to_q4_0_8_bl(struct lm_ggml_tensor *t, int interleave_block, const void * restrict data, size_t data_size) {
3711
+ static int repack_q4_0_to_q4_0_8_bl(struct lm_ggml_tensor * t, int interleave_block, const void * LM_GGML_RESTRICT data, size_t data_size) {
3676
3712
  LM_GGML_ASSERT(t->type == LM_GGML_TYPE_Q4_0);
3677
3713
  LM_GGML_ASSERT(interleave_block == 8);
3714
+ constexpr int nrows_interleaved = 8;
3678
3715
 
3679
3716
  block_q4_0x8 * dst = (block_q4_0x8*)t->data;
3680
3717
  const block_q4_0 * src = (const block_q4_0*) data;
3681
3718
  block_q4_0 dst_tmp[8];
3682
- int nrow = t->ne[1]; // Number of rows
3683
- int nrows_interleaved = 8;
3719
+ int nrow = lm_ggml_nrows(t);
3684
3720
  int nblocks = t->ne[0] / QK4_0;
3685
3721
 
3686
3722
  LM_GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0));
3687
3723
 
3688
- if (nrow % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
3724
+ if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
3689
3725
  return -1;
3690
3726
  }
3691
3727
 
@@ -3712,16 +3748,18 @@ static block_iq4_nlx4 make_block_iq4_nlx4(block_iq4_nl * in, unsigned int blck_s
3712
3748
 
3713
3749
  const int end = QK4_NL * 2 / blck_size_interleave;
3714
3750
 
3715
- if (blck_size_interleave == 8) {
3716
- for (int i = 0; i < end; ++i) {
3717
- int src_id = i % 4;
3718
- int src_offset = (i / 4) * blck_size_interleave;
3719
- int dst_offset = i * blck_size_interleave;
3720
-
3721
- // Using memcpy to avoid unaligned memory accesses
3722
- memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint64_t));
3723
- }
3724
- } else if (blck_size_interleave == 4) {
3751
+ // TODO: this branch seems wrong
3752
+ //if (blck_size_interleave == 8) {
3753
+ // for (int i = 0; i < end; ++i) {
3754
+ // int src_id = i % 4;
3755
+ // int src_offset = (i / 4) * blck_size_interleave;
3756
+ // int dst_offset = i * blck_size_interleave;
3757
+
3758
+ // // Using memcpy to avoid unaligned memory accesses
3759
+ // memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint64_t));
3760
+ // }
3761
+ //} else
3762
+ if (blck_size_interleave == 4) {
3725
3763
  for (int i = 0; i < end; ++i) {
3726
3764
  int src_id = i % 4;
3727
3765
  int src_offset = (i / 4) * blck_size_interleave;
@@ -3736,20 +3774,21 @@ static block_iq4_nlx4 make_block_iq4_nlx4(block_iq4_nl * in, unsigned int blck_s
3736
3774
  return out;
3737
3775
  }
3738
3776
 
3739
- static int repack_iq4_nl_to_iq4_nl_4_bl(struct lm_ggml_tensor * t, int interleave_block, const void * restrict data, size_t data_size) {
3777
+ static int repack_iq4_nl_to_iq4_nl_4_bl(struct lm_ggml_tensor * t, int interleave_block, const void * LM_GGML_RESTRICT data, size_t data_size) {
3740
3778
  LM_GGML_ASSERT(t->type == LM_GGML_TYPE_IQ4_NL);
3741
- LM_GGML_ASSERT(interleave_block == 4 || interleave_block == 8);
3779
+ //LM_GGML_ASSERT(interleave_block == 4 || interleave_block == 8);
3780
+ LM_GGML_ASSERT(interleave_block == 4);
3742
3781
 
3743
3782
  block_iq4_nlx4 * dst = (block_iq4_nlx4 *)t->data;
3744
3783
  const block_iq4_nl * src = (const block_iq4_nl *)data;
3745
3784
  block_iq4_nl dst_tmp[4];
3746
- int nrow = t->ne[1]; // Number of rows
3785
+ int nrow = lm_ggml_nrows(t);
3747
3786
  int nrows_interleaved = 4;
3748
3787
  int nblocks = t->ne[0] / QK4_0;
3749
3788
 
3750
3789
  LM_GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_iq4_nl));
3751
3790
 
3752
- if (nrow % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
3791
+ if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
3753
3792
  return -1;
3754
3793
  }
3755
3794
 
@@ -3767,57 +3806,457 @@ static int repack_iq4_nl_to_iq4_nl_4_bl(struct lm_ggml_tensor * t, int interleav
3767
3806
  LM_GGML_UNUSED(data_size);
3768
3807
  }
3769
3808
 
3770
- // Prepare for optimized kernels if applicable
3771
- void lm_ggml_aarch64_repack_tensor(struct lm_ggml_tensor * cur, enum lm_ggml_type repack_type, const void * restrict data, size_t data_size) {
3772
- if (cur->type == repack_type) {
3773
- memcpy(cur->data, data, data_size);
3774
- return;
3809
+ namespace ggml::cpu::aarch64 {
3810
+ // repack
3811
+ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS>
3812
+ int repack(struct lm_ggml_tensor *, const void *, size_t);
3813
+
3814
+ // TODO: generalise.
3815
+ template <> int repack<block_q4_0, 4, 4>(struct lm_ggml_tensor * t, const void * data, size_t data_size) {
3816
+ return repack_q4_0_to_q4_0_4_bl(t, 4, data, data_size);
3817
+ }
3818
+
3819
+ template <> int repack<block_q4_0, 8, 4>(struct lm_ggml_tensor * t, const void * data, size_t data_size) {
3820
+ return repack_q4_0_to_q4_0_4_bl(t, 8, data, data_size);
3821
+ }
3822
+
3823
+ template <> int repack<block_q4_0, 8, 8>(struct lm_ggml_tensor * t, const void * data, size_t data_size) {
3824
+ return repack_q4_0_to_q4_0_8_bl(t, 8, data, data_size);
3825
+ }
3826
+
3827
+ template <> int repack<block_iq4_nl, 4, 4>(struct lm_ggml_tensor * t, const void * data, size_t data_size) {
3828
+ return repack_iq4_nl_to_iq4_nl_4_bl(t, 4, data, data_size);
3829
+ }
3830
+
3831
+ // TODO: needs to be revisited
3832
+ //template <> int repack<block_iq4_nl, 8, 4>(struct lm_ggml_tensor * t, const void * data, size_t data_size) {
3833
+ // return repack_iq4_nl_to_iq4_nl_4_bl(t, 8, data, data_size);
3834
+ //}
3835
+
3836
+ // gemv
3837
+ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS>
3838
+ void gemv(int, float *, size_t, const void *, const void *, int, int);
3839
+
3840
+ template <> void gemv<block_q4_0, 4, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
3841
+ lm_ggml_gemv_q4_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
3842
+ }
3843
+
3844
+ template <> void gemv<block_q4_0, 8, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
3845
+ lm_ggml_gemv_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
3846
+ }
3847
+
3848
+ template <> void gemv<block_q4_0, 8, 8>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
3849
+ lm_ggml_gemv_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
3850
+ }
3851
+
3852
+ template <>
3853
+ void gemv<block_iq4_nl, 4, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
3854
+ lm_ggml_gemv_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
3855
+ }
3856
+
3857
+ // gemm
3858
+ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS>
3859
+ void gemm(int, float *, size_t, const void *, const void *, int, int);
3860
+
3861
+ template <> void gemm<block_q4_0, 4, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
3862
+ lm_ggml_gemm_q4_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
3863
+ }
3864
+
3865
+ template <> void gemm<block_q4_0, 8, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
3866
+ lm_ggml_gemm_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
3867
+ }
3868
+
3869
+ template <> void gemm<block_q4_0, 8, 8>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
3870
+ lm_ggml_gemm_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
3871
+ }
3872
+
3873
+ template <>
3874
+ void gemm<block_iq4_nl, 4, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
3875
+ lm_ggml_gemm_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
3876
+ }
3877
+
3878
+ class tensor_traits_base : public ggml::cpu::tensor_traits {
3879
+ public:
3880
+ virtual int repack(struct lm_ggml_tensor * t, const void * data, size_t data_size) = 0;
3881
+ };
3882
+
3883
+ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS> class tensor_traits : public tensor_traits_base {
3884
+
3885
+ bool work_size(int /* n_threads */, const struct lm_ggml_tensor * op, size_t & size) override {
3886
+ // not realy a LM_GGML_TYPE_Q8_0 but same size.
3887
+ switch (op->op) {
3888
+ case LM_GGML_OP_MUL_MAT:
3889
+ size = lm_ggml_row_size(LM_GGML_TYPE_Q8_0, lm_ggml_nelements(op->src[1]));
3890
+ return true;
3891
+ case LM_GGML_OP_MUL_MAT_ID:
3892
+ size = lm_ggml_row_size(LM_GGML_TYPE_Q8_0, lm_ggml_nelements(op->src[1]));
3893
+ size = LM_GGML_PAD(size, sizeof(int64_t)); // + padding for next bloc.
3894
+ size += sizeof(int64_t) * (1+op->src[0]->ne[2]) * op->src[1]->ne[2];
3895
+ return true;
3896
+ default:
3897
+ // LM_GGML_ABORT("fatal error");
3898
+ break;
3899
+ }
3900
+ return false;
3775
3901
  }
3776
3902
 
3777
- if (cur->type == LM_GGML_TYPE_Q4_0) {
3778
- switch (repack_type) {
3779
- case LM_GGML_TYPE_Q4_0_8_8:
3780
- repack_q4_0_to_q4_0_8_bl(cur, 8, data, data_size);
3781
- break;
3782
- case LM_GGML_TYPE_Q4_0_4_8:
3783
- repack_q4_0_to_q4_0_4_bl(cur, 8, data, data_size);
3784
- break;
3785
- case LM_GGML_TYPE_Q4_0_4_4:
3786
- repack_q4_0_to_q4_0_4_bl(cur, 4, data, data_size);
3787
- break;
3788
- default:
3789
- LM_GGML_ABORT("Unsupported type");
3903
+ bool compute_forward(struct lm_ggml_compute_params * params, struct lm_ggml_tensor * op) override {
3904
+ switch (op->op) {
3905
+ case LM_GGML_OP_MUL_MAT:
3906
+ forward_mul_mat(params, op);
3907
+ return true;
3908
+ case LM_GGML_OP_MUL_MAT_ID:
3909
+ forward_mul_mat_id(params, op);
3910
+ return true;
3911
+ default:
3912
+ // LM_GGML_ABORT("fatal error");
3913
+ break;
3790
3914
  }
3791
- } else if (cur->type == LM_GGML_TYPE_IQ4_NL) {
3792
- switch (repack_type) {
3793
- case LM_GGML_TYPE_IQ4_NL_4_4:
3794
- repack_iq4_nl_to_iq4_nl_4_bl(cur, 4, data, data_size);
3795
- break;
3796
- default:
3797
- LM_GGML_ABORT("Unsupported type");
3915
+ return false;
3916
+ }
3917
+
3918
+ void forward_mul_mat(lm_ggml_compute_params * params, lm_ggml_tensor * op) {
3919
+ const lm_ggml_tensor * src0 = op->src[0];
3920
+ const lm_ggml_tensor * src1 = op->src[1];
3921
+ lm_ggml_tensor * dst = op;
3922
+
3923
+ LM_GGML_TENSOR_BINARY_OP_LOCALS
3924
+
3925
+ const int ith = params->ith;
3926
+ const int nth = params->nth;
3927
+
3928
+ LM_GGML_ASSERT(ne0 == ne01);
3929
+ LM_GGML_ASSERT(ne1 == ne11);
3930
+ LM_GGML_ASSERT(ne2 == ne12);
3931
+ LM_GGML_ASSERT(ne3 == ne13);
3932
+
3933
+ // dst cannot be transposed or permuted
3934
+ LM_GGML_ASSERT(nb0 == sizeof(float));
3935
+ LM_GGML_ASSERT(nb0 <= nb1);
3936
+ LM_GGML_ASSERT(nb1 <= nb2);
3937
+ LM_GGML_ASSERT(nb2 <= nb3);
3938
+
3939
+ LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_F32);
3940
+
3941
+ LM_GGML_ASSERT(lm_ggml_n_dims(op->src[0]) == 2);
3942
+ // LM_GGML_ASSERT(lm_ggml_n_dims(op->src[1]) == 2);
3943
+
3944
+ char * wdata = static_cast<char *>(params->wdata);
3945
+ const size_t nbw1 = lm_ggml_row_size(LM_GGML_TYPE_Q8_0, ne10);
3946
+
3947
+ assert(params->wsize >= nbw1 * ne11);
3948
+
3949
+ const lm_ggml_from_float_t from_float = lm_ggml_get_type_traits_cpu(LM_GGML_TYPE_Q8_0)->from_float;
3950
+
3951
+ int64_t i11_processed = 0;
3952
+ for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
3953
+ quantize_mat_q8_0((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), 4, ne10,
3954
+ INTER_SIZE);
3955
+ }
3956
+ i11_processed = ne11 - ne11 % 4;
3957
+ for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) {
3958
+ from_float((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), ne10);
3959
+ }
3960
+
3961
+ lm_ggml_barrier(params->threadpool);
3962
+
3963
+ const void * src1_wdata = params->wdata;
3964
+ const size_t src1_col_stride = lm_ggml_row_size(LM_GGML_TYPE_Q8_0, ne10);
3965
+ int64_t src0_start = (ith * ne01) / nth;
3966
+ int64_t src0_end = ((ith + 1) * ne01) / nth;
3967
+ src0_start = (src0_start % NB_COLS) ? src0_start + NB_COLS - (src0_start % NB_COLS) : src0_start;
3968
+ src0_end = (src0_end % NB_COLS) ? src0_end + NB_COLS - (src0_end % NB_COLS) : src0_end;
3969
+ if (src0_start >= src0_end) {
3970
+ return;
3971
+ }
3972
+
3973
+ // If there are more than three rows in src1, use gemm; otherwise, use gemv.
3974
+ if (ne11 > 3) {
3975
+ gemm<BLOC_TYPE, INTER_SIZE, NB_COLS>(ne00, (float *) ((char *) dst->data) + src0_start, ne01,
3976
+ (const char *) src0->data + src0_start * nb01,
3977
+ (const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start);
3978
+ }
3979
+ for (int iter = ne11 - ne11 % 4; iter < ne11; iter++) {
3980
+ gemv<BLOC_TYPE, INTER_SIZE, NB_COLS>(ne00, (float *) ((char *) dst->data + (iter * nb1)) + src0_start, ne01,
3981
+ (const char *) src0->data + src0_start * nb01,
3982
+ (const char *) src1_wdata + (src1_col_stride * iter), 1,
3983
+ src0_end - src0_start);
3798
3984
  }
3799
- } else {
3800
- LM_GGML_ABORT("Unsupported type");
3801
3985
  }
3802
- }
3803
3986
 
3804
- enum lm_ggml_type lm_ggml_aarch64_get_optimal_repack_type(const struct lm_ggml_tensor * cur) {
3987
+ void forward_mul_mat_id(lm_ggml_compute_params * params, lm_ggml_tensor * op) {
3988
+ const lm_ggml_tensor * src0 = op->src[0];
3989
+ const lm_ggml_tensor * src1 = op->src[1];
3990
+ const lm_ggml_tensor * ids = op->src[2];
3991
+ lm_ggml_tensor * dst = op;
3992
+
3993
+ LM_GGML_TENSOR_BINARY_OP_LOCALS
3994
+
3995
+ const int ith = params->ith;
3996
+ const int nth = params->nth;
3997
+
3998
+ const lm_ggml_from_float_t from_float = lm_ggml_get_type_traits_cpu(LM_GGML_TYPE_Q8_0)->from_float;
3999
+
4000
+ // we don't support permuted src0 or src1
4001
+ LM_GGML_ASSERT(nb00 == lm_ggml_type_size(src0->type));
4002
+ LM_GGML_ASSERT(nb10 == lm_ggml_type_size(src1->type));
4003
+
4004
+ // dst cannot be transposed or permuted
4005
+ LM_GGML_ASSERT(nb0 == sizeof(float));
4006
+ LM_GGML_ASSERT(nb0 <= nb1);
4007
+ LM_GGML_ASSERT(nb1 <= nb2);
4008
+ LM_GGML_ASSERT(nb2 <= nb3);
4009
+
4010
+ LM_GGML_ASSERT(ne03 == 1);
4011
+ LM_GGML_ASSERT(ne13 == 1);
4012
+ LM_GGML_ASSERT(ne3 == 1);
4013
+
4014
+ LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_F32);
4015
+
4016
+ // row groups
4017
+ const int n_ids = ids->ne[0]; // n_expert_used
4018
+ const int n_as = ne02; // n_expert
4019
+
4020
+ const size_t nbw1 = lm_ggml_row_size(LM_GGML_TYPE_Q8_0, ne10);
4021
+ const size_t nbw2 = nbw1*ne11;
4022
+ const size_t nbw3 = nbw2*ne12;
4023
+
4024
+ struct mmid_row_mapping {
4025
+ int32_t i1;
4026
+ int32_t i2;
4027
+ };
4028
+
4029
+ LM_GGML_ASSERT(params->wsize >= (LM_GGML_PAD(nbw3, sizeof(int64_t)) + n_as * sizeof(int64_t) +
4030
+ n_as * ne12 * sizeof(mmid_row_mapping)));
4031
+
4032
+ auto wdata = (char *) params->wdata;
4033
+ auto wdata_src1_end = (char *) wdata + LM_GGML_PAD(nbw3, sizeof(int64_t));
4034
+ int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
4035
+ struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *) (matrix_row_counts + n_as); // [n_as][ne12]
4036
+
4037
+ // src1: float32 => block_q8_0
4038
+ for (int64_t i12 = 0; i12 < ne12; ++i12) {
4039
+ for (int64_t i11 = ith; i11 < ne11; i11 += nth) {
4040
+ from_float((float *)((char *) src1->data + i12 * nb12 + i11 * nb11),
4041
+ (void *) (wdata + i12 * nbw2 + i11 * nbw1),
4042
+ ne10);
4043
+ }
4044
+ }
4045
+
4046
+ #define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id) * ne12 + (i1)]
4047
+
4048
+ if (ith == 0) {
4049
+ // initialize matrix_row_counts
4050
+ memset(matrix_row_counts, 0, n_as * sizeof(int64_t));
4051
+
4052
+ // group rows by src0 matrix
4053
+ for (int32_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) {
4054
+ for (int32_t id = 0; id < n_ids; ++id) {
4055
+ const int32_t i02 =
4056
+ *(const int32_t *) ((const char *) ids->data + iid1 * ids->nb[1] + id * ids->nb[0]);
4057
+
4058
+ LM_GGML_ASSERT(i02 >= 0 && i02 < n_as);
4059
+
4060
+ MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = { id, iid1 };
4061
+ matrix_row_counts[i02] += 1;
4062
+ }
4063
+ }
4064
+ }
4065
+
4066
+ lm_ggml_barrier(params->threadpool);
4067
+
4068
+ // compute each matrix multiplication in sequence
4069
+ for (int cur_a = 0; cur_a < n_as; ++cur_a) {
4070
+ const int64_t cne1 = matrix_row_counts[cur_a];
4071
+
4072
+ if (cne1 == 0) {
4073
+ continue;
4074
+ }
4075
+
4076
+ auto src0_cur = (const char *) src0->data + cur_a*nb02;
4077
+
4078
+ //const int64_t nr0 = ne01; // src0 rows
4079
+ const int64_t nr1 = cne1; // src1 rows
4080
+
4081
+ int64_t src0_cur_start = (ith * ne01) / nth;
4082
+ int64_t src0_cur_end = ((ith + 1) * ne01) / nth;
4083
+ src0_cur_start =
4084
+ (src0_cur_start % NB_COLS) ? src0_cur_start + NB_COLS - (src0_cur_start % NB_COLS) : src0_cur_start;
4085
+ src0_cur_end = (src0_cur_end % NB_COLS) ? src0_cur_end + NB_COLS - (src0_cur_end % NB_COLS) : src0_cur_end;
4086
+
4087
+ if (src0_cur_start >= src0_cur_end) return;
4088
+
4089
+ for (int ir1 = 0; ir1 < nr1; ir1++) {
4090
+ struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, ir1);
4091
+ const int id = row_mapping.i1; // selected expert index
4092
+
4093
+ const int64_t i11 = id % ne11;
4094
+ const int64_t i12 = row_mapping.i2; // row index in src1
4095
+
4096
+ const int64_t i1 = id; // selected expert index
4097
+ const int64_t i2 = i12; // row
4098
+
4099
+ auto src1_col = (const char *) wdata + (i11 * nbw1 + i12 * nbw2);
4100
+
4101
+ gemv<BLOC_TYPE, INTER_SIZE, NB_COLS>(
4102
+ ne00, (float *)((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start,
4103
+ ne01, src0_cur + src0_cur_start * nb01,
4104
+ src1_col, 1, src0_cur_end - src0_cur_start);
4105
+ }
4106
+ }
4107
+ #undef MMID_MATRIX_ROW
4108
+ }
4109
+
4110
+ int repack(struct lm_ggml_tensor * t, const void * data, size_t data_size) override {
4111
+ LM_GGML_LOG_DEBUG("%s: repack tensor %s with %s_%dx%d\n", __func__, t->name, lm_ggml_type_name(t->type),
4112
+ (int) NB_COLS, (int) INTER_SIZE);
4113
+ return ggml::cpu::aarch64::repack<BLOC_TYPE, INTER_SIZE, NB_COLS>(t, data, data_size);
4114
+ }
4115
+ };
4116
+
4117
+ // instance for Q4
4118
+ static const tensor_traits<block_q4_0, 4, 4> q4_0_4x4_q8_0;
4119
+ static const tensor_traits<block_q4_0, 8, 4> q4_0_4x8_q8_0;
4120
+ static const tensor_traits<block_q4_0, 8, 8> q4_0_8x8_q8_0;
4121
+
4122
+ // instance for IQ4
4123
+ static const tensor_traits<block_iq4_nl, 4, 4> iq4_nl_4x4_q8_0;
4124
+
4125
+ } // namespace ggml::cpu::aarch64
4126
+
4127
+ static const ggml::cpu::tensor_traits * lm_ggml_aarch64_get_optimal_repack_type(const struct lm_ggml_tensor * cur) {
3805
4128
  if (cur->type == LM_GGML_TYPE_Q4_0) {
3806
- // TODO: enable for AVX2 - currently disabled due to bad gemv performance
3807
- if (/* lm_ggml_cpu_has_avx2() || */ (lm_ggml_cpu_has_sve() && lm_ggml_cpu_has_matmul_int8() && lm_ggml_cpu_get_sve_cnt() == QK8_0)) {
3808
- return LM_GGML_TYPE_Q4_0_8_8;
4129
+ if (lm_ggml_cpu_has_avx2() || (lm_ggml_cpu_has_sve() && lm_ggml_cpu_has_matmul_int8() && lm_ggml_cpu_get_sve_cnt() == QK8_0)) {
4130
+ if (cur->ne[1] % 8 == 0) {
4131
+ return &ggml::cpu::aarch64::q4_0_8x8_q8_0;
4132
+ }
3809
4133
  }
3810
4134
  if (lm_ggml_cpu_has_neon() && lm_ggml_cpu_has_matmul_int8()) {
3811
- return LM_GGML_TYPE_Q4_0_4_8;
4135
+ if (cur->ne[1] % 4 == 0) {
4136
+ return &ggml::cpu::aarch64::q4_0_4x8_q8_0;
4137
+ }
3812
4138
  }
3813
4139
  if (lm_ggml_cpu_has_neon() && lm_ggml_cpu_has_dotprod()) {
3814
- return LM_GGML_TYPE_Q4_0_4_4;
4140
+ if (cur->ne[1] % 4 == 0) {
4141
+ return &ggml::cpu::aarch64::q4_0_4x4_q8_0;
4142
+ }
3815
4143
  }
3816
4144
  } else if (cur->type == LM_GGML_TYPE_IQ4_NL) {
3817
4145
  if (lm_ggml_cpu_has_neon() && lm_ggml_cpu_has_dotprod()) {
3818
- return LM_GGML_TYPE_IQ4_NL_4_4;
4146
+ if (cur->ne[1] % 4 == 0) {
4147
+ return &ggml::cpu::aarch64::iq4_nl_4x4_q8_0;
4148
+ }
4149
+ }
4150
+ }
4151
+
4152
+ return nullptr;
4153
+ }
4154
+
4155
+ static void lm_ggml_backend_cpu_aarch64_buffer_init_tensor(lm_ggml_backend_buffer_t buffer, struct lm_ggml_tensor * tensor) {
4156
+ tensor->extra = (void *) const_cast<ggml::cpu::tensor_traits *>(lm_ggml_aarch64_get_optimal_repack_type(tensor));
4157
+
4158
+ LM_GGML_UNUSED(buffer);
4159
+ }
4160
+
4161
+ static void lm_ggml_backend_cpu_aarch64_buffer_set_tensor(lm_ggml_backend_buffer_t buffer, struct lm_ggml_tensor * tensor,
4162
+ const void * data, size_t offset, size_t size) {
4163
+ LM_GGML_ASSERT(offset == 0);
4164
+ LM_GGML_ASSERT(size == lm_ggml_nbytes(tensor));
4165
+
4166
+ auto tensor_traits = (ggml::cpu::aarch64::tensor_traits_base *) tensor->extra;
4167
+ auto OK = tensor_traits->repack(tensor, data, size);
4168
+
4169
+ LM_GGML_ASSERT(OK == 0);
4170
+ LM_GGML_UNUSED(buffer);
4171
+ }
4172
+
4173
+ static const char * lm_ggml_backend_cpu_aarch64_buffer_type_get_name(lm_ggml_backend_buffer_type_t buft) {
4174
+ return "CPU_AARCH64";
4175
+
4176
+ LM_GGML_UNUSED(buft);
4177
+ }
4178
+
4179
+ static lm_ggml_backend_buffer_t lm_ggml_backend_cpu_aarch64_buffer_type_alloc_buffer(lm_ggml_backend_buffer_type_t buft, size_t size) {
4180
+ lm_ggml_backend_buffer_t buffer = lm_ggml_backend_buft_alloc_buffer(lm_ggml_backend_cpu_buffer_type(), size);
4181
+
4182
+ if (buffer == nullptr) {
4183
+ return nullptr;
4184
+ }
4185
+
4186
+ buffer->buft = buft;
4187
+ buffer->iface.init_tensor = lm_ggml_backend_cpu_aarch64_buffer_init_tensor;
4188
+ buffer->iface.set_tensor = lm_ggml_backend_cpu_aarch64_buffer_set_tensor;
4189
+ return buffer;
4190
+ }
4191
+
4192
+ static size_t lm_ggml_backend_cpu_aarch64_buffer_type_get_alignment(lm_ggml_backend_buffer_type_t buft) {
4193
+ return TENSOR_ALIGNMENT;
4194
+
4195
+ LM_GGML_UNUSED(buft);
4196
+ }
4197
+
4198
+ namespace ggml::cpu::aarch64 {
4199
+ class extra_buffer_type : ggml::cpu::extra_buffer_type {
4200
+ bool supports_op(lm_ggml_backend_dev_t, const struct lm_ggml_tensor * op) override {
4201
+ if ( op->op == LM_GGML_OP_MUL_MAT &&
4202
+ op->src[0]->buffer &&
4203
+ (lm_ggml_n_dims(op->src[0]) == 2) &&
4204
+ op->src[0]->buffer->buft == lm_ggml_backend_cpu_aarch64_buffer_type() &&
4205
+ lm_ggml_aarch64_get_optimal_repack_type(op->src[0])
4206
+ ) {
4207
+ if (op->src[1]->buffer && !lm_ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
4208
+ return false;
4209
+ }
4210
+ if (op->src[1]->type == LM_GGML_TYPE_F32) {
4211
+ return true;
4212
+ }
4213
+ //if (op->src[1]->type == LM_GGML_TYPE_Q8_0) {
4214
+ // return true;
4215
+ //}
4216
+ // may be possible if Q8_0 packed...
4217
+ } else if (op->op == LM_GGML_OP_MUL_MAT_ID
4218
+ && op->src[0]->buffer
4219
+ && (lm_ggml_n_dims(op->src[0]) == 3)
4220
+ && op->src[0]->buffer->buft == lm_ggml_backend_cpu_aarch64_buffer_type()
4221
+ && lm_ggml_aarch64_get_optimal_repack_type(op->src[0])
4222
+ ) {
4223
+ if (op->src[1]->buffer && !lm_ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
4224
+ return false;
4225
+ }
4226
+ if (op->src[1]->type == LM_GGML_TYPE_F32) {
4227
+ return true;
4228
+ }
4229
+ //if (op->src[1]->type == LM_GGML_TYPE_Q8_0) {
4230
+ // return true;
4231
+ //}
3819
4232
  }
4233
+ return false;
3820
4234
  }
3821
4235
 
3822
- return cur->type;
4236
+ ggml::cpu::tensor_traits * get_tensor_traits(const struct lm_ggml_tensor * op) override {
4237
+ if (op->op == LM_GGML_OP_MUL_MAT || op->op == LM_GGML_OP_MUL_MAT_ID) {
4238
+ if (op->src[0]->buffer && op->src[0]->buffer->buft == lm_ggml_backend_cpu_aarch64_buffer_type()) {
4239
+ return (ggml::cpu::tensor_traits *) op->src[0]->extra;
4240
+ }
4241
+ }
4242
+ return nullptr;
4243
+ }
4244
+ };
4245
+ } // namespace ggml::cpu::aarch64
4246
+
4247
+ lm_ggml_backend_buffer_type_t lm_ggml_backend_cpu_aarch64_buffer_type(void) {
4248
+ static struct lm_ggml_backend_buffer_type lm_ggml_backend_cpu_buffer_type_aarch64 = {
4249
+ /* .iface = */ {
4250
+ /* .get_name = */ lm_ggml_backend_cpu_aarch64_buffer_type_get_name,
4251
+ /* .alloc_buffer = */ lm_ggml_backend_cpu_aarch64_buffer_type_alloc_buffer,
4252
+ /* .get_alignment = */ lm_ggml_backend_cpu_aarch64_buffer_type_get_alignment,
4253
+ /* .get_max_size = */ nullptr, // defaults to SIZE_MAX
4254
+ /* .get_alloc_size = */ nullptr, // defaults to lm_ggml_nbytes
4255
+ /* .is_host = */ nullptr,
4256
+ },
4257
+ /* .device = */ lm_ggml_backend_reg_dev_get(lm_ggml_backend_cpu_reg(), 0),
4258
+ /* .context = */ new ggml::cpu::aarch64::extra_buffer_type(),
4259
+ };
4260
+
4261
+ return &lm_ggml_backend_cpu_buffer_type_aarch64;
3823
4262
  }