llama_cpp 0.8.0 → 0.9.1

Sign up to get free protection for your applications and to get access to all the features.
@@ -125,9 +125,17 @@ kernel void kernel_mul_row(
125
125
  }
126
126
 
127
127
  kernel void kernel_scale(
128
+ device const float * src0,
129
+ device float * dst,
130
+ constant float & scale,
131
+ uint tpig[[thread_position_in_grid]]) {
132
+ dst[tpig] = src0[tpig] * scale;
133
+ }
134
+
135
+ kernel void kernel_scale_4(
128
136
  device const float4 * src0,
129
137
  device float4 * dst,
130
- constant float & scale,
138
+ constant float & scale,
131
139
  uint tpig[[thread_position_in_grid]]) {
132
140
  dst[tpig] = src0[tpig] * scale;
133
141
  }
@@ -176,36 +184,73 @@ kernel void kernel_soft_max(
176
184
  constant int64_t & ne00,
177
185
  constant int64_t & ne01,
178
186
  constant int64_t & ne02,
179
- uint3 tgpig[[threadgroup_position_in_grid]],
180
- uint3 tpitg[[thread_position_in_threadgroup]],
181
- uint3 ntg[[threads_per_threadgroup]]) {
182
- const int64_t i03 = tgpig[2];
183
- const int64_t i02 = tgpig[1];
184
- const int64_t i01 = tgpig[0];
187
+ threadgroup float * buf [[threadgroup(0)]],
188
+ uint tgpig[[threadgroup_position_in_grid]],
189
+ uint tpitg[[thread_position_in_threadgroup]],
190
+ uint sgitg[[simdgroup_index_in_threadgroup]],
191
+ uint tiisg[[thread_index_in_simdgroup]],
192
+ uint ntg[[threads_per_threadgroup]]) {
193
+ const int64_t i03 = (tgpig) / (ne02*ne01);
194
+ const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
195
+ const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
185
196
 
186
197
  device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
187
198
  device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
188
199
 
189
200
  // parallel max
190
- float lmax = tpitg[0] < ne00 ? psrc0[tpitg[0]] : -INFINITY;
191
- for (int i00 = tpitg[0] + ntg[0]; i00 < ne00; i00 += ntg[0]) {
201
+ float lmax = tpitg < ne00 ? psrc0[tpitg] : -INFINITY;
202
+
203
+ for (int i00 = tpitg + ntg; i00 < ne00; i00 += ntg) {
192
204
  lmax = MAX(lmax, psrc0[i00]);
193
205
  }
194
- const float max = simd_max(lmax);
206
+
207
+ float max = simd_max(lmax);
208
+ if (tiisg == 0) {
209
+ buf[sgitg] = max;
210
+ }
211
+
212
+ threadgroup_barrier(mem_flags::mem_threadgroup);
213
+
214
+ // broadcast, simd group number is ntg / 32
215
+ for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
216
+ if (tpitg < i) {
217
+ buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]);
218
+ }
219
+ }
220
+
221
+ threadgroup_barrier(mem_flags::mem_threadgroup);
222
+
223
+ max = buf[0];
195
224
 
196
225
  // parallel sum
197
226
  float lsum = 0.0f;
198
- for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
227
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
199
228
  const float exp_psrc0 = exp(psrc0[i00] - max);
200
229
  lsum += exp_psrc0;
201
230
  // Remember the result of exp here. exp is expensive, so we really do not
202
- // whish to compute it twice.
231
+ // wish to compute it twice.
203
232
  pdst[i00] = exp_psrc0;
204
233
  }
205
234
 
206
- const float sum = simd_sum(lsum);
235
+ float sum = simd_sum(lsum);
236
+ if (tiisg == 0) {
237
+ buf[sgitg] = sum;
238
+ }
239
+
240
+ threadgroup_barrier(mem_flags::mem_threadgroup);
207
241
 
208
- for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
242
+ // broadcast, simd group number is ntg / 32
243
+ for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
244
+ if (tpitg < i) {
245
+ buf[tpitg] += buf[tpitg + i];
246
+ }
247
+ }
248
+
249
+ threadgroup_barrier(mem_flags::mem_threadgroup);
250
+
251
+ sum = buf[0];
252
+
253
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
209
254
  pdst[i00] /= sum;
210
255
  }
211
256
  }
