llama_cpp 0.14.3 → 0.14.4

Sign up to get free protection for your applications and to get access to all the features.
@@ -13,8 +13,8 @@ using namespace metal;
13
13
  #define N_SIMDWIDTH 32 // assuming SIMD group size is 32
14
14
 
15
15
  enum ggml_sort_order {
16
- GGML_SORT_ASC,
17
- GGML_SORT_DESC,
16
+ GGML_SORT_ORDER_ASC,
17
+ GGML_SORT_ORDER_DESC,
18
18
  };
19
19
 
20
20
  // general-purpose kernel for addition, multiplication and division of two tensors
@@ -1973,9 +1973,11 @@ kernel void kernel_timestep_embedding_f32(
1973
1973
 
1974
1974
  // bitonic sort implementation following the CUDA kernels as reference
1975
1975
  typedef void (argsort_t)(
1976
- device const float * x,
1977
- device int32_t * dst,
1978
- constant int64_t & ncols,
1976
+ device const float * x,
1977
+ device int32_t * dst,
1978
+ constant int64_t & ncols,
1979
+ constant int64_t & ncols_pad,
1980
+ threadgroup int32_t * shared_values [[threadgroup(0)]],
1979
1981
  uint3 tgpig[[threadgroup_position_in_grid]],
1980
1982
  uint3 tpitg[[thread_position_in_threadgroup]]);
1981
1983
 
@@ -1984,33 +1986,42 @@ kernel void kernel_argsort_f32_i32(
1984
1986
  device const float * x,
1985
1987
  device int32_t * dst,
1986
1988
  constant int64_t & ncols,
1989
+ constant int64_t & ncols_pad,
1990
+ threadgroup int32_t * shared_values [[threadgroup(0)]],
1987
1991
  uint3 tgpig[[threadgroup_position_in_grid]],
1988
1992
  uint3 tpitg[[thread_position_in_threadgroup]]) {
1989
1993
  // bitonic sort
1990
1994
  int col = tpitg[0];
1991
1995
  int row = tgpig[1];
1992
1996
 
1993
- if (col >= ncols) return;
1997
+ if (col >= ncols_pad) return;
1994
1998
 
1995
- device const float * x_row = x + row * ncols;
1996
- device int32_t * dst_row = dst + row * ncols;
1999
+ device const float * x_row = x + row * ncols;
2000
+ threadgroup int32_t * dst_row = shared_values;
1997
2001
 
1998
2002
  // initialize indices
1999
- if (col < ncols) {
2000
- dst_row[col] = col;
2001
- }
2003
+ dst_row[col] = col;
2004
+
2002
2005
  threadgroup_barrier(mem_flags::mem_threadgroup);
2003
2006
 
2004
- for (int k = 2; k <= ncols; k *= 2) {
2007
+ for (int k = 2; k <= ncols_pad; k *= 2) {
2005
2008
  for (int j = k / 2; j > 0; j /= 2) {
2006
2009
  int ixj = col ^ j;
2007
2010
  if (ixj > col) {
2008
2011
  if ((col & k) == 0) {
2009
- if (order == GGML_SORT_ASC ? x_row[dst_row[col]] > x_row[dst_row[ixj]] : x_row[dst_row[col]] < x_row[dst_row[ixj]]) {
2012
+ if (dst_row[col] >= ncols ||
2013
+ (dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ?
2014
+ x_row[dst_row[col]] > x_row[dst_row[ixj]] :
2015
+ x_row[dst_row[col]] < x_row[dst_row[ixj]]))
2016
+ ) {
2010
2017
  SWAP(dst_row[col], dst_row[ixj]);
2011
2018
  }
2012
2019
  } else {
2013
- if (order == GGML_SORT_ASC ? x_row[dst_row[col]] < x_row[dst_row[ixj]] : x_row[dst_row[col]] > x_row[dst_row[ixj]]) {
2020
+ if (dst_row[ixj] >= ncols ||
2021
+ (dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ?
2022
+ x_row[dst_row[col]] < x_row[dst_row[ixj]] :
2023
+ x_row[dst_row[col]] > x_row[dst_row[ixj]]))
2024
+ ) {
2014
2025
  SWAP(dst_row[col], dst_row[ixj]);
2015
2026
  }
2016
2027
  }
@@ -2018,10 +2029,15 @@ kernel void kernel_argsort_f32_i32(
2018
2029
  threadgroup_barrier(mem_flags::mem_threadgroup);
2019
2030
  }
2020
2031
  }
2032
+
2033
+ // copy the result to dst without the padding
2034
+ if (col < ncols) {
2035
+ dst[row * ncols + col] = dst_row[col];
2036
+ }
2021
2037
  }
2022
2038
 
2023
- template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ASC>;
2024
- template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_DESC>;
2039
+ template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_ASC>;
2040
+ template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_DESC>;
2025
2041
 
2026
2042
  kernel void kernel_leaky_relu_f32(
2027
2043
  device const float * src0,
@@ -4456,6 +4472,114 @@ void kernel_mul_mv_iq1_s_f32_impl(
4456
4472
  }
4457
4473
  }
4458
4474
 
4475
+ void kernel_mul_mv_iq1_m_f32_impl(
4476
+ device const void * src0,
4477
+ device const float * src1,
4478
+ device float * dst,
4479
+ constant int64_t & ne00,
4480
+ constant int64_t & ne01,
4481
+ constant int64_t & ne02,
4482
+ constant int64_t & ne10,
4483
+ constant int64_t & ne12,
4484
+ constant int64_t & ne0,
4485
+ constant int64_t & ne1,
4486
+ constant uint & r2,
4487
+ constant uint & r3,
4488
+ uint3 tgpig[[threadgroup_position_in_grid]],
4489
+ uint tiisg[[thread_index_in_simdgroup]],
4490
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4491
+
4492
+ const int nb = ne00/QK_K;
4493
+ const int r0 = tgpig.x;
4494
+ const int r1 = tgpig.y;
4495
+ const int im = tgpig.z;
4496
+
4497
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
4498
+ const int ib_row = first_row * nb;
4499
+
4500
+ const uint i12 = im%ne12;
4501
+ const uint i13 = im/ne12;
4502
+
4503
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
4504
+ device const block_iq1_m * x = (device const block_iq1_m *) src0 + ib_row + offset0;
4505
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
4506
+
4507
+ float yl[32];
4508
+ float sumf[N_DST]={0.f}, all_sum;
4509
+
4510
+ const int nb32 = nb * (QK_K / 32);
4511
+
4512
+ const int ix = tiisg;
4513
+
4514
+ device const float * y4 = y + 32 * ix;
4515
+
4516
+ #if QK_K != 64
4517
+ iq1m_scale_t scale;
4518
+ #endif
4519
+
4520
+ for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
4521
+
4522
+ float4 sumy = {0.f};
4523
+ for (int i = 0; i < 8; ++i) {
4524
+ yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
4525
+ yl[i+ 8] = y4[i+ 8]; sumy[1] += yl[i+ 8];
4526
+ yl[i+16] = y4[i+16]; sumy[2] += yl[i+16];
4527
+ yl[i+24] = y4[i+24]; sumy[3] += yl[i+24];
4528
+ }
4529
+
4530
+ const int ibl = ib32 / (QK_K / 32);
4531
+ const int ib = ib32 % (QK_K / 32);
4532
+
4533
+ device const block_iq1_m * xr = x + ibl;
4534
+ device const uint8_t * qs = xr->qs + 4 * ib;
4535
+ device const uint8_t * qh = xr->qh + 2 * ib;
4536
+ device const uint16_t * sc = (device const uint16_t *)xr->scales;
4537
+
4538
+ for (int row = 0; row < N_DST; row++) {
4539
+
4540
+ #if QK_K != 64
4541
+ scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
4542
+ #endif
4543
+
4544
+ constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
4545
+ constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700)));
4546
+ constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid_gpu + (qs[2] | ((qh[1] << 8) & 0x700)));
4547
+ constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[1] << 4) & 0x700)));
4548
+
4549
+ float2 sum = {0.f};
4550
+ for (int j = 0; j < 4; ++j) {
4551
+ sum[0] += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4)
4552
+ + yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4);
4553
+ sum[1] += yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4)
4554
+ + yl[j+24] * (grid4[j] & 0xf) + yl[j+28] * (grid4[j] >> 4);
4555
+ }
4556
+ const float delta1 = sumy[0] * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[1] * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
4557
+ const float delta2 = sumy[2] * (qh[1] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[3] * (qh[1] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
4558
+ #if QK_K == 64
4559
+ const float d = (float) *((device const half *)(sc - 1));
4560
+ sumf[row] += d * ((sum[0] + delta1) * (2*((sc[0] >> (8*(ib%2)+0)) & 0xf) + 1) +
4561
+ (sum[1] + delta2) * (2*((sc[0] >> (8*(ib%2)+4)) & 0xf) + 1));
4562
+ #else
4563
+ sumf[row] += (float)scale.f16 * ((sum[0] + delta1) * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 7) + 1) +
4564
+ (sum[1] + delta2) * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 7) + 1));
4565
+ #endif
4566
+
4567
+ sc += nb*sizeof(block_iq1_m)/2;
4568
+ qs += nb*sizeof(block_iq1_m);
4569
+ qh += nb*sizeof(block_iq1_m);
4570
+ }
4571
+
4572
+ y4 += 32 * 32;
4573
+ }
4574
+
4575
+ for (int row = 0; row < N_DST; ++row) {
4576
+ all_sum = simd_sum(sumf[row]);
4577
+ if (tiisg == 0) {
4578
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
4579
+ }
4580
+ }
4581
+ }
4582
+
4459
4583
  void kernel_mul_mv_iq4_nl_f32_impl(
4460
4584
  device const void * src0,
4461
4585
  device const float * src1,
@@ -4673,6 +4797,34 @@ kernel void kernel_mul_mv_iq1_s_f32(
4673
4797
  kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
4674
4798
  }
4675
4799
 
4800
+ [[host_name("kernel_mul_mv_iq1_m_f32")]]
4801
+ kernel void kernel_mul_mv_iq1_m_f32(
4802
+ device const void * src0,
4803
+ device const float * src1,
4804
+ device float * dst,
4805
+ constant int64_t & ne00,
4806
+ constant int64_t & ne01,
4807
+ constant int64_t & ne02,
4808
+ constant uint64_t & nb00,
4809
+ constant uint64_t & nb01,
4810
+ constant uint64_t & nb02,
4811
+ constant int64_t & ne10,
4812
+ constant int64_t & ne11,
4813
+ constant int64_t & ne12,
4814
+ constant uint64_t & nb10,
4815
+ constant uint64_t & nb11,
4816
+ constant uint64_t & nb12,
4817
+ constant int64_t & ne0,
4818
+ constant int64_t & ne1,
4819
+ constant uint & r2,
4820
+ constant uint & r3,
4821
+ uint3 tgpig[[threadgroup_position_in_grid]],
4822
+ uint tiisg[[thread_index_in_simdgroup]],
4823
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4824
+
4825
+ kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
4826
+ }
4827
+
4676
4828
  [[host_name("kernel_mul_mv_iq4_nl_f32")]]
4677
4829
  kernel void kernel_mul_mv_iq4_nl_f32(
4678
4830
  device const void * src0,
@@ -5146,6 +5298,38 @@ void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 &
5146
5298
  }
5147
5299
  }
