llama_cpp 0.12.6 → 0.13.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -351,12 +351,17 @@ kernel void kernel_sum_rows(
351
351
  kernel void kernel_soft_max(
352
352
  device const float * src0,
353
353
  device const float * src1,
354
+ device const float * src2,
354
355
  device float * dst,
355
356
  constant int64_t & ne00,
356
357
  constant int64_t & ne01,
357
358
  constant int64_t & ne02,
358
359
  constant float & scale,
359
- threadgroup float * buf [[threadgroup(0)]],
360
+ constant float & max_bias,
361
+ constant float & m0,
362
+ constant float & m1,
363
+ constant uint32_t & n_head_log2,
364
+ threadgroup float * buf [[threadgroup(0)]],
360
365
  uint tgpig[[threadgroup_position_in_grid]],
361
366
  uint tpitg[[thread_position_in_threadgroup]],
362
367
  uint sgitg[[simdgroup_index_in_threadgroup]],
@@ -368,13 +373,26 @@ kernel void kernel_soft_max(
368
373
 
369
374
  device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
370
375
  device const float * pmask = src1 != src0 ? src1 + i01*ne00 : nullptr;
376
+ device const float * ppos = src2 != src0 ? src2 : nullptr;
371
377
  device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
372
378
 
379
+ float slope = 0.0f;
380
+
381
+ // ALiBi
382
+ if (max_bias > 0.0f) {
383
+ const int64_t h = i02;
384
+
385
+ const float base = h < n_head_log2 ? m0 : m1;
386
+ const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
387
+
388
+ slope = pow(base, exp);
389
+ }
390
+
373
391
  // parallel max
374
392
  float lmax = -INFINITY;
375
393
 
376
394
  for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
377
- lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f));
395
+ lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f));
378
396
  }
379
397
 
380
398
  // find the max value in the block
@@ -399,7 +417,7 @@ kernel void kernel_soft_max(
399
417
  // parallel sum
400
418
  float lsum = 0.0f;
401
419
  for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
402
- const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val);
420
+ const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f)) - max_val);
403
421
  lsum += exp_psrc0;
404
422
  pdst[i00] = exp_psrc0;
405
423
  }
@@ -437,12 +455,17 @@ kernel void kernel_soft_max(
437
455
  kernel void kernel_soft_max_4(
438
456
  device const float * src0,
439
457
  device const float * src1,
458
+ device const float * src2,
440
459
  device float * dst,
441
460
  constant int64_t & ne00,
442
461
  constant int64_t & ne01,
443
462
  constant int64_t & ne02,
444
463
  constant float & scale,
445
- threadgroup float * buf [[threadgroup(0)]],
464
+ constant float & max_bias,
465
+ constant float & m0,
466
+ constant float & m1,
467
+ constant uint32_t & n_head_log2,
468
+ threadgroup float * buf [[threadgroup(0)]],
446
469
  uint tgpig[[threadgroup_position_in_grid]],
447
470
  uint tpitg[[thread_position_in_threadgroup]],
448
471
  uint sgitg[[simdgroup_index_in_threadgroup]],
@@ -454,13 +477,25 @@ kernel void kernel_soft_max_4(
454
477
 
455
478
  device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
456
479
  device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
480
+ device const float4 * ppos = src2 != src0 ? (device const float4 *)(src2) : nullptr;
457
481
  device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
458
482
 
483
+ float slope = 0.0f;
484
+
485
+ if (max_bias > 0.0f) {
486
+ const int64_t h = i02;
487
+
488
+ const float base = h < n_head_log2 ? m0 : m1;
489
+ const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
490
+
491
+ slope = pow(base, exp);
492
+ }
493
+
459
494
  // parallel max
460
495
  float4 lmax4 = -INFINITY;
461
496
 
462
497
  for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
463
- lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f));
498
+ lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f));
464
499
  }
465
500
 
466
501
  const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
@@ -486,7 +521,7 @@ kernel void kernel_soft_max_4(
486
521
  // parallel sum
487
522
  float4 lsum4 = 0.0f;
488
523
  for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
489
- const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val);
524
+ const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f)) - max_val);
490
525
  lsum4 += exp_psrc4;
491
526
  pdst4[i00] = exp_psrc4;
492
527
  }
@@ -2484,12 +2519,58 @@ typedef struct {
2484
2519
  } block_iq2_xs;
2485
2520
  // 74 bytes / block for QK_K = 256, so 2.3125 bpw
2486
2521
 
2522
+ // 2.5625 bpw quants
2523
+ typedef struct {
2524
+ half d;
2525
+ uint8_t qs[QK_K/4];
2526
+ uint8_t qh[QK_K/32];
2527
+ uint8_t scales[QK_K/32];
2528
+ } block_iq2_s;
2529
+
2487
2530
  typedef struct {
2488
2531
  half d;
2489
2532
  uint8_t qs[3*QK_K/8];
2490
2533
  } block_iq3_xxs;
2491
2534
  // 98 bytes / block for QK_K = 256, so 3.0625 bpw
2492
2535
 
2536
+ // 3.4375 bpw
2537
+ #if QK_K == 64
2538
+ #define IQ3S_N_SCALE 2
2539
+ #else
2540
+ #define IQ3S_N_SCALE QK_K/64
2541
+ #endif
2542
+ typedef struct {
2543
+ half d;
2544
+ uint8_t qs[QK_K/4];
2545
+ uint8_t qh[QK_K/32];
2546
+ uint8_t signs[QK_K/8];
2547
+ uint8_t scales[IQ3S_N_SCALE];
2548
+ } block_iq3_s;
2549
+
2550
+ typedef struct {
2551
+ half d;
2552
+ uint8_t qs[QK_K/8];
2553
+ uint8_t scales[QK_K/16];
2554
+ } block_iq1_s;
2555
+
2556
+ // Non-linear quants
2557
+ #define QK4_NL 32
2558
+ typedef struct {
2559
+ half d;
2560
+ uint8_t qs[QK4_NL/2];
2561
+ } block_iq4_nl;
2562
+
2563
+ #if QK_K == 64
2564
+ #define block_iq4_xs block_iq4_nl
2565
+ #else
2566
+ typedef struct {
2567
+ half d;
2568
+ uint16_t scales_h;
2569
+ uint8_t scales_l[QK_K/64];
2570
+ uint8_t qs[QK_K/2];
2571
+ } block_iq4_xs;
2572
+ #endif
2573
+
2493
2574
  //====================================== dot products =========================
2494
2575
 
