llama_cpp 0.3.0 → 0.3.2

Sign up to get free protection for your applications and to get access to all the features.
@@ -21,11 +21,19 @@
21
21
 
22
22
  #define CL_DMMV_BLOCK_SIZE 32
23
23
 
24
+ #ifndef K_QUANTS_PER_ITERATION
25
+ #define K_QUANTS_PER_ITERATION 1
26
+ #else
27
+ static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2");
28
+ #endif
29
+
24
30
  #define MULTILINE_QUOTE(...) #__VA_ARGS__
25
31
  static std::string program_source = MULTILINE_QUOTE(
26
32
 
27
33
  typedef char int8_t;
28
34
  typedef uchar uint8_t;
35
+ typedef short int16_t;
36
+ typedef ushort uint16_t;
29
37
  typedef int int32_t;
30
38
  typedef uint uint32_t;
31
39
 
@@ -175,7 +183,9 @@ void convert_f16(__global half* x, const int ib, const int iqs, float* v0, float
175
183
  *v0 = vload_half(0, &x[ib + 0]);
176
184
  *v1 = vload_half(0, &x[ib + 1]);
177
185
  }
186
+ );
178
187
 
188
+ static std::string k_quants_source = MULTILINE_QUOTE(
179
189
  inline void get_scale_min_k4(int j, const __global uint8_t *q, uint8_t *d, uint8_t *m)
180
190
  {
181
191
  if (j < 4)
@@ -199,7 +209,7 @@ __kernel void dequantize_block_q2_K(__global const struct block_q2_K *x, __globa
199
209
  const int is = 8 * n + l / 16;
200
210
 
201
211
  const uint8_t q = x[i].qs[32 * n + l];
202
- __global float *y = yy + i * 256 + 128 * n;
212
+ __global float *y = yy + i * QK_K + 128 * n;
203
213
 
204
214
  const float dall = vload_half(0, &x[i].d);
205
215
  const float dmin = vload_half(0, &x[i].dmin);
@@ -231,7 +241,7 @@ __kernel void dequantize_block_q3_K(__global const struct block_q3_K *x, __globa
231
241
  float d_all = vload_half(0, &x[i].d);
232
242
  float dl = d_all * (us - 32);
233
243
 
234
- __global float *y = yy + i * 256 + 128 * n + 32 * j;
244
+ __global float *y = yy + i * QK_K + 128 * n + 32 * j;
235
245
  const __global uint8_t *q = x[i].qs + 32 * n;
236
246
  const __global uint8_t *hm = x[i].hmask;
237
247
 
@@ -248,7 +258,7 @@ __kernel void dequantize_block_q4_K(__global const struct block_q4_K *x, __globa
248
258
  const int is = 2 * il;
249
259
  const int n = 4;
250
260
 
251
- __global float *y = yy + i * 256 + 64 * il + n * ir;
261
+ __global float *y = yy + i * QK_K + 64 * il + n * ir;
252
262
 
253
263
  const float dall = vload_half(0, &x[i].d);
254
264
  const float dmin = vload_half(0, &x[i].dmin);
@@ -277,7 +287,7 @@ __kernel void dequantize_block_q5_K(__global const struct block_q5_K *x, __globa
277
287
  const int ir = tid % 16;
278
288
  const int is = 2 * il;
279
289
 
280
- __global float *y = yy + i * 256 + 64 * il + 2 * ir;
290
+ __global float *y = yy + i * QK_K + 64 * il + 2 * ir;
281
291
 
282
292
  const float dall = vload_half(0, &x[i].d);
283
293
  const float dmin = vload_half(0, &x[i].dmin);
@@ -309,7 +319,7 @@ __kernel void dequantize_block_q6_K(__global const struct block_q6_K *x, __globa
309
319
  const int il = tid - 32 * ip;
310
320
  const int is = 8 * ip + il / 16;
311
321
 
312
- __global float *y = yy + i * 256 + 128 * ip + il;
322
+ __global float *y = yy + i * QK_K + 128 * ip + il;
313
323
 
314
324
  const float d = vload_half(0, &x[i].d);
315
325
 
@@ -323,161 +333,387 @@ __kernel void dequantize_block_q6_K(__global const struct block_q6_K *x, __globa
323
333
  y[96] = d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32);
324
334
  }
325
335
 
336
+ __kernel void dequantize_mul_mat_vec_q2_K(__global const struct block_q2_K * xx, __local float* tmp, __global float* yy, __global float* dst, const int ncols) {
326
337
 
327
- void vec_dot_q2_K(__global const struct block_q2_K* x, const int ib, const int iqs, const __global float *yy, float *result) {
338
+ const int row = get_group_id(0);
328
339
 
329
- int n = iqs / 128;
330
- int r = iqs - 128 * n;
331
- int l = r / 8;
340
+ const int num_blocks_per_row = ncols / QK_K;
341
+ const int ib0 = row*num_blocks_per_row;
332
342
 
333
- __global const float *y = yy + 128 * n + l;
334
- __global const uint8_t *q = x[ib].qs + 32 * n + l;
335
- __global const uint8_t *s = x[ib].scales + 8 * n;
343
+ __global const struct block_q2_K * x = xx + ib0;
336
344
 
337
- const float dall = vload_half(0, &x[ib].d);
338
- const float dmin = vload_half(0, &x[ib].dmin);
345
+ const int tid = get_local_id(0)/K_QUANTS_PER_ITERATION; // 0...31 or 0...15
346
+ const int ix = get_local_id(0)%K_QUANTS_PER_ITERATION; // 0 or 0,1
339
347
 
340
- float sum = y[ 0] * (dall * ((s[0] & 0xF) * ((q[ 0] >> 0) & 3)) - dmin * (s[0] >> 4))
341
- + y[ 32] * (dall * ((s[2] & 0xF) * ((q[ 0] >> 2) & 3)) - dmin * (s[2] >> 4))
342
- + y[ 64] * (dall * ((s[4] & 0xF) * ((q[ 0] >> 4) & 3)) - dmin * (s[4] >> 4))
343
- + y[ 96] * (dall * ((s[6] & 0xF) * ((q[ 0] >> 6) & 3)) - dmin * (s[6] >> 4))
344
- + y[ 16] * (dall * ((s[1] & 0xF) * ((q[16] >> 0) & 3)) - dmin * (s[1] >> 4))
345
- + y[ 48] * (dall * ((s[3] & 0xF) * ((q[16] >> 2) & 3)) - dmin * (s[3] >> 4))
346
- + y[ 80] * (dall * ((s[5] & 0xF) * ((q[16] >> 4) & 3)) - dmin * (s[5] >> 4))
347
- + y[112] * (dall * ((s[7] & 0xF) * ((q[16] >> 6) & 3)) - dmin * (s[7] >> 4));
348
+ const int step = 16/K_QUANTS_PER_ITERATION;
348
349
 
349
- *result = sum;
350
- }
350
+ const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
351
+ const int in = tid - step*im; // 0...15 or 0...7
351
352
 
352
- void vec_dot_q3_K(__global const struct block_q3_K* x, const int ib, const int iqs, const __global float *yy, float *result) {
353
+ const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15 or 0...14 in steps of 2
354
+ const int q_offset = 32*im + l0;
355
+ const int s_offset = 8*im;
356
+ const int y_offset = 128*im + l0;
353
357
 
354
- const uint32_t kmask1 = 0x03030303;
355
- const uint32_t kmask2 = 0x0f0f0f0f;
358
+ tmp[16 * ix + tid] = 0;
356
359
 
357
- uint32_t aux[3];
358
- uint32_t utmp[4];
360
+ uint32_t aux[4];
361
+ const uint8_t * d = (const uint8_t *)aux;
362
+ const uint8_t * m = (const uint8_t *)(aux + 2);
359
363
 
360
- int n = iqs/128;
361
- int r = iqs - 128*n;
362
- int l = r/8;
364
+ for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
363
365
 
364
- __global const float * y = yy + 128*n + l;
365
- __global const uint8_t * q = x[ib].qs + 32*n + l;
366
- __global const uint8_t * hm = x[ib].hmask + l;
367
- const int8_t * s = (const int8_t *)utmp + 8*n;
366
+ __global const float * y = yy + i * QK_K + y_offset;
367
+ __global const uint8_t * q = x[i].qs + q_offset;
368
368
 
369
- aux[0] = x[ib].scales[0] | x[ib].scales[1] << 8 | x[ib].scales[2] << 16 | x[ib].scales[3] << 24;
370
- aux[1] = x[ib].scales[4] | x[ib].scales[5] << 8 | x[ib].scales[6] << 16 | x[ib].scales[7] << 24;
371
- aux[2] = x[ib].scales[8] | x[ib].scales[9] << 8 | x[ib].scales[10] << 16 | x[ib].scales[11] << 24;
369
+ const float dall = vload_half(0, &x[i].d);
370
+ const float dmin = vload_half(0, &x[i].dmin);
372
371
 
373
- utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
374
- utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
375
- utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
376
- utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
372
+ __global const uint32_t * a = (__global const uint32_t *)(x[i].scales + s_offset);
373
+ aux[0] = a[0] & 0x0f0f0f0f;
374
+ aux[1] = a[1] & 0x0f0f0f0f;
375
+ aux[2] = (a[0] >> 4) & 0x0f0f0f0f;
376
+ aux[3] = (a[1] >> 4) & 0x0f0f0f0f;
377
377
 
378
- const float dall = vload_half(0, &x[ib].d);
379
- const uint8_t m = 1 << (4*n);
378
+ float sum1 = 0, sum2 = 0;
379
+ for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
380
+ sum1 += y[l+ 0] * d[0] * ((q[l+ 0] >> 0) & 3)
381
+ + y[l+32] * d[2] * ((q[l+ 0] >> 2) & 3)
382
+ + y[l+64] * d[4] * ((q[l+ 0] >> 4) & 3)
383
+ + y[l+96] * d[6] * ((q[l+ 0] >> 6) & 3)
384
+ + y[l+16] * d[1] * ((q[l+16] >> 0) & 3)
385
+ + y[l+48] * d[3] * ((q[l+16] >> 2) & 3)
386
+ + y[l+80] * d[5] * ((q[l+16] >> 4) & 3)
387
+ +y[l+112] * d[7] * ((q[l+16] >> 6) & 3);
388
+ sum2 += y[l+ 0] * m[0] + y[l+32] * m[2] + y[l+64] * m[4] + y[ l+96] * m[6]
389
+ + y[l+16] * m[1] + y[l+48] * m[3] + y[l+80] * m[5] + y[l+112] * m[7];
380
390
 
381
- float sum = y[ 0] * (s[0] - 32) * (((q[ 0] >> 0) & 3) - (hm[ 0] & (m << 0) ? 0 : 4))
382
- + y[ 32] * (s[2] - 32) * (((q[ 0] >> 2) & 3) - (hm[ 0] & (m << 1) ? 0 : 4))
383
- + y[ 64] * (s[4] - 32) * (((q[ 0] >> 4) & 3) - (hm[ 0] & (m << 2) ? 0 : 4))
384
- + y[ 96] * (s[6] - 32) * (((q[ 0] >> 6) & 3) - (hm[ 0] & (m << 3) ? 0 : 4))
385
- + y[ 16] * (s[1] - 32) * (((q[16] >> 0) & 3) - (hm[16] & (m << 0) ? 0 : 4))
386
- + y[ 48] * (s[3] - 32) * (((q[16] >> 2) & 3) - (hm[16] & (m << 1) ? 0 : 4))
387
- + y[ 80] * (s[5] - 32) * (((q[16] >> 4) & 3) - (hm[16] & (m << 2) ? 0 : 4))
388
- + y[112] * (s[7] - 32) * (((q[16] >> 6) & 3) - (hm[16] & (m << 3) ? 0 : 4));
391
+ }
392
+ tmp[16 * ix + tid] += dall * sum1 - dmin * sum2;
389
393
 
390
- *result = sum * dall;
394
+ }
391
395
 
396
+ // sum up partial sums and write back result
397
+ barrier(CLK_LOCAL_MEM_FENCE);
398
+ for (int s=16; s>0; s>>=1) {
399
+ if (tid < s) {
400
+ tmp[tid] += tmp[tid + s];
401
+ }
402
+ barrier(CLK_LOCAL_MEM_FENCE);
403
+ }
404
+ if (tid == 0) {
405
+ dst[row] = tmp[0];
406
+ }
392
407
  }
393
408
 
394
- void vec_dot_q4_K(__global const struct block_q4_K* x, const int ib, const int iqs, const __global float *yy, float *result) {
409
+ __kernel void dequantize_mul_mat_vec_q3_K(__global const struct block_q3_K * xx, __local float* tmp, __global float* yy, __global float* dst, const int ncols) {
410
+ const uint16_t kmask1 = 0x0303;
411
+ const uint16_t kmask2 = 0x0f0f;
395
412
 
396
- const int j = iqs / 64; // j is in 0...3
397
- const int ir = (iqs - 64*j)/2; // ir is in 0...28 in steps of 4
398
- const int is = 2*j; // is is in 0...6 in steps of 2
413
+ const int row = get_group_id(0);
399
414
 
400
- __global const float * y = yy + 64*j + ir;
401
- __global const uint8_t * q = x[ib].qs + 32*j + ir;
415
+ const int num_blocks_per_row = ncols / QK_K;
416
+ const int ib0 = row*num_blocks_per_row;
402
417
 
403
- const float dall = vload_half(0, &x[ib].d);
404
- const float dmin = vload_half(0, &x[ib].dmin);
418
+ __global const struct block_q3_K * x = xx + ib0;
405
419
 
406
- uint8_t sc, m;
407
- get_scale_min_k4(is + 0, x[ib].scales, &sc, &m);
408
- const float d1 = dall * sc;
409
- const float m1 = dmin * m;
410
- get_scale_min_k4(is + 1, x[ib].scales, &sc, &m);
411
- const float d2 = dall * sc;
412
- const float m2 = dmin * m;
420
+ const int tid = get_local_id(0)/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
421
+ const int ix = get_local_id(0)%K_QUANTS_PER_ITERATION; // 0 or 0,1
422
+
423
+ const int n = K_QUANTS_PER_ITERATION; // iterations in the inner loop
424
+ const int step = 16/K_QUANTS_PER_ITERATION;
425
+ const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
426
+ const int in = tid - step*im; // 0....15 or 0...7
427
+
428
+ const uint8_t m = 1 << (4*im);
429
+
430
+ const int l0 = n*in; // 0...15 or 0...14 in steps of 2
431
+ const int q_offset = 32*im + l0;
432
+ const int y_offset = 128*im + l0;
433
+
434
+ uint16_t utmp[4];
435
+ const int8_t * s = (const int8_t *)utmp;
436
+
437
+ const uint16_t s_shift = 4*im;
438
+
439
+ tmp[16 * ix + tid] = 0;
440
+
441
+ for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
442
+
443
+ __global const float * y = yy + i * QK_K + y_offset;
444
+ __global const uint8_t * q = x[i].qs + q_offset;
445
+ __global const uint8_t * h = x[i].hmask + l0;
446
+
447
+ __global const uint16_t * a = (__global const uint16_t *)x[i].scales;
448
+ utmp[0] = ((a[0] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 0)) & kmask1) << 4);
449
+ utmp[1] = ((a[1] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 0)) & kmask1) << 4);
450
+ utmp[2] = ((a[2] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 2)) & kmask1) << 4);
451
+ utmp[3] = ((a[3] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 2)) & kmask1) << 4);
452
+
453
+ const float d = vload_half(0, &x[i].d);
454
+
455
+ float sum = 0;
456
+ for (int l = 0; l < n; ++l) {
457
+ sum += y[l+ 0] * (s[0] - 32) * (((q[l] >> 0) & 3) - (h[l] & (m << 0) ? 0 : 4))
458
+ + y[l+32] * (s[2] - 32) * (((q[l] >> 2) & 3) - (h[l] & (m << 1) ? 0 : 4))
459
+ + y[l+64] * (s[4] - 32) * (((q[l] >> 4) & 3) - (h[l] & (m << 2) ? 0 : 4))
460
+ + y[l+96] * (s[6] - 32) * (((q[l] >> 6) & 3) - (h[l] & (m << 3) ? 0 : 4));
461
+ sum += y[l+16] * (s[1] - 32) * (((q[l+16] >> 0) & 3) - (h[l+16] & (m << 0) ? 0 : 4))
462
+ + y[l+48] * (s[3] - 32) * (((q[l+16] >> 2) & 3) - (h[l+16] & (m << 1) ? 0 : 4))
463
+ + y[l+80] * (s[5] - 32) * (((q[l+16] >> 4) & 3) - (h[l+16] & (m << 2) ? 0 : 4))
464
+ + y[l+112] * (s[7] - 32) * (((q[l+16] >> 6) & 3) - (h[l+16] & (m << 3) ? 0 : 4));
465
+ }
466
+ tmp[16 * ix + tid] += d * sum;
413
467
 
