llama_cpp 0.14.2 → 0.14.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -13,8 +13,8 @@ using namespace metal;
13
13
  #define N_SIMDWIDTH 32 // assuming SIMD group size is 32
14
14
 
15
15
  enum ggml_sort_order {
16
- GGML_SORT_ASC,
17
- GGML_SORT_DESC,
16
+ GGML_SORT_ORDER_ASC,
17
+ GGML_SORT_ORDER_DESC,
18
18
  };
19
19
 
20
20
  // general-purpose kernel for addition, multiplication and division of two tensors
@@ -1973,9 +1973,11 @@ kernel void kernel_timestep_embedding_f32(
1973
1973
 
1974
1974
  // bitonic sort implementation following the CUDA kernels as reference
1975
1975
  typedef void (argsort_t)(
1976
- device const float * x,
1977
- device int32_t * dst,
1978
- constant int64_t & ncols,
1976
+ device const float * x,
1977
+ device int32_t * dst,
1978
+ constant int64_t & ncols,
1979
+ constant int64_t & ncols_pad,
1980
+ threadgroup int32_t * shared_values [[threadgroup(0)]],
1979
1981
  uint3 tgpig[[threadgroup_position_in_grid]],
1980
1982
  uint3 tpitg[[thread_position_in_threadgroup]]);
1981
1983
 
@@ -1984,33 +1986,42 @@ kernel void kernel_argsort_f32_i32(
1984
1986
  device const float * x,
1985
1987
  device int32_t * dst,
1986
1988
  constant int64_t & ncols,
1989
+ constant int64_t & ncols_pad,
1990
+ threadgroup int32_t * shared_values [[threadgroup(0)]],
1987
1991
  uint3 tgpig[[threadgroup_position_in_grid]],
1988
1992
  uint3 tpitg[[thread_position_in_threadgroup]]) {
1989
1993
  // bitonic sort
1990
1994
  int col = tpitg[0];
1991
1995
  int row = tgpig[1];
1992
1996
 
1993
- if (col >= ncols) return;
1997
+ if (col >= ncols_pad) return;
1994
1998
 
1995
- device const float * x_row = x + row * ncols;
1996
- device int32_t * dst_row = dst + row * ncols;
1999
+ device const float * x_row = x + row * ncols;
2000
+ threadgroup int32_t * dst_row = shared_values;
1997
2001
 
1998
2002
  // initialize indices
1999
- if (col < ncols) {
2000
- dst_row[col] = col;
2001
- }
2003
+ dst_row[col] = col;
2004
+
2002
2005
  threadgroup_barrier(mem_flags::mem_threadgroup);
2003
2006
 
2004
- for (int k = 2; k <= ncols; k *= 2) {
2007
+ for (int k = 2; k <= ncols_pad; k *= 2) {
2005
2008
  for (int j = k / 2; j > 0; j /= 2) {
2006
2009
  int ixj = col ^ j;
2007
2010
  if (ixj > col) {
2008
2011
  if ((col & k) == 0) {
2009
- if (order == GGML_SORT_ASC ? x_row[dst_row[col]] > x_row[dst_row[ixj]] : x_row[dst_row[col]] < x_row[dst_row[ixj]]) {
2012
+ if (dst_row[col] >= ncols ||
2013
+ (dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ?
2014
+ x_row[dst_row[col]] > x_row[dst_row[ixj]] :
2015
+ x_row[dst_row[col]] < x_row[dst_row[ixj]]))
2016
+ ) {
2010
2017
  SWAP(dst_row[col], dst_row[ixj]);
2011
2018
  }
2012
2019
  } else {
2013
- if (order == GGML_SORT_ASC ? x_row[dst_row[col]] < x_row[dst_row[ixj]] : x_row[dst_row[col]] > x_row[dst_row[ixj]]) {
2020
+ if (dst_row[ixj] >= ncols ||
2021
+ (dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ?
2022
+ x_row[dst_row[col]] < x_row[dst_row[ixj]] :
2023
+ x_row[dst_row[col]] > x_row[dst_row[ixj]]))
2024
+ ) {
2014
2025
  SWAP(dst_row[col], dst_row[ixj]);
2015
2026
  }
2016
2027
  }
@@ -2018,10 +2029,15 @@ kernel void kernel_argsort_f32_i32(
2018
2029
  threadgroup_barrier(mem_flags::mem_threadgroup);
2019
2030
  }
2020
2031
  }
2032
+
2033
+ // copy the result to dst without the padding
2034
+ if (col < ncols) {
2035
+ dst[row * ncols + col] = dst_row[col];
2036
+ }
2021
2037
  }
2022
2038
 
2023
- template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ASC>;
2024
- template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_DESC>;
2039
+ template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_ASC>;
2040
+ template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_DESC>;
2025
2041
 
2026
2042
  kernel void kernel_leaky_relu_f32(
2027
2043
  device const float * src0,
@@ -2388,6 +2404,242 @@ kernel void kernel_cpy_f32_q4_1(
2388
2404
  }
2389
2405
  }
2390
2406
 
2407
+ kernel void kernel_cpy_f32_q5_0(
2408
+ device const float * src0,
2409
+ device void * dst,
2410
+ constant int64_t & ne00,
2411
+ constant int64_t & ne01,
2412
+ constant int64_t & ne02,
2413
+ constant int64_t & ne03,
2414
+ constant uint64_t & nb00,
2415
+ constant uint64_t & nb01,
2416
+ constant uint64_t & nb02,
2417
+ constant uint64_t & nb03,
2418
+ constant int64_t & ne0,
2419
+ constant int64_t & ne1,
2420
+ constant int64_t & ne2,
2421
+ constant int64_t & ne3,
2422
+ constant uint64_t & nb0,
2423
+ constant uint64_t & nb1,
2424
+ constant uint64_t & nb2,
2425
+ constant uint64_t & nb3,
2426
+ uint3 tgpig[[threadgroup_position_in_grid]],
2427
+ uint3 tpitg[[thread_position_in_threadgroup]],
2428
+ uint3 ntg[[threads_per_threadgroup]]) {
2429
+ const int64_t i03 = tgpig[2];
2430
+ const int64_t i02 = tgpig[1];
2431
+ const int64_t i01 = tgpig[0];
2432
+
2433
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
2434
+
2435
+ const int64_t i3 = n / (ne2*ne1*ne0);
2436
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
2437
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
2438
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK5_0;
2439
+
2440
+ device block_q5_0 * dst_data = (device block_q5_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
2441
+
2442
+ for (int64_t i00 = tpitg.x*QK5_0; i00 < ne00; i00 += ntg.x*QK5_0) {
2443
+ device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
2444
+
2445
+ float amax = 0.0f; // absolute max
2446
+ float max = 0.0f;
2447
+
2448
+ for (int j = 0; j < QK5_0; j++) {
2449
+ const float v = src[j];
2450
+ if (amax < fabs(v)) {
2451
+ amax = fabs(v);
2452
+ max = v;
2453
+ }
2454
+ }
2455
+
2456
+ const float d = max / -16;
2457
+ const float id = d ? 1.0f/d : 0.0f;
2458
+
2459
+ dst_data[i00/QK5_0].d = d;
2460
+
2461
+ uint32_t qh = 0;
2462
+ for (int j = 0; j < QK5_0/2; ++j) {
2463
+ const float x0 = src[0 + j]*id;
2464
+ const float x1 = src[QK5_0/2 + j]*id;
2465
+
2466
+ const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));
2467
+ const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));
2468
+
2469
+ dst_data[i00/QK5_0].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
2470
+ qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
2471
+ qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
2472
+ }
2473
+ thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
2474
+ for (int j = 0; j < 4; ++j) {
2475
+ dst_data[i00/QK5_0].qh[j] = qh8[j];
2476
+ }
2477
+ }
2478
+ }
2479
+
2480
+ kernel void kernel_cpy_f32_q5_1(
2481
+ device const float * src0,
2482
+ device void * dst,
2483
+ constant int64_t & ne00,
2484
+ constant int64_t & ne01,
2485
+ constant int64_t & ne02,
2486
+ constant int64_t & ne03,
2487
+ constant uint64_t & nb00,
2488
+ constant uint64_t & nb01,
2489
+ constant uint64_t & nb02,
2490
+ constant uint64_t & nb03,
2491
+ constant int64_t & ne0,
2492
+ constant int64_t & ne1,
2493
+ constant int64_t & ne2,
2494
+ constant int64_t & ne3,
2495
+ constant uint64_t & nb0,
2496
+ constant uint64_t & nb1,
2497
+ constant uint64_t & nb2,
2498
+ constant uint64_t & nb3,
2499
+ uint3 tgpig[[threadgroup_position_in_grid]],
2500
+ uint3 tpitg[[thread_position_in_threadgroup]],
2501
+ uint3 ntg[[threads_per_threadgroup]]) {
2502
+ const int64_t i03 = tgpig[2];
2503
+ const int64_t i02 = tgpig[1];
2504
+ const int64_t i01 = tgpig[0];
2505
+
2506
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
2507
+
2508
+ const int64_t i3 = n / (ne2*ne1*ne0);
2509
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
2510
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
2511
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK5_1;
2512
+
2513
+ device block_q5_1 * dst_data = (device block_q5_1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
2514
+
2515
+ for (int64_t i00 = tpitg.x*QK5_1; i00 < ne00; i00 += ntg.x*QK5_1) {
2516
+ device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
2517
+
2518
+ float max = src[0];
2519
+ float min = src[0];
2520
+
2521
+ for (int j = 1; j < QK5_1; j++) {
2522
+ const float v = src[j];
2523
+ min = v < min ? v : min;
2524
+ max = v > max ? v : max;
2525
+ }
2526
+
2527
+ const float d = (max - min) / 31;
2528
+ const float id = d ? 1.0f/d : 0.0f;
2529
+
2530
+ dst_data[i00/QK5_1].d = d;
2531
+ dst_data[i00/QK5_1].m = min;
2532
+
2533
+ uint32_t qh = 0;
2534
+ for (int j = 0; j < QK5_1/2; ++j) {
2535
+ const float x0 = (src[0 + j] - min)*id;
2536
+ const float x1 = (src[QK5_1/2 + j] - min)*id;
2537
+
2538
+ const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
2539
+ const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
2540
+
2541
+ dst_data[i00/QK5_1].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
2542
+ qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
2543
+ qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
2544
+ }
2545
+ thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
2546
+ for (int j = 0; j < 4; ++j) {
2547
+ dst_data[i00/QK5_1].qh[j] = qh8[j];
2548
+ }
2549
+ }
2550
+ }
2551
+
2552
+ static inline int best_index_int8(int n, constant float * val, float x) {
2553
+ if (x <= val[0]) return 0;
2554
+ if (x >= val[n-1]) return n-1;
2555
+ int ml = 0, mu = n-1;
2556
+ while (mu-ml > 1) {
2557
+ int mav = (ml+mu)/2;
2558
+ if (x < val[mav]) mu = mav; else ml = mav;
2559
+ }
2560
+ return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
2561
+ }
2562
+
2563
+ constexpr constant static float kvalues_iq4nl_f[16] = {
2564
+ -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
2565
+ };
2566
+
2567
+ kernel void kernel_cpy_f32_iq4_nl(
2568
+ device const float * src0,
2569
+ device void * dst,
2570
+ constant int64_t & ne00,
2571
+ constant int64_t & ne01,
2572
+ constant int64_t & ne02,
2573
+ constant int64_t & ne03,
2574
+ constant uint64_t & nb00,
2575
+ constant uint64_t & nb01,
2576
+ constant uint64_t & nb02,
2577
+ constant uint64_t & nb03,
2578
+ constant int64_t & ne0,
2579
+ constant int64_t & ne1,
2580
+ constant int64_t & ne2,
2581
+ constant int64_t & ne3,
2582
+ constant uint64_t & nb0,
2583
+ constant uint64_t & nb1,
2584
+ constant uint64_t & nb2,
2585
+ constant uint64_t & nb3,
2586
+ uint3 tgpig[[threadgroup_position_in_grid]],
2587
+ uint3 tpitg[[thread_position_in_threadgroup]],
2588
+ uint3 ntg[[threads_per_threadgroup]]) {
2589
+ const int64_t i03 = tgpig[2];
2590
+ const int64_t i02 = tgpig[1];
2591
+ const int64_t i01 = tgpig[0];
2592
+
2593
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
2594
+
2595
+ const int64_t i3 = n / (ne2*ne1*ne0);
2596
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
2597
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
2598
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_NL;
2599
+
2600
+ device block_iq4_nl * dst_data = (device block_iq4_nl *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
2601
+
2602
+ for (int64_t i00 = tpitg.x*QK4_NL; i00 < ne00; i00 += ntg.x*QK4_NL) {
2603
+ device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
2604
+
2605
+ float amax = 0.0f; // absolute max
2606
+ float max = 0.0f;
2607
+
2608
+ for (int j = 0; j < QK4_0; j++) {
2609
+ const float v = src[j];
2610
+ if (amax < fabs(v)) {
2611
+ amax = fabs(v);
2612
+ max = v;
2613
+ }
2614
+ }
2615
+
2616
+ const float d = max / kvalues_iq4nl_f[0];
2617
+ const float id = d ? 1.0f/d : 0.0f;
2618
+
2619
+ float sumqx = 0, sumq2 = 0;
2620
+ for (int j = 0; j < QK4_NL/2; ++j) {
2621
+ const float x0 = src[0 + j]*id;
2622
+ const float x1 = src[QK4_NL/2 + j]*id;
2623
+
2624
+ const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl_f, x0);
2625
+ const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl_f, x1);
2626
+
2627
+ dst_data[i00/QK4_NL].qs[j] = xi0 | (xi1 << 4);
2628
+
2629
+ const float v0 = kvalues_iq4nl_f[xi0];
2630
+ const float v1 = kvalues_iq4nl_f[xi1];
2631
+ const float w0 = src[0 + j]*src[0 + j];
2632
+ const float w1 = src[QK4_NL/2 + j]*src[QK4_NL/2 + j];
2633
+ sumqx += w0*v0*src[j] + w1*v1*src[QK4_NL/2 + j];
2634
+ sumq2 += w0*v0*v0 + w1*v1*v1;
2635
+
2636
+ }
2637
+
2638
+ dst_data[i00/QK4_NL].d = sumq2 > 0 ? sumqx/sumq2 : d;
2639
+
2640
+ }
2641
+ }
2642
+
2391
2643
  kernel void kernel_concat(
2392
2644
  device const char * src0,
2393
2645
  device const char * src1,
@@ -4220,9 +4472,113 @@ void kernel_mul_mv_iq1_s_f32_impl(
4220
4472
  }
4221
4473
  }
4222
4474
 
4223
- constexpr constant static float kvalues_iq4nl_f[16] = {
4224
- -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
4225
- };
4475
+ void kernel_mul_mv_iq1_m_f32_impl(
4476
+ device const void * src0,
4477
+ device const float * src1,
4478
+ device float * dst,
4479
+ constant int64_t & ne00,
4480
+ constant int64_t & ne01,
4481
+ constant int64_t & ne02,
4482
+ constant int64_t & ne10,
4483
+ constant int64_t & ne12,
4484
+ constant int64_t & ne0,
4485
+ constant int64_t & ne1,
4486
+ constant uint & r2,
4487
+ constant uint & r3,
4488
+ uint3 tgpig[[threadgroup_position_in_grid]],
4489
+ uint tiisg[[thread_index_in_simdgroup]],
4490
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4491
+
4492
+ const int nb = ne00/QK_K;
4493
+ const int r0 = tgpig.x;
4494
+ const int r1 = tgpig.y;
4495
+ const int im = tgpig.z;
4496
+
4497
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
4498
+ const int ib_row = first_row * nb;
4499
+
4500
+ const uint i12 = im%ne12;
4501
+ const uint i13 = im/ne12;
4502
+
4503
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
4504
+ device const block_iq1_m * x = (device const block_iq1_m *) src0 + ib_row + offset0;
4505
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
4506
+
4507
+ float yl[32];
4508
+ float sumf[N_DST]={0.f}, all_sum;
4509
+
4510
+ const int nb32 = nb * (QK_K / 32);
4511
+
4512
+ const int ix = tiisg;
4513
+
4514
+ device const float * y4 = y + 32 * ix;
4515
+
4516
+ #if QK_K != 64
4517
+ iq1m_scale_t scale;
4518
+ #endif
4519
+
4520
+ for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
4521
+
4522
+ float4 sumy = {0.f};
4523
+ for (int i = 0; i < 8; ++i) {
4524
+ yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
4525
+ yl[i+ 8] = y4[i+ 8]; sumy[1] += yl[i+ 8];
4526
+ yl[i+16] = y4[i+16]; sumy[2] += yl[i+16];
4527
+ yl[i+24] = y4[i+24]; sumy[3] += yl[i+24];
4528
+ }
4529
+
4530
+ const int ibl = ib32 / (QK_K / 32);
4531
+ const int ib = ib32 % (QK_K / 32);
4532
+
4533
+ device const block_iq1_m * xr = x + ibl;
4534
+ device const uint8_t * qs = xr->qs + 4 * ib;
4535
+ device const uint8_t * qh = xr->qh + 2 * ib;
4536
+ device const uint16_t * sc = (device const uint16_t *)xr->scales;
4537
+
4538
+ for (int row = 0; row < N_DST; row++) {
4539
+
4540
+ #if QK_K != 64
4541
+ scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
4542
+ #endif
4543
+
4544
+ constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
4545
+ constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700)));
4546
+ constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid_gpu + (qs[2] | ((qh[1] << 8) & 0x700)));
4547
+ constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[1] << 4) & 0x700)));
4548
+
4549
+ float2 sum = {0.f};
4550
+ for (int j = 0; j < 4; ++j) {
4551
+ sum[0] += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4)
4552
+ + yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4);
4553
+ sum[1] += yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4)
4554
+ + yl[j+24] * (grid4[j] & 0xf) + yl[j+28] * (grid4[j] >> 4);
4555
+ }
4556
+ const float delta1 = sumy[0] * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[1] * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
4557
+ const float delta2 = sumy[2] * (qh[1] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[3] * (qh[1] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
4558
+ #if QK_K == 64
4559
+ const float d = (float) *((device const half *)(sc - 1));
4560
+ sumf[row] += d * ((sum[0] + delta1) * (2*((sc[0] >> (8*(ib%2)+0)) & 0xf) + 1) +
4561
+ (sum[1] + delta2) * (2*((sc[0] >> (8*(ib%2)+4)) & 0xf) + 1));
4562
+ #else
4563
+ sumf[row] += (float)scale.f16 * ((sum[0] + delta1) * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 7) + 1) +
4564
+ (sum[1] + delta2) * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 7) + 1));
4565
+ #endif
4566
+
4567
+ sc += nb*sizeof(block_iq1_m)/2;
4568
+ qs += nb*sizeof(block_iq1_m);
4569
+ qh += nb*sizeof(block_iq1_m);
4570
+ }
4571
+
4572
+ y4 += 32 * 32;
4573
+ }
4574
+
4575
+ for (int row = 0; row < N_DST; ++row) {
4576
+ all_sum = simd_sum(sumf[row]);
4577
+ if (tiisg == 0) {
4578
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
4579
+ }
4580
+ }
4581
+ }
4226
4582
 
