llama_cpp 0.2.1 → 0.2.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -15,7 +15,11 @@
15
15
 
16
16
  #include "ggml.h"
17
17
 
18
- #define CL_DMMV_BLOCK_SIZE 32;
18
+ #if defined(_MSC_VER)
19
+ #pragma warning(disable: 4244 4267) // possible loss of data
20
+ #endif
21
+
22
+ #define CL_DMMV_BLOCK_SIZE 32
19
23
 
20
24
  #define MULTILINE_QUOTE(...) #__VA_ARGS__
21
25
  static std::string program_source = MULTILINE_QUOTE(
@@ -59,6 +63,46 @@ struct __attribute__ ((packed)) block_q8_0
59
63
  int8_t qs[QK8_0];
60
64
  };
61
65
 
66
+ struct __attribute__((packed)) block_q2_K
67
+ {
68
+ uint8_t scales[16];
69
+ uint8_t qs[64];
70
+ half d;
71
+ half dmin;
72
+ };
73
+
74
+ struct __attribute__((packed)) block_q3_K
75
+ {
76
+ uint8_t hmask[32];
77
+ uint8_t qs[64];
78
+ uint8_t scales[12];
79
+ half d;
80
+ };
81
+
82
+ struct __attribute__((packed)) block_q4_K
83
+ {
84
+ half d;
85
+ half dmin;
86
+ uint8_t scales[12];
87
+ uint8_t qs[128];
88
+ };
89
+
90
+ struct __attribute__((packed)) block_q5_K
91
+ {
92
+ half d;
93
+ half dmin;
94
+ uint8_t scales[12];
95
+ uint8_t qh[32];
96
+ uint8_t qs[128];
97
+ };
98
+
99
+ struct __attribute__((packed)) block_q6_K
100
+ {
101
+ uint8_t ql[128];
102
+ uint8_t qh[64];
103
+ int8_t scales[16];
104
+ half d;
105
+ };
62
106
 
