llama_cpp 0.2.2 → 0.3.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -21,11 +21,19 @@
21
21
 
22
22
  #define CL_DMMV_BLOCK_SIZE 32
23
23
 
24
+ #ifndef K_QUANTS_PER_ITERATION
25
+ #define K_QUANTS_PER_ITERATION 1
26
+ #else
27
+ static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2");
28
+ #endif
29
+
24
30
  #define MULTILINE_QUOTE(...) #__VA_ARGS__
25
31
  static std::string program_source = MULTILINE_QUOTE(
26
32
 
27
33
  typedef char int8_t;
28
34
  typedef uchar uint8_t;
35
+ typedef short int16_t;
36
+ typedef ushort uint16_t;
29
37
  typedef int int32_t;
30
38
  typedef uint uint32_t;
31
39
 
@@ -175,7 +183,9 @@ void convert_f16(__global half* x, const int ib, const int iqs, float* v0, float
175
183
  *v0 = vload_half(0, &x[ib + 0]);
176
184
  *v1 = vload_half(0, &x[ib + 1]);
177
185
  }
186
+ );
178
187
 
188
+ static std::string k_quants_source = MULTILINE_QUOTE(
179
189
  inline void get_scale_min_k4(int j, const __global uint8_t *q, uint8_t *d, uint8_t *m)
180
190
  {
181
191
  if (j < 4)
@@ -199,7 +209,7 @@ __kernel void dequantize_block_q2_K(__global const struct block_q2_K *x, __globa
199
209
  const int is = 8 * n + l / 16;
200
210
 
201
211
  const uint8_t q = x[i].qs[32 * n + l];
202
- __global float *y = yy + i * 256 + 128 * n;
212
+ __global float *y = yy + i * QK_K + 128 * n;
203
213
 
204
214
  const float dall = vload_half(0, &x[i].d);
205
215
  const float dmin = vload_half(0, &x[i].dmin);
@@ -231,7 +241,7 @@ __kernel void dequantize_block_q3_K(__global const struct block_q3_K *x, __globa
231
241
  float d_all = vload_half(0, &x[i].d);
232
242
  float dl = d_all * (us - 32);
233
243
 
234
- __global float *y = yy + i * 256 + 128 * n + 32 * j;
244
+ __global float *y = yy + i * QK_K + 128 * n + 32 * j;
235
245
  const __global uint8_t *q = x[i].qs + 32 * n;
236
246
  const __global uint8_t *hm = x[i].hmask;
237
247
 
@@ -248,7 +258,7 @@ __kernel void dequantize_block_q4_K(__global const struct block_q4_K *x, __globa
248
258
  const int is = 2 * il;
249
259
  const int n = 4;
250
260
 
251
- __global float *y = yy + i * 256 + 64 * il + n * ir;
261
+ __global float *y = yy + i * QK_K + 64 * il + n * ir;
252
262
 
253
263
  const float dall = vload_half(0, &x[i].d);
254
264
  const float dmin = vload_half(0, &x[i].dmin);
@@ -277,7 +287,7 @@ __kernel void dequantize_block_q5_K(__global const struct block_q5_K *x, __globa
277
287
  const int ir = tid % 16;
278
288
  const int is = 2 * il;
279
289
 
280
- __global float *y = yy + i * 256 + 64 * il + 2 * ir;
290
+ __global float *y = yy + i * QK_K + 64 * il + 2 * ir;
281
291
 
282
292
  const float dall = vload_half(0, &x[i].d);
283
293
  const float dmin = vload_half(0, &x[i].dmin);
@@ -309,7 +319,7 @@ __kernel void dequantize_block_q6_K(__global const struct block_q6_K *x, __globa
309
319
  const int il = tid - 32 * ip;
310
320
  const int is = 8 * ip + il / 16;
311
321
 
312
- __global float *y = yy + i * 256 + 128 * ip + il;
322
+ __global float *y = yy + i * QK_K + 128 * ip + il;
313
323
 
314
324
  const float d = vload_half(0, &x[i].d);
315
325
 
@@ -323,161 +333,383 @@ __kernel void dequantize_block_q6_K(__global const struct block_q6_K *x, __globa
323
333
  y[96] = d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32);
324
334
  }
325
335
 
336
+ __kernel void dequantize_mul_mat_vec_q2_K(__global const struct block_q2_K * xx, __local float* tmp, __global float* yy, __global float* dst, const int ncols) {
326
337
 
327
- void vec_dot_q2_K(__global const struct block_q2_K* x, const int ib, const int iqs, const __global float *yy, float *result) {
338
+ const int row = get_group_id(0);
328
339
 
329
- int n = iqs / 128;
330
- int r = iqs - 128 * n;
331
- int l = r / 8;
340
+ const int num_blocks_per_row = ncols / QK_K;
341
+ const int ib0 = row*num_blocks_per_row;
332
342
 
333
- __global const float *y = yy + 128 * n + l;
334
- __global const uint8_t *q = x[ib].qs + 32 * n + l;
335
- __global const uint8_t *s = x[ib].scales + 8 * n;
343
+ __global const struct block_q2_K * x = xx + ib0;
336
344
 
337
- const float dall = vload_half(0, &x[ib].d);
338
- const float dmin = vload_half(0, &x[ib].dmin);
345
+ const int tid = get_local_id(0)/K_QUANTS_PER_ITERATION; // 0...31 or 0...15
346
+ const int ix = get_local_id(0)%K_QUANTS_PER_ITERATION; // 0 or 0,1
339
347
 
340
- float sum = y[ 0] * (dall * ((s[0] & 0xF) * ((q[ 0] >> 0) & 3)) - dmin * (s[0] >> 4))
341
- + y[ 32] * (dall * ((s[2] & 0xF) * ((q[ 0] >> 2) & 3)) - dmin * (s[2] >> 4))
342
- + y[ 64] * (dall * ((s[4] & 0xF) * ((q[ 0] >> 4) & 3)) - dmin * (s[4] >> 4))
343
- + y[ 96] * (dall * ((s[6] & 0xF) * ((q[ 0] >> 6) & 3)) - dmin * (s[6] >> 4))
344
- + y[ 16] * (dall * ((s[1] & 0xF) * ((q[16] >> 0) & 3)) - dmin * (s[1] >> 4))
345
- + y[ 48] * (dall * ((s[3] & 0xF) * ((q[16] >> 2) & 3)) - dmin * (s[3] >> 4))
346
- + y[ 80] * (dall * ((s[5] & 0xF) * ((q[16] >> 4) & 3)) - dmin * (s[5] >> 4))
347
- + y[112] * (dall * ((s[7] & 0xF) * ((q[16] >> 6) & 3)) - dmin * (s[7] >> 4));
348
+ const int step = 16/K_QUANTS_PER_ITERATION;
348
349
 
349
- *result = sum;
350
- }
350
+ const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
351
+ const int in = tid - step*im; // 0...15 or 0...7
351
352
 
352
- void vec_dot_q3_K(__global const struct block_q3_K* x, const int ib, const int iqs, const __global float *yy, float *result) {
353
+ const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15 or 0...14 in steps of 2
354
+ const int q_offset = 32*im + l0;
355
+ const int s_offset = 8*im;
356
+ const int y_offset = 128*im + l0;
353
357
 
354
- const uint32_t kmask1 = 0x03030303;
355
- const uint32_t kmask2 = 0x0f0f0f0f;
358
+ tmp[16 * ix + tid] = 0;
356
359
 
357
- uint32_t aux[3];
358
- uint32_t utmp[4];
360
+ uint32_t aux[4];
361
+ const uint8_t * d = (const uint8_t *)aux;
362
+ const uint8_t * m = (const uint8_t *)(aux + 2);
359
363
 
360
- int n = iqs/128;
361
- int r = iqs - 128*n;
362
- int l = r/8;
364
+ for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
363
365
 
364
- __global const float * y = yy + 128*n + l;
365
- __global const uint8_t * q = x[ib].qs + 32*n + l;
366
- __global const uint8_t * hm = x[ib].hmask + l;
367
- const int8_t * s = (const int8_t *)utmp + 8*n;
366
+ __global const float * y = yy + i * QK_K + y_offset;
367
+ __global const uint8_t * q = x[i].qs + q_offset;
368
368
 
369
- aux[0] = x[ib].scales[0] | x[ib].scales[1] << 8 | x[ib].scales[2] << 16 | x[ib].scales[3] << 24;
370
- aux[1] = x[ib].scales[4] | x[ib].scales[5] << 8 | x[ib].scales[6] << 16 | x[ib].scales[7] << 24;
371
- aux[2] = x[ib].scales[8] | x[ib].scales[9] << 8 | x[ib].scales[10] << 16 | x[ib].scales[11] << 24;
369
+ const float dall = vload_half(0, &x[i].d);
370
+ const float dmin = vload_half(0, &x[i].dmin);
372
371
 
373
- utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
374
- utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
375
- utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
376
- utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
372
+ __global const uint32_t * a = (__global const uint32_t *)(x[i].scales + s_offset);
373
+ aux[0] = a[0] & 0x0f0f0f0f;
374
+ aux[1] = a[1] & 0x0f0f0f0f;
375
+ aux[2] = (a[0] >> 4) & 0x0f0f0f0f;
376
+ aux[3] = (a[1] >> 4) & 0x0f0f0f0f;
377
377
 
378
- const float dall = vload_half(0, &x[ib].d);
379
- const uint8_t m = 1 << (4*n);
378
+ float sum1 = 0, sum2 = 0;
379
+ for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
380
+ sum1 += y[l+ 0] * d[0] * ((q[l+ 0] >> 0) & 3)
381
+ + y[l+32] * d[2] * ((q[l+ 0] >> 2) & 3)
382
+ + y[l+64] * d[4] * ((q[l+ 0] >> 4) & 3)
383
+ + y[l+96] * d[6] * ((q[l+ 0] >> 6) & 3)
384
+ + y[l+16] * d[1] * ((q[l+16] >> 0) & 3)
385
+ + y[l+48] * d[3] * ((q[l+16] >> 2) & 3)
386
+ + y[l+80] * d[5] * ((q[l+16] >> 4) & 3)
387
+ +y[l+112] * d[7] * ((q[l+16] >> 6) & 3);
388
+ sum2 += y[l+ 0] * m[0] + y[l+32] * m[2] + y[l+64] * m[4] + y[ l+96] * m[6]
389
+ + y[l+16] * m[1] + y[l+48] * m[3] + y[l+80] * m[5] + y[l+112] * m[7];
380
390
 
381
- float sum = y[ 0] * (s[0] - 32) * (((q[ 0] >> 0) & 3) - (hm[ 0] & (m << 0) ? 0 : 4))
382
- + y[ 32] * (s[2] - 32) * (((q[ 0] >> 2) & 3) - (hm[ 0] & (m << 1) ? 0 : 4))
383
- + y[ 64] * (s[4] - 32) * (((q[ 0] >> 4) & 3) - (hm[ 0] & (m << 2) ? 0 : 4))
384
- + y[ 96] * (s[6] - 32) * (((q[ 0] >> 6) & 3) - (hm[ 0] & (m << 3) ? 0 : 4))
385
- + y[ 16] * (s[1] - 32) * (((q[16] >> 0) & 3) - (hm[16] & (m << 0) ? 0 : 4))
386
- + y[ 48] * (s[3] - 32) * (((q[16] >> 2) & 3) - (hm[16] & (m << 1) ? 0 : 4))
387
- + y[ 80] * (s[5] - 32) * (((q[16] >> 4) & 3) - (hm[16] & (m << 2) ? 0 : 4))
388
- + y[112] * (s[7] - 32) * (((q[16] >> 6) & 3) - (hm[16] & (m << 3) ? 0 : 4));
391
+ }
392
+ tmp[16 * ix + tid] += dall * sum1 - dmin * sum2;
389
393
 
390
- *result = sum * dall;
394
+ }
391
395
 
396
+ // sum up partial sums and write back result
397
+ barrier(CLK_LOCAL_MEM_FENCE);
398
+ for (int s=16; s>0; s>>=1) {
399
+ if (tid < s) {
400
+ tmp[tid] += tmp[tid + s];
401
+ }
402
+ barrier(CLK_LOCAL_MEM_FENCE);
403
+ }
404
+ if (tid == 0) {
405
+ dst[row] = tmp[0];
406
+ }
392
407
  }
393
408
 
394
- void vec_dot_q4_K(__global const struct block_q4_K* x, const int ib, const int iqs, const __global float *yy, float *result) {
409
+ __kernel void dequantize_mul_mat_vec_q3_K(__global const struct block_q3_K * xx, __local float* tmp, __global float* yy, __global float* dst, const int ncols) {
410
+ const uint16_t kmask1 = 0x0303;
411
+ const uint16_t kmask2 = 0x0f0f;
412
+
413
+ const int row = get_group_id(0);
414
+
415
+ const int num_blocks_per_row = ncols / QK_K;
416
+ const int ib0 = row*num_blocks_per_row;
395
417
 
396
- const int j = iqs / 64; // j is in 0...3
397
- const int ir = (iqs - 64*j)/2; // ir is in 0...28 in steps of 4
398
- const int is = 2*j; // is is in 0...6 in steps of 2
418
+ __global const struct block_q3_K * x = xx + ib0;
399
419
 
400
- __global const float * y = yy + 64*j + ir;
401
- __global const uint8_t * q = x[ib].qs + 32*j + ir;
420
+ const int tid = get_local_id(0)/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
421
+ const int ix = get_local_id(0)%K_QUANTS_PER_ITERATION; // 0 or 0,1
402
422
 
403
- const float dall = vload_half(0, &x[ib].d);
404
- const float dmin = vload_half(0, &x[ib].dmin);
423
+ const int n = K_QUANTS_PER_ITERATION; // iterations in the inner loop
424
+ const int step = 16/K_QUANTS_PER_ITERATION;
425
+ const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
426
+ const int in = tid - step*im; // 0....15 or 0...7
405
427
 
406
- uint8_t sc, m;
407
- get_scale_min_k4(is + 0, x[ib].scales, &sc, &m);
408
- const float d1 = dall * sc;
409
- const float m1 = dmin * m;
410
- get_scale_min_k4(is + 1, x[ib].scales, &sc, &m);
411
- const float d2 = dall * sc;
412
- const float m2 = dmin * m;
428
+ const uint8_t m = 1 << (4*im);
429
+
430
+ const int l0 = n*in; // 0...15 or 0...14 in steps of 2
431
+ const int q_offset = 32*im + l0;
432
+ const int y_offset = 128*im + l0;
433
+
434
+ uint16_t utmp[4];
435
+ const int8_t * s = (const int8_t *)utmp;
436
+
437
+ const uint16_t s_shift = 4*im;
438
+
439
+ tmp[16 * ix + tid] = 0;
440
+
441
+ for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
442
+
443
+ __global const float * y = yy + i * QK_K + y_offset;
444
+ __global const uint8_t * q = x[i].qs + q_offset;
445
+ __global const uint8_t * h = x[i].hmask + l0;
446
+
447
+ __global const uint16_t * a = (__global const uint16_t *)x[i].scales;
448
+ utmp[0] = ((a[0] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 0)) & kmask1) << 4);
449
+ utmp[1] = ((a[1] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 0)) & kmask1) << 4);
450
+ utmp[2] = ((a[2] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 2)) & kmask1) << 4);
451
+ utmp[3] = ((a[3] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 2)) & kmask1) << 4);
452
+
453
+ const float d = vload_half(0, &x[i].d);
454
+
455
+ float sum = 0;
456
+ for (int l = 0; l < n; ++l) {
457
+ sum += y[l+ 0] * (s[0] - 32) * (((q[l] >> 0) & 3) - (h[l] & (m << 0) ? 0 : 4))
458
+ + y[l+32] * (s[2] - 32) * (((q[l] >> 2) & 3) - (h[l] & (m << 1) ? 0 : 4))
459
+ + y[l+64] * (s[4] - 32) * (((q[l] >> 4) & 3) - (h[l] & (m << 2) ? 0 : 4))
460
+ + y[l+96] * (s[6] - 32) * (((q[l] >> 6) & 3) - (h[l] & (m << 3) ? 0 : 4));
461
+ sum += y[l+16] * (s[1] - 32) * (((q[l+16] >> 0) & 3) - (h[l+16] & (m << 0) ? 0 : 4))
462
+ + y[l+48] * (s[3] - 32) * (((q[l+16] >> 2) & 3) - (h[l+16] & (m << 1) ? 0 : 4))
463
+ + y[l+80] * (s[5] - 32) * (((q[l+16] >> 4) & 3) - (h[l+16] & (m << 2) ? 0 : 4))
464
+ + y[l+112] * (s[7] - 32) * (((q[l+16] >> 6) & 3) - (h[l+16] & (m << 3) ? 0 : 4));
465
+ }
466
+ tmp[16 * ix + tid] += d * sum;
413
467
 
