llama_cpp 0.12.6 → 0.12.7

Sign up to get free protection for your applications and to get access to all the features.
@@ -351,12 +351,17 @@ kernel void kernel_sum_rows(
351
351
  kernel void kernel_soft_max(
352
352
  device const float * src0,
353
353
  device const float * src1,
354
+ device const float * src2,
354
355
  device float * dst,
355
356
  constant int64_t & ne00,
356
357
  constant int64_t & ne01,
357
358
  constant int64_t & ne02,
358
359
  constant float & scale,
359
- threadgroup float * buf [[threadgroup(0)]],
360
+ constant float & max_bias,
361
+ constant float & m0,
362
+ constant float & m1,
363
+ constant uint32_t & n_head_log2,
364
+ threadgroup float * buf [[threadgroup(0)]],
360
365
  uint tgpig[[threadgroup_position_in_grid]],
361
366
  uint tpitg[[thread_position_in_threadgroup]],
362
367
  uint sgitg[[simdgroup_index_in_threadgroup]],
@@ -368,13 +373,26 @@ kernel void kernel_soft_max(
368
373
 
369
374
  device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
370
375
  device const float * pmask = src1 != src0 ? src1 + i01*ne00 : nullptr;
376
+ device const float * ppos = src2 != src0 ? src2 : nullptr;
371
377
  device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
372
378
 
379
+ float slope = 0.0f;
380
+
381
+ // ALiBi
382
+ if (max_bias > 0.0f) {
383
+ const int64_t h = i02;
384
+
385
+ const float base = h < n_head_log2 ? m0 : m1;
386
+ const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
387
+
388
+ slope = pow(base, exp);
389
+ }
390
+
373
391
  // parallel max
374
392
  float lmax = -INFINITY;
375
393
 
376
394
  for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
377
- lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f));
395
+ lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f));
378
396
  }
379
397
 
380
398
  // find the max value in the block
@@ -399,7 +417,7 @@ kernel void kernel_soft_max(
399
417
  // parallel sum
400
418
  float lsum = 0.0f;
401
419
  for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
402
- const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val);
420
+ const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f)) - max_val);
403
421
  lsum += exp_psrc0;
404
422
  pdst[i00] = exp_psrc0;
405
423
  }
@@ -437,12 +455,17 @@ kernel void kernel_soft_max(
437
455
  kernel void kernel_soft_max_4(
438
456
  device const float * src0,
439
457
  device const float * src1,
458
+ device const float * src2,
440
459
  device float * dst,
441
460
  constant int64_t & ne00,
442
461
  constant int64_t & ne01,
443
462
  constant int64_t & ne02,
444
463
  constant float & scale,
445
- threadgroup float * buf [[threadgroup(0)]],
464
+ constant float & max_bias,
465
+ constant float & m0,
466
+ constant float & m1,
467
+ constant uint32_t & n_head_log2,
468
+ threadgroup float * buf [[threadgroup(0)]],
446
469
  uint tgpig[[threadgroup_position_in_grid]],
447
470
  uint tpitg[[thread_position_in_threadgroup]],
448
471
  uint sgitg[[simdgroup_index_in_threadgroup]],
@@ -454,13 +477,25 @@ kernel void kernel_soft_max_4(
454
477
 
455
478
  device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
456
479
  device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
480
+ device const float4 * ppos = src2 != src0 ? (device const float4 *)(src2) : nullptr;
457
481
  device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
458
482
 
483
+ float slope = 0.0f;
484
+
485
+ if (max_bias > 0.0f) {
486
+ const int64_t h = i02;
487
+
488
+ const float base = h < n_head_log2 ? m0 : m1;
489
+ const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
490
+
491
+ slope = pow(base, exp);
492
+ }
493
+
459
494
  // parallel max
460
495
  float4 lmax4 = -INFINITY;
461
496
 
462
497
  for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
463
- lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f));
498
+ lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f));
464
499
  }
465
500
 
466
501
  const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
@@ -486,7 +521,7 @@ kernel void kernel_soft_max_4(
486
521
  // parallel sum
487
522
  float4 lsum4 = 0.0f;
488
523
  for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
489
- const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val);
524
+ const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f)) - max_val);
490
525
  lsum4 += exp_psrc4;
491
526
  pdst4[i00] = exp_psrc4;
492
527
  }
@@ -2490,6 +2525,19 @@ typedef struct {
2490
2525
  } block_iq3_xxs;
2491
2526
  // 98 bytes / block for QK_K = 256, so 3.0625 bpw
2492
2527
 
2528
+ typedef struct {
2529
+ half d;
2530
+ uint8_t qs[QK_K/8];
2531
+ uint8_t scales[QK_K/16];
2532
+ } block_iq1_s;
2533
+
2534
+ // Non-linear quants
2535
+ #define QK4_NL 32
2536
+ typedef struct {
2537
+ half d;
2538
+ uint8_t qs[QK4_NL/2];
2539
+ } block_iq4_nl;
2540
+
2493
2541
  //====================================== dot products =========================
2494
2542
 
