llama_cpp 0.8.0 → 0.9.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -125,9 +125,17 @@ kernel void kernel_mul_row(
125
125
  }
126
126
 
127
127
  kernel void kernel_scale(
128
+ device const float * src0,
129
+ device float * dst,
130
+ constant float & scale,
131
+ uint tpig[[thread_position_in_grid]]) {
132
+ dst[tpig] = src0[tpig] * scale;
133
+ }
134
+
135
+ kernel void kernel_scale_4(
128
136
  device const float4 * src0,
129
137
  device float4 * dst,
130
- constant float & scale,
138
+ constant float & scale,
131
139
  uint tpig[[thread_position_in_grid]]) {
132
140
  dst[tpig] = src0[tpig] * scale;
133
141
  }
@@ -176,36 +184,73 @@ kernel void kernel_soft_max(
176
184
  constant int64_t & ne00,
177
185
  constant int64_t & ne01,
178
186
  constant int64_t & ne02,
179
- uint3 tgpig[[threadgroup_position_in_grid]],
180
- uint3 tpitg[[thread_position_in_threadgroup]],
181
- uint3 ntg[[threads_per_threadgroup]]) {
182
- const int64_t i03 = tgpig[2];
183
- const int64_t i02 = tgpig[1];
184
- const int64_t i01 = tgpig[0];
187
+ threadgroup float * buf [[threadgroup(0)]],
188
+ uint tgpig[[threadgroup_position_in_grid]],
189
+ uint tpitg[[thread_position_in_threadgroup]],
190
+ uint sgitg[[simdgroup_index_in_threadgroup]],
191
+ uint tiisg[[thread_index_in_simdgroup]],
192
+ uint ntg[[threads_per_threadgroup]]) {
193
+ const int64_t i03 = (tgpig) / (ne02*ne01);
194
+ const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
195
+ const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
185
196
 
186
197
  device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
187
198
  device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
188
199
 
189
200
  // parallel max
190
- float lmax = tpitg[0] < ne00 ? psrc0[tpitg[0]] : -INFINITY;
191
- for (int i00 = tpitg[0] + ntg[0]; i00 < ne00; i00 += ntg[0]) {
201
+ float lmax = tpitg < ne00 ? psrc0[tpitg] : -INFINITY;
202
+
203
+ for (int i00 = tpitg + ntg; i00 < ne00; i00 += ntg) {
192
204
  lmax = MAX(lmax, psrc0[i00]);
193
205
  }
194
- const float max = simd_max(lmax);
206
+
207
+ float max = simd_max(lmax);
208
+ if (tiisg == 0) {
209
+ buf[sgitg] = max;
210
+ }
211
+
212
+ threadgroup_barrier(mem_flags::mem_threadgroup);
213
+
214
+ // broadcast, simd group number is ntg / 32
215
+ for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
216
+ if (tpitg < i) {
217
+ buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]);
218
+ }
219
+ }
220
+
221
+ threadgroup_barrier(mem_flags::mem_threadgroup);
222
+
223
+ max = buf[0];
195
224
 
196
225
  // parallel sum
197
226
  float lsum = 0.0f;
198
- for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
227
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
199
228
  const float exp_psrc0 = exp(psrc0[i00] - max);
200
229
  lsum += exp_psrc0;
201
230
  // Remember the result of exp here. exp is expensive, so we really do not
202
- // whish to compute it twice.
231
+ // wish to compute it twice.
203
232
  pdst[i00] = exp_psrc0;
204
233
  }
205
234
 
206
- const float sum = simd_sum(lsum);
235
+ float sum = simd_sum(lsum);
236
+ if (tiisg == 0) {
237
+ buf[sgitg] = sum;
238
+ }
239
+
240
+ threadgroup_barrier(mem_flags::mem_threadgroup);
207
241
 
208
- for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
242
+ // broadcast, simd group number is ntg / 32
243
+ for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
244
+ if (tpitg < i) {
245
+ buf[tpitg] += buf[tpitg + i];
246
+ }
247
+ }
248
+
249
+ threadgroup_barrier(mem_flags::mem_threadgroup);
250
+
251
+ sum = buf[0];
252
+
253
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
209
254
  pdst[i00] /= sum;
210
255
  }
211
256
  }
