cui-llama.rn 1.3.3 → 1.3.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/src/main/CMakeLists.txt +5 -7
- package/android/src/main/java/com/rnllama/LlamaContext.java +4 -4
- package/android/src/main/jni.cpp +9 -9
- package/cpp/common.cpp +28 -44
- package/cpp/common.h +35 -14
- package/cpp/ggml-alloc.c +0 -1
- package/cpp/ggml-backend-impl.h +38 -20
- package/cpp/ggml-backend-reg.cpp +246 -92
- package/cpp/ggml-backend.h +1 -0
- package/cpp/ggml-common.h +42 -48
- package/cpp/{ggml-cpu-aarch64.c → ggml-cpu-aarch64.cpp} +642 -223
- package/cpp/ggml-cpu-aarch64.h +2 -26
- package/cpp/ggml-cpu-traits.cpp +36 -0
- package/cpp/ggml-cpu-traits.h +38 -0
- package/cpp/ggml-cpu.c +14122 -13971
- package/cpp/ggml-cpu.cpp +627 -715
- package/cpp/ggml-cpu.h +0 -17
- package/cpp/ggml-impl.h +22 -6
- package/cpp/ggml-metal.m +482 -24
- package/cpp/ggml-quants.c +0 -9
- package/cpp/ggml-threading.h +4 -2
- package/cpp/ggml.c +284 -178
- package/cpp/ggml.h +73 -25
- package/cpp/llama-grammar.cpp +15 -15
- package/cpp/llama-grammar.h +2 -5
- package/cpp/llama-sampling.cpp +35 -90
- package/cpp/llama-vocab.cpp +7 -2
- package/cpp/llama-vocab.h +1 -1
- package/cpp/llama.cpp +1782 -586
- package/cpp/llama.h +20 -19
- package/cpp/sampling.cpp +11 -16
- package/cpp/sgemm.cpp +265 -258
- package/cpp/sgemm.h +2 -2
- package/cpp/speculative.cpp +4 -0
- package/cpp/unicode.cpp +51 -51
- package/cpp/unicode.h +9 -10
- package/lib/commonjs/index.js +38 -1
- package/lib/commonjs/index.js.map +1 -1
- package/lib/module/index.js +36 -0
- package/lib/module/index.js.map +1 -1
- package/lib/typescript/NativeRNLlama.d.ts +2 -3
- package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
- package/lib/typescript/index.d.ts +36 -2
- package/lib/typescript/index.d.ts.map +1 -1
- package/package.json +1 -1
- package/src/NativeRNLlama.ts +3 -3
- package/src/index.ts +46 -2
- package/cpp/amx/amx.cpp +0 -196
- package/cpp/amx/amx.h +0 -20
- package/cpp/amx/common.h +0 -101
- package/cpp/amx/mmq.cpp +0 -2524
- package/cpp/amx/mmq.h +0 -16
- package/cpp/ggml-aarch64.c +0 -129
- package/cpp/ggml-aarch64.h +0 -19
@@ -1,20 +1,57 @@
|
|
1
|
-
#define
|
1
|
+
#define LM_GGML_COMMON_IMPL_CPP
|
2
|
+
#define LM_GGML_COMMON_DECL_CPP
|
2
3
|
#include "ggml-common.h"
|
4
|
+
#include "ggml-backend-impl.h"
|
3
5
|
|
4
6
|
#include "ggml-quants.h"
|
5
7
|
#include "ggml-impl.h"
|
6
8
|
#include "ggml-cpu.h"
|
7
9
|
#include "ggml-cpu-impl.h"
|
10
|
+
#include "ggml-cpu-traits.h"
|
8
11
|
|
9
|
-
#include <
|
10
|
-
#include <
|
11
|
-
#include <
|
12
|
-
#include <
|
13
|
-
#include <
|
14
|
-
#include <
|
12
|
+
#include <cmath>
|
13
|
+
#include <cstring>
|
14
|
+
#include <cassert>
|
15
|
+
#include <cfloat>
|
16
|
+
#include <cstdlib> // for qsort
|
17
|
+
#include <cstdio> // for LM_GGML_ASSERT
|
15
18
|
|
16
19
|
#include "ggml-cpu-aarch64.h"
|
17
20
|
|
21
|
+
// TODO: move to include file?
|
22
|
+
template <int K> constexpr int QK_0() {
|
23
|
+
if constexpr (K == 4) {
|
24
|
+
return QK4_0;
|
25
|
+
}
|
26
|
+
if constexpr (K == 8) {
|
27
|
+
return QK8_0;
|
28
|
+
}
|
29
|
+
return -1;
|
30
|
+
}
|
31
|
+
|
32
|
+
template <int K, int N> struct block {
|
33
|
+
lm_ggml_half d[N]; // deltas for N qK_0 blocks
|
34
|
+
int8_t qs[(QK_0<K>() * N * K) / 8]; // quants for N qK_0 blocks
|
35
|
+
};
|
36
|
+
|
37
|
+
// control size
|
38
|
+
static_assert(sizeof(block<4, 4>) == 4 * sizeof(lm_ggml_half) + QK8_0 * 2, "wrong block<4,4> size/padding");
|
39
|
+
static_assert(sizeof(block<4, 8>) == 8 * sizeof(lm_ggml_half) + QK8_0 * 4, "wrong block<4,8> size/padding");
|
40
|
+
static_assert(sizeof(block<8, 4>) == 4 * sizeof(lm_ggml_half) + QK8_0 * 4, "wrong block<8,4> size/padding");
|
41
|
+
static_assert(sizeof(block<8, 8>) == 8 * sizeof(lm_ggml_half) + QK8_0 * 8, "wrong block<8,8> size/padding");
|
42
|
+
|
43
|
+
using block_q4_0x4 = block<4, 4>;
|
44
|
+
using block_q4_0x8 = block<4, 8>;
|
45
|
+
using block_q8_0x4 = block<8, 4>;
|
46
|
+
using block_q8_0x8 = block<8, 8>;
|
47
|
+
|
48
|
+
struct block_iq4_nlx4 {
|
49
|
+
lm_ggml_half d[4]; // deltas for 4 iq4_nl blocks
|
50
|
+
uint8_t qs[QK4_NL * 2]; // nibbles / quants for 4 iq4_nl blocks
|
51
|
+
};
|
52
|
+
|
53
|
+
static_assert(sizeof(block_iq4_nlx4) == 4 * sizeof(lm_ggml_half) + QK4_NL * 2, "wrong iq4_nlx4 block size/padding");
|
54
|
+
|
18
55
|
#if defined(__GNUC__)
|
19
56
|
#pragma GCC diagnostic ignored "-Woverlength-strings"
|
20
57
|
#elif defined(_MSC_VER)
|
@@ -128,7 +165,7 @@ static inline __m512i sum_i16_pairs_int_32x16(const __m512i x) {
|
|
128
165
|
}
|
129
166
|
|
130
167
|
static inline __m512i mul_sum_us8_pairs_int32x16(const __m512i ax, const __m512i sy) {
|
131
|
-
#if defined(
|
168
|
+
#if defined(__AVX512VNNI__)
|
132
169
|
const __m512i zero = _mm512_setzero_si512();
|
133
170
|
return _mm512_dpbusd_epi32(zero, ax, sy);
|
134
171
|
#else
|
@@ -185,12 +222,12 @@ static inline __m256i mul_sum_i8_pairs_int32x8(const __m256i x, const __m256i y)
|
|
185
222
|
|
186
223
|
static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
|
187
224
|
|
188
|
-
static void quantize_q8_0_4x4(const float *
|
225
|
+
static void quantize_q8_0_4x4(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT vy, int64_t k) {
|
189
226
|
assert(QK8_0 == 32);
|
190
227
|
assert(k % QK8_0 == 0);
|
191
228
|
const int nb = k / QK8_0;
|
192
229
|
|
193
|
-
block_q8_0x4 *
|
230
|
+
block_q8_0x4 * LM_GGML_RESTRICT y = (block_q8_0x4 *) vy;
|
194
231
|
|
195
232
|
#if defined(__ARM_NEON)
|
196
233
|
float32x4_t srcv[4][8];
|
@@ -279,12 +316,12 @@ static void quantize_q8_0_4x4(const float * restrict x, void * restrict vy, int6
|
|
279
316
|
#endif
|
280
317
|
}
|
281
318
|
|
282
|
-
static void quantize_q8_0_4x8(const float *
|
319
|
+
static void quantize_q8_0_4x8(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT vy, int64_t k) {
|
283
320
|
assert(QK8_0 == 32);
|
284
321
|
assert(k % QK8_0 == 0);
|
285
322
|
const int nb = k / QK8_0;
|
286
323
|
|
287
|
-
block_q8_0x4 *
|
324
|
+
block_q8_0x4 * LM_GGML_RESTRICT y = (block_q8_0x4 *) vy;
|
288
325
|
|
289
326
|
#if defined(__ARM_NEON)
|
290
327
|
float32x4_t srcv[4][8];
|
@@ -494,7 +531,7 @@ static void quantize_q8_0_4x8(const float * restrict x, void * restrict vy, int6
|
|
494
531
|
#endif
|
495
532
|
}
|
496
533
|
|
497
|
-
void quantize_mat_q8_0(const float *
|
534
|
+
static void quantize_mat_q8_0(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row, int64_t blck_size_interleave) {
|
498
535
|
assert(nrow == 4);
|
499
536
|
UNUSED(nrow);
|
500
537
|
if (blck_size_interleave == 4) {
|
@@ -506,7 +543,7 @@ void quantize_mat_q8_0(const float * restrict x, void * restrict vy, int64_t nro
|
|
506
543
|
}
|
507
544
|
}
|
508
545
|
|
509
|
-
void lm_ggml_gemv_q4_0_4x4_q8_0(int n, float *
|
546
|
+
static void lm_ggml_gemv_q4_0_4x4_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, const void * LM_GGML_RESTRICT vy, int nr, int nc) {
|
510
547
|
const int qk = QK8_0;
|
511
548
|
const int nb = n / qk;
|
512
549
|
const int ncols_interleaved = 4;
|
@@ -527,21 +564,21 @@ void lm_ggml_gemv_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void
|
|
527
564
|
|
528
565
|
#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
529
566
|
if (lm_ggml_cpu_has_neon() && lm_ggml_cpu_has_dotprod()) {
|
530
|
-
const block_q4_0x4 * b_ptr = (const block_q4_0x4 *)vx;
|
567
|
+
const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx;
|
531
568
|
|
532
569
|
for (int c = 0; c < nc; c += ncols_interleaved) {
|
533
|
-
const block_q8_0 * a_ptr = (const block_q8_0 *)vy;
|
570
|
+
const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
|
534
571
|
float32x4_t acc = vdupq_n_f32(0);
|
535
572
|
for (int b = 0; b < nb; b++) {
|
536
|
-
int8x16_t b0 = vld1q_s8((const int8_t *)b_ptr->qs);
|
537
|
-
int8x16_t b1 = vld1q_s8((const int8_t *)b_ptr->qs + 16);
|
538
|
-
int8x16_t b2 = vld1q_s8((const int8_t *)b_ptr->qs + 32);
|
539
|
-
int8x16_t b3 = vld1q_s8((const int8_t *)b_ptr->qs + 48);
|
540
|
-
float16x4_t bd = vld1_f16((const __fp16 *)b_ptr->d);
|
573
|
+
int8x16_t b0 = vld1q_s8((const int8_t *) b_ptr->qs);
|
574
|
+
int8x16_t b1 = vld1q_s8((const int8_t *) b_ptr->qs + 16);
|
575
|
+
int8x16_t b2 = vld1q_s8((const int8_t *) b_ptr->qs + 32);
|
576
|
+
int8x16_t b3 = vld1q_s8((const int8_t *) b_ptr->qs + 48);
|
577
|
+
float16x4_t bd = vld1_f16((const __fp16 *) b_ptr->d);
|
541
578
|
|
542
579
|
int8x16_t a0 = vld1q_s8(a_ptr->qs);
|
543
580
|
int8x16_t a1 = vld1q_s8(a_ptr->qs + qk/2);
|
544
|
-
float16x4_t ad = vld1_dup_f16((const __fp16 *)&a_ptr->d);
|
581
|
+
float16x4_t ad = vld1_dup_f16((const __fp16 *) &a_ptr->d);
|
545
582
|
|
546
583
|
int32x4_t ret = vdupq_n_s32(0);
|
547
584
|
|
@@ -591,7 +628,7 @@ void lm_ggml_gemv_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void
|
|
591
628
|
}
|
592
629
|
}
|
593
630
|
|
594
|
-
void lm_ggml_gemv_q4_0_4x8_q8_0(int n, float *
|
631
|
+
static void lm_ggml_gemv_q4_0_4x8_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, const void * LM_GGML_RESTRICT vy, int nr, int nc) {
|
595
632
|
const int qk = QK8_0;
|
596
633
|
const int nb = n / qk;
|
597
634
|
const int ncols_interleaved = 4;
|
@@ -610,72 +647,52 @@ void lm_ggml_gemv_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void
|
|
610
647
|
UNUSED(ncols_interleaved);
|
611
648
|
UNUSED(blocklen);
|
612
649
|
|
613
|
-
#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(
|
614
|
-
if (lm_ggml_cpu_has_neon() &&
|
615
|
-
const
|
616
|
-
const void * a_ptr = vy;
|
617
|
-
float * res_ptr = s;
|
650
|
+
#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
651
|
+
if (lm_ggml_cpu_has_neon() && lm_ggml_cpu_has_dotprod()) {
|
652
|
+
const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx;
|
618
653
|
|
619
|
-
|
620
|
-
|
621
|
-
|
622
|
-
|
623
|
-
|
624
|
-
|
625
|
-
|
626
|
-
|
627
|
-
|
628
|
-
|
629
|
-
|
630
|
-
|
631
|
-
|
632
|
-
|
633
|
-
|
634
|
-
|
635
|
-
|
636
|
-
|
637
|
-
|
638
|
-
|
639
|
-
|
640
|
-
|
641
|
-
|
642
|
-
|
643
|
-
|
644
|
-
|
645
|
-
|
646
|
-
|
647
|
-
|
648
|
-
|
649
|
-
|
650
|
-
|
651
|
-
|
652
|
-
|
653
|
-
|
654
|
-
|
655
|
-
|
656
|
-
|
657
|
-
|
658
|
-
".inst 0x4e95967a // sdot v26.4s, v19.16b, v21.16b\n"
|
659
|
-
"fmul v16.4s, v16.4s, v25.4s\n"
|
660
|
-
".inst 0x4e9297fd // sdot v29.4s, v31.16b, v18.16b\n"
|
661
|
-
".inst 0x4e9297da // sdot v26.4s, v30.16b, v18.16b\n"
|
662
|
-
".inst 0x4e91979d // sdot v29.4s, v28.16b, v17.16b\n"
|
663
|
-
".inst 0x4e91977a // sdot v26.4s, v27.16b, v17.16b\n"
|
664
|
-
"addp v29.4s, v29.4s, v26.4s\n"
|
665
|
-
"scvtf v29.4s, v29.4s, #0x4\n"
|
666
|
-
"fmla v0.4s, v29.4s, v16.4s\n"
|
667
|
-
"cbnz x22, 2b\n"
|
668
|
-
"sub %x[nc], %x[nc], #0x4\n"
|
669
|
-
"str q0, [%x[res_ptr], #0x0]\n"
|
670
|
-
"add %x[res_ptr], %x[res_ptr], #0x10\n"
|
671
|
-
"cbnz %x[nc], 1b\n"
|
672
|
-
: [b_ptr] "+&r" (b_ptr), [res_ptr] "+&r" (res_ptr), [nc] "+&r" (nc)
|
673
|
-
: [a_ptr] "r" (a_ptr), [nb] "r" (nb)
|
674
|
-
: "memory", "v0", "v1", "v2", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22", "x23"
|
675
|
-
);
|
654
|
+
for (int c = 0; c < nc; c += ncols_interleaved) {
|
655
|
+
const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
|
656
|
+
float32x4_t acc = vdupq_n_f32(0);
|
657
|
+
for (int b = 0; b < nb; b++) {
|
658
|
+
int8x16_t b0 = vld1q_s8((const int8_t *) b_ptr->qs);
|
659
|
+
int8x16_t b1 = vld1q_s8((const int8_t *) b_ptr->qs + 16);
|
660
|
+
int8x16_t b2 = vld1q_s8((const int8_t *) b_ptr->qs + 32);
|
661
|
+
int8x16_t b3 = vld1q_s8((const int8_t *) b_ptr->qs + 48);
|
662
|
+
float16x4_t bd = vld1_f16((const __fp16 *) b_ptr->d);
|
663
|
+
|
664
|
+
int8x16_t a0 = (int8x16_t) vld1q_dup_s64((const int64_t *) a_ptr->qs);
|
665
|
+
int8x16_t a1 = (int8x16_t) vld1q_dup_s64((const int64_t *) a_ptr->qs + 1);
|
666
|
+
int8x16_t a2 = (int8x16_t) vld1q_dup_s64((const int64_t *) a_ptr->qs + 2);
|
667
|
+
int8x16_t a3 = (int8x16_t) vld1q_dup_s64((const int64_t *) a_ptr->qs + 3);
|
668
|
+
float16x4_t ad = vld1_dup_f16((const __fp16 *) &a_ptr->d);
|
669
|
+
|
670
|
+
int32x4_t ret0 = vdupq_n_s32(0);
|
671
|
+
int32x4_t ret1 = vdupq_n_s32(0);
|
672
|
+
|
673
|
+
ret0 = vdotq_s32(ret0, b0 << 4, a0);
|
674
|
+
ret1 = vdotq_s32(ret1, b1 << 4, a0);
|
675
|
+
ret0 = vdotq_s32(ret0, b2 << 4, a1);
|
676
|
+
ret1 = vdotq_s32(ret1, b3 << 4, a1);
|
677
|
+
|
678
|
+
ret0 = vdotq_s32(ret0, b0 & 0xf0U, a2);
|
679
|
+
ret1 = vdotq_s32(ret1, b1 & 0xf0U, a2);
|
680
|
+
ret0 = vdotq_s32(ret0, b2 & 0xf0U, a3);
|
681
|
+
ret1 = vdotq_s32(ret1, b3 & 0xf0U, a3);
|
682
|
+
|
683
|
+
int32x4_t ret = vpaddq_s32(ret0, ret1);
|
684
|
+
|
685
|
+
acc = vfmaq_f32(acc, vcvtq_n_f32_s32(ret, 4),
|
686
|
+
vmulq_f32(vcvt_f32_f16(ad), vcvt_f32_f16(bd)));
|
687
|
+
a_ptr++;
|
688
|
+
b_ptr++;
|
689
|
+
}
|
690
|
+
vst1q_f32(s, acc);
|
691
|
+
s += ncols_interleaved;
|
692
|
+
}
|
676
693
|
return;
|
677
694
|
}
|
678
|
-
#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(
|
695
|
+
#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
679
696
|
float sumf[4];
|
680
697
|
int sumi;
|
681
698
|
|
@@ -701,7 +718,7 @@ void lm_ggml_gemv_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void
|
|
701
718
|
}
|
702
719
|
}
|
703
720
|
|
704
|
-
void lm_ggml_gemv_q4_0_8x8_q8_0(int n, float *
|
721
|
+
static void lm_ggml_gemv_q4_0_8x8_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, const void * LM_GGML_RESTRICT vy, int nr, int nc) {
|
705
722
|
const int qk = QK8_0;
|
706
723
|
const int nb = n / qk;
|
707
724
|
const int ncols_interleaved = 8;
|
@@ -974,7 +991,7 @@ void lm_ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void
|
|
974
991
|
}
|
975
992
|
}
|
976
993
|
|
977
|
-
void lm_ggml_gemv_iq4_nl_4x4_q8_0(int n, float *
|
994
|
+
static void lm_ggml_gemv_iq4_nl_4x4_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, const void * LM_GGML_RESTRICT vy, int nr, int nc) {
|
978
995
|
const int qk = QK8_0;
|
979
996
|
const int nb = n / qk;
|
980
997
|
const int ncols_interleaved = 4;
|
@@ -1070,7 +1087,7 @@ void lm_ggml_gemv_iq4_nl_4x4_q8_0(int n, float * restrict s, size_t bs, const vo
|
|
1070
1087
|
}
|
1071
1088
|
}
|
1072
1089
|
|
1073
|
-
void lm_ggml_gemm_q4_0_4x4_q8_0(int n, float *
|
1090
|
+
static void lm_ggml_gemm_q4_0_4x4_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, const void * LM_GGML_RESTRICT vy, int nr, int nc) {
|
1074
1091
|
const int qk = QK8_0;
|
1075
1092
|
const int nb = n / qk;
|
1076
1093
|
const int ncols_interleaved = 4;
|
@@ -1586,7 +1603,7 @@ void lm_ggml_gemm_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void
|
|
1586
1603
|
}
|
1587
1604
|
}
|
1588
1605
|
|
1589
|
-
void lm_ggml_gemm_q4_0_4x8_q8_0(int n, float *
|
1606
|
+
static void lm_ggml_gemm_q4_0_4x8_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, const void * LM_GGML_RESTRICT vy, int nr, int nc) {
|
1590
1607
|
const int qk = QK8_0;
|
1591
1608
|
const int nb = n / qk;
|
1592
1609
|
const int ncols_interleaved = 4;
|
@@ -2040,7 +2057,7 @@ void lm_ggml_gemm_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void
|
|
2040
2057
|
}
|
2041
2058
|
}
|
2042
2059
|
|
2043
|
-
void lm_ggml_gemm_q4_0_8x8_q8_0(int n, float *
|
2060
|
+
static void lm_ggml_gemm_q4_0_8x8_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, const void * LM_GGML_RESTRICT vy, int nr, int nc) {
|
2044
2061
|
const int qk = QK8_0;
|
2045
2062
|
const int nb = n / qk;
|
2046
2063
|
const int ncols_interleaved = 8;
|
@@ -2560,31 +2577,31 @@ void lm_ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void
|
|
2560
2577
|
const __m512i rhs_mat_2367ABEF_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 4), m4bexpanded)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31) BA(24-31) BB(24-31) BE(24-31) BF(24-31)
|
2561
2578
|
|
2562
2579
|
// Shuffle pattern one - right side input
|
2563
|
-
const __m512i rhs_mat_014589CD_0_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, 136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) B8(0-3) B9(0-3) B8(0-3) B9(0-3) BC(0-3) BD(0-3) BC(0-3) BD(0-3)
|
2564
|
-
const __m512i rhs_mat_2367ABEF_0_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, 136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) BA(0-3) BB(0-3) BA(0-3) BB(0-3) BE(0-3) BF(0-3) BE(0-3) BF(0-3)
|
2580
|
+
const __m512i rhs_mat_014589CD_0_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, (_MM_PERM_ENUM)136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) B8(0-3) B9(0-3) B8(0-3) B9(0-3) BC(0-3) BD(0-3) BC(0-3) BD(0-3)
|
2581
|
+
const __m512i rhs_mat_2367ABEF_0_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, (_MM_PERM_ENUM)136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) BA(0-3) BB(0-3) BA(0-3) BB(0-3) BE(0-3) BF(0-3) BE(0-3) BF(0-3)
|
2565
2582
|
|
2566
|
-
const __m512i rhs_mat_014589CD_1_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, 136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) B8(8-11) B9(8-11) B8(8-11) B9(8-11) BC(8-11) BD(8-11) BC(8-11) BD(8-11)
|
2567
|
-
const __m512i rhs_mat_2367ABEF_1_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, 136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) BA(8-11) BB(8-11) BA(8-11) BB(8-11) BE(8-11) BF(8-11) BE(8-11) BF(8-11)
|
2583
|
+
const __m512i rhs_mat_014589CD_1_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, (_MM_PERM_ENUM)136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) B8(8-11) B9(8-11) B8(8-11) B9(8-11) BC(8-11) BD(8-11) BC(8-11) BD(8-11)
|
2584
|
+
const __m512i rhs_mat_2367ABEF_1_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, (_MM_PERM_ENUM)136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) BA(8-11) BB(8-11) BA(8-11) BB(8-11) BE(8-11) BF(8-11) BE(8-11) BF(8-11)
|
2568
2585
|
|
2569
|
-
const __m512i rhs_mat_014589CD_2_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, 136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) B8(16-19) B9(16-19) B8(16-19) B9(16-19) BC(16-19) BD(16-19) BC(16-19) BD(16-19)
|
2570
|
-
const __m512i rhs_mat_2367ABEF_2_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, 136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) BA(16-19) BB(16-19) BA(16-19) BB(16-19) BE(16-19) BF(16-19) BE(16-19) BF(16-19)
|
2586
|
+
const __m512i rhs_mat_014589CD_2_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, (_MM_PERM_ENUM)136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) B8(16-19) B9(16-19) B8(16-19) B9(16-19) BC(16-19) BD(16-19) BC(16-19) BD(16-19)
|
2587
|
+
const __m512i rhs_mat_2367ABEF_2_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, (_MM_PERM_ENUM)136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) BA(16-19) BB(16-19) BA(16-19) BB(16-19) BE(16-19) BF(16-19) BE(16-19) BF(16-19)
|
2571
2588
|
|
2572
|
-
const __m512i rhs_mat_014589CD_3_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, 136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) B8(24-27) B9(24-27) B8(24-27) B9(24-27) BC(24-27) BD(24-27) BC(24-27) BD(24-27)
|
2573
|
-
const __m512i rhs_mat_2367ABEF_3_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, 136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) BA(24-27) BB(24-27) BA(24-27) BB(24-27) BE(24-27) BF(24-27) BE(24-27) BF(24-27)
|
2589
|
+
const __m512i rhs_mat_014589CD_3_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, (_MM_PERM_ENUM)136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) B8(24-27) B9(24-27) B8(24-27) B9(24-27) BC(24-27) BD(24-27) BC(24-27) BD(24-27)
|
2590
|
+
const __m512i rhs_mat_2367ABEF_3_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, (_MM_PERM_ENUM)136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) BA(24-27) BB(24-27) BA(24-27) BB(24-27) BE(24-27) BF(24-27) BE(24-27) BF(24-27)
|
2574
2591
|
|
2575
2592
|
// Shuffle pattern two - right side input
|
2576
2593
|
|
2577
|
-
const __m512i rhs_mat_014589CD_0_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, 221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) B8(4-7) B9(4-7) B8(4-7) B9(4-7) BC(4-7) BD(4-7) BC(4-7) BD(4-7)
|
2578
|
-
const __m512i rhs_mat_2367ABEF_0_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, 221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) BA(4-7) BB(4-7) BA(4-7) BB(4-7) BE(4-7) BF(4-7) BE(4-7) BF(4-7)
|
2594
|
+
const __m512i rhs_mat_014589CD_0_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, (_MM_PERM_ENUM)221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) B8(4-7) B9(4-7) B8(4-7) B9(4-7) BC(4-7) BD(4-7) BC(4-7) BD(4-7)
|
2595
|
+
const __m512i rhs_mat_2367ABEF_0_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, (_MM_PERM_ENUM)221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) BA(4-7) BB(4-7) BA(4-7) BB(4-7) BE(4-7) BF(4-7) BE(4-7) BF(4-7)
|
2579
2596
|
|
2580
|
-
const __m512i rhs_mat_014589CD_1_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, 221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) B8(12-15) B9(12-15) B8(12-15) B9(12-15) BC(12-15) BD(12-15) BC(12-15) BD(12-15)
|
2581
|
-
const __m512i rhs_mat_2367ABEF_1_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, 221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) BA(12-15) BB(12-15) BA(12-15) BB(12-15) BE(12-15) BF(12-15) BE(12-15) BF(12-15)
|
2597
|
+
const __m512i rhs_mat_014589CD_1_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, (_MM_PERM_ENUM)221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) B8(12-15) B9(12-15) B8(12-15) B9(12-15) BC(12-15) BD(12-15) BC(12-15) BD(12-15)
|
2598
|
+
const __m512i rhs_mat_2367ABEF_1_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, (_MM_PERM_ENUM)221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) BA(12-15) BB(12-15) BA(12-15) BB(12-15) BE(12-15) BF(12-15) BE(12-15) BF(12-15)
|
2582
2599
|
|
2583
|
-
const __m512i rhs_mat_014589CD_2_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, 221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) B8(20-23) B9(20-23) B8(20-23) B9(20-23) BC(20-23) BD(20-23) BC(20-23) BD(20-23)
|
2584
|
-
const __m512i rhs_mat_2367ABEF_2_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, 221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) BA(20-23) BB(20-23) BA(20-23) BB(20-23) BE(20-23) BF(20-23) BE(20-23) BF(20-23)
|
2600
|
+
const __m512i rhs_mat_014589CD_2_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, (_MM_PERM_ENUM)221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) B8(20-23) B9(20-23) B8(20-23) B9(20-23) BC(20-23) BD(20-23) BC(20-23) BD(20-23)
|
2601
|
+
const __m512i rhs_mat_2367ABEF_2_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, (_MM_PERM_ENUM)221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) BA(20-23) BB(20-23) BA(20-23) BB(20-23) BE(20-23) BF(20-23) BE(20-23) BF(20-23)
|
2585
2602
|
|
2586
|
-
const __m512i rhs_mat_014589CD_3_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, 221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) B8(28-31) B9(28-31) B8(28-31) B9(28-31) BC(28-31) BD(28-31) BC(28-31) BD(28-31)
|
2587
|
-
const __m512i rhs_mat_2367ABEF_3_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, 221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) BA(28-31) BB(28-31) BA(28-31) BB(28-31) BE(28-31) BF(28-31) BE(28-31) BF(28-31)
|
2603
|
+
const __m512i rhs_mat_014589CD_3_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, (_MM_PERM_ENUM)221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) B8(28-31) B9(28-31) B8(28-31) B9(28-31) BC(28-31) BD(28-31) BC(28-31) BD(28-31)
|
2604
|
+
const __m512i rhs_mat_2367ABEF_3_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, (_MM_PERM_ENUM)221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) BA(28-31) BB(28-31) BA(28-31) BB(28-31) BE(28-31) BF(28-31) BE(28-31) BF(28-31)
|
2588
2605
|
|
2589
2606
|
// Scale values - Load the weight scale values of two block_q4_0x8
|
2590
2607
|
const __m512 col_scale_f32 = LM_GGML_F32Cx8x2_LOAD(b_ptr_0[b].d, b_ptr_1[b].d);
|
@@ -2618,31 +2635,31 @@ void lm_ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void
|
|
2618
2635
|
|
2619
2636
|
// Shuffle pattern one - left side input
|
2620
2637
|
|
2621
|
-
const __m512i lhs_mat_01_0_sp1 = _mm512_shuffle_epi32(lhs_mat_01_0, 160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3)
|
2622
|
-
const __m512i lhs_mat_23_0_sp1 = _mm512_shuffle_epi32(lhs_mat_23_0, 160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3)
|
2638
|
+
const __m512i lhs_mat_01_0_sp1 = _mm512_shuffle_epi32(lhs_mat_01_0, (_MM_PERM_ENUM)160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3)
|
2639
|
+
const __m512i lhs_mat_23_0_sp1 = _mm512_shuffle_epi32(lhs_mat_23_0, (_MM_PERM_ENUM)160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3)
|
2623
2640
|
|
2624
|
-
const __m512i lhs_mat_01_1_sp1 = _mm512_shuffle_epi32(lhs_mat_01_1, 160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11)
|
2625
|
-
const __m512i lhs_mat_23_1_sp1 = _mm512_shuffle_epi32(lhs_mat_23_1, 160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11)
|
2641
|
+
const __m512i lhs_mat_01_1_sp1 = _mm512_shuffle_epi32(lhs_mat_01_1, (_MM_PERM_ENUM)160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11)
|
2642
|
+
const __m512i lhs_mat_23_1_sp1 = _mm512_shuffle_epi32(lhs_mat_23_1, (_MM_PERM_ENUM)160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11)
|
2626
2643
|
|
2627
|
-
const __m512i lhs_mat_01_2_sp1 = _mm512_shuffle_epi32(lhs_mat_01_2, 160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19)
|
2628
|
-
const __m512i lhs_mat_23_2_sp1 = _mm512_shuffle_epi32(lhs_mat_23_2, 160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19)
|
2644
|
+
const __m512i lhs_mat_01_2_sp1 = _mm512_shuffle_epi32(lhs_mat_01_2, (_MM_PERM_ENUM)160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19)
|
2645
|
+
const __m512i lhs_mat_23_2_sp1 = _mm512_shuffle_epi32(lhs_mat_23_2, (_MM_PERM_ENUM)160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19)
|
2629
2646
|
|
2630
|
-
const __m512i lhs_mat_01_3_sp1 = _mm512_shuffle_epi32(lhs_mat_01_3, 160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27)
|
2631
|
-
const __m512i lhs_mat_23_3_sp1 = _mm512_shuffle_epi32(lhs_mat_23_3, 160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27)
|
2647
|
+
const __m512i lhs_mat_01_3_sp1 = _mm512_shuffle_epi32(lhs_mat_01_3, (_MM_PERM_ENUM)160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27)
|
2648
|
+
const __m512i lhs_mat_23_3_sp1 = _mm512_shuffle_epi32(lhs_mat_23_3, (_MM_PERM_ENUM)160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27)
|
2632
2649
|
|
2633
2650
|
// Shuffle pattern two - left side input
|
2634
2651
|
|
2635
|
-
const __m512i lhs_mat_01_0_sp2 = _mm512_shuffle_epi32(lhs_mat_01_0, 245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7)
|
2636
|
-
const __m512i lhs_mat_23_0_sp2 = _mm512_shuffle_epi32(lhs_mat_23_0, 245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7)
|
2652
|
+
const __m512i lhs_mat_01_0_sp2 = _mm512_shuffle_epi32(lhs_mat_01_0, (_MM_PERM_ENUM)245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7)
|
2653
|
+
const __m512i lhs_mat_23_0_sp2 = _mm512_shuffle_epi32(lhs_mat_23_0, (_MM_PERM_ENUM)245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7)
|
2637
2654
|
|
2638
|
-
const __m512i lhs_mat_01_1_sp2 = _mm512_shuffle_epi32(lhs_mat_01_1, 245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15)
|
2639
|
-
const __m512i lhs_mat_23_1_sp2 = _mm512_shuffle_epi32(lhs_mat_23_1, 245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15)
|
2655
|
+
const __m512i lhs_mat_01_1_sp2 = _mm512_shuffle_epi32(lhs_mat_01_1, (_MM_PERM_ENUM)245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15)
|
2656
|
+
const __m512i lhs_mat_23_1_sp2 = _mm512_shuffle_epi32(lhs_mat_23_1, (_MM_PERM_ENUM)245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15)
|
2640
2657
|
|
2641
|
-
const __m512i lhs_mat_01_2_sp2 = _mm512_shuffle_epi32(lhs_mat_01_2, 245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23)
|
2642
|
-
const __m512i lhs_mat_23_2_sp2 = _mm512_shuffle_epi32(lhs_mat_23_2, 245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23)
|
2658
|
+
const __m512i lhs_mat_01_2_sp2 = _mm512_shuffle_epi32(lhs_mat_01_2, (_MM_PERM_ENUM)245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23)
|
2659
|
+
const __m512i lhs_mat_23_2_sp2 = _mm512_shuffle_epi32(lhs_mat_23_2, (_MM_PERM_ENUM)245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23)
|
2643
2660
|
|
2644
|
-
const __m512i lhs_mat_01_3_sp2 = _mm512_shuffle_epi32(lhs_mat_01_3, 245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31)
|
2645
|
-
const __m512i lhs_mat_23_3_sp2 = _mm512_shuffle_epi32(lhs_mat_23_3, 245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31)
|
2661
|
+
const __m512i lhs_mat_01_3_sp2 = _mm512_shuffle_epi32(lhs_mat_01_3, (_MM_PERM_ENUM)245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31)
|
2662
|
+
const __m512i lhs_mat_23_3_sp2 = _mm512_shuffle_epi32(lhs_mat_23_3, (_MM_PERM_ENUM)245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31)
|
2646
2663
|
|
2647
2664
|
// The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
|
2648
2665
|
// Resembles MMLAs into 2x2 matrices in ARM Version
|
@@ -2671,10 +2688,10 @@ void lm_ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void
|
|
2671
2688
|
|
2672
2689
|
|
2673
2690
|
// Straighten out to make 4 row vectors
|
2674
|
-
__m512i iacc_row_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00, _mm512_shuffle_epi32(iacc_mat_01, 78));
|
2675
|
-
__m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, 78), iacc_mat_01);
|
2676
|
-
__m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, 78));
|
2677
|
-
__m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11);
|
2691
|
+
__m512i iacc_row_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00, _mm512_shuffle_epi32(iacc_mat_01, (_MM_PERM_ENUM)78));
|
2692
|
+
__m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, (_MM_PERM_ENUM)78), iacc_mat_01);
|
2693
|
+
__m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, (_MM_PERM_ENUM)78));
|
2694
|
+
__m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, (_MM_PERM_ENUM)78), iacc_mat_11);
|
2678
2695
|
|
2679
2696
|
// Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes
|
2680
2697
|
const __m128i row_scale_f16 = _mm_shuffle_epi32(_mm_maskload_epi32((int const*)(a_ptrs[rp][b].d), loadMask), 68);
|
@@ -2753,31 +2770,31 @@ void lm_ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void
|
|
2753
2770
|
const __m512i rhs_mat_2367ABEF_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 4), m4bexpanded)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31) BA(24-31) BB(24-31) BE(24-31) BF(24-31)
|
2754
2771
|
|
2755
2772
|
// Shuffle pattern one - right side input
|
2756
|
-
const __m512i rhs_mat_014589CD_0_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, 136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) B8(0-3) B9(0-3) B8(0-3) B9(0-3) BC(0-3) BD(0-3) BC(0-3) BD(0-3)
|
2757
|
-
const __m512i rhs_mat_2367ABEF_0_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, 136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) BA(0-3) BB(0-3) BA(0-3) BB(0-3) BE(0-3) BF(0-3) BE(0-3) BF(0-3)
|
2773
|
+
const __m512i rhs_mat_014589CD_0_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, (_MM_PERM_ENUM)136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) B8(0-3) B9(0-3) B8(0-3) B9(0-3) BC(0-3) BD(0-3) BC(0-3) BD(0-3)
|
2774
|
+
const __m512i rhs_mat_2367ABEF_0_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, (_MM_PERM_ENUM)136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) BA(0-3) BB(0-3) BA(0-3) BB(0-3) BE(0-3) BF(0-3) BE(0-3) BF(0-3)
|
2758
2775
|
|
2759
|
-
const __m512i rhs_mat_014589CD_1_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, 136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) B8(8-11) B9(8-11) B8(8-11) B9(8-11) BC(8-11) BD(8-11) BC(8-11) BD(8-11)
|
2760
|
-
const __m512i rhs_mat_2367ABEF_1_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, 136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) BA(8-11) BB(8-11) BA(8-11) BB(8-11) BE(8-11) BF(8-11) BE(8-11) BF(8-11)
|
2776
|
+
const __m512i rhs_mat_014589CD_1_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, (_MM_PERM_ENUM)136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) B8(8-11) B9(8-11) B8(8-11) B9(8-11) BC(8-11) BD(8-11) BC(8-11) BD(8-11)
|
2777
|
+
const __m512i rhs_mat_2367ABEF_1_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, (_MM_PERM_ENUM)136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) BA(8-11) BB(8-11) BA(8-11) BB(8-11) BE(8-11) BF(8-11) BE(8-11) BF(8-11)
|
2761
2778
|
|
2762
|
-
const __m512i rhs_mat_014589CD_2_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, 136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) B8(16-19) B9(16-19) B8(16-19) B9(16-19) BC(16-19) BD(16-19) BC(16-19) BD(16-19)
|
2763
|
-
const __m512i rhs_mat_2367ABEF_2_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, 136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) BA(16-19) BB(16-19) BA(16-19) BB(16-19) BE(16-19) BF(16-19) BE(16-19) BF(16-19)
|
2779
|
+
const __m512i rhs_mat_014589CD_2_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, (_MM_PERM_ENUM)136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) B8(16-19) B9(16-19) B8(16-19) B9(16-19) BC(16-19) BD(16-19) BC(16-19) BD(16-19)
|
2780
|
+
const __m512i rhs_mat_2367ABEF_2_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, (_MM_PERM_ENUM)136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) BA(16-19) BB(16-19) BA(16-19) BB(16-19) BE(16-19) BF(16-19) BE(16-19) BF(16-19)
|
2764
2781
|
|
2765
|
-
const __m512i rhs_mat_014589CD_3_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, 136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) B8(24-27) B9(24-27) B8(24-27) B9(24-27) BC(24-27) BD(24-27) BC(24-27) BD(24-27)
|
2766
|
-
const __m512i rhs_mat_2367ABEF_3_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, 136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) BA(24-27) BB(24-27) BA(24-27) BB(24-27) BE(24-27) BF(24-27) BE(24-27) BF(24-27)
|
2782
|
+
const __m512i rhs_mat_014589CD_3_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, (_MM_PERM_ENUM)136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) B8(24-27) B9(24-27) B8(24-27) B9(24-27) BC(24-27) BD(24-27) BC(24-27) BD(24-27)
|
2783
|
+
const __m512i rhs_mat_2367ABEF_3_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, (_MM_PERM_ENUM)136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) BA(24-27) BB(24-27) BA(24-27) BB(24-27) BE(24-27) BF(24-27) BE(24-27) BF(24-27)
|
2767
2784
|
|
2768
2785
|
// Shuffle pattern two - right side input
|
2769
2786
|
|
2770
|
-
const __m512i rhs_mat_014589CD_0_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, 221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) B8(4-7) B9(4-7) B8(4-7) B9(4-7) BC(4-7) BD(4-7) BC(4-7) BD(4-7)
|
2771
|
-
const __m512i rhs_mat_2367ABEF_0_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, 221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) BA(4-7) BB(4-7) BA(4-7) BB(4-7) BE(4-7) BF(4-7) BE(4-7) BF(4-7)
|
2787
|
+
const __m512i rhs_mat_014589CD_0_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, (_MM_PERM_ENUM)221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) B8(4-7) B9(4-7) B8(4-7) B9(4-7) BC(4-7) BD(4-7) BC(4-7) BD(4-7)
|
2788
|
+
const __m512i rhs_mat_2367ABEF_0_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, (_MM_PERM_ENUM)221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) BA(4-7) BB(4-7) BA(4-7) BB(4-7) BE(4-7) BF(4-7) BE(4-7) BF(4-7)
|
2772
2789
|
|
2773
|
-
const __m512i rhs_mat_014589CD_1_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, 221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) B8(12-15) B9(12-15) B8(12-15) B9(12-15) BC(12-15) BD(12-15) BC(12-15) BD(12-15)
|
2774
|
-
const __m512i rhs_mat_2367ABEF_1_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, 221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) BA(12-15) BB(12-15) BA(12-15) BB(12-15) BE(12-15) BF(12-15) BE(12-15) BF(12-15)
|
2790
|
+
const __m512i rhs_mat_014589CD_1_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, (_MM_PERM_ENUM)221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) B8(12-15) B9(12-15) B8(12-15) B9(12-15) BC(12-15) BD(12-15) BC(12-15) BD(12-15)
|
2791
|
+
const __m512i rhs_mat_2367ABEF_1_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, (_MM_PERM_ENUM)221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) BA(12-15) BB(12-15) BA(12-15) BB(12-15) BE(12-15) BF(12-15) BE(12-15) BF(12-15)
|
2775
2792
|
|
2776
|
-
const __m512i rhs_mat_014589CD_2_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, 221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) B8(20-23) B9(20-23) B8(20-23) B9(20-23) BC(20-23) BD(20-23) BC(20-23) BD(20-23)
|
2777
|
-
const __m512i rhs_mat_2367ABEF_2_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, 221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) BA(20-23) BB(20-23) BA(20-23) BB(20-23) BE(20-23) BF(20-23) BE(20-23) BF(20-23)
|
2793
|
+
const __m512i rhs_mat_014589CD_2_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, (_MM_PERM_ENUM)221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) B8(20-23) B9(20-23) B8(20-23) B9(20-23) BC(20-23) BD(20-23) BC(20-23) BD(20-23)
|
2794
|
+
const __m512i rhs_mat_2367ABEF_2_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, (_MM_PERM_ENUM)221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) BA(20-23) BB(20-23) BA(20-23) BB(20-23) BE(20-23) BF(20-23) BE(20-23) BF(20-23)
|
2778
2795
|
|
2779
|
-
const __m512i rhs_mat_014589CD_3_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, 221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) B8(28-31) B9(28-31) B8(28-31) B9(28-31) BC(28-31) BD(28-31) BC(28-31) BD(28-31)
|
2780
|
-
const __m512i rhs_mat_2367ABEF_3_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, 221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) BA(28-31) BB(28-31) BA(28-31) BB(28-31) BE(28-31) BF(28-31) BE(28-31) BF(28-31)
|
2796
|
+
const __m512i rhs_mat_014589CD_3_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, (_MM_PERM_ENUM)221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) B8(28-31) B9(28-31) B8(28-31) B9(28-31) BC(28-31) BD(28-31) BC(28-31) BD(28-31)
|
2797
|
+
const __m512i rhs_mat_2367ABEF_3_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, (_MM_PERM_ENUM)221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) BA(28-31) BB(28-31) BA(28-31) BB(28-31) BE(28-31) BF(28-31) BE(28-31) BF(28-31)
|
2781
2798
|
|
2782
2799
|
|
2783
2800
|
// Scale values - Load the weight scale values of two block_q4_0x8
|
@@ -2809,31 +2826,31 @@ void lm_ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void
|
|
2809
2826
|
|
2810
2827
|
// Shuffle pattern one - left side input
|
2811
2828
|
|
2812
|
-
const __m512i lhs_mat_01_0_sp1 = _mm512_shuffle_epi32(lhs_mat_01_0, 160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3)
|
2813
|
-
const __m512i lhs_mat_23_0_sp1 = _mm512_shuffle_epi32(lhs_mat_23_0, 160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3)
|
2829
|
+
const __m512i lhs_mat_01_0_sp1 = _mm512_shuffle_epi32(lhs_mat_01_0, (_MM_PERM_ENUM)160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3)
|
2830
|
+
const __m512i lhs_mat_23_0_sp1 = _mm512_shuffle_epi32(lhs_mat_23_0, (_MM_PERM_ENUM)160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3)
|
2814
2831
|
|
2815
|
-
const __m512i lhs_mat_01_1_sp1 = _mm512_shuffle_epi32(lhs_mat_01_1, 160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11)
|
2816
|
-
const __m512i lhs_mat_23_1_sp1 = _mm512_shuffle_epi32(lhs_mat_23_1, 160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11)
|
2832
|
+
const __m512i lhs_mat_01_1_sp1 = _mm512_shuffle_epi32(lhs_mat_01_1, (_MM_PERM_ENUM)160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11)
|
2833
|
+
const __m512i lhs_mat_23_1_sp1 = _mm512_shuffle_epi32(lhs_mat_23_1, (_MM_PERM_ENUM)160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11)
|
2817
2834
|
|
2818
|
-
const __m512i lhs_mat_01_2_sp1 = _mm512_shuffle_epi32(lhs_mat_01_2, 160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19)
|
2819
|
-
const __m512i lhs_mat_23_2_sp1 = _mm512_shuffle_epi32(lhs_mat_23_2, 160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19)
|
2835
|
+
const __m512i lhs_mat_01_2_sp1 = _mm512_shuffle_epi32(lhs_mat_01_2, (_MM_PERM_ENUM)160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19)
|
2836
|
+
const __m512i lhs_mat_23_2_sp1 = _mm512_shuffle_epi32(lhs_mat_23_2, (_MM_PERM_ENUM)160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19)
|
2820
2837
|
|
2821
|
-
const __m512i lhs_mat_01_3_sp1 = _mm512_shuffle_epi32(lhs_mat_01_3, 160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27)
|
2822
|
-
const __m512i lhs_mat_23_3_sp1 = _mm512_shuffle_epi32(lhs_mat_23_3, 160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27)
|
2838
|
+
const __m512i lhs_mat_01_3_sp1 = _mm512_shuffle_epi32(lhs_mat_01_3, (_MM_PERM_ENUM)160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27)
|
2839
|
+
const __m512i lhs_mat_23_3_sp1 = _mm512_shuffle_epi32(lhs_mat_23_3, (_MM_PERM_ENUM)160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27)
|
2823
2840
|
|
2824
2841
|
// Shuffle pattern two - left side input
|
2825
2842
|
|
2826
|
-
const __m512i lhs_mat_01_0_sp2 = _mm512_shuffle_epi32(lhs_mat_01_0, 245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7)
|
2827
|
-
const __m512i lhs_mat_23_0_sp2 = _mm512_shuffle_epi32(lhs_mat_23_0, 245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7)
|
2843
|
+
const __m512i lhs_mat_01_0_sp2 = _mm512_shuffle_epi32(lhs_mat_01_0, (_MM_PERM_ENUM)245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7)
|
2844
|
+
const __m512i lhs_mat_23_0_sp2 = _mm512_shuffle_epi32(lhs_mat_23_0, (_MM_PERM_ENUM)245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7)
|
2828
2845
|
|
2829
|
-
const __m512i lhs_mat_01_1_sp2 = _mm512_shuffle_epi32(lhs_mat_01_1, 245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15)
|
2830
|
-
const __m512i lhs_mat_23_1_sp2 = _mm512_shuffle_epi32(lhs_mat_23_1, 245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15)
|
2846
|
+
const __m512i lhs_mat_01_1_sp2 = _mm512_shuffle_epi32(lhs_mat_01_1, (_MM_PERM_ENUM)245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15)
|
2847
|
+
const __m512i lhs_mat_23_1_sp2 = _mm512_shuffle_epi32(lhs_mat_23_1, (_MM_PERM_ENUM)245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15)
|
2831
2848
|
|
2832
|
-
const __m512i lhs_mat_01_2_sp2 = _mm512_shuffle_epi32(lhs_mat_01_2, 245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23)
|
2833
|
-
const __m512i lhs_mat_23_2_sp2 = _mm512_shuffle_epi32(lhs_mat_23_2, 245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23)
|
2849
|
+
const __m512i lhs_mat_01_2_sp2 = _mm512_shuffle_epi32(lhs_mat_01_2, (_MM_PERM_ENUM)245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23)
|
2850
|
+
const __m512i lhs_mat_23_2_sp2 = _mm512_shuffle_epi32(lhs_mat_23_2, (_MM_PERM_ENUM)245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23)
|
2834
2851
|
|
2835
|
-
const __m512i lhs_mat_01_3_sp2 = _mm512_shuffle_epi32(lhs_mat_01_3, 245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31)
|
2836
|
-
const __m512i lhs_mat_23_3_sp2 = _mm512_shuffle_epi32(lhs_mat_23_3, 245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31)
|
2852
|
+
const __m512i lhs_mat_01_3_sp2 = _mm512_shuffle_epi32(lhs_mat_01_3, (_MM_PERM_ENUM)245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31)
|
2853
|
+
const __m512i lhs_mat_23_3_sp2 = _mm512_shuffle_epi32(lhs_mat_23_3, (_MM_PERM_ENUM)245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31)
|
2837
2854
|
|
2838
2855
|
// The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
|
2839
2856
|
// Resembles MMLAs into 2x2 matrices in ARM Version
|
@@ -2862,10 +2879,10 @@ void lm_ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void
|
|
2862
2879
|
|
2863
2880
|
|
2864
2881
|
// Straighten out to make 4 row vectors
|
2865
|
-
__m512i iacc_row_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00, _mm512_shuffle_epi32(iacc_mat_01, 78));
|
2866
|
-
__m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, 78), iacc_mat_01);
|
2867
|
-
__m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, 78));
|
2868
|
-
__m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11);
|
2882
|
+
__m512i iacc_row_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00, _mm512_shuffle_epi32(iacc_mat_01, (_MM_PERM_ENUM)78));
|
2883
|
+
__m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, (_MM_PERM_ENUM)78), iacc_mat_01);
|
2884
|
+
__m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, (_MM_PERM_ENUM)78));
|
2885
|
+
__m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, (_MM_PERM_ENUM)78), iacc_mat_11);
|
2869
2886
|
|
2870
2887
|
// Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes
|
2871
2888
|
const __m128i row_scale_f16 = _mm_shuffle_epi32(_mm_maskload_epi32((int const*)(a_ptr[b].d), loadMask), 68);
|
@@ -3460,7 +3477,7 @@ void lm_ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void
|
|
3460
3477
|
}
|
3461
3478
|
}
|
3462
3479
|
|
3463
|
-
void lm_ggml_gemm_iq4_nl_4x4_q8_0(int n, float *
|
3480
|
+
static void lm_ggml_gemm_iq4_nl_4x4_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, const void * LM_GGML_RESTRICT vy, int nr, int nc) {
|
3464
3481
|
const int qk = QK8_0;
|
3465
3482
|
const int nb = n / qk;
|
3466
3483
|
const int ncols_interleaved = 4;
|
@@ -3571,7 +3588,6 @@ void lm_ggml_gemm_iq4_nl_4x4_q8_0(int n, float * restrict s, size_t bs, const vo
|
|
3571
3588
|
}
|
3572
3589
|
}
|
3573
3590
|
|
3574
|
-
// FIXME: this code is duplicated from ggml-aarch64.c
|
3575
3591
|
static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int blck_size_interleave) {
|
3576
3592
|
block_q4_0x4 out;
|
3577
3593
|
|
@@ -3641,20 +3657,20 @@ static block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int blck_size_in
|
|
3641
3657
|
return out;
|
3642
3658
|
}
|
3643
3659
|
|
3644
|
-
static int repack_q4_0_to_q4_0_4_bl(struct lm_ggml_tensor * t, int interleave_block, const void *
|
3660
|
+
static int repack_q4_0_to_q4_0_4_bl(struct lm_ggml_tensor * t, int interleave_block, const void * LM_GGML_RESTRICT data, size_t data_size) {
|
3645
3661
|
LM_GGML_ASSERT(t->type == LM_GGML_TYPE_Q4_0);
|
3646
3662
|
LM_GGML_ASSERT(interleave_block == 4 || interleave_block == 8);
|
3663
|
+
constexpr int nrows_interleaved = 4;
|
3647
3664
|
|
3648
3665
|
block_q4_0x4 * dst = (block_q4_0x4 *)t->data;
|
3649
3666
|
const block_q4_0 * src = (const block_q4_0 *)data;
|
3650
3667
|
block_q4_0 dst_tmp[4];
|
3651
|
-
int nrow = t
|
3652
|
-
int nrows_interleaved = 4;
|
3668
|
+
int nrow = lm_ggml_nrows(t);
|
3653
3669
|
int nblocks = t->ne[0] / QK4_0;
|
3654
3670
|
|
3655
3671
|
LM_GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0));
|
3656
3672
|
|
3657
|
-
if (
|
3673
|
+
if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
|
3658
3674
|
return -1;
|
3659
3675
|
}
|
3660
3676
|
|
@@ -3672,20 +3688,20 @@ static int repack_q4_0_to_q4_0_4_bl(struct lm_ggml_tensor * t, int interleave_bl
|
|
3672
3688
|
LM_GGML_UNUSED(data_size);
|
3673
3689
|
}
|
3674
3690
|
|
3675
|
-
static int repack_q4_0_to_q4_0_8_bl(struct lm_ggml_tensor *t, int interleave_block, const void *
|
3691
|
+
static int repack_q4_0_to_q4_0_8_bl(struct lm_ggml_tensor * t, int interleave_block, const void * LM_GGML_RESTRICT data, size_t data_size) {
|
3676
3692
|
LM_GGML_ASSERT(t->type == LM_GGML_TYPE_Q4_0);
|
3677
3693
|
LM_GGML_ASSERT(interleave_block == 8);
|
3694
|
+
constexpr int nrows_interleaved = 8;
|
3678
3695
|
|
3679
3696
|
block_q4_0x8 * dst = (block_q4_0x8*)t->data;
|
3680
3697
|
const block_q4_0 * src = (const block_q4_0*) data;
|
3681
3698
|
block_q4_0 dst_tmp[8];
|
3682
|
-
int nrow = t
|
3683
|
-
int nrows_interleaved = 8;
|
3699
|
+
int nrow = lm_ggml_nrows(t);
|
3684
3700
|
int nblocks = t->ne[0] / QK4_0;
|
3685
3701
|
|
3686
3702
|
LM_GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0));
|
3687
3703
|
|
3688
|
-
if (
|
3704
|
+
if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
|
3689
3705
|
return -1;
|
3690
3706
|
}
|
3691
3707
|
|
@@ -3712,16 +3728,18 @@ static block_iq4_nlx4 make_block_iq4_nlx4(block_iq4_nl * in, unsigned int blck_s
|
|
3712
3728
|
|
3713
3729
|
const int end = QK4_NL * 2 / blck_size_interleave;
|
3714
3730
|
|
3715
|
-
|
3716
|
-
|
3717
|
-
|
3718
|
-
|
3719
|
-
|
3720
|
-
|
3721
|
-
|
3722
|
-
|
3723
|
-
|
3724
|
-
}
|
3731
|
+
// TODO: this branch seems wrong
|
3732
|
+
//if (blck_size_interleave == 8) {
|
3733
|
+
// for (int i = 0; i < end; ++i) {
|
3734
|
+
// int src_id = i % 4;
|
3735
|
+
// int src_offset = (i / 4) * blck_size_interleave;
|
3736
|
+
// int dst_offset = i * blck_size_interleave;
|
3737
|
+
|
3738
|
+
// // Using memcpy to avoid unaligned memory accesses
|
3739
|
+
// memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint64_t));
|
3740
|
+
// }
|
3741
|
+
//} else
|
3742
|
+
if (blck_size_interleave == 4) {
|
3725
3743
|
for (int i = 0; i < end; ++i) {
|
3726
3744
|
int src_id = i % 4;
|
3727
3745
|
int src_offset = (i / 4) * blck_size_interleave;
|
@@ -3736,20 +3754,21 @@ static block_iq4_nlx4 make_block_iq4_nlx4(block_iq4_nl * in, unsigned int blck_s
|
|
3736
3754
|
return out;
|
3737
3755
|
}
|
3738
3756
|
|
3739
|
-
static int repack_iq4_nl_to_iq4_nl_4_bl(struct lm_ggml_tensor * t, int interleave_block, const void *
|
3757
|
+
static int repack_iq4_nl_to_iq4_nl_4_bl(struct lm_ggml_tensor * t, int interleave_block, const void * LM_GGML_RESTRICT data, size_t data_size) {
|
3740
3758
|
LM_GGML_ASSERT(t->type == LM_GGML_TYPE_IQ4_NL);
|
3741
|
-
LM_GGML_ASSERT(interleave_block == 4 || interleave_block == 8);
|
3759
|
+
//LM_GGML_ASSERT(interleave_block == 4 || interleave_block == 8);
|
3760
|
+
LM_GGML_ASSERT(interleave_block == 4);
|
3742
3761
|
|
3743
3762
|
block_iq4_nlx4 * dst = (block_iq4_nlx4 *)t->data;
|
3744
3763
|
const block_iq4_nl * src = (const block_iq4_nl *)data;
|
3745
3764
|
block_iq4_nl dst_tmp[4];
|
3746
|
-
int nrow = t
|
3765
|
+
int nrow = lm_ggml_nrows(t);
|
3747
3766
|
int nrows_interleaved = 4;
|
3748
3767
|
int nblocks = t->ne[0] / QK4_0;
|
3749
3768
|
|
3750
3769
|
LM_GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_iq4_nl));
|
3751
3770
|
|
3752
|
-
if (
|
3771
|
+
if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
|
3753
3772
|
return -1;
|
3754
3773
|
}
|
3755
3774
|
|
@@ -3767,57 +3786,457 @@ static int repack_iq4_nl_to_iq4_nl_4_bl(struct lm_ggml_tensor * t, int interleav
|
|
3767
3786
|
LM_GGML_UNUSED(data_size);
|
3768
3787
|
}
|
3769
3788
|
|
3770
|
-
|
3771
|
-
|
3772
|
-
|
3773
|
-
|
3774
|
-
|
3789
|
+
namespace ggml::cpu::aarch64 {
|
3790
|
+
// repack
|
3791
|
+
template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS>
|
3792
|
+
int repack(struct lm_ggml_tensor *, const void *, size_t);
|
3793
|
+
|
3794
|
+
// TODO: generalise.
|
3795
|
+
template <> int repack<block_q4_0, 4, 4>(struct lm_ggml_tensor * t, const void * data, size_t data_size) {
|
3796
|
+
return repack_q4_0_to_q4_0_4_bl(t, 4, data, data_size);
|
3797
|
+
}
|
3798
|
+
|
3799
|
+
template <> int repack<block_q4_0, 8, 4>(struct lm_ggml_tensor * t, const void * data, size_t data_size) {
|
3800
|
+
return repack_q4_0_to_q4_0_4_bl(t, 8, data, data_size);
|
3801
|
+
}
|
3802
|
+
|
3803
|
+
template <> int repack<block_q4_0, 8, 8>(struct lm_ggml_tensor * t, const void * data, size_t data_size) {
|
3804
|
+
return repack_q4_0_to_q4_0_8_bl(t, 8, data, data_size);
|
3805
|
+
}
|
3806
|
+
|
3807
|
+
template <> int repack<block_iq4_nl, 4, 4>(struct lm_ggml_tensor * t, const void * data, size_t data_size) {
|
3808
|
+
return repack_iq4_nl_to_iq4_nl_4_bl(t, 4, data, data_size);
|
3809
|
+
}
|
3810
|
+
|
3811
|
+
// TODO: needs to be revisited
|
3812
|
+
//template <> int repack<block_iq4_nl, 8, 4>(struct lm_ggml_tensor * t, const void * data, size_t data_size) {
|
3813
|
+
// return repack_iq4_nl_to_iq4_nl_4_bl(t, 8, data, data_size);
|
3814
|
+
//}
|
3815
|
+
|
3816
|
+
// gemv
|
3817
|
+
template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS>
|
3818
|
+
void gemv(int, float *, size_t, const void *, const void *, int, int);
|
3819
|
+
|
3820
|
+
template <> void gemv<block_q4_0, 4, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
3821
|
+
lm_ggml_gemv_q4_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
|
3822
|
+
}
|
3823
|
+
|
3824
|
+
template <> void gemv<block_q4_0, 8, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
3825
|
+
lm_ggml_gemv_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
|
3826
|
+
}
|
3827
|
+
|
3828
|
+
template <> void gemv<block_q4_0, 8, 8>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
3829
|
+
lm_ggml_gemv_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
|
3830
|
+
}
|
3831
|
+
|
3832
|
+
template <>
|
3833
|
+
void gemv<block_iq4_nl, 4, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
3834
|
+
lm_ggml_gemv_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
|
3835
|
+
}
|
3836
|
+
|
3837
|
+
// gemm
|
3838
|
+
template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS>
|
3839
|
+
void gemm(int, float *, size_t, const void *, const void *, int, int);
|
3840
|
+
|
3841
|
+
template <> void gemm<block_q4_0, 4, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
3842
|
+
lm_ggml_gemm_q4_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
|
3843
|
+
}
|
3844
|
+
|
3845
|
+
template <> void gemm<block_q4_0, 8, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
3846
|
+
lm_ggml_gemm_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
|
3847
|
+
}
|
3848
|
+
|
3849
|
+
template <> void gemm<block_q4_0, 8, 8>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
3850
|
+
lm_ggml_gemm_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
|
3851
|
+
}
|
3852
|
+
|
3853
|
+
template <>
|
3854
|
+
void gemm<block_iq4_nl, 4, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
3855
|
+
lm_ggml_gemm_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
|
3856
|
+
}
|
3857
|
+
|
3858
|
+
class tensor_traits_base : public ggml::cpu::tensor_traits {
|
3859
|
+
public:
|
3860
|
+
virtual int repack(struct lm_ggml_tensor * t, const void * data, size_t data_size) = 0;
|
3861
|
+
};
|
3862
|
+
|
3863
|
+
template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS> class tensor_traits : public tensor_traits_base {
|
3864
|
+
|
3865
|
+
bool work_size(int /* n_threads */, const struct lm_ggml_tensor * op, size_t & size) override {
|
3866
|
+
// not realy a LM_GGML_TYPE_Q8_0 but same size.
|
3867
|
+
switch (op->op) {
|
3868
|
+
case LM_GGML_OP_MUL_MAT:
|
3869
|
+
size = lm_ggml_row_size(LM_GGML_TYPE_Q8_0, lm_ggml_nelements(op->src[1]));
|
3870
|
+
return true;
|
3871
|
+
case LM_GGML_OP_MUL_MAT_ID:
|
3872
|
+
size = lm_ggml_row_size(LM_GGML_TYPE_Q8_0, lm_ggml_nelements(op->src[1]));
|
3873
|
+
size = LM_GGML_PAD(size, sizeof(int64_t)); // + padding for next bloc.
|
3874
|
+
size += sizeof(int64_t) * (1+op->src[0]->ne[2]) * op->src[1]->ne[2];
|
3875
|
+
return true;
|
3876
|
+
default:
|
3877
|
+
// LM_GGML_ABORT("fatal error");
|
3878
|
+
break;
|
3879
|
+
}
|
3880
|
+
return false;
|
3775
3881
|
}
|
3776
3882
|
|
3777
|
-
|
3778
|
-
switch (
|
3779
|
-
|
3780
|
-
|
3781
|
-
|
3782
|
-
|
3783
|
-
|
3784
|
-
|
3785
|
-
|
3786
|
-
|
3787
|
-
|
3788
|
-
default:
|
3789
|
-
LM_GGML_ABORT("Unsupported type");
|
3883
|
+
bool compute_forward(struct lm_ggml_compute_params * params, struct lm_ggml_tensor * op) override {
|
3884
|
+
switch (op->op) {
|
3885
|
+
case LM_GGML_OP_MUL_MAT:
|
3886
|
+
forward_mul_mat(params, op);
|
3887
|
+
return true;
|
3888
|
+
case LM_GGML_OP_MUL_MAT_ID:
|
3889
|
+
forward_mul_mat_id(params, op);
|
3890
|
+
return true;
|
3891
|
+
default:
|
3892
|
+
// LM_GGML_ABORT("fatal error");
|
3893
|
+
break;
|
3790
3894
|
}
|
3791
|
-
|
3792
|
-
|
3793
|
-
|
3794
|
-
|
3795
|
-
|
3796
|
-
|
3797
|
-
|
3895
|
+
return false;
|
3896
|
+
}
|
3897
|
+
|
3898
|
+
void forward_mul_mat(lm_ggml_compute_params * params, lm_ggml_tensor * op) {
|
3899
|
+
const lm_ggml_tensor * src0 = op->src[0];
|
3900
|
+
const lm_ggml_tensor * src1 = op->src[1];
|
3901
|
+
lm_ggml_tensor * dst = op;
|
3902
|
+
|
3903
|
+
LM_GGML_TENSOR_BINARY_OP_LOCALS
|
3904
|
+
|
3905
|
+
const int ith = params->ith;
|
3906
|
+
const int nth = params->nth;
|
3907
|
+
|
3908
|
+
LM_GGML_ASSERT(ne0 == ne01);
|
3909
|
+
LM_GGML_ASSERT(ne1 == ne11);
|
3910
|
+
LM_GGML_ASSERT(ne2 == ne12);
|
3911
|
+
LM_GGML_ASSERT(ne3 == ne13);
|
3912
|
+
|
3913
|
+
// dst cannot be transposed or permuted
|
3914
|
+
LM_GGML_ASSERT(nb0 == sizeof(float));
|
3915
|
+
LM_GGML_ASSERT(nb0 <= nb1);
|
3916
|
+
LM_GGML_ASSERT(nb1 <= nb2);
|
3917
|
+
LM_GGML_ASSERT(nb2 <= nb3);
|
3918
|
+
|
3919
|
+
LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_F32);
|
3920
|
+
|
3921
|
+
LM_GGML_ASSERT(lm_ggml_n_dims(op->src[0]) == 2);
|
3922
|
+
// LM_GGML_ASSERT(lm_ggml_n_dims(op->src[1]) == 2);
|
3923
|
+
|
3924
|
+
char * wdata = static_cast<char *>(params->wdata);
|
3925
|
+
const size_t nbw1 = lm_ggml_row_size(LM_GGML_TYPE_Q8_0, ne10);
|
3926
|
+
|
3927
|
+
assert(params->wsize >= nbw1 * ne11);
|
3928
|
+
|
3929
|
+
const lm_ggml_from_float_t from_float = lm_ggml_get_type_traits_cpu(LM_GGML_TYPE_Q8_0)->from_float;
|
3930
|
+
|
3931
|
+
int64_t i11_processed = 0;
|
3932
|
+
for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
|
3933
|
+
quantize_mat_q8_0((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), 4, ne10,
|
3934
|
+
INTER_SIZE);
|
3935
|
+
}
|
3936
|
+
i11_processed = ne11 - ne11 % 4;
|
3937
|
+
for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) {
|
3938
|
+
from_float((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), ne10);
|
3939
|
+
}
|
3940
|
+
|
3941
|
+
lm_ggml_barrier(params->threadpool);
|
3942
|
+
|
3943
|
+
const void * src1_wdata = params->wdata;
|
3944
|
+
const size_t src1_col_stride = lm_ggml_row_size(LM_GGML_TYPE_Q8_0, ne10);
|
3945
|
+
int64_t src0_start = (ith * ne01) / nth;
|
3946
|
+
int64_t src0_end = ((ith + 1) * ne01) / nth;
|
3947
|
+
src0_start = (src0_start % NB_COLS) ? src0_start + NB_COLS - (src0_start % NB_COLS) : src0_start;
|
3948
|
+
src0_end = (src0_end % NB_COLS) ? src0_end + NB_COLS - (src0_end % NB_COLS) : src0_end;
|
3949
|
+
if (src0_start >= src0_end) {
|
3950
|
+
return;
|
3951
|
+
}
|
3952
|
+
|
3953
|
+
// If there are more than three rows in src1, use gemm; otherwise, use gemv.
|
3954
|
+
if (ne11 > 3) {
|
3955
|
+
gemm<BLOC_TYPE, INTER_SIZE, NB_COLS>(ne00, (float *) ((char *) dst->data) + src0_start, ne01,
|
3956
|
+
(const char *) src0->data + src0_start * nb01,
|
3957
|
+
(const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start);
|
3958
|
+
}
|
3959
|
+
for (int iter = ne11 - ne11 % 4; iter < ne11; iter++) {
|
3960
|
+
gemv<BLOC_TYPE, INTER_SIZE, NB_COLS>(ne00, (float *) ((char *) dst->data + (iter * nb1)) + src0_start, ne01,
|
3961
|
+
(const char *) src0->data + src0_start * nb01,
|
3962
|
+
(const char *) src1_wdata + (src1_col_stride * iter), 1,
|
3963
|
+
src0_end - src0_start);
|
3798
3964
|
}
|
3799
|
-
} else {
|
3800
|
-
LM_GGML_ABORT("Unsupported type");
|
3801
3965
|
}
|
3802
|
-
}
|
3803
3966
|
|
3804
|
-
|
3967
|
+
void forward_mul_mat_id(lm_ggml_compute_params * params, lm_ggml_tensor * op) {
|
3968
|
+
const lm_ggml_tensor * src0 = op->src[0];
|
3969
|
+
const lm_ggml_tensor * src1 = op->src[1];
|
3970
|
+
const lm_ggml_tensor * ids = op->src[2];
|
3971
|
+
lm_ggml_tensor * dst = op;
|
3972
|
+
|
3973
|
+
LM_GGML_TENSOR_BINARY_OP_LOCALS
|
3974
|
+
|
3975
|
+
const int ith = params->ith;
|
3976
|
+
const int nth = params->nth;
|
3977
|
+
|
3978
|
+
const lm_ggml_from_float_t from_float = lm_ggml_get_type_traits_cpu(LM_GGML_TYPE_Q8_0)->from_float;
|
3979
|
+
|
3980
|
+
// we don't support permuted src0 or src1
|
3981
|
+
LM_GGML_ASSERT(nb00 == lm_ggml_type_size(src0->type));
|
3982
|
+
LM_GGML_ASSERT(nb10 == lm_ggml_type_size(src1->type));
|
3983
|
+
|
3984
|
+
// dst cannot be transposed or permuted
|
3985
|
+
LM_GGML_ASSERT(nb0 == sizeof(float));
|
3986
|
+
LM_GGML_ASSERT(nb0 <= nb1);
|
3987
|
+
LM_GGML_ASSERT(nb1 <= nb2);
|
3988
|
+
LM_GGML_ASSERT(nb2 <= nb3);
|
3989
|
+
|
3990
|
+
LM_GGML_ASSERT(ne03 == 1);
|
3991
|
+
LM_GGML_ASSERT(ne13 == 1);
|
3992
|
+
LM_GGML_ASSERT(ne3 == 1);
|
3993
|
+
|
3994
|
+
LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_F32);
|
3995
|
+
|
3996
|
+
// row groups
|
3997
|
+
const int n_ids = ids->ne[0]; // n_expert_used
|
3998
|
+
const int n_as = ne02; // n_expert
|
3999
|
+
|
4000
|
+
const size_t nbw1 = lm_ggml_row_size(LM_GGML_TYPE_Q8_0, ne10);
|
4001
|
+
const size_t nbw2 = nbw1*ne11;
|
4002
|
+
const size_t nbw3 = nbw2*ne12;
|
4003
|
+
|
4004
|
+
struct mmid_row_mapping {
|
4005
|
+
int32_t i1;
|
4006
|
+
int32_t i2;
|
4007
|
+
};
|
4008
|
+
|
4009
|
+
LM_GGML_ASSERT(params->wsize >= (LM_GGML_PAD(nbw3, sizeof(int64_t)) + n_as * sizeof(int64_t) +
|
4010
|
+
n_as * ne12 * sizeof(mmid_row_mapping)));
|
4011
|
+
|
4012
|
+
auto wdata = (char *) params->wdata;
|
4013
|
+
auto wdata_src1_end = (char *) wdata + LM_GGML_PAD(nbw3, sizeof(int64_t));
|
4014
|
+
int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
|
4015
|
+
struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *) (matrix_row_counts + n_as); // [n_as][ne12]
|
4016
|
+
|
4017
|
+
// src1: float32 => block_q8_0
|
4018
|
+
for (int64_t i12 = 0; i12 < ne12; ++i12) {
|
4019
|
+
for (int64_t i11 = ith; i11 < ne11; i11 += nth) {
|
4020
|
+
from_float((float *)((char *) src1->data + i12 * nb12 + i11 * nb11),
|
4021
|
+
(void *) (wdata + i12 * nbw2 + i11 * nbw1),
|
4022
|
+
ne10);
|
4023
|
+
}
|
4024
|
+
}
|
4025
|
+
|
4026
|
+
#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id) * ne12 + (i1)]
|
4027
|
+
|
4028
|
+
if (ith == 0) {
|
4029
|
+
// initialize matrix_row_counts
|
4030
|
+
memset(matrix_row_counts, 0, n_as * sizeof(int64_t));
|
4031
|
+
|
4032
|
+
// group rows by src0 matrix
|
4033
|
+
for (int32_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) {
|
4034
|
+
for (int32_t id = 0; id < n_ids; ++id) {
|
4035
|
+
const int32_t i02 =
|
4036
|
+
*(const int32_t *) ((const char *) ids->data + iid1 * ids->nb[1] + id * ids->nb[0]);
|
4037
|
+
|
4038
|
+
LM_GGML_ASSERT(i02 >= 0 && i02 < n_as);
|
4039
|
+
|
4040
|
+
MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = { id, iid1 };
|
4041
|
+
matrix_row_counts[i02] += 1;
|
4042
|
+
}
|
4043
|
+
}
|
4044
|
+
}
|
4045
|
+
|
4046
|
+
lm_ggml_barrier(params->threadpool);
|
4047
|
+
|
4048
|
+
// compute each matrix multiplication in sequence
|
4049
|
+
for (int cur_a = 0; cur_a < n_as; ++cur_a) {
|
4050
|
+
const int64_t cne1 = matrix_row_counts[cur_a];
|
4051
|
+
|
4052
|
+
if (cne1 == 0) {
|
4053
|
+
continue;
|
4054
|
+
}
|
4055
|
+
|
4056
|
+
auto src0_cur = (const char *) src0->data + cur_a*nb02;
|
4057
|
+
|
4058
|
+
//const int64_t nr0 = ne01; // src0 rows
|
4059
|
+
const int64_t nr1 = cne1; // src1 rows
|
4060
|
+
|
4061
|
+
int64_t src0_cur_start = (ith * ne01) / nth;
|
4062
|
+
int64_t src0_cur_end = ((ith + 1) * ne01) / nth;
|
4063
|
+
src0_cur_start =
|
4064
|
+
(src0_cur_start % NB_COLS) ? src0_cur_start + NB_COLS - (src0_cur_start % NB_COLS) : src0_cur_start;
|
4065
|
+
src0_cur_end = (src0_cur_end % NB_COLS) ? src0_cur_end + NB_COLS - (src0_cur_end % NB_COLS) : src0_cur_end;
|
4066
|
+
|
4067
|
+
if (src0_cur_start >= src0_cur_end) return;
|
4068
|
+
|
4069
|
+
for (int ir1 = 0; ir1 < nr1; ir1++) {
|
4070
|
+
struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, ir1);
|
4071
|
+
const int id = row_mapping.i1; // selected expert index
|
4072
|
+
|
4073
|
+
const int64_t i11 = id % ne11;
|
4074
|
+
const int64_t i12 = row_mapping.i2; // row index in src1
|
4075
|
+
|
4076
|
+
const int64_t i1 = id; // selected expert index
|
4077
|
+
const int64_t i2 = i12; // row
|
4078
|
+
|
4079
|
+
auto src1_col = (const char *) wdata + (i11 * nbw1 + i12 * nbw2);
|
4080
|
+
|
4081
|
+
gemv<BLOC_TYPE, INTER_SIZE, NB_COLS>(
|
4082
|
+
ne00, (float *)((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start,
|
4083
|
+
ne01, src0_cur + src0_cur_start * nb01,
|
4084
|
+
src1_col, 1, src0_cur_end - src0_cur_start);
|
4085
|
+
}
|
4086
|
+
}
|
4087
|
+
#undef MMID_MATRIX_ROW
|
4088
|
+
}
|
4089
|
+
|
4090
|
+
int repack(struct lm_ggml_tensor * t, const void * data, size_t data_size) override {
|
4091
|
+
LM_GGML_LOG_DEBUG("%s: repack tensor %s with %s_%dx%d\n", __func__, t->name, lm_ggml_type_name(t->type),
|
4092
|
+
(int) NB_COLS, (int) INTER_SIZE);
|
4093
|
+
return ggml::cpu::aarch64::repack<BLOC_TYPE, INTER_SIZE, NB_COLS>(t, data, data_size);
|
4094
|
+
}
|
4095
|
+
};
|
4096
|
+
|
4097
|
+
// instance for Q4
|
4098
|
+
static const tensor_traits<block_q4_0, 4, 4> q4_0_4x4_q8_0;
|
4099
|
+
static const tensor_traits<block_q4_0, 8, 4> q4_0_4x8_q8_0;
|
4100
|
+
static const tensor_traits<block_q4_0, 8, 8> q4_0_8x8_q8_0;
|
4101
|
+
|
4102
|
+
// instance for IQ4
|
4103
|
+
static const tensor_traits<block_iq4_nl, 4, 4> iq4_nl_4x4_q8_0;
|
4104
|
+
|
4105
|
+
} // namespace ggml::cpu::aarch64
|
4106
|
+
|
4107
|
+
static const ggml::cpu::tensor_traits * lm_ggml_aarch64_get_optimal_repack_type(const struct lm_ggml_tensor * cur) {
|
3805
4108
|
if (cur->type == LM_GGML_TYPE_Q4_0) {
|
3806
|
-
|
3807
|
-
|
3808
|
-
|
4109
|
+
if (lm_ggml_cpu_has_avx2() || (lm_ggml_cpu_has_sve() && lm_ggml_cpu_has_matmul_int8() && lm_ggml_cpu_get_sve_cnt() == QK8_0)) {
|
4110
|
+
if (cur->ne[1] % 8 == 0) {
|
4111
|
+
return &ggml::cpu::aarch64::q4_0_8x8_q8_0;
|
4112
|
+
}
|
3809
4113
|
}
|
3810
4114
|
if (lm_ggml_cpu_has_neon() && lm_ggml_cpu_has_matmul_int8()) {
|
3811
|
-
|
4115
|
+
if (cur->ne[1] % 4 == 0) {
|
4116
|
+
return &ggml::cpu::aarch64::q4_0_4x8_q8_0;
|
4117
|
+
}
|
3812
4118
|
}
|
3813
4119
|
if (lm_ggml_cpu_has_neon() && lm_ggml_cpu_has_dotprod()) {
|
3814
|
-
|
4120
|
+
if (cur->ne[1] % 4 == 0) {
|
4121
|
+
return &ggml::cpu::aarch64::q4_0_4x4_q8_0;
|
4122
|
+
}
|
3815
4123
|
}
|
3816
4124
|
} else if (cur->type == LM_GGML_TYPE_IQ4_NL) {
|
3817
4125
|
if (lm_ggml_cpu_has_neon() && lm_ggml_cpu_has_dotprod()) {
|
3818
|
-
|
4126
|
+
if (cur->ne[1] % 4 == 0) {
|
4127
|
+
return &ggml::cpu::aarch64::iq4_nl_4x4_q8_0;
|
4128
|
+
}
|
4129
|
+
}
|
4130
|
+
}
|
4131
|
+
|
4132
|
+
return nullptr;
|
4133
|
+
}
|
4134
|
+
|
4135
|
+
static void lm_ggml_backend_cpu_aarch64_buffer_init_tensor(lm_ggml_backend_buffer_t buffer, struct lm_ggml_tensor * tensor) {
|
4136
|
+
tensor->extra = (void *) const_cast<ggml::cpu::tensor_traits *>(lm_ggml_aarch64_get_optimal_repack_type(tensor));
|
4137
|
+
|
4138
|
+
LM_GGML_UNUSED(buffer);
|
4139
|
+
}
|
4140
|
+
|
4141
|
+
static void lm_ggml_backend_cpu_aarch64_buffer_set_tensor(lm_ggml_backend_buffer_t buffer, struct lm_ggml_tensor * tensor,
|
4142
|
+
const void * data, size_t offset, size_t size) {
|
4143
|
+
LM_GGML_ASSERT(offset == 0);
|
4144
|
+
LM_GGML_ASSERT(size == lm_ggml_nbytes(tensor));
|
4145
|
+
|
4146
|
+
auto tensor_traits = (ggml::cpu::aarch64::tensor_traits_base *) tensor->extra;
|
4147
|
+
auto OK = tensor_traits->repack(tensor, data, size);
|
4148
|
+
|
4149
|
+
LM_GGML_ASSERT(OK == 0);
|
4150
|
+
LM_GGML_UNUSED(buffer);
|
4151
|
+
}
|
4152
|
+
|
4153
|
+
static const char * lm_ggml_backend_cpu_aarch64_buffer_type_get_name(lm_ggml_backend_buffer_type_t buft) {
|
4154
|
+
return "CPU_AARCH64";
|
4155
|
+
|
4156
|
+
LM_GGML_UNUSED(buft);
|
4157
|
+
}
|
4158
|
+
|
4159
|
+
static lm_ggml_backend_buffer_t lm_ggml_backend_cpu_aarch64_buffer_type_alloc_buffer(lm_ggml_backend_buffer_type_t buft, size_t size) {
|
4160
|
+
lm_ggml_backend_buffer_t buffer = lm_ggml_backend_buft_alloc_buffer(lm_ggml_backend_cpu_buffer_type(), size);
|
4161
|
+
|
4162
|
+
if (buffer == nullptr) {
|
4163
|
+
return nullptr;
|
4164
|
+
}
|
4165
|
+
|
4166
|
+
buffer->buft = buft;
|
4167
|
+
buffer->iface.init_tensor = lm_ggml_backend_cpu_aarch64_buffer_init_tensor;
|
4168
|
+
buffer->iface.set_tensor = lm_ggml_backend_cpu_aarch64_buffer_set_tensor;
|
4169
|
+
return buffer;
|
4170
|
+
}
|
4171
|
+
|
4172
|
+
static size_t lm_ggml_backend_cpu_aarch64_buffer_type_get_alignment(lm_ggml_backend_buffer_type_t buft) {
|
4173
|
+
return TENSOR_ALIGNMENT;
|
4174
|
+
|
4175
|
+
LM_GGML_UNUSED(buft);
|
4176
|
+
}
|
4177
|
+
|
4178
|
+
namespace ggml::cpu::aarch64 {
|
4179
|
+
class extra_buffer_type : ggml::cpu::extra_buffer_type {
|
4180
|
+
bool supports_op(lm_ggml_backend_dev_t, const struct lm_ggml_tensor * op) override {
|
4181
|
+
if ( op->op == LM_GGML_OP_MUL_MAT &&
|
4182
|
+
op->src[0]->buffer &&
|
4183
|
+
(lm_ggml_n_dims(op->src[0]) == 2) &&
|
4184
|
+
op->src[0]->buffer->buft == lm_ggml_backend_cpu_aarch64_buffer_type() &&
|
4185
|
+
lm_ggml_aarch64_get_optimal_repack_type(op->src[0])
|
4186
|
+
) {
|
4187
|
+
if (op->src[1]->buffer && !lm_ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
|
4188
|
+
return false;
|
4189
|
+
}
|
4190
|
+
if (op->src[1]->type == LM_GGML_TYPE_F32) {
|
4191
|
+
return true;
|
4192
|
+
}
|
4193
|
+
//if (op->src[1]->type == LM_GGML_TYPE_Q8_0) {
|
4194
|
+
// return true;
|
4195
|
+
//}
|
4196
|
+
// may be possible if Q8_0 packed...
|
4197
|
+
} else if (op->op == LM_GGML_OP_MUL_MAT_ID
|
4198
|
+
&& op->src[0]->buffer
|
4199
|
+
&& (lm_ggml_n_dims(op->src[0]) == 3)
|
4200
|
+
&& op->src[0]->buffer->buft == lm_ggml_backend_cpu_aarch64_buffer_type()
|
4201
|
+
&& lm_ggml_aarch64_get_optimal_repack_type(op->src[0])
|
4202
|
+
) {
|
4203
|
+
if (op->src[1]->buffer && !lm_ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
|
4204
|
+
return false;
|
4205
|
+
}
|
4206
|
+
if (op->src[1]->type == LM_GGML_TYPE_F32) {
|
4207
|
+
return true;
|
4208
|
+
}
|
4209
|
+
//if (op->src[1]->type == LM_GGML_TYPE_Q8_0) {
|
4210
|
+
// return true;
|
4211
|
+
//}
|
3819
4212
|
}
|
4213
|
+
return false;
|
3820
4214
|
}
|
3821
4215
|
|
3822
|
-
|
4216
|
+
ggml::cpu::tensor_traits * get_tensor_traits(const struct lm_ggml_tensor * op) override {
|
4217
|
+
if (op->op == LM_GGML_OP_MUL_MAT || op->op == LM_GGML_OP_MUL_MAT_ID) {
|
4218
|
+
if (op->src[0]->buffer && op->src[0]->buffer->buft == lm_ggml_backend_cpu_aarch64_buffer_type()) {
|
4219
|
+
return (ggml::cpu::tensor_traits *) op->src[0]->extra;
|
4220
|
+
}
|
4221
|
+
}
|
4222
|
+
return nullptr;
|
4223
|
+
}
|
4224
|
+
};
|
4225
|
+
} // namespace ggml::cpu::aarch64
|
4226
|
+
|
4227
|
+
lm_ggml_backend_buffer_type_t lm_ggml_backend_cpu_aarch64_buffer_type(void) {
|
4228
|
+
static struct lm_ggml_backend_buffer_type lm_ggml_backend_cpu_buffer_type_aarch64 = {
|
4229
|
+
/* .iface = */ {
|
4230
|
+
/* .get_name = */ lm_ggml_backend_cpu_aarch64_buffer_type_get_name,
|
4231
|
+
/* .alloc_buffer = */ lm_ggml_backend_cpu_aarch64_buffer_type_alloc_buffer,
|
4232
|
+
/* .get_alignment = */ lm_ggml_backend_cpu_aarch64_buffer_type_get_alignment,
|
4233
|
+
/* .get_max_size = */ nullptr, // defaults to SIZE_MAX
|
4234
|
+
/* .get_alloc_size = */ nullptr, // defaults to lm_ggml_nbytes
|
4235
|
+
/* .is_host = */ nullptr,
|
4236
|
+
},
|
4237
|
+
/* .device = */ lm_ggml_backend_reg_dev_get(lm_ggml_backend_cpu_reg(), 0),
|
4238
|
+
/* .context = */ new ggml::cpu::aarch64::extra_buffer_type(),
|
4239
|
+
};
|
4240
|
+
|
4241
|
+
return &lm_ggml_backend_cpu_buffer_type_aarch64;
|
3823
4242
|
}
|