whisper.rn 0.4.0-rc.10 → 0.4.0-rc.12
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/src/main/CMakeLists.txt +9 -3
- package/cpp/amx/amx.cpp +220 -0
- package/cpp/amx/amx.h +8 -0
- package/cpp/amx/common.h +91 -0
- package/cpp/amx/mmq.cpp +2511 -0
- package/cpp/amx/mmq.h +10 -0
- package/cpp/ggml-alloc.c +6 -14
- package/cpp/ggml-backend-impl.h +50 -11
- package/cpp/ggml-backend-reg.cpp +409 -31
- package/cpp/ggml-backend.cpp +9 -3
- package/cpp/ggml-backend.h +18 -0
- package/cpp/ggml-common.h +41 -43
- package/cpp/ggml-cpp.h +1 -0
- package/cpp/{ggml-cpu-aarch64.c → ggml-cpu-aarch64.cpp} +941 -254
- package/cpp/ggml-cpu-aarch64.h +2 -24
- package/cpp/ggml-cpu-impl.h +171 -11
- package/cpp/ggml-cpu-quants.c +1812 -389
- package/cpp/ggml-cpu-traits.cpp +36 -0
- package/cpp/ggml-cpu-traits.h +38 -0
- package/cpp/ggml-cpu.c +1432 -610
- package/cpp/ggml-cpu.cpp +131 -141
- package/cpp/ggml-cpu.h +10 -50
- package/cpp/ggml-impl.h +27 -11
- package/cpp/ggml-metal-impl.h +39 -0
- package/cpp/ggml-metal.h +1 -1
- package/cpp/ggml-metal.m +1031 -359
- package/cpp/ggml-opt.cpp +854 -0
- package/cpp/ggml-opt.h +216 -0
- package/cpp/ggml-quants.c +0 -9
- package/cpp/ggml-threading.h +4 -2
- package/cpp/ggml-whisper.metallib +0 -0
- package/cpp/ggml.c +501 -1537
- package/cpp/ggml.h +144 -171
- package/cpp/gguf.cpp +1329 -0
- package/cpp/gguf.h +202 -0
- package/cpp/whisper.cpp +254 -114
- package/cpp/whisper.h +6 -3
- package/lib/commonjs/version.json +1 -1
- package/lib/module/version.json +1 -1
- package/package.json +2 -1
- package/src/version.json +1 -1
- package/whisper-rn.podspec +2 -2
- package/cpp/README.md +0 -4
- package/cpp/ggml-aarch64.c +0 -129
- package/cpp/ggml-aarch64.h +0 -19
- package/cpp/ggml-backend.cpp.rej +0 -12
|
@@ -1,24 +1,57 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
//
|
|
4
|
-
|
|
5
|
-
#define WSP_GGML_COMMON_IMPL_C
|
|
1
|
+
#define WSP_GGML_COMMON_IMPL_CPP
|
|
2
|
+
#define WSP_GGML_COMMON_DECL_CPP
|
|
6
3
|
#include "ggml-common.h"
|
|
4
|
+
#include "ggml-backend-impl.h"
|
|
7
5
|
|
|
8
6
|
#include "ggml-quants.h"
|
|
9
7
|
#include "ggml-impl.h"
|
|
10
8
|
#include "ggml-cpu.h"
|
|
11
9
|
#include "ggml-cpu-impl.h"
|
|
10
|
+
#include "ggml-cpu-traits.h"
|
|
12
11
|
|
|
13
|
-
#include <
|
|
14
|
-
#include <
|
|
15
|
-
#include <
|
|
16
|
-
#include <
|
|
17
|
-
#include <
|
|
18
|
-
#include <
|
|
12
|
+
#include <cmath>
|
|
13
|
+
#include <cstring>
|
|
14
|
+
#include <cassert>
|
|
15
|
+
#include <cfloat>
|
|
16
|
+
#include <cstdlib> // for qsort
|
|
17
|
+
#include <cstdio> // for WSP_GGML_ASSERT
|
|
19
18
|
|
|
20
19
|
#include "ggml-cpu-aarch64.h"
|
|
21
20
|
|
|
21
|
+
// TODO: move to include file?
|
|
22
|
+
template <int K> constexpr int QK_0() {
|
|
23
|
+
if constexpr (K == 4) {
|
|
24
|
+
return QK4_0;
|
|
25
|
+
}
|
|
26
|
+
if constexpr (K == 8) {
|
|
27
|
+
return QK8_0;
|
|
28
|
+
}
|
|
29
|
+
return -1;
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
template <int K, int N> struct block {
|
|
33
|
+
wsp_ggml_half d[N]; // deltas for N qK_0 blocks
|
|
34
|
+
int8_t qs[(QK_0<K>() * N * K) / 8]; // quants for N qK_0 blocks
|
|
35
|
+
};
|
|
36
|
+
|
|
37
|
+
// control size
|
|
38
|
+
static_assert(sizeof(block<4, 4>) == 4 * sizeof(wsp_ggml_half) + QK8_0 * 2, "wrong block<4,4> size/padding");
|
|
39
|
+
static_assert(sizeof(block<4, 8>) == 8 * sizeof(wsp_ggml_half) + QK8_0 * 4, "wrong block<4,8> size/padding");
|
|
40
|
+
static_assert(sizeof(block<8, 4>) == 4 * sizeof(wsp_ggml_half) + QK8_0 * 4, "wrong block<8,4> size/padding");
|
|
41
|
+
static_assert(sizeof(block<8, 8>) == 8 * sizeof(wsp_ggml_half) + QK8_0 * 8, "wrong block<8,8> size/padding");
|
|
42
|
+
|
|
43
|
+
using block_q4_0x4 = block<4, 4>;
|
|
44
|
+
using block_q4_0x8 = block<4, 8>;
|
|
45
|
+
using block_q8_0x4 = block<8, 4>;
|
|
46
|
+
using block_q8_0x8 = block<8, 8>;
|
|
47
|
+
|
|
48
|
+
struct block_iq4_nlx4 {
|
|
49
|
+
wsp_ggml_half d[4]; // deltas for 4 iq4_nl blocks
|
|
50
|
+
uint8_t qs[QK4_NL * 2]; // nibbles / quants for 4 iq4_nl blocks
|
|
51
|
+
};
|
|
52
|
+
|
|
53
|
+
static_assert(sizeof(block_iq4_nlx4) == 4 * sizeof(wsp_ggml_half) + QK4_NL * 2, "wrong iq4_nlx4 block size/padding");
|
|
54
|
+
|
|
22
55
|
#if defined(__GNUC__)
|
|
23
56
|
#pragma GCC diagnostic ignored "-Woverlength-strings"
|
|
24
57
|
#elif defined(_MSC_VER)
|
|
@@ -132,7 +165,7 @@ static inline __m512i sum_i16_pairs_int_32x16(const __m512i x) {
|
|
|
132
165
|
}
|
|
133
166
|
|
|
134
167
|
static inline __m512i mul_sum_us8_pairs_int32x16(const __m512i ax, const __m512i sy) {
|
|
135
|
-
#if defined(
|
|
168
|
+
#if defined(__AVX512VNNI__)
|
|
136
169
|
const __m512i zero = _mm512_setzero_si512();
|
|
137
170
|
return _mm512_dpbusd_epi32(zero, ax, sy);
|
|
138
171
|
#else
|
|
@@ -161,9 +194,12 @@ static inline __m256i sum_i16_pairs_int32x8(const __m256i x) {
|
|
|
161
194
|
}
|
|
162
195
|
|
|
163
196
|
static inline __m256i mul_sum_us8_pairs_int32x8(const __m256i ax, const __m256i sy) {
|
|
164
|
-
#if defined(
|
|
197
|
+
#if defined(__AVX512VNNI__) && defined(__AVX512VL__)
|
|
165
198
|
const __m256i zero = _mm256_setzero_si256();
|
|
166
199
|
return _mm256_dpbusd_epi32(zero, ax, sy);
|
|
200
|
+
#elif defined(__AVXVNNI__)
|
|
201
|
+
const __m256i zero = _mm256_setzero_si256();
|
|
202
|
+
return _mm256_dpbusd_avx_epi32(zero, ax, sy);
|
|
167
203
|
#else
|
|
168
204
|
// Perform multiplication and create 16-bit values
|
|
169
205
|
const __m256i dot = _mm256_maddubs_epi16(ax, sy);
|
|
@@ -187,12 +223,14 @@ static inline __m256i mul_sum_i8_pairs_int32x8(const __m256i x, const __m256i y)
|
|
|
187
223
|
}
|
|
188
224
|
#endif
|
|
189
225
|
|
|
190
|
-
static
|
|
226
|
+
static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
|
|
227
|
+
|
|
228
|
+
static void wsp_quantize_q8_0_4x4(const float * WSP_GGML_RESTRICT x, void * WSP_GGML_RESTRICT vy, int64_t k) {
|
|
191
229
|
assert(QK8_0 == 32);
|
|
192
230
|
assert(k % QK8_0 == 0);
|
|
193
231
|
const int nb = k / QK8_0;
|
|
194
232
|
|
|
195
|
-
block_q8_0x4 *
|
|
233
|
+
block_q8_0x4 * WSP_GGML_RESTRICT y = (block_q8_0x4 *) vy;
|
|
196
234
|
|
|
197
235
|
#if defined(__ARM_NEON)
|
|
198
236
|
float32x4_t srcv[4][8];
|
|
@@ -281,12 +319,12 @@ static void wsp_quantize_q8_0_4x4(const float * restrict x, void * restrict vy,
|
|
|
281
319
|
#endif
|
|
282
320
|
}
|
|
283
321
|
|
|
284
|
-
static void wsp_quantize_q8_0_4x8(const float *
|
|
322
|
+
static void wsp_quantize_q8_0_4x8(const float * WSP_GGML_RESTRICT x, void * WSP_GGML_RESTRICT vy, int64_t k) {
|
|
285
323
|
assert(QK8_0 == 32);
|
|
286
324
|
assert(k % QK8_0 == 0);
|
|
287
325
|
const int nb = k / QK8_0;
|
|
288
326
|
|
|
289
|
-
block_q8_0x4 *
|
|
327
|
+
block_q8_0x4 * WSP_GGML_RESTRICT y = (block_q8_0x4 *) vy;
|
|
290
328
|
|
|
291
329
|
#if defined(__ARM_NEON)
|
|
292
330
|
float32x4_t srcv[4][8];
|
|
@@ -496,7 +534,7 @@ static void wsp_quantize_q8_0_4x8(const float * restrict x, void * restrict vy,
|
|
|
496
534
|
#endif
|
|
497
535
|
}
|
|
498
536
|
|
|
499
|
-
void wsp_quantize_mat_q8_0(const float *
|
|
537
|
+
static void wsp_quantize_mat_q8_0(const float * WSP_GGML_RESTRICT x, void * WSP_GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row, int64_t blck_size_interleave) {
|
|
500
538
|
assert(nrow == 4);
|
|
501
539
|
UNUSED(nrow);
|
|
502
540
|
if (blck_size_interleave == 4) {
|
|
@@ -508,7 +546,7 @@ void wsp_quantize_mat_q8_0(const float * restrict x, void * restrict vy, int64_t
|
|
|
508
546
|
}
|
|
509
547
|
}
|
|
510
548
|
|
|
511
|
-
void wsp_ggml_gemv_q4_0_4x4_q8_0(int n, float *
|
|
549
|
+
static void wsp_ggml_gemv_q4_0_4x4_q8_0(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc) {
|
|
512
550
|
const int qk = QK8_0;
|
|
513
551
|
const int nb = n / qk;
|
|
514
552
|
const int ncols_interleaved = 4;
|
|
@@ -527,67 +565,47 @@ void wsp_ggml_gemv_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const voi
|
|
|
527
565
|
UNUSED(ncols_interleaved);
|
|
528
566
|
UNUSED(blocklen);
|
|
529
567
|
|
|
530
|
-
#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
|
|
531
|
-
if (wsp_ggml_cpu_has_neon()) {
|
|
532
|
-
const
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
"fcvtl v16.4s, v20.4h\n"
|
|
569
|
-
".inst 0x4f99e39a // sdot v26.4s, v28.16b, v25.4b[0]\n"
|
|
570
|
-
"fmul v16.4s, v16.4s, v21.4s\n"
|
|
571
|
-
".inst 0x4fbbe27a // sdot v26.4s, v19.16b, v27.4b[1]\n"
|
|
572
|
-
".inst 0x4fb9e31a // sdot v26.4s, v24.16b, v25.4b[1]\n"
|
|
573
|
-
".inst 0x4f9bea5a // sdot v26.4s, v18.16b, v27.4b[2]\n"
|
|
574
|
-
".inst 0x4f99eafa // sdot v26.4s, v23.16b, v25.4b[2]\n"
|
|
575
|
-
".inst 0x4fbbea3a // sdot v26.4s, v17.16b, v27.4b[3]\n"
|
|
576
|
-
".inst 0x4fb9eada // sdot v26.4s, v22.16b, v25.4b[3]\n"
|
|
577
|
-
"scvtf v26.4s, v26.4s, #0x4\n"
|
|
578
|
-
"fmla v29.4s, v26.4s, v16.4s\n"
|
|
579
|
-
"cbnz x21, 2b\n"
|
|
580
|
-
"sub %x[nc], %x[nc], #0x4\n"
|
|
581
|
-
"str q29, [%x[res_ptr], #0x0]\n"
|
|
582
|
-
"add %x[res_ptr], %x[res_ptr], #0x10\n"
|
|
583
|
-
"cbnz %x[nc], 1b\n"
|
|
584
|
-
: [b_ptr] "+&r" (b_ptr), [res_ptr] "+&r" (res_ptr), [nc] "+&r" (nc)
|
|
585
|
-
: [a_ptr] "r" (a_ptr), [nb] "r" (nb)
|
|
586
|
-
: "memory", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22"
|
|
587
|
-
);
|
|
568
|
+
#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
|
569
|
+
if (wsp_ggml_cpu_has_neon() && wsp_ggml_cpu_has_dotprod()) {
|
|
570
|
+
const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx;
|
|
571
|
+
|
|
572
|
+
for (int c = 0; c < nc; c += ncols_interleaved) {
|
|
573
|
+
const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
|
|
574
|
+
float32x4_t acc = vdupq_n_f32(0);
|
|
575
|
+
for (int b = 0; b < nb; b++) {
|
|
576
|
+
int8x16_t b0 = vld1q_s8((const int8_t *) b_ptr->qs);
|
|
577
|
+
int8x16_t b1 = vld1q_s8((const int8_t *) b_ptr->qs + 16);
|
|
578
|
+
int8x16_t b2 = vld1q_s8((const int8_t *) b_ptr->qs + 32);
|
|
579
|
+
int8x16_t b3 = vld1q_s8((const int8_t *) b_ptr->qs + 48);
|
|
580
|
+
float16x4_t bd = vld1_f16((const __fp16 *) b_ptr->d);
|
|
581
|
+
|
|
582
|
+
int8x16_t a0 = vld1q_s8(a_ptr->qs);
|
|
583
|
+
int8x16_t a1 = vld1q_s8(a_ptr->qs + qk/2);
|
|
584
|
+
float16x4_t ad = vld1_dup_f16((const __fp16 *) &a_ptr->d);
|
|
585
|
+
|
|
586
|
+
int32x4_t ret = vdupq_n_s32(0);
|
|
587
|
+
|
|
588
|
+
ret = vdotq_laneq_s32(ret, b0 << 4, a0, 0);
|
|
589
|
+
ret = vdotq_laneq_s32(ret, b1 << 4, a0, 1);
|
|
590
|
+
ret = vdotq_laneq_s32(ret, b2 << 4, a0, 2);
|
|
591
|
+
ret = vdotq_laneq_s32(ret, b3 << 4, a0, 3);
|
|
592
|
+
|
|
593
|
+
ret = vdotq_laneq_s32(ret, b0 & 0xf0U, a1, 0);
|
|
594
|
+
ret = vdotq_laneq_s32(ret, b1 & 0xf0U, a1, 1);
|
|
595
|
+
ret = vdotq_laneq_s32(ret, b2 & 0xf0U, a1, 2);
|
|
596
|
+
ret = vdotq_laneq_s32(ret, b3 & 0xf0U, a1, 3);
|
|
597
|
+
|
|
598
|
+
acc = vfmaq_f32(acc, vcvtq_n_f32_s32(ret, 4),
|
|
599
|
+
vmulq_f32(vcvt_f32_f16(ad), vcvt_f32_f16(bd)));
|
|
600
|
+
a_ptr++;
|
|
601
|
+
b_ptr++;
|
|
602
|
+
}
|
|
603
|
+
vst1q_f32(s, acc);
|
|
604
|
+
s += ncols_interleaved;
|
|
605
|
+
}
|
|
588
606
|
return;
|
|
589
607
|
}
|
|
590
|
-
#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
|
|
608
|
+
#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
|
591
609
|
float sumf[4];
|
|
592
610
|
int sumi;
|
|
593
611
|
|
|
@@ -613,7 +631,7 @@ void wsp_ggml_gemv_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const voi
|
|
|
613
631
|
}
|
|
614
632
|
}
|
|
615
633
|
|
|
616
|
-
void wsp_ggml_gemv_q4_0_4x8_q8_0(int n, float *
|
|
634
|
+
static void wsp_ggml_gemv_q4_0_4x8_q8_0(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc) {
|
|
617
635
|
const int qk = QK8_0;
|
|
618
636
|
const int nb = n / qk;
|
|
619
637
|
const int ncols_interleaved = 4;
|
|
@@ -632,72 +650,52 @@ void wsp_ggml_gemv_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const voi
|
|
|
632
650
|
UNUSED(ncols_interleaved);
|
|
633
651
|
UNUSED(blocklen);
|
|
634
652
|
|
|
635
|
-
#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(
|
|
636
|
-
if (wsp_ggml_cpu_has_neon() &&
|
|
637
|
-
const
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
|
|
670
|
-
|
|
671
|
-
|
|
672
|
-
|
|
673
|
-
|
|
674
|
-
|
|
675
|
-
|
|
676
|
-
|
|
677
|
-
|
|
678
|
-
"fcvtl v16.4s, v24.4h\n"
|
|
679
|
-
".inst 0x4e95969d // sdot v29.4s, v20.16b, v21.16b\n"
|
|
680
|
-
".inst 0x4e95967a // sdot v26.4s, v19.16b, v21.16b\n"
|
|
681
|
-
"fmul v16.4s, v16.4s, v25.4s\n"
|
|
682
|
-
".inst 0x4e9297fd // sdot v29.4s, v31.16b, v18.16b\n"
|
|
683
|
-
".inst 0x4e9297da // sdot v26.4s, v30.16b, v18.16b\n"
|
|
684
|
-
".inst 0x4e91979d // sdot v29.4s, v28.16b, v17.16b\n"
|
|
685
|
-
".inst 0x4e91977a // sdot v26.4s, v27.16b, v17.16b\n"
|
|
686
|
-
"addp v29.4s, v29.4s, v26.4s\n"
|
|
687
|
-
"scvtf v29.4s, v29.4s, #0x4\n"
|
|
688
|
-
"fmla v0.4s, v29.4s, v16.4s\n"
|
|
689
|
-
"cbnz x22, 2b\n"
|
|
690
|
-
"sub %x[nc], %x[nc], #0x4\n"
|
|
691
|
-
"str q0, [%x[res_ptr], #0x0]\n"
|
|
692
|
-
"add %x[res_ptr], %x[res_ptr], #0x10\n"
|
|
693
|
-
"cbnz %x[nc], 1b\n"
|
|
694
|
-
: [b_ptr] "+&r" (b_ptr), [res_ptr] "+&r" (res_ptr), [nc] "+&r" (nc)
|
|
695
|
-
: [a_ptr] "r" (a_ptr), [nb] "r" (nb)
|
|
696
|
-
: "memory", "v0", "v1", "v2", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22", "x23"
|
|
697
|
-
);
|
|
653
|
+
#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
|
654
|
+
if (wsp_ggml_cpu_has_neon() && wsp_ggml_cpu_has_dotprod()) {
|
|
655
|
+
const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx;
|
|
656
|
+
|
|
657
|
+
for (int c = 0; c < nc; c += ncols_interleaved) {
|
|
658
|
+
const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
|
|
659
|
+
float32x4_t acc = vdupq_n_f32(0);
|
|
660
|
+
for (int b = 0; b < nb; b++) {
|
|
661
|
+
int8x16_t b0 = vld1q_s8((const int8_t *) b_ptr->qs);
|
|
662
|
+
int8x16_t b1 = vld1q_s8((const int8_t *) b_ptr->qs + 16);
|
|
663
|
+
int8x16_t b2 = vld1q_s8((const int8_t *) b_ptr->qs + 32);
|
|
664
|
+
int8x16_t b3 = vld1q_s8((const int8_t *) b_ptr->qs + 48);
|
|
665
|
+
float16x4_t bd = vld1_f16((const __fp16 *) b_ptr->d);
|
|
666
|
+
|
|
667
|
+
int8x16_t a0 = (int8x16_t) vld1q_dup_s64((const int64_t *) a_ptr->qs);
|
|
668
|
+
int8x16_t a1 = (int8x16_t) vld1q_dup_s64((const int64_t *) a_ptr->qs + 1);
|
|
669
|
+
int8x16_t a2 = (int8x16_t) vld1q_dup_s64((const int64_t *) a_ptr->qs + 2);
|
|
670
|
+
int8x16_t a3 = (int8x16_t) vld1q_dup_s64((const int64_t *) a_ptr->qs + 3);
|
|
671
|
+
float16x4_t ad = vld1_dup_f16((const __fp16 *) &a_ptr->d);
|
|
672
|
+
|
|
673
|
+
int32x4_t ret0 = vdupq_n_s32(0);
|
|
674
|
+
int32x4_t ret1 = vdupq_n_s32(0);
|
|
675
|
+
|
|
676
|
+
ret0 = vdotq_s32(ret0, b0 << 4, a0);
|
|
677
|
+
ret1 = vdotq_s32(ret1, b1 << 4, a0);
|
|
678
|
+
ret0 = vdotq_s32(ret0, b2 << 4, a1);
|
|
679
|
+
ret1 = vdotq_s32(ret1, b3 << 4, a1);
|
|
680
|
+
|
|
681
|
+
ret0 = vdotq_s32(ret0, b0 & 0xf0U, a2);
|
|
682
|
+
ret1 = vdotq_s32(ret1, b1 & 0xf0U, a2);
|
|
683
|
+
ret0 = vdotq_s32(ret0, b2 & 0xf0U, a3);
|
|
684
|
+
ret1 = vdotq_s32(ret1, b3 & 0xf0U, a3);
|
|
685
|
+
|
|
686
|
+
int32x4_t ret = vpaddq_s32(ret0, ret1);
|
|
687
|
+
|
|
688
|
+
acc = vfmaq_f32(acc, vcvtq_n_f32_s32(ret, 4),
|
|
689
|
+
vmulq_f32(vcvt_f32_f16(ad), vcvt_f32_f16(bd)));
|
|
690
|
+
a_ptr++;
|
|
691
|
+
b_ptr++;
|
|
692
|
+
}
|
|
693
|
+
vst1q_f32(s, acc);
|
|
694
|
+
s += ncols_interleaved;
|
|
695
|
+
}
|
|
698
696
|
return;
|
|
699
697
|
}
|
|
700
|
-
#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(
|
|
698
|
+
#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
|
701
699
|
float sumf[4];
|
|
702
700
|
int sumi;
|
|
703
701
|
|
|
@@ -723,7 +721,7 @@ void wsp_ggml_gemv_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const voi
|
|
|
723
721
|
}
|
|
724
722
|
}
|
|
725
723
|
|
|
726
|
-
void wsp_ggml_gemv_q4_0_8x8_q8_0(int n, float *
|
|
724
|
+
static void wsp_ggml_gemv_q4_0_8x8_q8_0(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc) {
|
|
727
725
|
const int qk = QK8_0;
|
|
728
726
|
const int nb = n / qk;
|
|
729
727
|
const int ncols_interleaved = 8;
|
|
@@ -996,7 +994,103 @@ void wsp_ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const voi
|
|
|
996
994
|
}
|
|
997
995
|
}
|
|
998
996
|
|
|
999
|
-
void
|
|
997
|
+
static void wsp_ggml_gemv_iq4_nl_4x4_q8_0(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc) {
|
|
998
|
+
const int qk = QK8_0;
|
|
999
|
+
const int nb = n / qk;
|
|
1000
|
+
const int ncols_interleaved = 4;
|
|
1001
|
+
const int blocklen = 4;
|
|
1002
|
+
|
|
1003
|
+
assert (n % qk == 0);
|
|
1004
|
+
assert (nc % ncols_interleaved == 0);
|
|
1005
|
+
|
|
1006
|
+
UNUSED(s);
|
|
1007
|
+
UNUSED(bs);
|
|
1008
|
+
UNUSED(vx);
|
|
1009
|
+
UNUSED(vy);
|
|
1010
|
+
UNUSED(nr);
|
|
1011
|
+
UNUSED(nc);
|
|
1012
|
+
UNUSED(nb);
|
|
1013
|
+
UNUSED(ncols_interleaved);
|
|
1014
|
+
UNUSED(blocklen);
|
|
1015
|
+
|
|
1016
|
+
#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
|
1017
|
+
if (wsp_ggml_cpu_has_neon() && wsp_ggml_cpu_has_dotprod()) {
|
|
1018
|
+
const int8x16_t kvalues = vld1q_s8(kvalues_iq4nl);
|
|
1019
|
+
const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
|
|
1020
|
+
float * res_ptr = s;
|
|
1021
|
+
|
|
1022
|
+
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
|
1023
|
+
const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb);
|
|
1024
|
+
|
|
1025
|
+
float32x4_t sumf = vdupq_n_f32(0);
|
|
1026
|
+
for (int l = 0; l < nb; l++) {
|
|
1027
|
+
uint8x16_t b_0 = vld1q_u8(b_ptr[l].qs + 0);
|
|
1028
|
+
uint8x16_t b_1 = vld1q_u8(b_ptr[l].qs + 16);
|
|
1029
|
+
uint8x16_t b_2 = vld1q_u8(b_ptr[l].qs + 32);
|
|
1030
|
+
uint8x16_t b_3 = vld1q_u8(b_ptr[l].qs + 48);
|
|
1031
|
+
|
|
1032
|
+
int8x16_t b_0_hi = vqtbl1q_s8(kvalues, b_0 >> 4);
|
|
1033
|
+
int8x16_t b_0_lo = vqtbl1q_s8(kvalues, b_0 & 0x0F);
|
|
1034
|
+
int8x16_t b_1_hi = vqtbl1q_s8(kvalues, b_1 >> 4);
|
|
1035
|
+
int8x16_t b_1_lo = vqtbl1q_s8(kvalues, b_1 & 0x0F);
|
|
1036
|
+
int8x16_t b_2_hi = vqtbl1q_s8(kvalues, b_2 >> 4);
|
|
1037
|
+
int8x16_t b_2_lo = vqtbl1q_s8(kvalues, b_2 & 0x0F);
|
|
1038
|
+
int8x16_t b_3_hi = vqtbl1q_s8(kvalues, b_3 >> 4);
|
|
1039
|
+
int8x16_t b_3_lo = vqtbl1q_s8(kvalues, b_3 & 0x0F);
|
|
1040
|
+
|
|
1041
|
+
int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 0);
|
|
1042
|
+
int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16);
|
|
1043
|
+
|
|
1044
|
+
int32x4_t sumi = vdupq_n_s32(0);
|
|
1045
|
+
sumi = vdotq_laneq_s32(sumi, b_0_lo, a_0, 0);
|
|
1046
|
+
sumi = vdotq_laneq_s32(sumi, b_0_hi, a_1, 0);
|
|
1047
|
+
sumi = vdotq_laneq_s32(sumi, b_1_lo, a_0, 1);
|
|
1048
|
+
sumi = vdotq_laneq_s32(sumi, b_1_hi, a_1, 1);
|
|
1049
|
+
sumi = vdotq_laneq_s32(sumi, b_2_lo, a_0, 2);
|
|
1050
|
+
sumi = vdotq_laneq_s32(sumi, b_2_hi, a_1, 2);
|
|
1051
|
+
sumi = vdotq_laneq_s32(sumi, b_3_lo, a_0, 3);
|
|
1052
|
+
sumi = vdotq_laneq_s32(sumi, b_3_hi, a_1, 3);
|
|
1053
|
+
|
|
1054
|
+
float32x4_t a_d = vcvt_f32_f16(vld1_dup_f16((const float16_t *)&a_ptr[l].d));
|
|
1055
|
+
float32x4_t b_d = vcvt_f32_f16(vld1_f16((const float16_t *)b_ptr[l].d));
|
|
1056
|
+
float32x4_t d = a_d * b_d;
|
|
1057
|
+
|
|
1058
|
+
sumf = vmlaq_f32(sumf, d, vcvtq_f32_s32(sumi));
|
|
1059
|
+
}
|
|
1060
|
+
|
|
1061
|
+
vst1q_f32(res_ptr + x * 4, sumf);
|
|
1062
|
+
}
|
|
1063
|
+
return;
|
|
1064
|
+
}
|
|
1065
|
+
#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
|
|
1066
|
+
{
|
|
1067
|
+
float sumf[4];
|
|
1068
|
+
int sumi;
|
|
1069
|
+
|
|
1070
|
+
const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
|
|
1071
|
+
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
|
1072
|
+
const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb);
|
|
1073
|
+
|
|
1074
|
+
for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
|
|
1075
|
+
for (int l = 0; l < nb; l++) {
|
|
1076
|
+
for (int k = 0; k < (qk / (2 * blocklen)); k++) {
|
|
1077
|
+
for (int j = 0; j < ncols_interleaved; j++) {
|
|
1078
|
+
sumi = 0;
|
|
1079
|
+
for (int i = 0; i < blocklen; ++i) {
|
|
1080
|
+
const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F];
|
|
1081
|
+
const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4];
|
|
1082
|
+
sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2]));
|
|
1083
|
+
}
|
|
1084
|
+
sumf[j] += sumi * WSP_GGML_FP16_TO_FP32(b_ptr[l].d[j]) * WSP_GGML_FP16_TO_FP32(a_ptr[l].d);
|
|
1085
|
+
}
|
|
1086
|
+
}
|
|
1087
|
+
}
|
|
1088
|
+
for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
|
|
1089
|
+
}
|
|
1090
|
+
}
|
|
1091
|
+
}
|
|
1092
|
+
|
|
1093
|
+
static void wsp_ggml_gemm_q4_0_4x4_q8_0(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc) {
|
|
1000
1094
|
const int qk = QK8_0;
|
|
1001
1095
|
const int nb = n / qk;
|
|
1002
1096
|
const int ncols_interleaved = 4;
|
|
@@ -1017,7 +1111,7 @@ void wsp_ggml_gemm_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const voi
|
|
|
1017
1111
|
UNUSED(blocklen);
|
|
1018
1112
|
|
|
1019
1113
|
#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
|
|
1020
|
-
if (wsp_ggml_cpu_has_neon()) {
|
|
1114
|
+
if (wsp_ggml_cpu_has_neon() && wsp_ggml_cpu_has_dotprod()) {
|
|
1021
1115
|
const void * b_ptr = vx;
|
|
1022
1116
|
const void * a_ptr = vy;
|
|
1023
1117
|
float * res_ptr = s;
|
|
@@ -1512,7 +1606,7 @@ void wsp_ggml_gemm_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const voi
|
|
|
1512
1606
|
}
|
|
1513
1607
|
}
|
|
1514
1608
|
|
|
1515
|
-
void wsp_ggml_gemm_q4_0_4x8_q8_0(int n, float *
|
|
1609
|
+
static void wsp_ggml_gemm_q4_0_4x8_q8_0(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc) {
|
|
1516
1610
|
const int qk = QK8_0;
|
|
1517
1611
|
const int nb = n / qk;
|
|
1518
1612
|
const int ncols_interleaved = 4;
|
|
@@ -1966,7 +2060,7 @@ void wsp_ggml_gemm_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const voi
|
|
|
1966
2060
|
}
|
|
1967
2061
|
}
|
|
1968
2062
|
|
|
1969
|
-
void wsp_ggml_gemm_q4_0_8x8_q8_0(int n, float *
|
|
2063
|
+
static void wsp_ggml_gemm_q4_0_8x8_q8_0(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc) {
|
|
1970
2064
|
const int qk = QK8_0;
|
|
1971
2065
|
const int nb = n / qk;
|
|
1972
2066
|
const int ncols_interleaved = 8;
|
|
@@ -2486,31 +2580,31 @@ void wsp_ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const voi
|
|
|
2486
2580
|
const __m512i rhs_mat_2367ABEF_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 4), m4bexpanded)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31) BA(24-31) BB(24-31) BE(24-31) BF(24-31)
|
|
2487
2581
|
|
|
2488
2582
|
// Shuffle pattern one - right side input
|
|
2489
|
-
const __m512i rhs_mat_014589CD_0_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, 136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) B8(0-3) B9(0-3) B8(0-3) B9(0-3) BC(0-3) BD(0-3) BC(0-3) BD(0-3)
|
|
2490
|
-
const __m512i rhs_mat_2367ABEF_0_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, 136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) BA(0-3) BB(0-3) BA(0-3) BB(0-3) BE(0-3) BF(0-3) BE(0-3) BF(0-3)
|
|
2583
|
+
const __m512i rhs_mat_014589CD_0_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, (_MM_PERM_ENUM)136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) B8(0-3) B9(0-3) B8(0-3) B9(0-3) BC(0-3) BD(0-3) BC(0-3) BD(0-3)
|
|
2584
|
+
const __m512i rhs_mat_2367ABEF_0_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, (_MM_PERM_ENUM)136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) BA(0-3) BB(0-3) BA(0-3) BB(0-3) BE(0-3) BF(0-3) BE(0-3) BF(0-3)
|
|
2491
2585
|
|
|
2492
|
-
const __m512i rhs_mat_014589CD_1_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, 136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) B8(8-11) B9(8-11) B8(8-11) B9(8-11) BC(8-11) BD(8-11) BC(8-11) BD(8-11)
|
|
2493
|
-
const __m512i rhs_mat_2367ABEF_1_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, 136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) BA(8-11) BB(8-11) BA(8-11) BB(8-11) BE(8-11) BF(8-11) BE(8-11) BF(8-11)
|
|
2586
|
+
const __m512i rhs_mat_014589CD_1_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, (_MM_PERM_ENUM)136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) B8(8-11) B9(8-11) B8(8-11) B9(8-11) BC(8-11) BD(8-11) BC(8-11) BD(8-11)
|
|
2587
|
+
const __m512i rhs_mat_2367ABEF_1_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, (_MM_PERM_ENUM)136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) BA(8-11) BB(8-11) BA(8-11) BB(8-11) BE(8-11) BF(8-11) BE(8-11) BF(8-11)
|
|
2494
2588
|
|
|
2495
|
-
const __m512i rhs_mat_014589CD_2_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, 136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) B8(16-19) B9(16-19) B8(16-19) B9(16-19) BC(16-19) BD(16-19) BC(16-19) BD(16-19)
|
|
2496
|
-
const __m512i rhs_mat_2367ABEF_2_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, 136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) BA(16-19) BB(16-19) BA(16-19) BB(16-19) BE(16-19) BF(16-19) BE(16-19) BF(16-19)
|
|
2589
|
+
const __m512i rhs_mat_014589CD_2_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, (_MM_PERM_ENUM)136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) B8(16-19) B9(16-19) B8(16-19) B9(16-19) BC(16-19) BD(16-19) BC(16-19) BD(16-19)
|
|
2590
|
+
const __m512i rhs_mat_2367ABEF_2_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, (_MM_PERM_ENUM)136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) BA(16-19) BB(16-19) BA(16-19) BB(16-19) BE(16-19) BF(16-19) BE(16-19) BF(16-19)
|
|
2497
2591
|
|
|
2498
|
-
const __m512i rhs_mat_014589CD_3_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, 136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) B8(24-27) B9(24-27) B8(24-27) B9(24-27) BC(24-27) BD(24-27) BC(24-27) BD(24-27)
|
|
2499
|
-
const __m512i rhs_mat_2367ABEF_3_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, 136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) BA(24-27) BB(24-27) BA(24-27) BB(24-27) BE(24-27) BF(24-27) BE(24-27) BF(24-27)
|
|
2592
|
+
const __m512i rhs_mat_014589CD_3_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, (_MM_PERM_ENUM)136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) B8(24-27) B9(24-27) B8(24-27) B9(24-27) BC(24-27) BD(24-27) BC(24-27) BD(24-27)
|
|
2593
|
+
const __m512i rhs_mat_2367ABEF_3_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, (_MM_PERM_ENUM)136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) BA(24-27) BB(24-27) BA(24-27) BB(24-27) BE(24-27) BF(24-27) BE(24-27) BF(24-27)
|
|
2500
2594
|
|
|
2501
2595
|
// Shuffle pattern two - right side input
|
|
2502
2596
|
|
|
2503
|
-
const __m512i rhs_mat_014589CD_0_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, 221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) B8(4-7) B9(4-7) B8(4-7) B9(4-7) BC(4-7) BD(4-7) BC(4-7) BD(4-7)
|
|
2504
|
-
const __m512i rhs_mat_2367ABEF_0_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, 221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) BA(4-7) BB(4-7) BA(4-7) BB(4-7) BE(4-7) BF(4-7) BE(4-7) BF(4-7)
|
|
2597
|
+
const __m512i rhs_mat_014589CD_0_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, (_MM_PERM_ENUM)221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) B8(4-7) B9(4-7) B8(4-7) B9(4-7) BC(4-7) BD(4-7) BC(4-7) BD(4-7)
|
|
2598
|
+
const __m512i rhs_mat_2367ABEF_0_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, (_MM_PERM_ENUM)221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) BA(4-7) BB(4-7) BA(4-7) BB(4-7) BE(4-7) BF(4-7) BE(4-7) BF(4-7)
|
|
2505
2599
|
|
|
2506
|
-
const __m512i rhs_mat_014589CD_1_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, 221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) B8(12-15) B9(12-15) B8(12-15) B9(12-15) BC(12-15) BD(12-15) BC(12-15) BD(12-15)
|
|
2507
|
-
const __m512i rhs_mat_2367ABEF_1_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, 221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) BA(12-15) BB(12-15) BA(12-15) BB(12-15) BE(12-15) BF(12-15) BE(12-15) BF(12-15)
|
|
2600
|
+
const __m512i rhs_mat_014589CD_1_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, (_MM_PERM_ENUM)221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) B8(12-15) B9(12-15) B8(12-15) B9(12-15) BC(12-15) BD(12-15) BC(12-15) BD(12-15)
|
|
2601
|
+
const __m512i rhs_mat_2367ABEF_1_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, (_MM_PERM_ENUM)221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) BA(12-15) BB(12-15) BA(12-15) BB(12-15) BE(12-15) BF(12-15) BE(12-15) BF(12-15)
|
|
2508
2602
|
|
|
2509
|
-
const __m512i rhs_mat_014589CD_2_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, 221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) B8(20-23) B9(20-23) B8(20-23) B9(20-23) BC(20-23) BD(20-23) BC(20-23) BD(20-23)
|
|
2510
|
-
const __m512i rhs_mat_2367ABEF_2_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, 221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) BA(20-23) BB(20-23) BA(20-23) BB(20-23) BE(20-23) BF(20-23) BE(20-23) BF(20-23)
|
|
2603
|
+
const __m512i rhs_mat_014589CD_2_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, (_MM_PERM_ENUM)221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) B8(20-23) B9(20-23) B8(20-23) B9(20-23) BC(20-23) BD(20-23) BC(20-23) BD(20-23)
|
|
2604
|
+
const __m512i rhs_mat_2367ABEF_2_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, (_MM_PERM_ENUM)221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) BA(20-23) BB(20-23) BA(20-23) BB(20-23) BE(20-23) BF(20-23) BE(20-23) BF(20-23)
|
|
2511
2605
|
|
|
2512
|
-
const __m512i rhs_mat_014589CD_3_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, 221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) B8(28-31) B9(28-31) B8(28-31) B9(28-31) BC(28-31) BD(28-31) BC(28-31) BD(28-31)
|
|
2513
|
-
const __m512i rhs_mat_2367ABEF_3_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, 221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) BA(28-31) BB(28-31) BA(28-31) BB(28-31) BE(28-31) BF(28-31) BE(28-31) BF(28-31)
|
|
2606
|
+
const __m512i rhs_mat_014589CD_3_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, (_MM_PERM_ENUM)221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) B8(28-31) B9(28-31) B8(28-31) B9(28-31) BC(28-31) BD(28-31) BC(28-31) BD(28-31)
|
|
2607
|
+
const __m512i rhs_mat_2367ABEF_3_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, (_MM_PERM_ENUM)221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) BA(28-31) BB(28-31) BA(28-31) BB(28-31) BE(28-31) BF(28-31) BE(28-31) BF(28-31)
|
|
2514
2608
|
|
|
2515
2609
|
// Scale values - Load the weight scale values of two block_q4_0x8
|
|
2516
2610
|
const __m512 col_scale_f32 = WSP_GGML_F32Cx8x2_LOAD(b_ptr_0[b].d, b_ptr_1[b].d);
|
|
@@ -2544,31 +2638,31 @@ void wsp_ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const voi
|
|
|
2544
2638
|
|
|
2545
2639
|
// Shuffle pattern one - left side input
|
|
2546
2640
|
|
|
2547
|
-
const __m512i lhs_mat_01_0_sp1 = _mm512_shuffle_epi32(lhs_mat_01_0, 160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3)
|
|
2548
|
-
const __m512i lhs_mat_23_0_sp1 = _mm512_shuffle_epi32(lhs_mat_23_0, 160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3)
|
|
2641
|
+
const __m512i lhs_mat_01_0_sp1 = _mm512_shuffle_epi32(lhs_mat_01_0, (_MM_PERM_ENUM)160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3)
|
|
2642
|
+
const __m512i lhs_mat_23_0_sp1 = _mm512_shuffle_epi32(lhs_mat_23_0, (_MM_PERM_ENUM)160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3)
|
|
2549
2643
|
|
|
2550
|
-
const __m512i lhs_mat_01_1_sp1 = _mm512_shuffle_epi32(lhs_mat_01_1, 160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11)
|
|
2551
|
-
const __m512i lhs_mat_23_1_sp1 = _mm512_shuffle_epi32(lhs_mat_23_1, 160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11)
|
|
2644
|
+
const __m512i lhs_mat_01_1_sp1 = _mm512_shuffle_epi32(lhs_mat_01_1, (_MM_PERM_ENUM)160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11)
|
|
2645
|
+
const __m512i lhs_mat_23_1_sp1 = _mm512_shuffle_epi32(lhs_mat_23_1, (_MM_PERM_ENUM)160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11)
|
|
2552
2646
|
|
|
2553
|
-
const __m512i lhs_mat_01_2_sp1 = _mm512_shuffle_epi32(lhs_mat_01_2, 160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19)
|
|
2554
|
-
const __m512i lhs_mat_23_2_sp1 = _mm512_shuffle_epi32(lhs_mat_23_2, 160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19)
|
|
2647
|
+
const __m512i lhs_mat_01_2_sp1 = _mm512_shuffle_epi32(lhs_mat_01_2, (_MM_PERM_ENUM)160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19)
|
|
2648
|
+
const __m512i lhs_mat_23_2_sp1 = _mm512_shuffle_epi32(lhs_mat_23_2, (_MM_PERM_ENUM)160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19)
|
|
2555
2649
|
|
|
2556
|
-
const __m512i lhs_mat_01_3_sp1 = _mm512_shuffle_epi32(lhs_mat_01_3, 160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27)
|
|
2557
|
-
const __m512i lhs_mat_23_3_sp1 = _mm512_shuffle_epi32(lhs_mat_23_3, 160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27)
|
|
2650
|
+
const __m512i lhs_mat_01_3_sp1 = _mm512_shuffle_epi32(lhs_mat_01_3, (_MM_PERM_ENUM)160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27)
|
|
2651
|
+
const __m512i lhs_mat_23_3_sp1 = _mm512_shuffle_epi32(lhs_mat_23_3, (_MM_PERM_ENUM)160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27)
|
|
2558
2652
|
|
|
2559
2653
|
// Shuffle pattern two - left side input
|
|
2560
2654
|
|
|
2561
|
-
const __m512i lhs_mat_01_0_sp2 = _mm512_shuffle_epi32(lhs_mat_01_0, 245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7)
|
|
2562
|
-
const __m512i lhs_mat_23_0_sp2 = _mm512_shuffle_epi32(lhs_mat_23_0, 245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7)
|
|
2655
|
+
const __m512i lhs_mat_01_0_sp2 = _mm512_shuffle_epi32(lhs_mat_01_0, (_MM_PERM_ENUM)245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7)
|
|
2656
|
+
const __m512i lhs_mat_23_0_sp2 = _mm512_shuffle_epi32(lhs_mat_23_0, (_MM_PERM_ENUM)245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7)
|
|
2563
2657
|
|
|
2564
|
-
const __m512i lhs_mat_01_1_sp2 = _mm512_shuffle_epi32(lhs_mat_01_1, 245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15)
|
|
2565
|
-
const __m512i lhs_mat_23_1_sp2 = _mm512_shuffle_epi32(lhs_mat_23_1, 245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15)
|
|
2658
|
+
const __m512i lhs_mat_01_1_sp2 = _mm512_shuffle_epi32(lhs_mat_01_1, (_MM_PERM_ENUM)245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15)
|
|
2659
|
+
const __m512i lhs_mat_23_1_sp2 = _mm512_shuffle_epi32(lhs_mat_23_1, (_MM_PERM_ENUM)245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15)
|
|
2566
2660
|
|
|
2567
|
-
const __m512i lhs_mat_01_2_sp2 = _mm512_shuffle_epi32(lhs_mat_01_2, 245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23)
|
|
2568
|
-
const __m512i lhs_mat_23_2_sp2 = _mm512_shuffle_epi32(lhs_mat_23_2, 245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23)
|
|
2661
|
+
const __m512i lhs_mat_01_2_sp2 = _mm512_shuffle_epi32(lhs_mat_01_2, (_MM_PERM_ENUM)245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23)
|
|
2662
|
+
const __m512i lhs_mat_23_2_sp2 = _mm512_shuffle_epi32(lhs_mat_23_2, (_MM_PERM_ENUM)245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23)
|
|
2569
2663
|
|
|
2570
|
-
const __m512i lhs_mat_01_3_sp2 = _mm512_shuffle_epi32(lhs_mat_01_3, 245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31)
|
|
2571
|
-
const __m512i lhs_mat_23_3_sp2 = _mm512_shuffle_epi32(lhs_mat_23_3, 245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31)
|
|
2664
|
+
const __m512i lhs_mat_01_3_sp2 = _mm512_shuffle_epi32(lhs_mat_01_3, (_MM_PERM_ENUM)245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31)
|
|
2665
|
+
const __m512i lhs_mat_23_3_sp2 = _mm512_shuffle_epi32(lhs_mat_23_3, (_MM_PERM_ENUM)245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31)
|
|
2572
2666
|
|
|
2573
2667
|
// The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
|
|
2574
2668
|
// Resembles MMLAs into 2x2 matrices in ARM Version
|
|
@@ -2597,10 +2691,10 @@ void wsp_ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const voi
|
|
|
2597
2691
|
|
|
2598
2692
|
|
|
2599
2693
|
// Straighten out to make 4 row vectors
|
|
2600
|
-
__m512i iacc_row_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00, _mm512_shuffle_epi32(iacc_mat_01, 78));
|
|
2601
|
-
__m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, 78), iacc_mat_01);
|
|
2602
|
-
__m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, 78));
|
|
2603
|
-
__m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11);
|
|
2694
|
+
__m512i iacc_row_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00, _mm512_shuffle_epi32(iacc_mat_01, (_MM_PERM_ENUM)78));
|
|
2695
|
+
__m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, (_MM_PERM_ENUM)78), iacc_mat_01);
|
|
2696
|
+
__m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, (_MM_PERM_ENUM)78));
|
|
2697
|
+
__m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, (_MM_PERM_ENUM)78), iacc_mat_11);
|
|
2604
2698
|
|
|
2605
2699
|
// Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes
|
|
2606
2700
|
const __m128i row_scale_f16 = _mm_shuffle_epi32(_mm_maskload_epi32((int const*)(a_ptrs[rp][b].d), loadMask), 68);
|
|
@@ -2679,31 +2773,31 @@ void wsp_ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const voi
|
|
|
2679
2773
|
const __m512i rhs_mat_2367ABEF_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 4), m4bexpanded)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31) BA(24-31) BB(24-31) BE(24-31) BF(24-31)
|
|
2680
2774
|
|
|
2681
2775
|
// Shuffle pattern one - right side input
|
|
2682
|
-
const __m512i rhs_mat_014589CD_0_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, 136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) B8(0-3) B9(0-3) B8(0-3) B9(0-3) BC(0-3) BD(0-3) BC(0-3) BD(0-3)
|
|
2683
|
-
const __m512i rhs_mat_2367ABEF_0_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, 136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) BA(0-3) BB(0-3) BA(0-3) BB(0-3) BE(0-3) BF(0-3) BE(0-3) BF(0-3)
|
|
2776
|
+
const __m512i rhs_mat_014589CD_0_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, (_MM_PERM_ENUM)136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) B8(0-3) B9(0-3) B8(0-3) B9(0-3) BC(0-3) BD(0-3) BC(0-3) BD(0-3)
|
|
2777
|
+
const __m512i rhs_mat_2367ABEF_0_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, (_MM_PERM_ENUM)136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) BA(0-3) BB(0-3) BA(0-3) BB(0-3) BE(0-3) BF(0-3) BE(0-3) BF(0-3)
|
|
2684
2778
|
|
|
2685
|
-
const __m512i rhs_mat_014589CD_1_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, 136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) B8(8-11) B9(8-11) B8(8-11) B9(8-11) BC(8-11) BD(8-11) BC(8-11) BD(8-11)
|
|
2686
|
-
const __m512i rhs_mat_2367ABEF_1_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, 136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) BA(8-11) BB(8-11) BA(8-11) BB(8-11) BE(8-11) BF(8-11) BE(8-11) BF(8-11)
|
|
2779
|
+
const __m512i rhs_mat_014589CD_1_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, (_MM_PERM_ENUM)136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) B8(8-11) B9(8-11) B8(8-11) B9(8-11) BC(8-11) BD(8-11) BC(8-11) BD(8-11)
|
|
2780
|
+
const __m512i rhs_mat_2367ABEF_1_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, (_MM_PERM_ENUM)136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) BA(8-11) BB(8-11) BA(8-11) BB(8-11) BE(8-11) BF(8-11) BE(8-11) BF(8-11)
|
|
2687
2781
|
|
|
2688
|
-
const __m512i rhs_mat_014589CD_2_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, 136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) B8(16-19) B9(16-19) B8(16-19) B9(16-19) BC(16-19) BD(16-19) BC(16-19) BD(16-19)
|
|
2689
|
-
const __m512i rhs_mat_2367ABEF_2_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, 136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) BA(16-19) BB(16-19) BA(16-19) BB(16-19) BE(16-19) BF(16-19) BE(16-19) BF(16-19)
|
|
2782
|
+
const __m512i rhs_mat_014589CD_2_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, (_MM_PERM_ENUM)136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) B8(16-19) B9(16-19) B8(16-19) B9(16-19) BC(16-19) BD(16-19) BC(16-19) BD(16-19)
|
|
2783
|
+
const __m512i rhs_mat_2367ABEF_2_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, (_MM_PERM_ENUM)136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) BA(16-19) BB(16-19) BA(16-19) BB(16-19) BE(16-19) BF(16-19) BE(16-19) BF(16-19)
|
|
2690
2784
|
|
|
2691
|
-
const __m512i rhs_mat_014589CD_3_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, 136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) B8(24-27) B9(24-27) B8(24-27) B9(24-27) BC(24-27) BD(24-27) BC(24-27) BD(24-27)
|
|
2692
|
-
const __m512i rhs_mat_2367ABEF_3_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, 136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) BA(24-27) BB(24-27) BA(24-27) BB(24-27) BE(24-27) BF(24-27) BE(24-27) BF(24-27)
|
|
2785
|
+
const __m512i rhs_mat_014589CD_3_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, (_MM_PERM_ENUM)136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) B8(24-27) B9(24-27) B8(24-27) B9(24-27) BC(24-27) BD(24-27) BC(24-27) BD(24-27)
|
|
2786
|
+
const __m512i rhs_mat_2367ABEF_3_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, (_MM_PERM_ENUM)136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) BA(24-27) BB(24-27) BA(24-27) BB(24-27) BE(24-27) BF(24-27) BE(24-27) BF(24-27)
|
|
2693
2787
|
|
|
2694
2788
|
// Shuffle pattern two - right side input
|
|
2695
2789
|
|
|
2696
|
-
const __m512i rhs_mat_014589CD_0_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, 221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) B8(4-7) B9(4-7) B8(4-7) B9(4-7) BC(4-7) BD(4-7) BC(4-7) BD(4-7)
|
|
2697
|
-
const __m512i rhs_mat_2367ABEF_0_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, 221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) BA(4-7) BB(4-7) BA(4-7) BB(4-7) BE(4-7) BF(4-7) BE(4-7) BF(4-7)
|
|
2790
|
+
const __m512i rhs_mat_014589CD_0_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, (_MM_PERM_ENUM)221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) B8(4-7) B9(4-7) B8(4-7) B9(4-7) BC(4-7) BD(4-7) BC(4-7) BD(4-7)
|
|
2791
|
+
const __m512i rhs_mat_2367ABEF_0_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, (_MM_PERM_ENUM)221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) BA(4-7) BB(4-7) BA(4-7) BB(4-7) BE(4-7) BF(4-7) BE(4-7) BF(4-7)
|
|
2698
2792
|
|
|
2699
|
-
const __m512i rhs_mat_014589CD_1_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, 221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) B8(12-15) B9(12-15) B8(12-15) B9(12-15) BC(12-15) BD(12-15) BC(12-15) BD(12-15)
|
|
2700
|
-
const __m512i rhs_mat_2367ABEF_1_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, 221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) BA(12-15) BB(12-15) BA(12-15) BB(12-15) BE(12-15) BF(12-15) BE(12-15) BF(12-15)
|
|
2793
|
+
const __m512i rhs_mat_014589CD_1_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, (_MM_PERM_ENUM)221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) B8(12-15) B9(12-15) B8(12-15) B9(12-15) BC(12-15) BD(12-15) BC(12-15) BD(12-15)
|
|
2794
|
+
const __m512i rhs_mat_2367ABEF_1_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, (_MM_PERM_ENUM)221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) BA(12-15) BB(12-15) BA(12-15) BB(12-15) BE(12-15) BF(12-15) BE(12-15) BF(12-15)
|
|
2701
2795
|
|
|
2702
|
-
const __m512i rhs_mat_014589CD_2_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, 221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) B8(20-23) B9(20-23) B8(20-23) B9(20-23) BC(20-23) BD(20-23) BC(20-23) BD(20-23)
|
|
2703
|
-
const __m512i rhs_mat_2367ABEF_2_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, 221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) BA(20-23) BB(20-23) BA(20-23) BB(20-23) BE(20-23) BF(20-23) BE(20-23) BF(20-23)
|
|
2796
|
+
const __m512i rhs_mat_014589CD_2_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, (_MM_PERM_ENUM)221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) B8(20-23) B9(20-23) B8(20-23) B9(20-23) BC(20-23) BD(20-23) BC(20-23) BD(20-23)
|
|
2797
|
+
const __m512i rhs_mat_2367ABEF_2_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, (_MM_PERM_ENUM)221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) BA(20-23) BB(20-23) BA(20-23) BB(20-23) BE(20-23) BF(20-23) BE(20-23) BF(20-23)
|
|
2704
2798
|
|
|
2705
|
-
const __m512i rhs_mat_014589CD_3_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, 221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) B8(28-31) B9(28-31) B8(28-31) B9(28-31) BC(28-31) BD(28-31) BC(28-31) BD(28-31)
|
|
2706
|
-
const __m512i rhs_mat_2367ABEF_3_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, 221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) BA(28-31) BB(28-31) BA(28-31) BB(28-31) BE(28-31) BF(28-31) BE(28-31) BF(28-31)
|
|
2799
|
+
const __m512i rhs_mat_014589CD_3_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, (_MM_PERM_ENUM)221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) B8(28-31) B9(28-31) B8(28-31) B9(28-31) BC(28-31) BD(28-31) BC(28-31) BD(28-31)
|
|
2800
|
+
const __m512i rhs_mat_2367ABEF_3_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, (_MM_PERM_ENUM)221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) BA(28-31) BB(28-31) BA(28-31) BB(28-31) BE(28-31) BF(28-31) BE(28-31) BF(28-31)
|
|
2707
2801
|
|
|
2708
2802
|
|
|
2709
2803
|
// Scale values - Load the weight scale values of two block_q4_0x8
|
|
@@ -2735,31 +2829,31 @@ void wsp_ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const voi
|
|
|
2735
2829
|
|
|
2736
2830
|
// Shuffle pattern one - left side input
|
|
2737
2831
|
|
|
2738
|
-
const __m512i lhs_mat_01_0_sp1 = _mm512_shuffle_epi32(lhs_mat_01_0, 160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3)
|
|
2739
|
-
const __m512i lhs_mat_23_0_sp1 = _mm512_shuffle_epi32(lhs_mat_23_0, 160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3)
|
|
2832
|
+
const __m512i lhs_mat_01_0_sp1 = _mm512_shuffle_epi32(lhs_mat_01_0, (_MM_PERM_ENUM)160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3)
|
|
2833
|
+
const __m512i lhs_mat_23_0_sp1 = _mm512_shuffle_epi32(lhs_mat_23_0, (_MM_PERM_ENUM)160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3)
|
|
2740
2834
|
|
|
2741
|
-
const __m512i lhs_mat_01_1_sp1 = _mm512_shuffle_epi32(lhs_mat_01_1, 160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11)
|
|
2742
|
-
const __m512i lhs_mat_23_1_sp1 = _mm512_shuffle_epi32(lhs_mat_23_1, 160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11)
|
|
2835
|
+
const __m512i lhs_mat_01_1_sp1 = _mm512_shuffle_epi32(lhs_mat_01_1, (_MM_PERM_ENUM)160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11)
|
|
2836
|
+
const __m512i lhs_mat_23_1_sp1 = _mm512_shuffle_epi32(lhs_mat_23_1, (_MM_PERM_ENUM)160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11)
|
|
2743
2837
|
|
|
2744
|
-
const __m512i lhs_mat_01_2_sp1 = _mm512_shuffle_epi32(lhs_mat_01_2, 160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19)
|
|
2745
|
-
const __m512i lhs_mat_23_2_sp1 = _mm512_shuffle_epi32(lhs_mat_23_2, 160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19)
|
|
2838
|
+
const __m512i lhs_mat_01_2_sp1 = _mm512_shuffle_epi32(lhs_mat_01_2, (_MM_PERM_ENUM)160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19)
|
|
2839
|
+
const __m512i lhs_mat_23_2_sp1 = _mm512_shuffle_epi32(lhs_mat_23_2, (_MM_PERM_ENUM)160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19)
|
|
2746
2840
|
|
|
2747
|
-
const __m512i lhs_mat_01_3_sp1 = _mm512_shuffle_epi32(lhs_mat_01_3, 160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27)
|
|
2748
|
-
const __m512i lhs_mat_23_3_sp1 = _mm512_shuffle_epi32(lhs_mat_23_3, 160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27)
|
|
2841
|
+
const __m512i lhs_mat_01_3_sp1 = _mm512_shuffle_epi32(lhs_mat_01_3, (_MM_PERM_ENUM)160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27)
|
|
2842
|
+
const __m512i lhs_mat_23_3_sp1 = _mm512_shuffle_epi32(lhs_mat_23_3, (_MM_PERM_ENUM)160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27)
|
|
2749
2843
|
|
|
2750
2844
|
// Shuffle pattern two - left side input
|
|
2751
2845
|
|
|
2752
|
-
const __m512i lhs_mat_01_0_sp2 = _mm512_shuffle_epi32(lhs_mat_01_0, 245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7)
|
|
2753
|
-
const __m512i lhs_mat_23_0_sp2 = _mm512_shuffle_epi32(lhs_mat_23_0, 245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7)
|
|
2846
|
+
const __m512i lhs_mat_01_0_sp2 = _mm512_shuffle_epi32(lhs_mat_01_0, (_MM_PERM_ENUM)245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7)
|
|
2847
|
+
const __m512i lhs_mat_23_0_sp2 = _mm512_shuffle_epi32(lhs_mat_23_0, (_MM_PERM_ENUM)245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7)
|
|
2754
2848
|
|
|
2755
|
-
const __m512i lhs_mat_01_1_sp2 = _mm512_shuffle_epi32(lhs_mat_01_1, 245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15)
|
|
2756
|
-
const __m512i lhs_mat_23_1_sp2 = _mm512_shuffle_epi32(lhs_mat_23_1, 245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15)
|
|
2849
|
+
const __m512i lhs_mat_01_1_sp2 = _mm512_shuffle_epi32(lhs_mat_01_1, (_MM_PERM_ENUM)245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15)
|
|
2850
|
+
const __m512i lhs_mat_23_1_sp2 = _mm512_shuffle_epi32(lhs_mat_23_1, (_MM_PERM_ENUM)245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15)
|
|
2757
2851
|
|
|
2758
|
-
const __m512i lhs_mat_01_2_sp2 = _mm512_shuffle_epi32(lhs_mat_01_2, 245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23)
|
|
2759
|
-
const __m512i lhs_mat_23_2_sp2 = _mm512_shuffle_epi32(lhs_mat_23_2, 245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23)
|
|
2852
|
+
const __m512i lhs_mat_01_2_sp2 = _mm512_shuffle_epi32(lhs_mat_01_2, (_MM_PERM_ENUM)245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23)
|
|
2853
|
+
const __m512i lhs_mat_23_2_sp2 = _mm512_shuffle_epi32(lhs_mat_23_2, (_MM_PERM_ENUM)245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23)
|
|
2760
2854
|
|
|
2761
|
-
const __m512i lhs_mat_01_3_sp2 = _mm512_shuffle_epi32(lhs_mat_01_3, 245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31)
|
|
2762
|
-
const __m512i lhs_mat_23_3_sp2 = _mm512_shuffle_epi32(lhs_mat_23_3, 245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31)
|
|
2855
|
+
const __m512i lhs_mat_01_3_sp2 = _mm512_shuffle_epi32(lhs_mat_01_3, (_MM_PERM_ENUM)245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31)
|
|
2856
|
+
const __m512i lhs_mat_23_3_sp2 = _mm512_shuffle_epi32(lhs_mat_23_3, (_MM_PERM_ENUM)245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31)
|
|
2763
2857
|
|
|
2764
2858
|
// The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
|
|
2765
2859
|
// Resembles MMLAs into 2x2 matrices in ARM Version
|
|
@@ -2788,10 +2882,10 @@ void wsp_ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const voi
|
|
|
2788
2882
|
|
|
2789
2883
|
|
|
2790
2884
|
// Straighten out to make 4 row vectors
|
|
2791
|
-
__m512i iacc_row_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00, _mm512_shuffle_epi32(iacc_mat_01, 78));
|
|
2792
|
-
__m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, 78), iacc_mat_01);
|
|
2793
|
-
__m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, 78));
|
|
2794
|
-
__m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11);
|
|
2885
|
+
__m512i iacc_row_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00, _mm512_shuffle_epi32(iacc_mat_01, (_MM_PERM_ENUM)78));
|
|
2886
|
+
__m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, (_MM_PERM_ENUM)78), iacc_mat_01);
|
|
2887
|
+
__m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, (_MM_PERM_ENUM)78));
|
|
2888
|
+
__m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, (_MM_PERM_ENUM)78), iacc_mat_11);
|
|
2795
2889
|
|
|
2796
2890
|
// Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes
|
|
2797
2891
|
const __m128i row_scale_f16 = _mm_shuffle_epi32(_mm_maskload_epi32((int const*)(a_ptr[b].d), loadMask), 68);
|
|
@@ -3386,7 +3480,117 @@ void wsp_ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const voi
|
|
|
3386
3480
|
}
|
|
3387
3481
|
}
|
|
3388
3482
|
|
|
3389
|
-
|
|
3483
|
+
static void wsp_ggml_gemm_iq4_nl_4x4_q8_0(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc) {
|
|
3484
|
+
const int qk = QK8_0;
|
|
3485
|
+
const int nb = n / qk;
|
|
3486
|
+
const int ncols_interleaved = 4;
|
|
3487
|
+
const int blocklen = 4;
|
|
3488
|
+
|
|
3489
|
+
assert (n % qk == 0);
|
|
3490
|
+
assert (nr % 4 == 0);
|
|
3491
|
+
assert (nc % ncols_interleaved == 0);
|
|
3492
|
+
|
|
3493
|
+
UNUSED(s);
|
|
3494
|
+
UNUSED(bs);
|
|
3495
|
+
UNUSED(vx);
|
|
3496
|
+
UNUSED(vy);
|
|
3497
|
+
UNUSED(nr);
|
|
3498
|
+
UNUSED(nc);
|
|
3499
|
+
UNUSED(nb);
|
|
3500
|
+
UNUSED(ncols_interleaved);
|
|
3501
|
+
UNUSED(blocklen);
|
|
3502
|
+
|
|
3503
|
+
#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
|
3504
|
+
if (wsp_ggml_cpu_has_neon() && wsp_ggml_cpu_has_dotprod()) {
|
|
3505
|
+
const int8x16_t kvalues = vld1q_s8(kvalues_iq4nl);
|
|
3506
|
+
|
|
3507
|
+
for (int y = 0; y < nr / 4; y++) {
|
|
3508
|
+
const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
|
|
3509
|
+
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
|
3510
|
+
const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb);
|
|
3511
|
+
|
|
3512
|
+
float32x4_t sumf[4];
|
|
3513
|
+
for (int m = 0; m < 4; m++) {
|
|
3514
|
+
sumf[m] = vdupq_n_f32(0);
|
|
3515
|
+
}
|
|
3516
|
+
|
|
3517
|
+
for (int l = 0; l < nb; l++) {
|
|
3518
|
+
float32x4_t a_d = vcvt_f32_f16(vld1_f16((const float16_t *)a_ptr[l].d));
|
|
3519
|
+
float32x4_t b_d = vcvt_f32_f16(vld1_f16((const float16_t *)b_ptr[l].d));
|
|
3520
|
+
|
|
3521
|
+
int32x4_t sumi_0 = vdupq_n_s32(0);
|
|
3522
|
+
int32x4_t sumi_1 = vdupq_n_s32(0);
|
|
3523
|
+
int32x4_t sumi_2 = vdupq_n_s32(0);
|
|
3524
|
+
int32x4_t sumi_3 = vdupq_n_s32(0);
|
|
3525
|
+
|
|
3526
|
+
for (int k = 0; k < 4; k++) {
|
|
3527
|
+
int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 16 * k + 0);
|
|
3528
|
+
int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16 * k + 64);
|
|
3529
|
+
|
|
3530
|
+
uint8x16_t b = vld1q_u8(b_ptr[l].qs + 16 * k);
|
|
3531
|
+
int8x16_t b_hi = vqtbl1q_s8(kvalues, b >> 4);
|
|
3532
|
+
int8x16_t b_lo = vqtbl1q_s8(kvalues, b & 0xF);
|
|
3533
|
+
|
|
3534
|
+
sumi_0 = vdotq_laneq_s32(sumi_0, b_lo, a_0, 0);
|
|
3535
|
+
sumi_1 = vdotq_laneq_s32(sumi_1, b_lo, a_0, 1);
|
|
3536
|
+
sumi_2 = vdotq_laneq_s32(sumi_2, b_lo, a_0, 2);
|
|
3537
|
+
sumi_3 = vdotq_laneq_s32(sumi_3, b_lo, a_0, 3);
|
|
3538
|
+
sumi_0 = vdotq_laneq_s32(sumi_0, b_hi, a_1, 0);
|
|
3539
|
+
sumi_1 = vdotq_laneq_s32(sumi_1, b_hi, a_1, 1);
|
|
3540
|
+
sumi_2 = vdotq_laneq_s32(sumi_2, b_hi, a_1, 2);
|
|
3541
|
+
sumi_3 = vdotq_laneq_s32(sumi_3, b_hi, a_1, 3);
|
|
3542
|
+
}
|
|
3543
|
+
|
|
3544
|
+
sumf[0] = vmlaq_f32(sumf[0], vmulq_laneq_f32(b_d, a_d, 0), vcvtq_f32_s32(sumi_0));
|
|
3545
|
+
sumf[1] = vmlaq_f32(sumf[1], vmulq_laneq_f32(b_d, a_d, 1), vcvtq_f32_s32(sumi_1));
|
|
3546
|
+
sumf[2] = vmlaq_f32(sumf[2], vmulq_laneq_f32(b_d, a_d, 2), vcvtq_f32_s32(sumi_2));
|
|
3547
|
+
sumf[3] = vmlaq_f32(sumf[3], vmulq_laneq_f32(b_d, a_d, 3), vcvtq_f32_s32(sumi_3));
|
|
3548
|
+
}
|
|
3549
|
+
|
|
3550
|
+
for (int m = 0; m < 4; m++) {
|
|
3551
|
+
vst1q_f32(s + (y * 4 + m) * bs + x * 4, sumf[m]);
|
|
3552
|
+
}
|
|
3553
|
+
}
|
|
3554
|
+
}
|
|
3555
|
+
return;
|
|
3556
|
+
}
|
|
3557
|
+
#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
|
|
3558
|
+
{
|
|
3559
|
+
float sumf[4][4];
|
|
3560
|
+
int sumi;
|
|
3561
|
+
|
|
3562
|
+
for (int y = 0; y < nr / 4; y++) {
|
|
3563
|
+
const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
|
|
3564
|
+
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
|
3565
|
+
const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb);
|
|
3566
|
+
for (int m = 0; m < 4; m++) {
|
|
3567
|
+
for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
|
|
3568
|
+
}
|
|
3569
|
+
for (int l = 0; l < nb; l++) {
|
|
3570
|
+
for (int k = 0; k < (qk / (2 * blocklen)); k++) {
|
|
3571
|
+
for (int m = 0; m < 4; m++) {
|
|
3572
|
+
for (int j = 0; j < ncols_interleaved; j++) {
|
|
3573
|
+
sumi = 0;
|
|
3574
|
+
for (int i = 0; i < blocklen; ++i) {
|
|
3575
|
+
const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F];
|
|
3576
|
+
const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4];
|
|
3577
|
+
sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
|
|
3578
|
+
(v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4]));
|
|
3579
|
+
}
|
|
3580
|
+
sumf[m][j] += sumi * WSP_GGML_FP16_TO_FP32(b_ptr[l].d[j]) * WSP_GGML_FP16_TO_FP32(a_ptr[l].d[m]);
|
|
3581
|
+
}
|
|
3582
|
+
}
|
|
3583
|
+
}
|
|
3584
|
+
}
|
|
3585
|
+
for (int m = 0; m < 4; m++) {
|
|
3586
|
+
for (int j = 0; j < ncols_interleaved; j++)
|
|
3587
|
+
s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
|
|
3588
|
+
}
|
|
3589
|
+
}
|
|
3590
|
+
}
|
|
3591
|
+
}
|
|
3592
|
+
}
|
|
3593
|
+
|
|
3390
3594
|
static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int blck_size_interleave) {
|
|
3391
3595
|
block_q4_0x4 out;
|
|
3392
3596
|
|
|
@@ -3456,20 +3660,20 @@ static block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int blck_size_in
|
|
|
3456
3660
|
return out;
|
|
3457
3661
|
}
|
|
3458
3662
|
|
|
3459
|
-
static int repack_q4_0_to_q4_0_4_bl(struct wsp_ggml_tensor * t, int interleave_block, const void *
|
|
3663
|
+
static int repack_q4_0_to_q4_0_4_bl(struct wsp_ggml_tensor * t, int interleave_block, const void * WSP_GGML_RESTRICT data, size_t data_size) {
|
|
3460
3664
|
WSP_GGML_ASSERT(t->type == WSP_GGML_TYPE_Q4_0);
|
|
3461
3665
|
WSP_GGML_ASSERT(interleave_block == 4 || interleave_block == 8);
|
|
3666
|
+
constexpr int nrows_interleaved = 4;
|
|
3462
3667
|
|
|
3463
3668
|
block_q4_0x4 * dst = (block_q4_0x4 *)t->data;
|
|
3464
3669
|
const block_q4_0 * src = (const block_q4_0 *)data;
|
|
3465
3670
|
block_q4_0 dst_tmp[4];
|
|
3466
|
-
int nrow = t
|
|
3467
|
-
int nrows_interleaved = 4;
|
|
3671
|
+
int nrow = wsp_ggml_nrows(t);
|
|
3468
3672
|
int nblocks = t->ne[0] / QK4_0;
|
|
3469
3673
|
|
|
3470
3674
|
WSP_GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0));
|
|
3471
3675
|
|
|
3472
|
-
if (
|
|
3676
|
+
if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
|
|
3473
3677
|
return -1;
|
|
3474
3678
|
}
|
|
3475
3679
|
|
|
@@ -3487,20 +3691,20 @@ static int repack_q4_0_to_q4_0_4_bl(struct wsp_ggml_tensor * t, int interleave_b
|
|
|
3487
3691
|
WSP_GGML_UNUSED(data_size);
|
|
3488
3692
|
}
|
|
3489
3693
|
|
|
3490
|
-
static int repack_q4_0_to_q4_0_8_bl(struct wsp_ggml_tensor *t, int interleave_block, const void *
|
|
3694
|
+
static int repack_q4_0_to_q4_0_8_bl(struct wsp_ggml_tensor * t, int interleave_block, const void * WSP_GGML_RESTRICT data, size_t data_size) {
|
|
3491
3695
|
WSP_GGML_ASSERT(t->type == WSP_GGML_TYPE_Q4_0);
|
|
3492
3696
|
WSP_GGML_ASSERT(interleave_block == 8);
|
|
3697
|
+
constexpr int nrows_interleaved = 8;
|
|
3493
3698
|
|
|
3494
3699
|
block_q4_0x8 * dst = (block_q4_0x8*)t->data;
|
|
3495
3700
|
const block_q4_0 * src = (const block_q4_0*) data;
|
|
3496
3701
|
block_q4_0 dst_tmp[8];
|
|
3497
|
-
int nrow = t
|
|
3498
|
-
int nrows_interleaved = 8;
|
|
3702
|
+
int nrow = wsp_ggml_nrows(t);
|
|
3499
3703
|
int nblocks = t->ne[0] / QK4_0;
|
|
3500
3704
|
|
|
3501
3705
|
WSP_GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0));
|
|
3502
3706
|
|
|
3503
|
-
if (
|
|
3707
|
+
if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
|
|
3504
3708
|
return -1;
|
|
3505
3709
|
}
|
|
3506
3710
|
|
|
@@ -3518,43 +3722,526 @@ static int repack_q4_0_to_q4_0_8_bl(struct wsp_ggml_tensor *t, int interleave_bl
|
|
|
3518
3722
|
WSP_GGML_UNUSED(data_size);
|
|
3519
3723
|
}
|
|
3520
3724
|
|
|
3521
|
-
|
|
3522
|
-
|
|
3523
|
-
|
|
3524
|
-
|
|
3525
|
-
|
|
3725
|
+
static block_iq4_nlx4 make_block_iq4_nlx4(block_iq4_nl * in, unsigned int blck_size_interleave) {
|
|
3726
|
+
block_iq4_nlx4 out;
|
|
3727
|
+
|
|
3728
|
+
for (int i = 0; i < 4; i++) {
|
|
3729
|
+
out.d[i] = in[i].d;
|
|
3526
3730
|
}
|
|
3527
3731
|
|
|
3528
|
-
|
|
3732
|
+
const int end = QK4_NL * 2 / blck_size_interleave;
|
|
3529
3733
|
|
|
3530
|
-
|
|
3531
|
-
|
|
3532
|
-
|
|
3533
|
-
|
|
3534
|
-
|
|
3535
|
-
|
|
3536
|
-
|
|
3537
|
-
|
|
3538
|
-
|
|
3734
|
+
// TODO: this branch seems wrong
|
|
3735
|
+
//if (blck_size_interleave == 8) {
|
|
3736
|
+
// for (int i = 0; i < end; ++i) {
|
|
3737
|
+
// int src_id = i % 4;
|
|
3738
|
+
// int src_offset = (i / 4) * blck_size_interleave;
|
|
3739
|
+
// int dst_offset = i * blck_size_interleave;
|
|
3740
|
+
|
|
3741
|
+
// // Using memcpy to avoid unaligned memory accesses
|
|
3742
|
+
// memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint64_t));
|
|
3743
|
+
// }
|
|
3744
|
+
//} else
|
|
3745
|
+
if (blck_size_interleave == 4) {
|
|
3746
|
+
for (int i = 0; i < end; ++i) {
|
|
3747
|
+
int src_id = i % 4;
|
|
3748
|
+
int src_offset = (i / 4) * blck_size_interleave;
|
|
3749
|
+
int dst_offset = i * blck_size_interleave;
|
|
3750
|
+
|
|
3751
|
+
memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint32_t));
|
|
3752
|
+
}
|
|
3753
|
+
} else {
|
|
3754
|
+
WSP_GGML_ASSERT(false);
|
|
3755
|
+
}
|
|
3756
|
+
|
|
3757
|
+
return out;
|
|
3758
|
+
}
|
|
3759
|
+
|
|
3760
|
+
static int repack_iq4_nl_to_iq4_nl_4_bl(struct wsp_ggml_tensor * t, int interleave_block, const void * WSP_GGML_RESTRICT data, size_t data_size) {
|
|
3761
|
+
WSP_GGML_ASSERT(t->type == WSP_GGML_TYPE_IQ4_NL);
|
|
3762
|
+
//WSP_GGML_ASSERT(interleave_block == 4 || interleave_block == 8);
|
|
3763
|
+
WSP_GGML_ASSERT(interleave_block == 4);
|
|
3764
|
+
|
|
3765
|
+
block_iq4_nlx4 * dst = (block_iq4_nlx4 *)t->data;
|
|
3766
|
+
const block_iq4_nl * src = (const block_iq4_nl *)data;
|
|
3767
|
+
block_iq4_nl dst_tmp[4];
|
|
3768
|
+
int nrow = wsp_ggml_nrows(t);
|
|
3769
|
+
int nrows_interleaved = 4;
|
|
3770
|
+
int nblocks = t->ne[0] / QK4_0;
|
|
3771
|
+
|
|
3772
|
+
WSP_GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_iq4_nl));
|
|
3773
|
+
|
|
3774
|
+
if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
|
|
3775
|
+
return -1;
|
|
3776
|
+
}
|
|
3777
|
+
|
|
3778
|
+
for (int b = 0; b < nrow; b += nrows_interleaved) {
|
|
3779
|
+
for (int64_t x = 0; x < nblocks; x++) {
|
|
3780
|
+
for (int i = 0; i < nrows_interleaved; i++) {
|
|
3781
|
+
dst_tmp[i] = src[x + i * nblocks];
|
|
3782
|
+
}
|
|
3783
|
+
*dst++ = make_block_iq4_nlx4(dst_tmp, interleave_block);
|
|
3784
|
+
}
|
|
3785
|
+
src += nrows_interleaved * nblocks;
|
|
3786
|
+
}
|
|
3787
|
+
return 0;
|
|
3788
|
+
|
|
3789
|
+
WSP_GGML_UNUSED(data_size);
|
|
3790
|
+
}
|
|
3791
|
+
|
|
3792
|
+
namespace ggml::cpu::aarch64 {
|
|
3793
|
+
// repack
|
|
3794
|
+
template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS>
|
|
3795
|
+
int repack(struct wsp_ggml_tensor *, const void *, size_t);
|
|
3796
|
+
|
|
3797
|
+
// TODO: generalise.
|
|
3798
|
+
template <> int repack<block_q4_0, 4, 4>(struct wsp_ggml_tensor * t, const void * data, size_t data_size) {
|
|
3799
|
+
return repack_q4_0_to_q4_0_4_bl(t, 4, data, data_size);
|
|
3800
|
+
}
|
|
3801
|
+
|
|
3802
|
+
template <> int repack<block_q4_0, 8, 4>(struct wsp_ggml_tensor * t, const void * data, size_t data_size) {
|
|
3803
|
+
return repack_q4_0_to_q4_0_4_bl(t, 8, data, data_size);
|
|
3804
|
+
}
|
|
3805
|
+
|
|
3806
|
+
template <> int repack<block_q4_0, 8, 8>(struct wsp_ggml_tensor * t, const void * data, size_t data_size) {
|
|
3807
|
+
return repack_q4_0_to_q4_0_8_bl(t, 8, data, data_size);
|
|
3808
|
+
}
|
|
3809
|
+
|
|
3810
|
+
template <> int repack<block_iq4_nl, 4, 4>(struct wsp_ggml_tensor * t, const void * data, size_t data_size) {
|
|
3811
|
+
return repack_iq4_nl_to_iq4_nl_4_bl(t, 4, data, data_size);
|
|
3812
|
+
}
|
|
3813
|
+
|
|
3814
|
+
// TODO: needs to be revisited
|
|
3815
|
+
//template <> int repack<block_iq4_nl, 8, 4>(struct wsp_ggml_tensor * t, const void * data, size_t data_size) {
|
|
3816
|
+
// return repack_iq4_nl_to_iq4_nl_4_bl(t, 8, data, data_size);
|
|
3817
|
+
//}
|
|
3818
|
+
|
|
3819
|
+
// gemv
|
|
3820
|
+
template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS>
|
|
3821
|
+
void gemv(int, float *, size_t, const void *, const void *, int, int);
|
|
3822
|
+
|
|
3823
|
+
template <> void gemv<block_q4_0, 4, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
|
3824
|
+
wsp_ggml_gemv_q4_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
|
|
3825
|
+
}
|
|
3826
|
+
|
|
3827
|
+
template <> void gemv<block_q4_0, 8, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
|
3828
|
+
wsp_ggml_gemv_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
|
|
3829
|
+
}
|
|
3830
|
+
|
|
3831
|
+
template <> void gemv<block_q4_0, 8, 8>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
|
3832
|
+
wsp_ggml_gemv_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
|
|
3833
|
+
}
|
|
3834
|
+
|
|
3835
|
+
template <>
|
|
3836
|
+
void gemv<block_iq4_nl, 4, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
|
3837
|
+
wsp_ggml_gemv_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
|
|
3838
|
+
}
|
|
3839
|
+
|
|
3840
|
+
// gemm
|
|
3841
|
+
template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS>
|
|
3842
|
+
void gemm(int, float *, size_t, const void *, const void *, int, int);
|
|
3843
|
+
|
|
3844
|
+
template <> void gemm<block_q4_0, 4, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
|
3845
|
+
wsp_ggml_gemm_q4_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
|
|
3846
|
+
}
|
|
3847
|
+
|
|
3848
|
+
template <> void gemm<block_q4_0, 8, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
|
3849
|
+
wsp_ggml_gemm_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
|
|
3850
|
+
}
|
|
3851
|
+
|
|
3852
|
+
template <> void gemm<block_q4_0, 8, 8>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
|
3853
|
+
wsp_ggml_gemm_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
|
|
3854
|
+
}
|
|
3855
|
+
|
|
3856
|
+
template <>
|
|
3857
|
+
void gemm<block_iq4_nl, 4, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
|
3858
|
+
wsp_ggml_gemm_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
|
|
3859
|
+
}
|
|
3860
|
+
|
|
3861
|
+
class tensor_traits_base : public ggml::cpu::tensor_traits {
|
|
3862
|
+
public:
|
|
3863
|
+
virtual int repack(struct wsp_ggml_tensor * t, const void * data, size_t data_size) = 0;
|
|
3864
|
+
};
|
|
3865
|
+
|
|
3866
|
+
template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS> class tensor_traits : public tensor_traits_base {
|
|
3867
|
+
|
|
3868
|
+
bool work_size(int /* n_threads */, const struct wsp_ggml_tensor * op, size_t & size) override {
|
|
3869
|
+
// not realy a WSP_GGML_TYPE_Q8_0 but same size.
|
|
3870
|
+
switch (op->op) {
|
|
3871
|
+
case WSP_GGML_OP_MUL_MAT:
|
|
3872
|
+
size = wsp_ggml_row_size(WSP_GGML_TYPE_Q8_0, wsp_ggml_nelements(op->src[1]));
|
|
3873
|
+
return true;
|
|
3874
|
+
case WSP_GGML_OP_MUL_MAT_ID:
|
|
3875
|
+
size = wsp_ggml_row_size(WSP_GGML_TYPE_Q8_0, wsp_ggml_nelements(op->src[1]));
|
|
3876
|
+
size = WSP_GGML_PAD(size, sizeof(int64_t)); // + padding for next bloc.
|
|
3877
|
+
size += sizeof(int64_t) * (1+op->src[0]->ne[2]) * op->src[1]->ne[2];
|
|
3878
|
+
return true;
|
|
3879
|
+
default:
|
|
3880
|
+
// WSP_GGML_ABORT("fatal error");
|
|
3539
3881
|
break;
|
|
3882
|
+
}
|
|
3883
|
+
return false;
|
|
3884
|
+
}
|
|
3885
|
+
|
|
3886
|
+
bool compute_forward(struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * op) override {
|
|
3887
|
+
switch (op->op) {
|
|
3888
|
+
case WSP_GGML_OP_MUL_MAT:
|
|
3889
|
+
forward_mul_mat(params, op);
|
|
3890
|
+
return true;
|
|
3891
|
+
case WSP_GGML_OP_MUL_MAT_ID:
|
|
3892
|
+
forward_mul_mat_id(params, op);
|
|
3893
|
+
return true;
|
|
3540
3894
|
default:
|
|
3541
|
-
WSP_GGML_ABORT("
|
|
3895
|
+
// WSP_GGML_ABORT("fatal error");
|
|
3896
|
+
break;
|
|
3897
|
+
}
|
|
3898
|
+
return false;
|
|
3542
3899
|
}
|
|
3543
|
-
}
|
|
3544
3900
|
|
|
3545
|
-
|
|
3901
|
+
void forward_mul_mat(wsp_ggml_compute_params * params, wsp_ggml_tensor * op) {
|
|
3902
|
+
const wsp_ggml_tensor * src0 = op->src[0];
|
|
3903
|
+
const wsp_ggml_tensor * src1 = op->src[1];
|
|
3904
|
+
wsp_ggml_tensor * dst = op;
|
|
3905
|
+
|
|
3906
|
+
WSP_GGML_TENSOR_BINARY_OP_LOCALS
|
|
3907
|
+
|
|
3908
|
+
const int ith = params->ith;
|
|
3909
|
+
const int nth = params->nth;
|
|
3910
|
+
|
|
3911
|
+
WSP_GGML_ASSERT(ne0 == ne01);
|
|
3912
|
+
WSP_GGML_ASSERT(ne1 == ne11);
|
|
3913
|
+
WSP_GGML_ASSERT(ne2 == ne12);
|
|
3914
|
+
WSP_GGML_ASSERT(ne3 == ne13);
|
|
3915
|
+
|
|
3916
|
+
// dst cannot be transposed or permuted
|
|
3917
|
+
WSP_GGML_ASSERT(nb0 == sizeof(float));
|
|
3918
|
+
WSP_GGML_ASSERT(nb0 <= nb1);
|
|
3919
|
+
WSP_GGML_ASSERT(nb1 <= nb2);
|
|
3920
|
+
WSP_GGML_ASSERT(nb2 <= nb3);
|
|
3921
|
+
|
|
3922
|
+
WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F32);
|
|
3923
|
+
|
|
3924
|
+
WSP_GGML_ASSERT(wsp_ggml_n_dims(op->src[0]) == 2);
|
|
3925
|
+
// WSP_GGML_ASSERT(wsp_ggml_n_dims(op->src[1]) == 2);
|
|
3926
|
+
|
|
3927
|
+
char * wdata = static_cast<char *>(params->wdata);
|
|
3928
|
+
const size_t nbw1 = wsp_ggml_row_size(WSP_GGML_TYPE_Q8_0, ne10);
|
|
3929
|
+
|
|
3930
|
+
assert(params->wsize >= nbw1 * ne11);
|
|
3931
|
+
|
|
3932
|
+
const wsp_ggml_from_float_t from_float = wsp_ggml_get_type_traits_cpu(WSP_GGML_TYPE_Q8_0)->from_float;
|
|
3933
|
+
|
|
3934
|
+
int64_t i11_processed = 0;
|
|
3935
|
+
for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
|
|
3936
|
+
wsp_quantize_mat_q8_0((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), 4, ne10,
|
|
3937
|
+
INTER_SIZE);
|
|
3938
|
+
}
|
|
3939
|
+
i11_processed = ne11 - ne11 % 4;
|
|
3940
|
+
for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) {
|
|
3941
|
+
from_float((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), ne10);
|
|
3942
|
+
}
|
|
3943
|
+
|
|
3944
|
+
wsp_ggml_barrier(params->threadpool);
|
|
3945
|
+
|
|
3946
|
+
const void * src1_wdata = params->wdata;
|
|
3947
|
+
const size_t src1_col_stride = wsp_ggml_row_size(WSP_GGML_TYPE_Q8_0, ne10);
|
|
3948
|
+
int64_t src0_start = (ith * ne01) / nth;
|
|
3949
|
+
int64_t src0_end = ((ith + 1) * ne01) / nth;
|
|
3950
|
+
src0_start = (src0_start % NB_COLS) ? src0_start + NB_COLS - (src0_start % NB_COLS) : src0_start;
|
|
3951
|
+
src0_end = (src0_end % NB_COLS) ? src0_end + NB_COLS - (src0_end % NB_COLS) : src0_end;
|
|
3952
|
+
if (src0_start >= src0_end) {
|
|
3953
|
+
return;
|
|
3954
|
+
}
|
|
3955
|
+
|
|
3956
|
+
// If there are more than three rows in src1, use gemm; otherwise, use gemv.
|
|
3957
|
+
if (ne11 > 3) {
|
|
3958
|
+
gemm<BLOC_TYPE, INTER_SIZE, NB_COLS>(ne00, (float *) ((char *) dst->data) + src0_start, ne01,
|
|
3959
|
+
(const char *) src0->data + src0_start * nb01,
|
|
3960
|
+
(const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start);
|
|
3961
|
+
}
|
|
3962
|
+
for (int iter = ne11 - ne11 % 4; iter < ne11; iter++) {
|
|
3963
|
+
gemv<BLOC_TYPE, INTER_SIZE, NB_COLS>(ne00, (float *) ((char *) dst->data + (iter * nb1)) + src0_start, ne01,
|
|
3964
|
+
(const char *) src0->data + src0_start * nb01,
|
|
3965
|
+
(const char *) src1_wdata + (src1_col_stride * iter), 1,
|
|
3966
|
+
src0_end - src0_start);
|
|
3967
|
+
}
|
|
3968
|
+
}
|
|
3969
|
+
|
|
3970
|
+
void forward_mul_mat_id(wsp_ggml_compute_params * params, wsp_ggml_tensor * op) {
|
|
3971
|
+
const wsp_ggml_tensor * src0 = op->src[0];
|
|
3972
|
+
const wsp_ggml_tensor * src1 = op->src[1];
|
|
3973
|
+
const wsp_ggml_tensor * ids = op->src[2];
|
|
3974
|
+
wsp_ggml_tensor * dst = op;
|
|
3975
|
+
|
|
3976
|
+
WSP_GGML_TENSOR_BINARY_OP_LOCALS
|
|
3977
|
+
|
|
3978
|
+
const int ith = params->ith;
|
|
3979
|
+
const int nth = params->nth;
|
|
3980
|
+
|
|
3981
|
+
const wsp_ggml_from_float_t from_float = wsp_ggml_get_type_traits_cpu(WSP_GGML_TYPE_Q8_0)->from_float;
|
|
3982
|
+
|
|
3983
|
+
// we don't support permuted src0 or src1
|
|
3984
|
+
WSP_GGML_ASSERT(nb00 == wsp_ggml_type_size(src0->type));
|
|
3985
|
+
WSP_GGML_ASSERT(nb10 == wsp_ggml_type_size(src1->type));
|
|
3986
|
+
|
|
3987
|
+
// dst cannot be transposed or permuted
|
|
3988
|
+
WSP_GGML_ASSERT(nb0 == sizeof(float));
|
|
3989
|
+
WSP_GGML_ASSERT(nb0 <= nb1);
|
|
3990
|
+
WSP_GGML_ASSERT(nb1 <= nb2);
|
|
3991
|
+
WSP_GGML_ASSERT(nb2 <= nb3);
|
|
3992
|
+
|
|
3993
|
+
WSP_GGML_ASSERT(ne03 == 1);
|
|
3994
|
+
WSP_GGML_ASSERT(ne13 == 1);
|
|
3995
|
+
WSP_GGML_ASSERT(ne3 == 1);
|
|
3996
|
+
|
|
3997
|
+
WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F32);
|
|
3998
|
+
|
|
3999
|
+
// row groups
|
|
4000
|
+
const int n_ids = ids->ne[0]; // n_expert_used
|
|
4001
|
+
const int n_as = ne02; // n_expert
|
|
4002
|
+
|
|
4003
|
+
const size_t nbw1 = wsp_ggml_row_size(WSP_GGML_TYPE_Q8_0, ne10);
|
|
4004
|
+
const size_t nbw2 = nbw1*ne11;
|
|
4005
|
+
const size_t nbw3 = nbw2*ne12;
|
|
4006
|
+
|
|
4007
|
+
struct mmid_row_mapping {
|
|
4008
|
+
int32_t i1;
|
|
4009
|
+
int32_t i2;
|
|
4010
|
+
};
|
|
4011
|
+
|
|
4012
|
+
WSP_GGML_ASSERT(params->wsize >= (WSP_GGML_PAD(nbw3, sizeof(int64_t)) + n_as * sizeof(int64_t) +
|
|
4013
|
+
n_as * ne12 * sizeof(mmid_row_mapping)));
|
|
4014
|
+
|
|
4015
|
+
auto wdata = (char *) params->wdata;
|
|
4016
|
+
auto wdata_src1_end = (char *) wdata + WSP_GGML_PAD(nbw3, sizeof(int64_t));
|
|
4017
|
+
int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
|
|
4018
|
+
struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *) (matrix_row_counts + n_as); // [n_as][ne12]
|
|
4019
|
+
|
|
4020
|
+
// src1: float32 => block_q8_0
|
|
4021
|
+
for (int64_t i12 = 0; i12 < ne12; ++i12) {
|
|
4022
|
+
for (int64_t i11 = ith; i11 < ne11; i11 += nth) {
|
|
4023
|
+
from_float((float *)((char *) src1->data + i12 * nb12 + i11 * nb11),
|
|
4024
|
+
(void *) (wdata + i12 * nbw2 + i11 * nbw1),
|
|
4025
|
+
ne10);
|
|
4026
|
+
}
|
|
4027
|
+
}
|
|
4028
|
+
|
|
4029
|
+
#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id) * ne12 + (i1)]
|
|
4030
|
+
|
|
4031
|
+
if (ith == 0) {
|
|
4032
|
+
// initialize matrix_row_counts
|
|
4033
|
+
memset(matrix_row_counts, 0, n_as * sizeof(int64_t));
|
|
4034
|
+
|
|
4035
|
+
// group rows by src0 matrix
|
|
4036
|
+
for (int32_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) {
|
|
4037
|
+
for (int32_t id = 0; id < n_ids; ++id) {
|
|
4038
|
+
const int32_t i02 =
|
|
4039
|
+
*(const int32_t *) ((const char *) ids->data + iid1 * ids->nb[1] + id * ids->nb[0]);
|
|
4040
|
+
|
|
4041
|
+
WSP_GGML_ASSERT(i02 >= 0 && i02 < n_as);
|
|
4042
|
+
|
|
4043
|
+
MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = { id, iid1 };
|
|
4044
|
+
matrix_row_counts[i02] += 1;
|
|
4045
|
+
}
|
|
4046
|
+
}
|
|
4047
|
+
}
|
|
4048
|
+
|
|
4049
|
+
wsp_ggml_barrier(params->threadpool);
|
|
4050
|
+
|
|
4051
|
+
// compute each matrix multiplication in sequence
|
|
4052
|
+
for (int cur_a = 0; cur_a < n_as; ++cur_a) {
|
|
4053
|
+
const int64_t cne1 = matrix_row_counts[cur_a];
|
|
4054
|
+
|
|
4055
|
+
if (cne1 == 0) {
|
|
4056
|
+
continue;
|
|
4057
|
+
}
|
|
4058
|
+
|
|
4059
|
+
auto src0_cur = (const char *) src0->data + cur_a*nb02;
|
|
4060
|
+
|
|
4061
|
+
//const int64_t nr0 = ne01; // src0 rows
|
|
4062
|
+
const int64_t nr1 = cne1; // src1 rows
|
|
4063
|
+
|
|
4064
|
+
int64_t src0_cur_start = (ith * ne01) / nth;
|
|
4065
|
+
int64_t src0_cur_end = ((ith + 1) * ne01) / nth;
|
|
4066
|
+
src0_cur_start =
|
|
4067
|
+
(src0_cur_start % NB_COLS) ? src0_cur_start + NB_COLS - (src0_cur_start % NB_COLS) : src0_cur_start;
|
|
4068
|
+
src0_cur_end = (src0_cur_end % NB_COLS) ? src0_cur_end + NB_COLS - (src0_cur_end % NB_COLS) : src0_cur_end;
|
|
4069
|
+
|
|
4070
|
+
if (src0_cur_start >= src0_cur_end) return;
|
|
4071
|
+
|
|
4072
|
+
for (int ir1 = 0; ir1 < nr1; ir1++) {
|
|
4073
|
+
struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, ir1);
|
|
4074
|
+
const int id = row_mapping.i1; // selected expert index
|
|
4075
|
+
|
|
4076
|
+
const int64_t i11 = id % ne11;
|
|
4077
|
+
const int64_t i12 = row_mapping.i2; // row index in src1
|
|
4078
|
+
|
|
4079
|
+
const int64_t i1 = id; // selected expert index
|
|
4080
|
+
const int64_t i2 = i12; // row
|
|
4081
|
+
|
|
4082
|
+
auto src1_col = (const char *) wdata + (i11 * nbw1 + i12 * nbw2);
|
|
4083
|
+
|
|
4084
|
+
gemv<BLOC_TYPE, INTER_SIZE, NB_COLS>(
|
|
4085
|
+
ne00, (float *)((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start,
|
|
4086
|
+
ne01, src0_cur + src0_cur_start * nb01,
|
|
4087
|
+
src1_col, 1, src0_cur_end - src0_cur_start);
|
|
4088
|
+
}
|
|
4089
|
+
}
|
|
4090
|
+
#undef MMID_MATRIX_ROW
|
|
4091
|
+
}
|
|
4092
|
+
|
|
4093
|
+
int repack(struct wsp_ggml_tensor * t, const void * data, size_t data_size) override {
|
|
4094
|
+
WSP_GGML_LOG_DEBUG("%s: repack tensor %s with %s_%dx%d\n", __func__, t->name, wsp_ggml_type_name(t->type),
|
|
4095
|
+
(int) NB_COLS, (int) INTER_SIZE);
|
|
4096
|
+
return ggml::cpu::aarch64::repack<BLOC_TYPE, INTER_SIZE, NB_COLS>(t, data, data_size);
|
|
4097
|
+
}
|
|
4098
|
+
};
|
|
4099
|
+
|
|
4100
|
+
// instance for Q4
|
|
4101
|
+
static const tensor_traits<block_q4_0, 4, 4> q4_0_4x4_q8_0;
|
|
4102
|
+
static const tensor_traits<block_q4_0, 8, 4> q4_0_4x8_q8_0;
|
|
4103
|
+
static const tensor_traits<block_q4_0, 8, 8> q4_0_8x8_q8_0;
|
|
4104
|
+
|
|
4105
|
+
// instance for IQ4
|
|
4106
|
+
static const tensor_traits<block_iq4_nl, 4, 4> iq4_nl_4x4_q8_0;
|
|
4107
|
+
|
|
4108
|
+
} // namespace ggml::cpu::aarch64
|
|
4109
|
+
|
|
4110
|
+
static const ggml::cpu::tensor_traits * wsp_ggml_aarch64_get_optimal_repack_type(const struct wsp_ggml_tensor * cur) {
|
|
3546
4111
|
if (cur->type == WSP_GGML_TYPE_Q4_0) {
|
|
3547
|
-
|
|
3548
|
-
|
|
3549
|
-
|
|
4112
|
+
if (wsp_ggml_cpu_has_avx2() || (wsp_ggml_cpu_has_sve() && wsp_ggml_cpu_has_matmul_int8() && wsp_ggml_cpu_get_sve_cnt() == QK8_0)) {
|
|
4113
|
+
if (cur->ne[1] % 8 == 0) {
|
|
4114
|
+
return &ggml::cpu::aarch64::q4_0_8x8_q8_0;
|
|
4115
|
+
}
|
|
3550
4116
|
}
|
|
3551
4117
|
if (wsp_ggml_cpu_has_neon() && wsp_ggml_cpu_has_matmul_int8()) {
|
|
3552
|
-
|
|
4118
|
+
if (cur->ne[1] % 4 == 0) {
|
|
4119
|
+
return &ggml::cpu::aarch64::q4_0_4x8_q8_0;
|
|
4120
|
+
}
|
|
3553
4121
|
}
|
|
3554
|
-
if (wsp_ggml_cpu_has_neon()) {
|
|
3555
|
-
|
|
4122
|
+
if (wsp_ggml_cpu_has_neon() && wsp_ggml_cpu_has_dotprod()) {
|
|
4123
|
+
if (cur->ne[1] % 4 == 0) {
|
|
4124
|
+
return &ggml::cpu::aarch64::q4_0_4x4_q8_0;
|
|
4125
|
+
}
|
|
4126
|
+
}
|
|
4127
|
+
} else if (cur->type == WSP_GGML_TYPE_IQ4_NL) {
|
|
4128
|
+
if (wsp_ggml_cpu_has_neon() && wsp_ggml_cpu_has_dotprod()) {
|
|
4129
|
+
if (cur->ne[1] % 4 == 0) {
|
|
4130
|
+
return &ggml::cpu::aarch64::iq4_nl_4x4_q8_0;
|
|
4131
|
+
}
|
|
3556
4132
|
}
|
|
3557
4133
|
}
|
|
3558
4134
|
|
|
3559
|
-
return
|
|
4135
|
+
return nullptr;
|
|
4136
|
+
}
|
|
4137
|
+
|
|
4138
|
+
static void wsp_ggml_backend_cpu_aarch64_buffer_init_tensor(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * tensor) {
|
|
4139
|
+
tensor->extra = (void *) const_cast<ggml::cpu::tensor_traits *>(wsp_ggml_aarch64_get_optimal_repack_type(tensor));
|
|
4140
|
+
|
|
4141
|
+
WSP_GGML_UNUSED(buffer);
|
|
4142
|
+
}
|
|
4143
|
+
|
|
4144
|
+
static void wsp_ggml_backend_cpu_aarch64_buffer_set_tensor(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * tensor,
|
|
4145
|
+
const void * data, size_t offset, size_t size) {
|
|
4146
|
+
WSP_GGML_ASSERT(offset == 0);
|
|
4147
|
+
WSP_GGML_ASSERT(size == wsp_ggml_nbytes(tensor));
|
|
4148
|
+
|
|
4149
|
+
auto tensor_traits = (ggml::cpu::aarch64::tensor_traits_base *) tensor->extra;
|
|
4150
|
+
auto OK = tensor_traits->repack(tensor, data, size);
|
|
4151
|
+
|
|
4152
|
+
WSP_GGML_ASSERT(OK == 0);
|
|
4153
|
+
WSP_GGML_UNUSED(buffer);
|
|
4154
|
+
}
|
|
4155
|
+
|
|
4156
|
+
static const char * wsp_ggml_backend_cpu_aarch64_buffer_type_get_name(wsp_ggml_backend_buffer_type_t buft) {
|
|
4157
|
+
return "CPU_AARCH64";
|
|
4158
|
+
|
|
4159
|
+
WSP_GGML_UNUSED(buft);
|
|
4160
|
+
}
|
|
4161
|
+
|
|
4162
|
+
static wsp_ggml_backend_buffer_t wsp_ggml_backend_cpu_aarch64_buffer_type_alloc_buffer(wsp_ggml_backend_buffer_type_t buft, size_t size) {
|
|
4163
|
+
wsp_ggml_backend_buffer_t buffer = wsp_ggml_backend_buft_alloc_buffer(wsp_ggml_backend_cpu_buffer_type(), size);
|
|
4164
|
+
|
|
4165
|
+
if (buffer == nullptr) {
|
|
4166
|
+
return nullptr;
|
|
4167
|
+
}
|
|
4168
|
+
|
|
4169
|
+
buffer->buft = buft;
|
|
4170
|
+
buffer->iface.init_tensor = wsp_ggml_backend_cpu_aarch64_buffer_init_tensor;
|
|
4171
|
+
buffer->iface.set_tensor = wsp_ggml_backend_cpu_aarch64_buffer_set_tensor;
|
|
4172
|
+
buffer->iface.get_tensor = nullptr;
|
|
4173
|
+
buffer->iface.cpy_tensor = nullptr;
|
|
4174
|
+
return buffer;
|
|
4175
|
+
}
|
|
4176
|
+
|
|
4177
|
+
static size_t wsp_ggml_backend_cpu_aarch64_buffer_type_get_alignment(wsp_ggml_backend_buffer_type_t buft) {
|
|
4178
|
+
return TENSOR_ALIGNMENT;
|
|
4179
|
+
|
|
4180
|
+
WSP_GGML_UNUSED(buft);
|
|
4181
|
+
}
|
|
4182
|
+
|
|
4183
|
+
namespace ggml::cpu::aarch64 {
|
|
4184
|
+
class extra_buffer_type : ggml::cpu::extra_buffer_type {
|
|
4185
|
+
bool supports_op(wsp_ggml_backend_dev_t, const struct wsp_ggml_tensor * op) override {
|
|
4186
|
+
if ( op->op == WSP_GGML_OP_MUL_MAT &&
|
|
4187
|
+
op->src[0]->buffer &&
|
|
4188
|
+
(wsp_ggml_n_dims(op->src[0]) == 2) &&
|
|
4189
|
+
op->src[0]->buffer->buft == wsp_ggml_backend_cpu_aarch64_buffer_type() &&
|
|
4190
|
+
wsp_ggml_aarch64_get_optimal_repack_type(op->src[0])
|
|
4191
|
+
) {
|
|
4192
|
+
if (op->src[1]->buffer && !wsp_ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
|
|
4193
|
+
return false;
|
|
4194
|
+
}
|
|
4195
|
+
if (op->src[1]->type == WSP_GGML_TYPE_F32) {
|
|
4196
|
+
return true;
|
|
4197
|
+
}
|
|
4198
|
+
//if (op->src[1]->type == WSP_GGML_TYPE_Q8_0) {
|
|
4199
|
+
// return true;
|
|
4200
|
+
//}
|
|
4201
|
+
// may be possible if Q8_0 packed...
|
|
4202
|
+
} else if (op->op == WSP_GGML_OP_MUL_MAT_ID
|
|
4203
|
+
&& op->src[0]->buffer
|
|
4204
|
+
&& (wsp_ggml_n_dims(op->src[0]) == 3)
|
|
4205
|
+
&& op->src[0]->buffer->buft == wsp_ggml_backend_cpu_aarch64_buffer_type()
|
|
4206
|
+
&& wsp_ggml_aarch64_get_optimal_repack_type(op->src[0])
|
|
4207
|
+
) {
|
|
4208
|
+
if (op->src[1]->buffer && !wsp_ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
|
|
4209
|
+
return false;
|
|
4210
|
+
}
|
|
4211
|
+
if (op->src[1]->type == WSP_GGML_TYPE_F32) {
|
|
4212
|
+
return true;
|
|
4213
|
+
}
|
|
4214
|
+
//if (op->src[1]->type == WSP_GGML_TYPE_Q8_0) {
|
|
4215
|
+
// return true;
|
|
4216
|
+
//}
|
|
4217
|
+
}
|
|
4218
|
+
return false;
|
|
4219
|
+
}
|
|
4220
|
+
|
|
4221
|
+
ggml::cpu::tensor_traits * get_tensor_traits(const struct wsp_ggml_tensor * op) override {
|
|
4222
|
+
if (op->op == WSP_GGML_OP_MUL_MAT || op->op == WSP_GGML_OP_MUL_MAT_ID) {
|
|
4223
|
+
if (op->src[0]->buffer && op->src[0]->buffer->buft == wsp_ggml_backend_cpu_aarch64_buffer_type()) {
|
|
4224
|
+
return (ggml::cpu::tensor_traits *) op->src[0]->extra;
|
|
4225
|
+
}
|
|
4226
|
+
}
|
|
4227
|
+
return nullptr;
|
|
4228
|
+
}
|
|
4229
|
+
};
|
|
4230
|
+
} // namespace ggml::cpu::aarch64
|
|
4231
|
+
|
|
4232
|
+
wsp_ggml_backend_buffer_type_t wsp_ggml_backend_cpu_aarch64_buffer_type(void) {
|
|
4233
|
+
static struct wsp_ggml_backend_buffer_type wsp_ggml_backend_cpu_buffer_type_aarch64 = {
|
|
4234
|
+
/* .iface = */ {
|
|
4235
|
+
/* .get_name = */ wsp_ggml_backend_cpu_aarch64_buffer_type_get_name,
|
|
4236
|
+
/* .alloc_buffer = */ wsp_ggml_backend_cpu_aarch64_buffer_type_alloc_buffer,
|
|
4237
|
+
/* .get_alignment = */ wsp_ggml_backend_cpu_aarch64_buffer_type_get_alignment,
|
|
4238
|
+
/* .get_max_size = */ nullptr, // defaults to SIZE_MAX
|
|
4239
|
+
/* .get_alloc_size = */ nullptr, // defaults to wsp_ggml_nbytes
|
|
4240
|
+
/* .is_host = */ nullptr,
|
|
4241
|
+
},
|
|
4242
|
+
/* .device = */ wsp_ggml_backend_reg_dev_get(wsp_ggml_backend_cpu_reg(), 0),
|
|
4243
|
+
/* .context = */ new ggml::cpu::aarch64::extra_buffer_type(),
|
|
4244
|
+
};
|
|
4245
|
+
|
|
4246
|
+
return &wsp_ggml_backend_cpu_buffer_type_aarch64;
|
|
3560
4247
|
}
|