@@ -216,37 +261,73 @@ kernel void kernel_soft_max_4(
216
261
  constant int64_t & ne00,
217
262
  constant int64_t & ne01,
218
263
  constant int64_t & ne02,
219
- uint3 tgpig[[threadgroup_position_in_grid]],
220
- uint3 tpitg[[thread_position_in_threadgroup]],
221
- uint3 ntg[[threads_per_threadgroup]]) {
222
- const int64_t i03 = tgpig[2];
223
- const int64_t i02 = tgpig[1];
224
- const int64_t i01 = tgpig[0];
264
+ threadgroup float * buf [[threadgroup(0)]],
265
+ uint tgpig[[threadgroup_position_in_grid]],
266
+ uint tpitg[[thread_position_in_threadgroup]],
267
+ uint sgitg[[simdgroup_index_in_threadgroup]],
268
+ uint tiisg[[thread_index_in_simdgroup]],
269
+ uint ntg[[threads_per_threadgroup]]) {
270
+ const int64_t i03 = (tgpig) / (ne02*ne01);
271
+ const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
272
+ const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
225
273
 
226
274
  device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
227
275
  device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
228
276
 
229
277
  // parallel max
230
- float4 lmax4 = tpitg[0] < ne00/4 ? psrc4[tpitg[0]] : -INFINITY;
231
- for (int i00 = tpitg[0] + ntg[0]; i00 < ne00/4; i00 += ntg[0]) {
278
+ float4 lmax4 = tpitg < ne00/4 ? psrc4[tpitg] : -INFINITY;
279
+
280
+ for (int i00 = tpitg + ntg; i00 < ne00/4; i00 += ntg) {
232
281
  lmax4 = fmax(lmax4, psrc4[i00]);
233
282
  }
234
- float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
235
283
 
236
- const float max = simd_max(lmax);
284
+ const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
285
+ float max = simd_max(lmax);
286
+ if (tiisg == 0) {
287
+ buf[sgitg] = max;
288
+ }
289
+
290
+ threadgroup_barrier(mem_flags::mem_threadgroup);
291
+
292
+ // broadcast, simd group number is ntg / 32
293
+ for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
294
+ if (tpitg < i) {
295
+ buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]);
296
+ }
297
+ }
298
+
299
+ threadgroup_barrier(mem_flags::mem_threadgroup);
300
+
301
+ max = buf[0];
237
302
 
238
303
  // parallel sum
239
304
  float4 lsum4 = 0.0f;
240
- for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) {
305
+ for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
241
306
  const float4 exp_psrc4 = exp(psrc4[i00] - max);
242
307
  lsum4 += exp_psrc4;
243
308
  pdst4[i00] = exp_psrc4;
244
309
  }
245
- float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
246
310
 
247
- const float sum = simd_sum(lsum);
311
+ const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
312
+ float sum = simd_sum(lsum);
313
+ if (tiisg == 0) {
314
+ buf[sgitg] = sum;
315
+ }
316
+
317
+ threadgroup_barrier(mem_flags::mem_threadgroup);
318
+
319
+ // broadcast, simd group number is ntg / 32
320
+ for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
321
+ if (tpitg < i) {
322
+ buf[tpitg] += buf[tpitg + i];
323
+ }
324
+ }
325
+
326
+ threadgroup_barrier(mem_flags::mem_threadgroup);
327
+
328
+ sum = buf[0];
248
329
 
249
- for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) {
330
+ for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
250
331
  pdst4[i00] /= sum;
251
332
  }
252
333
  }
@@ -266,7 +347,7 @@ kernel void kernel_diag_mask_inf(
266
347
  dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
267
348
  } else {
268
349
  dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
269
- }
350
+ }
270
351
  }
271
352
 
272
353
  kernel void kernel_diag_mask_inf_8(
@@ -980,6 +1061,45 @@ kernel void kernel_alibi_f32(
980
1061
  }
981
1062
  }
982
1063
 
