cui-llama.rn 1.3.3 → 1.3.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. package/android/src/main/CMakeLists.txt +5 -7
  2. package/android/src/main/java/com/rnllama/LlamaContext.java +4 -4
  3. package/android/src/main/jni.cpp +9 -9
  4. package/cpp/common.cpp +28 -44
  5. package/cpp/common.h +35 -14
  6. package/cpp/ggml-alloc.c +0 -1
  7. package/cpp/ggml-backend-impl.h +38 -20
  8. package/cpp/ggml-backend-reg.cpp +246 -92
  9. package/cpp/ggml-backend.h +1 -0
  10. package/cpp/ggml-common.h +42 -48
  11. package/cpp/{ggml-cpu-aarch64.c → ggml-cpu-aarch64.cpp} +642 -223
  12. package/cpp/ggml-cpu-aarch64.h +2 -26
  13. package/cpp/ggml-cpu-traits.cpp +36 -0
  14. package/cpp/ggml-cpu-traits.h +38 -0
  15. package/cpp/ggml-cpu.c +14122 -13971
  16. package/cpp/ggml-cpu.cpp +627 -715
  17. package/cpp/ggml-cpu.h +0 -17
  18. package/cpp/ggml-impl.h +22 -6
  19. package/cpp/ggml-metal.m +482 -24
  20. package/cpp/ggml-quants.c +0 -9
  21. package/cpp/ggml-threading.h +4 -2
  22. package/cpp/ggml.c +284 -178
  23. package/cpp/ggml.h +73 -25
  24. package/cpp/llama-grammar.cpp +15 -15
  25. package/cpp/llama-grammar.h +2 -5
  26. package/cpp/llama-sampling.cpp +35 -90
  27. package/cpp/llama-vocab.cpp +7 -2
  28. package/cpp/llama-vocab.h +1 -1
  29. package/cpp/llama.cpp +1782 -586
  30. package/cpp/llama.h +20 -19
  31. package/cpp/sampling.cpp +11 -16
  32. package/cpp/sgemm.cpp +265 -258
  33. package/cpp/sgemm.h +2 -2
  34. package/cpp/speculative.cpp +4 -0
  35. package/cpp/unicode.cpp +51 -51
  36. package/cpp/unicode.h +9 -10
  37. package/lib/commonjs/index.js +38 -1
  38. package/lib/commonjs/index.js.map +1 -1
  39. package/lib/module/index.js +36 -0
  40. package/lib/module/index.js.map +1 -1
  41. package/lib/typescript/NativeRNLlama.d.ts +2 -3
  42. package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
  43. package/lib/typescript/index.d.ts +36 -2
  44. package/lib/typescript/index.d.ts.map +1 -1
  45. package/package.json +1 -1
  46. package/src/NativeRNLlama.ts +3 -3
  47. package/src/index.ts +46 -2
  48. package/cpp/amx/amx.cpp +0 -196
  49. package/cpp/amx/amx.h +0 -20
  50. package/cpp/amx/common.h +0 -101
  51. package/cpp/amx/mmq.cpp +0 -2524
  52. package/cpp/amx/mmq.h +0 -16
  53. package/cpp/ggml-aarch64.c +0 -129
  54. package/cpp/ggml-aarch64.h +0 -19
@@ -1,20 +1,57 @@
1
- #define LM_GGML_COMMON_IMPL_C
1
+ #define LM_GGML_COMMON_IMPL_CPP
2
+ #define LM_GGML_COMMON_DECL_CPP
2
3
  #include "ggml-common.h"
4
+ #include "ggml-backend-impl.h"
3
5
 
4
6
  #include "ggml-quants.h"
5
7
  #include "ggml-impl.h"
6
8
  #include "ggml-cpu.h"
7
9
  #include "ggml-cpu-impl.h"
10
+ #include "ggml-cpu-traits.h"
8
11
 
9
- #include <math.h>
10
- #include <string.h>
11
- #include <assert.h>
12
- #include <float.h>
13
- #include <stdlib.h> // for qsort
14
- #include <stdio.h> // for LM_GGML_ASSERT
12
+ #include <cmath>
13
+ #include <cstring>
14
+ #include <cassert>
15
+ #include <cfloat>
16
+ #include <cstdlib> // for qsort
17
+ #include <cstdio> // for LM_GGML_ASSERT
15
18
 
16
19
  #include "ggml-cpu-aarch64.h"
17
20
 
21
+ // TODO: move to include file?
22
+ template <int K> constexpr int QK_0() {
23
+ if constexpr (K == 4) {
24
+ return QK4_0;
25
+ }
26
+ if constexpr (K == 8) {
27
+ return QK8_0;
28
+ }
29
+ return -1;
30
+ }
31
+
32
+ template <int K, int N> struct block {
33
+ lm_ggml_half d[N]; // deltas for N qK_0 blocks
34
+ int8_t qs[(QK_0<K>() * N * K) / 8]; // quants for N qK_0 blocks
35
+ };
36
+
37
+ // control size
38
+ static_assert(sizeof(block<4, 4>) == 4 * sizeof(lm_ggml_half) + QK8_0 * 2, "wrong block<4,4> size/padding");
39
+ static_assert(sizeof(block<4, 8>) == 8 * sizeof(lm_ggml_half) + QK8_0 * 4, "wrong block<4,8> size/padding");
40
+ static_assert(sizeof(block<8, 4>) == 4 * sizeof(lm_ggml_half) + QK8_0 * 4, "wrong block<8,4> size/padding");
41
+ static_assert(sizeof(block<8, 8>) == 8 * sizeof(lm_ggml_half) + QK8_0 * 8, "wrong block<8,8> size/padding");
42
+
43
+ using block_q4_0x4 = block<4, 4>;
44
+ using block_q4_0x8 = block<4, 8>;
45
+ using block_q8_0x4 = block<8, 4>;
46
+ using block_q8_0x8 = block<8, 8>;
47
+
48
+ struct block_iq4_nlx4 {
49
+ lm_ggml_half d[4]; // deltas for 4 iq4_nl blocks
50
+ uint8_t qs[QK4_NL * 2]; // nibbles / quants for 4 iq4_nl blocks
51
+ };
52
+
53
+ static_assert(sizeof(block_iq4_nlx4) == 4 * sizeof(lm_ggml_half) + QK4_NL * 2, "wrong iq4_nlx4 block size/padding");
54
+
18
55
  #if defined(__GNUC__)
19
56
  #pragma GCC diagnostic ignored "-Woverlength-strings"
20
57
  #elif defined(_MSC_VER)
@@ -128,7 +165,7 @@ static inline __m512i sum_i16_pairs_int_32x16(const __m512i x) {
128
165
  }
129
166
 