414
- float sum = 0;
415
- for (int k = 0; k < 4; ++k) {
416
- sum += y[k + 0] * (d1 * (q[k] & 0xF) - m1);
417
- sum += y[k + 32] * (d2 * (q[k] >> 4) - m2);
418
468
  }
419
469
 
420
- *result = sum;
470
+ // sum up partial sums and write back result
471
+ barrier(CLK_LOCAL_MEM_FENCE);
472
+ for (int s=16; s>0; s>>=1) {
473
+ if (tid < s) {
474
+ tmp[tid] += tmp[tid + s];
475
+ }
476
+ barrier(CLK_LOCAL_MEM_FENCE);
477
+ }
478
+ if (tid == 0) {
479
+ dst[row] = tmp[0];
480
+ }
421
481
  }
422
482
 
423
- void vec_dot_q5_K(__global const struct block_q5_K* x, const int ib, const int iqs, const __global float *yy, float *result) {
483
+ __kernel void dequantize_mul_mat_vec_q4_K(__global const struct block_q4_K * xx, __local float* tmp, __global float* yy, __global float* dst, const int ncols) {
424
484
 
425
- const int j = iqs / 64;
426
- const int ir = (iqs - 64*j)/2;
427
- const int is = 2*j;
485
+ //to rename it later, just to test now
486
+ const uint16_t kmask1 = 0x3f3f;
487
+ const uint16_t kmask2 = 0x0f0f;
488
+ const uint16_t kmask3 = 0xc0c0;
428
489
 
429
- __global const float * y = yy + 64*j + ir;
430
- __global const uint8_t * ql = x[ib].qs + 32*j + ir;
431
- __global const uint8_t * qh = x[ib].qh + ir;
490
+ const int row = get_group_id(0);
491
+ const int num_blocks_per_row = ncols / QK_K;
492
+ const int ib0 = row*num_blocks_per_row;
432
493
 
433
- const float dall = vload_half(0, &x[ib].d);
434
- const float dmin = vload_half(0, &x[ib].dmin);
494
+ const int tid = get_local_id(0)/K_QUANTS_PER_ITERATION; // 0...15
495
+ const int ix = get_local_id(0)%K_QUANTS_PER_ITERATION;
435
496
 
436
- uint8_t sc, m;
437
- get_scale_min_k4(is + 0, x[ib].scales, &sc, &m);
438
- const float d1 = dall * sc;
439
- const float m1 = dmin * m;
440
- get_scale_min_k4(is + 1, x[ib].scales, &sc, &m);
441
- const float d2 = dall * sc;
442
- const float m2 = dmin * m;
497
+ const int step = 8/K_QUANTS_PER_ITERATION;
498
+
499
+ const int il = tid/step; // 0...3
500
+ const int ir = tid - step*il;// 0...3
501
+ const int n = 2*K_QUANTS_PER_ITERATION;
502
+
503
+ const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
504
+ const int in = il%2;
505
+
506
+ const int l0 = n*(2*ir + in);
507
+ const int q_offset = 32*im + l0;
508
+ const int y_offset = 64*im + l0;
509
+
510
+ uint16_t aux[4];
511
+ const uint8_t * sc = (const uint8_t *)aux;
512
+
513
+ __global const struct block_q4_K * x = xx + ib0;
514
+
515
+ tmp[16 * ix + tid] = 0;
516
+
517
+ for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
518
+
519
+ __global const uint8_t * q1 = x[i].qs + q_offset;
520
+ __global const uint8_t * q2 = q1 + 64;
521
+ __global const float * y1 = yy + i*QK_K + y_offset;
522
+ __global const float * y2 = y1 + 128;
523
+
524
+ const float dall = vload_half(0, &x[i].d);
525
+ const float dmin = vload_half(0, &x[i].dmin);
526
+
527
+ __global const uint16_t * a = (__global const uint16_t *)x[i].scales;
528
+ aux[0] = a[im+0] & kmask1;
529
+ aux[1] = a[im+2] & kmask1;
530
+ aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2);
531
+ aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2);
532
+
533
+ float4 s = (float4)(0.f);
534
+ float smin = 0;
535
+ for (int l = 0; l < n; ++l) {
536
+ s.x += y1[l] * (q1[l] & 0xF); s.y += y1[l+32] * (q1[l] >> 4);
537
+ s.z += y2[l] * (q2[l] & 0xF); s.w += y2[l+32] * (q2[l] >> 4);
538
+ smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7];
539
+ }
540
+ tmp[16 * ix + tid] += dall * (s.x * sc[0] + s.y * sc[1] + s.z * sc[4] + s.w * sc[5]) - dmin * smin;
443
541
 
