llama_cpp 0.3.8 → 0.5.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -18,10 +18,16 @@ typedef struct {
18
18
  uint8_t qs[QK4_1 / 2]; // nibbles / quants
19
19
  } block_q4_1;
20
20
 
21
+ #define QK8_0 32
22
+ typedef struct {
23
+ half d; // delta
24
+ int8_t qs[QK8_0]; // quants
25
+ } block_q8_0;
26
+
21
27
  kernel void kernel_add(
22
- device const float * src0,
23
- device const float * src1,
24
- device float * dst,
28
+ device const float4 * src0,
29
+ device const float4 * src1,
30
+ device float4 * dst,
25
31
  uint tpig[[thread_position_in_grid]]) {
26
32
  dst[tpig] = src0[tpig] + src1[tpig];
27
33
  }
@@ -29,18 +35,18 @@ kernel void kernel_add(
29
35
  // assumption: src1 is a row
30
36
  // broadcast src1 into src0
31
37
  kernel void kernel_add_row(
32
- device const float * src0,
33
- device const float * src1,
34
- device float * dst,
35
- constant int64_t & ne00,
38
+ device const float4 * src0,
39
+ device const float4 * src1,
40
+ device float4 * dst,
41
+ constant int64_t & nb,
36
42
  uint tpig[[thread_position_in_grid]]) {
37
- dst[tpig] = src0[tpig] + src1[tpig % ne00];
43
+ dst[tpig] = src0[tpig] + src1[tpig % nb];
38
44
  }
39
45
 
40
46
  kernel void kernel_mul(
41
- device const float * src0,
42
- device const float * src1,
43
- device float * dst,
47
+ device const float4 * src0,
48
+ device const float4 * src1,
49
+ device float4 * dst,
44
50
  uint tpig[[thread_position_in_grid]]) {
45
51
  dst[tpig] = src0[tpig] * src1[tpig];
46
52
  }
@@ -48,12 +54,12 @@ kernel void kernel_mul(
48
54
  // assumption: src1 is a row
49
55
  // broadcast src1 into src0
50
56
  kernel void kernel_mul_row(
51
- device const float * src0,
52
- device const float * src1,
53
- device float * dst,
54
- constant int64_t & ne00,
57
+ device const float4 * src0,
58
+ device const float4 * src1,
59
+ device float4 * dst,
60
+ constant int64_t & nb,
55
61
  uint tpig[[thread_position_in_grid]]) {
56
- dst[tpig] = src0[tpig] * src1[tpig % ne00];
62
+ dst[tpig] = src0[tpig] * src1[tpig % nb];
57
63
  }
58
64
 
59
65
  kernel void kernel_scale(
@@ -87,7 +93,12 @@ kernel void kernel_gelu(
87
93
  device float * dst,
88
94
  uint tpig[[thread_position_in_grid]]) {
89
95
  float x = src0[tpig];
90
- dst[tpig] = 0.5f*x*(1.0f + tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
96
+
97
+ // BEWARE !!!
98
+ // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
99
+ // This was observed with Falcon 7B and 40B models
100
+ //
101
+ dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
91
102
  }
92
103
 
93
104
  kernel void kernel_soft_max(
@@ -352,7 +363,7 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
352
363
  const int first_row = (r0 * nsg + sgitg) * nr;
353
364
  const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
354
365
  device const block_q_type * x = (device const block_q_type *) src0 + offset0;
355
- device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
366
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
356
367
  float yl[16]; // src1 vector cache
357
368
  float sumf[nr]={0.f};
358
369
 
@@ -424,6 +435,68 @@ kernel void kernel_mul_mat_q4_1_f32(
424
435
  mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
425
436
  }
426
437
 
438
+ kernel void kernel_mul_mat_q8_0_f32(
439
+ device const void * src0,
440
+ device const float * src1,
441
+ device float * dst,
442
+ constant int64_t & ne00,
443
+ constant int64_t & ne01[[buffer(4)]],
444
+ constant int64_t & ne02[[buffer(5)]],
445
+ constant int64_t & ne10[[buffer(9)]],
446
+ constant int64_t & ne12[[buffer(11)]],
447
+ constant int64_t & ne0[[buffer(15)]],
448
+ constant int64_t & ne1[[buffer(16)]],
449
+ constant uint & gqa[[buffer(17)]],
450
+ uint3 tgpig[[threadgroup_position_in_grid]],
451
+ uint tiisg[[thread_index_in_simdgroup]],
452
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
453
+ const int nr = N_DST;
454
+ const int nsg = N_SIMDGROUP;
455
+ const int nw = N_SIMDWIDTH;
456
+
457
+ const int nb = ne00/QK8_0;
458
+ const int r0 = tgpig.x;
459
+ const int r1 = tgpig.y;
460
+ const int im = tgpig.z;
461
+ const int first_row = (r0 * nsg + sgitg) * nr;
462
+ const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
463
+ device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0;
464
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
465
+
466
+ float yl[16];
467
+ float sumf[nr]={0.f};
468
+
469
+ const int ix = tiisg/2;
470
+ const int il = tiisg%2;
471
+
472
+ device const float * yb = y + ix * QK8_0 + 16*il;
473
+
474
+ // each thread in a SIMD group deals with half a block.
475
+ for (int ib = ix; ib < nb; ib += nw/2) {
476
+ for (int i = 0; i < 16; ++i) {
477
+ yl[i] = yb[i];
478
+ }
479
+
480
+ for (int row = 0; row < nr; row++) {
481
+ device const int8_t * qs = x[ib+row*nb].qs + 16*il;
482
+ float sumq = 0.f;
483
+ for (int iq = 0; iq < 16; ++iq) {
484
+ sumq += qs[iq] * yl[iq];
485
+ }
486
+ sumf[row] += sumq*x[ib+row*nb].d;
487
+ }
488
+
489
+ yb += QK8_0 * 16;
490
+ }
491
+
492
+ for (int row = 0; row < nr; ++row) {
493
+ const float tot = simd_sum(sumf[row]);
494
+ if (tiisg == 0 && first_row + row < ne01) {
495
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
496
+ }
497
+ }
498
+ }
499
+
427
500
  kernel void kernel_mul_mat_f16_f32(
428
501
  device const char * src0,
429
502
  device const char * src1,
@@ -455,26 +528,43 @@ kernel void kernel_mul_mat_f16_f32(
455
528
  device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
456
529
  device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
457
530
 
458
- sum[tpitg.x] = 0.0f;
531
+ uint ith = tpitg.x;
532
+ uint nth = tptg.x;
459
533
 
460
- for (int i = tpitg.x; i < ne00; i += tptg.x) {
461
- sum[tpitg.x] += (float) x[i] * (float) y[i];
534
+ sum[ith] = 0.0f;
535
+
536
+ for (int i = ith; i < ne00; i += nth) {
537
+ sum[ith] += (float) x[i] * (float) y[i];
462
538
  }
463
539
 
464
540
  // accumulate the sum from all threads in the threadgroup
465
541
  threadgroup_barrier(mem_flags::mem_threadgroup);
466
- for (uint i = tptg.x/2; i > 0; i /= 2) {
467
- if (tpitg.x < i) {
468
- sum[tpitg.x] += sum[tpitg.x + i];
469
- }
470
- threadgroup_barrier(mem_flags::mem_threadgroup);
542
+ if (ith%4 == 0) {
543
+ for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
471
544
  }
472
-
473
- if (tpitg.x == 0) {
545
+ threadgroup_barrier(mem_flags::mem_threadgroup);
546
+ if (ith%16 == 0) {
547
+ for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
548
+ }
549
+ threadgroup_barrier(mem_flags::mem_threadgroup);
550
+ if (ith == 0) {
551
+ for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
474
552
  dst[im*ne1*ne0 + r1*ne0 + r0] = sum[0];
475
553
  }
476
- }
477
554
 
555
+ // Original implementation. Left behind commented out for now
556
+ //threadgroup_barrier(mem_flags::mem_threadgroup);
557
+ //for (uint i = tptg.x/2; i > 0; i /= 2) {
558
+ // if (tpitg.x < i) {
559
+ // sum[tpitg.x] += sum[tpitg.x + i];
560
+ // }
561
+ // threadgroup_barrier(mem_flags::mem_threadgroup);
562
+ //}
563
+ //
564
+ //if (tpitg.x == 0) {
565
+ // dst[im*ne1*ne0 + r1*ne0 + r0] = sum[0];
566
+ //}
567
+ }
478
568
 
479
569
  kernel void kernel_alibi_f32(
480
570
  device const float * src0,
@@ -571,7 +661,25 @@ kernel void kernel_rope(
571
661
  dst_data[1] = x0*sin_theta + x1*cos_theta;
572
662
  }
573
663
  } else {
574
- // TODO: implement
664
+ for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
665
+ for (int64_t ic = 0; ic < n_dims; ic += 2) {
666
+ const float cos_theta = cos(theta);
667
+ const float sin_theta = sin(theta);
668
+
669
+ theta *= theta_scale;
670
+
671
+ const int64_t i0 = ib*n_dims + ic/2;
672
+
673
+ device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
674
+ device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
675
+
676
+ const float x0 = src[0];
677
+ const float x1 = src[n_dims/2];
678
+
679
+ dst_data[0] = x0*cos_theta - x1*sin_theta;
680
+ dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
681
+ }
682
+ }
575
683
  }
576
684
  }
577
685
 
@@ -1598,12 +1706,12 @@ template <typename type4x4>
1598
1706
  void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
1599
1707
  device const uint16_t * qs = ((device const uint16_t *)xb + 1);
1600
1708
  const half d = il ? (xb->d / 16.h) : xb->d;
1601
- const half m = il ? (-8.h * 16.h) : -8.h;
1709
+ const half m = il ? ( -8.h * 16.h) : -8.h;
1602
1710
  const ushort mask0 = il ? 0x00F0 : 0x000F;
1603
1711
  const ushort mask1 = il ? 0xF000 : 0x0F00;
1604
1712
 
1605
1713
  for (int i=0;i<8;i++) {
1606
- reg[i/2][2*(i%2)] = (((qs[i] & mask0)) + m) * d;
1714
+ reg[i/2][2*(i%2)] = (((qs[i] & mask0) ) + m) * d;
1607
1715
  reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) + m) * d;
1608
1716
  }
1609
1717
  }
@@ -1617,11 +1725,21 @@ void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg
1617
1725
  const ushort mask1 = il ? 0xF000 : 0x0F00;
1618
1726
 
1619
1727
  for (int i=0;i<8;i++) {
1620
- reg[i/2][2*(i%2)] = (((qs[i] & mask0)) * d) + m;
1728
+ reg[i/2][2*(i%2)] = (((qs[i] & mask0) ) * d) + m;
1621
1729
  reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) * d) + m;
1622
1730
  }
1623
1731
  }