4227
4583
  void kernel_mul_mv_iq4_nl_f32_impl(
4228
4584
  device const void * src0,
@@ -4441,6 +4797,34 @@ kernel void kernel_mul_mv_iq1_s_f32(
4441
4797
  kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
4442
4798
  }
4443
4799
 
4800
+ [[host_name("kernel_mul_mv_iq1_m_f32")]]
4801
+ kernel void kernel_mul_mv_iq1_m_f32(
4802
+ device const void * src0,
4803
+ device const float * src1,
4804
+ device float * dst,
4805
+ constant int64_t & ne00,
4806
+ constant int64_t & ne01,
4807
+ constant int64_t & ne02,
4808
+ constant uint64_t & nb00,
4809
+ constant uint64_t & nb01,
4810
+ constant uint64_t & nb02,
4811
+ constant int64_t & ne10,
4812
+ constant int64_t & ne11,
4813
+ constant int64_t & ne12,
4814
+ constant uint64_t & nb10,
4815
+ constant uint64_t & nb11,
4816
+ constant uint64_t & nb12,
4817
+ constant int64_t & ne0,
4818
+ constant int64_t & ne1,
4819
+ constant uint & r2,
4820
+ constant uint & r3,
4821
+ uint3 tgpig[[threadgroup_position_in_grid]],
4822
+ uint tiisg[[thread_index_in_simdgroup]],
4823
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4824
+
4825
+ kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
4826
+ }
4827
+
4444
4828
  [[host_name("kernel_mul_mv_iq4_nl_f32")]]
4445
4829
  kernel void kernel_mul_mv_iq4_nl_f32(
4446
4830
  device const void * src0,
@@ -4914,6 +5298,38 @@ void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 &
4914
5298
  }
4915
5299
  }