130
167
  static inline __m512i mul_sum_us8_pairs_int32x16(const __m512i ax, const __m512i sy) {
131
- #if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
168
+ #if defined(__AVX512VNNI__)
132
169
  const __m512i zero = _mm512_setzero_si512();
133
170
  return _mm512_dpbusd_epi32(zero, ax, sy);
134
171
  #else
@@ -185,12 +222,12 @@ static inline __m256i mul_sum_i8_pairs_int32x8(const __m256i x, const __m256i y)
185
222
 
186
223
  static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
187
224
 
188
- static void quantize_q8_0_4x4(const float * restrict x, void * restrict vy, int64_t k) {
225
+ static void quantize_q8_0_4x4(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT vy, int64_t k) {
189
226
  assert(QK8_0 == 32);
190
227
  assert(k % QK8_0 == 0);
191
228
  const int nb = k / QK8_0;
192
229
 
193
- block_q8_0x4 * restrict y = (block_q8_0x4 *) vy;
230
+ block_q8_0x4 * LM_GGML_RESTRICT y = (block_q8_0x4 *) vy;
194
231
 
195
232
  #if defined(__ARM_NEON)
196
233
  float32x4_t srcv[4][8];
@@ -279,12 +316,12 @@ static void quantize_q8_0_4x4(const float * restrict x, void * restrict vy, int6
279
316
  #endif
280
317
  }
281
318
 
282
- static void quantize_q8_0_4x8(const float * restrict x, void * restrict vy, int64_t k) {
319
+ static void quantize_q8_0_4x8(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT vy, int64_t k) {
283
320
  assert(QK8_0 == 32);
284
321
  assert(k % QK8_0 == 0);
285
322
  const int nb = k / QK8_0;
286
323
 
287
- block_q8_0x4 * restrict y = (block_q8_0x4 *) vy;
324
+ block_q8_0x4 * LM_GGML_RESTRICT y = (block_q8_0x4 *) vy;
288
325
 
289
326
  #if defined(__ARM_NEON)
290
327
  float32x4_t srcv[4][8];
@@ -494,7 +531,7 @@ static void quantize_q8_0_4x8(const float * restrict x, void * restrict vy, int6
494
531
  #endif
495
532
  }
496
533
 
497
- void quantize_mat_q8_0(const float * restrict x, void * restrict vy, int64_t nrow, int64_t n_per_row, int64_t blck_size_interleave) {
534
+ static void quantize_mat_q8_0(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row, int64_t blck_size_interleave) {
498
535
  assert(nrow == 4);
499
536
  UNUSED(nrow);
500
537
  if (blck_size_interleave == 4) {
@@ -506,7 +543,7 @@ void quantize_mat_q8_0(const float * restrict x, void * restrict vy, int64_t nro
506
543
  }
507
544
  }
508
545
 
509
- void lm_ggml_gemv_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
546
+ static void lm_ggml_gemv_q4_0_4x4_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, const void * LM_GGML_RESTRICT vy, int nr, int nc) {
510
547
  const int qk = QK8_0;
511
548
  const int nb = n / qk;
512
549
  const int ncols_interleaved = 4;
@@ -527,21 +564,21 @@ void lm_ggml_gemv_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void
527
564
 
528
565
  #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
529
566
  if (lm_ggml_cpu_has_neon() && lm_ggml_cpu_has_dotprod()) {
530
- const block_q4_0x4 * b_ptr = (const block_q4_0x4 *)vx;
567
+ const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx;
531
568
 
532
569
  for (int c = 0; c < nc; c += ncols_interleaved) {
533
- const block_q8_0 * a_ptr = (const block_q8_0 *)vy;
570
+ const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
534
571
  float32x4_t acc = vdupq_n_f32(0);
535
572
  for (int b = 0; b < nb; b++) {
536
- int8x16_t b0 = vld1q_s8((const int8_t *)b_ptr->qs);
537
- int8x16_t b1 = vld1q_s8((const int8_t *)b_ptr->qs + 16);
538
- int8x16_t b2 = vld1q_s8((const int8_t *)b_ptr->qs + 32);
539
- int8x16_t b3 = vld1q_s8((const int8_t *)b_ptr->qs + 48);
540
- float16x4_t bd = vld1_f16((const __fp16 *)b_ptr->d);
573
+ int8x16_t b0 = vld1q_s8((const int8_t *) b_ptr->qs);
574
+ int8x16_t b1 = vld1q_s8((const int8_t *) b_ptr->qs + 16);
575
+ int8x16_t b2 = vld1q_s8((const int8_t *) b_ptr->qs + 32);
576
+ int8x16_t b3 = vld1q_s8((const int8_t *) b_ptr->qs + 48);
577
+ float16x4_t bd = vld1_f16((const __fp16 *) b_ptr->d);
541
578
 
542
579
  int8x16_t a0 = vld1q_s8(a_ptr->qs);
543
580
  int8x16_t a1 = vld1q_s8(a_ptr->qs + qk/2);
544
- float16x4_t ad = vld1_dup_f16((const __fp16 *)&a_ptr->d);
581
+ float16x4_t ad = vld1_dup_f16((const __fp16 *) &a_ptr->d);
545
582
 
546
583
  int32x4_t ret = vdupq_n_s32(0);
547
584
 
@@ -591,7 +628,7 @@ void lm_ggml_gemv_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void
591
628
  }
592
629
  }
593
630
 
594
- void lm_ggml_gemv_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
631
+ static void lm_ggml_gemv_q4_0_4x8_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, const void * LM_GGML_RESTRICT vy, int nr, int nc) {
595
632
  const int qk = QK8_0;
596
633
  const int nb = n / qk;
597
634
  const int ncols_interleaved = 4;
@@ -610,72 +647,52 @@ void lm_ggml_gemv_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void
610
647
  UNUSED(ncols_interleaved);
611
648
  UNUSED(blocklen);
612
649
 
613
- #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
614
- if (lm_ggml_cpu_has_neon() && lm_ggml_cpu_has_matmul_int8()) {
615
- const void * b_ptr = vx;
616
- const void * a_ptr = vy;
617
- float * res_ptr = s;
650
+ #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
651
+ if (lm_ggml_cpu_has_neon() && lm_ggml_cpu_has_dotprod()) {
652
+ const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx;
618
653
 
619
- __asm__ __volatile__(
620
- "movi v2.16b, #0x4\n"
621
- "movi v1.16b, #0xf0\n"
622
- "add %x[b_ptr], %x[b_ptr], #0x8\n"
623
- "1:" // Column loop
624
- "add x23, %x[a_ptr], #0x2\n"
625
- "movi v0.16b, #0x0\n"
626
- "mov x22, %x[nb]\n"
627
- "2:" // Block loop
628
- "ldr q31, [%x[b_ptr], #0x0]\n"
629
- "ldr q30, [%x[b_ptr], #0x10]\n"
630
- "mov x21, x23\n"
631
- "movi v29.4s, #0x0\n"
632
- "ldr q28, [%x[b_ptr], #0x20]\n"
633
- "ldr q27, [%x[b_ptr], #0x30]\n"
634
- "movi v26.4s, #0x0\n"
635
- "sub x20, x23, #0x2\n"
636
- "ld1r { v25.8h }, [x20]\n"
637
- "ldr q24, [%x[b_ptr], #-0x8]\n"
638
- "sub x22, x22, #0x1\n"
639
- "add x23, x23, #0x22\n"
640
- "ld1r { v23.2d }, [x21], #0x8\n"
641
- "sshl v22.16b, v31.16b, v2.16b\n"
642
- "sshl v16.16b, v30.16b, v2.16b\n"
643
- "add %x[b_ptr], %x[b_ptr], #0x48\n"
644
- "ld1r { v21.2d }, [x21], #0x8\n"
645
- "sshl v20.16b, v28.16b, v2.16b\n"
646
- "sshl v19.16b, v27.16b, v2.16b\n"
647
- "ld1r { v18.2d }, [x21], #0x8\n"
648
- "ld1r { v17.2d }, [x21], #0x8\n"
649
- "and v31.16b, v31.16b, v1.16b\n"
650
- "and v30.16b, v30.16b, v1.16b\n"
651
- ".inst 0x4e9796dd // sdot v29.4s, v22.16b, v23.16b\n"
652
- ".inst 0x4e97961a // sdot v26.4s, v16.16b, v23.16b\n"
653
- "and v28.16b, v28.16b, v1.16b\n"
654
- "and v27.16b, v27.16b, v1.16b\n"
655
- "fcvtl v25.4s, v25.4h\n"
656
- "fcvtl v16.4s, v24.4h\n"
657
- ".inst 0x4e95969d // sdot v29.4s, v20.16b, v21.16b\n"
658
- ".inst 0x4e95967a // sdot v26.4s, v19.16b, v21.16b\n"
659
- "fmul v16.4s, v16.4s, v25.4s\n"
660
- ".inst 0x4e9297fd // sdot v29.4s, v31.16b, v18.16b\n"
661
- ".inst 0x4e9297da // sdot v26.4s, v30.16b, v18.16b\n"
662
- ".inst 0x4e91979d // sdot v29.4s, v28.16b, v17.16b\n"
663
- ".inst 0x4e91977a // sdot v26.4s, v27.16b, v17.16b\n"
664
- "addp v29.4s, v29.4s, v26.4s\n"
665
- "scvtf v29.4s, v29.4s, #0x4\n"
666
- "fmla v0.4s, v29.4s, v16.4s\n"
667
- "cbnz x22, 2b\n"
668
- "sub %x[nc], %x[nc], #0x4\n"
669
- "str q0, [%x[res_ptr], #0x0]\n"
670
- "add %x[res_ptr], %x[res_ptr], #0x10\n"
671
- "cbnz %x[nc], 1b\n"
672
- : [b_ptr] "+&r" (b_ptr), [res_ptr] "+&r" (res_ptr), [nc] "+&r" (nc)
673
- : [a_ptr] "r" (a_ptr), [nb] "r" (nb)
674
- : "memory", "v0", "v1", "v2", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22", "x23"
675
- );
654
+ for (int c = 0; c < nc; c += ncols_interleaved) {
655
+ const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
656
+ float32x4_t acc = vdupq_n_f32(0);
657
+ for (int b = 0; b < nb; b++) {
658
+ int8x16_t b0 = vld1q_s8((const int8_t *) b_ptr->qs);
659
+ int8x16_t b1 = vld1q_s8((const int8_t *) b_ptr->qs + 16);
660
+ int8x16_t b2 = vld1q_s8((const int8_t *) b_ptr->qs + 32);
661
+ int8x16_t b3 = vld1q_s8((const int8_t *) b_ptr->qs + 48);
662
+ float16x4_t bd = vld1_f16((const __fp16 *) b_ptr->d);
663
+
664
+ int8x16_t a0 = (int8x16_t) vld1q_dup_s64((const int64_t *) a_ptr->qs);
665
+ int8x16_t a1 = (int8x16_t) vld1q_dup_s64((const int64_t *) a_ptr->qs + 1);
666
+ int8x16_t a2 = (int8x16_t) vld1q_dup_s64((const int64_t *) a_ptr->qs + 2);
667
+ int8x16_t a3 = (int8x16_t) vld1q_dup_s64((const int64_t *) a_ptr->qs + 3);
668
+ float16x4_t ad = vld1_dup_f16((const __fp16 *) &a_ptr->d);
669
+
670
+ int32x4_t ret0 = vdupq_n_s32(0);
671
+ int32x4_t ret1 = vdupq_n_s32(0);
672
+
673
+ ret0 = vdotq_s32(ret0, b0 << 4, a0);
674
+ ret1 = vdotq_s32(ret1, b1 << 4, a0);
675
+ ret0 = vdotq_s32(ret0, b2 << 4, a1);
676
+ ret1 = vdotq_s32(ret1, b3 << 4, a1);
677
+
678
+ ret0 = vdotq_s32(ret0, b0 & 0xf0U, a2);
679
+ ret1 = vdotq_s32(ret1, b1 & 0xf0U, a2);
680
+ ret0 = vdotq_s32(ret0, b2 & 0xf0U, a3);
681
+ ret1 = vdotq_s32(ret1, b3 & 0xf0U, a3);
682
+
683
+ int32x4_t ret = vpaddq_s32(ret0, ret1);
684
+
685
+ acc = vfmaq_f32(acc, vcvtq_n_f32_s32(ret, 4),
686
+ vmulq_f32(vcvt_f32_f16(ad), vcvt_f32_f16(bd)));
687
+ a_ptr++;
688
+ b_ptr++;
689
+ }
690
+ vst1q_f32(s, acc);
691
+ s += ncols_interleaved;
692
+ }
676
693
  return;
677
694
  }
678
- #endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
695
+ #endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
679
696
  float sumf[4];
680
697
  int sumi;
681
698
 
@@ -701,7 +718,7 @@ void lm_ggml_gemv_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void
701
718
  }
702
719
  }
703
720
 
704
- void lm_ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
721
+ static void lm_ggml_gemv_q4_0_8x8_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, const void * LM_GGML_RESTRICT vy, int nr, int nc) {
705
722
  const int qk = QK8_0;
706
723
  const int nb = n / qk;
707
724
  const int ncols_interleaved = 8;
@@ -974,7 +991,7 @@ void lm_ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void
974
991
  }
975
992
  }
976
993
 
977
- void lm_ggml_gemv_iq4_nl_4x4_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
994
+ static void lm_ggml_gemv_iq4_nl_4x4_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, const void * LM_GGML_RESTRICT vy, int nr, int nc) {
978
995
  const int qk = QK8_0;
979
996
  const int nb = n / qk;
980
997
  const int ncols_interleaved = 4;
@@ -1070,7 +1087,7 @@ void lm_ggml_gemv_iq4_nl_4x4_q8_0(int n, float * restrict s, size_t bs, const vo
1070
1087
  }
1071
1088
  }
1072
1089
 
1073
- void lm_ggml_gemm_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
1090
+ static void lm_ggml_gemm_q4_0_4x4_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, const void * LM_GGML_RESTRICT vy, int nr, int nc) {
1074
1091
  const int qk = QK8_0;
1075
1092
  const int nb = n / qk;
1076
1093
  const int ncols_interleaved = 4;
@@ -1586,7 +1603,7 @@ void lm_ggml_gemm_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void
1586
1603
  }
1587
1604
  }
1588
1605
 
1589
- void lm_ggml_gemm_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
1606
+ static void lm_ggml_gemm_q4_0_4x8_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, const void * LM_GGML_RESTRICT vy, int nr, int nc) {
1590
1607
  const int qk = QK8_0;
1591
1608
  const int nb = n / qk;
1592
1609
  const int ncols_interleaved = 4;
@@ -2040,7 +2057,7 @@ void lm_ggml_gemm_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void
2040
2057
  }
2041
2058
  }