5148
5300
 
5301
+ template <typename type4x4>
5302
+ void dequantize_iq1_m(device const block_iq1_m * xb, short il, thread type4x4 & reg) {
5303
+ // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
5304
+ const int ib32 = il/2;
5305
+ il = il%2;
5306
+ device const uint16_t * sc = (device const uint16_t *)xb->scales;
5307
+ #if QK_K == 64
5308
+ const float d = xb->d;
5309
+ #else
5310
+ iq1m_scale_t scale;
5311
+ scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
5312
+ const float d = scale.f16;
5313
+ #endif
5314
+ device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
5315
+ device const uint8_t * qh = xb->qh + 2*ib32 + il;
5316
+ #if QK_K == 64
5317
+ const float dl = d * (2*((sc[ib32/2] >> (8*(ib32%2)+4*il)) & 0xf) + 1);
5318
+ #else
5319
+ const float dl = d * (2*((sc[ib32/2] >> (6*(ib32%2)+3*il)) & 7) + 1);
5320
+ #endif
5321
+ const float ml1 = dl * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
5322
+ const float ml2 = dl * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
5323
+ constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
5324
+ constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700)));
5325
+ for (int i = 0; i < 4; ++i) {
5326
+ reg[0][i] = dl * (grid1[i] & 0xf) + ml1;
5327
+ reg[1][i] = dl * (grid1[i] >> 4) + ml1;
5328
+ reg[2][i] = dl * (grid2[i] & 0xf) + ml2;
5329
+ reg[3][i] = dl * (grid2[i] >> 4) + ml2;
5330
+ }
5331
+ }
5332
+
5149
5333
  template <typename type4x4>