414
- float sum = 0;
415
- for (int k = 0; k < 4; ++k) {
416
- sum += y[k + 0] * (d1 * (q[k] & 0xF) - m1);
417
- sum += y[k + 32] * (d2 * (q[k] >> 4) - m2);
418
468
  }
419
469
 
420
- *result = sum;
470
+ // sum up partial sums and write back result
471
+ barrier(CLK_LOCAL_MEM_FENCE);
472
+ for (int s=16; s>0; s>>=1) {
473
+ if (tid < s) {
474
+ tmp[tid] += tmp[tid + s];
475
+ }
476
+ barrier(CLK_LOCAL_MEM_FENCE);
477
+ }
478
+ if (tid == 0) {
479
+ dst[row] = tmp[0];
480
+ }
421
481
  }
422
482
 
423
- void vec_dot_q5_K(__global const struct block_q5_K* x, const int ib, const int iqs, const __global float *yy, float *result) {
483
+ __kernel void dequantize_mul_mat_vec_q4_K(__global const struct block_q4_K * xx, __local float* tmp, __global float* yy, __global float* dst, const int ncols) {
424
484
 
425
- const int j = iqs / 64;
426
- const int ir = (iqs - 64*j)/2;
427
- const int is = 2*j;
485
+ //to rename it later, just to test now
486
+ const uint16_t kmask1 = 0x3f3f;
487
+ const uint16_t kmask2 = 0x0f0f;
488
+ const uint16_t kmask3 = 0xc0c0;
428
489
 
429
- __global const float * y = yy + 64*j + ir;
430
- __global const uint8_t * ql = x[ib].qs + 32*j + ir;
431
- __global const uint8_t * qh = x[ib].qh + ir;
490
+ const int row = get_group_id(0);
491
+ const int num_blocks_per_row = ncols / QK_K;
492
+ const int ib0 = row*num_blocks_per_row;
432
493
 
433
- const float dall = vload_half(0, &x[ib].d);
434
- const float dmin = vload_half(0, &x[ib].dmin);
494
+ const int tid = get_local_id(0)/K_QUANTS_PER_ITERATION; // 0...15
495
+ const int ix = get_local_id(0)%K_QUANTS_PER_ITERATION;
435
496
 
436
- uint8_t sc, m;
437
- get_scale_min_k4(is + 0, x[ib].scales, &sc, &m);
438
- const float d1 = dall * sc;
439
- const float m1 = dmin * m;
440
- get_scale_min_k4(is + 1, x[ib].scales, &sc, &m);
441
- const float d2 = dall * sc;
442
- const float m2 = dmin * m;
497
+ const int step = 8/K_QUANTS_PER_ITERATION;
498
+
499
+ const int il = tid/step; // 0...3
500
+ const int ir = tid - step*il;// 0...3
501
+ const int n = 2*K_QUANTS_PER_ITERATION;
502
+
503
+ const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
504
+ const int in = il%2;
505
+
506
+ const int l0 = n*(2*ir + in);
507
+ const int q_offset = 32*im + l0;
508
+ const int y_offset = 64*im + l0;
509
+
510
+ uint16_t aux[4];
511
+ const uint8_t * sc = (const uint8_t *)aux;
512
+
513
+ __global const struct block_q4_K * x = xx + ib0;
514
+
515
+ tmp[16 * ix + tid] = 0;
516
+
517
+ for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
518
+
519
+ __global const uint8_t * q1 = x[i].qs + q_offset;
520
+ __global const uint8_t * q2 = q1 + 64;
521
+ __global const float * y1 = yy + i*QK_K + y_offset;
522
+ __global const float * y2 = y1 + 128;
523
+
524
+ const float dall = vload_half(0, &x[i].d);
525
+ const float dmin = vload_half(0, &x[i].dmin);
526
+
527
+ __global const uint16_t * a = (__global const uint16_t *)x[i].scales;
528
+ aux[0] = a[im+0] & kmask1;
529
+ aux[1] = a[im+2] & kmask1;
530
+ aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2);
531
+ aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2);
532
+
533
+ float4 s = (float4)(0.f);
534
+ float smin = 0;
535
+ for (int l = 0; l < n; ++l) {
536
+ s.x += y1[l] * (q1[l] & 0xF); s.y += y1[l+32] * (q1[l] >> 4);
537
+ s.z += y2[l] * (q2[l] & 0xF); s.w += y2[l+32] * (q2[l] >> 4);
538
+ smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7];
539
+ }
540
+ tmp[16 * ix + tid] += dall * (s.x * sc[0] + s.y * sc[1] + s.z * sc[4] + s.w * sc[5]) - dmin * smin;
443
541
 
