llama_cpp 0.2.1 → 0.2.2

Sign up to get free protection for your applications and to get access to all the features.
@@ -15,7 +15,11 @@
15
15
 
16
16
  #include "ggml.h"
17
17
 
18
- #define CL_DMMV_BLOCK_SIZE 32;
18
+ #if defined(_MSC_VER)
19
+ #pragma warning(disable: 4244 4267) // possible loss of data
20
+ #endif
21
+
22
+ #define CL_DMMV_BLOCK_SIZE 32
19
23
 
20
24
  #define MULTILINE_QUOTE(...) #__VA_ARGS__
21
25
  static std::string program_source = MULTILINE_QUOTE(
@@ -59,6 +63,46 @@ struct __attribute__ ((packed)) block_q8_0
59
63
  int8_t qs[QK8_0];
60
64
  };
61
65
 
66
+ struct __attribute__((packed)) block_q2_K
67
+ {
68
+ uint8_t scales[16];
69
+ uint8_t qs[64];
70
+ half d;
71
+ half dmin;
72
+ };
73
+
74
+ struct __attribute__((packed)) block_q3_K
75
+ {
76
+ uint8_t hmask[32];
77
+ uint8_t qs[64];
78
+ uint8_t scales[12];
79
+ half d;
80
+ };
81
+
82
+ struct __attribute__((packed)) block_q4_K
83
+ {
84
+ half d;
85
+ half dmin;
86
+ uint8_t scales[12];
87
+ uint8_t qs[128];
88
+ };
89
+
90
+ struct __attribute__((packed)) block_q5_K
91
+ {
92
+ half d;
93
+ half dmin;
94
+ uint8_t scales[12];
95
+ uint8_t qh[32];
96
+ uint8_t qs[128];
97
+ };
98
+
99
+ struct __attribute__((packed)) block_q6_K
100
+ {
101
+ uint8_t ql[128];
102
+ uint8_t qh[64];
103
+ int8_t scales[16];
104
+ half d;
105
+ };
62
106
 
