llama_cpp 0.2.1 → 0.3.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -13,6 +13,10 @@
13
13
  #include "ggml-cuda.h"
14
14
  #include "ggml.h"
15
15
 
16
+ #if defined(_MSC_VER)
17
+ #pragma warning(disable: 4244 4267) // possible loss of data
18
+ #endif
19
+
16
20
  static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
17
21
 
18
22
  #define CUDA_CHECK(err) \
@@ -46,7 +50,15 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
46
50
  } while (0)
47
51
  #endif // CUDART_VERSION >= 11
48
52
 
49
- typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, float & v0, float & v1);
53
+ #ifdef GGML_CUDA_DMMV_F16
54
+ typedef half dfloat; // dequantize float
55
+ typedef half2 dfloat2;
56
+ #else
57
+ typedef float dfloat; // dequantize float
58
+ typedef float2 dfloat2;
59
+ #endif //GGML_CUDA_DMMV_F16
60
+
61
+ typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, dfloat2 & v);
50
62
  typedef void (*to_fp32_cuda_t)(const void * x, float * y, int k, cudaStream_t stream);
51
63
  typedef void (*dot_kernel_k_t)(const void * vx, const int ib, const int iqs, const float * y, float & v);
52
64
  typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
@@ -105,7 +117,13 @@ static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 blo
105
117
 
106
118
  //================================= k-quants
107
119
 
120
+ #ifdef GGML_QKK_64
121
+ #define QK_K 64
122
+ #define K_SCALE_SIZE 4
123
+ #else
108
124
  #define QK_K 256
125
+ #define K_SCALE_SIZE 12
126
+ #endif
109
127
 
110
128
  typedef struct {
111
129
  uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
@@ -116,13 +134,25 @@ typedef struct {
116
134
  static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_fp16_t) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding");
117
135
 
118
136
  typedef struct {
119
- uint8_t hmask[QK_K/8];
120
- uint8_t qs[QK_K/4]; // nibbles / quants
121
- uint8_t scales[3*QK_K/64];
122
- half d;
137
+ uint8_t hmask[QK_K/8]; // quants - high bit
138
+ uint8_t qs[QK_K/4]; // quants - low 2 bits
139
+ #ifdef GGML_QKK_64
140
+ uint8_t scales[2]; // scales, quantized with 8 bits
141
+ #else
142
+ uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits
143
+ #endif
144
+ half d; // super-block scale
123
145
  } block_q3_K;
124
- static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + 11 * QK_K / 64, "wrong q3_K block size/padding");
146
+ //static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + K_SCALE_SIZE, "wrong q3_K block size/padding");
125
147
 
148
+ #ifdef GGML_QKK_64
149
+ typedef struct {
150
+ half d[2]; // super-block scales/mins
151
+ uint8_t scales[2]; // 4-bit block scales/mins
152
+ uint8_t qs[QK_K/2]; // 4--bit quants
153
+ } block_q4_K;
154
+ static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + QK_K/2 + 2, "wrong q4_K block size/padding");
155
+ #else
126
156
  typedef struct {
127
157
  half d; // super-block scale for quantized scales
128
158
  half dmin; // super-block scale for quantized mins
@@ -130,15 +160,26 @@ typedef struct {
130
160
  uint8_t qs[QK_K/2]; // 4--bit quants
131
161
  } block_q4_K;
132
162
  static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2, "wrong q4_K block size/padding");
163
+ #endif
133
164
 
165
+ #ifdef GGML_QKK_64
166
+ typedef struct {
167
+ half d; // super-block scale
168
+ int8_t scales[QK_K/16]; // block scales
169
+ uint8_t qh[QK_K/8]; // quants, high bit
170
+ uint8_t qs[QK_K/2]; // quants, low 4 bits
171
+ } block_q5_K;
172
+ static_assert(sizeof(block_q5_K) == sizeof(ggml_fp16_t) + QK_K/2 + QK_K/8 + QK_K/16, "wrong q5_K block size/padding");
173
+ #else
134
174
  typedef struct {
135
- half d; // super-block scale for quantized scales
136
- half dmin; // super-block scale for quantized mins
137
- uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits
175
+ half d; // super-block scale for quantized scales
176
+ half dmin; // super-block scale for quantized mins
177
+ uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
138
178
  uint8_t qh[QK_K/8]; // quants, high bit
139
179
  uint8_t qs[QK_K/2]; // quants, low 4 bits
140
180
  } block_q5_K;
141
- static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2 + QK_K/8, "wrong q5_K block size/padding");
181
+ static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2 + QK_K/8, "wrong q5_K block size/padding");
182
+ #endif
142
183
 
143
184
  typedef struct {
144
185
  uint8_t ql[QK_K/2]; // quants, lower 4 bits
@@ -167,6 +208,12 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
167
208
  #define GGML_CUDA_DMMV_Y 1
168
209
  #endif
169
210
 
211
+ #ifndef K_QUANTS_PER_ITERATION
212
+ #define K_QUANTS_PER_ITERATION 2
213
+ #else
214
+ static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2");
215
+ #endif
216
+
170
217
  static __global__ void add_f32(const float * x, const float * y, float * dst, const int k) {
171
218
  const int i = blockDim.x*blockIdx.x + threadIdx.x;
172
219
 
@@ -224,82 +271,106 @@ static __global__ void rms_norm_f32(const float * x, float * dst, const int ncol
224
271
  }
225
272
  }
226
273
 
227
- static __device__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){
274
+ static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
228
275
  const block_q4_0 * x = (const block_q4_0 *) vx;
229
276
 
230
- const float d = x[ib].d;
277
+ const dfloat d = x[ib].d;
231
278
 
232
- const uint8_t vui = x[ib].qs[iqs];
279
+ const int vui = x[ib].qs[iqs];
233
280
 
234
- const int8_t vi0 = vui & 0xF;
235
- const int8_t vi1 = vui >> 4;
281
+ v.x = vui & 0xF;
282
+ v.y = vui >> 4;
236
283
 
237
- v0 = (vi0 - 8)*d;
238
- v1 = (vi1 - 8)*d;
284
+ #ifdef GGML_CUDA_DMMV_F16
285
+ v = __hsub2(v, {8.0f, 8.0f});
286
+ v = __hmul2(v, {d, d});
287
+ #else
288
+ v.x = (v.x - 8.0f) * d;
289
+ v.y = (v.y - 8.0f) * d;
290
+ #endif // GGML_CUDA_DMMV_F16
239
291
  }
240
292
 
241
- static __device__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, float & v0, float & v1){
293
+ static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, dfloat2 & v){
242
294
  const block_q4_1 * x = (const block_q4_1 *) vx;
243
295
 
244
- const float d = x[ib].d;
245
- const float m = x[ib].m;
296
+ const dfloat d = x[ib].d;
297
+ const dfloat m = x[ib].m;
246
298
 
247
- const uint8_t vui = x[ib].qs[iqs];
299
+ const int vui = x[ib].qs[iqs];
248
300
 
249
- const int8_t vi0 = vui & 0xF;
250
- const int8_t vi1 = vui >> 4;
301
+ v.x = vui & 0xF;
302
+ v.y = vui >> 4;
251
303
 
252
- v0 = vi0*d + m;
253
- v1 = vi1*d + m;
304
+ #ifdef GGML_CUDA_DMMV_F16
305
+ v = __hmul2(v, {d, d});
306
+ v = __hadd2(v, {m, m});
307
+ #else
308
+ v.x = (v.x * d) + m;
309
+ v.y = (v.y * d) + m;
310
+ #endif // GGML_CUDA_DMMV_F16
254
311
  }
255
312
 
256
- static __device__ void dequantize_q5_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){
313
+ static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
257
314
  const block_q5_0 * x = (const block_q5_0 *) vx;
258
315
 
259
- const float d = x[ib].d;
316
+ const dfloat d = x[ib].d;
260
317
 
261
318
  uint32_t qh;
262
319
  memcpy(&qh, x[ib].qh, sizeof(qh));
263
320
 
264
- const uint8_t xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10;
265
- const uint8_t xh_1 = ((qh >> (iqs + 12)) ) & 0x10;
321
+ const int xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10;
322
+ const int xh_1 = ((qh >> (iqs + 12)) ) & 0x10;
266
323
 
267
- const int32_t x0 = ((x[ib].qs[iqs] & 0xf) | xh_0) - 16;
268
- const int32_t x1 = ((x[ib].qs[iqs] >> 4) | xh_1) - 16;
324
+ v.x = ((x[ib].qs[iqs] & 0xf) | xh_0);
325
+ v.y = ((x[ib].qs[iqs] >> 4) | xh_1);
269
326
 
270
- v0 = x0*d;
271
- v1 = x1*d;
327
+ #ifdef GGML_CUDA_DMMV_F16
328
+ v = __hsub2(v, {16.0f, 16.0f});
329
+ v = __hmul2(v, {d, d});
330
+ #else
331
+ v.x = (v.x - 16.0f) * d;
332
+ v.y = (v.y - 16.0f) * d;
333
+ #endif // GGML_CUDA_DMMV_F16
272
334
  }
273
335
 
274
- static __device__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, float & v0, float & v1){
336
+ static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, dfloat2 & v){
275
337
  const block_q5_1 * x = (const block_q5_1 *) vx;
276
338
 
277
- const float d = x[ib].d;
278
- const float m = x[ib].m;
339
+ const dfloat d = x[ib].d;
340
+ const dfloat m = x[ib].m;
279
341
 
280
342
  uint32_t qh;
281
343
  memcpy(&qh, x[ib].qh, sizeof(qh));
282
344
 
283
- const uint8_t xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10;
284
- const uint8_t xh_1 = ((qh >> (iqs + 12)) ) & 0x10;
345
+ const int xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10;
346
+ const int xh_1 = ((qh >> (iqs + 12)) ) & 0x10;
285
347
 
286
- const int32_t x0 = ((x[ib].qs[iqs] & 0xf) | xh_0);
287
- const int32_t x1 = ((x[ib].qs[iqs] >> 4) | xh_1);
348
+ v.x = ((x[ib].qs[iqs] & 0xf) | xh_0);
349
+ v.y = ((x[ib].qs[iqs] >> 4) | xh_1);
288
350
 
289
- v0 = x0*d + m;
290
- v1 = x1*d + m;
351
+ #ifdef GGML_CUDA_DMMV_F16
352
+ v = __hmul2(v, {d, d});
353
+ v = __hadd2(v, {m, m});
354
+ #else
355
+ v.x = (v.x * d) + m;
356
+ v.y = (v.y * d) + m;
357
+ #endif // GGML_CUDA_DMMV_F16
291
358
  }
