llama_cpp 0.2.1 → 0.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -13,6 +13,10 @@
13
13
  #include "ggml-cuda.h"
14
14
  #include "ggml.h"
15
15
 
16
+ #if defined(_MSC_VER)
17
+ #pragma warning(disable: 4244 4267) // possible loss of data
18
+ #endif
19
+
16
20
  static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
17
21
 
18
22
  #define CUDA_CHECK(err) \
@@ -46,7 +50,15 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
46
50
  } while (0)
47
51
  #endif // CUDART_VERSION >= 11
48
52
 
49
- typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, float & v0, float & v1);
53
+ #ifdef GGML_CUDA_DMMV_F16
54
+ typedef half dfloat; // dequantize float
55
+ typedef half2 dfloat2;
56
+ #else
57
+ typedef float dfloat; // dequantize float
58
+ typedef float2 dfloat2;
59
+ #endif //GGML_CUDA_DMMV_F16
60
+
61
+ typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, dfloat2 & v);
50
62
  typedef void (*to_fp32_cuda_t)(const void * x, float * y, int k, cudaStream_t stream);
51
63
  typedef void (*dot_kernel_k_t)(const void * vx, const int ib, const int iqs, const float * y, float & v);
52
64
  typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
@@ -105,7 +117,13 @@ static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 blo
105
117
 
106
118
  //================================= k-quants
107
119
 
120
+ #ifdef GGML_QKK_64
121
+ #define QK_K 64
122
+ #define K_SCALE_SIZE 4
123
+ #else
108
124
  #define QK_K 256
125
+ #define K_SCALE_SIZE 12
126
+ #endif
109
127
 
110
128
  typedef struct {
111
129
  uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
@@ -116,13 +134,25 @@ typedef struct {
116
134
  static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_fp16_t) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding");
117
135
 
118
136
  typedef struct {
119
- uint8_t hmask[QK_K/8];
120
- uint8_t qs[QK_K/4]; // nibbles / quants
121
- uint8_t scales[3*QK_K/64];
122
- half d;
137
+ uint8_t hmask[QK_K/8]; // quants - high bit
138
+ uint8_t qs[QK_K/4]; // quants - low 2 bits
139
+ #ifdef GGML_QKK_64
140
+ uint8_t scales[2]; // scales, quantized with 8 bits
141
+ #else
142
+ uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits
143
+ #endif
144
+ half d; // super-block scale
123
145
  } block_q3_K;
124
- static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + 11 * QK_K / 64, "wrong q3_K block size/padding");
146
+ //static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + K_SCALE_SIZE, "wrong q3_K block size/padding");
125
147
 
148
+ #ifdef GGML_QKK_64
149
+ typedef struct {
150
+ half d[2]; // super-block scales/mins
151
+ uint8_t scales[2]; // 4-bit block scales/mins
152
+ uint8_t qs[QK_K/2]; // 4--bit quants
153
+ } block_q4_K;
154
+ static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + QK_K/2 + 2, "wrong q4_K block size/padding");
155
+ #else
126
156
  typedef struct {
127
157
  half d; // super-block scale for quantized scales
128
158
  half dmin; // super-block scale for quantized mins
@@ -130,15 +160,26 @@ typedef struct {
130
160
  uint8_t qs[QK_K/2]; // 4--bit quants
131
161
  } block_q4_K;
132
162
  static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2, "wrong q4_K block size/padding");
163
+ #endif
133
164
 
165
+ #ifdef GGML_QKK_64
166
+ typedef struct {
167
+ half d; // super-block scale
168
+ int8_t scales[QK_K/16]; // block scales
169
+ uint8_t qh[QK_K/8]; // quants, high bit
170
+ uint8_t qs[QK_K/2]; // quants, low 4 bits
171
+ } block_q5_K;
172
+ static_assert(sizeof(block_q5_K) == sizeof(ggml_fp16_t) + QK_K/2 + QK_K/8 + QK_K/16, "wrong q5_K block size/padding");
173
+ #else
134
174
  typedef struct {
135
- half d; // super-block scale for quantized scales
136
- half dmin; // super-block scale for quantized mins
137
- uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits
175
+ half d; // super-block scale for quantized scales
176
+ half dmin; // super-block scale for quantized mins
177
+ uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
138
178
  uint8_t qh[QK_K/8]; // quants, high bit
139
179
  uint8_t qs[QK_K/2]; // quants, low 4 bits
140
180
  } block_q5_K;
141
- static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2 + QK_K/8, "wrong q5_K block size/padding");
181
+ static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2 + QK_K/8, "wrong q5_K block size/padding");
182
+ #endif
142
183
 
143
184
  typedef struct {
144
185
  uint8_t ql[QK_K/2]; // quants, lower 4 bits
@@ -167,6 +208,12 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
167
208
  #define GGML_CUDA_DMMV_Y 1
168
209
  #endif
169
210
 
211
+ #ifndef K_QUANTS_PER_ITERATION
212
+ #define K_QUANTS_PER_ITERATION 2
213
+ #else
214
+ static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2");
215
+ #endif
216
+
170
217
  static __global__ void add_f32(const float * x, const float * y, float * dst, const int k) {
171
218
  const int i = blockDim.x*blockIdx.x + threadIdx.x;
172
219
 
@@ -224,82 +271,106 @@ static __global__ void rms_norm_f32(const float * x, float * dst, const int ncol
224
271
  }
225
272
  }
226
273
 
227
- static __device__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){
274
+ static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
228
275
  const block_q4_0 * x = (const block_q4_0 *) vx;
229
276
 
230
- const float d = x[ib].d;
277
+ const dfloat d = x[ib].d;
231
278
 
232
- const uint8_t vui = x[ib].qs[iqs];
279
+ const int vui = x[ib].qs[iqs];
233
280
 
234
- const int8_t vi0 = vui & 0xF;
235
- const int8_t vi1 = vui >> 4;
281
+ v.x = vui & 0xF;
282
+ v.y = vui >> 4;
236
283
 
237
- v0 = (vi0 - 8)*d;
238
- v1 = (vi1 - 8)*d;
284
+ #ifdef GGML_CUDA_DMMV_F16
285
+ v = __hsub2(v, {8.0f, 8.0f});
286
+ v = __hmul2(v, {d, d});
287
+ #else
288
+ v.x = (v.x - 8.0f) * d;
289
+ v.y = (v.y - 8.0f) * d;
290
+ #endif // GGML_CUDA_DMMV_F16
239
291
  }
240
292
 
241
- static __device__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, float & v0, float & v1){
293
+ static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, dfloat2 & v){
242
294
  const block_q4_1 * x = (const block_q4_1 *) vx;
243
295
 
244
- const float d = x[ib].d;
245
- const float m = x[ib].m;
296
+ const dfloat d = x[ib].d;
297
+ const dfloat m = x[ib].m;
246
298
 
247
- const uint8_t vui = x[ib].qs[iqs];
299
+ const int vui = x[ib].qs[iqs];
248
300
 
249
- const int8_t vi0 = vui & 0xF;
250
- const int8_t vi1 = vui >> 4;
301
+ v.x = vui & 0xF;
302
+ v.y = vui >> 4;
251
303
 
252
- v0 = vi0*d + m;
253
- v1 = vi1*d + m;
304
+ #ifdef GGML_CUDA_DMMV_F16
305
+ v = __hmul2(v, {d, d});
306
+ v = __hadd2(v, {m, m});
307
+ #else
308
+ v.x = (v.x * d) + m;
309
+ v.y = (v.y * d) + m;
310
+ #endif // GGML_CUDA_DMMV_F16
254
311
  }
255
312
 
256
- static __device__ void dequantize_q5_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){
313
+ static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
257
314
  const block_q5_0 * x = (const block_q5_0 *) vx;
258
315
 
259
- const float d = x[ib].d;
316
+ const dfloat d = x[ib].d;
260
317
 
261
318
  uint32_t qh;
262
319
  memcpy(&qh, x[ib].qh, sizeof(qh));
263
320
 
264
- const uint8_t xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10;
265
- const uint8_t xh_1 = ((qh >> (iqs + 12)) ) & 0x10;
321
+ const int xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10;
322
+ const int xh_1 = ((qh >> (iqs + 12)) ) & 0x10;
266
323
 
267
- const int32_t x0 = ((x[ib].qs[iqs] & 0xf) | xh_0) - 16;
268
- const int32_t x1 = ((x[ib].qs[iqs] >> 4) | xh_1) - 16;
324
+ v.x = ((x[ib].qs[iqs] & 0xf) | xh_0);
325
+ v.y = ((x[ib].qs[iqs] >> 4) | xh_1);
269
326
 
270
- v0 = x0*d;
271
- v1 = x1*d;
327
+ #ifdef GGML_CUDA_DMMV_F16
328
+ v = __hsub2(v, {16.0f, 16.0f});
329
+ v = __hmul2(v, {d, d});
330
+ #else
331
+ v.x = (v.x - 16.0f) * d;
332
+ v.y = (v.y - 16.0f) * d;
333
+ #endif // GGML_CUDA_DMMV_F16
272
334
  }
273
335
 
274
- static __device__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, float & v0, float & v1){
336
+ static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, dfloat2 & v){
275
337
  const block_q5_1 * x = (const block_q5_1 *) vx;
276
338
 
277
- const float d = x[ib].d;
278
- const float m = x[ib].m;
339
+ const dfloat d = x[ib].d;
340
+ const dfloat m = x[ib].m;
279
341
 
280
342
  uint32_t qh;
281
343
  memcpy(&qh, x[ib].qh, sizeof(qh));
282
344
 
283
- const uint8_t xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10;
284
- const uint8_t xh_1 = ((qh >> (iqs + 12)) ) & 0x10;
345
+ const int xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10;
346
+ const int xh_1 = ((qh >> (iqs + 12)) ) & 0x10;
285
347
 
286
- const int32_t x0 = ((x[ib].qs[iqs] & 0xf) | xh_0);
287
- const int32_t x1 = ((x[ib].qs[iqs] >> 4) | xh_1);
348
+ v.x = ((x[ib].qs[iqs] & 0xf) | xh_0);
349
+ v.y = ((x[ib].qs[iqs] >> 4) | xh_1);
288
350
 
289
- v0 = x0*d + m;
290
- v1 = x1*d + m;
351
+ #ifdef GGML_CUDA_DMMV_F16
352
+ v = __hmul2(v, {d, d});
353
+ v = __hadd2(v, {m, m});
354
+ #else
355
+ v.x = (v.x * d) + m;
356
+ v.y = (v.y * d) + m;
357
+ #endif // GGML_CUDA_DMMV_F16
291
358
  }
