llama_cpp 0.14.7 → 0.15.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +13 -0
- data/README.md +2 -2
- data/ext/llama_cpp/extconf.rb +2 -1
- data/ext/llama_cpp/llama_cpp.cpp +53 -9
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +18 -3
- data/vendor/tmp/llama.cpp/Makefile +41 -16
- data/vendor/tmp/llama.cpp/ggml-backend.c +7 -5
- data/vendor/tmp/llama.cpp/ggml-cuda.cu +6 -0
- data/vendor/tmp/llama.cpp/ggml-impl.h +1 -1
- data/vendor/tmp/llama.cpp/ggml-kompute.cpp +7 -0
- data/vendor/tmp/llama.cpp/ggml-metal.m +376 -176
- data/vendor/tmp/llama.cpp/ggml-metal.metal +654 -18
- data/vendor/tmp/llama.cpp/ggml-quants.c +284 -0
- data/vendor/tmp/llama.cpp/ggml-sycl.cpp +17 -7
- data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +5 -0
- data/vendor/tmp/llama.cpp/ggml.c +391 -27
- data/vendor/tmp/llama.cpp/ggml.h +22 -0
- data/vendor/tmp/llama.cpp/llama.cpp +623 -395
- data/vendor/tmp/llama.cpp/llama.h +27 -9
- data/vendor/tmp/llama.cpp/sgemm.cpp +83 -87
- data/vendor/tmp/llama.cpp/sgemm.h +4 -2
- data/vendor/tmp/llama.cpp/unicode-data.cpp +1 -1
- data/vendor/tmp/llama.cpp/unicode-data.h +2 -2
- data/vendor/tmp/llama.cpp/unicode.cpp +448 -39
- data/vendor/tmp/llama.cpp/unicode.h +2 -1
- metadata +3 -3
@@ -352,11 +352,12 @@ kernel void kernel_sum_rows(
|
|
352
352
|
dst_row[0] = row_sum;
|
353
353
|
}
|
354
354
|
|
355
|
+
template<typename T>
|
355
356
|
kernel void kernel_soft_max(
|
356
|
-
device const
|
357
|
-
device const
|
358
|
-
device const
|
359
|
-
device
|
357
|
+
device const char * src0,
|
358
|
+
device const char * src1,
|
359
|
+
device const char * src2,
|
360
|
+
device char * dst,
|
360
361
|
constant int64_t & ne00,
|
361
362
|
constant int64_t & ne01,
|
362
363
|
constant int64_t & ne02,
|
@@ -375,10 +376,10 @@ kernel void kernel_soft_max(
|
|
375
376
|
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
|
376
377
|
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
377
378
|
|
378
|
-
device const float * psrc0 =
|
379
|
-
device const
|
380
|
-
device const
|
381
|
-
device float * pdst =
|
379
|
+
device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
380
|
+
device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00 : nullptr;
|
381
|
+
device const T * ppos = src2 != src0 ? (device const T *) src2 : nullptr;
|
382
|
+
device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
382
383
|
|
383
384
|
float slope = 0.0f;
|
384
385
|
|
@@ -456,11 +457,12 @@ kernel void kernel_soft_max(
|
|
456
457
|
}
|
457
458
|
}
|
458
459
|
|
460
|
+
template<typename T>
|
459
461
|
kernel void kernel_soft_max_4(
|
460
|
-
device const
|
461
|
-
device const
|
462
|
-
device const
|
463
|
-
device
|
462
|
+
device const char * src0,
|
463
|
+
device const char * src1,
|
464
|
+
device const char * src2,
|
465
|
+
device char * dst,
|
464
466
|
constant int64_t & ne00,
|
465
467
|
constant int64_t & ne01,
|
466
468
|
constant int64_t & ne02,
|
@@ -479,10 +481,10 @@ kernel void kernel_soft_max_4(
|
|
479
481
|
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
|
480
482
|
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
481
483
|
|
482
|
-
device const float4 * psrc4 =
|
483
|
-
device const
|
484
|
-
device const
|
485
|
-
device float4 * pdst4 =
|
484
|
+
device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
|
485
|
+
device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00/4 : nullptr;
|
486
|
+
device const T * ppos = src2 != src0 ? (device const T *) src2 : nullptr;
|
487
|
+
device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
|
486
488
|
|
487
489
|
float slope = 0.0f;
|
488
490
|
|
@@ -499,7 +501,7 @@ kernel void kernel_soft_max_4(
|
|
499
501
|
float4 lmax4 = -INFINITY;
|
500
502
|
|
501
503
|
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
502
|
-
lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f));
|
504
|
+
lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f)));
|
503
505
|
}
|
504
506
|
|
505
507
|
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
|
@@ -525,7 +527,7 @@ kernel void kernel_soft_max_4(
|
|
525
527
|
// parallel sum
|
526
528
|
float4 lsum4 = 0.0f;
|
527
529
|
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
528
|
-
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f)) - max_val);
|
530
|
+
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f))) - max_val);
|
529
531
|
lsum4 += exp_psrc4;
|
530
532
|
pdst4[i00] = exp_psrc4;
|
531
533
|
}
|
@@ -562,6 +564,14 @@ kernel void kernel_soft_max_4(
|
|
562
564
|
}
|
563
565
|
}
|
564
566
|
|
567
|
+
typedef decltype(kernel_soft_max<float>) kernel_soft_max_t;
|
568
|
+
typedef decltype(kernel_soft_max_4<float4>) kernel_soft_max_4_t;
|
569
|
+
|
570
|
+
template [[host_name("kernel_soft_max_f16")]] kernel kernel_soft_max_t kernel_soft_max<half>;
|
571
|
+
template [[host_name("kernel_soft_max_f32")]] kernel kernel_soft_max_t kernel_soft_max<float>;
|
572
|
+
template [[host_name("kernel_soft_max_f16_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4<half4>;
|
573
|
+
template [[host_name("kernel_soft_max_f32_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4<float4>;
|
574
|
+
|
565
575
|
kernel void kernel_diag_mask_inf(
|
566
576
|
device const float * src0,
|
567
577
|
device float * dst,
|
@@ -2084,6 +2094,632 @@ kernel void kernel_leaky_relu_f32(
|
|
2084
2094
|
dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope;
|
2085
2095
|
}
|
2086
2096
|
|
2097
|
+
typedef void (flash_attn_ext_f16_t)(
|
2098
|
+
device const char * q,
|
2099
|
+
device const char * k,
|
2100
|
+
device const char * v,
|
2101
|
+
device const char * mask,
|
2102
|
+
device float * dst,
|
2103
|
+
constant int64_t & ne00,
|
2104
|
+
constant int64_t & ne01,
|
2105
|
+
constant int64_t & ne02,
|
2106
|
+
constant int64_t & ne03,
|
2107
|
+
constant uint64_t & nb00,
|
2108
|
+
constant uint64_t & nb01,
|
2109
|
+
constant uint64_t & nb02,
|
2110
|
+
constant uint64_t & nb03,
|
2111
|
+
constant int64_t & ne10,
|
2112
|
+
constant int64_t & ne11,
|
2113
|
+
constant int64_t & ne12,
|
2114
|
+
constant int64_t & ne13,
|
2115
|
+
constant uint64_t & nb10,
|
2116
|
+
constant uint64_t & nb11,
|
2117
|
+
constant uint64_t & nb12,
|
2118
|
+
constant uint64_t & nb13,
|
2119
|
+
constant int64_t & ne31,
|
2120
|
+
constant uint64_t & nb31,
|
2121
|
+
constant int64_t & ne0,
|
2122
|
+
constant int64_t & ne1,
|
2123
|
+
constant int64_t & ne2,
|
2124
|
+
constant int64_t & ne3,
|
2125
|
+
constant float & scale,
|
2126
|
+
threadgroup half * shared,
|
2127
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
2128
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
2129
|
+
uint3 ntg[[threads_per_threadgroup]],
|
2130
|
+
ushort tiisg[[thread_index_in_simdgroup]],
|
2131
|
+
ushort sgitg[[simdgroup_index_in_threadgroup]]);
|
2132
|
+
|
2133
|
+
// ref: https://arxiv.org/pdf/2307.08691.pdf
|
2134
|
+
template<int64_t D, int64_t Q = 8, int64_t C = 32> // head size, queries per threadgroup, cache items per threadgroup
|
2135
|
+
kernel void kernel_flash_attn_ext_f16(
|
2136
|
+
device const char * q,
|
2137
|
+
device const char * k,
|
2138
|
+
device const char * v,
|
2139
|
+
device const char * mask,
|
2140
|
+
device float * dst,
|
2141
|
+
constant int64_t & ne00,
|
2142
|
+
constant int64_t & ne01,
|
2143
|
+
constant int64_t & ne02,
|
2144
|
+
constant int64_t & ne03,
|
2145
|
+
constant uint64_t & nb00,
|
2146
|
+
constant uint64_t & nb01,
|
2147
|
+
constant uint64_t & nb02,
|
2148
|
+
constant uint64_t & nb03,
|
2149
|
+
constant int64_t & ne10,
|
2150
|
+
constant int64_t & ne11,
|
2151
|
+
constant int64_t & ne12,
|
2152
|
+
constant int64_t & ne13,
|
2153
|
+
constant uint64_t & nb10,
|
2154
|
+
constant uint64_t & nb11,
|
2155
|
+
constant uint64_t & nb12,
|
2156
|
+
constant uint64_t & nb13,
|
2157
|
+
constant int64_t & ne31,
|
2158
|
+
constant uint64_t & nb31,
|
2159
|
+
constant int64_t & ne0,
|
2160
|
+
constant int64_t & ne1,
|
2161
|
+
constant int64_t & ne2,
|
2162
|
+
constant int64_t & ne3,
|
2163
|
+
constant float & scale,
|
2164
|
+
threadgroup half * shared [[threadgroup(0)]],
|
2165
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
2166
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
2167
|
+
uint3 ntg[[threads_per_threadgroup]],
|
2168
|
+
ushort tiisg[[thread_index_in_simdgroup]],
|
2169
|
+
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
2170
|
+
const short nsg = ntg.y; // number of simdgroups
|
2171
|
+
|
2172
|
+
const short iq3 = tgpig[2];
|
2173
|
+
const short iq2 = tgpig[1];
|
2174
|
+
const short iq1 = tgpig[0]*Q;
|
2175
|
+
|
2176
|
+
const short D4 = D/4;
|
2177
|
+
const short D8 = D/8;
|
2178
|
+
const short Q8 = Q/8;
|
2179
|
+
const short NW = N_SIMDWIDTH;
|
2180
|
+
const short SH = (C + Q); // shared memory per simdgroup in (half)
|
2181
|
+
|
2182
|
+
const short T = D + 2*nsg*SH; // shared memory size per query in (half)
|
2183
|
+
const short TF = T/2; // shared memory size per query in (float)
|
2184
|
+
const short T4 = T/4; // shared memory size per query in (half4)
|
2185
|
+
|
2186
|
+
threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
|
2187
|
+
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
|
2188
|
+
threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
|
2189
|
+
|
2190
|
+
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
|
2191
|
+
simdgroup_half8x8 lo[D8];
|
2192
|
+
|
2193
|
+
// load heads from Q to shared memory
|
2194
|
+
for (short j = sgitg; j < Q; j += nsg) {
|
2195
|
+
device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03));
|
2196
|
+
|
2197
|
+
for (short i = tiisg; i < D4; i += NW) {
|
2198
|
+
if (iq1 + j < ne01) {
|
2199
|
+
sq4[j*T4 + i] = (half4) q4[i];
|
2200
|
+
} else {
|
2201
|
+
sq4[j*T4 + i] = 0.0h;
|
2202
|
+
}
|
2203
|
+
}
|
2204
|
+
}
|
2205
|
+
|
2206
|
+
// zero out lo
|
2207
|
+
for (short i = 0; i < D8; ++i) {
|
2208
|
+
lo[i] = make_filled_simdgroup_matrix<half, 8>(0.0h);
|
2209
|
+
}
|
2210
|
+
|
2211
|
+
// zero out shared memory SH
|
2212
|
+
for (short j = 0; j < Q; ++j) {
|
2213
|
+
for (short i = tiisg; i < SH; i += NW) {
|
2214
|
+
ss[j*TF + i] = 0.0f;
|
2215
|
+
}
|
2216
|
+
}
|
2217
|
+
|
2218
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
2219
|
+
|
2220
|
+
{
|
2221
|
+
float S[Q] = { [0 ... Q-1] = 0.0h };
|
2222
|
+
float M[Q] = { [0 ... Q-1] = -FLT_MAX/2 };
|
2223
|
+
|
2224
|
+
// assume K and V are same shape
|
2225
|
+
const short ne22 = ne12;
|
2226
|
+
const short ne23 = ne13;
|
2227
|
+
|
2228
|
+
const uint nb21 = nb11;
|
2229
|
+
const uint nb22 = nb12;
|
2230
|
+
const uint nb23 = nb13;
|
2231
|
+
|
2232
|
+
// broadcast
|
2233
|
+
const short rk2 = ne02/ne12;
|
2234
|
+
const short rk3 = ne03/ne13;
|
2235
|
+
|
2236
|
+
const short rv2 = ne02/ne22;
|
2237
|
+
const short rv3 = ne03/ne23;
|
2238
|
+
|
2239
|
+
// k indices
|
2240
|
+
const short ik2 = iq2/rk2;
|
2241
|
+
const short ik3 = iq3/rk3;
|
2242
|
+
|
2243
|
+
// v indices
|
2244
|
+
const short iv2 = iq2/rv2;
|
2245
|
+
const short iv3 = iq3/rv3;
|
2246
|
+
|
2247
|
+
// load the queries from shared memory into local memory
|
2248
|
+
simdgroup_half8x8 mq[D8];
|
2249
|
+
|
2250
|
+
for (short i = 0; i < D8; ++i) {
|
2251
|
+
simdgroup_load(mq[i], sq + i*8, T);
|
2252
|
+
}
|
2253
|
+
|
2254
|
+
// pointer to the mask
|
2255
|
+
device const half * mp = (device const half *) (mask + iq1*nb31);
|
2256
|
+
|
2257
|
+
// prepare diagonal scale matrix
|
2258
|
+
simdgroup_float8x8 mscale(scale);
|
2259
|
+
|
2260
|
+
// loop over the KV cache
|
2261
|
+
// each simdgroup handles blocks of Q rows and C columns
|
2262
|
+
for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) {
|
2263
|
+
const int ic = ic0 + C*sgitg;
|
2264
|
+
if (ic >= ne11) {
|
2265
|
+
break;
|
2266
|
+
}
|
2267
|
+
|
2268
|
+
// Q*K^T
|
2269
|
+
{
|
2270
|
+
for (short cc = 0; cc < C/8; ++cc) {
|
2271
|
+
simdgroup_float8x8 mqk = make_filled_simdgroup_matrix<float, 8>(0.h);
|
2272
|
+
|
2273
|
+
device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13));
|
2274
|
+
|
2275
|
+
for (short i = 0; i < D8; ++i) {
|
2276
|
+
simdgroup_half8x8 mk;
|
2277
|
+
simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose
|
2278
|
+
|
2279
|
+
simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
|
2280
|
+
}
|
2281
|
+
|
2282
|
+
// mqk = mqk*scale + mask
|
2283
|
+
simdgroup_half8x8 mm;
|
2284
|
+
simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(half), 0, false);
|
2285
|
+
simdgroup_multiply_accumulate(mqk, mqk, mscale, mm);
|
2286
|
+
|
2287
|
+
simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
|
2288
|
+
}
|
2289
|
+
}
|
2290
|
+
|
2291
|
+
// used to detect blocks full of -INF
|
2292
|
+
float smax = -INFINITY;
|
2293
|
+
|
2294
|
+
// online softmax
|
2295
|
+
{
|
2296
|
+
float ms[Q];
|
2297
|
+
|
2298
|
+
for (short j = 0; j < Q; ++j) {
|
2299
|
+
const short p = tiisg;
|
2300
|
+
|
2301
|
+
const float m = M[j];
|
2302
|
+
const float s = ss[j*TF + p];
|
2303
|
+
|
2304
|
+
smax = simd_max(max(smax, s));
|
2305
|
+
M[j] = simd_max(max(M[j], s));
|
2306
|
+
|
2307
|
+
ms[j] = exp(m - M[j]);
|
2308
|
+
const float vs = exp(s - M[j]);
|
2309
|
+
|
2310
|
+
S[j] = S[j]*ms[j] + simd_sum(vs);
|
2311
|
+
|
2312
|
+
// the P matrix from the paper (Q rows, C columns)
|
2313
|
+
ss[j*TF + p] = vs;
|
2314
|
+
}
|
2315
|
+
|
2316
|
+
// create a QxQ diagonal matrix for rescaling the output
|
2317
|
+
if (tiisg < Q) {
|
2318
|
+
ss[tiisg*TF + C + tiisg] = ms[tiisg];
|
2319
|
+
}
|
2320
|
+
}
|
2321
|
+
|
2322
|
+
// skip -INF blocks
|
2323
|
+
if (smax == -INFINITY) {
|
2324
|
+
continue;
|
2325
|
+
}
|
2326
|
+
|
2327
|
+
// O = diag(ms)*O
|
2328
|
+
{
|
2329
|
+
simdgroup_float8x8 mm;
|
2330
|
+
simdgroup_load(mm, ss + C, TF, 0, false);
|
2331
|
+
|
2332
|
+
for (short i = 0; i < D8; ++i) {
|
2333
|
+
simdgroup_multiply(lo[i], mm, lo[i]);
|
2334
|
+
}
|
2335
|
+
}
|
2336
|
+
|
2337
|
+
// O = O + (Q*K^T)*V
|
2338
|
+
{
|
2339
|
+
for (short cc = 0; cc < C/8; ++cc) {
|
2340
|
+
device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23));
|
2341
|
+
|
2342
|
+
for (short i = 0; i < D8; ++i) {
|
2343
|
+
simdgroup_half8x8 mk;
|
2344
|
+
simdgroup_load(mk, pv + i*8, nb21/sizeof(half), 0, false);
|
2345
|
+
|
2346
|
+
simdgroup_float8x8 mv;
|
2347
|
+
simdgroup_load(mv, ss + 8*cc, TF, 0, false);
|
2348
|
+
|
2349
|
+
simdgroup_multiply_accumulate(lo[i], mv, mk, lo[i]);
|
2350
|
+
}
|
2351
|
+
}
|
2352
|
+
}
|
2353
|
+
}
|
2354
|
+
|
2355
|
+
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
|
2356
|
+
for (short j = 0; j < Q; ++j) {
|
2357
|
+
if (tiisg == 0) {
|
2358
|
+
ss[j*TF + 0] = S[j];
|
2359
|
+
ss[j*TF + 1] = M[j];
|
2360
|
+
}
|
2361
|
+
}
|
2362
|
+
}
|
2363
|
+
|
2364
|
+
// reduce the warps sequentially
|
2365
|
+
for (short sg = 1; sg < nsg; ++sg) {
|
2366
|
+
float S = { 0.0h };
|
2367
|
+
float M = { -FLT_MAX/2 };
|
2368
|
+
|
2369
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
2370
|
+
|
2371
|
+
// each simdgroup stores its output to shared memory, reusing sq
|
2372
|
+
if (sgitg == sg) {
|
2373
|
+
for (short i = 0; i < D8; ++i) {
|
2374
|
+
simdgroup_store(lo[i], sq + i*8, T, 0, false);
|
2375
|
+
}
|
2376
|
+
}
|
2377
|
+
|
2378
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
2379
|
+
|
2380
|
+
// the first simdgroup accumulates the results from the other simdgroups
|
2381
|
+
if (sgitg == 0) {
|
2382
|
+
for (short j = 0; j < Q; ++j) {
|
2383
|
+
const float S0 = ss[j*TF + 0];
|
2384
|
+
const float S1 = ss[j*TF + sg*SH + 0];
|
2385
|
+
|
2386
|
+
const float M0 = ss[j*TF + 1];
|
2387
|
+
const float M1 = ss[j*TF + sg*SH + 1];
|
2388
|
+
|
2389
|
+
M = max(M0, M1);
|
2390
|
+
|
2391
|
+
const float ms0 = exp(M0 - M);
|
2392
|
+
const float ms1 = exp(M1 - M);
|
2393
|
+
|
2394
|
+
S = S0*ms0 + S1*ms1;
|
2395
|
+
|
2396
|
+
if (tiisg == 0) {
|
2397
|
+
ss[j*TF + 0] = S;
|
2398
|
+
ss[j*TF + 1] = M;
|
2399
|
+
|
2400
|
+
ss[j*TF + C + j ] = ms0;
|
2401
|
+
ss[j*TF + C + j + sg*SH] = ms1;
|
2402
|
+
}
|
2403
|
+
}
|
2404
|
+
|
2405
|
+
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
|
2406
|
+
{
|
2407
|
+
simdgroup_half8x8 t;
|
2408
|
+
simdgroup_float8x8 ms0;
|
2409
|
+
simdgroup_float8x8 ms1;
|
2410
|
+
|
2411
|
+
simdgroup_load(ms0, ss + C, TF, 0, false);
|
2412
|
+
simdgroup_load(ms1, ss + C + sg*SH, TF, 0, false);
|
2413
|
+
|
2414
|
+
for (short i = 0; i < D8; ++i) {
|
2415
|
+
simdgroup_load (t, sq + i*8, T, 0, false);
|
2416
|
+
simdgroup_multiply(t, ms1, t);
|
2417
|
+
|
2418
|
+
simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t);
|
2419
|
+
}
|
2420
|
+
}
|
2421
|
+
}
|
2422
|
+
}
|
2423
|
+
|
2424
|
+
// store result to shared memory (reuse sq)
|
2425
|
+
if (sgitg == 0) {
|
2426
|
+
for (short i = 0; i < D8; ++i) {
|
2427
|
+
simdgroup_store(lo[i], sq + i*8, T, 0, false);
|
2428
|
+
}
|
2429
|
+
}
|
2430
|
+
|
2431
|
+
device float4 * dst4 = (device float4 *) dst;
|
2432
|
+
|
2433
|
+
// final rescale with 1/S and store to global memory
|
2434
|
+
if (sgitg == 0) {
|
2435
|
+
for (short j = 0; j < Q && iq1 + j < ne01; ++j) {
|
2436
|
+
const float S = ss[j*TF + 0];
|
2437
|
+
|
2438
|
+
for (short i = tiisg; i < D4; i += NW) {
|
2439
|
+
dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) sq4[j*T4 + i]/S;
|
2440
|
+
}
|
2441
|
+
}
|
2442
|
+
}
|
2443
|
+
}
|
2444
|
+
|
2445
|
+
template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64>;
|
2446
|
+
template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80>;
|
2447
|
+
template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<96>;
|
2448
|
+
template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<112>;
|
2449
|
+
template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128>;
|
2450
|
+
template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256>;
|
2451
|
+
|
2452
|
+
template<int64_t D, int64_t Q = 1, int64_t C = 32> // head size, queries per threadgroup, cache items per threadgroup
|
2453
|
+
kernel void kernel_flash_attn_ext_vec_f16(
|
2454
|
+
device const char * q,
|
2455
|
+
device const char * k,
|
2456
|
+
device const char * v,
|
2457
|
+
device const char * mask,
|
2458
|
+
device float * dst,
|
2459
|
+
constant int64_t & ne00,
|
2460
|
+
constant int64_t & ne01,
|
2461
|
+
constant int64_t & ne02,
|
2462
|
+
constant int64_t & ne03,
|
2463
|
+
constant uint64_t & nb00,
|
2464
|
+
constant uint64_t & nb01,
|
2465
|
+
constant uint64_t & nb02,
|
2466
|
+
constant uint64_t & nb03,
|
2467
|
+
constant int64_t & ne10,
|
2468
|
+
constant int64_t & ne11,
|
2469
|
+
constant int64_t & ne12,
|
2470
|
+
constant int64_t & ne13,
|
2471
|
+
constant uint64_t & nb10,
|
2472
|
+
constant uint64_t & nb11,
|
2473
|
+
constant uint64_t & nb12,
|
2474
|
+
constant uint64_t & nb13,
|
2475
|
+
constant int64_t & ne31,
|
2476
|
+
constant uint64_t & nb31,
|
2477
|
+
constant int64_t & ne0,
|
2478
|
+
constant int64_t & ne1,
|
2479
|
+
constant int64_t & ne2,
|
2480
|
+
constant int64_t & ne3,
|
2481
|
+
constant float & scale,
|
2482
|
+
threadgroup half * shared [[threadgroup(0)]],
|
2483
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
2484
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
2485
|
+
uint3 ntg[[threads_per_threadgroup]],
|
2486
|
+
ushort tiisg[[thread_index_in_simdgroup]],
|
2487
|
+
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
2488
|
+
const short nsg = ntg.y; // number of simdgroups
|
2489
|
+
|
2490
|
+
const short iq3 = tgpig[2];
|
2491
|
+
const short iq2 = tgpig[1];
|
2492
|
+
const short iq1 = tgpig[0];
|
2493
|
+
|
2494
|
+
const short D4 = D/4;
|
2495
|
+
const short NW = N_SIMDWIDTH;
|
2496
|
+
const short SH = (C + Q); // shared memory per simdgroup in (half)
|
2497
|
+
|
2498
|
+
const short T = D + 2*nsg*SH; // shared memory size per query in (half)
|
2499
|
+
|
2500
|
+
//threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
|
2501
|
+
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
|
2502
|
+
threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
|
2503
|
+
threadgroup float4 * ss4 = (threadgroup float4 *) (shared + 2*sgitg*SH + 1*D); // same as above but in half4
|
2504
|
+
threadgroup half4 * sr4 = (threadgroup half4 *) (shared + sgitg*D + 1*T); // scratch buffer for the results
|
2505
|
+
|
2506
|
+
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
|
2507
|
+
half4 lo[D4/NW];
|
2508
|
+
|
2509
|
+
// load heads from Q to shared memory
|
2510
|
+
device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03));
|
2511
|
+
|
2512
|
+
for (short i = tiisg; i < D4; i += NW) {
|
2513
|
+
if (iq1 < ne01) {
|
2514
|
+
sq4[i] = (half4) q4[i];
|
2515
|
+
} else {
|
2516
|
+
sq4[i] = 0.0h;
|
2517
|
+
}
|
2518
|
+
}
|
2519
|
+
|
2520
|
+
// zero out lo
|
2521
|
+
for (short i = tiisg; i < D4; i += NW) {
|
2522
|
+
lo[i/NW] = 0.0h;
|
2523
|
+
}
|
2524
|
+
|
2525
|
+
// zero out shared memory SH
|
2526
|
+
for (short i = tiisg; i < SH/4; i += NW) {
|
2527
|
+
ss4[i] = 0.0h;
|
2528
|
+
}
|
2529
|
+
|
2530
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
2531
|
+
|
2532
|
+
{
|
2533
|
+
float S = { 0.0h };
|
2534
|
+
float M = { -FLT_MAX/2 };
|
2535
|
+
|
2536
|
+
// assume K and V are same shape
|
2537
|
+
const short ne22 = ne12;
|
2538
|
+
const short ne23 = ne13;
|
2539
|
+
|
2540
|
+
const uint nb21 = nb11;
|
2541
|
+
const uint nb22 = nb12;
|
2542
|
+
const uint nb23 = nb13;
|
2543
|
+
|
2544
|
+
// broadcast
|
2545
|
+
const short rk2 = ne02/ne12;
|
2546
|
+
const short rk3 = ne03/ne13;
|
2547
|
+
|
2548
|
+
const short rv2 = ne02/ne22;
|
2549
|
+
const short rv3 = ne03/ne23;
|
2550
|
+
|
2551
|
+
// k indices
|
2552
|
+
const short ik2 = iq2 / rk2;
|
2553
|
+
const short ik3 = iq3 / rk3;
|
2554
|
+
|
2555
|
+
// v indices
|
2556
|
+
const short iv2 = iq2 / rv2;
|
2557
|
+
const short iv3 = iq3 / rv3;
|
2558
|
+
|
2559
|
+
// load the queries from shared memory into local memory
|
2560
|
+
half4 mq[D4];
|
2561
|
+
|
2562
|
+
for (short ii = 0; ii < D4; ii += NW) {
|
2563
|
+
short i = ii + tiisg;
|
2564
|
+
mq[i] = sq4[i];
|
2565
|
+
}
|
2566
|
+
|
2567
|
+
// pointer to the mask
|
2568
|
+
device const half4 * mp4 = (device const half4 *) (mask + iq1*nb31);
|
2569
|
+
|
2570
|
+
// loop over the KV cache
|
2571
|
+
// each simdgroup handles blocks of Q rows and C columns
|
2572
|
+
for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) {
|
2573
|
+
const int ic = ic0 + C*sgitg;
|
2574
|
+
if (ic >= ne11) {
|
2575
|
+
break;
|
2576
|
+
}
|
2577
|
+
|
2578
|
+
// Q*K^T
|
2579
|
+
{
|
2580
|
+
#pragma unroll
|
2581
|
+
for (short cc = 0; cc < C/4; ++cc) {
|
2582
|
+
float4 mqk = { 0.0h };
|
2583
|
+
|
2584
|
+
device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + 4*cc)*nb11 + ik2*nb12 + ik3*nb13));
|
2585
|
+
|
2586
|
+
#pragma unroll
|
2587
|
+
for (short ii = 0; ii < D4; ii += NW) {
|
2588
|
+
const short i = ii + tiisg;
|
2589
|
+
|
2590
|
+
half4x4 mk;
|
2591
|
+
mk[0] = pk4[i + 0*(nb11/8)];
|
2592
|
+
mk[1] = pk4[i + 1*(nb11/8)];
|
2593
|
+
mk[2] = pk4[i + 2*(nb11/8)];
|
2594
|
+
mk[3] = pk4[i + 3*(nb11/8)];
|
2595
|
+
|
2596
|
+
mqk += (float4) (mq[i] * mk);
|
2597
|
+
}
|
2598
|
+
|
2599
|
+
// reduce the results from the threads in the simdgroup
|
2600
|
+
mqk += simd_shuffle_down(mqk, 16);
|
2601
|
+
mqk += simd_shuffle_down(mqk, 8);
|
2602
|
+
mqk += simd_shuffle_down(mqk, 4);
|
2603
|
+
mqk += simd_shuffle_down(mqk, 2);
|
2604
|
+
mqk += simd_shuffle_down(mqk, 1);
|
2605
|
+
|
2606
|
+
// mqk = mqk*scale + mask
|
2607
|
+
if (tiisg == 0) {
|
2608
|
+
float4 mm = (float4) mp4[ic/4 + cc];
|
2609
|
+
mqk = mqk*scale + mm;
|
2610
|
+
|
2611
|
+
ss4[cc] = mqk;
|
2612
|
+
}
|
2613
|
+
}
|
2614
|
+
}
|
2615
|
+
|
2616
|
+
// online softmax
|
2617
|
+
{
|
2618
|
+
const short p = tiisg;
|
2619
|
+
|
2620
|
+
const float m = M;
|
2621
|
+
const float s = ss[p];
|
2622
|
+
|
2623
|
+
M = simd_max(max(M, s));
|
2624
|
+
|
2625
|
+
const float ms = exp(m - M);
|
2626
|
+
const float vs = exp(s - M);
|
2627
|
+
|
2628
|
+
S = S*ms + simd_sum(vs);
|
2629
|
+
|
2630
|
+
// the P matrix from the paper (Q rows, C columns)
|
2631
|
+
ss[p] = vs;
|
2632
|
+
|
2633
|
+
// O = diag(ms)*O
|
2634
|
+
#pragma unroll
|
2635
|
+
for (short ii = 0; ii < D4; ii += NW) {
|
2636
|
+
const short i = ii + tiisg;
|
2637
|
+
lo[i/NW] *= ms;
|
2638
|
+
}
|
2639
|
+
}
|
2640
|
+
|
2641
|
+
// O = O + (Q*K^T)*V
|
2642
|
+
{
|
2643
|
+
#pragma unroll
|
2644
|
+
for (short cc = 0; cc < C/4; ++cc) {
|
2645
|
+
device const half4 * pv4 = (device const half4 *) ((device const char *) v + ((ic + 4*cc)*nb21 + iv2*nb22 + iv3*nb23));
|
2646
|
+
|
2647
|
+
#pragma unroll
|
2648
|
+
for (short ii = 0; ii < D4; ii += NW) {
|
2649
|
+
const short i = ii + tiisg;
|
2650
|
+
|
2651
|
+
lo[i/NW] += pv4[i + 0*(nb21/8)] * ss[4*cc + 0];
|
2652
|
+
lo[i/NW] += pv4[i + 1*(nb21/8)] * ss[4*cc + 1];
|
2653
|
+
lo[i/NW] += pv4[i + 2*(nb21/8)] * ss[4*cc + 2];
|
2654
|
+
lo[i/NW] += pv4[i + 3*(nb21/8)] * ss[4*cc + 3];
|
2655
|
+
}
|
2656
|
+
}
|
2657
|
+
}
|
2658
|
+
|
2659
|
+
}
|
2660
|
+
|
2661
|
+
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
|
2662
|
+
if (tiisg == 0) {
|
2663
|
+
ss[0] = S;
|
2664
|
+
ss[1] = M;
|
2665
|
+
}
|
2666
|
+
}
|
2667
|
+
|
2668
|
+
// store results to shared memory
|
2669
|
+
for (short ii = 0; ii < D4; ii += NW) {
|
2670
|
+
short i = ii + tiisg;
|
2671
|
+
sr4[i] = lo[ii/NW];
|
2672
|
+
}
|
2673
|
+
|
2674
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
2675
|
+
|
2676
|
+
// parallel reduce
|
2677
|
+
for (short r = nsg/2; r > 0; r >>= 1) {
|
2678
|
+
if (sgitg < r) {
|
2679
|
+
const float S0 = ss[ 0];
|
2680
|
+
const float S1 = ss[r*SH + 0];
|
2681
|
+
|
2682
|
+
const float M0 = ss[ 1];
|
2683
|
+
const float M1 = ss[r*SH + 1];
|
2684
|
+
|
2685
|
+
const float M = max(M0, M1);
|
2686
|
+
|
2687
|
+
const float ms0 = exp(M0 - M);
|
2688
|
+
const float ms1 = exp(M1 - M);
|
2689
|
+
|
2690
|
+
const float S = S0*ms0 + S1*ms1;
|
2691
|
+
|
2692
|
+
if (tiisg == 0) {
|
2693
|
+
ss[0] = S;
|
2694
|
+
ss[1] = M;
|
2695
|
+
}
|
2696
|
+
|
2697
|
+
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
|
2698
|
+
for (short ii = 0; ii < D4; ii += NW) {
|
2699
|
+
short i = ii + tiisg;
|
2700
|
+
sr4[i] = sr4[i]*ms0 + sr4[i + r*D4]*ms1;
|
2701
|
+
}
|
2702
|
+
}
|
2703
|
+
|
2704
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
2705
|
+
}
|
2706
|
+
|
2707
|
+
device float4 * dst4 = (device float4 *) dst;
|
2708
|
+
|
2709
|
+
// final rescale with 1/S and store to global memory
|
2710
|
+
if (sgitg == 0) {
|
2711
|
+
const float S = ss[0];
|
2712
|
+
|
2713
|
+
for (short ii = 0; ii < D4; ii += NW) {
|
2714
|
+
short i = ii + tiisg;
|
2715
|
+
dst4[(iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D4 + i] = (float4) sr4[i]/S;
|
2716
|
+
}
|
2717
|
+
}
|
2718
|
+
}
|
2719
|
+
|
2720
|
+
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>;
|
2721
|
+
template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>;
|
2722
|
+
|
2087
2723
|
kernel void kernel_cpy_f16_f16(
|
2088
2724
|
device const half * src0,
|
2089
2725
|
device half * dst,
|