llama_cpp 0.10.0 → 0.10.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +12 -0
- data/ext/llama_cpp/llama_cpp.cpp +18 -1
- data/ext/llama_cpp/src/ggml-alloc.c +12 -4
- data/ext/llama_cpp/src/ggml-alloc.h +1 -1
- data/ext/llama_cpp/src/ggml-backend-impl.h +12 -8
- data/ext/llama_cpp/src/ggml-backend.c +75 -5
- data/ext/llama_cpp/src/ggml-backend.h +7 -0
- data/ext/llama_cpp/src/ggml-cuda.cu +952 -232
- data/ext/llama_cpp/src/ggml-metal.h +3 -0
- data/ext/llama_cpp/src/ggml-metal.m +725 -98
- data/ext/llama_cpp/src/ggml-metal.metal +1508 -171
- data/ext/llama_cpp/src/ggml-quants.c +2 -2
- data/ext/llama_cpp/src/ggml.c +554 -215
- data/ext/llama_cpp/src/ggml.h +58 -23
- data/ext/llama_cpp/src/llama.cpp +1157 -851
- data/ext/llama_cpp/src/llama.h +9 -4
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +2 -0
- metadata +2 -2
| @@ -79,6 +79,7 @@ kernel void kernel_add( | |
| 79 79 | 
             
                    constant  int64_t & nb1,
         | 
| 80 80 | 
             
                    constant  int64_t & nb2,
         | 
| 81 81 | 
             
                    constant  int64_t & nb3,
         | 
| 82 | 
            +
                    constant  int64_t & offs,
         | 
| 82 83 | 
             
                    uint3 tgpig[[threadgroup_position_in_grid]],
         | 
| 83 84 | 
             
                    uint3 tpitg[[thread_position_in_threadgroup]],
         | 
| 84 85 | 
             
                    uint3   ntg[[threads_per_threadgroup]]) {
         | 
| @@ -90,9 +91,9 @@ kernel void kernel_add( | |
| 90 91 | 
             
                const int64_t i12 = i02 % ne12;
         | 
| 91 92 | 
             
                const int64_t i11 = i01 % ne11;
         | 
| 92 93 |  | 
| 93 | 
            -
                device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
         | 
| 94 | 
            +
                device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs;
         | 
| 94 95 | 
             
                device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
         | 
| 95 | 
            -
                device       char * dst_ptr  = dst  + i03*nb3  + i02*nb2  + i01*nb1;
         | 
| 96 | 
            +
                device       char * dst_ptr  = dst  + i03*nb3  + i02*nb2  + i01*nb1  + offs;
         | 
| 96 97 |  | 
| 97 98 | 
             
                for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
         | 
| 98 99 | 
             
                    const int i10 = i0 % ne10;
         | 
| @@ -204,7 +205,7 @@ kernel void kernel_add_row( | |
| 204 205 | 
             
                    device const float4 * src0,
         | 
| 205 206 | 
             
                    device const float4 * src1,
         | 
| 206 207 | 
             
                    device       float4 * dst,
         | 
| 207 | 
            -
                    constant    int64_t & nb [[buffer( | 
| 208 | 
            +
                    constant    int64_t & nb [[buffer(28)]],
         | 
| 208 209 | 
             
                    uint tpig[[thread_position_in_grid]]) {
         | 
| 209 210 | 
             
                dst[tpig] = src0[tpig] + src1[tpig % nb];
         | 
| 210 211 | 
             
            }
         | 
| @@ -213,7 +214,7 @@ kernel void kernel_mul_row( | |
| 213 214 | 
             
                    device const float4 * src0,
         | 
| 214 215 | 
             
                    device const float4 * src1,
         | 
| 215 216 | 
             
                    device       float4 * dst,
         | 
| 216 | 
            -
                    constant    int64_t & nb  [[buffer( | 
| 217 | 
            +
                    constant    int64_t & nb  [[buffer(28)]],
         | 
| 217 218 | 
             
                    uint tpig[[thread_position_in_grid]]) {
         | 
| 218 219 | 
             
                dst[tpig] = src0[tpig] * src1[tpig % nb];
         | 
| 219 220 | 
             
            }
         | 
| @@ -222,7 +223,7 @@ kernel void kernel_div_row( | |
| 222 223 | 
             
                    device const float4 * src0,
         | 
| 223 224 | 
             
                    device const float4 * src1,
         | 
| 224 225 | 
             
                    device       float4 * dst,
         | 
| 225 | 
            -
                    constant    int64_t & nb  [[buffer( | 
| 226 | 
            +
                    constant    int64_t & nb  [[buffer(28)]],
         | 
| 226 227 | 
             
                    uint tpig[[thread_position_in_grid]]) {
         | 
| 227 228 | 
             
                dst[tpig] = src0[tpig] / src1[tpig % nb];
         | 
| 228 229 | 
             
            }
         | 
| @@ -243,19 +244,53 @@ kernel void kernel_scale_4( | |
| 243 244 | 
             
                dst[tpig] = src0[tpig] * scale;
         | 
| 244 245 | 
             
            }
         | 
| 245 246 |  | 
| 246 | 
            -
            kernel void  | 
| 247 | 
            -
                    device const  | 
| 248 | 
            -
                    device        | 
| 247 | 
            +
            kernel void kernel_relu(
         | 
| 248 | 
            +
                    device const float * src0,
         | 
| 249 | 
            +
                    device       float * dst,
         | 
| 249 250 | 
             
                    uint tpig[[thread_position_in_grid]]) {
         | 
| 250 | 
            -
                 | 
| 251 | 
            -
                dst[tpig] = x / (1.0f + exp(-x));
         | 
| 251 | 
            +
                dst[tpig] = max(0.0f, src0[tpig]);
         | 
| 252 252 | 
             
            }
         | 
| 253 253 |  | 
| 254 | 
            -
            kernel void  | 
| 254 | 
            +
            kernel void kernel_tanh(
         | 
| 255 255 | 
             
                    device const float * src0,
         | 
| 256 256 | 
             
                    device       float * dst,
         | 
| 257 257 | 
             
                    uint tpig[[thread_position_in_grid]]) {
         | 
| 258 | 
            -
                 | 
| 258 | 
            +
                device const float & x = src0[tpig];
         | 
| 259 | 
            +
                dst[tpig] = precise::tanh(x);
         | 
| 260 | 
            +
            }
         | 
| 261 | 
            +
             | 
| 262 | 
            +
            constant float GELU_COEF_A     = 0.044715f;
         | 
| 263 | 
            +
            constant float GELU_QUICK_COEF = -1.702f;
         | 
| 264 | 
            +
            constant float SQRT_2_OVER_PI  = 0.79788456080286535587989211986876f;
         | 
| 265 | 
            +
             | 
| 266 | 
            +
            kernel void kernel_gelu(
         | 
| 267 | 
            +
                device const float4 * src0,
         | 
| 268 | 
            +
                device       float4 * dst,
         | 
| 269 | 
            +
                uint tpig[[thread_position_in_grid]]) {
         | 
| 270 | 
            +
                device const float4 & x = src0[tpig];
         | 
| 271 | 
            +
             | 
| 272 | 
            +
                // BEWARE !!!
         | 
| 273 | 
            +
                // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
         | 
| 274 | 
            +
                // This was observed with Falcon 7B and 40B models
         | 
| 275 | 
            +
                //
         | 
| 276 | 
            +
                dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
         | 
| 277 | 
            +
            }
         | 
| 278 | 
            +
             | 
| 279 | 
            +
            kernel void kernel_gelu_quick(
         | 
| 280 | 
            +
                device const float4 * src0,
         | 
| 281 | 
            +
                device       float4 * dst,
         | 
| 282 | 
            +
                uint tpig[[thread_position_in_grid]]) {
         | 
| 283 | 
            +
                device const float4 & x = src0[tpig];
         | 
| 284 | 
            +
             | 
| 285 | 
            +
                dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
         | 
| 286 | 
            +
            }
         | 
| 287 | 
            +
             | 
| 288 | 
            +
            kernel void kernel_silu(
         | 
| 289 | 
            +
                    device const float4 * src0,
         | 
| 290 | 
            +
                    device       float4 * dst,
         | 
| 291 | 
            +
                    uint tpig[[thread_position_in_grid]]) {
         | 
| 292 | 
            +
                device const float4 & x = src0[tpig];
         | 
| 293 | 
            +
                dst[tpig] = x / (1.0f + exp(-x));
         | 
| 259 294 | 
             
            }
         | 
| 260 295 |  | 
| 261 296 | 
             
            kernel void kernel_sqr(
         | 
| @@ -313,22 +348,6 @@ kernel void kernel_sum_rows( | |
| 313 348 | 
             
                dst_row[0] = row_sum;
         | 
| 314 349 | 
             
            }
         | 
| 315 350 |  | 
| 316 | 
            -
            constant float GELU_COEF_A    = 0.044715f;
         | 
| 317 | 
            -
            constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
         | 
| 318 | 
            -
             | 
| 319 | 
            -
            kernel void kernel_gelu(
         | 
| 320 | 
            -
                device const float4 * src0,
         | 
| 321 | 
            -
                device       float4 * dst,
         | 
| 322 | 
            -
                uint tpig[[thread_position_in_grid]]) {
         | 
| 323 | 
            -
                device const float4 & x = src0[tpig];
         | 
| 324 | 
            -
             | 
| 325 | 
            -
                // BEWARE !!!
         | 
| 326 | 
            -
                // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
         | 
| 327 | 
            -
                // This was observed with Falcon 7B and 40B models
         | 
| 328 | 
            -
                //
         | 
| 329 | 
            -
                dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
         | 
| 330 | 
            -
            }
         | 
| 331 | 
            -
             | 
| 332 351 | 
             
            kernel void kernel_soft_max(
         | 
| 333 352 | 
             
                    device const float * src0,
         | 
| 334 353 | 
             
                    device const float * src1,
         | 
| @@ -347,9 +366,9 @@ kernel void kernel_soft_max( | |
| 347 366 | 
             
                const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
         | 
| 348 367 | 
             
                const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
         | 
| 349 368 |  | 
| 350 | 
            -
                device const float * psrc0 = | 
| 351 | 
            -
                device const float * pmask = src1 ? src1 | 
| 352 | 
            -
                device       float * pdst  = | 
| 369 | 
            +
                device const float * psrc0 =         src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
         | 
| 370 | 
            +
                device const float * pmask = src1 != src0 ? src1                               + i01*ne00 : nullptr;
         | 
| 371 | 
            +
                device       float * pdst  =         dst  + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
         | 
| 353 372 |  | 
| 354 373 | 
             
                // parallel max
         | 
| 355 374 | 
             
                float lmax = -INFINITY;
         | 
| @@ -385,7 +404,12 @@ kernel void kernel_soft_max( | |
| 385 404 | 
             
                    pdst[i00] = exp_psrc0;
         | 
| 386 405 | 
             
                }
         | 
| 387 406 |  | 
| 407 | 
            +
                // This barrier fixes a failing test
         | 
| 408 | 
            +
                // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
         | 
| 409 | 
            +
                threadgroup_barrier(mem_flags::mem_none);
         | 
| 410 | 
            +
             | 
| 388 411 | 
             
                float sum = simd_sum(lsum);
         | 
| 412 | 
            +
             | 
| 389 413 | 
             
                if (ntg > N_SIMDWIDTH) {
         | 
| 390 414 | 
             
                    if (sgitg == 0) {
         | 
| 391 415 | 
             
                        buf[tiisg] = 0.0f;
         | 
| @@ -428,9 +452,9 @@ kernel void kernel_soft_max_4( | |
| 428 452 | 
             
                const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
         | 
| 429 453 | 
             
                const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
         | 
| 430 454 |  | 
| 431 | 
            -
                device const float4 * psrc4 = | 
| 432 | 
            -
                device const float4 * pmask = src1 ? (device const float4 *)(src1 +                                      i01*ne00) : nullptr;
         | 
| 433 | 
            -
                device       float4 * pdst4 = | 
| 455 | 
            +
                device const float4 * psrc4 =                (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
         | 
| 456 | 
            +
                device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 +                                      i01*ne00) : nullptr;
         | 
| 457 | 
            +
                device       float4 * pdst4 =                (device       float4 *)(dst  + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
         | 
| 434 458 |  | 
| 435 459 | 
             
                // parallel max
         | 
| 436 460 | 
             
                float4 lmax4 = -INFINITY;
         | 
| @@ -468,7 +492,13 @@ kernel void kernel_soft_max_4( | |
| 468 492 | 
             
                }
         | 
| 469 493 |  | 
| 470 494 | 
             
                const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
         | 
| 495 | 
            +
             | 
| 496 | 
            +
                // This barrier fixes a failing test
         | 
| 497 | 
            +
                // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
         | 
| 498 | 
            +
                threadgroup_barrier(mem_flags::mem_none);
         | 
| 499 | 
            +
             | 
| 471 500 | 
             
                float sum = simd_sum(lsum);
         | 
| 501 | 
            +
             | 
| 472 502 | 
             
                if (ntg > N_SIMDWIDTH) {
         | 
| 473 503 | 
             
                    if (sgitg == 0) {
         | 
| 474 504 | 
             
                        buf[tiisg] = 0.0f;
         | 
| @@ -639,6 +669,94 @@ kernel void kernel_rms_norm( | |
| 639 669 | 
             
                }
         | 
| 640 670 | 
             
            }
         | 
| 641 671 |  | 
| 672 | 
            +
            kernel void kernel_group_norm(
         | 
| 673 | 
            +
                    device const float * src0,
         | 
| 674 | 
            +
                    device       float * dst,
         | 
| 675 | 
            +
                    constant   int64_t & ne00,
         | 
| 676 | 
            +
                    constant   int64_t & ne01,
         | 
| 677 | 
            +
                    constant   int64_t & ne02,
         | 
| 678 | 
            +
                    constant  uint64_t & nb00,
         | 
| 679 | 
            +
                    constant  uint64_t & nb01,
         | 
| 680 | 
            +
                    constant  uint64_t & nb02,
         | 
| 681 | 
            +
                    constant   int32_t & n_groups,
         | 
| 682 | 
            +
                    constant     float & eps,
         | 
| 683 | 
            +
                    threadgroup float  * buf [[threadgroup(0)]],
         | 
| 684 | 
            +
                    uint tgpig[[threadgroup_position_in_grid]],
         | 
| 685 | 
            +
                    uint tpitg[[thread_position_in_threadgroup]],
         | 
| 686 | 
            +
                    uint sgitg[[simdgroup_index_in_threadgroup]],
         | 
| 687 | 
            +
                    uint tiisg[[thread_index_in_simdgroup]],
         | 
| 688 | 
            +
                    uint   ntg[[threads_per_threadgroup]]) {
         | 
| 689 | 
            +
                const int64_t ne = ne00*ne01*ne02;
         | 
| 690 | 
            +
                const int64_t gs = ne00*ne01*((ne02 + n_groups - 1) / n_groups);
         | 
| 691 | 
            +
             | 
| 692 | 
            +
                int start = tgpig * gs;
         | 
| 693 | 
            +
                int end   = start + gs;
         | 
| 694 | 
            +
             | 
| 695 | 
            +
                start += tpitg;
         | 
| 696 | 
            +
             | 
| 697 | 
            +
                if (end >= ne) {
         | 
| 698 | 
            +
                    end = ne;
         | 
| 699 | 
            +
                }
         | 
| 700 | 
            +
             | 
| 701 | 
            +
                float tmp = 0.0f; // partial sum for thread in warp
         | 
| 702 | 
            +
             | 
| 703 | 
            +
                for (int j = start; j < end; j += ntg) {
         | 
| 704 | 
            +
                    tmp += src0[j];
         | 
| 705 | 
            +
                }
         | 
| 706 | 
            +
             | 
| 707 | 
            +
                threadgroup_barrier(mem_flags::mem_threadgroup);
         | 
| 708 | 
            +
                tmp = simd_sum(tmp);
         | 
| 709 | 
            +
                if (ntg > N_SIMDWIDTH) {
         | 
| 710 | 
            +
                    if (sgitg == 0) {
         | 
| 711 | 
            +
                        buf[tiisg] = 0.0f;
         | 
| 712 | 
            +
                    }
         | 
| 713 | 
            +
             | 
| 714 | 
            +
                    threadgroup_barrier(mem_flags::mem_threadgroup);
         | 
| 715 | 
            +
             | 
| 716 | 
            +
                    if (tiisg == 0) {
         | 
| 717 | 
            +
                        buf[sgitg] = tmp;
         | 
| 718 | 
            +
                    }
         | 
| 719 | 
            +
             | 
| 720 | 
            +
                    threadgroup_barrier(mem_flags::mem_threadgroup);
         | 
| 721 | 
            +
             | 
| 722 | 
            +
                    tmp = buf[tiisg];
         | 
| 723 | 
            +
                    tmp = simd_sum(tmp);
         | 
| 724 | 
            +
                }
         | 
| 725 | 
            +
             | 
| 726 | 
            +
                const float mean = tmp / gs;
         | 
| 727 | 
            +
                tmp = 0.0f;
         | 
| 728 | 
            +
             | 
| 729 | 
            +
                for (int j = start; j < end; j += ntg) {
         | 
| 730 | 
            +
                    float xi = src0[j] - mean;
         | 
| 731 | 
            +
                    dst[j] = xi;
         | 
| 732 | 
            +
                    tmp += xi * xi;
         | 
| 733 | 
            +
                }
         | 
| 734 | 
            +
             | 
| 735 | 
            +
                tmp = simd_sum(tmp);
         | 
| 736 | 
            +
                if (ntg > N_SIMDWIDTH) {
         | 
| 737 | 
            +
                    if (sgitg == 0) {
         | 
| 738 | 
            +
                        buf[tiisg] = 0.0f;
         | 
| 739 | 
            +
                    }
         | 
| 740 | 
            +
             | 
| 741 | 
            +
                    threadgroup_barrier(mem_flags::mem_threadgroup);
         | 
| 742 | 
            +
             | 
| 743 | 
            +
                    if (tiisg == 0) {
         | 
| 744 | 
            +
                        buf[sgitg] = tmp;
         | 
| 745 | 
            +
                    }
         | 
| 746 | 
            +
             | 
| 747 | 
            +
                    threadgroup_barrier(mem_flags::mem_threadgroup);
         | 
| 748 | 
            +
             | 
| 749 | 
            +
                    tmp = buf[tiisg];
         | 
| 750 | 
            +
                    tmp = simd_sum(tmp);
         | 
| 751 | 
            +
                }
         | 
| 752 | 
            +
             | 
| 753 | 
            +
                const float variance = tmp / gs;
         | 
| 754 | 
            +
                const float scale = 1.0f/sqrt(variance + eps);
         | 
| 755 | 
            +
                for (int j = start; j < end; j += ntg) {
         | 
| 756 | 
            +
                    dst[j] *= scale;
         | 
| 757 | 
            +
                }
         | 
| 758 | 
            +
            }
         | 
| 759 | 
            +
             | 
| 642 760 | 
             
            // function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
         | 
| 643 761 | 
             
            // il indicates where the q4 quants begin (0 or QK4_0/4)
         | 
| 644 762 | 
             
            // we assume that the yl's have been multiplied with the appropriate scale factor
         | 
| @@ -731,7 +849,7 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre | |
| 731 849 | 
             
            //      giard against the number of rows not being divisible by
         | 
| 732 850 | 
             
            //      N_DST, so this is another explicit assumption of the implementation.
         | 
| 733 851 | 
             
            template<typename block_q_type, int nr, int nsg, int nw>
         | 
| 734 | 
            -
            void  | 
| 852 | 
            +
            void mul_vec_q_n_f32_impl(
         | 
| 735 853 | 
             
                    device const void  * src0,
         | 
| 736 854 | 
             
                    device const float * src1,
         | 
| 737 855 | 
             
                    device       float * dst,
         | 
| @@ -813,7 +931,7 @@ kernel void kernel_mul_mv_q4_0_f32( | |
| 813 931 | 
             
                    uint3 tgpig[[threadgroup_position_in_grid]],
         | 
| 814 932 | 
             
                    uint  tiisg[[thread_index_in_simdgroup]],
         | 
| 815 933 | 
             
                    uint  sgitg[[simdgroup_index_in_threadgroup]]) {
         | 
| 816 | 
            -
                 | 
| 934 | 
            +
                mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
         | 
| 817 935 | 
             
            }
         | 
| 818 936 |  | 
| 819 937 | 
             
            kernel void kernel_mul_mv_q4_1_f32(
         | 
| @@ -832,7 +950,7 @@ kernel void kernel_mul_mv_q4_1_f32( | |
| 832 950 | 
             
                    uint3 tgpig[[threadgroup_position_in_grid]],
         | 
| 833 951 | 
             
                    uint tiisg[[thread_index_in_simdgroup]],
         | 
| 834 952 | 
             
                    uint sgitg[[simdgroup_index_in_threadgroup]]) {
         | 
| 835 | 
            -
                  | 
| 953 | 
            +
                 mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
         | 
| 836 954 | 
             
            }
         | 
| 837 955 |  | 
| 838 956 | 
             
            kernel void kernel_mul_mv_q5_0_f32(
         | 
| @@ -851,7 +969,7 @@ kernel void kernel_mul_mv_q5_0_f32( | |
| 851 969 | 
             
                    uint3 tgpig[[threadgroup_position_in_grid]],
         | 
| 852 970 | 
             
                    uint  tiisg[[thread_index_in_simdgroup]],
         | 
| 853 971 | 
             
                    uint  sgitg[[simdgroup_index_in_threadgroup]]) {
         | 
| 854 | 
            -
                 | 
| 972 | 
            +
                mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
         | 
| 855 973 | 
             
            }
         | 
| 856 974 |  | 
| 857 975 | 
             
            kernel void kernel_mul_mv_q5_1_f32(
         | 
| @@ -870,28 +988,28 @@ kernel void kernel_mul_mv_q5_1_f32( | |
| 870 988 | 
             
                    uint3 tgpig[[threadgroup_position_in_grid]],
         | 
| 871 989 | 
             
                    uint  tiisg[[thread_index_in_simdgroup]],
         | 
| 872 990 | 
             
                    uint  sgitg[[simdgroup_index_in_threadgroup]]) {
         | 
| 873 | 
            -
                 | 
| 991 | 
            +
                mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
         | 
| 874 992 | 
             
            }
         | 
| 875 993 |  | 
| 876 994 |  | 
| 877 995 | 
             
            #define NB_Q8_0 8
         | 
| 878 996 |  | 
| 879 | 
            -
             | 
| 997 | 
            +
            void kernel_mul_mv_q8_0_f32_impl(
         | 
| 880 998 | 
             
                    device const  void * src0,
         | 
| 881 999 | 
             
                    device const float * src1,
         | 
| 882 1000 | 
             
                    device       float * dst,
         | 
| 883 1001 | 
             
                    constant   int64_t & ne00,
         | 
| 884 | 
            -
                    constant   int64_t & ne01 | 
| 885 | 
            -
                    constant   int64_t & ne02 | 
| 886 | 
            -
                    constant   int64_t & ne10 | 
| 887 | 
            -
                    constant   int64_t & ne12 | 
| 888 | 
            -
                    constant   int64_t & ne0 | 
| 889 | 
            -
                    constant   int64_t & ne1 | 
| 890 | 
            -
                    constant   uint    & r2 | 
| 891 | 
            -
                    constant   uint    & r3 | 
| 1002 | 
            +
                    constant   int64_t & ne01,
         | 
| 1003 | 
            +
                    constant   int64_t & ne02,
         | 
| 1004 | 
            +
                    constant   int64_t & ne10,
         | 
| 1005 | 
            +
                    constant   int64_t & ne12,
         | 
| 1006 | 
            +
                    constant   int64_t & ne0,
         | 
| 1007 | 
            +
                    constant   int64_t & ne1,
         | 
| 1008 | 
            +
                    constant   uint    & r2,
         | 
| 1009 | 
            +
                    constant   uint    & r3,
         | 
| 892 1010 | 
             
                    uint3 tgpig[[threadgroup_position_in_grid]],
         | 
| 893 | 
            -
                    uint | 
| 894 | 
            -
                    uint | 
| 1011 | 
            +
                    uint  tiisg[[thread_index_in_simdgroup]],
         | 
| 1012 | 
            +
                    uint  sgitg[[simdgroup_index_in_threadgroup]]) {
         | 
| 895 1013 | 
             
                const int nr  = N_DST;
         | 
| 896 1014 | 
             
                const int nsg = N_SIMDGROUP;
         | 
| 897 1015 | 
             
                const int nw  = N_SIMDWIDTH;
         | 
| @@ -945,9 +1063,29 @@ kernel void kernel_mul_mv_q8_0_f32( | |
| 945 1063 | 
             
                }
         | 
| 946 1064 | 
             
            }
         | 
| 947 1065 |  | 
| 1066 | 
            +
            [[host_name("kernel_mul_mv_q8_0_f32")]]
         | 
| 1067 | 
            +
            kernel void kernel_mul_mv_q8_0_f32(
         | 
| 1068 | 
            +
                    device const  void * src0,
         | 
| 1069 | 
            +
                    device const float * src1,
         | 
| 1070 | 
            +
                    device       float * dst,
         | 
| 1071 | 
            +
                    constant   int64_t & ne00,
         | 
| 1072 | 
            +
                    constant   int64_t & ne01,
         | 
| 1073 | 
            +
                    constant   int64_t & ne02,
         | 
| 1074 | 
            +
                    constant   int64_t & ne10,
         | 
| 1075 | 
            +
                    constant   int64_t & ne12,
         | 
| 1076 | 
            +
                    constant   int64_t & ne0,
         | 
| 1077 | 
            +
                    constant   int64_t & ne1,
         | 
| 1078 | 
            +
                    constant   uint    & r2   [[buffer(17)]],
         | 
| 1079 | 
            +
                    constant   uint    & r3   [[buffer(18)]],
         | 
| 1080 | 
            +
                    uint3 tgpig[[threadgroup_position_in_grid]],
         | 
| 1081 | 
            +
                    uint  tiisg[[thread_index_in_simdgroup]],
         | 
| 1082 | 
            +
                    uint  sgitg[[simdgroup_index_in_threadgroup]]) {
         | 
| 1083 | 
            +
                kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
         | 
| 1084 | 
            +
            }
         | 
| 1085 | 
            +
             | 
| 948 1086 | 
             
            #define N_F32_F32 4
         | 
| 949 1087 |  | 
| 950 | 
            -
             | 
| 1088 | 
            +
            void kernel_mul_mv_f32_f32_impl(
         | 
| 951 1089 | 
             
                    device const  char * src0,
         | 
| 952 1090 | 
             
                    device const  char * src1,
         | 
| 953 1091 | 
             
                    device       float * dst,
         | 
| @@ -965,8 +1103,8 @@ kernel void kernel_mul_mv_f32_f32( | |
| 965 1103 | 
             
                    constant  uint64_t & nb12,
         | 
| 966 1104 | 
             
                    constant   int64_t & ne0,
         | 
| 967 1105 | 
             
                    constant   int64_t & ne1,
         | 
| 968 | 
            -
                    constant   uint    & r2 | 
| 969 | 
            -
                    constant   uint    & r3 | 
| 1106 | 
            +
                    constant   uint    & r2,
         | 
| 1107 | 
            +
                    constant   uint    & r3,
         | 
| 970 1108 | 
             
                    uint3 tgpig[[threadgroup_position_in_grid]],
         | 
| 971 1109 | 
             
                    uint  tiisg[[thread_index_in_simdgroup]]) {
         | 
| 972 1110 |  | 
| @@ -1025,6 +1163,32 @@ kernel void kernel_mul_mv_f32_f32( | |
| 1025 1163 | 
             
                }
         | 
| 1026 1164 | 
             
            }
         | 
| 1027 1165 |  | 
| 1166 | 
            +
            [[host_name("kernel_mul_mv_f32_f32")]]
         | 
| 1167 | 
            +
            kernel void kernel_mul_mv_f32_f32(
         | 
| 1168 | 
            +
                    device const  char * src0,
         | 
| 1169 | 
            +
                    device const  char * src1,
         | 
| 1170 | 
            +
                    device       float * dst,
         | 
| 1171 | 
            +
                    constant   int64_t & ne00,
         | 
| 1172 | 
            +
                    constant   int64_t & ne01,
         | 
| 1173 | 
            +
                    constant   int64_t & ne02,
         | 
| 1174 | 
            +
                    constant  uint64_t & nb00,
         | 
| 1175 | 
            +
                    constant  uint64_t & nb01,
         | 
| 1176 | 
            +
                    constant  uint64_t & nb02,
         | 
| 1177 | 
            +
                    constant   int64_t & ne10,
         | 
| 1178 | 
            +
                    constant   int64_t & ne11,
         | 
| 1179 | 
            +
                    constant   int64_t & ne12,
         | 
| 1180 | 
            +
                    constant  uint64_t & nb10,
         | 
| 1181 | 
            +
                    constant  uint64_t & nb11,
         | 
| 1182 | 
            +
                    constant  uint64_t & nb12,
         | 
| 1183 | 
            +
                    constant   int64_t & ne0,
         | 
| 1184 | 
            +
                    constant   int64_t & ne1,
         | 
| 1185 | 
            +
                    constant   uint    & r2   [[buffer(17)]],
         | 
| 1186 | 
            +
                    constant   uint    & r3   [[buffer(18)]],
         | 
| 1187 | 
            +
                    uint3 tgpig[[threadgroup_position_in_grid]],
         | 
| 1188 | 
            +
                    uint  tiisg[[thread_index_in_simdgroup]]) {
         | 
| 1189 | 
            +
                kernel_mul_mv_f32_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
         | 
| 1190 | 
            +
            }
         | 
| 1191 | 
            +
             | 
| 1028 1192 | 
             
            #define N_F16_F16 4
         | 
| 1029 1193 |  | 
| 1030 1194 | 
             
            kernel void kernel_mul_mv_f16_f16(
         | 
| @@ -1105,7 +1269,7 @@ kernel void kernel_mul_mv_f16_f16( | |
| 1105 1269 | 
             
                }
         | 
| 1106 1270 | 
             
            }
         | 
| 1107 1271 |  | 
| 1108 | 
            -
             | 
| 1272 | 
            +
            void kernel_mul_mv_f16_f32_1row_impl(
         | 
| 1109 1273 | 
             
                    device const  char * src0,
         | 
| 1110 1274 | 
             
                    device const  char * src1,
         | 
| 1111 1275 | 
             
                    device       float * dst,
         | 
| @@ -1123,8 +1287,8 @@ kernel void kernel_mul_mv_f16_f32_1row( | |
| 1123 1287 | 
             
                    constant  uint64_t & nb12,
         | 
| 1124 1288 | 
             
                    constant   int64_t & ne0,
         | 
| 1125 1289 | 
             
                    constant   int64_t & ne1,
         | 
| 1126 | 
            -
                    constant   uint    & r2 | 
| 1127 | 
            -
                    constant   uint    & r3 | 
| 1290 | 
            +
                    constant   uint    & r2,
         | 
| 1291 | 
            +
                    constant   uint    & r3,
         | 
| 1128 1292 | 
             
                    uint3 tgpig[[threadgroup_position_in_grid]],
         | 
| 1129 1293 | 
             
                    uint  tiisg[[thread_index_in_simdgroup]]) {
         | 
| 1130 1294 |  | 
| @@ -1161,12 +1325,37 @@ kernel void kernel_mul_mv_f16_f32_1row( | |
| 1161 1325 | 
             
                        dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
         | 
| 1162 1326 | 
             
                    }
         | 
| 1163 1327 | 
             
                }
         | 
| 1328 | 
            +
            }
         | 
| 1164 1329 |  | 
| 1330 | 
            +
            [[host_name("kernel_mul_mv_f16_f32_1row")]]
         | 
| 1331 | 
            +
            kernel void kernel_mul_mv_f16_f32_1row(
         | 
| 1332 | 
            +
                    device const  char * src0,
         | 
| 1333 | 
            +
                    device const  char * src1,
         | 
| 1334 | 
            +
                    device       float * dst,
         | 
| 1335 | 
            +
                    constant   int64_t & ne00,
         | 
| 1336 | 
            +
                    constant   int64_t & ne01,
         | 
| 1337 | 
            +
                    constant   int64_t & ne02,
         | 
| 1338 | 
            +
                    constant  uint64_t & nb00,
         | 
| 1339 | 
            +
                    constant  uint64_t & nb01,
         | 
| 1340 | 
            +
                    constant  uint64_t & nb02,
         | 
| 1341 | 
            +
                    constant   int64_t & ne10,
         | 
| 1342 | 
            +
                    constant   int64_t & ne11,
         | 
| 1343 | 
            +
                    constant   int64_t & ne12,
         | 
| 1344 | 
            +
                    constant  uint64_t & nb10,
         | 
| 1345 | 
            +
                    constant  uint64_t & nb11,
         | 
| 1346 | 
            +
                    constant  uint64_t & nb12,
         | 
| 1347 | 
            +
                    constant   int64_t & ne0,
         | 
| 1348 | 
            +
                    constant   int64_t & ne1,
         | 
| 1349 | 
            +
                    constant   uint    & r2   [[buffer(17)]],
         | 
| 1350 | 
            +
                    constant   uint    & r3   [[buffer(18)]],
         | 
| 1351 | 
            +
                    uint3 tgpig[[threadgroup_position_in_grid]],
         | 
| 1352 | 
            +
                    uint  tiisg[[thread_index_in_simdgroup]]) {
         | 
| 1353 | 
            +
                kernel_mul_mv_f16_f32_1row_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
         | 
| 1165 1354 | 
             
            }
         | 
| 1166 1355 |  | 
| 1167 1356 | 
             
            #define N_F16_F32 4
         | 
| 1168 1357 |  | 
| 1169 | 
            -
             | 
| 1358 | 
            +
            void kernel_mul_mv_f16_f32_impl(
         | 
| 1170 1359 | 
             
                    device const  char * src0,
         | 
| 1171 1360 | 
             
                    device const  char * src1,
         | 
| 1172 1361 | 
             
                    device       float * dst,
         | 
| @@ -1184,8 +1373,8 @@ kernel void kernel_mul_mv_f16_f32( | |
| 1184 1373 | 
             
                    constant  uint64_t & nb12,
         | 
| 1185 1374 | 
             
                    constant   int64_t & ne0,
         | 
| 1186 1375 | 
             
                    constant   int64_t & ne1,
         | 
| 1187 | 
            -
                    constant   uint    & r2 | 
| 1188 | 
            -
                    constant   uint    & r3 | 
| 1376 | 
            +
                    constant   uint    & r2,
         | 
| 1377 | 
            +
                    constant   uint    & r3,
         | 
| 1189 1378 | 
             
                    uint3 tgpig[[threadgroup_position_in_grid]],
         | 
| 1190 1379 | 
             
                    uint tiisg[[thread_index_in_simdgroup]]) {
         | 
| 1191 1380 |  | 
| @@ -1244,6 +1433,32 @@ kernel void kernel_mul_mv_f16_f32( | |
| 1244 1433 | 
             
                }
         | 
| 1245 1434 | 
             
            }
         | 
| 1246 1435 |  | 
| 1436 | 
            +
            [[host_name("kernel_mul_mv_f16_f32")]]
         | 
| 1437 | 
            +
            kernel void kernel_mul_mv_f16_f32(
         | 
| 1438 | 
            +
                    device const  char * src0,
         | 
| 1439 | 
            +
                    device const  char * src1,
         | 
| 1440 | 
            +
                    device       float * dst,
         | 
| 1441 | 
            +
                    constant   int64_t & ne00,
         | 
| 1442 | 
            +
                    constant   int64_t & ne01,
         | 
| 1443 | 
            +
                    constant   int64_t & ne02,
         | 
| 1444 | 
            +
                    constant  uint64_t & nb00,
         | 
| 1445 | 
            +
                    constant  uint64_t & nb01,
         | 
| 1446 | 
            +
                    constant  uint64_t & nb02,
         | 
| 1447 | 
            +
                    constant   int64_t & ne10,
         | 
| 1448 | 
            +
                    constant   int64_t & ne11,
         | 
| 1449 | 
            +
                    constant   int64_t & ne12,
         | 
| 1450 | 
            +
                    constant  uint64_t & nb10,
         | 
| 1451 | 
            +
                    constant  uint64_t & nb11,
         | 
| 1452 | 
            +
                    constant  uint64_t & nb12,
         | 
| 1453 | 
            +
                    constant   int64_t & ne0,
         | 
| 1454 | 
            +
                    constant   int64_t & ne1,
         | 
| 1455 | 
            +
                    constant   uint    & r2   [[buffer(17)]],
         | 
| 1456 | 
            +
                    constant   uint    & r3   [[buffer(18)]],
         | 
| 1457 | 
            +
                    uint3 tgpig[[threadgroup_position_in_grid]],
         | 
| 1458 | 
            +
                    uint tiisg[[thread_index_in_simdgroup]]) {
         | 
| 1459 | 
            +
                kernel_mul_mv_f16_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
         | 
| 1460 | 
            +
            }
         | 
| 1461 | 
            +
             | 
| 1247 1462 | 
             
            // Assumes row size (ne00) is a multiple of 4
         | 
| 1248 1463 | 
             
            kernel void kernel_mul_mv_f16_f32_l4(
         | 
| 1249 1464 | 
             
                    device const  char * src0,
         | 
| @@ -1487,8 +1702,9 @@ kernel void kernel_rope( | |
| 1487 1702 | 
             
                        dst_data[1] = x0*sin_theta + x1*cos_theta;
         | 
| 1488 1703 | 
             
                    }
         | 
| 1489 1704 | 
             
                } else {
         | 
| 1490 | 
            -
                    for (int64_t  | 
| 1491 | 
            -
                         | 
| 1705 | 
            +
                    for (int64_t ic = 2*tiitg; ic < ne0; ic += 2*tptg.x) {
         | 
| 1706 | 
            +
                        if (ic < n_dims) {
         | 
| 1707 | 
            +
                            const int64_t ib = 0;
         | 
| 1492 1708 |  | 
| 1493 1709 | 
             
                            // simplified from `(ib * n_dims + ic) * inv_ndims`
         | 
| 1494 1710 | 
             
                            const float cur_rot = inv_ndims*ic - ib;
         | 
| @@ -1507,6 +1723,14 @@ kernel void kernel_rope( | |
| 1507 1723 |  | 
| 1508 1724 | 
             
                            dst_data[0]        = x0*cos_theta - x1*sin_theta;
         | 
| 1509 1725 | 
             
                            dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
         | 
| 1726 | 
            +
                        } else {
         | 
| 1727 | 
            +
                            const int64_t i0 = ic;
         | 
| 1728 | 
            +
             | 
| 1729 | 
            +
                            device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
         | 
| 1730 | 
            +
                            device       T * dst_data  = (device T *)((device char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
         | 
| 1731 | 
            +
             | 
| 1732 | 
            +
                            dst_data[0] = src[0];
         | 
| 1733 | 
            +
                            dst_data[1] = src[1];
         | 
| 1510 1734 | 
             
                        }
         | 
| 1511 1735 | 
             
                    }
         | 
| 1512 1736 | 
             
                }
         | 
| @@ -1548,21 +1772,112 @@ kernel void kernel_im2col_f16( | |
| 1548 1772 | 
             
                }
         | 
| 1549 1773 | 
             
            }
         | 
| 1550 1774 |  | 
| 1551 | 
            -
             | 
| 1552 | 
            -
             | 
| 1553 | 
            -
                     | 
| 1554 | 
            -
             | 
| 1555 | 
            -
             | 
| 1556 | 
            -
             | 
| 1557 | 
            -
             | 
| 1558 | 
            -
             | 
| 1559 | 
            -
             | 
| 1560 | 
            -
             | 
| 1561 | 
            -
             | 
| 1562 | 
            -
             | 
| 1563 | 
            -
             | 
| 1564 | 
            -
             | 
| 1565 | 
            -
             | 
| 1775 | 
            +
            kernel void kernel_upscale_f32(
         | 
| 1776 | 
            +
                device  const char * src0,
         | 
| 1777 | 
            +
                device        char * dst,
         | 
| 1778 | 
            +
                constant   int64_t & ne00,
         | 
| 1779 | 
            +
                constant   int64_t & ne01,
         | 
| 1780 | 
            +
                constant   int64_t & ne02,
         | 
| 1781 | 
            +
                constant   int64_t & ne03,
         | 
| 1782 | 
            +
                constant  uint64_t & nb00,
         | 
| 1783 | 
            +
                constant  uint64_t & nb01,
         | 
| 1784 | 
            +
                constant  uint64_t & nb02,
         | 
| 1785 | 
            +
                constant  uint64_t & nb03,
         | 
| 1786 | 
            +
                constant   int64_t & ne0,
         | 
| 1787 | 
            +
                constant   int64_t & ne1,
         | 
| 1788 | 
            +
                constant   int64_t & ne2,
         | 
| 1789 | 
            +
                constant   int64_t & ne3,
         | 
| 1790 | 
            +
                constant  uint64_t & nb0,
         | 
| 1791 | 
            +
                constant  uint64_t & nb1,
         | 
| 1792 | 
            +
                constant  uint64_t & nb2,
         | 
| 1793 | 
            +
                constant  uint64_t & nb3,
         | 
| 1794 | 
            +
                constant   int32_t & sf,
         | 
| 1795 | 
            +
                uint3 tgpig[[threadgroup_position_in_grid]],
         | 
| 1796 | 
            +
                uint3 tpitg[[thread_position_in_threadgroup]],
         | 
| 1797 | 
            +
                uint3   ntg[[threads_per_threadgroup]]) {
         | 
| 1798 | 
            +
             | 
| 1799 | 
            +
                const int64_t i3 = tgpig.z;
         | 
| 1800 | 
            +
                const int64_t i2 = tgpig.y;
         | 
| 1801 | 
            +
                const int64_t i1 = tgpig.x;
         | 
| 1802 | 
            +
             | 
| 1803 | 
            +
                const int64_t i03 = i3;
         | 
| 1804 | 
            +
                const int64_t i02 = i2;
         | 
| 1805 | 
            +
                const int64_t i01 = i1/sf;
         | 
| 1806 | 
            +
             | 
| 1807 | 
            +
                device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
         | 
| 1808 | 
            +
                device       float * dst_ptr  = (device       float *) (dst  +  i3*nb3  +  i2*nb2  +  i1*nb1);
         | 
| 1809 | 
            +
             | 
| 1810 | 
            +
                for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
         | 
| 1811 | 
            +
                    dst_ptr[i0] = src0_ptr[i0/sf];
         | 
| 1812 | 
            +
                }
         | 
| 1813 | 
            +
            }
         | 
| 1814 | 
            +
             | 
| 1815 | 
            +
            kernel void kernel_pad_f32(
         | 
| 1816 | 
            +
                device  const char * src0,
         | 
| 1817 | 
            +
                device        char * dst,
         | 
| 1818 | 
            +
                constant   int64_t & ne00,
         | 
| 1819 | 
            +
                constant   int64_t & ne01,
         | 
| 1820 | 
            +
                constant   int64_t & ne02,
         | 
| 1821 | 
            +
                constant   int64_t & ne03,
         | 
| 1822 | 
            +
                constant  uint64_t & nb00,
         | 
| 1823 | 
            +
                constant  uint64_t & nb01,
         | 
| 1824 | 
            +
                constant  uint64_t & nb02,
         | 
| 1825 | 
            +
                constant  uint64_t & nb03,
         | 
| 1826 | 
            +
                constant   int64_t & ne0,
         | 
| 1827 | 
            +
                constant   int64_t & ne1,
         | 
| 1828 | 
            +
                constant   int64_t & ne2,
         | 
| 1829 | 
            +
                constant   int64_t & ne3,
         | 
| 1830 | 
            +
                constant  uint64_t & nb0,
         | 
| 1831 | 
            +
                constant  uint64_t & nb1,
         | 
| 1832 | 
            +
                constant  uint64_t & nb2,
         | 
| 1833 | 
            +
                constant  uint64_t & nb3,
         | 
| 1834 | 
            +
                uint3 tgpig[[threadgroup_position_in_grid]],
         | 
| 1835 | 
            +
                uint3 tpitg[[thread_position_in_threadgroup]],
         | 
| 1836 | 
            +
                uint3   ntg[[threads_per_threadgroup]]) {
         | 
| 1837 | 
            +
             | 
| 1838 | 
            +
                const int64_t i3 = tgpig.z;
         | 
| 1839 | 
            +
                const int64_t i2 = tgpig.y;
         | 
| 1840 | 
            +
                const int64_t i1 = tgpig.x;
         | 
| 1841 | 
            +
             | 
| 1842 | 
            +
                const int64_t i03 = i3;
         | 
| 1843 | 
            +
                const int64_t i02 = i2;
         | 
| 1844 | 
            +
                const int64_t i01 = i1;
         | 
| 1845 | 
            +
             | 
| 1846 | 
            +
                device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
         | 
| 1847 | 
            +
                device       float * dst_ptr  = (device       float *) (dst  +  i3*nb3  +  i2*nb2  +  i1*nb1);
         | 
| 1848 | 
            +
             | 
| 1849 | 
            +
                if (i1 < ne01 && i2 < ne02 && i3 < ne03) {
         | 
| 1850 | 
            +
                    for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
         | 
| 1851 | 
            +
                        if (i0 < ne00) {
         | 
| 1852 | 
            +
                            dst_ptr[i0] = src0_ptr[i0];
         | 
| 1853 | 
            +
                        } else {
         | 
| 1854 | 
            +
                            dst_ptr[i0] = 0.0f;
         | 
| 1855 | 
            +
                        }
         | 
| 1856 | 
            +
                    }
         | 
| 1857 | 
            +
             | 
| 1858 | 
            +
                    return;
         | 
| 1859 | 
            +
                }
         | 
| 1860 | 
            +
             | 
| 1861 | 
            +
                for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
         | 
| 1862 | 
            +
                    dst_ptr[i0] = 0.0f;
         | 
| 1863 | 
            +
                }
         | 
| 1864 | 
            +
            }
         | 
| 1865 | 
            +
             | 
| 1866 | 
            +
            // bitonic sort implementation following the CUDA kernels as reference
         | 
| 1867 | 
            +
            typedef void (argsort_t)(
         | 
| 1868 | 
            +
                    device const float * x,
         | 
| 1869 | 
            +
                    device     int32_t * dst,
         | 
| 1870 | 
            +
                    constant   int64_t & ncols,
         | 
| 1871 | 
            +
                    uint3 tgpig[[threadgroup_position_in_grid]],
         | 
| 1872 | 
            +
                    uint3 tpitg[[thread_position_in_threadgroup]]);
         | 
| 1873 | 
            +
             | 
| 1874 | 
            +
            template<ggml_sort_order order>
         | 
| 1875 | 
            +
            kernel void kernel_argsort_f32_i32(
         | 
| 1876 | 
            +
                    device const float   * x,
         | 
| 1877 | 
            +
                    device       int32_t * dst,
         | 
| 1878 | 
            +
                    constant     int64_t & ncols,
         | 
| 1879 | 
            +
                    uint3 tgpig[[threadgroup_position_in_grid]],
         | 
| 1880 | 
            +
                    uint3 tpitg[[thread_position_in_threadgroup]]) {
         | 
| 1566 1881 | 
             
                // bitonic sort
         | 
| 1567 1882 | 
             
                int col = tpitg[0];
         | 
| 1568 1883 | 
             
                int row = tgpig[1];
         | 
| @@ -1600,9 +1915,17 @@ kernel void kernel_argsort_f32_i32( | |
| 1600 1915 | 
             
            template [[host_name("kernel_argsort_f32_i32_asc")]]  kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ASC>;
         | 
| 1601 1916 | 
             
            template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_DESC>;
         | 
| 1602 1917 |  | 
| 1918 | 
            +
            kernel void kernel_leaky_relu_f32(
         | 
| 1919 | 
            +
                    device const float * src0,
         | 
| 1920 | 
            +
                    device       float * dst,
         | 
| 1921 | 
            +
                    constant     float & slope,
         | 
| 1922 | 
            +
                    uint tpig[[thread_position_in_grid]]) {
         | 
| 1923 | 
            +
                dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope;
         | 
| 1924 | 
            +
            }
         | 
| 1925 | 
            +
             | 
| 1603 1926 | 
             
            kernel void kernel_cpy_f16_f16(
         | 
| 1604 | 
            -
                    device | 
| 1605 | 
            -
                    device | 
| 1927 | 
            +
                    device  const half * src0,
         | 
| 1928 | 
            +
                    device        half * dst,
         | 
| 1606 1929 | 
             
                    constant   int64_t & ne00,
         | 
| 1607 1930 | 
             
                    constant   int64_t & ne01,
         | 
| 1608 1931 | 
             
                    constant   int64_t & ne02,
         | 
| @@ -1641,6 +1964,47 @@ kernel void kernel_cpy_f16_f16( | |
| 1641 1964 | 
             
                }
         | 
| 1642 1965 | 
             
            }
         | 
| 1643 1966 |  | 
| 1967 | 
            +
            kernel void kernel_cpy_f16_f32(
         | 
| 1968 | 
            +
                    device  const half * src0,
         | 
| 1969 | 
            +
                    device       float * dst,
         | 
| 1970 | 
            +
                    constant   int64_t & ne00,
         | 
| 1971 | 
            +
                    constant   int64_t & ne01,
         | 
| 1972 | 
            +
                    constant   int64_t & ne02,
         | 
| 1973 | 
            +
                    constant   int64_t & ne03,
         | 
| 1974 | 
            +
                    constant  uint64_t & nb00,
         | 
| 1975 | 
            +
                    constant  uint64_t & nb01,
         | 
| 1976 | 
            +
                    constant  uint64_t & nb02,
         | 
| 1977 | 
            +
                    constant  uint64_t & nb03,
         | 
| 1978 | 
            +
                    constant   int64_t & ne0,
         | 
| 1979 | 
            +
                    constant   int64_t & ne1,
         | 
| 1980 | 
            +
                    constant   int64_t & ne2,
         | 
| 1981 | 
            +
                    constant   int64_t & ne3,
         | 
| 1982 | 
            +
                    constant  uint64_t & nb0,
         | 
| 1983 | 
            +
                    constant  uint64_t & nb1,
         | 
| 1984 | 
            +
                    constant  uint64_t & nb2,
         | 
| 1985 | 
            +
                    constant  uint64_t & nb3,
         | 
| 1986 | 
            +
                    uint3 tgpig[[threadgroup_position_in_grid]],
         | 
| 1987 | 
            +
                    uint3 tpitg[[thread_position_in_threadgroup]],
         | 
| 1988 | 
            +
                    uint3   ntg[[threads_per_threadgroup]]) {
         | 
| 1989 | 
            +
                const int64_t i03 = tgpig[2];
         | 
| 1990 | 
            +
                const int64_t i02 = tgpig[1];
         | 
| 1991 | 
            +
                const int64_t i01 = tgpig[0];
         | 
| 1992 | 
            +
             | 
| 1993 | 
            +
                const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
         | 
| 1994 | 
            +
             | 
| 1995 | 
            +
                const int64_t i3 = n / (ne2*ne1*ne0);
         | 
| 1996 | 
            +
                const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
         | 
| 1997 | 
            +
                const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
         | 
| 1998 | 
            +
                const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
         | 
| 1999 | 
            +
             | 
| 2000 | 
            +
                device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
         | 
| 2001 | 
            +
             | 
| 2002 | 
            +
                for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
         | 
| 2003 | 
            +
                    device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
         | 
| 2004 | 
            +
                    dst_data[i00] = src[0];
         | 
| 2005 | 
            +
                }
         | 
| 2006 | 
            +
            }
         | 
| 2007 | 
            +
             | 
| 1644 2008 | 
             
            kernel void kernel_cpy_f32_f16(
         | 
| 1645 2009 | 
             
                    device const float * src0,
         | 
| 1646 2010 | 
             
                    device        half * dst,
         | 
| @@ -1917,9 +2281,9 @@ kernel void kernel_cpy_f32_q4_1( | |
| 1917 2281 | 
             
            }
         | 
| 1918 2282 |  | 
| 1919 2283 | 
             
            kernel void kernel_concat(
         | 
| 1920 | 
            -
                device | 
| 1921 | 
            -
                device | 
| 1922 | 
            -
                device | 
| 2284 | 
            +
                device  const char * src0,
         | 
| 2285 | 
            +
                device  const char * src1,
         | 
| 2286 | 
            +
                device        char * dst,
         | 
| 1923 2287 | 
             
                constant   int64_t & ne00,
         | 
| 1924 2288 | 
             
                constant   int64_t & ne01,
         | 
| 1925 2289 | 
             
                constant   int64_t & ne02,
         | 
| @@ -1956,7 +2320,7 @@ kernel void kernel_concat( | |
| 1956 2320 | 
             
                const int64_t i12 = i02 % ne12;
         | 
| 1957 2321 | 
             
                const int64_t i11 = i01 % ne11;
         | 
| 1958 2322 |  | 
| 1959 | 
            -
                device const char * src0_ptr = src0 + i03 | 
| 2323 | 
            +
                device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + tpitg.x*nb00;
         | 
| 1960 2324 | 
             
                device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
         | 
| 1961 2325 | 
             
                device       char * dst_ptr  = dst  + i03*nb3  + i02*nb2  + i01*nb1  + tpitg.x*nb0;
         | 
| 1962 2326 |  | 
| @@ -2064,19 +2428,19 @@ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) { | |
| 2064 2428 |  | 
| 2065 2429 | 
             
            //====================================== dot products =========================
         | 
| 2066 2430 |  | 
| 2067 | 
            -
             | 
| 2431 | 
            +
            void kernel_mul_mv_q2_K_f32_impl(
         | 
| 2068 2432 | 
             
                    device const  void * src0,
         | 
| 2069 2433 | 
             
                    device const float * src1,
         | 
| 2070 2434 | 
             
                    device       float * dst,
         | 
| 2071 2435 | 
             
                    constant   int64_t & ne00,
         | 
| 2072 | 
            -
                    constant   int64_t & ne01 | 
| 2073 | 
            -
                    constant   int64_t & ne02 | 
| 2074 | 
            -
                    constant   int64_t & ne10 | 
| 2075 | 
            -
                    constant   int64_t & ne12 | 
| 2076 | 
            -
                    constant   int64_t & ne0 | 
| 2077 | 
            -
                    constant   int64_t & ne1 | 
| 2078 | 
            -
                    constant   uint    & r2 | 
| 2079 | 
            -
                    constant   uint    & r3 | 
| 2436 | 
            +
                    constant   int64_t & ne01,
         | 
| 2437 | 
            +
                    constant   int64_t & ne02,
         | 
| 2438 | 
            +
                    constant   int64_t & ne10,
         | 
| 2439 | 
            +
                    constant   int64_t & ne12,
         | 
| 2440 | 
            +
                    constant   int64_t & ne0,
         | 
| 2441 | 
            +
                    constant   int64_t & ne1,
         | 
| 2442 | 
            +
                    constant   uint    & r2,
         | 
| 2443 | 
            +
                    constant   uint    & r3,
         | 
| 2080 2444 | 
             
                    uint3 tgpig[[threadgroup_position_in_grid]],
         | 
| 2081 2445 | 
             
                    uint  tiisg[[thread_index_in_simdgroup]],
         | 
| 2082 2446 | 
             
                    uint  sgitg[[simdgroup_index_in_threadgroup]]) {
         | 
| @@ -2214,8 +2578,8 @@ kernel void kernel_mul_mv_q2_K_f32( | |
| 2214 2578 | 
             
                }
         | 
| 2215 2579 | 
             
            }
         | 
| 2216 2580 |  | 
| 2217 | 
            -
             | 
| 2218 | 
            -
            kernel void  | 
| 2581 | 
            +
            [[host_name("kernel_mul_mv_q2_K_f32")]]
         | 
| 2582 | 
            +
            kernel void kernel_mul_mv_q2_K_f32(
         | 
| 2219 2583 | 
             
                    device const  void * src0,
         | 
| 2220 2584 | 
             
                    device const float * src1,
         | 
| 2221 2585 | 
             
                    device       float * dst,
         | 
| @@ -2229,8 +2593,29 @@ kernel void kernel_mul_mv_q3_K_f32( | |
| 2229 2593 | 
             
                    constant   uint    & r2  [[buffer(17)]],
         | 
| 2230 2594 | 
             
                    constant   uint    & r3  [[buffer(18)]],
         | 
| 2231 2595 | 
             
                    uint3 tgpig[[threadgroup_position_in_grid]],
         | 
| 2232 | 
            -
                    uint | 
| 2233 | 
            -
                    uint | 
| 2596 | 
            +
                    uint  tiisg[[thread_index_in_simdgroup]],
         | 
| 2597 | 
            +
                    uint  sgitg[[simdgroup_index_in_threadgroup]]) {
         | 
| 2598 | 
            +
             | 
| 2599 | 
            +
                kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
         | 
| 2600 | 
            +
            }
         | 
| 2601 | 
            +
             | 
| 2602 | 
            +
            #if QK_K == 256
         | 
| 2603 | 
            +
            void kernel_mul_mv_q3_K_f32_impl(
         | 
| 2604 | 
            +
                    device const  void * src0,
         | 
| 2605 | 
            +
                    device const float * src1,
         | 
| 2606 | 
            +
                    device       float * dst,
         | 
| 2607 | 
            +
                    constant   int64_t & ne00,
         | 
| 2608 | 
            +
                    constant   int64_t & ne01,
         | 
| 2609 | 
            +
                    constant   int64_t & ne02,
         | 
| 2610 | 
            +
                    constant   int64_t & ne10,
         | 
| 2611 | 
            +
                    constant   int64_t & ne12,
         | 
| 2612 | 
            +
                    constant   int64_t & ne0,
         | 
| 2613 | 
            +
                    constant   int64_t & ne1,
         | 
| 2614 | 
            +
                    constant   uint    & r2,
         | 
| 2615 | 
            +
                    constant   uint    & r3,
         | 
| 2616 | 
            +
                    uint3 tgpig[[threadgroup_position_in_grid]],
         | 
| 2617 | 
            +
                    uint  tiisg[[thread_index_in_simdgroup]],
         | 
| 2618 | 
            +
                    uint  sgitg[[simdgroup_index_in_threadgroup]]) {
         | 
| 2234 2619 |  | 
| 2235 2620 | 
             
                const int nb = ne00/QK_K;
         | 
| 2236 2621 |  | 
| @@ -2373,19 +2758,19 @@ kernel void kernel_mul_mv_q3_K_f32( | |
| 2373 2758 | 
             
                }
         | 
| 2374 2759 | 
             
            }
         | 
| 2375 2760 | 
             
            #else
         | 
| 2376 | 
            -
             | 
| 2761 | 
            +
            void kernel_mul_mv_q3_K_f32_impl(
         | 
| 2377 2762 | 
             
                    device const  void * src0,
         | 
| 2378 2763 | 
             
                    device const float * src1,
         | 
| 2379 2764 | 
             
                    device       float * dst,
         | 
| 2380 2765 | 
             
                    constant   int64_t & ne00,
         | 
| 2381 | 
            -
                    constant   int64_t & ne01 | 
| 2382 | 
            -
                    constant   int64_t & ne02 | 
| 2383 | 
            -
                    constant   int64_t & ne10 | 
| 2384 | 
            -
                    constant   int64_t & ne12 | 
| 2385 | 
            -
                    constant   int64_t & ne0 | 
| 2386 | 
            -
                    constant   int64_t & ne1 | 
| 2387 | 
            -
                    constant   uint    & r2 | 
| 2388 | 
            -
                    constant   uint    & r3 | 
| 2766 | 
            +
                    constant   int64_t & ne01,
         | 
| 2767 | 
            +
                    constant   int64_t & ne02,
         | 
| 2768 | 
            +
                    constant   int64_t & ne10,
         | 
| 2769 | 
            +
                    constant   int64_t & ne12,
         | 
| 2770 | 
            +
                    constant   int64_t & ne0,
         | 
| 2771 | 
            +
                    constant   int64_t & ne1,
         | 
| 2772 | 
            +
                    constant   uint    & r2,
         | 
| 2773 | 
            +
                    constant   uint    & r3,
         | 
| 2389 2774 | 
             
                    uint3 tgpig[[threadgroup_position_in_grid]],
         | 
| 2390 2775 | 
             
                    uint  tiisg[[thread_index_in_simdgroup]],
         | 
| 2391 2776 | 
             
                    uint  sgitg[[simdgroup_index_in_threadgroup]]) {
         | 
| @@ -2450,20 +2835,41 @@ kernel void kernel_mul_mv_q3_K_f32( | |
| 2450 2835 | 
             
            }
         | 
| 2451 2836 | 
             
            #endif
         | 
| 2452 2837 |  | 
| 2838 | 
            +
            [[host_name("kernel_mul_mv_q3_K_f32")]]
         | 
| 2839 | 
            +
            kernel void kernel_mul_mv_q3_K_f32(
         | 
| 2840 | 
            +
                    device const  void * src0,
         | 
| 2841 | 
            +
                    device const float * src1,
         | 
| 2842 | 
            +
                    device       float * dst,
         | 
| 2843 | 
            +
                    constant   int64_t & ne00,
         | 
| 2844 | 
            +
                    constant   int64_t & ne01[[buffer(4)]],
         | 
| 2845 | 
            +
                    constant   int64_t & ne02[[buffer(5)]],
         | 
| 2846 | 
            +
                    constant   int64_t & ne10[[buffer(9)]],
         | 
| 2847 | 
            +
                    constant   int64_t & ne12[[buffer(11)]],
         | 
| 2848 | 
            +
                    constant   int64_t & ne0 [[buffer(15)]],
         | 
| 2849 | 
            +
                    constant   int64_t & ne1 [[buffer(16)]],
         | 
| 2850 | 
            +
                    constant   uint    & r2  [[buffer(17)]],
         | 
| 2851 | 
            +
                    constant   uint    & r3  [[buffer(18)]],
         | 
| 2852 | 
            +
                    uint3 tgpig[[threadgroup_position_in_grid]],
         | 
| 2853 | 
            +
                    uint  tiisg[[thread_index_in_simdgroup]],
         | 
| 2854 | 
            +
                    uint  sgitg[[simdgroup_index_in_threadgroup]]) {
         | 
| 2855 | 
            +
             | 
| 2856 | 
            +
                kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
         | 
| 2857 | 
            +
            }
         | 
| 2858 | 
            +
             | 
| 2453 2859 | 
             
            #if QK_K == 256
         | 
| 2454 | 
            -
             | 
| 2860 | 
            +
            void kernel_mul_mv_q4_K_f32_impl(
         | 
| 2455 2861 | 
             
                    device const  void * src0,
         | 
| 2456 2862 | 
             
                    device const float * src1,
         | 
| 2457 2863 | 
             
                    device       float * dst,
         | 
| 2458 2864 | 
             
                    constant   int64_t & ne00,
         | 
| 2459 | 
            -
                    constant   int64_t & ne01 | 
| 2460 | 
            -
                    constant   int64_t & ne02 | 
| 2461 | 
            -
                    constant   int64_t & ne10 | 
| 2462 | 
            -
                    constant   int64_t & ne12 | 
| 2463 | 
            -
                    constant   int64_t & ne0 | 
| 2464 | 
            -
                    constant   int64_t & ne1 | 
| 2465 | 
            -
                    constant   uint    & r2 | 
| 2466 | 
            -
                    constant   uint    & r3 | 
| 2865 | 
            +
                    constant   int64_t & ne01,
         | 
| 2866 | 
            +
                    constant   int64_t & ne02,
         | 
| 2867 | 
            +
                    constant   int64_t & ne10,
         | 
| 2868 | 
            +
                    constant   int64_t & ne12,
         | 
| 2869 | 
            +
                    constant   int64_t & ne0,
         | 
| 2870 | 
            +
                    constant   int64_t & ne1,
         | 
| 2871 | 
            +
                    constant   uint    & r2,
         | 
| 2872 | 
            +
                    constant   uint    & r3,
         | 
| 2467 2873 | 
             
                    uint3 tgpig[[threadgroup_position_in_grid]],
         | 
| 2468 2874 | 
             
                    uint  tiisg[[thread_index_in_simdgroup]],
         | 
| 2469 2875 | 
             
                    uint  sgitg[[simdgroup_index_in_threadgroup]]) {
         | 
| @@ -2564,19 +2970,19 @@ kernel void kernel_mul_mv_q4_K_f32( | |
| 2564 2970 | 
             
                }
         | 
| 2565 2971 | 
             
            }
         | 
| 2566 2972 | 
             
            #else
         | 
| 2567 | 
            -
             | 
| 2973 | 
            +
            void kernel_mul_mv_q4_K_f32_impl(
         | 
| 2568 2974 | 
             
                    device const  void * src0,
         | 
| 2569 2975 | 
             
                    device const float * src1,
         | 
| 2570 2976 | 
             
                    device       float * dst,
         | 
| 2571 2977 | 
             
                    constant   int64_t & ne00,
         | 
| 2572 | 
            -
                    constant   int64_t & ne01 | 
| 2573 | 
            -
                    constant   int64_t & ne02 | 
| 2574 | 
            -
                    constant   int64_t & ne10 | 
| 2575 | 
            -
                    constant   int64_t & ne12 | 
| 2576 | 
            -
                    constant   int64_t & ne0 | 
| 2577 | 
            -
                    constant   int64_t & ne1 | 
| 2578 | 
            -
                    constant   uint    & r2 | 
| 2579 | 
            -
                    constant   uint    & r3 | 
| 2978 | 
            +
                    constant   int64_t & ne01,
         | 
| 2979 | 
            +
                    constant   int64_t & ne02,
         | 
| 2980 | 
            +
                    constant   int64_t & ne10,
         | 
| 2981 | 
            +
                    constant   int64_t & ne12,
         | 
| 2982 | 
            +
                    constant   int64_t & ne0,
         | 
| 2983 | 
            +
                    constant   int64_t & ne1,
         | 
| 2984 | 
            +
                    constant   uint    & r2,
         | 
| 2985 | 
            +
                    constant   uint    & r3,
         | 
| 2580 2986 | 
             
                    uint3 tgpig[[threadgroup_position_in_grid]],
         | 
| 2581 2987 | 
             
                    uint tiisg[[thread_index_in_simdgroup]],
         | 
| 2582 2988 | 
             
                    uint sgitg[[simdgroup_index_in_threadgroup]]) {
         | 
| @@ -2660,7 +3066,8 @@ kernel void kernel_mul_mv_q4_K_f32( | |
| 2660 3066 | 
             
            }
         | 
| 2661 3067 | 
             
            #endif
         | 
| 2662 3068 |  | 
| 2663 | 
            -
             | 
| 3069 | 
            +
            [[host_name("kernel_mul_mv_q4_K_f32")]]
         | 
| 3070 | 
            +
            kernel void kernel_mul_mv_q4_K_f32(
         | 
| 2664 3071 | 
             
                    device const  void * src0,
         | 
| 2665 3072 | 
             
                    device const float * src1,
         | 
| 2666 3073 | 
             
                    device       float * dst,
         | 
| @@ -2677,6 +3084,26 @@ kernel void kernel_mul_mv_q5_K_f32( | |
| 2677 3084 | 
             
                    uint tiisg[[thread_index_in_simdgroup]],
         | 
| 2678 3085 | 
             
                    uint sgitg[[simdgroup_index_in_threadgroup]]) {
         | 
| 2679 3086 |  | 
| 3087 | 
            +
                kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
         | 
| 3088 | 
            +
            }
         | 
| 3089 | 
            +
             | 
| 3090 | 
            +
            void kernel_mul_mv_q5_K_f32_impl(
         | 
| 3091 | 
            +
                    device const  void * src0,
         | 
| 3092 | 
            +
                    device const float * src1,
         | 
| 3093 | 
            +
                    device       float * dst,
         | 
| 3094 | 
            +
                    constant   int64_t & ne00,
         | 
| 3095 | 
            +
                    constant   int64_t & ne01,
         | 
| 3096 | 
            +
                    constant   int64_t & ne02,
         | 
| 3097 | 
            +
                    constant   int64_t & ne10,
         | 
| 3098 | 
            +
                    constant   int64_t & ne12,
         | 
| 3099 | 
            +
                    constant   int64_t & ne0,
         | 
| 3100 | 
            +
                    constant   int64_t & ne1,
         | 
| 3101 | 
            +
                    constant   uint    & r2,
         | 
| 3102 | 
            +
                    constant   uint    & r3,
         | 
| 3103 | 
            +
                    uint3 tgpig[[threadgroup_position_in_grid]],
         | 
| 3104 | 
            +
                    uint  tiisg[[thread_index_in_simdgroup]],
         | 
| 3105 | 
            +
                    uint  sgitg[[simdgroup_index_in_threadgroup]]) {
         | 
| 3106 | 
            +
             | 
| 2680 3107 | 
             
                const int nb = ne00/QK_K;
         | 
| 2681 3108 |  | 
| 2682 3109 | 
             
                const int64_t r0 = tgpig.x;
         | 
| @@ -2836,10 +3263,10 @@ kernel void kernel_mul_mv_q5_K_f32( | |
| 2836 3263 | 
             
                        dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
         | 
| 2837 3264 | 
             
                    }
         | 
| 2838 3265 | 
             
                }
         | 
| 2839 | 
            -
             | 
| 2840 3266 | 
             
            }
         | 
| 2841 3267 |  | 
| 2842 | 
            -
             | 
| 3268 | 
            +
            [[host_name("kernel_mul_mv_q5_K_f32")]]
         | 
| 3269 | 
            +
            kernel void kernel_mul_mv_q5_K_f32(
         | 
| 2843 3270 | 
             
                    device const  void * src0,
         | 
| 2844 3271 | 
             
                    device const float * src1,
         | 
| 2845 3272 | 
             
                    device       float * dst,
         | 
| @@ -2853,21 +3280,41 @@ kernel void kernel_mul_mv_q6_K_f32( | |
| 2853 3280 | 
             
                    constant   uint    & r2  [[buffer(17)]],
         | 
| 2854 3281 | 
             
                    constant   uint    & r3  [[buffer(18)]],
         | 
| 2855 3282 | 
             
                    uint3 tgpig[[threadgroup_position_in_grid]],
         | 
| 2856 | 
            -
                    uint | 
| 2857 | 
            -
                    uint | 
| 2858 | 
            -
             | 
| 2859 | 
            -
                const uint8_t kmask1 = 0x03;
         | 
| 2860 | 
            -
                const uint8_t kmask2 = 0x0C;
         | 
| 2861 | 
            -
                const uint8_t kmask3 = 0x30;
         | 
| 2862 | 
            -
                const uint8_t kmask4 = 0xC0;
         | 
| 2863 | 
            -
             | 
| 2864 | 
            -
                const int nb = ne00/QK_K;
         | 
| 2865 | 
            -
             | 
| 2866 | 
            -
                const int64_t r0 = tgpig.x;
         | 
| 2867 | 
            -
                const int64_t r1 = tgpig.y;
         | 
| 2868 | 
            -
                const int     im = tgpig.z;
         | 
| 3283 | 
            +
                    uint  tiisg[[thread_index_in_simdgroup]],
         | 
| 3284 | 
            +
                    uint  sgitg[[simdgroup_index_in_threadgroup]]) {
         | 
| 2869 3285 |  | 
| 2870 | 
            -
                 | 
| 3286 | 
            +
                kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
         | 
| 3287 | 
            +
            }
         | 
| 3288 | 
            +
             | 
| 3289 | 
            +
            void kernel_mul_mv_q6_K_f32_impl(
         | 
| 3290 | 
            +
                    device const  void * src0,
         | 
| 3291 | 
            +
                    device const float * src1,
         | 
| 3292 | 
            +
                    device       float * dst,
         | 
| 3293 | 
            +
                    constant   int64_t & ne00,
         | 
| 3294 | 
            +
                    constant   int64_t & ne01,
         | 
| 3295 | 
            +
                    constant   int64_t & ne02,
         | 
| 3296 | 
            +
                    constant   int64_t & ne10,
         | 
| 3297 | 
            +
                    constant   int64_t & ne12,
         | 
| 3298 | 
            +
                    constant   int64_t & ne0,
         | 
| 3299 | 
            +
                    constant   int64_t & ne1,
         | 
| 3300 | 
            +
                    constant   uint    & r2,
         | 
| 3301 | 
            +
                    constant   uint    & r3,
         | 
| 3302 | 
            +
                    uint3 tgpig[[threadgroup_position_in_grid]],
         | 
| 3303 | 
            +
                    uint  tiisg[[thread_index_in_simdgroup]],
         | 
| 3304 | 
            +
                    uint  sgitg[[simdgroup_index_in_threadgroup]]) {
         | 
| 3305 | 
            +
             | 
| 3306 | 
            +
                const uint8_t kmask1 = 0x03;
         | 
| 3307 | 
            +
                const uint8_t kmask2 = 0x0C;
         | 
| 3308 | 
            +
                const uint8_t kmask3 = 0x30;
         | 
| 3309 | 
            +
                const uint8_t kmask4 = 0xC0;
         | 
| 3310 | 
            +
             | 
| 3311 | 
            +
                const int nb = ne00/QK_K;
         | 
| 3312 | 
            +
             | 
| 3313 | 
            +
                const int64_t r0 = tgpig.x;
         | 
| 3314 | 
            +
                const int64_t r1 = tgpig.y;
         | 
| 3315 | 
            +
                const int     im = tgpig.z;
         | 
| 3316 | 
            +
             | 
| 3317 | 
            +
                const int row = 2 * r0 + sgitg;
         | 
| 2871 3318 |  | 
| 2872 3319 | 
             
                const uint i12 = im%ne12;
         | 
| 2873 3320 | 
             
                const uint i13 = im/ne12;
         | 
| @@ -2945,6 +3392,27 @@ kernel void kernel_mul_mv_q6_K_f32( | |
| 2945 3392 | 
             
                }
         | 
| 2946 3393 | 
             
            }
         | 
| 2947 3394 |  | 
| 3395 | 
            +
            [[host_name("kernel_mul_mv_q6_K_f32")]]
         | 
| 3396 | 
            +
            kernel void kernel_mul_mv_q6_K_f32(
         | 
| 3397 | 
            +
                    device const  void * src0,
         | 
| 3398 | 
            +
                    device const float * src1,
         | 
| 3399 | 
            +
                    device       float * dst,
         | 
| 3400 | 
            +
                    constant   int64_t & ne00,
         | 
| 3401 | 
            +
                    constant   int64_t & ne01[[buffer(4)]],
         | 
| 3402 | 
            +
                    constant   int64_t & ne02[[buffer(5)]],
         | 
| 3403 | 
            +
                    constant   int64_t & ne10[[buffer(9)]],
         | 
| 3404 | 
            +
                    constant   int64_t & ne12[[buffer(11)]],
         | 
| 3405 | 
            +
                    constant   int64_t & ne0 [[buffer(15)]],
         | 
| 3406 | 
            +
                    constant   int64_t & ne1 [[buffer(16)]],
         | 
| 3407 | 
            +
                    constant   uint    & r2  [[buffer(17)]],
         | 
| 3408 | 
            +
                    constant   uint    & r3  [[buffer(18)]],
         | 
| 3409 | 
            +
                    uint3 tgpig[[threadgroup_position_in_grid]],
         | 
| 3410 | 
            +
                    uint  tiisg[[thread_index_in_simdgroup]],
         | 
| 3411 | 
            +
                    uint  sgitg[[simdgroup_index_in_threadgroup]]) {
         | 
| 3412 | 
            +
             | 
| 3413 | 
            +
                kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
         | 
| 3414 | 
            +
            }
         | 
| 3415 | 
            +
             | 
| 2948 3416 | 
             
            //============================= templates and their specializations =============================
         | 
| 2949 3417 |  | 
| 2950 3418 | 
             
            // NOTE: this is not dequantizing - we are simply fitting the template
         | 
| @@ -3062,10 +3530,10 @@ void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg | |
| 3062 3530 |  | 
| 3063 3531 | 
             
            template <typename type4x4>
         | 
| 3064 3532 | 
             
            void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
         | 
| 3065 | 
            -
                const  | 
| 3066 | 
            -
                const  | 
| 3533 | 
            +
                const float d = xb->d;
         | 
| 3534 | 
            +
                const float min = xb->dmin;
         | 
| 3067 3535 | 
             
                device const uint8_t * q = (device const uint8_t *)xb->qs;
         | 
| 3068 | 
            -
                 | 
| 3536 | 
            +
                float dl, ml;
         | 
| 3069 3537 | 
             
                uint8_t sc = xb->scales[il];
         | 
| 3070 3538 |  | 
| 3071 3539 | 
             
            #if QK_K == 256
         | 
| @@ -3135,10 +3603,10 @@ void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg | |
| 3135 3603 | 
             
                q = q + (il/4) * 32 + 16 * (il&1);
         | 
| 3136 3604 | 
             
                il = il & 3;
         | 
| 3137 3605 | 
             
                const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
         | 
| 3138 | 
            -
                const  | 
| 3139 | 
            -
                const  | 
| 3140 | 
            -
                const  | 
| 3141 | 
            -
                const  | 
| 3606 | 
            +
                const float d   = il < 2 ? xb->d : xb->d / 16.h;
         | 
| 3607 | 
            +
                const float min = xb->dmin;
         | 
| 3608 | 
            +
                const float dl = d * sc[0];
         | 
| 3609 | 
            +
                const float ml = min * sc[1];
         | 
| 3142 3610 | 
             
            #else
         | 
| 3143 3611 | 
             
                q = q + 16 * (il&1);
         | 
| 3144 3612 | 
             
                device const uint8_t * s = xb->scales;
         | 
| @@ -3165,13 +3633,13 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg | |
| 3165 3633 | 
             
                uint8_t ul = 1 << (il/2);
         | 
| 3166 3634 | 
             
                il = il & 3;
         | 
| 3167 3635 | 
             
                const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
         | 
| 3168 | 
            -
                const  | 
| 3169 | 
            -
                const  | 
| 3170 | 
            -
                const  | 
| 3171 | 
            -
                const  | 
| 3636 | 
            +
                const float d = il < 2 ? xb->d : xb->d / 16.h;
         | 
| 3637 | 
            +
                const float min = xb->dmin;
         | 
| 3638 | 
            +
                const float dl = d * sc[0];
         | 
| 3639 | 
            +
                const float ml = min * sc[1];
         | 
| 3172 3640 |  | 
| 3173 | 
            -
                const ushort mask | 
| 3174 | 
            -
                const  | 
| 3641 | 
            +
                const ushort mask  = il<2 ? 0x0F : 0xF0;
         | 
| 3642 | 
            +
                const float qh_val = il<2 ? 16.f : 256.f;
         | 
| 3175 3643 | 
             
                for (int i = 0; i < 16; ++i) {
         | 
| 3176 3644 | 
             
                    reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
         | 
| 3177 3645 | 
             
                }
         | 
| @@ -3219,22 +3687,90 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg | |
| 3219 3687 | 
             
            template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
         | 
| 3220 3688 | 
             
            kernel void kernel_get_rows(
         | 
| 3221 3689 | 
             
                    device const  void * src0,
         | 
| 3222 | 
            -
                    device const | 
| 3690 | 
            +
                    device const  char * src1,
         | 
| 3223 3691 | 
             
                    device       float * dst,
         | 
| 3224 3692 | 
             
                    constant   int64_t & ne00,
         | 
| 3225 3693 | 
             
                    constant  uint64_t & nb01,
         | 
| 3694 | 
            +
                    constant  uint64_t & nb02,
         | 
| 3695 | 
            +
                    constant   int64_t & ne10,
         | 
| 3696 | 
            +
                    constant  uint64_t & nb10,
         | 
| 3697 | 
            +
                    constant  uint64_t & nb11,
         | 
| 3226 3698 | 
             
                    constant  uint64_t & nb1,
         | 
| 3227 | 
            -
                     | 
| 3699 | 
            +
                    constant  uint64_t & nb2,
         | 
| 3700 | 
            +
                    uint3                tgpig[[threadgroup_position_in_grid]],
         | 
| 3228 3701 | 
             
                    uint                 tiitg[[thread_index_in_threadgroup]],
         | 
| 3229 | 
            -
                     | 
| 3230 | 
            -
                const  | 
| 3231 | 
            -
                const  | 
| 3702 | 
            +
                    uint3                tptg [[threads_per_threadgroup]]) {
         | 
| 3703 | 
            +
                //const int64_t i = tgpig;
         | 
| 3704 | 
            +
                //const int64_t r = ((device int32_t *) src1)[i];
         | 
| 3705 | 
            +
             | 
| 3706 | 
            +
                const int64_t i10 = tgpig.x;
         | 
| 3707 | 
            +
                const int64_t i11 = tgpig.y;
         | 
| 3708 | 
            +
             | 
| 3709 | 
            +
                const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
         | 
| 3232 3710 |  | 
| 3233 | 
            -
                 | 
| 3711 | 
            +
                const int64_t i02 = i11;
         | 
| 3712 | 
            +
             | 
| 3713 | 
            +
                for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) {
         | 
| 3234 3714 | 
             
                    float4x4 temp;
         | 
| 3235 3715 | 
             
                    dequantize_func(
         | 
| 3236 | 
            -
                        ((device const block_q *) ((device char *) src0 + r*nb01)) + ind/nl, ind%nl, temp);
         | 
| 3237 | 
            -
                    *(((device float4x4 *) ((device char *) dst +  | 
| 3716 | 
            +
                        ((device const block_q *) ((device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp);
         | 
| 3717 | 
            +
                    *(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp;
         | 
| 3718 | 
            +
                }
         | 
| 3719 | 
            +
            }
         | 
| 3720 | 
            +
             | 
| 3721 | 
            +
            kernel void kernel_get_rows_f32(
         | 
| 3722 | 
            +
                    device const  void * src0,
         | 
| 3723 | 
            +
                    device const  char * src1,
         | 
| 3724 | 
            +
                    device       float * dst,
         | 
| 3725 | 
            +
                    constant   int64_t & ne00,
         | 
| 3726 | 
            +
                    constant  uint64_t & nb01,
         | 
| 3727 | 
            +
                    constant  uint64_t & nb02,
         | 
| 3728 | 
            +
                    constant   int64_t & ne10,
         | 
| 3729 | 
            +
                    constant  uint64_t & nb10,
         | 
| 3730 | 
            +
                    constant  uint64_t & nb11,
         | 
| 3731 | 
            +
                    constant  uint64_t & nb1,
         | 
| 3732 | 
            +
                    constant  uint64_t & nb2,
         | 
| 3733 | 
            +
                    uint3                tgpig[[threadgroup_position_in_grid]],
         | 
| 3734 | 
            +
                    uint                 tiitg[[thread_index_in_threadgroup]],
         | 
| 3735 | 
            +
                    uint3                tptg [[threads_per_threadgroup]]) {
         | 
| 3736 | 
            +
                const int64_t i10 = tgpig.x;
         | 
| 3737 | 
            +
                const int64_t i11 = tgpig.y;
         | 
| 3738 | 
            +
             | 
| 3739 | 
            +
                const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
         | 
| 3740 | 
            +
             | 
| 3741 | 
            +
                const int64_t i02 = i11;
         | 
| 3742 | 
            +
             | 
| 3743 | 
            +
                for (int ind = tiitg; ind < ne00; ind += tptg.x) {
         | 
| 3744 | 
            +
                    ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
         | 
| 3745 | 
            +
                        ((device float *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
         | 
| 3746 | 
            +
                }
         | 
| 3747 | 
            +
            }
         | 
| 3748 | 
            +
             | 
| 3749 | 
            +
            kernel void kernel_get_rows_f16(
         | 
| 3750 | 
            +
                    device const  void * src0,
         | 
| 3751 | 
            +
                    device const  char * src1,
         | 
| 3752 | 
            +
                    device       float * dst,
         | 
| 3753 | 
            +
                    constant   int64_t & ne00,
         | 
| 3754 | 
            +
                    constant  uint64_t & nb01,
         | 
| 3755 | 
            +
                    constant  uint64_t & nb02,
         | 
| 3756 | 
            +
                    constant   int64_t & ne10,
         | 
| 3757 | 
            +
                    constant  uint64_t & nb10,
         | 
| 3758 | 
            +
                    constant  uint64_t & nb11,
         | 
| 3759 | 
            +
                    constant  uint64_t & nb1,
         | 
| 3760 | 
            +
                    constant  uint64_t & nb2,
         | 
| 3761 | 
            +
                    uint3                tgpig[[threadgroup_position_in_grid]],
         | 
| 3762 | 
            +
                    uint                 tiitg[[thread_index_in_threadgroup]],
         | 
| 3763 | 
            +
                    uint3                tptg [[threads_per_threadgroup]]) {
         | 
| 3764 | 
            +
                const int64_t i10 = tgpig.x;
         | 
| 3765 | 
            +
                const int64_t i11 = tgpig.y;
         | 
| 3766 | 
            +
             | 
| 3767 | 
            +
                const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
         | 
| 3768 | 
            +
             | 
| 3769 | 
            +
                const int64_t i02 = i11;
         | 
| 3770 | 
            +
             | 
| 3771 | 
            +
                for (int ind = tiitg; ind < ne00; ind += tptg.x) {
         | 
| 3772 | 
            +
                    ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
         | 
| 3773 | 
            +
                        ((device half *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
         | 
| 3238 3774 | 
             
                }
         | 
| 3239 3775 | 
             
            }
         | 
| 3240 3776 |  | 
| @@ -3426,19 +3962,22 @@ kernel void kernel_mul_mm(device const  uchar * src0, | |
| 3426 3962 |  | 
| 3427 3963 | 
             
            template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
         | 
| 3428 3964 | 
             
            kernel void kernel_mul_mm_id(
         | 
| 3429 | 
            -
                    device const  | 
| 3965 | 
            +
                    device const   uchar * ids,
         | 
| 3430 3966 | 
             
                    device const   uchar * src1,
         | 
| 3431 | 
            -
                    device          | 
| 3967 | 
            +
                    device         uchar * dst,
         | 
| 3968 | 
            +
                    constant     int64_t & nbi1,
         | 
| 3432 3969 | 
             
                    constant     int64_t & ne00,
         | 
| 3433 3970 | 
             
                    constant     int64_t & ne02,
         | 
| 3434 3971 | 
             
                    constant     int64_t & nb01,
         | 
| 3435 3972 | 
             
                    constant     int64_t & nb02,
         | 
| 3436 3973 | 
             
                    constant     int64_t & ne12,
         | 
| 3974 | 
            +
                    constant     int64_t & ne13,
         | 
| 3437 3975 | 
             
                    constant     int64_t & nb10,
         | 
| 3438 3976 | 
             
                    constant     int64_t & nb11,
         | 
| 3439 3977 | 
             
                    constant     int64_t & nb12,
         | 
| 3440 3978 | 
             
                    constant     int64_t & ne0,
         | 
| 3441 3979 | 
             
                    constant     int64_t & ne1,
         | 
| 3980 | 
            +
                    constant     int64_t & nb1,
         | 
| 3442 3981 | 
             
                    constant        uint & r2,
         | 
| 3443 3982 | 
             
                    constant        uint & r3,
         | 
| 3444 3983 | 
             
                    constant         int & idx,
         | 
| @@ -3456,10 +3995,16 @@ kernel void kernel_mul_mm_id( | |
| 3456 3995 | 
             
                    uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
         | 
| 3457 3996 | 
             
                device const uchar * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
         | 
| 3458 3997 |  | 
| 3998 | 
            +
                const int64_t bid = tgpig.z/(ne12*ne13);
         | 
| 3999 | 
            +
             | 
| 4000 | 
            +
                tgpig.z = tgpig.z%(ne12*ne13);
         | 
| 4001 | 
            +
             | 
| 4002 | 
            +
                const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
         | 
| 4003 | 
            +
             | 
| 3459 4004 | 
             
                kernel_mul_mm_impl<block_q, nl, dequantize_func>(
         | 
| 3460 | 
            -
                    src0[ | 
| 3461 | 
            -
                    src1,
         | 
| 3462 | 
            -
                    dst,
         | 
| 4005 | 
            +
                    src0[id],
         | 
| 4006 | 
            +
                    src1 + bid*nb11,
         | 
| 4007 | 
            +
                    (device float *) (dst + bid*nb1),
         | 
| 3463 4008 | 
             
                    ne00,
         | 
| 3464 4009 | 
             
                    ne02,
         | 
| 3465 4010 | 
             
                    nb01,
         | 
| @@ -3484,17 +4029,26 @@ kernel void kernel_mul_mm_id( | |
| 3484 4029 | 
             
            #define QK_NL 4
         | 
| 3485 4030 | 
             
            #endif
         | 
| 3486 4031 |  | 
| 4032 | 
            +
            //
         | 
| 4033 | 
            +
            // get rows
         | 
| 4034 | 
            +
            //
         | 
| 4035 | 
            +
             | 
| 3487 4036 | 
             
            typedef void (get_rows_t)(
         | 
| 3488 4037 | 
             
                    device const void * src0,
         | 
| 3489 | 
            -
                    device const | 
| 4038 | 
            +
                    device const char * src1,
         | 
| 3490 4039 | 
             
                    device      float * dst,
         | 
| 3491 4040 | 
             
                    constant  int64_t & ne00,
         | 
| 3492 4041 | 
             
                    constant uint64_t & nb01,
         | 
| 4042 | 
            +
                    constant uint64_t & nb02,
         | 
| 4043 | 
            +
                    constant  int64_t & ne10,
         | 
| 4044 | 
            +
                    constant uint64_t & nb10,
         | 
| 4045 | 
            +
                    constant uint64_t & nb11,
         | 
| 3493 4046 | 
             
                    constant uint64_t & nb1,
         | 
| 3494 | 
            -
                     | 
| 4047 | 
            +
                    constant uint64_t & nb2,
         | 
| 4048 | 
            +
                    uint3, uint, uint3);
         | 
| 3495 4049 |  | 
| 3496 | 
            -
            template [[host_name("kernel_get_rows_f32")]]  kernel get_rows_t kernel_get_rows<float4x4,   1, dequantize_f32>;
         | 
| 3497 | 
            -
            template [[host_name("kernel_get_rows_f16")]]  kernel get_rows_t kernel_get_rows<half4x4,    1, dequantize_f16>;
         | 
| 4050 | 
            +
            //template [[host_name("kernel_get_rows_f32")]]  kernel get_rows_t kernel_get_rows<float4x4,   1, dequantize_f32>;
         | 
| 4051 | 
            +
            //template [[host_name("kernel_get_rows_f16")]]  kernel get_rows_t kernel_get_rows<half4x4,    1, dequantize_f16>;
         | 
| 3498 4052 | 
             
            template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
         | 
| 3499 4053 | 
             
            template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
         | 
| 3500 4054 | 
             
            template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows<block_q5_0, 2, dequantize_q5_0>;
         | 
| @@ -3506,6 +4060,10 @@ template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows | |
| 3506 4060 | 
             
            template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
         | 
| 3507 4061 | 
             
            template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
         | 
| 3508 4062 |  | 
| 4063 | 
            +
            //
         | 
| 4064 | 
            +
            // matrix-matrix multiplication
         | 
| 4065 | 
            +
            //
         | 
| 4066 | 
            +
             | 
| 3509 4067 | 
             
            typedef void (mat_mm_t)(
         | 
| 3510 4068 | 
             
                    device const  uchar * src0,
         | 
| 3511 4069 | 
             
                    device const  uchar * src1,
         | 
| @@ -3538,20 +4096,27 @@ template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<b | |
| 3538 4096 | 
             
            template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
         | 
| 3539 4097 | 
             
            template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
         | 
| 3540 4098 |  | 
| 4099 | 
            +
            //
         | 
| 4100 | 
            +
            // indirect matrix-matrix multiplication
         | 
| 4101 | 
            +
            //
         | 
| 4102 | 
            +
             | 
| 3541 4103 | 
             
            typedef void (mat_mm_id_t)(
         | 
| 3542 | 
            -
                    device const  | 
| 4104 | 
            +
                    device const   uchar * ids,
         | 
| 3543 4105 | 
             
                    device const   uchar * src1,
         | 
| 3544 | 
            -
                    device          | 
| 4106 | 
            +
                    device         uchar * dst,
         | 
| 4107 | 
            +
                    constant     int64_t & nbi1,
         | 
| 3545 4108 | 
             
                    constant     int64_t & ne00,
         | 
| 3546 4109 | 
             
                    constant     int64_t & ne02,
         | 
| 3547 4110 | 
             
                    constant     int64_t & nb01,
         | 
| 3548 4111 | 
             
                    constant     int64_t & nb02,
         | 
| 3549 4112 | 
             
                    constant     int64_t & ne12,
         | 
| 4113 | 
            +
                    constant     int64_t & ne13,
         | 
| 3550 4114 | 
             
                    constant     int64_t & nb10,
         | 
| 3551 4115 | 
             
                    constant     int64_t & nb11,
         | 
| 3552 4116 | 
             
                    constant     int64_t & nb12,
         | 
| 3553 4117 | 
             
                    constant     int64_t & ne0,
         | 
| 3554 4118 | 
             
                    constant     int64_t & ne1,
         | 
| 4119 | 
            +
                    constant     int64_t & nb1,
         | 
| 3555 4120 | 
             
                    constant        uint & r2,
         | 
| 3556 4121 | 
             
                    constant        uint & r3,
         | 
| 3557 4122 | 
             
                    constant         int & idx,
         | 
| @@ -3578,3 +4143,775 @@ template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mu | |
| 3578 4143 | 
             
            template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>;
         | 
| 3579 4144 | 
             
            template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
         | 
| 3580 4145 | 
             
            template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
         | 
| 4146 | 
            +
             | 
| 4147 | 
            +
            //
         | 
| 4148 | 
            +
            // matrix-vector multiplication
         | 
| 4149 | 
            +
            //
         | 
| 4150 | 
            +
             | 
| 4151 | 
            +
            [[host_name("kernel_mul_mv_id_f32_f32")]]
         | 
| 4152 | 
            +
            kernel void kernel_mul_mv_id_f32_f32(
         | 
| 4153 | 
            +
                    device const    char * ids,
         | 
| 4154 | 
            +
                    device const    char * src1,
         | 
| 4155 | 
            +
                    device         uchar * dst,
         | 
| 4156 | 
            +
                    constant     int64_t & nbi1,
         | 
| 4157 | 
            +
                    constant     int64_t & ne00,
         | 
| 4158 | 
            +
                    constant     int64_t & ne01,
         | 
| 4159 | 
            +
                    constant     int64_t & ne02,
         | 
| 4160 | 
            +
                    constant    uint64_t & nb00,
         | 
| 4161 | 
            +
                    constant    uint64_t & nb01,
         | 
| 4162 | 
            +
                    constant    uint64_t & nb02,
         | 
| 4163 | 
            +
                    constant     int64_t & ne10,
         | 
| 4164 | 
            +
                    constant     int64_t & ne11,
         | 
| 4165 | 
            +
                    constant     int64_t & ne12,
         | 
| 4166 | 
            +
                    constant     int64_t & ne13,
         | 
| 4167 | 
            +
                    constant    uint64_t & nb10,
         | 
| 4168 | 
            +
                    constant    uint64_t & nb11,
         | 
| 4169 | 
            +
                    constant    uint64_t & nb12,
         | 
| 4170 | 
            +
                    constant     int64_t & ne0,
         | 
| 4171 | 
            +
                    constant     int64_t & ne1,
         | 
| 4172 | 
            +
                    constant     int64_t & nb1,
         | 
| 4173 | 
            +
                    constant        uint & r2,
         | 
| 4174 | 
            +
                    constant        uint & r3,
         | 
| 4175 | 
            +
                    constant         int & idx,
         | 
| 4176 | 
            +
                    device const    char * src00,
         | 
| 4177 | 
            +
                    device const    char * src01,
         | 
| 4178 | 
            +
                    device const    char * src02,
         | 
| 4179 | 
            +
                    device const    char * src03,
         | 
| 4180 | 
            +
                    device const    char * src04,
         | 
| 4181 | 
            +
                    device const    char * src05,
         | 
| 4182 | 
            +
                    device const    char * src06,
         | 
| 4183 | 
            +
                    device const    char * src07,
         | 
| 4184 | 
            +
                    uint3                  tgpig[[threadgroup_position_in_grid]],
         | 
| 4185 | 
            +
                    uint                   tiitg[[thread_index_in_threadgroup]],
         | 
| 4186 | 
            +
                    uint                   tiisg[[thread_index_in_simdgroup]],
         | 
| 4187 | 
            +
                    uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
         | 
| 4188 | 
            +
                device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
         | 
| 4189 | 
            +
             | 
| 4190 | 
            +
                const int64_t bid = tgpig.z/(ne12*ne13);
         | 
| 4191 | 
            +
             | 
| 4192 | 
            +
                tgpig.z = tgpig.z%(ne12*ne13);
         | 
| 4193 | 
            +
             | 
| 4194 | 
            +
                const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
         | 
| 4195 | 
            +
             | 
| 4196 | 
            +
                kernel_mul_mv_f32_f32_impl(
         | 
| 4197 | 
            +
                    src0[id],
         | 
| 4198 | 
            +
                    src1 + bid*nb11,
         | 
| 4199 | 
            +
                    (device float *) (dst + bid*nb1),
         | 
| 4200 | 
            +
                    ne00,
         | 
| 4201 | 
            +
                    ne01,
         | 
| 4202 | 
            +
                    ne02,
         | 
| 4203 | 
            +
                    nb00,
         | 
| 4204 | 
            +
                    nb01,
         | 
| 4205 | 
            +
                    nb02,
         | 
| 4206 | 
            +
                    ne10,
         | 
| 4207 | 
            +
                    ne11,
         | 
| 4208 | 
            +
                    ne12,
         | 
| 4209 | 
            +
                    nb10,
         | 
| 4210 | 
            +
                    nb11,
         | 
| 4211 | 
            +
                    nb12,
         | 
| 4212 | 
            +
                    ne0,
         | 
| 4213 | 
            +
                    ne1,
         | 
| 4214 | 
            +
                    r2,
         | 
| 4215 | 
            +
                    r3,
         | 
| 4216 | 
            +
                    tgpig,
         | 
| 4217 | 
            +
                    tiisg);
         | 
| 4218 | 
            +
            }
         | 
| 4219 | 
            +
             | 
| 4220 | 
            +
            [[host_name("kernel_mul_mv_id_f16_f32")]]
         | 
| 4221 | 
            +
            kernel void kernel_mul_mv_id_f16_f32(
         | 
| 4222 | 
            +
                    device const    char * ids,
         | 
| 4223 | 
            +
                    device const    char * src1,
         | 
| 4224 | 
            +
                    device         uchar * dst,
         | 
| 4225 | 
            +
                    constant     int64_t & nbi1,
         | 
| 4226 | 
            +
                    constant     int64_t & ne00,
         | 
| 4227 | 
            +
                    constant     int64_t & ne01,
         | 
| 4228 | 
            +
                    constant     int64_t & ne02,
         | 
| 4229 | 
            +
                    constant    uint64_t & nb00,
         | 
| 4230 | 
            +
                    constant    uint64_t & nb01,
         | 
| 4231 | 
            +
                    constant    uint64_t & nb02,
         | 
| 4232 | 
            +
                    constant     int64_t & ne10,
         | 
| 4233 | 
            +
                    constant     int64_t & ne11,
         | 
| 4234 | 
            +
                    constant     int64_t & ne12,
         | 
| 4235 | 
            +
                    constant     int64_t & ne13,
         | 
| 4236 | 
            +
                    constant    uint64_t & nb10,
         | 
| 4237 | 
            +
                    constant    uint64_t & nb11,
         | 
| 4238 | 
            +
                    constant    uint64_t & nb12,
         | 
| 4239 | 
            +
                    constant     int64_t & ne0,
         | 
| 4240 | 
            +
                    constant     int64_t & ne1,
         | 
| 4241 | 
            +
                    constant     int64_t & nb1,
         | 
| 4242 | 
            +
                    constant        uint & r2,
         | 
| 4243 | 
            +
                    constant        uint & r3,
         | 
| 4244 | 
            +
                    constant         int & idx,
         | 
| 4245 | 
            +
                    device const    char * src00,
         | 
| 4246 | 
            +
                    device const    char * src01,
         | 
| 4247 | 
            +
                    device const    char * src02,
         | 
| 4248 | 
            +
                    device const    char * src03,
         | 
| 4249 | 
            +
                    device const    char * src04,
         | 
| 4250 | 
            +
                    device const    char * src05,
         | 
| 4251 | 
            +
                    device const    char * src06,
         | 
| 4252 | 
            +
                    device const    char * src07,
         | 
| 4253 | 
            +
                    uint3                  tgpig[[threadgroup_position_in_grid]],
         | 
| 4254 | 
            +
                    uint                   tiitg[[thread_index_in_threadgroup]],
         | 
| 4255 | 
            +
                    uint                   tiisg[[thread_index_in_simdgroup]],
         | 
| 4256 | 
            +
                    uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
         | 
| 4257 | 
            +
                device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
         | 
| 4258 | 
            +
             | 
| 4259 | 
            +
                const int64_t bid = tgpig.z/(ne12*ne13);
         | 
| 4260 | 
            +
             | 
| 4261 | 
            +
                tgpig.z = tgpig.z%(ne12*ne13);
         | 
| 4262 | 
            +
             | 
| 4263 | 
            +
                const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
         | 
| 4264 | 
            +
             | 
| 4265 | 
            +
                kernel_mul_mv_f16_f32_impl(
         | 
| 4266 | 
            +
                    src0[id],
         | 
| 4267 | 
            +
                    src1 + bid*nb11,
         | 
| 4268 | 
            +
                    (device float *) (dst + bid*nb1),
         | 
| 4269 | 
            +
                    ne00,
         | 
| 4270 | 
            +
                    ne01,
         | 
| 4271 | 
            +
                    ne02,
         | 
| 4272 | 
            +
                    nb00,
         | 
| 4273 | 
            +
                    nb01,
         | 
| 4274 | 
            +
                    nb02,
         | 
| 4275 | 
            +
                    ne10,
         | 
| 4276 | 
            +
                    ne11,
         | 
| 4277 | 
            +
                    ne12,
         | 
| 4278 | 
            +
                    nb10,
         | 
| 4279 | 
            +
                    nb11,
         | 
| 4280 | 
            +
                    nb12,
         | 
| 4281 | 
            +
                    ne0,
         | 
| 4282 | 
            +
                    ne1,
         | 
| 4283 | 
            +
                    r2,
         | 
| 4284 | 
            +
                    r3,
         | 
| 4285 | 
            +
                    tgpig,
         | 
| 4286 | 
            +
                    tiisg);
         | 
| 4287 | 
            +
            }
         | 
| 4288 | 
            +
             | 
| 4289 | 
            +
            [[host_name("kernel_mul_mv_id_q8_0_f32")]]
         | 
| 4290 | 
            +
            kernel void kernel_mul_mv_id_q8_0_f32(
         | 
| 4291 | 
            +
                    device const    char * ids,
         | 
| 4292 | 
            +
                    device const    char * src1,
         | 
| 4293 | 
            +
                    device         uchar * dst,
         | 
| 4294 | 
            +
                    constant     int64_t & nbi1,
         | 
| 4295 | 
            +
                    constant     int64_t & ne00,
         | 
| 4296 | 
            +
                    constant     int64_t & ne01,
         | 
| 4297 | 
            +
                    constant     int64_t & ne02,
         | 
| 4298 | 
            +
                    constant    uint64_t & nb00,
         | 
| 4299 | 
            +
                    constant    uint64_t & nb01,
         | 
| 4300 | 
            +
                    constant    uint64_t & nb02,
         | 
| 4301 | 
            +
                    constant     int64_t & ne10,
         | 
| 4302 | 
            +
                    constant     int64_t & ne11,
         | 
| 4303 | 
            +
                    constant     int64_t & ne12,
         | 
| 4304 | 
            +
                    constant     int64_t & ne13,
         | 
| 4305 | 
            +
                    constant    uint64_t & nb10,
         | 
| 4306 | 
            +
                    constant    uint64_t & nb11,
         | 
| 4307 | 
            +
                    constant    uint64_t & nb12,
         | 
| 4308 | 
            +
                    constant     int64_t & ne0,
         | 
| 4309 | 
            +
                    constant     int64_t & ne1,
         | 
| 4310 | 
            +
                    constant     int64_t & nb1,
         | 
| 4311 | 
            +
                    constant        uint & r2,
         | 
| 4312 | 
            +
                    constant        uint & r3,
         | 
| 4313 | 
            +
                    constant         int & idx,
         | 
| 4314 | 
            +
                    device const    char * src00,
         | 
| 4315 | 
            +
                    device const    char * src01,
         | 
| 4316 | 
            +
                    device const    char * src02,
         | 
| 4317 | 
            +
                    device const    char * src03,
         | 
| 4318 | 
            +
                    device const    char * src04,
         | 
| 4319 | 
            +
                    device const    char * src05,
         | 
| 4320 | 
            +
                    device const    char * src06,
         | 
| 4321 | 
            +
                    device const    char * src07,
         | 
| 4322 | 
            +
                    uint3                  tgpig[[threadgroup_position_in_grid]],
         | 
| 4323 | 
            +
                    uint                   tiitg[[thread_index_in_threadgroup]],
         | 
| 4324 | 
            +
                    uint                   tiisg[[thread_index_in_simdgroup]],
         | 
| 4325 | 
            +
                    uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
         | 
| 4326 | 
            +
                device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
         | 
| 4327 | 
            +
             | 
| 4328 | 
            +
                const int64_t bid = tgpig.z/(ne12*ne13);
         | 
| 4329 | 
            +
             | 
| 4330 | 
            +
                tgpig.z = tgpig.z%(ne12*ne13);
         | 
| 4331 | 
            +
             | 
| 4332 | 
            +
                const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
         | 
| 4333 | 
            +
             | 
| 4334 | 
            +
                kernel_mul_mv_q8_0_f32_impl(
         | 
| 4335 | 
            +
                    src0[id],
         | 
| 4336 | 
            +
                    (device const float *) (src1 + bid*nb11),
         | 
| 4337 | 
            +
                    (device       float *) ( dst + bid*nb1),
         | 
| 4338 | 
            +
                    ne00,
         | 
| 4339 | 
            +
                    ne01,
         | 
| 4340 | 
            +
                    ne02,
         | 
| 4341 | 
            +
                    ne10,
         | 
| 4342 | 
            +
                    ne12,
         | 
| 4343 | 
            +
                    ne0,
         | 
| 4344 | 
            +
                    ne1,
         | 
| 4345 | 
            +
                    r2,
         | 
| 4346 | 
            +
                    r3,
         | 
| 4347 | 
            +
                    tgpig,
         | 
| 4348 | 
            +
                    tiisg,
         | 
| 4349 | 
            +
                    sgitg);
         | 
| 4350 | 
            +
            }
         | 
| 4351 | 
            +
             | 
| 4352 | 
            +
            [[host_name("kernel_mul_mv_id_q4_0_f32")]]
         | 
| 4353 | 
            +
            kernel void kernel_mul_mv_id_q4_0_f32(
         | 
| 4354 | 
            +
                    device const    char * ids,
         | 
| 4355 | 
            +
                    device const    char * src1,
         | 
| 4356 | 
            +
                    device         uchar * dst,
         | 
| 4357 | 
            +
                    constant     int64_t & nbi1,
         | 
| 4358 | 
            +
                    constant     int64_t & ne00,
         | 
| 4359 | 
            +
                    constant     int64_t & ne01,
         | 
| 4360 | 
            +
                    constant     int64_t & ne02,
         | 
| 4361 | 
            +
                    constant    uint64_t & nb00,
         | 
| 4362 | 
            +
                    constant    uint64_t & nb01,
         | 
| 4363 | 
            +
                    constant    uint64_t & nb02,
         | 
| 4364 | 
            +
                    constant     int64_t & ne10,
         | 
| 4365 | 
            +
                    constant     int64_t & ne11,
         | 
| 4366 | 
            +
                    constant     int64_t & ne12,
         | 
| 4367 | 
            +
                    constant     int64_t & ne13,
         | 
| 4368 | 
            +
                    constant    uint64_t & nb10,
         | 
| 4369 | 
            +
                    constant    uint64_t & nb11,
         | 
| 4370 | 
            +
                    constant    uint64_t & nb12,
         | 
| 4371 | 
            +
                    constant     int64_t & ne0,
         | 
| 4372 | 
            +
                    constant     int64_t & ne1,
         | 
| 4373 | 
            +
                    constant     int64_t & nb1,
         | 
| 4374 | 
            +
                    constant        uint & r2,
         | 
| 4375 | 
            +
                    constant        uint & r3,
         | 
| 4376 | 
            +
                    constant         int & idx,
         | 
| 4377 | 
            +
                    device const    char * src00,
         | 
| 4378 | 
            +
                    device const    char * src01,
         | 
| 4379 | 
            +
                    device const    char * src02,
         | 
| 4380 | 
            +
                    device const    char * src03,
         | 
| 4381 | 
            +
                    device const    char * src04,
         | 
| 4382 | 
            +
                    device const    char * src05,
         | 
| 4383 | 
            +
                    device const    char * src06,
         | 
| 4384 | 
            +
                    device const    char * src07,
         | 
| 4385 | 
            +
                    uint3                  tgpig[[threadgroup_position_in_grid]],
         | 
| 4386 | 
            +
                    uint                   tiitg[[thread_index_in_threadgroup]],
         | 
| 4387 | 
            +
                    uint                   tiisg[[thread_index_in_simdgroup]],
         | 
| 4388 | 
            +
                    uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
         | 
| 4389 | 
            +
                device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
         | 
| 4390 | 
            +
             | 
| 4391 | 
            +
                const int64_t bid = tgpig.z/(ne12*ne13);
         | 
| 4392 | 
            +
             | 
| 4393 | 
            +
                tgpig.z = tgpig.z%(ne12*ne13);
         | 
| 4394 | 
            +
             | 
| 4395 | 
            +
                const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
         | 
| 4396 | 
            +
             | 
| 4397 | 
            +
                mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
         | 
| 4398 | 
            +
                    src0[id],
         | 
| 4399 | 
            +
                    (device const float *) (src1 + bid*nb11),
         | 
| 4400 | 
            +
                    (device       float *) ( dst + bid*nb1),
         | 
| 4401 | 
            +
                    ne00,
         | 
| 4402 | 
            +
                    ne01,
         | 
| 4403 | 
            +
                    ne02,
         | 
| 4404 | 
            +
                    ne10,
         | 
| 4405 | 
            +
                    ne12,
         | 
| 4406 | 
            +
                    ne0,
         | 
| 4407 | 
            +
                    ne1,
         | 
| 4408 | 
            +
                    r2,
         | 
| 4409 | 
            +
                    r3,
         | 
| 4410 | 
            +
                    tgpig,
         | 
| 4411 | 
            +
                    tiisg,
         | 
| 4412 | 
            +
                    sgitg);
         | 
| 4413 | 
            +
            }
         | 
| 4414 | 
            +
             | 
| 4415 | 
            +
            [[host_name("kernel_mul_mv_id_q4_1_f32")]]
         | 
| 4416 | 
            +
            kernel void kernel_mul_mv_id_q4_1_f32(
         | 
| 4417 | 
            +
                    device const    char * ids,
         | 
| 4418 | 
            +
                    device const    char * src1,
         | 
| 4419 | 
            +
                    device         uchar * dst,
         | 
| 4420 | 
            +
                    constant     int64_t & nbi1,
         | 
| 4421 | 
            +
                    constant     int64_t & ne00,
         | 
| 4422 | 
            +
                    constant     int64_t & ne01,
         | 
| 4423 | 
            +
                    constant     int64_t & ne02,
         | 
| 4424 | 
            +
                    constant    uint64_t & nb00,
         | 
| 4425 | 
            +
                    constant    uint64_t & nb01,
         | 
| 4426 | 
            +
                    constant    uint64_t & nb02,
         | 
| 4427 | 
            +
                    constant     int64_t & ne10,
         | 
| 4428 | 
            +
                    constant     int64_t & ne11,
         | 
| 4429 | 
            +
                    constant     int64_t & ne12,
         | 
| 4430 | 
            +
                    constant     int64_t & ne13,
         | 
| 4431 | 
            +
                    constant    uint64_t & nb10,
         | 
| 4432 | 
            +
                    constant    uint64_t & nb11,
         | 
| 4433 | 
            +
                    constant    uint64_t & nb12,
         | 
| 4434 | 
            +
                    constant     int64_t & ne0,
         | 
| 4435 | 
            +
                    constant     int64_t & ne1,
         | 
| 4436 | 
            +
                    constant     int64_t & nb1,
         | 
| 4437 | 
            +
                    constant        uint & r2,
         | 
| 4438 | 
            +
                    constant        uint & r3,
         | 
| 4439 | 
            +
                    constant         int & idx,
         | 
| 4440 | 
            +
                    device const    char * src00,
         | 
| 4441 | 
            +
                    device const    char * src01,
         | 
| 4442 | 
            +
                    device const    char * src02,
         | 
| 4443 | 
            +
                    device const    char * src03,
         | 
| 4444 | 
            +
                    device const    char * src04,
         | 
| 4445 | 
            +
                    device const    char * src05,
         | 
| 4446 | 
            +
                    device const    char * src06,
         | 
| 4447 | 
            +
                    device const    char * src07,
         | 
| 4448 | 
            +
                    uint3                  tgpig[[threadgroup_position_in_grid]],
         | 
| 4449 | 
            +
                    uint                   tiitg[[thread_index_in_threadgroup]],
         | 
| 4450 | 
            +
                    uint                   tiisg[[thread_index_in_simdgroup]],
         | 
| 4451 | 
            +
                    uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
         | 
| 4452 | 
            +
                device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
         | 
| 4453 | 
            +
             | 
| 4454 | 
            +
                const int64_t bid = tgpig.z/(ne12*ne13);
         | 
| 4455 | 
            +
             | 
| 4456 | 
            +
                tgpig.z = tgpig.z%(ne12*ne13);
         | 
| 4457 | 
            +
             | 
| 4458 | 
            +
                const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
         | 
| 4459 | 
            +
             | 
| 4460 | 
            +
                mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
         | 
| 4461 | 
            +
                    src0[id],
         | 
| 4462 | 
            +
                    (device const float *) (src1 + bid*nb11),
         | 
| 4463 | 
            +
                    (device       float *) ( dst + bid*nb1),
         | 
| 4464 | 
            +
                    ne00,
         | 
| 4465 | 
            +
                    ne01,
         | 
| 4466 | 
            +
                    ne02,
         | 
| 4467 | 
            +
                    ne10,
         | 
| 4468 | 
            +
                    ne12,
         | 
| 4469 | 
            +
                    ne0,
         | 
| 4470 | 
            +
                    ne1,
         | 
| 4471 | 
            +
                    r2,
         | 
| 4472 | 
            +
                    r3,
         | 
| 4473 | 
            +
                    tgpig,
         | 
| 4474 | 
            +
                    tiisg,
         | 
| 4475 | 
            +
                    sgitg);
         | 
| 4476 | 
            +
            }
         | 
| 4477 | 
            +
             | 
| 4478 | 
            +
            [[host_name("kernel_mul_mv_id_q5_0_f32")]]
         | 
| 4479 | 
            +
            kernel void kernel_mul_mv_id_q5_0_f32(
         | 
| 4480 | 
            +
                    device const    char * ids,
         | 
| 4481 | 
            +
                    device const    char * src1,
         | 
| 4482 | 
            +
                    device         uchar * dst,
         | 
| 4483 | 
            +
                    constant     int64_t & nbi1,
         | 
| 4484 | 
            +
                    constant     int64_t & ne00,
         | 
| 4485 | 
            +
                    constant     int64_t & ne01,
         | 
| 4486 | 
            +
                    constant     int64_t & ne02,
         | 
| 4487 | 
            +
                    constant    uint64_t & nb00,
         | 
| 4488 | 
            +
                    constant    uint64_t & nb01,
         | 
| 4489 | 
            +
                    constant    uint64_t & nb02,
         | 
| 4490 | 
            +
                    constant     int64_t & ne10,
         | 
| 4491 | 
            +
                    constant     int64_t & ne11,
         | 
| 4492 | 
            +
                    constant     int64_t & ne12,
         | 
| 4493 | 
            +
                    constant     int64_t & ne13,
         | 
| 4494 | 
            +
                    constant    uint64_t & nb10,
         | 
| 4495 | 
            +
                    constant    uint64_t & nb11,
         | 
| 4496 | 
            +
                    constant    uint64_t & nb12,
         | 
| 4497 | 
            +
                    constant     int64_t & ne0,
         | 
| 4498 | 
            +
                    constant     int64_t & ne1,
         | 
| 4499 | 
            +
                    constant     int64_t & nb1,
         | 
| 4500 | 
            +
                    constant        uint & r2,
         | 
| 4501 | 
            +
                    constant        uint & r3,
         | 
| 4502 | 
            +
                    constant         int & idx,
         | 
| 4503 | 
            +
                    device const    char * src00,
         | 
| 4504 | 
            +
                    device const    char * src01,
         | 
| 4505 | 
            +
                    device const    char * src02,
         | 
| 4506 | 
            +
                    device const    char * src03,
         | 
| 4507 | 
            +
                    device const    char * src04,
         | 
| 4508 | 
            +
                    device const    char * src05,
         | 
| 4509 | 
            +
                    device const    char * src06,
         | 
| 4510 | 
            +
                    device const    char * src07,
         | 
| 4511 | 
            +
                    uint3                  tgpig[[threadgroup_position_in_grid]],
         | 
| 4512 | 
            +
                    uint                   tiitg[[thread_index_in_threadgroup]],
         | 
| 4513 | 
            +
                    uint                   tiisg[[thread_index_in_simdgroup]],
         | 
| 4514 | 
            +
                    uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
         | 
| 4515 | 
            +
                device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
         | 
| 4516 | 
            +
             | 
| 4517 | 
            +
                const int64_t bid = tgpig.z/(ne12*ne13);
         | 
| 4518 | 
            +
             | 
| 4519 | 
            +
                tgpig.z = tgpig.z%(ne12*ne13);
         | 
| 4520 | 
            +
             | 
| 4521 | 
            +
                const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
         | 
| 4522 | 
            +
             | 
| 4523 | 
            +
                mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
         | 
| 4524 | 
            +
                    src0[id],
         | 
| 4525 | 
            +
                    (device const float *) (src1 + bid*nb11),
         | 
| 4526 | 
            +
                    (device       float *) ( dst + bid*nb1),
         | 
| 4527 | 
            +
                    ne00,
         | 
| 4528 | 
            +
                    ne01,
         | 
| 4529 | 
            +
                    ne02,
         | 
| 4530 | 
            +
                    ne10,
         | 
| 4531 | 
            +
                    ne12,
         | 
| 4532 | 
            +
                    ne0,
         | 
| 4533 | 
            +
                    ne1,
         | 
| 4534 | 
            +
                    r2,
         | 
| 4535 | 
            +
                    r3,
         | 
| 4536 | 
            +
                    tgpig,
         | 
| 4537 | 
            +
                    tiisg,
         | 
| 4538 | 
            +
                    sgitg);
         | 
| 4539 | 
            +
            }
         | 
| 4540 | 
            +
             | 
| 4541 | 
            +
            [[host_name("kernel_mul_mv_id_q5_1_f32")]]
         | 
| 4542 | 
            +
            kernel void kernel_mul_mv_id_q5_1_f32(
         | 
| 4543 | 
            +
                    device const    char * ids,
         | 
| 4544 | 
            +
                    device const    char * src1,
         | 
| 4545 | 
            +
                    device         uchar * dst,
         | 
| 4546 | 
            +
                    constant     int64_t & nbi1,
         | 
| 4547 | 
            +
                    constant     int64_t & ne00,
         | 
| 4548 | 
            +
                    constant     int64_t & ne01,
         | 
| 4549 | 
            +
                    constant     int64_t & ne02,
         | 
| 4550 | 
            +
                    constant    uint64_t & nb00,
         | 
| 4551 | 
            +
                    constant    uint64_t & nb01,
         | 
| 4552 | 
            +
                    constant    uint64_t & nb02,
         | 
| 4553 | 
            +
                    constant     int64_t & ne10,
         | 
| 4554 | 
            +
                    constant     int64_t & ne11,
         | 
| 4555 | 
            +
                    constant     int64_t & ne12,
         | 
| 4556 | 
            +
                    constant     int64_t & ne13,
         | 
| 4557 | 
            +
                    constant    uint64_t & nb10,
         | 
| 4558 | 
            +
                    constant    uint64_t & nb11,
         | 
| 4559 | 
            +
                    constant    uint64_t & nb12,
         | 
| 4560 | 
            +
                    constant     int64_t & ne0,
         | 
| 4561 | 
            +
                    constant     int64_t & ne1,
         | 
| 4562 | 
            +
                    constant     int64_t & nb1,
         | 
| 4563 | 
            +
                    constant        uint & r2,
         | 
| 4564 | 
            +
                    constant        uint & r3,
         | 
| 4565 | 
            +
                    constant         int & idx,
         | 
| 4566 | 
            +
                    device const    char * src00,
         | 
| 4567 | 
            +
                    device const    char * src01,
         | 
| 4568 | 
            +
                    device const    char * src02,
         | 
| 4569 | 
            +
                    device const    char * src03,
         | 
| 4570 | 
            +
                    device const    char * src04,
         | 
| 4571 | 
            +
                    device const    char * src05,
         | 
| 4572 | 
            +
                    device const    char * src06,
         | 
| 4573 | 
            +
                    device const    char * src07,
         | 
| 4574 | 
            +
                    uint3                  tgpig[[threadgroup_position_in_grid]],
         | 
| 4575 | 
            +
                    uint                   tiitg[[thread_index_in_threadgroup]],
         | 
| 4576 | 
            +
                    uint                   tiisg[[thread_index_in_simdgroup]],
         | 
| 4577 | 
            +
                    uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
         | 
| 4578 | 
            +
                device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
         | 
| 4579 | 
            +
             | 
| 4580 | 
            +
                const int64_t bid = tgpig.z/(ne12*ne13);
         | 
| 4581 | 
            +
             | 
| 4582 | 
            +
                tgpig.z = tgpig.z%(ne12*ne13);
         | 
| 4583 | 
            +
             | 
| 4584 | 
            +
                const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
         | 
| 4585 | 
            +
             | 
| 4586 | 
            +
                mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
         | 
| 4587 | 
            +
                    src0[id],
         | 
| 4588 | 
            +
                    (device const float *) (src1 + bid*nb11),
         | 
| 4589 | 
            +
                    (device       float *) ( dst + bid*nb1),
         | 
| 4590 | 
            +
                    ne00,
         | 
| 4591 | 
            +
                    ne01,
         | 
| 4592 | 
            +
                    ne02,
         | 
| 4593 | 
            +
                    ne10,
         | 
| 4594 | 
            +
                    ne12,
         | 
| 4595 | 
            +
                    ne0,
         | 
| 4596 | 
            +
                    ne1,
         | 
| 4597 | 
            +
                    r2,
         | 
| 4598 | 
            +
                    r3,
         | 
| 4599 | 
            +
                    tgpig,
         | 
| 4600 | 
            +
                    tiisg,
         | 
| 4601 | 
            +
                    sgitg);
         | 
| 4602 | 
            +
            }
         | 
| 4603 | 
            +
             | 
| 4604 | 
            +
            [[host_name("kernel_mul_mv_id_q2_K_f32")]]
         | 
| 4605 | 
            +
            kernel void kernel_mul_mv_id_q2_K_f32(
         | 
| 4606 | 
            +
                    device const    char * ids,
         | 
| 4607 | 
            +
                    device const    char * src1,
         | 
| 4608 | 
            +
                    device         uchar * dst,
         | 
| 4609 | 
            +
                    constant     int64_t & nbi1,
         | 
| 4610 | 
            +
                    constant     int64_t & ne00,
         | 
| 4611 | 
            +
                    constant     int64_t & ne01,
         | 
| 4612 | 
            +
                    constant     int64_t & ne02,
         | 
| 4613 | 
            +
                    constant    uint64_t & nb00,
         | 
| 4614 | 
            +
                    constant    uint64_t & nb01,
         | 
| 4615 | 
            +
                    constant    uint64_t & nb02,
         | 
| 4616 | 
            +
                    constant     int64_t & ne10,
         | 
| 4617 | 
            +
                    constant     int64_t & ne11,
         | 
| 4618 | 
            +
                    constant     int64_t & ne12,
         | 
| 4619 | 
            +
                    constant     int64_t & ne13,
         | 
| 4620 | 
            +
                    constant    uint64_t & nb10,
         | 
| 4621 | 
            +
                    constant    uint64_t & nb11,
         | 
| 4622 | 
            +
                    constant    uint64_t & nb12,
         | 
| 4623 | 
            +
                    constant     int64_t & ne0,
         | 
| 4624 | 
            +
                    constant     int64_t & ne1,
         | 
| 4625 | 
            +
                    constant     int64_t & nb1,
         | 
| 4626 | 
            +
                    constant        uint & r2,
         | 
| 4627 | 
            +
                    constant        uint & r3,
         | 
| 4628 | 
            +
                    constant         int & idx,
         | 
| 4629 | 
            +
                    device const    char * src00,
         | 
| 4630 | 
            +
                    device const    char * src01,
         | 
| 4631 | 
            +
                    device const    char * src02,
         | 
| 4632 | 
            +
                    device const    char * src03,
         | 
| 4633 | 
            +
                    device const    char * src04,
         | 
| 4634 | 
            +
                    device const    char * src05,
         | 
| 4635 | 
            +
                    device const    char * src06,
         | 
| 4636 | 
            +
                    device const    char * src07,
         | 
| 4637 | 
            +
                    uint3                  tgpig[[threadgroup_position_in_grid]],
         | 
| 4638 | 
            +
                    uint                   tiitg[[thread_index_in_threadgroup]],
         | 
| 4639 | 
            +
                    uint                   tiisg[[thread_index_in_simdgroup]],
         | 
| 4640 | 
            +
                    uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
         | 
| 4641 | 
            +
                device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
         | 
| 4642 | 
            +
             | 
| 4643 | 
            +
                const int64_t bid = tgpig.z/(ne12*ne13);
         | 
| 4644 | 
            +
             | 
| 4645 | 
            +
                tgpig.z = tgpig.z%(ne12*ne13);
         | 
| 4646 | 
            +
             | 
| 4647 | 
            +
                const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
         | 
| 4648 | 
            +
             | 
| 4649 | 
            +
                kernel_mul_mv_q2_K_f32_impl(
         | 
| 4650 | 
            +
                    src0[id],
         | 
| 4651 | 
            +
                    (device const float *) (src1 + bid*nb11),
         | 
| 4652 | 
            +
                    (device       float *) ( dst + bid*nb1),
         | 
| 4653 | 
            +
                    ne00,
         | 
| 4654 | 
            +
                    ne01,
         | 
| 4655 | 
            +
                    ne02,
         | 
| 4656 | 
            +
                    ne10,
         | 
| 4657 | 
            +
                    ne12,
         | 
| 4658 | 
            +
                    ne0,
         | 
| 4659 | 
            +
                    ne1,
         | 
| 4660 | 
            +
                    r2,
         | 
| 4661 | 
            +
                    r3,
         | 
| 4662 | 
            +
                    tgpig,
         | 
| 4663 | 
            +
                    tiisg,
         | 
| 4664 | 
            +
                    sgitg);
         | 
| 4665 | 
            +
            }
         | 
| 4666 | 
            +
             | 
| 4667 | 
            +
            [[host_name("kernel_mul_mv_id_q3_K_f32")]]
         | 
| 4668 | 
            +
            kernel void kernel_mul_mv_id_q3_K_f32(
         | 
| 4669 | 
            +
                    device const    char * ids,
         | 
| 4670 | 
            +
                    device const    char * src1,
         | 
| 4671 | 
            +
                    device         uchar * dst,
         | 
| 4672 | 
            +
                    constant     int64_t & nbi1,
         | 
| 4673 | 
            +
                    constant     int64_t & ne00,
         | 
| 4674 | 
            +
                    constant     int64_t & ne01,
         | 
| 4675 | 
            +
                    constant     int64_t & ne02,
         | 
| 4676 | 
            +
                    constant    uint64_t & nb00,
         | 
| 4677 | 
            +
                    constant    uint64_t & nb01,
         | 
| 4678 | 
            +
                    constant    uint64_t & nb02,
         | 
| 4679 | 
            +
                    constant     int64_t & ne10,
         | 
| 4680 | 
            +
                    constant     int64_t & ne11,
         | 
| 4681 | 
            +
                    constant     int64_t & ne12,
         | 
| 4682 | 
            +
                    constant     int64_t & ne13,
         | 
| 4683 | 
            +
                    constant    uint64_t & nb10,
         | 
| 4684 | 
            +
                    constant    uint64_t & nb11,
         | 
| 4685 | 
            +
                    constant    uint64_t & nb12,
         | 
| 4686 | 
            +
                    constant     int64_t & ne0,
         | 
| 4687 | 
            +
                    constant     int64_t & ne1,
         | 
| 4688 | 
            +
                    constant     int64_t & nb1,
         | 
| 4689 | 
            +
                    constant        uint & r2,
         | 
| 4690 | 
            +
                    constant        uint & r3,
         | 
| 4691 | 
            +
                    constant         int & idx,
         | 
| 4692 | 
            +
                    device const    char * src00,
         | 
| 4693 | 
            +
                    device const    char * src01,
         | 
| 4694 | 
            +
                    device const    char * src02,
         | 
| 4695 | 
            +
                    device const    char * src03,
         | 
| 4696 | 
            +
                    device const    char * src04,
         | 
| 4697 | 
            +
                    device const    char * src05,
         | 
| 4698 | 
            +
                    device const    char * src06,
         | 
| 4699 | 
            +
                    device const    char * src07,
         | 
| 4700 | 
            +
                    uint3                  tgpig[[threadgroup_position_in_grid]],
         | 
| 4701 | 
            +
                    uint                   tiitg[[thread_index_in_threadgroup]],
         | 
| 4702 | 
            +
                    uint                   tiisg[[thread_index_in_simdgroup]],
         | 
| 4703 | 
            +
                    uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
         | 
| 4704 | 
            +
                device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
         | 
| 4705 | 
            +
             | 
| 4706 | 
            +
                const int64_t bid = tgpig.z/(ne12*ne13);
         | 
| 4707 | 
            +
             | 
| 4708 | 
            +
                tgpig.z = tgpig.z%(ne12*ne13);
         | 
| 4709 | 
            +
             | 
| 4710 | 
            +
                const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
         | 
| 4711 | 
            +
             | 
| 4712 | 
            +
                kernel_mul_mv_q3_K_f32_impl(
         | 
| 4713 | 
            +
                    src0[id],
         | 
| 4714 | 
            +
                    (device const float *) (src1 + bid*nb11),
         | 
| 4715 | 
            +
                    (device       float *) ( dst + bid*nb1),
         | 
| 4716 | 
            +
                    ne00,
         | 
| 4717 | 
            +
                    ne01,
         | 
| 4718 | 
            +
                    ne02,
         | 
| 4719 | 
            +
                    ne10,
         | 
| 4720 | 
            +
                    ne12,
         | 
| 4721 | 
            +
                    ne0,
         | 
| 4722 | 
            +
                    ne1,
         | 
| 4723 | 
            +
                    r2,
         | 
| 4724 | 
            +
                    r3,
         | 
| 4725 | 
            +
                    tgpig,
         | 
| 4726 | 
            +
                    tiisg,
         | 
| 4727 | 
            +
                    sgitg);
         | 
| 4728 | 
            +
            }
         | 
| 4729 | 
            +
             | 
| 4730 | 
            +
            [[host_name("kernel_mul_mv_id_q4_K_f32")]]
         | 
| 4731 | 
            +
            kernel void kernel_mul_mv_id_q4_K_f32(
         | 
| 4732 | 
            +
                    device const    char * ids,
         | 
| 4733 | 
            +
                    device const    char * src1,
         | 
| 4734 | 
            +
                    device         uchar * dst,
         | 
| 4735 | 
            +
                    constant     int64_t & nbi1,
         | 
| 4736 | 
            +
                    constant     int64_t & ne00,
         | 
| 4737 | 
            +
                    constant     int64_t & ne01,
         | 
| 4738 | 
            +
                    constant     int64_t & ne02,
         | 
| 4739 | 
            +
                    constant    uint64_t & nb00,
         | 
| 4740 | 
            +
                    constant    uint64_t & nb01,
         | 
| 4741 | 
            +
                    constant    uint64_t & nb02,
         | 
| 4742 | 
            +
                    constant     int64_t & ne10,
         | 
| 4743 | 
            +
                    constant     int64_t & ne11,
         | 
| 4744 | 
            +
                    constant     int64_t & ne12,
         | 
| 4745 | 
            +
                    constant     int64_t & ne13,
         | 
| 4746 | 
            +
                    constant    uint64_t & nb10,
         | 
| 4747 | 
            +
                    constant    uint64_t & nb11,
         | 
| 4748 | 
            +
                    constant    uint64_t & nb12,
         | 
| 4749 | 
            +
                    constant     int64_t & ne0,
         | 
| 4750 | 
            +
                    constant     int64_t & ne1,
         | 
| 4751 | 
            +
                    constant     int64_t & nb1,
         | 
| 4752 | 
            +
                    constant        uint & r2,
         | 
| 4753 | 
            +
                    constant        uint & r3,
         | 
| 4754 | 
            +
                    constant         int & idx,
         | 
| 4755 | 
            +
                    device const    char * src00,
         | 
| 4756 | 
            +
                    device const    char * src01,
         | 
| 4757 | 
            +
                    device const    char * src02,
         | 
| 4758 | 
            +
                    device const    char * src03,
         | 
| 4759 | 
            +
                    device const    char * src04,
         | 
| 4760 | 
            +
                    device const    char * src05,
         | 
| 4761 | 
            +
                    device const    char * src06,
         | 
| 4762 | 
            +
                    device const    char * src07,
         | 
| 4763 | 
            +
                    uint3                  tgpig[[threadgroup_position_in_grid]],
         | 
| 4764 | 
            +
                    uint                   tiitg[[thread_index_in_threadgroup]],
         | 
| 4765 | 
            +
                    uint                   tiisg[[thread_index_in_simdgroup]],
         | 
| 4766 | 
            +
                    uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
         | 
| 4767 | 
            +
                device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
         | 
| 4768 | 
            +
             | 
| 4769 | 
            +
                const int64_t bid = tgpig.z/(ne12*ne13);
         | 
| 4770 | 
            +
             | 
| 4771 | 
            +
                tgpig.z = tgpig.z%(ne12*ne13);
         | 
| 4772 | 
            +
             | 
| 4773 | 
            +
                const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
         | 
| 4774 | 
            +
             | 
| 4775 | 
            +
                kernel_mul_mv_q4_K_f32_impl(
         | 
| 4776 | 
            +
                    src0[id],
         | 
| 4777 | 
            +
                    (device const float *) (src1 + bid*nb11),
         | 
| 4778 | 
            +
                    (device       float *) ( dst + bid*nb1),
         | 
| 4779 | 
            +
                    ne00,
         | 
| 4780 | 
            +
                    ne01,
         | 
| 4781 | 
            +
                    ne02,
         | 
| 4782 | 
            +
                    ne10,
         | 
| 4783 | 
            +
                    ne12,
         | 
| 4784 | 
            +
                    ne0,
         | 
| 4785 | 
            +
                    ne1,
         | 
| 4786 | 
            +
                    r2,
         | 
| 4787 | 
            +
                    r3,
         | 
| 4788 | 
            +
                    tgpig,
         | 
| 4789 | 
            +
                    tiisg,
         | 
| 4790 | 
            +
                    sgitg);
         | 
| 4791 | 
            +
            }
         | 
| 4792 | 
            +
             | 
| 4793 | 
            +
            [[host_name("kernel_mul_mv_id_q5_K_f32")]]
         | 
| 4794 | 
            +
            kernel void kernel_mul_mv_id_q5_K_f32(
         | 
| 4795 | 
            +
                    device const    char * ids,
         | 
| 4796 | 
            +
                    device const    char * src1,
         | 
| 4797 | 
            +
                    device         uchar * dst,
         | 
| 4798 | 
            +
                    constant     int64_t & nbi1,
         | 
| 4799 | 
            +
                    constant     int64_t & ne00,
         | 
| 4800 | 
            +
                    constant     int64_t & ne01,
         | 
| 4801 | 
            +
                    constant     int64_t & ne02,
         | 
| 4802 | 
            +
                    constant    uint64_t & nb00,
         | 
| 4803 | 
            +
                    constant    uint64_t & nb01,
         | 
| 4804 | 
            +
                    constant    uint64_t & nb02,
         | 
| 4805 | 
            +
                    constant     int64_t & ne10,
         | 
| 4806 | 
            +
                    constant     int64_t & ne11,
         | 
| 4807 | 
            +
                    constant     int64_t & ne12,
         | 
| 4808 | 
            +
                    constant     int64_t & ne13,
         | 
| 4809 | 
            +
                    constant    uint64_t & nb10,
         | 
| 4810 | 
            +
                    constant    uint64_t & nb11,
         | 
| 4811 | 
            +
                    constant    uint64_t & nb12,
         | 
| 4812 | 
            +
                    constant     int64_t & ne0,
         | 
| 4813 | 
            +
                    constant     int64_t & ne1,
         | 
| 4814 | 
            +
                    constant     int64_t & nb1,
         | 
| 4815 | 
            +
                    constant        uint & r2,
         | 
| 4816 | 
            +
                    constant        uint & r3,
         | 
| 4817 | 
            +
                    constant         int & idx,
         | 
| 4818 | 
            +
                    device const    char * src00,
         | 
| 4819 | 
            +
                    device const    char * src01,
         | 
| 4820 | 
            +
                    device const    char * src02,
         | 
| 4821 | 
            +
                    device const    char * src03,
         | 
| 4822 | 
            +
                    device const    char * src04,
         | 
| 4823 | 
            +
                    device const    char * src05,
         | 
| 4824 | 
            +
                    device const    char * src06,
         | 
| 4825 | 
            +
                    device const    char * src07,
         | 
| 4826 | 
            +
                    uint3                  tgpig[[threadgroup_position_in_grid]],
         | 
| 4827 | 
            +
                    uint                   tiitg[[thread_index_in_threadgroup]],
         | 
| 4828 | 
            +
                    uint                   tiisg[[thread_index_in_simdgroup]],
         | 
| 4829 | 
            +
                    uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
         | 
| 4830 | 
            +
                device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
         | 
| 4831 | 
            +
             | 
| 4832 | 
            +
                const int64_t bid = tgpig.z/(ne12*ne13);
         | 
| 4833 | 
            +
             | 
| 4834 | 
            +
                tgpig.z = tgpig.z%(ne12*ne13);
         | 
| 4835 | 
            +
             | 
| 4836 | 
            +
                const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
         | 
| 4837 | 
            +
             | 
| 4838 | 
            +
                kernel_mul_mv_q5_K_f32_impl(
         | 
| 4839 | 
            +
                    src0[id],
         | 
| 4840 | 
            +
                    (device const float *) (src1 + bid*nb11),
         | 
| 4841 | 
            +
                    (device       float *) ( dst + bid*nb1),
         | 
| 4842 | 
            +
                    ne00,
         | 
| 4843 | 
            +
                    ne01,
         | 
| 4844 | 
            +
                    ne02,
         | 
| 4845 | 
            +
                    ne10,
         | 
| 4846 | 
            +
                    ne12,
         | 
| 4847 | 
            +
                    ne0,
         | 
| 4848 | 
            +
                    ne1,
         | 
| 4849 | 
            +
                    r2,
         | 
| 4850 | 
            +
                    r3,
         | 
| 4851 | 
            +
                    tgpig,
         | 
| 4852 | 
            +
                    tiisg,
         | 
| 4853 | 
            +
                    sgitg);
         | 
| 4854 | 
            +
            }
         | 
| 4855 | 
            +
             | 
| 4856 | 
            +
            [[host_name("kernel_mul_mv_id_q6_K_f32")]]
         | 
| 4857 | 
            +
            kernel void kernel_mul_mv_id_q6_K_f32(
         | 
| 4858 | 
            +
                    device const    char * ids,
         | 
| 4859 | 
            +
                    device const    char * src1,
         | 
| 4860 | 
            +
                    device         uchar * dst,
         | 
| 4861 | 
            +
                    constant     int64_t & nbi1,
         | 
| 4862 | 
            +
                    constant     int64_t & ne00,
         | 
| 4863 | 
            +
                    constant     int64_t & ne01,
         | 
| 4864 | 
            +
                    constant     int64_t & ne02,
         | 
| 4865 | 
            +
                    constant    uint64_t & nb00,
         | 
| 4866 | 
            +
                    constant    uint64_t & nb01,
         | 
| 4867 | 
            +
                    constant    uint64_t & nb02,
         | 
| 4868 | 
            +
                    constant     int64_t & ne10,
         | 
| 4869 | 
            +
                    constant     int64_t & ne11,
         | 
| 4870 | 
            +
                    constant     int64_t & ne12,
         | 
| 4871 | 
            +
                    constant     int64_t & ne13,
         | 
| 4872 | 
            +
                    constant    uint64_t & nb10,
         | 
| 4873 | 
            +
                    constant    uint64_t & nb11,
         | 
| 4874 | 
            +
                    constant    uint64_t & nb12,
         | 
| 4875 | 
            +
                    constant     int64_t & ne0,
         | 
| 4876 | 
            +
                    constant     int64_t & ne1,
         | 
| 4877 | 
            +
                    constant     int64_t & nb1,
         | 
| 4878 | 
            +
                    constant        uint & r2,
         | 
| 4879 | 
            +
                    constant        uint & r3,
         | 
| 4880 | 
            +
                    constant         int & idx,
         | 
| 4881 | 
            +
                    device const    char * src00,
         | 
| 4882 | 
            +
                    device const    char * src01,
         | 
| 4883 | 
            +
                    device const    char * src02,
         | 
| 4884 | 
            +
                    device const    char * src03,
         | 
| 4885 | 
            +
                    device const    char * src04,
         | 
| 4886 | 
            +
                    device const    char * src05,
         | 
| 4887 | 
            +
                    device const    char * src06,
         | 
| 4888 | 
            +
                    device const    char * src07,
         | 
| 4889 | 
            +
                    uint3                  tgpig[[threadgroup_position_in_grid]],
         | 
| 4890 | 
            +
                    uint                   tiitg[[thread_index_in_threadgroup]],
         | 
| 4891 | 
            +
                    uint                   tiisg[[thread_index_in_simdgroup]],
         | 
| 4892 | 
            +
                    uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
         | 
| 4893 | 
            +
                device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
         | 
| 4894 | 
            +
             | 
| 4895 | 
            +
                const int64_t bid = tgpig.z/(ne12*ne13);
         | 
| 4896 | 
            +
             | 
| 4897 | 
            +
                tgpig.z = tgpig.z%(ne12*ne13);
         | 
| 4898 | 
            +
             | 
| 4899 | 
            +
                const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
         | 
| 4900 | 
            +
             | 
| 4901 | 
            +
                kernel_mul_mv_q6_K_f32_impl(
         | 
| 4902 | 
            +
                    src0[id],
         | 
| 4903 | 
            +
                    (device const float *) (src1 + bid*nb11),
         | 
| 4904 | 
            +
                    (device       float *) ( dst + bid*nb1),
         | 
| 4905 | 
            +
                    ne00,
         | 
| 4906 | 
            +
                    ne01,
         | 
| 4907 | 
            +
                    ne02,
         | 
| 4908 | 
            +
                    ne10,
         | 
| 4909 | 
            +
                    ne12,
         | 
| 4910 | 
            +
                    ne0,
         | 
| 4911 | 
            +
                    ne1,
         | 
| 4912 | 
            +
                    r2,
         | 
| 4913 | 
            +
                    r3,
         | 
| 4914 | 
            +
                    tgpig,
         | 
| 4915 | 
            +
                    tiisg,
         | 
| 4916 | 
            +
                    sgitg);
         | 
| 4917 | 
            +
            }
         |