llama_cpp 0.15.1 → 0.15.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -229,6 +229,13 @@ kernel void kernel_relu(
229
229
  dst[tpig] = max(0.0f, src0[tpig]);
230
230
  }
231
231
 
232
+ kernel void kernel_sigmoid(
233
+ device const float * src0,
234
+ device float * dst,
235
+ uint tpig[[thread_position_in_grid]]) {
236
+ dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig]));
237
+ }
238
+
232
239
  kernel void kernel_tanh(
233
240
  device const float * src0,
234
241
  device float * dst,
@@ -356,7 +363,6 @@ template<typename T>
356
363
  kernel void kernel_soft_max(
357
364
  device const char * src0,
358
365
  device const char * src1,
359
- device const char * src2,
360
366
  device char * dst,
361
367
  constant int64_t & ne00,
362
368
  constant int64_t & ne01,
@@ -378,10 +384,9 @@ kernel void kernel_soft_max(
378
384
 
379
385
  device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
380
386
  device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00 : nullptr;
381
- device const T * ppos = src2 != src0 ? (device const T *) src2 : nullptr;
382
387
  device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
383
388
 
384
- float slope = 0.0f;
389
+ float slope = 1.0f;
385
390
 
386
391
  // ALiBi
387
392
  if (max_bias > 0.0f) {
@@ -397,7 +402,7 @@ kernel void kernel_soft_max(
397
402
  float lmax = -INFINITY;
398
403
 
399
404
  for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
400
- lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f));
405
+ lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
401
406
  }
402
407
 
403
408
  // find the max value in the block
@@ -422,7 +427,7 @@ kernel void kernel_soft_max(
422
427
  // parallel sum
423
428
  float lsum = 0.0f;
424
429
  for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
425
- const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f)) - max_val);
430
+ const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val);
426
431
  lsum += exp_psrc0;
427
432
  pdst[i00] = exp_psrc0;
428
433
  }
@@ -461,7 +466,6 @@ template<typename T>
461
466
  kernel void kernel_soft_max_4(
462
467
  device const char * src0,
463
468
  device const char * src1,
464
- device const char * src2,
465
469
  device char * dst,
466
470
  constant int64_t & ne00,
467
471
  constant int64_t & ne01,
@@ -483,10 +487,9 @@ kernel void kernel_soft_max_4(
483
487
 
484
488
  device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
485
489
  device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00/4 : nullptr;
486
- device const T * ppos = src2 != src0 ? (device const T *) src2 : nullptr;
487
490
  device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
488
491
 
489
- float slope = 0.0f;
492
+ float slope = 1.0f;
490
493
 
491
494
  if (max_bias > 0.0f) {
492
495
  const int64_t h = i02;
@@ -501,7 +504,7 @@ kernel void kernel_soft_max_4(
501
504
  float4 lmax4 = -INFINITY;
502
505
 
503
506
  for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
504
- lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f)));
507
+ lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
505
508
  }
506
509
 
507
510
  const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
@@ -527,7 +530,7 @@ kernel void kernel_soft_max_4(
527
530
  // parallel sum
528
531
  float4 lsum4 = 0.0f;
529
532
  for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
530
- const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f))) - max_val);
533
+ const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val);
531
534
  lsum4 += exp_psrc4;
532
535
  pdst4[i00] = exp_psrc4;
533
536
  }
@@ -1595,60 +1598,6 @@ kernel void kernel_mul_mv_f16_f32_l4(
1595
1598
  }
1596
1599
  }
1597
1600
 