63
107
  __kernel void convert_fp16_to_fp32(__global half* x, __global float* y) {
64
108
  const uint i = get_global_id(0);
@@ -131,8 +175,314 @@ void convert_f16(__global half* x, const int ib, const int iqs, float* v0, float
131
175
  *v0 = vload_half(0, &x[ib + 0]);
132
176
  *v1 = vload_half(0, &x[ib + 1]);
133
177
  }
178
+
179
+ inline void get_scale_min_k4(int j, const __global uint8_t *q, uint8_t *d, uint8_t *m)
180
+ {
181
+ if (j < 4)
182
+ {
183
+ *d = q[j] & 63;
184
+ *m = q[j + 4] & 63;
185
+ }
186
+ else
187
+ {
188
+ *d = (q[j + 4] & 0xF) | ((q[j - 4] >> 6) << 4);
189
+ *m = (q[j + 4] >> 4) | ((q[j - 0] >> 6) << 4);
190
+ }
191
+ }
192
+
193
+ __kernel void dequantize_block_q2_K(__global const struct block_q2_K *x, __global float *yy)
194
+ {
195
+ const int i = get_group_id(0);
196
+ const int tid = get_local_id(0);
197
+ const int n = tid / 32;
198
+ const int l = tid - 32 * n;
199
+ const int is = 8 * n + l / 16;
200
+
201
+ const uint8_t q = x[i].qs[32 * n + l];
202
+ __global float *y = yy + i * 256 + 128 * n;
203
+
204
+ const float dall = vload_half(0, &x[i].d);
205
+ const float dmin = vload_half(0, &x[i].dmin);
206
+
207
+ y[l + 0] = dall * (x[i].scales[is + 0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is + 0] >> 4);
208
+ y[l + 32] = dall * (x[i].scales[is + 2] & 0xF) * ((q >> 2) & 3) - dmin * (x[i].scales[is + 2] >> 4);
209
+ y[l + 64] = dall * (x[i].scales[is + 4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is + 4] >> 4);
210
+ y[l + 96] = dall * (x[i].scales[is + 6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is + 6] >> 4);
211
+ }
212
+
213
+ __kernel void dequantize_block_q3_K(__global const struct block_q3_K *x, __global float *yy)
214
+ {
215
+ int r = get_local_id(0) / 4;
216
+ int i = get_group_id(0);
217
+ int tid = r / 2;
218
+ int is0 = r % 2;
219
+ int l0 = 16 * is0 + 4 * (get_local_id(0) % 4);
220
+ int n = tid / 4;
221
+ int j = tid - 4 * n;
222
+
223
+ uint8_t m = 1 << (4 * n + j);
224
+ int is = 8 * n + 2 * j + is0;
225
+ int shift = 2 * j;
226
+
227
+ int8_t us = is < 4 ? (x[i].scales[is - 0] & 0xF) | (((x[i].scales[is + 8] >> 0) & 3) << 4)
228
+ : is < 8 ? (x[i].scales[is - 0] & 0xF) | (((x[i].scales[is + 4] >> 2) & 3) << 4)
229
+ : is < 12 ? (x[i].scales[is - 8] >> 4) | (((x[i].scales[is + 0] >> 4) & 3) << 4)
230
+ : (x[i].scales[is - 8] >> 4) | (((x[i].scales[is - 4] >> 6) & 3) << 4);
231
+ float d_all = vload_half(0, &x[i].d);
232
+ float dl = d_all * (us - 32);
233
+
234
+ __global float *y = yy + i * 256 + 128 * n + 32 * j;
235
+ const __global uint8_t *q = x[i].qs + 32 * n;
236
+ const __global uint8_t *hm = x[i].hmask;
237
+
238
+ for (int l = l0; l < l0 + 4; ++l)
239
+ y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4));
240
+ }
241
+
242
+ __kernel void dequantize_block_q4_K(__global const struct block_q4_K *x, __global float *yy)
243
+ {
244
+ const int i = get_group_id(0);
245
+ const int tid = get_local_id(0);
246
+ const int il = tid / 8;
247
+ const int ir = tid % 8;
248
+ const int is = 2 * il;
249
+ const int n = 4;
250
+
251
+ __global float *y = yy + i * 256 + 64 * il + n * ir;
252
+
253
+ const float dall = vload_half(0, &x[i].d);
254
+ const float dmin = vload_half(0, &x[i].dmin);
255
+
256
+ __global const uint8_t *q = x[i].qs + 32 * il + n * ir;
257
+
258
+ uint8_t sc, m;
259
+ get_scale_min_k4(is + 0, x[i].scales, &sc, &m);
260
+ float d1 = dall * sc;
261
+ float m1 = dmin * m;
262
+ get_scale_min_k4(is + 1, x[i].scales, &sc, &m);
263
+ float d2 = dall * sc;
264
+ float m2 = dmin * m;
265
+ for (int l = 0; l < n; ++l)
266
+ {
267
+ y[l + 0] = d1 * (q[l] & 0xF) - m1;
268
+ y[l + 32] = d2 * (q[l] >> 4) - m2;
269
+ }
270
+ }
271
+
272
+ __kernel void dequantize_block_q5_K(__global const struct block_q5_K *x, __global float *yy)
273
+ {
274
+ const int i = get_group_id(0);
275
+ const int tid = get_local_id(0);
276
+ const int il = tid / 16;
277
+ const int ir = tid % 16;
278
+ const int is = 2 * il;
279
+
280
+ __global float *y = yy + i * 256 + 64 * il + 2 * ir;
281
+
282
+ const float dall = vload_half(0, &x[i].d);
283
+ const float dmin = vload_half(0, &x[i].dmin);
284
+
285
+ __global const uint8_t *ql = x[i].qs + 32 * il + 2 * ir;
286
+ __global const uint8_t *qh = x[i].qh + 2 * ir;
287
+
288
+ uint8_t sc, m;
289
+ get_scale_min_k4(is + 0, x[i].scales, &sc, &m);
290
+ const float d1 = dall * sc;
291
+ const float m1 = dmin * m;
292
+ get_scale_min_k4(is + 1, x[i].scales, &sc, &m);
293
+ const float d2 = dall * sc;
294
+ const float m2 = dmin * m;
295
+
296
+ uint8_t hm = 1 << (2 * il);
297
+ y[0] = d1 * ((ql[0] & 0xF) + (qh[0] & hm ? 16 : 0)) - m1;
298
+ y[1] = d1 * ((ql[1] & 0xF) + (qh[1] & hm ? 16 : 0)) - m1;
299
+ hm <<= 1;
300
+ y[32] = d2 * ((ql[0] >> 4) + (qh[0] & hm ? 16 : 0)) - m2;
301
+ y[33] = d2 * ((ql[1] >> 4) + (qh[1] & hm ? 16 : 0)) - m2;
302
+ }
303
+
304
+ __kernel void dequantize_block_q6_K(__global const struct block_q6_K *x, __global float *yy)
305
+ {
306
+ const int i = get_group_id(0);
307
+ const int tid = get_local_id(0);
308
+ const int ip = tid / 32;
309
+ const int il = tid - 32 * ip;
310
+ const int is = 8 * ip + il / 16;
311
+
312
+ __global float *y = yy + i * 256 + 128 * ip + il;
313
+
314
+ const float d = vload_half(0, &x[i].d);
315
+
316
+ __global const uint8_t *ql = x[i].ql + 64 * ip + il;
317
+ const uint8_t qh = x[i].qh[32 * ip + il];
318
+ __global const int8_t *sc = x[i].scales + is;
319
+
320
+ y[0] = d * sc[0] * ((int8_t)((ql[0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32);
321
+ y[32] = d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32);
322
+ y[64] = d * sc[4] * ((int8_t)((ql[0] >> 4) | (((qh >> 4) & 3) << 4)) - 32);
323
+ y[96] = d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32);
324
+ }
325
+
326
+
327
+ void vec_dot_q2_K(__global const struct block_q2_K* x, const int ib, const int iqs, const __global float *yy, float *result) {
328
+
329
+ int n = iqs / 128;
330
+ int r = iqs - 128 * n;
331
+ int l = r / 8;
332
+
333
+ __global const float *y = yy + 128 * n + l;
334
+ __global const uint8_t *q = x[ib].qs + 32 * n + l;
335
+ __global const uint8_t *s = x[ib].scales + 8 * n;
336
+
337
+ const float dall = vload_half(0, &x[ib].d);
338
+ const float dmin = vload_half(0, &x[ib].dmin);
339
+
340
+ float sum = y[ 0] * (dall * ((s[0] & 0xF) * ((q[ 0] >> 0) & 3)) - dmin * (s[0] >> 4))
341
+ + y[ 32] * (dall * ((s[2] & 0xF) * ((q[ 0] >> 2) & 3)) - dmin * (s[2] >> 4))
342
+ + y[ 64] * (dall * ((s[4] & 0xF) * ((q[ 0] >> 4) & 3)) - dmin * (s[4] >> 4))
343
+ + y[ 96] * (dall * ((s[6] & 0xF) * ((q[ 0] >> 6) & 3)) - dmin * (s[6] >> 4))
344
+ + y[ 16] * (dall * ((s[1] & 0xF) * ((q[16] >> 0) & 3)) - dmin * (s[1] >> 4))
345
+ + y[ 48] * (dall * ((s[3] & 0xF) * ((q[16] >> 2) & 3)) - dmin * (s[3] >> 4))
346
+ + y[ 80] * (dall * ((s[5] & 0xF) * ((q[16] >> 4) & 3)) - dmin * (s[5] >> 4))
347
+ + y[112] * (dall * ((s[7] & 0xF) * ((q[16] >> 6) & 3)) - dmin * (s[7] >> 4));
348
+
349
+ *result = sum;
350
+ }
351
+
352
+ void vec_dot_q3_K(__global const struct block_q3_K* x, const int ib, const int iqs, const __global float *yy, float *result) {
353
+
354
+ const uint32_t kmask1 = 0x03030303;
355
+ const uint32_t kmask2 = 0x0f0f0f0f;
356
+
357
+ uint32_t aux[3];
358
+ uint32_t utmp[4];
359
+
360
+ int n = iqs/128;
361
+ int r = iqs - 128*n;
362
+ int l = r/8;
363
+
364
+ __global const float * y = yy + 128*n + l;
365
+ __global const uint8_t * q = x[ib].qs + 32*n + l;
366
+ __global const uint8_t * hm = x[ib].hmask + l;
367
+ const int8_t * s = (const int8_t *)utmp + 8*n;
368
+
369
+ aux[0] = x[ib].scales[0] | x[ib].scales[1] << 8 | x[ib].scales[2] << 16 | x[ib].scales[3] << 24;
370
+ aux[1] = x[ib].scales[4] | x[ib].scales[5] << 8 | x[ib].scales[6] << 16 | x[ib].scales[7] << 24;
371
+ aux[2] = x[ib].scales[8] | x[ib].scales[9] << 8 | x[ib].scales[10] << 16 | x[ib].scales[11] << 24;
372
+
373
+ utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
374
+ utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
375
+ utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
376
+ utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
377
+
378
+ const float dall = vload_half(0, &x[ib].d);
379
+ const uint8_t m = 1 << (4*n);
380
+
381
+ float sum = y[ 0] * (s[0] - 32) * (((q[ 0] >> 0) & 3) - (hm[ 0] & (m << 0) ? 0 : 4))
382
+ + y[ 32] * (s[2] - 32) * (((q[ 0] >> 2) & 3) - (hm[ 0] & (m << 1) ? 0 : 4))
383
+ + y[ 64] * (s[4] - 32) * (((q[ 0] >> 4) & 3) - (hm[ 0] & (m << 2) ? 0 : 4))
384
+ + y[ 96] * (s[6] - 32) * (((q[ 0] >> 6) & 3) - (hm[ 0] & (m << 3) ? 0 : 4))
385
+ + y[ 16] * (s[1] - 32) * (((q[16] >> 0) & 3) - (hm[16] & (m << 0) ? 0 : 4))
386
+ + y[ 48] * (s[3] - 32) * (((q[16] >> 2) & 3) - (hm[16] & (m << 1) ? 0 : 4))
387
+ + y[ 80] * (s[5] - 32) * (((q[16] >> 4) & 3) - (hm[16] & (m << 2) ? 0 : 4))
388
+ + y[112] * (s[7] - 32) * (((q[16] >> 6) & 3) - (hm[16] & (m << 3) ? 0 : 4));
389
+
390
+ *result = sum * dall;
391
+
392
+ }
393
+
394
+ void vec_dot_q4_K(__global const struct block_q4_K* x, const int ib, const int iqs, const __global float *yy, float *result) {
395
+
396
+ const int j = iqs / 64; // j is in 0...3
397
+ const int ir = (iqs - 64*j)/2; // ir is in 0...28 in steps of 4
398
+ const int is = 2*j; // is is in 0...6 in steps of 2
399
+
400
+ __global const float * y = yy + 64*j + ir;
401
+ __global const uint8_t * q = x[ib].qs + 32*j + ir;
402
+
403
+ const float dall = vload_half(0, &x[ib].d);
404
+ const float dmin = vload_half(0, &x[ib].dmin);
405
+
406
+ uint8_t sc, m;
407
+ get_scale_min_k4(is + 0, x[ib].scales, &sc, &m);
408
+ const float d1 = dall * sc;
409
+ const float m1 = dmin * m;
410
+ get_scale_min_k4(is + 1, x[ib].scales, &sc, &m);
411
+ const float d2 = dall * sc;
412
+ const float m2 = dmin * m;
413
+
414
+ float sum = 0;
415
+ for (int k = 0; k < 4; ++k) {
416
+ sum += y[k + 0] * (d1 * (q[k] & 0xF) - m1);
417
+ sum += y[k + 32] * (d2 * (q[k] >> 4) - m2);
418
+ }
419
+
420
+ *result = sum;
421
+ }
422
+
423
+ void vec_dot_q5_K(__global const struct block_q5_K* x, const int ib, const int iqs, const __global float *yy, float *result) {
424
+
425
+ const int j = iqs / 64;
426
+ const int ir = (iqs - 64*j)/2;
427
+ const int is = 2*j;
428
+
429
+ __global const float * y = yy + 64*j + ir;
430
+ __global const uint8_t * ql = x[ib].qs + 32*j + ir;
431
+ __global const uint8_t * qh = x[ib].qh + ir;
432
+
433
+ const float dall = vload_half(0, &x[ib].d);
434
+ const float dmin = vload_half(0, &x[ib].dmin);
435
+
436
+ uint8_t sc, m;
437
+ get_scale_min_k4(is + 0, x[ib].scales, &sc, &m);
438
+ const float d1 = dall * sc;
439
+ const float m1 = dmin * m;
440
+ get_scale_min_k4(is + 1, x[ib].scales, &sc, &m);
441
+ const float d2 = dall * sc;
442
+ const float m2 = dmin * m;
443
+
444
+ uint8_t hm = 1 << is;
445
+ float sum = 0;
446
+ for (int k = 0; k < 4; ++k) {
447
+ sum += y[k + 0] * (d1 * ((ql[k] & 0xF) + (qh[k] & hm ? 16 : 0)) - m1);
448
+ }
449
+ hm <<= 1;
450
+ for (int k = 0; k < 4; ++k) {
451
+ sum += y[k + 32] * (d2 * ((ql[k] >> 4) + (qh[k] & hm ? 16 : 0)) - m2);
452
+ }
453
+ *result = sum;
454
+
455
+ }
456
+
457
+ void vec_dot_q6_K(__global const struct block_q6_K* x, const int ib, const int iqs, const __global float *yy, float *result) {
458
+
459
+
460
+ const int ip = iqs / 128; // 0 or 1
461
+ const int il = (iqs - 128*ip)/8; // 0...15
462
+ const int is = 8*ip;
463
+
464
+ __global const float * y = yy + 128*ip + il;
465
+
466
+ const float d = vload_half(0, &x[ib].d);
467
+
468
+ __global const uint8_t * ql = x[ib].ql + 64*ip + il;
469
+ __global const uint8_t * qh = x[ib].qh + 32*ip + il;
470
+ __global const int8_t * sc = x[ib].scales + is;
471
+
472
+ *result = y[ 0] * d * sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh[ 0] >> 0) & 3) << 4)) - 32)
473
+ + y[ 32] * d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh[ 0] >> 2) & 3) << 4)) - 32)
474
+ + y[ 64] * d * sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh[ 0] >> 4) & 3) << 4)) - 32)
475
+ + y[ 96] * d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh[ 0] >> 6) & 3) << 4)) - 32)
476
+ + y[ 16] * d * sc[1] * ((int8_t)((ql[16] & 0xF) | (((qh[16] >> 0) & 3) << 4)) - 32)
477
+ + y[ 48] * d * sc[3] * ((int8_t)((ql[48] & 0xF) | (((qh[16] >> 2) & 3) << 4)) - 32)
478
+ + y[ 80] * d * sc[5] * ((int8_t)((ql[16] >> 4) | (((qh[16] >> 4) & 3) << 4)) - 32)
479
+ + y[112] * d * sc[7] * ((int8_t)((ql[48] >> 4) | (((qh[16] >> 6) & 3) << 4)) - 32);
480
+
481
+ }
482
+
134
483
  );
