llama_cpp 0.3.8 → 0.5.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -18,10 +18,16 @@ typedef struct {
18
18
  uint8_t qs[QK4_1 / 2]; // nibbles / quants
19
19
  } block_q4_1;
20
20
 
21
+ #define QK8_0 32
22
+ typedef struct {
23
+ half d; // delta
24
+ int8_t qs[QK8_0]; // quants
25
+ } block_q8_0;
26
+
21
27
  kernel void kernel_add(
22
- device const float * src0,
23
- device const float * src1,
24
- device float * dst,
28
+ device const float4 * src0,
29
+ device const float4 * src1,
30
+ device float4 * dst,
25
31
  uint tpig[[thread_position_in_grid]]) {
26
32
  dst[tpig] = src0[tpig] + src1[tpig];
27
33
  }
@@ -29,18 +35,18 @@ kernel void kernel_add(
29
35
  // assumption: src1 is a row
30
36
  // broadcast src1 into src0
31
37
  kernel void kernel_add_row(
32
- device const float * src0,
33
- device const float * src1,
34
- device float * dst,
35
- constant int64_t & ne00,
38
+ device const float4 * src0,
39
+ device const float4 * src1,
40
+ device float4 * dst,
41
+ constant int64_t & nb,
36
42
  uint tpig[[thread_position_in_grid]]) {
37
- dst[tpig] = src0[tpig] + src1[tpig % ne00];
43
+ dst[tpig] = src0[tpig] + src1[tpig % nb];
38
44
  }
39
45
 
40
46
  kernel void kernel_mul(
41
- device const float * src0,
42
- device const float * src1,
43
- device float * dst,
47
+ device const float4 * src0,
48
+ device const float4 * src1,
49
+ device float4 * dst,
44
50
  uint tpig[[thread_position_in_grid]]) {
45
51
  dst[tpig] = src0[tpig] * src1[tpig];
46
52
  }
@@ -48,12 +54,12 @@ kernel void kernel_mul(
48
54
  // assumption: src1 is a row
49
55
  // broadcast src1 into src0
50
56
  kernel void kernel_mul_row(
51
- device const float * src0,
52
- device const float * src1,
53
- device float * dst,
54
- constant int64_t & ne00,
57
+ device const float4 * src0,
58
+ device const float4 * src1,
59
+ device float4 * dst,
60
+ constant int64_t & nb,
55
61
  uint tpig[[thread_position_in_grid]]) {
56
- dst[tpig] = src0[tpig] * src1[tpig % ne00];
62
+ dst[tpig] = src0[tpig] * src1[tpig % nb];
57
63
  }
58
64
 
59
65
  kernel void kernel_scale(
@@ -87,7 +93,12 @@ kernel void kernel_gelu(
87
93
  device float * dst,
88
94
  uint tpig[[thread_position_in_grid]]) {
89
95
  float x = src0[tpig];
90
- dst[tpig] = 0.5f*x*(1.0f + tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
96
+
97
+ // BEWARE !!!
98
+ // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
99
+ // This was observed with Falcon 7B and 40B models
100
+ //
101
+ dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
91
102
  }
92
103
 
93
104
  kernel void kernel_soft_max(
@@ -352,7 +363,7 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
352
363
  const int first_row = (r0 * nsg + sgitg) * nr;
353
364
  const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
354
365
  device const block_q_type * x = (device const block_q_type *) src0 + offset0;
355
- device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
366
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
356
367
  float yl[16]; // src1 vector cache
357
368
  float sumf[nr]={0.f};
358
369
 
@@ -424,6 +435,68 @@ kernel void kernel_mul_mat_q4_1_f32(
424
435
  mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
425
436
  }
426
437
 
438
+ kernel void kernel_mul_mat_q8_0_f32(
439
+ device const void * src0,
440
+ device const float * src1,
441
+ device float * dst,
442
+ constant int64_t & ne00,
443
+ constant int64_t & ne01[[buffer(4)]],
444
+ constant int64_t & ne02[[buffer(5)]],
445
+ constant int64_t & ne10[[buffer(9)]],
446
+ constant int64_t & ne12[[buffer(11)]],
447
+ constant int64_t & ne0[[buffer(15)]],
448
+ constant int64_t & ne1[[buffer(16)]],
449
+ constant uint & gqa[[buffer(17)]],
450
+ uint3 tgpig[[threadgroup_position_in_grid]],
451
+ uint tiisg[[thread_index_in_simdgroup]],
452
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
453
+ const int nr = N_DST;
454
+ const int nsg = N_SIMDGROUP;
455
+ const int nw = N_SIMDWIDTH;
456
+
457
+ const int nb = ne00/QK8_0;
458
+ const int r0 = tgpig.x;
459
+ const int r1 = tgpig.y;
460
+ const int im = tgpig.z;
461
+ const int first_row = (r0 * nsg + sgitg) * nr;
462
+ const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
463
+ device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0;
464
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
465
+
466
+ float yl[16];
467
+ float sumf[nr]={0.f};
468
+
469
+ const int ix = tiisg/2;
470
+ const int il = tiisg%2;
471
+
472
+ device const float * yb = y + ix * QK8_0 + 16*il;
473
+
474
+ // each thread in a SIMD group deals with half a block.
475
+ for (int ib = ix; ib < nb; ib += nw/2) {
476
+ for (int i = 0; i < 16; ++i) {
477
+ yl[i] = yb[i];
478
+ }
479
+
480
+ for (int row = 0; row < nr; row++) {
481
+ device const int8_t * qs = x[ib+row*nb].qs + 16*il;
482
+ float sumq = 0.f;
483
+ for (int iq = 0; iq < 16; ++iq) {
484
+ sumq += qs[iq] * yl[iq];
485
+ }
486
+ sumf[row] += sumq*x[ib+row*nb].d;
487
+ }
488
+
489
+ yb += QK8_0 * 16;
490
+ }
491
+
492
+ for (int row = 0; row < nr; ++row) {
493
+ const float tot = simd_sum(sumf[row]);
494
+ if (tiisg == 0 && first_row + row < ne01) {
495
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
496
+ }
497
+ }
498
+ }
499
+
427
500
  kernel void kernel_mul_mat_f16_f32(
428
501
  device const char * src0,
429
502
  device const char * src1,
@@ -455,26 +528,43 @@ kernel void kernel_mul_mat_f16_f32(
455
528
  device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
456
529
  device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
457
530
 
458
- sum[tpitg.x] = 0.0f;
531
+ uint ith = tpitg.x;
532
+ uint nth = tptg.x;
459
533
 
460
- for (int i = tpitg.x; i < ne00; i += tptg.x) {
461
- sum[tpitg.x] += (float) x[i] * (float) y[i];
534
+ sum[ith] = 0.0f;
535
+
536
+ for (int i = ith; i < ne00; i += nth) {
537
+ sum[ith] += (float) x[i] * (float) y[i];
462
538
  }
463
539
 
464
540
  // accumulate the sum from all threads in the threadgroup
465
541
  threadgroup_barrier(mem_flags::mem_threadgroup);
466
- for (uint i = tptg.x/2; i > 0; i /= 2) {
467
- if (tpitg.x < i) {
468
- sum[tpitg.x] += sum[tpitg.x + i];
469
- }
470
- threadgroup_barrier(mem_flags::mem_threadgroup);
542
+ if (ith%4 == 0) {
543
+ for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
471
544
  }
472
-
473
- if (tpitg.x == 0) {
545
+ threadgroup_barrier(mem_flags::mem_threadgroup);
546
+ if (ith%16 == 0) {
547
+ for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
548
+ }
549
+ threadgroup_barrier(mem_flags::mem_threadgroup);
550
+ if (ith == 0) {
551
+ for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
474
552
  dst[im*ne1*ne0 + r1*ne0 + r0] = sum[0];
475
553
  }
476
- }
477
554
 
555
+ // Original implementation. Left behind commented out for now
556
+ //threadgroup_barrier(mem_flags::mem_threadgroup);
557
+ //for (uint i = tptg.x/2; i > 0; i /= 2) {
558
+ // if (tpitg.x < i) {
559
+ // sum[tpitg.x] += sum[tpitg.x + i];
560
+ // }
561
+ // threadgroup_barrier(mem_flags::mem_threadgroup);
562
+ //}
563
+ //
564
+ //if (tpitg.x == 0) {
565
+ // dst[im*ne1*ne0 + r1*ne0 + r0] = sum[0];
566
+ //}
567
+ }
478
568
 
479
569
  kernel void kernel_alibi_f32(
480
570
  device const float * src0,
@@ -571,7 +661,25 @@ kernel void kernel_rope(
571
661
  dst_data[1] = x0*sin_theta + x1*cos_theta;
572
662
  }
573
663
  } else {
574
- // TODO: implement
664
+ for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
665
+ for (int64_t ic = 0; ic < n_dims; ic += 2) {
666
+ const float cos_theta = cos(theta);
667
+ const float sin_theta = sin(theta);
668
+
669
+ theta *= theta_scale;
670
+
671
+ const int64_t i0 = ib*n_dims + ic/2;
672
+
673
+ device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
674
+ device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
675
+
676
+ const float x0 = src[0];
677
+ const float x1 = src[n_dims/2];
678
+
679
+ dst_data[0] = x0*cos_theta - x1*sin_theta;
680
+ dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
681
+ }
682
+ }
575
683
  }
576
684
  }
577
685
 
@@ -1598,12 +1706,12 @@ template <typename type4x4>
1598
1706
  void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
1599
1707
  device const uint16_t * qs = ((device const uint16_t *)xb + 1);
1600
1708
  const half d = il ? (xb->d / 16.h) : xb->d;
1601
- const half m = il ? (-8.h * 16.h) : -8.h;
1709
+ const half m = il ? ( -8.h * 16.h) : -8.h;
1602
1710
  const ushort mask0 = il ? 0x00F0 : 0x000F;
1603
1711
  const ushort mask1 = il ? 0xF000 : 0x0F00;
1604
1712
 
1605
1713
  for (int i=0;i<8;i++) {
1606
- reg[i/2][2*(i%2)] = (((qs[i] & mask0)) + m) * d;
1714
+ reg[i/2][2*(i%2)] = (((qs[i] & mask0) ) + m) * d;
1607
1715
  reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) + m) * d;
1608
1716
  }
1609
1717
  }
@@ -1617,11 +1725,21 @@ void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg
1617
1725
  const ushort mask1 = il ? 0xF000 : 0x0F00;
1618
1726
 
1619
1727
  for (int i=0;i<8;i++) {
1620
- reg[i/2][2*(i%2)] = (((qs[i] & mask0)) * d) + m;
1728
+ reg[i/2][2*(i%2)] = (((qs[i] & mask0) ) * d) + m;
1621
1729
  reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) * d) + m;
1622
1730
  }
1623
1731
  }
