llama_cpp 0.12.6 → 0.12.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -351,12 +351,17 @@ kernel void kernel_sum_rows(
351
351
  kernel void kernel_soft_max(
352
352
  device const float * src0,
353
353
  device const float * src1,
354
+ device const float * src2,
354
355
  device float * dst,
355
356
  constant int64_t & ne00,
356
357
  constant int64_t & ne01,
357
358
  constant int64_t & ne02,
358
359
  constant float & scale,
359
- threadgroup float * buf [[threadgroup(0)]],
360
+ constant float & max_bias,
361
+ constant float & m0,
362
+ constant float & m1,
363
+ constant uint32_t & n_head_log2,
364
+ threadgroup float * buf [[threadgroup(0)]],
360
365
  uint tgpig[[threadgroup_position_in_grid]],
361
366
  uint tpitg[[thread_position_in_threadgroup]],
362
367
  uint sgitg[[simdgroup_index_in_threadgroup]],
@@ -368,13 +373,26 @@ kernel void kernel_soft_max(
368
373
 
369
374
  device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
370
375
  device const float * pmask = src1 != src0 ? src1 + i01*ne00 : nullptr;
376
+ device const float * ppos = src2 != src0 ? src2 : nullptr;
371
377
  device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
372
378
 
379
+ float slope = 0.0f;
380
+
381
+ // ALiBi
382
+ if (max_bias > 0.0f) {
383
+ const int64_t h = i02;
384
+
385
+ const float base = h < n_head_log2 ? m0 : m1;
386
+ const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
387
+
388
+ slope = pow(base, exp);
389
+ }
390
+
373
391
  // parallel max
374
392
  float lmax = -INFINITY;
375
393
 
376
394
  for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
377
- lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f));
395
+ lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f));
378
396
  }
379
397
 
380
398
  // find the max value in the block
@@ -399,7 +417,7 @@ kernel void kernel_soft_max(
399
417
  // parallel sum
400
418
  float lsum = 0.0f;
401
419
  for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
402
- const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val);
420
+ const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f)) - max_val);
403
421
  lsum += exp_psrc0;
404
422
  pdst[i00] = exp_psrc0;
405
423
  }
@@ -437,12 +455,17 @@ kernel void kernel_soft_max(
437
455
  kernel void kernel_soft_max_4(
438
456
  device const float * src0,
439
457
  device const float * src1,
458
+ device const float * src2,
440
459
  device float * dst,
441
460
  constant int64_t & ne00,
442
461
  constant int64_t & ne01,
443
462
  constant int64_t & ne02,
444
463
  constant float & scale,
445
- threadgroup float * buf [[threadgroup(0)]],
464
+ constant float & max_bias,
465
+ constant float & m0,
466
+ constant float & m1,
467
+ constant uint32_t & n_head_log2,
468
+ threadgroup float * buf [[threadgroup(0)]],
446
469
  uint tgpig[[threadgroup_position_in_grid]],
447
470
  uint tpitg[[thread_position_in_threadgroup]],
448
471
  uint sgitg[[simdgroup_index_in_threadgroup]],
@@ -454,13 +477,25 @@ kernel void kernel_soft_max_4(
454
477
 
455
478
  device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
456
479
  device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
480
+ device const float4 * ppos = src2 != src0 ? (device const float4 *)(src2) : nullptr;
457
481
  device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
458
482
 
483
+ float slope = 0.0f;
484
+
485
+ if (max_bias > 0.0f) {
486
+ const int64_t h = i02;
487
+
488
+ const float base = h < n_head_log2 ? m0 : m1;
489
+ const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
490
+
491
+ slope = pow(base, exp);
492
+ }
493
+
459
494
  // parallel max
460
495
  float4 lmax4 = -INFINITY;
461
496
 
462
497
  for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
463
- lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f));
498
+ lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f));
464
499
  }
465
500
 
466
501
  const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
@@ -486,7 +521,7 @@ kernel void kernel_soft_max_4(
486
521
  // parallel sum
487
522
  float4 lsum4 = 0.0f;
488
523
  for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
489
- const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val);
524
+ const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f)) - max_val);
490
525
  lsum4 += exp_psrc4;
491
526
  pdst4[i00] = exp_psrc4;
492
527
  }
@@ -2490,6 +2525,19 @@ typedef struct {
2490
2525
  } block_iq3_xxs;
2491
2526
  // 98 bytes / block for QK_K = 256, so 3.0625 bpw
2492
2527
 
2528
+ typedef struct {
2529
+ half d;
2530
+ uint8_t qs[QK_K/8];
2531
+ uint8_t scales[QK_K/16];
2532
+ } block_iq1_s;
2533
+
2534
+ // Non-linear quants
2535
+ #define QK4_NL 32
2536
+ typedef struct {
2537
+ half d;
2538
+ uint8_t qs[QK4_NL/2];
2539
+ } block_iq4_nl;
2540
+
2493
2541
  //====================================== dot products =========================
2494
2542
 