135
484
 
485
+
136
486
  std::string dequant_template = MULTILINE_QUOTE(
137
487
  __kernel void KERNEL_NAME(__global X_TYPE* x, __global float* y) {
138
488
  const int i = get_group_id(0)*get_local_size(0) + get_local_id(0)*2;
@@ -160,7 +510,7 @@ __kernel void KERNEL_NAME(__global X_TYPE* x, __global float* y) {
160
510
  std::string dequant_mul_mat_vec_template = MULTILINE_QUOTE(
161
511
  __kernel void KERNEL_NAME(__global X_TYPE* x, __local float* tmp, __global float* y, __global float* dst, const int ncols) {
162
512
  const int block_size = get_local_size(0);
163
- const int row = get_global_id(0) / block_size;
513
+ const int row = get_group_id(0);
164
514
  const int tid = get_local_id(0);
165
515
 
166
516
  const uint qk = QUANT_K;
@@ -199,6 +549,45 @@ __kernel void KERNEL_NAME(__global X_TYPE* x, __local float* tmp, __global float
199
549
  }
200
550
  );
201
551
 
552
+ std::string dequant_mul_mat_vec_k_template = MULTILINE_QUOTE(
553
+ __kernel void KERNEL_NAME(__global X_TYPE* x, __local float* tmp, __global float* y, __global float* dst, const int ncols) {
554
+ const int block_size = get_local_size(0);
555
+ const int row = get_group_id(0);
556
+ const int tid = get_local_id(0);
557
+
558
+ const int iter_stride = 256;
559
+ const int vals_per_iter = iter_stride / block_size;
560
+ const int num_blocks_per_row = ncols / 256;
561
+ const int ib0 = row*num_blocks_per_row;
562
+
563
+ tmp[tid] = 0;
564
+
565
+ for (int i = 0; i < ncols; i += iter_stride) {
566
+ const int col = i + vals_per_iter*tid;
567
+ const int ib = ib0 + col/256; // x block index
568
+ const int iqs = col%256; // x quant index
569
+ const int iybs = col - col%256; // y block start index
570
+
571
+ // dequantize
572
+ float v;
573
+ DOT_KERNEL(x, ib, iqs, y + iybs, &v);
574
+ tmp[tid] += v;
575
+ }
576
+
577
+ // sum up partial sums and write back result
578
+ barrier(CLK_LOCAL_MEM_FENCE);
579
+ for (int s=block_size/2; s>0; s>>=1) {
580
+ if (tid < s) {
581
+ tmp[tid] += tmp[tid + s];
582
+ }
583
+ barrier(CLK_LOCAL_MEM_FENCE);
584
+ }
585
+ if (tid == 0) {
586
+ dst[row] = tmp[0];
587
+ }
588
+ }
589
+ );
590
+
202
591
  std::string mul_template = MULTILINE_QUOTE(
203
592
  __kernel void KERNEL_NAME(__global TYPE* x, const int x_offset, __global TYPE* y, const int y_offset, __global TYPE* dst, const int dst_offset, const int ky) {
204
593
  const int i = get_group_id(0)*get_local_size(0) + get_local_id(0);
@@ -260,6 +649,18 @@ std::array<std::string, 2> mul_str_values = {
260
649
  "mul_f32", "float"
261
650
  };
262
651
 
652
+ std::array<std::string, 3> dmmv_k_str_keys = {
653
+ "KERNEL_NAME", "X_TYPE", "DOT_KERNEL"
654
+ };
655
+
656
+ std::array<std::string, 15> dmmv_k_str_values = {
657
+ "dequantize_mul_mat_vec_q2_K", "struct block_q2_K", "vec_dot_q2_K",
658
+ "dequantize_mul_mat_vec_q3_K", "struct block_q3_K", "vec_dot_q3_K",
659
+ "dequantize_mul_mat_vec_q4_K", "struct block_q4_K", "vec_dot_q4_K",
660
+ "dequantize_mul_mat_vec_q5_K", "struct block_q5_K", "vec_dot_q5_K",
661
+ "dequantize_mul_mat_vec_q6_K", "struct block_q6_K", "vec_dot_q6_K",
662
+ };
663
+
263
664
  std::string& replace(std::string& s, const std::string& from, const std::string& to) {
264
665
  size_t pos = 0;
265
666
  while ((pos = s.find(from, pos)) != std::string::npos) {
@@ -289,6 +690,14 @@ std::string generate_kernels() {
289
690
  }
290
691
  src << mul_kernel << '\n';
291
692
  }
693
+ for (size_t i = 0; i < dmmv_k_str_values.size(); i += dmmv_k_str_keys.size()) {
694
+ std::string dmmv_k_kernel = dequant_mul_mat_vec_k_template;
695
+ for (size_t j = 0; j < dmmv_k_str_keys.size(); j++) {
696
+ replace(dmmv_k_kernel, dmmv_k_str_keys[j], dmmv_k_str_values[i + j]);
697
+ }
698
+ src << dmmv_k_kernel << '\n';
699
+ }
700
+
292
701
  return src.str();
293
702
  }
294
703
 
@@ -300,6 +709,8 @@ static cl_program program;
300
709
  static cl_kernel convert_row_f16_cl;
301
710
  static cl_kernel dequantize_row_q4_0_cl, dequantize_row_q4_1_cl, dequantize_row_q5_0_cl, dequantize_row_q5_1_cl, dequantize_row_q8_0_cl;
302
711
  static cl_kernel dequantize_mul_mat_vec_q4_0_cl, dequantize_mul_mat_vec_q4_1_cl, dequantize_mul_mat_vec_q5_0_cl, dequantize_mul_mat_vec_q5_1_cl, dequantize_mul_mat_vec_q8_0_cl, convert_mul_mat_vec_f16_cl;
712
+ static cl_kernel dequantize_block_q2_k_cl, dequantize_block_q3_k_cl, dequantize_block_q4_k_cl, dequantize_block_q5_k_cl, dequantize_block_q6_k_cl;
713
+ static cl_kernel dequantize_mul_mat_vec_q2_K_cl, dequantize_mul_mat_vec_q3_K_cl, dequantize_mul_mat_vec_q4_K_cl, dequantize_mul_mat_vec_q5_K_cl, dequantize_mul_mat_vec_q6_K_cl;
303
714
  static cl_kernel mul_f32_cl;
304
715
  static bool fp16_support;
305
716
 
@@ -529,6 +940,12 @@ void ggml_cl_init(void) {
529
940
  CL_CHECK((dequantize_row_q5_0_cl = clCreateKernel(program, "dequantize_row_q5_0", &err), err));
530
941
  CL_CHECK((dequantize_row_q5_1_cl = clCreateKernel(program, "dequantize_row_q5_1", &err), err));
531
942
  CL_CHECK((dequantize_row_q8_0_cl = clCreateKernel(program, "dequantize_row_q8_0", &err), err));
943
+ CL_CHECK((dequantize_row_q8_0_cl = clCreateKernel(program, "dequantize_row_q8_0", &err), err));
944
+ CL_CHECK((dequantize_block_q2_k_cl = clCreateKernel(program, "dequantize_block_q2_K", &err), err));
945
+ CL_CHECK((dequantize_block_q3_k_cl = clCreateKernel(program, "dequantize_block_q3_K", &err), err));
946
+ CL_CHECK((dequantize_block_q4_k_cl = clCreateKernel(program, "dequantize_block_q4_K", &err), err));
947
+ CL_CHECK((dequantize_block_q5_k_cl = clCreateKernel(program, "dequantize_block_q5_K", &err), err));
948
+ CL_CHECK((dequantize_block_q6_k_cl = clCreateKernel(program, "dequantize_block_q6_K", &err), err));
532
949
 
533
950
  // dequant mul mat kernel
534
951
  CL_CHECK((dequantize_mul_mat_vec_q4_0_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q4_0", &err), err));
@@ -537,6 +954,11 @@ void ggml_cl_init(void) {
537
954
  CL_CHECK((dequantize_mul_mat_vec_q5_1_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q5_1", &err), err));
538
955
  CL_CHECK((dequantize_mul_mat_vec_q8_0_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q8_0", &err), err));
539
956
  CL_CHECK((convert_mul_mat_vec_f16_cl = clCreateKernel(program, "convert_mul_mat_vec_f16", &err), err));
957
+ CL_CHECK((dequantize_mul_mat_vec_q2_K_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q2_K", &err), err));
958
+ CL_CHECK((dequantize_mul_mat_vec_q3_K_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q3_K", &err), err));
959
+ CL_CHECK((dequantize_mul_mat_vec_q4_K_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q4_K", &err), err));
960
+ CL_CHECK((dequantize_mul_mat_vec_q5_K_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q5_K", &err), err));
961
+ CL_CHECK((dequantize_mul_mat_vec_q6_K_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q6_K", &err), err));
540
962
 
541
963
  // mul kernel
542
964
  CL_CHECK((mul_f32_cl = clCreateKernel(program, "mul_f32", &err), err));
@@ -554,6 +976,16 @@ static cl_kernel* ggml_get_to_fp32_cl(ggml_type type) {
554
976
  return &dequantize_row_q5_1_cl;
555
977
  case GGML_TYPE_Q8_0:
556
978
  return &dequantize_row_q8_0_cl;
979
+ case GGML_TYPE_Q2_K:
980
+ return &dequantize_block_q2_k_cl;
981
+ case GGML_TYPE_Q3_K:
982
+ return &dequantize_block_q3_k_cl;
983
+ case GGML_TYPE_Q4_K:
984
+ return &dequantize_block_q4_k_cl;
985
+ case GGML_TYPE_Q5_K:
986
+ return &dequantize_block_q5_k_cl;
987
+ case GGML_TYPE_Q6_K:
988
+ return &dequantize_block_q6_k_cl;
557
989
  case GGML_TYPE_F16:
558
990
  return &convert_row_f16_cl;
559
991
  default:
@@ -561,6 +993,50 @@ static cl_kernel* ggml_get_to_fp32_cl(ggml_type type) {
561
993
  }
562
994
  }
563
995
 
996
+ static size_t ggml_cl_global_denom(ggml_type type) {
997
+ switch (type) {
998
+ case GGML_TYPE_Q4_0:
999
+ case GGML_TYPE_Q4_1:
1000
+ case GGML_TYPE_Q5_0:
1001
+ case GGML_TYPE_Q5_1:
1002
+ case GGML_TYPE_Q8_0:
1003
+ return 1;
1004
+ case GGML_TYPE_Q2_K:
1005
+ case GGML_TYPE_Q3_K:
1006
+ return 4;
1007
+ case GGML_TYPE_Q4_K:
1008
+ return 8;
1009
+ case GGML_TYPE_Q5_K:
1010
+ case GGML_TYPE_Q6_K:
1011
+ return 4;
1012
+ case GGML_TYPE_F16:
1013
+ default:
1014
+ return 1;
1015
+ }
1016
+ }
1017
+
1018
+ static size_t ggml_cl_local_size(ggml_type type) {
1019
+ switch (type) {
1020
+ case GGML_TYPE_Q4_0:
1021
+ case GGML_TYPE_Q4_1:
1022
+ case GGML_TYPE_Q5_0:
1023
+ case GGML_TYPE_Q5_1:
1024
+ case GGML_TYPE_Q8_0:
1025
+ return 0;
1026
+ case GGML_TYPE_Q2_K:
1027
+ case GGML_TYPE_Q3_K:
1028
+ return 64;
1029
+ case GGML_TYPE_Q4_K:
1030
+ return 32;
1031
+ case GGML_TYPE_Q5_K:
1032
+ case GGML_TYPE_Q6_K:
1033
+ return 64;
1034
+ case GGML_TYPE_F16:
1035
+ default:
1036
+ return 0;
1037
+ }
1038
+ }
1039
+
564
1040
  static cl_kernel* ggml_get_dequantize_mul_mat_vec_cl(ggml_type type) {
565
1041
  switch (type) {
566
1042
  case GGML_TYPE_Q4_0:
@@ -575,6 +1051,16 @@ static cl_kernel* ggml_get_dequantize_mul_mat_vec_cl(ggml_type type) {
575
1051
  return &dequantize_mul_mat_vec_q8_0_cl;
576
1052
  case GGML_TYPE_F16:
577
1053
  return &convert_mul_mat_vec_f16_cl;
1054
+ case GGML_TYPE_Q2_K:
1055
+ return &dequantize_mul_mat_vec_q2_K_cl;
1056
+ case GGML_TYPE_Q3_K:
1057
+ return &dequantize_mul_mat_vec_q3_K_cl;
1058
+ case GGML_TYPE_Q4_K:
1059
+ return &dequantize_mul_mat_vec_q4_K_cl;
1060
+ case GGML_TYPE_Q5_K:
1061
+ return &dequantize_mul_mat_vec_q5_K_cl;
1062
+ case GGML_TYPE_Q6_K:
1063
+ return &dequantize_mul_mat_vec_q6_K_cl;
578
1064
  default:
579
1065
  return nullptr;
580
1066
  }
@@ -1017,6 +1503,9 @@ static void ggml_cl_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor *
1017
1503
  cl_kernel* dmmv = ggml_get_dequantize_mul_mat_vec_cl(type);
1018
1504
  GGML_ASSERT(to_fp32_cl != nullptr);
1019
1505
 
1506
+ const size_t global_denom = ggml_cl_global_denom(type);
1507
+ const size_t local = ggml_cl_local_size(type);
1508
+
1020
1509
  size_t ev_idx = 0;
1021
1510
  std::vector<cl_event> events;
1022
1511
 
@@ -1049,10 +1538,10 @@ static void ggml_cl_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor *
1049
1538
  CL_CHECK(clEnqueueNDRangeKernel(queue, *dmmv, 1, NULL, &global, &local, events.size() - 1, events.data(), events.data() + ev_idx++));
1050
1539
  } else { // general dequantization kernel + CLBlast matrix matrix multiplication
1051
1540
  // convert src0 to fp32 on device
1052
- const size_t global = x_ne;
1541
+ const size_t global = x_ne / global_denom;
1053
1542
  CL_CHECK(clSetKernelArg(*to_fp32_cl, 0, sizeof(cl_mem), &d_Q));
1054
1543
  CL_CHECK(clSetKernelArg(*to_fp32_cl, 1, sizeof(cl_mem), &d_X));
1055
- CL_CHECK(clEnqueueNDRangeKernel(queue, *to_fp32_cl, 1, NULL, &global, NULL, events.size(), !events.empty() ? events.data() : NULL, NULL));
1544
+ CL_CHECK(clEnqueueNDRangeKernel(queue, *to_fp32_cl, 1, NULL, &global, local > 0 ? &local : NULL, events.size(), !events.empty() ? events.data() : NULL, NULL));
1056
1545
 
1057
1546
  // copy src1 to device
1058
1547
  CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_Y, 0, src1, i03, i02, NULL));