444
- uint8_t hm = 1 << is;
445
- float sum = 0;
446
- for (int k = 0; k < 4; ++k) {
447
- sum += y[k + 0] * (d1 * ((ql[k] & 0xF) + (qh[k] & hm ? 16 : 0)) - m1);
448
542
  }
449
- hm <<= 1;
450
- for (int k = 0; k < 4; ++k) {
451
- sum += y[k + 32] * (d2 * ((ql[k] >> 4) + (qh[k] & hm ? 16 : 0)) - m2);
543
+
544
+ // sum up partial sums and write back result
545
+ barrier(CLK_LOCAL_MEM_FENCE);
546
+ for (int s=16; s>0; s>>=1) {
547
+ if (tid < s) {
548
+ tmp[tid] += tmp[tid + s];
549
+ }
550
+ barrier(CLK_LOCAL_MEM_FENCE);
551
+ }
552
+ if (tid == 0) {
553
+ dst[row] = tmp[0];
554
+ }
555
+ }
556
+
557
+ __kernel void dequantize_mul_mat_vec_q5_K(__global const struct block_q5_K * xx, __local float* tmp, __global float* yy, __global float* dst, const int ncols) {
558
+
559
+ const uint16_t kmask1 = 0x3f3f;
560
+ const uint16_t kmask2 = 0x0f0f;
561
+ const uint16_t kmask3 = 0xc0c0;
562
+
563
+ const int row = get_group_id(0);
564
+ const int num_blocks_per_row = ncols / QK_K;
565
+ const int ib0 = row*num_blocks_per_row;
566
+
567
+ const int tid = get_local_id(0)/2; // 0...15
568
+ const int ix = get_local_id(0)%2;
569
+
570
+ const int il = tid/4; // 0...3
571
+ const int ir = tid - 4*il;// 0...3
572
+ const int n = 2;
573
+
574
+ const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
575
+ const int in = il%2;
576
+
577
+ const int l0 = n*(2*ir + in);
578
+ const int q_offset = 32*im + l0;
579
+ const int y_offset = 64*im + l0;
580
+
581
+ const uint8_t hm1 = 1 << (2*im);
582
+ const uint8_t hm2 = hm1 << 4;
583
+
584
+ uint16_t aux[4];
585
+ const uint8_t * sc = (const uint8_t *)aux;
586
+
587
+ __global const struct block_q5_K * x = xx + ib0;
588
+
589
+ tmp[16 * ix + tid] = 0;
590
+
591
+ for (int i = ix; i < num_blocks_per_row; i += 2) {
592
+
593
+ __global const uint8_t * ql1 = x[i].qs + q_offset;
594
+ __global const uint8_t * ql2 = ql1 + 64;
595
+ __global const uint8_t * qh = x[i].qh + l0;
596
+ __global const float * y1 = yy + i*QK_K + y_offset;
597
+ __global const float * y2 = y1 + 128;
598
+
599
+ const float dall = vload_half(0, &x[i].d);
600
+ const float dmin = vload_half(0, &x[i].dmin);
601
+
602
+ __global const uint16_t * a = (__global const uint16_t *)x[i].scales;
603
+ aux[0] = a[im+0] & kmask1;
604
+ aux[1] = a[im+2] & kmask1;
605
+ aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2);
606
+ aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2);
607
+
608
+ float4 sum = (float4)(0.f);
609
+ float smin = 0;
610
+ for (int l = 0; l < n; ++l) {
611
+ sum.x += y1[l+ 0] * ((ql1[l+ 0] & 0xF) + (qh[l+ 0] & (hm1 << 0) ? 16 : 0))
612
+ + y1[l+16] * ((ql1[l+16] & 0xF) + (qh[l+16] & (hm1 << 0) ? 16 : 0));
613
+ sum.y += y1[l+32] * ((ql1[l+ 0] >> 4) + (qh[l+ 0] & (hm1 << 1) ? 16 : 0))
614
+ + y1[l+48] * ((ql1[l+16] >> 4) + (qh[l+16] & (hm1 << 1) ? 16 : 0));
615
+ sum.z += y2[l+ 0] * ((ql2[l+ 0] & 0xF) + (qh[l+ 0] & (hm2 << 0) ? 16 : 0))
616
+ + y2[l+16] * ((ql2[l+16] & 0xF) + (qh[l+16] & (hm2 << 0) ? 16 : 0));
617
+ sum.w += y2[l+32] * ((ql2[l+ 0] >> 4) + (qh[l+ 0] & (hm2 << 1) ? 16 : 0))
618
+ + y2[l+48] * ((ql2[l+16] >> 4) + (qh[l+16] & (hm2 << 1) ? 16 : 0));
619
+ smin += (y1[l] + y1[l+16]) * sc[2] + (y1[l+32] + y1[l+48]) * sc[3]
620
+ + (y2[l] + y2[l+16]) * sc[6] + (y2[l+32] + y2[l+48]) * sc[7];
621
+ }
622
+ tmp[16 * ix + tid] += dall * (sum.x * sc[0] + sum.y * sc[1] + sum.z * sc[4] + sum.w * sc[5]) - dmin * smin;
623
+
452
624
  }