292
359
 
293
- static __device__ void dequantize_q8_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){
360
+ static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
294
361
  const block_q8_0 * x = (const block_q8_0 *) vx;
295
362
 
296
- const float d = x[ib].d;
363
+ const dfloat d = x[ib].d;
297
364
 
298
- const int8_t vi0 = x[ib].qs[iqs + 0];
299
- const int8_t vi1 = x[ib].qs[iqs + 1];
365
+ v.x = x[ib].qs[iqs + 0];
366
+ v.y = x[ib].qs[iqs + 1];
300
367
 
301
- v0 = vi0*d;
302
- v1 = vi1*d;
368
+ #ifdef GGML_CUDA_DMMV_F16
369
+ v = __hmul2(v, {d, d});
370
+ #else
371
+ v.x *= d;
372
+ v.y *= d;
373
+ #endif // GGML_CUDA_DMMV_F16
303
374
  }
304
375
 
305
376
  //================================== k-quants
@@ -307,13 +378,14 @@ static __device__ void dequantize_q8_0(const void * vx, const int ib, const int
307
378
  static __global__ void dequantize_block_q2_K(const void * vx, float * yy) {
308
379
 
309
380
  const int i = blockIdx.x;
381
+ const block_q2_K * x = (const block_q2_K *) vx;
382
+
310
383
  const int tid = threadIdx.x;
384
+ #if QK_K == 256
311
385
  const int n = tid/32;
312
386
  const int l = tid - 32*n;
313
387
  const int is = 8*n + l/16;
314
388
 
315
- const block_q2_K * x = (const block_q2_K *) vx;
316
-
317
389
  const uint8_t q = x[i].qs[32*n + l];
318
390
  float * y = yy + i*QK_K + 128*n;
319
391
 
@@ -323,52 +395,32 @@ static __global__ void dequantize_block_q2_K(const void * vx, float * yy) {
323
395
  y[l+32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 2) & 3) - dmin * (x[i].scales[is+2] >> 4);
324
396
  y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4);
325
397
  y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4);
326
-
327
- }
328
-
329
- static __device__ void vec_dot_q2_K(const void * vx, const int ib, const int iqs, const float * yy, float & result) {
330
-
331
- const block_q2_K * x = (const block_q2_K *) vx;
332
-
333
- // if n is 0, we want to do the lower 128, else the upper 128,
334
- // covering y[l+0], y[l+32], y[l+64], y[l+96] and
335
- // y[l+16], y[l+48], y[l+80], y[l+112]
336
- int n = iqs/128; // 0 or 1
337
- int r = iqs - 128*n; // 0...120 in steps of 8
338
- int l = r/8; // 0...15 in steps of 1
339
-
340
- const float * y = yy + 128*n + l;
341
- const uint8_t * q = x[ib].qs + 32*n + l;
342
- const uint8_t * s = x[ib].scales + 8*n;
343
-
344
- const float dall = x[ib].d;
345
- const float dmin = x[ib].dmin;
346
-
347
- float sum = y[ 0] * (dall * ((s[0] & 0xF) * ((q[ 0] >> 0) & 3)) - dmin * (s[0] >> 4))
348
- + y[ 32] * (dall * ((s[2] & 0xF) * ((q[ 0] >> 2) & 3)) - dmin * (s[2] >> 4))
349
- + y[ 64] * (dall * ((s[4] & 0xF) * ((q[ 0] >> 4) & 3)) - dmin * (s[4] >> 4))
350
- + y[ 96] * (dall * ((s[6] & 0xF) * ((q[ 0] >> 6) & 3)) - dmin * (s[6] >> 4))
351
- + y[ 16] * (dall * ((s[1] & 0xF) * ((q[16] >> 0) & 3)) - dmin * (s[1] >> 4))
352
- + y[ 48] * (dall * ((s[3] & 0xF) * ((q[16] >> 2) & 3)) - dmin * (s[3] >> 4))
353
- + y[ 80] * (dall * ((s[5] & 0xF) * ((q[16] >> 4) & 3)) - dmin * (s[5] >> 4))
354
- + y[112] * (dall * ((s[7] & 0xF) * ((q[16] >> 6) & 3)) - dmin * (s[7] >> 4));
355
-
356
- result = sum;
398
+ #else
399
+ const int is = tid/16; // 0 or 1
400
+ const int il = tid%16; // 0...15
401
+ const uint8_t q = x[i].qs[il] >> (2*is);
402
+ float * y = yy + i*QK_K + 16*is + il;
403
+ float dall = x[i].d;
404
+ float dmin = x[i].dmin;
405
+ y[ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
406
+ y[32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+2] >> 4);
407
+ #endif
357
408
 
358
409
  }
359
410
 
360
411
  static __global__ void dequantize_block_q3_K(const void * vx, float * yy) {
361
412
 
362
- int r = threadIdx.x/4;
363
- int i = blockIdx.x;
364
- int tid = r/2;
365
- int is0 = r%2;
366
- int l0 = 16*is0 + 4*(threadIdx.x%4);
367
- int n = tid / 4;
368
- int j = tid - 4*n;
369
-
413
+ const int i = blockIdx.x;
370
414
  const block_q3_K * x = (const block_q3_K *) vx;
371
415
 
416
+ #if QK_K == 256
417
+ const int r = threadIdx.x/4;
418
+ const int tid = r/2;
419
+ const int is0 = r%2;
420
+ const int l0 = 16*is0 + 4*(threadIdx.x%4);
421
+ const int n = tid / 4;
422
+ const int j = tid - 4*n;
423
+
372
424
  uint8_t m = 1 << (4*n + j);
373
425
  int is = 8*n + 2*j + is0;
374
426
  int shift = 2*j;
@@ -385,54 +437,31 @@ static __global__ void dequantize_block_q3_K(const void * vx, float * yy) {
385
437
  const uint8_t * hm = x[i].hmask;
386
438
 
387
439
  for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4));
440
+ #else
441
+ const int tid = threadIdx.x;
442
+ const int is = tid/16; // 0 or 1
443
+ const int il = tid%16; // 0...15
444
+ const int im = il/8; // 0...1
445
+ const int in = il%8; // 0...7
388
446
 
389
- }
390
-
391
- static __device__ void vec_dot_q3_K(const void * vx, const int ib, const int iqs, const float * yy, float & result) {
392
-
393
- const block_q3_K * x = (const block_q3_K *) vx;
394
-
395
- const uint32_t kmask1 = 0x03030303;
396
- const uint32_t kmask2 = 0x0f0f0f0f;
397
-
398
- uint32_t aux[3];
399
- uint32_t utmp[4];
400
-
401
- // if n is 0, we want to do the lower 128, else the upper 128,
402
- // covering y[l+0], y[l+32], y[l+64], y[l+96] and
403
- // y[l+16], y[l+48], y[l+80], y[l+112]
404
- int n = iqs/128; // 0 or 1
405
- int r = iqs - 128*n; // 0...120 in steps of 8
406
- int l = r/8; // 0...15 in steps of 1
407
-
408
- const float * y = yy + 128*n + l;
409
- const uint8_t * q = x[ib].qs + 32*n + l;
410
- const uint8_t * hm = x[ib].hmask + l;
411
- const int8_t * s = (const int8_t *)utmp + 8*n;
412
-
413
- memcpy(aux, x[ib].scales, 12);
414
- utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
415
- utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
416
- utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
417
- utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
418
-
419
- const float dall = x[ib].d;
420
-
421
- const uint8_t m = 1 << (4*n);
447
+ float * y = yy + i*QK_K + 16*is + il;
422
448
 
423
- float sum = y[ 0] * (s[0] - 32) * (((q[ 0] >> 0) & 3) - (hm[ 0] & (m << 0) ? 0 : 4))
424
- + y[ 32] * (s[2] - 32) * (((q[ 0] >> 2) & 3) - (hm[ 0] & (m << 1) ? 0 : 4))
425
- + y[ 64] * (s[4] - 32) * (((q[ 0] >> 4) & 3) - (hm[ 0] & (m << 2) ? 0 : 4))
426
- + y[ 96] * (s[6] - 32) * (((q[ 0] >> 6) & 3) - (hm[ 0] & (m << 3) ? 0 : 4))
427
- + y[ 16] * (s[1] - 32) * (((q[16] >> 0) & 3) - (hm[16] & (m << 0) ? 0 : 4))
428
- + y[ 48] * (s[3] - 32) * (((q[16] >> 2) & 3) - (hm[16] & (m << 1) ? 0 : 4))
429
- + y[ 80] * (s[5] - 32) * (((q[16] >> 4) & 3) - (hm[16] & (m << 2) ? 0 : 4))
430
- + y[112] * (s[7] - 32) * (((q[16] >> 6) & 3) - (hm[16] & (m << 3) ? 0 : 4));
449
+ const uint8_t q = x[i].qs[il] >> (2*is);
450
+ const uint8_t h = x[i].hmask[in] >> (2*is + im);
451
+ const float d = (float)x[i].d;
431
452
 
432
- result = sum * dall;
453
+ if (is == 0) {
454
+ y[ 0] = d * ((x[i].scales[0] & 0xF) - 8) * ((int8_t)((q >> 0) & 3) - ((h >> 0) & 1 ? 0 : 4));
455
+ y[32] = d * ((x[i].scales[1] & 0xF) - 8) * ((int8_t)((q >> 4) & 3) - ((h >> 4) & 1 ? 0 : 4));
456
+ } else {
457
+ y[ 0] = d * ((x[i].scales[0] >> 4) - 8) * ((int8_t)((q >> 0) & 3) - ((h >> 0) & 1 ? 0 : 4));
458
+ y[32] = d * ((x[i].scales[1] >> 4) - 8) * ((int8_t)((q >> 4) & 3) - ((h >> 4) & 1 ? 0 : 4));
459
+ }
460
+ #endif
433
461
 
434
462
  }
435
463
 
