llama_cpp 0.2.2 → 0.3.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +34 -0
- data/README.md +39 -6
- data/examples/chat.rb +2 -1
- data/examples/embedding.rb +3 -2
- data/ext/llama_cpp/extconf.rb +13 -0
- data/ext/llama_cpp/llama_cpp.cpp +305 -133
- data/ext/llama_cpp/src/ggml-cuda.cu +367 -69
- data/ext/llama_cpp/src/ggml-cuda.h +1 -0
- data/ext/llama_cpp/src/ggml-metal.m +36 -30
- data/ext/llama_cpp/src/ggml-metal.metal +328 -84
- data/ext/llama_cpp/src/ggml-opencl.cpp +352 -175
- data/ext/llama_cpp/src/ggml.c +800 -303
- data/ext/llama_cpp/src/ggml.h +68 -5
- data/ext/llama_cpp/src/k_quants.c +1712 -56
- data/ext/llama_cpp/src/k_quants.h +41 -6
- data/ext/llama_cpp/src/llama-util.h +19 -5
- data/ext/llama_cpp/src/llama.cpp +262 -291
- data/ext/llama_cpp/src/llama.h +49 -11
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +0 -2
- data/sig/llama_cpp.rbs +14 -17
- metadata +2 -3
- data/lib/llama_cpp/client.rb +0 -172
@@ -117,7 +117,13 @@ static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 blo
|
|
117
117
|
|
118
118
|
//================================= k-quants
|
119
119
|
|
120
|
+
#ifdef GGML_QKK_64
|
121
|
+
#define QK_K 64
|
122
|
+
#define K_SCALE_SIZE 4
|
123
|
+
#else
|
120
124
|
#define QK_K 256
|
125
|
+
#define K_SCALE_SIZE 12
|
126
|
+
#endif
|
121
127
|
|
122
128
|
typedef struct {
|
123
129
|
uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
|
@@ -128,13 +134,25 @@ typedef struct {
|
|
128
134
|
static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_fp16_t) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding");
|
129
135
|
|
130
136
|
typedef struct {
|
131
|
-
uint8_t hmask[QK_K/8];
|
132
|
-
uint8_t qs[QK_K/4];
|
133
|
-
|
134
|
-
|
137
|
+
uint8_t hmask[QK_K/8]; // quants - high bit
|
138
|
+
uint8_t qs[QK_K/4]; // quants - low 2 bits
|
139
|
+
#ifdef GGML_QKK_64
|
140
|
+
uint8_t scales[2]; // scales, quantized with 8 bits
|
141
|
+
#else
|
142
|
+
uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits
|
143
|
+
#endif
|
144
|
+
half d; // super-block scale
|
135
145
|
} block_q3_K;
|
136
|
-
static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 +
|
146
|
+
//static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + K_SCALE_SIZE, "wrong q3_K block size/padding");
|
137
147
|
|
148
|
+
#ifdef GGML_QKK_64
|
149
|
+
typedef struct {
|
150
|
+
half d[2]; // super-block scales/mins
|
151
|
+
uint8_t scales[2]; // 4-bit block scales/mins
|
152
|
+
uint8_t qs[QK_K/2]; // 4--bit quants
|
153
|
+
} block_q4_K;
|
154
|
+
static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + QK_K/2 + 2, "wrong q4_K block size/padding");
|
155
|
+
#else
|
138
156
|
typedef struct {
|
139
157
|
half d; // super-block scale for quantized scales
|
140
158
|
half dmin; // super-block scale for quantized mins
|
@@ -142,15 +160,26 @@ typedef struct {
|
|
142
160
|
uint8_t qs[QK_K/2]; // 4--bit quants
|
143
161
|
} block_q4_K;
|
144
162
|
static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2, "wrong q4_K block size/padding");
|
163
|
+
#endif
|
145
164
|
|
165
|
+
#ifdef GGML_QKK_64
|
166
|
+
typedef struct {
|
167
|
+
half d; // super-block scale
|
168
|
+
int8_t scales[QK_K/16]; // block scales
|
169
|
+
uint8_t qh[QK_K/8]; // quants, high bit
|
170
|
+
uint8_t qs[QK_K/2]; // quants, low 4 bits
|
171
|
+
} block_q5_K;
|
172
|
+
static_assert(sizeof(block_q5_K) == sizeof(ggml_fp16_t) + QK_K/2 + QK_K/8 + QK_K/16, "wrong q5_K block size/padding");
|
173
|
+
#else
|
146
174
|
typedef struct {
|
147
|
-
half
|
148
|
-
half
|
149
|
-
uint8_t scales[
|
175
|
+
half d; // super-block scale for quantized scales
|
176
|
+
half dmin; // super-block scale for quantized mins
|
177
|
+
uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
|
150
178
|
uint8_t qh[QK_K/8]; // quants, high bit
|
151
179
|
uint8_t qs[QK_K/2]; // quants, low 4 bits
|
152
180
|
} block_q5_K;
|
153
|
-
static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) +
|
181
|
+
static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2 + QK_K/8, "wrong q5_K block size/padding");
|
182
|
+
#endif
|
154
183
|
|
155
184
|
typedef struct {
|
156
185
|
uint8_t ql[QK_K/2]; // quants, lower 4 bits
|
@@ -194,6 +223,15 @@ static __global__ void add_f32(const float * x, const float * y, float * dst, co
|
|
194
223
|
dst[i] = x[i] + y[i];
|
195
224
|
}
|
196
225
|
|
226
|
+
static __global__ void add_f16_f32_f16(const half * x, const float * y, half * dst, const int k) {
|
227
|
+
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
228
|
+
|
229
|
+
if (i >= k) {
|
230
|
+
return;
|
231
|
+
}
|
232
|
+
dst[i] = __hadd(x[i], __float2half(y[i]));
|
233
|
+
}
|
234
|
+
|
197
235
|
static __global__ void mul_f32(const float * x, const float * y, float * dst, const int kx, const int ky) {
|
198
236
|
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
199
237
|
|
@@ -349,13 +387,14 @@ static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const in
|
|
349
387
|
static __global__ void dequantize_block_q2_K(const void * vx, float * yy) {
|
350
388
|
|
351
389
|
const int i = blockIdx.x;
|
390
|
+
const block_q2_K * x = (const block_q2_K *) vx;
|
391
|
+
|
352
392
|
const int tid = threadIdx.x;
|
393
|
+
#if QK_K == 256
|
353
394
|
const int n = tid/32;
|
354
395
|
const int l = tid - 32*n;
|
355
396
|
const int is = 8*n + l/16;
|
356
397
|
|
357
|
-
const block_q2_K * x = (const block_q2_K *) vx;
|
358
|
-
|
359
398
|
const uint8_t q = x[i].qs[32*n + l];
|
360
399
|
float * y = yy + i*QK_K + 128*n;
|
361
400
|
|
@@ -365,21 +404,32 @@ static __global__ void dequantize_block_q2_K(const void * vx, float * yy) {
|
|
365
404
|
y[l+32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 2) & 3) - dmin * (x[i].scales[is+2] >> 4);
|
366
405
|
y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4);
|
367
406
|
y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4);
|
407
|
+
#else
|
408
|
+
const int is = tid/16; // 0 or 1
|
409
|
+
const int il = tid%16; // 0...15
|
410
|
+
const uint8_t q = x[i].qs[il] >> (2*is);
|
411
|
+
float * y = yy + i*QK_K + 16*is + il;
|
412
|
+
float dall = x[i].d;
|
413
|
+
float dmin = x[i].dmin;
|
414
|
+
y[ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
|
415
|
+
y[32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+2] >> 4);
|
416
|
+
#endif
|
368
417
|
|
369
418
|
}
|
370
419
|
|
371
420
|
static __global__ void dequantize_block_q3_K(const void * vx, float * yy) {
|
372
421
|
|
373
|
-
int
|
374
|
-
int i = blockIdx.x;
|
375
|
-
int tid = r/2;
|
376
|
-
int is0 = r%2;
|
377
|
-
int l0 = 16*is0 + 4*(threadIdx.x%4);
|
378
|
-
int n = tid / 4;
|
379
|
-
int j = tid - 4*n;
|
380
|
-
|
422
|
+
const int i = blockIdx.x;
|
381
423
|
const block_q3_K * x = (const block_q3_K *) vx;
|
382
424
|
|
425
|
+
#if QK_K == 256
|
426
|
+
const int r = threadIdx.x/4;
|
427
|
+
const int tid = r/2;
|
428
|
+
const int is0 = r%2;
|
429
|
+
const int l0 = 16*is0 + 4*(threadIdx.x%4);
|
430
|
+
const int n = tid / 4;
|
431
|
+
const int j = tid - 4*n;
|
432
|
+
|
383
433
|
uint8_t m = 1 << (4*n + j);
|
384
434
|
int is = 8*n + 2*j + is0;
|
385
435
|
int shift = 2*j;
|
@@ -396,9 +446,31 @@ static __global__ void dequantize_block_q3_K(const void * vx, float * yy) {
|
|
396
446
|
const uint8_t * hm = x[i].hmask;
|
397
447
|
|
398
448
|
for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4));
|
449
|
+
#else
|
450
|
+
const int tid = threadIdx.x;
|
451
|
+
const int is = tid/16; // 0 or 1
|
452
|
+
const int il = tid%16; // 0...15
|
453
|
+
const int im = il/8; // 0...1
|
454
|
+
const int in = il%8; // 0...7
|
455
|
+
|
456
|
+
float * y = yy + i*QK_K + 16*is + il;
|
457
|
+
|
458
|
+
const uint8_t q = x[i].qs[il] >> (2*is);
|
459
|
+
const uint8_t h = x[i].hmask[in] >> (2*is + im);
|
460
|
+
const float d = (float)x[i].d;
|
461
|
+
|
462
|
+
if (is == 0) {
|
463
|
+
y[ 0] = d * ((x[i].scales[0] & 0xF) - 8) * ((int8_t)((q >> 0) & 3) - ((h >> 0) & 1 ? 0 : 4));
|
464
|
+
y[32] = d * ((x[i].scales[1] & 0xF) - 8) * ((int8_t)((q >> 4) & 3) - ((h >> 4) & 1 ? 0 : 4));
|
465
|
+
} else {
|
466
|
+
y[ 0] = d * ((x[i].scales[0] >> 4) - 8) * ((int8_t)((q >> 0) & 3) - ((h >> 0) & 1 ? 0 : 4));
|
467
|
+
y[32] = d * ((x[i].scales[1] >> 4) - 8) * ((int8_t)((q >> 4) & 3) - ((h >> 4) & 1 ? 0 : 4));
|
468
|
+
}
|
469
|
+
#endif
|
399
470
|
|
400
471
|
}
|
401
472
|
|
473
|
+
#if QK_K == 256
|
402
474
|
static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) {
|
403
475
|
if (j < 4) {
|
404
476
|
d = q[j] & 63; m = q[j + 4] & 63;
|
@@ -407,19 +479,14 @@ static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t
|
|
407
479
|
m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
|
408
480
|
}
|
409
481
|
}
|
482
|
+
#endif
|
410
483
|
|
411
484
|
static __global__ void dequantize_block_q4_K(const void * vx, float * yy) {
|
412
485
|
const block_q4_K * x = (const block_q4_K *) vx;
|
413
486
|
|
414
487
|
const int i = blockIdx.x;
|
415
488
|
|
416
|
-
|
417
|
-
//const int tid = threadIdx.x;
|
418
|
-
//const int il = tid/16;
|
419
|
-
//const int ir = tid%16;
|
420
|
-
//const int is = 2*il;
|
421
|
-
//const int n = 2;
|
422
|
-
|
489
|
+
#if QK_K == 256
|
423
490
|
// assume 32 threads
|
424
491
|
const int tid = threadIdx.x;
|
425
492
|
const int il = tid/8;
|
@@ -443,6 +510,15 @@ static __global__ void dequantize_block_q4_K(const void * vx, float * yy) {
|
|
443
510
|
y[l + 0] = d1 * (q[l] & 0xF) - m1;
|
444
511
|
y[l +32] = d2 * (q[l] >> 4) - m2;
|
445
512
|
}
|
513
|
+
#else
|
514
|
+
const int tid = threadIdx.x;
|
515
|
+
const uint8_t * q = x[i].qs;
|
516
|
+
float * y = yy + i*QK_K;
|
517
|
+
const float d = (float)x[i].d[0];
|
518
|
+
const float m = (float)x[i].d[1];
|
519
|
+
y[tid+ 0] = d * (x[i].scales[0] & 0xF) * (q[tid] & 0xF) - m * (x[i].scales[0] >> 4);
|
520
|
+
y[tid+32] = d * (x[i].scales[1] & 0xF) * (q[tid] >> 4) - m * (x[i].scales[1] >> 4);
|
521
|
+
#endif
|
446
522
|
}
|
447
523
|
|
448
524
|
static __global__ void dequantize_block_q5_K(const void * vx, float * yy) {
|
@@ -450,6 +526,7 @@ static __global__ void dequantize_block_q5_K(const void * vx, float * yy) {
|
|
450
526
|
|
451
527
|
const int i = blockIdx.x;
|
452
528
|
|
529
|
+
#if QK_K == 256
|
453
530
|
// assume 64 threads - this is very slightly better than the one below
|
454
531
|
const int tid = threadIdx.x;
|
455
532
|
const int il = tid/16; // il is in 0...3
|
@@ -476,12 +553,25 @@ static __global__ void dequantize_block_q5_K(const void * vx, float * yy) {
|
|
476
553
|
hm <<= 1;
|
477
554
|
y[32] = d2 * ((ql[ 0] >> 4) + (qh[ 0] & hm ? 16 : 0)) - m2;
|
478
555
|
y[33] = d2 * ((ql[ 1] >> 4) + (qh[ 1] & hm ? 16 : 0)) - m2;
|
556
|
+
#else
|
557
|
+
const int tid = threadIdx.x;
|
558
|
+
const uint8_t q = x[i].qs[tid];
|
559
|
+
const int im = tid/8; // 0...3
|
560
|
+
const int in = tid%8; // 0...7
|
561
|
+
const int is = tid/16; // 0 or 1
|
562
|
+
const uint8_t h = x[i].qh[in] >> im;
|
563
|
+
const float d = x[i].d;
|
564
|
+
float * y = yy + i*QK_K + tid;
|
565
|
+
y[ 0] = d * x[i].scales[is+0] * ((q & 0xF) - ((h >> 0) & 1 ? 0 : 16));
|
566
|
+
y[32] = d * x[i].scales[is+2] * ((q >> 4) - ((h >> 4) & 1 ? 0 : 16));
|
567
|
+
#endif
|
479
568
|
}
|
480
569
|
|
481
570
|
static __global__ void dequantize_block_q6_K(const void * vx, float * yy) {
|
482
571
|
const block_q6_K * x = (const block_q6_K *) vx;
|
483
572
|
|
484
573
|
const int i = blockIdx.x;
|
574
|
+
#if QK_K == 256
|
485
575
|
|
486
576
|
// assume 64 threads - this is very slightly better than the one below
|
487
577
|
const int tid = threadIdx.x;
|
@@ -501,6 +591,24 @@ static __global__ void dequantize_block_q6_K(const void * vx, float * yy) {
|
|
501
591
|
y[32] = d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32);
|
502
592
|
y[64] = d * sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32);
|
503
593
|
y[96] = d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32);
|
594
|
+
#else
|
595
|
+
|
596
|
+
// assume 32 threads
|
597
|
+
const int tid = threadIdx.x;
|
598
|
+
const int ip = tid/16; // 0 or 1
|
599
|
+
const int il = tid - 16*ip; // 0...15
|
600
|
+
|
601
|
+
float * y = yy + i*QK_K + 16*ip + il;
|
602
|
+
|
603
|
+
const float d = x[i].d;
|
604
|
+
|
605
|
+
const uint8_t ql = x[i].ql[16*ip + il];
|
606
|
+
const uint8_t qh = x[i].qh[il] >> (2*ip);
|
607
|
+
const int8_t * sc = x[i].scales;
|
608
|
+
|
609
|
+
y[ 0] = d * sc[ip+0] * ((int8_t)((ql & 0xF) | (((qh >> 0) & 3) << 4)) - 32);
|
610
|
+
y[32] = d * sc[ip+2] * ((int8_t)((ql >> 4) | (((qh >> 4) & 3) << 4)) - 32);
|
611
|
+
#endif
|
504
612
|
}
|
505
613
|
|
506
614
|
static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
|
@@ -515,6 +623,9 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float
|
|
515
623
|
|
516
624
|
const block_q2_K * x = (const block_q2_K *)vx + ib0;
|
517
625
|
|
626
|
+
float tmp = 0; // partial sum for thread in warp
|
627
|
+
|
628
|
+
#if QK_K == 256
|
518
629
|
const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...15
|
519
630
|
const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1
|
520
631
|
|
@@ -528,8 +639,6 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float
|
|
528
639
|
const int s_offset = 8*im;
|
529
640
|
const int y_offset = 128*im + l0;
|
530
641
|
|
531
|
-
float tmp = 0; // partial sum for thread in warp
|
532
|
-
|
533
642
|
uint32_t aux[4];
|
534
643
|
const uint8_t * d = (const uint8_t *)aux;
|
535
644
|
const uint8_t * m = (const uint8_t *)(aux + 2);
|
@@ -565,6 +674,39 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float
|
|
565
674
|
tmp += dall * sum1 - dmin * sum2;
|
566
675
|
|
567
676
|
}
|
677
|
+
#else
|
678
|
+
const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15 or 0...7
|
679
|
+
const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); // 0....1 or 0...3
|
680
|
+
const int offset = tid * K_QUANTS_PER_ITERATION;
|
681
|
+
|
682
|
+
uint32_t uaux[2];
|
683
|
+
const uint8_t * d = (const uint8_t *)uaux;
|
684
|
+
|
685
|
+
for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
|
686
|
+
|
687
|
+
const float * y = yy + i * QK_K + offset;
|
688
|
+
const uint8_t * q = x[i].qs + offset;
|
689
|
+
const uint32_t * s = (const uint32_t *)x[i].scales;
|
690
|
+
|
691
|
+
uaux[0] = s[0] & 0x0f0f0f0f;
|
692
|
+
uaux[1] = (s[0] >> 4) & 0x0f0f0f0f;
|
693
|
+
|
694
|
+
const half2 * dh = (const half2 *)&x[i].d;
|
695
|
+
|
696
|
+
const float2 dall = __half22float2(dh[0]);
|
697
|
+
|
698
|
+
float sum1 = 0, sum2 = 0;
|
699
|
+
for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
|
700
|
+
const uint8_t ql = q[l];
|
701
|
+
sum1 += y[l+ 0] * d[0] * ((ql >> 0) & 3)
|
702
|
+
+ y[l+16] * d[1] * ((ql >> 2) & 3)
|
703
|
+
+ y[l+32] * d[2] * ((ql >> 4) & 3)
|
704
|
+
+ y[l+48] * d[3] * ((ql >> 6) & 3);
|
705
|
+
sum2 += y[l+0] * d[4] + y[l+16] * d[5] + y[l+32] * d[6] + y[l+48] * d[7];
|
706
|
+
}
|
707
|
+
tmp += dall.x * sum1 - dall.y * sum2;
|
708
|
+
}
|
709
|
+
#endif
|
568
710
|
|
569
711
|
// sum up partial sums and write back result
|
570
712
|
__syncthreads();
|
@@ -573,16 +715,13 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float
|
|
573
715
|
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
574
716
|
}
|
575
717
|
|
576
|
-
if (
|
718
|
+
if (threadIdx.x == 0) {
|
577
719
|
dst[row] = tmp;
|
578
720
|
}
|
579
721
|
}
|
580
722
|
|
581
723
|
static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
|
582
724
|
|
583
|
-
const uint16_t kmask1 = 0x0303;
|
584
|
-
const uint16_t kmask2 = 0x0f0f;
|
585
|
-
|
586
725
|
const int row = blockIdx.y*blockDim.y + threadIdx.y;
|
587
726
|
if (row > nrows) return;
|
588
727
|
|
@@ -591,6 +730,13 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float
|
|
591
730
|
|
592
731
|
const block_q3_K * x = (const block_q3_K *)vx + ib0;
|
593
732
|
|
733
|
+
float tmp = 0; // partial sum for thread in warp
|
734
|
+
|
735
|
+
#if QK_K == 256
|
736
|
+
|
737
|
+
const uint16_t kmask1 = 0x0303;
|
738
|
+
const uint16_t kmask2 = 0x0f0f;
|
739
|
+
|
594
740
|
const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
|
595
741
|
const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1
|
596
742
|
|
@@ -610,8 +756,6 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float
|
|
610
756
|
|
611
757
|
const uint16_t s_shift = 4*im;
|
612
758
|
|
613
|
-
float tmp = 0; // partial sum for thread in warp
|
614
|
-
|
615
759
|
for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
|
616
760
|
|
617
761
|
const float * y = yy + i * QK_K + y_offset;
|
@@ -640,6 +784,34 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float
|
|
640
784
|
tmp += d * sum;
|
641
785
|
|
642
786
|
}
|
787
|
+
#else
|
788
|
+
|
789
|
+
const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15 or 0...7
|
790
|
+
const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); // 0....1 or 0...3
|
791
|
+
const int offset = tid * K_QUANTS_PER_ITERATION; // 0...15 or 0...14
|
792
|
+
const int in = offset/8; // 0 or 1
|
793
|
+
const int im = offset%8; // 0...7
|
794
|
+
|
795
|
+
for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
|
796
|
+
|
797
|
+
const float * y = yy + i * QK_K + offset;
|
798
|
+
const uint8_t * q = x[i].qs + offset;
|
799
|
+
const uint8_t * s = x[i].scales;
|
800
|
+
|
801
|
+
const float dall = (float)x[i].d;
|
802
|
+
|
803
|
+
float sum = 0;
|
804
|
+
for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
|
805
|
+
const uint8_t hl = x[i].hmask[im+l] >> in;
|
806
|
+
const uint8_t ql = q[l];
|
807
|
+
sum += y[l+ 0] * dall * ((s[0] & 0xF) - 8) * ((int8_t)((ql >> 0) & 3) - ((hl >> 0) & 1 ? 0 : 4))
|
808
|
+
+ y[l+16] * dall * ((s[0] >> 4) - 8) * ((int8_t)((ql >> 2) & 3) - ((hl >> 2) & 1 ? 0 : 4))
|
809
|
+
+ y[l+32] * dall * ((s[1] & 0xF) - 8) * ((int8_t)((ql >> 4) & 3) - ((hl >> 4) & 1 ? 0 : 4))
|
810
|
+
+ y[l+48] * dall * ((s[1] >> 4) - 8) * ((int8_t)((ql >> 6) & 3) - ((hl >> 6) & 1 ? 0 : 4));
|
811
|
+
}
|
812
|
+
tmp += sum;
|
813
|
+
}
|
814
|
+
#endif
|
643
815
|
|
644
816
|
// sum up partial sums and write back result
|
645
817
|
__syncthreads();
|
@@ -648,22 +820,25 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float
|
|
648
820
|
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
649
821
|
}
|
650
822
|
|
651
|
-
if (
|
823
|
+
if (threadIdx.x == 0) {
|
652
824
|
dst[row] = tmp;
|
653
825
|
}
|
654
826
|
}
|
655
827
|
|
656
828
|
static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
|
657
829
|
|
658
|
-
const uint16_t kmask1 = 0x3f3f;
|
659
|
-
const uint16_t kmask2 = 0x0f0f;
|
660
|
-
const uint16_t kmask3 = 0xc0c0;
|
661
|
-
|
662
830
|
const int row = blockIdx.y*blockDim.y + threadIdx.y;
|
663
831
|
if (row > nrows) return;
|
664
832
|
const int num_blocks_per_row = ncols / QK_K;
|
665
833
|
const int ib0 = row*num_blocks_per_row;
|
666
834
|
|
835
|
+
const block_q4_K * x = (const block_q4_K *)vx + ib0;
|
836
|
+
|
837
|
+
#if QK_K == 256
|
838
|
+
const uint16_t kmask1 = 0x3f3f;
|
839
|
+
const uint16_t kmask2 = 0x0f0f;
|
840
|
+
const uint16_t kmask3 = 0xc0c0;
|
841
|
+
|
667
842
|
const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
|
668
843
|
const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1
|
669
844
|
|
@@ -683,8 +858,6 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float
|
|
683
858
|
uint16_t aux[4];
|
684
859
|
const uint8_t * sc = (const uint8_t *)aux;
|
685
860
|
|
686
|
-
const block_q4_K * x = (const block_q4_K *)vx + ib0;
|
687
|
-
|
688
861
|
float tmp = 0; // partial sum for thread in warp
|
689
862
|
|
690
863
|
for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
|
@@ -713,6 +886,36 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float
|
|
713
886
|
tmp += dall * (s.x * sc[0] + s.y * sc[1] + s.z * sc[4] + s.w * sc[5]) - dmin * smin;
|
714
887
|
|
715
888
|
}
|
889
|
+
#else
|
890
|
+
const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15
|
891
|
+
const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION);
|
892
|
+
|
893
|
+
const int step = tid * K_QUANTS_PER_ITERATION;
|
894
|
+
|
895
|
+
uint16_t aux16[2];
|
896
|
+
const uint8_t * s = (const uint8_t *)aux16;
|
897
|
+
|
898
|
+
float tmp = 0;
|
899
|
+
|
900
|
+
for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
|
901
|
+
const uint8_t * q = x[i].qs + step;
|
902
|
+
const float * y = yy + i*QK_K + step;
|
903
|
+
const uint16_t * a = (const uint16_t *)x[i].scales;
|
904
|
+
aux16[0] = a[0] & 0x0f0f;
|
905
|
+
aux16[1] = (a[0] >> 4) & 0x0f0f;
|
906
|
+
const float d = (float)x[i].d[0];
|
907
|
+
const float m = (float)x[i].d[1];
|
908
|
+
float sum = 0.f;
|
909
|
+
for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
|
910
|
+
sum += y[j+ 0] * (d * s[0] * (q[j+ 0] & 0xF) - m * s[2])
|
911
|
+
+ y[j+16] * (d * s[0] * (q[j+16] & 0xF) - m * s[2])
|
912
|
+
+ y[j+32] * (d * s[1] * (q[j+ 0] >> 4) - m * s[3])
|
913
|
+
+ y[j+48] * (d * s[1] * (q[j+16] >> 4) - m * s[3]);
|
914
|
+
}
|
915
|
+
tmp += sum;
|
916
|
+
}
|
917
|
+
|
918
|
+
#endif
|
716
919
|
|
717
920
|
// sum up partial sums and write back result
|
718
921
|
__syncthreads();
|
@@ -728,15 +931,19 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float
|
|
728
931
|
|
729
932
|
static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float * yy, float * dst, const int ncols) {
|
730
933
|
|
731
|
-
const uint16_t kmask1 = 0x3f3f;
|
732
|
-
const uint16_t kmask2 = 0x0f0f;
|
733
|
-
const uint16_t kmask3 = 0xc0c0;
|
734
|
-
|
735
|
-
//const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
736
934
|
const int row = blockIdx.x;
|
737
935
|
const int num_blocks_per_row = ncols / QK_K;
|
738
936
|
const int ib0 = row*num_blocks_per_row;
|
739
937
|
|
938
|
+
const block_q5_K * x = (const block_q5_K *)vx + ib0;
|
939
|
+
|
940
|
+
float tmp = 0; // partial sum for thread in warp
|
941
|
+
|
942
|
+
#if QK_K == 256
|
943
|
+
const uint16_t kmask1 = 0x3f3f;
|
944
|
+
const uint16_t kmask2 = 0x0f0f;
|
945
|
+
const uint16_t kmask3 = 0xc0c0;
|
946
|
+
|
740
947
|
const int tid = threadIdx.x/2; // 0...15
|
741
948
|
const int ix = threadIdx.x%2;
|
742
949
|
|
@@ -757,10 +964,6 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float
|
|
757
964
|
uint16_t aux[4];
|
758
965
|
const uint8_t * sc = (const uint8_t *)aux;
|
759
966
|
|
760
|
-
const block_q5_K * x = (const block_q5_K *)vx + ib0;
|
761
|
-
|
762
|
-
float tmp = 0; // partial sum for thread in warp
|
763
|
-
|
764
967
|
for (int i = ix; i < num_blocks_per_row; i += 2) {
|
765
968
|
|
766
969
|
const uint8_t * ql1 = x[i].qs + q_offset;
|
@@ -793,8 +996,31 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float
|
|
793
996
|
+ (y2[l] + y2[l+16]) * sc[6] + (y2[l+32] + y2[l+48]) * sc[7];
|
794
997
|
}
|
795
998
|
tmp += dall * (sum.x * sc[0] + sum.y * sc[1] + sum.z * sc[4] + sum.w * sc[5]) - dmin * smin;
|
999
|
+
}
|
796
1000
|
|
1001
|
+
#else
|
1002
|
+
const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15
|
1003
|
+
const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION);
|
1004
|
+
const int step = tid * K_QUANTS_PER_ITERATION;
|
1005
|
+
const int im = step/8;
|
1006
|
+
const int in = step%8;
|
1007
|
+
|
1008
|
+
for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
|
1009
|
+
const uint8_t * q = x[i].qs + step;
|
1010
|
+
const int8_t * s = x[i].scales;
|
1011
|
+
const float * y = yy + i*QK_K + step;
|
1012
|
+
const float d = x[i].d;
|
1013
|
+
float sum = 0.f;
|
1014
|
+
for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
|
1015
|
+
const uint8_t h = x[i].qh[in+j] >> im;
|
1016
|
+
sum += y[j+ 0] * d * s[0] * ((q[j+ 0] & 0xF) - ((h >> 0) & 1 ? 0 : 16))
|
1017
|
+
+ y[j+16] * d * s[1] * ((q[j+16] & 0xF) - ((h >> 2) & 1 ? 0 : 16))
|
1018
|
+
+ y[j+32] * d * s[2] * ((q[j+ 0] >> 4) - ((h >> 4) & 1 ? 0 : 16))
|
1019
|
+
+ y[j+48] * d * s[3] * ((q[j+16] >> 4) - ((h >> 6) & 1 ? 0 : 16));
|
1020
|
+
}
|
1021
|
+
tmp += sum;
|
797
1022
|
}
|
1023
|
+
#endif
|
798
1024
|
|
799
1025
|
// sum up partial sums and write back result
|
800
1026
|
__syncthreads();
|
@@ -803,7 +1029,7 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float
|
|
803
1029
|
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
804
1030
|
}
|
805
1031
|
|
806
|
-
if (
|
1032
|
+
if (threadIdx.x == 0) {
|
807
1033
|
dst[row] = tmp;
|
808
1034
|
}
|
809
1035
|
}
|
@@ -820,6 +1046,8 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * vx, const float
|
|
820
1046
|
|
821
1047
|
const block_q6_K * x = (const block_q6_K *)vx + ib0;
|
822
1048
|
|
1049
|
+
#if QK_K == 256
|
1050
|
+
|
823
1051
|
const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
|
824
1052
|
const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1
|
825
1053
|
|
@@ -874,6 +1102,37 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * vx, const float
|
|
874
1102
|
|
875
1103
|
}
|
876
1104
|
|
1105
|
+
#else
|
1106
|
+
|
1107
|
+
const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...7
|
1108
|
+
const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); // 0...3
|
1109
|
+
|
1110
|
+
const int step = tid * K_QUANTS_PER_ITERATION;
|
1111
|
+
|
1112
|
+
float tmp = 0; // partial sum for thread in warp
|
1113
|
+
|
1114
|
+
for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
|
1115
|
+
|
1116
|
+
const float * y = yy + i * QK_K + step;
|
1117
|
+
const uint8_t * ql = x[i].ql + step;
|
1118
|
+
const uint8_t * qh = x[i].qh + step;
|
1119
|
+
const int8_t * s = x[i].scales;
|
1120
|
+
|
1121
|
+
const float d = x[i+0].d;
|
1122
|
+
|
1123
|
+
float sum = 0;
|
1124
|
+
for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
|
1125
|
+
sum += y[j+ 0] * s[0] * d * ((int8_t)((ql[j+ 0] & 0xF) | ((qh[j] & 0x03) << 4)) - 32)
|
1126
|
+
+ y[j+16] * s[1] * d * ((int8_t)((ql[j+16] & 0xF) | ((qh[j] & 0x0c) << 2)) - 32)
|
1127
|
+
+ y[j+32] * s[2] * d * ((int8_t)((ql[j+ 0] >> 4) | ((qh[j] & 0x30) >> 0)) - 32)
|
1128
|
+
+ y[j+48] * s[3] * d * ((int8_t)((ql[j+16] >> 4) | ((qh[j] & 0xc0) >> 2)) - 32);
|
1129
|
+
}
|
1130
|
+
tmp += sum;
|
1131
|
+
|
1132
|
+
}
|
1133
|
+
|
1134
|
+
#endif
|
1135
|
+
|
877
1136
|
// sum up partial sums and write back result
|
878
1137
|
__syncthreads();
|
879
1138
|
#pragma unroll
|
@@ -985,7 +1244,7 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const dfloat * y,
|
|
985
1244
|
}
|
986
1245
|
|
987
1246
|
static __global__ void mul_mat_p021_f16_f32(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nchannels_x) {
|
988
|
-
const half * x = (half *) vx;
|
1247
|
+
const half * x = (const half *) vx;
|
989
1248
|
|
990
1249
|
const int row_x = blockDim.y*blockIdx.y + threadIdx.y;
|
991
1250
|
const int channel = blockDim.z*blockIdx.z + threadIdx.z;
|
@@ -1033,9 +1292,9 @@ static __global__ void mul_mat_p021_f16_f32(const void * vx, const float * y, fl
|
|
1033
1292
|
|
1034
1293
|
static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
|
1035
1294
|
const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x,
|
1036
|
-
const int row_stride_x, const int
|
1295
|
+
const int row_stride_x, const int channel_stride_x) {
|
1037
1296
|
|
1038
|
-
const half * x = (half *) vx;
|
1297
|
+
const half * x = (const half *) vx;
|
1039
1298
|
|
1040
1299
|
const int row_x = blockDim.y*blockIdx.y + threadIdx.y;
|
1041
1300
|
const int channel = blockDim.z*blockIdx.z + threadIdx.z;
|
@@ -1078,14 +1337,14 @@ static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
|
|
1078
1337
|
}
|
1079
1338
|
|
1080
1339
|
static __device__ void cpy_1_f32_f32(const char * cxi, char * cdsti) {
|
1081
|
-
const float * xi = (float *) cxi;
|
1340
|
+
const float * xi = (const float *) cxi;
|
1082
1341
|
float * dsti = (float *) cdsti;
|
1083
1342
|
|
1084
1343
|
*dsti = *xi;
|
1085
1344
|
}
|
1086
1345
|
|
1087
1346
|
static __device__ void cpy_1_f32_f16(const char * cxi, char * cdsti) {
|
1088
|
-
const float * xi = (float *) cxi;
|
1347
|
+
const float * xi = (const float *) cxi;
|
1089
1348
|
half * dsti = (half *) cdsti;
|
1090
1349
|
|
1091
1350
|
*dsti = __float2half(*xi);
|
@@ -1209,6 +1468,11 @@ static void add_f32_cuda(const float * x, const float * y, float * dst, const in
|
|
1209
1468
|
add_f32<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, k);
|
1210
1469
|
}
|
1211
1470
|
|
1471
|
+
static void add_f16_f32_f16_cuda(const half * x, const float * y, half * dst, const int k, cudaStream_t stream) {
|
1472
|
+
const int num_blocks = (k + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE;
|
1473
|
+
add_f16_f32_f16<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, k);
|
1474
|
+
}
|
1475
|
+
|
1212
1476
|
static void mul_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) {
|
1213
1477
|
const int num_blocks = (kx + CUDA_MUL_BLOCK_SIZE - 1) / CUDA_MUL_BLOCK_SIZE;
|
1214
1478
|
mul_f32<<<num_blocks, CUDA_MUL_BLOCK_SIZE, 0, stream>>>(x, y, dst, kx, ky);
|
@@ -1252,12 +1516,20 @@ static void dequantize_row_q8_0_cuda(const void * vx, float * y, const int k, cu
|
|
1252
1516
|
|
1253
1517
|
static void dequantize_row_q2_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
|
1254
1518
|
const int nb = k / QK_K;
|
1519
|
+
#if QK_K == 256
|
1255
1520
|
dequantize_block_q2_K<<<nb, 64, 0, stream>>>(vx, y);
|
1521
|
+
#else
|
1522
|
+
dequantize_block_q2_K<<<nb, 32, 0, stream>>>(vx, y);
|
1523
|
+
#endif
|
1256
1524
|
}
|
1257
1525
|
|
1258
1526
|
static void dequantize_row_q3_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
|
1259
1527
|
const int nb = k / QK_K;
|
1528
|
+
#if QK_K == 256
|
1260
1529
|
dequantize_block_q3_K<<<nb, 64, 0, stream>>>(vx, y);
|
1530
|
+
#else
|
1531
|
+
dequantize_block_q3_K<<<nb, 32, 0, stream>>>(vx, y);
|
1532
|
+
#endif
|
1261
1533
|
}
|
1262
1534
|
|
1263
1535
|
static void dequantize_row_q4_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
|
@@ -1267,12 +1539,20 @@ static void dequantize_row_q4_K_cuda(const void * vx, float * y, const int k, cu
|
|
1267
1539
|
|
1268
1540
|
static void dequantize_row_q5_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
|
1269
1541
|
const int nb = k / QK_K;
|
1542
|
+
#if QK_K == 256
|
1270
1543
|
dequantize_block_q5_K<<<nb, 64, 0, stream>>>(vx, y);
|
1544
|
+
#else
|
1545
|
+
dequantize_block_q5_K<<<nb, 32, 0, stream>>>(vx, y);
|
1546
|
+
#endif
|
1271
1547
|
}
|
1272
1548
|
|
1273
1549
|
static void dequantize_row_q6_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
|
1274
1550
|
const int nb = k / QK_K;
|
1551
|
+
#if QK_K == 256
|
1275
1552
|
dequantize_block_q6_K<<<nb, 64, 0, stream>>>(vx, y);
|
1553
|
+
#else
|
1554
|
+
dequantize_block_q6_K<<<nb, 32, 0, stream>>>(vx, y);
|
1555
|
+
#endif
|
1276
1556
|
}
|
1277
1557
|
|
1278
1558
|
static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
@@ -1418,7 +1698,7 @@ static void ggml_mul_mat_vec_nc_f16_f32_cuda(
|
|
1418
1698
|
const dim3 block_nums(1, nrows_x, nchannels_x);
|
1419
1699
|
const dim3 block_dims(WARP_SIZE, 1, 1);
|
1420
1700
|
mul_mat_vec_nc_f16_f32<<<block_nums, block_dims, 0, stream>>>
|
1421
|
-
(vx, y, dst, ncols_x, nrows_x, row_stride_x,
|
1701
|
+
(vx, y, dst, ncols_x, nrows_x, row_stride_x, channel_stride_x);
|
1422
1702
|
}
|
1423
1703
|
|
1424
1704
|
static void ggml_cpy_f32_f32_cuda(
|
@@ -1675,7 +1955,7 @@ inline void ggml_cuda_op_add(
|
|
1675
1955
|
float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
|
1676
1956
|
cudaStream_t & cudaStream_main){
|
1677
1957
|
|
1678
|
-
GGML_ASSERT(src0_ddf_i != nullptr);
|
1958
|
+
GGML_ASSERT(src0_ddq_i != nullptr || src0_ddf_i != nullptr);
|
1679
1959
|
GGML_ASSERT(src1_ddf_i != nullptr);
|
1680
1960
|
GGML_ASSERT(dst_ddf_i != nullptr);
|
1681
1961
|
|
@@ -1683,7 +1963,13 @@ inline void ggml_cuda_op_add(
|
|
1683
1963
|
const int64_t i01_diff = i01_high - i01_low;
|
1684
1964
|
|
1685
1965
|
// compute
|
1686
|
-
|
1966
|
+
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
1967
|
+
add_f32_cuda(src0_ddf_i, src1_ddf_i, dst_ddf_i, ne0*i01_diff, cudaStream_main);
|
1968
|
+
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
|
1969
|
+
add_f16_f32_f16_cuda((half *) src0_ddq_i, src1_ddf_i, (half *) dst_ddf_i, ne0*i01_diff, cudaStream_main);
|
1970
|
+
} else {
|
1971
|
+
GGML_ASSERT(false);
|
1972
|
+
}
|
1687
1973
|
CUDA_CHECK(cudaGetLastError());
|
1688
1974
|
|
1689
1975
|
(void) src1;
|
@@ -2281,8 +2567,14 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
|
|
2281
2567
|
}
|
2282
2568
|
|
2283
2569
|
void ggml_cuda_add(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
2284
|
-
|
2285
|
-
|
2570
|
+
// ggml_cuda_add permits f16 dst even though this could in theory cause problems with the pointer arithmetic in ggml_cuda_op.
|
2571
|
+
// Due to flatten_rows == true this does in practice not make a difference however.
|
2572
|
+
// Better solution would be nice but right now that would require disproportionate changes.
|
2573
|
+
GGML_ASSERT(
|
2574
|
+
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16) &&
|
2575
|
+
src1->type == GGML_TYPE_F32 &&
|
2576
|
+
(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16));
|
2577
|
+
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_add, false, true);
|
2286
2578
|
}
|
2287
2579
|
|
2288
2580
|
void ggml_cuda_mul(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
@@ -2535,7 +2827,7 @@ void ggml_cuda_free_data(struct ggml_tensor * tensor) {
|
|
2535
2827
|
delete extra;
|
2536
2828
|
}
|
2537
2829
|
|
2538
|
-
void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch) {
|
2830
|
+
void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bool force_inplace) {
|
2539
2831
|
if (scratch && g_scratch_size == 0) {
|
2540
2832
|
return;
|
2541
2833
|
}
|
@@ -2544,22 +2836,24 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch) {
|
|
2544
2836
|
if (tensor->src0 != nullptr && tensor->src0->backend == GGML_BACKEND_CPU) {
|
2545
2837
|
const ggml_op src0_op = tensor->src0->op;
|
2546
2838
|
if (src0_op == GGML_OP_RESHAPE || src0_op == GGML_OP_TRANSPOSE || src0_op == GGML_OP_VIEW) {
|
2547
|
-
ggml_cuda_assign_buffers_impl(tensor->src0, scratch);
|
2839
|
+
ggml_cuda_assign_buffers_impl(tensor->src0, scratch, force_inplace);
|
2548
2840
|
}
|
2549
2841
|
}
|
2550
2842
|
if (tensor->op == GGML_OP_CPY && tensor->src1->backend == GGML_BACKEND_CPU) {
|
2551
|
-
ggml_cuda_assign_buffers_impl(tensor->src1, scratch);
|
2843
|
+
ggml_cuda_assign_buffers_impl(tensor->src1, scratch, force_inplace);
|
2552
2844
|
}
|
2553
2845
|
|
2554
2846
|
tensor->backend = GGML_BACKEND_GPU;
|
2555
2847
|
struct ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu;
|
2848
|
+
memset(extra, 0, sizeof(*extra));
|
2556
2849
|
|
2557
2850
|
const bool inplace = (tensor->src0 != nullptr && tensor->src0->data == tensor->data) ||
|
2558
|
-
tensor->op == GGML_OP_VIEW
|
2851
|
+
tensor->op == GGML_OP_VIEW ||
|
2852
|
+
force_inplace;
|
2559
2853
|
const size_t size = ggml_nbytes(tensor);
|
2560
2854
|
|
2561
2855
|
CUDA_CHECK(cudaSetDevice(g_main_device));
|
2562
|
-
if (inplace && tensor->src0->backend == GGML_BACKEND_GPU) {
|
2856
|
+
if (inplace && (tensor->src0->backend == GGML_BACKEND_GPU || tensor->src0->backend == GGML_BACKEND_GPU_SPLIT)) {
|
2563
2857
|
struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu * ) tensor->src0->extra;
|
2564
2858
|
char * src0_ddc = (char *) src0_extra->data_device[g_main_device];
|
2565
2859
|
size_t offset = 0;
|
@@ -2598,11 +2892,15 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch) {
|
|
2598
2892
|
}
|
2599
2893
|
|
2600
2894
|
void ggml_cuda_assign_buffers(struct ggml_tensor * tensor) {
|
2601
|
-
ggml_cuda_assign_buffers_impl(tensor, true);
|
2895
|
+
ggml_cuda_assign_buffers_impl(tensor, true, false);
|
2602
2896
|
}
|
2603
2897
|
|
2604
2898
|
void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor) {
|
2605
|
-
ggml_cuda_assign_buffers_impl(tensor, false);
|
2899
|
+
ggml_cuda_assign_buffers_impl(tensor, false, false);
|
2900
|
+
}
|
2901
|
+
|
2902
|
+
void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor) {
|
2903
|
+
ggml_cuda_assign_buffers_impl(tensor, false, true);
|
2606
2904
|
}
|
2607
2905
|
|
2608
2906
|
void ggml_cuda_set_main_device(int main_device) {
|
@@ -2635,7 +2933,7 @@ void ggml_cuda_free_scratch() {
|
|
2635
2933
|
bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor){
|
2636
2934
|
ggml_cuda_func_t func;
|
2637
2935
|
const bool any_on_device = tensor->backend == GGML_BACKEND_GPU
|
2638
|
-
|| tensor->src0->backend == GGML_BACKEND_GPU || tensor->src0->backend == GGML_BACKEND_GPU_SPLIT
|
2936
|
+
|| (tensor->src0 != nullptr && (tensor->src0->backend == GGML_BACKEND_GPU || tensor->src0->backend == GGML_BACKEND_GPU_SPLIT))
|
2639
2937
|
|| (tensor->src1 != nullptr && tensor->src1->backend == GGML_BACKEND_GPU);
|
2640
2938
|
|
2641
2939
|
switch (tensor->op) {
|