node-llama-cpp 2.8.5 → 2.8.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +1 -1
- package/llama/addon.cpp +2 -2
- package/llama/binariesGithubRelease.json +1 -1
- package/llama/gitRelease.bundle +0 -0
- package/llamaBins/linux-arm64/llama-addon.node +0 -0
- package/llamaBins/linux-armv7l/llama-addon.node +0 -0
- package/llamaBins/linux-x64/llama-addon.node +0 -0
- package/llamaBins/mac-arm64/ggml-metal.metal +681 -10
- package/llamaBins/mac-arm64/llama-addon.node +0 -0
- package/llamaBins/mac-x64/ggml-metal.metal +681 -10
- package/llamaBins/mac-x64/llama-addon.node +0 -0
- package/llamaBins/win-x64/llama-addon.node +0 -0
- package/package.json +1 -1
|
@@ -351,12 +351,17 @@ kernel void kernel_sum_rows(
|
|
|
351
351
|
kernel void kernel_soft_max(
|
|
352
352
|
device const float * src0,
|
|
353
353
|
device const float * src1,
|
|
354
|
+
device const float * src2,
|
|
354
355
|
device float * dst,
|
|
355
356
|
constant int64_t & ne00,
|
|
356
357
|
constant int64_t & ne01,
|
|
357
358
|
constant int64_t & ne02,
|
|
358
359
|
constant float & scale,
|
|
359
|
-
|
|
360
|
+
constant float & max_bias,
|
|
361
|
+
constant float & m0,
|
|
362
|
+
constant float & m1,
|
|
363
|
+
constant uint32_t & n_head_log2,
|
|
364
|
+
threadgroup float * buf [[threadgroup(0)]],
|
|
360
365
|
uint tgpig[[threadgroup_position_in_grid]],
|
|
361
366
|
uint tpitg[[thread_position_in_threadgroup]],
|
|
362
367
|
uint sgitg[[simdgroup_index_in_threadgroup]],
|
|
@@ -368,13 +373,26 @@ kernel void kernel_soft_max(
|
|
|
368
373
|
|
|
369
374
|
device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
|
370
375
|
device const float * pmask = src1 != src0 ? src1 + i01*ne00 : nullptr;
|
|
376
|
+
device const float * ppos = src2 != src0 ? src2 : nullptr;
|
|
371
377
|
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
|
372
378
|
|
|
379
|
+
float slope = 0.0f;
|
|
380
|
+
|
|
381
|
+
// ALiBi
|
|
382
|
+
if (max_bias > 0.0f) {
|
|
383
|
+
const int64_t h = i02;
|
|
384
|
+
|
|
385
|
+
const float base = h < n_head_log2 ? m0 : m1;
|
|
386
|
+
const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
|
387
|
+
|
|
388
|
+
slope = pow(base, exp);
|
|
389
|
+
}
|
|
390
|
+
|
|
373
391
|
// parallel max
|
|
374
392
|
float lmax = -INFINITY;
|
|
375
393
|
|
|
376
394
|
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
|
377
|
-
lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f));
|
|
395
|
+
lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f) + slope*ppos[i00]);
|
|
378
396
|
}
|
|
379
397
|
|
|
380
398
|
// find the max value in the block
|
|
@@ -399,7 +417,7 @@ kernel void kernel_soft_max(
|
|
|
399
417
|
// parallel sum
|
|
400
418
|
float lsum = 0.0f;
|
|
401
419
|
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
|
402
|
-
const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val);
|
|
420
|
+
const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f) + slope*ppos[i00]) - max_val);
|
|
403
421
|
lsum += exp_psrc0;
|
|
404
422
|
pdst[i00] = exp_psrc0;
|
|
405
423
|
}
|
|
@@ -437,12 +455,17 @@ kernel void kernel_soft_max(
|
|
|
437
455
|
kernel void kernel_soft_max_4(
|
|
438
456
|
device const float * src0,
|
|
439
457
|
device const float * src1,
|
|
458
|
+
device const float * src2,
|
|
440
459
|
device float * dst,
|
|
441
460
|
constant int64_t & ne00,
|
|
442
461
|
constant int64_t & ne01,
|
|
443
462
|
constant int64_t & ne02,
|
|
444
463
|
constant float & scale,
|
|
445
|
-
|
|
464
|
+
constant float & max_bias,
|
|
465
|
+
constant float & m0,
|
|
466
|
+
constant float & m1,
|
|
467
|
+
constant uint32_t & n_head_log2,
|
|
468
|
+
threadgroup float * buf [[threadgroup(0)]],
|
|
446
469
|
uint tgpig[[threadgroup_position_in_grid]],
|
|
447
470
|
uint tpitg[[thread_position_in_threadgroup]],
|
|
448
471
|
uint sgitg[[simdgroup_index_in_threadgroup]],
|
|
@@ -454,13 +477,25 @@ kernel void kernel_soft_max_4(
|
|
|
454
477
|
|
|
455
478
|
device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
|
456
479
|
device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
|
|
480
|
+
device const float4 * ppos = src2 != src0 ? (device const float4 *)(src2) : nullptr;
|
|
457
481
|
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
|
458
482
|
|
|
483
|
+
float slope = 0.0f;
|
|
484
|
+
|
|
485
|
+
if (max_bias > 0.0f) {
|
|
486
|
+
const int64_t h = i02;
|
|
487
|
+
|
|
488
|
+
const float base = h < n_head_log2 ? m0 : m1;
|
|
489
|
+
const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
|
490
|
+
|
|
491
|
+
slope = pow(base, exp);
|
|
492
|
+
}
|
|
493
|
+
|
|
459
494
|
// parallel max
|
|
460
495
|
float4 lmax4 = -INFINITY;
|
|
461
496
|
|
|
462
497
|
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
|
463
|
-
lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f));
|
|
498
|
+
lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f) + slope*ppos[i00]);
|
|
464
499
|
}
|
|
465
500
|
|
|
466
501
|
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
|
|
@@ -486,7 +521,7 @@ kernel void kernel_soft_max_4(
|
|
|
486
521
|
// parallel sum
|
|
487
522
|
float4 lsum4 = 0.0f;
|
|
488
523
|
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
|
489
|
-
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val);
|
|
524
|
+
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f) + slope*ppos[i00]) - max_val);
|
|
490
525
|
lsum4 += exp_psrc4;
|
|
491
526
|
pdst4[i00] = exp_psrc4;
|
|
492
527
|
}
|
|
@@ -1775,9 +1810,29 @@ kernel void kernel_rope(
|
|
|
1775
1810
|
template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope<float>;
|
|
1776
1811
|
template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope<half>;
|
|
1777
1812
|
|
|
1778
|
-
|
|
1813
|
+
typedef void (im2col_t)(
|
|
1779
1814
|
device const float * x,
|
|
1780
|
-
device
|
|
1815
|
+
device char * dst,
|
|
1816
|
+
constant int32_t & ofs0,
|
|
1817
|
+
constant int32_t & ofs1,
|
|
1818
|
+
constant int32_t & IW,
|
|
1819
|
+
constant int32_t & IH,
|
|
1820
|
+
constant int32_t & CHW,
|
|
1821
|
+
constant int32_t & s0,
|
|
1822
|
+
constant int32_t & s1,
|
|
1823
|
+
constant int32_t & p0,
|
|
1824
|
+
constant int32_t & p1,
|
|
1825
|
+
constant int32_t & d0,
|
|
1826
|
+
constant int32_t & d1,
|
|
1827
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1828
|
+
uint3 tgpg[[threadgroups_per_grid]],
|
|
1829
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
1830
|
+
uint3 ntg[[threads_per_threadgroup]]);
|
|
1831
|
+
|
|
1832
|
+
template <typename T>
|
|
1833
|
+
kernel void kernel_im2col(
|
|
1834
|
+
device const float * x,
|
|
1835
|
+
device char * dst,
|
|
1781
1836
|
constant int32_t & ofs0,
|
|
1782
1837
|
constant int32_t & ofs1,
|
|
1783
1838
|
constant int32_t & IW,
|
|
@@ -1800,14 +1855,19 @@ kernel void kernel_im2col_f16(
|
|
|
1800
1855
|
(tpitg[0] * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW +
|
|
1801
1856
|
(tgpig[0] * (ntg[1] * ntg[2]) + tpitg[1] * ntg[2] + tpitg[2]);
|
|
1802
1857
|
|
|
1858
|
+
device T * pdst = (device T *) (dst);
|
|
1859
|
+
|
|
1803
1860
|
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
|
|
1804
|
-
|
|
1861
|
+
pdst[offset_dst] = 0.0f;
|
|
1805
1862
|
} else {
|
|
1806
1863
|
const int32_t offset_src = tpitg[0] * ofs0 + tgpig[0] * ofs1;
|
|
1807
|
-
|
|
1864
|
+
pdst[offset_dst] = x[offset_src + iih * IW + iiw];
|
|
1808
1865
|
}
|
|
1809
1866
|
}
|
|
1810
1867
|
|
|
1868
|
+
template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col<float>;
|
|
1869
|
+
template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col<half>;
|
|
1870
|
+
|
|
1811
1871
|
kernel void kernel_upscale_f32(
|
|
1812
1872
|
device const char * src0,
|
|
1813
1873
|
device char * dst,
|
|
@@ -2459,6 +2519,19 @@ typedef struct {
|
|
|
2459
2519
|
} block_iq2_xs;
|
|
2460
2520
|
// 74 bytes / block for QK_K = 256, so 2.3125 bpw
|
|
2461
2521
|
|
|
2522
|
+
typedef struct {
|
|
2523
|
+
half d;
|
|
2524
|
+
uint8_t qs[3*QK_K/8];
|
|
2525
|
+
} block_iq3_xxs;
|
|
2526
|
+
// 98 bytes / block for QK_K = 256, so 3.0625 bpw
|
|
2527
|
+
|
|
2528
|
+
typedef struct {
|
|
2529
|
+
half d;
|
|
2530
|
+
uint8_t qs[QK_K/8];
|
|
2531
|
+
uint8_t scales[QK_K/16];
|
|
2532
|
+
} block_iq1_s;
|
|
2533
|
+
|
|
2534
|
+
|
|
2462
2535
|
//====================================== dot products =========================
|
|
2463
2536
|
|
|
2464
2537
|
void kernel_mul_mv_q2_K_f32_impl(
|
|
@@ -3681,6 +3754,173 @@ constexpr constant static uint64_t iq2xs_grid[512] = {
|
|
|
3681
3754
|
0x2b2b2b2b082b2b08, 0x2b2b2b2b082b2b2b, 0x2b2b2b2b2b190819, 0x2b2b2b2b2b2b2b2b,
|
|
3682
3755
|
};
|
|
3683
3756
|
|
|
3757
|
+
constexpr constant static uint32_t iq3xxs_grid[256] = {
|
|
3758
|
+
0x04040404, 0x04040414, 0x04040424, 0x04040c0c, 0x04040c1c, 0x04040c3e, 0x04041404, 0x04041414,
|
|
3759
|
+
0x04041c0c, 0x04042414, 0x04043e1c, 0x04043e2c, 0x040c040c, 0x040c041c, 0x040c0c04, 0x040c0c14,
|
|
3760
|
+
0x040c140c, 0x040c142c, 0x040c1c04, 0x040c1c14, 0x040c240c, 0x040c2c24, 0x040c3e04, 0x04140404,
|
|
3761
|
+
0x04140414, 0x04140424, 0x04140c0c, 0x04141404, 0x04141414, 0x04141c0c, 0x04141c1c, 0x04141c3e,
|
|
3762
|
+
0x04142c0c, 0x04142c3e, 0x04143e2c, 0x041c040c, 0x041c043e, 0x041c0c04, 0x041c0c14, 0x041c142c,
|
|
3763
|
+
0x041c3e04, 0x04240c1c, 0x04241c3e, 0x04242424, 0x04242c3e, 0x04243e1c, 0x04243e2c, 0x042c040c,
|
|
3764
|
+
0x042c043e, 0x042c1c14, 0x042c2c14, 0x04341c2c, 0x04343424, 0x043e0c04, 0x043e0c24, 0x043e0c34,
|
|
3765
|
+
0x043e241c, 0x043e340c, 0x0c04040c, 0x0c04041c, 0x0c040c04, 0x0c040c14, 0x0c04140c, 0x0c04141c,
|
|
3766
|
+
0x0c041c04, 0x0c041c14, 0x0c041c24, 0x0c04243e, 0x0c042c04, 0x0c0c0404, 0x0c0c0414, 0x0c0c0c0c,
|
|
3767
|
+
0x0c0c1404, 0x0c0c1414, 0x0c14040c, 0x0c14041c, 0x0c140c04, 0x0c140c14, 0x0c14140c, 0x0c141c04,
|
|
3768
|
+
0x0c143e14, 0x0c1c0404, 0x0c1c0414, 0x0c1c1404, 0x0c1c1c0c, 0x0c1c2434, 0x0c1c3434, 0x0c24040c,
|
|
3769
|
+
0x0c24042c, 0x0c242c04, 0x0c2c1404, 0x0c2c1424, 0x0c2c2434, 0x0c2c3e0c, 0x0c34042c, 0x0c3e1414,
|
|
3770
|
+
0x0c3e2404, 0x14040404, 0x14040414, 0x14040c0c, 0x14040c1c, 0x14041404, 0x14041414, 0x14041434,
|
|
3771
|
+
0x14041c0c, 0x14042414, 0x140c040c, 0x140c041c, 0x140c042c, 0x140c0c04, 0x140c0c14, 0x140c140c,
|
|
3772
|
+
0x140c1c04, 0x140c341c, 0x140c343e, 0x140c3e04, 0x14140404, 0x14140414, 0x14140c0c, 0x14140c3e,
|
|
3773
|
+
0x14141404, 0x14141414, 0x14141c3e, 0x14142404, 0x14142c2c, 0x141c040c, 0x141c0c04, 0x141c0c24,
|
|
3774
|
+
0x141c3e04, 0x141c3e24, 0x14241c2c, 0x14242c1c, 0x142c041c, 0x142c143e, 0x142c240c, 0x142c3e24,
|
|
3775
|
+
0x143e040c, 0x143e041c, 0x143e0c34, 0x143e242c, 0x1c04040c, 0x1c040c04, 0x1c040c14, 0x1c04140c,
|
|
3776
|
+
0x1c04141c, 0x1c042c04, 0x1c04342c, 0x1c043e14, 0x1c0c0404, 0x1c0c0414, 0x1c0c1404, 0x1c0c1c0c,
|
|
3777
|
+
0x1c0c2424, 0x1c0c2434, 0x1c14040c, 0x1c14041c, 0x1c140c04, 0x1c14142c, 0x1c142c14, 0x1c143e14,
|
|
3778
|
+
0x1c1c0c0c, 0x1c1c1c1c, 0x1c241c04, 0x1c24243e, 0x1c243e14, 0x1c2c0404, 0x1c2c0434, 0x1c2c1414,
|
|
3779
|
+
0x1c2c2c2c, 0x1c340c24, 0x1c341c34, 0x1c34341c, 0x1c3e1c1c, 0x1c3e3404, 0x24040424, 0x24040c3e,
|
|
3780
|
+
0x24041c2c, 0x24041c3e, 0x24042c1c, 0x24042c3e, 0x240c3e24, 0x24141404, 0x24141c3e, 0x24142404,
|
|
3781
|
+
0x24143404, 0x24143434, 0x241c043e, 0x241c242c, 0x24240424, 0x24242c0c, 0x24243424, 0x242c142c,
|
|
3782
|
+
0x242c241c, 0x242c3e04, 0x243e042c, 0x243e0c04, 0x243e0c14, 0x243e1c04, 0x2c040c14, 0x2c04240c,
|
|
3783
|
+
0x2c043e04, 0x2c0c0404, 0x2c0c0434, 0x2c0c1434, 0x2c0c2c2c, 0x2c140c24, 0x2c141c14, 0x2c143e14,
|
|
3784
|
+
0x2c1c0414, 0x2c1c2c1c, 0x2c240c04, 0x2c24141c, 0x2c24143e, 0x2c243e14, 0x2c2c0414, 0x2c2c1c0c,
|
|
3785
|
+
0x2c342c04, 0x2c3e1424, 0x2c3e2414, 0x34041424, 0x34042424, 0x34042434, 0x34043424, 0x340c140c,
|
|
3786
|
+
0x340c340c, 0x34140c3e, 0x34143424, 0x341c1c04, 0x341c1c34, 0x34242424, 0x342c042c, 0x342c2c14,
|
|
3787
|
+
0x34341c1c, 0x343e041c, 0x343e140c, 0x3e04041c, 0x3e04042c, 0x3e04043e, 0x3e040c04, 0x3e041c14,
|
|
3788
|
+
0x3e042c14, 0x3e0c1434, 0x3e0c2404, 0x3e140c14, 0x3e14242c, 0x3e142c14, 0x3e1c0404, 0x3e1c0c2c,
|
|
3789
|
+
0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
|
|
3790
|
+
};
|
|
3791
|
+
|
|
3792
|
+
#define NGRID_IQ1S 512
|
|
3793
|
+
constexpr constant static uint64_t iq1s_grid[NGRID_IQ1S] = {
|
|
3794
|
+
0xffffffffffff0101, 0xffffffffff01ff00, 0xffffffffff010100, 0xffffffff00000000,
|
|
3795
|
+
0xffffffff01ff00ff, 0xffffffff01ff0001, 0xffffffff0101ffff, 0xffffffff0101ff01,
|
|
3796
|
+
0xffffff00ff000000, 0xffffff000000ff00, 0xffffff00000000ff, 0xffffff0000000100,
|
|
3797
|
+
0xffffff0000010000, 0xffffff0001000000, 0xffffff01ffff00ff, 0xffffff01ff01ff00,
|
|
3798
|
+
0xffffff01ff010100, 0xffffff0100000001, 0xffffff0101ffff00, 0xffffff0101ff0101,
|
|
3799
|
+
0xffffff0101010100, 0xffff00ffff00ff01, 0xffff00ffff0000ff, 0xffff00ff00ff0100,
|
|
3800
|
+
0xffff00ff0100ff00, 0xffff00ff010001ff, 0xffff0000ff0101ff, 0xffff000000ffff00,
|
|
3801
|
+
0xffff000000000000, 0xffff00000001ff01, 0xffff000001000101, 0xffff0000010100ff,
|
|
3802
|
+
0xffff0001ffff0100, 0xffff00010000ff00, 0xffff000100010101, 0xffff000101000000,
|
|
3803
|
+
0xffff01ffffff0000, 0xffff01ffff01ffff, 0xffff01ffff010100, 0xffff01ff00000000,
|
|
3804
|
+
0xffff01ff01ffffff, 0xffff01ff01ff0001, 0xffff01ff0101ffff, 0xffff01ff01010001,
|
|
3805
|
+
0xffff0100ffffff01, 0xffff01000000ffff, 0xffff010000000100, 0xffff010001ff01ff,
|
|
3806
|
+
0xffff010001000000, 0xffff0101ff000000, 0xffff0101000101ff, 0xffff010101ffff01,
|
|
3807
|
+
0xffff01010101ff00, 0xff00ffffff000000, 0xff00ffff00ffff00, 0xff00ffff00000001,
|
|
3808
|
+
0xff00ffff000001ff, 0xff00ffff01010000, 0xff00ff00ffff0000, 0xff00ff00ff00ff00,
|
|
3809
|
+
0xff00ff00ff0000ff, 0xff00ff00ff000100, 0xff00ff00ff010001, 0xff00ff0000ff0001,
|
|
3810
|
+
0xff00ff000000ffff, 0xff00ff0000000000, 0xff00ff000001ff00, 0xff00ff0000010100,
|
|
3811
|
+
0xff00ff0001ff0000, 0xff00ff000100ff00, 0xff00ff0001000100, 0xff00ff01ff000000,
|
|
3812
|
+
0xff00ff0100ff0000, 0xff00ff01000001ff, 0xff00ff0101010001, 0xff0000ff00000000,
|
|
3813
|
+
0xff0000ff0001ff00, 0xff0000ff00010100, 0xff000000ffff0101, 0xff000000ff000000,
|
|
3814
|
+
0xff000000ff01ff00, 0xff00000000ff0000, 0xff0000000000ff00, 0xff000000000000ff,
|
|
3815
|
+
0xff00000000000000, 0xff00000000000001, 0xff00000000000100, 0xff0000000001ffff,
|
|
3816
|
+
0xff00000000010000, 0xff00000001000000, 0xff00000001010100, 0xff000001ff00ff01,
|
|
3817
|
+
0xff000001ff0100ff, 0xff00000100000000, 0xff0000010001ff00, 0xff00000101ff0100,
|
|
3818
|
+
0xff0000010100ff00, 0xff0001ff00ff00ff, 0xff0001ff00000101, 0xff0001ff000100ff,
|
|
3819
|
+
0xff0001ff01000000, 0xff000100ff0001ff, 0xff0001000000ff01, 0xff00010000000000,
|
|
3820
|
+
0xff00010000010001, 0xff00010000010100, 0xff00010001ffff00, 0xff00010001ff0101,
|
|
3821
|
+
0xff00010001010000, 0xff000101ffffffff, 0xff000101ff000101, 0xff00010101ff00ff,
|
|
3822
|
+
0xff00010101000001, 0xff000101010100ff, 0xff01ffffff000101, 0xff01ffffff01ffff,
|
|
3823
|
+
0xff01ffffff01ff01, 0xff01ffffff0101ff, 0xff01ffff00000000, 0xff01ffff01ff0001,
|
|
3824
|
+
0xff01ffff0101ff01, 0xff01ff00ff000000, 0xff01ff0000ff0100, 0xff01ff000000ff01,
|
|
3825
|
+
0xff01ff0000010000, 0xff01ff00010000ff, 0xff01ff01ff01ff00, 0xff01ff0100000101,
|
|
3826
|
+
0xff0100ffffff0000, 0xff0100ffff010000, 0xff0100ff01ff00ff, 0xff0100ff01000100,
|
|
3827
|
+
0xff0100ff010100ff, 0xff010000ffffff01, 0xff01000000000000, 0xff0100000101ff00,
|
|
3828
|
+
0xff010001ffff00ff, 0xff010001ff000100, 0xff01000100ffff00, 0xff01000100010001,
|
|
3829
|
+
0xff01000101ff0001, 0xff010001010001ff, 0xff0101ffffffffff, 0xff0101ffff01ffff,
|
|
3830
|
+
0xff0101ffff010101, 0xff0101ff0000ff00, 0xff0101ff01010001, 0xff010100ff000000,
|
|
3831
|
+
0xff010100ff01ff01, 0xff01010000ff0001, 0xff01010000000100, 0xff01010001000000,
|
|
3832
|
+
0xff0101010100ffff, 0x00ffffff0000ff01, 0x00ffffff000000ff, 0x00ffffff00000100,
|
|
3833
|
+
0x00ffffff00010000, 0x00ffff00ffff0001, 0x00ffff00ff0000ff, 0x00ffff00ff000100,
|
|
3834
|
+
0x00ffff0000000000, 0x00ffff0001000100, 0x00ffff0001010001, 0x00ffff01ff00ff01,
|
|
3835
|
+
0x00ffff0100ff0100, 0x00ffff010000ff00, 0x00ffff01000100ff, 0x00ffff0101ff00ff,
|
|
3836
|
+
0x00ffff010101ff00, 0x00ff00ffffffffff, 0x00ff00ffffff01ff, 0x00ff00ffff000101,
|
|
3837
|
+
0x00ff00ff00000000, 0x00ff00ff000101ff, 0x00ff00ff01010101, 0x00ff0000ff000000,
|
|
3838
|
+
0x00ff0000ff01ffff, 0x00ff000000ff0000, 0x00ff00000000ff00, 0x00ff0000000000ff,
|
|
3839
|
+
0x00ff000000000000, 0x00ff000000000001, 0x00ff000000000100, 0x00ff000000010000,
|
|
3840
|
+
0x00ff000001ffff01, 0x00ff000001000000, 0x00ff0001ff000101, 0x00ff000100ffffff,
|
|
3841
|
+
0x00ff000100000000, 0x00ff0001010001ff, 0x00ff01ffff000000, 0x00ff01ff0001ff00,
|
|
3842
|
+
0x00ff01ff01ff0100, 0x00ff0100ff01ff01, 0x00ff010000ff00ff, 0x00ff010000ff0101,
|
|
3843
|
+
0x00ff010000000000, 0x00ff010000010101, 0x00ff01000100ff00, 0x00ff010001010000,
|
|
3844
|
+
0x00ff0101ffffff00, 0x00ff01010000ff01, 0x00ff010100000100, 0x00ff010101ff0000,
|
|
3845
|
+
0x0000ffffffff0100, 0x0000ffffff00ff00, 0x0000ffffff0000ff, 0x0000ffffff010000,
|
|
3846
|
+
0x0000ffff00000000, 0x0000ffff00010101, 0x0000ffff01ffff01, 0x0000ffff01000100,
|
|
3847
|
+
0x0000ff00ff000000, 0x0000ff00ff01ff00, 0x0000ff00ff0101ff, 0x0000ff0000ff0000,
|
|
3848
|
+
0x0000ff000000ff00, 0x0000ff00000000ff, 0x0000ff0000000000, 0x0000ff0000000001,
|
|
3849
|
+
0x0000ff0000000100, 0x0000ff0000010000, 0x0000ff0001ffffff, 0x0000ff0001ff01ff,
|
|
3850
|
+
0x0000ff0001000000, 0x0000ff000101ffff, 0x0000ff01ffff0101, 0x0000ff01ff010000,
|
|
3851
|
+
0x0000ff0100000000, 0x0000ff0101000101, 0x000000ffffff0001, 0x000000ffff000000,
|
|
3852
|
+
0x000000ff00ff0000, 0x000000ff0000ff00, 0x000000ff000000ff, 0x000000ff00000000,
|
|
3853
|
+
0x000000ff00000001, 0x000000ff00000100, 0x000000ff00010000, 0x000000ff01000000,
|
|
3854
|
+
0x000000ff0101ff00, 0x00000000ffff0000, 0x00000000ff00ff00, 0x00000000ff0000ff,
|
|
3855
|
+
0x00000000ff000000, 0x00000000ff000001, 0x00000000ff000100, 0x00000000ff010000,
|
|
3856
|
+
0x0000000000ffff00, 0x0000000000ff00ff, 0x0000000000ff0000, 0x0000000000ff0001,
|
|
3857
|
+
0x0000000000ff0100, 0x000000000000ffff, 0x000000000000ff00, 0x000000000000ff01,
|
|
3858
|
+
0x00000000000000ff, 0x0000000000000001, 0x00000000000001ff, 0x0000000000000100,
|
|
3859
|
+
0x0000000000000101, 0x000000000001ff00, 0x00000000000100ff, 0x0000000000010000,
|
|
3860
|
+
0x0000000000010001, 0x0000000000010100, 0x0000000001ff0000, 0x000000000100ff00,
|
|
3861
|
+
0x00000000010000ff, 0x0000000001000000, 0x0000000001000001, 0x0000000001000100,
|
|
3862
|
+
0x0000000001010000, 0x00000001ffff01ff, 0x00000001ff000000, 0x0000000100ff0000,
|
|
3863
|
+
0x000000010000ff00, 0x00000001000000ff, 0x0000000100000000, 0x0000000100000001,
|
|
3864
|
+
0x0000000100000100, 0x0000000100010000, 0x0000000101000000, 0x000001ffff00ff00,
|
|
3865
|
+
0x000001ffff010001, 0x000001ffff0101ff, 0x000001ff00ffff01, 0x000001ff0000ffff,
|
|
3866
|
+
0x000001ff00000000, 0x000001ff010000ff, 0x000001ff01010100, 0x00000100ffff0100,
|
|
3867
|
+
0x00000100ff000000, 0x0000010000ff0000, 0x000001000000ff00, 0x00000100000000ff,
|
|
3868
|
+
0x0000010000000000, 0x0000010000000001, 0x0000010000000100, 0x0000010000010000,
|
|
3869
|
+
0x0000010001000000, 0x000001000101ff01, 0x00000101ffff0001, 0x00000101ff01ffff,
|
|
3870
|
+
0x0000010100000000, 0x0000010101010100, 0x0001ffffff000000, 0x0001ffff00ffffff,
|
|
3871
|
+
0x0001ffff00000100, 0x0001ffff0001ff00, 0x0001ffff01000000, 0x0001ff00ffffff00,
|
|
3872
|
+
0x0001ff00ffff01ff, 0x0001ff00ff010000, 0x0001ff0000000000, 0x0001ff0000010001,
|
|
3873
|
+
0x0001ff0001ff0000, 0x0001ff0001010100, 0x0001ff01ff0000ff, 0x0001ff01ff000001,
|
|
3874
|
+
0x0001ff0100ffffff, 0x0001ff010001ffff, 0x0001ff01000101ff, 0x0001ff010100ff01,
|
|
3875
|
+
0x000100ffff00ffff, 0x000100ffff00ff01, 0x000100ffff000100, 0x000100ff00000000,
|
|
3876
|
+
0x000100ff000101ff, 0x000100ff01ff0101, 0x000100ff0100ffff, 0x000100ff01010101,
|
|
3877
|
+
0x00010000ff000000, 0x00010000ff010100, 0x0001000000ff0000, 0x000100000000ff00,
|
|
3878
|
+
0x00010000000000ff, 0x0001000000000000, 0x0001000000000001, 0x0001000000000100,
|
|
3879
|
+
0x0001000000010000, 0x0001000001ffff01, 0x0001000001000000, 0x0001000100ff0101,
|
|
3880
|
+
0x0001000100000000, 0x00010001010100ff, 0x000101ffffff01ff, 0x000101ffffff0101,
|
|
3881
|
+
0x000101ff00010000, 0x000101ff01ff0000, 0x000101ff0100ff01, 0x00010100ffff0000,
|
|
3882
|
+
0x0001010000000000, 0x000101000001ffff, 0x0001010000010101, 0x00010100010001ff,
|
|
3883
|
+
0x00010101ff00ff00, 0x00010101ff010001, 0x0001010100ffffff, 0x0001010100ff01ff,
|
|
3884
|
+
0x00010101000101ff, 0x0001010101ff0000, 0x000101010100ff01, 0x0001010101000101,
|
|
3885
|
+
0x01ffffffffff0101, 0x01ffffffff01ffff, 0x01ffffffff01ff01, 0x01ffffffff0101ff,
|
|
3886
|
+
0x01ffffffff010101, 0x01ffffff00000000, 0x01ffffff01ff01ff, 0x01ffffff01000101,
|
|
3887
|
+
0x01ffffff0101ff01, 0x01ffffff010100ff, 0x01ffff000000ff00, 0x01ffff0000000001,
|
|
3888
|
+
0x01ffff00000001ff, 0x01ffff0000010000, 0x01ffff0001ff0000, 0x01ffff01ffffffff,
|
|
3889
|
+
0x01ffff01ffff01ff, 0x01ffff01ff000000, 0x01ffff01ff01ffff, 0x01ffff01ff0101ff,
|
|
3890
|
+
0x01ffff010100ffff, 0x01ff00ffffff0000, 0x01ff00ffff010000, 0x01ff00ff00ffff01,
|
|
3891
|
+
0x01ff0000ff0000ff, 0x01ff000000000000, 0x01ff00000001ff01, 0x01ff000001ffffff,
|
|
3892
|
+
0x01ff000001010100, 0x01ff0001ffffff01, 0x01ff0001ff010001, 0x01ff000101ff0100,
|
|
3893
|
+
0x01ff000101000001, 0x01ff0001010100ff, 0x01ff01ffff00ffff, 0x01ff01ff00010001,
|
|
3894
|
+
0x01ff01ff01000000, 0x01ff01ff010101ff, 0x01ff0100ff000001, 0x01ff010000ffff00,
|
|
3895
|
+
0x01ff010000000100, 0x01ff010001ff01ff, 0x01ff01000101ffff, 0x01ff0101ffff00ff,
|
|
3896
|
+
0x01ff0101ffff0101, 0x01ff0101ff0101ff, 0x01ff010100010000, 0x0100ffff00ff00ff,
|
|
3897
|
+
0x0100ffff00ff0001, 0x0100ffff00000100, 0x0100ffff0100ff00, 0x0100ff00ffff0000,
|
|
3898
|
+
0x0100ff00ff00ffff, 0x0100ff00ff00ff01, 0x0100ff00ff000100, 0x0100ff00ff010000,
|
|
3899
|
+
0x0100ff0000000000, 0x0100ff00000100ff, 0x0100ff0001ff0101, 0x0100ff0001010101,
|
|
3900
|
+
0x0100ff0100ff00ff, 0x0100ff0100ff0001, 0x0100ff0100000100, 0x0100ff0100010001,
|
|
3901
|
+
0x0100ff0101000000, 0x010000ffff00ff00, 0x010000ff0000ffff, 0x010000ff00000000,
|
|
3902
|
+
0x010000ff010001ff, 0x010000ff01010001, 0x01000000ffffff00, 0x01000000ffff0101,
|
|
3903
|
+
0x01000000ff000000, 0x01000000ff0100ff, 0x01000000ff010101, 0x0100000000ff0000,
|
|
3904
|
+
0x010000000000ff00, 0x01000000000000ff, 0x0100000000000000, 0x0100000000000001,
|
|
3905
|
+
0x0100000000000100, 0x0100000000010000, 0x0100000001000000, 0x0100000100000000,
|
|
3906
|
+
0x01000001000101ff, 0x0100000101ffff01, 0x010001ffff000101, 0x010001ff00ff0100,
|
|
3907
|
+
0x010001ff0000ff00, 0x010001ff000100ff, 0x010001ff01ffffff, 0x01000100ffff0000,
|
|
3908
|
+
0x01000100ff0001ff, 0x0100010000000000, 0x010001000001ff00, 0x0100010001ff0000,
|
|
3909
|
+
0x01000100010000ff, 0x0100010001000101, 0x01000101ff00ff01, 0x0100010100ff0100,
|
|
3910
|
+
0x010001010000ffff, 0x0100010101010001, 0x0101ffffffff0101, 0x0101ffffff0001ff,
|
|
3911
|
+
0x0101ffffff01ffff, 0x0101ffffff010101, 0x0101ffff00000000, 0x0101ffff0101ffff,
|
|
3912
|
+
0x0101ffff010101ff, 0x0101ff00ff000000, 0x0101ff0000ff0100, 0x0101ff000000ff00,
|
|
3913
|
+
0x0101ff0000010000, 0x0101ff00010000ff, 0x0101ff0001000001, 0x0101ff01ff010101,
|
|
3914
|
+
0x0101ff0100000000, 0x0101ff010101ff00, 0x010100ffffff0000, 0x010100ffff010000,
|
|
3915
|
+
0x010100ff00ff01ff, 0x010100ff000000ff, 0x010100ff00000101, 0x010100ff01ffff00,
|
|
3916
|
+
0x01010000ffffff01, 0x01010000ff000100, 0x01010000ff01ff01, 0x0101000000000000,
|
|
3917
|
+
0x01010000000100ff, 0x010100000101ff01, 0x01010001ffff0000, 0x01010001ff00ffff,
|
|
3918
|
+
0x01010001ff010000, 0x0101000101ffffff, 0x0101000101ff01ff, 0x0101000101010101,
|
|
3919
|
+
0x010101ffff01ffff, 0x010101ff00000000, 0x010101ff0001ff01, 0x010101ff0101ffff,
|
|
3920
|
+
0x010101ff010101ff, 0x01010100ffffffff, 0x01010100ff000001, 0x010101000000ff00,
|
|
3921
|
+
0x0101010001010000, 0x0101010100ff0001, 0x010101010001ff01, 0x010101010101ffff,
|
|
3922
|
+
};
|
|
3923
|
+
|
|
3684
3924
|
constexpr constant static uint8_t ksigns_iq2xs[128] = {
|
|
3685
3925
|
0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
|
|
3686
3926
|
144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
|
|
@@ -3970,6 +4210,260 @@ kernel void kernel_mul_mv_iq2_xs_f32(
|
|
|
3970
4210
|
kernel_mul_mv_iq2_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
|
3971
4211
|
}
|
|
3972
4212
|
|
|
4213
|
+
void kernel_mul_mv_iq3_xxs_f32_impl(
|
|
4214
|
+
device const void * src0,
|
|
4215
|
+
device const float * src1,
|
|
4216
|
+
device float * dst,
|
|
4217
|
+
constant int64_t & ne00,
|
|
4218
|
+
constant int64_t & ne01,
|
|
4219
|
+
constant int64_t & ne02,
|
|
4220
|
+
constant int64_t & ne10,
|
|
4221
|
+
constant int64_t & ne12,
|
|
4222
|
+
constant int64_t & ne0,
|
|
4223
|
+
constant int64_t & ne1,
|
|
4224
|
+
constant uint & r2,
|
|
4225
|
+
constant uint & r3,
|
|
4226
|
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
|
4227
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
4228
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
4229
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
4230
|
+
|
|
4231
|
+
const int nb = ne00/QK_K;
|
|
4232
|
+
const int r0 = tgpig.x;
|
|
4233
|
+
const int r1 = tgpig.y;
|
|
4234
|
+
const int im = tgpig.z;
|
|
4235
|
+
|
|
4236
|
+
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
|
4237
|
+
const int ib_row = first_row * nb;
|
|
4238
|
+
|
|
4239
|
+
const uint i12 = im%ne12;
|
|
4240
|
+
const uint i13 = im/ne12;
|
|
4241
|
+
|
|
4242
|
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
|
4243
|
+
|
|
4244
|
+
device const block_iq3_xxs * x = (device const block_iq3_xxs *) src0 + ib_row + offset0;
|
|
4245
|
+
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
|
4246
|
+
|
|
4247
|
+
float yl[32];
|
|
4248
|
+
float sumf[N_DST]={0.f}, all_sum;
|
|
4249
|
+
|
|
4250
|
+
const int nb32 = nb * (QK_K / 32);
|
|
4251
|
+
|
|
4252
|
+
threadgroup uint32_t * values = (threadgroup uint32_t *)shared_values;
|
|
4253
|
+
threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 256);
|
|
4254
|
+
{
|
|
4255
|
+
int nval = 4;
|
|
4256
|
+
int pos = (32*sgitg + tiisg)*nval;
|
|
4257
|
+
for (int i = 0; i < nval; ++i) values[pos + i] = iq3xxs_grid[pos + i];
|
|
4258
|
+
nval = 2;
|
|
4259
|
+
pos = (32*sgitg + tiisg)*nval;
|
|
4260
|
+
for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i];
|
|
4261
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
4262
|
+
}
|
|
4263
|
+
|
|
4264
|
+
#if QK_K == 256
|
|
4265
|
+
const int ix = tiisg;
|
|
4266
|
+
|
|
4267
|
+
device const float * y4 = y + 32 * ix;
|
|
4268
|
+
|
|
4269
|
+
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
|
|
4270
|
+
|
|
4271
|
+
for (int i = 0; i < 32; ++i) {
|
|
4272
|
+
yl[i] = y4[i];
|
|
4273
|
+
}
|
|
4274
|
+
|
|
4275
|
+
const int ibl = ib32 / (QK_K / 32);
|
|
4276
|
+
const int ib = ib32 % (QK_K / 32);
|
|
4277
|
+
|
|
4278
|
+
device const block_iq3_xxs * xr = x + ibl;
|
|
4279
|
+
device const uint8_t * q3 = xr->qs + 8 * ib;
|
|
4280
|
+
device const uint16_t * gas = (device const uint16_t *)(xr->qs + QK_K/4) + 2 * ib;
|
|
4281
|
+
device const half * dh = &xr->d;
|
|
4282
|
+
|
|
4283
|
+
for (int row = 0; row < N_DST; row++) {
|
|
4284
|
+
|
|
4285
|
+
const float db = dh[0];
|
|
4286
|
+
const uint32_t aux32 = gas[0] | (gas[1] << 16);
|
|
4287
|
+
const float d = db * (0.5f + (aux32 >> 28));
|
|
4288
|
+
|
|
4289
|
+
float2 sum = {0};
|
|
4290
|
+
for (int l = 0; l < 4; ++l) {
|
|
4291
|
+
const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(values + q3[2*l+0]);
|
|
4292
|
+
const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(values + q3[2*l+1]);
|
|
4293
|
+
const uint8_t signs = shared_signs[(aux32 >> 7*l) & 127];
|
|
4294
|
+
for (int j = 0; j < 4; ++j) {
|
|
4295
|
+
sum[0] += yl[8*l + j + 0] * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
|
|
4296
|
+
sum[1] += yl[8*l + j + 4] * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
|
|
4297
|
+
}
|
|
4298
|
+
}
|
|
4299
|
+
sumf[row] += d * (sum[0] + sum[1]);
|
|
4300
|
+
|
|
4301
|
+
dh += nb*sizeof(block_iq3_xxs)/2;
|
|
4302
|
+
q3 += nb*sizeof(block_iq3_xxs);
|
|
4303
|
+
gas += nb*sizeof(block_iq3_xxs)/2;
|
|
4304
|
+
}
|
|
4305
|
+
|
|
4306
|
+
y4 += 32 * 32;
|
|
4307
|
+
}
|
|
4308
|
+
#else
|
|
4309
|
+
// TODO
|
|
4310
|
+
#endif
|
|
4311
|
+
|
|
4312
|
+
for (int row = 0; row < N_DST; ++row) {
|
|
4313
|
+
all_sum = simd_sum(sumf[row]);
|
|
4314
|
+
if (tiisg == 0) {
|
|
4315
|
+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.5f;
|
|
4316
|
+
}
|
|
4317
|
+
}
|
|
4318
|
+
}
|
|
4319
|
+
|
|
4320
|
+
[[host_name("kernel_mul_mv_iq3_xxs_f32")]]
|
|
4321
|
+
kernel void kernel_mul_mv_iq3_xxs_f32(
|
|
4322
|
+
device const void * src0,
|
|
4323
|
+
device const float * src1,
|
|
4324
|
+
device float * dst,
|
|
4325
|
+
constant int64_t & ne00,
|
|
4326
|
+
constant int64_t & ne01,
|
|
4327
|
+
constant int64_t & ne02,
|
|
4328
|
+
constant uint64_t & nb00,
|
|
4329
|
+
constant uint64_t & nb01,
|
|
4330
|
+
constant uint64_t & nb02,
|
|
4331
|
+
constant int64_t & ne10,
|
|
4332
|
+
constant int64_t & ne11,
|
|
4333
|
+
constant int64_t & ne12,
|
|
4334
|
+
constant uint64_t & nb10,
|
|
4335
|
+
constant uint64_t & nb11,
|
|
4336
|
+
constant uint64_t & nb12,
|
|
4337
|
+
constant int64_t & ne0,
|
|
4338
|
+
constant int64_t & ne1,
|
|
4339
|
+
constant uint & r2,
|
|
4340
|
+
constant uint & r3,
|
|
4341
|
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
|
4342
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
4343
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
4344
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
4345
|
+
|
|
4346
|
+
kernel_mul_mv_iq3_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
|
4347
|
+
}
|
|
4348
|
+
|
|
4349
|
+
void kernel_mul_mv_iq1_s_f32_impl(
|
|
4350
|
+
device const void * src0,
|
|
4351
|
+
device const float * src1,
|
|
4352
|
+
device float * dst,
|
|
4353
|
+
constant int64_t & ne00,
|
|
4354
|
+
constant int64_t & ne01,
|
|
4355
|
+
constant int64_t & ne02,
|
|
4356
|
+
constant int64_t & ne10,
|
|
4357
|
+
constant int64_t & ne12,
|
|
4358
|
+
constant int64_t & ne0,
|
|
4359
|
+
constant int64_t & ne1,
|
|
4360
|
+
constant uint & r2,
|
|
4361
|
+
constant uint & r3,
|
|
4362
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
4363
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
4364
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
4365
|
+
|
|
4366
|
+
const int nb = ne00/QK_K;
|
|
4367
|
+
const int r0 = tgpig.x;
|
|
4368
|
+
const int r1 = tgpig.y;
|
|
4369
|
+
const int im = tgpig.z;
|
|
4370
|
+
|
|
4371
|
+
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
|
4372
|
+
const int ib_row = first_row * nb;
|
|
4373
|
+
|
|
4374
|
+
const uint i12 = im%ne12;
|
|
4375
|
+
const uint i13 = im/ne12;
|
|
4376
|
+
|
|
4377
|
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
|
4378
|
+
|
|
4379
|
+
device const block_iq1_s * x = (device const block_iq1_s *) src0 + ib_row + offset0;
|
|
4380
|
+
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
|
4381
|
+
|
|
4382
|
+
float yl[16];
|
|
4383
|
+
float sumf[N_DST]={0.f}, all_sum;
|
|
4384
|
+
|
|
4385
|
+
const int nb32 = nb * (QK_K / 32);
|
|
4386
|
+
|
|
4387
|
+
#if QK_K == 256
|
|
4388
|
+
const int ix = tiisg/2;
|
|
4389
|
+
const int il = tiisg%2;
|
|
4390
|
+
|
|
4391
|
+
device const float * y4 = y + 32 * ix + 16 * il;
|
|
4392
|
+
|
|
4393
|
+
for (int ib32 = ix; ib32 < nb32; ib32 += 16) {
|
|
4394
|
+
|
|
4395
|
+
for (int i = 0; i < 16; ++i) {
|
|
4396
|
+
yl[i] = y4[i];
|
|
4397
|
+
}
|
|
4398
|
+
|
|
4399
|
+
const int ibl = ib32 / (QK_K / 32);
|
|
4400
|
+
const int ib = ib32 % (QK_K / 32);
|
|
4401
|
+
|
|
4402
|
+
device const block_iq1_s * xr = x + ibl;
|
|
4403
|
+
device const uint8_t * qs = xr->qs + 4 * ib + 2 * il;
|
|
4404
|
+
device const uint8_t * sc = xr->scales + 2 * ib + il;
|
|
4405
|
+
device const half * dh = &xr->d;
|
|
4406
|
+
|
|
4407
|
+
for (int row = 0; row < N_DST; row++) {
|
|
4408
|
+
|
|
4409
|
+
constant int8_t * grid1 = (constant int8_t *)(iq1s_grid + (qs[0] | ((sc[0] & 0x08) << 5)));
|
|
4410
|
+
constant int8_t * grid2 = (constant int8_t *)(iq1s_grid + (qs[1] | ((sc[0] & 0x80) << 1)));
|
|
4411
|
+
|
|
4412
|
+
float2 sum = {0};
|
|
4413
|
+
for (int j = 0; j < 8; ++j) {
|
|
4414
|
+
sum[0] += yl[j+ 0] * grid1[j];
|
|
4415
|
+
sum[1] += yl[j+ 8] * grid2[j];
|
|
4416
|
+
}
|
|
4417
|
+
sumf[row] += (float)dh[0] * (sum[0] * (2*(sc[0] & 7) + 1) + sum[1] * (2*((sc[0] >> 4) & 7) + 1));
|
|
4418
|
+
|
|
4419
|
+
dh += nb*sizeof(block_iq1_s)/2;
|
|
4420
|
+
qs += nb*sizeof(block_iq1_s);
|
|
4421
|
+
sc += nb*sizeof(block_iq1_s);
|
|
4422
|
+
}
|
|
4423
|
+
|
|
4424
|
+
y4 += 16 * 32;
|
|
4425
|
+
}
|
|
4426
|
+
#else
|
|
4427
|
+
// TODO
|
|
4428
|
+
#endif
|
|
4429
|
+
|
|
4430
|
+
for (int row = 0; row < N_DST; ++row) {
|
|
4431
|
+
all_sum = simd_sum(sumf[row]);
|
|
4432
|
+
if (tiisg == 0) {
|
|
4433
|
+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
|
|
4434
|
+
}
|
|
4435
|
+
}
|
|
4436
|
+
}
|
|
4437
|
+
|
|
4438
|
+
[[host_name("kernel_mul_mv_iq1_s_f32")]]
|
|
4439
|
+
kernel void kernel_mul_mv_iq1_s_f32(
|
|
4440
|
+
device const void * src0,
|
|
4441
|
+
device const float * src1,
|
|
4442
|
+
device float * dst,
|
|
4443
|
+
constant int64_t & ne00,
|
|
4444
|
+
constant int64_t & ne01,
|
|
4445
|
+
constant int64_t & ne02,
|
|
4446
|
+
constant uint64_t & nb00,
|
|
4447
|
+
constant uint64_t & nb01,
|
|
4448
|
+
constant uint64_t & nb02,
|
|
4449
|
+
constant int64_t & ne10,
|
|
4450
|
+
constant int64_t & ne11,
|
|
4451
|
+
constant int64_t & ne12,
|
|
4452
|
+
constant uint64_t & nb10,
|
|
4453
|
+
constant uint64_t & nb11,
|
|
4454
|
+
constant uint64_t & nb12,
|
|
4455
|
+
constant int64_t & ne0,
|
|
4456
|
+
constant int64_t & ne1,
|
|
4457
|
+
constant uint & r2,
|
|
4458
|
+
constant uint & r3,
|
|
4459
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
4460
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
4461
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
4462
|
+
|
|
4463
|
+
kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
|
4464
|
+
}
|
|
4465
|
+
|
|
4466
|
+
|
|
3973
4467
|
//============================= templates and their specializations =============================
|
|
3974
4468
|
|
|
3975
4469
|
// NOTE: this is not dequantizing - we are simply fitting the template
|
|
@@ -4287,6 +4781,49 @@ void dequantize_iq2_xs(device const block_iq2_xs * xb, short il, thread type4x4
|
|
|
4287
4781
|
}
|
|
4288
4782
|
}
|
|
4289
4783
|
|
|
4784
|
+
template <typename type4x4>
|
|
4785
|
+
void dequantize_iq3_xxs(device const block_iq3_xxs * xb, short il, thread type4x4 & reg) {
|
|
4786
|
+
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
|
|
4787
|
+
const float d = xb->d;
|
|
4788
|
+
const int ib32 = il/2;
|
|
4789
|
+
il = il%2;
|
|
4790
|
+
// il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
|
|
4791
|
+
device const uint8_t * q3 = xb->qs + 8*ib32;
|
|
4792
|
+
device const uint16_t * gas = (device const uint16_t *)(xb->qs + QK_K/4) + 2*ib32;
|
|
4793
|
+
const uint32_t aux32 = gas[0] | (gas[1] << 16);
|
|
4794
|
+
const float dl = d * (0.5f + (aux32 >> 28)) * 0.5f;
|
|
4795
|
+
constant uint8_t * grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+0]);
|
|
4796
|
+
constant uint8_t * grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+1]);
|
|
4797
|
+
uint8_t signs = ksigns_iq2xs[(aux32 >> 14*il) & 127];
|
|
4798
|
+
for (int i = 0; i < 4; ++i) {
|
|
4799
|
+
reg[0][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f);
|
|
4800
|
+
reg[1][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f);
|
|
4801
|
+
}
|
|
4802
|
+
grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+2]);
|
|
4803
|
+
grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+3]);
|
|
4804
|
+
signs = ksigns_iq2xs[(aux32 >> (14*il+7)) & 127];
|
|
4805
|
+
for (int i = 0; i < 4; ++i) {
|
|
4806
|
+
reg[2][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f);
|
|
4807
|
+
reg[3][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f);
|
|
4808
|
+
}
|
|
4809
|
+
}
|
|
4810
|
+
|
|
4811
|
+
template <typename type4x4>
|
|
4812
|
+
void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 & reg) {
|
|
4813
|
+
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
|
|
4814
|
+
const float d = xb->d;
|
|
4815
|
+
device const uint8_t * qs = xb->qs + 2*il;
|
|
4816
|
+
device const uint8_t * sc = xb->scales + il;
|
|
4817
|
+
const float dl1 = d * (2*(sc[0] & 7) + 1);
|
|
4818
|
+
const float dl2 = d * (2*((sc[0] >> 4) & 7) + 1);
|
|
4819
|
+
constant int8_t * grid1 = (constant int8_t *)(iq1s_grid + (qs[0] | ((sc[0] & 0x08) << 5)));
|
|
4820
|
+
constant int8_t * grid2 = (constant int8_t *)(iq1s_grid + (qs[1] | ((sc[0] & 0x80) << 1)));
|
|
4821
|
+
for (int i = 0; i < 8; ++i) {
|
|
4822
|
+
reg[i/4+0][i%4] = dl1 * grid1[i];
|
|
4823
|
+
reg[i/4+2][i%4] = dl2 * grid2[i];
|
|
4824
|
+
}
|
|
4825
|
+
}
|
|
4826
|
+
|
|
4290
4827
|
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
|
|
4291
4828
|
kernel void kernel_get_rows(
|
|
4292
4829
|
device const void * src0,
|
|
@@ -4828,6 +5365,8 @@ template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows
|
|
|
4828
5365
|
template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
|
|
4829
5366
|
template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_t kernel_get_rows<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
|
4830
5367
|
template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_t kernel_get_rows<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
|
5368
|
+
template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_t kernel_get_rows<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
|
5369
|
+
template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
|
4831
5370
|
|
|
4832
5371
|
//
|
|
4833
5372
|
// matrix-matrix multiplication
|
|
@@ -4866,6 +5405,8 @@ template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<b
|
|
|
4866
5405
|
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
|
|
4867
5406
|
template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
|
4868
5407
|
template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
|
5408
|
+
template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
|
5409
|
+
template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
|
4869
5410
|
|
|
4870
5411
|
//
|
|
4871
5412
|
// indirect matrix-matrix multiplication
|
|
@@ -4916,6 +5457,8 @@ template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mu
|
|
|
4916
5457
|
template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
|
|
4917
5458
|
template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
|
4918
5459
|
template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
|
5460
|
+
template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
|
5461
|
+
template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
|
4919
5462
|
|
|
4920
5463
|
//
|
|
4921
5464
|
// matrix-vector multiplication
|
|
@@ -5818,3 +6361,131 @@ kernel void kernel_mul_mv_id_iq2_xs_f32(
|
|
|
5818
6361
|
tiisg,
|
|
5819
6362
|
sgitg);
|
|
5820
6363
|
}
|
|
6364
|
+
|
|
6365
|
+
[[host_name("kernel_mul_mv_id_iq3_xxs_f32")]]
|
|
6366
|
+
kernel void kernel_mul_mv_id_iq3_xxs_f32(
|
|
6367
|
+
device const char * ids,
|
|
6368
|
+
device const char * src1,
|
|
6369
|
+
device float * dst,
|
|
6370
|
+
constant uint64_t & nbi1,
|
|
6371
|
+
constant int64_t & ne00,
|
|
6372
|
+
constant int64_t & ne01,
|
|
6373
|
+
constant int64_t & ne02,
|
|
6374
|
+
constant uint64_t & nb00,
|
|
6375
|
+
constant uint64_t & nb01,
|
|
6376
|
+
constant uint64_t & nb02,
|
|
6377
|
+
constant int64_t & ne10,
|
|
6378
|
+
constant int64_t & ne11,
|
|
6379
|
+
constant int64_t & ne12,
|
|
6380
|
+
constant int64_t & ne13,
|
|
6381
|
+
constant uint64_t & nb10,
|
|
6382
|
+
constant uint64_t & nb11,
|
|
6383
|
+
constant uint64_t & nb12,
|
|
6384
|
+
constant int64_t & ne0,
|
|
6385
|
+
constant int64_t & ne1,
|
|
6386
|
+
constant uint64_t & nb1,
|
|
6387
|
+
constant uint & r2,
|
|
6388
|
+
constant uint & r3,
|
|
6389
|
+
constant int & idx,
|
|
6390
|
+
device const char * src00,
|
|
6391
|
+
device const char * src01,
|
|
6392
|
+
device const char * src02,
|
|
6393
|
+
device const char * src03,
|
|
6394
|
+
device const char * src04,
|
|
6395
|
+
device const char * src05,
|
|
6396
|
+
device const char * src06,
|
|
6397
|
+
device const char * src07,
|
|
6398
|
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
|
6399
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
6400
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
|
6401
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
6402
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
6403
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
|
6404
|
+
|
|
6405
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
|
6406
|
+
|
|
6407
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
|
6408
|
+
|
|
6409
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
6410
|
+
|
|
6411
|
+
kernel_mul_mv_iq3_xxs_f32_impl(
|
|
6412
|
+
src0[id],
|
|
6413
|
+
(device const float *) (src1 + bid*nb11),
|
|
6414
|
+
dst + bid*ne0,
|
|
6415
|
+
ne00,
|
|
6416
|
+
ne01,
|
|
6417
|
+
ne02,
|
|
6418
|
+
ne10,
|
|
6419
|
+
ne12,
|
|
6420
|
+
ne0,
|
|
6421
|
+
ne1,
|
|
6422
|
+
r2,
|
|
6423
|
+
r3,
|
|
6424
|
+
shared_values,
|
|
6425
|
+
tgpig,
|
|
6426
|
+
tiisg,
|
|
6427
|
+
sgitg);
|
|
6428
|
+
}
|
|
6429
|
+
|
|
6430
|
+
[[host_name("kernel_mul_mv_id_iq1_s_f32")]]
|
|
6431
|
+
kernel void kernel_mul_mv_id_iq1_s_f32(
|
|
6432
|
+
device const char * ids,
|
|
6433
|
+
device const char * src1,
|
|
6434
|
+
device float * dst,
|
|
6435
|
+
constant uint64_t & nbi1,
|
|
6436
|
+
constant int64_t & ne00,
|
|
6437
|
+
constant int64_t & ne01,
|
|
6438
|
+
constant int64_t & ne02,
|
|
6439
|
+
constant uint64_t & nb00,
|
|
6440
|
+
constant uint64_t & nb01,
|
|
6441
|
+
constant uint64_t & nb02,
|
|
6442
|
+
constant int64_t & ne10,
|
|
6443
|
+
constant int64_t & ne11,
|
|
6444
|
+
constant int64_t & ne12,
|
|
6445
|
+
constant int64_t & ne13,
|
|
6446
|
+
constant uint64_t & nb10,
|
|
6447
|
+
constant uint64_t & nb11,
|
|
6448
|
+
constant uint64_t & nb12,
|
|
6449
|
+
constant int64_t & ne0,
|
|
6450
|
+
constant int64_t & ne1,
|
|
6451
|
+
constant uint64_t & nb1,
|
|
6452
|
+
constant uint & r2,
|
|
6453
|
+
constant uint & r3,
|
|
6454
|
+
constant int & idx,
|
|
6455
|
+
device const char * src00,
|
|
6456
|
+
device const char * src01,
|
|
6457
|
+
device const char * src02,
|
|
6458
|
+
device const char * src03,
|
|
6459
|
+
device const char * src04,
|
|
6460
|
+
device const char * src05,
|
|
6461
|
+
device const char * src06,
|
|
6462
|
+
device const char * src07,
|
|
6463
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
6464
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
|
6465
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
6466
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
6467
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
|
6468
|
+
|
|
6469
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
|
6470
|
+
|
|
6471
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
|
6472
|
+
|
|
6473
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
6474
|
+
|
|
6475
|
+
kernel_mul_mv_iq1_s_f32_impl(
|
|
6476
|
+
src0[id],
|
|
6477
|
+
(device const float *) (src1 + bid*nb11),
|
|
6478
|
+
dst + bid*ne0,
|
|
6479
|
+
ne00,
|
|
6480
|
+
ne01,
|
|
6481
|
+
ne02,
|
|
6482
|
+
ne10,
|
|
6483
|
+
ne12,
|
|
6484
|
+
ne0,
|
|
6485
|
+
ne1,
|
|
6486
|
+
r2,
|
|
6487
|
+
r3,
|
|
6488
|
+
tgpig,
|
|
6489
|
+
tiisg,
|
|
6490
|
+
sgitg);
|
|
6491
|
+
}
|