llama_cpp 0.14.2 → 0.14.4

Sign up to get free protection for your applications and to get access to all the features.
@@ -13,8 +13,8 @@ using namespace metal;
13
13
  #define N_SIMDWIDTH 32 // assuming SIMD group size is 32
14
14
 
15
15
  enum ggml_sort_order {
16
- GGML_SORT_ASC,
17
- GGML_SORT_DESC,
16
+ GGML_SORT_ORDER_ASC,
17
+ GGML_SORT_ORDER_DESC,
18
18
  };
19
19
 
20
20
  // general-purpose kernel for addition, multiplication and division of two tensors
@@ -1973,9 +1973,11 @@ kernel void kernel_timestep_embedding_f32(
1973
1973
 
1974
1974
  // bitonic sort implementation following the CUDA kernels as reference
1975
1975
  typedef void (argsort_t)(
1976
- device const float * x,
1977
- device int32_t * dst,
1978
- constant int64_t & ncols,
1976
+ device const float * x,
1977
+ device int32_t * dst,
1978
+ constant int64_t & ncols,
1979
+ constant int64_t & ncols_pad,
1980
+ threadgroup int32_t * shared_values [[threadgroup(0)]],
1979
1981
  uint3 tgpig[[threadgroup_position_in_grid]],
1980
1982
  uint3 tpitg[[thread_position_in_threadgroup]]);
1981
1983
 
@@ -1984,33 +1986,42 @@ kernel void kernel_argsort_f32_i32(
1984
1986
  device const float * x,
1985
1987
  device int32_t * dst,
1986
1988
  constant int64_t & ncols,
1989
+ constant int64_t & ncols_pad,
1990
+ threadgroup int32_t * shared_values [[threadgroup(0)]],
1987
1991
  uint3 tgpig[[threadgroup_position_in_grid]],
1988
1992
  uint3 tpitg[[thread_position_in_threadgroup]]) {
1989
1993
  // bitonic sort
1990
1994
  int col = tpitg[0];
1991
1995
  int row = tgpig[1];
1992
1996
 
1993
- if (col >= ncols) return;
1997
+ if (col >= ncols_pad) return;
1994
1998
 
1995
- device const float * x_row = x + row * ncols;
1996
- device int32_t * dst_row = dst + row * ncols;
1999
+ device const float * x_row = x + row * ncols;
2000
+ threadgroup int32_t * dst_row = shared_values;
1997
2001
 
1998
2002
  // initialize indices
1999
- if (col < ncols) {
2000
- dst_row[col] = col;
2001
- }
2003
+ dst_row[col] = col;
2004
+
2002
2005
  threadgroup_barrier(mem_flags::mem_threadgroup);
2003
2006
 
2004
- for (int k = 2; k <= ncols; k *= 2) {
2007
+ for (int k = 2; k <= ncols_pad; k *= 2) {
2005
2008
  for (int j = k / 2; j > 0; j /= 2) {
2006
2009
  int ixj = col ^ j;
2007
2010
  if (ixj > col) {
2008
2011
  if ((col & k) == 0) {
2009
- if (order == GGML_SORT_ASC ? x_row[dst_row[col]] > x_row[dst_row[ixj]] : x_row[dst_row[col]] < x_row[dst_row[ixj]]) {
2012
+ if (dst_row[col] >= ncols ||
2013
+ (dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ?
2014
+ x_row[dst_row[col]] > x_row[dst_row[ixj]] :
2015
+ x_row[dst_row[col]] < x_row[dst_row[ixj]]))
2016
+ ) {
2010
2017
  SWAP(dst_row[col], dst_row[ixj]);
2011
2018
  }
2012
2019
  } else {
2013
- if (order == GGML_SORT_ASC ? x_row[dst_row[col]] < x_row[dst_row[ixj]] : x_row[dst_row[col]] > x_row[dst_row[ixj]]) {
2020
+ if (dst_row[ixj] >= ncols ||
2021
+ (dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ?
2022
+ x_row[dst_row[col]] < x_row[dst_row[ixj]] :
2023
+ x_row[dst_row[col]] > x_row[dst_row[ixj]]))
2024
+ ) {
2014
2025
  SWAP(dst_row[col], dst_row[ixj]);
2015
2026
  }
2016
2027
  }
@@ -2018,10 +2029,15 @@ kernel void kernel_argsort_f32_i32(
2018
2029
  threadgroup_barrier(mem_flags::mem_threadgroup);
2019
2030
  }
2020
2031
  }
2032
+
2033
+ // copy the result to dst without the padding
2034
+ if (col < ncols) {
2035
+ dst[row * ncols + col] = dst_row[col];
2036
+ }
2021
2037
  }
2022
2038
 
2023
- template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ASC>;
2024
- template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_DESC>;
2039
+ template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_ASC>;
2040
+ template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_DESC>;
2025
2041
 
2026
2042
  kernel void kernel_leaky_relu_f32(
2027
2043
  device const float * src0,
@@ -2388,6 +2404,242 @@ kernel void kernel_cpy_f32_q4_1(
2388
2404
  }
2389
2405
  }
2390
2406
 
2407
+ kernel void kernel_cpy_f32_q5_0(
2408
+ device const float * src0,
2409
+ device void * dst,
2410
+ constant int64_t & ne00,
2411
+ constant int64_t & ne01,
2412
+ constant int64_t & ne02,
2413
+ constant int64_t & ne03,
2414
+ constant uint64_t & nb00,
2415
+ constant uint64_t & nb01,
2416
+ constant uint64_t & nb02,
2417
+ constant uint64_t & nb03,
2418
+ constant int64_t & ne0,
2419
+ constant int64_t & ne1,
2420
+ constant int64_t & ne2,
2421
+ constant int64_t & ne3,
2422
+ constant uint64_t & nb0,
2423
+ constant uint64_t & nb1,
2424
+ constant uint64_t & nb2,
2425
+ constant uint64_t & nb3,
2426
+ uint3 tgpig[[threadgroup_position_in_grid]],
2427
+ uint3 tpitg[[thread_position_in_threadgroup]],
2428
+ uint3 ntg[[threads_per_threadgroup]]) {
2429
+ const int64_t i03 = tgpig[2];
2430
+ const int64_t i02 = tgpig[1];
2431
+ const int64_t i01 = tgpig[0];
2432
+
2433
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
2434
+
2435
+ const int64_t i3 = n / (ne2*ne1*ne0);
2436
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
2437
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
2438
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK5_0;
2439
+
2440
+ device block_q5_0 * dst_data = (device block_q5_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
2441
+
2442
+ for (int64_t i00 = tpitg.x*QK5_0; i00 < ne00; i00 += ntg.x*QK5_0) {
2443
+ device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
2444
+
2445
+ float amax = 0.0f; // absolute max
2446
+ float max = 0.0f;
2447
+
2448
+ for (int j = 0; j < QK5_0; j++) {
2449
+ const float v = src[j];
2450
+ if (amax < fabs(v)) {
2451
+ amax = fabs(v);
2452
+ max = v;
2453
+ }
2454
+ }
2455
+
2456
+ const float d = max / -16;
2457
+ const float id = d ? 1.0f/d : 0.0f;
2458
+
2459
+ dst_data[i00/QK5_0].d = d;
2460
+
2461
+ uint32_t qh = 0;
2462
+ for (int j = 0; j < QK5_0/2; ++j) {
2463
+ const float x0 = src[0 + j]*id;
2464
+ const float x1 = src[QK5_0/2 + j]*id;
2465
+
2466
+ const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));
2467
+ const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));
2468
+
2469
+ dst_data[i00/QK5_0].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
2470
+ qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
2471
+ qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
2472
+ }
2473
+ thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
2474
+ for (int j = 0; j < 4; ++j) {
2475
+ dst_data[i00/QK5_0].qh[j] = qh8[j];
2476
+ }
2477
+ }
2478
+ }
2479
+
2480
+ kernel void kernel_cpy_f32_q5_1(
2481
+ device const float * src0,
2482
+ device void * dst,
2483
+ constant int64_t & ne00,
2484
+ constant int64_t & ne01,
2485
+ constant int64_t & ne02,
2486
+ constant int64_t & ne03,
2487
+ constant uint64_t & nb00,
2488
+ constant uint64_t & nb01,
2489
+ constant uint64_t & nb02,
2490
+ constant uint64_t & nb03,
2491
+ constant int64_t & ne0,
2492
+ constant int64_t & ne1,
2493
+ constant int64_t & ne2,
2494
+ constant int64_t & ne3,
2495
+ constant uint64_t & nb0,
2496
+ constant uint64_t & nb1,
2497
+ constant uint64_t & nb2,
2498
+ constant uint64_t & nb3,
2499
+ uint3 tgpig[[threadgroup_position_in_grid]],
2500
+ uint3 tpitg[[thread_position_in_threadgroup]],
2501
+ uint3 ntg[[threads_per_threadgroup]]) {
2502
+ const int64_t i03 = tgpig[2];
2503
+ const int64_t i02 = tgpig[1];
2504
+ const int64_t i01 = tgpig[0];
2505
+
2506
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
2507
+
2508
+ const int64_t i3 = n / (ne2*ne1*ne0);
2509
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
2510
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
2511
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK5_1;
2512
+
2513
+ device block_q5_1 * dst_data = (device block_q5_1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
2514
+
2515
+ for (int64_t i00 = tpitg.x*QK5_1; i00 < ne00; i00 += ntg.x*QK5_1) {
2516
+ device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
2517
+
2518
+ float max = src[0];
2519
+ float min = src[0];
2520
+
2521
+ for (int j = 1; j < QK5_1; j++) {
2522
+ const float v = src[j];
2523
+ min = v < min ? v : min;
2524
+ max = v > max ? v : max;
2525
+ }
2526
+
2527
+ const float d = (max - min) / 31;
2528
+ const float id = d ? 1.0f/d : 0.0f;
2529
+
2530
+ dst_data[i00/QK5_1].d = d;
2531
+ dst_data[i00/QK5_1].m = min;
2532
+
2533
+ uint32_t qh = 0;
2534
+ for (int j = 0; j < QK5_1/2; ++j) {
2535
+ const float x0 = (src[0 + j] - min)*id;
2536
+ const float x1 = (src[QK5_1/2 + j] - min)*id;
2537
+
2538
+ const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
2539
+ const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
2540
+
2541
+ dst_data[i00/QK5_1].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
2542
+ qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
2543
+ qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
2544
+ }
2545
+ thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
2546
+ for (int j = 0; j < 4; ++j) {
2547
+ dst_data[i00/QK5_1].qh[j] = qh8[j];
2548
+ }
2549
+ }
2550
+ }
2551
+
2552
+ static inline int best_index_int8(int n, constant float * val, float x) {
2553
+ if (x <= val[0]) return 0;
2554
+ if (x >= val[n-1]) return n-1;
2555
+ int ml = 0, mu = n-1;
2556
+ while (mu-ml > 1) {
2557
+ int mav = (ml+mu)/2;
2558
+ if (x < val[mav]) mu = mav; else ml = mav;
2559
+ }
2560
+ return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
2561
+ }
2562
+
2563
+ constexpr constant static float kvalues_iq4nl_f[16] = {
2564
+ -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
2565
+ };
2566
+
2567
+ kernel void kernel_cpy_f32_iq4_nl(
2568
+ device const float * src0,
2569
+ device void * dst,
2570
+ constant int64_t & ne00,
2571
+ constant int64_t & ne01,
2572
+ constant int64_t & ne02,
2573
+ constant int64_t & ne03,
2574
+ constant uint64_t & nb00,
2575
+ constant uint64_t & nb01,
2576
+ constant uint64_t & nb02,
2577
+ constant uint64_t & nb03,
2578
+ constant int64_t & ne0,
2579
+ constant int64_t & ne1,
2580
+ constant int64_t & ne2,
2581
+ constant int64_t & ne3,
2582
+ constant uint64_t & nb0,
2583
+ constant uint64_t & nb1,
2584
+ constant uint64_t & nb2,
2585
+ constant uint64_t & nb3,
2586
+ uint3 tgpig[[threadgroup_position_in_grid]],
2587
+ uint3 tpitg[[thread_position_in_threadgroup]],
2588
+ uint3 ntg[[threads_per_threadgroup]]) {
2589
+ const int64_t i03 = tgpig[2];
2590
+ const int64_t i02 = tgpig[1];
2591
+ const int64_t i01 = tgpig[0];
2592
+
2593
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
2594
+
2595
+ const int64_t i3 = n / (ne2*ne1*ne0);
2596
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
2597
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
2598
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_NL;
2599
+
2600
+ device block_iq4_nl * dst_data = (device block_iq4_nl *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
2601
+
2602
+ for (int64_t i00 = tpitg.x*QK4_NL; i00 < ne00; i00 += ntg.x*QK4_NL) {
2603
+ device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
2604
+
2605
+ float amax = 0.0f; // absolute max
2606
+ float max = 0.0f;
2607
+
2608
+ for (int j = 0; j < QK4_0; j++) {
2609
+ const float v = src[j];
2610
+ if (amax < fabs(v)) {
2611
+ amax = fabs(v);
2612
+ max = v;
2613
+ }
2614
+ }
2615
+
2616
+ const float d = max / kvalues_iq4nl_f[0];
2617
+ const float id = d ? 1.0f/d : 0.0f;
2618
+
2619
+ float sumqx = 0, sumq2 = 0;
2620
+ for (int j = 0; j < QK4_NL/2; ++j) {
2621
+ const float x0 = src[0 + j]*id;
2622
+ const float x1 = src[QK4_NL/2 + j]*id;
2623
+
2624
+ const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl_f, x0);
2625
+ const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl_f, x1);
2626
+
2627
+ dst_data[i00/QK4_NL].qs[j] = xi0 | (xi1 << 4);
2628
+
2629
+ const float v0 = kvalues_iq4nl_f[xi0];
2630
+ const float v1 = kvalues_iq4nl_f[xi1];
2631
+ const float w0 = src[0 + j]*src[0 + j];
2632
+ const float w1 = src[QK4_NL/2 + j]*src[QK4_NL/2 + j];
2633
+ sumqx += w0*v0*src[j] + w1*v1*src[QK4_NL/2 + j];
2634
+ sumq2 += w0*v0*v0 + w1*v1*v1;
2635
+
2636
+ }
2637
+
2638
+ dst_data[i00/QK4_NL].d = sumq2 > 0 ? sumqx/sumq2 : d;
2639
+
2640
+ }
2641
+ }
2642
+
2391
2643
  kernel void kernel_concat(
2392
2644
  device const char * src0,
2393
2645
  device const char * src1,
@@ -4220,9 +4472,113 @@ void kernel_mul_mv_iq1_s_f32_impl(
4220
4472
  }
4221
4473
  }
4222
4474
 
4223
- constexpr constant static float kvalues_iq4nl_f[16] = {
4224
- -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
4225
- };
4475
+ void kernel_mul_mv_iq1_m_f32_impl(
4476
+ device const void * src0,
4477
+ device const float * src1,
4478
+ device float * dst,
4479
+ constant int64_t & ne00,
4480
+ constant int64_t & ne01,
4481
+ constant int64_t & ne02,
4482
+ constant int64_t & ne10,
4483
+ constant int64_t & ne12,
4484
+ constant int64_t & ne0,
4485
+ constant int64_t & ne1,
4486
+ constant uint & r2,
4487
+ constant uint & r3,
4488
+ uint3 tgpig[[threadgroup_position_in_grid]],
4489
+ uint tiisg[[thread_index_in_simdgroup]],
4490
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4491
+
4492
+ const int nb = ne00/QK_K;
4493
+ const int r0 = tgpig.x;
4494
+ const int r1 = tgpig.y;
4495
+ const int im = tgpig.z;
4496
+
4497
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
4498
+ const int ib_row = first_row * nb;
4499
+
4500
+ const uint i12 = im%ne12;
4501
+ const uint i13 = im/ne12;
4502
+
4503
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
4504
+ device const block_iq1_m * x = (device const block_iq1_m *) src0 + ib_row + offset0;
4505
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
4506
+
4507
+ float yl[32];
4508
+ float sumf[N_DST]={0.f}, all_sum;
4509
+
4510
+ const int nb32 = nb * (QK_K / 32);
4511
+
4512
+ const int ix = tiisg;
4513
+
4514
+ device const float * y4 = y + 32 * ix;
4515
+
4516
+ #if QK_K != 64
4517
+ iq1m_scale_t scale;
4518
+ #endif
4519
+
4520
+ for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
4521
+
4522
+ float4 sumy = {0.f};
4523
+ for (int i = 0; i < 8; ++i) {
4524
+ yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
4525
+ yl[i+ 8] = y4[i+ 8]; sumy[1] += yl[i+ 8];
4526
+ yl[i+16] = y4[i+16]; sumy[2] += yl[i+16];
4527
+ yl[i+24] = y4[i+24]; sumy[3] += yl[i+24];
4528
+ }
4529
+
4530
+ const int ibl = ib32 / (QK_K / 32);
4531
+ const int ib = ib32 % (QK_K / 32);
4532
+
4533
+ device const block_iq1_m * xr = x + ibl;
4534
+ device const uint8_t * qs = xr->qs + 4 * ib;
4535
+ device const uint8_t * qh = xr->qh + 2 * ib;
4536
+ device const uint16_t * sc = (device const uint16_t *)xr->scales;
4537
+
4538
+ for (int row = 0; row < N_DST; row++) {
4539
+
4540
+ #if QK_K != 64
4541
+ scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
4542
+ #endif
4543
+
4544
+ constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
4545
+ constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700)));
4546
+ constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid_gpu + (qs[2] | ((qh[1] << 8) & 0x700)));
4547
+ constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[1] << 4) & 0x700)));
4548
+
4549
+ float2 sum = {0.f};
4550
+ for (int j = 0; j < 4; ++j) {
4551
+ sum[0] += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4)
4552
+ + yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4);
4553
+ sum[1] += yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4)
4554
+ + yl[j+24] * (grid4[j] & 0xf) + yl[j+28] * (grid4[j] >> 4);
4555
+ }
4556
+ const float delta1 = sumy[0] * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[1] * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
4557
+ const float delta2 = sumy[2] * (qh[1] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[3] * (qh[1] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
4558
+ #if QK_K == 64
4559
+ const float d = (float) *((device const half *)(sc - 1));
4560
+ sumf[row] += d * ((sum[0] + delta1) * (2*((sc[0] >> (8*(ib%2)+0)) & 0xf) + 1) +
4561
+ (sum[1] + delta2) * (2*((sc[0] >> (8*(ib%2)+4)) & 0xf) + 1));
4562
+ #else
4563
+ sumf[row] += (float)scale.f16 * ((sum[0] + delta1) * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 7) + 1) +
4564
+ (sum[1] + delta2) * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 7) + 1));
4565
+ #endif
4566
+
4567
+ sc += nb*sizeof(block_iq1_m)/2;
4568
+ qs += nb*sizeof(block_iq1_m);
4569
+ qh += nb*sizeof(block_iq1_m);
4570
+ }
4571
+
4572
+ y4 += 32 * 32;
4573
+ }
4574
+
4575
+ for (int row = 0; row < N_DST; ++row) {
4576
+ all_sum = simd_sum(sumf[row]);
4577
+ if (tiisg == 0) {
4578
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
4579
+ }
4580
+ }
4581
+ }
4226
4582
 