2495
2543
  void kernel_mul_mv_q2_K_f32_impl(
@@ -3747,6 +3795,137 @@ constexpr constant static uint32_t iq3xxs_grid[256] = {
3747
3795
  0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
3748
3796
  };
3749
3797
 
3798
+ #define NGRID_IQ1S 512
3799
+ constexpr constant static uint64_t iq1s_grid[NGRID_IQ1S] = {
3800
+ 0xffffffffffff0101, 0xffffffffff01ff00, 0xffffffffff010100, 0xffffffff00000000,
3801
+ 0xffffffff01ff00ff, 0xffffffff01ff0001, 0xffffffff0101ffff, 0xffffffff0101ff01,
3802
+ 0xffffff00ff000000, 0xffffff000000ff00, 0xffffff00000000ff, 0xffffff0000000100,
3803
+ 0xffffff0000010000, 0xffffff0001000000, 0xffffff01ffff00ff, 0xffffff01ff01ff00,
3804
+ 0xffffff01ff010100, 0xffffff0100000001, 0xffffff0101ffff00, 0xffffff0101ff0101,
3805
+ 0xffffff0101010100, 0xffff00ffff00ff01, 0xffff00ffff0000ff, 0xffff00ff00ff0100,
3806
+ 0xffff00ff0100ff00, 0xffff00ff010001ff, 0xffff0000ff0101ff, 0xffff000000ffff00,
3807
+ 0xffff000000000000, 0xffff00000001ff01, 0xffff000001000101, 0xffff0000010100ff,
3808
+ 0xffff0001ffff0100, 0xffff00010000ff00, 0xffff000100010101, 0xffff000101000000,
3809
+ 0xffff01ffffff0000, 0xffff01ffff01ffff, 0xffff01ffff010100, 0xffff01ff00000000,
3810
+ 0xffff01ff01ffffff, 0xffff01ff01ff0001, 0xffff01ff0101ffff, 0xffff01ff01010001,
3811
+ 0xffff0100ffffff01, 0xffff01000000ffff, 0xffff010000000100, 0xffff010001ff01ff,
3812
+ 0xffff010001000000, 0xffff0101ff000000, 0xffff0101000101ff, 0xffff010101ffff01,
3813
+ 0xffff01010101ff00, 0xff00ffffff000000, 0xff00ffff00ffff00, 0xff00ffff00000001,
3814
+ 0xff00ffff000001ff, 0xff00ffff01010000, 0xff00ff00ffff0000, 0xff00ff00ff00ff00,
3815
+ 0xff00ff00ff0000ff, 0xff00ff00ff000100, 0xff00ff00ff010001, 0xff00ff0000ff0001,
3816
+ 0xff00ff000000ffff, 0xff00ff0000000000, 0xff00ff000001ff00, 0xff00ff0000010100,
3817
+ 0xff00ff0001ff0000, 0xff00ff000100ff00, 0xff00ff0001000100, 0xff00ff01ff000000,
3818
+ 0xff00ff0100ff0000, 0xff00ff01000001ff, 0xff00ff0101010001, 0xff0000ff00000000,
3819
+ 0xff0000ff0001ff00, 0xff0000ff00010100, 0xff000000ffff0101, 0xff000000ff000000,
3820
+ 0xff000000ff01ff00, 0xff00000000ff0000, 0xff0000000000ff00, 0xff000000000000ff,
3821
+ 0xff00000000000000, 0xff00000000000001, 0xff00000000000100, 0xff0000000001ffff,
3822
+ 0xff00000000010000, 0xff00000001000000, 0xff00000001010100, 0xff000001ff00ff01,
3823
+ 0xff000001ff0100ff, 0xff00000100000000, 0xff0000010001ff00, 0xff00000101ff0100,
3824
+ 0xff0000010100ff00, 0xff0001ff00ff00ff, 0xff0001ff00000101, 0xff0001ff000100ff,
3825
+ 0xff0001ff01000000, 0xff000100ff0001ff, 0xff0001000000ff01, 0xff00010000000000,
3826
+ 0xff00010000010001, 0xff00010000010100, 0xff00010001ffff00, 0xff00010001ff0101,
3827
+ 0xff00010001010000, 0xff000101ffffffff, 0xff000101ff000101, 0xff00010101ff00ff,
3828
+ 0xff00010101000001, 0xff000101010100ff, 0xff01ffffff000101, 0xff01ffffff01ffff,
3829
+ 0xff01ffffff01ff01, 0xff01ffffff0101ff, 0xff01ffff00000000, 0xff01ffff01ff0001,
3830
+ 0xff01ffff0101ff01, 0xff01ff00ff000000, 0xff01ff0000ff0100, 0xff01ff000000ff01,
3831
+ 0xff01ff0000010000, 0xff01ff00010000ff, 0xff01ff01ff01ff00, 0xff01ff0100000101,
3832
+ 0xff0100ffffff0000, 0xff0100ffff010000, 0xff0100ff01ff00ff, 0xff0100ff01000100,
3833
+ 0xff0100ff010100ff, 0xff010000ffffff01, 0xff01000000000000, 0xff0100000101ff00,
3834
+ 0xff010001ffff00ff, 0xff010001ff000100, 0xff01000100ffff00, 0xff01000100010001,
3835
+ 0xff01000101ff0001, 0xff010001010001ff, 0xff0101ffffffffff, 0xff0101ffff01ffff,
3836
+ 0xff0101ffff010101, 0xff0101ff0000ff00, 0xff0101ff01010001, 0xff010100ff000000,
3837
+ 0xff010100ff01ff01, 0xff01010000ff0001, 0xff01010000000100, 0xff01010001000000,
3838
+ 0xff0101010100ffff, 0x00ffffff0000ff01, 0x00ffffff000000ff, 0x00ffffff00000100,
3839
+ 0x00ffffff00010000, 0x00ffff00ffff0001, 0x00ffff00ff0000ff, 0x00ffff00ff000100,
3840
+ 0x00ffff0000000000, 0x00ffff0001000100, 0x00ffff0001010001, 0x00ffff01ff00ff01,
3841
+ 0x00ffff0100ff0100, 0x00ffff010000ff00, 0x00ffff01000100ff, 0x00ffff0101ff00ff,
3842
+ 0x00ffff010101ff00, 0x00ff00ffffffffff, 0x00ff00ffffff01ff, 0x00ff00ffff000101,
3843
+ 0x00ff00ff00000000, 0x00ff00ff000101ff, 0x00ff00ff01010101, 0x00ff0000ff000000,
3844
+ 0x00ff0000ff01ffff, 0x00ff000000ff0000, 0x00ff00000000ff00, 0x00ff0000000000ff,
3845
+ 0x00ff000000000000, 0x00ff000000000001, 0x00ff000000000100, 0x00ff000000010000,
3846
+ 0x00ff000001ffff01, 0x00ff000001000000, 0x00ff0001ff000101, 0x00ff000100ffffff,
3847
+ 0x00ff000100000000, 0x00ff0001010001ff, 0x00ff01ffff000000, 0x00ff01ff0001ff00,
3848
+ 0x00ff01ff01ff0100, 0x00ff0100ff01ff01, 0x00ff010000ff00ff, 0x00ff010000ff0101,
3849
+ 0x00ff010000000000, 0x00ff010000010101, 0x00ff01000100ff00, 0x00ff010001010000,
3850
+ 0x00ff0101ffffff00, 0x00ff01010000ff01, 0x00ff010100000100, 0x00ff010101ff0000,
3851
+ 0x0000ffffffff0100, 0x0000ffffff00ff00, 0x0000ffffff0000ff, 0x0000ffffff010000,
3852
+ 0x0000ffff00000000, 0x0000ffff00010101, 0x0000ffff01ffff01, 0x0000ffff01000100,
3853
+ 0x0000ff00ff000000, 0x0000ff00ff01ff00, 0x0000ff00ff0101ff, 0x0000ff0000ff0000,
3854
+ 0x0000ff000000ff00, 0x0000ff00000000ff, 0x0000ff0000000000, 0x0000ff0000000001,
3855
+ 0x0000ff0000000100, 0x0000ff0000010000, 0x0000ff0001ffffff, 0x0000ff0001ff01ff,
3856
+ 0x0000ff0001000000, 0x0000ff000101ffff, 0x0000ff01ffff0101, 0x0000ff01ff010000,
3857
+ 0x0000ff0100000000, 0x0000ff0101000101, 0x000000ffffff0001, 0x000000ffff000000,
3858
+ 0x000000ff00ff0000, 0x000000ff0000ff00, 0x000000ff000000ff, 0x000000ff00000000,
3859
+ 0x000000ff00000001, 0x000000ff00000100, 0x000000ff00010000, 0x000000ff01000000,
3860
+ 0x000000ff0101ff00, 0x00000000ffff0000, 0x00000000ff00ff00, 0x00000000ff0000ff,
3861
+ 0x00000000ff000000, 0x00000000ff000001, 0x00000000ff000100, 0x00000000ff010000,
3862
+ 0x0000000000ffff00, 0x0000000000ff00ff, 0x0000000000ff0000, 0x0000000000ff0001,
3863
+ 0x0000000000ff0100, 0x000000000000ffff, 0x000000000000ff00, 0x000000000000ff01,
3864
+ 0x00000000000000ff, 0x0000000000000001, 0x00000000000001ff, 0x0000000000000100,
3865
+ 0x0000000000000101, 0x000000000001ff00, 0x00000000000100ff, 0x0000000000010000,
3866
+ 0x0000000000010001, 0x0000000000010100, 0x0000000001ff0000, 0x000000000100ff00,
3867
+ 0x00000000010000ff, 0x0000000001000000, 0x0000000001000001, 0x0000000001000100,
3868
+ 0x0000000001010000, 0x00000001ffff01ff, 0x00000001ff000000, 0x0000000100ff0000,
3869
+ 0x000000010000ff00, 0x00000001000000ff, 0x0000000100000000, 0x0000000100000001,
3870
+ 0x0000000100000100, 0x0000000100010000, 0x0000000101000000, 0x000001ffff00ff00,
3871
+ 0x000001ffff010001, 0x000001ffff0101ff, 0x000001ff00ffff01, 0x000001ff0000ffff,
3872
+ 0x000001ff00000000, 0x000001ff010000ff, 0x000001ff01010100, 0x00000100ffff0100,
3873
+ 0x00000100ff000000, 0x0000010000ff0000, 0x000001000000ff00, 0x00000100000000ff,
3874
+ 0x0000010000000000, 0x0000010000000001, 0x0000010000000100, 0x0000010000010000,
3875
+ 0x0000010001000000, 0x000001000101ff01, 0x00000101ffff0001, 0x00000101ff01ffff,
3876
+ 0x0000010100000000, 0x0000010101010100, 0x0001ffffff000000, 0x0001ffff00ffffff,
3877
+ 0x0001ffff00000100, 0x0001ffff0001ff00, 0x0001ffff01000000, 0x0001ff00ffffff00,
3878
+ 0x0001ff00ffff01ff, 0x0001ff00ff010000, 0x0001ff0000000000, 0x0001ff0000010001,
3879
+ 0x0001ff0001ff0000, 0x0001ff0001010100, 0x0001ff01ff0000ff, 0x0001ff01ff000001,
3880
+ 0x0001ff0100ffffff, 0x0001ff010001ffff, 0x0001ff01000101ff, 0x0001ff010100ff01,
3881
+ 0x000100ffff00ffff, 0x000100ffff00ff01, 0x000100ffff000100, 0x000100ff00000000,
3882
+ 0x000100ff000101ff, 0x000100ff01ff0101, 0x000100ff0100ffff, 0x000100ff01010101,
3883
+ 0x00010000ff000000, 0x00010000ff010100, 0x0001000000ff0000, 0x000100000000ff00,
3884
+ 0x00010000000000ff, 0x0001000000000000, 0x0001000000000001, 0x0001000000000100,
3885
+ 0x0001000000010000, 0x0001000001ffff01, 0x0001000001000000, 0x0001000100ff0101,
3886
+ 0x0001000100000000, 0x00010001010100ff, 0x000101ffffff01ff, 0x000101ffffff0101,
3887
+ 0x000101ff00010000, 0x000101ff01ff0000, 0x000101ff0100ff01, 0x00010100ffff0000,
3888
+ 0x0001010000000000, 0x000101000001ffff, 0x0001010000010101, 0x00010100010001ff,
3889
+ 0x00010101ff00ff00, 0x00010101ff010001, 0x0001010100ffffff, 0x0001010100ff01ff,
3890
+ 0x00010101000101ff, 0x0001010101ff0000, 0x000101010100ff01, 0x0001010101000101,
3891
+ 0x01ffffffffff0101, 0x01ffffffff01ffff, 0x01ffffffff01ff01, 0x01ffffffff0101ff,
3892
+ 0x01ffffffff010101, 0x01ffffff00000000, 0x01ffffff01ff01ff, 0x01ffffff01000101,
3893
+ 0x01ffffff0101ff01, 0x01ffffff010100ff, 0x01ffff000000ff00, 0x01ffff0000000001,
3894
+ 0x01ffff00000001ff, 0x01ffff0000010000, 0x01ffff0001ff0000, 0x01ffff01ffffffff,
3895
+ 0x01ffff01ffff01ff, 0x01ffff01ff000000, 0x01ffff01ff01ffff, 0x01ffff01ff0101ff,
3896
+ 0x01ffff010100ffff, 0x01ff00ffffff0000, 0x01ff00ffff010000, 0x01ff00ff00ffff01,
3897
+ 0x01ff0000ff0000ff, 0x01ff000000000000, 0x01ff00000001ff01, 0x01ff000001ffffff,
3898
+ 0x01ff000001010100, 0x01ff0001ffffff01, 0x01ff0001ff010001, 0x01ff000101ff0100,
3899
+ 0x01ff000101000001, 0x01ff0001010100ff, 0x01ff01ffff00ffff, 0x01ff01ff00010001,
3900
+ 0x01ff01ff01000000, 0x01ff01ff010101ff, 0x01ff0100ff000001, 0x01ff010000ffff00,
3901
+ 0x01ff010000000100, 0x01ff010001ff01ff, 0x01ff01000101ffff, 0x01ff0101ffff00ff,
3902
+ 0x01ff0101ffff0101, 0x01ff0101ff0101ff, 0x01ff010100010000, 0x0100ffff00ff00ff,
3903
+ 0x0100ffff00ff0001, 0x0100ffff00000100, 0x0100ffff0100ff00, 0x0100ff00ffff0000,
3904
+ 0x0100ff00ff00ffff, 0x0100ff00ff00ff01, 0x0100ff00ff000100, 0x0100ff00ff010000,
3905
+ 0x0100ff0000000000, 0x0100ff00000100ff, 0x0100ff0001ff0101, 0x0100ff0001010101,
3906
+ 0x0100ff0100ff00ff, 0x0100ff0100ff0001, 0x0100ff0100000100, 0x0100ff0100010001,
3907
+ 0x0100ff0101000000, 0x010000ffff00ff00, 0x010000ff0000ffff, 0x010000ff00000000,
3908
+ 0x010000ff010001ff, 0x010000ff01010001, 0x01000000ffffff00, 0x01000000ffff0101,
3909
+ 0x01000000ff000000, 0x01000000ff0100ff, 0x01000000ff010101, 0x0100000000ff0000,
3910
+ 0x010000000000ff00, 0x01000000000000ff, 0x0100000000000000, 0x0100000000000001,
3911
+ 0x0100000000000100, 0x0100000000010000, 0x0100000001000000, 0x0100000100000000,
3912
+ 0x01000001000101ff, 0x0100000101ffff01, 0x010001ffff000101, 0x010001ff00ff0100,
3913
+ 0x010001ff0000ff00, 0x010001ff000100ff, 0x010001ff01ffffff, 0x01000100ffff0000,
3914
+ 0x01000100ff0001ff, 0x0100010000000000, 0x010001000001ff00, 0x0100010001ff0000,
3915
+ 0x01000100010000ff, 0x0100010001000101, 0x01000101ff00ff01, 0x0100010100ff0100,
3916
+ 0x010001010000ffff, 0x0100010101010001, 0x0101ffffffff0101, 0x0101ffffff0001ff,
3917
+ 0x0101ffffff01ffff, 0x0101ffffff010101, 0x0101ffff00000000, 0x0101ffff0101ffff,
3918
+ 0x0101ffff010101ff, 0x0101ff00ff000000, 0x0101ff0000ff0100, 0x0101ff000000ff00,
3919
+ 0x0101ff0000010000, 0x0101ff00010000ff, 0x0101ff0001000001, 0x0101ff01ff010101,
3920
+ 0x0101ff0100000000, 0x0101ff010101ff00, 0x010100ffffff0000, 0x010100ffff010000,
3921
+ 0x010100ff00ff01ff, 0x010100ff000000ff, 0x010100ff00000101, 0x010100ff01ffff00,
3922
+ 0x01010000ffffff01, 0x01010000ff000100, 0x01010000ff01ff01, 0x0101000000000000,
3923
+ 0x01010000000100ff, 0x010100000101ff01, 0x01010001ffff0000, 0x01010001ff00ffff,
3924
+ 0x01010001ff010000, 0x0101000101ffffff, 0x0101000101ff01ff, 0x0101000101010101,
3925
+ 0x010101ffff01ffff, 0x010101ff00000000, 0x010101ff0001ff01, 0x010101ff0101ffff,
3926
+ 0x010101ff010101ff, 0x01010100ffffffff, 0x01010100ff000001, 0x010101000000ff00,
3927
+ 0x0101010001010000, 0x0101010100ff0001, 0x010101010001ff01, 0x010101010101ffff,
3928
+ };
3750
3929
 
3751
3930
  constexpr constant static uint8_t ksigns_iq2xs[128] = {
3752
3931
  0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
@@ -3854,7 +4033,10 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
3854
4033
  y4 += 32 * 32;
3855
4034
  }
3856
4035
  #else
3857
- // TODO
4036
+ (void) x;
4037
+ (void) y;
4038
+ (void) yl;
4039
+ (void) nb32;
3858
4040
  #endif
3859
4041
 
3860
4042
  for (int row = 0; row < N_DST; ++row) {
@@ -3997,7 +4179,10 @@ void kernel_mul_mv_iq2_xs_f32_impl(
3997
4179
  y4 += 32 * 32;
3998
4180
  }
3999
4181
  #else
4000
- // TODO
4182
+ (void) x;
4183
+ (void) y;
4184
+ (void) yl;
4185
+ (void) nb32;
4001
4186
  #endif
4002
4187
 
4003
4188
  for (int row = 0; row < N_DST; ++row) {
@@ -4133,7 +4318,10 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
4133
4318
  y4 += 32 * 32;
4134
4319
  }
4135
4320
  #else
4136
- // TODO
4321
+ (void) x;
4322
+ (void) y;
4323
+ (void) yl;
4324
+ (void) nb32;
4137
4325
  #endif
4138
4326
 
4139
4327
  for (int row = 0; row < N_DST; ++row) {
@@ -4173,6 +4361,250 @@ kernel void kernel_mul_mv_iq3_xxs_f32(
4173
4361
  kernel_mul_mv_iq3_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
4174
4362
  }
4175
4363
 
4364
+ void kernel_mul_mv_iq1_s_f32_impl(
4365
+ device const void * src0,
4366
+ device const float * src1,
4367
+ device float * dst,
4368
+ constant int64_t & ne00,
4369
+ constant int64_t & ne01,
4370
+ constant int64_t & ne02,
4371
+ constant int64_t & ne10,
4372
+ constant int64_t & ne12,
4373
+ constant int64_t & ne0,
4374
+ constant int64_t & ne1,
4375
+ constant uint & r2,
4376
+ constant uint & r3,
4377
+ uint3 tgpig[[threadgroup_position_in_grid]],
4378
+ uint tiisg[[thread_index_in_simdgroup]],
4379
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4380
+
4381
+ const int nb = ne00/QK_K;
4382
+ const int r0 = tgpig.x;
4383
+ const int r1 = tgpig.y;
4384
+ const int im = tgpig.z;
4385
+
4386
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
4387
+ const int ib_row = first_row * nb;
4388
+
4389
+ const uint i12 = im%ne12;
4390
+ const uint i13 = im/ne12;
4391
+
4392
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
4393
+ device const block_iq1_s * x = (device const block_iq1_s *) src0 + ib_row + offset0;
4394
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
4395
+
4396
+ float yl[16];
4397
+ float sumf[N_DST]={0.f}, all_sum;
4398
+
4399
+ const int nb32 = nb * (QK_K / 32);
4400
+
4401
+ #if QK_K == 256
4402
+ const int ix = tiisg/2;
4403
+ const int il = tiisg%2;
4404
+
4405
+ device const float * y4 = y + 32 * ix + 16 * il;
4406
+
4407
+ for (int ib32 = ix; ib32 < nb32; ib32 += 16) {
4408
+
4409
+ for (int i = 0; i < 16; ++i) {
4410
+ yl[i] = y4[i];
4411
+ }
4412
+
4413
+ const int ibl = ib32 / (QK_K / 32);
4414
+ const int ib = ib32 % (QK_K / 32);
4415
+
4416
+ device const block_iq1_s * xr = x + ibl;
4417
+ device const uint8_t * qs = xr->qs + 4 * ib + 2 * il;
4418
+ device const uint8_t * sc = xr->scales + 2 * ib + il;
4419
+ device const half * dh = &xr->d;
4420
+
4421
+ for (int row = 0; row < N_DST; row++) {
4422
+
4423
+ constant int8_t * grid1 = (constant int8_t *)(iq1s_grid + (qs[0] | ((sc[0] & 0x08) << 5)));
4424
+ constant int8_t * grid2 = (constant int8_t *)(iq1s_grid + (qs[1] | ((sc[0] & 0x80) << 1)));
4425
+
4426
+ float2 sum = {0};
4427
+ for (int j = 0; j < 8; ++j) {
4428
+ sum[0] += yl[j+ 0] * grid1[j];
4429
+ sum[1] += yl[j+ 8] * grid2[j];
4430
+ }
4431
+ sumf[row] += (float)dh[0] * (sum[0] * (2*(sc[0] & 7) + 1) + sum[1] * (2*((sc[0] >> 4) & 7) + 1));
4432
+
4433
+ dh += nb*sizeof(block_iq1_s)/2;
4434
+ qs += nb*sizeof(block_iq1_s);
4435
+ sc += nb*sizeof(block_iq1_s);
4436
+ }
4437
+
4438
+ y4 += 16 * 32;
4439
+ }
4440
+ #else
4441
+ (void) x;
4442
+ (void) y;
4443
+ (void) yl;
4444
+ (void) nb32;
4445
+ #endif
4446
+
4447
+ for (int row = 0; row < N_DST; ++row) {
4448
+ all_sum = simd_sum(sumf[row]);
4449
+ if (tiisg == 0) {
4450
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
4451
+ }
4452
+ }
4453
+ }
4454
+
4455
+ constexpr constant static float kvalues_iq4nl_f[16] = {
4456
+ -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
4457
+ };
4458
+
4459
+ void kernel_mul_mv_iq4_nl_f32_impl(
4460
+ device const void * src0,
4461
+ device const float * src1,
4462
+ device float * dst,
4463
+ constant int64_t & ne00,
4464
+ constant int64_t & ne01,
4465
+ constant int64_t & ne02,
4466
+ constant int64_t & ne10,
4467
+ constant int64_t & ne12,
4468
+ constant int64_t & ne0,
4469
+ constant int64_t & ne1,
4470
+ constant uint & r2,
4471
+ constant uint & r3,
4472
+ threadgroup float * shared_values [[threadgroup(0)]],
4473
+ uint3 tgpig[[threadgroup_position_in_grid]],
4474
+ uint tiisg[[thread_index_in_simdgroup]],
4475
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4476
+
4477
+ const int nb = ne00/QK4_NL;
4478
+ const int r0 = tgpig.x;
4479
+ const int r1 = tgpig.y;
4480
+ const int im = tgpig.z;
4481
+ const int first_row = (r0 * 2 + sgitg) * 2;
4482
+ const int ib_row = first_row * nb;
4483
+
4484
+ const uint i12 = im%ne12;
4485
+ const uint i13 = im/ne12;
4486
+
4487
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
4488
+ device const block_iq4_nl * x = (device const block_iq4_nl *) src0 + ib_row + offset0;
4489
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
4490
+
4491
+ const int ix = tiisg/2; // 0...15
4492
+ const int it = tiisg%2; // 0 or 1
4493
+
4494
+ shared_values[tiisg] = kvalues_iq4nl_f[tiisg%16];
4495
+ threadgroup_barrier(mem_flags::mem_threadgroup);
4496
+
4497
+ float4 yl[4];
4498
+ float sumf[2]={0.f}, all_sum;
4499
+
4500
+ device const float * yb = y + ix * QK4_NL + it * 8;
4501
+
4502
+ uint32_t aux32[2];
4503
+ thread const uint8_t * q8 = (thread const uint8_t *)aux32;
4504
+
4505
+ float4 qf1, qf2;
4506
+
4507
+ for (int ib = ix; ib < nb; ib += 16) {
4508
+
4509
+ device const float4 * y4 = (device const float4 *)yb;
4510
+ yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5];
4511
+
4512
+ for (int row = 0; row < 2; ++row) {
4513
+
4514
+ device const block_iq4_nl & xb = x[row*nb + ib];
4515
+ device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it);
4516
+
4517
+ float4 acc1 = {0.f}, acc2 = {0.f};
4518
+
4519
+ aux32[0] = q4[0] | (q4[1] << 16);
4520
+ aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f;
4521
+ aux32[0] &= 0x0f0f0f0f;
4522
+ qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
4523
+ qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
4524
+ acc1 += yl[0] * qf1;
4525
+ acc2 += yl[1] * qf2;
4526
+
4527
+ aux32[0] = q4[2] | (q4[3] << 16);
4528
+ aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f;
4529
+ aux32[0] &= 0x0f0f0f0f;
4530
+ qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
4531
+ qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
4532
+ acc1 += yl[2] * qf1;
4533
+ acc2 += yl[3] * qf2;
4534
+
4535
+ acc1 += acc2;
4536
+
4537
+ sumf[row] += (float)xb.d * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
4538
+
4539
+ }
4540
+
4541
+ yb += 16 * QK4_NL;
4542
+ }
4543
+
4544
+ for (int row = 0; row < 2; ++row) {
4545
+ all_sum = simd_sum(sumf[row]);
4546
+ if (tiisg == 0) {
4547
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
4548
+ }
4549
+ }
4550
+ }
4551
+
4552
+ [[host_name("kernel_mul_mv_iq1_s_f32")]]
4553
+ kernel void kernel_mul_mv_iq1_s_f32(
4554
+ device const void * src0,
4555
+ device const float * src1,
4556
+ device float * dst,
4557
+ constant int64_t & ne00,
4558
+ constant int64_t & ne01,
4559
+ constant int64_t & ne02,
4560
+ constant uint64_t & nb00,
4561
+ constant uint64_t & nb01,
4562
+ constant uint64_t & nb02,
4563
+ constant int64_t & ne10,
4564
+ constant int64_t & ne11,
4565
+ constant int64_t & ne12,
4566
+ constant uint64_t & nb10,
4567
+ constant uint64_t & nb11,
4568
+ constant uint64_t & nb12,
4569
+ constant int64_t & ne0,
4570
+ constant int64_t & ne1,
4571
+ constant uint & r2,
4572
+ constant uint & r3,
4573
+ uint3 tgpig[[threadgroup_position_in_grid]],
4574
+ uint tiisg[[thread_index_in_simdgroup]],
4575
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4576
+
4577
+ kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
4578
+ }
4579
+
4580
+ [[host_name("kernel_mul_mv_iq4_nl_f32")]]
4581
+ kernel void kernel_mul_mv_iq4_nl_f32(
4582
+ device const void * src0,
4583
+ device const float * src1,
4584
+ device float * dst,
4585
+ constant int64_t & ne00,
4586
+ constant int64_t & ne01,
4587
+ constant int64_t & ne02,
4588
+ constant uint64_t & nb00,
4589
+ constant uint64_t & nb01,
4590
+ constant uint64_t & nb02,
4591
+ constant int64_t & ne10,
4592
+ constant int64_t & ne11,
4593
+ constant int64_t & ne12,
4594
+ constant uint64_t & nb10,
4595
+ constant uint64_t & nb11,
4596
+ constant uint64_t & nb12,
4597
+ constant int64_t & ne0,
4598
+ constant int64_t & ne1,
4599
+ constant uint & r2,
4600
+ constant uint & r3,
4601
+ threadgroup float * shared_values [[threadgroup(0)]],
4602
+ uint3 tgpig[[threadgroup_position_in_grid]],
4603
+ uint tiisg[[thread_index_in_simdgroup]],
4604
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4605
+
4606
+ kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
4607
+ }
4176
4608
 
4177
4609
  //============================= templates and their specializations =============================
4178
4610
 
@@ -4369,6 +4801,8 @@ void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg
4369
4801
  const float dl = d * sc[0];
4370
4802
  const float ml = min * sc[1];
4371
4803
  #else
4804
+ (void) get_scale_min_k4_just2;
4805
+
4372
4806
  q = q + 16 * (il&1);
4373
4807
  device const uint8_t * s = xb->scales;
4374
4808
  device const half2 * dh = (device const half2 *)xb->d;
@@ -4518,6 +4952,37 @@ void dequantize_iq3_xxs(device const block_iq3_xxs * xb, short il, thread type4x
4518
4952
  }
4519
4953
  }