464
+ #if QK_K == 256
436
465
  static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) {
437
466
  if (j < 4) {
438
467
  d = q[j] & 63; m = q[j + 4] & 63;
@@ -441,19 +470,14 @@ static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t
441
470
  m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
442
471
  }
443
472
  }
473
+ #endif
444
474
 
445
475
  static __global__ void dequantize_block_q4_K(const void * vx, float * yy) {
446
476
  const block_q4_K * x = (const block_q4_K *) vx;
447
477
 
448
478
  const int i = blockIdx.x;
449
479
 
450
- //// assume 64 threads - this is very slightly better than the one below
451
- //const int tid = threadIdx.x;
452
- //const int il = tid/16;
453
- //const int ir = tid%16;
454
- //const int is = 2*il;
455
- //const int n = 2;
456
-
480
+ #if QK_K == 256
457
481
  // assume 32 threads
458
482
  const int tid = threadIdx.x;
459
483
  const int il = tid/8;
@@ -477,38 +501,15 @@ static __global__ void dequantize_block_q4_K(const void * vx, float * yy) {
477
501
  y[l + 0] = d1 * (q[l] & 0xF) - m1;
478
502
  y[l +32] = d2 * (q[l] >> 4) - m2;
479
503
  }
480
- }
481
-
482
- static __device__ void vec_dot_q4_K(const void * vx, const int ib, const int iqs, const float * yy, float & result) {
483
-
484
- const block_q4_K * x = (const block_q4_K *) vx;
485
-
486
- // iqs is in 0...248 in steps of 8 =>
487
- const int j = iqs / 64; // j is in 0...3
488
- const int ir = (iqs - 64*j)/2; // ir is in 0...28 in steps of 4
489
- const int is = 2*j; // is is in 0...6 in steps of 2
490
-
491
- const float * y = yy + 64*j + ir;
492
- const uint8_t * q = x[ib].qs + 32*j + ir;
493
-
494
- const float dall = x[ib].d;
495
- const float dmin = x[ib].dmin;
496
-
497
- uint8_t sc, m;
498
- get_scale_min_k4(is + 0, x[ib].scales, sc, m);
499
- const float d1 = dall * sc;
500
- const float m1 = dmin * m;
501
- get_scale_min_k4(is + 1, x[ib].scales, sc, m);
502
- const float d2 = dall * sc;
503
- const float m2 = dmin * m;
504
-
505
- float sum = 0;
506
- for (int k = 0; k < 4; ++k) {
507
- sum += y[k + 0] * (d1 * (q[k] & 0xF) - m1);
508
- sum += y[k + 32] * (d2 * (q[k] >> 4) - m2);
509
- }
510
- result = sum;
511
-
504
+ #else
505
+ const int tid = threadIdx.x;
506
+ const uint8_t * q = x[i].qs;
507
+ float * y = yy + i*QK_K;
508
+ const float d = (float)x[i].d[0];
509
+ const float m = (float)x[i].d[1];
510
+ y[tid+ 0] = d * (x[i].scales[0] & 0xF) * (q[tid] & 0xF) - m * (x[i].scales[0] >> 4);
511
+ y[tid+32] = d * (x[i].scales[1] & 0xF) * (q[tid] >> 4) - m * (x[i].scales[1] >> 4);
512
+ #endif
512
513
  }
513
514
 
514
515
  static __global__ void dequantize_block_q5_K(const void * vx, float * yy) {
@@ -516,6 +517,7 @@ static __global__ void dequantize_block_q5_K(const void * vx, float * yy) {
516
517
 
517
518
  const int i = blockIdx.x;
518
519
 
520
+ #if QK_K == 256
519
521
  // assume 64 threads - this is very slightly better than the one below
520
522
  const int tid = threadIdx.x;
521
523
  const int il = tid/16; // il is in 0...3
@@ -542,49 +544,25 @@ static __global__ void dequantize_block_q5_K(const void * vx, float * yy) {
542
544
  hm <<= 1;
543
545
  y[32] = d2 * ((ql[ 0] >> 4) + (qh[ 0] & hm ? 16 : 0)) - m2;
544
546
  y[33] = d2 * ((ql[ 1] >> 4) + (qh[ 1] & hm ? 16 : 0)) - m2;
545
- }
546
-
547
- static __device__ void vec_dot_q5_K(const void * vx, const int ib, const int iqs, const float * yy, float & result) {
548
-
549
- const block_q5_K * x = (const block_q5_K *) vx;
550
-
551
- // iqs is in 0...248 in steps of 8 =>
552
- const int j = iqs / 64; // j is in 0...3
553
- const int ir = (iqs - 64*j)/2; // ir is in 0...28 in steps of 4
554
- const int is = 2*j; // is is in 0...6 in steps of 2
555
-
556
- const float * y = yy + 64*j + ir;
557
- const uint8_t * ql = x[ib].qs + 32*j + ir;
558
- const uint8_t * qh = x[ib].qh + ir;
559
-
560
- const float dall = x[ib].d;
561
- const float dmin = x[ib].dmin;
562
-
563
- uint8_t sc, m;
564
- get_scale_min_k4(is + 0, x[ib].scales, sc, m);
565
- const float d1 = dall * sc;
566
- const float m1 = dmin * m;
567
- get_scale_min_k4(is + 1, x[ib].scales, sc, m);
568
- const float d2 = dall * sc;
569
- const float m2 = dmin * m;
570
-
571
- uint8_t hm = 1 << is;
572
- float sum = 0;
573
- for (int k = 0; k < 4; ++k) {
574
- sum += y[k + 0] * (d1 * ((ql[k] & 0xF) + (qh[k] & hm ? 16 : 0)) - m1);
575
- }
576
- hm <<= 1;
577
- for (int k = 0; k < 4; ++k) {
578
- sum += y[k + 32] * (d2 * ((ql[k] >> 4) + (qh[k] & hm ? 16 : 0)) - m2);
579
- }
580
- result = sum;
581
-
547
+ #else
548
+ const int tid = threadIdx.x;
549
+ const uint8_t q = x[i].qs[tid];
550
+ const int im = tid/8; // 0...3
551
+ const int in = tid%8; // 0...7
552
+ const int is = tid/16; // 0 or 1
553
+ const uint8_t h = x[i].qh[in] >> im;
554
+ const float d = x[i].d;
555
+ float * y = yy + i*QK_K + tid;
556
+ y[ 0] = d * x[i].scales[is+0] * ((q & 0xF) - ((h >> 0) & 1 ? 0 : 16));
557
+ y[32] = d * x[i].scales[is+2] * ((q >> 4) - ((h >> 4) & 1 ? 0 : 16));
558
+ #endif
582
559
  }
583
560
 
584
561
  static __global__ void dequantize_block_q6_K(const void * vx, float * yy) {
585
562
  const block_q6_K * x = (const block_q6_K *) vx;
586
563
 
587
564
  const int i = blockIdx.x;
565
+ #if QK_K == 256
588
566
 
589
567
  // assume 64 threads - this is very slightly better than the one below
590
568
  const int tid = threadIdx.x;
@@ -604,40 +582,566 @@ static __global__ void dequantize_block_q6_K(const void * vx, float * yy) {
604
582
  y[32] = d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32);
605
583
  y[64] = d * sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32);
606
584
  y[96] = d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32);
585
+ #else
586
+
587
+ // assume 32 threads
588
+ const int tid = threadIdx.x;
589
+ const int ip = tid/16; // 0 or 1
590
+ const int il = tid - 16*ip; // 0...15
591
+
592
+ float * y = yy + i*QK_K + 16*ip + il;
593
+
594
+ const float d = x[i].d;
595
+
596
+ const uint8_t ql = x[i].ql[16*ip + il];
597
+ const uint8_t qh = x[i].qh[il] >> (2*ip);
598
+ const int8_t * sc = x[i].scales;
599
+
600
+ y[ 0] = d * sc[ip+0] * ((int8_t)((ql & 0xF) | (((qh >> 0) & 3) << 4)) - 32);
601
+ y[32] = d * sc[ip+2] * ((int8_t)((ql >> 4) | (((qh >> 4) & 3) << 4)) - 32);
602
+ #endif
607
603
  }
608
604
 
