llama_cpp 0.9.0 → 0.9.1

Sign up to get free protection for your applications and to get access to all the features.
@@ -184,36 +184,73 @@ kernel void kernel_soft_max(
184
184
  constant int64_t & ne00,
185
185
  constant int64_t & ne01,
186
186
  constant int64_t & ne02,
187
- uint3 tgpig[[threadgroup_position_in_grid]],
188
- uint3 tpitg[[thread_position_in_threadgroup]],
189
- uint3 ntg[[threads_per_threadgroup]]) {
190
- const int64_t i03 = tgpig[2];
191
- const int64_t i02 = tgpig[1];
192
- const int64_t i01 = tgpig[0];
187
+ threadgroup float * buf [[threadgroup(0)]],
188
+ uint tgpig[[threadgroup_position_in_grid]],
189
+ uint tpitg[[thread_position_in_threadgroup]],
190
+ uint sgitg[[simdgroup_index_in_threadgroup]],
191
+ uint tiisg[[thread_index_in_simdgroup]],
192
+ uint ntg[[threads_per_threadgroup]]) {
193
+ const int64_t i03 = (tgpig) / (ne02*ne01);
194
+ const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
195
+ const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
193
196
 
194
197
  device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
195
198
  device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
196
199
 
197
200
  // parallel max
198
- float lmax = tpitg[0] < ne00 ? psrc0[tpitg[0]] : -INFINITY;
199
- for (int i00 = tpitg[0] + ntg[0]; i00 < ne00; i00 += ntg[0]) {
201
+ float lmax = tpitg < ne00 ? psrc0[tpitg] : -INFINITY;
202
+
203
+ for (int i00 = tpitg + ntg; i00 < ne00; i00 += ntg) {
200
204
  lmax = MAX(lmax, psrc0[i00]);
201
205
  }
202
- const float max = simd_max(lmax);
206
+
207
+ float max = simd_max(lmax);
208
+ if (tiisg == 0) {
209
+ buf[sgitg] = max;
210
+ }
211
+
212
+ threadgroup_barrier(mem_flags::mem_threadgroup);
213
+
214
+ // broadcast, simd group number is ntg / 32
215
+ for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
216
+ if (tpitg < i) {
217
+ buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]);
218
+ }
219
+ }
220
+
221
+ threadgroup_barrier(mem_flags::mem_threadgroup);
222
+
223
+ max = buf[0];
203
224
 
204
225
  // parallel sum
205
226
  float lsum = 0.0f;
206
- for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
227
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
207
228
  const float exp_psrc0 = exp(psrc0[i00] - max);
208
229
  lsum += exp_psrc0;
209
230
  // Remember the result of exp here. exp is expensive, so we really do not
210
- // whish to compute it twice.
231
+ // wish to compute it twice.
211
232
  pdst[i00] = exp_psrc0;
212
233
  }
213
234
 
214
- const float sum = simd_sum(lsum);
235
+ float sum = simd_sum(lsum);
236
+ if (tiisg == 0) {
237
+ buf[sgitg] = sum;
238
+ }
239
+
240
+ threadgroup_barrier(mem_flags::mem_threadgroup);
215
241
 
216
- for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
242
+ // broadcast, simd group number is ntg / 32
243
+ for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
244
+ if (tpitg < i) {
245
+ buf[tpitg] += buf[tpitg + i];
246
+ }
247
+ }
248
+
249
+ threadgroup_barrier(mem_flags::mem_threadgroup);
250
+
251
+ sum = buf[0];
252
+
253
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
217
254
  pdst[i00] /= sum;
218
255
  }
219
256
  }