444
- uint8_t hm = 1 << is;
445
- float sum = 0;
446
- for (int k = 0; k < 4; ++k) {
447
- sum += y[k + 0] * (d1 * ((ql[k] & 0xF) + (qh[k] & hm ? 16 : 0)) - m1);
448
542
  }
449
- hm <<= 1;
450
- for (int k = 0; k < 4; ++k) {
451
- sum += y[k + 32] * (d2 * ((ql[k] >> 4) + (qh[k] & hm ? 16 : 0)) - m2);
543
+
544
+ // sum up partial sums and write back result
545
+ barrier(CLK_LOCAL_MEM_FENCE);
546
+ for (int s=16; s>0; s>>=1) {
547
+ if (tid < s) {
548
+ tmp[tid] += tmp[tid + s];
549
+ }
550
+ barrier(CLK_LOCAL_MEM_FENCE);
551
+ }
552
+ if (tid == 0) {
553
+ dst[row] = tmp[0];
554
+ }
555
+ }
556
+
557
+ __kernel void dequantize_mul_mat_vec_q5_K(__global const struct block_q5_K * xx, __local float* tmp, __global float* yy, __global float* dst, const int ncols) {
558
+
559
+ const uint16_t kmask1 = 0x3f3f;
560
+ const uint16_t kmask2 = 0x0f0f;
561
+ const uint16_t kmask3 = 0xc0c0;
562
+
563
+ const int row = get_group_id(0);
564
+ const int num_blocks_per_row = ncols / QK_K;
565
+ const int ib0 = row*num_blocks_per_row;
566
+
567
+ const int tid = get_local_id(0)/2; // 0...15
568
+ const int ix = get_local_id(0)%2;
569
+
570
+ const int il = tid/4; // 0...3
571
+ const int ir = tid - 4*il;// 0...3
572
+ const int n = 2;
573
+
574
+ const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
575
+ const int in = il%2;
576
+
577
+ const int l0 = n*(2*ir + in);
578
+ const int q_offset = 32*im + l0;
579
+ const int y_offset = 64*im + l0;
580
+
581
+ const uint8_t hm1 = 1 << (2*im);
582
+ const uint8_t hm2 = hm1 << 4;
583
+
584
+ uint16_t aux[4];
585
+ const uint8_t * sc = (const uint8_t *)aux;
586
+
587
+ __global const struct block_q5_K * x = xx + ib0;
588
+
589
+ tmp[16 * ix + tid] = 0;
590
+
591
+ for (int i = ix; i < num_blocks_per_row; i += 2) {
592
+
593
+ __global const uint8_t * ql1 = x[i].qs + q_offset;
594
+ __global const uint8_t * ql2 = ql1 + 64;
595
+ __global const uint8_t * qh = x[i].qh + l0;
596
+ __global const float * y1 = yy + i*QK_K + y_offset;
597
+ __global const float * y2 = y1 + 128;
598
+
599
+ const float dall = vload_half(0, &x[i].d);
600
+ const float dmin = vload_half(0, &x[i].dmin);
601
+
602
+ __global const uint16_t * a = (__global const uint16_t *)x[i].scales;
603
+ aux[0] = a[im+0] & kmask1;
604
+ aux[1] = a[im+2] & kmask1;
605
+ aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2);
606
+ aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2);
607
+
608
+ float4 sum = (float4)(0.f);
609
+ float smin = 0;
610
+ for (int l = 0; l < n; ++l) {
611
+ sum.x += y1[l+ 0] * ((ql1[l+ 0] & 0xF) + (qh[l+ 0] & (hm1 << 0) ? 16 : 0))
612
+ + y1[l+16] * ((ql1[l+16] & 0xF) + (qh[l+16] & (hm1 << 0) ? 16 : 0));
613
+ sum.y += y1[l+32] * ((ql1[l+ 0] >> 4) + (qh[l+ 0] & (hm1 << 1) ? 16 : 0))
614
+ + y1[l+48] * ((ql1[l+16] >> 4) + (qh[l+16] & (hm1 << 1) ? 16 : 0));
615
+ sum.z += y2[l+ 0] * ((ql2[l+ 0] & 0xF) + (qh[l+ 0] & (hm2 << 0) ? 16 : 0))
616
+ + y2[l+16] * ((ql2[l+16] & 0xF) + (qh[l+16] & (hm2 << 0) ? 16 : 0));
617
+ sum.w += y2[l+32] * ((ql2[l+ 0] >> 4) + (qh[l+ 0] & (hm2 << 1) ? 16 : 0))
618
+ + y2[l+48] * ((ql2[l+16] >> 4) + (qh[l+16] & (hm2 << 1) ? 16 : 0));
619
+ smin += (y1[l] + y1[l+16]) * sc[2] + (y1[l+32] + y1[l+48]) * sc[3]
620
+ + (y2[l] + y2[l+16]) * sc[6] + (y2[l+32] + y2[l+48]) * sc[7];
621
+ }
622
+ tmp[16 * ix + tid] += dall * (sum.x * sc[0] + sum.y * sc[1] + sum.z * sc[4] + sum.w * sc[5]) - dmin * smin;
623
+
452
624
  }