4916
5300
 
5301
+ template <typename type4x4>
5302
+ void dequantize_iq1_m(device const block_iq1_m * xb, short il, thread type4x4 & reg) {
5303
+ // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
5304
+ const int ib32 = il/2;
5305
+ il = il%2;
5306
+ device const uint16_t * sc = (device const uint16_t *)xb->scales;
5307
+ #if QK_K == 64
5308
+ const float d = xb->d;
5309
+ #else
5310
+ iq1m_scale_t scale;
5311
+ scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
5312
+ const float d = scale.f16;
5313
+ #endif
5314
+ device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
5315
+ device const uint8_t * qh = xb->qh + 2*ib32 + il;
5316
+ #if QK_K == 64
5317
+ const float dl = d * (2*((sc[ib32/2] >> (8*(ib32%2)+4*il)) & 0xf) + 1);
5318
+ #else
5319
+ const float dl = d * (2*((sc[ib32/2] >> (6*(ib32%2)+3*il)) & 7) + 1);
5320
+ #endif
5321
+ const float ml1 = dl * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
5322
+ const float ml2 = dl * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
5323
+ constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
5324
+ constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700)));
5325
+ for (int i = 0; i < 4; ++i) {
5326
+ reg[0][i] = dl * (grid1[i] & 0xf) + ml1;
5327
+ reg[1][i] = dl * (grid1[i] >> 4) + ml1;
5328
+ reg[2][i] = dl * (grid2[i] & 0xf) + ml2;
5329
+ reg[3][i] = dl * (grid2[i] >> 4) + ml2;
5330
+ }
5331
+ }
5332
+
4917
5333
  template <typename type4x4>