1624
1732
 
1733
+ template <typename type4x4>
1734
+ void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
1735
+ device const int8_t * qs = ((device const int8_t *)xb->qs);
1736
+ const half d = xb->d;
1737
+
1738
+ for (int i=0;i<16;i++) {
1739
+ reg[i/4][i%4] = (qs[i + 16*il] * d);
1740
+ }
1741
+ }
1742
+
1625
1743
  template <typename type4x4>
1626
1744
  void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
1627
1745
  const half d = xb->d;
@@ -1850,6 +1968,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
1850
1968
  //load data and store to threadgroup memory
1851
1969
  half4x4 temp_a;
1852
1970
  dequantize_func(x, il, temp_a);
1971
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1853
1972
  #pragma unroll(16)
1854
1973
  for (int i = 0; i < 16; i++) {
1855
1974
  *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
@@ -1895,6 +2014,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
1895
2014
  }
1896
2015
  } else {
1897
2016
  // block is smaller than 64x32, we should avoid writing data outside of the matrix
2017
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1898
2018
  threadgroup float *temp_str = ((threadgroup float *)shared_memory) \
1899
2019
  + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
1900
2020
  for (int i = 0; i < 8; i++) {
@@ -1922,9 +2042,10 @@ kernel void kernel_mul_mm(device const uchar * src0,
1922
2042
  typedef void (get_rows_t)(device const void *, device const int *, device float *, constant int64_t &, \
1923
2043
  constant uint64_t &, constant uint64_t &, uint, uint, uint);
1924
2044
 
1925
- template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
2045
+ template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
1926
2046
  template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
1927
2047
  template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
2048
+ template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows<block_q8_0, 2, dequantize_q8_0>;
1928
2049
  template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows<block_q2_K, QK_NL, dequantize_q2_K>;
1929
2050
  template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows<block_q3_K, QK_NL, dequantize_q3_K>;
1930
2051
  template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows<block_q4_K, QK_NL, dequantize_q4_K>;
@@ -1935,9 +2056,10 @@ typedef void (mat_mm_t)(device const uchar *, device const float *, device float
1935
2056
  constant int64_t &, constant int64_t &, constant int64_t &, constant int64_t &, \
1936
2057
  constant int64_t &, constant int64_t &, constant uint &, threadgroup uchar *, uint3, uint, uint);
1937
2058
 
1938
- template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
2059
+ template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
1939
2060
  template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
1940
2061
  template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
2062
+ template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
1941
2063
  template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
1942
2064
  template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
1943
2065
  template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;