2042
2059
 
2043
- void lm_ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
2060
+ static void lm_ggml_gemm_q4_0_8x8_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, const void * LM_GGML_RESTRICT vy, int nr, int nc) {
2044
2061
  const int qk = QK8_0;
2045
2062
  const int nb = n / qk;
2046
2063
  const int ncols_interleaved = 8;
@@ -2560,31 +2577,31 @@ void lm_ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void
2560
2577
  const __m512i rhs_mat_2367ABEF_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 4), m4bexpanded)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31) BA(24-31) BB(24-31) BE(24-31) BF(24-31)
2561
2578
 
2562
2579
  // Shuffle pattern one - right side input
2563
- const __m512i rhs_mat_014589CD_0_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, 136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) B8(0-3) B9(0-3) B8(0-3) B9(0-3) BC(0-3) BD(0-3) BC(0-3) BD(0-3)
2564
- const __m512i rhs_mat_2367ABEF_0_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, 136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) BA(0-3) BB(0-3) BA(0-3) BB(0-3) BE(0-3) BF(0-3) BE(0-3) BF(0-3)
2580
+ const __m512i rhs_mat_014589CD_0_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, (_MM_PERM_ENUM)136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) B8(0-3) B9(0-3) B8(0-3) B9(0-3) BC(0-3) BD(0-3) BC(0-3) BD(0-3)
2581
+ const __m512i rhs_mat_2367ABEF_0_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, (_MM_PERM_ENUM)136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) BA(0-3) BB(0-3) BA(0-3) BB(0-3) BE(0-3) BF(0-3) BE(0-3) BF(0-3)
2565
2582
 
2566
- const __m512i rhs_mat_014589CD_1_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, 136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) B8(8-11) B9(8-11) B8(8-11) B9(8-11) BC(8-11) BD(8-11) BC(8-11) BD(8-11)
2567
- const __m512i rhs_mat_2367ABEF_1_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, 136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) BA(8-11) BB(8-11) BA(8-11) BB(8-11) BE(8-11) BF(8-11) BE(8-11) BF(8-11)
2583
+ const __m512i rhs_mat_014589CD_1_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, (_MM_PERM_ENUM)136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) B8(8-11) B9(8-11) B8(8-11) B9(8-11) BC(8-11) BD(8-11) BC(8-11) BD(8-11)
2584
+ const __m512i rhs_mat_2367ABEF_1_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, (_MM_PERM_ENUM)136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) BA(8-11) BB(8-11) BA(8-11) BB(8-11) BE(8-11) BF(8-11) BE(8-11) BF(8-11)
2568
2585
 
2569
- const __m512i rhs_mat_014589CD_2_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, 136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) B8(16-19) B9(16-19) B8(16-19) B9(16-19) BC(16-19) BD(16-19) BC(16-19) BD(16-19)
2570
- const __m512i rhs_mat_2367ABEF_2_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, 136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) BA(16-19) BB(16-19) BA(16-19) BB(16-19) BE(16-19) BF(16-19) BE(16-19) BF(16-19)
2586
+ const __m512i rhs_mat_014589CD_2_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, (_MM_PERM_ENUM)136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) B8(16-19) B9(16-19) B8(16-19) B9(16-19) BC(16-19) BD(16-19) BC(16-19) BD(16-19)
2587
+ const __m512i rhs_mat_2367ABEF_2_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, (_MM_PERM_ENUM)136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) BA(16-19) BB(16-19) BA(16-19) BB(16-19) BE(16-19) BF(16-19) BE(16-19) BF(16-19)
2571
2588
 
2572
- const __m512i rhs_mat_014589CD_3_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, 136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) B8(24-27) B9(24-27) B8(24-27) B9(24-27) BC(24-27) BD(24-27) BC(24-27) BD(24-27)
2573
- const __m512i rhs_mat_2367ABEF_3_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, 136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) BA(24-27) BB(24-27) BA(24-27) BB(24-27) BE(24-27) BF(24-27) BE(24-27) BF(24-27)
2589
+ const __m512i rhs_mat_014589CD_3_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, (_MM_PERM_ENUM)136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) B8(24-27) B9(24-27) B8(24-27) B9(24-27) BC(24-27) BD(24-27) BC(24-27) BD(24-27)
2590
+ const __m512i rhs_mat_2367ABEF_3_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, (_MM_PERM_ENUM)136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) BA(24-27) BB(24-27) BA(24-27) BB(24-27) BE(24-27) BF(24-27) BE(24-27) BF(24-27)
2574
2591
 
2575
2592
  // Shuffle pattern two - right side input
2576
2593
 
2577
- const __m512i rhs_mat_014589CD_0_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, 221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) B8(4-7) B9(4-7) B8(4-7) B9(4-7) BC(4-7) BD(4-7) BC(4-7) BD(4-7)
2578
- const __m512i rhs_mat_2367ABEF_0_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, 221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) BA(4-7) BB(4-7) BA(4-7) BB(4-7) BE(4-7) BF(4-7) BE(4-7) BF(4-7)
2594
+ const __m512i rhs_mat_014589CD_0_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, (_MM_PERM_ENUM)221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) B8(4-7) B9(4-7) B8(4-7) B9(4-7) BC(4-7) BD(4-7) BC(4-7) BD(4-7)
2595
+ const __m512i rhs_mat_2367ABEF_0_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, (_MM_PERM_ENUM)221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) BA(4-7) BB(4-7) BA(4-7) BB(4-7) BE(4-7) BF(4-7) BE(4-7) BF(4-7)
2579
2596
 
2580
- const __m512i rhs_mat_014589CD_1_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, 221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) B8(12-15) B9(12-15) B8(12-15) B9(12-15) BC(12-15) BD(12-15) BC(12-15) BD(12-15)
2581
- const __m512i rhs_mat_2367ABEF_1_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, 221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) BA(12-15) BB(12-15) BA(12-15) BB(12-15) BE(12-15) BF(12-15) BE(12-15) BF(12-15)
2597
+ const __m512i rhs_mat_014589CD_1_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, (_MM_PERM_ENUM)221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) B8(12-15) B9(12-15) B8(12-15) B9(12-15) BC(12-15) BD(12-15) BC(12-15) BD(12-15)
2598
+ const __m512i rhs_mat_2367ABEF_1_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, (_MM_PERM_ENUM)221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) BA(12-15) BB(12-15) BA(12-15) BB(12-15) BE(12-15) BF(12-15) BE(12-15) BF(12-15)
2582
2599
 
2583
- const __m512i rhs_mat_014589CD_2_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, 221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) B8(20-23) B9(20-23) B8(20-23) B9(20-23) BC(20-23) BD(20-23) BC(20-23) BD(20-23)
2584
- const __m512i rhs_mat_2367ABEF_2_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, 221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) BA(20-23) BB(20-23) BA(20-23) BB(20-23) BE(20-23) BF(20-23) BE(20-23) BF(20-23)
2600
+ const __m512i rhs_mat_014589CD_2_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, (_MM_PERM_ENUM)221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) B8(20-23) B9(20-23) B8(20-23) B9(20-23) BC(20-23) BD(20-23) BC(20-23) BD(20-23)
2601
+ const __m512i rhs_mat_2367ABEF_2_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, (_MM_PERM_ENUM)221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) BA(20-23) BB(20-23) BA(20-23) BB(20-23) BE(20-23) BF(20-23) BE(20-23) BF(20-23)
2585
2602
 
2586
- const __m512i rhs_mat_014589CD_3_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, 221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) B8(28-31) B9(28-31) B8(28-31) B9(28-31) BC(28-31) BD(28-31) BC(28-31) BD(28-31)
2587
- const __m512i rhs_mat_2367ABEF_3_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, 221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) BA(28-31) BB(28-31) BA(28-31) BB(28-31) BE(28-31) BF(28-31) BE(28-31) BF(28-31)
2603
+ const __m512i rhs_mat_014589CD_3_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, (_MM_PERM_ENUM)221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) B8(28-31) B9(28-31) B8(28-31) B9(28-31) BC(28-31) BD(28-31) BC(28-31) BD(28-31)
2604
+ const __m512i rhs_mat_2367ABEF_3_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, (_MM_PERM_ENUM)221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) BA(28-31) BB(28-31) BA(28-31) BB(28-31) BE(28-31) BF(28-31) BE(28-31) BF(28-31)
2588
2605
 
2589
2606
  // Scale values - Load the weight scale values of two block_q4_0x8
2590
2607
  const __m512 col_scale_f32 = LM_GGML_F32Cx8x2_LOAD(b_ptr_0[b].d, b_ptr_1[b].d);
@@ -2618,31 +2635,31 @@ void lm_ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void
2618
2635
 
2619
2636
  // Shuffle pattern one - left side input
2620
2637
 
2621
- const __m512i lhs_mat_01_0_sp1 = _mm512_shuffle_epi32(lhs_mat_01_0, 160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3)
2622
- const __m512i lhs_mat_23_0_sp1 = _mm512_shuffle_epi32(lhs_mat_23_0, 160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3)
2638
+ const __m512i lhs_mat_01_0_sp1 = _mm512_shuffle_epi32(lhs_mat_01_0, (_MM_PERM_ENUM)160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3)
2639
+ const __m512i lhs_mat_23_0_sp1 = _mm512_shuffle_epi32(lhs_mat_23_0, (_MM_PERM_ENUM)160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3)
2623
2640
 
2624
- const __m512i lhs_mat_01_1_sp1 = _mm512_shuffle_epi32(lhs_mat_01_1, 160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11)
2625
- const __m512i lhs_mat_23_1_sp1 = _mm512_shuffle_epi32(lhs_mat_23_1, 160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11)
2641
+ const __m512i lhs_mat_01_1_sp1 = _mm512_shuffle_epi32(lhs_mat_01_1, (_MM_PERM_ENUM)160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11)
2642
+ const __m512i lhs_mat_23_1_sp1 = _mm512_shuffle_epi32(lhs_mat_23_1, (_MM_PERM_ENUM)160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11)
2626
2643
 
2627
- const __m512i lhs_mat_01_2_sp1 = _mm512_shuffle_epi32(lhs_mat_01_2, 160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19)
2628
- const __m512i lhs_mat_23_2_sp1 = _mm512_shuffle_epi32(lhs_mat_23_2, 160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19)
2644
+ const __m512i lhs_mat_01_2_sp1 = _mm512_shuffle_epi32(lhs_mat_01_2, (_MM_PERM_ENUM)160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19)
2645
+ const __m512i lhs_mat_23_2_sp1 = _mm512_shuffle_epi32(lhs_mat_23_2, (_MM_PERM_ENUM)160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19)
2629
2646
 