292
359
 
293
- static __device__ void dequantize_q8_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){
360
+ static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
294
361
  const block_q8_0 * x = (const block_q8_0 *) vx;
295
362
 
296
- const float d = x[ib].d;
363
+ const dfloat d = x[ib].d;
297
364
 
298
- const int8_t vi0 = x[ib].qs[iqs + 0];
299
- const int8_t vi1 = x[ib].qs[iqs + 1];
365
+ v.x = x[ib].qs[iqs + 0];
366
+ v.y = x[ib].qs[iqs + 1];
300
367
 
301
- v0 = vi0*d;
302
- v1 = vi1*d;
368
+ #ifdef GGML_CUDA_DMMV_F16
369
+ v = __hmul2(v, {d, d});
370
+ #else
371
+ v.x *= d;
372
+ v.y *= d;
373
+ #endif // GGML_CUDA_DMMV_F16
303
374
  }
304
375
 
305
376
  //================================== k-quants
@@ -307,13 +378,14 @@ static __device__ void dequantize_q8_0(const void * vx, const int ib, const int
307
378
  static __global__ void dequantize_block_q2_K(const void * vx, float * yy) {
308
379
 
309
380
  const int i = blockIdx.x;
381
+ const block_q2_K * x = (const block_q2_K *) vx;
382
+
310
383
  const int tid = threadIdx.x;
384
+ #if QK_K == 256
311
385
  const int n = tid/32;
312
386
  const int l = tid - 32*n;
313
387
  const int is = 8*n + l/16;
314
388
 
315
- const block_q2_K * x = (const block_q2_K *) vx;
316
-
317
389
  const uint8_t q = x[i].qs[32*n + l];
318
390
  float * y = yy + i*QK_K + 128*n;
319
391
 
@@ -323,52 +395,32 @@ static __global__ void dequantize_block_q2_K(const void * vx, float * yy) {
323
395
  y[l+32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 2) & 3) - dmin * (x[i].scales[is+2] >> 4);
324
396
  y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4);
325
397
  y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4);
326
-
327
- }
328
-
329
- static __device__ void vec_dot_q2_K(const void * vx, const int ib, const int iqs, const float * yy, float & result) {
330
-
331
- const block_q2_K * x = (const block_q2_K *) vx;
332
-
333
- // if n is 0, we want to do the lower 128, else the upper 128,
334
- // covering y[l+0], y[l+32], y[l+64], y[l+96] and
335
- // y[l+16], y[l+48], y[l+80], y[l+112]
336
- int n = iqs/128; // 0 or 1
337
- int r = iqs - 128*n; // 0...120 in steps of 8
338
- int l = r/8; // 0...15 in steps of 1
339
-
340
- const float * y = yy + 128*n + l;
341
- const uint8_t * q = x[ib].qs + 32*n + l;
342
- const uint8_t * s = x[ib].scales + 8*n;
343
-
344
- const float dall = x[ib].d;
345
- const float dmin = x[ib].dmin;
346
-
347
- float sum = y[ 0] * (dall * ((s[0] & 0xF) * ((q[ 0] >> 0) & 3)) - dmin * (s[0] >> 4))
348
- + y[ 32] * (dall * ((s[2] & 0xF) * ((q[ 0] >> 2) & 3)) - dmin * (s[2] >> 4))
349
- + y[ 64] * (dall * ((s[4] & 0xF) * ((q[ 0] >> 4) & 3)) - dmin * (s[4] >> 4))
350
- + y[ 96] * (dall * ((s[6] & 0xF) * ((q[ 0] >> 6) & 3)) - dmin * (s[6] >> 4))
351
- + y[ 16] * (dall * ((s[1] & 0xF) * ((q[16] >> 0) & 3)) - dmin * (s[1] >> 4))
352
- + y[ 48] * (dall * ((s[3] & 0xF) * ((q[16] >> 2) & 3)) - dmin * (s[3] >> 4))
353
- + y[ 80] * (dall * ((s[5] & 0xF) * ((q[16] >> 4) & 3)) - dmin * (s[5] >> 4))
354
- + y[112] * (dall * ((s[7] & 0xF) * ((q[16] >> 6) & 3)) - dmin * (s[7] >> 4));
355
-
356
- result = sum;
398
+ #else
399
+ const int is = tid/16; // 0 or 1
400
+ const int il = tid%16; // 0...15
401
+ const uint8_t q = x[i].qs[il] >> (2*is);
402
+ float * y = yy + i*QK_K + 16*is + il;
403
+ float dall = x[i].d;
404
+ float dmin = x[i].dmin;
405
+ y[ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
406
+ y[32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+2] >> 4);
407
+ #endif
357
408
 
358
409
  }
359
410
 
360
411
  static __global__ void dequantize_block_q3_K(const void * vx, float * yy) {
361
412
 
362
- int r = threadIdx.x/4;
363
- int i = blockIdx.x;
364
- int tid = r/2;
365
- int is0 = r%2;
366
- int l0 = 16*is0 + 4*(threadIdx.x%4);
367
- int n = tid / 4;
368
- int j = tid - 4*n;
369
-
413
+ const int i = blockIdx.x;
370
414
  const block_q3_K * x = (const block_q3_K *) vx;
371
415
 
416
+ #if QK_K == 256
417
+ const int r = threadIdx.x/4;
418
+ const int tid = r/2;
419
+ const int is0 = r%2;
420
+ const int l0 = 16*is0 + 4*(threadIdx.x%4);
421
+ const int n = tid / 4;
422
+ const int j = tid - 4*n;
423
+
372
424
  uint8_t m = 1 << (4*n + j);
373
425
  int is = 8*n + 2*j + is0;
374
426
  int shift = 2*j;
@@ -385,54 +437,31 @@ static __global__ void dequantize_block_q3_K(const void * vx, float * yy) {
385
437
  const uint8_t * hm = x[i].hmask;
386
438
 
387
439
  for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4));
440
+ #else
441
+ const int tid = threadIdx.x;
442
+ const int is = tid/16; // 0 or 1
443
+ const int il = tid%16; // 0...15
444
+ const int im = il/8; // 0...1
445
+ const int in = il%8; // 0...7
388
446
 
389
- }
390
-
391
- static __device__ void vec_dot_q3_K(const void * vx, const int ib, const int iqs, const float * yy, float & result) {
392
-
393
- const block_q3_K * x = (const block_q3_K *) vx;
394
-
395
- const uint32_t kmask1 = 0x03030303;
396
- const uint32_t kmask2 = 0x0f0f0f0f;
397
-
398
- uint32_t aux[3];
399
- uint32_t utmp[4];
400
-
401
- // if n is 0, we want to do the lower 128, else the upper 128,
402
- // covering y[l+0], y[l+32], y[l+64], y[l+96] and
403
- // y[l+16], y[l+48], y[l+80], y[l+112]
404
- int n = iqs/128; // 0 or 1
405
- int r = iqs - 128*n; // 0...120 in steps of 8
406
- int l = r/8; // 0...15 in steps of 1
407
-
408
- const float * y = yy + 128*n + l;
409
- const uint8_t * q = x[ib].qs + 32*n + l;
410
- const uint8_t * hm = x[ib].hmask + l;
411
- const int8_t * s = (const int8_t *)utmp + 8*n;
412
-
413
- memcpy(aux, x[ib].scales, 12);
414
- utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
415
- utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
416
- utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
417
- utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
418
-
419
- const float dall = x[ib].d;
420
-
421
- const uint8_t m = 1 << (4*n);
447
+ float * y = yy + i*QK_K + 16*is + il;
422
448
 
423
- float sum = y[ 0] * (s[0] - 32) * (((q[ 0] >> 0) & 3) - (hm[ 0] & (m << 0) ? 0 : 4))
424
- + y[ 32] * (s[2] - 32) * (((q[ 0] >> 2) & 3) - (hm[ 0] & (m << 1) ? 0 : 4))
425
- + y[ 64] * (s[4] - 32) * (((q[ 0] >> 4) & 3) - (hm[ 0] & (m << 2) ? 0 : 4))
426
- + y[ 96] * (s[6] - 32) * (((q[ 0] >> 6) & 3) - (hm[ 0] & (m << 3) ? 0 : 4))
427
- + y[ 16] * (s[1] - 32) * (((q[16] >> 0) & 3) - (hm[16] & (m << 0) ? 0 : 4))
428
- + y[ 48] * (s[3] - 32) * (((q[16] >> 2) & 3) - (hm[16] & (m << 1) ? 0 : 4))
429
- + y[ 80] * (s[5] - 32) * (((q[16] >> 4) & 3) - (hm[16] & (m << 2) ? 0 : 4))
430
- + y[112] * (s[7] - 32) * (((q[16] >> 6) & 3) - (hm[16] & (m << 3) ? 0 : 4));
449
+ const uint8_t q = x[i].qs[il] >> (2*is);
450
+ const uint8_t h = x[i].hmask[in] >> (2*is + im);
451
+ const float d = (float)x[i].d;
431
452
 
432
- result = sum * dall;
453
+ if (is == 0) {
454
+ y[ 0] = d * ((x[i].scales[0] & 0xF) - 8) * ((int8_t)((q >> 0) & 3) - ((h >> 0) & 1 ? 0 : 4));
455
+ y[32] = d * ((x[i].scales[1] & 0xF) - 8) * ((int8_t)((q >> 4) & 3) - ((h >> 4) & 1 ? 0 : 4));
456
+ } else {
457
+ y[ 0] = d * ((x[i].scales[0] >> 4) - 8) * ((int8_t)((q >> 0) & 3) - ((h >> 0) & 1 ? 0 : 4));
458
+ y[32] = d * ((x[i].scales[1] >> 4) - 8) * ((int8_t)((q >> 4) & 3) - ((h >> 4) & 1 ? 0 : 4));
459
+ }
460
+ #endif
433
461
 
434
462
  }
435
463
 
