node-llama-cpp 2.8.6 → 2.8.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/llama/addon.cpp +2 -2
- package/llama/binariesGithubRelease.json +1 -1
- package/llama/gitRelease.bundle +0 -0
- package/llamaBins/linux-arm64/llama-addon.node +0 -0
- package/llamaBins/linux-armv7l/llama-addon.node +0 -0
- package/llamaBins/linux-x64/llama-addon.node +0 -0
- package/llamaBins/mac-arm64/ggml-metal.metal +378 -6
- package/llamaBins/mac-arm64/llama-addon.node +0 -0
- package/llamaBins/mac-x64/ggml-metal.metal +378 -6
- package/llamaBins/mac-x64/llama-addon.node +0 -0
- package/llamaBins/win-x64/llama-addon.node +0 -0
- package/package.json +1 -1
package/llama/addon.cpp
CHANGED
|
@@ -39,7 +39,7 @@ class LLAMAModel : public Napi::ObjectWrap<LLAMAModel> {
|
|
|
39
39
|
}
|
|
40
40
|
}
|
|
41
41
|
|
|
42
|
-
llama_backend_init(
|
|
42
|
+
llama_backend_init();
|
|
43
43
|
model = llama_load_model_from_file(modelPath.c_str(), model_params);
|
|
44
44
|
|
|
45
45
|
if (model == NULL) {
|
|
@@ -436,7 +436,7 @@ Napi::Value LLAMAContext::Eval(const Napi::CallbackInfo& info) {
|
|
|
436
436
|
Napi::Value systemInfo(const Napi::CallbackInfo& info) { return Napi::String::From(info.Env(), llama_print_system_info()); }
|
|
437
437
|
|
|
438
438
|
Napi::Object registerCallback(Napi::Env env, Napi::Object exports) {
|
|
439
|
-
llama_backend_init(
|
|
439
|
+
llama_backend_init();
|
|
440
440
|
exports.DefineProperties({
|
|
441
441
|
Napi::PropertyDescriptor::Function("systemInfo", systemInfo),
|
|
442
442
|
});
|
package/llama/gitRelease.bundle
CHANGED
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
@@ -351,12 +351,17 @@ kernel void kernel_sum_rows(
|
|
|
351
351
|
kernel void kernel_soft_max(
|
|
352
352
|
device const float * src0,
|
|
353
353
|
device const float * src1,
|
|
354
|
+
device const float * src2,
|
|
354
355
|
device float * dst,
|
|
355
356
|
constant int64_t & ne00,
|
|
356
357
|
constant int64_t & ne01,
|
|
357
358
|
constant int64_t & ne02,
|
|
358
359
|
constant float & scale,
|
|
359
|
-
|
|
360
|
+
constant float & max_bias,
|
|
361
|
+
constant float & m0,
|
|
362
|
+
constant float & m1,
|
|
363
|
+
constant uint32_t & n_head_log2,
|
|
364
|
+
threadgroup float * buf [[threadgroup(0)]],
|
|
360
365
|
uint tgpig[[threadgroup_position_in_grid]],
|
|
361
366
|
uint tpitg[[thread_position_in_threadgroup]],
|
|
362
367
|
uint sgitg[[simdgroup_index_in_threadgroup]],
|
|
@@ -368,13 +373,26 @@ kernel void kernel_soft_max(
|
|
|
368
373
|
|
|
369
374
|
device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
|
370
375
|
device const float * pmask = src1 != src0 ? src1 + i01*ne00 : nullptr;
|
|
376
|
+
device const float * ppos = src2 != src0 ? src2 : nullptr;
|
|
371
377
|
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
|
372
378
|
|
|
379
|
+
float slope = 0.0f;
|
|
380
|
+
|
|
381
|
+
// ALiBi
|
|
382
|
+
if (max_bias > 0.0f) {
|
|
383
|
+
const int64_t h = i02;
|
|
384
|
+
|
|
385
|
+
const float base = h < n_head_log2 ? m0 : m1;
|
|
386
|
+
const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
|
387
|
+
|
|
388
|
+
slope = pow(base, exp);
|
|
389
|
+
}
|
|
390
|
+
|
|
373
391
|
// parallel max
|
|
374
392
|
float lmax = -INFINITY;
|
|
375
393
|
|
|
376
394
|
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
|
377
|
-
lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f));
|
|
395
|
+
lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f) + slope*ppos[i00]);
|
|
378
396
|
}
|
|
379
397
|
|
|
380
398
|
// find the max value in the block
|
|
@@ -399,7 +417,7 @@ kernel void kernel_soft_max(
|
|
|
399
417
|
// parallel sum
|
|
400
418
|
float lsum = 0.0f;
|
|
401
419
|
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
|
402
|
-
const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val);
|
|
420
|
+
const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f) + slope*ppos[i00]) - max_val);
|
|
403
421
|
lsum += exp_psrc0;
|
|
404
422
|
pdst[i00] = exp_psrc0;
|
|
405
423
|
}
|
|
@@ -437,12 +455,17 @@ kernel void kernel_soft_max(
|
|
|
437
455
|
kernel void kernel_soft_max_4(
|
|
438
456
|
device const float * src0,
|
|
439
457
|
device const float * src1,
|
|
458
|
+
device const float * src2,
|
|
440
459
|
device float * dst,
|
|
441
460
|
constant int64_t & ne00,
|
|
442
461
|
constant int64_t & ne01,
|
|
443
462
|
constant int64_t & ne02,
|
|
444
463
|
constant float & scale,
|
|
445
|
-
|
|
464
|
+
constant float & max_bias,
|
|
465
|
+
constant float & m0,
|
|
466
|
+
constant float & m1,
|
|
467
|
+
constant uint32_t & n_head_log2,
|
|
468
|
+
threadgroup float * buf [[threadgroup(0)]],
|
|
446
469
|
uint tgpig[[threadgroup_position_in_grid]],
|
|
447
470
|
uint tpitg[[thread_position_in_threadgroup]],
|
|
448
471
|
uint sgitg[[simdgroup_index_in_threadgroup]],
|
|
@@ -454,13 +477,25 @@ kernel void kernel_soft_max_4(
|
|
|
454
477
|
|
|
455
478
|
device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
|
456
479
|
device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
|
|
480
|
+
device const float4 * ppos = src2 != src0 ? (device const float4 *)(src2) : nullptr;
|
|
457
481
|
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
|
458
482
|
|
|
483
|
+
float slope = 0.0f;
|
|
484
|
+
|
|
485
|
+
if (max_bias > 0.0f) {
|
|
486
|
+
const int64_t h = i02;
|
|
487
|
+
|
|
488
|
+
const float base = h < n_head_log2 ? m0 : m1;
|
|
489
|
+
const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
|
490
|
+
|
|
491
|
+
slope = pow(base, exp);
|
|
492
|
+
}
|
|
493
|
+
|
|
459
494
|
// parallel max
|
|
460
495
|
float4 lmax4 = -INFINITY;
|
|
461
496
|
|
|
462
497
|
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
|
463
|
-
lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f));
|
|
498
|
+
lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f) + slope*ppos[i00]);
|
|
464
499
|
}
|
|
465
500
|
|
|
466
501
|
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
|
|
@@ -486,7 +521,7 @@ kernel void kernel_soft_max_4(
|
|
|
486
521
|
// parallel sum
|
|
487
522
|
float4 lsum4 = 0.0f;
|
|
488
523
|
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
|
489
|
-
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val);
|
|
524
|
+
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f) + slope*ppos[i00]) - max_val);
|
|
490
525
|
lsum4 += exp_psrc4;
|
|
491
526
|
pdst4[i00] = exp_psrc4;
|
|
492
527
|
}
|
|
@@ -2490,6 +2525,13 @@ typedef struct {
|
|
|
2490
2525
|
} block_iq3_xxs;
|
|
2491
2526
|
// 98 bytes / block for QK_K = 256, so 3.0625 bpw
|
|
2492
2527
|
|
|
2528
|
+
typedef struct {
|
|
2529
|
+
half d;
|
|
2530
|
+
uint8_t qs[QK_K/8];
|
|
2531
|
+
uint8_t scales[QK_K/16];
|
|
2532
|
+
} block_iq1_s;
|
|
2533
|
+
|
|
2534
|
+
|
|
2493
2535
|
//====================================== dot products =========================
|
|
2494
2536
|
|
|
2495
2537
|
void kernel_mul_mv_q2_K_f32_impl(
|
|
@@ -3747,6 +3789,137 @@ constexpr constant static uint32_t iq3xxs_grid[256] = {
|
|
|
3747
3789
|
0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
|
|
3748
3790
|
};
|
|
3749
3791
|
|
|
3792
|
+
#define NGRID_IQ1S 512
|
|
3793
|
+
constexpr constant static uint64_t iq1s_grid[NGRID_IQ1S] = {
|
|
3794
|
+
0xffffffffffff0101, 0xffffffffff01ff00, 0xffffffffff010100, 0xffffffff00000000,
|
|
3795
|
+
0xffffffff01ff00ff, 0xffffffff01ff0001, 0xffffffff0101ffff, 0xffffffff0101ff01,
|
|
3796
|
+
0xffffff00ff000000, 0xffffff000000ff00, 0xffffff00000000ff, 0xffffff0000000100,
|
|
3797
|
+
0xffffff0000010000, 0xffffff0001000000, 0xffffff01ffff00ff, 0xffffff01ff01ff00,
|
|
3798
|
+
0xffffff01ff010100, 0xffffff0100000001, 0xffffff0101ffff00, 0xffffff0101ff0101,
|
|
3799
|
+
0xffffff0101010100, 0xffff00ffff00ff01, 0xffff00ffff0000ff, 0xffff00ff00ff0100,
|
|
3800
|
+
0xffff00ff0100ff00, 0xffff00ff010001ff, 0xffff0000ff0101ff, 0xffff000000ffff00,
|
|
3801
|
+
0xffff000000000000, 0xffff00000001ff01, 0xffff000001000101, 0xffff0000010100ff,
|
|
3802
|
+
0xffff0001ffff0100, 0xffff00010000ff00, 0xffff000100010101, 0xffff000101000000,
|
|
3803
|
+
0xffff01ffffff0000, 0xffff01ffff01ffff, 0xffff01ffff010100, 0xffff01ff00000000,
|
|
3804
|
+
0xffff01ff01ffffff, 0xffff01ff01ff0001, 0xffff01ff0101ffff, 0xffff01ff01010001,
|
|
3805
|
+
0xffff0100ffffff01, 0xffff01000000ffff, 0xffff010000000100, 0xffff010001ff01ff,
|
|
3806
|
+
0xffff010001000000, 0xffff0101ff000000, 0xffff0101000101ff, 0xffff010101ffff01,
|
|
3807
|
+
0xffff01010101ff00, 0xff00ffffff000000, 0xff00ffff00ffff00, 0xff00ffff00000001,
|
|
3808
|
+
0xff00ffff000001ff, 0xff00ffff01010000, 0xff00ff00ffff0000, 0xff00ff00ff00ff00,
|
|
3809
|
+
0xff00ff00ff0000ff, 0xff00ff00ff000100, 0xff00ff00ff010001, 0xff00ff0000ff0001,
|
|
3810
|
+
0xff00ff000000ffff, 0xff00ff0000000000, 0xff00ff000001ff00, 0xff00ff0000010100,
|
|
3811
|
+
0xff00ff0001ff0000, 0xff00ff000100ff00, 0xff00ff0001000100, 0xff00ff01ff000000,
|
|
3812
|
+
0xff00ff0100ff0000, 0xff00ff01000001ff, 0xff00ff0101010001, 0xff0000ff00000000,
|
|
3813
|
+
0xff0000ff0001ff00, 0xff0000ff00010100, 0xff000000ffff0101, 0xff000000ff000000,
|
|
3814
|
+
0xff000000ff01ff00, 0xff00000000ff0000, 0xff0000000000ff00, 0xff000000000000ff,
|
|
3815
|
+
0xff00000000000000, 0xff00000000000001, 0xff00000000000100, 0xff0000000001ffff,
|
|
3816
|
+
0xff00000000010000, 0xff00000001000000, 0xff00000001010100, 0xff000001ff00ff01,
|
|
3817
|
+
0xff000001ff0100ff, 0xff00000100000000, 0xff0000010001ff00, 0xff00000101ff0100,
|
|
3818
|
+
0xff0000010100ff00, 0xff0001ff00ff00ff, 0xff0001ff00000101, 0xff0001ff000100ff,
|
|
3819
|
+
0xff0001ff01000000, 0xff000100ff0001ff, 0xff0001000000ff01, 0xff00010000000000,
|
|
3820
|
+
0xff00010000010001, 0xff00010000010100, 0xff00010001ffff00, 0xff00010001ff0101,
|
|
3821
|
+
0xff00010001010000, 0xff000101ffffffff, 0xff000101ff000101, 0xff00010101ff00ff,
|
|
3822
|
+
0xff00010101000001, 0xff000101010100ff, 0xff01ffffff000101, 0xff01ffffff01ffff,
|
|
3823
|
+
0xff01ffffff01ff01, 0xff01ffffff0101ff, 0xff01ffff00000000, 0xff01ffff01ff0001,
|
|
3824
|
+
0xff01ffff0101ff01, 0xff01ff00ff000000, 0xff01ff0000ff0100, 0xff01ff000000ff01,
|
|
3825
|
+
0xff01ff0000010000, 0xff01ff00010000ff, 0xff01ff01ff01ff00, 0xff01ff0100000101,
|
|
3826
|
+
0xff0100ffffff0000, 0xff0100ffff010000, 0xff0100ff01ff00ff, 0xff0100ff01000100,
|
|
3827
|
+
0xff0100ff010100ff, 0xff010000ffffff01, 0xff01000000000000, 0xff0100000101ff00,
|
|
3828
|
+
0xff010001ffff00ff, 0xff010001ff000100, 0xff01000100ffff00, 0xff01000100010001,
|
|
3829
|
+
0xff01000101ff0001, 0xff010001010001ff, 0xff0101ffffffffff, 0xff0101ffff01ffff,
|
|
3830
|
+
0xff0101ffff010101, 0xff0101ff0000ff00, 0xff0101ff01010001, 0xff010100ff000000,
|
|
3831
|
+
0xff010100ff01ff01, 0xff01010000ff0001, 0xff01010000000100, 0xff01010001000000,
|
|
3832
|
+
0xff0101010100ffff, 0x00ffffff0000ff01, 0x00ffffff000000ff, 0x00ffffff00000100,
|
|
3833
|
+
0x00ffffff00010000, 0x00ffff00ffff0001, 0x00ffff00ff0000ff, 0x00ffff00ff000100,
|
|
3834
|
+
0x00ffff0000000000, 0x00ffff0001000100, 0x00ffff0001010001, 0x00ffff01ff00ff01,
|
|
3835
|
+
0x00ffff0100ff0100, 0x00ffff010000ff00, 0x00ffff01000100ff, 0x00ffff0101ff00ff,
|
|
3836
|
+
0x00ffff010101ff00, 0x00ff00ffffffffff, 0x00ff00ffffff01ff, 0x00ff00ffff000101,
|
|
3837
|
+
0x00ff00ff00000000, 0x00ff00ff000101ff, 0x00ff00ff01010101, 0x00ff0000ff000000,
|
|
3838
|
+
0x00ff0000ff01ffff, 0x00ff000000ff0000, 0x00ff00000000ff00, 0x00ff0000000000ff,
|
|
3839
|
+
0x00ff000000000000, 0x00ff000000000001, 0x00ff000000000100, 0x00ff000000010000,
|
|
3840
|
+
0x00ff000001ffff01, 0x00ff000001000000, 0x00ff0001ff000101, 0x00ff000100ffffff,
|
|
3841
|
+
0x00ff000100000000, 0x00ff0001010001ff, 0x00ff01ffff000000, 0x00ff01ff0001ff00,
|
|
3842
|
+
0x00ff01ff01ff0100, 0x00ff0100ff01ff01, 0x00ff010000ff00ff, 0x00ff010000ff0101,
|
|
3843
|
+
0x00ff010000000000, 0x00ff010000010101, 0x00ff01000100ff00, 0x00ff010001010000,
|
|
3844
|
+
0x00ff0101ffffff00, 0x00ff01010000ff01, 0x00ff010100000100, 0x00ff010101ff0000,
|
|
3845
|
+
0x0000ffffffff0100, 0x0000ffffff00ff00, 0x0000ffffff0000ff, 0x0000ffffff010000,
|
|
3846
|
+
0x0000ffff00000000, 0x0000ffff00010101, 0x0000ffff01ffff01, 0x0000ffff01000100,
|
|
3847
|
+
0x0000ff00ff000000, 0x0000ff00ff01ff00, 0x0000ff00ff0101ff, 0x0000ff0000ff0000,
|
|
3848
|
+
0x0000ff000000ff00, 0x0000ff00000000ff, 0x0000ff0000000000, 0x0000ff0000000001,
|
|
3849
|
+
0x0000ff0000000100, 0x0000ff0000010000, 0x0000ff0001ffffff, 0x0000ff0001ff01ff,
|
|
3850
|
+
0x0000ff0001000000, 0x0000ff000101ffff, 0x0000ff01ffff0101, 0x0000ff01ff010000,
|
|
3851
|
+
0x0000ff0100000000, 0x0000ff0101000101, 0x000000ffffff0001, 0x000000ffff000000,
|
|
3852
|
+
0x000000ff00ff0000, 0x000000ff0000ff00, 0x000000ff000000ff, 0x000000ff00000000,
|
|
3853
|
+
0x000000ff00000001, 0x000000ff00000100, 0x000000ff00010000, 0x000000ff01000000,
|
|
3854
|
+
0x000000ff0101ff00, 0x00000000ffff0000, 0x00000000ff00ff00, 0x00000000ff0000ff,
|
|
3855
|
+
0x00000000ff000000, 0x00000000ff000001, 0x00000000ff000100, 0x00000000ff010000,
|
|
3856
|
+
0x0000000000ffff00, 0x0000000000ff00ff, 0x0000000000ff0000, 0x0000000000ff0001,
|
|
3857
|
+
0x0000000000ff0100, 0x000000000000ffff, 0x000000000000ff00, 0x000000000000ff01,
|
|
3858
|
+
0x00000000000000ff, 0x0000000000000001, 0x00000000000001ff, 0x0000000000000100,
|
|
3859
|
+
0x0000000000000101, 0x000000000001ff00, 0x00000000000100ff, 0x0000000000010000,
|
|
3860
|
+
0x0000000000010001, 0x0000000000010100, 0x0000000001ff0000, 0x000000000100ff00,
|
|
3861
|
+
0x00000000010000ff, 0x0000000001000000, 0x0000000001000001, 0x0000000001000100,
|
|
3862
|
+
0x0000000001010000, 0x00000001ffff01ff, 0x00000001ff000000, 0x0000000100ff0000,
|
|
3863
|
+
0x000000010000ff00, 0x00000001000000ff, 0x0000000100000000, 0x0000000100000001,
|
|
3864
|
+
0x0000000100000100, 0x0000000100010000, 0x0000000101000000, 0x000001ffff00ff00,
|
|
3865
|
+
0x000001ffff010001, 0x000001ffff0101ff, 0x000001ff00ffff01, 0x000001ff0000ffff,
|
|
3866
|
+
0x000001ff00000000, 0x000001ff010000ff, 0x000001ff01010100, 0x00000100ffff0100,
|
|
3867
|
+
0x00000100ff000000, 0x0000010000ff0000, 0x000001000000ff00, 0x00000100000000ff,
|
|
3868
|
+
0x0000010000000000, 0x0000010000000001, 0x0000010000000100, 0x0000010000010000,
|
|
3869
|
+
0x0000010001000000, 0x000001000101ff01, 0x00000101ffff0001, 0x00000101ff01ffff,
|
|
3870
|
+
0x0000010100000000, 0x0000010101010100, 0x0001ffffff000000, 0x0001ffff00ffffff,
|
|
3871
|
+
0x0001ffff00000100, 0x0001ffff0001ff00, 0x0001ffff01000000, 0x0001ff00ffffff00,
|
|
3872
|
+
0x0001ff00ffff01ff, 0x0001ff00ff010000, 0x0001ff0000000000, 0x0001ff0000010001,
|
|
3873
|
+
0x0001ff0001ff0000, 0x0001ff0001010100, 0x0001ff01ff0000ff, 0x0001ff01ff000001,
|
|
3874
|
+
0x0001ff0100ffffff, 0x0001ff010001ffff, 0x0001ff01000101ff, 0x0001ff010100ff01,
|
|
3875
|
+
0x000100ffff00ffff, 0x000100ffff00ff01, 0x000100ffff000100, 0x000100ff00000000,
|
|
3876
|
+
0x000100ff000101ff, 0x000100ff01ff0101, 0x000100ff0100ffff, 0x000100ff01010101,
|
|
3877
|
+
0x00010000ff000000, 0x00010000ff010100, 0x0001000000ff0000, 0x000100000000ff00,
|
|
3878
|
+
0x00010000000000ff, 0x0001000000000000, 0x0001000000000001, 0x0001000000000100,
|
|
3879
|
+
0x0001000000010000, 0x0001000001ffff01, 0x0001000001000000, 0x0001000100ff0101,
|
|
3880
|
+
0x0001000100000000, 0x00010001010100ff, 0x000101ffffff01ff, 0x000101ffffff0101,
|
|
3881
|
+
0x000101ff00010000, 0x000101ff01ff0000, 0x000101ff0100ff01, 0x00010100ffff0000,
|
|
3882
|
+
0x0001010000000000, 0x000101000001ffff, 0x0001010000010101, 0x00010100010001ff,
|
|
3883
|
+
0x00010101ff00ff00, 0x00010101ff010001, 0x0001010100ffffff, 0x0001010100ff01ff,
|
|
3884
|
+
0x00010101000101ff, 0x0001010101ff0000, 0x000101010100ff01, 0x0001010101000101,
|
|
3885
|
+
0x01ffffffffff0101, 0x01ffffffff01ffff, 0x01ffffffff01ff01, 0x01ffffffff0101ff,
|
|
3886
|
+
0x01ffffffff010101, 0x01ffffff00000000, 0x01ffffff01ff01ff, 0x01ffffff01000101,
|
|
3887
|
+
0x01ffffff0101ff01, 0x01ffffff010100ff, 0x01ffff000000ff00, 0x01ffff0000000001,
|
|
3888
|
+
0x01ffff00000001ff, 0x01ffff0000010000, 0x01ffff0001ff0000, 0x01ffff01ffffffff,
|
|
3889
|
+
0x01ffff01ffff01ff, 0x01ffff01ff000000, 0x01ffff01ff01ffff, 0x01ffff01ff0101ff,
|
|
3890
|
+
0x01ffff010100ffff, 0x01ff00ffffff0000, 0x01ff00ffff010000, 0x01ff00ff00ffff01,
|
|
3891
|
+
0x01ff0000ff0000ff, 0x01ff000000000000, 0x01ff00000001ff01, 0x01ff000001ffffff,
|
|
3892
|
+
0x01ff000001010100, 0x01ff0001ffffff01, 0x01ff0001ff010001, 0x01ff000101ff0100,
|
|
3893
|
+
0x01ff000101000001, 0x01ff0001010100ff, 0x01ff01ffff00ffff, 0x01ff01ff00010001,
|
|
3894
|
+
0x01ff01ff01000000, 0x01ff01ff010101ff, 0x01ff0100ff000001, 0x01ff010000ffff00,
|
|
3895
|
+
0x01ff010000000100, 0x01ff010001ff01ff, 0x01ff01000101ffff, 0x01ff0101ffff00ff,
|
|
3896
|
+
0x01ff0101ffff0101, 0x01ff0101ff0101ff, 0x01ff010100010000, 0x0100ffff00ff00ff,
|
|
3897
|
+
0x0100ffff00ff0001, 0x0100ffff00000100, 0x0100ffff0100ff00, 0x0100ff00ffff0000,
|
|
3898
|
+
0x0100ff00ff00ffff, 0x0100ff00ff00ff01, 0x0100ff00ff000100, 0x0100ff00ff010000,
|
|
3899
|
+
0x0100ff0000000000, 0x0100ff00000100ff, 0x0100ff0001ff0101, 0x0100ff0001010101,
|
|
3900
|
+
0x0100ff0100ff00ff, 0x0100ff0100ff0001, 0x0100ff0100000100, 0x0100ff0100010001,
|
|
3901
|
+
0x0100ff0101000000, 0x010000ffff00ff00, 0x010000ff0000ffff, 0x010000ff00000000,
|
|
3902
|
+
0x010000ff010001ff, 0x010000ff01010001, 0x01000000ffffff00, 0x01000000ffff0101,
|
|
3903
|
+
0x01000000ff000000, 0x01000000ff0100ff, 0x01000000ff010101, 0x0100000000ff0000,
|
|
3904
|
+
0x010000000000ff00, 0x01000000000000ff, 0x0100000000000000, 0x0100000000000001,
|
|
3905
|
+
0x0100000000000100, 0x0100000000010000, 0x0100000001000000, 0x0100000100000000,
|
|
3906
|
+
0x01000001000101ff, 0x0100000101ffff01, 0x010001ffff000101, 0x010001ff00ff0100,
|
|
3907
|
+
0x010001ff0000ff00, 0x010001ff000100ff, 0x010001ff01ffffff, 0x01000100ffff0000,
|
|
3908
|
+
0x01000100ff0001ff, 0x0100010000000000, 0x010001000001ff00, 0x0100010001ff0000,
|
|
3909
|
+
0x01000100010000ff, 0x0100010001000101, 0x01000101ff00ff01, 0x0100010100ff0100,
|
|
3910
|
+
0x010001010000ffff, 0x0100010101010001, 0x0101ffffffff0101, 0x0101ffffff0001ff,
|
|
3911
|
+
0x0101ffffff01ffff, 0x0101ffffff010101, 0x0101ffff00000000, 0x0101ffff0101ffff,
|
|
3912
|
+
0x0101ffff010101ff, 0x0101ff00ff000000, 0x0101ff0000ff0100, 0x0101ff000000ff00,
|
|
3913
|
+
0x0101ff0000010000, 0x0101ff00010000ff, 0x0101ff0001000001, 0x0101ff01ff010101,
|
|
3914
|
+
0x0101ff0100000000, 0x0101ff010101ff00, 0x010100ffffff0000, 0x010100ffff010000,
|
|
3915
|
+
0x010100ff00ff01ff, 0x010100ff000000ff, 0x010100ff00000101, 0x010100ff01ffff00,
|
|
3916
|
+
0x01010000ffffff01, 0x01010000ff000100, 0x01010000ff01ff01, 0x0101000000000000,
|
|
3917
|
+
0x01010000000100ff, 0x010100000101ff01, 0x01010001ffff0000, 0x01010001ff00ffff,
|
|
3918
|
+
0x01010001ff010000, 0x0101000101ffffff, 0x0101000101ff01ff, 0x0101000101010101,
|
|
3919
|
+
0x010101ffff01ffff, 0x010101ff00000000, 0x010101ff0001ff01, 0x010101ff0101ffff,
|
|
3920
|
+
0x010101ff010101ff, 0x01010100ffffffff, 0x01010100ff000001, 0x010101000000ff00,
|
|
3921
|
+
0x0101010001010000, 0x0101010100ff0001, 0x010101010001ff01, 0x010101010101ffff,
|
|
3922
|
+
};
|
|
3750
3923
|
|
|
3751
3924
|
constexpr constant static uint8_t ksigns_iq2xs[128] = {
|
|
3752
3925
|
0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
|
|
@@ -4173,6 +4346,123 @@ kernel void kernel_mul_mv_iq3_xxs_f32(
|
|
|
4173
4346
|
kernel_mul_mv_iq3_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
|
4174
4347
|
}
|
|
4175
4348
|
|
|
4349
|
+
void kernel_mul_mv_iq1_s_f32_impl(
|
|
4350
|
+
device const void * src0,
|
|
4351
|
+
device const float * src1,
|
|
4352
|
+
device float * dst,
|
|
4353
|
+
constant int64_t & ne00,
|
|
4354
|
+
constant int64_t & ne01,
|
|
4355
|
+
constant int64_t & ne02,
|
|
4356
|
+
constant int64_t & ne10,
|
|
4357
|
+
constant int64_t & ne12,
|
|
4358
|
+
constant int64_t & ne0,
|
|
4359
|
+
constant int64_t & ne1,
|
|
4360
|
+
constant uint & r2,
|
|
4361
|
+
constant uint & r3,
|
|
4362
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
4363
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
4364
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
4365
|
+
|
|
4366
|
+
const int nb = ne00/QK_K;
|
|
4367
|
+
const int r0 = tgpig.x;
|
|
4368
|
+
const int r1 = tgpig.y;
|
|
4369
|
+
const int im = tgpig.z;
|
|
4370
|
+
|
|
4371
|
+
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
|
4372
|
+
const int ib_row = first_row * nb;
|
|
4373
|
+
|
|
4374
|
+
const uint i12 = im%ne12;
|
|
4375
|
+
const uint i13 = im/ne12;
|
|
4376
|
+
|
|
4377
|
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
|
4378
|
+
|
|
4379
|
+
device const block_iq1_s * x = (device const block_iq1_s *) src0 + ib_row + offset0;
|
|
4380
|
+
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
|
4381
|
+
|
|
4382
|
+
float yl[16];
|
|
4383
|
+
float sumf[N_DST]={0.f}, all_sum;
|
|
4384
|
+
|
|
4385
|
+
const int nb32 = nb * (QK_K / 32);
|
|
4386
|
+
|
|
4387
|
+
#if QK_K == 256
|
|
4388
|
+
const int ix = tiisg/2;
|
|
4389
|
+
const int il = tiisg%2;
|
|
4390
|
+
|
|
4391
|
+
device const float * y4 = y + 32 * ix + 16 * il;
|
|
4392
|
+
|
|
4393
|
+
for (int ib32 = ix; ib32 < nb32; ib32 += 16) {
|
|
4394
|
+
|
|
4395
|
+
for (int i = 0; i < 16; ++i) {
|
|
4396
|
+
yl[i] = y4[i];
|
|
4397
|
+
}
|
|
4398
|
+
|
|
4399
|
+
const int ibl = ib32 / (QK_K / 32);
|
|
4400
|
+
const int ib = ib32 % (QK_K / 32);
|
|
4401
|
+
|
|
4402
|
+
device const block_iq1_s * xr = x + ibl;
|
|
4403
|
+
device const uint8_t * qs = xr->qs + 4 * ib + 2 * il;
|
|
4404
|
+
device const uint8_t * sc = xr->scales + 2 * ib + il;
|
|
4405
|
+
device const half * dh = &xr->d;
|
|
4406
|
+
|
|
4407
|
+
for (int row = 0; row < N_DST; row++) {
|
|
4408
|
+
|
|
4409
|
+
constant int8_t * grid1 = (constant int8_t *)(iq1s_grid + (qs[0] | ((sc[0] & 0x08) << 5)));
|
|
4410
|
+
constant int8_t * grid2 = (constant int8_t *)(iq1s_grid + (qs[1] | ((sc[0] & 0x80) << 1)));
|
|
4411
|
+
|
|
4412
|
+
float2 sum = {0};
|
|
4413
|
+
for (int j = 0; j < 8; ++j) {
|
|
4414
|
+
sum[0] += yl[j+ 0] * grid1[j];
|
|
4415
|
+
sum[1] += yl[j+ 8] * grid2[j];
|
|
4416
|
+
}
|
|
4417
|
+
sumf[row] += (float)dh[0] * (sum[0] * (2*(sc[0] & 7) + 1) + sum[1] * (2*((sc[0] >> 4) & 7) + 1));
|
|
4418
|
+
|
|
4419
|
+
dh += nb*sizeof(block_iq1_s)/2;
|
|
4420
|
+
qs += nb*sizeof(block_iq1_s);
|
|
4421
|
+
sc += nb*sizeof(block_iq1_s);
|
|
4422
|
+
}
|
|
4423
|
+
|
|
4424
|
+
y4 += 16 * 32;
|
|
4425
|
+
}
|
|
4426
|
+
#else
|
|
4427
|
+
// TODO
|
|
4428
|
+
#endif
|
|
4429
|
+
|
|
4430
|
+
for (int row = 0; row < N_DST; ++row) {
|
|
4431
|
+
all_sum = simd_sum(sumf[row]);
|
|
4432
|
+
if (tiisg == 0) {
|
|
4433
|
+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
|
|
4434
|
+
}
|
|
4435
|
+
}
|
|
4436
|
+
}
|
|
4437
|
+
|
|
4438
|
+
[[host_name("kernel_mul_mv_iq1_s_f32")]]
|
|
4439
|
+
kernel void kernel_mul_mv_iq1_s_f32(
|
|
4440
|
+
device const void * src0,
|
|
4441
|
+
device const float * src1,
|
|
4442
|
+
device float * dst,
|
|
4443
|
+
constant int64_t & ne00,
|
|
4444
|
+
constant int64_t & ne01,
|
|
4445
|
+
constant int64_t & ne02,
|
|
4446
|
+
constant uint64_t & nb00,
|
|
4447
|
+
constant uint64_t & nb01,
|
|
4448
|
+
constant uint64_t & nb02,
|
|
4449
|
+
constant int64_t & ne10,
|
|
4450
|
+
constant int64_t & ne11,
|
|
4451
|
+
constant int64_t & ne12,
|
|
4452
|
+
constant uint64_t & nb10,
|
|
4453
|
+
constant uint64_t & nb11,
|
|
4454
|
+
constant uint64_t & nb12,
|
|
4455
|
+
constant int64_t & ne0,
|
|
4456
|
+
constant int64_t & ne1,
|
|
4457
|
+
constant uint & r2,
|
|
4458
|
+
constant uint & r3,
|
|
4459
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
4460
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
4461
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
4462
|
+
|
|
4463
|
+
kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
|
4464
|
+
}
|
|
4465
|
+
|
|
4176
4466
|
|
|
4177
4467
|
//============================= templates and their specializations =============================
|
|
4178
4468
|
|
|
@@ -4518,6 +4808,22 @@ void dequantize_iq3_xxs(device const block_iq3_xxs * xb, short il, thread type4x
|
|
|
4518
4808
|
}
|
|
4519
4809
|
}
|
|
4520
4810
|
|
|
4811
|
+
template <typename type4x4>
|
|
4812
|
+
void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 & reg) {
|
|
4813
|
+
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
|
|
4814
|
+
const float d = xb->d;
|
|
4815
|
+
device const uint8_t * qs = xb->qs + 2*il;
|
|
4816
|
+
device const uint8_t * sc = xb->scales + il;
|
|
4817
|
+
const float dl1 = d * (2*(sc[0] & 7) + 1);
|
|
4818
|
+
const float dl2 = d * (2*((sc[0] >> 4) & 7) + 1);
|
|
4819
|
+
constant int8_t * grid1 = (constant int8_t *)(iq1s_grid + (qs[0] | ((sc[0] & 0x08) << 5)));
|
|
4820
|
+
constant int8_t * grid2 = (constant int8_t *)(iq1s_grid + (qs[1] | ((sc[0] & 0x80) << 1)));
|
|
4821
|
+
for (int i = 0; i < 8; ++i) {
|
|
4822
|
+
reg[i/4+0][i%4] = dl1 * grid1[i];
|
|
4823
|
+
reg[i/4+2][i%4] = dl2 * grid2[i];
|
|
4824
|
+
}
|
|
4825
|
+
}
|
|
4826
|
+
|
|
4521
4827
|
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
|
|
4522
4828
|
kernel void kernel_get_rows(
|
|
4523
4829
|
device const void * src0,
|
|
@@ -5060,6 +5366,7 @@ template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows
|
|
|
5060
5366
|
template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_t kernel_get_rows<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
|
5061
5367
|
template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_t kernel_get_rows<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
|
5062
5368
|
template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_t kernel_get_rows<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
|
5369
|
+
template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
|
5063
5370
|
|
|
5064
5371
|
//
|
|
5065
5372
|
// matrix-matrix multiplication
|
|
@@ -5099,6 +5406,7 @@ template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<b
|
|
|
5099
5406
|
template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
|
5100
5407
|
template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
|
5101
5408
|
template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
|
5409
|
+
template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
|
5102
5410
|
|
|
5103
5411
|
//
|
|
5104
5412
|
// indirect matrix-matrix multiplication
|
|
@@ -5150,6 +5458,7 @@ template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mu
|
|
|
5150
5458
|
template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
|
5151
5459
|
template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
|
5152
5460
|
template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
|
5461
|
+
template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
|
5153
5462
|
|
|
5154
5463
|
//
|
|
5155
5464
|
// matrix-vector multiplication
|
|
@@ -6117,3 +6426,66 @@ kernel void kernel_mul_mv_id_iq3_xxs_f32(
|
|
|
6117
6426
|
tiisg,
|
|
6118
6427
|
sgitg);
|
|
6119
6428
|
}
|
|
6429
|
+
|
|
6430
|
+
[[host_name("kernel_mul_mv_id_iq1_s_f32")]]
|
|
6431
|
+
kernel void kernel_mul_mv_id_iq1_s_f32(
|
|
6432
|
+
device const char * ids,
|
|
6433
|
+
device const char * src1,
|
|
6434
|
+
device float * dst,
|
|
6435
|
+
constant uint64_t & nbi1,
|
|
6436
|
+
constant int64_t & ne00,
|
|
6437
|
+
constant int64_t & ne01,
|
|
6438
|
+
constant int64_t & ne02,
|
|
6439
|
+
constant uint64_t & nb00,
|
|
6440
|
+
constant uint64_t & nb01,
|
|
6441
|
+
constant uint64_t & nb02,
|
|
6442
|
+
constant int64_t & ne10,
|
|
6443
|
+
constant int64_t & ne11,
|
|
6444
|
+
constant int64_t & ne12,
|
|
6445
|
+
constant int64_t & ne13,
|
|
6446
|
+
constant uint64_t & nb10,
|
|
6447
|
+
constant uint64_t & nb11,
|
|
6448
|
+
constant uint64_t & nb12,
|
|
6449
|
+
constant int64_t & ne0,
|
|
6450
|
+
constant int64_t & ne1,
|
|
6451
|
+
constant uint64_t & nb1,
|
|
6452
|
+
constant uint & r2,
|
|
6453
|
+
constant uint & r3,
|
|
6454
|
+
constant int & idx,
|
|
6455
|
+
device const char * src00,
|
|
6456
|
+
device const char * src01,
|
|
6457
|
+
device const char * src02,
|
|
6458
|
+
device const char * src03,
|
|
6459
|
+
device const char * src04,
|
|
6460
|
+
device const char * src05,
|
|
6461
|
+
device const char * src06,
|
|
6462
|
+
device const char * src07,
|
|
6463
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
6464
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
|
6465
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
6466
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
6467
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
|
6468
|
+
|
|
6469
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
|
6470
|
+
|
|
6471
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
|
6472
|
+
|
|
6473
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
6474
|
+
|
|
6475
|
+
kernel_mul_mv_iq1_s_f32_impl(
|
|
6476
|
+
src0[id],
|
|
6477
|
+
(device const float *) (src1 + bid*nb11),
|
|
6478
|
+
dst + bid*ne0,
|
|
6479
|
+
ne00,
|
|
6480
|
+
ne01,
|
|
6481
|
+
ne02,
|
|
6482
|
+
ne10,
|
|
6483
|
+
ne12,
|
|
6484
|
+
ne0,
|
|
6485
|
+
ne1,
|
|
6486
|
+
r2,
|
|
6487
|
+
r3,
|
|
6488
|
+
tgpig,
|
|
6489
|
+
tiisg,
|
|
6490
|
+
sgitg);
|
|
6491
|
+
}
|
|
Binary file
|
|
@@ -351,12 +351,17 @@ kernel void kernel_sum_rows(
|
|
|
351
351
|
kernel void kernel_soft_max(
|
|
352
352
|
device const float * src0,
|
|
353
353
|
device const float * src1,
|
|
354
|
+
device const float * src2,
|
|
354
355
|
device float * dst,
|
|
355
356
|
constant int64_t & ne00,
|
|
356
357
|
constant int64_t & ne01,
|
|
357
358
|
constant int64_t & ne02,
|
|
358
359
|
constant float & scale,
|
|
359
|
-
|
|
360
|
+
constant float & max_bias,
|
|
361
|
+
constant float & m0,
|
|
362
|
+
constant float & m1,
|
|
363
|
+
constant uint32_t & n_head_log2,
|
|
364
|
+
threadgroup float * buf [[threadgroup(0)]],
|
|
360
365
|
uint tgpig[[threadgroup_position_in_grid]],
|
|
361
366
|
uint tpitg[[thread_position_in_threadgroup]],
|
|
362
367
|
uint sgitg[[simdgroup_index_in_threadgroup]],
|
|
@@ -368,13 +373,26 @@ kernel void kernel_soft_max(
|
|
|
368
373
|
|
|
369
374
|
device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
|
370
375
|
device const float * pmask = src1 != src0 ? src1 + i01*ne00 : nullptr;
|
|
376
|
+
device const float * ppos = src2 != src0 ? src2 : nullptr;
|
|
371
377
|
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
|
372
378
|
|
|
379
|
+
float slope = 0.0f;
|
|
380
|
+
|
|
381
|
+
// ALiBi
|
|
382
|
+
if (max_bias > 0.0f) {
|
|
383
|
+
const int64_t h = i02;
|
|
384
|
+
|
|
385
|
+
const float base = h < n_head_log2 ? m0 : m1;
|
|
386
|
+
const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
|
387
|
+
|
|
388
|
+
slope = pow(base, exp);
|
|
389
|
+
}
|
|
390
|
+
|
|
373
391
|
// parallel max
|
|
374
392
|
float lmax = -INFINITY;
|
|
375
393
|
|
|
376
394
|
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
|
377
|
-
lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f));
|
|
395
|
+
lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f) + slope*ppos[i00]);
|
|
378
396
|
}
|
|
379
397
|
|
|
380
398
|
// find the max value in the block
|
|
@@ -399,7 +417,7 @@ kernel void kernel_soft_max(
|
|
|
399
417
|
// parallel sum
|
|
400
418
|
float lsum = 0.0f;
|
|
401
419
|
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
|
402
|
-
const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val);
|
|
420
|
+
const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f) + slope*ppos[i00]) - max_val);
|
|
403
421
|
lsum += exp_psrc0;
|
|
404
422
|
pdst[i00] = exp_psrc0;
|
|
405
423
|
}
|
|
@@ -437,12 +455,17 @@ kernel void kernel_soft_max(
|
|
|
437
455
|
kernel void kernel_soft_max_4(
|
|
438
456
|
device const float * src0,
|
|
439
457
|
device const float * src1,
|
|
458
|
+
device const float * src2,
|
|
440
459
|
device float * dst,
|
|
441
460
|
constant int64_t & ne00,
|
|
442
461
|
constant int64_t & ne01,
|
|
443
462
|
constant int64_t & ne02,
|
|
444
463
|
constant float & scale,
|
|
445
|
-
|
|
464
|
+
constant float & max_bias,
|
|
465
|
+
constant float & m0,
|
|
466
|
+
constant float & m1,
|
|
467
|
+
constant uint32_t & n_head_log2,
|
|
468
|
+
threadgroup float * buf [[threadgroup(0)]],
|
|
446
469
|
uint tgpig[[threadgroup_position_in_grid]],
|
|
447
470
|
uint tpitg[[thread_position_in_threadgroup]],
|
|
448
471
|
uint sgitg[[simdgroup_index_in_threadgroup]],
|
|
@@ -454,13 +477,25 @@ kernel void kernel_soft_max_4(
|
|
|
454
477
|
|
|
455
478
|
device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
|
456
479
|
device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
|
|
480
|
+
device const float4 * ppos = src2 != src0 ? (device const float4 *)(src2) : nullptr;
|
|
457
481
|
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
|
458
482
|
|
|
483
|
+
float slope = 0.0f;
|
|
484
|
+
|
|
485
|
+
if (max_bias > 0.0f) {
|
|
486
|
+
const int64_t h = i02;
|
|
487
|
+
|
|
488
|
+
const float base = h < n_head_log2 ? m0 : m1;
|
|
489
|
+
const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
|
490
|
+
|
|
491
|
+
slope = pow(base, exp);
|
|
492
|
+
}
|
|
493
|
+
|
|
459
494
|
// parallel max
|
|
460
495
|
float4 lmax4 = -INFINITY;
|
|
461
496
|
|
|
462
497
|
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
|
463
|
-
lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f));
|
|
498
|
+
lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f) + slope*ppos[i00]);
|
|
464
499
|
}
|
|
465
500
|
|
|
466
501
|
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
|
|
@@ -486,7 +521,7 @@ kernel void kernel_soft_max_4(
|
|
|
486
521
|
// parallel sum
|
|
487
522
|
float4 lsum4 = 0.0f;
|
|
488
523
|
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
|
489
|
-
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val);
|
|
524
|
+
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f) + slope*ppos[i00]) - max_val);
|
|
490
525
|
lsum4 += exp_psrc4;
|
|
491
526
|
pdst4[i00] = exp_psrc4;
|
|
492
527
|
}
|
|
@@ -2490,6 +2525,13 @@ typedef struct {
|
|
|
2490
2525
|
} block_iq3_xxs;
|
|
2491
2526
|
// 98 bytes / block for QK_K = 256, so 3.0625 bpw
|
|
2492
2527
|
|
|
2528
|
+
typedef struct {
|
|
2529
|
+
half d;
|
|
2530
|
+
uint8_t qs[QK_K/8];
|
|
2531
|
+
uint8_t scales[QK_K/16];
|
|
2532
|
+
} block_iq1_s;
|
|
2533
|
+
|
|
2534
|
+
|
|
2493
2535
|
//====================================== dot products =========================
|
|
2494
2536
|
|
|
2495
2537
|
void kernel_mul_mv_q2_K_f32_impl(
|
|
@@ -3747,6 +3789,137 @@ constexpr constant static uint32_t iq3xxs_grid[256] = {
|
|
|
3747
3789
|
0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
|
|
3748
3790
|
};
|
|
3749
3791
|
|
|
3792
|
+
#define NGRID_IQ1S 512
|
|
3793
|
+
constexpr constant static uint64_t iq1s_grid[NGRID_IQ1S] = {
|
|
3794
|
+
0xffffffffffff0101, 0xffffffffff01ff00, 0xffffffffff010100, 0xffffffff00000000,
|
|
3795
|
+
0xffffffff01ff00ff, 0xffffffff01ff0001, 0xffffffff0101ffff, 0xffffffff0101ff01,
|
|
3796
|
+
0xffffff00ff000000, 0xffffff000000ff00, 0xffffff00000000ff, 0xffffff0000000100,
|
|
3797
|
+
0xffffff0000010000, 0xffffff0001000000, 0xffffff01ffff00ff, 0xffffff01ff01ff00,
|
|
3798
|
+
0xffffff01ff010100, 0xffffff0100000001, 0xffffff0101ffff00, 0xffffff0101ff0101,
|
|
3799
|
+
0xffffff0101010100, 0xffff00ffff00ff01, 0xffff00ffff0000ff, 0xffff00ff00ff0100,
|
|
3800
|
+
0xffff00ff0100ff00, 0xffff00ff010001ff, 0xffff0000ff0101ff, 0xffff000000ffff00,
|
|
3801
|
+
0xffff000000000000, 0xffff00000001ff01, 0xffff000001000101, 0xffff0000010100ff,
|
|
3802
|
+
0xffff0001ffff0100, 0xffff00010000ff00, 0xffff000100010101, 0xffff000101000000,
|
|
3803
|
+
0xffff01ffffff0000, 0xffff01ffff01ffff, 0xffff01ffff010100, 0xffff01ff00000000,
|
|
3804
|
+
0xffff01ff01ffffff, 0xffff01ff01ff0001, 0xffff01ff0101ffff, 0xffff01ff01010001,
|
|
3805
|
+
0xffff0100ffffff01, 0xffff01000000ffff, 0xffff010000000100, 0xffff010001ff01ff,
|
|
3806
|
+
0xffff010001000000, 0xffff0101ff000000, 0xffff0101000101ff, 0xffff010101ffff01,
|
|
3807
|
+
0xffff01010101ff00, 0xff00ffffff000000, 0xff00ffff00ffff00, 0xff00ffff00000001,
|
|
3808
|
+
0xff00ffff000001ff, 0xff00ffff01010000, 0xff00ff00ffff0000, 0xff00ff00ff00ff00,
|
|
3809
|
+
0xff00ff00ff0000ff, 0xff00ff00ff000100, 0xff00ff00ff010001, 0xff00ff0000ff0001,
|
|
3810
|
+
0xff00ff000000ffff, 0xff00ff0000000000, 0xff00ff000001ff00, 0xff00ff0000010100,
|
|
3811
|
+
0xff00ff0001ff0000, 0xff00ff000100ff00, 0xff00ff0001000100, 0xff00ff01ff000000,
|
|
3812
|
+
0xff00ff0100ff0000, 0xff00ff01000001ff, 0xff00ff0101010001, 0xff0000ff00000000,
|
|
3813
|
+
0xff0000ff0001ff00, 0xff0000ff00010100, 0xff000000ffff0101, 0xff000000ff000000,
|
|
3814
|
+
0xff000000ff01ff00, 0xff00000000ff0000, 0xff0000000000ff00, 0xff000000000000ff,
|
|
3815
|
+
0xff00000000000000, 0xff00000000000001, 0xff00000000000100, 0xff0000000001ffff,
|
|
3816
|
+
0xff00000000010000, 0xff00000001000000, 0xff00000001010100, 0xff000001ff00ff01,
|
|
3817
|
+
0xff000001ff0100ff, 0xff00000100000000, 0xff0000010001ff00, 0xff00000101ff0100,
|
|
3818
|
+
0xff0000010100ff00, 0xff0001ff00ff00ff, 0xff0001ff00000101, 0xff0001ff000100ff,
|
|
3819
|
+
0xff0001ff01000000, 0xff000100ff0001ff, 0xff0001000000ff01, 0xff00010000000000,
|
|
3820
|
+
0xff00010000010001, 0xff00010000010100, 0xff00010001ffff00, 0xff00010001ff0101,
|
|
3821
|
+
0xff00010001010000, 0xff000101ffffffff, 0xff000101ff000101, 0xff00010101ff00ff,
|
|
3822
|
+
0xff00010101000001, 0xff000101010100ff, 0xff01ffffff000101, 0xff01ffffff01ffff,
|
|
3823
|
+
0xff01ffffff01ff01, 0xff01ffffff0101ff, 0xff01ffff00000000, 0xff01ffff01ff0001,
|
|
3824
|
+
0xff01ffff0101ff01, 0xff01ff00ff000000, 0xff01ff0000ff0100, 0xff01ff000000ff01,
|
|
3825
|
+
0xff01ff0000010000, 0xff01ff00010000ff, 0xff01ff01ff01ff00, 0xff01ff0100000101,
|
|
3826
|
+
0xff0100ffffff0000, 0xff0100ffff010000, 0xff0100ff01ff00ff, 0xff0100ff01000100,
|
|
3827
|
+
0xff0100ff010100ff, 0xff010000ffffff01, 0xff01000000000000, 0xff0100000101ff00,
|
|
3828
|
+
0xff010001ffff00ff, 0xff010001ff000100, 0xff01000100ffff00, 0xff01000100010001,
|
|
3829
|
+
0xff01000101ff0001, 0xff010001010001ff, 0xff0101ffffffffff, 0xff0101ffff01ffff,
|
|
3830
|
+
0xff0101ffff010101, 0xff0101ff0000ff00, 0xff0101ff01010001, 0xff010100ff000000,
|
|
3831
|
+
0xff010100ff01ff01, 0xff01010000ff0001, 0xff01010000000100, 0xff01010001000000,
|
|
3832
|
+
0xff0101010100ffff, 0x00ffffff0000ff01, 0x00ffffff000000ff, 0x00ffffff00000100,
|
|
3833
|
+
0x00ffffff00010000, 0x00ffff00ffff0001, 0x00ffff00ff0000ff, 0x00ffff00ff000100,
|
|
3834
|
+
0x00ffff0000000000, 0x00ffff0001000100, 0x00ffff0001010001, 0x00ffff01ff00ff01,
|
|
3835
|
+
0x00ffff0100ff0100, 0x00ffff010000ff00, 0x00ffff01000100ff, 0x00ffff0101ff00ff,
|
|
3836
|
+
0x00ffff010101ff00, 0x00ff00ffffffffff, 0x00ff00ffffff01ff, 0x00ff00ffff000101,
|
|
3837
|
+
0x00ff00ff00000000, 0x00ff00ff000101ff, 0x00ff00ff01010101, 0x00ff0000ff000000,
|
|
3838
|
+
0x00ff0000ff01ffff, 0x00ff000000ff0000, 0x00ff00000000ff00, 0x00ff0000000000ff,
|
|
3839
|
+
0x00ff000000000000, 0x00ff000000000001, 0x00ff000000000100, 0x00ff000000010000,
|
|
3840
|
+
0x00ff000001ffff01, 0x00ff000001000000, 0x00ff0001ff000101, 0x00ff000100ffffff,
|
|
3841
|
+
0x00ff000100000000, 0x00ff0001010001ff, 0x00ff01ffff000000, 0x00ff01ff0001ff00,
|
|
3842
|
+
0x00ff01ff01ff0100, 0x00ff0100ff01ff01, 0x00ff010000ff00ff, 0x00ff010000ff0101,
|
|
3843
|
+
0x00ff010000000000, 0x00ff010000010101, 0x00ff01000100ff00, 0x00ff010001010000,
|
|
3844
|
+
0x00ff0101ffffff00, 0x00ff01010000ff01, 0x00ff010100000100, 0x00ff010101ff0000,
|
|
3845
|
+
0x0000ffffffff0100, 0x0000ffffff00ff00, 0x0000ffffff0000ff, 0x0000ffffff010000,
|
|
3846
|
+
0x0000ffff00000000, 0x0000ffff00010101, 0x0000ffff01ffff01, 0x0000ffff01000100,
|
|
3847
|
+
0x0000ff00ff000000, 0x0000ff00ff01ff00, 0x0000ff00ff0101ff, 0x0000ff0000ff0000,
|
|
3848
|
+
0x0000ff000000ff00, 0x0000ff00000000ff, 0x0000ff0000000000, 0x0000ff0000000001,
|
|
3849
|
+
0x0000ff0000000100, 0x0000ff0000010000, 0x0000ff0001ffffff, 0x0000ff0001ff01ff,
|
|
3850
|
+
0x0000ff0001000000, 0x0000ff000101ffff, 0x0000ff01ffff0101, 0x0000ff01ff010000,
|
|
3851
|
+
0x0000ff0100000000, 0x0000ff0101000101, 0x000000ffffff0001, 0x000000ffff000000,
|
|
3852
|
+
0x000000ff00ff0000, 0x000000ff0000ff00, 0x000000ff000000ff, 0x000000ff00000000,
|
|
3853
|
+
0x000000ff00000001, 0x000000ff00000100, 0x000000ff00010000, 0x000000ff01000000,
|
|
3854
|
+
0x000000ff0101ff00, 0x00000000ffff0000, 0x00000000ff00ff00, 0x00000000ff0000ff,
|
|
3855
|
+
0x00000000ff000000, 0x00000000ff000001, 0x00000000ff000100, 0x00000000ff010000,
|
|
3856
|
+
0x0000000000ffff00, 0x0000000000ff00ff, 0x0000000000ff0000, 0x0000000000ff0001,
|
|
3857
|
+
0x0000000000ff0100, 0x000000000000ffff, 0x000000000000ff00, 0x000000000000ff01,
|
|
3858
|
+
0x00000000000000ff, 0x0000000000000001, 0x00000000000001ff, 0x0000000000000100,
|
|
3859
|
+
0x0000000000000101, 0x000000000001ff00, 0x00000000000100ff, 0x0000000000010000,
|
|
3860
|
+
0x0000000000010001, 0x0000000000010100, 0x0000000001ff0000, 0x000000000100ff00,
|
|
3861
|
+
0x00000000010000ff, 0x0000000001000000, 0x0000000001000001, 0x0000000001000100,
|
|
3862
|
+
0x0000000001010000, 0x00000001ffff01ff, 0x00000001ff000000, 0x0000000100ff0000,
|
|
3863
|
+
0x000000010000ff00, 0x00000001000000ff, 0x0000000100000000, 0x0000000100000001,
|
|
3864
|
+
0x0000000100000100, 0x0000000100010000, 0x0000000101000000, 0x000001ffff00ff00,
|
|
3865
|
+
0x000001ffff010001, 0x000001ffff0101ff, 0x000001ff00ffff01, 0x000001ff0000ffff,
|
|
3866
|
+
0x000001ff00000000, 0x000001ff010000ff, 0x000001ff01010100, 0x00000100ffff0100,
|
|
3867
|
+
0x00000100ff000000, 0x0000010000ff0000, 0x000001000000ff00, 0x00000100000000ff,
|
|
3868
|
+
0x0000010000000000, 0x0000010000000001, 0x0000010000000100, 0x0000010000010000,
|
|
3869
|
+
0x0000010001000000, 0x000001000101ff01, 0x00000101ffff0001, 0x00000101ff01ffff,
|
|
3870
|
+
0x0000010100000000, 0x0000010101010100, 0x0001ffffff000000, 0x0001ffff00ffffff,
|
|
3871
|
+
0x0001ffff00000100, 0x0001ffff0001ff00, 0x0001ffff01000000, 0x0001ff00ffffff00,
|
|
3872
|
+
0x0001ff00ffff01ff, 0x0001ff00ff010000, 0x0001ff0000000000, 0x0001ff0000010001,
|
|
3873
|
+
0x0001ff0001ff0000, 0x0001ff0001010100, 0x0001ff01ff0000ff, 0x0001ff01ff000001,
|
|
3874
|
+
0x0001ff0100ffffff, 0x0001ff010001ffff, 0x0001ff01000101ff, 0x0001ff010100ff01,
|
|
3875
|
+
0x000100ffff00ffff, 0x000100ffff00ff01, 0x000100ffff000100, 0x000100ff00000000,
|
|
3876
|
+
0x000100ff000101ff, 0x000100ff01ff0101, 0x000100ff0100ffff, 0x000100ff01010101,
|
|
3877
|
+
0x00010000ff000000, 0x00010000ff010100, 0x0001000000ff0000, 0x000100000000ff00,
|
|
3878
|
+
0x00010000000000ff, 0x0001000000000000, 0x0001000000000001, 0x0001000000000100,
|
|
3879
|
+
0x0001000000010000, 0x0001000001ffff01, 0x0001000001000000, 0x0001000100ff0101,
|
|
3880
|
+
0x0001000100000000, 0x00010001010100ff, 0x000101ffffff01ff, 0x000101ffffff0101,
|
|
3881
|
+
0x000101ff00010000, 0x000101ff01ff0000, 0x000101ff0100ff01, 0x00010100ffff0000,
|
|
3882
|
+
0x0001010000000000, 0x000101000001ffff, 0x0001010000010101, 0x00010100010001ff,
|
|
3883
|
+
0x00010101ff00ff00, 0x00010101ff010001, 0x0001010100ffffff, 0x0001010100ff01ff,
|
|
3884
|
+
0x00010101000101ff, 0x0001010101ff0000, 0x000101010100ff01, 0x0001010101000101,
|
|
3885
|
+
0x01ffffffffff0101, 0x01ffffffff01ffff, 0x01ffffffff01ff01, 0x01ffffffff0101ff,
|
|
3886
|
+
0x01ffffffff010101, 0x01ffffff00000000, 0x01ffffff01ff01ff, 0x01ffffff01000101,
|
|
3887
|
+
0x01ffffff0101ff01, 0x01ffffff010100ff, 0x01ffff000000ff00, 0x01ffff0000000001,
|
|
3888
|
+
0x01ffff00000001ff, 0x01ffff0000010000, 0x01ffff0001ff0000, 0x01ffff01ffffffff,
|
|
3889
|
+
0x01ffff01ffff01ff, 0x01ffff01ff000000, 0x01ffff01ff01ffff, 0x01ffff01ff0101ff,
|
|
3890
|
+
0x01ffff010100ffff, 0x01ff00ffffff0000, 0x01ff00ffff010000, 0x01ff00ff00ffff01,
|
|
3891
|
+
0x01ff0000ff0000ff, 0x01ff000000000000, 0x01ff00000001ff01, 0x01ff000001ffffff,
|
|
3892
|
+
0x01ff000001010100, 0x01ff0001ffffff01, 0x01ff0001ff010001, 0x01ff000101ff0100,
|
|
3893
|
+
0x01ff000101000001, 0x01ff0001010100ff, 0x01ff01ffff00ffff, 0x01ff01ff00010001,
|
|
3894
|
+
0x01ff01ff01000000, 0x01ff01ff010101ff, 0x01ff0100ff000001, 0x01ff010000ffff00,
|
|
3895
|
+
0x01ff010000000100, 0x01ff010001ff01ff, 0x01ff01000101ffff, 0x01ff0101ffff00ff,
|
|
3896
|
+
0x01ff0101ffff0101, 0x01ff0101ff0101ff, 0x01ff010100010000, 0x0100ffff00ff00ff,
|
|
3897
|
+
0x0100ffff00ff0001, 0x0100ffff00000100, 0x0100ffff0100ff00, 0x0100ff00ffff0000,
|
|
3898
|
+
0x0100ff00ff00ffff, 0x0100ff00ff00ff01, 0x0100ff00ff000100, 0x0100ff00ff010000,
|
|
3899
|
+
0x0100ff0000000000, 0x0100ff00000100ff, 0x0100ff0001ff0101, 0x0100ff0001010101,
|
|
3900
|
+
0x0100ff0100ff00ff, 0x0100ff0100ff0001, 0x0100ff0100000100, 0x0100ff0100010001,
|
|
3901
|
+
0x0100ff0101000000, 0x010000ffff00ff00, 0x010000ff0000ffff, 0x010000ff00000000,
|
|
3902
|
+
0x010000ff010001ff, 0x010000ff01010001, 0x01000000ffffff00, 0x01000000ffff0101,
|
|
3903
|
+
0x01000000ff000000, 0x01000000ff0100ff, 0x01000000ff010101, 0x0100000000ff0000,
|
|
3904
|
+
0x010000000000ff00, 0x01000000000000ff, 0x0100000000000000, 0x0100000000000001,
|
|
3905
|
+
0x0100000000000100, 0x0100000000010000, 0x0100000001000000, 0x0100000100000000,
|
|
3906
|
+
0x01000001000101ff, 0x0100000101ffff01, 0x010001ffff000101, 0x010001ff00ff0100,
|
|
3907
|
+
0x010001ff0000ff00, 0x010001ff000100ff, 0x010001ff01ffffff, 0x01000100ffff0000,
|
|
3908
|
+
0x01000100ff0001ff, 0x0100010000000000, 0x010001000001ff00, 0x0100010001ff0000,
|
|
3909
|
+
0x01000100010000ff, 0x0100010001000101, 0x01000101ff00ff01, 0x0100010100ff0100,
|
|
3910
|
+
0x010001010000ffff, 0x0100010101010001, 0x0101ffffffff0101, 0x0101ffffff0001ff,
|
|
3911
|
+
0x0101ffffff01ffff, 0x0101ffffff010101, 0x0101ffff00000000, 0x0101ffff0101ffff,
|
|
3912
|
+
0x0101ffff010101ff, 0x0101ff00ff000000, 0x0101ff0000ff0100, 0x0101ff000000ff00,
|
|
3913
|
+
0x0101ff0000010000, 0x0101ff00010000ff, 0x0101ff0001000001, 0x0101ff01ff010101,
|
|
3914
|
+
0x0101ff0100000000, 0x0101ff010101ff00, 0x010100ffffff0000, 0x010100ffff010000,
|
|
3915
|
+
0x010100ff00ff01ff, 0x010100ff000000ff, 0x010100ff00000101, 0x010100ff01ffff00,
|
|
3916
|
+
0x01010000ffffff01, 0x01010000ff000100, 0x01010000ff01ff01, 0x0101000000000000,
|
|
3917
|
+
0x01010000000100ff, 0x010100000101ff01, 0x01010001ffff0000, 0x01010001ff00ffff,
|
|
3918
|
+
0x01010001ff010000, 0x0101000101ffffff, 0x0101000101ff01ff, 0x0101000101010101,
|
|
3919
|
+
0x010101ffff01ffff, 0x010101ff00000000, 0x010101ff0001ff01, 0x010101ff0101ffff,
|
|
3920
|
+
0x010101ff010101ff, 0x01010100ffffffff, 0x01010100ff000001, 0x010101000000ff00,
|
|
3921
|
+
0x0101010001010000, 0x0101010100ff0001, 0x010101010001ff01, 0x010101010101ffff,
|
|
3922
|
+
};
|
|
3750
3923
|
|
|
3751
3924
|
constexpr constant static uint8_t ksigns_iq2xs[128] = {
|
|
3752
3925
|
0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
|
|
@@ -4173,6 +4346,123 @@ kernel void kernel_mul_mv_iq3_xxs_f32(
|
|
|
4173
4346
|
kernel_mul_mv_iq3_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
|
4174
4347
|
}
|
|
4175
4348
|
|
|
4349
|
+
void kernel_mul_mv_iq1_s_f32_impl(
|
|
4350
|
+
device const void * src0,
|
|
4351
|
+
device const float * src1,
|
|
4352
|
+
device float * dst,
|
|
4353
|
+
constant int64_t & ne00,
|
|
4354
|
+
constant int64_t & ne01,
|
|
4355
|
+
constant int64_t & ne02,
|
|
4356
|
+
constant int64_t & ne10,
|
|
4357
|
+
constant int64_t & ne12,
|
|
4358
|
+
constant int64_t & ne0,
|
|
4359
|
+
constant int64_t & ne1,
|
|
4360
|
+
constant uint & r2,
|
|
4361
|
+
constant uint & r3,
|
|
4362
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
4363
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
4364
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
4365
|
+
|
|
4366
|
+
const int nb = ne00/QK_K;
|
|
4367
|
+
const int r0 = tgpig.x;
|
|
4368
|
+
const int r1 = tgpig.y;
|
|
4369
|
+
const int im = tgpig.z;
|
|
4370
|
+
|
|
4371
|
+
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
|
4372
|
+
const int ib_row = first_row * nb;
|
|
4373
|
+
|
|
4374
|
+
const uint i12 = im%ne12;
|
|
4375
|
+
const uint i13 = im/ne12;
|
|
4376
|
+
|
|
4377
|
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
|
4378
|
+
|
|
4379
|
+
device const block_iq1_s * x = (device const block_iq1_s *) src0 + ib_row + offset0;
|
|
4380
|
+
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
|
4381
|
+
|
|
4382
|
+
float yl[16];
|
|
4383
|
+
float sumf[N_DST]={0.f}, all_sum;
|
|
4384
|
+
|
|
4385
|
+
const int nb32 = nb * (QK_K / 32);
|
|
4386
|
+
|
|
4387
|
+
#if QK_K == 256
|
|
4388
|
+
const int ix = tiisg/2;
|
|
4389
|
+
const int il = tiisg%2;
|
|
4390
|
+
|
|
4391
|
+
device const float * y4 = y + 32 * ix + 16 * il;
|
|
4392
|
+
|
|
4393
|
+
for (int ib32 = ix; ib32 < nb32; ib32 += 16) {
|
|
4394
|
+
|
|
4395
|
+
for (int i = 0; i < 16; ++i) {
|
|
4396
|
+
yl[i] = y4[i];
|
|
4397
|
+
}
|
|
4398
|
+
|
|
4399
|
+
const int ibl = ib32 / (QK_K / 32);
|
|
4400
|
+
const int ib = ib32 % (QK_K / 32);
|
|
4401
|
+
|
|
4402
|
+
device const block_iq1_s * xr = x + ibl;
|
|
4403
|
+
device const uint8_t * qs = xr->qs + 4 * ib + 2 * il;
|
|
4404
|
+
device const uint8_t * sc = xr->scales + 2 * ib + il;
|
|
4405
|
+
device const half * dh = &xr->d;
|
|
4406
|
+
|
|
4407
|
+
for (int row = 0; row < N_DST; row++) {
|
|
4408
|
+
|
|
4409
|
+
constant int8_t * grid1 = (constant int8_t *)(iq1s_grid + (qs[0] | ((sc[0] & 0x08) << 5)));
|
|
4410
|
+
constant int8_t * grid2 = (constant int8_t *)(iq1s_grid + (qs[1] | ((sc[0] & 0x80) << 1)));
|
|
4411
|
+
|
|
4412
|
+
float2 sum = {0};
|
|
4413
|
+
for (int j = 0; j < 8; ++j) {
|
|
4414
|
+
sum[0] += yl[j+ 0] * grid1[j];
|
|
4415
|
+
sum[1] += yl[j+ 8] * grid2[j];
|
|
4416
|
+
}
|
|
4417
|
+
sumf[row] += (float)dh[0] * (sum[0] * (2*(sc[0] & 7) + 1) + sum[1] * (2*((sc[0] >> 4) & 7) + 1));
|
|
4418
|
+
|
|
4419
|
+
dh += nb*sizeof(block_iq1_s)/2;
|
|
4420
|
+
qs += nb*sizeof(block_iq1_s);
|
|
4421
|
+
sc += nb*sizeof(block_iq1_s);
|
|
4422
|
+
}
|
|
4423
|
+
|
|
4424
|
+
y4 += 16 * 32;
|
|
4425
|
+
}
|
|
4426
|
+
#else
|
|
4427
|
+
// TODO
|
|
4428
|
+
#endif
|
|
4429
|
+
|
|
4430
|
+
for (int row = 0; row < N_DST; ++row) {
|
|
4431
|
+
all_sum = simd_sum(sumf[row]);
|
|
4432
|
+
if (tiisg == 0) {
|
|
4433
|
+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
|
|
4434
|
+
}
|
|
4435
|
+
}
|
|
4436
|
+
}
|
|
4437
|
+
|
|
4438
|
+
[[host_name("kernel_mul_mv_iq1_s_f32")]]
|
|
4439
|
+
kernel void kernel_mul_mv_iq1_s_f32(
|
|
4440
|
+
device const void * src0,
|
|
4441
|
+
device const float * src1,
|
|
4442
|
+
device float * dst,
|
|
4443
|
+
constant int64_t & ne00,
|
|
4444
|
+
constant int64_t & ne01,
|
|
4445
|
+
constant int64_t & ne02,
|
|
4446
|
+
constant uint64_t & nb00,
|
|
4447
|
+
constant uint64_t & nb01,
|
|
4448
|
+
constant uint64_t & nb02,
|
|
4449
|
+
constant int64_t & ne10,
|
|
4450
|
+
constant int64_t & ne11,
|
|
4451
|
+
constant int64_t & ne12,
|
|
4452
|
+
constant uint64_t & nb10,
|
|
4453
|
+
constant uint64_t & nb11,
|
|
4454
|
+
constant uint64_t & nb12,
|
|
4455
|
+
constant int64_t & ne0,
|
|
4456
|
+
constant int64_t & ne1,
|
|
4457
|
+
constant uint & r2,
|
|
4458
|
+
constant uint & r3,
|
|
4459
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
4460
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
4461
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
4462
|
+
|
|
4463
|
+
kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
|
4464
|
+
}
|
|
4465
|
+
|
|
4176
4466
|
|
|
4177
4467
|
//============================= templates and their specializations =============================
|
|
4178
4468
|
|
|
@@ -4518,6 +4808,22 @@ void dequantize_iq3_xxs(device const block_iq3_xxs * xb, short il, thread type4x
|
|
|
4518
4808
|
}
|
|
4519
4809
|
}
|
|
4520
4810
|
|
|
4811
|
+
template <typename type4x4>
|
|
4812
|
+
void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 & reg) {
|
|
4813
|
+
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
|
|
4814
|
+
const float d = xb->d;
|
|
4815
|
+
device const uint8_t * qs = xb->qs + 2*il;
|
|
4816
|
+
device const uint8_t * sc = xb->scales + il;
|
|
4817
|
+
const float dl1 = d * (2*(sc[0] & 7) + 1);
|
|
4818
|
+
const float dl2 = d * (2*((sc[0] >> 4) & 7) + 1);
|
|
4819
|
+
constant int8_t * grid1 = (constant int8_t *)(iq1s_grid + (qs[0] | ((sc[0] & 0x08) << 5)));
|
|
4820
|
+
constant int8_t * grid2 = (constant int8_t *)(iq1s_grid + (qs[1] | ((sc[0] & 0x80) << 1)));
|
|
4821
|
+
for (int i = 0; i < 8; ++i) {
|
|
4822
|
+
reg[i/4+0][i%4] = dl1 * grid1[i];
|
|
4823
|
+
reg[i/4+2][i%4] = dl2 * grid2[i];
|
|
4824
|
+
}
|
|
4825
|
+
}
|
|
4826
|
+
|
|
4521
4827
|
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
|
|
4522
4828
|
kernel void kernel_get_rows(
|
|
4523
4829
|
device const void * src0,
|
|
@@ -5060,6 +5366,7 @@ template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows
|
|
|
5060
5366
|
template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_t kernel_get_rows<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
|
5061
5367
|
template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_t kernel_get_rows<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
|
5062
5368
|
template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_t kernel_get_rows<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
|
5369
|
+
template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
|
5063
5370
|
|
|
5064
5371
|
//
|
|
5065
5372
|
// matrix-matrix multiplication
|
|
@@ -5099,6 +5406,7 @@ template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<b
|
|
|
5099
5406
|
template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
|
5100
5407
|
template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
|
5101
5408
|
template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
|
5409
|
+
template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
|
5102
5410
|
|
|
5103
5411
|
//
|
|
5104
5412
|
// indirect matrix-matrix multiplication
|
|
@@ -5150,6 +5458,7 @@ template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mu
|
|
|
5150
5458
|
template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
|
5151
5459
|
template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
|
5152
5460
|
template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
|
5461
|
+
template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
|
5153
5462
|
|
|
5154
5463
|
//
|
|
5155
5464
|
// matrix-vector multiplication
|
|
@@ -6117,3 +6426,66 @@ kernel void kernel_mul_mv_id_iq3_xxs_f32(
|
|
|
6117
6426
|
tiisg,
|
|
6118
6427
|
sgitg);
|
|
6119
6428
|
}
|
|
6429
|
+
|
|
6430
|
+
[[host_name("kernel_mul_mv_id_iq1_s_f32")]]
|
|
6431
|
+
kernel void kernel_mul_mv_id_iq1_s_f32(
|
|
6432
|
+
device const char * ids,
|
|
6433
|
+
device const char * src1,
|
|
6434
|
+
device float * dst,
|
|
6435
|
+
constant uint64_t & nbi1,
|
|
6436
|
+
constant int64_t & ne00,
|
|
6437
|
+
constant int64_t & ne01,
|
|
6438
|
+
constant int64_t & ne02,
|
|
6439
|
+
constant uint64_t & nb00,
|
|
6440
|
+
constant uint64_t & nb01,
|
|
6441
|
+
constant uint64_t & nb02,
|
|
6442
|
+
constant int64_t & ne10,
|
|
6443
|
+
constant int64_t & ne11,
|
|
6444
|
+
constant int64_t & ne12,
|
|
6445
|
+
constant int64_t & ne13,
|
|
6446
|
+
constant uint64_t & nb10,
|
|
6447
|
+
constant uint64_t & nb11,
|
|
6448
|
+
constant uint64_t & nb12,
|
|
6449
|
+
constant int64_t & ne0,
|
|
6450
|
+
constant int64_t & ne1,
|
|
6451
|
+
constant uint64_t & nb1,
|
|
6452
|
+
constant uint & r2,
|
|
6453
|
+
constant uint & r3,
|
|
6454
|
+
constant int & idx,
|
|
6455
|
+
device const char * src00,
|
|
6456
|
+
device const char * src01,
|
|
6457
|
+
device const char * src02,
|
|
6458
|
+
device const char * src03,
|
|
6459
|
+
device const char * src04,
|
|
6460
|
+
device const char * src05,
|
|
6461
|
+
device const char * src06,
|
|
6462
|
+
device const char * src07,
|
|
6463
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
6464
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
|
6465
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
6466
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
6467
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
|
6468
|
+
|
|
6469
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
|
6470
|
+
|
|
6471
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
|
6472
|
+
|
|
6473
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
6474
|
+
|
|
6475
|
+
kernel_mul_mv_iq1_s_f32_impl(
|
|
6476
|
+
src0[id],
|
|
6477
|
+
(device const float *) (src1 + bid*nb11),
|
|
6478
|
+
dst + bid*ne0,
|
|
6479
|
+
ne00,
|
|
6480
|
+
ne01,
|
|
6481
|
+
ne02,
|
|
6482
|
+
ne10,
|
|
6483
|
+
ne12,
|
|
6484
|
+
ne0,
|
|
6485
|
+
ne1,
|
|
6486
|
+
r2,
|
|
6487
|
+
r3,
|
|
6488
|
+
tgpig,
|
|
6489
|
+
tiisg,
|
|
6490
|
+
sgitg);
|
|
6491
|
+
}
|
|
Binary file
|
|
Binary file
|
package/package.json
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
{
|
|
2
2
|
"name": "node-llama-cpp",
|
|
3
|
-
"version": "2.8.
|
|
3
|
+
"version": "2.8.7",
|
|
4
4
|
"description": "Run AI models locally on your machine with node.js bindings for llama.cpp. Force a JSON schema on the model output on the generation level",
|
|
5
5
|
"main": "dist/index.js",
|
|
6
6
|
"type": "module",
|