llama_cpp 0.12.6 → 0.13.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +21 -0
- data/ext/llama_cpp/llama_cpp.cpp +90 -269
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +28 -23
- data/vendor/tmp/llama.cpp/Makefile +51 -15
- data/vendor/tmp/llama.cpp/ggml-alloc.c +73 -43
- data/vendor/tmp/llama.cpp/ggml-backend-impl.h +2 -0
- data/vendor/tmp/llama.cpp/ggml-backend.c +32 -11
- data/vendor/tmp/llama.cpp/ggml-backend.h +1 -1
- data/vendor/tmp/llama.cpp/ggml-cuda.cu +560 -346
- data/vendor/tmp/llama.cpp/ggml-impl.h +20 -7
- data/vendor/tmp/llama.cpp/ggml-kompute.cpp +7 -1
- data/vendor/tmp/llama.cpp/ggml-metal.m +191 -22
- data/vendor/tmp/llama.cpp/ggml-metal.metal +2472 -862
- data/vendor/tmp/llama.cpp/ggml-opencl.cpp +25 -25
- data/vendor/tmp/llama.cpp/ggml-quants.c +3176 -667
- data/vendor/tmp/llama.cpp/ggml-quants.h +77 -2
- data/vendor/tmp/llama.cpp/ggml-sycl.cpp +373 -424
- data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +186 -102
- data/vendor/tmp/llama.cpp/ggml.c +1266 -699
- data/vendor/tmp/llama.cpp/ggml.h +59 -30
- data/vendor/tmp/llama.cpp/llama.cpp +1517 -717
- data/vendor/tmp/llama.cpp/llama.h +87 -63
- data/vendor/tmp/llama.cpp/scripts/get-flags.mk +1 -1
- data/vendor/tmp/llama.cpp/unicode.h +310 -1
- metadata +2 -2
|
@@ -351,12 +351,17 @@ kernel void kernel_sum_rows(
|
|
|
351
351
|
kernel void kernel_soft_max(
|
|
352
352
|
device const float * src0,
|
|
353
353
|
device const float * src1,
|
|
354
|
+
device const float * src2,
|
|
354
355
|
device float * dst,
|
|
355
356
|
constant int64_t & ne00,
|
|
356
357
|
constant int64_t & ne01,
|
|
357
358
|
constant int64_t & ne02,
|
|
358
359
|
constant float & scale,
|
|
359
|
-
|
|
360
|
+
constant float & max_bias,
|
|
361
|
+
constant float & m0,
|
|
362
|
+
constant float & m1,
|
|
363
|
+
constant uint32_t & n_head_log2,
|
|
364
|
+
threadgroup float * buf [[threadgroup(0)]],
|
|
360
365
|
uint tgpig[[threadgroup_position_in_grid]],
|
|
361
366
|
uint tpitg[[thread_position_in_threadgroup]],
|
|
362
367
|
uint sgitg[[simdgroup_index_in_threadgroup]],
|
|
@@ -368,13 +373,26 @@ kernel void kernel_soft_max(
|
|
|
368
373
|
|
|
369
374
|
device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
|
370
375
|
device const float * pmask = src1 != src0 ? src1 + i01*ne00 : nullptr;
|
|
376
|
+
device const float * ppos = src2 != src0 ? src2 : nullptr;
|
|
371
377
|
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
|
372
378
|
|
|
379
|
+
float slope = 0.0f;
|
|
380
|
+
|
|
381
|
+
// ALiBi
|
|
382
|
+
if (max_bias > 0.0f) {
|
|
383
|
+
const int64_t h = i02;
|
|
384
|
+
|
|
385
|
+
const float base = h < n_head_log2 ? m0 : m1;
|
|
386
|
+
const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
|
387
|
+
|
|
388
|
+
slope = pow(base, exp);
|
|
389
|
+
}
|
|
390
|
+
|
|
373
391
|
// parallel max
|
|
374
392
|
float lmax = -INFINITY;
|
|
375
393
|
|
|
376
394
|
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
|
377
|
-
lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f));
|
|
395
|
+
lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f));
|
|
378
396
|
}
|
|
379
397
|
|
|
380
398
|
// find the max value in the block
|
|
@@ -399,7 +417,7 @@ kernel void kernel_soft_max(
|
|
|
399
417
|
// parallel sum
|
|
400
418
|
float lsum = 0.0f;
|
|
401
419
|
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
|
402
|
-
const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val);
|
|
420
|
+
const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f)) - max_val);
|
|
403
421
|
lsum += exp_psrc0;
|
|
404
422
|
pdst[i00] = exp_psrc0;
|
|
405
423
|
}
|
|
@@ -437,12 +455,17 @@ kernel void kernel_soft_max(
|
|
|
437
455
|
kernel void kernel_soft_max_4(
|
|
438
456
|
device const float * src0,
|
|
439
457
|
device const float * src1,
|
|
458
|
+
device const float * src2,
|
|
440
459
|
device float * dst,
|
|
441
460
|
constant int64_t & ne00,
|
|
442
461
|
constant int64_t & ne01,
|
|
443
462
|
constant int64_t & ne02,
|
|
444
463
|
constant float & scale,
|
|
445
|
-
|
|
464
|
+
constant float & max_bias,
|
|
465
|
+
constant float & m0,
|
|
466
|
+
constant float & m1,
|
|
467
|
+
constant uint32_t & n_head_log2,
|
|
468
|
+
threadgroup float * buf [[threadgroup(0)]],
|
|
446
469
|
uint tgpig[[threadgroup_position_in_grid]],
|
|
447
470
|
uint tpitg[[thread_position_in_threadgroup]],
|
|
448
471
|
uint sgitg[[simdgroup_index_in_threadgroup]],
|
|
@@ -454,13 +477,25 @@ kernel void kernel_soft_max_4(
|
|
|
454
477
|
|
|
455
478
|
device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
|
456
479
|
device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
|
|
480
|
+
device const float4 * ppos = src2 != src0 ? (device const float4 *)(src2) : nullptr;
|
|
457
481
|
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
|
458
482
|
|
|
483
|
+
float slope = 0.0f;
|
|
484
|
+
|
|
485
|
+
if (max_bias > 0.0f) {
|
|
486
|
+
const int64_t h = i02;
|
|
487
|
+
|
|
488
|
+
const float base = h < n_head_log2 ? m0 : m1;
|
|
489
|
+
const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
|
490
|
+
|
|
491
|
+
slope = pow(base, exp);
|
|
492
|
+
}
|
|
493
|
+
|
|
459
494
|
// parallel max
|
|
460
495
|
float4 lmax4 = -INFINITY;
|
|
461
496
|
|
|
462
497
|
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
|
463
|
-
lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f));
|
|
498
|
+
lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f));
|
|
464
499
|
}
|
|
465
500
|
|
|
466
501
|
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
|
|
@@ -486,7 +521,7 @@ kernel void kernel_soft_max_4(
|
|
|
486
521
|
// parallel sum
|
|
487
522
|
float4 lsum4 = 0.0f;
|
|
488
523
|
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
|
489
|
-
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val);
|
|
524
|
+
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f)) - max_val);
|
|
490
525
|
lsum4 += exp_psrc4;
|
|
491
526
|
pdst4[i00] = exp_psrc4;
|
|
492
527
|
}
|
|
@@ -2484,12 +2519,58 @@ typedef struct {
|
|
|
2484
2519
|
} block_iq2_xs;
|
|
2485
2520
|
// 74 bytes / block for QK_K = 256, so 2.3125 bpw
|
|
2486
2521
|
|
|
2522
|
+
// 2.5625 bpw quants
|
|
2523
|
+
typedef struct {
|
|
2524
|
+
half d;
|
|
2525
|
+
uint8_t qs[QK_K/4];
|
|
2526
|
+
uint8_t qh[QK_K/32];
|
|
2527
|
+
uint8_t scales[QK_K/32];
|
|
2528
|
+
} block_iq2_s;
|
|
2529
|
+
|
|
2487
2530
|
typedef struct {
|
|
2488
2531
|
half d;
|
|
2489
2532
|
uint8_t qs[3*QK_K/8];
|
|
2490
2533
|
} block_iq3_xxs;
|
|
2491
2534
|
// 98 bytes / block for QK_K = 256, so 3.0625 bpw
|
|
2492
2535
|
|
|
2536
|
+
// 3.4375 bpw
|
|
2537
|
+
#if QK_K == 64
|
|
2538
|
+
#define IQ3S_N_SCALE 2
|
|
2539
|
+
#else
|
|
2540
|
+
#define IQ3S_N_SCALE QK_K/64
|
|
2541
|
+
#endif
|
|
2542
|
+
typedef struct {
|
|
2543
|
+
half d;
|
|
2544
|
+
uint8_t qs[QK_K/4];
|
|
2545
|
+
uint8_t qh[QK_K/32];
|
|
2546
|
+
uint8_t signs[QK_K/8];
|
|
2547
|
+
uint8_t scales[IQ3S_N_SCALE];
|
|
2548
|
+
} block_iq3_s;
|
|
2549
|
+
|
|
2550
|
+
typedef struct {
|
|
2551
|
+
half d;
|
|
2552
|
+
uint8_t qs[QK_K/8];
|
|
2553
|
+
uint8_t scales[QK_K/16];
|
|
2554
|
+
} block_iq1_s;
|
|
2555
|
+
|
|
2556
|
+
// Non-linear quants
|
|
2557
|
+
#define QK4_NL 32
|
|
2558
|
+
typedef struct {
|
|
2559
|
+
half d;
|
|
2560
|
+
uint8_t qs[QK4_NL/2];
|
|
2561
|
+
} block_iq4_nl;
|
|
2562
|
+
|
|
2563
|
+
#if QK_K == 64
|
|
2564
|
+
#define block_iq4_xs block_iq4_nl
|
|
2565
|
+
#else
|
|
2566
|
+
typedef struct {
|
|
2567
|
+
half d;
|
|
2568
|
+
uint16_t scales_h;
|
|
2569
|
+
uint8_t scales_l[QK_K/64];
|
|
2570
|
+
uint8_t qs[QK_K/2];
|
|
2571
|
+
} block_iq4_xs;
|
|
2572
|
+
#endif
|
|
2573
|
+
|
|
2493
2574
|
//====================================== dot products =========================
|
|
2494
2575
|
|
|
2495
2576
|
void kernel_mul_mv_q2_K_f32_impl(
|
|
@@ -3712,6 +3793,265 @@ constexpr constant static uint64_t iq2xs_grid[512] = {
|
|
|
3712
3793
|
0x2b2b2b2b082b2b08, 0x2b2b2b2b082b2b2b, 0x2b2b2b2b2b190819, 0x2b2b2b2b2b2b2b2b,
|
|
3713
3794
|
};
|
|
3714
3795
|
|
|
3796
|
+
constexpr constant static uint64_t iq2s_grid[1024] = {
|
|
3797
|
+
0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
|
|
3798
|
+
0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x080808080819192b,
|
|
3799
|
+
0x0808080808192b19, 0x08080808082b0808, 0x08080808082b082b, 0x08080808082b1919,
|
|
3800
|
+
0x08080808082b2b08, 0x0808080819080819, 0x0808080819081908, 0x080808081908192b,
|
|
3801
|
+
0x0808080819082b19, 0x0808080819190808, 0x080808081919082b, 0x0808080819191919,
|
|
3802
|
+
0x0808080819192b08, 0x08080808192b0819, 0x08080808192b1908, 0x08080808192b192b,
|
|
3803
|
+
0x08080808192b2b19, 0x080808082b080808, 0x080808082b08082b, 0x080808082b081919,
|
|
3804
|
+
0x080808082b082b08, 0x080808082b190819, 0x080808082b191908, 0x080808082b2b0808,
|
|
3805
|
+
0x080808082b2b1919, 0x080808082b2b2b2b, 0x0808081908080819, 0x0808081908081908,
|
|
3806
|
+
0x080808190808192b, 0x0808081908082b19, 0x0808081908190808, 0x080808190819082b,
|
|
3807
|
+
0x0808081908191919, 0x0808081908192b08, 0x08080819082b0819, 0x08080819082b1908,
|
|
3808
|
+
0x0808081919080808, 0x080808191908082b, 0x0808081919081919, 0x0808081919082b08,
|
|
3809
|
+
0x0808081919190819, 0x0808081919191908, 0x080808191919192b, 0x0808081919192b19,
|
|
3810
|
+
0x08080819192b0808, 0x08080819192b1919, 0x08080819192b2b08, 0x080808192b080819,
|
|
3811
|
+
0x080808192b081908, 0x080808192b190808, 0x080808192b19082b, 0x080808192b191919,
|
|
3812
|
+
0x080808192b2b0819, 0x080808192b2b1908, 0x0808082b08080808, 0x0808082b0808082b,
|
|
3813
|
+
0x0808082b08081919, 0x0808082b08082b08, 0x0808082b08190819, 0x0808082b08191908,
|
|
3814
|
+
0x0808082b082b0808, 0x0808082b082b2b2b, 0x0808082b19080819, 0x0808082b19081908,
|
|
3815
|
+
0x0808082b1908192b, 0x0808082b19082b19, 0x0808082b19190808, 0x0808082b19191919,
|
|
3816
|
+
0x0808082b2b080808, 0x0808082b2b081919, 0x0808082b2b082b2b, 0x0808082b2b191908,
|
|
3817
|
+
0x0808082b2b2b082b, 0x0808190808080819, 0x0808190808081908, 0x080819080808192b,
|
|
3818
|
+
0x0808190808082b19, 0x0808190808190808, 0x080819080819082b, 0x0808190808191919,
|
|
3819
|
+
0x0808190808192b08, 0x08081908082b0819, 0x08081908082b1908, 0x08081908082b192b,
|
|
3820
|
+
0x08081908082b2b19, 0x0808190819080808, 0x080819081908082b, 0x0808190819081919,
|
|
3821
|
+
0x0808190819082b08, 0x0808190819082b2b, 0x0808190819190819, 0x0808190819191908,
|
|
3822
|
+
0x080819081919192b, 0x0808190819192b19, 0x08081908192b0808, 0x08081908192b082b,
|
|
3823
|
+
0x08081908192b1919, 0x080819082b080819, 0x080819082b081908, 0x080819082b08192b,
|
|
3824
|
+
0x080819082b082b19, 0x080819082b190808, 0x080819082b191919, 0x080819082b192b08,
|
|
3825
|
+
0x080819082b2b0819, 0x080819082b2b1908, 0x0808191908080808, 0x080819190808082b,
|
|
3826
|
+
0x0808191908081919, 0x0808191908082b08, 0x0808191908082b2b, 0x0808191908190819,
|
|
3827
|
+
0x0808191908191908, 0x080819190819192b, 0x0808191908192b19, 0x08081919082b0808,
|
|
3828
|
+
0x08081919082b1919, 0x08081919082b2b08, 0x0808191919080819, 0x0808191919081908,
|
|
3829
|
+
0x080819191908192b, 0x0808191919082b19, 0x0808191919190808, 0x080819191919082b,
|
|
3830
|
+
0x0808191919191919, 0x0808191919192b08, 0x08081919192b0819, 0x08081919192b1908,
|
|
3831
|
+
0x080819192b080808, 0x080819192b08082b, 0x080819192b081919, 0x080819192b082b08,
|
|
3832
|
+
0x080819192b190819, 0x080819192b191908, 0x080819192b2b0808, 0x0808192b08080819,
|
|
3833
|
+
0x0808192b08081908, 0x0808192b0808192b, 0x0808192b08082b19, 0x0808192b08190808,
|
|
3834
|
+
0x0808192b08191919, 0x0808192b19080808, 0x0808192b19081919, 0x0808192b19082b08,
|
|
3835
|
+
0x0808192b19190819, 0x0808192b19191908, 0x0808192b192b0808, 0x0808192b2b080819,
|
|
3836
|
+
0x0808192b2b081908, 0x0808192b2b190808, 0x08082b0808080808, 0x08082b080808082b,
|
|
3837
|
+
0x08082b0808081919, 0x08082b0808082b08, 0x08082b0808190819, 0x08082b0808191908,
|
|
3838
|
+
0x08082b080819192b, 0x08082b0808192b19, 0x08082b08082b0808, 0x08082b08082b1919,
|
|
3839
|
+
0x08082b08082b2b2b, 0x08082b0819080819, 0x08082b0819081908, 0x08082b081908192b,
|
|
3840
|
+
0x08082b0819082b19, 0x08082b0819190808, 0x08082b081919082b, 0x08082b0819191919,
|
|
3841
|
+
0x08082b0819192b08, 0x08082b08192b0819, 0x08082b08192b1908, 0x08082b082b080808,
|
|
3842
|
+
0x08082b082b081919, 0x08082b082b191908, 0x08082b082b2b2b2b, 0x08082b1908080819,
|
|
3843
|
+
0x08082b1908081908, 0x08082b1908190808, 0x08082b190819082b, 0x08082b1908191919,
|
|
3844
|
+
0x08082b1908192b08, 0x08082b19082b0819, 0x08082b1919080808, 0x08082b1919081919,
|
|
3845
|
+
0x08082b1919082b08, 0x08082b1919190819, 0x08082b1919191908, 0x08082b19192b0808,
|
|
3846
|
+
0x08082b192b080819, 0x08082b192b190808, 0x08082b2b08080808, 0x08082b2b08190819,
|
|
3847
|
+
0x08082b2b08191908, 0x08082b2b082b082b, 0x08082b2b082b2b08, 0x08082b2b082b2b2b,
|
|
3848
|
+
0x08082b2b19190808, 0x08082b2b2b192b19, 0x0819080808080819, 0x0819080808081908,
|
|
3849
|
+
0x081908080808192b, 0x0819080808082b19, 0x0819080808190808, 0x081908080819082b,
|
|
3850
|
+
0x0819080808191919, 0x0819080808192b08, 0x08190808082b0819, 0x08190808082b1908,
|
|
3851
|
+
0x08190808082b192b, 0x0819080819080808, 0x081908081908082b, 0x0819080819081919,
|
|
3852
|
+
0x0819080819082b08, 0x0819080819190819, 0x0819080819191908, 0x081908081919192b,
|
|
3853
|
+
0x0819080819192b19, 0x08190808192b0808, 0x08190808192b082b, 0x08190808192b1919,
|
|
3854
|
+
0x08190808192b2b08, 0x081908082b080819, 0x081908082b081908, 0x081908082b08192b,
|
|
3855
|
+
0x081908082b190808, 0x081908082b191919, 0x081908082b192b08, 0x081908082b2b0819,
|
|
3856
|
+
0x081908082b2b1908, 0x0819081908080808, 0x081908190808082b, 0x0819081908081919,
|
|
3857
|
+
0x0819081908082b08, 0x0819081908082b2b, 0x0819081908190819, 0x0819081908191908,
|
|
3858
|
+
0x081908190819192b, 0x0819081908192b19, 0x08190819082b0808, 0x08190819082b082b,
|
|
3859
|
+
0x08190819082b1919, 0x08190819082b2b08, 0x0819081919080819, 0x0819081919081908,
|
|
3860
|
+
0x081908191908192b, 0x0819081919082b19, 0x0819081919190808, 0x081908191919082b,
|
|
3861
|
+
0x0819081919191919, 0x0819081919192b08, 0x08190819192b0819, 0x08190819192b1908,
|
|
3862
|
+
0x081908192b080808, 0x081908192b08082b, 0x081908192b081919, 0x081908192b082b08,
|
|
3863
|
+
0x081908192b190819, 0x081908192b191908, 0x0819082b08080819, 0x0819082b08081908,
|
|
3864
|
+
0x0819082b08082b19, 0x0819082b08190808, 0x0819082b08191919, 0x0819082b082b0819,
|
|
3865
|
+
0x0819082b082b1908, 0x0819082b19080808, 0x0819082b19081919, 0x0819082b19190819,
|
|
3866
|
+
0x0819082b19191908, 0x0819082b2b080819, 0x0819082b2b081908, 0x0819082b2b190808,
|
|
3867
|
+
0x0819190808080808, 0x081919080808082b, 0x0819190808081919, 0x0819190808082b08,
|
|
3868
|
+
0x0819190808190819, 0x0819190808191908, 0x081919080819192b, 0x0819190808192b19,
|
|
3869
|
+
0x08191908082b0808, 0x08191908082b1919, 0x08191908082b2b08, 0x0819190819080819,
|
|
3870
|
+
0x0819190819081908, 0x081919081908192b, 0x0819190819082b19, 0x0819190819190808,
|
|
3871
|
+
0x081919081919082b, 0x0819190819191919, 0x0819190819192b08, 0x08191908192b0819,
|
|
3872
|
+
0x08191908192b1908, 0x081919082b080808, 0x081919082b08082b, 0x081919082b081919,
|
|
3873
|
+
0x081919082b082b08, 0x081919082b190819, 0x081919082b191908, 0x081919082b2b0808,
|
|
3874
|
+
0x0819191908080819, 0x0819191908081908, 0x081919190808192b, 0x0819191908082b19,
|
|
3875
|
+
0x0819191908190808, 0x081919190819082b, 0x0819191908191919, 0x0819191908192b08,
|
|
3876
|
+
0x08191919082b0819, 0x08191919082b1908, 0x0819191919080808, 0x081919191908082b,
|
|
3877
|
+
0x0819191919081919, 0x0819191919082b08, 0x0819191919190819, 0x0819191919191908,
|
|
3878
|
+
0x08191919192b0808, 0x081919192b080819, 0x081919192b081908, 0x081919192b190808,
|
|
3879
|
+
0x0819192b08080808, 0x0819192b08081919, 0x0819192b08082b08, 0x0819192b08190819,
|
|
3880
|
+
0x0819192b08191908, 0x0819192b082b0808, 0x0819192b19080819, 0x0819192b19081908,
|
|
3881
|
+
0x0819192b19190808, 0x0819192b2b080808, 0x0819192b2b2b2b2b, 0x08192b0808080819,
|
|
3882
|
+
0x08192b0808081908, 0x08192b080808192b, 0x08192b0808082b19, 0x08192b0808190808,
|
|
3883
|
+
0x08192b0808191919, 0x08192b0808192b08, 0x08192b08082b0819, 0x08192b0819080808,
|
|
3884
|
+
0x08192b081908082b, 0x08192b0819081919, 0x08192b0819082b08, 0x08192b0819190819,
|
|
3885
|
+
0x08192b0819191908, 0x08192b08192b0808, 0x08192b082b080819, 0x08192b082b081908,
|
|
3886
|
+
0x08192b1908080808, 0x08192b190808082b, 0x08192b1908081919, 0x08192b1908082b08,
|
|
3887
|
+
0x08192b1908190819, 0x08192b1908191908, 0x08192b19082b0808, 0x08192b1919080819,
|
|
3888
|
+
0x08192b1919081908, 0x08192b1919190808, 0x08192b19192b2b19, 0x08192b192b2b082b,
|
|
3889
|
+
0x08192b2b08081908, 0x08192b2b08190808, 0x08192b2b19080808, 0x08192b2b1919192b,
|
|
3890
|
+
0x082b080808080808, 0x082b08080808082b, 0x082b080808081919, 0x082b080808082b08,
|
|
3891
|
+
0x082b080808190819, 0x082b080808191908, 0x082b08080819192b, 0x082b080808192b19,
|
|
3892
|
+
0x082b0808082b0808, 0x082b0808082b1919, 0x082b0808082b2b2b, 0x082b080819080819,
|
|
3893
|
+
0x082b080819081908, 0x082b080819190808, 0x082b08081919082b, 0x082b080819191919,
|
|
3894
|
+
0x082b0808192b1908, 0x082b08082b080808, 0x082b08082b082b2b, 0x082b08082b191908,
|
|
3895
|
+
0x082b08082b2b2b2b, 0x082b081908080819, 0x082b081908081908, 0x082b081908190808,
|
|
3896
|
+
0x082b08190819082b, 0x082b081908191919, 0x082b0819082b0819, 0x082b081919080808,
|
|
3897
|
+
0x082b08191908082b, 0x082b081919081919, 0x082b081919190819, 0x082b081919191908,
|
|
3898
|
+
0x082b0819192b0808, 0x082b08192b080819, 0x082b08192b081908, 0x082b08192b190808,
|
|
3899
|
+
0x082b082b08080808, 0x082b082b08082b2b, 0x082b082b082b082b, 0x082b082b082b2b08,
|
|
3900
|
+
0x082b082b082b2b2b, 0x082b082b19081908, 0x082b082b19190808, 0x082b082b2b082b08,
|
|
3901
|
+
0x082b082b2b082b2b, 0x082b082b2b2b2b08, 0x082b190808080819, 0x082b190808081908,
|
|
3902
|
+
0x082b19080808192b, 0x082b190808082b19, 0x082b190808190808, 0x082b190808191919,
|
|
3903
|
+
0x082b190808192b08, 0x082b1908082b0819, 0x082b1908082b1908, 0x082b190819080808,
|
|
3904
|
+
0x082b19081908082b, 0x082b190819081919, 0x082b190819082b08, 0x082b190819190819,
|
|
3905
|
+
0x082b190819191908, 0x082b1908192b0808, 0x082b19082b080819, 0x082b19082b081908,
|
|
3906
|
+
0x082b19082b190808, 0x082b191908080808, 0x082b191908081919, 0x082b191908082b08,
|
|
3907
|
+
0x082b191908190819, 0x082b191908191908, 0x082b1919082b0808, 0x082b191919080819,
|
|
3908
|
+
0x082b191919081908, 0x082b191919190808, 0x082b1919192b192b, 0x082b19192b080808,
|
|
3909
|
+
0x082b192b08080819, 0x082b192b08081908, 0x082b192b08190808, 0x082b192b19080808,
|
|
3910
|
+
0x082b192b19192b19, 0x082b2b0808080808, 0x082b2b0808081919, 0x082b2b0808190819,
|
|
3911
|
+
0x082b2b0808191908, 0x082b2b0819080819, 0x082b2b0819081908, 0x082b2b0819190808,
|
|
3912
|
+
0x082b2b082b082b2b, 0x082b2b082b2b2b2b, 0x082b2b1908080819, 0x082b2b1908081908,
|
|
3913
|
+
0x082b2b1908190808, 0x082b2b192b191919, 0x082b2b2b08082b2b, 0x082b2b2b082b082b,
|
|
3914
|
+
0x082b2b2b192b1908, 0x082b2b2b2b082b08, 0x082b2b2b2b082b2b, 0x1908080808080819,
|
|
3915
|
+
0x1908080808081908, 0x190808080808192b, 0x1908080808082b19, 0x1908080808190808,
|
|
3916
|
+
0x190808080819082b, 0x1908080808191919, 0x1908080808192b08, 0x1908080808192b2b,
|
|
3917
|
+
0x19080808082b0819, 0x19080808082b1908, 0x19080808082b192b, 0x1908080819080808,
|
|
3918
|
+
0x190808081908082b, 0x1908080819081919, 0x1908080819082b08, 0x1908080819082b2b,
|
|
3919
|
+
0x1908080819190819, 0x1908080819191908, 0x190808081919192b, 0x1908080819192b19,
|
|
3920
|
+
0x19080808192b0808, 0x19080808192b082b, 0x19080808192b1919, 0x190808082b080819,
|
|
3921
|
+
0x190808082b081908, 0x190808082b190808, 0x190808082b191919, 0x190808082b192b08,
|
|
3922
|
+
0x190808082b2b0819, 0x190808082b2b1908, 0x1908081908080808, 0x190808190808082b,
|
|
3923
|
+
0x1908081908081919, 0x1908081908082b08, 0x1908081908190819, 0x1908081908191908,
|
|
3924
|
+
0x190808190819192b, 0x1908081908192b19, 0x19080819082b0808, 0x19080819082b082b,
|
|
3925
|
+
0x19080819082b1919, 0x1908081919080819, 0x1908081919081908, 0x190808191908192b,
|
|
3926
|
+
0x1908081919082b19, 0x1908081919190808, 0x190808191919082b, 0x1908081919191919,
|
|
3927
|
+
0x1908081919192b08, 0x19080819192b0819, 0x19080819192b1908, 0x190808192b080808,
|
|
3928
|
+
0x190808192b08082b, 0x190808192b081919, 0x190808192b082b08, 0x190808192b190819,
|
|
3929
|
+
0x190808192b191908, 0x190808192b2b0808, 0x1908082b08080819, 0x1908082b08081908,
|
|
3930
|
+
0x1908082b08190808, 0x1908082b0819082b, 0x1908082b08191919, 0x1908082b08192b08,
|
|
3931
|
+
0x1908082b082b1908, 0x1908082b19080808, 0x1908082b19081919, 0x1908082b19082b08,
|
|
3932
|
+
0x1908082b19190819, 0x1908082b19191908, 0x1908082b192b0808, 0x1908082b2b080819,
|
|
3933
|
+
0x1908082b2b081908, 0x1908190808080808, 0x190819080808082b, 0x1908190808081919,
|
|
3934
|
+
0x1908190808082b08, 0x1908190808082b2b, 0x1908190808190819, 0x1908190808191908,
|
|
3935
|
+
0x190819080819192b, 0x1908190808192b19, 0x19081908082b0808, 0x19081908082b082b,
|
|
3936
|
+
0x19081908082b1919, 0x19081908082b2b08, 0x1908190819080819, 0x1908190819081908,
|
|
3937
|
+
0x190819081908192b, 0x1908190819082b19, 0x1908190819190808, 0x190819081919082b,
|
|
3938
|
+
0x1908190819191919, 0x1908190819192b08, 0x19081908192b0819, 0x19081908192b1908,
|
|
3939
|
+
0x190819082b080808, 0x190819082b08082b, 0x190819082b081919, 0x190819082b082b08,
|
|
3940
|
+
0x190819082b190819, 0x190819082b191908, 0x190819082b2b0808, 0x1908191908080819,
|
|
3941
|
+
0x1908191908081908, 0x190819190808192b, 0x1908191908082b19, 0x1908191908190808,
|
|
3942
|
+
0x190819190819082b, 0x1908191908191919, 0x1908191908192b08, 0x19081919082b0819,
|
|
3943
|
+
0x19081919082b1908, 0x1908191919080808, 0x190819191908082b, 0x1908191919081919,
|
|
3944
|
+
0x1908191919082b08, 0x1908191919190819, 0x1908191919191908, 0x19081919192b0808,
|
|
3945
|
+
0x19081919192b2b2b, 0x190819192b080819, 0x190819192b081908, 0x190819192b190808,
|
|
3946
|
+
0x1908192b08080808, 0x1908192b0808082b, 0x1908192b08081919, 0x1908192b08082b08,
|
|
3947
|
+
0x1908192b08190819, 0x1908192b08191908, 0x1908192b082b0808, 0x1908192b19080819,
|
|
3948
|
+
0x1908192b19081908, 0x1908192b19190808, 0x1908192b2b080808, 0x1908192b2b2b1919,
|
|
3949
|
+
0x19082b0808080819, 0x19082b0808081908, 0x19082b0808082b19, 0x19082b0808190808,
|
|
3950
|
+
0x19082b080819082b, 0x19082b0808191919, 0x19082b0808192b08, 0x19082b08082b0819,
|
|
3951
|
+
0x19082b08082b1908, 0x19082b0819080808, 0x19082b081908082b, 0x19082b0819081919,
|
|
3952
|
+
0x19082b0819082b08, 0x19082b0819190819, 0x19082b0819191908, 0x19082b08192b0808,
|
|
3953
|
+
0x19082b082b081908, 0x19082b082b190808, 0x19082b1908080808, 0x19082b190808082b,
|
|
3954
|
+
0x19082b1908081919, 0x19082b1908082b08, 0x19082b1908190819, 0x19082b1908191908,
|
|
3955
|
+
0x19082b19082b0808, 0x19082b1919080819, 0x19082b1919081908, 0x19082b1919190808,
|
|
3956
|
+
0x19082b192b080808, 0x19082b192b19192b, 0x19082b2b08080819, 0x19082b2b08081908,
|
|
3957
|
+
0x19082b2b08190808, 0x19082b2b19080808, 0x1919080808080808, 0x191908080808082b,
|
|
3958
|
+
0x1919080808081919, 0x1919080808082b08, 0x1919080808190819, 0x1919080808191908,
|
|
3959
|
+
0x191908080819192b, 0x1919080808192b19, 0x19190808082b0808, 0x19190808082b082b,
|
|
3960
|
+
0x19190808082b1919, 0x19190808082b2b08, 0x1919080819080819, 0x1919080819081908,
|
|
3961
|
+
0x191908081908192b, 0x1919080819082b19, 0x1919080819190808, 0x191908081919082b,
|
|
3962
|
+
0x1919080819191919, 0x1919080819192b08, 0x19190808192b0819, 0x19190808192b1908,
|
|
3963
|
+
0x191908082b080808, 0x191908082b08082b, 0x191908082b081919, 0x191908082b082b08,
|
|
3964
|
+
0x191908082b190819, 0x191908082b191908, 0x1919081908080819, 0x1919081908081908,
|
|
3965
|
+
0x191908190808192b, 0x1919081908082b19, 0x1919081908190808, 0x191908190819082b,
|
|
3966
|
+
0x1919081908191919, 0x1919081908192b08, 0x19190819082b0819, 0x19190819082b1908,
|
|
3967
|
+
0x1919081919080808, 0x191908191908082b, 0x1919081919081919, 0x1919081919082b08,
|
|
3968
|
+
0x1919081919190819, 0x1919081919191908, 0x19190819192b0808, 0x191908192b080819,
|
|
3969
|
+
0x191908192b081908, 0x191908192b190808, 0x1919082b08080808, 0x1919082b08081919,
|
|
3970
|
+
0x1919082b08082b08, 0x1919082b08190819, 0x1919082b08191908, 0x1919082b082b0808,
|
|
3971
|
+
0x1919082b19080819, 0x1919082b19081908, 0x1919082b19190808, 0x1919082b192b2b19,
|
|
3972
|
+
0x1919082b2b080808, 0x1919190808080819, 0x1919190808081908, 0x191919080808192b,
|
|
3973
|
+
0x1919190808082b19, 0x1919190808190808, 0x191919080819082b, 0x1919190808191919,
|
|
3974
|
+
0x1919190808192b08, 0x19191908082b0819, 0x19191908082b1908, 0x1919190819080808,
|
|
3975
|
+
0x191919081908082b, 0x1919190819081919, 0x1919190819082b08, 0x1919190819190819,
|
|
3976
|
+
0x1919190819191908, 0x19191908192b0808, 0x191919082b080819, 0x191919082b081908,
|
|
3977
|
+
0x191919082b190808, 0x1919191908080808, 0x191919190808082b, 0x1919191908081919,
|
|
3978
|
+
0x1919191908082b08, 0x1919191908190819, 0x1919191908191908, 0x19191919082b0808,
|
|
3979
|
+
0x1919191919080819, 0x1919191919081908, 0x1919191919190808, 0x191919192b080808,
|
|
3980
|
+
0x1919192b08080819, 0x1919192b08081908, 0x1919192b08190808, 0x1919192b082b192b,
|
|
3981
|
+
0x1919192b19080808, 0x19192b0808080808, 0x19192b080808082b, 0x19192b0808081919,
|
|
3982
|
+
0x19192b0808082b08, 0x19192b0808190819, 0x19192b0808191908, 0x19192b08082b0808,
|
|
3983
|
+
0x19192b0819080819, 0x19192b0819081908, 0x19192b0819190808, 0x19192b0819192b2b,
|
|
3984
|
+
0x19192b082b080808, 0x19192b1908080819, 0x19192b1908081908, 0x19192b1908190808,
|
|
3985
|
+
0x19192b1919080808, 0x19192b2b08080808, 0x19192b2b08192b19, 0x19192b2b2b081919,
|
|
3986
|
+
0x19192b2b2b2b2b08, 0x192b080808080819, 0x192b080808081908, 0x192b08080808192b,
|
|
3987
|
+
0x192b080808190808, 0x192b08080819082b, 0x192b080808191919, 0x192b080808192b08,
|
|
3988
|
+
0x192b0808082b0819, 0x192b0808082b1908, 0x192b080819080808, 0x192b080819081919,
|
|
3989
|
+
0x192b080819082b08, 0x192b080819190819, 0x192b080819191908, 0x192b0808192b0808,
|
|
3990
|
+
0x192b08082b081908, 0x192b08082b190808, 0x192b081908080808, 0x192b08190808082b,
|
|
3991
|
+
0x192b081908081919, 0x192b081908082b08, 0x192b081908190819, 0x192b081908191908,
|
|
3992
|
+
0x192b0819082b0808, 0x192b081919080819, 0x192b081919081908, 0x192b081919190808,
|
|
3993
|
+
0x192b08192b080808, 0x192b08192b192b19, 0x192b082b08081908, 0x192b082b08190808,
|
|
3994
|
+
0x192b082b19080808, 0x192b082b1919192b, 0x192b082b2b2b0819, 0x192b190808080808,
|
|
3995
|
+
0x192b190808081919, 0x192b190808082b08, 0x192b190808190819, 0x192b190808191908,
|
|
3996
|
+
0x192b1908082b0808, 0x192b190819080819, 0x192b190819081908, 0x192b190819190808,
|
|
3997
|
+
0x192b19082b080808, 0x192b191908080819, 0x192b191908081908, 0x192b191908190808,
|
|
3998
|
+
0x192b191919080808, 0x192b191919082b2b, 0x192b1919192b2b08, 0x192b19192b19082b,
|
|
3999
|
+
0x192b192b08080808, 0x192b192b2b191908, 0x192b2b0808080819, 0x192b2b0808081908,
|
|
4000
|
+
0x192b2b0808190808, 0x192b2b08192b1919, 0x192b2b082b192b08, 0x192b2b1908080808,
|
|
4001
|
+
0x192b2b19082b2b2b, 0x192b2b2b1908082b, 0x192b2b2b2b2b0819, 0x2b08080808080808,
|
|
4002
|
+
0x2b0808080808082b, 0x2b08080808081919, 0x2b08080808082b08, 0x2b08080808190819,
|
|
4003
|
+
0x2b08080808191908, 0x2b08080808192b19, 0x2b080808082b0808, 0x2b080808082b1919,
|
|
4004
|
+
0x2b08080819080819, 0x2b08080819081908, 0x2b08080819190808, 0x2b0808081919082b,
|
|
4005
|
+
0x2b08080819191919, 0x2b08080819192b08, 0x2b080808192b0819, 0x2b0808082b080808,
|
|
4006
|
+
0x2b0808082b081919, 0x2b0808082b190819, 0x2b0808082b191908, 0x2b08081908080819,
|
|
4007
|
+
0x2b08081908081908, 0x2b08081908082b19, 0x2b08081908190808, 0x2b0808190819082b,
|
|
4008
|
+
0x2b08081908191919, 0x2b08081908192b08, 0x2b080819082b0819, 0x2b080819082b1908,
|
|
4009
|
+
0x2b08081919080808, 0x2b0808191908082b, 0x2b08081919081919, 0x2b08081919082b08,
|
|
4010
|
+
0x2b08081919190819, 0x2b08081919191908, 0x2b0808192b080819, 0x2b0808192b081908,
|
|
4011
|
+
0x2b0808192b190808, 0x2b0808192b2b2b19, 0x2b08082b08080808, 0x2b08082b08081919,
|
|
4012
|
+
0x2b08082b08082b2b, 0x2b08082b08190819, 0x2b08082b08191908, 0x2b08082b19080819,
|
|
4013
|
+
0x2b08082b19081908, 0x2b08082b19190808, 0x2b08190808080819, 0x2b08190808081908,
|
|
4014
|
+
0x2b0819080808192b, 0x2b08190808082b19, 0x2b08190808190808, 0x2b0819080819082b,
|
|
4015
|
+
0x2b08190808191919, 0x2b08190808192b08, 0x2b081908082b0819, 0x2b08190819080808,
|
|
4016
|
+
0x2b0819081908082b, 0x2b08190819081919, 0x2b08190819082b08, 0x2b08190819190819,
|
|
4017
|
+
0x2b08190819191908, 0x2b081908192b0808, 0x2b0819082b080819, 0x2b0819082b081908,
|
|
4018
|
+
0x2b0819082b190808, 0x2b08191908080808, 0x2b0819190808082b, 0x2b08191908081919,
|
|
4019
|
+
0x2b08191908082b08, 0x2b08191908190819, 0x2b08191908191908, 0x2b081919082b0808,
|
|
4020
|
+
0x2b08191919080819, 0x2b08191919081908, 0x2b08191919190808, 0x2b0819192b080808,
|
|
4021
|
+
0x2b0819192b082b2b, 0x2b08192b08080819, 0x2b08192b08081908, 0x2b08192b08190808,
|
|
4022
|
+
0x2b08192b082b2b19, 0x2b08192b19080808, 0x2b082b0808080808, 0x2b082b0808081919,
|
|
4023
|
+
0x2b082b0808190819, 0x2b082b0808191908, 0x2b082b0819080819, 0x2b082b0819081908,
|
|
4024
|
+
0x2b082b0819190808, 0x2b082b082b2b082b, 0x2b082b1908080819, 0x2b082b1908081908,
|
|
4025
|
+
0x2b082b1919080808, 0x2b082b19192b1919, 0x2b082b2b082b082b, 0x2b082b2b19192b08,
|
|
4026
|
+
0x2b082b2b19192b2b, 0x2b082b2b2b08082b, 0x2b082b2b2b2b082b, 0x2b19080808080819,
|
|
4027
|
+
0x2b19080808081908, 0x2b19080808082b19, 0x2b19080808190808, 0x2b1908080819082b,
|
|
4028
|
+
0x2b19080808191919, 0x2b19080808192b08, 0x2b190808082b1908, 0x2b19080819080808,
|
|
4029
|
+
0x2b1908081908082b, 0x2b19080819081919, 0x2b19080819082b08, 0x2b19080819190819,
|
|
4030
|
+
0x2b19080819191908, 0x2b190808192b0808, 0x2b1908082b080819, 0x2b1908082b081908,
|
|
4031
|
+
0x2b1908082b190808, 0x2b19081908080808, 0x2b19081908081919, 0x2b19081908190819,
|
|
4032
|
+
0x2b19081908191908, 0x2b19081919080819, 0x2b19081919081908, 0x2b19081919190808,
|
|
4033
|
+
0x2b19081919192b2b, 0x2b19082b08080819, 0x2b19082b08081908, 0x2b19082b08190808,
|
|
4034
|
+
0x2b19082b19080808, 0x2b19082b2b2b192b, 0x2b19190808080808, 0x2b1919080808082b,
|
|
4035
|
+
0x2b19190808081919, 0x2b19190808082b08, 0x2b19190808190819, 0x2b19190808191908,
|
|
4036
|
+
0x2b191908082b0808, 0x2b19190819080819, 0x2b19190819081908, 0x2b19190819190808,
|
|
4037
|
+
0x2b1919082b080808, 0x2b1919082b19192b, 0x2b19191908080819, 0x2b19191908081908,
|
|
4038
|
+
0x2b19191908190808, 0x2b19191919080808, 0x2b1919192b192b08, 0x2b1919192b2b0819,
|
|
4039
|
+
0x2b19192b08080808, 0x2b19192b1908192b, 0x2b19192b192b1908, 0x2b192b0808080819,
|
|
4040
|
+
0x2b192b0808081908, 0x2b192b0808190808, 0x2b192b08082b192b, 0x2b192b0819080808,
|
|
4041
|
+
0x2b192b082b2b2b19, 0x2b192b1908080808, 0x2b192b1919082b19, 0x2b192b191919082b,
|
|
4042
|
+
0x2b192b2b2b190808, 0x2b2b080808080808, 0x2b2b080808081919, 0x2b2b080808082b2b,
|
|
4043
|
+
0x2b2b080808191908, 0x2b2b0808082b082b, 0x2b2b0808082b2b2b, 0x2b2b080819080819,
|
|
4044
|
+
0x2b2b080819081908, 0x2b2b080819190808, 0x2b2b08082b2b082b, 0x2b2b08082b2b2b2b,
|
|
4045
|
+
0x2b2b081919080808, 0x2b2b0819192b1919, 0x2b2b082b0808082b, 0x2b2b082b08082b2b,
|
|
4046
|
+
0x2b2b082b082b082b, 0x2b2b082b082b2b08, 0x2b2b082b082b2b2b, 0x2b2b082b2b08082b,
|
|
4047
|
+
0x2b2b082b2b082b08, 0x2b2b082b2b082b2b, 0x2b2b082b2b2b2b08, 0x2b2b190808080819,
|
|
4048
|
+
0x2b2b190808081908, 0x2b2b190808190808, 0x2b2b190819080808, 0x2b2b19082b082b19,
|
|
4049
|
+
0x2b2b19082b2b1908, 0x2b2b191908080808, 0x2b2b191908192b19, 0x2b2b192b19190819,
|
|
4050
|
+
0x2b2b2b0808082b2b, 0x2b2b2b08082b2b08, 0x2b2b2b082b2b082b, 0x2b2b2b1919191908,
|
|
4051
|
+
0x2b2b2b192b08192b, 0x2b2b2b2b08082b08, 0x2b2b2b2b08082b2b, 0x2b2b2b2b082b0808,
|
|
4052
|
+
0x2b2b2b2b082b082b, 0x2b2b2b2b082b2b08, 0x2b2b2b2b2b082b08, 0x2b2b2b2b2b2b2b2b,
|
|
4053
|
+
};
|
|
4054
|
+
|
|
3715
4055
|
constexpr constant static uint32_t iq3xxs_grid[256] = {
|
|
3716
4056
|
0x04040404, 0x04040414, 0x04040424, 0x04040c0c, 0x04040c1c, 0x04040c3e, 0x04041404, 0x04041414,
|
|
3717
4057
|
0x04041c0c, 0x04042414, 0x04043e1c, 0x04043e2c, 0x040c040c, 0x040c041c, 0x040c0c04, 0x040c0c14,
|
|
@@ -3747,6 +4087,204 @@ constexpr constant static uint32_t iq3xxs_grid[256] = {
|
|
|
3747
4087
|
0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
|
|
3748
4088
|
};
|
|
3749
4089
|
|
|
4090
|
+
constexpr constant static uint32_t iq3xs_grid[512] = {
|
|
4091
|
+
0x04040404, 0x0404040c, 0x04040414, 0x0404042c, 0x0404043e, 0x04040c04, 0x04040c0c, 0x04040c14,
|
|
4092
|
+
0x04040c24, 0x04040c34, 0x04041404, 0x0404140c, 0x0404142c, 0x04041c1c, 0x04042404, 0x04042414,
|
|
4093
|
+
0x0404242c, 0x0404243e, 0x04042c0c, 0x04042c1c, 0x04043404, 0x04043414, 0x04043e0c, 0x04043e24,
|
|
4094
|
+
0x04043e3e, 0x040c0404, 0x040c040c, 0x040c0414, 0x040c0424, 0x040c0c04, 0x040c0c0c, 0x040c0c2c,
|
|
4095
|
+
0x040c1404, 0x040c141c, 0x040c143e, 0x040c1c0c, 0x040c1c2c, 0x040c2424, 0x040c340c, 0x040c342c,
|
|
4096
|
+
0x040c3e14, 0x04140404, 0x0414040c, 0x0414042c, 0x0414043e, 0x04140c04, 0x04140c1c, 0x04140c34,
|
|
4097
|
+
0x0414140c, 0x0414142c, 0x04141c04, 0x04141c24, 0x04142414, 0x0414242c, 0x0414243e, 0x04142c0c,
|
|
4098
|
+
0x04142c1c, 0x04143e04, 0x04143e1c, 0x041c041c, 0x041c0c0c, 0x041c0c2c, 0x041c1404, 0x041c1414,
|
|
4099
|
+
0x041c1c0c, 0x041c1c1c, 0x041c1c34, 0x041c2424, 0x041c2c04, 0x041c2c14, 0x041c343e, 0x041c3e0c,
|
|
4100
|
+
0x041c3e2c, 0x04240404, 0x04240c1c, 0x04240c3e, 0x0424140c, 0x04241424, 0x04241c14, 0x04242404,
|
|
4101
|
+
0x0424241c, 0x04242c0c, 0x04243e04, 0x042c0414, 0x042c0424, 0x042c1404, 0x042c1414, 0x042c1434,
|
|
4102
|
+
0x042c1c1c, 0x042c240c, 0x042c242c, 0x042c243e, 0x042c3434, 0x042c3e1c, 0x04340434, 0x04340c0c,
|
|
4103
|
+
0x04340c1c, 0x04341c0c, 0x04342c14, 0x04343e0c, 0x043e0404, 0x043e0414, 0x043e0424, 0x043e1404,
|
|
4104
|
+
0x043e1414, 0x043e1434, 0x043e1c1c, 0x043e2c04, 0x043e2c24, 0x0c040404, 0x0c04040c, 0x0c040414,
|
|
4105
|
+
0x0c040424, 0x0c040c04, 0x0c040c0c, 0x0c040c1c, 0x0c040c2c, 0x0c040c3e, 0x0c041404, 0x0c041414,
|
|
4106
|
+
0x0c041c0c, 0x0c041c24, 0x0c041c34, 0x0c042c24, 0x0c042c34, 0x0c04340c, 0x0c043e14, 0x0c0c0404,
|
|
4107
|
+
0x0c0c040c, 0x0c0c041c, 0x0c0c0434, 0x0c0c0c04, 0x0c0c0c24, 0x0c0c140c, 0x0c0c1c04, 0x0c0c1c1c,
|
|
4108
|
+
0x0c0c240c, 0x0c0c2c04, 0x0c0c2c14, 0x0c0c3e04, 0x0c0c3e34, 0x0c140404, 0x0c140c14, 0x0c140c2c,
|
|
4109
|
+
0x0c140c3e, 0x0c141404, 0x0c141424, 0x0c141c14, 0x0c142404, 0x0c14241c, 0x0c142c2c, 0x0c143404,
|
|
4110
|
+
0x0c143e14, 0x0c1c040c, 0x0c1c0424, 0x0c1c043e, 0x0c1c0c04, 0x0c1c0c1c, 0x0c1c140c, 0x0c1c143e,
|
|
4111
|
+
0x0c1c1c04, 0x0c1c1c24, 0x0c1c240c, 0x0c1c3414, 0x0c1c3e04, 0x0c24041c, 0x0c24042c, 0x0c240c14,
|
|
4112
|
+
0x0c240c24, 0x0c241c0c, 0x0c241c1c, 0x0c242414, 0x0c242434, 0x0c242c04, 0x0c242c24, 0x0c2c040c,
|
|
4113
|
+
0x0c2c0c04, 0x0c2c0c1c, 0x0c2c140c, 0x0c2c1c04, 0x0c2c1c14, 0x0c2c2c0c, 0x0c341404, 0x0c341424,
|
|
4114
|
+
0x0c34143e, 0x0c342424, 0x0c342434, 0x0c3e040c, 0x0c3e041c, 0x0c3e0c04, 0x0c3e0c14, 0x0c3e140c,
|
|
4115
|
+
0x0c3e1c2c, 0x0c3e240c, 0x0c3e3414, 0x0c3e3e04, 0x14040404, 0x1404040c, 0x1404041c, 0x1404042c,
|
|
4116
|
+
0x1404043e, 0x14040c04, 0x14040c14, 0x14040c24, 0x14040c34, 0x1404140c, 0x1404141c, 0x1404143e,
|
|
4117
|
+
0x14041c04, 0x14041c14, 0x1404240c, 0x1404241c, 0x1404242c, 0x14042c04, 0x14042c14, 0x1404343e,
|
|
4118
|
+
0x14043e04, 0x14043e1c, 0x14043e2c, 0x140c0404, 0x140c0414, 0x140c0c04, 0x140c0c1c, 0x140c0c3e,
|
|
4119
|
+
0x140c1414, 0x140c142c, 0x140c1c0c, 0x140c1c24, 0x140c2414, 0x140c2c0c, 0x1414040c, 0x14140424,
|
|
4120
|
+
0x1414043e, 0x1414140c, 0x1414141c, 0x14141c04, 0x14141c3e, 0x1414240c, 0x14142c1c, 0x14142c3e,
|
|
4121
|
+
0x14143e0c, 0x14143e24, 0x141c0404, 0x141c0414, 0x141c042c, 0x141c0c0c, 0x141c1414, 0x141c1424,
|
|
4122
|
+
0x141c1c0c, 0x141c1c1c, 0x141c2414, 0x141c2c04, 0x141c3434, 0x1424040c, 0x1424043e, 0x14241404,
|
|
4123
|
+
0x1424141c, 0x14241c14, 0x14241c2c, 0x1424240c, 0x14243e14, 0x14243e2c, 0x142c0424, 0x142c0c0c,
|
|
4124
|
+
0x142c1414, 0x142c1c3e, 0x142c2404, 0x142c2c1c, 0x142c3e04, 0x14340404, 0x14340414, 0x1434043e,
|
|
4125
|
+
0x1434140c, 0x14342c2c, 0x1434340c, 0x143e042c, 0x143e0c0c, 0x143e1434, 0x143e1c04, 0x143e241c,
|
|
4126
|
+
0x143e2c04, 0x1c040414, 0x1c040c0c, 0x1c040c1c, 0x1c040c2c, 0x1c040c3e, 0x1c041414, 0x1c041c0c,
|
|
4127
|
+
0x1c041c1c, 0x1c041c2c, 0x1c042414, 0x1c042424, 0x1c04243e, 0x1c042c0c, 0x1c04341c, 0x1c043e0c,
|
|
4128
|
+
0x1c0c040c, 0x1c0c041c, 0x1c0c042c, 0x1c0c0c24, 0x1c0c140c, 0x1c0c141c, 0x1c0c2404, 0x1c0c3404,
|
|
4129
|
+
0x1c0c3e14, 0x1c0c3e34, 0x1c140404, 0x1c140c14, 0x1c141404, 0x1c141c14, 0x1c141c24, 0x1c142c04,
|
|
4130
|
+
0x1c1c040c, 0x1c1c0c04, 0x1c1c0c24, 0x1c1c140c, 0x1c1c141c, 0x1c1c143e, 0x1c1c1c04, 0x1c1c240c,
|
|
4131
|
+
0x1c1c241c, 0x1c1c243e, 0x1c1c2c2c, 0x1c1c3e1c, 0x1c24041c, 0x1c240c0c, 0x1c240c34, 0x1c241414,
|
|
4132
|
+
0x1c241c0c, 0x1c242c14, 0x1c243404, 0x1c243424, 0x1c2c040c, 0x1c2c0c04, 0x1c2c0c14, 0x1c2c142c,
|
|
4133
|
+
0x1c2c1c14, 0x1c2c2424, 0x1c2c2c34, 0x1c2c3e1c, 0x1c340c34, 0x1c34240c, 0x1c3e040c, 0x1c3e041c,
|
|
4134
|
+
0x1c3e1404, 0x1c3e1414, 0x1c3e1c2c, 0x24040404, 0x24040424, 0x24040c14, 0x24041404, 0x24041424,
|
|
4135
|
+
0x2404143e, 0x24041c14, 0x2404240c, 0x24042c04, 0x24043e04, 0x240c0414, 0x240c043e, 0x240c0c0c,
|
|
4136
|
+
0x240c0c1c, 0x240c1414, 0x240c1c04, 0x240c1c2c, 0x240c241c, 0x240c2c0c, 0x240c2c2c, 0x2414040c,
|
|
4137
|
+
0x2414041c, 0x24140c04, 0x24140c2c, 0x2414140c, 0x24141c1c, 0x24142404, 0x24142c3e, 0x24143414,
|
|
4138
|
+
0x24143e04, 0x241c0424, 0x241c0c0c, 0x241c0c1c, 0x241c1404, 0x241c1414, 0x241c1c0c, 0x241c1c2c,
|
|
4139
|
+
0x24240404, 0x24240414, 0x24241424, 0x24241c3e, 0x24242404, 0x24243e0c, 0x242c042c, 0x242c043e,
|
|
4140
|
+
0x242c140c, 0x242c3414, 0x24340c1c, 0x24341c24, 0x24343404, 0x243e0c04, 0x243e0c2c, 0x243e1c04,
|
|
4141
|
+
0x243e241c, 0x243e2c0c, 0x2c040414, 0x2c040c04, 0x2c040c24, 0x2c041414, 0x2c042404, 0x2c042424,
|
|
4142
|
+
0x2c04243e, 0x2c042c14, 0x2c043434, 0x2c043e24, 0x2c0c040c, 0x2c0c041c, 0x2c0c042c, 0x2c0c0c14,
|
|
4143
|
+
0x2c0c140c, 0x2c0c1c14, 0x2c0c3e14, 0x2c140404, 0x2c140c0c, 0x2c14141c, 0x2c141c04, 0x2c141c34,
|
|
4144
|
+
0x2c142c1c, 0x2c1c0414, 0x2c1c043e, 0x2c1c0c04, 0x2c1c143e, 0x2c1c2424, 0x2c1c2c0c, 0x2c1c342c,
|
|
4145
|
+
0x2c1c3e1c, 0x2c24040c, 0x2c240424, 0x2c241404, 0x2c241c14, 0x2c242434, 0x2c2c0c14, 0x2c2c1434,
|
|
4146
|
+
0x2c2c2c0c, 0x2c2c2c1c, 0x2c342414, 0x2c3e0414, 0x2c3e0424, 0x2c3e1414, 0x34040c0c, 0x34040c1c,
|
|
4147
|
+
0x34040c2c, 0x34041c0c, 0x34041c1c, 0x34043404, 0x340c0404, 0x340c1404, 0x340c143e, 0x340c3424,
|
|
4148
|
+
0x34140c14, 0x34141c24, 0x34142414, 0x34142c2c, 0x34143414, 0x34143e04, 0x341c0404, 0x341c0c24,
|
|
4149
|
+
0x341c140c, 0x341c2404, 0x3424142c, 0x3424241c, 0x34243414, 0x342c0404, 0x342c041c, 0x342c1c24,
|
|
4150
|
+
0x342c3404, 0x3434042c, 0x34342404, 0x343e0c0c, 0x343e0c1c, 0x3e040404, 0x3e040424, 0x3e04043e,
|
|
4151
|
+
0x3e041404, 0x3e041414, 0x3e041c34, 0x3e042404, 0x3e042c24, 0x3e043414, 0x3e0c0414, 0x3e0c0c0c,
|
|
4152
|
+
0x3e0c1424, 0x3e0c241c, 0x3e0c242c, 0x3e14040c, 0x3e140424, 0x3e140c04, 0x3e140c34, 0x3e14140c,
|
|
4153
|
+
0x3e141c04, 0x3e142c0c, 0x3e1c0414, 0x3e1c1c14, 0x3e1c1c2c, 0x3e1c2c1c, 0x3e24040c, 0x3e24042c,
|
|
4154
|
+
0x3e240c1c, 0x3e241404, 0x3e242c04, 0x3e2c1414, 0x3e2c2414, 0x3e340414, 0x3e341c0c, 0x3e3e0404,
|
|
4155
|
+
};
|
|
4156
|
+
|
|
4157
|
+
#define NGRID_IQ1S 512
|
|
4158
|
+
constexpr constant static uint64_t iq1s_grid[NGRID_IQ1S] = {
|
|
4159
|
+
0xffffffffffff0101, 0xffffffffff01ff00, 0xffffffffff010100, 0xffffffff00000000,
|
|
4160
|
+
0xffffffff01ff00ff, 0xffffffff01ff0001, 0xffffffff0101ffff, 0xffffffff0101ff01,
|
|
4161
|
+
0xffffff00ff000000, 0xffffff000000ff00, 0xffffff00000000ff, 0xffffff0000000100,
|
|
4162
|
+
0xffffff0000010000, 0xffffff0001000000, 0xffffff01ffff00ff, 0xffffff01ff01ff00,
|
|
4163
|
+
0xffffff01ff010100, 0xffffff0100000001, 0xffffff0101ffff00, 0xffffff0101ff0101,
|
|
4164
|
+
0xffffff0101010100, 0xffff00ffff00ff01, 0xffff00ffff0000ff, 0xffff00ff00ff0100,
|
|
4165
|
+
0xffff00ff0100ff00, 0xffff00ff010001ff, 0xffff0000ff0101ff, 0xffff000000ffff00,
|
|
4166
|
+
0xffff000000000000, 0xffff00000001ff01, 0xffff000001000101, 0xffff0000010100ff,
|
|
4167
|
+
0xffff0001ffff0100, 0xffff00010000ff00, 0xffff000100010101, 0xffff000101000000,
|
|
4168
|
+
0xffff01ffffff0000, 0xffff01ffff01ffff, 0xffff01ffff010100, 0xffff01ff00000000,
|
|
4169
|
+
0xffff01ff01ffffff, 0xffff01ff01ff0001, 0xffff01ff0101ffff, 0xffff01ff01010001,
|
|
4170
|
+
0xffff0100ffffff01, 0xffff01000000ffff, 0xffff010000000100, 0xffff010001ff01ff,
|
|
4171
|
+
0xffff010001000000, 0xffff0101ff000000, 0xffff0101000101ff, 0xffff010101ffff01,
|
|
4172
|
+
0xffff01010101ff00, 0xff00ffffff000000, 0xff00ffff00ffff00, 0xff00ffff00000001,
|
|
4173
|
+
0xff00ffff000001ff, 0xff00ffff01010000, 0xff00ff00ffff0000, 0xff00ff00ff00ff00,
|
|
4174
|
+
0xff00ff00ff0000ff, 0xff00ff00ff000100, 0xff00ff00ff010001, 0xff00ff0000ff0001,
|
|
4175
|
+
0xff00ff000000ffff, 0xff00ff0000000000, 0xff00ff000001ff00, 0xff00ff0000010100,
|
|
4176
|
+
0xff00ff0001ff0000, 0xff00ff000100ff00, 0xff00ff0001000100, 0xff00ff01ff000000,
|
|
4177
|
+
0xff00ff0100ff0000, 0xff00ff01000001ff, 0xff00ff0101010001, 0xff0000ff00000000,
|
|
4178
|
+
0xff0000ff0001ff00, 0xff0000ff00010100, 0xff000000ffff0101, 0xff000000ff000000,
|
|
4179
|
+
0xff000000ff01ff00, 0xff00000000ff0000, 0xff0000000000ff00, 0xff000000000000ff,
|
|
4180
|
+
0xff00000000000000, 0xff00000000000001, 0xff00000000000100, 0xff0000000001ffff,
|
|
4181
|
+
0xff00000000010000, 0xff00000001000000, 0xff00000001010100, 0xff000001ff00ff01,
|
|
4182
|
+
0xff000001ff0100ff, 0xff00000100000000, 0xff0000010001ff00, 0xff00000101ff0100,
|
|
4183
|
+
0xff0000010100ff00, 0xff0001ff00ff00ff, 0xff0001ff00000101, 0xff0001ff000100ff,
|
|
4184
|
+
0xff0001ff01000000, 0xff000100ff0001ff, 0xff0001000000ff01, 0xff00010000000000,
|
|
4185
|
+
0xff00010000010001, 0xff00010000010100, 0xff00010001ffff00, 0xff00010001ff0101,
|
|
4186
|
+
0xff00010001010000, 0xff000101ffffffff, 0xff000101ff000101, 0xff00010101ff00ff,
|
|
4187
|
+
0xff00010101000001, 0xff000101010100ff, 0xff01ffffff000101, 0xff01ffffff01ffff,
|
|
4188
|
+
0xff01ffffff01ff01, 0xff01ffffff0101ff, 0xff01ffff00000000, 0xff01ffff01ff0001,
|
|
4189
|
+
0xff01ffff0101ff01, 0xff01ff00ff000000, 0xff01ff0000ff0100, 0xff01ff000000ff01,
|
|
4190
|
+
0xff01ff0000010000, 0xff01ff00010000ff, 0xff01ff01ff01ff00, 0xff01ff0100000101,
|
|
4191
|
+
0xff0100ffffff0000, 0xff0100ffff010000, 0xff0100ff01ff00ff, 0xff0100ff01000100,
|
|
4192
|
+
0xff0100ff010100ff, 0xff010000ffffff01, 0xff01000000000000, 0xff0100000101ff00,
|
|
4193
|
+
0xff010001ffff00ff, 0xff010001ff000100, 0xff01000100ffff00, 0xff01000100010001,
|
|
4194
|
+
0xff01000101ff0001, 0xff010001010001ff, 0xff0101ffffffffff, 0xff0101ffff01ffff,
|
|
4195
|
+
0xff0101ffff010101, 0xff0101ff0000ff00, 0xff0101ff01010001, 0xff010100ff000000,
|
|
4196
|
+
0xff010100ff01ff01, 0xff01010000ff0001, 0xff01010000000100, 0xff01010001000000,
|
|
4197
|
+
0xff0101010100ffff, 0x00ffffff0000ff01, 0x00ffffff000000ff, 0x00ffffff00000100,
|
|
4198
|
+
0x00ffffff00010000, 0x00ffff00ffff0001, 0x00ffff00ff0000ff, 0x00ffff00ff000100,
|
|
4199
|
+
0x00ffff0000000000, 0x00ffff0001000100, 0x00ffff0001010001, 0x00ffff01ff00ff01,
|
|
4200
|
+
0x00ffff0100ff0100, 0x00ffff010000ff00, 0x00ffff01000100ff, 0x00ffff0101ff00ff,
|
|
4201
|
+
0x00ffff010101ff00, 0x00ff00ffffffffff, 0x00ff00ffffff01ff, 0x00ff00ffff000101,
|
|
4202
|
+
0x00ff00ff00000000, 0x00ff00ff000101ff, 0x00ff00ff01010101, 0x00ff0000ff000000,
|
|
4203
|
+
0x00ff0000ff01ffff, 0x00ff000000ff0000, 0x00ff00000000ff00, 0x00ff0000000000ff,
|
|
4204
|
+
0x00ff000000000000, 0x00ff000000000001, 0x00ff000000000100, 0x00ff000000010000,
|
|
4205
|
+
0x00ff000001ffff01, 0x00ff000001000000, 0x00ff0001ff000101, 0x00ff000100ffffff,
|
|
4206
|
+
0x00ff000100000000, 0x00ff0001010001ff, 0x00ff01ffff000000, 0x00ff01ff0001ff00,
|
|
4207
|
+
0x00ff01ff01ff0100, 0x00ff0100ff01ff01, 0x00ff010000ff00ff, 0x00ff010000ff0101,
|
|
4208
|
+
0x00ff010000000000, 0x00ff010000010101, 0x00ff01000100ff00, 0x00ff010001010000,
|
|
4209
|
+
0x00ff0101ffffff00, 0x00ff01010000ff01, 0x00ff010100000100, 0x00ff010101ff0000,
|
|
4210
|
+
0x0000ffffffff0100, 0x0000ffffff00ff00, 0x0000ffffff0000ff, 0x0000ffffff010000,
|
|
4211
|
+
0x0000ffff00000000, 0x0000ffff00010101, 0x0000ffff01ffff01, 0x0000ffff01000100,
|
|
4212
|
+
0x0000ff00ff000000, 0x0000ff00ff01ff00, 0x0000ff00ff0101ff, 0x0000ff0000ff0000,
|
|
4213
|
+
0x0000ff000000ff00, 0x0000ff00000000ff, 0x0000ff0000000000, 0x0000ff0000000001,
|
|
4214
|
+
0x0000ff0000000100, 0x0000ff0000010000, 0x0000ff0001ffffff, 0x0000ff0001ff01ff,
|
|
4215
|
+
0x0000ff0001000000, 0x0000ff000101ffff, 0x0000ff01ffff0101, 0x0000ff01ff010000,
|
|
4216
|
+
0x0000ff0100000000, 0x0000ff0101000101, 0x000000ffffff0001, 0x000000ffff000000,
|
|
4217
|
+
0x000000ff00ff0000, 0x000000ff0000ff00, 0x000000ff000000ff, 0x000000ff00000000,
|
|
4218
|
+
0x000000ff00000001, 0x000000ff00000100, 0x000000ff00010000, 0x000000ff01000000,
|
|
4219
|
+
0x000000ff0101ff00, 0x00000000ffff0000, 0x00000000ff00ff00, 0x00000000ff0000ff,
|
|
4220
|
+
0x00000000ff000000, 0x00000000ff000001, 0x00000000ff000100, 0x00000000ff010000,
|
|
4221
|
+
0x0000000000ffff00, 0x0000000000ff00ff, 0x0000000000ff0000, 0x0000000000ff0001,
|
|
4222
|
+
0x0000000000ff0100, 0x000000000000ffff, 0x000000000000ff00, 0x000000000000ff01,
|
|
4223
|
+
0x00000000000000ff, 0x0000000000000001, 0x00000000000001ff, 0x0000000000000100,
|
|
4224
|
+
0x0000000000000101, 0x000000000001ff00, 0x00000000000100ff, 0x0000000000010000,
|
|
4225
|
+
0x0000000000010001, 0x0000000000010100, 0x0000000001ff0000, 0x000000000100ff00,
|
|
4226
|
+
0x00000000010000ff, 0x0000000001000000, 0x0000000001000001, 0x0000000001000100,
|
|
4227
|
+
0x0000000001010000, 0x00000001ffff01ff, 0x00000001ff000000, 0x0000000100ff0000,
|
|
4228
|
+
0x000000010000ff00, 0x00000001000000ff, 0x0000000100000000, 0x0000000100000001,
|
|
4229
|
+
0x0000000100000100, 0x0000000100010000, 0x0000000101000000, 0x000001ffff00ff00,
|
|
4230
|
+
0x000001ffff010001, 0x000001ffff0101ff, 0x000001ff00ffff01, 0x000001ff0000ffff,
|
|
4231
|
+
0x000001ff00000000, 0x000001ff010000ff, 0x000001ff01010100, 0x00000100ffff0100,
|
|
4232
|
+
0x00000100ff000000, 0x0000010000ff0000, 0x000001000000ff00, 0x00000100000000ff,
|
|
4233
|
+
0x0000010000000000, 0x0000010000000001, 0x0000010000000100, 0x0000010000010000,
|
|
4234
|
+
0x0000010001000000, 0x000001000101ff01, 0x00000101ffff0001, 0x00000101ff01ffff,
|
|
4235
|
+
0x0000010100000000, 0x0000010101010100, 0x0001ffffff000000, 0x0001ffff00ffffff,
|
|
4236
|
+
0x0001ffff00000100, 0x0001ffff0001ff00, 0x0001ffff01000000, 0x0001ff00ffffff00,
|
|
4237
|
+
0x0001ff00ffff01ff, 0x0001ff00ff010000, 0x0001ff0000000000, 0x0001ff0000010001,
|
|
4238
|
+
0x0001ff0001ff0000, 0x0001ff0001010100, 0x0001ff01ff0000ff, 0x0001ff01ff000001,
|
|
4239
|
+
0x0001ff0100ffffff, 0x0001ff010001ffff, 0x0001ff01000101ff, 0x0001ff010100ff01,
|
|
4240
|
+
0x000100ffff00ffff, 0x000100ffff00ff01, 0x000100ffff000100, 0x000100ff00000000,
|
|
4241
|
+
0x000100ff000101ff, 0x000100ff01ff0101, 0x000100ff0100ffff, 0x000100ff01010101,
|
|
4242
|
+
0x00010000ff000000, 0x00010000ff010100, 0x0001000000ff0000, 0x000100000000ff00,
|
|
4243
|
+
0x00010000000000ff, 0x0001000000000000, 0x0001000000000001, 0x0001000000000100,
|
|
4244
|
+
0x0001000000010000, 0x0001000001ffff01, 0x0001000001000000, 0x0001000100ff0101,
|
|
4245
|
+
0x0001000100000000, 0x00010001010100ff, 0x000101ffffff01ff, 0x000101ffffff0101,
|
|
4246
|
+
0x000101ff00010000, 0x000101ff01ff0000, 0x000101ff0100ff01, 0x00010100ffff0000,
|
|
4247
|
+
0x0001010000000000, 0x000101000001ffff, 0x0001010000010101, 0x00010100010001ff,
|
|
4248
|
+
0x00010101ff00ff00, 0x00010101ff010001, 0x0001010100ffffff, 0x0001010100ff01ff,
|
|
4249
|
+
0x00010101000101ff, 0x0001010101ff0000, 0x000101010100ff01, 0x0001010101000101,
|
|
4250
|
+
0x01ffffffffff0101, 0x01ffffffff01ffff, 0x01ffffffff01ff01, 0x01ffffffff0101ff,
|
|
4251
|
+
0x01ffffffff010101, 0x01ffffff00000000, 0x01ffffff01ff01ff, 0x01ffffff01000101,
|
|
4252
|
+
0x01ffffff0101ff01, 0x01ffffff010100ff, 0x01ffff000000ff00, 0x01ffff0000000001,
|
|
4253
|
+
0x01ffff00000001ff, 0x01ffff0000010000, 0x01ffff0001ff0000, 0x01ffff01ffffffff,
|
|
4254
|
+
0x01ffff01ffff01ff, 0x01ffff01ff000000, 0x01ffff01ff01ffff, 0x01ffff01ff0101ff,
|
|
4255
|
+
0x01ffff010100ffff, 0x01ff00ffffff0000, 0x01ff00ffff010000, 0x01ff00ff00ffff01,
|
|
4256
|
+
0x01ff0000ff0000ff, 0x01ff000000000000, 0x01ff00000001ff01, 0x01ff000001ffffff,
|
|
4257
|
+
0x01ff000001010100, 0x01ff0001ffffff01, 0x01ff0001ff010001, 0x01ff000101ff0100,
|
|
4258
|
+
0x01ff000101000001, 0x01ff0001010100ff, 0x01ff01ffff00ffff, 0x01ff01ff00010001,
|
|
4259
|
+
0x01ff01ff01000000, 0x01ff01ff010101ff, 0x01ff0100ff000001, 0x01ff010000ffff00,
|
|
4260
|
+
0x01ff010000000100, 0x01ff010001ff01ff, 0x01ff01000101ffff, 0x01ff0101ffff00ff,
|
|
4261
|
+
0x01ff0101ffff0101, 0x01ff0101ff0101ff, 0x01ff010100010000, 0x0100ffff00ff00ff,
|
|
4262
|
+
0x0100ffff00ff0001, 0x0100ffff00000100, 0x0100ffff0100ff00, 0x0100ff00ffff0000,
|
|
4263
|
+
0x0100ff00ff00ffff, 0x0100ff00ff00ff01, 0x0100ff00ff000100, 0x0100ff00ff010000,
|
|
4264
|
+
0x0100ff0000000000, 0x0100ff00000100ff, 0x0100ff0001ff0101, 0x0100ff0001010101,
|
|
4265
|
+
0x0100ff0100ff00ff, 0x0100ff0100ff0001, 0x0100ff0100000100, 0x0100ff0100010001,
|
|
4266
|
+
0x0100ff0101000000, 0x010000ffff00ff00, 0x010000ff0000ffff, 0x010000ff00000000,
|
|
4267
|
+
0x010000ff010001ff, 0x010000ff01010001, 0x01000000ffffff00, 0x01000000ffff0101,
|
|
4268
|
+
0x01000000ff000000, 0x01000000ff0100ff, 0x01000000ff010101, 0x0100000000ff0000,
|
|
4269
|
+
0x010000000000ff00, 0x01000000000000ff, 0x0100000000000000, 0x0100000000000001,
|
|
4270
|
+
0x0100000000000100, 0x0100000000010000, 0x0100000001000000, 0x0100000100000000,
|
|
4271
|
+
0x01000001000101ff, 0x0100000101ffff01, 0x010001ffff000101, 0x010001ff00ff0100,
|
|
4272
|
+
0x010001ff0000ff00, 0x010001ff000100ff, 0x010001ff01ffffff, 0x01000100ffff0000,
|
|
4273
|
+
0x01000100ff0001ff, 0x0100010000000000, 0x010001000001ff00, 0x0100010001ff0000,
|
|
4274
|
+
0x01000100010000ff, 0x0100010001000101, 0x01000101ff00ff01, 0x0100010100ff0100,
|
|
4275
|
+
0x010001010000ffff, 0x0100010101010001, 0x0101ffffffff0101, 0x0101ffffff0001ff,
|
|
4276
|
+
0x0101ffffff01ffff, 0x0101ffffff010101, 0x0101ffff00000000, 0x0101ffff0101ffff,
|
|
4277
|
+
0x0101ffff010101ff, 0x0101ff00ff000000, 0x0101ff0000ff0100, 0x0101ff000000ff00,
|
|
4278
|
+
0x0101ff0000010000, 0x0101ff00010000ff, 0x0101ff0001000001, 0x0101ff01ff010101,
|
|
4279
|
+
0x0101ff0100000000, 0x0101ff010101ff00, 0x010100ffffff0000, 0x010100ffff010000,
|
|
4280
|
+
0x010100ff00ff01ff, 0x010100ff000000ff, 0x010100ff00000101, 0x010100ff01ffff00,
|
|
4281
|
+
0x01010000ffffff01, 0x01010000ff000100, 0x01010000ff01ff01, 0x0101000000000000,
|
|
4282
|
+
0x01010000000100ff, 0x010100000101ff01, 0x01010001ffff0000, 0x01010001ff00ffff,
|
|
4283
|
+
0x01010001ff010000, 0x0101000101ffffff, 0x0101000101ff01ff, 0x0101000101010101,
|
|
4284
|
+
0x010101ffff01ffff, 0x010101ff00000000, 0x010101ff0001ff01, 0x010101ff0101ffff,
|
|
4285
|
+
0x010101ff010101ff, 0x01010100ffffffff, 0x01010100ff000001, 0x010101000000ff00,
|
|
4286
|
+
0x0101010001010000, 0x0101010100ff0001, 0x010101010001ff01, 0x010101010101ffff,
|
|
4287
|
+
};
|
|
3750
4288
|
|
|
3751
4289
|
constexpr constant static uint8_t ksigns_iq2xs[128] = {
|
|
3752
4290
|
0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
|
|
@@ -3812,7 +4350,6 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
|
|
|
3812
4350
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
3813
4351
|
}
|
|
3814
4352
|
|
|
3815
|
-
#if QK_K == 256
|
|
3816
4353
|
const int ix = tiisg;
|
|
3817
4354
|
|
|
3818
4355
|
device const float * y4 = y + 32 * ix;
|
|
@@ -3853,9 +4390,6 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
|
|
|
3853
4390
|
|
|
3854
4391
|
y4 += 32 * 32;
|
|
3855
4392
|
}
|
|
3856
|
-
#else
|
|
3857
|
-
// TODO
|
|
3858
|
-
#endif
|
|
3859
4393
|
|
|
3860
4394
|
for (int row = 0; row < N_DST; ++row) {
|
|
3861
4395
|
all_sum = simd_sum(sumf[row]);
|
|
@@ -3945,7 +4479,6 @@ void kernel_mul_mv_iq2_xs_f32_impl(
|
|
|
3945
4479
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
3946
4480
|
}
|
|
3947
4481
|
|
|
3948
|
-
#if QK_K == 256
|
|
3949
4482
|
const int ix = tiisg;
|
|
3950
4483
|
|
|
3951
4484
|
device const float * y4 = y + 32 * ix;
|
|
@@ -3996,9 +4529,6 @@ void kernel_mul_mv_iq2_xs_f32_impl(
|
|
|
3996
4529
|
|
|
3997
4530
|
y4 += 32 * 32;
|
|
3998
4531
|
}
|
|
3999
|
-
#else
|
|
4000
|
-
// TODO
|
|
4001
|
-
#endif
|
|
4002
4532
|
|
|
4003
4533
|
for (int row = 0; row < N_DST; ++row) {
|
|
4004
4534
|
all_sum = simd_sum(sumf[row]);
|
|
@@ -4088,7 +4618,6 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
|
|
|
4088
4618
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
4089
4619
|
}
|
|
4090
4620
|
|
|
4091
|
-
#if QK_K == 256
|
|
4092
4621
|
const int ix = tiisg;
|
|
4093
4622
|
|
|
4094
4623
|
device const float * y4 = y + 32 * ix;
|
|
@@ -4132,9 +4661,6 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
|
|
|
4132
4661
|
|
|
4133
4662
|
y4 += 32 * 32;
|
|
4134
4663
|
}
|
|
4135
|
-
#else
|
|
4136
|
-
// TODO
|
|
4137
|
-
#endif
|
|
4138
4664
|
|
|
4139
4665
|
for (int row = 0; row < N_DST; ++row) {
|
|
4140
4666
|
all_sum = simd_sum(sumf[row]);
|
|
@@ -4173,767 +4699,1802 @@ kernel void kernel_mul_mv_iq3_xxs_f32(
|
|
|
4173
4699
|
kernel_mul_mv_iq3_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
|
4174
4700
|
}
|
|
4175
4701
|
|
|
4702
|
+
void kernel_mul_mv_iq3_s_f32_impl(
|
|
4703
|
+
device const void * src0,
|
|
4704
|
+
device const float * src1,
|
|
4705
|
+
device float * dst,
|
|
4706
|
+
constant int64_t & ne00,
|
|
4707
|
+
constant int64_t & ne01,
|
|
4708
|
+
constant int64_t & ne02,
|
|
4709
|
+
constant int64_t & ne10,
|
|
4710
|
+
constant int64_t & ne12,
|
|
4711
|
+
constant int64_t & ne0,
|
|
4712
|
+
constant int64_t & ne1,
|
|
4713
|
+
constant uint & r2,
|
|
4714
|
+
constant uint & r3,
|
|
4715
|
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
|
4716
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
4717
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
4718
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
4176
4719
|
|
|
4177
|
-
|
|
4178
|
-
|
|
4179
|
-
|
|
4180
|
-
|
|
4181
|
-
void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
|
|
4182
|
-
float4x4 temp = *(((device float4x4 *)src));
|
|
4183
|
-
for (int i = 0; i < 16; i++){
|
|
4184
|
-
reg[i/4][i%4] = temp[i/4][i%4];
|
|
4185
|
-
}
|
|
4186
|
-
}
|
|
4187
|
-
|
|
4188
|
-
template <typename type4x4>
|
|
4189
|
-
void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
|
|
4190
|
-
half4x4 temp = *(((device half4x4 *)src));
|
|
4191
|
-
for (int i = 0; i < 16; i++){
|
|
4192
|
-
reg[i/4][i%4] = temp[i/4][i%4];
|
|
4193
|
-
}
|
|
4194
|
-
}
|
|
4720
|
+
const int nb = ne00/QK_K;
|
|
4721
|
+
const int r0 = tgpig.x;
|
|
4722
|
+
const int r1 = tgpig.y;
|
|
4723
|
+
const int im = tgpig.z;
|
|
4195
4724
|
|
|
4196
|
-
|
|
4197
|
-
|
|
4198
|
-
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
|
|
4199
|
-
const float d1 = il ? (xb->d / 16.h) : xb->d;
|
|
4200
|
-
const float d2 = d1 / 256.f;
|
|
4201
|
-
const float md = -8.h * xb->d;
|
|
4202
|
-
const ushort mask0 = il ? 0x00F0 : 0x000F;
|
|
4203
|
-
const ushort mask1 = mask0 << 8;
|
|
4725
|
+
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
|
4726
|
+
const int ib_row = first_row * nb;
|
|
4204
4727
|
|
|
4205
|
-
|
|
4206
|
-
|
|
4207
|
-
reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md;
|
|
4208
|
-
}
|
|
4209
|
-
}
|
|
4728
|
+
const uint i12 = im%ne12;
|
|
4729
|
+
const uint i13 = im/ne12;
|
|
4210
4730
|
|
|
4211
|
-
|
|
4212
|
-
void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
|
|
4213
|
-
device const uint16_t * qs = ((device const uint16_t *)xb + 2);
|
|
4214
|
-
const float d1 = il ? (xb->d / 16.h) : xb->d;
|
|
4215
|
-
const float d2 = d1 / 256.f;
|
|
4216
|
-
const float m = xb->m;
|
|
4217
|
-
const ushort mask0 = il ? 0x00F0 : 0x000F;
|
|
4218
|
-
const ushort mask1 = mask0 << 8;
|
|
4731
|
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
|
4219
4732
|
|
|
4220
|
-
|
|
4221
|
-
|
|
4222
|
-
reg[i/2][2*(i%2)+1] = ((qs[i] & mask1) * d2) + m;
|
|
4223
|
-
}
|
|
4224
|
-
}
|
|
4733
|
+
device const block_iq3_s * x = (device const block_iq3_s *) src0 + ib_row + offset0;
|
|
4734
|
+
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
|
4225
4735
|
|
|
4226
|
-
|
|
4227
|
-
|
|
4228
|
-
device const uint16_t * qs = ((device const uint16_t *)xb + 3);
|
|
4229
|
-
const float d = xb->d;
|
|
4230
|
-
const float md = -16.h * xb->d;
|
|
4231
|
-
const ushort mask = il ? 0x00F0 : 0x000F;
|
|
4736
|
+
float yl[32];
|
|
4737
|
+
float sumf[N_DST]={0.f}, all_sum;
|
|
4232
4738
|
|
|
4233
|
-
const
|
|
4739
|
+
const int nb32 = nb * (QK_K / 32);
|
|
4234
4740
|
|
|
4235
|
-
|
|
4741
|
+
threadgroup uint32_t * values = (threadgroup uint32_t *)shared_values;
|
|
4742
|
+
{
|
|
4743
|
+
int nval = 8;
|
|
4744
|
+
int pos = (32*sgitg + tiisg)*nval;
|
|
4745
|
+
for (int i = 0; i < nval; ++i) values[pos + i] = iq3xs_grid[pos + i];
|
|
4746
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
4747
|
+
}
|
|
4236
4748
|
|
|
4237
|
-
const int
|
|
4238
|
-
const int gh_bk = il ? 0 : 4;
|
|
4749
|
+
const int ix = tiisg;
|
|
4239
4750
|
|
|
4240
|
-
|
|
4241
|
-
// extract the 5-th bits for x0 and x1
|
|
4242
|
-
const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
|
|
4243
|
-
const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
|
|
4751
|
+
device const float * y4 = y + 32 * ix;
|
|
4244
4752
|
|
|
4245
|
-
|
|
4246
|
-
const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
|
|
4247
|
-
const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
|
|
4753
|
+
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
|
|
4248
4754
|
|
|
4249
|
-
|
|
4250
|
-
|
|
4251
|
-
|
|
4252
|
-
}
|
|
4755
|
+
for (int i = 0; i < 32; ++i) {
|
|
4756
|
+
yl[i] = y4[i];
|
|
4757
|
+
}
|
|
4253
4758
|
|
|
4254
|
-
|
|
4255
|
-
|
|
4256
|
-
device const uint16_t * qs = ((device const uint16_t *)xb + 4);
|
|
4257
|
-
const float d = xb->d;
|
|
4258
|
-
const float m = xb->m;
|
|
4259
|
-
const ushort mask = il ? 0x00F0 : 0x000F;
|
|
4759
|
+
const int ibl = ib32 / (QK_K / 32);
|
|
4760
|
+
const int ib = ib32 % (QK_K / 32);
|
|
4260
4761
|
|
|
4261
|
-
|
|
4762
|
+
device const block_iq3_s * xr = x + ibl;
|
|
4763
|
+
device const uint8_t * qs = xr->qs + 8 * ib;
|
|
4764
|
+
device const uint8_t * qh = xr->qh + ib;
|
|
4765
|
+
device const uint8_t * sc = xr->scales + (ib/2);
|
|
4766
|
+
device const uint8_t * signs = xr->signs + 4 * ib;
|
|
4767
|
+
device const half * dh = &xr->d;
|
|
4262
4768
|
|
|
4263
|
-
|
|
4769
|
+
for (int row = 0; row < N_DST; row++) {
|
|
4264
4770
|
|
|
4265
|
-
|
|
4266
|
-
|
|
4771
|
+
const float db = dh[0];
|
|
4772
|
+
const float d = db * (0.5f + ((sc[0] >> 4*(ib%2)) & 0xf));
|
|
4267
4773
|
|
|
4268
|
-
|
|
4269
|
-
|
|
4270
|
-
|
|
4271
|
-
|
|
4774
|
+
float2 sum = {0};
|
|
4775
|
+
for (int l = 0; l < 4; ++l) {
|
|
4776
|
+
const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(values + (qs[2*l+0] | ((qh[0] << (8-2*l)) & 256)));
|
|
4777
|
+
const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(values + (qs[2*l+1] | ((qh[0] << (7-2*l)) & 256)));
|
|
4778
|
+
for (int j = 0; j < 4; ++j) {
|
|
4779
|
+
sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l] & kmask_iq2xs[j+0]);
|
|
4780
|
+
sum[1] += yl[8*l + j + 4] * grid2[j] * select(1, -1, signs[l] & kmask_iq2xs[j+4]);
|
|
4781
|
+
}
|
|
4782
|
+
}
|
|
4783
|
+
sumf[row] += d * (sum[0] + sum[1]);
|
|
4272
4784
|
|
|
4273
|
-
|
|
4274
|
-
|
|
4275
|
-
|
|
4785
|
+
dh += nb*sizeof(block_iq3_s)/2;
|
|
4786
|
+
qs += nb*sizeof(block_iq3_s);
|
|
4787
|
+
qh += nb*sizeof(block_iq3_s);
|
|
4788
|
+
sc += nb*sizeof(block_iq3_s);
|
|
4789
|
+
signs += nb*sizeof(block_iq3_s);
|
|
4790
|
+
}
|
|
4276
4791
|
|
|
4277
|
-
|
|
4278
|
-
reg[i/2][2*(i%2)+1] = d * x1 + m;
|
|
4792
|
+
y4 += 32 * 32;
|
|
4279
4793
|
}
|
|
4280
|
-
}
|
|
4281
|
-
|
|
4282
|
-
template <typename type4x4>
|
|
4283
|
-
void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
|
|
4284
|
-
device const int8_t * qs = ((device const int8_t *)xb->qs);
|
|
4285
|
-
const half d = xb->d;
|
|
4286
4794
|
|
|
4287
|
-
for (int
|
|
4288
|
-
|
|
4795
|
+
for (int row = 0; row < N_DST; ++row) {
|
|
4796
|
+
all_sum = simd_sum(sumf[row]);
|
|
4797
|
+
if (tiisg == 0) {
|
|
4798
|
+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.5f;
|
|
4799
|
+
}
|
|
4289
4800
|
}
|
|
4290
4801
|
}
|
|
4291
4802
|
|
|
4292
|
-
|
|
4293
|
-
void
|
|
4294
|
-
|
|
4295
|
-
|
|
4296
|
-
|
|
4297
|
-
|
|
4298
|
-
|
|
4803
|
+
[[host_name("kernel_mul_mv_iq3_s_f32")]]
|
|
4804
|
+
kernel void kernel_mul_mv_iq3_s_f32(
|
|
4805
|
+
device const void * src0,
|
|
4806
|
+
device const float * src1,
|
|
4807
|
+
device float * dst,
|
|
4808
|
+
constant int64_t & ne00,
|
|
4809
|
+
constant int64_t & ne01,
|
|
4810
|
+
constant int64_t & ne02,
|
|
4811
|
+
constant uint64_t & nb00,
|
|
4812
|
+
constant uint64_t & nb01,
|
|
4813
|
+
constant uint64_t & nb02,
|
|
4814
|
+
constant int64_t & ne10,
|
|
4815
|
+
constant int64_t & ne11,
|
|
4816
|
+
constant int64_t & ne12,
|
|
4817
|
+
constant uint64_t & nb10,
|
|
4818
|
+
constant uint64_t & nb11,
|
|
4819
|
+
constant uint64_t & nb12,
|
|
4820
|
+
constant int64_t & ne0,
|
|
4821
|
+
constant int64_t & ne1,
|
|
4822
|
+
constant uint & r2,
|
|
4823
|
+
constant uint & r3,
|
|
4824
|
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
|
4825
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
4826
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
4827
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
4299
4828
|
|
|
4300
|
-
|
|
4301
|
-
q = q + 32*(il/8) + 16*(il&1);
|
|
4302
|
-
il = (il/2)%4;
|
|
4303
|
-
#endif
|
|
4304
|
-
half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
|
|
4305
|
-
uchar mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
|
|
4306
|
-
dl = d * (sc & 0xF) * coef, ml = min * (sc >> 4);
|
|
4307
|
-
for (int i = 0; i < 16; ++i) {
|
|
4308
|
-
reg[i/4][i%4] = dl * (q[i] & mask) - ml;
|
|
4309
|
-
}
|
|
4829
|
+
kernel_mul_mv_iq3_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
|
4310
4830
|
}
|
|
4311
4831
|
|
|
4312
|
-
|
|
4313
|
-
|
|
4314
|
-
|
|
4315
|
-
|
|
4316
|
-
|
|
4317
|
-
|
|
4318
|
-
|
|
4319
|
-
|
|
4320
|
-
|
|
4321
|
-
|
|
4322
|
-
|
|
4323
|
-
|
|
4324
|
-
|
|
4325
|
-
|
|
4326
|
-
|
|
4327
|
-
|
|
4328
|
-
|
|
4329
|
-
float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f);
|
|
4330
|
-
const float ml = 4.f * dl;
|
|
4832
|
+
void kernel_mul_mv_iq2_s_f32_impl(
|
|
4833
|
+
device const void * src0,
|
|
4834
|
+
device const float * src1,
|
|
4835
|
+
device float * dst,
|
|
4836
|
+
constant int64_t & ne00,
|
|
4837
|
+
constant int64_t & ne01,
|
|
4838
|
+
constant int64_t & ne02,
|
|
4839
|
+
constant int64_t & ne10,
|
|
4840
|
+
constant int64_t & ne12,
|
|
4841
|
+
constant int64_t & ne0,
|
|
4842
|
+
constant int64_t & ne1,
|
|
4843
|
+
constant uint & r2,
|
|
4844
|
+
constant uint & r3,
|
|
4845
|
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
|
4846
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
4847
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
4848
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
4331
4849
|
|
|
4332
|
-
|
|
4333
|
-
const
|
|
4334
|
-
const
|
|
4335
|
-
|
|
4850
|
+
const int nb = ne00/QK_K;
|
|
4851
|
+
const int r0 = tgpig.x;
|
|
4852
|
+
const int r1 = tgpig.y;
|
|
4853
|
+
const int im = tgpig.z;
|
|
4336
4854
|
|
|
4337
|
-
|
|
4338
|
-
|
|
4339
|
-
}
|
|
4340
|
-
#else
|
|
4341
|
-
float kcoef = il&1 ? 1.f/16.f : 1.f;
|
|
4342
|
-
uint16_t kmask = il&1 ? 0xF0 : 0x0F;
|
|
4343
|
-
float dl = d_all * ((scales[il/2] & kmask) * kcoef - 8);
|
|
4344
|
-
float coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
|
|
4345
|
-
uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
|
|
4346
|
-
uint8_t m = 1<<(il*2);
|
|
4347
|
-
for (int i = 0; i < 16; ++i) {
|
|
4348
|
-
reg[i/4][i%4] = coef * dl * ((q[i] & mask) - ((h[i%8] & (m * (1 + i/8))) ? 0 : 4.f/coef));
|
|
4349
|
-
}
|
|
4350
|
-
#endif
|
|
4351
|
-
}
|
|
4855
|
+
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
|
4856
|
+
const int ib_row = first_row * nb;
|
|
4352
4857
|
|
|
4353
|
-
|
|
4354
|
-
|
|
4355
|
-
: uchar2{uchar((q[j+4+k] & 0xF) | ((q[j-4+k] & 0xc0) >> 2)), uchar((q[j+4+k] >> 4) | ((q[j-0+k] & 0xc0) >> 2))};
|
|
4356
|
-
}
|
|
4858
|
+
const uint i12 = im%ne12;
|
|
4859
|
+
const uint i13 = im/ne12;
|
|
4357
4860
|
|
|
4358
|
-
|
|
4359
|
-
void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) {
|
|
4360
|
-
device const uchar * q = xb->qs;
|
|
4861
|
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
|
4361
4862
|
|
|
4362
|
-
|
|
4363
|
-
|
|
4364
|
-
q = q + (il/4) * 32 + 16 * (il&1);
|
|
4365
|
-
il = il & 3;
|
|
4366
|
-
const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
|
|
4367
|
-
const float d = il < 2 ? xb->d : xb->d / 16.h;
|
|
4368
|
-
const float min = xb->dmin;
|
|
4369
|
-
const float dl = d * sc[0];
|
|
4370
|
-
const float ml = min * sc[1];
|
|
4371
|
-
#else
|
|
4372
|
-
q = q + 16 * (il&1);
|
|
4373
|
-
device const uint8_t * s = xb->scales;
|
|
4374
|
-
device const half2 * dh = (device const half2 *)xb->d;
|
|
4375
|
-
const float2 d = (float2)dh[0];
|
|
4376
|
-
const float dl = il<2 ? d[0] * (s[0]&0xF) : d[0] * (s[1]&0xF)/16.h;
|
|
4377
|
-
const float ml = il<2 ? d[1] * (s[0]>>4) : d[1] * (s[1]>>4);
|
|
4378
|
-
#endif
|
|
4379
|
-
const ushort mask = il<2 ? 0x0F : 0xF0;
|
|
4380
|
-
for (int i = 0; i < 16; ++i) {
|
|
4381
|
-
reg[i/4][i%4] = dl * (q[i] & mask) - ml;
|
|
4382
|
-
}
|
|
4383
|
-
}
|
|
4863
|
+
device const block_iq2_s * x = (device const block_iq2_s *) src0 + ib_row + offset0;
|
|
4864
|
+
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
|
4384
4865
|
|
|
4385
|
-
|
|
4386
|
-
|
|
4387
|
-
device const uint8_t * q = xb->qs;
|
|
4388
|
-
device const uint8_t * qh = xb->qh;
|
|
4866
|
+
float yl[32];
|
|
4867
|
+
float sumf[N_DST]={0.f}, all_sum;
|
|
4389
4868
|
|
|
4390
|
-
|
|
4391
|
-
short is = (il/4) * 2;
|
|
4392
|
-
q = q + 32 * (il/4) + 16 * (il&1);
|
|
4393
|
-
qh = qh + 16 * (il&1);
|
|
4394
|
-
uint8_t ul = 1 << (il/2);
|
|
4395
|
-
il = il & 3;
|
|
4396
|
-
const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
|
|
4397
|
-
const float d = il < 2 ? xb->d : xb->d / 16.f;
|
|
4398
|
-
const float min = xb->dmin;
|
|
4399
|
-
const float dl = d * sc[0];
|
|
4400
|
-
const float ml = min * sc[1];
|
|
4869
|
+
const int nb32 = nb * (QK_K / 32);
|
|
4401
4870
|
|
|
4402
|
-
|
|
4403
|
-
|
|
4404
|
-
|
|
4405
|
-
|
|
4406
|
-
|
|
4407
|
-
|
|
4408
|
-
|
|
4409
|
-
device const int8_t * s = xb->scales;
|
|
4410
|
-
const float dl = xb->d * s[il];
|
|
4411
|
-
uint8_t m = 1<<(il*2);
|
|
4412
|
-
const float coef = il<2 ? 1.f : 1.f/16.f;
|
|
4413
|
-
const ushort mask = il<2 ? 0x0F : 0xF0;
|
|
4414
|
-
for (int i = 0; i < 16; ++i) {
|
|
4415
|
-
reg[i/4][i%4] = coef * dl * ((q[i] & mask) - (qh[i%8] & (m*(1+i/8)) ? 0.f : 16.f/coef));
|
|
4416
|
-
}
|
|
4417
|
-
#endif
|
|
4418
|
-
}
|
|
4871
|
+
//threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values;
|
|
4872
|
+
//{
|
|
4873
|
+
// int nval = 32;
|
|
4874
|
+
// int pos = (32*sgitg + tiisg)*nval;
|
|
4875
|
+
// for (int i = 0; i < nval; ++i) values[pos + i] = iq2s_grid[pos + i];
|
|
4876
|
+
// threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
4877
|
+
//}
|
|
4419
4878
|
|
|
4420
|
-
|
|
4421
|
-
void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) {
|
|
4422
|
-
const half d_all = xb->d;
|
|
4423
|
-
device const uint8_t * ql = (device const uint8_t *)xb->ql;
|
|
4424
|
-
device const uint8_t * qh = (device const uint8_t *)xb->qh;
|
|
4425
|
-
device const int8_t * scales = (device const int8_t *)xb->scales;
|
|
4879
|
+
const int ix = tiisg;
|
|
4426
4880
|
|
|
4427
|
-
|
|
4428
|
-
ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
|
|
4429
|
-
qh = qh + 32*(il/8) + 16*(il&1);
|
|
4430
|
-
float sc = scales[(il%2) + 2 * ((il/2))];
|
|
4431
|
-
il = (il/2) & 3;
|
|
4432
|
-
#else
|
|
4433
|
-
ql = ql + 16 * (il&1);
|
|
4434
|
-
float sc = scales[il];
|
|
4435
|
-
#endif
|
|
4436
|
-
const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
|
|
4437
|
-
const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
|
|
4438
|
-
const float coef = il>1 ? 1.f/16.f : 1.f;
|
|
4439
|
-
const float ml = d_all * sc * 32.f;
|
|
4440
|
-
const float dl = d_all * sc * coef;
|
|
4441
|
-
for (int i = 0; i < 16; ++i) {
|
|
4442
|
-
const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2))
|
|
4443
|
-
: ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4));
|
|
4444
|
-
reg[i/4][i%4] = dl * q - ml;
|
|
4445
|
-
}
|
|
4446
|
-
}
|
|
4881
|
+
device const float * y4 = y + 32 * ix;
|
|
4447
4882
|
|
|
4448
|
-
|
|
4449
|
-
void dequantize_iq2_xxs(device const block_iq2_xxs * xb, short il, thread type4x4 & reg) {
|
|
4450
|
-
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
|
|
4451
|
-
const float d = xb->d;
|
|
4452
|
-
const int ib32 = il/2;
|
|
4453
|
-
il = il%2;
|
|
4454
|
-
// il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
|
|
4455
|
-
// each block of 32 needs 2 uint32_t's for the quants & scale, so 4 uint16_t's.
|
|
4456
|
-
device const uint16_t * q2 = xb->qs + 4*ib32;
|
|
4457
|
-
const uint32_t aux32_g = q2[0] | (q2[1] << 16);
|
|
4458
|
-
const uint32_t aux32_s = q2[2] | (q2[3] << 16);
|
|
4459
|
-
thread const uint8_t * aux8 = (thread const uint8_t *)&aux32_g;
|
|
4460
|
-
const float dl = d * (0.5f + (aux32_s >> 28)) * 0.25f;
|
|
4461
|
-
constant uint8_t * grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+0]);
|
|
4462
|
-
uint8_t signs = ksigns_iq2xs[(aux32_s >> 14*il) & 127];
|
|
4463
|
-
for (int i = 0; i < 8; ++i) {
|
|
4464
|
-
reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
|
|
4465
|
-
}
|
|
4466
|
-
grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+1]);
|
|
4467
|
-
signs = ksigns_iq2xs[(aux32_s >> (14*il+7)) & 127];
|
|
4468
|
-
for (int i = 0; i < 8; ++i) {
|
|
4469
|
-
reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
|
|
4470
|
-
}
|
|
4471
|
-
}
|
|
4883
|
+
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
|
|
4472
4884
|
|
|
4473
|
-
|
|
4474
|
-
|
|
4475
|
-
|
|
4476
|
-
const float d = xb->d;
|
|
4477
|
-
const int ib32 = il/2;
|
|
4478
|
-
il = il%2;
|
|
4479
|
-
// il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
|
|
4480
|
-
device const uint16_t * q2 = xb->qs + 4*ib32;
|
|
4481
|
-
const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f;
|
|
4482
|
-
constant uint8_t * grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+0] & 511));
|
|
4483
|
-
uint8_t signs = ksigns_iq2xs[q2[2*il+0] >> 9];
|
|
4484
|
-
for (int i = 0; i < 8; ++i) {
|
|
4485
|
-
reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
|
|
4486
|
-
}
|
|
4487
|
-
grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+1] & 511));
|
|
4488
|
-
signs = ksigns_iq2xs[q2[2*il+1] >> 9];
|
|
4489
|
-
for (int i = 0; i < 8; ++i) {
|
|
4490
|
-
reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
|
|
4491
|
-
}
|
|
4492
|
-
}
|
|
4885
|
+
for (int i = 0; i < 32; ++i) {
|
|
4886
|
+
yl[i] = y4[i];
|
|
4887
|
+
}
|
|
4493
4888
|
|
|
4494
|
-
|
|
4495
|
-
|
|
4496
|
-
|
|
4497
|
-
|
|
4498
|
-
|
|
4499
|
-
|
|
4500
|
-
|
|
4501
|
-
|
|
4502
|
-
|
|
4503
|
-
|
|
4504
|
-
|
|
4505
|
-
|
|
4506
|
-
|
|
4507
|
-
|
|
4508
|
-
|
|
4509
|
-
|
|
4510
|
-
|
|
4889
|
+
const int ibl = ib32 / (QK_K / 32);
|
|
4890
|
+
const int ib = ib32 % (QK_K / 32);
|
|
4891
|
+
|
|
4892
|
+
device const block_iq2_s * xr = x + ibl;
|
|
4893
|
+
device const uint8_t * qs = xr->qs + 4 * ib;
|
|
4894
|
+
device const uint8_t * qh = xr->qh + ib;
|
|
4895
|
+
device const uint8_t * sc = xr->scales + ib;
|
|
4896
|
+
device const uint8_t * signs = qs + QK_K/8;
|
|
4897
|
+
device const half * dh = &xr->d;
|
|
4898
|
+
|
|
4899
|
+
for (int row = 0; row < N_DST; row++) {
|
|
4900
|
+
|
|
4901
|
+
const float db = dh[0];
|
|
4902
|
+
const float d1 = db * (0.5f + (sc[0] & 0xf));
|
|
4903
|
+
const float d2 = db * (0.5f + (sc[0] >> 4));
|
|
4904
|
+
|
|
4905
|
+
float2 sum = {0};
|
|
4906
|
+
for (int l = 0; l < 2; ++l) {
|
|
4907
|
+
//const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(values + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300)));
|
|
4908
|
+
//const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(values + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300)));
|
|
4909
|
+
constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300)));
|
|
4910
|
+
constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300)));
|
|
4911
|
+
for (int j = 0; j < 8; ++j) {
|
|
4912
|
+
sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l+0] & kmask_iq2xs[j]);
|
|
4913
|
+
sum[1] += yl[8*l + j + 16] * grid2[j] * select(1, -1, signs[l+2] & kmask_iq2xs[j]);
|
|
4914
|
+
}
|
|
4915
|
+
}
|
|
4916
|
+
sumf[row] += d1 * sum[0] + d2 * sum[1];
|
|
4917
|
+
|
|
4918
|
+
dh += nb*sizeof(block_iq2_s)/2;
|
|
4919
|
+
qs += nb*sizeof(block_iq2_s);
|
|
4920
|
+
qh += nb*sizeof(block_iq2_s);
|
|
4921
|
+
sc += nb*sizeof(block_iq2_s);
|
|
4922
|
+
signs += nb*sizeof(block_iq2_s);
|
|
4923
|
+
}
|
|
4924
|
+
|
|
4925
|
+
y4 += 32 * 32;
|
|
4511
4926
|
}
|
|
4512
|
-
|
|
4513
|
-
|
|
4514
|
-
|
|
4515
|
-
|
|
4516
|
-
|
|
4517
|
-
|
|
4927
|
+
|
|
4928
|
+
for (int row = 0; row < N_DST; ++row) {
|
|
4929
|
+
all_sum = simd_sum(sumf[row]);
|
|
4930
|
+
if (tiisg == 0) {
|
|
4931
|
+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f;
|
|
4932
|
+
}
|
|
4518
4933
|
}
|
|
4519
4934
|
}
|
|
4520
4935
|
|
|
4521
|
-
|
|
4522
|
-
kernel void
|
|
4936
|
+
[[host_name("kernel_mul_mv_iq2_s_f32")]]
|
|
4937
|
+
kernel void kernel_mul_mv_iq2_s_f32(
|
|
4523
4938
|
device const void * src0,
|
|
4524
|
-
device const
|
|
4939
|
+
device const float * src1,
|
|
4525
4940
|
device float * dst,
|
|
4526
4941
|
constant int64_t & ne00,
|
|
4942
|
+
constant int64_t & ne01,
|
|
4943
|
+
constant int64_t & ne02,
|
|
4944
|
+
constant uint64_t & nb00,
|
|
4527
4945
|
constant uint64_t & nb01,
|
|
4528
4946
|
constant uint64_t & nb02,
|
|
4529
4947
|
constant int64_t & ne10,
|
|
4948
|
+
constant int64_t & ne11,
|
|
4949
|
+
constant int64_t & ne12,
|
|
4530
4950
|
constant uint64_t & nb10,
|
|
4531
4951
|
constant uint64_t & nb11,
|
|
4532
|
-
constant uint64_t &
|
|
4533
|
-
constant
|
|
4534
|
-
|
|
4535
|
-
uint
|
|
4536
|
-
|
|
4537
|
-
|
|
4538
|
-
|
|
4539
|
-
|
|
4540
|
-
|
|
4541
|
-
const int64_t i11 = tgpig.y;
|
|
4542
|
-
|
|
4543
|
-
const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
|
|
4544
|
-
|
|
4545
|
-
const int64_t i02 = i11;
|
|
4952
|
+
constant uint64_t & nb12,
|
|
4953
|
+
constant int64_t & ne0,
|
|
4954
|
+
constant int64_t & ne1,
|
|
4955
|
+
constant uint & r2,
|
|
4956
|
+
constant uint & r3,
|
|
4957
|
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
|
4958
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
4959
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
4960
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
4546
4961
|
|
|
4547
|
-
|
|
4548
|
-
float4x4 temp;
|
|
4549
|
-
dequantize_func(
|
|
4550
|
-
((device const block_q *) ((device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp);
|
|
4551
|
-
*(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp;
|
|
4552
|
-
}
|
|
4962
|
+
kernel_mul_mv_iq2_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
|
4553
4963
|
}
|
|
4554
4964
|
|
|
4555
|
-
|
|
4965
|
+
void kernel_mul_mv_iq1_s_f32_impl(
|
|
4556
4966
|
device const void * src0,
|
|
4557
|
-
device const
|
|
4967
|
+
device const float * src1,
|
|
4558
4968
|
device float * dst,
|
|
4559
4969
|
constant int64_t & ne00,
|
|
4560
|
-
constant
|
|
4561
|
-
constant
|
|
4970
|
+
constant int64_t & ne01,
|
|
4971
|
+
constant int64_t & ne02,
|
|
4562
4972
|
constant int64_t & ne10,
|
|
4563
|
-
constant
|
|
4564
|
-
constant
|
|
4565
|
-
constant
|
|
4566
|
-
constant
|
|
4567
|
-
|
|
4568
|
-
|
|
4569
|
-
|
|
4570
|
-
|
|
4571
|
-
const int64_t i11 = tgpig.y;
|
|
4572
|
-
|
|
4573
|
-
const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
|
|
4973
|
+
constant int64_t & ne12,
|
|
4974
|
+
constant int64_t & ne0,
|
|
4975
|
+
constant int64_t & ne1,
|
|
4976
|
+
constant uint & r2,
|
|
4977
|
+
constant uint & r3,
|
|
4978
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
4979
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
4980
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
4574
4981
|
|
|
4575
|
-
const
|
|
4982
|
+
const int nb = ne00/QK_K;
|
|
4983
|
+
const int r0 = tgpig.x;
|
|
4984
|
+
const int r1 = tgpig.y;
|
|
4985
|
+
const int im = tgpig.z;
|
|
4576
4986
|
|
|
4577
|
-
|
|
4578
|
-
|
|
4579
|
-
((device float *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
|
|
4580
|
-
}
|
|
4581
|
-
}
|
|
4987
|
+
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
|
4988
|
+
const int ib_row = first_row * nb;
|
|
4582
4989
|
|
|
4583
|
-
|
|
4584
|
-
|
|
4585
|
-
device const char * src1,
|
|
4586
|
-
device float * dst,
|
|
4587
|
-
constant int64_t & ne00,
|
|
4588
|
-
constant uint64_t & nb01,
|
|
4589
|
-
constant uint64_t & nb02,
|
|
4590
|
-
constant int64_t & ne10,
|
|
4591
|
-
constant uint64_t & nb10,
|
|
4592
|
-
constant uint64_t & nb11,
|
|
4593
|
-
constant uint64_t & nb1,
|
|
4594
|
-
constant uint64_t & nb2,
|
|
4595
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
4596
|
-
uint tiitg[[thread_index_in_threadgroup]],
|
|
4597
|
-
uint3 tptg [[threads_per_threadgroup]]) {
|
|
4598
|
-
const int64_t i10 = tgpig.x;
|
|
4599
|
-
const int64_t i11 = tgpig.y;
|
|
4990
|
+
const uint i12 = im%ne12;
|
|
4991
|
+
const uint i13 = im/ne12;
|
|
4600
4992
|
|
|
4601
|
-
const
|
|
4993
|
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
|
4994
|
+
device const block_iq1_s * x = (device const block_iq1_s *) src0 + ib_row + offset0;
|
|
4995
|
+
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
|
4602
4996
|
|
|
4603
|
-
|
|
4997
|
+
float yl[16];
|
|
4998
|
+
float sumf[N_DST]={0.f}, all_sum;
|
|
4604
4999
|
|
|
4605
|
-
|
|
4606
|
-
((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
|
|
4607
|
-
((device half *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
|
|
4608
|
-
}
|
|
4609
|
-
}
|
|
5000
|
+
const int nb32 = nb * (QK_K / 32);
|
|
4610
5001
|
|
|
4611
|
-
|
|
4612
|
-
|
|
4613
|
-
device const char * src1,
|
|
4614
|
-
device int32_t * dst,
|
|
4615
|
-
constant int64_t & ne00,
|
|
4616
|
-
constant uint64_t & nb01,
|
|
4617
|
-
constant uint64_t & nb02,
|
|
4618
|
-
constant int64_t & ne10,
|
|
4619
|
-
constant uint64_t & nb10,
|
|
4620
|
-
constant uint64_t & nb11,
|
|
4621
|
-
constant uint64_t & nb1,
|
|
4622
|
-
constant uint64_t & nb2,
|
|
4623
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
4624
|
-
uint tiitg[[thread_index_in_threadgroup]],
|
|
4625
|
-
uint3 tptg [[threads_per_threadgroup]]) {
|
|
4626
|
-
const int64_t i10 = tgpig.x;
|
|
4627
|
-
const int64_t i11 = tgpig.y;
|
|
5002
|
+
const int ix = tiisg/2;
|
|
5003
|
+
const int il = tiisg%2;
|
|
4628
5004
|
|
|
4629
|
-
const
|
|
5005
|
+
device const float * y4 = y + 32 * ix + 16 * il;
|
|
4630
5006
|
|
|
4631
|
-
|
|
5007
|
+
for (int ib32 = ix; ib32 < nb32; ib32 += 16) {
|
|
4632
5008
|
|
|
4633
|
-
|
|
4634
|
-
|
|
4635
|
-
|
|
4636
|
-
}
|
|
4637
|
-
}
|
|
5009
|
+
for (int i = 0; i < 16; ++i) {
|
|
5010
|
+
yl[i] = y4[i];
|
|
5011
|
+
}
|
|
4638
5012
|
|
|
5013
|
+
const int ibl = ib32 / (QK_K / 32);
|
|
5014
|
+
const int ib = ib32 % (QK_K / 32);
|
|
4639
5015
|
|
|
4640
|
-
|
|
4641
|
-
|
|
4642
|
-
|
|
4643
|
-
|
|
4644
|
-
#define THREAD_MAT_N 2 // each thread take 2 simdgroup matrices from matrix B
|
|
4645
|
-
#define THREAD_PER_BLOCK 128
|
|
4646
|
-
#define THREAD_PER_ROW 2 // 2 thread for each row in matrix A to load numbers
|
|
4647
|
-
#define THREAD_PER_COL 4 // 4 thread for each row in matrix B to load numbers
|
|
4648
|
-
#define SG_MAT_SIZE 64 // simdgroup matrix is of shape 8x8
|
|
4649
|
-
#define SG_MAT_ROW 8
|
|
5016
|
+
device const block_iq1_s * xr = x + ibl;
|
|
5017
|
+
device const uint8_t * qs = xr->qs + 4 * ib + 2 * il;
|
|
5018
|
+
device const uint8_t * sc = xr->scales + 2 * ib + il;
|
|
5019
|
+
device const half * dh = &xr->d;
|
|
4650
5020
|
|
|
4651
|
-
|
|
4652
|
-
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
|
4653
|
-
void kernel_mul_mm_impl(device const uchar * src0,
|
|
4654
|
-
device const uchar * src1,
|
|
4655
|
-
device float * dst,
|
|
4656
|
-
constant int64_t & ne00,
|
|
4657
|
-
constant int64_t & ne02,
|
|
4658
|
-
constant uint64_t & nb01,
|
|
4659
|
-
constant uint64_t & nb02,
|
|
4660
|
-
constant int64_t & ne12,
|
|
4661
|
-
constant uint64_t & nb10,
|
|
4662
|
-
constant uint64_t & nb11,
|
|
4663
|
-
constant uint64_t & nb12,
|
|
4664
|
-
constant int64_t & ne0,
|
|
4665
|
-
constant int64_t & ne1,
|
|
4666
|
-
constant uint & r2,
|
|
4667
|
-
constant uint & r3,
|
|
4668
|
-
threadgroup uchar * shared_memory [[threadgroup(0)]],
|
|
4669
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
4670
|
-
uint tiitg[[thread_index_in_threadgroup]],
|
|
4671
|
-
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
5021
|
+
for (int row = 0; row < N_DST; row++) {
|
|
4672
5022
|
|
|
4673
|
-
|
|
4674
|
-
|
|
5023
|
+
constant int8_t * grid1 = (constant int8_t *)(iq1s_grid + (qs[0] | ((sc[0] & 0x08) << 5)));
|
|
5024
|
+
constant int8_t * grid2 = (constant int8_t *)(iq1s_grid + (qs[1] | ((sc[0] & 0x80) << 1)));
|
|
4675
5025
|
|
|
4676
|
-
|
|
4677
|
-
|
|
4678
|
-
|
|
5026
|
+
float2 sum = {0};
|
|
5027
|
+
for (int j = 0; j < 8; ++j) {
|
|
5028
|
+
sum[0] += yl[j+ 0] * grid1[j];
|
|
5029
|
+
sum[1] += yl[j+ 8] * grid2[j];
|
|
5030
|
+
}
|
|
5031
|
+
sumf[row] += (float)dh[0] * (sum[0] * (2*(sc[0] & 7) + 1) + sum[1] * (2*((sc[0] >> 4) & 7) + 1));
|
|
4679
5032
|
|
|
4680
|
-
|
|
4681
|
-
|
|
4682
|
-
|
|
5033
|
+
dh += nb*sizeof(block_iq1_s)/2;
|
|
5034
|
+
qs += nb*sizeof(block_iq1_s);
|
|
5035
|
+
sc += nb*sizeof(block_iq1_s);
|
|
5036
|
+
}
|
|
4683
5037
|
|
|
4684
|
-
|
|
4685
|
-
|
|
4686
|
-
short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
|
|
5038
|
+
y4 += 16 * 32;
|
|
5039
|
+
}
|
|
4687
5040
|
|
|
4688
|
-
|
|
4689
|
-
|
|
4690
|
-
|
|
4691
|
-
|
|
4692
|
-
|
|
5041
|
+
for (int row = 0; row < N_DST; ++row) {
|
|
5042
|
+
all_sum = simd_sum(sumf[row]);
|
|
5043
|
+
if (tiisg == 0) {
|
|
5044
|
+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
|
|
5045
|
+
}
|
|
4693
5046
|
}
|
|
5047
|
+
}
|
|
4694
5048
|
|
|
4695
|
-
|
|
5049
|
+
constexpr constant static float kvalues_iq4nl_f[16] = {
|
|
5050
|
+
-127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
|
|
5051
|
+
};
|
|
5052
|
+
|
|
5053
|
+
void kernel_mul_mv_iq4_nl_f32_impl(
|
|
5054
|
+
device const void * src0,
|
|
5055
|
+
device const float * src1,
|
|
5056
|
+
device float * dst,
|
|
5057
|
+
constant int64_t & ne00,
|
|
5058
|
+
constant int64_t & ne01,
|
|
5059
|
+
constant int64_t & ne02,
|
|
5060
|
+
constant int64_t & ne10,
|
|
5061
|
+
constant int64_t & ne12,
|
|
5062
|
+
constant int64_t & ne0,
|
|
5063
|
+
constant int64_t & ne1,
|
|
5064
|
+
constant uint & r2,
|
|
5065
|
+
constant uint & r3,
|
|
5066
|
+
threadgroup float * shared_values [[threadgroup(0)]],
|
|
5067
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
5068
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
5069
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
5070
|
+
|
|
5071
|
+
const int nb = ne00/QK4_NL;
|
|
5072
|
+
const int r0 = tgpig.x;
|
|
5073
|
+
const int r1 = tgpig.y;
|
|
5074
|
+
const int im = tgpig.z;
|
|
5075
|
+
const int first_row = (r0 * 2 + sgitg) * 2;
|
|
5076
|
+
const int ib_row = first_row * nb;
|
|
4696
5077
|
|
|
4697
5078
|
const uint i12 = im%ne12;
|
|
4698
5079
|
const uint i13 = im/ne12;
|
|
4699
5080
|
|
|
4700
|
-
uint
|
|
4701
|
-
|
|
5081
|
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
|
5082
|
+
device const block_iq4_nl * x = (device const block_iq4_nl *) src0 + ib_row + offset0;
|
|
5083
|
+
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
|
4702
5084
|
|
|
4703
|
-
|
|
4704
|
-
|
|
4705
|
-
+ nb12 * im
|
|
4706
|
-
+ nb11 * (r1 * BLOCK_SIZE_N + thread_col)
|
|
4707
|
-
+ nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
|
|
5085
|
+
const int ix = tiisg/2; // 0...15
|
|
5086
|
+
const int it = tiisg%2; // 0 or 1
|
|
4708
5087
|
|
|
4709
|
-
|
|
4710
|
-
|
|
4711
|
-
half4x4 temp_a;
|
|
4712
|
-
dequantize_func(x, il, temp_a);
|
|
4713
|
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
5088
|
+
shared_values[tiisg] = kvalues_iq4nl_f[tiisg%16];
|
|
5089
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
4714
5090
|
|
|
4715
|
-
|
|
4716
|
-
|
|
4717
|
-
*(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
|
|
4718
|
-
+ (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \
|
|
4719
|
-
+ (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
|
|
4720
|
-
}
|
|
5091
|
+
float4 yl[4];
|
|
5092
|
+
float sumf[2]={0.f}, all_sum;
|
|
4721
5093
|
|
|
4722
|
-
|
|
5094
|
+
device const float * yb = y + ix * QK4_NL + it * 8;
|
|
4723
5095
|
|
|
4724
|
-
|
|
4725
|
-
|
|
4726
|
-
y += BLOCK_SIZE_K;
|
|
5096
|
+
uint32_t aux32[2];
|
|
5097
|
+
thread const uint8_t * q8 = (thread const uint8_t *)aux32;
|
|
4727
5098
|
|
|
4728
|
-
|
|
5099
|
+
float4 qf1, qf2;
|
|
4729
5100
|
|
|
4730
|
-
|
|
4731
|
-
threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
|
|
4732
|
-
threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
|
|
5101
|
+
for (int ib = ix; ib < nb; ib += 16) {
|
|
4733
5102
|
|
|
4734
|
-
|
|
4735
|
-
|
|
4736
|
-
#pragma unroll(4)
|
|
4737
|
-
for (int i = 0; i < 4; i++) {
|
|
4738
|
-
simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i);
|
|
4739
|
-
}
|
|
4740
|
-
simdgroup_barrier(mem_flags::mem_none);
|
|
4741
|
-
#pragma unroll(2)
|
|
4742
|
-
for (int i = 0; i < 2; i++) {
|
|
4743
|
-
simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i);
|
|
4744
|
-
}
|
|
5103
|
+
device const float4 * y4 = (device const float4 *)yb;
|
|
5104
|
+
yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5];
|
|
4745
5105
|
|
|
4746
|
-
|
|
4747
|
-
|
|
5106
|
+
for (int row = 0; row < 2; ++row) {
|
|
5107
|
+
|
|
5108
|
+
device const block_iq4_nl & xb = x[row*nb + ib];
|
|
5109
|
+
device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it);
|
|
5110
|
+
|
|
5111
|
+
float4 acc1 = {0.f}, acc2 = {0.f};
|
|
5112
|
+
|
|
5113
|
+
aux32[0] = q4[0] | (q4[1] << 16);
|
|
5114
|
+
aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f;
|
|
5115
|
+
aux32[0] &= 0x0f0f0f0f;
|
|
5116
|
+
qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
|
|
5117
|
+
qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
|
|
5118
|
+
acc1 += yl[0] * qf1;
|
|
5119
|
+
acc2 += yl[1] * qf2;
|
|
5120
|
+
|
|
5121
|
+
aux32[0] = q4[2] | (q4[3] << 16);
|
|
5122
|
+
aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f;
|
|
5123
|
+
aux32[0] &= 0x0f0f0f0f;
|
|
5124
|
+
qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
|
|
5125
|
+
qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
|
|
5126
|
+
acc1 += yl[2] * qf1;
|
|
5127
|
+
acc2 += yl[3] * qf2;
|
|
5128
|
+
|
|
5129
|
+
acc1 += acc2;
|
|
5130
|
+
|
|
5131
|
+
sumf[row] += (float)xb.d * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
|
|
4748
5132
|
|
|
4749
|
-
#pragma unroll(8)
|
|
4750
|
-
for (int i = 0; i < 8; i++){
|
|
4751
|
-
simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
|
|
4752
|
-
}
|
|
4753
5133
|
}
|
|
5134
|
+
|
|
5135
|
+
yb += 16 * QK4_NL;
|
|
4754
5136
|
}
|
|
4755
5137
|
|
|
4756
|
-
|
|
4757
|
-
|
|
4758
|
-
|
|
4759
|
-
|
|
4760
|
-
simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0);
|
|
4761
|
-
}
|
|
4762
|
-
} else {
|
|
4763
|
-
// block is smaller than 64x32, we should avoid writing data outside of the matrix
|
|
4764
|
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
4765
|
-
threadgroup float * temp_str = ((threadgroup float *)shared_memory) \
|
|
4766
|
-
+ 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
|
|
4767
|
-
for (int i = 0; i < 8; i++) {
|
|
4768
|
-
simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
|
|
4769
|
-
}
|
|
4770
|
-
|
|
4771
|
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
4772
|
-
|
|
4773
|
-
device float * C = dst + (BLOCK_SIZE_M * r0) + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
|
|
4774
|
-
if (sgitg == 0) {
|
|
4775
|
-
for (int i = 0; i < n_rows; i++) {
|
|
4776
|
-
for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
|
|
4777
|
-
*(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
|
|
4778
|
-
}
|
|
4779
|
-
}
|
|
5138
|
+
for (int row = 0; row < 2; ++row) {
|
|
5139
|
+
all_sum = simd_sum(sumf[row]);
|
|
5140
|
+
if (tiisg == 0) {
|
|
5141
|
+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
|
|
4780
5142
|
}
|
|
4781
5143
|
}
|
|
4782
5144
|
}
|
|
4783
5145
|
|
|
4784
|
-
|
|
4785
|
-
|
|
4786
|
-
void
|
|
4787
|
-
device const
|
|
4788
|
-
device
|
|
4789
|
-
|
|
4790
|
-
|
|
4791
|
-
constant
|
|
4792
|
-
constant
|
|
4793
|
-
constant
|
|
4794
|
-
constant
|
|
4795
|
-
constant
|
|
4796
|
-
constant
|
|
4797
|
-
constant
|
|
4798
|
-
|
|
4799
|
-
|
|
4800
|
-
|
|
4801
|
-
|
|
4802
|
-
constant uint & r3,
|
|
4803
|
-
threadgroup uchar * shared_memory,
|
|
4804
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
4805
|
-
uint tiitg[[thread_index_in_threadgroup]],
|
|
4806
|
-
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
5146
|
+
#if QK_K != 64
|
|
5147
|
+
void kernel_mul_mv_iq4_xs_f32_impl(
|
|
5148
|
+
device const void * src0,
|
|
5149
|
+
device const float * src1,
|
|
5150
|
+
device float * dst,
|
|
5151
|
+
constant int64_t & ne00,
|
|
5152
|
+
constant int64_t & ne01,
|
|
5153
|
+
constant int64_t & ne02,
|
|
5154
|
+
constant int64_t & ne10,
|
|
5155
|
+
constant int64_t & ne12,
|
|
5156
|
+
constant int64_t & ne0,
|
|
5157
|
+
constant int64_t & ne1,
|
|
5158
|
+
constant uint & r2,
|
|
5159
|
+
constant uint & r3,
|
|
5160
|
+
threadgroup float * shared_values [[threadgroup(0)]],
|
|
5161
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
5162
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
5163
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
4807
5164
|
|
|
4808
|
-
|
|
4809
|
-
|
|
5165
|
+
const int nb = ne00/QK_K;
|
|
5166
|
+
const int r0 = tgpig.x;
|
|
5167
|
+
const int r1 = tgpig.y;
|
|
5168
|
+
const int im = tgpig.z;
|
|
5169
|
+
const int first_row = (r0 * 2 + sgitg) * 2;
|
|
5170
|
+
const int ib_row = first_row * nb;
|
|
4810
5171
|
|
|
4811
|
-
const uint
|
|
4812
|
-
const uint
|
|
4813
|
-
const uint im = tgpig.z;
|
|
5172
|
+
const uint i12 = im%ne12;
|
|
5173
|
+
const uint i13 = im/ne12;
|
|
4814
5174
|
|
|
4815
|
-
|
|
5175
|
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
|
5176
|
+
device const block_iq4_xs * x = (device const block_iq4_xs *) src0 + ib_row + offset0;
|
|
5177
|
+
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
|
4816
5178
|
|
|
4817
|
-
|
|
4818
|
-
|
|
4819
|
-
|
|
5179
|
+
const int ix = tiisg/16; // 0 or 1
|
|
5180
|
+
const int it = tiisg%16; // 0...15
|
|
5181
|
+
const int ib = it/2;
|
|
5182
|
+
const int il = it%2;
|
|
4820
5183
|
|
|
4821
|
-
|
|
4822
|
-
|
|
4823
|
-
short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
|
|
5184
|
+
shared_values[tiisg] = kvalues_iq4nl_f[tiisg%16];
|
|
5185
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
4824
5186
|
|
|
4825
|
-
|
|
4826
|
-
|
|
4827
|
-
simdgroup_float8x8 c_res[8];
|
|
4828
|
-
for (int i = 0; i < 8; i++){
|
|
4829
|
-
c_res[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
|
|
4830
|
-
}
|
|
5187
|
+
float4 yl[4];
|
|
5188
|
+
float sumf[2]={0.f}, all_sum;
|
|
4831
5189
|
|
|
4832
|
-
|
|
5190
|
+
device const float * yb = y + ix * QK_K + ib * 32 + il * 8;
|
|
4833
5191
|
|
|
4834
|
-
|
|
4835
|
-
const
|
|
5192
|
+
uint32_t aux32[2];
|
|
5193
|
+
thread const uint8_t * q8 = (thread const uint8_t *)aux32;
|
|
4836
5194
|
|
|
4837
|
-
|
|
4838
|
-
ushort offset1 = il/nl;
|
|
5195
|
+
float4 qf1, qf2;
|
|
4839
5196
|
|
|
4840
|
-
|
|
4841
|
-
device const float * y = (device const float *)(src1
|
|
4842
|
-
+ nb12 * im
|
|
4843
|
-
+ nb11 * src1ids[r1 * BLOCK_SIZE_N + thread_col]
|
|
4844
|
-
+ nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
|
|
5197
|
+
for (int ibl = ix; ibl < nb; ibl += 2) {
|
|
4845
5198
|
|
|
4846
|
-
|
|
4847
|
-
|
|
4848
|
-
half4x4 temp_a;
|
|
4849
|
-
dequantize_func(x, il, temp_a);
|
|
4850
|
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
5199
|
+
device const float4 * y4 = (device const float4 *)yb;
|
|
5200
|
+
yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5];
|
|
4851
5201
|
|
|
4852
|
-
for (int
|
|
4853
|
-
*(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
|
|
4854
|
-
+ (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \
|
|
4855
|
-
+ (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
|
|
4856
|
-
}
|
|
5202
|
+
for (int row = 0; row < 2; ++row) {
|
|
4857
5203
|
|
|
4858
|
-
|
|
5204
|
+
device const block_iq4_xs & xb = x[row*nb + ibl];
|
|
5205
|
+
device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il);
|
|
4859
5206
|
|
|
4860
|
-
|
|
4861
|
-
x = (il < 2) ? x + (2+nl-1)/nl : x;
|
|
4862
|
-
y += BLOCK_SIZE_K;
|
|
5207
|
+
float4 acc1 = {0.f}, acc2 = {0.f};
|
|
4863
5208
|
|
|
4864
|
-
|
|
5209
|
+
aux32[0] = q4[0] & 0x0f0f0f0f;
|
|
5210
|
+
aux32[1] = (q4[0] >> 4) & 0x0f0f0f0f;
|
|
5211
|
+
qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
|
|
5212
|
+
qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
|
|
5213
|
+
acc1 += yl[0] * qf1;
|
|
5214
|
+
acc2 += yl[1] * qf2;
|
|
4865
5215
|
|
|
4866
|
-
|
|
4867
|
-
|
|
4868
|
-
|
|
5216
|
+
aux32[0] = q4[1] & 0x0f0f0f0f;
|
|
5217
|
+
aux32[1] = (q4[1] >> 4) & 0x0f0f0f0f;
|
|
5218
|
+
qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
|
|
5219
|
+
qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
|
|
5220
|
+
acc1 += yl[2] * qf1;
|
|
5221
|
+
acc2 += yl[3] * qf2;
|
|
4869
5222
|
|
|
4870
|
-
|
|
4871
|
-
for (int i = 0; i < 4; i++) {
|
|
4872
|
-
simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i);
|
|
4873
|
-
}
|
|
4874
|
-
simdgroup_barrier(mem_flags::mem_none);
|
|
4875
|
-
for (int i = 0; i < 2; i++) {
|
|
4876
|
-
simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i);
|
|
4877
|
-
}
|
|
5223
|
+
acc1 += acc2;
|
|
4878
5224
|
|
|
4879
|
-
|
|
4880
|
-
|
|
5225
|
+
const int ls = (((xb.scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((xb.scales_h >> 2*ib) & 3) << 4)) - 32;
|
|
5226
|
+
sumf[row] += (float)xb.d * ls * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
|
|
4881
5227
|
|
|
4882
|
-
for (int i = 0; i < 8; i++){
|
|
4883
|
-
simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
|
|
4884
|
-
}
|
|
4885
5228
|
}
|
|
5229
|
+
|
|
5230
|
+
yb += 2 * QK_K;
|
|
5231
|
+
}
|
|
5232
|
+
|
|
5233
|
+
for (int row = 0; row < 2; ++row) {
|
|
5234
|
+
all_sum = simd_sum(sumf[row]);
|
|
5235
|
+
if (tiisg == 0) {
|
|
5236
|
+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
|
|
5237
|
+
}
|
|
5238
|
+
}
|
|
5239
|
+
}
|
|
5240
|
+
#endif
|
|
5241
|
+
|
|
5242
|
+
[[host_name("kernel_mul_mv_iq1_s_f32")]]
|
|
5243
|
+
kernel void kernel_mul_mv_iq1_s_f32(
|
|
5244
|
+
device const void * src0,
|
|
5245
|
+
device const float * src1,
|
|
5246
|
+
device float * dst,
|
|
5247
|
+
constant int64_t & ne00,
|
|
5248
|
+
constant int64_t & ne01,
|
|
5249
|
+
constant int64_t & ne02,
|
|
5250
|
+
constant uint64_t & nb00,
|
|
5251
|
+
constant uint64_t & nb01,
|
|
5252
|
+
constant uint64_t & nb02,
|
|
5253
|
+
constant int64_t & ne10,
|
|
5254
|
+
constant int64_t & ne11,
|
|
5255
|
+
constant int64_t & ne12,
|
|
5256
|
+
constant uint64_t & nb10,
|
|
5257
|
+
constant uint64_t & nb11,
|
|
5258
|
+
constant uint64_t & nb12,
|
|
5259
|
+
constant int64_t & ne0,
|
|
5260
|
+
constant int64_t & ne1,
|
|
5261
|
+
constant uint & r2,
|
|
5262
|
+
constant uint & r3,
|
|
5263
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
5264
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
5265
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
5266
|
+
|
|
5267
|
+
kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
|
5268
|
+
}
|
|
5269
|
+
|
|
5270
|
+
[[host_name("kernel_mul_mv_iq4_nl_f32")]]
|
|
5271
|
+
kernel void kernel_mul_mv_iq4_nl_f32(
|
|
5272
|
+
device const void * src0,
|
|
5273
|
+
device const float * src1,
|
|
5274
|
+
device float * dst,
|
|
5275
|
+
constant int64_t & ne00,
|
|
5276
|
+
constant int64_t & ne01,
|
|
5277
|
+
constant int64_t & ne02,
|
|
5278
|
+
constant uint64_t & nb00,
|
|
5279
|
+
constant uint64_t & nb01,
|
|
5280
|
+
constant uint64_t & nb02,
|
|
5281
|
+
constant int64_t & ne10,
|
|
5282
|
+
constant int64_t & ne11,
|
|
5283
|
+
constant int64_t & ne12,
|
|
5284
|
+
constant uint64_t & nb10,
|
|
5285
|
+
constant uint64_t & nb11,
|
|
5286
|
+
constant uint64_t & nb12,
|
|
5287
|
+
constant int64_t & ne0,
|
|
5288
|
+
constant int64_t & ne1,
|
|
5289
|
+
constant uint & r2,
|
|
5290
|
+
constant uint & r3,
|
|
5291
|
+
threadgroup float * shared_values [[threadgroup(0)]],
|
|
5292
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
5293
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
5294
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
5295
|
+
|
|
5296
|
+
kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
|
5297
|
+
}
|
|
5298
|
+
|
|
5299
|
+
[[host_name("kernel_mul_mv_iq4_xs_f32")]]
|
|
5300
|
+
kernel void kernel_mul_mv_iq4_xs_f32(
|
|
5301
|
+
device const void * src0,
|
|
5302
|
+
device const float * src1,
|
|
5303
|
+
device float * dst,
|
|
5304
|
+
constant int64_t & ne00,
|
|
5305
|
+
constant int64_t & ne01,
|
|
5306
|
+
constant int64_t & ne02,
|
|
5307
|
+
constant uint64_t & nb00,
|
|
5308
|
+
constant uint64_t & nb01,
|
|
5309
|
+
constant uint64_t & nb02,
|
|
5310
|
+
constant int64_t & ne10,
|
|
5311
|
+
constant int64_t & ne11,
|
|
5312
|
+
constant int64_t & ne12,
|
|
5313
|
+
constant uint64_t & nb10,
|
|
5314
|
+
constant uint64_t & nb11,
|
|
5315
|
+
constant uint64_t & nb12,
|
|
5316
|
+
constant int64_t & ne0,
|
|
5317
|
+
constant int64_t & ne1,
|
|
5318
|
+
constant uint & r2,
|
|
5319
|
+
constant uint & r3,
|
|
5320
|
+
threadgroup float * shared_values [[threadgroup(0)]],
|
|
5321
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
5322
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
5323
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
5324
|
+
|
|
5325
|
+
#if QK_K == 64
|
|
5326
|
+
kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
|
5327
|
+
#else
|
|
5328
|
+
kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
|
5329
|
+
#endif
|
|
5330
|
+
}
|
|
5331
|
+
|
|
5332
|
+
//============================= templates and their specializations =============================
|
|
5333
|
+
|
|
5334
|
+
// NOTE: this is not dequantizing - we are simply fitting the template
|
|
5335
|
+
template <typename type4x4>
|
|
5336
|
+
void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
|
|
5337
|
+
float4x4 temp = *(((device float4x4 *)src));
|
|
5338
|
+
for (int i = 0; i < 16; i++){
|
|
5339
|
+
reg[i/4][i%4] = temp[i/4][i%4];
|
|
5340
|
+
}
|
|
5341
|
+
}
|
|
5342
|
+
|
|
5343
|
+
template <typename type4x4>
|
|
5344
|
+
void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
|
|
5345
|
+
half4x4 temp = *(((device half4x4 *)src));
|
|
5346
|
+
for (int i = 0; i < 16; i++){
|
|
5347
|
+
reg[i/4][i%4] = temp[i/4][i%4];
|
|
4886
5348
|
}
|
|
5349
|
+
}
|
|
5350
|
+
|
|
5351
|
+
template <typename type4x4>
|
|
5352
|
+
void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
|
|
5353
|
+
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
|
|
5354
|
+
const float d1 = il ? (xb->d / 16.h) : xb->d;
|
|
5355
|
+
const float d2 = d1 / 256.f;
|
|
5356
|
+
const float md = -8.h * xb->d;
|
|
5357
|
+
const ushort mask0 = il ? 0x00F0 : 0x000F;
|
|
5358
|
+
const ushort mask1 = mask0 << 8;
|
|
5359
|
+
|
|
5360
|
+
for (int i=0;i<8;i++) {
|
|
5361
|
+
reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md;
|
|
5362
|
+
reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md;
|
|
5363
|
+
}
|
|
5364
|
+
}
|
|
5365
|
+
|
|
5366
|
+
template <typename type4x4>
|
|
5367
|
+
void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
|
|
5368
|
+
device const uint16_t * qs = ((device const uint16_t *)xb + 2);
|
|
5369
|
+
const float d1 = il ? (xb->d / 16.h) : xb->d;
|
|
5370
|
+
const float d2 = d1 / 256.f;
|
|
5371
|
+
const float m = xb->m;
|
|
5372
|
+
const ushort mask0 = il ? 0x00F0 : 0x000F;
|
|
5373
|
+
const ushort mask1 = mask0 << 8;
|
|
5374
|
+
|
|
5375
|
+
for (int i=0;i<8;i++) {
|
|
5376
|
+
reg[i/2][2*(i%2)+0] = ((qs[i] & mask0) * d1) + m;
|
|
5377
|
+
reg[i/2][2*(i%2)+1] = ((qs[i] & mask1) * d2) + m;
|
|
5378
|
+
}
|
|
5379
|
+
}
|
|
5380
|
+
|
|
5381
|
+
template <typename type4x4>
|
|
5382
|
+
void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg) {
|
|
5383
|
+
device const uint16_t * qs = ((device const uint16_t *)xb + 3);
|
|
5384
|
+
const float d = xb->d;
|
|
5385
|
+
const float md = -16.h * xb->d;
|
|
5386
|
+
const ushort mask = il ? 0x00F0 : 0x000F;
|
|
5387
|
+
|
|
5388
|
+
const uint32_t qh = *((device const uint32_t *)xb->qh);
|
|
5389
|
+
|
|
5390
|
+
const int x_mv = il ? 4 : 0;
|
|
5391
|
+
|
|
5392
|
+
const int gh_mv = il ? 12 : 0;
|
|
5393
|
+
const int gh_bk = il ? 0 : 4;
|
|
5394
|
+
|
|
5395
|
+
for (int i = 0; i < 8; i++) {
|
|
5396
|
+
// extract the 5-th bits for x0 and x1
|
|
5397
|
+
const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
|
|
5398
|
+
const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
|
|
5399
|
+
|
|
5400
|
+
// combine the 4-bits from qs with the 5th bit
|
|
5401
|
+
const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
|
|
5402
|
+
const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
|
|
5403
|
+
|
|
5404
|
+
reg[i/2][2*(i%2)+0] = d * x0 + md;
|
|
5405
|
+
reg[i/2][2*(i%2)+1] = d * x1 + md;
|
|
5406
|
+
}
|
|
5407
|
+
}
|
|
5408
|
+
|
|
5409
|
+
template <typename type4x4>
|
|
5410
|
+
void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg) {
|
|
5411
|
+
device const uint16_t * qs = ((device const uint16_t *)xb + 4);
|
|
5412
|
+
const float d = xb->d;
|
|
5413
|
+
const float m = xb->m;
|
|
5414
|
+
const ushort mask = il ? 0x00F0 : 0x000F;
|
|
5415
|
+
|
|
5416
|
+
const uint32_t qh = *((device const uint32_t *)xb->qh);
|
|
5417
|
+
|
|
5418
|
+
const int x_mv = il ? 4 : 0;
|
|
5419
|
+
|
|
5420
|
+
const int gh_mv = il ? 12 : 0;
|
|
5421
|
+
const int gh_bk = il ? 0 : 4;
|
|
5422
|
+
|
|
5423
|
+
for (int i = 0; i < 8; i++) {
|
|
5424
|
+
// extract the 5-th bits for x0 and x1
|
|
5425
|
+
const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
|
|
5426
|
+
const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
|
|
5427
|
+
|
|
5428
|
+
// combine the 4-bits from qs with the 5th bit
|
|
5429
|
+
const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
|
|
5430
|
+
const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
|
|
5431
|
+
|
|
5432
|
+
reg[i/2][2*(i%2)+0] = d * x0 + m;
|
|
5433
|
+
reg[i/2][2*(i%2)+1] = d * x1 + m;
|
|
5434
|
+
}
|
|
5435
|
+
}
|
|
5436
|
+
|
|
5437
|
+
template <typename type4x4>
|
|
5438
|
+
void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
|
|
5439
|
+
device const int8_t * qs = ((device const int8_t *)xb->qs);
|
|
5440
|
+
const half d = xb->d;
|
|
5441
|
+
|
|
5442
|
+
for (int i = 0; i < 16; i++) {
|
|
5443
|
+
reg[i/4][i%4] = (qs[i + 16*il] * d);
|
|
5444
|
+
}
|
|
5445
|
+
}
|
|
5446
|
+
|
|
5447
|
+
template <typename type4x4>
|
|
5448
|
+
void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
|
|
5449
|
+
const float d = xb->d;
|
|
5450
|
+
const float min = xb->dmin;
|
|
5451
|
+
device const uint8_t * q = (device const uint8_t *)xb->qs;
|
|
5452
|
+
float dl, ml;
|
|
5453
|
+
uint8_t sc = xb->scales[il];
|
|
5454
|
+
|
|
5455
|
+
#if QK_K == 256
|
|
5456
|
+
q = q + 32*(il/8) + 16*(il&1);
|
|
5457
|
+
il = (il/2)%4;
|
|
5458
|
+
#endif
|
|
5459
|
+
half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
|
|
5460
|
+
uchar mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
|
|
5461
|
+
dl = d * (sc & 0xF) * coef, ml = min * (sc >> 4);
|
|
5462
|
+
for (int i = 0; i < 16; ++i) {
|
|
5463
|
+
reg[i/4][i%4] = dl * (q[i] & mask) - ml;
|
|
5464
|
+
}
|
|
5465
|
+
}
|
|
5466
|
+
|
|
5467
|
+
template <typename type4x4>
|
|
5468
|
+
void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) {
|
|
5469
|
+
const half d_all = xb->d;
|
|
5470
|
+
device const uint8_t * q = (device const uint8_t *)xb->qs;
|
|
5471
|
+
device const uint8_t * h = (device const uint8_t *)xb->hmask;
|
|
5472
|
+
device const int8_t * scales = (device const int8_t *)xb->scales;
|
|
5473
|
+
|
|
5474
|
+
#if QK_K == 256
|
|
5475
|
+
q = q + 32 * (il/8) + 16 * (il&1);
|
|
5476
|
+
h = h + 16 * (il&1);
|
|
5477
|
+
uint8_t m = 1 << (il/2);
|
|
5478
|
+
uint16_t kmask1 = (il/4)>1 ? ((il/4)>2 ? 192 : 48) : \
|
|
5479
|
+
((il/4)>0 ? 12 : 3);
|
|
5480
|
+
uint16_t kmask2 = il/8 ? 0xF0 : 0x0F;
|
|
5481
|
+
uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
|
|
5482
|
+
int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2)
|
|
5483
|
+
: (scale_2&kmask2) | ((scale_1&kmask1) << 4);
|
|
5484
|
+
float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f);
|
|
5485
|
+
const float ml = 4.f * dl;
|
|
5486
|
+
|
|
5487
|
+
il = (il/2) & 3;
|
|
5488
|
+
const half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
|
|
5489
|
+
const uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
|
|
5490
|
+
dl *= coef;
|
|
5491
|
+
|
|
5492
|
+
for (int i = 0; i < 16; ++i) {
|
|
5493
|
+
reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml);
|
|
5494
|
+
}
|
|
5495
|
+
#else
|
|
5496
|
+
float kcoef = il&1 ? 1.f/16.f : 1.f;
|
|
5497
|
+
uint16_t kmask = il&1 ? 0xF0 : 0x0F;
|
|
5498
|
+
float dl = d_all * ((scales[il/2] & kmask) * kcoef - 8);
|
|
5499
|
+
float coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
|
|
5500
|
+
uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
|
|
5501
|
+
uint8_t m = 1<<(il*2);
|
|
5502
|
+
for (int i = 0; i < 16; ++i) {
|
|
5503
|
+
reg[i/4][i%4] = coef * dl * ((q[i] & mask) - ((h[i%8] & (m * (1 + i/8))) ? 0 : 4.f/coef));
|
|
5504
|
+
}
|
|
5505
|
+
#endif
|
|
5506
|
+
}
|
|
5507
|
+
|
|
5508
|
+
static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) {
|
|
5509
|
+
return j < 4 ? uchar2{uchar(q[j+0+k] & 63), uchar(q[j+4+k] & 63)}
|
|
5510
|
+
: uchar2{uchar((q[j+4+k] & 0xF) | ((q[j-4+k] & 0xc0) >> 2)), uchar((q[j+4+k] >> 4) | ((q[j-0+k] & 0xc0) >> 2))};
|
|
5511
|
+
}
|
|
5512
|
+
|
|
5513
|
+
template <typename type4x4>
|
|
5514
|
+
void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) {
|
|
5515
|
+
device const uchar * q = xb->qs;
|
|
5516
|
+
|
|
5517
|
+
#if QK_K == 256
|
|
5518
|
+
short is = (il/4) * 2;
|
|
5519
|
+
q = q + (il/4) * 32 + 16 * (il&1);
|
|
5520
|
+
il = il & 3;
|
|
5521
|
+
const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
|
|
5522
|
+
const float d = il < 2 ? xb->d : xb->d / 16.h;
|
|
5523
|
+
const float min = xb->dmin;
|
|
5524
|
+
const float dl = d * sc[0];
|
|
5525
|
+
const float ml = min * sc[1];
|
|
5526
|
+
#else
|
|
5527
|
+
(void) get_scale_min_k4_just2;
|
|
5528
|
+
|
|
5529
|
+
q = q + 16 * (il&1);
|
|
5530
|
+
device const uint8_t * s = xb->scales;
|
|
5531
|
+
device const half2 * dh = (device const half2 *)xb->d;
|
|
5532
|
+
const float2 d = (float2)dh[0];
|
|
5533
|
+
const float dl = il<2 ? d[0] * (s[0]&0xF) : d[0] * (s[1]&0xF)/16.h;
|
|
5534
|
+
const float ml = il<2 ? d[1] * (s[0]>>4) : d[1] * (s[1]>>4);
|
|
5535
|
+
#endif
|
|
5536
|
+
const ushort mask = il<2 ? 0x0F : 0xF0;
|
|
5537
|
+
for (int i = 0; i < 16; ++i) {
|
|
5538
|
+
reg[i/4][i%4] = dl * (q[i] & mask) - ml;
|
|
5539
|
+
}
|
|
5540
|
+
}
|
|
5541
|
+
|
|
5542
|
+
template <typename type4x4>
|
|
5543
|
+
void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg) {
|
|
5544
|
+
device const uint8_t * q = xb->qs;
|
|
5545
|
+
device const uint8_t * qh = xb->qh;
|
|
5546
|
+
|
|
5547
|
+
#if QK_K == 256
|
|
5548
|
+
short is = (il/4) * 2;
|
|
5549
|
+
q = q + 32 * (il/4) + 16 * (il&1);
|
|
5550
|
+
qh = qh + 16 * (il&1);
|
|
5551
|
+
uint8_t ul = 1 << (il/2);
|
|
5552
|
+
il = il & 3;
|
|
5553
|
+
const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
|
|
5554
|
+
const float d = il < 2 ? xb->d : xb->d / 16.f;
|
|
5555
|
+
const float min = xb->dmin;
|
|
5556
|
+
const float dl = d * sc[0];
|
|
5557
|
+
const float ml = min * sc[1];
|
|
5558
|
+
|
|
5559
|
+
const ushort mask = il<2 ? 0x0F : 0xF0;
|
|
5560
|
+
const float qh_val = il<2 ? 16.f : 256.f;
|
|
5561
|
+
for (int i = 0; i < 16; ++i) {
|
|
5562
|
+
reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
|
|
5563
|
+
}
|
|
5564
|
+
#else
|
|
5565
|
+
q = q + 16 * (il&1);
|
|
5566
|
+
device const int8_t * s = xb->scales;
|
|
5567
|
+
const float dl = xb->d * s[il];
|
|
5568
|
+
uint8_t m = 1<<(il*2);
|
|
5569
|
+
const float coef = il<2 ? 1.f : 1.f/16.f;
|
|
5570
|
+
const ushort mask = il<2 ? 0x0F : 0xF0;
|
|
5571
|
+
for (int i = 0; i < 16; ++i) {
|
|
5572
|
+
reg[i/4][i%4] = coef * dl * ((q[i] & mask) - (qh[i%8] & (m*(1+i/8)) ? 0.f : 16.f/coef));
|
|
5573
|
+
}
|
|
5574
|
+
#endif
|
|
5575
|
+
}
|
|
5576
|
+
|
|
5577
|
+
template <typename type4x4>
|
|
5578
|
+
void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) {
|
|
5579
|
+
const half d_all = xb->d;
|
|
5580
|
+
device const uint8_t * ql = (device const uint8_t *)xb->ql;
|
|
5581
|
+
device const uint8_t * qh = (device const uint8_t *)xb->qh;
|
|
5582
|
+
device const int8_t * scales = (device const int8_t *)xb->scales;
|
|
5583
|
+
|
|
5584
|
+
#if QK_K == 256
|
|
5585
|
+
ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
|
|
5586
|
+
qh = qh + 32*(il/8) + 16*(il&1);
|
|
5587
|
+
float sc = scales[(il%2) + 2 * ((il/2))];
|
|
5588
|
+
il = (il/2) & 3;
|
|
5589
|
+
#else
|
|
5590
|
+
ql = ql + 16 * (il&1);
|
|
5591
|
+
float sc = scales[il];
|
|
5592
|
+
#endif
|
|
5593
|
+
const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
|
|
5594
|
+
const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
|
|
5595
|
+
const float coef = il>1 ? 1.f/16.f : 1.f;
|
|
5596
|
+
const float ml = d_all * sc * 32.f;
|
|
5597
|
+
const float dl = d_all * sc * coef;
|
|
5598
|
+
for (int i = 0; i < 16; ++i) {
|
|
5599
|
+
const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2))
|
|
5600
|
+
: ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4));
|
|
5601
|
+
reg[i/4][i%4] = dl * q - ml;
|
|
5602
|
+
}
|
|
5603
|
+
}
|
|
5604
|
+
|
|
5605
|
+
template <typename type4x4>
|
|
5606
|
+
void dequantize_iq2_xxs(device const block_iq2_xxs * xb, short il, thread type4x4 & reg) {
|
|
5607
|
+
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
|
|
5608
|
+
const float d = xb->d;
|
|
5609
|
+
const int ib32 = il/2;
|
|
5610
|
+
il = il%2;
|
|
5611
|
+
// il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
|
|
5612
|
+
// each block of 32 needs 2 uint32_t's for the quants & scale, so 4 uint16_t's.
|
|
5613
|
+
device const uint16_t * q2 = xb->qs + 4*ib32;
|
|
5614
|
+
const uint32_t aux32_g = q2[0] | (q2[1] << 16);
|
|
5615
|
+
const uint32_t aux32_s = q2[2] | (q2[3] << 16);
|
|
5616
|
+
thread const uint8_t * aux8 = (thread const uint8_t *)&aux32_g;
|
|
5617
|
+
const float dl = d * (0.5f + (aux32_s >> 28)) * 0.25f;
|
|
5618
|
+
constant uint8_t * grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+0]);
|
|
5619
|
+
uint8_t signs = ksigns_iq2xs[(aux32_s >> 14*il) & 127];
|
|
5620
|
+
for (int i = 0; i < 8; ++i) {
|
|
5621
|
+
reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
|
|
5622
|
+
}
|
|
5623
|
+
grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+1]);
|
|
5624
|
+
signs = ksigns_iq2xs[(aux32_s >> (14*il+7)) & 127];
|
|
5625
|
+
for (int i = 0; i < 8; ++i) {
|
|
5626
|
+
reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
|
|
5627
|
+
}
|
|
5628
|
+
}
|
|
5629
|
+
|
|
5630
|
+
template <typename type4x4>
|
|
5631
|
+
void dequantize_iq2_xs(device const block_iq2_xs * xb, short il, thread type4x4 & reg) {
|
|
5632
|
+
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
|
|
5633
|
+
const float d = xb->d;
|
|
5634
|
+
const int ib32 = il/2;
|
|
5635
|
+
il = il%2;
|
|
5636
|
+
// il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
|
|
5637
|
+
device const uint16_t * q2 = xb->qs + 4*ib32;
|
|
5638
|
+
const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f;
|
|
5639
|
+
constant uint8_t * grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+0] & 511));
|
|
5640
|
+
uint8_t signs = ksigns_iq2xs[q2[2*il+0] >> 9];
|
|
5641
|
+
for (int i = 0; i < 8; ++i) {
|
|
5642
|
+
reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
|
|
5643
|
+
}
|
|
5644
|
+
grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+1] & 511));
|
|
5645
|
+
signs = ksigns_iq2xs[q2[2*il+1] >> 9];
|
|
5646
|
+
for (int i = 0; i < 8; ++i) {
|
|
5647
|
+
reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
|
|
5648
|
+
}
|
|
5649
|
+
}
|
|
5650
|
+
|
|
5651
|
+
template <typename type4x4>
|
|
5652
|
+
void dequantize_iq3_xxs(device const block_iq3_xxs * xb, short il, thread type4x4 & reg) {
|
|
5653
|
+
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
|
|
5654
|
+
const float d = xb->d;
|
|
5655
|
+
const int ib32 = il/2;
|
|
5656
|
+
il = il%2;
|
|
5657
|
+
// il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
|
|
5658
|
+
device const uint8_t * q3 = xb->qs + 8*ib32;
|
|
5659
|
+
device const uint16_t * gas = (device const uint16_t *)(xb->qs + QK_K/4) + 2*ib32;
|
|
5660
|
+
const uint32_t aux32 = gas[0] | (gas[1] << 16);
|
|
5661
|
+
const float dl = d * (0.5f + (aux32 >> 28)) * 0.5f;
|
|
5662
|
+
constant uint8_t * grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+0]);
|
|
5663
|
+
constant uint8_t * grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+1]);
|
|
5664
|
+
uint8_t signs = ksigns_iq2xs[(aux32 >> 14*il) & 127];
|
|
5665
|
+
for (int i = 0; i < 4; ++i) {
|
|
5666
|
+
reg[0][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f);
|
|
5667
|
+
reg[1][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f);
|
|
5668
|
+
}
|
|
5669
|
+
grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+2]);
|
|
5670
|
+
grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+3]);
|
|
5671
|
+
signs = ksigns_iq2xs[(aux32 >> (14*il+7)) & 127];
|
|
5672
|
+
for (int i = 0; i < 4; ++i) {
|
|
5673
|
+
reg[2][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f);
|
|
5674
|
+
reg[3][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f);
|
|
5675
|
+
}
|
|
5676
|
+
}
|
|
5677
|
+
|
|
5678
|
+
template <typename type4x4>
|
|
5679
|
+
void dequantize_iq3_s(device const block_iq3_s * xb, short il, thread type4x4 & reg) {
|
|
5680
|
+
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
|
|
5681
|
+
const float d = xb->d;
|
|
5682
|
+
const int ib32 = il/2;
|
|
5683
|
+
il = il%2;
|
|
5684
|
+
// il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
|
|
5685
|
+
device const uint8_t * qs = xb->qs + 8*ib32;
|
|
5686
|
+
device const uint8_t * signs = xb->signs + 4*ib32 + 2*il;
|
|
5687
|
+
const uint8_t qh = xb->qh[ib32] >> 4*il;
|
|
5688
|
+
const float dl = d * (0.5f + ((xb->scales[ib32/2] >> 4*(ib32%2)) & 0xf)) * 0.5f;
|
|
5689
|
+
constant uint8_t * grid1 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+0] | ((qh << 8) & 256)));
|
|
5690
|
+
constant uint8_t * grid2 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+1] | ((qh << 7) & 256)));
|
|
5691
|
+
for (int i = 0; i < 4; ++i) {
|
|
5692
|
+
reg[0][i] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i+0]);
|
|
5693
|
+
reg[1][i] = dl * grid2[i] * select(1, -1, signs[0] & kmask_iq2xs[i+4]);
|
|
5694
|
+
}
|
|
5695
|
+
grid1 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+2] | ((qh << 6) & 256)));
|
|
5696
|
+
grid2 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+3] | ((qh << 5) & 256)));
|
|
5697
|
+
for (int i = 0; i < 4; ++i) {
|
|
5698
|
+
reg[2][i] = dl * grid1[i] * select(1, -1, signs[1] & kmask_iq2xs[i+0]);
|
|
5699
|
+
reg[3][i] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i+4]);
|
|
5700
|
+
}
|
|
5701
|
+
}
|
|
5702
|
+
|
|
5703
|
+
template <typename type4x4>
|
|
5704
|
+
void dequantize_iq2_s(device const block_iq2_s * xb, short il, thread type4x4 & reg) {
|
|
5705
|
+
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
|
|
5706
|
+
const float d = xb->d;
|
|
5707
|
+
const int ib32 = il/2;
|
|
5708
|
+
il = il%2;
|
|
5709
|
+
// il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
|
|
5710
|
+
device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
|
|
5711
|
+
device const uint8_t * signs = qs + QK_K/8;
|
|
5712
|
+
const uint8_t qh = xb->qh[ib32] >> 4*il;
|
|
5713
|
+
const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f;
|
|
5714
|
+
constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[0] | ((qh << 8) & 0x300)));
|
|
5715
|
+
constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[1] | ((qh << 6) & 0x300)));
|
|
5716
|
+
for (int i = 0; i < 8; ++i) {
|
|
5717
|
+
reg[i/4+0][i%4] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i]);
|
|
5718
|
+
reg[i/4+2][i%4] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i]);
|
|
5719
|
+
}
|
|
5720
|
+
}
|
|
5721
|
+
|
|
5722
|
+
template <typename type4x4>
|
|
5723
|
+
void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 & reg) {
|
|
5724
|
+
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
|
|
5725
|
+
const float d = xb->d;
|
|
5726
|
+
device const uint8_t * qs = xb->qs + 2*il;
|
|
5727
|
+
device const uint8_t * sc = xb->scales + il;
|
|
5728
|
+
const float dl1 = d * (2*(sc[0] & 7) + 1);
|
|
5729
|
+
const float dl2 = d * (2*((sc[0] >> 4) & 7) + 1);
|
|
5730
|
+
constant int8_t * grid1 = (constant int8_t *)(iq1s_grid + (qs[0] | ((sc[0] & 0x08) << 5)));
|
|
5731
|
+
constant int8_t * grid2 = (constant int8_t *)(iq1s_grid + (qs[1] | ((sc[0] & 0x80) << 1)));
|
|
5732
|
+
for (int i = 0; i < 8; ++i) {
|
|
5733
|
+
reg[i/4+0][i%4] = dl1 * grid1[i];
|
|
5734
|
+
reg[i/4+2][i%4] = dl2 * grid2[i];
|
|
5735
|
+
}
|
|
5736
|
+
}
|
|
5737
|
+
|
|
5738
|
+
template <typename type4x4>
|
|
5739
|
+
void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4 & reg) {
|
|
5740
|
+
device const uint16_t * q4 = (device const uint16_t *)xb->qs;
|
|
5741
|
+
const float d = xb->d;
|
|
5742
|
+
uint32_t aux32;
|
|
5743
|
+
thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
|
|
5744
|
+
for (int i = 0; i < 4; ++i) {
|
|
5745
|
+
aux32 = ((q4[2*i] | (q4[2*i+1] << 16)) >> 4*il) & 0x0f0f0f0f;
|
|
5746
|
+
reg[i][0] = d * kvalues_iq4nl_f[q8[0]];
|
|
5747
|
+
reg[i][1] = d * kvalues_iq4nl_f[q8[1]];
|
|
5748
|
+
reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
|
|
5749
|
+
reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
|
|
5750
|
+
}
|
|
5751
|
+
}
|
|
5752
|
+
|
|
5753
|
+
template <typename type4x4>
|
|
5754
|
+
void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) {
|
|
5755
|
+
#if QK_K == 64
|
|
5756
|
+
dequantize_iq4_nl(xb, il, reg);
|
|
5757
|
+
#else
|
|
5758
|
+
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
|
|
5759
|
+
const int ib32 = il/2;
|
|
5760
|
+
il = il%2;
|
|
5761
|
+
// il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
|
|
5762
|
+
device const uint32_t * q4 = (device const uint32_t *)xb->qs + 4*ib32;
|
|
5763
|
+
const int ls = ((xb->scales_l[ib32/2] >> 4*(ib32%2)) & 0xf) | (((xb->scales_h >> 2*ib32) & 3) << 4);
|
|
5764
|
+
const float d = (float)xb->d * (ls - 32);
|
|
5765
|
+
uint32_t aux32;
|
|
5766
|
+
thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
|
|
5767
|
+
for (int i = 0; i < 4; ++i) {
|
|
5768
|
+
aux32 = (q4[i] >> 4*il) & 0x0f0f0f0f;
|
|
5769
|
+
reg[i][0] = d * kvalues_iq4nl_f[q8[0]];
|
|
5770
|
+
reg[i][1] = d * kvalues_iq4nl_f[q8[1]];
|
|
5771
|
+
reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
|
|
5772
|
+
reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
|
|
5773
|
+
}
|
|
5774
|
+
#endif
|
|
5775
|
+
}
|
|
5776
|
+
|
|
5777
|
+
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
|
|
5778
|
+
kernel void kernel_get_rows(
|
|
5779
|
+
device const void * src0,
|
|
5780
|
+
device const char * src1,
|
|
5781
|
+
device float * dst,
|
|
5782
|
+
constant int64_t & ne00,
|
|
5783
|
+
constant uint64_t & nb01,
|
|
5784
|
+
constant uint64_t & nb02,
|
|
5785
|
+
constant int64_t & ne10,
|
|
5786
|
+
constant uint64_t & nb10,
|
|
5787
|
+
constant uint64_t & nb11,
|
|
5788
|
+
constant uint64_t & nb1,
|
|
5789
|
+
constant uint64_t & nb2,
|
|
5790
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
5791
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
|
5792
|
+
uint3 tptg [[threads_per_threadgroup]]) {
|
|
5793
|
+
//const int64_t i = tgpig;
|
|
5794
|
+
//const int64_t r = ((device int32_t *) src1)[i];
|
|
5795
|
+
|
|
5796
|
+
const int64_t i10 = tgpig.x;
|
|
5797
|
+
const int64_t i11 = tgpig.y;
|
|
5798
|
+
|
|
5799
|
+
const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
|
|
5800
|
+
|
|
5801
|
+
const int64_t i02 = i11;
|
|
5802
|
+
|
|
5803
|
+
for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) {
|
|
5804
|
+
float4x4 temp;
|
|
5805
|
+
dequantize_func(
|
|
5806
|
+
((device const block_q *) ((device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp);
|
|
5807
|
+
*(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp;
|
|
5808
|
+
}
|
|
5809
|
+
}
|
|
5810
|
+
|
|
5811
|
+
kernel void kernel_get_rows_f32(
|
|
5812
|
+
device const void * src0,
|
|
5813
|
+
device const char * src1,
|
|
5814
|
+
device float * dst,
|
|
5815
|
+
constant int64_t & ne00,
|
|
5816
|
+
constant uint64_t & nb01,
|
|
5817
|
+
constant uint64_t & nb02,
|
|
5818
|
+
constant int64_t & ne10,
|
|
5819
|
+
constant uint64_t & nb10,
|
|
5820
|
+
constant uint64_t & nb11,
|
|
5821
|
+
constant uint64_t & nb1,
|
|
5822
|
+
constant uint64_t & nb2,
|
|
5823
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
5824
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
|
5825
|
+
uint3 tptg [[threads_per_threadgroup]]) {
|
|
5826
|
+
const int64_t i10 = tgpig.x;
|
|
5827
|
+
const int64_t i11 = tgpig.y;
|
|
5828
|
+
|
|
5829
|
+
const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
|
|
5830
|
+
|
|
5831
|
+
const int64_t i02 = i11;
|
|
5832
|
+
|
|
5833
|
+
for (int ind = tiitg; ind < ne00; ind += tptg.x) {
|
|
5834
|
+
((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
|
|
5835
|
+
((device float *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
|
|
5836
|
+
}
|
|
5837
|
+
}
|
|
5838
|
+
|
|
5839
|
+
kernel void kernel_get_rows_f16(
|
|
5840
|
+
device const void * src0,
|
|
5841
|
+
device const char * src1,
|
|
5842
|
+
device float * dst,
|
|
5843
|
+
constant int64_t & ne00,
|
|
5844
|
+
constant uint64_t & nb01,
|
|
5845
|
+
constant uint64_t & nb02,
|
|
5846
|
+
constant int64_t & ne10,
|
|
5847
|
+
constant uint64_t & nb10,
|
|
5848
|
+
constant uint64_t & nb11,
|
|
5849
|
+
constant uint64_t & nb1,
|
|
5850
|
+
constant uint64_t & nb2,
|
|
5851
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
5852
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
|
5853
|
+
uint3 tptg [[threads_per_threadgroup]]) {
|
|
5854
|
+
const int64_t i10 = tgpig.x;
|
|
5855
|
+
const int64_t i11 = tgpig.y;
|
|
5856
|
+
|
|
5857
|
+
const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
|
|
5858
|
+
|
|
5859
|
+
const int64_t i02 = i11;
|
|
5860
|
+
|
|
5861
|
+
for (int ind = tiitg; ind < ne00; ind += tptg.x) {
|
|
5862
|
+
((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
|
|
5863
|
+
((device half *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
|
|
5864
|
+
}
|
|
5865
|
+
}
|
|
5866
|
+
|
|
5867
|
+
kernel void kernel_get_rows_i32(
|
|
5868
|
+
device const void * src0,
|
|
5869
|
+
device const char * src1,
|
|
5870
|
+
device int32_t * dst,
|
|
5871
|
+
constant int64_t & ne00,
|
|
5872
|
+
constant uint64_t & nb01,
|
|
5873
|
+
constant uint64_t & nb02,
|
|
5874
|
+
constant int64_t & ne10,
|
|
5875
|
+
constant uint64_t & nb10,
|
|
5876
|
+
constant uint64_t & nb11,
|
|
5877
|
+
constant uint64_t & nb1,
|
|
5878
|
+
constant uint64_t & nb2,
|
|
5879
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
5880
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
|
5881
|
+
uint3 tptg [[threads_per_threadgroup]]) {
|
|
5882
|
+
const int64_t i10 = tgpig.x;
|
|
5883
|
+
const int64_t i11 = tgpig.y;
|
|
5884
|
+
|
|
5885
|
+
const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
|
|
5886
|
+
|
|
5887
|
+
const int64_t i02 = i11;
|
|
5888
|
+
|
|
5889
|
+
for (int ind = tiitg; ind < ne00; ind += tptg.x) {
|
|
5890
|
+
((device int32_t *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
|
|
5891
|
+
((device int32_t *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
|
|
5892
|
+
}
|
|
5893
|
+
}
|
|
5894
|
+
|
|
5895
|
+
|
|
5896
|
+
#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
|
|
5897
|
+
#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
|
|
5898
|
+
#define BLOCK_SIZE_K 32
|
|
5899
|
+
#define THREAD_MAT_M 4 // each thread take 4 simdgroup matrices from matrix A
|
|
5900
|
+
#define THREAD_MAT_N 2 // each thread take 2 simdgroup matrices from matrix B
|
|
5901
|
+
#define THREAD_PER_BLOCK 128
|
|
5902
|
+
#define THREAD_PER_ROW 2 // 2 thread for each row in matrix A to load numbers
|
|
5903
|
+
#define THREAD_PER_COL 4 // 4 thread for each row in matrix B to load numbers
|
|
5904
|
+
#define SG_MAT_SIZE 64 // simdgroup matrix is of shape 8x8
|
|
5905
|
+
#define SG_MAT_ROW 8
|
|
5906
|
+
|
|
5907
|
+
// each block_q contains 16*nl weights
|
|
5908
|
+
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
|
5909
|
+
void kernel_mul_mm_impl(device const uchar * src0,
|
|
5910
|
+
device const uchar * src1,
|
|
5911
|
+
device float * dst,
|
|
5912
|
+
constant int64_t & ne00,
|
|
5913
|
+
constant int64_t & ne02,
|
|
5914
|
+
constant uint64_t & nb01,
|
|
5915
|
+
constant uint64_t & nb02,
|
|
5916
|
+
constant int64_t & ne12,
|
|
5917
|
+
constant uint64_t & nb10,
|
|
5918
|
+
constant uint64_t & nb11,
|
|
5919
|
+
constant uint64_t & nb12,
|
|
5920
|
+
constant int64_t & ne0,
|
|
5921
|
+
constant int64_t & ne1,
|
|
5922
|
+
constant uint & r2,
|
|
5923
|
+
constant uint & r3,
|
|
5924
|
+
threadgroup uchar * shared_memory [[threadgroup(0)]],
|
|
5925
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
5926
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
|
5927
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
5928
|
+
|
|
5929
|
+
threadgroup half * sa = (threadgroup half *)(shared_memory);
|
|
5930
|
+
threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
|
|
5931
|
+
|
|
5932
|
+
const uint r0 = tgpig.y;
|
|
5933
|
+
const uint r1 = tgpig.x;
|
|
5934
|
+
const uint im = tgpig.z;
|
|
5935
|
+
|
|
5936
|
+
// if this block is of 64x32 shape or smaller
|
|
5937
|
+
short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
|
|
5938
|
+
short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
|
|
5939
|
+
|
|
5940
|
+
// a thread shouldn't load data outside of the matrix
|
|
5941
|
+
short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
|
|
5942
|
+
short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
|
|
5943
|
+
|
|
5944
|
+
simdgroup_half8x8 ma[4];
|
|
5945
|
+
simdgroup_float8x8 mb[2];
|
|
5946
|
+
simdgroup_float8x8 c_res[8];
|
|
5947
|
+
for (int i = 0; i < 8; i++){
|
|
5948
|
+
c_res[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
|
|
5949
|
+
}
|
|
5950
|
+
|
|
5951
|
+
short il = (tiitg % THREAD_PER_ROW);
|
|
5952
|
+
|
|
5953
|
+
const uint i12 = im%ne12;
|
|
5954
|
+
const uint i13 = im/ne12;
|
|
5955
|
+
|
|
5956
|
+
uint offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02);
|
|
5957
|
+
ushort offset1 = il/nl;
|
|
5958
|
+
|
|
5959
|
+
device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
|
|
5960
|
+
device const float * y = (device const float *)(src1
|
|
5961
|
+
+ nb12 * im
|
|
5962
|
+
+ nb11 * (r1 * BLOCK_SIZE_N + thread_col)
|
|
5963
|
+
+ nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
|
|
5964
|
+
|
|
5965
|
+
for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
|
|
5966
|
+
// load data and store to threadgroup memory
|
|
5967
|
+
half4x4 temp_a;
|
|
5968
|
+
dequantize_func(x, il, temp_a);
|
|
5969
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
5970
|
+
|
|
5971
|
+
#pragma unroll(16)
|
|
5972
|
+
for (int i = 0; i < 16; i++) {
|
|
5973
|
+
*(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
|
|
5974
|
+
+ (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \
|
|
5975
|
+
+ (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
|
|
5976
|
+
}
|
|
5977
|
+
|
|
5978
|
+
*(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y);
|
|
5979
|
+
|
|
5980
|
+
il = (il + 2 < nl) ? il + 2 : il % 2;
|
|
5981
|
+
x = (il < 2) ? x + (2+nl-1)/nl : x;
|
|
5982
|
+
y += BLOCK_SIZE_K;
|
|
5983
|
+
|
|
5984
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
5985
|
+
|
|
5986
|
+
// load matrices from threadgroup memory and conduct outer products
|
|
5987
|
+
threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
|
|
5988
|
+
threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
|
|
5989
|
+
|
|
5990
|
+
#pragma unroll(4)
|
|
5991
|
+
for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
|
|
5992
|
+
#pragma unroll(4)
|
|
5993
|
+
for (int i = 0; i < 4; i++) {
|
|
5994
|
+
simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i);
|
|
5995
|
+
}
|
|
5996
|
+
simdgroup_barrier(mem_flags::mem_none);
|
|
5997
|
+
#pragma unroll(2)
|
|
5998
|
+
for (int i = 0; i < 2; i++) {
|
|
5999
|
+
simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i);
|
|
6000
|
+
}
|
|
6001
|
+
|
|
6002
|
+
lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
|
|
6003
|
+
lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
|
|
6004
|
+
|
|
6005
|
+
#pragma unroll(8)
|
|
6006
|
+
for (int i = 0; i < 8; i++){
|
|
6007
|
+
simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
|
|
6008
|
+
}
|
|
6009
|
+
}
|
|
6010
|
+
}
|
|
6011
|
+
|
|
6012
|
+
if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) {
|
|
6013
|
+
device float * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1)) \
|
|
6014
|
+
+ (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * ne0 + im*ne1*ne0;
|
|
6015
|
+
for (int i = 0; i < 8; i++) {
|
|
6016
|
+
simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0);
|
|
6017
|
+
}
|
|
6018
|
+
} else {
|
|
6019
|
+
// block is smaller than 64x32, we should avoid writing data outside of the matrix
|
|
6020
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
6021
|
+
threadgroup float * temp_str = ((threadgroup float *)shared_memory) \
|
|
6022
|
+
+ 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
|
|
6023
|
+
for (int i = 0; i < 8; i++) {
|
|
6024
|
+
simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
|
|
6025
|
+
}
|
|
6026
|
+
|
|
6027
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
6028
|
+
|
|
6029
|
+
device float * C = dst + (BLOCK_SIZE_M * r0) + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
|
|
6030
|
+
if (sgitg == 0) {
|
|
6031
|
+
for (int i = 0; i < n_rows; i++) {
|
|
6032
|
+
for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
|
|
6033
|
+
*(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
|
|
6034
|
+
}
|
|
6035
|
+
}
|
|
6036
|
+
}
|
|
6037
|
+
}
|
|
6038
|
+
}
|
|
6039
|
+
|
|
6040
|
+
// same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in src1ids
|
|
6041
|
+
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
|
6042
|
+
void kernel_mul_mm_id_impl(
|
|
6043
|
+
device const uchar * src0,
|
|
6044
|
+
device const uchar * src1,
|
|
6045
|
+
thread short * src1ids,
|
|
6046
|
+
device float * dst,
|
|
6047
|
+
constant int64_t & ne00,
|
|
6048
|
+
constant int64_t & ne02,
|
|
6049
|
+
constant uint64_t & nb01,
|
|
6050
|
+
constant uint64_t & nb02,
|
|
6051
|
+
constant int64_t & ne12,
|
|
6052
|
+
constant uint64_t & nb10,
|
|
6053
|
+
constant uint64_t & nb11,
|
|
6054
|
+
constant uint64_t & nb12,
|
|
6055
|
+
constant int64_t & ne0,
|
|
6056
|
+
int64_t ne1,
|
|
6057
|
+
constant uint & r2,
|
|
6058
|
+
constant uint & r3,
|
|
6059
|
+
threadgroup uchar * shared_memory,
|
|
6060
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
6061
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
|
6062
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
6063
|
+
|
|
6064
|
+
threadgroup half * sa = (threadgroup half *)(shared_memory);
|
|
6065
|
+
threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
|
|
6066
|
+
|
|
6067
|
+
const uint r0 = tgpig.y;
|
|
6068
|
+
const uint r1 = tgpig.x;
|
|
6069
|
+
const uint im = tgpig.z;
|
|
6070
|
+
|
|
6071
|
+
if (r1 * BLOCK_SIZE_N >= ne1) return;
|
|
6072
|
+
|
|
6073
|
+
// if this block is of 64x32 shape or smaller
|
|
6074
|
+
short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
|
|
6075
|
+
short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
|
|
6076
|
+
|
|
6077
|
+
// a thread shouldn't load data outside of the matrix
|
|
6078
|
+
short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
|
|
6079
|
+
short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
|
|
6080
|
+
|
|
6081
|
+
simdgroup_half8x8 ma[4];
|
|
6082
|
+
simdgroup_float8x8 mb[2];
|
|
6083
|
+
simdgroup_float8x8 c_res[8];
|
|
6084
|
+
for (int i = 0; i < 8; i++){
|
|
6085
|
+
c_res[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
|
|
6086
|
+
}
|
|
6087
|
+
|
|
6088
|
+
short il = (tiitg % THREAD_PER_ROW);
|
|
6089
|
+
|
|
6090
|
+
const uint i12 = im%ne12;
|
|
6091
|
+
const uint i13 = im/ne12;
|
|
6092
|
+
|
|
6093
|
+
uint offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02);
|
|
6094
|
+
ushort offset1 = il/nl;
|
|
6095
|
+
|
|
6096
|
+
device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
|
|
6097
|
+
device const float * y = (device const float *)(src1
|
|
6098
|
+
+ nb12 * im
|
|
6099
|
+
+ nb11 * src1ids[r1 * BLOCK_SIZE_N + thread_col]
|
|
6100
|
+
+ nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
|
|
6101
|
+
|
|
6102
|
+
for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
|
|
6103
|
+
// load data and store to threadgroup memory
|
|
6104
|
+
half4x4 temp_a;
|
|
6105
|
+
dequantize_func(x, il, temp_a);
|
|
6106
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
6107
|
+
|
|
6108
|
+
for (int i = 0; i < 16; i++) {
|
|
6109
|
+
*(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
|
|
6110
|
+
+ (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \
|
|
6111
|
+
+ (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
|
|
6112
|
+
}
|
|
6113
|
+
|
|
6114
|
+
*(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y);
|
|
6115
|
+
|
|
6116
|
+
il = (il + 2 < nl) ? il + 2 : il % 2;
|
|
6117
|
+
x = (il < 2) ? x + (2+nl-1)/nl : x;
|
|
6118
|
+
y += BLOCK_SIZE_K;
|
|
6119
|
+
|
|
6120
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
6121
|
+
|
|
6122
|
+
// load matrices from threadgroup memory and conduct outer products
|
|
6123
|
+
threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
|
|
6124
|
+
threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
|
|
6125
|
+
|
|
6126
|
+
for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
|
|
6127
|
+
for (int i = 0; i < 4; i++) {
|
|
6128
|
+
simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i);
|
|
6129
|
+
}
|
|
6130
|
+
simdgroup_barrier(mem_flags::mem_none);
|
|
6131
|
+
for (int i = 0; i < 2; i++) {
|
|
6132
|
+
simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i);
|
|
6133
|
+
}
|
|
6134
|
+
|
|
6135
|
+
lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
|
|
6136
|
+
lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
|
|
6137
|
+
|
|
6138
|
+
for (int i = 0; i < 8; i++){
|
|
6139
|
+
simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
|
|
6140
|
+
}
|
|
6141
|
+
}
|
|
6142
|
+
}
|
|
6143
|
+
|
|
6144
|
+
{
|
|
6145
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
6146
|
+
threadgroup float * temp_str = ((threadgroup float *)shared_memory) \
|
|
6147
|
+
+ 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
|
|
6148
|
+
for (int i = 0; i < 8; i++) {
|
|
6149
|
+
simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
|
|
6150
|
+
}
|
|
6151
|
+
|
|
6152
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
6153
|
+
|
|
6154
|
+
device float * C = dst + (BLOCK_SIZE_M * r0) + im*ne1*ne0;
|
|
6155
|
+
if (sgitg == 0) {
|
|
6156
|
+
for (int i = 0; i < n_rows; i++) {
|
|
6157
|
+
for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
|
|
6158
|
+
*(C + i + src1ids[j + r1*BLOCK_SIZE_N] * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
|
|
6159
|
+
}
|
|
6160
|
+
}
|
|
6161
|
+
}
|
|
6162
|
+
}
|
|
6163
|
+
}
|
|
6164
|
+
|
|
6165
|
+
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
|
6166
|
+
kernel void kernel_mul_mm(device const uchar * src0,
|
|
6167
|
+
device const uchar * src1,
|
|
6168
|
+
device float * dst,
|
|
6169
|
+
constant int64_t & ne00,
|
|
6170
|
+
constant int64_t & ne02,
|
|
6171
|
+
constant uint64_t & nb01,
|
|
6172
|
+
constant uint64_t & nb02,
|
|
6173
|
+
constant int64_t & ne12,
|
|
6174
|
+
constant uint64_t & nb10,
|
|
6175
|
+
constant uint64_t & nb11,
|
|
6176
|
+
constant uint64_t & nb12,
|
|
6177
|
+
constant int64_t & ne0,
|
|
6178
|
+
constant int64_t & ne1,
|
|
6179
|
+
constant uint & r2,
|
|
6180
|
+
constant uint & r3,
|
|
6181
|
+
threadgroup uchar * shared_memory [[threadgroup(0)]],
|
|
6182
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
6183
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
|
6184
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
6185
|
+
kernel_mul_mm_impl<block_q, nl, dequantize_func>(
|
|
6186
|
+
src0,
|
|
6187
|
+
src1,
|
|
6188
|
+
dst,
|
|
6189
|
+
ne00,
|
|
6190
|
+
ne02,
|
|
6191
|
+
nb01,
|
|
6192
|
+
nb02,
|
|
6193
|
+
ne12,
|
|
6194
|
+
nb10,
|
|
6195
|
+
nb11,
|
|
6196
|
+
nb12,
|
|
6197
|
+
ne0,
|
|
6198
|
+
ne1,
|
|
6199
|
+
r2,
|
|
6200
|
+
r3,
|
|
6201
|
+
shared_memory,
|
|
6202
|
+
tgpig,
|
|
6203
|
+
tiitg,
|
|
6204
|
+
sgitg);
|
|
6205
|
+
}
|
|
6206
|
+
|
|
6207
|
+
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
|
6208
|
+
kernel void kernel_mul_mm_id(
|
|
6209
|
+
device const uchar * ids,
|
|
6210
|
+
device const uchar * src1,
|
|
6211
|
+
device float * dst,
|
|
6212
|
+
constant uint64_t & nbi1,
|
|
6213
|
+
constant int64_t & ne00,
|
|
6214
|
+
constant int64_t & ne02,
|
|
6215
|
+
constant uint64_t & nb01,
|
|
6216
|
+
constant uint64_t & nb02,
|
|
6217
|
+
constant int64_t & ne12,
|
|
6218
|
+
constant int64_t & ne13,
|
|
6219
|
+
constant uint64_t & nb10,
|
|
6220
|
+
constant uint64_t & nb11,
|
|
6221
|
+
constant uint64_t & nb12,
|
|
6222
|
+
constant int64_t & ne0,
|
|
6223
|
+
constant int64_t & ne1,
|
|
6224
|
+
constant uint64_t & nb1,
|
|
6225
|
+
constant uint & r2,
|
|
6226
|
+
constant uint & r3,
|
|
6227
|
+
constant int & idx,
|
|
6228
|
+
device const uchar * src00,
|
|
6229
|
+
device const uchar * src01,
|
|
6230
|
+
device const uchar * src02,
|
|
6231
|
+
device const uchar * src03,
|
|
6232
|
+
device const uchar * src04,
|
|
6233
|
+
device const uchar * src05,
|
|
6234
|
+
device const uchar * src06,
|
|
6235
|
+
device const uchar * src07,
|
|
6236
|
+
threadgroup uchar * shared_memory [[threadgroup(0)]],
|
|
6237
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
6238
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
|
6239
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
6240
|
+
device const uchar * src0s[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
|
6241
|
+
|
|
6242
|
+
// expert id
|
|
6243
|
+
const int32_t id = tgpig.z/(ne12*ne13);
|
|
6244
|
+
|
|
6245
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
|
6246
|
+
|
|
6247
|
+
// row indices of src1 for expert id
|
|
6248
|
+
int64_t _ne1 = 0;
|
|
6249
|
+
short src1ids[512];
|
|
6250
|
+
|
|
6251
|
+
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
|
6252
|
+
if (((device int32_t *) (ids + i1*nbi1))[idx] == id) {
|
|
6253
|
+
src1ids[_ne1++] = i1;
|
|
6254
|
+
}
|
|
6255
|
+
}
|
|
6256
|
+
|
|
6257
|
+
kernel_mul_mm_id_impl<block_q, nl, dequantize_func>(
|
|
6258
|
+
src0s[id],
|
|
6259
|
+
src1,
|
|
6260
|
+
src1ids,
|
|
6261
|
+
dst,
|
|
6262
|
+
ne00,
|
|
6263
|
+
ne02,
|
|
6264
|
+
nb01,
|
|
6265
|
+
nb02,
|
|
6266
|
+
ne12,
|
|
6267
|
+
nb10,
|
|
6268
|
+
nb11,
|
|
6269
|
+
nb12,
|
|
6270
|
+
ne0,
|
|
6271
|
+
_ne1,
|
|
6272
|
+
r2,
|
|
6273
|
+
r3,
|
|
6274
|
+
shared_memory,
|
|
6275
|
+
tgpig,
|
|
6276
|
+
tiitg,
|
|
6277
|
+
sgitg);
|
|
6278
|
+
}
|
|
6279
|
+
|
|
6280
|
+
#if QK_K == 256
|
|
6281
|
+
#define QK_NL 16
|
|
6282
|
+
#else
|
|
6283
|
+
#define QK_NL 4
|
|
6284
|
+
#endif
|
|
6285
|
+
|
|
6286
|
+
//
|
|
6287
|
+
// get rows
|
|
6288
|
+
//
|
|
6289
|
+
|
|
6290
|
+
typedef void (get_rows_t)(
|
|
6291
|
+
device const void * src0,
|
|
6292
|
+
device const char * src1,
|
|
6293
|
+
device float * dst,
|
|
6294
|
+
constant int64_t & ne00,
|
|
6295
|
+
constant uint64_t & nb01,
|
|
6296
|
+
constant uint64_t & nb02,
|
|
6297
|
+
constant int64_t & ne10,
|
|
6298
|
+
constant uint64_t & nb10,
|
|
6299
|
+
constant uint64_t & nb11,
|
|
6300
|
+
constant uint64_t & nb1,
|
|
6301
|
+
constant uint64_t & nb2,
|
|
6302
|
+
uint3, uint, uint3);
|
|
6303
|
+
|
|
6304
|
+
//template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
|
|
6305
|
+
//template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
|
|
6306
|
+
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
|
|
6307
|
+
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
|
|
6308
|
+
template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows<block_q5_0, 2, dequantize_q5_0>;
|
|
6309
|
+
template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_t kernel_get_rows<block_q5_1, 2, dequantize_q5_1>;
|
|
6310
|
+
template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows<block_q8_0, 2, dequantize_q8_0>;
|
|
6311
|
+
template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows<block_q2_K, QK_NL, dequantize_q2_K>;
|
|
6312
|
+
template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows<block_q3_K, QK_NL, dequantize_q3_K>;
|
|
6313
|
+
template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows<block_q4_K, QK_NL, dequantize_q4_K>;
|
|
6314
|
+
template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
|
|
6315
|
+
template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
|
|
6316
|
+
template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_t kernel_get_rows<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
|
6317
|
+
template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_t kernel_get_rows<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
|
6318
|
+
template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_t kernel_get_rows<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
|
6319
|
+
template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_t kernel_get_rows<block_iq3_s, QK_NL, dequantize_iq3_s>;
|
|
6320
|
+
template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_t kernel_get_rows<block_iq2_s, QK_NL, dequantize_iq2_s>;
|
|
6321
|
+
template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
|
6322
|
+
template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows<block_iq4_nl, 2, dequantize_iq4_nl>;
|
|
6323
|
+
#if QK_K == 64
|
|
6324
|
+
template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows<block_iq4_xs, 2, dequantize_iq4_xs>;
|
|
6325
|
+
#else
|
|
6326
|
+
template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
|
6327
|
+
#endif
|
|
6328
|
+
|
|
6329
|
+
//
|
|
6330
|
+
// matrix-matrix multiplication
|
|
6331
|
+
//
|
|
6332
|
+
|
|
6333
|
+
typedef void (mat_mm_t)(
|
|
6334
|
+
device const uchar * src0,
|
|
6335
|
+
device const uchar * src1,
|
|
6336
|
+
device float * dst,
|
|
6337
|
+
constant int64_t & ne00,
|
|
6338
|
+
constant int64_t & ne02,
|
|
6339
|
+
constant uint64_t & nb01,
|
|
6340
|
+
constant uint64_t & nb02,
|
|
6341
|
+
constant int64_t & ne12,
|
|
6342
|
+
constant uint64_t & nb10,
|
|
6343
|
+
constant uint64_t & nb11,
|
|
6344
|
+
constant uint64_t & nb12,
|
|
6345
|
+
constant int64_t & ne0,
|
|
6346
|
+
constant int64_t & ne1,
|
|
6347
|
+
constant uint & r2,
|
|
6348
|
+
constant uint & r3,
|
|
6349
|
+
threadgroup uchar *,
|
|
6350
|
+
uint3, uint, uint);
|
|
6351
|
+
|
|
6352
|
+
template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<float4x4, 1, dequantize_f32>;
|
|
6353
|
+
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
|
|
6354
|
+
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
|
|
6355
|
+
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
|
|
6356
|
+
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_0, 2, dequantize_q5_0>;
|
|
6357
|
+
template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_1, 2, dequantize_q5_1>;
|
|
6358
|
+
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
|
|
6359
|
+
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
|
|
6360
|
+
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
|
|
6361
|
+
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
|
|
6362
|
+
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
|
|
6363
|
+
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
|
|
6364
|
+
template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
|
6365
|
+
template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
|
6366
|
+
template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
|
6367
|
+
template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_s, QK_NL, dequantize_iq3_s>;
|
|
6368
|
+
template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_s, QK_NL, dequantize_iq2_s>;
|
|
6369
|
+
template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
|
6370
|
+
template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_nl>;
|
|
6371
|
+
#if QK_K == 64
|
|
6372
|
+
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_xs>;
|
|
6373
|
+
#else
|
|
6374
|
+
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
|
6375
|
+
#endif
|
|
6376
|
+
|
|
6377
|
+
//
|
|
6378
|
+
// indirect matrix-matrix multiplication
|
|
6379
|
+
//
|
|
6380
|
+
|
|
6381
|
+
typedef void (mat_mm_id_t)(
|
|
6382
|
+
device const uchar * ids,
|
|
6383
|
+
device const uchar * src1,
|
|
6384
|
+
device float * dst,
|
|
6385
|
+
constant uint64_t & nbi1,
|
|
6386
|
+
constant int64_t & ne00,
|
|
6387
|
+
constant int64_t & ne02,
|
|
6388
|
+
constant uint64_t & nb01,
|
|
6389
|
+
constant uint64_t & nb02,
|
|
6390
|
+
constant int64_t & ne12,
|
|
6391
|
+
constant int64_t & ne13,
|
|
6392
|
+
constant uint64_t & nb10,
|
|
6393
|
+
constant uint64_t & nb11,
|
|
6394
|
+
constant uint64_t & nb12,
|
|
6395
|
+
constant int64_t & ne0,
|
|
6396
|
+
constant int64_t & ne1,
|
|
6397
|
+
constant uint64_t & nb1,
|
|
6398
|
+
constant uint & r2,
|
|
6399
|
+
constant uint & r3,
|
|
6400
|
+
constant int & idx,
|
|
6401
|
+
device const uchar * src00,
|
|
6402
|
+
device const uchar * src01,
|
|
6403
|
+
device const uchar * src02,
|
|
6404
|
+
device const uchar * src03,
|
|
6405
|
+
device const uchar * src04,
|
|
6406
|
+
device const uchar * src05,
|
|
6407
|
+
device const uchar * src06,
|
|
6408
|
+
device const uchar * src07,
|
|
6409
|
+
threadgroup uchar *,
|
|
6410
|
+
uint3, uint, uint);
|
|
6411
|
+
|
|
6412
|
+
template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<float4x4, 1, dequantize_f32>;
|
|
6413
|
+
template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<half4x4, 1, dequantize_f16>;
|
|
6414
|
+
template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_0, 2, dequantize_q4_0>;
|
|
6415
|
+
template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_1, 2, dequantize_q4_1>;
|
|
6416
|
+
template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_0, 2, dequantize_q5_0>;
|
|
6417
|
+
template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_1, 2, dequantize_q5_1>;
|
|
6418
|
+
template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q8_0, 2, dequantize_q8_0>;
|
|
6419
|
+
template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q2_K, QK_NL, dequantize_q2_K>;
|
|
6420
|
+
template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q3_K, QK_NL, dequantize_q3_K>;
|
|
6421
|
+
template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>;
|
|
6422
|
+
template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
|
|
6423
|
+
template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
|
|
6424
|
+
template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
|
6425
|
+
template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
|
6426
|
+
template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
|
6427
|
+
template [[host_name("kernel_mul_mm_id_iq3_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_s, QK_NL, dequantize_iq3_s>;
|
|
6428
|
+
template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_s, QK_NL, dequantize_iq2_s>;
|
|
6429
|
+
template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
|
6430
|
+
template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl, 2, dequantize_iq4_nl>;
|
|
6431
|
+
#if QK_K == 64
|
|
6432
|
+
template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, 2, dequantize_iq4_xs>;
|
|
6433
|
+
#else
|
|
6434
|
+
template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
|
6435
|
+
#endif
|
|
6436
|
+
|
|
6437
|
+
//
|
|
6438
|
+
// matrix-vector multiplication
|
|
6439
|
+
//
|
|
6440
|
+
|
|
6441
|
+
[[host_name("kernel_mul_mv_id_f32_f32")]]
|
|
6442
|
+
kernel void kernel_mul_mv_id_f32_f32(
|
|
6443
|
+
device const char * ids,
|
|
6444
|
+
device const char * src1,
|
|
6445
|
+
device float * dst,
|
|
6446
|
+
constant uint64_t & nbi1,
|
|
6447
|
+
constant int64_t & ne00,
|
|
6448
|
+
constant int64_t & ne01,
|
|
6449
|
+
constant int64_t & ne02,
|
|
6450
|
+
constant uint64_t & nb00,
|
|
6451
|
+
constant uint64_t & nb01,
|
|
6452
|
+
constant uint64_t & nb02,
|
|
6453
|
+
constant int64_t & ne10,
|
|
6454
|
+
constant int64_t & ne11,
|
|
6455
|
+
constant int64_t & ne12,
|
|
6456
|
+
constant int64_t & ne13,
|
|
6457
|
+
constant uint64_t & nb10,
|
|
6458
|
+
constant uint64_t & nb11,
|
|
6459
|
+
constant uint64_t & nb12,
|
|
6460
|
+
constant int64_t & ne0,
|
|
6461
|
+
constant int64_t & ne1,
|
|
6462
|
+
constant uint64_t & nb1,
|
|
6463
|
+
constant uint & r2,
|
|
6464
|
+
constant uint & r3,
|
|
6465
|
+
constant int & idx,
|
|
6466
|
+
device const char * src00,
|
|
6467
|
+
device const char * src01,
|
|
6468
|
+
device const char * src02,
|
|
6469
|
+
device const char * src03,
|
|
6470
|
+
device const char * src04,
|
|
6471
|
+
device const char * src05,
|
|
6472
|
+
device const char * src06,
|
|
6473
|
+
device const char * src07,
|
|
6474
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
6475
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
|
6476
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
6477
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
6478
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
|
4887
6479
|
|
|
4888
|
-
|
|
4889
|
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
4890
|
-
threadgroup float * temp_str = ((threadgroup float *)shared_memory) \
|
|
4891
|
-
+ 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
|
|
4892
|
-
for (int i = 0; i < 8; i++) {
|
|
4893
|
-
simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
|
|
4894
|
-
}
|
|
6480
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
|
4895
6481
|
|
|
4896
|
-
|
|
6482
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
|
4897
6483
|
|
|
4898
|
-
|
|
4899
|
-
if (sgitg == 0) {
|
|
4900
|
-
for (int i = 0; i < n_rows; i++) {
|
|
4901
|
-
for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
|
|
4902
|
-
*(C + i + src1ids[j + r1*BLOCK_SIZE_N] * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
|
|
4903
|
-
}
|
|
4904
|
-
}
|
|
4905
|
-
}
|
|
4906
|
-
}
|
|
4907
|
-
}
|
|
6484
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
4908
6485
|
|
|
4909
|
-
|
|
4910
|
-
|
|
4911
|
-
|
|
4912
|
-
|
|
4913
|
-
constant int64_t & ne00,
|
|
4914
|
-
constant int64_t & ne02,
|
|
4915
|
-
constant uint64_t & nb01,
|
|
4916
|
-
constant uint64_t & nb02,
|
|
4917
|
-
constant int64_t & ne12,
|
|
4918
|
-
constant uint64_t & nb10,
|
|
4919
|
-
constant uint64_t & nb11,
|
|
4920
|
-
constant uint64_t & nb12,
|
|
4921
|
-
constant int64_t & ne0,
|
|
4922
|
-
constant int64_t & ne1,
|
|
4923
|
-
constant uint & r2,
|
|
4924
|
-
constant uint & r3,
|
|
4925
|
-
threadgroup uchar * shared_memory [[threadgroup(0)]],
|
|
4926
|
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
4927
|
-
uint tiitg[[thread_index_in_threadgroup]],
|
|
4928
|
-
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
4929
|
-
kernel_mul_mm_impl<block_q, nl, dequantize_func>(
|
|
4930
|
-
src0,
|
|
4931
|
-
src1,
|
|
4932
|
-
dst,
|
|
6486
|
+
kernel_mul_mv_f32_f32_impl(
|
|
6487
|
+
src0[id],
|
|
6488
|
+
src1 + bid*nb11,
|
|
6489
|
+
dst + bid*ne0,
|
|
4933
6490
|
ne00,
|
|
6491
|
+
ne01,
|
|
4934
6492
|
ne02,
|
|
6493
|
+
nb00,
|
|
4935
6494
|
nb01,
|
|
4936
6495
|
nb02,
|
|
6496
|
+
ne10,
|
|
6497
|
+
ne11,
|
|
4937
6498
|
ne12,
|
|
4938
6499
|
nb10,
|
|
4939
6500
|
nb11,
|
|
@@ -4942,22 +6503,24 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|
|
4942
6503
|
ne1,
|
|
4943
6504
|
r2,
|
|
4944
6505
|
r3,
|
|
4945
|
-
shared_memory,
|
|
4946
6506
|
tgpig,
|
|
4947
|
-
|
|
4948
|
-
sgitg);
|
|
6507
|
+
tiisg);
|
|
4949
6508
|
}
|
|
4950
6509
|
|
|
4951
|
-
|
|
4952
|
-
kernel void
|
|
4953
|
-
device const
|
|
4954
|
-
device const
|
|
6510
|
+
[[host_name("kernel_mul_mv_id_f16_f32")]]
|
|
6511
|
+
kernel void kernel_mul_mv_id_f16_f32(
|
|
6512
|
+
device const char * ids,
|
|
6513
|
+
device const char * src1,
|
|
4955
6514
|
device float * dst,
|
|
4956
6515
|
constant uint64_t & nbi1,
|
|
4957
6516
|
constant int64_t & ne00,
|
|
6517
|
+
constant int64_t & ne01,
|
|
4958
6518
|
constant int64_t & ne02,
|
|
6519
|
+
constant uint64_t & nb00,
|
|
4959
6520
|
constant uint64_t & nb01,
|
|
4960
6521
|
constant uint64_t & nb02,
|
|
6522
|
+
constant int64_t & ne10,
|
|
6523
|
+
constant int64_t & ne11,
|
|
4961
6524
|
constant int64_t & ne12,
|
|
4962
6525
|
constant int64_t & ne13,
|
|
4963
6526
|
constant uint64_t & nb10,
|
|
@@ -4969,150 +6532,190 @@ kernel void kernel_mul_mm_id(
|
|
|
4969
6532
|
constant uint & r2,
|
|
4970
6533
|
constant uint & r3,
|
|
4971
6534
|
constant int & idx,
|
|
4972
|
-
device const
|
|
4973
|
-
device const
|
|
4974
|
-
device const
|
|
4975
|
-
device const
|
|
4976
|
-
device const
|
|
4977
|
-
device const
|
|
4978
|
-
device const
|
|
4979
|
-
device const
|
|
4980
|
-
threadgroup uchar * shared_memory [[threadgroup(0)]],
|
|
6535
|
+
device const char * src00,
|
|
6536
|
+
device const char * src01,
|
|
6537
|
+
device const char * src02,
|
|
6538
|
+
device const char * src03,
|
|
6539
|
+
device const char * src04,
|
|
6540
|
+
device const char * src05,
|
|
6541
|
+
device const char * src06,
|
|
6542
|
+
device const char * src07,
|
|
4981
6543
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
4982
6544
|
uint tiitg[[thread_index_in_threadgroup]],
|
|
6545
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
4983
6546
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
4984
|
-
device const
|
|
6547
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
|
4985
6548
|
|
|
4986
|
-
|
|
4987
|
-
const int32_t id = tgpig.z/(ne12*ne13);
|
|
6549
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
|
4988
6550
|
|
|
4989
6551
|
tgpig.z = tgpig.z%(ne12*ne13);
|
|
4990
6552
|
|
|
4991
|
-
|
|
4992
|
-
int64_t _ne1 = 0;
|
|
4993
|
-
short src1ids[512];
|
|
4994
|
-
|
|
4995
|
-
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
|
4996
|
-
if (((device int32_t *) (ids + i1*nbi1))[idx] == id) {
|
|
4997
|
-
src1ids[_ne1++] = i1;
|
|
4998
|
-
}
|
|
4999
|
-
}
|
|
6553
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
5000
6554
|
|
|
5001
|
-
|
|
5002
|
-
|
|
5003
|
-
src1,
|
|
5004
|
-
|
|
5005
|
-
dst,
|
|
6555
|
+
kernel_mul_mv_f16_f32_impl(
|
|
6556
|
+
src0[id],
|
|
6557
|
+
src1 + bid*nb11,
|
|
6558
|
+
dst + bid*ne0,
|
|
5006
6559
|
ne00,
|
|
6560
|
+
ne01,
|
|
5007
6561
|
ne02,
|
|
6562
|
+
nb00,
|
|
5008
6563
|
nb01,
|
|
5009
6564
|
nb02,
|
|
6565
|
+
ne10,
|
|
6566
|
+
ne11,
|
|
5010
6567
|
ne12,
|
|
5011
6568
|
nb10,
|
|
5012
6569
|
nb11,
|
|
5013
6570
|
nb12,
|
|
5014
6571
|
ne0,
|
|
5015
|
-
|
|
6572
|
+
ne1,
|
|
5016
6573
|
r2,
|
|
5017
6574
|
r3,
|
|
5018
|
-
shared_memory,
|
|
5019
6575
|
tgpig,
|
|
5020
|
-
|
|
5021
|
-
sgitg);
|
|
6576
|
+
tiisg);
|
|
5022
6577
|
}
|
|
5023
6578
|
|
|
5024
|
-
|
|
5025
|
-
|
|
5026
|
-
|
|
5027
|
-
|
|
5028
|
-
|
|
6579
|
+
[[host_name("kernel_mul_mv_id_q8_0_f32")]]
|
|
6580
|
+
kernel void kernel_mul_mv_id_q8_0_f32(
|
|
6581
|
+
device const char * ids,
|
|
6582
|
+
device const char * src1,
|
|
6583
|
+
device float * dst,
|
|
6584
|
+
constant uint64_t & nbi1,
|
|
6585
|
+
constant int64_t & ne00,
|
|
6586
|
+
constant int64_t & ne01,
|
|
6587
|
+
constant int64_t & ne02,
|
|
6588
|
+
constant uint64_t & nb00,
|
|
6589
|
+
constant uint64_t & nb01,
|
|
6590
|
+
constant uint64_t & nb02,
|
|
6591
|
+
constant int64_t & ne10,
|
|
6592
|
+
constant int64_t & ne11,
|
|
6593
|
+
constant int64_t & ne12,
|
|
6594
|
+
constant int64_t & ne13,
|
|
6595
|
+
constant uint64_t & nb10,
|
|
6596
|
+
constant uint64_t & nb11,
|
|
6597
|
+
constant uint64_t & nb12,
|
|
6598
|
+
constant int64_t & ne0,
|
|
6599
|
+
constant int64_t & ne1,
|
|
6600
|
+
constant uint64_t & nb1,
|
|
6601
|
+
constant uint & r2,
|
|
6602
|
+
constant uint & r3,
|
|
6603
|
+
constant int & idx,
|
|
6604
|
+
device const char * src00,
|
|
6605
|
+
device const char * src01,
|
|
6606
|
+
device const char * src02,
|
|
6607
|
+
device const char * src03,
|
|
6608
|
+
device const char * src04,
|
|
6609
|
+
device const char * src05,
|
|
6610
|
+
device const char * src06,
|
|
6611
|
+
device const char * src07,
|
|
6612
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
6613
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
|
6614
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
6615
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
6616
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
|
5029
6617
|
|
|
5030
|
-
|
|
5031
|
-
// get rows
|
|
5032
|
-
//
|
|
6618
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
|
5033
6619
|
|
|
5034
|
-
|
|
5035
|
-
device const void * src0,
|
|
5036
|
-
device const char * src1,
|
|
5037
|
-
device float * dst,
|
|
5038
|
-
constant int64_t & ne00,
|
|
5039
|
-
constant uint64_t & nb01,
|
|
5040
|
-
constant uint64_t & nb02,
|
|
5041
|
-
constant int64_t & ne10,
|
|
5042
|
-
constant uint64_t & nb10,
|
|
5043
|
-
constant uint64_t & nb11,
|
|
5044
|
-
constant uint64_t & nb1,
|
|
5045
|
-
constant uint64_t & nb2,
|
|
5046
|
-
uint3, uint, uint3);
|
|
6620
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
|
5047
6621
|
|
|
5048
|
-
|
|
5049
|
-
|
|
5050
|
-
|
|
5051
|
-
|
|
5052
|
-
|
|
5053
|
-
|
|
5054
|
-
|
|
5055
|
-
|
|
5056
|
-
|
|
5057
|
-
|
|
5058
|
-
|
|
5059
|
-
|
|
5060
|
-
|
|
5061
|
-
|
|
5062
|
-
|
|
6622
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
6623
|
+
|
|
6624
|
+
kernel_mul_mv_q8_0_f32_impl(
|
|
6625
|
+
src0[id],
|
|
6626
|
+
(device const float *) (src1 + bid*nb11),
|
|
6627
|
+
dst + bid*ne0,
|
|
6628
|
+
ne00,
|
|
6629
|
+
ne01,
|
|
6630
|
+
ne02,
|
|
6631
|
+
ne10,
|
|
6632
|
+
ne12,
|
|
6633
|
+
ne0,
|
|
6634
|
+
ne1,
|
|
6635
|
+
r2,
|
|
6636
|
+
r3,
|
|
6637
|
+
tgpig,
|
|
6638
|
+
tiisg,
|
|
6639
|
+
sgitg);
|
|
6640
|
+
}
|
|
6641
|
+
|
|
6642
|
+
[[host_name("kernel_mul_mv_id_q4_0_f32")]]
|
|
6643
|
+
kernel void kernel_mul_mv_id_q4_0_f32(
|
|
6644
|
+
device const char * ids,
|
|
6645
|
+
device const char * src1,
|
|
6646
|
+
device float * dst,
|
|
6647
|
+
constant uint64_t & nbi1,
|
|
6648
|
+
constant int64_t & ne00,
|
|
6649
|
+
constant int64_t & ne01,
|
|
6650
|
+
constant int64_t & ne02,
|
|
6651
|
+
constant uint64_t & nb00,
|
|
6652
|
+
constant uint64_t & nb01,
|
|
6653
|
+
constant uint64_t & nb02,
|
|
6654
|
+
constant int64_t & ne10,
|
|
6655
|
+
constant int64_t & ne11,
|
|
6656
|
+
constant int64_t & ne12,
|
|
6657
|
+
constant int64_t & ne13,
|
|
6658
|
+
constant uint64_t & nb10,
|
|
6659
|
+
constant uint64_t & nb11,
|
|
6660
|
+
constant uint64_t & nb12,
|
|
6661
|
+
constant int64_t & ne0,
|
|
6662
|
+
constant int64_t & ne1,
|
|
6663
|
+
constant uint64_t & nb1,
|
|
6664
|
+
constant uint & r2,
|
|
6665
|
+
constant uint & r3,
|
|
6666
|
+
constant int & idx,
|
|
6667
|
+
device const char * src00,
|
|
6668
|
+
device const char * src01,
|
|
6669
|
+
device const char * src02,
|
|
6670
|
+
device const char * src03,
|
|
6671
|
+
device const char * src04,
|
|
6672
|
+
device const char * src05,
|
|
6673
|
+
device const char * src06,
|
|
6674
|
+
device const char * src07,
|
|
6675
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
6676
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
|
6677
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
6678
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
6679
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
|
5063
6680
|
|
|
5064
|
-
|
|
5065
|
-
// matrix-matrix multiplication
|
|
5066
|
-
//
|
|
6681
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
|
5067
6682
|
|
|
5068
|
-
|
|
5069
|
-
device const uchar * src0,
|
|
5070
|
-
device const uchar * src1,
|
|
5071
|
-
device float * dst,
|
|
5072
|
-
constant int64_t & ne00,
|
|
5073
|
-
constant int64_t & ne02,
|
|
5074
|
-
constant uint64_t & nb01,
|
|
5075
|
-
constant uint64_t & nb02,
|
|
5076
|
-
constant int64_t & ne12,
|
|
5077
|
-
constant uint64_t & nb10,
|
|
5078
|
-
constant uint64_t & nb11,
|
|
5079
|
-
constant uint64_t & nb12,
|
|
5080
|
-
constant int64_t & ne0,
|
|
5081
|
-
constant int64_t & ne1,
|
|
5082
|
-
constant uint & r2,
|
|
5083
|
-
constant uint & r3,
|
|
5084
|
-
threadgroup uchar *,
|
|
5085
|
-
uint3, uint, uint);
|
|
6683
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
|
5086
6684
|
|
|
5087
|
-
|
|
5088
|
-
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
|
|
5089
|
-
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
|
|
5090
|
-
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
|
|
5091
|
-
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_0, 2, dequantize_q5_0>;
|
|
5092
|
-
template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_1, 2, dequantize_q5_1>;
|
|
5093
|
-
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
|
|
5094
|
-
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
|
|
5095
|
-
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
|
|
5096
|
-
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
|
|
5097
|
-
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
|
|
5098
|
-
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
|
|
5099
|
-
template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
|
5100
|
-
template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
|
5101
|
-
template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
|
6685
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
5102
6686
|
|
|
5103
|
-
|
|
5104
|
-
|
|
5105
|
-
|
|
6687
|
+
mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
|
6688
|
+
src0[id],
|
|
6689
|
+
(device const float *) (src1 + bid*nb11),
|
|
6690
|
+
dst + bid*ne0,
|
|
6691
|
+
ne00,
|
|
6692
|
+
ne01,
|
|
6693
|
+
ne02,
|
|
6694
|
+
ne10,
|
|
6695
|
+
ne12,
|
|
6696
|
+
ne0,
|
|
6697
|
+
ne1,
|
|
6698
|
+
r2,
|
|
6699
|
+
r3,
|
|
6700
|
+
tgpig,
|
|
6701
|
+
tiisg,
|
|
6702
|
+
sgitg);
|
|
6703
|
+
}
|
|
5106
6704
|
|
|
5107
|
-
|
|
5108
|
-
|
|
5109
|
-
device const
|
|
6705
|
+
[[host_name("kernel_mul_mv_id_q4_1_f32")]]
|
|
6706
|
+
kernel void kernel_mul_mv_id_q4_1_f32(
|
|
6707
|
+
device const char * ids,
|
|
6708
|
+
device const char * src1,
|
|
5110
6709
|
device float * dst,
|
|
5111
6710
|
constant uint64_t & nbi1,
|
|
5112
6711
|
constant int64_t & ne00,
|
|
6712
|
+
constant int64_t & ne01,
|
|
5113
6713
|
constant int64_t & ne02,
|
|
6714
|
+
constant uint64_t & nb00,
|
|
5114
6715
|
constant uint64_t & nb01,
|
|
5115
6716
|
constant uint64_t & nb02,
|
|
6717
|
+
constant int64_t & ne10,
|
|
6718
|
+
constant int64_t & ne11,
|
|
5116
6719
|
constant int64_t & ne12,
|
|
5117
6720
|
constant int64_t & ne13,
|
|
5118
6721
|
constant uint64_t & nb10,
|
|
@@ -5124,39 +6727,46 @@ typedef void (mat_mm_id_t)(
|
|
|
5124
6727
|
constant uint & r2,
|
|
5125
6728
|
constant uint & r3,
|
|
5126
6729
|
constant int & idx,
|
|
5127
|
-
device const
|
|
5128
|
-
device const
|
|
5129
|
-
device const
|
|
5130
|
-
device const
|
|
5131
|
-
device const
|
|
5132
|
-
device const
|
|
5133
|
-
device const
|
|
5134
|
-
device const
|
|
5135
|
-
|
|
5136
|
-
|
|
6730
|
+
device const char * src00,
|
|
6731
|
+
device const char * src01,
|
|
6732
|
+
device const char * src02,
|
|
6733
|
+
device const char * src03,
|
|
6734
|
+
device const char * src04,
|
|
6735
|
+
device const char * src05,
|
|
6736
|
+
device const char * src06,
|
|
6737
|
+
device const char * src07,
|
|
6738
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
6739
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
|
6740
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
|
6741
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
6742
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
|
5137
6743
|
|
|
5138
|
-
|
|
5139
|
-
template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<half4x4, 1, dequantize_f16>;
|
|
5140
|
-
template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_0, 2, dequantize_q4_0>;
|
|
5141
|
-
template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_1, 2, dequantize_q4_1>;
|
|
5142
|
-
template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_0, 2, dequantize_q5_0>;
|
|
5143
|
-
template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_1, 2, dequantize_q5_1>;
|
|
5144
|
-
template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q8_0, 2, dequantize_q8_0>;
|
|
5145
|
-
template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q2_K, QK_NL, dequantize_q2_K>;
|
|
5146
|
-
template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q3_K, QK_NL, dequantize_q3_K>;
|
|
5147
|
-
template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>;
|
|
5148
|
-
template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
|
|
5149
|
-
template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
|
|
5150
|
-
template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
|
5151
|
-
template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
|
5152
|
-
template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
|
6744
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
|
5153
6745
|
|
|
5154
|
-
|
|
5155
|
-
// matrix-vector multiplication
|
|
5156
|
-
//
|
|
6746
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
|
5157
6747
|
|
|
5158
|
-
|
|
5159
|
-
|
|
6748
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
6749
|
+
|
|
6750
|
+
mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
|
6751
|
+
src0[id],
|
|
6752
|
+
(device const float *) (src1 + bid*nb11),
|
|
6753
|
+
dst + bid*ne0,
|
|
6754
|
+
ne00,
|
|
6755
|
+
ne01,
|
|
6756
|
+
ne02,
|
|
6757
|
+
ne10,
|
|
6758
|
+
ne12,
|
|
6759
|
+
ne0,
|
|
6760
|
+
ne1,
|
|
6761
|
+
r2,
|
|
6762
|
+
r3,
|
|
6763
|
+
tgpig,
|
|
6764
|
+
tiisg,
|
|
6765
|
+
sgitg);
|
|
6766
|
+
}
|
|
6767
|
+
|
|
6768
|
+
[[host_name("kernel_mul_mv_id_q5_0_f32")]]
|
|
6769
|
+
kernel void kernel_mul_mv_id_q5_0_f32(
|
|
5160
6770
|
device const char * ids,
|
|
5161
6771
|
device const char * src1,
|
|
5162
6772
|
device float * dst,
|
|
@@ -5200,32 +6810,26 @@ kernel void kernel_mul_mv_id_f32_f32(
|
|
|
5200
6810
|
|
|
5201
6811
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
5202
6812
|
|
|
5203
|
-
|
|
6813
|
+
mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
|
5204
6814
|
src0[id],
|
|
5205
|
-
src1 + bid*nb11,
|
|
5206
|
-
dst
|
|
6815
|
+
(device const float *) (src1 + bid*nb11),
|
|
6816
|
+
dst + bid*ne0,
|
|
5207
6817
|
ne00,
|
|
5208
6818
|
ne01,
|
|
5209
6819
|
ne02,
|
|
5210
|
-
nb00,
|
|
5211
|
-
nb01,
|
|
5212
|
-
nb02,
|
|
5213
6820
|
ne10,
|
|
5214
|
-
ne11,
|
|
5215
6821
|
ne12,
|
|
5216
|
-
nb10,
|
|
5217
|
-
nb11,
|
|
5218
|
-
nb12,
|
|
5219
6822
|
ne0,
|
|
5220
6823
|
ne1,
|
|
5221
6824
|
r2,
|
|
5222
6825
|
r3,
|
|
5223
6826
|
tgpig,
|
|
5224
|
-
tiisg
|
|
6827
|
+
tiisg,
|
|
6828
|
+
sgitg);
|
|
5225
6829
|
}
|
|
5226
6830
|
|
|
5227
|
-
[[host_name("
|
|
5228
|
-
kernel void
|
|
6831
|
+
[[host_name("kernel_mul_mv_id_q5_1_f32")]]
|
|
6832
|
+
kernel void kernel_mul_mv_id_q5_1_f32(
|
|
5229
6833
|
device const char * ids,
|
|
5230
6834
|
device const char * src1,
|
|
5231
6835
|
device float * dst,
|
|
@@ -5269,32 +6873,26 @@ kernel void kernel_mul_mv_id_f16_f32(
|
|
|
5269
6873
|
|
|
5270
6874
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
5271
6875
|
|
|
5272
|
-
|
|
6876
|
+
mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
|
5273
6877
|
src0[id],
|
|
5274
|
-
src1 + bid*nb11,
|
|
5275
|
-
dst
|
|
6878
|
+
(device const float *) (src1 + bid*nb11),
|
|
6879
|
+
dst + bid*ne0,
|
|
5276
6880
|
ne00,
|
|
5277
6881
|
ne01,
|
|
5278
6882
|
ne02,
|
|
5279
|
-
nb00,
|
|
5280
|
-
nb01,
|
|
5281
|
-
nb02,
|
|
5282
6883
|
ne10,
|
|
5283
|
-
ne11,
|
|
5284
6884
|
ne12,
|
|
5285
|
-
nb10,
|
|
5286
|
-
nb11,
|
|
5287
|
-
nb12,
|
|
5288
6885
|
ne0,
|
|
5289
6886
|
ne1,
|
|
5290
6887
|
r2,
|
|
5291
6888
|
r3,
|
|
5292
6889
|
tgpig,
|
|
5293
|
-
tiisg
|
|
6890
|
+
tiisg,
|
|
6891
|
+
sgitg);
|
|
5294
6892
|
}
|
|
5295
6893
|
|
|
5296
|
-
[[host_name("
|
|
5297
|
-
kernel void
|
|
6894
|
+
[[host_name("kernel_mul_mv_id_q2_K_f32")]]
|
|
6895
|
+
kernel void kernel_mul_mv_id_q2_K_f32(
|
|
5298
6896
|
device const char * ids,
|
|
5299
6897
|
device const char * src1,
|
|
5300
6898
|
device float * dst,
|
|
@@ -5338,7 +6936,7 @@ kernel void kernel_mul_mv_id_q8_0_f32(
|
|
|
5338
6936
|
|
|
5339
6937
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
5340
6938
|
|
|
5341
|
-
|
|
6939
|
+
kernel_mul_mv_q2_K_f32_impl(
|
|
5342
6940
|
src0[id],
|
|
5343
6941
|
(device const float *) (src1 + bid*nb11),
|
|
5344
6942
|
dst + bid*ne0,
|
|
@@ -5356,8 +6954,8 @@ kernel void kernel_mul_mv_id_q8_0_f32(
|
|
|
5356
6954
|
sgitg);
|
|
5357
6955
|
}
|
|
5358
6956
|
|
|
5359
|
-
[[host_name("
|
|
5360
|
-
kernel void
|
|
6957
|
+
[[host_name("kernel_mul_mv_id_q3_K_f32")]]
|
|
6958
|
+
kernel void kernel_mul_mv_id_q3_K_f32(
|
|
5361
6959
|
device const char * ids,
|
|
5362
6960
|
device const char * src1,
|
|
5363
6961
|
device float * dst,
|
|
@@ -5401,7 +6999,7 @@ kernel void kernel_mul_mv_id_q4_0_f32(
|
|
|
5401
6999
|
|
|
5402
7000
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
5403
7001
|
|
|
5404
|
-
|
|
7002
|
+
kernel_mul_mv_q3_K_f32_impl(
|
|
5405
7003
|
src0[id],
|
|
5406
7004
|
(device const float *) (src1 + bid*nb11),
|
|
5407
7005
|
dst + bid*ne0,
|
|
@@ -5419,8 +7017,8 @@ kernel void kernel_mul_mv_id_q4_0_f32(
|
|
|
5419
7017
|
sgitg);
|
|
5420
7018
|
}
|
|
5421
7019
|
|
|
5422
|
-
[[host_name("
|
|
5423
|
-
kernel void
|
|
7020
|
+
[[host_name("kernel_mul_mv_id_q4_K_f32")]]
|
|
7021
|
+
kernel void kernel_mul_mv_id_q4_K_f32(
|
|
5424
7022
|
device const char * ids,
|
|
5425
7023
|
device const char * src1,
|
|
5426
7024
|
device float * dst,
|
|
@@ -5464,7 +7062,7 @@ kernel void kernel_mul_mv_id_q4_1_f32(
|
|
|
5464
7062
|
|
|
5465
7063
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
5466
7064
|
|
|
5467
|
-
|
|
7065
|
+
kernel_mul_mv_q4_K_f32_impl(
|
|
5468
7066
|
src0[id],
|
|
5469
7067
|
(device const float *) (src1 + bid*nb11),
|
|
5470
7068
|
dst + bid*ne0,
|
|
@@ -5482,8 +7080,8 @@ kernel void kernel_mul_mv_id_q4_1_f32(
|
|
|
5482
7080
|
sgitg);
|
|
5483
7081
|
}
|
|
5484
7082
|
|
|
5485
|
-
[[host_name("
|
|
5486
|
-
kernel void
|
|
7083
|
+
[[host_name("kernel_mul_mv_id_q5_K_f32")]]
|
|
7084
|
+
kernel void kernel_mul_mv_id_q5_K_f32(
|
|
5487
7085
|
device const char * ids,
|
|
5488
7086
|
device const char * src1,
|
|
5489
7087
|
device float * dst,
|
|
@@ -5527,7 +7125,7 @@ kernel void kernel_mul_mv_id_q5_0_f32(
|
|
|
5527
7125
|
|
|
5528
7126
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
5529
7127
|
|
|
5530
|
-
|
|
7128
|
+
kernel_mul_mv_q5_K_f32_impl(
|
|
5531
7129
|
src0[id],
|
|
5532
7130
|
(device const float *) (src1 + bid*nb11),
|
|
5533
7131
|
dst + bid*ne0,
|
|
@@ -5545,8 +7143,8 @@ kernel void kernel_mul_mv_id_q5_0_f32(
|
|
|
5545
7143
|
sgitg);
|
|
5546
7144
|
}
|
|
5547
7145
|
|
|
5548
|
-
[[host_name("
|
|
5549
|
-
kernel void
|
|
7146
|
+
[[host_name("kernel_mul_mv_id_q6_K_f32")]]
|
|
7147
|
+
kernel void kernel_mul_mv_id_q6_K_f32(
|
|
5550
7148
|
device const char * ids,
|
|
5551
7149
|
device const char * src1,
|
|
5552
7150
|
device float * dst,
|
|
@@ -5590,7 +7188,7 @@ kernel void kernel_mul_mv_id_q5_1_f32(
|
|
|
5590
7188
|
|
|
5591
7189
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
5592
7190
|
|
|
5593
|
-
|
|
7191
|
+
kernel_mul_mv_q6_K_f32_impl(
|
|
5594
7192
|
src0[id],
|
|
5595
7193
|
(device const float *) (src1 + bid*nb11),
|
|
5596
7194
|
dst + bid*ne0,
|
|
@@ -5608,8 +7206,8 @@ kernel void kernel_mul_mv_id_q5_1_f32(
|
|
|
5608
7206
|
sgitg);
|
|
5609
7207
|
}
|
|
5610
7208
|
|
|
5611
|
-
[[host_name("
|
|
5612
|
-
kernel void
|
|
7209
|
+
[[host_name("kernel_mul_mv_id_iq2_xxs_f32")]]
|
|
7210
|
+
kernel void kernel_mul_mv_id_iq2_xxs_f32(
|
|
5613
7211
|
device const char * ids,
|
|
5614
7212
|
device const char * src1,
|
|
5615
7213
|
device float * dst,
|
|
@@ -5641,6 +7239,7 @@ kernel void kernel_mul_mv_id_q2_K_f32(
|
|
|
5641
7239
|
device const char * src05,
|
|
5642
7240
|
device const char * src06,
|
|
5643
7241
|
device const char * src07,
|
|
7242
|
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
|
5644
7243
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
5645
7244
|
uint tiitg[[thread_index_in_threadgroup]],
|
|
5646
7245
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
@@ -5653,7 +7252,7 @@ kernel void kernel_mul_mv_id_q2_K_f32(
|
|
|
5653
7252
|
|
|
5654
7253
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
5655
7254
|
|
|
5656
|
-
|
|
7255
|
+
kernel_mul_mv_iq2_xxs_f32_impl(
|
|
5657
7256
|
src0[id],
|
|
5658
7257
|
(device const float *) (src1 + bid*nb11),
|
|
5659
7258
|
dst + bid*ne0,
|
|
@@ -5666,13 +7265,14 @@ kernel void kernel_mul_mv_id_q2_K_f32(
|
|
|
5666
7265
|
ne1,
|
|
5667
7266
|
r2,
|
|
5668
7267
|
r3,
|
|
7268
|
+
shared_values,
|
|
5669
7269
|
tgpig,
|
|
5670
7270
|
tiisg,
|
|
5671
7271
|
sgitg);
|
|
5672
7272
|
}
|
|
5673
7273
|
|
|
5674
|
-
[[host_name("
|
|
5675
|
-
kernel void
|
|
7274
|
+
[[host_name("kernel_mul_mv_id_iq2_xs_f32")]]
|
|
7275
|
+
kernel void kernel_mul_mv_id_iq2_xs_f32(
|
|
5676
7276
|
device const char * ids,
|
|
5677
7277
|
device const char * src1,
|
|
5678
7278
|
device float * dst,
|
|
@@ -5704,6 +7304,7 @@ kernel void kernel_mul_mv_id_q3_K_f32(
|
|
|
5704
7304
|
device const char * src05,
|
|
5705
7305
|
device const char * src06,
|
|
5706
7306
|
device const char * src07,
|
|
7307
|
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
|
5707
7308
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
5708
7309
|
uint tiitg[[thread_index_in_threadgroup]],
|
|
5709
7310
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
@@ -5716,7 +7317,7 @@ kernel void kernel_mul_mv_id_q3_K_f32(
|
|
|
5716
7317
|
|
|
5717
7318
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
5718
7319
|
|
|
5719
|
-
|
|
7320
|
+
kernel_mul_mv_iq2_xs_f32_impl(
|
|
5720
7321
|
src0[id],
|
|
5721
7322
|
(device const float *) (src1 + bid*nb11),
|
|
5722
7323
|
dst + bid*ne0,
|
|
@@ -5729,13 +7330,14 @@ kernel void kernel_mul_mv_id_q3_K_f32(
|
|
|
5729
7330
|
ne1,
|
|
5730
7331
|
r2,
|
|
5731
7332
|
r3,
|
|
7333
|
+
shared_values,
|
|
5732
7334
|
tgpig,
|
|
5733
7335
|
tiisg,
|
|
5734
7336
|
sgitg);
|
|
5735
7337
|
}
|
|
5736
7338
|
|
|
5737
|
-
[[host_name("
|
|
5738
|
-
kernel void
|
|
7339
|
+
[[host_name("kernel_mul_mv_id_iq3_xxs_f32")]]
|
|
7340
|
+
kernel void kernel_mul_mv_id_iq3_xxs_f32(
|
|
5739
7341
|
device const char * ids,
|
|
5740
7342
|
device const char * src1,
|
|
5741
7343
|
device float * dst,
|
|
@@ -5767,6 +7369,7 @@ kernel void kernel_mul_mv_id_q4_K_f32(
|
|
|
5767
7369
|
device const char * src05,
|
|
5768
7370
|
device const char * src06,
|
|
5769
7371
|
device const char * src07,
|
|
7372
|
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
|
5770
7373
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
5771
7374
|
uint tiitg[[thread_index_in_threadgroup]],
|
|
5772
7375
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
@@ -5779,7 +7382,7 @@ kernel void kernel_mul_mv_id_q4_K_f32(
|
|
|
5779
7382
|
|
|
5780
7383
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
5781
7384
|
|
|
5782
|
-
|
|
7385
|
+
kernel_mul_mv_iq3_xxs_f32_impl(
|
|
5783
7386
|
src0[id],
|
|
5784
7387
|
(device const float *) (src1 + bid*nb11),
|
|
5785
7388
|
dst + bid*ne0,
|
|
@@ -5792,13 +7395,14 @@ kernel void kernel_mul_mv_id_q4_K_f32(
|
|
|
5792
7395
|
ne1,
|
|
5793
7396
|
r2,
|
|
5794
7397
|
r3,
|
|
7398
|
+
shared_values,
|
|
5795
7399
|
tgpig,
|
|
5796
7400
|
tiisg,
|
|
5797
7401
|
sgitg);
|
|
5798
7402
|
}
|
|
5799
7403
|
|
|
5800
|
-
[[host_name("
|
|
5801
|
-
kernel void
|
|
7404
|
+
[[host_name("kernel_mul_mv_id_iq3_s_f32")]]
|
|
7405
|
+
kernel void kernel_mul_mv_id_iq3_s_f32(
|
|
5802
7406
|
device const char * ids,
|
|
5803
7407
|
device const char * src1,
|
|
5804
7408
|
device float * dst,
|
|
@@ -5830,6 +7434,7 @@ kernel void kernel_mul_mv_id_q5_K_f32(
|
|
|
5830
7434
|
device const char * src05,
|
|
5831
7435
|
device const char * src06,
|
|
5832
7436
|
device const char * src07,
|
|
7437
|
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
|
5833
7438
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
5834
7439
|
uint tiitg[[thread_index_in_threadgroup]],
|
|
5835
7440
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
@@ -5842,7 +7447,7 @@ kernel void kernel_mul_mv_id_q5_K_f32(
|
|
|
5842
7447
|
|
|
5843
7448
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
5844
7449
|
|
|
5845
|
-
|
|
7450
|
+
kernel_mul_mv_iq3_s_f32_impl(
|
|
5846
7451
|
src0[id],
|
|
5847
7452
|
(device const float *) (src1 + bid*nb11),
|
|
5848
7453
|
dst + bid*ne0,
|
|
@@ -5855,13 +7460,14 @@ kernel void kernel_mul_mv_id_q5_K_f32(
|
|
|
5855
7460
|
ne1,
|
|
5856
7461
|
r2,
|
|
5857
7462
|
r3,
|
|
7463
|
+
shared_values,
|
|
5858
7464
|
tgpig,
|
|
5859
7465
|
tiisg,
|
|
5860
7466
|
sgitg);
|
|
5861
7467
|
}
|
|
5862
7468
|
|
|
5863
|
-
[[host_name("
|
|
5864
|
-
kernel void
|
|
7469
|
+
[[host_name("kernel_mul_mv_id_iq2_s_f32")]]
|
|
7470
|
+
kernel void kernel_mul_mv_id_iq2_s_f32(
|
|
5865
7471
|
device const char * ids,
|
|
5866
7472
|
device const char * src1,
|
|
5867
7473
|
device float * dst,
|
|
@@ -5893,6 +7499,7 @@ kernel void kernel_mul_mv_id_q6_K_f32(
|
|
|
5893
7499
|
device const char * src05,
|
|
5894
7500
|
device const char * src06,
|
|
5895
7501
|
device const char * src07,
|
|
7502
|
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
|
5896
7503
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
5897
7504
|
uint tiitg[[thread_index_in_threadgroup]],
|
|
5898
7505
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
@@ -5905,7 +7512,7 @@ kernel void kernel_mul_mv_id_q6_K_f32(
|
|
|
5905
7512
|
|
|
5906
7513
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
5907
7514
|
|
|
5908
|
-
|
|
7515
|
+
kernel_mul_mv_iq2_s_f32_impl(
|
|
5909
7516
|
src0[id],
|
|
5910
7517
|
(device const float *) (src1 + bid*nb11),
|
|
5911
7518
|
dst + bid*ne0,
|
|
@@ -5918,13 +7525,14 @@ kernel void kernel_mul_mv_id_q6_K_f32(
|
|
|
5918
7525
|
ne1,
|
|
5919
7526
|
r2,
|
|
5920
7527
|
r3,
|
|
7528
|
+
shared_values,
|
|
5921
7529
|
tgpig,
|
|
5922
7530
|
tiisg,
|
|
5923
7531
|
sgitg);
|
|
5924
7532
|
}
|
|
5925
7533
|
|
|
5926
|
-
[[host_name("
|
|
5927
|
-
kernel void
|
|
7534
|
+
[[host_name("kernel_mul_mv_id_iq1_s_f32")]]
|
|
7535
|
+
kernel void kernel_mul_mv_id_iq1_s_f32(
|
|
5928
7536
|
device const char * ids,
|
|
5929
7537
|
device const char * src1,
|
|
5930
7538
|
device float * dst,
|
|
@@ -5956,7 +7564,6 @@ kernel void kernel_mul_mv_id_iq2_xxs_f32(
|
|
|
5956
7564
|
device const char * src05,
|
|
5957
7565
|
device const char * src06,
|
|
5958
7566
|
device const char * src07,
|
|
5959
|
-
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
|
5960
7567
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
5961
7568
|
uint tiitg[[thread_index_in_threadgroup]],
|
|
5962
7569
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
@@ -5969,7 +7576,7 @@ kernel void kernel_mul_mv_id_iq2_xxs_f32(
|
|
|
5969
7576
|
|
|
5970
7577
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
5971
7578
|
|
|
5972
|
-
|
|
7579
|
+
kernel_mul_mv_iq1_s_f32_impl(
|
|
5973
7580
|
src0[id],
|
|
5974
7581
|
(device const float *) (src1 + bid*nb11),
|
|
5975
7582
|
dst + bid*ne0,
|
|
@@ -5982,14 +7589,13 @@ kernel void kernel_mul_mv_id_iq2_xxs_f32(
|
|
|
5982
7589
|
ne1,
|
|
5983
7590
|
r2,
|
|
5984
7591
|
r3,
|
|
5985
|
-
shared_values,
|
|
5986
7592
|
tgpig,
|
|
5987
7593
|
tiisg,
|
|
5988
7594
|
sgitg);
|
|
5989
7595
|
}
|
|
5990
7596
|
|
|
5991
|
-
[[host_name("
|
|
5992
|
-
kernel void
|
|
7597
|
+
[[host_name("kernel_mul_mv_id_iq4_nl_f32")]]
|
|
7598
|
+
kernel void kernel_mul_mv_id_iq4_nl_f32(
|
|
5993
7599
|
device const char * ids,
|
|
5994
7600
|
device const char * src1,
|
|
5995
7601
|
device float * dst,
|
|
@@ -6021,7 +7627,7 @@ kernel void kernel_mul_mv_id_iq2_xs_f32(
|
|
|
6021
7627
|
device const char * src05,
|
|
6022
7628
|
device const char * src06,
|
|
6023
7629
|
device const char * src07,
|
|
6024
|
-
threadgroup
|
|
7630
|
+
threadgroup float * shared_values [[threadgroup(0)]],
|
|
6025
7631
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
6026
7632
|
uint tiitg[[thread_index_in_threadgroup]],
|
|
6027
7633
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
@@ -6034,7 +7640,7 @@ kernel void kernel_mul_mv_id_iq2_xs_f32(
|
|
|
6034
7640
|
|
|
6035
7641
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
6036
7642
|
|
|
6037
|
-
|
|
7643
|
+
kernel_mul_mv_iq4_nl_f32_impl(
|
|
6038
7644
|
src0[id],
|
|
6039
7645
|
(device const float *) (src1 + bid*nb11),
|
|
6040
7646
|
dst + bid*ne0,
|
|
@@ -6053,8 +7659,8 @@ kernel void kernel_mul_mv_id_iq2_xs_f32(
|
|
|
6053
7659
|
sgitg);
|
|
6054
7660
|
}
|
|
6055
7661
|
|
|
6056
|
-
[[host_name("
|
|
6057
|
-
kernel void
|
|
7662
|
+
[[host_name("kernel_mul_mv_id_iq4_xs_f32")]]
|
|
7663
|
+
kernel void kernel_mul_mv_id_iq4_xs_f32(
|
|
6058
7664
|
device const char * ids,
|
|
6059
7665
|
device const char * src1,
|
|
6060
7666
|
device float * dst,
|
|
@@ -6086,7 +7692,7 @@ kernel void kernel_mul_mv_id_iq3_xxs_f32(
|
|
|
6086
7692
|
device const char * src05,
|
|
6087
7693
|
device const char * src06,
|
|
6088
7694
|
device const char * src07,
|
|
6089
|
-
threadgroup
|
|
7695
|
+
threadgroup float * shared_values [[threadgroup(0)]],
|
|
6090
7696
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
6091
7697
|
uint tiitg[[thread_index_in_threadgroup]],
|
|
6092
7698
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
@@ -6099,7 +7705,11 @@ kernel void kernel_mul_mv_id_iq3_xxs_f32(
|
|
|
6099
7705
|
|
|
6100
7706
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
|
6101
7707
|
|
|
6102
|
-
|
|
7708
|
+
#if QK_K == 64
|
|
7709
|
+
kernel_mul_mv_iq4_nl_f32_impl(
|
|
7710
|
+
#else
|
|
7711
|
+
kernel_mul_mv_iq4_xs_f32_impl(
|
|
7712
|
+
#endif
|
|
6103
7713
|
src0[id],
|
|
6104
7714
|
(device const float *) (src1 + bid*nb11),
|
|
6105
7715
|
dst + bid*ne0,
|