464
+ #if QK_K == 256
436
465
  static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) {
437
466
  if (j < 4) {
438
467
  d = q[j] & 63; m = q[j + 4] & 63;
@@ -441,19 +470,14 @@ static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t
441
470
  m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
442
471
  }
443
472
  }
473
+ #endif
444
474
 
445
475
  static __global__ void dequantize_block_q4_K(const void * vx, float * yy) {
446
476
  const block_q4_K * x = (const block_q4_K *) vx;
447
477
 
448
478
  const int i = blockIdx.x;
449
479
 
450
- //// assume 64 threads - this is very slightly better than the one below
451
- //const int tid = threadIdx.x;
452
- //const int il = tid/16;
453
- //const int ir = tid%16;
454
- //const int is = 2*il;
455
- //const int n = 2;
456
-
480
+ #if QK_K == 256
457
481
  // assume 32 threads
458
482
  const int tid = threadIdx.x;
459
483
  const int il = tid/8;
@@ -477,38 +501,15 @@ static __global__ void dequantize_block_q4_K(const void * vx, float * yy) {
477
501
  y[l + 0] = d1 * (q[l] & 0xF) - m1;
478
502
  y[l +32] = d2 * (q[l] >> 4) - m2;
479
503
  }
480
- }
481
-
482
- static __device__ void vec_dot_q4_K(const void * vx, const int ib, const int iqs, const float * yy, float & result) {
483
-
484
- const block_q4_K * x = (const block_q4_K *) vx;
485
-
486
- // iqs is in 0...248 in steps of 8 =>
487
- const int j = iqs / 64; // j is in 0...3
488
- const int ir = (iqs - 64*j)/2; // ir is in 0...28 in steps of 4
489
- const int is = 2*j; // is is in 0...6 in steps of 2
490
-
491
- const float * y = yy + 64*j + ir;
492
- const uint8_t * q = x[ib].qs + 32*j + ir;
493
-
494
- const float dall = x[ib].d;
495
- const float dmin = x[ib].dmin;
496
-
497
- uint8_t sc, m;
498
- get_scale_min_k4(is + 0, x[ib].scales, sc, m);
499
- const float d1 = dall * sc;
500
- const float m1 = dmin * m;
501
- get_scale_min_k4(is + 1, x[ib].scales, sc, m);
502
- const float d2 = dall * sc;
503
- const float m2 = dmin * m;
504
-
505
- float sum = 0;
506
- for (int k = 0; k < 4; ++k) {
507
- sum += y[k + 0] * (d1 * (q[k] & 0xF) - m1);
508
- sum += y[k + 32] * (d2 * (q[k] >> 4) - m2);
509
- }
510
- result = sum;
511
-
504
+ #else
505
+ const int tid = threadIdx.x;
506
+ const uint8_t * q = x[i].qs;
507
+ float * y = yy + i*QK_K;
508
+ const float d = (float)x[i].d[0];
509
+ const float m = (float)x[i].d[1];
510
+ y[tid+ 0] = d * (x[i].scales[0] & 0xF) * (q[tid] & 0xF) - m * (x[i].scales[0] >> 4);
511
+ y[tid+32] = d * (x[i].scales[1] & 0xF) * (q[tid] >> 4) - m * (x[i].scales[1] >> 4);
512
+ #endif
512
513
  }
513
514
 
514
515
  static __global__ void dequantize_block_q5_K(const void * vx, float * yy) {
@@ -516,6 +517,7 @@ static __global__ void dequantize_block_q5_K(const void * vx, float * yy) {
516
517
 
517
518
  const int i = blockIdx.x;
518
519
 
520
+ #if QK_K == 256
519
521
  // assume 64 threads - this is very slightly better than the one below
520
522
  const int tid = threadIdx.x;
521
523
  const int il = tid/16; // il is in 0...3
@@ -542,49 +544,25 @@ static __global__ void dequantize_block_q5_K(const void * vx, float * yy) {
542
544
  hm <<= 1;
543
545
  y[32] = d2 * ((ql[ 0] >> 4) + (qh[ 0] & hm ? 16 : 0)) - m2;
544
546
  y[33] = d2 * ((ql[ 1] >> 4) + (qh[ 1] & hm ? 16 : 0)) - m2;
545
- }
546
-
547
- static __device__ void vec_dot_q5_K(const void * vx, const int ib, const int iqs, const float * yy, float & result) {
548
-
549
- const block_q5_K * x = (const block_q5_K *) vx;
550
-
551
- // iqs is in 0...248 in steps of 8 =>
552
- const int j = iqs / 64; // j is in 0...3
553
- const int ir = (iqs - 64*j)/2; // ir is in 0...28 in steps of 4
554
- const int is = 2*j; // is is in 0...6 in steps of 2
555
-
556
- const float * y = yy + 64*j + ir;
557
- const uint8_t * ql = x[ib].qs + 32*j + ir;
558
- const uint8_t * qh = x[ib].qh + ir;
559
-
560
- const float dall = x[ib].d;
561
- const float dmin = x[ib].dmin;
562
-
563
- uint8_t sc, m;
564
- get_scale_min_k4(is + 0, x[ib].scales, sc, m);
565
- const float d1 = dall * sc;
566
- const float m1 = dmin * m;
567
- get_scale_min_k4(is + 1, x[ib].scales, sc, m);
568
- const float d2 = dall * sc;
569
- const float m2 = dmin * m;
570
-
571
- uint8_t hm = 1 << is;
572
- float sum = 0;
573
- for (int k = 0; k < 4; ++k) {
574
- sum += y[k + 0] * (d1 * ((ql[k] & 0xF) + (qh[k] & hm ? 16 : 0)) - m1);
575
- }
576
- hm <<= 1;
577
- for (int k = 0; k < 4; ++k) {
578
- sum += y[k + 32] * (d2 * ((ql[k] >> 4) + (qh[k] & hm ? 16 : 0)) - m2);
579
- }
580
- result = sum;
581
-
547
+ #else
548
+ const int tid = threadIdx.x;
549
+ const uint8_t q = x[i].qs[tid];
550
+ const int im = tid/8; // 0...3
551
+ const int in = tid%8; // 0...7
552
+ const int is = tid/16; // 0 or 1
553
+ const uint8_t h = x[i].qh[in] >> im;
554
+ const float d = x[i].d;
555
+ float * y = yy + i*QK_K + tid;
556
+ y[ 0] = d * x[i].scales[is+0] * ((q & 0xF) - ((h >> 0) & 1 ? 0 : 16));
557
+ y[32] = d * x[i].scales[is+2] * ((q >> 4) - ((h >> 4) & 1 ? 0 : 16));
558
+ #endif
582
559
  }
583
560
 
584
561
  static __global__ void dequantize_block_q6_K(const void * vx, float * yy) {
585
562
  const block_q6_K * x = (const block_q6_K *) vx;
586
563
 
587
564
  const int i = blockIdx.x;
565
+ #if QK_K == 256
588
566
 
589
567
  // assume 64 threads - this is very slightly better than the one below
590
568
  const int tid = threadIdx.x;
@@ -604,40 +582,566 @@ static __global__ void dequantize_block_q6_K(const void * vx, float * yy) {
604
582
  y[32] = d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32);
605
583
  y[64] = d * sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32);
606
584
  y[96] = d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32);
585
+ #else
586
+
587
+ // assume 32 threads
588
+ const int tid = threadIdx.x;
589
+ const int ip = tid/16; // 0 or 1
590
+ const int il = tid - 16*ip; // 0...15
591
+
592
+ float * y = yy + i*QK_K + 16*ip + il;
593
+
594
+ const float d = x[i].d;
595
+
596
+ const uint8_t ql = x[i].ql[16*ip + il];
597
+ const uint8_t qh = x[i].qh[il] >> (2*ip);
598
+ const int8_t * sc = x[i].scales;
599
+
600
+ y[ 0] = d * sc[ip+0] * ((int8_t)((ql & 0xF) | (((qh >> 0) & 3) << 4)) - 32);
601
+ y[32] = d * sc[ip+2] * ((int8_t)((ql >> 4) | (((qh >> 4) & 3) << 4)) - 32);
602
+ #endif
607
603
  }
608
604
 