@@ -224,37 +261,73 @@ kernel void kernel_soft_max_4(
224
261
  constant int64_t & ne00,
225
262
  constant int64_t & ne01,
226
263
  constant int64_t & ne02,
227
- uint3 tgpig[[threadgroup_position_in_grid]],
228
- uint3 tpitg[[thread_position_in_threadgroup]],
229
- uint3 ntg[[threads_per_threadgroup]]) {
230
- const int64_t i03 = tgpig[2];
231
- const int64_t i02 = tgpig[1];
232
- const int64_t i01 = tgpig[0];
264
+ threadgroup float * buf [[threadgroup(0)]],
265
+ uint tgpig[[threadgroup_position_in_grid]],
266
+ uint tpitg[[thread_position_in_threadgroup]],
267
+ uint sgitg[[simdgroup_index_in_threadgroup]],
268
+ uint tiisg[[thread_index_in_simdgroup]],
269
+ uint ntg[[threads_per_threadgroup]]) {
270
+ const int64_t i03 = (tgpig) / (ne02*ne01);
271
+ const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
272
+ const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
233
273
 
234
274
  device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
235
275
  device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
236
276
 
237
277
  // parallel max
238
- float4 lmax4 = tpitg[0] < ne00/4 ? psrc4[tpitg[0]] : -INFINITY;
239
- for (int i00 = tpitg[0] + ntg[0]; i00 < ne00/4; i00 += ntg[0]) {
278
+ float4 lmax4 = tpitg < ne00/4 ? psrc4[tpitg] : -INFINITY;
279
+
280
+ for (int i00 = tpitg + ntg; i00 < ne00/4; i00 += ntg) {
240
281
  lmax4 = fmax(lmax4, psrc4[i00]);
241
282
  }
242
- float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
243
283
 
244
- const float max = simd_max(lmax);
284
+ const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
285
+ float max = simd_max(lmax);
286
+ if (tiisg == 0) {
287
+ buf[sgitg] = max;
288
+ }
289
+
290
+ threadgroup_barrier(mem_flags::mem_threadgroup);
291
+
292
+ // broadcast, simd group number is ntg / 32
293
+ for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
294
+ if (tpitg < i) {
295
+ buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]);
296
+ }
297
+ }
298
+
299
+ threadgroup_barrier(mem_flags::mem_threadgroup);
300
+
301
+ max = buf[0];
245
302
 
246
303
  // parallel sum
247
304
  float4 lsum4 = 0.0f;
248
- for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) {
305
+ for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
249
306
  const float4 exp_psrc4 = exp(psrc4[i00] - max);
250
307
  lsum4 += exp_psrc4;
251
308
  pdst4[i00] = exp_psrc4;
252
309
  }
253
- float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
254
310
 
255
- const float sum = simd_sum(lsum);
311
+ const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
312
+ float sum = simd_sum(lsum);
313
+ if (tiisg == 0) {
314
+ buf[sgitg] = sum;
315
+ }
316
+
317
+ threadgroup_barrier(mem_flags::mem_threadgroup);
318
+
319
+ // broadcast, simd group number is ntg / 32
320
+ for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
321
+ if (tpitg < i) {
322
+ buf[tpitg] += buf[tpitg + i];
323
+ }
324
+ }
325
+
326
+ threadgroup_barrier(mem_flags::mem_threadgroup);
327
+
328
+ sum = buf[0];
256
329
 
257
- for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) {
330
+ for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
258
331
  pdst4[i00] /= sum;
259
332
  }
260
333
  }
@@ -274,7 +347,7 @@ kernel void kernel_diag_mask_inf(
274
347
  dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
275
348
  } else {
276
349
  dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
277
- }
350
+ }
278
351
  }
279
352
 
280
353
  kernel void kernel_diag_mask_inf_8(
@@ -988,6 +1061,45 @@ kernel void kernel_alibi_f32(
988
1061
  }
989
1062
  }
990
1063
 