2495
2576
  void kernel_mul_mv_q2_K_f32_impl(
@@ -3712,6 +3793,265 @@ constexpr constant static uint64_t iq2xs_grid[512] = {
3712
3793
  0x2b2b2b2b082b2b08, 0x2b2b2b2b082b2b2b, 0x2b2b2b2b2b190819, 0x2b2b2b2b2b2b2b2b,
3713
3794
  };
3714
3795
 
3796
+ constexpr constant static uint64_t iq2s_grid[1024] = {
3797
+ 0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
3798
+ 0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x080808080819192b,
3799
+ 0x0808080808192b19, 0x08080808082b0808, 0x08080808082b082b, 0x08080808082b1919,
3800
+ 0x08080808082b2b08, 0x0808080819080819, 0x0808080819081908, 0x080808081908192b,
3801
+ 0x0808080819082b19, 0x0808080819190808, 0x080808081919082b, 0x0808080819191919,
3802
+ 0x0808080819192b08, 0x08080808192b0819, 0x08080808192b1908, 0x08080808192b192b,
3803
+ 0x08080808192b2b19, 0x080808082b080808, 0x080808082b08082b, 0x080808082b081919,
3804
+ 0x080808082b082b08, 0x080808082b190819, 0x080808082b191908, 0x080808082b2b0808,
3805
+ 0x080808082b2b1919, 0x080808082b2b2b2b, 0x0808081908080819, 0x0808081908081908,
3806
+ 0x080808190808192b, 0x0808081908082b19, 0x0808081908190808, 0x080808190819082b,
3807
+ 0x0808081908191919, 0x0808081908192b08, 0x08080819082b0819, 0x08080819082b1908,
3808
+ 0x0808081919080808, 0x080808191908082b, 0x0808081919081919, 0x0808081919082b08,
3809
+ 0x0808081919190819, 0x0808081919191908, 0x080808191919192b, 0x0808081919192b19,
3810
+ 0x08080819192b0808, 0x08080819192b1919, 0x08080819192b2b08, 0x080808192b080819,
3811
+ 0x080808192b081908, 0x080808192b190808, 0x080808192b19082b, 0x080808192b191919,
3812
+ 0x080808192b2b0819, 0x080808192b2b1908, 0x0808082b08080808, 0x0808082b0808082b,
3813
+ 0x0808082b08081919, 0x0808082b08082b08, 0x0808082b08190819, 0x0808082b08191908,
3814
+ 0x0808082b082b0808, 0x0808082b082b2b2b, 0x0808082b19080819, 0x0808082b19081908,
3815
+ 0x0808082b1908192b, 0x0808082b19082b19, 0x0808082b19190808, 0x0808082b19191919,
3816
+ 0x0808082b2b080808, 0x0808082b2b081919, 0x0808082b2b082b2b, 0x0808082b2b191908,
3817
+ 0x0808082b2b2b082b, 0x0808190808080819, 0x0808190808081908, 0x080819080808192b,
3818
+ 0x0808190808082b19, 0x0808190808190808, 0x080819080819082b, 0x0808190808191919,
3819
+ 0x0808190808192b08, 0x08081908082b0819, 0x08081908082b1908, 0x08081908082b192b,
3820
+ 0x08081908082b2b19, 0x0808190819080808, 0x080819081908082b, 0x0808190819081919,
3821
+ 0x0808190819082b08, 0x0808190819082b2b, 0x0808190819190819, 0x0808190819191908,
3822
+ 0x080819081919192b, 0x0808190819192b19, 0x08081908192b0808, 0x08081908192b082b,
3823
+ 0x08081908192b1919, 0x080819082b080819, 0x080819082b081908, 0x080819082b08192b,
3824
+ 0x080819082b082b19, 0x080819082b190808, 0x080819082b191919, 0x080819082b192b08,
3825
+ 0x080819082b2b0819, 0x080819082b2b1908, 0x0808191908080808, 0x080819190808082b,
3826
+ 0x0808191908081919, 0x0808191908082b08, 0x0808191908082b2b, 0x0808191908190819,
3827
+ 0x0808191908191908, 0x080819190819192b, 0x0808191908192b19, 0x08081919082b0808,
3828
+ 0x08081919082b1919, 0x08081919082b2b08, 0x0808191919080819, 0x0808191919081908,
3829
+ 0x080819191908192b, 0x0808191919082b19, 0x0808191919190808, 0x080819191919082b,
3830
+ 0x0808191919191919, 0x0808191919192b08, 0x08081919192b0819, 0x08081919192b1908,
3831
+ 0x080819192b080808, 0x080819192b08082b, 0x080819192b081919, 0x080819192b082b08,
3832
+ 0x080819192b190819, 0x080819192b191908, 0x080819192b2b0808, 0x0808192b08080819,
3833
+ 0x0808192b08081908, 0x0808192b0808192b, 0x0808192b08082b19, 0x0808192b08190808,
3834
+ 0x0808192b08191919, 0x0808192b19080808, 0x0808192b19081919, 0x0808192b19082b08,
3835
+ 0x0808192b19190819, 0x0808192b19191908, 0x0808192b192b0808, 0x0808192b2b080819,
3836
+ 0x0808192b2b081908, 0x0808192b2b190808, 0x08082b0808080808, 0x08082b080808082b,
3837
+ 0x08082b0808081919, 0x08082b0808082b08, 0x08082b0808190819, 0x08082b0808191908,
3838
+ 0x08082b080819192b, 0x08082b0808192b19, 0x08082b08082b0808, 0x08082b08082b1919,
3839
+ 0x08082b08082b2b2b, 0x08082b0819080819, 0x08082b0819081908, 0x08082b081908192b,
3840
+ 0x08082b0819082b19, 0x08082b0819190808, 0x08082b081919082b, 0x08082b0819191919,
3841
+ 0x08082b0819192b08, 0x08082b08192b0819, 0x08082b08192b1908, 0x08082b082b080808,
3842
+ 0x08082b082b081919, 0x08082b082b191908, 0x08082b082b2b2b2b, 0x08082b1908080819,
3843
+ 0x08082b1908081908, 0x08082b1908190808, 0x08082b190819082b, 0x08082b1908191919,
3844
+ 0x08082b1908192b08, 0x08082b19082b0819, 0x08082b1919080808, 0x08082b1919081919,
3845
+ 0x08082b1919082b08, 0x08082b1919190819, 0x08082b1919191908, 0x08082b19192b0808,
3846
+ 0x08082b192b080819, 0x08082b192b190808, 0x08082b2b08080808, 0x08082b2b08190819,
3847
+ 0x08082b2b08191908, 0x08082b2b082b082b, 0x08082b2b082b2b08, 0x08082b2b082b2b2b,
3848
+ 0x08082b2b19190808, 0x08082b2b2b192b19, 0x0819080808080819, 0x0819080808081908,
3849
+ 0x081908080808192b, 0x0819080808082b19, 0x0819080808190808, 0x081908080819082b,
3850
+ 0x0819080808191919, 0x0819080808192b08, 0x08190808082b0819, 0x08190808082b1908,
3851
+ 0x08190808082b192b, 0x0819080819080808, 0x081908081908082b, 0x0819080819081919,
3852
+ 0x0819080819082b08, 0x0819080819190819, 0x0819080819191908, 0x081908081919192b,
3853
+ 0x0819080819192b19, 0x08190808192b0808, 0x08190808192b082b, 0x08190808192b1919,
3854
+ 0x08190808192b2b08, 0x081908082b080819, 0x081908082b081908, 0x081908082b08192b,
3855
+ 0x081908082b190808, 0x081908082b191919, 0x081908082b192b08, 0x081908082b2b0819,
3856
+ 0x081908082b2b1908, 0x0819081908080808, 0x081908190808082b, 0x0819081908081919,
3857
+ 0x0819081908082b08, 0x0819081908082b2b, 0x0819081908190819, 0x0819081908191908,
3858
+ 0x081908190819192b, 0x0819081908192b19, 0x08190819082b0808, 0x08190819082b082b,
3859
+ 0x08190819082b1919, 0x08190819082b2b08, 0x0819081919080819, 0x0819081919081908,
3860
+ 0x081908191908192b, 0x0819081919082b19, 0x0819081919190808, 0x081908191919082b,
3861
+ 0x0819081919191919, 0x0819081919192b08, 0x08190819192b0819, 0x08190819192b1908,
3862
+ 0x081908192b080808, 0x081908192b08082b, 0x081908192b081919, 0x081908192b082b08,
3863
+ 0x081908192b190819, 0x081908192b191908, 0x0819082b08080819, 0x0819082b08081908,
3864
+ 0x0819082b08082b19, 0x0819082b08190808, 0x0819082b08191919, 0x0819082b082b0819,
3865
+ 0x0819082b082b1908, 0x0819082b19080808, 0x0819082b19081919, 0x0819082b19190819,
3866
+ 0x0819082b19191908, 0x0819082b2b080819, 0x0819082b2b081908, 0x0819082b2b190808,
3867
+ 0x0819190808080808, 0x081919080808082b, 0x0819190808081919, 0x0819190808082b08,
3868
+ 0x0819190808190819, 0x0819190808191908, 0x081919080819192b, 0x0819190808192b19,
3869
+ 0x08191908082b0808, 0x08191908082b1919, 0x08191908082b2b08, 0x0819190819080819,
3870
+ 0x0819190819081908, 0x081919081908192b, 0x0819190819082b19, 0x0819190819190808,
3871
+ 0x081919081919082b, 0x0819190819191919, 0x0819190819192b08, 0x08191908192b0819,
3872
+ 0x08191908192b1908, 0x081919082b080808, 0x081919082b08082b, 0x081919082b081919,
3873
+ 0x081919082b082b08, 0x081919082b190819, 0x081919082b191908, 0x081919082b2b0808,
3874
+ 0x0819191908080819, 0x0819191908081908, 0x081919190808192b, 0x0819191908082b19,
3875
+ 0x0819191908190808, 0x081919190819082b, 0x0819191908191919, 0x0819191908192b08,
3876
+ 0x08191919082b0819, 0x08191919082b1908, 0x0819191919080808, 0x081919191908082b,
3877
+ 0x0819191919081919, 0x0819191919082b08, 0x0819191919190819, 0x0819191919191908,
3878
+ 0x08191919192b0808, 0x081919192b080819, 0x081919192b081908, 0x081919192b190808,
3879
+ 0x0819192b08080808, 0x0819192b08081919, 0x0819192b08082b08, 0x0819192b08190819,
3880
+ 0x0819192b08191908, 0x0819192b082b0808, 0x0819192b19080819, 0x0819192b19081908,
3881
+ 0x0819192b19190808, 0x0819192b2b080808, 0x0819192b2b2b2b2b, 0x08192b0808080819,
3882
+ 0x08192b0808081908, 0x08192b080808192b, 0x08192b0808082b19, 0x08192b0808190808,
3883
+ 0x08192b0808191919, 0x08192b0808192b08, 0x08192b08082b0819, 0x08192b0819080808,
3884
+ 0x08192b081908082b, 0x08192b0819081919, 0x08192b0819082b08, 0x08192b0819190819,
3885
+ 0x08192b0819191908, 0x08192b08192b0808, 0x08192b082b080819, 0x08192b082b081908,
3886
+ 0x08192b1908080808, 0x08192b190808082b, 0x08192b1908081919, 0x08192b1908082b08,
3887
+ 0x08192b1908190819, 0x08192b1908191908, 0x08192b19082b0808, 0x08192b1919080819,
3888
+ 0x08192b1919081908, 0x08192b1919190808, 0x08192b19192b2b19, 0x08192b192b2b082b,
3889
+ 0x08192b2b08081908, 0x08192b2b08190808, 0x08192b2b19080808, 0x08192b2b1919192b,
3890
+ 0x082b080808080808, 0x082b08080808082b, 0x082b080808081919, 0x082b080808082b08,
3891
+ 0x082b080808190819, 0x082b080808191908, 0x082b08080819192b, 0x082b080808192b19,
3892
+ 0x082b0808082b0808, 0x082b0808082b1919, 0x082b0808082b2b2b, 0x082b080819080819,
3893
+ 0x082b080819081908, 0x082b080819190808, 0x082b08081919082b, 0x082b080819191919,
3894
+ 0x082b0808192b1908, 0x082b08082b080808, 0x082b08082b082b2b, 0x082b08082b191908,
3895
+ 0x082b08082b2b2b2b, 0x082b081908080819, 0x082b081908081908, 0x082b081908190808,
3896
+ 0x082b08190819082b, 0x082b081908191919, 0x082b0819082b0819, 0x082b081919080808,
3897
+ 0x082b08191908082b, 0x082b081919081919, 0x082b081919190819, 0x082b081919191908,
3898
+ 0x082b0819192b0808, 0x082b08192b080819, 0x082b08192b081908, 0x082b08192b190808,
3899
+ 0x082b082b08080808, 0x082b082b08082b2b, 0x082b082b082b082b, 0x082b082b082b2b08,
3900
+ 0x082b082b082b2b2b, 0x082b082b19081908, 0x082b082b19190808, 0x082b082b2b082b08,
3901
+ 0x082b082b2b082b2b, 0x082b082b2b2b2b08, 0x082b190808080819, 0x082b190808081908,
3902
+ 0x082b19080808192b, 0x082b190808082b19, 0x082b190808190808, 0x082b190808191919,
3903
+ 0x082b190808192b08, 0x082b1908082b0819, 0x082b1908082b1908, 0x082b190819080808,
3904
+ 0x082b19081908082b, 0x082b190819081919, 0x082b190819082b08, 0x082b190819190819,
3905
+ 0x082b190819191908, 0x082b1908192b0808, 0x082b19082b080819, 0x082b19082b081908,
3906
+ 0x082b19082b190808, 0x082b191908080808, 0x082b191908081919, 0x082b191908082b08,
3907
+ 0x082b191908190819, 0x082b191908191908, 0x082b1919082b0808, 0x082b191919080819,
3908
+ 0x082b191919081908, 0x082b191919190808, 0x082b1919192b192b, 0x082b19192b080808,
3909
+ 0x082b192b08080819, 0x082b192b08081908, 0x082b192b08190808, 0x082b192b19080808,
3910
+ 0x082b192b19192b19, 0x082b2b0808080808, 0x082b2b0808081919, 0x082b2b0808190819,
3911
+ 0x082b2b0808191908, 0x082b2b0819080819, 0x082b2b0819081908, 0x082b2b0819190808,
3912
+ 0x082b2b082b082b2b, 0x082b2b082b2b2b2b, 0x082b2b1908080819, 0x082b2b1908081908,
3913
+ 0x082b2b1908190808, 0x082b2b192b191919, 0x082b2b2b08082b2b, 0x082b2b2b082b082b,
3914
+ 0x082b2b2b192b1908, 0x082b2b2b2b082b08, 0x082b2b2b2b082b2b, 0x1908080808080819,
3915
+ 0x1908080808081908, 0x190808080808192b, 0x1908080808082b19, 0x1908080808190808,
3916
+ 0x190808080819082b, 0x1908080808191919, 0x1908080808192b08, 0x1908080808192b2b,
3917
+ 0x19080808082b0819, 0x19080808082b1908, 0x19080808082b192b, 0x1908080819080808,
3918
+ 0x190808081908082b, 0x1908080819081919, 0x1908080819082b08, 0x1908080819082b2b,
3919
+ 0x1908080819190819, 0x1908080819191908, 0x190808081919192b, 0x1908080819192b19,
3920
+ 0x19080808192b0808, 0x19080808192b082b, 0x19080808192b1919, 0x190808082b080819,
3921
+ 0x190808082b081908, 0x190808082b190808, 0x190808082b191919, 0x190808082b192b08,
3922
+ 0x190808082b2b0819, 0x190808082b2b1908, 0x1908081908080808, 0x190808190808082b,
3923
+ 0x1908081908081919, 0x1908081908082b08, 0x1908081908190819, 0x1908081908191908,
3924
+ 0x190808190819192b, 0x1908081908192b19, 0x19080819082b0808, 0x19080819082b082b,
3925
+ 0x19080819082b1919, 0x1908081919080819, 0x1908081919081908, 0x190808191908192b,
3926
+ 0x1908081919082b19, 0x1908081919190808, 0x190808191919082b, 0x1908081919191919,
3927
+ 0x1908081919192b08, 0x19080819192b0819, 0x19080819192b1908, 0x190808192b080808,
3928
+ 0x190808192b08082b, 0x190808192b081919, 0x190808192b082b08, 0x190808192b190819,
3929
+ 0x190808192b191908, 0x190808192b2b0808, 0x1908082b08080819, 0x1908082b08081908,
3930
+ 0x1908082b08190808, 0x1908082b0819082b, 0x1908082b08191919, 0x1908082b08192b08,
3931
+ 0x1908082b082b1908, 0x1908082b19080808, 0x1908082b19081919, 0x1908082b19082b08,
3932
+ 0x1908082b19190819, 0x1908082b19191908, 0x1908082b192b0808, 0x1908082b2b080819,
3933
+ 0x1908082b2b081908, 0x1908190808080808, 0x190819080808082b, 0x1908190808081919,
3934
+ 0x1908190808082b08, 0x1908190808082b2b, 0x1908190808190819, 0x1908190808191908,
3935
+ 0x190819080819192b, 0x1908190808192b19, 0x19081908082b0808, 0x19081908082b082b,
3936
+ 0x19081908082b1919, 0x19081908082b2b08, 0x1908190819080819, 0x1908190819081908,
3937
+ 0x190819081908192b, 0x1908190819082b19, 0x1908190819190808, 0x190819081919082b,
3938
+ 0x1908190819191919, 0x1908190819192b08, 0x19081908192b0819, 0x19081908192b1908,
3939
+ 0x190819082b080808, 0x190819082b08082b, 0x190819082b081919, 0x190819082b082b08,
3940
+ 0x190819082b190819, 0x190819082b191908, 0x190819082b2b0808, 0x1908191908080819,
3941
+ 0x1908191908081908, 0x190819190808192b, 0x1908191908082b19, 0x1908191908190808,
3942
+ 0x190819190819082b, 0x1908191908191919, 0x1908191908192b08, 0x19081919082b0819,
3943
+ 0x19081919082b1908, 0x1908191919080808, 0x190819191908082b, 0x1908191919081919,
3944
+ 0x1908191919082b08, 0x1908191919190819, 0x1908191919191908, 0x19081919192b0808,
3945
+ 0x19081919192b2b2b, 0x190819192b080819, 0x190819192b081908, 0x190819192b190808,
3946
+ 0x1908192b08080808, 0x1908192b0808082b, 0x1908192b08081919, 0x1908192b08082b08,
3947
+ 0x1908192b08190819, 0x1908192b08191908, 0x1908192b082b0808, 0x1908192b19080819,
3948
+ 0x1908192b19081908, 0x1908192b19190808, 0x1908192b2b080808, 0x1908192b2b2b1919,
3949
+ 0x19082b0808080819, 0x19082b0808081908, 0x19082b0808082b19, 0x19082b0808190808,
3950
+ 0x19082b080819082b, 0x19082b0808191919, 0x19082b0808192b08, 0x19082b08082b0819,
3951
+ 0x19082b08082b1908, 0x19082b0819080808, 0x19082b081908082b, 0x19082b0819081919,
3952
+ 0x19082b0819082b08, 0x19082b0819190819, 0x19082b0819191908, 0x19082b08192b0808,
3953
+ 0x19082b082b081908, 0x19082b082b190808, 0x19082b1908080808, 0x19082b190808082b,
3954
+ 0x19082b1908081919, 0x19082b1908082b08, 0x19082b1908190819, 0x19082b1908191908,
3955
+ 0x19082b19082b0808, 0x19082b1919080819, 0x19082b1919081908, 0x19082b1919190808,
3956
+ 0x19082b192b080808, 0x19082b192b19192b, 0x19082b2b08080819, 0x19082b2b08081908,
3957
+ 0x19082b2b08190808, 0x19082b2b19080808, 0x1919080808080808, 0x191908080808082b,
3958
+ 0x1919080808081919, 0x1919080808082b08, 0x1919080808190819, 0x1919080808191908,
3959
+ 0x191908080819192b, 0x1919080808192b19, 0x19190808082b0808, 0x19190808082b082b,
3960
+ 0x19190808082b1919, 0x19190808082b2b08, 0x1919080819080819, 0x1919080819081908,
3961
+ 0x191908081908192b, 0x1919080819082b19, 0x1919080819190808, 0x191908081919082b,
3962
+ 0x1919080819191919, 0x1919080819192b08, 0x19190808192b0819, 0x19190808192b1908,
3963
+ 0x191908082b080808, 0x191908082b08082b, 0x191908082b081919, 0x191908082b082b08,
3964
+ 0x191908082b190819, 0x191908082b191908, 0x1919081908080819, 0x1919081908081908,
3965
+ 0x191908190808192b, 0x1919081908082b19, 0x1919081908190808, 0x191908190819082b,
3966
+ 0x1919081908191919, 0x1919081908192b08, 0x19190819082b0819, 0x19190819082b1908,
3967
+ 0x1919081919080808, 0x191908191908082b, 0x1919081919081919, 0x1919081919082b08,
3968
+ 0x1919081919190819, 0x1919081919191908, 0x19190819192b0808, 0x191908192b080819,
3969
+ 0x191908192b081908, 0x191908192b190808, 0x1919082b08080808, 0x1919082b08081919,
3970
+ 0x1919082b08082b08, 0x1919082b08190819, 0x1919082b08191908, 0x1919082b082b0808,
3971
+ 0x1919082b19080819, 0x1919082b19081908, 0x1919082b19190808, 0x1919082b192b2b19,
3972
+ 0x1919082b2b080808, 0x1919190808080819, 0x1919190808081908, 0x191919080808192b,
3973
+ 0x1919190808082b19, 0x1919190808190808, 0x191919080819082b, 0x1919190808191919,
3974
+ 0x1919190808192b08, 0x19191908082b0819, 0x19191908082b1908, 0x1919190819080808,
3975
+ 0x191919081908082b, 0x1919190819081919, 0x1919190819082b08, 0x1919190819190819,
3976
+ 0x1919190819191908, 0x19191908192b0808, 0x191919082b080819, 0x191919082b081908,
3977
+ 0x191919082b190808, 0x1919191908080808, 0x191919190808082b, 0x1919191908081919,
3978
+ 0x1919191908082b08, 0x1919191908190819, 0x1919191908191908, 0x19191919082b0808,
3979
+ 0x1919191919080819, 0x1919191919081908, 0x1919191919190808, 0x191919192b080808,
3980
+ 0x1919192b08080819, 0x1919192b08081908, 0x1919192b08190808, 0x1919192b082b192b,
3981
+ 0x1919192b19080808, 0x19192b0808080808, 0x19192b080808082b, 0x19192b0808081919,
3982
+ 0x19192b0808082b08, 0x19192b0808190819, 0x19192b0808191908, 0x19192b08082b0808,
3983
+ 0x19192b0819080819, 0x19192b0819081908, 0x19192b0819190808, 0x19192b0819192b2b,
3984
+ 0x19192b082b080808, 0x19192b1908080819, 0x19192b1908081908, 0x19192b1908190808,
3985
+ 0x19192b1919080808, 0x19192b2b08080808, 0x19192b2b08192b19, 0x19192b2b2b081919,
3986
+ 0x19192b2b2b2b2b08, 0x192b080808080819, 0x192b080808081908, 0x192b08080808192b,
3987
+ 0x192b080808190808, 0x192b08080819082b, 0x192b080808191919, 0x192b080808192b08,
3988
+ 0x192b0808082b0819, 0x192b0808082b1908, 0x192b080819080808, 0x192b080819081919,
3989
+ 0x192b080819082b08, 0x192b080819190819, 0x192b080819191908, 0x192b0808192b0808,
3990
+ 0x192b08082b081908, 0x192b08082b190808, 0x192b081908080808, 0x192b08190808082b,
3991
+ 0x192b081908081919, 0x192b081908082b08, 0x192b081908190819, 0x192b081908191908,
3992
+ 0x192b0819082b0808, 0x192b081919080819, 0x192b081919081908, 0x192b081919190808,
3993
+ 0x192b08192b080808, 0x192b08192b192b19, 0x192b082b08081908, 0x192b082b08190808,
3994
+ 0x192b082b19080808, 0x192b082b1919192b, 0x192b082b2b2b0819, 0x192b190808080808,
3995
+ 0x192b190808081919, 0x192b190808082b08, 0x192b190808190819, 0x192b190808191908,
3996
+ 0x192b1908082b0808, 0x192b190819080819, 0x192b190819081908, 0x192b190819190808,
3997
+ 0x192b19082b080808, 0x192b191908080819, 0x192b191908081908, 0x192b191908190808,
3998
+ 0x192b191919080808, 0x192b191919082b2b, 0x192b1919192b2b08, 0x192b19192b19082b,
3999
+ 0x192b192b08080808, 0x192b192b2b191908, 0x192b2b0808080819, 0x192b2b0808081908,
4000
+ 0x192b2b0808190808, 0x192b2b08192b1919, 0x192b2b082b192b08, 0x192b2b1908080808,
4001
+ 0x192b2b19082b2b2b, 0x192b2b2b1908082b, 0x192b2b2b2b2b0819, 0x2b08080808080808,
4002
+ 0x2b0808080808082b, 0x2b08080808081919, 0x2b08080808082b08, 0x2b08080808190819,
4003
+ 0x2b08080808191908, 0x2b08080808192b19, 0x2b080808082b0808, 0x2b080808082b1919,
4004
+ 0x2b08080819080819, 0x2b08080819081908, 0x2b08080819190808, 0x2b0808081919082b,
4005
+ 0x2b08080819191919, 0x2b08080819192b08, 0x2b080808192b0819, 0x2b0808082b080808,
4006
+ 0x2b0808082b081919, 0x2b0808082b190819, 0x2b0808082b191908, 0x2b08081908080819,
4007
+ 0x2b08081908081908, 0x2b08081908082b19, 0x2b08081908190808, 0x2b0808190819082b,
4008
+ 0x2b08081908191919, 0x2b08081908192b08, 0x2b080819082b0819, 0x2b080819082b1908,
4009
+ 0x2b08081919080808, 0x2b0808191908082b, 0x2b08081919081919, 0x2b08081919082b08,
4010
+ 0x2b08081919190819, 0x2b08081919191908, 0x2b0808192b080819, 0x2b0808192b081908,
4011
+ 0x2b0808192b190808, 0x2b0808192b2b2b19, 0x2b08082b08080808, 0x2b08082b08081919,
4012
+ 0x2b08082b08082b2b, 0x2b08082b08190819, 0x2b08082b08191908, 0x2b08082b19080819,
4013
+ 0x2b08082b19081908, 0x2b08082b19190808, 0x2b08190808080819, 0x2b08190808081908,
4014
+ 0x2b0819080808192b, 0x2b08190808082b19, 0x2b08190808190808, 0x2b0819080819082b,
4015
+ 0x2b08190808191919, 0x2b08190808192b08, 0x2b081908082b0819, 0x2b08190819080808,
4016
+ 0x2b0819081908082b, 0x2b08190819081919, 0x2b08190819082b08, 0x2b08190819190819,
4017
+ 0x2b08190819191908, 0x2b081908192b0808, 0x2b0819082b080819, 0x2b0819082b081908,
4018
+ 0x2b0819082b190808, 0x2b08191908080808, 0x2b0819190808082b, 0x2b08191908081919,
4019
+ 0x2b08191908082b08, 0x2b08191908190819, 0x2b08191908191908, 0x2b081919082b0808,
4020
+ 0x2b08191919080819, 0x2b08191919081908, 0x2b08191919190808, 0x2b0819192b080808,
4021
+ 0x2b0819192b082b2b, 0x2b08192b08080819, 0x2b08192b08081908, 0x2b08192b08190808,
4022
+ 0x2b08192b082b2b19, 0x2b08192b19080808, 0x2b082b0808080808, 0x2b082b0808081919,
4023
+ 0x2b082b0808190819, 0x2b082b0808191908, 0x2b082b0819080819, 0x2b082b0819081908,
4024
+ 0x2b082b0819190808, 0x2b082b082b2b082b, 0x2b082b1908080819, 0x2b082b1908081908,
4025
+ 0x2b082b1919080808, 0x2b082b19192b1919, 0x2b082b2b082b082b, 0x2b082b2b19192b08,
4026
+ 0x2b082b2b19192b2b, 0x2b082b2b2b08082b, 0x2b082b2b2b2b082b, 0x2b19080808080819,
4027
+ 0x2b19080808081908, 0x2b19080808082b19, 0x2b19080808190808, 0x2b1908080819082b,
4028
+ 0x2b19080808191919, 0x2b19080808192b08, 0x2b190808082b1908, 0x2b19080819080808,
4029
+ 0x2b1908081908082b, 0x2b19080819081919, 0x2b19080819082b08, 0x2b19080819190819,
4030
+ 0x2b19080819191908, 0x2b190808192b0808, 0x2b1908082b080819, 0x2b1908082b081908,
4031
+ 0x2b1908082b190808, 0x2b19081908080808, 0x2b19081908081919, 0x2b19081908190819,
4032
+ 0x2b19081908191908, 0x2b19081919080819, 0x2b19081919081908, 0x2b19081919190808,
4033
+ 0x2b19081919192b2b, 0x2b19082b08080819, 0x2b19082b08081908, 0x2b19082b08190808,
4034
+ 0x2b19082b19080808, 0x2b19082b2b2b192b, 0x2b19190808080808, 0x2b1919080808082b,
4035
+ 0x2b19190808081919, 0x2b19190808082b08, 0x2b19190808190819, 0x2b19190808191908,
4036
+ 0x2b191908082b0808, 0x2b19190819080819, 0x2b19190819081908, 0x2b19190819190808,
4037
+ 0x2b1919082b080808, 0x2b1919082b19192b, 0x2b19191908080819, 0x2b19191908081908,
4038
+ 0x2b19191908190808, 0x2b19191919080808, 0x2b1919192b192b08, 0x2b1919192b2b0819,
4039
+ 0x2b19192b08080808, 0x2b19192b1908192b, 0x2b19192b192b1908, 0x2b192b0808080819,
4040
+ 0x2b192b0808081908, 0x2b192b0808190808, 0x2b192b08082b192b, 0x2b192b0819080808,
4041
+ 0x2b192b082b2b2b19, 0x2b192b1908080808, 0x2b192b1919082b19, 0x2b192b191919082b,
4042
+ 0x2b192b2b2b190808, 0x2b2b080808080808, 0x2b2b080808081919, 0x2b2b080808082b2b,
4043
+ 0x2b2b080808191908, 0x2b2b0808082b082b, 0x2b2b0808082b2b2b, 0x2b2b080819080819,
4044
+ 0x2b2b080819081908, 0x2b2b080819190808, 0x2b2b08082b2b082b, 0x2b2b08082b2b2b2b,
4045
+ 0x2b2b081919080808, 0x2b2b0819192b1919, 0x2b2b082b0808082b, 0x2b2b082b08082b2b,
4046
+ 0x2b2b082b082b082b, 0x2b2b082b082b2b08, 0x2b2b082b082b2b2b, 0x2b2b082b2b08082b,
4047
+ 0x2b2b082b2b082b08, 0x2b2b082b2b082b2b, 0x2b2b082b2b2b2b08, 0x2b2b190808080819,
4048
+ 0x2b2b190808081908, 0x2b2b190808190808, 0x2b2b190819080808, 0x2b2b19082b082b19,
4049
+ 0x2b2b19082b2b1908, 0x2b2b191908080808, 0x2b2b191908192b19, 0x2b2b192b19190819,
4050
+ 0x2b2b2b0808082b2b, 0x2b2b2b08082b2b08, 0x2b2b2b082b2b082b, 0x2b2b2b1919191908,
4051
+ 0x2b2b2b192b08192b, 0x2b2b2b2b08082b08, 0x2b2b2b2b08082b2b, 0x2b2b2b2b082b0808,
4052
+ 0x2b2b2b2b082b082b, 0x2b2b2b2b082b2b08, 0x2b2b2b2b2b082b08, 0x2b2b2b2b2b2b2b2b,
4053
+ };
4054
+
3715
4055
  constexpr constant static uint32_t iq3xxs_grid[256] = {
3716
4056
  0x04040404, 0x04040414, 0x04040424, 0x04040c0c, 0x04040c1c, 0x04040c3e, 0x04041404, 0x04041414,
3717
4057
  0x04041c0c, 0x04042414, 0x04043e1c, 0x04043e2c, 0x040c040c, 0x040c041c, 0x040c0c04, 0x040c0c14,
@@ -3747,6 +4087,204 @@ constexpr constant static uint32_t iq3xxs_grid[256] = {
3747
4087
  0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
3748
4088
  };
3749
4089
 
4090
+ constexpr constant static uint32_t iq3xs_grid[512] = {
4091
+ 0x04040404, 0x0404040c, 0x04040414, 0x0404042c, 0x0404043e, 0x04040c04, 0x04040c0c, 0x04040c14,
4092
+ 0x04040c24, 0x04040c34, 0x04041404, 0x0404140c, 0x0404142c, 0x04041c1c, 0x04042404, 0x04042414,
4093
+ 0x0404242c, 0x0404243e, 0x04042c0c, 0x04042c1c, 0x04043404, 0x04043414, 0x04043e0c, 0x04043e24,
4094
+ 0x04043e3e, 0x040c0404, 0x040c040c, 0x040c0414, 0x040c0424, 0x040c0c04, 0x040c0c0c, 0x040c0c2c,
4095
+ 0x040c1404, 0x040c141c, 0x040c143e, 0x040c1c0c, 0x040c1c2c, 0x040c2424, 0x040c340c, 0x040c342c,
4096
+ 0x040c3e14, 0x04140404, 0x0414040c, 0x0414042c, 0x0414043e, 0x04140c04, 0x04140c1c, 0x04140c34,
4097
+ 0x0414140c, 0x0414142c, 0x04141c04, 0x04141c24, 0x04142414, 0x0414242c, 0x0414243e, 0x04142c0c,
4098
+ 0x04142c1c, 0x04143e04, 0x04143e1c, 0x041c041c, 0x041c0c0c, 0x041c0c2c, 0x041c1404, 0x041c1414,
4099
+ 0x041c1c0c, 0x041c1c1c, 0x041c1c34, 0x041c2424, 0x041c2c04, 0x041c2c14, 0x041c343e, 0x041c3e0c,
4100
+ 0x041c3e2c, 0x04240404, 0x04240c1c, 0x04240c3e, 0x0424140c, 0x04241424, 0x04241c14, 0x04242404,
4101
+ 0x0424241c, 0x04242c0c, 0x04243e04, 0x042c0414, 0x042c0424, 0x042c1404, 0x042c1414, 0x042c1434,
4102
+ 0x042c1c1c, 0x042c240c, 0x042c242c, 0x042c243e, 0x042c3434, 0x042c3e1c, 0x04340434, 0x04340c0c,
4103
+ 0x04340c1c, 0x04341c0c, 0x04342c14, 0x04343e0c, 0x043e0404, 0x043e0414, 0x043e0424, 0x043e1404,
4104
+ 0x043e1414, 0x043e1434, 0x043e1c1c, 0x043e2c04, 0x043e2c24, 0x0c040404, 0x0c04040c, 0x0c040414,
4105
+ 0x0c040424, 0x0c040c04, 0x0c040c0c, 0x0c040c1c, 0x0c040c2c, 0x0c040c3e, 0x0c041404, 0x0c041414,
4106
+ 0x0c041c0c, 0x0c041c24, 0x0c041c34, 0x0c042c24, 0x0c042c34, 0x0c04340c, 0x0c043e14, 0x0c0c0404,
4107
+ 0x0c0c040c, 0x0c0c041c, 0x0c0c0434, 0x0c0c0c04, 0x0c0c0c24, 0x0c0c140c, 0x0c0c1c04, 0x0c0c1c1c,
4108
+ 0x0c0c240c, 0x0c0c2c04, 0x0c0c2c14, 0x0c0c3e04, 0x0c0c3e34, 0x0c140404, 0x0c140c14, 0x0c140c2c,
4109
+ 0x0c140c3e, 0x0c141404, 0x0c141424, 0x0c141c14, 0x0c142404, 0x0c14241c, 0x0c142c2c, 0x0c143404,
4110
+ 0x0c143e14, 0x0c1c040c, 0x0c1c0424, 0x0c1c043e, 0x0c1c0c04, 0x0c1c0c1c, 0x0c1c140c, 0x0c1c143e,
4111
+ 0x0c1c1c04, 0x0c1c1c24, 0x0c1c240c, 0x0c1c3414, 0x0c1c3e04, 0x0c24041c, 0x0c24042c, 0x0c240c14,
4112
+ 0x0c240c24, 0x0c241c0c, 0x0c241c1c, 0x0c242414, 0x0c242434, 0x0c242c04, 0x0c242c24, 0x0c2c040c,
4113
+ 0x0c2c0c04, 0x0c2c0c1c, 0x0c2c140c, 0x0c2c1c04, 0x0c2c1c14, 0x0c2c2c0c, 0x0c341404, 0x0c341424,
4114
+ 0x0c34143e, 0x0c342424, 0x0c342434, 0x0c3e040c, 0x0c3e041c, 0x0c3e0c04, 0x0c3e0c14, 0x0c3e140c,
4115
+ 0x0c3e1c2c, 0x0c3e240c, 0x0c3e3414, 0x0c3e3e04, 0x14040404, 0x1404040c, 0x1404041c, 0x1404042c,
4116
+ 0x1404043e, 0x14040c04, 0x14040c14, 0x14040c24, 0x14040c34, 0x1404140c, 0x1404141c, 0x1404143e,
4117
+ 0x14041c04, 0x14041c14, 0x1404240c, 0x1404241c, 0x1404242c, 0x14042c04, 0x14042c14, 0x1404343e,
4118
+ 0x14043e04, 0x14043e1c, 0x14043e2c, 0x140c0404, 0x140c0414, 0x140c0c04, 0x140c0c1c, 0x140c0c3e,
4119
+ 0x140c1414, 0x140c142c, 0x140c1c0c, 0x140c1c24, 0x140c2414, 0x140c2c0c, 0x1414040c, 0x14140424,
4120
+ 0x1414043e, 0x1414140c, 0x1414141c, 0x14141c04, 0x14141c3e, 0x1414240c, 0x14142c1c, 0x14142c3e,
4121
+ 0x14143e0c, 0x14143e24, 0x141c0404, 0x141c0414, 0x141c042c, 0x141c0c0c, 0x141c1414, 0x141c1424,
4122
+ 0x141c1c0c, 0x141c1c1c, 0x141c2414, 0x141c2c04, 0x141c3434, 0x1424040c, 0x1424043e, 0x14241404,
4123
+ 0x1424141c, 0x14241c14, 0x14241c2c, 0x1424240c, 0x14243e14, 0x14243e2c, 0x142c0424, 0x142c0c0c,
4124
+ 0x142c1414, 0x142c1c3e, 0x142c2404, 0x142c2c1c, 0x142c3e04, 0x14340404, 0x14340414, 0x1434043e,
4125
+ 0x1434140c, 0x14342c2c, 0x1434340c, 0x143e042c, 0x143e0c0c, 0x143e1434, 0x143e1c04, 0x143e241c,
4126
+ 0x143e2c04, 0x1c040414, 0x1c040c0c, 0x1c040c1c, 0x1c040c2c, 0x1c040c3e, 0x1c041414, 0x1c041c0c,
4127
+ 0x1c041c1c, 0x1c041c2c, 0x1c042414, 0x1c042424, 0x1c04243e, 0x1c042c0c, 0x1c04341c, 0x1c043e0c,
4128
+ 0x1c0c040c, 0x1c0c041c, 0x1c0c042c, 0x1c0c0c24, 0x1c0c140c, 0x1c0c141c, 0x1c0c2404, 0x1c0c3404,
4129
+ 0x1c0c3e14, 0x1c0c3e34, 0x1c140404, 0x1c140c14, 0x1c141404, 0x1c141c14, 0x1c141c24, 0x1c142c04,
4130
+ 0x1c1c040c, 0x1c1c0c04, 0x1c1c0c24, 0x1c1c140c, 0x1c1c141c, 0x1c1c143e, 0x1c1c1c04, 0x1c1c240c,
4131
+ 0x1c1c241c, 0x1c1c243e, 0x1c1c2c2c, 0x1c1c3e1c, 0x1c24041c, 0x1c240c0c, 0x1c240c34, 0x1c241414,
4132
+ 0x1c241c0c, 0x1c242c14, 0x1c243404, 0x1c243424, 0x1c2c040c, 0x1c2c0c04, 0x1c2c0c14, 0x1c2c142c,
4133
+ 0x1c2c1c14, 0x1c2c2424, 0x1c2c2c34, 0x1c2c3e1c, 0x1c340c34, 0x1c34240c, 0x1c3e040c, 0x1c3e041c,
4134
+ 0x1c3e1404, 0x1c3e1414, 0x1c3e1c2c, 0x24040404, 0x24040424, 0x24040c14, 0x24041404, 0x24041424,
4135
+ 0x2404143e, 0x24041c14, 0x2404240c, 0x24042c04, 0x24043e04, 0x240c0414, 0x240c043e, 0x240c0c0c,
4136
+ 0x240c0c1c, 0x240c1414, 0x240c1c04, 0x240c1c2c, 0x240c241c, 0x240c2c0c, 0x240c2c2c, 0x2414040c,
4137
+ 0x2414041c, 0x24140c04, 0x24140c2c, 0x2414140c, 0x24141c1c, 0x24142404, 0x24142c3e, 0x24143414,
4138
+ 0x24143e04, 0x241c0424, 0x241c0c0c, 0x241c0c1c, 0x241c1404, 0x241c1414, 0x241c1c0c, 0x241c1c2c,
4139
+ 0x24240404, 0x24240414, 0x24241424, 0x24241c3e, 0x24242404, 0x24243e0c, 0x242c042c, 0x242c043e,
4140
+ 0x242c140c, 0x242c3414, 0x24340c1c, 0x24341c24, 0x24343404, 0x243e0c04, 0x243e0c2c, 0x243e1c04,
4141
+ 0x243e241c, 0x243e2c0c, 0x2c040414, 0x2c040c04, 0x2c040c24, 0x2c041414, 0x2c042404, 0x2c042424,
4142
+ 0x2c04243e, 0x2c042c14, 0x2c043434, 0x2c043e24, 0x2c0c040c, 0x2c0c041c, 0x2c0c042c, 0x2c0c0c14,
4143
+ 0x2c0c140c, 0x2c0c1c14, 0x2c0c3e14, 0x2c140404, 0x2c140c0c, 0x2c14141c, 0x2c141c04, 0x2c141c34,
4144
+ 0x2c142c1c, 0x2c1c0414, 0x2c1c043e, 0x2c1c0c04, 0x2c1c143e, 0x2c1c2424, 0x2c1c2c0c, 0x2c1c342c,
4145
+ 0x2c1c3e1c, 0x2c24040c, 0x2c240424, 0x2c241404, 0x2c241c14, 0x2c242434, 0x2c2c0c14, 0x2c2c1434,
4146
+ 0x2c2c2c0c, 0x2c2c2c1c, 0x2c342414, 0x2c3e0414, 0x2c3e0424, 0x2c3e1414, 0x34040c0c, 0x34040c1c,
4147
+ 0x34040c2c, 0x34041c0c, 0x34041c1c, 0x34043404, 0x340c0404, 0x340c1404, 0x340c143e, 0x340c3424,
4148
+ 0x34140c14, 0x34141c24, 0x34142414, 0x34142c2c, 0x34143414, 0x34143e04, 0x341c0404, 0x341c0c24,
4149
+ 0x341c140c, 0x341c2404, 0x3424142c, 0x3424241c, 0x34243414, 0x342c0404, 0x342c041c, 0x342c1c24,
4150
+ 0x342c3404, 0x3434042c, 0x34342404, 0x343e0c0c, 0x343e0c1c, 0x3e040404, 0x3e040424, 0x3e04043e,
4151
+ 0x3e041404, 0x3e041414, 0x3e041c34, 0x3e042404, 0x3e042c24, 0x3e043414, 0x3e0c0414, 0x3e0c0c0c,
4152
+ 0x3e0c1424, 0x3e0c241c, 0x3e0c242c, 0x3e14040c, 0x3e140424, 0x3e140c04, 0x3e140c34, 0x3e14140c,
4153
+ 0x3e141c04, 0x3e142c0c, 0x3e1c0414, 0x3e1c1c14, 0x3e1c1c2c, 0x3e1c2c1c, 0x3e24040c, 0x3e24042c,
4154
+ 0x3e240c1c, 0x3e241404, 0x3e242c04, 0x3e2c1414, 0x3e2c2414, 0x3e340414, 0x3e341c0c, 0x3e3e0404,
4155
+ };
4156
+
4157
+ #define NGRID_IQ1S 512
4158
+ constexpr constant static uint64_t iq1s_grid[NGRID_IQ1S] = {
4159
+ 0xffffffffffff0101, 0xffffffffff01ff00, 0xffffffffff010100, 0xffffffff00000000,
4160
+ 0xffffffff01ff00ff, 0xffffffff01ff0001, 0xffffffff0101ffff, 0xffffffff0101ff01,
4161
+ 0xffffff00ff000000, 0xffffff000000ff00, 0xffffff00000000ff, 0xffffff0000000100,
4162
+ 0xffffff0000010000, 0xffffff0001000000, 0xffffff01ffff00ff, 0xffffff01ff01ff00,
4163
+ 0xffffff01ff010100, 0xffffff0100000001, 0xffffff0101ffff00, 0xffffff0101ff0101,
4164
+ 0xffffff0101010100, 0xffff00ffff00ff01, 0xffff00ffff0000ff, 0xffff00ff00ff0100,
4165
+ 0xffff00ff0100ff00, 0xffff00ff010001ff, 0xffff0000ff0101ff, 0xffff000000ffff00,
4166
+ 0xffff000000000000, 0xffff00000001ff01, 0xffff000001000101, 0xffff0000010100ff,
4167
+ 0xffff0001ffff0100, 0xffff00010000ff00, 0xffff000100010101, 0xffff000101000000,
4168
+ 0xffff01ffffff0000, 0xffff01ffff01ffff, 0xffff01ffff010100, 0xffff01ff00000000,
4169
+ 0xffff01ff01ffffff, 0xffff01ff01ff0001, 0xffff01ff0101ffff, 0xffff01ff01010001,
4170
+ 0xffff0100ffffff01, 0xffff01000000ffff, 0xffff010000000100, 0xffff010001ff01ff,
4171
+ 0xffff010001000000, 0xffff0101ff000000, 0xffff0101000101ff, 0xffff010101ffff01,
4172
+ 0xffff01010101ff00, 0xff00ffffff000000, 0xff00ffff00ffff00, 0xff00ffff00000001,
4173
+ 0xff00ffff000001ff, 0xff00ffff01010000, 0xff00ff00ffff0000, 0xff00ff00ff00ff00,
4174
+ 0xff00ff00ff0000ff, 0xff00ff00ff000100, 0xff00ff00ff010001, 0xff00ff0000ff0001,
4175
+ 0xff00ff000000ffff, 0xff00ff0000000000, 0xff00ff000001ff00, 0xff00ff0000010100,
4176
+ 0xff00ff0001ff0000, 0xff00ff000100ff00, 0xff00ff0001000100, 0xff00ff01ff000000,
4177
+ 0xff00ff0100ff0000, 0xff00ff01000001ff, 0xff00ff0101010001, 0xff0000ff00000000,
4178
+ 0xff0000ff0001ff00, 0xff0000ff00010100, 0xff000000ffff0101, 0xff000000ff000000,
4179
+ 0xff000000ff01ff00, 0xff00000000ff0000, 0xff0000000000ff00, 0xff000000000000ff,
4180
+ 0xff00000000000000, 0xff00000000000001, 0xff00000000000100, 0xff0000000001ffff,
4181
+ 0xff00000000010000, 0xff00000001000000, 0xff00000001010100, 0xff000001ff00ff01,
4182
+ 0xff000001ff0100ff, 0xff00000100000000, 0xff0000010001ff00, 0xff00000101ff0100,
4183
+ 0xff0000010100ff00, 0xff0001ff00ff00ff, 0xff0001ff00000101, 0xff0001ff000100ff,
4184
+ 0xff0001ff01000000, 0xff000100ff0001ff, 0xff0001000000ff01, 0xff00010000000000,
4185
+ 0xff00010000010001, 0xff00010000010100, 0xff00010001ffff00, 0xff00010001ff0101,
4186
+ 0xff00010001010000, 0xff000101ffffffff, 0xff000101ff000101, 0xff00010101ff00ff,
4187
+ 0xff00010101000001, 0xff000101010100ff, 0xff01ffffff000101, 0xff01ffffff01ffff,
4188
+ 0xff01ffffff01ff01, 0xff01ffffff0101ff, 0xff01ffff00000000, 0xff01ffff01ff0001,
4189
+ 0xff01ffff0101ff01, 0xff01ff00ff000000, 0xff01ff0000ff0100, 0xff01ff000000ff01,
4190
+ 0xff01ff0000010000, 0xff01ff00010000ff, 0xff01ff01ff01ff00, 0xff01ff0100000101,
4191
+ 0xff0100ffffff0000, 0xff0100ffff010000, 0xff0100ff01ff00ff, 0xff0100ff01000100,
4192
+ 0xff0100ff010100ff, 0xff010000ffffff01, 0xff01000000000000, 0xff0100000101ff00,
4193
+ 0xff010001ffff00ff, 0xff010001ff000100, 0xff01000100ffff00, 0xff01000100010001,
4194
+ 0xff01000101ff0001, 0xff010001010001ff, 0xff0101ffffffffff, 0xff0101ffff01ffff,
4195
+ 0xff0101ffff010101, 0xff0101ff0000ff00, 0xff0101ff01010001, 0xff010100ff000000,
4196
+ 0xff010100ff01ff01, 0xff01010000ff0001, 0xff01010000000100, 0xff01010001000000,
4197
+ 0xff0101010100ffff, 0x00ffffff0000ff01, 0x00ffffff000000ff, 0x00ffffff00000100,
4198
+ 0x00ffffff00010000, 0x00ffff00ffff0001, 0x00ffff00ff0000ff, 0x00ffff00ff000100,
4199
+ 0x00ffff0000000000, 0x00ffff0001000100, 0x00ffff0001010001, 0x00ffff01ff00ff01,
4200
+ 0x00ffff0100ff0100, 0x00ffff010000ff00, 0x00ffff01000100ff, 0x00ffff0101ff00ff,
4201
+ 0x00ffff010101ff00, 0x00ff00ffffffffff, 0x00ff00ffffff01ff, 0x00ff00ffff000101,
4202
+ 0x00ff00ff00000000, 0x00ff00ff000101ff, 0x00ff00ff01010101, 0x00ff0000ff000000,
4203
+ 0x00ff0000ff01ffff, 0x00ff000000ff0000, 0x00ff00000000ff00, 0x00ff0000000000ff,
4204
+ 0x00ff000000000000, 0x00ff000000000001, 0x00ff000000000100, 0x00ff000000010000,
4205
+ 0x00ff000001ffff01, 0x00ff000001000000, 0x00ff0001ff000101, 0x00ff000100ffffff,
4206
+ 0x00ff000100000000, 0x00ff0001010001ff, 0x00ff01ffff000000, 0x00ff01ff0001ff00,
4207
+ 0x00ff01ff01ff0100, 0x00ff0100ff01ff01, 0x00ff010000ff00ff, 0x00ff010000ff0101,
4208
+ 0x00ff010000000000, 0x00ff010000010101, 0x00ff01000100ff00, 0x00ff010001010000,
4209
+ 0x00ff0101ffffff00, 0x00ff01010000ff01, 0x00ff010100000100, 0x00ff010101ff0000,
4210
+ 0x0000ffffffff0100, 0x0000ffffff00ff00, 0x0000ffffff0000ff, 0x0000ffffff010000,
4211
+ 0x0000ffff00000000, 0x0000ffff00010101, 0x0000ffff01ffff01, 0x0000ffff01000100,
4212
+ 0x0000ff00ff000000, 0x0000ff00ff01ff00, 0x0000ff00ff0101ff, 0x0000ff0000ff0000,
4213
+ 0x0000ff000000ff00, 0x0000ff00000000ff, 0x0000ff0000000000, 0x0000ff0000000001,
4214
+ 0x0000ff0000000100, 0x0000ff0000010000, 0x0000ff0001ffffff, 0x0000ff0001ff01ff,
4215
+ 0x0000ff0001000000, 0x0000ff000101ffff, 0x0000ff01ffff0101, 0x0000ff01ff010000,
4216
+ 0x0000ff0100000000, 0x0000ff0101000101, 0x000000ffffff0001, 0x000000ffff000000,
4217
+ 0x000000ff00ff0000, 0x000000ff0000ff00, 0x000000ff000000ff, 0x000000ff00000000,
4218
+ 0x000000ff00000001, 0x000000ff00000100, 0x000000ff00010000, 0x000000ff01000000,
4219
+ 0x000000ff0101ff00, 0x00000000ffff0000, 0x00000000ff00ff00, 0x00000000ff0000ff,
4220
+ 0x00000000ff000000, 0x00000000ff000001, 0x00000000ff000100, 0x00000000ff010000,
4221
+ 0x0000000000ffff00, 0x0000000000ff00ff, 0x0000000000ff0000, 0x0000000000ff0001,
4222
+ 0x0000000000ff0100, 0x000000000000ffff, 0x000000000000ff00, 0x000000000000ff01,
4223
+ 0x00000000000000ff, 0x0000000000000001, 0x00000000000001ff, 0x0000000000000100,
4224
+ 0x0000000000000101, 0x000000000001ff00, 0x00000000000100ff, 0x0000000000010000,
4225
+ 0x0000000000010001, 0x0000000000010100, 0x0000000001ff0000, 0x000000000100ff00,
4226
+ 0x00000000010000ff, 0x0000000001000000, 0x0000000001000001, 0x0000000001000100,
4227
+ 0x0000000001010000, 0x00000001ffff01ff, 0x00000001ff000000, 0x0000000100ff0000,
4228
+ 0x000000010000ff00, 0x00000001000000ff, 0x0000000100000000, 0x0000000100000001,
4229
+ 0x0000000100000100, 0x0000000100010000, 0x0000000101000000, 0x000001ffff00ff00,
4230
+ 0x000001ffff010001, 0x000001ffff0101ff, 0x000001ff00ffff01, 0x000001ff0000ffff,
4231
+ 0x000001ff00000000, 0x000001ff010000ff, 0x000001ff01010100, 0x00000100ffff0100,
4232
+ 0x00000100ff000000, 0x0000010000ff0000, 0x000001000000ff00, 0x00000100000000ff,
4233
+ 0x0000010000000000, 0x0000010000000001, 0x0000010000000100, 0x0000010000010000,
4234
+ 0x0000010001000000, 0x000001000101ff01, 0x00000101ffff0001, 0x00000101ff01ffff,
4235
+ 0x0000010100000000, 0x0000010101010100, 0x0001ffffff000000, 0x0001ffff00ffffff,
4236
+ 0x0001ffff00000100, 0x0001ffff0001ff00, 0x0001ffff01000000, 0x0001ff00ffffff00,
4237
+ 0x0001ff00ffff01ff, 0x0001ff00ff010000, 0x0001ff0000000000, 0x0001ff0000010001,
4238
+ 0x0001ff0001ff0000, 0x0001ff0001010100, 0x0001ff01ff0000ff, 0x0001ff01ff000001,
4239
+ 0x0001ff0100ffffff, 0x0001ff010001ffff, 0x0001ff01000101ff, 0x0001ff010100ff01,
4240
+ 0x000100ffff00ffff, 0x000100ffff00ff01, 0x000100ffff000100, 0x000100ff00000000,
4241
+ 0x000100ff000101ff, 0x000100ff01ff0101, 0x000100ff0100ffff, 0x000100ff01010101,
4242
+ 0x00010000ff000000, 0x00010000ff010100, 0x0001000000ff0000, 0x000100000000ff00,
4243
+ 0x00010000000000ff, 0x0001000000000000, 0x0001000000000001, 0x0001000000000100,
4244
+ 0x0001000000010000, 0x0001000001ffff01, 0x0001000001000000, 0x0001000100ff0101,
4245
+ 0x0001000100000000, 0x00010001010100ff, 0x000101ffffff01ff, 0x000101ffffff0101,
4246
+ 0x000101ff00010000, 0x000101ff01ff0000, 0x000101ff0100ff01, 0x00010100ffff0000,
4247
+ 0x0001010000000000, 0x000101000001ffff, 0x0001010000010101, 0x00010100010001ff,
4248
+ 0x00010101ff00ff00, 0x00010101ff010001, 0x0001010100ffffff, 0x0001010100ff01ff,
4249
+ 0x00010101000101ff, 0x0001010101ff0000, 0x000101010100ff01, 0x0001010101000101,
4250
+ 0x01ffffffffff0101, 0x01ffffffff01ffff, 0x01ffffffff01ff01, 0x01ffffffff0101ff,
4251
+ 0x01ffffffff010101, 0x01ffffff00000000, 0x01ffffff01ff01ff, 0x01ffffff01000101,
4252
+ 0x01ffffff0101ff01, 0x01ffffff010100ff, 0x01ffff000000ff00, 0x01ffff0000000001,
4253
+ 0x01ffff00000001ff, 0x01ffff0000010000, 0x01ffff0001ff0000, 0x01ffff01ffffffff,
4254
+ 0x01ffff01ffff01ff, 0x01ffff01ff000000, 0x01ffff01ff01ffff, 0x01ffff01ff0101ff,
4255
+ 0x01ffff010100ffff, 0x01ff00ffffff0000, 0x01ff00ffff010000, 0x01ff00ff00ffff01,
4256
+ 0x01ff0000ff0000ff, 0x01ff000000000000, 0x01ff00000001ff01, 0x01ff000001ffffff,
4257
+ 0x01ff000001010100, 0x01ff0001ffffff01, 0x01ff0001ff010001, 0x01ff000101ff0100,
4258
+ 0x01ff000101000001, 0x01ff0001010100ff, 0x01ff01ffff00ffff, 0x01ff01ff00010001,
4259
+ 0x01ff01ff01000000, 0x01ff01ff010101ff, 0x01ff0100ff000001, 0x01ff010000ffff00,
4260
+ 0x01ff010000000100, 0x01ff010001ff01ff, 0x01ff01000101ffff, 0x01ff0101ffff00ff,
4261
+ 0x01ff0101ffff0101, 0x01ff0101ff0101ff, 0x01ff010100010000, 0x0100ffff00ff00ff,
4262
+ 0x0100ffff00ff0001, 0x0100ffff00000100, 0x0100ffff0100ff00, 0x0100ff00ffff0000,
4263
+ 0x0100ff00ff00ffff, 0x0100ff00ff00ff01, 0x0100ff00ff000100, 0x0100ff00ff010000,
4264
+ 0x0100ff0000000000, 0x0100ff00000100ff, 0x0100ff0001ff0101, 0x0100ff0001010101,
4265
+ 0x0100ff0100ff00ff, 0x0100ff0100ff0001, 0x0100ff0100000100, 0x0100ff0100010001,
4266
+ 0x0100ff0101000000, 0x010000ffff00ff00, 0x010000ff0000ffff, 0x010000ff00000000,
4267
+ 0x010000ff010001ff, 0x010000ff01010001, 0x01000000ffffff00, 0x01000000ffff0101,
4268
+ 0x01000000ff000000, 0x01000000ff0100ff, 0x01000000ff010101, 0x0100000000ff0000,
4269
+ 0x010000000000ff00, 0x01000000000000ff, 0x0100000000000000, 0x0100000000000001,
4270
+ 0x0100000000000100, 0x0100000000010000, 0x0100000001000000, 0x0100000100000000,
4271
+ 0x01000001000101ff, 0x0100000101ffff01, 0x010001ffff000101, 0x010001ff00ff0100,
4272
+ 0x010001ff0000ff00, 0x010001ff000100ff, 0x010001ff01ffffff, 0x01000100ffff0000,
4273
+ 0x01000100ff0001ff, 0x0100010000000000, 0x010001000001ff00, 0x0100010001ff0000,
4274
+ 0x01000100010000ff, 0x0100010001000101, 0x01000101ff00ff01, 0x0100010100ff0100,
4275
+ 0x010001010000ffff, 0x0100010101010001, 0x0101ffffffff0101, 0x0101ffffff0001ff,
4276
+ 0x0101ffffff01ffff, 0x0101ffffff010101, 0x0101ffff00000000, 0x0101ffff0101ffff,
4277
+ 0x0101ffff010101ff, 0x0101ff00ff000000, 0x0101ff0000ff0100, 0x0101ff000000ff00,
4278
+ 0x0101ff0000010000, 0x0101ff00010000ff, 0x0101ff0001000001, 0x0101ff01ff010101,
4279
+ 0x0101ff0100000000, 0x0101ff010101ff00, 0x010100ffffff0000, 0x010100ffff010000,
4280
+ 0x010100ff00ff01ff, 0x010100ff000000ff, 0x010100ff00000101, 0x010100ff01ffff00,
4281
+ 0x01010000ffffff01, 0x01010000ff000100, 0x01010000ff01ff01, 0x0101000000000000,
4282
+ 0x01010000000100ff, 0x010100000101ff01, 0x01010001ffff0000, 0x01010001ff00ffff,
4283
+ 0x01010001ff010000, 0x0101000101ffffff, 0x0101000101ff01ff, 0x0101000101010101,
4284
+ 0x010101ffff01ffff, 0x010101ff00000000, 0x010101ff0001ff01, 0x010101ff0101ffff,
4285
+ 0x010101ff010101ff, 0x01010100ffffffff, 0x01010100ff000001, 0x010101000000ff00,
4286
+ 0x0101010001010000, 0x0101010100ff0001, 0x010101010001ff01, 0x010101010101ffff,
4287
+ };
3750
4288
 
3751
4289
  constexpr constant static uint8_t ksigns_iq2xs[128] = {
3752
4290
  0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
@@ -3812,7 +4350,6 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
3812
4350
  threadgroup_barrier(mem_flags::mem_threadgroup);
3813
4351
  }
3814
4352
 
3815
- #if QK_K == 256
3816
4353
  const int ix = tiisg;
3817
4354
 
3818
4355
  device const float * y4 = y + 32 * ix;
@@ -3853,9 +4390,6 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
3853
4390
 
3854
4391
  y4 += 32 * 32;
3855
4392
  }
3856
- #else
3857
- // TODO
3858
- #endif
3859
4393
 
3860
4394
  for (int row = 0; row < N_DST; ++row) {
3861
4395
  all_sum = simd_sum(sumf[row]);
@@ -3945,7 +4479,6 @@ void kernel_mul_mv_iq2_xs_f32_impl(
3945
4479
  threadgroup_barrier(mem_flags::mem_threadgroup);
3946
4480
  }
3947
4481
 
3948
- #if QK_K == 256
3949
4482
  const int ix = tiisg;
3950
4483
 
3951
4484
  device const float * y4 = y + 32 * ix;
@@ -3996,9 +4529,6 @@ void kernel_mul_mv_iq2_xs_f32_impl(
3996
4529
 
3997
4530
  y4 += 32 * 32;
3998
4531
  }
3999
- #else
4000
- // TODO
4001
- #endif
4002
4532
 
4003
4533
  for (int row = 0; row < N_DST; ++row) {
4004
4534
  all_sum = simd_sum(sumf[row]);
@@ -4088,7 +4618,6 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
4088
4618
  threadgroup_barrier(mem_flags::mem_threadgroup);
4089
4619
  }
4090
4620
 
4091
- #if QK_K == 256
4092
4621
  const int ix = tiisg;
4093
4622
 
4094
4623
  device const float * y4 = y + 32 * ix;
@@ -4132,9 +4661,6 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
4132
4661
 
4133
4662
  y4 += 32 * 32;
4134
4663
  }
4135
- #else
4136
- // TODO
4137
- #endif
4138
4664
 
4139
4665
  for (int row = 0; row < N_DST; ++row) {
4140
4666
  all_sum = simd_sum(sumf[row]);
@@ -4173,767 +4699,1802 @@ kernel void kernel_mul_mv_iq3_xxs_f32(
4173
4699
  kernel_mul_mv_iq3_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
4174
4700
  }
4175
4701
 
4702
+ void kernel_mul_mv_iq3_s_f32_impl(
4703
+ device const void * src0,
4704
+ device const float * src1,
4705
+ device float * dst,
4706
+ constant int64_t & ne00,
4707
+ constant int64_t & ne01,
4708
+ constant int64_t & ne02,
4709
+ constant int64_t & ne10,
4710
+ constant int64_t & ne12,
4711
+ constant int64_t & ne0,
4712
+ constant int64_t & ne1,
4713
+ constant uint & r2,
4714
+ constant uint & r3,
4715
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
4716
+ uint3 tgpig[[threadgroup_position_in_grid]],
4717
+ uint tiisg[[thread_index_in_simdgroup]],
4718
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4176
4719
 
4177
- //============================= templates and their specializations =============================
4178
-
4179
- // NOTE: this is not dequantizing - we are simply fitting the template
4180
- template <typename type4x4>
4181
- void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
4182
- float4x4 temp = *(((device float4x4 *)src));
4183
- for (int i = 0; i < 16; i++){
4184
- reg[i/4][i%4] = temp[i/4][i%4];
4185
- }
4186
- }
4187
-
4188
- template <typename type4x4>
4189
- void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
4190
- half4x4 temp = *(((device half4x4 *)src));
4191
- for (int i = 0; i < 16; i++){
4192
- reg[i/4][i%4] = temp[i/4][i%4];
4193
- }
4194
- }
4720
+ const int nb = ne00/QK_K;
4721
+ const int r0 = tgpig.x;
4722
+ const int r1 = tgpig.y;
4723
+ const int im = tgpig.z;
4195
4724
 
4196
- template <typename type4x4>
4197
- void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
4198
- device const uint16_t * qs = ((device const uint16_t *)xb + 1);
4199
- const float d1 = il ? (xb->d / 16.h) : xb->d;
4200
- const float d2 = d1 / 256.f;
4201
- const float md = -8.h * xb->d;
4202
- const ushort mask0 = il ? 0x00F0 : 0x000F;
4203
- const ushort mask1 = mask0 << 8;
4725
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
4726
+ const int ib_row = first_row * nb;
4204
4727
 
4205
- for (int i=0;i<8;i++) {
4206
- reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md;
4207
- reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md;
4208
- }
4209
- }
4728
+ const uint i12 = im%ne12;
4729
+ const uint i13 = im/ne12;
4210
4730
 
