llama_cpp 0.12.6 → 0.12.7
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +10 -0
- data/ext/llama_cpp/llama_cpp.cpp +21 -10
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +8 -1
- data/vendor/tmp/llama.cpp/Makefile +43 -12
- data/vendor/tmp/llama.cpp/ggml-alloc.c +73 -43
- data/vendor/tmp/llama.cpp/ggml-backend.c +18 -9
- data/vendor/tmp/llama.cpp/ggml-cuda.cu +560 -346
- data/vendor/tmp/llama.cpp/ggml-impl.h +20 -7
- data/vendor/tmp/llama.cpp/ggml-metal.m +99 -11
- data/vendor/tmp/llama.cpp/ggml-metal.metal +608 -9
- data/vendor/tmp/llama.cpp/ggml-quants.c +908 -54
- data/vendor/tmp/llama.cpp/ggml-quants.h +25 -2
- data/vendor/tmp/llama.cpp/ggml-sycl.cpp +81 -203
- data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +124 -52
- data/vendor/tmp/llama.cpp/ggml.c +948 -504
- data/vendor/tmp/llama.cpp/ggml.h +24 -11
- data/vendor/tmp/llama.cpp/llama.cpp +688 -163
- data/vendor/tmp/llama.cpp/llama.h +37 -1
- data/vendor/tmp/llama.cpp/scripts/get-flags.mk +1 -1
- metadata +2 -2
@@ -351,12 +351,17 @@ kernel void kernel_sum_rows(
|
|
351
351
|
kernel void kernel_soft_max(
|
352
352
|
device const float * src0,
|
353
353
|
device const float * src1,
|
354
|
+
device const float * src2,
|
354
355
|
device float * dst,
|
355
356
|
constant int64_t & ne00,
|
356
357
|
constant int64_t & ne01,
|
357
358
|
constant int64_t & ne02,
|
358
359
|
constant float & scale,
|
359
|
-
|
360
|
+
constant float & max_bias,
|
361
|
+
constant float & m0,
|
362
|
+
constant float & m1,
|
363
|
+
constant uint32_t & n_head_log2,
|
364
|
+
threadgroup float * buf [[threadgroup(0)]],
|
360
365
|
uint tgpig[[threadgroup_position_in_grid]],
|
361
366
|
uint tpitg[[thread_position_in_threadgroup]],
|
362
367
|
uint sgitg[[simdgroup_index_in_threadgroup]],
|
@@ -368,13 +373,26 @@ kernel void kernel_soft_max(
|
|
368
373
|
|
369
374
|
device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
370
375
|
device const float * pmask = src1 != src0 ? src1 + i01*ne00 : nullptr;
|
376
|
+
device const float * ppos = src2 != src0 ? src2 : nullptr;
|
371
377
|
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
372
378
|
|
379
|
+
float slope = 0.0f;
|
380
|
+
|
381
|
+
// ALiBi
|
382
|
+
if (max_bias > 0.0f) {
|
383
|
+
const int64_t h = i02;
|
384
|
+
|
385
|
+
const float base = h < n_head_log2 ? m0 : m1;
|
386
|
+
const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
387
|
+
|
388
|
+
slope = pow(base, exp);
|
389
|
+
}
|
390
|
+
|
373
391
|
// parallel max
|
374
392
|
float lmax = -INFINITY;
|
375
393
|
|
376
394
|
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
377
|
-
lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f));
|
395
|
+
lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f));
|
378
396
|
}
|
379
397
|
|
380
398
|
// find the max value in the block
|
@@ -399,7 +417,7 @@ kernel void kernel_soft_max(
|
|
399
417
|
// parallel sum
|
400
418
|
float lsum = 0.0f;
|
401
419
|
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
402
|
-
const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val);
|
420
|
+
const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f)) - max_val);
|
403
421
|
lsum += exp_psrc0;
|
404
422
|
pdst[i00] = exp_psrc0;
|
405
423
|
}
|
@@ -437,12 +455,17 @@ kernel void kernel_soft_max(
|
|
437
455
|
kernel void kernel_soft_max_4(
|
438
456
|
device const float * src0,
|
439
457
|
device const float * src1,
|
458
|
+
device const float * src2,
|
440
459
|
device float * dst,
|
441
460
|
constant int64_t & ne00,
|
442
461
|
constant int64_t & ne01,
|
443
462
|
constant int64_t & ne02,
|
444
463
|
constant float & scale,
|
445
|
-
|
464
|
+
constant float & max_bias,
|
465
|
+
constant float & m0,
|
466
|
+
constant float & m1,
|
467
|
+
constant uint32_t & n_head_log2,
|
468
|
+
threadgroup float * buf [[threadgroup(0)]],
|
446
469
|
uint tgpig[[threadgroup_position_in_grid]],
|
447
470
|
uint tpitg[[thread_position_in_threadgroup]],
|
448
471
|
uint sgitg[[simdgroup_index_in_threadgroup]],
|
@@ -454,13 +477,25 @@ kernel void kernel_soft_max_4(
|
|
454
477
|
|
455
478
|
device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
456
479
|
device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
|
480
|
+
device const float4 * ppos = src2 != src0 ? (device const float4 *)(src2) : nullptr;
|
457
481
|
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
458
482
|
|
483
|
+
float slope = 0.0f;
|
484
|
+
|
485
|
+
if (max_bias > 0.0f) {
|
486
|
+
const int64_t h = i02;
|
487
|
+
|
488
|
+
const float base = h < n_head_log2 ? m0 : m1;
|
489
|
+
const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
490
|
+
|
491
|
+
slope = pow(base, exp);
|
492
|
+
}
|
493
|
+
|
459
494
|
// parallel max
|
460
495
|
float4 lmax4 = -INFINITY;
|
461
496
|
|
462
497
|
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
463
|
-
lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f));
|
498
|
+
lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f));
|
464
499
|
}
|
465
500
|
|
466
501
|
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
|
@@ -486,7 +521,7 @@ kernel void kernel_soft_max_4(
|
|
486
521
|
// parallel sum
|
487
522
|
float4 lsum4 = 0.0f;
|
488
523
|
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
489
|
-
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val);
|
524
|
+
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f)) - max_val);
|
490
525
|
lsum4 += exp_psrc4;
|
491
526
|
pdst4[i00] = exp_psrc4;
|
492
527
|
}
|
@@ -2490,6 +2525,19 @@ typedef struct {
|
|
2490
2525
|
} block_iq3_xxs;
|
2491
2526
|
// 98 bytes / block for QK_K = 256, so 3.0625 bpw
|
2492
2527
|
|
2528
|
+
typedef struct {
|
2529
|
+
half d;
|
2530
|
+
uint8_t qs[QK_K/8];
|
2531
|
+
uint8_t scales[QK_K/16];
|
2532
|
+
} block_iq1_s;
|
2533
|
+
|
2534
|
+
// Non-linear quants
|
2535
|
+
#define QK4_NL 32
|
2536
|
+
typedef struct {
|
2537
|
+
half d;
|
2538
|
+
uint8_t qs[QK4_NL/2];
|
2539
|
+
} block_iq4_nl;
|
2540
|
+
|
2493
2541
|
//====================================== dot products =========================
|
2494
2542
|
|
2495
2543
|
void kernel_mul_mv_q2_K_f32_impl(
|
@@ -3747,6 +3795,137 @@ constexpr constant static uint32_t iq3xxs_grid[256] = {
|
|
3747
3795
|
0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
|
3748
3796
|
};
|
3749
3797
|
|
3798
|
+
#define NGRID_IQ1S 512
|
3799
|
+
constexpr constant static uint64_t iq1s_grid[NGRID_IQ1S] = {
|
3800
|
+
0xffffffffffff0101, 0xffffffffff01ff00, 0xffffffffff010100, 0xffffffff00000000,
|
3801
|
+
0xffffffff01ff00ff, 0xffffffff01ff0001, 0xffffffff0101ffff, 0xffffffff0101ff01,
|
3802
|
+
0xffffff00ff000000, 0xffffff000000ff00, 0xffffff00000000ff, 0xffffff0000000100,
|
3803
|
+
0xffffff0000010000, 0xffffff0001000000, 0xffffff01ffff00ff, 0xffffff01ff01ff00,
|
3804
|
+
0xffffff01ff010100, 0xffffff0100000001, 0xffffff0101ffff00, 0xffffff0101ff0101,
|
3805
|
+
0xffffff0101010100, 0xffff00ffff00ff01, 0xffff00ffff0000ff, 0xffff00ff00ff0100,
|
3806
|
+
0xffff00ff0100ff00, 0xffff00ff010001ff, 0xffff0000ff0101ff, 0xffff000000ffff00,
|
3807
|
+
0xffff000000000000, 0xffff00000001ff01, 0xffff000001000101, 0xffff0000010100ff,
|
3808
|
+
0xffff0001ffff0100, 0xffff00010000ff00, 0xffff000100010101, 0xffff000101000000,
|
3809
|
+
0xffff01ffffff0000, 0xffff01ffff01ffff, 0xffff01ffff010100, 0xffff01ff00000000,
|
3810
|
+
0xffff01ff01ffffff, 0xffff01ff01ff0001, 0xffff01ff0101ffff, 0xffff01ff01010001,
|
3811
|
+
0xffff0100ffffff01, 0xffff01000000ffff, 0xffff010000000100, 0xffff010001ff01ff,
|
3812
|
+
0xffff010001000000, 0xffff0101ff000000, 0xffff0101000101ff, 0xffff010101ffff01,
|
3813
|
+
0xffff01010101ff00, 0xff00ffffff000000, 0xff00ffff00ffff00, 0xff00ffff00000001,
|
3814
|
+
0xff00ffff000001ff, 0xff00ffff01010000, 0xff00ff00ffff0000, 0xff00ff00ff00ff00,
|
3815
|
+
0xff00ff00ff0000ff, 0xff00ff00ff000100, 0xff00ff00ff010001, 0xff00ff0000ff0001,
|
3816
|
+
0xff00ff000000ffff, 0xff00ff0000000000, 0xff00ff000001ff00, 0xff00ff0000010100,
|
3817
|
+
0xff00ff0001ff0000, 0xff00ff000100ff00, 0xff00ff0001000100, 0xff00ff01ff000000,
|
3818
|
+
0xff00ff0100ff0000, 0xff00ff01000001ff, 0xff00ff0101010001, 0xff0000ff00000000,
|
3819
|
+
0xff0000ff0001ff00, 0xff0000ff00010100, 0xff000000ffff0101, 0xff000000ff000000,
|
3820
|
+
0xff000000ff01ff00, 0xff00000000ff0000, 0xff0000000000ff00, 0xff000000000000ff,
|
3821
|
+
0xff00000000000000, 0xff00000000000001, 0xff00000000000100, 0xff0000000001ffff,
|
3822
|
+
0xff00000000010000, 0xff00000001000000, 0xff00000001010100, 0xff000001ff00ff01,
|
3823
|
+
0xff000001ff0100ff, 0xff00000100000000, 0xff0000010001ff00, 0xff00000101ff0100,
|
3824
|
+
0xff0000010100ff00, 0xff0001ff00ff00ff, 0xff0001ff00000101, 0xff0001ff000100ff,
|
3825
|
+
0xff0001ff01000000, 0xff000100ff0001ff, 0xff0001000000ff01, 0xff00010000000000,
|
3826
|
+
0xff00010000010001, 0xff00010000010100, 0xff00010001ffff00, 0xff00010001ff0101,
|
3827
|
+
0xff00010001010000, 0xff000101ffffffff, 0xff000101ff000101, 0xff00010101ff00ff,
|
3828
|
+
0xff00010101000001, 0xff000101010100ff, 0xff01ffffff000101, 0xff01ffffff01ffff,
|
3829
|
+
0xff01ffffff01ff01, 0xff01ffffff0101ff, 0xff01ffff00000000, 0xff01ffff01ff0001,
|
3830
|
+
0xff01ffff0101ff01, 0xff01ff00ff000000, 0xff01ff0000ff0100, 0xff01ff000000ff01,
|
3831
|
+
0xff01ff0000010000, 0xff01ff00010000ff, 0xff01ff01ff01ff00, 0xff01ff0100000101,
|
3832
|
+
0xff0100ffffff0000, 0xff0100ffff010000, 0xff0100ff01ff00ff, 0xff0100ff01000100,
|
3833
|
+
0xff0100ff010100ff, 0xff010000ffffff01, 0xff01000000000000, 0xff0100000101ff00,
|
3834
|
+
0xff010001ffff00ff, 0xff010001ff000100, 0xff01000100ffff00, 0xff01000100010001,
|
3835
|
+
0xff01000101ff0001, 0xff010001010001ff, 0xff0101ffffffffff, 0xff0101ffff01ffff,
|
3836
|
+
0xff0101ffff010101, 0xff0101ff0000ff00, 0xff0101ff01010001, 0xff010100ff000000,
|
3837
|
+
0xff010100ff01ff01, 0xff01010000ff0001, 0xff01010000000100, 0xff01010001000000,
|
3838
|
+
0xff0101010100ffff, 0x00ffffff0000ff01, 0x00ffffff000000ff, 0x00ffffff00000100,
|
3839
|
+
0x00ffffff00010000, 0x00ffff00ffff0001, 0x00ffff00ff0000ff, 0x00ffff00ff000100,
|
3840
|
+
0x00ffff0000000000, 0x00ffff0001000100, 0x00ffff0001010001, 0x00ffff01ff00ff01,
|
3841
|
+
0x00ffff0100ff0100, 0x00ffff010000ff00, 0x00ffff01000100ff, 0x00ffff0101ff00ff,
|
3842
|
+
0x00ffff010101ff00, 0x00ff00ffffffffff, 0x00ff00ffffff01ff, 0x00ff00ffff000101,
|
3843
|
+
0x00ff00ff00000000, 0x00ff00ff000101ff, 0x00ff00ff01010101, 0x00ff0000ff000000,
|
3844
|
+
0x00ff0000ff01ffff, 0x00ff000000ff0000, 0x00ff00000000ff00, 0x00ff0000000000ff,
|
3845
|
+
0x00ff000000000000, 0x00ff000000000001, 0x00ff000000000100, 0x00ff000000010000,
|
3846
|
+
0x00ff000001ffff01, 0x00ff000001000000, 0x00ff0001ff000101, 0x00ff000100ffffff,
|
3847
|
+
0x00ff000100000000, 0x00ff0001010001ff, 0x00ff01ffff000000, 0x00ff01ff0001ff00,
|
3848
|
+
0x00ff01ff01ff0100, 0x00ff0100ff01ff01, 0x00ff010000ff00ff, 0x00ff010000ff0101,
|
3849
|
+
0x00ff010000000000, 0x00ff010000010101, 0x00ff01000100ff00, 0x00ff010001010000,
|
3850
|
+
0x00ff0101ffffff00, 0x00ff01010000ff01, 0x00ff010100000100, 0x00ff010101ff0000,
|
3851
|
+
0x0000ffffffff0100, 0x0000ffffff00ff00, 0x0000ffffff0000ff, 0x0000ffffff010000,
|
3852
|
+
0x0000ffff00000000, 0x0000ffff00010101, 0x0000ffff01ffff01, 0x0000ffff01000100,
|
3853
|
+
0x0000ff00ff000000, 0x0000ff00ff01ff00, 0x0000ff00ff0101ff, 0x0000ff0000ff0000,
|
3854
|
+
0x0000ff000000ff00, 0x0000ff00000000ff, 0x0000ff0000000000, 0x0000ff0000000001,
|
3855
|
+
0x0000ff0000000100, 0x0000ff0000010000, 0x0000ff0001ffffff, 0x0000ff0001ff01ff,
|
3856
|
+
0x0000ff0001000000, 0x0000ff000101ffff, 0x0000ff01ffff0101, 0x0000ff01ff010000,
|
3857
|
+
0x0000ff0100000000, 0x0000ff0101000101, 0x000000ffffff0001, 0x000000ffff000000,
|
3858
|
+
0x000000ff00ff0000, 0x000000ff0000ff00, 0x000000ff000000ff, 0x000000ff00000000,
|
3859
|
+
0x000000ff00000001, 0x000000ff00000100, 0x000000ff00010000, 0x000000ff01000000,
|
3860
|
+
0x000000ff0101ff00, 0x00000000ffff0000, 0x00000000ff00ff00, 0x00000000ff0000ff,
|
3861
|
+
0x00000000ff000000, 0x00000000ff000001, 0x00000000ff000100, 0x00000000ff010000,
|
3862
|
+
0x0000000000ffff00, 0x0000000000ff00ff, 0x0000000000ff0000, 0x0000000000ff0001,
|
3863
|
+
0x0000000000ff0100, 0x000000000000ffff, 0x000000000000ff00, 0x000000000000ff01,
|
3864
|
+
0x00000000000000ff, 0x0000000000000001, 0x00000000000001ff, 0x0000000000000100,
|
3865
|
+
0x0000000000000101, 0x000000000001ff00, 0x00000000000100ff, 0x0000000000010000,
|
3866
|
+
0x0000000000010001, 0x0000000000010100, 0x0000000001ff0000, 0x000000000100ff00,
|
3867
|
+
0x00000000010000ff, 0x0000000001000000, 0x0000000001000001, 0x0000000001000100,
|
3868
|
+
0x0000000001010000, 0x00000001ffff01ff, 0x00000001ff000000, 0x0000000100ff0000,
|
3869
|
+
0x000000010000ff00, 0x00000001000000ff, 0x0000000100000000, 0x0000000100000001,
|
3870
|
+
0x0000000100000100, 0x0000000100010000, 0x0000000101000000, 0x000001ffff00ff00,
|
3871
|
+
0x000001ffff010001, 0x000001ffff0101ff, 0x000001ff00ffff01, 0x000001ff0000ffff,
|
3872
|
+
0x000001ff00000000, 0x000001ff010000ff, 0x000001ff01010100, 0x00000100ffff0100,
|
3873
|
+
0x00000100ff000000, 0x0000010000ff0000, 0x000001000000ff00, 0x00000100000000ff,
|
3874
|
+
0x0000010000000000, 0x0000010000000001, 0x0000010000000100, 0x0000010000010000,
|
3875
|
+
0x0000010001000000, 0x000001000101ff01, 0x00000101ffff0001, 0x00000101ff01ffff,
|
3876
|
+
0x0000010100000000, 0x0000010101010100, 0x0001ffffff000000, 0x0001ffff00ffffff,
|
3877
|
+
0x0001ffff00000100, 0x0001ffff0001ff00, 0x0001ffff01000000, 0x0001ff00ffffff00,
|
3878
|
+
0x0001ff00ffff01ff, 0x0001ff00ff010000, 0x0001ff0000000000, 0x0001ff0000010001,
|
3879
|
+
0x0001ff0001ff0000, 0x0001ff0001010100, 0x0001ff01ff0000ff, 0x0001ff01ff000001,
|
3880
|
+
0x0001ff0100ffffff, 0x0001ff010001ffff, 0x0001ff01000101ff, 0x0001ff010100ff01,
|
3881
|
+
0x000100ffff00ffff, 0x000100ffff00ff01, 0x000100ffff000100, 0x000100ff00000000,
|
3882
|
+
0x000100ff000101ff, 0x000100ff01ff0101, 0x000100ff0100ffff, 0x000100ff01010101,
|
3883
|
+
0x00010000ff000000, 0x00010000ff010100, 0x0001000000ff0000, 0x000100000000ff00,
|
3884
|
+
0x00010000000000ff, 0x0001000000000000, 0x0001000000000001, 0x0001000000000100,
|
3885
|
+
0x0001000000010000, 0x0001000001ffff01, 0x0001000001000000, 0x0001000100ff0101,
|
3886
|
+
0x0001000100000000, 0x00010001010100ff, 0x000101ffffff01ff, 0x000101ffffff0101,
|
3887
|
+
0x000101ff00010000, 0x000101ff01ff0000, 0x000101ff0100ff01, 0x00010100ffff0000,
|
3888
|
+
0x0001010000000000, 0x000101000001ffff, 0x0001010000010101, 0x00010100010001ff,
|
3889
|
+
0x00010101ff00ff00, 0x00010101ff010001, 0x0001010100ffffff, 0x0001010100ff01ff,
|
3890
|
+
0x00010101000101ff, 0x0001010101ff0000, 0x000101010100ff01, 0x0001010101000101,
|
3891
|
+
0x01ffffffffff0101, 0x01ffffffff01ffff, 0x01ffffffff01ff01, 0x01ffffffff0101ff,
|
3892
|
+
0x01ffffffff010101, 0x01ffffff00000000, 0x01ffffff01ff01ff, 0x01ffffff01000101,
|
3893
|
+
0x01ffffff0101ff01, 0x01ffffff010100ff, 0x01ffff000000ff00, 0x01ffff0000000001,
|
3894
|
+
0x01ffff00000001ff, 0x01ffff0000010000, 0x01ffff0001ff0000, 0x01ffff01ffffffff,
|
3895
|
+
0x01ffff01ffff01ff, 0x01ffff01ff000000, 0x01ffff01ff01ffff, 0x01ffff01ff0101ff,
|
3896
|
+
0x01ffff010100ffff, 0x01ff00ffffff0000, 0x01ff00ffff010000, 0x01ff00ff00ffff01,
|
3897
|
+
0x01ff0000ff0000ff, 0x01ff000000000000, 0x01ff00000001ff01, 0x01ff000001ffffff,
|
3898
|
+
0x01ff000001010100, 0x01ff0001ffffff01, 0x01ff0001ff010001, 0x01ff000101ff0100,
|
3899
|
+
0x01ff000101000001, 0x01ff0001010100ff, 0x01ff01ffff00ffff, 0x01ff01ff00010001,
|
3900
|
+
0x01ff01ff01000000, 0x01ff01ff010101ff, 0x01ff0100ff000001, 0x01ff010000ffff00,
|
3901
|
+
0x01ff010000000100, 0x01ff010001ff01ff, 0x01ff01000101ffff, 0x01ff0101ffff00ff,
|
3902
|
+
0x01ff0101ffff0101, 0x01ff0101ff0101ff, 0x01ff010100010000, 0x0100ffff00ff00ff,
|
3903
|
+
0x0100ffff00ff0001, 0x0100ffff00000100, 0x0100ffff0100ff00, 0x0100ff00ffff0000,
|
3904
|
+
0x0100ff00ff00ffff, 0x0100ff00ff00ff01, 0x0100ff00ff000100, 0x0100ff00ff010000,
|
3905
|
+
0x0100ff0000000000, 0x0100ff00000100ff, 0x0100ff0001ff0101, 0x0100ff0001010101,
|
3906
|
+
0x0100ff0100ff00ff, 0x0100ff0100ff0001, 0x0100ff0100000100, 0x0100ff0100010001,
|
3907
|
+
0x0100ff0101000000, 0x010000ffff00ff00, 0x010000ff0000ffff, 0x010000ff00000000,
|
3908
|
+
0x010000ff010001ff, 0x010000ff01010001, 0x01000000ffffff00, 0x01000000ffff0101,
|
3909
|
+
0x01000000ff000000, 0x01000000ff0100ff, 0x01000000ff010101, 0x0100000000ff0000,
|
3910
|
+
0x010000000000ff00, 0x01000000000000ff, 0x0100000000000000, 0x0100000000000001,
|
3911
|
+
0x0100000000000100, 0x0100000000010000, 0x0100000001000000, 0x0100000100000000,
|
3912
|
+
0x01000001000101ff, 0x0100000101ffff01, 0x010001ffff000101, 0x010001ff00ff0100,
|
3913
|
+
0x010001ff0000ff00, 0x010001ff000100ff, 0x010001ff01ffffff, 0x01000100ffff0000,
|
3914
|
+
0x01000100ff0001ff, 0x0100010000000000, 0x010001000001ff00, 0x0100010001ff0000,
|
3915
|
+
0x01000100010000ff, 0x0100010001000101, 0x01000101ff00ff01, 0x0100010100ff0100,
|
3916
|
+
0x010001010000ffff, 0x0100010101010001, 0x0101ffffffff0101, 0x0101ffffff0001ff,
|
3917
|
+
0x0101ffffff01ffff, 0x0101ffffff010101, 0x0101ffff00000000, 0x0101ffff0101ffff,
|
3918
|
+
0x0101ffff010101ff, 0x0101ff00ff000000, 0x0101ff0000ff0100, 0x0101ff000000ff00,
|
3919
|
+
0x0101ff0000010000, 0x0101ff00010000ff, 0x0101ff0001000001, 0x0101ff01ff010101,
|
3920
|
+
0x0101ff0100000000, 0x0101ff010101ff00, 0x010100ffffff0000, 0x010100ffff010000,
|
3921
|
+
0x010100ff00ff01ff, 0x010100ff000000ff, 0x010100ff00000101, 0x010100ff01ffff00,
|
3922
|
+
0x01010000ffffff01, 0x01010000ff000100, 0x01010000ff01ff01, 0x0101000000000000,
|
3923
|
+
0x01010000000100ff, 0x010100000101ff01, 0x01010001ffff0000, 0x01010001ff00ffff,
|
3924
|
+
0x01010001ff010000, 0x0101000101ffffff, 0x0101000101ff01ff, 0x0101000101010101,
|
3925
|
+
0x010101ffff01ffff, 0x010101ff00000000, 0x010101ff0001ff01, 0x010101ff0101ffff,
|
3926
|
+
0x010101ff010101ff, 0x01010100ffffffff, 0x01010100ff000001, 0x010101000000ff00,
|
3927
|
+
0x0101010001010000, 0x0101010100ff0001, 0x010101010001ff01, 0x010101010101ffff,
|
3928
|
+
};
|
3750
3929
|
|
3751
3930
|
constexpr constant static uint8_t ksigns_iq2xs[128] = {
|
3752
3931
|
0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
|
@@ -3854,7 +4033,10 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
|
|
3854
4033
|
y4 += 32 * 32;
|
3855
4034
|
}
|
3856
4035
|
#else
|
3857
|
-
|
4036
|
+
(void) x;
|
4037
|
+
(void) y;
|
4038
|
+
(void) yl;
|
4039
|
+
(void) nb32;
|
3858
4040
|
#endif
|
3859
4041
|
|
3860
4042
|
for (int row = 0; row < N_DST; ++row) {
|
@@ -3997,7 +4179,10 @@ void kernel_mul_mv_iq2_xs_f32_impl(
|
|
3997
4179
|
y4 += 32 * 32;
|
3998
4180
|
}
|
3999
4181
|
#else
|
4000
|
-
|
4182
|
+
(void) x;
|
4183
|
+
(void) y;
|
4184
|
+
(void) yl;
|
4185
|
+
(void) nb32;
|
4001
4186
|
#endif
|
4002
4187
|
|
4003
4188
|
for (int row = 0; row < N_DST; ++row) {
|
@@ -4133,7 +4318,10 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
|
|
4133
4318
|
y4 += 32 * 32;
|
4134
4319
|
}
|
4135
4320
|
#else
|
4136
|
-
|
4321
|
+
(void) x;
|
4322
|
+
(void) y;
|
4323
|
+
(void) yl;
|
4324
|
+
(void) nb32;
|
4137
4325
|
#endif
|
4138
4326
|
|
4139
4327
|
for (int row = 0; row < N_DST; ++row) {
|
@@ -4173,6 +4361,250 @@ kernel void kernel_mul_mv_iq3_xxs_f32(
|
|
4173
4361
|
kernel_mul_mv_iq3_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
4174
4362
|
}
|
4175
4363
|
|
4364
|
+
void kernel_mul_mv_iq1_s_f32_impl(
|
4365
|
+
device const void * src0,
|
4366
|
+
device const float * src1,
|
4367
|
+
device float * dst,
|
4368
|
+
constant int64_t & ne00,
|
4369
|
+
constant int64_t & ne01,
|
4370
|
+
constant int64_t & ne02,
|
4371
|
+
constant int64_t & ne10,
|
4372
|
+
constant int64_t & ne12,
|
4373
|
+
constant int64_t & ne0,
|
4374
|
+
constant int64_t & ne1,
|
4375
|
+
constant uint & r2,
|
4376
|
+
constant uint & r3,
|
4377
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
4378
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
4379
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4380
|
+
|
4381
|
+
const int nb = ne00/QK_K;
|
4382
|
+
const int r0 = tgpig.x;
|
4383
|
+
const int r1 = tgpig.y;
|
4384
|
+
const int im = tgpig.z;
|
4385
|
+
|
4386
|
+
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
4387
|
+
const int ib_row = first_row * nb;
|
4388
|
+
|
4389
|
+
const uint i12 = im%ne12;
|
4390
|
+
const uint i13 = im/ne12;
|
4391
|
+
|
4392
|
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
4393
|
+
device const block_iq1_s * x = (device const block_iq1_s *) src0 + ib_row + offset0;
|
4394
|
+
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
4395
|
+
|
4396
|
+
float yl[16];
|
4397
|
+
float sumf[N_DST]={0.f}, all_sum;
|
4398
|
+
|
4399
|
+
const int nb32 = nb * (QK_K / 32);
|
4400
|
+
|
4401
|
+
#if QK_K == 256
|
4402
|
+
const int ix = tiisg/2;
|
4403
|
+
const int il = tiisg%2;
|
4404
|
+
|
4405
|
+
device const float * y4 = y + 32 * ix + 16 * il;
|
4406
|
+
|
4407
|
+
for (int ib32 = ix; ib32 < nb32; ib32 += 16) {
|
4408
|
+
|
4409
|
+
for (int i = 0; i < 16; ++i) {
|
4410
|
+
yl[i] = y4[i];
|
4411
|
+
}
|
4412
|
+
|
4413
|
+
const int ibl = ib32 / (QK_K / 32);
|
4414
|
+
const int ib = ib32 % (QK_K / 32);
|
4415
|
+
|
4416
|
+
device const block_iq1_s * xr = x + ibl;
|
4417
|
+
device const uint8_t * qs = xr->qs + 4 * ib + 2 * il;
|
4418
|
+
device const uint8_t * sc = xr->scales + 2 * ib + il;
|
4419
|
+
device const half * dh = &xr->d;
|
4420
|
+
|
4421
|
+
for (int row = 0; row < N_DST; row++) {
|
4422
|
+
|
4423
|
+
constant int8_t * grid1 = (constant int8_t *)(iq1s_grid + (qs[0] | ((sc[0] & 0x08) << 5)));
|
4424
|
+
constant int8_t * grid2 = (constant int8_t *)(iq1s_grid + (qs[1] | ((sc[0] & 0x80) << 1)));
|
4425
|
+
|
4426
|
+
float2 sum = {0};
|
4427
|
+
for (int j = 0; j < 8; ++j) {
|
4428
|
+
sum[0] += yl[j+ 0] * grid1[j];
|
4429
|
+
sum[1] += yl[j+ 8] * grid2[j];
|
4430
|
+
}
|
4431
|
+
sumf[row] += (float)dh[0] * (sum[0] * (2*(sc[0] & 7) + 1) + sum[1] * (2*((sc[0] >> 4) & 7) + 1));
|
4432
|
+
|
4433
|
+
dh += nb*sizeof(block_iq1_s)/2;
|
4434
|
+
qs += nb*sizeof(block_iq1_s);
|
4435
|
+
sc += nb*sizeof(block_iq1_s);
|
4436
|
+
}
|
4437
|
+
|
4438
|
+
y4 += 16 * 32;
|
4439
|
+
}
|
4440
|
+
#else
|
4441
|
+
(void) x;
|
4442
|
+
(void) y;
|
4443
|
+
(void) yl;
|
4444
|
+
(void) nb32;
|
4445
|
+
#endif
|
4446
|
+
|
4447
|
+
for (int row = 0; row < N_DST; ++row) {
|
4448
|
+
all_sum = simd_sum(sumf[row]);
|
4449
|
+
if (tiisg == 0) {
|
4450
|
+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
|
4451
|
+
}
|
4452
|
+
}
|
4453
|
+
}
|
4454
|
+
|
4455
|
+
constexpr constant static float kvalues_iq4nl_f[16] = {
|
4456
|
+
-127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
|
4457
|
+
};
|
4458
|
+
|
4459
|
+
void kernel_mul_mv_iq4_nl_f32_impl(
|
4460
|
+
device const void * src0,
|
4461
|
+
device const float * src1,
|
4462
|
+
device float * dst,
|
4463
|
+
constant int64_t & ne00,
|
4464
|
+
constant int64_t & ne01,
|
4465
|
+
constant int64_t & ne02,
|
4466
|
+
constant int64_t & ne10,
|
4467
|
+
constant int64_t & ne12,
|
4468
|
+
constant int64_t & ne0,
|
4469
|
+
constant int64_t & ne1,
|
4470
|
+
constant uint & r2,
|
4471
|
+
constant uint & r3,
|
4472
|
+
threadgroup float * shared_values [[threadgroup(0)]],
|
4473
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
4474
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
4475
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4476
|
+
|
4477
|
+
const int nb = ne00/QK4_NL;
|
4478
|
+
const int r0 = tgpig.x;
|
4479
|
+
const int r1 = tgpig.y;
|
4480
|
+
const int im = tgpig.z;
|
4481
|
+
const int first_row = (r0 * 2 + sgitg) * 2;
|
4482
|
+
const int ib_row = first_row * nb;
|
4483
|
+
|
4484
|
+
const uint i12 = im%ne12;
|
4485
|
+
const uint i13 = im/ne12;
|
4486
|
+
|
4487
|
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
4488
|
+
device const block_iq4_nl * x = (device const block_iq4_nl *) src0 + ib_row + offset0;
|
4489
|
+
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
4490
|
+
|
4491
|
+
const int ix = tiisg/2; // 0...15
|
4492
|
+
const int it = tiisg%2; // 0 or 1
|
4493
|
+
|
4494
|
+
shared_values[tiisg] = kvalues_iq4nl_f[tiisg%16];
|
4495
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
4496
|
+
|
4497
|
+
float4 yl[4];
|
4498
|
+
float sumf[2]={0.f}, all_sum;
|
4499
|
+
|
4500
|
+
device const float * yb = y + ix * QK4_NL + it * 8;
|
4501
|
+
|
4502
|
+
uint32_t aux32[2];
|
4503
|
+
thread const uint8_t * q8 = (thread const uint8_t *)aux32;
|
4504
|
+
|
4505
|
+
float4 qf1, qf2;
|
4506
|
+
|
4507
|
+
for (int ib = ix; ib < nb; ib += 16) {
|
4508
|
+
|
4509
|
+
device const float4 * y4 = (device const float4 *)yb;
|
4510
|
+
yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5];
|
4511
|
+
|
4512
|
+
for (int row = 0; row < 2; ++row) {
|
4513
|
+
|
4514
|
+
device const block_iq4_nl & xb = x[row*nb + ib];
|
4515
|
+
device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it);
|
4516
|
+
|
4517
|
+
float4 acc1 = {0.f}, acc2 = {0.f};
|
4518
|
+
|
4519
|
+
aux32[0] = q4[0] | (q4[1] << 16);
|
4520
|
+
aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f;
|
4521
|
+
aux32[0] &= 0x0f0f0f0f;
|
4522
|
+
qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
|
4523
|
+
qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
|
4524
|
+
acc1 += yl[0] * qf1;
|
4525
|
+
acc2 += yl[1] * qf2;
|
4526
|
+
|
4527
|
+
aux32[0] = q4[2] | (q4[3] << 16);
|
4528
|
+
aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f;
|
4529
|
+
aux32[0] &= 0x0f0f0f0f;
|
4530
|
+
qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
|
4531
|
+
qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
|
4532
|
+
acc1 += yl[2] * qf1;
|
4533
|
+
acc2 += yl[3] * qf2;
|
4534
|
+
|
4535
|
+
acc1 += acc2;
|
4536
|
+
|
4537
|
+
sumf[row] += (float)xb.d * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
|
4538
|
+
|
4539
|
+
}
|
4540
|
+
|
4541
|
+
yb += 16 * QK4_NL;
|
4542
|
+
}
|
4543
|
+
|
4544
|
+
for (int row = 0; row < 2; ++row) {
|
4545
|
+
all_sum = simd_sum(sumf[row]);
|
4546
|
+
if (tiisg == 0) {
|
4547
|
+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
|
4548
|
+
}
|
4549
|
+
}
|
4550
|
+
}
|
4551
|
+
|
4552
|
+
[[host_name("kernel_mul_mv_iq1_s_f32")]]
|
4553
|
+
kernel void kernel_mul_mv_iq1_s_f32(
|
4554
|
+
device const void * src0,
|
4555
|
+
device const float * src1,
|
4556
|
+
device float * dst,
|
4557
|
+
constant int64_t & ne00,
|
4558
|
+
constant int64_t & ne01,
|
4559
|
+
constant int64_t & ne02,
|
4560
|
+
constant uint64_t & nb00,
|
4561
|
+
constant uint64_t & nb01,
|
4562
|
+
constant uint64_t & nb02,
|
4563
|
+
constant int64_t & ne10,
|
4564
|
+
constant int64_t & ne11,
|
4565
|
+
constant int64_t & ne12,
|
4566
|
+
constant uint64_t & nb10,
|
4567
|
+
constant uint64_t & nb11,
|
4568
|
+
constant uint64_t & nb12,
|
4569
|
+
constant int64_t & ne0,
|
4570
|
+
constant int64_t & ne1,
|
4571
|
+
constant uint & r2,
|
4572
|
+
constant uint & r3,
|
4573
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
4574
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
4575
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4576
|
+
|
4577
|
+
kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
4578
|
+
}
|
4579
|
+
|
4580
|
+
[[host_name("kernel_mul_mv_iq4_nl_f32")]]
|
4581
|
+
kernel void kernel_mul_mv_iq4_nl_f32(
|
4582
|
+
device const void * src0,
|
4583
|
+
device const float * src1,
|
4584
|
+
device float * dst,
|
4585
|
+
constant int64_t & ne00,
|
4586
|
+
constant int64_t & ne01,
|
4587
|
+
constant int64_t & ne02,
|
4588
|
+
constant uint64_t & nb00,
|
4589
|
+
constant uint64_t & nb01,
|
4590
|
+
constant uint64_t & nb02,
|
4591
|
+
constant int64_t & ne10,
|
4592
|
+
constant int64_t & ne11,
|
4593
|
+
constant int64_t & ne12,
|
4594
|
+
constant uint64_t & nb10,
|
4595
|
+
constant uint64_t & nb11,
|
4596
|
+
constant uint64_t & nb12,
|
4597
|
+
constant int64_t & ne0,
|
4598
|
+
constant int64_t & ne1,
|
4599
|
+
constant uint & r2,
|
4600
|
+
constant uint & r3,
|
4601
|
+
threadgroup float * shared_values [[threadgroup(0)]],
|
4602
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
4603
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
4604
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4605
|
+
|
4606
|
+
kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
4607
|
+
}
|
4176
4608
|
|
4177
4609
|
//============================= templates and their specializations =============================
|
4178
4610
|
|
@@ -4369,6 +4801,8 @@ void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg
|
|
4369
4801
|
const float dl = d * sc[0];
|
4370
4802
|
const float ml = min * sc[1];
|
4371
4803
|
#else
|
4804
|
+
(void) get_scale_min_k4_just2;
|
4805
|
+
|
4372
4806
|
q = q + 16 * (il&1);
|
4373
4807
|
device const uint8_t * s = xb->scales;
|
4374
4808
|
device const half2 * dh = (device const half2 *)xb->d;
|
@@ -4518,6 +4952,37 @@ void dequantize_iq3_xxs(device const block_iq3_xxs * xb, short il, thread type4x
|
|
4518
4952
|
}
|
4519
4953
|
}
|
4520
4954
|
|
4955
|
+
template <typename type4x4>
|
4956
|
+
void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 & reg) {
|
4957
|
+
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
|
4958
|
+
const float d = xb->d;
|
4959
|
+
device const uint8_t * qs = xb->qs + 2*il;
|
4960
|
+
device const uint8_t * sc = xb->scales + il;
|
4961
|
+
const float dl1 = d * (2*(sc[0] & 7) + 1);
|
4962
|
+
const float dl2 = d * (2*((sc[0] >> 4) & 7) + 1);
|
4963
|
+
constant int8_t * grid1 = (constant int8_t *)(iq1s_grid + (qs[0] | ((sc[0] & 0x08) << 5)));
|
4964
|
+
constant int8_t * grid2 = (constant int8_t *)(iq1s_grid + (qs[1] | ((sc[0] & 0x80) << 1)));
|
4965
|
+
for (int i = 0; i < 8; ++i) {
|
4966
|
+
reg[i/4+0][i%4] = dl1 * grid1[i];
|
4967
|
+
reg[i/4+2][i%4] = dl2 * grid2[i];
|
4968
|
+
}
|
4969
|
+
}
|
4970
|
+
|
4971
|
+
template <typename type4x4>
|
4972
|
+
void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4 & reg) {
|
4973
|
+
device const uint16_t * q4 = (device const uint16_t *)xb->qs;
|
4974
|
+
const float d = xb->d;
|
4975
|
+
uint32_t aux32;
|
4976
|
+
thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
|
4977
|
+
for (int i = 0; i < 4; ++i) {
|
4978
|
+
aux32 = ((q4[2*i] | (q4[2*i+1] << 16)) >> 4*il) & 0x0f0f0f0f;
|
4979
|
+
reg[i][0] = d * kvalues_iq4nl_f[q8[0]];
|
4980
|
+
reg[i][1] = d * kvalues_iq4nl_f[q8[1]];
|
4981
|
+
reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
|
4982
|
+
reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
|
4983
|
+
}
|
4984
|
+
}
|
4985
|
+
|
4521
4986
|
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
|
4522
4987
|
kernel void kernel_get_rows(
|
4523
4988
|
device const void * src0,
|
@@ -5060,6 +5525,8 @@ template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows
|
|
5060
5525
|
template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_t kernel_get_rows<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
5061
5526
|
template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_t kernel_get_rows<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
5062
5527
|
template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_t kernel_get_rows<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
5528
|
+
template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
5529
|
+
template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows<block_iq4_nl, 2, dequantize_iq4_nl>;
|
5063
5530
|
|
5064
5531
|
//
|
5065
5532
|
// matrix-matrix multiplication
|
@@ -5099,6 +5566,8 @@ template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<b
|
|
5099
5566
|
template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
5100
5567
|
template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
5101
5568
|
template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
5569
|
+
template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
5570
|
+
template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_nl>;
|
5102
5571
|
|
5103
5572
|
//
|
5104
5573
|
// indirect matrix-matrix multiplication
|
@@ -5150,6 +5619,8 @@ template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mu
|
|
5150
5619
|
template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
5151
5620
|
template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
5152
5621
|
template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
5622
|
+
template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
5623
|
+
template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl, 2, dequantize_iq4_nl>;
|
5153
5624
|
|
5154
5625
|
//
|
5155
5626
|
// matrix-vector multiplication
|
@@ -6117,3 +6588,131 @@ kernel void kernel_mul_mv_id_iq3_xxs_f32(
|
|
6117
6588
|
tiisg,
|
6118
6589
|
sgitg);
|
6119
6590
|
}
|
6591
|
+
|
6592
|
+
[[host_name("kernel_mul_mv_id_iq1_s_f32")]]
|
6593
|
+
kernel void kernel_mul_mv_id_iq1_s_f32(
|
6594
|
+
device const char * ids,
|
6595
|
+
device const char * src1,
|
6596
|
+
device float * dst,
|
6597
|
+
constant uint64_t & nbi1,
|
6598
|
+
constant int64_t & ne00,
|
6599
|
+
constant int64_t & ne01,
|
6600
|
+
constant int64_t & ne02,
|
6601
|
+
constant uint64_t & nb00,
|
6602
|
+
constant uint64_t & nb01,
|
6603
|
+
constant uint64_t & nb02,
|
6604
|
+
constant int64_t & ne10,
|
6605
|
+
constant int64_t & ne11,
|
6606
|
+
constant int64_t & ne12,
|
6607
|
+
constant int64_t & ne13,
|
6608
|
+
constant uint64_t & nb10,
|
6609
|
+
constant uint64_t & nb11,
|
6610
|
+
constant uint64_t & nb12,
|
6611
|
+
constant int64_t & ne0,
|
6612
|
+
constant int64_t & ne1,
|
6613
|
+
constant uint64_t & nb1,
|
6614
|
+
constant uint & r2,
|
6615
|
+
constant uint & r3,
|
6616
|
+
constant int & idx,
|
6617
|
+
device const char * src00,
|
6618
|
+
device const char * src01,
|
6619
|
+
device const char * src02,
|
6620
|
+
device const char * src03,
|
6621
|
+
device const char * src04,
|
6622
|
+
device const char * src05,
|
6623
|
+
device const char * src06,
|
6624
|
+
device const char * src07,
|
6625
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
6626
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
6627
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
6628
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6629
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
6630
|
+
|
6631
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
6632
|
+
|
6633
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
6634
|
+
|
6635
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6636
|
+
|
6637
|
+
kernel_mul_mv_iq1_s_f32_impl(
|
6638
|
+
src0[id],
|
6639
|
+
(device const float *) (src1 + bid*nb11),
|
6640
|
+
dst + bid*ne0,
|
6641
|
+
ne00,
|
6642
|
+
ne01,
|
6643
|
+
ne02,
|
6644
|
+
ne10,
|
6645
|
+
ne12,
|
6646
|
+
ne0,
|
6647
|
+
ne1,
|
6648
|
+
r2,
|
6649
|
+
r3,
|
6650
|
+
tgpig,
|
6651
|
+
tiisg,
|
6652
|
+
sgitg);
|
6653
|
+
}
|
6654
|
+
|
6655
|
+
[[host_name("kernel_mul_mv_id_iq4_nl_f32")]]
|
6656
|
+
kernel void kernel_mul_mv_id_iq4_nl_f32(
|
6657
|
+
device const char * ids,
|
6658
|
+
device const char * src1,
|
6659
|
+
device float * dst,
|
6660
|
+
constant uint64_t & nbi1,
|
6661
|
+
constant int64_t & ne00,
|
6662
|
+
constant int64_t & ne01,
|
6663
|
+
constant int64_t & ne02,
|
6664
|
+
constant uint64_t & nb00,
|
6665
|
+
constant uint64_t & nb01,
|
6666
|
+
constant uint64_t & nb02,
|
6667
|
+
constant int64_t & ne10,
|
6668
|
+
constant int64_t & ne11,
|
6669
|
+
constant int64_t & ne12,
|
6670
|
+
constant int64_t & ne13,
|
6671
|
+
constant uint64_t & nb10,
|
6672
|
+
constant uint64_t & nb11,
|
6673
|
+
constant uint64_t & nb12,
|
6674
|
+
constant int64_t & ne0,
|
6675
|
+
constant int64_t & ne1,
|
6676
|
+
constant uint64_t & nb1,
|
6677
|
+
constant uint & r2,
|
6678
|
+
constant uint & r3,
|
6679
|
+
constant int & idx,
|
6680
|
+
device const char * src00,
|
6681
|
+
device const char * src01,
|
6682
|
+
device const char * src02,
|
6683
|
+
device const char * src03,
|
6684
|
+
device const char * src04,
|
6685
|
+
device const char * src05,
|
6686
|
+
device const char * src06,
|
6687
|
+
device const char * src07,
|
6688
|
+
threadgroup float * shared_values [[threadgroup(0)]],
|
6689
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
6690
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
6691
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
6692
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6693
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
6694
|
+
|
6695
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
6696
|
+
|
6697
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
6698
|
+
|
6699
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6700
|
+
|
6701
|
+
kernel_mul_mv_iq4_nl_f32_impl(
|
6702
|
+
src0[id],
|
6703
|
+
(device const float *) (src1 + bid*nb11),
|
6704
|
+
dst + bid*ne0,
|
6705
|
+
ne00,
|
6706
|
+
ne01,
|
6707
|
+
ne02,
|
6708
|
+
ne10,
|
6709
|
+
ne12,
|
6710
|
+
ne0,
|
6711
|
+
ne1,
|
6712
|
+
r2,
|
6713
|
+
r3,
|
6714
|
+
shared_values,
|
6715
|
+
tgpig,
|
6716
|
+
tiisg,
|
6717
|
+
sgitg);
|
6718
|
+
}
|