llama_cpp 0.12.2 → 0.12.4

Sign up to get free protection for your applications and to get access to all the features.
@@ -1775,9 +1775,29 @@ kernel void kernel_rope(
1775
1775
  template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope<float>;
1776
1776
  template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope<half>;
1777
1777
 
1778
- kernel void kernel_im2col_f16(
1778
+ typedef void (im2col_t)(
1779
1779
  device const float * x,
1780
- device half * dst,
1780
+ device char * dst,
1781
+ constant int32_t & ofs0,
1782
+ constant int32_t & ofs1,
1783
+ constant int32_t & IW,
1784
+ constant int32_t & IH,
1785
+ constant int32_t & CHW,
1786
+ constant int32_t & s0,
1787
+ constant int32_t & s1,
1788
+ constant int32_t & p0,
1789
+ constant int32_t & p1,
1790
+ constant int32_t & d0,
1791
+ constant int32_t & d1,
1792
+ uint3 tgpig[[threadgroup_position_in_grid]],
1793
+ uint3 tgpg[[threadgroups_per_grid]],
1794
+ uint3 tpitg[[thread_position_in_threadgroup]],
1795
+ uint3 ntg[[threads_per_threadgroup]]);
1796
+
1797
+ template <typename T>
1798
+ kernel void kernel_im2col(
1799
+ device const float * x,
1800
+ device char * dst,
1781
1801
  constant int32_t & ofs0,
1782
1802
  constant int32_t & ofs1,
1783
1803
  constant int32_t & IW,
@@ -1800,14 +1820,19 @@ kernel void kernel_im2col_f16(
1800
1820
  (tpitg[0] * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW +
1801
1821
  (tgpig[0] * (ntg[1] * ntg[2]) + tpitg[1] * ntg[2] + tpitg[2]);
1802
1822
 
1823
+ device T * pdst = (device T *) (dst);
1824
+
1803
1825
  if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
1804
- dst[offset_dst] = 0.0f;
1826
+ pdst[offset_dst] = 0.0f;
1805
1827
  } else {
1806
1828
  const int32_t offset_src = tpitg[0] * ofs0 + tgpig[0] * ofs1;
1807
- dst[offset_dst] = x[offset_src + iih * IW + iiw];
1829
+ pdst[offset_dst] = x[offset_src + iih * IW + iiw];
1808
1830
  }
1809
1831
  }
1810
1832
 
1833
+ template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col<float>;
1834
+ template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col<half>;
1835
+
1811
1836
  kernel void kernel_upscale_f32(
1812
1837
  device const char * src0,
1813
1838
  device char * dst,
@@ -2459,6 +2484,12 @@ typedef struct {
2459
2484
  } block_iq2_xs;
2460
2485
  // 74 bytes / block for QK_K = 256, so 2.3125 bpw
2461
2486
 
2487
+ typedef struct {
2488
+ half d;
2489
+ uint8_t qs[3*QK_K/8];
2490
+ } block_iq3_xxs;
2491
+ // 98 bytes / block for QK_K = 256, so 3.0625 bpw
2492
+
2462
2493
  //====================================== dot products =========================
2463
2494
 
2464
2495
  void kernel_mul_mv_q2_K_f32_impl(
@@ -3681,6 +3712,42 @@ constexpr constant static uint64_t iq2xs_grid[512] = {
3681
3712
  0x2b2b2b2b082b2b08, 0x2b2b2b2b082b2b2b, 0x2b2b2b2b2b190819, 0x2b2b2b2b2b2b2b2b,
3682
3713
  };
3683
3714
 
3715
+ constexpr constant static uint32_t iq3xxs_grid[256] = {
3716
+ 0x04040404, 0x04040414, 0x04040424, 0x04040c0c, 0x04040c1c, 0x04040c3e, 0x04041404, 0x04041414,
3717
+ 0x04041c0c, 0x04042414, 0x04043e1c, 0x04043e2c, 0x040c040c, 0x040c041c, 0x040c0c04, 0x040c0c14,
3718
+ 0x040c140c, 0x040c142c, 0x040c1c04, 0x040c1c14, 0x040c240c, 0x040c2c24, 0x040c3e04, 0x04140404,
3719
+ 0x04140414, 0x04140424, 0x04140c0c, 0x04141404, 0x04141414, 0x04141c0c, 0x04141c1c, 0x04141c3e,
3720
+ 0x04142c0c, 0x04142c3e, 0x04143e2c, 0x041c040c, 0x041c043e, 0x041c0c04, 0x041c0c14, 0x041c142c,
3721
+ 0x041c3e04, 0x04240c1c, 0x04241c3e, 0x04242424, 0x04242c3e, 0x04243e1c, 0x04243e2c, 0x042c040c,
3722
+ 0x042c043e, 0x042c1c14, 0x042c2c14, 0x04341c2c, 0x04343424, 0x043e0c04, 0x043e0c24, 0x043e0c34,
3723
+ 0x043e241c, 0x043e340c, 0x0c04040c, 0x0c04041c, 0x0c040c04, 0x0c040c14, 0x0c04140c, 0x0c04141c,
3724
+ 0x0c041c04, 0x0c041c14, 0x0c041c24, 0x0c04243e, 0x0c042c04, 0x0c0c0404, 0x0c0c0414, 0x0c0c0c0c,
3725
+ 0x0c0c1404, 0x0c0c1414, 0x0c14040c, 0x0c14041c, 0x0c140c04, 0x0c140c14, 0x0c14140c, 0x0c141c04,
3726
+ 0x0c143e14, 0x0c1c0404, 0x0c1c0414, 0x0c1c1404, 0x0c1c1c0c, 0x0c1c2434, 0x0c1c3434, 0x0c24040c,
3727
+ 0x0c24042c, 0x0c242c04, 0x0c2c1404, 0x0c2c1424, 0x0c2c2434, 0x0c2c3e0c, 0x0c34042c, 0x0c3e1414,
3728
+ 0x0c3e2404, 0x14040404, 0x14040414, 0x14040c0c, 0x14040c1c, 0x14041404, 0x14041414, 0x14041434,
3729
+ 0x14041c0c, 0x14042414, 0x140c040c, 0x140c041c, 0x140c042c, 0x140c0c04, 0x140c0c14, 0x140c140c,
3730
+ 0x140c1c04, 0x140c341c, 0x140c343e, 0x140c3e04, 0x14140404, 0x14140414, 0x14140c0c, 0x14140c3e,
3731
+ 0x14141404, 0x14141414, 0x14141c3e, 0x14142404, 0x14142c2c, 0x141c040c, 0x141c0c04, 0x141c0c24,
3732
+ 0x141c3e04, 0x141c3e24, 0x14241c2c, 0x14242c1c, 0x142c041c, 0x142c143e, 0x142c240c, 0x142c3e24,
3733
+ 0x143e040c, 0x143e041c, 0x143e0c34, 0x143e242c, 0x1c04040c, 0x1c040c04, 0x1c040c14, 0x1c04140c,
3734
+ 0x1c04141c, 0x1c042c04, 0x1c04342c, 0x1c043e14, 0x1c0c0404, 0x1c0c0414, 0x1c0c1404, 0x1c0c1c0c,
3735
+ 0x1c0c2424, 0x1c0c2434, 0x1c14040c, 0x1c14041c, 0x1c140c04, 0x1c14142c, 0x1c142c14, 0x1c143e14,
3736
+ 0x1c1c0c0c, 0x1c1c1c1c, 0x1c241c04, 0x1c24243e, 0x1c243e14, 0x1c2c0404, 0x1c2c0434, 0x1c2c1414,
3737
+ 0x1c2c2c2c, 0x1c340c24, 0x1c341c34, 0x1c34341c, 0x1c3e1c1c, 0x1c3e3404, 0x24040424, 0x24040c3e,
3738
+ 0x24041c2c, 0x24041c3e, 0x24042c1c, 0x24042c3e, 0x240c3e24, 0x24141404, 0x24141c3e, 0x24142404,
3739
+ 0x24143404, 0x24143434, 0x241c043e, 0x241c242c, 0x24240424, 0x24242c0c, 0x24243424, 0x242c142c,
3740
+ 0x242c241c, 0x242c3e04, 0x243e042c, 0x243e0c04, 0x243e0c14, 0x243e1c04, 0x2c040c14, 0x2c04240c,
3741
+ 0x2c043e04, 0x2c0c0404, 0x2c0c0434, 0x2c0c1434, 0x2c0c2c2c, 0x2c140c24, 0x2c141c14, 0x2c143e14,
3742
+ 0x2c1c0414, 0x2c1c2c1c, 0x2c240c04, 0x2c24141c, 0x2c24143e, 0x2c243e14, 0x2c2c0414, 0x2c2c1c0c,
3743
+ 0x2c342c04, 0x2c3e1424, 0x2c3e2414, 0x34041424, 0x34042424, 0x34042434, 0x34043424, 0x340c140c,
3744
+ 0x340c340c, 0x34140c3e, 0x34143424, 0x341c1c04, 0x341c1c34, 0x34242424, 0x342c042c, 0x342c2c14,
3745
+ 0x34341c1c, 0x343e041c, 0x343e140c, 0x3e04041c, 0x3e04042c, 0x3e04043e, 0x3e040c04, 0x3e041c14,
3746
+ 0x3e042c14, 0x3e0c1434, 0x3e0c2404, 0x3e140c14, 0x3e14242c, 0x3e142c14, 0x3e1c0404, 0x3e1c0c2c,
3747
+ 0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
3748
+ };
3749
+
3750
+
3684
3751
  constexpr constant static uint8_t ksigns_iq2xs[128] = {
3685
3752
  0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
3686
3753
  144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
@@ -3970,6 +4037,143 @@ kernel void kernel_mul_mv_iq2_xs_f32(
3970
4037
  kernel_mul_mv_iq2_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
3971
4038
  }
3972
4039
 
4040
+ void kernel_mul_mv_iq3_xxs_f32_impl(
4041
+ device const void * src0,
4042
+ device const float * src1,
4043
+ device float * dst,
4044
+ constant int64_t & ne00,
4045
+ constant int64_t & ne01,
4046
+ constant int64_t & ne02,
4047
+ constant int64_t & ne10,
4048
+ constant int64_t & ne12,
4049
+ constant int64_t & ne0,
4050
+ constant int64_t & ne1,
4051
+ constant uint & r2,
4052
+ constant uint & r3,
4053
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
4054
+ uint3 tgpig[[threadgroup_position_in_grid]],
4055
+ uint tiisg[[thread_index_in_simdgroup]],
4056
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4057
+
4058
+ const int nb = ne00/QK_K;
4059
+ const int r0 = tgpig.x;
4060
+ const int r1 = tgpig.y;
4061
+ const int im = tgpig.z;
4062
+
4063
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
4064
+ const int ib_row = first_row * nb;
4065
+
4066
+ const uint i12 = im%ne12;
4067
+ const uint i13 = im/ne12;
4068
+
4069
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
4070
+
4071
+ device const block_iq3_xxs * x = (device const block_iq3_xxs *) src0 + ib_row + offset0;
4072
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
4073
+
4074
+ float yl[32];
4075
+ float sumf[N_DST]={0.f}, all_sum;
4076
+
4077
+ const int nb32 = nb * (QK_K / 32);
4078
+
4079
+ threadgroup uint32_t * values = (threadgroup uint32_t *)shared_values;
4080
+ threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 256);
4081
+ {
4082
+ int nval = 4;
4083
+ int pos = (32*sgitg + tiisg)*nval;
4084
+ for (int i = 0; i < nval; ++i) values[pos + i] = iq3xxs_grid[pos + i];
4085
+ nval = 2;
4086
+ pos = (32*sgitg + tiisg)*nval;
4087
+ for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i];
4088
+ threadgroup_barrier(mem_flags::mem_threadgroup);
4089
+ }
4090
+
4091
+ #if QK_K == 256
4092
+ const int ix = tiisg;
4093
+
4094
+ device const float * y4 = y + 32 * ix;
4095
+
4096
+ for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
4097
+
4098
+ for (int i = 0; i < 32; ++i) {
4099
+ yl[i] = y4[i];
4100
+ }
4101
+
4102
+ const int ibl = ib32 / (QK_K / 32);
4103
+ const int ib = ib32 % (QK_K / 32);
4104
+
4105
+ device const block_iq3_xxs * xr = x + ibl;
4106
+ device const uint8_t * q3 = xr->qs + 8 * ib;
4107
+ device const uint16_t * gas = (device const uint16_t *)(xr->qs + QK_K/4) + 2 * ib;
4108
+ device const half * dh = &xr->d;
4109
+
4110
+ for (int row = 0; row < N_DST; row++) {
4111
+
4112
+ const float db = dh[0];
4113
+ const uint32_t aux32 = gas[0] | (gas[1] << 16);
4114
+ const float d = db * (0.5f + (aux32 >> 28));
4115
+
4116
+ float2 sum = {0};
4117
+ for (int l = 0; l < 4; ++l) {
4118
+ const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(values + q3[2*l+0]);
4119
+ const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(values + q3[2*l+1]);
4120
+ const uint8_t signs = shared_signs[(aux32 >> 7*l) & 127];
4121
+ for (int j = 0; j < 4; ++j) {
4122
+ sum[0] += yl[8*l + j + 0] * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
4123
+ sum[1] += yl[8*l + j + 4] * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
4124
+ }
4125
+ }
4126
+ sumf[row] += d * (sum[0] + sum[1]);
4127
+
4128
+ dh += nb*sizeof(block_iq3_xxs)/2;
4129
+ q3 += nb*sizeof(block_iq3_xxs);
4130
+ gas += nb*sizeof(block_iq3_xxs)/2;
4131
+ }
4132
+
4133
+ y4 += 32 * 32;
4134
+ }
4135
+ #else
4136
+ // TODO
4137
+ #endif
4138
+
4139
+ for (int row = 0; row < N_DST; ++row) {
4140
+ all_sum = simd_sum(sumf[row]);
4141
+ if (tiisg == 0) {
4142
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.5f;
4143
+ }
4144
+ }
4145
+ }
4146
+
4147
+ [[host_name("kernel_mul_mv_iq3_xxs_f32")]]
4148
+ kernel void kernel_mul_mv_iq3_xxs_f32(
4149
+ device const void * src0,
4150
+ device const float * src1,
4151
+ device float * dst,
4152
+ constant int64_t & ne00,
4153
+ constant int64_t & ne01,
4154
+ constant int64_t & ne02,
4155
+ constant uint64_t & nb00,
4156
+ constant uint64_t & nb01,
4157
+ constant uint64_t & nb02,
4158
+ constant int64_t & ne10,
4159
+ constant int64_t & ne11,
4160
+ constant int64_t & ne12,
4161
+ constant uint64_t & nb10,
4162
+ constant uint64_t & nb11,
4163
+ constant uint64_t & nb12,
4164
+ constant int64_t & ne0,
4165
+ constant int64_t & ne1,
4166
+ constant uint & r2,
4167
+ constant uint & r3,
4168
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
4169
+ uint3 tgpig[[threadgroup_position_in_grid]],
4170
+ uint tiisg[[thread_index_in_simdgroup]],
4171
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4172
+
4173
+ kernel_mul_mv_iq3_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
4174
+ }
4175
+
4176
+
3973
4177
  //============================= templates and their specializations =============================
3974
4178
 
3975
4179
  // NOTE: this is not dequantizing - we are simply fitting the template
@@ -4287,6 +4491,33 @@ void dequantize_iq2_xs(device const block_iq2_xs * xb, short il, thread type4x4
4287
4491
  }
4288
4492
  }
4289
4493
 
4494
+ template <typename type4x4>
4495
+ void dequantize_iq3_xxs(device const block_iq3_xxs * xb, short il, thread type4x4 & reg) {
4496
+ // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
4497
+ const float d = xb->d;
4498
+ const int ib32 = il/2;
4499
+ il = il%2;
4500
+ // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
4501
+ device const uint8_t * q3 = xb->qs + 8*ib32;
4502
+ device const uint16_t * gas = (device const uint16_t *)(xb->qs + QK_K/4) + 2*ib32;
4503
+ const uint32_t aux32 = gas[0] | (gas[1] << 16);
4504
+ const float dl = d * (0.5f + (aux32 >> 28)) * 0.5f;
4505
+ constant uint8_t * grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+0]);
4506
+ constant uint8_t * grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+1]);
4507
+ uint8_t signs = ksigns_iq2xs[(aux32 >> 14*il) & 127];
4508
+ for (int i = 0; i < 4; ++i) {
4509
+ reg[0][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f);
4510
+ reg[1][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f);
4511
+ }
4512
+ grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+2]);
4513
+ grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+3]);
4514
+ signs = ksigns_iq2xs[(aux32 >> (14*il+7)) & 127];
4515
+ for (int i = 0; i < 4; ++i) {
4516
+ reg[2][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f);
4517
+ reg[3][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f);
4518
+ }
4519
+ }
4520
+
4290
4521
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
4291
4522
  kernel void kernel_get_rows(
4292
4523
  device const void * src0,
@@ -4828,6 +5059,7 @@ template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows
4828
5059
  template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
4829
5060
  template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_t kernel_get_rows<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
4830
5061
  template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_t kernel_get_rows<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
5062
+ template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_t kernel_get_rows<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
4831
5063
 
4832
5064
  //
4833
5065
  // matrix-matrix multiplication
@@ -4866,6 +5098,7 @@ template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<b
4866
5098
  template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
4867
5099
  template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
4868
5100
  template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
5101
+ template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
4869
5102
 
4870
5103
  //
4871
5104
  // indirect matrix-matrix multiplication
@@ -4916,6 +5149,7 @@ template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mu
4916
5149
  template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
4917
5150
  template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
4918
5151
  template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
5152
+ template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
4919
5153
 
4920
5154
  //
4921
5155
  // matrix-vector multiplication
@@ -5818,3 +6052,68 @@ kernel void kernel_mul_mv_id_iq2_xs_f32(
5818
6052
  tiisg,
5819
6053
  sgitg);
5820
6054
  }