1064
+ static float rope_yarn_ramp(const float low, const float high, const int i0) {
1065
+ const float y = (i0 / 2 - low) / max(0.001f, high - low);
1066
+ return 1.0f - min(1.0f, max(0.0f, y));
1067
+ }
1068
+
1069
+ // YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
1070
+ // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
1071
+ static void rope_yarn(
1072
+ float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
1073
+ thread float * cos_theta, thread float * sin_theta
1074
+ ) {
1075
+ // Get n-d rotational scaling corrected for extrapolation
1076
+ float theta_interp = freq_scale * theta_extrap;
1077
+ float theta = theta_interp;
1078
+ if (ext_factor != 0.0f) {
1079
+ float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
1080
+ theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
1081
+
1082
+ // Get n-d magnitude scaling corrected for interpolation
1083
+ mscale *= 1.0f + 0.1f * log(1.0f / freq_scale);
1084
+ }
1085
+ *cos_theta = cos(theta) * mscale;
1086
+ *sin_theta = sin(theta) * mscale;
1087
+ }
1088
+
1089
+ // Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
1090
+ // `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
1091
+ static float rope_yarn_corr_factor(int n_dims, int n_orig_ctx, float n_rot, float base) {
1092
+ return n_dims * log(n_orig_ctx / (n_rot * 2 * M_PI_F)) / (2 * log(base));
1093
+ }
1094
+
1095
+ static void rope_yarn_corr_dims(
1096
+ int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]
1097
+ ) {
1098
+ // start and end correction dims
1099
+ dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_fast, freq_base)));
1100
+ dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_slow, freq_base)));
1101
+ }
1102
+
983
1103
  typedef void (rope_t)(
984
1104
  device const void * src0,
985
1105
  device const int32_t * src1,
@@ -1003,8 +1123,13 @@ typedef void (rope_t)(
1003
1123
  constant int & n_past,
1004
1124
  constant int & n_dims,
1005
1125
  constant int & mode,
1126
+ constant int & n_orig_ctx,
1006
1127
  constant float & freq_base,
1007
1128
  constant float & freq_scale,
1129
+ constant float & ext_factor,
1130
+ constant float & attn_factor,
1131
+ constant float & beta_fast,
1132
+ constant float & beta_slow,
1008
1133
  uint tiitg[[thread_index_in_threadgroup]],
1009
1134
  uint3 tptg[[threads_per_threadgroup]],
1010
1135
  uint3 tgpig[[threadgroup_position_in_grid]]);
@@ -1033,8 +1158,13 @@ kernel void kernel_rope(
1033
1158
  constant int & n_past,
1034
1159
  constant int & n_dims,
1035
1160
  constant int & mode,
1161
+ constant int & n_orig_ctx,
1036
1162
  constant float & freq_base,
1037
1163
  constant float & freq_scale,
1164
+ constant float & ext_factor,
1165
+ constant float & attn_factor,
1166
+ constant float & beta_fast,
1167
+ constant float & beta_slow,
1038
1168
  uint tiitg[[thread_index_in_threadgroup]],
1039
1169
  uint3 tptg[[threads_per_threadgroup]],
1040
1170
  uint3 tgpig[[threadgroup_position_in_grid]]) {
@@ -1044,19 +1174,22 @@ kernel void kernel_rope(
1044
1174
 
1045
1175
  const bool is_neox = mode & 2;
1046
1176
 
1177
+ float corr_dims[2];
1178
+ rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
1179
+
1047
1180
  device const int32_t * pos = src1;
1048
1181
 
1049
1182
  const int64_t p = pos[i2];
1050
1183
 
1051
- const float theta_0 = freq_scale * (float)p;
1184
+ const float theta_0 = (float)p;
1052
1185
  const float inv_ndims = -1.f/n_dims;
1053
1186
 
1054
1187
  if (!is_neox) {
1055
1188
  for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
1056
1189
 
1057
1190
  const float theta = theta_0 * pow(freq_base, inv_ndims*i0);
1058
- const float cos_theta = cos(theta);
1059
- const float sin_theta = sin(theta);
1191
+ float cos_theta, sin_theta;
1192
+ rope_yarn(theta, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
1060
1193
 
1061
1194
  device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1062
1195
  device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
@@ -1071,9 +1204,12 @@ kernel void kernel_rope(
1071
1204
  for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
1072
1205
  for (int64_t ic = 2*tiitg; ic < n_dims; ic += 2*tptg.x) {
1073
1206
 
1074
- const float theta = theta_0 * pow(freq_base, inv_ndims*ic - ib);
1075
- const float cos_theta = cos(theta);
1076
- const float sin_theta = sin(theta);
1207
+ // simplified from `(ib * n_dims + ic) * inv_ndims`
1208
+ const float cur_rot = inv_ndims*ic - ib;
1209
+
1210
+ const float theta = theta_0 * pow(freq_base, cur_rot);
1211
+ float cos_theta, sin_theta;
1212
+ rope_yarn(theta, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
1077
1213
 
1078
1214
  const int64_t i0 = ib*n_dims + ic/2;
1079
1215