llama_cpp 0.12.0 → 0.12.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/ext/llama_cpp/llama_cpp.cpp +14 -0
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +4 -0
- data/vendor/tmp/llama.cpp/Makefile +8 -2
- data/vendor/tmp/llama.cpp/ggml-backend-impl.h +1 -1
- data/vendor/tmp/llama.cpp/ggml-backend.c +7 -3
- data/vendor/tmp/llama.cpp/ggml-backend.h +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda.cu +758 -39
- data/vendor/tmp/llama.cpp/ggml-metal.h +1 -1
- data/vendor/tmp/llama.cpp/ggml-metal.m +86 -7
- data/vendor/tmp/llama.cpp/ggml-metal.metal +692 -8
- data/vendor/tmp/llama.cpp/ggml-quants.c +635 -1
- data/vendor/tmp/llama.cpp/ggml-quants.h +25 -1
- data/vendor/tmp/llama.cpp/ggml.c +91 -52
- data/vendor/tmp/llama.cpp/ggml.h +14 -11
- data/vendor/tmp/llama.cpp/llama.cpp +79 -30
- data/vendor/tmp/llama.cpp/llama.h +14 -0
- metadata +2 -2
@@ -2446,6 +2446,19 @@ typedef struct {
|
|
2446
2446
|
} block_q6_K;
|
2447
2447
|
// 210 bytes / block
|
2448
2448
|
|
2449
|
+
typedef struct {
|
2450
|
+
half d;
|
2451
|
+
uint16_t qs[QK_K/8];
|
2452
|
+
} block_iq2_xxs;
|
2453
|
+
// 66 bytes / block for QK_K = 256, so 2.0625 bpw
|
2454
|
+
|
2455
|
+
typedef struct {
|
2456
|
+
half d;
|
2457
|
+
uint16_t qs[QK_K/8];
|
2458
|
+
uint8_t scales[QK_K/32];
|
2459
|
+
} block_iq2_xs;
|
2460
|
+
// 74 bytes / block for QK_K = 256, so 2.3125 bpw
|
2461
|
+
|
2449
2462
|
//====================================== dot products =========================
|
2450
2463
|
|
2451
2464
|
void kernel_mul_mv_q2_K_f32_impl(
|
@@ -3468,6 +3481,495 @@ kernel void kernel_mul_mv_q6_K_f32(
|
|
3468
3481
|
kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
3469
3482
|
}
|
3470
3483
|
|
3484
|
+
// ======================= "True" 2-bit
|
3485
|
+
|
3486
|
+
constexpr constant static uint64_t iq2xxs_grid[256] = {
|
3487
|
+
0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
|
3488
|
+
0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x08080808082b0808,
|
3489
|
+
0x08080808082b082b, 0x08080808082b2b08, 0x08080808082b2b2b, 0x0808080819080819,
|
3490
|
+
0x0808080819081908, 0x0808080819190808, 0x0808080819192b08, 0x08080808192b0819,
|
3491
|
+
0x08080808192b1908, 0x080808082b080808, 0x080808082b08082b, 0x080808082b082b2b,
|
3492
|
+
0x080808082b2b082b, 0x0808081908080819, 0x0808081908081908, 0x0808081908190808,
|
3493
|
+
0x0808081908191919, 0x0808081919080808, 0x080808192b081908, 0x080808192b192b08,
|
3494
|
+
0x0808082b08080808, 0x0808082b0808082b, 0x0808082b082b082b, 0x0808082b2b08082b,
|
3495
|
+
0x0808190808080819, 0x0808190808081908, 0x0808190808190808, 0x08081908082b0819,
|
3496
|
+
0x08081908082b1908, 0x0808190819080808, 0x080819081908082b, 0x0808190819082b08,
|
3497
|
+
0x08081908192b0808, 0x080819082b080819, 0x080819082b081908, 0x080819082b190808,
|
3498
|
+
0x080819082b2b1908, 0x0808191908080808, 0x080819190808082b, 0x0808191908082b08,
|
3499
|
+
0x08081919082b0808, 0x080819191908192b, 0x08081919192b2b19, 0x080819192b080808,
|
3500
|
+
0x080819192b190819, 0x0808192b08082b19, 0x0808192b08190808, 0x0808192b19080808,
|
3501
|
+
0x0808192b2b081908, 0x0808192b2b2b1908, 0x08082b0808080808, 0x08082b0808081919,
|
3502
|
+
0x08082b0808082b08, 0x08082b0808191908, 0x08082b08082b2b08, 0x08082b0819080819,
|
3503
|
+
0x08082b0819081908, 0x08082b0819190808, 0x08082b081919082b, 0x08082b082b082b08,
|
3504
|
+
0x08082b1908081908, 0x08082b1919080808, 0x08082b2b0808082b, 0x08082b2b08191908,
|
3505
|
+
0x0819080808080819, 0x0819080808081908, 0x0819080808190808, 0x08190808082b0819,
|
3506
|
+
0x0819080819080808, 0x08190808192b0808, 0x081908082b081908, 0x081908082b190808,
|
3507
|
+
0x081908082b191919, 0x0819081908080808, 0x0819081908082b08, 0x08190819082b0808,
|
3508
|
+
0x0819081919190808, 0x0819081919192b2b, 0x081908192b080808, 0x0819082b082b1908,
|
3509
|
+
0x0819082b19081919, 0x0819190808080808, 0x0819190808082b08, 0x08191908082b0808,
|
3510
|
+
0x08191908082b1919, 0x0819190819082b19, 0x081919082b080808, 0x0819191908192b08,
|
3511
|
+
0x08191919192b082b, 0x0819192b08080808, 0x0819192b0819192b, 0x08192b0808080819,
|
3512
|
+
0x08192b0808081908, 0x08192b0808190808, 0x08192b0819080808, 0x08192b082b080819,
|
3513
|
+
0x08192b1908080808, 0x08192b1908081919, 0x08192b192b2b0808, 0x08192b2b19190819,
|
3514
|
+
0x082b080808080808, 0x082b08080808082b, 0x082b080808082b2b, 0x082b080819081908,
|
3515
|
+
0x082b0808192b0819, 0x082b08082b080808, 0x082b08082b08082b, 0x082b0819082b2b19,
|
3516
|
+
0x082b081919082b08, 0x082b082b08080808, 0x082b082b0808082b, 0x082b190808080819,
|
3517
|
+
0x082b190808081908, 0x082b190808190808, 0x082b190819080808, 0x082b19081919192b,
|
3518
|
+
0x082b191908080808, 0x082b191919080819, 0x082b1919192b1908, 0x082b192b2b190808,
|
3519
|
+
0x082b2b0808082b08, 0x082b2b08082b0808, 0x082b2b082b191908, 0x082b2b2b19081908,
|
3520
|
+
0x1908080808080819, 0x1908080808081908, 0x1908080808190808, 0x1908080808192b08,
|
3521
|
+
0x19080808082b0819, 0x19080808082b1908, 0x1908080819080808, 0x1908080819082b08,
|
3522
|
+
0x190808081919192b, 0x19080808192b0808, 0x190808082b080819, 0x190808082b081908,
|
3523
|
+
0x190808082b190808, 0x1908081908080808, 0x19080819082b0808, 0x19080819192b0819,
|
3524
|
+
0x190808192b080808, 0x190808192b081919, 0x1908082b08080819, 0x1908082b08190808,
|
3525
|
+
0x1908082b19082b08, 0x1908082b1919192b, 0x1908082b192b2b08, 0x1908190808080808,
|
3526
|
+
0x1908190808082b08, 0x19081908082b0808, 0x190819082b080808, 0x190819082b192b19,
|
3527
|
+
0x190819190819082b, 0x19081919082b1908, 0x1908192b08080808, 0x19082b0808080819,
|
3528
|
+
0x19082b0808081908, 0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919,
|
3529
|
+
0x19082b1908080808, 0x19082b1919192b08, 0x19082b19192b0819, 0x19082b192b08082b,
|
3530
|
+
0x19082b2b19081919, 0x19082b2b2b190808, 0x1919080808080808, 0x1919080808082b08,
|
3531
|
+
0x1919080808190819, 0x1919080808192b19, 0x19190808082b0808, 0x191908082b080808,
|
3532
|
+
0x191908082b082b08, 0x1919081908081908, 0x191908191908082b, 0x191908192b2b1908,
|
3533
|
+
0x1919082b2b190819, 0x191919082b190808, 0x191919082b19082b, 0x1919191908082b2b,
|
3534
|
+
0x1919192b08080819, 0x1919192b19191908, 0x19192b0808080808, 0x19192b0808190819,
|
3535
|
+
0x19192b0808192b19, 0x19192b08192b1908, 0x19192b1919080808, 0x19192b2b08082b08,
|
3536
|
+
0x192b080808081908, 0x192b080808190808, 0x192b080819080808, 0x192b0808192b2b08,
|
3537
|
+
0x192b081908080808, 0x192b081919191919, 0x192b082b08192b08, 0x192b082b192b0808,
|
3538
|
+
0x192b190808080808, 0x192b190808081919, 0x192b191908190808, 0x192b19190819082b,
|
3539
|
+
0x192b19192b081908, 0x192b2b081908082b, 0x2b08080808080808, 0x2b0808080808082b,
|
3540
|
+
0x2b08080808082b2b, 0x2b08080819080819, 0x2b0808082b08082b, 0x2b08081908081908,
|
3541
|
+
0x2b08081908192b08, 0x2b08081919080808, 0x2b08082b08190819, 0x2b08190808080819,
|
3542
|
+
0x2b08190808081908, 0x2b08190808190808, 0x2b08190808191919, 0x2b08190819080808,
|
3543
|
+
0x2b081908192b0808, 0x2b08191908080808, 0x2b0819191908192b, 0x2b0819192b191908,
|
3544
|
+
0x2b08192b08082b19, 0x2b08192b19080808, 0x2b08192b192b0808, 0x2b082b080808082b,
|
3545
|
+
0x2b082b1908081908, 0x2b082b2b08190819, 0x2b19080808081908, 0x2b19080808190808,
|
3546
|
+
0x2b190808082b1908, 0x2b19080819080808, 0x2b1908082b2b0819, 0x2b1908190819192b,
|
3547
|
+
0x2b1908192b080808, 0x2b19082b19081919, 0x2b19190808080808, 0x2b191908082b082b,
|
3548
|
+
0x2b19190819081908, 0x2b19191919190819, 0x2b192b082b080819, 0x2b192b19082b0808,
|
3549
|
+
0x2b2b08080808082b, 0x2b2b080819190808, 0x2b2b08082b081919, 0x2b2b081908082b19,
|
3550
|
+
0x2b2b082b08080808, 0x2b2b190808192b08, 0x2b2b2b0819190808, 0x2b2b2b1908081908,
|
3551
|
+
};
|
3552
|
+
|
3553
|
+
constexpr constant static uint64_t iq2xs_grid[512] = {
|
3554
|
+
0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
|
3555
|
+
0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x080808080819192b,
|
3556
|
+
0x0808080808192b19, 0x08080808082b0808, 0x08080808082b082b, 0x08080808082b1919,
|
3557
|
+
0x08080808082b2b08, 0x0808080819080819, 0x0808080819081908, 0x080808081908192b,
|
3558
|
+
0x0808080819082b19, 0x0808080819190808, 0x080808081919082b, 0x0808080819191919,
|
3559
|
+
0x0808080819192b08, 0x08080808192b0819, 0x08080808192b1908, 0x080808082b080808,
|
3560
|
+
0x080808082b08082b, 0x080808082b081919, 0x080808082b082b08, 0x080808082b190819,
|
3561
|
+
0x080808082b191908, 0x080808082b192b19, 0x080808082b2b0808, 0x0808081908080819,
|
3562
|
+
0x0808081908081908, 0x080808190808192b, 0x0808081908082b19, 0x0808081908190808,
|
3563
|
+
0x080808190819082b, 0x0808081908191919, 0x0808081908192b08, 0x0808081908192b2b,
|
3564
|
+
0x08080819082b0819, 0x08080819082b1908, 0x0808081919080808, 0x080808191908082b,
|
3565
|
+
0x0808081919081919, 0x0808081919082b08, 0x0808081919190819, 0x0808081919191908,
|
3566
|
+
0x08080819192b0808, 0x08080819192b2b08, 0x080808192b080819, 0x080808192b081908,
|
3567
|
+
0x080808192b190808, 0x0808082b08080808, 0x0808082b0808082b, 0x0808082b08081919,
|
3568
|
+
0x0808082b08082b08, 0x0808082b08190819, 0x0808082b08191908, 0x0808082b082b0808,
|
3569
|
+
0x0808082b19080819, 0x0808082b19081908, 0x0808082b19190808, 0x0808082b19191919,
|
3570
|
+
0x0808082b2b080808, 0x0808082b2b082b2b, 0x0808190808080819, 0x0808190808081908,
|
3571
|
+
0x080819080808192b, 0x0808190808082b19, 0x0808190808190808, 0x080819080819082b,
|
3572
|
+
0x0808190808191919, 0x0808190808192b08, 0x08081908082b0819, 0x08081908082b1908,
|
3573
|
+
0x0808190819080808, 0x080819081908082b, 0x0808190819081919, 0x0808190819082b08,
|
3574
|
+
0x0808190819190819, 0x0808190819191908, 0x080819081919192b, 0x08081908192b0808,
|
3575
|
+
0x080819082b080819, 0x080819082b081908, 0x080819082b190808, 0x0808191908080808,
|
3576
|
+
0x080819190808082b, 0x0808191908081919, 0x0808191908082b08, 0x0808191908190819,
|
3577
|
+
0x0808191908191908, 0x08081919082b0808, 0x0808191919080819, 0x0808191919081908,
|
3578
|
+
0x0808191919190808, 0x08081919192b0819, 0x080819192b080808, 0x0808192b08080819,
|
3579
|
+
0x0808192b08081908, 0x0808192b08190808, 0x0808192b082b192b, 0x0808192b19080808,
|
3580
|
+
0x0808192b1908082b, 0x0808192b2b081908, 0x08082b0808080808, 0x08082b080808082b,
|
3581
|
+
0x08082b0808081919, 0x08082b0808082b08, 0x08082b0808082b2b, 0x08082b0808190819,
|
3582
|
+
0x08082b0808191908, 0x08082b08082b0808, 0x08082b08082b1919, 0x08082b0819080819,
|
3583
|
+
0x08082b0819081908, 0x08082b0819190808, 0x08082b0819192b08, 0x08082b082b080808,
|
3584
|
+
0x08082b082b2b0808, 0x08082b082b2b2b2b, 0x08082b1908080819, 0x08082b1908081908,
|
3585
|
+
0x08082b1908190808, 0x08082b1919080808, 0x08082b192b080819, 0x08082b192b082b19,
|
3586
|
+
0x08082b2b08080808, 0x08082b2b082b0808, 0x08082b2b082b2b08, 0x08082b2b2b19192b,
|
3587
|
+
0x08082b2b2b2b0808, 0x0819080808080819, 0x0819080808081908, 0x081908080808192b,
|
3588
|
+
0x0819080808082b19, 0x0819080808190808, 0x081908080819082b, 0x0819080808191919,
|
3589
|
+
0x0819080808192b08, 0x08190808082b0819, 0x08190808082b1908, 0x0819080819080808,
|
3590
|
+
0x081908081908082b, 0x0819080819081919, 0x0819080819082b08, 0x0819080819190819,
|
3591
|
+
0x0819080819191908, 0x08190808192b0808, 0x08190808192b2b2b, 0x081908082b080819,
|
3592
|
+
0x081908082b081908, 0x081908082b190808, 0x0819081908080808, 0x081908190808082b,
|
3593
|
+
0x0819081908081919, 0x0819081908082b08, 0x0819081908190819, 0x0819081908191908,
|
3594
|
+
0x08190819082b0808, 0x0819081919080819, 0x0819081919081908, 0x0819081919190808,
|
3595
|
+
0x081908192b080808, 0x081908192b191908, 0x081908192b19192b, 0x0819082b08080819,
|
3596
|
+
0x0819082b08081908, 0x0819082b0808192b, 0x0819082b08190808, 0x0819082b19080808,
|
3597
|
+
0x0819082b192b0808, 0x0819190808080808, 0x081919080808082b, 0x0819190808081919,
|
3598
|
+
0x0819190808082b08, 0x0819190808190819, 0x0819190808191908, 0x08191908082b0808,
|
3599
|
+
0x0819190819080819, 0x0819190819081908, 0x0819190819082b19, 0x0819190819190808,
|
3600
|
+
0x08191908192b1908, 0x081919082b080808, 0x0819191908080819, 0x0819191908081908,
|
3601
|
+
0x0819191908190808, 0x0819191919080808, 0x0819192b08080808, 0x0819192b08191908,
|
3602
|
+
0x0819192b19082b19, 0x08192b0808080819, 0x08192b0808081908, 0x08192b0808190808,
|
3603
|
+
0x08192b080819082b, 0x08192b0819080808, 0x08192b0819191908, 0x08192b082b08192b,
|
3604
|
+
0x08192b1908080808, 0x08192b1908081919, 0x08192b19192b192b, 0x08192b2b19190819,
|
3605
|
+
0x08192b2b2b2b2b19, 0x082b080808080808, 0x082b08080808082b, 0x082b080808081919,
|
3606
|
+
0x082b080808082b08, 0x082b080808082b2b, 0x082b080808190819, 0x082b080808191908,
|
3607
|
+
0x082b0808082b0808, 0x082b080819080819, 0x082b080819081908, 0x082b080819190808,
|
3608
|
+
0x082b08082b080808, 0x082b08082b2b0808, 0x082b081908080819, 0x082b081908081908,
|
3609
|
+
0x082b081908190808, 0x082b081919080808, 0x082b081919082b08, 0x082b0819192b1919,
|
3610
|
+
0x082b082b08080808, 0x082b082b082b082b, 0x082b082b2b080808, 0x082b082b2b2b2b08,
|
3611
|
+
0x082b190808080819, 0x082b190808081908, 0x082b190808190808, 0x082b1908082b2b19,
|
3612
|
+
0x082b190819080808, 0x082b191908080808, 0x082b191919080819, 0x082b19191919082b,
|
3613
|
+
0x082b19192b192b19, 0x082b192b08080819, 0x082b192b08192b2b, 0x082b192b2b2b192b,
|
3614
|
+
0x082b2b0808080808, 0x082b2b0808082b08, 0x082b2b0808082b2b, 0x082b2b08082b0808,
|
3615
|
+
0x082b2b0819191919, 0x082b2b082b082b08, 0x082b2b082b2b082b, 0x082b2b19192b2b08,
|
3616
|
+
0x082b2b192b190808, 0x082b2b2b08082b08, 0x082b2b2b082b0808, 0x082b2b2b2b08082b,
|
3617
|
+
0x082b2b2b2b082b08, 0x082b2b2b2b082b2b, 0x1908080808080819, 0x1908080808081908,
|
3618
|
+
0x190808080808192b, 0x1908080808082b19, 0x1908080808190808, 0x190808080819082b,
|
3619
|
+
0x1908080808191919, 0x1908080808192b08, 0x19080808082b0819, 0x19080808082b1908,
|
3620
|
+
0x1908080819080808, 0x190808081908082b, 0x1908080819081919, 0x1908080819082b08,
|
3621
|
+
0x1908080819082b2b, 0x1908080819190819, 0x1908080819191908, 0x19080808192b0808,
|
3622
|
+
0x19080808192b1919, 0x190808082b080819, 0x190808082b081908, 0x190808082b190808,
|
3623
|
+
0x1908081908080808, 0x190808190808082b, 0x1908081908081919, 0x1908081908082b08,
|
3624
|
+
0x1908081908190819, 0x1908081908191908, 0x19080819082b0808, 0x1908081919080819,
|
3625
|
+
0x1908081919081908, 0x1908081919190808, 0x190808192b080808, 0x190808192b081919,
|
3626
|
+
0x190808192b2b082b, 0x1908082b08080819, 0x1908082b08081908, 0x1908082b08190808,
|
3627
|
+
0x1908082b0819082b, 0x1908082b082b2b19, 0x1908082b19080808, 0x1908190808080808,
|
3628
|
+
0x190819080808082b, 0x1908190808081919, 0x1908190808082b08, 0x1908190808190819,
|
3629
|
+
0x1908190808191908, 0x1908190808192b19, 0x19081908082b0808, 0x1908190819080819,
|
3630
|
+
0x1908190819081908, 0x1908190819190808, 0x190819082b080808, 0x190819082b191908,
|
3631
|
+
0x1908191908080819, 0x1908191908081908, 0x1908191908190808, 0x19081919082b1908,
|
3632
|
+
0x1908191919080808, 0x190819192b192b2b, 0x1908192b08080808, 0x1908192b08082b2b,
|
3633
|
+
0x1908192b19081908, 0x1908192b19190808, 0x19082b0808080819, 0x19082b0808081908,
|
3634
|
+
0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919, 0x19082b0819191908,
|
3635
|
+
0x19082b08192b082b, 0x19082b1908080808, 0x19082b1908190819, 0x19082b1919081908,
|
3636
|
+
0x19082b1919190808, 0x19082b19192b2b19, 0x19082b2b08081908, 0x1919080808080808,
|
3637
|
+
0x191908080808082b, 0x1919080808081919, 0x1919080808082b08, 0x1919080808190819,
|
3638
|
+
0x1919080808191908, 0x19190808082b0808, 0x19190808082b2b08, 0x1919080819080819,
|
3639
|
+
0x1919080819081908, 0x1919080819190808, 0x191908082b080808, 0x1919081908080819,
|
3640
|
+
0x1919081908081908, 0x1919081908190808, 0x1919081908191919, 0x1919081919080808,
|
3641
|
+
0x191908191908082b, 0x1919082b08080808, 0x1919082b19081908, 0x1919082b2b2b2b2b,
|
3642
|
+
0x1919190808080819, 0x1919190808081908, 0x1919190808190808, 0x19191908082b0819,
|
3643
|
+
0x1919190819080808, 0x19191908192b0808, 0x191919082b080819, 0x191919082b2b0819,
|
3644
|
+
0x1919191908080808, 0x1919191908082b08, 0x191919192b080808, 0x191919192b082b08,
|
3645
|
+
0x1919192b082b0819, 0x1919192b192b2b08, 0x1919192b2b2b0819, 0x19192b0808080808,
|
3646
|
+
0x19192b0808191908, 0x19192b0819080819, 0x19192b0819190808, 0x19192b082b192b19,
|
3647
|
+
0x19192b1908192b2b, 0x19192b1919080808, 0x19192b191908082b, 0x19192b2b2b081919,
|
3648
|
+
0x192b080808080819, 0x192b080808081908, 0x192b080808190808, 0x192b080819080808,
|
3649
|
+
0x192b080819191908, 0x192b0808192b082b, 0x192b08082b08192b, 0x192b08082b2b2b19,
|
3650
|
+
0x192b081908080808, 0x192b082b082b1908, 0x192b082b19082b2b, 0x192b082b2b19082b,
|
3651
|
+
0x192b190808080808, 0x192b19080819192b, 0x192b191908190808, 0x192b191919080808,
|
3652
|
+
0x192b191919081919, 0x192b19192b2b1908, 0x192b2b0808080819, 0x192b2b08192b2b2b,
|
3653
|
+
0x192b2b19082b1919, 0x192b2b2b0808192b, 0x192b2b2b19191908, 0x192b2b2b192b082b,
|
3654
|
+
0x2b08080808080808, 0x2b0808080808082b, 0x2b08080808081919, 0x2b08080808082b08,
|
3655
|
+
0x2b08080808190819, 0x2b08080808191908, 0x2b080808082b0808, 0x2b080808082b2b2b,
|
3656
|
+
0x2b08080819080819, 0x2b08080819081908, 0x2b08080819190808, 0x2b0808082b080808,
|
3657
|
+
0x2b0808082b08082b, 0x2b0808082b2b2b08, 0x2b0808082b2b2b2b, 0x2b08081908080819,
|
3658
|
+
0x2b08081908081908, 0x2b0808190808192b, 0x2b08081908190808, 0x2b08081919080808,
|
3659
|
+
0x2b08081919190819, 0x2b08081919192b19, 0x2b08082b08080808, 0x2b08082b082b0808,
|
3660
|
+
0x2b08082b2b080808, 0x2b08082b2b08082b, 0x2b08082b2b2b0808, 0x2b08082b2b2b2b08,
|
3661
|
+
0x2b08190808080819, 0x2b08190808081908, 0x2b08190808190808, 0x2b0819080819082b,
|
3662
|
+
0x2b08190808191919, 0x2b08190819080808, 0x2b081908192b0808, 0x2b0819082b082b19,
|
3663
|
+
0x2b08191908080808, 0x2b08191919081908, 0x2b0819192b2b1919, 0x2b08192b08192b08,
|
3664
|
+
0x2b08192b192b2b2b, 0x2b082b0808080808, 0x2b082b0808082b08, 0x2b082b08082b1919,
|
3665
|
+
0x2b082b0819192b2b, 0x2b082b082b080808, 0x2b082b082b08082b, 0x2b082b082b2b2b08,
|
3666
|
+
0x2b082b190808192b, 0x2b082b2b082b082b, 0x2b082b2b2b080808, 0x2b082b2b2b082b08,
|
3667
|
+
0x2b082b2b2b19192b, 0x2b082b2b2b2b2b08, 0x2b19080808080819, 0x2b19080808081908,
|
3668
|
+
0x2b19080808190808, 0x2b19080819080808, 0x2b1908081919192b, 0x2b1908082b081908,
|
3669
|
+
0x2b19081908080808, 0x2b190819082b082b, 0x2b190819192b1908, 0x2b19082b1919192b,
|
3670
|
+
0x2b19082b2b082b19, 0x2b19190808080808, 0x2b19190808081919, 0x2b19190819081908,
|
3671
|
+
0x2b19190819190808, 0x2b19190819192b08, 0x2b191919082b2b19, 0x2b1919192b190808,
|
3672
|
+
0x2b1919192b19082b, 0x2b19192b19080819, 0x2b192b0819190819, 0x2b192b082b2b192b,
|
3673
|
+
0x2b192b1919082b19, 0x2b192b2b08191919, 0x2b192b2b192b0808, 0x2b2b080808080808,
|
3674
|
+
0x2b2b08080808082b, 0x2b2b080808082b08, 0x2b2b080808082b2b, 0x2b2b0808082b0808,
|
3675
|
+
0x2b2b0808082b2b2b, 0x2b2b08082b2b0808, 0x2b2b081919190819, 0x2b2b081919192b19,
|
3676
|
+
0x2b2b08192b2b192b, 0x2b2b082b08080808, 0x2b2b082b0808082b, 0x2b2b082b08082b08,
|
3677
|
+
0x2b2b082b082b2b2b, 0x2b2b082b2b080808, 0x2b2b082b2b2b0808, 0x2b2b190819080808,
|
3678
|
+
0x2b2b19082b191919, 0x2b2b192b192b1919, 0x2b2b192b2b192b08, 0x2b2b2b0808082b2b,
|
3679
|
+
0x2b2b2b08082b0808, 0x2b2b2b08082b082b, 0x2b2b2b08082b2b08, 0x2b2b2b082b2b0808,
|
3680
|
+
0x2b2b2b082b2b2b08, 0x2b2b2b1908081908, 0x2b2b2b192b081908, 0x2b2b2b192b08192b,
|
3681
|
+
0x2b2b2b2b082b2b08, 0x2b2b2b2b082b2b2b, 0x2b2b2b2b2b190819, 0x2b2b2b2b2b2b2b2b,
|
3682
|
+
};
|
3683
|
+
|
3684
|
+
constexpr constant static uint8_t ksigns_iq2xs[128] = {
|
3685
|
+
0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
|
3686
|
+
144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
|
3687
|
+
160, 33, 34, 163, 36, 165, 166, 39, 40, 169, 170, 43, 172, 45, 46, 175,
|
3688
|
+
48, 177, 178, 51, 180, 53, 54, 183, 184, 57, 58, 187, 60, 189, 190, 63,
|
3689
|
+
192, 65, 66, 195, 68, 197, 198, 71, 72, 201, 202, 75, 204, 77, 78, 207,
|
3690
|
+
80, 209, 210, 83, 212, 85, 86, 215, 216, 89, 90, 219, 92, 221, 222, 95,
|
3691
|
+
96, 225, 226, 99, 228, 101, 102, 231, 232, 105, 106, 235, 108, 237, 238, 111,
|
3692
|
+
240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255,
|
3693
|
+
};
|
3694
|
+
|
3695
|
+
constexpr constant static uint8_t kmask_iq2xs[8] = {1, 2, 4, 8, 16, 32, 64, 128};
|
3696
|
+
|
3697
|
+
void kernel_mul_mv_iq2_xxs_f32_impl(
|
3698
|
+
device const void * src0,
|
3699
|
+
device const float * src1,
|
3700
|
+
device float * dst,
|
3701
|
+
constant int64_t & ne00,
|
3702
|
+
constant int64_t & ne01,
|
3703
|
+
constant int64_t & ne02,
|
3704
|
+
constant int64_t & ne10,
|
3705
|
+
constant int64_t & ne12,
|
3706
|
+
constant int64_t & ne0,
|
3707
|
+
constant int64_t & ne1,
|
3708
|
+
constant uint & r2,
|
3709
|
+
constant uint & r3,
|
3710
|
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
3711
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
3712
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
3713
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
3714
|
+
|
3715
|
+
const int nb = ne00/QK_K;
|
3716
|
+
const int r0 = tgpig.x;
|
3717
|
+
const int r1 = tgpig.y;
|
3718
|
+
const int im = tgpig.z;
|
3719
|
+
|
3720
|
+
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
3721
|
+
const int ib_row = first_row * nb;
|
3722
|
+
|
3723
|
+
const uint i12 = im%ne12;
|
3724
|
+
const uint i13 = im/ne12;
|
3725
|
+
|
3726
|
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
3727
|
+
|
3728
|
+
device const block_iq2_xxs * x = (device const block_iq2_xxs *) src0 + ib_row + offset0;
|
3729
|
+
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
3730
|
+
|
3731
|
+
float yl[32];
|
3732
|
+
float sumf[N_DST]={0.f}, all_sum;
|
3733
|
+
|
3734
|
+
const int nb32 = nb * (QK_K / 32);
|
3735
|
+
|
3736
|
+
threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values;
|
3737
|
+
threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 256);
|
3738
|
+
{
|
3739
|
+
int nval = 4;
|
3740
|
+
int pos = (32*sgitg + tiisg)*nval;
|
3741
|
+
for (int i = 0; i < nval; ++i) values[pos + i] = iq2xxs_grid[pos + i];
|
3742
|
+
nval = 2;
|
3743
|
+
pos = (32*sgitg + tiisg)*nval;
|
3744
|
+
for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i];
|
3745
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
3746
|
+
}
|
3747
|
+
|
3748
|
+
#if QK_K == 256
|
3749
|
+
const int ix = tiisg;
|
3750
|
+
|
3751
|
+
device const float * y4 = y + 32 * ix;
|
3752
|
+
|
3753
|
+
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
|
3754
|
+
|
3755
|
+
for (int i = 0; i < 32; ++i) {
|
3756
|
+
yl[i] = y4[i];
|
3757
|
+
}
|
3758
|
+
|
3759
|
+
const int ibl = ib32 / (QK_K / 32);
|
3760
|
+
const int ib = ib32 % (QK_K / 32);
|
3761
|
+
|
3762
|
+
device const block_iq2_xxs * xr = x + ibl;
|
3763
|
+
device const uint16_t * q2 = xr->qs + 4 * ib;
|
3764
|
+
device const half * dh = &xr->d;
|
3765
|
+
|
3766
|
+
for (int row = 0; row < N_DST; row++) {
|
3767
|
+
|
3768
|
+
const float db = dh[0];
|
3769
|
+
device const uint8_t * aux8 = (device const uint8_t *)q2;
|
3770
|
+
const uint32_t aux32 = q2[2] | (q2[3] << 16);
|
3771
|
+
const float d = db * (0.5f + (aux32 >> 28));
|
3772
|
+
|
3773
|
+
float sum = 0;
|
3774
|
+
for (int l = 0; l < 4; ++l) {
|
3775
|
+
const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + aux8[l]);
|
3776
|
+
const uint8_t signs = shared_signs[(aux32 >> 7*l) & 127];
|
3777
|
+
for (int j = 0; j < 8; ++j) {
|
3778
|
+
sum += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
|
3779
|
+
}
|
3780
|
+
}
|
3781
|
+
sumf[row] += d * sum;
|
3782
|
+
|
3783
|
+
dh += nb*sizeof(block_iq2_xxs)/2;
|
3784
|
+
q2 += nb*sizeof(block_iq2_xxs)/2;
|
3785
|
+
}
|
3786
|
+
|
3787
|
+
y4 += 32 * 32;
|
3788
|
+
}
|
3789
|
+
#else
|
3790
|
+
// TODO
|
3791
|
+
#endif
|
3792
|
+
|
3793
|
+
for (int row = 0; row < N_DST; ++row) {
|
3794
|
+
all_sum = simd_sum(sumf[row]);
|
3795
|
+
if (tiisg == 0) {
|
3796
|
+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f;
|
3797
|
+
}
|
3798
|
+
}
|
3799
|
+
}
|
3800
|
+
|
3801
|
+
[[host_name("kernel_mul_mv_iq2_xxs_f32")]]
|
3802
|
+
kernel void kernel_mul_mv_iq2_xxs_f32(
|
3803
|
+
device const void * src0,
|
3804
|
+
device const float * src1,
|
3805
|
+
device float * dst,
|
3806
|
+
constant int64_t & ne00,
|
3807
|
+
constant int64_t & ne01,
|
3808
|
+
constant int64_t & ne02,
|
3809
|
+
constant uint64_t & nb00,
|
3810
|
+
constant uint64_t & nb01,
|
3811
|
+
constant uint64_t & nb02,
|
3812
|
+
constant int64_t & ne10,
|
3813
|
+
constant int64_t & ne11,
|
3814
|
+
constant int64_t & ne12,
|
3815
|
+
constant uint64_t & nb10,
|
3816
|
+
constant uint64_t & nb11,
|
3817
|
+
constant uint64_t & nb12,
|
3818
|
+
constant int64_t & ne0,
|
3819
|
+
constant int64_t & ne1,
|
3820
|
+
constant uint & r2,
|
3821
|
+
constant uint & r3,
|
3822
|
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
3823
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
3824
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
3825
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
3826
|
+
|
3827
|
+
kernel_mul_mv_iq2_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
3828
|
+
}
|
3829
|
+
|
3830
|
+
void kernel_mul_mv_iq2_xs_f32_impl(
|
3831
|
+
device const void * src0,
|
3832
|
+
device const float * src1,
|
3833
|
+
device float * dst,
|
3834
|
+
constant int64_t & ne00,
|
3835
|
+
constant int64_t & ne01,
|
3836
|
+
constant int64_t & ne02,
|
3837
|
+
constant int64_t & ne10,
|
3838
|
+
constant int64_t & ne12,
|
3839
|
+
constant int64_t & ne0,
|
3840
|
+
constant int64_t & ne1,
|
3841
|
+
constant uint & r2,
|
3842
|
+
constant uint & r3,
|
3843
|
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
3844
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
3845
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
3846
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
3847
|
+
|
3848
|
+
const int nb = ne00/QK_K;
|
3849
|
+
const int r0 = tgpig.x;
|
3850
|
+
const int r1 = tgpig.y;
|
3851
|
+
const int im = tgpig.z;
|
3852
|
+
|
3853
|
+
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
3854
|
+
const int ib_row = first_row * nb;
|
3855
|
+
|
3856
|
+
const uint i12 = im%ne12;
|
3857
|
+
const uint i13 = im/ne12;
|
3858
|
+
|
3859
|
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
3860
|
+
|
3861
|
+
device const block_iq2_xs * x = (device const block_iq2_xs *) src0 + ib_row + offset0;
|
3862
|
+
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
3863
|
+
|
3864
|
+
float yl[32];
|
3865
|
+
float sumf[N_DST]={0.f}, all_sum;
|
3866
|
+
|
3867
|
+
const int nb32 = nb * (QK_K / 32);
|
3868
|
+
|
3869
|
+
threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values;
|
3870
|
+
threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 512);
|
3871
|
+
{
|
3872
|
+
int nval = 8;
|
3873
|
+
int pos = (32*sgitg + tiisg)*nval;
|
3874
|
+
for (int i = 0; i < nval; ++i) values[pos + i] = iq2xs_grid[pos + i];
|
3875
|
+
nval = 2;
|
3876
|
+
pos = (32*sgitg + tiisg)*nval;
|
3877
|
+
for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i];
|
3878
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
3879
|
+
}
|
3880
|
+
|
3881
|
+
#if QK_K == 256
|
3882
|
+
const int ix = tiisg;
|
3883
|
+
|
3884
|
+
device const float * y4 = y + 32 * ix;
|
3885
|
+
|
3886
|
+
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
|
3887
|
+
|
3888
|
+
for (int i = 0; i < 32; ++i) {
|
3889
|
+
yl[i] = y4[i];
|
3890
|
+
}
|
3891
|
+
|
3892
|
+
const int ibl = ib32 / (QK_K / 32);
|
3893
|
+
const int ib = ib32 % (QK_K / 32);
|
3894
|
+
|
3895
|
+
device const block_iq2_xs * xr = x + ibl;
|
3896
|
+
device const uint16_t * q2 = xr->qs + 4 * ib;
|
3897
|
+
device const uint8_t * sc = xr->scales + ib;
|
3898
|
+
device const half * dh = &xr->d;
|
3899
|
+
|
3900
|
+
for (int row = 0; row < N_DST; row++) {
|
3901
|
+
|
3902
|
+
const float db = dh[0];
|
3903
|
+
const uint8_t ls1 = sc[0] & 0xf;
|
3904
|
+
const uint8_t ls2 = sc[0] >> 4;
|
3905
|
+
const float d1 = db * (0.5f + ls1);
|
3906
|
+
const float d2 = db * (0.5f + ls2);
|
3907
|
+
|
3908
|
+
float sum1 = 0, sum2 = 0;
|
3909
|
+
for (int l = 0; l < 2; ++l) {
|
3910
|
+
const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + (q2[l] & 511));
|
3911
|
+
const uint8_t signs = shared_signs[(q2[l] >> 9)];
|
3912
|
+
for (int j = 0; j < 8; ++j) {
|
3913
|
+
sum1 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
|
3914
|
+
}
|
3915
|
+
}
|
3916
|
+
for (int l = 2; l < 4; ++l) {
|
3917
|
+
const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + (q2[l] & 511));
|
3918
|
+
const uint8_t signs = shared_signs[(q2[l] >> 9)];
|
3919
|
+
for (int j = 0; j < 8; ++j) {
|
3920
|
+
sum2 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
|
3921
|
+
}
|
3922
|
+
}
|
3923
|
+
sumf[row] += d1 * sum1 + d2 * sum2;
|
3924
|
+
|
3925
|
+
dh += nb*sizeof(block_iq2_xs)/2;
|
3926
|
+
q2 += nb*sizeof(block_iq2_xs)/2;
|
3927
|
+
sc += nb*sizeof(block_iq2_xs);
|
3928
|
+
}
|
3929
|
+
|
3930
|
+
y4 += 32 * 32;
|
3931
|
+
}
|
3932
|
+
#else
|
3933
|
+
// TODO
|
3934
|
+
#endif
|
3935
|
+
|
3936
|
+
for (int row = 0; row < N_DST; ++row) {
|
3937
|
+
all_sum = simd_sum(sumf[row]);
|
3938
|
+
if (tiisg == 0) {
|
3939
|
+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f;
|
3940
|
+
}
|
3941
|
+
}
|
3942
|
+
}
|
3943
|
+
|
3944
|
+
[[host_name("kernel_mul_mv_iq2_xs_f32")]]
|
3945
|
+
kernel void kernel_mul_mv_iq2_xs_f32(
|
3946
|
+
device const void * src0,
|
3947
|
+
device const float * src1,
|
3948
|
+
device float * dst,
|
3949
|
+
constant int64_t & ne00,
|
3950
|
+
constant int64_t & ne01,
|
3951
|
+
constant int64_t & ne02,
|
3952
|
+
constant uint64_t & nb00,
|
3953
|
+
constant uint64_t & nb01,
|
3954
|
+
constant uint64_t & nb02,
|
3955
|
+
constant int64_t & ne10,
|
3956
|
+
constant int64_t & ne11,
|
3957
|
+
constant int64_t & ne12,
|
3958
|
+
constant uint64_t & nb10,
|
3959
|
+
constant uint64_t & nb11,
|
3960
|
+
constant uint64_t & nb12,
|
3961
|
+
constant int64_t & ne0,
|
3962
|
+
constant int64_t & ne1,
|
3963
|
+
constant uint & r2,
|
3964
|
+
constant uint & r3,
|
3965
|
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
3966
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
3967
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
3968
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
3969
|
+
|
3970
|
+
kernel_mul_mv_iq2_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
3971
|
+
}
|
3972
|
+
|
3471
3973
|
//============================= templates and their specializations =============================
|
3472
3974
|
|
3473
3975
|
// NOTE: this is not dequantizing - we are simply fitting the template
|
@@ -3620,8 +4122,8 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
|
|
3620
4122
|
uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
|
3621
4123
|
int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2)
|
3622
4124
|
: (scale_2&kmask2) | ((scale_1&kmask1) << 4);
|
3623
|
-
|
3624
|
-
const
|
4125
|
+
float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f);
|
4126
|
+
const float ml = 4.f * dl;
|
3625
4127
|
|
3626
4128
|
il = (il/2) & 3;
|
3627
4129
|
const half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
|
@@ -3688,7 +4190,7 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
|
|
3688
4190
|
uint8_t ul = 1 << (il/2);
|
3689
4191
|
il = il & 3;
|
3690
4192
|
const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
|
3691
|
-
const float d = il < 2 ? xb->d : xb->d / 16.
|
4193
|
+
const float d = il < 2 ? xb->d : xb->d / 16.f;
|
3692
4194
|
const float min = xb->dmin;
|
3693
4195
|
const float dl = d * sc[0];
|
3694
4196
|
const float ml = min * sc[1];
|
@@ -3721,17 +4223,17 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
|
|
3721
4223
|
#if QK_K == 256
|
3722
4224
|
ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
|
3723
4225
|
qh = qh + 32*(il/8) + 16*(il&1);
|
3724
|
-
|
4226
|
+
float sc = scales[(il%2) + 2 * ((il/2))];
|
3725
4227
|
il = (il/2) & 3;
|
3726
4228
|
#else
|
3727
4229
|
ql = ql + 16 * (il&1);
|
3728
|
-
|
4230
|
+
float sc = scales[il];
|
3729
4231
|
#endif
|
3730
4232
|
const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
|
3731
4233
|
const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
|
3732
|
-
const
|
3733
|
-
const
|
3734
|
-
const
|
4234
|
+
const float coef = il>1 ? 1.f/16.f : 1.f;
|
4235
|
+
const float ml = d_all * sc * 32.f;
|
4236
|
+
const float dl = d_all * sc * coef;
|
3735
4237
|
for (int i = 0; i < 16; ++i) {
|
3736
4238
|
const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2))
|
3737
4239
|
: ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4));
|
@@ -3739,6 +4241,52 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
|
|
3739
4241
|
}
|
3740
4242
|
}
|
3741
4243
|
|
4244
|
+
template <typename type4x4>
|
4245
|
+
void dequantize_iq2_xxs(device const block_iq2_xxs * xb, short il, thread type4x4 & reg) {
|
4246
|
+
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
|
4247
|
+
const float d = xb->d;
|
4248
|
+
const int ib32 = il/2;
|
4249
|
+
il = il%2;
|
4250
|
+
// il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
|
4251
|
+
// each block of 32 needs 2 uint32_t's for the quants & scale, so 4 uint16_t's.
|
4252
|
+
device const uint16_t * q2 = xb->qs + 4*ib32;
|
4253
|
+
const uint32_t aux32_g = q2[0] | (q2[1] << 16);
|
4254
|
+
const uint32_t aux32_s = q2[2] | (q2[3] << 16);
|
4255
|
+
thread const uint8_t * aux8 = (thread const uint8_t *)&aux32_g;
|
4256
|
+
const float dl = d * (0.5f + (aux32_s >> 28)) * 0.25f;
|
4257
|
+
constant uint8_t * grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+0]);
|
4258
|
+
uint8_t signs = ksigns_iq2xs[(aux32_s >> 14*il) & 127];
|
4259
|
+
for (int i = 0; i < 8; ++i) {
|
4260
|
+
reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
|
4261
|
+
}
|
4262
|
+
grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+1]);
|
4263
|
+
signs = ksigns_iq2xs[(aux32_s >> (14*il+7)) & 127];
|
4264
|
+
for (int i = 0; i < 8; ++i) {
|
4265
|
+
reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
|
4266
|
+
}
|
4267
|
+
}
|
4268
|
+
|
4269
|
+
template <typename type4x4>
|
4270
|
+
void dequantize_iq2_xs(device const block_iq2_xs * xb, short il, thread type4x4 & reg) {
|
4271
|
+
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
|
4272
|
+
const float d = xb->d;
|
4273
|
+
const int ib32 = il/2;
|
4274
|
+
il = il%2;
|
4275
|
+
// il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
|
4276
|
+
device const uint16_t * q2 = xb->qs + 4*ib32;
|
4277
|
+
const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f;
|
4278
|
+
constant uint8_t * grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+0] & 511));
|
4279
|
+
uint8_t signs = ksigns_iq2xs[q2[2*il+0] >> 9];
|
4280
|
+
for (int i = 0; i < 8; ++i) {
|
4281
|
+
reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
|
4282
|
+
}
|
4283
|
+
grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+1] & 511));
|
4284
|
+
signs = ksigns_iq2xs[q2[2*il+1] >> 9];
|
4285
|
+
for (int i = 0; i < 8; ++i) {
|
4286
|
+
reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
|
4287
|
+
}
|
4288
|
+
}
|
4289
|
+
|
3742
4290
|
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
|
3743
4291
|
kernel void kernel_get_rows(
|
3744
4292
|
device const void * src0,
|
@@ -4278,6 +4826,8 @@ template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows
|
|
4278
4826
|
template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows<block_q4_K, QK_NL, dequantize_q4_K>;
|
4279
4827
|
template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
|
4280
4828
|
template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
|
4829
|
+
template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_t kernel_get_rows<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
4830
|
+
template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_t kernel_get_rows<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
4281
4831
|
|
4282
4832
|
//
|
4283
4833
|
// matrix-matrix multiplication
|
@@ -4314,6 +4864,8 @@ template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<b
|
|
4314
4864
|
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
|
4315
4865
|
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
|
4316
4866
|
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
|
4867
|
+
template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
4868
|
+
template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
4317
4869
|
|
4318
4870
|
//
|
4319
4871
|
// indirect matrix-matrix multiplication
|
@@ -4362,6 +4914,8 @@ template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mu
|
|
4362
4914
|
template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>;
|
4363
4915
|
template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
|
4364
4916
|
template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
|
4917
|
+
template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
4918
|
+
template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
4365
4919
|
|
4366
4920
|
//
|
4367
4921
|
// matrix-vector multiplication
|
@@ -5134,3 +5688,133 @@ kernel void kernel_mul_mv_id_q6_K_f32(
|
|
5134
5688
|
tiisg,
|
5135
5689
|
sgitg);
|
5136
5690
|
}
|
5691
|
+
|
5692
|
+
[[host_name("kernel_mul_mv_id_iq2_xxs_f32")]]
|
5693
|
+
kernel void kernel_mul_mv_id_iq2_xxs_f32(
|
5694
|
+
device const char * ids,
|
5695
|
+
device const char * src1,
|
5696
|
+
device float * dst,
|
5697
|
+
constant uint64_t & nbi1,
|
5698
|
+
constant int64_t & ne00,
|
5699
|
+
constant int64_t & ne01,
|
5700
|
+
constant int64_t & ne02,
|
5701
|
+
constant uint64_t & nb00,
|
5702
|
+
constant uint64_t & nb01,
|
5703
|
+
constant uint64_t & nb02,
|
5704
|
+
constant int64_t & ne10,
|
5705
|
+
constant int64_t & ne11,
|
5706
|
+
constant int64_t & ne12,
|
5707
|
+
constant int64_t & ne13,
|
5708
|
+
constant uint64_t & nb10,
|
5709
|
+
constant uint64_t & nb11,
|
5710
|
+
constant uint64_t & nb12,
|
5711
|
+
constant int64_t & ne0,
|
5712
|
+
constant int64_t & ne1,
|
5713
|
+
constant uint64_t & nb1,
|
5714
|
+
constant uint & r2,
|
5715
|
+
constant uint & r3,
|
5716
|
+
constant int & idx,
|
5717
|
+
device const char * src00,
|
5718
|
+
device const char * src01,
|
5719
|
+
device const char * src02,
|
5720
|
+
device const char * src03,
|
5721
|
+
device const char * src04,
|
5722
|
+
device const char * src05,
|
5723
|
+
device const char * src06,
|
5724
|
+
device const char * src07,
|
5725
|
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
5726
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
5727
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
5728
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
5729
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
5730
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
5731
|
+
|
5732
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
5733
|
+
|
5734
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
5735
|
+
|
5736
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
5737
|
+
|
5738
|
+
kernel_mul_mv_iq2_xxs_f32_impl(
|
5739
|
+
src0[id],
|
5740
|
+
(device const float *) (src1 + bid*nb11),
|
5741
|
+
dst + bid*ne0,
|
5742
|
+
ne00,
|
5743
|
+
ne01,
|
5744
|
+
ne02,
|
5745
|
+
ne10,
|
5746
|
+
ne12,
|
5747
|
+
ne0,
|
5748
|
+
ne1,
|
5749
|
+
r2,
|
5750
|
+
r3,
|
5751
|
+
shared_values,
|
5752
|
+
tgpig,
|
5753
|
+
tiisg,
|
5754
|
+
sgitg);
|
5755
|
+
}
|
5756
|
+
|
5757
|
+
[[host_name("kernel_mul_mv_id_iq2_xs_f32")]]
|
5758
|
+
kernel void kernel_mul_mv_id_iq2_xs_f32(
|
5759
|
+
device const char * ids,
|
5760
|
+
device const char * src1,
|
5761
|
+
device float * dst,
|
5762
|
+
constant uint64_t & nbi1,
|
5763
|
+
constant int64_t & ne00,
|
5764
|
+
constant int64_t & ne01,
|
5765
|
+
constant int64_t & ne02,
|
5766
|
+
constant uint64_t & nb00,
|
5767
|
+
constant uint64_t & nb01,
|
5768
|
+
constant uint64_t & nb02,
|
5769
|
+
constant int64_t & ne10,
|
5770
|
+
constant int64_t & ne11,
|
5771
|
+
constant int64_t & ne12,
|
5772
|
+
constant int64_t & ne13,
|
5773
|
+
constant uint64_t & nb10,
|
5774
|
+
constant uint64_t & nb11,
|
5775
|
+
constant uint64_t & nb12,
|
5776
|
+
constant int64_t & ne0,
|
5777
|
+
constant int64_t & ne1,
|
5778
|
+
constant uint64_t & nb1,
|
5779
|
+
constant uint & r2,
|
5780
|
+
constant uint & r3,
|
5781
|
+
constant int & idx,
|
5782
|
+
device const char * src00,
|
5783
|
+
device const char * src01,
|
5784
|
+
device const char * src02,
|
5785
|
+
device const char * src03,
|
5786
|
+
device const char * src04,
|
5787
|
+
device const char * src05,
|
5788
|
+
device const char * src06,
|
5789
|
+
device const char * src07,
|
5790
|
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
5791
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
5792
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
5793
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
5794
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
5795
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
5796
|
+
|
5797
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
5798
|
+
|
5799
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
5800
|
+
|
5801
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
5802
|
+
|
5803
|
+
kernel_mul_mv_iq2_xs_f32_impl(
|
5804
|
+
src0[id],
|
5805
|
+
(device const float *) (src1 + bid*nb11),
|
5806
|
+
dst + bid*ne0,
|
5807
|
+
ne00,
|
5808
|
+
ne01,
|
5809
|
+
ne02,
|
5810
|
+
ne10,
|
5811
|
+
ne12,
|
5812
|
+
ne0,
|
5813
|
+
ne1,
|
5814
|
+
r2,
|
5815
|
+
r3,
|
5816
|
+
shared_values,
|
5817
|
+
tgpig,
|
5818
|
+
tiisg,
|
5819
|
+
sgitg);
|
5820
|
+
}
|