453
- *result = sum;
454
625
 
626
+ // sum up partial sums and write back result
627
+ barrier(CLK_LOCAL_MEM_FENCE);
628
+ for (int s=16; s>0; s>>=1) {
629
+ if (tid < s) {
630
+ tmp[tid] += tmp[tid + s];
631
+ }
632
+ barrier(CLK_LOCAL_MEM_FENCE);
633
+ }
634
+ if (tid == 0) {
635
+ dst[row] = tmp[0];
636
+ }
455
637
  }
456
638
 
457
- void vec_dot_q6_K(__global const struct block_q6_K* x, const int ib, const int iqs, const __global float *yy, float *result) {
639
+ __kernel void dequantize_mul_mat_vec_q6_K(__global const struct block_q6_K * xx, __local float* tmp, __global const float * yy, __global float * dst, const int ncols) {
458
640
 
641
+ const int row = get_group_id(0);
459
642
 
460
- const int ip = iqs / 128; // 0 or 1
461
- const int il = (iqs - 128*ip)/8; // 0...15
462
- const int is = 8*ip;
643
+ const int num_blocks_per_row = ncols / QK_K;
644
+ const int ib0 = row*num_blocks_per_row;
463
645
 
464
- __global const float * y = yy + 128*ip + il;
646
+ __global const struct block_q6_K * x = xx + ib0;
465
647
 
466
- const float d = vload_half(0, &x[ib].d);
648
+ const int tid = get_local_id(0)/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
649
+ const int ix = get_local_id(0)%K_QUANTS_PER_ITERATION; // 0 or 0, 1
650
+
651
+ const int step = 16/K_QUANTS_PER_ITERATION; // 16 or 8
652
+
653
+ const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
654
+ const int in = tid - step*im; // 0...15 or 0...7
655
+
656
+ \n#if K_QUANTS_PER_ITERATION == 1\n
657
+ const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15
658
+ const int is = 0;
659
+
660
+ \n#else\n
467
661
 
468
- __global const uint8_t * ql = x[ib].ql + 64*ip + il;
469
- __global const uint8_t * qh = x[ib].qh + 32*ip + il;
470
- __global const int8_t * sc = x[ib].scales + is;
662
+ const int l0 = 4 * in; // 0, 4, 8, ..., 28
663
+ const int is = in / 4;
471
664
 
472
- *result = y[ 0] * d * sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh[ 0] >> 0) & 3) << 4)) - 32)
473
- + y[ 32] * d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh[ 0] >> 2) & 3) << 4)) - 32)
474
- + y[ 64] * d * sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh[ 0] >> 4) & 3) << 4)) - 32)
475
- + y[ 96] * d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh[ 0] >> 6) & 3) << 4)) - 32)
476
- + y[ 16] * d * sc[1] * ((int8_t)((ql[16] & 0xF) | (((qh[16] >> 0) & 3) << 4)) - 32)
477
- + y[ 48] * d * sc[3] * ((int8_t)((ql[48] & 0xF) | (((qh[16] >> 2) & 3) << 4)) - 32)
478
- + y[ 80] * d * sc[5] * ((int8_t)((ql[16] >> 4) | (((qh[16] >> 4) & 3) << 4)) - 32)
479
- + y[112] * d * sc[7] * ((int8_t)((ql[48] >> 4) | (((qh[16] >> 6) & 3) << 4)) - 32);
665
+ \n#endif\n
480
666
 
