llama_cpp 0.0.3 → 0.0.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +36 -0
- data/README.md +5 -4
- data/ext/llama_cpp/extconf.rb +38 -0
- data/ext/llama_cpp/llama_cpp.cpp +118 -2
- data/ext/llama_cpp/src/ggml.c +1740 -658
- data/ext/llama_cpp/src/ggml.h +84 -16
- data/ext/llama_cpp/src/llama.cpp +1108 -756
- data/ext/llama_cpp/src/llama.h +37 -1
- data/ext/llama_cpp/src/llama_util.h +396 -0
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +3 -3
- data/sig/llama_cpp.rbs +6 -0
- metadata +3 -2
data/ext/llama_cpp/src/ggml.c
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
// Defines CLOCK_MONOTONIC
|
1
|
+
// Defines CLOCK_MONOTONIC on Linux
|
2
2
|
#define _GNU_SOURCE
|
3
3
|
|
4
4
|
#include "ggml.h"
|
@@ -26,14 +26,9 @@
|
|
26
26
|
#define static_assert(cond, msg) struct global_scope_noop_trick
|
27
27
|
#endif
|
28
28
|
|
29
|
-
#if defined
|
29
|
+
#if defined(_WIN32)
|
30
30
|
|
31
|
-
#if !defined(__MINGW32__)
|
32
|
-
#include <Windows.h>
|
33
|
-
#else
|
34
|
-
// ref: https://github.com/ggerganov/whisper.cpp/issues/168
|
35
31
|
#include <windows.h>
|
36
|
-
#endif
|
37
32
|
|
38
33
|
typedef volatile LONG atomic_int;
|
39
34
|
typedef atomic_int atomic_bool;
|
@@ -55,6 +50,7 @@ typedef HANDLE pthread_t;
|
|
55
50
|
|
56
51
|
typedef DWORD thread_ret_t;
|
57
52
|
static int pthread_create(pthread_t* out, void* unused, thread_ret_t(*func)(void*), void* arg) {
|
53
|
+
(void) unused;
|
58
54
|
HANDLE handle = CreateThread(NULL, 0, (LPTHREAD_START_ROUTINE) func, arg, 0, NULL);
|
59
55
|
if (handle == NULL)
|
60
56
|
{
|
@@ -66,6 +62,7 @@ static int pthread_create(pthread_t* out, void* unused, thread_ret_t(*func)(void
|
|
66
62
|
}
|
67
63
|
|
68
64
|
static int pthread_join(pthread_t thread, void* unused) {
|
65
|
+
(void) unused;
|
69
66
|
return (int) WaitForSingleObject(thread, INFINITE);
|
70
67
|
}
|
71
68
|
|
@@ -97,17 +94,6 @@ typedef void* thread_ret_t;
|
|
97
94
|
#define static_assert(cond, msg) _Static_assert(cond, msg)
|
98
95
|
#endif
|
99
96
|
|
100
|
-
#define GGML_MLOCK_SUPPORT 0
|
101
|
-
|
102
|
-
#ifdef __has_include
|
103
|
-
#if __has_include(<sys/mman.h>)
|
104
|
-
#undef GGML_MLOCK_SUPPORT
|
105
|
-
#define GGML_MLOCK_SUPPORT 1
|
106
|
-
#include <sys/mman.h>
|
107
|
-
#endif
|
108
|
-
#endif
|
109
|
-
|
110
|
-
|
111
97
|
/*#define GGML_PERF*/
|
112
98
|
#define GGML_DEBUG 0
|
113
99
|
#define GGML_GELU_FP16
|
@@ -128,6 +114,23 @@ typedef void* thread_ret_t;
|
|
128
114
|
#define GGML_MEM_ALIGN 16
|
129
115
|
#endif
|
130
116
|
|
117
|
+
#if defined(_MSC_VER) || defined(__MINGW32__)
|
118
|
+
#define GGML_ALIGNED_MALLOC(size) _aligned_malloc(size, GGML_MEM_ALIGN)
|
119
|
+
#define GGML_ALIGNED_FREE(ptr) _aligned_free(ptr)
|
120
|
+
#else
|
121
|
+
inline static void* ggml_aligned_malloc(size_t size) {
|
122
|
+
void* aligned_memory = NULL;
|
123
|
+
int result = posix_memalign(&aligned_memory, GGML_MEM_ALIGN, size);
|
124
|
+
if (result != 0) {
|
125
|
+
// Handle allocation failure
|
126
|
+
return NULL;
|
127
|
+
}
|
128
|
+
return aligned_memory;
|
129
|
+
}
|
130
|
+
#define GGML_ALIGNED_MALLOC(size) ggml_aligned_malloc(size)
|
131
|
+
#define GGML_ALIGNED_FREE(ptr) free(ptr)
|
132
|
+
#endif
|
133
|
+
|
131
134
|
#define UNUSED(x) (void)(x)
|
132
135
|
#define SWAP(x, y, T) do { T SWAP = x; x = y; y = SWAP; } while (0)
|
133
136
|
|
@@ -242,12 +245,12 @@ static inline float fp32_from_bits(uint32_t w) {
|
|
242
245
|
}
|
243
246
|
|
244
247
|
static inline uint32_t fp32_to_bits(float f) {
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
248
|
+
union {
|
249
|
+
float as_value;
|
250
|
+
uint32_t as_bits;
|
251
|
+
} fp32;
|
252
|
+
fp32.as_value = f;
|
253
|
+
return fp32.as_bits;
|
251
254
|
}
|
252
255
|
|
253
256
|
static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) {
|
@@ -424,8 +427,6 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
|
|
424
427
|
// quantization
|
425
428
|
//
|
426
429
|
|
427
|
-
#define QK 32
|
428
|
-
|
429
430
|
// AVX routines provided by GH user Const-me
|
430
431
|
// ref: https://github.com/ggerganov/ggml/pull/27#issuecomment-1464934600
|
431
432
|
#if __AVX2__ || __AVX512F__
|
@@ -497,37 +498,113 @@ static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
|
|
497
498
|
}
|
498
499
|
#endif
|
499
500
|
|
500
|
-
|
501
|
-
|
502
|
-
|
501
|
+
#if __ARM_NEON
|
502
|
+
|
503
|
+
#if !defined(__aarch64__)
|
504
|
+
|
505
|
+
inline static uint16_t vaddvq_u8(uint8x16_t v) {
|
506
|
+
return
|
507
|
+
(uint16_t)vgetq_lane_u8(v, 0) + (uint16_t)vgetq_lane_u8(v, 1) +
|
508
|
+
(uint16_t)vgetq_lane_u8(v, 2) + (uint16_t)vgetq_lane_u8(v, 3) +
|
509
|
+
(uint16_t)vgetq_lane_u8(v, 4) + (uint16_t)vgetq_lane_u8(v, 5) +
|
510
|
+
(uint16_t)vgetq_lane_u8(v, 6) + (uint16_t)vgetq_lane_u8(v, 7) +
|
511
|
+
(uint16_t)vgetq_lane_u8(v, 8) + (uint16_t)vgetq_lane_u8(v, 9) +
|
512
|
+
(uint16_t)vgetq_lane_u8(v, 10) + (uint16_t)vgetq_lane_u8(v, 11) +
|
513
|
+
(uint16_t)vgetq_lane_u8(v, 12) + (uint16_t)vgetq_lane_u8(v, 13) +
|
514
|
+
(uint16_t)vgetq_lane_u8(v, 14) + (uint16_t)vgetq_lane_u8(v, 15);
|
515
|
+
}
|
516
|
+
|
517
|
+
inline static int32_t vaddvq_s16(int16x8_t v) {
|
518
|
+
return
|
519
|
+
(int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) +
|
520
|
+
(int32_t)vgetq_lane_s16(v, 2) + (int32_t)vgetq_lane_s16(v, 3) +
|
521
|
+
(int32_t)vgetq_lane_s16(v, 4) + (int32_t)vgetq_lane_s16(v, 5) +
|
522
|
+
(int32_t)vgetq_lane_s16(v, 6) + (int32_t)vgetq_lane_s16(v, 7);
|
523
|
+
}
|
524
|
+
|
525
|
+
inline static uint32_t vaddvq_u16(uint16x8_t v) {
|
526
|
+
return
|
527
|
+
(uint32_t)vgetq_lane_u16(v, 0) + (uint32_t)vgetq_lane_u16(v, 1) +
|
528
|
+
(uint32_t)vgetq_lane_u16(v, 2) + (uint32_t)vgetq_lane_u16(v, 3) +
|
529
|
+
(uint32_t)vgetq_lane_u16(v, 4) + (uint32_t)vgetq_lane_u16(v, 5) +
|
530
|
+
(uint32_t)vgetq_lane_u16(v, 6) + (uint32_t)vgetq_lane_u16(v, 7);
|
531
|
+
}
|
532
|
+
|
533
|
+
inline static int32_t vaddvq_s32(int32x4_t v) {
|
534
|
+
return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3);
|
535
|
+
}
|
536
|
+
|
537
|
+
inline static float vaddvq_f32(float32x4_t v) {
|
538
|
+
return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3);
|
539
|
+
}
|
540
|
+
|
541
|
+
float vminvq_f32(float32x4_t v) {
|
542
|
+
return
|
543
|
+
MIN(MIN(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)),
|
544
|
+
MIN(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3)));
|
545
|
+
}
|
546
|
+
|
547
|
+
float vmaxvq_f32(float32x4_t v) {
|
548
|
+
return
|
549
|
+
MAX(MAX(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)),
|
550
|
+
MAX(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3)));
|
551
|
+
}
|
552
|
+
|
553
|
+
int8x8_t vzip1_s8(int8x8_t a, int8x8_t b) {
|
554
|
+
return vget_low_s8(vcombine_s8(a, b));
|
555
|
+
}
|
556
|
+
|
557
|
+
int8x8_t vzip2_s8(int8x8_t a, int8x8_t b) {
|
558
|
+
return vget_high_s8(vcombine_s8(a, b));
|
559
|
+
}
|
560
|
+
|
561
|
+
uint8x8_t vzip1_u8(uint8x8_t a, uint8x8_t b) {
|
562
|
+
return vget_low_u8(vcombine_u8(a, b));
|
563
|
+
}
|
564
|
+
|
565
|
+
uint8x8_t vzip2_u8(uint8x8_t a, uint8x8_t b) {
|
566
|
+
return vget_high_u8(vcombine_u8(a, b));
|
567
|
+
}
|
568
|
+
|
569
|
+
#endif
|
570
|
+
#endif
|
571
|
+
|
572
|
+
|
573
|
+
#define QK4_0 32
|
503
574
|
typedef struct {
|
504
|
-
float d;
|
505
|
-
uint8_t qs[
|
575
|
+
float d; // delta
|
576
|
+
uint8_t qs[QK4_0 / 2]; // nibbles / quants
|
506
577
|
} block_q4_0;
|
507
|
-
static_assert(sizeof(block_q4_0) == sizeof(float) +
|
578
|
+
static_assert(sizeof(block_q4_0) == sizeof(float) + QK4_0 / 2, "wrong q4_0 block size/padding");
|
508
579
|
|
509
|
-
|
510
|
-
// blocks of QK elements
|
511
|
-
// represented with 2 floats (delta + min) and QK/2 8-bit ints (i.e QK 4-bit unsigned integer factors)
|
580
|
+
#define QK4_1 32
|
512
581
|
typedef struct {
|
513
|
-
float d;
|
514
|
-
float m;
|
515
|
-
uint8_t qs[
|
582
|
+
float d; // delta
|
583
|
+
float m; // min
|
584
|
+
uint8_t qs[QK4_1 / 2]; // nibbles / quants
|
516
585
|
} block_q4_1;
|
517
|
-
static_assert(sizeof(block_q4_1) == sizeof(float) * 2 +
|
586
|
+
static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding");
|
587
|
+
|
588
|
+
#define QK8_0 32
|
589
|
+
typedef struct {
|
590
|
+
float d; // delta
|
591
|
+
int8_t qs[QK8_0]; // quants
|
592
|
+
} block_q8_0;
|
593
|
+
static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding");
|
594
|
+
|
518
595
|
|
519
596
|
// reference implementation for deterministic creation of model files
|
520
597
|
static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) {
|
521
|
-
assert(k %
|
522
|
-
const int nb = k /
|
598
|
+
assert(k % QK4_0 == 0);
|
599
|
+
const int nb = k / QK4_0;
|
523
600
|
|
524
|
-
uint8_t pp[
|
601
|
+
uint8_t pp[QK4_0/2];
|
525
602
|
|
526
603
|
for (int i = 0; i < nb; i++) {
|
527
604
|
float amax = 0.0f; // absolute max
|
528
605
|
|
529
|
-
for (int l = 0; l <
|
530
|
-
const float v = x[i*
|
606
|
+
for (int l = 0; l < QK4_0; l++) {
|
607
|
+
const float v = x[i*QK4_0 + l];
|
531
608
|
amax = MAX(amax, fabsf(v));
|
532
609
|
}
|
533
610
|
|
@@ -536,9 +613,9 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r
|
|
536
613
|
|
537
614
|
y[i].d = d;
|
538
615
|
|
539
|
-
for (int l = 0; l <
|
540
|
-
const float v0 = x[i*
|
541
|
-
const float v1 = x[i*
|
616
|
+
for (int l = 0; l < QK4_0; l += 2) {
|
617
|
+
const float v0 = x[i*QK4_0 + l + 0]*id;
|
618
|
+
const float v1 = x[i*QK4_0 + l + 1]*id;
|
542
619
|
|
543
620
|
const uint8_t vi0 = (int8_t)roundf(v0) + 8;
|
544
621
|
const uint8_t vi1 = (int8_t)roundf(v1) + 8;
|
@@ -554,8 +631,8 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r
|
|
554
631
|
}
|
555
632
|
|
556
633
|
static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int k) {
|
557
|
-
assert(k %
|
558
|
-
const int nb = k /
|
634
|
+
assert(k % QK4_0 == 0);
|
635
|
+
const int nb = k / QK4_0;
|
559
636
|
|
560
637
|
block_q4_0 * restrict y = vy;
|
561
638
|
|
@@ -610,10 +687,7 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
|
|
610
687
|
for (int l = 0; l < 2; l++) amaxv[4*l] = vmaxq_f32(amaxv[4*l], amaxv[4*l+2]);
|
611
688
|
for (int l = 0; l < 1; l++) amaxv[8*l] = vmaxq_f32(amaxv[8*l], amaxv[8*l+4]);
|
612
689
|
|
613
|
-
|
614
|
-
const float amax = MAX(
|
615
|
-
MAX(vgetq_lane_f32(amaxv[0], 0), vgetq_lane_f32(amaxv[0], 1)),
|
616
|
-
MAX(vgetq_lane_f32(amaxv[0], 2), vgetq_lane_f32(amaxv[0], 3)));
|
690
|
+
const float amax = vmaxvq_f32(amaxv[0]);
|
617
691
|
|
618
692
|
const float d = amax / ((1 << 3) - 1);
|
619
693
|
const float id = d ? 1.0f/d : 0.0f;
|
@@ -808,19 +882,19 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
|
|
808
882
|
}
|
809
883
|
|
810
884
|
static void quantize_row_q4_1_reference(const float * restrict x, void * restrict vy, int k) {
|
811
|
-
assert(k %
|
812
|
-
const int nb = k /
|
885
|
+
assert(k % QK4_1 == 0);
|
886
|
+
const int nb = k / QK4_1;
|
813
887
|
|
814
888
|
block_q4_1 * restrict y = vy;
|
815
889
|
|
816
|
-
uint8_t pp[
|
890
|
+
uint8_t pp[QK4_1/2];
|
817
891
|
|
818
892
|
for (int i = 0; i < nb; i++) {
|
819
893
|
float min = FLT_MAX;
|
820
894
|
float max = -FLT_MAX;
|
821
895
|
|
822
|
-
for (int l = 0; l <
|
823
|
-
const float v = x[i*
|
896
|
+
for (int l = 0; l < QK4_1; l++) {
|
897
|
+
const float v = x[i*QK4_1 + l];
|
824
898
|
if (v < min) min = v;
|
825
899
|
if (v > max) max = v;
|
826
900
|
}
|
@@ -831,9 +905,9 @@ static void quantize_row_q4_1_reference(const float * restrict x, void * restric
|
|
831
905
|
y[i].d = d;
|
832
906
|
y[i].m = min;
|
833
907
|
|
834
|
-
for (int l = 0; l <
|
835
|
-
const float v0 = (x[i*
|
836
|
-
const float v1 = (x[i*
|
908
|
+
for (int l = 0; l < QK4_1; l += 2) {
|
909
|
+
const float v0 = (x[i*QK4_1 + l + 0] - min)*id;
|
910
|
+
const float v1 = (x[i*QK4_1 + l + 1] - min)*id;
|
837
911
|
|
838
912
|
const uint8_t vi0 = roundf(v0);
|
839
913
|
const uint8_t vi1 = roundf(v1);
|
@@ -849,9 +923,9 @@ static void quantize_row_q4_1_reference(const float * restrict x, void * restric
|
|
849
923
|
}
|
850
924
|
|
851
925
|
static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int k) {
|
852
|
-
assert(k %
|
926
|
+
assert(k % QK4_1 == 0);
|
853
927
|
|
854
|
-
const int nb = k /
|
928
|
+
const int nb = k / QK4_1;
|
855
929
|
|
856
930
|
block_q4_1 * restrict y = vy;
|
857
931
|
|
@@ -935,7 +1009,7 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int
|
|
935
1009
|
float32x4_t minv[8];
|
936
1010
|
float32x4_t maxv[8];
|
937
1011
|
|
938
|
-
for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*
|
1012
|
+
for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*QK4_1 + 4*l);
|
939
1013
|
|
940
1014
|
for (int l = 0; l < 4; l++) minv[2*l] = vminq_f32(srcv[2*l], srcv[2*l + 1]);
|
941
1015
|
for (int l = 0; l < 2; l++) minv[4*l] = vminq_f32(minv[4*l], minv[4*l + 2]);
|
@@ -958,7 +1032,8 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int
|
|
958
1032
|
|
959
1033
|
for (int l = 0; l < 8; l++) {
|
960
1034
|
const float32x4_t v = vmulq_n_f32(vsubq_f32(srcv[l], minv0), id);
|
961
|
-
const
|
1035
|
+
const float32x4_t vf = vaddq_f32(v, vdupq_n_f32(0.5f)); // needed to round to nearest
|
1036
|
+
const int32x4_t vi = vcvtq_s32_f32(vf);
|
962
1037
|
|
963
1038
|
y[i].qs[2*l + 0] = vgetq_lane_s32(vi, 0) | (vgetq_lane_s32(vi, 1) << 4);
|
964
1039
|
y[i].qs[2*l + 1] = vgetq_lane_s32(vi, 2) | (vgetq_lane_s32(vi, 3) << 4);
|
@@ -970,9 +1045,160 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int
|
|
970
1045
|
#endif
|
971
1046
|
}
|
972
1047
|
|
1048
|
+
// reference implementation for deterministic creation of model files
|
1049
|
+
static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int k) {
|
1050
|
+
assert(k % QK8_0 == 0);
|
1051
|
+
const int nb = k / QK8_0;
|
1052
|
+
|
1053
|
+
for (int i = 0; i < nb; i++) {
|
1054
|
+
float amax = 0.0f; // absolute max
|
1055
|
+
|
1056
|
+
for (int l = 0; l < QK8_0; l++) {
|
1057
|
+
const float v = x[i*QK8_0 + l];
|
1058
|
+
amax = MAX(amax, fabsf(v));
|
1059
|
+
}
|
1060
|
+
|
1061
|
+
const float d = amax / ((1 << 7) - 1);
|
1062
|
+
const float id = d ? 1.0f/d : 0.0f;
|
1063
|
+
|
1064
|
+
y[i].d = d;
|
1065
|
+
|
1066
|
+
for (int l = 0; l < QK8_0; ++l) {
|
1067
|
+
const float v = x[i*QK8_0 + l]*id;
|
1068
|
+
y[i].qs[l] = roundf(v);
|
1069
|
+
}
|
1070
|
+
}
|
1071
|
+
}
|
1072
|
+
|
1073
|
+
static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int k) {
|
1074
|
+
assert(k % QK8_0 == 0);
|
1075
|
+
const int nb = k / QK8_0;
|
1076
|
+
|
1077
|
+
block_q8_0 * restrict y = vy;
|
1078
|
+
|
1079
|
+
#if defined(__ARM_NEON)
|
1080
|
+
for (int i = 0; i < nb; i++) {
|
1081
|
+
float32x4_t srcv [8];
|
1082
|
+
float32x4_t asrcv[8];
|
1083
|
+
float32x4_t amaxv[8];
|
1084
|
+
|
1085
|
+
for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*32 + 4*l);
|
1086
|
+
for (int l = 0; l < 8; l++) asrcv[l] = vabsq_f32(srcv[l]);
|
1087
|
+
|
1088
|
+
for (int l = 0; l < 4; l++) amaxv[2*l] = vmaxq_f32(asrcv[2*l], asrcv[2*l+1]);
|
1089
|
+
for (int l = 0; l < 2; l++) amaxv[4*l] = vmaxq_f32(amaxv[4*l], amaxv[4*l+2]);
|
1090
|
+
for (int l = 0; l < 1; l++) amaxv[8*l] = vmaxq_f32(amaxv[8*l], amaxv[8*l+4]);
|
1091
|
+
|
1092
|
+
const float amax = vmaxvq_f32(amaxv[0]);
|
1093
|
+
|
1094
|
+
const float d = amax / ((1 << 7) - 1);
|
1095
|
+
const float id = d ? 1.0f/d : 0.0f;
|
1096
|
+
|
1097
|
+
y[i].d = d;
|
1098
|
+
|
1099
|
+
for (int l = 0; l < 8; l++) {
|
1100
|
+
const float32x4_t v = vmulq_n_f32(srcv[l], id);
|
1101
|
+
const int32x4_t vi = vcvtnq_s32_f32(v);
|
1102
|
+
|
1103
|
+
y[i].qs[4*l + 0] = vgetq_lane_s32(vi, 0);
|
1104
|
+
y[i].qs[4*l + 1] = vgetq_lane_s32(vi, 1);
|
1105
|
+
y[i].qs[4*l + 2] = vgetq_lane_s32(vi, 2);
|
1106
|
+
y[i].qs[4*l + 3] = vgetq_lane_s32(vi, 3);
|
1107
|
+
}
|
1108
|
+
}
|
1109
|
+
#elif defined(__AVX2__) || defined(__AVX__)
|
1110
|
+
for (int i = 0; i < nb; i++) {
|
1111
|
+
// Load elements into 4 AVX vectors
|
1112
|
+
__m256 v0 = _mm256_loadu_ps( x );
|
1113
|
+
__m256 v1 = _mm256_loadu_ps( x + 8 );
|
1114
|
+
__m256 v2 = _mm256_loadu_ps( x + 16 );
|
1115
|
+
__m256 v3 = _mm256_loadu_ps( x + 24 );
|
1116
|
+
x += 32;
|
1117
|
+
|
1118
|
+
// Compute max(abs(e)) for the block
|
1119
|
+
const __m256 signBit = _mm256_set1_ps( -0.0f );
|
1120
|
+
__m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
|
1121
|
+
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
|
1122
|
+
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
|
1123
|
+
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
|
1124
|
+
|
1125
|
+
__m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
|
1126
|
+
max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
|
1127
|
+
max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
|
1128
|
+
const float maxScalar = _mm_cvtss_f32( max4 );
|
1129
|
+
|
1130
|
+
// Quantize these floats
|
1131
|
+
const float d = maxScalar / 127.f;
|
1132
|
+
y[i].d = d;
|
1133
|
+
const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
|
1134
|
+
const __m256 mul = _mm256_set1_ps( id );
|
1135
|
+
|
1136
|
+
// Apply the multiplier
|
1137
|
+
v0 = _mm256_mul_ps( v0, mul );
|
1138
|
+
v1 = _mm256_mul_ps( v1, mul );
|
1139
|
+
v2 = _mm256_mul_ps( v2, mul );
|
1140
|
+
v3 = _mm256_mul_ps( v3, mul );
|
1141
|
+
|
1142
|
+
// Round to nearest integer
|
1143
|
+
v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
|
1144
|
+
v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
|
1145
|
+
v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
|
1146
|
+
v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
|
1147
|
+
|
1148
|
+
// Convert floats to integers
|
1149
|
+
__m256i i0 = _mm256_cvtps_epi32( v0 );
|
1150
|
+
__m256i i1 = _mm256_cvtps_epi32( v1 );
|
1151
|
+
__m256i i2 = _mm256_cvtps_epi32( v2 );
|
1152
|
+
__m256i i3 = _mm256_cvtps_epi32( v3 );
|
1153
|
+
|
1154
|
+
#if defined(__AVX2__)
|
1155
|
+
// Convert int32 to int16
|
1156
|
+
i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
|
1157
|
+
i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
|
1158
|
+
// Convert int16 to int8
|
1159
|
+
i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
|
1160
|
+
|
1161
|
+
// We got our precious signed bytes, but the order is now wrong
|
1162
|
+
// These AVX2 pack instructions process 16-byte pieces independently
|
1163
|
+
// The following instruction is fixing the order
|
1164
|
+
const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
|
1165
|
+
i0 = _mm256_permutevar8x32_epi32( i0, perm );
|
1166
|
+
|
1167
|
+
_mm256_storeu_si256((__m256i *)y[i].qs, i0);
|
1168
|
+
#else
|
1169
|
+
// Since we don't have in AVX some necessary functions,
|
1170
|
+
// we split the registers in half and call AVX2 analogs from SSE
|
1171
|
+
__m128i ni0 = _mm256_castsi256_si128( i0 );
|
1172
|
+
__m128i ni1 = _mm256_extractf128_si256( i0, 1);
|
1173
|
+
__m128i ni2 = _mm256_castsi256_si128( i1 );
|
1174
|
+
__m128i ni3 = _mm256_extractf128_si256( i1, 1);
|
1175
|
+
__m128i ni4 = _mm256_castsi256_si128( i2 );
|
1176
|
+
__m128i ni5 = _mm256_extractf128_si256( i2, 1);
|
1177
|
+
__m128i ni6 = _mm256_castsi256_si128( i3 );
|
1178
|
+
__m128i ni7 = _mm256_extractf128_si256( i3, 1);
|
1179
|
+
|
1180
|
+
// Convert int32 to int16
|
1181
|
+
ni0 = _mm_packs_epi32( ni0, ni1 );
|
1182
|
+
ni2 = _mm_packs_epi32( ni2, ni3 );
|
1183
|
+
ni4 = _mm_packs_epi32( ni4, ni5 );
|
1184
|
+
ni6 = _mm_packs_epi32( ni6, ni7 );
|
1185
|
+
// Convert int16 to int8
|
1186
|
+
ni0 = _mm_packs_epi16( ni0, ni2 );
|
1187
|
+
ni4 = _mm_packs_epi16( ni4, ni6 );
|
1188
|
+
|
1189
|
+
_mm_storeu_si128((__m128i *)(y[i].qs + 0), ni0);
|
1190
|
+
_mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
|
1191
|
+
#endif
|
1192
|
+
}
|
1193
|
+
#else
|
1194
|
+
// scalar
|
1195
|
+
quantize_row_q8_0_reference(x, y, k);
|
1196
|
+
#endif
|
1197
|
+
}
|
1198
|
+
|
973
1199
|
static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, int k) {
|
974
|
-
assert(k %
|
975
|
-
const int nb = k /
|
1200
|
+
assert(k % QK4_0 == 0);
|
1201
|
+
const int nb = k / QK4_0;
|
976
1202
|
|
977
1203
|
const block_q4_0 * restrict x = vx;
|
978
1204
|
|
@@ -983,7 +1209,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
|
|
983
1209
|
|
984
1210
|
const uint8_t * restrict pp = x[i].qs;
|
985
1211
|
|
986
|
-
for (int l = 0; l <
|
1212
|
+
for (int l = 0; l < QK4_0; l += 32) {
|
987
1213
|
// Load 32x4-bit integers into 32x8-bit integers
|
988
1214
|
__m256i vx8 = bytesFromNibbles(pp+l/2);
|
989
1215
|
|
@@ -1005,7 +1231,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
|
|
1005
1231
|
// Scale and store
|
1006
1232
|
for (int j = 0; j < 4; j++) {
|
1007
1233
|
const __m256 result = _mm256_mul_ps(vf[j], d_v);
|
1008
|
-
_mm256_storeu_ps(y + i *
|
1234
|
+
_mm256_storeu_ps(y + i * QK4_0 + l + j*8, result);
|
1009
1235
|
}
|
1010
1236
|
}
|
1011
1237
|
}
|
@@ -1015,7 +1241,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
|
|
1015
1241
|
|
1016
1242
|
const uint8_t * restrict pp = x[i].qs;
|
1017
1243
|
|
1018
|
-
for (int l = 0; l <
|
1244
|
+
for (int l = 0; l < QK4_0; l += 16) {
|
1019
1245
|
// Load 16x4-bit integers into 8x8-bit integers
|
1020
1246
|
const uint8x8_t v8 = vld1_u8(pp + l/2);
|
1021
1247
|
|
@@ -1054,10 +1280,10 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
|
|
1054
1280
|
const float32x4_t r3 = vmulq_f32(vf_3, vd);
|
1055
1281
|
|
1056
1282
|
// Store
|
1057
|
-
vst1q_f32(y + i*
|
1058
|
-
vst1q_f32(y + i*
|
1059
|
-
vst1q_f32(y + i*
|
1060
|
-
vst1q_f32(y + i*
|
1283
|
+
vst1q_f32(y + i*QK4_0 + l + 0, r0);
|
1284
|
+
vst1q_f32(y + i*QK4_0 + l + 4, r1);
|
1285
|
+
vst1q_f32(y + i*QK4_0 + l + 8, r2);
|
1286
|
+
vst1q_f32(y + i*QK4_0 + l + 12, r3);
|
1061
1287
|
}
|
1062
1288
|
}
|
1063
1289
|
#else
|
@@ -1067,7 +1293,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
|
|
1067
1293
|
|
1068
1294
|
const uint8_t * restrict pp = x[i].qs;
|
1069
1295
|
|
1070
|
-
for (int l = 0; l <
|
1296
|
+
for (int l = 0; l < QK4_0; l += 2) {
|
1071
1297
|
const uint8_t vi = pp[l/2];
|
1072
1298
|
|
1073
1299
|
const int8_t vi0 = vi & 0xf;
|
@@ -1078,19 +1304,19 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
|
|
1078
1304
|
|
1079
1305
|
//printf("d = %f, vi = %d, vi0 = %d, vi1 = %d, v0 = %f, v1 = %f\n", d, vi, vi0, vi1, v0, v1);
|
1080
1306
|
|
1081
|
-
y[i*
|
1082
|
-
y[i*
|
1307
|
+
y[i*QK4_0 + l + 0] = v0;
|
1308
|
+
y[i*QK4_0 + l + 1] = v1;
|
1083
1309
|
|
1084
|
-
assert(!isnan(y[i*
|
1085
|
-
assert(!isnan(y[i*
|
1310
|
+
assert(!isnan(y[i*QK4_0 + l + 0]));
|
1311
|
+
assert(!isnan(y[i*QK4_0 + l + 1]));
|
1086
1312
|
}
|
1087
1313
|
}
|
1088
1314
|
#endif
|
1089
1315
|
}
|
1090
1316
|
|
1091
1317
|
static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, int k) {
|
1092
|
-
assert(k %
|
1093
|
-
const int nb = k /
|
1318
|
+
assert(k % QK4_1 == 0);
|
1319
|
+
const int nb = k / QK4_1;
|
1094
1320
|
|
1095
1321
|
const block_q4_1 * restrict x = vx;
|
1096
1322
|
|
@@ -1101,7 +1327,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
|
|
1101
1327
|
|
1102
1328
|
const uint8_t * restrict pp = x[i].qs;
|
1103
1329
|
|
1104
|
-
for (int l = 0; l <
|
1330
|
+
for (int l = 0; l < QK4_1; l += 32) {
|
1105
1331
|
// Load 32x4-bit integers into 32x8-bit integers
|
1106
1332
|
__m256i vx8 = bytesFromNibbles(pp+l/2);
|
1107
1333
|
|
@@ -1120,7 +1346,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
|
|
1120
1346
|
// Scale, add m and store
|
1121
1347
|
for (int j = 0; j < 4; j++) {
|
1122
1348
|
const __m256 result = _mm256_add_ps(_mm256_mul_ps(vf[j], d_v), d_m);
|
1123
|
-
_mm256_storeu_ps(y + i *
|
1349
|
+
_mm256_storeu_ps(y + i * QK4_1 + l + j*8, result);
|
1124
1350
|
}
|
1125
1351
|
}
|
1126
1352
|
}
|
@@ -1131,7 +1357,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
|
|
1131
1357
|
|
1132
1358
|
const uint8_t * restrict pp = x[i].qs;
|
1133
1359
|
|
1134
|
-
for (int l = 0; l <
|
1360
|
+
for (int l = 0; l < QK4_1; l += 16) {
|
1135
1361
|
// Load 16x4-bit integers into 8x8-bit integers
|
1136
1362
|
const uint8x8_t v8 = vld1_u8(pp + l/2);
|
1137
1363
|
|
@@ -1162,10 +1388,10 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
|
|
1162
1388
|
const float32x4_t r3 = vmlaq_f32(vm, vf_3, vd);
|
1163
1389
|
|
1164
1390
|
// Store
|
1165
|
-
vst1q_f32(y + i*
|
1166
|
-
vst1q_f32(y + i*
|
1167
|
-
vst1q_f32(y + i*
|
1168
|
-
vst1q_f32(y + i*
|
1391
|
+
vst1q_f32(y + i*QK4_1 + l + 0, r0);
|
1392
|
+
vst1q_f32(y + i*QK4_1 + l + 4, r1);
|
1393
|
+
vst1q_f32(y + i*QK4_1 + l + 8, r2);
|
1394
|
+
vst1q_f32(y + i*QK4_1 + l + 12, r3);
|
1169
1395
|
}
|
1170
1396
|
}
|
1171
1397
|
#else
|
@@ -1175,7 +1401,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
|
|
1175
1401
|
|
1176
1402
|
const uint8_t * restrict pp = x[i].qs;
|
1177
1403
|
|
1178
|
-
for (int l = 0; l <
|
1404
|
+
for (int l = 0; l < QK4_1; l += 2) {
|
1179
1405
|
const uint8_t vi = pp[l/2];
|
1180
1406
|
|
1181
1407
|
const int8_t vi0 = vi & 0xf;
|
@@ -1184,16 +1410,44 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
|
|
1184
1410
|
const float v0 = vi0*d + m;
|
1185
1411
|
const float v1 = vi1*d + m;
|
1186
1412
|
|
1187
|
-
y[i*
|
1188
|
-
y[i*
|
1413
|
+
y[i*QK4_1 + l + 0] = v0;
|
1414
|
+
y[i*QK4_1 + l + 1] = v1;
|
1189
1415
|
|
1190
|
-
assert(!isnan(y[i*
|
1191
|
-
assert(!isnan(y[i*
|
1416
|
+
assert(!isnan(y[i*QK4_1 + l + 0]));
|
1417
|
+
assert(!isnan(y[i*QK4_1 + l + 1]));
|
1192
1418
|
}
|
1193
1419
|
}
|
1194
1420
|
#endif
|
1195
1421
|
}
|
1196
1422
|
|
1423
|
+
static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
1424
|
+
static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
1425
|
+
|
1426
|
+
static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
|
1427
|
+
[GGML_TYPE_Q4_0] = {
|
1428
|
+
.dequantize_row_q = dequantize_row_q4_0,
|
1429
|
+
.quantize_row_q = quantize_row_q4_0,
|
1430
|
+
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_0_reference,
|
1431
|
+
.quantize_row_q_dot = quantize_row_q8_0,
|
1432
|
+
.vec_dot_q = ggml_vec_dot_q4_0_q8_0,
|
1433
|
+
},
|
1434
|
+
[GGML_TYPE_Q4_1] = {
|
1435
|
+
.dequantize_row_q = dequantize_row_q4_1,
|
1436
|
+
.quantize_row_q = quantize_row_q4_1,
|
1437
|
+
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_1_reference,
|
1438
|
+
.quantize_row_q_dot = quantize_row_q4_1,
|
1439
|
+
.vec_dot_q = ggml_vec_dot_q4_1,
|
1440
|
+
},
|
1441
|
+
// TODO: GGML_TYPE_Q8_0
|
1442
|
+
};
|
1443
|
+
|
1444
|
+
// For internal test use
|
1445
|
+
quantize_fns_t ggml_internal_get_quantize_fn(size_t i) {
|
1446
|
+
GGML_ASSERT(i < GGML_TYPE_COUNT);
|
1447
|
+
return quantize_fns[i];
|
1448
|
+
}
|
1449
|
+
|
1450
|
+
|
1197
1451
|
//
|
1198
1452
|
// simd mappings
|
1199
1453
|
//
|
@@ -1226,15 +1480,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
|
|
1226
1480
|
#define GGML_F32x4_FMA(a, b, c) vfmaq_f32(a, b, c)
|
1227
1481
|
#define GGML_F32x4_ADD vaddq_f32
|
1228
1482
|
#define GGML_F32x4_MUL vmulq_f32
|
1229
|
-
#
|
1230
|
-
#define GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x)
|
1231
|
-
#else
|
1232
|
-
#define GGML_F32x4_REDUCE_ONE(x) \
|
1233
|
-
(vgetq_lane_f32(x, 0) + \
|
1234
|
-
vgetq_lane_f32(x, 1) + \
|
1235
|
-
vgetq_lane_f32(x, 2) + \
|
1236
|
-
vgetq_lane_f32(x, 3))
|
1237
|
-
#endif
|
1483
|
+
#define GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x)
|
1238
1484
|
#define GGML_F32x4_REDUCE(res, x) \
|
1239
1485
|
{ \
|
1240
1486
|
for (int i = 0; i < GGML_F32_ARR/2; ++i) { \
|
@@ -1758,34 +2004,188 @@ inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float
|
|
1758
2004
|
*s = sumf;
|
1759
2005
|
}
|
1760
2006
|
|
1761
|
-
#if __AVX512F__ &&
|
1762
|
-
static inline
|
2007
|
+
#if __AVX512F__ && QK4_0 == 32
|
2008
|
+
static inline __m512i bytes_from_q4_0_twoblocks_avx512( const __m512i blocks ) {
|
2009
|
+
// The 64 bytes of `blocks` contain two consecutive Q4_0 blocks loaded from memory:
|
2010
|
+
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
|
2011
|
+
// |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32|
|
2012
|
+
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
|
2013
|
+
// | :. =_ () [] <> () Zz Yy|
|
2014
|
+
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
|
2015
|
+
// |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00|
|
2016
|
+
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
|
2017
|
+
// |Xx Ww Vv Uu Tt Ss Rr Qq Pp Oo Nn Mm Ll Kk Jj Ii Hh Gg Ff Ee Dd Cc Bb Aa |
|
2018
|
+
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
|
2019
|
+
//
|
2020
|
+
// Bytes 04..19 (block #0) and 24..39 (block #1) both contain 32 nibbles (4-bit unsigned integers).
|
2021
|
+
// We have exactly 64 nibbles, so we want to place each nibble into a separate byte.
|
2022
|
+
// Bytes 00..03 and 20..23 contain scales, which are irrelevant to this function.
|
2023
|
+
// Bytes 40..63 are masked when loading the data, so they are zeroed out.
|
2024
|
+
#ifdef __AVX512VBMI__
|
2025
|
+
const __m512i byte_perm = _mm512_set_epi8(
|
2026
|
+
39, 38, 39, 38, 37, 36, 37, 36, 35, 34, 35, 34, 33, 32, 33, 32,
|
2027
|
+
31, 30, 31, 30, 29, 28, 29, 28, 27, 26, 27, 26, 25, 24, 25, 24,
|
2028
|
+
19, 18, 19, 18, 17, 16, 17, 16, 15, 14, 15, 14, 13, 12, 13, 12,
|
2029
|
+
11, 10, 11, 10, 9, 8, 9, 8, 7, 6, 7, 6, 5, 4, 5, 4
|
2030
|
+
);
|
2031
|
+
const __m512i permuted = _mm512_permutexvar_epi8( byte_perm, blocks );
|
2032
|
+
// After applying VPERMB, `permuted` looks like this:
|
2033
|
+
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2034
|
+
// |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32|
|
2035
|
+
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2036
|
+
// |:. =_ :. =_ () [] () [] <> () <> () Zz Yy Zz Yy Xx Ww Xx Ww Vv Uu Vv Uu Tt Ss Tt Ss Rr Qq Rr Qq|
|
2037
|
+
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2038
|
+
// |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00|
|
2039
|
+
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2040
|
+
// |Pp Oo Pp Oo Nn Mm Nn Mm Ll Kk Ll Kk Jj Ii Jj Ii Hh Gg Hh Gg Ff Ee Ff Ee Dd Cc Dd Cc Bb Aa Bb Aa|
|
2041
|
+
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2042
|
+
#else
|
2043
|
+
const __m512i word_perm = _mm512_set_epi16(
|
2044
|
+
19, 19, 18, 18, 17, 17, 16, 16, 15, 15, 14, 14, 13, 13, 12, 12,
|
2045
|
+
9, 9, 8, 8, 7, 7, 6, 6, 5, 5, 4, 4, 3, 3, 2, 2
|
2046
|
+
);
|
2047
|
+
const __m512i permuted = _mm512_permutexvar_epi16( word_perm, blocks );
|
2048
|
+
// This is the fallback path for CPUs that don't support VPERMB. Since we permute 16-bit groups only,
|
2049
|
+
// VPERMB can be replaced with VPERMW. We could always use VPERMW, but at least on Tiger Lake and
|
2050
|
+
// Ice Lake VPERMW followed by a right shift is quite noticeably slower than VPERMB.
|
2051
|
+
#endif
|
2052
|
+
|
2053
|
+
// Shift every odd-numbered 16-bit group to the right by 4 bits.
|
2054
|
+
const __mmask32 shift_mask = 0xaaaaaaaa;
|
2055
|
+
const __m512i shifted = _mm512_mask_srai_epi16( permuted, shift_mask, permuted, 4 );
|
2056
|
+
// After applying VPSRAW, `shifted` looks like this (the "empty" nibbles are filled with zeroes):
|
2057
|
+
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2058
|
+
// |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32
|
2059
|
+
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2060
|
+
// | : .= :. =_ ( )[ () [] < >( <> () Z zY Zz Yy X xW Xx Ww V vU Vv Uu T tS Tt Ss R rQ Rr Qq
|
2061
|
+
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2062
|
+
// |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00|
|
2063
|
+
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2064
|
+
// | P pO Pp Oo N nM Nn Mm L lK Ll Kk J jI Jj Ii H hG Hh Gg F fE Ff Ee D dC Dd Cc B bA Bb Aa|
|
2065
|
+
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2066
|
+
|
2067
|
+
// Now we just need to zero out the higher nibble in each byte, and we're done.
|
2068
|
+
const __m512i low_nibble_mask = _mm512_set1_epi8( 0xf );
|
2069
|
+
return _mm512_and_si512( low_nibble_mask, shifted );
|
2070
|
+
// The final result looks like this:
|
2071
|
+
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2072
|
+
// |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32|
|
2073
|
+
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2074
|
+
// | : = . _ ( [ ) ] < ( > ) Z Y z y X W x w V U v u T S t s R Q r q|
|
2075
|
+
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2076
|
+
// |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00|
|
2077
|
+
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2078
|
+
// | P O p o N M n m L K l k J I j i H G h g F E f e D C d c B A b a|
|
2079
|
+
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2080
|
+
}
|
2081
|
+
|
2082
|
+
static inline __m512 dot_q4_0_twoblocks_avx512(
|
1763
2083
|
__m512 acc,
|
1764
2084
|
const block_q4_0 * restrict x,
|
1765
2085
|
const block_q4_0 * restrict y,
|
1766
2086
|
int i
|
1767
2087
|
) {
|
1768
|
-
//
|
1769
|
-
|
1770
|
-
|
1771
|
-
|
1772
|
-
|
1773
|
-
|
1774
|
-
|
1775
|
-
|
1776
|
-
|
1777
|
-
|
1778
|
-
|
1779
|
-
//
|
1780
|
-
|
1781
|
-
|
1782
|
-
//
|
1783
|
-
|
2088
|
+
// A pair of Q4_0 blocks spans 40 bytes, while an AVX-512 register has 64. The remaining 24 bytes
|
2089
|
+
// can potentially be unaddressable, so we make sure to mask them out before the load, even though
|
2090
|
+
// we don't use them at all. This might hurt the performance slightly, since the compiler is forced
|
2091
|
+
// to use e.g. `VMOVDQU64 REG, MASK, [ADDR] + VPERMB ..., REG` instead of just `VPERMB ..., [ADDR]`.
|
2092
|
+
const __mmask8 load_mask = 0x1f;
|
2093
|
+
const __m512i blocks_0 = _mm512_maskz_loadu_epi64( load_mask, &x[i] );
|
2094
|
+
const __m512i blocks_1 = _mm512_maskz_loadu_epi64( load_mask, &y[i] );
|
2095
|
+
|
2096
|
+
// We want to multiply the scales, so we interpret both registers as 16 32-bit floats:
|
2097
|
+
// +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
|
2098
|
+
// | 15 | 14 | 13 | 12 | 11 | 10 | 09 | 08 | 07 | 06 | 05 | 04 | 03 | 02 | 01 | 00 |
|
2099
|
+
// +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
|
2100
|
+
// blocks_0_float
|
2101
|
+
// +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
|
2102
|
+
// | | | | | | | xx | xx | xx | xx | B | xx | xx | xx | xx | A |
|
2103
|
+
// +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
|
2104
|
+
// blocks_1_float
|
2105
|
+
// +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
|
2106
|
+
// | | | | | | | xx | xx | xx | xx | D | xx | xx | xx | xx | C |
|
2107
|
+
// +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
|
2108
|
+
const __m512 blocks_0_float = _mm512_castsi512_ps( blocks_0 );
|
2109
|
+
const __m512 blocks_1_float = _mm512_castsi512_ps( blocks_1 );
|
2110
|
+
// We absolutely shouldn't touch the floats marked with `xx`: they contain some
|
2111
|
+
// random data, which might very well underflow. At least on Intel, this leads
|
2112
|
+
// to a huge penalty that can't be ignored (easily 100x or more) unless you
|
2113
|
+
// compile your code with something like `-ffast-math` to enable FTZ/DAZ flags.
|
2114
|
+
// (and ggml can't assume that you do)...
|
2115
|
+
const __mmask16 scale_mul_mask = 0x21;
|
2116
|
+
#ifdef __clang__
|
2117
|
+
// ...however, clang decides to optimize the multiplication mask away:
|
2118
|
+
// https://godbolt.org/z/P8PqdsfvW
|
2119
|
+
// gcc and MSVC do the sane thing. This horrible workaround forces clang to emit the mask.
|
2120
|
+
__m512i scales;
|
2121
|
+
__asm__(
|
2122
|
+
"vmulps %1, %2, %0%{%3%}"
|
2123
|
+
: "=v" ( scales )
|
2124
|
+
: "vm" ( blocks_0_float ), "v" ( blocks_1_float ), "Yk" ( scale_mul_mask )
|
2125
|
+
);
|
2126
|
+
#else
|
2127
|
+
const __m512 scales = _mm512_maskz_mul_ps( scale_mul_mask, blocks_0_float, blocks_1_float );
|
2128
|
+
#endif
|
2129
|
+
const __m512i scale_perm = _mm512_set_epi32(
|
2130
|
+
5, 5, 5, 5, 5, 5, 5, 5,
|
2131
|
+
0, 0, 0, 0, 0, 0, 0, 0
|
2132
|
+
);
|
2133
|
+
const __m512 permuted_scales = _mm512_permutexvar_ps( scale_perm, scales );
|
2134
|
+
// After VMULPS and VPERMPS, `permuted_scales` looks like this:
|
2135
|
+
// +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
|
2136
|
+
// | 15 | 14 | 13 | 12 | 11 | 10 | 09 | 08 | 07 | 06 | 05 | 04 | 03 | 02 | 01 | 00 |
|
2137
|
+
// +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
|
2138
|
+
// | B*D| B*D| B*D| B*D| B*D| B*D| B*D| B*D| A*C| A*C| A*C| A*C| A*C| A*C| A*C| A*C|
|
2139
|
+
// +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
|
2140
|
+
|
2141
|
+
const __m512i bytes_0 = bytes_from_q4_0_twoblocks_avx512( blocks_0 );
|
2142
|
+
const __m512i bytes_1 = bytes_from_q4_0_twoblocks_avx512( blocks_1 );
|
2143
|
+
|
2144
|
+
// Now we want to compute dot products of 4-element byte vectors and store them in
|
2145
|
+
// 32-bit integers. That is (only one 4-element vector is shown for clarity):
|
2146
|
+
// +----+----+----+----+
|
2147
|
+
// ... | 03 | 02 | 01 | 00 |
|
2148
|
+
// +----+----+----+----+
|
2149
|
+
// bytes_0
|
2150
|
+
// +----+----+----+----+
|
2151
|
+
// ... | D | C | B | A |
|
2152
|
+
// +----+----+----+----+
|
2153
|
+
// bytes_1
|
2154
|
+
// +----+----+----+----+
|
2155
|
+
// ... | H | G | F | E |
|
2156
|
+
// +----+----+----+----+
|
2157
|
+
// final_res_int
|
2158
|
+
// +----+----+----+----+
|
2159
|
+
// ... | A*E+B*F+C*G+D*H |
|
2160
|
+
// +----+----+----+----+
|
2161
|
+
const __m512i plus_8 = _mm512_set1_epi8( 8 );
|
2162
|
+
const __m512i bytes_1_minus_8 = _mm512_sub_epi8( bytes_1, plus_8 );
|
2163
|
+
|
2164
|
+
#ifdef __AVX512VNNI__
|
2165
|
+
// We have VPDPBUSDS in AVX512-VNNI, which does exactly what we want, but with a catch:
|
2166
|
+
// the *left* operand is supposed to be unsigned, while Q4_0 quantization subtracts 8
|
2167
|
+
// from each nibble, so they can be negative. So, instead of `(bytes_0 - 8) * (bytes_1 - 8)`,
|
2168
|
+
// we compute `bytes_0 * (bytes_1 - 8) + bytes_1 * (-8) + 64`. VPDPBUSDS uses an accumulator,
|
2169
|
+
// which means we only need 2 instructions.
|
2170
|
+
const __m512i dot_init = _mm512_set1_epi32( 4 * 64 );
|
2171
|
+
const __m512i minus_8 = _mm512_set1_epi8( -8 );
|
2172
|
+
const __m512i prod_0 = _mm512_dpbusds_epi32( dot_init, bytes_1, minus_8 );
|
2173
|
+
const __m512i final_res_int = _mm512_dpbusds_epi32( prod_0, bytes_0, bytes_1_minus_8 );
|
2174
|
+
#else
|
2175
|
+
// As a fallback, we have VPMADDUBSW in AVX512-BW, which uses 16-bit products instead of 32-bit ones.
|
2176
|
+
// It has the same catch as VPDPBUSDS: the left operand should be unsigned.
|
2177
|
+
// This is essentially the AVX-512 version of the AVX-2 trick used by GH user Const-me
|
2178
|
+
// ref: https://gist.github.com/Const-me/4d30e1fc767ab314596e16e90f53b6f4#file-matmultest-cpp-L119
|
2179
|
+
const __m512i one = _mm512_set1_epi16( 1 );
|
2180
|
+
const __m512i prod_0 = _mm512_maddubs_epi16( bytes_0, bytes_1_minus_8 );
|
2181
|
+
const __m512i prod_1 = _mm512_maddubs_epi16( plus_8, bytes_1_minus_8 );
|
2182
|
+
const __m512i diff = _mm512_sub_epi16( prod_0, prod_1 );
|
2183
|
+
const __m512i final_res_int = _mm512_madd_epi16( diff, one );
|
2184
|
+
#endif
|
1784
2185
|
|
1785
|
-
//
|
1786
|
-
__m512
|
1787
|
-
|
1788
|
-
return _mm512_fmadd_ps( d, p, acc );
|
2186
|
+
// Finally, we multiply the permuted scales and the 32-bit dot products, then accumulate.
|
2187
|
+
const __m512 final_res_float = _mm512_cvtepi32_ps( final_res_int );
|
2188
|
+
return _mm512_fmadd_ps( permuted_scales, final_res_float, acc );
|
1789
2189
|
}
|
1790
2190
|
#endif
|
1791
2191
|
|
@@ -1826,9 +2226,9 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
|
|
1826
2226
|
}
|
1827
2227
|
|
1828
2228
|
static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
1829
|
-
const int nb = n /
|
2229
|
+
const int nb = n / QK4_0;
|
1830
2230
|
|
1831
|
-
assert(n %
|
2231
|
+
assert(n % QK4_0 == 0);
|
1832
2232
|
assert(nb % 2 == 0);
|
1833
2233
|
|
1834
2234
|
const block_q4_0 * restrict x = vx;
|
@@ -1857,55 +2257,43 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
|
|
1857
2257
|
// 4-bit -> 8-bit
|
1858
2258
|
const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b));
|
1859
2259
|
const int8x16_t v1_0l = vreinterpretq_s8_u8(vandq_u8(v1_0, m4b));
|
1860
|
-
|
1861
2260
|
const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
|
1862
2261
|
const int8x16_t v1_0h = vreinterpretq_s8_u8(vshrq_n_u8(v1_0, 4));
|
1863
2262
|
|
1864
2263
|
const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8(v0_1, m4b));
|
1865
2264
|
const int8x16_t v1_1l = vreinterpretq_s8_u8(vandq_u8(v1_1, m4b));
|
1866
|
-
|
1867
2265
|
const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
|
1868
2266
|
const int8x16_t v1_1h = vreinterpretq_s8_u8(vshrq_n_u8(v1_1, 4));
|
1869
2267
|
|
1870
2268
|
// sub 8
|
1871
2269
|
const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
|
1872
2270
|
const int8x16_t v1_0ls = vsubq_s8(v1_0l, s8b);
|
1873
|
-
|
1874
2271
|
const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
|
1875
2272
|
const int8x16_t v1_0hs = vsubq_s8(v1_0h, s8b);
|
1876
2273
|
|
1877
2274
|
const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
|
1878
2275
|
const int8x16_t v1_1ls = vsubq_s8(v1_1l, s8b);
|
1879
|
-
|
1880
2276
|
const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
|
1881
2277
|
const int8x16_t v1_1hs = vsubq_s8(v1_1h, s8b);
|
1882
2278
|
|
1883
2279
|
#if defined(__ARM_FEATURE_DOTPROD)
|
1884
|
-
// dot product into
|
2280
|
+
// dot product into int32x4_t
|
1885
2281
|
int32x4_t p_0 = vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls);
|
1886
2282
|
int32x4_t p_1 = vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls);
|
1887
2283
|
|
1888
2284
|
p_0 = vdotq_s32(p_0, v0_0hs, v1_0hs);
|
1889
2285
|
p_1 = vdotq_s32(p_1, v0_1hs, v1_1hs);
|
1890
2286
|
|
1891
|
-
|
1892
|
-
|
1893
|
-
sum0 += x0->d * y0->d * vaddvq_s32(p_0);
|
1894
|
-
sum1 += x1->d * y1->d * vaddvq_s32(p_1);
|
2287
|
+
sum0 += x0->d*y0->d*vaddvq_s32(p_0);
|
2288
|
+
sum1 += x1->d*y1->d*vaddvq_s32(p_1);
|
1895
2289
|
#else
|
1896
|
-
|
1897
|
-
sum1 += x1->d * y1->d * (vgetq_lane_s32(p_1, 0) + vgetq_lane_s32(p_1, 1) + vgetq_lane_s32(p_1, 2) + vgetq_lane_s32(p_1, 3));
|
1898
|
-
#endif
|
1899
|
-
#else
|
1900
|
-
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
|
2290
|
+
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
|
1901
2291
|
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
|
1902
|
-
|
1903
2292
|
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs));
|
1904
2293
|
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs));
|
1905
2294
|
|
1906
2295
|
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls));
|
1907
2296
|
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls));
|
1908
|
-
|
1909
2297
|
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
|
1910
2298
|
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
|
1911
2299
|
|
@@ -1918,14 +2306,8 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
|
|
1918
2306
|
const int16x8_t p_0 = vaddq_s16(pl_0, ph_0);
|
1919
2307
|
const int16x8_t p_1 = vaddq_s16(pl_1, ph_1);
|
1920
2308
|
|
1921
|
-
|
1922
|
-
|
1923
|
-
sum0 += x0->d * y0->d * vaddvq_s16(p_0);
|
1924
|
-
sum1 += x1->d * y1->d * vaddvq_s16(p_1);
|
1925
|
-
#else
|
1926
|
-
sum0 += x0->d * y0->d * (vgetq_lane_s16(p_0, 0) + vgetq_lane_s16(p_0, 1) + vgetq_lane_s16(p_0, 2) + vgetq_lane_s16(p_0, 3) + vgetq_lane_s16(p_0, 4) + vgetq_lane_s16(p_0, 5) + vgetq_lane_s16(p_0, 6) + vgetq_lane_s16(p_0, 7));
|
1927
|
-
sum1 += x1->d * y1->d * (vgetq_lane_s16(p_1, 0) + vgetq_lane_s16(p_1, 1) + vgetq_lane_s16(p_1, 2) + vgetq_lane_s16(p_1, 3) + vgetq_lane_s16(p_1, 4) + vgetq_lane_s16(p_1, 5) + vgetq_lane_s16(p_1, 6) + vgetq_lane_s16(p_1, 7));
|
1928
|
-
#endif
|
2309
|
+
sum0 += x0->d*y0->d*vaddvq_s16(p_0);
|
2310
|
+
sum1 += x1->d*y1->d*vaddvq_s16(p_1);
|
1929
2311
|
#endif
|
1930
2312
|
}
|
1931
2313
|
|
@@ -1935,25 +2317,26 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
|
|
1935
2317
|
__m512 acc0 = _mm512_setzero_ps();
|
1936
2318
|
__m512 acc1 = _mm512_setzero_ps();
|
1937
2319
|
|
1938
|
-
const int superblock_size =
|
2320
|
+
const int superblock_size = 16;
|
2321
|
+
|
1939
2322
|
const int superblock_count = nb / superblock_size;
|
1940
2323
|
|
1941
2324
|
for (int superblock_ix = 0; superblock_ix < superblock_count; superblock_ix += 1) {
|
1942
2325
|
int i = superblock_ix * superblock_size;
|
1943
2326
|
|
1944
|
-
acc0 =
|
1945
|
-
acc1 =
|
1946
|
-
acc0 =
|
1947
|
-
acc1 =
|
1948
|
-
acc0 =
|
1949
|
-
acc1 =
|
1950
|
-
acc0 =
|
1951
|
-
acc1 =
|
2327
|
+
acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i+0 );
|
2328
|
+
acc1 = dot_q4_0_twoblocks_avx512( acc1, x, y, i+2 );
|
2329
|
+
acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i+4 );
|
2330
|
+
acc1 = dot_q4_0_twoblocks_avx512( acc1, x, y, i+6 );
|
2331
|
+
acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i+8 );
|
2332
|
+
acc1 = dot_q4_0_twoblocks_avx512( acc1, x, y, i+10 );
|
2333
|
+
acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i+12 );
|
2334
|
+
acc1 = dot_q4_0_twoblocks_avx512( acc1, x, y, i+14 );
|
1952
2335
|
}
|
1953
2336
|
|
1954
2337
|
// Remainders
|
1955
|
-
for (int i = superblock_count * superblock_size; i < nb;
|
1956
|
-
acc0 =
|
2338
|
+
for (int i = superblock_count * superblock_size; i < nb; i += 2) {
|
2339
|
+
acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i );
|
1957
2340
|
}
|
1958
2341
|
|
1959
2342
|
// Horizontal sum of all lanes of the accumulator
|
@@ -1962,7 +2345,7 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
|
|
1962
2345
|
// Initialize accumulator with zeros
|
1963
2346
|
__m256 acc = _mm256_setzero_ps();
|
1964
2347
|
|
1965
|
-
/* Prepare the constants we will need during execution */
|
2348
|
+
/* Prepare the constants we will need during execution */
|
1966
2349
|
const __m256i lowMask = _mm256_set1_epi8( 0xF );
|
1967
2350
|
const __m256i offset_8 = _mm256_set1_epi16( 8 );
|
1968
2351
|
|
@@ -1972,61 +2355,59 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
|
|
1972
2355
|
|
1973
2356
|
// Main loop
|
1974
2357
|
for (int i = 0; i < nb; i+=UNROLL_COUNT) {
|
1975
|
-
|
1976
|
-
// This loop will be unrolled by the compiler
|
2358
|
+
// This loop will be unrolled by the compiler
|
1977
2359
|
for (int u=0;u<UNROLL_COUNT;u++) {
|
1978
|
-
/* Compute combined scale for the block */
|
1979
|
-
const __m256 scale = _mm256_mul_ps(
|
1980
|
-
_mm256_broadcast_ss( &x[i+u].d ),
|
1981
|
-
_mm256_broadcast_ss( &y[i+u].d ) );
|
1982
|
-
|
1983
|
-
/* get input from x
|
1984
|
-
Input: 32 Nibbles (16 bytes) at *x[i+u]
|
1985
|
-
Output: 2 vectors with 16 values of type int16_t (x_high_q, x_low_q) */
|
1986
|
-
|
1987
|
-
/* Load 16 bytes from memory */
|
1988
|
-
const __m128i tmp_x = _mm_loadu_si128( ( const __m128i* ) x[i+u].qs);
|
1989
|
-
/* Expand bytes into uint16_t values */
|
1990
|
-
const __m256i bytes_x = _mm256_cvtepu8_epi16(tmp_x);
|
2360
|
+
/* Compute combined scale for the block */
|
2361
|
+
const __m256 scale = _mm256_mul_ps(
|
2362
|
+
_mm256_broadcast_ss( &x[i+u].d ),
|
2363
|
+
_mm256_broadcast_ss( &y[i+u].d ) );
|
2364
|
+
|
2365
|
+
/* get input from x
|
2366
|
+
Input: 32 Nibbles (16 bytes) at *x[i+u]
|
2367
|
+
Output: 2 vectors with 16 values of type int16_t (x_high_q, x_low_q) */
|
2368
|
+
|
2369
|
+
/* Load 16 bytes from memory */
|
2370
|
+
const __m128i tmp_x = _mm_loadu_si128( ( const __m128i* ) x[i+u].qs);
|
2371
|
+
/* Expand bytes into uint16_t values */
|
2372
|
+
const __m256i bytes_x = _mm256_cvtepu8_epi16(tmp_x);
|
1991
2373
|
/* Unpack values into individual bytes */
|
1992
2374
|
__m256i x_low_q = _mm256_and_si256( lowMask, bytes_x );
|
1993
2375
|
const __m256i pre_shift_x_high_q = _mm256_andnot_si256( lowMask, bytes_x );
|
1994
|
-
__m256i x_high_q = _mm256_srli_epi16( pre_shift_x_high_q, 4 );
|
2376
|
+
__m256i x_high_q = _mm256_srli_epi16( pre_shift_x_high_q, 4 );
|
1995
2377
|
/* Now we have two vectors with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. */
|
1996
|
-
x_high_q = _mm256_sub_epi16( x_high_q, offset_8 );
|
1997
|
-
x_low_q = _mm256_sub_epi16( x_low_q, offset_8 );
|
2378
|
+
x_high_q = _mm256_sub_epi16( x_high_q, offset_8 );
|
2379
|
+
x_low_q = _mm256_sub_epi16( x_low_q, offset_8 );
|
1998
2380
|
|
1999
|
-
/* get input from y
|
2000
|
-
Input: 32 Nibbles (16 bytes) at *y[i+u]
|
2001
|
-
Output: 2 vectors with 16 values of type int16_t (y_high_q, y_low_q) */
|
2381
|
+
/* get input from y
|
2382
|
+
Input: 32 Nibbles (16 bytes) at *y[i+u]
|
2383
|
+
Output: 2 vectors with 16 values of type int16_t (y_high_q, y_low_q) */
|
2002
2384
|
|
2003
|
-
/* Load 16 bytes from memory */
|
2004
|
-
const __m128i tmp_y = _mm_loadu_si128( (const __m128i* ) y[i+u].qs);
|
2005
|
-
/* Expand bytes into uint16_t values */
|
2006
|
-
const __m256i bytes_y = _mm256_cvtepu8_epi16(tmp_y);
|
2385
|
+
/* Load 16 bytes from memory */
|
2386
|
+
const __m128i tmp_y = _mm_loadu_si128( (const __m128i* ) y[i+u].qs);
|
2387
|
+
/* Expand bytes into uint16_t values */
|
2388
|
+
const __m256i bytes_y = _mm256_cvtepu8_epi16(tmp_y);
|
2007
2389
|
/* Unpack values into individual bytes */
|
2008
|
-
const __m256i pre_shift_y_high_q = _mm256_andnot_si256( lowMask, bytes_y );
|
2009
|
-
__m256i y_high_q = _mm256_srli_epi16( pre_shift_y_high_q, 4 );
|
2010
|
-
__m256i y_low_q = _mm256_and_si256( lowMask, bytes_y );
|
2390
|
+
const __m256i pre_shift_y_high_q = _mm256_andnot_si256( lowMask, bytes_y );
|
2391
|
+
__m256i y_high_q = _mm256_srli_epi16( pre_shift_y_high_q, 4 );
|
2392
|
+
__m256i y_low_q = _mm256_and_si256( lowMask, bytes_y );
|
2011
2393
|
/* Now we have two vectors with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. */
|
2012
|
-
y_high_q = _mm256_sub_epi16( y_high_q, offset_8 );
|
2013
|
-
y_low_q = _mm256_sub_epi16( y_low_q, offset_8 );
|
2394
|
+
y_high_q = _mm256_sub_epi16( y_high_q, offset_8 );
|
2395
|
+
y_low_q = _mm256_sub_epi16( y_low_q, offset_8 );
|
2014
2396
|
|
2015
|
-
/* Compute products of int16_t integers, add pairwise, store as int32_t */
|
2016
|
-
__m256i xy_high_q = _mm256_madd_epi16( x_high_q, y_high_q );
|
2017
|
-
__m256i xy_low_q = _mm256_madd_epi16( x_low_q, y_low_q );
|
2397
|
+
/* Compute products of int16_t integers, add pairwise, store as int32_t */
|
2398
|
+
__m256i xy_high_q = _mm256_madd_epi16( x_high_q, y_high_q );
|
2399
|
+
__m256i xy_low_q = _mm256_madd_epi16( x_low_q, y_low_q );
|
2018
2400
|
|
2019
|
-
/* Accumulate the products of int32_t integers -> we now have a vector of 8 int_32t */
|
2020
|
-
__m256i xy_q = _mm256_add_epi32( xy_high_q, xy_low_q );
|
2401
|
+
/* Accumulate the products of int32_t integers -> we now have a vector of 8 int_32t */
|
2402
|
+
__m256i xy_q = _mm256_add_epi32( xy_high_q, xy_low_q );
|
2021
2403
|
|
2022
|
-
/* Convert to vectore of 8 int32_t to 8 floats */
|
2023
|
-
__m256 q = _mm256_cvtepi32_ps( xy_q );
|
2404
|
+
/* Convert to vectore of 8 int32_t to 8 floats */
|
2405
|
+
__m256 q = _mm256_cvtepi32_ps( xy_q );
|
2024
2406
|
|
2025
|
-
/* Multiply q with scale and accumulate */
|
2026
|
-
acc = _mm256_fmadd_ps( scale, q, acc );
|
2407
|
+
/* Multiply q with scale and accumulate */
|
2408
|
+
acc = _mm256_fmadd_ps( scale, q, acc );
|
2027
2409
|
}
|
2028
|
-
|
2029
|
-
}
|
2410
|
+
}
|
2030
2411
|
|
2031
2412
|
// Return horizontal sum of the acc vector
|
2032
2413
|
__m128 res = _mm256_extractf128_ps( acc, 1 );
|
@@ -2087,18 +2468,18 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
|
|
2087
2468
|
float sum1 = 0.0f;
|
2088
2469
|
|
2089
2470
|
for (int i = 0; i < nb; i += 2) {
|
2090
|
-
const block_q4_0 * restrict x0 = &
|
2091
|
-
const block_q4_0 * restrict y0 = &
|
2092
|
-
const block_q4_0 * restrict x1 = &
|
2093
|
-
const block_q4_0 * restrict y1 = &
|
2471
|
+
const block_q4_0 * restrict x0 = &x[i + 0];
|
2472
|
+
const block_q4_0 * restrict y0 = &y[i + 0];
|
2473
|
+
const block_q4_0 * restrict x1 = &x[i + 1];
|
2474
|
+
const block_q4_0 * restrict y1 = &y[i + 1];
|
2094
2475
|
|
2095
2476
|
const v128_t m4b = wasm_u8x16_splat(0xf);
|
2096
2477
|
const v128_t s8b = wasm_i8x16_splat(0x8);
|
2097
2478
|
|
2098
|
-
const v128_t v0_0 = wasm_v128_load(x0
|
2099
|
-
const v128_t v0_1 = wasm_v128_load(y0
|
2100
|
-
const v128_t v1_0 = wasm_v128_load(x1
|
2101
|
-
const v128_t v1_1 = wasm_v128_load(y1
|
2479
|
+
const v128_t v0_0 = wasm_v128_load(x0->qs);
|
2480
|
+
const v128_t v0_1 = wasm_v128_load(y0->qs);
|
2481
|
+
const v128_t v1_0 = wasm_v128_load(x1->qs);
|
2482
|
+
const v128_t v1_1 = wasm_v128_load(y1->qs);
|
2102
2483
|
|
2103
2484
|
// 4-bit -> 8-bit
|
2104
2485
|
const v128_t v0_0l = wasm_v128_and(v0_0, m4b);
|
@@ -2170,18 +2551,20 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
|
|
2170
2551
|
const uint8_t * restrict p0 = x[i].qs;
|
2171
2552
|
const uint8_t * restrict p1 = y[i].qs;
|
2172
2553
|
|
2173
|
-
|
2554
|
+
int sumi = 0;
|
2555
|
+
for (int j = 0; j < QK4_0/2; j++) {
|
2174
2556
|
const uint8_t v0 = p0[j];
|
2175
2557
|
const uint8_t v1 = p1[j];
|
2176
2558
|
|
2177
|
-
const
|
2178
|
-
const
|
2559
|
+
const int i0 = (v0 & 0xf) - 8;
|
2560
|
+
const int i1 = (v0 >> 4) - 8;
|
2179
2561
|
|
2180
|
-
const
|
2181
|
-
const
|
2562
|
+
const int i2 = (v1 & 0xf) - 8;
|
2563
|
+
const int i3 = (v1 >> 4) - 8;
|
2182
2564
|
|
2183
|
-
|
2565
|
+
sumi += i0*i2 + i1*i3;
|
2184
2566
|
}
|
2567
|
+
sumf += d0 * d1 * sumi;
|
2185
2568
|
}
|
2186
2569
|
#endif
|
2187
2570
|
|
@@ -2189,7 +2572,7 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
|
|
2189
2572
|
}
|
2190
2573
|
|
2191
2574
|
static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
2192
|
-
const int nb = n /
|
2575
|
+
const int nb = n / QK4_1;
|
2193
2576
|
|
2194
2577
|
const block_q4_1 * restrict x = vx;
|
2195
2578
|
const block_q4_1 * restrict y = vy;
|
@@ -2266,46 +2649,81 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest
|
|
2266
2649
|
res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
|
2267
2650
|
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
|
2268
2651
|
|
2269
|
-
sumf = _mm_cvtss_f32( res ) + acc_offset *
|
2652
|
+
sumf = _mm_cvtss_f32( res ) + acc_offset * QK4_1;
|
2270
2653
|
#elif defined(__ARM_NEON)
|
2271
2654
|
float sum00 = 0.0f;
|
2272
2655
|
float sum01 = 0.0f;
|
2273
2656
|
float sum10 = 0.0f;
|
2274
2657
|
float sum11 = 0.0f;
|
2275
2658
|
|
2276
|
-
for (int i = 0; i < nb;
|
2659
|
+
for (int i = 0; i < nb; i += 2) {
|
2277
2660
|
const block_q4_1 * restrict x0 = &x[i + 0];
|
2278
2661
|
const block_q4_1 * restrict y0 = &y[i + 0];
|
2662
|
+
const block_q4_1 * restrict x1 = &x[i + 1];
|
2663
|
+
const block_q4_1 * restrict y1 = &y[i + 1];
|
2279
2664
|
|
2280
2665
|
const uint8x16_t m4b = vdupq_n_u8(0xf);
|
2281
2666
|
|
2282
2667
|
const uint8x16_t v0_0 = vld1q_u8(x0->qs);
|
2283
2668
|
const uint8x16_t v1_0 = vld1q_u8(y0->qs);
|
2669
|
+
const uint8x16_t v0_1 = vld1q_u8(x1->qs);
|
2670
|
+
const uint8x16_t v1_1 = vld1q_u8(y1->qs);
|
2284
2671
|
|
2285
|
-
//
|
2672
|
+
// 4-bit -> 8-bit
|
2286
2673
|
const uint8x16_t v0_0l = vandq_u8(v0_0, m4b);
|
2287
2674
|
const uint8x16_t v1_0l = vandq_u8(v1_0, m4b);
|
2288
|
-
|
2289
2675
|
const uint8x16_t v0_0h = vshrq_n_u8(v0_0, 4);
|
2290
2676
|
const uint8x16_t v1_0h = vshrq_n_u8(v1_0, 4);
|
2291
2677
|
|
2292
|
-
|
2678
|
+
const uint8x16_t v0_1l = vandq_u8(v0_1, m4b);
|
2679
|
+
const uint8x16_t v1_1l = vandq_u8(v1_1, m4b);
|
2680
|
+
const uint8x16_t v0_1h = vshrq_n_u8(v0_1, 4);
|
2681
|
+
const uint8x16_t v1_1h = vshrq_n_u8(v1_1, 4);
|
2682
|
+
|
2683
|
+
sum00 += x0->m*y0->m;
|
2684
|
+
sum01 += y0->m*x0->d*((uint16_t)vaddvq_u8(v0_0l) + (uint16_t)vaddvq_u8(v0_0h));
|
2685
|
+
sum10 += x0->m*y0->d*((uint16_t)vaddvq_u8(v1_0l) + (uint16_t)vaddvq_u8(v1_0h));
|
2686
|
+
|
2687
|
+
sum00 += x1->m*y1->m;
|
2688
|
+
sum01 += y1->m*x1->d*((uint16_t)vaddvq_u8(v0_1l) + (uint16_t)vaddvq_u8(v0_1h));
|
2689
|
+
sum10 += x1->m*y1->d*((uint16_t)vaddvq_u8(v1_1l) + (uint16_t)vaddvq_u8(v1_1h));
|
2690
|
+
|
2691
|
+
#if defined(__ARM_FEATURE_DOTPROD)
|
2692
|
+
// dot product into int32x4_t
|
2693
|
+
uint32x4_t p_0 = vdotq_u32(vdupq_n_u32(0), v0_0l, v1_0l);
|
2694
|
+
uint32x4_t p_1 = vdotq_u32(vdupq_n_u32(0), v0_1l, v1_1l);
|
2695
|
+
|
2696
|
+
p_0 = vdotq_u32(p_0, v0_0h, v1_0h);
|
2697
|
+
p_1 = vdotq_u32(p_1, v0_1h, v1_1h);
|
2698
|
+
|
2699
|
+
sum11 += x0->d*y0->d*vaddvq_u32(p_0);
|
2700
|
+
sum11 += x1->d*y1->d*vaddvq_u32(p_1);
|
2701
|
+
#else
|
2293
2702
|
const uint16x8_t pl0l = vmull_u8(vget_low_u8 (v0_0l), vget_low_u8 (v1_0l));
|
2294
2703
|
const uint16x8_t pl0h = vmull_u8(vget_high_u8(v0_0l), vget_high_u8(v1_0l));
|
2295
|
-
|
2296
2704
|
const uint16x8_t ph0l = vmull_u8(vget_low_u8 (v0_0h), vget_low_u8 (v1_0h));
|
2297
2705
|
const uint16x8_t ph0h = vmull_u8(vget_high_u8(v0_0h), vget_high_u8(v1_0h));
|
2298
2706
|
|
2299
|
-
const uint16x8_t
|
2300
|
-
const uint16x8_t
|
2707
|
+
const uint16x8_t pl1l = vmull_u8(vget_low_u8 (v0_1l), vget_low_u8 (v1_1l));
|
2708
|
+
const uint16x8_t pl1h = vmull_u8(vget_high_u8(v0_1l), vget_high_u8(v1_1l));
|
2709
|
+
const uint16x8_t ph1l = vmull_u8(vget_low_u8 (v0_1h), vget_low_u8 (v1_1h));
|
2710
|
+
const uint16x8_t ph1h = vmull_u8(vget_high_u8(v0_1h), vget_high_u8(v1_1h));
|
2301
2711
|
|
2302
|
-
|
2303
|
-
|
2304
|
-
|
2305
|
-
|
2712
|
+
const uint16x8_t pl_0 = vaddq_u16(pl0l, pl0h);
|
2713
|
+
const uint16x8_t ph_0 = vaddq_u16(ph0l, ph0h);
|
2714
|
+
|
2715
|
+
const uint16x8_t pl_1 = vaddq_u16(pl1l, pl1h);
|
2716
|
+
const uint16x8_t ph_1 = vaddq_u16(ph1l, ph1h);
|
2717
|
+
|
2718
|
+
const uint16x8_t p_0 = vaddq_u16(pl_0, ph_0);
|
2719
|
+
const uint16x8_t p_1 = vaddq_u16(pl_1, ph_1);
|
2720
|
+
|
2721
|
+
sum11 += x0->d*y0->d*vaddvq_u16(p_0);
|
2722
|
+
sum11 += x1->d*y1->d*vaddvq_u16(p_1);
|
2723
|
+
#endif
|
2306
2724
|
}
|
2307
2725
|
|
2308
|
-
sumf =
|
2726
|
+
sumf = QK4_1*sum00 + sum01 + sum10 + sum11;
|
2309
2727
|
#else
|
2310
2728
|
// scalar
|
2311
2729
|
for (int i = 0; i < nb; i++) {
|
@@ -2318,7 +2736,7 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest
|
|
2318
2736
|
const uint8_t * restrict p0 = x[i].qs;
|
2319
2737
|
const uint8_t * restrict p1 = y[i].qs;
|
2320
2738
|
|
2321
|
-
for (int j = 0; j <
|
2739
|
+
for (int j = 0; j < QK4_1/2; j++) {
|
2322
2740
|
const uint8_t v0 = p0[j];
|
2323
2741
|
const uint8_t v1 = p1[j];
|
2324
2742
|
|
@@ -2336,21 +2754,224 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest
|
|
2336
2754
|
*s = sumf;
|
2337
2755
|
}
|
2338
2756
|
|
2339
|
-
|
2340
|
-
|
2341
|
-
inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) {
|
2342
|
-
ggml_float sumf[GGML_VEC_DOT_UNROLL] = { 0.0 };
|
2757
|
+
static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
2758
|
+
const int nb = n / QK8_0;
|
2343
2759
|
|
2344
|
-
|
2760
|
+
assert(n % QK8_0 == 0);
|
2761
|
+
assert(nb % 2 == 0);
|
2345
2762
|
|
2346
|
-
|
2347
|
-
|
2348
|
-
}
|
2763
|
+
const block_q4_0 * restrict x = vx;
|
2764
|
+
const block_q8_0 * restrict y = vy;
|
2349
2765
|
|
2350
|
-
|
2351
|
-
const int np = (n & ~(GGML_F16_STEP - 1));
|
2766
|
+
float sumf = 0.0;
|
2352
2767
|
|
2353
|
-
|
2768
|
+
#if defined(__ARM_NEON)
|
2769
|
+
float sum0 = 0.0f;
|
2770
|
+
float sum1 = 0.0f;
|
2771
|
+
|
2772
|
+
for (int i = 0; i < nb; i += 2) {
|
2773
|
+
const block_q4_0 * restrict x0 = &x[i + 0];
|
2774
|
+
const block_q4_0 * restrict x1 = &x[i + 1];
|
2775
|
+
const block_q8_0 * restrict y0 = &y[i + 0];
|
2776
|
+
const block_q8_0 * restrict y1 = &y[i + 1];
|
2777
|
+
|
2778
|
+
const uint8x16_t m4b = vdupq_n_u8(0xf);
|
2779
|
+
const int8x16_t s8b = vdupq_n_s8(0x8);
|
2780
|
+
|
2781
|
+
const uint8x16_t v0_0 = vld1q_u8(x0->qs);
|
2782
|
+
const uint8x16_t v0_1 = vld1q_u8(x1->qs);
|
2783
|
+
|
2784
|
+
// 4-bit -> 8-bit
|
2785
|
+
const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
|
2786
|
+
const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
|
2787
|
+
const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
|
2788
|
+
const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
|
2789
|
+
|
2790
|
+
// sub 8
|
2791
|
+
const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
|
2792
|
+
const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
|
2793
|
+
const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
|
2794
|
+
const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
|
2795
|
+
|
2796
|
+
// load y
|
2797
|
+
const int8x16_t v1_0l = vld1q_s8(y0->qs);
|
2798
|
+
const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
|
2799
|
+
const int8x16_t v1_1l = vld1q_s8(y1->qs);
|
2800
|
+
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
|
2801
|
+
|
2802
|
+
// interleave
|
2803
|
+
const int8x16_t v1_0ls = vuzp1q_s8(v1_0l, v1_0h);
|
2804
|
+
const int8x16_t v1_0hs = vuzp2q_s8(v1_0l, v1_0h);
|
2805
|
+
const int8x16_t v1_1ls = vuzp1q_s8(v1_1l, v1_1h);
|
2806
|
+
const int8x16_t v1_1hs = vuzp2q_s8(v1_1l, v1_1h);
|
2807
|
+
|
2808
|
+
#if defined(__ARM_FEATURE_DOTPROD)
|
2809
|
+
// dot product into int32x4_t
|
2810
|
+
int32x4_t p_0 = vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls);
|
2811
|
+
int32x4_t p_1 = vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls);
|
2812
|
+
|
2813
|
+
p_0 = vdotq_s32(p_0, v0_0hs, v1_0hs);
|
2814
|
+
p_1 = vdotq_s32(p_1, v0_1hs, v1_1hs);
|
2815
|
+
|
2816
|
+
sum0 += x0->d*y0->d*vaddvq_s32(p_0);
|
2817
|
+
sum1 += x1->d*y1->d*vaddvq_s32(p_1);
|
2818
|
+
#else
|
2819
|
+
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
|
2820
|
+
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
|
2821
|
+
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs));
|
2822
|
+
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs));
|
2823
|
+
|
2824
|
+
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls));
|
2825
|
+
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls));
|
2826
|
+
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
|
2827
|
+
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
|
2828
|
+
|
2829
|
+
const int16x8_t pl_0 = vaddq_s16(pl0l, pl0h);
|
2830
|
+
const int16x8_t ph_0 = vaddq_s16(ph0l, ph0h);
|
2831
|
+
|
2832
|
+
const int16x8_t pl_1 = vaddq_s16(pl1l, pl1h);
|
2833
|
+
const int16x8_t ph_1 = vaddq_s16(ph1l, ph1h);
|
2834
|
+
|
2835
|
+
const int16x8_t p_0 = vaddq_s16(pl_0, ph_0);
|
2836
|
+
const int16x8_t p_1 = vaddq_s16(pl_1, ph_1);
|
2837
|
+
|
2838
|
+
sum0 += x0->d*y0->d*vaddvq_s16(p_0);
|
2839
|
+
sum1 += x1->d*y1->d*vaddvq_s16(p_1);
|
2840
|
+
#endif
|
2841
|
+
}
|
2842
|
+
|
2843
|
+
sumf = sum0 + sum1;
|
2844
|
+
#elif defined(__AVX2__)
|
2845
|
+
// Initialize accumulator with zeros
|
2846
|
+
__m256 acc = _mm256_setzero_ps();
|
2847
|
+
|
2848
|
+
// Main loop
|
2849
|
+
for (int i = 0; i < nb; ++i) {
|
2850
|
+
/* Compute combined scale for the block */
|
2851
|
+
const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
|
2852
|
+
|
2853
|
+
__m256i bx = bytesFromNibbles(x[i].qs);
|
2854
|
+
|
2855
|
+
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
|
2856
|
+
const __m256i off = _mm256_set1_epi8( 8 );
|
2857
|
+
bx = _mm256_sub_epi8( bx, off );
|
2858
|
+
|
2859
|
+
__m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
|
2860
|
+
|
2861
|
+
// Get absolute values of x vectors
|
2862
|
+
const __m256i ax = _mm256_sign_epi8(bx, bx);
|
2863
|
+
|
2864
|
+
// Sign the values of the y vectors
|
2865
|
+
const __m256i sy = _mm256_sign_epi8(by, bx);
|
2866
|
+
|
2867
|
+
// Perform multiplication and create 16-bit values
|
2868
|
+
const __m256i dot = _mm256_maddubs_epi16(ax, sy);
|
2869
|
+
|
2870
|
+
const __m256i ones = _mm256_set1_epi16(1);
|
2871
|
+
__m256i xy_q = _mm256_madd_epi16(ones, dot);
|
2872
|
+
|
2873
|
+
/* Convert to vectore of 8 int32_t to 8 floats */
|
2874
|
+
__m256 q = _mm256_cvtepi32_ps( xy_q );
|
2875
|
+
|
2876
|
+
/* Multiply q with scale and accumulate */
|
2877
|
+
acc = _mm256_fmadd_ps( d, q, acc );
|
2878
|
+
}
|
2879
|
+
|
2880
|
+
// Return horizontal sum of the acc vector
|
2881
|
+
__m128 res = _mm256_extractf128_ps( acc, 1 );
|
2882
|
+
res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
|
2883
|
+
res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
|
2884
|
+
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
|
2885
|
+
|
2886
|
+
sumf = _mm_cvtss_f32( res );
|
2887
|
+
#elif defined(__AVX__)
|
2888
|
+
// Initialize accumulator with zeros
|
2889
|
+
__m256 acc = _mm256_setzero_ps();
|
2890
|
+
|
2891
|
+
// Main loop
|
2892
|
+
for (int i = 0; i < nb; ++i) {
|
2893
|
+
// Compute combined scale for the block
|
2894
|
+
const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
|
2895
|
+
|
2896
|
+
__m128i i32[2];
|
2897
|
+
for (int j = 0; j < 2; ++j) {
|
2898
|
+
// Load 8 bytes, and unpack 4 bit fields into bytes, making 16 bytes
|
2899
|
+
__m128i bx = bytesFromNibbles( x[i].qs + 8*j );
|
2900
|
+
__m128i by = _mm_loadu_si128((const __m128i *)(y[i].qs + 16*j));
|
2901
|
+
|
2902
|
+
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
|
2903
|
+
const __m128i off = _mm_set1_epi8( 8 );
|
2904
|
+
bx = _mm_sub_epi8( bx, off );
|
2905
|
+
|
2906
|
+
// Get absolute values of x vectors
|
2907
|
+
const __m128i ax = _mm_sign_epi8(bx, bx);
|
2908
|
+
|
2909
|
+
// Sign the values of the y vectors
|
2910
|
+
const __m128i sy = _mm_sign_epi8(by, bx);
|
2911
|
+
|
2912
|
+
// Perform multiplication and create 16-bit values
|
2913
|
+
const __m128i dot = _mm_maddubs_epi16(ax, sy);
|
2914
|
+
|
2915
|
+
const __m128i ones = _mm_set1_epi16(1);
|
2916
|
+
i32[j] = _mm_madd_epi16(ones, dot);
|
2917
|
+
}
|
2918
|
+
|
2919
|
+
// Convert int32_t to float
|
2920
|
+
__m256 p = _mm256_cvtepi32_ps( _mm256_set_m128i( i32[0], i32[1] ));
|
2921
|
+
// Apply the scale, and accumulate
|
2922
|
+
acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc);
|
2923
|
+
}
|
2924
|
+
|
2925
|
+
// Return horizontal sum of the acc vector
|
2926
|
+
__m128 res = _mm256_extractf128_ps( acc, 1 );
|
2927
|
+
res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
|
2928
|
+
res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
|
2929
|
+
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
|
2930
|
+
|
2931
|
+
sumf = _mm_cvtss_f32( res );
|
2932
|
+
#else
|
2933
|
+
// scalar
|
2934
|
+
for (int i = 0; i < nb; i++) {
|
2935
|
+
const float d0 = x[i].d;
|
2936
|
+
const float d1 = y[i].d;
|
2937
|
+
|
2938
|
+
const uint8_t * restrict p0 = x[i].qs;
|
2939
|
+
const int8_t * restrict p1 = y[i].qs;
|
2940
|
+
|
2941
|
+
int sumi = 0;
|
2942
|
+
for (int j = 0; j < QK8_0/2; j++) {
|
2943
|
+
const uint8_t v0 = p0[j];
|
2944
|
+
|
2945
|
+
const int i0 = (int8_t) (v0 & 0xf) - 8;
|
2946
|
+
const int i1 = (int8_t) (v0 >> 4) - 8;
|
2947
|
+
|
2948
|
+
const int i2 = p1[2*j + 0];
|
2949
|
+
const int i3 = p1[2*j + 1];
|
2950
|
+
|
2951
|
+
sumi += i0*i2 + i1*i3;
|
2952
|
+
}
|
2953
|
+
sumf += d0*d1*sumi;
|
2954
|
+
}
|
2955
|
+
#endif
|
2956
|
+
|
2957
|
+
*s = sumf;
|
2958
|
+
}
|
2959
|
+
|
2960
|
+
// compute GGML_VEC_DOT_UNROLL dot products at once
|
2961
|
+
// xs - x row stride in bytes
|
2962
|
+
inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) {
|
2963
|
+
ggml_float sumf[GGML_VEC_DOT_UNROLL] = { 0.0 };
|
2964
|
+
|
2965
|
+
ggml_fp16_t * restrict x[GGML_VEC_DOT_UNROLL];
|
2966
|
+
|
2967
|
+
for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) {
|
2968
|
+
x[i] = (ggml_fp16_t *) ((char *) xv + i*xs);
|
2969
|
+
}
|
2970
|
+
|
2971
|
+
#if defined(GGML_SIMD)
|
2972
|
+
const int np = (n & ~(GGML_F16_STEP - 1));
|
2973
|
+
|
2974
|
+
GGML_F16_VEC sum[GGML_VEC_DOT_UNROLL][GGML_F16_ARR] = { { GGML_F16_VEC_ZERO } };
|
2354
2975
|
|
2355
2976
|
GGML_F16_VEC ax[GGML_F16_ARR];
|
2356
2977
|
GGML_F16_VEC ay[GGML_F16_ARR];
|
@@ -2578,29 +3199,41 @@ inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x
|
|
2578
3199
|
//
|
2579
3200
|
|
2580
3201
|
static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = {
|
2581
|
-
|
2582
|
-
|
2583
|
-
|
2584
|
-
|
2585
|
-
|
2586
|
-
1,
|
2587
|
-
1,
|
3202
|
+
[GGML_TYPE_F32] = 1,
|
3203
|
+
[GGML_TYPE_F16] = 1,
|
3204
|
+
[GGML_TYPE_Q4_0] = QK4_0,
|
3205
|
+
[GGML_TYPE_Q4_1] = QK4_1,
|
3206
|
+
[GGML_TYPE_Q8_0] = QK8_0,
|
3207
|
+
[GGML_TYPE_I8] = 1,
|
3208
|
+
[GGML_TYPE_I16] = 1,
|
3209
|
+
[GGML_TYPE_I32] = 1,
|
2588
3210
|
};
|
2589
|
-
|
2590
|
-
static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_COUNT != 5");
|
3211
|
+
static_assert(GGML_TYPE_COUNT == 8, "GGML_BLCK_SIZE is outdated");
|
2591
3212
|
|
2592
3213
|
static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
|
2593
|
-
sizeof(
|
2594
|
-
sizeof(
|
2595
|
-
sizeof(
|
2596
|
-
sizeof(
|
2597
|
-
sizeof(
|
2598
|
-
sizeof(
|
2599
|
-
sizeof(
|
3214
|
+
[GGML_TYPE_F32] = sizeof(float),
|
3215
|
+
[GGML_TYPE_F16] = sizeof(ggml_fp16_t),
|
3216
|
+
[GGML_TYPE_Q4_0] = sizeof(block_q4_0),
|
3217
|
+
[GGML_TYPE_Q4_1] = sizeof(block_q4_1),
|
3218
|
+
[GGML_TYPE_Q8_0] = sizeof(block_q8_0),
|
3219
|
+
[GGML_TYPE_I8] = sizeof(int8_t),
|
3220
|
+
[GGML_TYPE_I16] = sizeof(int16_t),
|
3221
|
+
[GGML_TYPE_I32] = sizeof(int32_t),
|
2600
3222
|
};
|
2601
|
-
|
2602
|
-
|
2603
|
-
|
3223
|
+
static_assert(GGML_TYPE_COUNT == 8, "GGML_TYPE_SIZE is outdated");
|
3224
|
+
|
3225
|
+
|
3226
|
+
static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
|
3227
|
+
[GGML_TYPE_F32] = "f32",
|
3228
|
+
[GGML_TYPE_F16] = "f16",
|
3229
|
+
[GGML_TYPE_Q4_0] = "q4_0",
|
3230
|
+
[GGML_TYPE_Q4_1] = "q4_1",
|
3231
|
+
[GGML_TYPE_Q8_0] = "q8_0",
|
3232
|
+
[GGML_TYPE_I8] = "i8",
|
3233
|
+
[GGML_TYPE_I16] = "i16",
|
3234
|
+
[GGML_TYPE_I32] = "i32",
|
3235
|
+
};
|
3236
|
+
static_assert(GGML_TYPE_COUNT == 8, "GGML_TYPE_NAME is outdated");
|
2604
3237
|
|
2605
3238
|
static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
|
2606
3239
|
"NONE",
|
@@ -2629,6 +3262,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
|
|
2629
3262
|
|
2630
3263
|
"SCALE",
|
2631
3264
|
"CPY",
|
3265
|
+
"CONT",
|
2632
3266
|
"RESHAPE",
|
2633
3267
|
"VIEW",
|
2634
3268
|
"PERMUTE",
|
@@ -2642,9 +3276,12 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
|
|
2642
3276
|
|
2643
3277
|
"FLASH_ATTN",
|
2644
3278
|
"FLASH_FF",
|
3279
|
+
|
3280
|
+
"MAP_UNARY",
|
3281
|
+
"MAP_BINARY",
|
2645
3282
|
};
|
2646
3283
|
|
2647
|
-
static_assert(GGML_OP_COUNT ==
|
3284
|
+
static_assert(GGML_OP_COUNT == 38, "GGML_OP_COUNT != 38");
|
2648
3285
|
|
2649
3286
|
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
2650
3287
|
"none",
|
@@ -2673,6 +3310,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|
2673
3310
|
|
2674
3311
|
"x*v",
|
2675
3312
|
"x-\\>y",
|
3313
|
+
"cont(x)",
|
2676
3314
|
"reshape(x)",
|
2677
3315
|
"view(x)",
|
2678
3316
|
"permute(x)",
|
@@ -2686,24 +3324,12 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|
2686
3324
|
|
2687
3325
|
"flash_attn(x)",
|
2688
3326
|
"flash_ff(x)",
|
2689
|
-
};
|
2690
|
-
|
2691
|
-
static_assert(GGML_OP_COUNT == 35, "GGML_OP_COUNT != 35");
|
2692
|
-
|
2693
|
-
//
|
2694
|
-
// ggml object
|
2695
|
-
//
|
2696
|
-
|
2697
|
-
struct ggml_object {
|
2698
|
-
size_t offs;
|
2699
|
-
size_t size;
|
2700
3327
|
|
2701
|
-
|
2702
|
-
|
2703
|
-
char padding[8];
|
3328
|
+
"f(x)",
|
3329
|
+
"f(x,y)",
|
2704
3330
|
};
|
2705
3331
|
|
2706
|
-
|
3332
|
+
static_assert(GGML_OP_COUNT == 38, "GGML_OP_COUNT != 38");
|
2707
3333
|
|
2708
3334
|
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
|
2709
3335
|
static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
|
@@ -2716,7 +3342,6 @@ struct ggml_context {
|
|
2716
3342
|
size_t mem_size;
|
2717
3343
|
void * mem_buffer;
|
2718
3344
|
bool mem_buffer_owned;
|
2719
|
-
bool mem_buffer_mlocked;
|
2720
3345
|
bool no_alloc;
|
2721
3346
|
|
2722
3347
|
int n_objects;
|
@@ -2834,6 +3459,11 @@ float ggml_type_sizef(enum ggml_type type) {
|
|
2834
3459
|
return ((float)(GGML_TYPE_SIZE[type]))/GGML_BLCK_SIZE[type];
|
2835
3460
|
}
|
2836
3461
|
|
3462
|
+
const char * ggml_type_name(enum ggml_type type) {
|
3463
|
+
return GGML_TYPE_NAME[type];
|
3464
|
+
}
|
3465
|
+
|
3466
|
+
|
2837
3467
|
size_t ggml_element_size(const struct ggml_tensor * tensor) {
|
2838
3468
|
return GGML_TYPE_SIZE[tensor->type];
|
2839
3469
|
}
|
@@ -2999,11 +3629,12 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
|
|
2999
3629
|
return NULL;
|
3000
3630
|
}
|
3001
3631
|
|
3632
|
+
const size_t mem_size = (params.mem_size + GGML_MEM_ALIGN - 1) & ~(GGML_MEM_ALIGN - 1);
|
3633
|
+
|
3002
3634
|
*ctx = (struct ggml_context) {
|
3003
|
-
/*.mem_size =*/
|
3004
|
-
/*.mem_buffer =*/ params.mem_buffer ? params.mem_buffer :
|
3635
|
+
/*.mem_size =*/ mem_size,
|
3636
|
+
/*.mem_buffer =*/ params.mem_buffer ? params.mem_buffer : GGML_ALIGNED_MALLOC(mem_size),
|
3005
3637
|
/*.mem_buffer_owned =*/ params.mem_buffer ? false : true,
|
3006
|
-
/*.mem_buffer_mlocked =*/ false,
|
3007
3638
|
/*.no_alloc =*/ params.no_alloc,
|
3008
3639
|
/*.n_objects =*/ 0,
|
3009
3640
|
/*.objects_begin =*/ NULL,
|
@@ -3012,7 +3643,7 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
|
|
3012
3643
|
/*.scratch_save =*/ { 0, 0, NULL, },
|
3013
3644
|
};
|
3014
3645
|
|
3015
|
-
GGML_ASSERT(ctx->mem_buffer != NULL);
|
3646
|
+
GGML_ASSERT(ctx->mem_buffer != NULL);
|
3016
3647
|
|
3017
3648
|
ggml_assert_aligned(ctx->mem_buffer);
|
3018
3649
|
|
@@ -3036,16 +3667,8 @@ void ggml_free(struct ggml_context * ctx) {
|
|
3036
3667
|
GGML_PRINT_DEBUG("%s: context %d with %d objects has been freed. memory used = %zu\n",
|
3037
3668
|
__func__, i, ctx->n_objects, ctx->objects_end->offs + ctx->objects_end->size);
|
3038
3669
|
|
3039
|
-
#if GGML_MLOCK_SUPPORT
|
3040
|
-
if (ctx->mem_buffer_mlocked) {
|
3041
|
-
if (munlock(ctx->mem_buffer, ctx->mem_size)) {
|
3042
|
-
fprintf(stderr, "%s: failed to munlock buffer: %s\n", __func__, strerror(errno));
|
3043
|
-
}
|
3044
|
-
}
|
3045
|
-
#endif
|
3046
|
-
|
3047
3670
|
if (ctx->mem_buffer_owned) {
|
3048
|
-
|
3671
|
+
GGML_ALIGNED_FREE(ctx->mem_buffer);
|
3049
3672
|
}
|
3050
3673
|
|
3051
3674
|
found = true;
|
@@ -3072,48 +3695,6 @@ size_t ggml_set_scratch(struct ggml_context * ctx, struct ggml_scratch scratch)
|
|
3072
3695
|
return result;
|
3073
3696
|
}
|
3074
3697
|
|
3075
|
-
#ifdef __APPLE__
|
3076
|
-
#define MLOCK_SUGGESTION \
|
3077
|
-
"Try increasing the sysctl values 'vm.user_wire_limit' and 'vm.global_user_wire_limit' and/or " \
|
3078
|
-
"decreasing 'vm.global_no_user_wire_amount'. Also try increasing RLIMIT_MLOCK (ulimit -l).\n"
|
3079
|
-
#else
|
3080
|
-
#define MLOCK_SUGGESTION \
|
3081
|
-
"Try increasing RLIMIT_MLOCK ('ulimit -l' as root).\n"
|
3082
|
-
#endif
|
3083
|
-
|
3084
|
-
bool ggml_mlock_supported(void) {
|
3085
|
-
return GGML_MLOCK_SUPPORT;
|
3086
|
-
}
|
3087
|
-
|
3088
|
-
bool ggml_mlock(
|
3089
|
-
struct ggml_context * ctx,
|
3090
|
-
const void *opt_extra_addr,
|
3091
|
-
size_t opt_extra_len,
|
3092
|
-
char **err_p) {
|
3093
|
-
// TODO: Use SetProcessWorkingSetSize() + VirtualLock() on WIN32
|
3094
|
-
#if GGML_MLOCK_SUPPORT
|
3095
|
-
if (ctx->mem_buffer_mlocked) {
|
3096
|
-
return true;
|
3097
|
-
}
|
3098
|
-
if (mlock(ctx->mem_buffer, ctx->mem_size) ||
|
3099
|
-
(opt_extra_len &&
|
3100
|
-
mlock(opt_extra_addr, opt_extra_len))) {
|
3101
|
-
if ((*err_p = malloc(1024))) {
|
3102
|
-
snprintf(*err_p, 1024,
|
3103
|
-
"failed to mlock %zu-byte buffer: %s\n" MLOCK_SUGGESTION,
|
3104
|
-
ctx->mem_size + opt_extra_len,
|
3105
|
-
strerror(errno));
|
3106
|
-
}
|
3107
|
-
return false;
|
3108
|
-
}
|
3109
|
-
ctx->mem_buffer_mlocked = true;
|
3110
|
-
return true;
|
3111
|
-
#else // GGML_MLOCK_SUPPORT
|
3112
|
-
*err_p = strdup("can't mlock because it's not supported on this system");
|
3113
|
-
return false;
|
3114
|
-
#endif // GGML_MLOCK_SUPPORT
|
3115
|
-
}
|
3116
|
-
|
3117
3698
|
////////////////////////////////////////////////////////////////////////////////
|
3118
3699
|
|
3119
3700
|
struct ggml_tensor * ggml_new_tensor_impl(
|
@@ -3325,14 +3906,6 @@ struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) {
|
|
3325
3906
|
char * const data = tensor->data;
|
3326
3907
|
|
3327
3908
|
switch (tensor->type) {
|
3328
|
-
case GGML_TYPE_Q4_0:
|
3329
|
-
{
|
3330
|
-
GGML_ASSERT(false);
|
3331
|
-
} break;
|
3332
|
-
case GGML_TYPE_Q4_1:
|
3333
|
-
{
|
3334
|
-
GGML_ASSERT(false);
|
3335
|
-
} break;
|
3336
3909
|
case GGML_TYPE_I8:
|
3337
3910
|
{
|
3338
3911
|
assert(tensor->nb[0] == sizeof(int8_t));
|
@@ -3368,7 +3941,7 @@ struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) {
|
|
3368
3941
|
ggml_vec_set_f32(nc, (float *)(data + i*n1), value);
|
3369
3942
|
}
|
3370
3943
|
} break;
|
3371
|
-
|
3944
|
+
default:
|
3372
3945
|
{
|
3373
3946
|
GGML_ASSERT(false);
|
3374
3947
|
} break;
|
@@ -3385,14 +3958,6 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
|
|
3385
3958
|
char * const data = tensor->data;
|
3386
3959
|
|
3387
3960
|
switch (tensor->type) {
|
3388
|
-
case GGML_TYPE_Q4_0:
|
3389
|
-
{
|
3390
|
-
GGML_ASSERT(false);
|
3391
|
-
} break;
|
3392
|
-
case GGML_TYPE_Q4_1:
|
3393
|
-
{
|
3394
|
-
GGML_ASSERT(false);
|
3395
|
-
} break;
|
3396
3961
|
case GGML_TYPE_I8:
|
3397
3962
|
{
|
3398
3963
|
assert(tensor->nb[0] == sizeof(int8_t));
|
@@ -3428,7 +3993,7 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
|
|
3428
3993
|
ggml_vec_set_f32(nc, (float *)(data + i*n1), value);
|
3429
3994
|
}
|
3430
3995
|
} break;
|
3431
|
-
|
3996
|
+
default:
|
3432
3997
|
{
|
3433
3998
|
GGML_ASSERT(false);
|
3434
3999
|
} break;
|
@@ -3439,14 +4004,6 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
|
|
3439
4004
|
|
3440
4005
|
int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) {
|
3441
4006
|
switch (tensor->type) {
|
3442
|
-
case GGML_TYPE_Q4_0:
|
3443
|
-
{
|
3444
|
-
GGML_ASSERT(false);
|
3445
|
-
} break;
|
3446
|
-
case GGML_TYPE_Q4_1:
|
3447
|
-
{
|
3448
|
-
GGML_ASSERT(false);
|
3449
|
-
} break;
|
3450
4007
|
case GGML_TYPE_I8:
|
3451
4008
|
{
|
3452
4009
|
GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
|
@@ -3472,7 +4029,7 @@ int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) {
|
|
3472
4029
|
GGML_ASSERT(tensor->nb[0] == sizeof(float));
|
3473
4030
|
return ((float *)(tensor->data))[i];
|
3474
4031
|
} break;
|
3475
|
-
|
4032
|
+
default:
|
3476
4033
|
{
|
3477
4034
|
GGML_ASSERT(false);
|
3478
4035
|
} break;
|
@@ -3483,14 +4040,6 @@ int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) {
|
|
3483
4040
|
|
3484
4041
|
void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
|
3485
4042
|
switch (tensor->type) {
|
3486
|
-
case GGML_TYPE_Q4_0:
|
3487
|
-
{
|
3488
|
-
GGML_ASSERT(false);
|
3489
|
-
} break;
|
3490
|
-
case GGML_TYPE_Q4_1:
|
3491
|
-
{
|
3492
|
-
GGML_ASSERT(false);
|
3493
|
-
} break;
|
3494
4043
|
case GGML_TYPE_I8:
|
3495
4044
|
{
|
3496
4045
|
GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
|
@@ -3516,7 +4065,7 @@ void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
|
|
3516
4065
|
GGML_ASSERT(tensor->nb[0] == sizeof(float));
|
3517
4066
|
((float *)(tensor->data))[i] = value;
|
3518
4067
|
} break;
|
3519
|
-
|
4068
|
+
default:
|
3520
4069
|
{
|
3521
4070
|
GGML_ASSERT(false);
|
3522
4071
|
} break;
|
@@ -3525,14 +4074,6 @@ void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
|
|
3525
4074
|
|
3526
4075
|
float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {
|
3527
4076
|
switch (tensor->type) {
|
3528
|
-
case GGML_TYPE_Q4_0:
|
3529
|
-
{
|
3530
|
-
GGML_ASSERT(false);
|
3531
|
-
} break;
|
3532
|
-
case GGML_TYPE_Q4_1:
|
3533
|
-
{
|
3534
|
-
GGML_ASSERT(false);
|
3535
|
-
} break;
|
3536
4077
|
case GGML_TYPE_I8:
|
3537
4078
|
{
|
3538
4079
|
GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
|
@@ -3558,7 +4099,7 @@ float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {
|
|
3558
4099
|
GGML_ASSERT(tensor->nb[0] == sizeof(float));
|
3559
4100
|
return ((float *)(tensor->data))[i];
|
3560
4101
|
} break;
|
3561
|
-
|
4102
|
+
default:
|
3562
4103
|
{
|
3563
4104
|
GGML_ASSERT(false);
|
3564
4105
|
} break;
|
@@ -3569,14 +4110,6 @@ float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {
|
|
3569
4110
|
|
3570
4111
|
void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) {
|
3571
4112
|
switch (tensor->type) {
|
3572
|
-
case GGML_TYPE_Q4_0:
|
3573
|
-
{
|
3574
|
-
GGML_ASSERT(false);
|
3575
|
-
} break;
|
3576
|
-
case GGML_TYPE_Q4_1:
|
3577
|
-
{
|
3578
|
-
GGML_ASSERT(false);
|
3579
|
-
} break;
|
3580
4113
|
case GGML_TYPE_I8:
|
3581
4114
|
{
|
3582
4115
|
GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
|
@@ -3602,7 +4135,7 @@ void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) {
|
|
3602
4135
|
GGML_ASSERT(tensor->nb[0] == sizeof(float));
|
3603
4136
|
((float *)(tensor->data))[i] = value;
|
3604
4137
|
} break;
|
3605
|
-
|
4138
|
+
default:
|
3606
4139
|
{
|
3607
4140
|
GGML_ASSERT(false);
|
3608
4141
|
} break;
|
@@ -4388,26 +4921,22 @@ struct ggml_tensor * ggml_cpy_inplace(
|
|
4388
4921
|
return ggml_cpy_impl(ctx, a, b, true);
|
4389
4922
|
}
|
4390
4923
|
|
4391
|
-
//
|
4924
|
+
// ggml_cont
|
4392
4925
|
|
4393
|
-
struct ggml_tensor *
|
4926
|
+
struct ggml_tensor * ggml_cont_impl(
|
4394
4927
|
struct ggml_context * ctx,
|
4395
|
-
struct ggml_tensor
|
4396
|
-
|
4397
|
-
GGML_ASSERT(ggml_is_contiguous(a));
|
4398
|
-
GGML_ASSERT(ggml_is_contiguous(b));
|
4399
|
-
GGML_ASSERT(ggml_nelements(a) == ggml_nelements(b));
|
4400
|
-
|
4928
|
+
struct ggml_tensor * a,
|
4929
|
+
bool inplace) {
|
4401
4930
|
bool is_node = false;
|
4402
4931
|
|
4403
|
-
if (
|
4932
|
+
if (!inplace && a->grad) {
|
4404
4933
|
GGML_ASSERT(false); // TODO: implement backward
|
4405
4934
|
is_node = true;
|
4406
4935
|
}
|
4407
4936
|
|
4408
|
-
struct ggml_tensor * result =
|
4937
|
+
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
4409
4938
|
|
4410
|
-
result->op =
|
4939
|
+
result->op = GGML_OP_CONT;
|
4411
4940
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
4412
4941
|
result->src0 = a;
|
4413
4942
|
result->src1 = NULL;
|
@@ -4415,12 +4944,51 @@ struct ggml_tensor * ggml_reshape(
|
|
4415
4944
|
return result;
|
4416
4945
|
}
|
4417
4946
|
|
4418
|
-
struct ggml_tensor *
|
4947
|
+
struct ggml_tensor * ggml_cont(
|
4419
4948
|
struct ggml_context * ctx,
|
4420
|
-
struct ggml_tensor
|
4421
|
-
|
4422
|
-
|
4423
|
-
|
4949
|
+
struct ggml_tensor * a) {
|
4950
|
+
return ggml_cont_impl(ctx, a, false);
|
4951
|
+
}
|
4952
|
+
|
4953
|
+
struct ggml_tensor * ggml_cont_inplace(
|
4954
|
+
struct ggml_context * ctx,
|
4955
|
+
struct ggml_tensor * a) {
|
4956
|
+
return ggml_cont_impl(ctx, a, true);
|
4957
|
+
}
|
4958
|
+
|
4959
|
+
// ggml_reshape
|
4960
|
+
|
4961
|
+
struct ggml_tensor * ggml_reshape(
|
4962
|
+
struct ggml_context * ctx,
|
4963
|
+
struct ggml_tensor * a,
|
4964
|
+
struct ggml_tensor * b) {
|
4965
|
+
GGML_ASSERT(ggml_is_contiguous(a));
|
4966
|
+
GGML_ASSERT(ggml_is_contiguous(b));
|
4967
|
+
GGML_ASSERT(ggml_nelements(a) == ggml_nelements(b));
|
4968
|
+
|
4969
|
+
bool is_node = false;
|
4970
|
+
|
4971
|
+
if (a->grad || b->grad) {
|
4972
|
+
GGML_ASSERT(false); // TODO: implement backward
|
4973
|
+
is_node = true;
|
4974
|
+
}
|
4975
|
+
|
4976
|
+
struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, b->n_dims, b->ne, a->data);
|
4977
|
+
|
4978
|
+
result->op = GGML_OP_RESHAPE;
|
4979
|
+
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
4980
|
+
result->src0 = a;
|
4981
|
+
result->src1 = NULL;
|
4982
|
+
|
4983
|
+
return result;
|
4984
|
+
}
|
4985
|
+
|
4986
|
+
struct ggml_tensor * ggml_reshape_2d(
|
4987
|
+
struct ggml_context * ctx,
|
4988
|
+
struct ggml_tensor * a,
|
4989
|
+
int64_t ne0,
|
4990
|
+
int64_t ne1) {
|
4991
|
+
GGML_ASSERT(ggml_is_contiguous(a));
|
4424
4992
|
GGML_ASSERT(ggml_nelements(a) == ne0*ne1);
|
4425
4993
|
|
4426
4994
|
bool is_node = false;
|
@@ -4866,6 +5434,90 @@ struct ggml_tensor * ggml_flash_ff(
|
|
4866
5434
|
return result;
|
4867
5435
|
}
|
4868
5436
|
|
5437
|
+
// ggml_map_unary
|
5438
|
+
|
5439
|
+
struct ggml_tensor * ggml_map_unary_impl_f32(
|
5440
|
+
struct ggml_context * ctx,
|
5441
|
+
struct ggml_tensor * a,
|
5442
|
+
const ggml_unary_op_f32_t fun,
|
5443
|
+
bool inplace) {
|
5444
|
+
bool is_node = false;
|
5445
|
+
|
5446
|
+
if (!inplace && a->grad) {
|
5447
|
+
is_node = true;
|
5448
|
+
}
|
5449
|
+
|
5450
|
+
struct ggml_tensor * addr_tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(void *) / sizeof(int32_t));
|
5451
|
+
*((void (**)(void))addr_tensor->data) = (void (*)(void))fun;
|
5452
|
+
struct ggml_tensor *result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
5453
|
+
|
5454
|
+
result->op = GGML_OP_MAP_UNARY;
|
5455
|
+
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
5456
|
+
result->src0 = a;
|
5457
|
+
result->opt[0] = addr_tensor;
|
5458
|
+
|
5459
|
+
return result;
|
5460
|
+
}
|
5461
|
+
|
5462
|
+
struct ggml_tensor * ggml_map_unary_f32(
|
5463
|
+
struct ggml_context * ctx,
|
5464
|
+
struct ggml_tensor * a,
|
5465
|
+
const ggml_unary_op_f32_t fun) {
|
5466
|
+
return ggml_map_unary_impl_f32(ctx, a, fun, false);
|
5467
|
+
}
|
5468
|
+
|
5469
|
+
struct ggml_tensor * ggml_map_unary_inplace_f32(
|
5470
|
+
struct ggml_context * ctx,
|
5471
|
+
struct ggml_tensor * a,
|
5472
|
+
const ggml_unary_op_f32_t fun) {
|
5473
|
+
return ggml_map_unary_impl_f32(ctx, a, fun, true);
|
5474
|
+
}
|
5475
|
+
|
5476
|
+
// ggml_map_binary
|
5477
|
+
|
5478
|
+
struct ggml_tensor * ggml_map_binary_impl_f32(
|
5479
|
+
struct ggml_context * ctx,
|
5480
|
+
struct ggml_tensor * a,
|
5481
|
+
struct ggml_tensor * b,
|
5482
|
+
const ggml_binary_op_f32_t fun,
|
5483
|
+
bool inplace) {
|
5484
|
+
GGML_ASSERT(ggml_are_same_shape(a, b));
|
5485
|
+
|
5486
|
+
bool is_node = false;
|
5487
|
+
|
5488
|
+
if (!inplace && (a->grad || b->grad)) {
|
5489
|
+
is_node = true;
|
5490
|
+
}
|
5491
|
+
|
5492
|
+
struct ggml_tensor * addr_tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(void *) / sizeof(int32_t));
|
5493
|
+
*((void (**)(void))addr_tensor->data) = (void (*)(void))fun;
|
5494
|
+
struct ggml_tensor *result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
5495
|
+
|
5496
|
+
result->op = GGML_OP_MAP_BINARY;
|
5497
|
+
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
5498
|
+
result->src0 = a;
|
5499
|
+
result->src1 = b;
|
5500
|
+
result->opt[0] = addr_tensor;
|
5501
|
+
|
5502
|
+
return result;
|
5503
|
+
}
|
5504
|
+
|
5505
|
+
struct ggml_tensor * ggml_map_binary_f32(
|
5506
|
+
struct ggml_context * ctx,
|
5507
|
+
struct ggml_tensor * a,
|
5508
|
+
struct ggml_tensor * b,
|
5509
|
+
const ggml_binary_op_f32_t fun) {
|
5510
|
+
return ggml_map_binary_impl_f32(ctx, a, b, fun, false);
|
5511
|
+
}
|
5512
|
+
|
5513
|
+
struct ggml_tensor * ggml_map_binary_inplace_f32(
|
5514
|
+
struct ggml_context * ctx,
|
5515
|
+
struct ggml_tensor * a,
|
5516
|
+
struct ggml_tensor * b,
|
5517
|
+
const ggml_binary_op_f32_t fun) {
|
5518
|
+
return ggml_map_binary_impl_f32(ctx, a, b, fun, true);
|
5519
|
+
}
|
5520
|
+
|
4869
5521
|
////////////////////////////////////////////////////////////////////////////////
|
4870
5522
|
|
4871
5523
|
void ggml_set_param(
|
@@ -4930,6 +5582,105 @@ static void ggml_compute_forward_dup_f16(
|
|
4930
5582
|
|
4931
5583
|
// TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
|
4932
5584
|
|
5585
|
+
if (ggml_is_contiguous(dst)) {
|
5586
|
+
if (src0->nb[0] == sizeof(ggml_fp16_t)) {
|
5587
|
+
if (dst->type == GGML_TYPE_F16) {
|
5588
|
+
size_t id = 0;
|
5589
|
+
const size_t rs = ne00*nb00;
|
5590
|
+
|
5591
|
+
for (int i03 = 0; i03 < ne03; i03++) {
|
5592
|
+
for (int i02 = 0; i02 < ne02; i02++) {
|
5593
|
+
for (int i01 = 0; i01 < ne01; i01++) {
|
5594
|
+
const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
|
5595
|
+
char * dst_ptr = (char *) dst->data + id*rs;
|
5596
|
+
|
5597
|
+
memcpy(dst_ptr, src0_ptr, rs);
|
5598
|
+
|
5599
|
+
id++;
|
5600
|
+
}
|
5601
|
+
}
|
5602
|
+
}
|
5603
|
+
} else if (dst->type == GGML_TYPE_F32) {
|
5604
|
+
size_t id = 0;
|
5605
|
+
float * dst_ptr = (float *) dst->data;
|
5606
|
+
|
5607
|
+
for (int i03 = 0; i03 < ne03; i03++) {
|
5608
|
+
for (int i02 = 0; i02 < ne02; i02++) {
|
5609
|
+
for (int i01 = 0; i01 < ne01; i01++) {
|
5610
|
+
for (int i00 = 0; i00 < ne00; i00++) {
|
5611
|
+
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
5612
|
+
|
5613
|
+
dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr);
|
5614
|
+
id++;
|
5615
|
+
}
|
5616
|
+
}
|
5617
|
+
}
|
5618
|
+
}
|
5619
|
+
} else if (dst->type == GGML_TYPE_Q4_0 || dst->type == GGML_TYPE_Q4_1) {
|
5620
|
+
quantize_row_q_t const quantize_row_q = quantize_fns[dst->type].quantize_row_q;
|
5621
|
+
size_t id = 0;
|
5622
|
+
uint8_t * dst_ptr = (uint8_t *) dst->data;
|
5623
|
+
size_t dst_row_size = nb0 * (ne00 / GGML_BLCK_SIZE[dst->type]);
|
5624
|
+
float * src0_f32 = (float *) params->wdata;
|
5625
|
+
|
5626
|
+
for (int i03 = 0; i03 < ne03; i03++) {
|
5627
|
+
for (int i02 = 0; i02 < ne02; i02++) {
|
5628
|
+
for (int i01 = 0; i01 < ne01; i01++) {
|
5629
|
+
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
5630
|
+
// convert to f32 and quantize
|
5631
|
+
for (int i00 = 0; i00 < ne00; i00++) {
|
5632
|
+
src0_f32[i00] = GGML_FP16_TO_FP32(src0_ptr[i00]);
|
5633
|
+
}
|
5634
|
+
quantize_row_q(src0_f32, dst_ptr + id, ne00);
|
5635
|
+
id += dst_row_size;
|
5636
|
+
}
|
5637
|
+
}
|
5638
|
+
}
|
5639
|
+
} else {
|
5640
|
+
GGML_ASSERT(false); // TODO: implement
|
5641
|
+
}
|
5642
|
+
} else {
|
5643
|
+
//printf("%s: this is not optimal - fix me\n", __func__);
|
5644
|
+
|
5645
|
+
if (dst->type == GGML_TYPE_F32) {
|
5646
|
+
size_t id = 0;
|
5647
|
+
float * dst_ptr = (float *) dst->data;
|
5648
|
+
|
5649
|
+
for (int i03 = 0; i03 < ne03; i03++) {
|
5650
|
+
for (int i02 = 0; i02 < ne02; i02++) {
|
5651
|
+
for (int i01 = 0; i01 < ne01; i01++) {
|
5652
|
+
for (int i00 = 0; i00 < ne00; i00++) {
|
5653
|
+
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
5654
|
+
|
5655
|
+
dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr);
|
5656
|
+
id++;
|
5657
|
+
}
|
5658
|
+
}
|
5659
|
+
}
|
5660
|
+
}
|
5661
|
+
} else if (dst->type == GGML_TYPE_F16) {
|
5662
|
+
size_t id = 0;
|
5663
|
+
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
|
5664
|
+
|
5665
|
+
for (int i03 = 0; i03 < ne03; i03++) {
|
5666
|
+
for (int i02 = 0; i02 < ne02; i02++) {
|
5667
|
+
for (int i01 = 0; i01 < ne01; i01++) {
|
5668
|
+
for (int i00 = 0; i00 < ne00; i00++) {
|
5669
|
+
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
5670
|
+
|
5671
|
+
dst_ptr[id] = *src0_ptr;
|
5672
|
+
id++;
|
5673
|
+
}
|
5674
|
+
}
|
5675
|
+
}
|
5676
|
+
}
|
5677
|
+
} else {
|
5678
|
+
GGML_ASSERT(false); // TODO: implement
|
5679
|
+
}
|
5680
|
+
}
|
5681
|
+
return;
|
5682
|
+
}
|
5683
|
+
|
4933
5684
|
// dst counters
|
4934
5685
|
int64_t i10 = 0;
|
4935
5686
|
int64_t i11 = 0;
|
@@ -5024,6 +5775,120 @@ static void ggml_compute_forward_dup_f32(
|
|
5024
5775
|
return;
|
5025
5776
|
}
|
5026
5777
|
|
5778
|
+
if (src0->type == dst->type &&
|
5779
|
+
src0->ne[0] == dst->ne[0] &&
|
5780
|
+
src0->nb[0] == GGML_TYPE_SIZE[src0->type] && dst->nb[0] == GGML_TYPE_SIZE[dst->type]) {
|
5781
|
+
// copy by rows
|
5782
|
+
const size_t rs = ne00*nb00;
|
5783
|
+
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
5784
|
+
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
5785
|
+
for (int64_t i01 = 0; i01 < ne01; i01++) {
|
5786
|
+
memcpy(
|
5787
|
+
((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
|
5788
|
+
((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
|
5789
|
+
rs);
|
5790
|
+
}
|
5791
|
+
}
|
5792
|
+
}
|
5793
|
+
return;
|
5794
|
+
}
|
5795
|
+
|
5796
|
+
if (ggml_is_contiguous(dst)) {
|
5797
|
+
// TODO: simplify
|
5798
|
+
if (src0->nb[0] == sizeof(float)) {
|
5799
|
+
if (dst->type == GGML_TYPE_F32) {
|
5800
|
+
size_t id = 0;
|
5801
|
+
const size_t rs = ne00*nb00;
|
5802
|
+
|
5803
|
+
for (int i03 = 0; i03 < ne03; i03++) {
|
5804
|
+
for (int i02 = 0; i02 < ne02; i02++) {
|
5805
|
+
for (int i01 = 0; i01 < ne01; i01++) {
|
5806
|
+
const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
|
5807
|
+
char * dst_ptr = (char *) dst->data + id*rs;
|
5808
|
+
|
5809
|
+
memcpy(dst_ptr, src0_ptr, rs);
|
5810
|
+
|
5811
|
+
id++;
|
5812
|
+
}
|
5813
|
+
}
|
5814
|
+
}
|
5815
|
+
} else if (dst->type == GGML_TYPE_F16) {
|
5816
|
+
size_t id = 0;
|
5817
|
+
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
|
5818
|
+
|
5819
|
+
for (int i03 = 0; i03 < ne03; i03++) {
|
5820
|
+
for (int i02 = 0; i02 < ne02; i02++) {
|
5821
|
+
for (int i01 = 0; i01 < ne01; i01++) {
|
5822
|
+
for (int i00 = 0; i00 < ne00; i00++) {
|
5823
|
+
const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
5824
|
+
|
5825
|
+
dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr);
|
5826
|
+
id++;
|
5827
|
+
}
|
5828
|
+
}
|
5829
|
+
}
|
5830
|
+
}
|
5831
|
+
} else if (dst->type == GGML_TYPE_Q4_0 || dst->type == GGML_TYPE_Q4_1) {
|
5832
|
+
quantize_row_q_t const quantize_row_q = quantize_fns[dst->type].quantize_row_q;
|
5833
|
+
size_t id = 0;
|
5834
|
+
uint8_t * dst_ptr = (uint8_t *) dst->data;
|
5835
|
+
size_t dst_row_size = nb0 * (ne00 / GGML_BLCK_SIZE[dst->type]);
|
5836
|
+
|
5837
|
+
for (int i03 = 0; i03 < ne03; i03++) {
|
5838
|
+
for (int i02 = 0; i02 < ne02; i02++) {
|
5839
|
+
for (int i01 = 0; i01 < ne01; i01++) {
|
5840
|
+
const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
5841
|
+
quantize_row_q(src0_ptr, dst_ptr + id, ne00);
|
5842
|
+
id += dst_row_size;
|
5843
|
+
}
|
5844
|
+
}
|
5845
|
+
}
|
5846
|
+
} else {
|
5847
|
+
GGML_ASSERT(false); // TODO: implement
|
5848
|
+
}
|
5849
|
+
} else {
|
5850
|
+
//printf("%s: this is not optimal - fix me\n", __func__);
|
5851
|
+
|
5852
|
+
if (dst->type == GGML_TYPE_F32) {
|
5853
|
+
size_t id = 0;
|
5854
|
+
float * dst_ptr = (float *) dst->data;
|
5855
|
+
|
5856
|
+
for (int i03 = 0; i03 < ne03; i03++) {
|
5857
|
+
for (int i02 = 0; i02 < ne02; i02++) {
|
5858
|
+
for (int i01 = 0; i01 < ne01; i01++) {
|
5859
|
+
for (int i00 = 0; i00 < ne00; i00++) {
|
5860
|
+
const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
5861
|
+
|
5862
|
+
dst_ptr[id] = *src0_ptr;
|
5863
|
+
id++;
|
5864
|
+
}
|
5865
|
+
}
|
5866
|
+
}
|
5867
|
+
}
|
5868
|
+
} else if (dst->type == GGML_TYPE_F16) {
|
5869
|
+
size_t id = 0;
|
5870
|
+
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
|
5871
|
+
|
5872
|
+
for (int i03 = 0; i03 < ne03; i03++) {
|
5873
|
+
for (int i02 = 0; i02 < ne02; i02++) {
|
5874
|
+
for (int i01 = 0; i01 < ne01; i01++) {
|
5875
|
+
for (int i00 = 0; i00 < ne00; i00++) {
|
5876
|
+
const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
5877
|
+
|
5878
|
+
dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr);
|
5879
|
+
id++;
|
5880
|
+
}
|
5881
|
+
}
|
5882
|
+
}
|
5883
|
+
}
|
5884
|
+
} else {
|
5885
|
+
GGML_ASSERT(false); // TODO: implement
|
5886
|
+
}
|
5887
|
+
}
|
5888
|
+
|
5889
|
+
return;
|
5890
|
+
}
|
5891
|
+
|
5027
5892
|
// dst counters
|
5028
5893
|
int64_t i10 = 0;
|
5029
5894
|
int64_t i11 = 0;
|
@@ -5100,12 +5965,7 @@ static void ggml_compute_forward_dup(
|
|
5100
5965
|
{
|
5101
5966
|
ggml_compute_forward_dup_f32(params, src0, dst);
|
5102
5967
|
} break;
|
5103
|
-
|
5104
|
-
case GGML_TYPE_Q4_1:
|
5105
|
-
case GGML_TYPE_I8:
|
5106
|
-
case GGML_TYPE_I16:
|
5107
|
-
case GGML_TYPE_I32:
|
5108
|
-
case GGML_TYPE_COUNT:
|
5968
|
+
default:
|
5109
5969
|
{
|
5110
5970
|
GGML_ASSERT(false);
|
5111
5971
|
} break;
|
@@ -5144,14 +6004,18 @@ static void ggml_compute_forward_add_f32(
|
|
5144
6004
|
GGML_ASSERT(nb00 == sizeof(float));
|
5145
6005
|
|
5146
6006
|
if (nb10 == sizeof(float)) {
|
5147
|
-
|
5148
|
-
|
5149
|
-
|
5150
|
-
|
6007
|
+
for (int j = ith; j < n; j += nth) {
|
6008
|
+
#ifdef GGML_USE_ACCELERATE
|
6009
|
+
vDSP_vadd(
|
6010
|
+
(float *) ((char *) src0->data + j*nb01), 1,
|
6011
|
+
(float *) ((char *) src1->data + j*nb11), 1,
|
6012
|
+
(float *) ((char *) dst->data + j*nb1), 1, nc);
|
6013
|
+
#else
|
5151
6014
|
ggml_vec_add_f32(nc,
|
5152
6015
|
(float *) ((char *) dst->data + j*nb1),
|
5153
6016
|
(float *) ((char *) src0->data + j*nb01),
|
5154
6017
|
(float *) ((char *) src1->data + j*nb11));
|
6018
|
+
#endif
|
5155
6019
|
}
|
5156
6020
|
} else {
|
5157
6021
|
// src1 is not contiguous
|
@@ -5167,6 +6031,212 @@ static void ggml_compute_forward_add_f32(
|
|
5167
6031
|
}
|
5168
6032
|
}
|
5169
6033
|
|
6034
|
+
static void ggml_compute_forward_add_f16_f32(
|
6035
|
+
const struct ggml_compute_params * params,
|
6036
|
+
const struct ggml_tensor * src0,
|
6037
|
+
const struct ggml_tensor * src1,
|
6038
|
+
struct ggml_tensor * dst) {
|
6039
|
+
GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
|
6040
|
+
|
6041
|
+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
6042
|
+
return;
|
6043
|
+
}
|
6044
|
+
|
6045
|
+
const int ith = params->ith;
|
6046
|
+
const int nth = params->nth;
|
6047
|
+
|
6048
|
+
const int n = ggml_nrows(src0);
|
6049
|
+
const int nc = src0->ne[0];
|
6050
|
+
|
6051
|
+
const size_t nb00 = src0->nb[0];
|
6052
|
+
const size_t nb01 = src0->nb[1];
|
6053
|
+
|
6054
|
+
const size_t nb10 = src1->nb[0];
|
6055
|
+
const size_t nb11 = src1->nb[1];
|
6056
|
+
|
6057
|
+
const size_t nb0 = dst->nb[0];
|
6058
|
+
const size_t nb1 = dst->nb[1];
|
6059
|
+
|
6060
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
6061
|
+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
6062
|
+
GGML_ASSERT(dst->type == GGML_TYPE_F16);
|
6063
|
+
|
6064
|
+
GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
|
6065
|
+
GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
|
6066
|
+
|
6067
|
+
if (nb10 == sizeof(float)) {
|
6068
|
+
for (int j = ith; j < n; j += nth) {
|
6069
|
+
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + j*nb1);
|
6070
|
+
ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01);
|
6071
|
+
for (int i = 0; i < nc; i++) {
|
6072
|
+
float * src1_ptr = (float *) ((char *) src1->data + j*nb11 + i*nb10);
|
6073
|
+
dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + *src1_ptr);
|
6074
|
+
}
|
6075
|
+
}
|
6076
|
+
}
|
6077
|
+
else {
|
6078
|
+
// src1 is not contiguous
|
6079
|
+
GGML_ASSERT(false);
|
6080
|
+
}
|
6081
|
+
}
|
6082
|
+
|
6083
|
+
static void ggml_compute_forward_add_f16_f16(
|
6084
|
+
const struct ggml_compute_params * params,
|
6085
|
+
const struct ggml_tensor * src0,
|
6086
|
+
const struct ggml_tensor * src1,
|
6087
|
+
struct ggml_tensor * dst) {
|
6088
|
+
GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
|
6089
|
+
|
6090
|
+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
6091
|
+
return;
|
6092
|
+
}
|
6093
|
+
|
6094
|
+
const int ith = params->ith;
|
6095
|
+
const int nth = params->nth;
|
6096
|
+
|
6097
|
+
const int n = ggml_nrows(src0);
|
6098
|
+
const int nc = src0->ne[0];
|
6099
|
+
|
6100
|
+
const size_t nb00 = src0->nb[0];
|
6101
|
+
const size_t nb01 = src0->nb[1];
|
6102
|
+
|
6103
|
+
const size_t nb10 = src1->nb[0];
|
6104
|
+
const size_t nb11 = src1->nb[1];
|
6105
|
+
|
6106
|
+
const size_t nb0 = dst->nb[0];
|
6107
|
+
const size_t nb1 = dst->nb[1];
|
6108
|
+
|
6109
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
6110
|
+
GGML_ASSERT(src1->type == GGML_TYPE_F16);
|
6111
|
+
GGML_ASSERT(dst->type == GGML_TYPE_F16);
|
6112
|
+
|
6113
|
+
GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
|
6114
|
+
GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
|
6115
|
+
|
6116
|
+
if (nb10 == sizeof(ggml_fp16_t)) {
|
6117
|
+
for (int j = ith; j < n; j += nth) {
|
6118
|
+
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + j*nb1);
|
6119
|
+
ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01);
|
6120
|
+
for (int i = 0; i < nc; i++) {
|
6121
|
+
ggml_fp16_t * src1_ptr = (ggml_fp16_t *) ((char *) src1->data + j*nb11 + i*nb10);
|
6122
|
+
dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + GGML_FP16_TO_FP32(*src1_ptr));
|
6123
|
+
}
|
6124
|
+
}
|
6125
|
+
}
|
6126
|
+
else {
|
6127
|
+
// src1 is not contiguous
|
6128
|
+
GGML_ASSERT(false);
|
6129
|
+
}
|
6130
|
+
}
|
6131
|
+
|
6132
|
+
static void ggml_compute_forward_add_q_f32(
|
6133
|
+
const struct ggml_compute_params * params,
|
6134
|
+
const struct ggml_tensor * src0,
|
6135
|
+
const struct ggml_tensor * src1,
|
6136
|
+
struct ggml_tensor * dst) {
|
6137
|
+
GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
|
6138
|
+
|
6139
|
+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
6140
|
+
return;
|
6141
|
+
}
|
6142
|
+
|
6143
|
+
const int64_t ne00 = src0->ne[0];
|
6144
|
+
const int64_t ne01 = src0->ne[1];
|
6145
|
+
const int64_t ne02 = src0->ne[2];
|
6146
|
+
const int64_t ne03 = src0->ne[3];
|
6147
|
+
|
6148
|
+
//const int64_t ne10 = src1->ne[0];
|
6149
|
+
//const int64_t ne11 = src1->ne[1];
|
6150
|
+
const int64_t ne12 = src1->ne[2];
|
6151
|
+
const int64_t ne13 = src1->ne[3];
|
6152
|
+
|
6153
|
+
//const int64_t ne0 = dst->ne[0];
|
6154
|
+
//const int64_t ne1 = dst->ne[1];
|
6155
|
+
const int64_t ne2 = dst->ne[2];
|
6156
|
+
const int64_t ne3 = dst->ne[3];
|
6157
|
+
|
6158
|
+
const int nb00 = src0->nb[0];
|
6159
|
+
const int nb01 = src0->nb[1];
|
6160
|
+
const int nb02 = src0->nb[2];
|
6161
|
+
const int nb03 = src0->nb[3];
|
6162
|
+
|
6163
|
+
const int nb10 = src1->nb[0];
|
6164
|
+
const int nb11 = src1->nb[1];
|
6165
|
+
const int nb12 = src1->nb[2];
|
6166
|
+
const int nb13 = src1->nb[3];
|
6167
|
+
|
6168
|
+
const int nb0 = dst->nb[0];
|
6169
|
+
const int nb1 = dst->nb[1];
|
6170
|
+
const int nb2 = dst->nb[2];
|
6171
|
+
const int nb3 = dst->nb[3];
|
6172
|
+
|
6173
|
+
const int ith = params->ith;
|
6174
|
+
const int nth = params->nth;
|
6175
|
+
|
6176
|
+
GGML_ASSERT(ne02 == ne12);
|
6177
|
+
GGML_ASSERT(ne03 == ne13);
|
6178
|
+
GGML_ASSERT(ne2 == ne12);
|
6179
|
+
GGML_ASSERT(ne3 == ne13);
|
6180
|
+
|
6181
|
+
const enum ggml_type type = src0->type;
|
6182
|
+
dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
|
6183
|
+
quantize_row_q_t const quantize_row_q = quantize_fns[type].quantize_row_q;
|
6184
|
+
|
6185
|
+
// we don't support permuted src0 or src1
|
6186
|
+
GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[type]);
|
6187
|
+
GGML_ASSERT(nb10 == sizeof(float));
|
6188
|
+
|
6189
|
+
// dst cannot be transposed or permuted
|
6190
|
+
GGML_ASSERT(nb0 <= nb1);
|
6191
|
+
GGML_ASSERT(nb1 <= nb2);
|
6192
|
+
GGML_ASSERT(nb2 <= nb3);
|
6193
|
+
|
6194
|
+
GGML_ASSERT(src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1);
|
6195
|
+
GGML_ASSERT(dst->type == src0->type);
|
6196
|
+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
6197
|
+
|
6198
|
+
// total rows in src0
|
6199
|
+
const int nr = ne01*ne02*ne03;
|
6200
|
+
|
6201
|
+
// rows per thread
|
6202
|
+
const int dr = (nr + nth - 1)/nth;
|
6203
|
+
|
6204
|
+
// row range for this thread
|
6205
|
+
const int ir0 = dr*ith;
|
6206
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
6207
|
+
|
6208
|
+
float * wdata = (float*) params->wdata + ne00 * ith;
|
6209
|
+
|
6210
|
+
for (int ir = ir0; ir < ir1; ++ir) {
|
6211
|
+
// src0 indices
|
6212
|
+
const int i03 = ir/(ne02*ne01);
|
6213
|
+
const int i02 = (ir - i03*ne02*ne01)/ne01;
|
6214
|
+
const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
|
6215
|
+
|
6216
|
+
// src1 and dst are same shape as src0 => same indices
|
6217
|
+
const int i13 = i03;
|
6218
|
+
const int i12 = i02;
|
6219
|
+
const int i11 = i01;
|
6220
|
+
|
6221
|
+
const int i3 = i03;
|
6222
|
+
const int i2 = i02;
|
6223
|
+
const int i1 = i01;
|
6224
|
+
|
6225
|
+
void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
|
6226
|
+
float * src1_row = (float *)((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13));
|
6227
|
+
void * dst_row = (void *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb0));
|
6228
|
+
|
6229
|
+
assert(ne00 % 32 == 0);
|
6230
|
+
|
6231
|
+
// unquantize row from src0 to temp buffer
|
6232
|
+
dequantize_row_q(src0_row, wdata, ne00);
|
6233
|
+
// add src1
|
6234
|
+
ggml_vec_acc_f32(ne00, wdata, src1_row);
|
6235
|
+
// quantize row to dst
|
6236
|
+
quantize_row_q(wdata, dst_row, ne00);
|
6237
|
+
}
|
6238
|
+
}
|
6239
|
+
|
5170
6240
|
static void ggml_compute_forward_add(
|
5171
6241
|
const struct ggml_compute_params * params,
|
5172
6242
|
const struct ggml_tensor * src0,
|
@@ -5177,13 +6247,24 @@ static void ggml_compute_forward_add(
|
|
5177
6247
|
{
|
5178
6248
|
ggml_compute_forward_add_f32(params, src0, src1, dst);
|
5179
6249
|
} break;
|
6250
|
+
case GGML_TYPE_F16:
|
6251
|
+
{
|
6252
|
+
if (src1->type == GGML_TYPE_F16) {
|
6253
|
+
ggml_compute_forward_add_f16_f16(params, src0, src1, dst);
|
6254
|
+
}
|
6255
|
+
else if (src1->type == GGML_TYPE_F32) {
|
6256
|
+
ggml_compute_forward_add_f16_f32(params, src0, src1, dst);
|
6257
|
+
}
|
6258
|
+
else {
|
6259
|
+
GGML_ASSERT(false);
|
6260
|
+
}
|
6261
|
+
} break;
|
5180
6262
|
case GGML_TYPE_Q4_0:
|
5181
6263
|
case GGML_TYPE_Q4_1:
|
5182
|
-
|
5183
|
-
|
5184
|
-
|
5185
|
-
|
5186
|
-
case GGML_TYPE_COUNT:
|
6264
|
+
{
|
6265
|
+
ggml_compute_forward_add_q_f32(params, src0, src1, dst);
|
6266
|
+
} break;
|
6267
|
+
default:
|
5187
6268
|
{
|
5188
6269
|
GGML_ASSERT(false);
|
5189
6270
|
} break;
|
@@ -5229,13 +6310,7 @@ static void ggml_compute_forward_sub(
|
|
5229
6310
|
{
|
5230
6311
|
ggml_compute_forward_sub_f32(params, src0, src1, dst);
|
5231
6312
|
} break;
|
5232
|
-
|
5233
|
-
case GGML_TYPE_Q4_1:
|
5234
|
-
case GGML_TYPE_I8:
|
5235
|
-
case GGML_TYPE_I16:
|
5236
|
-
case GGML_TYPE_I32:
|
5237
|
-
case GGML_TYPE_F16:
|
5238
|
-
case GGML_TYPE_COUNT:
|
6313
|
+
default:
|
5239
6314
|
{
|
5240
6315
|
GGML_ASSERT(false);
|
5241
6316
|
} break;
|
@@ -5281,13 +6356,7 @@ static void ggml_compute_forward_mul(
|
|
5281
6356
|
{
|
5282
6357
|
ggml_compute_forward_mul_f32(params, src0, src1, dst);
|
5283
6358
|
} break;
|
5284
|
-
|
5285
|
-
case GGML_TYPE_Q4_1:
|
5286
|
-
case GGML_TYPE_I8:
|
5287
|
-
case GGML_TYPE_I16:
|
5288
|
-
case GGML_TYPE_I32:
|
5289
|
-
case GGML_TYPE_F16:
|
5290
|
-
case GGML_TYPE_COUNT:
|
6359
|
+
default:
|
5291
6360
|
{
|
5292
6361
|
GGML_ASSERT(false);
|
5293
6362
|
} break;
|
@@ -5333,13 +6402,7 @@ static void ggml_compute_forward_div(
|
|
5333
6402
|
{
|
5334
6403
|
ggml_compute_forward_div_f32(params, src0, src1, dst);
|
5335
6404
|
} break;
|
5336
|
-
|
5337
|
-
case GGML_TYPE_Q4_1:
|
5338
|
-
case GGML_TYPE_I8:
|
5339
|
-
case GGML_TYPE_I16:
|
5340
|
-
case GGML_TYPE_I32:
|
5341
|
-
case GGML_TYPE_F16:
|
5342
|
-
case GGML_TYPE_COUNT:
|
6405
|
+
default:
|
5343
6406
|
{
|
5344
6407
|
GGML_ASSERT(false);
|
5345
6408
|
} break;
|
@@ -5381,13 +6444,7 @@ static void ggml_compute_forward_sqr(
|
|
5381
6444
|
{
|
5382
6445
|
ggml_compute_forward_sqr_f32(params, src0, dst);
|
5383
6446
|
} break;
|
5384
|
-
|
5385
|
-
case GGML_TYPE_Q4_1:
|
5386
|
-
case GGML_TYPE_I8:
|
5387
|
-
case GGML_TYPE_I16:
|
5388
|
-
case GGML_TYPE_I32:
|
5389
|
-
case GGML_TYPE_F16:
|
5390
|
-
case GGML_TYPE_COUNT:
|
6447
|
+
default:
|
5391
6448
|
{
|
5392
6449
|
GGML_ASSERT(false);
|
5393
6450
|
} break;
|
@@ -5429,13 +6486,7 @@ static void ggml_compute_forward_sqrt(
|
|
5429
6486
|
{
|
5430
6487
|
ggml_compute_forward_sqrt_f32(params, src0, dst);
|
5431
6488
|
} break;
|
5432
|
-
|
5433
|
-
case GGML_TYPE_Q4_1:
|
5434
|
-
case GGML_TYPE_I8:
|
5435
|
-
case GGML_TYPE_I16:
|
5436
|
-
case GGML_TYPE_I32:
|
5437
|
-
case GGML_TYPE_F16:
|
5438
|
-
case GGML_TYPE_COUNT:
|
6489
|
+
default:
|
5439
6490
|
{
|
5440
6491
|
GGML_ASSERT(false);
|
5441
6492
|
} break;
|
@@ -5485,16 +6536,10 @@ static void ggml_compute_forward_sum(
|
|
5485
6536
|
switch (src0->type) {
|
5486
6537
|
case GGML_TYPE_F32:
|
5487
6538
|
{
|
5488
|
-
ggml_compute_forward_sum_f32(params, src0, dst);
|
5489
|
-
} break;
|
5490
|
-
|
5491
|
-
|
5492
|
-
case GGML_TYPE_I8:
|
5493
|
-
case GGML_TYPE_I16:
|
5494
|
-
case GGML_TYPE_I32:
|
5495
|
-
case GGML_TYPE_F16:
|
5496
|
-
case GGML_TYPE_COUNT:
|
5497
|
-
{
|
6539
|
+
ggml_compute_forward_sum_f32(params, src0, dst);
|
6540
|
+
} break;
|
6541
|
+
default:
|
6542
|
+
{
|
5498
6543
|
GGML_ASSERT(false);
|
5499
6544
|
} break;
|
5500
6545
|
}
|
@@ -5564,13 +6609,7 @@ static void ggml_compute_forward_mean(
|
|
5564
6609
|
{
|
5565
6610
|
ggml_compute_forward_mean_f32(params, src0, dst);
|
5566
6611
|
} break;
|
5567
|
-
|
5568
|
-
case GGML_TYPE_Q4_1:
|
5569
|
-
case GGML_TYPE_I8:
|
5570
|
-
case GGML_TYPE_I16:
|
5571
|
-
case GGML_TYPE_I32:
|
5572
|
-
case GGML_TYPE_F16:
|
5573
|
-
case GGML_TYPE_COUNT:
|
6612
|
+
default:
|
5574
6613
|
{
|
5575
6614
|
GGML_ASSERT(false);
|
5576
6615
|
} break;
|
@@ -5628,13 +6667,7 @@ static void ggml_compute_forward_repeat(
|
|
5628
6667
|
{
|
5629
6668
|
ggml_compute_forward_repeat_f32(params, src0, dst);
|
5630
6669
|
} break;
|
5631
|
-
|
5632
|
-
case GGML_TYPE_Q4_1:
|
5633
|
-
case GGML_TYPE_I8:
|
5634
|
-
case GGML_TYPE_I16:
|
5635
|
-
case GGML_TYPE_I32:
|
5636
|
-
case GGML_TYPE_F16:
|
5637
|
-
case GGML_TYPE_COUNT:
|
6670
|
+
default:
|
5638
6671
|
{
|
5639
6672
|
GGML_ASSERT(false);
|
5640
6673
|
} break;
|
@@ -5676,13 +6709,7 @@ static void ggml_compute_forward_abs(
|
|
5676
6709
|
{
|
5677
6710
|
ggml_compute_forward_abs_f32(params, src0, dst);
|
5678
6711
|
} break;
|
5679
|
-
|
5680
|
-
case GGML_TYPE_Q4_1:
|
5681
|
-
case GGML_TYPE_I8:
|
5682
|
-
case GGML_TYPE_I16:
|
5683
|
-
case GGML_TYPE_I32:
|
5684
|
-
case GGML_TYPE_F16:
|
5685
|
-
case GGML_TYPE_COUNT:
|
6712
|
+
default:
|
5686
6713
|
{
|
5687
6714
|
GGML_ASSERT(false);
|
5688
6715
|
} break;
|
@@ -5724,13 +6751,7 @@ static void ggml_compute_forward_sgn(
|
|
5724
6751
|
{
|
5725
6752
|
ggml_compute_forward_sgn_f32(params, src0, dst);
|
5726
6753
|
} break;
|
5727
|
-
|
5728
|
-
case GGML_TYPE_Q4_1:
|
5729
|
-
case GGML_TYPE_I8:
|
5730
|
-
case GGML_TYPE_I16:
|
5731
|
-
case GGML_TYPE_I32:
|
5732
|
-
case GGML_TYPE_F16:
|
5733
|
-
case GGML_TYPE_COUNT:
|
6754
|
+
default:
|
5734
6755
|
{
|
5735
6756
|
GGML_ASSERT(false);
|
5736
6757
|
} break;
|
@@ -5772,13 +6793,7 @@ static void ggml_compute_forward_neg(
|
|
5772
6793
|
{
|
5773
6794
|
ggml_compute_forward_neg_f32(params, src0, dst);
|
5774
6795
|
} break;
|
5775
|
-
|
5776
|
-
case GGML_TYPE_Q4_1:
|
5777
|
-
case GGML_TYPE_I8:
|
5778
|
-
case GGML_TYPE_I16:
|
5779
|
-
case GGML_TYPE_I32:
|
5780
|
-
case GGML_TYPE_F16:
|
5781
|
-
case GGML_TYPE_COUNT:
|
6796
|
+
default:
|
5782
6797
|
{
|
5783
6798
|
GGML_ASSERT(false);
|
5784
6799
|
} break;
|
@@ -5820,13 +6835,7 @@ static void ggml_compute_forward_step(
|
|
5820
6835
|
{
|
5821
6836
|
ggml_compute_forward_step_f32(params, src0, dst);
|
5822
6837
|
} break;
|
5823
|
-
|
5824
|
-
case GGML_TYPE_Q4_1:
|
5825
|
-
case GGML_TYPE_I8:
|
5826
|
-
case GGML_TYPE_I16:
|
5827
|
-
case GGML_TYPE_I32:
|
5828
|
-
case GGML_TYPE_F16:
|
5829
|
-
case GGML_TYPE_COUNT:
|
6838
|
+
default:
|
5830
6839
|
{
|
5831
6840
|
GGML_ASSERT(false);
|
5832
6841
|
} break;
|
@@ -5868,13 +6877,7 @@ static void ggml_compute_forward_relu(
|
|
5868
6877
|
{
|
5869
6878
|
ggml_compute_forward_relu_f32(params, src0, dst);
|
5870
6879
|
} break;
|
5871
|
-
|
5872
|
-
case GGML_TYPE_Q4_1:
|
5873
|
-
case GGML_TYPE_I8:
|
5874
|
-
case GGML_TYPE_I16:
|
5875
|
-
case GGML_TYPE_I32:
|
5876
|
-
case GGML_TYPE_F16:
|
5877
|
-
case GGML_TYPE_COUNT:
|
6880
|
+
default:
|
5878
6881
|
{
|
5879
6882
|
GGML_ASSERT(false);
|
5880
6883
|
} break;
|
@@ -5933,13 +6936,7 @@ static void ggml_compute_forward_gelu(
|
|
5933
6936
|
{
|
5934
6937
|
ggml_compute_forward_gelu_f32(params, src0, dst);
|
5935
6938
|
} break;
|
5936
|
-
|
5937
|
-
case GGML_TYPE_Q4_1:
|
5938
|
-
case GGML_TYPE_I8:
|
5939
|
-
case GGML_TYPE_I16:
|
5940
|
-
case GGML_TYPE_I32:
|
5941
|
-
case GGML_TYPE_F16:
|
5942
|
-
case GGML_TYPE_COUNT:
|
6939
|
+
default:
|
5943
6940
|
{
|
5944
6941
|
GGML_ASSERT(false);
|
5945
6942
|
} break;
|
@@ -6000,13 +6997,7 @@ static void ggml_compute_forward_silu(
|
|
6000
6997
|
{
|
6001
6998
|
ggml_compute_forward_silu_f32(params, src0, dst);
|
6002
6999
|
} break;
|
6003
|
-
|
6004
|
-
case GGML_TYPE_Q4_1:
|
6005
|
-
case GGML_TYPE_I8:
|
6006
|
-
case GGML_TYPE_I16:
|
6007
|
-
case GGML_TYPE_I32:
|
6008
|
-
case GGML_TYPE_F16:
|
6009
|
-
case GGML_TYPE_COUNT:
|
7000
|
+
default:
|
6010
7001
|
{
|
6011
7002
|
GGML_ASSERT(false);
|
6012
7003
|
} break;
|
@@ -6086,13 +7077,7 @@ static void ggml_compute_forward_norm(
|
|
6086
7077
|
{
|
6087
7078
|
ggml_compute_forward_norm_f32(params, src0, dst);
|
6088
7079
|
} break;
|
6089
|
-
|
6090
|
-
case GGML_TYPE_Q4_1:
|
6091
|
-
case GGML_TYPE_I8:
|
6092
|
-
case GGML_TYPE_I16:
|
6093
|
-
case GGML_TYPE_I32:
|
6094
|
-
case GGML_TYPE_F16:
|
6095
|
-
case GGML_TYPE_COUNT:
|
7080
|
+
default:
|
6096
7081
|
{
|
6097
7082
|
GGML_ASSERT(false);
|
6098
7083
|
} break;
|
@@ -6166,13 +7151,7 @@ static void ggml_compute_forward_rms_norm(
|
|
6166
7151
|
{
|
6167
7152
|
ggml_compute_forward_rms_norm_f32(params, src0, dst);
|
6168
7153
|
} break;
|
6169
|
-
|
6170
|
-
case GGML_TYPE_Q4_1:
|
6171
|
-
case GGML_TYPE_I8:
|
6172
|
-
case GGML_TYPE_I16:
|
6173
|
-
case GGML_TYPE_I32:
|
6174
|
-
case GGML_TYPE_F16:
|
6175
|
-
case GGML_TYPE_COUNT:
|
7154
|
+
default:
|
6176
7155
|
{
|
6177
7156
|
GGML_ASSERT(false);
|
6178
7157
|
} break;
|
@@ -6304,7 +7283,7 @@ static void ggml_compute_forward_mul_mat_f32(
|
|
6304
7283
|
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
|
6305
7284
|
ne11, ne01, ne10,
|
6306
7285
|
1.0f, y, ne10,
|
6307
|
-
x,
|
7286
|
+
x, ne00,
|
6308
7287
|
0.0f, d, ne01);
|
6309
7288
|
}
|
6310
7289
|
}
|
@@ -6476,7 +7455,7 @@ static void ggml_compute_forward_mul_mat_f16_f32(
|
|
6476
7455
|
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
|
6477
7456
|
ne11, ne01, ne10,
|
6478
7457
|
1.0f, y, ne10,
|
6479
|
-
x,
|
7458
|
+
x, ne00,
|
6480
7459
|
0.0f, d, ne01);
|
6481
7460
|
}
|
6482
7461
|
}
|
@@ -6564,29 +7543,6 @@ static void ggml_compute_forward_mul_mat_f16_f32(
|
|
6564
7543
|
//}
|
6565
7544
|
}
|
6566
7545
|
|
6567
|
-
typedef void (*dequantize_row_q_t)(const void * restrict x, float * restrict y, int k);
|
6568
|
-
typedef void (*quantize_row_q_t)(const float * restrict x, void * restrict y, int k);
|
6569
|
-
typedef void (*vec_dot_q_t)(const int n, float * restrict s, const void * restrict x, const void * restrict y);
|
6570
|
-
|
6571
|
-
typedef struct {
|
6572
|
-
dequantize_row_q_t dequantize_row_q;
|
6573
|
-
quantize_row_q_t quantize_row_q;
|
6574
|
-
vec_dot_q_t vec_dot_q;
|
6575
|
-
} quantize_fns_t;
|
6576
|
-
|
6577
|
-
static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
|
6578
|
-
[GGML_TYPE_Q4_0] = {
|
6579
|
-
.dequantize_row_q = dequantize_row_q4_0,
|
6580
|
-
.quantize_row_q = quantize_row_q4_0,
|
6581
|
-
.vec_dot_q = ggml_vec_dot_q4_0,
|
6582
|
-
},
|
6583
|
-
[GGML_TYPE_Q4_1] = {
|
6584
|
-
.dequantize_row_q = dequantize_row_q4_1,
|
6585
|
-
.quantize_row_q = quantize_row_q4_1,
|
6586
|
-
.vec_dot_q = ggml_vec_dot_q4_1,
|
6587
|
-
},
|
6588
|
-
};
|
6589
|
-
|
6590
7546
|
static void ggml_compute_forward_mul_mat_q_f32(
|
6591
7547
|
const struct ggml_compute_params * params,
|
6592
7548
|
const struct ggml_tensor * src0,
|
@@ -6634,8 +7590,8 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|
6634
7590
|
GGML_ASSERT(ne3 == ne13);
|
6635
7591
|
|
6636
7592
|
const enum ggml_type type = src0->type;
|
6637
|
-
quantize_row_q_t const
|
6638
|
-
vec_dot_q_t const vec_dot_q
|
7593
|
+
quantize_row_q_t const quantize_row_q_dot = quantize_fns[type].quantize_row_q_dot;
|
7594
|
+
vec_dot_q_t const vec_dot_q = quantize_fns[type].vec_dot_q;
|
6639
7595
|
|
6640
7596
|
// we don't support permuted src0 or src1
|
6641
7597
|
GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[type]);
|
@@ -6691,7 +7647,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|
6691
7647
|
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
|
6692
7648
|
ne11, ne01, ne10,
|
6693
7649
|
1.0f, y, ne10,
|
6694
|
-
x,
|
7650
|
+
x, ne00,
|
6695
7651
|
0.0f, d, ne01);
|
6696
7652
|
}
|
6697
7653
|
}
|
@@ -6704,12 +7660,12 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|
6704
7660
|
|
6705
7661
|
if (params->type == GGML_TASK_INIT) {
|
6706
7662
|
char * wdata = params->wdata;
|
6707
|
-
const size_t row_size = ne10*GGML_TYPE_SIZE[
|
7663
|
+
const size_t row_size = ne10*GGML_TYPE_SIZE[GGML_TYPE_Q8_0]/GGML_BLCK_SIZE[GGML_TYPE_Q8_0];
|
6708
7664
|
|
6709
7665
|
for (int64_t i13 = 0; i13 < ne13; ++i13) {
|
6710
7666
|
for (int64_t i12 = 0; i12 < ne12; ++i12) {
|
6711
7667
|
for (int64_t i11 = 0; i11 < ne11; ++i11) {
|
6712
|
-
|
7668
|
+
quantize_row_q_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
|
6713
7669
|
wdata += row_size;
|
6714
7670
|
}
|
6715
7671
|
}
|
@@ -6735,7 +7691,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|
6735
7691
|
const int ir1 = MIN(ir0 + dr, nr);
|
6736
7692
|
|
6737
7693
|
void * wdata = params->wdata;
|
6738
|
-
const size_t row_size = ne00*GGML_TYPE_SIZE[
|
7694
|
+
const size_t row_size = ne00*GGML_TYPE_SIZE[GGML_TYPE_Q8_0]/GGML_BLCK_SIZE[GGML_TYPE_Q8_0];
|
6739
7695
|
|
6740
7696
|
for (int ir = ir0; ir < ir1; ++ir) {
|
6741
7697
|
// src0 indices
|
@@ -6783,6 +7739,7 @@ static void ggml_compute_forward_mul_mat(
|
|
6783
7739
|
switch (src0->type) {
|
6784
7740
|
case GGML_TYPE_Q4_0:
|
6785
7741
|
case GGML_TYPE_Q4_1:
|
7742
|
+
case GGML_TYPE_Q8_0:
|
6786
7743
|
{
|
6787
7744
|
ggml_compute_forward_mul_mat_q_f32(params, src0, src1, dst);
|
6788
7745
|
} break;
|
@@ -6794,10 +7751,7 @@ static void ggml_compute_forward_mul_mat(
|
|
6794
7751
|
{
|
6795
7752
|
ggml_compute_forward_mul_mat_f32(params, src0, src1, dst);
|
6796
7753
|
} break;
|
6797
|
-
|
6798
|
-
case GGML_TYPE_I16:
|
6799
|
-
case GGML_TYPE_I32:
|
6800
|
-
case GGML_TYPE_COUNT:
|
7754
|
+
default:
|
6801
7755
|
{
|
6802
7756
|
GGML_ASSERT(false);
|
6803
7757
|
} break;
|
@@ -6879,13 +7833,7 @@ static void ggml_compute_forward_scale(
|
|
6879
7833
|
{
|
6880
7834
|
ggml_compute_forward_scale_f32(params, src0, src1, dst);
|
6881
7835
|
} break;
|
6882
|
-
|
6883
|
-
case GGML_TYPE_Q4_1:
|
6884
|
-
case GGML_TYPE_I8:
|
6885
|
-
case GGML_TYPE_I16:
|
6886
|
-
case GGML_TYPE_I32:
|
6887
|
-
case GGML_TYPE_F16:
|
6888
|
-
case GGML_TYPE_COUNT:
|
7836
|
+
default:
|
6889
7837
|
{
|
6890
7838
|
GGML_ASSERT(false);
|
6891
7839
|
} break;
|
@@ -6901,6 +7849,15 @@ static void ggml_compute_forward_cpy(
|
|
6901
7849
|
ggml_compute_forward_dup(params, src0, dst);
|
6902
7850
|
}
|
6903
7851
|
|
7852
|
+
// ggml_compute_forward_cont
|
7853
|
+
|
7854
|
+
static void ggml_compute_forward_cont(
|
7855
|
+
const struct ggml_compute_params * params,
|
7856
|
+
const struct ggml_tensor * src0,
|
7857
|
+
struct ggml_tensor * dst) {
|
7858
|
+
ggml_compute_forward_dup(params, src0, dst);
|
7859
|
+
}
|
7860
|
+
|
6904
7861
|
// ggml_compute_forward_reshape
|
6905
7862
|
|
6906
7863
|
static void ggml_compute_forward_reshape(
|
@@ -7037,6 +7994,7 @@ static void ggml_compute_forward_get_rows(
|
|
7037
7994
|
switch (src0->type) {
|
7038
7995
|
case GGML_TYPE_Q4_0:
|
7039
7996
|
case GGML_TYPE_Q4_1:
|
7997
|
+
case GGML_TYPE_Q8_0:
|
7040
7998
|
{
|
7041
7999
|
ggml_compute_forward_get_rows_q(params, src0, src1, dst);
|
7042
8000
|
} break;
|
@@ -7048,10 +8006,7 @@ static void ggml_compute_forward_get_rows(
|
|
7048
8006
|
{
|
7049
8007
|
ggml_compute_forward_get_rows_f32(params, src0, src1, dst);
|
7050
8008
|
} break;
|
7051
|
-
|
7052
|
-
case GGML_TYPE_I16:
|
7053
|
-
case GGML_TYPE_I32:
|
7054
|
-
case GGML_TYPE_COUNT:
|
8009
|
+
default:
|
7055
8010
|
{
|
7056
8011
|
GGML_ASSERT(false);
|
7057
8012
|
} break;
|
@@ -7124,13 +8079,7 @@ static void ggml_compute_forward_diag_mask_inf(
|
|
7124
8079
|
{
|
7125
8080
|
ggml_compute_forward_diag_mask_inf_f32(params, src0, src1, dst);
|
7126
8081
|
} break;
|
7127
|
-
|
7128
|
-
case GGML_TYPE_Q4_1:
|
7129
|
-
case GGML_TYPE_I8:
|
7130
|
-
case GGML_TYPE_I16:
|
7131
|
-
case GGML_TYPE_I32:
|
7132
|
-
case GGML_TYPE_F16:
|
7133
|
-
case GGML_TYPE_COUNT:
|
8082
|
+
default:
|
7134
8083
|
{
|
7135
8084
|
GGML_ASSERT(false);
|
7136
8085
|
} break;
|
@@ -7218,13 +8167,7 @@ static void ggml_compute_forward_soft_max(
|
|
7218
8167
|
{
|
7219
8168
|
ggml_compute_forward_soft_max_f32(params, src0, dst);
|
7220
8169
|
} break;
|
7221
|
-
|
7222
|
-
case GGML_TYPE_Q4_1:
|
7223
|
-
case GGML_TYPE_I8:
|
7224
|
-
case GGML_TYPE_I16:
|
7225
|
-
case GGML_TYPE_I32:
|
7226
|
-
case GGML_TYPE_F16:
|
7227
|
-
case GGML_TYPE_COUNT:
|
8170
|
+
default:
|
7228
8171
|
{
|
7229
8172
|
GGML_ASSERT(false);
|
7230
8173
|
} break;
|
@@ -7279,6 +8222,8 @@ static void ggml_compute_forward_rope_f32(
|
|
7279
8222
|
// row index used to determine which thread to use
|
7280
8223
|
int ir = 0;
|
7281
8224
|
|
8225
|
+
const float theta_scale = powf(10000.0, -2.0f/n_dims);
|
8226
|
+
|
7282
8227
|
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
7283
8228
|
for (int64_t i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) {
|
7284
8229
|
const int p = (mode == 0 ? n_past + i2 : i2);
|
@@ -7286,11 +8231,13 @@ static void ggml_compute_forward_rope_f32(
|
|
7286
8231
|
if (ir++ < ir0) continue;
|
7287
8232
|
if (ir > ir1) break;
|
7288
8233
|
|
8234
|
+
float theta = (float)p;
|
8235
|
+
|
7289
8236
|
for (int i0 = 0; i0 < n_dims; i0 += 2) {
|
7290
|
-
const float
|
8237
|
+
const float cos_theta = cosf(theta);
|
8238
|
+
const float sin_theta = sinf(theta);
|
7291
8239
|
|
7292
|
-
|
7293
|
-
const float sin_theta = sinf(p*theta);
|
8240
|
+
theta *= theta_scale;
|
7294
8241
|
|
7295
8242
|
const float * const src = (float *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
7296
8243
|
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
@@ -7352,6 +8299,8 @@ static void ggml_compute_forward_rope_f16(
|
|
7352
8299
|
// row index used to determine which thread to use
|
7353
8300
|
int ir = 0;
|
7354
8301
|
|
8302
|
+
const float theta_scale = powf(10000.0, -2.0f/n_dims);
|
8303
|
+
|
7355
8304
|
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
7356
8305
|
for (int64_t i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) {
|
7357
8306
|
const int p = (mode == 0 ? n_past + i2 : i2);
|
@@ -7359,20 +8308,22 @@ static void ggml_compute_forward_rope_f16(
|
|
7359
8308
|
if (ir++ < ir0) continue;
|
7360
8309
|
if (ir > ir1) break;
|
7361
8310
|
|
8311
|
+
float theta = (float)p;
|
8312
|
+
|
7362
8313
|
for (int i0 = 0; i0 < n_dims; i0 += 2) {
|
7363
|
-
const float
|
8314
|
+
const float cos_theta = cosf(theta);
|
8315
|
+
const float sin_theta = sinf(theta);
|
7364
8316
|
|
7365
|
-
|
7366
|
-
const float sin_theta = sinf(p*theta);
|
8317
|
+
theta *= theta_scale;
|
7367
8318
|
|
7368
8319
|
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
7369
8320
|
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
7370
8321
|
|
7371
|
-
const float x0 =
|
7372
|
-
const float x1 =
|
8322
|
+
const float x0 = GGML_FP16_TO_FP32(src[0]);
|
8323
|
+
const float x1 = GGML_FP16_TO_FP32(src[1]);
|
7373
8324
|
|
7374
|
-
dst_data[0] =
|
7375
|
-
dst_data[1] =
|
8325
|
+
dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
|
8326
|
+
dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
7376
8327
|
}
|
7377
8328
|
}
|
7378
8329
|
}
|
@@ -7393,12 +8344,7 @@ static void ggml_compute_forward_rope(
|
|
7393
8344
|
{
|
7394
8345
|
ggml_compute_forward_rope_f32(params, src0, src1, dst);
|
7395
8346
|
} break;
|
7396
|
-
|
7397
|
-
case GGML_TYPE_Q4_1:
|
7398
|
-
case GGML_TYPE_I8:
|
7399
|
-
case GGML_TYPE_I16:
|
7400
|
-
case GGML_TYPE_I32:
|
7401
|
-
case GGML_TYPE_COUNT:
|
8347
|
+
default:
|
7402
8348
|
{
|
7403
8349
|
GGML_ASSERT(false);
|
7404
8350
|
} break;
|
@@ -7661,12 +8607,7 @@ static void ggml_compute_forward_conv_1d_1s(
|
|
7661
8607
|
{
|
7662
8608
|
ggml_compute_forward_conv_1d_1s_f32(params, src0, src1, dst);
|
7663
8609
|
} break;
|
7664
|
-
|
7665
|
-
case GGML_TYPE_Q4_1:
|
7666
|
-
case GGML_TYPE_I8:
|
7667
|
-
case GGML_TYPE_I16:
|
7668
|
-
case GGML_TYPE_I32:
|
7669
|
-
case GGML_TYPE_COUNT:
|
8610
|
+
default:
|
7670
8611
|
{
|
7671
8612
|
GGML_ASSERT(false);
|
7672
8613
|
} break;
|
@@ -7929,12 +8870,7 @@ static void ggml_compute_forward_conv_1d_2s(
|
|
7929
8870
|
{
|
7930
8871
|
ggml_compute_forward_conv_1d_2s_f32(params, src0, src1, dst);
|
7931
8872
|
} break;
|
7932
|
-
|
7933
|
-
case GGML_TYPE_Q4_1:
|
7934
|
-
case GGML_TYPE_I8:
|
7935
|
-
case GGML_TYPE_I16:
|
7936
|
-
case GGML_TYPE_I32:
|
7937
|
-
case GGML_TYPE_COUNT:
|
8873
|
+
default:
|
7938
8874
|
{
|
7939
8875
|
GGML_ASSERT(false);
|
7940
8876
|
} break;
|
@@ -8414,12 +9350,7 @@ static void ggml_compute_forward_flash_attn(
|
|
8414
9350
|
{
|
8415
9351
|
ggml_compute_forward_flash_attn_f32(params, q, k, v, masked, dst);
|
8416
9352
|
} break;
|
8417
|
-
|
8418
|
-
case GGML_TYPE_Q4_1:
|
8419
|
-
case GGML_TYPE_I8:
|
8420
|
-
case GGML_TYPE_I16:
|
8421
|
-
case GGML_TYPE_I32:
|
8422
|
-
case GGML_TYPE_COUNT:
|
9353
|
+
default:
|
8423
9354
|
{
|
8424
9355
|
GGML_ASSERT(false);
|
8425
9356
|
} break;
|
@@ -8625,12 +9556,100 @@ static void ggml_compute_forward_flash_ff(
|
|
8625
9556
|
{
|
8626
9557
|
GGML_ASSERT(false); // TODO
|
8627
9558
|
} break;
|
8628
|
-
|
8629
|
-
|
8630
|
-
|
8631
|
-
|
8632
|
-
|
8633
|
-
|
9559
|
+
default:
|
9560
|
+
{
|
9561
|
+
GGML_ASSERT(false);
|
9562
|
+
} break;
|
9563
|
+
}
|
9564
|
+
}
|
9565
|
+
|
9566
|
+
// ggml_compute_forward_map_unary
|
9567
|
+
|
9568
|
+
static void ggml_compute_forward_map_unary_f32(
|
9569
|
+
const struct ggml_compute_params * params,
|
9570
|
+
const struct ggml_tensor * src0,
|
9571
|
+
struct ggml_tensor * dst,
|
9572
|
+
const ggml_unary_op_f32_t fun) {
|
9573
|
+
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
9574
|
+
|
9575
|
+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
9576
|
+
return;
|
9577
|
+
}
|
9578
|
+
|
9579
|
+
const int n = ggml_nrows(src0);
|
9580
|
+
const int nc = src0->ne[0];
|
9581
|
+
|
9582
|
+
assert( dst->nb[0] == sizeof(float));
|
9583
|
+
assert(src0->nb[0] == sizeof(float));
|
9584
|
+
|
9585
|
+
for (int i = 0; i < n; i++) {
|
9586
|
+
fun(nc,
|
9587
|
+
(float *) ((char *) dst->data + i*( dst->nb[1])),
|
9588
|
+
(float *) ((char *) src0->data + i*(src0->nb[1])));
|
9589
|
+
}
|
9590
|
+
}
|
9591
|
+
|
9592
|
+
|
9593
|
+
static void ggml_compute_forward_map_unary(
|
9594
|
+
const struct ggml_compute_params * params,
|
9595
|
+
const struct ggml_tensor * src0,
|
9596
|
+
struct ggml_tensor * dst,
|
9597
|
+
const ggml_unary_op_f32_t fun) {
|
9598
|
+
switch (src0->type) {
|
9599
|
+
case GGML_TYPE_F32:
|
9600
|
+
{
|
9601
|
+
ggml_compute_forward_map_unary_f32(params, src0, dst, fun);
|
9602
|
+
} break;
|
9603
|
+
default:
|
9604
|
+
{
|
9605
|
+
GGML_ASSERT(false);
|
9606
|
+
} break;
|
9607
|
+
}
|
9608
|
+
}
|
9609
|
+
|
9610
|
+
// ggml_compute_forward_map_binary
|
9611
|
+
|
9612
|
+
static void ggml_compute_forward_map_binary_f32(
|
9613
|
+
const struct ggml_compute_params * params,
|
9614
|
+
const struct ggml_tensor * src0,
|
9615
|
+
const struct ggml_tensor * src1,
|
9616
|
+
struct ggml_tensor * dst,
|
9617
|
+
const ggml_binary_op_f32_t fun) {
|
9618
|
+
assert(params->ith == 0);
|
9619
|
+
assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
|
9620
|
+
|
9621
|
+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
9622
|
+
return;
|
9623
|
+
}
|
9624
|
+
|
9625
|
+
const int n = ggml_nrows(src0);
|
9626
|
+
const int nc = src0->ne[0];
|
9627
|
+
|
9628
|
+
assert( dst->nb[0] == sizeof(float));
|
9629
|
+
assert(src0->nb[0] == sizeof(float));
|
9630
|
+
assert(src1->nb[0] == sizeof(float));
|
9631
|
+
|
9632
|
+
for (int i = 0; i < n; i++) {
|
9633
|
+
fun(nc,
|
9634
|
+
(float *) ((char *) dst->data + i*( dst->nb[1])),
|
9635
|
+
(float *) ((char *) src0->data + i*(src0->nb[1])),
|
9636
|
+
(float *) ((char *) src1->data + i*(src1->nb[1])));
|
9637
|
+
}
|
9638
|
+
}
|
9639
|
+
|
9640
|
+
|
9641
|
+
static void ggml_compute_forward_map_binary(
|
9642
|
+
const struct ggml_compute_params * params,
|
9643
|
+
const struct ggml_tensor * src0,
|
9644
|
+
const struct ggml_tensor * src1,
|
9645
|
+
struct ggml_tensor * dst,
|
9646
|
+
const ggml_binary_op_f32_t fun) {
|
9647
|
+
switch (src0->type) {
|
9648
|
+
case GGML_TYPE_F32:
|
9649
|
+
{
|
9650
|
+
ggml_compute_forward_map_binary_f32(params, src0, src1, dst, fun);
|
9651
|
+
} break;
|
9652
|
+
default:
|
8634
9653
|
{
|
8635
9654
|
GGML_ASSERT(false);
|
8636
9655
|
} break;
|
@@ -8731,6 +9750,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
8731
9750
|
{
|
8732
9751
|
ggml_compute_forward_cpy(params, tensor->src0, tensor);
|
8733
9752
|
} break;
|
9753
|
+
case GGML_OP_CONT:
|
9754
|
+
{
|
9755
|
+
ggml_compute_forward_cont(params, tensor->src0, tensor);
|
9756
|
+
} break;
|
8734
9757
|
case GGML_OP_RESHAPE:
|
8735
9758
|
{
|
8736
9759
|
ggml_compute_forward_reshape(params, tensor->src0, tensor);
|
@@ -8782,6 +9805,18 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
8782
9805
|
{
|
8783
9806
|
ggml_compute_forward_flash_ff(params, tensor->src0, tensor->src1, tensor->opt[0], tensor->opt[1], tensor->opt[2], tensor);
|
8784
9807
|
} break;
|
9808
|
+
case GGML_OP_MAP_UNARY:
|
9809
|
+
{
|
9810
|
+
const ggml_unary_op_f32_t fun = *((ggml_unary_op_f32_t *)tensor->opt[0]->data);
|
9811
|
+
ggml_compute_forward_map_unary(params, tensor->src0, tensor, fun);
|
9812
|
+
}
|
9813
|
+
break;
|
9814
|
+
case GGML_OP_MAP_BINARY:
|
9815
|
+
{
|
9816
|
+
const ggml_binary_op_f32_t fun = *((ggml_binary_op_f32_t *)tensor->opt[0]->data);
|
9817
|
+
ggml_compute_forward_map_binary(params, tensor->src0, tensor->src1, tensor, fun);
|
9818
|
+
}
|
9819
|
+
break;
|
8785
9820
|
case GGML_OP_NONE:
|
8786
9821
|
{
|
8787
9822
|
// nop
|
@@ -8975,8 +10010,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
8975
10010
|
src1->grad =
|
8976
10011
|
ggml_add_impl(ctx,
|
8977
10012
|
src1->grad,
|
8978
|
-
|
8979
|
-
|
10013
|
+
ggml_mul_mat(ctx,
|
10014
|
+
ggml_cont(ctx, ggml_transpose(ctx, src0)),
|
10015
|
+
tensor->grad),
|
8980
10016
|
inplace);
|
8981
10017
|
}
|
8982
10018
|
} break;
|
@@ -8988,6 +10024,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
8988
10024
|
{
|
8989
10025
|
GGML_ASSERT(false); // TODO: not implemented
|
8990
10026
|
} break;
|
10027
|
+
case GGML_OP_CONT:
|
10028
|
+
{
|
10029
|
+
GGML_ASSERT(false); // TODO: not implemented
|
10030
|
+
} break;
|
8991
10031
|
case GGML_OP_RESHAPE:
|
8992
10032
|
{
|
8993
10033
|
GGML_ASSERT(false); // TODO: not implemented
|
@@ -9036,6 +10076,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
9036
10076
|
{
|
9037
10077
|
GGML_ASSERT(false); // not supported
|
9038
10078
|
} break;
|
10079
|
+
case GGML_OP_MAP_UNARY:
|
10080
|
+
case GGML_OP_MAP_BINARY:
|
10081
|
+
{
|
10082
|
+
GGML_ASSERT(false); // not supported
|
10083
|
+
} break;
|
9039
10084
|
case GGML_OP_NONE:
|
9040
10085
|
{
|
9041
10086
|
// nop
|
@@ -9126,7 +10171,7 @@ struct ggml_cgraph ggml_build_forward(struct ggml_tensor * tensor) {
|
|
9126
10171
|
struct ggml_cgraph result = {
|
9127
10172
|
/*.n_nodes =*/ 0,
|
9128
10173
|
/*.n_leafs =*/ 0,
|
9129
|
-
/*.n_threads =*/
|
10174
|
+
/*.n_threads =*/ GGML_DEFAULT_N_THREADS,
|
9130
10175
|
/*.work_size =*/ 0,
|
9131
10176
|
/*.work =*/ NULL,
|
9132
10177
|
/*.nodes =*/ { NULL },
|
@@ -9354,13 +10399,29 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|
9354
10399
|
struct ggml_tensor * node = cgraph->nodes[i];
|
9355
10400
|
|
9356
10401
|
switch (node->op) {
|
10402
|
+
case GGML_OP_CPY:
|
9357
10403
|
case GGML_OP_DUP:
|
9358
10404
|
{
|
9359
10405
|
node->n_tasks = 1;
|
10406
|
+
|
10407
|
+
size_t cur = 0;
|
10408
|
+
if (node->type == GGML_TYPE_Q4_0 || node->type == GGML_TYPE_Q4_1) {
|
10409
|
+
cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->ne[0];
|
10410
|
+
}
|
10411
|
+
|
10412
|
+
work_size = MAX(work_size, cur);
|
9360
10413
|
} break;
|
9361
10414
|
case GGML_OP_ADD:
|
9362
10415
|
{
|
9363
10416
|
node->n_tasks = n_threads;
|
10417
|
+
|
10418
|
+
size_t cur = 0;
|
10419
|
+
|
10420
|
+
if (node->src0->type == GGML_TYPE_Q4_0 || node->src0->type == GGML_TYPE_Q4_1) {
|
10421
|
+
cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src0->ne[0] * n_threads;
|
10422
|
+
}
|
10423
|
+
|
10424
|
+
work_size = MAX(work_size, cur);
|
9364
10425
|
} break;
|
9365
10426
|
case GGML_OP_SUB:
|
9366
10427
|
case GGML_OP_MUL:
|
@@ -9429,7 +10490,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|
9429
10490
|
} else
|
9430
10491
|
#endif
|
9431
10492
|
{
|
9432
|
-
cur = GGML_TYPE_SIZE[
|
10493
|
+
cur = GGML_TYPE_SIZE[GGML_TYPE_Q8_0]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[GGML_TYPE_Q8_0];
|
9433
10494
|
}
|
9434
10495
|
} else {
|
9435
10496
|
GGML_ASSERT(false);
|
@@ -9441,7 +10502,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|
9441
10502
|
{
|
9442
10503
|
node->n_tasks = n_threads;
|
9443
10504
|
} break;
|
9444
|
-
case
|
10505
|
+
case GGML_OP_CONT:
|
9445
10506
|
case GGML_OP_RESHAPE:
|
9446
10507
|
case GGML_OP_VIEW:
|
9447
10508
|
case GGML_OP_PERMUTE:
|
@@ -9527,6 +10588,11 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|
9527
10588
|
|
9528
10589
|
work_size = MAX(work_size, cur);
|
9529
10590
|
} break;
|
10591
|
+
case GGML_OP_MAP_UNARY:
|
10592
|
+
case GGML_OP_MAP_BINARY:
|
10593
|
+
{
|
10594
|
+
node->n_tasks = 1;
|
10595
|
+
} break;
|
9530
10596
|
case GGML_OP_NONE:
|
9531
10597
|
{
|
9532
10598
|
node->n_tasks = 1;
|
@@ -9745,8 +10811,8 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
|
|
9745
10811
|
|
9746
10812
|
GGML_PRINT("=== GRAPH ===\n");
|
9747
10813
|
|
9748
|
-
GGML_PRINT_DEBUG("n_threads = %d\n",
|
9749
|
-
GGML_PRINT_DEBUG("total work size = %zu bytes\n",cgraph->work_size);
|
10814
|
+
GGML_PRINT_DEBUG("n_threads = %d\n", cgraph->n_threads);
|
10815
|
+
GGML_PRINT_DEBUG("total work size = %zu bytes\n", cgraph->work_size);
|
9750
10816
|
|
9751
10817
|
GGML_PRINT("n_nodes = %d\n", cgraph->n_nodes);
|
9752
10818
|
for (int i = 0; i < cgraph->n_nodes; i++) {
|
@@ -10598,16 +11664,16 @@ enum ggml_opt_result ggml_opt(
|
|
10598
11664
|
////////////////////////////////////////////////////////////////////////////////
|
10599
11665
|
|
10600
11666
|
size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist) {
|
10601
|
-
assert(k %
|
10602
|
-
const int nb = k /
|
11667
|
+
assert(k % QK4_0 == 0);
|
11668
|
+
const int nb = k / QK4_0;
|
10603
11669
|
|
10604
11670
|
for (int j = 0; j < n; j += k) {
|
10605
|
-
block_q4_0 * restrict y = (block_q4_0 *)dst + j/
|
11671
|
+
block_q4_0 * restrict y = (block_q4_0 *)dst + j/QK4_0;
|
10606
11672
|
|
10607
11673
|
quantize_row_q4_0_reference(src + j, y, k);
|
10608
11674
|
|
10609
11675
|
for (int i = 0; i < nb; i++) {
|
10610
|
-
for (int l = 0; l <
|
11676
|
+
for (int l = 0; l < QK4_0; l += 2) {
|
10611
11677
|
const uint8_t vi0 = y[i].qs[l/2] & 0xF;
|
10612
11678
|
const uint8_t vi1 = y[i].qs[l/2] >> 4;
|
10613
11679
|
|
@@ -10617,20 +11683,20 @@ size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t *
|
|
10617
11683
|
}
|
10618
11684
|
}
|
10619
11685
|
|
10620
|
-
return (n/
|
11686
|
+
return (n/QK4_0*sizeof(block_q4_0));
|
10621
11687
|
}
|
10622
11688
|
|
10623
11689
|
size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist) {
|
10624
|
-
assert(k %
|
10625
|
-
const int nb = k /
|
11690
|
+
assert(k % QK4_1 == 0);
|
11691
|
+
const int nb = k / QK4_1;
|
10626
11692
|
|
10627
11693
|
for (int j = 0; j < n; j += k) {
|
10628
|
-
block_q4_1 * restrict y = (block_q4_1 *)dst + j/
|
11694
|
+
block_q4_1 * restrict y = (block_q4_1 *)dst + j/QK4_1;
|
10629
11695
|
|
10630
11696
|
quantize_row_q4_1_reference(src + j, y, k);
|
10631
11697
|
|
10632
11698
|
for (int i = 0; i < nb; i++) {
|
10633
|
-
for (int l = 0; l <
|
11699
|
+
for (int l = 0; l < QK4_1; l += 2) {
|
10634
11700
|
const uint8_t vi0 = y[i].qs[l/2] & 0xF;
|
10635
11701
|
const uint8_t vi1 = y[i].qs[l/2] >> 4;
|
10636
11702
|
|
@@ -10640,7 +11706,7 @@ size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t *
|
|
10640
11706
|
}
|
10641
11707
|
}
|
10642
11708
|
|
10643
|
-
return (n/
|
11709
|
+
return (n/QK4_1*sizeof(block_q4_1));
|
10644
11710
|
}
|
10645
11711
|
|
10646
11712
|
////////////////////////////////////////////////////////////////////////////////
|
@@ -10669,6 +11735,22 @@ int ggml_cpu_has_avx512(void) {
|
|
10669
11735
|
#endif
|
10670
11736
|
}
|
10671
11737
|
|
11738
|
+
int ggml_cpu_has_avx512_vbmi(void) {
|
11739
|
+
#if defined(__AVX512VBMI__)
|
11740
|
+
return 1;
|
11741
|
+
#else
|
11742
|
+
return 0;
|
11743
|
+
#endif
|
11744
|
+
}
|
11745
|
+
|
11746
|
+
int ggml_cpu_has_avx512_vnni(void) {
|
11747
|
+
#if defined(__AVX512VNNI__)
|
11748
|
+
return 1;
|
11749
|
+
#else
|
11750
|
+
return 0;
|
11751
|
+
#endif
|
11752
|
+
}
|
11753
|
+
|
10672
11754
|
int ggml_cpu_has_fma(void) {
|
10673
11755
|
#if defined(__FMA__)
|
10674
11756
|
return 1;
|