1624
1732
 
1733
+ template <typename type4x4>
1734
+ void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
1735
+ device const int8_t * qs = ((device const int8_t *)xb->qs);
1736
+ const half d = xb->d;
1737
+
1738
+ for (int i=0;i<16;i++) {
1739
+ reg[i/4][i%4] = (qs[i + 16*il] * d);
1740
+ }
1741
+ }
1742
+
1625
1743
  template <typename type4x4>
1626
1744
  void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
1627
1745
  const half d = xb->d;
@@ -1850,6 +1968,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
1850
1968
  //load data and store to threadgroup memory
1851
1969
  half4x4 temp_a;
1852
1970
  dequantize_func(x, il, temp_a);
1971
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1853
1972
  #pragma unroll(16)
1854
1973
  for (int i = 0; i < 16; i++) {
1855
1974
  *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
@@ -1895,6 +2014,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
1895
2014
  }
1896
2015
  } else {
1897
2016
  // block is smaller than 64x32, we should avoid writing data outside of the matrix
2017
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1898
2018
  threadgroup float *temp_str = ((threadgroup float *)shared_memory) \
1899
2019
  + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
1900
2020
  for (int i = 0; i < 8; i++) {
@@ -1922,9 +2042,10 @@ kernel void kernel_mul_mm(device const uchar * src0,
1922
2042
  typedef void (get_rows_t)(device const void *, device const int *, device float *, constant int64_t &, \
1923
2043
  constant uint64_t &, constant uint64_t &, uint, uint, uint);
1924
2044
 
1925
- template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
2045
+ template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
1926
2046
  template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
1927
2047
  template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
2048
+ template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows<block_q8_0, 2, dequantize_q8_0>;
1928
2049
  template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows<block_q2_K, QK_NL, dequantize_q2_K>;
1929
2050
  template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows<block_q3_K, QK_NL, dequantize_q3_K>;
1930
2051
  template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows<block_q4_K, QK_NL, dequantize_q4_K>;
@@ -1935,9 +2056,10 @@ typedef void (mat_mm_t)(device const uchar *, device const float *, device float
1935
2056
  constant int64_t &, constant int64_t &, constant int64_t &, constant int64_t &, \
1936
2057
  constant int64_t &, constant int64_t &, constant uint &, threadgroup uchar *, uint3, uint, uint);
1937
2058
 
1938
- template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
2059
+ template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
1939
2060
  template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
1940
2061
  template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
2062
+ template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
1941
2063
  template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
1942
2064
  template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
1943
2065
  template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;