4211
- template <typename type4x4>
4212
- void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
4213
- device const uint16_t * qs = ((device const uint16_t *)xb + 2);
4214
- const float d1 = il ? (xb->d / 16.h) : xb->d;
4215
- const float d2 = d1 / 256.f;
4216
- const float m = xb->m;
4217
- const ushort mask0 = il ? 0x00F0 : 0x000F;
4218
- const ushort mask1 = mask0 << 8;
4731
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
4219
4732
 
4220
- for (int i=0;i<8;i++) {
4221
- reg[i/2][2*(i%2)+0] = ((qs[i] & mask0) * d1) + m;
4222
- reg[i/2][2*(i%2)+1] = ((qs[i] & mask1) * d2) + m;
4223
- }
4224
- }
4733
+ device const block_iq3_s * x = (device const block_iq3_s *) src0 + ib_row + offset0;
4734
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
4225
4735
 
4226
- template <typename type4x4>
4227
- void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg) {
4228
- device const uint16_t * qs = ((device const uint16_t *)xb + 3);
4229
- const float d = xb->d;
4230
- const float md = -16.h * xb->d;
4231
- const ushort mask = il ? 0x00F0 : 0x000F;
4736
+ float yl[32];
4737
+ float sumf[N_DST]={0.f}, all_sum;
4232
4738
 
4233
- const uint32_t qh = *((device const uint32_t *)xb->qh);
4739
+ const int nb32 = nb * (QK_K / 32);
4234
4740
 
4235
- const int x_mv = il ? 4 : 0;
4741
+ threadgroup uint32_t * values = (threadgroup uint32_t *)shared_values;
4742
+ {
4743
+ int nval = 8;
4744
+ int pos = (32*sgitg + tiisg)*nval;
4745
+ for (int i = 0; i < nval; ++i) values[pos + i] = iq3xs_grid[pos + i];
4746
+ threadgroup_barrier(mem_flags::mem_threadgroup);
4747
+ }
4236
4748
 
4237
- const int gh_mv = il ? 12 : 0;
4238
- const int gh_bk = il ? 0 : 4;
4749
+ const int ix = tiisg;
4239
4750
 
4240
- for (int i = 0; i < 8; i++) {
4241
- // extract the 5-th bits for x0 and x1
4242
- const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
4243
- const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
4751
+ device const float * y4 = y + 32 * ix;
4244
4752
 
4245
- // combine the 4-bits from qs with the 5th bit
4246
- const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
4247
- const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
4753
+ for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
4248
4754
 
4249
- reg[i/2][2*(i%2)+0] = d * x0 + md;
4250
- reg[i/2][2*(i%2)+1] = d * x1 + md;
4251
- }
4252
- }
4755
+ for (int i = 0; i < 32; ++i) {
4756
+ yl[i] = y4[i];
4757
+ }
4253
4758
 
4254
- template <typename type4x4>
4255
- void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg) {
4256
- device const uint16_t * qs = ((device const uint16_t *)xb + 4);
4257
- const float d = xb->d;
4258
- const float m = xb->m;
4259
- const ushort mask = il ? 0x00F0 : 0x000F;
4759
+ const int ibl = ib32 / (QK_K / 32);
4760
+ const int ib = ib32 % (QK_K / 32);
4260
4761
 
4261
- const uint32_t qh = *((device const uint32_t *)xb->qh);
4762
+ device const block_iq3_s * xr = x + ibl;
4763
+ device const uint8_t * qs = xr->qs + 8 * ib;
4764
+ device const uint8_t * qh = xr->qh + ib;
4765
+ device const uint8_t * sc = xr->scales + (ib/2);
4766
+ device const uint8_t * signs = xr->signs + 4 * ib;
4767
+ device const half * dh = &xr->d;
4262
4768
 
4263
- const int x_mv = il ? 4 : 0;
4769
+ for (int row = 0; row < N_DST; row++) {
4264
4770
 
4265
- const int gh_mv = il ? 12 : 0;
4266
- const int gh_bk = il ? 0 : 4;
4771
+ const float db = dh[0];
4772
+ const float d = db * (0.5f + ((sc[0] >> 4*(ib%2)) & 0xf));
4267
4773
 
4268
- for (int i = 0; i < 8; i++) {
4269
- // extract the 5-th bits for x0 and x1
4270
- const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
4271
- const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
4774
+ float2 sum = {0};
4775
+ for (int l = 0; l < 4; ++l) {
4776
+ const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(values + (qs[2*l+0] | ((qh[0] << (8-2*l)) & 256)));
4777
+ const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(values + (qs[2*l+1] | ((qh[0] << (7-2*l)) & 256)));
4778
+ for (int j = 0; j < 4; ++j) {
4779
+ sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l] & kmask_iq2xs[j+0]);
4780
+ sum[1] += yl[8*l + j + 4] * grid2[j] * select(1, -1, signs[l] & kmask_iq2xs[j+4]);
4781
+ }
4782
+ }
4783
+ sumf[row] += d * (sum[0] + sum[1]);
4272
4784
 
4273
- // combine the 4-bits from qs with the 5th bit
4274
- const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
4275
- const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
4785
+ dh += nb*sizeof(block_iq3_s)/2;
4786
+ qs += nb*sizeof(block_iq3_s);
4787
+ qh += nb*sizeof(block_iq3_s);
4788
+ sc += nb*sizeof(block_iq3_s);
4789
+ signs += nb*sizeof(block_iq3_s);
4790
+ }
4276
4791
 