@@ -216,37 +261,73 @@ kernel void kernel_soft_max_4(
216
261
  constant int64_t & ne00,
217
262
  constant int64_t & ne01,
218
263
  constant int64_t & ne02,
219
- uint3 tgpig[[threadgroup_position_in_grid]],
220
- uint3 tpitg[[thread_position_in_threadgroup]],
221
- uint3 ntg[[threads_per_threadgroup]]) {
222
- const int64_t i03 = tgpig[2];
223
- const int64_t i02 = tgpig[1];
224
- const int64_t i01 = tgpig[0];
264
+ threadgroup float * buf [[threadgroup(0)]],
265
+ uint tgpig[[threadgroup_position_in_grid]],
266
+ uint tpitg[[thread_position_in_threadgroup]],
267
+ uint sgitg[[simdgroup_index_in_threadgroup]],
268
+ uint tiisg[[thread_index_in_simdgroup]],
269
+ uint ntg[[threads_per_threadgroup]]) {
270
+ const int64_t i03 = (tgpig) / (ne02*ne01);
271
+ const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
272
+ const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
225
273
 
226
274
  device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
227
275
  device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
228
276
 
229
277
  // parallel max
230
- float4 lmax4 = tpitg[0] < ne00/4 ? psrc4[tpitg[0]] : -INFINITY;
231
- for (int i00 = tpitg[0] + ntg[0]; i00 < ne00/4; i00 += ntg[0]) {
278
+ float4 lmax4 = tpitg < ne00/4 ? psrc4[tpitg] : -INFINITY;
279
+
280
+ for (int i00 = tpitg + ntg; i00 < ne00/4; i00 += ntg) {
232
281
  lmax4 = fmax(lmax4, psrc4[i00]);
233
282
  }
234
- float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
235
283
 
236
- const float max = simd_max(lmax);
284
+ const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
285
+ float max = simd_max(lmax);
286
+ if (tiisg == 0) {
287
+ buf[sgitg] = max;
288
+ }
289
+
290
+ threadgroup_barrier(mem_flags::mem_threadgroup);
291
+
292
+ // broadcast, simd group number is ntg / 32
293
+ for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
294
+ if (tpitg < i) {
295
+ buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]);
296
+ }
297
+ }
298
+
299
+ threadgroup_barrier(mem_flags::mem_threadgroup);
300
+
301
+ max = buf[0];
237
302
 
238
303
  // parallel sum
239
304
  float4 lsum4 = 0.0f;
240
- for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) {
305
+ for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
241
306
  const float4 exp_psrc4 = exp(psrc4[i00] - max);
242
307
  lsum4 += exp_psrc4;
243
308
  pdst4[i00] = exp_psrc4;
244
309
  }
245
- float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
246
310
 
247
- const float sum = simd_sum(lsum);
311
+ const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
312
+ float sum = simd_sum(lsum);
313
+ if (tiisg == 0) {
314
+ buf[sgitg] = sum;
315
+ }
316
+
317
+ threadgroup_barrier(mem_flags::mem_threadgroup);
318
+
319
+ // broadcast, simd group number is ntg / 32
320
+ for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
321
+ if (tpitg < i) {
322
+ buf[tpitg] += buf[tpitg + i];
323
+ }
324
+ }
325
+
326
+ threadgroup_barrier(mem_flags::mem_threadgroup);
327
+
328
+ sum = buf[0];
248
329
 
249
- for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) {
330
+ for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
250
331
  pdst4[i00] /= sum;
251
332
  }
252
333
  }
@@ -266,7 +347,7 @@ kernel void kernel_diag_mask_inf(
266
347
  dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
267
348
  } else {
268
349
  dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
269
- }
350
+ }
270
351
  }
271
352
 
272
353
  kernel void kernel_diag_mask_inf_8(
@@ -980,6 +1061,45 @@ kernel void kernel_alibi_f32(
980
1061
  }
981
1062
  }
982
1063
 
