llama_cpp 0.0.4 → 0.0.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +18 -0
- data/README.md +3 -2
- data/ext/llama_cpp/extconf.rb +12 -0
- data/ext/llama_cpp/llama_cpp.cpp +60 -0
- data/ext/llama_cpp/src/ggml.c +1108 -508
- data/ext/llama_cpp/src/ggml.h +10 -0
- data/ext/llama_cpp/src/llama.cpp +317 -47
- data/ext/llama_cpp/src/llama.h +12 -0
- data/ext/llama_cpp/src/llama_util.h +22 -15
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +3 -3
- data/sig/llama_cpp.rbs +3 -0
- metadata +2 -2
data/ext/llama_cpp/src/ggml.c
CHANGED
@@ -118,7 +118,16 @@ typedef void* thread_ret_t;
|
|
118
118
|
#define GGML_ALIGNED_MALLOC(size) _aligned_malloc(size, GGML_MEM_ALIGN)
|
119
119
|
#define GGML_ALIGNED_FREE(ptr) _aligned_free(ptr)
|
120
120
|
#else
|
121
|
-
|
121
|
+
inline static void* ggml_aligned_malloc(size_t size) {
|
122
|
+
void* aligned_memory = NULL;
|
123
|
+
int result = posix_memalign(&aligned_memory, GGML_MEM_ALIGN, size);
|
124
|
+
if (result != 0) {
|
125
|
+
// Handle allocation failure
|
126
|
+
return NULL;
|
127
|
+
}
|
128
|
+
return aligned_memory;
|
129
|
+
}
|
130
|
+
#define GGML_ALIGNED_MALLOC(size) ggml_aligned_malloc(size)
|
122
131
|
#define GGML_ALIGNED_FREE(ptr) free(ptr)
|
123
132
|
#endif
|
124
133
|
|
@@ -418,8 +427,6 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
|
|
418
427
|
// quantization
|
419
428
|
//
|
420
429
|
|
421
|
-
#define QK 32
|
422
|
-
|
423
430
|
// AVX routines provided by GH user Const-me
|
424
431
|
// ref: https://github.com/ggerganov/ggml/pull/27#issuecomment-1464934600
|
425
432
|
#if __AVX2__ || __AVX512F__
|
@@ -531,68 +538,73 @@ inline static float vaddvq_f32(float32x4_t v) {
|
|
531
538
|
return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3);
|
532
539
|
}
|
533
540
|
|
534
|
-
|
541
|
+
float vminvq_f32(float32x4_t v) {
|
535
542
|
return
|
536
543
|
MIN(MIN(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)),
|
537
544
|
MIN(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3)));
|
538
545
|
}
|
539
546
|
|
540
|
-
|
547
|
+
float vmaxvq_f32(float32x4_t v) {
|
541
548
|
return
|
542
549
|
MAX(MAX(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)),
|
543
550
|
MAX(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3)));
|
544
551
|
}
|
545
552
|
|
546
|
-
|
553
|
+
int8x8_t vzip1_s8(int8x8_t a, int8x8_t b) {
|
547
554
|
return vget_low_s8(vcombine_s8(a, b));
|
548
555
|
}
|
549
556
|
|
550
|
-
|
557
|
+
int8x8_t vzip2_s8(int8x8_t a, int8x8_t b) {
|
551
558
|
return vget_high_s8(vcombine_s8(a, b));
|
552
559
|
}
|
553
560
|
|
554
|
-
|
561
|
+
uint8x8_t vzip1_u8(uint8x8_t a, uint8x8_t b) {
|
555
562
|
return vget_low_u8(vcombine_u8(a, b));
|
556
563
|
}
|
557
564
|
|
558
|
-
|
565
|
+
uint8x8_t vzip2_u8(uint8x8_t a, uint8x8_t b) {
|
559
566
|
return vget_high_u8(vcombine_u8(a, b));
|
560
567
|
}
|
561
568
|
|
562
569
|
#endif
|
563
570
|
#endif
|
564
571
|
|
565
|
-
|
566
|
-
|
567
|
-
// represented with a single float (delta) and QK/2 8-bit ints (i.e QK 4-bit signed integer factors)
|
572
|
+
|
573
|
+
#define QK4_0 32
|
568
574
|
typedef struct {
|
569
|
-
float d;
|
570
|
-
uint8_t qs[
|
575
|
+
float d; // delta
|
576
|
+
uint8_t qs[QK4_0 / 2]; // nibbles / quants
|
571
577
|
} block_q4_0;
|
572
|
-
static_assert(sizeof(block_q4_0) == sizeof(float) +
|
578
|
+
static_assert(sizeof(block_q4_0) == sizeof(float) + QK4_0 / 2, "wrong q4_0 block size/padding");
|
573
579
|
|
574
|
-
|
575
|
-
// blocks of QK elements
|
576
|
-
// represented with 2 floats (delta + min) and QK/2 8-bit ints (i.e QK 4-bit unsigned integer factors)
|
580
|
+
#define QK4_1 32
|
577
581
|
typedef struct {
|
578
|
-
float d;
|
579
|
-
float m;
|
580
|
-
uint8_t qs[
|
582
|
+
float d; // delta
|
583
|
+
float m; // min
|
584
|
+
uint8_t qs[QK4_1 / 2]; // nibbles / quants
|
581
585
|
} block_q4_1;
|
582
|
-
static_assert(sizeof(block_q4_1) == sizeof(float) * 2 +
|
586
|
+
static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding");
|
587
|
+
|
588
|
+
#define QK8_0 32
|
589
|
+
typedef struct {
|
590
|
+
float d; // delta
|
591
|
+
int8_t qs[QK8_0]; // quants
|
592
|
+
} block_q8_0;
|
593
|
+
static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding");
|
594
|
+
|
583
595
|
|
584
596
|
// reference implementation for deterministic creation of model files
|
585
597
|
static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) {
|
586
|
-
assert(k %
|
587
|
-
const int nb = k /
|
598
|
+
assert(k % QK4_0 == 0);
|
599
|
+
const int nb = k / QK4_0;
|
588
600
|
|
589
|
-
uint8_t pp[
|
601
|
+
uint8_t pp[QK4_0/2];
|
590
602
|
|
591
603
|
for (int i = 0; i < nb; i++) {
|
592
604
|
float amax = 0.0f; // absolute max
|
593
605
|
|
594
|
-
for (int l = 0; l <
|
595
|
-
const float v = x[i*
|
606
|
+
for (int l = 0; l < QK4_0; l++) {
|
607
|
+
const float v = x[i*QK4_0 + l];
|
596
608
|
amax = MAX(amax, fabsf(v));
|
597
609
|
}
|
598
610
|
|
@@ -601,9 +613,9 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r
|
|
601
613
|
|
602
614
|
y[i].d = d;
|
603
615
|
|
604
|
-
for (int l = 0; l <
|
605
|
-
const float v0 = x[i*
|
606
|
-
const float v1 = x[i*
|
616
|
+
for (int l = 0; l < QK4_0; l += 2) {
|
617
|
+
const float v0 = x[i*QK4_0 + l + 0]*id;
|
618
|
+
const float v1 = x[i*QK4_0 + l + 1]*id;
|
607
619
|
|
608
620
|
const uint8_t vi0 = (int8_t)roundf(v0) + 8;
|
609
621
|
const uint8_t vi1 = (int8_t)roundf(v1) + 8;
|
@@ -619,8 +631,8 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r
|
|
619
631
|
}
|
620
632
|
|
621
633
|
static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int k) {
|
622
|
-
assert(k %
|
623
|
-
const int nb = k /
|
634
|
+
assert(k % QK4_0 == 0);
|
635
|
+
const int nb = k / QK4_0;
|
624
636
|
|
625
637
|
block_q4_0 * restrict y = vy;
|
626
638
|
|
@@ -870,19 +882,19 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
|
|
870
882
|
}
|
871
883
|
|
872
884
|
static void quantize_row_q4_1_reference(const float * restrict x, void * restrict vy, int k) {
|
873
|
-
assert(k %
|
874
|
-
const int nb = k /
|
885
|
+
assert(k % QK4_1 == 0);
|
886
|
+
const int nb = k / QK4_1;
|
875
887
|
|
876
888
|
block_q4_1 * restrict y = vy;
|
877
889
|
|
878
|
-
uint8_t pp[
|
890
|
+
uint8_t pp[QK4_1/2];
|
879
891
|
|
880
892
|
for (int i = 0; i < nb; i++) {
|
881
893
|
float min = FLT_MAX;
|
882
894
|
float max = -FLT_MAX;
|
883
895
|
|
884
|
-
for (int l = 0; l <
|
885
|
-
const float v = x[i*
|
896
|
+
for (int l = 0; l < QK4_1; l++) {
|
897
|
+
const float v = x[i*QK4_1 + l];
|
886
898
|
if (v < min) min = v;
|
887
899
|
if (v > max) max = v;
|
888
900
|
}
|
@@ -893,9 +905,9 @@ static void quantize_row_q4_1_reference(const float * restrict x, void * restric
|
|
893
905
|
y[i].d = d;
|
894
906
|
y[i].m = min;
|
895
907
|
|
896
|
-
for (int l = 0; l <
|
897
|
-
const float v0 = (x[i*
|
898
|
-
const float v1 = (x[i*
|
908
|
+
for (int l = 0; l < QK4_1; l += 2) {
|
909
|
+
const float v0 = (x[i*QK4_1 + l + 0] - min)*id;
|
910
|
+
const float v1 = (x[i*QK4_1 + l + 1] - min)*id;
|
899
911
|
|
900
912
|
const uint8_t vi0 = roundf(v0);
|
901
913
|
const uint8_t vi1 = roundf(v1);
|
@@ -911,9 +923,9 @@ static void quantize_row_q4_1_reference(const float * restrict x, void * restric
|
|
911
923
|
}
|
912
924
|
|
913
925
|
static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int k) {
|
914
|
-
assert(k %
|
926
|
+
assert(k % QK4_1 == 0);
|
915
927
|
|
916
|
-
const int nb = k /
|
928
|
+
const int nb = k / QK4_1;
|
917
929
|
|
918
930
|
block_q4_1 * restrict y = vy;
|
919
931
|
|
@@ -997,7 +1009,7 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int
|
|
997
1009
|
float32x4_t minv[8];
|
998
1010
|
float32x4_t maxv[8];
|
999
1011
|
|
1000
|
-
for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*
|
1012
|
+
for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*QK4_1 + 4*l);
|
1001
1013
|
|
1002
1014
|
for (int l = 0; l < 4; l++) minv[2*l] = vminq_f32(srcv[2*l], srcv[2*l + 1]);
|
1003
1015
|
for (int l = 0; l < 2; l++) minv[4*l] = vminq_f32(minv[4*l], minv[4*l + 2]);
|
@@ -1033,9 +1045,160 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int
|
|
1033
1045
|
#endif
|
1034
1046
|
}
|
1035
1047
|
|
1048
|
+
// reference implementation for deterministic creation of model files
|
1049
|
+
static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int k) {
|
1050
|
+
assert(k % QK8_0 == 0);
|
1051
|
+
const int nb = k / QK8_0;
|
1052
|
+
|
1053
|
+
for (int i = 0; i < nb; i++) {
|
1054
|
+
float amax = 0.0f; // absolute max
|
1055
|
+
|
1056
|
+
for (int l = 0; l < QK8_0; l++) {
|
1057
|
+
const float v = x[i*QK8_0 + l];
|
1058
|
+
amax = MAX(amax, fabsf(v));
|
1059
|
+
}
|
1060
|
+
|
1061
|
+
const float d = amax / ((1 << 7) - 1);
|
1062
|
+
const float id = d ? 1.0f/d : 0.0f;
|
1063
|
+
|
1064
|
+
y[i].d = d;
|
1065
|
+
|
1066
|
+
for (int l = 0; l < QK8_0; ++l) {
|
1067
|
+
const float v = x[i*QK8_0 + l]*id;
|
1068
|
+
y[i].qs[l] = roundf(v);
|
1069
|
+
}
|
1070
|
+
}
|
1071
|
+
}
|
1072
|
+
|
1073
|
+
static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int k) {
|
1074
|
+
assert(k % QK8_0 == 0);
|
1075
|
+
const int nb = k / QK8_0;
|
1076
|
+
|
1077
|
+
block_q8_0 * restrict y = vy;
|
1078
|
+
|
1079
|
+
#if defined(__ARM_NEON)
|
1080
|
+
for (int i = 0; i < nb; i++) {
|
1081
|
+
float32x4_t srcv [8];
|
1082
|
+
float32x4_t asrcv[8];
|
1083
|
+
float32x4_t amaxv[8];
|
1084
|
+
|
1085
|
+
for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*32 + 4*l);
|
1086
|
+
for (int l = 0; l < 8; l++) asrcv[l] = vabsq_f32(srcv[l]);
|
1087
|
+
|
1088
|
+
for (int l = 0; l < 4; l++) amaxv[2*l] = vmaxq_f32(asrcv[2*l], asrcv[2*l+1]);
|
1089
|
+
for (int l = 0; l < 2; l++) amaxv[4*l] = vmaxq_f32(amaxv[4*l], amaxv[4*l+2]);
|
1090
|
+
for (int l = 0; l < 1; l++) amaxv[8*l] = vmaxq_f32(amaxv[8*l], amaxv[8*l+4]);
|
1091
|
+
|
1092
|
+
const float amax = vmaxvq_f32(amaxv[0]);
|
1093
|
+
|
1094
|
+
const float d = amax / ((1 << 7) - 1);
|
1095
|
+
const float id = d ? 1.0f/d : 0.0f;
|
1096
|
+
|
1097
|
+
y[i].d = d;
|
1098
|
+
|
1099
|
+
for (int l = 0; l < 8; l++) {
|
1100
|
+
const float32x4_t v = vmulq_n_f32(srcv[l], id);
|
1101
|
+
const int32x4_t vi = vcvtnq_s32_f32(v);
|
1102
|
+
|
1103
|
+
y[i].qs[4*l + 0] = vgetq_lane_s32(vi, 0);
|
1104
|
+
y[i].qs[4*l + 1] = vgetq_lane_s32(vi, 1);
|
1105
|
+
y[i].qs[4*l + 2] = vgetq_lane_s32(vi, 2);
|
1106
|
+
y[i].qs[4*l + 3] = vgetq_lane_s32(vi, 3);
|
1107
|
+
}
|
1108
|
+
}
|
1109
|
+
#elif defined(__AVX2__) || defined(__AVX__)
|
1110
|
+
for (int i = 0; i < nb; i++) {
|
1111
|
+
// Load elements into 4 AVX vectors
|
1112
|
+
__m256 v0 = _mm256_loadu_ps( x );
|
1113
|
+
__m256 v1 = _mm256_loadu_ps( x + 8 );
|
1114
|
+
__m256 v2 = _mm256_loadu_ps( x + 16 );
|
1115
|
+
__m256 v3 = _mm256_loadu_ps( x + 24 );
|
1116
|
+
x += 32;
|
1117
|
+
|
1118
|
+
// Compute max(abs(e)) for the block
|
1119
|
+
const __m256 signBit = _mm256_set1_ps( -0.0f );
|
1120
|
+
__m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
|
1121
|
+
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
|
1122
|
+
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
|
1123
|
+
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
|
1124
|
+
|
1125
|
+
__m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
|
1126
|
+
max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
|
1127
|
+
max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
|
1128
|
+
const float maxScalar = _mm_cvtss_f32( max4 );
|
1129
|
+
|
1130
|
+
// Quantize these floats
|
1131
|
+
const float d = maxScalar / 127.f;
|
1132
|
+
y[i].d = d;
|
1133
|
+
const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
|
1134
|
+
const __m256 mul = _mm256_set1_ps( id );
|
1135
|
+
|
1136
|
+
// Apply the multiplier
|
1137
|
+
v0 = _mm256_mul_ps( v0, mul );
|
1138
|
+
v1 = _mm256_mul_ps( v1, mul );
|
1139
|
+
v2 = _mm256_mul_ps( v2, mul );
|
1140
|
+
v3 = _mm256_mul_ps( v3, mul );
|
1141
|
+
|
1142
|
+
// Round to nearest integer
|
1143
|
+
v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
|
1144
|
+
v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
|
1145
|
+
v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
|
1146
|
+
v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
|
1147
|
+
|
1148
|
+
// Convert floats to integers
|
1149
|
+
__m256i i0 = _mm256_cvtps_epi32( v0 );
|
1150
|
+
__m256i i1 = _mm256_cvtps_epi32( v1 );
|
1151
|
+
__m256i i2 = _mm256_cvtps_epi32( v2 );
|
1152
|
+
__m256i i3 = _mm256_cvtps_epi32( v3 );
|
1153
|
+
|
1154
|
+
#if defined(__AVX2__)
|
1155
|
+
// Convert int32 to int16
|
1156
|
+
i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
|
1157
|
+
i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
|
1158
|
+
// Convert int16 to int8
|
1159
|
+
i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
|
1160
|
+
|
1161
|
+
// We got our precious signed bytes, but the order is now wrong
|
1162
|
+
// These AVX2 pack instructions process 16-byte pieces independently
|
1163
|
+
// The following instruction is fixing the order
|
1164
|
+
const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
|
1165
|
+
i0 = _mm256_permutevar8x32_epi32( i0, perm );
|
1166
|
+
|
1167
|
+
_mm256_storeu_si256((__m256i *)y[i].qs, i0);
|
1168
|
+
#else
|
1169
|
+
// Since we don't have in AVX some necessary functions,
|
1170
|
+
// we split the registers in half and call AVX2 analogs from SSE
|
1171
|
+
__m128i ni0 = _mm256_castsi256_si128( i0 );
|
1172
|
+
__m128i ni1 = _mm256_extractf128_si256( i0, 1);
|
1173
|
+
__m128i ni2 = _mm256_castsi256_si128( i1 );
|
1174
|
+
__m128i ni3 = _mm256_extractf128_si256( i1, 1);
|
1175
|
+
__m128i ni4 = _mm256_castsi256_si128( i2 );
|
1176
|
+
__m128i ni5 = _mm256_extractf128_si256( i2, 1);
|
1177
|
+
__m128i ni6 = _mm256_castsi256_si128( i3 );
|
1178
|
+
__m128i ni7 = _mm256_extractf128_si256( i3, 1);
|
1179
|
+
|
1180
|
+
// Convert int32 to int16
|
1181
|
+
ni0 = _mm_packs_epi32( ni0, ni1 );
|
1182
|
+
ni2 = _mm_packs_epi32( ni2, ni3 );
|
1183
|
+
ni4 = _mm_packs_epi32( ni4, ni5 );
|
1184
|
+
ni6 = _mm_packs_epi32( ni6, ni7 );
|
1185
|
+
// Convert int16 to int8
|
1186
|
+
ni0 = _mm_packs_epi16( ni0, ni2 );
|
1187
|
+
ni4 = _mm_packs_epi16( ni4, ni6 );
|
1188
|
+
|
1189
|
+
_mm_storeu_si128((__m128i *)(y[i].qs + 0), ni0);
|
1190
|
+
_mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
|
1191
|
+
#endif
|
1192
|
+
}
|
1193
|
+
#else
|
1194
|
+
// scalar
|
1195
|
+
quantize_row_q8_0_reference(x, y, k);
|
1196
|
+
#endif
|
1197
|
+
}
|
1198
|
+
|
1036
1199
|
static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, int k) {
|
1037
|
-
assert(k %
|
1038
|
-
const int nb = k /
|
1200
|
+
assert(k % QK4_0 == 0);
|
1201
|
+
const int nb = k / QK4_0;
|
1039
1202
|
|
1040
1203
|
const block_q4_0 * restrict x = vx;
|
1041
1204
|
|
@@ -1046,7 +1209,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
|
|
1046
1209
|
|
1047
1210
|
const uint8_t * restrict pp = x[i].qs;
|
1048
1211
|
|
1049
|
-
for (int l = 0; l <
|
1212
|
+
for (int l = 0; l < QK4_0; l += 32) {
|
1050
1213
|
// Load 32x4-bit integers into 32x8-bit integers
|
1051
1214
|
__m256i vx8 = bytesFromNibbles(pp+l/2);
|
1052
1215
|
|
@@ -1068,7 +1231,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
|
|
1068
1231
|
// Scale and store
|
1069
1232
|
for (int j = 0; j < 4; j++) {
|
1070
1233
|
const __m256 result = _mm256_mul_ps(vf[j], d_v);
|
1071
|
-
_mm256_storeu_ps(y + i *
|
1234
|
+
_mm256_storeu_ps(y + i * QK4_0 + l + j*8, result);
|
1072
1235
|
}
|
1073
1236
|
}
|
1074
1237
|
}
|
@@ -1078,7 +1241,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
|
|
1078
1241
|
|
1079
1242
|
const uint8_t * restrict pp = x[i].qs;
|
1080
1243
|
|
1081
|
-
for (int l = 0; l <
|
1244
|
+
for (int l = 0; l < QK4_0; l += 16) {
|
1082
1245
|
// Load 16x4-bit integers into 8x8-bit integers
|
1083
1246
|
const uint8x8_t v8 = vld1_u8(pp + l/2);
|
1084
1247
|
|
@@ -1117,10 +1280,10 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
|
|
1117
1280
|
const float32x4_t r3 = vmulq_f32(vf_3, vd);
|
1118
1281
|
|
1119
1282
|
// Store
|
1120
|
-
vst1q_f32(y + i*
|
1121
|
-
vst1q_f32(y + i*
|
1122
|
-
vst1q_f32(y + i*
|
1123
|
-
vst1q_f32(y + i*
|
1283
|
+
vst1q_f32(y + i*QK4_0 + l + 0, r0);
|
1284
|
+
vst1q_f32(y + i*QK4_0 + l + 4, r1);
|
1285
|
+
vst1q_f32(y + i*QK4_0 + l + 8, r2);
|
1286
|
+
vst1q_f32(y + i*QK4_0 + l + 12, r3);
|
1124
1287
|
}
|
1125
1288
|
}
|
1126
1289
|
#else
|
@@ -1130,7 +1293,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
|
|
1130
1293
|
|
1131
1294
|
const uint8_t * restrict pp = x[i].qs;
|
1132
1295
|
|
1133
|
-
for (int l = 0; l <
|
1296
|
+
for (int l = 0; l < QK4_0; l += 2) {
|
1134
1297
|
const uint8_t vi = pp[l/2];
|
1135
1298
|
|
1136
1299
|
const int8_t vi0 = vi & 0xf;
|
@@ -1141,19 +1304,19 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
|
|
1141
1304
|
|
1142
1305
|
//printf("d = %f, vi = %d, vi0 = %d, vi1 = %d, v0 = %f, v1 = %f\n", d, vi, vi0, vi1, v0, v1);
|
1143
1306
|
|
1144
|
-
y[i*
|
1145
|
-
y[i*
|
1307
|
+
y[i*QK4_0 + l + 0] = v0;
|
1308
|
+
y[i*QK4_0 + l + 1] = v1;
|
1146
1309
|
|
1147
|
-
assert(!isnan(y[i*
|
1148
|
-
assert(!isnan(y[i*
|
1310
|
+
assert(!isnan(y[i*QK4_0 + l + 0]));
|
1311
|
+
assert(!isnan(y[i*QK4_0 + l + 1]));
|
1149
1312
|
}
|
1150
1313
|
}
|
1151
1314
|
#endif
|
1152
1315
|
}
|
1153
1316
|
|
1154
1317
|
static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, int k) {
|
1155
|
-
assert(k %
|
1156
|
-
const int nb = k /
|
1318
|
+
assert(k % QK4_1 == 0);
|
1319
|
+
const int nb = k / QK4_1;
|
1157
1320
|
|
1158
1321
|
const block_q4_1 * restrict x = vx;
|
1159
1322
|
|
@@ -1164,7 +1327,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
|
|
1164
1327
|
|
1165
1328
|
const uint8_t * restrict pp = x[i].qs;
|
1166
1329
|
|
1167
|
-
for (int l = 0; l <
|
1330
|
+
for (int l = 0; l < QK4_1; l += 32) {
|
1168
1331
|
// Load 32x4-bit integers into 32x8-bit integers
|
1169
1332
|
__m256i vx8 = bytesFromNibbles(pp+l/2);
|
1170
1333
|
|
@@ -1183,7 +1346,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
|
|
1183
1346
|
// Scale, add m and store
|
1184
1347
|
for (int j = 0; j < 4; j++) {
|
1185
1348
|
const __m256 result = _mm256_add_ps(_mm256_mul_ps(vf[j], d_v), d_m);
|
1186
|
-
_mm256_storeu_ps(y + i *
|
1349
|
+
_mm256_storeu_ps(y + i * QK4_1 + l + j*8, result);
|
1187
1350
|
}
|
1188
1351
|
}
|
1189
1352
|
}
|
@@ -1194,7 +1357,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
|
|
1194
1357
|
|
1195
1358
|
const uint8_t * restrict pp = x[i].qs;
|
1196
1359
|
|
1197
|
-
for (int l = 0; l <
|
1360
|
+
for (int l = 0; l < QK4_1; l += 16) {
|
1198
1361
|
// Load 16x4-bit integers into 8x8-bit integers
|
1199
1362
|
const uint8x8_t v8 = vld1_u8(pp + l/2);
|
1200
1363
|
|
@@ -1225,10 +1388,10 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
|
|
1225
1388
|
const float32x4_t r3 = vmlaq_f32(vm, vf_3, vd);
|
1226
1389
|
|
1227
1390
|
// Store
|
1228
|
-
vst1q_f32(y + i*
|
1229
|
-
vst1q_f32(y + i*
|
1230
|
-
vst1q_f32(y + i*
|
1231
|
-
vst1q_f32(y + i*
|
1391
|
+
vst1q_f32(y + i*QK4_1 + l + 0, r0);
|
1392
|
+
vst1q_f32(y + i*QK4_1 + l + 4, r1);
|
1393
|
+
vst1q_f32(y + i*QK4_1 + l + 8, r2);
|
1394
|
+
vst1q_f32(y + i*QK4_1 + l + 12, r3);
|
1232
1395
|
}
|
1233
1396
|
}
|
1234
1397
|
#else
|
@@ -1238,7 +1401,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
|
|
1238
1401
|
|
1239
1402
|
const uint8_t * restrict pp = x[i].qs;
|
1240
1403
|
|
1241
|
-
for (int l = 0; l <
|
1404
|
+
for (int l = 0; l < QK4_1; l += 2) {
|
1242
1405
|
const uint8_t vi = pp[l/2];
|
1243
1406
|
|
1244
1407
|
const int8_t vi0 = vi & 0xf;
|
@@ -1247,16 +1410,44 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
|
|
1247
1410
|
const float v0 = vi0*d + m;
|
1248
1411
|
const float v1 = vi1*d + m;
|
1249
1412
|
|
1250
|
-
y[i*
|
1251
|
-
y[i*
|
1413
|
+
y[i*QK4_1 + l + 0] = v0;
|
1414
|
+
y[i*QK4_1 + l + 1] = v1;
|
1252
1415
|
|
1253
|
-
assert(!isnan(y[i*
|
1254
|
-
assert(!isnan(y[i*
|
1416
|
+
assert(!isnan(y[i*QK4_1 + l + 0]));
|
1417
|
+
assert(!isnan(y[i*QK4_1 + l + 1]));
|
1255
1418
|
}
|
1256
1419
|
}
|
1257
1420
|
#endif
|
1258
1421
|
}
|
1259
1422
|
|
1423
|
+
static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
1424
|
+
static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
1425
|
+
|
1426
|
+
static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
|
1427
|
+
[GGML_TYPE_Q4_0] = {
|
1428
|
+
.dequantize_row_q = dequantize_row_q4_0,
|
1429
|
+
.quantize_row_q = quantize_row_q4_0,
|
1430
|
+
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_0_reference,
|
1431
|
+
.quantize_row_q_dot = quantize_row_q8_0,
|
1432
|
+
.vec_dot_q = ggml_vec_dot_q4_0_q8_0,
|
1433
|
+
},
|
1434
|
+
[GGML_TYPE_Q4_1] = {
|
1435
|
+
.dequantize_row_q = dequantize_row_q4_1,
|
1436
|
+
.quantize_row_q = quantize_row_q4_1,
|
1437
|
+
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_1_reference,
|
1438
|
+
.quantize_row_q_dot = quantize_row_q4_1,
|
1439
|
+
.vec_dot_q = ggml_vec_dot_q4_1,
|
1440
|
+
},
|
1441
|
+
// TODO: GGML_TYPE_Q8_0
|
1442
|
+
};
|
1443
|
+
|
1444
|
+
// For internal test use
|
1445
|
+
quantize_fns_t ggml_internal_get_quantize_fn(size_t i) {
|
1446
|
+
GGML_ASSERT(i < GGML_TYPE_COUNT);
|
1447
|
+
return quantize_fns[i];
|
1448
|
+
}
|
1449
|
+
|
1450
|
+
|
1260
1451
|
//
|
1261
1452
|
// simd mappings
|
1262
1453
|
//
|
@@ -1813,34 +2004,188 @@ inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float
|
|
1813
2004
|
*s = sumf;
|
1814
2005
|
}
|
1815
2006
|
|
1816
|
-
#if __AVX512F__ &&
|
1817
|
-
static inline
|
2007
|
+
#if __AVX512F__ && QK4_0 == 32
|
2008
|
+
static inline __m512i bytes_from_q4_0_twoblocks_avx512( const __m512i blocks ) {
|
2009
|
+
// The 64 bytes of `blocks` contain two consecutive Q4_0 blocks loaded from memory:
|
2010
|
+
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
|
2011
|
+
// |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32|
|
2012
|
+
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
|
2013
|
+
// | :. =_ () [] <> () Zz Yy|
|
2014
|
+
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
|
2015
|
+
// |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00|
|
2016
|
+
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
|
2017
|
+
// |Xx Ww Vv Uu Tt Ss Rr Qq Pp Oo Nn Mm Ll Kk Jj Ii Hh Gg Ff Ee Dd Cc Bb Aa |
|
2018
|
+
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
|
2019
|
+
//
|
2020
|
+
// Bytes 04..19 (block #0) and 24..39 (block #1) both contain 32 nibbles (4-bit unsigned integers).
|
2021
|
+
// We have exactly 64 nibbles, so we want to place each nibble into a separate byte.
|
2022
|
+
// Bytes 00..03 and 20..23 contain scales, which are irrelevant to this function.
|
2023
|
+
// Bytes 40..63 are masked when loading the data, so they are zeroed out.
|
2024
|
+
#ifdef __AVX512VBMI__
|
2025
|
+
const __m512i byte_perm = _mm512_set_epi8(
|
2026
|
+
39, 38, 39, 38, 37, 36, 37, 36, 35, 34, 35, 34, 33, 32, 33, 32,
|
2027
|
+
31, 30, 31, 30, 29, 28, 29, 28, 27, 26, 27, 26, 25, 24, 25, 24,
|
2028
|
+
19, 18, 19, 18, 17, 16, 17, 16, 15, 14, 15, 14, 13, 12, 13, 12,
|
2029
|
+
11, 10, 11, 10, 9, 8, 9, 8, 7, 6, 7, 6, 5, 4, 5, 4
|
2030
|
+
);
|
2031
|
+
const __m512i permuted = _mm512_permutexvar_epi8( byte_perm, blocks );
|
2032
|
+
// After applying VPERMB, `permuted` looks like this:
|
2033
|
+
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2034
|
+
// |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32|
|
2035
|
+
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2036
|
+
// |:. =_ :. =_ () [] () [] <> () <> () Zz Yy Zz Yy Xx Ww Xx Ww Vv Uu Vv Uu Tt Ss Tt Ss Rr Qq Rr Qq|
|
2037
|
+
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2038
|
+
// |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00|
|
2039
|
+
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2040
|
+
// |Pp Oo Pp Oo Nn Mm Nn Mm Ll Kk Ll Kk Jj Ii Jj Ii Hh Gg Hh Gg Ff Ee Ff Ee Dd Cc Dd Cc Bb Aa Bb Aa|
|
2041
|
+
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2042
|
+
#else
|
2043
|
+
const __m512i word_perm = _mm512_set_epi16(
|
2044
|
+
19, 19, 18, 18, 17, 17, 16, 16, 15, 15, 14, 14, 13, 13, 12, 12,
|
2045
|
+
9, 9, 8, 8, 7, 7, 6, 6, 5, 5, 4, 4, 3, 3, 2, 2
|
2046
|
+
);
|
2047
|
+
const __m512i permuted = _mm512_permutexvar_epi16( word_perm, blocks );
|
2048
|
+
// This is the fallback path for CPUs that don't support VPERMB. Since we permute 16-bit groups only,
|
2049
|
+
// VPERMB can be replaced with VPERMW. We could always use VPERMW, but at least on Tiger Lake and
|
2050
|
+
// Ice Lake VPERMW followed by a right shift is quite noticeably slower than VPERMB.
|
2051
|
+
#endif
|
2052
|
+
|
2053
|
+
// Shift every odd-numbered 16-bit group to the right by 4 bits.
|
2054
|
+
const __mmask32 shift_mask = 0xaaaaaaaa;
|
2055
|
+
const __m512i shifted = _mm512_mask_srai_epi16( permuted, shift_mask, permuted, 4 );
|
2056
|
+
// After applying VPSRAW, `shifted` looks like this (the "empty" nibbles are filled with zeroes):
|
2057
|
+
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2058
|
+
// |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32
|
2059
|
+
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2060
|
+
// | : .= :. =_ ( )[ () [] < >( <> () Z zY Zz Yy X xW Xx Ww V vU Vv Uu T tS Tt Ss R rQ Rr Qq
|
2061
|
+
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2062
|
+
// |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00|
|
2063
|
+
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2064
|
+
// | P pO Pp Oo N nM Nn Mm L lK Ll Kk J jI Jj Ii H hG Hh Gg F fE Ff Ee D dC Dd Cc B bA Bb Aa|
|
2065
|
+
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2066
|
+
|
2067
|
+
// Now we just need to zero out the higher nibble in each byte, and we're done.
|
2068
|
+
const __m512i low_nibble_mask = _mm512_set1_epi8( 0xf );
|
2069
|
+
return _mm512_and_si512( low_nibble_mask, shifted );
|
2070
|
+
// The final result looks like this:
|
2071
|
+
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2072
|
+
// |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32|
|
2073
|
+
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2074
|
+
// | : = . _ ( [ ) ] < ( > ) Z Y z y X W x w V U v u T S t s R Q r q|
|
2075
|
+
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2076
|
+
// |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00|
|
2077
|
+
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2078
|
+
// | P O p o N M n m L K l k J I j i H G h g F E f e D C d c B A b a|
|
2079
|
+
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2080
|
+
}
|
2081
|
+
|
2082
|
+
static inline __m512 dot_q4_0_twoblocks_avx512(
|
1818
2083
|
__m512 acc,
|
1819
2084
|
const block_q4_0 * restrict x,
|
1820
2085
|
const block_q4_0 * restrict y,
|
1821
2086
|
int i
|
1822
2087
|
) {
|
1823
|
-
//
|
1824
|
-
|
1825
|
-
|
1826
|
-
|
1827
|
-
|
1828
|
-
|
1829
|
-
|
1830
|
-
|
1831
|
-
|
1832
|
-
|
1833
|
-
|
1834
|
-
//
|
1835
|
-
|
1836
|
-
|
1837
|
-
//
|
1838
|
-
|
2088
|
+
// A pair of Q4_0 blocks spans 40 bytes, while an AVX-512 register has 64. The remaining 24 bytes
|
2089
|
+
// can potentially be unaddressable, so we make sure to mask them out before the load, even though
|
2090
|
+
// we don't use them at all. This might hurt the performance slightly, since the compiler is forced
|
2091
|
+
// to use e.g. `VMOVDQU64 REG, MASK, [ADDR] + VPERMB ..., REG` instead of just `VPERMB ..., [ADDR]`.
|
2092
|
+
const __mmask8 load_mask = 0x1f;
|
2093
|
+
const __m512i blocks_0 = _mm512_maskz_loadu_epi64( load_mask, &x[i] );
|
2094
|
+
const __m512i blocks_1 = _mm512_maskz_loadu_epi64( load_mask, &y[i] );
|
2095
|
+
|
2096
|
+
// We want to multiply the scales, so we interpret both registers as 16 32-bit floats:
|
2097
|
+
// +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
|
2098
|
+
// | 15 | 14 | 13 | 12 | 11 | 10 | 09 | 08 | 07 | 06 | 05 | 04 | 03 | 02 | 01 | 00 |
|
2099
|
+
// +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
|
2100
|
+
// blocks_0_float
|
2101
|
+
// +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
|
2102
|
+
// | | | | | | | xx | xx | xx | xx | B | xx | xx | xx | xx | A |
|
2103
|
+
// +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
|
2104
|
+
// blocks_1_float
|
2105
|
+
// +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
|
2106
|
+
// | | | | | | | xx | xx | xx | xx | D | xx | xx | xx | xx | C |
|
2107
|
+
// +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
|
2108
|
+
const __m512 blocks_0_float = _mm512_castsi512_ps( blocks_0 );
|
2109
|
+
const __m512 blocks_1_float = _mm512_castsi512_ps( blocks_1 );
|
2110
|
+
// We absolutely shouldn't touch the floats marked with `xx`: they contain some
|
2111
|
+
// random data, which might very well underflow. At least on Intel, this leads
|
2112
|
+
// to a huge penalty that can't be ignored (easily 100x or more) unless you
|
2113
|
+
// compile your code with something like `-ffast-math` to enable FTZ/DAZ flags.
|
2114
|
+
// (and ggml can't assume that you do)...
|
2115
|
+
const __mmask16 scale_mul_mask = 0x21;
|
2116
|
+
#ifdef __clang__
|
2117
|
+
// ...however, clang decides to optimize the multiplication mask away:
|
2118
|
+
// https://godbolt.org/z/P8PqdsfvW
|
2119
|
+
// gcc and MSVC do the sane thing. This horrible workaround forces clang to emit the mask.
|
2120
|
+
__m512i scales;
|
2121
|
+
__asm__(
|
2122
|
+
"vmulps %1, %2, %0%{%3%}"
|
2123
|
+
: "=v" ( scales )
|
2124
|
+
: "vm" ( blocks_0_float ), "v" ( blocks_1_float ), "Yk" ( scale_mul_mask )
|
2125
|
+
);
|
2126
|
+
#else
|
2127
|
+
const __m512 scales = _mm512_maskz_mul_ps( scale_mul_mask, blocks_0_float, blocks_1_float );
|
2128
|
+
#endif
|
2129
|
+
const __m512i scale_perm = _mm512_set_epi32(
|
2130
|
+
5, 5, 5, 5, 5, 5, 5, 5,
|
2131
|
+
0, 0, 0, 0, 0, 0, 0, 0
|
2132
|
+
);
|
2133
|
+
const __m512 permuted_scales = _mm512_permutexvar_ps( scale_perm, scales );
|
2134
|
+
// After VMULPS and VPERMPS, `permuted_scales` looks like this:
|
2135
|
+
// +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
|
2136
|
+
// | 15 | 14 | 13 | 12 | 11 | 10 | 09 | 08 | 07 | 06 | 05 | 04 | 03 | 02 | 01 | 00 |
|
2137
|
+
// +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
|
2138
|
+
// | B*D| B*D| B*D| B*D| B*D| B*D| B*D| B*D| A*C| A*C| A*C| A*C| A*C| A*C| A*C| A*C|
|
2139
|
+
// +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
|
2140
|
+
|
2141
|
+
const __m512i bytes_0 = bytes_from_q4_0_twoblocks_avx512( blocks_0 );
|
2142
|
+
const __m512i bytes_1 = bytes_from_q4_0_twoblocks_avx512( blocks_1 );
|
2143
|
+
|
2144
|
+
// Now we want to compute dot products of 4-element byte vectors and store them in
|
2145
|
+
// 32-bit integers. That is (only one 4-element vector is shown for clarity):
|
2146
|
+
// +----+----+----+----+
|
2147
|
+
// ... | 03 | 02 | 01 | 00 |
|
2148
|
+
// +----+----+----+----+
|
2149
|
+
// bytes_0
|
2150
|
+
// +----+----+----+----+
|
2151
|
+
// ... | D | C | B | A |
|
2152
|
+
// +----+----+----+----+
|
2153
|
+
// bytes_1
|
2154
|
+
// +----+----+----+----+
|
2155
|
+
// ... | H | G | F | E |
|
2156
|
+
// +----+----+----+----+
|
2157
|
+
// final_res_int
|
2158
|
+
// +----+----+----+----+
|
2159
|
+
// ... | A*E+B*F+C*G+D*H |
|
2160
|
+
// +----+----+----+----+
|
2161
|
+
const __m512i plus_8 = _mm512_set1_epi8( 8 );
|
2162
|
+
const __m512i bytes_1_minus_8 = _mm512_sub_epi8( bytes_1, plus_8 );
|
2163
|
+
|
2164
|
+
#ifdef __AVX512VNNI__
|
2165
|
+
// We have VPDPBUSDS in AVX512-VNNI, which does exactly what we want, but with a catch:
|
2166
|
+
// the *left* operand is supposed to be unsigned, while Q4_0 quantization subtracts 8
|
2167
|
+
// from each nibble, so they can be negative. So, instead of `(bytes_0 - 8) * (bytes_1 - 8)`,
|
2168
|
+
// we compute `bytes_0 * (bytes_1 - 8) + bytes_1 * (-8) + 64`. VPDPBUSDS uses an accumulator,
|
2169
|
+
// which means we only need 2 instructions.
|
2170
|
+
const __m512i dot_init = _mm512_set1_epi32( 4 * 64 );
|
2171
|
+
const __m512i minus_8 = _mm512_set1_epi8( -8 );
|
2172
|
+
const __m512i prod_0 = _mm512_dpbusds_epi32( dot_init, bytes_1, minus_8 );
|
2173
|
+
const __m512i final_res_int = _mm512_dpbusds_epi32( prod_0, bytes_0, bytes_1_minus_8 );
|
2174
|
+
#else
|
2175
|
+
// As a fallback, we have VPMADDUBSW in AVX512-BW, which uses 16-bit products instead of 32-bit ones.
|
2176
|
+
// It has the same catch as VPDPBUSDS: the left operand should be unsigned.
|
2177
|
+
// This is essentially the AVX-512 version of the AVX-2 trick used by GH user Const-me
|
2178
|
+
// ref: https://gist.github.com/Const-me/4d30e1fc767ab314596e16e90f53b6f4#file-matmultest-cpp-L119
|
2179
|
+
const __m512i one = _mm512_set1_epi16( 1 );
|
2180
|
+
const __m512i prod_0 = _mm512_maddubs_epi16( bytes_0, bytes_1_minus_8 );
|
2181
|
+
const __m512i prod_1 = _mm512_maddubs_epi16( plus_8, bytes_1_minus_8 );
|
2182
|
+
const __m512i diff = _mm512_sub_epi16( prod_0, prod_1 );
|
2183
|
+
const __m512i final_res_int = _mm512_madd_epi16( diff, one );
|
2184
|
+
#endif
|
1839
2185
|
|
1840
|
-
//
|
1841
|
-
__m512
|
1842
|
-
|
1843
|
-
return _mm512_fmadd_ps( d, p, acc );
|
2186
|
+
// Finally, we multiply the permuted scales and the 32-bit dot products, then accumulate.
|
2187
|
+
const __m512 final_res_float = _mm512_cvtepi32_ps( final_res_int );
|
2188
|
+
return _mm512_fmadd_ps( permuted_scales, final_res_float, acc );
|
1844
2189
|
}
|
1845
2190
|
#endif
|
1846
2191
|
|
@@ -1881,9 +2226,9 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
|
|
1881
2226
|
}
|
1882
2227
|
|
1883
2228
|
static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
1884
|
-
const int nb = n /
|
2229
|
+
const int nb = n / QK4_0;
|
1885
2230
|
|
1886
|
-
assert(n %
|
2231
|
+
assert(n % QK4_0 == 0);
|
1887
2232
|
assert(nb % 2 == 0);
|
1888
2233
|
|
1889
2234
|
const block_q4_0 * restrict x = vx;
|
@@ -1972,25 +2317,26 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
|
|
1972
2317
|
__m512 acc0 = _mm512_setzero_ps();
|
1973
2318
|
__m512 acc1 = _mm512_setzero_ps();
|
1974
2319
|
|
1975
|
-
const int superblock_size =
|
2320
|
+
const int superblock_size = 16;
|
2321
|
+
|
1976
2322
|
const int superblock_count = nb / superblock_size;
|
1977
2323
|
|
1978
2324
|
for (int superblock_ix = 0; superblock_ix < superblock_count; superblock_ix += 1) {
|
1979
2325
|
int i = superblock_ix * superblock_size;
|
1980
2326
|
|
1981
|
-
acc0 =
|
1982
|
-
acc1 =
|
1983
|
-
acc0 =
|
1984
|
-
acc1 =
|
1985
|
-
acc0 =
|
1986
|
-
acc1 =
|
1987
|
-
acc0 =
|
1988
|
-
acc1 =
|
2327
|
+
acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i+0 );
|
2328
|
+
acc1 = dot_q4_0_twoblocks_avx512( acc1, x, y, i+2 );
|
2329
|
+
acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i+4 );
|
2330
|
+
acc1 = dot_q4_0_twoblocks_avx512( acc1, x, y, i+6 );
|
2331
|
+
acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i+8 );
|
2332
|
+
acc1 = dot_q4_0_twoblocks_avx512( acc1, x, y, i+10 );
|
2333
|
+
acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i+12 );
|
2334
|
+
acc1 = dot_q4_0_twoblocks_avx512( acc1, x, y, i+14 );
|
1989
2335
|
}
|
1990
2336
|
|
1991
2337
|
// Remainders
|
1992
|
-
for (int i = superblock_count * superblock_size; i < nb;
|
1993
|
-
acc0 =
|
2338
|
+
for (int i = superblock_count * superblock_size; i < nb; i += 2) {
|
2339
|
+
acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i );
|
1994
2340
|
}
|
1995
2341
|
|
1996
2342
|
// Horizontal sum of all lanes of the accumulator
|
@@ -2206,15 +2552,15 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
|
|
2206
2552
|
const uint8_t * restrict p1 = y[i].qs;
|
2207
2553
|
|
2208
2554
|
int sumi = 0;
|
2209
|
-
for (int j = 0; j <
|
2555
|
+
for (int j = 0; j < QK4_0/2; j++) {
|
2210
2556
|
const uint8_t v0 = p0[j];
|
2211
2557
|
const uint8_t v1 = p1[j];
|
2212
2558
|
|
2213
|
-
const
|
2214
|
-
const
|
2559
|
+
const int i0 = (v0 & 0xf) - 8;
|
2560
|
+
const int i1 = (v0 >> 4) - 8;
|
2215
2561
|
|
2216
|
-
const
|
2217
|
-
const
|
2562
|
+
const int i2 = (v1 & 0xf) - 8;
|
2563
|
+
const int i3 = (v1 >> 4) - 8;
|
2218
2564
|
|
2219
2565
|
sumi += i0*i2 + i1*i3;
|
2220
2566
|
}
|
@@ -2226,7 +2572,7 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
|
|
2226
2572
|
}
|
2227
2573
|
|
2228
2574
|
static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
2229
|
-
const int nb = n /
|
2575
|
+
const int nb = n / QK4_1;
|
2230
2576
|
|
2231
2577
|
const block_q4_1 * restrict x = vx;
|
2232
2578
|
const block_q4_1 * restrict y = vy;
|
@@ -2303,7 +2649,7 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest
|
|
2303
2649
|
res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
|
2304
2650
|
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
|
2305
2651
|
|
2306
|
-
sumf = _mm_cvtss_f32( res ) + acc_offset *
|
2652
|
+
sumf = _mm_cvtss_f32( res ) + acc_offset * QK4_1;
|
2307
2653
|
#elif defined(__ARM_NEON)
|
2308
2654
|
float sum00 = 0.0f;
|
2309
2655
|
float sum01 = 0.0f;
|
@@ -2335,12 +2681,12 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest
|
|
2335
2681
|
const uint8x16_t v1_1h = vshrq_n_u8(v1_1, 4);
|
2336
2682
|
|
2337
2683
|
sum00 += x0->m*y0->m;
|
2338
|
-
sum01 += y0->m*x0->d*(vaddvq_u8(v0_0l) + vaddvq_u8(v0_0h));
|
2339
|
-
sum10 += x0->m*y0->d*(vaddvq_u8(v1_0l) + vaddvq_u8(v1_0h));
|
2684
|
+
sum01 += y0->m*x0->d*((uint16_t)vaddvq_u8(v0_0l) + (uint16_t)vaddvq_u8(v0_0h));
|
2685
|
+
sum10 += x0->m*y0->d*((uint16_t)vaddvq_u8(v1_0l) + (uint16_t)vaddvq_u8(v1_0h));
|
2340
2686
|
|
2341
2687
|
sum00 += x1->m*y1->m;
|
2342
|
-
sum01 += y1->m*x1->d*(vaddvq_u8(v0_1l) + vaddvq_u8(v0_1h));
|
2343
|
-
sum10 += x1->m*y1->d*(vaddvq_u8(v1_1l) + vaddvq_u8(v1_1h));
|
2688
|
+
sum01 += y1->m*x1->d*((uint16_t)vaddvq_u8(v0_1l) + (uint16_t)vaddvq_u8(v0_1h));
|
2689
|
+
sum10 += x1->m*y1->d*((uint16_t)vaddvq_u8(v1_1l) + (uint16_t)vaddvq_u8(v1_1h));
|
2344
2690
|
|
2345
2691
|
#if defined(__ARM_FEATURE_DOTPROD)
|
2346
2692
|
// dot product into int32x4_t
|
@@ -2377,7 +2723,7 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest
|
|
2377
2723
|
#endif
|
2378
2724
|
}
|
2379
2725
|
|
2380
|
-
sumf =
|
2726
|
+
sumf = QK4_1*sum00 + sum01 + sum10 + sum11;
|
2381
2727
|
#else
|
2382
2728
|
// scalar
|
2383
2729
|
for (int i = 0; i < nb; i++) {
|
@@ -2390,7 +2736,7 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest
|
|
2390
2736
|
const uint8_t * restrict p0 = x[i].qs;
|
2391
2737
|
const uint8_t * restrict p1 = y[i].qs;
|
2392
2738
|
|
2393
|
-
for (int j = 0; j <
|
2739
|
+
for (int j = 0; j < QK4_1/2; j++) {
|
2394
2740
|
const uint8_t v0 = p0[j];
|
2395
2741
|
const uint8_t v1 = p1[j];
|
2396
2742
|
|
@@ -2408,78 +2754,281 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest
|
|
2408
2754
|
*s = sumf;
|
2409
2755
|
}
|
2410
2756
|
|
2411
|
-
|
2412
|
-
|
2413
|
-
inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) {
|
2414
|
-
ggml_float sumf[GGML_VEC_DOT_UNROLL] = { 0.0 };
|
2415
|
-
|
2416
|
-
ggml_fp16_t * restrict x[GGML_VEC_DOT_UNROLL];
|
2757
|
+
static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
2758
|
+
const int nb = n / QK8_0;
|
2417
2759
|
|
2418
|
-
|
2419
|
-
|
2420
|
-
}
|
2760
|
+
assert(n % QK8_0 == 0);
|
2761
|
+
assert(nb % 2 == 0);
|
2421
2762
|
|
2422
|
-
|
2423
|
-
const
|
2763
|
+
const block_q4_0 * restrict x = vx;
|
2764
|
+
const block_q8_0 * restrict y = vy;
|
2424
2765
|
|
2425
|
-
|
2766
|
+
float sumf = 0.0;
|
2426
2767
|
|
2427
|
-
|
2428
|
-
|
2768
|
+
#if defined(__ARM_NEON)
|
2769
|
+
float sum0 = 0.0f;
|
2770
|
+
float sum1 = 0.0f;
|
2429
2771
|
|
2430
|
-
for (int i = 0; i <
|
2431
|
-
|
2432
|
-
|
2772
|
+
for (int i = 0; i < nb; i += 2) {
|
2773
|
+
const block_q4_0 * restrict x0 = &x[i + 0];
|
2774
|
+
const block_q4_0 * restrict x1 = &x[i + 1];
|
2775
|
+
const block_q8_0 * restrict y0 = &y[i + 0];
|
2776
|
+
const block_q8_0 * restrict y1 = &y[i + 1];
|
2433
2777
|
|
2434
|
-
|
2435
|
-
|
2778
|
+
const uint8x16_t m4b = vdupq_n_u8(0xf);
|
2779
|
+
const int8x16_t s8b = vdupq_n_s8(0x8);
|
2436
2780
|
|
2437
|
-
|
2438
|
-
|
2439
|
-
}
|
2440
|
-
}
|
2781
|
+
const uint8x16_t v0_0 = vld1q_u8(x0->qs);
|
2782
|
+
const uint8x16_t v0_1 = vld1q_u8(x1->qs);
|
2441
2783
|
|
2442
|
-
|
2443
|
-
|
2444
|
-
|
2445
|
-
|
2784
|
+
// 4-bit -> 8-bit
|
2785
|
+
const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
|
2786
|
+
const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
|
2787
|
+
const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
|
2788
|
+
const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
|
2446
2789
|
|
2447
|
-
|
2448
|
-
|
2449
|
-
|
2450
|
-
|
2451
|
-
|
2452
|
-
}
|
2453
|
-
#else
|
2454
|
-
for (int i = 0; i < n; ++i) {
|
2455
|
-
for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
|
2456
|
-
sumf[j] += (ggml_float)(GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i]));
|
2457
|
-
}
|
2458
|
-
}
|
2459
|
-
#endif
|
2790
|
+
// sub 8
|
2791
|
+
const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
|
2792
|
+
const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
|
2793
|
+
const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
|
2794
|
+
const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
|
2460
2795
|
|
2461
|
-
|
2462
|
-
|
2463
|
-
|
2464
|
-
|
2796
|
+
// load y
|
2797
|
+
const int8x16_t v1_0l = vld1q_s8(y0->qs);
|
2798
|
+
const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
|
2799
|
+
const int8x16_t v1_1l = vld1q_s8(y1->qs);
|
2800
|
+
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
|
2465
2801
|
|
2466
|
-
|
2467
|
-
|
2468
|
-
|
2802
|
+
// interleave
|
2803
|
+
const int8x16_t v1_0ls = vuzp1q_s8(v1_0l, v1_0h);
|
2804
|
+
const int8x16_t v1_0hs = vuzp2q_s8(v1_0l, v1_0h);
|
2805
|
+
const int8x16_t v1_1ls = vuzp1q_s8(v1_1l, v1_1h);
|
2806
|
+
const int8x16_t v1_1hs = vuzp2q_s8(v1_1l, v1_1h);
|
2469
2807
|
|
2470
|
-
|
2808
|
+
#if defined(__ARM_FEATURE_DOTPROD)
|
2809
|
+
// dot product into int32x4_t
|
2810
|
+
int32x4_t p_0 = vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls);
|
2811
|
+
int32x4_t p_1 = vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls);
|
2471
2812
|
|
2472
|
-
|
2473
|
-
|
2813
|
+
p_0 = vdotq_s32(p_0, v0_0hs, v1_0hs);
|
2814
|
+
p_1 = vdotq_s32(p_1, v0_1hs, v1_1hs);
|
2474
2815
|
|
2475
|
-
|
2476
|
-
|
2477
|
-
|
2478
|
-
|
2479
|
-
|
2816
|
+
sum0 += x0->d*y0->d*vaddvq_s32(p_0);
|
2817
|
+
sum1 += x1->d*y1->d*vaddvq_s32(p_1);
|
2818
|
+
#else
|
2819
|
+
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
|
2820
|
+
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
|
2821
|
+
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs));
|
2822
|
+
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs));
|
2480
2823
|
|
2481
|
-
|
2482
|
-
|
2824
|
+
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls));
|
2825
|
+
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls));
|
2826
|
+
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
|
2827
|
+
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
|
2828
|
+
|
2829
|
+
const int16x8_t pl_0 = vaddq_s16(pl0l, pl0h);
|
2830
|
+
const int16x8_t ph_0 = vaddq_s16(ph0l, ph0h);
|
2831
|
+
|
2832
|
+
const int16x8_t pl_1 = vaddq_s16(pl1l, pl1h);
|
2833
|
+
const int16x8_t ph_1 = vaddq_s16(ph1l, ph1h);
|
2834
|
+
|
2835
|
+
const int16x8_t p_0 = vaddq_s16(pl_0, ph_0);
|
2836
|
+
const int16x8_t p_1 = vaddq_s16(pl_1, ph_1);
|
2837
|
+
|
2838
|
+
sum0 += x0->d*y0->d*vaddvq_s16(p_0);
|
2839
|
+
sum1 += x1->d*y1->d*vaddvq_s16(p_1);
|
2840
|
+
#endif
|
2841
|
+
}
|
2842
|
+
|
2843
|
+
sumf = sum0 + sum1;
|
2844
|
+
#elif defined(__AVX2__)
|
2845
|
+
// Initialize accumulator with zeros
|
2846
|
+
__m256 acc = _mm256_setzero_ps();
|
2847
|
+
|
2848
|
+
// Main loop
|
2849
|
+
for (int i = 0; i < nb; ++i) {
|
2850
|
+
/* Compute combined scale for the block */
|
2851
|
+
const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
|
2852
|
+
|
2853
|
+
__m256i bx = bytesFromNibbles(x[i].qs);
|
2854
|
+
|
2855
|
+
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
|
2856
|
+
const __m256i off = _mm256_set1_epi8( 8 );
|
2857
|
+
bx = _mm256_sub_epi8( bx, off );
|
2858
|
+
|
2859
|
+
__m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
|
2860
|
+
|
2861
|
+
// Get absolute values of x vectors
|
2862
|
+
const __m256i ax = _mm256_sign_epi8(bx, bx);
|
2863
|
+
|
2864
|
+
// Sign the values of the y vectors
|
2865
|
+
const __m256i sy = _mm256_sign_epi8(by, bx);
|
2866
|
+
|
2867
|
+
// Perform multiplication and create 16-bit values
|
2868
|
+
const __m256i dot = _mm256_maddubs_epi16(ax, sy);
|
2869
|
+
|
2870
|
+
const __m256i ones = _mm256_set1_epi16(1);
|
2871
|
+
__m256i xy_q = _mm256_madd_epi16(ones, dot);
|
2872
|
+
|
2873
|
+
/* Convert to vectore of 8 int32_t to 8 floats */
|
2874
|
+
__m256 q = _mm256_cvtepi32_ps( xy_q );
|
2875
|
+
|
2876
|
+
/* Multiply q with scale and accumulate */
|
2877
|
+
acc = _mm256_fmadd_ps( d, q, acc );
|
2878
|
+
}
|
2879
|
+
|
2880
|
+
// Return horizontal sum of the acc vector
|
2881
|
+
__m128 res = _mm256_extractf128_ps( acc, 1 );
|
2882
|
+
res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
|
2883
|
+
res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
|
2884
|
+
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
|
2885
|
+
|
2886
|
+
sumf = _mm_cvtss_f32( res );
|
2887
|
+
#elif defined(__AVX__)
|
2888
|
+
// Initialize accumulator with zeros
|
2889
|
+
__m256 acc = _mm256_setzero_ps();
|
2890
|
+
|
2891
|
+
// Main loop
|
2892
|
+
for (int i = 0; i < nb; ++i) {
|
2893
|
+
// Compute combined scale for the block
|
2894
|
+
const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
|
2895
|
+
|
2896
|
+
__m128i i32[2];
|
2897
|
+
for (int j = 0; j < 2; ++j) {
|
2898
|
+
// Load 8 bytes, and unpack 4 bit fields into bytes, making 16 bytes
|
2899
|
+
__m128i bx = bytesFromNibbles( x[i].qs + 8*j );
|
2900
|
+
__m128i by = _mm_loadu_si128((const __m128i *)(y[i].qs + 16*j));
|
2901
|
+
|
2902
|
+
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
|
2903
|
+
const __m128i off = _mm_set1_epi8( 8 );
|
2904
|
+
bx = _mm_sub_epi8( bx, off );
|
2905
|
+
|
2906
|
+
// Get absolute values of x vectors
|
2907
|
+
const __m128i ax = _mm_sign_epi8(bx, bx);
|
2908
|
+
|
2909
|
+
// Sign the values of the y vectors
|
2910
|
+
const __m128i sy = _mm_sign_epi8(by, bx);
|
2911
|
+
|
2912
|
+
// Perform multiplication and create 16-bit values
|
2913
|
+
const __m128i dot = _mm_maddubs_epi16(ax, sy);
|
2914
|
+
|
2915
|
+
const __m128i ones = _mm_set1_epi16(1);
|
2916
|
+
i32[j] = _mm_madd_epi16(ones, dot);
|
2917
|
+
}
|
2918
|
+
|
2919
|
+
// Convert int32_t to float
|
2920
|
+
__m256 p = _mm256_cvtepi32_ps( _mm256_set_m128i( i32[0], i32[1] ));
|
2921
|
+
// Apply the scale, and accumulate
|
2922
|
+
acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc);
|
2923
|
+
}
|
2924
|
+
|
2925
|
+
// Return horizontal sum of the acc vector
|
2926
|
+
__m128 res = _mm256_extractf128_ps( acc, 1 );
|
2927
|
+
res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
|
2928
|
+
res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
|
2929
|
+
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
|
2930
|
+
|
2931
|
+
sumf = _mm_cvtss_f32( res );
|
2932
|
+
#else
|
2933
|
+
// scalar
|
2934
|
+
for (int i = 0; i < nb; i++) {
|
2935
|
+
const float d0 = x[i].d;
|
2936
|
+
const float d1 = y[i].d;
|
2937
|
+
|
2938
|
+
const uint8_t * restrict p0 = x[i].qs;
|
2939
|
+
const int8_t * restrict p1 = y[i].qs;
|
2940
|
+
|
2941
|
+
int sumi = 0;
|
2942
|
+
for (int j = 0; j < QK8_0/2; j++) {
|
2943
|
+
const uint8_t v0 = p0[j];
|
2944
|
+
|
2945
|
+
const int i0 = (int8_t) (v0 & 0xf) - 8;
|
2946
|
+
const int i1 = (int8_t) (v0 >> 4) - 8;
|
2947
|
+
|
2948
|
+
const int i2 = p1[2*j + 0];
|
2949
|
+
const int i3 = p1[2*j + 1];
|
2950
|
+
|
2951
|
+
sumi += i0*i2 + i1*i3;
|
2952
|
+
}
|
2953
|
+
sumf += d0*d1*sumi;
|
2954
|
+
}
|
2955
|
+
#endif
|
2956
|
+
|
2957
|
+
*s = sumf;
|
2958
|
+
}
|
2959
|
+
|
2960
|
+
// compute GGML_VEC_DOT_UNROLL dot products at once
|
2961
|
+
// xs - x row stride in bytes
|
2962
|
+
inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) {
|
2963
|
+
ggml_float sumf[GGML_VEC_DOT_UNROLL] = { 0.0 };
|
2964
|
+
|
2965
|
+
ggml_fp16_t * restrict x[GGML_VEC_DOT_UNROLL];
|
2966
|
+
|
2967
|
+
for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) {
|
2968
|
+
x[i] = (ggml_fp16_t *) ((char *) xv + i*xs);
|
2969
|
+
}
|
2970
|
+
|
2971
|
+
#if defined(GGML_SIMD)
|
2972
|
+
const int np = (n & ~(GGML_F16_STEP - 1));
|
2973
|
+
|
2974
|
+
GGML_F16_VEC sum[GGML_VEC_DOT_UNROLL][GGML_F16_ARR] = { { GGML_F16_VEC_ZERO } };
|
2975
|
+
|
2976
|
+
GGML_F16_VEC ax[GGML_F16_ARR];
|
2977
|
+
GGML_F16_VEC ay[GGML_F16_ARR];
|
2978
|
+
|
2979
|
+
for (int i = 0; i < np; i += GGML_F16_STEP) {
|
2980
|
+
for (int j = 0; j < GGML_F16_ARR; j++) {
|
2981
|
+
ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
|
2982
|
+
|
2983
|
+
for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) {
|
2984
|
+
ax[j] = GGML_F16_VEC_LOAD(x[k] + i + j*GGML_F16_EPR, j);
|
2985
|
+
|
2986
|
+
sum[k][j] = GGML_F16_VEC_FMA(sum[k][j], ax[j], ay[j]);
|
2987
|
+
}
|
2988
|
+
}
|
2989
|
+
}
|
2990
|
+
|
2991
|
+
// reduce sum0..sum3 to sum0
|
2992
|
+
for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) {
|
2993
|
+
GGML_F16_VEC_REDUCE(sumf[k], sum[k]);
|
2994
|
+
}
|
2995
|
+
|
2996
|
+
// leftovers
|
2997
|
+
for (int i = np; i < n; ++i) {
|
2998
|
+
for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
|
2999
|
+
sumf[j] += (ggml_float)(GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i]));
|
3000
|
+
}
|
3001
|
+
}
|
3002
|
+
#else
|
3003
|
+
for (int i = 0; i < n; ++i) {
|
3004
|
+
for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
|
3005
|
+
sumf[j] += (ggml_float)(GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i]));
|
3006
|
+
}
|
3007
|
+
}
|
3008
|
+
#endif
|
3009
|
+
|
3010
|
+
for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) {
|
3011
|
+
s[i] = sumf[i];
|
3012
|
+
}
|
3013
|
+
}
|
3014
|
+
|
3015
|
+
inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float * restrict x, const float v) {
|
3016
|
+
#if defined(GGML_SIMD)
|
3017
|
+
const int np = (n & ~(GGML_F32_STEP - 1));
|
3018
|
+
|
3019
|
+
GGML_F32_VEC vx = GGML_F32_VEC_SET1(v);
|
3020
|
+
|
3021
|
+
GGML_F32_VEC ax[GGML_F32_ARR];
|
3022
|
+
GGML_F32_VEC ay[GGML_F32_ARR];
|
3023
|
+
|
3024
|
+
for (int i = 0; i < np; i += GGML_F32_STEP) {
|
3025
|
+
for (int j = 0; j < GGML_F32_ARR; j++) {
|
3026
|
+
ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR);
|
3027
|
+
ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
|
3028
|
+
ay[j] = GGML_F32_VEC_FMA(ay[j], ax[j], vx);
|
3029
|
+
|
3030
|
+
GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
|
3031
|
+
}
|
2483
3032
|
}
|
2484
3033
|
|
2485
3034
|
// leftovers
|
@@ -2652,24 +3201,26 @@ inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x
|
|
2652
3201
|
static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = {
|
2653
3202
|
[GGML_TYPE_F32] = 1,
|
2654
3203
|
[GGML_TYPE_F16] = 1,
|
2655
|
-
[GGML_TYPE_Q4_0] =
|
2656
|
-
[GGML_TYPE_Q4_1] =
|
3204
|
+
[GGML_TYPE_Q4_0] = QK4_0,
|
3205
|
+
[GGML_TYPE_Q4_1] = QK4_1,
|
3206
|
+
[GGML_TYPE_Q8_0] = QK8_0,
|
2657
3207
|
[GGML_TYPE_I8] = 1,
|
2658
3208
|
[GGML_TYPE_I16] = 1,
|
2659
3209
|
[GGML_TYPE_I32] = 1,
|
2660
3210
|
};
|
2661
|
-
static_assert(GGML_TYPE_COUNT ==
|
3211
|
+
static_assert(GGML_TYPE_COUNT == 8, "GGML_BLCK_SIZE is outdated");
|
2662
3212
|
|
2663
3213
|
static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
|
2664
3214
|
[GGML_TYPE_F32] = sizeof(float),
|
2665
3215
|
[GGML_TYPE_F16] = sizeof(ggml_fp16_t),
|
2666
3216
|
[GGML_TYPE_Q4_0] = sizeof(block_q4_0),
|
2667
3217
|
[GGML_TYPE_Q4_1] = sizeof(block_q4_1),
|
3218
|
+
[GGML_TYPE_Q8_0] = sizeof(block_q8_0),
|
2668
3219
|
[GGML_TYPE_I8] = sizeof(int8_t),
|
2669
3220
|
[GGML_TYPE_I16] = sizeof(int16_t),
|
2670
3221
|
[GGML_TYPE_I32] = sizeof(int32_t),
|
2671
3222
|
};
|
2672
|
-
static_assert(GGML_TYPE_COUNT ==
|
3223
|
+
static_assert(GGML_TYPE_COUNT == 8, "GGML_TYPE_SIZE is outdated");
|
2673
3224
|
|
2674
3225
|
|
2675
3226
|
static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
|
@@ -2677,11 +3228,12 @@ static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
|
|
2677
3228
|
[GGML_TYPE_F16] = "f16",
|
2678
3229
|
[GGML_TYPE_Q4_0] = "q4_0",
|
2679
3230
|
[GGML_TYPE_Q4_1] = "q4_1",
|
3231
|
+
[GGML_TYPE_Q8_0] = "q8_0",
|
2680
3232
|
[GGML_TYPE_I8] = "i8",
|
2681
3233
|
[GGML_TYPE_I16] = "i16",
|
2682
3234
|
[GGML_TYPE_I32] = "i32",
|
2683
3235
|
};
|
2684
|
-
static_assert(GGML_TYPE_COUNT ==
|
3236
|
+
static_assert(GGML_TYPE_COUNT == 8, "GGML_TYPE_NAME is outdated");
|
2685
3237
|
|
2686
3238
|
static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
|
2687
3239
|
"NONE",
|
@@ -3354,14 +3906,6 @@ struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) {
|
|
3354
3906
|
char * const data = tensor->data;
|
3355
3907
|
|
3356
3908
|
switch (tensor->type) {
|
3357
|
-
case GGML_TYPE_Q4_0:
|
3358
|
-
{
|
3359
|
-
GGML_ASSERT(false);
|
3360
|
-
} break;
|
3361
|
-
case GGML_TYPE_Q4_1:
|
3362
|
-
{
|
3363
|
-
GGML_ASSERT(false);
|
3364
|
-
} break;
|
3365
3909
|
case GGML_TYPE_I8:
|
3366
3910
|
{
|
3367
3911
|
assert(tensor->nb[0] == sizeof(int8_t));
|
@@ -3397,7 +3941,7 @@ struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) {
|
|
3397
3941
|
ggml_vec_set_f32(nc, (float *)(data + i*n1), value);
|
3398
3942
|
}
|
3399
3943
|
} break;
|
3400
|
-
|
3944
|
+
default:
|
3401
3945
|
{
|
3402
3946
|
GGML_ASSERT(false);
|
3403
3947
|
} break;
|
@@ -3414,14 +3958,6 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
|
|
3414
3958
|
char * const data = tensor->data;
|
3415
3959
|
|
3416
3960
|
switch (tensor->type) {
|
3417
|
-
case GGML_TYPE_Q4_0:
|
3418
|
-
{
|
3419
|
-
GGML_ASSERT(false);
|
3420
|
-
} break;
|
3421
|
-
case GGML_TYPE_Q4_1:
|
3422
|
-
{
|
3423
|
-
GGML_ASSERT(false);
|
3424
|
-
} break;
|
3425
3961
|
case GGML_TYPE_I8:
|
3426
3962
|
{
|
3427
3963
|
assert(tensor->nb[0] == sizeof(int8_t));
|
@@ -3457,7 +3993,7 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
|
|
3457
3993
|
ggml_vec_set_f32(nc, (float *)(data + i*n1), value);
|
3458
3994
|
}
|
3459
3995
|
} break;
|
3460
|
-
|
3996
|
+
default:
|
3461
3997
|
{
|
3462
3998
|
GGML_ASSERT(false);
|
3463
3999
|
} break;
|
@@ -3468,14 +4004,6 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
|
|
3468
4004
|
|
3469
4005
|
int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) {
|
3470
4006
|
switch (tensor->type) {
|
3471
|
-
case GGML_TYPE_Q4_0:
|
3472
|
-
{
|
3473
|
-
GGML_ASSERT(false);
|
3474
|
-
} break;
|
3475
|
-
case GGML_TYPE_Q4_1:
|
3476
|
-
{
|
3477
|
-
GGML_ASSERT(false);
|
3478
|
-
} break;
|
3479
4007
|
case GGML_TYPE_I8:
|
3480
4008
|
{
|
3481
4009
|
GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
|
@@ -3501,7 +4029,7 @@ int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) {
|
|
3501
4029
|
GGML_ASSERT(tensor->nb[0] == sizeof(float));
|
3502
4030
|
return ((float *)(tensor->data))[i];
|
3503
4031
|
} break;
|
3504
|
-
|
4032
|
+
default:
|
3505
4033
|
{
|
3506
4034
|
GGML_ASSERT(false);
|
3507
4035
|
} break;
|
@@ -3512,14 +4040,6 @@ int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) {
|
|
3512
4040
|
|
3513
4041
|
void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
|
3514
4042
|
switch (tensor->type) {
|
3515
|
-
case GGML_TYPE_Q4_0:
|
3516
|
-
{
|
3517
|
-
GGML_ASSERT(false);
|
3518
|
-
} break;
|
3519
|
-
case GGML_TYPE_Q4_1:
|
3520
|
-
{
|
3521
|
-
GGML_ASSERT(false);
|
3522
|
-
} break;
|
3523
4043
|
case GGML_TYPE_I8:
|
3524
4044
|
{
|
3525
4045
|
GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
|
@@ -3545,7 +4065,7 @@ void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
|
|
3545
4065
|
GGML_ASSERT(tensor->nb[0] == sizeof(float));
|
3546
4066
|
((float *)(tensor->data))[i] = value;
|
3547
4067
|
} break;
|
3548
|
-
|
4068
|
+
default:
|
3549
4069
|
{
|
3550
4070
|
GGML_ASSERT(false);
|
3551
4071
|
} break;
|
@@ -3554,14 +4074,6 @@ void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
|
|
3554
4074
|
|
3555
4075
|
float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {
|
3556
4076
|
switch (tensor->type) {
|
3557
|
-
case GGML_TYPE_Q4_0:
|
3558
|
-
{
|
3559
|
-
GGML_ASSERT(false);
|
3560
|
-
} break;
|
3561
|
-
case GGML_TYPE_Q4_1:
|
3562
|
-
{
|
3563
|
-
GGML_ASSERT(false);
|
3564
|
-
} break;
|
3565
4077
|
case GGML_TYPE_I8:
|
3566
4078
|
{
|
3567
4079
|
GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
|
@@ -3587,7 +4099,7 @@ float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {
|
|
3587
4099
|
GGML_ASSERT(tensor->nb[0] == sizeof(float));
|
3588
4100
|
return ((float *)(tensor->data))[i];
|
3589
4101
|
} break;
|
3590
|
-
|
4102
|
+
default:
|
3591
4103
|
{
|
3592
4104
|
GGML_ASSERT(false);
|
3593
4105
|
} break;
|
@@ -3598,14 +4110,6 @@ float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {
|
|
3598
4110
|
|
3599
4111
|
void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) {
|
3600
4112
|
switch (tensor->type) {
|
3601
|
-
case GGML_TYPE_Q4_0:
|
3602
|
-
{
|
3603
|
-
GGML_ASSERT(false);
|
3604
|
-
} break;
|
3605
|
-
case GGML_TYPE_Q4_1:
|
3606
|
-
{
|
3607
|
-
GGML_ASSERT(false);
|
3608
|
-
} break;
|
3609
4113
|
case GGML_TYPE_I8:
|
3610
4114
|
{
|
3611
4115
|
GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
|
@@ -3631,7 +4135,7 @@ void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) {
|
|
3631
4135
|
GGML_ASSERT(tensor->nb[0] == sizeof(float));
|
3632
4136
|
((float *)(tensor->data))[i] = value;
|
3633
4137
|
} break;
|
3634
|
-
|
4138
|
+
default:
|
3635
4139
|
{
|
3636
4140
|
GGML_ASSERT(false);
|
3637
4141
|
} break;
|
@@ -5112,6 +5616,26 @@ static void ggml_compute_forward_dup_f16(
|
|
5112
5616
|
}
|
5113
5617
|
}
|
5114
5618
|
}
|
5619
|
+
} else if (dst->type == GGML_TYPE_Q4_0 || dst->type == GGML_TYPE_Q4_1) {
|
5620
|
+
quantize_row_q_t const quantize_row_q = quantize_fns[dst->type].quantize_row_q;
|
5621
|
+
size_t id = 0;
|
5622
|
+
uint8_t * dst_ptr = (uint8_t *) dst->data;
|
5623
|
+
size_t dst_row_size = nb0 * (ne00 / GGML_BLCK_SIZE[dst->type]);
|
5624
|
+
float * src0_f32 = (float *) params->wdata;
|
5625
|
+
|
5626
|
+
for (int i03 = 0; i03 < ne03; i03++) {
|
5627
|
+
for (int i02 = 0; i02 < ne02; i02++) {
|
5628
|
+
for (int i01 = 0; i01 < ne01; i01++) {
|
5629
|
+
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
5630
|
+
// convert to f32 and quantize
|
5631
|
+
for (int i00 = 0; i00 < ne00; i00++) {
|
5632
|
+
src0_f32[i00] = GGML_FP16_TO_FP32(src0_ptr[i00]);
|
5633
|
+
}
|
5634
|
+
quantize_row_q(src0_f32, dst_ptr + id, ne00);
|
5635
|
+
id += dst_row_size;
|
5636
|
+
}
|
5637
|
+
}
|
5638
|
+
}
|
5115
5639
|
} else {
|
5116
5640
|
GGML_ASSERT(false); // TODO: implement
|
5117
5641
|
}
|
@@ -5304,6 +5828,21 @@ static void ggml_compute_forward_dup_f32(
|
|
5304
5828
|
}
|
5305
5829
|
}
|
5306
5830
|
}
|
5831
|
+
} else if (dst->type == GGML_TYPE_Q4_0 || dst->type == GGML_TYPE_Q4_1) {
|
5832
|
+
quantize_row_q_t const quantize_row_q = quantize_fns[dst->type].quantize_row_q;
|
5833
|
+
size_t id = 0;
|
5834
|
+
uint8_t * dst_ptr = (uint8_t *) dst->data;
|
5835
|
+
size_t dst_row_size = nb0 * (ne00 / GGML_BLCK_SIZE[dst->type]);
|
5836
|
+
|
5837
|
+
for (int i03 = 0; i03 < ne03; i03++) {
|
5838
|
+
for (int i02 = 0; i02 < ne02; i02++) {
|
5839
|
+
for (int i01 = 0; i01 < ne01; i01++) {
|
5840
|
+
const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
5841
|
+
quantize_row_q(src0_ptr, dst_ptr + id, ne00);
|
5842
|
+
id += dst_row_size;
|
5843
|
+
}
|
5844
|
+
}
|
5845
|
+
}
|
5307
5846
|
} else {
|
5308
5847
|
GGML_ASSERT(false); // TODO: implement
|
5309
5848
|
}
|
@@ -5426,12 +5965,7 @@ static void ggml_compute_forward_dup(
|
|
5426
5965
|
{
|
5427
5966
|
ggml_compute_forward_dup_f32(params, src0, dst);
|
5428
5967
|
} break;
|
5429
|
-
|
5430
|
-
case GGML_TYPE_Q4_1:
|
5431
|
-
case GGML_TYPE_I8:
|
5432
|
-
case GGML_TYPE_I16:
|
5433
|
-
case GGML_TYPE_I32:
|
5434
|
-
case GGML_TYPE_COUNT:
|
5968
|
+
default:
|
5435
5969
|
{
|
5436
5970
|
GGML_ASSERT(false);
|
5437
5971
|
} break;
|
@@ -5463,37 +5997,243 @@ static void ggml_compute_forward_add_f32(
|
|
5463
5997
|
const size_t nb10 = src1->nb[0];
|
5464
5998
|
const size_t nb11 = src1->nb[1];
|
5465
5999
|
|
5466
|
-
const size_t nb0 = dst->nb[0];
|
5467
|
-
const size_t nb1 = dst->nb[1];
|
6000
|
+
const size_t nb0 = dst->nb[0];
|
6001
|
+
const size_t nb1 = dst->nb[1];
|
6002
|
+
|
6003
|
+
GGML_ASSERT( nb0 == sizeof(float));
|
6004
|
+
GGML_ASSERT(nb00 == sizeof(float));
|
6005
|
+
|
6006
|
+
if (nb10 == sizeof(float)) {
|
6007
|
+
for (int j = ith; j < n; j += nth) {
|
6008
|
+
#ifdef GGML_USE_ACCELERATE
|
6009
|
+
vDSP_vadd(
|
6010
|
+
(float *) ((char *) src0->data + j*nb01), 1,
|
6011
|
+
(float *) ((char *) src1->data + j*nb11), 1,
|
6012
|
+
(float *) ((char *) dst->data + j*nb1), 1, nc);
|
6013
|
+
#else
|
6014
|
+
ggml_vec_add_f32(nc,
|
6015
|
+
(float *) ((char *) dst->data + j*nb1),
|
6016
|
+
(float *) ((char *) src0->data + j*nb01),
|
6017
|
+
(float *) ((char *) src1->data + j*nb11));
|
6018
|
+
#endif
|
6019
|
+
}
|
6020
|
+
} else {
|
6021
|
+
// src1 is not contiguous
|
6022
|
+
for (int j = ith; j < n; j += nth) {
|
6023
|
+
float * dst_ptr = (float *) ((char *) dst->data + j*nb1);
|
6024
|
+
float * src0_ptr = (float *) ((char *) src0->data + j*nb01);
|
6025
|
+
for (int i = 0; i < nc; i++) {
|
6026
|
+
float * src1_ptr = (float *) ((char *) src1->data + j*nb11 + i*nb10);
|
6027
|
+
|
6028
|
+
dst_ptr[i] = src0_ptr[i] + *src1_ptr;
|
6029
|
+
}
|
6030
|
+
}
|
6031
|
+
}
|
6032
|
+
}
|
6033
|
+
|
6034
|
+
static void ggml_compute_forward_add_f16_f32(
|
6035
|
+
const struct ggml_compute_params * params,
|
6036
|
+
const struct ggml_tensor * src0,
|
6037
|
+
const struct ggml_tensor * src1,
|
6038
|
+
struct ggml_tensor * dst) {
|
6039
|
+
GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
|
6040
|
+
|
6041
|
+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
6042
|
+
return;
|
6043
|
+
}
|
6044
|
+
|
6045
|
+
const int ith = params->ith;
|
6046
|
+
const int nth = params->nth;
|
6047
|
+
|
6048
|
+
const int n = ggml_nrows(src0);
|
6049
|
+
const int nc = src0->ne[0];
|
6050
|
+
|
6051
|
+
const size_t nb00 = src0->nb[0];
|
6052
|
+
const size_t nb01 = src0->nb[1];
|
6053
|
+
|
6054
|
+
const size_t nb10 = src1->nb[0];
|
6055
|
+
const size_t nb11 = src1->nb[1];
|
6056
|
+
|
6057
|
+
const size_t nb0 = dst->nb[0];
|
6058
|
+
const size_t nb1 = dst->nb[1];
|
6059
|
+
|
6060
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
6061
|
+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
6062
|
+
GGML_ASSERT(dst->type == GGML_TYPE_F16);
|
6063
|
+
|
6064
|
+
GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
|
6065
|
+
GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
|
6066
|
+
|
6067
|
+
if (nb10 == sizeof(float)) {
|
6068
|
+
for (int j = ith; j < n; j += nth) {
|
6069
|
+
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + j*nb1);
|
6070
|
+
ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01);
|
6071
|
+
for (int i = 0; i < nc; i++) {
|
6072
|
+
float * src1_ptr = (float *) ((char *) src1->data + j*nb11 + i*nb10);
|
6073
|
+
dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + *src1_ptr);
|
6074
|
+
}
|
6075
|
+
}
|
6076
|
+
}
|
6077
|
+
else {
|
6078
|
+
// src1 is not contiguous
|
6079
|
+
GGML_ASSERT(false);
|
6080
|
+
}
|
6081
|
+
}
|
6082
|
+
|
6083
|
+
static void ggml_compute_forward_add_f16_f16(
|
6084
|
+
const struct ggml_compute_params * params,
|
6085
|
+
const struct ggml_tensor * src0,
|
6086
|
+
const struct ggml_tensor * src1,
|
6087
|
+
struct ggml_tensor * dst) {
|
6088
|
+
GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
|
6089
|
+
|
6090
|
+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
6091
|
+
return;
|
6092
|
+
}
|
6093
|
+
|
6094
|
+
const int ith = params->ith;
|
6095
|
+
const int nth = params->nth;
|
6096
|
+
|
6097
|
+
const int n = ggml_nrows(src0);
|
6098
|
+
const int nc = src0->ne[0];
|
6099
|
+
|
6100
|
+
const size_t nb00 = src0->nb[0];
|
6101
|
+
const size_t nb01 = src0->nb[1];
|
6102
|
+
|
6103
|
+
const size_t nb10 = src1->nb[0];
|
6104
|
+
const size_t nb11 = src1->nb[1];
|
6105
|
+
|
6106
|
+
const size_t nb0 = dst->nb[0];
|
6107
|
+
const size_t nb1 = dst->nb[1];
|
6108
|
+
|
6109
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
6110
|
+
GGML_ASSERT(src1->type == GGML_TYPE_F16);
|
6111
|
+
GGML_ASSERT(dst->type == GGML_TYPE_F16);
|
6112
|
+
|
6113
|
+
GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
|
6114
|
+
GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
|
6115
|
+
|
6116
|
+
if (nb10 == sizeof(ggml_fp16_t)) {
|
6117
|
+
for (int j = ith; j < n; j += nth) {
|
6118
|
+
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + j*nb1);
|
6119
|
+
ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01);
|
6120
|
+
for (int i = 0; i < nc; i++) {
|
6121
|
+
ggml_fp16_t * src1_ptr = (ggml_fp16_t *) ((char *) src1->data + j*nb11 + i*nb10);
|
6122
|
+
dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + GGML_FP16_TO_FP32(*src1_ptr));
|
6123
|
+
}
|
6124
|
+
}
|
6125
|
+
}
|
6126
|
+
else {
|
6127
|
+
// src1 is not contiguous
|
6128
|
+
GGML_ASSERT(false);
|
6129
|
+
}
|
6130
|
+
}
|
6131
|
+
|
6132
|
+
static void ggml_compute_forward_add_q_f32(
|
6133
|
+
const struct ggml_compute_params * params,
|
6134
|
+
const struct ggml_tensor * src0,
|
6135
|
+
const struct ggml_tensor * src1,
|
6136
|
+
struct ggml_tensor * dst) {
|
6137
|
+
GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
|
6138
|
+
|
6139
|
+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
6140
|
+
return;
|
6141
|
+
}
|
6142
|
+
|
6143
|
+
const int64_t ne00 = src0->ne[0];
|
6144
|
+
const int64_t ne01 = src0->ne[1];
|
6145
|
+
const int64_t ne02 = src0->ne[2];
|
6146
|
+
const int64_t ne03 = src0->ne[3];
|
6147
|
+
|
6148
|
+
//const int64_t ne10 = src1->ne[0];
|
6149
|
+
//const int64_t ne11 = src1->ne[1];
|
6150
|
+
const int64_t ne12 = src1->ne[2];
|
6151
|
+
const int64_t ne13 = src1->ne[3];
|
6152
|
+
|
6153
|
+
//const int64_t ne0 = dst->ne[0];
|
6154
|
+
//const int64_t ne1 = dst->ne[1];
|
6155
|
+
const int64_t ne2 = dst->ne[2];
|
6156
|
+
const int64_t ne3 = dst->ne[3];
|
6157
|
+
|
6158
|
+
const int nb00 = src0->nb[0];
|
6159
|
+
const int nb01 = src0->nb[1];
|
6160
|
+
const int nb02 = src0->nb[2];
|
6161
|
+
const int nb03 = src0->nb[3];
|
6162
|
+
|
6163
|
+
const int nb10 = src1->nb[0];
|
6164
|
+
const int nb11 = src1->nb[1];
|
6165
|
+
const int nb12 = src1->nb[2];
|
6166
|
+
const int nb13 = src1->nb[3];
|
6167
|
+
|
6168
|
+
const int nb0 = dst->nb[0];
|
6169
|
+
const int nb1 = dst->nb[1];
|
6170
|
+
const int nb2 = dst->nb[2];
|
6171
|
+
const int nb3 = dst->nb[3];
|
6172
|
+
|
6173
|
+
const int ith = params->ith;
|
6174
|
+
const int nth = params->nth;
|
6175
|
+
|
6176
|
+
GGML_ASSERT(ne02 == ne12);
|
6177
|
+
GGML_ASSERT(ne03 == ne13);
|
6178
|
+
GGML_ASSERT(ne2 == ne12);
|
6179
|
+
GGML_ASSERT(ne3 == ne13);
|
6180
|
+
|
6181
|
+
const enum ggml_type type = src0->type;
|
6182
|
+
dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
|
6183
|
+
quantize_row_q_t const quantize_row_q = quantize_fns[type].quantize_row_q;
|
6184
|
+
|
6185
|
+
// we don't support permuted src0 or src1
|
6186
|
+
GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[type]);
|
6187
|
+
GGML_ASSERT(nb10 == sizeof(float));
|
6188
|
+
|
6189
|
+
// dst cannot be transposed or permuted
|
6190
|
+
GGML_ASSERT(nb0 <= nb1);
|
6191
|
+
GGML_ASSERT(nb1 <= nb2);
|
6192
|
+
GGML_ASSERT(nb2 <= nb3);
|
6193
|
+
|
6194
|
+
GGML_ASSERT(src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1);
|
6195
|
+
GGML_ASSERT(dst->type == src0->type);
|
6196
|
+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
6197
|
+
|
6198
|
+
// total rows in src0
|
6199
|
+
const int nr = ne01*ne02*ne03;
|
6200
|
+
|
6201
|
+
// rows per thread
|
6202
|
+
const int dr = (nr + nth - 1)/nth;
|
6203
|
+
|
6204
|
+
// row range for this thread
|
6205
|
+
const int ir0 = dr*ith;
|
6206
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
6207
|
+
|
6208
|
+
float * wdata = (float*) params->wdata + ne00 * ith;
|
6209
|
+
|
6210
|
+
for (int ir = ir0; ir < ir1; ++ir) {
|
6211
|
+
// src0 indices
|
6212
|
+
const int i03 = ir/(ne02*ne01);
|
6213
|
+
const int i02 = (ir - i03*ne02*ne01)/ne01;
|
6214
|
+
const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
|
6215
|
+
|
6216
|
+
// src1 and dst are same shape as src0 => same indices
|
6217
|
+
const int i13 = i03;
|
6218
|
+
const int i12 = i02;
|
6219
|
+
const int i11 = i01;
|
5468
6220
|
|
5469
|
-
|
5470
|
-
|
6221
|
+
const int i3 = i03;
|
6222
|
+
const int i2 = i02;
|
6223
|
+
const int i1 = i01;
|
5471
6224
|
|
5472
|
-
|
5473
|
-
|
5474
|
-
|
5475
|
-
vDSP_vadd(
|
5476
|
-
(float *) ((char *) src0->data + j*nb01), 1,
|
5477
|
-
(float *) ((char *) src1->data + j*nb11), 1,
|
5478
|
-
(float *) ((char *) dst->data + j*nb1), 1, nc);
|
5479
|
-
#else
|
5480
|
-
ggml_vec_add_f32(nc,
|
5481
|
-
(float *) ((char *) dst->data + j*nb1),
|
5482
|
-
(float *) ((char *) src0->data + j*nb01),
|
5483
|
-
(float *) ((char *) src1->data + j*nb11));
|
5484
|
-
#endif
|
5485
|
-
}
|
5486
|
-
} else {
|
5487
|
-
// src1 is not contiguous
|
5488
|
-
for (int j = ith; j < n; j += nth) {
|
5489
|
-
float * dst_ptr = (float *) ((char *) dst->data + j*nb1);
|
5490
|
-
float * src0_ptr = (float *) ((char *) src0->data + j*nb01);
|
5491
|
-
for (int i = 0; i < nc; i++) {
|
5492
|
-
float * src1_ptr = (float *) ((char *) src1->data + j*nb11 + i*nb10);
|
6225
|
+
void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
|
6226
|
+
float * src1_row = (float *)((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13));
|
6227
|
+
void * dst_row = (void *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb0));
|
5493
6228
|
|
5494
|
-
|
5495
|
-
|
5496
|
-
|
6229
|
+
assert(ne00 % 32 == 0);
|
6230
|
+
|
6231
|
+
// unquantize row from src0 to temp buffer
|
6232
|
+
dequantize_row_q(src0_row, wdata, ne00);
|
6233
|
+
// add src1
|
6234
|
+
ggml_vec_acc_f32(ne00, wdata, src1_row);
|
6235
|
+
// quantize row to dst
|
6236
|
+
quantize_row_q(wdata, dst_row, ne00);
|
5497
6237
|
}
|
5498
6238
|
}
|
5499
6239
|
|
@@ -5507,13 +6247,24 @@ static void ggml_compute_forward_add(
|
|
5507
6247
|
{
|
5508
6248
|
ggml_compute_forward_add_f32(params, src0, src1, dst);
|
5509
6249
|
} break;
|
6250
|
+
case GGML_TYPE_F16:
|
6251
|
+
{
|
6252
|
+
if (src1->type == GGML_TYPE_F16) {
|
6253
|
+
ggml_compute_forward_add_f16_f16(params, src0, src1, dst);
|
6254
|
+
}
|
6255
|
+
else if (src1->type == GGML_TYPE_F32) {
|
6256
|
+
ggml_compute_forward_add_f16_f32(params, src0, src1, dst);
|
6257
|
+
}
|
6258
|
+
else {
|
6259
|
+
GGML_ASSERT(false);
|
6260
|
+
}
|
6261
|
+
} break;
|
5510
6262
|
case GGML_TYPE_Q4_0:
|
5511
6263
|
case GGML_TYPE_Q4_1:
|
5512
|
-
|
5513
|
-
|
5514
|
-
|
5515
|
-
|
5516
|
-
case GGML_TYPE_COUNT:
|
6264
|
+
{
|
6265
|
+
ggml_compute_forward_add_q_f32(params, src0, src1, dst);
|
6266
|
+
} break;
|
6267
|
+
default:
|
5517
6268
|
{
|
5518
6269
|
GGML_ASSERT(false);
|
5519
6270
|
} break;
|
@@ -5559,13 +6310,7 @@ static void ggml_compute_forward_sub(
|
|
5559
6310
|
{
|
5560
6311
|
ggml_compute_forward_sub_f32(params, src0, src1, dst);
|
5561
6312
|
} break;
|
5562
|
-
|
5563
|
-
case GGML_TYPE_Q4_1:
|
5564
|
-
case GGML_TYPE_I8:
|
5565
|
-
case GGML_TYPE_I16:
|
5566
|
-
case GGML_TYPE_I32:
|
5567
|
-
case GGML_TYPE_F16:
|
5568
|
-
case GGML_TYPE_COUNT:
|
6313
|
+
default:
|
5569
6314
|
{
|
5570
6315
|
GGML_ASSERT(false);
|
5571
6316
|
} break;
|
@@ -5611,13 +6356,7 @@ static void ggml_compute_forward_mul(
|
|
5611
6356
|
{
|
5612
6357
|
ggml_compute_forward_mul_f32(params, src0, src1, dst);
|
5613
6358
|
} break;
|
5614
|
-
|
5615
|
-
case GGML_TYPE_Q4_1:
|
5616
|
-
case GGML_TYPE_I8:
|
5617
|
-
case GGML_TYPE_I16:
|
5618
|
-
case GGML_TYPE_I32:
|
5619
|
-
case GGML_TYPE_F16:
|
5620
|
-
case GGML_TYPE_COUNT:
|
6359
|
+
default:
|
5621
6360
|
{
|
5622
6361
|
GGML_ASSERT(false);
|
5623
6362
|
} break;
|
@@ -5663,13 +6402,7 @@ static void ggml_compute_forward_div(
|
|
5663
6402
|
{
|
5664
6403
|
ggml_compute_forward_div_f32(params, src0, src1, dst);
|
5665
6404
|
} break;
|
5666
|
-
|
5667
|
-
case GGML_TYPE_Q4_1:
|
5668
|
-
case GGML_TYPE_I8:
|
5669
|
-
case GGML_TYPE_I16:
|
5670
|
-
case GGML_TYPE_I32:
|
5671
|
-
case GGML_TYPE_F16:
|
5672
|
-
case GGML_TYPE_COUNT:
|
6405
|
+
default:
|
5673
6406
|
{
|
5674
6407
|
GGML_ASSERT(false);
|
5675
6408
|
} break;
|
@@ -5711,13 +6444,7 @@ static void ggml_compute_forward_sqr(
|
|
5711
6444
|
{
|
5712
6445
|
ggml_compute_forward_sqr_f32(params, src0, dst);
|
5713
6446
|
} break;
|
5714
|
-
|
5715
|
-
case GGML_TYPE_Q4_1:
|
5716
|
-
case GGML_TYPE_I8:
|
5717
|
-
case GGML_TYPE_I16:
|
5718
|
-
case GGML_TYPE_I32:
|
5719
|
-
case GGML_TYPE_F16:
|
5720
|
-
case GGML_TYPE_COUNT:
|
6447
|
+
default:
|
5721
6448
|
{
|
5722
6449
|
GGML_ASSERT(false);
|
5723
6450
|
} break;
|
@@ -5759,13 +6486,7 @@ static void ggml_compute_forward_sqrt(
|
|
5759
6486
|
{
|
5760
6487
|
ggml_compute_forward_sqrt_f32(params, src0, dst);
|
5761
6488
|
} break;
|
5762
|
-
|
5763
|
-
case GGML_TYPE_Q4_1:
|
5764
|
-
case GGML_TYPE_I8:
|
5765
|
-
case GGML_TYPE_I16:
|
5766
|
-
case GGML_TYPE_I32:
|
5767
|
-
case GGML_TYPE_F16:
|
5768
|
-
case GGML_TYPE_COUNT:
|
6489
|
+
default:
|
5769
6490
|
{
|
5770
6491
|
GGML_ASSERT(false);
|
5771
6492
|
} break;
|
@@ -5817,13 +6538,7 @@ static void ggml_compute_forward_sum(
|
|
5817
6538
|
{
|
5818
6539
|
ggml_compute_forward_sum_f32(params, src0, dst);
|
5819
6540
|
} break;
|
5820
|
-
|
5821
|
-
case GGML_TYPE_Q4_1:
|
5822
|
-
case GGML_TYPE_I8:
|
5823
|
-
case GGML_TYPE_I16:
|
5824
|
-
case GGML_TYPE_I32:
|
5825
|
-
case GGML_TYPE_F16:
|
5826
|
-
case GGML_TYPE_COUNT:
|
6541
|
+
default:
|
5827
6542
|
{
|
5828
6543
|
GGML_ASSERT(false);
|
5829
6544
|
} break;
|
@@ -5894,13 +6609,7 @@ static void ggml_compute_forward_mean(
|
|
5894
6609
|
{
|
5895
6610
|
ggml_compute_forward_mean_f32(params, src0, dst);
|
5896
6611
|
} break;
|
5897
|
-
|
5898
|
-
case GGML_TYPE_Q4_1:
|
5899
|
-
case GGML_TYPE_I8:
|
5900
|
-
case GGML_TYPE_I16:
|
5901
|
-
case GGML_TYPE_I32:
|
5902
|
-
case GGML_TYPE_F16:
|
5903
|
-
case GGML_TYPE_COUNT:
|
6612
|
+
default:
|
5904
6613
|
{
|
5905
6614
|
GGML_ASSERT(false);
|
5906
6615
|
} break;
|
@@ -5958,13 +6667,7 @@ static void ggml_compute_forward_repeat(
|
|
5958
6667
|
{
|
5959
6668
|
ggml_compute_forward_repeat_f32(params, src0, dst);
|
5960
6669
|
} break;
|
5961
|
-
|
5962
|
-
case GGML_TYPE_Q4_1:
|
5963
|
-
case GGML_TYPE_I8:
|
5964
|
-
case GGML_TYPE_I16:
|
5965
|
-
case GGML_TYPE_I32:
|
5966
|
-
case GGML_TYPE_F16:
|
5967
|
-
case GGML_TYPE_COUNT:
|
6670
|
+
default:
|
5968
6671
|
{
|
5969
6672
|
GGML_ASSERT(false);
|
5970
6673
|
} break;
|
@@ -6006,13 +6709,7 @@ static void ggml_compute_forward_abs(
|
|
6006
6709
|
{
|
6007
6710
|
ggml_compute_forward_abs_f32(params, src0, dst);
|
6008
6711
|
} break;
|
6009
|
-
|
6010
|
-
case GGML_TYPE_Q4_1:
|
6011
|
-
case GGML_TYPE_I8:
|
6012
|
-
case GGML_TYPE_I16:
|
6013
|
-
case GGML_TYPE_I32:
|
6014
|
-
case GGML_TYPE_F16:
|
6015
|
-
case GGML_TYPE_COUNT:
|
6712
|
+
default:
|
6016
6713
|
{
|
6017
6714
|
GGML_ASSERT(false);
|
6018
6715
|
} break;
|
@@ -6054,13 +6751,7 @@ static void ggml_compute_forward_sgn(
|
|
6054
6751
|
{
|
6055
6752
|
ggml_compute_forward_sgn_f32(params, src0, dst);
|
6056
6753
|
} break;
|
6057
|
-
|
6058
|
-
case GGML_TYPE_Q4_1:
|
6059
|
-
case GGML_TYPE_I8:
|
6060
|
-
case GGML_TYPE_I16:
|
6061
|
-
case GGML_TYPE_I32:
|
6062
|
-
case GGML_TYPE_F16:
|
6063
|
-
case GGML_TYPE_COUNT:
|
6754
|
+
default:
|
6064
6755
|
{
|
6065
6756
|
GGML_ASSERT(false);
|
6066
6757
|
} break;
|
@@ -6102,13 +6793,7 @@ static void ggml_compute_forward_neg(
|
|
6102
6793
|
{
|
6103
6794
|
ggml_compute_forward_neg_f32(params, src0, dst);
|
6104
6795
|
} break;
|
6105
|
-
|
6106
|
-
case GGML_TYPE_Q4_1:
|
6107
|
-
case GGML_TYPE_I8:
|
6108
|
-
case GGML_TYPE_I16:
|
6109
|
-
case GGML_TYPE_I32:
|
6110
|
-
case GGML_TYPE_F16:
|
6111
|
-
case GGML_TYPE_COUNT:
|
6796
|
+
default:
|
6112
6797
|
{
|
6113
6798
|
GGML_ASSERT(false);
|
6114
6799
|
} break;
|
@@ -6150,13 +6835,7 @@ static void ggml_compute_forward_step(
|
|
6150
6835
|
{
|
6151
6836
|
ggml_compute_forward_step_f32(params, src0, dst);
|
6152
6837
|
} break;
|
6153
|
-
|
6154
|
-
case GGML_TYPE_Q4_1:
|
6155
|
-
case GGML_TYPE_I8:
|
6156
|
-
case GGML_TYPE_I16:
|
6157
|
-
case GGML_TYPE_I32:
|
6158
|
-
case GGML_TYPE_F16:
|
6159
|
-
case GGML_TYPE_COUNT:
|
6838
|
+
default:
|
6160
6839
|
{
|
6161
6840
|
GGML_ASSERT(false);
|
6162
6841
|
} break;
|
@@ -6198,13 +6877,7 @@ static void ggml_compute_forward_relu(
|
|
6198
6877
|
{
|
6199
6878
|
ggml_compute_forward_relu_f32(params, src0, dst);
|
6200
6879
|
} break;
|
6201
|
-
|
6202
|
-
case GGML_TYPE_Q4_1:
|
6203
|
-
case GGML_TYPE_I8:
|
6204
|
-
case GGML_TYPE_I16:
|
6205
|
-
case GGML_TYPE_I32:
|
6206
|
-
case GGML_TYPE_F16:
|
6207
|
-
case GGML_TYPE_COUNT:
|
6880
|
+
default:
|
6208
6881
|
{
|
6209
6882
|
GGML_ASSERT(false);
|
6210
6883
|
} break;
|
@@ -6263,13 +6936,7 @@ static void ggml_compute_forward_gelu(
|
|
6263
6936
|
{
|
6264
6937
|
ggml_compute_forward_gelu_f32(params, src0, dst);
|
6265
6938
|
} break;
|
6266
|
-
|
6267
|
-
case GGML_TYPE_Q4_1:
|
6268
|
-
case GGML_TYPE_I8:
|
6269
|
-
case GGML_TYPE_I16:
|
6270
|
-
case GGML_TYPE_I32:
|
6271
|
-
case GGML_TYPE_F16:
|
6272
|
-
case GGML_TYPE_COUNT:
|
6939
|
+
default:
|
6273
6940
|
{
|
6274
6941
|
GGML_ASSERT(false);
|
6275
6942
|
} break;
|
@@ -6330,13 +6997,7 @@ static void ggml_compute_forward_silu(
|
|
6330
6997
|
{
|
6331
6998
|
ggml_compute_forward_silu_f32(params, src0, dst);
|
6332
6999
|
} break;
|
6333
|
-
|
6334
|
-
case GGML_TYPE_Q4_1:
|
6335
|
-
case GGML_TYPE_I8:
|
6336
|
-
case GGML_TYPE_I16:
|
6337
|
-
case GGML_TYPE_I32:
|
6338
|
-
case GGML_TYPE_F16:
|
6339
|
-
case GGML_TYPE_COUNT:
|
7000
|
+
default:
|
6340
7001
|
{
|
6341
7002
|
GGML_ASSERT(false);
|
6342
7003
|
} break;
|
@@ -6416,13 +7077,7 @@ static void ggml_compute_forward_norm(
|
|
6416
7077
|
{
|
6417
7078
|
ggml_compute_forward_norm_f32(params, src0, dst);
|
6418
7079
|
} break;
|
6419
|
-
|
6420
|
-
case GGML_TYPE_Q4_1:
|
6421
|
-
case GGML_TYPE_I8:
|
6422
|
-
case GGML_TYPE_I16:
|
6423
|
-
case GGML_TYPE_I32:
|
6424
|
-
case GGML_TYPE_F16:
|
6425
|
-
case GGML_TYPE_COUNT:
|
7080
|
+
default:
|
6426
7081
|
{
|
6427
7082
|
GGML_ASSERT(false);
|
6428
7083
|
} break;
|
@@ -6496,13 +7151,7 @@ static void ggml_compute_forward_rms_norm(
|
|
6496
7151
|
{
|
6497
7152
|
ggml_compute_forward_rms_norm_f32(params, src0, dst);
|
6498
7153
|
} break;
|
6499
|
-
|
6500
|
-
case GGML_TYPE_Q4_1:
|
6501
|
-
case GGML_TYPE_I8:
|
6502
|
-
case GGML_TYPE_I16:
|
6503
|
-
case GGML_TYPE_I32:
|
6504
|
-
case GGML_TYPE_F16:
|
6505
|
-
case GGML_TYPE_COUNT:
|
7154
|
+
default:
|
6506
7155
|
{
|
6507
7156
|
GGML_ASSERT(false);
|
6508
7157
|
} break;
|
@@ -6894,27 +7543,6 @@ static void ggml_compute_forward_mul_mat_f16_f32(
|
|
6894
7543
|
//}
|
6895
7544
|
}
|
6896
7545
|
|
6897
|
-
static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
|
6898
|
-
[GGML_TYPE_Q4_0] = {
|
6899
|
-
.dequantize_row_q = dequantize_row_q4_0,
|
6900
|
-
.quantize_row_q = quantize_row_q4_0,
|
6901
|
-
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_0_reference,
|
6902
|
-
.vec_dot_q = ggml_vec_dot_q4_0,
|
6903
|
-
},
|
6904
|
-
[GGML_TYPE_Q4_1] = {
|
6905
|
-
.dequantize_row_q = dequantize_row_q4_1,
|
6906
|
-
.quantize_row_q = quantize_row_q4_1,
|
6907
|
-
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_1_reference,
|
6908
|
-
.vec_dot_q = ggml_vec_dot_q4_1,
|
6909
|
-
},
|
6910
|
-
};
|
6911
|
-
|
6912
|
-
// For internal test use
|
6913
|
-
quantize_fns_t ggml_internal_get_quantize_fn(size_t i) {
|
6914
|
-
GGML_ASSERT(i < GGML_TYPE_COUNT);
|
6915
|
-
return quantize_fns[i];
|
6916
|
-
}
|
6917
|
-
|
6918
7546
|
static void ggml_compute_forward_mul_mat_q_f32(
|
6919
7547
|
const struct ggml_compute_params * params,
|
6920
7548
|
const struct ggml_tensor * src0,
|
@@ -6962,8 +7590,8 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|
6962
7590
|
GGML_ASSERT(ne3 == ne13);
|
6963
7591
|
|
6964
7592
|
const enum ggml_type type = src0->type;
|
6965
|
-
quantize_row_q_t const
|
6966
|
-
vec_dot_q_t const vec_dot_q
|
7593
|
+
quantize_row_q_t const quantize_row_q_dot = quantize_fns[type].quantize_row_q_dot;
|
7594
|
+
vec_dot_q_t const vec_dot_q = quantize_fns[type].vec_dot_q;
|
6967
7595
|
|
6968
7596
|
// we don't support permuted src0 or src1
|
6969
7597
|
GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[type]);
|
@@ -7032,12 +7660,12 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|
7032
7660
|
|
7033
7661
|
if (params->type == GGML_TASK_INIT) {
|
7034
7662
|
char * wdata = params->wdata;
|
7035
|
-
const size_t row_size = ne10*GGML_TYPE_SIZE[
|
7663
|
+
const size_t row_size = ne10*GGML_TYPE_SIZE[GGML_TYPE_Q8_0]/GGML_BLCK_SIZE[GGML_TYPE_Q8_0];
|
7036
7664
|
|
7037
7665
|
for (int64_t i13 = 0; i13 < ne13; ++i13) {
|
7038
7666
|
for (int64_t i12 = 0; i12 < ne12; ++i12) {
|
7039
7667
|
for (int64_t i11 = 0; i11 < ne11; ++i11) {
|
7040
|
-
|
7668
|
+
quantize_row_q_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
|
7041
7669
|
wdata += row_size;
|
7042
7670
|
}
|
7043
7671
|
}
|
@@ -7063,7 +7691,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|
7063
7691
|
const int ir1 = MIN(ir0 + dr, nr);
|
7064
7692
|
|
7065
7693
|
void * wdata = params->wdata;
|
7066
|
-
const size_t row_size = ne00*GGML_TYPE_SIZE[
|
7694
|
+
const size_t row_size = ne00*GGML_TYPE_SIZE[GGML_TYPE_Q8_0]/GGML_BLCK_SIZE[GGML_TYPE_Q8_0];
|
7067
7695
|
|
7068
7696
|
for (int ir = ir0; ir < ir1; ++ir) {
|
7069
7697
|
// src0 indices
|
@@ -7111,6 +7739,7 @@ static void ggml_compute_forward_mul_mat(
|
|
7111
7739
|
switch (src0->type) {
|
7112
7740
|
case GGML_TYPE_Q4_0:
|
7113
7741
|
case GGML_TYPE_Q4_1:
|
7742
|
+
case GGML_TYPE_Q8_0:
|
7114
7743
|
{
|
7115
7744
|
ggml_compute_forward_mul_mat_q_f32(params, src0, src1, dst);
|
7116
7745
|
} break;
|
@@ -7122,10 +7751,7 @@ static void ggml_compute_forward_mul_mat(
|
|
7122
7751
|
{
|
7123
7752
|
ggml_compute_forward_mul_mat_f32(params, src0, src1, dst);
|
7124
7753
|
} break;
|
7125
|
-
|
7126
|
-
case GGML_TYPE_I16:
|
7127
|
-
case GGML_TYPE_I32:
|
7128
|
-
case GGML_TYPE_COUNT:
|
7754
|
+
default:
|
7129
7755
|
{
|
7130
7756
|
GGML_ASSERT(false);
|
7131
7757
|
} break;
|
@@ -7207,13 +7833,7 @@ static void ggml_compute_forward_scale(
|
|
7207
7833
|
{
|
7208
7834
|
ggml_compute_forward_scale_f32(params, src0, src1, dst);
|
7209
7835
|
} break;
|
7210
|
-
|
7211
|
-
case GGML_TYPE_Q4_1:
|
7212
|
-
case GGML_TYPE_I8:
|
7213
|
-
case GGML_TYPE_I16:
|
7214
|
-
case GGML_TYPE_I32:
|
7215
|
-
case GGML_TYPE_F16:
|
7216
|
-
case GGML_TYPE_COUNT:
|
7836
|
+
default:
|
7217
7837
|
{
|
7218
7838
|
GGML_ASSERT(false);
|
7219
7839
|
} break;
|
@@ -7374,6 +7994,7 @@ static void ggml_compute_forward_get_rows(
|
|
7374
7994
|
switch (src0->type) {
|
7375
7995
|
case GGML_TYPE_Q4_0:
|
7376
7996
|
case GGML_TYPE_Q4_1:
|
7997
|
+
case GGML_TYPE_Q8_0:
|
7377
7998
|
{
|
7378
7999
|
ggml_compute_forward_get_rows_q(params, src0, src1, dst);
|
7379
8000
|
} break;
|
@@ -7385,10 +8006,7 @@ static void ggml_compute_forward_get_rows(
|
|
7385
8006
|
{
|
7386
8007
|
ggml_compute_forward_get_rows_f32(params, src0, src1, dst);
|
7387
8008
|
} break;
|
7388
|
-
|
7389
|
-
case GGML_TYPE_I16:
|
7390
|
-
case GGML_TYPE_I32:
|
7391
|
-
case GGML_TYPE_COUNT:
|
8009
|
+
default:
|
7392
8010
|
{
|
7393
8011
|
GGML_ASSERT(false);
|
7394
8012
|
} break;
|
@@ -7461,13 +8079,7 @@ static void ggml_compute_forward_diag_mask_inf(
|
|
7461
8079
|
{
|
7462
8080
|
ggml_compute_forward_diag_mask_inf_f32(params, src0, src1, dst);
|
7463
8081
|
} break;
|
7464
|
-
|
7465
|
-
case GGML_TYPE_Q4_1:
|
7466
|
-
case GGML_TYPE_I8:
|
7467
|
-
case GGML_TYPE_I16:
|
7468
|
-
case GGML_TYPE_I32:
|
7469
|
-
case GGML_TYPE_F16:
|
7470
|
-
case GGML_TYPE_COUNT:
|
8082
|
+
default:
|
7471
8083
|
{
|
7472
8084
|
GGML_ASSERT(false);
|
7473
8085
|
} break;
|
@@ -7555,13 +8167,7 @@ static void ggml_compute_forward_soft_max(
|
|
7555
8167
|
{
|
7556
8168
|
ggml_compute_forward_soft_max_f32(params, src0, dst);
|
7557
8169
|
} break;
|
7558
|
-
|
7559
|
-
case GGML_TYPE_Q4_1:
|
7560
|
-
case GGML_TYPE_I8:
|
7561
|
-
case GGML_TYPE_I16:
|
7562
|
-
case GGML_TYPE_I32:
|
7563
|
-
case GGML_TYPE_F16:
|
7564
|
-
case GGML_TYPE_COUNT:
|
8170
|
+
default:
|
7565
8171
|
{
|
7566
8172
|
GGML_ASSERT(false);
|
7567
8173
|
} break;
|
@@ -7713,11 +8319,11 @@ static void ggml_compute_forward_rope_f16(
|
|
7713
8319
|
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
7714
8320
|
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
7715
8321
|
|
7716
|
-
const float x0 =
|
7717
|
-
const float x1 =
|
8322
|
+
const float x0 = GGML_FP16_TO_FP32(src[0]);
|
8323
|
+
const float x1 = GGML_FP16_TO_FP32(src[1]);
|
7718
8324
|
|
7719
|
-
dst_data[0] =
|
7720
|
-
dst_data[1] =
|
8325
|
+
dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
|
8326
|
+
dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
7721
8327
|
}
|
7722
8328
|
}
|
7723
8329
|
}
|
@@ -7738,12 +8344,7 @@ static void ggml_compute_forward_rope(
|
|
7738
8344
|
{
|
7739
8345
|
ggml_compute_forward_rope_f32(params, src0, src1, dst);
|
7740
8346
|
} break;
|
7741
|
-
|
7742
|
-
case GGML_TYPE_Q4_1:
|
7743
|
-
case GGML_TYPE_I8:
|
7744
|
-
case GGML_TYPE_I16:
|
7745
|
-
case GGML_TYPE_I32:
|
7746
|
-
case GGML_TYPE_COUNT:
|
8347
|
+
default:
|
7747
8348
|
{
|
7748
8349
|
GGML_ASSERT(false);
|
7749
8350
|
} break;
|
@@ -8006,12 +8607,7 @@ static void ggml_compute_forward_conv_1d_1s(
|
|
8006
8607
|
{
|
8007
8608
|
ggml_compute_forward_conv_1d_1s_f32(params, src0, src1, dst);
|
8008
8609
|
} break;
|
8009
|
-
|
8010
|
-
case GGML_TYPE_Q4_1:
|
8011
|
-
case GGML_TYPE_I8:
|
8012
|
-
case GGML_TYPE_I16:
|
8013
|
-
case GGML_TYPE_I32:
|
8014
|
-
case GGML_TYPE_COUNT:
|
8610
|
+
default:
|
8015
8611
|
{
|
8016
8612
|
GGML_ASSERT(false);
|
8017
8613
|
} break;
|
@@ -8274,12 +8870,7 @@ static void ggml_compute_forward_conv_1d_2s(
|
|
8274
8870
|
{
|
8275
8871
|
ggml_compute_forward_conv_1d_2s_f32(params, src0, src1, dst);
|
8276
8872
|
} break;
|
8277
|
-
|
8278
|
-
case GGML_TYPE_Q4_1:
|
8279
|
-
case GGML_TYPE_I8:
|
8280
|
-
case GGML_TYPE_I16:
|
8281
|
-
case GGML_TYPE_I32:
|
8282
|
-
case GGML_TYPE_COUNT:
|
8873
|
+
default:
|
8283
8874
|
{
|
8284
8875
|
GGML_ASSERT(false);
|
8285
8876
|
} break;
|
@@ -8759,12 +9350,7 @@ static void ggml_compute_forward_flash_attn(
|
|
8759
9350
|
{
|
8760
9351
|
ggml_compute_forward_flash_attn_f32(params, q, k, v, masked, dst);
|
8761
9352
|
} break;
|
8762
|
-
|
8763
|
-
case GGML_TYPE_Q4_1:
|
8764
|
-
case GGML_TYPE_I8:
|
8765
|
-
case GGML_TYPE_I16:
|
8766
|
-
case GGML_TYPE_I32:
|
8767
|
-
case GGML_TYPE_COUNT:
|
9353
|
+
default:
|
8768
9354
|
{
|
8769
9355
|
GGML_ASSERT(false);
|
8770
9356
|
} break;
|
@@ -8970,12 +9556,7 @@ static void ggml_compute_forward_flash_ff(
|
|
8970
9556
|
{
|
8971
9557
|
GGML_ASSERT(false); // TODO
|
8972
9558
|
} break;
|
8973
|
-
|
8974
|
-
case GGML_TYPE_Q4_1:
|
8975
|
-
case GGML_TYPE_I8:
|
8976
|
-
case GGML_TYPE_I16:
|
8977
|
-
case GGML_TYPE_I32:
|
8978
|
-
case GGML_TYPE_COUNT:
|
9559
|
+
default:
|
8979
9560
|
{
|
8980
9561
|
GGML_ASSERT(false);
|
8981
9562
|
} break;
|
@@ -9019,13 +9600,7 @@ static void ggml_compute_forward_map_unary(
|
|
9019
9600
|
{
|
9020
9601
|
ggml_compute_forward_map_unary_f32(params, src0, dst, fun);
|
9021
9602
|
} break;
|
9022
|
-
|
9023
|
-
case GGML_TYPE_Q4_1:
|
9024
|
-
case GGML_TYPE_I8:
|
9025
|
-
case GGML_TYPE_I16:
|
9026
|
-
case GGML_TYPE_I32:
|
9027
|
-
case GGML_TYPE_F16:
|
9028
|
-
case GGML_TYPE_COUNT:
|
9603
|
+
default:
|
9029
9604
|
{
|
9030
9605
|
GGML_ASSERT(false);
|
9031
9606
|
} break;
|
@@ -9074,13 +9649,7 @@ static void ggml_compute_forward_map_binary(
|
|
9074
9649
|
{
|
9075
9650
|
ggml_compute_forward_map_binary_f32(params, src0, src1, dst, fun);
|
9076
9651
|
} break;
|
9077
|
-
|
9078
|
-
case GGML_TYPE_Q4_1:
|
9079
|
-
case GGML_TYPE_I8:
|
9080
|
-
case GGML_TYPE_I16:
|
9081
|
-
case GGML_TYPE_I32:
|
9082
|
-
case GGML_TYPE_F16:
|
9083
|
-
case GGML_TYPE_COUNT:
|
9652
|
+
default:
|
9084
9653
|
{
|
9085
9654
|
GGML_ASSERT(false);
|
9086
9655
|
} break;
|
@@ -9830,13 +10399,29 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|
9830
10399
|
struct ggml_tensor * node = cgraph->nodes[i];
|
9831
10400
|
|
9832
10401
|
switch (node->op) {
|
10402
|
+
case GGML_OP_CPY:
|
9833
10403
|
case GGML_OP_DUP:
|
9834
10404
|
{
|
9835
10405
|
node->n_tasks = 1;
|
10406
|
+
|
10407
|
+
size_t cur = 0;
|
10408
|
+
if (node->type == GGML_TYPE_Q4_0 || node->type == GGML_TYPE_Q4_1) {
|
10409
|
+
cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->ne[0];
|
10410
|
+
}
|
10411
|
+
|
10412
|
+
work_size = MAX(work_size, cur);
|
9836
10413
|
} break;
|
9837
10414
|
case GGML_OP_ADD:
|
9838
10415
|
{
|
9839
10416
|
node->n_tasks = n_threads;
|
10417
|
+
|
10418
|
+
size_t cur = 0;
|
10419
|
+
|
10420
|
+
if (node->src0->type == GGML_TYPE_Q4_0 || node->src0->type == GGML_TYPE_Q4_1) {
|
10421
|
+
cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src0->ne[0] * n_threads;
|
10422
|
+
}
|
10423
|
+
|
10424
|
+
work_size = MAX(work_size, cur);
|
9840
10425
|
} break;
|
9841
10426
|
case GGML_OP_SUB:
|
9842
10427
|
case GGML_OP_MUL:
|
@@ -9905,7 +10490,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|
9905
10490
|
} else
|
9906
10491
|
#endif
|
9907
10492
|
{
|
9908
|
-
cur = GGML_TYPE_SIZE[
|
10493
|
+
cur = GGML_TYPE_SIZE[GGML_TYPE_Q8_0]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[GGML_TYPE_Q8_0];
|
9909
10494
|
}
|
9910
10495
|
} else {
|
9911
10496
|
GGML_ASSERT(false);
|
@@ -9917,7 +10502,6 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|
9917
10502
|
{
|
9918
10503
|
node->n_tasks = n_threads;
|
9919
10504
|
} break;
|
9920
|
-
case GGML_OP_CPY:
|
9921
10505
|
case GGML_OP_CONT:
|
9922
10506
|
case GGML_OP_RESHAPE:
|
9923
10507
|
case GGML_OP_VIEW:
|
@@ -11080,16 +11664,16 @@ enum ggml_opt_result ggml_opt(
|
|
11080
11664
|
////////////////////////////////////////////////////////////////////////////////
|
11081
11665
|
|
11082
11666
|
size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist) {
|
11083
|
-
assert(k %
|
11084
|
-
const int nb = k /
|
11667
|
+
assert(k % QK4_0 == 0);
|
11668
|
+
const int nb = k / QK4_0;
|
11085
11669
|
|
11086
11670
|
for (int j = 0; j < n; j += k) {
|
11087
|
-
block_q4_0 * restrict y = (block_q4_0 *)dst + j/
|
11671
|
+
block_q4_0 * restrict y = (block_q4_0 *)dst + j/QK4_0;
|
11088
11672
|
|
11089
11673
|
quantize_row_q4_0_reference(src + j, y, k);
|
11090
11674
|
|
11091
11675
|
for (int i = 0; i < nb; i++) {
|
11092
|
-
for (int l = 0; l <
|
11676
|
+
for (int l = 0; l < QK4_0; l += 2) {
|
11093
11677
|
const uint8_t vi0 = y[i].qs[l/2] & 0xF;
|
11094
11678
|
const uint8_t vi1 = y[i].qs[l/2] >> 4;
|
11095
11679
|
|
@@ -11099,20 +11683,20 @@ size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t *
|
|
11099
11683
|
}
|
11100
11684
|
}
|
11101
11685
|
|
11102
|
-
return (n/
|
11686
|
+
return (n/QK4_0*sizeof(block_q4_0));
|
11103
11687
|
}
|
11104
11688
|
|
11105
11689
|
size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist) {
|
11106
|
-
assert(k %
|
11107
|
-
const int nb = k /
|
11690
|
+
assert(k % QK4_1 == 0);
|
11691
|
+
const int nb = k / QK4_1;
|
11108
11692
|
|
11109
11693
|
for (int j = 0; j < n; j += k) {
|
11110
|
-
block_q4_1 * restrict y = (block_q4_1 *)dst + j/
|
11694
|
+
block_q4_1 * restrict y = (block_q4_1 *)dst + j/QK4_1;
|
11111
11695
|
|
11112
11696
|
quantize_row_q4_1_reference(src + j, y, k);
|
11113
11697
|
|
11114
11698
|
for (int i = 0; i < nb; i++) {
|
11115
|
-
for (int l = 0; l <
|
11699
|
+
for (int l = 0; l < QK4_1; l += 2) {
|
11116
11700
|
const uint8_t vi0 = y[i].qs[l/2] & 0xF;
|
11117
11701
|
const uint8_t vi1 = y[i].qs[l/2] >> 4;
|
11118
11702
|
|
@@ -11122,7 +11706,7 @@ size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t *
|
|
11122
11706
|
}
|
11123
11707
|
}
|
11124
11708
|
|
11125
|
-
return (n/
|
11709
|
+
return (n/QK4_1*sizeof(block_q4_1));
|
11126
11710
|
}
|
11127
11711
|
|
11128
11712
|
////////////////////////////////////////////////////////////////////////////////
|
@@ -11151,6 +11735,22 @@ int ggml_cpu_has_avx512(void) {
|
|
11151
11735
|
#endif
|
11152
11736
|
}
|
11153
11737
|
|
11738
|
+
int ggml_cpu_has_avx512_vbmi(void) {
|
11739
|
+
#if defined(__AVX512VBMI__)
|
11740
|
+
return 1;
|
11741
|
+
#else
|
11742
|
+
return 0;
|
11743
|
+
#endif
|
11744
|
+
}
|
11745
|
+
|
11746
|
+
int ggml_cpu_has_avx512_vnni(void) {
|
11747
|
+
#if defined(__AVX512VNNI__)
|
11748
|
+
return 1;
|
11749
|
+
#else
|
11750
|
+
return 0;
|
11751
|
+
#endif
|
11752
|
+
}
|
11753
|
+
|
11154
11754
|
int ggml_cpu_has_fma(void) {
|
11155
11755
|
#if defined(__FMA__)
|
11156
11756
|
return 1;
|