4520
4954
 
4955
+ template <typename type4x4>
4956
+ void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 & reg) {
4957
+ // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
4958
+ const float d = xb->d;
4959
+ device const uint8_t * qs = xb->qs + 2*il;
4960
+ device const uint8_t * sc = xb->scales + il;
4961
+ const float dl1 = d * (2*(sc[0] & 7) + 1);
4962
+ const float dl2 = d * (2*((sc[0] >> 4) & 7) + 1);
4963
+ constant int8_t * grid1 = (constant int8_t *)(iq1s_grid + (qs[0] | ((sc[0] & 0x08) << 5)));
4964
+ constant int8_t * grid2 = (constant int8_t *)(iq1s_grid + (qs[1] | ((sc[0] & 0x80) << 1)));
4965
+ for (int i = 0; i < 8; ++i) {
4966
+ reg[i/4+0][i%4] = dl1 * grid1[i];
4967
+ reg[i/4+2][i%4] = dl2 * grid2[i];
4968
+ }
4969
+ }
4970
+
4971
+ template <typename type4x4>
4972
+ void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4 & reg) {
4973
+ device const uint16_t * q4 = (device const uint16_t *)xb->qs;
4974
+ const float d = xb->d;
4975
+ uint32_t aux32;
4976
+ thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
4977
+ for (int i = 0; i < 4; ++i) {
4978
+ aux32 = ((q4[2*i] | (q4[2*i+1] << 16)) >> 4*il) & 0x0f0f0f0f;
4979
+ reg[i][0] = d * kvalues_iq4nl_f[q8[0]];
4980
+ reg[i][1] = d * kvalues_iq4nl_f[q8[1]];
4981
+ reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
4982
+ reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
4983
+ }
4984
+ }
4985
+
4521
4986
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
4522
4987
  kernel void kernel_get_rows(
4523
4988
  device const void * src0,
@@ -5060,6 +5525,8 @@ template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows
5060
5525
  template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_t kernel_get_rows<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
5061
5526
  template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_t kernel_get_rows<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
5062
5527
  template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_t kernel_get_rows<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
5528
+ template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows<block_iq1_s, QK_NL, dequantize_iq1_s>;
5529
+ template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows<block_iq4_nl, 2, dequantize_iq4_nl>;
5063
5530
 
5064
5531
  //
5065
5532
  // matrix-matrix multiplication
@@ -5099,6 +5566,8 @@ template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<b
5099
5566
  template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
5100
5567
  template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
5101
5568
  template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
5569
+ template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_s, QK_NL, dequantize_iq1_s>;
5570
+ template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_nl>;
5102
5571
 
5103
5572
  //
5104
5573
  // indirect matrix-matrix multiplication
@@ -5150,6 +5619,8 @@ template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mu
5150
5619
  template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
5151
5620
  template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
5152
5621
  template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
5622
+ template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s, QK_NL, dequantize_iq1_s>;
5623
+ template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl, 2, dequantize_iq4_nl>;
5153
5624
 
5154
5625
  //
5155
5626
  // matrix-vector multiplication
@@ -6117,3 +6588,131 @@ kernel void kernel_mul_mv_id_iq3_xxs_f32(
6117
6588
  tiisg,
6118
6589
  sgitg);
6119
6590
  }
