llama_cpp 0.8.0 → 0.9.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +19 -0
- data/examples/chat.rb +8 -6
- data/ext/llama_cpp/extconf.rb +3 -11
- data/ext/llama_cpp/llama_cpp.cpp +228 -165
- data/ext/llama_cpp/src/ggml-cuda.cu +441 -77
- data/ext/llama_cpp/src/ggml-impl.h +237 -0
- data/ext/llama_cpp/src/ggml-metal.m +71 -42
- data/ext/llama_cpp/src/ggml-metal.metal +171 -35
- data/ext/llama_cpp/src/ggml-opencl.cpp +161 -169
- data/ext/llama_cpp/src/{k_quants.c → ggml-quants.c} +3329 -1099
- data/ext/llama_cpp/src/{k_quants.h → ggml-quants.h} +81 -22
- data/ext/llama_cpp/src/ggml.c +1303 -3419
- data/ext/llama_cpp/src/ggml.h +33 -11
- data/ext/llama_cpp/src/llama.cpp +1925 -2655
- data/ext/llama_cpp/src/llama.h +48 -33
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +4 -4
- data/sig/llama_cpp.rbs +34 -14
- metadata +5 -4
@@ -125,9 +125,17 @@ kernel void kernel_mul_row(
|
|
125
125
|
}
|
126
126
|
|
127
127
|
kernel void kernel_scale(
|
128
|
+
device const float * src0,
|
129
|
+
device float * dst,
|
130
|
+
constant float & scale,
|
131
|
+
uint tpig[[thread_position_in_grid]]) {
|
132
|
+
dst[tpig] = src0[tpig] * scale;
|
133
|
+
}
|
134
|
+
|
135
|
+
kernel void kernel_scale_4(
|
128
136
|
device const float4 * src0,
|
129
137
|
device float4 * dst,
|
130
|
-
constant float
|
138
|
+
constant float & scale,
|
131
139
|
uint tpig[[thread_position_in_grid]]) {
|
132
140
|
dst[tpig] = src0[tpig] * scale;
|
133
141
|
}
|
@@ -176,36 +184,73 @@ kernel void kernel_soft_max(
|
|
176
184
|
constant int64_t & ne00,
|
177
185
|
constant int64_t & ne01,
|
178
186
|
constant int64_t & ne02,
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
187
|
+
threadgroup float * buf [[threadgroup(0)]],
|
188
|
+
uint tgpig[[threadgroup_position_in_grid]],
|
189
|
+
uint tpitg[[thread_position_in_threadgroup]],
|
190
|
+
uint sgitg[[simdgroup_index_in_threadgroup]],
|
191
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
192
|
+
uint ntg[[threads_per_threadgroup]]) {
|
193
|
+
const int64_t i03 = (tgpig) / (ne02*ne01);
|
194
|
+
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
|
195
|
+
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
185
196
|
|
186
197
|
device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
187
198
|
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
188
199
|
|
189
200
|
// parallel max
|
190
|
-
float lmax = tpitg
|
191
|
-
|
201
|
+
float lmax = tpitg < ne00 ? psrc0[tpitg] : -INFINITY;
|
202
|
+
|
203
|
+
for (int i00 = tpitg + ntg; i00 < ne00; i00 += ntg) {
|
192
204
|
lmax = MAX(lmax, psrc0[i00]);
|
193
205
|
}
|
194
|
-
|
206
|
+
|
207
|
+
float max = simd_max(lmax);
|
208
|
+
if (tiisg == 0) {
|
209
|
+
buf[sgitg] = max;
|
210
|
+
}
|
211
|
+
|
212
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
213
|
+
|
214
|
+
// broadcast, simd group number is ntg / 32
|
215
|
+
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
|
216
|
+
if (tpitg < i) {
|
217
|
+
buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]);
|
218
|
+
}
|
219
|
+
}
|
220
|
+
|
221
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
222
|
+
|
223
|
+
max = buf[0];
|
195
224
|
|
196
225
|
// parallel sum
|
197
226
|
float lsum = 0.0f;
|
198
|
-
for (int i00 = tpitg
|
227
|
+
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
199
228
|
const float exp_psrc0 = exp(psrc0[i00] - max);
|
200
229
|
lsum += exp_psrc0;
|
201
230
|
// Remember the result of exp here. exp is expensive, so we really do not
|
202
|
-
//
|
231
|
+
// wish to compute it twice.
|
203
232
|
pdst[i00] = exp_psrc0;
|
204
233
|
}
|
205
234
|
|
206
|
-
|
235
|
+
float sum = simd_sum(lsum);
|
236
|
+
if (tiisg == 0) {
|
237
|
+
buf[sgitg] = sum;
|
238
|
+
}
|
239
|
+
|
240
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
207
241
|
|
208
|
-
|
242
|
+
// broadcast, simd group number is ntg / 32
|
243
|
+
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
|
244
|
+
if (tpitg < i) {
|
245
|
+
buf[tpitg] += buf[tpitg + i];
|
246
|
+
}
|
247
|
+
}
|
248
|
+
|
249
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
250
|
+
|
251
|
+
sum = buf[0];
|
252
|
+
|
253
|
+
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
209
254
|
pdst[i00] /= sum;
|
210
255
|
}
|
211
256
|
}
|
@@ -216,37 +261,73 @@ kernel void kernel_soft_max_4(
|
|
216
261
|
constant int64_t & ne00,
|
217
262
|
constant int64_t & ne01,
|
218
263
|
constant int64_t & ne02,
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
264
|
+
threadgroup float * buf [[threadgroup(0)]],
|
265
|
+
uint tgpig[[threadgroup_position_in_grid]],
|
266
|
+
uint tpitg[[thread_position_in_threadgroup]],
|
267
|
+
uint sgitg[[simdgroup_index_in_threadgroup]],
|
268
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
269
|
+
uint ntg[[threads_per_threadgroup]]) {
|
270
|
+
const int64_t i03 = (tgpig) / (ne02*ne01);
|
271
|
+
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
|
272
|
+
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
225
273
|
|
226
274
|
device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
227
275
|
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
228
276
|
|
229
277
|
// parallel max
|
230
|
-
float4 lmax4 = tpitg
|
231
|
-
|
278
|
+
float4 lmax4 = tpitg < ne00/4 ? psrc4[tpitg] : -INFINITY;
|
279
|
+
|
280
|
+
for (int i00 = tpitg + ntg; i00 < ne00/4; i00 += ntg) {
|
232
281
|
lmax4 = fmax(lmax4, psrc4[i00]);
|
233
282
|
}
|
234
|
-
float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
|
235
283
|
|
236
|
-
const float
|
284
|
+
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
|
285
|
+
float max = simd_max(lmax);
|
286
|
+
if (tiisg == 0) {
|
287
|
+
buf[sgitg] = max;
|
288
|
+
}
|
289
|
+
|
290
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
291
|
+
|
292
|
+
// broadcast, simd group number is ntg / 32
|
293
|
+
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
|
294
|
+
if (tpitg < i) {
|
295
|
+
buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]);
|
296
|
+
}
|
297
|
+
}
|
298
|
+
|
299
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
300
|
+
|
301
|
+
max = buf[0];
|
237
302
|
|
238
303
|
// parallel sum
|
239
304
|
float4 lsum4 = 0.0f;
|
240
|
-
for (int i00 = tpitg
|
305
|
+
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
241
306
|
const float4 exp_psrc4 = exp(psrc4[i00] - max);
|
242
307
|
lsum4 += exp_psrc4;
|
243
308
|
pdst4[i00] = exp_psrc4;
|
244
309
|
}
|
245
|
-
float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
|
246
310
|
|
247
|
-
const float
|
311
|
+
const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
|
312
|
+
float sum = simd_sum(lsum);
|
313
|
+
if (tiisg == 0) {
|
314
|
+
buf[sgitg] = sum;
|
315
|
+
}
|
316
|
+
|
317
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
318
|
+
|
319
|
+
// broadcast, simd group number is ntg / 32
|
320
|
+
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
|
321
|
+
if (tpitg < i) {
|
322
|
+
buf[tpitg] += buf[tpitg + i];
|
323
|
+
}
|
324
|
+
}
|
325
|
+
|
326
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
327
|
+
|
328
|
+
sum = buf[0];
|
248
329
|
|
249
|
-
for (int i00 = tpitg
|
330
|
+
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
250
331
|
pdst4[i00] /= sum;
|
251
332
|
}
|
252
333
|
}
|
@@ -266,7 +347,7 @@ kernel void kernel_diag_mask_inf(
|
|
266
347
|
dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
|
267
348
|
} else {
|
268
349
|
dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
|
269
|
-
|
350
|
+
}
|
270
351
|
}
|
271
352
|
|
272
353
|
kernel void kernel_diag_mask_inf_8(
|
@@ -980,6 +1061,45 @@ kernel void kernel_alibi_f32(
|
|
980
1061
|
}
|
981
1062
|
}
|
982
1063
|
|
1064
|
+
static float rope_yarn_ramp(const float low, const float high, const int i0) {
|
1065
|
+
const float y = (i0 / 2 - low) / max(0.001f, high - low);
|
1066
|
+
return 1.0f - min(1.0f, max(0.0f, y));
|
1067
|
+
}
|
1068
|
+
|
1069
|
+
// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
|
1070
|
+
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
|
1071
|
+
static void rope_yarn(
|
1072
|
+
float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
|
1073
|
+
thread float * cos_theta, thread float * sin_theta
|
1074
|
+
) {
|
1075
|
+
// Get n-d rotational scaling corrected for extrapolation
|
1076
|
+
float theta_interp = freq_scale * theta_extrap;
|
1077
|
+
float theta = theta_interp;
|
1078
|
+
if (ext_factor != 0.0f) {
|
1079
|
+
float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
|
1080
|
+
theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
|
1081
|
+
|
1082
|
+
// Get n-d magnitude scaling corrected for interpolation
|
1083
|
+
mscale *= 1.0f + 0.1f * log(1.0f / freq_scale);
|
1084
|
+
}
|
1085
|
+
*cos_theta = cos(theta) * mscale;
|
1086
|
+
*sin_theta = sin(theta) * mscale;
|
1087
|
+
}
|
1088
|
+
|
1089
|
+
// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
|
1090
|
+
// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
|
1091
|
+
static float rope_yarn_corr_factor(int n_dims, int n_orig_ctx, float n_rot, float base) {
|
1092
|
+
return n_dims * log(n_orig_ctx / (n_rot * 2 * M_PI_F)) / (2 * log(base));
|
1093
|
+
}
|
1094
|
+
|
1095
|
+
static void rope_yarn_corr_dims(
|
1096
|
+
int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]
|
1097
|
+
) {
|
1098
|
+
// start and end correction dims
|
1099
|
+
dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_fast, freq_base)));
|
1100
|
+
dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_slow, freq_base)));
|
1101
|
+
}
|
1102
|
+
|
983
1103
|
typedef void (rope_t)(
|
984
1104
|
device const void * src0,
|
985
1105
|
device const int32_t * src1,
|
@@ -1003,8 +1123,13 @@ typedef void (rope_t)(
|
|
1003
1123
|
constant int & n_past,
|
1004
1124
|
constant int & n_dims,
|
1005
1125
|
constant int & mode,
|
1126
|
+
constant int & n_orig_ctx,
|
1006
1127
|
constant float & freq_base,
|
1007
1128
|
constant float & freq_scale,
|
1129
|
+
constant float & ext_factor,
|
1130
|
+
constant float & attn_factor,
|
1131
|
+
constant float & beta_fast,
|
1132
|
+
constant float & beta_slow,
|
1008
1133
|
uint tiitg[[thread_index_in_threadgroup]],
|
1009
1134
|
uint3 tptg[[threads_per_threadgroup]],
|
1010
1135
|
uint3 tgpig[[threadgroup_position_in_grid]]);
|
@@ -1033,8 +1158,13 @@ kernel void kernel_rope(
|
|
1033
1158
|
constant int & n_past,
|
1034
1159
|
constant int & n_dims,
|
1035
1160
|
constant int & mode,
|
1161
|
+
constant int & n_orig_ctx,
|
1036
1162
|
constant float & freq_base,
|
1037
1163
|
constant float & freq_scale,
|
1164
|
+
constant float & ext_factor,
|
1165
|
+
constant float & attn_factor,
|
1166
|
+
constant float & beta_fast,
|
1167
|
+
constant float & beta_slow,
|
1038
1168
|
uint tiitg[[thread_index_in_threadgroup]],
|
1039
1169
|
uint3 tptg[[threads_per_threadgroup]],
|
1040
1170
|
uint3 tgpig[[threadgroup_position_in_grid]]) {
|
@@ -1044,19 +1174,22 @@ kernel void kernel_rope(
|
|
1044
1174
|
|
1045
1175
|
const bool is_neox = mode & 2;
|
1046
1176
|
|
1177
|
+
float corr_dims[2];
|
1178
|
+
rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
|
1179
|
+
|
1047
1180
|
device const int32_t * pos = src1;
|
1048
1181
|
|
1049
1182
|
const int64_t p = pos[i2];
|
1050
1183
|
|
1051
|
-
const float theta_0 =
|
1184
|
+
const float theta_0 = (float)p;
|
1052
1185
|
const float inv_ndims = -1.f/n_dims;
|
1053
1186
|
|
1054
1187
|
if (!is_neox) {
|
1055
1188
|
for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
|
1056
1189
|
|
1057
1190
|
const float theta = theta_0 * pow(freq_base, inv_ndims*i0);
|
1058
|
-
|
1059
|
-
|
1191
|
+
float cos_theta, sin_theta;
|
1192
|
+
rope_yarn(theta, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
1060
1193
|
|
1061
1194
|
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
1062
1195
|
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
@@ -1071,9 +1204,12 @@ kernel void kernel_rope(
|
|
1071
1204
|
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
|
1072
1205
|
for (int64_t ic = 2*tiitg; ic < n_dims; ic += 2*tptg.x) {
|
1073
1206
|
|
1074
|
-
|
1075
|
-
const float
|
1076
|
-
|
1207
|
+
// simplified from `(ib * n_dims + ic) * inv_ndims`
|
1208
|
+
const float cur_rot = inv_ndims*ic - ib;
|
1209
|
+
|
1210
|
+
const float theta = theta_0 * pow(freq_base, cur_rot);
|
1211
|
+
float cos_theta, sin_theta;
|
1212
|
+
rope_yarn(theta, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
1077
1213
|
|
1078
1214
|
const int64_t i0 = ib*n_dims + ic/2;
|
1079
1215
|
|