609
- static __device__ void vec_dot_q6_K(const void * vx, const int ib, const int iqs, const float * yy, float & result) {
605
+ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
610
606
 
611
- const block_q6_K * x = (const block_q6_K *) vx;
607
+ static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
608
+
609
+ const int row = blockIdx.y*blockDim.y + threadIdx.y;
610
+ if (row > nrows) return;
611
+
612
+ const int num_blocks_per_row = ncols / QK_K;
613
+ const int ib0 = row*num_blocks_per_row;
614
+
615
+ const block_q2_K * x = (const block_q2_K *)vx + ib0;
616
+
617
+ float tmp = 0; // partial sum for thread in warp
618
+
619
+ #if QK_K == 256
620
+ const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...15
621
+ const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1
622
+
623
+ const int step = 16/K_QUANTS_PER_ITERATION;
624
+
625
+ const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
626
+ const int in = tid - step*im; // 0...15 or 0...7
627
+
628
+ const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15 or 0...14 in steps of 2
629
+ const int q_offset = 32*im + l0;
630
+ const int s_offset = 8*im;
631
+ const int y_offset = 128*im + l0;
632
+
633
+ uint32_t aux[4];
634
+ const uint8_t * d = (const uint8_t *)aux;
635
+ const uint8_t * m = (const uint8_t *)(aux + 2);
636
+
637
+ for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
638
+
639
+ const float * y = yy + i * QK_K + y_offset;
640
+ const uint8_t * q = x[i].qs + q_offset;
641
+
642
+ const float dall = x[i].d;
643
+ const float dmin = x[i].dmin;
644
+
645
+ const uint32_t * a = (const uint32_t *)(x[i].scales + s_offset);
646
+ aux[0] = a[0] & 0x0f0f0f0f;
647
+ aux[1] = a[1] & 0x0f0f0f0f;
648
+ aux[2] = (a[0] >> 4) & 0x0f0f0f0f;
649
+ aux[3] = (a[1] >> 4) & 0x0f0f0f0f;
650
+
651
+ float sum1 = 0, sum2 = 0;
652
+ for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
653
+ sum1 += y[l+ 0] * d[0] * ((q[l+ 0] >> 0) & 3)
654
+ + y[l+32] * d[2] * ((q[l+ 0] >> 2) & 3)
655
+ + y[l+64] * d[4] * ((q[l+ 0] >> 4) & 3)
656
+ + y[l+96] * d[6] * ((q[l+ 0] >> 6) & 3)
657
+ + y[l+16] * d[1] * ((q[l+16] >> 0) & 3)
658
+ + y[l+48] * d[3] * ((q[l+16] >> 2) & 3)
659
+ + y[l+80] * d[5] * ((q[l+16] >> 4) & 3)
660
+ +y[l+112] * d[7] * ((q[l+16] >> 6) & 3);
661
+ sum2 += y[l+ 0] * m[0] + y[l+32] * m[2] + y[l+64] * m[4] + y[ l+96] * m[6]
662
+ + y[l+16] * m[1] + y[l+48] * m[3] + y[l+80] * m[5] + y[l+112] * m[7];
663
+
664
+ }
665
+ tmp += dall * sum1 - dmin * sum2;
666
+
667
+ }
668
+ #else
669
+ const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15 or 0...7
670
+ const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); // 0....1 or 0...3
671
+ const int offset = tid * K_QUANTS_PER_ITERATION;
672
+
673
+ uint32_t uaux[2];
674
+ const uint8_t * d = (const uint8_t *)uaux;
675
+
676
+ for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
677
+
678
+ const float * y = yy + i * QK_K + offset;
679
+ const uint8_t * q = x[i].qs + offset;
680
+ const uint32_t * s = (const uint32_t *)x[i].scales;
681
+
682
+ uaux[0] = s[0] & 0x0f0f0f0f;
683
+ uaux[1] = (s[0] >> 4) & 0x0f0f0f0f;
684
+
685
+ const half2 * dh = (const half2 *)&x[i].d;
686
+
687
+ const float2 dall = __half22float2(dh[0]);
688
+
689
+ float sum1 = 0, sum2 = 0;
690
+ for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
691
+ const uint8_t ql = q[l];
692
+ sum1 += y[l+ 0] * d[0] * ((ql >> 0) & 3)
693
+ + y[l+16] * d[1] * ((ql >> 2) & 3)
694
+ + y[l+32] * d[2] * ((ql >> 4) & 3)
695
+ + y[l+48] * d[3] * ((ql >> 6) & 3);
696
+ sum2 += y[l+0] * d[4] + y[l+16] * d[5] + y[l+32] * d[6] + y[l+48] * d[7];
697
+ }
698
+ tmp += dall.x * sum1 - dall.y * sum2;
699
+ }
700
+ #endif
701
+
702
+ // sum up partial sums and write back result
703
+ __syncthreads();
704
+ #pragma unroll
705
+ for (int mask = 16; mask > 0; mask >>= 1) {
706
+ tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
707
+ }
708
+
709
+ if (threadIdx.x == 0) {
710
+ dst[row] = tmp;
711
+ }
712
+ }
713
+
714
+ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
715
+
716
+ const int row = blockIdx.y*blockDim.y + threadIdx.y;
717
+ if (row > nrows) return;
718
+
719
+ const int num_blocks_per_row = ncols / QK_K;
720
+ const int ib0 = row*num_blocks_per_row;
721
+
722
+ const block_q3_K * x = (const block_q3_K *)vx + ib0;
723
+
724
+ float tmp = 0; // partial sum for thread in warp
725
+
726
+ #if QK_K == 256
727
+
728
+ const uint16_t kmask1 = 0x0303;
729
+ const uint16_t kmask2 = 0x0f0f;
730
+
731
+ const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
732
+ const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1
733
+
734
+ const int n = K_QUANTS_PER_ITERATION; // iterations in the inner loop
735
+ const int step = 16/K_QUANTS_PER_ITERATION;
736
+ const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
737
+ const int in = tid - step*im; // 0....15 or 0...7
738
+
739
+ const uint8_t m = 1 << (4*im);
740
+
741
+ const int l0 = n*in; // 0...15 or 0...14 in steps of 2
742
+ const int q_offset = 32*im + l0;
743
+ const int y_offset = 128*im + l0;
744
+
745
+ uint16_t utmp[4];
746
+ const int8_t * s = (const int8_t *)utmp;
747
+
748
+ const uint16_t s_shift = 4*im;
749
+
750
+ for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
751
+
752
+ const float * y = yy + i * QK_K + y_offset;
753
+ const uint8_t * q = x[i].qs + q_offset;
754
+ const uint8_t * h = x[i].hmask + l0;
755
+
756
+ const uint16_t * a = (const uint16_t *)x[i].scales;
757
+ utmp[0] = ((a[0] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 0)) & kmask1) << 4);
758
+ utmp[1] = ((a[1] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 0)) & kmask1) << 4);
759
+ utmp[2] = ((a[2] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 2)) & kmask1) << 4);
760
+ utmp[3] = ((a[3] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 2)) & kmask1) << 4);
761
+
762
+ const float d = x[i].d;
763
+
764
+ float sum = 0;
765
+ for (int l = 0; l < n; ++l) {
766
+ sum += y[l+ 0] * (s[0] - 32) * (((q[l] >> 0) & 3) - (h[l] & (m << 0) ? 0 : 4))
767
+ + y[l+32] * (s[2] - 32) * (((q[l] >> 2) & 3) - (h[l] & (m << 1) ? 0 : 4))
768
+ + y[l+64] * (s[4] - 32) * (((q[l] >> 4) & 3) - (h[l] & (m << 2) ? 0 : 4))
769
+ + y[l+96] * (s[6] - 32) * (((q[l] >> 6) & 3) - (h[l] & (m << 3) ? 0 : 4));
770
+ sum += y[l+16] * (s[1] - 32) * (((q[l+16] >> 0) & 3) - (h[l+16] & (m << 0) ? 0 : 4))
771
+ + y[l+48] * (s[3] - 32) * (((q[l+16] >> 2) & 3) - (h[l+16] & (m << 1) ? 0 : 4))
772
+ + y[l+80] * (s[5] - 32) * (((q[l+16] >> 4) & 3) - (h[l+16] & (m << 2) ? 0 : 4))
773
+ + y[l+112] * (s[7] - 32) * (((q[l+16] >> 6) & 3) - (h[l+16] & (m << 3) ? 0 : 4));
774
+ }
775
+ tmp += d * sum;
776
+
777
+ }
778
+ #else
779
+
780
+ const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15 or 0...7
781
+ const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); // 0....1 or 0...3
782
+ const int offset = tid * K_QUANTS_PER_ITERATION; // 0...15 or 0...14
783
+ const int in = offset/8; // 0 or 1
784
+ const int im = offset%8; // 0...7
785
+
786
+ for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
787
+
788
+ const float * y = yy + i * QK_K + offset;
789
+ const uint8_t * q = x[i].qs + offset;
790
+ const uint8_t * s = x[i].scales;
791
+
792
+ const float dall = (float)x[i].d;
793
+
794
+ float sum = 0;
795
+ for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
796
+ const uint8_t hl = x[i].hmask[im+l] >> in;
797
+ const uint8_t ql = q[l];
798
+ sum += y[l+ 0] * dall * ((s[0] & 0xF) - 8) * ((int8_t)((ql >> 0) & 3) - ((hl >> 0) & 1 ? 0 : 4))
799
+ + y[l+16] * dall * ((s[0] >> 4) - 8) * ((int8_t)((ql >> 2) & 3) - ((hl >> 2) & 1 ? 0 : 4))
800
+ + y[l+32] * dall * ((s[1] & 0xF) - 8) * ((int8_t)((ql >> 4) & 3) - ((hl >> 4) & 1 ? 0 : 4))
801
+ + y[l+48] * dall * ((s[1] >> 4) - 8) * ((int8_t)((ql >> 6) & 3) - ((hl >> 6) & 1 ? 0 : 4));
802
+ }
803
+ tmp += sum;
804
+ }
805
+ #endif
806
+
807
+ // sum up partial sums and write back result
808
+ __syncthreads();
809
+ #pragma unroll
810
+ for (int mask = 16; mask > 0; mask >>= 1) {
811
+ tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
812
+ }
813
+
814
+ if (threadIdx.x == 0) {
815
+ dst[row] = tmp;
816
+ }
817
+ }
818
+
819
+ static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
820
+
821
+ const int row = blockIdx.y*blockDim.y + threadIdx.y;
822
+ if (row > nrows) return;
823
+ const int num_blocks_per_row = ncols / QK_K;
824
+ const int ib0 = row*num_blocks_per_row;
825
+
826
+ const block_q4_K * x = (const block_q4_K *)vx + ib0;
827
+
828
+ #if QK_K == 256
829
+ const uint16_t kmask1 = 0x3f3f;
830
+ const uint16_t kmask2 = 0x0f0f;
831
+ const uint16_t kmask3 = 0xc0c0;
832
+
833
+ const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
834
+ const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1
835
+
836
+ const int step = 8/K_QUANTS_PER_ITERATION; // 8 or 4
837
+
838
+ const int il = tid/step; // 0...3
839
+ const int ir = tid - step*il; // 0...7 or 0...3
840
+ const int n = 2 * K_QUANTS_PER_ITERATION; // 2 or 4
841
+
842
+ const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
843
+ const int in = il%2;
844
+
845
+ const int l0 = n*(2*ir + in);
846
+ const int q_offset = 32*im + l0;
847
+ const int y_offset = 64*im + l0;
848
+
849
+ uint16_t aux[4];
850
+ const uint8_t * sc = (const uint8_t *)aux;
851
+
852
+ float tmp = 0; // partial sum for thread in warp
612
853
 
613
- const int ip = iqs / 128; // 0 or 1
614
- const int il = (iqs - 128*ip)/8; // 0...15
615
- const int is = 8*ip;
854
+ for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
616
855
 
617
- const float * y = yy + 128*ip + il;
856
+ const uint8_t * q1 = x[i].qs + q_offset;
857
+ const uint8_t * q2 = q1 + 64;
858
+ const float * y1 = yy + i*QK_K + y_offset;
859
+ const float * y2 = y1 + 128;
618
860
 
619
- const float d = x[ib].d;
861
+ const float dall = x[i].d;
862
+ const float dmin = x[i].dmin;
620
863
 
