llama_cpp 0.15.2 → 0.15.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +14 -0
- data/ext/llama_cpp/llama_cpp.cpp +61 -0
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +6 -0
- data/vendor/tmp/llama.cpp/Makefile +8 -16
- data/vendor/tmp/llama.cpp/ggml-common.h +0 -54
- data/vendor/tmp/llama.cpp/ggml-cuda.cu +99 -40
- data/vendor/tmp/llama.cpp/ggml-cuda.h +1 -0
- data/vendor/tmp/llama.cpp/ggml-impl.h +44 -0
- data/vendor/tmp/llama.cpp/ggml-kompute.cpp +4 -1
- data/vendor/tmp/llama.cpp/ggml-metal.m +133 -81
- data/vendor/tmp/llama.cpp/ggml-metal.metal +91 -434
- data/vendor/tmp/llama.cpp/ggml-opencl.cpp +4 -1
- data/vendor/tmp/llama.cpp/ggml-quants.c +1962 -2443
- data/vendor/tmp/llama.cpp/ggml-rpc.cpp +248 -108
- data/vendor/tmp/llama.cpp/ggml-sycl.cpp +375 -657
- data/vendor/tmp/llama.cpp/ggml-vulkan-shaders.hpp +9351 -5627
- data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +204 -225
- data/vendor/tmp/llama.cpp/ggml.c +498 -836
- data/vendor/tmp/llama.cpp/ggml.h +57 -30
- data/vendor/tmp/llama.cpp/llama.cpp +1477 -859
- data/vendor/tmp/llama.cpp/llama.h +21 -8
- metadata +3 -3
@@ -168,6 +168,53 @@ kernel void kernel_div(
|
|
168
168
|
}
|
169
169
|
}
|
170
170
|
|
171
|
+
template<typename T>
|
172
|
+
kernel void kernel_repeat(
|
173
|
+
device const char * src0,
|
174
|
+
device char * dst,
|
175
|
+
constant int64_t & ne00,
|
176
|
+
constant int64_t & ne01,
|
177
|
+
constant int64_t & ne02,
|
178
|
+
constant int64_t & ne03,
|
179
|
+
constant uint64_t & nb00,
|
180
|
+
constant uint64_t & nb01,
|
181
|
+
constant uint64_t & nb02,
|
182
|
+
constant uint64_t & nb03,
|
183
|
+
constant int64_t & ne0,
|
184
|
+
constant int64_t & ne1,
|
185
|
+
constant int64_t & ne2,
|
186
|
+
constant int64_t & ne3,
|
187
|
+
constant uint64_t & nb0,
|
188
|
+
constant uint64_t & nb1,
|
189
|
+
constant uint64_t & nb2,
|
190
|
+
constant uint64_t & nb3,
|
191
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
192
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
193
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
194
|
+
const int64_t i3 = tgpig.z;
|
195
|
+
const int64_t i2 = tgpig.y;
|
196
|
+
const int64_t i1 = tgpig.x;
|
197
|
+
|
198
|
+
const int64_t i03 = i3 % ne03;
|
199
|
+
const int64_t i02 = i2 % ne02;
|
200
|
+
const int64_t i01 = i1 % ne01;
|
201
|
+
|
202
|
+
device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
|
203
|
+
device char * dst_ptr = dst + i3*nb3 + i2*nb2 + i1*nb1 ;
|
204
|
+
|
205
|
+
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
206
|
+
const int i00 = i0 % ne00;
|
207
|
+
*((device T *)(dst_ptr + i0*nb0)) = *((device T *)(src0_ptr + i00*nb00));
|
208
|
+
}
|
209
|
+
}
|
210
|
+
|
211
|
+
typedef decltype(kernel_repeat<float>) kernel_repeat_t;
|
212
|
+
|
213
|
+
template [[host_name("kernel_repeat_f32")]] kernel kernel_repeat_t kernel_repeat<float>;
|
214
|
+
template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat<half>;
|
215
|
+
template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat<int>;
|
216
|
+
template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat<short>;
|
217
|
+
|
171
218
|
// assumption: src1 is a row
|
172
219
|
// broadcast src1 into src0
|
173
220
|
kernel void kernel_add_row(
|
@@ -1640,6 +1687,7 @@ static void rope_yarn_corr_dims(
|
|
1640
1687
|
typedef void (rope_t)(
|
1641
1688
|
device const void * src0,
|
1642
1689
|
device const int32_t * src1,
|
1690
|
+
device const float * src2,
|
1643
1691
|
device float * dst,
|
1644
1692
|
constant int64_t & ne00,
|
1645
1693
|
constant int64_t & ne01,
|
@@ -1675,6 +1723,7 @@ template<typename T>
|
|
1675
1723
|
kernel void kernel_rope(
|
1676
1724
|
device const void * src0,
|
1677
1725
|
device const int32_t * src1,
|
1726
|
+
device const float * src2,
|
1678
1727
|
device float * dst,
|
1679
1728
|
constant int64_t & ne00,
|
1680
1729
|
constant int64_t & ne01,
|
@@ -1718,13 +1767,13 @@ kernel void kernel_rope(
|
|
1718
1767
|
|
1719
1768
|
const int64_t p = pos[i2];
|
1720
1769
|
|
1721
|
-
const float
|
1770
|
+
const float theta_base = (float)p;
|
1722
1771
|
const float inv_ndims = -1.f/n_dims;
|
1723
1772
|
|
1724
1773
|
if (!is_neox) {
|
1725
1774
|
for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
|
1775
|
+
const float theta = theta_base * pow(freq_base, inv_ndims*i0);
|
1726
1776
|
|
1727
|
-
const float theta = theta_0 * pow(freq_base, inv_ndims*i0);
|
1728
1777
|
float cos_theta, sin_theta;
|
1729
1778
|
rope_yarn(theta, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
1730
1779
|
|
@@ -1740,16 +1789,14 @@ kernel void kernel_rope(
|
|
1740
1789
|
} else {
|
1741
1790
|
for (int64_t ic = 2*tiitg; ic < ne0; ic += 2*tptg.x) {
|
1742
1791
|
if (ic < n_dims) {
|
1743
|
-
const int64_t
|
1792
|
+
const int64_t i0 = ic/2;
|
1744
1793
|
|
1745
|
-
|
1746
|
-
const float cur_rot = inv_ndims*ic - ib;
|
1794
|
+
const float freq_factor = src2 != src0 ? src2[i0] : 1.0f;
|
1747
1795
|
|
1748
|
-
const float theta =
|
1749
|
-
float cos_theta, sin_theta;
|
1750
|
-
rope_yarn(theta, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
1796
|
+
const float theta = theta_base * pow(freq_base, inv_ndims*ic);
|
1751
1797
|
|
1752
|
-
|
1798
|
+
float cos_theta, sin_theta;
|
1799
|
+
rope_yarn(theta/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
1753
1800
|
|
1754
1801
|
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
1755
1802
|
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
@@ -2204,11 +2251,7 @@ kernel void kernel_flash_attn_ext_f16(
|
|
2204
2251
|
// pointer to the mask
|
2205
2252
|
device const half * mp = (device const half *) (mask + iq1*nb31);
|
2206
2253
|
|
2207
|
-
|
2208
|
-
simdgroup_float8x8 mscale(scale);
|
2209
|
-
|
2210
|
-
// prepare diagonal slope matrix
|
2211
|
-
simdgroup_float8x8 mslope(1.0f);
|
2254
|
+
float slope = 1.0f;
|
2212
2255
|
|
2213
2256
|
// ALiBi
|
2214
2257
|
if (max_bias > 0.0f) {
|
@@ -2217,7 +2260,7 @@ kernel void kernel_flash_attn_ext_f16(
|
|
2217
2260
|
const float base = h < n_head_log2 ? m0 : m1;
|
2218
2261
|
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
2219
2262
|
|
2220
|
-
|
2263
|
+
slope = pow(base, exph);
|
2221
2264
|
}
|
2222
2265
|
|
2223
2266
|
// loop over the KV cache
|
@@ -2242,18 +2285,20 @@ kernel void kernel_flash_attn_ext_f16(
|
|
2242
2285
|
simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
|
2243
2286
|
}
|
2244
2287
|
|
2288
|
+
simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
|
2289
|
+
|
2290
|
+
const short tx = tiisg%4;
|
2291
|
+
const short ty = tiisg/4;
|
2292
|
+
|
2245
2293
|
if (mask != q) {
|
2246
2294
|
// mqk = mqk*scale + mask*slope
|
2247
|
-
|
2248
|
-
|
2249
|
-
simdgroup_multiply(mm, mslope, mm);
|
2250
|
-
simdgroup_multiply_accumulate(mqk, mqk, mscale, mm);
|
2295
|
+
ss[8*cc + ty*TF + 2*tx + 0] = scale*ss[8*cc + ty*TF + 2*tx + 0] + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 0];
|
2296
|
+
ss[8*cc + ty*TF + 2*tx + 1] = scale*ss[8*cc + ty*TF + 2*tx + 1] + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 1];
|
2251
2297
|
} else {
|
2252
2298
|
// mqk = mqk*scale
|
2253
|
-
|
2299
|
+
ss[8*cc + ty*TF + 2*tx + 0] *= scale;
|
2300
|
+
ss[8*cc + ty*TF + 2*tx + 1] *= scale;
|
2254
2301
|
}
|
2255
|
-
|
2256
|
-
simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
|
2257
2302
|
}
|
2258
2303
|
}
|
2259
2304
|
|
@@ -2416,7 +2461,7 @@ template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f
|
|
2416
2461
|
template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<96>;
|
2417
2462
|
template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<112>;
|
2418
2463
|
template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128>;
|
2419
|
-
template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256>;
|
2464
|
+
//template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256>;
|
2420
2465
|
|
2421
2466
|
template<int64_t D, int64_t Q = 1, int64_t C = 32> // head size, queries per threadgroup, cache items per threadgroup
|
2422
2467
|
kernel void kernel_flash_attn_ext_vec_f16(
|
@@ -2694,7 +2739,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
|
2694
2739
|
}
|
2695
2740
|
|
2696
2741
|
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>;
|
2697
|
-
template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>;
|
2742
|
+
//template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>;
|
2698
2743
|
|
2699
2744
|
kernel void kernel_cpy_f16_f16(
|
2700
2745
|
device const half * src0,
|
@@ -2816,8 +2861,7 @@ kernel void kernel_cpy_f32_f16(
|
|
2816
2861
|
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
2817
2862
|
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
2818
2863
|
|
2819
|
-
|
2820
|
-
dst_data[i00] = src[0] == -INFINITY ? -MAXHALF : src[0];
|
2864
|
+
dst_data[i00] = src[0];
|
2821
2865
|
}
|
2822
2866
|
}
|
2823
2867
|
|
@@ -3318,31 +3362,30 @@ kernel void kernel_concat(
|
|
3318
3362
|
constant uint64_t & nb1,
|
3319
3363
|
constant uint64_t & nb2,
|
3320
3364
|
constant uint64_t & nb3,
|
3365
|
+
constant int32_t & dim,
|
3321
3366
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
3322
3367
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
3323
3368
|
uint3 ntg[[threads_per_threadgroup]]) {
|
3324
3369
|
|
3325
|
-
const int64_t
|
3326
|
-
const int64_t
|
3327
|
-
const int64_t
|
3370
|
+
const int64_t i3 = tgpig.z;
|
3371
|
+
const int64_t i2 = tgpig.y;
|
3372
|
+
const int64_t i1 = tgpig.x;
|
3328
3373
|
|
3329
|
-
|
3330
|
-
|
3331
|
-
const int64_t i11 = i01 % ne11;
|
3374
|
+
int64_t o[4] = {0, 0, 0, 0};
|
3375
|
+
o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03));
|
3332
3376
|
|
3333
|
-
device const
|
3334
|
-
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
|
3335
|
-
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
|
3377
|
+
device const float * x;
|
3336
3378
|
|
3337
3379
|
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
3338
|
-
if (
|
3339
|
-
(
|
3340
|
-
src0_ptr += ntg.x*nb00;
|
3380
|
+
if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
|
3381
|
+
x = (device const float *)(src0 + (i3 )*nb03 + (i2 )*nb02 + (i1 )*nb01 + (i0 )*nb00);
|
3341
3382
|
} else {
|
3342
|
-
(
|
3343
|
-
src1_ptr += ntg.x*nb10;
|
3383
|
+
x = (device const float *)(src1 + (i3 - o[3])*nb13 + (i2 - o[2])*nb12 + (i1 - o[1])*nb11 + (i0 - o[0])*nb10);
|
3344
3384
|
}
|
3345
|
-
|
3385
|
+
|
3386
|
+
device float * y = (device float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
3387
|
+
|
3388
|
+
*y = *x;
|
3346
3389
|
}
|
3347
3390
|
}
|
3348
3391
|
|
@@ -3385,7 +3428,6 @@ void kernel_mul_mv_q2_K_f32_impl(
|
|
3385
3428
|
|
3386
3429
|
const int step = sizeof(block_q2_K) * nb;
|
3387
3430
|
|
3388
|
-
#if QK_K == 256
|
3389
3431
|
const int ix = tiisg/8; // 0...3
|
3390
3432
|
const int it = tiisg%8; // 0...7
|
3391
3433
|
const int iq = it/4; // 0 or 1
|
@@ -3437,57 +3479,6 @@ void kernel_mul_mv_q2_K_f32_impl(
|
|
3437
3479
|
|
3438
3480
|
y4 += 4 * QK_K;
|
3439
3481
|
}
|
3440
|
-
#else
|
3441
|
-
const int ix = tiisg/2; // 0...15
|
3442
|
-
const int it = tiisg%2; // 0...1
|
3443
|
-
|
3444
|
-
device const float * y4 = y + ix * QK_K + 8 * it;
|
3445
|
-
|
3446
|
-
for (int ib = ix; ib < nb; ib += 16) {
|
3447
|
-
|
3448
|
-
float4 sumy = {0.f, 0.f, 0.f, 0.f};
|
3449
|
-
for (int i = 0; i < 8; ++i) {
|
3450
|
-
yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
|
3451
|
-
yl[i+ 8] = y4[i+16]; sumy[1] += yl[i+ 8];
|
3452
|
-
yl[i+16] = y4[i+32]; sumy[2] += yl[i+16];
|
3453
|
-
yl[i+24] = y4[i+48]; sumy[3] += yl[i+24];
|
3454
|
-
}
|
3455
|
-
|
3456
|
-
device const uint8_t * sc = (device const uint8_t *)x[ib].scales;
|
3457
|
-
device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it;
|
3458
|
-
device const half * dh = &x[ib].d;
|
3459
|
-
|
3460
|
-
for (int row = 0; row < N_DST; row++) {
|
3461
|
-
|
3462
|
-
float4 acc1 = {0.f, 0.f, 0.f, 0.f};
|
3463
|
-
float4 acc2 = {0.f, 0.f, 0.f, 0.f};
|
3464
|
-
for (int i = 0; i < 8; i += 2) {
|
3465
|
-
acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003);
|
3466
|
-
acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300);
|
3467
|
-
acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c);
|
3468
|
-
acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00);
|
3469
|
-
acc1[2] += yl[i+16] * (qs[i/2] & 0x0030);
|
3470
|
-
acc2[2] += yl[i+17] * (qs[i/2] & 0x3000);
|
3471
|
-
acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0);
|
3472
|
-
acc2[3] += yl[i+25] * (qs[i/2] & 0xc000);
|
3473
|
-
}
|
3474
|
-
|
3475
|
-
float dall = dh[0];
|
3476
|
-
float dmin = dh[1];
|
3477
|
-
sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f +
|
3478
|
-
(acc1[1] + 1.f/256.f * acc2[1]) * (sc[1] & 0xF) * 1.f/ 4.f +
|
3479
|
-
(acc1[2] + 1.f/256.f * acc2[2]) * (sc[2] & 0xF) * 1.f/16.f +
|
3480
|
-
(acc1[3] + 1.f/256.f * acc2[3]) * (sc[3] & 0xF) * 1.f/64.f) -
|
3481
|
-
dmin * (sumy[0] * (sc[0] >> 4) + sumy[1] * (sc[1] >> 4) + sumy[2] * (sc[2] >> 4) + sumy[3] * (sc[3] >> 4));
|
3482
|
-
|
3483
|
-
qs += step/2;
|
3484
|
-
sc += step;
|
3485
|
-
dh += step/2;
|
3486
|
-
}
|
3487
|
-
|
3488
|
-
y4 += 16 * QK_K;
|
3489
|
-
}
|
3490
|
-
#endif
|
3491
3482
|
|
3492
3483
|
for (int row = 0; row < N_DST; ++row) {
|
3493
3484
|
all_sum = simd_sum(sumf[row]);
|
@@ -3525,7 +3516,6 @@ kernel void kernel_mul_mv_q2_K_f32(
|
|
3525
3516
|
kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
|
3526
3517
|
}
|
3527
3518
|
|
3528
|
-
#if QK_K == 256
|
3529
3519
|
void kernel_mul_mv_q3_K_f32_impl(
|
3530
3520
|
device const void * src0,
|
3531
3521
|
device const float * src1,
|
@@ -3684,84 +3674,6 @@ void kernel_mul_mv_q3_K_f32_impl(
|
|
3684
3674
|
}
|
3685
3675
|
}
|
3686
3676
|
}
|
3687
|
-
#else
|
3688
|
-
void kernel_mul_mv_q3_K_f32_impl(
|
3689
|
-
device const void * src0,
|
3690
|
-
device const float * src1,
|
3691
|
-
device float * dst,
|
3692
|
-
constant int64_t & ne00,
|
3693
|
-
constant int64_t & ne01,
|
3694
|
-
constant int64_t & ne02,
|
3695
|
-
constant int64_t & ne10,
|
3696
|
-
constant int64_t & ne12,
|
3697
|
-
constant int64_t & ne0,
|
3698
|
-
constant int64_t & ne1,
|
3699
|
-
constant uint & r2,
|
3700
|
-
constant uint & r3,
|
3701
|
-
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
3702
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
3703
|
-
uint tiisg[[thread_index_in_simdgroup]],
|
3704
|
-
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
3705
|
-
|
3706
|
-
const int nb = ne00/QK_K;
|
3707
|
-
|
3708
|
-
const int64_t r0 = tgpig.x;
|
3709
|
-
const int64_t r1 = tgpig.y;
|
3710
|
-
const int64_t im = tgpig.z;
|
3711
|
-
|
3712
|
-
const int row = 2 * r0 + sgitg;
|
3713
|
-
|
3714
|
-
const uint i12 = im%ne12;
|
3715
|
-
const uint i13 = im/ne12;
|
3716
|
-
|
3717
|
-
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
3718
|
-
|
3719
|
-
device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb + offset0;
|
3720
|
-
device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
3721
|
-
|
3722
|
-
const int ix = tiisg/4;
|
3723
|
-
const int il = 4 * (tiisg%4);// 0, 4, 8, 12
|
3724
|
-
const int iq = il/8; // 0, 0, 1, 1
|
3725
|
-
const int in = il%8; // 0, 4, 0, 4
|
3726
|
-
|
3727
|
-
float2 sum = {0.f, 0.f};
|
3728
|
-
|
3729
|
-
for (int i = ix; i < nb; i += 8) {
|
3730
|
-
|
3731
|
-
const float d_all = (float)(x[i].d);
|
3732
|
-
|
3733
|
-
device const uint16_t * q = (device const uint16_t *)(x[i].qs + il);
|
3734
|
-
device const uint16_t * h = (device const uint16_t *)(x[i].hmask + in);
|
3735
|
-
device const uint16_t * s = (device const uint16_t *)(x[i].scales);
|
3736
|
-
device const float * y = yy + i * QK_K + il;
|
3737
|
-
|
3738
|
-
const float d1 = d_all * ((int32_t)(s[0] & 0x000F) - 8);
|
3739
|
-
const float d2 = d_all * ((int32_t)(s[0] & 0x00F0) - 128) * 1.f/64.f;
|
3740
|
-
const float d3 = d_all * ((int32_t)(s[0] & 0x0F00) - 2048) * 1.f/4096.f;
|
3741
|
-
const float d4 = d_all * ((int32_t)(s[0] & 0xF000) - 32768) * 1.f/262144.f;
|
3742
|
-
|
3743
|
-
for (int l = 0; l < 4; l += 2) {
|
3744
|
-
const uint16_t hm = h[l/2] >> iq;
|
3745
|
-
sum[0] += y[l+ 0] * d1 * ((int32_t)(q[l/2] & 0x0003) - ((hm & 0x0001) ? 0 : 4))
|
3746
|
-
+ y[l+16] * d2 * ((int32_t)(q[l/2] & 0x000c) - ((hm & 0x0004) ? 0 : 16))
|
3747
|
-
+ y[l+32] * d3 * ((int32_t)(q[l/2] & 0x0030) - ((hm & 0x0010) ? 0 : 64))
|
3748
|
-
+ y[l+48] * d4 * ((int32_t)(q[l/2] & 0x00c0) - ((hm & 0x0040) ? 0 : 256));
|
3749
|
-
sum[1] += y[l+ 1] * d1 * ((int32_t)(q[l/2] & 0x0300) - ((hm & 0x0100) ? 0 : 1024))
|
3750
|
-
+ y[l+17] * d2 * ((int32_t)(q[l/2] & 0x0c00) - ((hm & 0x0400) ? 0 : 4096))
|
3751
|
-
+ y[l+33] * d3 * ((int32_t)(q[l/2] & 0x3000) - ((hm & 0x1000) ? 0 : 16384))
|
3752
|
-
+ y[l+49] * d4 * ((int32_t)(q[l/2] & 0xc000) - ((hm & 0x4000) ? 0 : 65536));
|
3753
|
-
}
|
3754
|
-
|
3755
|
-
}
|
3756
|
-
const float sumf = sum[0] + sum[1] * 1.f/256.f;
|
3757
|
-
|
3758
|
-
const float tot = simd_sum(sumf);
|
3759
|
-
if (tiisg == 0) {
|
3760
|
-
dst[r1*ne0 + im*ne0*ne1 + row] = tot;
|
3761
|
-
}
|
3762
|
-
|
3763
|
-
}
|
3764
|
-
#endif
|
3765
3677
|
|
3766
3678
|
[[host_name("kernel_mul_mv_q3_K_f32")]]
|
3767
3679
|
kernel void kernel_mul_mv_q3_K_f32(
|
@@ -3791,7 +3703,6 @@ kernel void kernel_mul_mv_q3_K_f32(
|
|
3791
3703
|
kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
|
3792
3704
|
}
|
3793
3705
|
|
3794
|
-
#if QK_K == 256
|
3795
3706
|
void kernel_mul_mv_q4_K_f32_impl(
|
3796
3707
|
device const void * src0,
|
3797
3708
|
device const float * src1,
|
@@ -3905,103 +3816,6 @@ void kernel_mul_mv_q4_K_f32_impl(
|
|
3905
3816
|
}
|
3906
3817
|
}
|
3907
3818
|
}
|
3908
|
-
#else
|
3909
|
-
void kernel_mul_mv_q4_K_f32_impl(
|
3910
|
-
device const void * src0,
|
3911
|
-
device const float * src1,
|
3912
|
-
device float * dst,
|
3913
|
-
constant int64_t & ne00,
|
3914
|
-
constant int64_t & ne01,
|
3915
|
-
constant int64_t & ne02,
|
3916
|
-
constant int64_t & ne10,
|
3917
|
-
constant int64_t & ne12,
|
3918
|
-
constant int64_t & ne0,
|
3919
|
-
constant int64_t & ne1,
|
3920
|
-
constant uint & r2,
|
3921
|
-
constant uint & r3,
|
3922
|
-
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
3923
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
3924
|
-
uint tiisg[[thread_index_in_simdgroup]],
|
3925
|
-
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
3926
|
-
|
3927
|
-
const int ix = tiisg/4; // 0...7
|
3928
|
-
const int it = tiisg%4; // 0...3
|
3929
|
-
|
3930
|
-
const int nb = ne00/QK_K;
|
3931
|
-
const int r0 = tgpig.x;
|
3932
|
-
const int r1 = tgpig.y;
|
3933
|
-
const int im = tgpig.z;
|
3934
|
-
const int first_row = r0 * N_DST;
|
3935
|
-
const int ib_row = first_row * nb;
|
3936
|
-
|
3937
|
-
const uint i12 = im%ne12;
|
3938
|
-
const uint i13 = im/ne12;
|
3939
|
-
|
3940
|
-
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
3941
|
-
|
3942
|
-
device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
|
3943
|
-
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
3944
|
-
|
3945
|
-
float yl[8];
|
3946
|
-
float yh[8];
|
3947
|
-
float sumf[N_DST]={0.f}, all_sum;
|
3948
|
-
|
3949
|
-
const int step = sizeof(block_q4_K) * nb / 2;
|
3950
|
-
|
3951
|
-
device const float * y4 = y + ix * QK_K + 8 * it;
|
3952
|
-
|
3953
|
-
uint16_t sc16[4];
|
3954
|
-
|
3955
|
-
for (int ib = ix; ib < nb; ib += 8) {
|
3956
|
-
|
3957
|
-
float2 sumy = {0.f, 0.f};
|
3958
|
-
for (int i = 0; i < 8; ++i) {
|
3959
|
-
yl[i] = y4[i+ 0]; sumy[0] += yl[i];
|
3960
|
-
yh[i] = y4[i+32]; sumy[1] += yh[i];
|
3961
|
-
}
|
3962
|
-
|
3963
|
-
device const uint16_t * sc = (device const uint16_t *)x[ib].scales;
|
3964
|
-
device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it;
|
3965
|
-
device const half * dh = x[ib].d;
|
3966
|
-
|
3967
|
-
for (int row = 0; row < N_DST; row++) {
|
3968
|
-
|
3969
|
-
sc16[0] = sc[0] & 0x000f;
|
3970
|
-
sc16[1] = sc[0] & 0x0f00;
|
3971
|
-
sc16[2] = sc[0] & 0x00f0;
|
3972
|
-
sc16[3] = sc[0] & 0xf000;
|
3973
|
-
|
3974
|
-
float2 acc1 = {0.f, 0.f};
|
3975
|
-
float2 acc2 = {0.f, 0.f};
|
3976
|
-
for (int i = 0; i < 8; i += 2) {
|
3977
|
-
acc1[0] += yl[i+0] * (qs[i/2] & 0x000F);
|
3978
|
-
acc1[1] += yl[i+1] * (qs[i/2] & 0x0F00);
|
3979
|
-
acc2[0] += yh[i+0] * (qs[i/2] & 0x00F0);
|
3980
|
-
acc2[1] += yh[i+1] * (qs[i/2] & 0xF000);
|
3981
|
-
}
|
3982
|
-
|
3983
|
-
float dall = dh[0];
|
3984
|
-
float dmin = dh[1];
|
3985
|
-
sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc16[0] +
|
3986
|
-
(acc2[0] + 1.f/256.f * acc2[1]) * sc16[1] * 1.f/4096.f) -
|
3987
|
-
dmin * 1.f/16.f * (sumy[0] * sc16[2] + sumy[1] * sc16[3] * 1.f/256.f);
|
3988
|
-
|
3989
|
-
qs += step;
|
3990
|
-
sc += step;
|
3991
|
-
dh += step;
|
3992
|
-
}
|
3993
|
-
|
3994
|
-
y4 += 8 * QK_K;
|
3995
|
-
}
|
3996
|
-
|
3997
|
-
for (int row = 0; row < N_DST; ++row) {
|
3998
|
-
all_sum = simd_sum(sumf[row]);
|
3999
|
-
if (tiisg == 0) {
|
4000
|
-
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
|
4001
|
-
}
|
4002
|
-
}
|
4003
|
-
}
|
4004
|
-
#endif
|
4005
3819
|
|
4006
3820
|
[[host_name("kernel_mul_mv_q4_K_f32")]]
|
4007
3821
|
kernel void kernel_mul_mv_q4_K_f32(
|
@@ -4069,8 +3883,6 @@ void kernel_mul_mv_q5_K_f32_impl(
|
|
4069
3883
|
|
4070
3884
|
const int step = sizeof(block_q5_K) * nb;
|
4071
3885
|
|
4072
|
-
#if QK_K == 256
|
4073
|
-
#
|
4074
3886
|
float yl[16], yh[16];
|
4075
3887
|
|
4076
3888
|
const uint16_t kmask1 = 0x3f3f;
|
@@ -4153,54 +3965,6 @@ void kernel_mul_mv_q5_K_f32_impl(
|
|
4153
3965
|
y1 += 4 * QK_K;
|
4154
3966
|
|
4155
3967
|
}
|
4156
|
-
#else
|
4157
|
-
float yl[8], yh[8];
|
4158
|
-
|
4159
|
-
const int il = 4 * (tiisg/8); // 0, 4, 8, 12
|
4160
|
-
const int ix = tiisg%8;
|
4161
|
-
const int iq = il/8; // 0, 0, 1, 1
|
4162
|
-
const int in = il%8; // 0, 4, 0, 4
|
4163
|
-
|
4164
|
-
device const float * y = yy + ix*QK_K + il;
|
4165
|
-
|
4166
|
-
for (int i = ix; i < nb; i += 8) {
|
4167
|
-
|
4168
|
-
for (int l = 0; l < 4; ++l) {
|
4169
|
-
yl[l+0] = y[l+ 0];
|
4170
|
-
yl[l+4] = y[l+16];
|
4171
|
-
yh[l+0] = y[l+32];
|
4172
|
-
yh[l+4] = y[l+48];
|
4173
|
-
}
|
4174
|
-
|
4175
|
-
device const half * dh = &x[i].d;
|
4176
|
-
device const uint8_t * q = x[i].qs + il;
|
4177
|
-
device const uint8_t * h = x[i].qh + in;
|
4178
|
-
device const int8_t * s = x[i].scales;
|
4179
|
-
|
4180
|
-
for (int row = 0; row < 2; ++row) {
|
4181
|
-
|
4182
|
-
const float d = dh[0];
|
4183
|
-
|
4184
|
-
float2 acc = {0.f, 0.f};
|
4185
|
-
for (int l = 0; l < 4; ++l) {
|
4186
|
-
const uint8_t hl = h[l] >> iq;
|
4187
|
-
acc[0] += yl[l+0] * s[0] * ((int16_t)(q[l+ 0] & 0x0F) - (hl & 0x01 ? 0 : 16))
|
4188
|
-
+ yl[l+4] * s[1] * ((int16_t)(q[l+16] & 0x0F) - (hl & 0x04 ? 0 : 16));
|
4189
|
-
acc[1] += yh[l+0] * s[2] * ((int16_t)(q[l+ 0] & 0xF0) - (hl & 0x10 ? 0 : 256))
|
4190
|
-
+ yh[l+4] * s[3] * ((int16_t)(q[l+16] & 0xF0) - (hl & 0x40 ? 0 : 256));
|
4191
|
-
}
|
4192
|
-
sumf[row] += d * (acc[0] + 1.f/16.f * acc[1]);
|
4193
|
-
|
4194
|
-
q += step;
|
4195
|
-
h += step;
|
4196
|
-
s += step;
|
4197
|
-
dh += step/2;
|
4198
|
-
|
4199
|
-
}
|
4200
|
-
|
4201
|
-
y += 8 * QK_K;
|
4202
|
-
}
|
4203
|
-
#endif
|
4204
3968
|
|
4205
3969
|
for (int row = 0; row < 2; ++row) {
|
4206
3970
|
const float tot = simd_sum(sumf[row]);
|
@@ -4279,7 +4043,6 @@ void kernel_mul_mv_q6_K_f32_impl(
|
|
4279
4043
|
|
4280
4044
|
float sumf = 0;
|
4281
4045
|
|
4282
|
-
#if QK_K == 256
|
4283
4046
|
const int tid = tiisg/2;
|
4284
4047
|
const int ix = tiisg%2;
|
4285
4048
|
const int ip = tid/8; // 0 or 1
|
@@ -4315,30 +4078,6 @@ void kernel_mul_mv_q6_K_f32_impl(
|
|
4315
4078
|
|
4316
4079
|
}
|
4317
4080
|
|
4318
|
-
#else
|
4319
|
-
const int ix = tiisg/4;
|
4320
|
-
const int il = 4*(tiisg%4);
|
4321
|
-
|
4322
|
-
for (int i = ix; i < nb; i += 8) {
|
4323
|
-
device const float * y = yy + i * QK_K + il;
|
4324
|
-
device const uint8_t * ql = x[i].ql + il;
|
4325
|
-
device const uint8_t * qh = x[i].qh + il;
|
4326
|
-
device const int8_t * s = x[i].scales;
|
4327
|
-
|
4328
|
-
const float d = x[i].d;
|
4329
|
-
|
4330
|
-
float4 sums = {0.f, 0.f, 0.f, 0.f};
|
4331
|
-
for (int l = 0; l < 4; ++l) {
|
4332
|
-
sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
|
4333
|
-
sums[1] += y[l+16] * ((int8_t)((ql[l+16] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
|
4334
|
-
sums[2] += y[l+32] * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & kmask3) >> 0)) - 32);
|
4335
|
-
sums[3] += y[l+48] * ((int8_t)((ql[l+16] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
|
4336
|
-
}
|
4337
|
-
sumf += d * (sums[0] * s[0] + sums[1] * s[1] + sums[2] * s[2] + sums[3] * s[3]);
|
4338
|
-
}
|
4339
|
-
|
4340
|
-
#endif
|
4341
|
-
|
4342
4081
|
const float tot = simd_sum(sumf);
|
4343
4082
|
if (tiisg == 0) {
|
4344
4083
|
dst[r1*ne0 + im*ne0*ne1 + row] = tot;
|
@@ -5172,9 +4911,7 @@ void kernel_mul_mv_iq1_m_f32_impl(
|
|
5172
4911
|
|
5173
4912
|
device const float * y4 = y + 32 * ix;
|
5174
4913
|
|
5175
|
-
#if QK_K != 64
|
5176
4914
|
iq1m_scale_t scale;
|
5177
|
-
#endif
|
5178
4915
|
|
5179
4916
|
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
|
5180
4917
|
|
@@ -5195,10 +4932,7 @@ void kernel_mul_mv_iq1_m_f32_impl(
|
|
5195
4932
|
device const uint16_t * sc = (device const uint16_t *)xr->scales;
|
5196
4933
|
|
5197
4934
|
for (int row = 0; row < N_DST; row++) {
|
5198
|
-
|
5199
|
-
#if QK_K != 64
|
5200
4935
|
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
|
5201
|
-
#endif
|
5202
4936
|
|
5203
4937
|
constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
|
5204
4938
|
constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700)));
|
@@ -5214,14 +4948,9 @@ void kernel_mul_mv_iq1_m_f32_impl(
|
|
5214
4948
|
}
|
5215
4949
|
const float delta1 = sumy[0] * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[1] * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
|
5216
4950
|
const float delta2 = sumy[2] * (qh[1] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[3] * (qh[1] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
|
5217
|
-
|
5218
|
-
const float d = (float) *((device const half *)(sc - 1));
|
5219
|
-
sumf[row] += d * ((sum[0] + delta1) * (2*((sc[0] >> (8*(ib%2)+0)) & 0xf) + 1) +
|
5220
|
-
(sum[1] + delta2) * (2*((sc[0] >> (8*(ib%2)+4)) & 0xf) + 1));
|
5221
|
-
#else
|
4951
|
+
|
5222
4952
|
sumf[row] += (float)scale.f16 * ((sum[0] + delta1) * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 7) + 1) +
|
5223
4953
|
(sum[1] + delta2) * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 7) + 1));
|
5224
|
-
#endif
|
5225
4954
|
|
5226
4955
|
sc += nb*sizeof(block_iq1_m)/2;
|
5227
4956
|
qs += nb*sizeof(block_iq1_m);
|
@@ -5333,7 +5062,6 @@ void kernel_mul_mv_iq4_nl_f32_impl(
|
|
5333
5062
|
}
|
5334
5063
|
}
|
5335
5064
|
|
5336
|
-
#if QK_K != 64
|
5337
5065
|
void kernel_mul_mv_iq4_xs_f32_impl(
|
5338
5066
|
device const void * src0,
|
5339
5067
|
device const float * src1,
|
@@ -5428,7 +5156,6 @@ void kernel_mul_mv_iq4_xs_f32_impl(
|
|
5428
5156
|
}
|
5429
5157
|
}
|
5430
5158
|
}
|
5431
|
-
#endif
|
5432
5159
|
|
5433
5160
|
[[host_name("kernel_mul_mv_iq1_s_f32")]]
|
5434
5161
|
kernel void kernel_mul_mv_iq1_s_f32(
|
@@ -5541,11 +5268,7 @@ kernel void kernel_mul_mv_iq4_xs_f32(
|
|
5541
5268
|
uint tiisg[[thread_index_in_simdgroup]],
|
5542
5269
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
5543
5270
|
|
5544
|
-
#if QK_K == 64
|
5545
|
-
kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
5546
|
-
#else
|
5547
5271
|
kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
5548
|
-
#endif
|
5549
5272
|
}
|
5550
5273
|
|
5551
5274
|
//============================= templates and their specializations =============================
|
@@ -5671,10 +5394,9 @@ void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg
|
|
5671
5394
|
float dl, ml;
|
5672
5395
|
uint8_t sc = xb->scales[il];
|
5673
5396
|
|
5674
|
-
#if QK_K == 256
|
5675
5397
|
q = q + 32*(il/8) + 16*(il&1);
|
5676
5398
|
il = (il/2)%4;
|
5677
|
-
|
5399
|
+
|
5678
5400
|
half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
|
5679
5401
|
uchar mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
|
5680
5402
|
dl = d * (sc & 0xF) * coef, ml = min * (sc >> 4);
|
@@ -5690,7 +5412,6 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
|
|
5690
5412
|
device const uint8_t * h = (device const uint8_t *)xb->hmask;
|
5691
5413
|
device const int8_t * scales = (device const int8_t *)xb->scales;
|
5692
5414
|
|
5693
|
-
#if QK_K == 256
|
5694
5415
|
q = q + 32 * (il/8) + 16 * (il&1);
|
5695
5416
|
h = h + 16 * (il&1);
|
5696
5417
|
uint8_t m = 1 << (il/2);
|
@@ -5711,17 +5432,6 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
|
|
5711
5432
|
for (int i = 0; i < 16; ++i) {
|
5712
5433
|
reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml);
|
5713
5434
|
}
|
5714
|
-
#else
|
5715
|
-
float kcoef = il&1 ? 1.f/16.f : 1.f;
|
5716
|
-
uint16_t kmask = il&1 ? 0xF0 : 0x0F;
|
5717
|
-
float dl = d_all * ((scales[il/2] & kmask) * kcoef - 8);
|
5718
|
-
float coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
|
5719
|
-
uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
|
5720
|
-
uint8_t m = 1<<(il*2);
|
5721
|
-
for (int i = 0; i < 16; ++i) {
|
5722
|
-
reg[i/4][i%4] = coef * dl * ((q[i] & mask) - ((h[i%8] & (m * (1 + i/8))) ? 0 : 4.f/coef));
|
5723
|
-
}
|
5724
|
-
#endif
|
5725
5435
|
}
|
5726
5436
|
|
5727
5437
|
static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) {
|
@@ -5733,7 +5443,6 @@ template <typename type4x4>
|
|
5733
5443
|
void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) {
|
5734
5444
|
device const uchar * q = xb->qs;
|
5735
5445
|
|
5736
|
-
#if QK_K == 256
|
5737
5446
|
short is = (il/4) * 2;
|
5738
5447
|
q = q + (il/4) * 32 + 16 * (il&1);
|
5739
5448
|
il = il & 3;
|
@@ -5742,16 +5451,7 @@ void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg
|
|
5742
5451
|
const float min = xb->dmin;
|
5743
5452
|
const float dl = d * sc[0];
|
5744
5453
|
const float ml = min * sc[1];
|
5745
|
-
|
5746
|
-
(void) get_scale_min_k4_just2;
|
5747
|
-
|
5748
|
-
q = q + 16 * (il&1);
|
5749
|
-
device const uint8_t * s = xb->scales;
|
5750
|
-
device const half2 * dh = (device const half2 *)xb->d;
|
5751
|
-
const float2 d = (float2)dh[0];
|
5752
|
-
const float dl = il<2 ? d[0] * (s[0]&0xF) : d[0] * (s[1]&0xF)/16.h;
|
5753
|
-
const float ml = il<2 ? d[1] * (s[0]>>4) : d[1] * (s[1]>>4);
|
5754
|
-
#endif
|
5454
|
+
|
5755
5455
|
const ushort mask = il<2 ? 0x0F : 0xF0;
|
5756
5456
|
for (int i = 0; i < 16; ++i) {
|
5757
5457
|
reg[i/4][i%4] = dl * (q[i] & mask) - ml;
|
@@ -5763,7 +5463,6 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
|
|
5763
5463
|
device const uint8_t * q = xb->qs;
|
5764
5464
|
device const uint8_t * qh = xb->qh;
|
5765
5465
|
|
5766
|
-
#if QK_K == 256
|
5767
5466
|
short is = (il/4) * 2;
|
5768
5467
|
q = q + 32 * (il/4) + 16 * (il&1);
|
5769
5468
|
qh = qh + 16 * (il&1);
|
@@ -5780,17 +5479,6 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
|
|
5780
5479
|
for (int i = 0; i < 16; ++i) {
|
5781
5480
|
reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
|
5782
5481
|
}
|
5783
|
-
#else
|
5784
|
-
q = q + 16 * (il&1);
|
5785
|
-
device const int8_t * s = xb->scales;
|
5786
|
-
const float dl = xb->d * s[il];
|
5787
|
-
uint8_t m = 1<<(il*2);
|
5788
|
-
const float coef = il<2 ? 1.f : 1.f/16.f;
|
5789
|
-
const ushort mask = il<2 ? 0x0F : 0xF0;
|
5790
|
-
for (int i = 0; i < 16; ++i) {
|
5791
|
-
reg[i/4][i%4] = coef * dl * ((q[i] & mask) - (qh[i%8] & (m*(1+i/8)) ? 0.f : 16.f/coef));
|
5792
|
-
}
|
5793
|
-
#endif
|
5794
5482
|
}
|
5795
5483
|
|
5796
5484
|
template <typename type4x4>
|
@@ -5800,15 +5488,11 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
|
|
5800
5488
|
device const uint8_t * qh = (device const uint8_t *)xb->qh;
|
5801
5489
|
device const int8_t * scales = (device const int8_t *)xb->scales;
|
5802
5490
|
|
5803
|
-
#if QK_K == 256
|
5804
5491
|
ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
|
5805
5492
|
qh = qh + 32*(il/8) + 16*(il&1);
|
5806
5493
|
float sc = scales[(il%2) + 2 * ((il/2))];
|
5807
5494
|
il = (il/2) & 3;
|
5808
|
-
|
5809
|
-
ql = ql + 16 * (il&1);
|
5810
|
-
float sc = scales[il];
|
5811
|
-
#endif
|
5495
|
+
|
5812
5496
|
const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
|
5813
5497
|
const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
|
5814
5498
|
const float coef = il>1 ? 1.f/16.f : 1.f;
|
@@ -5965,20 +5649,15 @@ void dequantize_iq1_m(device const block_iq1_m * xb, short il, thread type4x4 &
|
|
5965
5649
|
const int ib32 = il/2;
|
5966
5650
|
il = il%2;
|
5967
5651
|
device const uint16_t * sc = (device const uint16_t *)xb->scales;
|
5968
|
-
|
5969
|
-
const float d = xb->d;
|
5970
|
-
#else
|
5652
|
+
|
5971
5653
|
iq1m_scale_t scale;
|
5972
5654
|
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
|
5973
5655
|
const float d = scale.f16;
|
5974
|
-
|
5656
|
+
|
5975
5657
|
device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
|
5976
5658
|
device const uint8_t * qh = xb->qh + 2*ib32 + il;
|
5977
|
-
|
5978
|
-
const float dl = d * (2*((sc[ib32/2] >> (8*(ib32%2)+4*il)) & 0xf) + 1);
|
5979
|
-
#else
|
5659
|
+
|
5980
5660
|
const float dl = d * (2*((sc[ib32/2] >> (6*(ib32%2)+3*il)) & 7) + 1);
|
5981
|
-
#endif
|
5982
5661
|
const float ml1 = dl * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
|
5983
5662
|
const float ml2 = dl * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
|
5984
5663
|
constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
|
@@ -6008,9 +5687,6 @@ void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4
|
|
6008
5687
|
|
6009
5688
|
template <typename type4x4>
|
6010
5689
|
void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) {
|
6011
|
-
#if QK_K == 64
|
6012
|
-
dequantize_iq4_nl(xb, il, reg);
|
6013
|
-
#else
|
6014
5690
|
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
|
6015
5691
|
const int ib32 = il/2;
|
6016
5692
|
il = il%2;
|
@@ -6027,7 +5703,6 @@ void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4
|
|
6027
5703
|
reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
|
6028
5704
|
reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
|
6029
5705
|
}
|
6030
|
-
#endif
|
6031
5706
|
}
|
6032
5707
|
|
6033
5708
|
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
|
@@ -6532,11 +6207,7 @@ kernel void kernel_mul_mm_id(
|
|
6532
6207
|
sgitg);
|
6533
6208
|
}
|
6534
6209
|
|
6535
|
-
#if QK_K == 256
|
6536
6210
|
#define QK_NL 16
|
6537
|
-
#else
|
6538
|
-
#define QK_NL 4
|
6539
|
-
#endif
|
6540
6211
|
|
6541
6212
|
//
|
6542
6213
|
// get rows
|
@@ -6576,11 +6247,7 @@ template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_t kernel_get_r
|
|
6576
6247
|
template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
6577
6248
|
template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_t kernel_get_rows<block_iq1_m, QK_NL, dequantize_iq1_m>;
|
6578
6249
|
template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows<block_iq4_nl, 2, dequantize_iq4_nl>;
|
6579
|
-
#if QK_K == 64
|
6580
|
-
template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows<block_iq4_xs, 2, dequantize_iq4_xs>;
|
6581
|
-
#else
|
6582
6250
|
template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
6583
|
-
#endif
|
6584
6251
|
|
6585
6252
|
//
|
6586
6253
|
// matrix-matrix multiplication
|
@@ -6608,11 +6275,7 @@ template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_m
|
|
6608
6275
|
template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
6609
6276
|
template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_m, QK_NL, dequantize_iq1_m>;
|
6610
6277
|
template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_nl>;
|
6611
|
-
#if QK_K == 64
|
6612
|
-
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_xs>;
|
6613
|
-
#else
|
6614
6278
|
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
6615
|
-
#endif
|
6616
6279
|
|
6617
6280
|
//
|
6618
6281
|
// indirect matrix-matrix multiplication
|
@@ -6640,11 +6303,7 @@ template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mat_mm_id_t kernel
|
|
6640
6303
|
template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
6641
6304
|
template [[host_name("kernel_mul_mm_id_iq1_m_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_m, QK_NL, dequantize_iq1_m>;
|
6642
6305
|
template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl, 2, dequantize_iq4_nl>;
|
6643
|
-
#if QK_K == 64
|
6644
|
-
template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, 2, dequantize_iq4_xs>;
|
6645
|
-
#else
|
6646
6306
|
template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
6647
|
-
#endif
|
6648
6307
|
|
6649
6308
|
//
|
6650
6309
|
// matrix-vector multiplication
|
@@ -6853,7 +6512,5 @@ template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t
|
|
6853
6512
|
template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_s_f32_impl>>;
|
6854
6513
|
template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_s_f32_impl>>;
|
6855
6514
|
template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl>>;
|
6856
|
-
#if QK_K != 64
|
6857
6515
|
template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl>>;
|
6858
|
-
#endif
|
6859
6516
|
|