609
- static __device__ void vec_dot_q6_K(const void * vx, const int ib, const int iqs, const float * yy, float & result) {
605
+ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
610
606
 
611
- const block_q6_K * x = (const block_q6_K *) vx;
607
+ static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
608
+
609
+ const int row = blockIdx.y*blockDim.y + threadIdx.y;
610
+ if (row > nrows) return;
611
+
612
+ const int num_blocks_per_row = ncols / QK_K;
613
+ const int ib0 = row*num_blocks_per_row;
614
+
615
+ const block_q2_K * x = (const block_q2_K *)vx + ib0;
616
+
617
+ float tmp = 0; // partial sum for thread in warp
618
+
619
+ #if QK_K == 256
620
+ const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...15
621
+ const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1
622
+
623
+ const int step = 16/K_QUANTS_PER_ITERATION;
624
+
625
+ const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
626
+ const int in = tid - step*im; // 0...15 or 0...7
627
+
628
+ const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15 or 0...14 in steps of 2
629
+ const int q_offset = 32*im + l0;
630
+ const int s_offset = 8*im;
631
+ const int y_offset = 128*im + l0;
632
+
633
+ uint32_t aux[4];
634
+ const uint8_t * d = (const uint8_t *)aux;
635
+ const uint8_t * m = (const uint8_t *)(aux + 2);
636
+
637
+ for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
638
+
639
+ const float * y = yy + i * QK_K + y_offset;
640
+ const uint8_t * q = x[i].qs + q_offset;
641
+
642
+ const float dall = x[i].d;
643
+ const float dmin = x[i].dmin;
644
+
645
+ const uint32_t * a = (const uint32_t *)(x[i].scales + s_offset);
646
+ aux[0] = a[0] & 0x0f0f0f0f;
647
+ aux[1] = a[1] & 0x0f0f0f0f;
648
+ aux[2] = (a[0] >> 4) & 0x0f0f0f0f;
649
+ aux[3] = (a[1] >> 4) & 0x0f0f0f0f;
650
+
651
+ float sum1 = 0, sum2 = 0;
652
+ for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
653
+ sum1 += y[l+ 0] * d[0] * ((q[l+ 0] >> 0) & 3)
654
+ + y[l+32] * d[2] * ((q[l+ 0] >> 2) & 3)
655
+ + y[l+64] * d[4] * ((q[l+ 0] >> 4) & 3)
656
+ + y[l+96] * d[6] * ((q[l+ 0] >> 6) & 3)
657
+ + y[l+16] * d[1] * ((q[l+16] >> 0) & 3)
658
+ + y[l+48] * d[3] * ((q[l+16] >> 2) & 3)
659
+ + y[l+80] * d[5] * ((q[l+16] >> 4) & 3)
660
+ +y[l+112] * d[7] * ((q[l+16] >> 6) & 3);
661
+ sum2 += y[l+ 0] * m[0] + y[l+32] * m[2] + y[l+64] * m[4] + y[ l+96] * m[6]
662
+ + y[l+16] * m[1] + y[l+48] * m[3] + y[l+80] * m[5] + y[l+112] * m[7];
663
+
664
+ }
665
+ tmp += dall * sum1 - dmin * sum2;
666
+
667
+ }
668
+ #else
669
+ const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15 or 0...7
670
+ const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); // 0....1 or 0...3
671
+ const int offset = tid * K_QUANTS_PER_ITERATION;
672
+
673
+ uint32_t uaux[2];
674
+ const uint8_t * d = (const uint8_t *)uaux;
675
+
676
+ for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
677
+
678
+ const float * y = yy + i * QK_K + offset;
679
+ const uint8_t * q = x[i].qs + offset;
680
+ const uint32_t * s = (const uint32_t *)x[i].scales;
681
+
682
+ uaux[0] = s[0] & 0x0f0f0f0f;
683
+ uaux[1] = (s[0] >> 4) & 0x0f0f0f0f;
684
+
685
+ const half2 * dh = (const half2 *)&x[i].d;
686
+
687
+ const float2 dall = __half22float2(dh[0]);
688
+
689
+ float sum1 = 0, sum2 = 0;
690
+ for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
691
+ const uint8_t ql = q[l];
692
+ sum1 += y[l+ 0] * d[0] * ((ql >> 0) & 3)
693
+ + y[l+16] * d[1] * ((ql >> 2) & 3)
694
+ + y[l+32] * d[2] * ((ql >> 4) & 3)
695
+ + y[l+48] * d[3] * ((ql >> 6) & 3);
696
+ sum2 += y[l+0] * d[4] + y[l+16] * d[5] + y[l+32] * d[6] + y[l+48] * d[7];
697
+ }
698
+ tmp += dall.x * sum1 - dall.y * sum2;
699
+ }
700
+ #endif
701
+
702
+ // sum up partial sums and write back result
703
+ __syncthreads();
704
+ #pragma unroll
705
+ for (int mask = 16; mask > 0; mask >>= 1) {
706
+ tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
707
+ }
708
+
709
+ if (threadIdx.x == 0) {
710
+ dst[row] = tmp;
711
+ }
712
+ }
713
+
714
+ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
715
+
716
+ const int row = blockIdx.y*blockDim.y + threadIdx.y;
717
+ if (row > nrows) return;
718
+
719
+ const int num_blocks_per_row = ncols / QK_K;
720
+ const int ib0 = row*num_blocks_per_row;
721
+
722
+ const block_q3_K * x = (const block_q3_K *)vx + ib0;
723
+
724
+ float tmp = 0; // partial sum for thread in warp
725
+
726
+ #if QK_K == 256
727
+
728
+ const uint16_t kmask1 = 0x0303;
729
+ const uint16_t kmask2 = 0x0f0f;
730
+
731
+ const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
732
+ const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1
733
+
734
+ const int n = K_QUANTS_PER_ITERATION; // iterations in the inner loop
735
+ const int step = 16/K_QUANTS_PER_ITERATION;
736
+ const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
737
+ const int in = tid - step*im; // 0....15 or 0...7
738
+
739
+ const uint8_t m = 1 << (4*im);
740
+
741
+ const int l0 = n*in; // 0...15 or 0...14 in steps of 2
742
+ const int q_offset = 32*im + l0;
743
+ const int y_offset = 128*im + l0;
744
+
745
+ uint16_t utmp[4];
746
+ const int8_t * s = (const int8_t *)utmp;
747
+
748
+ const uint16_t s_shift = 4*im;
749
+
750
+ for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
751
+
752
+ const float * y = yy + i * QK_K + y_offset;
753
+ const uint8_t * q = x[i].qs + q_offset;
754
+ const uint8_t * h = x[i].hmask + l0;
755
+
756
+ const uint16_t * a = (const uint16_t *)x[i].scales;
757
+ utmp[0] = ((a[0] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 0)) & kmask1) << 4);
758
+ utmp[1] = ((a[1] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 0)) & kmask1) << 4);
759
+ utmp[2] = ((a[2] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 2)) & kmask1) << 4);
760
+ utmp[3] = ((a[3] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 2)) & kmask1) << 4);
761
+
762
+ const float d = x[i].d;
763
+
764
+ float sum = 0;
765
+ for (int l = 0; l < n; ++l) {
766
+ sum += y[l+ 0] * (s[0] - 32) * (((q[l] >> 0) & 3) - (h[l] & (m << 0) ? 0 : 4))
767
+ + y[l+32] * (s[2] - 32) * (((q[l] >> 2) & 3) - (h[l] & (m << 1) ? 0 : 4))
768
+ + y[l+64] * (s[4] - 32) * (((q[l] >> 4) & 3) - (h[l] & (m << 2) ? 0 : 4))
769
+ + y[l+96] * (s[6] - 32) * (((q[l] >> 6) & 3) - (h[l] & (m << 3) ? 0 : 4));
770
+ sum += y[l+16] * (s[1] - 32) * (((q[l+16] >> 0) & 3) - (h[l+16] & (m << 0) ? 0 : 4))
771
+ + y[l+48] * (s[3] - 32) * (((q[l+16] >> 2) & 3) - (h[l+16] & (m << 1) ? 0 : 4))
772
+ + y[l+80] * (s[5] - 32) * (((q[l+16] >> 4) & 3) - (h[l+16] & (m << 2) ? 0 : 4))
773
+ + y[l+112] * (s[7] - 32) * (((q[l+16] >> 6) & 3) - (h[l+16] & (m << 3) ? 0 : 4));
774
+ }
775
+ tmp += d * sum;
776
+
777
+ }
778
+ #else
779
+
780
+ const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15 or 0...7
781
+ const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); // 0....1 or 0...3
782
+ const int offset = tid * K_QUANTS_PER_ITERATION; // 0...15 or 0...14
783
+ const int in = offset/8; // 0 or 1
784
+ const int im = offset%8; // 0...7
785
+
786
+ for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
787
+
788
+ const float * y = yy + i * QK_K + offset;
789
+ const uint8_t * q = x[i].qs + offset;
790
+ const uint8_t * s = x[i].scales;
791
+
792
+ const float dall = (float)x[i].d;
793
+
794
+ float sum = 0;
795
+ for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
796
+ const uint8_t hl = x[i].hmask[im+l] >> in;
797
+ const uint8_t ql = q[l];
798
+ sum += y[l+ 0] * dall * ((s[0] & 0xF) - 8) * ((int8_t)((ql >> 0) & 3) - ((hl >> 0) & 1 ? 0 : 4))
799
+ + y[l+16] * dall * ((s[0] >> 4) - 8) * ((int8_t)((ql >> 2) & 3) - ((hl >> 2) & 1 ? 0 : 4))
800
+ + y[l+32] * dall * ((s[1] & 0xF) - 8) * ((int8_t)((ql >> 4) & 3) - ((hl >> 4) & 1 ? 0 : 4))
801
+ + y[l+48] * dall * ((s[1] >> 4) - 8) * ((int8_t)((ql >> 6) & 3) - ((hl >> 6) & 1 ? 0 : 4));
802
+ }
803
+ tmp += sum;
804
+ }
805
+ #endif
806
+
807
+ // sum up partial sums and write back result
808
+ __syncthreads();
809
+ #pragma unroll
810
+ for (int mask = 16; mask > 0; mask >>= 1) {
811
+ tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
812
+ }
813
+
814
+ if (threadIdx.x == 0) {
815
+ dst[row] = tmp;
816
+ }
817
+ }
818
+
819
+ static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
820
+
821
+ const int row = blockIdx.y*blockDim.y + threadIdx.y;
822
+ if (row > nrows) return;
823
+ const int num_blocks_per_row = ncols / QK_K;
824
+ const int ib0 = row*num_blocks_per_row;
825
+
826
+ const block_q4_K * x = (const block_q4_K *)vx + ib0;
827
+
828
+ #if QK_K == 256
829
+ const uint16_t kmask1 = 0x3f3f;
830
+ const uint16_t kmask2 = 0x0f0f;
831
+ const uint16_t kmask3 = 0xc0c0;
832
+
833
+ const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
834
+ const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1
835
+
836
+ const int step = 8/K_QUANTS_PER_ITERATION; // 8 or 4
837
+
838
+ const int il = tid/step; // 0...3
839
+ const int ir = tid - step*il; // 0...7 or 0...3
840
+ const int n = 2 * K_QUANTS_PER_ITERATION; // 2 or 4
841
+
842
+ const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
843
+ const int in = il%2;
844
+
845
+ const int l0 = n*(2*ir + in);
846
+ const int q_offset = 32*im + l0;
847
+ const int y_offset = 64*im + l0;
848
+
849
+ uint16_t aux[4];
850
+ const uint8_t * sc = (const uint8_t *)aux;
851
+
852
+ float tmp = 0; // partial sum for thread in warp
612
853
 
613
- const int ip = iqs / 128; // 0 or 1
614
- const int il = (iqs - 128*ip)/8; // 0...15
615
- const int is = 8*ip;
854
+ for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
616
855
 
617
- const float * y = yy + 128*ip + il;
856
+ const uint8_t * q1 = x[i].qs + q_offset;
857
+ const uint8_t * q2 = q1 + 64;
858
+ const float * y1 = yy + i*QK_K + y_offset;
859
+ const float * y2 = y1 + 128;
618
860
 
619
- const float d = x[ib].d;
861
+ const float dall = x[i].d;
862
+ const float dmin = x[i].dmin;
620
863
 