2630
- const __m512i lhs_mat_01_3_sp1 = _mm512_shuffle_epi32(lhs_mat_01_3, 160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27)
2631
- const __m512i lhs_mat_23_3_sp1 = _mm512_shuffle_epi32(lhs_mat_23_3, 160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27)
2647
+ const __m512i lhs_mat_01_3_sp1 = _mm512_shuffle_epi32(lhs_mat_01_3, (_MM_PERM_ENUM)160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27)
2648
+ const __m512i lhs_mat_23_3_sp1 = _mm512_shuffle_epi32(lhs_mat_23_3, (_MM_PERM_ENUM)160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27)
2632
2649
 
2633
2650
  // Shuffle pattern two - left side input
2634
2651
 
2635
- const __m512i lhs_mat_01_0_sp2 = _mm512_shuffle_epi32(lhs_mat_01_0, 245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7)
2636
- const __m512i lhs_mat_23_0_sp2 = _mm512_shuffle_epi32(lhs_mat_23_0, 245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7)
2652
+ const __m512i lhs_mat_01_0_sp2 = _mm512_shuffle_epi32(lhs_mat_01_0, (_MM_PERM_ENUM)245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7)
2653
+ const __m512i lhs_mat_23_0_sp2 = _mm512_shuffle_epi32(lhs_mat_23_0, (_MM_PERM_ENUM)245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7)
2637
2654
 
2638
- const __m512i lhs_mat_01_1_sp2 = _mm512_shuffle_epi32(lhs_mat_01_1, 245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15)
2639
- const __m512i lhs_mat_23_1_sp2 = _mm512_shuffle_epi32(lhs_mat_23_1, 245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15)
2655
+ const __m512i lhs_mat_01_1_sp2 = _mm512_shuffle_epi32(lhs_mat_01_1, (_MM_PERM_ENUM)245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15)
2656
+ const __m512i lhs_mat_23_1_sp2 = _mm512_shuffle_epi32(lhs_mat_23_1, (_MM_PERM_ENUM)245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15)
2640
2657
 
2641
- const __m512i lhs_mat_01_2_sp2 = _mm512_shuffle_epi32(lhs_mat_01_2, 245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23)
2642
- const __m512i lhs_mat_23_2_sp2 = _mm512_shuffle_epi32(lhs_mat_23_2, 245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23)
2658
+ const __m512i lhs_mat_01_2_sp2 = _mm512_shuffle_epi32(lhs_mat_01_2, (_MM_PERM_ENUM)245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23)
2659
+ const __m512i lhs_mat_23_2_sp2 = _mm512_shuffle_epi32(lhs_mat_23_2, (_MM_PERM_ENUM)245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23)
2643
2660
 
2644
- const __m512i lhs_mat_01_3_sp2 = _mm512_shuffle_epi32(lhs_mat_01_3, 245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31)
2645
- const __m512i lhs_mat_23_3_sp2 = _mm512_shuffle_epi32(lhs_mat_23_3, 245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31)
2661
+ const __m512i lhs_mat_01_3_sp2 = _mm512_shuffle_epi32(lhs_mat_01_3, (_MM_PERM_ENUM)245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31)
2662
+ const __m512i lhs_mat_23_3_sp2 = _mm512_shuffle_epi32(lhs_mat_23_3, (_MM_PERM_ENUM)245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31)
2646
2663
 
2647
2664
  // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
2648
2665
  // Resembles MMLAs into 2x2 matrices in ARM Version
@@ -2671,10 +2688,10 @@ void lm_ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void
2671
2688
 
2672
2689
 
2673
2690
  // Straighten out to make 4 row vectors
2674
- __m512i iacc_row_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00, _mm512_shuffle_epi32(iacc_mat_01, 78));
2675
- __m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, 78), iacc_mat_01);
2676
- __m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, 78));
2677
- __m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11);
2691
+ __m512i iacc_row_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00, _mm512_shuffle_epi32(iacc_mat_01, (_MM_PERM_ENUM)78));
2692
+ __m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, (_MM_PERM_ENUM)78), iacc_mat_01);
2693
+ __m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, (_MM_PERM_ENUM)78));
2694
+ __m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, (_MM_PERM_ENUM)78), iacc_mat_11);
2678
2695
 
2679
2696
  // Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes
2680
2697
  const __m128i row_scale_f16 = _mm_shuffle_epi32(_mm_maskload_epi32((int const*)(a_ptrs[rp][b].d), loadMask), 68);
@@ -2753,31 +2770,31 @@ void lm_ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void
2753
2770
  const __m512i rhs_mat_2367ABEF_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 4), m4bexpanded)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31) BA(24-31) BB(24-31) BE(24-31) BF(24-31)
2754
2771
 
2755
2772
  // Shuffle pattern one - right side input
2756
- const __m512i rhs_mat_014589CD_0_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, 136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) B8(0-3) B9(0-3) B8(0-3) B9(0-3) BC(0-3) BD(0-3) BC(0-3) BD(0-3)
2757
- const __m512i rhs_mat_2367ABEF_0_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, 136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) BA(0-3) BB(0-3) BA(0-3) BB(0-3) BE(0-3) BF(0-3) BE(0-3) BF(0-3)
2773
+ const __m512i rhs_mat_014589CD_0_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, (_MM_PERM_ENUM)136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) B8(0-3) B9(0-3) B8(0-3) B9(0-3) BC(0-3) BD(0-3) BC(0-3) BD(0-3)
2774
+ const __m512i rhs_mat_2367ABEF_0_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, (_MM_PERM_ENUM)136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) BA(0-3) BB(0-3) BA(0-3) BB(0-3) BE(0-3) BF(0-3) BE(0-3) BF(0-3)
2758
2775
 
2759
- const __m512i rhs_mat_014589CD_1_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, 136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) B8(8-11) B9(8-11) B8(8-11) B9(8-11) BC(8-11) BD(8-11) BC(8-11) BD(8-11)
2760
- const __m512i rhs_mat_2367ABEF_1_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, 136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) BA(8-11) BB(8-11) BA(8-11) BB(8-11) BE(8-11) BF(8-11) BE(8-11) BF(8-11)
2776
+ const __m512i rhs_mat_014589CD_1_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, (_MM_PERM_ENUM)136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) B8(8-11) B9(8-11) B8(8-11) B9(8-11) BC(8-11) BD(8-11) BC(8-11) BD(8-11)
2777
+ const __m512i rhs_mat_2367ABEF_1_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, (_MM_PERM_ENUM)136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) BA(8-11) BB(8-11) BA(8-11) BB(8-11) BE(8-11) BF(8-11) BE(8-11) BF(8-11)
2761
2778
 
2762
- const __m512i rhs_mat_014589CD_2_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, 136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) B8(16-19) B9(16-19) B8(16-19) B9(16-19) BC(16-19) BD(16-19) BC(16-19) BD(16-19)
2763
- const __m512i rhs_mat_2367ABEF_2_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, 136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) BA(16-19) BB(16-19) BA(16-19) BB(16-19) BE(16-19) BF(16-19) BE(16-19) BF(16-19)
2779
+ const __m512i rhs_mat_014589CD_2_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, (_MM_PERM_ENUM)136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) B8(16-19) B9(16-19) B8(16-19) B9(16-19) BC(16-19) BD(16-19) BC(16-19) BD(16-19)
2780
+ const __m512i rhs_mat_2367ABEF_2_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, (_MM_PERM_ENUM)136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) BA(16-19) BB(16-19) BA(16-19) BB(16-19) BE(16-19) BF(16-19) BE(16-19) BF(16-19)
2764
2781
 
2765
- const __m512i rhs_mat_014589CD_3_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, 136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) B8(24-27) B9(24-27) B8(24-27) B9(24-27) BC(24-27) BD(24-27) BC(24-27) BD(24-27)
2766
- const __m512i rhs_mat_2367ABEF_3_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, 136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) BA(24-27) BB(24-27) BA(24-27) BB(24-27) BE(24-27) BF(24-27) BE(24-27) BF(24-27)
2782
+ const __m512i rhs_mat_014589CD_3_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, (_MM_PERM_ENUM)136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) B8(24-27) B9(24-27) B8(24-27) B9(24-27) BC(24-27) BD(24-27) BC(24-27) BD(24-27)
2783
+ const __m512i rhs_mat_2367ABEF_3_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, (_MM_PERM_ENUM)136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) BA(24-27) BB(24-27) BA(24-27) BB(24-27) BE(24-27) BF(24-27) BE(24-27) BF(24-27)
2767
2784
 
2768
2785
  // Shuffle pattern two - right side input
2769
2786
 
2770
- const __m512i rhs_mat_014589CD_0_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, 221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) B8(4-7) B9(4-7) B8(4-7) B9(4-7) BC(4-7) BD(4-7) BC(4-7) BD(4-7)
2771
- const __m512i rhs_mat_2367ABEF_0_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, 221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) BA(4-7) BB(4-7) BA(4-7) BB(4-7) BE(4-7) BF(4-7) BE(4-7) BF(4-7)
2787
+ const __m512i rhs_mat_014589CD_0_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, (_MM_PERM_ENUM)221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) B8(4-7) B9(4-7) B8(4-7) B9(4-7) BC(4-7) BD(4-7) BC(4-7) BD(4-7)
2788
+ const __m512i rhs_mat_2367ABEF_0_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, (_MM_PERM_ENUM)221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) BA(4-7) BB(4-7) BA(4-7) BB(4-7) BE(4-7) BF(4-7) BE(4-7) BF(4-7)
2772
2789
 
2773
- const __m512i rhs_mat_014589CD_1_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, 221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) B8(12-15) B9(12-15) B8(12-15) B9(12-15) BC(12-15) BD(12-15) BC(12-15) BD(12-15)
2774
- const __m512i rhs_mat_2367ABEF_1_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, 221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) BA(12-15) BB(12-15) BA(12-15) BB(12-15) BE(12-15) BF(12-15) BE(12-15) BF(12-15)
2790
+ const __m512i rhs_mat_014589CD_1_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, (_MM_PERM_ENUM)221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) B8(12-15) B9(12-15) B8(12-15) B9(12-15) BC(12-15) BD(12-15) BC(12-15) BD(12-15)
2791
+ const __m512i rhs_mat_2367ABEF_1_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, (_MM_PERM_ENUM)221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) BA(12-15) BB(12-15) BA(12-15) BB(12-15) BE(12-15) BF(12-15) BE(12-15) BF(12-15)
2775
2792
 
