llama_cpp 0.14.3 → 0.14.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -13,8 +13,8 @@ using namespace metal;
13
13
  #define N_SIMDWIDTH 32 // assuming SIMD group size is 32
14
14
 
15
15
  enum ggml_sort_order {
16
- GGML_SORT_ASC,
17
- GGML_SORT_DESC,
16
+ GGML_SORT_ORDER_ASC,
17
+ GGML_SORT_ORDER_DESC,
18
18
  };
19
19
 
20
20
  // general-purpose kernel for addition, multiplication and division of two tensors
@@ -1973,9 +1973,11 @@ kernel void kernel_timestep_embedding_f32(
1973
1973
 
1974
1974
  // bitonic sort implementation following the CUDA kernels as reference
1975
1975
  typedef void (argsort_t)(
1976
- device const float * x,
1977
- device int32_t * dst,
1978
- constant int64_t & ncols,
1976
+ device const float * x,
1977
+ device int32_t * dst,
1978
+ constant int64_t & ncols,
1979
+ constant int64_t & ncols_pad,
1980
+ threadgroup int32_t * shared_values [[threadgroup(0)]],
1979
1981
  uint3 tgpig[[threadgroup_position_in_grid]],
1980
1982
  uint3 tpitg[[thread_position_in_threadgroup]]);
1981
1983
 
@@ -1984,33 +1986,42 @@ kernel void kernel_argsort_f32_i32(
1984
1986
  device const float * x,
1985
1987
  device int32_t * dst,
1986
1988
  constant int64_t & ncols,
1989
+ constant int64_t & ncols_pad,
1990
+ threadgroup int32_t * shared_values [[threadgroup(0)]],
1987
1991
  uint3 tgpig[[threadgroup_position_in_grid]],
1988
1992
  uint3 tpitg[[thread_position_in_threadgroup]]) {
1989
1993
  // bitonic sort
1990
1994
  int col = tpitg[0];
1991
1995
  int row = tgpig[1];
1992
1996
 
1993
- if (col >= ncols) return;
1997
+ if (col >= ncols_pad) return;
1994
1998
 
1995
- device const float * x_row = x + row * ncols;
1996
- device int32_t * dst_row = dst + row * ncols;
1999
+ device const float * x_row = x + row * ncols;
2000
+ threadgroup int32_t * dst_row = shared_values;
1997
2001
 
1998
2002
  // initialize indices
1999
- if (col < ncols) {
2000
- dst_row[col] = col;
2001
- }
2003
+ dst_row[col] = col;
2004
+
2002
2005
  threadgroup_barrier(mem_flags::mem_threadgroup);
2003
2006
 
2004
- for (int k = 2; k <= ncols; k *= 2) {
2007
+ for (int k = 2; k <= ncols_pad; k *= 2) {
2005
2008
  for (int j = k / 2; j > 0; j /= 2) {
2006
2009
  int ixj = col ^ j;
2007
2010
  if (ixj > col) {
2008
2011
  if ((col & k) == 0) {
2009
- if (order == GGML_SORT_ASC ? x_row[dst_row[col]] > x_row[dst_row[ixj]] : x_row[dst_row[col]] < x_row[dst_row[ixj]]) {
2012
+ if (dst_row[col] >= ncols ||
2013
+ (dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ?
2014
+ x_row[dst_row[col]] > x_row[dst_row[ixj]] :
2015
+ x_row[dst_row[col]] < x_row[dst_row[ixj]]))
2016
+ ) {
2010
2017
  SWAP(dst_row[col], dst_row[ixj]);
2011
2018
  }
2012
2019
  } else {
2013
- if (order == GGML_SORT_ASC ? x_row[dst_row[col]] < x_row[dst_row[ixj]] : x_row[dst_row[col]] > x_row[dst_row[ixj]]) {
2020
+ if (dst_row[ixj] >= ncols ||
2021
+ (dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ?
2022
+ x_row[dst_row[col]] < x_row[dst_row[ixj]] :
2023
+ x_row[dst_row[col]] > x_row[dst_row[ixj]]))
2024
+ ) {
2014
2025
  SWAP(dst_row[col], dst_row[ixj]);
2015
2026
  }
2016
2027
  }
@@ -2018,10 +2029,15 @@ kernel void kernel_argsort_f32_i32(
2018
2029
  threadgroup_barrier(mem_flags::mem_threadgroup);
2019
2030
  }
2020
2031
  }
2032
+
2033
+ // copy the result to dst without the padding
2034
+ if (col < ncols) {
2035
+ dst[row * ncols + col] = dst_row[col];
2036
+ }
2021
2037
  }
2022
2038
 
2023
- template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ASC>;
2024
- template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_DESC>;
2039
+ template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_ASC>;
2040
+ template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_DESC>;
2025
2041
 
2026
2042
  kernel void kernel_leaky_relu_f32(
2027
2043
  device const float * src0,
@@ -4456,6 +4472,114 @@ void kernel_mul_mv_iq1_s_f32_impl(
4456
4472
  }
4457
4473
  }
4458
4474
 
4475
+ void kernel_mul_mv_iq1_m_f32_impl(
4476
+ device const void * src0,
4477
+ device const float * src1,
4478
+ device float * dst,
4479
+ constant int64_t & ne00,
4480
+ constant int64_t & ne01,
4481
+ constant int64_t & ne02,
4482
+ constant int64_t & ne10,
4483
+ constant int64_t & ne12,
4484
+ constant int64_t & ne0,
4485
+ constant int64_t & ne1,
4486
+ constant uint & r2,
4487
+ constant uint & r3,
4488
+ uint3 tgpig[[threadgroup_position_in_grid]],
4489
+ uint tiisg[[thread_index_in_simdgroup]],
4490
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4491
+
4492
+ const int nb = ne00/QK_K;
4493
+ const int r0 = tgpig.x;
4494
+ const int r1 = tgpig.y;
4495
+ const int im = tgpig.z;
4496
+
4497
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
4498
+ const int ib_row = first_row * nb;
4499
+
4500
+ const uint i12 = im%ne12;
4501
+ const uint i13 = im/ne12;
4502
+
4503
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
4504
+ device const block_iq1_m * x = (device const block_iq1_m *) src0 + ib_row + offset0;
4505
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
4506
+
4507
+ float yl[32];
4508
+ float sumf[N_DST]={0.f}, all_sum;
4509
+
4510
+ const int nb32 = nb * (QK_K / 32);
4511
+
4512
+ const int ix = tiisg;
4513
+
4514
+ device const float * y4 = y + 32 * ix;
4515
+
4516
+ #if QK_K != 64
4517
+ iq1m_scale_t scale;
4518
+ #endif
4519
+
4520
+ for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
4521
+
4522
+ float4 sumy = {0.f};
4523
+ for (int i = 0; i < 8; ++i) {
4524
+ yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
4525
+ yl[i+ 8] = y4[i+ 8]; sumy[1] += yl[i+ 8];
4526
+ yl[i+16] = y4[i+16]; sumy[2] += yl[i+16];
4527
+ yl[i+24] = y4[i+24]; sumy[3] += yl[i+24];
4528
+ }
4529
+
4530
+ const int ibl = ib32 / (QK_K / 32);
4531
+ const int ib = ib32 % (QK_K / 32);
4532
+
4533
+ device const block_iq1_m * xr = x + ibl;
4534
+ device const uint8_t * qs = xr->qs + 4 * ib;
4535
+ device const uint8_t * qh = xr->qh + 2 * ib;
4536
+ device const uint16_t * sc = (device const uint16_t *)xr->scales;
4537
+
4538
+ for (int row = 0; row < N_DST; row++) {
4539
+
4540
+ #if QK_K != 64
4541
+ scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
4542
+ #endif
4543
+
4544
+ constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
4545
+ constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700)));
4546
+ constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid_gpu + (qs[2] | ((qh[1] << 8) & 0x700)));
4547
+ constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[1] << 4) & 0x700)));
4548
+
4549
+ float2 sum = {0.f};
4550
+ for (int j = 0; j < 4; ++j) {
4551
+ sum[0] += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4)
4552
+ + yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4);
4553
+ sum[1] += yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4)
4554
+ + yl[j+24] * (grid4[j] & 0xf) + yl[j+28] * (grid4[j] >> 4);
4555
+ }
4556
+ const float delta1 = sumy[0] * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[1] * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
4557
+ const float delta2 = sumy[2] * (qh[1] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[3] * (qh[1] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
4558
+ #if QK_K == 64
4559
+ const float d = (float) *((device const half *)(sc - 1));
4560
+ sumf[row] += d * ((sum[0] + delta1) * (2*((sc[0] >> (8*(ib%2)+0)) & 0xf) + 1) +
4561
+ (sum[1] + delta2) * (2*((sc[0] >> (8*(ib%2)+4)) & 0xf) + 1));
4562
+ #else
4563
+ sumf[row] += (float)scale.f16 * ((sum[0] + delta1) * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 7) + 1) +
4564
+ (sum[1] + delta2) * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 7) + 1));
4565
+ #endif
4566
+
4567
+ sc += nb*sizeof(block_iq1_m)/2;
4568
+ qs += nb*sizeof(block_iq1_m);
4569
+ qh += nb*sizeof(block_iq1_m);
4570
+ }
4571
+
4572
+ y4 += 32 * 32;
4573
+ }
4574
+
4575
+ for (int row = 0; row < N_DST; ++row) {
4576
+ all_sum = simd_sum(sumf[row]);
4577
+ if (tiisg == 0) {
4578
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
4579
+ }
4580
+ }
4581
+ }
4582
+
4459
4583
  void kernel_mul_mv_iq4_nl_f32_impl(
4460
4584
  device const void * src0,
4461
4585
  device const float * src1,
@@ -4673,6 +4797,34 @@ kernel void kernel_mul_mv_iq1_s_f32(
4673
4797
  kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
4674
4798
  }
4675
4799
 
4800
+ [[host_name("kernel_mul_mv_iq1_m_f32")]]
4801
+ kernel void kernel_mul_mv_iq1_m_f32(
4802
+ device const void * src0,
4803
+ device const float * src1,
4804
+ device float * dst,
4805
+ constant int64_t & ne00,
4806
+ constant int64_t & ne01,
4807
+ constant int64_t & ne02,
4808
+ constant uint64_t & nb00,
4809
+ constant uint64_t & nb01,
4810
+ constant uint64_t & nb02,
4811
+ constant int64_t & ne10,
4812
+ constant int64_t & ne11,
4813
+ constant int64_t & ne12,
4814
+ constant uint64_t & nb10,
4815
+ constant uint64_t & nb11,
4816
+ constant uint64_t & nb12,
4817
+ constant int64_t & ne0,
4818
+ constant int64_t & ne1,
4819
+ constant uint & r2,
4820
+ constant uint & r3,
4821
+ uint3 tgpig[[threadgroup_position_in_grid]],
4822
+ uint tiisg[[thread_index_in_simdgroup]],
4823
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4824
+
4825
+ kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
4826
+ }
4827
+
4676
4828
  [[host_name("kernel_mul_mv_iq4_nl_f32")]]
4677
4829
  kernel void kernel_mul_mv_iq4_nl_f32(
4678
4830
  device const void * src0,
@@ -5146,6 +5298,38 @@ void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 &
5146
5298
  }
5147
5299
  }