621
- const uint8_t * ql = x[ib].ql + 64*ip + il;
622
- const uint8_t * qh = x[ib].qh + 32*ip + il;
623
- const int8_t * sc = x[ib].scales + is;
864
+ const uint16_t * a = (const uint16_t *)x[i].scales;
865
+ aux[0] = a[im+0] & kmask1;
866
+ aux[1] = a[im+2] & kmask1;
867
+ aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2);
868
+ aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2);
624
869
 
625
- result = y[ 0] * d * sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh[ 0] >> 0) & 3) << 4)) - 32)
626
- + y[ 32] * d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh[ 0] >> 2) & 3) << 4)) - 32)
627
- + y[ 64] * d * sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh[ 0] >> 4) & 3) << 4)) - 32)
628
- + y[ 96] * d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh[ 0] >> 6) & 3) << 4)) - 32)
629
- + y[ 16] * d * sc[1] * ((int8_t)((ql[16] & 0xF) | (((qh[16] >> 0) & 3) << 4)) - 32)
630
- + y[ 48] * d * sc[3] * ((int8_t)((ql[48] & 0xF) | (((qh[16] >> 2) & 3) << 4)) - 32)
631
- + y[ 80] * d * sc[5] * ((int8_t)((ql[16] >> 4) | (((qh[16] >> 4) & 3) << 4)) - 32)
632
- + y[112] * d * sc[7] * ((int8_t)((ql[48] >> 4) | (((qh[16] >> 6) & 3) << 4)) - 32);
870
+ float4 s = {0.f, 0.f, 0.f, 0.f};
871
+ float smin = 0;
872
+ for (int l = 0; l < n; ++l) {
873
+ s.x += y1[l] * (q1[l] & 0xF); s.y += y1[l+32] * (q1[l] >> 4);
874
+ s.z += y2[l] * (q2[l] & 0xF); s.w += y2[l+32] * (q2[l] >> 4);
875
+ smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7];
876
+ }
877
+ tmp += dall * (s.x * sc[0] + s.y * sc[1] + s.z * sc[4] + s.w * sc[5]) - dmin * smin;
878
+
879
+ }
880
+ #else
881
+ const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15
882
+ const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION);
883
+
884
+ const int step = tid * K_QUANTS_PER_ITERATION;
885
+
886
+ uint16_t aux16[2];
887
+ const uint8_t * s = (const uint8_t *)aux16;
888
+
889
+ float tmp = 0;
890
+
891
+ for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
892
+ const uint8_t * q = x[i].qs + step;
893
+ const float * y = yy + i*QK_K + step;
894
+ const uint16_t * a = (const uint16_t *)x[i].scales;
895
+ aux16[0] = a[0] & 0x0f0f;
896
+ aux16[1] = (a[0] >> 4) & 0x0f0f;
897
+ const float d = (float)x[i].d[0];
898
+ const float m = (float)x[i].d[1];
899
+ float sum = 0.f;
900
+ for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
901
+ sum += y[j+ 0] * (d * s[0] * (q[j+ 0] & 0xF) - m * s[2])
902
+ + y[j+16] * (d * s[0] * (q[j+16] & 0xF) - m * s[2])
903
+ + y[j+32] * (d * s[1] * (q[j+ 0] >> 4) - m * s[3])
904
+ + y[j+48] * (d * s[1] * (q[j+16] >> 4) - m * s[3]);
905
+ }
906
+ tmp += sum;
907
+ }
908
+
909
+ #endif
910
+
911
+ // sum up partial sums and write back result
912
+ __syncthreads();
913
+ #pragma unroll
914
+ for (int mask = 16; mask > 0; mask >>= 1) {
915
+ tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
916
+ }
633
917
 
918
+ if (tid == 0) {
919
+ dst[row] = tmp;
920
+ }
634
921
  }
635
922
 
636
- static __device__ void convert_f16(const void * vx, const int ib, const int iqs, float & v0, float & v1){
923
+ static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float * yy, float * dst, const int ncols) {
924
+
925
+ const int row = blockIdx.x;
926
+ const int num_blocks_per_row = ncols / QK_K;
927
+ const int ib0 = row*num_blocks_per_row;
928
+
929
+ const block_q5_K * x = (const block_q5_K *)vx + ib0;
930
+
931
+ float tmp = 0; // partial sum for thread in warp
932
+
933
+ #if QK_K == 256
934
+ const uint16_t kmask1 = 0x3f3f;
935
+ const uint16_t kmask2 = 0x0f0f;
936
+ const uint16_t kmask3 = 0xc0c0;
937
+
938
+ const int tid = threadIdx.x/2; // 0...15
939
+ const int ix = threadIdx.x%2;
940
+
941
+ const int il = tid/4; // 0...3
942
+ const int ir = tid - 4*il;// 0...3
943
+ const int n = 2;
944
+
945
+ const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
946
+ const int in = il%2;
947
+
948
+ const int l0 = n*(2*ir + in);
949
+ const int q_offset = 32*im + l0;
950
+ const int y_offset = 64*im + l0;
951
+
952
+ const uint8_t hm1 = 1 << (2*im);
953
+ const uint8_t hm2 = hm1 << 4;
954
+
955
+ uint16_t aux[4];
956
+ const uint8_t * sc = (const uint8_t *)aux;
957
+
958
+ for (int i = ix; i < num_blocks_per_row; i += 2) {
959
+
960
+ const uint8_t * ql1 = x[i].qs + q_offset;
961
+ const uint8_t * ql2 = ql1 + 64;
962
+ const uint8_t * qh = x[i].qh + l0;
963
+ const float * y1 = yy + i*QK_K + y_offset;
964
+ const float * y2 = y1 + 128;
965
+
966
+ const float dall = x[i].d;
967
+ const float dmin = x[i].dmin;
968
+
969
+ const uint16_t * a = (const uint16_t *)x[i].scales;
970
+ aux[0] = a[im+0] & kmask1;
971
+ aux[1] = a[im+2] & kmask1;
972
+ aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2);
973
+ aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2);
974
+
975
+ float4 sum = {0.f, 0.f, 0.f, 0.f};
976
+ float smin = 0;
977
+ for (int l = 0; l < n; ++l) {
978
+ sum.x += y1[l+ 0] * ((ql1[l+ 0] & 0xF) + (qh[l+ 0] & (hm1 << 0) ? 16 : 0))
979
+ + y1[l+16] * ((ql1[l+16] & 0xF) + (qh[l+16] & (hm1 << 0) ? 16 : 0));
980
+ sum.y += y1[l+32] * ((ql1[l+ 0] >> 4) + (qh[l+ 0] & (hm1 << 1) ? 16 : 0))
981
+ + y1[l+48] * ((ql1[l+16] >> 4) + (qh[l+16] & (hm1 << 1) ? 16 : 0));
982
+ sum.z += y2[l+ 0] * ((ql2[l+ 0] & 0xF) + (qh[l+ 0] & (hm2 << 0) ? 16 : 0))
983
+ + y2[l+16] * ((ql2[l+16] & 0xF) + (qh[l+16] & (hm2 << 0) ? 16 : 0));
984
+ sum.w += y2[l+32] * ((ql2[l+ 0] >> 4) + (qh[l+ 0] & (hm2 << 1) ? 16 : 0))
985
+ + y2[l+48] * ((ql2[l+16] >> 4) + (qh[l+16] & (hm2 << 1) ? 16 : 0));
986
+ smin += (y1[l] + y1[l+16]) * sc[2] + (y1[l+32] + y1[l+48]) * sc[3]
987
+ + (y2[l] + y2[l+16]) * sc[6] + (y2[l+32] + y2[l+48]) * sc[7];
988
+ }
989
+ tmp += dall * (sum.x * sc[0] + sum.y * sc[1] + sum.z * sc[4] + sum.w * sc[5]) - dmin * smin;
990
+ }
991
+
992
+ #else
993
+ const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15
994
+ const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION);
995
+ const int step = tid * K_QUANTS_PER_ITERATION;
996
+ const int im = step/8;
997
+ const int in = step%8;
998
+
999
+ for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
1000
+ const uint8_t * q = x[i].qs + step;
1001
+ const int8_t * s = x[i].scales;
1002
+ const float * y = yy + i*QK_K + step;
1003
+ const float d = x[i].d;
1004
+ float sum = 0.f;
1005
+ for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
1006
+ const uint8_t h = x[i].qh[in+j] >> im;
1007
+ sum += y[j+ 0] * d * s[0] * ((q[j+ 0] & 0xF) - ((h >> 0) & 1 ? 0 : 16))
1008
+ + y[j+16] * d * s[1] * ((q[j+16] & 0xF) - ((h >> 2) & 1 ? 0 : 16))
1009
+ + y[j+32] * d * s[2] * ((q[j+ 0] >> 4) - ((h >> 4) & 1 ? 0 : 16))
1010
+ + y[j+48] * d * s[3] * ((q[j+16] >> 4) - ((h >> 6) & 1 ? 0 : 16));
1011
+ }
1012
+ tmp += sum;
1013
+ }
1014
+ #endif
1015
+
1016
+ // sum up partial sums and write back result
1017
+ __syncthreads();
1018
+ #pragma unroll
1019
+ for (int mask = 16; mask > 0; mask >>= 1) {
1020
+ tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
1021
+ }
1022
+
1023
+ if (threadIdx.x == 0) {
1024
+ dst[row] = tmp;
1025
+ }
1026
+ }
1027
+
1028
+ static __global__ void dequantize_mul_mat_vec_q6_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
1029
+
1030
+ static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
1031
+
1032
+ const int row = blockIdx.y*blockDim.y + threadIdx.y;
1033
+ if (row > nrows) return;
1034
+
1035
+ const int num_blocks_per_row = ncols / QK_K;
1036
+ const int ib0 = row*num_blocks_per_row;
1037
+
1038
+ const block_q6_K * x = (const block_q6_K *)vx + ib0;
1039
+
1040
+ #if QK_K == 256
1041
+
1042
+ const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
1043
+ const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1
1044
+
1045
+ const int step = 16/K_QUANTS_PER_ITERATION; // 16 or 8
1046
+
1047
+ const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
1048
+ const int in = tid - step*im; // 0...15 or 0...7
1049
+
1050
+ #if K_QUANTS_PER_ITERATION == 1
1051
+ const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15
1052
+ const int is = 0;
1053
+ #else
1054
+ const int l0 = 4 * in; // 0, 4, 8, ..., 28
1055
+ const int is = in / 4;
1056
+ #endif
1057
+ const int ql_offset = 64*im + l0;
1058
+ const int qh_offset = 32*im + l0;
1059
+ const int s_offset = 8*im + is;
1060
+ const int y_offset = 128*im + l0;
1061
+
1062
+ float tmp = 0; // partial sum for thread in warp
1063
+
1064
+ for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
1065
+
1066
+ const float * y = yy + i * QK_K + y_offset;
1067
+ const uint8_t * ql = x[i].ql + ql_offset;
1068
+ const uint8_t * qh = x[i].qh + qh_offset;
1069
+ const int8_t * s = x[i].scales + s_offset;
1070
+
1071
+ const float d = x[i].d;
1072
+
1073
+ #if K_QUANTS_PER_ITERATION == 1
1074
+ float sum = y[ 0] * s[0] * d * ((int8_t)((ql[ 0] & 0xF) | ((qh[ 0] & 0x03) << 4)) - 32)
1075
+ + y[16] * s[1] * d * ((int8_t)((ql[16] & 0xF) | ((qh[16] & 0x03) << 4)) - 32)
1076
+ + y[32] * s[2] * d * ((int8_t)((ql[32] & 0xF) | ((qh[ 0] & 0x0c) << 2)) - 32)
1077
+ + y[48] * s[3] * d * ((int8_t)((ql[48] & 0xF) | ((qh[16] & 0x0c) << 2)) - 32)
1078
+ + y[64] * s[4] * d * ((int8_t)((ql[ 0] >> 4) | ((qh[ 0] & 0x30) >> 0)) - 32)
1079
+ + y[80] * s[5] * d * ((int8_t)((ql[16] >> 4) | ((qh[16] & 0x30) >> 0)) - 32)
1080
+ + y[96] * s[6] * d * ((int8_t)((ql[32] >> 4) | ((qh[ 0] & 0xc0) >> 2)) - 32)
1081
+ +y[112] * s[7] * d * ((int8_t)((ql[48] >> 4) | ((qh[16] & 0xc0) >> 2)) - 32);
1082
+ tmp += sum;
1083
+ #else
1084
+ float sum = 0;
1085
+ for (int l = 0; l < 4; ++l) {
1086
+ sum += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32)
1087
+ + y[l+32] * s[2] * d * ((int8_t)((ql[l+32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32)
1088
+ + y[l+64] * s[4] * d * ((int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32)
1089
+ + y[l+96] * s[6] * d * ((int8_t)((ql[l+32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32);
1090
+ }
1091
+ tmp += sum;
1092
+ #endif
1093
+
1094
+ }
1095
+
1096
+ #else
1097
+
1098
+ const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...7
1099
+ const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); // 0...3
1100
+
1101
+ const int step = tid * K_QUANTS_PER_ITERATION;
1102
+
1103
+ float tmp = 0; // partial sum for thread in warp
1104
+
1105
+ for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
1106
+
1107
+ const float * y = yy + i * QK_K + step;
1108
+ const uint8_t * ql = x[i].ql + step;
1109
+ const uint8_t * qh = x[i].qh + step;
1110
+ const int8_t * s = x[i].scales;
1111
+
1112
+ const float d = x[i+0].d;
1113
+
1114
+ float sum = 0;
1115
+ for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
1116
+ sum += y[j+ 0] * s[0] * d * ((int8_t)((ql[j+ 0] & 0xF) | ((qh[j] & 0x03) << 4)) - 32)
1117
+ + y[j+16] * s[1] * d * ((int8_t)((ql[j+16] & 0xF) | ((qh[j] & 0x0c) << 2)) - 32)
1118
+ + y[j+32] * s[2] * d * ((int8_t)((ql[j+ 0] >> 4) | ((qh[j] & 0x30) >> 0)) - 32)
1119
+ + y[j+48] * s[3] * d * ((int8_t)((ql[j+16] >> 4) | ((qh[j] & 0xc0) >> 2)) - 32);
1120
+ }
1121
+ tmp += sum;
1122
+
1123
+ }
1124
+
1125
+ #endif
1126
+
1127
+ // sum up partial sums and write back result
1128
+ __syncthreads();
1129
+ #pragma unroll
1130
+ for (int mask = 16; mask > 0; mask >>= 1) {
1131
+ tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
1132
+ }
1133
+
1134
+ if (tid == 0) {
1135
+ dst[row] = tmp;
1136
+ }
1137
+ }
1138
+
1139
+ static __device__ void convert_f16(const void * vx, const int ib, const int iqs, dfloat2 & v){
637
1140
  const half * x = (const half *) vx;
638
1141
 
639
- v0 = __half2float(x[ib + iqs + 0]);
640
- v1 = __half2float(x[ib + iqs + 1]);
1142
+ // automatic half -> float type cast if dfloat == float
1143
+ v.x = x[ib + iqs + 0];
1144
+ v.y = x[ib + iqs + 1];
641
1145
  }
642
1146
 
643
1147
  template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
@@ -654,13 +1158,15 @@ static __global__ void dequantize_block(const void * vx, float * y, const int k)
654
1158
  const int y_offset = qr == 1 ? 1 : qk/2;
655
1159
 
656
1160
  // dequantize
657
- float & v0 = y[iybs + iqs + 0];
658
- float & v1 = y[iybs + iqs + y_offset];
659
- dequantize_kernel(vx, ib, iqs, v0, v1);
1161
+ dfloat2 v;
1162
+ dequantize_kernel(vx, ib, iqs, v);
1163
+
1164
+ y[iybs + iqs + 0] = v.x;
1165
+ y[iybs + iqs + y_offset] = v.y;
660
1166
  }