2776
- const __m512i rhs_mat_014589CD_2_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, 221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) B8(20-23) B9(20-23) B8(20-23) B9(20-23) BC(20-23) BD(20-23) BC(20-23) BD(20-23)
2777
- const __m512i rhs_mat_2367ABEF_2_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, 221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) BA(20-23) BB(20-23) BA(20-23) BB(20-23) BE(20-23) BF(20-23) BE(20-23) BF(20-23)
2793
+ const __m512i rhs_mat_014589CD_2_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, (_MM_PERM_ENUM)221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) B8(20-23) B9(20-23) B8(20-23) B9(20-23) BC(20-23) BD(20-23) BC(20-23) BD(20-23)
2794
+ const __m512i rhs_mat_2367ABEF_2_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, (_MM_PERM_ENUM)221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) BA(20-23) BB(20-23) BA(20-23) BB(20-23) BE(20-23) BF(20-23) BE(20-23) BF(20-23)
2778
2795
 
2779
- const __m512i rhs_mat_014589CD_3_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, 221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) B8(28-31) B9(28-31) B8(28-31) B9(28-31) BC(28-31) BD(28-31) BC(28-31) BD(28-31)
2780
- const __m512i rhs_mat_2367ABEF_3_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, 221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) BA(28-31) BB(28-31) BA(28-31) BB(28-31) BE(28-31) BF(28-31) BE(28-31) BF(28-31)
2796
+ const __m512i rhs_mat_014589CD_3_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, (_MM_PERM_ENUM)221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) B8(28-31) B9(28-31) B8(28-31) B9(28-31) BC(28-31) BD(28-31) BC(28-31) BD(28-31)
2797
+ const __m512i rhs_mat_2367ABEF_3_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, (_MM_PERM_ENUM)221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) BA(28-31) BB(28-31) BA(28-31) BB(28-31) BE(28-31) BF(28-31) BE(28-31) BF(28-31)
2781
2798
 
2782
2799
 
2783
2800
  // Scale values - Load the weight scale values of two block_q4_0x8
@@ -2809,31 +2826,31 @@ void lm_ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void
2809
2826
 
2810
2827
  // Shuffle pattern one - left side input
2811
2828
 
2812
- const __m512i lhs_mat_01_0_sp1 = _mm512_shuffle_epi32(lhs_mat_01_0, 160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3)
2813
- const __m512i lhs_mat_23_0_sp1 = _mm512_shuffle_epi32(lhs_mat_23_0, 160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3)
2829
+ const __m512i lhs_mat_01_0_sp1 = _mm512_shuffle_epi32(lhs_mat_01_0, (_MM_PERM_ENUM)160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3)
2830
+ const __m512i lhs_mat_23_0_sp1 = _mm512_shuffle_epi32(lhs_mat_23_0, (_MM_PERM_ENUM)160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3)
2814
2831
 
2815
- const __m512i lhs_mat_01_1_sp1 = _mm512_shuffle_epi32(lhs_mat_01_1, 160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11)
2816
- const __m512i lhs_mat_23_1_sp1 = _mm512_shuffle_epi32(lhs_mat_23_1, 160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11)
2832
+ const __m512i lhs_mat_01_1_sp1 = _mm512_shuffle_epi32(lhs_mat_01_1, (_MM_PERM_ENUM)160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11)
2833
+ const __m512i lhs_mat_23_1_sp1 = _mm512_shuffle_epi32(lhs_mat_23_1, (_MM_PERM_ENUM)160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11)
2817
2834
 
2818
- const __m512i lhs_mat_01_2_sp1 = _mm512_shuffle_epi32(lhs_mat_01_2, 160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19)
2819
- const __m512i lhs_mat_23_2_sp1 = _mm512_shuffle_epi32(lhs_mat_23_2, 160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19)
2835
+ const __m512i lhs_mat_01_2_sp1 = _mm512_shuffle_epi32(lhs_mat_01_2, (_MM_PERM_ENUM)160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19)
2836
+ const __m512i lhs_mat_23_2_sp1 = _mm512_shuffle_epi32(lhs_mat_23_2, (_MM_PERM_ENUM)160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19)
2820
2837
 
2821
- const __m512i lhs_mat_01_3_sp1 = _mm512_shuffle_epi32(lhs_mat_01_3, 160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27)
2822
- const __m512i lhs_mat_23_3_sp1 = _mm512_shuffle_epi32(lhs_mat_23_3, 160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27)
2838
+ const __m512i lhs_mat_01_3_sp1 = _mm512_shuffle_epi32(lhs_mat_01_3, (_MM_PERM_ENUM)160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27)
2839
+ const __m512i lhs_mat_23_3_sp1 = _mm512_shuffle_epi32(lhs_mat_23_3, (_MM_PERM_ENUM)160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27)
2823
2840
 
2824
2841
  // Shuffle pattern two - left side input
2825
2842
 
2826
- const __m512i lhs_mat_01_0_sp2 = _mm512_shuffle_epi32(lhs_mat_01_0, 245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7)
2827
- const __m512i lhs_mat_23_0_sp2 = _mm512_shuffle_epi32(lhs_mat_23_0, 245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7)
2843
+ const __m512i lhs_mat_01_0_sp2 = _mm512_shuffle_epi32(lhs_mat_01_0, (_MM_PERM_ENUM)245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7)
2844
+ const __m512i lhs_mat_23_0_sp2 = _mm512_shuffle_epi32(lhs_mat_23_0, (_MM_PERM_ENUM)245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7)
2828
2845
 
2829
- const __m512i lhs_mat_01_1_sp2 = _mm512_shuffle_epi32(lhs_mat_01_1, 245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15)
2830
- const __m512i lhs_mat_23_1_sp2 = _mm512_shuffle_epi32(lhs_mat_23_1, 245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15)
2846
+ const __m512i lhs_mat_01_1_sp2 = _mm512_shuffle_epi32(lhs_mat_01_1, (_MM_PERM_ENUM)245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15)
2847
+ const __m512i lhs_mat_23_1_sp2 = _mm512_shuffle_epi32(lhs_mat_23_1, (_MM_PERM_ENUM)245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15)
2831
2848
 
2832
- const __m512i lhs_mat_01_2_sp2 = _mm512_shuffle_epi32(lhs_mat_01_2, 245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23)
2833
- const __m512i lhs_mat_23_2_sp2 = _mm512_shuffle_epi32(lhs_mat_23_2, 245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23)
2849
+ const __m512i lhs_mat_01_2_sp2 = _mm512_shuffle_epi32(lhs_mat_01_2, (_MM_PERM_ENUM)245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23)
2850
+ const __m512i lhs_mat_23_2_sp2 = _mm512_shuffle_epi32(lhs_mat_23_2, (_MM_PERM_ENUM)245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23)
2834
2851
 
2835
- const __m512i lhs_mat_01_3_sp2 = _mm512_shuffle_epi32(lhs_mat_01_3, 245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31)
2836
- const __m512i lhs_mat_23_3_sp2 = _mm512_shuffle_epi32(lhs_mat_23_3, 245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31)
2852
+ const __m512i lhs_mat_01_3_sp2 = _mm512_shuffle_epi32(lhs_mat_01_3, (_MM_PERM_ENUM)245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31)
2853
+ const __m512i lhs_mat_23_3_sp2 = _mm512_shuffle_epi32(lhs_mat_23_3, (_MM_PERM_ENUM)245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31)
2837
2854
 
2838
2855
  // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
2839
2856
  // Resembles MMLAs into 2x2 matrices in ARM Version
@@ -2862,10 +2879,10 @@ void lm_ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void
2862
2879
 
2863
2880
 
2864
2881
  // Straighten out to make 4 row vectors
2865
- __m512i iacc_row_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00, _mm512_shuffle_epi32(iacc_mat_01, 78));
2866
- __m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, 78), iacc_mat_01);
2867
- __m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, 78));
2868
- __m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11);
2882
+ __m512i iacc_row_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00, _mm512_shuffle_epi32(iacc_mat_01, (_MM_PERM_ENUM)78));
2883
+ __m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, (_MM_PERM_ENUM)78), iacc_mat_01);
2884
+ __m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, (_MM_PERM_ENUM)78));
2885
+ __m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, (_MM_PERM_ENUM)78), iacc_mat_11);
2869
2886
 
2870
2887
  // Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes
2871
2888
  const __m128i row_scale_f16 = _mm_shuffle_epi32(_mm_maskload_epi32((int const*)(a_ptr[b].d), loadMask), 68);
@@ -3460,7 +3477,7 @@ void lm_ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void
3460
3477
  }
3461
3478
  }
3462
3479
 
3463
- void lm_ggml_gemm_iq4_nl_4x4_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
3480
+ static void lm_ggml_gemm_iq4_nl_4x4_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, const void * LM_GGML_RESTRICT vy, int nr, int nc) {
3464
3481
  const int qk = QK8_0;
3465
3482
  const int nb = n / qk;
3466
3483
  const int ncols_interleaved = 4;
@@ -3571,7 +3588,6 @@ void lm_ggml_gemm_iq4_nl_4x4_q8_0(int n, float * restrict s, size_t bs, const vo
3571
3588
  }
3572
3589
  }
3573
3590
 
3574
- // FIXME: this code is duplicated from ggml-aarch64.c
3575
3591
  static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int blck_size_interleave) {
3576
3592
  block_q4_0x4 out;
3577
3593
 
@@ -3641,20 +3657,20 @@ static block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int blck_size_in
3641
3657
  return out;
3642
3658
  }
3643
3659
 
3644
- static int repack_q4_0_to_q4_0_4_bl(struct lm_ggml_tensor * t, int interleave_block, const void * restrict data, size_t data_size) {
3660
+ static int repack_q4_0_to_q4_0_4_bl(struct lm_ggml_tensor * t, int interleave_block, const void * LM_GGML_RESTRICT data, size_t data_size) {
3645
3661
  LM_GGML_ASSERT(t->type == LM_GGML_TYPE_Q4_0);
3646
3662
  LM_GGML_ASSERT(interleave_block == 4 || interleave_block == 8);
3663
+ constexpr int nrows_interleaved = 4;
3647
3664
 
3648
3665
  block_q4_0x4 * dst = (block_q4_0x4 *)t->data;
3649
3666
  const block_q4_0 * src = (const block_q4_0 *)data;
3650
3667
  block_q4_0 dst_tmp[4];
3651
- int nrow = t->ne[1]; // Number of rows
3652
- int nrows_interleaved = 4;
3668
+ int nrow = lm_ggml_nrows(t);
3653
3669
  int nblocks = t->ne[0] / QK4_0;
3654
3670
 
3655
3671
  LM_GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0));