4227
4583
  void kernel_mul_mv_iq4_nl_f32_impl(
4228
4584
  device const void * src0,
@@ -4441,6 +4797,34 @@ kernel void kernel_mul_mv_iq1_s_f32(
4441
4797
  kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
4442
4798
  }
4443
4799
 
4800
+ [[host_name("kernel_mul_mv_iq1_m_f32")]]
4801
+ kernel void kernel_mul_mv_iq1_m_f32(
4802
+ device const void * src0,
4803
+ device const float * src1,
4804
+ device float * dst,
4805
+ constant int64_t & ne00,
4806
+ constant int64_t & ne01,
4807
+ constant int64_t & ne02,
4808
+ constant uint64_t & nb00,
4809
+ constant uint64_t & nb01,
4810
+ constant uint64_t & nb02,
4811
+ constant int64_t & ne10,
4812
+ constant int64_t & ne11,
4813
+ constant int64_t & ne12,
4814
+ constant uint64_t & nb10,
4815
+ constant uint64_t & nb11,
4816
+ constant uint64_t & nb12,
4817
+ constant int64_t & ne0,
4818
+ constant int64_t & ne1,
4819
+ constant uint & r2,
4820
+ constant uint & r3,
4821
+ uint3 tgpig[[threadgroup_position_in_grid]],
4822
+ uint tiisg[[thread_index_in_simdgroup]],
4823
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
4824
+
4825
+ kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
4826
+ }
4827
+
4444
4828
  [[host_name("kernel_mul_mv_iq4_nl_f32")]]
4445
4829
  kernel void kernel_mul_mv_iq4_nl_f32(
4446
4830
  device const void * src0,
@@ -4914,6 +5298,38 @@ void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 &
4914
5298
  }
4915
5299
  }