621
- const uint8_t * ql = x[ib].ql + 64*ip + il;
622
- const uint8_t * qh = x[ib].qh + 32*ip + il;
623
- const int8_t * sc = x[ib].scales + is;
864
+ const uint16_t * a = (const uint16_t *)x[i].scales;
865
+ aux[0] = a[im+0] & kmask1;
866
+ aux[1] = a[im+2] & kmask1;
867
+ aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2);
868
+ aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2);
624
869
 
625
- result = y[ 0] * d * sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh[ 0] >> 0) & 3) << 4)) - 32)
626
- + y[ 32] * d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh[ 0] >> 2) & 3) << 4)) - 32)
627
- + y[ 64] * d * sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh[ 0] >> 4) & 3) << 4)) - 32)
628
- + y[ 96] * d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh[ 0] >> 6) & 3) << 4)) - 32)
629
- + y[ 16] * d * sc[1] * ((int8_t)((ql[16] & 0xF) | (((qh[16] >> 0) & 3) << 4)) - 32)
630
- + y[ 48] * d * sc[3] * ((int8_t)((ql[48] & 0xF) | (((qh[16] >> 2) & 3) << 4)) - 32)
631
- + y[ 80] * d * sc[5] * ((int8_t)((ql[16] >> 4) | (((qh[16] >> 4) & 3) << 4)) - 32)
632
- + y[112] * d * sc[7] * ((int8_t)((ql[48] >> 4) | (((qh[16] >> 6) & 3) << 4)) - 32);
870
+ float4 s = {0.f, 0.f, 0.f, 0.f};
871
+ float smin = 0;
872
+ for (int l = 0; l < n; ++l) {
873
+ s.x += y1[l] * (q1[l] & 0xF); s.y += y1[l+32] * (q1[l] >> 4);
874
+ s.z += y2[l] * (q2[l] & 0xF); s.w += y2[l+32] * (q2[l] >> 4);
875
+ smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7];
876
+ }
877
+ tmp += dall * (s.x * sc[0] + s.y * sc[1] + s.z * sc[4] + s.w * sc[5]) - dmin * smin;
878
+
879
+ }
880
+ #else
881
+ const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15
882
+ const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION);
883
+
884
+ const int step = tid * K_QUANTS_PER_ITERATION;
885
+
886
+ uint16_t aux16[2];
887
+ const uint8_t * s = (const uint8_t *)aux16;
888
+
889
+ float tmp = 0;
890
+
891
+ for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
892
+ const uint8_t * q = x[i].qs + step;
893
+ const float * y = yy + i*QK_K + step;
894
+ const uint16_t * a = (const uint16_t *)x[i].scales;
895
+ aux16[0] = a[0] & 0x0f0f;
896
+ aux16[1] = (a[0] >> 4) & 0x0f0f;
897
+ const float d = (float)x[i].d[0];
898
+ const float m = (float)x[i].d[1];
899
+ float sum = 0.f;
900
+ for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
901
+ sum += y[j+ 0] * (d * s[0] * (q[j+ 0] & 0xF) - m * s[2])
902
+ + y[j+16] * (d * s[0] * (q[j+16] & 0xF) - m * s[2])
903
+ + y[j+32] * (d * s[1] * (q[j+ 0] >> 4) - m * s[3])
904
+ + y[j+48] * (d * s[1] * (q[j+16] >> 4) - m * s[3]);
905
+ }
906
+ tmp += sum;
907
+ }
908
+
909
+ #endif
910
+
911
+ // sum up partial sums and write back result
912
+ __syncthreads();
913
+ #pragma unroll
914
+ for (int mask = 16; mask > 0; mask >>= 1) {
915
+ tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
916
+ }
633
917
 
918
+ if (tid == 0) {
919
+ dst[row] = tmp;
920
+ }
634
921
  }
635
922
 
636
- static __device__ void convert_f16(const void * vx, const int ib, const int iqs, float & v0, float & v1){
923
+ static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float * yy, float * dst, const int ncols) {
924
+
925
+ const int row = blockIdx.x;
926
+ const int num_blocks_per_row = ncols / QK_K;
927
+ const int ib0 = row*num_blocks_per_row;
928
+
929
+ const block_q5_K * x = (const block_q5_K *)vx + ib0;
930
+
931
+ float tmp = 0; // partial sum for thread in warp
932
+
933
+ #if QK_K == 256
934
+ const uint16_t kmask1 = 0x3f3f;
935
+ const uint16_t kmask2 = 0x0f0f;
936
+ const uint16_t kmask3 = 0xc0c0;
937
+
938
+ const int tid = threadIdx.x/2; // 0...15
939
+ const int ix = threadIdx.x%2;
940
+
941
+ const int il = tid/4; // 0...3
942
+ const int ir = tid - 4*il;// 0...3
943
+ const int n = 2;
944
+
945
+ const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
946
+ const int in = il%2;
947
+
948
+ const int l0 = n*(2*ir + in);
949
+ const int q_offset = 32*im + l0;
950
+ const int y_offset = 64*im + l0;
951
+
952
+ const uint8_t hm1 = 1 << (2*im);
953
+ const uint8_t hm2 = hm1 << 4;
954
+
955
+ uint16_t aux[4];
956
+ const uint8_t * sc = (const uint8_t *)aux;
957
+
958
+ for (int i = ix; i < num_blocks_per_row; i += 2) {
959
+
960
+ const uint8_t * ql1 = x[i].qs + q_offset;
961
+ const uint8_t * ql2 = ql1 + 64;
962
+ const uint8_t * qh = x[i].qh + l0;
963
+ const float * y1 = yy + i*QK_K + y_offset;
964
+ const float * y2 = y1 + 128;
965
+
966
+ const float dall = x[i].d;
967
+ const float dmin = x[i].dmin;
968
+
969
+ const uint16_t * a = (const uint16_t *)x[i].scales;
970
+ aux[0] = a[im+0] & kmask1;
971
+ aux[1] = a[im+2] & kmask1;
972
+ aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2);
973
+ aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2);
974
+
975
+ float4 sum = {0.f, 0.f, 0.f, 0.f};
976
+ float smin = 0;
977
+ for (int l = 0; l < n; ++l) {
978
+ sum.x += y1[l+ 0] * ((ql1[l+ 0] & 0xF) + (qh[l+ 0] & (hm1 << 0) ? 16 : 0))
979
+ + y1[l+16] * ((ql1[l+16] & 0xF) + (qh[l+16] & (hm1 << 0) ? 16 : 0));
980
+ sum.y += y1[l+32] * ((ql1[l+ 0] >> 4) + (qh[l+ 0] & (hm1 << 1) ? 16 : 0))
981
+ + y1[l+48] * ((ql1[l+16] >> 4) + (qh[l+16] & (hm1 << 1) ? 16 : 0));
982
+ sum.z += y2[l+ 0] * ((ql2[l+ 0] & 0xF) + (qh[l+ 0] & (hm2 << 0) ? 16 : 0))
983
+ + y2[l+16] * ((ql2[l+16] & 0xF) + (qh[l+16] & (hm2 << 0) ? 16 : 0));
984
+ sum.w += y2[l+32] * ((ql2[l+ 0] >> 4) + (qh[l+ 0] & (hm2 << 1) ? 16 : 0))
985
+ + y2[l+48] * ((ql2[l+16] >> 4) + (qh[l+16] & (hm2 << 1) ? 16 : 0));
986
+ smin += (y1[l] + y1[l+16]) * sc[2] + (y1[l+32] + y1[l+48]) * sc[3]
987
+ + (y2[l] + y2[l+16]) * sc[6] + (y2[l+32] + y2[l+48]) * sc[7];
988
+ }
989
+ tmp += dall * (sum.x * sc[0] + sum.y * sc[1] + sum.z * sc[4] + sum.w * sc[5]) - dmin * smin;
990
+ }
991
+
992
+ #else
993
+ const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15
994
+ const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION);
995
+ const int step = tid * K_QUANTS_PER_ITERATION;
996
+ const int im = step/8;
997
+ const int in = step%8;
998
+
999
+ for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
1000
+ const uint8_t * q = x[i].qs + step;
1001
+ const int8_t * s = x[i].scales;
1002
+ const float * y = yy + i*QK_K + step;
1003
+ const float d = x[i].d;
1004
+ float sum = 0.f;
1005
+ for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
1006
+ const uint8_t h = x[i].qh[in+j] >> im;
1007
+ sum += y[j+ 0] * d * s[0] * ((q[j+ 0] & 0xF) - ((h >> 0) & 1 ? 0 : 16))
1008
+ + y[j+16] * d * s[1] * ((q[j+16] & 0xF) - ((h >> 2) & 1 ? 0 : 16))
1009
+ + y[j+32] * d * s[2] * ((q[j+ 0] >> 4) - ((h >> 4) & 1 ? 0 : 16))
1010
+ + y[j+48] * d * s[3] * ((q[j+16] >> 4) - ((h >> 6) & 1 ? 0 : 16));
1011
+ }
1012
+ tmp += sum;
1013
+ }
1014
+ #endif
1015
+
1016
+ // sum up partial sums and write back result
1017
+ __syncthreads();
1018
+ #pragma unroll
1019
+ for (int mask = 16; mask > 0; mask >>= 1) {
1020
+ tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
1021
+ }
1022
+
1023
+ if (threadIdx.x == 0) {
1024
+ dst[row] = tmp;
1025
+ }
1026
+ }
1027
+
1028
+ static __global__ void dequantize_mul_mat_vec_q6_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
1029
+
1030
+ static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
1031
+
1032
+ const int row = blockIdx.y*blockDim.y + threadIdx.y;
1033
+ if (row > nrows) return;
1034
+
1035
+ const int num_blocks_per_row = ncols / QK_K;
1036
+ const int ib0 = row*num_blocks_per_row;
1037
+
1038
+ const block_q6_K * x = (const block_q6_K *)vx + ib0;
1039
+
1040
+ #if QK_K == 256
1041
+
1042
+ const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
1043
+ const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1
1044
+
1045
+ const int step = 16/K_QUANTS_PER_ITERATION; // 16 or 8
1046
+
1047
+ const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
1048
+ const int in = tid - step*im; // 0...15 or 0...7
1049
+
1050
+ #if K_QUANTS_PER_ITERATION == 1
1051
+ const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15
1052
+ const int is = 0;
1053
+ #else
1054
+ const int l0 = 4 * in; // 0, 4, 8, ..., 28
1055
+ const int is = in / 4;
1056
+ #endif
1057
+ const int ql_offset = 64*im + l0;
1058
+ const int qh_offset = 32*im + l0;
1059
+ const int s_offset = 8*im + is;
1060
+ const int y_offset = 128*im + l0;
1061
+
1062
+ float tmp = 0; // partial sum for thread in warp
1063
+
1064
+ for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
1065
+
1066
+ const float * y = yy + i * QK_K + y_offset;
1067
+ const uint8_t * ql = x[i].ql + ql_offset;
1068
+ const uint8_t * qh = x[i].qh + qh_offset;
1069
+ const int8_t * s = x[i].scales + s_offset;
1070
+
1071
+ const float d = x[i].d;
1072
+
1073
+ #if K_QUANTS_PER_ITERATION == 1
1074
+ float sum = y[ 0] * s[0] * d * ((int8_t)((ql[ 0] & 0xF) | ((qh[ 0] & 0x03) << 4)) - 32)
1075
+ + y[16] * s[1] * d * ((int8_t)((ql[16] & 0xF) | ((qh[16] & 0x03) << 4)) - 32)
1076
+ + y[32] * s[2] * d * ((int8_t)((ql[32] & 0xF) | ((qh[ 0] & 0x0c) << 2)) - 32)
1077
+ + y[48] * s[3] * d * ((int8_t)((ql[48] & 0xF) | ((qh[16] & 0x0c) << 2)) - 32)
1078
+ + y[64] * s[4] * d * ((int8_t)((ql[ 0] >> 4) | ((qh[ 0] & 0x30) >> 0)) - 32)
1079
+ + y[80] * s[5] * d * ((int8_t)((ql[16] >> 4) | ((qh[16] & 0x30) >> 0)) - 32)
1080
+ + y[96] * s[6] * d * ((int8_t)((ql[32] >> 4) | ((qh[ 0] & 0xc0) >> 2)) - 32)
1081
+ +y[112] * s[7] * d * ((int8_t)((ql[48] >> 4) | ((qh[16] & 0xc0) >> 2)) - 32);
1082
+ tmp += sum;
1083
+ #else
1084
+ float sum = 0;
1085
+ for (int l = 0; l < 4; ++l) {
1086
+ sum += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32)
1087
+ + y[l+32] * s[2] * d * ((int8_t)((ql[l+32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32)
1088
+ + y[l+64] * s[4] * d * ((int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32)
1089
+ + y[l+96] * s[6] * d * ((int8_t)((ql[l+32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32);
1090
+ }
1091
+ tmp += sum;
1092
+ #endif
1093
+
1094
+ }
1095
+
1096
+ #else
1097
+
1098
+ const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...7
1099
+ const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); // 0...3
1100
+
1101
+ const int step = tid * K_QUANTS_PER_ITERATION;
1102
+
1103
+ float tmp = 0; // partial sum for thread in warp
1104
+
1105
+ for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
1106
+
1107
+ const float * y = yy + i * QK_K + step;
1108
+ const uint8_t * ql = x[i].ql + step;
1109
+ const uint8_t * qh = x[i].qh + step;
1110
+ const int8_t * s = x[i].scales;
1111
+
1112
+ const float d = x[i+0].d;
1113
+
1114
+ float sum = 0;
1115
+ for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
1116
+ sum += y[j+ 0] * s[0] * d * ((int8_t)((ql[j+ 0] & 0xF) | ((qh[j] & 0x03) << 4)) - 32)
1117
+ + y[j+16] * s[1] * d * ((int8_t)((ql[j+16] & 0xF) | ((qh[j] & 0x0c) << 2)) - 32)
1118
+ + y[j+32] * s[2] * d * ((int8_t)((ql[j+ 0] >> 4) | ((qh[j] & 0x30) >> 0)) - 32)
1119
+ + y[j+48] * s[3] * d * ((int8_t)((ql[j+16] >> 4) | ((qh[j] & 0xc0) >> 2)) - 32);
1120
+ }
1121
+ tmp += sum;
1122
+
1123
+ }
1124
+
1125
+ #endif
1126
+
1127
+ // sum up partial sums and write back result
1128
+ __syncthreads();
1129
+ #pragma unroll
1130
+ for (int mask = 16; mask > 0; mask >>= 1) {
1131
+ tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
1132
+ }
1133
+
1134
+ if (tid == 0) {
1135
+ dst[row] = tmp;
1136
+ }
1137
+ }
1138
+
1139
+ static __device__ void convert_f16(const void * vx, const int ib, const int iqs, dfloat2 & v){
637
1140
  const half * x = (const half *) vx;
638
1141
 
639
- v0 = __half2float(x[ib + iqs + 0]);
640
- v1 = __half2float(x[ib + iqs + 1]);
1142
+ // automatic half -> float type cast if dfloat == float
1143
+ v.x = x[ib + iqs + 0];
1144
+ v.y = x[ib + iqs + 1];
641
1145
  }
642
1146
 
643
1147
  template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
@@ -654,13 +1158,15 @@ static __global__ void dequantize_block(const void * vx, float * y, const int k)
654
1158
  const int y_offset = qr == 1 ? 1 : qk/2;
655
1159
 
656
1160
  // dequantize
657
- float & v0 = y[iybs + iqs + 0];
658
- float & v1 = y[iybs + iqs + y_offset];
659
- dequantize_kernel(vx, ib, iqs, v0, v1);
1161
+ dfloat2 v;
1162
+ dequantize_kernel(vx, ib, iqs, v);
1163
+
1164
+ y[iybs + iqs + 0] = v.x;
1165
+ y[iybs + iqs + y_offset] = v.y;
660
1166
  }