5148
5300
 
5301
+ template <typename type4x4>
5302
+ void dequantize_iq1_m(device const block_iq1_m * xb, short il, thread type4x4 & reg) {
5303
+ // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
5304
+ const int ib32 = il/2;
5305
+ il = il%2;
5306
+ device const uint16_t * sc = (device const uint16_t *)xb->scales;
5307
+ #if QK_K == 64
5308
+ const float d = xb->d;
5309
+ #else
5310
+ iq1m_scale_t scale;
5311
+ scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
5312
+ const float d = scale.f16;
5313
+ #endif
5314
+ device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
5315
+ device const uint8_t * qh = xb->qh + 2*ib32 + il;
5316
+ #if QK_K == 64
5317
+ const float dl = d * (2*((sc[ib32/2] >> (8*(ib32%2)+4*il)) & 0xf) + 1);
5318
+ #else
5319
+ const float dl = d * (2*((sc[ib32/2] >> (6*(ib32%2)+3*il)) & 7) + 1);
5320
+ #endif
5321
+ const float ml1 = dl * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
5322
+ const float ml2 = dl * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
5323
+ constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
5324
+ constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700)));
5325
+ for (int i = 0; i < 4; ++i) {
5326
+ reg[0][i] = dl * (grid1[i] & 0xf) + ml1;
5327
+ reg[1][i] = dl * (grid1[i] >> 4) + ml1;
5328
+ reg[2][i] = dl * (grid2[i] & 0xf) + ml2;
5329
+ reg[3][i] = dl * (grid2[i] >> 4) + ml2;
5330
+ }
5331
+ }
5332
+
5149
5333
  template <typename type4x4>