1598
- kernel void kernel_alibi_f32(
1599
- device const float * src0,
1600
- device float * dst,
1601
- constant int64_t & ne00,
1602
- constant int64_t & ne01,
1603
- constant int64_t & ne02,
1604
- constant int64_t & ne03,
1605
- constant uint64_t & nb00,
1606
- constant uint64_t & nb01,
1607
- constant uint64_t & nb02,
1608
- constant uint64_t & nb03,
1609
- constant int64_t & ne0,
1610
- constant int64_t & ne1,
1611
- constant int64_t & ne2,
1612
- constant int64_t & ne3,
1613
- constant uint64_t & nb0,
1614
- constant uint64_t & nb1,
1615
- constant uint64_t & nb2,
1616
- constant uint64_t & nb3,
1617
- constant float & m0,
1618
- constant float & m1,
1619
- constant int & n_heads_log2_floor,
1620
- uint3 tgpig[[threadgroup_position_in_grid]],
1621
- uint3 tpitg[[thread_position_in_threadgroup]],
1622
- uint3 ntg[[threads_per_threadgroup]]) {
1623
- const int64_t i03 = tgpig[2];
1624
- const int64_t i02 = tgpig[1];
1625
- const int64_t i01 = tgpig[0];
1626
-
1627
- const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
1628
-
1629
- const int64_t i3 = n / (ne2*ne1*ne0);
1630
- const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
1631
- const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
1632
- //const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
1633
-
1634
- const int64_t k = i3*ne3 + i2;
1635
-
1636
- float m_k;
1637
- if (k < n_heads_log2_floor) {
1638
- m_k = pow(m0, k + 1);
1639
- } else {
1640
- m_k = pow(m1, 2 * (k - n_heads_log2_floor) + 1);
1641
- }
1642
-
1643
- device char * dst_row = (device char *) dst + i3*nb3 + i2*nb2 + i1*nb1;
1644
- device const char * src_row = (device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01;
1645
- for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
1646
- const float src_v = *(device float *)(src_row + i00*nb00);
1647
- device float * dst_v = (device float *)(dst_row + i00*nb0);
1648
- *dst_v = i00 * m_k + src_v;
1649
- }
1650
- }
1651
-
1652
1601
  static float rope_yarn_ramp(const float low, const float high, const int i0) {
1653
1602
  const float y = (i0 / 2 - low) / max(0.001f, high - low);
1654
1603
  return 1.0f - min(1.0f, max(0.0f, y));
@@ -1903,7 +1852,10 @@ kernel void kernel_upscale_f32(
1903
1852
  constant uint64_t & nb1,
1904
1853
  constant uint64_t & nb2,
1905
1854
  constant uint64_t & nb3,
1906
- constant int32_t & sf,
1855
+ constant float & sf0,
1856
+ constant float & sf1,
1857
+ constant float & sf2,
1858
+ constant float & sf3,
1907
1859
  uint3 tgpig[[threadgroup_position_in_grid]],
1908
1860
  uint3 tpitg[[thread_position_in_threadgroup]],
1909
1861
  uint3 ntg[[threads_per_threadgroup]]) {
@@ -1912,15 +1864,17 @@ kernel void kernel_upscale_f32(
1912
1864
  const int64_t i2 = tgpig.y;
1913
1865
  const int64_t i1 = tgpig.x;
1914
1866
 
1915
- const int64_t i03 = i3;
1916
- const int64_t i02 = i2;
1917
- const int64_t i01 = i1/sf;
1918
-
1919
- device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
1920
- device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1);
1867
+ const int64_t i03 = i3/sf3;
1868
+ const int64_t i02 = i2/sf2;
1869
+ const int64_t i01 = i1/sf1;
1921
1870
 
1922
1871
  for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
1923
- dst_ptr[i0] = src0_ptr[i0/sf];
1872
+ const int64_t i00 = i0/sf0;
1873
+
1874
+ device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
1875
+ device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1876
+
1877
+ dst_ptr[0] = src0_ptr[0];
1924
1878
  }
1925
1879
  }
1926
1880
 