3656
3672
 
3657
- if (nrow % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
3673
+ if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
3658
3674
  return -1;
3659
3675
  }
3660
3676
 
@@ -3672,20 +3688,20 @@ static int repack_q4_0_to_q4_0_4_bl(struct lm_ggml_tensor * t, int interleave_bl
3672
3688
  LM_GGML_UNUSED(data_size);
3673
3689
  }
3674
3690
 
3675
- static int repack_q4_0_to_q4_0_8_bl(struct lm_ggml_tensor *t, int interleave_block, const void * restrict data, size_t data_size) {
3691
+ static int repack_q4_0_to_q4_0_8_bl(struct lm_ggml_tensor * t, int interleave_block, const void * LM_GGML_RESTRICT data, size_t data_size) {
3676
3692
  LM_GGML_ASSERT(t->type == LM_GGML_TYPE_Q4_0);
3677
3693
  LM_GGML_ASSERT(interleave_block == 8);
3694
+ constexpr int nrows_interleaved = 8;
3678
3695
 
3679
3696
  block_q4_0x8 * dst = (block_q4_0x8*)t->data;
3680
3697
  const block_q4_0 * src = (const block_q4_0*) data;
3681
3698
  block_q4_0 dst_tmp[8];
3682
- int nrow = t->ne[1]; // Number of rows
3683
- int nrows_interleaved = 8;
3699
+ int nrow = lm_ggml_nrows(t);
3684
3700
  int nblocks = t->ne[0] / QK4_0;
3685
3701
 
3686
3702
  LM_GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0));
3687
3703
 
3688
- if (nrow % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
3704
+ if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
3689
3705
  return -1;
3690
3706
  }
3691
3707
 