2495
2543
  void kernel_mul_mv_q2_K_f32_impl(
@@ -3747,6 +3795,137 @@ constexpr constant static uint32_t iq3xxs_grid[256] = {
3747
3795
  0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
3748
3796
  };
3749
3797
 
3798
+ #define NGRID_IQ1S 512
3799
+ constexpr constant static uint64_t iq1s_grid[NGRID_IQ1S] = {
3800
+ 0xffffffffffff0101, 0xffffffffff01ff00, 0xffffffffff010100, 0xffffffff00000000,
3801
+ 0xffffffff01ff00ff, 0xffffffff01ff0001, 0xffffffff0101ffff, 0xffffffff0101ff01,
3802
+ 0xffffff00ff000000, 0xffffff000000ff00, 0xffffff00000000ff, 0xffffff0000000100,
3803
+ 0xffffff0000010000, 0xffffff0001000000, 0xffffff01ffff00ff, 0xffffff01ff01ff00,
3804
+ 0xffffff01ff010100, 0xffffff0100000001, 0xffffff0101ffff00, 0xffffff0101ff0101,
3805
+ 0xffffff0101010100, 0xffff00ffff00ff01, 0xffff00ffff0000ff, 0xffff00ff00ff0100,
3806
+ 0xffff00ff0100ff00, 0xffff00ff010001ff, 0xffff0000ff0101ff, 0xffff000000ffff00,
3807
+ 0xffff000000000000, 0xffff00000001ff01, 0xffff000001000101, 0xffff0000010100ff,
3808
+ 0xffff0001ffff0100, 0xffff00010000ff00, 0xffff000100010101, 0xffff000101000000,
3809
+ 0xffff01ffffff0000, 0xffff01ffff01ffff, 0xffff01ffff010100, 0xffff01ff00000000,
3810
+ 0xffff01ff01ffffff, 0xffff01ff01ff0001, 0xffff01ff0101ffff, 0xffff01ff01010001,
3811
+ 0xffff0100ffffff01, 0xffff01000000ffff, 0xffff010000000100, 0xffff010001ff01ff,
3812
+ 0xffff010001000000, 0xffff0101ff000000, 0xffff0101000101ff, 0xffff010101ffff01,
3813
+ 0xffff01010101ff00, 0xff00ffffff000000, 0xff00ffff00ffff00, 0xff00ffff00000001,
3814
+ 0xff00ffff000001ff, 0xff00ffff01010000, 0xff00ff00ffff0000, 0xff00ff00ff00ff00,
3815
+ 0xff00ff00ff0000ff, 0xff00ff00ff000100, 0xff00ff00ff010001, 0xff00ff0000ff0001,
3816
+ 0xff00ff000000ffff, 0xff00ff0000000000, 0xff00ff000001ff00, 0xff00ff0000010100,
3817
+ 0xff00ff0001ff0000, 0xff00ff000100ff00, 0xff00ff0001000100, 0xff00ff01ff000000,
3818
+ 0xff00ff0100ff0000, 0xff00ff01000001ff, 0xff00ff0101010001, 0xff0000ff00000000,
3819
+ 0xff0000ff0001ff00, 0xff0000ff00010100, 0xff000000ffff0101, 0xff000000ff000000,
3820
+ 0xff000000ff01ff00, 0xff00000000ff0000, 0xff0000000000ff00, 0xff000000000000ff,
3821
+ 0xff00000000000000, 0xff00000000000001, 0xff00000000000100, 0xff0000000001ffff,
3822
+ 0xff00000000010000, 0xff00000001000000, 0xff00000001010100, 0xff000001ff00ff01,
3823
+ 0xff000001ff0100ff, 0xff00000100000000, 0xff0000010001ff00, 0xff00000101ff0100,
3824
+ 0xff0000010100ff00, 0xff0001ff00ff00ff, 0xff0001ff00000101, 0xff0001ff000100ff,
3825
+ 0xff0001ff01000000, 0xff000100ff0001ff, 0xff0001000000ff01, 0xff00010000000000,
3826
+ 0xff00010000010001, 0xff00010000010100, 0xff00010001ffff00, 0xff00010001ff0101,
3827
+ 0xff00010001010000, 0xff000101ffffffff, 0xff000101ff000101, 0xff00010101ff00ff,
3828
+ 0xff00010101000001, 0xff000101010100ff, 0xff01ffffff000101, 0xff01ffffff01ffff,
3829
+ 0xff01ffffff01ff01, 0xff01ffffff0101ff, 0xff01ffff00000000, 0xff01ffff01ff0001,
3830
+ 0xff01ffff0101ff01, 0xff01ff00ff000000, 0xff01ff0000ff0100, 0xff01ff000000ff01,
3831
+ 0xff01ff0000010000, 0xff01ff00010000ff, 0xff01ff01ff01ff00, 0xff01ff0100000101,
3832
+ 0xff0100ffffff0000, 0xff0100ffff010000, 0xff0100ff01ff00ff, 0xff0100ff01000100,
3833
+ 0xff0100ff010100ff, 0xff010000ffffff01, 0xff01000000000000, 0xff0100000101ff00,
3834
+ 0xff010001ffff00ff, 0xff010001ff000100, 0xff01000100ffff00, 0xff01000100010001,
3835
+ 0xff01000101ff0001, 0xff010001010001ff, 0xff0101ffffffffff, 0xff0101ffff01ffff,
3836
+ 0xff0101ffff010101, 0xff0101ff0000ff00, 0xff0101ff01010001, 0xff010100ff000000,
3837
+ 0xff010100ff01ff01, 0xff01010000ff0001, 0xff01010000000100, 0xff01010001000000,
3838
+ 0xff0101010100ffff, 0x00ffffff0000ff01, 0x00ffffff000000ff, 0x00ffffff00000100,
3839
+ 0x00ffffff00010000, 0x00ffff00ffff0001, 0x00ffff00ff0000ff, 0x00ffff00ff000100,
3840
+ 0x00ffff0000000000, 0x00ffff0001000100, 0x00ffff0001010001, 0x00ffff01ff00ff01,
3841
+ 0x00ffff0100ff0100, 0x00ffff010000ff00, 0x00ffff01000100ff, 0x00ffff0101ff00ff,
3842
+ 0x00ffff010101ff00, 0x00ff00ffffffffff, 0x00ff00ffffff01ff, 0x00ff00ffff000101,
3843
+ 0x00ff00ff00000000, 0x00ff00ff000101ff, 0x00ff00ff01010101, 0x00ff0000ff000000,
3844
+ 0x00ff0000ff01ffff, 0x00ff000000ff0000, 0x00ff00000000ff00, 0x00ff0000000000ff,
3845
+ 0x00ff000000000000, 0x00ff000000000001, 0x00ff000000000100, 0x00ff000000010000,
3846
+ 0x00ff000001ffff01, 0x00ff000001000000, 0x00ff0001ff000101, 0x00ff000100ffffff,
3847
+ 0x00ff000100000000, 0x00ff0001010001ff, 0x00ff01ffff000000, 0x00ff01ff0001ff00,
3848
+ 0x00ff01ff01ff0100, 0x00ff0100ff01ff01, 0x00ff010000ff00ff, 0x00ff010000ff0101,
3849
+ 0x00ff010000000000, 0x00ff010000010101, 0x00ff01000100ff00, 0x00ff010001010000,
3850
+ 0x00ff0101ffffff00, 0x00ff01010000ff01, 0x00ff010100000100, 0x00ff010101ff0000,
3851
+ 0x0000ffffffff0100, 0x0000ffffff00ff00, 0x0000ffffff0000ff, 0x0000ffffff010000,
3852
+ 0x0000ffff00000000, 0x0000ffff00010101, 0x0000ffff01ffff01, 0x0000ffff01000100,
3853
+ 0x0000ff00ff000000, 0x0000ff00ff01ff00, 0x0000ff00ff0101ff, 0x0000ff0000ff0000,
3854
+ 0x0000ff000000ff00, 0x0000ff00000000ff, 0x0000ff0000000000, 0x0000ff0000000001,
3855
+ 0x0000ff0000000100, 0x0000ff0000010000, 0x0000ff0001ffffff, 0x0000ff0001ff01ff,
3856
+ 0x0000ff0001000000, 0x0000ff000101ffff, 0x0000ff01ffff0101, 0x0000ff01ff010000,
3857
+ 0x0000ff0100000000, 0x0000ff0101000101, 0x000000ffffff0001, 0x000000ffff000000,
3858
+ 0x000000ff00ff0000, 0x000000ff0000ff00, 0x000000ff000000ff, 0x000000ff00000000,
3859
+ 0x000000ff00000001, 0x000000ff00000100, 0x000000ff00010000, 0x000000ff01000000,
3860
+ 0x000000ff0101ff00, 0x00000000ffff0000, 0x00000000ff00ff00, 0x00000000ff0000ff,
3861
+ 0x00000000ff000000, 0x00000000ff000001, 0x00000000ff000100, 0x00000000ff010000,
3862
+ 0x0000000000ffff00, 0x0000000000ff00ff, 0x0000000000ff0000, 0x0000000000ff0001,
3863
+ 0x0000000000ff0100, 0x000000000000ffff, 0x000000000000ff00, 0x000000000000ff01,
3864
+ 0x00000000000000ff, 0x0000000000000001, 0x00000000000001ff, 0x0000000000000100,
3865
+ 0x0000000000000101, 0x000000000001ff00, 0x00000000000100ff, 0x0000000000010000,
3866
+ 0x0000000000010001, 0x0000000000010100, 0x0000000001ff0000, 0x000000000100ff00,
3867
+ 0x00000000010000ff, 0x0000000001000000, 0x0000000001000001, 0x0000000001000100,
3868
+ 0x0000000001010000, 0x00000001ffff01ff, 0x00000001ff000000, 0x0000000100ff0000,
3869
+ 0x000000010000ff00, 0x00000001000000ff, 0x0000000100000000, 0x0000000100000001,
3870
+ 0x0000000100000100, 0x0000000100010000, 0x0000000101000000, 0x000001ffff00ff00,
3871
+ 0x000001ffff010001, 0x000001ffff0101ff, 0x000001ff00ffff01, 0x000001ff0000ffff,
3872
+ 0x000001ff00000000, 0x000001ff010000ff, 0x000001ff01010100, 0x00000100ffff0100,
3873
+ 0x00000100ff000000, 0x0000010000ff0000, 0x000001000000ff00, 0x00000100000000ff,
3874
+ 0x0000010000000000, 0x0000010000000001, 0x0000010000000100, 0x0000010000010000,
3875
+ 0x0000010001000000, 0x000001000101ff01, 0x00000101ffff0001, 0x00000101ff01ffff,
3876
+ 0x0000010100000000, 0x0000010101010100, 0x0001ffffff000000, 0x0001ffff00ffffff,
3877
+ 0x0001ffff00000100, 0x0001ffff0001ff00, 0x0001ffff01000000, 0x0001ff00ffffff00,
3878
+ 0x0001ff00ffff01ff, 0x0001ff00ff010000, 0x0001ff0000000000, 0x0001ff0000010001,
3879
+ 0x0001ff0001ff0000, 0x0001ff0001010100, 0x0001ff01ff0000ff, 0x0001ff01ff000001,
3880
+ 0x0001ff0100ffffff, 0x0001ff010001ffff, 0x0001ff01000101ff, 0x0001ff010100ff01,
3881
+ 0x000100ffff00ffff, 0x000100ffff00ff01, 0x000100ffff000100, 0x000100ff00000000,
3882
+ 0x000100ff000101ff, 0x000100ff01ff0101, 0x000100ff0100ffff, 0x000100ff01010101,
3883
+ 0x00010000ff000000, 0x00010000ff010100, 0x0001000000ff0000, 0x000100000000ff00,
3884
+ 0x00010000000000ff, 0x0001000000000000, 0x0001000000000001, 0x0001000000000100,
3885
+ 0x0001000000010000, 0x0001000001ffff01, 0x0001000001000000, 0x0001000100ff0101,
3886
+ 0x0001000100000000, 0x00010001010100ff, 0x000101ffffff01ff, 0x000101ffffff0101,
3887
+ 0x000101ff00010000, 0x000101ff01ff0000, 0x000101ff0100ff01, 0x00010100ffff0000,
3888
+ 0x0001010000000000, 0x000101000001ffff, 0x0001010000010101, 0x00010100010001ff,
3889
+ 0x00010101ff00ff00, 0x00010101ff010001, 0x0001010100ffffff, 0x0001010100ff01ff,
3890
+ 0x00010101000101ff, 0x0001010101ff0000, 0x000101010100ff01, 0x0001010101000101,
3891
+ 0x01ffffffffff0101, 0x01ffffffff01ffff, 0x01ffffffff01ff01, 0x01ffffffff0101ff,
3892
+ 0x01ffffffff010101, 0x01ffffff00000000, 0x01ffffff01ff01ff, 0x01ffffff01000101,
3893
+ 0x01ffffff0101ff01, 0x01ffffff010100ff, 0x01ffff000000ff00, 0x01ffff0000000001,
3894
+ 0x01ffff00000001ff, 0x01ffff0000010000, 0x01ffff0001ff0000, 0x01ffff01ffffffff,
3895
+ 0x01ffff01ffff01ff, 0x01ffff01ff000000, 0x01ffff01ff01ffff, 0x01ffff01ff0101ff,
3896
+ 0x01ffff010100ffff, 0x01ff00ffffff0000, 0x01ff00ffff010000, 0x01ff00ff00ffff01,
3897
+ 0x01ff0000ff0000ff, 0x01ff000000000000, 0x01ff00000001ff01, 0x01ff000001ffffff,
3898
+ 0x01ff000001010100, 0x01ff0001ffffff01, 0x01ff0001ff010001, 0x01ff000101ff0100,
3899
+ 0x01ff000101000001, 0x01ff0001010100ff, 0x01ff01ffff00ffff, 0x01ff01ff00010001,
3900
+ 0x01ff01ff01000000, 0x01ff01ff010101ff, 0x01ff0100ff000001, 0x01ff010000ffff00,
3901
+ 0x01ff010000000100, 0x01ff010001ff01ff, 0x01ff01000101ffff, 0x01ff0101ffff00ff,
3902
+ 0x01ff0101ffff0101, 0x01ff0101ff0101ff, 0x01ff010100010000, 0x0100ffff00ff00ff,
3903
+ 0x0100ffff00ff0001, 0x0100ffff00000100, 0x0100ffff0100ff00, 0x0100ff00ffff0000,
3904
+ 0x0100ff00ff00ffff, 0x0100ff00ff00ff01, 0x0100ff00ff000100, 0x0100ff00ff010000,
3905
+ 0x0100ff0000000000, 0x0100ff00000100ff, 0x0100ff0001ff0101, 0x0100ff0001010101,
3906
+ 0x0100ff0100ff00ff, 0x0100ff0100ff0001, 0x0100ff0100000100, 0x0100ff0100010001,
3907
+ 0x0100ff0101000000, 0x010000ffff00ff00, 0x010000ff0000ffff, 0x010000ff00000000,
3908
+ 0x010000ff010001ff, 0x010000ff01010001, 0x01000000ffffff00, 0x01000000ffff0101,
3909
+ 0x01000000ff000000, 0x01000000ff0100ff, 0x01000000ff010101, 0x0100000000ff0000,
3910
+ 0x010000000000ff00, 0x01000000000000ff, 0x0100000000000000, 0x0100000000000001,
3911
+ 0x0100000000000100, 0x0100000000010000, 0x0100000001000000, 0x0100000100000000,
3912
+ 0x01000001000101ff, 0x0100000101ffff01, 0x010001ffff000101, 0x010001ff00ff0100,
3913
+ 0x010001ff0000ff00, 0x010001ff000100ff, 0x010001ff01ffffff, 0x01000100ffff0000,
3914
+ 0x01000100ff0001ff, 0x0100010000000000, 0x010001000001ff00, 0x0100010001ff0000,
3915
+ 0x01000100010000ff, 0x0100010001000101, 0x01000101ff00ff01, 0x0100010100ff0100,
3916
+ 0x010001010000ffff, 0x0100010101010001, 0x0101ffffffff0101, 0x0101ffffff0001ff,
3917
+ 0x0101ffffff01ffff, 0x0101ffffff010101, 0x0101ffff00000000, 0x0101ffff0101ffff,
3918
+ 0x0101ffff010101ff, 0x0101ff00ff000000, 0x0101ff0000ff0100, 0x0101ff000000ff00,
3919
+ 0x0101ff0000010000, 0x0101ff00010000ff, 0x0101ff0001000001, 0x0101ff01ff010101,
3920
+ 0x0101ff0100000000, 0x0101ff010101ff00, 0x010100ffffff0000, 0x010100ffff010000,
3921
+ 0x010100ff00ff01ff, 0x010100ff000000ff, 0x010100ff00000101, 0x010100ff01ffff00,
3922
+ 0x01010000ffffff01, 0x01010000ff000100, 0x01010000ff01ff01, 0x0101000000000000,
3923
+ 0x01010000000100ff, 0x010100000101ff01, 0x01010001ffff0000, 0x01010001ff00ffff,
3924
+ 0x01010001ff010000, 0x0101000101ffffff, 0x0101000101ff01ff, 0x0101000101010101,
3925
+ 0x010101ffff01ffff, 0x010101ff00000000, 0x010101ff0001ff01, 0x010101ff0101ffff,
3926
+ 0x010101ff010101ff, 0x01010100ffffffff, 0x01010100ff000001, 0x010101000000ff00,
3927
+ 0x0101010001010000, 0x0101010100ff0001, 0x010101010001ff01, 0x010101010101ffff,
3928
+ };
3750
3929
 
3751
3930
  constexpr constant static uint8_t ksigns_iq2xs[128] = {
3752
3931
  0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
@@ -3854,7 +4033,10 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
3854
4033
  y4 += 32 * 32;
3855
4034
  }
3856
4035
  #else
3857
- // TODO
4036
+ (void) x;
4037
+ (void) y;
4038
+ (void) yl;
4039
+ (void) nb32;
3858
4040
  #endif
3859
4041
 
3860
4042
  for (int row = 0; row < N_DST; ++row) {
@@ -3997,7 +4179,10 @@ void kernel_mul_mv_iq2_xs_f32_impl(
3997
4179
  y4 += 32 * 32;
3998
4180
  }
3999
4181
  #else
4000
- // TODO
4182
+ (void) x;
4183
+ (void) y;
4184
+ (void) yl;
4185
+ (void) nb32;
4001
4186
  #endif
4002
4187
 
4003
4188
  for (int row = 0; row < N_DST; ++row) {
@@ -4133,7 +4318,10 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
4133
4318
  y4 += 32 * 32;
4134
4319
  }
4135
4320
  #else
4136
- // TODO
4321
+ (void) x;
4322
+ (void) y;
4323
+ (void) yl;
4324
+ (void) nb32;
4137
4325
  #endif
4138
4326
 
4139
4327
  for (int row = 0; row < N_DST; ++row) {
@@ -4173,6 +4361,250 @@ kernel void kernel_mul_mv_iq3_xxs_f32(
4173
4361
  kernel_mul_mv_iq3_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
4174
4362
  }
4175
4363
 
4364
+ void kernel_mul_mv_iq1_s_f32_impl(
4365
+ device const void * src0,
4366
+ device const float * src1,
4367
+ device float * dst,
4368
+ constant int64_t & ne00,
4369
+ constant int64_t & ne01,
4370
+ constant int64_t & ne02,
4371
+ constant int64_t & ne10,
4372
+ constant int64_t & ne12,
4373
+ constant int64_t & ne0,
4374
+ constant int64_t & ne1,
4375
+ constant uint & r2,
4376
+ constant uint & r3,
4377
+ uint3 tgpig[[threadgroup_position_in_grid]],
4378
+ uint tiisg[[thread_index_in_simdgroup]],
4379
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4380
+
4381
+ const int nb = ne00/QK_K;
4382
+ const int r0 = tgpig.x;
4383
+ const int r1 = tgpig.y;
4384
+ const int im = tgpig.z;
4385
+
4386
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
4387
+ const int ib_row = first_row * nb;
4388
+
4389
+ const uint i12 = im%ne12;
4390
+ const uint i13 = im/ne12;
4391
+
4392
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
4393
+ device const block_iq1_s * x = (device const block_iq1_s *) src0 + ib_row + offset0;
4394
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
4395
+
4396
+ float yl[16];
4397
+ float sumf[N_DST]={0.f}, all_sum;
4398
+
4399
+ const int nb32 = nb * (QK_K / 32);
4400
+
4401
+ #if QK_K == 256
4402
+ const int ix = tiisg/2;
4403
+ const int il = tiisg%2;
4404
+
4405
+ device const float * y4 = y + 32 * ix + 16 * il;
4406
+
4407
+ for (int ib32 = ix; ib32 < nb32; ib32 += 16) {
4408
+
4409
+ for (int i = 0; i < 16; ++i) {
4410
+ yl[i] = y4[i];
4411
+ }
4412
+
4413
+ const int ibl = ib32 / (QK_K / 32);
4414
+ const int ib = ib32 % (QK_K / 32);
4415
+
4416
+ device const block_iq1_s * xr = x + ibl;
4417
+ device const uint8_t * qs = xr->qs + 4 * ib + 2 * il;
4418
+ device const uint8_t * sc = xr->scales + 2 * ib + il;
4419
+ device const half * dh = &xr->d;
4420
+
4421
+ for (int row = 0; row < N_DST; row++) {
4422
+
4423
+ constant int8_t * grid1 = (constant int8_t *)(iq1s_grid + (qs[0] | ((sc[0] & 0x08) << 5)));
4424
+ constant int8_t * grid2 = (constant int8_t *)(iq1s_grid + (qs[1] | ((sc[0] & 0x80) << 1)));
4425
+
4426
+ float2 sum = {0};
4427
+ for (int j = 0; j < 8; ++j) {
4428
+ sum[0] += yl[j+ 0] * grid1[j];
4429
+ sum[1] += yl[j+ 8] * grid2[j];
4430
+ }
4431
+ sumf[row] += (float)dh[0] * (sum[0] * (2*(sc[0] & 7) + 1) + sum[1] * (2*((sc[0] >> 4) & 7) + 1));
4432
+
4433
+ dh += nb*sizeof(block_iq1_s)/2;
4434
+ qs += nb*sizeof(block_iq1_s);
4435
+ sc += nb*sizeof(block_iq1_s);
4436
+ }
4437
+
4438
+ y4 += 16 * 32;
4439
+ }
4440
+ #else
4441
+ (void) x;
4442
+ (void) y;
4443
+ (void) yl;
4444
+ (void) nb32;
4445
+ #endif
4446
+
4447
+ for (int row = 0; row < N_DST; ++row) {
4448
+ all_sum = simd_sum(sumf[row]);
4449
+ if (tiisg == 0) {
4450
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
4451
+ }
4452
+ }
4453
+ }
4454
+
4455
+ constexpr constant static float kvalues_iq4nl_f[16] = {
4456
+ -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
4457
+ };
4458
+
4459
+ void kernel_mul_mv_iq4_nl_f32_impl(
4460
+ device const void * src0,
4461
+ device const float * src1,
4462
+ device float * dst,
4463
+ constant int64_t & ne00,
4464
+ constant int64_t & ne01,
4465
+ constant int64_t & ne02,
4466
+ constant int64_t & ne10,
4467
+ constant int64_t & ne12,
4468
+ constant int64_t & ne0,
4469
+ constant int64_t & ne1,
4470
+ constant uint & r2,
4471
+ constant uint & r3,
4472
+ threadgroup float * shared_values [[threadgroup(0)]],
4473
+ uint3 tgpig[[threadgroup_position_in_grid]],
4474
+ uint tiisg[[thread_index_in_simdgroup]],
4475
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4476
+
4477
+ const int nb = ne00/QK4_NL;
4478
+ const int r0 = tgpig.x;
4479
+ const int r1 = tgpig.y;
4480
+ const int im = tgpig.z;
4481
+ const int first_row = (r0 * 2 + sgitg) * 2;
4482
+ const int ib_row = first_row * nb;
4483
+
4484
+ const uint i12 = im%ne12;
4485
+ const uint i13 = im/ne12;
4486
+
4487
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
4488
+ device const block_iq4_nl * x = (device const block_iq4_nl *) src0 + ib_row + offset0;
4489
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
4490
+
4491
+ const int ix = tiisg/2; // 0...15
4492
+ const int it = tiisg%2; // 0 or 1
4493
+
4494
+ shared_values[tiisg] = kvalues_iq4nl_f[tiisg%16];
4495
+ threadgroup_barrier(mem_flags::mem_threadgroup);
4496
+
4497
+ float4 yl[4];
4498
+ float sumf[2]={0.f}, all_sum;
4499
+
4500
+ device const float * yb = y + ix * QK4_NL + it * 8;
4501
+
4502
+ uint32_t aux32[2];
4503
+ thread const uint8_t * q8 = (thread const uint8_t *)aux32;
4504
+
4505
+ float4 qf1, qf2;
4506
+
4507
+ for (int ib = ix; ib < nb; ib += 16) {
4508
+
4509
+ device const float4 * y4 = (device const float4 *)yb;
4510
+ yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5];
4511
+
4512
+ for (int row = 0; row < 2; ++row) {
4513
+
4514
+ device const block_iq4_nl & xb = x[row*nb + ib];
4515
+ device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it);
4516
+
4517
+ float4 acc1 = {0.f}, acc2 = {0.f};
4518
+
4519
+ aux32[0] = q4[0] | (q4[1] << 16);
4520
+ aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f;
4521
+ aux32[0] &= 0x0f0f0f0f;
4522
+ qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
4523
+ qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
4524
+ acc1 += yl[0] * qf1;
4525
+ acc2 += yl[1] * qf2;
4526
+
4527
+ aux32[0] = q4[2] | (q4[3] << 16);
4528
+ aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f;
4529
+ aux32[0] &= 0x0f0f0f0f;
4530
+ qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
4531
+ qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
4532
+ acc1 += yl[2] * qf1;
4533
+ acc2 += yl[3] * qf2;
4534
+
4535
+ acc1 += acc2;
4536
+
4537
+ sumf[row] += (float)xb.d * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
4538
+
4539
+ }
4540
+
4541
+ yb += 16 * QK4_NL;
4542
+ }
4543
+
4544
+ for (int row = 0; row < 2; ++row) {
4545
+ all_sum = simd_sum(sumf[row]);
4546
+ if (tiisg == 0) {
4547
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
4548
+ }
4549
+ }
4550
+ }
4551
+
4552
+ [[host_name("kernel_mul_mv_iq1_s_f32")]]
4553
+ kernel void kernel_mul_mv_iq1_s_f32(
4554
+ device const void * src0,
4555
+ device const float * src1,
4556
+ device float * dst,
4557
+ constant int64_t & ne00,
4558
+ constant int64_t & ne01,
4559
+ constant int64_t & ne02,
4560
+ constant uint64_t & nb00,
4561
+ constant uint64_t & nb01,
4562
+ constant uint64_t & nb02,
4563
+ constant int64_t & ne10,
4564
+ constant int64_t & ne11,
4565
+ constant int64_t & ne12,
4566
+ constant uint64_t & nb10,
4567
+ constant uint64_t & nb11,
4568
+ constant uint64_t & nb12,
4569
+ constant int64_t & ne0,
4570
+ constant int64_t & ne1,
4571
+ constant uint & r2,
4572
+ constant uint & r3,
4573
+ uint3 tgpig[[threadgroup_position_in_grid]],
4574
+ uint tiisg[[thread_index_in_simdgroup]],
4575
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4576
+
4577
+ kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
4578
+ }
4579
+
4580
+ [[host_name("kernel_mul_mv_iq4_nl_f32")]]
4581
+ kernel void kernel_mul_mv_iq4_nl_f32(
4582
+ device const void * src0,
4583
+ device const float * src1,
4584
+ device float * dst,
4585
+ constant int64_t & ne00,
4586
+ constant int64_t & ne01,
4587
+ constant int64_t & ne02,
4588
+ constant uint64_t & nb00,
4589
+ constant uint64_t & nb01,
4590
+ constant uint64_t & nb02,
4591
+ constant int64_t & ne10,
4592
+ constant int64_t & ne11,
4593
+ constant int64_t & ne12,
4594
+ constant uint64_t & nb10,
4595
+ constant uint64_t & nb11,
4596
+ constant uint64_t & nb12,
4597
+ constant int64_t & ne0,
4598
+ constant int64_t & ne1,
4599
+ constant uint & r2,
4600
+ constant uint & r3,
4601
+ threadgroup float * shared_values [[threadgroup(0)]],
4602
+ uint3 tgpig[[threadgroup_position_in_grid]],
4603
+ uint tiisg[[thread_index_in_simdgroup]],
4604
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4605
+
4606
+ kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
4607
+ }
4176
4608
 
4177
4609
  //============================= templates and their specializations =============================
4178
4610
 
@@ -4369,6 +4801,8 @@ void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg
4369
4801
  const float dl = d * sc[0];
4370
4802
  const float ml = min * sc[1];
4371
4803
  #else
4804
+ (void) get_scale_min_k4_just2;
4805
+
4372
4806
  q = q + 16 * (il&1);
4373
4807
  device const uint8_t * s = xb->scales;
4374
4808
  device const half2 * dh = (device const half2 *)xb->d;
@@ -4518,6 +4952,37 @@ void dequantize_iq3_xxs(device const block_iq3_xxs * xb, short il, thread type4x
4518
4952
  }
4519
4953
  }