5150
5334
  void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4 & reg) {
5151
5335
  device const uint16_t * q4 = (device const uint16_t *)xb->qs;
@@ -5617,9 +5801,10 @@ kernel void kernel_mul_mm(device const uchar * src0,
5617
5801
 
5618
5802
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
5619
5803
  kernel void kernel_mul_mm_id(
5620
- device const uchar * ids,
5804
+ device const uchar * src0s,
5621
5805
  device const uchar * src1,
5622
5806
  device float * dst,
5807
+ device const uchar * ids,
5623
5808
  constant uint64_t & nbi1,
5624
5809
  constant int64_t & ne00,
5625
5810
  constant int64_t & ne02,
@@ -5636,22 +5821,14 @@ kernel void kernel_mul_mm_id(
5636
5821
  constant uint & r2,
5637
5822
  constant uint & r3,
5638
5823
  constant int & idx,
5639
- device const uchar * src00,
5640
- device const uchar * src01,
5641
- device const uchar * src02,
5642
- device const uchar * src03,
5643
- device const uchar * src04,
5644
- device const uchar * src05,
5645
- device const uchar * src06,
5646
- device const uchar * src07,
5647
5824
  threadgroup uchar * shared_memory [[threadgroup(0)]],
5648
5825
  uint3 tgpig[[threadgroup_position_in_grid]],
5649
5826
  uint tiitg[[thread_index_in_threadgroup]],
5650
5827
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
5651
- device const uchar * src0s[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
5652
5828
 
5653
5829
  // expert id
5654
5830
  const int32_t id = tgpig.z/(ne12*ne13);
5831
+ device const uchar * src0 = src0s + id*nb02;
5655
5832
 
5656
5833
  tgpig.z = tgpig.z%(ne12*ne13);
5657
5834
 
@@ -5666,7 +5843,7 @@ kernel void kernel_mul_mm_id(
5666
5843
  }
5667
5844
 
5668
5845
  kernel_mul_mm_id_impl<block_q, nl, dequantize_func>(
5669
- src0s[id],
5846
+ src0,
5670
5847
  src1,
5671
5848
  src1ids,
5672
5849
  dst,
@@ -5730,6 +5907,7 @@ template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_t kernel_get_r
5730
5907
  template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_t kernel_get_rows<block_iq3_s, QK_NL, dequantize_iq3_s>;
5731
5908
  template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_t kernel_get_rows<block_iq2_s, QK_NL, dequantize_iq2_s>;
5732
5909
  template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows<block_iq1_s, QK_NL, dequantize_iq1_s>;
5910
+ template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_t kernel_get_rows<block_iq1_m, QK_NL, dequantize_iq1_m>;
5733
5911
  template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows<block_iq4_nl, 2, dequantize_iq4_nl>;
5734
5912
  #if QK_K == 64
5735
5913
  template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows<block_iq4_xs, 2, dequantize_iq4_xs>;
@@ -5778,6 +5956,7 @@ template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_m
5778
5956
  template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_s, QK_NL, dequantize_iq3_s>;
5779
5957
  template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_s, QK_NL, dequantize_iq2_s>;
5780
5958
  template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_s, QK_NL, dequantize_iq1_s>;
5959
+ template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_m, QK_NL, dequantize_iq1_m>;
5781
5960
  template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_nl>;
5782
5961
  #if QK_K == 64
5783
5962
  template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_xs>;
@@ -5790,9 +5969,10 @@ template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_m
5790
5969
  //
5791
5970
 
5792
5971
  typedef void (mat_mm_id_t)(
5793
- device const uchar * ids,
5972
+ device const uchar * src0s,
5794
5973
  device const uchar * src1,
5795
5974
  device float * dst,
5975
+ device const uchar * ids,
5796
5976
  constant uint64_t & nbi1,
5797
5977
  constant int64_t & ne00,
5798
5978
  constant int64_t & ne02,
@@ -5809,14 +5989,6 @@ typedef void (mat_mm_id_t)(
5809
5989
  constant uint & r2,
5810
5990
  constant uint & r3,
5811
5991
  constant int & idx,
5812
- device const uchar * src00,
5813
- device const uchar * src01,
5814
- device const uchar * src02,
5815
- device const uchar * src03,
5816
- device const uchar * src04,
5817
- device const uchar * src05,
5818
- device const uchar * src06,
5819
- device const uchar * src07,
5820
5992
  threadgroup uchar *,
5821
5993
  uint3, uint, uint);
5822
5994
 
@@ -5838,6 +6010,7 @@ template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel
5838
6010
  template [[host_name("kernel_mul_mm_id_iq3_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_s, QK_NL, dequantize_iq3_s>;
5839
6011
  template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_s, QK_NL, dequantize_iq2_s>;
5840
6012
  template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s, QK_NL, dequantize_iq1_s>;
6013
+ template [[host_name("kernel_mul_mm_id_iq1_m_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_m, QK_NL, dequantize_iq1_m>;
5841
6014
  template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl, 2, dequantize_iq4_nl>;
5842
6015
  #if QK_K == 64
5843
6016
  template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, 2, dequantize_iq4_xs>;
@@ -5851,9 +6024,10 @@ template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel
5851
6024
 
5852
6025
  [[host_name("kernel_mul_mv_id_f32_f32")]]
5853
6026
  kernel void kernel_mul_mv_id_f32_f32(
5854
- device const char * ids,
6027
+ device const char * src0s,
5855
6028
  device const char * src1,
5856
6029
  device float * dst,
6030
+ device const char * ids,
5857
6031
  constant uint64_t & nbi1,
5858
6032
  constant int64_t & ne00,
5859
6033
  constant int64_t & ne01,
@@ -5874,28 +6048,19 @@ kernel void kernel_mul_mv_id_f32_f32(
5874
6048
  constant uint & r2,
5875
6049
  constant uint & r3,
5876
6050
  constant int & idx,
5877
- device const char * src00,
5878
- device const char * src01,
5879
- device const char * src02,
5880
- device const char * src03,
5881
- device const char * src04,
5882
- device const char * src05,
5883
- device const char * src06,
5884
- device const char * src07,
5885
6051
  uint3 tgpig[[threadgroup_position_in_grid]],
5886
6052
  uint tiitg[[thread_index_in_threadgroup]],
5887
6053
  uint tiisg[[thread_index_in_simdgroup]],
5888
6054
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
5889
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
5890
-
5891
6055
  const int64_t bid = tgpig.z/(ne12*ne13);
5892
6056
 
5893
6057
  tgpig.z = tgpig.z%(ne12*ne13);
5894
6058
 
5895
6059
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6060
+ device const char * src0 = src0s + id*nb02;
5896
6061
 
5897
6062
  kernel_mul_mv_f32_f32_impl(
5898
- src0[id],
6063
+ src0,
5899
6064
  src1 + bid*nb11,
5900
6065
  dst + bid*ne0,
5901
6066
  ne00,
@@ -5920,9 +6085,10 @@ kernel void kernel_mul_mv_id_f32_f32(
5920
6085
 
5921
6086
  [[host_name("kernel_mul_mv_id_f16_f32")]]
5922
6087
  kernel void kernel_mul_mv_id_f16_f32(
5923
- device const char * ids,
6088
+ device const char * src0s,
5924
6089
  device const char * src1,
5925
6090
  device float * dst,
6091
+ device const char * ids,
5926
6092
  constant uint64_t & nbi1,
5927
6093
  constant int64_t & ne00,
5928
6094
  constant int64_t & ne01,
@@ -5943,28 +6109,19 @@ kernel void kernel_mul_mv_id_f16_f32(
5943
6109
  constant uint & r2,
5944
6110
  constant uint & r3,
5945
6111
  constant int & idx,
5946
- device const char * src00,
5947
- device const char * src01,
5948
- device const char * src02,
5949
- device const char * src03,
5950
- device const char * src04,
5951
- device const char * src05,
5952
- device const char * src06,
5953
- device const char * src07,
5954
6112
  uint3 tgpig[[threadgroup_position_in_grid]],
5955
6113
  uint tiitg[[thread_index_in_threadgroup]],
5956
6114
  uint tiisg[[thread_index_in_simdgroup]],
5957
6115
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
5958
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
5959
-
5960
6116
  const int64_t bid = tgpig.z/(ne12*ne13);
5961
6117
 
5962
6118
  tgpig.z = tgpig.z%(ne12*ne13);
5963
6119
 
5964
6120
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6121
+ device const char * src0 = src0s + id*nb02;
5965
6122
 
5966
6123
  kernel_mul_mv_f16_f32_impl(
5967
- src0[id],
6124
+ src0,
5968
6125
  src1 + bid*nb11,
5969
6126
  dst + bid*ne0,
5970
6127
  ne00,
@@ -5989,9 +6146,10 @@ kernel void kernel_mul_mv_id_f16_f32(
5989
6146
 
5990
6147
  [[host_name("kernel_mul_mv_id_q8_0_f32")]]
5991
6148
  kernel void kernel_mul_mv_id_q8_0_f32(
5992
- device const char * ids,
6149
+ device const char * src0s,
5993
6150
  device const char * src1,
5994
6151
  device float * dst,
6152
+ device const char * ids,
5995
6153
  constant uint64_t & nbi1,
5996
6154
  constant int64_t & ne00,
5997
6155
  constant int64_t & ne01,
@@ -6012,28 +6170,19 @@ kernel void kernel_mul_mv_id_q8_0_f32(
6012
6170
  constant uint & r2,
6013
6171
  constant uint & r3,
6014
6172
  constant int & idx,
6015
- device const char * src00,
6016
- device const char * src01,
6017
- device const char * src02,
6018
- device const char * src03,
6019
- device const char * src04,
6020
- device const char * src05,
6021
- device const char * src06,
6022
- device const char * src07,
6023
6173
  uint3 tgpig[[threadgroup_position_in_grid]],
6024
6174
  uint tiitg[[thread_index_in_threadgroup]],
6025
6175
  uint tiisg[[thread_index_in_simdgroup]],
6026
6176
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
6027
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6028
-
6029
6177
  const int64_t bid = tgpig.z/(ne12*ne13);
6030
6178
 
6031
6179
  tgpig.z = tgpig.z%(ne12*ne13);
6032
6180
 
6033
6181
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6182
+ device const char * src0 = src0s + id*nb02;
6034
6183
 
6035
6184
  kernel_mul_mv_q8_0_f32_impl(
6036
- src0[id],
6185
+ src0,
6037
6186
  (device const float *) (src1 + bid*nb11),
6038
6187
  dst + bid*ne0,
6039
6188
  ne00,
@@ -6052,9 +6201,10 @@ kernel void kernel_mul_mv_id_q8_0_f32(
6052
6201
 
6053
6202
  [[host_name("kernel_mul_mv_id_q4_0_f32")]]
6054
6203
  kernel void kernel_mul_mv_id_q4_0_f32(
6055
- device const char * ids,
6204
+ device const char * src0s,
6056
6205
  device const char * src1,
6057
6206
  device float * dst,
6207
+ device const char * ids,
6058
6208
  constant uint64_t & nbi1,
6059
6209
  constant int64_t & ne00,
6060
6210
  constant int64_t & ne01,
@@ -6075,28 +6225,19 @@ kernel void kernel_mul_mv_id_q4_0_f32(
6075
6225
  constant uint & r2,
6076
6226
  constant uint & r3,
6077
6227
  constant int & idx,
6078
- device const char * src00,
6079
- device const char * src01,
6080
- device const char * src02,
6081
- device const char * src03,
6082
- device const char * src04,
6083
- device const char * src05,
6084
- device const char * src06,
6085
- device const char * src07,
6086
6228
  uint3 tgpig[[threadgroup_position_in_grid]],
6087
6229
  uint tiitg[[thread_index_in_threadgroup]],
6088
6230
  uint tiisg[[thread_index_in_simdgroup]],
6089
6231
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
6090
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6091
-
6092
6232
  const int64_t bid = tgpig.z/(ne12*ne13);
6093
6233
 
6094
6234
  tgpig.z = tgpig.z%(ne12*ne13);
6095
6235
 
6096
6236
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6237
+ device const char * src0 = src0s + id*nb02;
6097
6238
 
6098
6239
  mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
6099
- src0[id],
6240
+ src0,
6100
6241
  (device const float *) (src1 + bid*nb11),
6101
6242
  dst + bid*ne0,
6102
6243
  ne00,
@@ -6115,9 +6256,10 @@ kernel void kernel_mul_mv_id_q4_0_f32(
6115
6256
 
6116
6257
  [[host_name("kernel_mul_mv_id_q4_1_f32")]]
6117
6258
  kernel void kernel_mul_mv_id_q4_1_f32(
6118
- device const char * ids,
6259
+ device const char * src0s,
6119
6260
  device const char * src1,
6120
6261
  device float * dst,
6262
+ device const char * ids,
6121
6263
  constant uint64_t & nbi1,
6122
6264
  constant int64_t & ne00,
6123
6265
  constant int64_t & ne01,
@@ -6138,28 +6280,19 @@ kernel void kernel_mul_mv_id_q4_1_f32(
6138
6280
  constant uint & r2,
6139
6281
  constant uint & r3,
6140
6282
  constant int & idx,
6141
- device const char * src00,
6142
- device const char * src01,
6143
- device const char * src02,
6144
- device const char * src03,
6145
- device const char * src04,
6146
- device const char * src05,
6147
- device const char * src06,
6148
- device const char * src07,
6149
6283
  uint3 tgpig[[threadgroup_position_in_grid]],
6150
6284
  uint tiitg[[thread_index_in_threadgroup]],
6151
6285
  uint tiisg[[thread_index_in_simdgroup]],
6152
6286
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
6153
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6154
-
6155
6287
  const int64_t bid = tgpig.z/(ne12*ne13);
6156
6288
 
6157
6289
  tgpig.z = tgpig.z%(ne12*ne13);
6158
6290
 
6159
6291
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6292
+ device const char * src0 = src0s + id*nb02;
6160
6293
 
6161
6294
  mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
6162
- src0[id],
6295
+ src0,
6163
6296
  (device const float *) (src1 + bid*nb11),
6164
6297
  dst + bid*ne0,
6165
6298
  ne00,
@@ -6178,9 +6311,10 @@ kernel void kernel_mul_mv_id_q4_1_f32(
6178
6311
 
6179
6312
  [[host_name("kernel_mul_mv_id_q5_0_f32")]]
6180
6313
  kernel void kernel_mul_mv_id_q5_0_f32(
6181
- device const char * ids,
6314
+ device const char * src0s,
6182
6315
  device const char * src1,
6183
6316
  device float * dst,
6317
+ device const char * ids,
6184
6318
  constant uint64_t & nbi1,
6185
6319
  constant int64_t & ne00,
6186
6320
  constant int64_t & ne01,
@@ -6201,28 +6335,19 @@ kernel void kernel_mul_mv_id_q5_0_f32(
6201
6335
  constant uint & r2,
6202
6336
  constant uint & r3,
6203
6337
  constant int & idx,
6204
- device const char * src00,
6205
- device const char * src01,
6206
- device const char * src02,
6207
- device const char * src03,
6208
- device const char * src04,
6209
- device const char * src05,
6210
- device const char * src06,
6211
- device const char * src07,
6212
6338
  uint3 tgpig[[threadgroup_position_in_grid]],
6213
6339
  uint tiitg[[thread_index_in_threadgroup]],
6214
6340
  uint tiisg[[thread_index_in_simdgroup]],
6215
6341
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
6216
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6217
-
6218
6342
  const int64_t bid = tgpig.z/(ne12*ne13);
6219
6343
 
6220
6344
  tgpig.z = tgpig.z%(ne12*ne13);
6221
6345
 
6222
6346
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6347
+ device const char * src0 = src0s + id*nb02;
6223
6348
 
6224
6349
  mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
6225
- src0[id],
6350
+ src0,
6226
6351
  (device const float *) (src1 + bid*nb11),
6227
6352
  dst + bid*ne0,
6228
6353
  ne00,
@@ -6241,9 +6366,10 @@ kernel void kernel_mul_mv_id_q5_0_f32(
6241
6366
 
6242
6367
  [[host_name("kernel_mul_mv_id_q5_1_f32")]]
6243
6368
  kernel void kernel_mul_mv_id_q5_1_f32(
6244
- device const char * ids,
6369
+ device const char * src0s,
6245
6370
  device const char * src1,
6246
6371
  device float * dst,
6372
+ device const char * ids,
6247
6373
  constant uint64_t & nbi1,
6248
6374
  constant int64_t & ne00,
6249
6375
  constant int64_t & ne01,
@@ -6264,28 +6390,19 @@ kernel void kernel_mul_mv_id_q5_1_f32(
6264
6390
  constant uint & r2,
6265
6391
  constant uint & r3,
6266
6392
  constant int & idx,
6267
- device const char * src00,
6268
- device const char * src01,
6269
- device const char * src02,
6270
- device const char * src03,
6271
- device const char * src04,
6272
- device const char * src05,
6273
- device const char * src06,
6274
- device const char * src07,
6275
6393
  uint3 tgpig[[threadgroup_position_in_grid]],
6276
6394
  uint tiitg[[thread_index_in_threadgroup]],
6277
6395
  uint tiisg[[thread_index_in_simdgroup]],
6278
6396
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
6279
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6280
-
6281
6397
  const int64_t bid = tgpig.z/(ne12*ne13);
6282
6398
 
6283
6399
  tgpig.z = tgpig.z%(ne12*ne13);
6284
6400
 
6285
6401
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6402
+ device const char * src0 = src0s + id*nb02;
6286
6403
 
6287
6404
  mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
6288
- src0[id],
6405
+ src0,
6289
6406
  (device const float *) (src1 + bid*nb11),
6290
6407
  dst + bid*ne0,
6291
6408
  ne00,
@@ -6304,9 +6421,10 @@ kernel void kernel_mul_mv_id_q5_1_f32(
6304
6421
 
6305
6422
  [[host_name("kernel_mul_mv_id_q2_K_f32")]]
6306
6423
  kernel void kernel_mul_mv_id_q2_K_f32(
6307
- device const char * ids,
6424
+ device const char * src0s,
6308
6425
  device const char * src1,
6309
6426
  device float * dst,
6427
+ device const char * ids,
6310
6428
  constant uint64_t & nbi1,
6311
6429
  constant int64_t & ne00,
6312
6430
  constant int64_t & ne01,
@@ -6327,28 +6445,19 @@ kernel void kernel_mul_mv_id_q2_K_f32(
6327
6445
  constant uint & r2,
6328
6446
  constant uint & r3,
6329
6447
  constant int & idx,
6330
- device const char * src00,
6331
- device const char * src01,
6332
- device const char * src02,
6333
- device const char * src03,
6334
- device const char * src04,
6335
- device const char * src05,
6336
- device const char * src06,
6337
- device const char * src07,
6338
6448
  uint3 tgpig[[threadgroup_position_in_grid]],
6339
6449
  uint tiitg[[thread_index_in_threadgroup]],
6340
6450
  uint tiisg[[thread_index_in_simdgroup]],
6341
6451
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
6342
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6343
-
6344
6452
  const int64_t bid = tgpig.z/(ne12*ne13);
6345
6453
 
6346
6454
  tgpig.z = tgpig.z%(ne12*ne13);
6347
6455
 
6348
6456
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6457
+ device const char * src0 = src0s + id*nb02;
6349
6458
 
6350
6459
  kernel_mul_mv_q2_K_f32_impl(
6351
- src0[id],
6460
+ src0,
6352
6461
  (device const float *) (src1 + bid*nb11),
6353
6462
  dst + bid*ne0,
6354
6463
  ne00,
@@ -6367,9 +6476,10 @@ kernel void kernel_mul_mv_id_q2_K_f32(
6367
6476
 
6368
6477
  [[host_name("kernel_mul_mv_id_q3_K_f32")]]
6369
6478
  kernel void kernel_mul_mv_id_q3_K_f32(
6370
- device const char * ids,
6479
+ device const char * src0s,
6371
6480
  device const char * src1,
6372
6481
  device float * dst,
6482
+ device const char * ids,
6373
6483
  constant uint64_t & nbi1,
6374
6484
  constant int64_t & ne00,
6375
6485
  constant int64_t & ne01,
@@ -6390,28 +6500,19 @@ kernel void kernel_mul_mv_id_q3_K_f32(
6390
6500
  constant uint & r2,
6391
6501
  constant uint & r3,
6392
6502
  constant int & idx,
6393
- device const char * src00,
6394
- device const char * src01,
6395
- device const char * src02,
6396
- device const char * src03,
6397
- device const char * src04,
6398
- device const char * src05,
6399
- device const char * src06,
6400
- device const char * src07,
6401
6503
  uint3 tgpig[[threadgroup_position_in_grid]],
6402
6504
  uint tiitg[[thread_index_in_threadgroup]],
6403
6505
  uint tiisg[[thread_index_in_simdgroup]],
6404
6506
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
6405
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6406
-
6407
6507
  const int64_t bid = tgpig.z/(ne12*ne13);
6408
6508
 
6409
6509
  tgpig.z = tgpig.z%(ne12*ne13);
6410
6510
 
6411
6511
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6512
+ device const char * src0 = src0s + id*nb02;
6412
6513
 
6413
6514
  kernel_mul_mv_q3_K_f32_impl(
6414
- src0[id],
6515
+ src0,
6415
6516
  (device const float *) (src1 + bid*nb11),
6416
6517
  dst + bid*ne0,
6417
6518
  ne00,
@@ -6430,9 +6531,10 @@ kernel void kernel_mul_mv_id_q3_K_f32(
6430
6531
 
6431
6532
  [[host_name("kernel_mul_mv_id_q4_K_f32")]]
6432
6533
  kernel void kernel_mul_mv_id_q4_K_f32(
6433
- device const char * ids,
6534
+ device const char * src0s,
6434
6535
  device const char * src1,
6435
6536
  device float * dst,
6537
+ device const char * ids,
6436
6538
  constant uint64_t & nbi1,
6437
6539
  constant int64_t & ne00,
6438
6540
  constant int64_t & ne01,
@@ -6453,28 +6555,19 @@ kernel void kernel_mul_mv_id_q4_K_f32(
6453
6555
  constant uint & r2,
6454
6556
  constant uint & r3,
6455
6557
  constant int & idx,
6456
- device const char * src00,
6457
- device const char * src01,
6458
- device const char * src02,
6459
- device const char * src03,
6460
- device const char * src04,
6461
- device const char * src05,
6462
- device const char * src06,
6463
- device const char * src07,
6464
6558
  uint3 tgpig[[threadgroup_position_in_grid]],
6465
6559
  uint tiitg[[thread_index_in_threadgroup]],
6466
6560
  uint tiisg[[thread_index_in_simdgroup]],
6467
6561
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
6468
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6469
-
6470
6562
  const int64_t bid = tgpig.z/(ne12*ne13);
6471
6563
 
6472
6564
  tgpig.z = tgpig.z%(ne12*ne13);
6473
6565
 
6474
6566
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6567
+ device const char * src0 = src0s + id*nb02;
6475
6568
 
6476
6569
  kernel_mul_mv_q4_K_f32_impl(
6477
- src0[id],
6570
+ src0,
6478
6571
  (device const float *) (src1 + bid*nb11),
6479
6572
  dst + bid*ne0,
6480
6573
  ne00,
@@ -6493,9 +6586,10 @@ kernel void kernel_mul_mv_id_q4_K_f32(
6493
6586
 
6494
6587
  [[host_name("kernel_mul_mv_id_q5_K_f32")]]
6495
6588
  kernel void kernel_mul_mv_id_q5_K_f32(
6496
- device const char * ids,
6589
+ device const char * src0s,
6497
6590
  device const char * src1,
6498
6591
  device float * dst,
6592
+ device const char * ids,
6499
6593
  constant uint64_t & nbi1,
6500
6594
  constant int64_t & ne00,
6501
6595
  constant int64_t & ne01,
@@ -6516,28 +6610,19 @@ kernel void kernel_mul_mv_id_q5_K_f32(
6516
6610
  constant uint & r2,
6517
6611
  constant uint & r3,
6518
6612
  constant int & idx,
6519
- device const char * src00,
6520
- device const char * src01,
6521
- device const char * src02,
6522
- device const char * src03,
6523
- device const char * src04,
6524
- device const char * src05,
6525
- device const char * src06,
6526
- device const char * src07,
6527
6613
  uint3 tgpig[[threadgroup_position_in_grid]],
6528
6614
  uint tiitg[[thread_index_in_threadgroup]],
6529
6615
  uint tiisg[[thread_index_in_simdgroup]],
6530
6616
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
6531
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6532
-
6533
6617
  const int64_t bid = tgpig.z/(ne12*ne13);
6534
6618
 
6535
6619
  tgpig.z = tgpig.z%(ne12*ne13);
6536
6620
 
6537
6621
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6622
+ device const char * src0 = src0s + id*nb02;
6538
6623
 
6539
6624
  kernel_mul_mv_q5_K_f32_impl(
6540
- src0[id],
6625
+ src0,
6541
6626
  (device const float *) (src1 + bid*nb11),
6542
6627
  dst + bid*ne0,
6543
6628
  ne00,
@@ -6556,9 +6641,10 @@ kernel void kernel_mul_mv_id_q5_K_f32(
6556
6641
 
6557
6642
  [[host_name("kernel_mul_mv_id_q6_K_f32")]]
6558
6643
  kernel void kernel_mul_mv_id_q6_K_f32(
6559
- device const char * ids,
6644
+ device const char * src0s,
6560
6645
  device const char * src1,
6561
6646
  device float * dst,
6647
+ device const char * ids,
6562
6648
  constant uint64_t & nbi1,
6563
6649
  constant int64_t & ne00,
6564
6650
  constant int64_t & ne01,
@@ -6579,28 +6665,19 @@ kernel void kernel_mul_mv_id_q6_K_f32(
6579
6665
  constant uint & r2,
6580
6666
  constant uint & r3,
6581
6667
  constant int & idx,
6582
- device const char * src00,
6583
- device const char * src01,
6584
- device const char * src02,
6585
- device const char * src03,
6586
- device const char * src04,
6587
- device const char * src05,
6588
- device const char * src06,
6589
- device const char * src07,
6590
6668
  uint3 tgpig[[threadgroup_position_in_grid]],
6591
6669
  uint tiitg[[thread_index_in_threadgroup]],
6592
6670
  uint tiisg[[thread_index_in_simdgroup]],
6593
6671
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
6594
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6595
-
6596
6672
  const int64_t bid = tgpig.z/(ne12*ne13);
6597
6673
 
6598
6674
  tgpig.z = tgpig.z%(ne12*ne13);
6599
6675
 
6600
6676
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6677
+ device const char * src0 = src0s + id*nb02;
6601
6678
 
6602
6679
  kernel_mul_mv_q6_K_f32_impl(
6603
- src0[id],
6680
+ src0,
6604
6681
  (device const float *) (src1 + bid*nb11),
6605
6682
  dst + bid*ne0,
6606
6683
  ne00,
@@ -6619,9 +6696,10 @@ kernel void kernel_mul_mv_id_q6_K_f32(
6619
6696
 
6620
6697
  [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]]
6621
6698
  kernel void kernel_mul_mv_id_iq2_xxs_f32(
6622
- device const char * ids,
6699
+ device const char * src0s,
6623
6700
  device const char * src1,
6624
6701
  device float * dst,
6702
+ device const char * ids,
6625
6703
  constant uint64_t & nbi1,
6626
6704
  constant int64_t & ne00,
6627
6705
  constant int64_t & ne01,
@@ -6642,29 +6720,20 @@ kernel void kernel_mul_mv_id_iq2_xxs_f32(
6642
6720
  constant uint & r2,
6643
6721
  constant uint & r3,
6644
6722
  constant int & idx,
6645
- device const char * src00,
6646
- device const char * src01,
6647
- device const char * src02,
6648
- device const char * src03,
6649
- device const char * src04,
6650
- device const char * src05,
6651
- device const char * src06,
6652
- device const char * src07,
6653
6723
  threadgroup int8_t * shared_values [[threadgroup(0)]],
6654
6724
  uint3 tgpig[[threadgroup_position_in_grid]],
6655
6725
  uint tiitg[[thread_index_in_threadgroup]],
6656
6726
  uint tiisg[[thread_index_in_simdgroup]],
6657
6727
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
6658
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6659
-
6660
6728
  const int64_t bid = tgpig.z/(ne12*ne13);
6661
6729
 
6662
6730
  tgpig.z = tgpig.z%(ne12*ne13);
6663
6731
 
6664
6732
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6733
+ device const char * src0 = src0s + id*nb02;
6665
6734
 
6666
6735
  kernel_mul_mv_iq2_xxs_f32_impl(
6667
- src0[id],
6736
+ src0,
6668
6737
  (device const float *) (src1 + bid*nb11),
6669
6738
  dst + bid*ne0,
6670
6739
  ne00,
@@ -6684,9 +6753,10 @@ kernel void kernel_mul_mv_id_iq2_xxs_f32(
6684
6753
 
6685
6754
  [[host_name("kernel_mul_mv_id_iq2_xs_f32")]]
6686
6755
  kernel void kernel_mul_mv_id_iq2_xs_f32(
6687
- device const char * ids,
6756
+ device const char * src0s,
6688
6757
  device const char * src1,
6689
6758
  device float * dst,
6759
+ device const char * ids,
6690
6760
  constant uint64_t & nbi1,
6691
6761
  constant int64_t & ne00,
6692
6762
  constant int64_t & ne01,
@@ -6707,29 +6777,20 @@ kernel void kernel_mul_mv_id_iq2_xs_f32(
6707
6777
  constant uint & r2,
6708
6778
  constant uint & r3,
6709
6779
  constant int & idx,
6710
- device const char * src00,
6711
- device const char * src01,
6712
- device const char * src02,
6713
- device const char * src03,
6714
- device const char * src04,
6715
- device const char * src05,
6716
- device const char * src06,
6717
- device const char * src07,
6718
6780
  threadgroup int8_t * shared_values [[threadgroup(0)]],
6719
6781
  uint3 tgpig[[threadgroup_position_in_grid]],
6720
6782
  uint tiitg[[thread_index_in_threadgroup]],
6721
6783
  uint tiisg[[thread_index_in_simdgroup]],
6722
6784
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
6723
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6724
-
6725
6785
  const int64_t bid = tgpig.z/(ne12*ne13);
6726
6786
 
6727
6787
  tgpig.z = tgpig.z%(ne12*ne13);
6728
6788
 
6729
6789
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6790
+ device const char * src0 = src0s + id*nb02;
6730
6791
 
6731
6792
  kernel_mul_mv_iq2_xs_f32_impl(
6732
- src0[id],
6793
+ src0,
6733
6794
  (device const float *) (src1 + bid*nb11),
6734
6795
  dst + bid*ne0,
6735
6796
  ne00,
@@ -6749,9 +6810,10 @@ kernel void kernel_mul_mv_id_iq2_xs_f32(
6749
6810
 
6750
6811
  [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]]
6751
6812
  kernel void kernel_mul_mv_id_iq3_xxs_f32(
6752
- device const char * ids,
6813
+ device const char * src0s,
6753
6814
  device const char * src1,
6754
6815
  device float * dst,
6816
+ device const char * ids,
6755
6817
  constant uint64_t & nbi1,
6756
6818
  constant int64_t & ne00,
6757
6819
  constant int64_t & ne01,
@@ -6772,29 +6834,20 @@ kernel void kernel_mul_mv_id_iq3_xxs_f32(
6772
6834
  constant uint & r2,
6773
6835
  constant uint & r3,
6774
6836
  constant int & idx,
6775
- device const char * src00,
6776
- device const char * src01,
6777
- device const char * src02,
6778
- device const char * src03,
6779
- device const char * src04,
6780
- device const char * src05,
6781
- device const char * src06,
6782
- device const char * src07,
6783
6837
  threadgroup int8_t * shared_values [[threadgroup(0)]],
6784
6838
  uint3 tgpig[[threadgroup_position_in_grid]],
6785
6839
  uint tiitg[[thread_index_in_threadgroup]],
6786
6840
  uint tiisg[[thread_index_in_simdgroup]],
6787
6841
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
6788
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6789
-
6790
6842
  const int64_t bid = tgpig.z/(ne12*ne13);
6791
6843
 
6792
6844
  tgpig.z = tgpig.z%(ne12*ne13);
6793
6845
 
6794
6846
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6847
+ device const char * src0 = src0s + id*nb02;
6795
6848
 
6796
6849
  kernel_mul_mv_iq3_xxs_f32_impl(
6797
- src0[id],
6850
+ src0,
6798
6851
  (device const float *) (src1 + bid*nb11),
6799
6852
  dst + bid*ne0,
6800
6853
  ne00,
@@ -6814,9 +6867,10 @@ kernel void kernel_mul_mv_id_iq3_xxs_f32(
6814
6867
 
6815
6868
  [[host_name("kernel_mul_mv_id_iq3_s_f32")]]
6816
6869
  kernel void kernel_mul_mv_id_iq3_s_f32(
6817
- device const char * ids,
6870
+ device const char * src0s,
6818
6871
  device const char * src1,
6819
6872
  device float * dst,
6873
+ device const char * ids,
6820
6874
  constant uint64_t & nbi1,
6821
6875
  constant int64_t & ne00,
6822
6876
  constant int64_t & ne01,
@@ -6837,29 +6891,20 @@ kernel void kernel_mul_mv_id_iq3_s_f32(
6837
6891
  constant uint & r2,
6838
6892
  constant uint & r3,
6839
6893
  constant int & idx,
6840
- device const char * src00,
6841
- device const char * src01,
6842
- device const char * src02,
6843
- device const char * src03,
6844
- device const char * src04,
6845
- device const char * src05,
6846
- device const char * src06,
6847
- device const char * src07,
6848
6894
  threadgroup int8_t * shared_values [[threadgroup(0)]],
6849
6895
  uint3 tgpig[[threadgroup_position_in_grid]],
6850
6896
  uint tiitg[[thread_index_in_threadgroup]],
6851
6897
  uint tiisg[[thread_index_in_simdgroup]],
6852
6898
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
6853
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6854
-
6855
6899
  const int64_t bid = tgpig.z/(ne12*ne13);
6856
6900
 
6857
6901
  tgpig.z = tgpig.z%(ne12*ne13);
6858
6902
 
6859
6903
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6904
+ device const char * src0 = src0s + id*nb02;
6860
6905
 
6861
6906
  kernel_mul_mv_iq3_s_f32_impl(
6862
- src0[id],
6907
+ src0,
6863
6908
  (device const float *) (src1 + bid*nb11),
6864
6909
  dst + bid*ne0,
6865
6910
  ne00,
@@ -6879,9 +6924,10 @@ kernel void kernel_mul_mv_id_iq3_s_f32(
6879
6924
 
6880
6925
  [[host_name("kernel_mul_mv_id_iq2_s_f32")]]
6881
6926
  kernel void kernel_mul_mv_id_iq2_s_f32(
6882
- device const char * ids,
6927
+ device const char * src0s,
6883
6928
  device const char * src1,
6884
6929
  device float * dst,
6930
+ device const char * ids,
6885
6931
  constant uint64_t & nbi1,
6886
6932
  constant int64_t & ne00,
6887
6933
  constant int64_t & ne01,
@@ -6902,29 +6948,20 @@ kernel void kernel_mul_mv_id_iq2_s_f32(
6902
6948
  constant uint & r2,
6903
6949
  constant uint & r3,
6904
6950
  constant int & idx,
6905
- device const char * src00,
6906
- device const char * src01,
6907
- device const char * src02,
6908
- device const char * src03,
6909
- device const char * src04,
6910
- device const char * src05,
6911
- device const char * src06,
6912
- device const char * src07,
6913
6951
  threadgroup int8_t * shared_values [[threadgroup(0)]],
6914
6952
  uint3 tgpig[[threadgroup_position_in_grid]],
6915
6953
  uint tiitg[[thread_index_in_threadgroup]],
6916
6954
  uint tiisg[[thread_index_in_simdgroup]],
6917
6955
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
6918
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6919
-
6920
6956
  const int64_t bid = tgpig.z/(ne12*ne13);
6921
6957
 
6922
6958
  tgpig.z = tgpig.z%(ne12*ne13);
6923
6959
 
6924
6960
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6961
+ device const char * src0 = src0s + id*nb02;
6925
6962
 
6926
6963
  kernel_mul_mv_iq2_s_f32_impl(
6927
- src0[id],
6964
+ src0,
6928
6965
  (device const float *) (src1 + bid*nb11),
6929
6966
  dst + bid*ne0,
6930
6967
  ne00,
@@ -6944,9 +6981,10 @@ kernel void kernel_mul_mv_id_iq2_s_f32(
6944
6981
 
6945
6982
  [[host_name("kernel_mul_mv_id_iq1_s_f32")]]
6946
6983
  kernel void kernel_mul_mv_id_iq1_s_f32(
6947
- device const char * ids,
6984
+ device const char * src0s,
6948
6985
  device const char * src1,
6949
6986
  device float * dst,
6987
+ device const char * ids,
6950
6988
  constant uint64_t & nbi1,
6951
6989
  constant int64_t & ne00,
6952
6990
  constant int64_t & ne01,
@@ -6967,28 +7005,74 @@ kernel void kernel_mul_mv_id_iq1_s_f32(
6967
7005
  constant uint & r2,
6968
7006
  constant uint & r3,
6969
7007
  constant int & idx,
6970
- device const char * src00,
6971
- device const char * src01,
6972
- device const char * src02,
6973
- device const char * src03,
6974
- device const char * src04,
6975
- device const char * src05,
6976
- device const char * src06,
6977
- device const char * src07,
6978
7008
  uint3 tgpig[[threadgroup_position_in_grid]],
6979
7009
  uint tiitg[[thread_index_in_threadgroup]],
6980
7010
  uint tiisg[[thread_index_in_simdgroup]],
6981
7011
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
6982
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6983
-
6984
7012
  const int64_t bid = tgpig.z/(ne12*ne13);
6985
7013
 
6986
7014
  tgpig.z = tgpig.z%(ne12*ne13);
6987
7015
 
6988
7016
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
7017
+ device const char * src0 = src0s + id*nb02;
6989
7018
 
6990
7019
  kernel_mul_mv_iq1_s_f32_impl(
6991
- src0[id],
7020
+ src0,
7021
+ (device const float *) (src1 + bid*nb11),
7022
+ dst + bid*ne0,
7023
+ ne00,
7024
+ ne01,
7025
+ ne02,
7026
+ ne10,
7027
+ ne12,
7028
+ ne0,
7029
+ ne1,
7030
+ r2,
7031
+ r3,
7032
+ tgpig,
7033
+ tiisg,
7034
+ sgitg);
7035
+ }
7036
+
7037
+ [[host_name("kernel_mul_mv_id_iq1_m_f32")]]
7038
+ kernel void kernel_mul_mv_id_iq1_m_f32(
7039
+ device const char * src0s,
7040
+ device const char * src1,
7041
+ device float * dst,
7042
+ device const char * ids,
7043
+ constant uint64_t & nbi1,
7044
+ constant int64_t & ne00,
7045
+ constant int64_t & ne01,
7046
+ constant int64_t & ne02,
7047
+ constant uint64_t & nb00,
7048
+ constant uint64_t & nb01,
7049
+ constant uint64_t & nb02,
7050
+ constant int64_t & ne10,
7051
+ constant int64_t & ne11,
7052
+ constant int64_t & ne12,
7053
+ constant int64_t & ne13,
7054
+ constant uint64_t & nb10,
7055
+ constant uint64_t & nb11,
7056
+ constant uint64_t & nb12,
7057
+ constant int64_t & ne0,
7058
+ constant int64_t & ne1,
7059
+ constant uint64_t & nb1,
7060
+ constant uint & r2,
7061
+ constant uint & r3,
7062
+ constant int & idx,
7063
+ uint3 tgpig[[threadgroup_position_in_grid]],
7064
+ uint tiitg[[thread_index_in_threadgroup]],
7065
+ uint tiisg[[thread_index_in_simdgroup]],
7066
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
7067
+ const int64_t bid = tgpig.z/(ne12*ne13);
7068
+
7069
+ tgpig.z = tgpig.z%(ne12*ne13);
7070
+
7071
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
7072
+ device const char * src0 = src0s + id*nb02;
7073
+
7074
+ kernel_mul_mv_iq1_m_f32_impl(
7075
+ src0,
6992
7076
  (device const float *) (src1 + bid*nb11),
6993
7077
  dst + bid*ne0,
6994
7078
  ne00,
@@ -7007,9 +7091,10 @@ kernel void kernel_mul_mv_id_iq1_s_f32(
7007
7091
 
7008
7092
  [[host_name("kernel_mul_mv_id_iq4_nl_f32")]]
7009
7093
  kernel void kernel_mul_mv_id_iq4_nl_f32(
7010
- device const char * ids,
7094
+ device const char * src0s,
7011
7095
  device const char * src1,
7012
7096
  device float * dst,
7097
+ device const char * ids,
7013
7098
  constant uint64_t & nbi1,
7014
7099
  constant int64_t & ne00,
7015
7100
  constant int64_t & ne01,
@@ -7030,29 +7115,20 @@ kernel void kernel_mul_mv_id_iq4_nl_f32(
7030
7115
  constant uint & r2,
7031
7116
  constant uint & r3,
7032
7117
  constant int & idx,
7033
- device const char * src00,
7034
- device const char * src01,
7035
- device const char * src02,
7036
- device const char * src03,
7037
- device const char * src04,
7038
- device const char * src05,
7039
- device const char * src06,
7040
- device const char * src07,
7041
7118
  threadgroup float * shared_values [[threadgroup(0)]],
7042
7119
  uint3 tgpig[[threadgroup_position_in_grid]],
7043
7120
  uint tiitg[[thread_index_in_threadgroup]],
7044
7121
  uint tiisg[[thread_index_in_simdgroup]],
7045
7122
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
7046
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
7047
-
7048
7123
  const int64_t bid = tgpig.z/(ne12*ne13);
7049
7124
 
7050
7125
  tgpig.z = tgpig.z%(ne12*ne13);
7051
7126
 
7052
7127
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
7128
+ device const char * src0 = src0s + id*nb02;
7053
7129
 
7054
7130
  kernel_mul_mv_iq4_nl_f32_impl(
7055
- src0[id],
7131
+ src0,
7056
7132
  (device const float *) (src1 + bid*nb11),
7057
7133
  dst + bid*ne0,
7058
7134
  ne00,
@@ -7072,9 +7148,10 @@ kernel void kernel_mul_mv_id_iq4_nl_f32(
7072
7148
 
7073
7149
  [[host_name("kernel_mul_mv_id_iq4_xs_f32")]]
7074
7150
  kernel void kernel_mul_mv_id_iq4_xs_f32(
7075
- device const char * ids,
7151
+ device const char * src0s,
7076
7152
  device const char * src1,
7077
7153
  device float * dst,
7154
+ device const char * ids,
7078
7155
  constant uint64_t & nbi1,
7079
7156
  constant int64_t & ne00,
7080
7157
  constant int64_t & ne01,
@@ -7095,33 +7172,24 @@ kernel void kernel_mul_mv_id_iq4_xs_f32(
7095
7172
  constant uint & r2,
7096
7173
  constant uint & r3,
7097
7174
  constant int & idx,
7098
- device const char * src00,
7099
- device const char * src01,
7100
- device const char * src02,
7101
- device const char * src03,
7102
- device const char * src04,
7103
- device const char * src05,
7104
- device const char * src06,
7105
- device const char * src07,
7106
7175
  threadgroup float * shared_values [[threadgroup(0)]],
7107
7176
  uint3 tgpig[[threadgroup_position_in_grid]],
7108
7177
  uint tiitg[[thread_index_in_threadgroup]],
7109
7178
  uint tiisg[[thread_index_in_simdgroup]],
7110
7179
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
7111
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
7112
-
7113
7180
  const int64_t bid = tgpig.z/(ne12*ne13);
7114
7181
 
7115
7182
  tgpig.z = tgpig.z%(ne12*ne13);
7116
7183
 
7117
7184
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
7185
+ device const char * src0 = src0s + id*nb02;
7118
7186
 
7119
7187
  #if QK_K == 64
7120
7188
  kernel_mul_mv_iq4_nl_f32_impl(
7121
7189
  #else
7122
7190
  kernel_mul_mv_iq4_xs_f32_impl(
7123
7191
  #endif
7124
- src0[id],
7192
+ src0,
7125
7193
  (device const float *) (src1 + bid*nb11),
7126
7194
  dst + bid*ne0,
7127
7195
  ne00,