@@ -3712,16 +3728,18 @@ static block_iq4_nlx4 make_block_iq4_nlx4(block_iq4_nl * in, unsigned int blck_s
3712
3728
 
3713
3729
  const int end = QK4_NL * 2 / blck_size_interleave;
3714
3730
 
3715
- if (blck_size_interleave == 8) {
3716
- for (int i = 0; i < end; ++i) {
3717
- int src_id = i % 4;
3718
- int src_offset = (i / 4) * blck_size_interleave;
3719
- int dst_offset = i * blck_size_interleave;
3720
-
3721
- // Using memcpy to avoid unaligned memory accesses
3722
- memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint64_t));
3723
- }
3724
- } else if (blck_size_interleave == 4) {
3731
+ // TODO: this branch seems wrong
3732
+ //if (blck_size_interleave == 8) {
3733
+ // for (int i = 0; i < end; ++i) {
3734
+ // int src_id = i % 4;
3735
+ // int src_offset = (i / 4) * blck_size_interleave;
3736
+ // int dst_offset = i * blck_size_interleave;
3737
+
3738
+ // // Using memcpy to avoid unaligned memory accesses
3739
+ // memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint64_t));
3740
+ // }
3741
+ //} else
3742
+ if (blck_size_interleave == 4) {
3725
3743
  for (int i = 0; i < end; ++i) {
3726
3744
  int src_id = i % 4;
3727
3745
  int src_offset = (i / 4) * blck_size_interleave;
@@ -3736,20 +3754,21 @@ static block_iq4_nlx4 make_block_iq4_nlx4(block_iq4_nl * in, unsigned int blck_s
3736
3754
  return out;
3737
3755
  }
3738
3756
 
3739
- static int repack_iq4_nl_to_iq4_nl_4_bl(struct lm_ggml_tensor * t, int interleave_block, const void * restrict data, size_t data_size) {
3757
+ static int repack_iq4_nl_to_iq4_nl_4_bl(struct lm_ggml_tensor * t, int interleave_block, const void * LM_GGML_RESTRICT data, size_t data_size) {
3740
3758
  LM_GGML_ASSERT(t->type == LM_GGML_TYPE_IQ4_NL);
3741
- LM_GGML_ASSERT(interleave_block == 4 || interleave_block == 8);
3759
+ //LM_GGML_ASSERT(interleave_block == 4 || interleave_block == 8);
3760
+ LM_GGML_ASSERT(interleave_block == 4);
3742
3761
 
3743
3762
  block_iq4_nlx4 * dst = (block_iq4_nlx4 *)t->data;
3744
3763
  const block_iq4_nl * src = (const block_iq4_nl *)data;
3745
3764
  block_iq4_nl dst_tmp[4];
3746
- int nrow = t->ne[1]; // Number of rows
3765
+ int nrow = lm_ggml_nrows(t);
3747
3766
  int nrows_interleaved = 4;
3748
3767
  int nblocks = t->ne[0] / QK4_0;
3749
3768
 
3750
3769
  LM_GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_iq4_nl));
3751
3770
 
3752
- if (nrow % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
3771
+ if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
3753
3772
  return -1;
3754
3773
  }
3755
3774
 
@@ -3767,57 +3786,457 @@ static int repack_iq4_nl_to_iq4_nl_4_bl(struct lm_ggml_tensor * t, int interleav
3767
3786
  LM_GGML_UNUSED(data_size);
3768
3787
  }
3769
3788
 
3770
- // Prepare for optimized kernels if applicable
3771
- void lm_ggml_aarch64_repack_tensor(struct lm_ggml_tensor * cur, enum lm_ggml_type repack_type, const void * restrict data, size_t data_size) {
3772
- if (cur->type == repack_type) {
3773
- memcpy(cur->data, data, data_size);
3774
- return;
3789
+ namespace ggml::cpu::aarch64 {
3790
+ // repack
3791
+ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS>
3792
+ int repack(struct lm_ggml_tensor *, const void *, size_t);
3793
+
3794
+ // TODO: generalise.
3795
+ template <> int repack<block_q4_0, 4, 4>(struct lm_ggml_tensor * t, const void * data, size_t data_size) {
3796
+ return repack_q4_0_to_q4_0_4_bl(t, 4, data, data_size);
3797
+ }
3798
+
3799
+ template <> int repack<block_q4_0, 8, 4>(struct lm_ggml_tensor * t, const void * data, size_t data_size) {
3800
+ return repack_q4_0_to_q4_0_4_bl(t, 8, data, data_size);
3801
+ }
3802
+
3803
+ template <> int repack<block_q4_0, 8, 8>(struct lm_ggml_tensor * t, const void * data, size_t data_size) {
3804
+ return repack_q4_0_to_q4_0_8_bl(t, 8, data, data_size);
3805
+ }
3806
+
3807
+ template <> int repack<block_iq4_nl, 4, 4>(struct lm_ggml_tensor * t, const void * data, size_t data_size) {
3808
+ return repack_iq4_nl_to_iq4_nl_4_bl(t, 4, data, data_size);
3809
+ }
3810
+
3811
+ // TODO: needs to be revisited
3812
+ //template <> int repack<block_iq4_nl, 8, 4>(struct lm_ggml_tensor * t, const void * data, size_t data_size) {
3813
+ // return repack_iq4_nl_to_iq4_nl_4_bl(t, 8, data, data_size);
3814
+ //}
3815
+
3816
+ // gemv
3817
+ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS>
3818
+ void gemv(int, float *, size_t, const void *, const void *, int, int);
3819
+
3820
+ template <> void gemv<block_q4_0, 4, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
3821
+ lm_ggml_gemv_q4_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
3822
+ }
3823
+
3824
+ template <> void gemv<block_q4_0, 8, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
3825
+ lm_ggml_gemv_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
3826
+ }
3827
+
3828
+ template <> void gemv<block_q4_0, 8, 8>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
3829
+ lm_ggml_gemv_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
3830
+ }
3831
+
3832
+ template <>
3833
+ void gemv<block_iq4_nl, 4, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
3834
+ lm_ggml_gemv_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
3835
+ }
3836
+
3837
+ // gemm
3838
+ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS>
3839
+ void gemm(int, float *, size_t, const void *, const void *, int, int);
3840
+
3841
+ template <> void gemm<block_q4_0, 4, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
3842
+ lm_ggml_gemm_q4_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
3843
+ }
3844
+
3845
+ template <> void gemm<block_q4_0, 8, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
3846
+ lm_ggml_gemm_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
3847
+ }
3848
+
3849
+ template <> void gemm<block_q4_0, 8, 8>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
3850
+ lm_ggml_gemm_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
3851
+ }
3852
+
3853
+ template <>
3854
+ void gemm<block_iq4_nl, 4, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
3855
+ lm_ggml_gemm_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
3856
+ }
3857
+
3858
+ class tensor_traits_base : public ggml::cpu::tensor_traits {
3859
+ public:
3860
+ virtual int repack(struct lm_ggml_tensor * t, const void * data, size_t data_size) = 0;
3861
+ };
3862
+
3863
+ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS> class tensor_traits : public tensor_traits_base {
3864
+
3865
+ bool work_size(int /* n_threads */, const struct lm_ggml_tensor * op, size_t & size) override {
3866
+ // not realy a LM_GGML_TYPE_Q8_0 but same size.
3867
+ switch (op->op) {
3868
+ case LM_GGML_OP_MUL_MAT:
3869
+ size = lm_ggml_row_size(LM_GGML_TYPE_Q8_0, lm_ggml_nelements(op->src[1]));
3870
+ return true;
3871
+ case LM_GGML_OP_MUL_MAT_ID:
3872
+ size = lm_ggml_row_size(LM_GGML_TYPE_Q8_0, lm_ggml_nelements(op->src[1]));
3873
+ size = LM_GGML_PAD(size, sizeof(int64_t)); // + padding for next bloc.
3874
+ size += sizeof(int64_t) * (1+op->src[0]->ne[2]) * op->src[1]->ne[2];
3875
+ return true;
3876
+ default:
3877
+ // LM_GGML_ABORT("fatal error");
3878
+ break;
3879
+ }
3880
+ return false;
3775
3881
  }
3776
3882
 
3777
- if (cur->type == LM_GGML_TYPE_Q4_0) {
3778
- switch (repack_type) {
3779
- case LM_GGML_TYPE_Q4_0_8_8:
3780
- repack_q4_0_to_q4_0_8_bl(cur, 8, data, data_size);
3781
- break;
3782
- case LM_GGML_TYPE_Q4_0_4_8:
3783
- repack_q4_0_to_q4_0_4_bl(cur, 8, data, data_size);
3784
- break;
3785
- case LM_GGML_TYPE_Q4_0_4_4:
3786
- repack_q4_0_to_q4_0_4_bl(cur, 4, data, data_size);
3787
- break;
3788
- default:
3789
- LM_GGML_ABORT("Unsupported type");
3883
+ bool compute_forward(struct lm_ggml_compute_params * params, struct lm_ggml_tensor * op) override {
3884
+ switch (op->op) {
3885
+ case LM_GGML_OP_MUL_MAT:
3886
+ forward_mul_mat(params, op);
3887
+ return true;
3888
+ case LM_GGML_OP_MUL_MAT_ID:
3889
+ forward_mul_mat_id(params, op);
3890
+ return true;
3891
+ default:
3892
+ // LM_GGML_ABORT("fatal error");
3893
+ break;
3790
3894
  }
3791
- } else if (cur->type == LM_GGML_TYPE_IQ4_NL) {
3792
- switch (repack_type) {
3793
- case LM_GGML_TYPE_IQ4_NL_4_4:
3794
- repack_iq4_nl_to_iq4_nl_4_bl(cur, 4, data, data_size);
3795
- break;
3796
- default:
3797
- LM_GGML_ABORT("Unsupported type");
3895
+ return false;
3896
+ }
3897
+
3898
+ void forward_mul_mat(lm_ggml_compute_params * params, lm_ggml_tensor * op) {
3899
+ const lm_ggml_tensor * src0 = op->src[0];
3900
+ const lm_ggml_tensor * src1 = op->src[1];
3901
+ lm_ggml_tensor * dst = op;
3902
+
3903
+ LM_GGML_TENSOR_BINARY_OP_LOCALS
3904
+
3905
+ const int ith = params->ith;
3906
+ const int nth = params->nth;
3907
+
3908
+ LM_GGML_ASSERT(ne0 == ne01);
3909
+ LM_GGML_ASSERT(ne1 == ne11);
3910
+ LM_GGML_ASSERT(ne2 == ne12);
3911
+ LM_GGML_ASSERT(ne3 == ne13);
3912
+
3913
+ // dst cannot be transposed or permuted
3914
+ LM_GGML_ASSERT(nb0 == sizeof(float));
3915
+ LM_GGML_ASSERT(nb0 <= nb1);
3916
+ LM_GGML_ASSERT(nb1 <= nb2);
3917
+ LM_GGML_ASSERT(nb2 <= nb3);
3918
+
3919
+ LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_F32);
3920
+
3921
+ LM_GGML_ASSERT(lm_ggml_n_dims(op->src[0]) == 2);
3922
+ // LM_GGML_ASSERT(lm_ggml_n_dims(op->src[1]) == 2);
3923
+
3924
+ char * wdata = static_cast<char *>(params->wdata);
3925
+ const size_t nbw1 = lm_ggml_row_size(LM_GGML_TYPE_Q8_0, ne10);
3926
+
3927
+ assert(params->wsize >= nbw1 * ne11);
3928
+
3929
+ const lm_ggml_from_float_t from_float = lm_ggml_get_type_traits_cpu(LM_GGML_TYPE_Q8_0)->from_float;
3930
+
3931
+ int64_t i11_processed = 0;
3932
+ for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
3933
+ quantize_mat_q8_0((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), 4, ne10,
3934
+ INTER_SIZE);
3935
+ }
3936
+ i11_processed = ne11 - ne11 % 4;
3937
+ for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) {
3938
+ from_float((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), ne10);
3939
+ }
3940
+
3941
+ lm_ggml_barrier(params->threadpool);
3942
+
3943
+ const void * src1_wdata = params->wdata;
3944
+ const size_t src1_col_stride = lm_ggml_row_size(LM_GGML_TYPE_Q8_0, ne10);
3945
+ int64_t src0_start = (ith * ne01) / nth;
3946
+ int64_t src0_end = ((ith + 1) * ne01) / nth;
3947
+ src0_start = (src0_start % NB_COLS) ? src0_start + NB_COLS - (src0_start % NB_COLS) : src0_start;
3948
+ src0_end = (src0_end % NB_COLS) ? src0_end + NB_COLS - (src0_end % NB_COLS) : src0_end;
3949
+ if (src0_start >= src0_end) {
3950
+ return;
3951
+ }
3952
+
3953
+ // If there are more than three rows in src1, use gemm; otherwise, use gemv.
3954
+ if (ne11 > 3) {
3955
+ gemm<BLOC_TYPE, INTER_SIZE, NB_COLS>(ne00, (float *) ((char *) dst->data) + src0_start, ne01,
3956
+ (const char *) src0->data + src0_start * nb01,
3957
+ (const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start);
3958
+ }
3959
+ for (int iter = ne11 - ne11 % 4; iter < ne11; iter++) {
3960
+ gemv<BLOC_TYPE, INTER_SIZE, NB_COLS>(ne00, (float *) ((char *) dst->data + (iter * nb1)) + src0_start, ne01,
3961
+ (const char *) src0->data + src0_start * nb01,
3962
+ (const char *) src1_wdata + (src1_col_stride * iter), 1,
3963
+ src0_end - src0_start);
3798
3964
  }
3799
- } else {
3800
- LM_GGML_ABORT("Unsupported type");
3801
3965
  }
3802
- }
3803
3966
 
3804
- enum lm_ggml_type lm_ggml_aarch64_get_optimal_repack_type(const struct lm_ggml_tensor * cur) {
3967
+ void forward_mul_mat_id(lm_ggml_compute_params * params, lm_ggml_tensor * op) {
3968
+ const lm_ggml_tensor * src0 = op->src[0];
3969
+ const lm_ggml_tensor * src1 = op->src[1];
3970
+ const lm_ggml_tensor * ids = op->src[2];
3971
+ lm_ggml_tensor * dst = op;
3972
+
3973
+ LM_GGML_TENSOR_BINARY_OP_LOCALS
3974
+
3975
+ const int ith = params->ith;
3976
+ const int nth = params->nth;
3977
+
3978
+ const lm_ggml_from_float_t from_float = lm_ggml_get_type_traits_cpu(LM_GGML_TYPE_Q8_0)->from_float;
3979
+
3980
+ // we don't support permuted src0 or src1
3981
+ LM_GGML_ASSERT(nb00 == lm_ggml_type_size(src0->type));
3982
+ LM_GGML_ASSERT(nb10 == lm_ggml_type_size(src1->type));
3983
+
3984
+ // dst cannot be transposed or permuted
3985
+ LM_GGML_ASSERT(nb0 == sizeof(float));
3986
+ LM_GGML_ASSERT(nb0 <= nb1);
3987
+ LM_GGML_ASSERT(nb1 <= nb2);
3988
+ LM_GGML_ASSERT(nb2 <= nb3);
3989
+
3990
+ LM_GGML_ASSERT(ne03 == 1);
3991
+ LM_GGML_ASSERT(ne13 == 1);
3992
+ LM_GGML_ASSERT(ne3 == 1);
3993
+
3994
+ LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_F32);
3995
+
3996
+ // row groups
3997
+ const int n_ids = ids->ne[0]; // n_expert_used
3998
+ const int n_as = ne02; // n_expert
3999
+
4000
+ const size_t nbw1 = lm_ggml_row_size(LM_GGML_TYPE_Q8_0, ne10);
4001
+ const size_t nbw2 = nbw1*ne11;
4002
+ const size_t nbw3 = nbw2*ne12;
4003
+
4004
+ struct mmid_row_mapping {
4005
+ int32_t i1;
4006
+ int32_t i2;
4007
+ };
4008
+
4009
+ LM_GGML_ASSERT(params->wsize >= (LM_GGML_PAD(nbw3, sizeof(int64_t)) + n_as * sizeof(int64_t) +
4010
+ n_as * ne12 * sizeof(mmid_row_mapping)));
4011
+
4012
+ auto wdata = (char *) params->wdata;
4013
+ auto wdata_src1_end = (char *) wdata + LM_GGML_PAD(nbw3, sizeof(int64_t));
4014
+ int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
4015
+ struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *) (matrix_row_counts + n_as); // [n_as][ne12]
4016
+
4017
+ // src1: float32 => block_q8_0
4018
+ for (int64_t i12 = 0; i12 < ne12; ++i12) {
4019
+ for (int64_t i11 = ith; i11 < ne11; i11 += nth) {
4020
+ from_float((float *)((char *) src1->data + i12 * nb12 + i11 * nb11),
4021
+ (void *) (wdata + i12 * nbw2 + i11 * nbw1),
4022
+ ne10);
4023
+ }
4024
+ }
4025
+
4026
+ #define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id) * ne12 + (i1)]
4027
+
4028
+ if (ith == 0) {
4029
+ // initialize matrix_row_counts
4030
+ memset(matrix_row_counts, 0, n_as * sizeof(int64_t));
4031
+
4032
+ // group rows by src0 matrix
4033
+ for (int32_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) {
4034
+ for (int32_t id = 0; id < n_ids; ++id) {
4035
+ const int32_t i02 =
4036
+ *(const int32_t *) ((const char *) ids->data + iid1 * ids->nb[1] + id * ids->nb[0]);
4037
+
4038
+ LM_GGML_ASSERT(i02 >= 0 && i02 < n_as);
4039
+
4040
+ MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = { id, iid1 };
4041
+ matrix_row_counts[i02] += 1;
4042
+ }
4043
+ }
4044
+ }
4045
+
4046
+ lm_ggml_barrier(params->threadpool);
4047
+
4048
+ // compute each matrix multiplication in sequence
4049
+ for (int cur_a = 0; cur_a < n_as; ++cur_a) {
4050
+ const int64_t cne1 = matrix_row_counts[cur_a];
4051
+
4052
+ if (cne1 == 0) {
4053
+ continue;
4054
+ }
4055
+
4056
+ auto src0_cur = (const char *) src0->data + cur_a*nb02;
4057
+
4058
+ //const int64_t nr0 = ne01; // src0 rows
4059
+ const int64_t nr1 = cne1; // src1 rows
4060
+
4061
+ int64_t src0_cur_start = (ith * ne01) / nth;
4062
+ int64_t src0_cur_end = ((ith + 1) * ne01) / nth;
4063
+ src0_cur_start =
4064
+ (src0_cur_start % NB_COLS) ? src0_cur_start + NB_COLS - (src0_cur_start % NB_COLS) : src0_cur_start;
4065
+ src0_cur_end = (src0_cur_end % NB_COLS) ? src0_cur_end + NB_COLS - (src0_cur_end % NB_COLS) : src0_cur_end;
4066
+
4067
+ if (src0_cur_start >= src0_cur_end) return;
4068
+
4069
+ for (int ir1 = 0; ir1 < nr1; ir1++) {
4070
+ struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, ir1);
4071
+ const int id = row_mapping.i1; // selected expert index
4072
+
4073
+ const int64_t i11 = id % ne11;
4074
+ const int64_t i12 = row_mapping.i2; // row index in src1
4075
+
4076
+ const int64_t i1 = id; // selected expert index
4077
+ const int64_t i2 = i12; // row
4078
+
4079
+ auto src1_col = (const char *) wdata + (i11 * nbw1 + i12 * nbw2);
4080
+
4081
+ gemv<BLOC_TYPE, INTER_SIZE, NB_COLS>(
4082
+ ne00, (float *)((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start,
4083
+ ne01, src0_cur + src0_cur_start * nb01,
4084
+ src1_col, 1, src0_cur_end - src0_cur_start);
4085
+ }
4086
+ }
4087
+ #undef MMID_MATRIX_ROW
4088
+ }
4089
+
4090
+ int repack(struct lm_ggml_tensor * t, const void * data, size_t data_size) override {
4091
+ LM_GGML_LOG_DEBUG("%s: repack tensor %s with %s_%dx%d\n", __func__, t->name, lm_ggml_type_name(t->type),
4092
+ (int) NB_COLS, (int) INTER_SIZE);
4093
+ return ggml::cpu::aarch64::repack<BLOC_TYPE, INTER_SIZE, NB_COLS>(t, data, data_size);
4094
+ }
4095
+ };
4096
+
4097
+ // instance for Q4
4098
+ static const tensor_traits<block_q4_0, 4, 4> q4_0_4x4_q8_0;
4099
+ static const tensor_traits<block_q4_0, 8, 4> q4_0_4x8_q8_0;
4100
+ static const tensor_traits<block_q4_0, 8, 8> q4_0_8x8_q8_0;
4101
+
4102
+ // instance for IQ4
4103
+ static const tensor_traits<block_iq4_nl, 4, 4> iq4_nl_4x4_q8_0;
4104
+
4105
+ } // namespace ggml::cpu::aarch64
4106
+
4107
+ static const ggml::cpu::tensor_traits * lm_ggml_aarch64_get_optimal_repack_type(const struct lm_ggml_tensor * cur) {
3805
4108
  if (cur->type == LM_GGML_TYPE_Q4_0) {
3806
- // TODO: enable for AVX2 - currently disabled due to bad gemv performance
3807
- if (/* lm_ggml_cpu_has_avx2() || */ (lm_ggml_cpu_has_sve() && lm_ggml_cpu_has_matmul_int8() && lm_ggml_cpu_get_sve_cnt() == QK8_0)) {
3808
- return LM_GGML_TYPE_Q4_0_8_8;
4109
+ if (lm_ggml_cpu_has_avx2() || (lm_ggml_cpu_has_sve() && lm_ggml_cpu_has_matmul_int8() && lm_ggml_cpu_get_sve_cnt() == QK8_0)) {
4110
+ if (cur->ne[1] % 8 == 0) {
4111
+ return &ggml::cpu::aarch64::q4_0_8x8_q8_0;
4112
+ }
3809
4113
  }
3810
4114
  if (lm_ggml_cpu_has_neon() && lm_ggml_cpu_has_matmul_int8()) {
3811
- return LM_GGML_TYPE_Q4_0_4_8;
4115
+ if (cur->ne[1] % 4 == 0) {
4116
+ return &ggml::cpu::aarch64::q4_0_4x8_q8_0;
4117
+ }
3812
4118
  }
3813
4119
  if (lm_ggml_cpu_has_neon() && lm_ggml_cpu_has_dotprod()) {
3814
- return LM_GGML_TYPE_Q4_0_4_4;
4120
+ if (cur->ne[1] % 4 == 0) {
4121
+ return &ggml::cpu::aarch64::q4_0_4x4_q8_0;
4122
+ }
3815
4123
  }
3816
4124
  } else if (cur->type == LM_GGML_TYPE_IQ4_NL) {
3817
4125
  if (lm_ggml_cpu_has_neon() && lm_ggml_cpu_has_dotprod()) {
3818
- return LM_GGML_TYPE_IQ4_NL_4_4;
4126
+ if (cur->ne[1] % 4 == 0) {
4127
+ return &ggml::cpu::aarch64::iq4_nl_4x4_q8_0;
4128
+ }
4129
+ }
4130
+ }
4131
+
4132
+ return nullptr;
4133
+ }
4134
+
4135
+ static void lm_ggml_backend_cpu_aarch64_buffer_init_tensor(lm_ggml_backend_buffer_t buffer, struct lm_ggml_tensor * tensor) {
4136
+ tensor->extra = (void *) const_cast<ggml::cpu::tensor_traits *>(lm_ggml_aarch64_get_optimal_repack_type(tensor));
4137
+
4138
+ LM_GGML_UNUSED(buffer);
4139
+ }
4140
+
4141
+ static void lm_ggml_backend_cpu_aarch64_buffer_set_tensor(lm_ggml_backend_buffer_t buffer, struct lm_ggml_tensor * tensor,
4142
+ const void * data, size_t offset, size_t size) {
4143
+ LM_GGML_ASSERT(offset == 0);
4144
+ LM_GGML_ASSERT(size == lm_ggml_nbytes(tensor));
4145
+
4146
+ auto tensor_traits = (ggml::cpu::aarch64::tensor_traits_base *) tensor->extra;
4147
+ auto OK = tensor_traits->repack(tensor, data, size);
4148
+
4149
+ LM_GGML_ASSERT(OK == 0);
4150
+ LM_GGML_UNUSED(buffer);
4151
+ }
4152
+
4153
+ static const char * lm_ggml_backend_cpu_aarch64_buffer_type_get_name(lm_ggml_backend_buffer_type_t buft) {
4154
+ return "CPU_AARCH64";
4155
+
4156
+ LM_GGML_UNUSED(buft);
4157
+ }
4158
+
4159
+ static lm_ggml_backend_buffer_t lm_ggml_backend_cpu_aarch64_buffer_type_alloc_buffer(lm_ggml_backend_buffer_type_t buft, size_t size) {
4160
+ lm_ggml_backend_buffer_t buffer = lm_ggml_backend_buft_alloc_buffer(lm_ggml_backend_cpu_buffer_type(), size);
4161
+
4162
+ if (buffer == nullptr) {
4163
+ return nullptr;
4164
+ }
4165
+
4166
+ buffer->buft = buft;
4167
+ buffer->iface.init_tensor = lm_ggml_backend_cpu_aarch64_buffer_init_tensor;
4168
+ buffer->iface.set_tensor = lm_ggml_backend_cpu_aarch64_buffer_set_tensor;
4169
+ return buffer;
4170
+ }
4171
+
4172
+ static size_t lm_ggml_backend_cpu_aarch64_buffer_type_get_alignment(lm_ggml_backend_buffer_type_t buft) {
4173
+ return TENSOR_ALIGNMENT;
4174
+
4175
+ LM_GGML_UNUSED(buft);
4176
+ }
4177
+
4178
+ namespace ggml::cpu::aarch64 {
4179
+ class extra_buffer_type : ggml::cpu::extra_buffer_type {
4180
+ bool supports_op(lm_ggml_backend_dev_t, const struct lm_ggml_tensor * op) override {
4181
+ if ( op->op == LM_GGML_OP_MUL_MAT &&
4182
+ op->src[0]->buffer &&
4183
+ (lm_ggml_n_dims(op->src[0]) == 2) &&
4184
+ op->src[0]->buffer->buft == lm_ggml_backend_cpu_aarch64_buffer_type() &&
4185
+ lm_ggml_aarch64_get_optimal_repack_type(op->src[0])
4186
+ ) {
4187
+ if (op->src[1]->buffer && !lm_ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
4188
+ return false;
4189
+ }
4190
+ if (op->src[1]->type == LM_GGML_TYPE_F32) {
4191
+ return true;
4192
+ }
4193
+ //if (op->src[1]->type == LM_GGML_TYPE_Q8_0) {
4194
+ // return true;
4195
+ //}
4196
+ // may be possible if Q8_0 packed...
4197
+ } else if (op->op == LM_GGML_OP_MUL_MAT_ID
4198
+ && op->src[0]->buffer
4199
+ && (lm_ggml_n_dims(op->src[0]) == 3)
4200
+ && op->src[0]->buffer->buft == lm_ggml_backend_cpu_aarch64_buffer_type()
4201
+ && lm_ggml_aarch64_get_optimal_repack_type(op->src[0])
4202
+ ) {
4203
+ if (op->src[1]->buffer && !lm_ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
4204
+ return false;
4205
+ }
4206
+ if (op->src[1]->type == LM_GGML_TYPE_F32) {
4207
+ return true;
4208
+ }
4209
+ //if (op->src[1]->type == LM_GGML_TYPE_Q8_0) {
4210
+ // return true;
4211
+ //}
3819
4212
  }
4213
+ return false;
3820
4214
  }
3821
4215
 
3822
- return cur->type;
4216
+ ggml::cpu::tensor_traits * get_tensor_traits(const struct lm_ggml_tensor * op) override {
4217
+ if (op->op == LM_GGML_OP_MUL_MAT || op->op == LM_GGML_OP_MUL_MAT_ID) {
4218
+ if (op->src[0]->buffer && op->src[0]->buffer->buft == lm_ggml_backend_cpu_aarch64_buffer_type()) {
4219
+ return (ggml::cpu::tensor_traits *) op->src[0]->extra;
4220
+ }
4221
+ }
4222
+ return nullptr;
4223
+ }
4224
+ };
4225
+ } // namespace ggml::cpu::aarch64
4226
+
4227
+ lm_ggml_backend_buffer_type_t lm_ggml_backend_cpu_aarch64_buffer_type(void) {
4228
+ static struct lm_ggml_backend_buffer_type lm_ggml_backend_cpu_buffer_type_aarch64 = {
4229
+ /* .iface = */ {
4230
+ /* .get_name = */ lm_ggml_backend_cpu_aarch64_buffer_type_get_name,
4231
+ /* .alloc_buffer = */ lm_ggml_backend_cpu_aarch64_buffer_type_alloc_buffer,
4232
+ /* .get_alignment = */ lm_ggml_backend_cpu_aarch64_buffer_type_get_alignment,
4233
+ /* .get_max_size = */ nullptr, // defaults to SIZE_MAX
4234
+ /* .get_alloc_size = */ nullptr, // defaults to lm_ggml_nbytes
4235
+ /* .is_host = */ nullptr,
4236
+ },
4237
+ /* .device = */ lm_ggml_backend_reg_dev_get(lm_ggml_backend_cpu_reg(), 0),
4238
+ /* .context = */ new ggml::cpu::aarch64::extra_buffer_type(),
4239
+ };
4240
+
4241
+ return &lm_ggml_backend_cpu_buffer_type_aarch64;
3823
4242
  }