llama_cpp 0.8.0 → 0.9.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +19 -0
- data/examples/chat.rb +8 -6
- data/ext/llama_cpp/extconf.rb +3 -11
- data/ext/llama_cpp/llama_cpp.cpp +228 -165
- data/ext/llama_cpp/src/ggml-cuda.cu +441 -77
- data/ext/llama_cpp/src/ggml-impl.h +237 -0
- data/ext/llama_cpp/src/ggml-metal.m +71 -42
- data/ext/llama_cpp/src/ggml-metal.metal +171 -35
- data/ext/llama_cpp/src/ggml-opencl.cpp +161 -169
- data/ext/llama_cpp/src/{k_quants.c → ggml-quants.c} +3329 -1099
- data/ext/llama_cpp/src/{k_quants.h → ggml-quants.h} +81 -22
- data/ext/llama_cpp/src/ggml.c +1303 -3419
- data/ext/llama_cpp/src/ggml.h +33 -11
- data/ext/llama_cpp/src/llama.cpp +1925 -2655
- data/ext/llama_cpp/src/llama.h +48 -33
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +4 -4
- data/sig/llama_cpp.rbs +34 -14
- metadata +5 -4
@@ -125,9 +125,17 @@ kernel void kernel_mul_row(
|
|
125
125
|
}
|
126
126
|
|
127
127
|
kernel void kernel_scale(
|
128
|
+
device const float * src0,
|
129
|
+
device float * dst,
|
130
|
+
constant float & scale,
|
131
|
+
uint tpig[[thread_position_in_grid]]) {
|
132
|
+
dst[tpig] = src0[tpig] * scale;
|
133
|
+
}
|
134
|
+
|
135
|
+
kernel void kernel_scale_4(
|
128
136
|
device const float4 * src0,
|
129
137
|
device float4 * dst,
|
130
|
-
constant float
|
138
|
+
constant float & scale,
|
131
139
|
uint tpig[[thread_position_in_grid]]) {
|
132
140
|
dst[tpig] = src0[tpig] * scale;
|
133
141
|
}
|
@@ -176,36 +184,73 @@ kernel void kernel_soft_max(
|
|
176
184
|
constant int64_t & ne00,
|
177
185
|
constant int64_t & ne01,
|
178
186
|
constant int64_t & ne02,
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
187
|
+
threadgroup float * buf [[threadgroup(0)]],
|
188
|
+
uint tgpig[[threadgroup_position_in_grid]],
|
189
|
+
uint tpitg[[thread_position_in_threadgroup]],
|
190
|
+
uint sgitg[[simdgroup_index_in_threadgroup]],
|
191
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
192
|
+
uint ntg[[threads_per_threadgroup]]) {
|
193
|
+
const int64_t i03 = (tgpig) / (ne02*ne01);
|
194
|
+
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
|
195
|
+
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
185
196
|
|
186
197
|
device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
187
198
|
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
188
199
|
|
189
200
|
// parallel max
|
190
|
-
float lmax = tpitg
|
191
|
-
|
201
|
+
float lmax = tpitg < ne00 ? psrc0[tpitg] : -INFINITY;
|
202
|
+
|
203
|
+
for (int i00 = tpitg + ntg; i00 < ne00; i00 += ntg) {
|
192
204
|
lmax = MAX(lmax, psrc0[i00]);
|
193
205
|
}
|
194
|
-
|
206
|
+
|
207
|
+
float max = simd_max(lmax);
|
208
|
+
if (tiisg == 0) {
|
209
|
+
buf[sgitg] = max;
|
210
|
+
}
|
211
|
+
|
212
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
213
|
+
|
214
|
+
// broadcast, simd group number is ntg / 32
|
215
|
+
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
|
216
|
+
if (tpitg < i) {
|
217
|
+
buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]);
|
218
|
+
}
|
219
|
+
}
|
220
|
+
|
221
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
222
|
+
|
223
|
+
max = buf[0];
|
195
224
|
|
196
225
|
// parallel sum
|
197
226
|
float lsum = 0.0f;
|
198
|
-
for (int i00 = tpitg
|
227
|
+
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
199
228
|
const float exp_psrc0 = exp(psrc0[i00] - max);
|
200
229
|
lsum += exp_psrc0;
|
201
230
|
// Remember the result of exp here. exp is expensive, so we really do not
|
202
|
-
//
|
231
|
+
// wish to compute it twice.
|
203
232
|
pdst[i00] = exp_psrc0;
|
204
233
|
}
|
205
234
|
|
206
|
-
|
235
|
+
float sum = simd_sum(lsum);
|
236
|
+
if (tiisg == 0) {
|
237
|
+
buf[sgitg] = sum;
|
238
|
+
}
|
239
|
+
|
240
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
207
241
|
|
208
|
-
|
242
|
+
// broadcast, simd group number is ntg / 32
|
243
|
+
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
|
244
|
+
if (tpitg < i) {
|
245
|
+
buf[tpitg] += buf[tpitg + i];
|
246
|
+
}
|
247
|
+
}
|
248
|
+
|
249
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
250
|
+
|
251
|
+
sum = buf[0];
|
252
|
+
|
253
|
+
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
209
254
|
pdst[i00] /= sum;
|
210
255
|
}
|
211
256
|
}
|
@@ -216,37 +261,73 @@ kernel void kernel_soft_max_4(
|
|
216
261
|
constant int64_t & ne00,
|
217
262
|
constant int64_t & ne01,
|
218
263
|
constant int64_t & ne02,
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
264
|
+
threadgroup float * buf [[threadgroup(0)]],
|
265
|
+
uint tgpig[[threadgroup_position_in_grid]],
|
266
|
+
uint tpitg[[thread_position_in_threadgroup]],
|
267
|
+
uint sgitg[[simdgroup_index_in_threadgroup]],
|
268
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
269
|
+
uint ntg[[threads_per_threadgroup]]) {
|
270
|
+
const int64_t i03 = (tgpig) / (ne02*ne01);
|
271
|
+
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
|
272
|
+
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
225
273
|
|
226
274
|
device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
227
275
|
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
228
276
|
|
229
277
|
// parallel max
|
230
|
-
float4 lmax4 = tpitg
|
231
|
-
|
278
|
+
float4 lmax4 = tpitg < ne00/4 ? psrc4[tpitg] : -INFINITY;
|
279
|
+
|
280
|
+
for (int i00 = tpitg + ntg; i00 < ne00/4; i00 += ntg) {
|
232
281
|
lmax4 = fmax(lmax4, psrc4[i00]);
|
233
282
|
}
|
234
|
-
float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
|
235
283
|
|
236
|
-
const float
|
284
|
+
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
|
285
|
+
float max = simd_max(lmax);
|
286
|
+
if (tiisg == 0) {
|
287
|
+
buf[sgitg] = max;
|
288
|
+
}
|
289
|
+
|
290
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
291
|
+
|
292
|
+
// broadcast, simd group number is ntg / 32
|
293
|
+
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
|
294
|
+
if (tpitg < i) {
|
295
|
+
buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]);
|
296
|
+
}
|
297
|
+
}
|
298
|
+
|
299
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
300
|
+
|
301
|
+
max = buf[0];
|
237
302
|
|
238
303
|
// parallel sum
|
239
304
|
float4 lsum4 = 0.0f;
|
240
|
-
for (int i00 = tpitg
|
305
|
+
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
241
306
|
const float4 exp_psrc4 = exp(psrc4[i00] - max);
|
242
307
|
lsum4 += exp_psrc4;
|
243
308
|
pdst4[i00] = exp_psrc4;
|
244
309
|
}
|
245
|
-
float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
|
246
310
|
|
247
|
-
const float
|
311
|
+
const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
|
312
|
+
float sum = simd_sum(lsum);
|
313
|
+
if (tiisg == 0) {
|
314
|
+
buf[sgitg] = sum;
|
315
|
+
}
|
316
|
+
|
317
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
318
|
+
|
319
|
+
// broadcast, simd group number is ntg / 32
|
320
|
+
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
|
321
|
+
if (tpitg < i) {
|
322
|
+
buf[tpitg] += buf[tpitg + i];
|
323
|
+
}
|
324
|
+
}
|
325
|
+
|
326
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
327
|
+
|
328
|
+
sum = buf[0];
|
248
329
|
|
249
|
-
for (int i00 = tpitg
|
330
|
+
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
250
331
|
pdst4[i00] /= sum;
|
251
332
|
}
|
252
333
|
}
|
@@ -266,7 +347,7 @@ kernel void kernel_diag_mask_inf(
|
|
266
347
|
dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
|
267
348
|
} else {
|
268
349
|
dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
|
269
|
-
|
350
|
+
}
|
270
351
|
}
|
271
352
|
|
272
353
|
kernel void kernel_diag_mask_inf_8(
|
@@ -980,6 +1061,45 @@ kernel void kernel_alibi_f32(
|
|
980
1061
|
}
|
981
1062
|
}
|
982
1063
|
|
1064
|
+
static float rope_yarn_ramp(const float low, const float high, const int i0) {
|
1065
|
+
const float y = (i0 / 2 - low) / max(0.001f, high - low);
|
1066
|
+
return 1.0f - min(1.0f, max(0.0f, y));
|
1067
|
+
}
|
1068
|
+
|
1069
|
+
// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
|
1070
|
+
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
|
1071
|
+
static void rope_yarn(
|
1072
|
+
float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
|
1073
|
+
thread float * cos_theta, thread float * sin_theta
|
1074
|
+
) {
|
1075
|
+
// Get n-d rotational scaling corrected for extrapolation
|
1076
|
+
float theta_interp = freq_scale * theta_extrap;
|
1077
|
+
float theta = theta_interp;
|
1078
|
+
if (ext_factor != 0.0f) {
|
1079
|
+
float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
|
1080
|
+
theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
|
1081
|
+
|
1082
|
+
// Get n-d magnitude scaling corrected for interpolation
|
1083
|
+
mscale *= 1.0f + 0.1f * log(1.0f / freq_scale);
|
1084
|
+
}
|
1085
|
+
*cos_theta = cos(theta) * mscale;
|
1086
|
+
*sin_theta = sin(theta) * mscale;
|
1087
|
+
}
|
1088
|
+
|
1089
|
+
// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
|
1090
|
+
// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
|
1091
|
+
static float rope_yarn_corr_factor(int n_dims, int n_orig_ctx, float n_rot, float base) {
|
1092
|
+
return n_dims * log(n_orig_ctx / (n_rot * 2 * M_PI_F)) / (2 * log(base));
|
1093
|
+
}
|
1094
|
+
|
1095
|
+
static void rope_yarn_corr_dims(
|
1096
|
+
int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]
|
1097
|
+
) {
|
1098
|
+
// start and end correction dims
|
1099
|
+
dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_fast, freq_base)));
|
1100
|
+
dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_slow, freq_base)));
|
1101
|
+
}
|
1102
|
+
|
983
1103
|
typedef void (rope_t)(
|
984
1104
|
device const void * src0,
|
985
1105
|
device const int32_t * src1,
|
@@ -1003,8 +1123,13 @@ typedef void (rope_t)(
|
|
1003
1123
|
constant int & n_past,
|
1004
1124
|
constant int & n_dims,
|
1005
1125
|
constant int & mode,
|
1126
|
+
constant int & n_orig_ctx,
|
1006
1127
|
constant float & freq_base,
|
1007
1128
|
constant float & freq_scale,
|
1129
|
+
constant float & ext_factor,
|
1130
|
+
constant float & attn_factor,
|
1131
|
+
constant float & beta_fast,
|
1132
|
+
constant float & beta_slow,
|
1008
1133
|
uint tiitg[[thread_index_in_threadgroup]],
|
1009
1134
|
uint3 tptg[[threads_per_threadgroup]],
|
1010
1135
|
uint3 tgpig[[threadgroup_position_in_grid]]);
|
@@ -1033,8 +1158,13 @@ kernel void kernel_rope(
|
|
1033
1158
|
constant int & n_past,
|
1034
1159
|
constant int & n_dims,
|
1035
1160
|
constant int & mode,
|
1161
|
+
constant int & n_orig_ctx,
|
1036
1162
|
constant float & freq_base,
|
1037
1163
|
constant float & freq_scale,
|
1164
|
+
constant float & ext_factor,
|
1165
|
+
constant float & attn_factor,
|
1166
|
+
constant float & beta_fast,
|
1167
|
+
constant float & beta_slow,
|
1038
1168
|
uint tiitg[[thread_index_in_threadgroup]],
|
1039
1169
|
uint3 tptg[[threads_per_threadgroup]],
|
1040
1170
|
uint3 tgpig[[threadgroup_position_in_grid]]) {
|
@@ -1044,19 +1174,22 @@ kernel void kernel_rope(
|
|
1044
1174
|
|
1045
1175
|
const bool is_neox = mode & 2;
|
1046
1176
|
|
1177
|
+
float corr_dims[2];
|
1178
|
+
rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
|
1179
|
+
|
1047
1180
|
device const int32_t * pos = src1;
|
1048
1181
|
|
1049
1182
|
const int64_t p = pos[i2];
|
1050
1183
|
|
1051
|
-
const float theta_0 =
|
1184
|
+
const float theta_0 = (float)p;
|
1052
1185
|
const float inv_ndims = -1.f/n_dims;
|
1053
1186
|
|
1054
1187
|
if (!is_neox) {
|
1055
1188
|
for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
|
1056
1189
|
|
1057
1190
|
const float theta = theta_0 * pow(freq_base, inv_ndims*i0);
|
1058
|
-
|
1059
|
-
|
1191
|
+
float cos_theta, sin_theta;
|
1192
|
+
rope_yarn(theta, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
1060
1193
|
|
1061
1194
|
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
1062
1195
|
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
@@ -1071,9 +1204,12 @@ kernel void kernel_rope(
|
|
1071
1204
|
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
|
1072
1205
|
for (int64_t ic = 2*tiitg; ic < n_dims; ic += 2*tptg.x) {
|
1073
1206
|
|
1074
|
-
|
1075
|
-
const float
|
1076
|
-
|
1207
|
+
// simplified from `(ib * n_dims + ic) * inv_ndims`
|
1208
|
+
const float cur_rot = inv_ndims*ic - ib;
|
1209
|
+
|
1210
|
+
const float theta = theta_0 * pow(freq_base, cur_rot);
|
1211
|
+
float cos_theta, sin_theta;
|
1212
|
+
rope_yarn(theta, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
1077
1213
|
|
1078
1214
|
const int64_t i0 = ib*n_dims + ic/2;
|
1079
1215
|
|