llama_cpp 0.14.6 → 0.15.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -352,11 +352,12 @@ kernel void kernel_sum_rows(
352
352
  dst_row[0] = row_sum;
353
353
  }
354
354
 
355
+ template<typename T>
355
356
  kernel void kernel_soft_max(
356
- device const float * src0,
357
- device const float * src1,
358
- device const float * src2,
359
- device float * dst,
357
+ device const char * src0,
358
+ device const char * src1,
359
+ device const char * src2,
360
+ device char * dst,
360
361
  constant int64_t & ne00,
361
362
  constant int64_t & ne01,
362
363
  constant int64_t & ne02,
@@ -375,10 +376,10 @@ kernel void kernel_soft_max(
375
376
  const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
376
377
  const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
377
378
 
378
- device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
379
- device const float * pmask = src1 != src0 ? src1 + i01*ne00 : nullptr;
380
- device const float * ppos = src2 != src0 ? src2 : nullptr;
381
- device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
379
+ device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
380
+ device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00 : nullptr;
381
+ device const T * ppos = src2 != src0 ? (device const T *) src2 : nullptr;
382
+ device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
382
383
 
383
384
  float slope = 0.0f;
384
385
 
@@ -456,11 +457,12 @@ kernel void kernel_soft_max(
456
457
  }
457
458
  }
458
459
 
460
+ template<typename T>
459
461
  kernel void kernel_soft_max_4(
460
- device const float * src0,
461
- device const float * src1,
462
- device const float * src2,
463
- device float * dst,
462
+ device const char * src0,
463
+ device const char * src1,
464
+ device const char * src2,
465
+ device char * dst,
464
466
  constant int64_t & ne00,
465
467
  constant int64_t & ne01,
466
468
  constant int64_t & ne02,
@@ -479,10 +481,10 @@ kernel void kernel_soft_max_4(
479
481
  const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
480
482
  const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
481
483
 
482
- device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
483
- device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
484
- device const float4 * ppos = src2 != src0 ? (device const float4 *)(src2) : nullptr;
485
- device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
484
+ device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
485
+ device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00/4 : nullptr;
486
+ device const T * ppos = src2 != src0 ? (device const T *) src2 : nullptr;
487
+ device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
486
488
 
487
489
  float slope = 0.0f;
488
490
 
@@ -499,7 +501,7 @@ kernel void kernel_soft_max_4(
499
501
  float4 lmax4 = -INFINITY;
500
502
 
501
503
  for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
502
- lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f));
504
+ lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f)));
503
505
  }
504
506
 
505
507
  const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
@@ -525,7 +527,7 @@ kernel void kernel_soft_max_4(
525
527
  // parallel sum
526
528
  float4 lsum4 = 0.0f;
527
529
  for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
528
- const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f)) - max_val);
530
+ const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f))) - max_val);
529
531
  lsum4 += exp_psrc4;
530
532
  pdst4[i00] = exp_psrc4;
531
533
  }
@@ -562,6 +564,14 @@ kernel void kernel_soft_max_4(
562
564
  }
563
565
  }
564
566
 