453
- *result = sum;
454
625
 
626
+ // sum up partial sums and write back result
627
+ barrier(CLK_LOCAL_MEM_FENCE);
628
+ for (int s=16; s>0; s>>=1) {
629
+ if (tid < s) {
630
+ tmp[tid] += tmp[tid + s];
631
+ }
632
+ barrier(CLK_LOCAL_MEM_FENCE);
633
+ }
634
+ if (tid == 0) {
635
+ dst[row] = tmp[0];
636
+ }
455
637
  }
456
638
 
457
- void vec_dot_q6_K(__global const struct block_q6_K* x, const int ib, const int iqs, const __global float *yy, float *result) {
639
+ __kernel void dequantize_mul_mat_vec_q6_K(__global const struct block_q6_K * xx, __local float* tmp, __global const float * yy, __global float * dst, const int ncols) {
640
+
641
+ const int row = get_group_id(0);
458
642
 
643
+ const int num_blocks_per_row = ncols / QK_K;
644
+ const int ib0 = row*num_blocks_per_row;
459
645
 
460
- const int ip = iqs / 128; // 0 or 1
461
- const int il = (iqs - 128*ip)/8; // 0...15
462
- const int is = 8*ip;
646
+ __global const struct block_q6_K * x = xx + ib0;
463
647
 
464
- __global const float * y = yy + 128*ip + il;
648
+ const int tid = get_local_id(0)/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
649
+ const int ix = get_local_id(0)%K_QUANTS_PER_ITERATION; // 0 or 0, 1
465
650
 
466
- const float d = vload_half(0, &x[ib].d);
651
+ const int step = 16/K_QUANTS_PER_ITERATION; // 16 or 8
467
652
 
468
- __global const uint8_t * ql = x[ib].ql + 64*ip + il;
469
- __global const uint8_t * qh = x[ib].qh + 32*ip + il;
470
- __global const int8_t * sc = x[ib].scales + is;
653
+ const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
654
+ const int in = tid - step*im; // 0...15 or 0...7
655
+
656
+ #if K_QUANTS_PER_ITERATION == 1
657
+ const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15
658
+ const int is = 0;
659
+ #else
660
+ const int l0 = 4 * in; // 0, 4, 8, ..., 28
661
+ const int is = in / 4;
662
+ #endif
663
+ const int ql_offset = 64*im + l0;
664
+ const int qh_offset = 32*im + l0;
665
+ const int s_offset = 8*im + is;
666
+ const int y_offset = 128*im + l0;
667
+
668
+ tmp[16 * ix + tid] = 0; // partial sum for thread in warp
669
+
670
+ for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
671
+
672
+ __global const float * y = yy + i * QK_K + y_offset;
673
+ __global const uint8_t * ql = x[i].ql + ql_offset;
674
+ __global const uint8_t * qh = x[i].qh + qh_offset;
675
+ __global const int8_t * s = x[i].scales + s_offset;
676
+
677
+ const float d = vload_half(0, &x[i].d);
678
+
679
+ #if K_QUANTS_PER_ITERATION == 1
680
+ float sum = y[ 0] * s[0] * d * ((int8_t)((ql[ 0] & 0xF) | ((qh[ 0] & 0x03) << 4)) - 32)
681
+ + y[16] * s[1] * d * ((int8_t)((ql[16] & 0xF) | ((qh[16] & 0x03) << 4)) - 32)
682
+ + y[32] * s[2] * d * ((int8_t)((ql[32] & 0xF) | ((qh[ 0] & 0x0c) << 2)) - 32)
683
+ + y[48] * s[3] * d * ((int8_t)((ql[48] & 0xF) | ((qh[16] & 0x0c) << 2)) - 32)
684
+ + y[64] * s[4] * d * ((int8_t)((ql[ 0] >> 4) | ((qh[ 0] & 0x30) >> 0)) - 32)
685
+ + y[80] * s[5] * d * ((int8_t)((ql[16] >> 4) | ((qh[16] & 0x30) >> 0)) - 32)
686
+ + y[96] * s[6] * d * ((int8_t)((ql[32] >> 4) | ((qh[ 0] & 0xc0) >> 2)) - 32)
687
+ +y[112] * s[7] * d * ((int8_t)((ql[48] >> 4) | ((qh[16] & 0xc0) >> 2)) - 32);
688
+ tmp[16 * ix + tid] += sum;
689
+ #else
690
+ float sum = 0;
691
+ for (int l = 0; l < 4; ++l) {
692
+ sum += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32)
693
+ + y[l+32] * s[2] * d * ((int8_t)((ql[l+32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32)
694
+ + y[l+64] * s[4] * d * ((int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32)
695
+ + y[l+96] * s[6] * d * ((int8_t)((ql[l+32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32);
696
+ }
697
+ tmp[16 * ix + tid] += sum;
698
+ #endif
471
699
 
472
- *result = y[ 0] * d * sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh[ 0] >> 0) & 3) << 4)) - 32)
473
- + y[ 32] * d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh[ 0] >> 2) & 3) << 4)) - 32)
474
- + y[ 64] * d * sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh[ 0] >> 4) & 3) << 4)) - 32)
475
- + y[ 96] * d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh[ 0] >> 6) & 3) << 4)) - 32)
476
- + y[ 16] * d * sc[1] * ((int8_t)((ql[16] & 0xF) | (((qh[16] >> 0) & 3) << 4)) - 32)
477
- + y[ 48] * d * sc[3] * ((int8_t)((ql[48] & 0xF) | (((qh[16] >> 2) & 3) << 4)) - 32)
478
- + y[ 80] * d * sc[5] * ((int8_t)((ql[16] >> 4) | (((qh[16] >> 4) & 3) << 4)) - 32)
479
- + y[112] * d * sc[7] * ((int8_t)((ql[48] >> 4) | (((qh[16] >> 6) & 3) << 4)) - 32);
700
+ }
480
701
 