661
1167
 
662
1168
  template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
663
- static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, float * dst, const int ncols, const int nrows) {
1169
+ static __global__ void dequantize_mul_mat_vec(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows) {
664
1170
  // qk = quantized weights per x block
665
1171
  // qr = number of quantized weights per data value in x block
666
1172
  const int row = blockIdx.y*blockDim.y + threadIdx.y;
@@ -675,7 +1181,12 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y,
675
1181
  const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter
676
1182
  const int y_offset = qr == 1 ? 1 : qk/2;
677
1183
 
678
- float tmp = 0.0f; // partial sum for thread in warp
1184
+ // partial sum for each thread
1185
+ #ifdef GGML_CUDA_DMMV_F16
1186
+ half2 tmp = {0.0f, 0.0f}; // two sums for f16 to take advantage of half2 intrinsics
1187
+ #else
1188
+ float tmp = 0.0f;
1189
+ #endif // GGML_CUDA_DMMV_F16
679
1190
 
680
1191
  for (int i = 0; i < ncols; i += iter_stride) {
681
1192
  const int col = i + vals_per_iter*tid;
@@ -689,14 +1200,21 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y,
689
1200
  // process 2 vals per j iter
690
1201
 
691
1202
  // dequantize
692
- float v0, v1;
693
- dequantize_kernel(vx, ib, iqs + j/qr, v0, v1);
694
1203
  // for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val
1204
+ dfloat2 v;
1205
+ dequantize_kernel(vx, ib, iqs + j/qr, v);
695
1206
 
696
1207
  // matrix multiplication
697
- tmp += v0 * y[iybs + iqs + j/qr + 0];
698
- tmp += v1 * y[iybs + iqs + j/qr + y_offset];
699
1208
  // for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2
1209
+ #ifdef GGML_CUDA_DMMV_F16
1210
+ tmp += __hmul2(v, {
1211
+ y[iybs + iqs + j/qr + 0],
1212
+ y[iybs + iqs + j/qr + y_offset]
1213
+ });
1214
+ #else
1215
+ tmp += v.x * y[iybs + iqs + j/qr + 0];
1216
+ tmp += v.y * y[iybs + iqs + j/qr + y_offset];
1217
+ #endif // GGML_CUDA_DMMV_F16
700
1218
  }
701
1219
  }
702
1220
 
@@ -708,47 +1226,11 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y,
708
1226
  }
709
1227
 
710
1228
  if (tid == 0) {
1229
+ #ifdef GGML_CUDA_DMMV_F16
1230
+ dst[row] = tmp.x + tmp.y;
1231
+ #else
711
1232
  dst[row] = tmp;
712
- }
713
- }
714
-
715
- template <int n_thread, dot_kernel_k_t dot_kernel>
716
- static __global__ void dequantize_mul_mat_vec_k(const void * vx, const float * y, float * dst, const int ncols, const int nrows) {
717
- const int row = blockIdx.y*blockDim.y + threadIdx.y;
718
-
719
- if (row >= nrows) {
720
- return;
721
- }
722
-
723
- const int tid = threadIdx.x;
724
-
725
- const int iter_stride = QK_K;
726
- const int vals_per_iter = iter_stride / n_thread;
727
- const int num_blocks_per_row = ncols / QK_K;
728
- const int ib0 = row*num_blocks_per_row;
729
-
730
- float tmp = 0; // partial sum for thread in warp
731
-
732
- for (int i = 0; i < ncols; i += iter_stride) {
733
- const int col = i + vals_per_iter*tid;
734
- const int ib = ib0 + col/QK_K; // x block index
735
- const int iqs = col%QK_K; // x quant index
736
- const int iybs = col - col%QK_K; // y block start index
737
-
738
- float v;
739
- dot_kernel(vx, ib, iqs, y + iybs, v);
740
- tmp += v;
741
- }
742
-
743
- // sum up partial sums and write back result
744
- __syncthreads();
745
- #pragma unroll
746
- for (int mask = 16; mask > 0; mask >>= 1) {
747
- tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
748
- }
749
-
750
- if (tid == 0) {
751
- dst[row] = tmp;
1233
+ #endif // GGML_CUDA_DMMV_F16
752
1234
  }
753
1235
  }
754
1236
 
@@ -1020,12 +1502,20 @@ static void dequantize_row_q8_0_cuda(const void * vx, float * y, const int k, cu
1020
1502
 
1021
1503
  static void dequantize_row_q2_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
1022
1504
  const int nb = k / QK_K;
1505
+ #if QK_K == 256
1023
1506
  dequantize_block_q2_K<<<nb, 64, 0, stream>>>(vx, y);
1507
+ #else
1508
+ dequantize_block_q2_K<<<nb, 32, 0, stream>>>(vx, y);
1509
+ #endif
1024
1510
  }
1025
1511
 
1026
1512
  static void dequantize_row_q3_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