6591
+
6592
+ [[host_name("kernel_mul_mv_id_iq1_s_f32")]]
6593
+ kernel void kernel_mul_mv_id_iq1_s_f32(
6594
+ device const char * ids,
6595
+ device const char * src1,
6596
+ device float * dst,
6597
+ constant uint64_t & nbi1,
6598
+ constant int64_t & ne00,
6599
+ constant int64_t & ne01,
6600
+ constant int64_t & ne02,
6601
+ constant uint64_t & nb00,
6602
+ constant uint64_t & nb01,
6603
+ constant uint64_t & nb02,
6604
+ constant int64_t & ne10,
6605
+ constant int64_t & ne11,
6606
+ constant int64_t & ne12,
6607
+ constant int64_t & ne13,
6608
+ constant uint64_t & nb10,
6609
+ constant uint64_t & nb11,
6610
+ constant uint64_t & nb12,
6611
+ constant int64_t & ne0,
6612
+ constant int64_t & ne1,
6613
+ constant uint64_t & nb1,
6614
+ constant uint & r2,
6615
+ constant uint & r3,
6616
+ constant int & idx,
6617
+ device const char * src00,
6618
+ device const char * src01,
6619
+ device const char * src02,
6620
+ device const char * src03,
6621
+ device const char * src04,
6622
+ device const char * src05,
6623
+ device const char * src06,
6624
+ device const char * src07,
6625
+ uint3 tgpig[[threadgroup_position_in_grid]],
6626
+ uint tiitg[[thread_index_in_threadgroup]],
6627
+ uint tiisg[[thread_index_in_simdgroup]],
6628
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
6629
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6630
+
6631
+ const int64_t bid = tgpig.z/(ne12*ne13);
6632
+
6633
+ tgpig.z = tgpig.z%(ne12*ne13);
6634
+
6635
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6636
+
6637
+ kernel_mul_mv_iq1_s_f32_impl(
6638
+ src0[id],
6639
+ (device const float *) (src1 + bid*nb11),
6640
+ dst + bid*ne0,
6641
+ ne00,
6642
+ ne01,
6643
+ ne02,
6644
+ ne10,
6645
+ ne12,
6646
+ ne0,
6647
+ ne1,
6648
+ r2,
6649
+ r3,
6650
+ tgpig,
6651
+ tiisg,
6652
+ sgitg);
6653
+ }
6654
+
6655
+ [[host_name("kernel_mul_mv_id_iq4_nl_f32")]]
6656
+ kernel void kernel_mul_mv_id_iq4_nl_f32(
6657
+ device const char * ids,
6658
+ device const char * src1,
6659
+ device float * dst,
6660
+ constant uint64_t & nbi1,
6661
+ constant int64_t & ne00,
6662
+ constant int64_t & ne01,
6663
+ constant int64_t & ne02,
6664
+ constant uint64_t & nb00,
6665
+ constant uint64_t & nb01,
6666
+ constant uint64_t & nb02,
6667
+ constant int64_t & ne10,
6668
+ constant int64_t & ne11,
6669
+ constant int64_t & ne12,
6670
+ constant int64_t & ne13,
6671
+ constant uint64_t & nb10,
6672
+ constant uint64_t & nb11,
6673
+ constant uint64_t & nb12,
6674
+ constant int64_t & ne0,
6675
+ constant int64_t & ne1,
6676
+ constant uint64_t & nb1,
6677
+ constant uint & r2,
6678
+ constant uint & r3,
6679
+ constant int & idx,
6680
+ device const char * src00,
6681
+ device const char * src01,
6682
+ device const char * src02,
6683
+ device const char * src03,
6684
+ device const char * src04,
6685
+ device const char * src05,
6686
+ device const char * src06,
6687
+ device const char * src07,
6688
+ threadgroup float * shared_values [[threadgroup(0)]],
6689
+ uint3 tgpig[[threadgroup_position_in_grid]],
6690
+ uint tiitg[[thread_index_in_threadgroup]],
6691
+ uint tiisg[[thread_index_in_simdgroup]],
6692
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
6693
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6694
+
6695
+ const int64_t bid = tgpig.z/(ne12*ne13);
6696
+
6697
+ tgpig.z = tgpig.z%(ne12*ne13);
6698
+
6699
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6700
+
6701
+ kernel_mul_mv_iq4_nl_f32_impl(
6702
+ src0[id],
6703
+ (device const float *) (src1 + bid*nb11),
6704
+ dst + bid*ne0,
6705
+ ne00,
6706
+ ne01,
6707
+ ne02,
6708
+ ne10,
6709
+ ne12,
6710
+ ne0,
6711
+ ne1,
6712
+ r2,
6713
+ r3,
6714
+ shared_values,
6715
+ tgpig,
6716
+ tiisg,
6717
+ sgitg);
6718
+ }