llama_cpp 0.15.2 → 0.15.4
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +14 -0
- data/ext/llama_cpp/llama_cpp.cpp +61 -0
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +6 -0
- data/vendor/tmp/llama.cpp/Makefile +8 -16
- data/vendor/tmp/llama.cpp/ggml-common.h +0 -54
- data/vendor/tmp/llama.cpp/ggml-cuda.cu +99 -40
- data/vendor/tmp/llama.cpp/ggml-cuda.h +1 -0
- data/vendor/tmp/llama.cpp/ggml-impl.h +44 -0
- data/vendor/tmp/llama.cpp/ggml-kompute.cpp +4 -1
- data/vendor/tmp/llama.cpp/ggml-metal.m +133 -81
- data/vendor/tmp/llama.cpp/ggml-metal.metal +91 -434
- data/vendor/tmp/llama.cpp/ggml-opencl.cpp +4 -1
- data/vendor/tmp/llama.cpp/ggml-quants.c +1962 -2443
- data/vendor/tmp/llama.cpp/ggml-rpc.cpp +248 -108
- data/vendor/tmp/llama.cpp/ggml-sycl.cpp +375 -657
- data/vendor/tmp/llama.cpp/ggml-vulkan-shaders.hpp +9351 -5627
- data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +204 -225
- data/vendor/tmp/llama.cpp/ggml.c +498 -836
- data/vendor/tmp/llama.cpp/ggml.h +57 -30
- data/vendor/tmp/llama.cpp/llama.cpp +1477 -859
- data/vendor/tmp/llama.cpp/llama.h +21 -8
- metadata +3 -3
@@ -168,6 +168,53 @@ kernel void kernel_div(
|
|
168
168
|
}
|
169
169
|
}
|
170
170
|
|
171
|
+
template<typename T>
|
172
|
+
kernel void kernel_repeat(
|
173
|
+
device const char * src0,
|
174
|
+
device char * dst,
|
175
|
+
constant int64_t & ne00,
|
176
|
+
constant int64_t & ne01,
|
177
|
+
constant int64_t & ne02,
|
178
|
+
constant int64_t & ne03,
|
179
|
+
constant uint64_t & nb00,
|
180
|
+
constant uint64_t & nb01,
|
181
|
+
constant uint64_t & nb02,
|
182
|
+
constant uint64_t & nb03,
|
183
|
+
constant int64_t & ne0,
|
184
|
+
constant int64_t & ne1,
|
185
|
+
constant int64_t & ne2,
|
186
|
+
constant int64_t & ne3,
|
187
|
+
constant uint64_t & nb0,
|
188
|
+
constant uint64_t & nb1,
|
189
|
+
constant uint64_t & nb2,
|
190
|
+
constant uint64_t & nb3,
|
191
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
192
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
193
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
194
|
+
const int64_t i3 = tgpig.z;
|
195
|
+
const int64_t i2 = tgpig.y;
|
196
|
+
const int64_t i1 = tgpig.x;
|
197
|
+
|
198
|
+
const int64_t i03 = i3 % ne03;
|
199
|
+
const int64_t i02 = i2 % ne02;
|
200
|
+
const int64_t i01 = i1 % ne01;
|
201
|
+
|
202
|
+
device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
|
203
|
+
device char * dst_ptr = dst + i3*nb3 + i2*nb2 + i1*nb1 ;
|
204
|
+
|
205
|
+
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
206
|
+
const int i00 = i0 % ne00;
|
207
|
+
*((device T *)(dst_ptr + i0*nb0)) = *((device T *)(src0_ptr + i00*nb00));
|
208
|
+
}
|
209
|
+
}
|
210
|
+
|
211
|
+
typedef decltype(kernel_repeat<float>) kernel_repeat_t;
|
212
|
+
|
213
|
+
template [[host_name("kernel_repeat_f32")]] kernel kernel_repeat_t kernel_repeat<float>;
|
214
|
+
template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat<half>;
|
215
|
+
template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat<int>;
|
216
|
+
template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat<short>;
|
217
|
+
|
171
218
|
// assumption: src1 is a row
|
172
219
|
// broadcast src1 into src0
|
173
220
|
kernel void kernel_add_row(
|
@@ -1640,6 +1687,7 @@ static void rope_yarn_corr_dims(
|
|
1640
1687
|
typedef void (rope_t)(
|
1641
1688
|
device const void * src0,
|
1642
1689
|
device const int32_t * src1,
|
1690
|
+
device const float * src2,
|
1643
1691
|
device float * dst,
|
1644
1692
|
constant int64_t & ne00,
|
1645
1693
|
constant int64_t & ne01,
|
@@ -1675,6 +1723,7 @@ template<typename T>
|
|
1675
1723
|
kernel void kernel_rope(
|
1676
1724
|
device const void * src0,
|
1677
1725
|
device const int32_t * src1,
|
1726
|
+
device const float * src2,
|
1678
1727
|
device float * dst,
|
1679
1728
|
constant int64_t & ne00,
|
1680
1729
|
constant int64_t & ne01,
|
@@ -1718,13 +1767,13 @@ kernel void kernel_rope(
|
|
1718
1767
|
|
1719
1768
|
const int64_t p = pos[i2];
|
1720
1769
|
|
1721
|
-
const float
|
1770
|
+
const float theta_base = (float)p;
|
1722
1771
|
const float inv_ndims = -1.f/n_dims;
|
1723
1772
|
|
1724
1773
|
if (!is_neox) {
|
1725
1774
|
for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
|
1775
|
+
const float theta = theta_base * pow(freq_base, inv_ndims*i0);
|
1726
1776
|
|
1727
|
-
const float theta = theta_0 * pow(freq_base, inv_ndims*i0);
|
1728
1777
|
float cos_theta, sin_theta;
|
1729
1778
|
rope_yarn(theta, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
1730
1779
|
|
@@ -1740,16 +1789,14 @@ kernel void kernel_rope(
|
|
1740
1789
|
} else {
|
1741
1790
|
for (int64_t ic = 2*tiitg; ic < ne0; ic += 2*tptg.x) {
|
1742
1791
|
if (ic < n_dims) {
|
1743
|
-
const int64_t
|
1792
|
+
const int64_t i0 = ic/2;
|
1744
1793
|
|
1745
|
-
|
1746
|
-
const float cur_rot = inv_ndims*ic - ib;
|
1794
|
+
const float freq_factor = src2 != src0 ? src2[i0] : 1.0f;
|
1747
1795
|
|
1748
|
-
const float theta =
|
1749
|
-
float cos_theta, sin_theta;
|
1750
|
-
rope_yarn(theta, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
1796
|
+
const float theta = theta_base * pow(freq_base, inv_ndims*ic);
|
1751
1797
|
|
1752
|
-
|
1798
|
+
float cos_theta, sin_theta;
|
1799
|
+
rope_yarn(theta/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
1753
1800
|
|
1754
1801
|
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
1755
1802
|
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
@@ -2204,11 +2251,7 @@ kernel void kernel_flash_attn_ext_f16(
|
|
2204
2251
|
// pointer to the mask
|
2205
2252
|
device const half * mp = (device const half *) (mask + iq1*nb31);
|
2206
2253
|
|
2207
|
-
|
2208
|
-
simdgroup_float8x8 mscale(scale);
|
2209
|
-
|
2210
|
-
// prepare diagonal slope matrix
|
2211
|
-
simdgroup_float8x8 mslope(1.0f);
|
2254
|
+
float slope = 1.0f;
|
2212
2255
|
|
2213
2256
|
// ALiBi
|
2214
2257
|
if (max_bias > 0.0f) {
|
@@ -2217,7 +2260,7 @@ kernel void kernel_flash_attn_ext_f16(
|
|
2217
2260
|
const float base = h < n_head_log2 ? m0 : m1;
|
2218
2261
|
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
2219
2262
|
|
2220
|
-
|
2263
|
+
slope = pow(base, exph);
|
2221
2264
|
}
|
2222
2265
|
|
2223
2266
|
// loop over the KV cache
|
@@ -2242,18 +2285,20 @@ kernel void kernel_flash_attn_ext_f16(
|
|
2242
2285
|
simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
|
2243
2286
|
}
|
2244
2287
|
|
2288
|
+
simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
|
2289
|
+
|
2290
|
+
const short tx = tiisg%4;
|
2291
|
+
const short ty = tiisg/4;
|
2292
|
+
|
2245
2293
|
if (mask != q) {
|
2246
2294
|
// mqk = mqk*scale + mask*slope
|
2247
|
-
|
2248
|
-
|
2249
|
-
simdgroup_multiply(mm, mslope, mm);
|
2250
|
-
simdgroup_multiply_accumulate(mqk, mqk, mscale, mm);
|
2295
|
+
ss[8*cc + ty*TF + 2*tx + 0] = scale*ss[8*cc + ty*TF + 2*tx + 0] + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 0];
|
2296
|
+
ss[8*cc + ty*TF + 2*tx + 1] = scale*ss[8*cc + ty*TF + 2*tx + 1] + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 1];
|
2251
2297
|
} else {
|
2252
2298
|
// mqk = mqk*scale
|
2253
|
-
|
2299
|
+
ss[8*cc + ty*TF + 2*tx + 0] *= scale;
|
2300
|
+
ss[8*cc + ty*TF + 2*tx + 1] *= scale;
|
2254
2301
|
}
|
2255
|
-
|
2256
|
-
simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
|
2257
2302
|
}
|
2258
2303
|
}
|
2259
2304
|
|
@@ -2416,7 +2461,7 @@ template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f
|
|
2416
2461
|
template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<96>;
|
2417
2462
|
template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<112>;
|
2418
2463
|
template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128>;
|
2419
|
-
template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256>;
|
2464
|
+
//template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256>;
|
2420
2465
|
|
2421
2466
|
template<int64_t D, int64_t Q = 1, int64_t C = 32> // head size, queries per threadgroup, cache items per threadgroup
|
2422
2467
|
kernel void kernel_flash_attn_ext_vec_f16(
|
@@ -2694,7 +2739,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
|
2694
2739
|
}
|
2695
2740
|
|
2696
2741
|
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>;
|
2697
|
-
template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>;
|
2742
|
+
//template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>;
|
2698
2743
|
|
2699
2744
|
kernel void kernel_cpy_f16_f16(
|
2700
2745
|
device const half * src0,
|
@@ -2816,8 +2861,7 @@ kernel void kernel_cpy_f32_f16(
|
|
2816
2861
|
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
2817
2862
|
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
2818
2863
|
|
2819
|
-
|
2820
|
-
dst_data[i00] = src[0] == -INFINITY ? -MAXHALF : src[0];
|
2864
|
+
dst_data[i00] = src[0];
|
2821
2865
|
}
|
2822
2866
|
}
|
2823
2867
|
|
@@ -3318,31 +3362,30 @@ kernel void kernel_concat(
|
|
3318
3362
|
constant uint64_t & nb1,
|
3319
3363
|
constant uint64_t & nb2,
|
3320
3364
|
constant uint64_t & nb3,
|
3365
|
+
constant int32_t & dim,
|
3321
3366
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
3322
3367
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
3323
3368
|
uint3 ntg[[threads_per_threadgroup]]) {
|
3324
3369
|
|
3325
|
-
const int64_t
|
3326
|
-
const int64_t
|
3327
|
-
const int64_t
|
3370
|
+
const int64_t i3 = tgpig.z;
|
3371
|
+
const int64_t i2 = tgpig.y;
|
3372
|
+
const int64_t i1 = tgpig.x;
|
3328
3373
|
|
3329
|
-
|
3330
|
-
|
3331
|
-
const int64_t i11 = i01 % ne11;
|
3374
|
+
int64_t o[4] = {0, 0, 0, 0};
|
3375
|
+
o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03));
|
3332
3376
|
|
3333
|
-
device const
|
3334
|
-
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
|
3335
|
-
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
|
3377
|
+
device const float * x;
|
3336
3378
|
|
3337
3379
|
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
3338
|
-
if (
|
3339
|
-
(
|
3340
|
-
src0_ptr += ntg.x*nb00;
|
3380
|
+
if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
|
3381
|
+
x = (device const float *)(src0 + (i3 )*nb03 + (i2 )*nb02 + (i1 )*nb01 + (i0 )*nb00);
|
3341
3382
|
} else {
|
3342
|
-
(
|
3343
|
-
src1_ptr += ntg.x*nb10;
|
3383
|
+
x = (device const float *)(src1 + (i3 - o[3])*nb13 + (i2 - o[2])*nb12 + (i1 - o[1])*nb11 + (i0 - o[0])*nb10);
|
3344
3384
|
}
|
3345
|
-
|
3385
|
+
|
3386
|
+
device float * y = (device float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
3387
|
+
|
3388
|
+
*y = *x;
|
3346
3389
|
}
|
3347
3390
|
}
|
3348
3391
|
|
@@ -3385,7 +3428,6 @@ void kernel_mul_mv_q2_K_f32_impl(
|
|
3385
3428
|
|
3386
3429
|
const int step = sizeof(block_q2_K) * nb;
|
3387
3430
|
|
3388
|
-
#if QK_K == 256
|
3389
3431
|
const int ix = tiisg/8; // 0...3
|
3390
3432
|
const int it = tiisg%8; // 0...7
|
3391
3433
|
const int iq = it/4; // 0 or 1
|
@@ -3437,57 +3479,6 @@ void kernel_mul_mv_q2_K_f32_impl(
|
|
3437
3479
|
|
3438
3480
|
y4 += 4 * QK_K;
|
3439
3481
|
}
|
3440
|
-
#else
|
3441
|
-
const int ix = tiisg/2; // 0...15
|
3442
|
-
const int it = tiisg%2; // 0...1
|
3443
|
-
|
3444
|
-
device const float * y4 = y + ix * QK_K + 8 * it;
|
3445
|
-
|
3446
|
-
for (int ib = ix; ib < nb; ib += 16) {
|
3447
|
-
|
3448
|
-
float4 sumy = {0.f, 0.f, 0.f, 0.f};
|
3449
|
-
for (int i = 0; i < 8; ++i) {
|
3450
|
-
yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
|
3451
|
-
yl[i+ 8] = y4[i+16]; sumy[1] += yl[i+ 8];
|
3452
|
-
yl[i+16] = y4[i+32]; sumy[2] += yl[i+16];
|
3453
|
-
yl[i+24] = y4[i+48]; sumy[3] += yl[i+24];
|
3454
|
-
}
|
3455
|
-
|
3456
|
-
device const uint8_t * sc = (device const uint8_t *)x[ib].scales;
|
3457
|
-
device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it;
|
3458
|
-
device const half * dh = &x[ib].d;
|
3459
|
-
|
3460
|
-
for (int row = 0; row < N_DST; row++) {
|
3461
|
-
|
3462
|
-
float4 acc1 = {0.f, 0.f, 0.f, 0.f};
|
3463
|
-
float4 acc2 = {0.f, 0.f, 0.f, 0.f};
|
3464
|
-
for (int i = 0; i < 8; i += 2) {
|
3465
|
-
acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003);
|
3466
|
-
acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300);
|
3467
|
-
acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c);
|
3468
|
-
acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00);
|
3469
|
-
acc1[2] += yl[i+16] * (qs[i/2] & 0x0030);
|
3470
|
-
acc2[2] += yl[i+17] * (qs[i/2] & 0x3000);
|
3471
|
-
acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0);
|
3472
|
-
acc2[3] += yl[i+25] * (qs[i/2] & 0xc000);
|
3473
|
-
}
|
3474
|
-
|
3475
|
-
float dall = dh[0];
|
3476
|
-
float dmin = dh[1];
|
3477
|
-
sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f +
|
3478
|
-
(acc1[1] + 1.f/256.f * acc2[1]) * (sc[1] & 0xF) * 1.f/ 4.f +
|
3479
|
-
(acc1[2] + 1.f/256.f * acc2[2]) * (sc[2] & 0xF) * 1.f/16.f +
|
3480
|
-
(acc1[3] + 1.f/256.f * acc2[3]) * (sc[3] & 0xF) * 1.f/64.f) -
|
3481
|
-
dmin * (sumy[0] * (sc[0] >> 4) + sumy[1] * (sc[1] >> 4) + sumy[2] * (sc[2] >> 4) + sumy[3] * (sc[3] >> 4));
|
3482
|
-
|
3483
|
-
qs += step/2;
|
3484
|
-
sc += step;
|
3485
|
-
dh += step/2;
|
3486
|
-
}
|
3487
|
-
|
3488
|
-
y4 += 16 * QK_K;
|
3489
|
-
}
|
3490
|
-
#endif
|
3491
3482
|
|
3492
3483
|
for (int row = 0; row < N_DST; ++row) {
|
3493
3484
|
all_sum = simd_sum(sumf[row]);
|
@@ -3525,7 +3516,6 @@ kernel void kernel_mul_mv_q2_K_f32(
|
|
3525
3516
|
kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
|
3526
3517
|
}
|
3527
3518
|
|
3528
|
-
#if QK_K == 256
|
3529
3519
|
void kernel_mul_mv_q3_K_f32_impl(
|
3530
3520
|
device const void * src0,
|
3531
3521
|
device const float * src1,
|
@@ -3684,84 +3674,6 @@ void kernel_mul_mv_q3_K_f32_impl(
|
|
3684
3674
|
}
|
3685
3675
|
}
|
3686
3676
|
}
|
3687
|
-
#else
|
3688
|
-
void kernel_mul_mv_q3_K_f32_impl(
|
3689
|
-
device const void * src0,
|
3690
|
-
device const float * src1,
|
3691
|
-
device float * dst,
|
3692
|
-
constant int64_t & ne00,
|
3693
|
-
constant int64_t & ne01,
|
3694
|
-
constant int64_t & ne02,
|
3695
|
-
constant int64_t & ne10,
|
3696
|
-
constant int64_t & ne12,
|
3697
|
-
constant int64_t & ne0,
|
3698
|
-
constant int64_t & ne1,
|
3699
|
-
constant uint & r2,
|
3700
|
-
constant uint & r3,
|
3701
|
-
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
3702
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
3703
|
-
uint tiisg[[thread_index_in_simdgroup]],
|
3704
|
-
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
3705
|
-
|
3706
|
-
const int nb = ne00/QK_K;
|
3707
|
-
|
3708
|
-
const int64_t r0 = tgpig.x;
|
3709
|
-
const int64_t r1 = tgpig.y;
|
3710
|
-
const int64_t im = tgpig.z;
|
3711
|
-
|
3712
|
-
const int row = 2 * r0 + sgitg;
|
3713
|
-
|
3714
|
-
const uint i12 = im%ne12;
|
3715
|
-
const uint i13 = im/ne12;
|
3716
|
-
|
3717
|
-
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
3718
|
-
|
3719
|
-
device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb + offset0;
|
3720
|
-
device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
3721
|
-
|
3722
|
-
const int ix = tiisg/4;
|
3723
|
-
const int il = 4 * (tiisg%4);// 0, 4, 8, 12
|
3724
|
-
const int iq = il/8; // 0, 0, 1, 1
|
3725
|
-
const int in = il%8; // 0, 4, 0, 4
|
3726
|
-
|
3727
|
-
float2 sum = {0.f, 0.f};
|
3728
|
-
|
3729
|
-
for (int i = ix; i < nb; i += 8) {
|
3730
|
-
|
3731
|
-
const float d_all = (float)(x[i].d);
|
3732
|
-
|
3733
|
-
device const uint16_t * q = (device const uint16_t *)(x[i].qs + il);
|
3734
|
-
device const uint16_t * h = (device const uint16_t *)(x[i].hmask + in);
|
3735
|
-
device const uint16_t * s = (device const uint16_t *)(x[i].scales);
|
3736
|
-
device const float * y = yy + i * QK_K + il;
|
3737
|
-
|
3738
|
-
const float d1 = d_all * ((int32_t)(s[0] & 0x000F) - 8);
|
3739
|
-
const float d2 = d_all * ((int32_t)(s[0] & 0x00F0) - 128) * 1.f/64.f;
|
3740
|
-
const float d3 = d_all * ((int32_t)(s[0] & 0x0F00) - 2048) * 1.f/4096.f;
|
3741
|
-
const float d4 = d_all * ((int32_t)(s[0] & 0xF000) - 32768) * 1.f/262144.f;
|
3742
|
-
|
3743
|
-
for (int l = 0; l < 4; l += 2) {
|
3744
|
-
const uint16_t hm = h[l/2] >> iq;
|
3745
|
-
sum[0] += y[l+ 0] * d1 * ((int32_t)(q[l/2] & 0x0003) - ((hm & 0x0001) ? 0 : 4))
|
3746
|
-
+ y[l+16] * d2 * ((int32_t)(q[l/2] & 0x000c) - ((hm & 0x0004) ? 0 : 16))
|
3747
|
-
+ y[l+32] * d3 * ((int32_t)(q[l/2] & 0x0030) - ((hm & 0x0010) ? 0 : 64))
|
3748
|
-
+ y[l+48] * d4 * ((int32_t)(q[l/2] & 0x00c0) - ((hm & 0x0040) ? 0 : 256));
|
3749
|
-
sum[1] += y[l+ 1] * d1 * ((int32_t)(q[l/2] & 0x0300) - ((hm & 0x0100) ? 0 : 1024))
|
3750
|
-
+ y[l+17] * d2 * ((int32_t)(q[l/2] & 0x0c00) - ((hm & 0x0400) ? 0 : 4096))
|
3751
|
-
+ y[l+33] * d3 * ((int32_t)(q[l/2] & 0x3000) - ((hm & 0x1000) ? 0 : 16384))
|
3752
|
-
+ y[l+49] * d4 * ((int32_t)(q[l/2] & 0xc000) - ((hm & 0x4000) ? 0 : 65536));
|
3753
|
-
}
|
3754
|
-
|
3755
|
-
}
|
3756
|
-
const float sumf = sum[0] + sum[1] * 1.f/256.f;
|
3757
|
-
|
3758
|
-
const float tot = simd_sum(sumf);
|
3759
|
-
if (tiisg == 0) {
|
3760
|
-
dst[r1*ne0 + im*ne0*ne1 + row] = tot;
|
3761
|
-
}
|
3762
|
-
|
3763
|
-
}
|
3764
|
-
#endif
|
3765
3677
|
|
3766
3678
|
[[host_name("kernel_mul_mv_q3_K_f32")]]
|
3767
3679
|
kernel void kernel_mul_mv_q3_K_f32(
|
@@ -3791,7 +3703,6 @@ kernel void kernel_mul_mv_q3_K_f32(
|
|
3791
3703
|
kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
|
3792
3704
|
}
|
3793
3705
|
|
3794
|
-
#if QK_K == 256
|
3795
3706
|
void kernel_mul_mv_q4_K_f32_impl(
|
3796
3707
|
device const void * src0,
|
3797
3708
|
device const float * src1,
|
@@ -3905,103 +3816,6 @@ void kernel_mul_mv_q4_K_f32_impl(
|
|
3905
3816
|
}
|
3906
3817
|
}
|
3907
3818
|
}
|
3908
|
-
#else
|
3909
|
-
void kernel_mul_mv_q4_K_f32_impl(
|
3910
|
-
device const void * src0,
|
3911
|
-
device const float * src1,
|
3912
|
-
device float * dst,
|
3913
|
-
constant int64_t & ne00,
|
3914
|
-
constant int64_t & ne01,
|
3915
|
-
constant int64_t & ne02,
|
3916
|
-
constant int64_t & ne10,
|
3917
|
-
constant int64_t & ne12,
|
3918
|
-
constant int64_t & ne0,
|
3919
|
-
constant int64_t & ne1,
|
3920
|
-
constant uint & r2,
|
3921
|
-
constant uint & r3,
|
3922
|
-
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
3923
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
3924
|
-
uint tiisg[[thread_index_in_simdgroup]],
|
3925
|
-
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
3926
|
-
|
3927
|
-
const int ix = tiisg/4; // 0...7
|
3928
|
-
const int it = tiisg%4; // 0...3
|
3929
|
-
|
3930
|
-
const int nb = ne00/QK_K;
|
3931
|
-
const int r0 = tgpig.x;
|
3932
|
-
const int r1 = tgpig.y;
|
3933
|
-
const int im = tgpig.z;
|
3934
|
-
const int first_row = r0 * N_DST;
|
3935
|
-
const int ib_row = first_row * nb;
|
3936
|
-
|
3937
|
-
const uint i12 = im%ne12;
|
3938
|
-
const uint i13 = im/ne12;
|
3939
|
-
|
3940
|
-
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
3941
|
-
|
3942
|
-
device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
|
3943
|
-
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
3944
|
-
|
3945
|
-
float yl[8];
|
3946
|
-
float yh[8];
|
3947
|
-
float sumf[N_DST]={0.f}, all_sum;
|
3948
|
-
|
3949
|
-
const int step = sizeof(block_q4_K) * nb / 2;
|
3950
|
-
|
3951
|
-
device const float * y4 = y + ix * QK_K + 8 * it;
|
3952
|
-
|
3953
|
-
uint16_t sc16[4];
|
3954
|
-
|
3955
|
-
for (int ib = ix; ib < nb; ib += 8) {
|
3956
|
-
|
3957
|
-
float2 sumy = {0.f, 0.f};
|
3958
|
-
for (int i = 0; i < 8; ++i) {
|
3959
|
-
yl[i] = y4[i+ 0]; sumy[0] += yl[i];
|
3960
|
-
yh[i] = y4[i+32]; sumy[1] += yh[i];
|
3961
|
-
}
|
3962
|
-
|
3963
|
-
device const uint16_t * sc = (device const uint16_t *)x[ib].scales;
|
3964
|
-
device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it;
|
3965
|
-
device const half * dh = x[ib].d;
|
3966
|
-
|
3967
|
-
for (int row = 0; row < N_DST; row++) {
|
3968
|
-
|
3969
|
-
sc16[0] = sc[0] & 0x000f;
|
3970
|
-
sc16[1] = sc[0] & 0x0f00;
|
3971
|
-
sc16[2] = sc[0] & 0x00f0;
|
3972
|
-
sc16[3] = sc[0] & 0xf000;
|
3973
|
-
|
3974
|
-
float2 acc1 = {0.f, 0.f};
|
3975
|
-
float2 acc2 = {0.f, 0.f};
|
3976
|
-
for (int i = 0; i < 8; i += 2) {
|
3977
|
-
acc1[0] += yl[i+0] * (qs[i/2] & 0x000F);
|
3978
|
-
acc1[1] += yl[i+1] * (qs[i/2] & 0x0F00);
|
3979
|
-
acc2[0] += yh[i+0] * (qs[i/2] & 0x00F0);
|
3980
|
-
acc2[1] += yh[i+1] * (qs[i/2] & 0xF000);
|
3981
|
-
}
|
3982
|
-
|
3983
|
-
float dall = dh[0];
|
3984
|
-
float dmin = dh[1];
|
3985
|
-
sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc16[0] +
|
3986
|
-
(acc2[0] + 1.f/256.f * acc2[1]) * sc16[1] * 1.f/4096.f) -
|
3987
|
-
dmin * 1.f/16.f * (sumy[0] * sc16[2] + sumy[1] * sc16[3] * 1.f/256.f);
|
3988
|
-
|
3989
|
-
qs += step;
|
3990
|
-
sc += step;
|
3991
|
-
dh += step;
|
3992
|
-
}
|
3993
|
-
|
3994
|
-
y4 += 8 * QK_K;
|
3995
|
-
}
|
3996
|
-
|
3997
|
-
for (int row = 0; row < N_DST; ++row) {
|
3998
|
-
all_sum = simd_sum(sumf[row]);
|
3999
|
-
if (tiisg == 0) {
|
4000
|
-
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
|
4001
|
-
}
|
4002
|
-
}
|
4003
|
-
}
|
4004
|
-
#endif
|
4005
3819
|
|
4006
3820
|
[[host_name("kernel_mul_mv_q4_K_f32")]]
|
4007
3821
|
kernel void kernel_mul_mv_q4_K_f32(
|
@@ -4069,8 +3883,6 @@ void kernel_mul_mv_q5_K_f32_impl(
|
|
4069
3883
|
|
4070
3884
|
const int step = sizeof(block_q5_K) * nb;
|
4071
3885
|
|
4072
|
-
#if QK_K == 256
|
4073
|
-
#
|
4074
3886
|
float yl[16], yh[16];
|
4075
3887
|
|
4076
3888
|
const uint16_t kmask1 = 0x3f3f;
|
@@ -4153,54 +3965,6 @@ void kernel_mul_mv_q5_K_f32_impl(
|
|
4153
3965
|
y1 += 4 * QK_K;
|
4154
3966
|
|
4155
3967
|
}
|
4156
|
-
#else
|
4157
|
-
float yl[8], yh[8];
|
4158
|
-
|
4159
|
-
const int il = 4 * (tiisg/8); // 0, 4, 8, 12
|
4160
|
-
const int ix = tiisg%8;
|
4161
|
-
const int iq = il/8; // 0, 0, 1, 1
|
4162
|
-
const int in = il%8; // 0, 4, 0, 4
|
4163
|
-
|
4164
|
-
device const float * y = yy + ix*QK_K + il;
|
4165
|
-
|
4166
|
-
for (int i = ix; i < nb; i += 8) {
|
4167
|
-
|
4168
|
-
for (int l = 0; l < 4; ++l) {
|
4169
|
-
yl[l+0] = y[l+ 0];
|
4170
|
-
yl[l+4] = y[l+16];
|
4171
|
-
yh[l+0] = y[l+32];
|
4172
|
-
yh[l+4] = y[l+48];
|
4173
|
-
}
|
4174
|
-
|
4175
|
-
device const half * dh = &x[i].d;
|
4176
|
-
device const uint8_t * q = x[i].qs + il;
|
4177
|
-
device const uint8_t * h = x[i].qh + in;
|
4178
|
-
device const int8_t * s = x[i].scales;
|
4179
|
-
|
4180
|
-
for (int row = 0; row < 2; ++row) {
|
4181
|
-
|
4182
|
-
const float d = dh[0];
|
4183
|
-
|
4184
|
-
float2 acc = {0.f, 0.f};
|
4185
|
-
for (int l = 0; l < 4; ++l) {
|
4186
|
-
const uint8_t hl = h[l] >> iq;
|
4187
|
-
acc[0] += yl[l+0] * s[0] * ((int16_t)(q[l+ 0] & 0x0F) - (hl & 0x01 ? 0 : 16))
|
4188
|
-
+ yl[l+4] * s[1] * ((int16_t)(q[l+16] & 0x0F) - (hl & 0x04 ? 0 : 16));
|
4189
|
-
acc[1] += yh[l+0] * s[2] * ((int16_t)(q[l+ 0] & 0xF0) - (hl & 0x10 ? 0 : 256))
|
4190
|
-
+ yh[l+4] * s[3] * ((int16_t)(q[l+16] & 0xF0) - (hl & 0x40 ? 0 : 256));
|
4191
|
-
}
|
4192
|
-
sumf[row] += d * (acc[0] + 1.f/16.f * acc[1]);
|
4193
|
-
|
4194
|
-
q += step;
|
4195
|
-
h += step;
|
4196
|
-
s += step;
|
4197
|
-
dh += step/2;
|
4198
|
-
|
4199
|
-
}
|
4200
|
-
|
4201
|
-
y += 8 * QK_K;
|
4202
|
-
}
|
4203
|
-
#endif
|
4204
3968
|
|
4205
3969
|
for (int row = 0; row < 2; ++row) {
|
4206
3970
|
const float tot = simd_sum(sumf[row]);
|
@@ -4279,7 +4043,6 @@ void kernel_mul_mv_q6_K_f32_impl(
|
|
4279
4043
|
|
4280
4044
|
float sumf = 0;
|
4281
4045
|
|
4282
|
-
#if QK_K == 256
|
4283
4046
|
const int tid = tiisg/2;
|
4284
4047
|
const int ix = tiisg%2;
|
4285
4048
|
const int ip = tid/8; // 0 or 1
|
@@ -4315,30 +4078,6 @@ void kernel_mul_mv_q6_K_f32_impl(
|
|
4315
4078
|
|
4316
4079
|
}
|
4317
4080
|
|
4318
|
-
#else
|
4319
|
-
const int ix = tiisg/4;
|
4320
|
-
const int il = 4*(tiisg%4);
|
4321
|
-
|
4322
|
-
for (int i = ix; i < nb; i += 8) {
|
4323
|
-
device const float * y = yy + i * QK_K + il;
|
4324
|
-
device const uint8_t * ql = x[i].ql + il;
|
4325
|
-
device const uint8_t * qh = x[i].qh + il;
|
4326
|
-
device const int8_t * s = x[i].scales;
|
4327
|
-
|
4328
|
-
const float d = x[i].d;
|
4329
|
-
|
4330
|
-
float4 sums = {0.f, 0.f, 0.f, 0.f};
|
4331
|
-
for (int l = 0; l < 4; ++l) {
|
4332
|
-
sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
|
4333
|
-
sums[1] += y[l+16] * ((int8_t)((ql[l+16] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
|
4334
|
-
sums[2] += y[l+32] * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & kmask3) >> 0)) - 32);
|
4335
|
-
sums[3] += y[l+48] * ((int8_t)((ql[l+16] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
|
4336
|
-
}
|
4337
|
-
sumf += d * (sums[0] * s[0] + sums[1] * s[1] + sums[2] * s[2] + sums[3] * s[3]);
|
4338
|
-
}
|
4339
|
-
|
4340
|
-
#endif
|
4341
|
-
|
4342
4081
|
const float tot = simd_sum(sumf);
|
4343
4082
|
if (tiisg == 0) {
|
4344
4083
|
dst[r1*ne0 + im*ne0*ne1 + row] = tot;
|
@@ -5172,9 +4911,7 @@ void kernel_mul_mv_iq1_m_f32_impl(
|
|
5172
4911
|
|
5173
4912
|
device const float * y4 = y + 32 * ix;
|
5174
4913
|
|
5175
|
-
#if QK_K != 64
|
5176
4914
|
iq1m_scale_t scale;
|
5177
|
-
#endif
|
5178
4915
|
|
5179
4916
|
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
|
5180
4917
|
|
@@ -5195,10 +4932,7 @@ void kernel_mul_mv_iq1_m_f32_impl(
|
|
5195
4932
|
device const uint16_t * sc = (device const uint16_t *)xr->scales;
|
5196
4933
|
|
5197
4934
|
for (int row = 0; row < N_DST; row++) {
|
5198
|
-
|
5199
|
-
#if QK_K != 64
|
5200
4935
|
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
|
5201
|
-
#endif
|
5202
4936
|
|
5203
4937
|
constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
|
5204
4938
|
constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700)));
|
@@ -5214,14 +4948,9 @@ void kernel_mul_mv_iq1_m_f32_impl(
|
|
5214
4948
|
}
|
5215
4949
|
const float delta1 = sumy[0] * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[1] * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
|
5216
4950
|
const float delta2 = sumy[2] * (qh[1] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[3] * (qh[1] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
|
5217
|
-
|
5218
|
-
const float d = (float) *((device const half *)(sc - 1));
|
5219
|
-
sumf[row] += d * ((sum[0] + delta1) * (2*((sc[0] >> (8*(ib%2)+0)) & 0xf) + 1) +
|
5220
|
-
(sum[1] + delta2) * (2*((sc[0] >> (8*(ib%2)+4)) & 0xf) + 1));
|
5221
|
-
#else
|
4951
|
+
|
5222
4952
|
sumf[row] += (float)scale.f16 * ((sum[0] + delta1) * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 7) + 1) +
|
5223
4953
|
(sum[1] + delta2) * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 7) + 1));
|
5224
|
-
#endif
|
5225
4954
|
|
5226
4955
|
sc += nb*sizeof(block_iq1_m)/2;
|
5227
4956
|
qs += nb*sizeof(block_iq1_m);
|
@@ -5333,7 +5062,6 @@ void kernel_mul_mv_iq4_nl_f32_impl(
|
|
5333
5062
|
}
|
5334
5063
|
}
|
5335
5064
|
|
5336
|
-
#if QK_K != 64
|
5337
5065
|
void kernel_mul_mv_iq4_xs_f32_impl(
|
5338
5066
|
device const void * src0,
|
5339
5067
|
device const float * src1,
|
@@ -5428,7 +5156,6 @@ void kernel_mul_mv_iq4_xs_f32_impl(
|
|
5428
5156
|
}
|
5429
5157
|
}
|
5430
5158
|
}
|
5431
|
-
#endif
|
5432
5159
|
|
5433
5160
|
[[host_name("kernel_mul_mv_iq1_s_f32")]]
|
5434
5161
|
kernel void kernel_mul_mv_iq1_s_f32(
|
@@ -5541,11 +5268,7 @@ kernel void kernel_mul_mv_iq4_xs_f32(
|
|
5541
5268
|
uint tiisg[[thread_index_in_simdgroup]],
|
5542
5269
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
5543
5270
|
|
5544
|
-
#if QK_K == 64
|
5545
|
-
kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
5546
|
-
#else
|
5547
5271
|
kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
5548
|
-
#endif
|
5549
5272
|
}
|
5550
5273
|
|
5551
5274
|
//============================= templates and their specializations =============================
|
@@ -5671,10 +5394,9 @@ void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg
|
|
5671
5394
|
float dl, ml;
|
5672
5395
|
uint8_t sc = xb->scales[il];
|
5673
5396
|
|
5674
|
-
#if QK_K == 256
|
5675
5397
|
q = q + 32*(il/8) + 16*(il&1);
|
5676
5398
|
il = (il/2)%4;
|
5677
|
-
|
5399
|
+
|
5678
5400
|
half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
|
5679
5401
|
uchar mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
|
5680
5402
|
dl = d * (sc & 0xF) * coef, ml = min * (sc >> 4);
|
@@ -5690,7 +5412,6 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
|
|
5690
5412
|
device const uint8_t * h = (device const uint8_t *)xb->hmask;
|
5691
5413
|
device const int8_t * scales = (device const int8_t *)xb->scales;
|
5692
5414
|
|
5693
|
-
#if QK_K == 256
|
5694
5415
|
q = q + 32 * (il/8) + 16 * (il&1);
|
5695
5416
|
h = h + 16 * (il&1);
|
5696
5417
|
uint8_t m = 1 << (il/2);
|
@@ -5711,17 +5432,6 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
|
|
5711
5432
|
for (int i = 0; i < 16; ++i) {
|
5712
5433
|
reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml);
|
5713
5434
|
}
|
5714
|
-
#else
|
5715
|
-
float kcoef = il&1 ? 1.f/16.f : 1.f;
|
5716
|
-
uint16_t kmask = il&1 ? 0xF0 : 0x0F;
|
5717
|
-
float dl = d_all * ((scales[il/2] & kmask) * kcoef - 8);
|
5718
|
-
float coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
|
5719
|
-
uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
|
5720
|
-
uint8_t m = 1<<(il*2);
|
5721
|
-
for (int i = 0; i < 16; ++i) {
|
5722
|
-
reg[i/4][i%4] = coef * dl * ((q[i] & mask) - ((h[i%8] & (m * (1 + i/8))) ? 0 : 4.f/coef));
|
5723
|
-
}
|
5724
|
-
#endif
|
5725
5435
|
}
|
5726
5436
|
|
5727
5437
|
static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) {
|
@@ -5733,7 +5443,6 @@ template <typename type4x4>
|
|
5733
5443
|
void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) {
|
5734
5444
|
device const uchar * q = xb->qs;
|
5735
5445
|
|
5736
|
-
#if QK_K == 256
|
5737
5446
|
short is = (il/4) * 2;
|
5738
5447
|
q = q + (il/4) * 32 + 16 * (il&1);
|
5739
5448
|
il = il & 3;
|
@@ -5742,16 +5451,7 @@ void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg
|
|
5742
5451
|
const float min = xb->dmin;
|
5743
5452
|
const float dl = d * sc[0];
|
5744
5453
|
const float ml = min * sc[1];
|
5745
|
-
|
5746
|
-
(void) get_scale_min_k4_just2;
|
5747
|
-
|
5748
|
-
q = q + 16 * (il&1);
|
5749
|
-
device const uint8_t * s = xb->scales;
|
5750
|
-
device const half2 * dh = (device const half2 *)xb->d;
|
5751
|
-
const float2 d = (float2)dh[0];
|
5752
|
-
const float dl = il<2 ? d[0] * (s[0]&0xF) : d[0] * (s[1]&0xF)/16.h;
|
5753
|
-
const float ml = il<2 ? d[1] * (s[0]>>4) : d[1] * (s[1]>>4);
|
5754
|
-
#endif
|
5454
|
+
|
5755
5455
|
const ushort mask = il<2 ? 0x0F : 0xF0;
|
5756
5456
|
for (int i = 0; i < 16; ++i) {
|
5757
5457
|
reg[i/4][i%4] = dl * (q[i] & mask) - ml;
|
@@ -5763,7 +5463,6 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
|
|
5763
5463
|
device const uint8_t * q = xb->qs;
|
5764
5464
|
device const uint8_t * qh = xb->qh;
|
5765
5465
|
|
5766
|
-
#if QK_K == 256
|
5767
5466
|
short is = (il/4) * 2;
|
5768
5467
|
q = q + 32 * (il/4) + 16 * (il&1);
|
5769
5468
|
qh = qh + 16 * (il&1);
|
@@ -5780,17 +5479,6 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
|
|
5780
5479
|
for (int i = 0; i < 16; ++i) {
|
5781
5480
|
reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
|
5782
5481
|
}
|
5783
|
-
#else
|
5784
|
-
q = q + 16 * (il&1);
|
5785
|
-
device const int8_t * s = xb->scales;
|
5786
|
-
const float dl = xb->d * s[il];
|
5787
|
-
uint8_t m = 1<<(il*2);
|
5788
|
-
const float coef = il<2 ? 1.f : 1.f/16.f;
|
5789
|
-
const ushort mask = il<2 ? 0x0F : 0xF0;
|
5790
|
-
for (int i = 0; i < 16; ++i) {
|
5791
|
-
reg[i/4][i%4] = coef * dl * ((q[i] & mask) - (qh[i%8] & (m*(1+i/8)) ? 0.f : 16.f/coef));
|
5792
|
-
}
|
5793
|
-
#endif
|
5794
5482
|
}
|
5795
5483
|
|
5796
5484
|
template <typename type4x4>
|
@@ -5800,15 +5488,11 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
|
|
5800
5488
|
device const uint8_t * qh = (device const uint8_t *)xb->qh;
|
5801
5489
|
device const int8_t * scales = (device const int8_t *)xb->scales;
|
5802
5490
|
|
5803
|
-
#if QK_K == 256
|
5804
5491
|
ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
|
5805
5492
|
qh = qh + 32*(il/8) + 16*(il&1);
|
5806
5493
|
float sc = scales[(il%2) + 2 * ((il/2))];
|
5807
5494
|
il = (il/2) & 3;
|
5808
|
-
|
5809
|
-
ql = ql + 16 * (il&1);
|
5810
|
-
float sc = scales[il];
|
5811
|
-
#endif
|
5495
|
+
|
5812
5496
|
const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
|
5813
5497
|
const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
|
5814
5498
|
const float coef = il>1 ? 1.f/16.f : 1.f;
|
@@ -5965,20 +5649,15 @@ void dequantize_iq1_m(device const block_iq1_m * xb, short il, thread type4x4 &
|
|
5965
5649
|
const int ib32 = il/2;
|
5966
5650
|
il = il%2;
|
5967
5651
|
device const uint16_t * sc = (device const uint16_t *)xb->scales;
|
5968
|
-
|
5969
|
-
const float d = xb->d;
|
5970
|
-
#else
|
5652
|
+
|
5971
5653
|
iq1m_scale_t scale;
|
5972
5654
|
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
|
5973
5655
|
const float d = scale.f16;
|
5974
|
-
|
5656
|
+
|
5975
5657
|
device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
|
5976
5658
|
device const uint8_t * qh = xb->qh + 2*ib32 + il;
|
5977
|
-
|
5978
|
-
const float dl = d * (2*((sc[ib32/2] >> (8*(ib32%2)+4*il)) & 0xf) + 1);
|
5979
|
-
#else
|
5659
|
+
|
5980
5660
|
const float dl = d * (2*((sc[ib32/2] >> (6*(ib32%2)+3*il)) & 7) + 1);
|
5981
|
-
#endif
|
5982
5661
|
const float ml1 = dl * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
|
5983
5662
|
const float ml2 = dl * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
|
5984
5663
|
constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
|
@@ -6008,9 +5687,6 @@ void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4
|
|
6008
5687
|
|
6009
5688
|
template <typename type4x4>
|
6010
5689
|
void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) {
|
6011
|
-
#if QK_K == 64
|
6012
|
-
dequantize_iq4_nl(xb, il, reg);
|
6013
|
-
#else
|
6014
5690
|
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
|
6015
5691
|
const int ib32 = il/2;
|
6016
5692
|
il = il%2;
|
@@ -6027,7 +5703,6 @@ void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4
|
|
6027
5703
|
reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
|
6028
5704
|
reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
|
6029
5705
|
}
|
6030
|
-
#endif
|
6031
5706
|
}
|
6032
5707
|
|
6033
5708
|
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
|
@@ -6532,11 +6207,7 @@ kernel void kernel_mul_mm_id(
|
|
6532
6207
|
sgitg);
|
6533
6208
|
}
|
6534
6209
|
|
6535
|
-
#if QK_K == 256
|
6536
6210
|
#define QK_NL 16
|
6537
|
-
#else
|
6538
|
-
#define QK_NL 4
|
6539
|
-
#endif
|
6540
6211
|
|
6541
6212
|
//
|
6542
6213
|
// get rows
|
@@ -6576,11 +6247,7 @@ template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_t kernel_get_r
|
|
6576
6247
|
template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
6577
6248
|
template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_t kernel_get_rows<block_iq1_m, QK_NL, dequantize_iq1_m>;
|
6578
6249
|
template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows<block_iq4_nl, 2, dequantize_iq4_nl>;
|
6579
|
-
#if QK_K == 64
|
6580
|
-
template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows<block_iq4_xs, 2, dequantize_iq4_xs>;
|
6581
|
-
#else
|
6582
6250
|
template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
6583
|
-
#endif
|
6584
6251
|
|
6585
6252
|
//
|
6586
6253
|
// matrix-matrix multiplication
|
@@ -6608,11 +6275,7 @@ template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_m
|
|
6608
6275
|
template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
6609
6276
|
template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_m, QK_NL, dequantize_iq1_m>;
|
6610
6277
|
template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_nl>;
|
6611
|
-
#if QK_K == 64
|
6612
|
-
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_xs>;
|
6613
|
-
#else
|
6614
6278
|
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
6615
|
-
#endif
|
6616
6279
|
|
6617
6280
|
//
|
6618
6281
|
// indirect matrix-matrix multiplication
|
@@ -6640,11 +6303,7 @@ template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mat_mm_id_t kernel
|
|
6640
6303
|
template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
6641
6304
|
template [[host_name("kernel_mul_mm_id_iq1_m_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_m, QK_NL, dequantize_iq1_m>;
|
6642
6305
|
template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl, 2, dequantize_iq4_nl>;
|
6643
|
-
#if QK_K == 64
|
6644
|
-
template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, 2, dequantize_iq4_xs>;
|
6645
|
-
#else
|
6646
6306
|
template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
6647
|
-
#endif
|
6648
6307
|
|
6649
6308
|
//
|
6650
6309
|
// matrix-vector multiplication
|
@@ -6853,7 +6512,5 @@ template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t
|
|
6853
6512
|
template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_s_f32_impl>>;
|
6854
6513
|
template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_s_f32_impl>>;
|
6855
6514
|
template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl>>;
|
6856
|
-
#if QK_K != 64
|
6857
6515
|
template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl>>;
|
6858
|
-
#endif
|
6859
6516
|
|