1064
+ static float rope_yarn_ramp(const float low, const float high, const int i0) {
1065
+ const float y = (i0 / 2 - low) / max(0.001f, high - low);
1066
+ return 1.0f - min(1.0f, max(0.0f, y));
1067
+ }
1068
+
1069
+ // YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
1070
+ // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
1071
+ static void rope_yarn(
1072
+ float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
1073
+ thread float * cos_theta, thread float * sin_theta
1074
+ ) {
1075
+ // Get n-d rotational scaling corrected for extrapolation
1076
+ float theta_interp = freq_scale * theta_extrap;
1077
+ float theta = theta_interp;
1078
+ if (ext_factor != 0.0f) {
1079
+ float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
1080
+ theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
1081
+
1082
+ // Get n-d magnitude scaling corrected for interpolation
1083
+ mscale *= 1.0f + 0.1f * log(1.0f / freq_scale);
1084
+ }
1085
+ *cos_theta = cos(theta) * mscale;
1086
+ *sin_theta = sin(theta) * mscale;
1087
+ }
1088
+
1089
+ // Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
1090
+ // `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
1091
+ static float rope_yarn_corr_factor(int n_dims, int n_orig_ctx, float n_rot, float base) {
1092
+ return n_dims * log(n_orig_ctx / (n_rot * 2 * M_PI_F)) / (2 * log(base));
1093
+ }
1094
+
1095
+ static void rope_yarn_corr_dims(
1096
+ int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]
1097
+ ) {
1098
+ // start and end correction dims
1099
+ dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_fast, freq_base)));
1100
+ dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_slow, freq_base)));
1101
+ }
1102
+
991
1103
  typedef void (rope_t)(
992
1104
  device const void * src0,
993
1105
  device const int32_t * src1,
@@ -1011,8 +1123,13 @@ typedef void (rope_t)(
1011
1123
  constant int & n_past,
1012
1124
  constant int & n_dims,
1013
1125
  constant int & mode,
1126
+ constant int & n_orig_ctx,
1014
1127
  constant float & freq_base,
1015
1128
  constant float & freq_scale,
1129
+ constant float & ext_factor,
1130
+ constant float & attn_factor,
1131
+ constant float & beta_fast,
1132
+ constant float & beta_slow,
1016
1133
  uint tiitg[[thread_index_in_threadgroup]],
1017
1134
  uint3 tptg[[threads_per_threadgroup]],
1018
1135
  uint3 tgpig[[threadgroup_position_in_grid]]);
@@ -1041,8 +1158,13 @@ kernel void kernel_rope(
1041
1158
  constant int & n_past,
1042
1159
  constant int & n_dims,
1043
1160
  constant int & mode,
1161
+ constant int & n_orig_ctx,
1044
1162
  constant float & freq_base,
1045
1163
  constant float & freq_scale,
1164
+ constant float & ext_factor,
1165
+ constant float & attn_factor,
1166
+ constant float & beta_fast,
1167
+ constant float & beta_slow,
1046
1168
  uint tiitg[[thread_index_in_threadgroup]],
1047
1169
  uint3 tptg[[threads_per_threadgroup]],
1048
1170
  uint3 tgpig[[threadgroup_position_in_grid]]) {
@@ -1052,19 +1174,22 @@ kernel void kernel_rope(
1052
1174
 
1053
1175
  const bool is_neox = mode & 2;
1054
1176
 
1177
+ float corr_dims[2];
1178
+ rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
1179
+
1055
1180
  device const int32_t * pos = src1;
1056
1181
 
1057
1182
  const int64_t p = pos[i2];
1058
1183
 
1059
- const float theta_0 = freq_scale * (float)p;
1184
+ const float theta_0 = (float)p;
1060
1185
  const float inv_ndims = -1.f/n_dims;
1061
1186
 
1062
1187
  if (!is_neox) {
1063
1188
  for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
1064
1189
 
1065
1190
  const float theta = theta_0 * pow(freq_base, inv_ndims*i0);
1066
- const float cos_theta = cos(theta);
1067
- const float sin_theta = sin(theta);
1191
+ float cos_theta, sin_theta;
1192
+ rope_yarn(theta, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
1068
1193
 
1069
1194
  device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1070
1195
  device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
@@ -1079,9 +1204,12 @@ kernel void kernel_rope(
1079
1204
  for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
1080
1205
  for (int64_t ic = 2*tiitg; ic < n_dims; ic += 2*tptg.x) {
1081
1206
 
1082
- const float theta = theta_0 * pow(freq_base, inv_ndims*ic - ib);
1083
- const float cos_theta = cos(theta);
1084
- const float sin_theta = sin(theta);
1207
+ // simplified from `(ib * n_dims + ic) * inv_ndims`
1208
+ const float cur_rot = inv_ndims*ic - ib;
1209
+
1210
+ const float theta = theta_0 * pow(freq_base, cur_rot);
1211
+ float cos_theta, sin_theta;
1212
+ rope_yarn(theta, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
1085
1213
 
1086
1214
  const int64_t i0 = ib*n_dims + ic/2;
1087
1215