4277
- reg[i/2][2*(i%2)+0] = d * x0 + m;
4278
- reg[i/2][2*(i%2)+1] = d * x1 + m;
4792
+ y4 += 32 * 32;
4279
4793
  }
4280
- }
4281
-
4282
- template <typename type4x4>
4283
- void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
4284
- device const int8_t * qs = ((device const int8_t *)xb->qs);
4285
- const half d = xb->d;
4286
4794
 
4287
- for (int i = 0; i < 16; i++) {
4288
- reg[i/4][i%4] = (qs[i + 16*il] * d);
4795
+ for (int row = 0; row < N_DST; ++row) {
4796
+ all_sum = simd_sum(sumf[row]);
4797
+ if (tiisg == 0) {
4798
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.5f;
4799
+ }
4289
4800
  }
4290
4801
  }
4291
4802
 
4292
- template <typename type4x4>
4293
- void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
4294
- const float d = xb->d;
4295
- const float min = xb->dmin;
4296
- device const uint8_t * q = (device const uint8_t *)xb->qs;
4297
- float dl, ml;
4298
- uint8_t sc = xb->scales[il];
4803
+ [[host_name("kernel_mul_mv_iq3_s_f32")]]
4804
+ kernel void kernel_mul_mv_iq3_s_f32(
4805
+ device const void * src0,
4806
+ device const float * src1,
4807
+ device float * dst,
4808
+ constant int64_t & ne00,
4809
+ constant int64_t & ne01,
4810
+ constant int64_t & ne02,
4811
+ constant uint64_t & nb00,
4812
+ constant uint64_t & nb01,
4813
+ constant uint64_t & nb02,
4814
+ constant int64_t & ne10,
4815
+ constant int64_t & ne11,
4816
+ constant int64_t & ne12,
4817
+ constant uint64_t & nb10,
4818
+ constant uint64_t & nb11,
4819
+ constant uint64_t & nb12,
4820
+ constant int64_t & ne0,
4821
+ constant int64_t & ne1,
4822
+ constant uint & r2,
4823
+ constant uint & r3,
4824
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
4825
+ uint3 tgpig[[threadgroup_position_in_grid]],
4826
+ uint tiisg[[thread_index_in_simdgroup]],
4827
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4299
4828
 
4300
- #if QK_K == 256
4301
- q = q + 32*(il/8) + 16*(il&1);
4302
- il = (il/2)%4;
4303
- #endif
4304
- half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
4305
- uchar mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
4306
- dl = d * (sc & 0xF) * coef, ml = min * (sc >> 4);
4307
- for (int i = 0; i < 16; ++i) {
4308
- reg[i/4][i%4] = dl * (q[i] & mask) - ml;
4309
- }
4829
+ kernel_mul_mv_iq3_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
4310
4830
  }
4311
4831
 
4312
- template <typename type4x4>
4313
- void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) {
4314
- const half d_all = xb->d;
4315
- device const uint8_t * q = (device const uint8_t *)xb->qs;
4316
- device const uint8_t * h = (device const uint8_t *)xb->hmask;
4317
- device const int8_t * scales = (device const int8_t *)xb->scales;
4318
-
4319
- #if QK_K == 256
4320
- q = q + 32 * (il/8) + 16 * (il&1);
4321
- h = h + 16 * (il&1);
4322
- uint8_t m = 1 << (il/2);
4323
- uint16_t kmask1 = (il/4)>1 ? ((il/4)>2 ? 192 : 48) : \
4324
- ((il/4)>0 ? 12 : 3);
4325
- uint16_t kmask2 = il/8 ? 0xF0 : 0x0F;
4326
- uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
4327
- int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2)
4328
- : (scale_2&kmask2) | ((scale_1&kmask1) << 4);
4329
- float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f);
4330
- const float ml = 4.f * dl;
4832
+ void kernel_mul_mv_iq2_s_f32_impl(
4833
+ device const void * src0,
4834
+ device const float * src1,
4835
+ device float * dst,
4836
+ constant int64_t & ne00,
4837
+ constant int64_t & ne01,
4838
+ constant int64_t & ne02,
4839
+ constant int64_t & ne10,
4840
+ constant int64_t & ne12,
4841
+ constant int64_t & ne0,
4842
+ constant int64_t & ne1,
4843
+ constant uint & r2,
4844
+ constant uint & r3,
4845
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
4846
+ uint3 tgpig[[threadgroup_position_in_grid]],
4847
+ uint tiisg[[thread_index_in_simdgroup]],
4848
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4331
4849
 
4332
- il = (il/2) & 3;
4333
- const half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
4334
- const uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
4335
- dl *= coef;
4850
+ const int nb = ne00/QK_K;
4851
+ const int r0 = tgpig.x;
4852
+ const int r1 = tgpig.y;
4853
+ const int im = tgpig.z;
4336
4854
 
4337
- for (int i = 0; i < 16; ++i) {
4338
- reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml);
4339
- }
4340
- #else
4341
- float kcoef = il&1 ? 1.f/16.f : 1.f;
4342
- uint16_t kmask = il&1 ? 0xF0 : 0x0F;
4343
- float dl = d_all * ((scales[il/2] & kmask) * kcoef - 8);
4344
- float coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
4345
- uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
4346
- uint8_t m = 1<<(il*2);
4347
- for (int i = 0; i < 16; ++i) {
4348
- reg[i/4][i%4] = coef * dl * ((q[i] & mask) - ((h[i%8] & (m * (1 + i/8))) ? 0 : 4.f/coef));
4349
- }
4350
- #endif
4351
- }
4855
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
4856
+ const int ib_row = first_row * nb;
4352
4857
 
4353
- static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) {
4354
- return j < 4 ? uchar2{uchar(q[j+0+k] & 63), uchar(q[j+4+k] & 63)}
4355
- : uchar2{uchar((q[j+4+k] & 0xF) | ((q[j-4+k] & 0xc0) >> 2)), uchar((q[j+4+k] >> 4) | ((q[j-0+k] & 0xc0) >> 2))};
4356
- }
4858
+ const uint i12 = im%ne12;
4859
+ const uint i13 = im/ne12;
4357
4860
 
4358
- template <typename type4x4>
4359
- void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) {
4360
- device const uchar * q = xb->qs;
4861
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
4361
4862
 
4362
- #if QK_K == 256
4363
- short is = (il/4) * 2;
4364
- q = q + (il/4) * 32 + 16 * (il&1);
4365
- il = il & 3;
4366
- const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
4367
- const float d = il < 2 ? xb->d : xb->d / 16.h;
4368
- const float min = xb->dmin;
4369
- const float dl = d * sc[0];
4370
- const float ml = min * sc[1];
4371
- #else
4372
- q = q + 16 * (il&1);
4373
- device const uint8_t * s = xb->scales;
4374
- device const half2 * dh = (device const half2 *)xb->d;
4375
- const float2 d = (float2)dh[0];
4376
- const float dl = il<2 ? d[0] * (s[0]&0xF) : d[0] * (s[1]&0xF)/16.h;
4377
- const float ml = il<2 ? d[1] * (s[0]>>4) : d[1] * (s[1]>>4);
4378
- #endif
4379
- const ushort mask = il<2 ? 0x0F : 0xF0;
4380
- for (int i = 0; i < 16; ++i) {
4381
- reg[i/4][i%4] = dl * (q[i] & mask) - ml;
4382
- }
4383
- }
4863
+ device const block_iq2_s * x = (device const block_iq2_s *) src0 + ib_row + offset0;
4864
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
4384
4865
 
4385
- template <typename type4x4>
4386
- void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg) {
4387
- device const uint8_t * q = xb->qs;
4388
- device const uint8_t * qh = xb->qh;
4866
+ float yl[32];
4867
+ float sumf[N_DST]={0.f}, all_sum;
4389
4868
 
4390
- #if QK_K == 256
4391
- short is = (il/4) * 2;
4392
- q = q + 32 * (il/4) + 16 * (il&1);
4393
- qh = qh + 16 * (il&1);
4394
- uint8_t ul = 1 << (il/2);
4395
- il = il & 3;
4396
- const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
4397
- const float d = il < 2 ? xb->d : xb->d / 16.f;
4398
- const float min = xb->dmin;
4399
- const float dl = d * sc[0];
4400
- const float ml = min * sc[1];
4869
+ const int nb32 = nb * (QK_K / 32);
4401
4870
 
4402
- const ushort mask = il<2 ? 0x0F : 0xF0;
4403
- const float qh_val = il<2 ? 16.f : 256.f;
4404
- for (int i = 0; i < 16; ++i) {
4405
- reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
4406
- }
4407
- #else
4408
- q = q + 16 * (il&1);
4409
- device const int8_t * s = xb->scales;
4410
- const float dl = xb->d * s[il];
4411
- uint8_t m = 1<<(il*2);
4412
- const float coef = il<2 ? 1.f : 1.f/16.f;
4413
- const ushort mask = il<2 ? 0x0F : 0xF0;
4414
- for (int i = 0; i < 16; ++i) {
4415
- reg[i/4][i%4] = coef * dl * ((q[i] & mask) - (qh[i%8] & (m*(1+i/8)) ? 0.f : 16.f/coef));
4416
- }
4417
- #endif
4418
- }
4871
+ //threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values;
4872
+ //{
4873
+ // int nval = 32;
4874
+ // int pos = (32*sgitg + tiisg)*nval;
4875
+ // for (int i = 0; i < nval; ++i) values[pos + i] = iq2s_grid[pos + i];
4876
+ // threadgroup_barrier(mem_flags::mem_threadgroup);
4877
+ //}
4419
4878
 
4420
- template <typename type4x4>
4421
- void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) {
4422
- const half d_all = xb->d;
4423
- device const uint8_t * ql = (device const uint8_t *)xb->ql;
4424
- device const uint8_t * qh = (device const uint8_t *)xb->qh;
4425
- device const int8_t * scales = (device const int8_t *)xb->scales;
4879
+ const int ix = tiisg;
4426
4880
 
4427
- #if QK_K == 256
4428
- ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
4429
- qh = qh + 32*(il/8) + 16*(il&1);
4430
- float sc = scales[(il%2) + 2 * ((il/2))];
4431
- il = (il/2) & 3;
4432
- #else
4433
- ql = ql + 16 * (il&1);
4434
- float sc = scales[il];
4435
- #endif
4436
- const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
4437
- const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
4438
- const float coef = il>1 ? 1.f/16.f : 1.f;
4439
- const float ml = d_all * sc * 32.f;
4440
- const float dl = d_all * sc * coef;
4441
- for (int i = 0; i < 16; ++i) {
4442
- const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2))
4443
- : ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4));
4444
- reg[i/4][i%4] = dl * q - ml;
4445
- }
4446
- }
4881
+ device const float * y4 = y + 32 * ix;
4447
4882
 
4448
- template <typename type4x4>
4449
- void dequantize_iq2_xxs(device const block_iq2_xxs * xb, short il, thread type4x4 & reg) {
4450
- // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
4451
- const float d = xb->d;
4452
- const int ib32 = il/2;
4453
- il = il%2;
4454
- // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
4455
- // each block of 32 needs 2 uint32_t's for the quants & scale, so 4 uint16_t's.
4456
- device const uint16_t * q2 = xb->qs + 4*ib32;
4457
- const uint32_t aux32_g = q2[0] | (q2[1] << 16);
4458
- const uint32_t aux32_s = q2[2] | (q2[3] << 16);
4459
- thread const uint8_t * aux8 = (thread const uint8_t *)&aux32_g;
4460
- const float dl = d * (0.5f + (aux32_s >> 28)) * 0.25f;
4461
- constant uint8_t * grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+0]);
4462
- uint8_t signs = ksigns_iq2xs[(aux32_s >> 14*il) & 127];
4463
- for (int i = 0; i < 8; ++i) {
4464
- reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
4465
- }
4466
- grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+1]);
4467
- signs = ksigns_iq2xs[(aux32_s >> (14*il+7)) & 127];
4468
- for (int i = 0; i < 8; ++i) {
4469
- reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
4470
- }
4471
- }
4883
+ for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
4472
4884
 
4473
- template <typename type4x4>
4474
- void dequantize_iq2_xs(device const block_iq2_xs * xb, short il, thread type4x4 & reg) {
4475
- // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
4476
- const float d = xb->d;
4477
- const int ib32 = il/2;
4478
- il = il%2;
4479
- // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
4480
- device const uint16_t * q2 = xb->qs + 4*ib32;
4481
- const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f;
4482
- constant uint8_t * grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+0] & 511));
4483
- uint8_t signs = ksigns_iq2xs[q2[2*il+0] >> 9];
4484
- for (int i = 0; i < 8; ++i) {
4485
- reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
4486
- }
4487
- grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+1] & 511));
4488
- signs = ksigns_iq2xs[q2[2*il+1] >> 9];
4489
- for (int i = 0; i < 8; ++i) {
4490
- reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
4491
- }
4492
- }
4885
+ for (int i = 0; i < 32; ++i) {
4886
+ yl[i] = y4[i];
4887
+ }
4493
4888
 
4494
- template <typename type4x4>
4495
- void dequantize_iq3_xxs(device const block_iq3_xxs * xb, short il, thread type4x4 & reg) {
4496
- // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
4497
- const float d = xb->d;
4498
- const int ib32 = il/2;
4499
- il = il%2;
4500
- // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
4501
- device const uint8_t * q3 = xb->qs + 8*ib32;
4502
- device const uint16_t * gas = (device const uint16_t *)(xb->qs + QK_K/4) + 2*ib32;
4503
- const uint32_t aux32 = gas[0] | (gas[1] << 16);
4504
- const float dl = d * (0.5f + (aux32 >> 28)) * 0.5f;
4505
- constant uint8_t * grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+0]);
4506
- constant uint8_t * grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+1]);
4507
- uint8_t signs = ksigns_iq2xs[(aux32 >> 14*il) & 127];
4508
- for (int i = 0; i < 4; ++i) {
4509
- reg[0][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f);
4510
- reg[1][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f);
4889
+ const int ibl = ib32 / (QK_K / 32);
4890
+ const int ib = ib32 % (QK_K / 32);
4891
+
4892
+ device const block_iq2_s * xr = x + ibl;
4893
+ device const uint8_t * qs = xr->qs + 4 * ib;
4894
+ device const uint8_t * qh = xr->qh + ib;
4895
+ device const uint8_t * sc = xr->scales + ib;
4896
+ device const uint8_t * signs = qs + QK_K/8;
4897
+ device const half * dh = &xr->d;
4898
+
4899
+ for (int row = 0; row < N_DST; row++) {
4900
+
4901
+ const float db = dh[0];
4902
+ const float d1 = db * (0.5f + (sc[0] & 0xf));
4903
+ const float d2 = db * (0.5f + (sc[0] >> 4));
4904
+
4905
+ float2 sum = {0};
4906
+ for (int l = 0; l < 2; ++l) {
4907
+ //const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(values + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300)));
4908
+ //const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(values + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300)));
4909
+ constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300)));
4910
+ constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300)));
4911
+ for (int j = 0; j < 8; ++j) {
4912
+ sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l+0] & kmask_iq2xs[j]);
4913
+ sum[1] += yl[8*l + j + 16] * grid2[j] * select(1, -1, signs[l+2] & kmask_iq2xs[j]);
4914
+ }
4915
+ }
4916
+ sumf[row] += d1 * sum[0] + d2 * sum[1];
4917
+
4918
+ dh += nb*sizeof(block_iq2_s)/2;
4919
+ qs += nb*sizeof(block_iq2_s);
4920
+ qh += nb*sizeof(block_iq2_s);
4921
+ sc += nb*sizeof(block_iq2_s);
4922
+ signs += nb*sizeof(block_iq2_s);
4923
+ }
4924
+
4925
+ y4 += 32 * 32;
4511
4926
  }
4512
- grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+2]);
4513
- grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+3]);
4514
- signs = ksigns_iq2xs[(aux32 >> (14*il+7)) & 127];
4515
- for (int i = 0; i < 4; ++i) {
4516
- reg[2][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f);
4517
- reg[3][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f);
4927
+
4928
+ for (int row = 0; row < N_DST; ++row) {
4929
+ all_sum = simd_sum(sumf[row]);
4930
+ if (tiisg == 0) {
4931
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f;
4932
+ }
4518
4933
  }
4519
4934
  }
4520
4935
 
