llama_cpp 0.15.1 → 0.15.2

Sign up to get free protection for your applications and to get access to all the features.
@@ -229,6 +229,13 @@ kernel void kernel_relu(
229
229
  dst[tpig] = max(0.0f, src0[tpig]);
230
230
  }
231
231
 
232
+ kernel void kernel_sigmoid(
233
+ device const float * src0,
234
+ device float * dst,
235
+ uint tpig[[thread_position_in_grid]]) {
236
+ dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig]));
237
+ }
238
+
232
239
  kernel void kernel_tanh(
233
240
  device const float * src0,
234
241
  device float * dst,
@@ -356,7 +363,6 @@ template<typename T>
356
363
  kernel void kernel_soft_max(
357
364
  device const char * src0,
358
365
  device const char * src1,
359
- device const char * src2,
360
366
  device char * dst,
361
367
  constant int64_t & ne00,
362
368
  constant int64_t & ne01,
@@ -378,10 +384,9 @@ kernel void kernel_soft_max(
378
384
 
379
385
  device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
380
386
  device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00 : nullptr;
381
- device const T * ppos = src2 != src0 ? (device const T *) src2 : nullptr;
382
387
  device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
383
388
 
384
- float slope = 0.0f;
389
+ float slope = 1.0f;
385
390
 
386
391
  // ALiBi
387
392
  if (max_bias > 0.0f) {
@@ -397,7 +402,7 @@ kernel void kernel_soft_max(
397
402
  float lmax = -INFINITY;
398
403
 
399
404
  for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
400
- lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f));
405
+ lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
401
406
  }
402
407
 
403
408
  // find the max value in the block
@@ -422,7 +427,7 @@ kernel void kernel_soft_max(
422
427
  // parallel sum
423
428
  float lsum = 0.0f;
424
429
  for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
425
- const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f)) - max_val);
430
+ const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val);
426
431
  lsum += exp_psrc0;
427
432
  pdst[i00] = exp_psrc0;
428
433
  }
@@ -461,7 +466,6 @@ template<typename T>
461
466
  kernel void kernel_soft_max_4(
462
467
  device const char * src0,
463
468
  device const char * src1,
464
- device const char * src2,
465
469
  device char * dst,
466
470
  constant int64_t & ne00,
467
471
  constant int64_t & ne01,
@@ -483,10 +487,9 @@ kernel void kernel_soft_max_4(
483
487
 
484
488
  device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
485
489
  device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00/4 : nullptr;
486
- device const T * ppos = src2 != src0 ? (device const T *) src2 : nullptr;
487
490
  device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
488
491
 
489
- float slope = 0.0f;
492
+ float slope = 1.0f;
490
493
 
491
494
  if (max_bias > 0.0f) {
492
495
  const int64_t h = i02;
@@ -501,7 +504,7 @@ kernel void kernel_soft_max_4(
501
504
  float4 lmax4 = -INFINITY;
502
505
 
503
506
  for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
504
- lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f)));
507
+ lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
505
508
  }
506
509
 
507
510
  const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
@@ -527,7 +530,7 @@ kernel void kernel_soft_max_4(
527
530
  // parallel sum
528
531
  float4 lsum4 = 0.0f;
529
532
  for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
530
- const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f))) - max_val);
533
+ const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val);
531
534
  lsum4 += exp_psrc4;
532
535
  pdst4[i00] = exp_psrc4;
533
536
  }
@@ -1595,60 +1598,6 @@ kernel void kernel_mul_mv_f16_f32_l4(
1595
1598
  }
1596
1599
  }
1597
1600
 
