llama_cpp 0.14.2 → 0.14.4
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +14 -0
- data/ext/llama_cpp/llama_cpp.cpp +64 -0
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +6 -0
- data/vendor/tmp/llama.cpp/Makefile +91 -21
- data/vendor/tmp/llama.cpp/ggml-alloc.c +14 -5
- data/vendor/tmp/llama.cpp/ggml-backend-impl.h +5 -0
- data/vendor/tmp/llama.cpp/ggml-backend.c +155 -125
- data/vendor/tmp/llama.cpp/ggml-backend.h +4 -4
- data/vendor/tmp/llama.cpp/ggml-common.h +25 -2
- data/vendor/tmp/llama.cpp/ggml-cuda.cu +1779 -10762
- data/vendor/tmp/llama.cpp/ggml-cuda.h +6 -15
- data/vendor/tmp/llama.cpp/ggml-kompute.cpp +5 -0
- data/vendor/tmp/llama.cpp/ggml-metal.m +167 -124
- data/vendor/tmp/llama.cpp/ggml-metal.metal +603 -303
- data/vendor/tmp/llama.cpp/ggml-opencl.cpp +5 -0
- data/vendor/tmp/llama.cpp/ggml-quants.c +663 -56
- data/vendor/tmp/llama.cpp/ggml-quants.h +3 -0
- data/vendor/tmp/llama.cpp/ggml-sycl.cpp +341 -469
- data/vendor/tmp/llama.cpp/ggml-sycl.h +19 -4
- data/vendor/tmp/llama.cpp/ggml-vulkan-shaders.hpp +37199 -14939
- data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +335 -307
- data/vendor/tmp/llama.cpp/ggml-vulkan.h +0 -11
- data/vendor/tmp/llama.cpp/ggml.c +229 -107
- data/vendor/tmp/llama.cpp/ggml.h +11 -5
- data/vendor/tmp/llama.cpp/llama.cpp +2136 -464
- data/vendor/tmp/llama.cpp/llama.h +86 -23
- data/vendor/tmp/llama.cpp/unicode-data.cpp +1651 -0
- data/vendor/tmp/llama.cpp/unicode-data.h +16 -0
- data/vendor/tmp/llama.cpp/unicode.cpp +8 -1403
- data/vendor/tmp/llama.cpp/unicode.h +2 -0
- metadata +5 -3
@@ -13,8 +13,8 @@ using namespace metal;
|
|
13
13
|
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
|
14
14
|
|
15
15
|
enum ggml_sort_order {
|
16
|
-
|
17
|
-
|
16
|
+
GGML_SORT_ORDER_ASC,
|
17
|
+
GGML_SORT_ORDER_DESC,
|
18
18
|
};
|
19
19
|
|
20
20
|
// general-purpose kernel for addition, multiplication and division of two tensors
|
@@ -1973,9 +1973,11 @@ kernel void kernel_timestep_embedding_f32(
|
|
1973
1973
|
|
1974
1974
|
// bitonic sort implementation following the CUDA kernels as reference
|
1975
1975
|
typedef void (argsort_t)(
|
1976
|
-
device const float
|
1977
|
-
device int32_t
|
1978
|
-
constant int64_t
|
1976
|
+
device const float * x,
|
1977
|
+
device int32_t * dst,
|
1978
|
+
constant int64_t & ncols,
|
1979
|
+
constant int64_t & ncols_pad,
|
1980
|
+
threadgroup int32_t * shared_values [[threadgroup(0)]],
|
1979
1981
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
1980
1982
|
uint3 tpitg[[thread_position_in_threadgroup]]);
|
1981
1983
|
|
@@ -1984,33 +1986,42 @@ kernel void kernel_argsort_f32_i32(
|
|
1984
1986
|
device const float * x,
|
1985
1987
|
device int32_t * dst,
|
1986
1988
|
constant int64_t & ncols,
|
1989
|
+
constant int64_t & ncols_pad,
|
1990
|
+
threadgroup int32_t * shared_values [[threadgroup(0)]],
|
1987
1991
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
1988
1992
|
uint3 tpitg[[thread_position_in_threadgroup]]) {
|
1989
1993
|
// bitonic sort
|
1990
1994
|
int col = tpitg[0];
|
1991
1995
|
int row = tgpig[1];
|
1992
1996
|
|
1993
|
-
if (col >=
|
1997
|
+
if (col >= ncols_pad) return;
|
1994
1998
|
|
1995
|
-
device const float * x_row = x
|
1996
|
-
|
1999
|
+
device const float * x_row = x + row * ncols;
|
2000
|
+
threadgroup int32_t * dst_row = shared_values;
|
1997
2001
|
|
1998
2002
|
// initialize indices
|
1999
|
-
|
2000
|
-
|
2001
|
-
}
|
2003
|
+
dst_row[col] = col;
|
2004
|
+
|
2002
2005
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
2003
2006
|
|
2004
|
-
for (int k = 2; k <=
|
2007
|
+
for (int k = 2; k <= ncols_pad; k *= 2) {
|
2005
2008
|
for (int j = k / 2; j > 0; j /= 2) {
|
2006
2009
|
int ixj = col ^ j;
|
2007
2010
|
if (ixj > col) {
|
2008
2011
|
if ((col & k) == 0) {
|
2009
|
-
if (
|
2012
|
+
if (dst_row[col] >= ncols ||
|
2013
|
+
(dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ?
|
2014
|
+
x_row[dst_row[col]] > x_row[dst_row[ixj]] :
|
2015
|
+
x_row[dst_row[col]] < x_row[dst_row[ixj]]))
|
2016
|
+
) {
|
2010
2017
|
SWAP(dst_row[col], dst_row[ixj]);
|
2011
2018
|
}
|
2012
2019
|
} else {
|
2013
|
-
if (
|
2020
|
+
if (dst_row[ixj] >= ncols ||
|
2021
|
+
(dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ?
|
2022
|
+
x_row[dst_row[col]] < x_row[dst_row[ixj]] :
|
2023
|
+
x_row[dst_row[col]] > x_row[dst_row[ixj]]))
|
2024
|
+
) {
|
2014
2025
|
SWAP(dst_row[col], dst_row[ixj]);
|
2015
2026
|
}
|
2016
2027
|
}
|
@@ -2018,10 +2029,15 @@ kernel void kernel_argsort_f32_i32(
|
|
2018
2029
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
2019
2030
|
}
|
2020
2031
|
}
|
2032
|
+
|
2033
|
+
// copy the result to dst without the padding
|
2034
|
+
if (col < ncols) {
|
2035
|
+
dst[row * ncols + col] = dst_row[col];
|
2036
|
+
}
|
2021
2037
|
}
|
2022
2038
|
|
2023
|
-
template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32<
|
2024
|
-
template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<
|
2039
|
+
template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_ASC>;
|
2040
|
+
template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_DESC>;
|
2025
2041
|
|
2026
2042
|
kernel void kernel_leaky_relu_f32(
|
2027
2043
|
device const float * src0,
|
@@ -2388,6 +2404,242 @@ kernel void kernel_cpy_f32_q4_1(
|
|
2388
2404
|
}
|
2389
2405
|
}
|
2390
2406
|
|
2407
|
+
kernel void kernel_cpy_f32_q5_0(
|
2408
|
+
device const float * src0,
|
2409
|
+
device void * dst,
|
2410
|
+
constant int64_t & ne00,
|
2411
|
+
constant int64_t & ne01,
|
2412
|
+
constant int64_t & ne02,
|
2413
|
+
constant int64_t & ne03,
|
2414
|
+
constant uint64_t & nb00,
|
2415
|
+
constant uint64_t & nb01,
|
2416
|
+
constant uint64_t & nb02,
|
2417
|
+
constant uint64_t & nb03,
|
2418
|
+
constant int64_t & ne0,
|
2419
|
+
constant int64_t & ne1,
|
2420
|
+
constant int64_t & ne2,
|
2421
|
+
constant int64_t & ne3,
|
2422
|
+
constant uint64_t & nb0,
|
2423
|
+
constant uint64_t & nb1,
|
2424
|
+
constant uint64_t & nb2,
|
2425
|
+
constant uint64_t & nb3,
|
2426
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
2427
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
2428
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
2429
|
+
const int64_t i03 = tgpig[2];
|
2430
|
+
const int64_t i02 = tgpig[1];
|
2431
|
+
const int64_t i01 = tgpig[0];
|
2432
|
+
|
2433
|
+
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
2434
|
+
|
2435
|
+
const int64_t i3 = n / (ne2*ne1*ne0);
|
2436
|
+
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
2437
|
+
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
2438
|
+
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK5_0;
|
2439
|
+
|
2440
|
+
device block_q5_0 * dst_data = (device block_q5_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
2441
|
+
|
2442
|
+
for (int64_t i00 = tpitg.x*QK5_0; i00 < ne00; i00 += ntg.x*QK5_0) {
|
2443
|
+
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
2444
|
+
|
2445
|
+
float amax = 0.0f; // absolute max
|
2446
|
+
float max = 0.0f;
|
2447
|
+
|
2448
|
+
for (int j = 0; j < QK5_0; j++) {
|
2449
|
+
const float v = src[j];
|
2450
|
+
if (amax < fabs(v)) {
|
2451
|
+
amax = fabs(v);
|
2452
|
+
max = v;
|
2453
|
+
}
|
2454
|
+
}
|
2455
|
+
|
2456
|
+
const float d = max / -16;
|
2457
|
+
const float id = d ? 1.0f/d : 0.0f;
|
2458
|
+
|
2459
|
+
dst_data[i00/QK5_0].d = d;
|
2460
|
+
|
2461
|
+
uint32_t qh = 0;
|
2462
|
+
for (int j = 0; j < QK5_0/2; ++j) {
|
2463
|
+
const float x0 = src[0 + j]*id;
|
2464
|
+
const float x1 = src[QK5_0/2 + j]*id;
|
2465
|
+
|
2466
|
+
const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));
|
2467
|
+
const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));
|
2468
|
+
|
2469
|
+
dst_data[i00/QK5_0].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
|
2470
|
+
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
|
2471
|
+
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
|
2472
|
+
}
|
2473
|
+
thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
|
2474
|
+
for (int j = 0; j < 4; ++j) {
|
2475
|
+
dst_data[i00/QK5_0].qh[j] = qh8[j];
|
2476
|
+
}
|
2477
|
+
}
|
2478
|
+
}
|
2479
|
+
|
2480
|
+
kernel void kernel_cpy_f32_q5_1(
|
2481
|
+
device const float * src0,
|
2482
|
+
device void * dst,
|
2483
|
+
constant int64_t & ne00,
|
2484
|
+
constant int64_t & ne01,
|
2485
|
+
constant int64_t & ne02,
|
2486
|
+
constant int64_t & ne03,
|
2487
|
+
constant uint64_t & nb00,
|
2488
|
+
constant uint64_t & nb01,
|
2489
|
+
constant uint64_t & nb02,
|
2490
|
+
constant uint64_t & nb03,
|
2491
|
+
constant int64_t & ne0,
|
2492
|
+
constant int64_t & ne1,
|
2493
|
+
constant int64_t & ne2,
|
2494
|
+
constant int64_t & ne3,
|
2495
|
+
constant uint64_t & nb0,
|
2496
|
+
constant uint64_t & nb1,
|
2497
|
+
constant uint64_t & nb2,
|
2498
|
+
constant uint64_t & nb3,
|
2499
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
2500
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
2501
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
2502
|
+
const int64_t i03 = tgpig[2];
|
2503
|
+
const int64_t i02 = tgpig[1];
|
2504
|
+
const int64_t i01 = tgpig[0];
|
2505
|
+
|
2506
|
+
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
2507
|
+
|
2508
|
+
const int64_t i3 = n / (ne2*ne1*ne0);
|
2509
|
+
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
2510
|
+
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
2511
|
+
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK5_1;
|
2512
|
+
|
2513
|
+
device block_q5_1 * dst_data = (device block_q5_1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
2514
|
+
|
2515
|
+
for (int64_t i00 = tpitg.x*QK5_1; i00 < ne00; i00 += ntg.x*QK5_1) {
|
2516
|
+
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
2517
|
+
|
2518
|
+
float max = src[0];
|
2519
|
+
float min = src[0];
|
2520
|
+
|
2521
|
+
for (int j = 1; j < QK5_1; j++) {
|
2522
|
+
const float v = src[j];
|
2523
|
+
min = v < min ? v : min;
|
2524
|
+
max = v > max ? v : max;
|
2525
|
+
}
|
2526
|
+
|
2527
|
+
const float d = (max - min) / 31;
|
2528
|
+
const float id = d ? 1.0f/d : 0.0f;
|
2529
|
+
|
2530
|
+
dst_data[i00/QK5_1].d = d;
|
2531
|
+
dst_data[i00/QK5_1].m = min;
|
2532
|
+
|
2533
|
+
uint32_t qh = 0;
|
2534
|
+
for (int j = 0; j < QK5_1/2; ++j) {
|
2535
|
+
const float x0 = (src[0 + j] - min)*id;
|
2536
|
+
const float x1 = (src[QK5_1/2 + j] - min)*id;
|
2537
|
+
|
2538
|
+
const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
|
2539
|
+
const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
|
2540
|
+
|
2541
|
+
dst_data[i00/QK5_1].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
|
2542
|
+
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
|
2543
|
+
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
|
2544
|
+
}
|
2545
|
+
thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
|
2546
|
+
for (int j = 0; j < 4; ++j) {
|
2547
|
+
dst_data[i00/QK5_1].qh[j] = qh8[j];
|
2548
|
+
}
|
2549
|
+
}
|
2550
|
+
}
|
2551
|
+
|
2552
|
+
static inline int best_index_int8(int n, constant float * val, float x) {
|
2553
|
+
if (x <= val[0]) return 0;
|
2554
|
+
if (x >= val[n-1]) return n-1;
|
2555
|
+
int ml = 0, mu = n-1;
|
2556
|
+
while (mu-ml > 1) {
|
2557
|
+
int mav = (ml+mu)/2;
|
2558
|
+
if (x < val[mav]) mu = mav; else ml = mav;
|
2559
|
+
}
|
2560
|
+
return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
|
2561
|
+
}
|
2562
|
+
|
2563
|
+
constexpr constant static float kvalues_iq4nl_f[16] = {
|
2564
|
+
-127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
|
2565
|
+
};
|
2566
|
+
|
2567
|
+
kernel void kernel_cpy_f32_iq4_nl(
|
2568
|
+
device const float * src0,
|
2569
|
+
device void * dst,
|
2570
|
+
constant int64_t & ne00,
|
2571
|
+
constant int64_t & ne01,
|
2572
|
+
constant int64_t & ne02,
|
2573
|
+
constant int64_t & ne03,
|
2574
|
+
constant uint64_t & nb00,
|
2575
|
+
constant uint64_t & nb01,
|
2576
|
+
constant uint64_t & nb02,
|
2577
|
+
constant uint64_t & nb03,
|
2578
|
+
constant int64_t & ne0,
|
2579
|
+
constant int64_t & ne1,
|
2580
|
+
constant int64_t & ne2,
|
2581
|
+
constant int64_t & ne3,
|
2582
|
+
constant uint64_t & nb0,
|
2583
|
+
constant uint64_t & nb1,
|
2584
|
+
constant uint64_t & nb2,
|
2585
|
+
constant uint64_t & nb3,
|
2586
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
2587
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
2588
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
2589
|
+
const int64_t i03 = tgpig[2];
|
2590
|
+
const int64_t i02 = tgpig[1];
|
2591
|
+
const int64_t i01 = tgpig[0];
|
2592
|
+
|
2593
|
+
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
2594
|
+
|
2595
|
+
const int64_t i3 = n / (ne2*ne1*ne0);
|
2596
|
+
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
2597
|
+
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
2598
|
+
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_NL;
|
2599
|
+
|
2600
|
+
device block_iq4_nl * dst_data = (device block_iq4_nl *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
2601
|
+
|
2602
|
+
for (int64_t i00 = tpitg.x*QK4_NL; i00 < ne00; i00 += ntg.x*QK4_NL) {
|
2603
|
+
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
2604
|
+
|
2605
|
+
float amax = 0.0f; // absolute max
|
2606
|
+
float max = 0.0f;
|
2607
|
+
|
2608
|
+
for (int j = 0; j < QK4_0; j++) {
|
2609
|
+
const float v = src[j];
|
2610
|
+
if (amax < fabs(v)) {
|
2611
|
+
amax = fabs(v);
|
2612
|
+
max = v;
|
2613
|
+
}
|
2614
|
+
}
|
2615
|
+
|
2616
|
+
const float d = max / kvalues_iq4nl_f[0];
|
2617
|
+
const float id = d ? 1.0f/d : 0.0f;
|
2618
|
+
|
2619
|
+
float sumqx = 0, sumq2 = 0;
|
2620
|
+
for (int j = 0; j < QK4_NL/2; ++j) {
|
2621
|
+
const float x0 = src[0 + j]*id;
|
2622
|
+
const float x1 = src[QK4_NL/2 + j]*id;
|
2623
|
+
|
2624
|
+
const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl_f, x0);
|
2625
|
+
const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl_f, x1);
|
2626
|
+
|
2627
|
+
dst_data[i00/QK4_NL].qs[j] = xi0 | (xi1 << 4);
|
2628
|
+
|
2629
|
+
const float v0 = kvalues_iq4nl_f[xi0];
|
2630
|
+
const float v1 = kvalues_iq4nl_f[xi1];
|
2631
|
+
const float w0 = src[0 + j]*src[0 + j];
|
2632
|
+
const float w1 = src[QK4_NL/2 + j]*src[QK4_NL/2 + j];
|
2633
|
+
sumqx += w0*v0*src[j] + w1*v1*src[QK4_NL/2 + j];
|
2634
|
+
sumq2 += w0*v0*v0 + w1*v1*v1;
|
2635
|
+
|
2636
|
+
}
|
2637
|
+
|
2638
|
+
dst_data[i00/QK4_NL].d = sumq2 > 0 ? sumqx/sumq2 : d;
|
2639
|
+
|
2640
|
+
}
|
2641
|
+
}
|
2642
|
+
|
2391
2643
|
kernel void kernel_concat(
|
2392
2644
|
device const char * src0,
|
2393
2645
|
device const char * src1,
|
@@ -4220,9 +4472,113 @@ void kernel_mul_mv_iq1_s_f32_impl(
|
|
4220
4472
|
}
|
4221
4473
|
}
|
4222
4474
|
|
4223
|
-
|
4224
|
-
|
4225
|
-
|
4475
|
+
void kernel_mul_mv_iq1_m_f32_impl(
|
4476
|
+
device const void * src0,
|
4477
|
+
device const float * src1,
|
4478
|
+
device float * dst,
|
4479
|
+
constant int64_t & ne00,
|
4480
|
+
constant int64_t & ne01,
|
4481
|
+
constant int64_t & ne02,
|
4482
|
+
constant int64_t & ne10,
|
4483
|
+
constant int64_t & ne12,
|
4484
|
+
constant int64_t & ne0,
|
4485
|
+
constant int64_t & ne1,
|
4486
|
+
constant uint & r2,
|
4487
|
+
constant uint & r3,
|
4488
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
4489
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
4490
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4491
|
+
|
4492
|
+
const int nb = ne00/QK_K;
|
4493
|
+
const int r0 = tgpig.x;
|
4494
|
+
const int r1 = tgpig.y;
|
4495
|
+
const int im = tgpig.z;
|
4496
|
+
|
4497
|
+
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
4498
|
+
const int ib_row = first_row * nb;
|
4499
|
+
|
4500
|
+
const uint i12 = im%ne12;
|
4501
|
+
const uint i13 = im/ne12;
|
4502
|
+
|
4503
|
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
4504
|
+
device const block_iq1_m * x = (device const block_iq1_m *) src0 + ib_row + offset0;
|
4505
|
+
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
4506
|
+
|
4507
|
+
float yl[32];
|
4508
|
+
float sumf[N_DST]={0.f}, all_sum;
|
4509
|
+
|
4510
|
+
const int nb32 = nb * (QK_K / 32);
|
4511
|
+
|
4512
|
+
const int ix = tiisg;
|
4513
|
+
|
4514
|
+
device const float * y4 = y + 32 * ix;
|
4515
|
+
|
4516
|
+
#if QK_K != 64
|
4517
|
+
iq1m_scale_t scale;
|
4518
|
+
#endif
|
4519
|
+
|
4520
|
+
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
|
4521
|
+
|
4522
|
+
float4 sumy = {0.f};
|
4523
|
+
for (int i = 0; i < 8; ++i) {
|
4524
|
+
yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
|
4525
|
+
yl[i+ 8] = y4[i+ 8]; sumy[1] += yl[i+ 8];
|
4526
|
+
yl[i+16] = y4[i+16]; sumy[2] += yl[i+16];
|
4527
|
+
yl[i+24] = y4[i+24]; sumy[3] += yl[i+24];
|
4528
|
+
}
|
4529
|
+
|
4530
|
+
const int ibl = ib32 / (QK_K / 32);
|
4531
|
+
const int ib = ib32 % (QK_K / 32);
|
4532
|
+
|
4533
|
+
device const block_iq1_m * xr = x + ibl;
|
4534
|
+
device const uint8_t * qs = xr->qs + 4 * ib;
|
4535
|
+
device const uint8_t * qh = xr->qh + 2 * ib;
|
4536
|
+
device const uint16_t * sc = (device const uint16_t *)xr->scales;
|
4537
|
+
|
4538
|
+
for (int row = 0; row < N_DST; row++) {
|
4539
|
+
|
4540
|
+
#if QK_K != 64
|
4541
|
+
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
|
4542
|
+
#endif
|
4543
|
+
|
4544
|
+
constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
|
4545
|
+
constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700)));
|
4546
|
+
constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid_gpu + (qs[2] | ((qh[1] << 8) & 0x700)));
|
4547
|
+
constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[1] << 4) & 0x700)));
|
4548
|
+
|
4549
|
+
float2 sum = {0.f};
|
4550
|
+
for (int j = 0; j < 4; ++j) {
|
4551
|
+
sum[0] += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4)
|
4552
|
+
+ yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4);
|
4553
|
+
sum[1] += yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4)
|
4554
|
+
+ yl[j+24] * (grid4[j] & 0xf) + yl[j+28] * (grid4[j] >> 4);
|
4555
|
+
}
|
4556
|
+
const float delta1 = sumy[0] * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[1] * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
|
4557
|
+
const float delta2 = sumy[2] * (qh[1] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[3] * (qh[1] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
|
4558
|
+
#if QK_K == 64
|
4559
|
+
const float d = (float) *((device const half *)(sc - 1));
|
4560
|
+
sumf[row] += d * ((sum[0] + delta1) * (2*((sc[0] >> (8*(ib%2)+0)) & 0xf) + 1) +
|
4561
|
+
(sum[1] + delta2) * (2*((sc[0] >> (8*(ib%2)+4)) & 0xf) + 1));
|
4562
|
+
#else
|
4563
|
+
sumf[row] += (float)scale.f16 * ((sum[0] + delta1) * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 7) + 1) +
|
4564
|
+
(sum[1] + delta2) * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 7) + 1));
|
4565
|
+
#endif
|
4566
|
+
|
4567
|
+
sc += nb*sizeof(block_iq1_m)/2;
|
4568
|
+
qs += nb*sizeof(block_iq1_m);
|
4569
|
+
qh += nb*sizeof(block_iq1_m);
|
4570
|
+
}
|
4571
|
+
|
4572
|
+
y4 += 32 * 32;
|
4573
|
+
}
|
4574
|
+
|
4575
|
+
for (int row = 0; row < N_DST; ++row) {
|
4576
|
+
all_sum = simd_sum(sumf[row]);
|
4577
|
+
if (tiisg == 0) {
|
4578
|
+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
|
4579
|
+
}
|
4580
|
+
}
|
4581
|
+
}
|
4226
4582
|
|
4227
4583
|
void kernel_mul_mv_iq4_nl_f32_impl(
|
4228
4584
|
device const void * src0,
|
@@ -4441,6 +4797,34 @@ kernel void kernel_mul_mv_iq1_s_f32(
|
|
4441
4797
|
kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
4442
4798
|
}
|
4443
4799
|
|
4800
|
+
[[host_name("kernel_mul_mv_iq1_m_f32")]]
|
4801
|
+
kernel void kernel_mul_mv_iq1_m_f32(
|
4802
|
+
device const void * src0,
|
4803
|
+
device const float * src1,
|
4804
|
+
device float * dst,
|
4805
|
+
constant int64_t & ne00,
|
4806
|
+
constant int64_t & ne01,
|
4807
|
+
constant int64_t & ne02,
|
4808
|
+
constant uint64_t & nb00,
|
4809
|
+
constant uint64_t & nb01,
|
4810
|
+
constant uint64_t & nb02,
|
4811
|
+
constant int64_t & ne10,
|
4812
|
+
constant int64_t & ne11,
|
4813
|
+
constant int64_t & ne12,
|
4814
|
+
constant uint64_t & nb10,
|
4815
|
+
constant uint64_t & nb11,
|
4816
|
+
constant uint64_t & nb12,
|
4817
|
+
constant int64_t & ne0,
|
4818
|
+
constant int64_t & ne1,
|
4819
|
+
constant uint & r2,
|
4820
|
+
constant uint & r3,
|
4821
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
4822
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
4823
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4824
|
+
|
4825
|
+
kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
4826
|
+
}
|
4827
|
+
|
4444
4828
|
[[host_name("kernel_mul_mv_iq4_nl_f32")]]
|
4445
4829
|
kernel void kernel_mul_mv_iq4_nl_f32(
|
4446
4830
|
device const void * src0,
|
@@ -4914,6 +5298,38 @@ void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 &
|
|
4914
5298
|
}
|
4915
5299
|
}
|
4916
5300
|
|
5301
|
+
template <typename type4x4>
|
5302
|
+
void dequantize_iq1_m(device const block_iq1_m * xb, short il, thread type4x4 & reg) {
|
5303
|
+
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
|
5304
|
+
const int ib32 = il/2;
|
5305
|
+
il = il%2;
|
5306
|
+
device const uint16_t * sc = (device const uint16_t *)xb->scales;
|
5307
|
+
#if QK_K == 64
|
5308
|
+
const float d = xb->d;
|
5309
|
+
#else
|
5310
|
+
iq1m_scale_t scale;
|
5311
|
+
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
|
5312
|
+
const float d = scale.f16;
|
5313
|
+
#endif
|
5314
|
+
device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
|
5315
|
+
device const uint8_t * qh = xb->qh + 2*ib32 + il;
|
5316
|
+
#if QK_K == 64
|
5317
|
+
const float dl = d * (2*((sc[ib32/2] >> (8*(ib32%2)+4*il)) & 0xf) + 1);
|
5318
|
+
#else
|
5319
|
+
const float dl = d * (2*((sc[ib32/2] >> (6*(ib32%2)+3*il)) & 7) + 1);
|
5320
|
+
#endif
|
5321
|
+
const float ml1 = dl * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
|
5322
|
+
const float ml2 = dl * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
|
5323
|
+
constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
|
5324
|
+
constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700)));
|
5325
|
+
for (int i = 0; i < 4; ++i) {
|
5326
|
+
reg[0][i] = dl * (grid1[i] & 0xf) + ml1;
|
5327
|
+
reg[1][i] = dl * (grid1[i] >> 4) + ml1;
|
5328
|
+
reg[2][i] = dl * (grid2[i] & 0xf) + ml2;
|
5329
|
+
reg[3][i] = dl * (grid2[i] >> 4) + ml2;
|
5330
|
+
}
|
5331
|
+
}
|
5332
|
+
|
4917
5333
|
template <typename type4x4>
|
4918
5334
|
void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4 & reg) {
|
4919
5335
|
device const uint16_t * q4 = (device const uint16_t *)xb->qs;
|
@@ -5385,9 +5801,10 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|
5385
5801
|
|
5386
5802
|
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
5387
5803
|
kernel void kernel_mul_mm_id(
|
5388
|
-
device const uchar *
|
5804
|
+
device const uchar * src0s,
|
5389
5805
|
device const uchar * src1,
|
5390
5806
|
device float * dst,
|
5807
|
+
device const uchar * ids,
|
5391
5808
|
constant uint64_t & nbi1,
|
5392
5809
|
constant int64_t & ne00,
|
5393
5810
|
constant int64_t & ne02,
|
@@ -5404,22 +5821,14 @@ kernel void kernel_mul_mm_id(
|
|
5404
5821
|
constant uint & r2,
|
5405
5822
|
constant uint & r3,
|
5406
5823
|
constant int & idx,
|
5407
|
-
device const uchar * src00,
|
5408
|
-
device const uchar * src01,
|
5409
|
-
device const uchar * src02,
|
5410
|
-
device const uchar * src03,
|
5411
|
-
device const uchar * src04,
|
5412
|
-
device const uchar * src05,
|
5413
|
-
device const uchar * src06,
|
5414
|
-
device const uchar * src07,
|
5415
5824
|
threadgroup uchar * shared_memory [[threadgroup(0)]],
|
5416
5825
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
5417
5826
|
uint tiitg[[thread_index_in_threadgroup]],
|
5418
5827
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
5419
|
-
device const uchar * src0s[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
5420
5828
|
|
5421
5829
|
// expert id
|
5422
5830
|
const int32_t id = tgpig.z/(ne12*ne13);
|
5831
|
+
device const uchar * src0 = src0s + id*nb02;
|
5423
5832
|
|
5424
5833
|
tgpig.z = tgpig.z%(ne12*ne13);
|
5425
5834
|
|
@@ -5434,7 +5843,7 @@ kernel void kernel_mul_mm_id(
|
|
5434
5843
|
}
|
5435
5844
|
|
5436
5845
|
kernel_mul_mm_id_impl<block_q, nl, dequantize_func>(
|
5437
|
-
|
5846
|
+
src0,
|
5438
5847
|
src1,
|
5439
5848
|
src1ids,
|
5440
5849
|
dst,
|
@@ -5498,6 +5907,7 @@ template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_t kernel_get_r
|
|
5498
5907
|
template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_t kernel_get_rows<block_iq3_s, QK_NL, dequantize_iq3_s>;
|
5499
5908
|
template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_t kernel_get_rows<block_iq2_s, QK_NL, dequantize_iq2_s>;
|
5500
5909
|
template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
5910
|
+
template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_t kernel_get_rows<block_iq1_m, QK_NL, dequantize_iq1_m>;
|
5501
5911
|
template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows<block_iq4_nl, 2, dequantize_iq4_nl>;
|
5502
5912
|
#if QK_K == 64
|
5503
5913
|
template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows<block_iq4_xs, 2, dequantize_iq4_xs>;
|
@@ -5528,24 +5938,25 @@ typedef void (mat_mm_t)(
|
|
5528
5938
|
threadgroup uchar *,
|
5529
5939
|
uint3, uint, uint);
|
5530
5940
|
|
5531
|
-
template [[host_name("kernel_mul_mm_f32_f32")]]
|
5532
|
-
template [[host_name("kernel_mul_mm_f16_f32")]]
|
5533
|
-
template [[host_name("kernel_mul_mm_q4_0_f32")]]
|
5534
|
-
template [[host_name("kernel_mul_mm_q4_1_f32")]]
|
5535
|
-
template [[host_name("kernel_mul_mm_q5_0_f32")]]
|
5536
|
-
template [[host_name("kernel_mul_mm_q5_1_f32")]]
|
5537
|
-
template [[host_name("kernel_mul_mm_q8_0_f32")]]
|
5538
|
-
template [[host_name("kernel_mul_mm_q2_K_f32")]]
|
5539
|
-
template [[host_name("kernel_mul_mm_q3_K_f32")]]
|
5540
|
-
template [[host_name("kernel_mul_mm_q4_K_f32")]]
|
5541
|
-
template [[host_name("kernel_mul_mm_q5_K_f32")]]
|
5542
|
-
template [[host_name("kernel_mul_mm_q6_K_f32")]]
|
5941
|
+
template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<float4x4, 1, dequantize_f32>;
|
5942
|
+
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
|
5943
|
+
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
|
5944
|
+
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
|
5945
|
+
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_0, 2, dequantize_q5_0>;
|
5946
|
+
template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_1, 2, dequantize_q5_1>;
|
5947
|
+
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
|
5948
|
+
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
|
5949
|
+
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
|
5950
|
+
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
|
5951
|
+
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
|
5952
|
+
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
|
5543
5953
|
template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
5544
5954
|
template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
5545
5955
|
template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
5546
5956
|
template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_s, QK_NL, dequantize_iq3_s>;
|
5547
5957
|
template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_s, QK_NL, dequantize_iq2_s>;
|
5548
5958
|
template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
5959
|
+
template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_m, QK_NL, dequantize_iq1_m>;
|
5549
5960
|
template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_nl>;
|
5550
5961
|
#if QK_K == 64
|
5551
5962
|
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_xs>;
|
@@ -5558,9 +5969,10 @@ template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_m
|
|
5558
5969
|
//
|
5559
5970
|
|
5560
5971
|
typedef void (mat_mm_id_t)(
|
5561
|
-
device const uchar *
|
5972
|
+
device const uchar * src0s,
|
5562
5973
|
device const uchar * src1,
|
5563
5974
|
device float * dst,
|
5975
|
+
device const uchar * ids,
|
5564
5976
|
constant uint64_t & nbi1,
|
5565
5977
|
constant int64_t & ne00,
|
5566
5978
|
constant int64_t & ne02,
|
@@ -5577,35 +5989,28 @@ typedef void (mat_mm_id_t)(
|
|
5577
5989
|
constant uint & r2,
|
5578
5990
|
constant uint & r3,
|
5579
5991
|
constant int & idx,
|
5580
|
-
device const uchar * src00,
|
5581
|
-
device const uchar * src01,
|
5582
|
-
device const uchar * src02,
|
5583
|
-
device const uchar * src03,
|
5584
|
-
device const uchar * src04,
|
5585
|
-
device const uchar * src05,
|
5586
|
-
device const uchar * src06,
|
5587
|
-
device const uchar * src07,
|
5588
5992
|
threadgroup uchar *,
|
5589
5993
|
uint3, uint, uint);
|
5590
5994
|
|
5591
|
-
template [[host_name("kernel_mul_mm_id_f32_f32")]]
|
5592
|
-
template [[host_name("kernel_mul_mm_id_f16_f32")]]
|
5593
|
-
template [[host_name("kernel_mul_mm_id_q4_0_f32")]]
|
5594
|
-
template [[host_name("kernel_mul_mm_id_q4_1_f32")]]
|
5595
|
-
template [[host_name("kernel_mul_mm_id_q5_0_f32")]]
|
5596
|
-
template [[host_name("kernel_mul_mm_id_q5_1_f32")]]
|
5597
|
-
template [[host_name("kernel_mul_mm_id_q8_0_f32")]]
|
5598
|
-
template [[host_name("kernel_mul_mm_id_q2_K_f32")]]
|
5599
|
-
template [[host_name("kernel_mul_mm_id_q3_K_f32")]]
|
5600
|
-
template [[host_name("kernel_mul_mm_id_q4_K_f32")]]
|
5601
|
-
template [[host_name("kernel_mul_mm_id_q5_K_f32")]]
|
5602
|
-
template [[host_name("kernel_mul_mm_id_q6_K_f32")]]
|
5995
|
+
template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<float4x4, 1, dequantize_f32>;
|
5996
|
+
template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<half4x4, 1, dequantize_f16>;
|
5997
|
+
template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_0, 2, dequantize_q4_0>;
|
5998
|
+
template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_1, 2, dequantize_q4_1>;
|
5999
|
+
template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_0, 2, dequantize_q5_0>;
|
6000
|
+
template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_1, 2, dequantize_q5_1>;
|
6001
|
+
template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q8_0, 2, dequantize_q8_0>;
|
6002
|
+
template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q2_K, QK_NL, dequantize_q2_K>;
|
6003
|
+
template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q3_K, QK_NL, dequantize_q3_K>;
|
6004
|
+
template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>;
|
6005
|
+
template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
|
6006
|
+
template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
|
5603
6007
|
template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
5604
6008
|
template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
5605
6009
|
template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
5606
6010
|
template [[host_name("kernel_mul_mm_id_iq3_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_s, QK_NL, dequantize_iq3_s>;
|
5607
6011
|
template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_s, QK_NL, dequantize_iq2_s>;
|
5608
6012
|
template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
6013
|
+
template [[host_name("kernel_mul_mm_id_iq1_m_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_m, QK_NL, dequantize_iq1_m>;
|
5609
6014
|
template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl, 2, dequantize_iq4_nl>;
|
5610
6015
|
#if QK_K == 64
|
5611
6016
|
template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, 2, dequantize_iq4_xs>;
|
@@ -5619,9 +6024,10 @@ template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel
|
|
5619
6024
|
|
5620
6025
|
[[host_name("kernel_mul_mv_id_f32_f32")]]
|
5621
6026
|
kernel void kernel_mul_mv_id_f32_f32(
|
5622
|
-
device const char *
|
6027
|
+
device const char * src0s,
|
5623
6028
|
device const char * src1,
|
5624
6029
|
device float * dst,
|
6030
|
+
device const char * ids,
|
5625
6031
|
constant uint64_t & nbi1,
|
5626
6032
|
constant int64_t & ne00,
|
5627
6033
|
constant int64_t & ne01,
|
@@ -5642,28 +6048,19 @@ kernel void kernel_mul_mv_id_f32_f32(
|
|
5642
6048
|
constant uint & r2,
|
5643
6049
|
constant uint & r3,
|
5644
6050
|
constant int & idx,
|
5645
|
-
device const char * src00,
|
5646
|
-
device const char * src01,
|
5647
|
-
device const char * src02,
|
5648
|
-
device const char * src03,
|
5649
|
-
device const char * src04,
|
5650
|
-
device const char * src05,
|
5651
|
-
device const char * src06,
|
5652
|
-
device const char * src07,
|
5653
6051
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
5654
6052
|
uint tiitg[[thread_index_in_threadgroup]],
|
5655
6053
|
uint tiisg[[thread_index_in_simdgroup]],
|
5656
6054
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
5657
|
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
5658
|
-
|
5659
6055
|
const int64_t bid = tgpig.z/(ne12*ne13);
|
5660
6056
|
|
5661
6057
|
tgpig.z = tgpig.z%(ne12*ne13);
|
5662
6058
|
|
5663
6059
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6060
|
+
device const char * src0 = src0s + id*nb02;
|
5664
6061
|
|
5665
6062
|
kernel_mul_mv_f32_f32_impl(
|
5666
|
-
src0
|
6063
|
+
src0,
|
5667
6064
|
src1 + bid*nb11,
|
5668
6065
|
dst + bid*ne0,
|
5669
6066
|
ne00,
|
@@ -5688,9 +6085,10 @@ kernel void kernel_mul_mv_id_f32_f32(
|
|
5688
6085
|
|
5689
6086
|
[[host_name("kernel_mul_mv_id_f16_f32")]]
|
5690
6087
|
kernel void kernel_mul_mv_id_f16_f32(
|
5691
|
-
device const char *
|
6088
|
+
device const char * src0s,
|
5692
6089
|
device const char * src1,
|
5693
6090
|
device float * dst,
|
6091
|
+
device const char * ids,
|
5694
6092
|
constant uint64_t & nbi1,
|
5695
6093
|
constant int64_t & ne00,
|
5696
6094
|
constant int64_t & ne01,
|
@@ -5711,28 +6109,19 @@ kernel void kernel_mul_mv_id_f16_f32(
|
|
5711
6109
|
constant uint & r2,
|
5712
6110
|
constant uint & r3,
|
5713
6111
|
constant int & idx,
|
5714
|
-
device const char * src00,
|
5715
|
-
device const char * src01,
|
5716
|
-
device const char * src02,
|
5717
|
-
device const char * src03,
|
5718
|
-
device const char * src04,
|
5719
|
-
device const char * src05,
|
5720
|
-
device const char * src06,
|
5721
|
-
device const char * src07,
|
5722
6112
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
5723
6113
|
uint tiitg[[thread_index_in_threadgroup]],
|
5724
6114
|
uint tiisg[[thread_index_in_simdgroup]],
|
5725
6115
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
5726
|
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
5727
|
-
|
5728
6116
|
const int64_t bid = tgpig.z/(ne12*ne13);
|
5729
6117
|
|
5730
6118
|
tgpig.z = tgpig.z%(ne12*ne13);
|
5731
6119
|
|
5732
6120
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6121
|
+
device const char * src0 = src0s + id*nb02;
|
5733
6122
|
|
5734
6123
|
kernel_mul_mv_f16_f32_impl(
|
5735
|
-
src0
|
6124
|
+
src0,
|
5736
6125
|
src1 + bid*nb11,
|
5737
6126
|
dst + bid*ne0,
|
5738
6127
|
ne00,
|
@@ -5757,9 +6146,10 @@ kernel void kernel_mul_mv_id_f16_f32(
|
|
5757
6146
|
|
5758
6147
|
[[host_name("kernel_mul_mv_id_q8_0_f32")]]
|
5759
6148
|
kernel void kernel_mul_mv_id_q8_0_f32(
|
5760
|
-
device const char *
|
6149
|
+
device const char * src0s,
|
5761
6150
|
device const char * src1,
|
5762
6151
|
device float * dst,
|
6152
|
+
device const char * ids,
|
5763
6153
|
constant uint64_t & nbi1,
|
5764
6154
|
constant int64_t & ne00,
|
5765
6155
|
constant int64_t & ne01,
|
@@ -5780,28 +6170,19 @@ kernel void kernel_mul_mv_id_q8_0_f32(
|
|
5780
6170
|
constant uint & r2,
|
5781
6171
|
constant uint & r3,
|
5782
6172
|
constant int & idx,
|
5783
|
-
device const char * src00,
|
5784
|
-
device const char * src01,
|
5785
|
-
device const char * src02,
|
5786
|
-
device const char * src03,
|
5787
|
-
device const char * src04,
|
5788
|
-
device const char * src05,
|
5789
|
-
device const char * src06,
|
5790
|
-
device const char * src07,
|
5791
6173
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
5792
6174
|
uint tiitg[[thread_index_in_threadgroup]],
|
5793
6175
|
uint tiisg[[thread_index_in_simdgroup]],
|
5794
6176
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
5795
|
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
5796
|
-
|
5797
6177
|
const int64_t bid = tgpig.z/(ne12*ne13);
|
5798
6178
|
|
5799
6179
|
tgpig.z = tgpig.z%(ne12*ne13);
|
5800
6180
|
|
5801
6181
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6182
|
+
device const char * src0 = src0s + id*nb02;
|
5802
6183
|
|
5803
6184
|
kernel_mul_mv_q8_0_f32_impl(
|
5804
|
-
src0
|
6185
|
+
src0,
|
5805
6186
|
(device const float *) (src1 + bid*nb11),
|
5806
6187
|
dst + bid*ne0,
|
5807
6188
|
ne00,
|
@@ -5820,9 +6201,10 @@ kernel void kernel_mul_mv_id_q8_0_f32(
|
|
5820
6201
|
|
5821
6202
|
[[host_name("kernel_mul_mv_id_q4_0_f32")]]
|
5822
6203
|
kernel void kernel_mul_mv_id_q4_0_f32(
|
5823
|
-
device const char *
|
6204
|
+
device const char * src0s,
|
5824
6205
|
device const char * src1,
|
5825
6206
|
device float * dst,
|
6207
|
+
device const char * ids,
|
5826
6208
|
constant uint64_t & nbi1,
|
5827
6209
|
constant int64_t & ne00,
|
5828
6210
|
constant int64_t & ne01,
|
@@ -5843,28 +6225,19 @@ kernel void kernel_mul_mv_id_q4_0_f32(
|
|
5843
6225
|
constant uint & r2,
|
5844
6226
|
constant uint & r3,
|
5845
6227
|
constant int & idx,
|
5846
|
-
device const char * src00,
|
5847
|
-
device const char * src01,
|
5848
|
-
device const char * src02,
|
5849
|
-
device const char * src03,
|
5850
|
-
device const char * src04,
|
5851
|
-
device const char * src05,
|
5852
|
-
device const char * src06,
|
5853
|
-
device const char * src07,
|
5854
6228
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
5855
6229
|
uint tiitg[[thread_index_in_threadgroup]],
|
5856
6230
|
uint tiisg[[thread_index_in_simdgroup]],
|
5857
6231
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
5858
|
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
5859
|
-
|
5860
6232
|
const int64_t bid = tgpig.z/(ne12*ne13);
|
5861
6233
|
|
5862
6234
|
tgpig.z = tgpig.z%(ne12*ne13);
|
5863
6235
|
|
5864
6236
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6237
|
+
device const char * src0 = src0s + id*nb02;
|
5865
6238
|
|
5866
6239
|
mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
5867
|
-
src0
|
6240
|
+
src0,
|
5868
6241
|
(device const float *) (src1 + bid*nb11),
|
5869
6242
|
dst + bid*ne0,
|
5870
6243
|
ne00,
|
@@ -5883,9 +6256,10 @@ kernel void kernel_mul_mv_id_q4_0_f32(
|
|
5883
6256
|
|
5884
6257
|
[[host_name("kernel_mul_mv_id_q4_1_f32")]]
|
5885
6258
|
kernel void kernel_mul_mv_id_q4_1_f32(
|
5886
|
-
device const char *
|
6259
|
+
device const char * src0s,
|
5887
6260
|
device const char * src1,
|
5888
6261
|
device float * dst,
|
6262
|
+
device const char * ids,
|
5889
6263
|
constant uint64_t & nbi1,
|
5890
6264
|
constant int64_t & ne00,
|
5891
6265
|
constant int64_t & ne01,
|
@@ -5906,28 +6280,19 @@ kernel void kernel_mul_mv_id_q4_1_f32(
|
|
5906
6280
|
constant uint & r2,
|
5907
6281
|
constant uint & r3,
|
5908
6282
|
constant int & idx,
|
5909
|
-
device const char * src00,
|
5910
|
-
device const char * src01,
|
5911
|
-
device const char * src02,
|
5912
|
-
device const char * src03,
|
5913
|
-
device const char * src04,
|
5914
|
-
device const char * src05,
|
5915
|
-
device const char * src06,
|
5916
|
-
device const char * src07,
|
5917
6283
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
5918
6284
|
uint tiitg[[thread_index_in_threadgroup]],
|
5919
6285
|
uint tiisg[[thread_index_in_simdgroup]],
|
5920
6286
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
5921
|
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
5922
|
-
|
5923
6287
|
const int64_t bid = tgpig.z/(ne12*ne13);
|
5924
6288
|
|
5925
6289
|
tgpig.z = tgpig.z%(ne12*ne13);
|
5926
6290
|
|
5927
6291
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6292
|
+
device const char * src0 = src0s + id*nb02;
|
5928
6293
|
|
5929
6294
|
mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
5930
|
-
src0
|
6295
|
+
src0,
|
5931
6296
|
(device const float *) (src1 + bid*nb11),
|
5932
6297
|
dst + bid*ne0,
|
5933
6298
|
ne00,
|
@@ -5946,9 +6311,10 @@ kernel void kernel_mul_mv_id_q4_1_f32(
|
|
5946
6311
|
|
5947
6312
|
[[host_name("kernel_mul_mv_id_q5_0_f32")]]
|
5948
6313
|
kernel void kernel_mul_mv_id_q5_0_f32(
|
5949
|
-
device const char *
|
6314
|
+
device const char * src0s,
|
5950
6315
|
device const char * src1,
|
5951
6316
|
device float * dst,
|
6317
|
+
device const char * ids,
|
5952
6318
|
constant uint64_t & nbi1,
|
5953
6319
|
constant int64_t & ne00,
|
5954
6320
|
constant int64_t & ne01,
|
@@ -5969,28 +6335,19 @@ kernel void kernel_mul_mv_id_q5_0_f32(
|
|
5969
6335
|
constant uint & r2,
|
5970
6336
|
constant uint & r3,
|
5971
6337
|
constant int & idx,
|
5972
|
-
device const char * src00,
|
5973
|
-
device const char * src01,
|
5974
|
-
device const char * src02,
|
5975
|
-
device const char * src03,
|
5976
|
-
device const char * src04,
|
5977
|
-
device const char * src05,
|
5978
|
-
device const char * src06,
|
5979
|
-
device const char * src07,
|
5980
6338
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
5981
6339
|
uint tiitg[[thread_index_in_threadgroup]],
|
5982
6340
|
uint tiisg[[thread_index_in_simdgroup]],
|
5983
6341
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
5984
|
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
5985
|
-
|
5986
6342
|
const int64_t bid = tgpig.z/(ne12*ne13);
|
5987
6343
|
|
5988
6344
|
tgpig.z = tgpig.z%(ne12*ne13);
|
5989
6345
|
|
5990
6346
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6347
|
+
device const char * src0 = src0s + id*nb02;
|
5991
6348
|
|
5992
6349
|
mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
5993
|
-
src0
|
6350
|
+
src0,
|
5994
6351
|
(device const float *) (src1 + bid*nb11),
|
5995
6352
|
dst + bid*ne0,
|
5996
6353
|
ne00,
|
@@ -6009,9 +6366,10 @@ kernel void kernel_mul_mv_id_q5_0_f32(
|
|
6009
6366
|
|
6010
6367
|
[[host_name("kernel_mul_mv_id_q5_1_f32")]]
|
6011
6368
|
kernel void kernel_mul_mv_id_q5_1_f32(
|
6012
|
-
device const char *
|
6369
|
+
device const char * src0s,
|
6013
6370
|
device const char * src1,
|
6014
6371
|
device float * dst,
|
6372
|
+
device const char * ids,
|
6015
6373
|
constant uint64_t & nbi1,
|
6016
6374
|
constant int64_t & ne00,
|
6017
6375
|
constant int64_t & ne01,
|
@@ -6032,28 +6390,19 @@ kernel void kernel_mul_mv_id_q5_1_f32(
|
|
6032
6390
|
constant uint & r2,
|
6033
6391
|
constant uint & r3,
|
6034
6392
|
constant int & idx,
|
6035
|
-
device const char * src00,
|
6036
|
-
device const char * src01,
|
6037
|
-
device const char * src02,
|
6038
|
-
device const char * src03,
|
6039
|
-
device const char * src04,
|
6040
|
-
device const char * src05,
|
6041
|
-
device const char * src06,
|
6042
|
-
device const char * src07,
|
6043
6393
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
6044
6394
|
uint tiitg[[thread_index_in_threadgroup]],
|
6045
6395
|
uint tiisg[[thread_index_in_simdgroup]],
|
6046
6396
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6047
|
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
6048
|
-
|
6049
6397
|
const int64_t bid = tgpig.z/(ne12*ne13);
|
6050
6398
|
|
6051
6399
|
tgpig.z = tgpig.z%(ne12*ne13);
|
6052
6400
|
|
6053
6401
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6402
|
+
device const char * src0 = src0s + id*nb02;
|
6054
6403
|
|
6055
6404
|
mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
6056
|
-
src0
|
6405
|
+
src0,
|
6057
6406
|
(device const float *) (src1 + bid*nb11),
|
6058
6407
|
dst + bid*ne0,
|
6059
6408
|
ne00,
|
@@ -6072,9 +6421,10 @@ kernel void kernel_mul_mv_id_q5_1_f32(
|
|
6072
6421
|
|
6073
6422
|
[[host_name("kernel_mul_mv_id_q2_K_f32")]]
|
6074
6423
|
kernel void kernel_mul_mv_id_q2_K_f32(
|
6075
|
-
device const char *
|
6424
|
+
device const char * src0s,
|
6076
6425
|
device const char * src1,
|
6077
6426
|
device float * dst,
|
6427
|
+
device const char * ids,
|
6078
6428
|
constant uint64_t & nbi1,
|
6079
6429
|
constant int64_t & ne00,
|
6080
6430
|
constant int64_t & ne01,
|
@@ -6095,28 +6445,19 @@ kernel void kernel_mul_mv_id_q2_K_f32(
|
|
6095
6445
|
constant uint & r2,
|
6096
6446
|
constant uint & r3,
|
6097
6447
|
constant int & idx,
|
6098
|
-
device const char * src00,
|
6099
|
-
device const char * src01,
|
6100
|
-
device const char * src02,
|
6101
|
-
device const char * src03,
|
6102
|
-
device const char * src04,
|
6103
|
-
device const char * src05,
|
6104
|
-
device const char * src06,
|
6105
|
-
device const char * src07,
|
6106
6448
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
6107
6449
|
uint tiitg[[thread_index_in_threadgroup]],
|
6108
6450
|
uint tiisg[[thread_index_in_simdgroup]],
|
6109
6451
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6110
|
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
6111
|
-
|
6112
6452
|
const int64_t bid = tgpig.z/(ne12*ne13);
|
6113
6453
|
|
6114
6454
|
tgpig.z = tgpig.z%(ne12*ne13);
|
6115
6455
|
|
6116
6456
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6457
|
+
device const char * src0 = src0s + id*nb02;
|
6117
6458
|
|
6118
6459
|
kernel_mul_mv_q2_K_f32_impl(
|
6119
|
-
src0
|
6460
|
+
src0,
|
6120
6461
|
(device const float *) (src1 + bid*nb11),
|
6121
6462
|
dst + bid*ne0,
|
6122
6463
|
ne00,
|
@@ -6135,9 +6476,10 @@ kernel void kernel_mul_mv_id_q2_K_f32(
|
|
6135
6476
|
|
6136
6477
|
[[host_name("kernel_mul_mv_id_q3_K_f32")]]
|
6137
6478
|
kernel void kernel_mul_mv_id_q3_K_f32(
|
6138
|
-
device const char *
|
6479
|
+
device const char * src0s,
|
6139
6480
|
device const char * src1,
|
6140
6481
|
device float * dst,
|
6482
|
+
device const char * ids,
|
6141
6483
|
constant uint64_t & nbi1,
|
6142
6484
|
constant int64_t & ne00,
|
6143
6485
|
constant int64_t & ne01,
|
@@ -6158,28 +6500,19 @@ kernel void kernel_mul_mv_id_q3_K_f32(
|
|
6158
6500
|
constant uint & r2,
|
6159
6501
|
constant uint & r3,
|
6160
6502
|
constant int & idx,
|
6161
|
-
device const char * src00,
|
6162
|
-
device const char * src01,
|
6163
|
-
device const char * src02,
|
6164
|
-
device const char * src03,
|
6165
|
-
device const char * src04,
|
6166
|
-
device const char * src05,
|
6167
|
-
device const char * src06,
|
6168
|
-
device const char * src07,
|
6169
6503
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
6170
6504
|
uint tiitg[[thread_index_in_threadgroup]],
|
6171
6505
|
uint tiisg[[thread_index_in_simdgroup]],
|
6172
6506
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6173
|
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
6174
|
-
|
6175
6507
|
const int64_t bid = tgpig.z/(ne12*ne13);
|
6176
6508
|
|
6177
6509
|
tgpig.z = tgpig.z%(ne12*ne13);
|
6178
6510
|
|
6179
6511
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6512
|
+
device const char * src0 = src0s + id*nb02;
|
6180
6513
|
|
6181
6514
|
kernel_mul_mv_q3_K_f32_impl(
|
6182
|
-
src0
|
6515
|
+
src0,
|
6183
6516
|
(device const float *) (src1 + bid*nb11),
|
6184
6517
|
dst + bid*ne0,
|
6185
6518
|
ne00,
|
@@ -6198,9 +6531,10 @@ kernel void kernel_mul_mv_id_q3_K_f32(
|
|
6198
6531
|
|
6199
6532
|
[[host_name("kernel_mul_mv_id_q4_K_f32")]]
|
6200
6533
|
kernel void kernel_mul_mv_id_q4_K_f32(
|
6201
|
-
device const char *
|
6534
|
+
device const char * src0s,
|
6202
6535
|
device const char * src1,
|
6203
6536
|
device float * dst,
|
6537
|
+
device const char * ids,
|
6204
6538
|
constant uint64_t & nbi1,
|
6205
6539
|
constant int64_t & ne00,
|
6206
6540
|
constant int64_t & ne01,
|
@@ -6221,28 +6555,19 @@ kernel void kernel_mul_mv_id_q4_K_f32(
|
|
6221
6555
|
constant uint & r2,
|
6222
6556
|
constant uint & r3,
|
6223
6557
|
constant int & idx,
|
6224
|
-
device const char * src00,
|
6225
|
-
device const char * src01,
|
6226
|
-
device const char * src02,
|
6227
|
-
device const char * src03,
|
6228
|
-
device const char * src04,
|
6229
|
-
device const char * src05,
|
6230
|
-
device const char * src06,
|
6231
|
-
device const char * src07,
|
6232
6558
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
6233
6559
|
uint tiitg[[thread_index_in_threadgroup]],
|
6234
6560
|
uint tiisg[[thread_index_in_simdgroup]],
|
6235
6561
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6236
|
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
6237
|
-
|
6238
6562
|
const int64_t bid = tgpig.z/(ne12*ne13);
|
6239
6563
|
|
6240
6564
|
tgpig.z = tgpig.z%(ne12*ne13);
|
6241
6565
|
|
6242
6566
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6567
|
+
device const char * src0 = src0s + id*nb02;
|
6243
6568
|
|
6244
6569
|
kernel_mul_mv_q4_K_f32_impl(
|
6245
|
-
src0
|
6570
|
+
src0,
|
6246
6571
|
(device const float *) (src1 + bid*nb11),
|
6247
6572
|
dst + bid*ne0,
|
6248
6573
|
ne00,
|
@@ -6261,9 +6586,10 @@ kernel void kernel_mul_mv_id_q4_K_f32(
|
|
6261
6586
|
|
6262
6587
|
[[host_name("kernel_mul_mv_id_q5_K_f32")]]
|
6263
6588
|
kernel void kernel_mul_mv_id_q5_K_f32(
|
6264
|
-
device const char *
|
6589
|
+
device const char * src0s,
|
6265
6590
|
device const char * src1,
|
6266
6591
|
device float * dst,
|
6592
|
+
device const char * ids,
|
6267
6593
|
constant uint64_t & nbi1,
|
6268
6594
|
constant int64_t & ne00,
|
6269
6595
|
constant int64_t & ne01,
|
@@ -6284,28 +6610,19 @@ kernel void kernel_mul_mv_id_q5_K_f32(
|
|
6284
6610
|
constant uint & r2,
|
6285
6611
|
constant uint & r3,
|
6286
6612
|
constant int & idx,
|
6287
|
-
device const char * src00,
|
6288
|
-
device const char * src01,
|
6289
|
-
device const char * src02,
|
6290
|
-
device const char * src03,
|
6291
|
-
device const char * src04,
|
6292
|
-
device const char * src05,
|
6293
|
-
device const char * src06,
|
6294
|
-
device const char * src07,
|
6295
6613
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
6296
6614
|
uint tiitg[[thread_index_in_threadgroup]],
|
6297
6615
|
uint tiisg[[thread_index_in_simdgroup]],
|
6298
6616
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6299
|
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
6300
|
-
|
6301
6617
|
const int64_t bid = tgpig.z/(ne12*ne13);
|
6302
6618
|
|
6303
6619
|
tgpig.z = tgpig.z%(ne12*ne13);
|
6304
6620
|
|
6305
6621
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6622
|
+
device const char * src0 = src0s + id*nb02;
|
6306
6623
|
|
6307
6624
|
kernel_mul_mv_q5_K_f32_impl(
|
6308
|
-
src0
|
6625
|
+
src0,
|
6309
6626
|
(device const float *) (src1 + bid*nb11),
|
6310
6627
|
dst + bid*ne0,
|
6311
6628
|
ne00,
|
@@ -6324,9 +6641,10 @@ kernel void kernel_mul_mv_id_q5_K_f32(
|
|
6324
6641
|
|
6325
6642
|
[[host_name("kernel_mul_mv_id_q6_K_f32")]]
|
6326
6643
|
kernel void kernel_mul_mv_id_q6_K_f32(
|
6327
|
-
device const char *
|
6644
|
+
device const char * src0s,
|
6328
6645
|
device const char * src1,
|
6329
6646
|
device float * dst,
|
6647
|
+
device const char * ids,
|
6330
6648
|
constant uint64_t & nbi1,
|
6331
6649
|
constant int64_t & ne00,
|
6332
6650
|
constant int64_t & ne01,
|
@@ -6347,28 +6665,19 @@ kernel void kernel_mul_mv_id_q6_K_f32(
|
|
6347
6665
|
constant uint & r2,
|
6348
6666
|
constant uint & r3,
|
6349
6667
|
constant int & idx,
|
6350
|
-
device const char * src00,
|
6351
|
-
device const char * src01,
|
6352
|
-
device const char * src02,
|
6353
|
-
device const char * src03,
|
6354
|
-
device const char * src04,
|
6355
|
-
device const char * src05,
|
6356
|
-
device const char * src06,
|
6357
|
-
device const char * src07,
|
6358
6668
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
6359
6669
|
uint tiitg[[thread_index_in_threadgroup]],
|
6360
6670
|
uint tiisg[[thread_index_in_simdgroup]],
|
6361
6671
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6362
|
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
6363
|
-
|
6364
6672
|
const int64_t bid = tgpig.z/(ne12*ne13);
|
6365
6673
|
|
6366
6674
|
tgpig.z = tgpig.z%(ne12*ne13);
|
6367
6675
|
|
6368
6676
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6677
|
+
device const char * src0 = src0s + id*nb02;
|
6369
6678
|
|
6370
6679
|
kernel_mul_mv_q6_K_f32_impl(
|
6371
|
-
src0
|
6680
|
+
src0,
|
6372
6681
|
(device const float *) (src1 + bid*nb11),
|
6373
6682
|
dst + bid*ne0,
|
6374
6683
|
ne00,
|
@@ -6387,9 +6696,10 @@ kernel void kernel_mul_mv_id_q6_K_f32(
|
|
6387
6696
|
|
6388
6697
|
[[host_name("kernel_mul_mv_id_iq2_xxs_f32")]]
|
6389
6698
|
kernel void kernel_mul_mv_id_iq2_xxs_f32(
|
6390
|
-
device const char *
|
6699
|
+
device const char * src0s,
|
6391
6700
|
device const char * src1,
|
6392
6701
|
device float * dst,
|
6702
|
+
device const char * ids,
|
6393
6703
|
constant uint64_t & nbi1,
|
6394
6704
|
constant int64_t & ne00,
|
6395
6705
|
constant int64_t & ne01,
|
@@ -6410,29 +6720,20 @@ kernel void kernel_mul_mv_id_iq2_xxs_f32(
|
|
6410
6720
|
constant uint & r2,
|
6411
6721
|
constant uint & r3,
|
6412
6722
|
constant int & idx,
|
6413
|
-
device const char * src00,
|
6414
|
-
device const char * src01,
|
6415
|
-
device const char * src02,
|
6416
|
-
device const char * src03,
|
6417
|
-
device const char * src04,
|
6418
|
-
device const char * src05,
|
6419
|
-
device const char * src06,
|
6420
|
-
device const char * src07,
|
6421
6723
|
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
6422
6724
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
6423
6725
|
uint tiitg[[thread_index_in_threadgroup]],
|
6424
6726
|
uint tiisg[[thread_index_in_simdgroup]],
|
6425
6727
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6426
|
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
6427
|
-
|
6428
6728
|
const int64_t bid = tgpig.z/(ne12*ne13);
|
6429
6729
|
|
6430
6730
|
tgpig.z = tgpig.z%(ne12*ne13);
|
6431
6731
|
|
6432
6732
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6733
|
+
device const char * src0 = src0s + id*nb02;
|
6433
6734
|
|
6434
6735
|
kernel_mul_mv_iq2_xxs_f32_impl(
|
6435
|
-
src0
|
6736
|
+
src0,
|
6436
6737
|
(device const float *) (src1 + bid*nb11),
|
6437
6738
|
dst + bid*ne0,
|
6438
6739
|
ne00,
|
@@ -6452,9 +6753,10 @@ kernel void kernel_mul_mv_id_iq2_xxs_f32(
|
|
6452
6753
|
|
6453
6754
|
[[host_name("kernel_mul_mv_id_iq2_xs_f32")]]
|
6454
6755
|
kernel void kernel_mul_mv_id_iq2_xs_f32(
|
6455
|
-
device const char *
|
6756
|
+
device const char * src0s,
|
6456
6757
|
device const char * src1,
|
6457
6758
|
device float * dst,
|
6759
|
+
device const char * ids,
|
6458
6760
|
constant uint64_t & nbi1,
|
6459
6761
|
constant int64_t & ne00,
|
6460
6762
|
constant int64_t & ne01,
|
@@ -6475,29 +6777,20 @@ kernel void kernel_mul_mv_id_iq2_xs_f32(
|
|
6475
6777
|
constant uint & r2,
|
6476
6778
|
constant uint & r3,
|
6477
6779
|
constant int & idx,
|
6478
|
-
device const char * src00,
|
6479
|
-
device const char * src01,
|
6480
|
-
device const char * src02,
|
6481
|
-
device const char * src03,
|
6482
|
-
device const char * src04,
|
6483
|
-
device const char * src05,
|
6484
|
-
device const char * src06,
|
6485
|
-
device const char * src07,
|
6486
6780
|
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
6487
6781
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
6488
6782
|
uint tiitg[[thread_index_in_threadgroup]],
|
6489
6783
|
uint tiisg[[thread_index_in_simdgroup]],
|
6490
6784
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6491
|
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
6492
|
-
|
6493
6785
|
const int64_t bid = tgpig.z/(ne12*ne13);
|
6494
6786
|
|
6495
6787
|
tgpig.z = tgpig.z%(ne12*ne13);
|
6496
6788
|
|
6497
6789
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6790
|
+
device const char * src0 = src0s + id*nb02;
|
6498
6791
|
|
6499
6792
|
kernel_mul_mv_iq2_xs_f32_impl(
|
6500
|
-
src0
|
6793
|
+
src0,
|
6501
6794
|
(device const float *) (src1 + bid*nb11),
|
6502
6795
|
dst + bid*ne0,
|
6503
6796
|
ne00,
|
@@ -6517,9 +6810,10 @@ kernel void kernel_mul_mv_id_iq2_xs_f32(
|
|
6517
6810
|
|
6518
6811
|
[[host_name("kernel_mul_mv_id_iq3_xxs_f32")]]
|
6519
6812
|
kernel void kernel_mul_mv_id_iq3_xxs_f32(
|
6520
|
-
device const char *
|
6813
|
+
device const char * src0s,
|
6521
6814
|
device const char * src1,
|
6522
6815
|
device float * dst,
|
6816
|
+
device const char * ids,
|
6523
6817
|
constant uint64_t & nbi1,
|
6524
6818
|
constant int64_t & ne00,
|
6525
6819
|
constant int64_t & ne01,
|
@@ -6540,29 +6834,20 @@ kernel void kernel_mul_mv_id_iq3_xxs_f32(
|
|
6540
6834
|
constant uint & r2,
|
6541
6835
|
constant uint & r3,
|
6542
6836
|
constant int & idx,
|
6543
|
-
device const char * src00,
|
6544
|
-
device const char * src01,
|
6545
|
-
device const char * src02,
|
6546
|
-
device const char * src03,
|
6547
|
-
device const char * src04,
|
6548
|
-
device const char * src05,
|
6549
|
-
device const char * src06,
|
6550
|
-
device const char * src07,
|
6551
6837
|
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
6552
6838
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
6553
6839
|
uint tiitg[[thread_index_in_threadgroup]],
|
6554
6840
|
uint tiisg[[thread_index_in_simdgroup]],
|
6555
6841
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6556
|
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
6557
|
-
|
6558
6842
|
const int64_t bid = tgpig.z/(ne12*ne13);
|
6559
6843
|
|
6560
6844
|
tgpig.z = tgpig.z%(ne12*ne13);
|
6561
6845
|
|
6562
6846
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6847
|
+
device const char * src0 = src0s + id*nb02;
|
6563
6848
|
|
6564
6849
|
kernel_mul_mv_iq3_xxs_f32_impl(
|
6565
|
-
src0
|
6850
|
+
src0,
|
6566
6851
|
(device const float *) (src1 + bid*nb11),
|
6567
6852
|
dst + bid*ne0,
|
6568
6853
|
ne00,
|
@@ -6582,9 +6867,10 @@ kernel void kernel_mul_mv_id_iq3_xxs_f32(
|
|
6582
6867
|
|
6583
6868
|
[[host_name("kernel_mul_mv_id_iq3_s_f32")]]
|
6584
6869
|
kernel void kernel_mul_mv_id_iq3_s_f32(
|
6585
|
-
device const char *
|
6870
|
+
device const char * src0s,
|
6586
6871
|
device const char * src1,
|
6587
6872
|
device float * dst,
|
6873
|
+
device const char * ids,
|
6588
6874
|
constant uint64_t & nbi1,
|
6589
6875
|
constant int64_t & ne00,
|
6590
6876
|
constant int64_t & ne01,
|
@@ -6605,29 +6891,20 @@ kernel void kernel_mul_mv_id_iq3_s_f32(
|
|
6605
6891
|
constant uint & r2,
|
6606
6892
|
constant uint & r3,
|
6607
6893
|
constant int & idx,
|
6608
|
-
device const char * src00,
|
6609
|
-
device const char * src01,
|
6610
|
-
device const char * src02,
|
6611
|
-
device const char * src03,
|
6612
|
-
device const char * src04,
|
6613
|
-
device const char * src05,
|
6614
|
-
device const char * src06,
|
6615
|
-
device const char * src07,
|
6616
6894
|
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
6617
6895
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
6618
6896
|
uint tiitg[[thread_index_in_threadgroup]],
|
6619
6897
|
uint tiisg[[thread_index_in_simdgroup]],
|
6620
6898
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6621
|
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
6622
|
-
|
6623
6899
|
const int64_t bid = tgpig.z/(ne12*ne13);
|
6624
6900
|
|
6625
6901
|
tgpig.z = tgpig.z%(ne12*ne13);
|
6626
6902
|
|
6627
6903
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6904
|
+
device const char * src0 = src0s + id*nb02;
|
6628
6905
|
|
6629
6906
|
kernel_mul_mv_iq3_s_f32_impl(
|
6630
|
-
src0
|
6907
|
+
src0,
|
6631
6908
|
(device const float *) (src1 + bid*nb11),
|
6632
6909
|
dst + bid*ne0,
|
6633
6910
|
ne00,
|
@@ -6647,9 +6924,10 @@ kernel void kernel_mul_mv_id_iq3_s_f32(
|
|
6647
6924
|
|
6648
6925
|
[[host_name("kernel_mul_mv_id_iq2_s_f32")]]
|
6649
6926
|
kernel void kernel_mul_mv_id_iq2_s_f32(
|
6650
|
-
device const char *
|
6927
|
+
device const char * src0s,
|
6651
6928
|
device const char * src1,
|
6652
6929
|
device float * dst,
|
6930
|
+
device const char * ids,
|
6653
6931
|
constant uint64_t & nbi1,
|
6654
6932
|
constant int64_t & ne00,
|
6655
6933
|
constant int64_t & ne01,
|
@@ -6670,29 +6948,20 @@ kernel void kernel_mul_mv_id_iq2_s_f32(
|
|
6670
6948
|
constant uint & r2,
|
6671
6949
|
constant uint & r3,
|
6672
6950
|
constant int & idx,
|
6673
|
-
device const char * src00,
|
6674
|
-
device const char * src01,
|
6675
|
-
device const char * src02,
|
6676
|
-
device const char * src03,
|
6677
|
-
device const char * src04,
|
6678
|
-
device const char * src05,
|
6679
|
-
device const char * src06,
|
6680
|
-
device const char * src07,
|
6681
6951
|
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
6682
6952
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
6683
6953
|
uint tiitg[[thread_index_in_threadgroup]],
|
6684
6954
|
uint tiisg[[thread_index_in_simdgroup]],
|
6685
6955
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6686
|
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
6687
|
-
|
6688
6956
|
const int64_t bid = tgpig.z/(ne12*ne13);
|
6689
6957
|
|
6690
6958
|
tgpig.z = tgpig.z%(ne12*ne13);
|
6691
6959
|
|
6692
6960
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6961
|
+
device const char * src0 = src0s + id*nb02;
|
6693
6962
|
|
6694
6963
|
kernel_mul_mv_iq2_s_f32_impl(
|
6695
|
-
src0
|
6964
|
+
src0,
|
6696
6965
|
(device const float *) (src1 + bid*nb11),
|
6697
6966
|
dst + bid*ne0,
|
6698
6967
|
ne00,
|
@@ -6712,9 +6981,10 @@ kernel void kernel_mul_mv_id_iq2_s_f32(
|
|
6712
6981
|
|
6713
6982
|
[[host_name("kernel_mul_mv_id_iq1_s_f32")]]
|
6714
6983
|
kernel void kernel_mul_mv_id_iq1_s_f32(
|
6715
|
-
device const char *
|
6984
|
+
device const char * src0s,
|
6716
6985
|
device const char * src1,
|
6717
6986
|
device float * dst,
|
6987
|
+
device const char * ids,
|
6718
6988
|
constant uint64_t & nbi1,
|
6719
6989
|
constant int64_t & ne00,
|
6720
6990
|
constant int64_t & ne01,
|
@@ -6735,28 +7005,74 @@ kernel void kernel_mul_mv_id_iq1_s_f32(
|
|
6735
7005
|
constant uint & r2,
|
6736
7006
|
constant uint & r3,
|
6737
7007
|
constant int & idx,
|
6738
|
-
device const char * src00,
|
6739
|
-
device const char * src01,
|
6740
|
-
device const char * src02,
|
6741
|
-
device const char * src03,
|
6742
|
-
device const char * src04,
|
6743
|
-
device const char * src05,
|
6744
|
-
device const char * src06,
|
6745
|
-
device const char * src07,
|
6746
7008
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
6747
7009
|
uint tiitg[[thread_index_in_threadgroup]],
|
6748
7010
|
uint tiisg[[thread_index_in_simdgroup]],
|
6749
7011
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6750
|
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
6751
|
-
|
6752
7012
|
const int64_t bid = tgpig.z/(ne12*ne13);
|
6753
7013
|
|
6754
7014
|
tgpig.z = tgpig.z%(ne12*ne13);
|
6755
7015
|
|
6756
7016
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
7017
|
+
device const char * src0 = src0s + id*nb02;
|
6757
7018
|
|
6758
7019
|
kernel_mul_mv_iq1_s_f32_impl(
|
6759
|
-
src0
|
7020
|
+
src0,
|
7021
|
+
(device const float *) (src1 + bid*nb11),
|
7022
|
+
dst + bid*ne0,
|
7023
|
+
ne00,
|
7024
|
+
ne01,
|
7025
|
+
ne02,
|
7026
|
+
ne10,
|
7027
|
+
ne12,
|
7028
|
+
ne0,
|
7029
|
+
ne1,
|
7030
|
+
r2,
|
7031
|
+
r3,
|
7032
|
+
tgpig,
|
7033
|
+
tiisg,
|
7034
|
+
sgitg);
|
7035
|
+
}
|
7036
|
+
|
7037
|
+
[[host_name("kernel_mul_mv_id_iq1_m_f32")]]
|
7038
|
+
kernel void kernel_mul_mv_id_iq1_m_f32(
|
7039
|
+
device const char * src0s,
|
7040
|
+
device const char * src1,
|
7041
|
+
device float * dst,
|
7042
|
+
device const char * ids,
|
7043
|
+
constant uint64_t & nbi1,
|
7044
|
+
constant int64_t & ne00,
|
7045
|
+
constant int64_t & ne01,
|
7046
|
+
constant int64_t & ne02,
|
7047
|
+
constant uint64_t & nb00,
|
7048
|
+
constant uint64_t & nb01,
|
7049
|
+
constant uint64_t & nb02,
|
7050
|
+
constant int64_t & ne10,
|
7051
|
+
constant int64_t & ne11,
|
7052
|
+
constant int64_t & ne12,
|
7053
|
+
constant int64_t & ne13,
|
7054
|
+
constant uint64_t & nb10,
|
7055
|
+
constant uint64_t & nb11,
|
7056
|
+
constant uint64_t & nb12,
|
7057
|
+
constant int64_t & ne0,
|
7058
|
+
constant int64_t & ne1,
|
7059
|
+
constant uint64_t & nb1,
|
7060
|
+
constant uint & r2,
|
7061
|
+
constant uint & r3,
|
7062
|
+
constant int & idx,
|
7063
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
7064
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
7065
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
7066
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
7067
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
7068
|
+
|
7069
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
7070
|
+
|
7071
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
7072
|
+
device const char * src0 = src0s + id*nb02;
|
7073
|
+
|
7074
|
+
kernel_mul_mv_iq1_m_f32_impl(
|
7075
|
+
src0,
|
6760
7076
|
(device const float *) (src1 + bid*nb11),
|
6761
7077
|
dst + bid*ne0,
|
6762
7078
|
ne00,
|
@@ -6775,9 +7091,10 @@ kernel void kernel_mul_mv_id_iq1_s_f32(
|
|
6775
7091
|
|
6776
7092
|
[[host_name("kernel_mul_mv_id_iq4_nl_f32")]]
|
6777
7093
|
kernel void kernel_mul_mv_id_iq4_nl_f32(
|
6778
|
-
device const char *
|
7094
|
+
device const char * src0s,
|
6779
7095
|
device const char * src1,
|
6780
7096
|
device float * dst,
|
7097
|
+
device const char * ids,
|
6781
7098
|
constant uint64_t & nbi1,
|
6782
7099
|
constant int64_t & ne00,
|
6783
7100
|
constant int64_t & ne01,
|
@@ -6798,29 +7115,20 @@ kernel void kernel_mul_mv_id_iq4_nl_f32(
|
|
6798
7115
|
constant uint & r2,
|
6799
7116
|
constant uint & r3,
|
6800
7117
|
constant int & idx,
|
6801
|
-
device const char * src00,
|
6802
|
-
device const char * src01,
|
6803
|
-
device const char * src02,
|
6804
|
-
device const char * src03,
|
6805
|
-
device const char * src04,
|
6806
|
-
device const char * src05,
|
6807
|
-
device const char * src06,
|
6808
|
-
device const char * src07,
|
6809
7118
|
threadgroup float * shared_values [[threadgroup(0)]],
|
6810
7119
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
6811
7120
|
uint tiitg[[thread_index_in_threadgroup]],
|
6812
7121
|
uint tiisg[[thread_index_in_simdgroup]],
|
6813
7122
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6814
|
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
6815
|
-
|
6816
7123
|
const int64_t bid = tgpig.z/(ne12*ne13);
|
6817
7124
|
|
6818
7125
|
tgpig.z = tgpig.z%(ne12*ne13);
|
6819
7126
|
|
6820
7127
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
7128
|
+
device const char * src0 = src0s + id*nb02;
|
6821
7129
|
|
6822
7130
|
kernel_mul_mv_iq4_nl_f32_impl(
|
6823
|
-
src0
|
7131
|
+
src0,
|
6824
7132
|
(device const float *) (src1 + bid*nb11),
|
6825
7133
|
dst + bid*ne0,
|
6826
7134
|
ne00,
|
@@ -6840,9 +7148,10 @@ kernel void kernel_mul_mv_id_iq4_nl_f32(
|
|
6840
7148
|
|
6841
7149
|
[[host_name("kernel_mul_mv_id_iq4_xs_f32")]]
|
6842
7150
|
kernel void kernel_mul_mv_id_iq4_xs_f32(
|
6843
|
-
device const char *
|
7151
|
+
device const char * src0s,
|
6844
7152
|
device const char * src1,
|
6845
7153
|
device float * dst,
|
7154
|
+
device const char * ids,
|
6846
7155
|
constant uint64_t & nbi1,
|
6847
7156
|
constant int64_t & ne00,
|
6848
7157
|
constant int64_t & ne01,
|
@@ -6863,33 +7172,24 @@ kernel void kernel_mul_mv_id_iq4_xs_f32(
|
|
6863
7172
|
constant uint & r2,
|
6864
7173
|
constant uint & r3,
|
6865
7174
|
constant int & idx,
|
6866
|
-
device const char * src00,
|
6867
|
-
device const char * src01,
|
6868
|
-
device const char * src02,
|
6869
|
-
device const char * src03,
|
6870
|
-
device const char * src04,
|
6871
|
-
device const char * src05,
|
6872
|
-
device const char * src06,
|
6873
|
-
device const char * src07,
|
6874
7175
|
threadgroup float * shared_values [[threadgroup(0)]],
|
6875
7176
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
6876
7177
|
uint tiitg[[thread_index_in_threadgroup]],
|
6877
7178
|
uint tiisg[[thread_index_in_simdgroup]],
|
6878
7179
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6879
|
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
6880
|
-
|
6881
7180
|
const int64_t bid = tgpig.z/(ne12*ne13);
|
6882
7181
|
|
6883
7182
|
tgpig.z = tgpig.z%(ne12*ne13);
|
6884
7183
|
|
6885
7184
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
7185
|
+
device const char * src0 = src0s + id*nb02;
|
6886
7186
|
|
6887
7187
|
#if QK_K == 64
|
6888
7188
|
kernel_mul_mv_iq4_nl_f32_impl(
|
6889
7189
|
#else
|
6890
7190
|
kernel_mul_mv_iq4_xs_f32_impl(
|
6891
7191
|
#endif
|
6892
|
-
src0
|
7192
|
+
src0,
|
6893
7193
|
(device const float *) (src1 + bid*nb11),
|
6894
7194
|
dst + bid*ne0,
|
6895
7195
|
ne00,
|