6055
+
6056
+ [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]]
6057
+ kernel void kernel_mul_mv_id_iq3_xxs_f32(
6058
+ device const char * ids,
6059
+ device const char * src1,
6060
+ device float * dst,
6061
+ constant uint64_t & nbi1,
6062
+ constant int64_t & ne00,
6063
+ constant int64_t & ne01,
6064
+ constant int64_t & ne02,
6065
+ constant uint64_t & nb00,
6066
+ constant uint64_t & nb01,
6067
+ constant uint64_t & nb02,
6068
+ constant int64_t & ne10,
6069
+ constant int64_t & ne11,
6070
+ constant int64_t & ne12,
6071
+ constant int64_t & ne13,
6072
+ constant uint64_t & nb10,
6073
+ constant uint64_t & nb11,
6074
+ constant uint64_t & nb12,
6075
+ constant int64_t & ne0,
6076
+ constant int64_t & ne1,
6077
+ constant uint64_t & nb1,
6078
+ constant uint & r2,
6079
+ constant uint & r3,
6080
+ constant int & idx,
6081
+ device const char * src00,
6082
+ device const char * src01,
6083
+ device const char * src02,
6084
+ device const char * src03,
6085
+ device const char * src04,
6086
+ device const char * src05,
6087
+ device const char * src06,
6088
+ device const char * src07,
6089
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
6090
+ uint3 tgpig[[threadgroup_position_in_grid]],
6091
+ uint tiitg[[thread_index_in_threadgroup]],
6092
+ uint tiisg[[thread_index_in_simdgroup]],
6093
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
6094
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6095
+
6096
+ const int64_t bid = tgpig.z/(ne12*ne13);
6097
+
6098
+ tgpig.z = tgpig.z%(ne12*ne13);
6099
+
6100
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6101
+
6102
+ kernel_mul_mv_iq3_xxs_f32_impl(
6103
+ src0[id],
6104
+ (device const float *) (src1 + bid*nb11),
6105
+ dst + bid*ne0,
6106
+ ne00,
6107
+ ne01,
6108
+ ne02,
6109
+ ne10,
6110
+ ne12,
6111
+ ne0,
6112
+ ne1,
6113
+ r2,
6114
+ r3,
6115
+ shared_values,
6116
+ tgpig,
6117
+ tiisg,
6118
+ sgitg);
6119
+ }
@@ -714,7 +714,6 @@ __kernel void dequantize_mul_mat_vec_q6_K(__global const struct block_q6_K * xx,
714
714
  dst[row] = tmp[0];
715
715
  }