702
+ // sum up partial sums and write back result
703
+ barrier(CLK_LOCAL_MEM_FENCE);
704
+ for (int s=16; s>0; s>>=1) {
705
+ if (tid < s) {
706
+ tmp[tid] += tmp[tid + s];
707
+ }
708
+ barrier(CLK_LOCAL_MEM_FENCE);
709
+ }
710
+ if (tid == 0) {
711
+ dst[row] = tmp[0];
712
+ }
481
713
  }
482
714
 
483
715
  );
@@ -549,44 +781,6 @@ __kernel void KERNEL_NAME(__global X_TYPE* x, __local float* tmp, __global float
549
781
  }
550
782
  );
551
783
 
552
- std::string dequant_mul_mat_vec_k_template = MULTILINE_QUOTE(
553
- __kernel void KERNEL_NAME(__global X_TYPE* x, __local float* tmp, __global float* y, __global float* dst, const int ncols) {
554
- const int block_size = get_local_size(0);
555
- const int row = get_group_id(0);
556
- const int tid = get_local_id(0);
557
-
558
- const int iter_stride = 256;
559
- const int vals_per_iter = iter_stride / block_size;
560
- const int num_blocks_per_row = ncols / 256;
561
- const int ib0 = row*num_blocks_per_row;
562
-
563
- tmp[tid] = 0;
564
-
565
- for (int i = 0; i < ncols; i += iter_stride) {
566
- const int col = i + vals_per_iter*tid;
567
- const int ib = ib0 + col/256; // x block index
568
- const int iqs = col%256; // x quant index
569
- const int iybs = col - col%256; // y block start index
570
-
571
- // dequantize
572
- float v;
573
- DOT_KERNEL(x, ib, iqs, y + iybs, &v);
574
- tmp[tid] += v;
575
- }
576
-
577
- // sum up partial sums and write back result
578
- barrier(CLK_LOCAL_MEM_FENCE);
579
- for (int s=block_size/2; s>0; s>>=1) {
580
- if (tid < s) {
581
- tmp[tid] += tmp[tid + s];
582
- }
583
- barrier(CLK_LOCAL_MEM_FENCE);
584
- }
585
- if (tid == 0) {
586
- dst[row] = tmp[0];
587
- }
588
- }
589
- );
590
784
 
