llama_cpp 0.2.2 → 0.3.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +34 -0
- data/README.md +39 -6
- data/examples/chat.rb +2 -1
- data/examples/embedding.rb +3 -2
- data/ext/llama_cpp/extconf.rb +13 -0
- data/ext/llama_cpp/llama_cpp.cpp +305 -133
- data/ext/llama_cpp/src/ggml-cuda.cu +367 -69
- data/ext/llama_cpp/src/ggml-cuda.h +1 -0
- data/ext/llama_cpp/src/ggml-metal.m +36 -30
- data/ext/llama_cpp/src/ggml-metal.metal +328 -84
- data/ext/llama_cpp/src/ggml-opencl.cpp +352 -175
- data/ext/llama_cpp/src/ggml.c +800 -303
- data/ext/llama_cpp/src/ggml.h +68 -5
- data/ext/llama_cpp/src/k_quants.c +1712 -56
- data/ext/llama_cpp/src/k_quants.h +41 -6
- data/ext/llama_cpp/src/llama-util.h +19 -5
- data/ext/llama_cpp/src/llama.cpp +262 -291
- data/ext/llama_cpp/src/llama.h +49 -11
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +0 -2
- data/sig/llama_cpp.rbs +14 -17
- metadata +2 -3
- data/lib/llama_cpp/client.rb +0 -172
@@ -21,11 +21,19 @@
|
|
21
21
|
|
22
22
|
#define CL_DMMV_BLOCK_SIZE 32
|
23
23
|
|
24
|
+
#ifndef K_QUANTS_PER_ITERATION
|
25
|
+
#define K_QUANTS_PER_ITERATION 1
|
26
|
+
#else
|
27
|
+
static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2");
|
28
|
+
#endif
|
29
|
+
|
24
30
|
#define MULTILINE_QUOTE(...) #__VA_ARGS__
|
25
31
|
static std::string program_source = MULTILINE_QUOTE(
|
26
32
|
|
27
33
|
typedef char int8_t;
|
28
34
|
typedef uchar uint8_t;
|
35
|
+
typedef short int16_t;
|
36
|
+
typedef ushort uint16_t;
|
29
37
|
typedef int int32_t;
|
30
38
|
typedef uint uint32_t;
|
31
39
|
|
@@ -175,7 +183,9 @@ void convert_f16(__global half* x, const int ib, const int iqs, float* v0, float
|
|
175
183
|
*v0 = vload_half(0, &x[ib + 0]);
|
176
184
|
*v1 = vload_half(0, &x[ib + 1]);
|
177
185
|
}
|
186
|
+
);
|
178
187
|
|
188
|
+
static std::string k_quants_source = MULTILINE_QUOTE(
|
179
189
|
inline void get_scale_min_k4(int j, const __global uint8_t *q, uint8_t *d, uint8_t *m)
|
180
190
|
{
|
181
191
|
if (j < 4)
|
@@ -199,7 +209,7 @@ __kernel void dequantize_block_q2_K(__global const struct block_q2_K *x, __globa
|
|
199
209
|
const int is = 8 * n + l / 16;
|
200
210
|
|
201
211
|
const uint8_t q = x[i].qs[32 * n + l];
|
202
|
-
__global float *y = yy + i *
|
212
|
+
__global float *y = yy + i * QK_K + 128 * n;
|
203
213
|
|
204
214
|
const float dall = vload_half(0, &x[i].d);
|
205
215
|
const float dmin = vload_half(0, &x[i].dmin);
|
@@ -231,7 +241,7 @@ __kernel void dequantize_block_q3_K(__global const struct block_q3_K *x, __globa
|
|
231
241
|
float d_all = vload_half(0, &x[i].d);
|
232
242
|
float dl = d_all * (us - 32);
|
233
243
|
|
234
|
-
__global float *y = yy + i *
|
244
|
+
__global float *y = yy + i * QK_K + 128 * n + 32 * j;
|
235
245
|
const __global uint8_t *q = x[i].qs + 32 * n;
|
236
246
|
const __global uint8_t *hm = x[i].hmask;
|
237
247
|
|
@@ -248,7 +258,7 @@ __kernel void dequantize_block_q4_K(__global const struct block_q4_K *x, __globa
|
|
248
258
|
const int is = 2 * il;
|
249
259
|
const int n = 4;
|
250
260
|
|
251
|
-
__global float *y = yy + i *
|
261
|
+
__global float *y = yy + i * QK_K + 64 * il + n * ir;
|
252
262
|
|
253
263
|
const float dall = vload_half(0, &x[i].d);
|
254
264
|
const float dmin = vload_half(0, &x[i].dmin);
|
@@ -277,7 +287,7 @@ __kernel void dequantize_block_q5_K(__global const struct block_q5_K *x, __globa
|
|
277
287
|
const int ir = tid % 16;
|
278
288
|
const int is = 2 * il;
|
279
289
|
|
280
|
-
__global float *y = yy + i *
|
290
|
+
__global float *y = yy + i * QK_K + 64 * il + 2 * ir;
|
281
291
|
|
282
292
|
const float dall = vload_half(0, &x[i].d);
|
283
293
|
const float dmin = vload_half(0, &x[i].dmin);
|
@@ -309,7 +319,7 @@ __kernel void dequantize_block_q6_K(__global const struct block_q6_K *x, __globa
|
|
309
319
|
const int il = tid - 32 * ip;
|
310
320
|
const int is = 8 * ip + il / 16;
|
311
321
|
|
312
|
-
__global float *y = yy + i *
|
322
|
+
__global float *y = yy + i * QK_K + 128 * ip + il;
|
313
323
|
|
314
324
|
const float d = vload_half(0, &x[i].d);
|
315
325
|
|
@@ -323,161 +333,383 @@ __kernel void dequantize_block_q6_K(__global const struct block_q6_K *x, __globa
|
|
323
333
|
y[96] = d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32);
|
324
334
|
}
|
325
335
|
|
336
|
+
__kernel void dequantize_mul_mat_vec_q2_K(__global const struct block_q2_K * xx, __local float* tmp, __global float* yy, __global float* dst, const int ncols) {
|
326
337
|
|
327
|
-
|
338
|
+
const int row = get_group_id(0);
|
328
339
|
|
329
|
-
int
|
330
|
-
int
|
331
|
-
int l = r / 8;
|
340
|
+
const int num_blocks_per_row = ncols / QK_K;
|
341
|
+
const int ib0 = row*num_blocks_per_row;
|
332
342
|
|
333
|
-
__global const
|
334
|
-
__global const uint8_t *q = x[ib].qs + 32 * n + l;
|
335
|
-
__global const uint8_t *s = x[ib].scales + 8 * n;
|
343
|
+
__global const struct block_q2_K * x = xx + ib0;
|
336
344
|
|
337
|
-
const
|
338
|
-
const
|
345
|
+
const int tid = get_local_id(0)/K_QUANTS_PER_ITERATION; // 0...31 or 0...15
|
346
|
+
const int ix = get_local_id(0)%K_QUANTS_PER_ITERATION; // 0 or 0,1
|
339
347
|
|
340
|
-
|
341
|
-
+ y[ 32] * (dall * ((s[2] & 0xF) * ((q[ 0] >> 2) & 3)) - dmin * (s[2] >> 4))
|
342
|
-
+ y[ 64] * (dall * ((s[4] & 0xF) * ((q[ 0] >> 4) & 3)) - dmin * (s[4] >> 4))
|
343
|
-
+ y[ 96] * (dall * ((s[6] & 0xF) * ((q[ 0] >> 6) & 3)) - dmin * (s[6] >> 4))
|
344
|
-
+ y[ 16] * (dall * ((s[1] & 0xF) * ((q[16] >> 0) & 3)) - dmin * (s[1] >> 4))
|
345
|
-
+ y[ 48] * (dall * ((s[3] & 0xF) * ((q[16] >> 2) & 3)) - dmin * (s[3] >> 4))
|
346
|
-
+ y[ 80] * (dall * ((s[5] & 0xF) * ((q[16] >> 4) & 3)) - dmin * (s[5] >> 4))
|
347
|
-
+ y[112] * (dall * ((s[7] & 0xF) * ((q[16] >> 6) & 3)) - dmin * (s[7] >> 4));
|
348
|
+
const int step = 16/K_QUANTS_PER_ITERATION;
|
348
349
|
|
349
|
-
|
350
|
-
|
350
|
+
const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
|
351
|
+
const int in = tid - step*im; // 0...15 or 0...7
|
351
352
|
|
352
|
-
|
353
|
+
const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15 or 0...14 in steps of 2
|
354
|
+
const int q_offset = 32*im + l0;
|
355
|
+
const int s_offset = 8*im;
|
356
|
+
const int y_offset = 128*im + l0;
|
353
357
|
|
354
|
-
|
355
|
-
const uint32_t kmask2 = 0x0f0f0f0f;
|
358
|
+
tmp[16 * ix + tid] = 0;
|
356
359
|
|
357
|
-
uint32_t aux[
|
358
|
-
|
360
|
+
uint32_t aux[4];
|
361
|
+
const uint8_t * d = (const uint8_t *)aux;
|
362
|
+
const uint8_t * m = (const uint8_t *)(aux + 2);
|
359
363
|
|
360
|
-
int
|
361
|
-
int r = iqs - 128*n;
|
362
|
-
int l = r/8;
|
364
|
+
for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
|
363
365
|
|
364
|
-
|
365
|
-
|
366
|
-
__global const uint8_t * hm = x[ib].hmask + l;
|
367
|
-
const int8_t * s = (const int8_t *)utmp + 8*n;
|
366
|
+
__global const float * y = yy + i * QK_K + y_offset;
|
367
|
+
__global const uint8_t * q = x[i].qs + q_offset;
|
368
368
|
|
369
|
-
|
370
|
-
|
371
|
-
aux[2] = x[ib].scales[8] | x[ib].scales[9] << 8 | x[ib].scales[10] << 16 | x[ib].scales[11] << 24;
|
369
|
+
const float dall = vload_half(0, &x[i].d);
|
370
|
+
const float dmin = vload_half(0, &x[i].dmin);
|
372
371
|
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
|
372
|
+
__global const uint32_t * a = (__global const uint32_t *)(x[i].scales + s_offset);
|
373
|
+
aux[0] = a[0] & 0x0f0f0f0f;
|
374
|
+
aux[1] = a[1] & 0x0f0f0f0f;
|
375
|
+
aux[2] = (a[0] >> 4) & 0x0f0f0f0f;
|
376
|
+
aux[3] = (a[1] >> 4) & 0x0f0f0f0f;
|
377
377
|
|
378
|
-
|
379
|
-
|
378
|
+
float sum1 = 0, sum2 = 0;
|
379
|
+
for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
|
380
|
+
sum1 += y[l+ 0] * d[0] * ((q[l+ 0] >> 0) & 3)
|
381
|
+
+ y[l+32] * d[2] * ((q[l+ 0] >> 2) & 3)
|
382
|
+
+ y[l+64] * d[4] * ((q[l+ 0] >> 4) & 3)
|
383
|
+
+ y[l+96] * d[6] * ((q[l+ 0] >> 6) & 3)
|
384
|
+
+ y[l+16] * d[1] * ((q[l+16] >> 0) & 3)
|
385
|
+
+ y[l+48] * d[3] * ((q[l+16] >> 2) & 3)
|
386
|
+
+ y[l+80] * d[5] * ((q[l+16] >> 4) & 3)
|
387
|
+
+y[l+112] * d[7] * ((q[l+16] >> 6) & 3);
|
388
|
+
sum2 += y[l+ 0] * m[0] + y[l+32] * m[2] + y[l+64] * m[4] + y[ l+96] * m[6]
|
389
|
+
+ y[l+16] * m[1] + y[l+48] * m[3] + y[l+80] * m[5] + y[l+112] * m[7];
|
380
390
|
|
381
|
-
|
382
|
-
|
383
|
-
+ y[ 64] * (s[4] - 32) * (((q[ 0] >> 4) & 3) - (hm[ 0] & (m << 2) ? 0 : 4))
|
384
|
-
+ y[ 96] * (s[6] - 32) * (((q[ 0] >> 6) & 3) - (hm[ 0] & (m << 3) ? 0 : 4))
|
385
|
-
+ y[ 16] * (s[1] - 32) * (((q[16] >> 0) & 3) - (hm[16] & (m << 0) ? 0 : 4))
|
386
|
-
+ y[ 48] * (s[3] - 32) * (((q[16] >> 2) & 3) - (hm[16] & (m << 1) ? 0 : 4))
|
387
|
-
+ y[ 80] * (s[5] - 32) * (((q[16] >> 4) & 3) - (hm[16] & (m << 2) ? 0 : 4))
|
388
|
-
+ y[112] * (s[7] - 32) * (((q[16] >> 6) & 3) - (hm[16] & (m << 3) ? 0 : 4));
|
391
|
+
}
|
392
|
+
tmp[16 * ix + tid] += dall * sum1 - dmin * sum2;
|
389
393
|
|
390
|
-
|
394
|
+
}
|
391
395
|
|
396
|
+
// sum up partial sums and write back result
|
397
|
+
barrier(CLK_LOCAL_MEM_FENCE);
|
398
|
+
for (int s=16; s>0; s>>=1) {
|
399
|
+
if (tid < s) {
|
400
|
+
tmp[tid] += tmp[tid + s];
|
401
|
+
}
|
402
|
+
barrier(CLK_LOCAL_MEM_FENCE);
|
403
|
+
}
|
404
|
+
if (tid == 0) {
|
405
|
+
dst[row] = tmp[0];
|
406
|
+
}
|
392
407
|
}
|
393
408
|
|
394
|
-
void
|
409
|
+
__kernel void dequantize_mul_mat_vec_q3_K(__global const struct block_q3_K * xx, __local float* tmp, __global float* yy, __global float* dst, const int ncols) {
|
410
|
+
const uint16_t kmask1 = 0x0303;
|
411
|
+
const uint16_t kmask2 = 0x0f0f;
|
412
|
+
|
413
|
+
const int row = get_group_id(0);
|
414
|
+
|
415
|
+
const int num_blocks_per_row = ncols / QK_K;
|
416
|
+
const int ib0 = row*num_blocks_per_row;
|
395
417
|
|
396
|
-
const
|
397
|
-
const int ir = (iqs - 64*j)/2; // ir is in 0...28 in steps of 4
|
398
|
-
const int is = 2*j; // is is in 0...6 in steps of 2
|
418
|
+
__global const struct block_q3_K * x = xx + ib0;
|
399
419
|
|
400
|
-
|
401
|
-
|
420
|
+
const int tid = get_local_id(0)/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
|
421
|
+
const int ix = get_local_id(0)%K_QUANTS_PER_ITERATION; // 0 or 0,1
|
402
422
|
|
403
|
-
const
|
404
|
-
const
|
423
|
+
const int n = K_QUANTS_PER_ITERATION; // iterations in the inner loop
|
424
|
+
const int step = 16/K_QUANTS_PER_ITERATION;
|
425
|
+
const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
|
426
|
+
const int in = tid - step*im; // 0....15 or 0...7
|
405
427
|
|
406
|
-
uint8_t
|
407
|
-
|
408
|
-
const
|
409
|
-
const
|
410
|
-
|
411
|
-
|
412
|
-
|
428
|
+
const uint8_t m = 1 << (4*im);
|
429
|
+
|
430
|
+
const int l0 = n*in; // 0...15 or 0...14 in steps of 2
|
431
|
+
const int q_offset = 32*im + l0;
|
432
|
+
const int y_offset = 128*im + l0;
|
433
|
+
|
434
|
+
uint16_t utmp[4];
|
435
|
+
const int8_t * s = (const int8_t *)utmp;
|
436
|
+
|
437
|
+
const uint16_t s_shift = 4*im;
|
438
|
+
|
439
|
+
tmp[16 * ix + tid] = 0;
|
440
|
+
|
441
|
+
for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
|
442
|
+
|
443
|
+
__global const float * y = yy + i * QK_K + y_offset;
|
444
|
+
__global const uint8_t * q = x[i].qs + q_offset;
|
445
|
+
__global const uint8_t * h = x[i].hmask + l0;
|
446
|
+
|
447
|
+
__global const uint16_t * a = (__global const uint16_t *)x[i].scales;
|
448
|
+
utmp[0] = ((a[0] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 0)) & kmask1) << 4);
|
449
|
+
utmp[1] = ((a[1] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 0)) & kmask1) << 4);
|
450
|
+
utmp[2] = ((a[2] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 2)) & kmask1) << 4);
|
451
|
+
utmp[3] = ((a[3] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 2)) & kmask1) << 4);
|
452
|
+
|
453
|
+
const float d = vload_half(0, &x[i].d);
|
454
|
+
|
455
|
+
float sum = 0;
|
456
|
+
for (int l = 0; l < n; ++l) {
|
457
|
+
sum += y[l+ 0] * (s[0] - 32) * (((q[l] >> 0) & 3) - (h[l] & (m << 0) ? 0 : 4))
|
458
|
+
+ y[l+32] * (s[2] - 32) * (((q[l] >> 2) & 3) - (h[l] & (m << 1) ? 0 : 4))
|
459
|
+
+ y[l+64] * (s[4] - 32) * (((q[l] >> 4) & 3) - (h[l] & (m << 2) ? 0 : 4))
|
460
|
+
+ y[l+96] * (s[6] - 32) * (((q[l] >> 6) & 3) - (h[l] & (m << 3) ? 0 : 4));
|
461
|
+
sum += y[l+16] * (s[1] - 32) * (((q[l+16] >> 0) & 3) - (h[l+16] & (m << 0) ? 0 : 4))
|
462
|
+
+ y[l+48] * (s[3] - 32) * (((q[l+16] >> 2) & 3) - (h[l+16] & (m << 1) ? 0 : 4))
|
463
|
+
+ y[l+80] * (s[5] - 32) * (((q[l+16] >> 4) & 3) - (h[l+16] & (m << 2) ? 0 : 4))
|
464
|
+
+ y[l+112] * (s[7] - 32) * (((q[l+16] >> 6) & 3) - (h[l+16] & (m << 3) ? 0 : 4));
|
465
|
+
}
|
466
|
+
tmp[16 * ix + tid] += d * sum;
|
413
467
|
|
414
|
-
float sum = 0;
|
415
|
-
for (int k = 0; k < 4; ++k) {
|
416
|
-
sum += y[k + 0] * (d1 * (q[k] & 0xF) - m1);
|
417
|
-
sum += y[k + 32] * (d2 * (q[k] >> 4) - m2);
|
418
468
|
}
|
419
469
|
|
420
|
-
|
470
|
+
// sum up partial sums and write back result
|
471
|
+
barrier(CLK_LOCAL_MEM_FENCE);
|
472
|
+
for (int s=16; s>0; s>>=1) {
|
473
|
+
if (tid < s) {
|
474
|
+
tmp[tid] += tmp[tid + s];
|
475
|
+
}
|
476
|
+
barrier(CLK_LOCAL_MEM_FENCE);
|
477
|
+
}
|
478
|
+
if (tid == 0) {
|
479
|
+
dst[row] = tmp[0];
|
480
|
+
}
|
421
481
|
}
|
422
482
|
|
423
|
-
void
|
483
|
+
__kernel void dequantize_mul_mat_vec_q4_K(__global const struct block_q4_K * xx, __local float* tmp, __global float* yy, __global float* dst, const int ncols) {
|
424
484
|
|
425
|
-
|
426
|
-
const
|
427
|
-
const
|
485
|
+
//to rename it later, just to test now
|
486
|
+
const uint16_t kmask1 = 0x3f3f;
|
487
|
+
const uint16_t kmask2 = 0x0f0f;
|
488
|
+
const uint16_t kmask3 = 0xc0c0;
|
428
489
|
|
429
|
-
|
430
|
-
|
431
|
-
|
490
|
+
const int row = get_group_id(0);
|
491
|
+
const int num_blocks_per_row = ncols / QK_K;
|
492
|
+
const int ib0 = row*num_blocks_per_row;
|
432
493
|
|
433
|
-
const
|
434
|
-
const
|
494
|
+
const int tid = get_local_id(0)/K_QUANTS_PER_ITERATION; // 0...15
|
495
|
+
const int ix = get_local_id(0)%K_QUANTS_PER_ITERATION;
|
435
496
|
|
436
|
-
|
437
|
-
|
438
|
-
const
|
439
|
-
const
|
440
|
-
|
441
|
-
|
442
|
-
const
|
497
|
+
const int step = 8/K_QUANTS_PER_ITERATION;
|
498
|
+
|
499
|
+
const int il = tid/step; // 0...3
|
500
|
+
const int ir = tid - step*il;// 0...3
|
501
|
+
const int n = 2*K_QUANTS_PER_ITERATION;
|
502
|
+
|
503
|
+
const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
|
504
|
+
const int in = il%2;
|
505
|
+
|
506
|
+
const int l0 = n*(2*ir + in);
|
507
|
+
const int q_offset = 32*im + l0;
|
508
|
+
const int y_offset = 64*im + l0;
|
509
|
+
|
510
|
+
uint16_t aux[4];
|
511
|
+
const uint8_t * sc = (const uint8_t *)aux;
|
512
|
+
|
513
|
+
__global const struct block_q4_K * x = xx + ib0;
|
514
|
+
|
515
|
+
tmp[16 * ix + tid] = 0;
|
516
|
+
|
517
|
+
for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
|
518
|
+
|
519
|
+
__global const uint8_t * q1 = x[i].qs + q_offset;
|
520
|
+
__global const uint8_t * q2 = q1 + 64;
|
521
|
+
__global const float * y1 = yy + i*QK_K + y_offset;
|
522
|
+
__global const float * y2 = y1 + 128;
|
523
|
+
|
524
|
+
const float dall = vload_half(0, &x[i].d);
|
525
|
+
const float dmin = vload_half(0, &x[i].dmin);
|
526
|
+
|
527
|
+
__global const uint16_t * a = (__global const uint16_t *)x[i].scales;
|
528
|
+
aux[0] = a[im+0] & kmask1;
|
529
|
+
aux[1] = a[im+2] & kmask1;
|
530
|
+
aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2);
|
531
|
+
aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2);
|
532
|
+
|
533
|
+
float4 s = (float4)(0.f);
|
534
|
+
float smin = 0;
|
535
|
+
for (int l = 0; l < n; ++l) {
|
536
|
+
s.x += y1[l] * (q1[l] & 0xF); s.y += y1[l+32] * (q1[l] >> 4);
|
537
|
+
s.z += y2[l] * (q2[l] & 0xF); s.w += y2[l+32] * (q2[l] >> 4);
|
538
|
+
smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7];
|
539
|
+
}
|
540
|
+
tmp[16 * ix + tid] += dall * (s.x * sc[0] + s.y * sc[1] + s.z * sc[4] + s.w * sc[5]) - dmin * smin;
|
443
541
|
|
444
|
-
uint8_t hm = 1 << is;
|
445
|
-
float sum = 0;
|
446
|
-
for (int k = 0; k < 4; ++k) {
|
447
|
-
sum += y[k + 0] * (d1 * ((ql[k] & 0xF) + (qh[k] & hm ? 16 : 0)) - m1);
|
448
542
|
}
|
449
|
-
|
450
|
-
|
451
|
-
|
543
|
+
|
544
|
+
// sum up partial sums and write back result
|
545
|
+
barrier(CLK_LOCAL_MEM_FENCE);
|
546
|
+
for (int s=16; s>0; s>>=1) {
|
547
|
+
if (tid < s) {
|
548
|
+
tmp[tid] += tmp[tid + s];
|
549
|
+
}
|
550
|
+
barrier(CLK_LOCAL_MEM_FENCE);
|
551
|
+
}
|
552
|
+
if (tid == 0) {
|
553
|
+
dst[row] = tmp[0];
|
554
|
+
}
|
555
|
+
}
|
556
|
+
|
557
|
+
__kernel void dequantize_mul_mat_vec_q5_K(__global const struct block_q5_K * xx, __local float* tmp, __global float* yy, __global float* dst, const int ncols) {
|
558
|
+
|
559
|
+
const uint16_t kmask1 = 0x3f3f;
|
560
|
+
const uint16_t kmask2 = 0x0f0f;
|
561
|
+
const uint16_t kmask3 = 0xc0c0;
|
562
|
+
|
563
|
+
const int row = get_group_id(0);
|
564
|
+
const int num_blocks_per_row = ncols / QK_K;
|
565
|
+
const int ib0 = row*num_blocks_per_row;
|
566
|
+
|
567
|
+
const int tid = get_local_id(0)/2; // 0...15
|
568
|
+
const int ix = get_local_id(0)%2;
|
569
|
+
|
570
|
+
const int il = tid/4; // 0...3
|
571
|
+
const int ir = tid - 4*il;// 0...3
|
572
|
+
const int n = 2;
|
573
|
+
|
574
|
+
const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
|
575
|
+
const int in = il%2;
|
576
|
+
|
577
|
+
const int l0 = n*(2*ir + in);
|
578
|
+
const int q_offset = 32*im + l0;
|
579
|
+
const int y_offset = 64*im + l0;
|
580
|
+
|
581
|
+
const uint8_t hm1 = 1 << (2*im);
|
582
|
+
const uint8_t hm2 = hm1 << 4;
|
583
|
+
|
584
|
+
uint16_t aux[4];
|
585
|
+
const uint8_t * sc = (const uint8_t *)aux;
|
586
|
+
|
587
|
+
__global const struct block_q5_K * x = xx + ib0;
|
588
|
+
|
589
|
+
tmp[16 * ix + tid] = 0;
|
590
|
+
|
591
|
+
for (int i = ix; i < num_blocks_per_row; i += 2) {
|
592
|
+
|
593
|
+
__global const uint8_t * ql1 = x[i].qs + q_offset;
|
594
|
+
__global const uint8_t * ql2 = ql1 + 64;
|
595
|
+
__global const uint8_t * qh = x[i].qh + l0;
|
596
|
+
__global const float * y1 = yy + i*QK_K + y_offset;
|
597
|
+
__global const float * y2 = y1 + 128;
|
598
|
+
|
599
|
+
const float dall = vload_half(0, &x[i].d);
|
600
|
+
const float dmin = vload_half(0, &x[i].dmin);
|
601
|
+
|
602
|
+
__global const uint16_t * a = (__global const uint16_t *)x[i].scales;
|
603
|
+
aux[0] = a[im+0] & kmask1;
|
604
|
+
aux[1] = a[im+2] & kmask1;
|
605
|
+
aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2);
|
606
|
+
aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2);
|
607
|
+
|
608
|
+
float4 sum = (float4)(0.f);
|
609
|
+
float smin = 0;
|
610
|
+
for (int l = 0; l < n; ++l) {
|
611
|
+
sum.x += y1[l+ 0] * ((ql1[l+ 0] & 0xF) + (qh[l+ 0] & (hm1 << 0) ? 16 : 0))
|
612
|
+
+ y1[l+16] * ((ql1[l+16] & 0xF) + (qh[l+16] & (hm1 << 0) ? 16 : 0));
|
613
|
+
sum.y += y1[l+32] * ((ql1[l+ 0] >> 4) + (qh[l+ 0] & (hm1 << 1) ? 16 : 0))
|
614
|
+
+ y1[l+48] * ((ql1[l+16] >> 4) + (qh[l+16] & (hm1 << 1) ? 16 : 0));
|
615
|
+
sum.z += y2[l+ 0] * ((ql2[l+ 0] & 0xF) + (qh[l+ 0] & (hm2 << 0) ? 16 : 0))
|
616
|
+
+ y2[l+16] * ((ql2[l+16] & 0xF) + (qh[l+16] & (hm2 << 0) ? 16 : 0));
|
617
|
+
sum.w += y2[l+32] * ((ql2[l+ 0] >> 4) + (qh[l+ 0] & (hm2 << 1) ? 16 : 0))
|
618
|
+
+ y2[l+48] * ((ql2[l+16] >> 4) + (qh[l+16] & (hm2 << 1) ? 16 : 0));
|
619
|
+
smin += (y1[l] + y1[l+16]) * sc[2] + (y1[l+32] + y1[l+48]) * sc[3]
|
620
|
+
+ (y2[l] + y2[l+16]) * sc[6] + (y2[l+32] + y2[l+48]) * sc[7];
|
621
|
+
}
|
622
|
+
tmp[16 * ix + tid] += dall * (sum.x * sc[0] + sum.y * sc[1] + sum.z * sc[4] + sum.w * sc[5]) - dmin * smin;
|
623
|
+
|
452
624
|
}
|
453
|
-
*result = sum;
|
454
625
|
|
626
|
+
// sum up partial sums and write back result
|
627
|
+
barrier(CLK_LOCAL_MEM_FENCE);
|
628
|
+
for (int s=16; s>0; s>>=1) {
|
629
|
+
if (tid < s) {
|
630
|
+
tmp[tid] += tmp[tid + s];
|
631
|
+
}
|
632
|
+
barrier(CLK_LOCAL_MEM_FENCE);
|
633
|
+
}
|
634
|
+
if (tid == 0) {
|
635
|
+
dst[row] = tmp[0];
|
636
|
+
}
|
455
637
|
}
|
456
638
|
|
457
|
-
void
|
639
|
+
__kernel void dequantize_mul_mat_vec_q6_K(__global const struct block_q6_K * xx, __local float* tmp, __global const float * yy, __global float * dst, const int ncols) {
|
640
|
+
|
641
|
+
const int row = get_group_id(0);
|
458
642
|
|
643
|
+
const int num_blocks_per_row = ncols / QK_K;
|
644
|
+
const int ib0 = row*num_blocks_per_row;
|
459
645
|
|
460
|
-
const
|
461
|
-
const int il = (iqs - 128*ip)/8; // 0...15
|
462
|
-
const int is = 8*ip;
|
646
|
+
__global const struct block_q6_K * x = xx + ib0;
|
463
647
|
|
464
|
-
|
648
|
+
const int tid = get_local_id(0)/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
|
649
|
+
const int ix = get_local_id(0)%K_QUANTS_PER_ITERATION; // 0 or 0, 1
|
465
650
|
|
466
|
-
const
|
651
|
+
const int step = 16/K_QUANTS_PER_ITERATION; // 16 or 8
|
467
652
|
|
468
|
-
|
469
|
-
|
470
|
-
|
653
|
+
const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
|
654
|
+
const int in = tid - step*im; // 0...15 or 0...7
|
655
|
+
|
656
|
+
#if K_QUANTS_PER_ITERATION == 1
|
657
|
+
const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15
|
658
|
+
const int is = 0;
|
659
|
+
#else
|
660
|
+
const int l0 = 4 * in; // 0, 4, 8, ..., 28
|
661
|
+
const int is = in / 4;
|
662
|
+
#endif
|
663
|
+
const int ql_offset = 64*im + l0;
|
664
|
+
const int qh_offset = 32*im + l0;
|
665
|
+
const int s_offset = 8*im + is;
|
666
|
+
const int y_offset = 128*im + l0;
|
667
|
+
|
668
|
+
tmp[16 * ix + tid] = 0; // partial sum for thread in warp
|
669
|
+
|
670
|
+
for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
|
671
|
+
|
672
|
+
__global const float * y = yy + i * QK_K + y_offset;
|
673
|
+
__global const uint8_t * ql = x[i].ql + ql_offset;
|
674
|
+
__global const uint8_t * qh = x[i].qh + qh_offset;
|
675
|
+
__global const int8_t * s = x[i].scales + s_offset;
|
676
|
+
|
677
|
+
const float d = vload_half(0, &x[i].d);
|
678
|
+
|
679
|
+
#if K_QUANTS_PER_ITERATION == 1
|
680
|
+
float sum = y[ 0] * s[0] * d * ((int8_t)((ql[ 0] & 0xF) | ((qh[ 0] & 0x03) << 4)) - 32)
|
681
|
+
+ y[16] * s[1] * d * ((int8_t)((ql[16] & 0xF) | ((qh[16] & 0x03) << 4)) - 32)
|
682
|
+
+ y[32] * s[2] * d * ((int8_t)((ql[32] & 0xF) | ((qh[ 0] & 0x0c) << 2)) - 32)
|
683
|
+
+ y[48] * s[3] * d * ((int8_t)((ql[48] & 0xF) | ((qh[16] & 0x0c) << 2)) - 32)
|
684
|
+
+ y[64] * s[4] * d * ((int8_t)((ql[ 0] >> 4) | ((qh[ 0] & 0x30) >> 0)) - 32)
|
685
|
+
+ y[80] * s[5] * d * ((int8_t)((ql[16] >> 4) | ((qh[16] & 0x30) >> 0)) - 32)
|
686
|
+
+ y[96] * s[6] * d * ((int8_t)((ql[32] >> 4) | ((qh[ 0] & 0xc0) >> 2)) - 32)
|
687
|
+
+y[112] * s[7] * d * ((int8_t)((ql[48] >> 4) | ((qh[16] & 0xc0) >> 2)) - 32);
|
688
|
+
tmp[16 * ix + tid] += sum;
|
689
|
+
#else
|
690
|
+
float sum = 0;
|
691
|
+
for (int l = 0; l < 4; ++l) {
|
692
|
+
sum += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32)
|
693
|
+
+ y[l+32] * s[2] * d * ((int8_t)((ql[l+32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32)
|
694
|
+
+ y[l+64] * s[4] * d * ((int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32)
|
695
|
+
+ y[l+96] * s[6] * d * ((int8_t)((ql[l+32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32);
|
696
|
+
}
|
697
|
+
tmp[16 * ix + tid] += sum;
|
698
|
+
#endif
|
471
699
|
|
472
|
-
|
473
|
-
+ y[ 32] * d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh[ 0] >> 2) & 3) << 4)) - 32)
|
474
|
-
+ y[ 64] * d * sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh[ 0] >> 4) & 3) << 4)) - 32)
|
475
|
-
+ y[ 96] * d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh[ 0] >> 6) & 3) << 4)) - 32)
|
476
|
-
+ y[ 16] * d * sc[1] * ((int8_t)((ql[16] & 0xF) | (((qh[16] >> 0) & 3) << 4)) - 32)
|
477
|
-
+ y[ 48] * d * sc[3] * ((int8_t)((ql[48] & 0xF) | (((qh[16] >> 2) & 3) << 4)) - 32)
|
478
|
-
+ y[ 80] * d * sc[5] * ((int8_t)((ql[16] >> 4) | (((qh[16] >> 4) & 3) << 4)) - 32)
|
479
|
-
+ y[112] * d * sc[7] * ((int8_t)((ql[48] >> 4) | (((qh[16] >> 6) & 3) << 4)) - 32);
|
700
|
+
}
|
480
701
|
|
702
|
+
// sum up partial sums and write back result
|
703
|
+
barrier(CLK_LOCAL_MEM_FENCE);
|
704
|
+
for (int s=16; s>0; s>>=1) {
|
705
|
+
if (tid < s) {
|
706
|
+
tmp[tid] += tmp[tid + s];
|
707
|
+
}
|
708
|
+
barrier(CLK_LOCAL_MEM_FENCE);
|
709
|
+
}
|
710
|
+
if (tid == 0) {
|
711
|
+
dst[row] = tmp[0];
|
712
|
+
}
|
481
713
|
}
|
482
714
|
|
483
715
|
);
|
@@ -549,44 +781,6 @@ __kernel void KERNEL_NAME(__global X_TYPE* x, __local float* tmp, __global float
|
|
549
781
|
}
|
550
782
|
);
|
551
783
|
|
552
|
-
std::string dequant_mul_mat_vec_k_template = MULTILINE_QUOTE(
|
553
|
-
__kernel void KERNEL_NAME(__global X_TYPE* x, __local float* tmp, __global float* y, __global float* dst, const int ncols) {
|
554
|
-
const int block_size = get_local_size(0);
|
555
|
-
const int row = get_group_id(0);
|
556
|
-
const int tid = get_local_id(0);
|
557
|
-
|
558
|
-
const int iter_stride = 256;
|
559
|
-
const int vals_per_iter = iter_stride / block_size;
|
560
|
-
const int num_blocks_per_row = ncols / 256;
|
561
|
-
const int ib0 = row*num_blocks_per_row;
|
562
|
-
|
563
|
-
tmp[tid] = 0;
|
564
|
-
|
565
|
-
for (int i = 0; i < ncols; i += iter_stride) {
|
566
|
-
const int col = i + vals_per_iter*tid;
|
567
|
-
const int ib = ib0 + col/256; // x block index
|
568
|
-
const int iqs = col%256; // x quant index
|
569
|
-
const int iybs = col - col%256; // y block start index
|
570
|
-
|
571
|
-
// dequantize
|
572
|
-
float v;
|
573
|
-
DOT_KERNEL(x, ib, iqs, y + iybs, &v);
|
574
|
-
tmp[tid] += v;
|
575
|
-
}
|
576
|
-
|
577
|
-
// sum up partial sums and write back result
|
578
|
-
barrier(CLK_LOCAL_MEM_FENCE);
|
579
|
-
for (int s=block_size/2; s>0; s>>=1) {
|
580
|
-
if (tid < s) {
|
581
|
-
tmp[tid] += tmp[tid + s];
|
582
|
-
}
|
583
|
-
barrier(CLK_LOCAL_MEM_FENCE);
|
584
|
-
}
|
585
|
-
if (tid == 0) {
|
586
|
-
dst[row] = tmp[0];
|
587
|
-
}
|
588
|
-
}
|
589
|
-
);
|
590
784
|
|
591
785
|
std::string mul_template = MULTILINE_QUOTE(
|
592
786
|
__kernel void KERNEL_NAME(__global TYPE* x, const int x_offset, __global TYPE* y, const int y_offset, __global TYPE* dst, const int dst_offset, const int ky) {
|
@@ -649,18 +843,6 @@ std::array<std::string, 2> mul_str_values = {
|
|
649
843
|
"mul_f32", "float"
|
650
844
|
};
|
651
845
|
|
652
|
-
std::array<std::string, 3> dmmv_k_str_keys = {
|
653
|
-
"KERNEL_NAME", "X_TYPE", "DOT_KERNEL"
|
654
|
-
};
|
655
|
-
|
656
|
-
std::array<std::string, 15> dmmv_k_str_values = {
|
657
|
-
"dequantize_mul_mat_vec_q2_K", "struct block_q2_K", "vec_dot_q2_K",
|
658
|
-
"dequantize_mul_mat_vec_q3_K", "struct block_q3_K", "vec_dot_q3_K",
|
659
|
-
"dequantize_mul_mat_vec_q4_K", "struct block_q4_K", "vec_dot_q4_K",
|
660
|
-
"dequantize_mul_mat_vec_q5_K", "struct block_q5_K", "vec_dot_q5_K",
|
661
|
-
"dequantize_mul_mat_vec_q6_K", "struct block_q6_K", "vec_dot_q6_K",
|
662
|
-
};
|
663
|
-
|
664
846
|
std::string& replace(std::string& s, const std::string& from, const std::string& to) {
|
665
847
|
size_t pos = 0;
|
666
848
|
while ((pos = s.find(from, pos)) != std::string::npos) {
|
@@ -673,6 +855,7 @@ std::string& replace(std::string& s, const std::string& from, const std::string&
|
|
673
855
|
std::string generate_kernels() {
|
674
856
|
std::stringstream src;
|
675
857
|
src << program_source << '\n';
|
858
|
+
src << k_quants_source << '\n';
|
676
859
|
for (size_t i = 0; i < dequant_str_values.size(); i += dequant_str_keys.size()) {
|
677
860
|
std::string dequant_kernel = dequant_template;
|
678
861
|
std::string dmmv_kernel = dequant_mul_mat_vec_template;
|
@@ -690,13 +873,6 @@ std::string generate_kernels() {
|
|
690
873
|
}
|
691
874
|
src << mul_kernel << '\n';
|
692
875
|
}
|
693
|
-
for (size_t i = 0; i < dmmv_k_str_values.size(); i += dmmv_k_str_keys.size()) {
|
694
|
-
std::string dmmv_k_kernel = dequant_mul_mat_vec_k_template;
|
695
|
-
for (size_t j = 0; j < dmmv_k_str_keys.size(); j++) {
|
696
|
-
replace(dmmv_k_kernel, dmmv_k_str_keys[j], dmmv_k_str_values[i + j]);
|
697
|
-
}
|
698
|
-
src << dmmv_k_kernel << '\n';
|
699
|
-
}
|
700
876
|
|
701
877
|
return src.str();
|
702
878
|
}
|
@@ -729,10 +905,11 @@ static cl_program build_program_from_source(cl_context ctx, cl_device_id dev, co
|
|
729
905
|
exit(1);
|
730
906
|
}
|
731
907
|
|
732
|
-
|
733
|
-
"-DQK4_0=32 -DQR4_0=2 -DQK4_1=32 -DQR4_1=2 -DQK5_0=32 -DQR5_0=2 -DQK5_1=32 -DQR5_1=2 -DQK8_0=32 -DQR8_0=1"
|
908
|
+
std::string compile_opts = "-cl-mad-enable -cl-unsafe-math-optimizations -cl-finite-math-only -cl-fast-relaxed-math "
|
909
|
+
"-DQK4_0=32 -DQR4_0=2 -DQK4_1=32 -DQR4_1=2 -DQK5_0=32 -DQR5_0=2 -DQK5_1=32 -DQR5_1=2 -DQK8_0=32 -DQR8_0=1 "
|
910
|
+
"-DQK_K=256 -DK_QUANTS_PER_ITERATION=" + std::to_string(K_QUANTS_PER_ITERATION);
|
734
911
|
|
735
|
-
err = clBuildProgram(p, 0, NULL, compile_opts, NULL, NULL);
|
912
|
+
err = clBuildProgram(p, 0, NULL, compile_opts.c_str(), NULL, NULL);
|
736
913
|
if(err < 0) {
|
737
914
|
|
738
915
|
clGetProgramBuildInfo(p, dev, CL_PROGRAM_BUILD_LOG, 0, NULL, &log_size);
|