716
716
  }
717
-
718
717
  );
719
718
 
720
719
 
@@ -784,6 +783,7 @@ __kernel void KERNEL_NAME(__global X_TYPE* x, __local float* tmp, __global float
784
783
  dst[row] = tmp[0];
785
784
  }
786
785
  }
786
+
787
787
  );
788
788
 
789
789
 
@@ -799,6 +799,18 @@ __kernel void KERNEL_NAME(__global TYPE* x, const int x_offset, __global TYPE* y
799
799
  }
800
800
  );
801
801
 
802
+ std::string add_template = MULTILINE_QUOTE(
803
+ __kernel void add_f32(__global float * x, const int x_offset, __global float * y, const int y_offset, __global float * dst, const int dst_offset, const int ky) {
804
+ const int i = get_group_id(0)*get_local_size(0) + get_local_id(0);
805
+
806
+ if (i >= get_global_size(0)) {
807
+ return;
808
+ }
809
+
810
+ dst[dst_offset + i] = x[x_offset + i] + y[y_offset + i%ky];
811
+ }
812
+ );
813
+
802
814
  #define CL_CHECK(err) \
803
815
  do { \
804
816
  cl_int err_ = (err); \
@@ -878,6 +890,7 @@ static std::string generate_kernels() {
878
890
  }
879
891
  src << mul_kernel << '\n';
880
892
  }
