llama_cpp 0.9.0 → 0.9.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +9 -0
- data/ext/llama_cpp/extconf.rb +3 -11
- data/ext/llama_cpp/llama_cpp.cpp +147 -3
- data/ext/llama_cpp/src/ggml-cuda.cu +288 -92
- data/ext/llama_cpp/src/ggml-impl.h +237 -0
- data/ext/llama_cpp/src/ggml-metal.m +58 -37
- data/ext/llama_cpp/src/ggml-metal.metal +162 -34
- data/ext/llama_cpp/src/{k_quants.c → ggml-quants.c} +3329 -1099
- data/ext/llama_cpp/src/{k_quants.h → ggml-quants.h} +81 -22
- data/ext/llama_cpp/src/ggml.c +939 -3333
- data/ext/llama_cpp/src/ggml.h +25 -4
- data/ext/llama_cpp/src/llama.cpp +1819 -2554
- data/ext/llama_cpp/src/llama.h +32 -12
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +23 -2
- metadata +5 -4
@@ -184,36 +184,73 @@ kernel void kernel_soft_max(
|
|
184
184
|
constant int64_t & ne00,
|
185
185
|
constant int64_t & ne01,
|
186
186
|
constant int64_t & ne02,
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
187
|
+
threadgroup float * buf [[threadgroup(0)]],
|
188
|
+
uint tgpig[[threadgroup_position_in_grid]],
|
189
|
+
uint tpitg[[thread_position_in_threadgroup]],
|
190
|
+
uint sgitg[[simdgroup_index_in_threadgroup]],
|
191
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
192
|
+
uint ntg[[threads_per_threadgroup]]) {
|
193
|
+
const int64_t i03 = (tgpig) / (ne02*ne01);
|
194
|
+
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
|
195
|
+
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
193
196
|
|
194
197
|
device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
195
198
|
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
196
199
|
|
197
200
|
// parallel max
|
198
|
-
float lmax = tpitg
|
199
|
-
|
201
|
+
float lmax = tpitg < ne00 ? psrc0[tpitg] : -INFINITY;
|
202
|
+
|
203
|
+
for (int i00 = tpitg + ntg; i00 < ne00; i00 += ntg) {
|
200
204
|
lmax = MAX(lmax, psrc0[i00]);
|
201
205
|
}
|
202
|
-
|
206
|
+
|
207
|
+
float max = simd_max(lmax);
|
208
|
+
if (tiisg == 0) {
|
209
|
+
buf[sgitg] = max;
|
210
|
+
}
|
211
|
+
|
212
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
213
|
+
|
214
|
+
// broadcast, simd group number is ntg / 32
|
215
|
+
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
|
216
|
+
if (tpitg < i) {
|
217
|
+
buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]);
|
218
|
+
}
|
219
|
+
}
|
220
|
+
|
221
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
222
|
+
|
223
|
+
max = buf[0];
|
203
224
|
|
204
225
|
// parallel sum
|
205
226
|
float lsum = 0.0f;
|
206
|
-
for (int i00 = tpitg
|
227
|
+
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
207
228
|
const float exp_psrc0 = exp(psrc0[i00] - max);
|
208
229
|
lsum += exp_psrc0;
|
209
230
|
// Remember the result of exp here. exp is expensive, so we really do not
|
210
|
-
//
|
231
|
+
// wish to compute it twice.
|
211
232
|
pdst[i00] = exp_psrc0;
|
212
233
|
}
|
213
234
|
|
214
|
-
|
235
|
+
float sum = simd_sum(lsum);
|
236
|
+
if (tiisg == 0) {
|
237
|
+
buf[sgitg] = sum;
|
238
|
+
}
|
239
|
+
|
240
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
215
241
|
|
216
|
-
|
242
|
+
// broadcast, simd group number is ntg / 32
|
243
|
+
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
|
244
|
+
if (tpitg < i) {
|
245
|
+
buf[tpitg] += buf[tpitg + i];
|
246
|
+
}
|
247
|
+
}
|
248
|
+
|
249
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
250
|
+
|
251
|
+
sum = buf[0];
|
252
|
+
|
253
|
+
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
217
254
|
pdst[i00] /= sum;
|
218
255
|
}
|
219
256
|
}
|
@@ -224,37 +261,73 @@ kernel void kernel_soft_max_4(
|
|
224
261
|
constant int64_t & ne00,
|
225
262
|
constant int64_t & ne01,
|
226
263
|
constant int64_t & ne02,
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
264
|
+
threadgroup float * buf [[threadgroup(0)]],
|
265
|
+
uint tgpig[[threadgroup_position_in_grid]],
|
266
|
+
uint tpitg[[thread_position_in_threadgroup]],
|
267
|
+
uint sgitg[[simdgroup_index_in_threadgroup]],
|
268
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
269
|
+
uint ntg[[threads_per_threadgroup]]) {
|
270
|
+
const int64_t i03 = (tgpig) / (ne02*ne01);
|
271
|
+
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
|
272
|
+
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
233
273
|
|
234
274
|
device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
235
275
|
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
236
276
|
|
237
277
|
// parallel max
|
238
|
-
float4 lmax4 = tpitg
|
239
|
-
|
278
|
+
float4 lmax4 = tpitg < ne00/4 ? psrc4[tpitg] : -INFINITY;
|
279
|
+
|
280
|
+
for (int i00 = tpitg + ntg; i00 < ne00/4; i00 += ntg) {
|
240
281
|
lmax4 = fmax(lmax4, psrc4[i00]);
|
241
282
|
}
|
242
|
-
float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
|
243
283
|
|
244
|
-
const float
|
284
|
+
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
|
285
|
+
float max = simd_max(lmax);
|
286
|
+
if (tiisg == 0) {
|
287
|
+
buf[sgitg] = max;
|
288
|
+
}
|
289
|
+
|
290
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
291
|
+
|
292
|
+
// broadcast, simd group number is ntg / 32
|
293
|
+
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
|
294
|
+
if (tpitg < i) {
|
295
|
+
buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]);
|
296
|
+
}
|
297
|
+
}
|
298
|
+
|
299
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
300
|
+
|
301
|
+
max = buf[0];
|
245
302
|
|
246
303
|
// parallel sum
|
247
304
|
float4 lsum4 = 0.0f;
|
248
|
-
for (int i00 = tpitg
|
305
|
+
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
249
306
|
const float4 exp_psrc4 = exp(psrc4[i00] - max);
|
250
307
|
lsum4 += exp_psrc4;
|
251
308
|
pdst4[i00] = exp_psrc4;
|
252
309
|
}
|
253
|
-
float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
|
254
310
|
|
255
|
-
const float
|
311
|
+
const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
|
312
|
+
float sum = simd_sum(lsum);
|
313
|
+
if (tiisg == 0) {
|
314
|
+
buf[sgitg] = sum;
|
315
|
+
}
|
316
|
+
|
317
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
318
|
+
|
319
|
+
// broadcast, simd group number is ntg / 32
|
320
|
+
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
|
321
|
+
if (tpitg < i) {
|
322
|
+
buf[tpitg] += buf[tpitg + i];
|
323
|
+
}
|
324
|
+
}
|
325
|
+
|
326
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
327
|
+
|
328
|
+
sum = buf[0];
|
256
329
|
|
257
|
-
for (int i00 = tpitg
|
330
|
+
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
258
331
|
pdst4[i00] /= sum;
|
259
332
|
}
|
260
333
|
}
|
@@ -274,7 +347,7 @@ kernel void kernel_diag_mask_inf(
|
|
274
347
|
dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
|
275
348
|
} else {
|
276
349
|
dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
|
277
|
-
|
350
|
+
}
|
278
351
|
}
|
279
352
|
|
280
353
|
kernel void kernel_diag_mask_inf_8(
|
@@ -988,6 +1061,45 @@ kernel void kernel_alibi_f32(
|
|
988
1061
|
}
|
989
1062
|
}
|
990
1063
|
|
1064
|
+
static float rope_yarn_ramp(const float low, const float high, const int i0) {
|
1065
|
+
const float y = (i0 / 2 - low) / max(0.001f, high - low);
|
1066
|
+
return 1.0f - min(1.0f, max(0.0f, y));
|
1067
|
+
}
|
1068
|
+
|
1069
|
+
// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
|
1070
|
+
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
|
1071
|
+
static void rope_yarn(
|
1072
|
+
float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
|
1073
|
+
thread float * cos_theta, thread float * sin_theta
|
1074
|
+
) {
|
1075
|
+
// Get n-d rotational scaling corrected for extrapolation
|
1076
|
+
float theta_interp = freq_scale * theta_extrap;
|
1077
|
+
float theta = theta_interp;
|
1078
|
+
if (ext_factor != 0.0f) {
|
1079
|
+
float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
|
1080
|
+
theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
|
1081
|
+
|
1082
|
+
// Get n-d magnitude scaling corrected for interpolation
|
1083
|
+
mscale *= 1.0f + 0.1f * log(1.0f / freq_scale);
|
1084
|
+
}
|
1085
|
+
*cos_theta = cos(theta) * mscale;
|
1086
|
+
*sin_theta = sin(theta) * mscale;
|
1087
|
+
}
|
1088
|
+
|
1089
|
+
// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
|
1090
|
+
// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
|
1091
|
+
static float rope_yarn_corr_factor(int n_dims, int n_orig_ctx, float n_rot, float base) {
|
1092
|
+
return n_dims * log(n_orig_ctx / (n_rot * 2 * M_PI_F)) / (2 * log(base));
|
1093
|
+
}
|
1094
|
+
|
1095
|
+
static void rope_yarn_corr_dims(
|
1096
|
+
int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]
|
1097
|
+
) {
|
1098
|
+
// start and end correction dims
|
1099
|
+
dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_fast, freq_base)));
|
1100
|
+
dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_slow, freq_base)));
|
1101
|
+
}
|
1102
|
+
|
991
1103
|
typedef void (rope_t)(
|
992
1104
|
device const void * src0,
|
993
1105
|
device const int32_t * src1,
|
@@ -1011,8 +1123,13 @@ typedef void (rope_t)(
|
|
1011
1123
|
constant int & n_past,
|
1012
1124
|
constant int & n_dims,
|
1013
1125
|
constant int & mode,
|
1126
|
+
constant int & n_orig_ctx,
|
1014
1127
|
constant float & freq_base,
|
1015
1128
|
constant float & freq_scale,
|
1129
|
+
constant float & ext_factor,
|
1130
|
+
constant float & attn_factor,
|
1131
|
+
constant float & beta_fast,
|
1132
|
+
constant float & beta_slow,
|
1016
1133
|
uint tiitg[[thread_index_in_threadgroup]],
|
1017
1134
|
uint3 tptg[[threads_per_threadgroup]],
|
1018
1135
|
uint3 tgpig[[threadgroup_position_in_grid]]);
|
@@ -1041,8 +1158,13 @@ kernel void kernel_rope(
|
|
1041
1158
|
constant int & n_past,
|
1042
1159
|
constant int & n_dims,
|
1043
1160
|
constant int & mode,
|
1161
|
+
constant int & n_orig_ctx,
|
1044
1162
|
constant float & freq_base,
|
1045
1163
|
constant float & freq_scale,
|
1164
|
+
constant float & ext_factor,
|
1165
|
+
constant float & attn_factor,
|
1166
|
+
constant float & beta_fast,
|
1167
|
+
constant float & beta_slow,
|
1046
1168
|
uint tiitg[[thread_index_in_threadgroup]],
|
1047
1169
|
uint3 tptg[[threads_per_threadgroup]],
|
1048
1170
|
uint3 tgpig[[threadgroup_position_in_grid]]) {
|
@@ -1052,19 +1174,22 @@ kernel void kernel_rope(
|
|
1052
1174
|
|
1053
1175
|
const bool is_neox = mode & 2;
|
1054
1176
|
|
1177
|
+
float corr_dims[2];
|
1178
|
+
rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
|
1179
|
+
|
1055
1180
|
device const int32_t * pos = src1;
|
1056
1181
|
|
1057
1182
|
const int64_t p = pos[i2];
|
1058
1183
|
|
1059
|
-
const float theta_0 =
|
1184
|
+
const float theta_0 = (float)p;
|
1060
1185
|
const float inv_ndims = -1.f/n_dims;
|
1061
1186
|
|
1062
1187
|
if (!is_neox) {
|
1063
1188
|
for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
|
1064
1189
|
|
1065
1190
|
const float theta = theta_0 * pow(freq_base, inv_ndims*i0);
|
1066
|
-
|
1067
|
-
|
1191
|
+
float cos_theta, sin_theta;
|
1192
|
+
rope_yarn(theta, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
1068
1193
|
|
1069
1194
|
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
1070
1195
|
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
@@ -1079,9 +1204,12 @@ kernel void kernel_rope(
|
|
1079
1204
|
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
|
1080
1205
|
for (int64_t ic = 2*tiitg; ic < n_dims; ic += 2*tptg.x) {
|
1081
1206
|
|
1082
|
-
|
1083
|
-
const float
|
1084
|
-
|
1207
|
+
// simplified from `(ib * n_dims + ic) * inv_ndims`
|
1208
|
+
const float cur_rot = inv_ndims*ic - ib;
|
1209
|
+
|
1210
|
+
const float theta = theta_0 * pow(freq_base, cur_rot);
|
1211
|
+
float cos_theta, sin_theta;
|
1212
|
+
rope_yarn(theta, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
1085
1213
|
|
1086
1214
|
const int64_t i0 = ib*n_dims + ic/2;
|
1087
1215
|
|