4520
4954
 
4955
+ template <typename type4x4>
4956
+ void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 & reg) {
4957
+ // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
4958
+ const float d = xb->d;
4959
+ device const uint8_t * qs = xb->qs + 2*il;
4960
+ device const uint8_t * sc = xb->scales + il;
4961
+ const float dl1 = d * (2*(sc[0] & 7) + 1);
4962
+ const float dl2 = d * (2*((sc[0] >> 4) & 7) + 1);
4963
+ constant int8_t * grid1 = (constant int8_t *)(iq1s_grid + (qs[0] | ((sc[0] & 0x08) << 5)));
4964
+ constant int8_t * grid2 = (constant int8_t *)(iq1s_grid + (qs[1] | ((sc[0] & 0x80) << 1)));
4965
+ for (int i = 0; i < 8; ++i) {
4966
+ reg[i/4+0][i%4] = dl1 * grid1[i];
4967
+ reg[i/4+2][i%4] = dl2 * grid2[i];
4968
+ }
4969
+ }
4970
+
4971
+ template <typename type4x4>
4972
+ void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4 & reg) {
4973
+ device const uint16_t * q4 = (device const uint16_t *)xb->qs;
4974
+ const float d = xb->d;
4975
+ uint32_t aux32;
4976
+ thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
4977
+ for (int i = 0; i < 4; ++i) {
4978
+ aux32 = ((q4[2*i] | (q4[2*i+1] << 16)) >> 4*il) & 0x0f0f0f0f;
4979
+ reg[i][0] = d * kvalues_iq4nl_f[q8[0]];
4980
+ reg[i][1] = d * kvalues_iq4nl_f[q8[1]];
4981
+ reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
4982
+ reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
4983
+ }
4984
+ }
4985
+
4521
4986
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
4522
4987
  kernel void kernel_get_rows(
4523
4988
  device const void * src0,
@@ -5060,6 +5525,8 @@ template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows
5060
5525
  template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_t kernel_get_rows<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
5061
5526
  template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_t kernel_get_rows<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
5062
5527
  template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_t kernel_get_rows<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
5528
+ template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows<block_iq1_s, QK_NL, dequantize_iq1_s>;
5529
+ template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows<block_iq4_nl, 2, dequantize_iq4_nl>;
5063
5530
 
5064
5531
  //
5065
5532
  // matrix-matrix multiplication
@@ -5099,6 +5566,8 @@ template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<b
5099
5566
  template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
5100
5567
  template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
5101
5568
  template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
5569
+ template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_s, QK_NL, dequantize_iq1_s>;
5570
+ template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_nl>;
5102
5571
 
5103
5572
  //
5104
5573
  // indirect matrix-matrix multiplication
@@ -5150,6 +5619,8 @@ template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mu
5150
5619
  template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
5151
5620
  template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
5152
5621
  template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
5622
+ template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s, QK_NL, dequantize_iq1_s>;
5623
+ template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl, 2, dequantize_iq4_nl>;
5153
5624
 
5154
5625
  //
5155
5626
  // matrix-vector multiplication
@@ -6117,3 +6588,131 @@ kernel void kernel_mul_mv_id_iq3_xxs_f32(
6117
6588
  tiisg,
6118
6589
  sgitg);
6119
6590
  }