893
+ src << add_template << '\n';
881
894
 
882
895
  return src.str();
883
896
  }
@@ -893,6 +906,7 @@ static cl_kernel dequantize_mul_mat_vec_q4_0_cl, dequantize_mul_mat_vec_q4_1_cl,
893
906
  static cl_kernel dequantize_block_q2_k_cl, dequantize_block_q3_k_cl, dequantize_block_q4_k_cl, dequantize_block_q5_k_cl, dequantize_block_q6_k_cl;
894
907
  static cl_kernel dequantize_mul_mat_vec_q2_K_cl, dequantize_mul_mat_vec_q3_K_cl, dequantize_mul_mat_vec_q4_K_cl, dequantize_mul_mat_vec_q5_K_cl, dequantize_mul_mat_vec_q6_K_cl;
895
908
  static cl_kernel mul_f32_cl;
909
+ static cl_kernel add_f32_cl;
896
910
  static bool fp16_support;
897
911
 
898
912
  static cl_program build_program_from_source(cl_context ctx, cl_device_id dev, const char* program_buffer) {
@@ -1100,9 +1114,10 @@ void ggml_cl_init(void) {
1100
1114
  char *ext_buffer = (char *)alloca(ext_str_size + 1);
1101
1115
  clGetDeviceInfo(device, CL_DEVICE_EXTENSIONS, ext_str_size, ext_buffer, NULL);
1102
1116
  ext_buffer[ext_str_size] = '\0'; // ensure it is null terminated
1117
+ // Disabled due to faulty outputs
1103
1118
  // Check if ext_buffer contains cl_khr_fp16
1104
- fp16_support = strstr(ext_buffer, "cl_khr_fp16") != NULL;
1105
- fprintf(stderr, "ggml_opencl: device FP16 support: %s\n", fp16_support ? "true" : "false");
1119
+ fp16_support = false; // strstr(ext_buffer, "cl_khr_fp16") != NULL;
1120
+ // fprintf(stderr, "ggml_opencl: device FP16 support: %s\n", fp16_support ? "true" : "false");
1106
1121
 
1107
1122
  cl_context_properties properties[] = {
1108
1123
  (intptr_t)CL_CONTEXT_PLATFORM, (intptr_t)platform, 0
@@ -1150,6 +1165,8 @@ void ggml_cl_init(void) {
1150
1165
 
1151
1166
  // mul kernel
1152
1167
  CL_CHECK((mul_f32_cl = clCreateKernel(program, "mul_f32", &err), err));
1168
+
1169
+ CL_CHECK((add_f32_cl = clCreateKernel(program, "add_f32", &err), err));
1153
1170
  }
1154
1171
 
1155
1172
  static cl_kernel* ggml_get_to_fp32_cl(ggml_type type) {
@@ -1458,6 +1475,70 @@ void ggml_cl_mul(const struct ggml_tensor * src0, const struct ggml_tensor * src
1458
1475
  ggml_cl_mul_f32(src0, src1, dst);
1459
1476
  }
1460
1477
 
1478
+ static void ggml_cl_add_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
1479
+ GGML_ASSERT(src1->backend == GGML_BACKEND_GPU);
1480
+ const int64_t ne00 = src0->ne[0];
1481
+ const int64_t ne01 = src0->ne[1];
1482
+ const int64_t ne02 = src0->ne[2];
1483
+ const int64_t ne03 = src0->ne[3];
1484
+ const int64_t ne10 = src1->ne[0];
1485
+ const int64_t ne11 = src1->ne[1];
1486
+ const int64_t ne12 = src1->ne[2];
1487
+ const int64_t ne13 = src1->ne[3];
1488
+ const int nb2 = dst->nb[2];
1489
+ const int nb3 = dst->nb[3];
1490
+ size_t x_size;
1491
+ size_t d_size;
1492
+
1493
+ cl_mem d_X = ggml_cl_pool_malloc(ne00 * ne01 * sizeof(float), &x_size); // src0
1494
+ cl_mem d_Y = (cl_mem) src1->extra; // src1 is already on device, broadcasted.
1495
+ cl_mem d_D = ggml_cl_pool_malloc(ne00 * ne01 * sizeof(float), &d_size); // dst
1496
+
1497
+
1498
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
1499
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
1500
+ cl_event ev;
1501
+
1502
+ // copy src0 to device
1503
+ CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_X, 0, src0, i03, i02, &ev));
1504
+
1505
+ const int64_t i13 = i03%ne13;
1506
+ const int64_t i12 = i02%ne12;
1507
+ const int i1 = i13*ne12*ne11 + i12*ne11;
1508
+
1509
+ cl_int x_offset = 0;
1510
+ cl_int y_offset = i1*ne10;
1511
+ cl_int d_offset = 0;
1512
+
1513
+ size_t global = ne00 * ne01;
1514
+ cl_int ky = ne10 * ne11;
1515
+
1516
+ CL_CHECK(clSetKernelArg(add_f32_cl, 0, sizeof(cl_mem), &d_X));
1517
+ CL_CHECK(clSetKernelArg(add_f32_cl, 1, sizeof(cl_int), &x_offset));
1518
+ CL_CHECK(clSetKernelArg(add_f32_cl, 2, sizeof(cl_mem), &d_Y));
1519
+ CL_CHECK(clSetKernelArg(add_f32_cl, 3, sizeof(cl_int), &y_offset));
1520
+ CL_CHECK(clSetKernelArg(add_f32_cl, 4, sizeof(cl_mem), &d_D));
1521
+ CL_CHECK(clSetKernelArg(add_f32_cl, 5, sizeof(cl_int), &d_offset));
1522
+ CL_CHECK(clSetKernelArg(add_f32_cl, 6, sizeof(cl_int), &ky));
1523
+ CL_CHECK(clEnqueueNDRangeKernel(queue, add_f32_cl, 1, NULL, &global, NULL, 1, &ev, NULL));
1524
+
1525
+ CL_CHECK(clReleaseEvent(ev));
1526
+ CL_CHECK(clFinish(queue));
1527
+
1528
+ // copy dst to host
1529
+ float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
1530
+ CL_CHECK(clEnqueueReadBuffer(queue, d_D, true, 0, sizeof(float) * ne00*ne01, d, 0, NULL, NULL));
1531
+ }
1532
+ }
1533
+ ggml_cl_pool_free(d_X, x_size);
1534
+ ggml_cl_pool_free(d_D, d_size);
1535
+ }
1536
+
1537
+ void ggml_cl_add(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
1538
+ GGML_ASSERT(src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
1539
+ ggml_cl_add_f32(src0, src1, dst);
1540
+ }
1541
+
1461
1542
  static void ggml_cl_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
1462
1543
  const int64_t ne00 = src0->ne[0];
1463
1544
  const int64_t ne01 = src0->ne[1];
@@ -2044,6 +2125,15 @@ static size_t ggml_backend_opencl_buffer_type_get_alignment(ggml_backend_buffer_
2044
2125
  GGML_UNUSED(buffer_type);
2045
2126
  }
2046
2127
 
2128
+ static size_t ggml_backend_opencl_buffer_type_get_max_size(ggml_backend_buffer_type_t buffer_type) {
2129
+ static size_t max_size = -1;
2130
+ if (max_size == (size_t)-1) {
2131
+ ggml_cl_init();
2132
+ clGetDeviceInfo(device, CL_DEVICE_MAX_MEM_ALLOC_SIZE, sizeof(size_t), &max_size, NULL);
2133
+ }
2134
+ return max_size;
2135
+ }
2136
+
2047
2137
  static bool ggml_backend_opencl_buffer_type_supports_backend(ggml_backend_buffer_type_t buffer_type, ggml_backend_t backend) {
2048
2138
  //return ggml_backend_is_opencl(backend); // opencl must be used through the cpu backend
2049
2139
  return ggml_backend_is_cpu(backend);
@@ -2055,6 +2145,7 @@ static ggml_backend_buffer_type_i ggml_backend_opencl_buffer_type_interface = {
2055
2145
  /* .get_name = */ ggml_backend_opencl_buffer_type_name,
2056
2146
  /* .alloc_buffer = */ ggml_backend_opencl_buffer_type_alloc_buffer,
2057
2147
  /* .get_alignment = */ ggml_backend_opencl_buffer_type_get_alignment,
2148
+ /* .get_max_size = */ ggml_backend_opencl_buffer_type_get_max_size,
2058
2149
  /* .get_alloc_size = */ NULL,
2059
2150
  /* .supports_backend = */ ggml_backend_opencl_buffer_type_supports_backend,
2060
2151
  /* .is_host = */ NULL,
@@ -2111,6 +2202,7 @@ ggml_backend_buffer_type_t ggml_backend_opencl_host_buffer_type() {
2111
2202
  /* .get_name = */ ggml_backend_opencl_host_buffer_type_name,
2112
2203
  /* .alloc_buffer = */ ggml_backend_opencl_host_buffer_type_alloc_buffer,
2113
2204
  /* .get_alignment = */ ggml_backend_cpu_buffer_type()->iface.get_alignment,
2205
+ /* .get_max_size = */ NULL, // defaults to SIZE_MAX
2114
2206
  /* .get_alloc_size = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size,
2115
2207
  /* .supports_backend = */ ggml_backend_cpu_buffer_type()->iface.supports_backend,
2116
2208
  /* .is_host = */ ggml_backend_cpu_buffer_type()->iface.is_host,
@@ -10,6 +10,7 @@ extern "C" {
10
10
  GGML_API void ggml_cl_init(void);
11
11
 
12
12
  GGML_API void ggml_cl_mul(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
13
+ GGML_API void ggml_cl_add(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
13
14
  GGML_API bool ggml_cl_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, const struct ggml_tensor * dst);
14
15
  GGML_API size_t ggml_cl_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
15
16
  GGML_API void ggml_cl_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst, void * wdata, size_t wsize);