4916
5300
 
5301
+ template <typename type4x4>
5302
+ void dequantize_iq1_m(device const block_iq1_m * xb, short il, thread type4x4 & reg) {
5303
+ // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
5304
+ const int ib32 = il/2;
5305
+ il = il%2;
5306
+ device const uint16_t * sc = (device const uint16_t *)xb->scales;
5307
+ #if QK_K == 64
5308
+ const float d = xb->d;
5309
+ #else
5310
+ iq1m_scale_t scale;
5311
+ scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
5312
+ const float d = scale.f16;
5313
+ #endif
5314
+ device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
5315
+ device const uint8_t * qh = xb->qh + 2*ib32 + il;
5316
+ #if QK_K == 64
5317
+ const float dl = d * (2*((sc[ib32/2] >> (8*(ib32%2)+4*il)) & 0xf) + 1);
5318
+ #else
5319
+ const float dl = d * (2*((sc[ib32/2] >> (6*(ib32%2)+3*il)) & 7) + 1);
5320
+ #endif
5321
+ const float ml1 = dl * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
5322
+ const float ml2 = dl * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
5323
+ constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
5324
+ constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700)));
5325
+ for (int i = 0; i < 4; ++i) {
5326
+ reg[0][i] = dl * (grid1[i] & 0xf) + ml1;
5327
+ reg[1][i] = dl * (grid1[i] >> 4) + ml1;
5328
+ reg[2][i] = dl * (grid2[i] & 0xf) + ml2;
5329
+ reg[3][i] = dl * (grid2[i] >> 4) + ml2;
5330
+ }
5331
+ }
5332
+
4917
5333
  template <typename type4x4>
