llama_cpp 0.14.7 → 0.15.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -352,11 +352,12 @@ kernel void kernel_sum_rows(
352
352
  dst_row[0] = row_sum;
353
353
  }
354
354
 
355
+ template<typename T>
355
356
  kernel void kernel_soft_max(
356
- device const float * src0,
357
- device const float * src1,
358
- device const float * src2,
359
- device float * dst,
357
+ device const char * src0,
358
+ device const char * src1,
359
+ device const char * src2,
360
+ device char * dst,
360
361
  constant int64_t & ne00,
361
362
  constant int64_t & ne01,
362
363
  constant int64_t & ne02,
@@ -375,10 +376,10 @@ kernel void kernel_soft_max(
375
376
  const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
376
377
  const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
377
378
 
378
- device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
379
- device const float * pmask = src1 != src0 ? src1 + i01*ne00 : nullptr;
380
- device const float * ppos = src2 != src0 ? src2 : nullptr;
381
- device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
379
+ device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
380
+ device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00 : nullptr;
381
+ device const T * ppos = src2 != src0 ? (device const T *) src2 : nullptr;
382
+ device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
382
383
 
383
384
  float slope = 0.0f;
384
385
 
@@ -456,11 +457,12 @@ kernel void kernel_soft_max(
456
457
  }
457
458
  }
458
459
 
460
+ template<typename T>
459
461
  kernel void kernel_soft_max_4(
460
- device const float * src0,
461
- device const float * src1,
462
- device const float * src2,
463
- device float * dst,
462
+ device const char * src0,
463
+ device const char * src1,
464
+ device const char * src2,
465
+ device char * dst,
464
466
  constant int64_t & ne00,
465
467
  constant int64_t & ne01,
466
468
  constant int64_t & ne02,
@@ -479,10 +481,10 @@ kernel void kernel_soft_max_4(
479
481
  const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
480
482
  const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
481
483
 
482
- device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
483
- device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
484
- device const float4 * ppos = src2 != src0 ? (device const float4 *)(src2) : nullptr;
485
- device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
484
+ device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
485
+ device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00/4 : nullptr;
486
+ device const T * ppos = src2 != src0 ? (device const T *) src2 : nullptr;
487
+ device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
486
488
 
487
489
  float slope = 0.0f;
488
490
 
@@ -499,7 +501,7 @@ kernel void kernel_soft_max_4(
499
501
  float4 lmax4 = -INFINITY;
500
502
 
501
503
  for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
502
- lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f));
504
+ lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f)));
503
505
  }
504
506
 
505
507
  const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
@@ -525,7 +527,7 @@ kernel void kernel_soft_max_4(
525
527
  // parallel sum
526
528
  float4 lsum4 = 0.0f;
527
529
  for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
528
- const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f)) - max_val);
530
+ const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f))) - max_val);
529
531
  lsum4 += exp_psrc4;
530
532
  pdst4[i00] = exp_psrc4;
531
533
  }
@@ -562,6 +564,14 @@ kernel void kernel_soft_max_4(
562
564
  }
563
565
  }
564
566
 
