llama_cpp 0.15.1 → 0.15.3
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +14 -0
- data/ext/llama_cpp/llama_cpp.cpp +49 -0
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +4 -0
- data/vendor/tmp/llama.cpp/Makefile +9 -20
- data/vendor/tmp/llama.cpp/ggml-backend.c +2 -3
- data/vendor/tmp/llama.cpp/ggml-common.h +0 -54
- data/vendor/tmp/llama.cpp/ggml-cuda.cu +87 -37
- data/vendor/tmp/llama.cpp/ggml-cuda.h +1 -0
- data/vendor/tmp/llama.cpp/ggml-impl.h +47 -0
- data/vendor/tmp/llama.cpp/ggml-kompute.cpp +13 -3
- data/vendor/tmp/llama.cpp/ggml-metal.m +177 -190
- data/vendor/tmp/llama.cpp/ggml-metal.metal +97 -505
- data/vendor/tmp/llama.cpp/ggml-opencl.cpp +4 -1
- data/vendor/tmp/llama.cpp/ggml-quants.c +3660 -2057
- data/vendor/tmp/llama.cpp/ggml-rpc.cpp +1155 -0
- data/vendor/tmp/llama.cpp/ggml-rpc.h +24 -0
- data/vendor/tmp/llama.cpp/ggml-sycl.cpp +60 -639
- data/vendor/tmp/llama.cpp/ggml-vulkan-shaders.hpp +9351 -5627
- data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +203 -224
- data/vendor/tmp/llama.cpp/ggml.c +1168 -1470
- data/vendor/tmp/llama.cpp/ggml.h +67 -44
- data/vendor/tmp/llama.cpp/llama.cpp +1371 -944
- data/vendor/tmp/llama.cpp/llama.h +13 -3
- data/vendor/tmp/llama.cpp/unicode-data.cpp +6969 -2169
- data/vendor/tmp/llama.cpp/unicode-data.h +15 -12
- data/vendor/tmp/llama.cpp/unicode.cpp +89 -111
- data/vendor/tmp/llama.cpp/unicode.h +44 -12
- metadata +5 -3
@@ -229,6 +229,13 @@ kernel void kernel_relu(
|
|
229
229
|
dst[tpig] = max(0.0f, src0[tpig]);
|
230
230
|
}
|
231
231
|
|
232
|
+
kernel void kernel_sigmoid(
|
233
|
+
device const float * src0,
|
234
|
+
device float * dst,
|
235
|
+
uint tpig[[thread_position_in_grid]]) {
|
236
|
+
dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig]));
|
237
|
+
}
|
238
|
+
|
232
239
|
kernel void kernel_tanh(
|
233
240
|
device const float * src0,
|
234
241
|
device float * dst,
|
@@ -356,7 +363,6 @@ template<typename T>
|
|
356
363
|
kernel void kernel_soft_max(
|
357
364
|
device const char * src0,
|
358
365
|
device const char * src1,
|
359
|
-
device const char * src2,
|
360
366
|
device char * dst,
|
361
367
|
constant int64_t & ne00,
|
362
368
|
constant int64_t & ne01,
|
@@ -378,10 +384,9 @@ kernel void kernel_soft_max(
|
|
378
384
|
|
379
385
|
device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
380
386
|
device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00 : nullptr;
|
381
|
-
device const T * ppos = src2 != src0 ? (device const T *) src2 : nullptr;
|
382
387
|
device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
383
388
|
|
384
|
-
float slope =
|
389
|
+
float slope = 1.0f;
|
385
390
|
|
386
391
|
// ALiBi
|
387
392
|
if (max_bias > 0.0f) {
|
@@ -397,7 +402,7 @@ kernel void kernel_soft_max(
|
|
397
402
|
float lmax = -INFINITY;
|
398
403
|
|
399
404
|
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
400
|
-
lmax = MAX(lmax, psrc0[i00]*scale + (pmask ?
|
405
|
+
lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
|
401
406
|
}
|
402
407
|
|
403
408
|
// find the max value in the block
|
@@ -422,7 +427,7 @@ kernel void kernel_soft_max(
|
|
422
427
|
// parallel sum
|
423
428
|
float lsum = 0.0f;
|
424
429
|
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
425
|
-
const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ?
|
430
|
+
const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val);
|
426
431
|
lsum += exp_psrc0;
|
427
432
|
pdst[i00] = exp_psrc0;
|
428
433
|
}
|
@@ -461,7 +466,6 @@ template<typename T>
|
|
461
466
|
kernel void kernel_soft_max_4(
|
462
467
|
device const char * src0,
|
463
468
|
device const char * src1,
|
464
|
-
device const char * src2,
|
465
469
|
device char * dst,
|
466
470
|
constant int64_t & ne00,
|
467
471
|
constant int64_t & ne01,
|
@@ -483,10 +487,9 @@ kernel void kernel_soft_max_4(
|
|
483
487
|
|
484
488
|
device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
|
485
489
|
device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00/4 : nullptr;
|
486
|
-
device const T * ppos = src2 != src0 ? (device const T *) src2 : nullptr;
|
487
490
|
device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
|
488
491
|
|
489
|
-
float slope =
|
492
|
+
float slope = 1.0f;
|
490
493
|
|
491
494
|
if (max_bias > 0.0f) {
|
492
495
|
const int64_t h = i02;
|
@@ -501,7 +504,7 @@ kernel void kernel_soft_max_4(
|
|
501
504
|
float4 lmax4 = -INFINITY;
|
502
505
|
|
503
506
|
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
504
|
-
lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ?
|
507
|
+
lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
|
505
508
|
}
|
506
509
|
|
507
510
|
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
|
@@ -527,7 +530,7 @@ kernel void kernel_soft_max_4(
|
|
527
530
|
// parallel sum
|
528
531
|
float4 lsum4 = 0.0f;
|
529
532
|
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
530
|
-
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ?
|
533
|
+
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val);
|
531
534
|
lsum4 += exp_psrc4;
|
532
535
|
pdst4[i00] = exp_psrc4;
|
533
536
|
}
|
@@ -1595,60 +1598,6 @@ kernel void kernel_mul_mv_f16_f32_l4(
|
|
1595
1598
|
}
|
1596
1599
|
}
|
1597
1600
|
|
1598
|
-
kernel void kernel_alibi_f32(
|
1599
|
-
device const float * src0,
|
1600
|
-
device float * dst,
|
1601
|
-
constant int64_t & ne00,
|
1602
|
-
constant int64_t & ne01,
|
1603
|
-
constant int64_t & ne02,
|
1604
|
-
constant int64_t & ne03,
|
1605
|
-
constant uint64_t & nb00,
|
1606
|
-
constant uint64_t & nb01,
|
1607
|
-
constant uint64_t & nb02,
|
1608
|
-
constant uint64_t & nb03,
|
1609
|
-
constant int64_t & ne0,
|
1610
|
-
constant int64_t & ne1,
|
1611
|
-
constant int64_t & ne2,
|
1612
|
-
constant int64_t & ne3,
|
1613
|
-
constant uint64_t & nb0,
|
1614
|
-
constant uint64_t & nb1,
|
1615
|
-
constant uint64_t & nb2,
|
1616
|
-
constant uint64_t & nb3,
|
1617
|
-
constant float & m0,
|
1618
|
-
constant float & m1,
|
1619
|
-
constant int & n_heads_log2_floor,
|
1620
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
1621
|
-
uint3 tpitg[[thread_position_in_threadgroup]],
|
1622
|
-
uint3 ntg[[threads_per_threadgroup]]) {
|
1623
|
-
const int64_t i03 = tgpig[2];
|
1624
|
-
const int64_t i02 = tgpig[1];
|
1625
|
-
const int64_t i01 = tgpig[0];
|
1626
|
-
|
1627
|
-
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
1628
|
-
|
1629
|
-
const int64_t i3 = n / (ne2*ne1*ne0);
|
1630
|
-
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
1631
|
-
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
1632
|
-
//const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
1633
|
-
|
1634
|
-
const int64_t k = i3*ne3 + i2;
|
1635
|
-
|
1636
|
-
float m_k;
|
1637
|
-
if (k < n_heads_log2_floor) {
|
1638
|
-
m_k = pow(m0, k + 1);
|
1639
|
-
} else {
|
1640
|
-
m_k = pow(m1, 2 * (k - n_heads_log2_floor) + 1);
|
1641
|
-
}
|
1642
|
-
|
1643
|
-
device char * dst_row = (device char *) dst + i3*nb3 + i2*nb2 + i1*nb1;
|
1644
|
-
device const char * src_row = (device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01;
|
1645
|
-
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
1646
|
-
const float src_v = *(device float *)(src_row + i00*nb00);
|
1647
|
-
device float * dst_v = (device float *)(dst_row + i00*nb0);
|
1648
|
-
*dst_v = i00 * m_k + src_v;
|
1649
|
-
}
|
1650
|
-
}
|
1651
|
-
|
1652
1601
|
static float rope_yarn_ramp(const float low, const float high, const int i0) {
|
1653
1602
|
const float y = (i0 / 2 - low) / max(0.001f, high - low);
|
1654
1603
|
return 1.0f - min(1.0f, max(0.0f, y));
|
@@ -1691,6 +1640,7 @@ static void rope_yarn_corr_dims(
|
|
1691
1640
|
typedef void (rope_t)(
|
1692
1641
|
device const void * src0,
|
1693
1642
|
device const int32_t * src1,
|
1643
|
+
device const float * src2,
|
1694
1644
|
device float * dst,
|
1695
1645
|
constant int64_t & ne00,
|
1696
1646
|
constant int64_t & ne01,
|
@@ -1726,6 +1676,7 @@ template<typename T>
|
|
1726
1676
|
kernel void kernel_rope(
|
1727
1677
|
device const void * src0,
|
1728
1678
|
device const int32_t * src1,
|
1679
|
+
device const float * src2,
|
1729
1680
|
device float * dst,
|
1730
1681
|
constant int64_t & ne00,
|
1731
1682
|
constant int64_t & ne01,
|
@@ -1795,8 +1746,10 @@ kernel void kernel_rope(
|
|
1795
1746
|
|
1796
1747
|
// simplified from `(ib * n_dims + ic) * inv_ndims`
|
1797
1748
|
const float cur_rot = inv_ndims*ic - ib;
|
1749
|
+
const float freq_factor = src2 != src0 ? src2[ic/2] : 1.0f;
|
1750
|
+
|
1751
|
+
const float theta = theta_0 * pow(freq_base, cur_rot) / freq_factor;
|
1798
1752
|
|
1799
|
-
const float theta = theta_0 * pow(freq_base, cur_rot);
|
1800
1753
|
float cos_theta, sin_theta;
|
1801
1754
|
rope_yarn(theta, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
1802
1755
|
|
@@ -1903,7 +1856,10 @@ kernel void kernel_upscale_f32(
|
|
1903
1856
|
constant uint64_t & nb1,
|
1904
1857
|
constant uint64_t & nb2,
|
1905
1858
|
constant uint64_t & nb3,
|
1906
|
-
constant
|
1859
|
+
constant float & sf0,
|
1860
|
+
constant float & sf1,
|
1861
|
+
constant float & sf2,
|
1862
|
+
constant float & sf3,
|
1907
1863
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
1908
1864
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
1909
1865
|
uint3 ntg[[threads_per_threadgroup]]) {
|
@@ -1912,15 +1868,17 @@ kernel void kernel_upscale_f32(
|
|
1912
1868
|
const int64_t i2 = tgpig.y;
|
1913
1869
|
const int64_t i1 = tgpig.x;
|
1914
1870
|
|
1915
|
-
const int64_t i03 = i3;
|
1916
|
-
const int64_t i02 = i2;
|
1917
|
-
const int64_t i01 = i1/
|
1918
|
-
|
1919
|
-
device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
|
1920
|
-
device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1);
|
1871
|
+
const int64_t i03 = i3/sf3;
|
1872
|
+
const int64_t i02 = i2/sf2;
|
1873
|
+
const int64_t i01 = i1/sf1;
|
1921
1874
|
|
1922
1875
|
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
1923
|
-
|
1876
|
+
const int64_t i00 = i0/sf0;
|
1877
|
+
|
1878
|
+
device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
1879
|
+
device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
1880
|
+
|
1881
|
+
dst_ptr[0] = src0_ptr[0];
|
1924
1882
|
}
|
1925
1883
|
}
|
1926
1884
|
|
@@ -2100,29 +2058,29 @@ typedef void (flash_attn_ext_f16_t)(
|
|
2100
2058
|
device const char * v,
|
2101
2059
|
device const char * mask,
|
2102
2060
|
device float * dst,
|
2103
|
-
constant int64_t & ne00,
|
2104
2061
|
constant int64_t & ne01,
|
2105
2062
|
constant int64_t & ne02,
|
2106
2063
|
constant int64_t & ne03,
|
2107
|
-
constant uint64_t & nb00,
|
2108
2064
|
constant uint64_t & nb01,
|
2109
2065
|
constant uint64_t & nb02,
|
2110
2066
|
constant uint64_t & nb03,
|
2111
|
-
constant int64_t & ne10,
|
2112
2067
|
constant int64_t & ne11,
|
2113
2068
|
constant int64_t & ne12,
|
2114
2069
|
constant int64_t & ne13,
|
2115
|
-
constant uint64_t & nb10,
|
2116
2070
|
constant uint64_t & nb11,
|
2117
2071
|
constant uint64_t & nb12,
|
2118
2072
|
constant uint64_t & nb13,
|
2119
|
-
constant
|
2073
|
+
constant uint64_t & nb21,
|
2074
|
+
constant uint64_t & nb22,
|
2075
|
+
constant uint64_t & nb23,
|
2120
2076
|
constant uint64_t & nb31,
|
2121
|
-
constant int64_t & ne0,
|
2122
2077
|
constant int64_t & ne1,
|
2123
2078
|
constant int64_t & ne2,
|
2124
|
-
constant int64_t & ne3,
|
2125
2079
|
constant float & scale,
|
2080
|
+
constant float & max_bias,
|
2081
|
+
constant float & m0,
|
2082
|
+
constant float & m1,
|
2083
|
+
constant uint32_t & n_head_log2,
|
2126
2084
|
threadgroup half * shared,
|
2127
2085
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
2128
2086
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
@@ -2138,29 +2096,29 @@ kernel void kernel_flash_attn_ext_f16(
|
|
2138
2096
|
device const char * v,
|
2139
2097
|
device const char * mask,
|
2140
2098
|
device float * dst,
|
2141
|
-
constant int64_t & ne00,
|
2142
2099
|
constant int64_t & ne01,
|
2143
2100
|
constant int64_t & ne02,
|
2144
2101
|
constant int64_t & ne03,
|
2145
|
-
constant uint64_t & nb00,
|
2146
2102
|
constant uint64_t & nb01,
|
2147
2103
|
constant uint64_t & nb02,
|
2148
2104
|
constant uint64_t & nb03,
|
2149
|
-
constant int64_t & ne10,
|
2150
2105
|
constant int64_t & ne11,
|
2151
2106
|
constant int64_t & ne12,
|
2152
2107
|
constant int64_t & ne13,
|
2153
|
-
constant uint64_t & nb10,
|
2154
2108
|
constant uint64_t & nb11,
|
2155
2109
|
constant uint64_t & nb12,
|
2156
2110
|
constant uint64_t & nb13,
|
2157
|
-
constant
|
2111
|
+
constant uint64_t & nb21,
|
2112
|
+
constant uint64_t & nb22,
|
2113
|
+
constant uint64_t & nb23,
|
2158
2114
|
constant uint64_t & nb31,
|
2159
|
-
constant int64_t & ne0,
|
2160
2115
|
constant int64_t & ne1,
|
2161
2116
|
constant int64_t & ne2,
|
2162
|
-
constant int64_t & ne3,
|
2163
2117
|
constant float & scale,
|
2118
|
+
constant float & max_bias,
|
2119
|
+
constant float & m0,
|
2120
|
+
constant float & m1,
|
2121
|
+
constant uint32_t & n_head_log2,
|
2164
2122
|
threadgroup half * shared [[threadgroup(0)]],
|
2165
2123
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
2166
2124
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
@@ -2225,10 +2183,6 @@ kernel void kernel_flash_attn_ext_f16(
|
|
2225
2183
|
const short ne22 = ne12;
|
2226
2184
|
const short ne23 = ne13;
|
2227
2185
|
|
2228
|
-
const uint nb21 = nb11;
|
2229
|
-
const uint nb22 = nb12;
|
2230
|
-
const uint nb23 = nb13;
|
2231
|
-
|
2232
2186
|
// broadcast
|
2233
2187
|
const short rk2 = ne02/ne12;
|
2234
2188
|
const short rk3 = ne03/ne13;
|
@@ -2254,8 +2208,17 @@ kernel void kernel_flash_attn_ext_f16(
|
|
2254
2208
|
// pointer to the mask
|
2255
2209
|
device const half * mp = (device const half *) (mask + iq1*nb31);
|
2256
2210
|
|
2257
|
-
|
2258
|
-
|
2211
|
+
float slope = 1.0f;
|
2212
|
+
|
2213
|
+
// ALiBi
|
2214
|
+
if (max_bias > 0.0f) {
|
2215
|
+
const uint32_t h = iq2;
|
2216
|
+
|
2217
|
+
const float base = h < n_head_log2 ? m0 : m1;
|
2218
|
+
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
2219
|
+
|
2220
|
+
slope = pow(base, exph);
|
2221
|
+
}
|
2259
2222
|
|
2260
2223
|
// loop over the KV cache
|
2261
2224
|
// each simdgroup handles blocks of Q rows and C columns
|
@@ -2279,12 +2242,20 @@ kernel void kernel_flash_attn_ext_f16(
|
|
2279
2242
|
simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
|
2280
2243
|
}
|
2281
2244
|
|
2282
|
-
// mqk = mqk*scale + mask
|
2283
|
-
simdgroup_half8x8 mm;
|
2284
|
-
simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(half), 0, false);
|
2285
|
-
simdgroup_multiply_accumulate(mqk, mqk, mscale, mm);
|
2286
|
-
|
2287
2245
|
simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
|
2246
|
+
|
2247
|
+
const short tx = tiisg%4;
|
2248
|
+
const short ty = tiisg/4;
|
2249
|
+
|
2250
|
+
if (mask != q) {
|
2251
|
+
// mqk = mqk*scale + mask*slope
|
2252
|
+
ss[8*cc + ty*TF + 2*tx + 0] = scale*ss[8*cc + ty*TF + 2*tx + 0] + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 0];
|
2253
|
+
ss[8*cc + ty*TF + 2*tx + 1] = scale*ss[8*cc + ty*TF + 2*tx + 1] + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 1];
|
2254
|
+
} else {
|
2255
|
+
// mqk = mqk*scale
|
2256
|
+
ss[8*cc + ty*TF + 2*tx + 0] *= scale;
|
2257
|
+
ss[8*cc + ty*TF + 2*tx + 1] *= scale;
|
2258
|
+
}
|
2288
2259
|
}
|
2289
2260
|
}
|
2290
2261
|
|
@@ -2456,29 +2427,29 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
|
2456
2427
|
device const char * v,
|
2457
2428
|
device const char * mask,
|
2458
2429
|
device float * dst,
|
2459
|
-
constant int64_t & ne00,
|
2460
2430
|
constant int64_t & ne01,
|
2461
2431
|
constant int64_t & ne02,
|
2462
2432
|
constant int64_t & ne03,
|
2463
|
-
constant uint64_t & nb00,
|
2464
2433
|
constant uint64_t & nb01,
|
2465
2434
|
constant uint64_t & nb02,
|
2466
2435
|
constant uint64_t & nb03,
|
2467
|
-
constant int64_t & ne10,
|
2468
2436
|
constant int64_t & ne11,
|
2469
2437
|
constant int64_t & ne12,
|
2470
2438
|
constant int64_t & ne13,
|
2471
|
-
constant uint64_t & nb10,
|
2472
2439
|
constant uint64_t & nb11,
|
2473
2440
|
constant uint64_t & nb12,
|
2474
2441
|
constant uint64_t & nb13,
|
2475
|
-
constant
|
2442
|
+
constant uint64_t & nb21,
|
2443
|
+
constant uint64_t & nb22,
|
2444
|
+
constant uint64_t & nb23,
|
2476
2445
|
constant uint64_t & nb31,
|
2477
|
-
constant int64_t & ne0,
|
2478
2446
|
constant int64_t & ne1,
|
2479
2447
|
constant int64_t & ne2,
|
2480
|
-
constant int64_t & ne3,
|
2481
2448
|
constant float & scale,
|
2449
|
+
constant float & max_bias,
|
2450
|
+
constant float & m0,
|
2451
|
+
constant float & m1,
|
2452
|
+
constant uint32_t & n_head_log2,
|
2482
2453
|
threadgroup half * shared [[threadgroup(0)]],
|
2483
2454
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
2484
2455
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
@@ -2497,6 +2468,18 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
|
2497
2468
|
|
2498
2469
|
const short T = D + 2*nsg*SH; // shared memory size per query in (half)
|
2499
2470
|
|
2471
|
+
float slope = 1.0f;
|
2472
|
+
|
2473
|
+
// ALiBi
|
2474
|
+
if (max_bias > 0.0f) {
|
2475
|
+
const uint32_t h = iq2;
|
2476
|
+
|
2477
|
+
const float base = h < n_head_log2 ? m0 : m1;
|
2478
|
+
const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
2479
|
+
|
2480
|
+
slope = pow(base, exp);
|
2481
|
+
}
|
2482
|
+
|
2500
2483
|
//threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
|
2501
2484
|
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
|
2502
2485
|
threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
|
@@ -2537,10 +2520,6 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
|
2537
2520
|
const short ne22 = ne12;
|
2538
2521
|
const short ne23 = ne13;
|
2539
2522
|
|
2540
|
-
const uint nb21 = nb11;
|
2541
|
-
const uint nb22 = nb12;
|
2542
|
-
const uint nb23 = nb13;
|
2543
|
-
|
2544
2523
|
// broadcast
|
2545
2524
|
const short rk2 = ne02/ne12;
|
2546
2525
|
const short rk3 = ne03/ne13;
|
@@ -2603,10 +2582,9 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
|
2603
2582
|
mqk += simd_shuffle_down(mqk, 2);
|
2604
2583
|
mqk += simd_shuffle_down(mqk, 1);
|
2605
2584
|
|
2606
|
-
// mqk = mqk*scale + mask
|
2585
|
+
// mqk = mqk*scale + mask*slope
|
2607
2586
|
if (tiisg == 0) {
|
2608
|
-
|
2609
|
-
mqk = mqk*scale + mm;
|
2587
|
+
mqk = mqk*scale + ((mask != q) ? ((float4) mp4[ic/4 + cc])*slope : (float4) 0.0f);
|
2610
2588
|
|
2611
2589
|
ss4[cc] = mqk;
|
2612
2590
|
}
|
@@ -3408,7 +3386,6 @@ void kernel_mul_mv_q2_K_f32_impl(
|
|
3408
3386
|
|
3409
3387
|
const int step = sizeof(block_q2_K) * nb;
|
3410
3388
|
|
3411
|
-
#if QK_K == 256
|
3412
3389
|
const int ix = tiisg/8; // 0...3
|
3413
3390
|
const int it = tiisg%8; // 0...7
|
3414
3391
|
const int iq = it/4; // 0 or 1
|
@@ -3460,57 +3437,6 @@ void kernel_mul_mv_q2_K_f32_impl(
|
|
3460
3437
|
|
3461
3438
|
y4 += 4 * QK_K;
|
3462
3439
|
}
|
3463
|
-
#else
|
3464
|
-
const int ix = tiisg/2; // 0...15
|
3465
|
-
const int it = tiisg%2; // 0...1
|
3466
|
-
|
3467
|
-
device const float * y4 = y + ix * QK_K + 8 * it;
|
3468
|
-
|
3469
|
-
for (int ib = ix; ib < nb; ib += 16) {
|
3470
|
-
|
3471
|
-
float4 sumy = {0.f, 0.f, 0.f, 0.f};
|
3472
|
-
for (int i = 0; i < 8; ++i) {
|
3473
|
-
yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
|
3474
|
-
yl[i+ 8] = y4[i+16]; sumy[1] += yl[i+ 8];
|
3475
|
-
yl[i+16] = y4[i+32]; sumy[2] += yl[i+16];
|
3476
|
-
yl[i+24] = y4[i+48]; sumy[3] += yl[i+24];
|
3477
|
-
}
|
3478
|
-
|
3479
|
-
device const uint8_t * sc = (device const uint8_t *)x[ib].scales;
|
3480
|
-
device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it;
|
3481
|
-
device const half * dh = &x[ib].d;
|
3482
|
-
|
3483
|
-
for (int row = 0; row < N_DST; row++) {
|
3484
|
-
|
3485
|
-
float4 acc1 = {0.f, 0.f, 0.f, 0.f};
|
3486
|
-
float4 acc2 = {0.f, 0.f, 0.f, 0.f};
|
3487
|
-
for (int i = 0; i < 8; i += 2) {
|
3488
|
-
acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003);
|
3489
|
-
acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300);
|
3490
|
-
acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c);
|
3491
|
-
acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00);
|
3492
|
-
acc1[2] += yl[i+16] * (qs[i/2] & 0x0030);
|
3493
|
-
acc2[2] += yl[i+17] * (qs[i/2] & 0x3000);
|
3494
|
-
acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0);
|
3495
|
-
acc2[3] += yl[i+25] * (qs[i/2] & 0xc000);
|
3496
|
-
}
|
3497
|
-
|
3498
|
-
float dall = dh[0];
|
3499
|
-
float dmin = dh[1];
|
3500
|
-
sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f +
|
3501
|
-
(acc1[1] + 1.f/256.f * acc2[1]) * (sc[1] & 0xF) * 1.f/ 4.f +
|
3502
|
-
(acc1[2] + 1.f/256.f * acc2[2]) * (sc[2] & 0xF) * 1.f/16.f +
|
3503
|
-
(acc1[3] + 1.f/256.f * acc2[3]) * (sc[3] & 0xF) * 1.f/64.f) -
|
3504
|
-
dmin * (sumy[0] * (sc[0] >> 4) + sumy[1] * (sc[1] >> 4) + sumy[2] * (sc[2] >> 4) + sumy[3] * (sc[3] >> 4));
|
3505
|
-
|
3506
|
-
qs += step/2;
|
3507
|
-
sc += step;
|
3508
|
-
dh += step/2;
|
3509
|
-
}
|
3510
|
-
|
3511
|
-
y4 += 16 * QK_K;
|
3512
|
-
}
|
3513
|
-
#endif
|
3514
3440
|
|
3515
3441
|
for (int row = 0; row < N_DST; ++row) {
|
3516
3442
|
all_sum = simd_sum(sumf[row]);
|
@@ -3548,7 +3474,6 @@ kernel void kernel_mul_mv_q2_K_f32(
|
|
3548
3474
|
kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
|
3549
3475
|
}
|
3550
3476
|
|
3551
|
-
#if QK_K == 256
|
3552
3477
|
void kernel_mul_mv_q3_K_f32_impl(
|
3553
3478
|
device const void * src0,
|
3554
3479
|
device const float * src1,
|
@@ -3707,84 +3632,6 @@ void kernel_mul_mv_q3_K_f32_impl(
|
|
3707
3632
|
}
|
3708
3633
|
}
|
3709
3634
|
}
|
3710
|
-
#else
|
3711
|
-
void kernel_mul_mv_q3_K_f32_impl(
|
3712
|
-
device const void * src0,
|
3713
|
-
device const float * src1,
|
3714
|
-
device float * dst,
|
3715
|
-
constant int64_t & ne00,
|
3716
|
-
constant int64_t & ne01,
|
3717
|
-
constant int64_t & ne02,
|
3718
|
-
constant int64_t & ne10,
|
3719
|
-
constant int64_t & ne12,
|
3720
|
-
constant int64_t & ne0,
|
3721
|
-
constant int64_t & ne1,
|
3722
|
-
constant uint & r2,
|
3723
|
-
constant uint & r3,
|
3724
|
-
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
3725
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
3726
|
-
uint tiisg[[thread_index_in_simdgroup]],
|
3727
|
-
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
3728
|
-
|
3729
|
-
const int nb = ne00/QK_K;
|
3730
|
-
|
3731
|
-
const int64_t r0 = tgpig.x;
|
3732
|
-
const int64_t r1 = tgpig.y;
|
3733
|
-
const int64_t im = tgpig.z;
|
3734
|
-
|
3735
|
-
const int row = 2 * r0 + sgitg;
|
3736
|
-
|
3737
|
-
const uint i12 = im%ne12;
|
3738
|
-
const uint i13 = im/ne12;
|
3739
|
-
|
3740
|
-
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
3741
|
-
|
3742
|
-
device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb + offset0;
|
3743
|
-
device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
3744
|
-
|
3745
|
-
const int ix = tiisg/4;
|
3746
|
-
const int il = 4 * (tiisg%4);// 0, 4, 8, 12
|
3747
|
-
const int iq = il/8; // 0, 0, 1, 1
|
3748
|
-
const int in = il%8; // 0, 4, 0, 4
|
3749
|
-
|
3750
|
-
float2 sum = {0.f, 0.f};
|
3751
|
-
|
3752
|
-
for (int i = ix; i < nb; i += 8) {
|
3753
|
-
|
3754
|
-
const float d_all = (float)(x[i].d);
|
3755
|
-
|
3756
|
-
device const uint16_t * q = (device const uint16_t *)(x[i].qs + il);
|
3757
|
-
device const uint16_t * h = (device const uint16_t *)(x[i].hmask + in);
|
3758
|
-
device const uint16_t * s = (device const uint16_t *)(x[i].scales);
|
3759
|
-
device const float * y = yy + i * QK_K + il;
|
3760
|
-
|
3761
|
-
const float d1 = d_all * ((int32_t)(s[0] & 0x000F) - 8);
|
3762
|
-
const float d2 = d_all * ((int32_t)(s[0] & 0x00F0) - 128) * 1.f/64.f;
|
3763
|
-
const float d3 = d_all * ((int32_t)(s[0] & 0x0F00) - 2048) * 1.f/4096.f;
|
3764
|
-
const float d4 = d_all * ((int32_t)(s[0] & 0xF000) - 32768) * 1.f/262144.f;
|
3765
|
-
|
3766
|
-
for (int l = 0; l < 4; l += 2) {
|
3767
|
-
const uint16_t hm = h[l/2] >> iq;
|
3768
|
-
sum[0] += y[l+ 0] * d1 * ((int32_t)(q[l/2] & 0x0003) - ((hm & 0x0001) ? 0 : 4))
|
3769
|
-
+ y[l+16] * d2 * ((int32_t)(q[l/2] & 0x000c) - ((hm & 0x0004) ? 0 : 16))
|
3770
|
-
+ y[l+32] * d3 * ((int32_t)(q[l/2] & 0x0030) - ((hm & 0x0010) ? 0 : 64))
|
3771
|
-
+ y[l+48] * d4 * ((int32_t)(q[l/2] & 0x00c0) - ((hm & 0x0040) ? 0 : 256));
|
3772
|
-
sum[1] += y[l+ 1] * d1 * ((int32_t)(q[l/2] & 0x0300) - ((hm & 0x0100) ? 0 : 1024))
|
3773
|
-
+ y[l+17] * d2 * ((int32_t)(q[l/2] & 0x0c00) - ((hm & 0x0400) ? 0 : 4096))
|
3774
|
-
+ y[l+33] * d3 * ((int32_t)(q[l/2] & 0x3000) - ((hm & 0x1000) ? 0 : 16384))
|
3775
|
-
+ y[l+49] * d4 * ((int32_t)(q[l/2] & 0xc000) - ((hm & 0x4000) ? 0 : 65536));
|
3776
|
-
}
|
3777
|
-
|
3778
|
-
}
|
3779
|
-
const float sumf = sum[0] + sum[1] * 1.f/256.f;
|
3780
|
-
|
3781
|
-
const float tot = simd_sum(sumf);
|
3782
|
-
if (tiisg == 0) {
|
3783
|
-
dst[r1*ne0 + im*ne0*ne1 + row] = tot;
|
3784
|
-
}
|
3785
|
-
|
3786
|
-
}
|
3787
|
-
#endif
|
3788
3635
|
|
3789
3636
|
[[host_name("kernel_mul_mv_q3_K_f32")]]
|
3790
3637
|
kernel void kernel_mul_mv_q3_K_f32(
|
@@ -3814,7 +3661,6 @@ kernel void kernel_mul_mv_q3_K_f32(
|
|
3814
3661
|
kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
|
3815
3662
|
}
|
3816
3663
|
|
3817
|
-
#if QK_K == 256
|
3818
3664
|
void kernel_mul_mv_q4_K_f32_impl(
|
3819
3665
|
device const void * src0,
|
3820
3666
|
device const float * src1,
|
@@ -3928,103 +3774,6 @@ void kernel_mul_mv_q4_K_f32_impl(
|
|
3928
3774
|
}
|
3929
3775
|
}
|
3930
3776
|
}
|
3931
|
-
#else
|
3932
|
-
void kernel_mul_mv_q4_K_f32_impl(
|
3933
|
-
device const void * src0,
|
3934
|
-
device const float * src1,
|
3935
|
-
device float * dst,
|
3936
|
-
constant int64_t & ne00,
|
3937
|
-
constant int64_t & ne01,
|
3938
|
-
constant int64_t & ne02,
|
3939
|
-
constant int64_t & ne10,
|
3940
|
-
constant int64_t & ne12,
|
3941
|
-
constant int64_t & ne0,
|
3942
|
-
constant int64_t & ne1,
|
3943
|
-
constant uint & r2,
|
3944
|
-
constant uint & r3,
|
3945
|
-
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
3946
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
3947
|
-
uint tiisg[[thread_index_in_simdgroup]],
|
3948
|
-
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
3949
|
-
|
3950
|
-
const int ix = tiisg/4; // 0...7
|
3951
|
-
const int it = tiisg%4; // 0...3
|
3952
|
-
|
3953
|
-
const int nb = ne00/QK_K;
|
3954
|
-
const int r0 = tgpig.x;
|
3955
|
-
const int r1 = tgpig.y;
|
3956
|
-
const int im = tgpig.z;
|
3957
|
-
const int first_row = r0 * N_DST;
|
3958
|
-
const int ib_row = first_row * nb;
|
3959
|
-
|
3960
|
-
const uint i12 = im%ne12;
|
3961
|
-
const uint i13 = im/ne12;
|
3962
|
-
|
3963
|
-
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
3964
|
-
|
3965
|
-
device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
|
3966
|
-
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
3967
|
-
|
3968
|
-
float yl[8];
|
3969
|
-
float yh[8];
|
3970
|
-
float sumf[N_DST]={0.f}, all_sum;
|
3971
|
-
|
3972
|
-
const int step = sizeof(block_q4_K) * nb / 2;
|
3973
|
-
|
3974
|
-
device const float * y4 = y + ix * QK_K + 8 * it;
|
3975
|
-
|
3976
|
-
uint16_t sc16[4];
|
3977
|
-
|
3978
|
-
for (int ib = ix; ib < nb; ib += 8) {
|
3979
|
-
|
3980
|
-
float2 sumy = {0.f, 0.f};
|
3981
|
-
for (int i = 0; i < 8; ++i) {
|
3982
|
-
yl[i] = y4[i+ 0]; sumy[0] += yl[i];
|
3983
|
-
yh[i] = y4[i+32]; sumy[1] += yh[i];
|
3984
|
-
}
|
3985
|
-
|
3986
|
-
device const uint16_t * sc = (device const uint16_t *)x[ib].scales;
|
3987
|
-
device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it;
|
3988
|
-
device const half * dh = x[ib].d;
|
3989
|
-
|
3990
|
-
for (int row = 0; row < N_DST; row++) {
|
3991
|
-
|
3992
|
-
sc16[0] = sc[0] & 0x000f;
|
3993
|
-
sc16[1] = sc[0] & 0x0f00;
|
3994
|
-
sc16[2] = sc[0] & 0x00f0;
|
3995
|
-
sc16[3] = sc[0] & 0xf000;
|
3996
|
-
|
3997
|
-
float2 acc1 = {0.f, 0.f};
|
3998
|
-
float2 acc2 = {0.f, 0.f};
|
3999
|
-
for (int i = 0; i < 8; i += 2) {
|
4000
|
-
acc1[0] += yl[i+0] * (qs[i/2] & 0x000F);
|
4001
|
-
acc1[1] += yl[i+1] * (qs[i/2] & 0x0F00);
|
4002
|
-
acc2[0] += yh[i+0] * (qs[i/2] & 0x00F0);
|
4003
|
-
acc2[1] += yh[i+1] * (qs[i/2] & 0xF000);
|
4004
|
-
}
|
4005
|
-
|
4006
|
-
float dall = dh[0];
|
4007
|
-
float dmin = dh[1];
|
4008
|
-
sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc16[0] +
|
4009
|
-
(acc2[0] + 1.f/256.f * acc2[1]) * sc16[1] * 1.f/4096.f) -
|
4010
|
-
dmin * 1.f/16.f * (sumy[0] * sc16[2] + sumy[1] * sc16[3] * 1.f/256.f);
|
4011
|
-
|
4012
|
-
qs += step;
|
4013
|
-
sc += step;
|
4014
|
-
dh += step;
|
4015
|
-
}
|
4016
|
-
|
4017
|
-
y4 += 8 * QK_K;
|
4018
|
-
}
|
4019
|
-
|
4020
|
-
for (int row = 0; row < N_DST; ++row) {
|
4021
|
-
all_sum = simd_sum(sumf[row]);
|
4022
|
-
if (tiisg == 0) {
|
4023
|
-
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
|
4024
|
-
}
|
4025
|
-
}
|
4026
|
-
}
|
4027
|
-
#endif
|
4028
3777
|
|
4029
3778
|
[[host_name("kernel_mul_mv_q4_K_f32")]]
|
4030
3779
|
kernel void kernel_mul_mv_q4_K_f32(
|
@@ -4092,8 +3841,6 @@ void kernel_mul_mv_q5_K_f32_impl(
|
|
4092
3841
|
|
4093
3842
|
const int step = sizeof(block_q5_K) * nb;
|
4094
3843
|
|
4095
|
-
#if QK_K == 256
|
4096
|
-
#
|
4097
3844
|
float yl[16], yh[16];
|
4098
3845
|
|
4099
3846
|
const uint16_t kmask1 = 0x3f3f;
|
@@ -4176,54 +3923,6 @@ void kernel_mul_mv_q5_K_f32_impl(
|
|
4176
3923
|
y1 += 4 * QK_K;
|
4177
3924
|
|
4178
3925
|
}
|
4179
|
-
#else
|
4180
|
-
float yl[8], yh[8];
|
4181
|
-
|
4182
|
-
const int il = 4 * (tiisg/8); // 0, 4, 8, 12
|
4183
|
-
const int ix = tiisg%8;
|
4184
|
-
const int iq = il/8; // 0, 0, 1, 1
|
4185
|
-
const int in = il%8; // 0, 4, 0, 4
|
4186
|
-
|
4187
|
-
device const float * y = yy + ix*QK_K + il;
|
4188
|
-
|
4189
|
-
for (int i = ix; i < nb; i += 8) {
|
4190
|
-
|
4191
|
-
for (int l = 0; l < 4; ++l) {
|
4192
|
-
yl[l+0] = y[l+ 0];
|
4193
|
-
yl[l+4] = y[l+16];
|
4194
|
-
yh[l+0] = y[l+32];
|
4195
|
-
yh[l+4] = y[l+48];
|
4196
|
-
}
|
4197
|
-
|
4198
|
-
device const half * dh = &x[i].d;
|
4199
|
-
device const uint8_t * q = x[i].qs + il;
|
4200
|
-
device const uint8_t * h = x[i].qh + in;
|
4201
|
-
device const int8_t * s = x[i].scales;
|
4202
|
-
|
4203
|
-
for (int row = 0; row < 2; ++row) {
|
4204
|
-
|
4205
|
-
const float d = dh[0];
|
4206
|
-
|
4207
|
-
float2 acc = {0.f, 0.f};
|
4208
|
-
for (int l = 0; l < 4; ++l) {
|
4209
|
-
const uint8_t hl = h[l] >> iq;
|
4210
|
-
acc[0] += yl[l+0] * s[0] * ((int16_t)(q[l+ 0] & 0x0F) - (hl & 0x01 ? 0 : 16))
|
4211
|
-
+ yl[l+4] * s[1] * ((int16_t)(q[l+16] & 0x0F) - (hl & 0x04 ? 0 : 16));
|
4212
|
-
acc[1] += yh[l+0] * s[2] * ((int16_t)(q[l+ 0] & 0xF0) - (hl & 0x10 ? 0 : 256))
|
4213
|
-
+ yh[l+4] * s[3] * ((int16_t)(q[l+16] & 0xF0) - (hl & 0x40 ? 0 : 256));
|
4214
|
-
}
|
4215
|
-
sumf[row] += d * (acc[0] + 1.f/16.f * acc[1]);
|
4216
|
-
|
4217
|
-
q += step;
|
4218
|
-
h += step;
|
4219
|
-
s += step;
|
4220
|
-
dh += step/2;
|
4221
|
-
|
4222
|
-
}
|
4223
|
-
|
4224
|
-
y += 8 * QK_K;
|
4225
|
-
}
|
4226
|
-
#endif
|
4227
3926
|
|
4228
3927
|
for (int row = 0; row < 2; ++row) {
|
4229
3928
|
const float tot = simd_sum(sumf[row]);
|
@@ -4302,7 +4001,6 @@ void kernel_mul_mv_q6_K_f32_impl(
|
|
4302
4001
|
|
4303
4002
|
float sumf = 0;
|
4304
4003
|
|
4305
|
-
#if QK_K == 256
|
4306
4004
|
const int tid = tiisg/2;
|
4307
4005
|
const int ix = tiisg%2;
|
4308
4006
|
const int ip = tid/8; // 0 or 1
|
@@ -4338,30 +4036,6 @@ void kernel_mul_mv_q6_K_f32_impl(
|
|
4338
4036
|
|
4339
4037
|
}
|
4340
4038
|
|
4341
|
-
#else
|
4342
|
-
const int ix = tiisg/4;
|
4343
|
-
const int il = 4*(tiisg%4);
|
4344
|
-
|
4345
|
-
for (int i = ix; i < nb; i += 8) {
|
4346
|
-
device const float * y = yy + i * QK_K + il;
|
4347
|
-
device const uint8_t * ql = x[i].ql + il;
|
4348
|
-
device const uint8_t * qh = x[i].qh + il;
|
4349
|
-
device const int8_t * s = x[i].scales;
|
4350
|
-
|
4351
|
-
const float d = x[i].d;
|
4352
|
-
|
4353
|
-
float4 sums = {0.f, 0.f, 0.f, 0.f};
|
4354
|
-
for (int l = 0; l < 4; ++l) {
|
4355
|
-
sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
|
4356
|
-
sums[1] += y[l+16] * ((int8_t)((ql[l+16] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
|
4357
|
-
sums[2] += y[l+32] * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & kmask3) >> 0)) - 32);
|
4358
|
-
sums[3] += y[l+48] * ((int8_t)((ql[l+16] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
|
4359
|
-
}
|
4360
|
-
sumf += d * (sums[0] * s[0] + sums[1] * s[1] + sums[2] * s[2] + sums[3] * s[3]);
|
4361
|
-
}
|
4362
|
-
|
4363
|
-
#endif
|
4364
|
-
|
4365
4039
|
const float tot = simd_sum(sumf);
|
4366
4040
|
if (tiisg == 0) {
|
4367
4041
|
dst[r1*ne0 + im*ne0*ne1 + row] = tot;
|
@@ -5195,9 +4869,7 @@ void kernel_mul_mv_iq1_m_f32_impl(
|
|
5195
4869
|
|
5196
4870
|
device const float * y4 = y + 32 * ix;
|
5197
4871
|
|
5198
|
-
#if QK_K != 64
|
5199
4872
|
iq1m_scale_t scale;
|
5200
|
-
#endif
|
5201
4873
|
|
5202
4874
|
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
|
5203
4875
|
|
@@ -5218,10 +4890,7 @@ void kernel_mul_mv_iq1_m_f32_impl(
|
|
5218
4890
|
device const uint16_t * sc = (device const uint16_t *)xr->scales;
|
5219
4891
|
|
5220
4892
|
for (int row = 0; row < N_DST; row++) {
|
5221
|
-
|
5222
|
-
#if QK_K != 64
|
5223
4893
|
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
|
5224
|
-
#endif
|
5225
4894
|
|
5226
4895
|
constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
|
5227
4896
|
constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700)));
|
@@ -5237,14 +4906,9 @@ void kernel_mul_mv_iq1_m_f32_impl(
|
|
5237
4906
|
}
|
5238
4907
|
const float delta1 = sumy[0] * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[1] * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
|
5239
4908
|
const float delta2 = sumy[2] * (qh[1] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[3] * (qh[1] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
|
5240
|
-
|
5241
|
-
const float d = (float) *((device const half *)(sc - 1));
|
5242
|
-
sumf[row] += d * ((sum[0] + delta1) * (2*((sc[0] >> (8*(ib%2)+0)) & 0xf) + 1) +
|
5243
|
-
(sum[1] + delta2) * (2*((sc[0] >> (8*(ib%2)+4)) & 0xf) + 1));
|
5244
|
-
#else
|
4909
|
+
|
5245
4910
|
sumf[row] += (float)scale.f16 * ((sum[0] + delta1) * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 7) + 1) +
|
5246
4911
|
(sum[1] + delta2) * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 7) + 1));
|
5247
|
-
#endif
|
5248
4912
|
|
5249
4913
|
sc += nb*sizeof(block_iq1_m)/2;
|
5250
4914
|
qs += nb*sizeof(block_iq1_m);
|
@@ -5356,7 +5020,6 @@ void kernel_mul_mv_iq4_nl_f32_impl(
|
|
5356
5020
|
}
|
5357
5021
|
}
|
5358
5022
|
|
5359
|
-
#if QK_K != 64
|
5360
5023
|
void kernel_mul_mv_iq4_xs_f32_impl(
|
5361
5024
|
device const void * src0,
|
5362
5025
|
device const float * src1,
|
@@ -5451,7 +5114,6 @@ void kernel_mul_mv_iq4_xs_f32_impl(
|
|
5451
5114
|
}
|
5452
5115
|
}
|
5453
5116
|
}
|
5454
|
-
#endif
|
5455
5117
|
|
5456
5118
|
[[host_name("kernel_mul_mv_iq1_s_f32")]]
|
5457
5119
|
kernel void kernel_mul_mv_iq1_s_f32(
|
@@ -5564,11 +5226,7 @@ kernel void kernel_mul_mv_iq4_xs_f32(
|
|
5564
5226
|
uint tiisg[[thread_index_in_simdgroup]],
|
5565
5227
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
5566
5228
|
|
5567
|
-
#if QK_K == 64
|
5568
|
-
kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
5569
|
-
#else
|
5570
5229
|
kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
5571
|
-
#endif
|
5572
5230
|
}
|
5573
5231
|
|
5574
5232
|
//============================= templates and their specializations =============================
|
@@ -5694,10 +5352,9 @@ void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg
|
|
5694
5352
|
float dl, ml;
|
5695
5353
|
uint8_t sc = xb->scales[il];
|
5696
5354
|
|
5697
|
-
#if QK_K == 256
|
5698
5355
|
q = q + 32*(il/8) + 16*(il&1);
|
5699
5356
|
il = (il/2)%4;
|
5700
|
-
|
5357
|
+
|
5701
5358
|
half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
|
5702
5359
|
uchar mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
|
5703
5360
|
dl = d * (sc & 0xF) * coef, ml = min * (sc >> 4);
|
@@ -5713,7 +5370,6 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
|
|
5713
5370
|
device const uint8_t * h = (device const uint8_t *)xb->hmask;
|
5714
5371
|
device const int8_t * scales = (device const int8_t *)xb->scales;
|
5715
5372
|
|
5716
|
-
#if QK_K == 256
|
5717
5373
|
q = q + 32 * (il/8) + 16 * (il&1);
|
5718
5374
|
h = h + 16 * (il&1);
|
5719
5375
|
uint8_t m = 1 << (il/2);
|
@@ -5734,17 +5390,6 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
|
|
5734
5390
|
for (int i = 0; i < 16; ++i) {
|
5735
5391
|
reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml);
|
5736
5392
|
}
|
5737
|
-
#else
|
5738
|
-
float kcoef = il&1 ? 1.f/16.f : 1.f;
|
5739
|
-
uint16_t kmask = il&1 ? 0xF0 : 0x0F;
|
5740
|
-
float dl = d_all * ((scales[il/2] & kmask) * kcoef - 8);
|
5741
|
-
float coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
|
5742
|
-
uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
|
5743
|
-
uint8_t m = 1<<(il*2);
|
5744
|
-
for (int i = 0; i < 16; ++i) {
|
5745
|
-
reg[i/4][i%4] = coef * dl * ((q[i] & mask) - ((h[i%8] & (m * (1 + i/8))) ? 0 : 4.f/coef));
|
5746
|
-
}
|
5747
|
-
#endif
|
5748
5393
|
}
|
5749
5394
|
|
5750
5395
|
static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) {
|
@@ -5756,7 +5401,6 @@ template <typename type4x4>
|
|
5756
5401
|
void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) {
|
5757
5402
|
device const uchar * q = xb->qs;
|
5758
5403
|
|
5759
|
-
#if QK_K == 256
|
5760
5404
|
short is = (il/4) * 2;
|
5761
5405
|
q = q + (il/4) * 32 + 16 * (il&1);
|
5762
5406
|
il = il & 3;
|
@@ -5765,16 +5409,7 @@ void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg
|
|
5765
5409
|
const float min = xb->dmin;
|
5766
5410
|
const float dl = d * sc[0];
|
5767
5411
|
const float ml = min * sc[1];
|
5768
|
-
|
5769
|
-
(void) get_scale_min_k4_just2;
|
5770
|
-
|
5771
|
-
q = q + 16 * (il&1);
|
5772
|
-
device const uint8_t * s = xb->scales;
|
5773
|
-
device const half2 * dh = (device const half2 *)xb->d;
|
5774
|
-
const float2 d = (float2)dh[0];
|
5775
|
-
const float dl = il<2 ? d[0] * (s[0]&0xF) : d[0] * (s[1]&0xF)/16.h;
|
5776
|
-
const float ml = il<2 ? d[1] * (s[0]>>4) : d[1] * (s[1]>>4);
|
5777
|
-
#endif
|
5412
|
+
|
5778
5413
|
const ushort mask = il<2 ? 0x0F : 0xF0;
|
5779
5414
|
for (int i = 0; i < 16; ++i) {
|
5780
5415
|
reg[i/4][i%4] = dl * (q[i] & mask) - ml;
|
@@ -5786,7 +5421,6 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
|
|
5786
5421
|
device const uint8_t * q = xb->qs;
|
5787
5422
|
device const uint8_t * qh = xb->qh;
|
5788
5423
|
|
5789
|
-
#if QK_K == 256
|
5790
5424
|
short is = (il/4) * 2;
|
5791
5425
|
q = q + 32 * (il/4) + 16 * (il&1);
|
5792
5426
|
qh = qh + 16 * (il&1);
|
@@ -5803,17 +5437,6 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
|
|
5803
5437
|
for (int i = 0; i < 16; ++i) {
|
5804
5438
|
reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
|
5805
5439
|
}
|
5806
|
-
#else
|
5807
|
-
q = q + 16 * (il&1);
|
5808
|
-
device const int8_t * s = xb->scales;
|
5809
|
-
const float dl = xb->d * s[il];
|
5810
|
-
uint8_t m = 1<<(il*2);
|
5811
|
-
const float coef = il<2 ? 1.f : 1.f/16.f;
|
5812
|
-
const ushort mask = il<2 ? 0x0F : 0xF0;
|
5813
|
-
for (int i = 0; i < 16; ++i) {
|
5814
|
-
reg[i/4][i%4] = coef * dl * ((q[i] & mask) - (qh[i%8] & (m*(1+i/8)) ? 0.f : 16.f/coef));
|
5815
|
-
}
|
5816
|
-
#endif
|
5817
5440
|
}
|
5818
5441
|
|
5819
5442
|
template <typename type4x4>
|
@@ -5823,15 +5446,11 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
|
|
5823
5446
|
device const uint8_t * qh = (device const uint8_t *)xb->qh;
|
5824
5447
|
device const int8_t * scales = (device const int8_t *)xb->scales;
|
5825
5448
|
|
5826
|
-
#if QK_K == 256
|
5827
5449
|
ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
|
5828
5450
|
qh = qh + 32*(il/8) + 16*(il&1);
|
5829
5451
|
float sc = scales[(il%2) + 2 * ((il/2))];
|
5830
5452
|
il = (il/2) & 3;
|
5831
|
-
|
5832
|
-
ql = ql + 16 * (il&1);
|
5833
|
-
float sc = scales[il];
|
5834
|
-
#endif
|
5453
|
+
|
5835
5454
|
const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
|
5836
5455
|
const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
|
5837
5456
|
const float coef = il>1 ? 1.f/16.f : 1.f;
|
@@ -5988,20 +5607,15 @@ void dequantize_iq1_m(device const block_iq1_m * xb, short il, thread type4x4 &
|
|
5988
5607
|
const int ib32 = il/2;
|
5989
5608
|
il = il%2;
|
5990
5609
|
device const uint16_t * sc = (device const uint16_t *)xb->scales;
|
5991
|
-
|
5992
|
-
const float d = xb->d;
|
5993
|
-
#else
|
5610
|
+
|
5994
5611
|
iq1m_scale_t scale;
|
5995
5612
|
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
|
5996
5613
|
const float d = scale.f16;
|
5997
|
-
|
5614
|
+
|
5998
5615
|
device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
|
5999
5616
|
device const uint8_t * qh = xb->qh + 2*ib32 + il;
|
6000
|
-
|
6001
|
-
const float dl = d * (2*((sc[ib32/2] >> (8*(ib32%2)+4*il)) & 0xf) + 1);
|
6002
|
-
#else
|
5617
|
+
|
6003
5618
|
const float dl = d * (2*((sc[ib32/2] >> (6*(ib32%2)+3*il)) & 7) + 1);
|
6004
|
-
#endif
|
6005
5619
|
const float ml1 = dl * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
|
6006
5620
|
const float ml2 = dl * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
|
6007
5621
|
constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
|
@@ -6031,9 +5645,6 @@ void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4
|
|
6031
5645
|
|
6032
5646
|
template <typename type4x4>
|
6033
5647
|
void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) {
|
6034
|
-
#if QK_K == 64
|
6035
|
-
dequantize_iq4_nl(xb, il, reg);
|
6036
|
-
#else
|
6037
5648
|
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
|
6038
5649
|
const int ib32 = il/2;
|
6039
5650
|
il = il%2;
|
@@ -6050,7 +5661,6 @@ void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4
|
|
6050
5661
|
reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
|
6051
5662
|
reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
|
6052
5663
|
}
|
6053
|
-
#endif
|
6054
5664
|
}
|
6055
5665
|
|
6056
5666
|
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
|
@@ -6555,11 +6165,7 @@ kernel void kernel_mul_mm_id(
|
|
6555
6165
|
sgitg);
|
6556
6166
|
}
|
6557
6167
|
|
6558
|
-
#if QK_K == 256
|
6559
6168
|
#define QK_NL 16
|
6560
|
-
#else
|
6561
|
-
#define QK_NL 4
|
6562
|
-
#endif
|
6563
6169
|
|
6564
6170
|
//
|
6565
6171
|
// get rows
|
@@ -6599,11 +6205,7 @@ template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_t kernel_get_r
|
|
6599
6205
|
template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
6600
6206
|
template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_t kernel_get_rows<block_iq1_m, QK_NL, dequantize_iq1_m>;
|
6601
6207
|
template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows<block_iq4_nl, 2, dequantize_iq4_nl>;
|
6602
|
-
#if QK_K == 64
|
6603
|
-
template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows<block_iq4_xs, 2, dequantize_iq4_xs>;
|
6604
|
-
#else
|
6605
6208
|
template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
6606
|
-
#endif
|
6607
6209
|
|
6608
6210
|
//
|
6609
6211
|
// matrix-matrix multiplication
|
@@ -6631,11 +6233,7 @@ template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_m
|
|
6631
6233
|
template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
6632
6234
|
template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_m, QK_NL, dequantize_iq1_m>;
|
6633
6235
|
template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_nl>;
|
6634
|
-
#if QK_K == 64
|
6635
|
-
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_xs>;
|
6636
|
-
#else
|
6637
6236
|
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
6638
|
-
#endif
|
6639
6237
|
|
6640
6238
|
//
|
6641
6239
|
// indirect matrix-matrix multiplication
|
@@ -6663,11 +6261,7 @@ template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mat_mm_id_t kernel
|
|
6663
6261
|
template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
6664
6262
|
template [[host_name("kernel_mul_mm_id_iq1_m_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_m, QK_NL, dequantize_iq1_m>;
|
6665
6263
|
template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl, 2, dequantize_iq4_nl>;
|
6666
|
-
#if QK_K == 64
|
6667
|
-
template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, 2, dequantize_iq4_xs>;
|
6668
|
-
#else
|
6669
6264
|
template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
6670
|
-
#endif
|
6671
6265
|
|
6672
6266
|
//
|
6673
6267
|
// matrix-vector multiplication
|
@@ -6876,7 +6470,5 @@ template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t
|
|
6876
6470
|
template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_s_f32_impl>>;
|
6877
6471
|
template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_s_f32_impl>>;
|
6878
6472
|
template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl>>;
|
6879
|
-
#if QK_K != 64
|
6880
6473
|
template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl>>;
|
6881
|
-
#endif
|
6882
6474
|
|