63
107
  __kernel void convert_fp16_to_fp32(__global half* x, __global float* y) {
64
108
  const uint i = get_global_id(0);
@@ -131,8 +175,314 @@ void convert_f16(__global half* x, const int ib, const int iqs, float* v0, float
131
175
  *v0 = vload_half(0, &x[ib + 0]);
132
176
  *v1 = vload_half(0, &x[ib + 1]);
133
177
  }
178
+
179
+ inline void get_scale_min_k4(int j, const __global uint8_t *q, uint8_t *d, uint8_t *m)
180
+ {
181
+ if (j < 4)
182
+ {
183
+ *d = q[j] & 63;
184
+ *m = q[j + 4] & 63;
185
+ }
186
+ else
187
+ {
188
+ *d = (q[j + 4] & 0xF) | ((q[j - 4] >> 6) << 4);
189
+ *m = (q[j + 4] >> 4) | ((q[j - 0] >> 6) << 4);
190
+ }
191
+ }
192
+
193
+ __kernel void dequantize_block_q2_K(__global const struct block_q2_K *x, __global float *yy)
194
+ {
195
+ const int i = get_group_id(0);
196
+ const int tid = get_local_id(0);
197
+ const int n = tid / 32;
198
+ const int l = tid - 32 * n;
199
+ const int is = 8 * n + l / 16;
200
+
201
+ const uint8_t q = x[i].qs[32 * n + l];
202
+ __global float *y = yy + i * 256 + 128 * n;
203
+
204
+ const float dall = vload_half(0, &x[i].d);
205
+ const float dmin = vload_half(0, &x[i].dmin);
206
+
207
+ y[l + 0] = dall * (x[i].scales[is + 0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is + 0] >> 4);
208
+ y[l + 32] = dall * (x[i].scales[is + 2] & 0xF) * ((q >> 2) & 3) - dmin * (x[i].scales[is + 2] >> 4);
209
+ y[l + 64] = dall * (x[i].scales[is + 4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is + 4] >> 4);
210
+ y[l + 96] = dall * (x[i].scales[is + 6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is + 6] >> 4);
211
+ }
212
+
213
+ __kernel void dequantize_block_q3_K(__global const struct block_q3_K *x, __global float *yy)
214
+ {
215
+ int r = get_local_id(0) / 4;
216
+ int i = get_group_id(0);
217
+ int tid = r / 2;
218
+ int is0 = r % 2;
219
+ int l0 = 16 * is0 + 4 * (get_local_id(0) % 4);
220
+ int n = tid / 4;
221
+ int j = tid - 4 * n;
222
+
223
+ uint8_t m = 1 << (4 * n + j);
224
+ int is = 8 * n + 2 * j + is0;
225
+ int shift = 2 * j;
226
+
227
+ int8_t us = is < 4 ? (x[i].scales[is - 0] & 0xF) | (((x[i].scales[is + 8] >> 0) & 3) << 4)
228
+ : is < 8 ? (x[i].scales[is - 0] & 0xF) | (((x[i].scales[is + 4] >> 2) & 3) << 4)
229
+ : is < 12 ? (x[i].scales[is - 8] >> 4) | (((x[i].scales[is + 0] >> 4) & 3) << 4)
230
+ : (x[i].scales[is - 8] >> 4) | (((x[i].scales[is - 4] >> 6) & 3) << 4);
231
+ float d_all = vload_half(0, &x[i].d);
232
+ float dl = d_all * (us - 32);
233
+
234
+ __global float *y = yy + i * 256 + 128 * n + 32 * j;
235
+ const __global uint8_t *q = x[i].qs + 32 * n;
236
+ const __global uint8_t *hm = x[i].hmask;
237
+
238
+ for (int l = l0; l < l0 + 4; ++l)
239
+ y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4));
240
+ }
241
+
242
+ __kernel void dequantize_block_q4_K(__global const struct block_q4_K *x, __global float *yy)
243
+ {
244
+ const int i = get_group_id(0);
245
+ const int tid = get_local_id(0);
246
+ const int il = tid / 8;
247
+ const int ir = tid % 8;
248
+ const int is = 2 * il;
249
+ const int n = 4;
250
+
251
+ __global float *y = yy + i * 256 + 64 * il + n * ir;
252
+
253
+ const float dall = vload_half(0, &x[i].d);
254
+ const float dmin = vload_half(0, &x[i].dmin);
255
+
256
+ __global const uint8_t *q = x[i].qs + 32 * il + n * ir;
257
+
258
+ uint8_t sc, m;
259
+ get_scale_min_k4(is + 0, x[i].scales, &sc, &m);
260
+ float d1 = dall * sc;
261
+ float m1 = dmin * m;
262
+ get_scale_min_k4(is + 1, x[i].scales, &sc, &m);
263
+ float d2 = dall * sc;
264
+ float m2 = dmin * m;
265
+ for (int l = 0; l < n; ++l)
266
+ {
267
+ y[l + 0] = d1 * (q[l] & 0xF) - m1;
268
+ y[l + 32] = d2 * (q[l] >> 4) - m2;
269
+ }
270
+ }
271
+
272
+ __kernel void dequantize_block_q5_K(__global const struct block_q5_K *x, __global float *yy)
273
+ {
274
+ const int i = get_group_id(0);
275
+ const int tid = get_local_id(0);
276
+ const int il = tid / 16;
277
+ const int ir = tid % 16;
278
+ const int is = 2 * il;
279
+
280
+ __global float *y = yy + i * 256 + 64 * il + 2 * ir;
281
+
282
+ const float dall = vload_half(0, &x[i].d);
283
+ const float dmin = vload_half(0, &x[i].dmin);
284
+
285
+ __global const uint8_t *ql = x[i].qs + 32 * il + 2 * ir;
286
+ __global const uint8_t *qh = x[i].qh + 2 * ir;
287
+
288
+ uint8_t sc, m;
289
+ get_scale_min_k4(is + 0, x[i].scales, &sc, &m);
290
+ const float d1 = dall * sc;
291
+ const float m1 = dmin * m;
292
+ get_scale_min_k4(is + 1, x[i].scales, &sc, &m);
293
+ const float d2 = dall * sc;
294
+ const float m2 = dmin * m;
295
+
296
+ uint8_t hm = 1 << (2 * il);
297
+ y[0] = d1 * ((ql[0] & 0xF) + (qh[0] & hm ? 16 : 0)) - m1;
298
+ y[1] = d1 * ((ql[1] & 0xF) + (qh[1] & hm ? 16 : 0)) - m1;
299
+ hm <<= 1;
300
+ y[32] = d2 * ((ql[0] >> 4) + (qh[0] & hm ? 16 : 0)) - m2;
301
+ y[33] = d2 * ((ql[1] >> 4) + (qh[1] & hm ? 16 : 0)) - m2;
302
+ }
303
+
304
+ __kernel void dequantize_block_q6_K(__global const struct block_q6_K *x, __global float *yy)
305
+ {
306
+ const int i = get_group_id(0);
307
+ const int tid = get_local_id(0);
308
+ const int ip = tid / 32;
309
+ const int il = tid - 32 * ip;
310
+ const int is = 8 * ip + il / 16;
311
+
312
+ __global float *y = yy + i * 256 + 128 * ip + il;
313
+
314
+ const float d = vload_half(0, &x[i].d);
315
+
316
+ __global const uint8_t *ql = x[i].ql + 64 * ip + il;
317
+ const uint8_t qh = x[i].qh[32 * ip + il];
318
+ __global const int8_t *sc = x[i].scales + is;
319
+
320
+ y[0] = d * sc[0] * ((int8_t)((ql[0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32);
321
+ y[32] = d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32);
322
+ y[64] = d * sc[4] * ((int8_t)((ql[0] >> 4) | (((qh >> 4) & 3) << 4)) - 32);
323
+ y[96] = d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32);
324
+ }
325
+
326
+
327
+ void vec_dot_q2_K(__global const struct block_q2_K* x, const int ib, const int iqs, const __global float *yy, float *result) {
328
+
329
+ int n = iqs / 128;
330
+ int r = iqs - 128 * n;
331
+ int l = r / 8;
332
+
333
+ __global const float *y = yy + 128 * n + l;
334
+ __global const uint8_t *q = x[ib].qs + 32 * n + l;
335
+ __global const uint8_t *s = x[ib].scales + 8 * n;
336
+
337
+ const float dall = vload_half(0, &x[ib].d);
338
+ const float dmin = vload_half(0, &x[ib].dmin);
339
+
340
+ float sum = y[ 0] * (dall * ((s[0] & 0xF) * ((q[ 0] >> 0) & 3)) - dmin * (s[0] >> 4))
341
+ + y[ 32] * (dall * ((s[2] & 0xF) * ((q[ 0] >> 2) & 3)) - dmin * (s[2] >> 4))
342
+ + y[ 64] * (dall * ((s[4] & 0xF) * ((q[ 0] >> 4) & 3)) - dmin * (s[4] >> 4))
343
+ + y[ 96] * (dall * ((s[6] & 0xF) * ((q[ 0] >> 6) & 3)) - dmin * (s[6] >> 4))
344
+ + y[ 16] * (dall * ((s[1] & 0xF) * ((q[16] >> 0) & 3)) - dmin * (s[1] >> 4))
345
+ + y[ 48] * (dall * ((s[3] & 0xF) * ((q[16] >> 2) & 3)) - dmin * (s[3] >> 4))
346
+ + y[ 80] * (dall * ((s[5] & 0xF) * ((q[16] >> 4) & 3)) - dmin * (s[5] >> 4))
347
+ + y[112] * (dall * ((s[7] & 0xF) * ((q[16] >> 6) & 3)) - dmin * (s[7] >> 4));
348
+
349
+ *result = sum;
350
+ }
351
+
352
+ void vec_dot_q3_K(__global const struct block_q3_K* x, const int ib, const int iqs, const __global float *yy, float *result) {
353
+
354
+ const uint32_t kmask1 = 0x03030303;
355
+ const uint32_t kmask2 = 0x0f0f0f0f;
356
+
357
+ uint32_t aux[3];
358
+ uint32_t utmp[4];
359
+
360
+ int n = iqs/128;
361
+ int r = iqs - 128*n;
362
+ int l = r/8;
363
+
364
+ __global const float * y = yy + 128*n + l;
365
+ __global const uint8_t * q = x[ib].qs + 32*n + l;
366
+ __global const uint8_t * hm = x[ib].hmask + l;
367
+ const int8_t * s = (const int8_t *)utmp + 8*n;
368
+
369
+ aux[0] = x[ib].scales[0] | x[ib].scales[1] << 8 | x[ib].scales[2] << 16 | x[ib].scales[3] << 24;
370
+ aux[1] = x[ib].scales[4] | x[ib].scales[5] << 8 | x[ib].scales[6] << 16 | x[ib].scales[7] << 24;
371
+ aux[2] = x[ib].scales[8] | x[ib].scales[9] << 8 | x[ib].scales[10] << 16 | x[ib].scales[11] << 24;
372
+
373
+ utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
374
+ utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
375
+ utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
376
+ utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
377
+
378
+ const float dall = vload_half(0, &x[ib].d);
379
+ const uint8_t m = 1 << (4*n);
380
+
381
+ float sum = y[ 0] * (s[0] - 32) * (((q[ 0] >> 0) & 3) - (hm[ 0] & (m << 0) ? 0 : 4))
382
+ + y[ 32] * (s[2] - 32) * (((q[ 0] >> 2) & 3) - (hm[ 0] & (m << 1) ? 0 : 4))
383
+ + y[ 64] * (s[4] - 32) * (((q[ 0] >> 4) & 3) - (hm[ 0] & (m << 2) ? 0 : 4))
384
+ + y[ 96] * (s[6] - 32) * (((q[ 0] >> 6) & 3) - (hm[ 0] & (m << 3) ? 0 : 4))
385
+ + y[ 16] * (s[1] - 32) * (((q[16] >> 0) & 3) - (hm[16] & (m << 0) ? 0 : 4))
386
+ + y[ 48] * (s[3] - 32) * (((q[16] >> 2) & 3) - (hm[16] & (m << 1) ? 0 : 4))
387
+ + y[ 80] * (s[5] - 32) * (((q[16] >> 4) & 3) - (hm[16] & (m << 2) ? 0 : 4))
388
+ + y[112] * (s[7] - 32) * (((q[16] >> 6) & 3) - (hm[16] & (m << 3) ? 0 : 4));
389
+
390
+ *result = sum * dall;
391
+
392
+ }
393
+
394
+ void vec_dot_q4_K(__global const struct block_q4_K* x, const int ib, const int iqs, const __global float *yy, float *result) {
395
+
396
+ const int j = iqs / 64; // j is in 0...3
397
+ const int ir = (iqs - 64*j)/2; // ir is in 0...28 in steps of 4
398
+ const int is = 2*j; // is is in 0...6 in steps of 2
399
+
400
+ __global const float * y = yy + 64*j + ir;
401
+ __global const uint8_t * q = x[ib].qs + 32*j + ir;
402
+
403
+ const float dall = vload_half(0, &x[ib].d);
404
+ const float dmin = vload_half(0, &x[ib].dmin);
405
+
406
+ uint8_t sc, m;
407
+ get_scale_min_k4(is + 0, x[ib].scales, &sc, &m);
408
+ const float d1 = dall * sc;
409
+ const float m1 = dmin * m;
410
+ get_scale_min_k4(is + 1, x[ib].scales, &sc, &m);
411
+ const float d2 = dall * sc;
412
+ const float m2 = dmin * m;
413
+
414
+ float sum = 0;
415
+ for (int k = 0; k < 4; ++k) {
416
+ sum += y[k + 0] * (d1 * (q[k] & 0xF) - m1);
417
+ sum += y[k + 32] * (d2 * (q[k] >> 4) - m2);
418
+ }
419
+
420
+ *result = sum;
421
+ }
422
+
423
+ void vec_dot_q5_K(__global const struct block_q5_K* x, const int ib, const int iqs, const __global float *yy, float *result) {
424
+
425
+ const int j = iqs / 64;
426
+ const int ir = (iqs - 64*j)/2;
427
+ const int is = 2*j;
428
+
429
+ __global const float * y = yy + 64*j + ir;
430
+ __global const uint8_t * ql = x[ib].qs + 32*j + ir;
431
+ __global const uint8_t * qh = x[ib].qh + ir;
432
+
433
+ const float dall = vload_half(0, &x[ib].d);
434
+ const float dmin = vload_half(0, &x[ib].dmin);
435
+
436
+ uint8_t sc, m;
437
+ get_scale_min_k4(is + 0, x[ib].scales, &sc, &m);
438
+ const float d1 = dall * sc;
439
+ const float m1 = dmin * m;
440
+ get_scale_min_k4(is + 1, x[ib].scales, &sc, &m);
441
+ const float d2 = dall * sc;
442
+ const float m2 = dmin * m;
443
+
444
+ uint8_t hm = 1 << is;
445
+ float sum = 0;
446
+ for (int k = 0; k < 4; ++k) {
447
+ sum += y[k + 0] * (d1 * ((ql[k] & 0xF) + (qh[k] & hm ? 16 : 0)) - m1);
448
+ }
449
+ hm <<= 1;
450
+ for (int k = 0; k < 4; ++k) {
451
+ sum += y[k + 32] * (d2 * ((ql[k] >> 4) + (qh[k] & hm ? 16 : 0)) - m2);
452
+ }
453
+ *result = sum;
454
+
455
+ }
456
+
457
+ void vec_dot_q6_K(__global const struct block_q6_K* x, const int ib, const int iqs, const __global float *yy, float *result) {
458
+
459
+
460
+ const int ip = iqs / 128; // 0 or 1
461
+ const int il = (iqs - 128*ip)/8; // 0...15
462
+ const int is = 8*ip;
463
+
464
+ __global const float * y = yy + 128*ip + il;
465
+
466
+ const float d = vload_half(0, &x[ib].d);
467
+
468
+ __global const uint8_t * ql = x[ib].ql + 64*ip + il;
469
+ __global const uint8_t * qh = x[ib].qh + 32*ip + il;
470
+ __global const int8_t * sc = x[ib].scales + is;
471
+
472
+ *result = y[ 0] * d * sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh[ 0] >> 0) & 3) << 4)) - 32)
473
+ + y[ 32] * d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh[ 0] >> 2) & 3) << 4)) - 32)
474
+ + y[ 64] * d * sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh[ 0] >> 4) & 3) << 4)) - 32)
475
+ + y[ 96] * d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh[ 0] >> 6) & 3) << 4)) - 32)
476
+ + y[ 16] * d * sc[1] * ((int8_t)((ql[16] & 0xF) | (((qh[16] >> 0) & 3) << 4)) - 32)
477
+ + y[ 48] * d * sc[3] * ((int8_t)((ql[48] & 0xF) | (((qh[16] >> 2) & 3) << 4)) - 32)
478
+ + y[ 80] * d * sc[5] * ((int8_t)((ql[16] >> 4) | (((qh[16] >> 4) & 3) << 4)) - 32)
479
+ + y[112] * d * sc[7] * ((int8_t)((ql[48] >> 4) | (((qh[16] >> 6) & 3) << 4)) - 32);
480
+
481
+ }
482
+
134
483
  );
