llama_cpp 0.0.5 → 0.0.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +10 -0
- data/ext/llama_cpp/extconf.rb +15 -1
- data/ext/llama_cpp/llama_cpp.cpp +46 -0
- data/ext/llama_cpp/src/ggml-cuda.h +12 -0
- data/ext/llama_cpp/src/ggml.c +1343 -800
- data/ext/llama_cpp/src/ggml.h +12 -2
- data/ext/llama_cpp/src/llama.cpp +60 -16
- data/ext/llama_cpp/src/llama.h +5 -1
- data/ext/llama_cpp/src/llama_util.h +0 -1
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +10 -1
- metadata +3 -2
data/ext/llama_cpp/src/ggml.c
CHANGED
@@ -19,6 +19,7 @@
|
|
19
19
|
#include <inttypes.h>
|
20
20
|
#include <stdio.h>
|
21
21
|
#include <float.h>
|
22
|
+
#include <limits.h>
|
22
23
|
|
23
24
|
// if C99 - static_assert is noop
|
24
25
|
// ref: https://stackoverflow.com/a/53923785/4039976
|
@@ -142,10 +143,49 @@ inline static void* ggml_aligned_malloc(size_t size) {
|
|
142
143
|
} \
|
143
144
|
} while (0)
|
144
145
|
|
145
|
-
#
|
146
|
+
#if defined(GGML_USE_ACCELERATE)
|
146
147
|
#include <Accelerate/Accelerate.h>
|
147
|
-
#elif GGML_USE_OPENBLAS
|
148
|
+
#elif defined(GGML_USE_OPENBLAS)
|
148
149
|
#include <cblas.h>
|
150
|
+
#elif defined(GGML_USE_CUBLAS)
|
151
|
+
#include <cublas_v2.h>
|
152
|
+
#include <cuda_runtime.h>
|
153
|
+
#include "ggml-cuda.h"
|
154
|
+
|
155
|
+
#define CUDA_CHECK(err) \
|
156
|
+
do { \
|
157
|
+
cudaError_t err_ = (err); \
|
158
|
+
if (err_ != cudaSuccess) { \
|
159
|
+
printf("CUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \
|
160
|
+
cudaGetErrorString(err_)); \
|
161
|
+
exit(1); \
|
162
|
+
} \
|
163
|
+
} while (0)
|
164
|
+
|
165
|
+
#define CUBLAS_CHECK(err) \
|
166
|
+
do { \
|
167
|
+
cublasStatus_t err_ = (err); \
|
168
|
+
if (err_ != CUBLAS_STATUS_SUCCESS) { \
|
169
|
+
printf("cuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__); \
|
170
|
+
exit(1); \
|
171
|
+
} \
|
172
|
+
} while (0)
|
173
|
+
|
174
|
+
static cublasHandle_t cublasH = NULL;
|
175
|
+
static cudaStream_t cudaStream = NULL;
|
176
|
+
static void init_cublas(void) {
|
177
|
+
if (cublasH == NULL) {
|
178
|
+
// create cublas handle, bind a stream
|
179
|
+
CUBLAS_CHECK(cublasCreate(&cublasH));
|
180
|
+
|
181
|
+
CUDA_CHECK(cudaStreamCreateWithFlags(&cudaStream, cudaStreamNonBlocking));
|
182
|
+
|
183
|
+
CUBLAS_CHECK(cublasSetStream(cublasH, cudaStream));
|
184
|
+
|
185
|
+
// configure logging to stdout
|
186
|
+
// CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, NULL));
|
187
|
+
}
|
188
|
+
}
|
149
189
|
#endif
|
150
190
|
|
151
191
|
#undef MIN
|
@@ -427,12 +467,30 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
|
|
427
467
|
// quantization
|
428
468
|
//
|
429
469
|
|
430
|
-
|
431
|
-
//
|
470
|
+
#if __AVX__ || __AVX2__ || __AVX512F__
|
471
|
+
// Unpack 16 4-bit fields into 16 bytes
|
472
|
+
// The output vector contains 16 bytes, each one in [ 0 .. 15 ] interval
|
473
|
+
static inline __m128i bytes_from_nibbles_16(const uint8_t * rsi)
|
474
|
+
{
|
475
|
+
// Load 8 bytes from memory
|
476
|
+
__m128i tmp = _mm_loadu_si64( ( const __m128i* )rsi );
|
477
|
+
|
478
|
+
// Expand bytes into uint16_t values
|
479
|
+
__m128i bytes = _mm_cvtepu8_epi16( tmp );
|
480
|
+
|
481
|
+
// Unpack values into individual bytes
|
482
|
+
const __m128i lowMask = _mm_set1_epi8( 0xF );
|
483
|
+
__m128i high = _mm_andnot_si128( lowMask, bytes );
|
484
|
+
__m128i low = _mm_and_si128( lowMask, bytes );
|
485
|
+
high = _mm_slli_epi16( high, 4 );
|
486
|
+
bytes = _mm_or_si128( low, high );
|
487
|
+
return bytes;
|
488
|
+
}
|
489
|
+
|
432
490
|
#if __AVX2__ || __AVX512F__
|
433
491
|
// Unpack 32 4-bit fields into 32 bytes
|
434
492
|
// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
|
435
|
-
static inline __m256i
|
493
|
+
static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
|
436
494
|
{
|
437
495
|
// Load 16 bytes from memory
|
438
496
|
__m128i tmp = _mm_loadu_si128( ( const __m128i* )rsi );
|
@@ -463,24 +521,7 @@ static inline __m128i packNibbles( __m256i bytes )
|
|
463
521
|
__m128i r1 = _mm256_extracti128_si256( bytes, 1 );
|
464
522
|
return _mm_packus_epi16( r0, r1 );
|
465
523
|
}
|
466
|
-
#
|
467
|
-
static inline __m128i bytesFromNibbles( const uint8_t* rsi )
|
468
|
-
{
|
469
|
-
// Load 8 bytes from memory
|
470
|
-
__m128i tmp = _mm_loadu_si64( ( const __m128i* )rsi );
|
471
|
-
|
472
|
-
// Expand bytes into uint16_t values
|
473
|
-
__m128i bytes = _mm_cvtepu8_epi16( tmp );
|
474
|
-
|
475
|
-
// Unpack values into individual bytes
|
476
|
-
const __m128i lowMask = _mm_set1_epi8( 0xF );
|
477
|
-
__m128i high = _mm_andnot_si128( lowMask, bytes );
|
478
|
-
__m128i low = _mm_and_si128( lowMask, bytes );
|
479
|
-
high = _mm_slli_epi16( high, 4 );
|
480
|
-
bytes = _mm_or_si128( low, high );
|
481
|
-
return bytes;
|
482
|
-
}
|
483
|
-
|
524
|
+
#else
|
484
525
|
static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
|
485
526
|
{
|
486
527
|
// Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
|
@@ -497,6 +538,7 @@ static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
|
|
497
538
|
return _mm_packus_epi16( bytes1, bytes2);
|
498
539
|
}
|
499
540
|
#endif
|
541
|
+
#endif // __AVX__ || __AVX2__ || __AVX512F__
|
500
542
|
|
501
543
|
#if __ARM_NEON
|
502
544
|
|
@@ -514,6 +556,18 @@ inline static uint16_t vaddvq_u8(uint8x16_t v) {
|
|
514
556
|
(uint16_t)vgetq_lane_u8(v, 14) + (uint16_t)vgetq_lane_u8(v, 15);
|
515
557
|
}
|
516
558
|
|
559
|
+
inline static int16_t vaddvq_s8(int8x16_t v) {
|
560
|
+
return
|
561
|
+
(int16_t)vgetq_lane_s8(v, 0) + (int16_t)vgetq_lane_s8(v, 1) +
|
562
|
+
(int16_t)vgetq_lane_s8(v, 2) + (int16_t)vgetq_lane_s8(v, 3) +
|
563
|
+
(int16_t)vgetq_lane_s8(v, 4) + (int16_t)vgetq_lane_s8(v, 5) +
|
564
|
+
(int16_t)vgetq_lane_s8(v, 6) + (int16_t)vgetq_lane_s8(v, 7) +
|
565
|
+
(int16_t)vgetq_lane_s8(v, 8) + (int16_t)vgetq_lane_s8(v, 9) +
|
566
|
+
(int16_t)vgetq_lane_s8(v, 10) + (int16_t)vgetq_lane_s8(v, 11) +
|
567
|
+
(int16_t)vgetq_lane_s8(v, 12) + (int16_t)vgetq_lane_s8(v, 13) +
|
568
|
+
(int16_t)vgetq_lane_s8(v, 14) + (int16_t)vgetq_lane_s8(v, 15);
|
569
|
+
}
|
570
|
+
|
517
571
|
inline static int32_t vaddvq_s16(int16x8_t v) {
|
518
572
|
return
|
519
573
|
(int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) +
|
@@ -583,7 +637,22 @@ typedef struct {
|
|
583
637
|
float m; // min
|
584
638
|
uint8_t qs[QK4_1 / 2]; // nibbles / quants
|
585
639
|
} block_q4_1;
|
586
|
-
static_assert(sizeof(block_q4_1) == sizeof(float)
|
640
|
+
static_assert(sizeof(block_q4_1) == 2 * sizeof(float) + QK4_1 / 2, "wrong q4_1 block size/padding");
|
641
|
+
|
642
|
+
#define QK4_2 16
|
643
|
+
typedef struct {
|
644
|
+
ggml_fp16_t d; // delta
|
645
|
+
uint8_t qs[QK4_2 / 2]; // nibbles / quants
|
646
|
+
} block_q4_2;
|
647
|
+
static_assert(sizeof(block_q4_2) == sizeof(ggml_fp16_t) + QK4_2 / 2, "wrong q4_2 block size/padding");
|
648
|
+
|
649
|
+
#define QK4_3 16
|
650
|
+
typedef struct {
|
651
|
+
ggml_fp16_t d; // delta
|
652
|
+
ggml_fp16_t m; // min
|
653
|
+
uint8_t qs[QK4_3 / 2]; // nibbles / quants
|
654
|
+
} block_q4_3;
|
655
|
+
static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong q4_3 block size/padding");
|
587
656
|
|
588
657
|
#define QK8_0 32
|
589
658
|
typedef struct {
|
@@ -1045,6 +1114,173 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int
|
|
1045
1114
|
#endif
|
1046
1115
|
}
|
1047
1116
|
|
1117
|
+
// reference implementation for deterministic creation of model files
|
1118
|
+
static void quantize_row_q4_2_reference(const float * restrict x, block_q4_2 * restrict y, int k) {
|
1119
|
+
assert(k % QK4_2 == 0);
|
1120
|
+
|
1121
|
+
const int nb = k / QK4_2;
|
1122
|
+
|
1123
|
+
for (int i = 0; i < nb; i++) {
|
1124
|
+
float amax = 0.0f; // absolute max
|
1125
|
+
|
1126
|
+
for (int l = 0; l < QK4_2; l++) {
|
1127
|
+
const float v = x[i*QK4_2 + l];
|
1128
|
+
amax = MAX(amax, fabsf(v));
|
1129
|
+
}
|
1130
|
+
|
1131
|
+
const float d = amax / ((1 << 3) - 1);
|
1132
|
+
|
1133
|
+
const float id = d ? 1.0f/d : 0.0f;
|
1134
|
+
|
1135
|
+
y[i].d = GGML_FP32_TO_FP16(d);
|
1136
|
+
|
1137
|
+
for (int l = 0; l < QK4_2; l += 2) {
|
1138
|
+
const float v0 = x[i*QK4_2 + l + 0]*id;
|
1139
|
+
const float v1 = x[i*QK4_2 + l + 1]*id;
|
1140
|
+
|
1141
|
+
const uint8_t vi0 = (uint8_t)(v0 + 8.5f);
|
1142
|
+
const uint8_t vi1 = (uint8_t)(v1 + 8.5f);
|
1143
|
+
|
1144
|
+
assert(vi0 < 16);
|
1145
|
+
assert(vi1 < 16);
|
1146
|
+
|
1147
|
+
y[i].qs[l/2] = vi0 | (vi1 << 4);
|
1148
|
+
}
|
1149
|
+
}
|
1150
|
+
}
|
1151
|
+
|
1152
|
+
static inline int nearest_int(float fval) {
|
1153
|
+
assert(fval <= 4194303.f);
|
1154
|
+
float val = fval + 12582912.f;
|
1155
|
+
int i; memcpy(&i, &val, sizeof(int));
|
1156
|
+
return (i & 0x007fffff) - 0x00400000;
|
1157
|
+
}
|
1158
|
+
|
1159
|
+
static float kquantize_q4_with_bounds(int n, int nmin, int nmax, const float * restrict X, int nCandidates,
|
1160
|
+
const float * restrict candidates, int8_t * restrict L) {
|
1161
|
+
assert (nmin >= INT8_MIN);
|
1162
|
+
assert (nmax <= INT8_MAX);
|
1163
|
+
float amax = 0;
|
1164
|
+
for (int i=0; i<n; ++i) amax = MAX(amax, fabsf(X[i]));
|
1165
|
+
if (!amax) { // all zero
|
1166
|
+
for (int i=0; i<n; ++i) L[i] = 0;
|
1167
|
+
return 1.f;
|
1168
|
+
}
|
1169
|
+
float best = 0, bestScale = 0;
|
1170
|
+
for (int si=0; si<nCandidates; ++si) {
|
1171
|
+
float iscale = candidates[si]/amax;
|
1172
|
+
float sumlxP = 0; int suml2P = 0;
|
1173
|
+
float sumlxM = 0; int suml2M = 0;
|
1174
|
+
for (int i=0; i<n; ++i) {
|
1175
|
+
int l = nearest_int(iscale*X[i]);
|
1176
|
+
int lp = MAX(nmin, MIN(nmax, +l));
|
1177
|
+
int lm = MAX(nmin, MIN(nmax, -l));
|
1178
|
+
sumlxP += X[i]*lp; suml2P += lp*lp;
|
1179
|
+
sumlxM += X[i]*lm; suml2M += lm*lm;
|
1180
|
+
}
|
1181
|
+
float sumlxP2 = sumlxP*sumlxP;
|
1182
|
+
float sumlxM2 = sumlxM*sumlxM;
|
1183
|
+
if (sumlxP2*suml2M > sumlxM2*suml2P) {
|
1184
|
+
if (sumlxP2 > best*suml2P) {
|
1185
|
+
best = sumlxP2/suml2P; bestScale = iscale;
|
1186
|
+
}
|
1187
|
+
} else {
|
1188
|
+
if (sumlxM2 > best*suml2M) {
|
1189
|
+
best = sumlxM2/suml2M; bestScale = -iscale;
|
1190
|
+
}
|
1191
|
+
}
|
1192
|
+
}
|
1193
|
+
float sumlx = 0; int suml2 = 0;
|
1194
|
+
for (int i=0; i<n; ++i) {
|
1195
|
+
int l = nearest_int(bestScale*X[i]);
|
1196
|
+
l = MAX(nmin, MIN(nmax, l));
|
1197
|
+
sumlx += X[i]*l; suml2 += l*l;
|
1198
|
+
L[i] = l;
|
1199
|
+
}
|
1200
|
+
float scale = sumlx/suml2;
|
1201
|
+
return scale;
|
1202
|
+
}
|
1203
|
+
|
1204
|
+
static void quantize_row_q4_2_rmse(const float * restrict x, block_q4_2 * restrict y, int k) {
|
1205
|
+
#define CANDIDATE_COUNT 8
|
1206
|
+
static const float candidates[CANDIDATE_COUNT] = { +8.7f, +8.3f, +8.1f, +7.8f, +7.3f, +7.0f, +6.3f, +5.7f };
|
1207
|
+
assert(k % QK4_2 == 0);
|
1208
|
+
|
1209
|
+
int8_t L[QK4_2];
|
1210
|
+
|
1211
|
+
const int nb = k / QK4_2;
|
1212
|
+
|
1213
|
+
for (int i = 0; i < nb; i++) {
|
1214
|
+
float scale = kquantize_q4_with_bounds(QK4_2, -8, 7, x, CANDIDATE_COUNT, candidates, L);
|
1215
|
+
y[i].d = GGML_FP32_TO_FP16(scale);
|
1216
|
+
|
1217
|
+
for (int l = 0; l < QK4_2; l += 2) {
|
1218
|
+
const uint8_t vi0 = (uint8_t)(L[l+0] + 8);
|
1219
|
+
const uint8_t vi1 = (uint8_t)(L[l+1] + 8);
|
1220
|
+
|
1221
|
+
assert(vi0 < 16);
|
1222
|
+
assert(vi1 < 16);
|
1223
|
+
|
1224
|
+
y[i].qs[l/2] = vi0 | (vi1 << 4);
|
1225
|
+
}
|
1226
|
+
|
1227
|
+
x += QK4_2;
|
1228
|
+
}
|
1229
|
+
}
|
1230
|
+
|
1231
|
+
static void quantize_row_q4_2(const float * restrict x, void * restrict vy, int k) {
|
1232
|
+
assert(k % QK4_2 == 0);
|
1233
|
+
|
1234
|
+
block_q4_2 * restrict y = vy;
|
1235
|
+
|
1236
|
+
//quantize_row_q4_2_reference(x, y, k);
|
1237
|
+
// This produces the exact same format, just better match to the input floats ("better" as measured by RMSE)
|
1238
|
+
quantize_row_q4_2_rmse(x, y, k);
|
1239
|
+
}
|
1240
|
+
|
1241
|
+
static void quantize_row_q4_3_reference(const float * restrict x, block_q4_3 * restrict y, int k) {
|
1242
|
+
assert(k % QK4_3 == 0);
|
1243
|
+
const int nb = k / QK4_3;
|
1244
|
+
|
1245
|
+
for (int i = 0; i < nb; i++) {
|
1246
|
+
float min = FLT_MAX;
|
1247
|
+
float max = -FLT_MAX;
|
1248
|
+
|
1249
|
+
for (int l = 0; l < QK4_3; l++) {
|
1250
|
+
const float v = x[i*QK4_3 + l];
|
1251
|
+
if (v < min) min = v;
|
1252
|
+
if (v > max) max = v;
|
1253
|
+
}
|
1254
|
+
|
1255
|
+
const float d = (max - min) / ((1 << 4) - 1);
|
1256
|
+
const float id = d ? 1.0f/d : 0.0f;
|
1257
|
+
|
1258
|
+
y[i].d = GGML_FP32_TO_FP16(d);
|
1259
|
+
y[i].m = GGML_FP32_TO_FP16(min);
|
1260
|
+
|
1261
|
+
for (int l = 0; l < QK4_3; l += 2) {
|
1262
|
+
const float v0 = (x[i*QK4_3 + l + 0] - min)*id;
|
1263
|
+
const float v1 = (x[i*QK4_3 + l + 1] - min)*id;
|
1264
|
+
|
1265
|
+
const uint8_t vi0 = (int) (v0 + 0.5f);
|
1266
|
+
const uint8_t vi1 = (int) (v1 + 0.5f);
|
1267
|
+
|
1268
|
+
assert(vi0 < 16);
|
1269
|
+
assert(vi1 < 16);
|
1270
|
+
|
1271
|
+
y[i].qs[l/2] = vi0 | (vi1 << 4);
|
1272
|
+
}
|
1273
|
+
}
|
1274
|
+
}
|
1275
|
+
|
1276
|
+
static void quantize_row_q4_3(const float * restrict x, void * restrict vy, int k) {
|
1277
|
+
assert(k % QK4_3 == 0);
|
1278
|
+
|
1279
|
+
block_q4_3 * restrict y = vy;
|
1280
|
+
|
1281
|
+
quantize_row_q4_3_reference(x, y, k);
|
1282
|
+
}
|
1283
|
+
|
1048
1284
|
// reference implementation for deterministic creation of model files
|
1049
1285
|
static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int k) {
|
1050
1286
|
assert(k % QK8_0 == 0);
|
@@ -1064,7 +1300,7 @@ static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * r
|
|
1064
1300
|
y[i].d = d;
|
1065
1301
|
|
1066
1302
|
for (int l = 0; l < QK8_0; ++l) {
|
1067
|
-
const float
|
1303
|
+
const float v = x[i*QK8_0 + l]*id;
|
1068
1304
|
y[i].qs[l] = roundf(v);
|
1069
1305
|
}
|
1070
1306
|
}
|
@@ -1211,7 +1447,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
|
|
1211
1447
|
|
1212
1448
|
for (int l = 0; l < QK4_0; l += 32) {
|
1213
1449
|
// Load 32x4-bit integers into 32x8-bit integers
|
1214
|
-
__m256i vx8 =
|
1450
|
+
__m256i vx8 = bytes_from_nibbles_32(pp+l/2);
|
1215
1451
|
|
1216
1452
|
// Subtract 8 from the integers
|
1217
1453
|
vx8 = _mm256_sub_epi8(vx8, _mm256_set1_epi8(8));
|
@@ -1329,7 +1565,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
|
|
1329
1565
|
|
1330
1566
|
for (int l = 0; l < QK4_1; l += 32) {
|
1331
1567
|
// Load 32x4-bit integers into 32x8-bit integers
|
1332
|
-
__m256i vx8 =
|
1568
|
+
__m256i vx8 = bytes_from_nibbles_32(pp+l/2);
|
1333
1569
|
|
1334
1570
|
// Convert to 16-bit int
|
1335
1571
|
const __m256i vx16_lo = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 0));
|
@@ -1420,8 +1656,69 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
|
|
1420
1656
|
#endif
|
1421
1657
|
}
|
1422
1658
|
|
1423
|
-
static void
|
1659
|
+
static void dequantize_row_q4_2(const void * restrict vx, float * restrict y, int k) {
|
1660
|
+
assert(k % QK4_2 == 0);
|
1661
|
+
const int nb = k / QK4_2;
|
1662
|
+
|
1663
|
+
const block_q4_2 * restrict x = vx;
|
1664
|
+
|
1665
|
+
for (int i = 0; i < nb; i++) {
|
1666
|
+
const float d = GGML_FP16_TO_FP32(x[i].d);
|
1667
|
+
|
1668
|
+
const uint8_t * restrict pp = x[i].qs;
|
1669
|
+
|
1670
|
+
for (int l = 0; l < QK4_2; l += 2) {
|
1671
|
+
const uint8_t vi = pp[l/2];
|
1672
|
+
|
1673
|
+
const int8_t vi0 = vi & 0xf;
|
1674
|
+
const int8_t vi1 = vi >> 4;
|
1675
|
+
|
1676
|
+
const float v0 = (vi0 - 8)*d;
|
1677
|
+
const float v1 = (vi1 - 8)*d;
|
1678
|
+
|
1679
|
+
y[i*QK4_2 + l + 0] = v0;
|
1680
|
+
y[i*QK4_2 + l + 1] = v1;
|
1681
|
+
|
1682
|
+
assert(!isnan(y[i*QK4_2 + l + 0]));
|
1683
|
+
assert(!isnan(y[i*QK4_2 + l + 1]));
|
1684
|
+
}
|
1685
|
+
}
|
1686
|
+
}
|
1687
|
+
|
1688
|
+
static void dequantize_row_q4_3(const void * restrict vx, float * restrict y, int k) {
|
1689
|
+
assert(k % QK4_3 == 0);
|
1690
|
+
const int nb = k / QK4_3;
|
1691
|
+
|
1692
|
+
const block_q4_3 * restrict x = vx;
|
1693
|
+
|
1694
|
+
for (int i = 0; i < nb; i++) {
|
1695
|
+
const float d = GGML_FP16_TO_FP32(x[i].d);
|
1696
|
+
const float m = GGML_FP16_TO_FP32(x[i].m);
|
1697
|
+
|
1698
|
+
const uint8_t * restrict pp = x[i].qs;
|
1699
|
+
|
1700
|
+
for (int l = 0; l < QK4_3; l += 2) {
|
1701
|
+
const uint8_t vi = pp[l/2];
|
1702
|
+
|
1703
|
+
const int8_t vi0 = vi & 0xf;
|
1704
|
+
const int8_t vi1 = vi >> 4;
|
1705
|
+
|
1706
|
+
const float v0 = vi0*d + m;
|
1707
|
+
const float v1 = vi1*d + m;
|
1708
|
+
|
1709
|
+
y[i*QK4_3 + l + 0] = v0;
|
1710
|
+
y[i*QK4_3 + l + 1] = v1;
|
1711
|
+
|
1712
|
+
assert(!isnan(y[i*QK4_3 + l + 0]));
|
1713
|
+
assert(!isnan(y[i*QK4_3 + l + 1]));
|
1714
|
+
}
|
1715
|
+
}
|
1716
|
+
}
|
1717
|
+
|
1424
1718
|
static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
1719
|
+
static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
1720
|
+
static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
1721
|
+
static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
1425
1722
|
|
1426
1723
|
static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
|
1427
1724
|
[GGML_TYPE_Q4_0] = {
|
@@ -1435,10 +1732,30 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
|
|
1435
1732
|
.dequantize_row_q = dequantize_row_q4_1,
|
1436
1733
|
.quantize_row_q = quantize_row_q4_1,
|
1437
1734
|
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_1_reference,
|
1438
|
-
.quantize_row_q_dot =
|
1439
|
-
.vec_dot_q =
|
1735
|
+
.quantize_row_q_dot = quantize_row_q8_0,
|
1736
|
+
.vec_dot_q = ggml_vec_dot_q4_1_q8_0,
|
1737
|
+
},
|
1738
|
+
[GGML_TYPE_Q4_2] = {
|
1739
|
+
.dequantize_row_q = dequantize_row_q4_2,
|
1740
|
+
.quantize_row_q = quantize_row_q4_2,
|
1741
|
+
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_2_rmse, //quantize_row_q4_2_reference,
|
1742
|
+
.quantize_row_q_dot = quantize_row_q8_0,
|
1743
|
+
.vec_dot_q = ggml_vec_dot_q4_2_q8_0,
|
1744
|
+
},
|
1745
|
+
[GGML_TYPE_Q4_3] = {
|
1746
|
+
.dequantize_row_q = dequantize_row_q4_3,
|
1747
|
+
.quantize_row_q = quantize_row_q4_3,
|
1748
|
+
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_3_reference, // TODO: RMSE optimization
|
1749
|
+
.quantize_row_q_dot = quantize_row_q8_0,
|
1750
|
+
.vec_dot_q = ggml_vec_dot_q4_3_q8_0,
|
1751
|
+
},
|
1752
|
+
[GGML_TYPE_Q8_0] = {
|
1753
|
+
.dequantize_row_q = NULL, // TODO
|
1754
|
+
.quantize_row_q = quantize_row_q8_0,
|
1755
|
+
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q8_0_reference,
|
1756
|
+
.quantize_row_q_dot = quantize_row_q8_0,
|
1757
|
+
.vec_dot_q = NULL, // TODO
|
1440
1758
|
},
|
1441
|
-
// TODO: GGML_TYPE_Q8_0
|
1442
1759
|
};
|
1443
1760
|
|
1444
1761
|
// For internal test use
|
@@ -2004,191 +2321,6 @@ inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float
|
|
2004
2321
|
*s = sumf;
|
2005
2322
|
}
|
2006
2323
|
|
2007
|
-
#if __AVX512F__ && QK4_0 == 32
|
2008
|
-
static inline __m512i bytes_from_q4_0_twoblocks_avx512( const __m512i blocks ) {
|
2009
|
-
// The 64 bytes of `blocks` contain two consecutive Q4_0 blocks loaded from memory:
|
2010
|
-
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
|
2011
|
-
// |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32|
|
2012
|
-
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
|
2013
|
-
// | :. =_ () [] <> () Zz Yy|
|
2014
|
-
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
|
2015
|
-
// |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00|
|
2016
|
-
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
|
2017
|
-
// |Xx Ww Vv Uu Tt Ss Rr Qq Pp Oo Nn Mm Ll Kk Jj Ii Hh Gg Ff Ee Dd Cc Bb Aa |
|
2018
|
-
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
|
2019
|
-
//
|
2020
|
-
// Bytes 04..19 (block #0) and 24..39 (block #1) both contain 32 nibbles (4-bit unsigned integers).
|
2021
|
-
// We have exactly 64 nibbles, so we want to place each nibble into a separate byte.
|
2022
|
-
// Bytes 00..03 and 20..23 contain scales, which are irrelevant to this function.
|
2023
|
-
// Bytes 40..63 are masked when loading the data, so they are zeroed out.
|
2024
|
-
#ifdef __AVX512VBMI__
|
2025
|
-
const __m512i byte_perm = _mm512_set_epi8(
|
2026
|
-
39, 38, 39, 38, 37, 36, 37, 36, 35, 34, 35, 34, 33, 32, 33, 32,
|
2027
|
-
31, 30, 31, 30, 29, 28, 29, 28, 27, 26, 27, 26, 25, 24, 25, 24,
|
2028
|
-
19, 18, 19, 18, 17, 16, 17, 16, 15, 14, 15, 14, 13, 12, 13, 12,
|
2029
|
-
11, 10, 11, 10, 9, 8, 9, 8, 7, 6, 7, 6, 5, 4, 5, 4
|
2030
|
-
);
|
2031
|
-
const __m512i permuted = _mm512_permutexvar_epi8( byte_perm, blocks );
|
2032
|
-
// After applying VPERMB, `permuted` looks like this:
|
2033
|
-
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2034
|
-
// |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32|
|
2035
|
-
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2036
|
-
// |:. =_ :. =_ () [] () [] <> () <> () Zz Yy Zz Yy Xx Ww Xx Ww Vv Uu Vv Uu Tt Ss Tt Ss Rr Qq Rr Qq|
|
2037
|
-
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2038
|
-
// |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00|
|
2039
|
-
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2040
|
-
// |Pp Oo Pp Oo Nn Mm Nn Mm Ll Kk Ll Kk Jj Ii Jj Ii Hh Gg Hh Gg Ff Ee Ff Ee Dd Cc Dd Cc Bb Aa Bb Aa|
|
2041
|
-
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2042
|
-
#else
|
2043
|
-
const __m512i word_perm = _mm512_set_epi16(
|
2044
|
-
19, 19, 18, 18, 17, 17, 16, 16, 15, 15, 14, 14, 13, 13, 12, 12,
|
2045
|
-
9, 9, 8, 8, 7, 7, 6, 6, 5, 5, 4, 4, 3, 3, 2, 2
|
2046
|
-
);
|
2047
|
-
const __m512i permuted = _mm512_permutexvar_epi16( word_perm, blocks );
|
2048
|
-
// This is the fallback path for CPUs that don't support VPERMB. Since we permute 16-bit groups only,
|
2049
|
-
// VPERMB can be replaced with VPERMW. We could always use VPERMW, but at least on Tiger Lake and
|
2050
|
-
// Ice Lake VPERMW followed by a right shift is quite noticeably slower than VPERMB.
|
2051
|
-
#endif
|
2052
|
-
|
2053
|
-
// Shift every odd-numbered 16-bit group to the right by 4 bits.
|
2054
|
-
const __mmask32 shift_mask = 0xaaaaaaaa;
|
2055
|
-
const __m512i shifted = _mm512_mask_srai_epi16( permuted, shift_mask, permuted, 4 );
|
2056
|
-
// After applying VPSRAW, `shifted` looks like this (the "empty" nibbles are filled with zeroes):
|
2057
|
-
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2058
|
-
// |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32
|
2059
|
-
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2060
|
-
// | : .= :. =_ ( )[ () [] < >( <> () Z zY Zz Yy X xW Xx Ww V vU Vv Uu T tS Tt Ss R rQ Rr Qq
|
2061
|
-
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2062
|
-
// |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00|
|
2063
|
-
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2064
|
-
// | P pO Pp Oo N nM Nn Mm L lK Ll Kk J jI Jj Ii H hG Hh Gg F fE Ff Ee D dC Dd Cc B bA Bb Aa|
|
2065
|
-
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2066
|
-
|
2067
|
-
// Now we just need to zero out the higher nibble in each byte, and we're done.
|
2068
|
-
const __m512i low_nibble_mask = _mm512_set1_epi8( 0xf );
|
2069
|
-
return _mm512_and_si512( low_nibble_mask, shifted );
|
2070
|
-
// The final result looks like this:
|
2071
|
-
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2072
|
-
// |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32|
|
2073
|
-
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2074
|
-
// | : = . _ ( [ ) ] < ( > ) Z Y z y X W x w V U v u T S t s R Q r q|
|
2075
|
-
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2076
|
-
// |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00|
|
2077
|
-
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2078
|
-
// | P O p o N M n m L K l k J I j i H G h g F E f e D C d c B A b a|
|
2079
|
-
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
|
2080
|
-
}
|
2081
|
-
|
2082
|
-
static inline __m512 dot_q4_0_twoblocks_avx512(
|
2083
|
-
__m512 acc,
|
2084
|
-
const block_q4_0 * restrict x,
|
2085
|
-
const block_q4_0 * restrict y,
|
2086
|
-
int i
|
2087
|
-
) {
|
2088
|
-
// A pair of Q4_0 blocks spans 40 bytes, while an AVX-512 register has 64. The remaining 24 bytes
|
2089
|
-
// can potentially be unaddressable, so we make sure to mask them out before the load, even though
|
2090
|
-
// we don't use them at all. This might hurt the performance slightly, since the compiler is forced
|
2091
|
-
// to use e.g. `VMOVDQU64 REG, MASK, [ADDR] + VPERMB ..., REG` instead of just `VPERMB ..., [ADDR]`.
|
2092
|
-
const __mmask8 load_mask = 0x1f;
|
2093
|
-
const __m512i blocks_0 = _mm512_maskz_loadu_epi64( load_mask, &x[i] );
|
2094
|
-
const __m512i blocks_1 = _mm512_maskz_loadu_epi64( load_mask, &y[i] );
|
2095
|
-
|
2096
|
-
// We want to multiply the scales, so we interpret both registers as 16 32-bit floats:
|
2097
|
-
// +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
|
2098
|
-
// | 15 | 14 | 13 | 12 | 11 | 10 | 09 | 08 | 07 | 06 | 05 | 04 | 03 | 02 | 01 | 00 |
|
2099
|
-
// +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
|
2100
|
-
// blocks_0_float
|
2101
|
-
// +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
|
2102
|
-
// | | | | | | | xx | xx | xx | xx | B | xx | xx | xx | xx | A |
|
2103
|
-
// +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
|
2104
|
-
// blocks_1_float
|
2105
|
-
// +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
|
2106
|
-
// | | | | | | | xx | xx | xx | xx | D | xx | xx | xx | xx | C |
|
2107
|
-
// +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
|
2108
|
-
const __m512 blocks_0_float = _mm512_castsi512_ps( blocks_0 );
|
2109
|
-
const __m512 blocks_1_float = _mm512_castsi512_ps( blocks_1 );
|
2110
|
-
// We absolutely shouldn't touch the floats marked with `xx`: they contain some
|
2111
|
-
// random data, which might very well underflow. At least on Intel, this leads
|
2112
|
-
// to a huge penalty that can't be ignored (easily 100x or more) unless you
|
2113
|
-
// compile your code with something like `-ffast-math` to enable FTZ/DAZ flags.
|
2114
|
-
// (and ggml can't assume that you do)...
|
2115
|
-
const __mmask16 scale_mul_mask = 0x21;
|
2116
|
-
#ifdef __clang__
|
2117
|
-
// ...however, clang decides to optimize the multiplication mask away:
|
2118
|
-
// https://godbolt.org/z/P8PqdsfvW
|
2119
|
-
// gcc and MSVC do the sane thing. This horrible workaround forces clang to emit the mask.
|
2120
|
-
__m512i scales;
|
2121
|
-
__asm__(
|
2122
|
-
"vmulps %1, %2, %0%{%3%}"
|
2123
|
-
: "=v" ( scales )
|
2124
|
-
: "vm" ( blocks_0_float ), "v" ( blocks_1_float ), "Yk" ( scale_mul_mask )
|
2125
|
-
);
|
2126
|
-
#else
|
2127
|
-
const __m512 scales = _mm512_maskz_mul_ps( scale_mul_mask, blocks_0_float, blocks_1_float );
|
2128
|
-
#endif
|
2129
|
-
const __m512i scale_perm = _mm512_set_epi32(
|
2130
|
-
5, 5, 5, 5, 5, 5, 5, 5,
|
2131
|
-
0, 0, 0, 0, 0, 0, 0, 0
|
2132
|
-
);
|
2133
|
-
const __m512 permuted_scales = _mm512_permutexvar_ps( scale_perm, scales );
|
2134
|
-
// After VMULPS and VPERMPS, `permuted_scales` looks like this:
|
2135
|
-
// +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
|
2136
|
-
// | 15 | 14 | 13 | 12 | 11 | 10 | 09 | 08 | 07 | 06 | 05 | 04 | 03 | 02 | 01 | 00 |
|
2137
|
-
// +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
|
2138
|
-
// | B*D| B*D| B*D| B*D| B*D| B*D| B*D| B*D| A*C| A*C| A*C| A*C| A*C| A*C| A*C| A*C|
|
2139
|
-
// +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
|
2140
|
-
|
2141
|
-
const __m512i bytes_0 = bytes_from_q4_0_twoblocks_avx512( blocks_0 );
|
2142
|
-
const __m512i bytes_1 = bytes_from_q4_0_twoblocks_avx512( blocks_1 );
|
2143
|
-
|
2144
|
-
// Now we want to compute dot products of 4-element byte vectors and store them in
|
2145
|
-
// 32-bit integers. That is (only one 4-element vector is shown for clarity):
|
2146
|
-
// +----+----+----+----+
|
2147
|
-
// ... | 03 | 02 | 01 | 00 |
|
2148
|
-
// +----+----+----+----+
|
2149
|
-
// bytes_0
|
2150
|
-
// +----+----+----+----+
|
2151
|
-
// ... | D | C | B | A |
|
2152
|
-
// +----+----+----+----+
|
2153
|
-
// bytes_1
|
2154
|
-
// +----+----+----+----+
|
2155
|
-
// ... | H | G | F | E |
|
2156
|
-
// +----+----+----+----+
|
2157
|
-
// final_res_int
|
2158
|
-
// +----+----+----+----+
|
2159
|
-
// ... | A*E+B*F+C*G+D*H |
|
2160
|
-
// +----+----+----+----+
|
2161
|
-
const __m512i plus_8 = _mm512_set1_epi8( 8 );
|
2162
|
-
const __m512i bytes_1_minus_8 = _mm512_sub_epi8( bytes_1, plus_8 );
|
2163
|
-
|
2164
|
-
#ifdef __AVX512VNNI__
|
2165
|
-
// We have VPDPBUSDS in AVX512-VNNI, which does exactly what we want, but with a catch:
|
2166
|
-
// the *left* operand is supposed to be unsigned, while Q4_0 quantization subtracts 8
|
2167
|
-
// from each nibble, so they can be negative. So, instead of `(bytes_0 - 8) * (bytes_1 - 8)`,
|
2168
|
-
// we compute `bytes_0 * (bytes_1 - 8) + bytes_1 * (-8) + 64`. VPDPBUSDS uses an accumulator,
|
2169
|
-
// which means we only need 2 instructions.
|
2170
|
-
const __m512i dot_init = _mm512_set1_epi32( 4 * 64 );
|
2171
|
-
const __m512i minus_8 = _mm512_set1_epi8( -8 );
|
2172
|
-
const __m512i prod_0 = _mm512_dpbusds_epi32( dot_init, bytes_1, minus_8 );
|
2173
|
-
const __m512i final_res_int = _mm512_dpbusds_epi32( prod_0, bytes_0, bytes_1_minus_8 );
|
2174
|
-
#else
|
2175
|
-
// As a fallback, we have VPMADDUBSW in AVX512-BW, which uses 16-bit products instead of 32-bit ones.
|
2176
|
-
// It has the same catch as VPDPBUSDS: the left operand should be unsigned.
|
2177
|
-
// This is essentially the AVX-512 version of the AVX-2 trick used by GH user Const-me
|
2178
|
-
// ref: https://gist.github.com/Const-me/4d30e1fc767ab314596e16e90f53b6f4#file-matmultest-cpp-L119
|
2179
|
-
const __m512i one = _mm512_set1_epi16( 1 );
|
2180
|
-
const __m512i prod_0 = _mm512_maddubs_epi16( bytes_0, bytes_1_minus_8 );
|
2181
|
-
const __m512i prod_1 = _mm512_maddubs_epi16( plus_8, bytes_1_minus_8 );
|
2182
|
-
const __m512i diff = _mm512_sub_epi16( prod_0, prod_1 );
|
2183
|
-
const __m512i final_res_int = _mm512_madd_epi16( diff, one );
|
2184
|
-
#endif
|
2185
|
-
|
2186
|
-
// Finally, we multiply the permuted scales and the 32-bit dot products, then accumulate.
|
2187
|
-
const __m512 final_res_float = _mm512_cvtepi32_ps( final_res_int );
|
2188
|
-
return _mm512_fmadd_ps( permuted_scales, final_res_float, acc );
|
2189
|
-
}
|
2190
|
-
#endif
|
2191
|
-
|
2192
2324
|
inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y) {
|
2193
2325
|
ggml_float sumf = 0.0;
|
2194
2326
|
|
@@ -2225,67 +2357,64 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
|
|
2225
2357
|
*s = sumf;
|
2226
2358
|
}
|
2227
2359
|
|
2228
|
-
static void
|
2229
|
-
const int nb = n /
|
2360
|
+
static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
2361
|
+
const int nb = n / QK8_0;
|
2230
2362
|
|
2231
|
-
assert(n %
|
2363
|
+
assert(n % QK8_0 == 0);
|
2232
2364
|
assert(nb % 2 == 0);
|
2233
2365
|
|
2234
2366
|
const block_q4_0 * restrict x = vx;
|
2235
|
-
const
|
2367
|
+
const block_q8_0 * restrict y = vy;
|
2236
2368
|
|
2237
2369
|
float sumf = 0.0;
|
2238
2370
|
|
2239
2371
|
#if defined(__ARM_NEON)
|
2240
|
-
|
2241
|
-
|
2372
|
+
float32x4_t sumv0 = vdupq_n_f32(0.0f);
|
2373
|
+
float32x4_t sumv1 = vdupq_n_f32(0.0f);
|
2242
2374
|
|
2243
2375
|
for (int i = 0; i < nb; i += 2) {
|
2244
2376
|
const block_q4_0 * restrict x0 = &x[i + 0];
|
2245
|
-
const block_q4_0 * restrict y0 = &y[i + 0];
|
2246
2377
|
const block_q4_0 * restrict x1 = &x[i + 1];
|
2247
|
-
const
|
2378
|
+
const block_q8_0 * restrict y0 = &y[i + 0];
|
2379
|
+
const block_q8_0 * restrict y1 = &y[i + 1];
|
2248
2380
|
|
2249
|
-
const uint8x16_t m4b
|
2250
|
-
const int8x16_t s8b
|
2381
|
+
const uint8x16_t m4b = vdupq_n_u8(0xf);
|
2382
|
+
const int8x16_t s8b = vdupq_n_s8(0x8);
|
2251
2383
|
|
2252
2384
|
const uint8x16_t v0_0 = vld1q_u8(x0->qs);
|
2253
|
-
const uint8x16_t v1_0 = vld1q_u8(y0->qs);
|
2254
2385
|
const uint8x16_t v0_1 = vld1q_u8(x1->qs);
|
2255
|
-
const uint8x16_t v1_1 = vld1q_u8(y1->qs);
|
2256
2386
|
|
2257
2387
|
// 4-bit -> 8-bit
|
2258
|
-
const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b));
|
2259
|
-
const int8x16_t v1_0l = vreinterpretq_s8_u8(vandq_u8(v1_0, m4b));
|
2388
|
+
const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
|
2260
2389
|
const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
|
2261
|
-
const int8x16_t
|
2262
|
-
|
2263
|
-
const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8(v0_1, m4b));
|
2264
|
-
const int8x16_t v1_1l = vreinterpretq_s8_u8(vandq_u8(v1_1, m4b));
|
2390
|
+
const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
|
2265
2391
|
const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
|
2266
|
-
const int8x16_t v1_1h = vreinterpretq_s8_u8(vshrq_n_u8(v1_1, 4));
|
2267
2392
|
|
2268
2393
|
// sub 8
|
2269
2394
|
const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
|
2270
|
-
const int8x16_t v1_0ls = vsubq_s8(v1_0l, s8b);
|
2271
2395
|
const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
|
2272
|
-
const int8x16_t v1_0hs = vsubq_s8(v1_0h, s8b);
|
2273
|
-
|
2274
2396
|
const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
|
2275
|
-
const int8x16_t v1_1ls = vsubq_s8(v1_1l, s8b);
|
2276
2397
|
const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
|
2277
|
-
|
2398
|
+
|
2399
|
+
// load y
|
2400
|
+
const int8x16_t v1_0l = vld1q_s8(y0->qs);
|
2401
|
+
const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
|
2402
|
+
const int8x16_t v1_1l = vld1q_s8(y1->qs);
|
2403
|
+
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
|
2404
|
+
|
2405
|
+
// interleave
|
2406
|
+
const int8x16_t v1_0ls = vuzp1q_s8(v1_0l, v1_0h);
|
2407
|
+
const int8x16_t v1_0hs = vuzp2q_s8(v1_0l, v1_0h);
|
2408
|
+
const int8x16_t v1_1ls = vuzp1q_s8(v1_1l, v1_1h);
|
2409
|
+
const int8x16_t v1_1hs = vuzp2q_s8(v1_1l, v1_1h);
|
2278
2410
|
|
2279
2411
|
#if defined(__ARM_FEATURE_DOTPROD)
|
2280
2412
|
// dot product into int32x4_t
|
2281
|
-
int32x4_t p_0 = vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls);
|
2282
|
-
int32x4_t p_1 = vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls);
|
2413
|
+
const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls), v0_0hs, v1_0hs);
|
2414
|
+
const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls), v0_1hs, v1_1hs);
|
2283
2415
|
|
2284
|
-
|
2285
|
-
|
2286
|
-
|
2287
|
-
sum0 += x0->d*y0->d*vaddvq_s32(p_0);
|
2288
|
-
sum1 += x1->d*y1->d*vaddvq_s32(p_1);
|
2416
|
+
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d);
|
2417
|
+
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d);
|
2289
2418
|
#else
|
2290
2419
|
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
|
2291
2420
|
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
|
@@ -2297,116 +2426,51 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
|
|
2297
2426
|
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
|
2298
2427
|
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
|
2299
2428
|
|
2300
|
-
const
|
2301
|
-
const
|
2302
|
-
|
2303
|
-
const
|
2304
|
-
const int16x8_t ph_1 = vaddq_s16(ph1l, ph1h);
|
2429
|
+
const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
|
2430
|
+
const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
|
2431
|
+
const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
|
2432
|
+
const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
|
2305
2433
|
|
2306
|
-
|
2307
|
-
|
2308
|
-
|
2309
|
-
sum0 += x0->d*y0->d*vaddvq_s16(p_0);
|
2310
|
-
sum1 += x1->d*y1->d*vaddvq_s16(p_1);
|
2434
|
+
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0->d*y0->d);
|
2435
|
+
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), x1->d*y1->d);
|
2311
2436
|
#endif
|
2312
2437
|
}
|
2313
2438
|
|
2314
|
-
sumf =
|
2315
|
-
#elif defined(
|
2439
|
+
sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
|
2440
|
+
#elif defined(__AVX2__)
|
2316
2441
|
// Initialize accumulator with zeros
|
2317
|
-
|
2318
|
-
__m512 acc1 = _mm512_setzero_ps();
|
2442
|
+
__m256 acc = _mm256_setzero_ps();
|
2319
2443
|
|
2320
|
-
|
2444
|
+
// Main loop
|
2445
|
+
for (int i = 0; i < nb; ++i) {
|
2446
|
+
/* Compute combined scale for the block */
|
2447
|
+
const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
|
2321
2448
|
|
2322
|
-
|
2449
|
+
__m256i bx = bytes_from_nibbles_32(x[i].qs);
|
2323
2450
|
|
2324
|
-
|
2325
|
-
|
2451
|
+
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
|
2452
|
+
const __m256i off = _mm256_set1_epi8( 8 );
|
2453
|
+
bx = _mm256_sub_epi8( bx, off );
|
2326
2454
|
|
2327
|
-
|
2328
|
-
acc1 = dot_q4_0_twoblocks_avx512( acc1, x, y, i+2 );
|
2329
|
-
acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i+4 );
|
2330
|
-
acc1 = dot_q4_0_twoblocks_avx512( acc1, x, y, i+6 );
|
2331
|
-
acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i+8 );
|
2332
|
-
acc1 = dot_q4_0_twoblocks_avx512( acc1, x, y, i+10 );
|
2333
|
-
acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i+12 );
|
2334
|
-
acc1 = dot_q4_0_twoblocks_avx512( acc1, x, y, i+14 );
|
2335
|
-
}
|
2455
|
+
__m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
|
2336
2456
|
|
2337
|
-
|
2338
|
-
|
2339
|
-
acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i );
|
2340
|
-
}
|
2457
|
+
// Get absolute values of x vectors
|
2458
|
+
const __m256i ax = _mm256_sign_epi8(bx, bx);
|
2341
2459
|
|
2342
|
-
|
2343
|
-
|
2344
|
-
#elif defined(__AVX2__)
|
2345
|
-
// Initialize accumulator with zeros
|
2346
|
-
__m256 acc = _mm256_setzero_ps();
|
2460
|
+
// Sign the values of the y vectors
|
2461
|
+
const __m256i sy = _mm256_sign_epi8(by, bx);
|
2347
2462
|
|
2348
|
-
|
2349
|
-
|
2350
|
-
const __m256i offset_8 = _mm256_set1_epi16( 8 );
|
2463
|
+
// Perform multiplication and create 16-bit values
|
2464
|
+
const __m256i dot = _mm256_maddubs_epi16(ax, sy);
|
2351
2465
|
|
2352
|
-
|
2353
|
-
|
2354
|
-
assert(nb % UNROLL_COUNT == 0);
|
2466
|
+
const __m256i ones = _mm256_set1_epi16(1);
|
2467
|
+
__m256i xy_q = _mm256_madd_epi16(ones, dot);
|
2355
2468
|
|
2356
|
-
|
2357
|
-
|
2358
|
-
|
2359
|
-
|
2360
|
-
|
2361
|
-
const __m256 scale = _mm256_mul_ps(
|
2362
|
-
_mm256_broadcast_ss( &x[i+u].d ),
|
2363
|
-
_mm256_broadcast_ss( &y[i+u].d ) );
|
2364
|
-
|
2365
|
-
/* get input from x
|
2366
|
-
Input: 32 Nibbles (16 bytes) at *x[i+u]
|
2367
|
-
Output: 2 vectors with 16 values of type int16_t (x_high_q, x_low_q) */
|
2368
|
-
|
2369
|
-
/* Load 16 bytes from memory */
|
2370
|
-
const __m128i tmp_x = _mm_loadu_si128( ( const __m128i* ) x[i+u].qs);
|
2371
|
-
/* Expand bytes into uint16_t values */
|
2372
|
-
const __m256i bytes_x = _mm256_cvtepu8_epi16(tmp_x);
|
2373
|
-
/* Unpack values into individual bytes */
|
2374
|
-
__m256i x_low_q = _mm256_and_si256( lowMask, bytes_x );
|
2375
|
-
const __m256i pre_shift_x_high_q = _mm256_andnot_si256( lowMask, bytes_x );
|
2376
|
-
__m256i x_high_q = _mm256_srli_epi16( pre_shift_x_high_q, 4 );
|
2377
|
-
/* Now we have two vectors with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. */
|
2378
|
-
x_high_q = _mm256_sub_epi16( x_high_q, offset_8 );
|
2379
|
-
x_low_q = _mm256_sub_epi16( x_low_q, offset_8 );
|
2380
|
-
|
2381
|
-
/* get input from y
|
2382
|
-
Input: 32 Nibbles (16 bytes) at *y[i+u]
|
2383
|
-
Output: 2 vectors with 16 values of type int16_t (y_high_q, y_low_q) */
|
2384
|
-
|
2385
|
-
/* Load 16 bytes from memory */
|
2386
|
-
const __m128i tmp_y = _mm_loadu_si128( (const __m128i* ) y[i+u].qs);
|
2387
|
-
/* Expand bytes into uint16_t values */
|
2388
|
-
const __m256i bytes_y = _mm256_cvtepu8_epi16(tmp_y);
|
2389
|
-
/* Unpack values into individual bytes */
|
2390
|
-
const __m256i pre_shift_y_high_q = _mm256_andnot_si256( lowMask, bytes_y );
|
2391
|
-
__m256i y_high_q = _mm256_srli_epi16( pre_shift_y_high_q, 4 );
|
2392
|
-
__m256i y_low_q = _mm256_and_si256( lowMask, bytes_y );
|
2393
|
-
/* Now we have two vectors with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. */
|
2394
|
-
y_high_q = _mm256_sub_epi16( y_high_q, offset_8 );
|
2395
|
-
y_low_q = _mm256_sub_epi16( y_low_q, offset_8 );
|
2396
|
-
|
2397
|
-
/* Compute products of int16_t integers, add pairwise, store as int32_t */
|
2398
|
-
__m256i xy_high_q = _mm256_madd_epi16( x_high_q, y_high_q );
|
2399
|
-
__m256i xy_low_q = _mm256_madd_epi16( x_low_q, y_low_q );
|
2400
|
-
|
2401
|
-
/* Accumulate the products of int32_t integers -> we now have a vector of 8 int_32t */
|
2402
|
-
__m256i xy_q = _mm256_add_epi32( xy_high_q, xy_low_q );
|
2403
|
-
|
2404
|
-
/* Convert to vectore of 8 int32_t to 8 floats */
|
2405
|
-
__m256 q = _mm256_cvtepi32_ps( xy_q );
|
2406
|
-
|
2407
|
-
/* Multiply q with scale and accumulate */
|
2408
|
-
acc = _mm256_fmadd_ps( scale, q, acc );
|
2409
|
-
}
|
2469
|
+
/* Convert to vectore of 8 int32_t to 8 floats */
|
2470
|
+
__m256 q = _mm256_cvtepi32_ps( xy_q );
|
2471
|
+
|
2472
|
+
/* Multiply q with scale and accumulate */
|
2473
|
+
acc = _mm256_fmadd_ps( d, q, acc );
|
2410
2474
|
}
|
2411
2475
|
|
2412
2476
|
// Return horizontal sum of the acc vector
|
@@ -2428,13 +2492,12 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
|
|
2428
2492
|
__m128i i32[2];
|
2429
2493
|
for (int j = 0; j < 2; ++j) {
|
2430
2494
|
// Load 8 bytes, and unpack 4 bit fields into bytes, making 16 bytes
|
2431
|
-
__m128i bx =
|
2432
|
-
__m128i by =
|
2495
|
+
__m128i bx = bytes_from_nibbles_16(x[i].qs + 8*j);
|
2496
|
+
__m128i by = _mm_loadu_si128((const __m128i *)(y[i].qs + 16*j));
|
2433
2497
|
|
2434
2498
|
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
|
2435
2499
|
const __m128i off = _mm_set1_epi8( 8 );
|
2436
2500
|
bx = _mm_sub_epi8( bx, off );
|
2437
|
-
by = _mm_sub_epi8( by, off );
|
2438
2501
|
|
2439
2502
|
// Get absolute values of x vectors
|
2440
2503
|
const __m128i ax = _mm_sign_epi8(bx, bx);
|
@@ -2462,86 +2525,6 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
|
|
2462
2525
|
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
|
2463
2526
|
|
2464
2527
|
sumf = _mm_cvtss_f32( res );
|
2465
|
-
#elif defined(__wasm_simd128__)
|
2466
|
-
// wasm simd
|
2467
|
-
float sum0 = 0.0f;
|
2468
|
-
float sum1 = 0.0f;
|
2469
|
-
|
2470
|
-
for (int i = 0; i < nb; i += 2) {
|
2471
|
-
const block_q4_0 * restrict x0 = &x[i + 0];
|
2472
|
-
const block_q4_0 * restrict y0 = &y[i + 0];
|
2473
|
-
const block_q4_0 * restrict x1 = &x[i + 1];
|
2474
|
-
const block_q4_0 * restrict y1 = &y[i + 1];
|
2475
|
-
|
2476
|
-
const v128_t m4b = wasm_u8x16_splat(0xf);
|
2477
|
-
const v128_t s8b = wasm_i8x16_splat(0x8);
|
2478
|
-
|
2479
|
-
const v128_t v0_0 = wasm_v128_load(x0->qs);
|
2480
|
-
const v128_t v0_1 = wasm_v128_load(y0->qs);
|
2481
|
-
const v128_t v1_0 = wasm_v128_load(x1->qs);
|
2482
|
-
const v128_t v1_1 = wasm_v128_load(y1->qs);
|
2483
|
-
|
2484
|
-
// 4-bit -> 8-bit
|
2485
|
-
const v128_t v0_0l = wasm_v128_and(v0_0, m4b);
|
2486
|
-
const v128_t v1_0l = wasm_v128_and(v1_0, m4b);
|
2487
|
-
|
2488
|
-
const v128_t v0_0h = wasm_u8x16_shr(v0_0, 4);
|
2489
|
-
const v128_t v1_0h = wasm_u8x16_shr(v1_0, 4);
|
2490
|
-
|
2491
|
-
const v128_t v0_1l = wasm_v128_and(v0_1, m4b);
|
2492
|
-
const v128_t v1_1l = wasm_v128_and(v1_1, m4b);
|
2493
|
-
|
2494
|
-
const v128_t v0_1h = wasm_u8x16_shr(v0_1, 4);
|
2495
|
-
const v128_t v1_1h = wasm_u8x16_shr(v1_1, 4);
|
2496
|
-
|
2497
|
-
// sub 8
|
2498
|
-
const v128_t v0_0ls = wasm_i8x16_sub(v0_0l, s8b);
|
2499
|
-
const v128_t v1_0ls = wasm_i8x16_sub(v1_0l, s8b);
|
2500
|
-
|
2501
|
-
const v128_t v0_0hs = wasm_i8x16_sub(v0_0h, s8b);
|
2502
|
-
const v128_t v1_0hs = wasm_i8x16_sub(v1_0h, s8b);
|
2503
|
-
|
2504
|
-
const v128_t v0_1ls = wasm_i8x16_sub(v0_1l, s8b);
|
2505
|
-
const v128_t v1_1ls = wasm_i8x16_sub(v1_1l, s8b);
|
2506
|
-
|
2507
|
-
const v128_t v0_1hs = wasm_i8x16_sub(v0_1h, s8b);
|
2508
|
-
const v128_t v1_1hs = wasm_i8x16_sub(v1_1h, s8b);
|
2509
|
-
|
2510
|
-
// dot product into int16x8_t
|
2511
|
-
const v128_t pl0l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_0ls), wasm_i16x8_extend_low_i8x16(v1_0ls));
|
2512
|
-
const v128_t pl0h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_0ls), wasm_i16x8_extend_high_i8x16(v1_0ls));
|
2513
|
-
|
2514
|
-
const v128_t ph0l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_0hs), wasm_i16x8_extend_low_i8x16(v1_0hs));
|
2515
|
-
const v128_t ph0h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_0hs), wasm_i16x8_extend_high_i8x16(v1_0hs));
|
2516
|
-
|
2517
|
-
const v128_t pl1l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_1ls), wasm_i16x8_extend_low_i8x16(v1_1ls));
|
2518
|
-
const v128_t pl1h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_1ls), wasm_i16x8_extend_high_i8x16(v1_1ls));
|
2519
|
-
|
2520
|
-
const v128_t ph1l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_1hs), wasm_i16x8_extend_low_i8x16(v1_1hs));
|
2521
|
-
const v128_t ph1h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_1hs), wasm_i16x8_extend_high_i8x16(v1_1hs));
|
2522
|
-
|
2523
|
-
const v128_t pl_0 = wasm_i16x8_add(pl0l, pl0h);
|
2524
|
-
const v128_t ph_0 = wasm_i16x8_add(ph0l, ph0h);
|
2525
|
-
|
2526
|
-
const v128_t pl_1 = wasm_i16x8_add(pl1l, pl1h);
|
2527
|
-
const v128_t ph_1 = wasm_i16x8_add(ph1l, ph1h);
|
2528
|
-
|
2529
|
-
const v128_t p_0 = wasm_i16x8_add(pl_0, ph_0);
|
2530
|
-
const v128_t p_1 = wasm_i16x8_add(pl_1, ph_1);
|
2531
|
-
|
2532
|
-
sum0 += x0->d * y0->d * (
|
2533
|
-
wasm_i16x8_extract_lane(p_0, 0) + wasm_i16x8_extract_lane(p_0, 1) +
|
2534
|
-
wasm_i16x8_extract_lane(p_0, 2) + wasm_i16x8_extract_lane(p_0, 3) +
|
2535
|
-
wasm_i16x8_extract_lane(p_0, 4) + wasm_i16x8_extract_lane(p_0, 5) +
|
2536
|
-
wasm_i16x8_extract_lane(p_0, 6) + wasm_i16x8_extract_lane(p_0, 7));
|
2537
|
-
sum1 += x1->d * y1->d * (
|
2538
|
-
wasm_i16x8_extract_lane(p_1, 0) + wasm_i16x8_extract_lane(p_1, 1) +
|
2539
|
-
wasm_i16x8_extract_lane(p_1, 2) + wasm_i16x8_extract_lane(p_1, 3) +
|
2540
|
-
wasm_i16x8_extract_lane(p_1, 4) + wasm_i16x8_extract_lane(p_1, 5) +
|
2541
|
-
wasm_i16x8_extract_lane(p_1, 6) + wasm_i16x8_extract_lane(p_1, 7));
|
2542
|
-
}
|
2543
|
-
|
2544
|
-
sumf = sum0 + sum1;
|
2545
2528
|
#else
|
2546
2529
|
// scalar
|
2547
2530
|
for (int i = 0; i < nb; i++) {
|
@@ -2549,202 +2532,187 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
|
|
2549
2532
|
const float d1 = y[i].d;
|
2550
2533
|
|
2551
2534
|
const uint8_t * restrict p0 = x[i].qs;
|
2552
|
-
const
|
2535
|
+
const int8_t * restrict p1 = y[i].qs;
|
2553
2536
|
|
2554
2537
|
int sumi = 0;
|
2555
|
-
for (int j = 0; j <
|
2538
|
+
for (int j = 0; j < QK8_0/2; j++) {
|
2556
2539
|
const uint8_t v0 = p0[j];
|
2557
|
-
const uint8_t v1 = p1[j];
|
2558
2540
|
|
2559
|
-
const int i0 = (v0 & 0xf) - 8;
|
2560
|
-
const int i1 = (v0 >> 4) - 8;
|
2541
|
+
const int i0 = (int8_t) (v0 & 0xf) - 8;
|
2542
|
+
const int i1 = (int8_t) (v0 >> 4) - 8;
|
2561
2543
|
|
2562
|
-
const int i2 =
|
2563
|
-
const int i3 =
|
2544
|
+
const int i2 = p1[2*j + 0];
|
2545
|
+
const int i3 = p1[2*j + 1];
|
2564
2546
|
|
2565
2547
|
sumi += i0*i2 + i1*i3;
|
2566
2548
|
}
|
2567
|
-
sumf += d0
|
2549
|
+
sumf += d0*d1*sumi;
|
2568
2550
|
}
|
2569
2551
|
#endif
|
2570
2552
|
|
2571
2553
|
*s = sumf;
|
2572
2554
|
}
|
2573
2555
|
|
2574
|
-
static void
|
2575
|
-
const int nb = n /
|
2556
|
+
static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
2557
|
+
const int nb = n / QK8_0;
|
2558
|
+
|
2559
|
+
assert(n % QK8_0 == 0);
|
2560
|
+
assert(nb % 2 == 0);
|
2576
2561
|
|
2577
2562
|
const block_q4_1 * restrict x = vx;
|
2578
|
-
const
|
2563
|
+
const block_q8_0 * restrict y = vy;
|
2579
2564
|
|
2580
2565
|
float sumf = 0.0;
|
2581
2566
|
|
2582
|
-
|
2583
|
-
|
2584
|
-
|
2585
|
-
|
2586
|
-
float acc_offset = 0.0f;
|
2587
|
-
|
2588
|
-
// Main loop
|
2589
|
-
for (int i = 0; i < nb; ++i) {
|
2590
|
-
const float * d0 = &x[i].d;
|
2591
|
-
const float * d1 = &y[i].d;
|
2592
|
-
|
2593
|
-
const float * m0 = &x[i].m;
|
2594
|
-
const float * m1 = &y[i].m;
|
2595
|
-
|
2596
|
-
const __m256 d0v = _mm256_broadcast_ss( d0 );
|
2597
|
-
const __m256 d1v = _mm256_broadcast_ss( d1 );
|
2598
|
-
const __m256 m0v = _mm256_broadcast_ss( m0 );
|
2599
|
-
const __m256 m1v = _mm256_broadcast_ss( m1 );
|
2600
|
-
|
2601
|
-
// Compute combined scale for the block
|
2602
|
-
const __m256 scale_01 = _mm256_mul_ps( d0v, d1v );
|
2603
|
-
|
2604
|
-
// Compute cross scales for the block
|
2605
|
-
const __m256 scale_0 = _mm256_mul_ps( d0v, m1v );
|
2606
|
-
const __m256 scale_1 = _mm256_mul_ps( m0v, d1v );
|
2607
|
-
const __m256 cross_scales = _mm256_blend_ps( scale_0, scale_1, 0xAA /* 0b10101010 */ );
|
2608
|
-
|
2609
|
-
// Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
|
2610
|
-
__m256i bx = bytesFromNibbles( x[i].qs );
|
2611
|
-
__m256i by = bytesFromNibbles( y[i].qs );
|
2612
|
-
|
2613
|
-
// Now we have a vector with bytes in [ 0 .. 15 ] interval.
|
2614
|
-
|
2615
|
-
// Sign-extend first 16 signed bytes into int16_t
|
2616
|
-
__m256i x16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( bx ) );
|
2617
|
-
__m256i y16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) );
|
2618
|
-
// Compute products of int16_t integers, add pairwise
|
2619
|
-
__m256i i32 = _mm256_madd_epi16( x16, y16 );
|
2620
|
-
|
2621
|
-
// Sign-extend last 16 signed bytes into int16_t vectors
|
2622
|
-
__m256i x16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( bx, 1 ) );
|
2623
|
-
__m256i y16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) );
|
2624
|
-
// Accumulate products of int16_t integers
|
2625
|
-
i32 = _mm256_add_epi32( i32, _mm256_madd_epi16( x16_h, y16_h ) );
|
2626
|
-
|
2627
|
-
// compute sums of unsigned bytes in bx, by in blocks of 8.
|
2628
|
-
// This results in a layout like X100 0000 X200 0000 X300 0000 X400 0000,
|
2629
|
-
// which we then interleave as X100 Y100 X200 Y200 X300 Y300 X400 Y400.
|
2630
|
-
// so if we then cast to 8 singles, we get 8 floats like [ x0_7, y0_7, x8_15, y8_15, x16_23, y16_23, x24_31, y24_31 ]
|
2631
|
-
__m256i xsumi = _mm256_sad_epu8( bx, _mm256_setzero_si256() );
|
2632
|
-
__m256i ysumi = _mm256_sad_epu8( by, _mm256_setzero_si256() );
|
2633
|
-
__m256i sumsi = _mm256_or_si256( xsumi, _mm256_slli_si256( ysumi, 4 ) );
|
2634
|
-
__m256 sums = _mm256_cvtepi32_ps( sumsi );
|
2635
|
-
|
2636
|
-
// Convert int32_t to float
|
2637
|
-
__m256 p = _mm256_cvtepi32_ps( i32 );
|
2638
|
-
// Apply the scale, and accumulate
|
2639
|
-
// acc += d0*d1*x*y + d0*m1*x + d1*m0*y
|
2640
|
-
acc = _mm256_fmadd_ps( scale_01, p, acc );
|
2641
|
-
acc = _mm256_fmadd_ps( cross_scales, sums, acc );
|
2642
|
-
// acc_offset += m0*m1 (for each entry in the block)
|
2643
|
-
acc_offset += (*m0)*(*m1);
|
2644
|
-
}
|
2645
|
-
|
2646
|
-
// Return horizontal sum of the acc vector
|
2647
|
-
__m128 res = _mm256_extractf128_ps( acc, 1 );
|
2648
|
-
res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
|
2649
|
-
res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
|
2650
|
-
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
|
2651
|
-
|
2652
|
-
sumf = _mm_cvtss_f32( res ) + acc_offset * QK4_1;
|
2653
|
-
#elif defined(__ARM_NEON)
|
2654
|
-
float sum00 = 0.0f;
|
2655
|
-
float sum01 = 0.0f;
|
2656
|
-
float sum10 = 0.0f;
|
2657
|
-
float sum11 = 0.0f;
|
2567
|
+
// TODO: add AVX / WASM SIMD / etc
|
2568
|
+
#if defined(__ARM_NEON)
|
2569
|
+
float32x4_t sumv0 = vdupq_n_f32(0.0f);
|
2570
|
+
float32x4_t sumv1 = vdupq_n_f32(0.0f);
|
2658
2571
|
|
2659
2572
|
for (int i = 0; i < nb; i += 2) {
|
2660
2573
|
const block_q4_1 * restrict x0 = &x[i + 0];
|
2661
|
-
const block_q4_1 * restrict y0 = &y[i + 0];
|
2662
2574
|
const block_q4_1 * restrict x1 = &x[i + 1];
|
2663
|
-
const
|
2575
|
+
const block_q8_0 * restrict y0 = &y[i + 0];
|
2576
|
+
const block_q8_0 * restrict y1 = &y[i + 1];
|
2664
2577
|
|
2665
2578
|
const uint8x16_t m4b = vdupq_n_u8(0xf);
|
2666
2579
|
|
2667
2580
|
const uint8x16_t v0_0 = vld1q_u8(x0->qs);
|
2668
|
-
const uint8x16_t v1_0 = vld1q_u8(y0->qs);
|
2669
2581
|
const uint8x16_t v0_1 = vld1q_u8(x1->qs);
|
2670
|
-
const uint8x16_t v1_1 = vld1q_u8(y1->qs);
|
2671
2582
|
|
2672
2583
|
// 4-bit -> 8-bit
|
2673
|
-
const
|
2674
|
-
const
|
2675
|
-
const
|
2676
|
-
const
|
2584
|
+
const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
|
2585
|
+
const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
|
2586
|
+
const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
|
2587
|
+
const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
|
2588
|
+
|
2589
|
+
// load y
|
2590
|
+
const int8x16_t v1_0l = vld1q_s8(y0->qs);
|
2591
|
+
const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
|
2592
|
+
const int8x16_t v1_1l = vld1q_s8(y1->qs);
|
2593
|
+
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
|
2677
2594
|
|
2678
|
-
|
2679
|
-
const
|
2680
|
-
const
|
2681
|
-
const
|
2595
|
+
// interleave
|
2596
|
+
const int8x16_t v1_0ls = vuzp1q_s8(v1_0l, v1_0h);
|
2597
|
+
const int8x16_t v1_0hs = vuzp2q_s8(v1_0l, v1_0h);
|
2598
|
+
const int8x16_t v1_1ls = vuzp1q_s8(v1_1l, v1_1h);
|
2599
|
+
const int8x16_t v1_1hs = vuzp2q_s8(v1_1l, v1_1h);
|
2682
2600
|
|
2683
|
-
|
2684
|
-
|
2685
|
-
|
2601
|
+
const int16x8_t s0i = vaddq_s16(
|
2602
|
+
vaddq_s16(vmovl_s8(vget_low_s8(v1_0ls)), vmovl_s8(vget_high_s8(v1_0ls))),
|
2603
|
+
vaddq_s16(vmovl_s8(vget_low_s8(v1_0hs)), vmovl_s8(vget_high_s8(v1_0hs))));
|
2686
2604
|
|
2687
|
-
|
2688
|
-
|
2689
|
-
|
2605
|
+
const int16x8_t s1i = vaddq_s16(
|
2606
|
+
vaddq_s16(vmovl_s8(vget_low_s8(v1_1ls)), vmovl_s8(vget_high_s8(v1_1ls))),
|
2607
|
+
vaddq_s16(vmovl_s8(vget_low_s8(v1_1hs)), vmovl_s8(vget_high_s8(v1_1hs))));
|
2608
|
+
|
2609
|
+
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(s0i), vget_high_s16(s0i))), x0->m*y0->d);
|
2610
|
+
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(s1i), vget_high_s16(s1i))), x1->m*y1->d);
|
2690
2611
|
|
2691
2612
|
#if defined(__ARM_FEATURE_DOTPROD)
|
2692
2613
|
// dot product into int32x4_t
|
2693
|
-
|
2694
|
-
|
2695
|
-
|
2696
|
-
p_0 = vdotq_u32(p_0, v0_0h, v1_0h);
|
2697
|
-
p_1 = vdotq_u32(p_1, v0_1h, v1_1h);
|
2614
|
+
const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0ls), v0_0h, v1_0hs);
|
2615
|
+
const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1ls), v0_1h, v1_1hs);
|
2698
2616
|
|
2699
|
-
|
2700
|
-
|
2617
|
+
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d);
|
2618
|
+
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d);
|
2701
2619
|
#else
|
2702
|
-
const
|
2703
|
-
const
|
2704
|
-
const
|
2705
|
-
const
|
2620
|
+
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0l), vget_low_s8 (v1_0ls));
|
2621
|
+
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0l), vget_high_s8(v1_0ls));
|
2622
|
+
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0h), vget_low_s8 (v1_0hs));
|
2623
|
+
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0h), vget_high_s8(v1_0hs));
|
2624
|
+
|
2625
|
+
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1l), vget_low_s8 (v1_1ls));
|
2626
|
+
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1l), vget_high_s8(v1_1ls));
|
2627
|
+
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1h), vget_low_s8 (v1_1hs));
|
2628
|
+
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1h), vget_high_s8(v1_1hs));
|
2629
|
+
|
2630
|
+
const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
|
2631
|
+
const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
|
2632
|
+
const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
|
2633
|
+
const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
|
2634
|
+
|
2635
|
+
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0->d*y0->d);
|
2636
|
+
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), x1->d*y1->d);
|
2637
|
+
#endif
|
2638
|
+
}
|
2706
2639
|
|
2707
|
-
|
2708
|
-
|
2709
|
-
|
2710
|
-
|
2640
|
+
sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
|
2641
|
+
#elif defined(__AVX2__)
|
2642
|
+
// Initialize accumulator with zeros
|
2643
|
+
__m256 acc = _mm256_setzero_ps();
|
2711
2644
|
|
2712
|
-
|
2713
|
-
|
2645
|
+
// Main loop
|
2646
|
+
for (int i = 0; i < nb; ++i) {
|
2647
|
+
const float * d0 = &x[i].d;
|
2648
|
+
const float * d1 = &y[i].d;
|
2649
|
+
const float * m0 = &x[i].m;
|
2650
|
+
|
2651
|
+
const __m256 d0v = _mm256_broadcast_ss( d0 );
|
2652
|
+
const __m256 d1v = _mm256_broadcast_ss( d1 );
|
2653
|
+
const __m256 m0v = _mm256_broadcast_ss( m0 );
|
2714
2654
|
|
2715
|
-
|
2716
|
-
const
|
2655
|
+
// Compute combined scales
|
2656
|
+
const __m256 d0d1 = _mm256_mul_ps( d0v, d1v );
|
2657
|
+
const __m256 d1m0 = _mm256_mul_ps( d1v, m0v );
|
2717
2658
|
|
2718
|
-
|
2719
|
-
const
|
2659
|
+
// Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
|
2660
|
+
const __m256i bx = bytes_from_nibbles_32(x[i].qs);
|
2661
|
+
const __m256i by = _mm256_loadu_si256( (const __m256i *)y[i].qs );
|
2720
2662
|
|
2721
|
-
|
2722
|
-
|
2723
|
-
|
2663
|
+
// Get absolute values of x vectors
|
2664
|
+
const __m256i ax = _mm256_sign_epi8( bx, bx );
|
2665
|
+
|
2666
|
+
// Sign the values of the y vectors
|
2667
|
+
const __m256i sy = _mm256_sign_epi8( by, bx );
|
2668
|
+
|
2669
|
+
// Perform multiplication and create 16-bit values
|
2670
|
+
const __m256i dot = _mm256_maddubs_epi16( ax, sy );
|
2671
|
+
const __m256i ones = _mm256_set1_epi16( 1 );
|
2672
|
+
const __m256i xy_q = _mm256_madd_epi16( ones, dot );
|
2673
|
+
|
2674
|
+
// Convert to vector of 8 int32_t to 8 floats
|
2675
|
+
const __m256 xy = _mm256_cvtepi32_ps( xy_q );
|
2676
|
+
|
2677
|
+
// Accumulate d0*d1*x*y
|
2678
|
+
acc = _mm256_fmadd_ps( d0d1, xy, acc );
|
2679
|
+
|
2680
|
+
// Compute sum of y values
|
2681
|
+
const __m256i y16_l = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) );
|
2682
|
+
const __m256i y16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) );
|
2683
|
+
const __m256i ysumi = _mm256_madd_epi16( _mm256_add_epi16(y16_l, y16_h), ones );
|
2684
|
+
const __m256 ysum = _mm256_cvtepi32_ps( ysumi );
|
2685
|
+
|
2686
|
+
// Accumulate d1*m0*y
|
2687
|
+
acc = _mm256_fmadd_ps( d1m0, ysum, acc );
|
2724
2688
|
}
|
2725
2689
|
|
2726
|
-
|
2690
|
+
// Return horizontal sum of the acc vector
|
2691
|
+
__m128 res = _mm256_extractf128_ps( acc, 1 );
|
2692
|
+
res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
|
2693
|
+
res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
|
2694
|
+
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
|
2695
|
+
|
2696
|
+
sumf = _mm_cvtss_f32( res );
|
2727
2697
|
#else
|
2728
2698
|
// scalar
|
2729
2699
|
for (int i = 0; i < nb; i++) {
|
2730
2700
|
const float d0 = x[i].d;
|
2731
|
-
const float d1 = y[i].d;
|
2732
|
-
|
2733
2701
|
const float m0 = x[i].m;
|
2734
|
-
const float
|
2702
|
+
const float d1 = y[i].d;
|
2735
2703
|
|
2736
2704
|
const uint8_t * restrict p0 = x[i].qs;
|
2737
|
-
const
|
2705
|
+
const int8_t * restrict p1 = y[i].qs;
|
2738
2706
|
|
2739
|
-
|
2707
|
+
// TODO: this is very slow ..
|
2708
|
+
for (int j = 0; j < QK8_0/2; j++) {
|
2740
2709
|
const uint8_t v0 = p0[j];
|
2741
|
-
const uint8_t v1 = p1[j];
|
2742
2710
|
|
2743
2711
|
const float f0 = d0*(v0 & 0xf) + m0;
|
2744
2712
|
const float f1 = d0*(v0 >> 4) + m0;
|
2745
2713
|
|
2746
|
-
const float f2 = d1*
|
2747
|
-
const float f3 = d1*
|
2714
|
+
const float f2 = d1*p1[2*j + 0];
|
2715
|
+
const float f3 = d1*p1[2*j + 1];
|
2748
2716
|
|
2749
2717
|
sumf += f0*f2 + f1*f3;
|
2750
2718
|
}
|
@@ -2754,32 +2722,36 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest
|
|
2754
2722
|
*s = sumf;
|
2755
2723
|
}
|
2756
2724
|
|
2757
|
-
static void
|
2725
|
+
static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
2758
2726
|
const int nb = n / QK8_0;
|
2759
2727
|
|
2760
2728
|
assert(n % QK8_0 == 0);
|
2761
2729
|
assert(nb % 2 == 0);
|
2730
|
+
assert(QK8_0 == 2*QK4_2);
|
2762
2731
|
|
2763
|
-
const
|
2732
|
+
const block_q4_2 * restrict x = vx;
|
2764
2733
|
const block_q8_0 * restrict y = vy;
|
2765
2734
|
|
2766
2735
|
float sumf = 0.0;
|
2767
2736
|
|
2768
2737
|
#if defined(__ARM_NEON)
|
2769
|
-
|
2770
|
-
|
2738
|
+
float32x4_t sumv0 = vdupq_n_f32(0.0f);
|
2739
|
+
float32x4_t sumv1 = vdupq_n_f32(0.0f);
|
2771
2740
|
|
2772
2741
|
for (int i = 0; i < nb; i += 2) {
|
2773
|
-
const
|
2774
|
-
const
|
2742
|
+
const block_q4_2 * restrict x0_0 = &x[2*(i + 0) + 0];
|
2743
|
+
const block_q4_2 * restrict x0_1 = &x[2*(i + 0) + 1];
|
2744
|
+
const block_q4_2 * restrict x1_0 = &x[2*(i + 1) + 0];
|
2745
|
+
const block_q4_2 * restrict x1_1 = &x[2*(i + 1) + 1];
|
2746
|
+
|
2775
2747
|
const block_q8_0 * restrict y0 = &y[i + 0];
|
2776
2748
|
const block_q8_0 * restrict y1 = &y[i + 1];
|
2777
2749
|
|
2778
2750
|
const uint8x16_t m4b = vdupq_n_u8(0xf);
|
2779
2751
|
const int8x16_t s8b = vdupq_n_s8(0x8);
|
2780
2752
|
|
2781
|
-
const uint8x16_t v0_0 =
|
2782
|
-
const uint8x16_t v0_1 =
|
2753
|
+
const uint8x16_t v0_0 = vcombine_u8(vld1_u8(x0_0->qs), vld1_u8(x0_1->qs));
|
2754
|
+
const uint8x16_t v0_1 = vcombine_u8(vld1_u8(x1_0->qs), vld1_u8(x1_1->qs));
|
2783
2755
|
|
2784
2756
|
// 4-bit -> 8-bit
|
2785
2757
|
const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
|
@@ -2793,77 +2765,78 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
|
|
2793
2765
|
const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
|
2794
2766
|
const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
|
2795
2767
|
|
2768
|
+
// interleave
|
2769
|
+
const int8x16_t v0_0lz = vzip1q_s8(v0_0ls, v0_0hs);
|
2770
|
+
const int8x16_t v0_0hz = vzip2q_s8(v0_0ls, v0_0hs);
|
2771
|
+
const int8x16_t v0_1lz = vzip1q_s8(v0_1ls, v0_1hs);
|
2772
|
+
const int8x16_t v0_1hz = vzip2q_s8(v0_1ls, v0_1hs);
|
2773
|
+
|
2796
2774
|
// load y
|
2797
2775
|
const int8x16_t v1_0l = vld1q_s8(y0->qs);
|
2798
2776
|
const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
|
2799
2777
|
const int8x16_t v1_1l = vld1q_s8(y1->qs);
|
2800
2778
|
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
|
2801
2779
|
|
2802
|
-
// interleave
|
2803
|
-
const int8x16_t v1_0ls = vuzp1q_s8(v1_0l, v1_0h);
|
2804
|
-
const int8x16_t v1_0hs = vuzp2q_s8(v1_0l, v1_0h);
|
2805
|
-
const int8x16_t v1_1ls = vuzp1q_s8(v1_1l, v1_1h);
|
2806
|
-
const int8x16_t v1_1hs = vuzp2q_s8(v1_1l, v1_1h);
|
2807
|
-
|
2808
2780
|
#if defined(__ARM_FEATURE_DOTPROD)
|
2809
|
-
|
2810
|
-
|
2811
|
-
|
2812
|
-
|
2813
|
-
p_0 = vdotq_s32(p_0, v0_0hs, v1_0hs);
|
2814
|
-
p_1 = vdotq_s32(p_1, v0_1hs, v1_1hs);
|
2781
|
+
sumv0 = vmlaq_n_f32(sumv0, vaddq_f32(
|
2782
|
+
vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l)), GGML_FP16_TO_FP32(x0_0->d)),
|
2783
|
+
vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)), GGML_FP16_TO_FP32(x0_1->d))), y0->d);
|
2815
2784
|
|
2816
|
-
|
2817
|
-
|
2785
|
+
sumv1 = vmlaq_n_f32(sumv1, vaddq_f32(
|
2786
|
+
vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1lz, v1_1l)), GGML_FP16_TO_FP32(x1_0->d)),
|
2787
|
+
vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1hz, v1_1h)), GGML_FP16_TO_FP32(x1_1->d))), y1->d);
|
2818
2788
|
#else
|
2819
|
-
const int16x8_t pl0l = vmull_s8(vget_low_s8 (
|
2820
|
-
const int16x8_t pl0h = vmull_s8(vget_high_s8(
|
2821
|
-
const int16x8_t ph0l = vmull_s8(vget_low_s8 (
|
2822
|
-
const int16x8_t ph0h = vmull_s8(vget_high_s8(
|
2823
|
-
|
2824
|
-
const int16x8_t pl1l = vmull_s8(vget_low_s8 (
|
2825
|
-
const int16x8_t pl1h = vmull_s8(vget_high_s8(
|
2826
|
-
const int16x8_t ph1l = vmull_s8(vget_low_s8 (
|
2827
|
-
const int16x8_t ph1h = vmull_s8(vget_high_s8(
|
2828
|
-
|
2829
|
-
const
|
2830
|
-
const
|
2831
|
-
|
2832
|
-
const
|
2833
|
-
|
2834
|
-
|
2835
|
-
|
2836
|
-
|
2837
|
-
|
2838
|
-
|
2839
|
-
|
2789
|
+
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l));
|
2790
|
+
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l));
|
2791
|
+
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h));
|
2792
|
+
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h));
|
2793
|
+
|
2794
|
+
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lz), vget_low_s8 (v1_1l));
|
2795
|
+
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lz), vget_high_s8(v1_1l));
|
2796
|
+
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hz), vget_low_s8 (v1_1h));
|
2797
|
+
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hz), vget_high_s8(v1_1h));
|
2798
|
+
|
2799
|
+
const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
|
2800
|
+
const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
|
2801
|
+
const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
|
2802
|
+
const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
|
2803
|
+
|
2804
|
+
sumv0 = vmlaq_n_f32(sumv0, vaddq_f32(
|
2805
|
+
vmulq_n_f32(vcvtq_f32_s32(pl0), GGML_FP16_TO_FP32(x0_0->d)),
|
2806
|
+
vmulq_n_f32(vcvtq_f32_s32(ph0), GGML_FP16_TO_FP32(x0_1->d))), y0->d);
|
2807
|
+
|
2808
|
+
sumv1 = vmlaq_n_f32(sumv1, vaddq_f32(
|
2809
|
+
vmulq_n_f32(vcvtq_f32_s32(pl1), GGML_FP16_TO_FP32(x1_0->d)),
|
2810
|
+
vmulq_n_f32(vcvtq_f32_s32(ph1), GGML_FP16_TO_FP32(x1_1->d))), y1->d);
|
2840
2811
|
#endif
|
2841
2812
|
}
|
2842
2813
|
|
2843
|
-
sumf =
|
2814
|
+
sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
|
2844
2815
|
#elif defined(__AVX2__)
|
2845
2816
|
// Initialize accumulator with zeros
|
2846
2817
|
__m256 acc = _mm256_setzero_ps();
|
2847
2818
|
|
2848
2819
|
// Main loop
|
2849
|
-
for (int i = 0; i < nb; ++
|
2820
|
+
for (int i = 0; i < nb; i++) {
|
2850
2821
|
/* Compute combined scale for the block */
|
2851
|
-
const
|
2822
|
+
const __m128 d0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].d));
|
2823
|
+
const __m128 d1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].d));
|
2824
|
+
const __m256 d = _mm256_mul_ps(_mm256_set_m128(d1, d0), _mm256_broadcast_ss(&y[i].d));
|
2852
2825
|
|
2853
|
-
|
2826
|
+
__m128i bx0 = bytes_from_nibbles_16(x[2*i + 0].qs);
|
2827
|
+
__m128i bx1 = bytes_from_nibbles_16(x[2*i + 1].qs);
|
2828
|
+
__m256i bx = _mm256_set_m128i(bx1, bx0);
|
2854
2829
|
|
2855
2830
|
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
|
2856
|
-
const __m256i off = _mm256_set1_epi8(
|
2857
|
-
bx = _mm256_sub_epi8(
|
2831
|
+
const __m256i off = _mm256_set1_epi8(8);
|
2832
|
+
bx = _mm256_sub_epi8(bx, off);
|
2858
2833
|
|
2859
2834
|
__m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
|
2860
2835
|
|
2861
2836
|
// Get absolute values of x vectors
|
2862
2837
|
const __m256i ax = _mm256_sign_epi8(bx, bx);
|
2863
|
-
|
2864
2838
|
// Sign the values of the y vectors
|
2865
2839
|
const __m256i sy = _mm256_sign_epi8(by, bx);
|
2866
|
-
|
2867
2840
|
// Perform multiplication and create 16-bit values
|
2868
2841
|
const __m256i dot = _mm256_maddubs_epi16(ax, sy);
|
2869
2842
|
|
@@ -2871,92 +2844,208 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
|
|
2871
2844
|
__m256i xy_q = _mm256_madd_epi16(ones, dot);
|
2872
2845
|
|
2873
2846
|
/* Convert to vectore of 8 int32_t to 8 floats */
|
2874
|
-
__m256 q = _mm256_cvtepi32_ps(
|
2847
|
+
__m256 q = _mm256_cvtepi32_ps(xy_q);
|
2875
2848
|
|
2876
2849
|
/* Multiply q with scale and accumulate */
|
2877
|
-
acc = _mm256_fmadd_ps(
|
2850
|
+
acc = _mm256_fmadd_ps(d, q, acc);
|
2878
2851
|
}
|
2879
2852
|
|
2880
2853
|
// Return horizontal sum of the acc vector
|
2881
|
-
__m128 res = _mm256_extractf128_ps(
|
2882
|
-
res = _mm_add_ps(
|
2883
|
-
res = _mm_add_ps(
|
2884
|
-
res = _mm_add_ss(
|
2854
|
+
__m128 res = _mm256_extractf128_ps(acc, 1);
|
2855
|
+
res = _mm_add_ps(res, _mm256_castps256_ps128(acc));
|
2856
|
+
res = _mm_add_ps(res, _mm_movehl_ps(res, res));
|
2857
|
+
res = _mm_add_ss(res, _mm_movehdup_ps(res));
|
2885
2858
|
|
2886
|
-
sumf = _mm_cvtss_f32(
|
2887
|
-
#
|
2888
|
-
//
|
2889
|
-
|
2859
|
+
sumf = _mm_cvtss_f32(res);
|
2860
|
+
#else
|
2861
|
+
// scalar
|
2862
|
+
for (int i = 0; i < nb; i++) {
|
2863
|
+
const uint8_t * restrict x0 = x[2*i + 0].qs;
|
2864
|
+
const uint8_t * restrict x1 = x[2*i + 1].qs;
|
2865
|
+
const int8_t * restrict y0 = y[i].qs;
|
2890
2866
|
|
2891
|
-
|
2892
|
-
|
2893
|
-
// Compute combined scale for the block
|
2894
|
-
const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
|
2867
|
+
const float d0 = GGML_FP16_TO_FP32(x[2*i + 0].d);
|
2868
|
+
const float d1 = GGML_FP16_TO_FP32(x[2*i + 1].d);
|
2895
2869
|
|
2896
|
-
|
2897
|
-
|
2898
|
-
|
2899
|
-
|
2900
|
-
|
2870
|
+
int sumi_0 = 0;
|
2871
|
+
int sumi_1 = 0;
|
2872
|
+
|
2873
|
+
for (int j = 0; j < QK8_0/4; j++) {
|
2874
|
+
const uint8_t v0 = x0[j];
|
2875
|
+
const uint8_t v1 = x1[j];
|
2876
|
+
|
2877
|
+
const int i0_0 = (int8_t) (v0 & 0xf) - 8;
|
2878
|
+
const int i1_0 = (int8_t) (v0 >> 4) - 8;
|
2879
|
+
|
2880
|
+
const int i0_1 = (int8_t) (v1 & 0xf) - 8;
|
2881
|
+
const int i1_1 = (int8_t) (v1 >> 4) - 8;
|
2882
|
+
|
2883
|
+
const int i2_0 = y0[2*j + 0];
|
2884
|
+
const int i3_0 = y0[2*j + 1];
|
2885
|
+
|
2886
|
+
const int i2_1 = y0[2*(j + QK8_0/4) + 0];
|
2887
|
+
const int i3_1 = y0[2*(j + QK8_0/4) + 1];
|
2888
|
+
|
2889
|
+
sumi_0 += i0_0*i2_0 + i1_0*i3_0;
|
2890
|
+
sumi_1 += i0_1*i2_1 + i1_1*i3_1;
|
2891
|
+
}
|
2892
|
+
|
2893
|
+
sumf += (d0 * y[i].d) * sumi_0;
|
2894
|
+
sumf += (d1 * y[i].d) * sumi_1;
|
2895
|
+
}
|
2896
|
+
#endif
|
2897
|
+
|
2898
|
+
*s = sumf;
|
2899
|
+
}
|
2900
|
+
|
2901
|
+
static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
2902
|
+
const int nb = n / QK8_0;
|
2903
|
+
|
2904
|
+
assert(n % QK8_0 == 0);
|
2905
|
+
assert(nb % 2 == 0);
|
2906
|
+
assert(QK8_0 == 2*QK4_2);
|
2907
|
+
|
2908
|
+
const block_q4_3 * restrict x = vx;
|
2909
|
+
const block_q8_0 * restrict y = vy;
|
2910
|
+
|
2911
|
+
float sumf = 0.0;
|
2912
|
+
|
2913
|
+
#if defined(__ARM_NEON)
|
2914
|
+
float32x4_t sumv0 = vdupq_n_f32(0.0f);
|
2915
|
+
float32x4_t sumv1 = vdupq_n_f32(0.0f);
|
2916
|
+
|
2917
|
+
for (int i = 0; i < nb; i += 2) {
|
2918
|
+
const block_q4_3 * restrict x0_0 = &x[2*(i + 0) + 0];
|
2919
|
+
const block_q4_3 * restrict x0_1 = &x[2*(i + 0) + 1];
|
2920
|
+
const block_q4_3 * restrict x1_0 = &x[2*(i + 1) + 0];
|
2921
|
+
const block_q4_3 * restrict x1_1 = &x[2*(i + 1) + 1];
|
2922
|
+
|
2923
|
+
const block_q8_0 * restrict y0 = &y[i + 0];
|
2924
|
+
const block_q8_0 * restrict y1 = &y[i + 1];
|
2925
|
+
|
2926
|
+
const uint8x16_t m4b = vdupq_n_u8(0xf);
|
2927
|
+
|
2928
|
+
const float x0_0d = GGML_FP16_TO_FP32(x0_0->d);
|
2929
|
+
const float x0_1d = GGML_FP16_TO_FP32(x0_1->d);
|
2930
|
+
const float x1_0d = GGML_FP16_TO_FP32(x1_0->d);
|
2931
|
+
const float x1_1d = GGML_FP16_TO_FP32(x1_1->d);
|
2932
|
+
|
2933
|
+
const float x0_0m = GGML_FP16_TO_FP32(x0_0->m);
|
2934
|
+
const float x0_1m = GGML_FP16_TO_FP32(x0_1->m);
|
2935
|
+
const float x1_0m = GGML_FP16_TO_FP32(x1_0->m);
|
2936
|
+
const float x1_1m = GGML_FP16_TO_FP32(x1_1->m);
|
2937
|
+
|
2938
|
+
const uint8x16_t v0_0 = vcombine_u8(vld1_u8(x0_0->qs), vld1_u8(x0_1->qs));
|
2939
|
+
const uint8x16_t v0_1 = vcombine_u8(vld1_u8(x1_0->qs), vld1_u8(x1_1->qs));
|
2940
|
+
|
2941
|
+
// 4-bit -> 8-bit
|
2942
|
+
const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
|
2943
|
+
const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
|
2944
|
+
const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
|
2945
|
+
const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
|
2901
2946
|
|
2902
|
-
|
2903
|
-
|
2904
|
-
|
2947
|
+
// interleave
|
2948
|
+
const int8x16_t v0_0lz = vzip1q_s8(v0_0l, v0_0h);
|
2949
|
+
const int8x16_t v0_0hz = vzip2q_s8(v0_0l, v0_0h);
|
2950
|
+
const int8x16_t v0_1lz = vzip1q_s8(v0_1l, v0_1h);
|
2951
|
+
const int8x16_t v0_1hz = vzip2q_s8(v0_1l, v0_1h);
|
2905
2952
|
|
2906
|
-
|
2907
|
-
|
2953
|
+
// load y
|
2954
|
+
const int8x16_t v1_0l = vld1q_s8(y0->qs);
|
2955
|
+
const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
|
2956
|
+
const int8x16_t v1_1l = vld1q_s8(y1->qs);
|
2957
|
+
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
|
2908
2958
|
|
2909
|
-
|
2910
|
-
|
2959
|
+
const int16x8_t sy0_0 = vaddq_s16(vmovl_s8(vget_low_s8(v1_0l)), vmovl_s8(vget_high_s8(v1_0l)));
|
2960
|
+
const int16x8_t sy0_1 = vaddq_s16(vmovl_s8(vget_low_s8(v1_0h)), vmovl_s8(vget_high_s8(v1_0h)));
|
2911
2961
|
|
2912
|
-
|
2913
|
-
|
2962
|
+
const int16x8_t sy1_0 = vaddq_s16(vmovl_s8(vget_low_s8(v1_1l)), vmovl_s8(vget_high_s8(v1_1l)));
|
2963
|
+
const int16x8_t sy1_1 = vaddq_s16(vmovl_s8(vget_low_s8(v1_1h)), vmovl_s8(vget_high_s8(v1_1h)));
|
2914
2964
|
|
2915
|
-
|
2916
|
-
|
2917
|
-
|
2965
|
+
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy0_0), vget_high_s16(sy0_0))), x0_0m*y0->d);
|
2966
|
+
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy0_1), vget_high_s16(sy0_1))), x0_1m*y0->d);
|
2967
|
+
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy1_0), vget_high_s16(sy1_0))), x1_0m*y1->d);
|
2968
|
+
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy1_1), vget_high_s16(sy1_1))), x1_1m*y1->d);
|
2918
2969
|
|
2919
|
-
|
2920
|
-
|
2921
|
-
|
2922
|
-
|
2970
|
+
#if defined(__ARM_FEATURE_DOTPROD)
|
2971
|
+
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l)), x0_0d*y0->d);
|
2972
|
+
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)), x0_1d*y0->d);
|
2973
|
+
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1lz, v1_1l)), x1_0d*y1->d);
|
2974
|
+
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1hz, v1_1h)), x1_1d*y1->d);
|
2975
|
+
#else
|
2976
|
+
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l));
|
2977
|
+
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l));
|
2978
|
+
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h));
|
2979
|
+
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h));
|
2980
|
+
|
2981
|
+
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lz), vget_low_s8 (v1_1l));
|
2982
|
+
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lz), vget_high_s8(v1_1l));
|
2983
|
+
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hz), vget_low_s8 (v1_1h));
|
2984
|
+
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hz), vget_high_s8(v1_1h));
|
2985
|
+
|
2986
|
+
const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
|
2987
|
+
const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
|
2988
|
+
const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
|
2989
|
+
const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
|
2990
|
+
|
2991
|
+
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(pl0), x0_0d*y0->d);
|
2992
|
+
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(ph0), x0_1d*y0->d);
|
2993
|
+
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(pl1), x1_0d*y1->d);
|
2994
|
+
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(ph1), x1_1d*y1->d);
|
2995
|
+
#endif
|
2923
2996
|
}
|
2924
2997
|
|
2925
|
-
|
2926
|
-
__m128 res = _mm256_extractf128_ps( acc, 1 );
|
2927
|
-
res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
|
2928
|
-
res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
|
2929
|
-
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
|
2930
|
-
|
2931
|
-
sumf = _mm_cvtss_f32( res );
|
2998
|
+
sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
|
2932
2999
|
#else
|
2933
3000
|
// scalar
|
2934
3001
|
for (int i = 0; i < nb; i++) {
|
2935
|
-
const
|
2936
|
-
const
|
3002
|
+
const uint8_t * restrict x0 = x[2*i + 0].qs;
|
3003
|
+
const uint8_t * restrict x1 = x[2*i + 1].qs;
|
3004
|
+
const int8_t * restrict y0 = y[i].qs;
|
2937
3005
|
|
2938
|
-
const
|
2939
|
-
const
|
3006
|
+
const float d0 = GGML_FP16_TO_FP32(x[2*i + 0].d);
|
3007
|
+
const float m0 = GGML_FP16_TO_FP32(x[2*i + 0].m);
|
3008
|
+
const float d1 = GGML_FP16_TO_FP32(x[2*i + 1].d);
|
3009
|
+
const float m1 = GGML_FP16_TO_FP32(x[2*i + 1].m);
|
2940
3010
|
|
2941
|
-
int
|
2942
|
-
|
2943
|
-
const uint8_t v0 = p0[j];
|
3011
|
+
int sy_0 = 0;
|
3012
|
+
int sy_1 = 0;
|
2944
3013
|
|
2945
|
-
|
2946
|
-
|
3014
|
+
int sxy_0 = 0;
|
3015
|
+
int sxy_1 = 0;
|
2947
3016
|
|
2948
|
-
|
2949
|
-
const
|
3017
|
+
for (int j = 0; j < QK8_0/4; j++) {
|
3018
|
+
const uint8_t v0 = x0[j];
|
3019
|
+
const uint8_t v1 = x1[j];
|
2950
3020
|
|
2951
|
-
|
3021
|
+
const int x0_0 = v0 & 0xf;
|
3022
|
+
const int x1_0 = v0 >> 4;
|
3023
|
+
|
3024
|
+
const int x0_1 = v1 & 0xf;
|
3025
|
+
const int x1_1 = v1 >> 4;
|
3026
|
+
|
3027
|
+
const int y0_0 = y0[2*j + 0];
|
3028
|
+
const int y1_0 = y0[2*j + 1];
|
3029
|
+
|
3030
|
+
const int y0_1 = y0[2*(j + QK8_0/4) + 0];
|
3031
|
+
const int y1_1 = y0[2*(j + QK8_0/4) + 1];
|
3032
|
+
|
3033
|
+
sy_0 += y0_0 + y1_0;
|
3034
|
+
sy_1 += y0_1 + y1_1;
|
3035
|
+
|
3036
|
+
sxy_0 += x0_0*y0_0 + x1_0*y1_0;
|
3037
|
+
sxy_1 += x0_1*y0_1 + x1_1*y1_1;
|
2952
3038
|
}
|
2953
|
-
|
3039
|
+
|
3040
|
+
sumf += (d0*sxy_0 + m0*sy_0)*y[i].d;
|
3041
|
+
sumf += (d1*sxy_1 + m1*sy_1)*y[i].d;
|
2954
3042
|
}
|
2955
3043
|
#endif
|
2956
3044
|
|
2957
3045
|
*s = sumf;
|
2958
3046
|
}
|
2959
3047
|
|
3048
|
+
|
2960
3049
|
// compute GGML_VEC_DOT_UNROLL dot products at once
|
2961
3050
|
// xs - x row stride in bytes
|
2962
3051
|
inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) {
|
@@ -3203,24 +3292,28 @@ static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = {
|
|
3203
3292
|
[GGML_TYPE_F16] = 1,
|
3204
3293
|
[GGML_TYPE_Q4_0] = QK4_0,
|
3205
3294
|
[GGML_TYPE_Q4_1] = QK4_1,
|
3295
|
+
[GGML_TYPE_Q4_2] = QK4_2,
|
3296
|
+
[GGML_TYPE_Q4_3] = QK4_3,
|
3206
3297
|
[GGML_TYPE_Q8_0] = QK8_0,
|
3207
3298
|
[GGML_TYPE_I8] = 1,
|
3208
3299
|
[GGML_TYPE_I16] = 1,
|
3209
3300
|
[GGML_TYPE_I32] = 1,
|
3210
3301
|
};
|
3211
|
-
static_assert(GGML_TYPE_COUNT ==
|
3302
|
+
static_assert(GGML_TYPE_COUNT == 10, "GGML_BLCK_SIZE is outdated");
|
3212
3303
|
|
3213
3304
|
static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
|
3214
3305
|
[GGML_TYPE_F32] = sizeof(float),
|
3215
3306
|
[GGML_TYPE_F16] = sizeof(ggml_fp16_t),
|
3216
3307
|
[GGML_TYPE_Q4_0] = sizeof(block_q4_0),
|
3217
3308
|
[GGML_TYPE_Q4_1] = sizeof(block_q4_1),
|
3309
|
+
[GGML_TYPE_Q4_2] = sizeof(block_q4_2),
|
3310
|
+
[GGML_TYPE_Q4_3] = sizeof(block_q4_3),
|
3218
3311
|
[GGML_TYPE_Q8_0] = sizeof(block_q8_0),
|
3219
3312
|
[GGML_TYPE_I8] = sizeof(int8_t),
|
3220
3313
|
[GGML_TYPE_I16] = sizeof(int16_t),
|
3221
3314
|
[GGML_TYPE_I32] = sizeof(int32_t),
|
3222
3315
|
};
|
3223
|
-
static_assert(GGML_TYPE_COUNT ==
|
3316
|
+
static_assert(GGML_TYPE_COUNT == 10, "GGML_TYPE_SIZE is outdated");
|
3224
3317
|
|
3225
3318
|
|
3226
3319
|
static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
|
@@ -3228,12 +3321,28 @@ static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
|
|
3228
3321
|
[GGML_TYPE_F16] = "f16",
|
3229
3322
|
[GGML_TYPE_Q4_0] = "q4_0",
|
3230
3323
|
[GGML_TYPE_Q4_1] = "q4_1",
|
3324
|
+
[GGML_TYPE_Q4_2] = "q4_2",
|
3325
|
+
[GGML_TYPE_Q4_3] = "q4_3",
|
3231
3326
|
[GGML_TYPE_Q8_0] = "q8_0",
|
3232
3327
|
[GGML_TYPE_I8] = "i8",
|
3233
3328
|
[GGML_TYPE_I16] = "i16",
|
3234
3329
|
[GGML_TYPE_I32] = "i32",
|
3235
3330
|
};
|
3236
|
-
static_assert(GGML_TYPE_COUNT ==
|
3331
|
+
static_assert(GGML_TYPE_COUNT == 10, "GGML_TYPE_NAME is outdated");
|
3332
|
+
|
3333
|
+
static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = {
|
3334
|
+
[GGML_TYPE_F32] = false,
|
3335
|
+
[GGML_TYPE_F16] = false,
|
3336
|
+
[GGML_TYPE_Q4_0] = true,
|
3337
|
+
[GGML_TYPE_Q4_1] = true,
|
3338
|
+
[GGML_TYPE_Q4_2] = true,
|
3339
|
+
[GGML_TYPE_Q4_3] = true,
|
3340
|
+
[GGML_TYPE_Q8_0] = true,
|
3341
|
+
[GGML_TYPE_I8] = false,
|
3342
|
+
[GGML_TYPE_I16] = false,
|
3343
|
+
[GGML_TYPE_I32] = false,
|
3344
|
+
};
|
3345
|
+
static_assert(GGML_TYPE_COUNT == 10, "GGML_IS_QUANTIZED is outdated");
|
3237
3346
|
|
3238
3347
|
static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
|
3239
3348
|
"NONE",
|
@@ -3495,6 +3604,10 @@ static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct
|
|
3495
3604
|
(t0->ne[3] == t1->ne[3]);
|
3496
3605
|
}
|
3497
3606
|
|
3607
|
+
bool ggml_is_quantized(enum ggml_type type) {
|
3608
|
+
return GGML_IS_QUANTIZED[type];
|
3609
|
+
}
|
3610
|
+
|
3498
3611
|
static inline bool ggml_is_transposed(const struct ggml_tensor * tensor) {
|
3499
3612
|
return tensor->nb[0] > tensor->nb[1];
|
3500
3613
|
}
|
@@ -3605,6 +3718,11 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
|
|
3605
3718
|
GGML_PRINT_DEBUG("%s: g_state initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f);
|
3606
3719
|
}
|
3607
3720
|
|
3721
|
+
// initialize cuBLAS
|
3722
|
+
#if defined(GGML_USE_CUBLAS)
|
3723
|
+
init_cublas();
|
3724
|
+
#endif
|
3725
|
+
|
3608
3726
|
is_first_call = false;
|
3609
3727
|
}
|
3610
3728
|
|
@@ -5535,7 +5653,6 @@ static void ggml_compute_forward_dup_f16(
|
|
5535
5653
|
const struct ggml_compute_params * params,
|
5536
5654
|
const struct ggml_tensor * src0,
|
5537
5655
|
struct ggml_tensor * dst) {
|
5538
|
-
GGML_ASSERT(params->ith == 0);
|
5539
5656
|
GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
|
5540
5657
|
|
5541
5658
|
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
@@ -5547,6 +5664,11 @@ static void ggml_compute_forward_dup_f16(
|
|
5547
5664
|
const int64_t ne02 = src0->ne[2];
|
5548
5665
|
const int64_t ne03 = src0->ne[3];
|
5549
5666
|
|
5667
|
+
const int64_t ne0 = dst->ne[0];
|
5668
|
+
const int64_t ne1 = dst->ne[1];
|
5669
|
+
const int64_t ne2 = dst->ne[2];
|
5670
|
+
const int64_t ne3 = dst->ne[3];
|
5671
|
+
|
5550
5672
|
const size_t nb00 = src0->nb[0];
|
5551
5673
|
const size_t nb01 = src0->nb[1];
|
5552
5674
|
const size_t nb02 = src0->nb[2];
|
@@ -5557,19 +5679,40 @@ static void ggml_compute_forward_dup_f16(
|
|
5557
5679
|
const size_t nb2 = dst->nb[2];
|
5558
5680
|
const size_t nb3 = dst->nb[3];
|
5559
5681
|
|
5682
|
+
const int ith = params->ith; // thread index
|
5683
|
+
const int nth = params->nth; // number of threads
|
5684
|
+
|
5560
5685
|
if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) {
|
5561
|
-
|
5686
|
+
// parallelize by elements
|
5687
|
+
const int ne = ggml_nelements(dst);
|
5688
|
+
const int dr = (ne + nth - 1) / nth;
|
5689
|
+
const int ie0 = dr * ith;
|
5690
|
+
const int ie1 = MIN(ie0 + dr, ne);
|
5691
|
+
|
5692
|
+
memcpy(
|
5693
|
+
((char *) dst->data + ie0*nb0),
|
5694
|
+
((char *) src0->data + ie0*nb00),
|
5695
|
+
(ie1 - ie0) * GGML_TYPE_SIZE[src0->type]);
|
5696
|
+
|
5562
5697
|
return;
|
5563
5698
|
}
|
5564
5699
|
|
5700
|
+
// parallelize by rows
|
5701
|
+
const int nr = ne01;
|
5702
|
+
// number of rows per thread
|
5703
|
+
const int dr = (nr + nth - 1) / nth;
|
5704
|
+
// row range for this thread
|
5705
|
+
const int ir0 = dr * ith;
|
5706
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
5707
|
+
|
5565
5708
|
if (src0->type == dst->type &&
|
5566
|
-
|
5567
|
-
|
5709
|
+
ne00 == ne0 &&
|
5710
|
+
nb00 == GGML_TYPE_SIZE[src0->type] && nb0 == GGML_TYPE_SIZE[dst->type]) {
|
5568
5711
|
// copy by rows
|
5569
5712
|
const size_t rs = ne00*nb00;
|
5570
5713
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
5571
5714
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
5572
|
-
for (int64_t i01 =
|
5715
|
+
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
5573
5716
|
memcpy(
|
5574
5717
|
((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
|
5575
5718
|
((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
|
@@ -5583,21 +5726,21 @@ static void ggml_compute_forward_dup_f16(
|
|
5583
5726
|
// TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
|
5584
5727
|
|
5585
5728
|
if (ggml_is_contiguous(dst)) {
|
5586
|
-
if (
|
5729
|
+
if (nb00 == sizeof(ggml_fp16_t)) {
|
5587
5730
|
if (dst->type == GGML_TYPE_F16) {
|
5588
5731
|
size_t id = 0;
|
5589
|
-
const size_t rs = ne00*nb00;
|
5732
|
+
const size_t rs = ne00 * nb00;
|
5733
|
+
char * dst_ptr = (char *) dst->data;
|
5590
5734
|
|
5591
5735
|
for (int i03 = 0; i03 < ne03; i03++) {
|
5592
5736
|
for (int i02 = 0; i02 < ne02; i02++) {
|
5593
|
-
|
5737
|
+
id += rs * ir0;
|
5738
|
+
for (int i01 = ir0; i01 < ir1; i01++) {
|
5594
5739
|
const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
|
5595
|
-
|
5596
|
-
|
5597
|
-
memcpy(dst_ptr, src0_ptr, rs);
|
5598
|
-
|
5599
|
-
id++;
|
5740
|
+
memcpy(dst_ptr + id, src0_ptr, rs);
|
5741
|
+
id += rs;
|
5600
5742
|
}
|
5743
|
+
id += rs * (ne01 - ir1);
|
5601
5744
|
}
|
5602
5745
|
}
|
5603
5746
|
} else if (dst->type == GGML_TYPE_F32) {
|
@@ -5606,34 +5749,39 @@ static void ggml_compute_forward_dup_f16(
|
|
5606
5749
|
|
5607
5750
|
for (int i03 = 0; i03 < ne03; i03++) {
|
5608
5751
|
for (int i02 = 0; i02 < ne02; i02++) {
|
5609
|
-
|
5752
|
+
id += ne00 * ir0;
|
5753
|
+
for (int i01 = ir0; i01 < ir1; i01++) {
|
5754
|
+
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
5610
5755
|
for (int i00 = 0; i00 < ne00; i00++) {
|
5611
|
-
|
5612
|
-
|
5613
|
-
dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr);
|
5756
|
+
dst_ptr[id] = GGML_FP16_TO_FP32(src0_ptr[i00]);
|
5614
5757
|
id++;
|
5615
5758
|
}
|
5616
5759
|
}
|
5760
|
+
id += ne00 * (ne01 - ir1);
|
5617
5761
|
}
|
5618
5762
|
}
|
5619
|
-
} else if (dst->type
|
5763
|
+
} else if (ggml_is_quantized(dst->type)) {
|
5620
5764
|
quantize_row_q_t const quantize_row_q = quantize_fns[dst->type].quantize_row_q;
|
5765
|
+
float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
|
5766
|
+
|
5621
5767
|
size_t id = 0;
|
5622
|
-
|
5623
|
-
|
5624
|
-
float * src0_f32 = (float *) params->wdata;
|
5768
|
+
size_t rs = nb0 * (ne00 / GGML_BLCK_SIZE[dst->type]);
|
5769
|
+
char * dst_ptr = (char *) dst->data;
|
5625
5770
|
|
5626
5771
|
for (int i03 = 0; i03 < ne03; i03++) {
|
5627
5772
|
for (int i02 = 0; i02 < ne02; i02++) {
|
5628
|
-
|
5773
|
+
id += rs * ir0;
|
5774
|
+
for (int i01 = ir0; i01 < ir1; i01++) {
|
5629
5775
|
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
5630
|
-
|
5776
|
+
|
5631
5777
|
for (int i00 = 0; i00 < ne00; i00++) {
|
5632
5778
|
src0_f32[i00] = GGML_FP16_TO_FP32(src0_ptr[i00]);
|
5633
5779
|
}
|
5780
|
+
|
5634
5781
|
quantize_row_q(src0_f32, dst_ptr + id, ne00);
|
5635
|
-
id +=
|
5782
|
+
id += rs;
|
5636
5783
|
}
|
5784
|
+
id += rs * (ne01 - ir1);
|
5637
5785
|
}
|
5638
5786
|
}
|
5639
5787
|
} else {
|
@@ -5648,7 +5796,8 @@ static void ggml_compute_forward_dup_f16(
|
|
5648
5796
|
|
5649
5797
|
for (int i03 = 0; i03 < ne03; i03++) {
|
5650
5798
|
for (int i02 = 0; i02 < ne02; i02++) {
|
5651
|
-
|
5799
|
+
id += ne00 * ir0;
|
5800
|
+
for (int i01 = ir0; i01 < ir1; i01++) {
|
5652
5801
|
for (int i00 = 0; i00 < ne00; i00++) {
|
5653
5802
|
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
5654
5803
|
|
@@ -5656,6 +5805,7 @@ static void ggml_compute_forward_dup_f16(
|
|
5656
5805
|
id++;
|
5657
5806
|
}
|
5658
5807
|
}
|
5808
|
+
id += ne00 * (ne01 - ir1);
|
5659
5809
|
}
|
5660
5810
|
}
|
5661
5811
|
} else if (dst->type == GGML_TYPE_F16) {
|
@@ -5664,7 +5814,8 @@ static void ggml_compute_forward_dup_f16(
|
|
5664
5814
|
|
5665
5815
|
for (int i03 = 0; i03 < ne03; i03++) {
|
5666
5816
|
for (int i02 = 0; i02 < ne02; i02++) {
|
5667
|
-
|
5817
|
+
id += ne00 * ir0;
|
5818
|
+
for (int i01 = ir0; i01 < ir1; i01++) {
|
5668
5819
|
for (int i00 = 0; i00 < ne00; i00++) {
|
5669
5820
|
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
5670
5821
|
|
@@ -5672,6 +5823,7 @@ static void ggml_compute_forward_dup_f16(
|
|
5672
5823
|
id++;
|
5673
5824
|
}
|
5674
5825
|
}
|
5826
|
+
id += ne00 * (ne01 - ir1);
|
5675
5827
|
}
|
5676
5828
|
}
|
5677
5829
|
} else {
|
@@ -5690,7 +5842,20 @@ static void ggml_compute_forward_dup_f16(
|
|
5690
5842
|
if (dst->type == GGML_TYPE_F16) {
|
5691
5843
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
5692
5844
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
5693
|
-
|
5845
|
+
i10 += ne00 * ir0;
|
5846
|
+
while (i10 >= ne0) {
|
5847
|
+
i10 -= ne0;
|
5848
|
+
if (++i11 == ne1) {
|
5849
|
+
i11 = 0;
|
5850
|
+
if (++i12 == ne2) {
|
5851
|
+
i12 = 0;
|
5852
|
+
if (++i13 == ne3) {
|
5853
|
+
i13 = 0;
|
5854
|
+
}
|
5855
|
+
}
|
5856
|
+
}
|
5857
|
+
}
|
5858
|
+
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
5694
5859
|
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
5695
5860
|
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
5696
5861
|
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
@@ -5711,25 +5876,51 @@ static void ggml_compute_forward_dup_f16(
|
|
5711
5876
|
}
|
5712
5877
|
}
|
5713
5878
|
}
|
5879
|
+
i10 += ne00 * (ne01 - ir1);
|
5880
|
+
while (i10 >= ne0) {
|
5881
|
+
i10 -= ne0;
|
5882
|
+
if (++i11 == ne1) {
|
5883
|
+
i11 = 0;
|
5884
|
+
if (++i12 == ne2) {
|
5885
|
+
i12 = 0;
|
5886
|
+
if (++i13 == ne3) {
|
5887
|
+
i13 = 0;
|
5888
|
+
}
|
5889
|
+
}
|
5890
|
+
}
|
5891
|
+
}
|
5714
5892
|
}
|
5715
5893
|
}
|
5716
5894
|
} else if (dst->type == GGML_TYPE_F32) {
|
5717
5895
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
5718
5896
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
5719
|
-
|
5897
|
+
i10 += ne00 * ir0;
|
5898
|
+
while (i10 >= ne0) {
|
5899
|
+
i10 -= ne0;
|
5900
|
+
if (++i11 == ne1) {
|
5901
|
+
i11 = 0;
|
5902
|
+
if (++i12 == ne2) {
|
5903
|
+
i12 = 0;
|
5904
|
+
if (++i13 == ne3) {
|
5905
|
+
i13 = 0;
|
5906
|
+
}
|
5907
|
+
}
|
5908
|
+
}
|
5909
|
+
}
|
5910
|
+
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
5720
5911
|
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
5721
5912
|
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
5722
5913
|
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
5723
5914
|
|
5724
5915
|
*(float *) dst_ptr = GGML_FP16_TO_FP32(*(const ggml_fp16_t *) src0_ptr);
|
5725
5916
|
|
5726
|
-
if (++i10 ==
|
5917
|
+
if (++i10 == ne0) {
|
5727
5918
|
i10 = 0;
|
5728
|
-
if (++i11 ==
|
5919
|
+
if (++i11 == ne1) {
|
5729
5920
|
i11 = 0;
|
5730
|
-
if (++i12 ==
|
5921
|
+
if (++i12 == ne2) {
|
5731
5922
|
i12 = 0;
|
5732
|
-
if (++i13 ==
|
5923
|
+
if (++i13 == ne3) {
|
5733
5924
|
i13 = 0;
|
5734
5925
|
}
|
5735
5926
|
}
|
@@ -5737,6 +5928,19 @@ static void ggml_compute_forward_dup_f16(
|
|
5737
5928
|
}
|
5738
5929
|
}
|
5739
5930
|
}
|
5931
|
+
i10 += ne00 * (ne01 - ir1);
|
5932
|
+
while (i10 >= ne0) {
|
5933
|
+
i10 -= ne0;
|
5934
|
+
if (++i11 == ne1) {
|
5935
|
+
i11 = 0;
|
5936
|
+
if (++i12 == ne2) {
|
5937
|
+
i12 = 0;
|
5938
|
+
if (++i13 == ne3) {
|
5939
|
+
i13 = 0;
|
5940
|
+
}
|
5941
|
+
}
|
5942
|
+
}
|
5943
|
+
}
|
5740
5944
|
}
|
5741
5945
|
}
|
5742
5946
|
} else {
|
@@ -5748,7 +5952,6 @@ static void ggml_compute_forward_dup_f32(
|
|
5748
5952
|
const struct ggml_compute_params * params,
|
5749
5953
|
const struct ggml_tensor * src0,
|
5750
5954
|
struct ggml_tensor * dst) {
|
5751
|
-
GGML_ASSERT(params->ith == 0);
|
5752
5955
|
GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
|
5753
5956
|
|
5754
5957
|
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
@@ -5760,6 +5963,11 @@ static void ggml_compute_forward_dup_f32(
|
|
5760
5963
|
const int64_t ne02 = src0->ne[2];
|
5761
5964
|
const int64_t ne03 = src0->ne[3];
|
5762
5965
|
|
5966
|
+
const int64_t ne0 = dst->ne[0];
|
5967
|
+
const int64_t ne1 = dst->ne[1];
|
5968
|
+
const int64_t ne2 = dst->ne[2];
|
5969
|
+
const int64_t ne3 = dst->ne[3];
|
5970
|
+
|
5763
5971
|
const size_t nb00 = src0->nb[0];
|
5764
5972
|
const size_t nb01 = src0->nb[1];
|
5765
5973
|
const size_t nb02 = src0->nb[2];
|
@@ -5770,19 +5978,40 @@ static void ggml_compute_forward_dup_f32(
|
|
5770
5978
|
const size_t nb2 = dst->nb[2];
|
5771
5979
|
const size_t nb3 = dst->nb[3];
|
5772
5980
|
|
5981
|
+
const int ith = params->ith; // thread index
|
5982
|
+
const int nth = params->nth; // number of threads
|
5983
|
+
|
5773
5984
|
if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) {
|
5774
|
-
|
5985
|
+
// parallelize by elements
|
5986
|
+
const int ne = ggml_nelements(dst);
|
5987
|
+
const int dr = (ne + nth - 1) / nth;
|
5988
|
+
const int ie0 = dr * ith;
|
5989
|
+
const int ie1 = MIN(ie0 + dr, ne);
|
5990
|
+
|
5991
|
+
memcpy(
|
5992
|
+
((char *) dst->data + ie0*nb0),
|
5993
|
+
((char *) src0->data + ie0*nb00),
|
5994
|
+
(ie1 - ie0) * GGML_TYPE_SIZE[src0->type]);
|
5995
|
+
|
5775
5996
|
return;
|
5776
5997
|
}
|
5777
5998
|
|
5999
|
+
// parallelize by rows
|
6000
|
+
const int nr = ne01;
|
6001
|
+
// number of rows per thread
|
6002
|
+
const int dr = (nr + nth - 1) / nth;
|
6003
|
+
// row range for this thread
|
6004
|
+
const int ir0 = dr * ith;
|
6005
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
6006
|
+
|
5778
6007
|
if (src0->type == dst->type &&
|
5779
|
-
|
5780
|
-
|
6008
|
+
ne00 == ne0 &&
|
6009
|
+
nb00 == GGML_TYPE_SIZE[src0->type] && nb0 == GGML_TYPE_SIZE[dst->type]) {
|
5781
6010
|
// copy by rows
|
5782
6011
|
const size_t rs = ne00*nb00;
|
5783
6012
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
5784
6013
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
5785
|
-
for (int64_t i01 =
|
6014
|
+
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
5786
6015
|
memcpy(
|
5787
6016
|
((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
|
5788
6017
|
((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
|
@@ -5795,21 +6024,21 @@ static void ggml_compute_forward_dup_f32(
|
|
5795
6024
|
|
5796
6025
|
if (ggml_is_contiguous(dst)) {
|
5797
6026
|
// TODO: simplify
|
5798
|
-
if (
|
6027
|
+
if (nb00 == sizeof(float)) {
|
5799
6028
|
if (dst->type == GGML_TYPE_F32) {
|
5800
6029
|
size_t id = 0;
|
5801
|
-
const size_t rs = ne00*nb00;
|
6030
|
+
const size_t rs = ne00 * nb00;
|
6031
|
+
char * dst_ptr = (char *) dst->data;
|
5802
6032
|
|
5803
6033
|
for (int i03 = 0; i03 < ne03; i03++) {
|
5804
6034
|
for (int i02 = 0; i02 < ne02; i02++) {
|
5805
|
-
|
6035
|
+
id += rs * ir0;
|
6036
|
+
for (int i01 = ir0; i01 < ir1; i01++) {
|
5806
6037
|
const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
|
5807
|
-
|
5808
|
-
|
5809
|
-
memcpy(dst_ptr, src0_ptr, rs);
|
5810
|
-
|
5811
|
-
id++;
|
6038
|
+
memcpy(dst_ptr + id, src0_ptr, rs);
|
6039
|
+
id += rs;
|
5812
6040
|
}
|
6041
|
+
id += rs * (ne01 - ir1);
|
5813
6042
|
}
|
5814
6043
|
}
|
5815
6044
|
} else if (dst->type == GGML_TYPE_F16) {
|
@@ -5818,7 +6047,8 @@ static void ggml_compute_forward_dup_f32(
|
|
5818
6047
|
|
5819
6048
|
for (int i03 = 0; i03 < ne03; i03++) {
|
5820
6049
|
for (int i02 = 0; i02 < ne02; i02++) {
|
5821
|
-
|
6050
|
+
id += ne00 * ir0;
|
6051
|
+
for (int i01 = ir0; i01 < ir1; i01++) {
|
5822
6052
|
for (int i00 = 0; i00 < ne00; i00++) {
|
5823
6053
|
const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
5824
6054
|
|
@@ -5826,21 +6056,25 @@ static void ggml_compute_forward_dup_f32(
|
|
5826
6056
|
id++;
|
5827
6057
|
}
|
5828
6058
|
}
|
6059
|
+
id += ne00 * (ne01 - ir1);
|
5829
6060
|
}
|
5830
6061
|
}
|
5831
|
-
} else if (dst->type
|
6062
|
+
} else if (ggml_is_quantized(dst->type)) {
|
5832
6063
|
quantize_row_q_t const quantize_row_q = quantize_fns[dst->type].quantize_row_q;
|
6064
|
+
|
5833
6065
|
size_t id = 0;
|
5834
|
-
|
5835
|
-
|
6066
|
+
size_t rs = nb0 * (ne00 / GGML_BLCK_SIZE[dst->type]);
|
6067
|
+
char * dst_ptr = (char *) dst->data;
|
5836
6068
|
|
5837
6069
|
for (int i03 = 0; i03 < ne03; i03++) {
|
5838
6070
|
for (int i02 = 0; i02 < ne02; i02++) {
|
5839
|
-
|
6071
|
+
id += rs * ir0;
|
6072
|
+
for (int i01 = ir0; i01 < ir1; i01++) {
|
5840
6073
|
const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
5841
6074
|
quantize_row_q(src0_ptr, dst_ptr + id, ne00);
|
5842
|
-
id +=
|
6075
|
+
id += rs;
|
5843
6076
|
}
|
6077
|
+
id += rs * (ne01 - ir1);
|
5844
6078
|
}
|
5845
6079
|
}
|
5846
6080
|
} else {
|
@@ -5855,7 +6089,8 @@ static void ggml_compute_forward_dup_f32(
|
|
5855
6089
|
|
5856
6090
|
for (int i03 = 0; i03 < ne03; i03++) {
|
5857
6091
|
for (int i02 = 0; i02 < ne02; i02++) {
|
5858
|
-
|
6092
|
+
id += ne00 * ir0;
|
6093
|
+
for (int i01 = ir0; i01 < ir1; i01++) {
|
5859
6094
|
for (int i00 = 0; i00 < ne00; i00++) {
|
5860
6095
|
const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
5861
6096
|
|
@@ -5863,6 +6098,7 @@ static void ggml_compute_forward_dup_f32(
|
|
5863
6098
|
id++;
|
5864
6099
|
}
|
5865
6100
|
}
|
6101
|
+
id += ne00 * (ne01 - ir1);
|
5866
6102
|
}
|
5867
6103
|
}
|
5868
6104
|
} else if (dst->type == GGML_TYPE_F16) {
|
@@ -5871,7 +6107,8 @@ static void ggml_compute_forward_dup_f32(
|
|
5871
6107
|
|
5872
6108
|
for (int i03 = 0; i03 < ne03; i03++) {
|
5873
6109
|
for (int i02 = 0; i02 < ne02; i02++) {
|
5874
|
-
|
6110
|
+
id += ne00 * ir0;
|
6111
|
+
for (int i01 = ir0; i01 < ir1; i01++) {
|
5875
6112
|
for (int i00 = 0; i00 < ne00; i00++) {
|
5876
6113
|
const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
5877
6114
|
|
@@ -5879,6 +6116,7 @@ static void ggml_compute_forward_dup_f32(
|
|
5879
6116
|
id++;
|
5880
6117
|
}
|
5881
6118
|
}
|
6119
|
+
id += ne00 * (ne01 - ir1);
|
5882
6120
|
}
|
5883
6121
|
}
|
5884
6122
|
} else {
|
@@ -5890,6 +6128,7 @@ static void ggml_compute_forward_dup_f32(
|
|
5890
6128
|
}
|
5891
6129
|
|
5892
6130
|
// dst counters
|
6131
|
+
|
5893
6132
|
int64_t i10 = 0;
|
5894
6133
|
int64_t i11 = 0;
|
5895
6134
|
int64_t i12 = 0;
|
@@ -5898,20 +6137,33 @@ static void ggml_compute_forward_dup_f32(
|
|
5898
6137
|
if (dst->type == GGML_TYPE_F32) {
|
5899
6138
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
5900
6139
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
5901
|
-
|
6140
|
+
i10 += ne00 * ir0;
|
6141
|
+
while (i10 >= ne0) {
|
6142
|
+
i10 -= ne0;
|
6143
|
+
if (++i11 == ne1) {
|
6144
|
+
i11 = 0;
|
6145
|
+
if (++i12 == ne2) {
|
6146
|
+
i12 = 0;
|
6147
|
+
if (++i13 == ne3) {
|
6148
|
+
i13 = 0;
|
6149
|
+
}
|
6150
|
+
}
|
6151
|
+
}
|
6152
|
+
}
|
6153
|
+
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
5902
6154
|
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
5903
6155
|
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
5904
6156
|
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
5905
6157
|
|
5906
6158
|
memcpy(dst_ptr, src0_ptr, sizeof(float));
|
5907
6159
|
|
5908
|
-
if (++i10 ==
|
6160
|
+
if (++i10 == ne0) {
|
5909
6161
|
i10 = 0;
|
5910
|
-
if (++i11 ==
|
6162
|
+
if (++i11 == ne1) {
|
5911
6163
|
i11 = 0;
|
5912
|
-
if (++i12 ==
|
6164
|
+
if (++i12 == ne2) {
|
5913
6165
|
i12 = 0;
|
5914
|
-
if (++i13 ==
|
6166
|
+
if (++i13 == ne3) {
|
5915
6167
|
i13 = 0;
|
5916
6168
|
}
|
5917
6169
|
}
|
@@ -5919,25 +6171,51 @@ static void ggml_compute_forward_dup_f32(
|
|
5919
6171
|
}
|
5920
6172
|
}
|
5921
6173
|
}
|
6174
|
+
i10 += ne00 * (ne01 - ir1);
|
6175
|
+
while (i10 >= ne0) {
|
6176
|
+
i10 -= ne0;
|
6177
|
+
if (++i11 == ne1) {
|
6178
|
+
i11 = 0;
|
6179
|
+
if (++i12 == ne2) {
|
6180
|
+
i12 = 0;
|
6181
|
+
if (++i13 == ne3) {
|
6182
|
+
i13 = 0;
|
6183
|
+
}
|
6184
|
+
}
|
6185
|
+
}
|
6186
|
+
}
|
5922
6187
|
}
|
5923
6188
|
}
|
5924
6189
|
} else if (dst->type == GGML_TYPE_F16) {
|
5925
6190
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
5926
6191
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
5927
|
-
|
6192
|
+
i10 += ne00 * ir0;
|
6193
|
+
while (i10 >= ne0) {
|
6194
|
+
i10 -= ne0;
|
6195
|
+
if (++i11 == ne1) {
|
6196
|
+
i11 = 0;
|
6197
|
+
if (++i12 == ne2) {
|
6198
|
+
i12 = 0;
|
6199
|
+
if (++i13 == ne3) {
|
6200
|
+
i13 = 0;
|
6201
|
+
}
|
6202
|
+
}
|
6203
|
+
}
|
6204
|
+
}
|
6205
|
+
for (int64_t i01 = ir0; i01 < ir1; i01++) {
|
5928
6206
|
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
5929
6207
|
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
5930
6208
|
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
5931
6209
|
|
5932
6210
|
*(ggml_fp16_t *) dst_ptr = GGML_FP32_TO_FP16(*(const float *) src0_ptr);
|
5933
6211
|
|
5934
|
-
if (++i10 ==
|
6212
|
+
if (++i10 == ne0) {
|
5935
6213
|
i10 = 0;
|
5936
|
-
if (++i11 ==
|
6214
|
+
if (++i11 == ne1) {
|
5937
6215
|
i11 = 0;
|
5938
|
-
if (++i12 ==
|
6216
|
+
if (++i12 == ne2) {
|
5939
6217
|
i12 = 0;
|
5940
|
-
if (++i13 ==
|
6218
|
+
if (++i13 == ne3) {
|
5941
6219
|
i13 = 0;
|
5942
6220
|
}
|
5943
6221
|
}
|
@@ -5945,6 +6223,19 @@ static void ggml_compute_forward_dup_f32(
|
|
5945
6223
|
}
|
5946
6224
|
}
|
5947
6225
|
}
|
6226
|
+
i10 += ne00 * (ne01 - ir1);
|
6227
|
+
while (i10 >= ne0) {
|
6228
|
+
i10 -= ne0;
|
6229
|
+
if (++i11 == ne1) {
|
6230
|
+
i11 = 0;
|
6231
|
+
if (++i12 == ne2) {
|
6232
|
+
i12 = 0;
|
6233
|
+
if (++i13 == ne3) {
|
6234
|
+
i13 = 0;
|
6235
|
+
}
|
6236
|
+
}
|
6237
|
+
}
|
6238
|
+
}
|
5948
6239
|
}
|
5949
6240
|
}
|
5950
6241
|
} else {
|
@@ -6191,7 +6482,7 @@ static void ggml_compute_forward_add_q_f32(
|
|
6191
6482
|
GGML_ASSERT(nb1 <= nb2);
|
6192
6483
|
GGML_ASSERT(nb2 <= nb3);
|
6193
6484
|
|
6194
|
-
GGML_ASSERT(src0->type
|
6485
|
+
GGML_ASSERT(ggml_is_quantized(src0->type));
|
6195
6486
|
GGML_ASSERT(dst->type == src0->type);
|
6196
6487
|
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
6197
6488
|
|
@@ -6205,7 +6496,7 @@ static void ggml_compute_forward_add_q_f32(
|
|
6205
6496
|
const int ir0 = dr*ith;
|
6206
6497
|
const int ir1 = MIN(ir0 + dr, nr);
|
6207
6498
|
|
6208
|
-
float * wdata = (float*) params->wdata + ne00 * ith;
|
6499
|
+
float * wdata = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
|
6209
6500
|
|
6210
6501
|
for (int ir = ir0; ir < ir1; ++ir) {
|
6211
6502
|
// src0 indices
|
@@ -6261,6 +6552,8 @@ static void ggml_compute_forward_add(
|
|
6261
6552
|
} break;
|
6262
6553
|
case GGML_TYPE_Q4_0:
|
6263
6554
|
case GGML_TYPE_Q4_1:
|
6555
|
+
case GGML_TYPE_Q4_2:
|
6556
|
+
case GGML_TYPE_Q4_3:
|
6264
6557
|
{
|
6265
6558
|
ggml_compute_forward_add_q_f32(params, src0, src1, dst);
|
6266
6559
|
} break;
|
@@ -7161,7 +7454,7 @@ static void ggml_compute_forward_rms_norm(
|
|
7161
7454
|
|
7162
7455
|
// ggml_compute_forward_mul_mat
|
7163
7456
|
|
7164
|
-
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
|
7457
|
+
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
|
7165
7458
|
// helper function to determine if it is better to use BLAS or not
|
7166
7459
|
// for large matrices, BLAS is faster
|
7167
7460
|
static bool ggml_compute_forward_mul_mat_use_blas(
|
@@ -7201,7 +7494,7 @@ static void ggml_compute_forward_mul_mat_f32(
|
|
7201
7494
|
const int64_t ne02 = src0->ne[2];
|
7202
7495
|
const int64_t ne03 = src0->ne[3];
|
7203
7496
|
|
7204
|
-
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
|
7497
|
+
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
|
7205
7498
|
const int64_t ne10 = src1->ne[0];
|
7206
7499
|
#endif
|
7207
7500
|
const int64_t ne11 = src1->ne[1];
|
@@ -7258,7 +7551,7 @@ static void ggml_compute_forward_mul_mat_f32(
|
|
7258
7551
|
// nb01 >= nb00 - src0 is not transposed
|
7259
7552
|
// compute by src0 rows
|
7260
7553
|
|
7261
|
-
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
|
7554
|
+
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
|
7262
7555
|
if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
|
7263
7556
|
if (params->ith != 0) {
|
7264
7557
|
return;
|
@@ -7272,6 +7565,21 @@ static void ggml_compute_forward_mul_mat_f32(
|
|
7272
7565
|
return;
|
7273
7566
|
}
|
7274
7567
|
|
7568
|
+
#if defined(GGML_USE_CUBLAS)
|
7569
|
+
float *d_X = NULL;
|
7570
|
+
float *d_Y = NULL;
|
7571
|
+
float *d_D = NULL;
|
7572
|
+
const float alpha = 1.0f;
|
7573
|
+
const float beta = 0.0f;
|
7574
|
+
const int x_ne = ne01 * ne10;
|
7575
|
+
const int y_ne = ne11 * ne10;
|
7576
|
+
const int d_ne = ne11 * ne01;
|
7577
|
+
|
7578
|
+
CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(float) * x_ne));
|
7579
|
+
CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne));
|
7580
|
+
CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne));
|
7581
|
+
#endif
|
7582
|
+
|
7275
7583
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
7276
7584
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
7277
7585
|
const float * x = (float *) ((char *) src0->data + i02*nb02 + i03*nb03);
|
@@ -7279,15 +7587,37 @@ static void ggml_compute_forward_mul_mat_f32(
|
|
7279
7587
|
|
7280
7588
|
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
|
7281
7589
|
|
7590
|
+
#if defined(GGML_USE_CUBLAS)
|
7591
|
+
// copy data to device
|
7592
|
+
CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(float) * x_ne, cudaMemcpyHostToDevice, cudaStream));
|
7593
|
+
CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, cudaStream));
|
7594
|
+
|
7595
|
+
// compute
|
7596
|
+
CUBLAS_CHECK(
|
7597
|
+
cublasSgemm(cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
|
7598
|
+
ne01, ne11, ne10,
|
7599
|
+
&alpha, d_X, ne00,
|
7600
|
+
d_Y, ne10,
|
7601
|
+
&beta, d_D, ne01));
|
7602
|
+
|
7603
|
+
// copy data to host
|
7604
|
+
CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
|
7605
|
+
#else
|
7282
7606
|
// zT = y * xT
|
7283
7607
|
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
|
7284
7608
|
ne11, ne01, ne10,
|
7285
7609
|
1.0f, y, ne10,
|
7286
7610
|
x, ne00,
|
7287
7611
|
0.0f, d, ne01);
|
7612
|
+
#endif
|
7288
7613
|
}
|
7289
7614
|
}
|
7290
|
-
|
7615
|
+
#if defined(GGML_USE_CUBLAS)
|
7616
|
+
CUDA_CHECK(cudaStreamSynchronize(cudaStream));
|
7617
|
+
CUDA_CHECK(cudaFree(d_X));
|
7618
|
+
CUDA_CHECK(cudaFree(d_Y));
|
7619
|
+
CUDA_CHECK(cudaFree(d_D));
|
7620
|
+
#endif
|
7291
7621
|
//printf("CBLAS F32 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
|
7292
7622
|
|
7293
7623
|
return;
|
@@ -7417,7 +7747,7 @@ static void ggml_compute_forward_mul_mat_f16_f32(
|
|
7417
7747
|
// nb01 >= nb00 - src0 is not transposed
|
7418
7748
|
// compute by src0 rows
|
7419
7749
|
|
7420
|
-
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
|
7750
|
+
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
|
7421
7751
|
if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
|
7422
7752
|
GGML_ASSERT(nb10 == sizeof(float));
|
7423
7753
|
|
@@ -7433,10 +7763,37 @@ static void ggml_compute_forward_mul_mat_f16_f32(
|
|
7433
7763
|
return;
|
7434
7764
|
}
|
7435
7765
|
|
7436
|
-
|
7766
|
+
#if defined(GGML_USE_CUBLAS)
|
7767
|
+
ggml_fp16_t * const wdata = params->wdata;
|
7437
7768
|
|
7769
|
+
float *d_X = NULL;
|
7770
|
+
float *d_Y = NULL;
|
7771
|
+
float *d_D = NULL;
|
7772
|
+
const float alpha = 1.0f;
|
7773
|
+
const float beta = 0.0f;
|
7774
|
+
const int x_ne = ne01 * ne10;
|
7775
|
+
const int y_ne = ne11 * ne10;
|
7776
|
+
const int d_ne = ne11 * ne01;
|
7777
|
+
|
7778
|
+
CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(ggml_fp16_t) * x_ne));
|
7779
|
+
CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne));
|
7780
|
+
CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne));
|
7781
|
+
#else
|
7782
|
+
float * const wdata = params->wdata;
|
7783
|
+
#endif
|
7438
7784
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
7439
7785
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
7786
|
+
#if defined(GGML_USE_CUBLAS)
|
7787
|
+
// with cuBlAS, instead of converting src0 to fp32, we convert src1 to fp16
|
7788
|
+
{
|
7789
|
+
size_t id = 0;
|
7790
|
+
for (int64_t i01 = 0; i01 < ne11; ++i01) {
|
7791
|
+
for (int64_t i00 = 0; i00 < ne10; ++i00) {
|
7792
|
+
wdata[id++] = GGML_FP32_TO_FP16(*(float *) ((char *) src1->data + i03*nb13 + i02*nb12 + i01*nb11 + i00*nb10));
|
7793
|
+
}
|
7794
|
+
}
|
7795
|
+
}
|
7796
|
+
#else
|
7440
7797
|
{
|
7441
7798
|
size_t id = 0;
|
7442
7799
|
for (int64_t i01 = 0; i01 < ne01; ++i01) {
|
@@ -7445,7 +7802,31 @@ static void ggml_compute_forward_mul_mat_f16_f32(
|
|
7445
7802
|
}
|
7446
7803
|
}
|
7447
7804
|
}
|
7805
|
+
#endif
|
7806
|
+
|
7807
|
+
#if defined(GGML_USE_CUBLAS)
|
7808
|
+
const ggml_fp16_t * x = (ggml_fp16_t *) ((char *) src0->data + i02*nb02 + i03*nb03);
|
7809
|
+
const ggml_fp16_t * y = (ggml_fp16_t *) wdata;
|
7810
|
+
|
7811
|
+
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
|
7448
7812
|
|
7813
|
+
// copy data to device
|
7814
|
+
CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(ggml_fp16_t) * x_ne, cudaMemcpyHostToDevice, cudaStream));
|
7815
|
+
CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(ggml_fp16_t) * y_ne, cudaMemcpyHostToDevice, cudaStream));
|
7816
|
+
|
7817
|
+
// compute
|
7818
|
+
CUBLAS_CHECK(
|
7819
|
+
cublasGemmEx(cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
|
7820
|
+
ne01, ne11, ne10,
|
7821
|
+
&alpha, d_X, CUDA_R_16F, ne00,
|
7822
|
+
d_Y, CUDA_R_16F, ne10,
|
7823
|
+
&beta, d_D, CUDA_R_32F, ne01,
|
7824
|
+
CUBLAS_COMPUTE_32F,
|
7825
|
+
CUBLAS_GEMM_DEFAULT));
|
7826
|
+
|
7827
|
+
// copy data to host
|
7828
|
+
CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
|
7829
|
+
#else
|
7449
7830
|
const float * x = wdata;
|
7450
7831
|
const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
|
7451
7832
|
|
@@ -7457,9 +7838,16 @@ static void ggml_compute_forward_mul_mat_f16_f32(
|
|
7457
7838
|
1.0f, y, ne10,
|
7458
7839
|
x, ne00,
|
7459
7840
|
0.0f, d, ne01);
|
7841
|
+
#endif
|
7460
7842
|
}
|
7461
7843
|
}
|
7462
7844
|
|
7845
|
+
#if defined(GGML_USE_CUBLAS)
|
7846
|
+
CUDA_CHECK(cudaStreamSynchronize(cudaStream));
|
7847
|
+
CUDA_CHECK(cudaFree(d_X));
|
7848
|
+
CUDA_CHECK(cudaFree(d_Y));
|
7849
|
+
CUDA_CHECK(cudaFree(d_D));
|
7850
|
+
#endif
|
7463
7851
|
/*printf("CBLAS F16 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);*/
|
7464
7852
|
|
7465
7853
|
return;
|
@@ -7611,7 +7999,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|
7611
7999
|
// nb01 >= nb00 - src0 is not transposed
|
7612
8000
|
// compute by src0 rows
|
7613
8001
|
|
7614
|
-
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
|
8002
|
+
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
|
7615
8003
|
if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
|
7616
8004
|
if (params->ith != 0) {
|
7617
8005
|
return;
|
@@ -7625,11 +8013,55 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|
7625
8013
|
return;
|
7626
8014
|
}
|
7627
8015
|
|
8016
|
+
#if defined(GGML_USE_CUBLAS)
|
8017
|
+
float *d_X = NULL;
|
8018
|
+
float *d_Y = NULL;
|
8019
|
+
float *d_D = NULL;
|
8020
|
+
float *d_Q = NULL;
|
8021
|
+
const float alpha = 1.0f;
|
8022
|
+
const float beta = 0.0f;
|
8023
|
+
const int x_ne = ne01 * ne10;
|
8024
|
+
const int y_ne = ne11 * ne10;
|
8025
|
+
const int d_ne = ne11 * ne01;
|
8026
|
+
|
8027
|
+
CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(float) * x_ne));
|
8028
|
+
CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne));
|
8029
|
+
CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne));
|
8030
|
+
CUDA_CHECK(cudaMalloc((void **)(&d_Q), GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type]));
|
8031
|
+
|
8032
|
+
void (*dequantize_row_q_cuda)(const void * x, float * y, int k, cudaStream_t stream) = NULL;
|
8033
|
+
if (type == GGML_TYPE_Q4_0) {
|
8034
|
+
dequantize_row_q_cuda = dequantize_row_q4_0_cuda;
|
8035
|
+
}
|
8036
|
+
else if (type == GGML_TYPE_Q4_1) {
|
8037
|
+
dequantize_row_q_cuda = dequantize_row_q4_1_cuda;
|
8038
|
+
}
|
8039
|
+
else if (type == GGML_TYPE_Q4_2) {
|
8040
|
+
dequantize_row_q_cuda = dequantize_row_q4_2_cuda;
|
8041
|
+
}
|
8042
|
+
else {
|
8043
|
+
GGML_ASSERT(false);
|
8044
|
+
}
|
8045
|
+
#else
|
7628
8046
|
float * const wdata = params->wdata;
|
7629
8047
|
dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
|
8048
|
+
#endif
|
7630
8049
|
|
7631
8050
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
7632
8051
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
8052
|
+
const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
|
8053
|
+
|
8054
|
+
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
|
8055
|
+
|
8056
|
+
#if defined(GGML_USE_CUBLAS)
|
8057
|
+
// copy and dequantize on device
|
8058
|
+
CUDA_CHECK(
|
8059
|
+
cudaMemcpyAsync(d_Q, (char *) src0->data + i03*nb03 + i02*nb02,
|
8060
|
+
GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], cudaMemcpyHostToDevice, cudaStream));
|
8061
|
+
|
8062
|
+
dequantize_row_q_cuda(d_Q, d_X, ne01 * ne00, cudaStream);
|
8063
|
+
CUDA_CHECK(cudaGetLastError());
|
8064
|
+
#else
|
7633
8065
|
{
|
7634
8066
|
size_t id = 0;
|
7635
8067
|
for (int64_t i01 = 0; i01 < ne01; ++i01) {
|
@@ -7637,21 +8069,42 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|
7637
8069
|
id += ne00;
|
7638
8070
|
}
|
7639
8071
|
}
|
7640
|
-
|
7641
8072
|
const float * x = wdata;
|
7642
|
-
|
8073
|
+
#endif
|
7643
8074
|
|
7644
|
-
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
|
7645
8075
|
|
8076
|
+
#if defined(GGML_USE_CUBLAS)
|
8077
|
+
// copy data to device
|
8078
|
+
CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, cudaStream));
|
8079
|
+
|
8080
|
+
// compute
|
8081
|
+
CUBLAS_CHECK(
|
8082
|
+
cublasSgemm(cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
|
8083
|
+
ne01, ne11, ne10,
|
8084
|
+
&alpha, d_X, ne00,
|
8085
|
+
d_Y, ne10,
|
8086
|
+
&beta, d_D, ne01));
|
8087
|
+
|
8088
|
+
// copy data to host
|
8089
|
+
CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
|
8090
|
+
#else
|
7646
8091
|
// zT = y * xT
|
7647
8092
|
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
|
7648
8093
|
ne11, ne01, ne10,
|
7649
8094
|
1.0f, y, ne10,
|
7650
8095
|
x, ne00,
|
7651
8096
|
0.0f, d, ne01);
|
8097
|
+
#endif
|
7652
8098
|
}
|
7653
8099
|
}
|
7654
8100
|
|
8101
|
+
#if defined(GGML_USE_CUBLAS)
|
8102
|
+
CUDA_CHECK(cudaStreamSynchronize(cudaStream));
|
8103
|
+
CUDA_CHECK(cudaFree(d_X));
|
8104
|
+
CUDA_CHECK(cudaFree(d_Y));
|
8105
|
+
CUDA_CHECK(cudaFree(d_D));
|
8106
|
+
CUDA_CHECK(cudaFree(d_Q));
|
8107
|
+
#endif
|
7655
8108
|
//printf("CBLAS = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
|
7656
8109
|
|
7657
8110
|
return;
|
@@ -7739,6 +8192,8 @@ static void ggml_compute_forward_mul_mat(
|
|
7739
8192
|
switch (src0->type) {
|
7740
8193
|
case GGML_TYPE_Q4_0:
|
7741
8194
|
case GGML_TYPE_Q4_1:
|
8195
|
+
case GGML_TYPE_Q4_2:
|
8196
|
+
case GGML_TYPE_Q4_3:
|
7742
8197
|
case GGML_TYPE_Q8_0:
|
7743
8198
|
{
|
7744
8199
|
ggml_compute_forward_mul_mat_q_f32(params, src0, src1, dst);
|
@@ -7756,34 +8211,6 @@ static void ggml_compute_forward_mul_mat(
|
|
7756
8211
|
GGML_ASSERT(false);
|
7757
8212
|
} break;
|
7758
8213
|
}
|
7759
|
-
|
7760
|
-
#if 0
|
7761
|
-
if (src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_Q4_1) {
|
7762
|
-
static int first = 8;
|
7763
|
-
printf("src0: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src0->ne[0], src0->ne[1], src0->ne[2]);
|
7764
|
-
printf("src1: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src1->ne[0], src1->ne[1], src1->ne[2]);
|
7765
|
-
printf("dst: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", dst->ne[0], dst->ne[1], dst->ne[2]);
|
7766
|
-
if (first) {
|
7767
|
-
--first;
|
7768
|
-
} else {
|
7769
|
-
for (int k = 0; k < dst->ne[1]; ++k) {
|
7770
|
-
for (int j = 0; j < dst->ne[0]/16; ++j) {
|
7771
|
-
for (int i = 0; i < 16; ++i) {
|
7772
|
-
printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]);
|
7773
|
-
}
|
7774
|
-
printf("\n");
|
7775
|
-
}
|
7776
|
-
printf("\n");
|
7777
|
-
}
|
7778
|
-
printf("\n");
|
7779
|
-
exit(0);
|
7780
|
-
}
|
7781
|
-
} else {
|
7782
|
-
printf("aaaa src0: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src0->ne[0], src0->ne[1], src0->ne[2]);
|
7783
|
-
printf("aaaa src1: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src1->ne[0], src1->ne[1], src1->ne[2]);
|
7784
|
-
printf("aaaa dst: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", dst->ne[0], dst->ne[1], dst->ne[2]);
|
7785
|
-
}
|
7786
|
-
#endif
|
7787
8214
|
}
|
7788
8215
|
|
7789
8216
|
// ggml_compute_forward_scale
|
@@ -7994,6 +8421,8 @@ static void ggml_compute_forward_get_rows(
|
|
7994
8421
|
switch (src0->type) {
|
7995
8422
|
case GGML_TYPE_Q4_0:
|
7996
8423
|
case GGML_TYPE_Q4_1:
|
8424
|
+
case GGML_TYPE_Q4_2:
|
8425
|
+
case GGML_TYPE_Q4_3:
|
7997
8426
|
case GGML_TYPE_Q8_0:
|
7998
8427
|
{
|
7999
8428
|
ggml_compute_forward_get_rows_q(params, src0, src1, dst);
|
@@ -8224,9 +8653,11 @@ static void ggml_compute_forward_rope_f32(
|
|
8224
8653
|
|
8225
8654
|
const float theta_scale = powf(10000.0, -2.0f/n_dims);
|
8226
8655
|
|
8656
|
+
const bool is_neox = mode & 2;
|
8657
|
+
|
8227
8658
|
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
8228
|
-
for (int64_t i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) {
|
8229
|
-
const int p = (mode == 0 ? n_past + i2 : i2);
|
8659
|
+
for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
|
8660
|
+
const int p = ((mode & 1) == 0 ? n_past + i2 : i2);
|
8230
8661
|
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
8231
8662
|
if (ir++ < ir0) continue;
|
8232
8663
|
if (ir > ir1) break;
|
@@ -8239,14 +8670,25 @@ static void ggml_compute_forward_rope_f32(
|
|
8239
8670
|
|
8240
8671
|
theta *= theta_scale;
|
8241
8672
|
|
8242
|
-
|
8243
|
-
|
8673
|
+
if (!is_neox) {
|
8674
|
+
const float * const src = (float *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
8675
|
+
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
8244
8676
|
|
8245
|
-
|
8246
|
-
|
8677
|
+
const float x0 = src[0];
|
8678
|
+
const float x1 = src[1];
|
8247
8679
|
|
8248
|
-
|
8249
|
-
|
8680
|
+
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
8681
|
+
dst_data[1] = x0*sin_theta + x1*cos_theta;
|
8682
|
+
} else {
|
8683
|
+
const float * const src = (float *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
|
8684
|
+
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
|
8685
|
+
|
8686
|
+
const float x0 = src[0];
|
8687
|
+
const float x1 = src[n_dims/2];
|
8688
|
+
|
8689
|
+
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
8690
|
+
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
|
8691
|
+
}
|
8250
8692
|
}
|
8251
8693
|
}
|
8252
8694
|
}
|
@@ -8301,9 +8743,11 @@ static void ggml_compute_forward_rope_f16(
|
|
8301
8743
|
|
8302
8744
|
const float theta_scale = powf(10000.0, -2.0f/n_dims);
|
8303
8745
|
|
8746
|
+
const bool is_neox = mode & 2;
|
8747
|
+
|
8304
8748
|
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
8305
|
-
for (int64_t i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) {
|
8306
|
-
const int p = (mode == 0 ? n_past + i2 : i2);
|
8749
|
+
for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
|
8750
|
+
const int p = ((mode & 1) == 0 ? n_past + i2 : i2);
|
8307
8751
|
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
8308
8752
|
if (ir++ < ir0) continue;
|
8309
8753
|
if (ir > ir1) break;
|
@@ -8316,14 +8760,25 @@ static void ggml_compute_forward_rope_f16(
|
|
8316
8760
|
|
8317
8761
|
theta *= theta_scale;
|
8318
8762
|
|
8319
|
-
|
8320
|
-
|
8763
|
+
if (!is_neox) {
|
8764
|
+
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
8765
|
+
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
8321
8766
|
|
8322
|
-
|
8323
|
-
|
8767
|
+
const float x0 = GGML_FP16_TO_FP32(src[0]);
|
8768
|
+
const float x1 = GGML_FP16_TO_FP32(src[1]);
|
8324
8769
|
|
8325
|
-
|
8326
|
-
|
8770
|
+
dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
|
8771
|
+
dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
8772
|
+
} else {
|
8773
|
+
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
|
8774
|
+
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
|
8775
|
+
|
8776
|
+
const float x0 = GGML_FP16_TO_FP32(src[0]);
|
8777
|
+
const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);
|
8778
|
+
|
8779
|
+
dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
|
8780
|
+
dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
8781
|
+
}
|
8327
8782
|
}
|
8328
8783
|
}
|
8329
8784
|
}
|
@@ -10402,11 +10857,11 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|
10402
10857
|
case GGML_OP_CPY:
|
10403
10858
|
case GGML_OP_DUP:
|
10404
10859
|
{
|
10405
|
-
node->n_tasks =
|
10860
|
+
node->n_tasks = n_threads;
|
10406
10861
|
|
10407
10862
|
size_t cur = 0;
|
10408
|
-
if (node->type
|
10409
|
-
cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->ne[0];
|
10863
|
+
if (ggml_is_quantized(node->type)) {
|
10864
|
+
cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->ne[0] * n_threads;
|
10410
10865
|
}
|
10411
10866
|
|
10412
10867
|
work_size = MAX(work_size, cur);
|
@@ -10417,7 +10872,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|
10417
10872
|
|
10418
10873
|
size_t cur = 0;
|
10419
10874
|
|
10420
|
-
if (node->src0->type
|
10875
|
+
if (ggml_is_quantized(node->src0->type)) {
|
10421
10876
|
cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src0->ne[0] * n_threads;
|
10422
10877
|
}
|
10423
10878
|
|
@@ -10466,7 +10921,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|
10466
10921
|
size_t cur = 0;
|
10467
10922
|
|
10468
10923
|
if (node->src0->type == GGML_TYPE_F16 && node->src1->type == GGML_TYPE_F32) {
|
10469
|
-
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
|
10924
|
+
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
|
10470
10925
|
if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
|
10471
10926
|
node->n_tasks = 1; // TODO: this actually is doing nothing
|
10472
10927
|
// the threads are still spinning
|
@@ -10482,8 +10937,8 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|
10482
10937
|
#endif
|
10483
10938
|
} else if (node->src0->type == GGML_TYPE_F32 && node->src1->type == GGML_TYPE_F32) {
|
10484
10939
|
cur = 0;
|
10485
|
-
} else if (
|
10486
|
-
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
|
10940
|
+
} else if (ggml_is_quantized(node->src0->type) && node->src1->type == GGML_TYPE_F32) {
|
10941
|
+
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
|
10487
10942
|
if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
|
10488
10943
|
node->n_tasks = 1;
|
10489
10944
|
cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
|
@@ -11709,6 +12164,86 @@ size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t *
|
|
11709
12164
|
return (n/QK4_1*sizeof(block_q4_1));
|
11710
12165
|
}
|
11711
12166
|
|
12167
|
+
size_t ggml_quantize_q4_2(const float * src, void * dst, int n, int k, int64_t * hist) {
|
12168
|
+
assert(k % QK4_2 == 0);
|
12169
|
+
const int nb = k / QK4_2;
|
12170
|
+
|
12171
|
+
for (int j = 0; j < n; j += k) {
|
12172
|
+
block_q4_2 * restrict y = (block_q4_2 *)dst + j/QK4_2;
|
12173
|
+
|
12174
|
+
//quantize_row_q4_2_reference(src + j, y, k);
|
12175
|
+
quantize_row_q4_2_rmse(src + j, y, k);
|
12176
|
+
|
12177
|
+
for (int i = 0; i < nb; i++) {
|
12178
|
+
for (int l = 0; l < QK4_2; l += 2) {
|
12179
|
+
const uint8_t vi0 = y[i].qs[l/2] & 0xF;
|
12180
|
+
const uint8_t vi1 = y[i].qs[l/2] >> 4;
|
12181
|
+
|
12182
|
+
hist[vi0]++;
|
12183
|
+
hist[vi1]++;
|
12184
|
+
}
|
12185
|
+
}
|
12186
|
+
}
|
12187
|
+
|
12188
|
+
return (n/QK4_2*sizeof(block_q4_2));
|
12189
|
+
}
|
12190
|
+
|
12191
|
+
size_t ggml_quantize_q4_3(const float * src, void * dst, int n, int k, int64_t * hist) {
|
12192
|
+
assert(k % QK4_3 == 0);
|
12193
|
+
const int nb = k / QK4_3;
|
12194
|
+
|
12195
|
+
for (int j = 0; j < n; j += k) {
|
12196
|
+
block_q4_3 * restrict y = (block_q4_3 *)dst + j/QK4_3;
|
12197
|
+
|
12198
|
+
quantize_row_q4_3_reference(src + j, y, k);
|
12199
|
+
|
12200
|
+
for (int i = 0; i < nb; i++) {
|
12201
|
+
for (int l = 0; l < QK4_3; l += 2) {
|
12202
|
+
const uint8_t vi0 = y[i].qs[l/2] & 0xF;
|
12203
|
+
const uint8_t vi1 = y[i].qs[l/2] >> 4;
|
12204
|
+
|
12205
|
+
hist[vi0]++;
|
12206
|
+
hist[vi1]++;
|
12207
|
+
}
|
12208
|
+
}
|
12209
|
+
}
|
12210
|
+
|
12211
|
+
return (n/QK4_3*sizeof(block_q4_3));
|
12212
|
+
}
|
12213
|
+
|
12214
|
+
size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist) {
|
12215
|
+
size_t result = 0;
|
12216
|
+
switch (type) {
|
12217
|
+
case GGML_TYPE_Q4_0:
|
12218
|
+
{
|
12219
|
+
GGML_ASSERT(start % QK4_0 == 0);
|
12220
|
+
block_q4_0 * block = (block_q4_0*)dst + start / QK4_0;
|
12221
|
+
result = ggml_quantize_q4_0(src + start, block, n, n, hist);
|
12222
|
+
} break;
|
12223
|
+
case GGML_TYPE_Q4_1:
|
12224
|
+
{
|
12225
|
+
GGML_ASSERT(start % QK4_1 == 0);
|
12226
|
+
block_q4_1 * block = (block_q4_1*)dst + start / QK4_1;
|
12227
|
+
result = ggml_quantize_q4_1(src + start, block, n, n, hist);
|
12228
|
+
} break;
|
12229
|
+
case GGML_TYPE_Q4_2:
|
12230
|
+
{
|
12231
|
+
GGML_ASSERT(start % QK4_2 == 0);
|
12232
|
+
block_q4_2 * block = (block_q4_2*)dst + start / QK4_2;
|
12233
|
+
result = ggml_quantize_q4_2(src + start, block, n, n, hist);
|
12234
|
+
} break;
|
12235
|
+
case GGML_TYPE_Q4_3:
|
12236
|
+
{
|
12237
|
+
GGML_ASSERT(start % QK4_3 == 0);
|
12238
|
+
block_q4_3 * block = (block_q4_3*)dst + start / QK4_3;
|
12239
|
+
result = ggml_quantize_q4_3(src + start, block, n, n, hist);
|
12240
|
+
} break;
|
12241
|
+
default:
|
12242
|
+
assert(false);
|
12243
|
+
}
|
12244
|
+
return result;
|
12245
|
+
}
|
12246
|
+
|
11712
12247
|
////////////////////////////////////////////////////////////////////////////////
|
11713
12248
|
|
11714
12249
|
int ggml_cpu_has_avx(void) {
|
@@ -11800,7 +12335,15 @@ int ggml_cpu_has_wasm_simd(void) {
|
|
11800
12335
|
}
|
11801
12336
|
|
11802
12337
|
int ggml_cpu_has_blas(void) {
|
11803
|
-
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
|
12338
|
+
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
|
12339
|
+
return 1;
|
12340
|
+
#else
|
12341
|
+
return 0;
|
12342
|
+
#endif
|
12343
|
+
}
|
12344
|
+
|
12345
|
+
int ggml_cpu_has_cublas(void) {
|
12346
|
+
#if defined(GGML_USE_CUBLAS)
|
11804
12347
|
return 1;
|
11805
12348
|
#else
|
11806
12349
|
return 0;
|