1598
- kernel void kernel_alibi_f32(
1599
- device const float * src0,
1600
- device float * dst,
1601
- constant int64_t & ne00,
1602
- constant int64_t & ne01,
1603
- constant int64_t & ne02,
1604
- constant int64_t & ne03,
1605
- constant uint64_t & nb00,
1606
- constant uint64_t & nb01,
1607
- constant uint64_t & nb02,
1608
- constant uint64_t & nb03,
1609
- constant int64_t & ne0,
1610
- constant int64_t & ne1,
1611
- constant int64_t & ne2,
1612
- constant int64_t & ne3,
1613
- constant uint64_t & nb0,
1614
- constant uint64_t & nb1,
1615
- constant uint64_t & nb2,
1616
- constant uint64_t & nb3,
1617
- constant float & m0,
1618
- constant float & m1,
1619
- constant int & n_heads_log2_floor,
1620
- uint3 tgpig[[threadgroup_position_in_grid]],
1621
- uint3 tpitg[[thread_position_in_threadgroup]],
1622
- uint3 ntg[[threads_per_threadgroup]]) {
1623
- const int64_t i03 = tgpig[2];
1624
- const int64_t i02 = tgpig[1];
1625
- const int64_t i01 = tgpig[0];
1626
-
1627
- const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
1628
-
1629
- const int64_t i3 = n / (ne2*ne1*ne0);
1630
- const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
1631
- const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
1632
- //const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
1633
-
1634
- const int64_t k = i3*ne3 + i2;
1635
-
1636
- float m_k;
1637
- if (k < n_heads_log2_floor) {
1638
- m_k = pow(m0, k + 1);
1639
- } else {
1640
- m_k = pow(m1, 2 * (k - n_heads_log2_floor) + 1);
1641
- }
1642
-
1643
- device char * dst_row = (device char *) dst + i3*nb3 + i2*nb2 + i1*nb1;
1644
- device const char * src_row = (device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01;
1645
- for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
1646
- const float src_v = *(device float *)(src_row + i00*nb00);
1647
- device float * dst_v = (device float *)(dst_row + i00*nb0);
1648
- *dst_v = i00 * m_k + src_v;
1649
- }
1650
- }
1651
-
1652
1601
  static float rope_yarn_ramp(const float low, const float high, const int i0) {
1653
1602
  const float y = (i0 / 2 - low) / max(0.001f, high - low);
1654
1603
  return 1.0f - min(1.0f, max(0.0f, y));
@@ -1903,7 +1852,10 @@ kernel void kernel_upscale_f32(
1903
1852
  constant uint64_t & nb1,
1904
1853
  constant uint64_t & nb2,
1905
1854
  constant uint64_t & nb3,
1906
- constant int32_t & sf,
1855
+ constant float & sf0,
1856
+ constant float & sf1,
1857
+ constant float & sf2,
1858
+ constant float & sf3,
1907
1859
  uint3 tgpig[[threadgroup_position_in_grid]],
1908
1860
  uint3 tpitg[[thread_position_in_threadgroup]],
1909
1861
  uint3 ntg[[threads_per_threadgroup]]) {
@@ -1912,15 +1864,17 @@ kernel void kernel_upscale_f32(
1912
1864
  const int64_t i2 = tgpig.y;
1913
1865
  const int64_t i1 = tgpig.x;
1914
1866
 
1915
- const int64_t i03 = i3;
1916
- const int64_t i02 = i2;
1917
- const int64_t i01 = i1/sf;
1918
-
1919
- device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
1920
- device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1);
1867
+ const int64_t i03 = i3/sf3;
1868
+ const int64_t i02 = i2/sf2;
1869
+ const int64_t i01 = i1/sf1;
1921
1870
 
1922
1871
  for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
1923
- dst_ptr[i0] = src0_ptr[i0/sf];
1872
+ const int64_t i00 = i0/sf0;
1873
+
1874
+ device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
1875
+ device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1876
+
1877
+ dst_ptr[0] = src0_ptr[0];
1924
1878
  }
1925
1879
  }
1926
1880
 