135
484
 
485
+
136
486
  std::string dequant_template = MULTILINE_QUOTE(
137
487
  __kernel void KERNEL_NAME(__global X_TYPE* x, __global float* y) {
138
488
  const int i = get_group_id(0)*get_local_size(0) + get_local_id(0)*2;
@@ -160,7 +510,7 @@ __kernel void KERNEL_NAME(__global X_TYPE* x, __global float* y) {
160
510
  std::string dequant_mul_mat_vec_template = MULTILINE_QUOTE(
161
511
  __kernel void KERNEL_NAME(__global X_TYPE* x, __local float* tmp, __global float* y, __global float* dst, const int ncols) {
162
512
  const int block_size = get_local_size(0);
163
- const int row = get_global_id(0) / block_size;
513
+ const int row = get_group_id(0);
164
514
  const int tid = get_local_id(0);
165
515
 
166
516
  const uint qk = QUANT_K;
@@ -199,6 +549,45 @@ __kernel void KERNEL_NAME(__global X_TYPE* x, __local float* tmp, __global float
199
549
  }
200
550
  );
201
551
 
552
+ std::string dequant_mul_mat_vec_k_template = MULTILINE_QUOTE(
553
+ __kernel void KERNEL_NAME(__global X_TYPE* x, __local float* tmp, __global float* y, __global float* dst, const int ncols) {
554
+ const int block_size = get_local_size(0);
555
+ const int row = get_group_id(0);
556
+ const int tid = get_local_id(0);
557
+
558
+ const int iter_stride = 256;
559
+ const int vals_per_iter = iter_stride / block_size;
560
+ const int num_blocks_per_row = ncols / 256;
561
+ const int ib0 = row*num_blocks_per_row;
562
+
563
+ tmp[tid] = 0;
564
+
565
+ for (int i = 0; i < ncols; i += iter_stride) {
566
+ const int col = i + vals_per_iter*tid;
567
+ const int ib = ib0 + col/256; // x block index
568
+ const int iqs = col%256; // x quant index
569
+ const int iybs = col - col%256; // y block start index
570
+
571
+ // dequantize
572
+ float v;
573
+ DOT_KERNEL(x, ib, iqs, y + iybs, &v);
574
+ tmp[tid] += v;
575
+ }
576
+
577
+ // sum up partial sums and write back result
578
+ barrier(CLK_LOCAL_MEM_FENCE);
579
+ for (int s=block_size/2; s>0; s>>=1) {
580
+ if (tid < s) {
581
+ tmp[tid] += tmp[tid + s];
582
+ }
583
+ barrier(CLK_LOCAL_MEM_FENCE);
584
+ }
585
+ if (tid == 0) {
586
+ dst[row] = tmp[0];
587
+ }
588
+ }
589
+ );
590
+
202
591
  std::string mul_template = MULTILINE_QUOTE(
203
592
  __kernel void KERNEL_NAME(__global TYPE* x, const int x_offset, __global TYPE* y, const int y_offset, __global TYPE* dst, const int dst_offset, const int ky) {
204
593
  const int i = get_group_id(0)*get_local_size(0) + get_local_id(0);
@@ -260,6 +649,18 @@ std::array<std::string, 2> mul_str_values = {
260
649
  "mul_f32", "float"
261
650
  };
262
651
 
652
+ std::array<std::string, 3> dmmv_k_str_keys = {
653
+ "KERNEL_NAME", "X_TYPE", "DOT_KERNEL"
654
+ };
655
+
656
+ std::array<std::string, 15> dmmv_k_str_values = {
657
+ "dequantize_mul_mat_vec_q2_K", "struct block_q2_K", "vec_dot_q2_K",
658
+ "dequantize_mul_mat_vec_q3_K", "struct block_q3_K", "vec_dot_q3_K",
659
+ "dequantize_mul_mat_vec_q4_K", "struct block_q4_K", "vec_dot_q4_K",
660
+ "dequantize_mul_mat_vec_q5_K", "struct block_q5_K", "vec_dot_q5_K",
661
+ "dequantize_mul_mat_vec_q6_K", "struct block_q6_K", "vec_dot_q6_K",
662
+ };
663
+
263
664
  std::string& replace(std::string& s, const std::string& from, const std::string& to) {
264
665
  size_t pos = 0;
265
666
  while ((pos = s.find(from, pos)) != std::string::npos) {
@@ -289,6 +690,14 @@ std::string generate_kernels() {
289
690
  }
290
691
  src << mul_kernel << '\n';
291
692
  }
693
+ for (size_t i = 0; i < dmmv_k_str_values.size(); i += dmmv_k_str_keys.size()) {
694
+ std::string dmmv_k_kernel = dequant_mul_mat_vec_k_template;
695
+ for (size_t j = 0; j < dmmv_k_str_keys.size(); j++) {
696
+ replace(dmmv_k_kernel, dmmv_k_str_keys[j], dmmv_k_str_values[i + j]);
697
+ }
698
+ src << dmmv_k_kernel << '\n';
699
+ }
700
+
292
701
  return src.str();
293
702
  }
294
703
 
@@ -300,6 +709,8 @@ static cl_program program;
300
709
  static cl_kernel convert_row_f16_cl;
301
710
  static cl_kernel dequantize_row_q4_0_cl, dequantize_row_q4_1_cl, dequantize_row_q5_0_cl, dequantize_row_q5_1_cl, dequantize_row_q8_0_cl;
302
711
  static cl_kernel dequantize_mul_mat_vec_q4_0_cl, dequantize_mul_mat_vec_q4_1_cl, dequantize_mul_mat_vec_q5_0_cl, dequantize_mul_mat_vec_q5_1_cl, dequantize_mul_mat_vec_q8_0_cl, convert_mul_mat_vec_f16_cl;
712
+ static cl_kernel dequantize_block_q2_k_cl, dequantize_block_q3_k_cl, dequantize_block_q4_k_cl, dequantize_block_q5_k_cl, dequantize_block_q6_k_cl;
713
+ static cl_kernel dequantize_mul_mat_vec_q2_K_cl, dequantize_mul_mat_vec_q3_K_cl, dequantize_mul_mat_vec_q4_K_cl, dequantize_mul_mat_vec_q5_K_cl, dequantize_mul_mat_vec_q6_K_cl;
303
714
  static cl_kernel mul_f32_cl;
304
715
  static bool fp16_support;
305
716
 
@@ -529,6 +940,12 @@ void ggml_cl_init(void) {
529
940
  CL_CHECK((dequantize_row_q5_0_cl = clCreateKernel(program, "dequantize_row_q5_0", &err), err));
530
941
  CL_CHECK((dequantize_row_q5_1_cl = clCreateKernel(program, "dequantize_row_q5_1", &err), err));
531
942
  CL_CHECK((dequantize_row_q8_0_cl = clCreateKernel(program, "dequantize_row_q8_0", &err), err));
943
+ CL_CHECK((dequantize_row_q8_0_cl = clCreateKernel(program, "dequantize_row_q8_0", &err), err));
944
+ CL_CHECK((dequantize_block_q2_k_cl = clCreateKernel(program, "dequantize_block_q2_K", &err), err));
945
+ CL_CHECK((dequantize_block_q3_k_cl = clCreateKernel(program, "dequantize_block_q3_K", &err), err));
946
+ CL_CHECK((dequantize_block_q4_k_cl = clCreateKernel(program, "dequantize_block_q4_K", &err), err));
947
+ CL_CHECK((dequantize_block_q5_k_cl = clCreateKernel(program, "dequantize_block_q5_K", &err), err));
948
+ CL_CHECK((dequantize_block_q6_k_cl = clCreateKernel(program, "dequantize_block_q6_K", &err), err));
532
949
 
533
950
  // dequant mul mat kernel
534
951
  CL_CHECK((dequantize_mul_mat_vec_q4_0_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q4_0", &err), err));
@@ -537,6 +954,11 @@ void ggml_cl_init(void) {
537
954
  CL_CHECK((dequantize_mul_mat_vec_q5_1_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q5_1", &err), err));
538
955
  CL_CHECK((dequantize_mul_mat_vec_q8_0_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q8_0", &err), err));
539
956
  CL_CHECK((convert_mul_mat_vec_f16_cl = clCreateKernel(program, "convert_mul_mat_vec_f16", &err), err));
957
+ CL_CHECK((dequantize_mul_mat_vec_q2_K_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q2_K", &err), err));
958
+ CL_CHECK((dequantize_mul_mat_vec_q3_K_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q3_K", &err), err));
959
+ CL_CHECK((dequantize_mul_mat_vec_q4_K_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q4_K", &err), err));
960
+ CL_CHECK((dequantize_mul_mat_vec_q5_K_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q5_K", &err), err));
961
+ CL_CHECK((dequantize_mul_mat_vec_q6_K_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q6_K", &err), err));
540
962
 
541
963
  // mul kernel
542
964
  CL_CHECK((mul_f32_cl = clCreateKernel(program, "mul_f32", &err), err));
@@ -554,6 +976,16 @@ static cl_kernel* ggml_get_to_fp32_cl(ggml_type type) {
554
976
  return &dequantize_row_q5_1_cl;
555
977
  case GGML_TYPE_Q8_0:
556
978
  return &dequantize_row_q8_0_cl;
979
+ case GGML_TYPE_Q2_K:
980
+ return &dequantize_block_q2_k_cl;
981
+ case GGML_TYPE_Q3_K:
982
+ return &dequantize_block_q3_k_cl;
983
+ case GGML_TYPE_Q4_K:
984
+ return &dequantize_block_q4_k_cl;
985
+ case GGML_TYPE_Q5_K:
986
+ return &dequantize_block_q5_k_cl;
987
+ case GGML_TYPE_Q6_K:
988
+ return &dequantize_block_q6_k_cl;
557
989
  case GGML_TYPE_F16:
558
990
  return &convert_row_f16_cl;
559
991
  default:
@@ -561,6 +993,50 @@ static cl_kernel* ggml_get_to_fp32_cl(ggml_type type) {
561
993
  }
562
994
  }
563
995
 
996
+ static size_t ggml_cl_global_denom(ggml_type type) {
997
+ switch (type) {
998
+ case GGML_TYPE_Q4_0:
999
+ case GGML_TYPE_Q4_1:
1000
+ case GGML_TYPE_Q5_0:
1001
+ case GGML_TYPE_Q5_1:
1002
+ case GGML_TYPE_Q8_0:
1003
+ return 1;
1004
+ case GGML_TYPE_Q2_K:
1005
+ case GGML_TYPE_Q3_K:
1006
+ return 4;
1007
+ case GGML_TYPE_Q4_K:
1008
+ return 8;
1009
+ case GGML_TYPE_Q5_K:
1010
+ case GGML_TYPE_Q6_K:
1011
+ return 4;
1012
+ case GGML_TYPE_F16:
1013
+ default:
1014
+ return 1;
1015
+ }
1016
+ }
1017
+
1018
+ static size_t ggml_cl_local_size(ggml_type type) {
1019
+ switch (type) {
1020
+ case GGML_TYPE_Q4_0:
1021
+ case GGML_TYPE_Q4_1:
1022
+ case GGML_TYPE_Q5_0:
1023
+ case GGML_TYPE_Q5_1:
1024
+ case GGML_TYPE_Q8_0:
1025
+ return 0;
1026
+ case GGML_TYPE_Q2_K:
1027
+ case GGML_TYPE_Q3_K:
1028
+ return 64;
1029
+ case GGML_TYPE_Q4_K:
1030
+ return 32;
1031
+ case GGML_TYPE_Q5_K:
1032
+ case GGML_TYPE_Q6_K:
1033
+ return 64;
1034
+ case GGML_TYPE_F16:
1035
+ default:
1036
+ return 0;
1037
+ }
1038
+ }
1039
+
564
1040
  static cl_kernel* ggml_get_dequantize_mul_mat_vec_cl(ggml_type type) {
565
1041
  switch (type) {
566
1042
  case GGML_TYPE_Q4_0:
@@ -575,6 +1051,16 @@ static cl_kernel* ggml_get_dequantize_mul_mat_vec_cl(ggml_type type) {
575
1051
  return &dequantize_mul_mat_vec_q8_0_cl;
576
1052
  case GGML_TYPE_F16:
577
1053
  return &convert_mul_mat_vec_f16_cl;
1054
+ case GGML_TYPE_Q2_K:
1055
+ return &dequantize_mul_mat_vec_q2_K_cl;
1056
+ case GGML_TYPE_Q3_K:
1057
+ return &dequantize_mul_mat_vec_q3_K_cl;
1058
+ case GGML_TYPE_Q4_K:
1059
+ return &dequantize_mul_mat_vec_q4_K_cl;
1060
+ case GGML_TYPE_Q5_K:
1061
+ return &dequantize_mul_mat_vec_q5_K_cl;
1062
+ case GGML_TYPE_Q6_K:
1063
+ return &dequantize_mul_mat_vec_q6_K_cl;
578
1064
  default:
579
1065
  return nullptr;
580
1066
  }
@@ -1017,6 +1503,9 @@ static void ggml_cl_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor *
1017
1503
  cl_kernel* dmmv = ggml_get_dequantize_mul_mat_vec_cl(type);
1018
1504
  GGML_ASSERT(to_fp32_cl != nullptr);
1019
1505
 
1506
+ const size_t global_denom = ggml_cl_global_denom(type);
1507
+ const size_t local = ggml_cl_local_size(type);
1508
+
1020
1509
  size_t ev_idx = 0;
1021
1510
  std::vector<cl_event> events;
1022
1511
 
@@ -1049,10 +1538,10 @@ static void ggml_cl_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor *
1049
1538
  CL_CHECK(clEnqueueNDRangeKernel(queue, *dmmv, 1, NULL, &global, &local, events.size() - 1, events.data(), events.data() + ev_idx++));
1050
1539
  } else { // general dequantization kernel + CLBlast matrix matrix multiplication
1051
1540
  // convert src0 to fp32 on device
1052
- const size_t global = x_ne;
1541
+ const size_t global = x_ne / global_denom;
1053
1542
  CL_CHECK(clSetKernelArg(*to_fp32_cl, 0, sizeof(cl_mem), &d_Q));
1054
1543
  CL_CHECK(clSetKernelArg(*to_fp32_cl, 1, sizeof(cl_mem), &d_X));
1055
- CL_CHECK(clEnqueueNDRangeKernel(queue, *to_fp32_cl, 1, NULL, &global, NULL, events.size(), !events.empty() ? events.data() : NULL, NULL));
1544
+ CL_CHECK(clEnqueueNDRangeKernel(queue, *to_fp32_cl, 1, NULL, &global, local > 0 ? &local : NULL, events.size(), !events.empty() ? events.data() : NULL, NULL));
1056
1545
 
1057
1546
  // copy src1 to device
1058
1547
  CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_Y, 0, src1, i03, i02, NULL));