llama_cpp 0.15.1 → 0.15.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +14 -0
- data/ext/llama_cpp/llama_cpp.cpp +49 -0
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +4 -0
- data/vendor/tmp/llama.cpp/Makefile +9 -20
- data/vendor/tmp/llama.cpp/ggml-backend.c +2 -3
- data/vendor/tmp/llama.cpp/ggml-common.h +0 -54
- data/vendor/tmp/llama.cpp/ggml-cuda.cu +87 -37
- data/vendor/tmp/llama.cpp/ggml-cuda.h +1 -0
- data/vendor/tmp/llama.cpp/ggml-impl.h +47 -0
- data/vendor/tmp/llama.cpp/ggml-kompute.cpp +13 -3
- data/vendor/tmp/llama.cpp/ggml-metal.m +177 -190
- data/vendor/tmp/llama.cpp/ggml-metal.metal +97 -505
- data/vendor/tmp/llama.cpp/ggml-opencl.cpp +4 -1
- data/vendor/tmp/llama.cpp/ggml-quants.c +3660 -2057
- data/vendor/tmp/llama.cpp/ggml-rpc.cpp +1155 -0
- data/vendor/tmp/llama.cpp/ggml-rpc.h +24 -0
- data/vendor/tmp/llama.cpp/ggml-sycl.cpp +60 -639
- data/vendor/tmp/llama.cpp/ggml-vulkan-shaders.hpp +9351 -5627
- data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +203 -224
- data/vendor/tmp/llama.cpp/ggml.c +1168 -1470
- data/vendor/tmp/llama.cpp/ggml.h +67 -44
- data/vendor/tmp/llama.cpp/llama.cpp +1371 -944
- data/vendor/tmp/llama.cpp/llama.h +13 -3
- data/vendor/tmp/llama.cpp/unicode-data.cpp +6969 -2169
- data/vendor/tmp/llama.cpp/unicode-data.h +15 -12
- data/vendor/tmp/llama.cpp/unicode.cpp +89 -111
- data/vendor/tmp/llama.cpp/unicode.h +44 -12
- metadata +5 -3
@@ -229,6 +229,13 @@ kernel void kernel_relu(
|
|
229
229
|
dst[tpig] = max(0.0f, src0[tpig]);
|
230
230
|
}
|
231
231
|
|
232
|
+
kernel void kernel_sigmoid(
|
233
|
+
device const float * src0,
|
234
|
+
device float * dst,
|
235
|
+
uint tpig[[thread_position_in_grid]]) {
|
236
|
+
dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig]));
|
237
|
+
}
|
238
|
+
|
232
239
|
kernel void kernel_tanh(
|
233
240
|
device const float * src0,
|
234
241
|
device float * dst,
|
@@ -356,7 +363,6 @@ template<typename T>
|
|
356
363
|
kernel void kernel_soft_max(
|
357
364
|
device const char * src0,
|
358
365
|
device const char * src1,
|
359
|
-
device const char * src2,
|
360
366
|
device char * dst,
|
361
367
|
constant int64_t & ne00,
|
362
368
|
constant int64_t & ne01,
|
@@ -378,10 +384,9 @@ kernel void kernel_soft_max(
|
|
378
384
|
|
379
385
|
device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
380
386
|
device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00 : nullptr;
|
381
|
-
device const T * ppos = src2 != src0 ? (device const T *) src2 : nullptr;
|
382
387
|
device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
383
388
|
|
384
|
-
float slope =
|
389
|
+
float slope = 1.0f;
|
385
390
|
|
386
391
|
// ALiBi
|
387
392
|
if (max_bias > 0.0f) {
|
@@ -397,7 +402,7 @@ kernel void kernel_soft_max(
|
|
397
402
|
float lmax = -INFINITY;
|
398
403
|
|
399
404
|
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
400
|
-
lmax = MAX(lmax, psrc0[i00]*scale + (pmask ?
|
405
|
+
lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
|
401
406
|
}
|
402
407
|
|
403
408
|
// find the max value in the block
|
@@ -422,7 +427,7 @@ kernel void kernel_soft_max(
|
|
422
427
|
// parallel sum
|
423
428
|
float lsum = 0.0f;
|
424
429
|
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
425
|
-
const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ?
|
430
|
+
const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val);
|
426
431
|
lsum += exp_psrc0;
|
427
432
|
pdst[i00] = exp_psrc0;
|
428
433
|
}
|
@@ -461,7 +466,6 @@ template<typename T>
|
|
461
466
|
kernel void kernel_soft_max_4(
|
462
467
|
device const char * src0,
|
463
468
|
device const char * src1,
|
464
|
-
device const char * src2,
|
465
469
|
device char * dst,
|
466
470
|
constant int64_t & ne00,
|
467
471
|
constant int64_t & ne01,
|
@@ -483,10 +487,9 @@ kernel void kernel_soft_max_4(
|
|
483
487
|
|
484
488
|
device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
|
485
489
|
device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00/4 : nullptr;
|
486
|
-
device const T * ppos = src2 != src0 ? (device const T *) src2 : nullptr;
|
487
490
|
device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
|
488
491
|
|
489
|
-
float slope =
|
492
|
+
float slope = 1.0f;
|
490
493
|
|
491
494
|
if (max_bias > 0.0f) {
|
492
495
|
const int64_t h = i02;
|
@@ -501,7 +504,7 @@ kernel void kernel_soft_max_4(
|
|
501
504
|
float4 lmax4 = -INFINITY;
|
502
505
|
|
503
506
|
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
504
|
-
lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ?
|
507
|
+
lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
|
505
508
|
}
|
506
509
|
|
507
510
|
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
|
@@ -527,7 +530,7 @@ kernel void kernel_soft_max_4(
|
|
527
530
|
// parallel sum
|
528
531
|
float4 lsum4 = 0.0f;
|
529
532
|
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
530
|
-
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ?
|
533
|
+
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val);
|
531
534
|
lsum4 += exp_psrc4;
|
532
535
|
pdst4[i00] = exp_psrc4;
|
533
536
|
}
|
@@ -1595,60 +1598,6 @@ kernel void kernel_mul_mv_f16_f32_l4(
|
|
1595
1598
|
}
|
1596
1599
|
}
|
1597
1600
|
|
1598
|
-
kernel void kernel_alibi_f32(
|
1599
|
-
device const float * src0,
|
1600
|
-
device float * dst,
|
1601
|
-
constant int64_t & ne00,
|
1602
|
-
constant int64_t & ne01,
|
1603
|
-
constant int64_t & ne02,
|
1604
|
-
constant int64_t & ne03,
|
1605
|
-
constant uint64_t & nb00,
|
1606
|
-
constant uint64_t & nb01,
|
1607
|
-
constant uint64_t & nb02,
|
1608
|
-
constant uint64_t & nb03,
|
1609
|
-
constant int64_t & ne0,
|
1610
|
-
constant int64_t & ne1,
|
1611
|
-
constant int64_t & ne2,
|
1612
|
-
constant int64_t & ne3,
|
1613
|
-
constant uint64_t & nb0,
|
1614
|
-
constant uint64_t & nb1,
|
1615
|
-
constant uint64_t & nb2,
|
1616
|
-
constant uint64_t & nb3,
|
1617
|
-
constant float & m0,
|
1618
|
-
constant float & m1,
|
1619
|
-
constant int & n_heads_log2_floor,
|
1620
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
1621
|
-
uint3 tpitg[[thread_position_in_threadgroup]],
|
1622
|
-
uint3 ntg[[threads_per_threadgroup]]) {
|
1623
|
-
const int64_t i03 = tgpig[2];
|
1624
|
-
const int64_t i02 = tgpig[1];
|
1625
|
-
const int64_t i01 = tgpig[0];
|
1626
|
-
|
1627
|
-
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
1628
|
-
|
1629
|
-
const int64_t i3 = n / (ne2*ne1*ne0);
|
1630
|
-
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
1631
|
-
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
1632
|
-
//const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
1633
|
-
|
1634
|
-
const int64_t k = i3*ne3 + i2;
|
1635
|
-
|
1636
|
-
float m_k;
|
1637
|
-
if (k < n_heads_log2_floor) {
|
1638
|
-
m_k = pow(m0, k + 1);
|
1639
|
-
} else {
|
1640
|
-
m_k = pow(m1, 2 * (k - n_heads_log2_floor) + 1);
|
1641
|
-
}
|
1642
|
-
|
1643
|
-
device char * dst_row = (device char *) dst + i3*nb3 + i2*nb2 + i1*nb1;
|
1644
|
-
device const char * src_row = (device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01;
|
1645
|
-
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
1646
|
-
const float src_v = *(device float *)(src_row + i00*nb00);
|
1647
|
-
device float * dst_v = (device float *)(dst_row + i00*nb0);
|
1648
|
-
*dst_v = i00 * m_k + src_v;
|
1649
|
-
}
|
1650
|
-
}
|
1651
|
-
|
1652
1601
|
static float rope_yarn_ramp(const float low, const float high, const int i0) {
|
1653
1602
|
const float y = (i0 / 2 - low) / max(0.001f, high - low);
|
1654
1603
|
return 1.0f - min(1.0f, max(0.0f, y));
|
@@ -1691,6 +1640,7 @@ static void rope_yarn_corr_dims(
|
|
1691
1640
|
typedef void (rope_t)(
|
1692
1641
|
device const void * src0,
|
1693
1642
|
device const int32_t * src1,
|
1643
|
+
device const float * src2,
|
1694
1644
|
device float * dst,
|
1695
1645
|
constant int64_t & ne00,
|
1696
1646
|
constant int64_t & ne01,
|
@@ -1726,6 +1676,7 @@ template<typename T>
|
|
1726
1676
|
kernel void kernel_rope(
|
1727
1677
|
device const void * src0,
|
1728
1678
|
device const int32_t * src1,
|
1679
|
+
device const float * src2,
|
1729
1680
|
device float * dst,
|
1730
1681
|
constant int64_t & ne00,
|
1731
1682
|
constant int64_t & ne01,
|
@@ -1795,8 +1746,10 @@ kernel void kernel_rope(
|
|
1795
1746
|
|
1796
1747
|
// simplified from `(ib * n_dims + ic) * inv_ndims`
|
1797
1748
|
const float cur_rot = inv_ndims*ic - ib;
|
1749
|
+
const float freq_factor = src2 != src0 ? src2[ic/2] : 1.0f;
|
1750
|
+
|
1751
|
+
const float theta = theta_0 * pow(freq_base, cur_rot) / freq_factor;
|
1798
1752
|
|
1799
|
-
const float theta = theta_0 * pow(freq_base, cur_rot);
|
1800
1753
|
float cos_theta, sin_theta;
|
1801
1754
|
rope_yarn(theta, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
1802
1755
|
|
@@ -1903,7 +1856,10 @@ kernel void kernel_upscale_f32(
|
|
1903
1856
|
constant uint64_t & nb1,
|
1904
1857
|
constant uint64_t & nb2,
|
1905
1858
|
constant uint64_t & nb3,
|
1906
|
-
constant
|
1859
|
+
constant float & sf0,
|
1860
|
+
constant float & sf1,
|
1861
|
+
constant float & sf2,
|
1862
|
+
constant float & sf3,
|
1907
1863
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
1908
1864
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
1909
1865
|
uint3 ntg[[threads_per_threadgroup]]) {
|
@@ -1912,15 +1868,17 @@ kernel void kernel_upscale_f32(
|
|
1912
1868
|
const int64_t i2 = tgpig.y;
|
1913
1869
|
const int64_t i1 = tgpig.x;
|
1914
1870
|
|
1915
|
-
const int64_t i03 = i3;
|
1916
|
-
const int64_t i02 = i2;
|
1917
|
-
const int64_t i01 = i1/
|
1918
|
-
|
1919
|
-
device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
|
1920
|
-
device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1);
|
1871
|
+
const int64_t i03 = i3/sf3;
|
1872
|
+
const int64_t i02 = i2/sf2;
|
1873
|
+
const int64_t i01 = i1/sf1;
|
1921
1874
|
|
1922
1875
|
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
1923
|
-
|
1876
|
+
const int64_t i00 = i0/sf0;
|
1877
|
+
|
1878
|
+
device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
1879
|
+
device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
1880
|
+
|
1881
|
+
dst_ptr[0] = src0_ptr[0];
|
1924
1882
|
}
|
1925
1883
|
}
|
1926
1884
|
|
@@ -2100,29 +2058,29 @@ typedef void (flash_attn_ext_f16_t)(
|
|
2100
2058
|
device const char * v,
|
2101
2059
|
device const char * mask,
|
2102
2060
|
device float * dst,
|
2103
|
-
constant int64_t & ne00,
|
2104
2061
|
constant int64_t & ne01,
|
2105
2062
|
constant int64_t & ne02,
|
2106
2063
|
constant int64_t & ne03,
|
2107
|
-
constant uint64_t & nb00,
|
2108
2064
|
constant uint64_t & nb01,
|
2109
2065
|
constant uint64_t & nb02,
|
2110
2066
|
constant uint64_t & nb03,
|
2111
|
-
constant int64_t & ne10,
|
2112
2067
|
constant int64_t & ne11,
|
2113
2068
|
constant int64_t & ne12,
|
2114
2069
|
constant int64_t & ne13,
|
2115
|
-
constant uint64_t & nb10,
|
2116
2070
|
constant uint64_t & nb11,
|
2117
2071
|
constant uint64_t & nb12,
|
2118
2072
|
constant uint64_t & nb13,
|
2119
|
-
constant
|
2073
|
+
constant uint64_t & nb21,
|
2074
|
+
constant uint64_t & nb22,
|
2075
|
+
constant uint64_t & nb23,
|
2120
2076
|
constant uint64_t & nb31,
|
2121
|
-
constant int64_t & ne0,
|
2122
2077
|
constant int64_t & ne1,
|
2123
2078
|
constant int64_t & ne2,
|
2124
|
-
constant int64_t & ne3,
|
2125
2079
|
constant float & scale,
|
2080
|
+
constant float & max_bias,
|
2081
|
+
constant float & m0,
|
2082
|
+
constant float & m1,
|
2083
|
+
constant uint32_t & n_head_log2,
|
2126
2084
|
threadgroup half * shared,
|
2127
2085
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
2128
2086
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
@@ -2138,29 +2096,29 @@ kernel void kernel_flash_attn_ext_f16(
|
|
2138
2096
|
device const char * v,
|
2139
2097
|
device const char * mask,
|
2140
2098
|
device float * dst,
|
2141
|
-
constant int64_t & ne00,
|
2142
2099
|
constant int64_t & ne01,
|
2143
2100
|
constant int64_t & ne02,
|
2144
2101
|
constant int64_t & ne03,
|
2145
|
-
constant uint64_t & nb00,
|
2146
2102
|
constant uint64_t & nb01,
|
2147
2103
|
constant uint64_t & nb02,
|
2148
2104
|
constant uint64_t & nb03,
|
2149
|
-
constant int64_t & ne10,
|
2150
2105
|
constant int64_t & ne11,
|
2151
2106
|
constant int64_t & ne12,
|
2152
2107
|
constant int64_t & ne13,
|
2153
|
-
constant uint64_t & nb10,
|
2154
2108
|
constant uint64_t & nb11,
|
2155
2109
|
constant uint64_t & nb12,
|
2156
2110
|
constant uint64_t & nb13,
|
2157
|
-
constant
|
2111
|
+
constant uint64_t & nb21,
|
2112
|
+
constant uint64_t & nb22,
|
2113
|
+
constant uint64_t & nb23,
|
2158
2114
|
constant uint64_t & nb31,
|
2159
|
-
constant int64_t & ne0,
|
2160
2115
|
constant int64_t & ne1,
|
2161
2116
|
constant int64_t & ne2,
|
2162
|
-
constant int64_t & ne3,
|
2163
2117
|
constant float & scale,
|
2118
|
+
constant float & max_bias,
|
2119
|
+
constant float & m0,
|
2120
|
+
constant float & m1,
|
2121
|
+
constant uint32_t & n_head_log2,
|
2164
2122
|
threadgroup half * shared [[threadgroup(0)]],
|
2165
2123
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
2166
2124
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
@@ -2225,10 +2183,6 @@ kernel void kernel_flash_attn_ext_f16(
|
|
2225
2183
|
const short ne22 = ne12;
|
2226
2184
|
const short ne23 = ne13;
|
2227
2185
|
|
2228
|
-
const uint nb21 = nb11;
|
2229
|
-
const uint nb22 = nb12;
|
2230
|
-
const uint nb23 = nb13;
|
2231
|
-
|
2232
2186
|
// broadcast
|
2233
2187
|
const short rk2 = ne02/ne12;
|
2234
2188
|
const short rk3 = ne03/ne13;
|
@@ -2254,8 +2208,17 @@ kernel void kernel_flash_attn_ext_f16(
|
|
2254
2208
|
// pointer to the mask
|
2255
2209
|
device const half * mp = (device const half *) (mask + iq1*nb31);
|
2256
2210
|
|
2257
|
-
|
2258
|
-
|
2211
|
+
float slope = 1.0f;
|
2212
|
+
|
2213
|
+
// ALiBi
|
2214
|
+
if (max_bias > 0.0f) {
|
2215
|
+
const uint32_t h = iq2;
|
2216
|
+
|
2217
|
+
const float base = h < n_head_log2 ? m0 : m1;
|
2218
|
+
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
2219
|
+
|
2220
|
+
slope = pow(base, exph);
|
2221
|
+
}
|
2259
2222
|
|
2260
2223
|
// loop over the KV cache
|
2261
2224
|
// each simdgroup handles blocks of Q rows and C columns
|
@@ -2279,12 +2242,20 @@ kernel void kernel_flash_attn_ext_f16(
|
|
2279
2242
|
simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
|
2280
2243
|
}
|
2281
2244
|
|
2282
|
-
// mqk = mqk*scale + mask
|
2283
|
-
simdgroup_half8x8 mm;
|
2284
|
-
simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(half), 0, false);
|
2285
|
-
simdgroup_multiply_accumulate(mqk, mqk, mscale, mm);
|
2286
|
-
|
2287
2245
|
simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
|
2246
|
+
|
2247
|
+
const short tx = tiisg%4;
|
2248
|
+
const short ty = tiisg/4;
|
2249
|
+
|
2250
|
+
if (mask != q) {
|
2251
|
+
// mqk = mqk*scale + mask*slope
|
2252
|
+
ss[8*cc + ty*TF + 2*tx + 0] = scale*ss[8*cc + ty*TF + 2*tx + 0] + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 0];
|
2253
|
+
ss[8*cc + ty*TF + 2*tx + 1] = scale*ss[8*cc + ty*TF + 2*tx + 1] + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 1];
|
2254
|
+
} else {
|
2255
|
+
// mqk = mqk*scale
|
2256
|
+
ss[8*cc + ty*TF + 2*tx + 0] *= scale;
|
2257
|
+
ss[8*cc + ty*TF + 2*tx + 1] *= scale;
|
2258
|
+
}
|
2288
2259
|
}
|
2289
2260
|
}
|
2290
2261
|
|
@@ -2456,29 +2427,29 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
|
2456
2427
|
device const char * v,
|
2457
2428
|
device const char * mask,
|
2458
2429
|
device float * dst,
|
2459
|
-
constant int64_t & ne00,
|
2460
2430
|
constant int64_t & ne01,
|
2461
2431
|
constant int64_t & ne02,
|
2462
2432
|
constant int64_t & ne03,
|
2463
|
-
constant uint64_t & nb00,
|
2464
2433
|
constant uint64_t & nb01,
|
2465
2434
|
constant uint64_t & nb02,
|
2466
2435
|
constant uint64_t & nb03,
|
2467
|
-
constant int64_t & ne10,
|
2468
2436
|
constant int64_t & ne11,
|
2469
2437
|
constant int64_t & ne12,
|
2470
2438
|
constant int64_t & ne13,
|
2471
|
-
constant uint64_t & nb10,
|
2472
2439
|
constant uint64_t & nb11,
|
2473
2440
|
constant uint64_t & nb12,
|
2474
2441
|
constant uint64_t & nb13,
|
2475
|
-
constant
|
2442
|
+
constant uint64_t & nb21,
|
2443
|
+
constant uint64_t & nb22,
|
2444
|
+
constant uint64_t & nb23,
|
2476
2445
|
constant uint64_t & nb31,
|
2477
|
-
constant int64_t & ne0,
|
2478
2446
|
constant int64_t & ne1,
|
2479
2447
|
constant int64_t & ne2,
|
2480
|
-
constant int64_t & ne3,
|
2481
2448
|
constant float & scale,
|
2449
|
+
constant float & max_bias,
|
2450
|
+
constant float & m0,
|
2451
|
+
constant float & m1,
|
2452
|
+
constant uint32_t & n_head_log2,
|
2482
2453
|
threadgroup half * shared [[threadgroup(0)]],
|
2483
2454
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
2484
2455
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
@@ -2497,6 +2468,18 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
|
2497
2468
|
|
2498
2469
|
const short T = D + 2*nsg*SH; // shared memory size per query in (half)
|
2499
2470
|
|
2471
|
+
float slope = 1.0f;
|
2472
|
+
|
2473
|
+
// ALiBi
|
2474
|
+
if (max_bias > 0.0f) {
|
2475
|
+
const uint32_t h = iq2;
|
2476
|
+
|
2477
|
+
const float base = h < n_head_log2 ? m0 : m1;
|
2478
|
+
const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
2479
|
+
|
2480
|
+
slope = pow(base, exp);
|
2481
|
+
}
|
2482
|
+
|
2500
2483
|
//threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
|
2501
2484
|
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
|
2502
2485
|
threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
|
@@ -2537,10 +2520,6 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
|
2537
2520
|
const short ne22 = ne12;
|
2538
2521
|
const short ne23 = ne13;
|
2539
2522
|
|
2540
|
-
const uint nb21 = nb11;
|
2541
|
-
const uint nb22 = nb12;
|
2542
|
-
const uint nb23 = nb13;
|
2543
|
-
|
2544
2523
|
// broadcast
|
2545
2524
|
const short rk2 = ne02/ne12;
|
2546
2525
|
const short rk3 = ne03/ne13;
|
@@ -2603,10 +2582,9 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
|
2603
2582
|
mqk += simd_shuffle_down(mqk, 2);
|
2604
2583
|
mqk += simd_shuffle_down(mqk, 1);
|
2605
2584
|
|
2606
|
-
// mqk = mqk*scale + mask
|
2585
|
+
// mqk = mqk*scale + mask*slope
|
2607
2586
|
if (tiisg == 0) {
|
2608
|
-
|
2609
|
-
mqk = mqk*scale + mm;
|
2587
|
+
mqk = mqk*scale + ((mask != q) ? ((float4) mp4[ic/4 + cc])*slope : (float4) 0.0f);
|
2610
2588
|
|
2611
2589
|
ss4[cc] = mqk;
|
2612
2590
|
}
|
@@ -3408,7 +3386,6 @@ void kernel_mul_mv_q2_K_f32_impl(
|
|
3408
3386
|
|
3409
3387
|
const int step = sizeof(block_q2_K) * nb;
|
3410
3388
|
|
3411
|
-
#if QK_K == 256
|
3412
3389
|
const int ix = tiisg/8; // 0...3
|
3413
3390
|
const int it = tiisg%8; // 0...7
|
3414
3391
|
const int iq = it/4; // 0 or 1
|
@@ -3460,57 +3437,6 @@ void kernel_mul_mv_q2_K_f32_impl(
|
|
3460
3437
|
|
3461
3438
|
y4 += 4 * QK_K;
|
3462
3439
|
}
|
3463
|
-
#else
|
3464
|
-
const int ix = tiisg/2; // 0...15
|
3465
|
-
const int it = tiisg%2; // 0...1
|
3466
|
-
|
3467
|
-
device const float * y4 = y + ix * QK_K + 8 * it;
|
3468
|
-
|
3469
|
-
for (int ib = ix; ib < nb; ib += 16) {
|
3470
|
-
|
3471
|
-
float4 sumy = {0.f, 0.f, 0.f, 0.f};
|
3472
|
-
for (int i = 0; i < 8; ++i) {
|
3473
|
-
yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
|
3474
|
-
yl[i+ 8] = y4[i+16]; sumy[1] += yl[i+ 8];
|
3475
|
-
yl[i+16] = y4[i+32]; sumy[2] += yl[i+16];
|
3476
|
-
yl[i+24] = y4[i+48]; sumy[3] += yl[i+24];
|
3477
|
-
}
|
3478
|
-
|
3479
|
-
device const uint8_t * sc = (device const uint8_t *)x[ib].scales;
|
3480
|
-
device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it;
|
3481
|
-
device const half * dh = &x[ib].d;
|
3482
|
-
|
3483
|
-
for (int row = 0; row < N_DST; row++) {
|
3484
|
-
|
3485
|
-
float4 acc1 = {0.f, 0.f, 0.f, 0.f};
|
3486
|
-
float4 acc2 = {0.f, 0.f, 0.f, 0.f};
|
3487
|
-
for (int i = 0; i < 8; i += 2) {
|
3488
|
-
acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003);
|
3489
|
-
acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300);
|
3490
|
-
acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c);
|
3491
|
-
acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00);
|
3492
|
-
acc1[2] += yl[i+16] * (qs[i/2] & 0x0030);
|
3493
|
-
acc2[2] += yl[i+17] * (qs[i/2] & 0x3000);
|
3494
|
-
acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0);
|
3495
|
-
acc2[3] += yl[i+25] * (qs[i/2] & 0xc000);
|
3496
|
-
}
|
3497
|
-
|
3498
|
-
float dall = dh[0];
|
3499
|
-
float dmin = dh[1];
|
3500
|
-
sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f +
|
3501
|
-
(acc1[1] + 1.f/256.f * acc2[1]) * (sc[1] & 0xF) * 1.f/ 4.f +
|
3502
|
-
(acc1[2] + 1.f/256.f * acc2[2]) * (sc[2] & 0xF) * 1.f/16.f +
|
3503
|
-
(acc1[3] + 1.f/256.f * acc2[3]) * (sc[3] & 0xF) * 1.f/64.f) -
|
3504
|
-
dmin * (sumy[0] * (sc[0] >> 4) + sumy[1] * (sc[1] >> 4) + sumy[2] * (sc[2] >> 4) + sumy[3] * (sc[3] >> 4));
|
3505
|
-
|
3506
|
-
qs += step/2;
|
3507
|
-
sc += step;
|
3508
|
-
dh += step/2;
|
3509
|
-
}
|
3510
|
-
|
3511
|
-
y4 += 16 * QK_K;
|
3512
|
-
}
|
3513
|
-
#endif
|
3514
3440
|
|
3515
3441
|
for (int row = 0; row < N_DST; ++row) {
|
3516
3442
|
all_sum = simd_sum(sumf[row]);
|
@@ -3548,7 +3474,6 @@ kernel void kernel_mul_mv_q2_K_f32(
|
|
3548
3474
|
kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
|
3549
3475
|
}
|
3550
3476
|
|
3551
|
-
#if QK_K == 256
|
3552
3477
|
void kernel_mul_mv_q3_K_f32_impl(
|
3553
3478
|
device const void * src0,
|
3554
3479
|
device const float * src1,
|
@@ -3707,84 +3632,6 @@ void kernel_mul_mv_q3_K_f32_impl(
|
|
3707
3632
|
}
|
3708
3633
|
}
|
3709
3634
|
}
|
3710
|
-
#else
|
3711
|
-
void kernel_mul_mv_q3_K_f32_impl(
|
3712
|
-
device const void * src0,
|
3713
|
-
device const float * src1,
|
3714
|
-
device float * dst,
|
3715
|
-
constant int64_t & ne00,
|
3716
|
-
constant int64_t & ne01,
|
3717
|
-
constant int64_t & ne02,
|
3718
|
-
constant int64_t & ne10,
|
3719
|
-
constant int64_t & ne12,
|
3720
|
-
constant int64_t & ne0,
|
3721
|
-
constant int64_t & ne1,
|
3722
|
-
constant uint & r2,
|
3723
|
-
constant uint & r3,
|
3724
|
-
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
3725
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
3726
|
-
uint tiisg[[thread_index_in_simdgroup]],
|
3727
|
-
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
3728
|
-
|
3729
|
-
const int nb = ne00/QK_K;
|
3730
|
-
|
3731
|
-
const int64_t r0 = tgpig.x;
|
3732
|
-
const int64_t r1 = tgpig.y;
|
3733
|
-
const int64_t im = tgpig.z;
|
3734
|
-
|
3735
|
-
const int row = 2 * r0 + sgitg;
|
3736
|
-
|
3737
|
-
const uint i12 = im%ne12;
|
3738
|
-
const uint i13 = im/ne12;
|
3739
|
-
|
3740
|
-
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
3741
|
-
|
3742
|
-
device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb + offset0;
|
3743
|
-
device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
3744
|
-
|
3745
|
-
const int ix = tiisg/4;
|
3746
|
-
const int il = 4 * (tiisg%4);// 0, 4, 8, 12
|
3747
|
-
const int iq = il/8; // 0, 0, 1, 1
|
3748
|
-
const int in = il%8; // 0, 4, 0, 4
|
3749
|
-
|
3750
|
-
float2 sum = {0.f, 0.f};
|
3751
|
-
|
3752
|
-
for (int i = ix; i < nb; i += 8) {
|
3753
|
-
|
3754
|
-
const float d_all = (float)(x[i].d);
|
3755
|
-
|
3756
|
-
device const uint16_t * q = (device const uint16_t *)(x[i].qs + il);
|
3757
|
-
device const uint16_t * h = (device const uint16_t *)(x[i].hmask + in);
|
3758
|
-
device const uint16_t * s = (device const uint16_t *)(x[i].scales);
|
3759
|
-
device const float * y = yy + i * QK_K + il;
|
3760
|
-
|
3761
|
-
const float d1 = d_all * ((int32_t)(s[0] & 0x000F) - 8);
|
3762
|
-
const float d2 = d_all * ((int32_t)(s[0] & 0x00F0) - 128) * 1.f/64.f;
|
3763
|
-
const float d3 = d_all * ((int32_t)(s[0] & 0x0F00) - 2048) * 1.f/4096.f;
|
3764
|
-
const float d4 = d_all * ((int32_t)(s[0] & 0xF000) - 32768) * 1.f/262144.f;
|
3765
|
-
|
3766
|
-
for (int l = 0; l < 4; l += 2) {
|
3767
|
-
const uint16_t hm = h[l/2] >> iq;
|
3768
|
-
sum[0] += y[l+ 0] * d1 * ((int32_t)(q[l/2] & 0x0003) - ((hm & 0x0001) ? 0 : 4))
|
3769
|
-
+ y[l+16] * d2 * ((int32_t)(q[l/2] & 0x000c) - ((hm & 0x0004) ? 0 : 16))
|
3770
|
-
+ y[l+32] * d3 * ((int32_t)(q[l/2] & 0x0030) - ((hm & 0x0010) ? 0 : 64))
|
3771
|
-
+ y[l+48] * d4 * ((int32_t)(q[l/2] & 0x00c0) - ((hm & 0x0040) ? 0 : 256));
|
3772
|
-
sum[1] += y[l+ 1] * d1 * ((int32_t)(q[l/2] & 0x0300) - ((hm & 0x0100) ? 0 : 1024))
|
3773
|
-
+ y[l+17] * d2 * ((int32_t)(q[l/2] & 0x0c00) - ((hm & 0x0400) ? 0 : 4096))
|
3774
|
-
+ y[l+33] * d3 * ((int32_t)(q[l/2] & 0x3000) - ((hm & 0x1000) ? 0 : 16384))
|
3775
|
-
+ y[l+49] * d4 * ((int32_t)(q[l/2] & 0xc000) - ((hm & 0x4000) ? 0 : 65536));
|
3776
|
-
}
|
3777
|
-
|
3778
|
-
}
|
3779
|
-
const float sumf = sum[0] + sum[1] * 1.f/256.f;
|
3780
|
-
|
3781
|
-
const float tot = simd_sum(sumf);
|
3782
|
-
if (tiisg == 0) {
|
3783
|
-
dst[r1*ne0 + im*ne0*ne1 + row] = tot;
|
3784
|
-
}
|
3785
|
-
|
3786
|
-
}
|
3787
|
-
#endif
|
3788
3635
|
|
3789
3636
|
[[host_name("kernel_mul_mv_q3_K_f32")]]
|
3790
3637
|
kernel void kernel_mul_mv_q3_K_f32(
|
@@ -3814,7 +3661,6 @@ kernel void kernel_mul_mv_q3_K_f32(
|
|
3814
3661
|
kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
|
3815
3662
|
}
|
3816
3663
|
|
3817
|
-
#if QK_K == 256
|
3818
3664
|
void kernel_mul_mv_q4_K_f32_impl(
|
3819
3665
|
device const void * src0,
|
3820
3666
|
device const float * src1,
|
@@ -3928,103 +3774,6 @@ void kernel_mul_mv_q4_K_f32_impl(
|
|
3928
3774
|
}
|
3929
3775
|
}
|
3930
3776
|
}
|
3931
|
-
#else
|
3932
|
-
void kernel_mul_mv_q4_K_f32_impl(
|
3933
|
-
device const void * src0,
|
3934
|
-
device const float * src1,
|
3935
|
-
device float * dst,
|
3936
|
-
constant int64_t & ne00,
|
3937
|
-
constant int64_t & ne01,
|
3938
|
-
constant int64_t & ne02,
|
3939
|
-
constant int64_t & ne10,
|
3940
|
-
constant int64_t & ne12,
|
3941
|
-
constant int64_t & ne0,
|
3942
|
-
constant int64_t & ne1,
|
3943
|
-
constant uint & r2,
|
3944
|
-
constant uint & r3,
|
3945
|
-
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
3946
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
3947
|
-
uint tiisg[[thread_index_in_simdgroup]],
|
3948
|
-
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
3949
|
-
|
3950
|
-
const int ix = tiisg/4; // 0...7
|
3951
|
-
const int it = tiisg%4; // 0...3
|
3952
|
-
|
3953
|
-
const int nb = ne00/QK_K;
|
3954
|
-
const int r0 = tgpig.x;
|
3955
|
-
const int r1 = tgpig.y;
|
3956
|
-
const int im = tgpig.z;
|
3957
|
-
const int first_row = r0 * N_DST;
|
3958
|
-
const int ib_row = first_row * nb;
|
3959
|
-
|
3960
|
-
const uint i12 = im%ne12;
|
3961
|
-
const uint i13 = im/ne12;
|
3962
|
-
|
3963
|
-
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
3964
|
-
|
3965
|
-
device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
|
3966
|
-
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
3967
|
-
|
3968
|
-
float yl[8];
|
3969
|
-
float yh[8];
|
3970
|
-
float sumf[N_DST]={0.f}, all_sum;
|
3971
|
-
|
3972
|
-
const int step = sizeof(block_q4_K) * nb / 2;
|
3973
|
-
|
3974
|
-
device const float * y4 = y + ix * QK_K + 8 * it;
|
3975
|
-
|
3976
|
-
uint16_t sc16[4];
|
3977
|
-
|
3978
|
-
for (int ib = ix; ib < nb; ib += 8) {
|
3979
|
-
|
3980
|
-
float2 sumy = {0.f, 0.f};
|
3981
|
-
for (int i = 0; i < 8; ++i) {
|
3982
|
-
yl[i] = y4[i+ 0]; sumy[0] += yl[i];
|
3983
|
-
yh[i] = y4[i+32]; sumy[1] += yh[i];
|
3984
|
-
}
|
3985
|
-
|
3986
|
-
device const uint16_t * sc = (device const uint16_t *)x[ib].scales;
|
3987
|
-
device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it;
|
3988
|
-
device const half * dh = x[ib].d;
|
3989
|
-
|
3990
|
-
for (int row = 0; row < N_DST; row++) {
|
3991
|
-
|
3992
|
-
sc16[0] = sc[0] & 0x000f;
|
3993
|
-
sc16[1] = sc[0] & 0x0f00;
|
3994
|
-
sc16[2] = sc[0] & 0x00f0;
|
3995
|
-
sc16[3] = sc[0] & 0xf000;
|
3996
|
-
|
3997
|
-
float2 acc1 = {0.f, 0.f};
|
3998
|
-
float2 acc2 = {0.f, 0.f};
|
3999
|
-
for (int i = 0; i < 8; i += 2) {
|
4000
|
-
acc1[0] += yl[i+0] * (qs[i/2] & 0x000F);
|
4001
|
-
acc1[1] += yl[i+1] * (qs[i/2] & 0x0F00);
|
4002
|
-
acc2[0] += yh[i+0] * (qs[i/2] & 0x00F0);
|
4003
|
-
acc2[1] += yh[i+1] * (qs[i/2] & 0xF000);
|
4004
|
-
}
|
4005
|
-
|
4006
|
-
float dall = dh[0];
|
4007
|
-
float dmin = dh[1];
|
4008
|
-
sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc16[0] +
|
4009
|
-
(acc2[0] + 1.f/256.f * acc2[1]) * sc16[1] * 1.f/4096.f) -
|
4010
|
-
dmin * 1.f/16.f * (sumy[0] * sc16[2] + sumy[1] * sc16[3] * 1.f/256.f);
|
4011
|
-
|
4012
|
-
qs += step;
|
4013
|
-
sc += step;
|
4014
|
-
dh += step;
|
4015
|
-
}
|
4016
|
-
|
4017
|
-
y4 += 8 * QK_K;
|
4018
|
-
}
|
4019
|
-
|
4020
|
-
for (int row = 0; row < N_DST; ++row) {
|
4021
|
-
all_sum = simd_sum(sumf[row]);
|
4022
|
-
if (tiisg == 0) {
|
4023
|
-
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
|
4024
|
-
}
|
4025
|
-
}
|
4026
|
-
}
|
4027
|
-
#endif
|
4028
3777
|
|
4029
3778
|
[[host_name("kernel_mul_mv_q4_K_f32")]]
|
4030
3779
|
kernel void kernel_mul_mv_q4_K_f32(
|
@@ -4092,8 +3841,6 @@ void kernel_mul_mv_q5_K_f32_impl(
|
|
4092
3841
|
|
4093
3842
|
const int step = sizeof(block_q5_K) * nb;
|
4094
3843
|
|
4095
|
-
#if QK_K == 256
|
4096
|
-
#
|
4097
3844
|
float yl[16], yh[16];
|
4098
3845
|
|
4099
3846
|
const uint16_t kmask1 = 0x3f3f;
|
@@ -4176,54 +3923,6 @@ void kernel_mul_mv_q5_K_f32_impl(
|
|
4176
3923
|
y1 += 4 * QK_K;
|
4177
3924
|
|
4178
3925
|
}
|
4179
|
-
#else
|
4180
|
-
float yl[8], yh[8];
|
4181
|
-
|
4182
|
-
const int il = 4 * (tiisg/8); // 0, 4, 8, 12
|
4183
|
-
const int ix = tiisg%8;
|
4184
|
-
const int iq = il/8; // 0, 0, 1, 1
|
4185
|
-
const int in = il%8; // 0, 4, 0, 4
|
4186
|
-
|
4187
|
-
device const float * y = yy + ix*QK_K + il;
|
4188
|
-
|
4189
|
-
for (int i = ix; i < nb; i += 8) {
|
4190
|
-
|
4191
|
-
for (int l = 0; l < 4; ++l) {
|
4192
|
-
yl[l+0] = y[l+ 0];
|
4193
|
-
yl[l+4] = y[l+16];
|
4194
|
-
yh[l+0] = y[l+32];
|
4195
|
-
yh[l+4] = y[l+48];
|
4196
|
-
}
|
4197
|
-
|
4198
|
-
device const half * dh = &x[i].d;
|
4199
|
-
device const uint8_t * q = x[i].qs + il;
|
4200
|
-
device const uint8_t * h = x[i].qh + in;
|
4201
|
-
device const int8_t * s = x[i].scales;
|
4202
|
-
|
4203
|
-
for (int row = 0; row < 2; ++row) {
|
4204
|
-
|
4205
|
-
const float d = dh[0];
|
4206
|
-
|
4207
|
-
float2 acc = {0.f, 0.f};
|
4208
|
-
for (int l = 0; l < 4; ++l) {
|
4209
|
-
const uint8_t hl = h[l] >> iq;
|
4210
|
-
acc[0] += yl[l+0] * s[0] * ((int16_t)(q[l+ 0] & 0x0F) - (hl & 0x01 ? 0 : 16))
|
4211
|
-
+ yl[l+4] * s[1] * ((int16_t)(q[l+16] & 0x0F) - (hl & 0x04 ? 0 : 16));
|
4212
|
-
acc[1] += yh[l+0] * s[2] * ((int16_t)(q[l+ 0] & 0xF0) - (hl & 0x10 ? 0 : 256))
|
4213
|
-
+ yh[l+4] * s[3] * ((int16_t)(q[l+16] & 0xF0) - (hl & 0x40 ? 0 : 256));
|
4214
|
-
}
|
4215
|
-
sumf[row] += d * (acc[0] + 1.f/16.f * acc[1]);
|
4216
|
-
|
4217
|
-
q += step;
|
4218
|
-
h += step;
|
4219
|
-
s += step;
|
4220
|
-
dh += step/2;
|
4221
|
-
|
4222
|
-
}
|
4223
|
-
|
4224
|
-
y += 8 * QK_K;
|
4225
|
-
}
|
4226
|
-
#endif
|
4227
3926
|
|
4228
3927
|
for (int row = 0; row < 2; ++row) {
|
4229
3928
|
const float tot = simd_sum(sumf[row]);
|
@@ -4302,7 +4001,6 @@ void kernel_mul_mv_q6_K_f32_impl(
|
|
4302
4001
|
|
4303
4002
|
float sumf = 0;
|
4304
4003
|
|
4305
|
-
#if QK_K == 256
|
4306
4004
|
const int tid = tiisg/2;
|
4307
4005
|
const int ix = tiisg%2;
|
4308
4006
|
const int ip = tid/8; // 0 or 1
|
@@ -4338,30 +4036,6 @@ void kernel_mul_mv_q6_K_f32_impl(
|
|
4338
4036
|
|
4339
4037
|
}
|
4340
4038
|
|
4341
|
-
#else
|
4342
|
-
const int ix = tiisg/4;
|
4343
|
-
const int il = 4*(tiisg%4);
|
4344
|
-
|
4345
|
-
for (int i = ix; i < nb; i += 8) {
|
4346
|
-
device const float * y = yy + i * QK_K + il;
|
4347
|
-
device const uint8_t * ql = x[i].ql + il;
|
4348
|
-
device const uint8_t * qh = x[i].qh + il;
|
4349
|
-
device const int8_t * s = x[i].scales;
|
4350
|
-
|
4351
|
-
const float d = x[i].d;
|
4352
|
-
|
4353
|
-
float4 sums = {0.f, 0.f, 0.f, 0.f};
|
4354
|
-
for (int l = 0; l < 4; ++l) {
|
4355
|
-
sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
|
4356
|
-
sums[1] += y[l+16] * ((int8_t)((ql[l+16] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
|
4357
|
-
sums[2] += y[l+32] * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & kmask3) >> 0)) - 32);
|
4358
|
-
sums[3] += y[l+48] * ((int8_t)((ql[l+16] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
|
4359
|
-
}
|
4360
|
-
sumf += d * (sums[0] * s[0] + sums[1] * s[1] + sums[2] * s[2] + sums[3] * s[3]);
|
4361
|
-
}
|
4362
|
-
|
4363
|
-
#endif
|
4364
|
-
|
4365
4039
|
const float tot = simd_sum(sumf);
|
4366
4040
|
if (tiisg == 0) {
|
4367
4041
|
dst[r1*ne0 + im*ne0*ne1 + row] = tot;
|
@@ -5195,9 +4869,7 @@ void kernel_mul_mv_iq1_m_f32_impl(
|
|
5195
4869
|
|
5196
4870
|
device const float * y4 = y + 32 * ix;
|
5197
4871
|
|
5198
|
-
#if QK_K != 64
|
5199
4872
|
iq1m_scale_t scale;
|
5200
|
-
#endif
|
5201
4873
|
|
5202
4874
|
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
|
5203
4875
|
|
@@ -5218,10 +4890,7 @@ void kernel_mul_mv_iq1_m_f32_impl(
|
|
5218
4890
|
device const uint16_t * sc = (device const uint16_t *)xr->scales;
|
5219
4891
|
|
5220
4892
|
for (int row = 0; row < N_DST; row++) {
|
5221
|
-
|
5222
|
-
#if QK_K != 64
|
5223
4893
|
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
|
5224
|
-
#endif
|
5225
4894
|
|
5226
4895
|
constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
|
5227
4896
|
constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700)));
|
@@ -5237,14 +4906,9 @@ void kernel_mul_mv_iq1_m_f32_impl(
|
|
5237
4906
|
}
|
5238
4907
|
const float delta1 = sumy[0] * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[1] * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
|
5239
4908
|
const float delta2 = sumy[2] * (qh[1] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[3] * (qh[1] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
|
5240
|
-
|
5241
|
-
const float d = (float) *((device const half *)(sc - 1));
|
5242
|
-
sumf[row] += d * ((sum[0] + delta1) * (2*((sc[0] >> (8*(ib%2)+0)) & 0xf) + 1) +
|
5243
|
-
(sum[1] + delta2) * (2*((sc[0] >> (8*(ib%2)+4)) & 0xf) + 1));
|
5244
|
-
#else
|
4909
|
+
|
5245
4910
|
sumf[row] += (float)scale.f16 * ((sum[0] + delta1) * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 7) + 1) +
|
5246
4911
|
(sum[1] + delta2) * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 7) + 1));
|
5247
|
-
#endif
|
5248
4912
|
|
5249
4913
|
sc += nb*sizeof(block_iq1_m)/2;
|
5250
4914
|
qs += nb*sizeof(block_iq1_m);
|
@@ -5356,7 +5020,6 @@ void kernel_mul_mv_iq4_nl_f32_impl(
|
|
5356
5020
|
}
|
5357
5021
|
}
|
5358
5022
|
|
5359
|
-
#if QK_K != 64
|
5360
5023
|
void kernel_mul_mv_iq4_xs_f32_impl(
|
5361
5024
|
device const void * src0,
|
5362
5025
|
device const float * src1,
|
@@ -5451,7 +5114,6 @@ void kernel_mul_mv_iq4_xs_f32_impl(
|
|
5451
5114
|
}
|
5452
5115
|
}
|
5453
5116
|
}
|
5454
|
-
#endif
|
5455
5117
|
|
5456
5118
|
[[host_name("kernel_mul_mv_iq1_s_f32")]]
|
5457
5119
|
kernel void kernel_mul_mv_iq1_s_f32(
|
@@ -5564,11 +5226,7 @@ kernel void kernel_mul_mv_iq4_xs_f32(
|
|
5564
5226
|
uint tiisg[[thread_index_in_simdgroup]],
|
5565
5227
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
5566
5228
|
|
5567
|
-
#if QK_K == 64
|
5568
|
-
kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
5569
|
-
#else
|
5570
5229
|
kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
5571
|
-
#endif
|
5572
5230
|
}
|
5573
5231
|
|
5574
5232
|
//============================= templates and their specializations =============================
|
@@ -5694,10 +5352,9 @@ void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg
|
|
5694
5352
|
float dl, ml;
|
5695
5353
|
uint8_t sc = xb->scales[il];
|
5696
5354
|
|
5697
|
-
#if QK_K == 256
|
5698
5355
|
q = q + 32*(il/8) + 16*(il&1);
|
5699
5356
|
il = (il/2)%4;
|
5700
|
-
|
5357
|
+
|
5701
5358
|
half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
|
5702
5359
|
uchar mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
|
5703
5360
|
dl = d * (sc & 0xF) * coef, ml = min * (sc >> 4);
|
@@ -5713,7 +5370,6 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
|
|
5713
5370
|
device const uint8_t * h = (device const uint8_t *)xb->hmask;
|
5714
5371
|
device const int8_t * scales = (device const int8_t *)xb->scales;
|
5715
5372
|
|
5716
|
-
#if QK_K == 256
|
5717
5373
|
q = q + 32 * (il/8) + 16 * (il&1);
|
5718
5374
|
h = h + 16 * (il&1);
|
5719
5375
|
uint8_t m = 1 << (il/2);
|
@@ -5734,17 +5390,6 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
|
|
5734
5390
|
for (int i = 0; i < 16; ++i) {
|
5735
5391
|
reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml);
|
5736
5392
|
}
|
5737
|
-
#else
|
5738
|
-
float kcoef = il&1 ? 1.f/16.f : 1.f;
|
5739
|
-
uint16_t kmask = il&1 ? 0xF0 : 0x0F;
|
5740
|
-
float dl = d_all * ((scales[il/2] & kmask) * kcoef - 8);
|
5741
|
-
float coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
|
5742
|
-
uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
|
5743
|
-
uint8_t m = 1<<(il*2);
|
5744
|
-
for (int i = 0; i < 16; ++i) {
|
5745
|
-
reg[i/4][i%4] = coef * dl * ((q[i] & mask) - ((h[i%8] & (m * (1 + i/8))) ? 0 : 4.f/coef));
|
5746
|
-
}
|
5747
|
-
#endif
|
5748
5393
|
}
|
5749
5394
|
|
5750
5395
|
static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) {
|
@@ -5756,7 +5401,6 @@ template <typename type4x4>
|
|
5756
5401
|
void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) {
|
5757
5402
|
device const uchar * q = xb->qs;
|
5758
5403
|
|
5759
|
-
#if QK_K == 256
|
5760
5404
|
short is = (il/4) * 2;
|
5761
5405
|
q = q + (il/4) * 32 + 16 * (il&1);
|
5762
5406
|
il = il & 3;
|
@@ -5765,16 +5409,7 @@ void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg
|
|
5765
5409
|
const float min = xb->dmin;
|
5766
5410
|
const float dl = d * sc[0];
|
5767
5411
|
const float ml = min * sc[1];
|
5768
|
-
|
5769
|
-
(void) get_scale_min_k4_just2;
|
5770
|
-
|
5771
|
-
q = q + 16 * (il&1);
|
5772
|
-
device const uint8_t * s = xb->scales;
|
5773
|
-
device const half2 * dh = (device const half2 *)xb->d;
|
5774
|
-
const float2 d = (float2)dh[0];
|
5775
|
-
const float dl = il<2 ? d[0] * (s[0]&0xF) : d[0] * (s[1]&0xF)/16.h;
|
5776
|
-
const float ml = il<2 ? d[1] * (s[0]>>4) : d[1] * (s[1]>>4);
|
5777
|
-
#endif
|
5412
|
+
|
5778
5413
|
const ushort mask = il<2 ? 0x0F : 0xF0;
|
5779
5414
|
for (int i = 0; i < 16; ++i) {
|
5780
5415
|
reg[i/4][i%4] = dl * (q[i] & mask) - ml;
|
@@ -5786,7 +5421,6 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
|
|
5786
5421
|
device const uint8_t * q = xb->qs;
|
5787
5422
|
device const uint8_t * qh = xb->qh;
|
5788
5423
|
|
5789
|
-
#if QK_K == 256
|
5790
5424
|
short is = (il/4) * 2;
|
5791
5425
|
q = q + 32 * (il/4) + 16 * (il&1);
|
5792
5426
|
qh = qh + 16 * (il&1);
|
@@ -5803,17 +5437,6 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
|
|
5803
5437
|
for (int i = 0; i < 16; ++i) {
|
5804
5438
|
reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
|
5805
5439
|
}
|
5806
|
-
#else
|
5807
|
-
q = q + 16 * (il&1);
|
5808
|
-
device const int8_t * s = xb->scales;
|
5809
|
-
const float dl = xb->d * s[il];
|
5810
|
-
uint8_t m = 1<<(il*2);
|
5811
|
-
const float coef = il<2 ? 1.f : 1.f/16.f;
|
5812
|
-
const ushort mask = il<2 ? 0x0F : 0xF0;
|
5813
|
-
for (int i = 0; i < 16; ++i) {
|
5814
|
-
reg[i/4][i%4] = coef * dl * ((q[i] & mask) - (qh[i%8] & (m*(1+i/8)) ? 0.f : 16.f/coef));
|
5815
|
-
}
|
5816
|
-
#endif
|
5817
5440
|
}
|
5818
5441
|
|
5819
5442
|
template <typename type4x4>
|
@@ -5823,15 +5446,11 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
|
|
5823
5446
|
device const uint8_t * qh = (device const uint8_t *)xb->qh;
|
5824
5447
|
device const int8_t * scales = (device const int8_t *)xb->scales;
|
5825
5448
|
|
5826
|
-
#if QK_K == 256
|
5827
5449
|
ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
|
5828
5450
|
qh = qh + 32*(il/8) + 16*(il&1);
|
5829
5451
|
float sc = scales[(il%2) + 2 * ((il/2))];
|
5830
5452
|
il = (il/2) & 3;
|
5831
|
-
|
5832
|
-
ql = ql + 16 * (il&1);
|
5833
|
-
float sc = scales[il];
|
5834
|
-
#endif
|
5453
|
+
|
5835
5454
|
const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
|
5836
5455
|
const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
|
5837
5456
|
const float coef = il>1 ? 1.f/16.f : 1.f;
|
@@ -5988,20 +5607,15 @@ void dequantize_iq1_m(device const block_iq1_m * xb, short il, thread type4x4 &
|
|
5988
5607
|
const int ib32 = il/2;
|
5989
5608
|
il = il%2;
|
5990
5609
|
device const uint16_t * sc = (device const uint16_t *)xb->scales;
|
5991
|
-
|
5992
|
-
const float d = xb->d;
|
5993
|
-
#else
|
5610
|
+
|
5994
5611
|
iq1m_scale_t scale;
|
5995
5612
|
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
|
5996
5613
|
const float d = scale.f16;
|
5997
|
-
|
5614
|
+
|
5998
5615
|
device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
|
5999
5616
|
device const uint8_t * qh = xb->qh + 2*ib32 + il;
|
6000
|
-
|
6001
|
-
const float dl = d * (2*((sc[ib32/2] >> (8*(ib32%2)+4*il)) & 0xf) + 1);
|
6002
|
-
#else
|
5617
|
+
|
6003
5618
|
const float dl = d * (2*((sc[ib32/2] >> (6*(ib32%2)+3*il)) & 7) + 1);
|
6004
|
-
#endif
|
6005
5619
|
const float ml1 = dl * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
|
6006
5620
|
const float ml2 = dl * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
|
6007
5621
|
constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
|
@@ -6031,9 +5645,6 @@ void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4
|
|
6031
5645
|
|
6032
5646
|
template <typename type4x4>
|
6033
5647
|
void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) {
|
6034
|
-
#if QK_K == 64
|
6035
|
-
dequantize_iq4_nl(xb, il, reg);
|
6036
|
-
#else
|
6037
5648
|
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
|
6038
5649
|
const int ib32 = il/2;
|
6039
5650
|
il = il%2;
|
@@ -6050,7 +5661,6 @@ void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4
|
|
6050
5661
|
reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
|
6051
5662
|
reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
|
6052
5663
|
}
|
6053
|
-
#endif
|
6054
5664
|
}
|
6055
5665
|
|
6056
5666
|
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
|
@@ -6555,11 +6165,7 @@ kernel void kernel_mul_mm_id(
|
|
6555
6165
|
sgitg);
|
6556
6166
|
}
|
6557
6167
|
|
6558
|
-
#if QK_K == 256
|
6559
6168
|
#define QK_NL 16
|
6560
|
-
#else
|
6561
|
-
#define QK_NL 4
|
6562
|
-
#endif
|
6563
6169
|
|
6564
6170
|
//
|
6565
6171
|
// get rows
|
@@ -6599,11 +6205,7 @@ template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_t kernel_get_r
|
|
6599
6205
|
template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
6600
6206
|
template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_t kernel_get_rows<block_iq1_m, QK_NL, dequantize_iq1_m>;
|
6601
6207
|
template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows<block_iq4_nl, 2, dequantize_iq4_nl>;
|
6602
|
-
#if QK_K == 64
|
6603
|
-
template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows<block_iq4_xs, 2, dequantize_iq4_xs>;
|
6604
|
-
#else
|
6605
6208
|
template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
6606
|
-
#endif
|
6607
6209
|
|
6608
6210
|
//
|
6609
6211
|
// matrix-matrix multiplication
|
@@ -6631,11 +6233,7 @@ template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_m
|
|
6631
6233
|
template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
6632
6234
|
template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_m, QK_NL, dequantize_iq1_m>;
|
6633
6235
|
template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_nl>;
|
6634
|
-
#if QK_K == 64
|
6635
|
-
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_xs>;
|
6636
|
-
#else
|
6637
6236
|
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
6638
|
-
#endif
|
6639
6237
|
|
6640
6238
|
//
|
6641
6239
|
// indirect matrix-matrix multiplication
|
@@ -6663,11 +6261,7 @@ template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mat_mm_id_t kernel
|
|
6663
6261
|
template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
6664
6262
|
template [[host_name("kernel_mul_mm_id_iq1_m_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_m, QK_NL, dequantize_iq1_m>;
|
6665
6263
|
template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl, 2, dequantize_iq4_nl>;
|
6666
|
-
#if QK_K == 64
|
6667
|
-
template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, 2, dequantize_iq4_xs>;
|
6668
|
-
#else
|
6669
6264
|
template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
6670
|
-
#endif
|
6671
6265
|
|
6672
6266
|
//
|
6673
6267
|
// matrix-vector multiplication
|
@@ -6876,7 +6470,5 @@ template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t
|
|
6876
6470
|
template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_s_f32_impl>>;
|
6877
6471
|
template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_s_f32_impl>>;
|
6878
6472
|
template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl>>;
|
6879
|
-
#if QK_K != 64
|
6880
6473
|
template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl>>;
|
6881
|
-
#endif
|
6882
6474
|
|