@@ -2100,29 +2054,29 @@ typedef void (flash_attn_ext_f16_t)(
2100
2054
  device const char * v,
2101
2055
  device const char * mask,
2102
2056
  device float * dst,
2103
- constant int64_t & ne00,
2104
2057
  constant int64_t & ne01,
2105
2058
  constant int64_t & ne02,
2106
2059
  constant int64_t & ne03,
2107
- constant uint64_t & nb00,
2108
2060
  constant uint64_t & nb01,
2109
2061
  constant uint64_t & nb02,
2110
2062
  constant uint64_t & nb03,
2111
- constant int64_t & ne10,
2112
2063
  constant int64_t & ne11,
2113
2064
  constant int64_t & ne12,
2114
2065
  constant int64_t & ne13,
2115
- constant uint64_t & nb10,
2116
2066
  constant uint64_t & nb11,
2117
2067
  constant uint64_t & nb12,
2118
2068
  constant uint64_t & nb13,
2119
- constant int64_t & ne31,
2069
+ constant uint64_t & nb21,
2070
+ constant uint64_t & nb22,
2071
+ constant uint64_t & nb23,
2120
2072
  constant uint64_t & nb31,
2121
- constant int64_t & ne0,
2122
2073
  constant int64_t & ne1,
2123
2074
  constant int64_t & ne2,
2124
- constant int64_t & ne3,
2125
2075
  constant float & scale,
2076
+ constant float & max_bias,
2077
+ constant float & m0,
2078
+ constant float & m1,
2079
+ constant uint32_t & n_head_log2,
2126
2080
  threadgroup half * shared,
2127
2081
  uint3 tgpig[[threadgroup_position_in_grid]],
2128
2082
  uint3 tpitg[[thread_position_in_threadgroup]],
@@ -2138,29 +2092,29 @@ kernel void kernel_flash_attn_ext_f16(
2138
2092
  device const char * v,
2139
2093
  device const char * mask,
2140
2094
  device float * dst,
2141
- constant int64_t & ne00,
2142
2095
  constant int64_t & ne01,
2143
2096
  constant int64_t & ne02,
2144
2097
  constant int64_t & ne03,
2145
- constant uint64_t & nb00,
2146
2098
  constant uint64_t & nb01,
2147
2099
  constant uint64_t & nb02,
2148
2100
  constant uint64_t & nb03,
2149
- constant int64_t & ne10,
2150
2101
  constant int64_t & ne11,
2151
2102
  constant int64_t & ne12,
2152
2103
  constant int64_t & ne13,
2153
- constant uint64_t & nb10,
2154
2104
  constant uint64_t & nb11,
2155
2105
  constant uint64_t & nb12,
2156
2106
  constant uint64_t & nb13,
2157
- constant int64_t & ne31,
2107
+ constant uint64_t & nb21,
2108
+ constant uint64_t & nb22,
2109
+ constant uint64_t & nb23,
2158
2110
  constant uint64_t & nb31,
2159
- constant int64_t & ne0,
2160
2111
  constant int64_t & ne1,
2161
2112
  constant int64_t & ne2,
2162
- constant int64_t & ne3,
2163
2113
  constant float & scale,
2114
+ constant float & max_bias,
2115
+ constant float & m0,
2116
+ constant float & m1,
2117
+ constant uint32_t & n_head_log2,
2164
2118
  threadgroup half * shared [[threadgroup(0)]],
2165
2119
  uint3 tgpig[[threadgroup_position_in_grid]],
2166
2120
  uint3 tpitg[[thread_position_in_threadgroup]],
@@ -2225,10 +2179,6 @@ kernel void kernel_flash_attn_ext_f16(
2225
2179
  const short ne22 = ne12;
2226
2180
  const short ne23 = ne13;
2227
2181
 
2228
- const uint nb21 = nb11;
2229
- const uint nb22 = nb12;
2230
- const uint nb23 = nb13;
2231
-
2232
2182
  // broadcast
2233
2183
  const short rk2 = ne02/ne12;
2234
2184
  const short rk3 = ne03/ne13;
@@ -2257,6 +2207,19 @@ kernel void kernel_flash_attn_ext_f16(
2257
2207
  // prepare diagonal scale matrix
2258
2208
  simdgroup_float8x8 mscale(scale);
2259
2209
 
2210
+ // prepare diagonal slope matrix
2211
+ simdgroup_float8x8 mslope(1.0f);
2212
+
2213
+ // ALiBi
2214
+ if (max_bias > 0.0f) {
2215
+ const uint32_t h = iq2;
2216
+
2217
+ const float base = h < n_head_log2 ? m0 : m1;
2218
+ const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
2219
+
2220
+ mslope = simdgroup_float8x8(pow(base, exph));
2221
+ }
2222
+
2260
2223
  // loop over the KV cache
2261
2224
  // each simdgroup handles blocks of Q rows and C columns
2262
2225
  for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) {
@@ -2279,10 +2242,16 @@ kernel void kernel_flash_attn_ext_f16(
2279
2242
  simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
2280
2243
  }
2281
2244
 
2282
- // mqk = mqk*scale + mask
2283
- simdgroup_half8x8 mm;
2284
- simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(half), 0, false);
2285
- simdgroup_multiply_accumulate(mqk, mqk, mscale, mm);
2245
+ if (mask != q) {
2246
+ // mqk = mqk*scale + mask*slope
2247
+ simdgroup_half8x8 mm;
2248
+ simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(half), 0, false);
2249
+ simdgroup_multiply(mm, mslope, mm);
2250
+ simdgroup_multiply_accumulate(mqk, mqk, mscale, mm);
2251
+ } else {
2252
+ // mqk = mqk*scale
2253
+ simdgroup_multiply(mqk, mscale, mqk);
2254
+ }
2286
2255
 
2287
2256
  simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
2288
2257
  }
@@ -2456,29 +2425,29 @@ kernel void kernel_flash_attn_ext_vec_f16(
2456
2425
  device const char * v,
2457
2426
  device const char * mask,
2458
2427
  device float * dst,
2459
- constant int64_t & ne00,
2460
2428
  constant int64_t & ne01,
2461
2429
  constant int64_t & ne02,
2462
2430
  constant int64_t & ne03,
2463
- constant uint64_t & nb00,
2464
2431
  constant uint64_t & nb01,
2465
2432
  constant uint64_t & nb02,
2466
2433
  constant uint64_t & nb03,
2467
- constant int64_t & ne10,
2468
2434
  constant int64_t & ne11,
2469
2435
  constant int64_t & ne12,
2470
2436
  constant int64_t & ne13,
2471
- constant uint64_t & nb10,
2472
2437
  constant uint64_t & nb11,
2473
2438
  constant uint64_t & nb12,
2474
2439
  constant uint64_t & nb13,
2475
- constant int64_t & ne31,
2440
+ constant uint64_t & nb21,
2441
+ constant uint64_t & nb22,
2442
+ constant uint64_t & nb23,
2476
2443
  constant uint64_t & nb31,
2477
- constant int64_t & ne0,
2478
2444
  constant int64_t & ne1,
2479
2445
  constant int64_t & ne2,
2480
- constant int64_t & ne3,
2481
2446
  constant float & scale,
2447
+ constant float & max_bias,
2448
+ constant float & m0,
2449
+ constant float & m1,
2450
+ constant uint32_t & n_head_log2,
2482
2451
  threadgroup half * shared [[threadgroup(0)]],
2483
2452
  uint3 tgpig[[threadgroup_position_in_grid]],
2484
2453
  uint3 tpitg[[thread_position_in_threadgroup]],
@@ -2497,6 +2466,18 @@ kernel void kernel_flash_attn_ext_vec_f16(
2497
2466
 
2498
2467
  const short T = D + 2*nsg*SH; // shared memory size per query in (half)
2499
2468
 
2469
+ float slope = 1.0f;
2470
+
2471
+ // ALiBi
2472
+ if (max_bias > 0.0f) {
2473
+ const uint32_t h = iq2;
2474
+
2475
+ const float base = h < n_head_log2 ? m0 : m1;
2476
+ const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
2477
+
2478
+ slope = pow(base, exp);
2479
+ }
2480
+
2500
2481
  //threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
2501
2482
  threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
2502
2483
  threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
@@ -2537,10 +2518,6 @@ kernel void kernel_flash_attn_ext_vec_f16(
2537
2518
  const short ne22 = ne12;
2538
2519
  const short ne23 = ne13;
2539
2520
 
2540
- const uint nb21 = nb11;
2541
- const uint nb22 = nb12;
2542
- const uint nb23 = nb13;
2543
-
2544
2521
  // broadcast
2545
2522
  const short rk2 = ne02/ne12;
2546
2523
  const short rk3 = ne03/ne13;
@@ -2603,10 +2580,9 @@ kernel void kernel_flash_attn_ext_vec_f16(
2603
2580
  mqk += simd_shuffle_down(mqk, 2);
2604
2581
  mqk += simd_shuffle_down(mqk, 1);
2605
2582
 
2606
- // mqk = mqk*scale + mask
2583
+ // mqk = mqk*scale + mask*slope
2607
2584
  if (tiisg == 0) {
2608
- float4 mm = (float4) mp4[ic/4 + cc];
2609
- mqk = mqk*scale + mm;
2585
+ mqk = mqk*scale + ((mask != q) ? ((float4) mp4[ic/4 + cc])*slope : (float4) 0.0f);
2610
2586
 
2611
2587
  ss4[cc] = mqk;
2612
2588
  }
@@ -2840,7 +2816,8 @@ kernel void kernel_cpy_f32_f16(
2840
2816
  for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
2841
2817
  device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
2842
2818
 
2843
- dst_data[i00] = src[0];
2819
+ // TODO: is there a better way to handle -INFINITY?
2820
+ dst_data[i00] = src[0] == -INFINITY ? -MAXHALF : src[0];
2844
2821
  }
2845
2822
  }
2846
2823