661
1167
 
662
1168
  template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
663
- static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, float * dst, const int ncols, const int nrows) {
1169
+ static __global__ void dequantize_mul_mat_vec(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows) {
664
1170
  // qk = quantized weights per x block
665
1171
  // qr = number of quantized weights per data value in x block
666
1172
  const int row = blockIdx.y*blockDim.y + threadIdx.y;
@@ -675,7 +1181,12 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y,
675
1181
  const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter
676
1182
  const int y_offset = qr == 1 ? 1 : qk/2;
677
1183
 
678
- float tmp = 0.0f; // partial sum for thread in warp
1184
+ // partial sum for each thread
1185
+ #ifdef GGML_CUDA_DMMV_F16
1186
+ half2 tmp = {0.0f, 0.0f}; // two sums for f16 to take advantage of half2 intrinsics
1187
+ #else
1188
+ float tmp = 0.0f;
1189
+ #endif // GGML_CUDA_DMMV_F16
679
1190
 
680
1191
  for (int i = 0; i < ncols; i += iter_stride) {
681
1192
  const int col = i + vals_per_iter*tid;
@@ -689,14 +1200,21 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y,
689
1200
  // process 2 vals per j iter
690
1201
 
691
1202
  // dequantize
692
- float v0, v1;
693
- dequantize_kernel(vx, ib, iqs + j/qr, v0, v1);
694
1203
  // for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val
1204
+ dfloat2 v;
1205
+ dequantize_kernel(vx, ib, iqs + j/qr, v);
695
1206
 
696
1207
  // matrix multiplication
697
- tmp += v0 * y[iybs + iqs + j/qr + 0];
698
- tmp += v1 * y[iybs + iqs + j/qr + y_offset];
699
1208
  // for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2
1209
+ #ifdef GGML_CUDA_DMMV_F16
1210
+ tmp += __hmul2(v, {
1211
+ y[iybs + iqs + j/qr + 0],
1212
+ y[iybs + iqs + j/qr + y_offset]
1213
+ });
1214
+ #else
1215
+ tmp += v.x * y[iybs + iqs + j/qr + 0];
1216
+ tmp += v.y * y[iybs + iqs + j/qr + y_offset];
1217
+ #endif // GGML_CUDA_DMMV_F16
700
1218
  }
701
1219
  }
702
1220
 
@@ -708,47 +1226,11 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y,
708
1226
  }
709
1227
 
710
1228
  if (tid == 0) {
1229
+ #ifdef GGML_CUDA_DMMV_F16
1230
+ dst[row] = tmp.x + tmp.y;
1231
+ #else
711
1232
  dst[row] = tmp;
712
- }
713
- }
714
-
715
- template <int n_thread, dot_kernel_k_t dot_kernel>
716
- static __global__ void dequantize_mul_mat_vec_k(const void * vx, const float * y, float * dst, const int ncols, const int nrows) {
717
- const int row = blockIdx.y*blockDim.y + threadIdx.y;
718
-
719
- if (row >= nrows) {
720
- return;
721
- }
722
-
723
- const int tid = threadIdx.x;
724
-
725
- const int iter_stride = QK_K;
726
- const int vals_per_iter = iter_stride / n_thread;
727
- const int num_blocks_per_row = ncols / QK_K;
728
- const int ib0 = row*num_blocks_per_row;
729
-
730
- float tmp = 0; // partial sum for thread in warp
731
-
732
- for (int i = 0; i < ncols; i += iter_stride) {
733
- const int col = i + vals_per_iter*tid;
734
- const int ib = ib0 + col/QK_K; // x block index
735
- const int iqs = col%QK_K; // x quant index
736
- const int iybs = col - col%QK_K; // y block start index
737
-
738
- float v;
739
- dot_kernel(vx, ib, iqs, y + iybs, v);
740
- tmp += v;
741
- }
742
-
743
- // sum up partial sums and write back result
744
- __syncthreads();
745
- #pragma unroll
746
- for (int mask = 16; mask > 0; mask >>= 1) {
747
- tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
748
- }
749
-
750
- if (tid == 0) {
751
- dst[row] = tmp;
1233
+ #endif // GGML_CUDA_DMMV_F16
752
1234
  }
753
1235
  }
754
1236
 
@@ -1020,12 +1502,20 @@ static void dequantize_row_q8_0_cuda(const void * vx, float * y, const int k, cu
1020
1502
 
1021
1503
  static void dequantize_row_q2_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
1022
1504
  const int nb = k / QK_K;
1505
+ #if QK_K == 256
1023
1506
  dequantize_block_q2_K<<<nb, 64, 0, stream>>>(vx, y);
1507
+ #else
1508
+ dequantize_block_q2_K<<<nb, 32, 0, stream>>>(vx, y);
1509
+ #endif
1024
1510
  }
1025
1511
 
1026
1512
  static void dequantize_row_q3_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