567
+ typedef decltype(kernel_soft_max<float>) kernel_soft_max_t;
568
+ typedef decltype(kernel_soft_max_4<float4>) kernel_soft_max_4_t;
569
+
570
+ template [[host_name("kernel_soft_max_f16")]] kernel kernel_soft_max_t kernel_soft_max<half>;
571
+ template [[host_name("kernel_soft_max_f32")]] kernel kernel_soft_max_t kernel_soft_max<float>;
572
+ template [[host_name("kernel_soft_max_f16_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4<half4>;
573
+ template [[host_name("kernel_soft_max_f32_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4<float4>;
574
+
565
575
  kernel void kernel_diag_mask_inf(
566
576
  device const float * src0,
567
577
  device float * dst,
@@ -2084,6 +2094,632 @@ kernel void kernel_leaky_relu_f32(
2084
2094
  dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope;
2085
2095
  }
2086
2096
 
2097
+ typedef void (flash_attn_ext_f16_t)(
2098
+ device const char * q,
2099
+ device const char * k,
2100
+ device const char * v,
2101
+ device const char * mask,
2102
+ device float * dst,
2103
+ constant int64_t & ne00,
2104
+ constant int64_t & ne01,
2105
+ constant int64_t & ne02,
2106
+ constant int64_t & ne03,
2107
+ constant uint64_t & nb00,
2108
+ constant uint64_t & nb01,
2109
+ constant uint64_t & nb02,
2110
+ constant uint64_t & nb03,
2111
+ constant int64_t & ne10,
2112
+ constant int64_t & ne11,
2113
+ constant int64_t & ne12,
2114
+ constant int64_t & ne13,
2115
+ constant uint64_t & nb10,
2116
+ constant uint64_t & nb11,
2117
+ constant uint64_t & nb12,
2118
+ constant uint64_t & nb13,
2119
+ constant int64_t & ne31,
2120
+ constant uint64_t & nb31,
2121
+ constant int64_t & ne0,
2122
+ constant int64_t & ne1,
2123
+ constant int64_t & ne2,
2124
+ constant int64_t & ne3,
2125
+ constant float & scale,
2126
+ threadgroup half * shared,
2127
+ uint3 tgpig[[threadgroup_position_in_grid]],
2128
+ uint3 tpitg[[thread_position_in_threadgroup]],
2129
+ uint3 ntg[[threads_per_threadgroup]],
2130
+ ushort tiisg[[thread_index_in_simdgroup]],
2131
+ ushort sgitg[[simdgroup_index_in_threadgroup]]);
2132
+
2133
+ // ref: https://arxiv.org/pdf/2307.08691.pdf
2134
+ template<int64_t D, int64_t Q = 8, int64_t C = 32> // head size, queries per threadgroup, cache items per threadgroup
2135
+ kernel void kernel_flash_attn_ext_f16(
2136
+ device const char * q,
2137
+ device const char * k,
2138
+ device const char * v,
2139
+ device const char * mask,
2140
+ device float * dst,
2141
+ constant int64_t & ne00,
2142
+ constant int64_t & ne01,
2143
+ constant int64_t & ne02,
2144
+ constant int64_t & ne03,
2145
+ constant uint64_t & nb00,
2146
+ constant uint64_t & nb01,
2147
+ constant uint64_t & nb02,
2148
+ constant uint64_t & nb03,
2149
+ constant int64_t & ne10,
2150
+ constant int64_t & ne11,
2151
+ constant int64_t & ne12,
2152
+ constant int64_t & ne13,
2153
+ constant uint64_t & nb10,
2154
+ constant uint64_t & nb11,
2155
+ constant uint64_t & nb12,
2156
+ constant uint64_t & nb13,
2157
+ constant int64_t & ne31,
2158
+ constant uint64_t & nb31,
2159
+ constant int64_t & ne0,
2160
+ constant int64_t & ne1,
2161
+ constant int64_t & ne2,
2162
+ constant int64_t & ne3,
2163
+ constant float & scale,
2164
+ threadgroup half * shared [[threadgroup(0)]],
2165
+ uint3 tgpig[[threadgroup_position_in_grid]],
2166
+ uint3 tpitg[[thread_position_in_threadgroup]],
2167
+ uint3 ntg[[threads_per_threadgroup]],
2168
+ ushort tiisg[[thread_index_in_simdgroup]],
2169
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
2170
+ const short nsg = ntg.y; // number of simdgroups
2171
+
2172
+ const short iq3 = tgpig[2];
2173
+ const short iq2 = tgpig[1];
2174
+ const short iq1 = tgpig[0]*Q;
2175
+
2176
+ const short D4 = D/4;
2177
+ const short D8 = D/8;
2178
+ const short Q8 = Q/8;
2179
+ const short NW = N_SIMDWIDTH;
2180
+ const short SH = (C + Q); // shared memory per simdgroup in (half)
2181
+
2182
+ const short T = D + 2*nsg*SH; // shared memory size per query in (half)
2183
+ const short TF = T/2; // shared memory size per query in (float)
2184
+ const short T4 = T/4; // shared memory size per query in (half4)
2185
+
2186
+ threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
2187
+ threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
2188
+ threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
2189
+
2190
+ // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
2191
+ simdgroup_half8x8 lo[D8];
2192
+
2193
+ // load heads from Q to shared memory
2194
+ for (short j = sgitg; j < Q; j += nsg) {
2195
+ device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03));
2196
+
2197
+ for (short i = tiisg; i < D4; i += NW) {
2198
+ if (iq1 + j < ne01) {
2199
+ sq4[j*T4 + i] = (half4) q4[i];
2200
+ } else {
2201
+ sq4[j*T4 + i] = 0.0h;
2202
+ }
2203
+ }
2204
+ }
2205
+
2206
+ // zero out lo
2207
+ for (short i = 0; i < D8; ++i) {
2208
+ lo[i] = make_filled_simdgroup_matrix<half, 8>(0.0h);
2209
+ }
2210
+
2211
+ // zero out shared memory SH
2212
+ for (short j = 0; j < Q; ++j) {
2213
+ for (short i = tiisg; i < SH; i += NW) {
2214
+ ss[j*TF + i] = 0.0f;
2215
+ }
2216
+ }
2217
+
2218
+ threadgroup_barrier(mem_flags::mem_threadgroup);
2219
+
2220
+ {
2221
+ float S[Q] = { [0 ... Q-1] = 0.0h };
2222
+ float M[Q] = { [0 ... Q-1] = -FLT_MAX/2 };
2223
+
2224
+ // assume K and V are same shape
2225
+ const short ne22 = ne12;
2226
+ const short ne23 = ne13;
2227
+
2228
+ const uint nb21 = nb11;
2229
+ const uint nb22 = nb12;
2230
+ const uint nb23 = nb13;
2231
+
2232
+ // broadcast
2233
+ const short rk2 = ne02/ne12;
2234
+ const short rk3 = ne03/ne13;
2235
+
2236
+ const short rv2 = ne02/ne22;
2237
+ const short rv3 = ne03/ne23;
2238
+
2239
+ // k indices
2240
+ const short ik2 = iq2/rk2;
2241
+ const short ik3 = iq3/rk3;
2242
+
2243
+ // v indices
2244
+ const short iv2 = iq2/rv2;
2245
+ const short iv3 = iq3/rv3;
2246
+
2247
+ // load the queries from shared memory into local memory
2248
+ simdgroup_half8x8 mq[D8];
2249
+
2250
+ for (short i = 0; i < D8; ++i) {
2251
+ simdgroup_load(mq[i], sq + i*8, T);
2252
+ }
2253
+
2254
+ // pointer to the mask
2255
+ device const half * mp = (device const half *) (mask + iq1*nb31);
2256
+
2257
+ // prepare diagonal scale matrix
2258
+ simdgroup_float8x8 mscale(scale);
2259
+
2260
+ // loop over the KV cache
2261
+ // each simdgroup handles blocks of Q rows and C columns
2262
+ for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) {
2263
+ const int ic = ic0 + C*sgitg;
2264
+ if (ic >= ne11) {
2265
+ break;
2266
+ }
2267
+
2268
+ // Q*K^T
2269
+ {
2270
+ for (short cc = 0; cc < C/8; ++cc) {
2271
+ simdgroup_float8x8 mqk = make_filled_simdgroup_matrix<float, 8>(0.h);
2272
+
2273
+ device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13));
2274
+
2275
+ for (short i = 0; i < D8; ++i) {
2276
+ simdgroup_half8x8 mk;
2277
+ simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose
2278
+
2279
+ simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
2280
+ }
2281
+
2282
+ // mqk = mqk*scale + mask
2283
+ simdgroup_half8x8 mm;
2284
+ simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(half), 0, false);
2285
+ simdgroup_multiply_accumulate(mqk, mqk, mscale, mm);
2286
+
2287
+ simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
2288
+ }
2289
+ }
2290
+
2291
+ // used to detect blocks full of -INF
2292
+ float smax = -INFINITY;
2293
+
2294
+ // online softmax
2295
+ {
2296
+ float ms[Q];
2297
+
2298
+ for (short j = 0; j < Q; ++j) {
2299
+ const short p = tiisg;
2300
+
2301
+ const float m = M[j];
2302
+ const float s = ss[j*TF + p];
2303
+
2304
+ smax = simd_max(max(smax, s));
2305
+ M[j] = simd_max(max(M[j], s));
2306
+
2307
+ ms[j] = exp(m - M[j]);
2308
+ const float vs = exp(s - M[j]);
2309
+
2310
+ S[j] = S[j]*ms[j] + simd_sum(vs);
2311
+
2312
+ // the P matrix from the paper (Q rows, C columns)
2313
+ ss[j*TF + p] = vs;
2314
+ }
2315
+
2316
+ // create a QxQ diagonal matrix for rescaling the output
2317
+ if (tiisg < Q) {
2318
+ ss[tiisg*TF + C + tiisg] = ms[tiisg];
2319
+ }
2320
+ }
2321
+
2322
+ // skip -INF blocks
2323
+ if (smax == -INFINITY) {
2324
+ continue;
2325
+ }
2326
+
2327
+ // O = diag(ms)*O
2328
+ {
2329
+ simdgroup_float8x8 mm;
2330
+ simdgroup_load(mm, ss + C, TF, 0, false);
2331
+
2332
+ for (short i = 0; i < D8; ++i) {
2333
+ simdgroup_multiply(lo[i], mm, lo[i]);
2334
+ }
2335
+ }
2336
+
2337
+ // O = O + (Q*K^T)*V
2338
+ {
2339
+ for (short cc = 0; cc < C/8; ++cc) {
2340
+ device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23));
2341
+
2342
+ for (short i = 0; i < D8; ++i) {
2343
+ simdgroup_half8x8 mk;
2344
+ simdgroup_load(mk, pv + i*8, nb21/sizeof(half), 0, false);
2345
+
2346
+ simdgroup_float8x8 mv;
2347
+ simdgroup_load(mv, ss + 8*cc, TF, 0, false);
2348
+
2349
+ simdgroup_multiply_accumulate(lo[i], mv, mk, lo[i]);
2350
+ }
2351
+ }
2352
+ }
2353
+ }
2354
+
2355
+ // these are needed for reducing the results from the simdgroups (reuse the ss buffer)
2356
+ for (short j = 0; j < Q; ++j) {
2357
+ if (tiisg == 0) {
2358
+ ss[j*TF + 0] = S[j];
2359
+ ss[j*TF + 1] = M[j];
2360
+ }
2361
+ }
2362
+ }
2363
+
2364
+ // reduce the warps sequentially
2365
+ for (short sg = 1; sg < nsg; ++sg) {
2366
+ float S = { 0.0h };
2367
+ float M = { -FLT_MAX/2 };
2368
+
2369
+ threadgroup_barrier(mem_flags::mem_threadgroup);
2370
+
2371
+ // each simdgroup stores its output to shared memory, reusing sq
2372
+ if (sgitg == sg) {
2373
+ for (short i = 0; i < D8; ++i) {
2374
+ simdgroup_store(lo[i], sq + i*8, T, 0, false);
2375
+ }
2376
+ }
2377
+
2378
+ threadgroup_barrier(mem_flags::mem_threadgroup);
2379
+
2380
+ // the first simdgroup accumulates the results from the other simdgroups
2381
+ if (sgitg == 0) {
2382
+ for (short j = 0; j < Q; ++j) {
2383
+ const float S0 = ss[j*TF + 0];
2384
+ const float S1 = ss[j*TF + sg*SH + 0];
2385
+
2386
+ const float M0 = ss[j*TF + 1];
2387
+ const float M1 = ss[j*TF + sg*SH + 1];
2388
+
2389
+ M = max(M0, M1);
2390
+
2391
+ const float ms0 = exp(M0 - M);
2392
+ const float ms1 = exp(M1 - M);
2393
+
2394
+ S = S0*ms0 + S1*ms1;
2395
+
2396
+ if (tiisg == 0) {
2397
+ ss[j*TF + 0] = S;
2398
+ ss[j*TF + 1] = M;
2399
+
2400
+ ss[j*TF + C + j ] = ms0;
2401
+ ss[j*TF + C + j + sg*SH] = ms1;
2402
+ }
2403
+ }
2404
+
2405
+ // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
2406
+ {
2407
+ simdgroup_half8x8 t;
2408
+ simdgroup_float8x8 ms0;
2409
+ simdgroup_float8x8 ms1;
2410
+
2411
+ simdgroup_load(ms0, ss + C, TF, 0, false);
2412
+ simdgroup_load(ms1, ss + C + sg*SH, TF, 0, false);
2413
+
2414
+ for (short i = 0; i < D8; ++i) {
2415
+ simdgroup_load (t, sq + i*8, T, 0, false);
2416
+ simdgroup_multiply(t, ms1, t);
2417
+
2418
+ simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t);
2419
+ }
2420
+ }
2421
+ }
2422
+ }
2423
+
2424
+ // store result to shared memory (reuse sq)
2425
+ if (sgitg == 0) {
2426
+ for (short i = 0; i < D8; ++i) {
2427
+ simdgroup_store(lo[i], sq + i*8, T, 0, false);
2428
+ }
2429
+ }
2430
+
2431
+ device float4 * dst4 = (device float4 *) dst;
2432
+
2433
+ // final rescale with 1/S and store to global memory
2434
+ if (sgitg == 0) {
2435
+ for (short j = 0; j < Q && iq1 + j < ne01; ++j) {
2436
+ const float S = ss[j*TF + 0];
2437
+
2438
+ for (short i = tiisg; i < D4; i += NW) {
2439
+ dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) sq4[j*T4 + i]/S;
2440
+ }
2441
+ }
2442
+ }
2443
+ }
2444
+
2445
+ template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64>;
2446
+ template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80>;
2447
+ template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<96>;
2448
+ template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<112>;
2449
+ template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128>;
2450
+ template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256>;
2451
+
2452
+ template<int64_t D, int64_t Q = 1, int64_t C = 32> // head size, queries per threadgroup, cache items per threadgroup
2453
+ kernel void kernel_flash_attn_ext_vec_f16(
2454
+ device const char * q,
2455
+ device const char * k,
2456
+ device const char * v,
2457
+ device const char * mask,
2458
+ device float * dst,
2459
+ constant int64_t & ne00,
2460
+ constant int64_t & ne01,
2461
+ constant int64_t & ne02,
2462
+ constant int64_t & ne03,
2463
+ constant uint64_t & nb00,
2464
+ constant uint64_t & nb01,
2465
+ constant uint64_t & nb02,
2466
+ constant uint64_t & nb03,
2467
+ constant int64_t & ne10,
2468
+ constant int64_t & ne11,
2469
+ constant int64_t & ne12,
2470
+ constant int64_t & ne13,
2471
+ constant uint64_t & nb10,
2472
+ constant uint64_t & nb11,
2473
+ constant uint64_t & nb12,
2474
+ constant uint64_t & nb13,
2475
+ constant int64_t & ne31,
2476
+ constant uint64_t & nb31,
2477
+ constant int64_t & ne0,
2478
+ constant int64_t & ne1,
2479
+ constant int64_t & ne2,
2480
+ constant int64_t & ne3,
2481
+ constant float & scale,
2482
+ threadgroup half * shared [[threadgroup(0)]],
2483
+ uint3 tgpig[[threadgroup_position_in_grid]],
2484
+ uint3 tpitg[[thread_position_in_threadgroup]],
2485
+ uint3 ntg[[threads_per_threadgroup]],
2486
+ ushort tiisg[[thread_index_in_simdgroup]],
2487
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
2488
+ const short nsg = ntg.y; // number of simdgroups
2489
+
2490
+ const short iq3 = tgpig[2];
2491
+ const short iq2 = tgpig[1];
2492
+ const short iq1 = tgpig[0];
2493
+
2494
+ const short D4 = D/4;
2495
+ const short NW = N_SIMDWIDTH;
2496
+ const short SH = (C + Q); // shared memory per simdgroup in (half)
2497
+
2498
+ const short T = D + 2*nsg*SH; // shared memory size per query in (half)
2499
+
2500
+ //threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
2501
+ threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
2502
+ threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
2503
+ threadgroup float4 * ss4 = (threadgroup float4 *) (shared + 2*sgitg*SH + 1*D); // same as above but in half4
2504
+ threadgroup half4 * sr4 = (threadgroup half4 *) (shared + sgitg*D + 1*T); // scratch buffer for the results
2505
+
2506
+ // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
2507
+ half4 lo[D4/NW];
2508
+
2509
+ // load heads from Q to shared memory
2510
+ device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03));
2511
+
2512
+ for (short i = tiisg; i < D4; i += NW) {
2513
+ if (iq1 < ne01) {
2514
+ sq4[i] = (half4) q4[i];
2515
+ } else {
2516
+ sq4[i] = 0.0h;
2517
+ }
2518
+ }
2519
+
2520
+ // zero out lo
2521
+ for (short i = tiisg; i < D4; i += NW) {
2522
+ lo[i/NW] = 0.0h;
2523
+ }
2524
+
2525
+ // zero out shared memory SH
2526
+ for (short i = tiisg; i < SH/4; i += NW) {
2527
+ ss4[i] = 0.0h;
2528
+ }
2529
+
2530
+ threadgroup_barrier(mem_flags::mem_threadgroup);
2531
+
2532
+ {
2533
+ float S = { 0.0h };
2534
+ float M = { -FLT_MAX/2 };
2535
+
2536
+ // assume K and V are same shape
2537
+ const short ne22 = ne12;
2538
+ const short ne23 = ne13;
2539
+
2540
+ const uint nb21 = nb11;
2541
+ const uint nb22 = nb12;
2542
+ const uint nb23 = nb13;
2543
+
2544
+ // broadcast
2545
+ const short rk2 = ne02/ne12;
2546
+ const short rk3 = ne03/ne13;
2547
+
2548
+ const short rv2 = ne02/ne22;
2549
+ const short rv3 = ne03/ne23;
2550
+
2551
+ // k indices
2552
+ const short ik2 = iq2 / rk2;
2553
+ const short ik3 = iq3 / rk3;
2554
+
2555
+ // v indices
2556
+ const short iv2 = iq2 / rv2;
2557
+ const short iv3 = iq3 / rv3;
2558
+
2559
+ // load the queries from shared memory into local memory
2560
+ half4 mq[D4];
2561
+
2562
+ for (short ii = 0; ii < D4; ii += NW) {
2563
+ short i = ii + tiisg;
2564
+ mq[i] = sq4[i];
2565
+ }
2566
+
2567
+ // pointer to the mask
2568
+ device const half4 * mp4 = (device const half4 *) (mask + iq1*nb31);
2569
+
2570
+ // loop over the KV cache
2571
+ // each simdgroup handles blocks of Q rows and C columns
2572
+ for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) {
2573
+ const int ic = ic0 + C*sgitg;
2574
+ if (ic >= ne11) {
2575
+ break;
2576
+ }
2577
+
2578
+ // Q*K^T
2579
+ {
2580
+ #pragma unroll
2581
+ for (short cc = 0; cc < C/4; ++cc) {
2582
+ float4 mqk = { 0.0h };
2583
+
2584
+ device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + 4*cc)*nb11 + ik2*nb12 + ik3*nb13));
2585
+
2586
+ #pragma unroll
2587
+ for (short ii = 0; ii < D4; ii += NW) {
2588
+ const short i = ii + tiisg;
2589
+
2590
+ half4x4 mk;
2591
+ mk[0] = pk4[i + 0*(nb11/8)];
2592
+ mk[1] = pk4[i + 1*(nb11/8)];
2593
+ mk[2] = pk4[i + 2*(nb11/8)];
2594
+ mk[3] = pk4[i + 3*(nb11/8)];
2595
+
2596
+ mqk += (float4) (mq[i] * mk);
2597
+ }
2598
+
2599
+ // reduce the results from the threads in the simdgroup
2600
+ mqk += simd_shuffle_down(mqk, 16);
2601
+ mqk += simd_shuffle_down(mqk, 8);
2602
+ mqk += simd_shuffle_down(mqk, 4);
2603
+ mqk += simd_shuffle_down(mqk, 2);
2604
+ mqk += simd_shuffle_down(mqk, 1);
2605
+
2606
+ // mqk = mqk*scale + mask
2607
+ if (tiisg == 0) {
2608
+ float4 mm = (float4) mp4[ic/4 + cc];
2609
+ mqk = mqk*scale + mm;
2610
+
2611
+ ss4[cc] = mqk;
2612
+ }
2613
+ }
2614
+ }
2615
+
2616
+ // online softmax
2617
+ {
2618
+ const short p = tiisg;
2619
+
2620
+ const float m = M;
2621
+ const float s = ss[p];
2622
+
2623
+ M = simd_max(max(M, s));
2624
+
2625
+ const float ms = exp(m - M);
2626
+ const float vs = exp(s - M);
2627
+
2628
+ S = S*ms + simd_sum(vs);
2629
+
2630
+ // the P matrix from the paper (Q rows, C columns)
2631
+ ss[p] = vs;
2632
+
2633
+ // O = diag(ms)*O
2634
+ #pragma unroll
2635
+ for (short ii = 0; ii < D4; ii += NW) {
2636
+ const short i = ii + tiisg;
2637
+ lo[i/NW] *= ms;
2638
+ }
2639
+ }
2640
+
2641
+ // O = O + (Q*K^T)*V
2642
+ {
2643
+ #pragma unroll
2644
+ for (short cc = 0; cc < C/4; ++cc) {
2645
+ device const half4 * pv4 = (device const half4 *) ((device const char *) v + ((ic + 4*cc)*nb21 + iv2*nb22 + iv3*nb23));
2646
+
2647
+ #pragma unroll
2648
+ for (short ii = 0; ii < D4; ii += NW) {
2649
+ const short i = ii + tiisg;
2650
+
2651
+ lo[i/NW] += pv4[i + 0*(nb21/8)] * ss[4*cc + 0];
2652
+ lo[i/NW] += pv4[i + 1*(nb21/8)] * ss[4*cc + 1];
2653
+ lo[i/NW] += pv4[i + 2*(nb21/8)] * ss[4*cc + 2];
2654
+ lo[i/NW] += pv4[i + 3*(nb21/8)] * ss[4*cc + 3];
2655
+ }
2656
+ }
2657
+ }
2658
+
2659
+ }
2660
+
2661
+ // these are needed for reducing the results from the simdgroups (reuse the ss buffer)
2662
+ if (tiisg == 0) {
2663
+ ss[0] = S;
2664
+ ss[1] = M;
2665
+ }
2666
+ }
2667
+
2668
+ // store results to shared memory
2669
+ for (short ii = 0; ii < D4; ii += NW) {
2670
+ short i = ii + tiisg;
2671
+ sr4[i] = lo[ii/NW];
2672
+ }
2673
+
2674
+ threadgroup_barrier(mem_flags::mem_threadgroup);
2675
+
2676
+ // parallel reduce
2677
+ for (short r = nsg/2; r > 0; r >>= 1) {
2678
+ if (sgitg < r) {
2679
+ const float S0 = ss[ 0];
2680
+ const float S1 = ss[r*SH + 0];
2681
+
2682
+ const float M0 = ss[ 1];
2683
+ const float M1 = ss[r*SH + 1];
2684
+
2685
+ const float M = max(M0, M1);
2686
+
2687
+ const float ms0 = exp(M0 - M);
2688
+ const float ms1 = exp(M1 - M);
2689
+
2690
+ const float S = S0*ms0 + S1*ms1;
2691
+
2692
+ if (tiisg == 0) {
2693
+ ss[0] = S;
2694
+ ss[1] = M;
2695
+ }
2696
+
2697
+ // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
2698
+ for (short ii = 0; ii < D4; ii += NW) {
2699
+ short i = ii + tiisg;
2700
+ sr4[i] = sr4[i]*ms0 + sr4[i + r*D4]*ms1;
2701
+ }
2702
+ }
2703
+
2704
+ threadgroup_barrier(mem_flags::mem_threadgroup);
2705
+ }
2706
+
2707
+ device float4 * dst4 = (device float4 *) dst;
2708
+
2709
+ // final rescale with 1/S and store to global memory
2710
+ if (sgitg == 0) {
2711
+ const float S = ss[0];
2712
+
2713
+ for (short ii = 0; ii < D4; ii += NW) {
2714
+ short i = ii + tiisg;
2715
+ dst4[(iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D4 + i] = (float4) sr4[i]/S;
2716
+ }
2717
+ }
2718
+ }
2719
+
2720
+ template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>;
2721
+ template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>;
2722
+
2087
2723
  kernel void kernel_cpy_f16_f16(
2088
2724
  device const half * src0,
2089
2725
  device half * dst,