1064
+ static float rope_yarn_ramp(const float low, const float high, const int i0) {
1065
+ const float y = (i0 / 2 - low) / max(0.001f, high - low);
1066
+ return 1.0f - min(1.0f, max(0.0f, y));
1067
+ }
1068
+
1069
+ // YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
1070
+ // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
1071
+ static void rope_yarn(
1072
+ float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
1073
+ thread float * cos_theta, thread float * sin_theta
1074
+ ) {
1075
+ // Get n-d rotational scaling corrected for extrapolation
1076
+ float theta_interp = freq_scale * theta_extrap;
1077
+ float theta = theta_interp;
1078
+ if (ext_factor != 0.0f) {
1079
+ float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
1080
+ theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
1081
+
1082
+ // Get n-d magnitude scaling corrected for interpolation
1083
+ mscale *= 1.0f + 0.1f * log(1.0f / freq_scale);
1084
+ }
1085
+ *cos_theta = cos(theta) * mscale;
1086
+ *sin_theta = sin(theta) * mscale;
1087
+ }
1088
+
1089
+ // Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
1090
+ // `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
1091
+ static float rope_yarn_corr_factor(int n_dims, int n_orig_ctx, float n_rot, float base) {
1092
+ return n_dims * log(n_orig_ctx / (n_rot * 2 * M_PI_F)) / (2 * log(base));
1093
+ }
1094
+
1095
+ static void rope_yarn_corr_dims(
1096
+ int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]
1097
+ ) {
1098
+ // start and end correction dims
1099
+ dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_fast, freq_base)));
1100
+ dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_slow, freq_base)));
1101
+ }
1102
+
983
1103
  typedef void (rope_t)(
984
1104
  device const void * src0,
985
1105
  device const int32_t * src1,
@@ -1003,8 +1123,13 @@ typedef void (rope_t)(
1003
1123
  constant int & n_past,
1004
1124
  constant int & n_dims,
1005
1125
  constant int & mode,
1126
+ constant int & n_orig_ctx,
1006
1127
  constant float & freq_base,
1007
1128
  constant float & freq_scale,
1129
+ constant float & ext_factor,
1130
+ constant float & attn_factor,
1131
+ constant float & beta_fast,
1132
+ constant float & beta_slow,
1008
1133
  uint tiitg[[thread_index_in_threadgroup]],
1009
1134
  uint3 tptg[[threads_per_threadgroup]],
1010
1135
  uint3 tgpig[[threadgroup_position_in_grid]]);
@@ -1033,8 +1158,13 @@ kernel void kernel_rope(
1033
1158
  constant int & n_past,
1034
1159
  constant int & n_dims,
1035
1160
  constant int & mode,
1161
+ constant int & n_orig_ctx,
1036
1162
  constant float & freq_base,
1037
1163
  constant float & freq_scale,
1164
+ constant float & ext_factor,
1165
+ constant float & attn_factor,
1166
+ constant float & beta_fast,
1167
+ constant float & beta_slow,
1038
1168
  uint tiitg[[thread_index_in_threadgroup]],
1039
1169
  uint3 tptg[[threads_per_threadgroup]],
1040
1170
  uint3 tgpig[[threadgroup_position_in_grid]]) {
@@ -1044,19 +1174,22 @@ kernel void kernel_rope(
1044
1174
 
1045
1175
  const bool is_neox = mode & 2;
1046
1176
 
1177
+ float corr_dims[2];
1178
+ rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
1179
+
1047
1180
  device const int32_t * pos = src1;
1048
1181
 
1049
1182
  const int64_t p = pos[i2];
1050
1183
 
1051
- const float theta_0 = freq_scale * (float)p;
1184
+ const float theta_0 = (float)p;
1052
1185
  const float inv_ndims = -1.f/n_dims;
1053
1186
 
1054
1187
  if (!is_neox) {
1055
1188
  for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
1056
1189
 
1057
1190
  const float theta = theta_0 * pow(freq_base, inv_ndims*i0);
1058
- const float cos_theta = cos(theta);
1059
- const float sin_theta = sin(theta);
1191
+ float cos_theta, sin_theta;
1192
+ rope_yarn(theta, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
1060
1193
 
1061
1194
  device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1062
1195
  device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
@@ -1071,9 +1204,12 @@ kernel void kernel_rope(
1071
1204
  for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
1072
1205
  for (int64_t ic = 2*tiitg; ic < n_dims; ic += 2*tptg.x) {
1073
1206
 
1074
- const float theta = theta_0 * pow(freq_base, inv_ndims*ic - ib);
1075
- const float cos_theta = cos(theta);
1076
- const float sin_theta = sin(theta);
1207
+ // simplified from `(ib * n_dims + ic) * inv_ndims`
1208
+ const float cur_rot = inv_ndims*ic - ib;
1209
+
1210
+ const float theta = theta_0 * pow(freq_base, cur_rot);
1211
+ float cos_theta, sin_theta;
1212
+ rope_yarn(theta, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
1077
1213
 
1078
1214
  const int64_t i0 = ib*n_dims + ic/2;
1079
1215