567
+ typedef decltype(kernel_soft_max<float>) kernel_soft_max_t;
568
+ typedef decltype(kernel_soft_max_4<float4>) kernel_soft_max_4_t;
569
+
570
+ template [[host_name("kernel_soft_max_f16")]] kernel kernel_soft_max_t kernel_soft_max<half>;
571
+ template [[host_name("kernel_soft_max_f32")]] kernel kernel_soft_max_t kernel_soft_max<float>;
572
+ template [[host_name("kernel_soft_max_f16_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4<half4>;
573
+ template [[host_name("kernel_soft_max_f32_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4<float4>;
574
+
565
575
  kernel void kernel_diag_mask_inf(
566
576
  device const float * src0,
567
577
  device float * dst,
@@ -2084,6 +2094,632 @@ kernel void kernel_leaky_relu_f32(
2084
2094
  dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope;
2085
2095
  }
2086
2096
 
2097
+ typedef void (flash_attn_ext_f16_t)(
2098
+ device const char * q,
2099
+ device const char * k,
2100
+ device const char * v,
2101
+ device const char * mask,
2102
+ device float * dst,
2103
+ constant int64_t & ne00,
2104
+ constant int64_t & ne01,
2105
+ constant int64_t & ne02,
2106
+ constant int64_t & ne03,
2107
+ constant uint64_t & nb00,
2108
+ constant uint64_t & nb01,
2109
+ constant uint64_t & nb02,
2110
+ constant uint64_t & nb03,
2111
+ constant int64_t & ne10,
2112
+ constant int64_t & ne11,
2113
+ constant int64_t & ne12,
2114
+ constant int64_t & ne13,
2115
+ constant uint64_t & nb10,
2116
+ constant uint64_t & nb11,
2117
+ constant uint64_t & nb12,
2118
+ constant uint64_t & nb13,
2119
+ constant int64_t & ne31,
2120
+ constant uint64_t & nb31,
2121
+ constant int64_t & ne0,
2122
+ constant int64_t & ne1,
2123
+ constant int64_t & ne2,
2124
+ constant int64_t & ne3,
2125
+ constant float & scale,
2126
+ threadgroup half * shared,
2127
+ uint3 tgpig[[threadgroup_position_in_grid]],
2128
+ uint3 tpitg[[thread_position_in_threadgroup]],
2129
+ uint3 ntg[[threads_per_threadgroup]],
2130
+ ushort tiisg[[thread_index_in_simdgroup]],
2131
+ ushort sgitg[[simdgroup_index_in_threadgroup]]);
2132
+
2133
+ // ref: https://arxiv.org/pdf/2307.08691.pdf
2134
+ template<int64_t D, int64_t Q = 8, int64_t C = 32> // head size, queries per threadgroup, cache items per threadgroup
2135
+ kernel void kernel_flash_attn_ext_f16(
2136
+ device const char * q,
2137
+ device const char * k,
2138
+ device const char * v,
2139
+ device const char * mask,
2140
+ device float * dst,
2141
+ constant int64_t & ne00,
2142
+ constant int64_t & ne01,
2143
+ constant int64_t & ne02,
2144
+ constant int64_t & ne03,
2145
+ constant uint64_t & nb00,
2146
+ constant uint64_t & nb01,
2147
+ constant uint64_t & nb02,
2148
+ constant uint64_t & nb03,
2149
+ constant int64_t & ne10,
2150
+ constant int64_t & ne11,
2151
+ constant int64_t & ne12,
2152
+ constant int64_t & ne13,
2153
+ constant uint64_t & nb10,
2154
+ constant uint64_t & nb11,
2155
+ constant uint64_t & nb12,
2156
+ constant uint64_t & nb13,
2157
+ constant int64_t & ne31,
2158
+ constant uint64_t & nb31,
2159
+ constant int64_t & ne0,
2160
+ constant int64_t & ne1,
2161
+ constant int64_t & ne2,
2162
+ constant int64_t & ne3,
2163
+ constant float & scale,
2164
+ threadgroup half * shared [[threadgroup(0)]],
2165
+ uint3 tgpig[[threadgroup_position_in_grid]],
2166
+ uint3 tpitg[[thread_position_in_threadgroup]],
2167
+ uint3 ntg[[threads_per_threadgroup]],
2168
+ ushort tiisg[[thread_index_in_simdgroup]],
2169
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
2170
+ const short nsg = ntg.y; // number of simdgroups
2171
+
2172
+ const short iq3 = tgpig[2];
2173
+ const short iq2 = tgpig[1];
2174
+ const short iq1 = tgpig[0]*Q;
2175
+
2176
+ const short D4 = D/4;
2177
+ const short D8 = D/8;
2178
+ const short Q8 = Q/8;
2179
+ const short NW = N_SIMDWIDTH;
2180
+ const short SH = (C + Q); // shared memory per simdgroup in (half)
2181
+
2182
+ const short T = D + 2*nsg*SH; // shared memory size per query in (half)
2183
+ const short TF = T/2; // shared memory size per query in (float)
2184
+ const short T4 = T/4; // shared memory size per query in (half4)
2185
+
2186
+ threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
2187
+ threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
2188
+ threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
2189
+
2190
+ // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
2191
+ simdgroup_half8x8 lo[D8];
2192
+
2193
+ // load heads from Q to shared memory
2194
+ for (short j = sgitg; j < Q; j += nsg) {
2195
+ device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03));
2196
+
2197
+ for (short i = tiisg; i < D4; i += NW) {
2198
+ if (iq1 + j < ne01) {
2199
+ sq4[j*T4 + i] = (half4) q4[i];
2200
+ } else {
2201
+ sq4[j*T4 + i] = 0.0h;
2202
+ }
2203
+ }
2204
+ }
2205
+
2206
+ // zero out lo
2207
+ for (short i = 0; i < D8; ++i) {
2208
+ lo[i] = make_filled_simdgroup_matrix<half, 8>(0.0h);
2209
+ }
2210
+
2211
+ // zero out shared memory SH
2212
+ for (short j = 0; j < Q; ++j) {
2213
+ for (short i = tiisg; i < SH; i += NW) {
2214
+ ss[j*TF + i] = 0.0f;
2215
+ }
2216
+ }
2217
+
2218
+ threadgroup_barrier(mem_flags::mem_threadgroup);
2219
+
2220
+ {
2221
+ float S[Q] = { [0 ... Q-1] = 0.0h };
2222
+ float M[Q] = { [0 ... Q-1] = -FLT_MAX/2 };
2223
+
2224
+ // assume K and V are same shape
2225
+ const short ne22 = ne12;
2226
+ const short ne23 = ne13;
2227
+
2228
+ const uint nb21 = nb11;
2229
+ const uint nb22 = nb12;
2230
+ const uint nb23 = nb13;
2231
+
2232
+ // broadcast
2233
+ const short rk2 = ne02/ne12;
2234
+ const short rk3 = ne03/ne13;
2235
+
2236
+ const short rv2 = ne02/ne22;
2237
+ const short rv3 = ne03/ne23;
2238
+
2239
+ // k indices
2240
+ const short ik2 = iq2/rk2;
2241
+ const short ik3 = iq3/rk3;
2242
+
2243
+ // v indices
2244
+ const short iv2 = iq2/rv2;
2245
+ const short iv3 = iq3/rv3;
2246
+
2247
+ // load the queries from shared memory into local memory
2248
+ simdgroup_half8x8 mq[D8];
2249
+
2250
+ for (short i = 0; i < D8; ++i) {
2251
+ simdgroup_load(mq[i], sq + i*8, T);
2252
+ }
2253
+
2254
+ // pointer to the mask
2255
+ device const half * mp = (device const half *) (mask + iq1*nb31);
2256
+
2257
+ // prepare diagonal scale matrix
2258
+ simdgroup_float8x8 mscale(scale);
2259
+
2260
+ // loop over the KV cache
2261
+ // each simdgroup handles blocks of Q rows and C columns
2262
+ for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) {
2263
+ const int ic = ic0 + C*sgitg;
2264
+ if (ic >= ne11) {
2265
+ break;
2266
+ }
2267
+
2268
+ // Q*K^T
2269
+ {
2270
+ for (short cc = 0; cc < C/8; ++cc) {
2271
+ simdgroup_float8x8 mqk = make_filled_simdgroup_matrix<float, 8>(0.h);
2272
+
2273
+ device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13));
2274
+
2275
+ for (short i = 0; i < D8; ++i) {
2276
+ simdgroup_half8x8 mk;
2277
+ simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose
2278
+
2279
+ simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
2280
+ }
2281
+
2282
+ // mqk = mqk*scale + mask
2283
+ simdgroup_half8x8 mm;
2284
+ simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(half), 0, false);
2285
+ simdgroup_multiply_accumulate(mqk, mqk, mscale, mm);
2286
+
2287
+ simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
2288
+ }
2289
+ }
2290
+
2291
+ // used to detect blocks full of -INF
2292
+ float smax = -INFINITY;
2293
+
2294
+ // online softmax
2295
+ {
2296
+ float ms[Q];
2297
+
2298
+ for (short j = 0; j < Q; ++j) {
2299
+ const short p = tiisg;
2300
+
2301
+ const float m = M[j];
2302
+ const float s = ss[j*TF + p];
2303
+
2304
+ smax = simd_max(max(smax, s));
2305
+ M[j] = simd_max(max(M[j], s));
2306
+
2307
+ ms[j] = exp(m - M[j]);
2308
+ const float vs = exp(s - M[j]);
2309
+
2310
+ S[j] = S[j]*ms[j] + simd_sum(vs);
2311
+
2312
+ // the P matrix from the paper (Q rows, C columns)
2313
+ ss[j*TF + p] = vs;
2314
+ }
2315
+
2316
+ // create a QxQ diagonal matrix for rescaling the output
2317
+ if (tiisg < Q) {
2318
+ ss[tiisg*TF + C + tiisg] = ms[tiisg];
2319
+ }
2320
+ }
2321
+
2322
+ // skip -INF blocks
2323
+ if (smax == -INFINITY) {
2324
+ continue;
2325
+ }
2326
+
2327
+ // O = diag(ms)*O
2328
+ {
2329
+ simdgroup_float8x8 mm;
2330
+ simdgroup_load(mm, ss + C, TF, 0, false);
2331
+
2332
+ for (short i = 0; i < D8; ++i) {
2333
+ simdgroup_multiply(lo[i], mm, lo[i]);
2334
+ }
2335
+ }
2336
+
2337
+ // O = O + (Q*K^T)*V
2338
+ {
2339
+ for (short cc = 0; cc < C/8; ++cc) {
2340
+ device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23));
2341
+
2342
+ for (short i = 0; i < D8; ++i) {
2343
+ simdgroup_half8x8 mk;
2344
+ simdgroup_load(mk, pv + i*8, nb21/sizeof(half), 0, false);
2345
+
2346
+ simdgroup_float8x8 mv;
2347
+ simdgroup_load(mv, ss + 8*cc, TF, 0, false);
2348
+
2349
+ simdgroup_multiply_accumulate(lo[i], mv, mk, lo[i]);
2350
+ }
2351
+ }
2352
+ }
2353
+ }
2354
+
2355
+ // these are needed for reducing the results from the simdgroups (reuse the ss buffer)
2356
+ for (short j = 0; j < Q; ++j) {
2357
+ if (tiisg == 0) {
2358
+ ss[j*TF + 0] = S[j];
2359
+ ss[j*TF + 1] = M[j];
2360
+ }
2361
+ }
2362
+ }
2363
+
2364
+ // reduce the warps sequentially
2365
+ for (short sg = 1; sg < nsg; ++sg) {
2366
+ float S = { 0.0h };
2367
+ float M = { -FLT_MAX/2 };
2368
+
2369
+ threadgroup_barrier(mem_flags::mem_threadgroup);
2370
+
2371
+ // each simdgroup stores its output to shared memory, reusing sq
2372
+ if (sgitg == sg) {
2373
+ for (short i = 0; i < D8; ++i) {
2374
+ simdgroup_store(lo[i], sq + i*8, T, 0, false);
2375
+ }
2376
+ }
2377
+
2378
+ threadgroup_barrier(mem_flags::mem_threadgroup);
2379
+
2380
+ // the first simdgroup accumulates the results from the other simdgroups
2381
+ if (sgitg == 0) {
2382
+ for (short j = 0; j < Q; ++j) {
2383
+ const float S0 = ss[j*TF + 0];
2384
+ const float S1 = ss[j*TF + sg*SH + 0];
2385
+
2386
+ const float M0 = ss[j*TF + 1];
2387
+ const float M1 = ss[j*TF + sg*SH + 1];
2388
+
2389
+ M = max(M0, M1);
2390
+
2391
+ const float ms0 = exp(M0 - M);
2392
+ const float ms1 = exp(M1 - M);
2393
+
2394
+ S = S0*ms0 + S1*ms1;
2395
+
2396
+ if (tiisg == 0) {
2397
+ ss[j*TF + 0] = S;
2398
+ ss[j*TF + 1] = M;
2399
+
2400
+ ss[j*TF + C + j ] = ms0;
2401
+ ss[j*TF + C + j + sg*SH] = ms1;
2402
+ }
2403
+ }
2404
+
2405
+ // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
2406
+ {
2407
+ simdgroup_half8x8 t;
2408
+ simdgroup_float8x8 ms0;
2409
+ simdgroup_float8x8 ms1;
2410
+
2411
+ simdgroup_load(ms0, ss + C, TF, 0, false);
2412
+ simdgroup_load(ms1, ss + C + sg*SH, TF, 0, false);
2413
+
2414
+ for (short i = 0; i < D8; ++i) {
2415
+ simdgroup_load (t, sq + i*8, T, 0, false);
2416
+ simdgroup_multiply(t, ms1, t);
2417
+
2418
+ simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t);
2419
+ }
2420
+ }
2421
+ }
2422
+ }
2423
+
2424
+ // store result to shared memory (reuse sq)
2425
+ if (sgitg == 0) {
2426
+ for (short i = 0; i < D8; ++i) {
2427
+ simdgroup_store(lo[i], sq + i*8, T, 0, false);
2428
+ }
2429
+ }
2430
+
2431
+ device float4 * dst4 = (device float4 *) dst;
2432
+
2433
+ // final rescale with 1/S and store to global memory
2434
+ if (sgitg == 0) {
2435
+ for (short j = 0; j < Q && iq1 + j < ne01; ++j) {
2436
+ const float S = ss[j*TF + 0];
2437
+
2438
+ for (short i = tiisg; i < D4; i += NW) {
2439
+ dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) sq4[j*T4 + i]/S;
2440
+ }
2441
+ }
2442
+ }
2443
+ }
2444
+
2445
+ template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64>;
2446
+ template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80>;
2447
+ template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<96>;
2448
+ template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<112>;
2449
+ template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128>;
2450
+ template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256>;
2451
+
2452
+ template<int64_t D, int64_t Q = 1, int64_t C = 32> // head size, queries per threadgroup, cache items per threadgroup
2453
+ kernel void kernel_flash_attn_ext_vec_f16(
2454
+ device const char * q,
2455
+ device const char * k,
2456
+ device const char * v,
2457
+ device const char * mask,
2458
+ device float * dst,
2459
+ constant int64_t & ne00,
2460
+ constant int64_t & ne01,
2461
+ constant int64_t & ne02,
2462
+ constant int64_t & ne03,
2463
+ constant uint64_t & nb00,
2464
+ constant uint64_t & nb01,
2465
+ constant uint64_t & nb02,
2466
+ constant uint64_t & nb03,
2467
+ constant int64_t & ne10,
2468
+ constant int64_t & ne11,
2469
+ constant int64_t & ne12,
2470
+ constant int64_t & ne13,
2471
+ constant uint64_t & nb10,
2472
+ constant uint64_t & nb11,
2473
+ constant uint64_t & nb12,
2474
+ constant uint64_t & nb13,
2475
+ constant int64_t & ne31,
2476
+ constant uint64_t & nb31,
2477
+ constant int64_t & ne0,
2478
+ constant int64_t & ne1,
2479
+ constant int64_t & ne2,
2480
+ constant int64_t & ne3,
2481
+ constant float & scale,
2482
+ threadgroup half * shared [[threadgroup(0)]],
2483
+ uint3 tgpig[[threadgroup_position_in_grid]],
2484
+ uint3 tpitg[[thread_position_in_threadgroup]],
2485
+ uint3 ntg[[threads_per_threadgroup]],
2486
+ ushort tiisg[[thread_index_in_simdgroup]],
2487
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
2488
+ const short nsg = ntg.y; // number of simdgroups
2489
+
2490
+ const short iq3 = tgpig[2];
2491
+ const short iq2 = tgpig[1];
2492
+ const short iq1 = tgpig[0];
2493
+
2494
+ const short D4 = D/4;
2495
+ const short NW = N_SIMDWIDTH;
2496
+ const short SH = (C + Q); // shared memory per simdgroup in (half)
2497
+
2498
+ const short T = D + 2*nsg*SH; // shared memory size per query in (half)
2499
+
2500
+ //threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
2501
+ threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
2502
+ threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
2503
+ threadgroup float4 * ss4 = (threadgroup float4 *) (shared + 2*sgitg*SH + 1*D); // same as above but in half4
2504
+ threadgroup half4 * sr4 = (threadgroup half4 *) (shared + sgitg*D + 1*T); // scratch buffer for the results
2505
+
2506
+ // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
2507
+ half4 lo[D4/NW];
2508
+
2509
+ // load heads from Q to shared memory
2510
+ device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03));
2511
+
2512
+ for (short i = tiisg; i < D4; i += NW) {
2513
+ if (iq1 < ne01) {
2514
+ sq4[i] = (half4) q4[i];
2515
+ } else {
2516
+ sq4[i] = 0.0h;
2517
+ }
2518
+ }
2519
+
2520
+ // zero out lo
2521
+ for (short i = tiisg; i < D4; i += NW) {
2522
+ lo[i/NW] = 0.0h;
2523
+ }
2524
+
2525
+ // zero out shared memory SH
2526
+ for (short i = tiisg; i < SH/4; i += NW) {
2527
+ ss4[i] = 0.0h;
2528
+ }
2529
+
2530
+ threadgroup_barrier(mem_flags::mem_threadgroup);
2531
+
2532
+ {
2533
+ float S = { 0.0h };
2534
+ float M = { -FLT_MAX/2 };
2535
+
2536
+ // assume K and V are same shape
2537
+ const short ne22 = ne12;
2538
+ const short ne23 = ne13;
2539
+
2540
+ const uint nb21 = nb11;
2541
+ const uint nb22 = nb12;
2542
+ const uint nb23 = nb13;
2543
+
2544
+ // broadcast
2545
+ const short rk2 = ne02/ne12;
2546
+ const short rk3 = ne03/ne13;
2547
+
2548
+ const short rv2 = ne02/ne22;
2549
+ const short rv3 = ne03/ne23;
2550
+
2551
+ // k indices
2552
+ const short ik2 = iq2 / rk2;
2553
+ const short ik3 = iq3 / rk3;
2554
+
2555
+ // v indices
2556
+ const short iv2 = iq2 / rv2;
2557
+ const short iv3 = iq3 / rv3;
2558
+
2559
+ // load the queries from shared memory into local memory
2560
+ half4 mq[D4];
2561
+
2562
+ for (short ii = 0; ii < D4; ii += NW) {
2563
+ short i = ii + tiisg;
2564
+ mq[i] = sq4[i];
2565
+ }
2566
+
2567
+ // pointer to the mask
2568
+ device const half4 * mp4 = (device const half4 *) (mask + iq1*nb31);
2569
+
2570
+ // loop over the KV cache
2571
+ // each simdgroup handles blocks of Q rows and C columns
2572
+ for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) {
2573
+ const int ic = ic0 + C*sgitg;
2574
+ if (ic >= ne11) {
2575
+ break;
2576
+ }
2577
+
2578
+ // Q*K^T
2579
+ {
2580
+ #pragma unroll
2581
+ for (short cc = 0; cc < C/4; ++cc) {
2582
+ float4 mqk = { 0.0h };
2583
+
2584
+ device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + 4*cc)*nb11 + ik2*nb12 + ik3*nb13));
2585
+
2586
+ #pragma unroll
2587
+ for (short ii = 0; ii < D4; ii += NW) {
2588
+ const short i = ii + tiisg;
2589
+
2590
+ half4x4 mk;
2591
+ mk[0] = pk4[i + 0*(nb11/8)];
2592
+ mk[1] = pk4[i + 1*(nb11/8)];
2593
+ mk[2] = pk4[i + 2*(nb11/8)];
2594
+ mk[3] = pk4[i + 3*(nb11/8)];
2595
+
2596
+ mqk += (float4) (mq[i] * mk);
2597
+ }
2598
+
2599
+ // reduce the results from the threads in the simdgroup
2600
+ mqk += simd_shuffle_down(mqk, 16);
2601
+ mqk += simd_shuffle_down(mqk, 8);
2602
+ mqk += simd_shuffle_down(mqk, 4);
2603
+ mqk += simd_shuffle_down(mqk, 2);
2604
+ mqk += simd_shuffle_down(mqk, 1);
2605
+
2606
+ // mqk = mqk*scale + mask
2607
+ if (tiisg == 0) {
2608
+ float4 mm = (float4) mp4[ic/4 + cc];
2609
+ mqk = mqk*scale + mm;
2610
+
2611
+ ss4[cc] = mqk;
2612
+ }
2613
+ }
2614
+ }
2615
+
2616
+ // online softmax
2617
+ {
2618
+ const short p = tiisg;
2619
+
2620
+ const float m = M;
2621
+ const float s = ss[p];
2622
+
2623
+ M = simd_max(max(M, s));
2624
+
2625
+ const float ms = exp(m - M);
2626
+ const float vs = exp(s - M);
2627
+
2628
+ S = S*ms + simd_sum(vs);
2629
+
2630
+ // the P matrix from the paper (Q rows, C columns)
2631
+ ss[p] = vs;
2632
+
2633
+ // O = diag(ms)*O
2634
+ #pragma unroll
2635
+ for (short ii = 0; ii < D4; ii += NW) {
2636
+ const short i = ii + tiisg;
2637
+ lo[i/NW] *= ms;
2638
+ }
2639
+ }
2640
+
2641
+ // O = O + (Q*K^T)*V
2642
+ {
2643
+ #pragma unroll
2644
+ for (short cc = 0; cc < C/4; ++cc) {
2645
+ device const half4 * pv4 = (device const half4 *) ((device const char *) v + ((ic + 4*cc)*nb21 + iv2*nb22 + iv3*nb23));
2646
+
2647
+ #pragma unroll
2648
+ for (short ii = 0; ii < D4; ii += NW) {
2649
+ const short i = ii + tiisg;
2650
+
2651
+ lo[i/NW] += pv4[i + 0*(nb21/8)] * ss[4*cc + 0];
2652
+ lo[i/NW] += pv4[i + 1*(nb21/8)] * ss[4*cc + 1];
2653
+ lo[i/NW] += pv4[i + 2*(nb21/8)] * ss[4*cc + 2];
2654
+ lo[i/NW] += pv4[i + 3*(nb21/8)] * ss[4*cc + 3];
2655
+ }
2656
+ }
2657
+ }
2658
+
2659
+ }
2660
+
2661
+ // these are needed for reducing the results from the simdgroups (reuse the ss buffer)
2662
+ if (tiisg == 0) {
2663
+ ss[0] = S;
2664
+ ss[1] = M;
2665
+ }
2666
+ }
2667
+
2668
+ // store results to shared memory
2669
+ for (short ii = 0; ii < D4; ii += NW) {
2670
+ short i = ii + tiisg;
2671
+ sr4[i] = lo[ii/NW];
2672
+ }
2673
+
2674
+ threadgroup_barrier(mem_flags::mem_threadgroup);
2675
+
2676
+ // parallel reduce
2677
+ for (short r = nsg/2; r > 0; r >>= 1) {
2678
+ if (sgitg < r) {
2679
+ const float S0 = ss[ 0];
2680
+ const float S1 = ss[r*SH + 0];
2681
+
2682
+ const float M0 = ss[ 1];
2683
+ const float M1 = ss[r*SH + 1];
2684
+
2685
+ const float M = max(M0, M1);
2686
+
2687
+ const float ms0 = exp(M0 - M);
2688
+ const float ms1 = exp(M1 - M);
2689
+
2690
+ const float S = S0*ms0 + S1*ms1;
2691
+
2692
+ if (tiisg == 0) {
2693
+ ss[0] = S;
2694
+ ss[1] = M;
2695
+ }
2696
+
2697
+ // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
2698
+ for (short ii = 0; ii < D4; ii += NW) {
2699
+ short i = ii + tiisg;
2700
+ sr4[i] = sr4[i]*ms0 + sr4[i + r*D4]*ms1;
2701
+ }
2702
+ }
2703
+
2704
+ threadgroup_barrier(mem_flags::mem_threadgroup);
2705
+ }
2706
+
2707
+ device float4 * dst4 = (device float4 *) dst;
2708
+
2709
+ // final rescale with 1/S and store to global memory
2710
+ if (sgitg == 0) {
2711
+ const float S = ss[0];
2712
+
2713
+ for (short ii = 0; ii < D4; ii += NW) {
2714
+ short i = ii + tiisg;
2715
+ dst4[(iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D4 + i] = (float4) sr4[i]/S;
2716
+ }
2717
+ }
2718
+ }
2719
+
2720
+ template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>;
2721
+ template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>;
2722
+
2087
2723
  kernel void kernel_cpy_f16_f16(
2088
2724
  device const half * src0,
2089
2725
  device half * dst,