4918
5334
  void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4 & reg) {
4919
5335
  device const uint16_t * q4 = (device const uint16_t *)xb->qs;
@@ -5385,9 +5801,10 @@ kernel void kernel_mul_mm(device const uchar * src0,
5385
5801
 
5386
5802
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
5387
5803
  kernel void kernel_mul_mm_id(
5388
- device const uchar * ids,
5804
+ device const uchar * src0s,
5389
5805
  device const uchar * src1,
5390
5806
  device float * dst,
5807
+ device const uchar * ids,
5391
5808
  constant uint64_t & nbi1,
5392
5809
  constant int64_t & ne00,
5393
5810
  constant int64_t & ne02,
@@ -5404,22 +5821,14 @@ kernel void kernel_mul_mm_id(
5404
5821
  constant uint & r2,
5405
5822
  constant uint & r3,
5406
5823
  constant int & idx,
5407
- device const uchar * src00,
5408
- device const uchar * src01,
5409
- device const uchar * src02,
5410
- device const uchar * src03,
5411
- device const uchar * src04,
5412
- device const uchar * src05,
5413
- device const uchar * src06,
5414
- device const uchar * src07,
5415
5824
  threadgroup uchar * shared_memory [[threadgroup(0)]],
5416
5825
  uint3 tgpig[[threadgroup_position_in_grid]],
5417
5826
  uint tiitg[[thread_index_in_threadgroup]],
5418
5827
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
5419
- device const uchar * src0s[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
5420
5828
 
5421
5829
  // expert id
5422
5830
  const int32_t id = tgpig.z/(ne12*ne13);
5831
+ device const uchar * src0 = src0s + id*nb02;
5423
5832
 
5424
5833
  tgpig.z = tgpig.z%(ne12*ne13);
5425
5834
 
@@ -5434,7 +5843,7 @@ kernel void kernel_mul_mm_id(
5434
5843
  }
5435
5844
 
5436
5845
  kernel_mul_mm_id_impl<block_q, nl, dequantize_func>(
5437
- src0s[id],
5846
+ src0,
5438
5847
  src1,
5439
5848
  src1ids,
5440
5849
  dst,
@@ -5498,6 +5907,7 @@ template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_t kernel_get_r
5498
5907
  template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_t kernel_get_rows<block_iq3_s, QK_NL, dequantize_iq3_s>;
5499
5908
  template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_t kernel_get_rows<block_iq2_s, QK_NL, dequantize_iq2_s>;
5500
5909
  template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows<block_iq1_s, QK_NL, dequantize_iq1_s>;
5910
+ template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_t kernel_get_rows<block_iq1_m, QK_NL, dequantize_iq1_m>;
5501
5911
  template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows<block_iq4_nl, 2, dequantize_iq4_nl>;
5502
5912
  #if QK_K == 64
5503
5913
  template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows<block_iq4_xs, 2, dequantize_iq4_xs>;
@@ -5528,24 +5938,25 @@ typedef void (mat_mm_t)(
5528
5938
  threadgroup uchar *,
5529
5939
  uint3, uint, uint);
5530
5940
 
5531
- template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<float4x4, 1, dequantize_f32>;
5532
- template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
5533
- template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
5534
- template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
5535
- template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_0, 2, dequantize_q5_0>;
5536
- template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_1, 2, dequantize_q5_1>;
5537
- template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
5538
- template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
5539
- template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
5540
- template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
5541
- template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
5542
- template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
5941
+ template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<float4x4, 1, dequantize_f32>;
5942
+ template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
5943
+ template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
5944
+ template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
5945
+ template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_0, 2, dequantize_q5_0>;
5946
+ template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_1, 2, dequantize_q5_1>;
5947
+ template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
5948
+ template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
5949
+ template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
5950
+ template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
5951
+ template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
5952
+ template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
5543
5953
  template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
5544
5954
  template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
5545
5955
  template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
5546
5956
  template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_s, QK_NL, dequantize_iq3_s>;
5547
5957
  template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_s, QK_NL, dequantize_iq2_s>;
5548
5958
  template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_s, QK_NL, dequantize_iq1_s>;
5959
+ template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_m, QK_NL, dequantize_iq1_m>;
5549
5960
  template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_nl>;
5550
5961
  #if QK_K == 64
5551
5962
  template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_xs>;
@@ -5558,9 +5969,10 @@ template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_m
5558
5969
  //
5559
5970
 
5560
5971
  typedef void (mat_mm_id_t)(
5561
- device const uchar * ids,
5972
+ device const uchar * src0s,
5562
5973
  device const uchar * src1,
5563
5974
  device float * dst,
5975
+ device const uchar * ids,
5564
5976
  constant uint64_t & nbi1,
5565
5977
  constant int64_t & ne00,
5566
5978
  constant int64_t & ne02,
@@ -5577,35 +5989,28 @@ typedef void (mat_mm_id_t)(
5577
5989
  constant uint & r2,
5578
5990
  constant uint & r3,
5579
5991
  constant int & idx,
5580
- device const uchar * src00,
5581
- device const uchar * src01,
5582
- device const uchar * src02,
5583
- device const uchar * src03,
5584
- device const uchar * src04,
5585
- device const uchar * src05,
5586
- device const uchar * src06,
5587
- device const uchar * src07,
5588
5992
  threadgroup uchar *,
5589
5993
  uint3, uint, uint);
5590
5994
 
5591
- template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<float4x4, 1, dequantize_f32>;
5592
- template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<half4x4, 1, dequantize_f16>;
5593
- template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_0, 2, dequantize_q4_0>;
5594
- template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_1, 2, dequantize_q4_1>;
5595
- template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_0, 2, dequantize_q5_0>;
5596
- template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_1, 2, dequantize_q5_1>;
5597
- template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q8_0, 2, dequantize_q8_0>;
5598
- template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q2_K, QK_NL, dequantize_q2_K>;
5599
- template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q3_K, QK_NL, dequantize_q3_K>;
5600
- template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>;
5601
- template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
5602
- template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
5995
+ template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<float4x4, 1, dequantize_f32>;
5996
+ template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<half4x4, 1, dequantize_f16>;
5997
+ template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_0, 2, dequantize_q4_0>;
5998
+ template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_1, 2, dequantize_q4_1>;
5999
+ template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_0, 2, dequantize_q5_0>;
6000
+ template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_1, 2, dequantize_q5_1>;
6001
+ template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q8_0, 2, dequantize_q8_0>;
6002
+ template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q2_K, QK_NL, dequantize_q2_K>;
6003
+ template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q3_K, QK_NL, dequantize_q3_K>;
6004
+ template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>;
6005
+ template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
6006
+ template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
5603
6007
  template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
5604
6008
  template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
5605
6009
  template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
5606
6010
  template [[host_name("kernel_mul_mm_id_iq3_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_s, QK_NL, dequantize_iq3_s>;
5607
6011
  template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_s, QK_NL, dequantize_iq2_s>;
5608
6012
  template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s, QK_NL, dequantize_iq1_s>;
6013
+ template [[host_name("kernel_mul_mm_id_iq1_m_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_m, QK_NL, dequantize_iq1_m>;
5609
6014
  template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl, 2, dequantize_iq4_nl>;
5610
6015
  #if QK_K == 64
5611
6016
  template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, 2, dequantize_iq4_xs>;
@@ -5619,9 +6024,10 @@ template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel
5619
6024
 
5620
6025
  [[host_name("kernel_mul_mv_id_f32_f32")]]
5621
6026
  kernel void kernel_mul_mv_id_f32_f32(
5622
- device const char * ids,
6027
+ device const char * src0s,
5623
6028
  device const char * src1,
5624
6029
  device float * dst,
6030
+ device const char * ids,
5625
6031
  constant uint64_t & nbi1,
5626
6032
  constant int64_t & ne00,
5627
6033
  constant int64_t & ne01,
@@ -5642,28 +6048,19 @@ kernel void kernel_mul_mv_id_f32_f32(
5642
6048
  constant uint & r2,
5643
6049
  constant uint & r3,
5644
6050
  constant int & idx,
5645
- device const char * src00,
5646
- device const char * src01,
5647
- device const char * src02,
5648
- device const char * src03,
5649
- device const char * src04,
5650
- device const char * src05,
5651
- device const char * src06,
5652
- device const char * src07,
5653
6051
  uint3 tgpig[[threadgroup_position_in_grid]],
5654
6052
  uint tiitg[[thread_index_in_threadgroup]],
5655
6053
  uint tiisg[[thread_index_in_simdgroup]],
5656
6054
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
5657
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
5658
-
5659
6055
  const int64_t bid = tgpig.z/(ne12*ne13);
5660
6056
 
5661
6057
  tgpig.z = tgpig.z%(ne12*ne13);
5662
6058
 
5663
6059
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6060
+ device const char * src0 = src0s + id*nb02;
5664
6061
 
5665
6062
  kernel_mul_mv_f32_f32_impl(
5666
- src0[id],
6063
+ src0,
5667
6064
  src1 + bid*nb11,
5668
6065
  dst + bid*ne0,
5669
6066
  ne00,
@@ -5688,9 +6085,10 @@ kernel void kernel_mul_mv_id_f32_f32(
5688
6085
 
5689
6086
  [[host_name("kernel_mul_mv_id_f16_f32")]]
5690
6087
  kernel void kernel_mul_mv_id_f16_f32(
5691
- device const char * ids,
6088
+ device const char * src0s,
5692
6089
  device const char * src1,
5693
6090
  device float * dst,
6091
+ device const char * ids,
5694
6092
  constant uint64_t & nbi1,
5695
6093
  constant int64_t & ne00,
5696
6094
  constant int64_t & ne01,
@@ -5711,28 +6109,19 @@ kernel void kernel_mul_mv_id_f16_f32(
5711
6109
  constant uint & r2,
5712
6110
  constant uint & r3,
5713
6111
  constant int & idx,
5714
- device const char * src00,
5715
- device const char * src01,
5716
- device const char * src02,
5717
- device const char * src03,
5718
- device const char * src04,
5719
- device const char * src05,
5720
- device const char * src06,
5721
- device const char * src07,
5722
6112
  uint3 tgpig[[threadgroup_position_in_grid]],
5723
6113
  uint tiitg[[thread_index_in_threadgroup]],
5724
6114
  uint tiisg[[thread_index_in_simdgroup]],
5725
6115
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
5726
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
5727
-
5728
6116
  const int64_t bid = tgpig.z/(ne12*ne13);
5729
6117
 
5730
6118
  tgpig.z = tgpig.z%(ne12*ne13);
5731
6119
 
5732
6120
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6121
+ device const char * src0 = src0s + id*nb02;
5733
6122
 
5734
6123
  kernel_mul_mv_f16_f32_impl(
5735
- src0[id],
6124
+ src0,
5736
6125
  src1 + bid*nb11,
5737
6126
  dst + bid*ne0,
5738
6127
  ne00,
@@ -5757,9 +6146,10 @@ kernel void kernel_mul_mv_id_f16_f32(
5757
6146
 
5758
6147
  [[host_name("kernel_mul_mv_id_q8_0_f32")]]
5759
6148
  kernel void kernel_mul_mv_id_q8_0_f32(
5760
- device const char * ids,
6149
+ device const char * src0s,
5761
6150
  device const char * src1,
5762
6151
  device float * dst,
6152
+ device const char * ids,
5763
6153
  constant uint64_t & nbi1,
5764
6154
  constant int64_t & ne00,
5765
6155
  constant int64_t & ne01,
@@ -5780,28 +6170,19 @@ kernel void kernel_mul_mv_id_q8_0_f32(
5780
6170
  constant uint & r2,
5781
6171
  constant uint & r3,
5782
6172
  constant int & idx,
5783
- device const char * src00,
5784
- device const char * src01,
5785
- device const char * src02,
5786
- device const char * src03,
5787
- device const char * src04,
5788
- device const char * src05,
5789
- device const char * src06,
5790
- device const char * src07,
5791
6173
  uint3 tgpig[[threadgroup_position_in_grid]],
5792
6174
  uint tiitg[[thread_index_in_threadgroup]],
5793
6175
  uint tiisg[[thread_index_in_simdgroup]],
5794
6176
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
5795
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
5796
-
5797
6177
  const int64_t bid = tgpig.z/(ne12*ne13);
5798
6178
 
5799
6179
  tgpig.z = tgpig.z%(ne12*ne13);
5800
6180
 
5801
6181
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6182
+ device const char * src0 = src0s + id*nb02;
5802
6183
 
5803
6184
  kernel_mul_mv_q8_0_f32_impl(
5804
- src0[id],
6185
+ src0,
5805
6186
  (device const float *) (src1 + bid*nb11),
5806
6187
  dst + bid*ne0,
5807
6188
  ne00,
@@ -5820,9 +6201,10 @@ kernel void kernel_mul_mv_id_q8_0_f32(
5820
6201
 
5821
6202
  [[host_name("kernel_mul_mv_id_q4_0_f32")]]
5822
6203
  kernel void kernel_mul_mv_id_q4_0_f32(
5823
- device const char * ids,
6204
+ device const char * src0s,
5824
6205
  device const char * src1,
5825
6206
  device float * dst,
6207
+ device const char * ids,
5826
6208
  constant uint64_t & nbi1,
5827
6209
  constant int64_t & ne00,
5828
6210
  constant int64_t & ne01,
@@ -5843,28 +6225,19 @@ kernel void kernel_mul_mv_id_q4_0_f32(
5843
6225
  constant uint & r2,
5844
6226
  constant uint & r3,
5845
6227
  constant int & idx,
5846
- device const char * src00,
5847
- device const char * src01,
5848
- device const char * src02,
5849
- device const char * src03,
5850
- device const char * src04,
5851
- device const char * src05,
5852
- device const char * src06,
5853
- device const char * src07,
5854
6228
  uint3 tgpig[[threadgroup_position_in_grid]],
5855
6229
  uint tiitg[[thread_index_in_threadgroup]],
5856
6230
  uint tiisg[[thread_index_in_simdgroup]],
5857
6231
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
5858
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
5859
-
5860
6232
  const int64_t bid = tgpig.z/(ne12*ne13);
5861
6233
 
5862
6234
  tgpig.z = tgpig.z%(ne12*ne13);
5863
6235
 
5864
6236
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6237
+ device const char * src0 = src0s + id*nb02;
5865
6238
 
5866
6239
  mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
5867
- src0[id],
6240
+ src0,
5868
6241
  (device const float *) (src1 + bid*nb11),
5869
6242
  dst + bid*ne0,
5870
6243
  ne00,
@@ -5883,9 +6256,10 @@ kernel void kernel_mul_mv_id_q4_0_f32(
5883
6256
 
5884
6257
  [[host_name("kernel_mul_mv_id_q4_1_f32")]]
5885
6258
  kernel void kernel_mul_mv_id_q4_1_f32(
5886
- device const char * ids,
6259
+ device const char * src0s,
5887
6260
  device const char * src1,
5888
6261
  device float * dst,
6262
+ device const char * ids,
5889
6263
  constant uint64_t & nbi1,
5890
6264
  constant int64_t & ne00,
5891
6265
  constant int64_t & ne01,
@@ -5906,28 +6280,19 @@ kernel void kernel_mul_mv_id_q4_1_f32(
5906
6280
  constant uint & r2,
5907
6281
  constant uint & r3,
5908
6282
  constant int & idx,
5909
- device const char * src00,
5910
- device const char * src01,
5911
- device const char * src02,
5912
- device const char * src03,
5913
- device const char * src04,
5914
- device const char * src05,
5915
- device const char * src06,
5916
- device const char * src07,
5917
6283
  uint3 tgpig[[threadgroup_position_in_grid]],
5918
6284
  uint tiitg[[thread_index_in_threadgroup]],
5919
6285
  uint tiisg[[thread_index_in_simdgroup]],
5920
6286
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
5921
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
5922
-
5923
6287
  const int64_t bid = tgpig.z/(ne12*ne13);
5924
6288
 
5925
6289
  tgpig.z = tgpig.z%(ne12*ne13);
5926
6290
 
5927
6291
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6292
+ device const char * src0 = src0s + id*nb02;
5928
6293
 
5929
6294
  mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
5930
- src0[id],
6295
+ src0,
5931
6296
  (device const float *) (src1 + bid*nb11),
5932
6297
  dst + bid*ne0,
5933
6298
  ne00,
@@ -5946,9 +6311,10 @@ kernel void kernel_mul_mv_id_q4_1_f32(
5946
6311
 
5947
6312
  [[host_name("kernel_mul_mv_id_q5_0_f32")]]
5948
6313
  kernel void kernel_mul_mv_id_q5_0_f32(
5949
- device const char * ids,
6314
+ device const char * src0s,
5950
6315
  device const char * src1,
5951
6316
  device float * dst,
6317
+ device const char * ids,
5952
6318
  constant uint64_t & nbi1,
5953
6319
  constant int64_t & ne00,
5954
6320
  constant int64_t & ne01,
@@ -5969,28 +6335,19 @@ kernel void kernel_mul_mv_id_q5_0_f32(
5969
6335
  constant uint & r2,
5970
6336
  constant uint & r3,
5971
6337
  constant int & idx,
5972
- device const char * src00,
5973
- device const char * src01,
5974
- device const char * src02,
5975
- device const char * src03,
5976
- device const char * src04,
5977
- device const char * src05,
5978
- device const char * src06,
5979
- device const char * src07,
5980
6338
  uint3 tgpig[[threadgroup_position_in_grid]],
5981
6339
  uint tiitg[[thread_index_in_threadgroup]],
5982
6340
  uint tiisg[[thread_index_in_simdgroup]],
5983
6341
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
5984
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
5985
-
5986
6342
  const int64_t bid = tgpig.z/(ne12*ne13);
5987
6343
 
5988
6344
  tgpig.z = tgpig.z%(ne12*ne13);
5989
6345
 
5990
6346
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6347
+ device const char * src0 = src0s + id*nb02;
5991
6348
 
5992
6349
  mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
5993
- src0[id],
6350
+ src0,
5994
6351
  (device const float *) (src1 + bid*nb11),
5995
6352
  dst + bid*ne0,
5996
6353
  ne00,
@@ -6009,9 +6366,10 @@ kernel void kernel_mul_mv_id_q5_0_f32(
6009
6366
 
6010
6367
  [[host_name("kernel_mul_mv_id_q5_1_f32")]]
6011
6368
  kernel void kernel_mul_mv_id_q5_1_f32(
6012
- device const char * ids,
6369
+ device const char * src0s,
6013
6370
  device const char * src1,
6014
6371
  device float * dst,
6372
+ device const char * ids,
6015
6373
  constant uint64_t & nbi1,
6016
6374
  constant int64_t & ne00,
6017
6375
  constant int64_t & ne01,
@@ -6032,28 +6390,19 @@ kernel void kernel_mul_mv_id_q5_1_f32(
6032
6390
  constant uint & r2,
6033
6391
  constant uint & r3,
6034
6392
  constant int & idx,
6035
- device const char * src00,
6036
- device const char * src01,
6037
- device const char * src02,
6038
- device const char * src03,
6039
- device const char * src04,
6040
- device const char * src05,
6041
- device const char * src06,
6042
- device const char * src07,
6043
6393
  uint3 tgpig[[threadgroup_position_in_grid]],
6044
6394
  uint tiitg[[thread_index_in_threadgroup]],
6045
6395
  uint tiisg[[thread_index_in_simdgroup]],
6046
6396
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
6047
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6048
-
6049
6397
  const int64_t bid = tgpig.z/(ne12*ne13);
6050
6398
 
6051
6399
  tgpig.z = tgpig.z%(ne12*ne13);
6052
6400
 
6053
6401
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6402
+ device const char * src0 = src0s + id*nb02;
6054
6403
 
6055
6404
  mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
6056
- src0[id],
6405
+ src0,
6057
6406
  (device const float *) (src1 + bid*nb11),
6058
6407
  dst + bid*ne0,
6059
6408
  ne00,
@@ -6072,9 +6421,10 @@ kernel void kernel_mul_mv_id_q5_1_f32(
6072
6421
 
6073
6422
  [[host_name("kernel_mul_mv_id_q2_K_f32")]]
6074
6423
  kernel void kernel_mul_mv_id_q2_K_f32(
6075
- device const char * ids,
6424
+ device const char * src0s,
6076
6425
  device const char * src1,
6077
6426
  device float * dst,
6427
+ device const char * ids,
6078
6428
  constant uint64_t & nbi1,
6079
6429
  constant int64_t & ne00,
6080
6430
  constant int64_t & ne01,
@@ -6095,28 +6445,19 @@ kernel void kernel_mul_mv_id_q2_K_f32(
6095
6445
  constant uint & r2,
6096
6446
  constant uint & r3,
6097
6447
  constant int & idx,
6098
- device const char * src00,
6099
- device const char * src01,
6100
- device const char * src02,
6101
- device const char * src03,
6102
- device const char * src04,
6103
- device const char * src05,
6104
- device const char * src06,
6105
- device const char * src07,
6106
6448
  uint3 tgpig[[threadgroup_position_in_grid]],
6107
6449
  uint tiitg[[thread_index_in_threadgroup]],
6108
6450
  uint tiisg[[thread_index_in_simdgroup]],
6109
6451
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
6110
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6111
-
6112
6452
  const int64_t bid = tgpig.z/(ne12*ne13);
6113
6453
 
6114
6454
  tgpig.z = tgpig.z%(ne12*ne13);
6115
6455
 
6116
6456
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6457
+ device const char * src0 = src0s + id*nb02;
6117
6458
 
6118
6459
  kernel_mul_mv_q2_K_f32_impl(
6119
- src0[id],
6460
+ src0,
6120
6461
  (device const float *) (src1 + bid*nb11),
6121
6462
  dst + bid*ne0,
6122
6463
  ne00,
@@ -6135,9 +6476,10 @@ kernel void kernel_mul_mv_id_q2_K_f32(
6135
6476
 
6136
6477
  [[host_name("kernel_mul_mv_id_q3_K_f32")]]
6137
6478
  kernel void kernel_mul_mv_id_q3_K_f32(
6138
- device const char * ids,
6479
+ device const char * src0s,
6139
6480
  device const char * src1,
6140
6481
  device float * dst,
6482
+ device const char * ids,
6141
6483
  constant uint64_t & nbi1,
6142
6484
  constant int64_t & ne00,
6143
6485
  constant int64_t & ne01,
@@ -6158,28 +6500,19 @@ kernel void kernel_mul_mv_id_q3_K_f32(
6158
6500
  constant uint & r2,
6159
6501
  constant uint & r3,
6160
6502
  constant int & idx,
6161
- device const char * src00,
6162
- device const char * src01,
6163
- device const char * src02,
6164
- device const char * src03,
6165
- device const char * src04,
6166
- device const char * src05,
6167
- device const char * src06,
6168
- device const char * src07,
6169
6503
  uint3 tgpig[[threadgroup_position_in_grid]],
6170
6504
  uint tiitg[[thread_index_in_threadgroup]],
6171
6505
  uint tiisg[[thread_index_in_simdgroup]],
6172
6506
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
6173
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6174
-
6175
6507
  const int64_t bid = tgpig.z/(ne12*ne13);
6176
6508
 
6177
6509
  tgpig.z = tgpig.z%(ne12*ne13);
6178
6510
 
6179
6511
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6512
+ device const char * src0 = src0s + id*nb02;
6180
6513
 
6181
6514
  kernel_mul_mv_q3_K_f32_impl(
6182
- src0[id],
6515
+ src0,
6183
6516
  (device const float *) (src1 + bid*nb11),
6184
6517
  dst + bid*ne0,
6185
6518
  ne00,
@@ -6198,9 +6531,10 @@ kernel void kernel_mul_mv_id_q3_K_f32(
6198
6531
 
6199
6532
  [[host_name("kernel_mul_mv_id_q4_K_f32")]]
6200
6533
  kernel void kernel_mul_mv_id_q4_K_f32(
6201
- device const char * ids,
6534
+ device const char * src0s,
6202
6535
  device const char * src1,
6203
6536
  device float * dst,
6537
+ device const char * ids,
6204
6538
  constant uint64_t & nbi1,
6205
6539
  constant int64_t & ne00,
6206
6540
  constant int64_t & ne01,
@@ -6221,28 +6555,19 @@ kernel void kernel_mul_mv_id_q4_K_f32(
6221
6555
  constant uint & r2,
6222
6556
  constant uint & r3,
6223
6557
  constant int & idx,
6224
- device const char * src00,
6225
- device const char * src01,
6226
- device const char * src02,
6227
- device const char * src03,
6228
- device const char * src04,
6229
- device const char * src05,
6230
- device const char * src06,
6231
- device const char * src07,
6232
6558
  uint3 tgpig[[threadgroup_position_in_grid]],
6233
6559
  uint tiitg[[thread_index_in_threadgroup]],
6234
6560
  uint tiisg[[thread_index_in_simdgroup]],
6235
6561
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
6236
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6237
-
6238
6562
  const int64_t bid = tgpig.z/(ne12*ne13);
6239
6563
 
6240
6564
  tgpig.z = tgpig.z%(ne12*ne13);
6241
6565
 
6242
6566
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6567
+ device const char * src0 = src0s + id*nb02;
6243
6568
 
6244
6569
  kernel_mul_mv_q4_K_f32_impl(
6245
- src0[id],
6570
+ src0,
6246
6571
  (device const float *) (src1 + bid*nb11),
6247
6572
  dst + bid*ne0,
6248
6573
  ne00,
@@ -6261,9 +6586,10 @@ kernel void kernel_mul_mv_id_q4_K_f32(
6261
6586
 
6262
6587
  [[host_name("kernel_mul_mv_id_q5_K_f32")]]
6263
6588
  kernel void kernel_mul_mv_id_q5_K_f32(
6264
- device const char * ids,
6589
+ device const char * src0s,
6265
6590
  device const char * src1,
6266
6591
  device float * dst,
6592
+ device const char * ids,
6267
6593
  constant uint64_t & nbi1,
6268
6594
  constant int64_t & ne00,
6269
6595
  constant int64_t & ne01,
@@ -6284,28 +6610,19 @@ kernel void kernel_mul_mv_id_q5_K_f32(
6284
6610
  constant uint & r2,
6285
6611
  constant uint & r3,
6286
6612
  constant int & idx,
6287
- device const char * src00,
6288
- device const char * src01,
6289
- device const char * src02,
6290
- device const char * src03,
6291
- device const char * src04,
6292
- device const char * src05,
6293
- device const char * src06,
6294
- device const char * src07,
6295
6613
  uint3 tgpig[[threadgroup_position_in_grid]],
6296
6614
  uint tiitg[[thread_index_in_threadgroup]],
6297
6615
  uint tiisg[[thread_index_in_simdgroup]],
6298
6616
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
6299
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6300
-
6301
6617
  const int64_t bid = tgpig.z/(ne12*ne13);
6302
6618
 
6303
6619
  tgpig.z = tgpig.z%(ne12*ne13);
6304
6620
 
6305
6621
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6622
+ device const char * src0 = src0s + id*nb02;
6306
6623
 
6307
6624
  kernel_mul_mv_q5_K_f32_impl(
6308
- src0[id],
6625
+ src0,
6309
6626
  (device const float *) (src1 + bid*nb11),
6310
6627
  dst + bid*ne0,
6311
6628
  ne00,
@@ -6324,9 +6641,10 @@ kernel void kernel_mul_mv_id_q5_K_f32(
6324
6641
 
6325
6642
  [[host_name("kernel_mul_mv_id_q6_K_f32")]]
6326
6643
  kernel void kernel_mul_mv_id_q6_K_f32(
6327
- device const char * ids,
6644
+ device const char * src0s,
6328
6645
  device const char * src1,
6329
6646
  device float * dst,
6647
+ device const char * ids,
6330
6648
  constant uint64_t & nbi1,
6331
6649
  constant int64_t & ne00,
6332
6650
  constant int64_t & ne01,
@@ -6347,28 +6665,19 @@ kernel void kernel_mul_mv_id_q6_K_f32(
6347
6665
  constant uint & r2,
6348
6666
  constant uint & r3,
6349
6667
  constant int & idx,
6350
- device const char * src00,
6351
- device const char * src01,
6352
- device const char * src02,
6353
- device const char * src03,
6354
- device const char * src04,
6355
- device const char * src05,
6356
- device const char * src06,
6357
- device const char * src07,
6358
6668
  uint3 tgpig[[threadgroup_position_in_grid]],
6359
6669
  uint tiitg[[thread_index_in_threadgroup]],
6360
6670
  uint tiisg[[thread_index_in_simdgroup]],
6361
6671
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
6362
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6363
-
6364
6672
  const int64_t bid = tgpig.z/(ne12*ne13);
6365
6673
 
6366
6674
  tgpig.z = tgpig.z%(ne12*ne13);
6367
6675
 
6368
6676
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6677
+ device const char * src0 = src0s + id*nb02;
6369
6678
 
6370
6679
  kernel_mul_mv_q6_K_f32_impl(
6371
- src0[id],
6680
+ src0,
6372
6681
  (device const float *) (src1 + bid*nb11),
6373
6682
  dst + bid*ne0,
6374
6683
  ne00,
@@ -6387,9 +6696,10 @@ kernel void kernel_mul_mv_id_q6_K_f32(
6387
6696
 
6388
6697
  [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]]
6389
6698
  kernel void kernel_mul_mv_id_iq2_xxs_f32(
6390
- device const char * ids,
6699
+ device const char * src0s,
6391
6700
  device const char * src1,
6392
6701
  device float * dst,
6702
+ device const char * ids,
6393
6703
  constant uint64_t & nbi1,
6394
6704
  constant int64_t & ne00,
6395
6705
  constant int64_t & ne01,
@@ -6410,29 +6720,20 @@ kernel void kernel_mul_mv_id_iq2_xxs_f32(
6410
6720
  constant uint & r2,
6411
6721
  constant uint & r3,
6412
6722
  constant int & idx,
6413
- device const char * src00,
6414
- device const char * src01,
6415
- device const char * src02,
6416
- device const char * src03,
6417
- device const char * src04,
6418
- device const char * src05,
6419
- device const char * src06,
6420
- device const char * src07,
6421
6723
  threadgroup int8_t * shared_values [[threadgroup(0)]],
6422
6724
  uint3 tgpig[[threadgroup_position_in_grid]],
6423
6725
  uint tiitg[[thread_index_in_threadgroup]],
6424
6726
  uint tiisg[[thread_index_in_simdgroup]],
6425
6727
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
6426
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6427
-
6428
6728
  const int64_t bid = tgpig.z/(ne12*ne13);
6429
6729
 
6430
6730
  tgpig.z = tgpig.z%(ne12*ne13);
6431
6731
 
6432
6732
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6733
+ device const char * src0 = src0s + id*nb02;
6433
6734
 
6434
6735
  kernel_mul_mv_iq2_xxs_f32_impl(
6435
- src0[id],
6736
+ src0,
6436
6737
  (device const float *) (src1 + bid*nb11),
6437
6738
  dst + bid*ne0,
6438
6739
  ne00,
@@ -6452,9 +6753,10 @@ kernel void kernel_mul_mv_id_iq2_xxs_f32(
6452
6753
 
6453
6754
  [[host_name("kernel_mul_mv_id_iq2_xs_f32")]]
6454
6755
  kernel void kernel_mul_mv_id_iq2_xs_f32(
6455
- device const char * ids,
6756
+ device const char * src0s,
6456
6757
  device const char * src1,
6457
6758
  device float * dst,
6759
+ device const char * ids,
6458
6760
  constant uint64_t & nbi1,
6459
6761
  constant int64_t & ne00,
6460
6762
  constant int64_t & ne01,
@@ -6475,29 +6777,20 @@ kernel void kernel_mul_mv_id_iq2_xs_f32(
6475
6777
  constant uint & r2,
6476
6778
  constant uint & r3,
6477
6779
  constant int & idx,
6478
- device const char * src00,
6479
- device const char * src01,
6480
- device const char * src02,
6481
- device const char * src03,
6482
- device const char * src04,
6483
- device const char * src05,
6484
- device const char * src06,
6485
- device const char * src07,
6486
6780
  threadgroup int8_t * shared_values [[threadgroup(0)]],
6487
6781
  uint3 tgpig[[threadgroup_position_in_grid]],
6488
6782
  uint tiitg[[thread_index_in_threadgroup]],
6489
6783
  uint tiisg[[thread_index_in_simdgroup]],
6490
6784
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
6491
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6492
-
6493
6785
  const int64_t bid = tgpig.z/(ne12*ne13);
6494
6786
 
6495
6787
  tgpig.z = tgpig.z%(ne12*ne13);
6496
6788
 
6497
6789
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6790
+ device const char * src0 = src0s + id*nb02;
6498
6791
 
6499
6792
  kernel_mul_mv_iq2_xs_f32_impl(
6500
- src0[id],
6793
+ src0,
6501
6794
  (device const float *) (src1 + bid*nb11),
6502
6795
  dst + bid*ne0,
6503
6796
  ne00,
@@ -6517,9 +6810,10 @@ kernel void kernel_mul_mv_id_iq2_xs_f32(
6517
6810
 
6518
6811
  [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]]
6519
6812
  kernel void kernel_mul_mv_id_iq3_xxs_f32(
6520
- device const char * ids,
6813
+ device const char * src0s,
6521
6814
  device const char * src1,
6522
6815
  device float * dst,
6816
+ device const char * ids,
6523
6817
  constant uint64_t & nbi1,
6524
6818
  constant int64_t & ne00,
6525
6819
  constant int64_t & ne01,
@@ -6540,29 +6834,20 @@ kernel void kernel_mul_mv_id_iq3_xxs_f32(
6540
6834
  constant uint & r2,
6541
6835
  constant uint & r3,
6542
6836
  constant int & idx,
6543
- device const char * src00,
6544
- device const char * src01,
6545
- device const char * src02,
6546
- device const char * src03,
6547
- device const char * src04,
6548
- device const char * src05,
6549
- device const char * src06,
6550
- device const char * src07,
6551
6837
  threadgroup int8_t * shared_values [[threadgroup(0)]],
6552
6838
  uint3 tgpig[[threadgroup_position_in_grid]],
6553
6839
  uint tiitg[[thread_index_in_threadgroup]],
6554
6840
  uint tiisg[[thread_index_in_simdgroup]],
6555
6841
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
6556
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6557
-
6558
6842
  const int64_t bid = tgpig.z/(ne12*ne13);
6559
6843
 
6560
6844
  tgpig.z = tgpig.z%(ne12*ne13);
6561
6845
 
6562
6846
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6847
+ device const char * src0 = src0s + id*nb02;
6563
6848
 
6564
6849
  kernel_mul_mv_iq3_xxs_f32_impl(
6565
- src0[id],
6850
+ src0,
6566
6851
  (device const float *) (src1 + bid*nb11),
6567
6852
  dst + bid*ne0,
6568
6853
  ne00,
@@ -6582,9 +6867,10 @@ kernel void kernel_mul_mv_id_iq3_xxs_f32(
6582
6867
 
6583
6868
  [[host_name("kernel_mul_mv_id_iq3_s_f32")]]
6584
6869
  kernel void kernel_mul_mv_id_iq3_s_f32(
6585
- device const char * ids,
6870
+ device const char * src0s,
6586
6871
  device const char * src1,
6587
6872
  device float * dst,
6873
+ device const char * ids,
6588
6874
  constant uint64_t & nbi1,
6589
6875
  constant int64_t & ne00,
6590
6876
  constant int64_t & ne01,
@@ -6605,29 +6891,20 @@ kernel void kernel_mul_mv_id_iq3_s_f32(
6605
6891
  constant uint & r2,
6606
6892
  constant uint & r3,
6607
6893
  constant int & idx,
6608
- device const char * src00,
6609
- device const char * src01,
6610
- device const char * src02,
6611
- device const char * src03,
6612
- device const char * src04,
6613
- device const char * src05,
6614
- device const char * src06,
6615
- device const char * src07,
6616
6894
  threadgroup int8_t * shared_values [[threadgroup(0)]],
6617
6895
  uint3 tgpig[[threadgroup_position_in_grid]],
6618
6896
  uint tiitg[[thread_index_in_threadgroup]],
6619
6897
  uint tiisg[[thread_index_in_simdgroup]],
6620
6898
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
6621
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6622
-
6623
6899
  const int64_t bid = tgpig.z/(ne12*ne13);
6624
6900
 
6625
6901
  tgpig.z = tgpig.z%(ne12*ne13);
6626
6902
 
6627
6903
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6904
+ device const char * src0 = src0s + id*nb02;
6628
6905
 
6629
6906
  kernel_mul_mv_iq3_s_f32_impl(
6630
- src0[id],
6907
+ src0,
6631
6908
  (device const float *) (src1 + bid*nb11),
6632
6909
  dst + bid*ne0,
6633
6910
  ne00,
@@ -6647,9 +6924,10 @@ kernel void kernel_mul_mv_id_iq3_s_f32(
6647
6924
 
6648
6925
  [[host_name("kernel_mul_mv_id_iq2_s_f32")]]
6649
6926
  kernel void kernel_mul_mv_id_iq2_s_f32(
6650
- device const char * ids,
6927
+ device const char * src0s,
6651
6928
  device const char * src1,
6652
6929
  device float * dst,
6930
+ device const char * ids,
6653
6931
  constant uint64_t & nbi1,
6654
6932
  constant int64_t & ne00,
6655
6933
  constant int64_t & ne01,
@@ -6670,29 +6948,20 @@ kernel void kernel_mul_mv_id_iq2_s_f32(
6670
6948
  constant uint & r2,
6671
6949
  constant uint & r3,
6672
6950
  constant int & idx,
6673
- device const char * src00,
6674
- device const char * src01,
6675
- device const char * src02,
6676
- device const char * src03,
6677
- device const char * src04,
6678
- device const char * src05,
6679
- device const char * src06,
6680
- device const char * src07,
6681
6951
  threadgroup int8_t * shared_values [[threadgroup(0)]],
6682
6952
  uint3 tgpig[[threadgroup_position_in_grid]],
6683
6953
  uint tiitg[[thread_index_in_threadgroup]],
6684
6954
  uint tiisg[[thread_index_in_simdgroup]],
6685
6955
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
6686
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6687
-
6688
6956
  const int64_t bid = tgpig.z/(ne12*ne13);
6689
6957
 
6690
6958
  tgpig.z = tgpig.z%(ne12*ne13);
6691
6959
 
6692
6960
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6961
+ device const char * src0 = src0s + id*nb02;
6693
6962
 
6694
6963
  kernel_mul_mv_iq2_s_f32_impl(
6695
- src0[id],
6964
+ src0,
6696
6965
  (device const float *) (src1 + bid*nb11),
6697
6966
  dst + bid*ne0,
6698
6967
  ne00,
@@ -6712,9 +6981,10 @@ kernel void kernel_mul_mv_id_iq2_s_f32(
6712
6981
 
6713
6982
  [[host_name("kernel_mul_mv_id_iq1_s_f32")]]
6714
6983
  kernel void kernel_mul_mv_id_iq1_s_f32(
6715
- device const char * ids,
6984
+ device const char * src0s,
6716
6985
  device const char * src1,
6717
6986
  device float * dst,
6987
+ device const char * ids,
6718
6988
  constant uint64_t & nbi1,
6719
6989
  constant int64_t & ne00,
6720
6990
  constant int64_t & ne01,
@@ -6735,28 +7005,74 @@ kernel void kernel_mul_mv_id_iq1_s_f32(
6735
7005
  constant uint & r2,
6736
7006
  constant uint & r3,
6737
7007
  constant int & idx,
6738
- device const char * src00,
6739
- device const char * src01,
6740
- device const char * src02,
6741
- device const char * src03,
6742
- device const char * src04,
6743
- device const char * src05,
6744
- device const char * src06,
6745
- device const char * src07,
6746
7008
  uint3 tgpig[[threadgroup_position_in_grid]],
6747
7009
  uint tiitg[[thread_index_in_threadgroup]],
6748
7010
  uint tiisg[[thread_index_in_simdgroup]],
6749
7011
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
6750
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6751
-
6752
7012
  const int64_t bid = tgpig.z/(ne12*ne13);
6753
7013
 
6754
7014
  tgpig.z = tgpig.z%(ne12*ne13);
6755
7015
 
6756
7016
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
7017
+ device const char * src0 = src0s + id*nb02;
6757
7018
 
6758
7019
  kernel_mul_mv_iq1_s_f32_impl(
6759
- src0[id],
7020
+ src0,
7021
+ (device const float *) (src1 + bid*nb11),
7022
+ dst + bid*ne0,
7023
+ ne00,
7024
+ ne01,
7025
+ ne02,
7026
+ ne10,
7027
+ ne12,
7028
+ ne0,
7029
+ ne1,
7030
+ r2,
7031
+ r3,
7032
+ tgpig,
7033
+ tiisg,
7034
+ sgitg);
7035
+ }
7036
+
7037
+ [[host_name("kernel_mul_mv_id_iq1_m_f32")]]
7038
+ kernel void kernel_mul_mv_id_iq1_m_f32(
7039
+ device const char * src0s,
7040
+ device const char * src1,
7041
+ device float * dst,
7042
+ device const char * ids,
7043
+ constant uint64_t & nbi1,
7044
+ constant int64_t & ne00,
7045
+ constant int64_t & ne01,
7046
+ constant int64_t & ne02,
7047
+ constant uint64_t & nb00,
7048
+ constant uint64_t & nb01,
7049
+ constant uint64_t & nb02,
7050
+ constant int64_t & ne10,
7051
+ constant int64_t & ne11,
7052
+ constant int64_t & ne12,
7053
+ constant int64_t & ne13,
7054
+ constant uint64_t & nb10,
7055
+ constant uint64_t & nb11,
7056
+ constant uint64_t & nb12,
7057
+ constant int64_t & ne0,
7058
+ constant int64_t & ne1,
7059
+ constant uint64_t & nb1,
7060
+ constant uint & r2,
7061
+ constant uint & r3,
7062
+ constant int & idx,
7063
+ uint3 tgpig[[threadgroup_position_in_grid]],
7064
+ uint tiitg[[thread_index_in_threadgroup]],
7065
+ uint tiisg[[thread_index_in_simdgroup]],
7066
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
7067
+ const int64_t bid = tgpig.z/(ne12*ne13);
7068
+
7069
+ tgpig.z = tgpig.z%(ne12*ne13);
7070
+
7071
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
7072
+ device const char * src0 = src0s + id*nb02;
7073
+
7074
+ kernel_mul_mv_iq1_m_f32_impl(
7075
+ src0,
6760
7076
  (device const float *) (src1 + bid*nb11),
6761
7077
  dst + bid*ne0,
6762
7078
  ne00,
@@ -6775,9 +7091,10 @@ kernel void kernel_mul_mv_id_iq1_s_f32(
6775
7091
 
6776
7092
  [[host_name("kernel_mul_mv_id_iq4_nl_f32")]]
6777
7093
  kernel void kernel_mul_mv_id_iq4_nl_f32(
6778
- device const char * ids,
7094
+ device const char * src0s,
6779
7095
  device const char * src1,
6780
7096
  device float * dst,
7097
+ device const char * ids,
6781
7098
  constant uint64_t & nbi1,
6782
7099
  constant int64_t & ne00,
6783
7100
  constant int64_t & ne01,
@@ -6798,29 +7115,20 @@ kernel void kernel_mul_mv_id_iq4_nl_f32(
6798
7115
  constant uint & r2,
6799
7116
  constant uint & r3,
6800
7117
  constant int & idx,
6801
- device const char * src00,
6802
- device const char * src01,
6803
- device const char * src02,
6804
- device const char * src03,
6805
- device const char * src04,
6806
- device const char * src05,
6807
- device const char * src06,
6808
- device const char * src07,
6809
7118
  threadgroup float * shared_values [[threadgroup(0)]],
6810
7119
  uint3 tgpig[[threadgroup_position_in_grid]],
6811
7120
  uint tiitg[[thread_index_in_threadgroup]],
6812
7121
  uint tiisg[[thread_index_in_simdgroup]],
6813
7122
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
6814
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6815
-
6816
7123
  const int64_t bid = tgpig.z/(ne12*ne13);
6817
7124
 
6818
7125
  tgpig.z = tgpig.z%(ne12*ne13);
6819
7126
 
6820
7127
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
7128
+ device const char * src0 = src0s + id*nb02;
6821
7129
 
6822
7130
  kernel_mul_mv_iq4_nl_f32_impl(
6823
- src0[id],
7131
+ src0,
6824
7132
  (device const float *) (src1 + bid*nb11),
6825
7133
  dst + bid*ne0,
6826
7134
  ne00,
@@ -6840,9 +7148,10 @@ kernel void kernel_mul_mv_id_iq4_nl_f32(
6840
7148
 
6841
7149
  [[host_name("kernel_mul_mv_id_iq4_xs_f32")]]
6842
7150
  kernel void kernel_mul_mv_id_iq4_xs_f32(
6843
- device const char * ids,
7151
+ device const char * src0s,
6844
7152
  device const char * src1,
6845
7153
  device float * dst,
7154
+ device const char * ids,
6846
7155
  constant uint64_t & nbi1,
6847
7156
  constant int64_t & ne00,
6848
7157
  constant int64_t & ne01,
@@ -6863,33 +7172,24 @@ kernel void kernel_mul_mv_id_iq4_xs_f32(
6863
7172
  constant uint & r2,
6864
7173
  constant uint & r3,
6865
7174
  constant int & idx,
6866
- device const char * src00,
6867
- device const char * src01,
6868
- device const char * src02,
6869
- device const char * src03,
6870
- device const char * src04,
6871
- device const char * src05,
6872
- device const char * src06,
6873
- device const char * src07,
6874
7175
  threadgroup float * shared_values [[threadgroup(0)]],
6875
7176
  uint3 tgpig[[threadgroup_position_in_grid]],
6876
7177
  uint tiitg[[thread_index_in_threadgroup]],
6877
7178
  uint tiisg[[thread_index_in_simdgroup]],
6878
7179
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
6879
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6880
-
6881
7180
  const int64_t bid = tgpig.z/(ne12*ne13);
6882
7181
 
6883
7182
  tgpig.z = tgpig.z%(ne12*ne13);
6884
7183
 
6885
7184
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
7185
+ device const char * src0 = src0s + id*nb02;
6886
7186
 
6887
7187
  #if QK_K == 64
6888
7188
  kernel_mul_mv_iq4_nl_f32_impl(
6889
7189
  #else
6890
7190
  kernel_mul_mv_iq4_xs_f32_impl(
6891
7191
  #endif
6892
- src0[id],
7192
+ src0,
6893
7193
  (device const float *) (src1 + bid*nb11),
6894
7194
  dst + bid*ne0,
6895
7195
  ne00,