667
+ const int ql_offset = 64*im + l0;
668
+ const int qh_offset = 32*im + l0;
669
+ const int s_offset = 8*im + is;
670
+ const int y_offset = 128*im + l0;
671
+
672
+ tmp[16 * ix + tid] = 0; // partial sum for thread in warp
673
+
674
+ for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
675
+
676
+ __global const float * y = yy + i * QK_K + y_offset;
677
+ __global const uint8_t * ql = x[i].ql + ql_offset;
678
+ __global const uint8_t * qh = x[i].qh + qh_offset;
679
+ __global const int8_t * s = x[i].scales + s_offset;
680
+
681
+ const float d = vload_half(0, &x[i].d);
682
+
683
+ \n#if K_QUANTS_PER_ITERATION == 1\n
684
+ float sum = y[ 0] * s[0] * d * ((int8_t)((ql[ 0] & 0xF) | ((qh[ 0] & 0x03) << 4)) - 32)
685
+ + y[16] * s[1] * d * ((int8_t)((ql[16] & 0xF) | ((qh[16] & 0x03) << 4)) - 32)
686
+ + y[32] * s[2] * d * ((int8_t)((ql[32] & 0xF) | ((qh[ 0] & 0x0c) << 2)) - 32)
687
+ + y[48] * s[3] * d * ((int8_t)((ql[48] & 0xF) | ((qh[16] & 0x0c) << 2)) - 32)
688
+ + y[64] * s[4] * d * ((int8_t)((ql[ 0] >> 4) | ((qh[ 0] & 0x30) >> 0)) - 32)
689
+ + y[80] * s[5] * d * ((int8_t)((ql[16] >> 4) | ((qh[16] & 0x30) >> 0)) - 32)
690
+ + y[96] * s[6] * d * ((int8_t)((ql[32] >> 4) | ((qh[ 0] & 0xc0) >> 2)) - 32)
691
+ +y[112] * s[7] * d * ((int8_t)((ql[48] >> 4) | ((qh[16] & 0xc0) >> 2)) - 32);
692
+ tmp[16 * ix + tid] += sum;
693
+ \n#else\n
694
+ float sum = 0;
695
+ for (int l = 0; l < 4; ++l) {
696
+ sum += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32)
697
+ + y[l+32] * s[2] * d * ((int8_t)((ql[l+32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32)
698
+ + y[l+64] * s[4] * d * ((int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32)
699
+ + y[l+96] * s[6] * d * ((int8_t)((ql[l+32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32);
700
+ }
701
+ tmp[16 * ix + tid] += sum;
702
+ \n#endif\n
703
+
704
+ }
705
+
706
+ // sum up partial sums and write back result
707
+ barrier(CLK_LOCAL_MEM_FENCE);
708
+ for (int s=16; s>0; s>>=1) {
709
+ if (tid < s) {
710
+ tmp[tid] += tmp[tid + s];
711
+ }
712
+ barrier(CLK_LOCAL_MEM_FENCE);
713
+ }
714
+ if (tid == 0) {
715
+ dst[row] = tmp[0];
716
+ }
481
717
  }
482
718
 
483
719
  );
@@ -549,44 +785,6 @@ __kernel void KERNEL_NAME(__global X_TYPE* x, __local float* tmp, __global float
549
785
  }
550
786
  );
551
787
 
552
- std::string dequant_mul_mat_vec_k_template = MULTILINE_QUOTE(
553
- __kernel void KERNEL_NAME(__global X_TYPE* x, __local float* tmp, __global float* y, __global float* dst, const int ncols) {
554
- const int block_size = get_local_size(0);
555
- const int row = get_group_id(0);
556
- const int tid = get_local_id(0);
557
-
558
- const int iter_stride = 256;
559
- const int vals_per_iter = iter_stride / block_size;
560
- const int num_blocks_per_row = ncols / 256;
561
- const int ib0 = row*num_blocks_per_row;
562
-
563
- tmp[tid] = 0;
564
-
565
- for (int i = 0; i < ncols; i += iter_stride) {
566
- const int col = i + vals_per_iter*tid;
567
- const int ib = ib0 + col/256; // x block index
568
- const int iqs = col%256; // x quant index
569
- const int iybs = col - col%256; // y block start index
570
-
571
- // dequantize
572
- float v;
573
- DOT_KERNEL(x, ib, iqs, y + iybs, &v);
574
- tmp[tid] += v;
575
- }
576
-
577
- // sum up partial sums and write back result
578
- barrier(CLK_LOCAL_MEM_FENCE);
579
- for (int s=block_size/2; s>0; s>>=1) {
580
- if (tid < s) {
581
- tmp[tid] += tmp[tid + s];
582
- }
583
- barrier(CLK_LOCAL_MEM_FENCE);
584
- }
585
- if (tid == 0) {
586
- dst[row] = tmp[0];
587
- }
588
- }
589
- );
590
788
 
591
789
  std::string mul_template = MULTILINE_QUOTE(
592
790
  __kernel void KERNEL_NAME(__global TYPE* x, const int x_offset, __global TYPE* y, const int y_offset, __global TYPE* dst, const int dst_offset, const int ky) {
@@ -649,18 +847,6 @@ std::array<std::string, 2> mul_str_values = {
649
847
  "mul_f32", "float"
650
848
  };
651
849
 
652
- std::array<std::string, 3> dmmv_k_str_keys = {
653
- "KERNEL_NAME", "X_TYPE", "DOT_KERNEL"
654
- };
655
-
656
- std::array<std::string, 15> dmmv_k_str_values = {
657
- "dequantize_mul_mat_vec_q2_K", "struct block_q2_K", "vec_dot_q2_K",
658
- "dequantize_mul_mat_vec_q3_K", "struct block_q3_K", "vec_dot_q3_K",
659
- "dequantize_mul_mat_vec_q4_K", "struct block_q4_K", "vec_dot_q4_K",
660
- "dequantize_mul_mat_vec_q5_K", "struct block_q5_K", "vec_dot_q5_K",
661
- "dequantize_mul_mat_vec_q6_K", "struct block_q6_K", "vec_dot_q6_K",
662
- };
663
-
664
850
  std::string& replace(std::string& s, const std::string& from, const std::string& to) {
665
851
  size_t pos = 0;
666
852
  while ((pos = s.find(from, pos)) != std::string::npos) {
@@ -673,6 +859,7 @@ std::string& replace(std::string& s, const std::string& from, const std::string&
673
859
  std::string generate_kernels() {
674
860
  std::stringstream src;
675
861
  src << program_source << '\n';
862
+ src << k_quants_source << '\n';
676
863
  for (size_t i = 0; i < dequant_str_values.size(); i += dequant_str_keys.size()) {
677
864
  std::string dequant_kernel = dequant_template;
678
865
  std::string dmmv_kernel = dequant_mul_mat_vec_template;
@@ -690,13 +877,6 @@ std::string generate_kernels() {
690
877
  }
691
878
  src << mul_kernel << '\n';
692
879
  }
693
- for (size_t i = 0; i < dmmv_k_str_values.size(); i += dmmv_k_str_keys.size()) {
694
- std::string dmmv_k_kernel = dequant_mul_mat_vec_k_template;
695
- for (size_t j = 0; j < dmmv_k_str_keys.size(); j++) {
696
- replace(dmmv_k_kernel, dmmv_k_str_keys[j], dmmv_k_str_values[i + j]);
697
- }
698
- src << dmmv_k_kernel << '\n';
699
- }
700
880
 
701
881
  return src.str();
702
882
  }
@@ -729,10 +909,11 @@ static cl_program build_program_from_source(cl_context ctx, cl_device_id dev, co
729
909
  exit(1);
730
910
  }
731
911
 
732
- const char* compile_opts = "-cl-mad-enable -cl-unsafe-math-optimizations -cl-finite-math-only -cl-fast-relaxed-math "
733
- "-DQK4_0=32 -DQR4_0=2 -DQK4_1=32 -DQR4_1=2 -DQK5_0=32 -DQR5_0=2 -DQK5_1=32 -DQR5_1=2 -DQK8_0=32 -DQR8_0=1";
912
+ std::string compile_opts = "-cl-mad-enable -cl-unsafe-math-optimizations -cl-finite-math-only -cl-fast-relaxed-math "
913
+ "-DQK4_0=32 -DQR4_0=2 -DQK4_1=32 -DQR4_1=2 -DQK5_0=32 -DQR5_0=2 -DQK5_1=32 -DQR5_1=2 -DQK8_0=32 -DQR8_0=1 "
914
+ "-DQK_K=256 -DK_QUANTS_PER_ITERATION=" + std::to_string(K_QUANTS_PER_ITERATION);
734
915
 
735
- err = clBuildProgram(p, 0, NULL, compile_opts, NULL, NULL);
916
+ err = clBuildProgram(p, 0, NULL, compile_opts.c_str(), NULL, NULL);
736
917
  if(err < 0) {
737
918
 
738
919
  clGetProgramBuildInfo(p, dev, CL_PROGRAM_BUILD_LOG, 0, NULL, &log_size);
@@ -1199,7 +1380,7 @@ static void ggml_cl_mul_f32(const ggml_tensor * src0, const ggml_tensor * src1,
1199
1380
  const int64_t ne00 = src0->ne[0];
1200
1381
  const int64_t ne01 = src0->ne[1];
1201
1382
  const int64_t ne02 = src0->ne[2];
1202
- const int64_t ne03 = src0->ne[2];
1383
+ const int64_t ne03 = src0->ne[3];
1203
1384
  const int64_t ne0 = ne00 * ne01 * ne02 * ne03;
1204
1385
  const int64_t ne10 = src1->ne[0];
1205
1386
  const int64_t ne11 = src1->ne[1];