1027
1513
  const int nb = k / QK_K;
1514
+ #if QK_K == 256
1028
1515
  dequantize_block_q3_K<<<nb, 64, 0, stream>>>(vx, y);
1516
+ #else
1517
+ dequantize_block_q3_K<<<nb, 32, 0, stream>>>(vx, y);
1518
+ #endif
1029
1519
  }
1030
1520
 
1031
1521
  static void dequantize_row_q4_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
@@ -1035,15 +1525,23 @@ static void dequantize_row_q4_K_cuda(const void * vx, float * y, const int k, cu
1035
1525
 
1036
1526
  static void dequantize_row_q5_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
1037
1527
  const int nb = k / QK_K;
1528
+ #if QK_K == 256
1038
1529
  dequantize_block_q5_K<<<nb, 64, 0, stream>>>(vx, y);
1530
+ #else
1531
+ dequantize_block_q5_K<<<nb, 32, 0, stream>>>(vx, y);
1532
+ #endif
1039
1533
  }
1040
1534
 
1041
1535
  static void dequantize_row_q6_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
1042
1536
  const int nb = k / QK_K;
1537
+ #if QK_K == 256
1043
1538
  dequantize_block_q6_K<<<nb, 64, 0, stream>>>(vx, y);
1539
+ #else
1540
+ dequantize_block_q6_K<<<nb, 32, 0, stream>>>(vx, y);
1541
+ #endif
1044
1542
  }
1045
1543
 
1046
- static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
1544
+ static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
1047
1545
  GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
1048
1546
  const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
1049
1547
  const dim3 block_nums(1, block_num_y, 1);
@@ -1052,7 +1550,7 @@ static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const float * y, f
1052
1550
  <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
1053
1551
  }
1054
1552
 
1055
- static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
1553
+ static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
1056
1554
  GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
1057
1555
  const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
1058
1556
  const dim3 block_nums(1, block_num_y, 1);
@@ -1061,7 +1559,7 @@ static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const float * y, f
1061
1559
  <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
1062
1560
  }
1063
1561
 
1064
- static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
1562
+ static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
1065
1563
  GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
1066
1564
  const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
1067
1565
  const dim3 block_nums(1, block_num_y, 1);
@@ -1070,7 +1568,7 @@ static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const float * y, f
1070
1568
  <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
1071
1569
  }
1072
1570
 
1073
- static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
1571
+ static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
1074
1572
  GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
1075
1573
  const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
1076
1574
  const dim3 block_nums(1, block_num_y, 1);
@@ -1079,7 +1577,7 @@ static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const float * y, f
1079
1577
  <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
1080
1578
  }
1081
1579
 
1082
- static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
1580
+ static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
1083
1581
  GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
1084
1582
  const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
1085
1583
  const dim3 block_nums(1, block_num_y, 1);
@@ -1090,47 +1588,44 @@ static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const float * y, f
1090
1588
 
1091
1589
  static void dequantize_mul_mat_vec_q2_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
1092
1590
  GGML_ASSERT(ncols % QK_K == 0);
1093
- const int ny = 2;
1591
+ const int ny = 2; // very slightly faster than 1 even when K_QUANTS_PER_ITERATION = 2
1094
1592
  const int block_num_y = (nrows + ny - 1) / ny;
1095
1593
  const dim3 block_nums(1, block_num_y, 1);
1096
1594
  const dim3 block_dims(32, ny, 1);
1097
- dequantize_mul_mat_vec_k<32, vec_dot_q2_K><<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
1595
+ dequantize_mul_mat_vec_q2_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
1098
1596
  }
1099
1597
 
1100
1598
  static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
1101
1599
  GGML_ASSERT(ncols % QK_K == 0);
1102
- const int ny = 2;
1600
+ const int ny = 2 / K_QUANTS_PER_ITERATION;
1103
1601
  const int block_num_y = (nrows + ny - 1) / ny;
1104
1602
  const dim3 block_nums(1, block_num_y, 1);
1105
1603
  const dim3 block_dims(32, ny, 1);
1106
- dequantize_mul_mat_vec_k<32, vec_dot_q3_K><<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
1604
+ dequantize_mul_mat_vec_q3_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
1107
1605
  }
1108
1606
 
1109
1607
  static void dequantize_mul_mat_vec_q4_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
1110
1608
  GGML_ASSERT(ncols % QK_K == 0);
1111
- const int ny = 2;
1609
+ const int ny = 2 / K_QUANTS_PER_ITERATION;
1112
1610
  const int block_num_y = (nrows + ny - 1) / ny;
1113
1611
  const dim3 block_nums(1, block_num_y, 1);
1114
1612
  const dim3 block_dims(32, ny, 1);
1115
- dequantize_mul_mat_vec_k<32, vec_dot_q4_K><<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
1613
+ dequantize_mul_mat_vec_q4_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
1116
1614
  }
1117
1615
 
1118
1616
  static void dequantize_mul_mat_vec_q5_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
1119
1617
  GGML_ASSERT(ncols % QK_K == 0);
1120
- const int ny = 2;
1121
- const int block_num_y = (nrows + ny - 1) / ny;
1122
- const dim3 block_nums(1, block_num_y, 1);
1123
- const dim3 block_dims(32, ny, 1);
1124
- dequantize_mul_mat_vec_k<32, vec_dot_q5_K><<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
1618
+ const dim3 block_dims(32, 1, 1);
1619
+ dequantize_mul_mat_vec_q5_k<<<nrows, block_dims, 0, stream>>>(vx, y, dst, ncols);
1125
1620
  }
1126
1621
 
1127
1622
  static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
1128
1623
  GGML_ASSERT(ncols % QK_K == 0);
1129
- const int ny = 2;
1624
+ const int ny = 2 / K_QUANTS_PER_ITERATION;
1130
1625
  const int block_num_y = (nrows + ny - 1) / ny;
1131
1626
  const dim3 block_nums(1, block_num_y, 1);
1132
1627
  const dim3 block_dims(32, ny, 1);
1133
- dequantize_mul_mat_vec_k<32, vec_dot_q6_K><<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
1628
+ dequantize_mul_mat_vec_q6_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
1134
1629
  }
1135
1630
 
1136
1631
  static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
@@ -1138,7 +1633,7 @@ static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, c
1138
1633
  dequantize_block<1, 1, convert_f16><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
1139
1634
  }
1140
1635
 
1141
- static void convert_mul_mat_vec_f16_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
1636
+ static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
1142
1637
  GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
1143
1638
  const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
1144
1639
  const dim3 block_nums(1, block_num_y, 1);
@@ -1306,19 +1801,13 @@ static void * g_scratch_buffer = nullptr;
1306
1801
  static size_t g_scratch_size = 1024*1024*1024; // 1 GB by default
1307
1802
  static size_t g_scratch_offset = 0;
1308
1803
 
1309
- #define GGML_CUDA_MAX_STREAMS 8 // Set this to 1 for reproducible matrix multiplication.
1310
- #define GGML_CUDA_MAX_EVENTS 64
1311
-
1312
1804
  static int g_device_count = -1;
1313
1805
  static int g_main_device = 0;
1314
1806
  static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0};
1315
1807
 
1316
1808
  static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
1317
1809
 
1318
- static cudaStream_t g_cudaStreams_main[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { nullptr };
1319
-
1320
- static cudaStream_t g_cudaStreams_memcpy_src1[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { nullptr };
1321
- static cudaEvent_t g_cudaEvents_memcpy_src1[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_EVENTS] = { nullptr };
1810
+ static cudaStream_t g_cudaStreams_main[GGML_CUDA_MAX_DEVICES] = { nullptr };
1322
1811
 
1323
1812
  void ggml_init_cublas() {
1324
1813
  static bool initialized = false;
@@ -1342,15 +1831,8 @@ void ggml_init_cublas() {
1342
1831
  for (int id = 0; id < g_device_count; ++id) {
1343
1832
  CUDA_CHECK(cudaSetDevice(id));
1344
1833
 
1345
- // create streams
1346
- for (int i = 0; i < GGML_CUDA_MAX_STREAMS; ++i) {
1347
- CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams_main[id][i], cudaStreamNonBlocking));
1348
- CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams_memcpy_src1[id][i], cudaStreamNonBlocking));
1349
- }
1350
- // create events
1351
- for (int i = 0; i < GGML_CUDA_MAX_EVENTS; ++i) {
1352
- CUDA_CHECK(cudaEventCreateWithFlags(&g_cudaEvents_memcpy_src1[id][i], cudaEventDisableTiming));
1353
- }
1834
+ // create main stream
1835
+ CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams_main[id], cudaStreamNonBlocking));
1354
1836
 
1355
1837
  // create cublas handle
1356
1838
  CUBLAS_CHECK(cublasCreate(&g_cublas_handles[id]));