@@ -2100,29 +2054,29 @@ typedef void (flash_attn_ext_f16_t)(
2100
2054
  device const char * v,
2101
2055
  device const char * mask,
2102
2056
  device float * dst,
2103
- constant int64_t & ne00,
2104
2057
  constant int64_t & ne01,
2105
2058
  constant int64_t & ne02,
2106
2059
  constant int64_t & ne03,
2107
- constant uint64_t & nb00,
2108
2060
  constant uint64_t & nb01,
2109
2061
  constant uint64_t & nb02,
2110
2062
  constant uint64_t & nb03,
2111
- constant int64_t & ne10,
2112
2063
  constant int64_t & ne11,
2113
2064
  constant int64_t & ne12,
2114
2065
  constant int64_t & ne13,
2115
- constant uint64_t & nb10,
2116
2066
  constant uint64_t & nb11,
2117
2067
  constant uint64_t & nb12,
2118
2068
  constant uint64_t & nb13,
2119
- constant int64_t & ne31,
2069
+ constant uint64_t & nb21,
2070
+ constant uint64_t & nb22,
2071
+ constant uint64_t & nb23,
2120
2072
  constant uint64_t & nb31,
2121
- constant int64_t & ne0,
2122
2073
  constant int64_t & ne1,
2123
2074
  constant int64_t & ne2,
2124
- constant int64_t & ne3,
2125
2075
  constant float & scale,
2076
+ constant float & max_bias,
2077
+ constant float & m0,
2078
+ constant float & m1,
2079
+ constant uint32_t & n_head_log2,
2126
2080
  threadgroup half * shared,
2127
2081
  uint3 tgpig[[threadgroup_position_in_grid]],
2128
2082
  uint3 tpitg[[thread_position_in_threadgroup]],
@@ -2138,29 +2092,29 @@ kernel void kernel_flash_attn_ext_f16(
2138
2092
  device const char * v,
2139
2093
  device const char * mask,
2140
2094
  device float * dst,
2141
- constant int64_t & ne00,
2142
2095
  constant int64_t & ne01,
2143
2096
  constant int64_t & ne02,
2144
2097
  constant int64_t & ne03,
2145
- constant uint64_t & nb00,
2146
2098
  constant uint64_t & nb01,
2147
2099
  constant uint64_t & nb02,
2148
2100
  constant uint64_t & nb03,
2149
- constant int64_t & ne10,
2150
2101
  constant int64_t & ne11,
2151
2102
  constant int64_t & ne12,
2152
2103
  constant int64_t & ne13,
2153
- constant uint64_t & nb10,
2154
2104
  constant uint64_t & nb11,
2155
2105
  constant uint64_t & nb12,
2156
2106
  constant uint64_t & nb13,
2157
- constant int64_t & ne31,
2107
+ constant uint64_t & nb21,
2108
+ constant uint64_t & nb22,
2109
+ constant uint64_t & nb23,
2158
2110
  constant uint64_t & nb31,
2159
- constant int64_t & ne0,
2160
2111
  constant int64_t & ne1,
2161
2112
  constant int64_t & ne2,
2162
- constant int64_t & ne3,
2163
2113
  constant float & scale,
2114
+ constant float & max_bias,
2115
+ constant float & m0,
2116
+ constant float & m1,
2117
+ constant uint32_t & n_head_log2,
2164
2118
  threadgroup half * shared [[threadgroup(0)]],
2165
2119
  uint3 tgpig[[threadgroup_position_in_grid]],
2166
2120
  uint3 tpitg[[thread_position_in_threadgroup]],
@@ -2225,10 +2179,6 @@ kernel void kernel_flash_attn_ext_f16(
2225
2179
  const short ne22 = ne12;
2226
2180
  const short ne23 = ne13;
2227
2181
 
2228
- const uint nb21 = nb11;
2229
- const uint nb22 = nb12;
2230
- const uint nb23 = nb13;
2231
-
2232
2182
  // broadcast
2233
2183
  const short rk2 = ne02/ne12;
2234
2184
  const short rk3 = ne03/ne13;
@@ -2257,6 +2207,19 @@ kernel void kernel_flash_attn_ext_f16(
2257
2207
  // prepare diagonal scale matrix
2258
2208
  simdgroup_float8x8 mscale(scale);
2259
2209
 
2210
+ // prepare diagonal slope matrix
2211
+ simdgroup_float8x8 mslope(1.0f);
2212
+
2213
+ // ALiBi
2214
+ if (max_bias > 0.0f) {
2215
+ const uint32_t h = iq2;
2216
+
2217
+ const float base = h < n_head_log2 ? m0 : m1;
2218
+ const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
2219
+
2220
+ mslope = simdgroup_float8x8(pow(base, exph));
2221
+ }
2222
+
2260
2223
  // loop over the KV cache
2261
2224
  // each simdgroup handles blocks of Q rows and C columns
2262
2225
  for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) {
@@ -2279,10 +2242,16 @@ kernel void kernel_flash_attn_ext_f16(
2279
2242
  simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
2280
2243
  }
2281
2244
 
2282
- // mqk = mqk*scale + mask
2283
- simdgroup_half8x8 mm;
2284
- simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(half), 0, false);
2285
- simdgroup_multiply_accumulate(mqk, mqk, mscale, mm);
2245
+ if (mask != q) {
2246
+ // mqk = mqk*scale + mask*slope
2247
+ simdgroup_half8x8 mm;
2248
+ simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(half), 0, false);
2249
+ simdgroup_multiply(mm, mslope, mm);
2250
+ simdgroup_multiply_accumulate(mqk, mqk, mscale, mm);
2251
+ } else {
2252
+ // mqk = mqk*scale
2253
+ simdgroup_multiply(mqk, mscale, mqk);
2254
+ }
2286
2255
 
2287
2256
  simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
2288
2257
  }
@@ -2456,29 +2425,29 @@ kernel void kernel_flash_attn_ext_vec_f16(
2456
2425
  device const char * v,
2457
2426
  device const char * mask,
2458
2427
  device float * dst,
2459
- constant int64_t & ne00,
2460
2428
  constant int64_t & ne01,
2461
2429
  constant int64_t & ne02,
2462
2430
  constant int64_t & ne03,
2463
- constant uint64_t & nb00,
2464
2431
  constant uint64_t & nb01,
2465
2432
  constant uint64_t & nb02,
2466
2433
  constant uint64_t & nb03,
2467
- constant int64_t & ne10,
2468
2434
  constant int64_t & ne11,
2469
2435
  constant int64_t & ne12,
2470
2436
  constant int64_t & ne13,
2471
- constant uint64_t & nb10,
2472
2437
  constant uint64_t & nb11,
2473
2438
  constant uint64_t & nb12,
2474
2439
  constant uint64_t & nb13,
2475
- constant int64_t & ne31,
2440
+ constant uint64_t & nb21,
2441
+ constant uint64_t & nb22,
2442
+ constant uint64_t & nb23,
2476
2443
  constant uint64_t & nb31,
2477
- constant int64_t & ne0,
2478
2444
  constant int64_t & ne1,
2479
2445
  constant int64_t & ne2,
2480
- constant int64_t & ne3,
2481
2446
  constant float & scale,
2447
+ constant float & max_bias,
2448
+ constant float & m0,
2449
+ constant float & m1,
2450
+ constant uint32_t & n_head_log2,
2482
2451
  threadgroup half * shared [[threadgroup(0)]],
2483
2452
  uint3 tgpig[[threadgroup_position_in_grid]],
2484
2453
  uint3 tpitg[[thread_position_in_threadgroup]],
@@ -2497,6 +2466,18 @@ kernel void kernel_flash_attn_ext_vec_f16(
2497
2466
 
2498
2467
  const short T = D + 2*nsg*SH; // shared memory size per query in (half)
2499
2468
 
2469
+ float slope = 1.0f;
2470
+
2471
+ // ALiBi
2472
+ if (max_bias > 0.0f) {
2473
+ const uint32_t h = iq2;
2474
+
2475
+ const float base = h < n_head_log2 ? m0 : m1;
2476
+ const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
2477
+
2478
+ slope = pow(base, exp);
2479
+ }
2480
+
2500
2481
  //threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
2501
2482
  threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
2502
2483
  threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
@@ -2537,10 +2518,6 @@ kernel void kernel_flash_attn_ext_vec_f16(
2537
2518
  const short ne22 = ne12;
2538
2519
  const short ne23 = ne13;
2539
2520
 
2540
- const uint nb21 = nb11;
2541
- const uint nb22 = nb12;
2542
- const uint nb23 = nb13;
2543
-
2544
2521
  // broadcast
2545
2522
  const short rk2 = ne02/ne12;
2546
2523
  const short rk3 = ne03/ne13;
@@ -2603,10 +2580,9 @@ kernel void kernel_flash_attn_ext_vec_f16(
2603
2580
  mqk += simd_shuffle_down(mqk, 2);
2604
2581
  mqk += simd_shuffle_down(mqk, 1);
2605
2582
 
2606
- // mqk = mqk*scale + mask
2583
+ // mqk = mqk*scale + mask*slope
2607
2584
  if (tiisg == 0) {
2608
- float4 mm = (float4) mp4[ic/4 + cc];
2609
- mqk = mqk*scale + mm;
2585
+ mqk = mqk*scale + ((mask != q) ? ((float4) mp4[ic/4 + cc])*slope : (float4) 0.0f);
2610
2586
 
2611
2587
  ss4[cc] = mqk;
2612
2588
  }
@@ -2840,7 +2816,8 @@ kernel void kernel_cpy_f32_f16(
2840
2816
  for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
2841
2817
  device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
2842
2818
 
2843
- dst_data[i00] = src[0];
2819
+ // TODO: is there a better way to handle -INFINITY?
2820
+ dst_data[i00] = src[0] == -INFINITY ? -MAXHALF : src[0];
2844
2821
  }
2845
2822
  }
2846
2823