591
785
  std::string mul_template = MULTILINE_QUOTE(
592
786
  __kernel void KERNEL_NAME(__global TYPE* x, const int x_offset, __global TYPE* y, const int y_offset, __global TYPE* dst, const int dst_offset, const int ky) {
@@ -649,18 +843,6 @@ std::array<std::string, 2> mul_str_values = {
649
843
  "mul_f32", "float"
650
844
  };
651
845
 
652
- std::array<std::string, 3> dmmv_k_str_keys = {
653
- "KERNEL_NAME", "X_TYPE", "DOT_KERNEL"
654
- };
655
-
656
- std::array<std::string, 15> dmmv_k_str_values = {
657
- "dequantize_mul_mat_vec_q2_K", "struct block_q2_K", "vec_dot_q2_K",
658
- "dequantize_mul_mat_vec_q3_K", "struct block_q3_K", "vec_dot_q3_K",
659
- "dequantize_mul_mat_vec_q4_K", "struct block_q4_K", "vec_dot_q4_K",
660
- "dequantize_mul_mat_vec_q5_K", "struct block_q5_K", "vec_dot_q5_K",
661
- "dequantize_mul_mat_vec_q6_K", "struct block_q6_K", "vec_dot_q6_K",
662
- };
663
-
664
846
  std::string& replace(std::string& s, const std::string& from, const std::string& to) {
665
847
  size_t pos = 0;
666
848
  while ((pos = s.find(from, pos)) != std::string::npos) {
@@ -673,6 +855,7 @@ std::string& replace(std::string& s, const std::string& from, const std::string&
673
855
  std::string generate_kernels() {
674
856
  std::stringstream src;
675
857
  src << program_source << '\n';
858
+ src << k_quants_source << '\n';
676
859
  for (size_t i = 0; i < dequant_str_values.size(); i += dequant_str_keys.size()) {
677
860
  std::string dequant_kernel = dequant_template;
678
861
  std::string dmmv_kernel = dequant_mul_mat_vec_template;
@@ -690,13 +873,6 @@ std::string generate_kernels() {
690
873
  }
691
874
  src << mul_kernel << '\n';
692
875
  }
693
- for (size_t i = 0; i < dmmv_k_str_values.size(); i += dmmv_k_str_keys.size()) {
694
- std::string dmmv_k_kernel = dequant_mul_mat_vec_k_template;
695
- for (size_t j = 0; j < dmmv_k_str_keys.size(); j++) {
696
- replace(dmmv_k_kernel, dmmv_k_str_keys[j], dmmv_k_str_values[i + j]);
697
- }
698
- src << dmmv_k_kernel << '\n';
699
- }
700
876
 
701
877
  return src.str();
702
878
  }
@@ -729,10 +905,11 @@ static cl_program build_program_from_source(cl_context ctx, cl_device_id dev, co
729
905
  exit(1);
730
906
  }
731
907
 
732
- const char* compile_opts = "-cl-mad-enable -cl-unsafe-math-optimizations -cl-finite-math-only -cl-fast-relaxed-math "
733
- "-DQK4_0=32 -DQR4_0=2 -DQK4_1=32 -DQR4_1=2 -DQK5_0=32 -DQR5_0=2 -DQK5_1=32 -DQR5_1=2 -DQK8_0=32 -DQR8_0=1";
908
+ std::string compile_opts = "-cl-mad-enable -cl-unsafe-math-optimizations -cl-finite-math-only -cl-fast-relaxed-math "
909
+ "-DQK4_0=32 -DQR4_0=2 -DQK4_1=32 -DQR4_1=2 -DQK5_0=32 -DQR5_0=2 -DQK5_1=32 -DQR5_1=2 -DQK8_0=32 -DQR8_0=1 "
910
+ "-DQK_K=256 -DK_QUANTS_PER_ITERATION=" + std::to_string(K_QUANTS_PER_ITERATION);
734
911
 
735
- err = clBuildProgram(p, 0, NULL, compile_opts, NULL, NULL);
912
+ err = clBuildProgram(p, 0, NULL, compile_opts.c_str(), NULL, NULL);
736
913
  if(err < 0) {
737
914
 
738
915
  clGetProgramBuildInfo(p, dev, CL_PROGRAM_BUILD_LOG, 0, NULL, &log_size);