4918
5334
  void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4 & reg) {
4919
5335
  device const uint16_t * q4 = (device const uint16_t *)xb->qs;
@@ -5385,9 +5801,10 @@ kernel void kernel_mul_mm(device const uchar * src0,
5385
5801
 
5386
5802
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
5387
5803
  kernel void kernel_mul_mm_id(
5388
- device const uchar * ids,
5804
+ device const uchar * src0s,
5389
5805
  device const uchar * src1,
5390
5806
  device float * dst,
5807
+ device const uchar * ids,
5391
5808
  constant uint64_t & nbi1,
5392
5809
  constant int64_t & ne00,
5393
5810
  constant int64_t & ne02,
@@ -5404,22 +5821,14 @@ kernel void kernel_mul_mm_id(
5404
5821
  constant uint & r2,
5405
5822
  constant uint & r3,
5406
5823
  constant int & idx,
5407
- device const uchar * src00,
5408
- device const uchar * src01,
5409
- device const uchar * src02,
5410
- device const uchar * src03,
5411
- device const uchar * src04,
5412
- device const uchar * src05,
5413
- device const uchar * src06,
5414
- device const uchar * src07,
5415
5824
  threadgroup uchar * shared_memory [[threadgroup(0)]],
5416
5825
  uint3 tgpig[[threadgroup_position_in_grid]],
5417
5826
  uint tiitg[[thread_index_in_threadgroup]],
5418
5827
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
5419
- device const uchar * src0s[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
5420
5828
 
5421
5829
  // expert id
5422
5830
  const int32_t id = tgpig.z/(ne12*ne13);
5831
+ device const uchar * src0 = src0s + id*nb02;
5423
5832
 
5424
5833
  tgpig.z = tgpig.z%(ne12*ne13);
5425
5834
 
@@ -5434,7 +5843,7 @@ kernel void kernel_mul_mm_id(
5434
5843
  }
5435
5844
 
5436
5845
  kernel_mul_mm_id_impl<block_q, nl, dequantize_func>(
5437
- src0s[id],
5846
+ src0,
5438
5847
  src1,
5439
5848
  src1ids,
5440
5849
  dst,
@@ -5498,6 +5907,7 @@ template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_t kernel_get_r
5498
5907
  template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_t kernel_get_rows<block_iq3_s, QK_NL, dequantize_iq3_s>;
5499
5908
  template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_t kernel_get_rows<block_iq2_s, QK_NL, dequantize_iq2_s>;
5500
5909
  template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows<block_iq1_s, QK_NL, dequantize_iq1_s>;
5910
+ template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_t kernel_get_rows<block_iq1_m, QK_NL, dequantize_iq1_m>;
5501
5911
  template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows<block_iq4_nl, 2, dequantize_iq4_nl>;
5502
5912
  #if QK_K == 64
5503
5913
  template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows<block_iq4_xs, 2, dequantize_iq4_xs>;
@@ -5528,24 +5938,25 @@ typedef void (mat_mm_t)(
5528
5938
  threadgroup uchar *,
5529
5939
  uint3, uint, uint);
5530
5940
 
5531
- template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<float4x4, 1, dequantize_f32>;
5532
- template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
5533
- template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
5534
- template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
5535
- template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_0, 2, dequantize_q5_0>;
5536
- template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_1, 2, dequantize_q5_1>;
5537
- template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
5538
- template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
5539
- template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
5540
- template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
5541
- template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
5542
- template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
5941
+ template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<float4x4, 1, dequantize_f32>;
5942
+ template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
5943
+ template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
5944
+ template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
5945
+ template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_0, 2, dequantize_q5_0>;
5946
+ template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_1, 2, dequantize_q5_1>;
5947
+ template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
5948
+ template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
5949
+ template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
5950
+ template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
5951
+ template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
5952
+ template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
5543
5953
  template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
5544
5954
  template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
5545
5955
  template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
5546
5956
  template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_s, QK_NL, dequantize_iq3_s>;
5547
5957
  template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_s, QK_NL, dequantize_iq2_s>;
5548
5958
  template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_s, QK_NL, dequantize_iq1_s>;
5959
+ template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_m, QK_NL, dequantize_iq1_m>;
5549
5960
  template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_nl>;
5550
5961
  #if QK_K == 64
5551
5962
  template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_xs>;
@@ -5558,9 +5969,10 @@ template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_m
5558
5969
  //
5559
5970
 
5560
5971
  typedef void (mat_mm_id_t)(
5561
- device const uchar * ids,
5972
+ device const uchar * src0s,
5562
5973
  device const uchar * src1,
5563
5974
  device float * dst,
5975
+ device const uchar * ids,
5564
5976
  constant uint64_t & nbi1,
5565
5977
  constant int64_t & ne00,
5566
5978
  constant int64_t & ne02,
@@ -5577,35 +5989,28 @@ typedef void (mat_mm_id_t)(
5577
5989
  constant uint & r2,
5578
5990
  constant uint & r3,
5579
5991
  constant int & idx,
5580
- device const uchar * src00,
5581
- device const uchar * src01,
5582
- device const uchar * src02,
5583
- device const uchar * src03,
5584
- device const uchar * src04,
5585
- device const uchar * src05,
5586
- device const uchar * src06,
5587
- device const uchar * src07,
5588
5992
  threadgroup uchar *,
5589
5993
  uint3, uint, uint);
5590
5994
 
5591
- template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<float4x4, 1, dequantize_f32>;
5592
- template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<half4x4, 1, dequantize_f16>;
5593
- template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_0, 2, dequantize_q4_0>;
5594
- template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_1, 2, dequantize_q4_1>;
5595
- template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_0, 2, dequantize_q5_0>;
5596
- template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_1, 2, dequantize_q5_1>;
5597
- template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q8_0, 2, dequantize_q8_0>;
5598
- template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q2_K, QK_NL, dequantize_q2_K>;
5599
- template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q3_K, QK_NL, dequantize_q3_K>;
5600
- template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>;
5601
- template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
5602
- template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
5995
+ template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<float4x4, 1, dequantize_f32>;
5996
+ template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<half4x4, 1, dequantize_f16>;
5997
+ template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_0, 2, dequantize_q4_0>;
5998
+ template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_1, 2, dequantize_q4_1>;
5999
+ template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_0, 2, dequantize_q5_0>;
6000
+ template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_1, 2, dequantize_q5_1>;
6001
+ template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q8_0, 2, dequantize_q8_0>;
6002
+ template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q2_K, QK_NL, dequantize_q2_K>;
6003
+ template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q3_K, QK_NL, dequantize_q3_K>;
6004
+ template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>;
6005
+ template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
6006
+ template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
5603
6007
  template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
5604
6008
  template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
5605
6009
  template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
5606
6010
  template [[host_name("kernel_mul_mm_id_iq3_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_s, QK_NL, dequantize_iq3_s>;
5607
6011
  template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_s, QK_NL, dequantize_iq2_s>;
5608
6012
  template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s, QK_NL, dequantize_iq1_s>;
6013
+ template [[host_name("kernel_mul_mm_id_iq1_m_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_m, QK_NL, dequantize_iq1_m>;
5609
6014
  template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl, 2, dequantize_iq4_nl>;
5610
6015
  #if QK_K == 64
5611
6016
  template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, 2, dequantize_iq4_xs>;
@@ -5619,9 +6024,10 @@ template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel
5619
6024
 
5620
6025
  [[host_name("kernel_mul_mv_id_f32_f32")]]
5621
6026
  kernel void kernel_mul_mv_id_f32_f32(
5622
- device const char * ids,
6027
+ device const char * src0s,
5623
6028
  device const char * src1,
5624
6029
  device float * dst,
6030
+ device const char * ids,
5625
6031
  constant uint64_t & nbi1,
5626
6032
  constant int64_t & ne00,
5627
6033
  constant int64_t & ne01,
@@ -5642,28 +6048,19 @@ kernel void kernel_mul_mv_id_f32_f32(
5642
6048
  constant uint & r2,
5643
6049
  constant uint & r3,
5644
6050
  constant int & idx,
5645
- device const char * src00,
5646
- device const char * src01,
5647
- device const char * src02,
5648
- device const char * src03,
5649
- device const char * src04,
5650
- device const char * src05,
5651
- device const char * src06,
5652
- device const char * src07,
5653
6051
  uint3 tgpig[[threadgroup_position_in_grid]],
5654
6052
  uint tiitg[[thread_index_in_threadgroup]],
5655
6053
  uint tiisg[[thread_index_in_simdgroup]],
5656
6054
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
5657
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
5658
-
5659
6055
  const int64_t bid = tgpig.z/(ne12*ne13);
5660
6056
 
5661
6057
  tgpig.z = tgpig.z%(ne12*ne13);
5662
6058
 
5663
6059
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6060
+ device const char * src0 = src0s + id*nb02;
5664
6061
 
5665
6062
  kernel_mul_mv_f32_f32_impl(
5666
- src0[id],
6063
+ src0,
5667
6064
  src1 + bid*nb11,
5668
6065
  dst + bid*ne0,
5669
6066
  ne00,
@@ -5688,9 +6085,10 @@ kernel void kernel_mul_mv_id_f32_f32(
5688
6085
 
5689
6086
  [[host_name("kernel_mul_mv_id_f16_f32")]]
5690
6087
  kernel void kernel_mul_mv_id_f16_f32(
5691
- device const char * ids,
6088
+ device const char * src0s,
5692
6089
  device const char * src1,
5693
6090
  device float * dst,
6091
+ device const char * ids,
5694
6092
  constant uint64_t & nbi1,
5695
6093
  constant int64_t & ne00,
5696
6094
  constant int64_t & ne01,
@@ -5711,28 +6109,19 @@ kernel void kernel_mul_mv_id_f16_f32(
5711
6109
  constant uint & r2,
5712
6110
  constant uint & r3,
5713
6111
  constant int & idx,
5714
- device const char * src00,
5715
- device const char * src01,
5716
- device const char * src02,
5717
- device const char * src03,
5718
- device const char * src04,
5719
- device const char * src05,
5720
- device const char * src06,
5721
- device const char * src07,
5722
6112
  uint3 tgpig[[threadgroup_position_in_grid]],
5723
6113
  uint tiitg[[thread_index_in_threadgroup]],
5724
6114
  uint tiisg[[thread_index_in_simdgroup]],
5725
6115
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
5726
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
5727
-
5728
6116
  const int64_t bid = tgpig.z/(ne12*ne13);
5729
6117
 
5730
6118
  tgpig.z = tgpig.z%(ne12*ne13);
5731
6119
 
5732
6120
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6121
+ device const char * src0 = src0s + id*nb02;
5733
6122
 
5734
6123
  kernel_mul_mv_f16_f32_impl(
5735
- src0[id],
6124
+ src0,
5736
6125
  src1 + bid*nb11,
5737
6126
  dst + bid*ne0,
5738
6127
  ne00,
@@ -5757,9 +6146,10 @@ kernel void kernel_mul_mv_id_f16_f32(
5757
6146
 
5758
6147
  [[host_name("kernel_mul_mv_id_q8_0_f32")]]
5759
6148
  kernel void kernel_mul_mv_id_q8_0_f32(
5760
- device const char * ids,
6149
+ device const char * src0s,
5761
6150
  device const char * src1,
5762
6151
  device float * dst,
6152
+ device const char * ids,
5763
6153
  constant uint64_t & nbi1,
5764
6154
  constant int64_t & ne00,
5765
6155
  constant int64_t & ne01,
@@ -5780,28 +6170,19 @@ kernel void kernel_mul_mv_id_q8_0_f32(
5780
6170
  constant uint & r2,
5781
6171
  constant uint & r3,
5782
6172
  constant int & idx,
5783
- device const char * src00,
5784
- device const char * src01,
5785
- device const char * src02,
5786
- device const char * src03,
5787
- device const char * src04,
5788
- device const char * src05,
5789
- device const char * src06,
5790
- device const char * src07,
5791
6173
  uint3 tgpig[[threadgroup_position_in_grid]],
5792
6174
  uint tiitg[[thread_index_in_threadgroup]],
5793
6175
  uint tiisg[[thread_index_in_simdgroup]],
5794
6176
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
5795
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
5796
-
5797
6177
  const int64_t bid = tgpig.z/(ne12*ne13);
5798
6178
 
5799
6179
  tgpig.z = tgpig.z%(ne12*ne13);
5800
6180
 
5801
6181
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6182
+ device const char * src0 = src0s + id*nb02;
5802
6183
 
5803
6184
  kernel_mul_mv_q8_0_f32_impl(
5804
- src0[id],
6185
+ src0,
5805
6186
  (device const float *) (src1 + bid*nb11),
5806
6187
  dst + bid*ne0,
5807
6188
  ne00,
@@ -5820,9 +6201,10 @@ kernel void kernel_mul_mv_id_q8_0_f32(
5820
6201
 
5821
6202
  [[host_name("kernel_mul_mv_id_q4_0_f32")]]
5822
6203
  kernel void kernel_mul_mv_id_q4_0_f32(
5823
- device const char * ids,
6204
+ device const char * src0s,
5824
6205
  device const char * src1,
5825
6206
  device float * dst,
6207
+ device const char * ids,
5826
6208
  constant uint64_t & nbi1,
5827
6209
  constant int64_t & ne00,
5828
6210
  constant int64_t & ne01,
@@ -5843,28 +6225,19 @@ kernel void kernel_mul_mv_id_q4_0_f32(
5843
6225
  constant uint & r2,
5844
6226
  constant uint & r3,
5845
6227
  constant int & idx,
5846
- device const char * src00,
5847
- device const char * src01,
5848
- device const char * src02,
5849
- device const char * src03,
5850
- device const char * src04,
5851
- device const char * src05,
5852
- device const char * src06,
5853
- device const char * src07,
5854
6228
  uint3 tgpig[[threadgroup_position_in_grid]],
5855
6229
  uint tiitg[[thread_index_in_threadgroup]],
5856
6230
  uint tiisg[[thread_index_in_simdgroup]],
5857
6231
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
5858
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
5859
-
5860
6232
  const int64_t bid = tgpig.z/(ne12*ne13);
5861
6233
 
5862
6234
  tgpig.z = tgpig.z%(ne12*ne13);
5863
6235
 
5864
6236
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6237
+ device const char * src0 = src0s + id*nb02;
5865
6238
 
5866
6239
  mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
5867
- src0[id],
6240
+ src0,
5868
6241
  (device const float *) (src1 + bid*nb11),
5869
6242
  dst + bid*ne0,
5870
6243
  ne00,
@@ -5883,9 +6256,10 @@ kernel void kernel_mul_mv_id_q4_0_f32(
5883
6256
 
5884
6257
  [[host_name("kernel_mul_mv_id_q4_1_f32")]]
5885
6258
  kernel void kernel_mul_mv_id_q4_1_f32(
5886
- device const char * ids,
6259
+ device const char * src0s,
5887
6260
  device const char * src1,
5888
6261
  device float * dst,
6262
+ device const char * ids,
5889
6263
  constant uint64_t & nbi1,
5890
6264
  constant int64_t & ne00,
5891
6265
  constant int64_t & ne01,
@@ -5906,28 +6280,19 @@ kernel void kernel_mul_mv_id_q4_1_f32(
5906
6280
  constant uint & r2,
5907
6281
  constant uint & r3,
5908
6282
  constant int & idx,
5909
- device const char * src00,
5910
- device const char * src01,
5911
- device const char * src02,
5912
- device const char * src03,
5913
- device const char * src04,
5914
- device const char * src05,
5915
- device const char * src06,
5916
- device const char * src07,
5917
6283
  uint3 tgpig[[threadgroup_position_in_grid]],
5918
6284
  uint tiitg[[thread_index_in_threadgroup]],
5919
6285
  uint tiisg[[thread_index_in_simdgroup]],
5920
6286
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
5921
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
5922
-
5923
6287
  const int64_t bid = tgpig.z/(ne12*ne13);
5924
6288
 
5925
6289
  tgpig.z = tgpig.z%(ne12*ne13);
5926
6290
 
5927
6291
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6292
+ device const char * src0 = src0s + id*nb02;
5928
6293
 
5929
6294
  mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
5930
- src0[id],
6295
+ src0,
5931
6296
  (device const float *) (src1 + bid*nb11),
5932
6297
  dst + bid*ne0,
5933
6298
  ne00,
@@ -5946,9 +6311,10 @@ kernel void kernel_mul_mv_id_q4_1_f32(
5946
6311
 
5947
6312
  [[host_name("kernel_mul_mv_id_q5_0_f32")]]
5948
6313
  kernel void kernel_mul_mv_id_q5_0_f32(
5949
- device const char * ids,
6314
+ device const char * src0s,
5950
6315
  device const char * src1,
5951
6316
  device float * dst,
6317
+ device const char * ids,
5952
6318
  constant uint64_t & nbi1,
5953
6319
  constant int64_t & ne00,
5954
6320
  constant int64_t & ne01,
@@ -5969,28 +6335,19 @@ kernel void kernel_mul_mv_id_q5_0_f32(
5969
6335
  constant uint & r2,
5970
6336
  constant uint & r3,
5971
6337
  constant int & idx,
5972
- device const char * src00,
5973
- device const char * src01,
5974
- device const char * src02,
5975
- device const char * src03,
5976
- device const char * src04,
5977
- device const char * src05,
5978
- device const char * src06,
5979
- device const char * src07,
5980
6338
  uint3 tgpig[[threadgroup_position_in_grid]],
5981
6339
  uint tiitg[[thread_index_in_threadgroup]],
5982
6340
  uint tiisg[[thread_index_in_simdgroup]],
5983
6341
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
5984
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
5985
-
5986
6342
  const int64_t bid = tgpig.z/(ne12*ne13);
5987
6343
 
5988
6344
  tgpig.z = tgpig.z%(ne12*ne13);
5989
6345
 
5990
6346
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6347
+ device const char * src0 = src0s + id*nb02;
5991
6348
 
5992
6349
  mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
5993
- src0[id],
6350
+ src0,
5994
6351
  (device const float *) (src1 + bid*nb11),
5995
6352
  dst + bid*ne0,
5996
6353
  ne00,
@@ -6009,9 +6366,10 @@ kernel void kernel_mul_mv_id_q5_0_f32(
6009
6366
 
6010
6367
  [[host_name("kernel_mul_mv_id_q5_1_f32")]]
6011
6368
  kernel void kernel_mul_mv_id_q5_1_f32(
6012
- device const char * ids,
6369
+ device const char * src0s,
6013
6370
  device const char * src1,
6014
6371
  device float * dst,
6372
+ device const char * ids,
6015
6373
  constant uint64_t & nbi1,
6016
6374
  constant int64_t & ne00,
6017
6375
  constant int64_t & ne01,
@@ -6032,28 +6390,19 @@ kernel void kernel_mul_mv_id_q5_1_f32(
6032
6390
  constant uint & r2,
6033
6391
  constant uint & r3,
6034
6392
  constant int & idx,
6035
- device const char * src00,
6036
- device const char * src01,
6037
- device const char * src02,
6038
- device const char * src03,
6039
- device const char * src04,
6040
- device const char * src05,
6041
- device const char * src06,
6042
- device const char * src07,
6043
6393
  uint3 tgpig[[threadgroup_position_in_grid]],
6044
6394
  uint tiitg[[thread_index_in_threadgroup]],
6045
6395
  uint tiisg[[thread_index_in_simdgroup]],
6046
6396
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
6047
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6048
-
6049
6397
  const int64_t bid = tgpig.z/(ne12*ne13);
6050
6398
 
6051
6399
  tgpig.z = tgpig.z%(ne12*ne13);
6052
6400
 
6053
6401
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6402
+ device const char * src0 = src0s + id*nb02;
6054
6403
 
6055
6404
  mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
6056
- src0[id],
6405
+ src0,
6057
6406
  (device const float *) (src1 + bid*nb11),
6058
6407
  dst + bid*ne0,
6059
6408
  ne00,
@@ -6072,9 +6421,10 @@ kernel void kernel_mul_mv_id_q5_1_f32(
6072
6421
 
6073
6422
  [[host_name("kernel_mul_mv_id_q2_K_f32")]]
6074
6423
  kernel void kernel_mul_mv_id_q2_K_f32(
6075
- device const char * ids,
6424
+ device const char * src0s,
6076
6425
  device const char * src1,
6077
6426
  device float * dst,
6427
+ device const char * ids,
6078
6428
  constant uint64_t & nbi1,
6079
6429
  constant int64_t & ne00,
6080
6430
  constant int64_t & ne01,
@@ -6095,28 +6445,19 @@ kernel void kernel_mul_mv_id_q2_K_f32(
6095
6445
  constant uint & r2,
6096
6446
  constant uint & r3,
6097
6447
  constant int & idx,
6098
- device const char * src00,
6099
- device const char * src01,
6100
- device const char * src02,
6101
- device const char * src03,
6102
- device const char * src04,
6103
- device const char * src05,
6104
- device const char * src06,
6105
- device const char * src07,
6106
6448
  uint3 tgpig[[threadgroup_position_in_grid]],
6107
6449
  uint tiitg[[thread_index_in_threadgroup]],
6108
6450
  uint tiisg[[thread_index_in_simdgroup]],
6109
6451
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
6110
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6111
-
6112
6452
  const int64_t bid = tgpig.z/(ne12*ne13);
6113
6453
 
6114
6454
  tgpig.z = tgpig.z%(ne12*ne13);
6115
6455
 
6116
6456
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6457
+ device const char * src0 = src0s + id*nb02;
6117
6458
 
6118
6459
  kernel_mul_mv_q2_K_f32_impl(
6119
- src0[id],
6460
+ src0,
6120
6461
  (device const float *) (src1 + bid*nb11),
6121
6462
  dst + bid*ne0,
6122
6463
  ne00,
@@ -6135,9 +6476,10 @@ kernel void kernel_mul_mv_id_q2_K_f32(
6135
6476
 
6136
6477
  [[host_name("kernel_mul_mv_id_q3_K_f32")]]
6137
6478
  kernel void kernel_mul_mv_id_q3_K_f32(
6138
- device const char * ids,
6479
+ device const char * src0s,
6139
6480
  device const char * src1,
6140
6481
  device float * dst,
6482
+ device const char * ids,
6141
6483
  constant uint64_t & nbi1,
6142
6484
  constant int64_t & ne00,
6143
6485
  constant int64_t & ne01,
@@ -6158,28 +6500,19 @@ kernel void kernel_mul_mv_id_q3_K_f32(
6158
6500
  constant uint & r2,
6159
6501
  constant uint & r3,
6160
6502
  constant int & idx,
6161
- device const char * src00,
6162
- device const char * src01,
6163
- device const char * src02,
6164
- device const char * src03,
6165
- device const char * src04,
6166
- device const char * src05,
6167
- device const char * src06,
6168
- device const char * src07,
6169
6503
  uint3 tgpig[[threadgroup_position_in_grid]],
6170
6504
  uint tiitg[[thread_index_in_threadgroup]],
6171
6505
  uint tiisg[[thread_index_in_simdgroup]],
6172
6506
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
6173
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6174
-
6175
6507
  const int64_t bid = tgpig.z/(ne12*ne13);
6176
6508
 
6177
6509
  tgpig.z = tgpig.z%(ne12*ne13);
6178
6510
 
6179
6511
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6512
+ device const char * src0 = src0s + id*nb02;
6180
6513
 
6181
6514
  kernel_mul_mv_q3_K_f32_impl(
6182
- src0[id],
6515
+ src0,
6183
6516
  (device const float *) (src1 + bid*nb11),
6184
6517
  dst + bid*ne0,
6185
6518
  ne00,
@@ -6198,9 +6531,10 @@ kernel void kernel_mul_mv_id_q3_K_f32(
6198
6531
 
6199
6532
  [[host_name("kernel_mul_mv_id_q4_K_f32")]]
6200
6533
  kernel void kernel_mul_mv_id_q4_K_f32(
6201
- device const char * ids,
6534
+ device const char * src0s,
6202
6535
  device const char * src1,
6203
6536
  device float * dst,
6537
+ device const char * ids,
6204
6538
  constant uint64_t & nbi1,
6205
6539
  constant int64_t & ne00,
6206
6540
  constant int64_t & ne01,
@@ -6221,28 +6555,19 @@ kernel void kernel_mul_mv_id_q4_K_f32(
6221
6555
  constant uint & r2,
6222
6556
  constant uint & r3,
6223
6557
  constant int & idx,
6224
- device const char * src00,
6225
- device const char * src01,
6226
- device const char * src02,
6227
- device const char * src03,
6228
- device const char * src04,
6229
- device const char * src05,
6230
- device const char * src06,
6231
- device const char * src07,
6232
6558
  uint3 tgpig[[threadgroup_position_in_grid]],
6233
6559
  uint tiitg[[thread_index_in_threadgroup]],
6234
6560
  uint tiisg[[thread_index_in_simdgroup]],
6235
6561
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
6236
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6237
-
6238
6562
  const int64_t bid = tgpig.z/(ne12*ne13);
6239
6563
 
6240
6564
  tgpig.z = tgpig.z%(ne12*ne13);
6241
6565
 
6242
6566
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6567
+ device const char * src0 = src0s + id*nb02;
6243
6568
 
6244
6569
  kernel_mul_mv_q4_K_f32_impl(
6245
- src0[id],
6570
+ src0,
6246
6571
  (device const float *) (src1 + bid*nb11),
6247
6572
  dst + bid*ne0,
6248
6573
  ne00,
@@ -6261,9 +6586,10 @@ kernel void kernel_mul_mv_id_q4_K_f32(
6261
6586
 
6262
6587
  [[host_name("kernel_mul_mv_id_q5_K_f32")]]
6263
6588
  kernel void kernel_mul_mv_id_q5_K_f32(
6264
- device const char * ids,
6589
+ device const char * src0s,
6265
6590
  device const char * src1,
6266
6591
  device float * dst,
6592
+ device const char * ids,
6267
6593
  constant uint64_t & nbi1,
6268
6594
  constant int64_t & ne00,
6269
6595
  constant int64_t & ne01,
@@ -6284,28 +6610,19 @@ kernel void kernel_mul_mv_id_q5_K_f32(
6284
6610
  constant uint & r2,
6285
6611
  constant uint & r3,
6286
6612
  constant int & idx,
6287
- device const char * src00,
6288
- device const char * src01,
6289
- device const char * src02,
6290
- device const char * src03,
6291
- device const char * src04,
6292
- device const char * src05,
6293
- device const char * src06,
6294
- device const char * src07,
6295
6613
  uint3 tgpig[[threadgroup_position_in_grid]],
6296
6614
  uint tiitg[[thread_index_in_threadgroup]],
6297
6615
  uint tiisg[[thread_index_in_simdgroup]],
6298
6616
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
6299
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6300
-
6301
6617
  const int64_t bid = tgpig.z/(ne12*ne13);
6302
6618
 
6303
6619
  tgpig.z = tgpig.z%(ne12*ne13);
6304
6620
 
6305
6621
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6622
+ device const char * src0 = src0s + id*nb02;
6306
6623
 
6307
6624
  kernel_mul_mv_q5_K_f32_impl(
6308
- src0[id],
6625
+ src0,
6309
6626
  (device const float *) (src1 + bid*nb11),
6310
6627
  dst + bid*ne0,
6311
6628
  ne00,
@@ -6324,9 +6641,10 @@ kernel void kernel_mul_mv_id_q5_K_f32(
6324
6641
 
6325
6642
  [[host_name("kernel_mul_mv_id_q6_K_f32")]]
6326
6643
  kernel void kernel_mul_mv_id_q6_K_f32(
6327
- device const char * ids,
6644
+ device const char * src0s,
6328
6645
  device const char * src1,
6329
6646
  device float * dst,
6647
+ device const char * ids,
6330
6648
  constant uint64_t & nbi1,
6331
6649
  constant int64_t & ne00,
6332
6650
  constant int64_t & ne01,
@@ -6347,28 +6665,19 @@ kernel void kernel_mul_mv_id_q6_K_f32(
6347
6665
  constant uint & r2,
6348
6666
  constant uint & r3,
6349
6667
  constant int & idx,
6350
- device const char * src00,
6351
- device const char * src01,
6352
- device const char * src02,
6353
- device const char * src03,
6354
- device const char * src04,
6355
- device const char * src05,
6356
- device const char * src06,
6357
- device const char * src07,
6358
6668
  uint3 tgpig[[threadgroup_position_in_grid]],
6359
6669
  uint tiitg[[thread_index_in_threadgroup]],
6360
6670
  uint tiisg[[thread_index_in_simdgroup]],
6361
6671
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
6362
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6363
-
6364
6672
  const int64_t bid = tgpig.z/(ne12*ne13);
6365
6673
 
6366
6674
  tgpig.z = tgpig.z%(ne12*ne13);
6367
6675
 
6368
6676
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6677
+ device const char * src0 = src0s + id*nb02;
6369
6678
 
6370
6679
  kernel_mul_mv_q6_K_f32_impl(
6371
- src0[id],
6680
+ src0,
6372
6681
  (device const float *) (src1 + bid*nb11),
6373
6682
  dst + bid*ne0,
6374
6683
  ne00,
@@ -6387,9 +6696,10 @@ kernel void kernel_mul_mv_id_q6_K_f32(
6387
6696
 
6388
6697
  [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]]
6389
6698
  kernel void kernel_mul_mv_id_iq2_xxs_f32(
6390
- device const char * ids,
6699
+ device const char * src0s,
6391
6700
  device const char * src1,
6392
6701
  device float * dst,
6702
+ device const char * ids,
6393
6703
  constant uint64_t & nbi1,
6394
6704
  constant int64_t & ne00,
6395
6705
  constant int64_t & ne01,
@@ -6410,29 +6720,20 @@ kernel void kernel_mul_mv_id_iq2_xxs_f32(
6410
6720
  constant uint & r2,
6411
6721
  constant uint & r3,
6412
6722
  constant int & idx,
6413
- device const char * src00,
6414
- device const char * src01,
6415
- device const char * src02,
6416
- device const char * src03,
6417
- device const char * src04,
6418
- device const char * src05,
6419
- device const char * src06,
6420
- device const char * src07,
6421
6723
  threadgroup int8_t * shared_values [[threadgroup(0)]],
6422
6724
  uint3 tgpig[[threadgroup_position_in_grid]],
6423
6725
  uint tiitg[[thread_index_in_threadgroup]],
6424
6726
  uint tiisg[[thread_index_in_simdgroup]],
6425
6727
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
6426
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6427
-
6428
6728
  const int64_t bid = tgpig.z/(ne12*ne13);
6429
6729
 
6430
6730
  tgpig.z = tgpig.z%(ne12*ne13);
6431
6731
 
6432
6732
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6733
+ device const char * src0 = src0s + id*nb02;
6433
6734
 
6434
6735
  kernel_mul_mv_iq2_xxs_f32_impl(
6435
- src0[id],
6736
+ src0,
6436
6737
  (device const float *) (src1 + bid*nb11),
6437
6738
  dst + bid*ne0,
6438
6739
  ne00,
@@ -6452,9 +6753,10 @@ kernel void kernel_mul_mv_id_iq2_xxs_f32(
6452
6753
 
6453
6754
  [[host_name("kernel_mul_mv_id_iq2_xs_f32")]]
6454
6755
  kernel void kernel_mul_mv_id_iq2_xs_f32(
6455
- device const char * ids,
6756
+ device const char * src0s,
6456
6757
  device const char * src1,
6457
6758
  device float * dst,
6759
+ device const char * ids,
6458
6760
  constant uint64_t & nbi1,
6459
6761
  constant int64_t & ne00,
6460
6762
  constant int64_t & ne01,
@@ -6475,29 +6777,20 @@ kernel void kernel_mul_mv_id_iq2_xs_f32(
6475
6777
  constant uint & r2,
6476
6778
  constant uint & r3,
6477
6779
  constant int & idx,
6478
- device const char * src00,
6479
- device const char * src01,
6480
- device const char * src02,
6481
- device const char * src03,
6482
- device const char * src04,
6483
- device const char * src05,
6484
- device const char * src06,
6485
- device const char * src07,
6486
6780
  threadgroup int8_t * shared_values [[threadgroup(0)]],
6487
6781
  uint3 tgpig[[threadgroup_position_in_grid]],
6488
6782
  uint tiitg[[thread_index_in_threadgroup]],
6489
6783
  uint tiisg[[thread_index_in_simdgroup]],
6490
6784
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
6491
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6492
-
6493
6785
  const int64_t bid = tgpig.z/(ne12*ne13);
6494
6786
 
6495
6787
  tgpig.z = tgpig.z%(ne12*ne13);
6496
6788
 
6497
6789
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6790
+ device const char * src0 = src0s + id*nb02;
6498
6791
 
6499
6792
  kernel_mul_mv_iq2_xs_f32_impl(
6500
- src0[id],
6793
+ src0,
6501
6794
  (device const float *) (src1 + bid*nb11),
6502
6795
  dst + bid*ne0,
6503
6796
  ne00,
@@ -6517,9 +6810,10 @@ kernel void kernel_mul_mv_id_iq2_xs_f32(
6517
6810
 
6518
6811
  [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]]
6519
6812
  kernel void kernel_mul_mv_id_iq3_xxs_f32(
6520
- device const char * ids,
6813
+ device const char * src0s,
6521
6814
  device const char * src1,
6522
6815
  device float * dst,
6816
+ device const char * ids,
6523
6817
  constant uint64_t & nbi1,
6524
6818
  constant int64_t & ne00,
6525
6819
  constant int64_t & ne01,
@@ -6540,29 +6834,20 @@ kernel void kernel_mul_mv_id_iq3_xxs_f32(
6540
6834
  constant uint & r2,
6541
6835
  constant uint & r3,
6542
6836
  constant int & idx,
6543
- device const char * src00,
6544
- device const char * src01,
6545
- device const char * src02,
6546
- device const char * src03,
6547
- device const char * src04,
6548
- device const char * src05,
6549
- device const char * src06,
6550
- device const char * src07,
6551
6837
  threadgroup int8_t * shared_values [[threadgroup(0)]],
6552
6838
  uint3 tgpig[[threadgroup_position_in_grid]],
6553
6839
  uint tiitg[[thread_index_in_threadgroup]],
6554
6840
  uint tiisg[[thread_index_in_simdgroup]],
6555
6841
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
6556
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6557
-
6558
6842
  const int64_t bid = tgpig.z/(ne12*ne13);
6559
6843
 
6560
6844
  tgpig.z = tgpig.z%(ne12*ne13);
6561
6845
 
6562
6846
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6847
+ device const char * src0 = src0s + id*nb02;
6563
6848
 
6564
6849
  kernel_mul_mv_iq3_xxs_f32_impl(
6565
- src0[id],
6850
+ src0,
6566
6851
  (device const float *) (src1 + bid*nb11),
6567
6852
  dst + bid*ne0,
6568
6853
  ne00,
@@ -6582,9 +6867,10 @@ kernel void kernel_mul_mv_id_iq3_xxs_f32(
6582
6867
 
6583
6868
  [[host_name("kernel_mul_mv_id_iq3_s_f32")]]
6584
6869
  kernel void kernel_mul_mv_id_iq3_s_f32(
6585
- device const char * ids,
6870
+ device const char * src0s,
6586
6871
  device const char * src1,
6587
6872
  device float * dst,
6873
+ device const char * ids,
6588
6874
  constant uint64_t & nbi1,
6589
6875
  constant int64_t & ne00,
6590
6876
  constant int64_t & ne01,
@@ -6605,29 +6891,20 @@ kernel void kernel_mul_mv_id_iq3_s_f32(
6605
6891
  constant uint & r2,
6606
6892
  constant uint & r3,
6607
6893
  constant int & idx,
6608
- device const char * src00,
6609
- device const char * src01,
6610
- device const char * src02,
6611
- device const char * src03,
6612
- device const char * src04,
6613
- device const char * src05,
6614
- device const char * src06,
6615
- device const char * src07,
6616
6894
  threadgroup int8_t * shared_values [[threadgroup(0)]],
6617
6895
  uint3 tgpig[[threadgroup_position_in_grid]],
6618
6896
  uint tiitg[[thread_index_in_threadgroup]],
6619
6897
  uint tiisg[[thread_index_in_simdgroup]],
6620
6898
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
6621
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6622
-
6623
6899
  const int64_t bid = tgpig.z/(ne12*ne13);
6624
6900
 
6625
6901
  tgpig.z = tgpig.z%(ne12*ne13);
6626
6902
 
6627
6903
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6904
+ device const char * src0 = src0s + id*nb02;
6628
6905
 
6629
6906
  kernel_mul_mv_iq3_s_f32_impl(
6630
- src0[id],
6907
+ src0,
6631
6908
  (device const float *) (src1 + bid*nb11),
6632
6909
  dst + bid*ne0,
6633
6910
  ne00,
@@ -6647,9 +6924,10 @@ kernel void kernel_mul_mv_id_iq3_s_f32(
6647
6924
 
6648
6925
  [[host_name("kernel_mul_mv_id_iq2_s_f32")]]
6649
6926
  kernel void kernel_mul_mv_id_iq2_s_f32(
6650
- device const char * ids,
6927
+ device const char * src0s,
6651
6928
  device const char * src1,
6652
6929
  device float * dst,
6930
+ device const char * ids,
6653
6931
  constant uint64_t & nbi1,
6654
6932
  constant int64_t & ne00,
6655
6933
  constant int64_t & ne01,
@@ -6670,29 +6948,20 @@ kernel void kernel_mul_mv_id_iq2_s_f32(
6670
6948
  constant uint & r2,
6671
6949
  constant uint & r3,
6672
6950
  constant int & idx,
6673
- device const char * src00,
6674
- device const char * src01,
6675
- device const char * src02,
6676
- device const char * src03,
6677
- device const char * src04,
6678
- device const char * src05,
6679
- device const char * src06,
6680
- device const char * src07,
6681
6951
  threadgroup int8_t * shared_values [[threadgroup(0)]],
6682
6952
  uint3 tgpig[[threadgroup_position_in_grid]],
6683
6953
  uint tiitg[[thread_index_in_threadgroup]],
6684
6954
  uint tiisg[[thread_index_in_simdgroup]],
6685
6955
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
6686
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6687
-
6688
6956
  const int64_t bid = tgpig.z/(ne12*ne13);
6689
6957
 
6690
6958
  tgpig.z = tgpig.z%(ne12*ne13);
6691
6959
 
6692
6960
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6961
+ device const char * src0 = src0s + id*nb02;
6693
6962
 
6694
6963
  kernel_mul_mv_iq2_s_f32_impl(
6695
- src0[id],
6964
+ src0,
6696
6965
  (device const float *) (src1 + bid*nb11),
6697
6966
  dst + bid*ne0,
6698
6967
  ne00,
@@ -6712,9 +6981,10 @@ kernel void kernel_mul_mv_id_iq2_s_f32(
6712
6981
 
6713
6982
  [[host_name("kernel_mul_mv_id_iq1_s_f32")]]
6714
6983
  kernel void kernel_mul_mv_id_iq1_s_f32(
6715
- device const char * ids,
6984
+ device const char * src0s,
6716
6985
  device const char * src1,
6717
6986
  device float * dst,
6987
+ device const char * ids,
6718
6988
  constant uint64_t & nbi1,
6719
6989
  constant int64_t & ne00,
6720
6990
  constant int64_t & ne01,
@@ -6735,28 +7005,74 @@ kernel void kernel_mul_mv_id_iq1_s_f32(
6735
7005
  constant uint & r2,
6736
7006
  constant uint & r3,
6737
7007
  constant int & idx,
6738
- device const char * src00,
6739
- device const char * src01,
6740
- device const char * src02,
6741
- device const char * src03,
6742
- device const char * src04,
6743
- device const char * src05,
6744
- device const char * src06,
6745
- device const char * src07,
6746
7008
  uint3 tgpig[[threadgroup_position_in_grid]],
6747
7009
  uint tiitg[[thread_index_in_threadgroup]],
6748
7010
  uint tiisg[[thread_index_in_simdgroup]],
6749
7011
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
6750
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6751
-
6752
7012
  const int64_t bid = tgpig.z/(ne12*ne13);
6753
7013
 
6754
7014
  tgpig.z = tgpig.z%(ne12*ne13);
6755
7015
 
6756
7016
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
7017
+ device const char * src0 = src0s + id*nb02;
6757
7018
 
6758
7019
  kernel_mul_mv_iq1_s_f32_impl(
6759
- src0[id],
7020
+ src0,
7021
+ (device const float *) (src1 + bid*nb11),
7022
+ dst + bid*ne0,
7023
+ ne00,
7024
+ ne01,
7025
+ ne02,
7026
+ ne10,
7027
+ ne12,
7028
+ ne0,
7029
+ ne1,
7030
+ r2,
7031
+ r3,
7032
+ tgpig,
7033
+ tiisg,
7034
+ sgitg);
7035
+ }
7036
+
7037
+ [[host_name("kernel_mul_mv_id_iq1_m_f32")]]
7038
+ kernel void kernel_mul_mv_id_iq1_m_f32(
7039
+ device const char * src0s,
7040
+ device const char * src1,
7041
+ device float * dst,
7042
+ device const char * ids,
7043
+ constant uint64_t & nbi1,
7044
+ constant int64_t & ne00,
7045
+ constant int64_t & ne01,
7046
+ constant int64_t & ne02,
7047
+ constant uint64_t & nb00,
7048
+ constant uint64_t & nb01,
7049
+ constant uint64_t & nb02,
7050
+ constant int64_t & ne10,
7051
+ constant int64_t & ne11,
7052
+ constant int64_t & ne12,
7053
+ constant int64_t & ne13,
7054
+ constant uint64_t & nb10,
7055
+ constant uint64_t & nb11,
7056
+ constant uint64_t & nb12,
7057
+ constant int64_t & ne0,
7058
+ constant int64_t & ne1,
7059
+ constant uint64_t & nb1,
7060
+ constant uint & r2,
7061
+ constant uint & r3,
7062
+ constant int & idx,
7063
+ uint3 tgpig[[threadgroup_position_in_grid]],
7064
+ uint tiitg[[thread_index_in_threadgroup]],
7065
+ uint tiisg[[thread_index_in_simdgroup]],
7066
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
7067
+ const int64_t bid = tgpig.z/(ne12*ne13);
7068
+
7069
+ tgpig.z = tgpig.z%(ne12*ne13);
7070
+
7071
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
7072
+ device const char * src0 = src0s + id*nb02;
7073
+
7074
+ kernel_mul_mv_iq1_m_f32_impl(
7075
+ src0,
6760
7076
  (device const float *) (src1 + bid*nb11),
6761
7077
  dst + bid*ne0,
6762
7078
  ne00,
@@ -6775,9 +7091,10 @@ kernel void kernel_mul_mv_id_iq1_s_f32(
6775
7091
 
6776
7092
  [[host_name("kernel_mul_mv_id_iq4_nl_f32")]]
6777
7093
  kernel void kernel_mul_mv_id_iq4_nl_f32(
6778
- device const char * ids,
7094
+ device const char * src0s,
6779
7095
  device const char * src1,
6780
7096
  device float * dst,
7097
+ device const char * ids,
6781
7098
  constant uint64_t & nbi1,
6782
7099
  constant int64_t & ne00,
6783
7100
  constant int64_t & ne01,
@@ -6798,29 +7115,20 @@ kernel void kernel_mul_mv_id_iq4_nl_f32(
6798
7115
  constant uint & r2,
6799
7116
  constant uint & r3,
6800
7117
  constant int & idx,
6801
- device const char * src00,
6802
- device const char * src01,
6803
- device const char * src02,
6804
- device const char * src03,
6805
- device const char * src04,
6806
- device const char * src05,
6807
- device const char * src06,
6808
- device const char * src07,
6809
7118
  threadgroup float * shared_values [[threadgroup(0)]],
6810
7119
  uint3 tgpig[[threadgroup_position_in_grid]],
6811
7120
  uint tiitg[[thread_index_in_threadgroup]],
6812
7121
  uint tiisg[[thread_index_in_simdgroup]],
6813
7122
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
6814
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6815
-
6816
7123
  const int64_t bid = tgpig.z/(ne12*ne13);
6817
7124
 
6818
7125
  tgpig.z = tgpig.z%(ne12*ne13);
6819
7126
 
6820
7127
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
7128
+ device const char * src0 = src0s + id*nb02;
6821
7129
 
6822
7130
  kernel_mul_mv_iq4_nl_f32_impl(
6823
- src0[id],
7131
+ src0,
6824
7132
  (device const float *) (src1 + bid*nb11),
6825
7133
  dst + bid*ne0,
6826
7134
  ne00,
@@ -6840,9 +7148,10 @@ kernel void kernel_mul_mv_id_iq4_nl_f32(
6840
7148
 
6841
7149
  [[host_name("kernel_mul_mv_id_iq4_xs_f32")]]
6842
7150
  kernel void kernel_mul_mv_id_iq4_xs_f32(
6843
- device const char * ids,
7151
+ device const char * src0s,
6844
7152
  device const char * src1,
6845
7153
  device float * dst,
7154
+ device const char * ids,
6846
7155
  constant uint64_t & nbi1,
6847
7156
  constant int64_t & ne00,
6848
7157
  constant int64_t & ne01,
@@ -6863,33 +7172,24 @@ kernel void kernel_mul_mv_id_iq4_xs_f32(
6863
7172
  constant uint & r2,
6864
7173
  constant uint & r3,
6865
7174
  constant int & idx,
6866
- device const char * src00,
6867
- device const char * src01,
6868
- device const char * src02,
6869
- device const char * src03,
6870
- device const char * src04,
6871
- device const char * src05,
6872
- device const char * src06,
6873
- device const char * src07,
6874
7175
  threadgroup float * shared_values [[threadgroup(0)]],
6875
7176
  uint3 tgpig[[threadgroup_position_in_grid]],
6876
7177
  uint tiitg[[thread_index_in_threadgroup]],
6877
7178
  uint tiisg[[thread_index_in_simdgroup]],
6878
7179
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
6879
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
6880
-
6881
7180
  const int64_t bid = tgpig.z/(ne12*ne13);
6882
7181
 
6883
7182
  tgpig.z = tgpig.z%(ne12*ne13);
6884
7183
 
6885
7184
  const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
7185
+ device const char * src0 = src0s + id*nb02;
6886
7186
 
6887
7187
  #if QK_K == 64
6888
7188
  kernel_mul_mv_iq4_nl_f32_impl(
6889
7189
  #else
6890
7190
  kernel_mul_mv_iq4_xs_f32_impl(
6891
7191
  #endif
6892
- src0[id],
7192
+ src0,
6893
7193
  (device const float *) (src1 + bid*nb11),
6894
7194
  dst + bid*ne0,
6895
7195
  ne00,