@@ -1566,21 +2048,40 @@ inline void ggml_cuda_op_dequantize_mul_mat_vec(
1566
2048
  const int64_t ne00 = src0->ne[0];
1567
2049
  const int64_t nrows = i01_high - i01_low;
1568
2050
 
2051
+ // on some GPUs it is faster to convert src1 to half and to use half precision intrinsics
2052
+ #ifdef GGML_CUDA_DMMV_F16
2053
+ size_t ash;
2054
+ dfloat * src1_dfloat = nullptr; // dfloat == half
2055
+
2056
+ bool src1_convert_f16 = src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 ||
2057
+ src0->type == GGML_TYPE_Q5_0 || src0->type == GGML_TYPE_Q5_1 ||
2058
+ src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16;
2059
+
2060
+ if (src1_convert_f16) {
2061
+ src1_dfloat = (half *) ggml_cuda_pool_malloc(ne00*sizeof(half), &ash);
2062
+ ggml_cpy_f32_f16_cuda((char *) src1_ddf_i, (char *) src1_dfloat, ne00,
2063
+ ne00, 1, sizeof(float), 0, 0,
2064
+ ne00, 1, sizeof(half), 0, 0, cudaStream_main);
2065
+ }
2066
+ #else
2067
+ dfloat * src1_dfloat = src1_ddf_i; // dfloat == float, no conversion
2068
+ #endif // GGML_CUDA_DMMV_F16
2069
+
1569
2070
  switch (src0->type) {
1570
2071
  case GGML_TYPE_Q4_0:
1571
- dequantize_mul_mat_vec_q4_0_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
2072
+ dequantize_mul_mat_vec_q4_0_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
1572
2073
  break;
1573
2074
  case GGML_TYPE_Q4_1:
1574
- dequantize_mul_mat_vec_q4_1_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
2075
+ dequantize_mul_mat_vec_q4_1_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
1575
2076
  break;
1576
2077
  case GGML_TYPE_Q5_0:
1577
- dequantize_mul_mat_vec_q5_0_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
2078
+ dequantize_mul_mat_vec_q5_0_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
1578
2079
  break;
1579
2080
  case GGML_TYPE_Q5_1:
1580
- dequantize_mul_mat_vec_q5_1_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
2081
+ dequantize_mul_mat_vec_q5_1_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
1581
2082
  break;
1582
2083
  case GGML_TYPE_Q8_0:
1583
- dequantize_mul_mat_vec_q8_0_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
2084
+ dequantize_mul_mat_vec_q8_0_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
1584
2085
  break;
1585
2086
  case GGML_TYPE_Q2_K:
1586
2087
  dequantize_mul_mat_vec_q2_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
@@ -1598,7 +2099,7 @@ inline void ggml_cuda_op_dequantize_mul_mat_vec(
1598
2099
  dequantize_mul_mat_vec_q6_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
1599
2100
  break;
1600
2101
  case GGML_TYPE_F16:
1601
- convert_mul_mat_vec_f16_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
2102
+ convert_mul_mat_vec_f16_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
1602
2103
  break;
1603
2104
  default:
1604
2105
  GGML_ASSERT(false);
@@ -1606,6 +2107,12 @@ inline void ggml_cuda_op_dequantize_mul_mat_vec(
1606
2107
  }
1607
2108
  CUDA_CHECK(cudaGetLastError());
1608
2109
 
2110
+ #ifdef GGML_CUDA_DMMV_F16
2111
+ if (src1_convert_f16) {
2112
+ ggml_cuda_pool_free(src1_dfloat, ash);
2113
+ }
2114
+ #endif // GGML_CUDA_DMMV_F16
2115
+
1609
2116
  (void) src1;
1610
2117
  (void) dst;
1611
2118
  (void) src0_ddf_i;
@@ -1817,6 +2324,12 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
1817
2324
  size_t src1_asf[GGML_CUDA_MAX_DEVICES] = {0};
1818
2325
  size_t dst_asf[GGML_CUDA_MAX_DEVICES] = {0};
1819
2326
 
2327
+ // if multiple GPUs are used they need to wait for the main GPU to finish
2328
+ if (split && g_device_count > 1) {
2329
+ CUDA_CHECK(cudaSetDevice(g_main_device));
2330
+ CUDA_CHECK(cudaDeviceSynchronize());
2331
+ }
2332
+
1820
2333
  for (int id = 0; id < g_device_count; ++id) {
1821
2334
  if (!split && id != g_main_device) {
1822
2335
  continue;
@@ -1915,9 +2428,7 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
1915
2428
  }
1916
2429
  const int64_t i11 = i13*ne12 + i12;
1917
2430
 
1918
- cudaStream_t cudaStream_main = g_cudaStreams_main[id][i0 % GGML_CUDA_MAX_STREAMS];
1919
- cudaStream_t cudaStream_memcpy_src1 = g_cudaStreams_memcpy_src1[id][i0 % GGML_CUDA_MAX_STREAMS];
1920
- cudaEvent_t cudaEvent_memcpy_src1 = g_cudaEvents_memcpy_src1[id][i0 % GGML_CUDA_MAX_EVENTS];
2431
+ cudaStream_t cudaStream_main = g_cudaStreams_main[id];
1921
2432
 
1922
2433
  // for split tensors the data begins at i0 == i0_offset_low
1923
2434
  char * src0_ddq_i = src0_ddq[id] + (i0 - i0_offset_low)*src0_stride*src0_ts/src0_bs;
@@ -1945,14 +2456,14 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
1945
2456
  if (src1->backend == GGML_BACKEND_CPU) {
1946
2457
  GGML_ASSERT(!flatten_rows || nrows0 == ggml_nrows(src1));
1947
2458
  int64_t nrows1 = flatten_rows ? nrows0 : ne11;
1948
- CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src1_ddf_i, src1, i03, i02, 0, nrows1, cudaStream_memcpy_src1));
2459
+ CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src1_ddf_i, src1, i03, i02, 0, nrows1, cudaStream_main));
1949
2460
  } else if (src1->backend == GGML_BACKEND_GPU && src1_is_contiguous) {
1950
2461
  if (id != g_main_device) {
1951
2462
  GGML_ASSERT(!flatten_rows);
1952
2463
  float * src1_ddf_i_source = (float *) src1_extra->data_device[g_main_device];
1953
2464
  src1_ddf_i_source += i11*src1_stride;
1954
2465
  CUDA_CHECK(cudaMemcpyAsync(src1_ddf_i, src1_ddf_i_source, src1_stride*sizeof(float),
1955
- cudaMemcpyDeviceToDevice, cudaStream_memcpy_src1));
2466
+ cudaMemcpyDeviceToDevice, cudaStream_main));
1956
2467
  }
1957
2468
  } else if (src1_on_device && !src1_is_contiguous) {
1958
2469
  GGML_ASSERT(!split);
@@ -1961,7 +2472,6 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
1961
2472
  GGML_ASSERT(false);
1962
2473
  }
1963
2474
  }
1964
- CUDA_CHECK(cudaEventRecord(cudaEvent_memcpy_src1, cudaStream_memcpy_src1));
1965
2475
 
1966
2476
  if (!src0_on_device || !src0_is_contiguous) {
1967
2477
  if (src0_is_f32) {
@@ -1977,9 +2487,6 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
1977
2487
  CUDA_CHECK(cudaGetLastError());
1978
2488
  }
1979
2489
 
1980
- // wait with main stream until src1 memcpy is done
1981
- CUDA_CHECK(cudaStreamWaitEvent(cudaStream_main, cudaEvent_memcpy_src1, 0));
1982
-
1983
2490
  // do the computation
1984
2491
  op(src0, src1, dst, src0_ddq_i, src0_ddf_i, src1_ddf_i, dst_ddf_i, i02, i01_low, i01_high, i11, cudaStream_main);
1985
2492
 
@@ -2017,8 +2524,13 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
2017
2524
 
2018
2525
  // wait until each device is finished, then free their buffers
2019
2526
  for (int id = 0; id < g_device_count; ++id) {
2527
+ if (src0_asq[id] == 0 && src0_asf[id] == 0 && src1_asf[id] == 0 && dst_asf[id] == 0) {
2528
+ continue;
2529
+ }
2530
+
2020
2531
  CUDA_CHECK(cudaSetDevice(id));
2021
2532
  CUDA_CHECK(cudaDeviceSynchronize());
2533
+
2022
2534
  if (src0_asq[id] > 0) {
2023
2535
  ggml_cuda_pool_free(src0_ddq[id], src0_asq[id]);
2024
2536
  }
@@ -2084,7 +2596,7 @@ void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * sr
2084
2596
  const int64_t ne02 = src0->ne[2];
2085
2597
 
2086
2598
  CUDA_CHECK(cudaSetDevice(g_main_device));
2087
- cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device][0];
2599
+ cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device];
2088
2600
 
2089
2601
  struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
2090
2602
  void * src0_ddq = src0_extra->data_device[g_main_device];
@@ -2096,8 +2608,6 @@ void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * sr
2096
2608
  float * dst_ddf = (float *) dst_extra->data_device[g_main_device];
2097
2609
 
2098
2610
  ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, cudaStream_main);
2099
-
2100
- CUDA_CHECK(cudaDeviceSynchronize());
2101
2611
  }
2102
2612
 
2103
2613
  void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
@@ -2115,7 +2625,7 @@ void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1
2115
2625
  const int64_t nb02 = src0->nb[2];
2116
2626
 
2117
2627
  CUDA_CHECK(cudaSetDevice(g_main_device));
2118
- cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device][0];
2628
+ cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device];
2119
2629
 
2120
2630
  struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
2121
2631
  void * src0_ddq = src0_extra->data_device[g_main_device];
@@ -2130,8 +2640,6 @@ void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1
2130
2640
  const int channel_stride_x = nb02 / sizeof(half);
2131
2641
 
2132
2642
  ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, channel_stride_x, cudaStream_main);
2133
-
2134
- CUDA_CHECK(cudaDeviceSynchronize());
2135
2643
  }
2136
2644
 
2137
2645
  void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -2187,7 +2695,7 @@ void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tens
2187
2695
  const int64_t nb12 = src1->nb[2];
2188
2696
 
2189
2697
  CUDA_CHECK(cudaSetDevice(g_main_device));
2190
- cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device][0];
2698
+ cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device];
2191
2699
 
2192
2700
  const struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
2193
2701
  const struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
@@ -2205,8 +2713,6 @@ void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tens
2205
2713
  GGML_ASSERT(false);
2206
2714
  }
2207
2715
 
2208
- CUDA_CHECK(cudaDeviceSynchronize());
2209
-
2210
2716
  (void) dst;
2211
2717
  }
2212
2718
 
@@ -2313,6 +2819,7 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch) {
2313
2819
 
2314
2820
  tensor->backend = GGML_BACKEND_GPU;
2315
2821
  struct ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu;
2822
+ memset(extra, 0, sizeof(*extra));
2316
2823
 
2317
2824
  const bool inplace = (tensor->src0 != nullptr && tensor->src0->data == tensor->data) ||
2318
2825
  tensor->op == GGML_OP_VIEW;
@@ -2395,7 +2902,7 @@ void ggml_cuda_free_scratch() {
2395
2902
  bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor){
2396
2903
  ggml_cuda_func_t func;
2397
2904
  const bool any_on_device = tensor->backend == GGML_BACKEND_GPU
2398
- || tensor->src0->backend == GGML_BACKEND_GPU || tensor->src0->backend == GGML_BACKEND_GPU_SPLIT
2905
+ || (tensor->src0 != nullptr && (tensor->src0->backend == GGML_BACKEND_GPU || tensor->src0->backend == GGML_BACKEND_GPU_SPLIT))
2399
2906
  || (tensor->src1 != nullptr && tensor->src1->backend == GGML_BACKEND_GPU);
2400
2907
 
2401
2908
  switch (tensor->op) {