node-llama-cpp 2.7.3 → 2.7.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +3 -2
- package/dist/cli/commands/DownloadCommand.js +38 -32
- package/dist/cli/commands/DownloadCommand.js.map +1 -1
- package/dist/config.d.ts +3 -0
- package/dist/config.js +4 -0
- package/dist/config.js.map +1 -1
- package/dist/utils/cloneLlamaCppRepo.js +24 -6
- package/dist/utils/cloneLlamaCppRepo.js.map +1 -1
- package/dist/utils/cmake.js +5 -0
- package/dist/utils/cmake.js.map +1 -1
- package/dist/utils/compileLLamaCpp.js +25 -21
- package/dist/utils/compileLLamaCpp.js.map +1 -1
- package/dist/utils/gbnfJson/terminals/GbnfStringValue.js +4 -2
- package/dist/utils/gbnfJson/terminals/GbnfStringValue.js.map +1 -1
- package/llama/CMakeLists.txt +2 -2
- package/llama/addon.cpp +11 -18
- package/llama/binariesGithubRelease.json +1 -1
- package/llama/gitRelease.bundle +0 -0
- package/llama/package.json +5 -0
- package/llamaBins/linux-arm64/llama-addon.node +0 -0
- package/llamaBins/linux-armv7l/llama-addon.node +0 -0
- package/llamaBins/linux-x64/llama-addon.node +0 -0
- package/llamaBins/mac-arm64/ggml-metal.metal +333 -36
- package/llamaBins/mac-arm64/llama-addon.node +0 -0
- package/llamaBins/mac-x64/ggml-metal.metal +333 -36
- package/llamaBins/mac-x64/llama-addon.node +0 -0
- package/llamaBins/win-x64/llama-addon.exp +0 -0
- package/llamaBins/win-x64/llama-addon.lib +0 -0
- package/llamaBins/win-x64/llama-addon.node +0 -0
- package/package.json +2 -2
|
@@ -18,6 +18,21 @@ typedef struct {
|
|
|
18
18
|
uint8_t qs[QK4_1 / 2]; // nibbles / quants
|
|
19
19
|
} block_q4_1;
|
|
20
20
|
|
|
21
|
+
#define QK5_0 32
|
|
22
|
+
typedef struct {
|
|
23
|
+
half d; // delta
|
|
24
|
+
uint8_t qh[4]; // 5-th bit of quants
|
|
25
|
+
uint8_t qs[QK5_0 / 2]; // nibbles / quants
|
|
26
|
+
} block_q5_0;
|
|
27
|
+
|
|
28
|
+
#define QK5_1 32
|
|
29
|
+
typedef struct {
|
|
30
|
+
half d; // delta
|
|
31
|
+
half m; // min
|
|
32
|
+
uint8_t qh[4]; // 5-th bit of quants
|
|
33
|
+
uint8_t qs[QK5_1 / 2]; // nibbles / quants
|
|
34
|
+
} block_q5_1;
|
|
35
|
+
|
|
21
36
|
#define QK8_0 32
|
|
22
37
|
typedef struct {
|
|
23
38
|
half d; // delta
|
|
@@ -110,9 +125,17 @@ kernel void kernel_mul_row(
|
|
|
110
125
|
}
|
|
111
126
|
|
|
112
127
|
kernel void kernel_scale(
|
|
128
|
+
device const float * src0,
|
|
129
|
+
device float * dst,
|
|
130
|
+
constant float & scale,
|
|
131
|
+
uint tpig[[thread_position_in_grid]]) {
|
|
132
|
+
dst[tpig] = src0[tpig] * scale;
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
kernel void kernel_scale_4(
|
|
113
136
|
device const float4 * src0,
|
|
114
137
|
device float4 * dst,
|
|
115
|
-
constant float
|
|
138
|
+
constant float & scale,
|
|
116
139
|
uint tpig[[thread_position_in_grid]]) {
|
|
117
140
|
dst[tpig] = src0[tpig] * scale;
|
|
118
141
|
}
|
|
@@ -161,36 +184,73 @@ kernel void kernel_soft_max(
|
|
|
161
184
|
constant int64_t & ne00,
|
|
162
185
|
constant int64_t & ne01,
|
|
163
186
|
constant int64_t & ne02,
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
187
|
+
threadgroup float * buf [[threadgroup(0)]],
|
|
188
|
+
uint tgpig[[threadgroup_position_in_grid]],
|
|
189
|
+
uint tpitg[[thread_position_in_threadgroup]],
|
|
190
|
+
uint sgitg[[simdgroup_index_in_threadgroup]],
|
|
191
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
192
|
+
uint ntg[[threads_per_threadgroup]]) {
|
|
193
|
+
const int64_t i03 = (tgpig) / (ne02*ne01);
|
|
194
|
+
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
|
|
195
|
+
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
|
170
196
|
|
|
171
197
|
device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
|
172
198
|
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
|
173
199
|
|
|
174
200
|
// parallel max
|
|
175
|
-
float lmax = tpitg
|
|
176
|
-
|
|
201
|
+
float lmax = tpitg < ne00 ? psrc0[tpitg] : -INFINITY;
|
|
202
|
+
|
|
203
|
+
for (int i00 = tpitg + ntg; i00 < ne00; i00 += ntg) {
|
|
177
204
|
lmax = MAX(lmax, psrc0[i00]);
|
|
178
205
|
}
|
|
179
|
-
|
|
206
|
+
|
|
207
|
+
float max = simd_max(lmax);
|
|
208
|
+
if (tiisg == 0) {
|
|
209
|
+
buf[sgitg] = max;
|
|
210
|
+
}
|
|
211
|
+
|
|
212
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
213
|
+
|
|
214
|
+
// broadcast, simd group number is ntg / 32
|
|
215
|
+
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
|
|
216
|
+
if (tpitg < i) {
|
|
217
|
+
buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]);
|
|
218
|
+
}
|
|
219
|
+
}
|
|
220
|
+
|
|
221
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
222
|
+
|
|
223
|
+
max = buf[0];
|
|
180
224
|
|
|
181
225
|
// parallel sum
|
|
182
226
|
float lsum = 0.0f;
|
|
183
|
-
for (int i00 = tpitg
|
|
227
|
+
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
|
184
228
|
const float exp_psrc0 = exp(psrc0[i00] - max);
|
|
185
229
|
lsum += exp_psrc0;
|
|
186
230
|
// Remember the result of exp here. exp is expensive, so we really do not
|
|
187
|
-
//
|
|
231
|
+
// wish to compute it twice.
|
|
188
232
|
pdst[i00] = exp_psrc0;
|
|
189
233
|
}
|
|
190
234
|
|
|
191
|
-
|
|
235
|
+
float sum = simd_sum(lsum);
|
|
236
|
+
if (tiisg == 0) {
|
|
237
|
+
buf[sgitg] = sum;
|
|
238
|
+
}
|
|
239
|
+
|
|
240
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
241
|
+
|
|
242
|
+
// broadcast, simd group number is ntg / 32
|
|
243
|
+
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
|
|
244
|
+
if (tpitg < i) {
|
|
245
|
+
buf[tpitg] += buf[tpitg + i];
|
|
246
|
+
}
|
|
247
|
+
}
|
|
248
|
+
|
|
249
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
250
|
+
|
|
251
|
+
sum = buf[0];
|
|
192
252
|
|
|
193
|
-
for (int i00 = tpitg
|
|
253
|
+
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
|
194
254
|
pdst[i00] /= sum;
|
|
195
255
|
}
|
|
196
256
|
}
|
|
@@ -201,37 +261,73 @@ kernel void kernel_soft_max_4(
|
|
|
201
261
|
constant int64_t & ne00,
|
|
202
262
|
constant int64_t & ne01,
|
|
203
263
|
constant int64_t & ne02,
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
264
|
+
threadgroup float * buf [[threadgroup(0)]],
|
|
265
|
+
uint tgpig[[threadgroup_position_in_grid]],
|
|
266
|
+
uint tpitg[[thread_position_in_threadgroup]],
|
|
267
|
+
uint sgitg[[simdgroup_index_in_threadgroup]],
|
|
268
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
269
|
+
uint ntg[[threads_per_threadgroup]]) {
|
|
270
|
+
const int64_t i03 = (tgpig) / (ne02*ne01);
|
|
271
|
+
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
|
|
272
|
+
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
|
210
273
|
|
|
211
274
|
device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
|
212
275
|
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
|
213
276
|
|
|
214
277
|
// parallel max
|
|
215
|
-
float4 lmax4 = tpitg
|
|
216
|
-
|
|
278
|
+
float4 lmax4 = tpitg < ne00/4 ? psrc4[tpitg] : -INFINITY;
|
|
279
|
+
|
|
280
|
+
for (int i00 = tpitg + ntg; i00 < ne00/4; i00 += ntg) {
|
|
217
281
|
lmax4 = fmax(lmax4, psrc4[i00]);
|
|
218
282
|
}
|
|
219
|
-
float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
|
|
220
283
|
|
|
221
|
-
const float
|
|
284
|
+
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
|
|
285
|
+
float max = simd_max(lmax);
|
|
286
|
+
if (tiisg == 0) {
|
|
287
|
+
buf[sgitg] = max;
|
|
288
|
+
}
|
|
289
|
+
|
|
290
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
291
|
+
|
|
292
|
+
// broadcast, simd group number is ntg / 32
|
|
293
|
+
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
|
|
294
|
+
if (tpitg < i) {
|
|
295
|
+
buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]);
|
|
296
|
+
}
|
|
297
|
+
}
|
|
298
|
+
|
|
299
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
300
|
+
|
|
301
|
+
max = buf[0];
|
|
222
302
|
|
|
223
303
|
// parallel sum
|
|
224
304
|
float4 lsum4 = 0.0f;
|
|
225
|
-
for (int i00 = tpitg
|
|
305
|
+
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
|
226
306
|
const float4 exp_psrc4 = exp(psrc4[i00] - max);
|
|
227
307
|
lsum4 += exp_psrc4;
|
|
228
308
|
pdst4[i00] = exp_psrc4;
|
|
229
309
|
}
|
|
230
|
-
float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
|
|
231
310
|
|
|
232
|
-
const float
|
|
311
|
+
const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
|
|
312
|
+
float sum = simd_sum(lsum);
|
|
313
|
+
if (tiisg == 0) {
|
|
314
|
+
buf[sgitg] = sum;
|
|
315
|
+
}
|
|
316
|
+
|
|
317
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
233
318
|
|
|
234
|
-
|
|
319
|
+
// broadcast, simd group number is ntg / 32
|
|
320
|
+
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
|
|
321
|
+
if (tpitg < i) {
|
|
322
|
+
buf[tpitg] += buf[tpitg + i];
|
|
323
|
+
}
|
|
324
|
+
}
|
|
325
|
+
|
|
326
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
327
|
+
|
|
328
|
+
sum = buf[0];
|
|
329
|
+
|
|
330
|
+
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
|
235
331
|
pdst4[i00] /= sum;
|
|
236
332
|
}
|
|
237
333
|
}
|
|
@@ -251,7 +347,7 @@ kernel void kernel_diag_mask_inf(
|
|
|
251
347
|
dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
|
|
252
348
|
} else {
|
|
253
349
|
dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
|
|
254
|
-
|
|
350
|
+
}
|
|
255
351
|
}
|
|
256
352
|
|
|
257
353
|
kernel void kernel_diag_mask_inf_8(
|
|
@@ -399,8 +495,11 @@ kernel void kernel_rms_norm(
|
|
|
399
495
|
// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
|
|
400
496
|
inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {
|
|
401
497
|
float d = qb_curr->d;
|
|
498
|
+
|
|
402
499
|
float2 acc = 0.f;
|
|
500
|
+
|
|
403
501
|
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2);
|
|
502
|
+
|
|
404
503
|
for (int i = 0; i < 8; i+=2) {
|
|
405
504
|
acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
|
|
406
505
|
+ yl[i + 1] * (qs[i / 2] & 0x0F00);
|
|
@@ -417,8 +516,11 @@ inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thre
|
|
|
417
516
|
inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) {
|
|
418
517
|
float d = qb_curr->d;
|
|
419
518
|
float m = qb_curr->m;
|
|
420
|
-
|
|
519
|
+
|
|
421
520
|
float2 acc = 0.f;
|
|
521
|
+
|
|
522
|
+
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2);
|
|
523
|
+
|
|
422
524
|
for (int i = 0; i < 8; i+=2) {
|
|
423
525
|
acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
|
|
424
526
|
+ yl[i + 1] * (qs[i / 2] & 0x0F00);
|
|
@@ -428,6 +530,49 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre
|
|
|
428
530
|
return d * (acc[0] + acc[1]) + sumy * m;
|
|
429
531
|
}
|
|
430
532
|
|
|
533
|
+
// function for calculate inner product between half a q5_0 block and 16 floats (yl), sumy is SUM(yl[i])
|
|
534
|
+
// il indicates where the q5 quants begin (0 or QK5_0/4)
|
|
535
|
+
// we assume that the yl's have been multiplied with the appropriate scale factor
|
|
536
|
+
// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
|
|
537
|
+
inline float block_q_n_dot_y(device const block_q5_0 * qb_curr, float sumy, thread float * yl, int il) {
|
|
538
|
+
float d = qb_curr->d;
|
|
539
|
+
|
|
540
|
+
float2 acc = 0.f;
|
|
541
|
+
|
|
542
|
+
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 3 + il/2);
|
|
543
|
+
const uint32_t qh = *((device const uint32_t *)qb_curr->qh);
|
|
544
|
+
|
|
545
|
+
for (int i = 0; i < 8; i+=2) {
|
|
546
|
+
acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010))
|
|
547
|
+
+ yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000));
|
|
548
|
+
acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100))
|
|
549
|
+
+ yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
|
|
550
|
+
}
|
|
551
|
+
return d * (sumy * -16.f + acc[0] + acc[1]);
|
|
552
|
+
}
|
|
553
|
+
|
|
554
|
+
// function for calculate inner product between half a q5_1 block and 16 floats (yl), sumy is SUM(yl[i])
|
|
555
|
+
// il indicates where the q5 quants begin (0 or QK5_1/4)
|
|
556
|
+
// we assume that the yl's have been multiplied with the appropriate scale factor
|
|
557
|
+
// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
|
|
558
|
+
inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thread float * yl, int il) {
|
|
559
|
+
float d = qb_curr->d;
|
|
560
|
+
float m = qb_curr->m;
|
|
561
|
+
|
|
562
|
+
float2 acc = 0.f;
|
|
563
|
+
|
|
564
|
+
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 4 + il/2);
|
|
565
|
+
const uint32_t qh = *((device const uint32_t *)qb_curr->qh);
|
|
566
|
+
|
|
567
|
+
for (int i = 0; i < 8; i+=2) {
|
|
568
|
+
acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010))
|
|
569
|
+
+ yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000));
|
|
570
|
+
acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100))
|
|
571
|
+
+ yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
|
|
572
|
+
}
|
|
573
|
+
return d * (acc[0] + acc[1]) + sumy * m;
|
|
574
|
+
}
|
|
575
|
+
|
|
431
576
|
// putting them in the kernel cause a significant performance penalty
|
|
432
577
|
#define N_DST 4 // each SIMD group works on 4 rows
|
|
433
578
|
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
|
|
@@ -525,6 +670,43 @@ kernel void kernel_mul_mv_q4_1_f32(
|
|
|
525
670
|
mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
|
|
526
671
|
}
|
|
527
672
|
|
|
673
|
+
kernel void kernel_mul_mv_q5_0_f32(
|
|
674
|
+
device const void * src0,
|
|
675
|
+
device const float * src1,
|
|
676
|
+
device float * dst,
|
|
677
|
+
constant int64_t & ne00,
|
|
678
|
+
constant int64_t & ne01[[buffer(4)]],
|
|
679
|
+
constant int64_t & ne02[[buffer(5)]],
|
|
680
|
+
constant int64_t & ne10[[buffer(9)]],
|
|
681
|
+
constant int64_t & ne12[[buffer(11)]],
|
|
682
|
+
constant int64_t & ne0[[buffer(15)]],
|
|
683
|
+
constant int64_t & ne1[[buffer(16)]],
|
|
684
|
+
constant uint & gqa[[buffer(17)]],
|
|
685
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
686
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
687
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
688
|
+
mul_vec_q_n_f32<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
|
|
689
|
+
}
|
|
690
|
+
|
|
691
|
+
kernel void kernel_mul_mv_q5_1_f32(
|
|
692
|
+
device const void * src0,
|
|
693
|
+
device const float * src1,
|
|
694
|
+
device float * dst,
|
|
695
|
+
constant int64_t & ne00,
|
|
696
|
+
constant int64_t & ne01[[buffer(4)]],
|
|
697
|
+
constant int64_t & ne02[[buffer(5)]],
|
|
698
|
+
constant int64_t & ne10[[buffer(9)]],
|
|
699
|
+
constant int64_t & ne12[[buffer(11)]],
|
|
700
|
+
constant int64_t & ne0[[buffer(15)]],
|
|
701
|
+
constant int64_t & ne1[[buffer(16)]],
|
|
702
|
+
constant uint & gqa[[buffer(17)]],
|
|
703
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
704
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
705
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
706
|
+
mul_vec_q_n_f32<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
|
|
707
|
+
}
|
|
708
|
+
|
|
709
|
+
|
|
528
710
|
#define NB_Q8_0 8
|
|
529
711
|
|
|
530
712
|
kernel void kernel_mul_mv_q8_0_f32(
|
|
@@ -879,6 +1061,45 @@ kernel void kernel_alibi_f32(
|
|
|
879
1061
|
}
|
|
880
1062
|
}
|
|
881
1063
|
|
|
1064
|
+
static float rope_yarn_ramp(const float low, const float high, const int i0) {
|
|
1065
|
+
const float y = (i0 / 2 - low) / max(0.001f, high - low);
|
|
1066
|
+
return 1.0f - min(1.0f, max(0.0f, y));
|
|
1067
|
+
}
|
|
1068
|
+
|
|
1069
|
+
// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
|
|
1070
|
+
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
|
|
1071
|
+
static void rope_yarn(
|
|
1072
|
+
float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
|
|
1073
|
+
thread float * cos_theta, thread float * sin_theta
|
|
1074
|
+
) {
|
|
1075
|
+
// Get n-d rotational scaling corrected for extrapolation
|
|
1076
|
+
float theta_interp = freq_scale * theta_extrap;
|
|
1077
|
+
float theta = theta_interp;
|
|
1078
|
+
if (ext_factor != 0.0f) {
|
|
1079
|
+
float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
|
|
1080
|
+
theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
|
|
1081
|
+
|
|
1082
|
+
// Get n-d magnitude scaling corrected for interpolation
|
|
1083
|
+
mscale *= 1.0f + 0.1f * log(1.0f / freq_scale);
|
|
1084
|
+
}
|
|
1085
|
+
*cos_theta = cos(theta) * mscale;
|
|
1086
|
+
*sin_theta = sin(theta) * mscale;
|
|
1087
|
+
}
|
|
1088
|
+
|
|
1089
|
+
// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
|
|
1090
|
+
// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
|
|
1091
|
+
static float rope_yarn_corr_factor(int n_dims, int n_orig_ctx, float n_rot, float base) {
|
|
1092
|
+
return n_dims * log(n_orig_ctx / (n_rot * 2 * M_PI_F)) / (2 * log(base));
|
|
1093
|
+
}
|
|
1094
|
+
|
|
1095
|
+
static void rope_yarn_corr_dims(
|
|
1096
|
+
int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]
|
|
1097
|
+
) {
|
|
1098
|
+
// start and end correction dims
|
|
1099
|
+
dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_fast, freq_base)));
|
|
1100
|
+
dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_slow, freq_base)));
|
|
1101
|
+
}
|
|
1102
|
+
|
|
882
1103
|
typedef void (rope_t)(
|
|
883
1104
|
device const void * src0,
|
|
884
1105
|
device const int32_t * src1,
|
|
@@ -902,8 +1123,13 @@ typedef void (rope_t)(
|
|
|
902
1123
|
constant int & n_past,
|
|
903
1124
|
constant int & n_dims,
|
|
904
1125
|
constant int & mode,
|
|
1126
|
+
constant int & n_orig_ctx,
|
|
905
1127
|
constant float & freq_base,
|
|
906
1128
|
constant float & freq_scale,
|
|
1129
|
+
constant float & ext_factor,
|
|
1130
|
+
constant float & attn_factor,
|
|
1131
|
+
constant float & beta_fast,
|
|
1132
|
+
constant float & beta_slow,
|
|
907
1133
|
uint tiitg[[thread_index_in_threadgroup]],
|
|
908
1134
|
uint3 tptg[[threads_per_threadgroup]],
|
|
909
1135
|
uint3 tgpig[[threadgroup_position_in_grid]]);
|
|
@@ -932,8 +1158,13 @@ kernel void kernel_rope(
|
|
|
932
1158
|
constant int & n_past,
|
|
933
1159
|
constant int & n_dims,
|
|
934
1160
|
constant int & mode,
|
|
1161
|
+
constant int & n_orig_ctx,
|
|
935
1162
|
constant float & freq_base,
|
|
936
1163
|
constant float & freq_scale,
|
|
1164
|
+
constant float & ext_factor,
|
|
1165
|
+
constant float & attn_factor,
|
|
1166
|
+
constant float & beta_fast,
|
|
1167
|
+
constant float & beta_slow,
|
|
937
1168
|
uint tiitg[[thread_index_in_threadgroup]],
|
|
938
1169
|
uint3 tptg[[threads_per_threadgroup]],
|
|
939
1170
|
uint3 tgpig[[threadgroup_position_in_grid]]) {
|
|
@@ -943,19 +1174,22 @@ kernel void kernel_rope(
|
|
|
943
1174
|
|
|
944
1175
|
const bool is_neox = mode & 2;
|
|
945
1176
|
|
|
1177
|
+
float corr_dims[2];
|
|
1178
|
+
rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
|
|
1179
|
+
|
|
946
1180
|
device const int32_t * pos = src1;
|
|
947
1181
|
|
|
948
1182
|
const int64_t p = pos[i2];
|
|
949
1183
|
|
|
950
|
-
const float theta_0 =
|
|
1184
|
+
const float theta_0 = (float)p;
|
|
951
1185
|
const float inv_ndims = -1.f/n_dims;
|
|
952
1186
|
|
|
953
1187
|
if (!is_neox) {
|
|
954
1188
|
for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
|
|
955
1189
|
|
|
956
1190
|
const float theta = theta_0 * pow(freq_base, inv_ndims*i0);
|
|
957
|
-
|
|
958
|
-
|
|
1191
|
+
float cos_theta, sin_theta;
|
|
1192
|
+
rope_yarn(theta, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
|
959
1193
|
|
|
960
1194
|
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
|
961
1195
|
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
@@ -970,9 +1204,12 @@ kernel void kernel_rope(
|
|
|
970
1204
|
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
|
|
971
1205
|
for (int64_t ic = 2*tiitg; ic < n_dims; ic += 2*tptg.x) {
|
|
972
1206
|
|
|
973
|
-
|
|
974
|
-
const float
|
|
975
|
-
|
|
1207
|
+
// simplified from `(ib * n_dims + ic) * inv_ndims`
|
|
1208
|
+
const float cur_rot = inv_ndims*ic - ib;
|
|
1209
|
+
|
|
1210
|
+
const float theta = theta_0 * pow(freq_base, cur_rot);
|
|
1211
|
+
float cos_theta, sin_theta;
|
|
1212
|
+
rope_yarn(theta, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
|
976
1213
|
|
|
977
1214
|
const int64_t i0 = ib*n_dims + ic/2;
|
|
978
1215
|
|
|
@@ -2149,6 +2386,62 @@ void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg
|
|
|
2149
2386
|
}
|
|
2150
2387
|
}
|
|
2151
2388
|
|
|
2389
|
+
template <typename type4x4>
|
|
2390
|
+
void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg) {
|
|
2391
|
+
device const uint16_t * qs = ((device const uint16_t *)xb + 3);
|
|
2392
|
+
const float d = xb->d;
|
|
2393
|
+
const float md = -16.h * xb->d;
|
|
2394
|
+
const ushort mask = il ? 0x00F0 : 0x000F;
|
|
2395
|
+
|
|
2396
|
+
const uint32_t qh = *((device const uint32_t *)xb->qh);
|
|
2397
|
+
|
|
2398
|
+
const int x_mv = il ? 4 : 0;
|
|
2399
|
+
|
|
2400
|
+
const int gh_mv = il ? 12 : 0;
|
|
2401
|
+
const int gh_bk = il ? 0 : 4;
|
|
2402
|
+
|
|
2403
|
+
for (int i = 0; i < 8; i++) {
|
|
2404
|
+
// extract the 5-th bits for x0 and x1
|
|
2405
|
+
const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
|
|
2406
|
+
const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
|
|
2407
|
+
|
|
2408
|
+
// combine the 4-bits from qs with the 5th bit
|
|
2409
|
+
const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
|
|
2410
|
+
const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
|
|
2411
|
+
|
|
2412
|
+
reg[i/2][2*(i%2)+0] = d * x0 + md;
|
|
2413
|
+
reg[i/2][2*(i%2)+1] = d * x1 + md;
|
|
2414
|
+
}
|
|
2415
|
+
}
|
|
2416
|
+
|
|
2417
|
+
template <typename type4x4>
|
|
2418
|
+
void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg) {
|
|
2419
|
+
device const uint16_t * qs = ((device const uint16_t *)xb + 4);
|
|
2420
|
+
const float d = xb->d;
|
|
2421
|
+
const float m = xb->m;
|
|
2422
|
+
const ushort mask = il ? 0x00F0 : 0x000F;
|
|
2423
|
+
|
|
2424
|
+
const uint32_t qh = *((device const uint32_t *)xb->qh);
|
|
2425
|
+
|
|
2426
|
+
const int x_mv = il ? 4 : 0;
|
|
2427
|
+
|
|
2428
|
+
const int gh_mv = il ? 12 : 0;
|
|
2429
|
+
const int gh_bk = il ? 0 : 4;
|
|
2430
|
+
|
|
2431
|
+
for (int i = 0; i < 8; i++) {
|
|
2432
|
+
// extract the 5-th bits for x0 and x1
|
|
2433
|
+
const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
|
|
2434
|
+
const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
|
|
2435
|
+
|
|
2436
|
+
// combine the 4-bits from qs with the 5th bit
|
|
2437
|
+
const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
|
|
2438
|
+
const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
|
|
2439
|
+
|
|
2440
|
+
reg[i/2][2*(i%2)+0] = d * x0 + m;
|
|
2441
|
+
reg[i/2][2*(i%2)+1] = d * x1 + m;
|
|
2442
|
+
}
|
|
2443
|
+
}
|
|
2444
|
+
|
|
2152
2445
|
template <typename type4x4>
|
|
2153
2446
|
void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
|
|
2154
2447
|
device const int8_t * qs = ((device const int8_t *)xb->qs);
|
|
@@ -2490,6 +2783,8 @@ template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows
|
|
|
2490
2783
|
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
|
|
2491
2784
|
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
|
|
2492
2785
|
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
|
|
2786
|
+
template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows<block_q5_0, 2, dequantize_q5_0>;
|
|
2787
|
+
template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_t kernel_get_rows<block_q5_1, 2, dequantize_q5_1>;
|
|
2493
2788
|
template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows<block_q8_0, 2, dequantize_q8_0>;
|
|
2494
2789
|
template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows<block_q2_K, QK_NL, dequantize_q2_K>;
|
|
2495
2790
|
template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows<block_q3_K, QK_NL, dequantize_q3_K>;
|
|
@@ -2518,6 +2813,8 @@ template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<f
|
|
|
2518
2813
|
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
|
|
2519
2814
|
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
|
|
2520
2815
|
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
|
|
2816
|
+
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_0, 2, dequantize_q5_0>;
|
|
2817
|
+
template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_1, 2, dequantize_q5_1>;
|
|
2521
2818
|
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
|
|
2522
2819
|
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
|
|
2523
2820
|
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
|
|
Binary file
|