6591
+
6592
+ [[host_name("kernel_mul_mv_id_iq1_s_f32")]]
6593
+ kernel void kernel_mul_mv_id_iq1_s_f32(
6594
+ device const char * ids,
6595
+ device const char * src1,
6596
+ device float * dst,
6597
+ constant uint64_t & nbi1,
6598
+ constant int64_t & ne00,
6599
+ constant int64_t & ne01,
6600
+ constant int64_t & ne02,
6601
+ constant uint64_t & nb00,
6602
+ constant uint64_t & nb01,
6603
+ constant uint64_t & nb02,
6604
+ constant int64_t & ne10,
6605
+ constant int64_t & ne11,
6606
+ constant int64_t & ne12,
6607
+ constant int64_t & ne13,
6608
+ constant uint64_t & nb10,
6609
+ constant uint64_t & nb11,
6610
+ constant uint64_t & nb12,
6611
+ constant int64_t & ne0,
6612
+ constant int64_t & ne1,
6613
+ constant uint64_t & nb1,
6614
+ constant uint & r2,
6615
+ constant uint & r3,
6616
+ constant int & idx,
6617
+ device const char * src00,
6618
+ device const char * src01,
6619
+ device const char * src02,
6620
+ device const char * src03,
6621
+ device const char * src04,
6622
+ device const char * src05,
6623
+ device const char * src06,
6624
+ device const char * src07,
6625
+ uint3 tgpig[[threadgroup_position_in_grid]],
6626
+ uint tiitg[[thread_index_in_threadgroup]],
6627
+ uint tiisg[[thread_index_in_simdgroup]],
6628
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
6629
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6630
+
6631
+ const int64_t bid = tgpig.z/(ne12*ne13);
6632
+
6633
+ tgpig.z = tgpig.z%(ne12*ne13);
6634
+
6635
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6636
+
6637
+ kernel_mul_mv_iq1_s_f32_impl(
6638
+ src0[id],
6639
+ (device const float *) (src1 + bid*nb11),
6640
+ dst + bid*ne0,
6641
+ ne00,
6642
+ ne01,
6643
+ ne02,
6644
+ ne10,
6645
+ ne12,
6646
+ ne0,
6647
+ ne1,
6648
+ r2,
6649
+ r3,
6650
+ tgpig,
6651
+ tiisg,
6652
+ sgitg);
6653
+ }
6654
+
6655
+ [[host_name("kernel_mul_mv_id_iq4_nl_f32")]]
6656
+ kernel void kernel_mul_mv_id_iq4_nl_f32(
6657
+ device const char * ids,
6658
+ device const char * src1,
6659
+ device float * dst,
6660
+ constant uint64_t & nbi1,
6661
+ constant int64_t & ne00,
6662
+ constant int64_t & ne01,
6663
+ constant int64_t & ne02,
6664
+ constant uint64_t & nb00,
6665
+ constant uint64_t & nb01,
6666
+ constant uint64_t & nb02,
6667
+ constant int64_t & ne10,
6668
+ constant int64_t & ne11,
6669
+ constant int64_t & ne12,
6670
+ constant int64_t & ne13,
6671
+ constant uint64_t & nb10,
6672
+ constant uint64_t & nb11,
6673
+ constant uint64_t & nb12,
6674
+ constant int64_t & ne0,
6675
+ constant int64_t & ne1,
6676
+ constant uint64_t & nb1,
6677
+ constant uint & r2,
6678
+ constant uint & r3,
6679
+ constant int & idx,
6680
+ device const char * src00,
6681
+ device const char * src01,
6682
+ device const char * src02,
6683
+ device const char * src03,
6684
+ device const char * src04,
6685
+ device const char * src05,
6686
+ device const char * src06,
6687
+ device const char * src07,
6688
+ threadgroup float * shared_values [[threadgroup(0)]],
6689
+ uint3 tgpig[[threadgroup_position_in_grid]],
6690
+ uint tiitg[[thread_index_in_threadgroup]],
6691
+ uint tiisg[[thread_index_in_simdgroup]],
6692
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
6693
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6694
+
6695
+ const int64_t bid = tgpig.z/(ne12*ne13);
6696
+
6697
+ tgpig.z = tgpig.z%(ne12*ne13);
6698
+
6699
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6700
+
6701
+ kernel_mul_mv_iq4_nl_f32_impl(
6702
+ src0[id],
6703
+ (device const float *) (src1 + bid*nb11),
6704
+ dst + bid*ne0,
6705
+ ne00,
6706
+ ne01,
6707
+ ne02,
6708
+ ne10,
6709
+ ne12,
6710
+ ne0,
6711
+ ne1,
6712
+ r2,
6713
+ r3,
6714
+ shared_values,
6715
+ tgpig,
6716
+ tiisg,
6717
+ sgitg);
6718
+ }