4521
- template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
4522
- kernel void kernel_get_rows(
4936
+ [[host_name("kernel_mul_mv_iq2_s_f32")]]
4937
+ kernel void kernel_mul_mv_iq2_s_f32(
4523
4938
  device const void * src0,
4524
- device const char * src1,
4939
+ device const float * src1,
4525
4940
  device float * dst,
4526
4941
  constant int64_t & ne00,
4942
+ constant int64_t & ne01,
4943
+ constant int64_t & ne02,
4944
+ constant uint64_t & nb00,
4527
4945
  constant uint64_t & nb01,
4528
4946
  constant uint64_t & nb02,
4529
4947
  constant int64_t & ne10,
4948
+ constant int64_t & ne11,
4949
+ constant int64_t & ne12,
4530
4950
  constant uint64_t & nb10,
4531
4951
  constant uint64_t & nb11,
4532
- constant uint64_t & nb1,
4533
- constant uint64_t & nb2,
4534
- uint3 tgpig[[threadgroup_position_in_grid]],
4535
- uint tiitg[[thread_index_in_threadgroup]],
4536
- uint3 tptg [[threads_per_threadgroup]]) {
4537
- //const int64_t i = tgpig;
4538
- //const int64_t r = ((device int32_t *) src1)[i];
4539
-
4540
- const int64_t i10 = tgpig.x;
4541
- const int64_t i11 = tgpig.y;
4542
-
4543
- const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
4544
-
4545
- const int64_t i02 = i11;
4952
+ constant uint64_t & nb12,
4953
+ constant int64_t & ne0,
4954
+ constant int64_t & ne1,
4955
+ constant uint & r2,
4956
+ constant uint & r3,
4957
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
4958
+ uint3 tgpig[[threadgroup_position_in_grid]],
4959
+ uint tiisg[[thread_index_in_simdgroup]],
4960
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4546
4961
 
4547
- for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) {
4548
- float4x4 temp;
4549
- dequantize_func(
4550
- ((device const block_q *) ((device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp);
4551
- *(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp;
4552
- }
4962
+ kernel_mul_mv_iq2_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
4553
4963
  }
4554
4964
 
4555
- kernel void kernel_get_rows_f32(
4965
+ void kernel_mul_mv_iq1_s_f32_impl(
4556
4966
  device const void * src0,
4557
- device const char * src1,
4967
+ device const float * src1,
4558
4968
  device float * dst,
4559
4969
  constant int64_t & ne00,
4560
- constant uint64_t & nb01,
4561
- constant uint64_t & nb02,
4970
+ constant int64_t & ne01,
4971
+ constant int64_t & ne02,
4562
4972
  constant int64_t & ne10,
4563
- constant uint64_t & nb10,
4564
- constant uint64_t & nb11,
4565
- constant uint64_t & nb1,
4566
- constant uint64_t & nb2,
4567
- uint3 tgpig[[threadgroup_position_in_grid]],
4568
- uint tiitg[[thread_index_in_threadgroup]],
4569
- uint3 tptg [[threads_per_threadgroup]]) {
4570
- const int64_t i10 = tgpig.x;
4571
- const int64_t i11 = tgpig.y;
4572
-
4573
- const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
4973
+ constant int64_t & ne12,
4974
+ constant int64_t & ne0,
4975
+ constant int64_t & ne1,
4976
+ constant uint & r2,
4977
+ constant uint & r3,
4978
+ uint3 tgpig[[threadgroup_position_in_grid]],
4979
+ uint tiisg[[thread_index_in_simdgroup]],
4980
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4574
4981
 
4575
- const int64_t i02 = i11;
4982
+ const int nb = ne00/QK_K;
4983
+ const int r0 = tgpig.x;
4984
+ const int r1 = tgpig.y;
4985
+ const int im = tgpig.z;
4576
4986
 
4577
- for (int ind = tiitg; ind < ne00; ind += tptg.x) {
4578
- ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
4579
- ((device float *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
4580
- }
4581
- }
4987
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
4988
+ const int ib_row = first_row * nb;
4582
4989
 
4583
- kernel void kernel_get_rows_f16(
4584
- device const void * src0,
4585
- device const char * src1,
4586
- device float * dst,
4587
- constant int64_t & ne00,
4588
- constant uint64_t & nb01,
4589
- constant uint64_t & nb02,
4590
- constant int64_t & ne10,
4591
- constant uint64_t & nb10,
4592
- constant uint64_t & nb11,
4593
- constant uint64_t & nb1,
4594
- constant uint64_t & nb2,
4595
- uint3 tgpig[[threadgroup_position_in_grid]],
4596
- uint tiitg[[thread_index_in_threadgroup]],
4597
- uint3 tptg [[threads_per_threadgroup]]) {
4598
- const int64_t i10 = tgpig.x;
4599
- const int64_t i11 = tgpig.y;
4990
+ const uint i12 = im%ne12;
4991
+ const uint i13 = im/ne12;
4600
4992
 
4601
- const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
4993
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
4994
+ device const block_iq1_s * x = (device const block_iq1_s *) src0 + ib_row + offset0;
4995
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
4602
4996
 
4603
- const int64_t i02 = i11;
4997
+ float yl[16];
4998
+ float sumf[N_DST]={0.f}, all_sum;
4604
4999
 
4605
- for (int ind = tiitg; ind < ne00; ind += tptg.x) {
4606
- ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
4607
- ((device half *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
4608
- }
4609
- }
5000
+ const int nb32 = nb * (QK_K / 32);
4610
5001
 
4611
- kernel void kernel_get_rows_i32(
4612
- device const void * src0,
4613
- device const char * src1,
4614
- device int32_t * dst,
4615
- constant int64_t & ne00,
4616
- constant uint64_t & nb01,
4617
- constant uint64_t & nb02,
4618
- constant int64_t & ne10,
4619
- constant uint64_t & nb10,
4620
- constant uint64_t & nb11,
4621
- constant uint64_t & nb1,
4622
- constant uint64_t & nb2,
4623
- uint3 tgpig[[threadgroup_position_in_grid]],
4624
- uint tiitg[[thread_index_in_threadgroup]],
4625
- uint3 tptg [[threads_per_threadgroup]]) {
4626
- const int64_t i10 = tgpig.x;
4627
- const int64_t i11 = tgpig.y;
5002
+ const int ix = tiisg/2;
5003
+ const int il = tiisg%2;
4628
5004
 
4629
- const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
5005
+ device const float * y4 = y + 32 * ix + 16 * il;
4630
5006
 
4631
- const int64_t i02 = i11;
5007
+ for (int ib32 = ix; ib32 < nb32; ib32 += 16) {
4632
5008
 
4633
- for (int ind = tiitg; ind < ne00; ind += tptg.x) {
4634
- ((device int32_t *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
4635
- ((device int32_t *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
4636
- }
4637
- }
5009
+ for (int i = 0; i < 16; ++i) {
5010
+ yl[i] = y4[i];
5011
+ }
4638
5012
 
5013
+ const int ibl = ib32 / (QK_K / 32);
5014
+ const int ib = ib32 % (QK_K / 32);
4639
5015
 
4640
- #define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
4641
- #define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
4642
- #define BLOCK_SIZE_K 32
4643
- #define THREAD_MAT_M 4 // each thread take 4 simdgroup matrices from matrix A
4644
- #define THREAD_MAT_N 2 // each thread take 2 simdgroup matrices from matrix B
4645
- #define THREAD_PER_BLOCK 128
4646
- #define THREAD_PER_ROW 2 // 2 thread for each row in matrix A to load numbers
4647
- #define THREAD_PER_COL 4 // 4 thread for each row in matrix B to load numbers
4648
- #define SG_MAT_SIZE 64 // simdgroup matrix is of shape 8x8
4649
- #define SG_MAT_ROW 8
5016
+ device const block_iq1_s * xr = x + ibl;
5017
+ device const uint8_t * qs = xr->qs + 4 * ib + 2 * il;
5018
+ device const uint8_t * sc = xr->scales + 2 * ib + il;
5019
+ device const half * dh = &xr->d;
4650
5020
 
4651
- // each block_q contains 16*nl weights
4652
- template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
4653
- void kernel_mul_mm_impl(device const uchar * src0,
4654
- device const uchar * src1,
4655
- device float * dst,
4656
- constant int64_t & ne00,
4657
- constant int64_t & ne02,
4658
- constant uint64_t & nb01,
4659
- constant uint64_t & nb02,
4660
- constant int64_t & ne12,
4661
- constant uint64_t & nb10,
4662
- constant uint64_t & nb11,
4663
- constant uint64_t & nb12,
4664
- constant int64_t & ne0,
4665
- constant int64_t & ne1,
4666
- constant uint & r2,
4667
- constant uint & r3,
4668
- threadgroup uchar * shared_memory [[threadgroup(0)]],
4669
- uint3 tgpig[[threadgroup_position_in_grid]],
4670
- uint tiitg[[thread_index_in_threadgroup]],
4671
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
5021
+ for (int row = 0; row < N_DST; row++) {
4672
5022
 
4673
- threadgroup half * sa = (threadgroup half *)(shared_memory);
4674
- threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
5023
+ constant int8_t * grid1 = (constant int8_t *)(iq1s_grid + (qs[0] | ((sc[0] & 0x08) << 5)));
5024
+ constant int8_t * grid2 = (constant int8_t *)(iq1s_grid + (qs[1] | ((sc[0] & 0x80) << 1)));
4675
5025
 
4676
- const uint r0 = tgpig.y;
4677
- const uint r1 = tgpig.x;
4678
- const uint im = tgpig.z;
5026
+ float2 sum = {0};
5027
+ for (int j = 0; j < 8; ++j) {
5028
+ sum[0] += yl[j+ 0] * grid1[j];
5029
+ sum[1] += yl[j+ 8] * grid2[j];
5030
+ }
5031
+ sumf[row] += (float)dh[0] * (sum[0] * (2*(sc[0] & 7) + 1) + sum[1] * (2*((sc[0] >> 4) & 7) + 1));
4679
5032
 
4680
- // if this block is of 64x32 shape or smaller
4681
- short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
4682
- short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
5033
+ dh += nb*sizeof(block_iq1_s)/2;
5034
+ qs += nb*sizeof(block_iq1_s);
5035
+ sc += nb*sizeof(block_iq1_s);
5036
+ }
4683
5037
 
4684
- // a thread shouldn't load data outside of the matrix
4685
- short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
4686
- short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
5038
+ y4 += 16 * 32;
5039
+ }
4687
5040
 
4688
- simdgroup_half8x8 ma[4];
4689
- simdgroup_float8x8 mb[2];
4690
- simdgroup_float8x8 c_res[8];
4691
- for (int i = 0; i < 8; i++){
4692
- c_res[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
5041
+ for (int row = 0; row < N_DST; ++row) {
5042
+ all_sum = simd_sum(sumf[row]);
5043
+ if (tiisg == 0) {
5044
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
5045
+ }
4693
5046
  }
5047
+ }
4694
5048
 
4695
- short il = (tiitg % THREAD_PER_ROW);
5049
+ constexpr constant static float kvalues_iq4nl_f[16] = {
5050
+ -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
5051
+ };
5052
+
5053
+ void kernel_mul_mv_iq4_nl_f32_impl(
5054
+ device const void * src0,
5055
+ device const float * src1,
5056
+ device float * dst,
5057
+ constant int64_t & ne00,
5058
+ constant int64_t & ne01,
5059
+ constant int64_t & ne02,
5060
+ constant int64_t & ne10,
5061
+ constant int64_t & ne12,
5062
+ constant int64_t & ne0,
5063
+ constant int64_t & ne1,
5064
+ constant uint & r2,
5065
+ constant uint & r3,
5066
+ threadgroup float * shared_values [[threadgroup(0)]],
5067
+ uint3 tgpig[[threadgroup_position_in_grid]],
5068
+ uint tiisg[[thread_index_in_simdgroup]],
5069
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
5070
+
5071
+ const int nb = ne00/QK4_NL;
5072
+ const int r0 = tgpig.x;
5073
+ const int r1 = tgpig.y;
5074
+ const int im = tgpig.z;
5075
+ const int first_row = (r0 * 2 + sgitg) * 2;
5076
+ const int ib_row = first_row * nb;
4696
5077
 
4697
5078
  const uint i12 = im%ne12;
4698
5079
  const uint i13 = im/ne12;
4699
5080
 
4700
- uint offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02);
4701
- ushort offset1 = il/nl;
5081
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
5082
+ device const block_iq4_nl * x = (device const block_iq4_nl *) src0 + ib_row + offset0;
5083
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
4702
5084
 
4703
- device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
4704
- device const float * y = (device const float *)(src1
4705
- + nb12 * im
4706
- + nb11 * (r1 * BLOCK_SIZE_N + thread_col)
4707
- + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
5085
+ const int ix = tiisg/2; // 0...15
5086
+ const int it = tiisg%2; // 0 or 1
4708
5087
 
4709
- for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
4710
- // load data and store to threadgroup memory
4711
- half4x4 temp_a;
4712
- dequantize_func(x, il, temp_a);
4713
- threadgroup_barrier(mem_flags::mem_threadgroup);
5088
+ shared_values[tiisg] = kvalues_iq4nl_f[tiisg%16];
5089
+ threadgroup_barrier(mem_flags::mem_threadgroup);
4714
5090
 
4715
- #pragma unroll(16)
4716
- for (int i = 0; i < 16; i++) {
4717
- *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
4718
- + (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \
4719
- + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
4720
- }
5091
+ float4 yl[4];
5092
+ float sumf[2]={0.f}, all_sum;
4721
5093
 
4722
- *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y);
5094
+ device const float * yb = y + ix * QK4_NL + it * 8;
4723
5095
 
4724
- il = (il + 2 < nl) ? il + 2 : il % 2;
4725
- x = (il < 2) ? x + (2+nl-1)/nl : x;
4726
- y += BLOCK_SIZE_K;
5096
+ uint32_t aux32[2];
5097
+ thread const uint8_t * q8 = (thread const uint8_t *)aux32;
4727
5098
 
4728
- threadgroup_barrier(mem_flags::mem_threadgroup);
5099
+ float4 qf1, qf2;
4729
5100
 
4730
- // load matrices from threadgroup memory and conduct outer products
4731
- threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
4732
- threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
5101
+ for (int ib = ix; ib < nb; ib += 16) {
4733
5102
 
4734
- #pragma unroll(4)
4735
- for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
4736
- #pragma unroll(4)
4737
- for (int i = 0; i < 4; i++) {
4738
- simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i);
4739
- }
4740
- simdgroup_barrier(mem_flags::mem_none);
4741
- #pragma unroll(2)
4742
- for (int i = 0; i < 2; i++) {
4743
- simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i);
4744
- }
5103
+ device const float4 * y4 = (device const float4 *)yb;
5104
+ yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5];
4745
5105
 
4746
- lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
4747
- lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
5106
+ for (int row = 0; row < 2; ++row) {
5107
+
5108
+ device const block_iq4_nl & xb = x[row*nb + ib];
5109
+ device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it);
5110
+
5111
+ float4 acc1 = {0.f}, acc2 = {0.f};
5112
+
5113
+ aux32[0] = q4[0] | (q4[1] << 16);
5114
+ aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f;
5115
+ aux32[0] &= 0x0f0f0f0f;
5116
+ qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
5117
+ qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
5118
+ acc1 += yl[0] * qf1;
5119
+ acc2 += yl[1] * qf2;
5120
+
5121
+ aux32[0] = q4[2] | (q4[3] << 16);
5122
+ aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f;
5123
+ aux32[0] &= 0x0f0f0f0f;
5124
+ qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
5125
+ qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
5126
+ acc1 += yl[2] * qf1;
5127
+ acc2 += yl[3] * qf2;
5128
+
5129
+ acc1 += acc2;
5130
+
5131
+ sumf[row] += (float)xb.d * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
4748
5132
 
4749
- #pragma unroll(8)
4750
- for (int i = 0; i < 8; i++){
4751
- simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
4752
- }
4753
5133
  }
5134
+
5135
+ yb += 16 * QK4_NL;
4754
5136
  }
4755
5137
 
4756
- if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) {
4757
- device float * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1)) \
4758
- + (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * ne0 + im*ne1*ne0;
4759
- for (int i = 0; i < 8; i++) {
4760
- simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0);
4761
- }
4762
- } else {
4763
- // block is smaller than 64x32, we should avoid writing data outside of the matrix
4764
- threadgroup_barrier(mem_flags::mem_threadgroup);
4765
- threadgroup float * temp_str = ((threadgroup float *)shared_memory) \
4766
- + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
4767
- for (int i = 0; i < 8; i++) {
4768
- simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
4769
- }
4770
-
4771
- threadgroup_barrier(mem_flags::mem_threadgroup);
4772
-
4773
- device float * C = dst + (BLOCK_SIZE_M * r0) + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
4774
- if (sgitg == 0) {
4775
- for (int i = 0; i < n_rows; i++) {
4776
- for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
4777
- *(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
4778
- }
4779
- }
5138
+ for (int row = 0; row < 2; ++row) {
5139
+ all_sum = simd_sum(sumf[row]);
5140
+ if (tiisg == 0) {
5141
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
4780
5142
  }
4781
5143
  }
4782
5144
  }
4783
5145
 
4784
- // same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in src1ids
4785
- template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
4786
- void kernel_mul_mm_id_impl(
4787
- device const uchar * src0,
4788
- device const uchar * src1,
4789
- thread short * src1ids,
4790
- device float * dst,
4791
- constant int64_t & ne00,
4792
- constant int64_t & ne02,
4793
- constant uint64_t & nb01,
4794
- constant uint64_t & nb02,
4795
- constant int64_t & ne12,
4796
- constant uint64_t & nb10,
4797
- constant uint64_t & nb11,
4798
- constant uint64_t & nb12,
4799
- constant int64_t & ne0,
4800
- int64_t ne1,
4801
- constant uint & r2,
4802
- constant uint & r3,
4803
- threadgroup uchar * shared_memory,
4804
- uint3 tgpig[[threadgroup_position_in_grid]],
4805
- uint tiitg[[thread_index_in_threadgroup]],
4806
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
5146
+ #if QK_K != 64
5147
+ void kernel_mul_mv_iq4_xs_f32_impl(
5148
+ device const void * src0,
5149
+ device const float * src1,
5150
+ device float * dst,
5151
+ constant int64_t & ne00,
5152
+ constant int64_t & ne01,
5153
+ constant int64_t & ne02,
5154
+ constant int64_t & ne10,
5155
+ constant int64_t & ne12,
5156
+ constant int64_t & ne0,
5157
+ constant int64_t & ne1,
5158
+ constant uint & r2,
5159
+ constant uint & r3,
5160
+ threadgroup float * shared_values [[threadgroup(0)]],
5161
+ uint3 tgpig[[threadgroup_position_in_grid]],
5162
+ uint tiisg[[thread_index_in_simdgroup]],
5163
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4807
5164
 
4808
- threadgroup half * sa = (threadgroup half *)(shared_memory);
4809
- threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
5165
+ const int nb = ne00/QK_K;
5166
+ const int r0 = tgpig.x;
5167
+ const int r1 = tgpig.y;
5168
+ const int im = tgpig.z;
5169
+ const int first_row = (r0 * 2 + sgitg) * 2;
5170
+ const int ib_row = first_row * nb;
4810
5171
 
4811
- const uint r0 = tgpig.y;
4812
- const uint r1 = tgpig.x;
4813
- const uint im = tgpig.z;
5172
+ const uint i12 = im%ne12;
5173
+ const uint i13 = im/ne12;
4814
5174
 
4815
- if (r1 * BLOCK_SIZE_N >= ne1) return;
5175
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
5176
+ device const block_iq4_xs * x = (device const block_iq4_xs *) src0 + ib_row + offset0;
5177
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
4816
5178
 
4817
- // if this block is of 64x32 shape or smaller
4818
- short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
4819
- short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
5179
+ const int ix = tiisg/16; // 0 or 1
5180
+ const int it = tiisg%16; // 0...15
5181
+ const int ib = it/2;
5182
+ const int il = it%2;
4820
5183
 
4821
- // a thread shouldn't load data outside of the matrix
4822
- short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
4823
- short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
5184
+ shared_values[tiisg] = kvalues_iq4nl_f[tiisg%16];
5185
+ threadgroup_barrier(mem_flags::mem_threadgroup);
4824
5186
 
4825
- simdgroup_half8x8 ma[4];
4826
- simdgroup_float8x8 mb[2];
4827
- simdgroup_float8x8 c_res[8];
4828
- for (int i = 0; i < 8; i++){
4829
- c_res[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
4830
- }
5187
+ float4 yl[4];
5188
+ float sumf[2]={0.f}, all_sum;
4831
5189
 
4832
- short il = (tiitg % THREAD_PER_ROW);
5190
+ device const float * yb = y + ix * QK_K + ib * 32 + il * 8;
4833
5191
 
4834
- const uint i12 = im%ne12;
4835
- const uint i13 = im/ne12;
5192
+ uint32_t aux32[2];
5193
+ thread const uint8_t * q8 = (thread const uint8_t *)aux32;
4836
5194
 
4837
- uint offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02);
4838
- ushort offset1 = il/nl;
5195
+ float4 qf1, qf2;
4839
5196
 
4840
- device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
4841
- device const float * y = (device const float *)(src1
4842
- + nb12 * im
4843
- + nb11 * src1ids[r1 * BLOCK_SIZE_N + thread_col]
4844
- + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
5197
+ for (int ibl = ix; ibl < nb; ibl += 2) {
4845
5198
 
4846
- for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
4847
- // load data and store to threadgroup memory
4848
- half4x4 temp_a;
4849
- dequantize_func(x, il, temp_a);
4850
- threadgroup_barrier(mem_flags::mem_threadgroup);
5199
+ device const float4 * y4 = (device const float4 *)yb;
5200
+ yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5];
4851
5201
 
4852
- for (int i = 0; i < 16; i++) {
4853
- *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
4854
- + (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \
4855
- + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
4856
- }
5202
+ for (int row = 0; row < 2; ++row) {
4857
5203
 
4858
- *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y);
5204
+ device const block_iq4_xs & xb = x[row*nb + ibl];
5205
+ device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il);
4859
5206
 
4860
- il = (il + 2 < nl) ? il + 2 : il % 2;
4861
- x = (il < 2) ? x + (2+nl-1)/nl : x;
4862
- y += BLOCK_SIZE_K;
5207
+ float4 acc1 = {0.f}, acc2 = {0.f};
4863
5208
 
4864
- threadgroup_barrier(mem_flags::mem_threadgroup);
5209
+ aux32[0] = q4[0] & 0x0f0f0f0f;
5210
+ aux32[1] = (q4[0] >> 4) & 0x0f0f0f0f;
5211
+ qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
5212
+ qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
5213
+ acc1 += yl[0] * qf1;
5214
+ acc2 += yl[1] * qf2;
4865
5215
 
4866
- // load matrices from threadgroup memory and conduct outer products
4867
- threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
4868
- threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
5216
+ aux32[0] = q4[1] & 0x0f0f0f0f;
5217
+ aux32[1] = (q4[1] >> 4) & 0x0f0f0f0f;
5218
+ qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
5219
+ qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
5220
+ acc1 += yl[2] * qf1;
5221
+ acc2 += yl[3] * qf2;
4869
5222
 
4870
- for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
4871
- for (int i = 0; i < 4; i++) {
4872
- simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i);
4873
- }
4874
- simdgroup_barrier(mem_flags::mem_none);
4875
- for (int i = 0; i < 2; i++) {
4876
- simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i);
4877
- }
5223
+ acc1 += acc2;
4878
5224
 
4879
- lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
4880
- lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
5225
+ const int ls = (((xb.scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((xb.scales_h >> 2*ib) & 3) << 4)) - 32;
5226
+ sumf[row] += (float)xb.d * ls * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
4881
5227
 
4882
- for (int i = 0; i < 8; i++){
4883
- simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
4884
- }
4885
5228
  }
5229
+
5230
+ yb += 2 * QK_K;
5231
+ }
5232
+
5233
+ for (int row = 0; row < 2; ++row) {
5234
+ all_sum = simd_sum(sumf[row]);
5235
+ if (tiisg == 0) {
5236
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
5237
+ }
5238
+ }
5239
+ }
5240
+ #endif
5241
+
5242
+ [[host_name("kernel_mul_mv_iq1_s_f32")]]
5243
+ kernel void kernel_mul_mv_iq1_s_f32(
5244
+ device const void * src0,
5245
+ device const float * src1,
5246
+ device float * dst,
5247
+ constant int64_t & ne00,
5248
+ constant int64_t & ne01,
5249
+ constant int64_t & ne02,
5250
+ constant uint64_t & nb00,
5251
+ constant uint64_t & nb01,
5252
+ constant uint64_t & nb02,
5253
+ constant int64_t & ne10,
5254
+ constant int64_t & ne11,
5255
+ constant int64_t & ne12,
5256
+ constant uint64_t & nb10,
5257
+ constant uint64_t & nb11,
5258
+ constant uint64_t & nb12,
5259
+ constant int64_t & ne0,
5260
+ constant int64_t & ne1,
5261
+ constant uint & r2,
5262
+ constant uint & r3,
5263
+ uint3 tgpig[[threadgroup_position_in_grid]],
5264
+ uint tiisg[[thread_index_in_simdgroup]],
5265
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
5266
+
5267
+ kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
5268
+ }
5269
+
5270
+ [[host_name("kernel_mul_mv_iq4_nl_f32")]]
5271
+ kernel void kernel_mul_mv_iq4_nl_f32(
5272
+ device const void * src0,
5273
+ device const float * src1,
5274
+ device float * dst,
5275
+ constant int64_t & ne00,
5276
+ constant int64_t & ne01,
5277
+ constant int64_t & ne02,
5278
+ constant uint64_t & nb00,
5279
+ constant uint64_t & nb01,
5280
+ constant uint64_t & nb02,
5281
+ constant int64_t & ne10,
5282
+ constant int64_t & ne11,
5283
+ constant int64_t & ne12,
5284
+ constant uint64_t & nb10,
5285
+ constant uint64_t & nb11,
5286
+ constant uint64_t & nb12,
5287
+ constant int64_t & ne0,
5288
+ constant int64_t & ne1,
5289
+ constant uint & r2,
5290
+ constant uint & r3,
5291
+ threadgroup float * shared_values [[threadgroup(0)]],
5292
+ uint3 tgpig[[threadgroup_position_in_grid]],
5293
+ uint tiisg[[thread_index_in_simdgroup]],
5294
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
5295
+
5296
+ kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
5297
+ }
5298
+
5299
+ [[host_name("kernel_mul_mv_iq4_xs_f32")]]
5300
+ kernel void kernel_mul_mv_iq4_xs_f32(
5301
+ device const void * src0,
5302
+ device const float * src1,
5303
+ device float * dst,
5304
+ constant int64_t & ne00,
5305
+ constant int64_t & ne01,
5306
+ constant int64_t & ne02,
5307
+ constant uint64_t & nb00,
5308
+ constant uint64_t & nb01,
5309
+ constant uint64_t & nb02,
5310
+ constant int64_t & ne10,
5311
+ constant int64_t & ne11,
5312
+ constant int64_t & ne12,
5313
+ constant uint64_t & nb10,
5314
+ constant uint64_t & nb11,
5315
+ constant uint64_t & nb12,
5316
+ constant int64_t & ne0,
5317
+ constant int64_t & ne1,
5318
+ constant uint & r2,
5319
+ constant uint & r3,
5320
+ threadgroup float * shared_values [[threadgroup(0)]],
5321
+ uint3 tgpig[[threadgroup_position_in_grid]],
5322
+ uint tiisg[[thread_index_in_simdgroup]],
5323
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
5324
+
5325
+ #if QK_K == 64
5326
+ kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
5327
+ #else
5328
+ kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
5329
+ #endif
5330
+ }
5331
+
5332
+ //============================= templates and their specializations =============================
5333
+
5334
+ // NOTE: this is not dequantizing - we are simply fitting the template
5335
+ template <typename type4x4>
5336
+ void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
5337
+ float4x4 temp = *(((device float4x4 *)src));
5338
+ for (int i = 0; i < 16; i++){
5339
+ reg[i/4][i%4] = temp[i/4][i%4];
5340
+ }
5341
+ }
5342
+
5343
+ template <typename type4x4>
5344
+ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
5345
+ half4x4 temp = *(((device half4x4 *)src));
5346
+ for (int i = 0; i < 16; i++){
5347
+ reg[i/4][i%4] = temp[i/4][i%4];
4886
5348
  }
5349
+ }
5350
+
5351
+ template <typename type4x4>
5352
+ void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
5353
+ device const uint16_t * qs = ((device const uint16_t *)xb + 1);
5354
+ const float d1 = il ? (xb->d / 16.h) : xb->d;
5355
+ const float d2 = d1 / 256.f;
5356
+ const float md = -8.h * xb->d;
5357
+ const ushort mask0 = il ? 0x00F0 : 0x000F;
5358
+ const ushort mask1 = mask0 << 8;
5359
+
5360
+ for (int i=0;i<8;i++) {
5361
+ reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md;
5362
+ reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md;
5363
+ }
5364
+ }
5365
+
5366
+ template <typename type4x4>
5367
+ void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
5368
+ device const uint16_t * qs = ((device const uint16_t *)xb + 2);
5369
+ const float d1 = il ? (xb->d / 16.h) : xb->d;
5370
+ const float d2 = d1 / 256.f;
5371
+ const float m = xb->m;
5372
+ const ushort mask0 = il ? 0x00F0 : 0x000F;
5373
+ const ushort mask1 = mask0 << 8;
5374
+
5375
+ for (int i=0;i<8;i++) {
5376
+ reg[i/2][2*(i%2)+0] = ((qs[i] & mask0) * d1) + m;
5377
+ reg[i/2][2*(i%2)+1] = ((qs[i] & mask1) * d2) + m;
5378
+ }
5379
+ }
5380
+
5381
+ template <typename type4x4>
5382
+ void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg) {
5383
+ device const uint16_t * qs = ((device const uint16_t *)xb + 3);
5384
+ const float d = xb->d;
5385
+ const float md = -16.h * xb->d;
5386
+ const ushort mask = il ? 0x00F0 : 0x000F;
5387
+
5388
+ const uint32_t qh = *((device const uint32_t *)xb->qh);
5389
+
5390
+ const int x_mv = il ? 4 : 0;
5391
+
5392
+ const int gh_mv = il ? 12 : 0;
5393
+ const int gh_bk = il ? 0 : 4;
5394
+
5395
+ for (int i = 0; i < 8; i++) {
5396
+ // extract the 5-th bits for x0 and x1
5397
+ const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
5398
+ const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
5399
+
5400
+ // combine the 4-bits from qs with the 5th bit
5401
+ const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
5402
+ const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
5403
+
5404
+ reg[i/2][2*(i%2)+0] = d * x0 + md;
5405
+ reg[i/2][2*(i%2)+1] = d * x1 + md;
5406
+ }
5407
+ }
5408
+
5409
+ template <typename type4x4>
5410
+ void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg) {
5411
+ device const uint16_t * qs = ((device const uint16_t *)xb + 4);
5412
+ const float d = xb->d;
5413
+ const float m = xb->m;
5414
+ const ushort mask = il ? 0x00F0 : 0x000F;
5415
+
5416
+ const uint32_t qh = *((device const uint32_t *)xb->qh);
5417
+
5418
+ const int x_mv = il ? 4 : 0;
5419
+
5420
+ const int gh_mv = il ? 12 : 0;
5421
+ const int gh_bk = il ? 0 : 4;
5422
+
5423
+ for (int i = 0; i < 8; i++) {
5424
+ // extract the 5-th bits for x0 and x1
5425
+ const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
5426
+ const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
5427
+
5428
+ // combine the 4-bits from qs with the 5th bit
5429
+ const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
5430
+ const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
5431
+
5432
+ reg[i/2][2*(i%2)+0] = d * x0 + m;
5433
+ reg[i/2][2*(i%2)+1] = d * x1 + m;
5434
+ }
5435
+ }
5436
+
5437
+ template <typename type4x4>
5438
+ void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
5439
+ device const int8_t * qs = ((device const int8_t *)xb->qs);
5440
+ const half d = xb->d;
5441
+
5442
+ for (int i = 0; i < 16; i++) {
5443
+ reg[i/4][i%4] = (qs[i + 16*il] * d);
5444
+ }
5445
+ }
5446
+
5447
+ template <typename type4x4>
5448
+ void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
5449
+ const float d = xb->d;
5450
+ const float min = xb->dmin;
5451
+ device const uint8_t * q = (device const uint8_t *)xb->qs;
5452
+ float dl, ml;
5453
+ uint8_t sc = xb->scales[il];
5454
+
5455
+ #if QK_K == 256
5456
+ q = q + 32*(il/8) + 16*(il&1);
5457
+ il = (il/2)%4;
5458
+ #endif
5459
+ half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
5460
+ uchar mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
5461
+ dl = d * (sc & 0xF) * coef, ml = min * (sc >> 4);
5462
+ for (int i = 0; i < 16; ++i) {
5463
+ reg[i/4][i%4] = dl * (q[i] & mask) - ml;
5464
+ }
5465
+ }
5466
+
5467
+ template <typename type4x4>
5468
+ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) {
5469
+ const half d_all = xb->d;
5470
+ device const uint8_t * q = (device const uint8_t *)xb->qs;
5471
+ device const uint8_t * h = (device const uint8_t *)xb->hmask;
5472
+ device const int8_t * scales = (device const int8_t *)xb->scales;
5473
+
5474
+ #if QK_K == 256
5475
+ q = q + 32 * (il/8) + 16 * (il&1);
5476
+ h = h + 16 * (il&1);
5477
+ uint8_t m = 1 << (il/2);
5478
+ uint16_t kmask1 = (il/4)>1 ? ((il/4)>2 ? 192 : 48) : \
5479
+ ((il/4)>0 ? 12 : 3);
5480
+ uint16_t kmask2 = il/8 ? 0xF0 : 0x0F;
5481
+ uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
5482
+ int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2)
5483
+ : (scale_2&kmask2) | ((scale_1&kmask1) << 4);
5484
+ float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f);
5485
+ const float ml = 4.f * dl;
5486
+
5487
+ il = (il/2) & 3;
5488
+ const half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
5489
+ const uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
5490
+ dl *= coef;
5491
+
5492
+ for (int i = 0; i < 16; ++i) {
5493
+ reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml);
5494
+ }
5495
+ #else
5496
+ float kcoef = il&1 ? 1.f/16.f : 1.f;
5497
+ uint16_t kmask = il&1 ? 0xF0 : 0x0F;
5498
+ float dl = d_all * ((scales[il/2] & kmask) * kcoef - 8);
5499
+ float coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
5500
+ uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
5501
+ uint8_t m = 1<<(il*2);
5502
+ for (int i = 0; i < 16; ++i) {
5503
+ reg[i/4][i%4] = coef * dl * ((q[i] & mask) - ((h[i%8] & (m * (1 + i/8))) ? 0 : 4.f/coef));
5504
+ }
5505
+ #endif
5506
+ }
5507
+
5508
+ static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) {
5509
+ return j < 4 ? uchar2{uchar(q[j+0+k] & 63), uchar(q[j+4+k] & 63)}
5510
+ : uchar2{uchar((q[j+4+k] & 0xF) | ((q[j-4+k] & 0xc0) >> 2)), uchar((q[j+4+k] >> 4) | ((q[j-0+k] & 0xc0) >> 2))};
5511
+ }
5512
+
5513
+ template <typename type4x4>
5514
+ void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) {
5515
+ device const uchar * q = xb->qs;
5516
+
5517
+ #if QK_K == 256
5518
+ short is = (il/4) * 2;
5519
+ q = q + (il/4) * 32 + 16 * (il&1);
5520
+ il = il & 3;
5521
+ const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
5522
+ const float d = il < 2 ? xb->d : xb->d / 16.h;
5523
+ const float min = xb->dmin;
5524
+ const float dl = d * sc[0];
5525
+ const float ml = min * sc[1];
5526
+ #else
5527
+ (void) get_scale_min_k4_just2;
5528
+
5529
+ q = q + 16 * (il&1);
5530
+ device const uint8_t * s = xb->scales;
5531
+ device const half2 * dh = (device const half2 *)xb->d;
5532
+ const float2 d = (float2)dh[0];
5533
+ const float dl = il<2 ? d[0] * (s[0]&0xF) : d[0] * (s[1]&0xF)/16.h;
5534
+ const float ml = il<2 ? d[1] * (s[0]>>4) : d[1] * (s[1]>>4);
5535
+ #endif
5536
+ const ushort mask = il<2 ? 0x0F : 0xF0;
5537
+ for (int i = 0; i < 16; ++i) {
5538
+ reg[i/4][i%4] = dl * (q[i] & mask) - ml;
5539
+ }
5540
+ }
5541
+
5542
+ template <typename type4x4>
5543
+ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg) {
5544
+ device const uint8_t * q = xb->qs;
5545
+ device const uint8_t * qh = xb->qh;
5546
+
5547
+ #if QK_K == 256
5548
+ short is = (il/4) * 2;
5549
+ q = q + 32 * (il/4) + 16 * (il&1);
5550
+ qh = qh + 16 * (il&1);
5551
+ uint8_t ul = 1 << (il/2);
5552
+ il = il & 3;
5553
+ const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
5554
+ const float d = il < 2 ? xb->d : xb->d / 16.f;
5555
+ const float min = xb->dmin;
5556
+ const float dl = d * sc[0];
5557
+ const float ml = min * sc[1];
5558
+
5559
+ const ushort mask = il<2 ? 0x0F : 0xF0;
5560
+ const float qh_val = il<2 ? 16.f : 256.f;
5561
+ for (int i = 0; i < 16; ++i) {
5562
+ reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
5563
+ }
5564
+ #else
5565
+ q = q + 16 * (il&1);
5566
+ device const int8_t * s = xb->scales;
5567
+ const float dl = xb->d * s[il];
5568
+ uint8_t m = 1<<(il*2);
5569
+ const float coef = il<2 ? 1.f : 1.f/16.f;
5570
+ const ushort mask = il<2 ? 0x0F : 0xF0;
5571
+ for (int i = 0; i < 16; ++i) {
5572
+ reg[i/4][i%4] = coef * dl * ((q[i] & mask) - (qh[i%8] & (m*(1+i/8)) ? 0.f : 16.f/coef));
5573
+ }
5574
+ #endif
5575
+ }
5576
+
5577
+ template <typename type4x4>
5578
+ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) {
5579
+ const half d_all = xb->d;
5580
+ device const uint8_t * ql = (device const uint8_t *)xb->ql;
5581
+ device const uint8_t * qh = (device const uint8_t *)xb->qh;
5582
+ device const int8_t * scales = (device const int8_t *)xb->scales;
5583
+
5584
+ #if QK_K == 256
5585
+ ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
5586
+ qh = qh + 32*(il/8) + 16*(il&1);
5587
+ float sc = scales[(il%2) + 2 * ((il/2))];
5588
+ il = (il/2) & 3;
5589
+ #else
5590
+ ql = ql + 16 * (il&1);
5591
+ float sc = scales[il];
5592
+ #endif
5593
+ const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
5594
+ const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
5595
+ const float coef = il>1 ? 1.f/16.f : 1.f;
5596
+ const float ml = d_all * sc * 32.f;
5597
+ const float dl = d_all * sc * coef;
5598
+ for (int i = 0; i < 16; ++i) {
5599
+ const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2))
5600
+ : ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4));
5601
+ reg[i/4][i%4] = dl * q - ml;
5602
+ }
5603
+ }
5604
+
5605
+ template <typename type4x4>
5606
+ void dequantize_iq2_xxs(device const block_iq2_xxs * xb, short il, thread type4x4 & reg) {
5607
+ // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
5608
+ const float d = xb->d;
5609
+ const int ib32 = il/2;
5610
+ il = il%2;
5611
+ // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
5612
+ // each block of 32 needs 2 uint32_t's for the quants & scale, so 4 uint16_t's.
5613
+ device const uint16_t * q2 = xb->qs + 4*ib32;
5614
+ const uint32_t aux32_g = q2[0] | (q2[1] << 16);
5615
+ const uint32_t aux32_s = q2[2] | (q2[3] << 16);
5616
+ thread const uint8_t * aux8 = (thread const uint8_t *)&aux32_g;
5617
+ const float dl = d * (0.5f + (aux32_s >> 28)) * 0.25f;
5618
+ constant uint8_t * grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+0]);
5619
+ uint8_t signs = ksigns_iq2xs[(aux32_s >> 14*il) & 127];
5620
+ for (int i = 0; i < 8; ++i) {
5621
+ reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
5622
+ }
5623
+ grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+1]);
5624
+ signs = ksigns_iq2xs[(aux32_s >> (14*il+7)) & 127];
5625
+ for (int i = 0; i < 8; ++i) {
5626
+ reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
5627
+ }
5628
+ }
5629
+
5630
+ template <typename type4x4>
5631
+ void dequantize_iq2_xs(device const block_iq2_xs * xb, short il, thread type4x4 & reg) {
5632
+ // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
5633
+ const float d = xb->d;
5634
+ const int ib32 = il/2;
5635
+ il = il%2;
5636
+ // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
5637
+ device const uint16_t * q2 = xb->qs + 4*ib32;
5638
+ const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f;
5639
+ constant uint8_t * grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+0] & 511));
5640
+ uint8_t signs = ksigns_iq2xs[q2[2*il+0] >> 9];
5641
+ for (int i = 0; i < 8; ++i) {
5642
+ reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
5643
+ }
5644
+ grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+1] & 511));
5645
+ signs = ksigns_iq2xs[q2[2*il+1] >> 9];
5646
+ for (int i = 0; i < 8; ++i) {
5647
+ reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
5648
+ }
5649
+ }
5650
+
5651
+ template <typename type4x4>
5652
+ void dequantize_iq3_xxs(device const block_iq3_xxs * xb, short il, thread type4x4 & reg) {
5653
+ // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
5654
+ const float d = xb->d;
5655
+ const int ib32 = il/2;
5656
+ il = il%2;
5657
+ // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
5658
+ device const uint8_t * q3 = xb->qs + 8*ib32;
5659
+ device const uint16_t * gas = (device const uint16_t *)(xb->qs + QK_K/4) + 2*ib32;
5660
+ const uint32_t aux32 = gas[0] | (gas[1] << 16);
5661
+ const float dl = d * (0.5f + (aux32 >> 28)) * 0.5f;
5662
+ constant uint8_t * grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+0]);
5663
+ constant uint8_t * grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+1]);
5664
+ uint8_t signs = ksigns_iq2xs[(aux32 >> 14*il) & 127];
5665
+ for (int i = 0; i < 4; ++i) {
5666
+ reg[0][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f);
5667
+ reg[1][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f);
5668
+ }
5669
+ grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+2]);
5670
+ grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+3]);
5671
+ signs = ksigns_iq2xs[(aux32 >> (14*il+7)) & 127];
5672
+ for (int i = 0; i < 4; ++i) {
5673
+ reg[2][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f);
5674
+ reg[3][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f);
5675
+ }
5676
+ }
5677
+
5678
+ template <typename type4x4>
5679
+ void dequantize_iq3_s(device const block_iq3_s * xb, short il, thread type4x4 & reg) {
5680
+ // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
5681
+ const float d = xb->d;
5682
+ const int ib32 = il/2;
5683
+ il = il%2;
5684
+ // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
5685
+ device const uint8_t * qs = xb->qs + 8*ib32;
5686
+ device const uint8_t * signs = xb->signs + 4*ib32 + 2*il;
5687
+ const uint8_t qh = xb->qh[ib32] >> 4*il;
5688
+ const float dl = d * (0.5f + ((xb->scales[ib32/2] >> 4*(ib32%2)) & 0xf)) * 0.5f;
5689
+ constant uint8_t * grid1 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+0] | ((qh << 8) & 256)));
5690
+ constant uint8_t * grid2 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+1] | ((qh << 7) & 256)));
5691
+ for (int i = 0; i < 4; ++i) {
5692
+ reg[0][i] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i+0]);
5693
+ reg[1][i] = dl * grid2[i] * select(1, -1, signs[0] & kmask_iq2xs[i+4]);
5694
+ }
5695
+ grid1 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+2] | ((qh << 6) & 256)));
5696
+ grid2 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+3] | ((qh << 5) & 256)));
5697
+ for (int i = 0; i < 4; ++i) {
5698
+ reg[2][i] = dl * grid1[i] * select(1, -1, signs[1] & kmask_iq2xs[i+0]);
5699
+ reg[3][i] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i+4]);
5700
+ }
5701
+ }
5702
+
5703
+ template <typename type4x4>
5704
+ void dequantize_iq2_s(device const block_iq2_s * xb, short il, thread type4x4 & reg) {
5705
+ // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
5706
+ const float d = xb->d;
5707
+ const int ib32 = il/2;
5708
+ il = il%2;
5709
+ // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
5710
+ device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
5711
+ device const uint8_t * signs = qs + QK_K/8;
5712
+ const uint8_t qh = xb->qh[ib32] >> 4*il;
5713
+ const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f;
5714
+ constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[0] | ((qh << 8) & 0x300)));
5715
+ constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[1] | ((qh << 6) & 0x300)));
5716
+ for (int i = 0; i < 8; ++i) {
5717
+ reg[i/4+0][i%4] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i]);
5718
+ reg[i/4+2][i%4] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i]);
5719
+ }
5720
+ }
5721
+
5722
+ template <typename type4x4>
5723
+ void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 & reg) {
5724
+ // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
5725
+ const float d = xb->d;
5726
+ device const uint8_t * qs = xb->qs + 2*il;
5727
+ device const uint8_t * sc = xb->scales + il;
5728
+ const float dl1 = d * (2*(sc[0] & 7) + 1);
5729
+ const float dl2 = d * (2*((sc[0] >> 4) & 7) + 1);
5730
+ constant int8_t * grid1 = (constant int8_t *)(iq1s_grid + (qs[0] | ((sc[0] & 0x08) << 5)));
5731
+ constant int8_t * grid2 = (constant int8_t *)(iq1s_grid + (qs[1] | ((sc[0] & 0x80) << 1)));
5732
+ for (int i = 0; i < 8; ++i) {
5733
+ reg[i/4+0][i%4] = dl1 * grid1[i];
5734
+ reg[i/4+2][i%4] = dl2 * grid2[i];
5735
+ }
5736
+ }
5737
+
5738
+ template <typename type4x4>
5739
+ void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4 & reg) {
5740
+ device const uint16_t * q4 = (device const uint16_t *)xb->qs;
5741
+ const float d = xb->d;
5742
+ uint32_t aux32;
5743
+ thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
5744
+ for (int i = 0; i < 4; ++i) {
5745
+ aux32 = ((q4[2*i] | (q4[2*i+1] << 16)) >> 4*il) & 0x0f0f0f0f;
5746
+ reg[i][0] = d * kvalues_iq4nl_f[q8[0]];
5747
+ reg[i][1] = d * kvalues_iq4nl_f[q8[1]];
5748
+ reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
5749
+ reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
5750
+ }
5751
+ }
5752
+
5753
+ template <typename type4x4>
5754
+ void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) {
5755
+ #if QK_K == 64
5756
+ dequantize_iq4_nl(xb, il, reg);
5757
+ #else
5758
+ // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
5759
+ const int ib32 = il/2;
5760
+ il = il%2;
5761
+ // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
5762
+ device const uint32_t * q4 = (device const uint32_t *)xb->qs + 4*ib32;
5763
+ const int ls = ((xb->scales_l[ib32/2] >> 4*(ib32%2)) & 0xf) | (((xb->scales_h >> 2*ib32) & 3) << 4);
5764
+ const float d = (float)xb->d * (ls - 32);
5765
+ uint32_t aux32;
5766
+ thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
5767
+ for (int i = 0; i < 4; ++i) {
5768
+ aux32 = (q4[i] >> 4*il) & 0x0f0f0f0f;
5769
+ reg[i][0] = d * kvalues_iq4nl_f[q8[0]];
5770
+ reg[i][1] = d * kvalues_iq4nl_f[q8[1]];
5771
+ reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
5772
+ reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
5773
+ }
5774
+ #endif
5775
+ }
5776
+
5777
+ template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
5778
+ kernel void kernel_get_rows(
5779
+ device const void * src0,
5780
+ device const char * src1,
5781
+ device float * dst,
5782
+ constant int64_t & ne00,
5783
+ constant uint64_t & nb01,
5784
+ constant uint64_t & nb02,
5785
+ constant int64_t & ne10,
5786
+ constant uint64_t & nb10,
5787
+ constant uint64_t & nb11,
5788
+ constant uint64_t & nb1,
5789
+ constant uint64_t & nb2,
5790
+ uint3 tgpig[[threadgroup_position_in_grid]],
5791
+ uint tiitg[[thread_index_in_threadgroup]],
5792
+ uint3 tptg [[threads_per_threadgroup]]) {
5793
+ //const int64_t i = tgpig;
5794
+ //const int64_t r = ((device int32_t *) src1)[i];
5795
+
5796
+ const int64_t i10 = tgpig.x;
5797
+ const int64_t i11 = tgpig.y;
5798
+
5799
+ const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
5800
+
5801
+ const int64_t i02 = i11;
5802
+
5803
+ for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) {
5804
+ float4x4 temp;
5805
+ dequantize_func(
5806
+ ((device const block_q *) ((device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp);
5807
+ *(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp;
5808
+ }
5809
+ }
5810
+
5811
+ kernel void kernel_get_rows_f32(
5812
+ device const void * src0,
5813
+ device const char * src1,
5814
+ device float * dst,
5815
+ constant int64_t & ne00,
5816
+ constant uint64_t & nb01,
5817
+ constant uint64_t & nb02,
5818
+ constant int64_t & ne10,
5819
+ constant uint64_t & nb10,
5820
+ constant uint64_t & nb11,
5821
+ constant uint64_t & nb1,
5822
+ constant uint64_t & nb2,
5823
+ uint3 tgpig[[threadgroup_position_in_grid]],
5824
+ uint tiitg[[thread_index_in_threadgroup]],
5825
+ uint3 tptg [[threads_per_threadgroup]]) {
5826
+ const int64_t i10 = tgpig.x;
5827
+ const int64_t i11 = tgpig.y;
5828
+
5829
+ const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
5830
+
5831
+ const int64_t i02 = i11;
5832
+
5833
+ for (int ind = tiitg; ind < ne00; ind += tptg.x) {
5834
+ ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
5835
+ ((device float *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
5836
+ }
5837
+ }
5838
+
5839
+ kernel void kernel_get_rows_f16(
5840
+ device const void * src0,
5841
+ device const char * src1,
5842
+ device float * dst,
5843
+ constant int64_t & ne00,
5844
+ constant uint64_t & nb01,
5845
+ constant uint64_t & nb02,
5846
+ constant int64_t & ne10,
5847
+ constant uint64_t & nb10,
5848
+ constant uint64_t & nb11,
5849
+ constant uint64_t & nb1,
5850
+ constant uint64_t & nb2,
5851
+ uint3 tgpig[[threadgroup_position_in_grid]],
5852
+ uint tiitg[[thread_index_in_threadgroup]],
5853
+ uint3 tptg [[threads_per_threadgroup]]) {
5854
+ const int64_t i10 = tgpig.x;
5855
+ const int64_t i11 = tgpig.y;
5856
+
5857
+ const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
5858
+
5859
+ const int64_t i02 = i11;
5860
+
5861
+ for (int ind = tiitg; ind < ne00; ind += tptg.x) {
5862
+ ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
5863
+ ((device half *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
5864
+ }
5865
+ }
5866
+
5867
+ kernel void kernel_get_rows_i32(
5868
+ device const void * src0,
5869
+ device const char * src1,
5870
+ device int32_t * dst,
5871
+ constant int64_t & ne00,
5872
+ constant uint64_t & nb01,
5873
+ constant uint64_t & nb02,
5874
+ constant int64_t & ne10,
5875
+ constant uint64_t & nb10,
5876
+ constant uint64_t & nb11,
5877
+ constant uint64_t & nb1,
5878
+ constant uint64_t & nb2,
5879
+ uint3 tgpig[[threadgroup_position_in_grid]],
5880
+ uint tiitg[[thread_index_in_threadgroup]],
5881
+ uint3 tptg [[threads_per_threadgroup]]) {
5882
+ const int64_t i10 = tgpig.x;
5883
+ const int64_t i11 = tgpig.y;
5884
+
5885
+ const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
5886
+
5887
+ const int64_t i02 = i11;
5888
+
5889
+ for (int ind = tiitg; ind < ne00; ind += tptg.x) {
5890
+ ((device int32_t *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
5891
+ ((device int32_t *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
5892
+ }
5893
+ }
5894
+
5895
+
5896
+ #define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
5897
+ #define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
5898
+ #define BLOCK_SIZE_K 32
5899
+ #define THREAD_MAT_M 4 // each thread take 4 simdgroup matrices from matrix A
5900
+ #define THREAD_MAT_N 2 // each thread take 2 simdgroup matrices from matrix B
5901
+ #define THREAD_PER_BLOCK 128
5902
+ #define THREAD_PER_ROW 2 // 2 thread for each row in matrix A to load numbers
5903
+ #define THREAD_PER_COL 4 // 4 thread for each row in matrix B to load numbers
5904
+ #define SG_MAT_SIZE 64 // simdgroup matrix is of shape 8x8
5905
+ #define SG_MAT_ROW 8
5906
+
5907
+ // each block_q contains 16*nl weights
5908
+ template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
5909
+ void kernel_mul_mm_impl(device const uchar * src0,
5910
+ device const uchar * src1,
5911
+ device float * dst,
5912
+ constant int64_t & ne00,
5913
+ constant int64_t & ne02,
5914
+ constant uint64_t & nb01,
5915
+ constant uint64_t & nb02,
5916
+ constant int64_t & ne12,
5917
+ constant uint64_t & nb10,
5918
+ constant uint64_t & nb11,
5919
+ constant uint64_t & nb12,
5920
+ constant int64_t & ne0,
5921
+ constant int64_t & ne1,
5922
+ constant uint & r2,
5923
+ constant uint & r3,
5924
+ threadgroup uchar * shared_memory [[threadgroup(0)]],
5925
+ uint3 tgpig[[threadgroup_position_in_grid]],
5926
+ uint tiitg[[thread_index_in_threadgroup]],
5927
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
5928
+
5929
+ threadgroup half * sa = (threadgroup half *)(shared_memory);
5930
+ threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
5931
+
5932
+ const uint r0 = tgpig.y;
5933
+ const uint r1 = tgpig.x;
5934
+ const uint im = tgpig.z;
5935
+
5936
+ // if this block is of 64x32 shape or smaller
5937
+ short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
5938
+ short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
5939
+
5940
+ // a thread shouldn't load data outside of the matrix
5941
+ short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
5942
+ short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
5943
+
5944
+ simdgroup_half8x8 ma[4];
5945
+ simdgroup_float8x8 mb[2];
5946
+ simdgroup_float8x8 c_res[8];
5947
+ for (int i = 0; i < 8; i++){
5948
+ c_res[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
5949
+ }
5950
+
5951
+ short il = (tiitg % THREAD_PER_ROW);
5952
+
5953
+ const uint i12 = im%ne12;
5954
+ const uint i13 = im/ne12;
5955
+
5956
+ uint offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02);
5957
+ ushort offset1 = il/nl;
5958
+
5959
+ device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
5960
+ device const float * y = (device const float *)(src1
5961
+ + nb12 * im
5962
+ + nb11 * (r1 * BLOCK_SIZE_N + thread_col)
5963
+ + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
5964
+
5965
+ for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
5966
+ // load data and store to threadgroup memory
5967
+ half4x4 temp_a;
5968
+ dequantize_func(x, il, temp_a);
5969
+ threadgroup_barrier(mem_flags::mem_threadgroup);
5970
+
5971
+ #pragma unroll(16)
5972
+ for (int i = 0; i < 16; i++) {
5973
+ *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
5974
+ + (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \
5975
+ + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
5976
+ }
5977
+
5978
+ *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y);
5979
+
5980
+ il = (il + 2 < nl) ? il + 2 : il % 2;
5981
+ x = (il < 2) ? x + (2+nl-1)/nl : x;
5982
+ y += BLOCK_SIZE_K;
5983
+
5984
+ threadgroup_barrier(mem_flags::mem_threadgroup);
5985
+
5986
+ // load matrices from threadgroup memory and conduct outer products
5987
+ threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
5988
+ threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
5989
+
5990
+ #pragma unroll(4)
5991
+ for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
5992
+ #pragma unroll(4)
5993
+ for (int i = 0; i < 4; i++) {
5994
+ simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i);
5995
+ }
5996
+ simdgroup_barrier(mem_flags::mem_none);
5997
+ #pragma unroll(2)
5998
+ for (int i = 0; i < 2; i++) {
5999
+ simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i);
6000
+ }
6001
+
6002
+ lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
6003
+ lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
6004
+
6005
+ #pragma unroll(8)
6006
+ for (int i = 0; i < 8; i++){
6007
+ simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
6008
+ }
6009
+ }
6010
+ }
6011
+
6012
+ if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) {
6013
+ device float * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1)) \
6014
+ + (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * ne0 + im*ne1*ne0;
6015
+ for (int i = 0; i < 8; i++) {
6016
+ simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0);
6017
+ }
6018
+ } else {
6019
+ // block is smaller than 64x32, we should avoid writing data outside of the matrix
6020
+ threadgroup_barrier(mem_flags::mem_threadgroup);
6021
+ threadgroup float * temp_str = ((threadgroup float *)shared_memory) \
6022
+ + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
6023
+ for (int i = 0; i < 8; i++) {
6024
+ simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
6025
+ }
6026
+
6027
+ threadgroup_barrier(mem_flags::mem_threadgroup);
6028
+
6029
+ device float * C = dst + (BLOCK_SIZE_M * r0) + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
6030
+ if (sgitg == 0) {
6031
+ for (int i = 0; i < n_rows; i++) {
6032
+ for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
6033
+ *(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
6034
+ }
6035
+ }
6036
+ }
6037
+ }
6038
+ }
6039
+
6040
+ // same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in src1ids
6041
+ template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
6042
+ void kernel_mul_mm_id_impl(
6043
+ device const uchar * src0,
6044
+ device const uchar * src1,
6045
+ thread short * src1ids,
6046
+ device float * dst,
6047
+ constant int64_t & ne00,
6048
+ constant int64_t & ne02,
6049
+ constant uint64_t & nb01,
6050
+ constant uint64_t & nb02,
6051
+ constant int64_t & ne12,
6052
+ constant uint64_t & nb10,
6053
+ constant uint64_t & nb11,
6054
+ constant uint64_t & nb12,
6055
+ constant int64_t & ne0,
6056
+ int64_t ne1,
6057
+ constant uint & r2,
6058
+ constant uint & r3,
6059
+ threadgroup uchar * shared_memory,
6060
+ uint3 tgpig[[threadgroup_position_in_grid]],
6061
+ uint tiitg[[thread_index_in_threadgroup]],
6062
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
6063
+
6064
+ threadgroup half * sa = (threadgroup half *)(shared_memory);
6065
+ threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
6066
+
6067
+ const uint r0 = tgpig.y;
6068
+ const uint r1 = tgpig.x;
6069
+ const uint im = tgpig.z;
6070
+
6071
+ if (r1 * BLOCK_SIZE_N >= ne1) return;
6072
+
6073
+ // if this block is of 64x32 shape or smaller
6074
+ short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
6075
+ short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
6076
+
6077
+ // a thread shouldn't load data outside of the matrix
6078
+ short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
6079
+ short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
6080
+
6081
+ simdgroup_half8x8 ma[4];
6082
+ simdgroup_float8x8 mb[2];
6083
+ simdgroup_float8x8 c_res[8];
6084
+ for (int i = 0; i < 8; i++){
6085
+ c_res[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
6086
+ }
6087
+
6088
+ short il = (tiitg % THREAD_PER_ROW);
6089
+
6090
+ const uint i12 = im%ne12;
6091
+ const uint i13 = im/ne12;
6092
+
6093
+ uint offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02);
6094
+ ushort offset1 = il/nl;
6095
+
6096
+ device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
6097
+ device const float * y = (device const float *)(src1
6098
+ + nb12 * im
6099
+ + nb11 * src1ids[r1 * BLOCK_SIZE_N + thread_col]
6100
+ + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
6101
+
6102
+ for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
6103
+ // load data and store to threadgroup memory
6104
+ half4x4 temp_a;
6105
+ dequantize_func(x, il, temp_a);
6106
+ threadgroup_barrier(mem_flags::mem_threadgroup);
6107
+
6108
+ for (int i = 0; i < 16; i++) {
6109
+ *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
6110
+ + (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \
6111
+ + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
6112
+ }
6113
+
6114
+ *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y);
6115
+
6116
+ il = (il + 2 < nl) ? il + 2 : il % 2;
6117
+ x = (il < 2) ? x + (2+nl-1)/nl : x;
6118
+ y += BLOCK_SIZE_K;
6119
+
6120
+ threadgroup_barrier(mem_flags::mem_threadgroup);
6121
+
6122
+ // load matrices from threadgroup memory and conduct outer products
6123
+ threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
6124
+ threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
6125
+
6126
+ for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
6127
+ for (int i = 0; i < 4; i++) {
6128
+ simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i);
6129
+ }
6130
+ simdgroup_barrier(mem_flags::mem_none);
6131
+ for (int i = 0; i < 2; i++) {
6132
+ simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i);
6133
+ }
6134
+
6135
+ lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
6136
+ lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
6137
+
6138
+ for (int i = 0; i < 8; i++){
6139
+ simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
6140
+ }
6141
+ }
6142
+ }
6143
+
6144
+ {
6145
+ threadgroup_barrier(mem_flags::mem_threadgroup);
6146
+ threadgroup float * temp_str = ((threadgroup float *)shared_memory) \
6147
+ + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
6148
+ for (int i = 0; i < 8; i++) {
6149
+ simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
6150
+ }
6151
+
6152
+ threadgroup_barrier(mem_flags::mem_threadgroup);
6153
+
6154
+ device float * C = dst + (BLOCK_SIZE_M * r0) + im*ne1*ne0;
6155
+ if (sgitg == 0) {
6156
+ for (int i = 0; i < n_rows; i++) {
6157
+ for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
6158
+ *(C + i + src1ids[j + r1*BLOCK_SIZE_N] * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
6159
+ }
6160
+ }
6161
+ }
6162
+ }
6163
+ }
6164
+
6165
+ template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
6166
+ kernel void kernel_mul_mm(device const uchar * src0,
6167
+ device const uchar * src1,
6168
+ device float * dst,
6169
+ constant int64_t & ne00,
6170
+ constant int64_t & ne02,
6171
+ constant uint64_t & nb01,
6172
+ constant uint64_t & nb02,
6173
+ constant int64_t & ne12,
6174
+ constant uint64_t & nb10,
6175
+ constant uint64_t & nb11,
6176
+ constant uint64_t & nb12,
6177
+ constant int64_t & ne0,
6178
+ constant int64_t & ne1,
6179
+ constant uint & r2,
6180
+ constant uint & r3,
6181
+ threadgroup uchar * shared_memory [[threadgroup(0)]],
6182
+ uint3 tgpig[[threadgroup_position_in_grid]],
6183
+ uint tiitg[[thread_index_in_threadgroup]],
6184
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
6185
+ kernel_mul_mm_impl<block_q, nl, dequantize_func>(
6186
+ src0,
6187
+ src1,
6188
+ dst,
6189
+ ne00,
6190
+ ne02,
6191
+ nb01,
6192
+ nb02,
6193
+ ne12,
6194
+ nb10,
6195
+ nb11,
6196
+ nb12,
6197
+ ne0,
6198
+ ne1,
6199
+ r2,
6200
+ r3,
6201
+ shared_memory,
6202
+ tgpig,
6203
+ tiitg,
6204
+ sgitg);
6205
+ }
6206
+
6207
+ template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
6208
+ kernel void kernel_mul_mm_id(
6209
+ device const uchar * ids,
6210
+ device const uchar * src1,
6211
+ device float * dst,
6212
+ constant uint64_t & nbi1,
6213
+ constant int64_t & ne00,
6214
+ constant int64_t & ne02,
6215
+ constant uint64_t & nb01,
6216
+ constant uint64_t & nb02,
6217
+ constant int64_t & ne12,
6218
+ constant int64_t & ne13,
6219
+ constant uint64_t & nb10,
6220
+ constant uint64_t & nb11,
6221
+ constant uint64_t & nb12,
6222
+ constant int64_t & ne0,
6223
+ constant int64_t & ne1,
6224
+ constant uint64_t & nb1,
6225
+ constant uint & r2,
6226
+ constant uint & r3,
6227
+ constant int & idx,
6228
+ device const uchar * src00,
6229
+ device const uchar * src01,
6230
+ device const uchar * src02,
6231
+ device const uchar * src03,
6232
+ device const uchar * src04,
6233
+ device const uchar * src05,
6234
+ device const uchar * src06,
6235
+ device const uchar * src07,
6236
+ threadgroup uchar * shared_memory [[threadgroup(0)]],
6237
+ uint3 tgpig[[threadgroup_position_in_grid]],
6238
+ uint tiitg[[thread_index_in_threadgroup]],
6239
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
6240
+ device const uchar * src0s[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6241
+
6242
+ // expert id
6243
+ const int32_t id = tgpig.z/(ne12*ne13);
6244
+
6245
+ tgpig.z = tgpig.z%(ne12*ne13);
6246
+
6247
+ // row indices of src1 for expert id
6248
+ int64_t _ne1 = 0;
6249
+ short src1ids[512];
6250
+
6251
+ for (int64_t i1 = 0; i1 < ne1; i1++) {
6252
+ if (((device int32_t *) (ids + i1*nbi1))[idx] == id) {
6253
+ src1ids[_ne1++] = i1;
6254
+ }
6255
+ }
6256
+
6257
+ kernel_mul_mm_id_impl<block_q, nl, dequantize_func>(
6258
+ src0s[id],
6259
+ src1,
6260
+ src1ids,
6261
+ dst,
6262
+ ne00,
6263
+ ne02,
6264
+ nb01,
6265
+ nb02,
6266
+ ne12,
6267
+ nb10,
6268
+ nb11,
6269
+ nb12,
6270
+ ne0,
6271
+ _ne1,
6272
+ r2,
6273
+ r3,
6274
+ shared_memory,
6275
+ tgpig,
6276
+ tiitg,
6277
+ sgitg);
6278
+ }
6279
+
6280
+ #if QK_K == 256
6281
+ #define QK_NL 16
6282
+ #else
6283
+ #define QK_NL 4
6284
+ #endif
6285
+
6286
+ //
6287
+ // get rows
6288
+ //
6289
+
6290
+ typedef void (get_rows_t)(
6291
+ device const void * src0,
6292
+ device const char * src1,
6293
+ device float * dst,
6294
+ constant int64_t & ne00,
6295
+ constant uint64_t & nb01,
6296
+ constant uint64_t & nb02,
6297
+ constant int64_t & ne10,
6298
+ constant uint64_t & nb10,
6299
+ constant uint64_t & nb11,
6300
+ constant uint64_t & nb1,
6301
+ constant uint64_t & nb2,
6302
+ uint3, uint, uint3);
6303
+
6304
+ //template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
6305
+ //template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
6306
+ template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
6307
+ template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
6308
+ template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows<block_q5_0, 2, dequantize_q5_0>;
6309
+ template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_t kernel_get_rows<block_q5_1, 2, dequantize_q5_1>;
6310
+ template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows<block_q8_0, 2, dequantize_q8_0>;
6311
+ template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows<block_q2_K, QK_NL, dequantize_q2_K>;
6312
+ template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows<block_q3_K, QK_NL, dequantize_q3_K>;
6313
+ template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows<block_q4_K, QK_NL, dequantize_q4_K>;
6314
+ template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
6315
+ template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
6316
+ template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_t kernel_get_rows<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
6317
+ template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_t kernel_get_rows<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
6318
+ template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_t kernel_get_rows<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
6319
+ template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_t kernel_get_rows<block_iq3_s, QK_NL, dequantize_iq3_s>;
6320
+ template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_t kernel_get_rows<block_iq2_s, QK_NL, dequantize_iq2_s>;
6321
+ template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows<block_iq1_s, QK_NL, dequantize_iq1_s>;
6322
+ template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows<block_iq4_nl, 2, dequantize_iq4_nl>;
6323
+ #if QK_K == 64
6324
+ template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows<block_iq4_xs, 2, dequantize_iq4_xs>;
6325
+ #else
6326
+ template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
6327
+ #endif
6328
+
6329
+ //
6330
+ // matrix-matrix multiplication
6331
+ //
6332
+
6333
+ typedef void (mat_mm_t)(
6334
+ device const uchar * src0,
6335
+ device const uchar * src1,
6336
+ device float * dst,
6337
+ constant int64_t & ne00,
6338
+ constant int64_t & ne02,
6339
+ constant uint64_t & nb01,
6340
+ constant uint64_t & nb02,
6341
+ constant int64_t & ne12,
6342
+ constant uint64_t & nb10,
6343
+ constant uint64_t & nb11,
6344
+ constant uint64_t & nb12,
6345
+ constant int64_t & ne0,
6346
+ constant int64_t & ne1,
6347
+ constant uint & r2,
6348
+ constant uint & r3,
6349
+ threadgroup uchar *,
6350
+ uint3, uint, uint);
6351
+
6352
+ template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<float4x4, 1, dequantize_f32>;
6353
+ template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
6354
+ template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
6355
+ template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
6356
+ template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_0, 2, dequantize_q5_0>;
6357
+ template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_1, 2, dequantize_q5_1>;
6358
+ template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
6359
+ template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
6360
+ template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
6361
+ template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
6362
+ template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
6363
+ template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
6364
+ template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
6365
+ template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
6366
+ template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
6367
+ template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_s, QK_NL, dequantize_iq3_s>;
6368
+ template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_s, QK_NL, dequantize_iq2_s>;
6369
+ template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_s, QK_NL, dequantize_iq1_s>;
6370
+ template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_nl>;
6371
+ #if QK_K == 64
6372
+ template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_xs>;
6373
+ #else
6374
+ template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
6375
+ #endif
6376
+
6377
+ //
6378
+ // indirect matrix-matrix multiplication
6379
+ //
6380
+
6381
+ typedef void (mat_mm_id_t)(
6382
+ device const uchar * ids,
6383
+ device const uchar * src1,
6384
+ device float * dst,
6385
+ constant uint64_t & nbi1,
6386
+ constant int64_t & ne00,
6387
+ constant int64_t & ne02,
6388
+ constant uint64_t & nb01,
6389
+ constant uint64_t & nb02,
6390
+ constant int64_t & ne12,
6391
+ constant int64_t & ne13,
6392
+ constant uint64_t & nb10,
6393
+ constant uint64_t & nb11,
6394
+ constant uint64_t & nb12,
6395
+ constant int64_t & ne0,
6396
+ constant int64_t & ne1,
6397
+ constant uint64_t & nb1,
6398
+ constant uint & r2,
6399
+ constant uint & r3,
6400
+ constant int & idx,
6401
+ device const uchar * src00,
6402
+ device const uchar * src01,
6403
+ device const uchar * src02,
6404
+ device const uchar * src03,
6405
+ device const uchar * src04,
6406
+ device const uchar * src05,
6407
+ device const uchar * src06,
6408
+ device const uchar * src07,
6409
+ threadgroup uchar *,
6410
+ uint3, uint, uint);
6411
+
6412
+ template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<float4x4, 1, dequantize_f32>;
6413
+ template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<half4x4, 1, dequantize_f16>;
6414
+ template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_0, 2, dequantize_q4_0>;
6415
+ template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_1, 2, dequantize_q4_1>;
6416
+ template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_0, 2, dequantize_q5_0>;
6417
+ template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_1, 2, dequantize_q5_1>;
6418
+ template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q8_0, 2, dequantize_q8_0>;
6419
+ template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q2_K, QK_NL, dequantize_q2_K>;
6420
+ template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q3_K, QK_NL, dequantize_q3_K>;
6421
+ template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>;
6422
+ template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
6423
+ template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
6424
+ template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
6425
+ template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
6426
+ template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
6427
+ template [[host_name("kernel_mul_mm_id_iq3_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_s, QK_NL, dequantize_iq3_s>;
6428
+ template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_s, QK_NL, dequantize_iq2_s>;
6429
+ template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s, QK_NL, dequantize_iq1_s>;
6430
+ template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl, 2, dequantize_iq4_nl>;
6431
+ #if QK_K == 64
6432
+ template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, 2, dequantize_iq4_xs>;
6433
+ #else
6434
+ template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
6435
+ #endif
6436
+
6437
+ //
6438
+ // matrix-vector multiplication
6439
+ //
6440
+
6441
+ [[host_name("kernel_mul_mv_id_f32_f32")]]
6442
+ kernel void kernel_mul_mv_id_f32_f32(
6443
+ device const char * ids,
6444
+ device const char * src1,
6445
+ device float * dst,
6446
+ constant uint64_t & nbi1,
6447
+ constant int64_t & ne00,
6448
+ constant int64_t & ne01,
6449
+ constant int64_t & ne02,
6450
+ constant uint64_t & nb00,
6451
+ constant uint64_t & nb01,
6452
+ constant uint64_t & nb02,
6453
+ constant int64_t & ne10,
6454
+ constant int64_t & ne11,
6455
+ constant int64_t & ne12,
6456
+ constant int64_t & ne13,
6457
+ constant uint64_t & nb10,
6458
+ constant uint64_t & nb11,
6459
+ constant uint64_t & nb12,
6460
+ constant int64_t & ne0,
6461
+ constant int64_t & ne1,
6462
+ constant uint64_t & nb1,
6463
+ constant uint & r2,
6464
+ constant uint & r3,
6465
+ constant int & idx,
6466
+ device const char * src00,
6467
+ device const char * src01,
6468
+ device const char * src02,
6469
+ device const char * src03,
6470
+ device const char * src04,
6471
+ device const char * src05,
6472
+ device const char * src06,
6473
+ device const char * src07,
6474
+ uint3 tgpig[[threadgroup_position_in_grid]],
6475
+ uint tiitg[[thread_index_in_threadgroup]],
6476
+ uint tiisg[[thread_index_in_simdgroup]],
6477
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
6478
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4887
6479
 
4888
- {
4889
- threadgroup_barrier(mem_flags::mem_threadgroup);
4890
- threadgroup float * temp_str = ((threadgroup float *)shared_memory) \
4891
- + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
4892
- for (int i = 0; i < 8; i++) {
4893
- simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
4894
- }
6480
+ const int64_t bid = tgpig.z/(ne12*ne13);
4895
6481
 
4896
- threadgroup_barrier(mem_flags::mem_threadgroup);
6482
+ tgpig.z = tgpig.z%(ne12*ne13);
4897
6483
 
4898
- device float * C = dst + (BLOCK_SIZE_M * r0) + im*ne1*ne0;
4899
- if (sgitg == 0) {
4900
- for (int i = 0; i < n_rows; i++) {
4901
- for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
4902
- *(C + i + src1ids[j + r1*BLOCK_SIZE_N] * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
4903
- }
4904
- }
4905
- }
4906
- }
4907
- }
6484
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
4908
6485
 
4909
- template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
4910
- kernel void kernel_mul_mm(device const uchar * src0,
4911
- device const uchar * src1,
4912
- device float * dst,
4913
- constant int64_t & ne00,
4914
- constant int64_t & ne02,
4915
- constant uint64_t & nb01,
4916
- constant uint64_t & nb02,
4917
- constant int64_t & ne12,
4918
- constant uint64_t & nb10,
4919
- constant uint64_t & nb11,
4920
- constant uint64_t & nb12,
4921
- constant int64_t & ne0,
4922
- constant int64_t & ne1,
4923
- constant uint & r2,
4924
- constant uint & r3,
4925
- threadgroup uchar * shared_memory [[threadgroup(0)]],
4926
- uint3 tgpig[[threadgroup_position_in_grid]],
4927
- uint tiitg[[thread_index_in_threadgroup]],
4928
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
4929
- kernel_mul_mm_impl<block_q, nl, dequantize_func>(
4930
- src0,
4931
- src1,
4932
- dst,
6486
+ kernel_mul_mv_f32_f32_impl(
6487
+ src0[id],
6488
+ src1 + bid*nb11,
6489
+ dst + bid*ne0,
4933
6490
  ne00,
6491
+ ne01,
4934
6492
  ne02,
6493
+ nb00,
4935
6494
  nb01,
4936
6495
  nb02,
6496
+ ne10,
6497
+ ne11,
4937
6498
  ne12,
4938
6499
  nb10,
4939
6500
  nb11,
@@ -4942,22 +6503,24 @@ kernel void kernel_mul_mm(device const uchar * src0,
4942
6503
  ne1,
4943
6504
  r2,
4944
6505
  r3,
4945
- shared_memory,
4946
6506
  tgpig,
4947
- tiitg,
4948
- sgitg);
6507
+ tiisg);
4949
6508
  }
4950
6509
 
4951
- template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
4952
- kernel void kernel_mul_mm_id(
4953
- device const uchar * ids,
4954
- device const uchar * src1,
6510
+ [[host_name("kernel_mul_mv_id_f16_f32")]]
6511
+ kernel void kernel_mul_mv_id_f16_f32(
6512
+ device const char * ids,
6513
+ device const char * src1,
4955
6514
  device float * dst,
4956
6515
  constant uint64_t & nbi1,
4957
6516
  constant int64_t & ne00,
6517
+ constant int64_t & ne01,
4958
6518
  constant int64_t & ne02,
6519
+ constant uint64_t & nb00,
4959
6520
  constant uint64_t & nb01,
4960
6521
  constant uint64_t & nb02,
6522
+ constant int64_t & ne10,
6523
+ constant int64_t & ne11,
4961
6524
  constant int64_t & ne12,
4962
6525
  constant int64_t & ne13,
4963
6526
  constant uint64_t & nb10,
@@ -4969,150 +6532,190 @@ kernel void kernel_mul_mm_id(
4969
6532
  constant uint & r2,
4970
6533
  constant uint & r3,
4971
6534
  constant int & idx,
4972
- device const uchar * src00,
4973
- device const uchar * src01,
4974
- device const uchar * src02,
4975
- device const uchar * src03,
4976
- device const uchar * src04,
4977
- device const uchar * src05,
4978
- device const uchar * src06,
4979
- device const uchar * src07,
4980
- threadgroup uchar * shared_memory [[threadgroup(0)]],
6535
+ device const char * src00,
6536
+ device const char * src01,
6537
+ device const char * src02,
6538
+ device const char * src03,
6539
+ device const char * src04,
6540
+ device const char * src05,
6541
+ device const char * src06,
6542
+ device const char * src07,
4981
6543
  uint3 tgpig[[threadgroup_position_in_grid]],
4982
6544
  uint tiitg[[thread_index_in_threadgroup]],
6545
+ uint tiisg[[thread_index_in_simdgroup]],
4983
6546
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
4984
- device const uchar * src0s[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6547
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
4985
6548
 
4986
- // expert id
4987
- const int32_t id = tgpig.z/(ne12*ne13);
6549
+ const int64_t bid = tgpig.z/(ne12*ne13);
4988
6550
 
4989
6551
  tgpig.z = tgpig.z%(ne12*ne13);
4990
6552
 
4991
- // row indices of src1 for expert id
4992
- int64_t _ne1 = 0;
4993
- short src1ids[512];
4994
-
4995
- for (int64_t i1 = 0; i1 < ne1; i1++) {
4996
- if (((device int32_t *) (ids + i1*nbi1))[idx] == id) {
4997
- src1ids[_ne1++] = i1;
4998
- }
4999
- }
6553
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
5000
6554
 
5001
- kernel_mul_mm_id_impl<block_q, nl, dequantize_func>(
5002
- src0s[id],
5003
- src1,
5004
- src1ids,
5005
- dst,
6555
+ kernel_mul_mv_f16_f32_impl(
6556
+ src0[id],
6557
+ src1 + bid*nb11,
6558
+ dst + bid*ne0,
5006
6559
  ne00,
6560
+ ne01,
5007
6561
  ne02,
6562
+ nb00,
5008
6563
  nb01,
5009
6564
  nb02,
6565
+ ne10,
6566
+ ne11,
5010
6567
  ne12,
5011
6568
  nb10,
5012
6569
  nb11,
5013
6570
  nb12,
5014
6571
  ne0,
5015
- _ne1,
6572
+ ne1,
5016
6573
  r2,
5017
6574
  r3,
5018
- shared_memory,
5019
6575
  tgpig,
5020
- tiitg,
5021
- sgitg);
6576
+ tiisg);
5022
6577
  }
5023
6578
 
5024
- #if QK_K == 256
5025
- #define QK_NL 16
5026
- #else
5027
- #define QK_NL 4
5028
- #endif
6579
+ [[host_name("kernel_mul_mv_id_q8_0_f32")]]
6580
+ kernel void kernel_mul_mv_id_q8_0_f32(
6581
+ device const char * ids,
6582
+ device const char * src1,
6583
+ device float * dst,
6584
+ constant uint64_t & nbi1,
6585
+ constant int64_t & ne00,
6586
+ constant int64_t & ne01,
6587
+ constant int64_t & ne02,
6588
+ constant uint64_t & nb00,
6589
+ constant uint64_t & nb01,
6590
+ constant uint64_t & nb02,
6591
+ constant int64_t & ne10,
6592
+ constant int64_t & ne11,
6593
+ constant int64_t & ne12,
6594
+ constant int64_t & ne13,
6595
+ constant uint64_t & nb10,
6596
+ constant uint64_t & nb11,
6597
+ constant uint64_t & nb12,
6598
+ constant int64_t & ne0,
6599
+ constant int64_t & ne1,
6600
+ constant uint64_t & nb1,
6601
+ constant uint & r2,
6602
+ constant uint & r3,
6603
+ constant int & idx,
6604
+ device const char * src00,
6605
+ device const char * src01,
6606
+ device const char * src02,
6607
+ device const char * src03,
6608
+ device const char * src04,
6609
+ device const char * src05,
6610
+ device const char * src06,
6611
+ device const char * src07,
6612
+ uint3 tgpig[[threadgroup_position_in_grid]],
6613
+ uint tiitg[[thread_index_in_threadgroup]],
6614
+ uint tiisg[[thread_index_in_simdgroup]],
6615
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
6616
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
5029
6617
 
5030
- //
5031
- // get rows
5032
- //
6618
+ const int64_t bid = tgpig.z/(ne12*ne13);
5033
6619
 
5034
- typedef void (get_rows_t)(
5035
- device const void * src0,
5036
- device const char * src1,
5037
- device float * dst,
5038
- constant int64_t & ne00,
5039
- constant uint64_t & nb01,
5040
- constant uint64_t & nb02,
5041
- constant int64_t & ne10,
5042
- constant uint64_t & nb10,
5043
- constant uint64_t & nb11,
5044
- constant uint64_t & nb1,
5045
- constant uint64_t & nb2,
5046
- uint3, uint, uint3);
6620
+ tgpig.z = tgpig.z%(ne12*ne13);
5047
6621
 
5048
- //template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
5049
- //template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
5050
- template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
5051
- template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
5052
- template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows<block_q5_0, 2, dequantize_q5_0>;
5053
- template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_t kernel_get_rows<block_q5_1, 2, dequantize_q5_1>;
5054
- template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows<block_q8_0, 2, dequantize_q8_0>;
5055
- template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows<block_q2_K, QK_NL, dequantize_q2_K>;
5056
- template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows<block_q3_K, QK_NL, dequantize_q3_K>;
5057
- template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows<block_q4_K, QK_NL, dequantize_q4_K>;
5058
- template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
5059
- template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
5060
- template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_t kernel_get_rows<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
5061
- template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_t kernel_get_rows<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
5062
- template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_t kernel_get_rows<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
6622
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6623
+
6624
+ kernel_mul_mv_q8_0_f32_impl(
6625
+ src0[id],
6626
+ (device const float *) (src1 + bid*nb11),
6627
+ dst + bid*ne0,
6628
+ ne00,
6629
+ ne01,
6630
+ ne02,
6631
+ ne10,
6632
+ ne12,
6633
+ ne0,
6634
+ ne1,
6635
+ r2,
6636
+ r3,
6637
+ tgpig,
6638
+ tiisg,
6639
+ sgitg);
6640
+ }
6641
+
6642
+ [[host_name("kernel_mul_mv_id_q4_0_f32")]]
6643
+ kernel void kernel_mul_mv_id_q4_0_f32(
6644
+ device const char * ids,
6645
+ device const char * src1,
6646
+ device float * dst,
6647
+ constant uint64_t & nbi1,
6648
+ constant int64_t & ne00,
6649
+ constant int64_t & ne01,
6650
+ constant int64_t & ne02,
6651
+ constant uint64_t & nb00,
6652
+ constant uint64_t & nb01,
6653
+ constant uint64_t & nb02,
6654
+ constant int64_t & ne10,
6655
+ constant int64_t & ne11,
6656
+ constant int64_t & ne12,
6657
+ constant int64_t & ne13,
6658
+ constant uint64_t & nb10,
6659
+ constant uint64_t & nb11,
6660
+ constant uint64_t & nb12,
6661
+ constant int64_t & ne0,
6662
+ constant int64_t & ne1,
6663
+ constant uint64_t & nb1,
6664
+ constant uint & r2,
6665
+ constant uint & r3,
6666
+ constant int & idx,
6667
+ device const char * src00,
6668
+ device const char * src01,
6669
+ device const char * src02,
6670
+ device const char * src03,
6671
+ device const char * src04,
6672
+ device const char * src05,
6673
+ device const char * src06,
6674
+ device const char * src07,
6675
+ uint3 tgpig[[threadgroup_position_in_grid]],
6676
+ uint tiitg[[thread_index_in_threadgroup]],
6677
+ uint tiisg[[thread_index_in_simdgroup]],
6678
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
6679
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
5063
6680
 
5064
- //
5065
- // matrix-matrix multiplication
5066
- //
6681
+ const int64_t bid = tgpig.z/(ne12*ne13);
5067
6682
 
5068
- typedef void (mat_mm_t)(
5069
- device const uchar * src0,
5070
- device const uchar * src1,
5071
- device float * dst,
5072
- constant int64_t & ne00,
5073
- constant int64_t & ne02,
5074
- constant uint64_t & nb01,
5075
- constant uint64_t & nb02,
5076
- constant int64_t & ne12,
5077
- constant uint64_t & nb10,
5078
- constant uint64_t & nb11,
5079
- constant uint64_t & nb12,
5080
- constant int64_t & ne0,
5081
- constant int64_t & ne1,
5082
- constant uint & r2,
5083
- constant uint & r3,
5084
- threadgroup uchar *,
5085
- uint3, uint, uint);
6683
+ tgpig.z = tgpig.z%(ne12*ne13);
5086
6684
 
5087
- template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<float4x4, 1, dequantize_f32>;
5088
- template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
5089
- template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
5090
- template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
5091
- template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_0, 2, dequantize_q5_0>;
5092
- template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_1, 2, dequantize_q5_1>;
5093
- template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
5094
- template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
5095
- template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
5096
- template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
5097
- template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
5098
- template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
5099
- template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
5100
- template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
5101
- template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
6685
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
5102
6686
 
5103
- //
5104
- // indirect matrix-matrix multiplication
5105
- //
6687
+ mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
6688
+ src0[id],
6689
+ (device const float *) (src1 + bid*nb11),
6690
+ dst + bid*ne0,
6691
+ ne00,
6692
+ ne01,
6693
+ ne02,
6694
+ ne10,
6695
+ ne12,
6696
+ ne0,
6697
+ ne1,
6698
+ r2,
6699
+ r3,
6700
+ tgpig,
6701
+ tiisg,
6702
+ sgitg);
6703
+ }
5106
6704
 
5107
- typedef void (mat_mm_id_t)(
5108
- device const uchar * ids,
5109
- device const uchar * src1,
6705
+ [[host_name("kernel_mul_mv_id_q4_1_f32")]]
6706
+ kernel void kernel_mul_mv_id_q4_1_f32(
6707
+ device const char * ids,
6708
+ device const char * src1,
5110
6709
  device float * dst,
5111
6710
  constant uint64_t & nbi1,
5112
6711
  constant int64_t & ne00,
6712
+ constant int64_t & ne01,
5113
6713
  constant int64_t & ne02,
6714
+ constant uint64_t & nb00,
5114
6715
  constant uint64_t & nb01,
5115
6716
  constant uint64_t & nb02,
6717
+ constant int64_t & ne10,
6718
+ constant int64_t & ne11,
5116
6719
  constant int64_t & ne12,
5117
6720
  constant int64_t & ne13,
5118
6721
  constant uint64_t & nb10,
@@ -5124,39 +6727,46 @@ typedef void (mat_mm_id_t)(
5124
6727
  constant uint & r2,
5125
6728
  constant uint & r3,
5126
6729
  constant int & idx,
5127
- device const uchar * src00,
5128
- device const uchar * src01,
5129
- device const uchar * src02,
5130
- device const uchar * src03,
5131
- device const uchar * src04,
5132
- device const uchar * src05,
5133
- device const uchar * src06,
5134
- device const uchar * src07,
5135
- threadgroup uchar *,
5136
- uint3, uint, uint);
6730
+ device const char * src00,
6731
+ device const char * src01,
6732
+ device const char * src02,
6733
+ device const char * src03,
6734
+ device const char * src04,
6735
+ device const char * src05,
6736
+ device const char * src06,
6737
+ device const char * src07,
6738
+ uint3 tgpig[[threadgroup_position_in_grid]],
6739
+ uint tiitg[[thread_index_in_threadgroup]],
6740
+ uint tiisg[[thread_index_in_simdgroup]],
6741
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
6742
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
5137
6743
 
5138
- template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<float4x4, 1, dequantize_f32>;
5139
- template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<half4x4, 1, dequantize_f16>;
5140
- template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_0, 2, dequantize_q4_0>;
5141
- template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_1, 2, dequantize_q4_1>;
5142
- template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_0, 2, dequantize_q5_0>;
5143
- template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_1, 2, dequantize_q5_1>;
5144
- template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q8_0, 2, dequantize_q8_0>;
5145
- template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q2_K, QK_NL, dequantize_q2_K>;
5146
- template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q3_K, QK_NL, dequantize_q3_K>;
5147
- template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>;
5148
- template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
5149
- template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
5150
- template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
5151
- template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
5152
- template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
6744
+ const int64_t bid = tgpig.z/(ne12*ne13);
5153
6745
 
5154
- //
5155
- // matrix-vector multiplication
5156
- //
6746
+ tgpig.z = tgpig.z%(ne12*ne13);
5157
6747
 
5158
- [[host_name("kernel_mul_mv_id_f32_f32")]]
5159
- kernel void kernel_mul_mv_id_f32_f32(
6748
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6749
+
6750
+ mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
6751
+ src0[id],
6752
+ (device const float *) (src1 + bid*nb11),
6753
+ dst + bid*ne0,
6754
+ ne00,
6755
+ ne01,
6756
+ ne02,
6757
+ ne10,
6758
+ ne12,
6759
+ ne0,
6760
+ ne1,
6761
+ r2,
6762
+ r3,
6763
+ tgpig,
6764
+ tiisg,
6765
+ sgitg);
6766
+ }
6767
+
6768
+ [[host_name("kernel_mul_mv_id_q5_0_f32")]]
6769
+ kernel void kernel_mul_mv_id_q5_0_f32(
5160
6770
  device const char * ids,
5161
6771
  device const char * src1,
5162
6772
  device float * dst,
@@ -5200,32 +6810,26 @@ kernel void kernel_mul_mv_id_f32_f32(
5200
6810
 
5201
6811
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
5202
6812
 
5203
- kernel_mul_mv_f32_f32_impl(
6813
+ mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
5204
6814
  src0[id],
5205
- src1 + bid*nb11,
5206
- dst + bid*ne0,
6815
+ (device const float *) (src1 + bid*nb11),
6816
+ dst + bid*ne0,
5207
6817
  ne00,
5208
6818
  ne01,
5209
6819
  ne02,
5210
- nb00,
5211
- nb01,
5212
- nb02,
5213
6820
  ne10,
5214
- ne11,
5215
6821
  ne12,
5216
- nb10,
5217
- nb11,
5218
- nb12,
5219
6822
  ne0,
5220
6823
  ne1,
5221
6824
  r2,
5222
6825
  r3,
5223
6826
  tgpig,
5224
- tiisg);
6827
+ tiisg,
6828
+ sgitg);
5225
6829
  }
5226
6830
 
5227
- [[host_name("kernel_mul_mv_id_f16_f32")]]
5228
- kernel void kernel_mul_mv_id_f16_f32(
6831
+ [[host_name("kernel_mul_mv_id_q5_1_f32")]]
6832
+ kernel void kernel_mul_mv_id_q5_1_f32(
5229
6833
  device const char * ids,
5230
6834
  device const char * src1,
5231
6835
  device float * dst,
@@ -5269,32 +6873,26 @@ kernel void kernel_mul_mv_id_f16_f32(
5269
6873
 
5270
6874
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
5271
6875
 
5272
- kernel_mul_mv_f16_f32_impl(
6876
+ mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
5273
6877
  src0[id],
5274
- src1 + bid*nb11,
5275
- dst + bid*ne0,
6878
+ (device const float *) (src1 + bid*nb11),
6879
+ dst + bid*ne0,
5276
6880
  ne00,
5277
6881
  ne01,
5278
6882
  ne02,
5279
- nb00,
5280
- nb01,
5281
- nb02,
5282
6883
  ne10,
5283
- ne11,
5284
6884
  ne12,
5285
- nb10,
5286
- nb11,
5287
- nb12,
5288
6885
  ne0,
5289
6886
  ne1,
5290
6887
  r2,
5291
6888
  r3,
5292
6889
  tgpig,
5293
- tiisg);
6890
+ tiisg,
6891
+ sgitg);
5294
6892
  }
5295
6893
 
5296
- [[host_name("kernel_mul_mv_id_q8_0_f32")]]
5297
- kernel void kernel_mul_mv_id_q8_0_f32(
6894
+ [[host_name("kernel_mul_mv_id_q2_K_f32")]]
6895
+ kernel void kernel_mul_mv_id_q2_K_f32(
5298
6896
  device const char * ids,
5299
6897
  device const char * src1,
5300
6898
  device float * dst,
@@ -5338,7 +6936,7 @@ kernel void kernel_mul_mv_id_q8_0_f32(
5338
6936
 
5339
6937
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
5340
6938
 
5341
- kernel_mul_mv_q8_0_f32_impl(
6939
+ kernel_mul_mv_q2_K_f32_impl(
5342
6940
  src0[id],
5343
6941
  (device const float *) (src1 + bid*nb11),
5344
6942
  dst + bid*ne0,
@@ -5356,8 +6954,8 @@ kernel void kernel_mul_mv_id_q8_0_f32(
5356
6954
  sgitg);
5357
6955
  }
5358
6956
 
5359
- [[host_name("kernel_mul_mv_id_q4_0_f32")]]
5360
- kernel void kernel_mul_mv_id_q4_0_f32(
6957
+ [[host_name("kernel_mul_mv_id_q3_K_f32")]]
6958
+ kernel void kernel_mul_mv_id_q3_K_f32(
5361
6959
  device const char * ids,
5362
6960
  device const char * src1,
5363
6961
  device float * dst,
@@ -5401,7 +6999,7 @@ kernel void kernel_mul_mv_id_q4_0_f32(
5401
6999
 
5402
7000
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
5403
7001
 
5404
- mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
7002
+ kernel_mul_mv_q3_K_f32_impl(
5405
7003
  src0[id],
5406
7004
  (device const float *) (src1 + bid*nb11),
5407
7005
  dst + bid*ne0,
@@ -5419,8 +7017,8 @@ kernel void kernel_mul_mv_id_q4_0_f32(
5419
7017
  sgitg);
5420
7018
  }
5421
7019
 
5422
- [[host_name("kernel_mul_mv_id_q4_1_f32")]]
5423
- kernel void kernel_mul_mv_id_q4_1_f32(
7020
+ [[host_name("kernel_mul_mv_id_q4_K_f32")]]
7021
+ kernel void kernel_mul_mv_id_q4_K_f32(
5424
7022
  device const char * ids,
5425
7023
  device const char * src1,
5426
7024
  device float * dst,
@@ -5464,7 +7062,7 @@ kernel void kernel_mul_mv_id_q4_1_f32(
5464
7062
 
5465
7063
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
5466
7064
 
5467
- mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
7065
+ kernel_mul_mv_q4_K_f32_impl(
5468
7066
  src0[id],
5469
7067
  (device const float *) (src1 + bid*nb11),
5470
7068
  dst + bid*ne0,
@@ -5482,8 +7080,8 @@ kernel void kernel_mul_mv_id_q4_1_f32(
5482
7080
  sgitg);
5483
7081
  }
5484
7082
 
5485
- [[host_name("kernel_mul_mv_id_q5_0_f32")]]
5486
- kernel void kernel_mul_mv_id_q5_0_f32(
7083
+ [[host_name("kernel_mul_mv_id_q5_K_f32")]]
7084
+ kernel void kernel_mul_mv_id_q5_K_f32(
5487
7085
  device const char * ids,
5488
7086
  device const char * src1,
5489
7087
  device float * dst,
@@ -5527,7 +7125,7 @@ kernel void kernel_mul_mv_id_q5_0_f32(
5527
7125
 
5528
7126
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
5529
7127
 
5530
- mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
7128
+ kernel_mul_mv_q5_K_f32_impl(
5531
7129
  src0[id],
5532
7130
  (device const float *) (src1 + bid*nb11),
5533
7131
  dst + bid*ne0,
@@ -5545,8 +7143,8 @@ kernel void kernel_mul_mv_id_q5_0_f32(
5545
7143
  sgitg);
5546
7144
  }
5547
7145
 
5548
- [[host_name("kernel_mul_mv_id_q5_1_f32")]]
5549
- kernel void kernel_mul_mv_id_q5_1_f32(
7146
+ [[host_name("kernel_mul_mv_id_q6_K_f32")]]
7147
+ kernel void kernel_mul_mv_id_q6_K_f32(
5550
7148
  device const char * ids,
5551
7149
  device const char * src1,
5552
7150
  device float * dst,
@@ -5590,7 +7188,7 @@ kernel void kernel_mul_mv_id_q5_1_f32(
5590
7188
 
5591
7189
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
5592
7190
 
5593
- mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
7191
+ kernel_mul_mv_q6_K_f32_impl(
5594
7192
  src0[id],
5595
7193
  (device const float *) (src1 + bid*nb11),
5596
7194
  dst + bid*ne0,
@@ -5608,8 +7206,8 @@ kernel void kernel_mul_mv_id_q5_1_f32(
5608
7206
  sgitg);
5609
7207
  }
5610
7208
 
5611
- [[host_name("kernel_mul_mv_id_q2_K_f32")]]
5612
- kernel void kernel_mul_mv_id_q2_K_f32(
7209
+ [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]]
7210
+ kernel void kernel_mul_mv_id_iq2_xxs_f32(
5613
7211
  device const char * ids,
5614
7212
  device const char * src1,
5615
7213
  device float * dst,
@@ -5641,6 +7239,7 @@ kernel void kernel_mul_mv_id_q2_K_f32(
5641
7239
  device const char * src05,
5642
7240
  device const char * src06,
5643
7241
  device const char * src07,
7242
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
5644
7243
  uint3 tgpig[[threadgroup_position_in_grid]],
5645
7244
  uint tiitg[[thread_index_in_threadgroup]],
5646
7245
  uint tiisg[[thread_index_in_simdgroup]],
@@ -5653,7 +7252,7 @@ kernel void kernel_mul_mv_id_q2_K_f32(
5653
7252
 
5654
7253
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
5655
7254
 
5656
- kernel_mul_mv_q2_K_f32_impl(
7255
+ kernel_mul_mv_iq2_xxs_f32_impl(
5657
7256
  src0[id],
5658
7257
  (device const float *) (src1 + bid*nb11),
5659
7258
  dst + bid*ne0,
@@ -5666,13 +7265,14 @@ kernel void kernel_mul_mv_id_q2_K_f32(
5666
7265
  ne1,
5667
7266
  r2,
5668
7267
  r3,
7268
+ shared_values,
5669
7269
  tgpig,
5670
7270
  tiisg,
5671
7271
  sgitg);
5672
7272
  }
5673
7273
 
5674
- [[host_name("kernel_mul_mv_id_q3_K_f32")]]
5675
- kernel void kernel_mul_mv_id_q3_K_f32(
7274
+ [[host_name("kernel_mul_mv_id_iq2_xs_f32")]]
7275
+ kernel void kernel_mul_mv_id_iq2_xs_f32(
5676
7276
  device const char * ids,
5677
7277
  device const char * src1,
5678
7278
  device float * dst,
@@ -5704,6 +7304,7 @@ kernel void kernel_mul_mv_id_q3_K_f32(
5704
7304
  device const char * src05,
5705
7305
  device const char * src06,
5706
7306
  device const char * src07,
7307
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
5707
7308
  uint3 tgpig[[threadgroup_position_in_grid]],
5708
7309
  uint tiitg[[thread_index_in_threadgroup]],
5709
7310
  uint tiisg[[thread_index_in_simdgroup]],
@@ -5716,7 +7317,7 @@ kernel void kernel_mul_mv_id_q3_K_f32(
5716
7317
 
5717
7318
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
5718
7319
 
5719
- kernel_mul_mv_q3_K_f32_impl(
7320
+ kernel_mul_mv_iq2_xs_f32_impl(
5720
7321
  src0[id],
5721
7322
  (device const float *) (src1 + bid*nb11),
5722
7323
  dst + bid*ne0,
@@ -5729,13 +7330,14 @@ kernel void kernel_mul_mv_id_q3_K_f32(
5729
7330
  ne1,
5730
7331
  r2,
5731
7332
  r3,
7333
+ shared_values,
5732
7334
  tgpig,
5733
7335
  tiisg,
5734
7336
  sgitg);
5735
7337
  }
5736
7338
 
5737
- [[host_name("kernel_mul_mv_id_q4_K_f32")]]
5738
- kernel void kernel_mul_mv_id_q4_K_f32(
7339
+ [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]]
7340
+ kernel void kernel_mul_mv_id_iq3_xxs_f32(
5739
7341
  device const char * ids,
5740
7342
  device const char * src1,
5741
7343
  device float * dst,
@@ -5767,6 +7369,7 @@ kernel void kernel_mul_mv_id_q4_K_f32(
5767
7369
  device const char * src05,
5768
7370
  device const char * src06,
5769
7371
  device const char * src07,
7372
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
5770
7373
  uint3 tgpig[[threadgroup_position_in_grid]],
5771
7374
  uint tiitg[[thread_index_in_threadgroup]],
5772
7375
  uint tiisg[[thread_index_in_simdgroup]],
@@ -5779,7 +7382,7 @@ kernel void kernel_mul_mv_id_q4_K_f32(
5779
7382
 
5780
7383
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
5781
7384
 
5782
- kernel_mul_mv_q4_K_f32_impl(
7385
+ kernel_mul_mv_iq3_xxs_f32_impl(
5783
7386
  src0[id],
5784
7387
  (device const float *) (src1 + bid*nb11),
5785
7388
  dst + bid*ne0,
@@ -5792,13 +7395,14 @@ kernel void kernel_mul_mv_id_q4_K_f32(
5792
7395
  ne1,
5793
7396
  r2,
5794
7397
  r3,
7398
+ shared_values,
5795
7399
  tgpig,
5796
7400
  tiisg,
5797
7401
  sgitg);
5798
7402
  }
5799
7403
 
5800
- [[host_name("kernel_mul_mv_id_q5_K_f32")]]
5801
- kernel void kernel_mul_mv_id_q5_K_f32(
7404
+ [[host_name("kernel_mul_mv_id_iq3_s_f32")]]
7405
+ kernel void kernel_mul_mv_id_iq3_s_f32(
5802
7406
  device const char * ids,
5803
7407
  device const char * src1,
5804
7408
  device float * dst,
@@ -5830,6 +7434,7 @@ kernel void kernel_mul_mv_id_q5_K_f32(
5830
7434
  device const char * src05,
5831
7435
  device const char * src06,
5832
7436
  device const char * src07,
7437
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
5833
7438
  uint3 tgpig[[threadgroup_position_in_grid]],
5834
7439
  uint tiitg[[thread_index_in_threadgroup]],
5835
7440
  uint tiisg[[thread_index_in_simdgroup]],
@@ -5842,7 +7447,7 @@ kernel void kernel_mul_mv_id_q5_K_f32(
5842
7447
 
5843
7448
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
5844
7449
 
5845
- kernel_mul_mv_q5_K_f32_impl(
7450
+ kernel_mul_mv_iq3_s_f32_impl(
5846
7451
  src0[id],
5847
7452
  (device const float *) (src1 + bid*nb11),
5848
7453
  dst + bid*ne0,
@@ -5855,13 +7460,14 @@ kernel void kernel_mul_mv_id_q5_K_f32(
5855
7460
  ne1,
5856
7461
  r2,
5857
7462
  r3,
7463
+ shared_values,
5858
7464
  tgpig,
5859
7465
  tiisg,
5860
7466
  sgitg);
5861
7467
  }
5862
7468
 
5863
- [[host_name("kernel_mul_mv_id_q6_K_f32")]]
5864
- kernel void kernel_mul_mv_id_q6_K_f32(
7469
+ [[host_name("kernel_mul_mv_id_iq2_s_f32")]]
7470
+ kernel void kernel_mul_mv_id_iq2_s_f32(
5865
7471
  device const char * ids,
5866
7472
  device const char * src1,
5867
7473
  device float * dst,
@@ -5893,6 +7499,7 @@ kernel void kernel_mul_mv_id_q6_K_f32(
5893
7499
  device const char * src05,
5894
7500
  device const char * src06,
5895
7501
  device const char * src07,
7502
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
5896
7503
  uint3 tgpig[[threadgroup_position_in_grid]],
5897
7504
  uint tiitg[[thread_index_in_threadgroup]],
5898
7505
  uint tiisg[[thread_index_in_simdgroup]],
@@ -5905,7 +7512,7 @@ kernel void kernel_mul_mv_id_q6_K_f32(
5905
7512
 
5906
7513
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
5907
7514
 
5908
- kernel_mul_mv_q6_K_f32_impl(
7515
+ kernel_mul_mv_iq2_s_f32_impl(
5909
7516
  src0[id],
5910
7517
  (device const float *) (src1 + bid*nb11),
5911
7518
  dst + bid*ne0,
@@ -5918,13 +7525,14 @@ kernel void kernel_mul_mv_id_q6_K_f32(
5918
7525
  ne1,
5919
7526
  r2,
5920
7527
  r3,
7528
+ shared_values,
5921
7529
  tgpig,
5922
7530
  tiisg,
5923
7531
  sgitg);
5924
7532
  }
5925
7533
 
5926
- [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]]
5927
- kernel void kernel_mul_mv_id_iq2_xxs_f32(
7534
+ [[host_name("kernel_mul_mv_id_iq1_s_f32")]]
7535
+ kernel void kernel_mul_mv_id_iq1_s_f32(
5928
7536
  device const char * ids,
5929
7537
  device const char * src1,
5930
7538
  device float * dst,
@@ -5956,7 +7564,6 @@ kernel void kernel_mul_mv_id_iq2_xxs_f32(
5956
7564
  device const char * src05,
5957
7565
  device const char * src06,
5958
7566
  device const char * src07,
5959
- threadgroup int8_t * shared_values [[threadgroup(0)]],
5960
7567
  uint3 tgpig[[threadgroup_position_in_grid]],
5961
7568
  uint tiitg[[thread_index_in_threadgroup]],
5962
7569
  uint tiisg[[thread_index_in_simdgroup]],
@@ -5969,7 +7576,7 @@ kernel void kernel_mul_mv_id_iq2_xxs_f32(
5969
7576
 
5970
7577
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
5971
7578
 
5972
- kernel_mul_mv_iq2_xxs_f32_impl(
7579
+ kernel_mul_mv_iq1_s_f32_impl(
5973
7580
  src0[id],
5974
7581
  (device const float *) (src1 + bid*nb11),
5975
7582
  dst + bid*ne0,
@@ -5982,14 +7589,13 @@ kernel void kernel_mul_mv_id_iq2_xxs_f32(
5982
7589
  ne1,
5983
7590
  r2,
5984
7591
  r3,
5985
- shared_values,
5986
7592
  tgpig,
5987
7593
  tiisg,
5988
7594
  sgitg);
5989
7595
  }
5990
7596
 
5991
- [[host_name("kernel_mul_mv_id_iq2_xs_f32")]]
5992
- kernel void kernel_mul_mv_id_iq2_xs_f32(
7597
+ [[host_name("kernel_mul_mv_id_iq4_nl_f32")]]
7598
+ kernel void kernel_mul_mv_id_iq4_nl_f32(
5993
7599
  device const char * ids,
5994
7600
  device const char * src1,
5995
7601
  device float * dst,
@@ -6021,7 +7627,7 @@ kernel void kernel_mul_mv_id_iq2_xs_f32(
6021
7627
  device const char * src05,
6022
7628
  device const char * src06,
6023
7629
  device const char * src07,
6024
- threadgroup int8_t * shared_values [[threadgroup(0)]],
7630
+ threadgroup float * shared_values [[threadgroup(0)]],
6025
7631
  uint3 tgpig[[threadgroup_position_in_grid]],
6026
7632
  uint tiitg[[thread_index_in_threadgroup]],
6027
7633
  uint tiisg[[thread_index_in_simdgroup]],
@@ -6034,7 +7640,7 @@ kernel void kernel_mul_mv_id_iq2_xs_f32(
6034
7640
 
6035
7641
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6036
7642
 
6037
- kernel_mul_mv_iq2_xs_f32_impl(
7643
+ kernel_mul_mv_iq4_nl_f32_impl(
6038
7644
  src0[id],
6039
7645
  (device const float *) (src1 + bid*nb11),
6040
7646
  dst + bid*ne0,
@@ -6053,8 +7659,8 @@ kernel void kernel_mul_mv_id_iq2_xs_f32(
6053
7659
  sgitg);
6054
7660
  }
6055
7661
 
6056
- [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]]
6057
- kernel void kernel_mul_mv_id_iq3_xxs_f32(
7662
+ [[host_name("kernel_mul_mv_id_iq4_xs_f32")]]
7663
+ kernel void kernel_mul_mv_id_iq4_xs_f32(
6058
7664
  device const char * ids,
6059
7665
  device const char * src1,
6060
7666
  device float * dst,
@@ -6086,7 +7692,7 @@ kernel void kernel_mul_mv_id_iq3_xxs_f32(
6086
7692
  device const char * src05,
6087
7693
  device const char * src06,
6088
7694
  device const char * src07,
6089
- threadgroup int8_t * shared_values [[threadgroup(0)]],
7695
+ threadgroup float * shared_values [[threadgroup(0)]],
6090
7696
  uint3 tgpig[[threadgroup_position_in_grid]],
6091
7697
  uint tiitg[[thread_index_in_threadgroup]],
6092
7698
  uint tiisg[[thread_index_in_simdgroup]],
@@ -6099,7 +7705,11 @@ kernel void kernel_mul_mv_id_iq3_xxs_f32(
6099
7705
 
6100
7706
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6101
7707
 
6102
- kernel_mul_mv_iq3_xxs_f32_impl(
7708
+ #if QK_K == 64
7709
+ kernel_mul_mv_iq4_nl_f32_impl(
7710
+ #else
7711
+ kernel_mul_mv_iq4_xs_f32_impl(
7712
+ #endif
6103
7713
  src0[id],
6104
7714
  (device const float *) (src1 + bid*nb11),
6105
7715
  dst + bid*ne0,