llama_cpp 0.14.2 → 0.14.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +14 -0
- data/ext/llama_cpp/llama_cpp.cpp +64 -0
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +6 -0
- data/vendor/tmp/llama.cpp/Makefile +91 -21
- data/vendor/tmp/llama.cpp/ggml-alloc.c +14 -5
- data/vendor/tmp/llama.cpp/ggml-backend-impl.h +5 -0
- data/vendor/tmp/llama.cpp/ggml-backend.c +155 -125
- data/vendor/tmp/llama.cpp/ggml-backend.h +4 -4
- data/vendor/tmp/llama.cpp/ggml-common.h +25 -2
- data/vendor/tmp/llama.cpp/ggml-cuda.cu +1779 -10762
- data/vendor/tmp/llama.cpp/ggml-cuda.h +6 -15
- data/vendor/tmp/llama.cpp/ggml-kompute.cpp +5 -0
- data/vendor/tmp/llama.cpp/ggml-metal.m +167 -124
- data/vendor/tmp/llama.cpp/ggml-metal.metal +603 -303
- data/vendor/tmp/llama.cpp/ggml-opencl.cpp +5 -0
- data/vendor/tmp/llama.cpp/ggml-quants.c +663 -56
- data/vendor/tmp/llama.cpp/ggml-quants.h +3 -0
- data/vendor/tmp/llama.cpp/ggml-sycl.cpp +341 -469
- data/vendor/tmp/llama.cpp/ggml-sycl.h +19 -4
- data/vendor/tmp/llama.cpp/ggml-vulkan-shaders.hpp +37199 -14939
- data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +335 -307
- data/vendor/tmp/llama.cpp/ggml-vulkan.h +0 -11
- data/vendor/tmp/llama.cpp/ggml.c +229 -107
- data/vendor/tmp/llama.cpp/ggml.h +11 -5
- data/vendor/tmp/llama.cpp/llama.cpp +2136 -464
- data/vendor/tmp/llama.cpp/llama.h +86 -23
- data/vendor/tmp/llama.cpp/unicode-data.cpp +1651 -0
- data/vendor/tmp/llama.cpp/unicode-data.h +16 -0
- data/vendor/tmp/llama.cpp/unicode.cpp +8 -1403
- data/vendor/tmp/llama.cpp/unicode.h +2 -0
- metadata +5 -3
@@ -13,8 +13,8 @@ using namespace metal;
|
|
13
13
|
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
|
14
14
|
|
15
15
|
enum ggml_sort_order {
|
16
|
-
|
17
|
-
|
16
|
+
GGML_SORT_ORDER_ASC,
|
17
|
+
GGML_SORT_ORDER_DESC,
|
18
18
|
};
|
19
19
|
|
20
20
|
// general-purpose kernel for addition, multiplication and division of two tensors
|
@@ -1973,9 +1973,11 @@ kernel void kernel_timestep_embedding_f32(
|
|
1973
1973
|
|
1974
1974
|
// bitonic sort implementation following the CUDA kernels as reference
|
1975
1975
|
typedef void (argsort_t)(
|
1976
|
-
device const float
|
1977
|
-
device int32_t
|
1978
|
-
constant int64_t
|
1976
|
+
device const float * x,
|
1977
|
+
device int32_t * dst,
|
1978
|
+
constant int64_t & ncols,
|
1979
|
+
constant int64_t & ncols_pad,
|
1980
|
+
threadgroup int32_t * shared_values [[threadgroup(0)]],
|
1979
1981
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
1980
1982
|
uint3 tpitg[[thread_position_in_threadgroup]]);
|
1981
1983
|
|
@@ -1984,33 +1986,42 @@ kernel void kernel_argsort_f32_i32(
|
|
1984
1986
|
device const float * x,
|
1985
1987
|
device int32_t * dst,
|
1986
1988
|
constant int64_t & ncols,
|
1989
|
+
constant int64_t & ncols_pad,
|
1990
|
+
threadgroup int32_t * shared_values [[threadgroup(0)]],
|
1987
1991
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
1988
1992
|
uint3 tpitg[[thread_position_in_threadgroup]]) {
|
1989
1993
|
// bitonic sort
|
1990
1994
|
int col = tpitg[0];
|
1991
1995
|
int row = tgpig[1];
|
1992
1996
|
|
1993
|
-
if (col >=
|
1997
|
+
if (col >= ncols_pad) return;
|
1994
1998
|
|
1995
|
-
device const float * x_row = x
|
1996
|
-
|
1999
|
+
device const float * x_row = x + row * ncols;
|
2000
|
+
threadgroup int32_t * dst_row = shared_values;
|
1997
2001
|
|
1998
2002
|
// initialize indices
|
1999
|
-
|
2000
|
-
|
2001
|
-
}
|
2003
|
+
dst_row[col] = col;
|
2004
|
+
|
2002
2005
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
2003
2006
|
|
2004
|
-
for (int k = 2; k <=
|
2007
|
+
for (int k = 2; k <= ncols_pad; k *= 2) {
|
2005
2008
|
for (int j = k / 2; j > 0; j /= 2) {
|
2006
2009
|
int ixj = col ^ j;
|
2007
2010
|
if (ixj > col) {
|
2008
2011
|
if ((col & k) == 0) {
|
2009
|
-
if (
|
2012
|
+
if (dst_row[col] >= ncols ||
|
2013
|
+
(dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ?
|
2014
|
+
x_row[dst_row[col]] > x_row[dst_row[ixj]] :
|
2015
|
+
x_row[dst_row[col]] < x_row[dst_row[ixj]]))
|
2016
|
+
) {
|
2010
2017
|
SWAP(dst_row[col], dst_row[ixj]);
|
2011
2018
|
}
|
2012
2019
|
} else {
|
2013
|
-
if (
|
2020
|
+
if (dst_row[ixj] >= ncols ||
|
2021
|
+
(dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ?
|
2022
|
+
x_row[dst_row[col]] < x_row[dst_row[ixj]] :
|
2023
|
+
x_row[dst_row[col]] > x_row[dst_row[ixj]]))
|
2024
|
+
) {
|
2014
2025
|
SWAP(dst_row[col], dst_row[ixj]);
|
2015
2026
|
}
|
2016
2027
|
}
|
@@ -2018,10 +2029,15 @@ kernel void kernel_argsort_f32_i32(
|
|
2018
2029
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
2019
2030
|
}
|
2020
2031
|
}
|
2032
|
+
|
2033
|
+
// copy the result to dst without the padding
|
2034
|
+
if (col < ncols) {
|
2035
|
+
dst[row * ncols + col] = dst_row[col];
|
2036
|
+
}
|
2021
2037
|
}
|
2022
2038
|
|
2023
|
-
template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32<
|
2024
|
-
template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<
|
2039
|
+
template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_ASC>;
|
2040
|
+
template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_DESC>;
|
2025
2041
|
|
2026
2042
|
kernel void kernel_leaky_relu_f32(
|
2027
2043
|
device const float * src0,
|
@@ -2388,6 +2404,242 @@ kernel void kernel_cpy_f32_q4_1(
|
|
2388
2404
|
}
|
2389
2405
|
}
|
2390
2406
|
|
2407
|
+
kernel void kernel_cpy_f32_q5_0(
|
2408
|
+
device const float * src0,
|
2409
|
+
device void * dst,
|
2410
|
+
constant int64_t & ne00,
|
2411
|
+
constant int64_t & ne01,
|
2412
|
+
constant int64_t & ne02,
|
2413
|
+
constant int64_t & ne03,
|
2414
|
+
constant uint64_t & nb00,
|
2415
|
+
constant uint64_t & nb01,
|
2416
|
+
constant uint64_t & nb02,
|
2417
|
+
constant uint64_t & nb03,
|
2418
|
+
constant int64_t & ne0,
|
2419
|
+
constant int64_t & ne1,
|
2420
|
+
constant int64_t & ne2,
|
2421
|
+
constant int64_t & ne3,
|
2422
|
+
constant uint64_t & nb0,
|
2423
|
+
constant uint64_t & nb1,
|
2424
|
+
constant uint64_t & nb2,
|
2425
|
+
constant uint64_t & nb3,
|
2426
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
2427
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
2428
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
2429
|
+
const int64_t i03 = tgpig[2];
|
2430
|
+
const int64_t i02 = tgpig[1];
|
2431
|
+
const int64_t i01 = tgpig[0];
|
2432
|
+
|
2433
|
+
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
2434
|
+
|
2435
|
+
const int64_t i3 = n / (ne2*ne1*ne0);
|
2436
|
+
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
2437
|
+
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
2438
|
+
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK5_0;
|
2439
|
+
|
2440
|
+
device block_q5_0 * dst_data = (device block_q5_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
2441
|
+
|
2442
|
+
for (int64_t i00 = tpitg.x*QK5_0; i00 < ne00; i00 += ntg.x*QK5_0) {
|
2443
|
+
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
2444
|
+
|
2445
|
+
float amax = 0.0f; // absolute max
|
2446
|
+
float max = 0.0f;
|
2447
|
+
|
2448
|
+
for (int j = 0; j < QK5_0; j++) {
|
2449
|
+
const float v = src[j];
|
2450
|
+
if (amax < fabs(v)) {
|
2451
|
+
amax = fabs(v);
|
2452
|
+
max = v;
|
2453
|
+
}
|
2454
|
+
}
|
2455
|
+
|
2456
|
+
const float d = max / -16;
|
2457
|
+
const float id = d ? 1.0f/d : 0.0f;
|
2458
|
+
|
2459
|
+
dst_data[i00/QK5_0].d = d;
|
2460
|
+
|
2461
|
+
uint32_t qh = 0;
|
2462
|
+
for (int j = 0; j < QK5_0/2; ++j) {
|
2463
|
+
const float x0 = src[0 + j]*id;
|
2464
|
+
const float x1 = src[QK5_0/2 + j]*id;
|
2465
|
+
|
2466
|
+
const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));
|
2467
|
+
const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));
|
2468
|
+
|
2469
|
+
dst_data[i00/QK5_0].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
|
2470
|
+
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
|
2471
|
+
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
|
2472
|
+
}
|
2473
|
+
thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
|
2474
|
+
for (int j = 0; j < 4; ++j) {
|
2475
|
+
dst_data[i00/QK5_0].qh[j] = qh8[j];
|
2476
|
+
}
|
2477
|
+
}
|
2478
|
+
}
|
2479
|
+
|
2480
|
+
kernel void kernel_cpy_f32_q5_1(
|
2481
|
+
device const float * src0,
|
2482
|
+
device void * dst,
|
2483
|
+
constant int64_t & ne00,
|
2484
|
+
constant int64_t & ne01,
|
2485
|
+
constant int64_t & ne02,
|
2486
|
+
constant int64_t & ne03,
|
2487
|
+
constant uint64_t & nb00,
|
2488
|
+
constant uint64_t & nb01,
|
2489
|
+
constant uint64_t & nb02,
|
2490
|
+
constant uint64_t & nb03,
|
2491
|
+
constant int64_t & ne0,
|
2492
|
+
constant int64_t & ne1,
|
2493
|
+
constant int64_t & ne2,
|
2494
|
+
constant int64_t & ne3,
|
2495
|
+
constant uint64_t & nb0,
|
2496
|
+
constant uint64_t & nb1,
|
2497
|
+
constant uint64_t & nb2,
|
2498
|
+
constant uint64_t & nb3,
|
2499
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
2500
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
2501
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
2502
|
+
const int64_t i03 = tgpig[2];
|
2503
|
+
const int64_t i02 = tgpig[1];
|
2504
|
+
const int64_t i01 = tgpig[0];
|
2505
|
+
|
2506
|
+
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
2507
|
+
|
2508
|
+
const int64_t i3 = n / (ne2*ne1*ne0);
|
2509
|
+
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
2510
|
+
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
2511
|
+
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK5_1;
|
2512
|
+
|
2513
|
+
device block_q5_1 * dst_data = (device block_q5_1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
2514
|
+
|
2515
|
+
for (int64_t i00 = tpitg.x*QK5_1; i00 < ne00; i00 += ntg.x*QK5_1) {
|
2516
|
+
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
2517
|
+
|
2518
|
+
float max = src[0];
|
2519
|
+
float min = src[0];
|
2520
|
+
|
2521
|
+
for (int j = 1; j < QK5_1; j++) {
|
2522
|
+
const float v = src[j];
|
2523
|
+
min = v < min ? v : min;
|
2524
|
+
max = v > max ? v : max;
|
2525
|
+
}
|
2526
|
+
|
2527
|
+
const float d = (max - min) / 31;
|
2528
|
+
const float id = d ? 1.0f/d : 0.0f;
|
2529
|
+
|
2530
|
+
dst_data[i00/QK5_1].d = d;
|
2531
|
+
dst_data[i00/QK5_1].m = min;
|
2532
|
+
|
2533
|
+
uint32_t qh = 0;
|
2534
|
+
for (int j = 0; j < QK5_1/2; ++j) {
|
2535
|
+
const float x0 = (src[0 + j] - min)*id;
|
2536
|
+
const float x1 = (src[QK5_1/2 + j] - min)*id;
|
2537
|
+
|
2538
|
+
const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
|
2539
|
+
const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
|
2540
|
+
|
2541
|
+
dst_data[i00/QK5_1].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
|
2542
|
+
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
|
2543
|
+
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
|
2544
|
+
}
|
2545
|
+
thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
|
2546
|
+
for (int j = 0; j < 4; ++j) {
|
2547
|
+
dst_data[i00/QK5_1].qh[j] = qh8[j];
|
2548
|
+
}
|
2549
|
+
}
|
2550
|
+
}
|
2551
|
+
|
2552
|
+
static inline int best_index_int8(int n, constant float * val, float x) {
|
2553
|
+
if (x <= val[0]) return 0;
|
2554
|
+
if (x >= val[n-1]) return n-1;
|
2555
|
+
int ml = 0, mu = n-1;
|
2556
|
+
while (mu-ml > 1) {
|
2557
|
+
int mav = (ml+mu)/2;
|
2558
|
+
if (x < val[mav]) mu = mav; else ml = mav;
|
2559
|
+
}
|
2560
|
+
return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
|
2561
|
+
}
|
2562
|
+
|
2563
|
+
constexpr constant static float kvalues_iq4nl_f[16] = {
|
2564
|
+
-127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
|
2565
|
+
};
|
2566
|
+
|
2567
|
+
kernel void kernel_cpy_f32_iq4_nl(
|
2568
|
+
device const float * src0,
|
2569
|
+
device void * dst,
|
2570
|
+
constant int64_t & ne00,
|
2571
|
+
constant int64_t & ne01,
|
2572
|
+
constant int64_t & ne02,
|
2573
|
+
constant int64_t & ne03,
|
2574
|
+
constant uint64_t & nb00,
|
2575
|
+
constant uint64_t & nb01,
|
2576
|
+
constant uint64_t & nb02,
|
2577
|
+
constant uint64_t & nb03,
|
2578
|
+
constant int64_t & ne0,
|
2579
|
+
constant int64_t & ne1,
|
2580
|
+
constant int64_t & ne2,
|
2581
|
+
constant int64_t & ne3,
|
2582
|
+
constant uint64_t & nb0,
|
2583
|
+
constant uint64_t & nb1,
|
2584
|
+
constant uint64_t & nb2,
|
2585
|
+
constant uint64_t & nb3,
|
2586
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
2587
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
2588
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
2589
|
+
const int64_t i03 = tgpig[2];
|
2590
|
+
const int64_t i02 = tgpig[1];
|
2591
|
+
const int64_t i01 = tgpig[0];
|
2592
|
+
|
2593
|
+
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
2594
|
+
|
2595
|
+
const int64_t i3 = n / (ne2*ne1*ne0);
|
2596
|
+
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
2597
|
+
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
2598
|
+
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_NL;
|
2599
|
+
|
2600
|
+
device block_iq4_nl * dst_data = (device block_iq4_nl *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
2601
|
+
|
2602
|
+
for (int64_t i00 = tpitg.x*QK4_NL; i00 < ne00; i00 += ntg.x*QK4_NL) {
|
2603
|
+
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
2604
|
+
|
2605
|
+
float amax = 0.0f; // absolute max
|
2606
|
+
float max = 0.0f;
|
2607
|
+
|
2608
|
+
for (int j = 0; j < QK4_0; j++) {
|
2609
|
+
const float v = src[j];
|
2610
|
+
if (amax < fabs(v)) {
|
2611
|
+
amax = fabs(v);
|
2612
|
+
max = v;
|
2613
|
+
}
|
2614
|
+
}
|
2615
|
+
|
2616
|
+
const float d = max / kvalues_iq4nl_f[0];
|
2617
|
+
const float id = d ? 1.0f/d : 0.0f;
|
2618
|
+
|
2619
|
+
float sumqx = 0, sumq2 = 0;
|
2620
|
+
for (int j = 0; j < QK4_NL/2; ++j) {
|
2621
|
+
const float x0 = src[0 + j]*id;
|
2622
|
+
const float x1 = src[QK4_NL/2 + j]*id;
|
2623
|
+
|
2624
|
+
const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl_f, x0);
|
2625
|
+
const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl_f, x1);
|
2626
|
+
|
2627
|
+
dst_data[i00/QK4_NL].qs[j] = xi0 | (xi1 << 4);
|
2628
|
+
|
2629
|
+
const float v0 = kvalues_iq4nl_f[xi0];
|
2630
|
+
const float v1 = kvalues_iq4nl_f[xi1];
|
2631
|
+
const float w0 = src[0 + j]*src[0 + j];
|
2632
|
+
const float w1 = src[QK4_NL/2 + j]*src[QK4_NL/2 + j];
|
2633
|
+
sumqx += w0*v0*src[j] + w1*v1*src[QK4_NL/2 + j];
|
2634
|
+
sumq2 += w0*v0*v0 + w1*v1*v1;
|
2635
|
+
|
2636
|
+
}
|
2637
|
+
|
2638
|
+
dst_data[i00/QK4_NL].d = sumq2 > 0 ? sumqx/sumq2 : d;
|
2639
|
+
|
2640
|
+
}
|
2641
|
+
}
|
2642
|
+
|
2391
2643
|
kernel void kernel_concat(
|
2392
2644
|
device const char * src0,
|
2393
2645
|
device const char * src1,
|
@@ -4220,9 +4472,113 @@ void kernel_mul_mv_iq1_s_f32_impl(
|
|
4220
4472
|
}
|
4221
4473
|
}
|
4222
4474
|
|
4223
|
-
|
4224
|
-
|
4225
|
-
|
4475
|
+
void kernel_mul_mv_iq1_m_f32_impl(
|
4476
|
+
device const void * src0,
|
4477
|
+
device const float * src1,
|
4478
|
+
device float * dst,
|
4479
|
+
constant int64_t & ne00,
|
4480
|
+
constant int64_t & ne01,
|
4481
|
+
constant int64_t & ne02,
|
4482
|
+
constant int64_t & ne10,
|
4483
|
+
constant int64_t & ne12,
|
4484
|
+
constant int64_t & ne0,
|
4485
|
+
constant int64_t & ne1,
|
4486
|
+
constant uint & r2,
|
4487
|
+
constant uint & r3,
|
4488
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
4489
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
4490
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4491
|
+
|
4492
|
+
const int nb = ne00/QK_K;
|
4493
|
+
const int r0 = tgpig.x;
|
4494
|
+
const int r1 = tgpig.y;
|
4495
|
+
const int im = tgpig.z;
|
4496
|
+
|
4497
|
+
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
4498
|
+
const int ib_row = first_row * nb;
|
4499
|
+
|
4500
|
+
const uint i12 = im%ne12;
|
4501
|
+
const uint i13 = im/ne12;
|
4502
|
+
|
4503
|
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
4504
|
+
device const block_iq1_m * x = (device const block_iq1_m *) src0 + ib_row + offset0;
|
4505
|
+
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
4506
|
+
|
4507
|
+
float yl[32];
|
4508
|
+
float sumf[N_DST]={0.f}, all_sum;
|
4509
|
+
|
4510
|
+
const int nb32 = nb * (QK_K / 32);
|
4511
|
+
|
4512
|
+
const int ix = tiisg;
|
4513
|
+
|
4514
|
+
device const float * y4 = y + 32 * ix;
|
4515
|
+
|
4516
|
+
#if QK_K != 64
|
4517
|
+
iq1m_scale_t scale;
|
4518
|
+
#endif
|
4519
|
+
|
4520
|
+
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
|
4521
|
+
|
4522
|
+
float4 sumy = {0.f};
|
4523
|
+
for (int i = 0; i < 8; ++i) {
|
4524
|
+
yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
|
4525
|
+
yl[i+ 8] = y4[i+ 8]; sumy[1] += yl[i+ 8];
|
4526
|
+
yl[i+16] = y4[i+16]; sumy[2] += yl[i+16];
|
4527
|
+
yl[i+24] = y4[i+24]; sumy[3] += yl[i+24];
|
4528
|
+
}
|
4529
|
+
|
4530
|
+
const int ibl = ib32 / (QK_K / 32);
|
4531
|
+
const int ib = ib32 % (QK_K / 32);
|
4532
|
+
|
4533
|
+
device const block_iq1_m * xr = x + ibl;
|
4534
|
+
device const uint8_t * qs = xr->qs + 4 * ib;
|
4535
|
+
device const uint8_t * qh = xr->qh + 2 * ib;
|
4536
|
+
device const uint16_t * sc = (device const uint16_t *)xr->scales;
|
4537
|
+
|
4538
|
+
for (int row = 0; row < N_DST; row++) {
|
4539
|
+
|
4540
|
+
#if QK_K != 64
|
4541
|
+
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
|
4542
|
+
#endif
|
4543
|
+
|
4544
|
+
constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
|
4545
|
+
constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700)));
|
4546
|
+
constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid_gpu + (qs[2] | ((qh[1] << 8) & 0x700)));
|
4547
|
+
constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[1] << 4) & 0x700)));
|
4548
|
+
|
4549
|
+
float2 sum = {0.f};
|
4550
|
+
for (int j = 0; j < 4; ++j) {
|
4551
|
+
sum[0] += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4)
|
4552
|
+
+ yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4);
|
4553
|
+
sum[1] += yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4)
|
4554
|
+
+ yl[j+24] * (grid4[j] & 0xf) + yl[j+28] * (grid4[j] >> 4);
|
4555
|
+
}
|
4556
|
+
const float delta1 = sumy[0] * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[1] * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
|
4557
|
+
const float delta2 = sumy[2] * (qh[1] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[3] * (qh[1] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
|
4558
|
+
#if QK_K == 64
|
4559
|
+
const float d = (float) *((device const half *)(sc - 1));
|
4560
|
+
sumf[row] += d * ((sum[0] + delta1) * (2*((sc[0] >> (8*(ib%2)+0)) & 0xf) + 1) +
|
4561
|
+
(sum[1] + delta2) * (2*((sc[0] >> (8*(ib%2)+4)) & 0xf) + 1));
|
4562
|
+
#else
|
4563
|
+
sumf[row] += (float)scale.f16 * ((sum[0] + delta1) * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 7) + 1) +
|
4564
|
+
(sum[1] + delta2) * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 7) + 1));
|
4565
|
+
#endif
|
4566
|
+
|
4567
|
+
sc += nb*sizeof(block_iq1_m)/2;
|
4568
|
+
qs += nb*sizeof(block_iq1_m);
|
4569
|
+
qh += nb*sizeof(block_iq1_m);
|
4570
|
+
}
|
4571
|
+
|
4572
|
+
y4 += 32 * 32;
|
4573
|
+
}
|
4574
|
+
|
4575
|
+
for (int row = 0; row < N_DST; ++row) {
|
4576
|
+
all_sum = simd_sum(sumf[row]);
|
4577
|
+
if (tiisg == 0) {
|
4578
|
+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
|
4579
|
+
}
|
4580
|
+
}
|
4581
|
+
}
|
4226
4582
|
|
4227
4583
|
void kernel_mul_mv_iq4_nl_f32_impl(
|
4228
4584
|
device const void * src0,
|
@@ -4441,6 +4797,34 @@ kernel void kernel_mul_mv_iq1_s_f32(
|
|
4441
4797
|
kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
4442
4798
|
}
|
4443
4799
|
|
4800
|
+
[[host_name("kernel_mul_mv_iq1_m_f32")]]
|
4801
|
+
kernel void kernel_mul_mv_iq1_m_f32(
|
4802
|
+
device const void * src0,
|
4803
|
+
device const float * src1,
|
4804
|
+
device float * dst,
|
4805
|
+
constant int64_t & ne00,
|
4806
|
+
constant int64_t & ne01,
|
4807
|
+
constant int64_t & ne02,
|
4808
|
+
constant uint64_t & nb00,
|
4809
|
+
constant uint64_t & nb01,
|
4810
|
+
constant uint64_t & nb02,
|
4811
|
+
constant int64_t & ne10,
|
4812
|
+
constant int64_t & ne11,
|
4813
|
+
constant int64_t & ne12,
|
4814
|
+
constant uint64_t & nb10,
|
4815
|
+
constant uint64_t & nb11,
|
4816
|
+
constant uint64_t & nb12,
|
4817
|
+
constant int64_t & ne0,
|
4818
|
+
constant int64_t & ne1,
|
4819
|
+
constant uint & r2,
|
4820
|
+
constant uint & r3,
|
4821
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
4822
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
4823
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4824
|
+
|
4825
|
+
kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
4826
|
+
}
|
4827
|
+
|
4444
4828
|
[[host_name("kernel_mul_mv_iq4_nl_f32")]]
|
4445
4829
|
kernel void kernel_mul_mv_iq4_nl_f32(
|
4446
4830
|
device const void * src0,
|
@@ -4914,6 +5298,38 @@ void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 &
|
|
4914
5298
|
}
|
4915
5299
|
}
|
4916
5300
|
|
5301
|
+
template <typename type4x4>
|
5302
|
+
void dequantize_iq1_m(device const block_iq1_m * xb, short il, thread type4x4 & reg) {
|
5303
|
+
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
|
5304
|
+
const int ib32 = il/2;
|
5305
|
+
il = il%2;
|
5306
|
+
device const uint16_t * sc = (device const uint16_t *)xb->scales;
|
5307
|
+
#if QK_K == 64
|
5308
|
+
const float d = xb->d;
|
5309
|
+
#else
|
5310
|
+
iq1m_scale_t scale;
|
5311
|
+
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
|
5312
|
+
const float d = scale.f16;
|
5313
|
+
#endif
|
5314
|
+
device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
|
5315
|
+
device const uint8_t * qh = xb->qh + 2*ib32 + il;
|
5316
|
+
#if QK_K == 64
|
5317
|
+
const float dl = d * (2*((sc[ib32/2] >> (8*(ib32%2)+4*il)) & 0xf) + 1);
|
5318
|
+
#else
|
5319
|
+
const float dl = d * (2*((sc[ib32/2] >> (6*(ib32%2)+3*il)) & 7) + 1);
|
5320
|
+
#endif
|
5321
|
+
const float ml1 = dl * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
|
5322
|
+
const float ml2 = dl * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
|
5323
|
+
constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
|
5324
|
+
constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700)));
|
5325
|
+
for (int i = 0; i < 4; ++i) {
|
5326
|
+
reg[0][i] = dl * (grid1[i] & 0xf) + ml1;
|
5327
|
+
reg[1][i] = dl * (grid1[i] >> 4) + ml1;
|
5328
|
+
reg[2][i] = dl * (grid2[i] & 0xf) + ml2;
|
5329
|
+
reg[3][i] = dl * (grid2[i] >> 4) + ml2;
|
5330
|
+
}
|
5331
|
+
}
|
5332
|
+
|
4917
5333
|
template <typename type4x4>
|
4918
5334
|
void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4 & reg) {
|
4919
5335
|
device const uint16_t * q4 = (device const uint16_t *)xb->qs;
|
@@ -5385,9 +5801,10 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|
5385
5801
|
|
5386
5802
|
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
5387
5803
|
kernel void kernel_mul_mm_id(
|
5388
|
-
device const uchar *
|
5804
|
+
device const uchar * src0s,
|
5389
5805
|
device const uchar * src1,
|
5390
5806
|
device float * dst,
|
5807
|
+
device const uchar * ids,
|
5391
5808
|
constant uint64_t & nbi1,
|
5392
5809
|
constant int64_t & ne00,
|
5393
5810
|
constant int64_t & ne02,
|
@@ -5404,22 +5821,14 @@ kernel void kernel_mul_mm_id(
|
|
5404
5821
|
constant uint & r2,
|
5405
5822
|
constant uint & r3,
|
5406
5823
|
constant int & idx,
|
5407
|
-
device const uchar * src00,
|
5408
|
-
device const uchar * src01,
|
5409
|
-
device const uchar * src02,
|
5410
|
-
device const uchar * src03,
|
5411
|
-
device const uchar * src04,
|
5412
|
-
device const uchar * src05,
|
5413
|
-
device const uchar * src06,
|
5414
|
-
device const uchar * src07,
|
5415
5824
|
threadgroup uchar * shared_memory [[threadgroup(0)]],
|
5416
5825
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
5417
5826
|
uint tiitg[[thread_index_in_threadgroup]],
|
5418
5827
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
5419
|
-
device const uchar * src0s[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
5420
5828
|
|
5421
5829
|
// expert id
|
5422
5830
|
const int32_t id = tgpig.z/(ne12*ne13);
|
5831
|
+
device const uchar * src0 = src0s + id*nb02;
|
5423
5832
|
|
5424
5833
|
tgpig.z = tgpig.z%(ne12*ne13);
|
5425
5834
|
|
@@ -5434,7 +5843,7 @@ kernel void kernel_mul_mm_id(
|
|
5434
5843
|
}
|
5435
5844
|
|
5436
5845
|
kernel_mul_mm_id_impl<block_q, nl, dequantize_func>(
|
5437
|
-
|
5846
|
+
src0,
|
5438
5847
|
src1,
|
5439
5848
|
src1ids,
|
5440
5849
|
dst,
|
@@ -5498,6 +5907,7 @@ template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_t kernel_get_r
|
|
5498
5907
|
template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_t kernel_get_rows<block_iq3_s, QK_NL, dequantize_iq3_s>;
|
5499
5908
|
template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_t kernel_get_rows<block_iq2_s, QK_NL, dequantize_iq2_s>;
|
5500
5909
|
template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
5910
|
+
template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_t kernel_get_rows<block_iq1_m, QK_NL, dequantize_iq1_m>;
|
5501
5911
|
template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows<block_iq4_nl, 2, dequantize_iq4_nl>;
|
5502
5912
|
#if QK_K == 64
|
5503
5913
|
template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows<block_iq4_xs, 2, dequantize_iq4_xs>;
|
@@ -5528,24 +5938,25 @@ typedef void (mat_mm_t)(
|
|
5528
5938
|
threadgroup uchar *,
|
5529
5939
|
uint3, uint, uint);
|
5530
5940
|
|
5531
|
-
template [[host_name("kernel_mul_mm_f32_f32")]]
|
5532
|
-
template [[host_name("kernel_mul_mm_f16_f32")]]
|
5533
|
-
template [[host_name("kernel_mul_mm_q4_0_f32")]]
|
5534
|
-
template [[host_name("kernel_mul_mm_q4_1_f32")]]
|
5535
|
-
template [[host_name("kernel_mul_mm_q5_0_f32")]]
|
5536
|
-
template [[host_name("kernel_mul_mm_q5_1_f32")]]
|
5537
|
-
template [[host_name("kernel_mul_mm_q8_0_f32")]]
|
5538
|
-
template [[host_name("kernel_mul_mm_q2_K_f32")]]
|
5539
|
-
template [[host_name("kernel_mul_mm_q3_K_f32")]]
|
5540
|
-
template [[host_name("kernel_mul_mm_q4_K_f32")]]
|
5541
|
-
template [[host_name("kernel_mul_mm_q5_K_f32")]]
|
5542
|
-
template [[host_name("kernel_mul_mm_q6_K_f32")]]
|
5941
|
+
template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<float4x4, 1, dequantize_f32>;
|
5942
|
+
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
|
5943
|
+
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
|
5944
|
+
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
|
5945
|
+
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_0, 2, dequantize_q5_0>;
|
5946
|
+
template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_1, 2, dequantize_q5_1>;
|
5947
|
+
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
|
5948
|
+
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
|
5949
|
+
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
|
5950
|
+
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
|
5951
|
+
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
|
5952
|
+
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
|
5543
5953
|
template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
5544
5954
|
template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
5545
5955
|
template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
5546
5956
|
template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_s, QK_NL, dequantize_iq3_s>;
|
5547
5957
|
template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_s, QK_NL, dequantize_iq2_s>;
|
5548
5958
|
template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
5959
|
+
template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_m, QK_NL, dequantize_iq1_m>;
|
5549
5960
|
template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_nl>;
|
5550
5961
|
#if QK_K == 64
|
5551
5962
|
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_xs>;
|
@@ -5558,9 +5969,10 @@ template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_m
|
|
5558
5969
|
//
|
5559
5970
|
|
5560
5971
|
typedef void (mat_mm_id_t)(
|
5561
|
-
device const uchar *
|
5972
|
+
device const uchar * src0s,
|
5562
5973
|
device const uchar * src1,
|
5563
5974
|
device float * dst,
|
5975
|
+
device const uchar * ids,
|
5564
5976
|
constant uint64_t & nbi1,
|
5565
5977
|
constant int64_t & ne00,
|
5566
5978
|
constant int64_t & ne02,
|
@@ -5577,35 +5989,28 @@ typedef void (mat_mm_id_t)(
|
|
5577
5989
|
constant uint & r2,
|
5578
5990
|
constant uint & r3,
|
5579
5991
|
constant int & idx,
|
5580
|
-
device const uchar * src00,
|
5581
|
-
device const uchar * src01,
|
5582
|
-
device const uchar * src02,
|
5583
|
-
device const uchar * src03,
|
5584
|
-
device const uchar * src04,
|
5585
|
-
device const uchar * src05,
|
5586
|
-
device const uchar * src06,
|
5587
|
-
device const uchar * src07,
|
5588
5992
|
threadgroup uchar *,
|
5589
5993
|
uint3, uint, uint);
|
5590
5994
|
|
5591
|
-
template [[host_name("kernel_mul_mm_id_f32_f32")]]
|
5592
|
-
template [[host_name("kernel_mul_mm_id_f16_f32")]]
|
5593
|
-
template [[host_name("kernel_mul_mm_id_q4_0_f32")]]
|
5594
|
-
template [[host_name("kernel_mul_mm_id_q4_1_f32")]]
|
5595
|
-
template [[host_name("kernel_mul_mm_id_q5_0_f32")]]
|
5596
|
-
template [[host_name("kernel_mul_mm_id_q5_1_f32")]]
|
5597
|
-
template [[host_name("kernel_mul_mm_id_q8_0_f32")]]
|
5598
|
-
template [[host_name("kernel_mul_mm_id_q2_K_f32")]]
|
5599
|
-
template [[host_name("kernel_mul_mm_id_q3_K_f32")]]
|
5600
|
-
template [[host_name("kernel_mul_mm_id_q4_K_f32")]]
|
5601
|
-
template [[host_name("kernel_mul_mm_id_q5_K_f32")]]
|
5602
|
-
template [[host_name("kernel_mul_mm_id_q6_K_f32")]]
|
5995
|
+
template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<float4x4, 1, dequantize_f32>;
|
5996
|
+
template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<half4x4, 1, dequantize_f16>;
|
5997
|
+
template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_0, 2, dequantize_q4_0>;
|
5998
|
+
template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_1, 2, dequantize_q4_1>;
|
5999
|
+
template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_0, 2, dequantize_q5_0>;
|
6000
|
+
template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_1, 2, dequantize_q5_1>;
|
6001
|
+
template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q8_0, 2, dequantize_q8_0>;
|
6002
|
+
template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q2_K, QK_NL, dequantize_q2_K>;
|
6003
|
+
template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q3_K, QK_NL, dequantize_q3_K>;
|
6004
|
+
template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>;
|
6005
|
+
template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
|
6006
|
+
template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
|
5603
6007
|
template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
5604
6008
|
template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
5605
6009
|
template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
5606
6010
|
template [[host_name("kernel_mul_mm_id_iq3_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_s, QK_NL, dequantize_iq3_s>;
|
5607
6011
|
template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_s, QK_NL, dequantize_iq2_s>;
|
5608
6012
|
template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
6013
|
+
template [[host_name("kernel_mul_mm_id_iq1_m_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_m, QK_NL, dequantize_iq1_m>;
|
5609
6014
|
template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl, 2, dequantize_iq4_nl>;
|
5610
6015
|
#if QK_K == 64
|
5611
6016
|
template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, 2, dequantize_iq4_xs>;
|
@@ -5619,9 +6024,10 @@ template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel
|
|
5619
6024
|
|
5620
6025
|
[[host_name("kernel_mul_mv_id_f32_f32")]]
|
5621
6026
|
kernel void kernel_mul_mv_id_f32_f32(
|
5622
|
-
device const char *
|
6027
|
+
device const char * src0s,
|
5623
6028
|
device const char * src1,
|
5624
6029
|
device float * dst,
|
6030
|
+
device const char * ids,
|
5625
6031
|
constant uint64_t & nbi1,
|
5626
6032
|
constant int64_t & ne00,
|
5627
6033
|
constant int64_t & ne01,
|
@@ -5642,28 +6048,19 @@ kernel void kernel_mul_mv_id_f32_f32(
|
|
5642
6048
|
constant uint & r2,
|
5643
6049
|
constant uint & r3,
|
5644
6050
|
constant int & idx,
|
5645
|
-
device const char * src00,
|
5646
|
-
device const char * src01,
|
5647
|
-
device const char * src02,
|
5648
|
-
device const char * src03,
|
5649
|
-
device const char * src04,
|
5650
|
-
device const char * src05,
|
5651
|
-
device const char * src06,
|
5652
|
-
device const char * src07,
|
5653
6051
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
5654
6052
|
uint tiitg[[thread_index_in_threadgroup]],
|
5655
6053
|
uint tiisg[[thread_index_in_simdgroup]],
|
5656
6054
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
5657
|
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
5658
|
-
|
5659
6055
|
const int64_t bid = tgpig.z/(ne12*ne13);
|
5660
6056
|
|
5661
6057
|
tgpig.z = tgpig.z%(ne12*ne13);
|
5662
6058
|
|
5663
6059
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6060
|
+
device const char * src0 = src0s + id*nb02;
|
5664
6061
|
|
5665
6062
|
kernel_mul_mv_f32_f32_impl(
|
5666
|
-
src0
|
6063
|
+
src0,
|
5667
6064
|
src1 + bid*nb11,
|
5668
6065
|
dst + bid*ne0,
|
5669
6066
|
ne00,
|
@@ -5688,9 +6085,10 @@ kernel void kernel_mul_mv_id_f32_f32(
|
|
5688
6085
|
|
5689
6086
|
[[host_name("kernel_mul_mv_id_f16_f32")]]
|
5690
6087
|
kernel void kernel_mul_mv_id_f16_f32(
|
5691
|
-
device const char *
|
6088
|
+
device const char * src0s,
|
5692
6089
|
device const char * src1,
|
5693
6090
|
device float * dst,
|
6091
|
+
device const char * ids,
|
5694
6092
|
constant uint64_t & nbi1,
|
5695
6093
|
constant int64_t & ne00,
|
5696
6094
|
constant int64_t & ne01,
|
@@ -5711,28 +6109,19 @@ kernel void kernel_mul_mv_id_f16_f32(
|
|
5711
6109
|
constant uint & r2,
|
5712
6110
|
constant uint & r3,
|
5713
6111
|
constant int & idx,
|
5714
|
-
device const char * src00,
|
5715
|
-
device const char * src01,
|
5716
|
-
device const char * src02,
|
5717
|
-
device const char * src03,
|
5718
|
-
device const char * src04,
|
5719
|
-
device const char * src05,
|
5720
|
-
device const char * src06,
|
5721
|
-
device const char * src07,
|
5722
6112
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
5723
6113
|
uint tiitg[[thread_index_in_threadgroup]],
|
5724
6114
|
uint tiisg[[thread_index_in_simdgroup]],
|
5725
6115
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
5726
|
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
5727
|
-
|
5728
6116
|
const int64_t bid = tgpig.z/(ne12*ne13);
|
5729
6117
|
|
5730
6118
|
tgpig.z = tgpig.z%(ne12*ne13);
|
5731
6119
|
|
5732
6120
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6121
|
+
device const char * src0 = src0s + id*nb02;
|
5733
6122
|
|
5734
6123
|
kernel_mul_mv_f16_f32_impl(
|
5735
|
-
src0
|
6124
|
+
src0,
|
5736
6125
|
src1 + bid*nb11,
|
5737
6126
|
dst + bid*ne0,
|
5738
6127
|
ne00,
|
@@ -5757,9 +6146,10 @@ kernel void kernel_mul_mv_id_f16_f32(
|
|
5757
6146
|
|
5758
6147
|
[[host_name("kernel_mul_mv_id_q8_0_f32")]]
|
5759
6148
|
kernel void kernel_mul_mv_id_q8_0_f32(
|
5760
|
-
device const char *
|
6149
|
+
device const char * src0s,
|
5761
6150
|
device const char * src1,
|
5762
6151
|
device float * dst,
|
6152
|
+
device const char * ids,
|
5763
6153
|
constant uint64_t & nbi1,
|
5764
6154
|
constant int64_t & ne00,
|
5765
6155
|
constant int64_t & ne01,
|
@@ -5780,28 +6170,19 @@ kernel void kernel_mul_mv_id_q8_0_f32(
|
|
5780
6170
|
constant uint & r2,
|
5781
6171
|
constant uint & r3,
|
5782
6172
|
constant int & idx,
|
5783
|
-
device const char * src00,
|
5784
|
-
device const char * src01,
|
5785
|
-
device const char * src02,
|
5786
|
-
device const char * src03,
|
5787
|
-
device const char * src04,
|
5788
|
-
device const char * src05,
|
5789
|
-
device const char * src06,
|
5790
|
-
device const char * src07,
|
5791
6173
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
5792
6174
|
uint tiitg[[thread_index_in_threadgroup]],
|
5793
6175
|
uint tiisg[[thread_index_in_simdgroup]],
|
5794
6176
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
5795
|
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
5796
|
-
|
5797
6177
|
const int64_t bid = tgpig.z/(ne12*ne13);
|
5798
6178
|
|
5799
6179
|
tgpig.z = tgpig.z%(ne12*ne13);
|
5800
6180
|
|
5801
6181
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6182
|
+
device const char * src0 = src0s + id*nb02;
|
5802
6183
|
|
5803
6184
|
kernel_mul_mv_q8_0_f32_impl(
|
5804
|
-
src0
|
6185
|
+
src0,
|
5805
6186
|
(device const float *) (src1 + bid*nb11),
|
5806
6187
|
dst + bid*ne0,
|
5807
6188
|
ne00,
|
@@ -5820,9 +6201,10 @@ kernel void kernel_mul_mv_id_q8_0_f32(
|
|
5820
6201
|
|
5821
6202
|
[[host_name("kernel_mul_mv_id_q4_0_f32")]]
|
5822
6203
|
kernel void kernel_mul_mv_id_q4_0_f32(
|
5823
|
-
device const char *
|
6204
|
+
device const char * src0s,
|
5824
6205
|
device const char * src1,
|
5825
6206
|
device float * dst,
|
6207
|
+
device const char * ids,
|
5826
6208
|
constant uint64_t & nbi1,
|
5827
6209
|
constant int64_t & ne00,
|
5828
6210
|
constant int64_t & ne01,
|
@@ -5843,28 +6225,19 @@ kernel void kernel_mul_mv_id_q4_0_f32(
|
|
5843
6225
|
constant uint & r2,
|
5844
6226
|
constant uint & r3,
|
5845
6227
|
constant int & idx,
|
5846
|
-
device const char * src00,
|
5847
|
-
device const char * src01,
|
5848
|
-
device const char * src02,
|
5849
|
-
device const char * src03,
|
5850
|
-
device const char * src04,
|
5851
|
-
device const char * src05,
|
5852
|
-
device const char * src06,
|
5853
|
-
device const char * src07,
|
5854
6228
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
5855
6229
|
uint tiitg[[thread_index_in_threadgroup]],
|
5856
6230
|
uint tiisg[[thread_index_in_simdgroup]],
|
5857
6231
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
5858
|
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
5859
|
-
|
5860
6232
|
const int64_t bid = tgpig.z/(ne12*ne13);
|
5861
6233
|
|
5862
6234
|
tgpig.z = tgpig.z%(ne12*ne13);
|
5863
6235
|
|
5864
6236
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6237
|
+
device const char * src0 = src0s + id*nb02;
|
5865
6238
|
|
5866
6239
|
mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
5867
|
-
src0
|
6240
|
+
src0,
|
5868
6241
|
(device const float *) (src1 + bid*nb11),
|
5869
6242
|
dst + bid*ne0,
|
5870
6243
|
ne00,
|
@@ -5883,9 +6256,10 @@ kernel void kernel_mul_mv_id_q4_0_f32(
|
|
5883
6256
|
|
5884
6257
|
[[host_name("kernel_mul_mv_id_q4_1_f32")]]
|
5885
6258
|
kernel void kernel_mul_mv_id_q4_1_f32(
|
5886
|
-
device const char *
|
6259
|
+
device const char * src0s,
|
5887
6260
|
device const char * src1,
|
5888
6261
|
device float * dst,
|
6262
|
+
device const char * ids,
|
5889
6263
|
constant uint64_t & nbi1,
|
5890
6264
|
constant int64_t & ne00,
|
5891
6265
|
constant int64_t & ne01,
|
@@ -5906,28 +6280,19 @@ kernel void kernel_mul_mv_id_q4_1_f32(
|
|
5906
6280
|
constant uint & r2,
|
5907
6281
|
constant uint & r3,
|
5908
6282
|
constant int & idx,
|
5909
|
-
device const char * src00,
|
5910
|
-
device const char * src01,
|
5911
|
-
device const char * src02,
|
5912
|
-
device const char * src03,
|
5913
|
-
device const char * src04,
|
5914
|
-
device const char * src05,
|
5915
|
-
device const char * src06,
|
5916
|
-
device const char * src07,
|
5917
6283
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
5918
6284
|
uint tiitg[[thread_index_in_threadgroup]],
|
5919
6285
|
uint tiisg[[thread_index_in_simdgroup]],
|
5920
6286
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
5921
|
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
5922
|
-
|
5923
6287
|
const int64_t bid = tgpig.z/(ne12*ne13);
|
5924
6288
|
|
5925
6289
|
tgpig.z = tgpig.z%(ne12*ne13);
|
5926
6290
|
|
5927
6291
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6292
|
+
device const char * src0 = src0s + id*nb02;
|
5928
6293
|
|
5929
6294
|
mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
5930
|
-
src0
|
6295
|
+
src0,
|
5931
6296
|
(device const float *) (src1 + bid*nb11),
|
5932
6297
|
dst + bid*ne0,
|
5933
6298
|
ne00,
|
@@ -5946,9 +6311,10 @@ kernel void kernel_mul_mv_id_q4_1_f32(
|
|
5946
6311
|
|
5947
6312
|
[[host_name("kernel_mul_mv_id_q5_0_f32")]]
|
5948
6313
|
kernel void kernel_mul_mv_id_q5_0_f32(
|
5949
|
-
device const char *
|
6314
|
+
device const char * src0s,
|
5950
6315
|
device const char * src1,
|
5951
6316
|
device float * dst,
|
6317
|
+
device const char * ids,
|
5952
6318
|
constant uint64_t & nbi1,
|
5953
6319
|
constant int64_t & ne00,
|
5954
6320
|
constant int64_t & ne01,
|
@@ -5969,28 +6335,19 @@ kernel void kernel_mul_mv_id_q5_0_f32(
|
|
5969
6335
|
constant uint & r2,
|
5970
6336
|
constant uint & r3,
|
5971
6337
|
constant int & idx,
|
5972
|
-
device const char * src00,
|
5973
|
-
device const char * src01,
|
5974
|
-
device const char * src02,
|
5975
|
-
device const char * src03,
|
5976
|
-
device const char * src04,
|
5977
|
-
device const char * src05,
|
5978
|
-
device const char * src06,
|
5979
|
-
device const char * src07,
|
5980
6338
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
5981
6339
|
uint tiitg[[thread_index_in_threadgroup]],
|
5982
6340
|
uint tiisg[[thread_index_in_simdgroup]],
|
5983
6341
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
5984
|
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
5985
|
-
|
5986
6342
|
const int64_t bid = tgpig.z/(ne12*ne13);
|
5987
6343
|
|
5988
6344
|
tgpig.z = tgpig.z%(ne12*ne13);
|
5989
6345
|
|
5990
6346
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6347
|
+
device const char * src0 = src0s + id*nb02;
|
5991
6348
|
|
5992
6349
|
mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
5993
|
-
src0
|
6350
|
+
src0,
|
5994
6351
|
(device const float *) (src1 + bid*nb11),
|
5995
6352
|
dst + bid*ne0,
|
5996
6353
|
ne00,
|
@@ -6009,9 +6366,10 @@ kernel void kernel_mul_mv_id_q5_0_f32(
|
|
6009
6366
|
|
6010
6367
|
[[host_name("kernel_mul_mv_id_q5_1_f32")]]
|
6011
6368
|
kernel void kernel_mul_mv_id_q5_1_f32(
|
6012
|
-
device const char *
|
6369
|
+
device const char * src0s,
|
6013
6370
|
device const char * src1,
|
6014
6371
|
device float * dst,
|
6372
|
+
device const char * ids,
|
6015
6373
|
constant uint64_t & nbi1,
|
6016
6374
|
constant int64_t & ne00,
|
6017
6375
|
constant int64_t & ne01,
|
@@ -6032,28 +6390,19 @@ kernel void kernel_mul_mv_id_q5_1_f32(
|
|
6032
6390
|
constant uint & r2,
|
6033
6391
|
constant uint & r3,
|
6034
6392
|
constant int & idx,
|
6035
|
-
device const char * src00,
|
6036
|
-
device const char * src01,
|
6037
|
-
device const char * src02,
|
6038
|
-
device const char * src03,
|
6039
|
-
device const char * src04,
|
6040
|
-
device const char * src05,
|
6041
|
-
device const char * src06,
|
6042
|
-
device const char * src07,
|
6043
6393
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
6044
6394
|
uint tiitg[[thread_index_in_threadgroup]],
|
6045
6395
|
uint tiisg[[thread_index_in_simdgroup]],
|
6046
6396
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6047
|
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
6048
|
-
|
6049
6397
|
const int64_t bid = tgpig.z/(ne12*ne13);
|
6050
6398
|
|
6051
6399
|
tgpig.z = tgpig.z%(ne12*ne13);
|
6052
6400
|
|
6053
6401
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6402
|
+
device const char * src0 = src0s + id*nb02;
|
6054
6403
|
|
6055
6404
|
mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
6056
|
-
src0
|
6405
|
+
src0,
|
6057
6406
|
(device const float *) (src1 + bid*nb11),
|
6058
6407
|
dst + bid*ne0,
|
6059
6408
|
ne00,
|
@@ -6072,9 +6421,10 @@ kernel void kernel_mul_mv_id_q5_1_f32(
|
|
6072
6421
|
|
6073
6422
|
[[host_name("kernel_mul_mv_id_q2_K_f32")]]
|
6074
6423
|
kernel void kernel_mul_mv_id_q2_K_f32(
|
6075
|
-
device const char *
|
6424
|
+
device const char * src0s,
|
6076
6425
|
device const char * src1,
|
6077
6426
|
device float * dst,
|
6427
|
+
device const char * ids,
|
6078
6428
|
constant uint64_t & nbi1,
|
6079
6429
|
constant int64_t & ne00,
|
6080
6430
|
constant int64_t & ne01,
|
@@ -6095,28 +6445,19 @@ kernel void kernel_mul_mv_id_q2_K_f32(
|
|
6095
6445
|
constant uint & r2,
|
6096
6446
|
constant uint & r3,
|
6097
6447
|
constant int & idx,
|
6098
|
-
device const char * src00,
|
6099
|
-
device const char * src01,
|
6100
|
-
device const char * src02,
|
6101
|
-
device const char * src03,
|
6102
|
-
device const char * src04,
|
6103
|
-
device const char * src05,
|
6104
|
-
device const char * src06,
|
6105
|
-
device const char * src07,
|
6106
6448
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
6107
6449
|
uint tiitg[[thread_index_in_threadgroup]],
|
6108
6450
|
uint tiisg[[thread_index_in_simdgroup]],
|
6109
6451
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6110
|
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
6111
|
-
|
6112
6452
|
const int64_t bid = tgpig.z/(ne12*ne13);
|
6113
6453
|
|
6114
6454
|
tgpig.z = tgpig.z%(ne12*ne13);
|
6115
6455
|
|
6116
6456
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6457
|
+
device const char * src0 = src0s + id*nb02;
|
6117
6458
|
|
6118
6459
|
kernel_mul_mv_q2_K_f32_impl(
|
6119
|
-
src0
|
6460
|
+
src0,
|
6120
6461
|
(device const float *) (src1 + bid*nb11),
|
6121
6462
|
dst + bid*ne0,
|
6122
6463
|
ne00,
|
@@ -6135,9 +6476,10 @@ kernel void kernel_mul_mv_id_q2_K_f32(
|
|
6135
6476
|
|
6136
6477
|
[[host_name("kernel_mul_mv_id_q3_K_f32")]]
|
6137
6478
|
kernel void kernel_mul_mv_id_q3_K_f32(
|
6138
|
-
device const char *
|
6479
|
+
device const char * src0s,
|
6139
6480
|
device const char * src1,
|
6140
6481
|
device float * dst,
|
6482
|
+
device const char * ids,
|
6141
6483
|
constant uint64_t & nbi1,
|
6142
6484
|
constant int64_t & ne00,
|
6143
6485
|
constant int64_t & ne01,
|
@@ -6158,28 +6500,19 @@ kernel void kernel_mul_mv_id_q3_K_f32(
|
|
6158
6500
|
constant uint & r2,
|
6159
6501
|
constant uint & r3,
|
6160
6502
|
constant int & idx,
|
6161
|
-
device const char * src00,
|
6162
|
-
device const char * src01,
|
6163
|
-
device const char * src02,
|
6164
|
-
device const char * src03,
|
6165
|
-
device const char * src04,
|
6166
|
-
device const char * src05,
|
6167
|
-
device const char * src06,
|
6168
|
-
device const char * src07,
|
6169
6503
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
6170
6504
|
uint tiitg[[thread_index_in_threadgroup]],
|
6171
6505
|
uint tiisg[[thread_index_in_simdgroup]],
|
6172
6506
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6173
|
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
6174
|
-
|
6175
6507
|
const int64_t bid = tgpig.z/(ne12*ne13);
|
6176
6508
|
|
6177
6509
|
tgpig.z = tgpig.z%(ne12*ne13);
|
6178
6510
|
|
6179
6511
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6512
|
+
device const char * src0 = src0s + id*nb02;
|
6180
6513
|
|
6181
6514
|
kernel_mul_mv_q3_K_f32_impl(
|
6182
|
-
src0
|
6515
|
+
src0,
|
6183
6516
|
(device const float *) (src1 + bid*nb11),
|
6184
6517
|
dst + bid*ne0,
|
6185
6518
|
ne00,
|
@@ -6198,9 +6531,10 @@ kernel void kernel_mul_mv_id_q3_K_f32(
|
|
6198
6531
|
|
6199
6532
|
[[host_name("kernel_mul_mv_id_q4_K_f32")]]
|
6200
6533
|
kernel void kernel_mul_mv_id_q4_K_f32(
|
6201
|
-
device const char *
|
6534
|
+
device const char * src0s,
|
6202
6535
|
device const char * src1,
|
6203
6536
|
device float * dst,
|
6537
|
+
device const char * ids,
|
6204
6538
|
constant uint64_t & nbi1,
|
6205
6539
|
constant int64_t & ne00,
|
6206
6540
|
constant int64_t & ne01,
|
@@ -6221,28 +6555,19 @@ kernel void kernel_mul_mv_id_q4_K_f32(
|
|
6221
6555
|
constant uint & r2,
|
6222
6556
|
constant uint & r3,
|
6223
6557
|
constant int & idx,
|
6224
|
-
device const char * src00,
|
6225
|
-
device const char * src01,
|
6226
|
-
device const char * src02,
|
6227
|
-
device const char * src03,
|
6228
|
-
device const char * src04,
|
6229
|
-
device const char * src05,
|
6230
|
-
device const char * src06,
|
6231
|
-
device const char * src07,
|
6232
6558
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
6233
6559
|
uint tiitg[[thread_index_in_threadgroup]],
|
6234
6560
|
uint tiisg[[thread_index_in_simdgroup]],
|
6235
6561
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6236
|
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
6237
|
-
|
6238
6562
|
const int64_t bid = tgpig.z/(ne12*ne13);
|
6239
6563
|
|
6240
6564
|
tgpig.z = tgpig.z%(ne12*ne13);
|
6241
6565
|
|
6242
6566
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6567
|
+
device const char * src0 = src0s + id*nb02;
|
6243
6568
|
|
6244
6569
|
kernel_mul_mv_q4_K_f32_impl(
|
6245
|
-
src0
|
6570
|
+
src0,
|
6246
6571
|
(device const float *) (src1 + bid*nb11),
|
6247
6572
|
dst + bid*ne0,
|
6248
6573
|
ne00,
|
@@ -6261,9 +6586,10 @@ kernel void kernel_mul_mv_id_q4_K_f32(
|
|
6261
6586
|
|
6262
6587
|
[[host_name("kernel_mul_mv_id_q5_K_f32")]]
|
6263
6588
|
kernel void kernel_mul_mv_id_q5_K_f32(
|
6264
|
-
device const char *
|
6589
|
+
device const char * src0s,
|
6265
6590
|
device const char * src1,
|
6266
6591
|
device float * dst,
|
6592
|
+
device const char * ids,
|
6267
6593
|
constant uint64_t & nbi1,
|
6268
6594
|
constant int64_t & ne00,
|
6269
6595
|
constant int64_t & ne01,
|
@@ -6284,28 +6610,19 @@ kernel void kernel_mul_mv_id_q5_K_f32(
|
|
6284
6610
|
constant uint & r2,
|
6285
6611
|
constant uint & r3,
|
6286
6612
|
constant int & idx,
|
6287
|
-
device const char * src00,
|
6288
|
-
device const char * src01,
|
6289
|
-
device const char * src02,
|
6290
|
-
device const char * src03,
|
6291
|
-
device const char * src04,
|
6292
|
-
device const char * src05,
|
6293
|
-
device const char * src06,
|
6294
|
-
device const char * src07,
|
6295
6613
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
6296
6614
|
uint tiitg[[thread_index_in_threadgroup]],
|
6297
6615
|
uint tiisg[[thread_index_in_simdgroup]],
|
6298
6616
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6299
|
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
6300
|
-
|
6301
6617
|
const int64_t bid = tgpig.z/(ne12*ne13);
|
6302
6618
|
|
6303
6619
|
tgpig.z = tgpig.z%(ne12*ne13);
|
6304
6620
|
|
6305
6621
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6622
|
+
device const char * src0 = src0s + id*nb02;
|
6306
6623
|
|
6307
6624
|
kernel_mul_mv_q5_K_f32_impl(
|
6308
|
-
src0
|
6625
|
+
src0,
|
6309
6626
|
(device const float *) (src1 + bid*nb11),
|
6310
6627
|
dst + bid*ne0,
|
6311
6628
|
ne00,
|
@@ -6324,9 +6641,10 @@ kernel void kernel_mul_mv_id_q5_K_f32(
|
|
6324
6641
|
|
6325
6642
|
[[host_name("kernel_mul_mv_id_q6_K_f32")]]
|
6326
6643
|
kernel void kernel_mul_mv_id_q6_K_f32(
|
6327
|
-
device const char *
|
6644
|
+
device const char * src0s,
|
6328
6645
|
device const char * src1,
|
6329
6646
|
device float * dst,
|
6647
|
+
device const char * ids,
|
6330
6648
|
constant uint64_t & nbi1,
|
6331
6649
|
constant int64_t & ne00,
|
6332
6650
|
constant int64_t & ne01,
|
@@ -6347,28 +6665,19 @@ kernel void kernel_mul_mv_id_q6_K_f32(
|
|
6347
6665
|
constant uint & r2,
|
6348
6666
|
constant uint & r3,
|
6349
6667
|
constant int & idx,
|
6350
|
-
device const char * src00,
|
6351
|
-
device const char * src01,
|
6352
|
-
device const char * src02,
|
6353
|
-
device const char * src03,
|
6354
|
-
device const char * src04,
|
6355
|
-
device const char * src05,
|
6356
|
-
device const char * src06,
|
6357
|
-
device const char * src07,
|
6358
6668
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
6359
6669
|
uint tiitg[[thread_index_in_threadgroup]],
|
6360
6670
|
uint tiisg[[thread_index_in_simdgroup]],
|
6361
6671
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6362
|
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
6363
|
-
|
6364
6672
|
const int64_t bid = tgpig.z/(ne12*ne13);
|
6365
6673
|
|
6366
6674
|
tgpig.z = tgpig.z%(ne12*ne13);
|
6367
6675
|
|
6368
6676
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6677
|
+
device const char * src0 = src0s + id*nb02;
|
6369
6678
|
|
6370
6679
|
kernel_mul_mv_q6_K_f32_impl(
|
6371
|
-
src0
|
6680
|
+
src0,
|
6372
6681
|
(device const float *) (src1 + bid*nb11),
|
6373
6682
|
dst + bid*ne0,
|
6374
6683
|
ne00,
|
@@ -6387,9 +6696,10 @@ kernel void kernel_mul_mv_id_q6_K_f32(
|
|
6387
6696
|
|
6388
6697
|
[[host_name("kernel_mul_mv_id_iq2_xxs_f32")]]
|
6389
6698
|
kernel void kernel_mul_mv_id_iq2_xxs_f32(
|
6390
|
-
device const char *
|
6699
|
+
device const char * src0s,
|
6391
6700
|
device const char * src1,
|
6392
6701
|
device float * dst,
|
6702
|
+
device const char * ids,
|
6393
6703
|
constant uint64_t & nbi1,
|
6394
6704
|
constant int64_t & ne00,
|
6395
6705
|
constant int64_t & ne01,
|
@@ -6410,29 +6720,20 @@ kernel void kernel_mul_mv_id_iq2_xxs_f32(
|
|
6410
6720
|
constant uint & r2,
|
6411
6721
|
constant uint & r3,
|
6412
6722
|
constant int & idx,
|
6413
|
-
device const char * src00,
|
6414
|
-
device const char * src01,
|
6415
|
-
device const char * src02,
|
6416
|
-
device const char * src03,
|
6417
|
-
device const char * src04,
|
6418
|
-
device const char * src05,
|
6419
|
-
device const char * src06,
|
6420
|
-
device const char * src07,
|
6421
6723
|
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
6422
6724
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
6423
6725
|
uint tiitg[[thread_index_in_threadgroup]],
|
6424
6726
|
uint tiisg[[thread_index_in_simdgroup]],
|
6425
6727
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6426
|
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
6427
|
-
|
6428
6728
|
const int64_t bid = tgpig.z/(ne12*ne13);
|
6429
6729
|
|
6430
6730
|
tgpig.z = tgpig.z%(ne12*ne13);
|
6431
6731
|
|
6432
6732
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6733
|
+
device const char * src0 = src0s + id*nb02;
|
6433
6734
|
|
6434
6735
|
kernel_mul_mv_iq2_xxs_f32_impl(
|
6435
|
-
src0
|
6736
|
+
src0,
|
6436
6737
|
(device const float *) (src1 + bid*nb11),
|
6437
6738
|
dst + bid*ne0,
|
6438
6739
|
ne00,
|
@@ -6452,9 +6753,10 @@ kernel void kernel_mul_mv_id_iq2_xxs_f32(
|
|
6452
6753
|
|
6453
6754
|
[[host_name("kernel_mul_mv_id_iq2_xs_f32")]]
|
6454
6755
|
kernel void kernel_mul_mv_id_iq2_xs_f32(
|
6455
|
-
device const char *
|
6756
|
+
device const char * src0s,
|
6456
6757
|
device const char * src1,
|
6457
6758
|
device float * dst,
|
6759
|
+
device const char * ids,
|
6458
6760
|
constant uint64_t & nbi1,
|
6459
6761
|
constant int64_t & ne00,
|
6460
6762
|
constant int64_t & ne01,
|
@@ -6475,29 +6777,20 @@ kernel void kernel_mul_mv_id_iq2_xs_f32(
|
|
6475
6777
|
constant uint & r2,
|
6476
6778
|
constant uint & r3,
|
6477
6779
|
constant int & idx,
|
6478
|
-
device const char * src00,
|
6479
|
-
device const char * src01,
|
6480
|
-
device const char * src02,
|
6481
|
-
device const char * src03,
|
6482
|
-
device const char * src04,
|
6483
|
-
device const char * src05,
|
6484
|
-
device const char * src06,
|
6485
|
-
device const char * src07,
|
6486
6780
|
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
6487
6781
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
6488
6782
|
uint tiitg[[thread_index_in_threadgroup]],
|
6489
6783
|
uint tiisg[[thread_index_in_simdgroup]],
|
6490
6784
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6491
|
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
6492
|
-
|
6493
6785
|
const int64_t bid = tgpig.z/(ne12*ne13);
|
6494
6786
|
|
6495
6787
|
tgpig.z = tgpig.z%(ne12*ne13);
|
6496
6788
|
|
6497
6789
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6790
|
+
device const char * src0 = src0s + id*nb02;
|
6498
6791
|
|
6499
6792
|
kernel_mul_mv_iq2_xs_f32_impl(
|
6500
|
-
src0
|
6793
|
+
src0,
|
6501
6794
|
(device const float *) (src1 + bid*nb11),
|
6502
6795
|
dst + bid*ne0,
|
6503
6796
|
ne00,
|
@@ -6517,9 +6810,10 @@ kernel void kernel_mul_mv_id_iq2_xs_f32(
|
|
6517
6810
|
|
6518
6811
|
[[host_name("kernel_mul_mv_id_iq3_xxs_f32")]]
|
6519
6812
|
kernel void kernel_mul_mv_id_iq3_xxs_f32(
|
6520
|
-
device const char *
|
6813
|
+
device const char * src0s,
|
6521
6814
|
device const char * src1,
|
6522
6815
|
device float * dst,
|
6816
|
+
device const char * ids,
|
6523
6817
|
constant uint64_t & nbi1,
|
6524
6818
|
constant int64_t & ne00,
|
6525
6819
|
constant int64_t & ne01,
|
@@ -6540,29 +6834,20 @@ kernel void kernel_mul_mv_id_iq3_xxs_f32(
|
|
6540
6834
|
constant uint & r2,
|
6541
6835
|
constant uint & r3,
|
6542
6836
|
constant int & idx,
|
6543
|
-
device const char * src00,
|
6544
|
-
device const char * src01,
|
6545
|
-
device const char * src02,
|
6546
|
-
device const char * src03,
|
6547
|
-
device const char * src04,
|
6548
|
-
device const char * src05,
|
6549
|
-
device const char * src06,
|
6550
|
-
device const char * src07,
|
6551
6837
|
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
6552
6838
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
6553
6839
|
uint tiitg[[thread_index_in_threadgroup]],
|
6554
6840
|
uint tiisg[[thread_index_in_simdgroup]],
|
6555
6841
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6556
|
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
6557
|
-
|
6558
6842
|
const int64_t bid = tgpig.z/(ne12*ne13);
|
6559
6843
|
|
6560
6844
|
tgpig.z = tgpig.z%(ne12*ne13);
|
6561
6845
|
|
6562
6846
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6847
|
+
device const char * src0 = src0s + id*nb02;
|
6563
6848
|
|
6564
6849
|
kernel_mul_mv_iq3_xxs_f32_impl(
|
6565
|
-
src0
|
6850
|
+
src0,
|
6566
6851
|
(device const float *) (src1 + bid*nb11),
|
6567
6852
|
dst + bid*ne0,
|
6568
6853
|
ne00,
|
@@ -6582,9 +6867,10 @@ kernel void kernel_mul_mv_id_iq3_xxs_f32(
|
|
6582
6867
|
|
6583
6868
|
[[host_name("kernel_mul_mv_id_iq3_s_f32")]]
|
6584
6869
|
kernel void kernel_mul_mv_id_iq3_s_f32(
|
6585
|
-
device const char *
|
6870
|
+
device const char * src0s,
|
6586
6871
|
device const char * src1,
|
6587
6872
|
device float * dst,
|
6873
|
+
device const char * ids,
|
6588
6874
|
constant uint64_t & nbi1,
|
6589
6875
|
constant int64_t & ne00,
|
6590
6876
|
constant int64_t & ne01,
|
@@ -6605,29 +6891,20 @@ kernel void kernel_mul_mv_id_iq3_s_f32(
|
|
6605
6891
|
constant uint & r2,
|
6606
6892
|
constant uint & r3,
|
6607
6893
|
constant int & idx,
|
6608
|
-
device const char * src00,
|
6609
|
-
device const char * src01,
|
6610
|
-
device const char * src02,
|
6611
|
-
device const char * src03,
|
6612
|
-
device const char * src04,
|
6613
|
-
device const char * src05,
|
6614
|
-
device const char * src06,
|
6615
|
-
device const char * src07,
|
6616
6894
|
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
6617
6895
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
6618
6896
|
uint tiitg[[thread_index_in_threadgroup]],
|
6619
6897
|
uint tiisg[[thread_index_in_simdgroup]],
|
6620
6898
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6621
|
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
6622
|
-
|
6623
6899
|
const int64_t bid = tgpig.z/(ne12*ne13);
|
6624
6900
|
|
6625
6901
|
tgpig.z = tgpig.z%(ne12*ne13);
|
6626
6902
|
|
6627
6903
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6904
|
+
device const char * src0 = src0s + id*nb02;
|
6628
6905
|
|
6629
6906
|
kernel_mul_mv_iq3_s_f32_impl(
|
6630
|
-
src0
|
6907
|
+
src0,
|
6631
6908
|
(device const float *) (src1 + bid*nb11),
|
6632
6909
|
dst + bid*ne0,
|
6633
6910
|
ne00,
|
@@ -6647,9 +6924,10 @@ kernel void kernel_mul_mv_id_iq3_s_f32(
|
|
6647
6924
|
|
6648
6925
|
[[host_name("kernel_mul_mv_id_iq2_s_f32")]]
|
6649
6926
|
kernel void kernel_mul_mv_id_iq2_s_f32(
|
6650
|
-
device const char *
|
6927
|
+
device const char * src0s,
|
6651
6928
|
device const char * src1,
|
6652
6929
|
device float * dst,
|
6930
|
+
device const char * ids,
|
6653
6931
|
constant uint64_t & nbi1,
|
6654
6932
|
constant int64_t & ne00,
|
6655
6933
|
constant int64_t & ne01,
|
@@ -6670,29 +6948,20 @@ kernel void kernel_mul_mv_id_iq2_s_f32(
|
|
6670
6948
|
constant uint & r2,
|
6671
6949
|
constant uint & r3,
|
6672
6950
|
constant int & idx,
|
6673
|
-
device const char * src00,
|
6674
|
-
device const char * src01,
|
6675
|
-
device const char * src02,
|
6676
|
-
device const char * src03,
|
6677
|
-
device const char * src04,
|
6678
|
-
device const char * src05,
|
6679
|
-
device const char * src06,
|
6680
|
-
device const char * src07,
|
6681
6951
|
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
6682
6952
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
6683
6953
|
uint tiitg[[thread_index_in_threadgroup]],
|
6684
6954
|
uint tiisg[[thread_index_in_simdgroup]],
|
6685
6955
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6686
|
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
6687
|
-
|
6688
6956
|
const int64_t bid = tgpig.z/(ne12*ne13);
|
6689
6957
|
|
6690
6958
|
tgpig.z = tgpig.z%(ne12*ne13);
|
6691
6959
|
|
6692
6960
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
6961
|
+
device const char * src0 = src0s + id*nb02;
|
6693
6962
|
|
6694
6963
|
kernel_mul_mv_iq2_s_f32_impl(
|
6695
|
-
src0
|
6964
|
+
src0,
|
6696
6965
|
(device const float *) (src1 + bid*nb11),
|
6697
6966
|
dst + bid*ne0,
|
6698
6967
|
ne00,
|
@@ -6712,9 +6981,10 @@ kernel void kernel_mul_mv_id_iq2_s_f32(
|
|
6712
6981
|
|
6713
6982
|
[[host_name("kernel_mul_mv_id_iq1_s_f32")]]
|
6714
6983
|
kernel void kernel_mul_mv_id_iq1_s_f32(
|
6715
|
-
device const char *
|
6984
|
+
device const char * src0s,
|
6716
6985
|
device const char * src1,
|
6717
6986
|
device float * dst,
|
6987
|
+
device const char * ids,
|
6718
6988
|
constant uint64_t & nbi1,
|
6719
6989
|
constant int64_t & ne00,
|
6720
6990
|
constant int64_t & ne01,
|
@@ -6735,28 +7005,74 @@ kernel void kernel_mul_mv_id_iq1_s_f32(
|
|
6735
7005
|
constant uint & r2,
|
6736
7006
|
constant uint & r3,
|
6737
7007
|
constant int & idx,
|
6738
|
-
device const char * src00,
|
6739
|
-
device const char * src01,
|
6740
|
-
device const char * src02,
|
6741
|
-
device const char * src03,
|
6742
|
-
device const char * src04,
|
6743
|
-
device const char * src05,
|
6744
|
-
device const char * src06,
|
6745
|
-
device const char * src07,
|
6746
7008
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
6747
7009
|
uint tiitg[[thread_index_in_threadgroup]],
|
6748
7010
|
uint tiisg[[thread_index_in_simdgroup]],
|
6749
7011
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6750
|
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
6751
|
-
|
6752
7012
|
const int64_t bid = tgpig.z/(ne12*ne13);
|
6753
7013
|
|
6754
7014
|
tgpig.z = tgpig.z%(ne12*ne13);
|
6755
7015
|
|
6756
7016
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
7017
|
+
device const char * src0 = src0s + id*nb02;
|
6757
7018
|
|
6758
7019
|
kernel_mul_mv_iq1_s_f32_impl(
|
6759
|
-
src0
|
7020
|
+
src0,
|
7021
|
+
(device const float *) (src1 + bid*nb11),
|
7022
|
+
dst + bid*ne0,
|
7023
|
+
ne00,
|
7024
|
+
ne01,
|
7025
|
+
ne02,
|
7026
|
+
ne10,
|
7027
|
+
ne12,
|
7028
|
+
ne0,
|
7029
|
+
ne1,
|
7030
|
+
r2,
|
7031
|
+
r3,
|
7032
|
+
tgpig,
|
7033
|
+
tiisg,
|
7034
|
+
sgitg);
|
7035
|
+
}
|
7036
|
+
|
7037
|
+
[[host_name("kernel_mul_mv_id_iq1_m_f32")]]
|
7038
|
+
kernel void kernel_mul_mv_id_iq1_m_f32(
|
7039
|
+
device const char * src0s,
|
7040
|
+
device const char * src1,
|
7041
|
+
device float * dst,
|
7042
|
+
device const char * ids,
|
7043
|
+
constant uint64_t & nbi1,
|
7044
|
+
constant int64_t & ne00,
|
7045
|
+
constant int64_t & ne01,
|
7046
|
+
constant int64_t & ne02,
|
7047
|
+
constant uint64_t & nb00,
|
7048
|
+
constant uint64_t & nb01,
|
7049
|
+
constant uint64_t & nb02,
|
7050
|
+
constant int64_t & ne10,
|
7051
|
+
constant int64_t & ne11,
|
7052
|
+
constant int64_t & ne12,
|
7053
|
+
constant int64_t & ne13,
|
7054
|
+
constant uint64_t & nb10,
|
7055
|
+
constant uint64_t & nb11,
|
7056
|
+
constant uint64_t & nb12,
|
7057
|
+
constant int64_t & ne0,
|
7058
|
+
constant int64_t & ne1,
|
7059
|
+
constant uint64_t & nb1,
|
7060
|
+
constant uint & r2,
|
7061
|
+
constant uint & r3,
|
7062
|
+
constant int & idx,
|
7063
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
7064
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
7065
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
7066
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
7067
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
7068
|
+
|
7069
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
7070
|
+
|
7071
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
7072
|
+
device const char * src0 = src0s + id*nb02;
|
7073
|
+
|
7074
|
+
kernel_mul_mv_iq1_m_f32_impl(
|
7075
|
+
src0,
|
6760
7076
|
(device const float *) (src1 + bid*nb11),
|
6761
7077
|
dst + bid*ne0,
|
6762
7078
|
ne00,
|
@@ -6775,9 +7091,10 @@ kernel void kernel_mul_mv_id_iq1_s_f32(
|
|
6775
7091
|
|
6776
7092
|
[[host_name("kernel_mul_mv_id_iq4_nl_f32")]]
|
6777
7093
|
kernel void kernel_mul_mv_id_iq4_nl_f32(
|
6778
|
-
device const char *
|
7094
|
+
device const char * src0s,
|
6779
7095
|
device const char * src1,
|
6780
7096
|
device float * dst,
|
7097
|
+
device const char * ids,
|
6781
7098
|
constant uint64_t & nbi1,
|
6782
7099
|
constant int64_t & ne00,
|
6783
7100
|
constant int64_t & ne01,
|
@@ -6798,29 +7115,20 @@ kernel void kernel_mul_mv_id_iq4_nl_f32(
|
|
6798
7115
|
constant uint & r2,
|
6799
7116
|
constant uint & r3,
|
6800
7117
|
constant int & idx,
|
6801
|
-
device const char * src00,
|
6802
|
-
device const char * src01,
|
6803
|
-
device const char * src02,
|
6804
|
-
device const char * src03,
|
6805
|
-
device const char * src04,
|
6806
|
-
device const char * src05,
|
6807
|
-
device const char * src06,
|
6808
|
-
device const char * src07,
|
6809
7118
|
threadgroup float * shared_values [[threadgroup(0)]],
|
6810
7119
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
6811
7120
|
uint tiitg[[thread_index_in_threadgroup]],
|
6812
7121
|
uint tiisg[[thread_index_in_simdgroup]],
|
6813
7122
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6814
|
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
6815
|
-
|
6816
7123
|
const int64_t bid = tgpig.z/(ne12*ne13);
|
6817
7124
|
|
6818
7125
|
tgpig.z = tgpig.z%(ne12*ne13);
|
6819
7126
|
|
6820
7127
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
7128
|
+
device const char * src0 = src0s + id*nb02;
|
6821
7129
|
|
6822
7130
|
kernel_mul_mv_iq4_nl_f32_impl(
|
6823
|
-
src0
|
7131
|
+
src0,
|
6824
7132
|
(device const float *) (src1 + bid*nb11),
|
6825
7133
|
dst + bid*ne0,
|
6826
7134
|
ne00,
|
@@ -6840,9 +7148,10 @@ kernel void kernel_mul_mv_id_iq4_nl_f32(
|
|
6840
7148
|
|
6841
7149
|
[[host_name("kernel_mul_mv_id_iq4_xs_f32")]]
|
6842
7150
|
kernel void kernel_mul_mv_id_iq4_xs_f32(
|
6843
|
-
device const char *
|
7151
|
+
device const char * src0s,
|
6844
7152
|
device const char * src1,
|
6845
7153
|
device float * dst,
|
7154
|
+
device const char * ids,
|
6846
7155
|
constant uint64_t & nbi1,
|
6847
7156
|
constant int64_t & ne00,
|
6848
7157
|
constant int64_t & ne01,
|
@@ -6863,33 +7172,24 @@ kernel void kernel_mul_mv_id_iq4_xs_f32(
|
|
6863
7172
|
constant uint & r2,
|
6864
7173
|
constant uint & r3,
|
6865
7174
|
constant int & idx,
|
6866
|
-
device const char * src00,
|
6867
|
-
device const char * src01,
|
6868
|
-
device const char * src02,
|
6869
|
-
device const char * src03,
|
6870
|
-
device const char * src04,
|
6871
|
-
device const char * src05,
|
6872
|
-
device const char * src06,
|
6873
|
-
device const char * src07,
|
6874
7175
|
threadgroup float * shared_values [[threadgroup(0)]],
|
6875
7176
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
6876
7177
|
uint tiitg[[thread_index_in_threadgroup]],
|
6877
7178
|
uint tiisg[[thread_index_in_simdgroup]],
|
6878
7179
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
6879
|
-
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
6880
|
-
|
6881
7180
|
const int64_t bid = tgpig.z/(ne12*ne13);
|
6882
7181
|
|
6883
7182
|
tgpig.z = tgpig.z%(ne12*ne13);
|
6884
7183
|
|
6885
7184
|
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
7185
|
+
device const char * src0 = src0s + id*nb02;
|
6886
7186
|
|
6887
7187
|
#if QK_K == 64
|
6888
7188
|
kernel_mul_mv_iq4_nl_f32_impl(
|
6889
7189
|
#else
|
6890
7190
|
kernel_mul_mv_iq4_xs_f32_impl(
|
6891
7191
|
#endif
|
6892
|
-
src0
|
7192
|
+
src0,
|
6893
7193
|
(device const float *) (src1 + bid*nb11),
|
6894
7194
|
dst + bid*ne0,
|
6895
7195
|
ne00,
|