llama_cpp 0.0.3 → 0.0.5
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +36 -0
- data/README.md +5 -4
- data/ext/llama_cpp/extconf.rb +38 -0
- data/ext/llama_cpp/llama_cpp.cpp +118 -2
- data/ext/llama_cpp/src/ggml.c +1740 -658
- data/ext/llama_cpp/src/ggml.h +84 -16
- data/ext/llama_cpp/src/llama.cpp +1108 -756
- data/ext/llama_cpp/src/llama.h +37 -1
- data/ext/llama_cpp/src/llama_util.h +396 -0
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +3 -3
- data/sig/llama_cpp.rbs +6 -0
- metadata +3 -2
data/ext/llama_cpp/src/ggml.c
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
// Defines CLOCK_MONOTONIC
|
1
|
+
// Defines CLOCK_MONOTONIC on Linux
|
2
2
|
#define _GNU_SOURCE
|
3
3
|
|
4
4
|
#include "ggml.h"
|
@@ -26,14 +26,9 @@
|
|
26
26
|
#define static_assert(cond, msg) struct global_scope_noop_trick
|
27
27
|
#endif
|
28
28
|
|
29
|
-
#if defined
|
29
|
+
#if defined(_WIN32)
|
30
30
|
|
31
|
-
#if !defined(__MINGW32__)
|
32
|
-
#include <Windows.h>
|
33
|
-
#else
|
34
|
-
// ref: https://github.com/ggerganov/whisper.cpp/issues/168
|
35
31
|
#include <windows.h>
|
36
|
-
#endif
|
37
32
|
|
38
33
|
typedef volatile LONG atomic_int;
|
39
34
|
typedef atomic_int atomic_bool;
|
@@ -55,6 +50,7 @@ typedef HANDLE pthread_t;
|
|
55
50
|
|
56
51
|
typedef DWORD thread_ret_t;
|
57
52
|
static int pthread_create(pthread_t* out, void* unused, thread_ret_t(*func)(void*), void* arg) {
|
53
|
+
(void) unused;
|
58
54
|
HANDLE handle = CreateThread(NULL, 0, (LPTHREAD_START_ROUTINE) func, arg, 0, NULL);
|
59
55
|
if (handle == NULL)
|
60
56
|
{
|
@@ -66,6 +62,7 @@ static int pthread_create(pthread_t* out, void* unused, thread_ret_t(*func)(void
|
|
66
62
|
}
|
67
63
|
|
68
64
|
static int pthread_join(pthread_t thread, void* unused) {
|
65
|
+
(void) unused;
|
69
66
|
return (int) WaitForSingleObject(thread, INFINITE);
|
70
67
|
}
|
71
68
|
|
@@ -97,17 +94,6 @@ typedef void* thread_ret_t;
|
|
97
94
|
#define static_assert(cond, msg) _Static_assert(cond, msg)
|
98
95
|
#endif
|
99
96
|
|
100
|
-
#define GGML_MLOCK_SUPPORT 0
|
101
|
-
|
102
|
-
#ifdef __has_include
|
103
|
-
#if __has_include(<sys/mman.h>)
|
104
|
-
#undef GGML_MLOCK_SUPPORT
|
105
|
-
#define GGML_MLOCK_SUPPORT 1
|
106
|
-
#include <sys/mman.h>
|
107
|
-
#endif
|
108
|
-
#endif
|
109
|
-
|
110
|
-
|
111
97
|
/*#define GGML_PERF*/
|
112
98
|
#define GGML_DEBUG 0
|
113
99
|
#define GGML_GELU_FP16
|
@@ -128,6 +114,23 @@ typedef void* thread_ret_t;
|
|
128
114
|
#define GGML_MEM_ALIGN 16
|
129
115
|
#endif
|
130
116
|
|
117
|
+
#if defined(_MSC_VER) || defined(__MINGW32__)
|
118
|
+
#define GGML_ALIGNED_MALLOC(size) _aligned_malloc(size, GGML_MEM_ALIGN)
|
119
|
+
#define GGML_ALIGNED_FREE(ptr) _aligned_free(ptr)
|
120
|
+
#else
|
121
|
+
inline static void* ggml_aligned_malloc(size_t size) {
|
122
|
+
void* aligned_memory = NULL;
|
123
|
+
int result = posix_memalign(&aligned_memory, GGML_MEM_ALIGN, size);
|
124
|
+
if (result != 0) {
|
125
|
+
// Handle allocation failure
|
126
|
+
return NULL;
|
127
|
+
}
|
128
|
+
return aligned_memory;
|
129
|
+
}
|
130
|
+
#define GGML_ALIGNED_MALLOC(size) ggml_aligned_malloc(size)
|
131
|
+
#define GGML_ALIGNED_FREE(ptr) free(ptr)
|
132
|
+
#endif
|
133
|
+
|
131
134
|
#define UNUSED(x) (void)(x)
|
132
135
|
#define SWAP(x, y, T) do { T SWAP = x; x = y; y = SWAP; } while (0)
|
133
136
|
|
@@ -242,12 +245,12 @@ static inline float fp32_from_bits(uint32_t w) {
|
|
242
245
|
}
|
243
246
|
|
244
247
|
static inline uint32_t fp32_to_bits(float f) {
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
248
|
+
union {
|
249
|
+
float as_value;
|
250
|
+
uint32_t as_bits;
|
251
|
+
} fp32;
|
252
|
+
fp32.as_value = f;
|
253
|
+
return fp32.as_bits;
|
251
254
|
}
|
252
255
|
|
253
256
|
static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) {
|
@@ -424,8 +427,6 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
|
|
424
427
|
// quantization
|
425
428
|
//
|
426
429
|
|
427
|
-
#define QK 32
|
428
|
-
|
429
430
|
// AVX routines provided by GH user Const-me
|
430
431
|
// ref: https://github.com/ggerganov/ggml/pull/27#issuecomment-1464934600
|
431
432
|
#if __AVX2__ || __AVX512F__
|
@@ -497,37 +498,113 @@ static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
|
|
497
498
|
}
|
498
499
|
#endif
|
499
500
|
|
500
|
-
|
501
|
-
|
502
|
-
|
501
|
+
#if __ARM_NEON
|
502
|
+
|
503
|
+
#if !defined(__aarch64__)
|
504
|
+
|
505
|
+
inline static uint16_t vaddvq_u8(uint8x16_t v) {
|
506
|
+
return
|
507
|
+
(uint16_t)vgetq_lane_u8(v, 0) + (uint16_t)vgetq_lane_u8(v, 1) +
|
508
|
+
(uint16_t)vgetq_lane_u8(v, 2) + (uint16_t)vgetq_lane_u8(v, 3) +
|
509
|
+
(uint16_t)vgetq_lane_u8(v, 4) + (uint16_t)vgetq_lane_u8(v, 5) +
|
510
|
+
(uint16_t)vgetq_lane_u8(v, 6) + (uint16_t)vgetq_lane_u8(v, 7) +
|
511
|
+
(uint16_t)vgetq_lane_u8(v, 8) + (uint16_t)vgetq_lane_u8(v, 9) +
|
512
|
+
(uint16_t)vgetq_lane_u8(v, 10) + (uint16_t)vgetq_lane_u8(v, 11) +
|
513
|
+
(uint16_t)vgetq_lane_u8(v, 12) + (uint16_t)vgetq_lane_u8(v, 13) +
|
514
|
+
(uint16_t)vgetq_lane_u8(v, 14) + (uint16_t)vgetq_lane_u8(v, 15);
|
515
|
+
}
|
516
|
+
|
517
|
+
inline static int32_t vaddvq_s16(int16x8_t v) {
|
518
|
+
return
|
519
|
+
(int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) +
|
520
|
+
(int32_t)vgetq_lane_s16(v, 2) + (int32_t)vgetq_lane_s16(v, 3) +
|
521
|
+
(int32_t)vgetq_lane_s16(v, 4) + (int32_t)vgetq_lane_s16(v, 5) +
|
522
|
+
(int32_t)vgetq_lane_s16(v, 6) + (int32_t)vgetq_lane_s16(v, 7);
|
523
|
+
}
|
524
|
+
|
525
|
+
inline static uint32_t vaddvq_u16(uint16x8_t v) {
|
526
|
+
return
|
527
|
+
(uint32_t)vgetq_lane_u16(v, 0) + (uint32_t)vgetq_lane_u16(v, 1) +
|
528
|
+
(uint32_t)vgetq_lane_u16(v, 2) + (uint32_t)vgetq_lane_u16(v, 3) +
|
529
|
+
(uint32_t)vgetq_lane_u16(v, 4) + (uint32_t)vgetq_lane_u16(v, 5) +
|
530
|
+
(uint32_t)vgetq_lane_u16(v, 6) + (uint32_t)vgetq_lane_u16(v, 7);
|
531
|
+
}
|
532
|
+
|
533
|
+
inline static int32_t vaddvq_s32(int32x4_t v) {
|
534
|
+
return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3);
|
535
|
+
}
|
536
|
+
|
537
|
+
inline static float vaddvq_f32(float32x4_t v) {
|
538
|
+
return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3);
|
539
|
+
}
|
540
|
+
|
541
|
+
float vminvq_f32(float32x4_t v) {
|
542
|
+
return
|
543
|
+
MIN(MIN(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)),
|
544
|
+
MIN(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3)));
|
545
|
+
}
|
546
|
+
|
547
|
+
float vmaxvq_f32(float32x4_t v) {
|
548
|
+
return
|
549
|
+
MAX(MAX(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)),
|
550
|
+
MAX(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3)));
|
551
|
+
}
|
552
|
+
|
553
|
+
int8x8_t vzip1_s8(int8x8_t a, int8x8_t b) {
|
554
|
+
return vget_low_s8(vcombine_s8(a, b));
|
555
|
+
}
|
556
|
+
|
557
|
+
int8x8_t vzip2_s8(int8x8_t a, int8x8_t b) {
|
558
|
+
return vget_high_s8(vcombine_s8(a, b));
|
559
|
+
}
|
560
|
+
|
561
|
+
uint8x8_t vzip1_u8(uint8x8_t a, uint8x8_t b) {
|
562
|
+
return vget_low_u8(vcombine_u8(a, b));
|
563
|
+
}
|
564
|
+
|
565
|
+
uint8x8_t vzip2_u8(uint8x8_t a, uint8x8_t b) {
|
566
|
+
return vget_high_u8(vcombine_u8(a, b));
|
567
|
+
}
|
568
|
+
|
569
|
+
#endif
|
570
|
+
#endif
|
571
|
+
|
572
|
+
|
573
|
+
#define QK4_0 32
|
503
574
|
typedef struct {
|
504
|
-
float d;
|
505
|
-
uint8_t qs[
|
575
|
+
float d; // delta
|
576
|
+
uint8_t qs[QK4_0 / 2]; // nibbles / quants
|
506
577
|
} block_q4_0;
|
507
|
-
static_assert(sizeof(block_q4_0) == sizeof(float) +
|
578
|
+
static_assert(sizeof(block_q4_0) == sizeof(float) + QK4_0 / 2, "wrong q4_0 block size/padding");
|
508
579
|
|
509
|
-
|
510
|
-
// blocks of QK elements
|
511
|
-
// represented with 2 floats (delta + min) and QK/2 8-bit ints (i.e QK 4-bit unsigned integer factors)
|
580
|
+
#define QK4_1 32
|
512
581
|
typedef struct {
|
513
|
-
float d;
|
514
|
-
float m;
|
515
|
-
uint8_t qs[
|
582
|
+
float d; // delta
|
583
|
+
float m; // min
|
584
|
+
uint8_t qs[QK4_1 / 2]; // nibbles / quants
|
516
585
|
} block_q4_1;
|
517
|
-
static_assert(sizeof(block_q4_1) == sizeof(float) * 2 +
|
586
|
+
static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding");
|
587
|
+
|
588
|
+
#define QK8_0 32
|
589
|
+
typedef struct {
|
590
|
+
float d; // delta
|
591
|
+
int8_t qs[QK8_0]; // quants
|
592
|
+
} block_q8_0;
|
593
|
+
static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding");
|
594
|
+
|
518
595
|
|
519
596
|
// reference implementation for deterministic creation of model files
|
520
597
|
static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) {
|
521
|
-
assert(k %
|
522
|
-
const int nb = k /
|
598
|
+
assert(k % QK4_0 == 0);
|
599
|
+
const int nb = k / QK4_0;
|
523
600
|
|
524
|
-
uint8_t pp[
|
601
|
+
uint8_t pp[QK4_0/2];
|
525
602
|
|
526
603
|
for (int i = 0; i < nb; i++) {
|
527
604
|
float amax = 0.0f; // absolute max
|
528
605
|
|
529
|
-
for (int l = 0; l <
|
530
|
-
const float v = x[i*
|
606
|
+
for (int l = 0; l < QK4_0; l++) {
|
607
|
+
const float v = x[i*QK4_0 + l];
|
531
608
|
amax = MAX(amax, fabsf(v));
|
532
609
|
}
|
533
610
|
|
@@ -536,9 +613,9 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r
|
|
536
613
|
|
537
614
|
y[i].d = d;
|
538
615
|
|
539
|
-
for (int l = 0; l <
|
540
|
-
const float v0 = x[i*
|
541
|
-
const float v1 = x[i*
|
616
|
+
for (int l = 0; l < QK4_0; l += 2) {
|
617
|
+
const float v0 = x[i*QK4_0 + l + 0]*id;
|
618
|
+
const float v1 = x[i*QK4_0 + l + 1]*id;
|
542
619
|
|
543
620
|
const uint8_t vi0 = (int8_t)roundf(v0) + 8;
|
544
621
|
const uint8_t vi1 = (int8_t)roundf(v1) + 8;
|
@@ -554,8 +631,8 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r
|
|
554
631
|
}
|
555
632
|
|
556
633
|
static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int k) {
|
557
|
-
assert(k %
|
558
|
-
const int nb = k /
|
634
|
+
assert(k % QK4_0 == 0);
|
635
|
+
const int nb = k / QK4_0;
|
559
636
|
|
560
637
|
block_q4_0 * restrict y = vy;
|
561
638
|
|
@@ -610,10 +687,7 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
|
|
610
687
|
for (int l = 0; l < 2; l++) amaxv[4*l] = vmaxq_f32(amaxv[4*l], amaxv[4*l+2]);
|
611
688
|
for (int l = 0; l < 1; l++) amaxv[8*l] = vmaxq_f32(amaxv[8*l], amaxv[8*l+4]);
|
612
689
|
|
613
|
-
|
614
|
-
const float amax = MAX(
|
615
|
-
MAX(vgetq_lane_f32(amaxv[0], 0), vgetq_lane_f32(amaxv[0], 1)),
|
616
|
-
MAX(vgetq_lane_f32(amaxv[0], 2), vgetq_lane_f32(amaxv[0], 3)));
|
690
|
+
const float amax = vmaxvq_f32(amaxv[0]);
|
617
691
|
|
618
692
|
const float d = amax / ((1 << 3) - 1);
|
619
693
|
const float id = d ? 1.0f/d : 0.0f;
|
@@ -808,19 +882,19 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
|
|
808
882
|
}
|
809
883
|
|
810
884
|
static void quantize_row_q4_1_reference(const float * restrict x, void * restrict vy, int k) {
|
811
|
-
assert(k %
|
812
|
-
const int nb = k /
|
885
|
+
assert(k % QK4_1 == 0);
|
886
|
+
const int nb = k / QK4_1;
|
813
887
|
|
814
888
|
block_q4_1 * restrict y = vy;
|
815
889
|
|
816
|
-
uint8_t pp[
|
890
|
+
uint8_t pp[QK4_1/2];
|
817
891
|
|
818
892
|
for (int i = 0; i < nb; i++) {
|
819
893
|
float min = FLT_MAX;
|
820
894
|
float max = -FLT_MAX;
|
821
895
|
|
822
|
-
for (int l = 0; l <
|
823
|
-
const float v = x[i*
|
896
|
+
for (int l = 0; l < QK4_1; l++) {
|
897
|
+
const float v = x[i*QK4_1 + l];
|
824
898
|
if (v < min) min = v;
|
825
899
|
if (v > max) max = v;
|
826
900
|
}
|
@@ -831,9 +905,9 @@ static void quantize_row_q4_1_reference(const float * restrict x, void * restric
|
|
831
905
|
y[i].d = d;
|
832
906
|
y[i].m = min;
|
833
907
|
|
834
|
-
for (int l = 0; l <
|
835
|
-
const float v0 = (x[i*
|
836
|
-
const float v1 = (x[i*
|
908
|
+
for (int l = 0; l < QK4_1; l += 2) {
|
909
|
+
const float v0 = (x[i*QK4_1 + l + 0] - min)*id;
|
910
|
+
const float v1 = (x[i*QK4_1 + l + 1] - min)*id;
|
837
911
|
|
838
912
|
const uint8_t vi0 = roundf(v0);
|
839
913
|
const uint8_t vi1 = roundf(v1);
|
@@ -849,9 +923,9 @@ static void quantize_row_q4_1_reference(const float * restrict x, void * restric
|
|
849
923
|
}
|
850
924
|
|
851
925
|
static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int k) {
|
852
|
-
assert(k %
|
926
|
+
assert(k % QK4_1 == 0);
|
853
927
|
|
854
|
-
const int nb = k /
|
928
|
+
const int nb = k / QK4_1;
|
855
929
|
|
856
930
|
block_q4_1 * restrict y = vy;
|
857
931
|
|
@@ -935,7 +1009,7 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int
|
|
935
1009
|
float32x4_t minv[8];
|
936
1010
|
float32x4_t maxv[8];
|
937
1011
|
|
938
|
-
for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*
|
1012
|
+
for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*QK4_1 + 4*l);
|
939
1013
|
|
940
1014
|
for (int l = 0; l < 4; l++) minv[2*l] = vminq_f32(srcv[2*l], srcv[2*l + 1]);
|
941
1015
|
for (int l = 0; l < 2; l++) minv[4*l] = vminq_f32(minv[4*l], minv[4*l + 2]);
|
@@ -958,7 +1032,8 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int
|
|
958
1032
|
|
959
1033
|
for (int l = 0; l < 8; l++) {
|
960
1034
|
const float32x4_t v = vmulq_n_f32(vsubq_f32(srcv[l], minv0), id);
|
961
|
-
const
|
1035
|
+
const float32x4_t vf = vaddq_f32(v, vdupq_n_f32(0.5f)); // needed to round to nearest
|
1036
|
+
const int32x4_t vi = vcvtq_s32_f32(vf);
|
962
1037
|
|
963
1038
|
y[i].qs[2*l + 0] = vgetq_lane_s32(vi, 0) | (vgetq_lane_s32(vi, 1) << 4);
|
964
1039
|
y[i].qs[2*l + 1] = vgetq_lane_s32(vi, 2) | (vgetq_lane_s32(vi, 3) << 4);
|
@@ -970,9 +1045,160 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int
|
|
970
1045
|
#endif
|
971
1046
|
}
|
972
1047
|
|
1048
|
+
// reference implementation for deterministic creation of model files
|
1049
|
+
static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int k) {
|
1050
|
+
assert(k % QK8_0 == 0);
|
1051
|
+
const int nb = k / QK8_0;
|
1052
|
+
|
1053
|
+
for (int i = 0; i < nb; i++) {
|
1054
|
+
float amax = 0.0f; // absolute max
|
1055
|
+
|
1056
|
+
for (int l = 0; l < QK8_0; l++) {
|
1057
|
+
const float v = x[i*QK8_0 + l];
|
1058
|
+
amax = MAX(amax, fabsf(v));
|
1059
|
+
}
|
1060
|
+
|
1061
|
+
const float d = amax / ((1 << 7) - 1);
|
1062
|
+
const float id = d ? 1.0f/d : 0.0f;
|
1063
|
+
|
1064
|
+
y[i].d = d;
|
1065
|
+
|
1066
|
+
for (int l = 0; l < QK8_0; ++l) {
|
1067
|
+
const float v = x[i*QK8_0 + l]*id;
|
1068
|
+
y[i].qs[l] = roundf(v);
|
1069
|
+
}
|
1070
|
+
}
|
1071
|
+
}
|
1072
|
+
|
1073
|
+
static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int k) {
|
1074
|
+
assert(k % QK8_0 == 0);
|
1075
|
+
const int nb = k / QK8_0;
|
1076
|
+
|
1077
|
+
block_q8_0 * restrict y = vy;
|
1078
|
+
|
1079
|
+
#if defined(__ARM_NEON)
|
1080
|
+
for (int i = 0; i < nb; i++) {
|
1081
|
+
float32x4_t srcv [8];
|
1082
|
+
float32x4_t asrcv[8];
|
1083
|
+
float32x4_t amaxv[8];
|
1084
|
+
|
1085
|
+
for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*32 + 4*l);
|
1086
|
+
for (int l = 0; l < 8; l++) asrcv[l] = vabsq_f32(srcv[l]);
|
1087
|
+
|
1088
|
+
for (int l = 0; l < 4; l++) amaxv[2*l] = vmaxq_f32(asrcv[2*l], asrcv[2*l+1]);
|
1089
|
+
for (int l = 0; l < 2; l++) amaxv[4*l] = vmaxq_f32(amaxv[4*l], amaxv[4*l+2]);
|
1090
|
+
for (int l = 0; l < 1; l++) amaxv[8*l] = vmaxq_f32(amaxv[8*l], amaxv[8*l+4]);
|
1091
|
+
|
1092
|
+
const float amax = vmaxvq_f32(amaxv[0]);
|
1093
|
+
|
1094
|
+
const float d = amax / ((1 << 7) - 1);
|
1095
|
+
const float id = d ? 1.0f/d : 0.0f;
|
1096
|
+
|
1097
|
+
y[i].d = d;
|
1098
|
+
|
1099
|
+
for (int l = 0; l < 8; l++) {
|
1100
|
+
const float32x4_t v = vmulq_n_f32(srcv[l], id);
|
1101
|
+
const int32x4_t vi = vcvtnq_s32_f32(v);
|
1102
|
+
|
1103
|
+
y[i].qs[4*l + 0] = vgetq_lane_s32(vi, 0);
|
1104
|
+
y[i].qs[4*l + 1] = vgetq_lane_s32(vi, 1);
|
1105
|
+
y[i].qs[4*l + 2] = vgetq_lane_s32(vi, 2);
|
1106
|
+
y[i].qs[4*l + 3] = vgetq_lane_s32(vi, 3);
|
1107
|
+
}
|
1108
|
+
}
|
1109
|
+
#elif defined(__AVX2__) || defined(__AVX__)
|
1110
|
+
for (int i = 0; i < nb; i++) {
|
1111
|
+
// Load elements into 4 AVX vectors
|
1112
|
+
__m256 v0 = _mm256_loadu_ps( x );
|
1113
|
+
__m256 v1 = _mm256_loadu_ps( x + 8 );
|
1114
|
+
__m256 v2 = _mm256_loadu_ps( x + 16 );
|
1115
|
+
__m256 v3 = _mm256_loadu_ps( x + 24 );
|
1116
|
+
x += 32;
|
1117
|
+
|
1118
|
+
// Compute max(abs(e)) for the block
|
1119
|
+
const __m256 signBit = _mm256_set1_ps( -0.0f );
|
1120
|
+
__m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
|
1121
|
+
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
|
1122
|
+
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
|
1123
|
+
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
|
1124
|
+
|
1125
|
+
__m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
|
1126
|
+
max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
|
1127
|
+
max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
|
1128
|
+
const float maxScalar = _mm_cvtss_f32( max4 );
|
1129
|
+
|
1130
|
+
// Quantize these floats
|
1131
|
+
const float d = maxScalar / 127.f;
|
1132
|
+
y[i].d = d;
|
1133
|
+
const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
|
1134
|
+
const __m256 mul = _mm256_set1_ps( id );
|
1135
|
+
|
1136
|
+
// Apply the multiplier
|
1137
|
+
v0 = _mm256_mul_ps( v0, mul );
|
1138
|
+
v1 = _mm256_mul_ps( v1, mul );
|
1139
|
+
v2 = _mm256_mul_ps( v2, mul );
|
1140
|
+
v3 = _mm256_mul_ps( v3, mul );
|
1141
|
+
|
1142
|
+
// Round to nearest integer
|
1143
|
+
v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
|
1144
|
+
v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
|
1145
|
+
v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
|
1146
|
+
v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
|
1147
|
+
|
1148
|
+
// Convert floats to integers
|
1149
|
+
__m256i i0 = _mm256_cvtps_epi32( v0 );
|
1150
|
+
__m256i i1 = _mm256_cvtps_epi32( v1 );
|
1151
|
+
__m256i i2 = _mm256_cvtps_epi32( v2 );
|
1152
|
+
__m256i i3 = _mm256_cvtps_epi32( v3 );
|
1153
|
+
|
1154
|
+
#if defined(__AVX2__)
|
1155
|
+
// Convert int32 to int16
|
1156
|
+
i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
|
1157
|
+
i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
|
1158
|
+
// Convert int16 to int8
|
1159
|
+
i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
|
1160
|
+
|
1161
|
+
// We got our precious signed bytes, but the order is now wrong
|
1162
|
+
// These AVX2 pack instructions process 16-byte pieces independently
|
1163
|
+
// The following instruction is fixing the order
|
1164
|
+
const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
|
1165
|
+
i0 = _mm256_permutevar8x32_epi32( i0, perm );
|
1166
|
+
|
1167
|
+
_mm256_storeu_si256((__m256i *)y[i].qs, i0);
|
1168
|
+
#else
|
1169
|
+
// Since we don't have in AVX some necessary functions,
|
1170
|
+
// we split the registers in half and call AVX2 analogs from SSE
|
1171
|
+
__m128i ni0 = _mm256_castsi256_si128( i0 );
|
1172
|
+
__m128i ni1 = _mm256_extractf128_si256( i0, 1);
|
1173
|
+
__m128i ni2 = _mm256_castsi256_si128( i1 );
|
1174
|
+
__m128i ni3 = _mm256_extractf128_si256( i1, 1);
|
1175
|
+
__m128i ni4 = _mm256_castsi256_si128( i2 );
|
1176
|
+
__m128i ni5 = _mm256_extractf128_si256( i2, 1);
|
1177
|
+
__m128i ni6 = _mm256_castsi256_si128( i3 );
|
1178
|
+
__m128i ni7 = _mm256_extractf128_si256( i3, 1);
|
1179
|
+
|
1180
|
+
// Convert int32 to int16
|
1181
|
+
ni0 = _mm_packs_epi32( ni0, ni1 );
|
1182
|
+
ni2 = _mm_packs_epi32( ni2, ni3 );
|
1183
|
+
ni4 = _mm_packs_epi32( ni4, ni5 );
|
1184
|
+
ni6 = _mm_packs_epi32( ni6, ni7 );
|
1185
|
+
// Convert int16 to int8
|
1186
|
+
ni0 = _mm_packs_epi16( ni0, ni2 );
|
1187
|
+
ni4 = _mm_packs_epi16( ni4, ni6 );
|
1188
|
+
|
1189
|
+
_mm_storeu_si128((__m128i *)(y[i].qs + 0), ni0);
|
1190
|
+
_mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
|
1191
|
+
#endif
|
1192
|
+
}
|
1193
|
+
#else
|
1194
|
+
// scalar
|
1195
|
+
quantize_row_q8_0_reference(x, y, k);
|
1196
|
+
#endif
|
1197
|
+
}
|
1198
|
+
|
973
1199
|
static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, int k) {
|
974
|
-
assert(k %
|
975
|
-
const int nb = k /
|
1200
|
+
assert(k % QK4_0 == 0);
|
1201
|
+
const int nb = k / QK4_0;
|
976
1202
|
|
977
1203
|
const block_q4_0 * restrict x = vx;
|
978
1204
|
|
@@ -983,7 +1209,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
|
|
983
1209
|
|
984
1210
|
const uint8_t * restrict pp = x[i].qs;
|
985
1211
|
|
986
|
-
for (int l = 0; l <
|
1212
|
+
for (int l = 0; l < QK4_0; l += 32) {
|
987
1213
|
// Load 32x4-bit integers into 32x8-bit integers
|
988
1214
|
__m256i vx8 = bytesFromNibbles(pp+l/2);
|
989
1215
|
|
@@ -1005,7 +1231,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
|
|
1005
1231
|
// Scale and store
|
1006
1232
|
for (int j = 0; j < 4; j++) {
|
1007
1233
|
const __m256 result = _mm256_mul_ps(vf[j], d_v);
|
1008
|
-
_mm256_storeu_ps(y + i *
|
1234
|
+
_mm256_storeu_ps(y + i * QK4_0 + l + j*8, result);
|
1009
1235
|
}
|
1010
1236
|
}
|
1011
1237
|
}
|
@@ -1015,7 +1241,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
|
|
1015
1241
|
|
1016
1242
|
const uint8_t * restrict pp = x[i].qs;
|
1017
1243
|
|
1018
|
-
for (int l = 0; l <
|
1244
|
+
for (int l = 0; l < QK4_0; l += 16) {
|
1019
1245
|
// Load 16x4-bit integers into 8x8-bit integers
|
1020
1246
|
const uint8x8_t v8 = vld1_u8(pp + l/2);
|
1021
1247
|
|
@@ -1054,10 +1280,10 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
|
|
1054
1280
|
const float32x4_t r3 = vmulq_f32(vf_3, vd);
|
1055
1281
|
|
1056
1282
|
// Store
|
1057
|
-
vst1q_f32(y + i*
|
1058
|
-
vst1q_f32(y + i*
|
1059
|
-
vst1q_f32(y + i*
|
1060
|
-
vst1q_f32(y + i*
|
1283
|
+
vst1q_f32(y + i*QK4_0 + l + 0, r0);
|
1284
|
+
vst1q_f32(y + i*QK4_0 + l + 4, r1);
|
1285
|
+
vst1q_f32(y + i*QK4_0 + l + 8, r2);
|
1286
|
+
vst1q_f32(y + i*QK4_0 + l + 12, r3);
|
1061
1287
|
}
|
1062
1288
|
}
|
1063
1289
|
#else
|
@@ -1067,7 +1293,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
|
|
1067
1293
|
|
1068
1294
|
const uint8_t * restrict pp = x[i].qs;
|
1069
1295
|
|
1070
|
-
for (int l = 0; l <
|
1296
|
+
for (int l = 0; l < QK4_0; l += 2) {
|
1071
1297
|
const uint8_t vi = pp[l/2];
|
1072
1298
|
|
1073
1299
|
const int8_t vi0 = vi & 0xf;
|
@@ -1078,19 +1304,19 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
|
|
1078
1304
|
|
1079
1305
|
//printf("d = %f, vi = %d, vi0 = %d, vi1 = %d, v0 = %f, v1 = %f\n", d, vi, vi0, vi1, v0, v1);
|
1080
1306
|
|
1081
|
-
y[i*
|
1082
|
-
y[i*
|
1307
|
+
y[i*QK4_0 + l + 0] = v0;
|
1308
|
+
y[i*QK4_0 + l + 1] = v1;
|
1083
1309
|
|
1084
|
-
assert(!isnan(y[i*
|
1085
|
-
assert(!isnan(y[i*
|
1310
|
+
assert(!isnan(y[i*QK4_0 + l + 0]));
|
1311
|
+
assert(!isnan(y[i*QK4_0 + l + 1]));
|
1086
1312
|
}
|
1087
1313
|
}
|
1088
1314
|
#endif
|
1089
1315
|
}
|
1090
1316
|
|
1091
1317
|
static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, int k) {
|
1092
|
-
assert(k %
|
1093
|
-
const int nb = k /
|
1318
|
+
assert(k % QK4_1 == 0);
|
1319
|
+
const int nb = k / QK4_1;
|
1094
1320
|
|
1095
1321
|
const block_q4_1 * restrict x = vx;
|
1096
1322
|
|
@@ -1101,7 +1327,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
|
|
1101
1327
|
|
1102
1328
|
const uint8_t * restrict pp = x[i].qs;
|
1103
1329
|
|
1104
|
-
for (int l = 0; l <
|
1330
|
+
for (int l = 0; l < QK4_1; l += 32) {
|
1105
1331
|
// Load 32x4-bit integers into 32x8-bit integers
|
1106
1332
|
__m256i vx8 = bytesFromNibbles(pp+l/2);
|
1107
1333
|
|
@@ -1120,7 +1346,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
|
|
1120
1346
|
// Scale, add m and store
|
1121
1347
|
for (int j = 0; j < 4; j++) {
|
1122
1348
|
const __m256 result = _mm256_add_ps(_mm256_mul_ps(vf[j], d_v), d_m);
|
1123
|
-
_mm256_storeu_ps(y + i *
|
1349
|
+
_mm256_storeu_ps(y + i * QK4_1 + l + j*8, result);
|
1124
1350
|
}
|
1125
1351
|
}
|
1126
1352
|
}
|
@@ -1131,7 +1357,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
|
|
1131
1357
|
|
1132
1358
|
const uint8_t * restrict pp = x[i].qs;
|
1133
1359
|
|
1134
|
-
for (int l = 0; l <
|
1360
|
+
for (int l = 0; l < QK4_1; l += 16) {
|
1135
1361
|
// Load 16x4-bit integers into 8x8-bit integers
|
1136
1362
|
const uint8x8_t v8 = vld1_u8(pp + l/2);
|
1137
1363
|
|
@@ -1162,10 +1388,10 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
|
|
1162
1388
|
const float32x4_t r3 = vmlaq_f32(vm, vf_3, vd);
|
1163
1389
|
|
1164
1390
|
// Store
|
1165
|
-
vst1q_f32(y + i*
|
1166
|
-
vst1q_f32(y + i*
|
1167
|
-
vst1q_f32(y + i*
|
1168
|
-
vst1q_f32(y + i*
|
1391
|
+
vst1q_f32(y + i*QK4_1 + l + 0, r0);
|
1392
|
+
vst1q_f32(y + i*QK4_1 + l + 4, r1);
|
1393
|
+
vst1q_f32(y + i*QK4_1 + l + 8, r2);
|
1394
|
+
vst1q_f32(y + i*QK4_1 + l + 12, r3);
|
1169
1395
|
}
|
1170
1396
|
}
|
1171
1397
|
#else
|
@@ -1175,7 +1401,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
|
|
1175
1401
|
|
1176
1402
|
const uint8_t * restrict pp = x[i].qs;
|
1177
1403
|
|
1178
|
-
for (int l = 0; l <
|
1404
|
+
for (int l = 0; l < QK4_1; l += 2) {
|
1179
1405
|
const uint8_t vi = pp[l/2];
|
1180
1406
|
|
1181
1407
|
const int8_t vi0 = vi & 0xf;
|
@@ -1184,16 +1410,44 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
|
|
1184
1410
|
const float v0 = vi0*d + m;
|
1185
1411
|
const float v1 = vi1*d + m;
|
1186
1412
|
|
1187
|
-
y[i*
|
1188
|
-
y[i*
|
1413
|
+
y[i*QK4_1 + l + 0] = v0;
|
1414
|
+
y[i*QK4_1 + l + 1] = v1;
|
1189
1415
|
|
1190
|
-
assert(!isnan(y[i*
|
1191
|
-
assert(!isnan(y[i*
|
1416
|
+
assert(!isnan(y[i*QK4_1 + l + 0]));
|
1417
|
+
assert(!isnan(y[i*QK4_1 + l + 1]));
|
1192
1418
|
}
|
1193
1419
|
}
|
1194
1420
|
#endif
|
1195
1421
|
}
|
1196
1422
|
|
1423
|
+
static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
1424
|
+
static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
1425
|
+
|
1426
|
+
static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
|
1427
|
+
[GGML_TYPE_Q4_0] = {
|
1428
|
+
.dequantize_row_q = dequantize_row_q4_0,
|
1429
|
+
.quantize_row_q = quantize_row_q4_0,
|
1430
|
+
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_0_reference,
|
1431
|
+
.quantize_row_q_dot = quantize_row_q8_0,
|
1432
|
+
.vec_dot_q = ggml_vec_dot_q4_0_q8_0,
|
1433
|
+
},
|
1434
|
+
[GGML_TYPE_Q4_1] = {
|
1435
|
+
.dequantize_row_q = dequantize_row_q4_1,
|
1436
|
+
.quantize_row_q = quantize_row_q4_1,
|
1437
|
+
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_1_reference,
|
1438
|
+
.quantize_row_q_dot = quantize_row_q4_1,
|
1439
|
+
.vec_dot_q = ggml_vec_dot_q4_1,
|
1440
|
+
},
|
1441
|
+
// TODO: GGML_TYPE_Q8_0
|
1442
|
+
};
|
1443
|
+
|
1444
|
+
// For internal test use
|
1445
|
+
quantize_fns_t ggml_internal_get_quantize_fn(size_t i) {
|
1446
|
+
GGML_ASSERT(i < GGML_TYPE_COUNT);
|
1447
|
+
return quantize_fns[i];
|
1448
|
+
}
|
1449
|
+
|
1450
|
+
|
1197
1451
|
//
|
1198
1452
|
// simd mappings
|
1199
1453
|
//
|
@@ -1226,15 +1480,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
|
|
1226
1480
|
#define GGML_F32x4_FMA(a, b, c) vfmaq_f32(a, b, c)
|
1227
1481
|
#define GGML_F32x4_ADD vaddq_f32
|
1228
1482
|
#define GGML_F32x4_MUL vmulq_f32
|
1229
|
-
#
|
1230
|
-
#define GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x)
|
1231
|
-
#else
|
1232
|
-
#define GGML_F32x4_REDUCE_ONE(x) \
|
1233
|
-
(vgetq_lane_f32(x, 0) + \
|
1234
|
-
vgetq_lane_f32(x, 1) + \
|
1235
|
-
vgetq_lane_f32(x, 2) + \
|
1236
|
-
vgetq_lane_f32(x, 3))
|
1237
|
-
#endif
|
1483
|
+
#define GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x)
|
1238
1484
|
#define GGML_F32x4_REDUCE(res, x) \
|
1239
1485
|
{ \
|
1240
1486
|
for (int i = 0; i < GGML_F32_ARR/2; ++i) { \
|
@@ -1758,34 +2004,188 @@ inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float
|
|
1758
2004
|
*s = sumf;
|
1759
2005
|
}
|
1760
2006
|
|
1761
|
-
#if __AVX512F__ &&
|
1762
|
-
static inline
|
2007
|
+
#if __AVX512F__ && QK4_0 == 32
|
2008
|
+
static inline __m512i bytes_from_q4_0_twoblocks_avx512( const __m512i blocks ) {
|
2009
|
+
// The 64 bytes of `blocks` contain two consecutive Q4_0 blocks loaded from memory:
|
2010
|
+
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
|
2011
|
+
// |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32|
|
2012
|
+
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
|
2013
|
+
// | :. =_ () [] <> () Zz Yy|
|
2014
|
+
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
|
2015
|
+
// |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00|
|
2016
|
+
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
|
2017
|
+
// |Xx Ww Vv Uu Tt Ss Rr Qq Pp Oo Nn Mm Ll Kk Jj Ii Hh Gg Ff Ee Dd Cc Bb Aa |
|
2018
|
+
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
|
2019
|
+
//
|
2020
|
+
// Bytes 04..19 (block #0) and 24..39 (block #1) both contain 32 nibbles (4-bit unsigned integers).
|
2021
|
+
// We have exactly 64 nibbles, so we want to place each nibble into a separate byte.
|
2022
|
+
// Bytes 00..03 and 20..23 contain scales, which are irrelevant to this function.
|
2023
|
+
// Bytes 40..63 are masked when loading the data, so they are zeroed out.
|
2024
|
+
#ifdef __AVX512VBMI__
|
2025
|
+
const __m512i byte_perm = _mm512_set_epi8(
|
2026
|
+
39, 38, 39, 38, 37, 36, 37, 36, 35, 34, 35, 34, 33, 32, 33, 32,
|
2027
|
+
31, 30, 31, 30, 29, 28, 29, 28, 27, 26, 27, 26, 25, 24, 25, 24,
|
2028
|
+
19, 18, 19, 18, 17, 16, 17, 16, 15, 14, 15, 14, 13, 12, 13, 12,
|
2029
|
+
11, 10, 11, 10, 9, 8, 9, 8, 7, 6, 7, 6, 5, 4, 5, 4
|
2030
|
+
);
|
2031
|
+
const __m512i permuted = _mm512_permutexvar_epi8( byte_perm, blocks );
|
2032
|
+
// After applying VPERMB, `permuted` looks like this:
|
2033
|
+
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2034
|
+
// |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32|
|
2035
|
+
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2036
|
+
// |:. =_ :. =_ () [] () [] <> () <> () Zz Yy Zz Yy Xx Ww Xx Ww Vv Uu Vv Uu Tt Ss Tt Ss Rr Qq Rr Qq|
|
2037
|
+
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2038
|
+
// |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00|
|
2039
|
+
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2040
|
+
// |Pp Oo Pp Oo Nn Mm Nn Mm Ll Kk Ll Kk Jj Ii Jj Ii Hh Gg Hh Gg Ff Ee Ff Ee Dd Cc Dd Cc Bb Aa Bb Aa|
|
2041
|
+
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2042
|
+
#else
|
2043
|
+
const __m512i word_perm = _mm512_set_epi16(
|
2044
|
+
19, 19, 18, 18, 17, 17, 16, 16, 15, 15, 14, 14, 13, 13, 12, 12,
|
2045
|
+
9, 9, 8, 8, 7, 7, 6, 6, 5, 5, 4, 4, 3, 3, 2, 2
|
2046
|
+
);
|
2047
|
+
const __m512i permuted = _mm512_permutexvar_epi16( word_perm, blocks );
|
2048
|
+
// This is the fallback path for CPUs that don't support VPERMB. Since we permute 16-bit groups only,
|
2049
|
+
// VPERMB can be replaced with VPERMW. We could always use VPERMW, but at least on Tiger Lake and
|
2050
|
+
// Ice Lake VPERMW followed by a right shift is quite noticeably slower than VPERMB.
|
2051
|
+
#endif
|
2052
|
+
|
2053
|
+
// Shift every odd-numbered 16-bit group to the right by 4 bits.
|
2054
|
+
const __mmask32 shift_mask = 0xaaaaaaaa;
|
2055
|
+
const __m512i shifted = _mm512_mask_srai_epi16( permuted, shift_mask, permuted, 4 );
|
2056
|
+
// After applying VPSRAW, `shifted` looks like this (the "empty" nibbles are filled with zeroes):
|
2057
|
+
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2058
|
+
// |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32
|
2059
|
+
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2060
|
+
// | : .= :. =_ ( )[ () [] < >( <> () Z zY Zz Yy X xW Xx Ww V vU Vv Uu T tS Tt Ss R rQ Rr Qq
|
2061
|
+
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2062
|
+
// |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00|
|
2063
|
+
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2064
|
+
// | P pO Pp Oo N nM Nn Mm L lK Ll Kk J jI Jj Ii H hG Hh Gg F fE Ff Ee D dC Dd Cc B bA Bb Aa|
|
2065
|
+
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2066
|
+
|
2067
|
+
// Now we just need to zero out the higher nibble in each byte, and we're done.
|
2068
|
+
const __m512i low_nibble_mask = _mm512_set1_epi8( 0xf );
|
2069
|
+
return _mm512_and_si512( low_nibble_mask, shifted );
|
2070
|
+
// The final result looks like this:
|
2071
|
+
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2072
|
+
// |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32|
|
2073
|
+
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2074
|
+
// | : = . _ ( [ ) ] < ( > ) Z Y z y X W x w V U v u T S t s R Q r q|
|
2075
|
+
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2076
|
+
// |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00|
|
2077
|
+
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2078
|
+
// | P O p o N M n m L K l k J I j i H G h g F E f e D C d c B A b a|
|
2079
|
+
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2080
|
+
}
|
2081
|
+
|
2082
|
+
static inline __m512 dot_q4_0_twoblocks_avx512(
|
1763
2083
|
__m512 acc,
|
1764
2084
|
const block_q4_0 * restrict x,
|
1765
2085
|
const block_q4_0 * restrict y,
|
1766
2086
|
int i
|
1767
2087
|
) {
|
1768
|
-
//
|
1769
|
-
|
1770
|
-
|
1771
|
-
|
1772
|
-
|
1773
|
-
|
1774
|
-
|
1775
|
-
|
1776
|
-
|
1777
|
-
|
1778
|
-
|
1779
|
-
//
|
1780
|
-
|
1781
|
-
|
1782
|
-
//
|
1783
|
-
|
2088
|
+
// A pair of Q4_0 blocks spans 40 bytes, while an AVX-512 register has 64. The remaining 24 bytes
|
2089
|
+
// can potentially be unaddressable, so we make sure to mask them out before the load, even though
|
2090
|
+
// we don't use them at all. This might hurt the performance slightly, since the compiler is forced
|
2091
|
+
// to use e.g. `VMOVDQU64 REG, MASK, [ADDR] + VPERMB ..., REG` instead of just `VPERMB ..., [ADDR]`.
|
2092
|
+
const __mmask8 load_mask = 0x1f;
|
2093
|
+
const __m512i blocks_0 = _mm512_maskz_loadu_epi64( load_mask, &x[i] );
|
2094
|
+
const __m512i blocks_1 = _mm512_maskz_loadu_epi64( load_mask, &y[i] );
|
2095
|
+
|
2096
|
+
// We want to multiply the scales, so we interpret both registers as 16 32-bit floats:
|
2097
|
+
// +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
|
2098
|
+
// | 15 | 14 | 13 | 12 | 11 | 10 | 09 | 08 | 07 | 06 | 05 | 04 | 03 | 02 | 01 | 00 |
|
2099
|
+
// +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
|
2100
|
+
// blocks_0_float
|
2101
|
+
// +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
|
2102
|
+
// | | | | | | | xx | xx | xx | xx | B | xx | xx | xx | xx | A |
|
2103
|
+
// +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
|
2104
|
+
// blocks_1_float
|
2105
|
+
// +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
|
2106
|
+
// | | | | | | | xx | xx | xx | xx | D | xx | xx | xx | xx | C |
|
2107
|
+
// +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
|
2108
|
+
const __m512 blocks_0_float = _mm512_castsi512_ps( blocks_0 );
|
2109
|
+
const __m512 blocks_1_float = _mm512_castsi512_ps( blocks_1 );
|
2110
|
+
// We absolutely shouldn't touch the floats marked with `xx`: they contain some
|
2111
|
+
// random data, which might very well underflow. At least on Intel, this leads
|
2112
|
+
// to a huge penalty that can't be ignored (easily 100x or more) unless you
|
2113
|
+
// compile your code with something like `-ffast-math` to enable FTZ/DAZ flags.
|
2114
|
+
// (and ggml can't assume that you do)...
|
2115
|
+
const __mmask16 scale_mul_mask = 0x21;
|
2116
|
+
#ifdef __clang__
|
2117
|
+
// ...however, clang decides to optimize the multiplication mask away:
|
2118
|
+
// https://godbolt.org/z/P8PqdsfvW
|
2119
|
+
// gcc and MSVC do the sane thing. This horrible workaround forces clang to emit the mask.
|
2120
|
+
__m512i scales;
|
2121
|
+
__asm__(
|
2122
|
+
"vmulps %1, %2, %0%{%3%}"
|
2123
|
+
: "=v" ( scales )
|
2124
|
+
: "vm" ( blocks_0_float ), "v" ( blocks_1_float ), "Yk" ( scale_mul_mask )
|
2125
|
+
);
|
2126
|
+
#else
|
2127
|
+
const __m512 scales = _mm512_maskz_mul_ps( scale_mul_mask, blocks_0_float, blocks_1_float );
|
2128
|
+
#endif
|
2129
|
+
const __m512i scale_perm = _mm512_set_epi32(
|
2130
|
+
5, 5, 5, 5, 5, 5, 5, 5,
|
2131
|
+
0, 0, 0, 0, 0, 0, 0, 0
|
2132
|
+
);
|
2133
|
+
const __m512 permuted_scales = _mm512_permutexvar_ps( scale_perm, scales );
|
2134
|
+
// After VMULPS and VPERMPS, `permuted_scales` looks like this:
|
2135
|
+
// +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
|
2136
|
+
// | 15 | 14 | 13 | 12 | 11 | 10 | 09 | 08 | 07 | 06 | 05 | 04 | 03 | 02 | 01 | 00 |
|
2137
|
+
// +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
|
2138
|
+
// | B*D| B*D| B*D| B*D| B*D| B*D| B*D| B*D| A*C| A*C| A*C| A*C| A*C| A*C| A*C| A*C|
|
2139
|
+
// +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
|
2140
|
+
|
2141
|
+
const __m512i bytes_0 = bytes_from_q4_0_twoblocks_avx512( blocks_0 );
|
2142
|
+
const __m512i bytes_1 = bytes_from_q4_0_twoblocks_avx512( blocks_1 );
|
2143
|
+
|
2144
|
+
// Now we want to compute dot products of 4-element byte vectors and store them in
|
2145
|
+
// 32-bit integers. That is (only one 4-element vector is shown for clarity):
|
2146
|
+
// +----+----+----+----+
|
2147
|
+
// ... | 03 | 02 | 01 | 00 |
|
2148
|
+
// +----+----+----+----+
|
2149
|
+
// bytes_0
|
2150
|
+
// +----+----+----+----+
|
2151
|
+
// ... | D | C | B | A |
|
2152
|
+
// +----+----+----+----+
|
2153
|
+
// bytes_1
|
2154
|
+
// +----+----+----+----+
|
2155
|
+
// ... | H | G | F | E |
|
2156
|
+
// +----+----+----+----+
|
2157
|
+
// final_res_int
|
2158
|
+
// +----+----+----+----+
|
2159
|
+
// ... | A*E+B*F+C*G+D*H |
|
2160
|
+
// +----+----+----+----+
|
2161
|
+
const __m512i plus_8 = _mm512_set1_epi8( 8 );
|
2162
|
+
const __m512i bytes_1_minus_8 = _mm512_sub_epi8( bytes_1, plus_8 );
|
2163
|
+
|
2164
|
+
#ifdef __AVX512VNNI__
|
2165
|
+
// We have VPDPBUSDS in AVX512-VNNI, which does exactly what we want, but with a catch:
|
2166
|
+
// the *left* operand is supposed to be unsigned, while Q4_0 quantization subtracts 8
|
2167
|
+
// from each nibble, so they can be negative. So, instead of `(bytes_0 - 8) * (bytes_1 - 8)`,
|
2168
|
+
// we compute `bytes_0 * (bytes_1 - 8) + bytes_1 * (-8) + 64`. VPDPBUSDS uses an accumulator,
|
2169
|
+
// which means we only need 2 instructions.
|
2170
|
+
const __m512i dot_init = _mm512_set1_epi32( 4 * 64 );
|
2171
|
+
const __m512i minus_8 = _mm512_set1_epi8( -8 );
|
2172
|
+
const __m512i prod_0 = _mm512_dpbusds_epi32( dot_init, bytes_1, minus_8 );
|
2173
|
+
const __m512i final_res_int = _mm512_dpbusds_epi32( prod_0, bytes_0, bytes_1_minus_8 );
|
2174
|
+
#else
|
2175
|
+
// As a fallback, we have VPMADDUBSW in AVX512-BW, which uses 16-bit products instead of 32-bit ones.
|
2176
|
+
// It has the same catch as VPDPBUSDS: the left operand should be unsigned.
|
2177
|
+
// This is essentially the AVX-512 version of the AVX-2 trick used by GH user Const-me
|
2178
|
+
// ref: https://gist.github.com/Const-me/4d30e1fc767ab314596e16e90f53b6f4#file-matmultest-cpp-L119
|
2179
|
+
const __m512i one = _mm512_set1_epi16( 1 );
|
2180
|
+
const __m512i prod_0 = _mm512_maddubs_epi16( bytes_0, bytes_1_minus_8 );
|
2181
|
+
const __m512i prod_1 = _mm512_maddubs_epi16( plus_8, bytes_1_minus_8 );
|
2182
|
+
const __m512i diff = _mm512_sub_epi16( prod_0, prod_1 );
|
2183
|
+
const __m512i final_res_int = _mm512_madd_epi16( diff, one );
|
2184
|
+
#endif
|
1784
2185
|
|
1785
|
-
//
|
1786
|
-
__m512
|
1787
|
-
|
1788
|
-
return _mm512_fmadd_ps( d, p, acc );
|
2186
|
+
// Finally, we multiply the permuted scales and the 32-bit dot products, then accumulate.
|
2187
|
+
const __m512 final_res_float = _mm512_cvtepi32_ps( final_res_int );
|
2188
|
+
return _mm512_fmadd_ps( permuted_scales, final_res_float, acc );
|
1789
2189
|
}
|
1790
2190
|
#endif
|
1791
2191
|
|
@@ -1826,9 +2226,9 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
|
|
1826
2226
|
}
|
1827
2227
|
|
1828
2228
|
static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
1829
|
-
const int nb = n /
|
2229
|
+
const int nb = n / QK4_0;
|
1830
2230
|
|
1831
|
-
assert(n %
|
2231
|
+
assert(n % QK4_0 == 0);
|
1832
2232
|
assert(nb % 2 == 0);
|
1833
2233
|
|
1834
2234
|
const block_q4_0 * restrict x = vx;
|
@@ -1857,55 +2257,43 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
|
|
1857
2257
|
// 4-bit -> 8-bit
|
1858
2258
|
const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b));
|
1859
2259
|
const int8x16_t v1_0l = vreinterpretq_s8_u8(vandq_u8(v1_0, m4b));
|
1860
|
-
|
1861
2260
|
const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
|
1862
2261
|
const int8x16_t v1_0h = vreinterpretq_s8_u8(vshrq_n_u8(v1_0, 4));
|
1863
2262
|
|
1864
2263
|
const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8(v0_1, m4b));
|
1865
2264
|
const int8x16_t v1_1l = vreinterpretq_s8_u8(vandq_u8(v1_1, m4b));
|
1866
|
-
|
1867
2265
|
const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
|
1868
2266
|
const int8x16_t v1_1h = vreinterpretq_s8_u8(vshrq_n_u8(v1_1, 4));
|
1869
2267
|
|
1870
2268
|
// sub 8
|
1871
2269
|
const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
|
1872
2270
|
const int8x16_t v1_0ls = vsubq_s8(v1_0l, s8b);
|
1873
|
-
|
1874
2271
|
const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
|
1875
2272
|
const int8x16_t v1_0hs = vsubq_s8(v1_0h, s8b);
|
1876
2273
|
|
1877
2274
|
const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
|
1878
2275
|
const int8x16_t v1_1ls = vsubq_s8(v1_1l, s8b);
|
1879
|
-
|
1880
2276
|
const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
|
1881
2277
|
const int8x16_t v1_1hs = vsubq_s8(v1_1h, s8b);
|
1882
2278
|
|
1883
2279
|
#if defined(__ARM_FEATURE_DOTPROD)
|
1884
|
-
// dot product into
|
2280
|
+
// dot product into int32x4_t
|
1885
2281
|
int32x4_t p_0 = vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls);
|
1886
2282
|
int32x4_t p_1 = vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls);
|
1887
2283
|
|
1888
2284
|
p_0 = vdotq_s32(p_0, v0_0hs, v1_0hs);
|
1889
2285
|
p_1 = vdotq_s32(p_1, v0_1hs, v1_1hs);
|
1890
2286
|
|
1891
|
-
|
1892
|
-
|
1893
|
-
sum0 += x0->d * y0->d * vaddvq_s32(p_0);
|
1894
|
-
sum1 += x1->d * y1->d * vaddvq_s32(p_1);
|
2287
|
+
sum0 += x0->d*y0->d*vaddvq_s32(p_0);
|
2288
|
+
sum1 += x1->d*y1->d*vaddvq_s32(p_1);
|
1895
2289
|
#else
|
1896
|
-
|
1897
|
-
sum1 += x1->d * y1->d * (vgetq_lane_s32(p_1, 0) + vgetq_lane_s32(p_1, 1) + vgetq_lane_s32(p_1, 2) + vgetq_lane_s32(p_1, 3));
|
1898
|
-
#endif
|
1899
|
-
#else
|
1900
|
-
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
|
2290
|
+
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
|
1901
2291
|
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
|
1902
|
-
|
1903
2292
|
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs));
|
1904
2293
|
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs));
|
1905
2294
|
|
1906
2295
|
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls));
|
1907
2296
|
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls));
|
1908
|
-
|
1909
2297
|
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
|
1910
2298
|
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
|
1911
2299
|
|
@@ -1918,14 +2306,8 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
|
|
1918
2306
|
const int16x8_t p_0 = vaddq_s16(pl_0, ph_0);
|
1919
2307
|
const int16x8_t p_1 = vaddq_s16(pl_1, ph_1);
|
1920
2308
|
|
1921
|
-
|
1922
|
-
|
1923
|
-
sum0 += x0->d * y0->d * vaddvq_s16(p_0);
|
1924
|
-
sum1 += x1->d * y1->d * vaddvq_s16(p_1);
|
1925
|
-
#else
|
1926
|
-
sum0 += x0->d * y0->d * (vgetq_lane_s16(p_0, 0) + vgetq_lane_s16(p_0, 1) + vgetq_lane_s16(p_0, 2) + vgetq_lane_s16(p_0, 3) + vgetq_lane_s16(p_0, 4) + vgetq_lane_s16(p_0, 5) + vgetq_lane_s16(p_0, 6) + vgetq_lane_s16(p_0, 7));
|
1927
|
-
sum1 += x1->d * y1->d * (vgetq_lane_s16(p_1, 0) + vgetq_lane_s16(p_1, 1) + vgetq_lane_s16(p_1, 2) + vgetq_lane_s16(p_1, 3) + vgetq_lane_s16(p_1, 4) + vgetq_lane_s16(p_1, 5) + vgetq_lane_s16(p_1, 6) + vgetq_lane_s16(p_1, 7));
|
1928
|
-
#endif
|
2309
|
+
sum0 += x0->d*y0->d*vaddvq_s16(p_0);
|
2310
|
+
sum1 += x1->d*y1->d*vaddvq_s16(p_1);
|
1929
2311
|
#endif
|
1930
2312
|
}
|
1931
2313
|
|
@@ -1935,25 +2317,26 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
|
|
1935
2317
|
__m512 acc0 = _mm512_setzero_ps();
|
1936
2318
|
__m512 acc1 = _mm512_setzero_ps();
|
1937
2319
|
|
1938
|
-
const int superblock_size =
|
2320
|
+
const int superblock_size = 16;
|
2321
|
+
|
1939
2322
|
const int superblock_count = nb / superblock_size;
|
1940
2323
|
|
1941
2324
|
for (int superblock_ix = 0; superblock_ix < superblock_count; superblock_ix += 1) {
|
1942
2325
|
int i = superblock_ix * superblock_size;
|
1943
2326
|
|
1944
|
-
acc0 =
|
1945
|
-
acc1 =
|
1946
|
-
acc0 =
|
1947
|
-
acc1 =
|
1948
|
-
acc0 =
|
1949
|
-
acc1 =
|
1950
|
-
acc0 =
|
1951
|
-
acc1 =
|
2327
|
+
acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i+0 );
|
2328
|
+
acc1 = dot_q4_0_twoblocks_avx512( acc1, x, y, i+2 );
|
2329
|
+
acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i+4 );
|
2330
|
+
acc1 = dot_q4_0_twoblocks_avx512( acc1, x, y, i+6 );
|
2331
|
+
acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i+8 );
|
2332
|
+
acc1 = dot_q4_0_twoblocks_avx512( acc1, x, y, i+10 );
|
2333
|
+
acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i+12 );
|
2334
|
+
acc1 = dot_q4_0_twoblocks_avx512( acc1, x, y, i+14 );
|
1952
2335
|
}
|
1953
2336
|
|
1954
2337
|
// Remainders
|
1955
|
-
for (int i = superblock_count * superblock_size; i < nb;
|
1956
|
-
acc0 =
|
2338
|
+
for (int i = superblock_count * superblock_size; i < nb; i += 2) {
|
2339
|
+
acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i );
|
1957
2340
|
}
|
1958
2341
|
|
1959
2342
|
// Horizontal sum of all lanes of the accumulator
|
@@ -1962,7 +2345,7 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
|
|
1962
2345
|
// Initialize accumulator with zeros
|
1963
2346
|
__m256 acc = _mm256_setzero_ps();
|
1964
2347
|
|
1965
|
-
/* Prepare the constants we will need during execution */
|
2348
|
+
/* Prepare the constants we will need during execution */
|
1966
2349
|
const __m256i lowMask = _mm256_set1_epi8( 0xF );
|
1967
2350
|
const __m256i offset_8 = _mm256_set1_epi16( 8 );
|
1968
2351
|
|
@@ -1972,61 +2355,59 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
|
|
1972
2355
|
|
1973
2356
|
// Main loop
|
1974
2357
|
for (int i = 0; i < nb; i+=UNROLL_COUNT) {
|
1975
|
-
|
1976
|
-
// This loop will be unrolled by the compiler
|
2358
|
+
// This loop will be unrolled by the compiler
|
1977
2359
|
for (int u=0;u<UNROLL_COUNT;u++) {
|
1978
|
-
/* Compute combined scale for the block */
|
1979
|
-
const __m256 scale = _mm256_mul_ps(
|
1980
|
-
_mm256_broadcast_ss( &x[i+u].d ),
|
1981
|
-
_mm256_broadcast_ss( &y[i+u].d ) );
|
1982
|
-
|
1983
|
-
/* get input from x
|
1984
|
-
Input: 32 Nibbles (16 bytes) at *x[i+u]
|
1985
|
-
Output: 2 vectors with 16 values of type int16_t (x_high_q, x_low_q) */
|
1986
|
-
|
1987
|
-
/* Load 16 bytes from memory */
|
1988
|
-
const __m128i tmp_x = _mm_loadu_si128( ( const __m128i* ) x[i+u].qs);
|
1989
|
-
/* Expand bytes into uint16_t values */
|
1990
|
-
const __m256i bytes_x = _mm256_cvtepu8_epi16(tmp_x);
|
2360
|
+
/* Compute combined scale for the block */
|
2361
|
+
const __m256 scale = _mm256_mul_ps(
|
2362
|
+
_mm256_broadcast_ss( &x[i+u].d ),
|
2363
|
+
_mm256_broadcast_ss( &y[i+u].d ) );
|
2364
|
+
|
2365
|
+
/* get input from x
|
2366
|
+
Input: 32 Nibbles (16 bytes) at *x[i+u]
|
2367
|
+
Output: 2 vectors with 16 values of type int16_t (x_high_q, x_low_q) */
|
2368
|
+
|
2369
|
+
/* Load 16 bytes from memory */
|
2370
|
+
const __m128i tmp_x = _mm_loadu_si128( ( const __m128i* ) x[i+u].qs);
|
2371
|
+
/* Expand bytes into uint16_t values */
|
2372
|
+
const __m256i bytes_x = _mm256_cvtepu8_epi16(tmp_x);
|
1991
2373
|
/* Unpack values into individual bytes */
|
1992
2374
|
__m256i x_low_q = _mm256_and_si256( lowMask, bytes_x );
|
1993
2375
|
const __m256i pre_shift_x_high_q = _mm256_andnot_si256( lowMask, bytes_x );
|
1994
|
-
__m256i x_high_q = _mm256_srli_epi16( pre_shift_x_high_q, 4 );
|
2376
|
+
__m256i x_high_q = _mm256_srli_epi16( pre_shift_x_high_q, 4 );
|
1995
2377
|
/* Now we have two vectors with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. */
|
1996
|
-
x_high_q = _mm256_sub_epi16( x_high_q, offset_8 );
|
1997
|
-
x_low_q = _mm256_sub_epi16( x_low_q, offset_8 );
|
2378
|
+
x_high_q = _mm256_sub_epi16( x_high_q, offset_8 );
|
2379
|
+
x_low_q = _mm256_sub_epi16( x_low_q, offset_8 );
|
1998
2380
|
|
1999
|
-
/* get input from y
|
2000
|
-
Input: 32 Nibbles (16 bytes) at *y[i+u]
|
2001
|
-
Output: 2 vectors with 16 values of type int16_t (y_high_q, y_low_q) */
|
2381
|
+
/* get input from y
|
2382
|
+
Input: 32 Nibbles (16 bytes) at *y[i+u]
|
2383
|
+
Output: 2 vectors with 16 values of type int16_t (y_high_q, y_low_q) */
|
2002
2384
|
|
2003
|
-
/* Load 16 bytes from memory */
|
2004
|
-
const __m128i tmp_y = _mm_loadu_si128( (const __m128i* ) y[i+u].qs);
|
2005
|
-
/* Expand bytes into uint16_t values */
|
2006
|
-
const __m256i bytes_y = _mm256_cvtepu8_epi16(tmp_y);
|
2385
|
+
/* Load 16 bytes from memory */
|
2386
|
+
const __m128i tmp_y = _mm_loadu_si128( (const __m128i* ) y[i+u].qs);
|
2387
|
+
/* Expand bytes into uint16_t values */
|
2388
|
+
const __m256i bytes_y = _mm256_cvtepu8_epi16(tmp_y);
|
2007
2389
|
/* Unpack values into individual bytes */
|
2008
|
-
const __m256i pre_shift_y_high_q = _mm256_andnot_si256( lowMask, bytes_y );
|
2009
|
-
__m256i y_high_q = _mm256_srli_epi16( pre_shift_y_high_q, 4 );
|
2010
|
-
__m256i y_low_q = _mm256_and_si256( lowMask, bytes_y );
|
2390
|
+
const __m256i pre_shift_y_high_q = _mm256_andnot_si256( lowMask, bytes_y );
|
2391
|
+
__m256i y_high_q = _mm256_srli_epi16( pre_shift_y_high_q, 4 );
|
2392
|
+
__m256i y_low_q = _mm256_and_si256( lowMask, bytes_y );
|
2011
2393
|
/* Now we have two vectors with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. */
|
2012
|
-
y_high_q = _mm256_sub_epi16( y_high_q, offset_8 );
|
2013
|
-
y_low_q = _mm256_sub_epi16( y_low_q, offset_8 );
|
2394
|
+
y_high_q = _mm256_sub_epi16( y_high_q, offset_8 );
|
2395
|
+
y_low_q = _mm256_sub_epi16( y_low_q, offset_8 );
|
2014
2396
|
|
2015
|
-
/* Compute products of int16_t integers, add pairwise, store as int32_t */
|
2016
|
-
__m256i xy_high_q = _mm256_madd_epi16( x_high_q, y_high_q );
|
2017
|
-
__m256i xy_low_q = _mm256_madd_epi16( x_low_q, y_low_q );
|
2397
|
+
/* Compute products of int16_t integers, add pairwise, store as int32_t */
|
2398
|
+
__m256i xy_high_q = _mm256_madd_epi16( x_high_q, y_high_q );
|
2399
|
+
__m256i xy_low_q = _mm256_madd_epi16( x_low_q, y_low_q );
|
2018
2400
|
|
2019
|
-
/* Accumulate the products of int32_t integers -> we now have a vector of 8 int_32t */
|
2020
|
-
__m256i xy_q = _mm256_add_epi32( xy_high_q, xy_low_q );
|
2401
|
+
/* Accumulate the products of int32_t integers -> we now have a vector of 8 int_32t */
|
2402
|
+
__m256i xy_q = _mm256_add_epi32( xy_high_q, xy_low_q );
|
2021
2403
|
|
2022
|
-
/* Convert to vectore of 8 int32_t to 8 floats */
|
2023
|
-
__m256 q = _mm256_cvtepi32_ps( xy_q );
|
2404
|
+
/* Convert to vectore of 8 int32_t to 8 floats */
|
2405
|
+
__m256 q = _mm256_cvtepi32_ps( xy_q );
|
2024
2406
|
|
2025
|
-
/* Multiply q with scale and accumulate */
|
2026
|
-
acc = _mm256_fmadd_ps( scale, q, acc );
|
2407
|
+
/* Multiply q with scale and accumulate */
|
2408
|
+
acc = _mm256_fmadd_ps( scale, q, acc );
|
2027
2409
|
}
|
2028
|
-
|
2029
|
-
}
|
2410
|
+
}
|
2030
2411
|
|
2031
2412
|
// Return horizontal sum of the acc vector
|
2032
2413
|
__m128 res = _mm256_extractf128_ps( acc, 1 );
|
@@ -2087,18 +2468,18 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
|
|
2087
2468
|
float sum1 = 0.0f;
|
2088
2469
|
|
2089
2470
|
for (int i = 0; i < nb; i += 2) {
|
2090
|
-
const block_q4_0 * restrict x0 = &
|
2091
|
-
const block_q4_0 * restrict y0 = &
|
2092
|
-
const block_q4_0 * restrict x1 = &
|
2093
|
-
const block_q4_0 * restrict y1 = &
|
2471
|
+
const block_q4_0 * restrict x0 = &x[i + 0];
|
2472
|
+
const block_q4_0 * restrict y0 = &y[i + 0];
|
2473
|
+
const block_q4_0 * restrict x1 = &x[i + 1];
|
2474
|
+
const block_q4_0 * restrict y1 = &y[i + 1];
|
2094
2475
|
|
2095
2476
|
const v128_t m4b = wasm_u8x16_splat(0xf);
|
2096
2477
|
const v128_t s8b = wasm_i8x16_splat(0x8);
|
2097
2478
|
|
2098
|
-
const v128_t v0_0 = wasm_v128_load(x0
|
2099
|
-
const v128_t v0_1 = wasm_v128_load(y0
|
2100
|
-
const v128_t v1_0 = wasm_v128_load(x1
|
2101
|
-
const v128_t v1_1 = wasm_v128_load(y1
|
2479
|
+
const v128_t v0_0 = wasm_v128_load(x0->qs);
|
2480
|
+
const v128_t v0_1 = wasm_v128_load(y0->qs);
|
2481
|
+
const v128_t v1_0 = wasm_v128_load(x1->qs);
|
2482
|
+
const v128_t v1_1 = wasm_v128_load(y1->qs);
|
2102
2483
|
|
2103
2484
|
// 4-bit -> 8-bit
|
2104
2485
|
const v128_t v0_0l = wasm_v128_and(v0_0, m4b);
|
@@ -2170,18 +2551,20 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
|
|
2170
2551
|
const uint8_t * restrict p0 = x[i].qs;
|
2171
2552
|
const uint8_t * restrict p1 = y[i].qs;
|
2172
2553
|
|
2173
|
-
|
2554
|
+
int sumi = 0;
|
2555
|
+
for (int j = 0; j < QK4_0/2; j++) {
|
2174
2556
|
const uint8_t v0 = p0[j];
|
2175
2557
|
const uint8_t v1 = p1[j];
|
2176
2558
|
|
2177
|
-
const
|
2178
|
-
const
|
2559
|
+
const int i0 = (v0 & 0xf) - 8;
|
2560
|
+
const int i1 = (v0 >> 4) - 8;
|
2179
2561
|
|
2180
|
-
const
|
2181
|
-
const
|
2562
|
+
const int i2 = (v1 & 0xf) - 8;
|
2563
|
+
const int i3 = (v1 >> 4) - 8;
|
2182
2564
|
|
2183
|
-
|
2565
|
+
sumi += i0*i2 + i1*i3;
|
2184
2566
|
}
|
2567
|
+
sumf += d0 * d1 * sumi;
|
2185
2568
|
}
|
2186
2569
|
#endif
|
2187
2570
|
|
@@ -2189,7 +2572,7 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
|
|
2189
2572
|
}
|
2190
2573
|
|
2191
2574
|
static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
2192
|
-
const int nb = n /
|
2575
|
+
const int nb = n / QK4_1;
|
2193
2576
|
|
2194
2577
|
const block_q4_1 * restrict x = vx;
|
2195
2578
|
const block_q4_1 * restrict y = vy;
|
@@ -2266,46 +2649,81 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest
|
|
2266
2649
|
res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
|
2267
2650
|
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
|
2268
2651
|
|
2269
|
-
sumf = _mm_cvtss_f32( res ) + acc_offset *
|
2652
|
+
sumf = _mm_cvtss_f32( res ) + acc_offset * QK4_1;
|
2270
2653
|
#elif defined(__ARM_NEON)
|
2271
2654
|
float sum00 = 0.0f;
|
2272
2655
|
float sum01 = 0.0f;
|
2273
2656
|
float sum10 = 0.0f;
|
2274
2657
|
float sum11 = 0.0f;
|
2275
2658
|
|
2276
|
-
for (int i = 0; i < nb;
|
2659
|
+
for (int i = 0; i < nb; i += 2) {
|
2277
2660
|
const block_q4_1 * restrict x0 = &x[i + 0];
|
2278
2661
|
const block_q4_1 * restrict y0 = &y[i + 0];
|
2662
|
+
const block_q4_1 * restrict x1 = &x[i + 1];
|
2663
|
+
const block_q4_1 * restrict y1 = &y[i + 1];
|
2279
2664
|
|
2280
2665
|
const uint8x16_t m4b = vdupq_n_u8(0xf);
|
2281
2666
|
|
2282
2667
|
const uint8x16_t v0_0 = vld1q_u8(x0->qs);
|
2283
2668
|
const uint8x16_t v1_0 = vld1q_u8(y0->qs);
|
2669
|
+
const uint8x16_t v0_1 = vld1q_u8(x1->qs);
|
2670
|
+
const uint8x16_t v1_1 = vld1q_u8(y1->qs);
|
2284
2671
|
|
2285
|
-
//
|
2672
|
+
// 4-bit -> 8-bit
|
2286
2673
|
const uint8x16_t v0_0l = vandq_u8(v0_0, m4b);
|
2287
2674
|
const uint8x16_t v1_0l = vandq_u8(v1_0, m4b);
|
2288
|
-
|
2289
2675
|
const uint8x16_t v0_0h = vshrq_n_u8(v0_0, 4);
|
2290
2676
|
const uint8x16_t v1_0h = vshrq_n_u8(v1_0, 4);
|
2291
2677
|
|
2292
|
-
|
2678
|
+
const uint8x16_t v0_1l = vandq_u8(v0_1, m4b);
|
2679
|
+
const uint8x16_t v1_1l = vandq_u8(v1_1, m4b);
|
2680
|
+
const uint8x16_t v0_1h = vshrq_n_u8(v0_1, 4);
|
2681
|
+
const uint8x16_t v1_1h = vshrq_n_u8(v1_1, 4);
|
2682
|
+
|
2683
|
+
sum00 += x0->m*y0->m;
|
2684
|
+
sum01 += y0->m*x0->d*((uint16_t)vaddvq_u8(v0_0l) + (uint16_t)vaddvq_u8(v0_0h));
|
2685
|
+
sum10 += x0->m*y0->d*((uint16_t)vaddvq_u8(v1_0l) + (uint16_t)vaddvq_u8(v1_0h));
|
2686
|
+
|
2687
|
+
sum00 += x1->m*y1->m;
|
2688
|
+
sum01 += y1->m*x1->d*((uint16_t)vaddvq_u8(v0_1l) + (uint16_t)vaddvq_u8(v0_1h));
|
2689
|
+
sum10 += x1->m*y1->d*((uint16_t)vaddvq_u8(v1_1l) + (uint16_t)vaddvq_u8(v1_1h));
|
2690
|
+
|
2691
|
+
#if defined(__ARM_FEATURE_DOTPROD)
|
2692
|
+
// dot product into int32x4_t
|
2693
|
+
uint32x4_t p_0 = vdotq_u32(vdupq_n_u32(0), v0_0l, v1_0l);
|
2694
|
+
uint32x4_t p_1 = vdotq_u32(vdupq_n_u32(0), v0_1l, v1_1l);
|
2695
|
+
|
2696
|
+
p_0 = vdotq_u32(p_0, v0_0h, v1_0h);
|
2697
|
+
p_1 = vdotq_u32(p_1, v0_1h, v1_1h);
|
2698
|
+
|
2699
|
+
sum11 += x0->d*y0->d*vaddvq_u32(p_0);
|
2700
|
+
sum11 += x1->d*y1->d*vaddvq_u32(p_1);
|
2701
|
+
#else
|
2293
2702
|
const uint16x8_t pl0l = vmull_u8(vget_low_u8 (v0_0l), vget_low_u8 (v1_0l));
|
2294
2703
|
const uint16x8_t pl0h = vmull_u8(vget_high_u8(v0_0l), vget_high_u8(v1_0l));
|
2295
|
-
|
2296
2704
|
const uint16x8_t ph0l = vmull_u8(vget_low_u8 (v0_0h), vget_low_u8 (v1_0h));
|
2297
2705
|
const uint16x8_t ph0h = vmull_u8(vget_high_u8(v0_0h), vget_high_u8(v1_0h));
|
2298
2706
|
|
2299
|
-
const uint16x8_t
|
2300
|
-
const uint16x8_t
|
2707
|
+
const uint16x8_t pl1l = vmull_u8(vget_low_u8 (v0_1l), vget_low_u8 (v1_1l));
|
2708
|
+
const uint16x8_t pl1h = vmull_u8(vget_high_u8(v0_1l), vget_high_u8(v1_1l));
|
2709
|
+
const uint16x8_t ph1l = vmull_u8(vget_low_u8 (v0_1h), vget_low_u8 (v1_1h));
|
2710
|
+
const uint16x8_t ph1h = vmull_u8(vget_high_u8(v0_1h), vget_high_u8(v1_1h));
|
2301
2711
|
|
2302
|
-
|
2303
|
-
|
2304
|
-
|
2305
|
-
|
2712
|
+
const uint16x8_t pl_0 = vaddq_u16(pl0l, pl0h);
|
2713
|
+
const uint16x8_t ph_0 = vaddq_u16(ph0l, ph0h);
|
2714
|
+
|
2715
|
+
const uint16x8_t pl_1 = vaddq_u16(pl1l, pl1h);
|
2716
|
+
const uint16x8_t ph_1 = vaddq_u16(ph1l, ph1h);
|
2717
|
+
|
2718
|
+
const uint16x8_t p_0 = vaddq_u16(pl_0, ph_0);
|
2719
|
+
const uint16x8_t p_1 = vaddq_u16(pl_1, ph_1);
|
2720
|
+
|
2721
|
+
sum11 += x0->d*y0->d*vaddvq_u16(p_0);
|
2722
|
+
sum11 += x1->d*y1->d*vaddvq_u16(p_1);
|
2723
|
+
#endif
|
2306
2724
|
}
|
2307
2725
|
|
2308
|
-
sumf =
|
2726
|
+
sumf = QK4_1*sum00 + sum01 + sum10 + sum11;
|
2309
2727
|
#else
|
2310
2728
|
// scalar
|
2311
2729
|
for (int i = 0; i < nb; i++) {
|
@@ -2318,7 +2736,7 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest
|
|
2318
2736
|
const uint8_t * restrict p0 = x[i].qs;
|
2319
2737
|
const uint8_t * restrict p1 = y[i].qs;
|
2320
2738
|
|
2321
|
-
for (int j = 0; j <
|
2739
|
+
for (int j = 0; j < QK4_1/2; j++) {
|
2322
2740
|
const uint8_t v0 = p0[j];
|
2323
2741
|
const uint8_t v1 = p1[j];
|
2324
2742
|
|
@@ -2336,21 +2754,224 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest
|
|
2336
2754
|
*s = sumf;
|
2337
2755
|
}
|
2338
2756
|
|
2339
|
-
|
2340
|
-
|
2341
|
-
inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) {
|
2342
|
-
ggml_float sumf[GGML_VEC_DOT_UNROLL] = { 0.0 };
|
2757
|
+
static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
2758
|
+
const int nb = n / QK8_0;
|
2343
2759
|
|
2344
|
-
|
2760
|
+
assert(n % QK8_0 == 0);
|
2761
|
+
assert(nb % 2 == 0);
|
2345
2762
|
|
2346
|
-
|
2347
|
-
|
2348
|
-
}
|
2763
|
+
const block_q4_0 * restrict x = vx;
|
2764
|
+
const block_q8_0 * restrict y = vy;
|
2349
2765
|
|
2350
|
-
|
2351
|
-
const int np = (n & ~(GGML_F16_STEP - 1));
|
2766
|
+
float sumf = 0.0;
|
2352
2767
|
|
2353
|
-
|
2768
|
+
#if defined(__ARM_NEON)
|
2769
|
+
float sum0 = 0.0f;
|
2770
|
+
float sum1 = 0.0f;
|
2771
|
+
|
2772
|
+
for (int i = 0; i < nb; i += 2) {
|
2773
|
+
const block_q4_0 * restrict x0 = &x[i + 0];
|
2774
|
+
const block_q4_0 * restrict x1 = &x[i + 1];
|
2775
|
+
const block_q8_0 * restrict y0 = &y[i + 0];
|
2776
|
+
const block_q8_0 * restrict y1 = &y[i + 1];
|
2777
|
+
|
2778
|
+
const uint8x16_t m4b = vdupq_n_u8(0xf);
|
2779
|
+
const int8x16_t s8b = vdupq_n_s8(0x8);
|
2780
|
+
|
2781
|
+
const uint8x16_t v0_0 = vld1q_u8(x0->qs);
|
2782
|
+
const uint8x16_t v0_1 = vld1q_u8(x1->qs);
|
2783
|
+
|
2784
|
+
// 4-bit -> 8-bit
|
2785
|
+
const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
|
2786
|
+
const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
|
2787
|
+
const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
|
2788
|
+
const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
|
2789
|
+
|
2790
|
+
// sub 8
|
2791
|
+
const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
|
2792
|
+
const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
|
2793
|
+
const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
|
2794
|
+
const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
|
2795
|
+
|
2796
|
+
// load y
|
2797
|
+
const int8x16_t v1_0l = vld1q_s8(y0->qs);
|
2798
|
+
const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
|
2799
|
+
const int8x16_t v1_1l = vld1q_s8(y1->qs);
|
2800
|
+
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
|
2801
|
+
|
2802
|
+
// interleave
|
2803
|
+
const int8x16_t v1_0ls = vuzp1q_s8(v1_0l, v1_0h);
|
2804
|
+
const int8x16_t v1_0hs = vuzp2q_s8(v1_0l, v1_0h);
|
2805
|
+
const int8x16_t v1_1ls = vuzp1q_s8(v1_1l, v1_1h);
|
2806
|
+
const int8x16_t v1_1hs = vuzp2q_s8(v1_1l, v1_1h);
|
2807
|
+
|
2808
|
+
#if defined(__ARM_FEATURE_DOTPROD)
|
2809
|
+
// dot product into int32x4_t
|
2810
|
+
int32x4_t p_0 = vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls);
|
2811
|
+
int32x4_t p_1 = vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls);
|
2812
|
+
|
2813
|
+
p_0 = vdotq_s32(p_0, v0_0hs, v1_0hs);
|
2814
|
+
p_1 = vdotq_s32(p_1, v0_1hs, v1_1hs);
|
2815
|
+
|
2816
|
+
sum0 += x0->d*y0->d*vaddvq_s32(p_0);
|
2817
|
+
sum1 += x1->d*y1->d*vaddvq_s32(p_1);
|
2818
|
+
#else
|
2819
|
+
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
|
2820
|
+
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
|
2821
|
+
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs));
|
2822
|
+
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs));
|
2823
|
+
|
2824
|
+
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls));
|
2825
|
+
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls));
|
2826
|
+
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
|
2827
|
+
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
|
2828
|
+
|
2829
|
+
const int16x8_t pl_0 = vaddq_s16(pl0l, pl0h);
|
2830
|
+
const int16x8_t ph_0 = vaddq_s16(ph0l, ph0h);
|
2831
|
+
|
2832
|
+
const int16x8_t pl_1 = vaddq_s16(pl1l, pl1h);
|
2833
|
+
const int16x8_t ph_1 = vaddq_s16(ph1l, ph1h);
|
2834
|
+
|
2835
|
+
const int16x8_t p_0 = vaddq_s16(pl_0, ph_0);
|
2836
|
+
const int16x8_t p_1 = vaddq_s16(pl_1, ph_1);
|
2837
|
+
|
2838
|
+
sum0 += x0->d*y0->d*vaddvq_s16(p_0);
|
2839
|
+
sum1 += x1->d*y1->d*vaddvq_s16(p_1);
|
2840
|
+
#endif
|
2841
|
+
}
|
2842
|
+
|
2843
|
+
sumf = sum0 + sum1;
|
2844
|
+
#elif defined(__AVX2__)
|
2845
|
+
// Initialize accumulator with zeros
|
2846
|
+
__m256 acc = _mm256_setzero_ps();
|
2847
|
+
|
2848
|
+
// Main loop
|
2849
|
+
for (int i = 0; i < nb; ++i) {
|
2850
|
+
/* Compute combined scale for the block */
|
2851
|
+
const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
|
2852
|
+
|
2853
|
+
__m256i bx = bytesFromNibbles(x[i].qs);
|
2854
|
+
|
2855
|
+
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
|
2856
|
+
const __m256i off = _mm256_set1_epi8( 8 );
|
2857
|
+
bx = _mm256_sub_epi8( bx, off );
|
2858
|
+
|
2859
|
+
__m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
|
2860
|
+
|
2861
|
+
// Get absolute values of x vectors
|
2862
|
+
const __m256i ax = _mm256_sign_epi8(bx, bx);
|
2863
|
+
|
2864
|
+
// Sign the values of the y vectors
|
2865
|
+
const __m256i sy = _mm256_sign_epi8(by, bx);
|
2866
|
+
|
2867
|
+
// Perform multiplication and create 16-bit values
|
2868
|
+
const __m256i dot = _mm256_maddubs_epi16(ax, sy);
|
2869
|
+
|
2870
|
+
const __m256i ones = _mm256_set1_epi16(1);
|
2871
|
+
__m256i xy_q = _mm256_madd_epi16(ones, dot);
|
2872
|
+
|
2873
|
+
/* Convert to vectore of 8 int32_t to 8 floats */
|
2874
|
+
__m256 q = _mm256_cvtepi32_ps( xy_q );
|
2875
|
+
|
2876
|
+
/* Multiply q with scale and accumulate */
|
2877
|
+
acc = _mm256_fmadd_ps( d, q, acc );
|
2878
|
+
}
|
2879
|
+
|
2880
|
+
// Return horizontal sum of the acc vector
|
2881
|
+
__m128 res = _mm256_extractf128_ps( acc, 1 );
|
2882
|
+
res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
|
2883
|
+
res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
|
2884
|
+
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
|
2885
|
+
|
2886
|
+
sumf = _mm_cvtss_f32( res );
|
2887
|
+
#elif defined(__AVX__)
|
2888
|
+
// Initialize accumulator with zeros
|
2889
|
+
__m256 acc = _mm256_setzero_ps();
|
2890
|
+
|
2891
|
+
// Main loop
|
2892
|
+
for (int i = 0; i < nb; ++i) {
|
2893
|
+
// Compute combined scale for the block
|
2894
|
+
const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
|
2895
|
+
|
2896
|
+
__m128i i32[2];
|
2897
|
+
for (int j = 0; j < 2; ++j) {
|
2898
|
+
// Load 8 bytes, and unpack 4 bit fields into bytes, making 16 bytes
|
2899
|
+
__m128i bx = bytesFromNibbles( x[i].qs + 8*j );
|
2900
|
+
__m128i by = _mm_loadu_si128((const __m128i *)(y[i].qs + 16*j));
|
2901
|
+
|
2902
|
+
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
|
2903
|
+
const __m128i off = _mm_set1_epi8( 8 );
|
2904
|
+
bx = _mm_sub_epi8( bx, off );
|
2905
|
+
|
2906
|
+
// Get absolute values of x vectors
|
2907
|
+
const __m128i ax = _mm_sign_epi8(bx, bx);
|
2908
|
+
|
2909
|
+
// Sign the values of the y vectors
|
2910
|
+
const __m128i sy = _mm_sign_epi8(by, bx);
|
2911
|
+
|
2912
|
+
// Perform multiplication and create 16-bit values
|
2913
|
+
const __m128i dot = _mm_maddubs_epi16(ax, sy);
|
2914
|
+
|
2915
|
+
const __m128i ones = _mm_set1_epi16(1);
|
2916
|
+
i32[j] = _mm_madd_epi16(ones, dot);
|
2917
|
+
}
|
2918
|
+
|
2919
|
+
// Convert int32_t to float
|
2920
|
+
__m256 p = _mm256_cvtepi32_ps( _mm256_set_m128i( i32[0], i32[1] ));
|
2921
|
+
// Apply the scale, and accumulate
|
2922
|
+
acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc);
|
2923
|
+
}
|
2924
|
+
|
2925
|
+
// Return horizontal sum of the acc vector
|
2926
|
+
__m128 res = _mm256_extractf128_ps( acc, 1 );
|
2927
|
+
res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
|
2928
|
+
res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
|
2929
|
+
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
|
2930
|
+
|
2931
|
+
sumf = _mm_cvtss_f32( res );
|
2932
|
+
#else
|
2933
|
+
// scalar
|
2934
|
+
for (int i = 0; i < nb; i++) {
|
2935
|
+
const float d0 = x[i].d;
|
2936
|
+
const float d1 = y[i].d;
|
2937
|
+
|
2938
|
+
const uint8_t * restrict p0 = x[i].qs;
|
2939
|
+
const int8_t * restrict p1 = y[i].qs;
|
2940
|
+
|
2941
|
+
int sumi = 0;
|
2942
|
+
for (int j = 0; j < QK8_0/2; j++) {
|
2943
|
+
const uint8_t v0 = p0[j];
|
2944
|
+
|
2945
|
+
const int i0 = (int8_t) (v0 & 0xf) - 8;
|
2946
|
+
const int i1 = (int8_t) (v0 >> 4) - 8;
|
2947
|
+
|
2948
|
+
const int i2 = p1[2*j + 0];
|
2949
|
+
const int i3 = p1[2*j + 1];
|
2950
|
+
|
2951
|
+
sumi += i0*i2 + i1*i3;
|
2952
|
+
}
|
2953
|
+
sumf += d0*d1*sumi;
|
2954
|
+
}
|
2955
|
+
#endif
|
2956
|
+
|
2957
|
+
*s = sumf;
|
2958
|
+
}
|
2959
|
+
|
2960
|
+
// compute GGML_VEC_DOT_UNROLL dot products at once
|
2961
|
+
// xs - x row stride in bytes
|
2962
|
+
inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) {
|
2963
|
+
ggml_float sumf[GGML_VEC_DOT_UNROLL] = { 0.0 };
|
2964
|
+
|
2965
|
+
ggml_fp16_t * restrict x[GGML_VEC_DOT_UNROLL];
|
2966
|
+
|
2967
|
+
for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) {
|
2968
|
+
x[i] = (ggml_fp16_t *) ((char *) xv + i*xs);
|
2969
|
+
}
|
2970
|
+
|
2971
|
+
#if defined(GGML_SIMD)
|
2972
|
+
const int np = (n & ~(GGML_F16_STEP - 1));
|
2973
|
+
|
2974
|
+
GGML_F16_VEC sum[GGML_VEC_DOT_UNROLL][GGML_F16_ARR] = { { GGML_F16_VEC_ZERO } };
|
2354
2975
|
|
2355
2976
|
GGML_F16_VEC ax[GGML_F16_ARR];
|
2356
2977
|
GGML_F16_VEC ay[GGML_F16_ARR];
|
@@ -2578,29 +3199,41 @@ inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x
|
|
2578
3199
|
//
|
2579
3200
|
|
2580
3201
|
static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = {
|
2581
|
-
|
2582
|
-
|
2583
|
-
|
2584
|
-
|
2585
|
-
|
2586
|
-
1,
|
2587
|
-
1,
|
3202
|
+
[GGML_TYPE_F32] = 1,
|
3203
|
+
[GGML_TYPE_F16] = 1,
|
3204
|
+
[GGML_TYPE_Q4_0] = QK4_0,
|
3205
|
+
[GGML_TYPE_Q4_1] = QK4_1,
|
3206
|
+
[GGML_TYPE_Q8_0] = QK8_0,
|
3207
|
+
[GGML_TYPE_I8] = 1,
|
3208
|
+
[GGML_TYPE_I16] = 1,
|
3209
|
+
[GGML_TYPE_I32] = 1,
|
2588
3210
|
};
|
2589
|
-
|
2590
|
-
static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_COUNT != 5");
|
3211
|
+
static_assert(GGML_TYPE_COUNT == 8, "GGML_BLCK_SIZE is outdated");
|
2591
3212
|
|
2592
3213
|
static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
|
2593
|
-
sizeof(
|
2594
|
-
sizeof(
|
2595
|
-
sizeof(
|
2596
|
-
sizeof(
|
2597
|
-
sizeof(
|
2598
|
-
sizeof(
|
2599
|
-
sizeof(
|
3214
|
+
[GGML_TYPE_F32] = sizeof(float),
|
3215
|
+
[GGML_TYPE_F16] = sizeof(ggml_fp16_t),
|
3216
|
+
[GGML_TYPE_Q4_0] = sizeof(block_q4_0),
|
3217
|
+
[GGML_TYPE_Q4_1] = sizeof(block_q4_1),
|
3218
|
+
[GGML_TYPE_Q8_0] = sizeof(block_q8_0),
|
3219
|
+
[GGML_TYPE_I8] = sizeof(int8_t),
|
3220
|
+
[GGML_TYPE_I16] = sizeof(int16_t),
|
3221
|
+
[GGML_TYPE_I32] = sizeof(int32_t),
|
2600
3222
|
};
|
2601
|
-
|
2602
|
-
|
2603
|
-
|
3223
|
+
static_assert(GGML_TYPE_COUNT == 8, "GGML_TYPE_SIZE is outdated");
|
3224
|
+
|
3225
|
+
|
3226
|
+
static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
|
3227
|
+
[GGML_TYPE_F32] = "f32",
|
3228
|
+
[GGML_TYPE_F16] = "f16",
|
3229
|
+
[GGML_TYPE_Q4_0] = "q4_0",
|
3230
|
+
[GGML_TYPE_Q4_1] = "q4_1",
|
3231
|
+
[GGML_TYPE_Q8_0] = "q8_0",
|
3232
|
+
[GGML_TYPE_I8] = "i8",
|
3233
|
+
[GGML_TYPE_I16] = "i16",
|
3234
|
+
[GGML_TYPE_I32] = "i32",
|
3235
|
+
};
|
3236
|
+
static_assert(GGML_TYPE_COUNT == 8, "GGML_TYPE_NAME is outdated");
|
2604
3237
|
|
2605
3238
|
static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
|
2606
3239
|
"NONE",
|
@@ -2629,6 +3262,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
|
|
2629
3262
|
|
2630
3263
|
"SCALE",
|
2631
3264
|
"CPY",
|
3265
|
+
"CONT",
|
2632
3266
|
"RESHAPE",
|
2633
3267
|
"VIEW",
|
2634
3268
|
"PERMUTE",
|
@@ -2642,9 +3276,12 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
|
|
2642
3276
|
|
2643
3277
|
"FLASH_ATTN",
|
2644
3278
|
"FLASH_FF",
|
3279
|
+
|
3280
|
+
"MAP_UNARY",
|
3281
|
+
"MAP_BINARY",
|
2645
3282
|
};
|
2646
3283
|
|
2647
|
-
static_assert(GGML_OP_COUNT ==
|
3284
|
+
static_assert(GGML_OP_COUNT == 38, "GGML_OP_COUNT != 38");
|
2648
3285
|
|
2649
3286
|
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
2650
3287
|
"none",
|
@@ -2673,6 +3310,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|
2673
3310
|
|
2674
3311
|
"x*v",
|
2675
3312
|
"x-\\>y",
|
3313
|
+
"cont(x)",
|
2676
3314
|
"reshape(x)",
|
2677
3315
|
"view(x)",
|
2678
3316
|
"permute(x)",
|
@@ -2686,24 +3324,12 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|
2686
3324
|
|
2687
3325
|
"flash_attn(x)",
|
2688
3326
|
"flash_ff(x)",
|
2689
|
-
};
|
2690
|
-
|
2691
|
-
static_assert(GGML_OP_COUNT == 35, "GGML_OP_COUNT != 35");
|
2692
|
-
|
2693
|
-
//
|
2694
|
-
// ggml object
|
2695
|
-
//
|
2696
|
-
|
2697
|
-
struct ggml_object {
|
2698
|
-
size_t offs;
|
2699
|
-
size_t size;
|
2700
3327
|
|
2701
|
-
|
2702
|
-
|
2703
|
-
char padding[8];
|
3328
|
+
"f(x)",
|
3329
|
+
"f(x,y)",
|
2704
3330
|
};
|
2705
3331
|
|
2706
|
-
|
3332
|
+
static_assert(GGML_OP_COUNT == 38, "GGML_OP_COUNT != 38");
|
2707
3333
|
|
2708
3334
|
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
|
2709
3335
|
static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
|
@@ -2716,7 +3342,6 @@ struct ggml_context {
|
|
2716
3342
|
size_t mem_size;
|
2717
3343
|
void * mem_buffer;
|
2718
3344
|
bool mem_buffer_owned;
|
2719
|
-
bool mem_buffer_mlocked;
|
2720
3345
|
bool no_alloc;
|
2721
3346
|
|
2722
3347
|
int n_objects;
|
@@ -2834,6 +3459,11 @@ float ggml_type_sizef(enum ggml_type type) {
|
|
2834
3459
|
return ((float)(GGML_TYPE_SIZE[type]))/GGML_BLCK_SIZE[type];
|
2835
3460
|
}
|
2836
3461
|
|
3462
|
+
const char * ggml_type_name(enum ggml_type type) {
|
3463
|
+
return GGML_TYPE_NAME[type];
|
3464
|
+
}
|
3465
|
+
|
3466
|
+
|
2837
3467
|
size_t ggml_element_size(const struct ggml_tensor * tensor) {
|
2838
3468
|
return GGML_TYPE_SIZE[tensor->type];
|
2839
3469
|
}
|
@@ -2999,11 +3629,12 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
|
|
2999
3629
|
return NULL;
|
3000
3630
|
}
|
3001
3631
|
|
3632
|
+
const size_t mem_size = (params.mem_size + GGML_MEM_ALIGN - 1) & ~(GGML_MEM_ALIGN - 1);
|
3633
|
+
|
3002
3634
|
*ctx = (struct ggml_context) {
|
3003
|
-
/*.mem_size =*/
|
3004
|
-
/*.mem_buffer =*/ params.mem_buffer ? params.mem_buffer :
|
3635
|
+
/*.mem_size =*/ mem_size,
|
3636
|
+
/*.mem_buffer =*/ params.mem_buffer ? params.mem_buffer : GGML_ALIGNED_MALLOC(mem_size),
|
3005
3637
|
/*.mem_buffer_owned =*/ params.mem_buffer ? false : true,
|
3006
|
-
/*.mem_buffer_mlocked =*/ false,
|
3007
3638
|
/*.no_alloc =*/ params.no_alloc,
|
3008
3639
|
/*.n_objects =*/ 0,
|
3009
3640
|
/*.objects_begin =*/ NULL,
|
@@ -3012,7 +3643,7 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
|
|
3012
3643
|
/*.scratch_save =*/ { 0, 0, NULL, },
|
3013
3644
|
};
|
3014
3645
|
|
3015
|
-
GGML_ASSERT(ctx->mem_buffer != NULL);
|
3646
|
+
GGML_ASSERT(ctx->mem_buffer != NULL);
|
3016
3647
|
|
3017
3648
|
ggml_assert_aligned(ctx->mem_buffer);
|
3018
3649
|
|
@@ -3036,16 +3667,8 @@ void ggml_free(struct ggml_context * ctx) {
|
|
3036
3667
|
GGML_PRINT_DEBUG("%s: context %d with %d objects has been freed. memory used = %zu\n",
|
3037
3668
|
__func__, i, ctx->n_objects, ctx->objects_end->offs + ctx->objects_end->size);
|
3038
3669
|
|
3039
|
-
#if GGML_MLOCK_SUPPORT
|
3040
|
-
if (ctx->mem_buffer_mlocked) {
|
3041
|
-
if (munlock(ctx->mem_buffer, ctx->mem_size)) {
|
3042
|
-
fprintf(stderr, "%s: failed to munlock buffer: %s\n", __func__, strerror(errno));
|
3043
|
-
}
|
3044
|
-
}
|
3045
|
-
#endif
|
3046
|
-
|
3047
3670
|
if (ctx->mem_buffer_owned) {
|
3048
|
-
|
3671
|
+
GGML_ALIGNED_FREE(ctx->mem_buffer);
|
3049
3672
|
}
|
3050
3673
|
|
3051
3674
|
found = true;
|
@@ -3072,48 +3695,6 @@ size_t ggml_set_scratch(struct ggml_context * ctx, struct ggml_scratch scratch)
|
|
3072
3695
|
return result;
|
3073
3696
|
}
|
3074
3697
|
|
3075
|
-
#ifdef __APPLE__
|
3076
|
-
#define MLOCK_SUGGESTION \
|
3077
|
-
"Try increasing the sysctl values 'vm.user_wire_limit' and 'vm.global_user_wire_limit' and/or " \
|
3078
|
-
"decreasing 'vm.global_no_user_wire_amount'. Also try increasing RLIMIT_MLOCK (ulimit -l).\n"
|
3079
|
-
#else
|
3080
|
-
#define MLOCK_SUGGESTION \
|
3081
|
-
"Try increasing RLIMIT_MLOCK ('ulimit -l' as root).\n"
|
3082
|
-
#endif
|
3083
|
-
|
3084
|
-
bool ggml_mlock_supported(void) {
|
3085
|
-
return GGML_MLOCK_SUPPORT;
|
3086
|
-
}
|
3087
|
-
|
3088
|
-
bool ggml_mlock(
|
3089
|
-
struct ggml_context * ctx,
|
3090
|
-
const void *opt_extra_addr,
|
3091
|
-
size_t opt_extra_len,
|
3092
|
-
char **err_p) {
|
3093
|
-
// TODO: Use SetProcessWorkingSetSize() + VirtualLock() on WIN32
|
3094
|
-
#if GGML_MLOCK_SUPPORT
|
3095
|
-
if (ctx->mem_buffer_mlocked) {
|
3096
|
-
return true;
|
3097
|
-
}
|
3098
|
-
if (mlock(ctx->mem_buffer, ctx->mem_size) ||
|
3099
|
-
(opt_extra_len &&
|
3100
|
-
mlock(opt_extra_addr, opt_extra_len))) {
|
3101
|
-
if ((*err_p = malloc(1024))) {
|
3102
|
-
snprintf(*err_p, 1024,
|
3103
|
-
"failed to mlock %zu-byte buffer: %s\n" MLOCK_SUGGESTION,
|
3104
|
-
ctx->mem_size + opt_extra_len,
|
3105
|
-
strerror(errno));
|
3106
|
-
}
|
3107
|
-
return false;
|
3108
|
-
}
|
3109
|
-
ctx->mem_buffer_mlocked = true;
|
3110
|
-
return true;
|
3111
|
-
#else // GGML_MLOCK_SUPPORT
|
3112
|
-
*err_p = strdup("can't mlock because it's not supported on this system");
|
3113
|
-
return false;
|
3114
|
-
#endif // GGML_MLOCK_SUPPORT
|
3115
|
-
}
|
3116
|
-
|
3117
3698
|
////////////////////////////////////////////////////////////////////////////////
|
3118
3699
|
|
3119
3700
|
struct ggml_tensor * ggml_new_tensor_impl(
|
@@ -3325,14 +3906,6 @@ struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) {
|
|
3325
3906
|
char * const data = tensor->data;
|
3326
3907
|
|
3327
3908
|
switch (tensor->type) {
|
3328
|
-
case GGML_TYPE_Q4_0:
|
3329
|
-
{
|
3330
|
-
GGML_ASSERT(false);
|
3331
|
-
} break;
|
3332
|
-
case GGML_TYPE_Q4_1:
|
3333
|
-
{
|
3334
|
-
GGML_ASSERT(false);
|
3335
|
-
} break;
|
3336
3909
|
case GGML_TYPE_I8:
|
3337
3910
|
{
|
3338
3911
|
assert(tensor->nb[0] == sizeof(int8_t));
|
@@ -3368,7 +3941,7 @@ struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) {
|
|
3368
3941
|
ggml_vec_set_f32(nc, (float *)(data + i*n1), value);
|
3369
3942
|
}
|
3370
3943
|
} break;
|
3371
|
-
|
3944
|
+
default:
|
3372
3945
|
{
|
3373
3946
|
GGML_ASSERT(false);
|
3374
3947
|
} break;
|
@@ -3385,14 +3958,6 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
|
|
3385
3958
|
char * const data = tensor->data;
|
3386
3959
|
|
3387
3960
|
switch (tensor->type) {
|
3388
|
-
case GGML_TYPE_Q4_0:
|
3389
|
-
{
|
3390
|
-
GGML_ASSERT(false);
|
3391
|
-
} break;
|
3392
|
-
case GGML_TYPE_Q4_1:
|
3393
|
-
{
|
3394
|
-
GGML_ASSERT(false);
|
3395
|
-
} break;
|
3396
3961
|
case GGML_TYPE_I8:
|
3397
3962
|
{
|
3398
3963
|
assert(tensor->nb[0] == sizeof(int8_t));
|
@@ -3428,7 +3993,7 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
|
|
3428
3993
|
ggml_vec_set_f32(nc, (float *)(data + i*n1), value);
|
3429
3994
|
}
|
3430
3995
|
} break;
|
3431
|
-
|
3996
|
+
default:
|
3432
3997
|
{
|
3433
3998
|
GGML_ASSERT(false);
|
3434
3999
|
} break;
|
@@ -3439,14 +4004,6 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
|
|
3439
4004
|
|
3440
4005
|
int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) {
|
3441
4006
|
switch (tensor->type) {
|
3442
|
-
case GGML_TYPE_Q4_0:
|
3443
|
-
{
|
3444
|
-
GGML_ASSERT(false);
|
3445
|
-
} break;
|
3446
|
-
case GGML_TYPE_Q4_1:
|
3447
|
-
{
|
3448
|
-
GGML_ASSERT(false);
|
3449
|
-
} break;
|
3450
4007
|
case GGML_TYPE_I8:
|
3451
4008
|
{
|
3452
4009
|
GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
|
@@ -3472,7 +4029,7 @@ int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) {
|
|
3472
4029
|
GGML_ASSERT(tensor->nb[0] == sizeof(float));
|
3473
4030
|
return ((float *)(tensor->data))[i];
|
3474
4031
|
} break;
|
3475
|
-
|
4032
|
+
default:
|
3476
4033
|
{
|
3477
4034
|
GGML_ASSERT(false);
|
3478
4035
|
} break;
|
@@ -3483,14 +4040,6 @@ int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) {
|
|
3483
4040
|
|
3484
4041
|
void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
|
3485
4042
|
switch (tensor->type) {
|
3486
|
-
case GGML_TYPE_Q4_0:
|
3487
|
-
{
|
3488
|
-
GGML_ASSERT(false);
|
3489
|
-
} break;
|
3490
|
-
case GGML_TYPE_Q4_1:
|
3491
|
-
{
|
3492
|
-
GGML_ASSERT(false);
|
3493
|
-
} break;
|
3494
4043
|
case GGML_TYPE_I8:
|
3495
4044
|
{
|
3496
4045
|
GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
|
@@ -3516,7 +4065,7 @@ void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
|
|
3516
4065
|
GGML_ASSERT(tensor->nb[0] == sizeof(float));
|
3517
4066
|
((float *)(tensor->data))[i] = value;
|
3518
4067
|
} break;
|
3519
|
-
|
4068
|
+
default:
|
3520
4069
|
{
|
3521
4070
|
GGML_ASSERT(false);
|
3522
4071
|
} break;
|
@@ -3525,14 +4074,6 @@ void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
|
|
3525
4074
|
|
3526
4075
|
float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {
|
3527
4076
|
switch (tensor->type) {
|
3528
|
-
case GGML_TYPE_Q4_0:
|
3529
|
-
{
|
3530
|
-
GGML_ASSERT(false);
|
3531
|
-
} break;
|
3532
|
-
case GGML_TYPE_Q4_1:
|
3533
|
-
{
|
3534
|
-
GGML_ASSERT(false);
|
3535
|
-
} break;
|
3536
4077
|
case GGML_TYPE_I8:
|
3537
4078
|
{
|
3538
4079
|
GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
|
@@ -3558,7 +4099,7 @@ float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {
|
|
3558
4099
|
GGML_ASSERT(tensor->nb[0] == sizeof(float));
|
3559
4100
|
return ((float *)(tensor->data))[i];
|
3560
4101
|
} break;
|
3561
|
-
|
4102
|
+
default:
|
3562
4103
|
{
|
3563
4104
|
GGML_ASSERT(false);
|
3564
4105
|
} break;
|
@@ -3569,14 +4110,6 @@ float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {
|
|
3569
4110
|
|
3570
4111
|
void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) {
|
3571
4112
|
switch (tensor->type) {
|
3572
|
-
case GGML_TYPE_Q4_0:
|
3573
|
-
{
|
3574
|
-
GGML_ASSERT(false);
|
3575
|
-
} break;
|
3576
|
-
case GGML_TYPE_Q4_1:
|
3577
|
-
{
|
3578
|
-
GGML_ASSERT(false);
|
3579
|
-
} break;
|
3580
4113
|
case GGML_TYPE_I8:
|
3581
4114
|
{
|
3582
4115
|
GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
|
@@ -3602,7 +4135,7 @@ void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) {
|
|
3602
4135
|
GGML_ASSERT(tensor->nb[0] == sizeof(float));
|
3603
4136
|
((float *)(tensor->data))[i] = value;
|
3604
4137
|
} break;
|
3605
|
-
|
4138
|
+
default:
|
3606
4139
|
{
|
3607
4140
|
GGML_ASSERT(false);
|
3608
4141
|
} break;
|
@@ -4388,26 +4921,22 @@ struct ggml_tensor * ggml_cpy_inplace(
|
|
4388
4921
|
return ggml_cpy_impl(ctx, a, b, true);
|
4389
4922
|
}
|
4390
4923
|
|
4391
|
-
//
|
4924
|
+
// ggml_cont
|
4392
4925
|
|
4393
|
-
struct ggml_tensor *
|
4926
|
+
struct ggml_tensor * ggml_cont_impl(
|
4394
4927
|
struct ggml_context * ctx,
|
4395
|
-
struct ggml_tensor
|
4396
|
-
|
4397
|
-
GGML_ASSERT(ggml_is_contiguous(a));
|
4398
|
-
GGML_ASSERT(ggml_is_contiguous(b));
|
4399
|
-
GGML_ASSERT(ggml_nelements(a) == ggml_nelements(b));
|
4400
|
-
|
4928
|
+
struct ggml_tensor * a,
|
4929
|
+
bool inplace) {
|
4401
4930
|
bool is_node = false;
|
4402
4931
|
|
4403
|
-
if (
|
4932
|
+
if (!inplace && a->grad) {
|
4404
4933
|
GGML_ASSERT(false); // TODO: implement backward
|
4405
4934
|
is_node = true;
|
4406
4935
|
}
|
4407
4936
|
|
4408
|
-
struct ggml_tensor * result =
|
4937
|
+
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
4409
4938
|
|
4410
|
-
result->op =
|
4939
|
+
result->op = GGML_OP_CONT;
|
4411
4940
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
4412
4941
|
result->src0 = a;
|
4413
4942
|
result->src1 = NULL;
|
@@ -4415,12 +4944,51 @@ struct ggml_tensor * ggml_reshape(
|
|
4415
4944
|
return result;
|
4416
4945
|
}
|
4417
4946
|
|
4418
|
-
struct ggml_tensor *
|
4947
|
+
struct ggml_tensor * ggml_cont(
|
4419
4948
|
struct ggml_context * ctx,
|
4420
|
-
struct ggml_tensor
|
4421
|
-
|
4422
|
-
|
4423
|
-
|
4949
|
+
struct ggml_tensor * a) {
|
4950
|
+
return ggml_cont_impl(ctx, a, false);
|
4951
|
+
}
|
4952
|
+
|
4953
|
+
struct ggml_tensor * ggml_cont_inplace(
|
4954
|
+
struct ggml_context * ctx,
|
4955
|
+
struct ggml_tensor * a) {
|
4956
|
+
return ggml_cont_impl(ctx, a, true);
|
4957
|
+
}
|
4958
|
+
|
4959
|
+
// ggml_reshape
|
4960
|
+
|
4961
|
+
struct ggml_tensor * ggml_reshape(
|
4962
|
+
struct ggml_context * ctx,
|
4963
|
+
struct ggml_tensor * a,
|
4964
|
+
struct ggml_tensor * b) {
|
4965
|
+
GGML_ASSERT(ggml_is_contiguous(a));
|
4966
|
+
GGML_ASSERT(ggml_is_contiguous(b));
|
4967
|
+
GGML_ASSERT(ggml_nelements(a) == ggml_nelements(b));
|
4968
|
+
|
4969
|
+
bool is_node = false;
|
4970
|
+
|
4971
|
+
if (a->grad || b->grad) {
|
4972
|
+
GGML_ASSERT(false); // TODO: implement backward
|
4973
|
+
is_node = true;
|
4974
|
+
}
|
4975
|
+
|
4976
|
+
struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, b->n_dims, b->ne, a->data);
|
4977
|
+
|
4978
|
+
result->op = GGML_OP_RESHAPE;
|
4979
|
+
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
4980
|
+
result->src0 = a;
|
4981
|
+
result->src1 = NULL;
|
4982
|
+
|
4983
|
+
return result;
|
4984
|
+
}
|
4985
|
+
|
4986
|
+
struct ggml_tensor * ggml_reshape_2d(
|
4987
|
+
struct ggml_context * ctx,
|
4988
|
+
struct ggml_tensor * a,
|
4989
|
+
int64_t ne0,
|
4990
|
+
int64_t ne1) {
|
4991
|
+
GGML_ASSERT(ggml_is_contiguous(a));
|
4424
4992
|
GGML_ASSERT(ggml_nelements(a) == ne0*ne1);
|
4425
4993
|
|
4426
4994
|
bool is_node = false;
|
@@ -4866,6 +5434,90 @@ struct ggml_tensor * ggml_flash_ff(
|
|
4866
5434
|
return result;
|
4867
5435
|
}
|
4868
5436
|
|
5437
|
+
// ggml_map_unary
|
5438
|
+
|
5439
|
+
struct ggml_tensor * ggml_map_unary_impl_f32(
|
5440
|
+
struct ggml_context * ctx,
|
5441
|
+
struct ggml_tensor * a,
|
5442
|
+
const ggml_unary_op_f32_t fun,
|
5443
|
+
bool inplace) {
|
5444
|
+
bool is_node = false;
|
5445
|
+
|
5446
|
+
if (!inplace && a->grad) {
|
5447
|
+
is_node = true;
|
5448
|
+
}
|
5449
|
+
|
5450
|
+
struct ggml_tensor * addr_tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(void *) / sizeof(int32_t));
|
5451
|
+
*((void (**)(void))addr_tensor->data) = (void (*)(void))fun;
|
5452
|
+
struct ggml_tensor *result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
5453
|
+
|
5454
|
+
result->op = GGML_OP_MAP_UNARY;
|
5455
|
+
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
5456
|
+
result->src0 = a;
|
5457
|
+
result->opt[0] = addr_tensor;
|
5458
|
+
|
5459
|
+
return result;
|
5460
|
+
}
|
5461
|
+
|
5462
|
+
struct ggml_tensor * ggml_map_unary_f32(
|
5463
|
+
struct ggml_context * ctx,
|
5464
|
+
struct ggml_tensor * a,
|
5465
|
+
const ggml_unary_op_f32_t fun) {
|
5466
|
+
return ggml_map_unary_impl_f32(ctx, a, fun, false);
|
5467
|
+
}
|
5468
|
+
|
5469
|
+
struct ggml_tensor * ggml_map_unary_inplace_f32(
|
5470
|
+
struct ggml_context * ctx,
|
5471
|
+
struct ggml_tensor * a,
|
5472
|
+
const ggml_unary_op_f32_t fun) {
|
5473
|
+
return ggml_map_unary_impl_f32(ctx, a, fun, true);
|
5474
|
+
}
|
5475
|
+
|
5476
|
+
// ggml_map_binary
|
5477
|
+
|
5478
|
+
struct ggml_tensor * ggml_map_binary_impl_f32(
|
5479
|
+
struct ggml_context * ctx,
|
5480
|
+
struct ggml_tensor * a,
|
5481
|
+
struct ggml_tensor * b,
|
5482
|
+
const ggml_binary_op_f32_t fun,
|
5483
|
+
bool inplace) {
|
5484
|
+
GGML_ASSERT(ggml_are_same_shape(a, b));
|
5485
|
+
|
5486
|
+
bool is_node = false;
|
5487
|
+
|
5488
|
+
if (!inplace && (a->grad || b->grad)) {
|
5489
|
+
is_node = true;
|
5490
|
+
}
|
5491
|
+
|
5492
|
+
struct ggml_tensor * addr_tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(void *) / sizeof(int32_t));
|
5493
|
+
*((void (**)(void))addr_tensor->data) = (void (*)(void))fun;
|
5494
|
+
struct ggml_tensor *result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
5495
|
+
|
5496
|
+
result->op = GGML_OP_MAP_BINARY;
|
5497
|
+
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
5498
|
+
result->src0 = a;
|
5499
|
+
result->src1 = b;
|
5500
|
+
result->opt[0] = addr_tensor;
|
5501
|
+
|
5502
|
+
return result;
|
5503
|
+
}
|
5504
|
+
|
5505
|
+
struct ggml_tensor * ggml_map_binary_f32(
|
5506
|
+
struct ggml_context * ctx,
|
5507
|
+
struct ggml_tensor * a,
|
5508
|
+
struct ggml_tensor * b,
|
5509
|
+
const ggml_binary_op_f32_t fun) {
|
5510
|
+
return ggml_map_binary_impl_f32(ctx, a, b, fun, false);
|
5511
|
+
}
|
5512
|
+
|
5513
|
+
struct ggml_tensor * ggml_map_binary_inplace_f32(
|
5514
|
+
struct ggml_context * ctx,
|
5515
|
+
struct ggml_tensor * a,
|
5516
|
+
struct ggml_tensor * b,
|
5517
|
+
const ggml_binary_op_f32_t fun) {
|
5518
|
+
return ggml_map_binary_impl_f32(ctx, a, b, fun, true);
|
5519
|
+
}
|
5520
|
+
|
4869
5521
|
////////////////////////////////////////////////////////////////////////////////
|
4870
5522
|
|
4871
5523
|
void ggml_set_param(
|
@@ -4930,6 +5582,105 @@ static void ggml_compute_forward_dup_f16(
|
|
4930
5582
|
|
4931
5583
|
// TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
|
4932
5584
|
|
5585
|
+
if (ggml_is_contiguous(dst)) {
|
5586
|
+
if (src0->nb[0] == sizeof(ggml_fp16_t)) {
|
5587
|
+
if (dst->type == GGML_TYPE_F16) {
|
5588
|
+
size_t id = 0;
|
5589
|
+
const size_t rs = ne00*nb00;
|
5590
|
+
|
5591
|
+
for (int i03 = 0; i03 < ne03; i03++) {
|
5592
|
+
for (int i02 = 0; i02 < ne02; i02++) {
|
5593
|
+
for (int i01 = 0; i01 < ne01; i01++) {
|
5594
|
+
const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
|
5595
|
+
char * dst_ptr = (char *) dst->data + id*rs;
|
5596
|
+
|
5597
|
+
memcpy(dst_ptr, src0_ptr, rs);
|
5598
|
+
|
5599
|
+
id++;
|
5600
|
+
}
|
5601
|
+
}
|
5602
|
+
}
|
5603
|
+
} else if (dst->type == GGML_TYPE_F32) {
|
5604
|
+
size_t id = 0;
|
5605
|
+
float * dst_ptr = (float *) dst->data;
|
5606
|
+
|
5607
|
+
for (int i03 = 0; i03 < ne03; i03++) {
|
5608
|
+
for (int i02 = 0; i02 < ne02; i02++) {
|
5609
|
+
for (int i01 = 0; i01 < ne01; i01++) {
|
5610
|
+
for (int i00 = 0; i00 < ne00; i00++) {
|
5611
|
+
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
5612
|
+
|
5613
|
+
dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr);
|
5614
|
+
id++;
|
5615
|
+
}
|
5616
|
+
}
|
5617
|
+
}
|
5618
|
+
}
|
5619
|
+
} else if (dst->type == GGML_TYPE_Q4_0 || dst->type == GGML_TYPE_Q4_1) {
|
5620
|
+
quantize_row_q_t const quantize_row_q = quantize_fns[dst->type].quantize_row_q;
|
5621
|
+
size_t id = 0;
|
5622
|
+
uint8_t * dst_ptr = (uint8_t *) dst->data;
|
5623
|
+
size_t dst_row_size = nb0 * (ne00 / GGML_BLCK_SIZE[dst->type]);
|
5624
|
+
float * src0_f32 = (float *) params->wdata;
|
5625
|
+
|
5626
|
+
for (int i03 = 0; i03 < ne03; i03++) {
|
5627
|
+
for (int i02 = 0; i02 < ne02; i02++) {
|
5628
|
+
for (int i01 = 0; i01 < ne01; i01++) {
|
5629
|
+
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
5630
|
+
// convert to f32 and quantize
|
5631
|
+
for (int i00 = 0; i00 < ne00; i00++) {
|
5632
|
+
src0_f32[i00] = GGML_FP16_TO_FP32(src0_ptr[i00]);
|
5633
|
+
}
|
5634
|
+
quantize_row_q(src0_f32, dst_ptr + id, ne00);
|
5635
|
+
id += dst_row_size;
|
5636
|
+
}
|
5637
|
+
}
|
5638
|
+
}
|
5639
|
+
} else {
|
5640
|
+
GGML_ASSERT(false); // TODO: implement
|
5641
|
+
}
|
5642
|
+
} else {
|
5643
|
+
//printf("%s: this is not optimal - fix me\n", __func__);
|
5644
|
+
|
5645
|
+
if (dst->type == GGML_TYPE_F32) {
|
5646
|
+
size_t id = 0;
|
5647
|
+
float * dst_ptr = (float *) dst->data;
|
5648
|
+
|
5649
|
+
for (int i03 = 0; i03 < ne03; i03++) {
|
5650
|
+
for (int i02 = 0; i02 < ne02; i02++) {
|
5651
|
+
for (int i01 = 0; i01 < ne01; i01++) {
|
5652
|
+
for (int i00 = 0; i00 < ne00; i00++) {
|
5653
|
+
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
5654
|
+
|
5655
|
+
dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr);
|
5656
|
+
id++;
|
5657
|
+
}
|
5658
|
+
}
|
5659
|
+
}
|
5660
|
+
}
|
5661
|
+
} else if (dst->type == GGML_TYPE_F16) {
|
5662
|
+
size_t id = 0;
|
5663
|
+
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
|
5664
|
+
|
5665
|
+
for (int i03 = 0; i03 < ne03; i03++) {
|
5666
|
+
for (int i02 = 0; i02 < ne02; i02++) {
|
5667
|
+
for (int i01 = 0; i01 < ne01; i01++) {
|
5668
|
+
for (int i00 = 0; i00 < ne00; i00++) {
|
5669
|
+
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
5670
|
+
|
5671
|
+
dst_ptr[id] = *src0_ptr;
|
5672
|
+
id++;
|
5673
|
+
}
|
5674
|
+
}
|
5675
|
+
}
|
5676
|
+
}
|
5677
|
+
} else {
|
5678
|
+
GGML_ASSERT(false); // TODO: implement
|
5679
|
+
}
|
5680
|
+
}
|
5681
|
+
return;
|
5682
|
+
}
|
5683
|
+
|
4933
5684
|
// dst counters
|
4934
5685
|
int64_t i10 = 0;
|
4935
5686
|
int64_t i11 = 0;
|
@@ -5024,6 +5775,120 @@ static void ggml_compute_forward_dup_f32(
|
|
5024
5775
|
return;
|
5025
5776
|
}
|
5026
5777
|
|
5778
|
+
if (src0->type == dst->type &&
|
5779
|
+
src0->ne[0] == dst->ne[0] &&
|
5780
|
+
src0->nb[0] == GGML_TYPE_SIZE[src0->type] && dst->nb[0] == GGML_TYPE_SIZE[dst->type]) {
|
5781
|
+
// copy by rows
|
5782
|
+
const size_t rs = ne00*nb00;
|
5783
|
+
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
5784
|
+
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
5785
|
+
for (int64_t i01 = 0; i01 < ne01; i01++) {
|
5786
|
+
memcpy(
|
5787
|
+
((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
|
5788
|
+
((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
|
5789
|
+
rs);
|
5790
|
+
}
|
5791
|
+
}
|
5792
|
+
}
|
5793
|
+
return;
|
5794
|
+
}
|
5795
|
+
|
5796
|
+
if (ggml_is_contiguous(dst)) {
|
5797
|
+
// TODO: simplify
|
5798
|
+
if (src0->nb[0] == sizeof(float)) {
|
5799
|
+
if (dst->type == GGML_TYPE_F32) {
|
5800
|
+
size_t id = 0;
|
5801
|
+
const size_t rs = ne00*nb00;
|
5802
|
+
|
5803
|
+
for (int i03 = 0; i03 < ne03; i03++) {
|
5804
|
+
for (int i02 = 0; i02 < ne02; i02++) {
|
5805
|
+
for (int i01 = 0; i01 < ne01; i01++) {
|
5806
|
+
const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
|
5807
|
+
char * dst_ptr = (char *) dst->data + id*rs;
|
5808
|
+
|
5809
|
+
memcpy(dst_ptr, src0_ptr, rs);
|
5810
|
+
|
5811
|
+
id++;
|
5812
|
+
}
|
5813
|
+
}
|
5814
|
+
}
|
5815
|
+
} else if (dst->type == GGML_TYPE_F16) {
|
5816
|
+
size_t id = 0;
|
5817
|
+
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
|
5818
|
+
|
5819
|
+
for (int i03 = 0; i03 < ne03; i03++) {
|
5820
|
+
for (int i02 = 0; i02 < ne02; i02++) {
|
5821
|
+
for (int i01 = 0; i01 < ne01; i01++) {
|
5822
|
+
for (int i00 = 0; i00 < ne00; i00++) {
|
5823
|
+
const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
5824
|
+
|
5825
|
+
dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr);
|
5826
|
+
id++;
|
5827
|
+
}
|
5828
|
+
}
|
5829
|
+
}
|
5830
|
+
}
|
5831
|
+
} else if (dst->type == GGML_TYPE_Q4_0 || dst->type == GGML_TYPE_Q4_1) {
|
5832
|
+
quantize_row_q_t const quantize_row_q = quantize_fns[dst->type].quantize_row_q;
|
5833
|
+
size_t id = 0;
|
5834
|
+
uint8_t * dst_ptr = (uint8_t *) dst->data;
|
5835
|
+
size_t dst_row_size = nb0 * (ne00 / GGML_BLCK_SIZE[dst->type]);
|
5836
|
+
|
5837
|
+
for (int i03 = 0; i03 < ne03; i03++) {
|
5838
|
+
for (int i02 = 0; i02 < ne02; i02++) {
|
5839
|
+
for (int i01 = 0; i01 < ne01; i01++) {
|
5840
|
+
const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
5841
|
+
quantize_row_q(src0_ptr, dst_ptr + id, ne00);
|
5842
|
+
id += dst_row_size;
|
5843
|
+
}
|
5844
|
+
}
|
5845
|
+
}
|
5846
|
+
} else {
|
5847
|
+
GGML_ASSERT(false); // TODO: implement
|
5848
|
+
}
|
5849
|
+
} else {
|
5850
|
+
//printf("%s: this is not optimal - fix me\n", __func__);
|
5851
|
+
|
5852
|
+
if (dst->type == GGML_TYPE_F32) {
|
5853
|
+
size_t id = 0;
|
5854
|
+
float * dst_ptr = (float *) dst->data;
|
5855
|
+
|
5856
|
+
for (int i03 = 0; i03 < ne03; i03++) {
|
5857
|
+
for (int i02 = 0; i02 < ne02; i02++) {
|
5858
|
+
for (int i01 = 0; i01 < ne01; i01++) {
|
5859
|
+
for (int i00 = 0; i00 < ne00; i00++) {
|
5860
|
+
const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
5861
|
+
|
5862
|
+
dst_ptr[id] = *src0_ptr;
|
5863
|
+
id++;
|
5864
|
+
}
|
5865
|
+
}
|
5866
|
+
}
|
5867
|
+
}
|
5868
|
+
} else if (dst->type == GGML_TYPE_F16) {
|
5869
|
+
size_t id = 0;
|
5870
|
+
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
|
5871
|
+
|
5872
|
+
for (int i03 = 0; i03 < ne03; i03++) {
|
5873
|
+
for (int i02 = 0; i02 < ne02; i02++) {
|
5874
|
+
for (int i01 = 0; i01 < ne01; i01++) {
|
5875
|
+
for (int i00 = 0; i00 < ne00; i00++) {
|
5876
|
+
const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
5877
|
+
|
5878
|
+
dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr);
|
5879
|
+
id++;
|
5880
|
+
}
|
5881
|
+
}
|
5882
|
+
}
|
5883
|
+
}
|
5884
|
+
} else {
|
5885
|
+
GGML_ASSERT(false); // TODO: implement
|
5886
|
+
}
|
5887
|
+
}
|
5888
|
+
|
5889
|
+
return;
|
5890
|
+
}
|
5891
|
+
|
5027
5892
|
// dst counters
|
5028
5893
|
int64_t i10 = 0;
|
5029
5894
|
int64_t i11 = 0;
|
@@ -5100,12 +5965,7 @@ static void ggml_compute_forward_dup(
|
|
5100
5965
|
{
|
5101
5966
|
ggml_compute_forward_dup_f32(params, src0, dst);
|
5102
5967
|
} break;
|
5103
|
-
|
5104
|
-
case GGML_TYPE_Q4_1:
|
5105
|
-
case GGML_TYPE_I8:
|
5106
|
-
case GGML_TYPE_I16:
|
5107
|
-
case GGML_TYPE_I32:
|
5108
|
-
case GGML_TYPE_COUNT:
|
5968
|
+
default:
|
5109
5969
|
{
|
5110
5970
|
GGML_ASSERT(false);
|
5111
5971
|
} break;
|
@@ -5144,14 +6004,18 @@ static void ggml_compute_forward_add_f32(
|
|
5144
6004
|
GGML_ASSERT(nb00 == sizeof(float));
|
5145
6005
|
|
5146
6006
|
if (nb10 == sizeof(float)) {
|
5147
|
-
|
5148
|
-
|
5149
|
-
|
5150
|
-
|
6007
|
+
for (int j = ith; j < n; j += nth) {
|
6008
|
+
#ifdef GGML_USE_ACCELERATE
|
6009
|
+
vDSP_vadd(
|
6010
|
+
(float *) ((char *) src0->data + j*nb01), 1,
|
6011
|
+
(float *) ((char *) src1->data + j*nb11), 1,
|
6012
|
+
(float *) ((char *) dst->data + j*nb1), 1, nc);
|
6013
|
+
#else
|
5151
6014
|
ggml_vec_add_f32(nc,
|
5152
6015
|
(float *) ((char *) dst->data + j*nb1),
|
5153
6016
|
(float *) ((char *) src0->data + j*nb01),
|
5154
6017
|
(float *) ((char *) src1->data + j*nb11));
|
6018
|
+
#endif
|
5155
6019
|
}
|
5156
6020
|
} else {
|
5157
6021
|
// src1 is not contiguous
|
@@ -5167,6 +6031,212 @@ static void ggml_compute_forward_add_f32(
|
|
5167
6031
|
}
|
5168
6032
|
}
|
5169
6033
|
|
6034
|
+
static void ggml_compute_forward_add_f16_f32(
|
6035
|
+
const struct ggml_compute_params * params,
|
6036
|
+
const struct ggml_tensor * src0,
|
6037
|
+
const struct ggml_tensor * src1,
|
6038
|
+
struct ggml_tensor * dst) {
|
6039
|
+
GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
|
6040
|
+
|
6041
|
+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
6042
|
+
return;
|
6043
|
+
}
|
6044
|
+
|
6045
|
+
const int ith = params->ith;
|
6046
|
+
const int nth = params->nth;
|
6047
|
+
|
6048
|
+
const int n = ggml_nrows(src0);
|
6049
|
+
const int nc = src0->ne[0];
|
6050
|
+
|
6051
|
+
const size_t nb00 = src0->nb[0];
|
6052
|
+
const size_t nb01 = src0->nb[1];
|
6053
|
+
|
6054
|
+
const size_t nb10 = src1->nb[0];
|
6055
|
+
const size_t nb11 = src1->nb[1];
|
6056
|
+
|
6057
|
+
const size_t nb0 = dst->nb[0];
|
6058
|
+
const size_t nb1 = dst->nb[1];
|
6059
|
+
|
6060
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
6061
|
+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
6062
|
+
GGML_ASSERT(dst->type == GGML_TYPE_F16);
|
6063
|
+
|
6064
|
+
GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
|
6065
|
+
GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
|
6066
|
+
|
6067
|
+
if (nb10 == sizeof(float)) {
|
6068
|
+
for (int j = ith; j < n; j += nth) {
|
6069
|
+
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + j*nb1);
|
6070
|
+
ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01);
|
6071
|
+
for (int i = 0; i < nc; i++) {
|
6072
|
+
float * src1_ptr = (float *) ((char *) src1->data + j*nb11 + i*nb10);
|
6073
|
+
dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + *src1_ptr);
|
6074
|
+
}
|
6075
|
+
}
|
6076
|
+
}
|
6077
|
+
else {
|
6078
|
+
// src1 is not contiguous
|
6079
|
+
GGML_ASSERT(false);
|
6080
|
+
}
|
6081
|
+
}
|
6082
|
+
|
6083
|
+
static void ggml_compute_forward_add_f16_f16(
|
6084
|
+
const struct ggml_compute_params * params,
|
6085
|
+
const struct ggml_tensor * src0,
|
6086
|
+
const struct ggml_tensor * src1,
|
6087
|
+
struct ggml_tensor * dst) {
|
6088
|
+
GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
|
6089
|
+
|
6090
|
+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
6091
|
+
return;
|
6092
|
+
}
|
6093
|
+
|
6094
|
+
const int ith = params->ith;
|
6095
|
+
const int nth = params->nth;
|
6096
|
+
|
6097
|
+
const int n = ggml_nrows(src0);
|
6098
|
+
const int nc = src0->ne[0];
|
6099
|
+
|
6100
|
+
const size_t nb00 = src0->nb[0];
|
6101
|
+
const size_t nb01 = src0->nb[1];
|
6102
|
+
|
6103
|
+
const size_t nb10 = src1->nb[0];
|
6104
|
+
const size_t nb11 = src1->nb[1];
|
6105
|
+
|
6106
|
+
const size_t nb0 = dst->nb[0];
|
6107
|
+
const size_t nb1 = dst->nb[1];
|
6108
|
+
|
6109
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
6110
|
+
GGML_ASSERT(src1->type == GGML_TYPE_F16);
|
6111
|
+
GGML_ASSERT(dst->type == GGML_TYPE_F16);
|
6112
|
+
|
6113
|
+
GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
|
6114
|
+
GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
|
6115
|
+
|
6116
|
+
if (nb10 == sizeof(ggml_fp16_t)) {
|
6117
|
+
for (int j = ith; j < n; j += nth) {
|
6118
|
+
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + j*nb1);
|
6119
|
+
ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01);
|
6120
|
+
for (int i = 0; i < nc; i++) {
|
6121
|
+
ggml_fp16_t * src1_ptr = (ggml_fp16_t *) ((char *) src1->data + j*nb11 + i*nb10);
|
6122
|
+
dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + GGML_FP16_TO_FP32(*src1_ptr));
|
6123
|
+
}
|
6124
|
+
}
|
6125
|
+
}
|
6126
|
+
else {
|
6127
|
+
// src1 is not contiguous
|
6128
|
+
GGML_ASSERT(false);
|
6129
|
+
}
|
6130
|
+
}
|
6131
|
+
|
6132
|
+
static void ggml_compute_forward_add_q_f32(
|
6133
|
+
const struct ggml_compute_params * params,
|
6134
|
+
const struct ggml_tensor * src0,
|
6135
|
+
const struct ggml_tensor * src1,
|
6136
|
+
struct ggml_tensor * dst) {
|
6137
|
+
GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
|
6138
|
+
|
6139
|
+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
6140
|
+
return;
|
6141
|
+
}
|
6142
|
+
|
6143
|
+
const int64_t ne00 = src0->ne[0];
|
6144
|
+
const int64_t ne01 = src0->ne[1];
|
6145
|
+
const int64_t ne02 = src0->ne[2];
|
6146
|
+
const int64_t ne03 = src0->ne[3];
|
6147
|
+
|
6148
|
+
//const int64_t ne10 = src1->ne[0];
|
6149
|
+
//const int64_t ne11 = src1->ne[1];
|
6150
|
+
const int64_t ne12 = src1->ne[2];
|
6151
|
+
const int64_t ne13 = src1->ne[3];
|
6152
|
+
|
6153
|
+
//const int64_t ne0 = dst->ne[0];
|
6154
|
+
//const int64_t ne1 = dst->ne[1];
|
6155
|
+
const int64_t ne2 = dst->ne[2];
|
6156
|
+
const int64_t ne3 = dst->ne[3];
|
6157
|
+
|
6158
|
+
const int nb00 = src0->nb[0];
|
6159
|
+
const int nb01 = src0->nb[1];
|
6160
|
+
const int nb02 = src0->nb[2];
|
6161
|
+
const int nb03 = src0->nb[3];
|
6162
|
+
|
6163
|
+
const int nb10 = src1->nb[0];
|
6164
|
+
const int nb11 = src1->nb[1];
|
6165
|
+
const int nb12 = src1->nb[2];
|
6166
|
+
const int nb13 = src1->nb[3];
|
6167
|
+
|
6168
|
+
const int nb0 = dst->nb[0];
|
6169
|
+
const int nb1 = dst->nb[1];
|
6170
|
+
const int nb2 = dst->nb[2];
|
6171
|
+
const int nb3 = dst->nb[3];
|
6172
|
+
|
6173
|
+
const int ith = params->ith;
|
6174
|
+
const int nth = params->nth;
|
6175
|
+
|
6176
|
+
GGML_ASSERT(ne02 == ne12);
|
6177
|
+
GGML_ASSERT(ne03 == ne13);
|
6178
|
+
GGML_ASSERT(ne2 == ne12);
|
6179
|
+
GGML_ASSERT(ne3 == ne13);
|
6180
|
+
|
6181
|
+
const enum ggml_type type = src0->type;
|
6182
|
+
dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
|
6183
|
+
quantize_row_q_t const quantize_row_q = quantize_fns[type].quantize_row_q;
|
6184
|
+
|
6185
|
+
// we don't support permuted src0 or src1
|
6186
|
+
GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[type]);
|
6187
|
+
GGML_ASSERT(nb10 == sizeof(float));
|
6188
|
+
|
6189
|
+
// dst cannot be transposed or permuted
|
6190
|
+
GGML_ASSERT(nb0 <= nb1);
|
6191
|
+
GGML_ASSERT(nb1 <= nb2);
|
6192
|
+
GGML_ASSERT(nb2 <= nb3);
|
6193
|
+
|
6194
|
+
GGML_ASSERT(src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1);
|
6195
|
+
GGML_ASSERT(dst->type == src0->type);
|
6196
|
+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
6197
|
+
|
6198
|
+
// total rows in src0
|
6199
|
+
const int nr = ne01*ne02*ne03;
|
6200
|
+
|
6201
|
+
// rows per thread
|
6202
|
+
const int dr = (nr + nth - 1)/nth;
|
6203
|
+
|
6204
|
+
// row range for this thread
|
6205
|
+
const int ir0 = dr*ith;
|
6206
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
6207
|
+
|
6208
|
+
float * wdata = (float*) params->wdata + ne00 * ith;
|
6209
|
+
|
6210
|
+
for (int ir = ir0; ir < ir1; ++ir) {
|
6211
|
+
// src0 indices
|
6212
|
+
const int i03 = ir/(ne02*ne01);
|
6213
|
+
const int i02 = (ir - i03*ne02*ne01)/ne01;
|
6214
|
+
const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
|
6215
|
+
|
6216
|
+
// src1 and dst are same shape as src0 => same indices
|
6217
|
+
const int i13 = i03;
|
6218
|
+
const int i12 = i02;
|
6219
|
+
const int i11 = i01;
|
6220
|
+
|
6221
|
+
const int i3 = i03;
|
6222
|
+
const int i2 = i02;
|
6223
|
+
const int i1 = i01;
|
6224
|
+
|
6225
|
+
void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
|
6226
|
+
float * src1_row = (float *)((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13));
|
6227
|
+
void * dst_row = (void *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb0));
|
6228
|
+
|
6229
|
+
assert(ne00 % 32 == 0);
|
6230
|
+
|
6231
|
+
// unquantize row from src0 to temp buffer
|
6232
|
+
dequantize_row_q(src0_row, wdata, ne00);
|
6233
|
+
// add src1
|
6234
|
+
ggml_vec_acc_f32(ne00, wdata, src1_row);
|
6235
|
+
// quantize row to dst
|
6236
|
+
quantize_row_q(wdata, dst_row, ne00);
|
6237
|
+
}
|
6238
|
+
}
|
6239
|
+
|
5170
6240
|
static void ggml_compute_forward_add(
|
5171
6241
|
const struct ggml_compute_params * params,
|
5172
6242
|
const struct ggml_tensor * src0,
|
@@ -5177,13 +6247,24 @@ static void ggml_compute_forward_add(
|
|
5177
6247
|
{
|
5178
6248
|
ggml_compute_forward_add_f32(params, src0, src1, dst);
|
5179
6249
|
} break;
|
6250
|
+
case GGML_TYPE_F16:
|
6251
|
+
{
|
6252
|
+
if (src1->type == GGML_TYPE_F16) {
|
6253
|
+
ggml_compute_forward_add_f16_f16(params, src0, src1, dst);
|
6254
|
+
}
|
6255
|
+
else if (src1->type == GGML_TYPE_F32) {
|
6256
|
+
ggml_compute_forward_add_f16_f32(params, src0, src1, dst);
|
6257
|
+
}
|
6258
|
+
else {
|
6259
|
+
GGML_ASSERT(false);
|
6260
|
+
}
|
6261
|
+
} break;
|
5180
6262
|
case GGML_TYPE_Q4_0:
|
5181
6263
|
case GGML_TYPE_Q4_1:
|
5182
|
-
|
5183
|
-
|
5184
|
-
|
5185
|
-
|
5186
|
-
case GGML_TYPE_COUNT:
|
6264
|
+
{
|
6265
|
+
ggml_compute_forward_add_q_f32(params, src0, src1, dst);
|
6266
|
+
} break;
|
6267
|
+
default:
|
5187
6268
|
{
|
5188
6269
|
GGML_ASSERT(false);
|
5189
6270
|
} break;
|
@@ -5229,13 +6310,7 @@ static void ggml_compute_forward_sub(
|
|
5229
6310
|
{
|
5230
6311
|
ggml_compute_forward_sub_f32(params, src0, src1, dst);
|
5231
6312
|
} break;
|
5232
|
-
|
5233
|
-
case GGML_TYPE_Q4_1:
|
5234
|
-
case GGML_TYPE_I8:
|
5235
|
-
case GGML_TYPE_I16:
|
5236
|
-
case GGML_TYPE_I32:
|
5237
|
-
case GGML_TYPE_F16:
|
5238
|
-
case GGML_TYPE_COUNT:
|
6313
|
+
default:
|
5239
6314
|
{
|
5240
6315
|
GGML_ASSERT(false);
|
5241
6316
|
} break;
|
@@ -5281,13 +6356,7 @@ static void ggml_compute_forward_mul(
|
|
5281
6356
|
{
|
5282
6357
|
ggml_compute_forward_mul_f32(params, src0, src1, dst);
|
5283
6358
|
} break;
|
5284
|
-
|
5285
|
-
case GGML_TYPE_Q4_1:
|
5286
|
-
case GGML_TYPE_I8:
|
5287
|
-
case GGML_TYPE_I16:
|
5288
|
-
case GGML_TYPE_I32:
|
5289
|
-
case GGML_TYPE_F16:
|
5290
|
-
case GGML_TYPE_COUNT:
|
6359
|
+
default:
|
5291
6360
|
{
|
5292
6361
|
GGML_ASSERT(false);
|
5293
6362
|
} break;
|
@@ -5333,13 +6402,7 @@ static void ggml_compute_forward_div(
|
|
5333
6402
|
{
|
5334
6403
|
ggml_compute_forward_div_f32(params, src0, src1, dst);
|
5335
6404
|
} break;
|
5336
|
-
|
5337
|
-
case GGML_TYPE_Q4_1:
|
5338
|
-
case GGML_TYPE_I8:
|
5339
|
-
case GGML_TYPE_I16:
|
5340
|
-
case GGML_TYPE_I32:
|
5341
|
-
case GGML_TYPE_F16:
|
5342
|
-
case GGML_TYPE_COUNT:
|
6405
|
+
default:
|
5343
6406
|
{
|
5344
6407
|
GGML_ASSERT(false);
|
5345
6408
|
} break;
|
@@ -5381,13 +6444,7 @@ static void ggml_compute_forward_sqr(
|
|
5381
6444
|
{
|
5382
6445
|
ggml_compute_forward_sqr_f32(params, src0, dst);
|
5383
6446
|
} break;
|
5384
|
-
|
5385
|
-
case GGML_TYPE_Q4_1:
|
5386
|
-
case GGML_TYPE_I8:
|
5387
|
-
case GGML_TYPE_I16:
|
5388
|
-
case GGML_TYPE_I32:
|
5389
|
-
case GGML_TYPE_F16:
|
5390
|
-
case GGML_TYPE_COUNT:
|
6447
|
+
default:
|
5391
6448
|
{
|
5392
6449
|
GGML_ASSERT(false);
|
5393
6450
|
} break;
|
@@ -5429,13 +6486,7 @@ static void ggml_compute_forward_sqrt(
|
|
5429
6486
|
{
|
5430
6487
|
ggml_compute_forward_sqrt_f32(params, src0, dst);
|
5431
6488
|
} break;
|
5432
|
-
|
5433
|
-
case GGML_TYPE_Q4_1:
|
5434
|
-
case GGML_TYPE_I8:
|
5435
|
-
case GGML_TYPE_I16:
|
5436
|
-
case GGML_TYPE_I32:
|
5437
|
-
case GGML_TYPE_F16:
|
5438
|
-
case GGML_TYPE_COUNT:
|
6489
|
+
default:
|
5439
6490
|
{
|
5440
6491
|
GGML_ASSERT(false);
|
5441
6492
|
} break;
|
@@ -5485,16 +6536,10 @@ static void ggml_compute_forward_sum(
|
|
5485
6536
|
switch (src0->type) {
|
5486
6537
|
case GGML_TYPE_F32:
|
5487
6538
|
{
|
5488
|
-
ggml_compute_forward_sum_f32(params, src0, dst);
|
5489
|
-
} break;
|
5490
|
-
|
5491
|
-
|
5492
|
-
case GGML_TYPE_I8:
|
5493
|
-
case GGML_TYPE_I16:
|
5494
|
-
case GGML_TYPE_I32:
|
5495
|
-
case GGML_TYPE_F16:
|
5496
|
-
case GGML_TYPE_COUNT:
|
5497
|
-
{
|
6539
|
+
ggml_compute_forward_sum_f32(params, src0, dst);
|
6540
|
+
} break;
|
6541
|
+
default:
|
6542
|
+
{
|
5498
6543
|
GGML_ASSERT(false);
|
5499
6544
|
} break;
|
5500
6545
|
}
|
@@ -5564,13 +6609,7 @@ static void ggml_compute_forward_mean(
|
|
5564
6609
|
{
|
5565
6610
|
ggml_compute_forward_mean_f32(params, src0, dst);
|
5566
6611
|
} break;
|
5567
|
-
|
5568
|
-
case GGML_TYPE_Q4_1:
|
5569
|
-
case GGML_TYPE_I8:
|
5570
|
-
case GGML_TYPE_I16:
|
5571
|
-
case GGML_TYPE_I32:
|
5572
|
-
case GGML_TYPE_F16:
|
5573
|
-
case GGML_TYPE_COUNT:
|
6612
|
+
default:
|
5574
6613
|
{
|
5575
6614
|
GGML_ASSERT(false);
|
5576
6615
|
} break;
|
@@ -5628,13 +6667,7 @@ static void ggml_compute_forward_repeat(
|
|
5628
6667
|
{
|
5629
6668
|
ggml_compute_forward_repeat_f32(params, src0, dst);
|
5630
6669
|
} break;
|
5631
|
-
|
5632
|
-
case GGML_TYPE_Q4_1:
|
5633
|
-
case GGML_TYPE_I8:
|
5634
|
-
case GGML_TYPE_I16:
|
5635
|
-
case GGML_TYPE_I32:
|
5636
|
-
case GGML_TYPE_F16:
|
5637
|
-
case GGML_TYPE_COUNT:
|
6670
|
+
default:
|
5638
6671
|
{
|
5639
6672
|
GGML_ASSERT(false);
|
5640
6673
|
} break;
|
@@ -5676,13 +6709,7 @@ static void ggml_compute_forward_abs(
|
|
5676
6709
|
{
|
5677
6710
|
ggml_compute_forward_abs_f32(params, src0, dst);
|
5678
6711
|
} break;
|
5679
|
-
|
5680
|
-
case GGML_TYPE_Q4_1:
|
5681
|
-
case GGML_TYPE_I8:
|
5682
|
-
case GGML_TYPE_I16:
|
5683
|
-
case GGML_TYPE_I32:
|
5684
|
-
case GGML_TYPE_F16:
|
5685
|
-
case GGML_TYPE_COUNT:
|
6712
|
+
default:
|
5686
6713
|
{
|
5687
6714
|
GGML_ASSERT(false);
|
5688
6715
|
} break;
|
@@ -5724,13 +6751,7 @@ static void ggml_compute_forward_sgn(
|
|
5724
6751
|
{
|
5725
6752
|
ggml_compute_forward_sgn_f32(params, src0, dst);
|
5726
6753
|
} break;
|
5727
|
-
|
5728
|
-
case GGML_TYPE_Q4_1:
|
5729
|
-
case GGML_TYPE_I8:
|
5730
|
-
case GGML_TYPE_I16:
|
5731
|
-
case GGML_TYPE_I32:
|
5732
|
-
case GGML_TYPE_F16:
|
5733
|
-
case GGML_TYPE_COUNT:
|
6754
|
+
default:
|
5734
6755
|
{
|
5735
6756
|
GGML_ASSERT(false);
|
5736
6757
|
} break;
|
@@ -5772,13 +6793,7 @@ static void ggml_compute_forward_neg(
|
|
5772
6793
|
{
|
5773
6794
|
ggml_compute_forward_neg_f32(params, src0, dst);
|
5774
6795
|
} break;
|
5775
|
-
|
5776
|
-
case GGML_TYPE_Q4_1:
|
5777
|
-
case GGML_TYPE_I8:
|
5778
|
-
case GGML_TYPE_I16:
|
5779
|
-
case GGML_TYPE_I32:
|
5780
|
-
case GGML_TYPE_F16:
|
5781
|
-
case GGML_TYPE_COUNT:
|
6796
|
+
default:
|
5782
6797
|
{
|
5783
6798
|
GGML_ASSERT(false);
|
5784
6799
|
} break;
|
@@ -5820,13 +6835,7 @@ static void ggml_compute_forward_step(
|
|
5820
6835
|
{
|
5821
6836
|
ggml_compute_forward_step_f32(params, src0, dst);
|
5822
6837
|
} break;
|
5823
|
-
|
5824
|
-
case GGML_TYPE_Q4_1:
|
5825
|
-
case GGML_TYPE_I8:
|
5826
|
-
case GGML_TYPE_I16:
|
5827
|
-
case GGML_TYPE_I32:
|
5828
|
-
case GGML_TYPE_F16:
|
5829
|
-
case GGML_TYPE_COUNT:
|
6838
|
+
default:
|
5830
6839
|
{
|
5831
6840
|
GGML_ASSERT(false);
|
5832
6841
|
} break;
|
@@ -5868,13 +6877,7 @@ static void ggml_compute_forward_relu(
|
|
5868
6877
|
{
|
5869
6878
|
ggml_compute_forward_relu_f32(params, src0, dst);
|
5870
6879
|
} break;
|
5871
|
-
|
5872
|
-
case GGML_TYPE_Q4_1:
|
5873
|
-
case GGML_TYPE_I8:
|
5874
|
-
case GGML_TYPE_I16:
|
5875
|
-
case GGML_TYPE_I32:
|
5876
|
-
case GGML_TYPE_F16:
|
5877
|
-
case GGML_TYPE_COUNT:
|
6880
|
+
default:
|
5878
6881
|
{
|
5879
6882
|
GGML_ASSERT(false);
|
5880
6883
|
} break;
|
@@ -5933,13 +6936,7 @@ static void ggml_compute_forward_gelu(
|
|
5933
6936
|
{
|
5934
6937
|
ggml_compute_forward_gelu_f32(params, src0, dst);
|
5935
6938
|
} break;
|
5936
|
-
|
5937
|
-
case GGML_TYPE_Q4_1:
|
5938
|
-
case GGML_TYPE_I8:
|
5939
|
-
case GGML_TYPE_I16:
|
5940
|
-
case GGML_TYPE_I32:
|
5941
|
-
case GGML_TYPE_F16:
|
5942
|
-
case GGML_TYPE_COUNT:
|
6939
|
+
default:
|
5943
6940
|
{
|
5944
6941
|
GGML_ASSERT(false);
|
5945
6942
|
} break;
|
@@ -6000,13 +6997,7 @@ static void ggml_compute_forward_silu(
|
|
6000
6997
|
{
|
6001
6998
|
ggml_compute_forward_silu_f32(params, src0, dst);
|
6002
6999
|
} break;
|
6003
|
-
|
6004
|
-
case GGML_TYPE_Q4_1:
|
6005
|
-
case GGML_TYPE_I8:
|
6006
|
-
case GGML_TYPE_I16:
|
6007
|
-
case GGML_TYPE_I32:
|
6008
|
-
case GGML_TYPE_F16:
|
6009
|
-
case GGML_TYPE_COUNT:
|
7000
|
+
default:
|
6010
7001
|
{
|
6011
7002
|
GGML_ASSERT(false);
|
6012
7003
|
} break;
|
@@ -6086,13 +7077,7 @@ static void ggml_compute_forward_norm(
|
|
6086
7077
|
{
|
6087
7078
|
ggml_compute_forward_norm_f32(params, src0, dst);
|
6088
7079
|
} break;
|
6089
|
-
|
6090
|
-
case GGML_TYPE_Q4_1:
|
6091
|
-
case GGML_TYPE_I8:
|
6092
|
-
case GGML_TYPE_I16:
|
6093
|
-
case GGML_TYPE_I32:
|
6094
|
-
case GGML_TYPE_F16:
|
6095
|
-
case GGML_TYPE_COUNT:
|
7080
|
+
default:
|
6096
7081
|
{
|
6097
7082
|
GGML_ASSERT(false);
|
6098
7083
|
} break;
|
@@ -6166,13 +7151,7 @@ static void ggml_compute_forward_rms_norm(
|
|
6166
7151
|
{
|
6167
7152
|
ggml_compute_forward_rms_norm_f32(params, src0, dst);
|
6168
7153
|
} break;
|
6169
|
-
|
6170
|
-
case GGML_TYPE_Q4_1:
|
6171
|
-
case GGML_TYPE_I8:
|
6172
|
-
case GGML_TYPE_I16:
|
6173
|
-
case GGML_TYPE_I32:
|
6174
|
-
case GGML_TYPE_F16:
|
6175
|
-
case GGML_TYPE_COUNT:
|
7154
|
+
default:
|
6176
7155
|
{
|
6177
7156
|
GGML_ASSERT(false);
|
6178
7157
|
} break;
|
@@ -6304,7 +7283,7 @@ static void ggml_compute_forward_mul_mat_f32(
|
|
6304
7283
|
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
|
6305
7284
|
ne11, ne01, ne10,
|
6306
7285
|
1.0f, y, ne10,
|
6307
|
-
x,
|
7286
|
+
x, ne00,
|
6308
7287
|
0.0f, d, ne01);
|
6309
7288
|
}
|
6310
7289
|
}
|
@@ -6476,7 +7455,7 @@ static void ggml_compute_forward_mul_mat_f16_f32(
|
|
6476
7455
|
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
|
6477
7456
|
ne11, ne01, ne10,
|
6478
7457
|
1.0f, y, ne10,
|
6479
|
-
x,
|
7458
|
+
x, ne00,
|
6480
7459
|
0.0f, d, ne01);
|
6481
7460
|
}
|
6482
7461
|
}
|
@@ -6564,29 +7543,6 @@ static void ggml_compute_forward_mul_mat_f16_f32(
|
|
6564
7543
|
//}
|
6565
7544
|
}
|
6566
7545
|
|
6567
|
-
typedef void (*dequantize_row_q_t)(const void * restrict x, float * restrict y, int k);
|
6568
|
-
typedef void (*quantize_row_q_t)(const float * restrict x, void * restrict y, int k);
|
6569
|
-
typedef void (*vec_dot_q_t)(const int n, float * restrict s, const void * restrict x, const void * restrict y);
|
6570
|
-
|
6571
|
-
typedef struct {
|
6572
|
-
dequantize_row_q_t dequantize_row_q;
|
6573
|
-
quantize_row_q_t quantize_row_q;
|
6574
|
-
vec_dot_q_t vec_dot_q;
|
6575
|
-
} quantize_fns_t;
|
6576
|
-
|
6577
|
-
static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
|
6578
|
-
[GGML_TYPE_Q4_0] = {
|
6579
|
-
.dequantize_row_q = dequantize_row_q4_0,
|
6580
|
-
.quantize_row_q = quantize_row_q4_0,
|
6581
|
-
.vec_dot_q = ggml_vec_dot_q4_0,
|
6582
|
-
},
|
6583
|
-
[GGML_TYPE_Q4_1] = {
|
6584
|
-
.dequantize_row_q = dequantize_row_q4_1,
|
6585
|
-
.quantize_row_q = quantize_row_q4_1,
|
6586
|
-
.vec_dot_q = ggml_vec_dot_q4_1,
|
6587
|
-
},
|
6588
|
-
};
|
6589
|
-
|
6590
7546
|
static void ggml_compute_forward_mul_mat_q_f32(
|
6591
7547
|
const struct ggml_compute_params * params,
|
6592
7548
|
const struct ggml_tensor * src0,
|
@@ -6634,8 +7590,8 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|
6634
7590
|
GGML_ASSERT(ne3 == ne13);
|
6635
7591
|
|
6636
7592
|
const enum ggml_type type = src0->type;
|
6637
|
-
quantize_row_q_t const
|
6638
|
-
vec_dot_q_t const vec_dot_q
|
7593
|
+
quantize_row_q_t const quantize_row_q_dot = quantize_fns[type].quantize_row_q_dot;
|
7594
|
+
vec_dot_q_t const vec_dot_q = quantize_fns[type].vec_dot_q;
|
6639
7595
|
|
6640
7596
|
// we don't support permuted src0 or src1
|
6641
7597
|
GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[type]);
|
@@ -6691,7 +7647,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|
6691
7647
|
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
|
6692
7648
|
ne11, ne01, ne10,
|
6693
7649
|
1.0f, y, ne10,
|
6694
|
-
x,
|
7650
|
+
x, ne00,
|
6695
7651
|
0.0f, d, ne01);
|
6696
7652
|
}
|
6697
7653
|
}
|
@@ -6704,12 +7660,12 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|
6704
7660
|
|
6705
7661
|
if (params->type == GGML_TASK_INIT) {
|
6706
7662
|
char * wdata = params->wdata;
|
6707
|
-
const size_t row_size = ne10*GGML_TYPE_SIZE[
|
7663
|
+
const size_t row_size = ne10*GGML_TYPE_SIZE[GGML_TYPE_Q8_0]/GGML_BLCK_SIZE[GGML_TYPE_Q8_0];
|
6708
7664
|
|
6709
7665
|
for (int64_t i13 = 0; i13 < ne13; ++i13) {
|
6710
7666
|
for (int64_t i12 = 0; i12 < ne12; ++i12) {
|
6711
7667
|
for (int64_t i11 = 0; i11 < ne11; ++i11) {
|
6712
|
-
|
7668
|
+
quantize_row_q_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
|
6713
7669
|
wdata += row_size;
|
6714
7670
|
}
|
6715
7671
|
}
|
@@ -6735,7 +7691,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|
6735
7691
|
const int ir1 = MIN(ir0 + dr, nr);
|
6736
7692
|
|
6737
7693
|
void * wdata = params->wdata;
|
6738
|
-
const size_t row_size = ne00*GGML_TYPE_SIZE[
|
7694
|
+
const size_t row_size = ne00*GGML_TYPE_SIZE[GGML_TYPE_Q8_0]/GGML_BLCK_SIZE[GGML_TYPE_Q8_0];
|
6739
7695
|
|
6740
7696
|
for (int ir = ir0; ir < ir1; ++ir) {
|
6741
7697
|
// src0 indices
|
@@ -6783,6 +7739,7 @@ static void ggml_compute_forward_mul_mat(
|
|
6783
7739
|
switch (src0->type) {
|
6784
7740
|
case GGML_TYPE_Q4_0:
|
6785
7741
|
case GGML_TYPE_Q4_1:
|
7742
|
+
case GGML_TYPE_Q8_0:
|
6786
7743
|
{
|
6787
7744
|
ggml_compute_forward_mul_mat_q_f32(params, src0, src1, dst);
|
6788
7745
|
} break;
|
@@ -6794,10 +7751,7 @@ static void ggml_compute_forward_mul_mat(
|
|
6794
7751
|
{
|
6795
7752
|
ggml_compute_forward_mul_mat_f32(params, src0, src1, dst);
|
6796
7753
|
} break;
|
6797
|
-
|
6798
|
-
case GGML_TYPE_I16:
|
6799
|
-
case GGML_TYPE_I32:
|
6800
|
-
case GGML_TYPE_COUNT:
|
7754
|
+
default:
|
6801
7755
|
{
|
6802
7756
|
GGML_ASSERT(false);
|
6803
7757
|
} break;
|
@@ -6879,13 +7833,7 @@ static void ggml_compute_forward_scale(
|
|
6879
7833
|
{
|
6880
7834
|
ggml_compute_forward_scale_f32(params, src0, src1, dst);
|
6881
7835
|
} break;
|
6882
|
-
|
6883
|
-
case GGML_TYPE_Q4_1:
|
6884
|
-
case GGML_TYPE_I8:
|
6885
|
-
case GGML_TYPE_I16:
|
6886
|
-
case GGML_TYPE_I32:
|
6887
|
-
case GGML_TYPE_F16:
|
6888
|
-
case GGML_TYPE_COUNT:
|
7836
|
+
default:
|
6889
7837
|
{
|
6890
7838
|
GGML_ASSERT(false);
|
6891
7839
|
} break;
|
@@ -6901,6 +7849,15 @@ static void ggml_compute_forward_cpy(
|
|
6901
7849
|
ggml_compute_forward_dup(params, src0, dst);
|
6902
7850
|
}
|
6903
7851
|
|
7852
|
+
// ggml_compute_forward_cont
|
7853
|
+
|
7854
|
+
static void ggml_compute_forward_cont(
|
7855
|
+
const struct ggml_compute_params * params,
|
7856
|
+
const struct ggml_tensor * src0,
|
7857
|
+
struct ggml_tensor * dst) {
|
7858
|
+
ggml_compute_forward_dup(params, src0, dst);
|
7859
|
+
}
|
7860
|
+
|
6904
7861
|
// ggml_compute_forward_reshape
|
6905
7862
|
|
6906
7863
|
static void ggml_compute_forward_reshape(
|
@@ -7037,6 +7994,7 @@ static void ggml_compute_forward_get_rows(
|
|
7037
7994
|
switch (src0->type) {
|
7038
7995
|
case GGML_TYPE_Q4_0:
|
7039
7996
|
case GGML_TYPE_Q4_1:
|
7997
|
+
case GGML_TYPE_Q8_0:
|
7040
7998
|
{
|
7041
7999
|
ggml_compute_forward_get_rows_q(params, src0, src1, dst);
|
7042
8000
|
} break;
|
@@ -7048,10 +8006,7 @@ static void ggml_compute_forward_get_rows(
|
|
7048
8006
|
{
|
7049
8007
|
ggml_compute_forward_get_rows_f32(params, src0, src1, dst);
|
7050
8008
|
} break;
|
7051
|
-
|
7052
|
-
case GGML_TYPE_I16:
|
7053
|
-
case GGML_TYPE_I32:
|
7054
|
-
case GGML_TYPE_COUNT:
|
8009
|
+
default:
|
7055
8010
|
{
|
7056
8011
|
GGML_ASSERT(false);
|
7057
8012
|
} break;
|
@@ -7124,13 +8079,7 @@ static void ggml_compute_forward_diag_mask_inf(
|
|
7124
8079
|
{
|
7125
8080
|
ggml_compute_forward_diag_mask_inf_f32(params, src0, src1, dst);
|
7126
8081
|
} break;
|
7127
|
-
|
7128
|
-
case GGML_TYPE_Q4_1:
|
7129
|
-
case GGML_TYPE_I8:
|
7130
|
-
case GGML_TYPE_I16:
|
7131
|
-
case GGML_TYPE_I32:
|
7132
|
-
case GGML_TYPE_F16:
|
7133
|
-
case GGML_TYPE_COUNT:
|
8082
|
+
default:
|
7134
8083
|
{
|
7135
8084
|
GGML_ASSERT(false);
|
7136
8085
|
} break;
|
@@ -7218,13 +8167,7 @@ static void ggml_compute_forward_soft_max(
|
|
7218
8167
|
{
|
7219
8168
|
ggml_compute_forward_soft_max_f32(params, src0, dst);
|
7220
8169
|
} break;
|
7221
|
-
|
7222
|
-
case GGML_TYPE_Q4_1:
|
7223
|
-
case GGML_TYPE_I8:
|
7224
|
-
case GGML_TYPE_I16:
|
7225
|
-
case GGML_TYPE_I32:
|
7226
|
-
case GGML_TYPE_F16:
|
7227
|
-
case GGML_TYPE_COUNT:
|
8170
|
+
default:
|
7228
8171
|
{
|
7229
8172
|
GGML_ASSERT(false);
|
7230
8173
|
} break;
|
@@ -7279,6 +8222,8 @@ static void ggml_compute_forward_rope_f32(
|
|
7279
8222
|
// row index used to determine which thread to use
|
7280
8223
|
int ir = 0;
|
7281
8224
|
|
8225
|
+
const float theta_scale = powf(10000.0, -2.0f/n_dims);
|
8226
|
+
|
7282
8227
|
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
7283
8228
|
for (int64_t i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) {
|
7284
8229
|
const int p = (mode == 0 ? n_past + i2 : i2);
|
@@ -7286,11 +8231,13 @@ static void ggml_compute_forward_rope_f32(
|
|
7286
8231
|
if (ir++ < ir0) continue;
|
7287
8232
|
if (ir > ir1) break;
|
7288
8233
|
|
8234
|
+
float theta = (float)p;
|
8235
|
+
|
7289
8236
|
for (int i0 = 0; i0 < n_dims; i0 += 2) {
|
7290
|
-
const float
|
8237
|
+
const float cos_theta = cosf(theta);
|
8238
|
+
const float sin_theta = sinf(theta);
|
7291
8239
|
|
7292
|
-
|
7293
|
-
const float sin_theta = sinf(p*theta);
|
8240
|
+
theta *= theta_scale;
|
7294
8241
|
|
7295
8242
|
const float * const src = (float *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
7296
8243
|
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
@@ -7352,6 +8299,8 @@ static void ggml_compute_forward_rope_f16(
|
|
7352
8299
|
// row index used to determine which thread to use
|
7353
8300
|
int ir = 0;
|
7354
8301
|
|
8302
|
+
const float theta_scale = powf(10000.0, -2.0f/n_dims);
|
8303
|
+
|
7355
8304
|
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
7356
8305
|
for (int64_t i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) {
|
7357
8306
|
const int p = (mode == 0 ? n_past + i2 : i2);
|
@@ -7359,20 +8308,22 @@ static void ggml_compute_forward_rope_f16(
|
|
7359
8308
|
if (ir++ < ir0) continue;
|
7360
8309
|
if (ir > ir1) break;
|
7361
8310
|
|
8311
|
+
float theta = (float)p;
|
8312
|
+
|
7362
8313
|
for (int i0 = 0; i0 < n_dims; i0 += 2) {
|
7363
|
-
const float
|
8314
|
+
const float cos_theta = cosf(theta);
|
8315
|
+
const float sin_theta = sinf(theta);
|
7364
8316
|
|
7365
|
-
|
7366
|
-
const float sin_theta = sinf(p*theta);
|
8317
|
+
theta *= theta_scale;
|
7367
8318
|
|
7368
8319
|
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
7369
8320
|
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
7370
8321
|
|
7371
|
-
const float x0 =
|
7372
|
-
const float x1 =
|
8322
|
+
const float x0 = GGML_FP16_TO_FP32(src[0]);
|
8323
|
+
const float x1 = GGML_FP16_TO_FP32(src[1]);
|
7373
8324
|
|
7374
|
-
dst_data[0] =
|
7375
|
-
dst_data[1] =
|
8325
|
+
dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
|
8326
|
+
dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
7376
8327
|
}
|
7377
8328
|
}
|
7378
8329
|
}
|
@@ -7393,12 +8344,7 @@ static void ggml_compute_forward_rope(
|
|
7393
8344
|
{
|
7394
8345
|
ggml_compute_forward_rope_f32(params, src0, src1, dst);
|
7395
8346
|
} break;
|
7396
|
-
|
7397
|
-
case GGML_TYPE_Q4_1:
|
7398
|
-
case GGML_TYPE_I8:
|
7399
|
-
case GGML_TYPE_I16:
|
7400
|
-
case GGML_TYPE_I32:
|
7401
|
-
case GGML_TYPE_COUNT:
|
8347
|
+
default:
|
7402
8348
|
{
|
7403
8349
|
GGML_ASSERT(false);
|
7404
8350
|
} break;
|
@@ -7661,12 +8607,7 @@ static void ggml_compute_forward_conv_1d_1s(
|
|
7661
8607
|
{
|
7662
8608
|
ggml_compute_forward_conv_1d_1s_f32(params, src0, src1, dst);
|
7663
8609
|
} break;
|
7664
|
-
|
7665
|
-
case GGML_TYPE_Q4_1:
|
7666
|
-
case GGML_TYPE_I8:
|
7667
|
-
case GGML_TYPE_I16:
|
7668
|
-
case GGML_TYPE_I32:
|
7669
|
-
case GGML_TYPE_COUNT:
|
8610
|
+
default:
|
7670
8611
|
{
|
7671
8612
|
GGML_ASSERT(false);
|
7672
8613
|
} break;
|
@@ -7929,12 +8870,7 @@ static void ggml_compute_forward_conv_1d_2s(
|
|
7929
8870
|
{
|
7930
8871
|
ggml_compute_forward_conv_1d_2s_f32(params, src0, src1, dst);
|
7931
8872
|
} break;
|
7932
|
-
|
7933
|
-
case GGML_TYPE_Q4_1:
|
7934
|
-
case GGML_TYPE_I8:
|
7935
|
-
case GGML_TYPE_I16:
|
7936
|
-
case GGML_TYPE_I32:
|
7937
|
-
case GGML_TYPE_COUNT:
|
8873
|
+
default:
|
7938
8874
|
{
|
7939
8875
|
GGML_ASSERT(false);
|
7940
8876
|
} break;
|
@@ -8414,12 +9350,7 @@ static void ggml_compute_forward_flash_attn(
|
|
8414
9350
|
{
|
8415
9351
|
ggml_compute_forward_flash_attn_f32(params, q, k, v, masked, dst);
|
8416
9352
|
} break;
|
8417
|
-
|
8418
|
-
case GGML_TYPE_Q4_1:
|
8419
|
-
case GGML_TYPE_I8:
|
8420
|
-
case GGML_TYPE_I16:
|
8421
|
-
case GGML_TYPE_I32:
|
8422
|
-
case GGML_TYPE_COUNT:
|
9353
|
+
default:
|
8423
9354
|
{
|
8424
9355
|
GGML_ASSERT(false);
|
8425
9356
|
} break;
|
@@ -8625,12 +9556,100 @@ static void ggml_compute_forward_flash_ff(
|
|
8625
9556
|
{
|
8626
9557
|
GGML_ASSERT(false); // TODO
|
8627
9558
|
} break;
|
8628
|
-
|
8629
|
-
|
8630
|
-
|
8631
|
-
|
8632
|
-
|
8633
|
-
|
9559
|
+
default:
|
9560
|
+
{
|
9561
|
+
GGML_ASSERT(false);
|
9562
|
+
} break;
|
9563
|
+
}
|
9564
|
+
}
|
9565
|
+
|
9566
|
+
// ggml_compute_forward_map_unary
|
9567
|
+
|
9568
|
+
static void ggml_compute_forward_map_unary_f32(
|
9569
|
+
const struct ggml_compute_params * params,
|
9570
|
+
const struct ggml_tensor * src0,
|
9571
|
+
struct ggml_tensor * dst,
|
9572
|
+
const ggml_unary_op_f32_t fun) {
|
9573
|
+
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
9574
|
+
|
9575
|
+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
9576
|
+
return;
|
9577
|
+
}
|
9578
|
+
|
9579
|
+
const int n = ggml_nrows(src0);
|
9580
|
+
const int nc = src0->ne[0];
|
9581
|
+
|
9582
|
+
assert( dst->nb[0] == sizeof(float));
|
9583
|
+
assert(src0->nb[0] == sizeof(float));
|
9584
|
+
|
9585
|
+
for (int i = 0; i < n; i++) {
|
9586
|
+
fun(nc,
|
9587
|
+
(float *) ((char *) dst->data + i*( dst->nb[1])),
|
9588
|
+
(float *) ((char *) src0->data + i*(src0->nb[1])));
|
9589
|
+
}
|
9590
|
+
}
|
9591
|
+
|
9592
|
+
|
9593
|
+
static void ggml_compute_forward_map_unary(
|
9594
|
+
const struct ggml_compute_params * params,
|
9595
|
+
const struct ggml_tensor * src0,
|
9596
|
+
struct ggml_tensor * dst,
|
9597
|
+
const ggml_unary_op_f32_t fun) {
|
9598
|
+
switch (src0->type) {
|
9599
|
+
case GGML_TYPE_F32:
|
9600
|
+
{
|
9601
|
+
ggml_compute_forward_map_unary_f32(params, src0, dst, fun);
|
9602
|
+
} break;
|
9603
|
+
default:
|
9604
|
+
{
|
9605
|
+
GGML_ASSERT(false);
|
9606
|
+
} break;
|
9607
|
+
}
|
9608
|
+
}
|
9609
|
+
|
9610
|
+
// ggml_compute_forward_map_binary
|
9611
|
+
|
9612
|
+
static void ggml_compute_forward_map_binary_f32(
|
9613
|
+
const struct ggml_compute_params * params,
|
9614
|
+
const struct ggml_tensor * src0,
|
9615
|
+
const struct ggml_tensor * src1,
|
9616
|
+
struct ggml_tensor * dst,
|
9617
|
+
const ggml_binary_op_f32_t fun) {
|
9618
|
+
assert(params->ith == 0);
|
9619
|
+
assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
|
9620
|
+
|
9621
|
+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
9622
|
+
return;
|
9623
|
+
}
|
9624
|
+
|
9625
|
+
const int n = ggml_nrows(src0);
|
9626
|
+
const int nc = src0->ne[0];
|
9627
|
+
|
9628
|
+
assert( dst->nb[0] == sizeof(float));
|
9629
|
+
assert(src0->nb[0] == sizeof(float));
|
9630
|
+
assert(src1->nb[0] == sizeof(float));
|
9631
|
+
|
9632
|
+
for (int i = 0; i < n; i++) {
|
9633
|
+
fun(nc,
|
9634
|
+
(float *) ((char *) dst->data + i*( dst->nb[1])),
|
9635
|
+
(float *) ((char *) src0->data + i*(src0->nb[1])),
|
9636
|
+
(float *) ((char *) src1->data + i*(src1->nb[1])));
|
9637
|
+
}
|
9638
|
+
}
|
9639
|
+
|
9640
|
+
|
9641
|
+
static void ggml_compute_forward_map_binary(
|
9642
|
+
const struct ggml_compute_params * params,
|
9643
|
+
const struct ggml_tensor * src0,
|
9644
|
+
const struct ggml_tensor * src1,
|
9645
|
+
struct ggml_tensor * dst,
|
9646
|
+
const ggml_binary_op_f32_t fun) {
|
9647
|
+
switch (src0->type) {
|
9648
|
+
case GGML_TYPE_F32:
|
9649
|
+
{
|
9650
|
+
ggml_compute_forward_map_binary_f32(params, src0, src1, dst, fun);
|
9651
|
+
} break;
|
9652
|
+
default:
|
8634
9653
|
{
|
8635
9654
|
GGML_ASSERT(false);
|
8636
9655
|
} break;
|
@@ -8731,6 +9750,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
8731
9750
|
{
|
8732
9751
|
ggml_compute_forward_cpy(params, tensor->src0, tensor);
|
8733
9752
|
} break;
|
9753
|
+
case GGML_OP_CONT:
|
9754
|
+
{
|
9755
|
+
ggml_compute_forward_cont(params, tensor->src0, tensor);
|
9756
|
+
} break;
|
8734
9757
|
case GGML_OP_RESHAPE:
|
8735
9758
|
{
|
8736
9759
|
ggml_compute_forward_reshape(params, tensor->src0, tensor);
|
@@ -8782,6 +9805,18 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
8782
9805
|
{
|
8783
9806
|
ggml_compute_forward_flash_ff(params, tensor->src0, tensor->src1, tensor->opt[0], tensor->opt[1], tensor->opt[2], tensor);
|
8784
9807
|
} break;
|
9808
|
+
case GGML_OP_MAP_UNARY:
|
9809
|
+
{
|
9810
|
+
const ggml_unary_op_f32_t fun = *((ggml_unary_op_f32_t *)tensor->opt[0]->data);
|
9811
|
+
ggml_compute_forward_map_unary(params, tensor->src0, tensor, fun);
|
9812
|
+
}
|
9813
|
+
break;
|
9814
|
+
case GGML_OP_MAP_BINARY:
|
9815
|
+
{
|
9816
|
+
const ggml_binary_op_f32_t fun = *((ggml_binary_op_f32_t *)tensor->opt[0]->data);
|
9817
|
+
ggml_compute_forward_map_binary(params, tensor->src0, tensor->src1, tensor, fun);
|
9818
|
+
}
|
9819
|
+
break;
|
8785
9820
|
case GGML_OP_NONE:
|
8786
9821
|
{
|
8787
9822
|
// nop
|
@@ -8975,8 +10010,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
8975
10010
|
src1->grad =
|
8976
10011
|
ggml_add_impl(ctx,
|
8977
10012
|
src1->grad,
|
8978
|
-
|
8979
|
-
|
10013
|
+
ggml_mul_mat(ctx,
|
10014
|
+
ggml_cont(ctx, ggml_transpose(ctx, src0)),
|
10015
|
+
tensor->grad),
|
8980
10016
|
inplace);
|
8981
10017
|
}
|
8982
10018
|
} break;
|
@@ -8988,6 +10024,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
8988
10024
|
{
|
8989
10025
|
GGML_ASSERT(false); // TODO: not implemented
|
8990
10026
|
} break;
|
10027
|
+
case GGML_OP_CONT:
|
10028
|
+
{
|
10029
|
+
GGML_ASSERT(false); // TODO: not implemented
|
10030
|
+
} break;
|
8991
10031
|
case GGML_OP_RESHAPE:
|
8992
10032
|
{
|
8993
10033
|
GGML_ASSERT(false); // TODO: not implemented
|
@@ -9036,6 +10076,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
9036
10076
|
{
|
9037
10077
|
GGML_ASSERT(false); // not supported
|
9038
10078
|
} break;
|
10079
|
+
case GGML_OP_MAP_UNARY:
|
10080
|
+
case GGML_OP_MAP_BINARY:
|
10081
|
+
{
|
10082
|
+
GGML_ASSERT(false); // not supported
|
10083
|
+
} break;
|
9039
10084
|
case GGML_OP_NONE:
|
9040
10085
|
{
|
9041
10086
|
// nop
|
@@ -9126,7 +10171,7 @@ struct ggml_cgraph ggml_build_forward(struct ggml_tensor * tensor) {
|
|
9126
10171
|
struct ggml_cgraph result = {
|
9127
10172
|
/*.n_nodes =*/ 0,
|
9128
10173
|
/*.n_leafs =*/ 0,
|
9129
|
-
/*.n_threads =*/
|
10174
|
+
/*.n_threads =*/ GGML_DEFAULT_N_THREADS,
|
9130
10175
|
/*.work_size =*/ 0,
|
9131
10176
|
/*.work =*/ NULL,
|
9132
10177
|
/*.nodes =*/ { NULL },
|
@@ -9354,13 +10399,29 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|
9354
10399
|
struct ggml_tensor * node = cgraph->nodes[i];
|
9355
10400
|
|
9356
10401
|
switch (node->op) {
|
10402
|
+
case GGML_OP_CPY:
|
9357
10403
|
case GGML_OP_DUP:
|
9358
10404
|
{
|
9359
10405
|
node->n_tasks = 1;
|
10406
|
+
|
10407
|
+
size_t cur = 0;
|
10408
|
+
if (node->type == GGML_TYPE_Q4_0 || node->type == GGML_TYPE_Q4_1) {
|
10409
|
+
cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->ne[0];
|
10410
|
+
}
|
10411
|
+
|
10412
|
+
work_size = MAX(work_size, cur);
|
9360
10413
|
} break;
|
9361
10414
|
case GGML_OP_ADD:
|
9362
10415
|
{
|
9363
10416
|
node->n_tasks = n_threads;
|
10417
|
+
|
10418
|
+
size_t cur = 0;
|
10419
|
+
|
10420
|
+
if (node->src0->type == GGML_TYPE_Q4_0 || node->src0->type == GGML_TYPE_Q4_1) {
|
10421
|
+
cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src0->ne[0] * n_threads;
|
10422
|
+
}
|
10423
|
+
|
10424
|
+
work_size = MAX(work_size, cur);
|
9364
10425
|
} break;
|
9365
10426
|
case GGML_OP_SUB:
|
9366
10427
|
case GGML_OP_MUL:
|
@@ -9429,7 +10490,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|
9429
10490
|
} else
|
9430
10491
|
#endif
|
9431
10492
|
{
|
9432
|
-
cur = GGML_TYPE_SIZE[
|
10493
|
+
cur = GGML_TYPE_SIZE[GGML_TYPE_Q8_0]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[GGML_TYPE_Q8_0];
|
9433
10494
|
}
|
9434
10495
|
} else {
|
9435
10496
|
GGML_ASSERT(false);
|
@@ -9441,7 +10502,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|
9441
10502
|
{
|
9442
10503
|
node->n_tasks = n_threads;
|
9443
10504
|
} break;
|
9444
|
-
case
|
10505
|
+
case GGML_OP_CONT:
|
9445
10506
|
case GGML_OP_RESHAPE:
|
9446
10507
|
case GGML_OP_VIEW:
|
9447
10508
|
case GGML_OP_PERMUTE:
|
@@ -9527,6 +10588,11 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|
9527
10588
|
|
9528
10589
|
work_size = MAX(work_size, cur);
|
9529
10590
|
} break;
|
10591
|
+
case GGML_OP_MAP_UNARY:
|
10592
|
+
case GGML_OP_MAP_BINARY:
|
10593
|
+
{
|
10594
|
+
node->n_tasks = 1;
|
10595
|
+
} break;
|
9530
10596
|
case GGML_OP_NONE:
|
9531
10597
|
{
|
9532
10598
|
node->n_tasks = 1;
|
@@ -9745,8 +10811,8 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
|
|
9745
10811
|
|
9746
10812
|
GGML_PRINT("=== GRAPH ===\n");
|
9747
10813
|
|
9748
|
-
GGML_PRINT_DEBUG("n_threads = %d\n",
|
9749
|
-
GGML_PRINT_DEBUG("total work size = %zu bytes\n",cgraph->work_size);
|
10814
|
+
GGML_PRINT_DEBUG("n_threads = %d\n", cgraph->n_threads);
|
10815
|
+
GGML_PRINT_DEBUG("total work size = %zu bytes\n", cgraph->work_size);
|
9750
10816
|
|
9751
10817
|
GGML_PRINT("n_nodes = %d\n", cgraph->n_nodes);
|
9752
10818
|
for (int i = 0; i < cgraph->n_nodes; i++) {
|
@@ -10598,16 +11664,16 @@ enum ggml_opt_result ggml_opt(
|
|
10598
11664
|
////////////////////////////////////////////////////////////////////////////////
|
10599
11665
|
|
10600
11666
|
size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist) {
|
10601
|
-
assert(k %
|
10602
|
-
const int nb = k /
|
11667
|
+
assert(k % QK4_0 == 0);
|
11668
|
+
const int nb = k / QK4_0;
|
10603
11669
|
|
10604
11670
|
for (int j = 0; j < n; j += k) {
|
10605
|
-
block_q4_0 * restrict y = (block_q4_0 *)dst + j/
|
11671
|
+
block_q4_0 * restrict y = (block_q4_0 *)dst + j/QK4_0;
|
10606
11672
|
|
10607
11673
|
quantize_row_q4_0_reference(src + j, y, k);
|
10608
11674
|
|
10609
11675
|
for (int i = 0; i < nb; i++) {
|
10610
|
-
for (int l = 0; l <
|
11676
|
+
for (int l = 0; l < QK4_0; l += 2) {
|
10611
11677
|
const uint8_t vi0 = y[i].qs[l/2] & 0xF;
|
10612
11678
|
const uint8_t vi1 = y[i].qs[l/2] >> 4;
|
10613
11679
|
|
@@ -10617,20 +11683,20 @@ size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t *
|
|
10617
11683
|
}
|
10618
11684
|
}
|
10619
11685
|
|
10620
|
-
return (n/
|
11686
|
+
return (n/QK4_0*sizeof(block_q4_0));
|
10621
11687
|
}
|
10622
11688
|
|
10623
11689
|
size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist) {
|
10624
|
-
assert(k %
|
10625
|
-
const int nb = k /
|
11690
|
+
assert(k % QK4_1 == 0);
|
11691
|
+
const int nb = k / QK4_1;
|
10626
11692
|
|
10627
11693
|
for (int j = 0; j < n; j += k) {
|
10628
|
-
block_q4_1 * restrict y = (block_q4_1 *)dst + j/
|
11694
|
+
block_q4_1 * restrict y = (block_q4_1 *)dst + j/QK4_1;
|
10629
11695
|
|
10630
11696
|
quantize_row_q4_1_reference(src + j, y, k);
|
10631
11697
|
|
10632
11698
|
for (int i = 0; i < nb; i++) {
|
10633
|
-
for (int l = 0; l <
|
11699
|
+
for (int l = 0; l < QK4_1; l += 2) {
|
10634
11700
|
const uint8_t vi0 = y[i].qs[l/2] & 0xF;
|
10635
11701
|
const uint8_t vi1 = y[i].qs[l/2] >> 4;
|
10636
11702
|
|
@@ -10640,7 +11706,7 @@ size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t *
|
|
10640
11706
|
}
|
10641
11707
|
}
|
10642
11708
|
|
10643
|
-
return (n/
|
11709
|
+
return (n/QK4_1*sizeof(block_q4_1));
|
10644
11710
|
}
|
10645
11711
|
|
10646
11712
|
////////////////////////////////////////////////////////////////////////////////
|
@@ -10669,6 +11735,22 @@ int ggml_cpu_has_avx512(void) {
|
|
10669
11735
|
#endif
|
10670
11736
|
}
|
10671
11737
|
|
11738
|
+
int ggml_cpu_has_avx512_vbmi(void) {
|
11739
|
+
#if defined(__AVX512VBMI__)
|
11740
|
+
return 1;
|
11741
|
+
#else
|
11742
|
+
return 0;
|
11743
|
+
#endif
|
11744
|
+
}
|
11745
|
+
|
11746
|
+
int ggml_cpu_has_avx512_vnni(void) {
|
11747
|
+
#if defined(__AVX512VNNI__)
|
11748
|
+
return 1;
|
11749
|
+
#else
|
11750
|
+
return 0;
|
11751
|
+
#endif
|
11752
|
+
}
|
11753
|
+
|
10672
11754
|
int ggml_cpu_has_fma(void) {
|
10673
11755
|
#if defined(__FMA__)
|
10674
11756
|
return 1;
|