1027
1513
  const int nb = k / QK_K;
1514
+ #if QK_K == 256
1028
1515
  dequantize_block_q3_K<<<nb, 64, 0, stream>>>(vx, y);
1516
+ #else
1517
+ dequantize_block_q3_K<<<nb, 32, 0, stream>>>(vx, y);
1518
+ #endif
1029
1519
  }
1030
1520
 
1031
1521
  static void dequantize_row_q4_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
@@ -1035,15 +1525,23 @@ static void dequantize_row_q4_K_cuda(const void * vx, float * y, const int k, cu
1035
1525
 
1036
1526
  static void dequantize_row_q5_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
1037
1527
  const int nb = k / QK_K;
1528
+ #if QK_K == 256
1038
1529
  dequantize_block_q5_K<<<nb, 64, 0, stream>>>(vx, y);
1530
+ #else
1531
+ dequantize_block_q5_K<<<nb, 32, 0, stream>>>(vx, y);
1532
+ #endif
1039
1533
  }
1040
1534
 
1041
1535
  static void dequantize_row_q6_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
1042
1536
  const int nb = k / QK_K;
1537
+ #if QK_K == 256
1043
1538
  dequantize_block_q6_K<<<nb, 64, 0, stream>>>(vx, y);
1539
+ #else
1540
+ dequantize_block_q6_K<<<nb, 32, 0, stream>>>(vx, y);
1541
+ #endif
1044
1542
  }
1045
1543
 
1046
- static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
1544
+ static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
1047
1545
  GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
1048
1546
  const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
1049
1547
  const dim3 block_nums(1, block_num_y, 1);
@@ -1052,7 +1550,7 @@ static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const float * y, f
1052
1550
  <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
1053
1551
  }
1054
1552
 
1055
- static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
1553
+ static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
1056
1554
  GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
1057
1555
  const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
1058
1556
  const dim3 block_nums(1, block_num_y, 1);
@@ -1061,7 +1559,7 @@ static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const float * y, f
1061
1559
  <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
1062
1560
  }
1063
1561
 
1064
- static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
1562
+ static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
1065
1563
  GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
1066
1564
  const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
1067
1565
  const dim3 block_nums(1, block_num_y, 1);
@@ -1070,7 +1568,7 @@ static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const float * y, f
1070
1568
  <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
1071
1569
  }
1072
1570
 
1073
- static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
1571
+ static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
1074
1572
  GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
1075
1573
  const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
1076
1574
  const dim3 block_nums(1, block_num_y, 1);
@@ -1079,7 +1577,7 @@ static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const float * y, f
1079
1577
  <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
1080
1578
  }
1081
1579
 
1082
- static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
1580
+ static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
1083
1581
  GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
1084
1582
  const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
1085
1583
  const dim3 block_nums(1, block_num_y, 1);
@@ -1090,47 +1588,44 @@ static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const float * y, f
1090
1588
 
1091
1589
  static void dequantize_mul_mat_vec_q2_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
1092
1590
  GGML_ASSERT(ncols % QK_K == 0);
1093
- const int ny = 2;
1591
+ const int ny = 2; // very slightly faster than 1 even when K_QUANTS_PER_ITERATION = 2
1094
1592
  const int block_num_y = (nrows + ny - 1) / ny;
1095
1593
  const dim3 block_nums(1, block_num_y, 1);
1096
1594
  const dim3 block_dims(32, ny, 1);
1097
- dequantize_mul_mat_vec_k<32, vec_dot_q2_K><<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
1595
+ dequantize_mul_mat_vec_q2_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
1098
1596
  }
1099
1597
 
1100
1598
  static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
1101
1599
  GGML_ASSERT(ncols % QK_K == 0);
1102
- const int ny = 2;
1600
+ const int ny = 2 / K_QUANTS_PER_ITERATION;
1103
1601
  const int block_num_y = (nrows + ny - 1) / ny;
1104
1602
  const dim3 block_nums(1, block_num_y, 1);
1105
1603
  const dim3 block_dims(32, ny, 1);
1106
- dequantize_mul_mat_vec_k<32, vec_dot_q3_K><<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
1604
+ dequantize_mul_mat_vec_q3_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
1107
1605
  }
1108
1606
 
1109
1607
  static void dequantize_mul_mat_vec_q4_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
1110
1608
  GGML_ASSERT(ncols % QK_K == 0);
1111
- const int ny = 2;
1609
+ const int ny = 2 / K_QUANTS_PER_ITERATION;
1112
1610
  const int block_num_y = (nrows + ny - 1) / ny;
1113
1611
  const dim3 block_nums(1, block_num_y, 1);
1114
1612
  const dim3 block_dims(32, ny, 1);
1115
- dequantize_mul_mat_vec_k<32, vec_dot_q4_K><<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
1613
+ dequantize_mul_mat_vec_q4_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
1116
1614
  }
1117
1615
 
1118
1616
  static void dequantize_mul_mat_vec_q5_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
1119
1617
  GGML_ASSERT(ncols % QK_K == 0);
1120
- const int ny = 2;
1121
- const int block_num_y = (nrows + ny - 1) / ny;
1122
- const dim3 block_nums(1, block_num_y, 1);
1123
- const dim3 block_dims(32, ny, 1);
1124
- dequantize_mul_mat_vec_k<32, vec_dot_q5_K><<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
1618
+ const dim3 block_dims(32, 1, 1);
1619
+ dequantize_mul_mat_vec_q5_k<<<nrows, block_dims, 0, stream>>>(vx, y, dst, ncols);
1125
1620
  }
1126
1621
 
1127
1622
  static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
1128
1623
  GGML_ASSERT(ncols % QK_K == 0);
1129
- const int ny = 2;
1624
+ const int ny = 2 / K_QUANTS_PER_ITERATION;
1130
1625
  const int block_num_y = (nrows + ny - 1) / ny;
1131
1626
  const dim3 block_nums(1, block_num_y, 1);
1132
1627
  const dim3 block_dims(32, ny, 1);
1133
- dequantize_mul_mat_vec_k<32, vec_dot_q6_K><<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
1628
+ dequantize_mul_mat_vec_q6_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
1134
1629
  }
1135
1630
 
1136
1631
  static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
@@ -1138,7 +1633,7 @@ static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, c
1138
1633
  dequantize_block<1, 1, convert_f16><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
1139
1634
  }
1140
1635
 
1141
- static void convert_mul_mat_vec_f16_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
1636
+ static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
1142
1637
  GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
1143
1638
  const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
1144
1639
  const dim3 block_nums(1, block_num_y, 1);
@@ -1306,19 +1801,13 @@ static void * g_scratch_buffer = nullptr;
1306
1801
  static size_t g_scratch_size = 1024*1024*1024; // 1 GB by default
1307
1802
  static size_t g_scratch_offset = 0;
1308
1803
 
1309
- #define GGML_CUDA_MAX_STREAMS 8 // Set this to 1 for reproducible matrix multiplication.
1310
- #define GGML_CUDA_MAX_EVENTS 64
1311
-
1312
1804
  static int g_device_count = -1;
1313
1805
  static int g_main_device = 0;
1314
1806
  static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0};
1315
1807
 
1316
1808
  static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
1317
1809
 
1318
- static cudaStream_t g_cudaStreams_main[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { nullptr };
1319
-
1320
- static cudaStream_t g_cudaStreams_memcpy_src1[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { nullptr };
1321
- static cudaEvent_t g_cudaEvents_memcpy_src1[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_EVENTS] = { nullptr };
1810
+ static cudaStream_t g_cudaStreams_main[GGML_CUDA_MAX_DEVICES] = { nullptr };
1322
1811
 
1323
1812
  void ggml_init_cublas() {
1324
1813
  static bool initialized = false;
@@ -1342,15 +1831,8 @@ void ggml_init_cublas() {
1342
1831
  for (int id = 0; id < g_device_count; ++id) {
1343
1832
  CUDA_CHECK(cudaSetDevice(id));
1344
1833
 
1345
- // create streams
1346
- for (int i = 0; i < GGML_CUDA_MAX_STREAMS; ++i) {
1347
- CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams_main[id][i], cudaStreamNonBlocking));
1348
- CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams_memcpy_src1[id][i], cudaStreamNonBlocking));
1349
- }
1350
- // create events
1351
- for (int i = 0; i < GGML_CUDA_MAX_EVENTS; ++i) {
1352
- CUDA_CHECK(cudaEventCreateWithFlags(&g_cudaEvents_memcpy_src1[id][i], cudaEventDisableTiming));
1353
- }
1834
+ // create main stream
1835
+ CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams_main[id], cudaStreamNonBlocking));
1354
1836
 
1355
1837
  // create cublas handle
1356
1838
  CUBLAS_CHECK(cublasCreate(&g_cublas_handles[id]));
