llama_cpp 0.14.7 → 0.15.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +13 -0
- data/README.md +2 -2
- data/ext/llama_cpp/extconf.rb +2 -1
- data/ext/llama_cpp/llama_cpp.cpp +53 -9
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +18 -3
- data/vendor/tmp/llama.cpp/Makefile +41 -16
- data/vendor/tmp/llama.cpp/ggml-backend.c +7 -5
- data/vendor/tmp/llama.cpp/ggml-cuda.cu +6 -0
- data/vendor/tmp/llama.cpp/ggml-impl.h +1 -1
- data/vendor/tmp/llama.cpp/ggml-kompute.cpp +7 -0
- data/vendor/tmp/llama.cpp/ggml-metal.m +376 -176
- data/vendor/tmp/llama.cpp/ggml-metal.metal +654 -18
- data/vendor/tmp/llama.cpp/ggml-quants.c +284 -0
- data/vendor/tmp/llama.cpp/ggml-sycl.cpp +17 -7
- data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +5 -0
- data/vendor/tmp/llama.cpp/ggml.c +391 -27
- data/vendor/tmp/llama.cpp/ggml.h +22 -0
- data/vendor/tmp/llama.cpp/llama.cpp +623 -395
- data/vendor/tmp/llama.cpp/llama.h +27 -9
- data/vendor/tmp/llama.cpp/sgemm.cpp +83 -87
- data/vendor/tmp/llama.cpp/sgemm.h +4 -2
- data/vendor/tmp/llama.cpp/unicode-data.cpp +1 -1
- data/vendor/tmp/llama.cpp/unicode-data.h +2 -2
- data/vendor/tmp/llama.cpp/unicode.cpp +448 -39
- data/vendor/tmp/llama.cpp/unicode.h +2 -1
- metadata +3 -3
@@ -352,11 +352,12 @@ kernel void kernel_sum_rows(
|
|
352
352
|
dst_row[0] = row_sum;
|
353
353
|
}
|
354
354
|
|
355
|
+
template<typename T>
|
355
356
|
kernel void kernel_soft_max(
|
356
|
-
device const
|
357
|
-
device const
|
358
|
-
device const
|
359
|
-
device
|
357
|
+
device const char * src0,
|
358
|
+
device const char * src1,
|
359
|
+
device const char * src2,
|
360
|
+
device char * dst,
|
360
361
|
constant int64_t & ne00,
|
361
362
|
constant int64_t & ne01,
|
362
363
|
constant int64_t & ne02,
|
@@ -375,10 +376,10 @@ kernel void kernel_soft_max(
|
|
375
376
|
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
|
376
377
|
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
377
378
|
|
378
|
-
device const float * psrc0 =
|
379
|
-
device const
|
380
|
-
device const
|
381
|
-
device float * pdst =
|
379
|
+
device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
380
|
+
device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00 : nullptr;
|
381
|
+
device const T * ppos = src2 != src0 ? (device const T *) src2 : nullptr;
|
382
|
+
device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
382
383
|
|
383
384
|
float slope = 0.0f;
|
384
385
|
|
@@ -456,11 +457,12 @@ kernel void kernel_soft_max(
|
|
456
457
|
}
|
457
458
|
}
|
458
459
|
|
460
|
+
template<typename T>
|
459
461
|
kernel void kernel_soft_max_4(
|
460
|
-
device const
|
461
|
-
device const
|
462
|
-
device const
|
463
|
-
device
|
462
|
+
device const char * src0,
|
463
|
+
device const char * src1,
|
464
|
+
device const char * src2,
|
465
|
+
device char * dst,
|
464
466
|
constant int64_t & ne00,
|
465
467
|
constant int64_t & ne01,
|
466
468
|
constant int64_t & ne02,
|
@@ -479,10 +481,10 @@ kernel void kernel_soft_max_4(
|
|
479
481
|
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
|
480
482
|
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
481
483
|
|
482
|
-
device const float4 * psrc4 =
|
483
|
-
device const
|
484
|
-
device const
|
485
|
-
device float4 * pdst4 =
|
484
|
+
device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
|
485
|
+
device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00/4 : nullptr;
|
486
|
+
device const T * ppos = src2 != src0 ? (device const T *) src2 : nullptr;
|
487
|
+
device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
|
486
488
|
|
487
489
|
float slope = 0.0f;
|
488
490
|
|
@@ -499,7 +501,7 @@ kernel void kernel_soft_max_4(
|
|
499
501
|
float4 lmax4 = -INFINITY;
|
500
502
|
|
501
503
|
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
502
|
-
lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f));
|
504
|
+
lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f)));
|
503
505
|
}
|
504
506
|
|
505
507
|
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
|
@@ -525,7 +527,7 @@ kernel void kernel_soft_max_4(
|
|
525
527
|
// parallel sum
|
526
528
|
float4 lsum4 = 0.0f;
|
527
529
|
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
528
|
-
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f)) - max_val);
|
530
|
+
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f))) - max_val);
|
529
531
|
lsum4 += exp_psrc4;
|
530
532
|
pdst4[i00] = exp_psrc4;
|
531
533
|
}
|
@@ -562,6 +564,14 @@ kernel void kernel_soft_max_4(
|
|
562
564
|
}
|
563
565
|
}
|
564
566
|
|
567
|
+
typedef decltype(kernel_soft_max<float>) kernel_soft_max_t;
|
568
|
+
typedef decltype(kernel_soft_max_4<float4>) kernel_soft_max_4_t;
|
569
|
+
|
570
|
+
template [[host_name("kernel_soft_max_f16")]] kernel kernel_soft_max_t kernel_soft_max<half>;
|
571
|
+
template [[host_name("kernel_soft_max_f32")]] kernel kernel_soft_max_t kernel_soft_max<float>;
|
572
|
+
template [[host_name("kernel_soft_max_f16_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4<half4>;
|
573
|
+
template [[host_name("kernel_soft_max_f32_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4<float4>;
|
574
|
+
|
565
575
|
kernel void kernel_diag_mask_inf(
|
566
576
|
device const float * src0,
|
567
577
|
device float * dst,
|
@@ -2084,6 +2094,632 @@ kernel void kernel_leaky_relu_f32(
|
|
2084
2094
|
dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope;
|
2085
2095
|
}
|
2086
2096
|
|
2097
|
+
typedef void (flash_attn_ext_f16_t)(
|
2098
|
+
device const char * q,
|
2099
|
+
device const char * k,
|
2100
|
+
device const char * v,
|
2101
|
+
device const char * mask,
|
2102
|
+
device float * dst,
|
2103
|
+
constant int64_t & ne00,
|
2104
|
+
constant int64_t & ne01,
|
2105
|
+
constant int64_t & ne02,
|
2106
|
+
constant int64_t & ne03,
|
2107
|
+
constant uint64_t & nb00,
|
2108
|
+
constant uint64_t & nb01,
|
2109
|
+
constant uint64_t & nb02,
|
2110
|
+
constant uint64_t & nb03,
|
2111
|
+
constant int64_t & ne10,
|
2112
|
+
constant int64_t & ne11,
|
2113
|
+
constant int64_t & ne12,
|
2114
|
+
constant int64_t & ne13,
|
2115
|
+
constant uint64_t & nb10,
|
2116
|
+
constant uint64_t & nb11,
|
2117
|
+
constant uint64_t & nb12,
|
2118
|
+
constant uint64_t & nb13,
|
2119
|
+
constant int64_t & ne31,
|
2120
|
+
constant uint64_t & nb31,
|
2121
|
+
constant int64_t & ne0,
|
2122
|
+
constant int64_t & ne1,
|
2123
|
+
constant int64_t & ne2,
|
2124
|
+
constant int64_t & ne3,
|
2125
|
+
constant float & scale,
|
2126
|
+
threadgroup half * shared,
|
2127
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
2128
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
2129
|
+
uint3 ntg[[threads_per_threadgroup]],
|
2130
|
+
ushort tiisg[[thread_index_in_simdgroup]],
|
2131
|
+
ushort sgitg[[simdgroup_index_in_threadgroup]]);
|
2132
|
+
|
2133
|
+
// ref: https://arxiv.org/pdf/2307.08691.pdf
|
2134
|
+
template<int64_t D, int64_t Q = 8, int64_t C = 32> // head size, queries per threadgroup, cache items per threadgroup
|
2135
|
+
kernel void kernel_flash_attn_ext_f16(
|
2136
|
+
device const char * q,
|
2137
|
+
device const char * k,
|
2138
|
+
device const char * v,
|
2139
|
+
device const char * mask,
|
2140
|
+
device float * dst,
|
2141
|
+
constant int64_t & ne00,
|
2142
|
+
constant int64_t & ne01,
|
2143
|
+
constant int64_t & ne02,
|
2144
|
+
constant int64_t & ne03,
|
2145
|
+
constant uint64_t & nb00,
|
2146
|
+
constant uint64_t & nb01,
|
2147
|
+
constant uint64_t & nb02,
|
2148
|
+
constant uint64_t & nb03,
|
2149
|
+
constant int64_t & ne10,
|
2150
|
+
constant int64_t & ne11,
|
2151
|
+
constant int64_t & ne12,
|
2152
|
+
constant int64_t & ne13,
|
2153
|
+
constant uint64_t & nb10,
|
2154
|
+
constant uint64_t & nb11,
|
2155
|
+
constant uint64_t & nb12,
|
2156
|
+
constant uint64_t & nb13,
|
2157
|
+
constant int64_t & ne31,
|
2158
|
+
constant uint64_t & nb31,
|
2159
|
+
constant int64_t & ne0,
|
2160
|
+
constant int64_t & ne1,
|
2161
|
+
constant int64_t & ne2,
|
2162
|
+
constant int64_t & ne3,
|
2163
|
+
constant float & scale,
|
2164
|
+
threadgroup half * shared [[threadgroup(0)]],
|
2165
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
2166
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
2167
|
+
uint3 ntg[[threads_per_threadgroup]],
|
2168
|
+
ushort tiisg[[thread_index_in_simdgroup]],
|
2169
|
+
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
2170
|
+
const short nsg = ntg.y; // number of simdgroups
|
2171
|
+
|
2172
|
+
const short iq3 = tgpig[2];
|
2173
|
+
const short iq2 = tgpig[1];
|
2174
|
+
const short iq1 = tgpig[0]*Q;
|
2175
|
+
|
2176
|
+
const short D4 = D/4;
|
2177
|
+
const short D8 = D/8;
|
2178
|
+
const short Q8 = Q/8;
|
2179
|
+
const short NW = N_SIMDWIDTH;
|
2180
|
+
const short SH = (C + Q); // shared memory per simdgroup in (half)
|
2181
|
+
|
2182
|
+
const short T = D + 2*nsg*SH; // shared memory size per query in (half)
|
2183
|
+
const short TF = T/2; // shared memory size per query in (float)
|
2184
|
+
const short T4 = T/4; // shared memory size per query in (half4)
|
2185
|
+
|
2186
|
+
threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
|
2187
|
+
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
|
2188
|
+
threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
|
2189
|
+
|
2190
|
+
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
|
2191
|
+
simdgroup_half8x8 lo[D8];
|
2192
|
+
|
2193
|
+
// load heads from Q to shared memory
|
2194
|
+
for (short j = sgitg; j < Q; j += nsg) {
|
2195
|
+
device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03));
|
2196
|
+
|
2197
|
+
for (short i = tiisg; i < D4; i += NW) {
|
2198
|
+
if (iq1 + j < ne01) {
|
2199
|
+
sq4[j*T4 + i] = (half4) q4[i];
|
2200
|
+
} else {
|
2201
|
+
sq4[j*T4 + i] = 0.0h;
|
2202
|
+
}
|
2203
|
+
}
|
2204
|
+
}
|
2205
|
+
|
2206
|
+
// zero out lo
|
2207
|
+
for (short i = 0; i < D8; ++i) {
|
2208
|
+
lo[i] = make_filled_simdgroup_matrix<half, 8>(0.0h);
|
2209
|
+
}
|
2210
|
+
|
2211
|
+
// zero out shared memory SH
|
2212
|
+
for (short j = 0; j < Q; ++j) {
|
2213
|
+
for (short i = tiisg; i < SH; i += NW) {
|
2214
|
+
ss[j*TF + i] = 0.0f;
|
2215
|
+
}
|
2216
|
+
}
|
2217
|
+
|
2218
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
2219
|
+
|
2220
|
+
{
|
2221
|
+
float S[Q] = { [0 ... Q-1] = 0.0h };
|
2222
|
+
float M[Q] = { [0 ... Q-1] = -FLT_MAX/2 };
|
2223
|
+
|
2224
|
+
// assume K and V are same shape
|
2225
|
+
const short ne22 = ne12;
|
2226
|
+
const short ne23 = ne13;
|
2227
|
+
|
2228
|
+
const uint nb21 = nb11;
|
2229
|
+
const uint nb22 = nb12;
|
2230
|
+
const uint nb23 = nb13;
|
2231
|
+
|
2232
|
+
// broadcast
|
2233
|
+
const short rk2 = ne02/ne12;
|
2234
|
+
const short rk3 = ne03/ne13;
|
2235
|
+
|
2236
|
+
const short rv2 = ne02/ne22;
|
2237
|
+
const short rv3 = ne03/ne23;
|
2238
|
+
|
2239
|
+
// k indices
|
2240
|
+
const short ik2 = iq2/rk2;
|
2241
|
+
const short ik3 = iq3/rk3;
|
2242
|
+
|
2243
|
+
// v indices
|
2244
|
+
const short iv2 = iq2/rv2;
|
2245
|
+
const short iv3 = iq3/rv3;
|
2246
|
+
|
2247
|
+
// load the queries from shared memory into local memory
|
2248
|
+
simdgroup_half8x8 mq[D8];
|
2249
|
+
|
2250
|
+
for (short i = 0; i < D8; ++i) {
|
2251
|
+
simdgroup_load(mq[i], sq + i*8, T);
|
2252
|
+
}
|
2253
|
+
|
2254
|
+
// pointer to the mask
|
2255
|
+
device const half * mp = (device const half *) (mask + iq1*nb31);
|
2256
|
+
|
2257
|
+
// prepare diagonal scale matrix
|
2258
|
+
simdgroup_float8x8 mscale(scale);
|
2259
|
+
|
2260
|
+
// loop over the KV cache
|
2261
|
+
// each simdgroup handles blocks of Q rows and C columns
|
2262
|
+
for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) {
|
2263
|
+
const int ic = ic0 + C*sgitg;
|
2264
|
+
if (ic >= ne11) {
|
2265
|
+
break;
|
2266
|
+
}
|
2267
|
+
|
2268
|
+
// Q*K^T
|
2269
|
+
{
|
2270
|
+
for (short cc = 0; cc < C/8; ++cc) {
|
2271
|
+
simdgroup_float8x8 mqk = make_filled_simdgroup_matrix<float, 8>(0.h);
|
2272
|
+
|
2273
|
+
device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13));
|
2274
|
+
|
2275
|
+
for (short i = 0; i < D8; ++i) {
|
2276
|
+
simdgroup_half8x8 mk;
|
2277
|
+
simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose
|
2278
|
+
|
2279
|
+
simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
|
2280
|
+
}
|
2281
|
+
|
2282
|
+
// mqk = mqk*scale + mask
|
2283
|
+
simdgroup_half8x8 mm;
|
2284
|
+
simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(half), 0, false);
|
2285
|
+
simdgroup_multiply_accumulate(mqk, mqk, mscale, mm);
|
2286
|
+
|
2287
|
+
simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
|
2288
|
+
}
|
2289
|
+
}
|
2290
|
+
|
2291
|
+
// used to detect blocks full of -INF
|
2292
|
+
float smax = -INFINITY;
|
2293
|
+
|
2294
|
+
// online softmax
|
2295
|
+
{
|
2296
|
+
float ms[Q];
|
2297
|
+
|
2298
|
+
for (short j = 0; j < Q; ++j) {
|
2299
|
+
const short p = tiisg;
|
2300
|
+
|
2301
|
+
const float m = M[j];
|
2302
|
+
const float s = ss[j*TF + p];
|
2303
|
+
|
2304
|
+
smax = simd_max(max(smax, s));
|
2305
|
+
M[j] = simd_max(max(M[j], s));
|
2306
|
+
|
2307
|
+
ms[j] = exp(m - M[j]);
|
2308
|
+
const float vs = exp(s - M[j]);
|
2309
|
+
|
2310
|
+
S[j] = S[j]*ms[j] + simd_sum(vs);
|
2311
|
+
|
2312
|
+
// the P matrix from the paper (Q rows, C columns)
|
2313
|
+
ss[j*TF + p] = vs;
|
2314
|
+
}
|
2315
|
+
|
2316
|
+
// create a QxQ diagonal matrix for rescaling the output
|
2317
|
+
if (tiisg < Q) {
|
2318
|
+
ss[tiisg*TF + C + tiisg] = ms[tiisg];
|
2319
|
+
}
|
2320
|
+
}
|
2321
|
+
|
2322
|
+
// skip -INF blocks
|
2323
|
+
if (smax == -INFINITY) {
|
2324
|
+
continue;
|
2325
|
+
}
|
2326
|
+
|
2327
|
+
// O = diag(ms)*O
|
2328
|
+
{
|
2329
|
+
simdgroup_float8x8 mm;
|
2330
|
+
simdgroup_load(mm, ss + C, TF, 0, false);
|
2331
|
+
|
2332
|
+
for (short i = 0; i < D8; ++i) {
|
2333
|
+
simdgroup_multiply(lo[i], mm, lo[i]);
|
2334
|
+
}
|
2335
|
+
}
|
2336
|
+
|
2337
|
+
// O = O + (Q*K^T)*V
|
2338
|
+
{
|
2339
|
+
for (short cc = 0; cc < C/8; ++cc) {
|
2340
|
+
device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23));
|
2341
|
+
|
2342
|
+
for (short i = 0; i < D8; ++i) {
|
2343
|
+
simdgroup_half8x8 mk;
|
2344
|
+
simdgroup_load(mk, pv + i*8, nb21/sizeof(half), 0, false);
|
2345
|
+
|
2346
|
+
simdgroup_float8x8 mv;
|
2347
|
+
simdgroup_load(mv, ss + 8*cc, TF, 0, false);
|
2348
|
+
|
2349
|
+
simdgroup_multiply_accumulate(lo[i], mv, mk, lo[i]);
|
2350
|
+
}
|
2351
|
+
}
|
2352
|
+
}
|
2353
|
+
}
|
2354
|
+
|
2355
|
+
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
|
2356
|
+
for (short j = 0; j < Q; ++j) {
|
2357
|
+
if (tiisg == 0) {
|
2358
|
+
ss[j*TF + 0] = S[j];
|
2359
|
+
ss[j*TF + 1] = M[j];
|
2360
|
+
}
|
2361
|
+
}
|
2362
|
+
}
|
2363
|
+
|
2364
|
+
// reduce the warps sequentially
|
2365
|
+
for (short sg = 1; sg < nsg; ++sg) {
|
2366
|
+
float S = { 0.0h };
|
2367
|
+
float M = { -FLT_MAX/2 };
|
2368
|
+
|
2369
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
2370
|
+
|
2371
|
+
// each simdgroup stores its output to shared memory, reusing sq
|
2372
|
+
if (sgitg == sg) {
|
2373
|
+
for (short i = 0; i < D8; ++i) {
|
2374
|
+
simdgroup_store(lo[i], sq + i*8, T, 0, false);
|
2375
|
+
}
|
2376
|
+
}
|
2377
|
+
|
2378
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
2379
|
+
|
2380
|
+
// the first simdgroup accumulates the results from the other simdgroups
|
2381
|
+
if (sgitg == 0) {
|
2382
|
+
for (short j = 0; j < Q; ++j) {
|
2383
|
+
const float S0 = ss[j*TF + 0];
|
2384
|
+
const float S1 = ss[j*TF + sg*SH + 0];
|
2385
|
+
|
2386
|
+
const float M0 = ss[j*TF + 1];
|
2387
|
+
const float M1 = ss[j*TF + sg*SH + 1];
|
2388
|
+
|
2389
|
+
M = max(M0, M1);
|
2390
|
+
|
2391
|
+
const float ms0 = exp(M0 - M);
|
2392
|
+
const float ms1 = exp(M1 - M);
|
2393
|
+
|
2394
|
+
S = S0*ms0 + S1*ms1;
|
2395
|
+
|
2396
|
+
if (tiisg == 0) {
|
2397
|
+
ss[j*TF + 0] = S;
|
2398
|
+
ss[j*TF + 1] = M;
|
2399
|
+
|
2400
|
+
ss[j*TF + C + j ] = ms0;
|
2401
|
+
ss[j*TF + C + j + sg*SH] = ms1;
|
2402
|
+
}
|
2403
|
+
}
|
2404
|
+
|
2405
|
+
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
|
2406
|
+
{
|
2407
|
+
simdgroup_half8x8 t;
|
2408
|
+
simdgroup_float8x8 ms0;
|
2409
|
+
simdgroup_float8x8 ms1;
|
2410
|
+
|
2411
|
+
simdgroup_load(ms0, ss + C, TF, 0, false);
|
2412
|
+
simdgroup_load(ms1, ss + C + sg*SH, TF, 0, false);
|
2413
|
+
|
2414
|
+
for (short i = 0; i < D8; ++i) {
|
2415
|
+
simdgroup_load (t, sq + i*8, T, 0, false);
|
2416
|
+
simdgroup_multiply(t, ms1, t);
|
2417
|
+
|
2418
|
+
simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t);
|
2419
|
+
}
|
2420
|
+
}
|
2421
|
+
}
|
2422
|
+
}
|
2423
|
+
|
2424
|
+
// store result to shared memory (reuse sq)
|
2425
|
+
if (sgitg == 0) {
|
2426
|
+
for (short i = 0; i < D8; ++i) {
|
2427
|
+
simdgroup_store(lo[i], sq + i*8, T, 0, false);
|
2428
|
+
}
|
2429
|
+
}
|
2430
|
+
|
2431
|
+
device float4 * dst4 = (device float4 *) dst;
|
2432
|
+
|
2433
|
+
// final rescale with 1/S and store to global memory
|
2434
|
+
if (sgitg == 0) {
|
2435
|
+
for (short j = 0; j < Q && iq1 + j < ne01; ++j) {
|
2436
|
+
const float S = ss[j*TF + 0];
|
2437
|
+
|
2438
|
+
for (short i = tiisg; i < D4; i += NW) {
|
2439
|
+
dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) sq4[j*T4 + i]/S;
|
2440
|
+
}
|
2441
|
+
}
|
2442
|
+
}
|
2443
|
+
}
|
2444
|
+
|
2445
|
+
template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64>;
|
2446
|
+
template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80>;
|
2447
|
+
template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<96>;
|
2448
|
+
template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<112>;
|
2449
|
+
template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128>;
|
2450
|
+
template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256>;
|
2451
|
+
|
2452
|
+
template<int64_t D, int64_t Q = 1, int64_t C = 32> // head size, queries per threadgroup, cache items per threadgroup
|
2453
|
+
kernel void kernel_flash_attn_ext_vec_f16(
|
2454
|
+
device const char * q,
|
2455
|
+
device const char * k,
|
2456
|
+
device const char * v,
|
2457
|
+
device const char * mask,
|
2458
|
+
device float * dst,
|
2459
|
+
constant int64_t & ne00,
|
2460
|
+
constant int64_t & ne01,
|
2461
|
+
constant int64_t & ne02,
|
2462
|
+
constant int64_t & ne03,
|
2463
|
+
constant uint64_t & nb00,
|
2464
|
+
constant uint64_t & nb01,
|
2465
|
+
constant uint64_t & nb02,
|
2466
|
+
constant uint64_t & nb03,
|
2467
|
+
constant int64_t & ne10,
|
2468
|
+
constant int64_t & ne11,
|
2469
|
+
constant int64_t & ne12,
|
2470
|
+
constant int64_t & ne13,
|
2471
|
+
constant uint64_t & nb10,
|
2472
|
+
constant uint64_t & nb11,
|
2473
|
+
constant uint64_t & nb12,
|
2474
|
+
constant uint64_t & nb13,
|
2475
|
+
constant int64_t & ne31,
|
2476
|
+
constant uint64_t & nb31,
|
2477
|
+
constant int64_t & ne0,
|
2478
|
+
constant int64_t & ne1,
|
2479
|
+
constant int64_t & ne2,
|
2480
|
+
constant int64_t & ne3,
|
2481
|
+
constant float & scale,
|
2482
|
+
threadgroup half * shared [[threadgroup(0)]],
|
2483
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
2484
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
2485
|
+
uint3 ntg[[threads_per_threadgroup]],
|
2486
|
+
ushort tiisg[[thread_index_in_simdgroup]],
|
2487
|
+
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
2488
|
+
const short nsg = ntg.y; // number of simdgroups
|
2489
|
+
|
2490
|
+
const short iq3 = tgpig[2];
|
2491
|
+
const short iq2 = tgpig[1];
|
2492
|
+
const short iq1 = tgpig[0];
|
2493
|
+
|
2494
|
+
const short D4 = D/4;
|
2495
|
+
const short NW = N_SIMDWIDTH;
|
2496
|
+
const short SH = (C + Q); // shared memory per simdgroup in (half)
|
2497
|
+
|
2498
|
+
const short T = D + 2*nsg*SH; // shared memory size per query in (half)
|
2499
|
+
|
2500
|
+
//threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
|
2501
|
+
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
|
2502
|
+
threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
|
2503
|
+
threadgroup float4 * ss4 = (threadgroup float4 *) (shared + 2*sgitg*SH + 1*D); // same as above but in half4
|
2504
|
+
threadgroup half4 * sr4 = (threadgroup half4 *) (shared + sgitg*D + 1*T); // scratch buffer for the results
|
2505
|
+
|
2506
|
+
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
|
2507
|
+
half4 lo[D4/NW];
|
2508
|
+
|
2509
|
+
// load heads from Q to shared memory
|
2510
|
+
device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03));
|
2511
|
+
|
2512
|
+
for (short i = tiisg; i < D4; i += NW) {
|
2513
|
+
if (iq1 < ne01) {
|
2514
|
+
sq4[i] = (half4) q4[i];
|
2515
|
+
} else {
|
2516
|
+
sq4[i] = 0.0h;
|
2517
|
+
}
|
2518
|
+
}
|
2519
|
+
|
2520
|
+
// zero out lo
|
2521
|
+
for (short i = tiisg; i < D4; i += NW) {
|
2522
|
+
lo[i/NW] = 0.0h;
|
2523
|
+
}
|
2524
|
+
|
2525
|
+
// zero out shared memory SH
|
2526
|
+
for (short i = tiisg; i < SH/4; i += NW) {
|
2527
|
+
ss4[i] = 0.0h;
|
2528
|
+
}
|
2529
|
+
|
2530
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
2531
|
+
|
2532
|
+
{
|
2533
|
+
float S = { 0.0h };
|
2534
|
+
float M = { -FLT_MAX/2 };
|
2535
|
+
|
2536
|
+
// assume K and V are same shape
|
2537
|
+
const short ne22 = ne12;
|
2538
|
+
const short ne23 = ne13;
|
2539
|
+
|
2540
|
+
const uint nb21 = nb11;
|
2541
|
+
const uint nb22 = nb12;
|
2542
|
+
const uint nb23 = nb13;
|
2543
|
+
|
2544
|
+
// broadcast
|
2545
|
+
const short rk2 = ne02/ne12;
|
2546
|
+
const short rk3 = ne03/ne13;
|
2547
|
+
|
2548
|
+
const short rv2 = ne02/ne22;
|
2549
|
+
const short rv3 = ne03/ne23;
|
2550
|
+
|
2551
|
+
// k indices
|
2552
|
+
const short ik2 = iq2 / rk2;
|
2553
|
+
const short ik3 = iq3 / rk3;
|
2554
|
+
|
2555
|
+
// v indices
|
2556
|
+
const short iv2 = iq2 / rv2;
|
2557
|
+
const short iv3 = iq3 / rv3;
|
2558
|
+
|
2559
|
+
// load the queries from shared memory into local memory
|
2560
|
+
half4 mq[D4];
|
2561
|
+
|
2562
|
+
for (short ii = 0; ii < D4; ii += NW) {
|
2563
|
+
short i = ii + tiisg;
|
2564
|
+
mq[i] = sq4[i];
|
2565
|
+
}
|
2566
|
+
|
2567
|
+
// pointer to the mask
|
2568
|
+
device const half4 * mp4 = (device const half4 *) (mask + iq1*nb31);
|
2569
|
+
|
2570
|
+
// loop over the KV cache
|
2571
|
+
// each simdgroup handles blocks of Q rows and C columns
|
2572
|
+
for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) {
|
2573
|
+
const int ic = ic0 + C*sgitg;
|
2574
|
+
if (ic >= ne11) {
|
2575
|
+
break;
|
2576
|
+
}
|
2577
|
+
|
2578
|
+
// Q*K^T
|
2579
|
+
{
|
2580
|
+
#pragma unroll
|
2581
|
+
for (short cc = 0; cc < C/4; ++cc) {
|
2582
|
+
float4 mqk = { 0.0h };
|
2583
|
+
|
2584
|
+
device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + 4*cc)*nb11 + ik2*nb12 + ik3*nb13));
|
2585
|
+
|
2586
|
+
#pragma unroll
|
2587
|
+
for (short ii = 0; ii < D4; ii += NW) {
|
2588
|
+
const short i = ii + tiisg;
|
2589
|
+
|
2590
|
+
half4x4 mk;
|
2591
|
+
mk[0] = pk4[i + 0*(nb11/8)];
|
2592
|
+
mk[1] = pk4[i + 1*(nb11/8)];
|
2593
|
+
mk[2] = pk4[i + 2*(nb11/8)];
|
2594
|
+
mk[3] = pk4[i + 3*(nb11/8)];
|
2595
|
+
|
2596
|
+
mqk += (float4) (mq[i] * mk);
|
2597
|
+
}
|
2598
|
+
|
2599
|
+
// reduce the results from the threads in the simdgroup
|
2600
|
+
mqk += simd_shuffle_down(mqk, 16);
|
2601
|
+
mqk += simd_shuffle_down(mqk, 8);
|
2602
|
+
mqk += simd_shuffle_down(mqk, 4);
|
2603
|
+
mqk += simd_shuffle_down(mqk, 2);
|
2604
|
+
mqk += simd_shuffle_down(mqk, 1);
|
2605
|
+
|
2606
|
+
// mqk = mqk*scale + mask
|
2607
|
+
if (tiisg == 0) {
|
2608
|
+
float4 mm = (float4) mp4[ic/4 + cc];
|
2609
|
+
mqk = mqk*scale + mm;
|
2610
|
+
|
2611
|
+
ss4[cc] = mqk;
|
2612
|
+
}
|
2613
|
+
}
|
2614
|
+
}
|
2615
|
+
|
2616
|
+
// online softmax
|
2617
|
+
{
|
2618
|
+
const short p = tiisg;
|
2619
|
+
|
2620
|
+
const float m = M;
|
2621
|
+
const float s = ss[p];
|
2622
|
+
|
2623
|
+
M = simd_max(max(M, s));
|
2624
|
+
|
2625
|
+
const float ms = exp(m - M);
|
2626
|
+
const float vs = exp(s - M);
|
2627
|
+
|
2628
|
+
S = S*ms + simd_sum(vs);
|
2629
|
+
|
2630
|
+
// the P matrix from the paper (Q rows, C columns)
|
2631
|
+
ss[p] = vs;
|
2632
|
+
|
2633
|
+
// O = diag(ms)*O
|
2634
|
+
#pragma unroll
|
2635
|
+
for (short ii = 0; ii < D4; ii += NW) {
|
2636
|
+
const short i = ii + tiisg;
|
2637
|
+
lo[i/NW] *= ms;
|
2638
|
+
}
|
2639
|
+
}
|
2640
|
+
|
2641
|
+
// O = O + (Q*K^T)*V
|
2642
|
+
{
|
2643
|
+
#pragma unroll
|
2644
|
+
for (short cc = 0; cc < C/4; ++cc) {
|
2645
|
+
device const half4 * pv4 = (device const half4 *) ((device const char *) v + ((ic + 4*cc)*nb21 + iv2*nb22 + iv3*nb23));
|
2646
|
+
|
2647
|
+
#pragma unroll
|
2648
|
+
for (short ii = 0; ii < D4; ii += NW) {
|
2649
|
+
const short i = ii + tiisg;
|
2650
|
+
|
2651
|
+
lo[i/NW] += pv4[i + 0*(nb21/8)] * ss[4*cc + 0];
|
2652
|
+
lo[i/NW] += pv4[i + 1*(nb21/8)] * ss[4*cc + 1];
|
2653
|
+
lo[i/NW] += pv4[i + 2*(nb21/8)] * ss[4*cc + 2];
|
2654
|
+
lo[i/NW] += pv4[i + 3*(nb21/8)] * ss[4*cc + 3];
|
2655
|
+
}
|
2656
|
+
}
|
2657
|
+
}
|
2658
|
+
|
2659
|
+
}
|
2660
|
+
|
2661
|
+
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
|
2662
|
+
if (tiisg == 0) {
|
2663
|
+
ss[0] = S;
|
2664
|
+
ss[1] = M;
|
2665
|
+
}
|
2666
|
+
}
|
2667
|
+
|
2668
|
+
// store results to shared memory
|
2669
|
+
for (short ii = 0; ii < D4; ii += NW) {
|
2670
|
+
short i = ii + tiisg;
|
2671
|
+
sr4[i] = lo[ii/NW];
|
2672
|
+
}
|
2673
|
+
|
2674
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
2675
|
+
|
2676
|
+
// parallel reduce
|
2677
|
+
for (short r = nsg/2; r > 0; r >>= 1) {
|
2678
|
+
if (sgitg < r) {
|
2679
|
+
const float S0 = ss[ 0];
|
2680
|
+
const float S1 = ss[r*SH + 0];
|
2681
|
+
|
2682
|
+
const float M0 = ss[ 1];
|
2683
|
+
const float M1 = ss[r*SH + 1];
|
2684
|
+
|
2685
|
+
const float M = max(M0, M1);
|
2686
|
+
|
2687
|
+
const float ms0 = exp(M0 - M);
|
2688
|
+
const float ms1 = exp(M1 - M);
|
2689
|
+
|
2690
|
+
const float S = S0*ms0 + S1*ms1;
|
2691
|
+
|
2692
|
+
if (tiisg == 0) {
|
2693
|
+
ss[0] = S;
|
2694
|
+
ss[1] = M;
|
2695
|
+
}
|
2696
|
+
|
2697
|
+
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
|
2698
|
+
for (short ii = 0; ii < D4; ii += NW) {
|
2699
|
+
short i = ii + tiisg;
|
2700
|
+
sr4[i] = sr4[i]*ms0 + sr4[i + r*D4]*ms1;
|
2701
|
+
}
|
2702
|
+
}
|
2703
|
+
|
2704
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
2705
|
+
}
|
2706
|
+
|
2707
|
+
device float4 * dst4 = (device float4 *) dst;
|
2708
|
+
|
2709
|
+
// final rescale with 1/S and store to global memory
|
2710
|
+
if (sgitg == 0) {
|
2711
|
+
const float S = ss[0];
|
2712
|
+
|
2713
|
+
for (short ii = 0; ii < D4; ii += NW) {
|
2714
|
+
short i = ii + tiisg;
|
2715
|
+
dst4[(iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D4 + i] = (float4) sr4[i]/S;
|
2716
|
+
}
|
2717
|
+
}
|
2718
|
+
}
|
2719
|
+
|
2720
|
+
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>;
|
2721
|
+
template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>;
|
2722
|
+
|
2087
2723
|
kernel void kernel_cpy_f16_f16(
|
2088
2724
|
device const half * src0,
|
2089
2725
|
device half * dst,
|