5150
5334
  void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4 & reg) {
5151
5335
  device const uint16_t * q4 = (device const uint16_t *)xb->qs;
@@ -5617,9 +5801,10 @@ kernel void kernel_mul_mm(device const uchar * src0,
5617
5801
 
5618
5802
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
5619
5803
  kernel void kernel_mul_mm_id(
5620
- device const uchar * ids,
5804
+ device const uchar * src0s,
5621
5805
  device const uchar * src1,
5622
5806
  device float * dst,
5807
+ device const uchar * ids,
5623
5808
  constant uint64_t & nbi1,
5624
5809
  constant int64_t & ne00,
5625
5810
  constant int64_t & ne02,
@@ -5636,22 +5821,14 @@ kernel void kernel_mul_mm_id(
5636
5821
  constant uint & r2,
5637
5822
  constant uint & r3,
5638
5823
  constant int & idx,
5639
- device const uchar * src00,
5640
- device const uchar * src01,
5641
- device const uchar * src02,
5642
- device const uchar * src03,
5643
- device const uchar * src04,
5644
- device const uchar * src05,
5645
- device const uchar * src06,
5646
- device const uchar * src07,
5647
5824
  threadgroup uchar * shared_memory [[threadgroup(0)]],
5648
5825
  uint3 tgpig[[threadgroup_position_in_grid]],
5649
5826
  uint tiitg[[thread_index_in_threadgroup]],
5650
5827
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
5651
- device const uchar * src0s[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
5652
5828
 
5653
5829
  // expert id
5654
5830
  const int32_t id = tgpig.z/(ne12*ne13);
5831
+ device const uchar * src0 = src0s + id*nb02;
5655
5832
 
5656
5833
  tgpig.z = tgpig.z%(ne12*ne13);
5657
5834
 
@@ -5666,7 +5843,7 @@ kernel void kernel_mul_mm_id(
5666
5843
  }
5667
5844
 
5668
5845
  kernel_mul_mm_id_impl<block_q, nl, dequantize_func>(
5669
- src0s[id],
5846
+ src0,
5670
5847
  src1,
5671
5848
  src1ids,
5672
5849
  dst,
@@ -5730,6 +5907,7 @@ template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_t kernel_get_r
5730
5907
  template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_t kernel_get_rows<block_iq3_s, QK_NL, dequantize_iq3_s>;
5731
5908
  template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_t kernel_get_rows<block_iq2_s, QK_NL, dequantize_iq2_s>;
5732
5909
  template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows<block_iq1_s, QK_NL, dequantize_iq1_s>;
5910
+ template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_t kernel_get_rows<block_iq1_m, QK_NL, dequantize_iq1_m>;
5733
5911
  template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows<block_iq4_nl, 2, dequantize_iq4_nl>;
5734
5912
  #if QK_K == 64
5735
5913
  template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows<block_iq4_xs, 2, dequantize_iq4_xs>;
@@ -5778,6 +5956,7 @@ template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_m
5778
5956
  template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_s, QK_NL, dequantize_iq3_s>;
5779
5957
  template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_s, QK_NL, dequantize_iq2_s>;
5780
5958
  template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_s, QK_NL, dequantize_iq1_s>;
5959
+ template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_m, QK_NL, dequantize_iq1_m>;
5781
5960
  template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_nl>;
5782
5961
  #if QK_K == 64
5783
5962
  template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_xs>;
@@ -5790,9 +5969,10 @@ template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_m
5790
5969
  //
5791
5970
 
5792
5971
  typedef void (mat_mm_id_t)(
5793
- device const uchar * ids,
5972
+ device const uchar * src0s,
5794
5973
  device const uchar * src1,
5795
5974
  device float * dst,
5975
+ device const uchar * ids,
5796
5976
  constant uint64_t & nbi1,
5797
5977
  constant int64_t & ne00,
5798
5978
  constant int64_t & ne02,
@@ -5809,14 +5989,6 @@ typedef void (mat_mm_id_t)(
5809
5989
  constant uint & r2,
5810
5990
  constant uint & r3,
5811
5991
  constant int & idx,
5812
- device const uchar * src00,
5813
- device const uchar * src01,
5814
- device const uchar * src02,
5815
- device const uchar * src03,
5816
- device const uchar * src04,
5817
- device const uchar * src05,
5818
- device const uchar * src06,
5819
- device const uchar * src07,
5820
5992
  threadgroup uchar *,
5821
5993
  uint3, uint, uint);
5822
5994
 
@@ -5838,6 +6010,7 @@ template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel
5838
6010
  template [[host_name("kernel_mul_mm_id_iq3_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_s, QK_NL, dequantize_iq3_s>;
5839
6011
  template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_s, QK_NL, dequantize_iq2_s>;
5840
6012
  template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s, QK_NL, dequantize_iq1_s>;
6013
+ template [[host_name("kernel_mul_mm_id_iq1_m_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_m, QK_NL, dequantize_iq1_m>;
5841
6014
  template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl, 2, dequantize_iq4_nl>;
5842
6015
  #if QK_K == 64
5843
6016
  template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, 2, dequantize_iq4_xs>;
@@ -5851,9 +6024,10 @@ template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel
5851
6024
 
5852
6025
  [[host_name("kernel_mul_mv_id_f32_f32")]]
5853
6026
  kernel void kernel_mul_mv_id_f32_f32(
5854
- device const char * ids,
6027
+ device const char * src0s,
5855
6028
  device const char * src1,
5856
6029
  device float * dst,
6030
+ device const char * ids,
5857
6031
  constant uint64_t & nbi1,
5858
6032
  constant int64_t & ne00,
5859
6033
  constant int64_t & ne01,
@@ -5874,28 +6048,19 @@ kernel void kernel_mul_mv_id_f32_f32(
5874
6048
  constant uint & r2,
5875
6049
  constant uint & r3,
5876
6050
  constant int & idx,
5877
- device const char * src00,
5878
- device const char * src01,
5879
- device const char * src02,
5880
- device const char * src03,
5881
- device const char * src04,
5882
- device const char * src05,
5883
- device const char * src06,
5884
- device const char * src07,
5885
6051
  uint3 tgpig[[threadgroup_position_in_grid]],
5886
6052
  uint tiitg[[thread_index_in_threadgroup]],
5887
6053
  uint tiisg[[thread_index_in_simdgroup]],
5888
6054
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
5889
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
5890
-
5891
6055
  const int64_t bid = tgpig.z/(ne12*ne13);
5892
6056
 
5893
6057
  tgpig.z = tgpig.z%(ne12*ne13);
5894
6058
 
5895
6059
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6060
+ device const char * src0 = src0s + id*nb02;
5896
6061
 
5897
6062
  kernel_mul_mv_f32_f32_impl(
5898
- src0[id],
6063
+ src0,
5899
6064
  src1 + bid*nb11,
5900
6065
  dst + bid*ne0,
5901
6066
  ne00,
@@ -5920,9 +6085,10 @@ kernel void kernel_mul_mv_id_f32_f32(
5920
6085
 
5921
6086
  [[host_name("kernel_mul_mv_id_f16_f32")]]
5922
6087
  kernel void kernel_mul_mv_id_f16_f32(
5923
- device const char * ids,
6088
+ device const char * src0s,
5924
6089
  device const char * src1,
5925
6090
  device float * dst,
6091
+ device const char * ids,
5926
6092
  constant uint64_t & nbi1,
5927
6093
  constant int64_t & ne00,
5928
6094
  constant int64_t & ne01,
@@ -5943,28 +6109,19 @@ kernel void kernel_mul_mv_id_f16_f32(
5943
6109
  constant uint & r2,
5944
6110
  constant uint & r3,
5945
6111
  constant int & idx,
5946
- device const char * src00,
5947
- device const char * src01,
5948
- device const char * src02,
5949
- device const char * src03,
5950
- device const char * src04,
5951
- device const char * src05,
5952
- device const char * src06,
5953
- device const char * src07,
5954
6112
  uint3 tgpig[[threadgroup_position_in_grid]],
5955
6113
  uint tiitg[[thread_index_in_threadgroup]],
5956
6114
  uint tiisg[[thread_index_in_simdgroup]],
5957
6115
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
5958
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
5959
-
5960
6116
  const int64_t bid = tgpig.z/(ne12*ne13);
5961
6117
 
5962
6118
  tgpig.z = tgpig.z%(ne12*ne13);
5963
6119
 
5964
6120
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6121
+ device const char * src0 = src0s + id*nb02;
5965
6122
 
5966
6123
  kernel_mul_mv_f16_f32_impl(
5967
- src0[id],
6124
+ src0,
5968
6125
  src1 + bid*nb11,
5969
6126
  dst + bid*ne0,
5970
6127
  ne00,
@@ -5989,9 +6146,10 @@ kernel void kernel_mul_mv_id_f16_f32(
5989
6146
 
5990
6147
  [[host_name("kernel_mul_mv_id_q8_0_f32")]]
5991
6148
  kernel void kernel_mul_mv_id_q8_0_f32(
5992
- device const char * ids,
6149
+ device const char * src0s,
5993
6150
  device const char * src1,
5994
6151
  device float * dst,
6152
+ device const char * ids,
5995
6153
  constant uint64_t & nbi1,
5996
6154
  constant int64_t & ne00,
5997
6155
  constant int64_t & ne01,
@@ -6012,28 +6170,19 @@ kernel void kernel_mul_mv_id_q8_0_f32(
6012
6170
  constant uint & r2,
6013
6171
  constant uint & r3,
6014
6172
  constant int & idx,
6015
- device const char * src00,
6016
- device const char * src01,
6017
- device const char * src02,
6018
- device const char * src03,
6019
- device const char * src04,
6020
- device const char * src05,
6021
- device const char * src06,
6022
- device const char * src07,
6023
6173
  uint3 tgpig[[threadgroup_position_in_grid]],
6024
6174
  uint tiitg[[thread_index_in_threadgroup]],
6025
6175
  uint tiisg[[thread_index_in_simdgroup]],
6026
6176
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
6027
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6028
-
6029
6177
  const int64_t bid = tgpig.z/(ne12*ne13);
6030
6178
 
6031
6179
  tgpig.z = tgpig.z%(ne12*ne13);
6032
6180
 
6033
6181
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6182
+ device const char * src0 = src0s + id*nb02;
6034
6183
 
6035
6184
  kernel_mul_mv_q8_0_f32_impl(
6036
- src0[id],
6185
+ src0,
6037
6186
  (device const float *) (src1 + bid*nb11),
6038
6187
  dst + bid*ne0,
6039
6188
  ne00,
@@ -6052,9 +6201,10 @@ kernel void kernel_mul_mv_id_q8_0_f32(
6052
6201
 
6053
6202
  [[host_name("kernel_mul_mv_id_q4_0_f32")]]
6054
6203
  kernel void kernel_mul_mv_id_q4_0_f32(
6055
- device const char * ids,
6204
+ device const char * src0s,
6056
6205
  device const char * src1,
6057
6206
  device float * dst,
6207
+ device const char * ids,
6058
6208
  constant uint64_t & nbi1,
6059
6209
  constant int64_t & ne00,
6060
6210
  constant int64_t & ne01,
@@ -6075,28 +6225,19 @@ kernel void kernel_mul_mv_id_q4_0_f32(
6075
6225
  constant uint & r2,
6076
6226
  constant uint & r3,
6077
6227
  constant int & idx,
6078
- device const char * src00,
6079
- device const char * src01,
6080
- device const char * src02,
6081
- device const char * src03,
6082
- device const char * src04,
6083
- device const char * src05,
6084
- device const char * src06,
6085
- device const char * src07,
6086
6228
  uint3 tgpig[[threadgroup_position_in_grid]],
6087
6229
  uint tiitg[[thread_index_in_threadgroup]],
6088
6230
  uint tiisg[[thread_index_in_simdgroup]],
6089
6231
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
6090
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6091
-
6092
6232
  const int64_t bid = tgpig.z/(ne12*ne13);
6093
6233
 
6094
6234
  tgpig.z = tgpig.z%(ne12*ne13);
6095
6235
 
6096
6236
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6237
+ device const char * src0 = src0s + id*nb02;
6097
6238
 
6098
6239
  mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
6099
- src0[id],
6240
+ src0,
6100
6241
  (device const float *) (src1 + bid*nb11),
6101
6242
  dst + bid*ne0,
6102
6243
  ne00,
@@ -6115,9 +6256,10 @@ kernel void kernel_mul_mv_id_q4_0_f32(
6115
6256
 
6116
6257
  [[host_name("kernel_mul_mv_id_q4_1_f32")]]
6117
6258
  kernel void kernel_mul_mv_id_q4_1_f32(
6118
- device const char * ids,
6259
+ device const char * src0s,
6119
6260
  device const char * src1,
6120
6261
  device float * dst,
6262
+ device const char * ids,
6121
6263
  constant uint64_t & nbi1,
6122
6264
  constant int64_t & ne00,
6123
6265
  constant int64_t & ne01,
@@ -6138,28 +6280,19 @@ kernel void kernel_mul_mv_id_q4_1_f32(
6138
6280
  constant uint & r2,
6139
6281
  constant uint & r3,
6140
6282
  constant int & idx,
6141
- device const char * src00,
6142
- device const char * src01,
6143
- device const char * src02,
6144
- device const char * src03,
6145
- device const char * src04,
6146
- device const char * src05,
6147
- device const char * src06,
6148
- device const char * src07,
6149
6283
  uint3 tgpig[[threadgroup_position_in_grid]],
6150
6284
  uint tiitg[[thread_index_in_threadgroup]],
6151
6285
  uint tiisg[[thread_index_in_simdgroup]],
6152
6286
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
6153
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6154
-
6155
6287
  const int64_t bid = tgpig.z/(ne12*ne13);
6156
6288
 
6157
6289
  tgpig.z = tgpig.z%(ne12*ne13);
6158
6290
 
6159
6291
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6292
+ device const char * src0 = src0s + id*nb02;
6160
6293
 
6161
6294
  mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
6162
- src0[id],
6295
+ src0,
6163
6296
  (device const float *) (src1 + bid*nb11),
6164
6297
  dst + bid*ne0,
6165
6298
  ne00,
@@ -6178,9 +6311,10 @@ kernel void kernel_mul_mv_id_q4_1_f32(
6178
6311
 
6179
6312
  [[host_name("kernel_mul_mv_id_q5_0_f32")]]
6180
6313
  kernel void kernel_mul_mv_id_q5_0_f32(
6181
- device const char * ids,
6314
+ device const char * src0s,
6182
6315
  device const char * src1,
6183
6316
  device float * dst,
6317
+ device const char * ids,
6184
6318
  constant uint64_t & nbi1,
6185
6319
  constant int64_t & ne00,
6186
6320
  constant int64_t & ne01,
@@ -6201,28 +6335,19 @@ kernel void kernel_mul_mv_id_q5_0_f32(
6201
6335
  constant uint & r2,
6202
6336
  constant uint & r3,
6203
6337
  constant int & idx,
6204
- device const char * src00,
6205
- device const char * src01,
6206
- device const char * src02,
6207
- device const char * src03,
6208
- device const char * src04,
6209
- device const char * src05,
6210
- device const char * src06,
6211
- device const char * src07,
6212
6338
  uint3 tgpig[[threadgroup_position_in_grid]],
6213
6339
  uint tiitg[[thread_index_in_threadgroup]],
6214
6340
  uint tiisg[[thread_index_in_simdgroup]],
6215
6341
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
6216
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6217
-
6218
6342
  const int64_t bid = tgpig.z/(ne12*ne13);
6219
6343
 
6220
6344
  tgpig.z = tgpig.z%(ne12*ne13);
6221
6345
 
6222
6346
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6347
+ device const char * src0 = src0s + id*nb02;
6223
6348
 
6224
6349
  mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
6225
- src0[id],
6350
+ src0,
6226
6351
  (device const float *) (src1 + bid*nb11),
6227
6352
  dst + bid*ne0,
6228
6353
  ne00,
@@ -6241,9 +6366,10 @@ kernel void kernel_mul_mv_id_q5_0_f32(
6241
6366
 
6242
6367
  [[host_name("kernel_mul_mv_id_q5_1_f32")]]
6243
6368
  kernel void kernel_mul_mv_id_q5_1_f32(
6244
- device const char * ids,
6369
+ device const char * src0s,
6245
6370
  device const char * src1,
6246
6371
  device float * dst,
6372
+ device const char * ids,
6247
6373
  constant uint64_t & nbi1,
6248
6374
  constant int64_t & ne00,
6249
6375
  constant int64_t & ne01,
@@ -6264,28 +6390,19 @@ kernel void kernel_mul_mv_id_q5_1_f32(
6264
6390
  constant uint & r2,
6265
6391
  constant uint & r3,
6266
6392
  constant int & idx,
6267
- device const char * src00,
6268
- device const char * src01,
6269
- device const char * src02,
6270
- device const char * src03,
6271
- device const char * src04,
6272
- device const char * src05,
6273
- device const char * src06,
6274
- device const char * src07,
6275
6393
  uint3 tgpig[[threadgroup_position_in_grid]],
6276
6394
  uint tiitg[[thread_index_in_threadgroup]],
6277
6395
  uint tiisg[[thread_index_in_simdgroup]],
6278
6396
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
6279
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6280
-
6281
6397
  const int64_t bid = tgpig.z/(ne12*ne13);
6282
6398
 
6283
6399
  tgpig.z = tgpig.z%(ne12*ne13);
6284
6400
 
6285
6401
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6402
+ device const char * src0 = src0s + id*nb02;
6286
6403
 
6287
6404
  mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
6288
- src0[id],
6405
+ src0,
6289
6406
  (device const float *) (src1 + bid*nb11),
6290
6407
  dst + bid*ne0,
6291
6408
  ne00,
@@ -6304,9 +6421,10 @@ kernel void kernel_mul_mv_id_q5_1_f32(
6304
6421
 
6305
6422
  [[host_name("kernel_mul_mv_id_q2_K_f32")]]
6306
6423
  kernel void kernel_mul_mv_id_q2_K_f32(
6307
- device const char * ids,
6424
+ device const char * src0s,
6308
6425
  device const char * src1,
6309
6426
  device float * dst,
6427
+ device const char * ids,
6310
6428
  constant uint64_t & nbi1,
6311
6429
  constant int64_t & ne00,
6312
6430
  constant int64_t & ne01,
@@ -6327,28 +6445,19 @@ kernel void kernel_mul_mv_id_q2_K_f32(
6327
6445
  constant uint & r2,
6328
6446
  constant uint & r3,
6329
6447
  constant int & idx,
6330
- device const char * src00,
6331
- device const char * src01,
6332
- device const char * src02,
6333
- device const char * src03,
6334
- device const char * src04,
6335
- device const char * src05,
6336
- device const char * src06,
6337
- device const char * src07,
6338
6448
  uint3 tgpig[[threadgroup_position_in_grid]],
6339
6449
  uint tiitg[[thread_index_in_threadgroup]],
6340
6450
  uint tiisg[[thread_index_in_simdgroup]],
6341
6451
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
6342
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6343
-
6344
6452
  const int64_t bid = tgpig.z/(ne12*ne13);
6345
6453
 
6346
6454
  tgpig.z = tgpig.z%(ne12*ne13);
6347
6455
 
6348
6456
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6457
+ device const char * src0 = src0s + id*nb02;
6349
6458
 
6350
6459
  kernel_mul_mv_q2_K_f32_impl(
6351
- src0[id],
6460
+ src0,
6352
6461
  (device const float *) (src1 + bid*nb11),
6353
6462
  dst + bid*ne0,
6354
6463
  ne00,
@@ -6367,9 +6476,10 @@ kernel void kernel_mul_mv_id_q2_K_f32(
6367
6476
 
6368
6477
  [[host_name("kernel_mul_mv_id_q3_K_f32")]]
6369
6478
  kernel void kernel_mul_mv_id_q3_K_f32(
6370
- device const char * ids,
6479
+ device const char * src0s,
6371
6480
  device const char * src1,
6372
6481
  device float * dst,
6482
+ device const char * ids,
6373
6483
  constant uint64_t & nbi1,
6374
6484
  constant int64_t & ne00,
6375
6485
  constant int64_t & ne01,
@@ -6390,28 +6500,19 @@ kernel void kernel_mul_mv_id_q3_K_f32(
6390
6500
  constant uint & r2,
6391
6501
  constant uint & r3,
6392
6502
  constant int & idx,
6393
- device const char * src00,
6394
- device const char * src01,
6395
- device const char * src02,
6396
- device const char * src03,
6397
- device const char * src04,
6398
- device const char * src05,
6399
- device const char * src06,
6400
- device const char * src07,
6401
6503
  uint3 tgpig[[threadgroup_position_in_grid]],
6402
6504
  uint tiitg[[thread_index_in_threadgroup]],
6403
6505
  uint tiisg[[thread_index_in_simdgroup]],
6404
6506
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
6405
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6406
-
6407
6507
  const int64_t bid = tgpig.z/(ne12*ne13);
6408
6508
 
6409
6509
  tgpig.z = tgpig.z%(ne12*ne13);
6410
6510
 
6411
6511
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6512
+ device const char * src0 = src0s + id*nb02;
6412
6513
 
6413
6514
  kernel_mul_mv_q3_K_f32_impl(
6414
- src0[id],
6515
+ src0,
6415
6516
  (device const float *) (src1 + bid*nb11),
6416
6517
  dst + bid*ne0,
6417
6518
  ne00,
@@ -6430,9 +6531,10 @@ kernel void kernel_mul_mv_id_q3_K_f32(
6430
6531
 
6431
6532
  [[host_name("kernel_mul_mv_id_q4_K_f32")]]
6432
6533
  kernel void kernel_mul_mv_id_q4_K_f32(
6433
- device const char * ids,
6534
+ device const char * src0s,
6434
6535
  device const char * src1,
6435
6536
  device float * dst,
6537
+ device const char * ids,
6436
6538
  constant uint64_t & nbi1,
6437
6539
  constant int64_t & ne00,
6438
6540
  constant int64_t & ne01,
@@ -6453,28 +6555,19 @@ kernel void kernel_mul_mv_id_q4_K_f32(
6453
6555
  constant uint & r2,
6454
6556
  constant uint & r3,
6455
6557
  constant int & idx,
6456
- device const char * src00,
6457
- device const char * src01,
6458
- device const char * src02,
6459
- device const char * src03,
6460
- device const char * src04,
6461
- device const char * src05,
6462
- device const char * src06,
6463
- device const char * src07,
6464
6558
  uint3 tgpig[[threadgroup_position_in_grid]],
6465
6559
  uint tiitg[[thread_index_in_threadgroup]],
6466
6560
  uint tiisg[[thread_index_in_simdgroup]],
6467
6561
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
6468
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6469
-
6470
6562
  const int64_t bid = tgpig.z/(ne12*ne13);
6471
6563
 
6472
6564
  tgpig.z = tgpig.z%(ne12*ne13);
6473
6565
 
6474
6566
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6567
+ device const char * src0 = src0s + id*nb02;
6475
6568
 
6476
6569
  kernel_mul_mv_q4_K_f32_impl(
6477
- src0[id],
6570
+ src0,
6478
6571
  (device const float *) (src1 + bid*nb11),
6479
6572
  dst + bid*ne0,
6480
6573
  ne00,
@@ -6493,9 +6586,10 @@ kernel void kernel_mul_mv_id_q4_K_f32(
6493
6586
 
6494
6587
  [[host_name("kernel_mul_mv_id_q5_K_f32")]]
6495
6588
  kernel void kernel_mul_mv_id_q5_K_f32(
6496
- device const char * ids,
6589
+ device const char * src0s,
6497
6590
  device const char * src1,
6498
6591
  device float * dst,
6592
+ device const char * ids,
6499
6593
  constant uint64_t & nbi1,
6500
6594
  constant int64_t & ne00,
6501
6595
  constant int64_t & ne01,
@@ -6516,28 +6610,19 @@ kernel void kernel_mul_mv_id_q5_K_f32(
6516
6610
  constant uint & r2,
6517
6611
  constant uint & r3,
6518
6612
  constant int & idx,
6519
- device const char * src00,
6520
- device const char * src01,
6521
- device const char * src02,
6522
- device const char * src03,
6523
- device const char * src04,
6524
- device const char * src05,
6525
- device const char * src06,
6526
- device const char * src07,
6527
6613
  uint3 tgpig[[threadgroup_position_in_grid]],
6528
6614
  uint tiitg[[thread_index_in_threadgroup]],
6529
6615
  uint tiisg[[thread_index_in_simdgroup]],
6530
6616
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
6531
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6532
-
6533
6617
  const int64_t bid = tgpig.z/(ne12*ne13);
6534
6618
 
6535
6619
  tgpig.z = tgpig.z%(ne12*ne13);
6536
6620
 
6537
6621
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6622
+ device const char * src0 = src0s + id*nb02;
6538
6623
 
6539
6624
  kernel_mul_mv_q5_K_f32_impl(
6540
- src0[id],
6625
+ src0,
6541
6626
  (device const float *) (src1 + bid*nb11),
6542
6627
  dst + bid*ne0,
6543
6628
  ne00,
@@ -6556,9 +6641,10 @@ kernel void kernel_mul_mv_id_q5_K_f32(
6556
6641
 
6557
6642
  [[host_name("kernel_mul_mv_id_q6_K_f32")]]
6558
6643
  kernel void kernel_mul_mv_id_q6_K_f32(
6559
- device const char * ids,
6644
+ device const char * src0s,
6560
6645
  device const char * src1,
6561
6646
  device float * dst,
6647
+ device const char * ids,
6562
6648
  constant uint64_t & nbi1,
6563
6649
  constant int64_t & ne00,
6564
6650
  constant int64_t & ne01,
@@ -6579,28 +6665,19 @@ kernel void kernel_mul_mv_id_q6_K_f32(
6579
6665
  constant uint & r2,
6580
6666
  constant uint & r3,
6581
6667
  constant int & idx,
6582
- device const char * src00,
6583
- device const char * src01,
6584
- device const char * src02,
6585
- device const char * src03,
6586
- device const char * src04,
6587
- device const char * src05,
6588
- device const char * src06,
6589
- device const char * src07,
6590
6668
  uint3 tgpig[[threadgroup_position_in_grid]],
6591
6669
  uint tiitg[[thread_index_in_threadgroup]],
6592
6670
  uint tiisg[[thread_index_in_simdgroup]],
6593
6671
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
6594
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6595
-
6596
6672
  const int64_t bid = tgpig.z/(ne12*ne13);
6597
6673
 
6598
6674
  tgpig.z = tgpig.z%(ne12*ne13);
6599
6675
 
6600
6676
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6677
+ device const char * src0 = src0s + id*nb02;
6601
6678
 
6602
6679
  kernel_mul_mv_q6_K_f32_impl(
6603
- src0[id],
6680
+ src0,
6604
6681
  (device const float *) (src1 + bid*nb11),
6605
6682
  dst + bid*ne0,
6606
6683
  ne00,
@@ -6619,9 +6696,10 @@ kernel void kernel_mul_mv_id_q6_K_f32(
6619
6696
 
6620
6697
  [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]]
6621
6698
  kernel void kernel_mul_mv_id_iq2_xxs_f32(
6622
- device const char * ids,
6699
+ device const char * src0s,
6623
6700
  device const char * src1,
6624
6701
  device float * dst,
6702
+ device const char * ids,
6625
6703
  constant uint64_t & nbi1,
6626
6704
  constant int64_t & ne00,
6627
6705
  constant int64_t & ne01,
@@ -6642,29 +6720,20 @@ kernel void kernel_mul_mv_id_iq2_xxs_f32(
6642
6720
  constant uint & r2,
6643
6721
  constant uint & r3,
6644
6722
  constant int & idx,
6645
- device const char * src00,
6646
- device const char * src01,
6647
- device const char * src02,
6648
- device const char * src03,
6649
- device const char * src04,
6650
- device const char * src05,
6651
- device const char * src06,
6652
- device const char * src07,
6653
6723
  threadgroup int8_t * shared_values [[threadgroup(0)]],
6654
6724
  uint3 tgpig[[threadgroup_position_in_grid]],
6655
6725
  uint tiitg[[thread_index_in_threadgroup]],
6656
6726
  uint tiisg[[thread_index_in_simdgroup]],
6657
6727
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
6658
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6659
-
6660
6728
  const int64_t bid = tgpig.z/(ne12*ne13);
6661
6729
 
6662
6730
  tgpig.z = tgpig.z%(ne12*ne13);
6663
6731
 
6664
6732
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6733
+ device const char * src0 = src0s + id*nb02;
6665
6734
 
6666
6735
  kernel_mul_mv_iq2_xxs_f32_impl(
6667
- src0[id],
6736
+ src0,
6668
6737
  (device const float *) (src1 + bid*nb11),
6669
6738
  dst + bid*ne0,
6670
6739
  ne00,
@@ -6684,9 +6753,10 @@ kernel void kernel_mul_mv_id_iq2_xxs_f32(
6684
6753
 
6685
6754
  [[host_name("kernel_mul_mv_id_iq2_xs_f32")]]
6686
6755
  kernel void kernel_mul_mv_id_iq2_xs_f32(
6687
- device const char * ids,
6756
+ device const char * src0s,
6688
6757
  device const char * src1,
6689
6758
  device float * dst,
6759
+ device const char * ids,
6690
6760
  constant uint64_t & nbi1,
6691
6761
  constant int64_t & ne00,
6692
6762
  constant int64_t & ne01,
@@ -6707,29 +6777,20 @@ kernel void kernel_mul_mv_id_iq2_xs_f32(
6707
6777
  constant uint & r2,
6708
6778
  constant uint & r3,
6709
6779
  constant int & idx,
6710
- device const char * src00,
6711
- device const char * src01,
6712
- device const char * src02,
6713
- device const char * src03,
6714
- device const char * src04,
6715
- device const char * src05,
6716
- device const char * src06,
6717
- device const char * src07,
6718
6780
  threadgroup int8_t * shared_values [[threadgroup(0)]],
6719
6781
  uint3 tgpig[[threadgroup_position_in_grid]],
6720
6782
  uint tiitg[[thread_index_in_threadgroup]],
6721
6783
  uint tiisg[[thread_index_in_simdgroup]],
6722
6784
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
6723
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6724
-
6725
6785
  const int64_t bid = tgpig.z/(ne12*ne13);
6726
6786
 
6727
6787
  tgpig.z = tgpig.z%(ne12*ne13);
6728
6788
 
6729
6789
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6790
+ device const char * src0 = src0s + id*nb02;
6730
6791
 
6731
6792
  kernel_mul_mv_iq2_xs_f32_impl(
6732
- src0[id],
6793
+ src0,
6733
6794
  (device const float *) (src1 + bid*nb11),
6734
6795
  dst + bid*ne0,
6735
6796
  ne00,
@@ -6749,9 +6810,10 @@ kernel void kernel_mul_mv_id_iq2_xs_f32(
6749
6810
 
6750
6811
  [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]]
6751
6812
  kernel void kernel_mul_mv_id_iq3_xxs_f32(
6752
- device const char * ids,
6813
+ device const char * src0s,
6753
6814
  device const char * src1,
6754
6815
  device float * dst,
6816
+ device const char * ids,
6755
6817
  constant uint64_t & nbi1,
6756
6818
  constant int64_t & ne00,
6757
6819
  constant int64_t & ne01,
@@ -6772,29 +6834,20 @@ kernel void kernel_mul_mv_id_iq3_xxs_f32(
6772
6834
  constant uint & r2,
6773
6835
  constant uint & r3,
6774
6836
  constant int & idx,
6775
- device const char * src00,
6776
- device const char * src01,
6777
- device const char * src02,
6778
- device const char * src03,
6779
- device const char * src04,
6780
- device const char * src05,
6781
- device const char * src06,
6782
- device const char * src07,
6783
6837
  threadgroup int8_t * shared_values [[threadgroup(0)]],
6784
6838
  uint3 tgpig[[threadgroup_position_in_grid]],
6785
6839
  uint tiitg[[thread_index_in_threadgroup]],
6786
6840
  uint tiisg[[thread_index_in_simdgroup]],
6787
6841
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
6788
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6789
-
6790
6842
  const int64_t bid = tgpig.z/(ne12*ne13);
6791
6843
 
6792
6844
  tgpig.z = tgpig.z%(ne12*ne13);
6793
6845
 
6794
6846
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6847
+ device const char * src0 = src0s + id*nb02;
6795
6848
 
6796
6849
  kernel_mul_mv_iq3_xxs_f32_impl(
6797
- src0[id],
6850
+ src0,
6798
6851
  (device const float *) (src1 + bid*nb11),
6799
6852
  dst + bid*ne0,
6800
6853
  ne00,
@@ -6814,9 +6867,10 @@ kernel void kernel_mul_mv_id_iq3_xxs_f32(
6814
6867
 
6815
6868
  [[host_name("kernel_mul_mv_id_iq3_s_f32")]]
6816
6869
  kernel void kernel_mul_mv_id_iq3_s_f32(
6817
- device const char * ids,
6870
+ device const char * src0s,
6818
6871
  device const char * src1,
6819
6872
  device float * dst,
6873
+ device const char * ids,
6820
6874
  constant uint64_t & nbi1,
6821
6875
  constant int64_t & ne00,
6822
6876
  constant int64_t & ne01,
@@ -6837,29 +6891,20 @@ kernel void kernel_mul_mv_id_iq3_s_f32(
6837
6891
  constant uint & r2,
6838
6892
  constant uint & r3,
6839
6893
  constant int & idx,
6840
- device const char * src00,
6841
- device const char * src01,
6842
- device const char * src02,
6843
- device const char * src03,
6844
- device const char * src04,
6845
- device const char * src05,
6846
- device const char * src06,
6847
- device const char * src07,
6848
6894
  threadgroup int8_t * shared_values [[threadgroup(0)]],
6849
6895
  uint3 tgpig[[threadgroup_position_in_grid]],
6850
6896
  uint tiitg[[thread_index_in_threadgroup]],
6851
6897
  uint tiisg[[thread_index_in_simdgroup]],
6852
6898
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
6853
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6854
-
6855
6899
  const int64_t bid = tgpig.z/(ne12*ne13);
6856
6900
 
6857
6901
  tgpig.z = tgpig.z%(ne12*ne13);
6858
6902
 
6859
6903
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6904
+ device const char * src0 = src0s + id*nb02;
6860
6905
 
6861
6906
  kernel_mul_mv_iq3_s_f32_impl(
6862
- src0[id],
6907
+ src0,
6863
6908
  (device const float *) (src1 + bid*nb11),
6864
6909
  dst + bid*ne0,
6865
6910
  ne00,
@@ -6879,9 +6924,10 @@ kernel void kernel_mul_mv_id_iq3_s_f32(
6879
6924
 
6880
6925
  [[host_name("kernel_mul_mv_id_iq2_s_f32")]]
6881
6926
  kernel void kernel_mul_mv_id_iq2_s_f32(
6882
- device const char * ids,
6927
+ device const char * src0s,
6883
6928
  device const char * src1,
6884
6929
  device float * dst,
6930
+ device const char * ids,
6885
6931
  constant uint64_t & nbi1,
6886
6932
  constant int64_t & ne00,
6887
6933
  constant int64_t & ne01,
@@ -6902,29 +6948,20 @@ kernel void kernel_mul_mv_id_iq2_s_f32(
6902
6948
  constant uint & r2,
6903
6949
  constant uint & r3,
6904
6950
  constant int & idx,
6905
- device const char * src00,
6906
- device const char * src01,
6907
- device const char * src02,
6908
- device const char * src03,
6909
- device const char * src04,
6910
- device const char * src05,
6911
- device const char * src06,
6912
- device const char * src07,
6913
6951
  threadgroup int8_t * shared_values [[threadgroup(0)]],
6914
6952
  uint3 tgpig[[threadgroup_position_in_grid]],
6915
6953
  uint tiitg[[thread_index_in_threadgroup]],
6916
6954
  uint tiisg[[thread_index_in_simdgroup]],
6917
6955
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
6918
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6919
-
6920
6956
  const int64_t bid = tgpig.z/(ne12*ne13);
6921
6957
 
6922
6958
  tgpig.z = tgpig.z%(ne12*ne13);
6923
6959
 
6924
6960
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6961
+ device const char * src0 = src0s + id*nb02;
6925
6962
 
6926
6963
  kernel_mul_mv_iq2_s_f32_impl(
6927
- src0[id],
6964
+ src0,
6928
6965
  (device const float *) (src1 + bid*nb11),
6929
6966
  dst + bid*ne0,
6930
6967
  ne00,
@@ -6944,9 +6981,10 @@ kernel void kernel_mul_mv_id_iq2_s_f32(
6944
6981
 
6945
6982
  [[host_name("kernel_mul_mv_id_iq1_s_f32")]]
6946
6983
  kernel void kernel_mul_mv_id_iq1_s_f32(
6947
- device const char * ids,
6984
+ device const char * src0s,
6948
6985
  device const char * src1,
6949
6986
  device float * dst,
6987
+ device const char * ids,
6950
6988
  constant uint64_t & nbi1,
6951
6989
  constant int64_t & ne00,
6952
6990
  constant int64_t & ne01,
@@ -6967,28 +7005,74 @@ kernel void kernel_mul_mv_id_iq1_s_f32(
6967
7005
  constant uint & r2,
6968
7006
  constant uint & r3,
6969
7007
  constant int & idx,
6970
- device const char * src00,
6971
- device const char * src01,
6972
- device const char * src02,
6973
- device const char * src03,
6974
- device const char * src04,
6975
- device const char * src05,
6976
- device const char * src06,
6977
- device const char * src07,
6978
7008
  uint3 tgpig[[threadgroup_position_in_grid]],
6979
7009
  uint tiitg[[thread_index_in_threadgroup]],
6980
7010
  uint tiisg[[thread_index_in_simdgroup]],
6981
7011
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
6982
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6983
-
6984
7012
  const int64_t bid = tgpig.z/(ne12*ne13);
6985
7013
 
6986
7014
  tgpig.z = tgpig.z%(ne12*ne13);
6987
7015
 
6988
7016
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
7017
+ device const char * src0 = src0s + id*nb02;
6989
7018
 
6990
7019
  kernel_mul_mv_iq1_s_f32_impl(
6991
- src0[id],
7020
+ src0,
7021
+ (device const float *) (src1 + bid*nb11),
7022
+ dst + bid*ne0,
7023
+ ne00,
7024
+ ne01,
7025
+ ne02,
7026
+ ne10,
7027
+ ne12,
7028
+ ne0,
7029
+ ne1,
7030
+ r2,
7031
+ r3,
7032
+ tgpig,
7033
+ tiisg,
7034
+ sgitg);
7035
+ }
7036
+
7037
+ [[host_name("kernel_mul_mv_id_iq1_m_f32")]]
7038
+ kernel void kernel_mul_mv_id_iq1_m_f32(
7039
+ device const char * src0s,
7040
+ device const char * src1,
7041
+ device float * dst,
7042
+ device const char * ids,
7043
+ constant uint64_t & nbi1,
7044
+ constant int64_t & ne00,
7045
+ constant int64_t & ne01,
7046
+ constant int64_t & ne02,
7047
+ constant uint64_t & nb00,
7048
+ constant uint64_t & nb01,
7049
+ constant uint64_t & nb02,
7050
+ constant int64_t & ne10,
7051
+ constant int64_t & ne11,
7052
+ constant int64_t & ne12,
7053
+ constant int64_t & ne13,
7054
+ constant uint64_t & nb10,
7055
+ constant uint64_t & nb11,
7056
+ constant uint64_t & nb12,
7057
+ constant int64_t & ne0,
7058
+ constant int64_t & ne1,
7059
+ constant uint64_t & nb1,
7060
+ constant uint & r2,
7061
+ constant uint & r3,
7062
+ constant int & idx,
7063
+ uint3 tgpig[[threadgroup_position_in_grid]],
7064
+ uint tiitg[[thread_index_in_threadgroup]],
7065
+ uint tiisg[[thread_index_in_simdgroup]],
7066
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
7067
+ const int64_t bid = tgpig.z/(ne12*ne13);
7068
+
7069
+ tgpig.z = tgpig.z%(ne12*ne13);
7070
+
7071
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
7072
+ device const char * src0 = src0s + id*nb02;
7073
+
7074
+ kernel_mul_mv_iq1_m_f32_impl(
7075
+ src0,
6992
7076
  (device const float *) (src1 + bid*nb11),
6993
7077
  dst + bid*ne0,
6994
7078
  ne00,
@@ -7007,9 +7091,10 @@ kernel void kernel_mul_mv_id_iq1_s_f32(
7007
7091
 
7008
7092
  [[host_name("kernel_mul_mv_id_iq4_nl_f32")]]
7009
7093
  kernel void kernel_mul_mv_id_iq4_nl_f32(
7010
- device const char * ids,
7094
+ device const char * src0s,
7011
7095
  device const char * src1,
7012
7096
  device float * dst,
7097
+ device const char * ids,
7013
7098
  constant uint64_t & nbi1,
7014
7099
  constant int64_t & ne00,
7015
7100
  constant int64_t & ne01,
@@ -7030,29 +7115,20 @@ kernel void kernel_mul_mv_id_iq4_nl_f32(
7030
7115
  constant uint & r2,
7031
7116
  constant uint & r3,
7032
7117
  constant int & idx,
7033
- device const char * src00,
7034
- device const char * src01,
7035
- device const char * src02,
7036
- device const char * src03,
7037
- device const char * src04,
7038
- device const char * src05,
7039
- device const char * src06,
7040
- device const char * src07,
7041
7118
  threadgroup float * shared_values [[threadgroup(0)]],
7042
7119
  uint3 tgpig[[threadgroup_position_in_grid]],
7043
7120
  uint tiitg[[thread_index_in_threadgroup]],
7044
7121
  uint tiisg[[thread_index_in_simdgroup]],
7045
7122
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
7046
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
7047
-
7048
7123
  const int64_t bid = tgpig.z/(ne12*ne13);
7049
7124
 
7050
7125
  tgpig.z = tgpig.z%(ne12*ne13);
7051
7126
 
7052
7127
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
7128
+ device const char * src0 = src0s + id*nb02;
7053
7129
 
7054
7130
  kernel_mul_mv_iq4_nl_f32_impl(
7055
- src0[id],
7131
+ src0,
7056
7132
  (device const float *) (src1 + bid*nb11),
7057
7133
  dst + bid*ne0,
7058
7134
  ne00,
@@ -7072,9 +7148,10 @@ kernel void kernel_mul_mv_id_iq4_nl_f32(
7072
7148
 
7073
7149
  [[host_name("kernel_mul_mv_id_iq4_xs_f32")]]
7074
7150
  kernel void kernel_mul_mv_id_iq4_xs_f32(
7075
- device const char * ids,
7151
+ device const char * src0s,
7076
7152
  device const char * src1,
7077
7153
  device float * dst,
7154
+ device const char * ids,
7078
7155
  constant uint64_t & nbi1,
7079
7156
  constant int64_t & ne00,
7080
7157
  constant int64_t & ne01,
@@ -7095,33 +7172,24 @@ kernel void kernel_mul_mv_id_iq4_xs_f32(
7095
7172
  constant uint & r2,
7096
7173
  constant uint & r3,
7097
7174
  constant int & idx,
7098
- device const char * src00,
7099
- device const char * src01,
7100
- device const char * src02,
7101
- device const char * src03,
7102
- device const char * src04,
7103
- device const char * src05,
7104
- device const char * src06,
7105
- device const char * src07,
7106
7175
  threadgroup float * shared_values [[threadgroup(0)]],
7107
7176
  uint3 tgpig[[threadgroup_position_in_grid]],
7108
7177
  uint tiitg[[thread_index_in_threadgroup]],
7109
7178
  uint tiisg[[thread_index_in_simdgroup]],
7110
7179
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
7111
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
7112
-
7113
7180
  const int64_t bid = tgpig.z/(ne12*ne13);
7114
7181
 
7115
7182
  tgpig.z = tgpig.z%(ne12*ne13);
7116
7183
 
7117
7184
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
7185
+ device const char * src0 = src0s + id*nb02;
7118
7186
 
7119
7187
  #if QK_K == 64
7120
7188
  kernel_mul_mv_iq4_nl_f32_impl(
7121
7189
  #else
7122
7190
  kernel_mul_mv_iq4_xs_f32_impl(
7123
7191
  #endif
7124
- src0[id],
7192
+ src0,
7125
7193
  (device const float *) (src1 + bid*nb11),
7126
7194
  dst + bid*ne0,
7127
7195
  ne00,