@@ -1566,21 +2048,40 @@ inline void ggml_cuda_op_dequantize_mul_mat_vec(
1566
2048
  const int64_t ne00 = src0->ne[0];
1567
2049
  const int64_t nrows = i01_high - i01_low;
1568
2050
 
2051
+ // on some GPUs it is faster to convert src1 to half and to use half precision intrinsics
2052
+ #ifdef GGML_CUDA_DMMV_F16
2053
+ size_t ash;
2054
+ dfloat * src1_dfloat = nullptr; // dfloat == half
2055
+
2056
+ bool src1_convert_f16 = src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 ||
2057
+ src0->type == GGML_TYPE_Q5_0 || src0->type == GGML_TYPE_Q5_1 ||
2058
+ src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16;
2059
+
2060
+ if (src1_convert_f16) {
2061
+ src1_dfloat = (half *) ggml_cuda_pool_malloc(ne00*sizeof(half), &ash);
2062
+ ggml_cpy_f32_f16_cuda((char *) src1_ddf_i, (char *) src1_dfloat, ne00,
2063
+ ne00, 1, sizeof(float), 0, 0,
2064
+ ne00, 1, sizeof(half), 0, 0, cudaStream_main);
2065
+ }
2066
+ #else
2067
+ dfloat * src1_dfloat = src1_ddf_i; // dfloat == float, no conversion
2068
+ #endif // GGML_CUDA_DMMV_F16
2069
+
1569
2070
  switch (src0->type) {
1570
2071
  case GGML_TYPE_Q4_0:
1571
- dequantize_mul_mat_vec_q4_0_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
2072
+ dequantize_mul_mat_vec_q4_0_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
1572
2073
  break;
1573
2074
  case GGML_TYPE_Q4_1:
1574
- dequantize_mul_mat_vec_q4_1_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
2075
+ dequantize_mul_mat_vec_q4_1_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
1575
2076
  break;
1576
2077
  case GGML_TYPE_Q5_0:
1577
- dequantize_mul_mat_vec_q5_0_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
2078
+ dequantize_mul_mat_vec_q5_0_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
1578
2079
  break;
1579
2080
  case GGML_TYPE_Q5_1:
1580
- dequantize_mul_mat_vec_q5_1_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
2081
+ dequantize_mul_mat_vec_q5_1_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
1581
2082
  break;
1582
2083
  case GGML_TYPE_Q8_0:
1583
- dequantize_mul_mat_vec_q8_0_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
2084
+ dequantize_mul_mat_vec_q8_0_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
1584
2085
  break;
1585
2086
  case GGML_TYPE_Q2_K:
1586
2087
  dequantize_mul_mat_vec_q2_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
@@ -1598,7 +2099,7 @@ inline void ggml_cuda_op_dequantize_mul_mat_vec(
1598
2099
  dequantize_mul_mat_vec_q6_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
1599
2100
  break;
1600
2101
  case GGML_TYPE_F16:
1601
- convert_mul_mat_vec_f16_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
2102
+ convert_mul_mat_vec_f16_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
1602
2103
  break;
1603
2104
  default:
1604
2105
  GGML_ASSERT(false);
@@ -1606,6 +2107,12 @@ inline void ggml_cuda_op_dequantize_mul_mat_vec(
1606
2107
  }
1607
2108
  CUDA_CHECK(cudaGetLastError());
1608
2109
 
2110
+ #ifdef GGML_CUDA_DMMV_F16
2111
+ if (src1_convert_f16) {
2112
+ ggml_cuda_pool_free(src1_dfloat, ash);
2113
+ }
2114
+ #endif // GGML_CUDA_DMMV_F16
2115
+
1609
2116
  (void) src1;
1610
2117
  (void) dst;
1611
2118
  (void) src0_ddf_i;
@@ -1817,6 +2324,12 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
1817
2324
  size_t src1_asf[GGML_CUDA_MAX_DEVICES] = {0};
1818
2325
  size_t dst_asf[GGML_CUDA_MAX_DEVICES] = {0};
1819
2326
 
2327
+ // if multiple GPUs are used they need to wait for the main GPU to finish
2328
+ if (split && g_device_count > 1) {
2329
+ CUDA_CHECK(cudaSetDevice(g_main_device));
2330
+ CUDA_CHECK(cudaDeviceSynchronize());
2331
+ }
2332
+
1820
2333
  for (int id = 0; id < g_device_count; ++id) {
1821
2334
  if (!split && id != g_main_device) {
1822
2335
  continue;
@@ -1915,9 +2428,7 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
1915
2428
  }
1916
2429
  const int64_t i11 = i13*ne12 + i12;
1917
2430
 
1918
- cudaStream_t cudaStream_main = g_cudaStreams_main[id][i0 % GGML_CUDA_MAX_STREAMS];
1919
- cudaStream_t cudaStream_memcpy_src1 = g_cudaStreams_memcpy_src1[id][i0 % GGML_CUDA_MAX_STREAMS];
1920
- cudaEvent_t cudaEvent_memcpy_src1 = g_cudaEvents_memcpy_src1[id][i0 % GGML_CUDA_MAX_EVENTS];
2431
+ cudaStream_t cudaStream_main = g_cudaStreams_main[id];
1921
2432
 
1922
2433
  // for split tensors the data begins at i0 == i0_offset_low
1923
2434
  char * src0_ddq_i = src0_ddq[id] + (i0 - i0_offset_low)*src0_stride*src0_ts/src0_bs;
@@ -1945,14 +2456,14 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
1945
2456
  if (src1->backend == GGML_BACKEND_CPU) {
1946
2457
  GGML_ASSERT(!flatten_rows || nrows0 == ggml_nrows(src1));
1947
2458
  int64_t nrows1 = flatten_rows ? nrows0 : ne11;
1948
- CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src1_ddf_i, src1, i03, i02, 0, nrows1, cudaStream_memcpy_src1));
2459
+ CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src1_ddf_i, src1, i03, i02, 0, nrows1, cudaStream_main));
1949
2460
  } else if (src1->backend == GGML_BACKEND_GPU && src1_is_contiguous) {
1950
2461
  if (id != g_main_device) {
1951
2462
  GGML_ASSERT(!flatten_rows);
1952
2463
  float * src1_ddf_i_source = (float *) src1_extra->data_device[g_main_device];
1953
2464
  src1_ddf_i_source += i11*src1_stride;
1954
2465
  CUDA_CHECK(cudaMemcpyAsync(src1_ddf_i, src1_ddf_i_source, src1_stride*sizeof(float),
1955
- cudaMemcpyDeviceToDevice, cudaStream_memcpy_src1));
2466
+ cudaMemcpyDeviceToDevice, cudaStream_main));
1956
2467
  }
1957
2468
  } else if (src1_on_device && !src1_is_contiguous) {
1958
2469
  GGML_ASSERT(!split);
@@ -1961,7 +2472,6 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
1961
2472
  GGML_ASSERT(false);
1962
2473
  }
1963
2474
  }
1964
- CUDA_CHECK(cudaEventRecord(cudaEvent_memcpy_src1, cudaStream_memcpy_src1));
1965
2475
 
1966
2476
  if (!src0_on_device || !src0_is_contiguous) {
1967
2477
  if (src0_is_f32) {
@@ -1977,9 +2487,6 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
1977
2487
  CUDA_CHECK(cudaGetLastError());
1978
2488
  }
1979
2489
 
1980
- // wait with main stream until src1 memcpy is done
1981
- CUDA_CHECK(cudaStreamWaitEvent(cudaStream_main, cudaEvent_memcpy_src1, 0));
1982
-
1983
2490
  // do the computation
1984
2491
  op(src0, src1, dst, src0_ddq_i, src0_ddf_i, src1_ddf_i, dst_ddf_i, i02, i01_low, i01_high, i11, cudaStream_main);
1985
2492
 
@@ -2017,8 +2524,13 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
2017
2524
 
2018
2525
  // wait until each device is finished, then free their buffers
2019
2526
  for (int id = 0; id < g_device_count; ++id) {
2527
+ if (src0_asq[id] == 0 && src0_asf[id] == 0 && src1_asf[id] == 0 && dst_asf[id] == 0) {
2528
+ continue;
2529
+ }
2530
+
2020
2531
  CUDA_CHECK(cudaSetDevice(id));
2021
2532
  CUDA_CHECK(cudaDeviceSynchronize());
2533
+
2022
2534
  if (src0_asq[id] > 0) {
2023
2535
  ggml_cuda_pool_free(src0_ddq[id], src0_asq[id]);
2024
2536
  }
@@ -2084,7 +2596,7 @@ void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * sr
2084
2596
  const int64_t ne02 = src0->ne[2];
2085
2597
 
2086
2598
  CUDA_CHECK(cudaSetDevice(g_main_device));
2087
- cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device][0];
2599
+ cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device];
2088
2600
 
2089
2601
  struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
2090
2602
  void * src0_ddq = src0_extra->data_device[g_main_device];
@@ -2096,8 +2608,6 @@ void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * sr
2096
2608
  float * dst_ddf = (float *) dst_extra->data_device[g_main_device];
2097
2609
 
2098
2610
  ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, cudaStream_main);
2099
-
2100
- CUDA_CHECK(cudaDeviceSynchronize());
2101
2611
  }
2102
2612
 
2103
2613
  void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
@@ -2115,7 +2625,7 @@ void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1
2115
2625
  const int64_t nb02 = src0->nb[2];
2116
2626
 
2117
2627
  CUDA_CHECK(cudaSetDevice(g_main_device));
2118
- cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device][0];
2628
+ cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device];
2119
2629
 
2120
2630
  struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
2121
2631
  void * src0_ddq = src0_extra->data_device[g_main_device];
@@ -2130,8 +2640,6 @@ void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1
2130
2640
  const int channel_stride_x = nb02 / sizeof(half);
2131
2641
 
2132
2642
  ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, channel_stride_x, cudaStream_main);
2133
-
2134
- CUDA_CHECK(cudaDeviceSynchronize());
2135
2643
  }
2136
2644
 
2137
2645
  void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -2187,7 +2695,7 @@ void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tens
2187
2695
  const int64_t nb12 = src1->nb[2];
2188
2696
 
2189
2697
  CUDA_CHECK(cudaSetDevice(g_main_device));
2190
- cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device][0];
2698
+ cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device];
2191
2699
 
2192
2700
  const struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
2193
2701
  const struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
@@ -2205,8 +2713,6 @@ void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tens
2205
2713
  GGML_ASSERT(false);
2206
2714
  }
2207
2715
 
2208
- CUDA_CHECK(cudaDeviceSynchronize());
2209
-
2210
2716
  (void) dst;
2211
2717
  }
2212
2718
 
@@ -2313,6 +2819,7 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch) {
2313
2819
 
2314
2820
  tensor->backend = GGML_BACKEND_GPU;
2315
2821
  struct ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu;
2822
+ memset(extra, 0, sizeof(*extra));
2316
2823
 
2317
2824
  const bool inplace = (tensor->src0 != nullptr && tensor->src0->data == tensor->data) ||
2318
2825
  tensor->op == GGML_OP_VIEW;
@@ -2395,7 +2902,7 @@ void ggml_cuda_free_scratch() {
2395
2902
  bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor){
2396
2903
  ggml_cuda_func_t func;
2397
2904
  const bool any_on_device = tensor->backend == GGML_BACKEND_GPU
2398
- || tensor->src0->backend == GGML_BACKEND_GPU || tensor->src0->backend == GGML_BACKEND_GPU_SPLIT
2905
+ || (tensor->src0 != nullptr && (tensor->src0->backend == GGML_BACKEND_GPU || tensor->src0->backend == GGML_BACKEND_GPU_SPLIT))
2399
2906
  || (tensor->src1 != nullptr && tensor->src1->backend == GGML_BACKEND_GPU);
2400
2907
 
2401
2908
  switch (tensor->op) {